From f28e4df666b1a900b4cdb467c8be05ad3acf7a9b Mon Sep 17 00:00:00 2001 From: MLIR Team Date: Thu, 1 Nov 2018 07:26:00 -0700 Subject: Adds a dependence check to test whether two accesses to the same memref access the same element. - Builds access functions and iterations domains for each access. - Builds dependence polyhedron constraint system which has equality constraints for equated access functions and inequality constraints for iteration domain loop bounds. - Runs elimination on the dependence polyhedron to test if no dependence exists between the accesses. - Adds a trivial LoopFusion transformation pass with a simple test policy to test dependence between accesses to the same memref in adjacent loops. - The LoopFusion pass will be extended in subsequent CLs. PiperOrigin-RevId: 219630898 --- mlir/lib/Transforms/LoopFusion.cpp | 244 +++++++++++++++++++++++++++++++++++++ 1 file changed, 244 insertions(+) create mode 100644 mlir/lib/Transforms/LoopFusion.cpp (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp new file mode 100644 index 00000000000..d9cdf9d919b --- /dev/null +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -0,0 +1,244 @@ +//===- LoopFusion.cpp - Code to perform loop fusion -----------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements loop fusion. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/Analysis/LoopAnalysis.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/StmtVisitor.h" +#include "mlir/Pass.h" +#include "mlir/StandardOps/StandardOps.h" +#include "mlir/Transforms/LoopUtils.h" +#include "mlir/Transforms/Passes.h" +#include "llvm/ADT/DenseMap.h" + +using namespace mlir; + +namespace { + +/// Loop fusion pass. This pass fuses adjacent loops in MLFunctions which +/// access the same memref with no dependences. +// See MatchTestPattern for details on candidate loop selection. +// TODO(andydavis) Extend this pass to check for fusion preventing dependences, +// and add support for more general loop fusion algorithms. +struct LoopFusion : public FunctionPass { + LoopFusion() {} + + PassResult runOnMLFunction(MLFunction *f) override; +}; + +// LoopCollector walks the statements in an MLFunction and builds a map from +// StmtBlocks to a list of loops within the StmtBlock, and a map from ForStmts +// to the list of loads and stores with its StmtBlock. +class LoopCollector : public StmtWalker { +public: + DenseMap> loopMap; + DenseMap> loadsAndStoresMap; + bool hasIfStmt = false; + + void visitForStmt(ForStmt *forStmt) { + loopMap[forStmt->getBlock()].push_back(forStmt); + } + + void visitIfStmt(IfStmt *ifStmt) { hasIfStmt = true; } + + void visitOperationStmt(OperationStmt *opStmt) { + if (auto *parentStmt = opStmt->getParentStmt()) { + if (auto *parentForStmt = dyn_cast(parentStmt)) { + if (opStmt->isa() || opStmt->isa()) { + loadsAndStoresMap[parentForStmt].push_back(opStmt); + } + } + } + } +}; + +} // end anonymous namespace + +FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; } + +// TODO(andydavis) Remove the following test code when more general loop +// fusion is supported. +struct FusionCandidate { + // Loop nest of ForStmts with 'accessA' in the inner-most loop. + SmallVector forStmtsA; + // Load or store operation within loop nest 'forStmtsA'. + MemRefAccess accessA; + // Loop nest of ForStmts with 'accessB' in the inner-most loop. + SmallVector forStmtsB; + // Load or store operation within loop nest 'forStmtsB'. + MemRefAccess accessB; +}; + +static void getSingleMemRefAccess(OperationStmt *loadOrStoreOpStmt, + MemRefAccess *access) { + if (auto loadOp = loadOrStoreOpStmt->dyn_cast()) { + access->memref = cast(loadOp->getMemRef()); + access->opStmt = loadOrStoreOpStmt; + auto loadMemrefType = loadOp->getMemRefType(); + access->indices.reserve(loadMemrefType.getRank()); + for (auto *index : loadOp->getIndices()) { + access->indices.push_back(cast(index)); + } + } else { + assert(loadOrStoreOpStmt->isa()); + auto storeOp = loadOrStoreOpStmt->dyn_cast(); + access->opStmt = loadOrStoreOpStmt; + access->memref = cast(storeOp->getMemRef()); + auto storeMemrefType = storeOp->getMemRefType(); + access->indices.reserve(storeMemrefType.getRank()); + for (auto *index : storeOp->getIndices()) { + access->indices.push_back(cast(index)); + } + } +} + +// Checks if 'forStmtA' and 'forStmtB' match specific test criterion: +// constant loop bounds, no nested loops, single StoreOp in 'forStmtA' and +// a single LoadOp in 'forStmtB'. +// Returns true if the test pattern matches, false otherwise. +static bool MatchTestPatternLoopPair(LoopCollector *lc, + FusionCandidate *candidate, + ForStmt *forStmtA, ForStmt *forStmtB) { + if (forStmtA == nullptr || forStmtB == nullptr) + return false; + // Return if 'forStmtA' and 'forStmtB' do not have matching constant + // bounds and step. + if (!forStmtA->hasConstantBounds() || !forStmtB->hasConstantBounds() || + forStmtA->getConstantLowerBound() != forStmtB->getConstantLowerBound() || + forStmtA->getConstantUpperBound() != forStmtB->getConstantUpperBound() || + forStmtA->getStep() != forStmtB->getStep()) + return false; + + // Return if 'forStmtA' or 'forStmtB' have nested loops. + if (lc->loopMap.count(forStmtA) > 0 || lc->loopMap.count(forStmtB)) + return false; + + // Return if 'forStmtA' or 'forStmtB' do not have exactly one load or store. + if (lc->loadsAndStoresMap[forStmtA].size() != 1 || + lc->loadsAndStoresMap[forStmtB].size() != 1) + return false; + + // Get load/store access for forStmtA. + getSingleMemRefAccess(lc->loadsAndStoresMap[forStmtA][0], + &candidate->accessA); + // Return if 'accessA' is not a store. + if (!candidate->accessA.opStmt->isa()) + return false; + + // Get load/store access for forStmtB. + getSingleMemRefAccess(lc->loadsAndStoresMap[forStmtB][0], + &candidate->accessB); + + // Return if accesses do not access the same memref. + if (candidate->accessA.memref != candidate->accessB.memref) + return false; + + candidate->forStmtsA.push_back(forStmtA); + candidate->forStmtsB.push_back(forStmtB); + return true; +} + +// Returns the child ForStmt of 'parent' if unique, returns false otherwise. +ForStmt *getSingleForStmtChild(ForStmt *parent) { + if (parent->getStatements().size() == 1 && isa(parent->front())) + return dyn_cast(&parent->front()); + return nullptr; +} + +// Checks for a specific ForStmt/OpStatment test pattern in 'f', returns true +// on success and resturns fusion candidate in 'candidate'. Returns false +// otherwise. +// Currently supported test patterns: +// *) Adjacent loops with a StoreOp the only op in first loop, and a LoadOp the +// only op in the second loop (both load/store accessing the same memref). +// *) As above, but with one level of perfect loop nesting. +// +// TODO(andydavis) Look into using ntv@ pattern matcher here. +static bool MatchTestPattern(MLFunction *f, FusionCandidate *candidate) { + LoopCollector lc; + lc.walk(f); + // Return if an IfStmt was found or if less than two ForStmts were found. + if (lc.hasIfStmt || lc.loopMap.count(f) == 0 || lc.loopMap[f].size() < 2) + return false; + auto *forStmtA = lc.loopMap[f][0]; + auto *forStmtB = lc.loopMap[f][1]; + if (!MatchTestPatternLoopPair(&lc, candidate, forStmtA, forStmtB)) { + // Check for one level of loop nesting. + candidate->forStmtsA.push_back(forStmtA); + candidate->forStmtsB.push_back(forStmtB); + return MatchTestPatternLoopPair(&lc, candidate, + getSingleForStmtChild(forStmtA), + getSingleForStmtChild(forStmtB)); + } + return true; +} + +// FuseLoops implements the code generation mechanics of loop fusion. +// Fuses the operations statments from the inner-most loop in 'c.forStmtsB', +// by cloning them into the inner-most loop in 'c.forStmtsA', then erasing +// old statements and loops. +static void fuseLoops(const FusionCandidate &c) { + MLFuncBuilder builder(c.forStmtsA.back(), + StmtBlock::iterator(c.forStmtsA.back()->end())); + DenseMap operandMap; + assert(c.forStmtsA.size() == c.forStmtsB.size()); + for (unsigned i = 0, e = c.forStmtsA.size(); i < e; i++) { + // Map loop IVs to 'forStmtB[i]' to loop IV for 'forStmtA[i]'. + operandMap[c.forStmtsB[i]] = c.forStmtsA[i]; + } + // Clone the body of inner-most loop in 'forStmtsB', into the body of + // inner-most loop in 'forStmtsA'. + SmallVector stmtsToErase; + auto *innerForStmtB = c.forStmtsB.back(); + for (auto &stmt : *innerForStmtB) { + builder.clone(stmt, operandMap); + stmtsToErase.push_back(&stmt); + } + // Erase 'forStmtB' and its statement list. + for (auto it = stmtsToErase.rbegin(); it != stmtsToErase.rend(); ++it) + (*it)->erase(); + // Erase 'forStmtsB' loop nest. + for (int i = static_cast(c.forStmtsB.size()) - 1; i >= 0; --i) + c.forStmtsB[i]->erase(); +} + +PassResult LoopFusion::runOnMLFunction(MLFunction *f) { + FusionCandidate candidate; + if (!MatchTestPattern(f, &candidate)) + return failure(); + + // TODO(andydavis) Add checks for fusion-preventing dependences and ordering + // constraints which would prevent fusion. + // TODO(andydavis) This check if overly conservative for now. Support fusing + // statements with compatible dependences (i.e. statements where the + // dependence between the statements does not reverse direction when the + // statements are fused into the same loop). + if (!checkMemrefAccessDependence(candidate.accessA, candidate.accessB)) { + // Current conservatinve test policy: No dependence exists between accesses + // in different loop nests -> fuse loops. + fuseLoops(candidate); + } + + return success(); +} -- cgit v1.2.3 From 6f0fb2272344bf7528066e1554c8cbb78078ae2a Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Tue, 6 Nov 2018 18:34:18 -0800 Subject: Add static pass registration Add static pass registration and change mlir-opt to use it. Future work is needed to refactor the registration for PassManager usage. Change build targets to alwayslink to enforce registration. PiperOrigin-RevId: 220390178 --- mlir/include/mlir/Pass.h | 62 ++++++++++++++ mlir/include/mlir/Support/PassNameParser.h | 40 +++++++++ mlir/lib/Analysis/MemRefBoundCheck.cpp | 8 ++ mlir/lib/Analysis/MemRefDependenceCheck.cpp | 7 ++ mlir/lib/Analysis/Pass.cpp | 37 +++++++++ mlir/lib/Transforms/CFGFunctionViewGraph.cpp | 10 ++- mlir/lib/Transforms/Canonicalizer.cpp | 7 ++ mlir/lib/Transforms/ComposeAffineMaps.cpp | 7 ++ mlir/lib/Transforms/ConstantFold.cpp | 7 ++ mlir/lib/Transforms/ConvertToCFG.cpp | 8 ++ mlir/lib/Transforms/LoopFusion.cpp | 5 ++ mlir/lib/Transforms/LoopTiling.cpp | 6 ++ mlir/lib/Transforms/LoopUnroll.cpp | 20 +++-- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 9 +- mlir/lib/Transforms/PipelineDataTransfer.cpp | 9 ++ mlir/lib/Transforms/SimplifyAffineExpr.cpp | 7 ++ mlir/lib/Transforms/Vectorize.cpp | 8 ++ mlir/tools/mlir-opt/mlir-opt.cpp | 118 ++------------------------- 18 files changed, 256 insertions(+), 119 deletions(-) create mode 100644 mlir/include/mlir/Support/PassNameParser.h (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Pass.h b/mlir/include/mlir/Pass.h index cd7b7027907..d1610bf4d08 100644 --- a/mlir/include/mlir/Pass.h +++ b/mlir/include/mlir/Pass.h @@ -18,7 +18,10 @@ #ifndef MLIR_PASS_H #define MLIR_PASS_H +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Support/Compiler.h" +#include namespace mlir { class Function; @@ -81,6 +84,65 @@ public: virtual PassResult runOnModule(Module *m) override; }; +using PassAllocatorFunction = std::function; + +/// Structure to group information about a pass (argument to invoke via +/// mlir-opt, description, pass allocator and unique ID). +class PassInfo { +public: + /// PassInfo constructor should not be invoked directly, instead use + /// PassRegistration or registerPass. + PassInfo(StringRef arg, StringRef description, const void *passID, + PassAllocatorFunction allocator) + : arg(arg), description(description), allocator(allocator), + passID(passID){}; + + /// Returns an allocated instance of this pass. + Pass *createPass() const { + assert(allocator && + "Cannot call createPass on PassInfo without default allocator"); + return allocator(); + } + + /// Returns the command line option that may be passed to 'mlir-opt' that will + /// cause this pass to run or null if there is no such argument. + StringRef getPassArgument() const { return arg; } + + /// Returns a description for the pass, this never returns null. + StringRef getPassDescription() const { return description; } + +private: + // The argument with which to invoke the pass via mlir-opt. + StringRef arg; + + // Description of the pass. + StringRef description; + + // Allocator to construct an instance of this pass. + PassAllocatorFunction allocator; + + // Unique identifier for pass. + const void *passID; +}; + +/// Register a specific dialect creation function with the system, typically +/// used through the PassRegistration template. +void registerPass(StringRef arg, StringRef description, const void *passID, + const PassAllocatorFunction &function); + +/// PassRegistration provides a global initializer that registers a Pass +/// allocation routine. +/// +/// Usage: +/// +/// // At namespace scope. +/// static PassRegistration Unused("unused", "Unused pass"); +template struct PassRegistration { + PassRegistration(StringRef arg, StringRef description) { + registerPass(arg, description, &ConcretePass::passID, + [&]() { return new ConcretePass(); }); + } +}; } // end namespace mlir #endif // MLIR_PASS_H diff --git a/mlir/include/mlir/Support/PassNameParser.h b/mlir/include/mlir/Support/PassNameParser.h new file mode 100644 index 00000000000..bbdf433b9ab --- /dev/null +++ b/mlir/include/mlir/Support/PassNameParser.h @@ -0,0 +1,40 @@ +//===- PassNameParser.h - Base classes for compiler passes ------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// The PassNameParser class adds all passes linked in to the system that are +// creatable to the tool. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_SUPPORT_PASSNAMEPARSER_H_ +#define MLIR_SUPPORT_PASSNAMEPARSER_H_ + +#include "llvm/Support/CommandLine.h" + +namespace mlir { +class PassInfo; + +/// Adds command line option for each registered pass. +struct PassNameParser : public llvm::cl::parser { + PassNameParser(llvm::cl::Option &opt); + + void printOptionInfo(const llvm::cl::Option &O, + size_t GlobalWidth) const override; +}; +} // end namespace mlir + +#endif // MLIR_SUPPORT_PASSNAMEPARSER_H_ diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index 0725cea7086..a7f0ebf4936 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -45,10 +45,14 @@ struct MemRefBoundCheck : public FunctionPass, StmtWalker { PassResult runOnCFGFunction(CFGFunction *f) override { return success(); } void visitOperationStmt(OperationStmt *opStmt); + + static char passID; }; } // end anonymous namespace +char MemRefBoundCheck::passID = 0; + FunctionPass *mlir::createMemRefBoundCheckPass() { return new MemRefBoundCheck(); } @@ -164,3 +168,7 @@ void MemRefBoundCheck::visitOperationStmt(OperationStmt *opStmt) { PassResult MemRefBoundCheck::runOnMLFunction(MLFunction *f) { return walk(f), success(); } + +static PassRegistration + memRefBoundCheck("memref-bound-check", + "Check memref accesses in an MLFunction"); diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp index 3ca669c5c85..7a620c1a3a8 100644 --- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp @@ -51,10 +51,13 @@ struct MemRefDependenceCheck : public FunctionPass, loadsAndStores.push_back(opStmt); } } + static char passID; }; } // end anonymous namespace +char MemRefDependenceCheck::passID = 0; + FunctionPass *mlir::createMemRefDependenceCheckPass() { return new MemRefDependenceCheck(); } @@ -132,3 +135,7 @@ PassResult MemRefDependenceCheck::runOnMLFunction(MLFunction *f) { checkDependences(loadsAndStores); return success(); } + +static PassRegistration + pass("memref-dependence-check", + "Checks dependences between all pairs of memref accesses."); diff --git a/mlir/lib/Analysis/Pass.cpp b/mlir/lib/Analysis/Pass.cpp index 1249c18c07e..ea9da5b0e80 100644 --- a/mlir/lib/Analysis/Pass.cpp +++ b/mlir/lib/Analysis/Pass.cpp @@ -23,6 +23,9 @@ #include "mlir/IR/CFGFunction.h" #include "mlir/IR/MLFunction.h" #include "mlir/IR/Module.h" +#include "mlir/Support/PassNameParser.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/ManagedStatic.h" using namespace mlir; /// Out of line virtual method to ensure vtables and metadata are emitted to a @@ -51,3 +54,37 @@ PassResult FunctionPass::runOnFunction(Function *fn) { return success(); } + +// TODO: The pass registry and pass name parsing should be moved out. +static llvm::ManagedStatic> passRegistry; + +void mlir::registerPass(StringRef arg, StringRef description, + const void *passID, + const PassAllocatorFunction &function) { + bool inserted = passRegistry + ->insert(std::make_pair( + passID, PassInfo(arg, description, passID, function))) + .second; + assert(inserted && "Pass registered multiple times"); + (void)inserted; +} + +PassNameParser::PassNameParser(llvm::cl::Option &opt) + : llvm::cl::parser(opt) { + for (const auto &kv : *passRegistry) { + addLiteralOption(kv.second.getPassArgument(), &kv.second, + kv.second.getPassDescription()); + } +} + +void PassNameParser::printOptionInfo(const llvm::cl::Option &O, + size_t GlobalWidth) const { + PassNameParser *TP = const_cast(this); + llvm::array_pod_sort(TP->Values.begin(), TP->Values.end(), + [](const PassNameParser::OptionInfo *VT1, + const PassNameParser::OptionInfo *VT2) { + return VT1->Name.compare(VT2->Name); + }); + using llvm::cl::parser; + parser::printOptionInfo(O, GlobalWidth); +} diff --git a/mlir/lib/Transforms/CFGFunctionViewGraph.cpp b/mlir/lib/Transforms/CFGFunctionViewGraph.cpp index a75d26c1fbc..810264cb35e 100644 --- a/mlir/lib/Transforms/CFGFunctionViewGraph.cpp +++ b/mlir/lib/Transforms/CFGFunctionViewGraph.cpp @@ -74,13 +74,16 @@ void mlir::CFGFunction::viewGraph() const { namespace { struct PrintCFGPass : public FunctionPass { - PrintCFGPass(llvm::raw_ostream &os, bool shortNames, const llvm::Twine &title) + PrintCFGPass(llvm::raw_ostream &os = llvm::errs(), bool shortNames = false, + const llvm::Twine &title = "") : os(os), shortNames(shortNames), title(title) {} PassResult runOnCFGFunction(CFGFunction *function) override { mlir::writeGraph(os, function, shortNames, title); return success(); } + static char passID; + private: llvm::raw_ostream &os; bool shortNames; @@ -88,8 +91,13 @@ private: }; } // namespace +char PrintCFGPass::passID = 0; + FunctionPass *mlir::createPrintCFGGraphPass(llvm::raw_ostream &os, bool shortNames, const llvm::Twine &title) { return new PrintCFGPass(os, shortNames, title); } + +static PassRegistration pass("print-cfg-graph", + "Print CFG graph per function"); diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp index f34118ce21a..3a62f132459 100644 --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -35,9 +35,13 @@ namespace { /// Canonicalize operations in functions. struct Canonicalizer : public FunctionPass { PassResult runOnFunction(Function *fn) override; + + static char passID; }; } // end anonymous namespace +char Canonicalizer::passID = 0; + PassResult Canonicalizer::runOnFunction(Function *fn) { auto *context = fn->getContext(); OwningPatternList patterns; @@ -54,3 +58,6 @@ PassResult Canonicalizer::runOnFunction(Function *fn) { /// Create a Canonicalizer pass. FunctionPass *mlir::createCanonicalizerPass() { return new Canonicalizer(); } + +static PassRegistration pass("canonicalize", + "Canonicalize operations"); diff --git a/mlir/lib/Transforms/ComposeAffineMaps.cpp b/mlir/lib/Transforms/ComposeAffineMaps.cpp index af4a5d11521..61e2e8f2e83 100644 --- a/mlir/lib/Transforms/ComposeAffineMaps.cpp +++ b/mlir/lib/Transforms/ComposeAffineMaps.cpp @@ -50,10 +50,14 @@ struct ComposeAffineMaps : public FunctionPass, StmtWalker { void visitOperationStmt(OperationStmt *stmt); PassResult runOnMLFunction(MLFunction *f) override; using StmtWalker::walk; + + static char passID; }; } // end anonymous namespace +char ComposeAffineMaps::passID = 0; + FunctionPass *mlir::createComposeAffineMapsPass() { return new ComposeAffineMaps(); } @@ -92,3 +96,6 @@ PassResult ComposeAffineMaps::runOnMLFunction(MLFunction *f) { } return success(); } + +static PassRegistration pass("compose-affine-maps", + "Compose affine maps"); diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index 411d1caae29..9005c2bbf48 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -40,9 +40,13 @@ struct ConstantFold : public FunctionPass, StmtWalker { void visitForStmt(ForStmt *stmt); PassResult runOnCFGFunction(CFGFunction *f) override; PassResult runOnMLFunction(MLFunction *f) override; + + static char passID; }; } // end anonymous namespace +char ConstantFold::passID = 0; + /// Attempt to fold the specified operation, updating the IR to match. If /// constants are found, we keep track of them in the existingConstants list. /// @@ -174,3 +178,6 @@ PassResult ConstantFold::runOnMLFunction(MLFunction *f) { /// Creates a constant folding pass. FunctionPass *mlir::createConstantFoldPass() { return new ConstantFold(); } + +static PassRegistration + pass("constant-fold", "Constant fold operations in functions"); diff --git a/mlir/lib/Transforms/ConvertToCFG.cpp b/mlir/lib/Transforms/ConvertToCFG.cpp index 52687da65ba..b36717d272f 100644 --- a/mlir/lib/Transforms/ConvertToCFG.cpp +++ b/mlir/lib/Transforms/ConvertToCFG.cpp @@ -70,6 +70,8 @@ public: PassResult runOnModule(Module *m) override; + static char passID; + private: // Generates CFG functions for all ML functions in the module. void convertMLFunctions(); @@ -90,6 +92,8 @@ private: }; } // end anonymous namespace +char ModuleConverter::passID = 0; + // Iterates over all functions in the module generating CFG functions // equivalent to ML functions and replacing references to ML functions // with references to the generated ML functions. @@ -163,3 +167,7 @@ void ModuleConverter::removeMLFunctions() { /// Function references are appropriately patched to refer to the newly /// generated CFG functions. ModulePass *mlir::createConvertToCFGPass() { return new ModuleConverter(); } + +static PassRegistration + pass("convert-to-cfg", + "Convert all ML functions in the module to CFG ones"); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index d9cdf9d919b..ae4647e143d 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -45,6 +45,7 @@ struct LoopFusion : public FunctionPass { LoopFusion() {} PassResult runOnMLFunction(MLFunction *f) override; + static char passID; }; // LoopCollector walks the statements in an MLFunction and builds a map from @@ -75,6 +76,8 @@ public: } // end anonymous namespace +char LoopFusion::passID = 0; + FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; } // TODO(andydavis) Remove the following test code when more general loop @@ -242,3 +245,5 @@ PassResult LoopFusion::runOnMLFunction(MLFunction *f) { return success(); } + +static PassRegistration pass("loop-fusion", "Fuse loop nests"); diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index bd66e337609..3bff008942c 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -42,10 +42,14 @@ namespace { struct LoopTiling : public FunctionPass { PassResult runOnMLFunction(MLFunction *f) override; constexpr static unsigned kDefaultTileSize = 32; + + static char passID; }; } // end anonymous namespace +char LoopTiling::passID = 0; + /// Creates a pass to perform loop tiling on all suitable loop nests of an /// MLFunction. FunctionPass *mlir::createLoopTilingPass() { return new LoopTiling(); } @@ -238,3 +242,5 @@ PassResult LoopTiling::runOnMLFunction(MLFunction *f) { } return success(); } + +static PassRegistration pass("loop-tile", "Tile loop nests"); diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 15c7014dc42..ae09098f9d5 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -56,22 +56,20 @@ struct LoopUnroll : public FunctionPass { Optional unrollFactor; Optional unrollFull; - explicit LoopUnroll(Optional unrollFactor, - Optional unrollFull) + explicit LoopUnroll(Optional unrollFactor = None, + Optional unrollFull = None) : unrollFactor(unrollFactor), unrollFull(unrollFull) {} PassResult runOnMLFunction(MLFunction *f) override; /// Unroll this for stmt. Returns false if nothing was done. bool runOnForStmt(ForStmt *forStmt); + + static char passID; }; } // end anonymous namespace -FunctionPass *mlir::createLoopUnrollPass(int unrollFactor, int unrollFull) { - return new LoopUnroll(unrollFactor == -1 ? None - : Optional(unrollFactor), - unrollFull == -1 ? None : Optional(unrollFull)); -} +char LoopUnroll::passID = 0; PassResult LoopUnroll::runOnMLFunction(MLFunction *f) { // Gathers all innermost loops through a post order pruned walk. @@ -286,3 +284,11 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) { return true; } + +FunctionPass *mlir::createLoopUnrollPass(int unrollFactor, int unrollFull) { + return new LoopUnroll(unrollFactor == -1 ? None + : Optional(unrollFactor), + unrollFull == -1 ? None : Optional(unrollFull)); +} + +static PassRegistration pass("loop-unroll", "Unroll loops"); diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index f437b44ae26..ce6e939fae8 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -70,14 +70,18 @@ struct LoopUnrollAndJam : public FunctionPass { Optional unrollJamFactor; static const unsigned kDefaultUnrollJamFactor = 4; - explicit LoopUnrollAndJam(Optional unrollJamFactor) + explicit LoopUnrollAndJam(Optional unrollJamFactor = None) : unrollJamFactor(unrollJamFactor) {} PassResult runOnMLFunction(MLFunction *f) override; bool runOnForStmt(ForStmt *forStmt); + + static char passID; }; } // end anonymous namespace +char LoopUnrollAndJam::passID = 0; + FunctionPass *mlir::createLoopUnrollAndJamPass(int unrollJamFactor) { return new LoopUnrollAndJam( unrollJamFactor == -1 ? None : Optional(unrollJamFactor)); @@ -239,3 +243,6 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) { return true; } + +static PassRegistration pass("loop-unroll-jam", + "Unroll and jam loops"); diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index c59e007e543..52052e09d7b 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -47,10 +47,14 @@ struct PipelineDataTransfer : public FunctionPass, // Collect all 'for' statements. void visitForStmt(ForStmt *forStmt) { forStmts.push_back(forStmt); } std::vector forStmts; + + static char passID; }; } // end anonymous namespace +char PipelineDataTransfer::passID = 0; + /// Creates a pass to pipeline explicit movement of data across levels of the /// memory hierarchy. FunctionPass *mlir::createPipelineDataTransferPass() { @@ -306,3 +310,8 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { return success(); } + +static PassRegistration pass( + "pipeline-data-transfer", + "Pipeline non-blocking data transfers between explicitly managed levels of " + "the memory hierarchy"); diff --git a/mlir/lib/Transforms/SimplifyAffineExpr.cpp b/mlir/lib/Transforms/SimplifyAffineExpr.cpp index a412a83f66c..92d585f31bc 100644 --- a/mlir/lib/Transforms/SimplifyAffineExpr.cpp +++ b/mlir/lib/Transforms/SimplifyAffineExpr.cpp @@ -47,10 +47,14 @@ struct SimplifyAffineStructures : public FunctionPass, void visitIfStmt(IfStmt *ifStmt); void visitOperationStmt(OperationStmt *opStmt); + + static char passID; }; } // end anonymous namespace +char SimplifyAffineStructures::passID = 0; + FunctionPass *mlir::createSimplifyAffineStructuresPass() { return new SimplifyAffineStructures(); } @@ -83,3 +87,6 @@ PassResult SimplifyAffineStructures::runOnMLFunction(MLFunction *f) { walk(f); return success(); } + +static PassRegistration + pass("simplify-affine-structures", "Simplify affine expressions"); diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index fa97b7025d4..63969af451f 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -199,10 +199,14 @@ struct Vectorize : public FunctionPass { // Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit. MLFunctionMatcherContext MLContext; + + static char passID; }; } // end anonymous namespace +char Vectorize::passID = 0; + /////// TODO(ntv): Hoist to a VectorizationStrategy.cpp when appropriate. ////// namespace { @@ -669,3 +673,7 @@ PassResult Vectorize::runOnMLFunction(MLFunction *f) { } FunctionPass *mlir::createVectorizePass() { return new Vectorize(); } + +static PassRegistration + pass("vectorize", + "Vectorize to a target independent n-D vector abstraction"); diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 700436e8998..3225860c693 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -30,6 +30,7 @@ #include "mlir/IR/Module.h" #include "mlir/Parser.h" #include "mlir/Pass.h" +#include "mlir/Support/PassNameParser.h" #include "mlir/TensorFlow/ControlFlowOps.h" #include "mlir/TensorFlow/Passes.h" #include "mlir/TensorFlowLite/Passes.h" @@ -67,58 +68,7 @@ static cl::opt "expected-* lines on the corresponding line"), cl::init(false)); -enum Passes { - Canonicalize, - ComposeAffineMaps, - ConstantFold, - ConvertToCFG, - TFLiteLegaize, - LoopFusion, - LoopTiling, - LoopUnroll, - LoopUnrollAndJam, - MemRefBoundCheck, - MemRefDependenceCheck, - PipelineDataTransfer, - PrintCFGGraph, - SimplifyAffineStructures, - TFRaiseControlFlow, - Vectorize, - XLALower, -}; - -static cl::list passList( - "", cl::desc("Compiler passes to run"), - cl::values( - clEnumValN(Canonicalize, "canonicalize", "Canonicalize operations"), - clEnumValN(ComposeAffineMaps, "compose-affine-maps", - "Compose affine maps"), - clEnumValN(ConstantFold, "constant-fold", - "Constant fold operations in functions"), - clEnumValN(ConvertToCFG, "convert-to-cfg", - "Convert all ML functions in the module to CFG ones"), - clEnumValN(LoopFusion, "loop-fusion", "Fuse loop nests"), - clEnumValN(LoopTiling, "loop-tile", "Tile loop nests"), - clEnumValN(LoopUnroll, "loop-unroll", "Unroll loops"), - clEnumValN(LoopUnrollAndJam, "loop-unroll-jam", "Unroll and jam loops"), - clEnumValN(MemRefBoundCheck, "memref-bound-check", - "Convert all ML functions in the module to CFG ones"), - clEnumValN(MemRefDependenceCheck, "memref-dependence-check", - "Checks dependences between all pairs of memref accesses."), - clEnumValN(PipelineDataTransfer, "pipeline-data-transfer", - "Pipeline non-blocking data transfers between" - "explicitly managed levels of the memory hierarchy"), - clEnumValN(PrintCFGGraph, "print-cfg-graph", - "Print CFG graph per function"), - clEnumValN(SimplifyAffineStructures, "simplify-affine-structures", - "Simplify affine expressions"), - clEnumValN(TFLiteLegaize, "tfl-legalize", - "Legalize operations to TensorFlow Lite dialect"), - clEnumValN(TFRaiseControlFlow, "tf-raise-control-flow", - "Dynamic TensorFlow Switch/Match nodes to a CFG"), - clEnumValN(Vectorize, "vectorize", - "Vectorize to a target independent n-D vector abstraction."), - clEnumValN(XLALower, "xla-lower", "Lower to XLA dialect"))); +static std::vector *passList; enum OptResult { OptSuccess, OptFailure }; @@ -190,65 +140,9 @@ static OptResult performActions(SourceMgr &sourceMgr, MLIRContext *context) { return OptFailure; // Run each of the passes that were selected. - for (unsigned i = 0, e = passList.size(); i != e; ++i) { - auto passKind = passList[i]; - Pass *pass = nullptr; - switch (passKind) { - case Canonicalize: - pass = createCanonicalizerPass(); - break; - case ComposeAffineMaps: - pass = createComposeAffineMapsPass(); - break; - case ConstantFold: - pass = createConstantFoldPass(); - break; - case ConvertToCFG: - pass = createConvertToCFGPass(); - break; - case LoopFusion: - pass = createLoopFusionPass(); - break; - case LoopTiling: - pass = createLoopTilingPass(); - break; - case LoopUnroll: - pass = createLoopUnrollPass(); - break; - case LoopUnrollAndJam: - pass = createLoopUnrollAndJamPass(); - break; - case MemRefBoundCheck: - pass = createMemRefBoundCheckPass(); - break; - case MemRefDependenceCheck: - pass = createMemRefDependenceCheckPass(); - break; - case PipelineDataTransfer: - pass = createPipelineDataTransferPass(); - break; - case PrintCFGGraph: - pass = createPrintCFGGraphPass(); - break; - case SimplifyAffineStructures: - pass = createSimplifyAffineStructuresPass(); - break; - case TFLiteLegaize: - pass = tfl::createLegalizer(); - break; - case TFRaiseControlFlow: - pass = createRaiseTFControlFlowPass(); - break; - case Vectorize: - pass = createVectorizePass(); - break; - case XLALower: - pass = createXLALowerPass(); - break; - } - + for (const auto *passInfo : *passList) { + std::unique_ptr pass(passInfo->createPass()); PassResult result = pass->runOnModule(module.get()); - delete pass; if (result) return OptFailure; @@ -468,6 +362,10 @@ int main(int argc, char **argv) { llvm::PrettyStackTraceProgram x(argc, argv); InitLLVM y(argc, argv); + // Parse pass names in main to ensure static initialization completed. + llvm::cl::list passList( + "", llvm::cl::desc("Compiler passes to run")); + ::passList = &passList; cl::ParseCommandLineOptions(argc, argv, "MLIR modular optimizer driver\n"); // Set up the input file. -- cgit v1.2.3 From cc9a6ed09ddf75ad3964af5344b6fb74729d7a19 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Wed, 7 Nov 2018 10:24:03 -0800 Subject: Initialize Pass with PassID. The passID is not currently stored in Pass but this avoids the unused variable warning. The passID is used to uniquely identify passes, currently this is only stored/used in PassInfo. PiperOrigin-RevId: 220485662 --- mlir/include/mlir/Pass.h | 5 +++++ mlir/lib/Analysis/MemRefBoundCheck.cpp | 2 +- mlir/lib/Analysis/MemRefDependenceCheck.cpp | 3 ++- mlir/lib/Transforms/CFGFunctionViewGraph.cpp | 3 ++- mlir/lib/Transforms/Canonicalizer.cpp | 1 + mlir/lib/Transforms/ComposeAffineMaps.cpp | 2 +- mlir/lib/Transforms/ConstantFold.cpp | 2 ++ mlir/lib/Transforms/ConvertToCFG.cpp | 2 +- mlir/lib/Transforms/LoopFusion.cpp | 2 +- mlir/lib/Transforms/LoopTiling.cpp | 1 + mlir/lib/Transforms/LoopUnroll.cpp | 3 ++- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 3 ++- mlir/lib/Transforms/PipelineDataTransfer.cpp | 1 + mlir/lib/Transforms/SimplifyAffineExpr.cpp | 3 ++- mlir/lib/Transforms/Vectorize.cpp | 2 ++ 15 files changed, 26 insertions(+), 9 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Pass.h b/mlir/include/mlir/Pass.h index d1610bf4d08..f9802ac9250 100644 --- a/mlir/include/mlir/Pass.h +++ b/mlir/include/mlir/Pass.h @@ -40,6 +40,7 @@ struct LLVM_NODISCARD PassResult { class Pass { public: + explicit Pass(const void *passID) {} virtual ~Pass() = default; virtual PassResult runOnModule(Module *m) = 0; @@ -54,6 +55,8 @@ private: class ModulePass : public Pass { public: + explicit ModulePass(const void *passID) : Pass(passID) {} + virtual PassResult runOnModule(Module *m) override = 0; private: @@ -69,6 +72,8 @@ private: /// module. class FunctionPass : public Pass { public: + explicit FunctionPass(const void *passID) : Pass(passID) {} + /// Implement this function to be run on every function in the module. If you /// do not implement this, the default implementation will dispatch to /// runOnCFGFunction or runOnMLFunction. diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index a7f0ebf4936..eb4ce56d429 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -38,7 +38,7 @@ namespace { /// Checks for out of bound memef access subscripts.. struct MemRefBoundCheck : public FunctionPass, StmtWalker { - explicit MemRefBoundCheck() {} + explicit MemRefBoundCheck() : FunctionPass(&MemRefBoundCheck::passID) {} PassResult runOnMLFunction(MLFunction *f) override; // Not applicable to CFG functions. diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp index 7a620c1a3a8..28a80762b94 100644 --- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp @@ -40,7 +40,8 @@ namespace { struct MemRefDependenceCheck : public FunctionPass, StmtWalker { SmallVector loadsAndStores; - explicit MemRefDependenceCheck() {} + explicit MemRefDependenceCheck() + : FunctionPass(&MemRefDependenceCheck::passID) {} PassResult runOnMLFunction(MLFunction *f) override; // Not applicable to CFG functions. diff --git a/mlir/lib/Transforms/CFGFunctionViewGraph.cpp b/mlir/lib/Transforms/CFGFunctionViewGraph.cpp index 810264cb35e..d29708d4fef 100644 --- a/mlir/lib/Transforms/CFGFunctionViewGraph.cpp +++ b/mlir/lib/Transforms/CFGFunctionViewGraph.cpp @@ -76,7 +76,8 @@ namespace { struct PrintCFGPass : public FunctionPass { PrintCFGPass(llvm::raw_ostream &os = llvm::errs(), bool shortNames = false, const llvm::Twine &title = "") - : os(os), shortNames(shortNames), title(title) {} + : FunctionPass(&PrintCFGPass::passID), os(os), shortNames(shortNames), + title(title) {} PassResult runOnCFGFunction(CFGFunction *function) override { mlir::writeGraph(os, function, shortNames, title); return success(); diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp index 3a62f132459..e6f3a2b0550 100644 --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -34,6 +34,7 @@ namespace { /// Canonicalize operations in functions. struct Canonicalizer : public FunctionPass { + Canonicalizer() : FunctionPass(&Canonicalizer::passID) {} PassResult runOnFunction(Function *fn) override; static char passID; diff --git a/mlir/lib/Transforms/ComposeAffineMaps.cpp b/mlir/lib/Transforms/ComposeAffineMaps.cpp index 61e2e8f2e83..84507b91703 100644 --- a/mlir/lib/Transforms/ComposeAffineMaps.cpp +++ b/mlir/lib/Transforms/ComposeAffineMaps.cpp @@ -44,7 +44,7 @@ namespace { struct ComposeAffineMaps : public FunctionPass, StmtWalker { std::vector affineApplyOpsToErase; - explicit ComposeAffineMaps() {} + explicit ComposeAffineMaps() : FunctionPass(&ComposeAffineMaps::passID) {} using StmtListType = llvm::iplist; void walk(StmtListType::iterator Start, StmtListType::iterator End); void visitOperationStmt(OperationStmt *stmt); diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index 9005c2bbf48..15a5db15d73 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -27,6 +27,8 @@ using namespace mlir; namespace { /// Simple constant folding pass. struct ConstantFold : public FunctionPass, StmtWalker { + ConstantFold() : FunctionPass(&ConstantFold::passID) {} + // All constants in the function post folding. SmallVector existingConstants; // Operation statements that were folded and that need to be erased. diff --git a/mlir/lib/Transforms/ConvertToCFG.cpp b/mlir/lib/Transforms/ConvertToCFG.cpp index b36717d272f..8ffa23ca102 100644 --- a/mlir/lib/Transforms/ConvertToCFG.cpp +++ b/mlir/lib/Transforms/ConvertToCFG.cpp @@ -66,7 +66,7 @@ namespace { // ModuleConverter class does CFG conversion for the whole module. class ModuleConverter : public ModulePass { public: - explicit ModuleConverter() {} + explicit ModuleConverter() : ModulePass(&ModuleConverter::passID) {} PassResult runOnModule(Module *m) override; diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index ae4647e143d..87657aeb359 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -42,7 +42,7 @@ namespace { // TODO(andydavis) Extend this pass to check for fusion preventing dependences, // and add support for more general loop fusion algorithms. struct LoopFusion : public FunctionPass { - LoopFusion() {} + LoopFusion() : FunctionPass(&LoopFusion::passID) {} PassResult runOnMLFunction(MLFunction *f) override; static char passID; diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 3bff008942c..8efadfc44bb 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -40,6 +40,7 @@ namespace { /// A pass to perform loop tiling on all suitable loop nests of an MLFunction. struct LoopTiling : public FunctionPass { + LoopTiling() : FunctionPass(&LoopTiling::passID) {} PassResult runOnMLFunction(MLFunction *f) override; constexpr static unsigned kDefaultTileSize = 32; diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index ae09098f9d5..76e484c6a10 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -58,7 +58,8 @@ struct LoopUnroll : public FunctionPass { explicit LoopUnroll(Optional unrollFactor = None, Optional unrollFull = None) - : unrollFactor(unrollFactor), unrollFull(unrollFull) {} + : FunctionPass(&LoopUnroll::passID), unrollFactor(unrollFactor), + unrollFull(unrollFull) {} PassResult runOnMLFunction(MLFunction *f) override; diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index ce6e939fae8..45ca9dd98df 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -71,7 +71,8 @@ struct LoopUnrollAndJam : public FunctionPass { static const unsigned kDefaultUnrollJamFactor = 4; explicit LoopUnrollAndJam(Optional unrollJamFactor = None) - : unrollJamFactor(unrollJamFactor) {} + : FunctionPass(&LoopUnrollAndJam::passID), + unrollJamFactor(unrollJamFactor) {} PassResult runOnMLFunction(MLFunction *f) override; bool runOnForStmt(ForStmt *forStmt); diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 52052e09d7b..c3f131f407f 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -41,6 +41,7 @@ namespace { struct PipelineDataTransfer : public FunctionPass, StmtWalker { + PipelineDataTransfer() : FunctionPass(&PipelineDataTransfer::passID) {} PassResult runOnMLFunction(MLFunction *f) override; PassResult runOnForStmt(ForStmt *forStmt); diff --git a/mlir/lib/Transforms/SimplifyAffineExpr.cpp b/mlir/lib/Transforms/SimplifyAffineExpr.cpp index 92d585f31bc..06f3f8f44e2 100644 --- a/mlir/lib/Transforms/SimplifyAffineExpr.cpp +++ b/mlir/lib/Transforms/SimplifyAffineExpr.cpp @@ -38,7 +38,8 @@ namespace { // ML functions and CFG functions. struct SimplifyAffineStructures : public FunctionPass, StmtWalker { - explicit SimplifyAffineStructures() {} + explicit SimplifyAffineStructures() + : FunctionPass(&SimplifyAffineStructures::passID) {} PassResult runOnMLFunction(MLFunction *f) override; // Does nothing on CFG functions for now. No reusable walkers/visitors exist diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 63969af451f..90ee87563de 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -195,6 +195,8 @@ static std::vector makePatterns() { namespace { struct Vectorize : public FunctionPass { + Vectorize() : FunctionPass(&Vectorize::passID) {} + PassResult runOnMLFunction(MLFunction *f) override; // Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit. -- cgit v1.2.3 From b5424dd0cb3245e8151489ac6b2d02b45391470a Mon Sep 17 00:00:00 2001 From: MLIR Team Date: Fri, 9 Nov 2018 09:42:24 -0800 Subject: Adds support for returning the direction of the dependence between memref accesses (distance/direction vectors). Updates MemRefDependenceCheck to check and report on all memref access pairs at all loop nest depths. Updates old and adds new memref dependence check tests. Resolves multiple TODOs. PiperOrigin-RevId: 220816515 --- mlir/include/mlir/Analysis/AffineAnalysis.h | 30 +- mlir/lib/Analysis/AffineAnalysis.cpp | 562 ++++++++++++++++------ mlir/lib/Analysis/MemRefDependenceCheck.cpp | 91 +++- mlir/lib/Transforms/LoopFusion.cpp | 8 +- mlir/test/Transforms/memref-dependence-check.mlir | 403 ++++++++++++---- 5 files changed, 817 insertions(+), 277 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Analysis/AffineAnalysis.h b/mlir/include/mlir/Analysis/AffineAnalysis.h index a3fe813c548..64859f7a63b 100644 --- a/mlir/include/mlir/Analysis/AffineAnalysis.h +++ b/mlir/include/mlir/Analysis/AffineAnalysis.h @@ -72,11 +72,6 @@ bool getFlattenedAffineExpr(AffineExpr expr, unsigned numDims, bool addIndexSet(llvm::ArrayRef indices, FlatAffineConstraints *domain); -/// Checks whether two accesses to the same memref access the same element. -/// Each access is specified using the MemRefAccess structure, which contains -/// the operation statement, indices and memref associated with the access. -/// Returns 'false' if it can be determined conclusively that the accesses do -/// not access the same memref element. Returns 'true' otherwise. struct MemRefAccess { const MLValue *memref; const OperationStmt *opStmt; @@ -85,9 +80,30 @@ struct MemRefAccess { // 'indices'. void getAccessMap(AffineValueMap *accessMap) const; }; -bool checkMemrefAccessDependence(const MemRefAccess &srcAccess, - const MemRefAccess &dstAccess); +// DependenceComponent contains state about the direction of a dependence as an +// interval [lb, ub]. +// Distance vectors components are represented by the interval [lb, ub] with +// lb == ub. +// Direction vectors components are represented by the interval [lb, ub] with +// lb < ub. Note that ub/lb == None means unbounded. +struct DependenceComponent { + // The lower bound of the dependence distance. + llvm::Optional lb; + // The upper bound of the dependence distance (inclusive). + llvm::Optional ub; + DependenceComponent() : lb(llvm::None), ub(llvm::None) {} +}; + +/// Checks whether two accesses to the same memref access the same element. +/// Each access is specified using the MemRefAccess structure, which contains +/// the operation statement, indices and memref associated with the access. +/// Returns 'false' if it can be determined conclusively that the accesses do +/// not access the same memref element. Returns 'true' otherwise. +bool checkMemrefAccessDependence( + const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, + unsigned loopDepth, + llvm::SmallVector *dependenceComponents); } // end namespace mlir #endif // MLIR_ANALYSIS_AFFINE_ANALYSIS_H diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index a668b8eb2e6..5539cbc2c87 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -22,10 +22,12 @@ #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" +#include "mlir/Analysis/Utils.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Statements.h" #include "mlir/StandardOps/StandardOps.h" +#include "mlir/Support/MathExtras.h" #include "llvm/ADT/DenseMap.h" #include "llvm/Support/raw_ostream.h" @@ -441,6 +443,9 @@ bool mlir::addIndexSet(ArrayRef indices, // IterationDomainContext encapsulates the state required to represent // the iteration domain of an OperationStmt. +// TODO(andydavis) Move this into FlatAffineConstraints when we have shared +// code to manage the operand values and positions to use FlatAffineConstraints +// and AffineValueMap. struct IterationDomainContext { // Set of inequality constraint pairs, where each pair represents the // upper/lower bounds of a ForStmt in the iteration domain. @@ -452,8 +457,8 @@ struct IterationDomainContext { // [numDims, values.size()) representing symbol identifiers. SmallVector values; IterationDomainContext() : numDims(0) {} - unsigned getNumDims() { return numDims; } - unsigned getNumSymbols() { return values.size() - numDims; } + unsigned getNumDims() const { return numDims; } + unsigned getNumSymbols() const { return values.size() - numDims; } }; // Computes the iteration domain for 'opStmt' and populates 'ctx', which @@ -465,11 +470,8 @@ struct IterationDomainContext { // TODO(andydavis) Handle non-constant loop bounds by composing affine maps // for each ForStmt loop bound and adding de-duped ids/symbols to iteration // domain context. -// TODO(andydavis) Capture the context of the symbols. For example, check -// if a symbol is the result of a constant operation, and set the symbol to -// that value in FlatAffineConstraints (using setIdToConstant). -bool getIterationDomainContext(const Statement *stmt, - IterationDomainContext *ctx) { +static bool getIterationDomainContext(const Statement *stmt, + IterationDomainContext *ctx) { // Walk up tree storing parent statements in 'loops'. // TODO(andydavis) Extend this to gather enclosing IfStmts and consider // factoring it out into a utility function. @@ -505,73 +507,121 @@ bool getIterationDomainContext(const Statement *stmt, return addIndexSet(ctx->values, &ctx->domain); } +// ValuePositionMap manages the mapping from MLValues which represent dimension +// and symbol identifiers from 'src' and 'dst' access functions to positions +// in new space where some MLValues are kept separate (using addSrc/DstValue) +// and some MLValues are merged (addSymbolValue). +// Position lookups return the absolute position in the new space which +// has the following format: +// +// [src-dim-identifiers] [dst-dim-identifiers] [symbol-identifers] +// +// Note: access function non-IV dimension identifiers (that have 'dimension' +// positions in the access function position space) are assigned as symbols +// in the output position space. Convienience access functions which lookup +// an MLValue in multiple maps are provided (i.e. getSrcDimOrSymPos) to handle +// the common case of resolving positions for all access function operands. +// +// TODO(andydavis) Generalize this: could take a template parameter for +// the number of maps (3 in the current case), and lookups could take indices +// of maps to check. So getSrcDimOrSymPos would be "getPos(value, {0, 2})". +class ValuePositionMap { +public: + void addSrcValue(const MLValue *value) { + if (addValueAt(value, &srcDimPosMap, numSrcDims)) + ++numSrcDims; + } + void addDstValue(const MLValue *value) { + if (addValueAt(value, &dstDimPosMap, numDstDims)) + ++numDstDims; + } + void addSymbolValue(const MLValue *value) { + if (addValueAt(value, &symbolPosMap, numSymbols)) + ++numSymbols; + } + unsigned getSrcDimOrSymPos(const MLValue *value) const { + return getDimOrSymPos(value, srcDimPosMap, 0); + } + unsigned getDstDimOrSymPos(const MLValue *value) const { + return getDimOrSymPos(value, dstDimPosMap, numSrcDims); + } + unsigned getSymPos(const MLValue *value) const { + auto it = symbolPosMap.find(value); + assert(it != symbolPosMap.end()); + return numSrcDims + numDstDims + it->second; + } + + unsigned getNumSrcDims() const { return numSrcDims; } + unsigned getNumDstDims() const { return numDstDims; } + unsigned getNumDims() const { return numSrcDims + numDstDims; } + unsigned getNumSymbols() const { return numSymbols; } + +private: + bool addValueAt(const MLValue *value, + DenseMap *posMap, + unsigned position) { + auto it = posMap->find(value); + if (it == posMap->end()) { + (*posMap)[value] = position; + return true; + } + return false; + } + unsigned getDimOrSymPos(const MLValue *value, + const DenseMap &dimPosMap, + unsigned dimPosOffset) const { + auto it = dimPosMap.find(value); + if (it != dimPosMap.end()) { + return dimPosOffset + it->second; + } + it = symbolPosMap.find(value); + assert(it != symbolPosMap.end()); + return numSrcDims + numDstDims + it->second; + } + + unsigned numSrcDims = 0; + unsigned numDstDims = 0; + unsigned numSymbols = 0; + DenseMap srcDimPosMap; + DenseMap dstDimPosMap; + DenseMap symbolPosMap; +}; + // Builds a map from MLValue to identifier position in a new merged identifier // list, which is the result of merging dim/symbol lists from src/dst // iteration domains. The format of the new merged list is as follows: // // [src-dim-identifiers, dst-dim-identifiers, symbol-identifiers] // -// This method populates 'srcDimPosMap' and 'dstDimPosMap' with mappings from -// operand MLValues in 'srcAccessMap'/'dstAccessMap' to the position of these -// values in the merged list. -// In addition, this method populates 'symbolPosMap' with mappings from -// operand MLValues in both 'srcIterationDomainContext' and -// 'dstIterationDomainContext' to position of these values in the merged list. +// This method populates 'valuePosMap' with mappings from operand MLValues in +// 'srcAccessMap'/'dstAccessMap' (as well as those in +// 'srcIterationDomainContext'/'dstIterationDomainContext') to the position of +// these values in the merged list. static void buildDimAndSymbolPositionMaps( const IterationDomainContext &srcIterationDomainContext, const IterationDomainContext &dstIterationDomainContext, const AffineValueMap &srcAccessMap, const AffineValueMap &dstAccessMap, - DenseMap *srcDimPosMap, - DenseMap *dstDimPosMap, - DenseMap *symbolPosMap) { - unsigned pos = 0; - - auto updatePosMap = [&](DenseMap *posMap, - ArrayRef values, unsigned start, - unsigned limit) { - for (unsigned i = start; i < limit; ++i) { + ValuePositionMap *valuePosMap) { + auto updateValuePosMap = [&](ArrayRef values, bool isSrc) { + for (unsigned i = 0, e = values.size(); i < e; ++i) { auto *value = values[i]; - auto it = posMap->find(value); - if (it == posMap->end()) { - (*posMap)[value] = pos++; - } + if (!isa(values[i])) + valuePosMap->addSymbolValue(value); + else if (isSrc) + valuePosMap->addSrcValue(value); + else + valuePosMap->addDstValue(value); } }; - AffineMap srcMap = srcAccessMap.getAffineMap(); - AffineMap dstMap = dstAccessMap.getAffineMap(); - - // Update position map with src dimension identifiers from iteration domain - // and access function. - updatePosMap(srcDimPosMap, srcIterationDomainContext.values, 0, - srcIterationDomainContext.numDims); - // Update position map with 'srcAccessMap' operands not in iteration domain. - updatePosMap(srcDimPosMap, srcAccessMap.getOperands(), 0, - srcMap.getNumDims()); - - // Update position map with dst dimension identifiers from iteration domain - // and access function. - updatePosMap(dstDimPosMap, dstIterationDomainContext.values, 0, - dstIterationDomainContext.numDims); - // Update position map with 'dstAccessMap' operands not in iteration domain. - updatePosMap(dstDimPosMap, dstAccessMap.getOperands(), 0, - dstMap.getNumDims()); - - // Update position map with src symbol identifiers from iteration domain - // and access function. - updatePosMap(symbolPosMap, srcIterationDomainContext.values, - dstIterationDomainContext.numDims, - srcIterationDomainContext.values.size()); - updatePosMap(symbolPosMap, srcAccessMap.getOperands(), srcMap.getNumDims(), - srcMap.getNumDims() + srcMap.getNumSymbols()); - - // Update position map with dst symbol identifiers from iteration domain - // and access function. - updatePosMap(symbolPosMap, dstIterationDomainContext.values, - dstIterationDomainContext.numDims, - dstIterationDomainContext.values.size()); - updatePosMap(symbolPosMap, dstAccessMap.getOperands(), dstMap.getNumDims(), - dstMap.getNumDims() + dstMap.getNumSymbols()); + // Update value position map with identifiers from src iteration domain. + updateValuePosMap(srcIterationDomainContext.values, /*isSrc=*/true); + // Update value position map with identifiers from dst iteration domain. + updateValuePosMap(dstIterationDomainContext.values, /*isSrc=*/false); + // Update value position map with identifiers from src access function. + updateValuePosMap(srcAccessMap.getOperands(), /*isSrc=*/true); + // Update value position map with identifiers from dst access function. + updateValuePosMap(dstAccessMap.getOperands(), /*isSrc=*/false); } static unsigned getPos(const DenseMap &posMap, @@ -581,41 +631,55 @@ static unsigned getPos(const DenseMap &posMap, return it->second; } -// Adds iteration domain constraints from 'ctx.domain' into 'dependenceDomain'. -// Uses 'dimPosMap' to map from dim operand value in 'ctx.values', to dim -// position in 'dependenceDomain'. -// Uses 'symbolPosMap' to map from symbol operand value in 'ctx.values', to -// symbol position in 'dependenceDomain'. -static void -addDomainConstraints(const IterationDomainContext &ctx, - const DenseMap &dimPosMap, - const DenseMap &symbolPosMap, - FlatAffineConstraints *dependenceDomain) { - unsigned inputNumIneq = ctx.domain.getNumInequalities(); - unsigned inputNumDims = ctx.domain.getNumDimIds(); - unsigned inputNumSymbols = ctx.domain.getNumSymbolIds(); - unsigned inputNumIds = inputNumDims + inputNumSymbols; +// Adds iteration domain constraints from 'srcCtx' and 'dstCtx' into +// 'dependenceDomain'. +// Uses 'valuePosMap' to map from operand values in 'ctx.values' to position in +// 'dependenceDomain'. +static void addDomainConstraints(const IterationDomainContext &srcCtx, + const IterationDomainContext &dstCtx, + const ValuePositionMap &valuePosMap, + FlatAffineConstraints *dependenceDomain) { + unsigned srcNumIneq = srcCtx.domain.getNumInequalities(); + unsigned srcNumDims = srcCtx.domain.getNumDimIds(); + unsigned srcNumSymbols = srcCtx.domain.getNumSymbolIds(); + unsigned srcNumIds = srcNumDims + srcNumSymbols; + + unsigned dstNumIneq = dstCtx.domain.getNumInequalities(); + unsigned dstNumDims = dstCtx.domain.getNumDimIds(); + unsigned dstNumSymbols = dstCtx.domain.getNumSymbolIds(); + unsigned dstNumIds = dstNumDims + dstNumSymbols; unsigned outputNumDims = dependenceDomain->getNumDimIds(); unsigned outputNumSymbols = dependenceDomain->getNumSymbolIds(); unsigned outputNumIds = outputNumDims + outputNumSymbols; - SmallVector eq; - eq.resize(outputNumIds + 1); - for (unsigned i = 0; i < inputNumIneq; ++i) { + SmallVector ineq; + ineq.resize(outputNumIds + 1); + // Add inequalities from src domain. + for (unsigned i = 0; i < srcNumIneq; ++i) { // Zero fill. - std::fill(eq.begin(), eq.end(), 0); - // Add dim identifiers. - for (unsigned j = 0; j < inputNumDims; ++j) - eq[getPos(dimPosMap, ctx.values[j])] = ctx.domain.atIneq(i, j); - // Add symbol identifiers. - for (unsigned j = inputNumDims; j < inputNumIds; ++j) { - eq[getPos(symbolPosMap, ctx.values[j])] = ctx.domain.atIneq(i, j); - } - // Add constant term. - eq[outputNumIds] = ctx.domain.atIneq(i, inputNumIds); + std::fill(ineq.begin(), ineq.end(), 0); + // Set coefficients for identifiers corresponding to src domain. + for (unsigned j = 0; j < srcNumIds; ++j) + ineq[valuePosMap.getSrcDimOrSymPos(srcCtx.values[j])] = + srcCtx.domain.atIneq(i, j); + // Set constant term. + ineq[outputNumIds] = srcCtx.domain.atIneq(i, srcNumIds); + // Add inequality constraint. + dependenceDomain->addInequality(ineq); + } + // Add inequalities from dst domain. + for (unsigned i = 0; i < dstNumIneq; ++i) { + // Zero fill. + std::fill(ineq.begin(), ineq.end(), 0); + // Set coefficients for identifiers corresponding to dst domain. + for (unsigned j = 0; j < dstNumIds; ++j) + ineq[valuePosMap.getDstDimOrSymPos(dstCtx.values[j])] = + dstCtx.domain.atIneq(i, j); + // Set constant term. + ineq[outputNumIds] = dstCtx.domain.atIneq(i, dstNumIds); // Add inequality constraint. - dependenceDomain->addInequality(eq); + dependenceDomain->addInequality(ineq); } } @@ -640,12 +704,13 @@ addDomainConstraints(const IterationDomainContext &ctx, // a0 -c0 (a1 - c1) (a1 - c2) = 0 // b0 -f0 (b1 - f1) (b1 - f2) = 0 // -bool addMemRefAccessConstraints( - const AffineValueMap &srcAccessMap, const AffineValueMap &dstAccessMap, - const DenseMap &srcDimPosMap, - const DenseMap &dstDimPosMap, - const DenseMap &symbolPosMap, - FlatAffineConstraints *dependenceDomain) { +// Returns false if any AffineExpr cannot be flattened (which will be removed +// when mod/floor/ceil support is added). Returns true otherwise. +static bool +addMemRefAccessConstraints(const AffineValueMap &srcAccessMap, + const AffineValueMap &dstAccessMap, + const ValuePositionMap &valuePosMap, + FlatAffineConstraints *dependenceDomain) { AffineMap srcMap = srcAccessMap.getAffineMap(); AffineMap dstMap = dstAccessMap.getAffineMap(); assert(srcMap.getNumResults() == dstMap.getNumResults()); @@ -665,8 +730,7 @@ bool addMemRefAccessConstraints( unsigned outputNumSymbols = dependenceDomain->getNumSymbolIds(); unsigned outputNumIds = outputNumDims + outputNumSymbols; - SmallVector eq; - eq.resize(outputNumIds + 1); + SmallVector eq(outputNumIds + 1); SmallVector flattenedExpr; for (unsigned i = 0; i < numResults; ++i) { // Zero fill. @@ -677,13 +741,10 @@ bool addMemRefAccessConstraints( if (!getFlattenedAffineExpr(srcExpr, srcNumDims, srcNumSymbols, &flattenedExpr)) return false; - // Add dim identifier coefficients from src access function. - for (unsigned j = 0, e = srcNumDims; j < e; ++j) - eq[getPos(srcDimPosMap, srcOperands[j])] = flattenedExpr[j]; - // Add symbol identifiers from src access function. - for (unsigned j = srcNumDims; j < srcNumIds; ++j) - eq[getPos(symbolPosMap, srcOperands[j])] = flattenedExpr[j]; - // Add constant term. + // Set identifier coefficients from src access function. + for (unsigned j = 0, e = srcOperands.size(); j < e; ++j) + eq[valuePosMap.getSrcDimOrSymPos(srcOperands[j])] = flattenedExpr[j]; + // Set constant term. eq[outputNumIds] = flattenedExpr[srcNumIds]; // Get flattened AffineExpr for result 'i' from dst access function. @@ -692,49 +753,218 @@ bool addMemRefAccessConstraints( if (!getFlattenedAffineExpr(dstExpr, dstNumDims, dstNumSymbols, &flattenedExpr)) return false; - // Add dim identifier coefficients from dst access function. - for (unsigned j = 0, e = dstNumDims; j < e; ++j) - eq[getPos(dstDimPosMap, dstOperands[j])] = -flattenedExpr[j]; - // Add symbol identifiers from dst access function. - for (unsigned j = dstNumDims; j < dstNumIds; ++j) - eq[getPos(symbolPosMap, dstOperands[j])] -= flattenedExpr[j]; - // Add constant term. + // Set identifier coefficients from dst access function. + for (unsigned j = 0, e = dstOperands.size(); j < e; ++j) + eq[valuePosMap.getDstDimOrSymPos(dstOperands[j])] -= flattenedExpr[j]; + // Set constant term. eq[outputNumIds] -= flattenedExpr[dstNumIds]; // Add equality constraint. dependenceDomain->addEquality(eq); } // Add equality constraints for any operands that are defined by constant ops. - auto addEqForConstOperands = - [&](const DenseMap &posMap, - ArrayRef operands, unsigned start, unsigned limit) { - for (unsigned i = start; i < limit; ++i) { - if (isa(operands[i])) - continue; - auto *symbol = operands[i]; - assert(symbol->isValidSymbol()); - // Check if the symbols is a constant. - if (auto *opStmt = symbol->getDefiningStmt()) { - if (auto constOp = opStmt->dyn_cast()) { - dependenceDomain->setIdToConstant(getPos(posMap, symbol), - constOp->getValue()); - } - } + auto addEqForConstOperands = [&](ArrayRef operands) { + for (unsigned i = 0, e = operands.size(); i < e; ++i) { + if (isa(operands[i])) + continue; + auto *symbol = operands[i]; + assert(symbol->isValidSymbol()); + // Check if the symbol is a constant. + if (auto *opStmt = symbol->getDefiningStmt()) { + if (auto constOp = opStmt->dyn_cast()) { + dependenceDomain->setIdToConstant(valuePosMap.getSymPos(symbol), + constOp->getValue()); } - }; + } + } + }; - // Add equality constraints for any src dims defined by constant ops. - addEqForConstOperands(srcDimPosMap, srcOperands, 0, srcNumDims); // Add equality constraints for any src symbols defined by constant ops. - addEqForConstOperands(symbolPosMap, srcOperands, srcNumDims, srcNumIds); - // Add equality constraints for any dst dims defined by constant ops. - addEqForConstOperands(dstDimPosMap, dstOperands, 0, dstNumDims); + addEqForConstOperands(srcOperands); // Add equality constraints for any dst symbols defined by constant ops. - addEqForConstOperands(symbolPosMap, dstOperands, dstNumDims, dstNumIds); + addEqForConstOperands(dstOperands); + return true; +} + +// Returns the number of outer loop common to 'src/dstIterationDomainContext'. +static unsigned +getNumCommonLoops(const IterationDomainContext &srcIterationDomainContext, + const IterationDomainContext &dstIterationDomainContext) { + // Find the number of common loops shared by src and dst accesses. + unsigned minNumLoops = std::min(srcIterationDomainContext.getNumDims(), + dstIterationDomainContext.getNumDims()); + unsigned numCommonLoops = 0; + for (unsigned i = 0; i < minNumLoops; ++i) { + if (!isa(srcIterationDomainContext.values[i]) || + !isa(dstIterationDomainContext.values[i]) || + srcIterationDomainContext.values[i] != + dstIterationDomainContext.values[i]) + break; + ++numCommonLoops; + } + return numCommonLoops; +} + +// Returns true if the operation statement in 'srcAccess' properly dominates +// the operation statement in 'dstAccess'. Returns false otherwise. +// Note that 'numCommonLoops' is the number of contiguous surrounding outer +// loops. +static bool +srcHappensBeforeDst(const MemRefAccess &srcAccess, + const MemRefAccess &dstAccess, + const IterationDomainContext &srcIterationDomainContext, + unsigned numCommonLoops) { + if (numCommonLoops == 0) { + return mlir::properlyDominates(*srcAccess.opStmt, *dstAccess.opStmt); + } + auto *commonForValue = srcIterationDomainContext.values[numCommonLoops - 1]; + assert(isa(commonForValue)); + auto *commonForStmt = dyn_cast(commonForValue); + // Check the dominance relationship between the respective ancestors of the + // src and dst in the StmtBlock of the innermost among the common loops. + auto *srcStmt = commonForStmt->findAncestorStmtInBlock(*srcAccess.opStmt); + assert(srcStmt != nullptr); + auto *dstStmt = commonForStmt->findAncestorStmtInBlock(*dstAccess.opStmt); + assert(dstStmt != nullptr); + return mlir::properlyDominates(*srcStmt, *dstStmt); +} + +// Adds ordering constraints to 'dependenceDomain' based on number of loops +// common to 'src/dstIterationDomainContext' and requested 'loopDepth'. +// Note that 'loopDepth' cannot exceed the number of common loops plus one. +// EX: Given a loop nest of depth 2 with IVs 'i' and 'j': +// *) If 'loopDepth == 1' then one constraint is added: i' >= i + 1 +// *) If 'loopDepth == 2' then two constraints are added: i == i' and j' > j + 1 +// *) If 'loopDepth == 3' then two constraints are added: i == i' and j == j' +static void +addOrderingConstraints(const IterationDomainContext &srcIterationDomainContext, + const IterationDomainContext &dstIterationDomainContext, + const ValuePositionMap &valuePosMap, unsigned loopDepth, + FlatAffineConstraints *dependenceDomain) { + unsigned numCols = dependenceDomain->getNumCols(); + SmallVector eq(numCols); + unsigned numSrcDims = valuePosMap.getNumSrcDims(); + unsigned numCommonLoops = + getNumCommonLoops(srcIterationDomainContext, dstIterationDomainContext); + unsigned numCommonLoopConstraints = std::min(numCommonLoops, loopDepth); + for (unsigned i = 0; i < numCommonLoopConstraints; ++i) { + std::fill(eq.begin(), eq.end(), 0); + eq[i] = -1; + eq[i + numSrcDims] = 1; + if (i == loopDepth - 1) { + eq[numCols - 1] = -1; + dependenceDomain->addInequality(eq); + } else { + dependenceDomain->addEquality(eq); + } + } +} +// Returns true if 'isEq' constraint in 'dependenceDomain' has a single +// non-zero coefficient at (rowIdx, idPos). Returns false otherwise. +// TODO(andydavis) Move this function to FlatAffineConstraints. +static bool hasSingleNonZeroAt(unsigned idPos, unsigned rowIdx, bool isEq, + FlatAffineConstraints *dependenceDomain) { + unsigned numCols = dependenceDomain->getNumCols(); + for (unsigned j = 0; j < numCols - 1; ++j) { + int64_t v = isEq ? dependenceDomain->atEq(rowIdx, j) + : dependenceDomain->atIneq(rowIdx, j); + if ((j == idPos && v == 0) || (j != idPos && v != 0)) + return false; + } return true; } +// Computes distance and direction vectors in 'dependences', by adding +// variables to 'dependenceDomain' which represent the difference of the IVs, +// eliminating all other variables, and reading off distance vectors from +// equality constraints (if possible), and direction vectors from inequalities. +static void computeDirectionVector( + const IterationDomainContext &srcIterationDomainContext, + const IterationDomainContext &dstIterationDomainContext, unsigned loopDepth, + FlatAffineConstraints *dependenceDomain, + llvm::SmallVector *dependenceComponents) { + // Find the number of common loops shared by src and dst accesses. + unsigned numCommonLoops = + getNumCommonLoops(srcIterationDomainContext, dstIterationDomainContext); + if (numCommonLoops == 0) + return; + // Compute direction vectors for requested loop depth. + unsigned numIdsToEliminate = dependenceDomain->getNumIds(); + // Add new variables to 'dependenceDomain' to represent the direction + // constraints for each shared loop. + for (unsigned j = 0; j < numCommonLoops; ++j) { + dependenceDomain->addDimId(j); + } + + // Add equality contraints for each common loop, setting newly instroduced + // variable at column 'j' to the 'dst' IV minus the 'src IV. + SmallVector eq; + eq.resize(dependenceDomain->getNumCols()); + for (unsigned j = 0; j < numCommonLoops; ++j) { + std::fill(eq.begin(), eq.end(), 0); + eq[j] = 1; + eq[j + numCommonLoops] = 1; + eq[j + 2 * numCommonLoops] = -1; + dependenceDomain->addEquality(eq); + } + + // Eliminate all variables other than the direction variables just added. + dependenceDomain->projectOut(numCommonLoops, numIdsToEliminate); + + // Scan each common loop variable column and add direction vectors based + // on eliminated constraint system. + unsigned numCols = dependenceDomain->getNumCols(); + dependenceComponents->reserve(numCommonLoops); + for (unsigned j = 0; j < numCommonLoops; ++j) { + DependenceComponent depComp; + for (unsigned i = 0, e = dependenceDomain->getNumEqualities(); i < e; ++i) { + // Check for equality constraint with single non-zero in column 'j'. + if (!hasSingleNonZeroAt(j, i, /*isEq=*/true, dependenceDomain)) + continue; + // Get direction variable coefficient at (i, j). + int64_t d = dependenceDomain->atEq(i, j); + // Get constant coefficient at (i, numCols - 1). + int64_t c = -dependenceDomain->atEq(i, numCols - 1); + assert(c % d == 0 && "No dependence should have existed"); + depComp.lb = depComp.ub = c / d; + dependenceComponents->push_back(depComp); + break; + } + // Skip checking inequalities if we set 'depComp' based on equalities. + if (depComp.lb.hasValue() || depComp.ub.hasValue()) + continue; + // TODO(andydavis) Call FlatAffineConstraints::getConstantLower/UpperBound + // Check inequalities to track direction range for each 'j'. + for (unsigned i = 0, e = dependenceDomain->getNumInequalities(); i < e; + ++i) { + // Check for inequality constraint with single non-zero in column 'j'. + if (!hasSingleNonZeroAt(j, i, /*isEq=*/false, dependenceDomain)) + continue; + // Get direction variable coefficient at (i, j). + int64_t d = dependenceDomain->atIneq(i, j); + // Get constant coefficient at (i, numCols - 1). + int64_t c = dependenceDomain->atIneq(i, numCols - 1); + if (d < 0) { + // Upper bound: add tightest upper bound. + auto ub = mlir::floorDiv(c, -d); + if (!depComp.ub.hasValue() || ub < depComp.ub.getValue()) + depComp.ub = ub; + } else { + // Lower bound: add tightest lower bound. + auto lb = mlir::ceilDiv(-c, d); + if (!depComp.lb.hasValue() || lb > depComp.lb.getValue()) + depComp.lb = lb; + } + } + if (depComp.lb.hasValue() || depComp.ub.hasValue()) { + if (depComp.lb.hasValue() && depComp.ub.hasValue()) + assert(depComp.lb.getValue() <= depComp.ub.getValue()); + dependenceComponents->push_back(depComp); + } + } +} + // Populates 'accessMap' with composition of AffineApplyOps reachable from // indices of MemRefAccess. void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const { @@ -752,6 +982,9 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const { // between memref accesses 'srcAccess' and 'dstAccess'. // Returns 'false' if the accesses can be definitively shown not to access the // same element. Returns 'true' otherwise. +// If a dependence exists, returns in 'dependenceComponents' a direction +// vector for the dependence, with a component for each loop IV in loops +// common to both accesses (see Dependence in AffineAnalysis.h for details). // // The memref access dependence check is comprised of the following steps: // *) Compute access functions for each access. Access functions are computed @@ -830,13 +1063,10 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const { // // // TODO(andydavis) Support AffineExprs mod/floordiv/ceildiv. -// TODO(andydavis) Add precedence order constraints for accesses that -// share a common loop. -// TODO(andydavis) Add support for returning the direction of the dependence. -// For example, this function may return that there is a dependence between -// 'srcAccess' and 'dstAccess' but the dependence may be from dst to src. -bool mlir::checkMemrefAccessDependence(const MemRefAccess &srcAccess, - const MemRefAccess &dstAccess) { +bool mlir::checkMemrefAccessDependence( + const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, + unsigned loopDepth, + llvm::SmallVector *dependenceComponents) { // Return 'false' if these accesses do not acces the same memref. if (srcAccess.memref != dstAccess.memref) return false; @@ -862,23 +1092,32 @@ bool mlir::checkMemrefAccessDependence(const MemRefAccess &srcAccess, if (!getIterationDomainContext(dstAccess.opStmt, &dstIterationDomainContext)) return false; + // Return if loopDepth > numCommonLoops and 'srcAccess' does not properly + // dominate 'dstAccess' (i.e. no execution path from src to dst access). + unsigned numCommonLoops = + getNumCommonLoops(srcIterationDomainContext, dstIterationDomainContext); + assert(loopDepth <= numCommonLoops + 1); + if (loopDepth > numCommonLoops && + !srcHappensBeforeDst(srcAccess, dstAccess, srcIterationDomainContext, + numCommonLoops)) { + return false; + } // Build dim and symbol position maps for each access from access operand // MLValue to position in merged contstraint system. - DenseMap srcDimPosMap; - DenseMap dstDimPosMap; - DenseMap symbolPosMap; - buildDimAndSymbolPositionMaps( - srcIterationDomainContext, dstIterationDomainContext, srcAccessMap, - dstAccessMap, &srcDimPosMap, &dstDimPosMap, &symbolPosMap); + ValuePositionMap valuePosMap; + buildDimAndSymbolPositionMaps(srcIterationDomainContext, + dstIterationDomainContext, srcAccessMap, + dstAccessMap, &valuePosMap); - // TODO(andydavis) Add documentation. + // Calculate number of equalities/inequalities and columns required to + // initialize FlatAffineConstraints for 'dependenceDomain'. unsigned numIneq = srcIterationDomainContext.domain.getNumInequalities() + dstIterationDomainContext.domain.getNumInequalities(); AffineMap srcMap = srcAccessMap.getAffineMap(); assert(srcMap.getNumResults() == dstAccessMap.getAffineMap().getNumResults()); unsigned numEq = srcMap.getNumResults(); - unsigned numDims = srcDimPosMap.size() + dstDimPosMap.size(); - unsigned numSymbols = symbolPosMap.size(); + unsigned numDims = valuePosMap.getNumDims(); + unsigned numSymbols = valuePosMap.getNumSymbols(); unsigned numIds = numDims + numSymbols; unsigned numCols = numIds + 1; @@ -888,18 +1127,23 @@ bool mlir::checkMemrefAccessDependence(const MemRefAccess &srcAccess, // Create memref access constraint by equating src/dst access functions. // Note that this check is conservative, and will failure in the future // when local variables for mod/div exprs are supported. - if (!addMemRefAccessConstraints(srcAccessMap, dstAccessMap, srcDimPosMap, - dstDimPosMap, symbolPosMap, + if (!addMemRefAccessConstraints(srcAccessMap, dstAccessMap, valuePosMap, &dependenceDomain)) return true; - // Add domain constraints for src access function. - addDomainConstraints(srcIterationDomainContext, srcDimPosMap, symbolPosMap, - &dependenceDomain); - // Add equality constraints from 'dstConstraints'. - addDomainConstraints(dstIterationDomainContext, dstDimPosMap, symbolPosMap, - &dependenceDomain); - bool isEmpty = dependenceDomain.isEmpty(); - // Return false if the solution space is empty. - return !isEmpty; + // Add 'src' happens before 'dst' ordering constraints. + addOrderingConstraints(srcIterationDomainContext, dstIterationDomainContext, + valuePosMap, loopDepth, &dependenceDomain); + // Add src and dst domain constraints. + addDomainConstraints(srcIterationDomainContext, dstIterationDomainContext, + valuePosMap, &dependenceDomain); + + // Return false if the solution space is empty: no dependence. + if (dependenceDomain.isEmpty()) { + return false; + } + // Compute dependence direction vector and return true. + computeDirectionVector(srcIterationDomainContext, dstIterationDomainContext, + loopDepth, &dependenceDomain, dependenceComponents); + return true; } diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp index 28a80762b94..fc63a41c848 100644 --- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp @@ -90,40 +90,87 @@ static void getMemRefAccess(const OperationStmt *loadOrStoreOpStmt, } } +// Populates 'loops' with the loop nest surrounding 'stmt', ordered from +// outer-most ForStmt to inner-most. +static void getLoopNest(Statement *stmt, + SmallVector *loops) { + const auto *currStmt = stmt->getParentStmt(); + while (currStmt != nullptr && isa(currStmt)) { + loops->push_back(dyn_cast(currStmt)); + currStmt = currStmt->getParentStmt(); + } + std::reverse(loops->begin(), loops->end()); +} + +// Returns the number of surrounding loops common to 'loopsA' and 'loopsB', +// where each lists loops from outer-most to inner-most in loop nest. +static unsigned getNumCommonSurroundingLoops(ArrayRef loopsA, + ArrayRef loopsB) { + unsigned minNumLoops = std::min(loopsA.size(), loopsB.size()); + unsigned numCommonLoops = 0; + for (unsigned i = 0; i < minNumLoops; ++i) { + if (loopsA[i] != loopsB[i]) + break; + ++numCommonLoops; + } + return numCommonLoops; +} + +// Returns a result string which represents the direction vector (if there was +// a dependence), returns the string "false" otherwise. +static string +getDirectionVectorStr(bool ret, unsigned numCommonLoops, unsigned loopNestDepth, + ArrayRef dependenceComponents) { + if (!ret) + return "false"; + if (dependenceComponents.empty() || loopNestDepth > numCommonLoops) + return "true"; + string result; + for (unsigned i = 0, e = dependenceComponents.size(); i < e; ++i) { + string lbStr = dependenceComponents[i].lb.hasValue() + ? std::to_string(dependenceComponents[i].lb.getValue()) + : "-inf"; + string ubStr = dependenceComponents[i].ub.hasValue() + ? std::to_string(dependenceComponents[i].ub.getValue()) + : "+inf"; + result += "[" + lbStr + ", " + ubStr + "]"; + } + return result; +} + // For each access in 'loadsAndStores', runs a depence check between this // "source" access and all subsequent "destination" accesses in // 'loadsAndStores'. Emits the result of the dependence check as a note with // the source access. -// TODO(andydavis) Clarify expected-note logs. In particular we may want to -// drop the 'i' from the note string, tag dependence destination accesses -// with a note with their 'j' index. In addition, we may want a schedme that -// first assigned unique ids to each access, then emits a note for each access -// with its id, and emits a note for each dependence check with a pair of ids. -// For example, given this code: -// -// memref_access0 -// // emit note: "this op is memref access 0' -// // emit note: "dependence from memref access 0 to access 1 = false" -// // emit note: "dependence from memref access 0 to access 2 = true" -// memref_access1 -// // emit note: "this op is memref access 1' -// // emit note: "dependence from memref access 1 to access 2 = false" -// memref_access2 -// // emit note: "this op is memref access 2' -// static void checkDependences(ArrayRef loadsAndStores) { for (unsigned i = 0, e = loadsAndStores.size(); i < e; ++i) { auto *srcOpStmt = loadsAndStores[i]; MemRefAccess srcAccess; getMemRefAccess(srcOpStmt, &srcAccess); - for (unsigned j = i + 1; j < e; ++j) { + SmallVector srcLoops; + getLoopNest(srcOpStmt, &srcLoops); + for (unsigned j = 0; j < e; ++j) { auto *dstOpStmt = loadsAndStores[j]; MemRefAccess dstAccess; getMemRefAccess(dstOpStmt, &dstAccess); - bool ret = checkMemrefAccessDependence(srcAccess, dstAccess); - srcOpStmt->emitNote("dependence from memref access " + Twine(i) + - " to access " + Twine(j) + " = " + - (ret ? "true" : "false")); + + SmallVector dstLoops; + getLoopNest(dstOpStmt, &dstLoops); + unsigned numCommonLoops = + getNumCommonSurroundingLoops(srcLoops, dstLoops); + for (unsigned d = 1; d <= numCommonLoops + 1; ++d) { + llvm::SmallVector dependenceComponents; + bool ret = checkMemrefAccessDependence(srcAccess, dstAccess, d, + &dependenceComponents); + // TODO(andydavis) Print dependence type (i.e. RAW, etc) and print + // distance vectors as: ([2, 3], [0, 10]). Also, shorten distance + // vectors from ([1, 1], [3, 3]) to (1, 3). + srcOpStmt->emitNote( + "dependence from " + Twine(i) + " to " + Twine(j) + " at depth " + + Twine(d) + " = " + + getDirectionVectorStr(ret, numCommonLoops, d, dependenceComponents) + .c_str()); + } } } } diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 87657aeb359..3db290fac03 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -233,11 +233,15 @@ PassResult LoopFusion::runOnMLFunction(MLFunction *f) { // TODO(andydavis) Add checks for fusion-preventing dependences and ordering // constraints which would prevent fusion. - // TODO(andydavis) This check if overly conservative for now. Support fusing + // TODO(andydavis) This check is overly conservative for now. Support fusing // statements with compatible dependences (i.e. statements where the // dependence between the statements does not reverse direction when the // statements are fused into the same loop). - if (!checkMemrefAccessDependence(candidate.accessA, candidate.accessB)) { + llvm::SmallVector dependenceComponents; + // TODO(andydavis) Check dependences at differnt loop nest depths. + if (!checkMemrefAccessDependence(candidate.accessA, candidate.accessB, + /*loopNestDepth=*/0, + &dependenceComponents)) { // Current conservatinve test policy: No dependence exists between accesses // in different loop nests -> fuse loops. fuseLoops(candidate); diff --git a/mlir/test/Transforms/memref-dependence-check.mlir b/mlir/test/Transforms/memref-dependence-check.mlir index ca6d9fb4143..26f2738c994 100644 --- a/mlir/test/Transforms/memref-dependence-check.mlir +++ b/mlir/test/Transforms/memref-dependence-check.mlir @@ -1,7 +1,5 @@ // RUN: mlir-opt %s -memref-dependence-check -split-input-file -verify | FileCheck %s -// TODO(andydavis) Add test cases for self-edges and a dependence cycle. - // ----- // CHECK-LABEL: mlfunc @different_memrefs() { mlfunc @different_memrefs() { @@ -10,8 +8,11 @@ mlfunc @different_memrefs() { %c0 = constant 0 : index %c1 = constant 1.0 : f32 store %c1, %m.a[%c0] : memref<100xf32> - // expected-note@-1 {{dependence from memref access 0 to access 1 = false}} + // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 0 to 1 at depth 1 = false}} %v0 = load %m.b[%c0] : memref<100xf32> + // expected-note@-1 {{dependence from 1 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 1 to 1 at depth 1 = false}} return } @@ -23,8 +24,11 @@ mlfunc @store_load_different_elements() { %c1 = constant 1 : index %c7 = constant 7.0 : f32 store %c7, %m[%c0] : memref<100xf32> - // expected-note@-1 {{dependence from memref access 0 to access 1 = false}} + // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 0 to 1 at depth 1 = false}} %v0 = load %m[%c1] : memref<100xf32> + // expected-note@-1 {{dependence from 1 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 1 to 1 at depth 1 = false}} return } @@ -36,8 +40,11 @@ mlfunc @load_store_different_elements() { %c1 = constant 1 : index %c7 = constant 7.0 : f32 %v0 = load %m[%c1] : memref<100xf32> - // expected-note@-1 {{dependence from memref access 0 to access 1 = false}} + // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 0 to 1 at depth 1 = false}} store %c7, %m[%c0] : memref<100xf32> + // expected-note@-1 {{dependence from 1 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 1 to 1 at depth 1 = false}} return } @@ -48,20 +55,11 @@ mlfunc @store_load_same_element() { %c11 = constant 11 : index %c7 = constant 7.0 : f32 store %c7, %m[%c11] : memref<100xf32> - // expected-note@-1 {{dependence from memref access 0 to access 1 = true}} + // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 0 to 1 at depth 1 = true}} %v0 = load %m[%c11] : memref<100xf32> - return -} - -// ----- -// CHECK-LABEL: mlfunc @load_store_same_element() { -mlfunc @load_store_same_element() { - %m = alloc() : memref<100xf32> - %c11 = constant 11 : index - %c7 = constant 7.0 : f32 - %v0 = load %m[%c11] : memref<100xf32> - // expected-note@-1 {{dependence from memref access 0 to access 1 = true}} - store %c7, %m[%c11] : memref<100xf32> + // expected-note@-1 {{dependence from 1 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 1 to 1 at depth 1 = false}} return } @@ -72,8 +70,11 @@ mlfunc @load_load_same_element() { %c11 = constant 11 : index %c7 = constant 7.0 : f32 %v0 = load %m[%c11] : memref<100xf32> - // expected-note@-1 {{dependence from memref access 0 to access 1 = false}} + // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 0 to 1 at depth 1 = false}} %v1 = load %m[%c11] : memref<100xf32> + // expected-note@-1 {{dependence from 1 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 1 to 1 at depth 1 = false}} return } @@ -83,8 +84,11 @@ mlfunc @store_load_same_symbol(%arg0 : index) { %m = alloc() : memref<100xf32> %c7 = constant 7.0 : f32 store %c7, %m[%arg0] : memref<100xf32> - // expected-note@-1 {{dependence from memref access 0 to access 1 = true}} + // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 0 to 1 at depth 1 = true}} %v0 = load %m[%arg0] : memref<100xf32> + // expected-note@-1 {{dependence from 1 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 1 to 1 at depth 1 = false}} return } @@ -94,8 +98,11 @@ mlfunc @store_load_different_symbols(%arg0 : index, %arg1 : index) { %m = alloc() : memref<100xf32> %c7 = constant 7.0 : f32 store %c7, %m[%arg0] : memref<100xf32> - // expected-note@-1 {{dependence from memref access 0 to access 1 = true}} + // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 0 to 1 at depth 1 = true}} %v0 = load %m[%arg1] : memref<100xf32> + // expected-note@-1 {{dependence from 1 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 1 to 1 at depth 1 = false}} return } @@ -107,9 +114,12 @@ mlfunc @store_load_diff_element_affine_apply_const() { %c7 = constant 7.0 : f32 %a0 = affine_apply (d0) -> (d0) (%c1) store %c7, %m[%a0] : memref<100xf32> - // expected-note@-1 {{dependence from memref access 0 to access 1 = false}} + // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 0 to 1 at depth 1 = false}} %a1 = affine_apply (d0) -> (d0 + 1) (%c1) %v0 = load %m[%a1] : memref<100xf32> + // expected-note@-1 {{dependence from 1 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 1 to 1 at depth 1 = false}} return } @@ -122,9 +132,12 @@ mlfunc @store_load_same_element_affine_apply_const() { %c11 = constant 11 : index %a0 = affine_apply (d0) -> (d0 + 1) (%c9) store %c7, %m[%a0] : memref<100xf32> - // expected-note@-1 {{dependence from memref access 0 to access 1 = true}} + // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 0 to 1 at depth 1 = true}} %a1 = affine_apply (d0) -> (d0 - 1) (%c11) %v0 = load %m[%a1] : memref<100xf32> + // expected-note@-1 {{dependence from 1 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 1 to 1 at depth 1 = false}} return } @@ -135,23 +148,28 @@ mlfunc @store_load_affine_apply_symbol(%arg0 : index) { %c7 = constant 7.0 : f32 %a0 = affine_apply (d0) -> (d0) (%arg0) store %c7, %m[%a0] : memref<100xf32> - // expected-note@-1 {{dependence from memref access 0 to access 1 = true}} + // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 0 to 1 at depth 1 = true}} %a1 = affine_apply (d0) -> (d0) (%arg0) %v0 = load %m[%a1] : memref<100xf32> + // expected-note@-1 {{dependence from 1 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 1 to 1 at depth 1 = false}} return } // ----- -// Note: has single equality x - y - 1 = 0, which has solns for (1, 0) (0, -1) // CHECK-LABEL: mlfunc @store_load_affine_apply_symbol_offset(%arg0 : index) { mlfunc @store_load_affine_apply_symbol_offset(%arg0 : index) { %m = alloc() : memref<100xf32> %c7 = constant 7.0 : f32 %a0 = affine_apply (d0) -> (d0) (%arg0) store %c7, %m[%a0] : memref<100xf32> - // expected-note@-1 {{dependence from memref access 0 to access 1 = true}} + // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 0 to 1 at depth 1 = false}} %a1 = affine_apply (d0) -> (d0 + 1) (%arg0) %v0 = load %m[%a1] : memref<100xf32> + // expected-note@-1 {{dependence from 1 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 1 to 1 at depth 1 = false}} return } @@ -164,9 +182,39 @@ mlfunc @store_range_load_after_range() { for %i0 = 0 to 10 { %a0 = affine_apply (d0) -> (d0) (%i0) store %c7, %m[%a0] : memref<100xf32> - // expected-note@-1 {{dependence from memref access 0 to access 1 = false}} + // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 0 to 0 at depth 2 = false}} + // expected-note@-3 {{dependence from 0 to 1 at depth 1 = false}} + // expected-note@-4 {{dependence from 0 to 1 at depth 2 = false}} %a1 = affine_apply (d0) -> (d0) (%c10) %v0 = load %m[%a1] : memref<100xf32> + // expected-note@-1 {{dependence from 1 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 1 to 0 at depth 2 = false}} + // expected-note@-3 {{dependence from 1 to 1 at depth 1 = false}} + // expected-note@-4 {{dependence from 1 to 1 at depth 2 = false}} + } + return +} + +// ----- +// CHECK-LABEL: mlfunc @store_load_func_symbol(%arg0 : index) { +mlfunc @store_load_func_symbol(%arg0 : index) { + %m = alloc() : memref<100xf32> + %c7 = constant 7.0 : f32 + %c10 = constant 10 : index + for %i0 = 0 to 10 { + %a0 = affine_apply (d0) -> (d0) (%arg0) + store %c7, %m[%a0] : memref<100xf32> + // expected-note@-1 {{dependence from 0 to 0 at depth 1 = [1, 9]}} + // expected-note@-2 {{dependence from 0 to 0 at depth 2 = false}} + // expected-note@-3 {{dependence from 0 to 1 at depth 1 = [1, 9]}} + // expected-note@-4 {{dependence from 0 to 1 at depth 2 = true}} + %a1 = affine_apply (d0) -> (d0) (%arg0) + %v0 = load %m[%a1] : memref<100xf32> + // expected-note@-1 {{dependence from 1 to 0 at depth 1 = [1, 9]}} + // expected-note@-2 {{dependence from 1 to 0 at depth 2 = false}} + // expected-note@-3 {{dependence from 1 to 1 at depth 1 = false}} + // expected-note@-4 {{dependence from 1 to 1 at depth 2 = false}} } return } @@ -179,10 +227,22 @@ mlfunc @store_range_load_last_in_range() { %c10 = constant 10 : index for %i0 = 0 to 10 { %a0 = affine_apply (d0) -> (d0) (%i0) + // For dependence from 0 to 1, we do not have a loop carried dependence + // because only the final write in the loop accesses the same element as the + // load, so this dependence appears only at depth 2 (loop independent). store %c7, %m[%a0] : memref<100xf32> - // expected-note@-1 {{dependence from memref access 0 to access 1 = true}} + // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 0 to 0 at depth 2 = false}} + // expected-note@-3 {{dependence from 0 to 1 at depth 1 = false}} + // expected-note@-4 {{dependence from 0 to 1 at depth 2 = true}} %a1 = affine_apply (d0) -> (d0 - 1) (%c10) + // For dependence from 1 to 0, we have write-after-read (WAR) dependences + // for all loads in the loop to the store on the last iteration. %v0 = load %m[%a1] : memref<100xf32> + // expected-note@-1 {{dependence from 1 to 0 at depth 1 = [1, 9]}} + // expected-note@-2 {{dependence from 1 to 0 at depth 2 = false}} + // expected-note@-3 {{dependence from 1 to 1 at depth 1 = false}} + // expected-note@-4 {{dependence from 1 to 1 at depth 2 = false}} } return } @@ -196,9 +256,16 @@ mlfunc @store_range_load_before_range() { for %i0 = 1 to 11 { %a0 = affine_apply (d0) -> (d0) (%i0) store %c7, %m[%a0] : memref<100xf32> - // expected-note@-1 {{dependence from memref access 0 to access 1 = false}} + // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 0 to 0 at depth 2 = false}} + // expected-note@-3 {{dependence from 0 to 1 at depth 1 = false}} + // expected-note@-4 {{dependence from 0 to 1 at depth 2 = false}} %a1 = affine_apply (d0) -> (d0) (%c0) %v0 = load %m[%a1] : memref<100xf32> + // expected-note@-1 {{dependence from 1 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 1 to 0 at depth 2 = false}} + // expected-note@-3 {{dependence from 1 to 1 at depth 1 = false}} + // expected-note@-4 {{dependence from 1 to 1 at depth 2 = false}} } return } @@ -211,127 +278,289 @@ mlfunc @store_range_load_first_in_range() { %c0 = constant 0 : index for %i0 = 1 to 11 { %a0 = affine_apply (d0) -> (d0) (%i0) + // Dependence from 0 to 1 at depth 1 is a range because all loads at + // constant index zero are reads after first store at index zero during + // first iteration of the loop. store %c7, %m[%a0] : memref<100xf32> - // expected-note@-1 {{dependence from memref access 0 to access 1 = true}} + // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 0 to 0 at depth 2 = false}} + // expected-note@-3 {{dependence from 0 to 1 at depth 1 = [1, 9]}} + // expected-note@-4 {{dependence from 0 to 1 at depth 2 = true}} %a1 = affine_apply (d0) -> (d0 + 1) (%c0) %v0 = load %m[%a1] : memref<100xf32> + // expected-note@-1 {{dependence from 1 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 1 to 0 at depth 2 = false}} + // expected-note@-3 {{dependence from 1 to 1 at depth 1 = false}} + // expected-note@-4 {{dependence from 1 to 1 at depth 2 = false}} } return } // ----- -// CHECK-LABEL: mlfunc @store_load_diff_ranges_diff_1d_loop_nests() { -mlfunc @store_load_diff_ranges_diff_1d_loop_nests() { +// CHECK-LABEL: mlfunc @store_plus_3() { +mlfunc @store_plus_3() { %m = alloc() : memref<100xf32> %c7 = constant 7.0 : f32 - for %i0 = 0 to 5 { - %a0 = affine_apply (d0) -> (d0) (%i0) + for %i0 = 1 to 11 { + %a0 = affine_apply (d0) -> (d0 + 3) (%i0) store %c7, %m[%a0] : memref<100xf32> - // expected-note@-1 {{dependence from memref access 0 to access 1 = false}} - } - for %i1 = 5 to 11 { - %a1 = affine_apply (d0) -> (d0) (%i1) + // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 0 to 0 at depth 2 = false}} + // expected-note@-3 {{dependence from 0 to 1 at depth 1 = [3, 3]}} + // expected-note@-4 {{dependence from 0 to 1 at depth 2 = false}} + %a1 = affine_apply (d0) -> (d0) (%i0) %v0 = load %m[%a1] : memref<100xf32> + // expected-note@-1 {{dependence from 1 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 1 to 0 at depth 2 = false}} + // expected-note@-3 {{dependence from 1 to 1 at depth 1 = false}} + // expected-note@-4 {{dependence from 1 to 1 at depth 2 = false}} } return } // ----- -// CHECK-LABEL: mlfunc @store_load_overlapping_ranges_diff_1d_loop_nests() { -mlfunc @store_load_overlapping_ranges_diff_1d_loop_nests() { +// CHECK-LABEL: mlfunc @load_minus_2() { +mlfunc @load_minus_2() { %m = alloc() : memref<100xf32> %c7 = constant 7.0 : f32 - for %i0 = 0 to 5 { + for %i0 = 2 to 11 { %a0 = affine_apply (d0) -> (d0) (%i0) store %c7, %m[%a0] : memref<100xf32> - // expected-note@-1 {{dependence from memref access 0 to access 1 = true}} - } - for %i1 = 5 to 11 { - %a1 = affine_apply (d0) -> (d0 - 1) (%i1) + // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 0 to 0 at depth 2 = false}} + // expected-note@-3 {{dependence from 0 to 1 at depth 1 = [2, 2]}} + // expected-note@-4 {{dependence from 0 to 1 at depth 2 = false}} + %a1 = affine_apply (d0) -> (d0 - 2) (%i0) %v0 = load %m[%a1] : memref<100xf32> + // expected-note@-1 {{dependence from 1 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 1 to 0 at depth 2 = false}} + // expected-note@-3 {{dependence from 1 to 1 at depth 1 = false}} + // expected-note@-4 {{dependence from 1 to 1 at depth 2 = false}} } return } // ----- -// CHECK-LABEL: mlfunc @store_load_diff_inner_ranges_diff_2d_loop_nests() { -mlfunc @store_load_diff_inner_ranges_diff_2d_loop_nests() { +// CHECK-LABEL: mlfunc @perfectly_nested_loops_loop_independent() { +mlfunc @perfectly_nested_loops_loop_independent() { %m = alloc() : memref<10x10xf32> %c7 = constant 7.0 : f32 - for %i0 = 0 to 5 { - for %i1 = 0 to 5 { + for %i0 = 0 to 11 { + for %i1 = 0 to 11 { + // Dependence from access 0 to 1 is loop independent at depth = 3. %a0 = affine_apply (d0, d1) -> (d0, d1) (%i0, %i1) - store %c7, %m[%a0#0, %a0#1] : memref<10x10xf32> - // expected-note@-1 {{dependence from memref access 0 to access 1 = false}} - } - } - for %i2 = 0 to 5 { - for %i3 = 5 to 7 { - %a1 = affine_apply (d0, d1) -> (d0, d1) (%i2, %i3) + store %c7, %m[%a0#0, %a0#1] : memref<10x10xf32> + // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 0 to 0 at depth 2 = false}} + // expected-note@-3 {{dependence from 0 to 0 at depth 3 = false}} + // expected-note@-4 {{dependence from 0 to 1 at depth 1 = false}} + // expected-note@-5 {{dependence from 0 to 1 at depth 2 = false}} + // expected-note@-6 {{dependence from 0 to 1 at depth 3 = true}} + %a1 = affine_apply (d0, d1) -> (d0, d1) (%i0, %i1) %v0 = load %m[%a1#0, %a1#1] : memref<10x10xf32> + // expected-note@-1 {{dependence from 1 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 1 to 0 at depth 2 = false}} + // expected-note@-3 {{dependence from 1 to 0 at depth 3 = false}} + // expected-note@-4 {{dependence from 1 to 1 at depth 1 = false}} + // expected-note@-5 {{dependence from 1 to 1 at depth 2 = false}} + // expected-note@-6 {{dependence from 1 to 1 at depth 3 = false}} } } return } // ----- -// CHECK-LABEL: mlfunc @store_load_overlapping_inner_ranges_diff_2d_loop_nests() { -mlfunc @store_load_overlapping_inner_ranges_diff_2d_loop_nests() { +// CHECK-LABEL: mlfunc @perfectly_nested_loops_loop_carried_at_depth1() { +mlfunc @perfectly_nested_loops_loop_carried_at_depth1() { %m = alloc() : memref<10x10xf32> %c7 = constant 7.0 : f32 - for %i0 = 0 to 5 { - for %i1 = 0 to 5 { - %a0 = affine_apply (d0, d1) -> (d0, d1 + 1) (%i0, %i1) - store %c7, %m[%a0#0, %a0#1] : memref<10x10xf32> - // expected-note@-1 {{dependence from memref access 0 to access 1 = true}} + for %i0 = 0 to 9 { + for %i1 = 0 to 9 { + // Dependence from access 0 to 1 is loop carried at depth 1. + %a0 = affine_apply (d0, d1) -> (d0, d1) (%i0, %i1) + store %c7, %m[%a0#0, %a0#1] : memref<10x10xf32> + // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 0 to 0 at depth 2 = false}} + // expected-note@-3 {{dependence from 0 to 0 at depth 3 = false}} + // expected-note@-4 {{dependence from 0 to 1 at depth 1 = [2, 2][0, 0]}} + // expected-note@-5 {{dependence from 0 to 1 at depth 2 = false}} + // expected-note@-6 {{dependence from 0 to 1 at depth 3 = false}} + %a1 = affine_apply (d0, d1) -> (d0 - 2, d1) (%i0, %i1) + %v0 = load %m[%a1#0, %a1#1] : memref<10x10xf32> + // expected-note@-1 {{dependence from 1 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 1 to 0 at depth 2 = false}} + // expected-note@-3 {{dependence from 1 to 0 at depth 3 = false}} + // expected-note@-4 {{dependence from 1 to 1 at depth 1 = false}} + // expected-note@-5 {{dependence from 1 to 1 at depth 2 = false}} + // expected-note@-6 {{dependence from 1 to 1 at depth 3 = false}} } } - for %i2 = 0 to 5 { - for %i3 = 5 to 7 { - %a1 = affine_apply (d0, d1) -> (d0, d1) (%i2, %i3) + return +} + +// ----- +// CHECK-LABEL: mlfunc @perfectly_nested_loops_loop_carried_at_depth2() { +mlfunc @perfectly_nested_loops_loop_carried_at_depth2() { + %m = alloc() : memref<10x10xf32> + %c7 = constant 7.0 : f32 + for %i0 = 0 to 10 { + for %i1 = 0 to 10 { + // Dependence from access 0 to 1 is loop carried at depth 2. + %a0 = affine_apply (d0, d1) -> (d0, d1) (%i0, %i1) + store %c7, %m[%a0#0, %a0#1] : memref<10x10xf32> + // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 0 to 0 at depth 2 = false}} + // expected-note@-3 {{dependence from 0 to 0 at depth 3 = false}} + // expected-note@-4 {{dependence from 0 to 1 at depth 1 = false}} + // expected-note@-5 {{dependence from 0 to 1 at depth 2 = [0, 0][3, 3]}} + // expected-note@-6 {{dependence from 0 to 1 at depth 3 = false}} + %a1 = affine_apply (d0, d1) -> (d0, d1 - 3) (%i0, %i1) %v0 = load %m[%a1#0, %a1#1] : memref<10x10xf32> + // expected-note@-1 {{dependence from 1 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 1 to 0 at depth 2 = false}} + // expected-note@-3 {{dependence from 1 to 0 at depth 3 = false}} + // expected-note@-4 {{dependence from 1 to 1 at depth 1 = false}} + // expected-note@-5 {{dependence from 1 to 1 at depth 2 = false}} + // expected-note@-6 {{dependence from 1 to 1 at depth 3 = false}} } } return } // ----- -// CHECK-LABEL: mlfunc @store_load_diff_outer_ranges_diff_2d_loop_nests() { -mlfunc @store_load_diff_outer_ranges_diff_2d_loop_nests() { +// CHECK-LABEL: mlfunc @one_common_loop() { +mlfunc @one_common_loop() { %m = alloc() : memref<10x10xf32> %c7 = constant 7.0 : f32 - for %i0 = 0 to 5 { - for %i1 = 0 to 5 { + // There is a loop-independent dependence from access 0 to 1 at depth 2. + for %i0 = 0 to 10 { + for %i1 = 0 to 10 { %a0 = affine_apply (d0, d1) -> (d0, d1) (%i0, %i1) - store %c7, %m[%a0#0, %a0#1] : memref<10x10xf32> - // expected-note@-1 {{dependence from memref access 0 to access 1 = false}} + store %c7, %m[%a0#0, %a0#1] : memref<10x10xf32> + // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 0 to 0 at depth 2 = false}} + // expected-note@-3 {{dependence from 0 to 0 at depth 3 = false}} + // expected-note@-4 {{dependence from 0 to 1 at depth 1 = false}} + // expected-note@-5 {{dependence from 0 to 1 at depth 2 = true}} } - } - for %i2 = 5 to 8 { - for %i3 = 0 to 5 { - %a1 = affine_apply (d0, d1) -> (d0, d1) (%i2, %i3) + for %i2 = 0 to 9 { + %a1 = affine_apply (d0, d1) -> (d0, d1) (%i0, %i2) %v0 = load %m[%a1#0, %a1#1] : memref<10x10xf32> + // expected-note@-1 {{dependence from 1 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 1 to 0 at depth 2 = false}} + // expected-note@-3 {{dependence from 1 to 1 at depth 1 = false}} + // expected-note@-4 {{dependence from 1 to 1 at depth 2 = false}} + // expected-note@-5 {{dependence from 1 to 1 at depth 3 = false}} } } return } // ----- -// CHECK-LABEL: mlfunc @store_load_overlapping_outer_ranges_diff_2d_loop_nests() { -mlfunc @store_load_overlapping_outer_ranges_diff_2d_loop_nests() { +// CHECK-LABEL: mlfunc @dependence_cycle() { +mlfunc @dependence_cycle() { + %m.a = alloc() : memref<100xf32> + %m.b = alloc() : memref<100xf32> + + // Dependences: + // *) loop-independent dependence from access 1 to 2 at depth 2. + // *) loop-carried dependence from access 3 to 0 at depth 1. + for %i0 = 0 to 9 { + %a0 = affine_apply (d0) -> (d0) (%i0) + %v0 = load %m.a[%a0] : memref<100xf32> + // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 0 to 0 at depth 2 = false}} + // expected-note@-3 {{dependence from 0 to 1 at depth 1 = false}} + // expected-note@-4 {{dependence from 0 to 1 at depth 2 = false}} + // expected-note@-5 {{dependence from 0 to 2 at depth 1 = false}} + // expected-note@-6 {{dependence from 0 to 2 at depth 2 = false}} + // expected-note@-7 {{dependence from 0 to 3 at depth 1 = false}} + // expected-note@-8 {{dependence from 0 to 3 at depth 2 = false}} + %a1 = affine_apply (d0) -> (d0) (%i0) + store %v0, %m.b[%a1] : memref<100xf32> + // expected-note@-1 {{dependence from 1 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 1 to 0 at depth 2 = false}} + // expected-note@-3 {{dependence from 1 to 1 at depth 1 = false}} + // expected-note@-4 {{dependence from 1 to 1 at depth 2 = false}} + // expected-note@-5 {{dependence from 1 to 2 at depth 1 = false}} + // expected-note@-6 {{dependence from 1 to 2 at depth 2 = true}} + // expected-note@-7 {{dependence from 1 to 3 at depth 1 = false}} + // expected-note@-8 {{dependence from 1 to 3 at depth 2 = false}} + %a2 = affine_apply (d0) -> (d0) (%i0) + %v1 = load %m.b[%a2] : memref<100xf32> + // expected-note@-1 {{dependence from 2 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 2 to 0 at depth 2 = false}} + // expected-note@-3 {{dependence from 2 to 1 at depth 1 = false}} + // expected-note@-4 {{dependence from 2 to 1 at depth 2 = false}} + // expected-note@-5 {{dependence from 2 to 2 at depth 1 = false}} + // expected-note@-6 {{dependence from 2 to 2 at depth 2 = false}} + // expected-note@-7 {{dependence from 2 to 3 at depth 1 = false}} + // expected-note@-8 {{dependence from 2 to 3 at depth 2 = false}} + %a3 = affine_apply (d0) -> (d0 + 1) (%i0) + store %v1, %m.a[%a3] : memref<100xf32> + // expected-note@-1 {{dependence from 3 to 0 at depth 1 = [1, 1]}} + // expected-note@-2 {{dependence from 3 to 0 at depth 2 = false}} + // expected-note@-3 {{dependence from 3 to 1 at depth 1 = false}} + // expected-note@-4 {{dependence from 3 to 1 at depth 2 = false}} + // expected-note@-5 {{dependence from 3 to 2 at depth 1 = false}} + // expected-note@-6 {{dependence from 3 to 2 at depth 2 = false}} + // expected-note@-7 {{dependence from 3 to 3 at depth 1 = false}} + // expected-note@-8 {{dependence from 3 to 3 at depth 2 = false}} + } + return +} + +// ----- +// CHECK-LABEL: mlfunc @negative_and_positive_direction_vectors() { +mlfunc @negative_and_positive_direction_vectors() { %m = alloc() : memref<10x10xf32> %c7 = constant 7.0 : f32 - for %i0 = 0 to 5 { - for %i1 = 0 to 5 { - %a0 = affine_apply (d0, d1) -> (d0 + 1, d1) (%i0, %i1) - store %c7, %m[%a0#0, %a0#1] : memref<10x10xf32> - // expected-note@-1 {{dependence from memref access 0 to access 1 = true}} + for %i0 = 0 to 10 { + for %i1 = 0 to 10 { + %a0 = affine_apply (d0, d1) -> (d0 - 1, d1 + 1) (%i0, %i1) + %v0 = load %m[%a0#0, %a0#1] : memref<10x10xf32> + // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 0 to 0 at depth 2 = false}} + // expected-note@-3 {{dependence from 0 to 0 at depth 3 = false}} + // expected-note@-4 {{dependence from 0 to 1 at depth 1 = false}} + // expected-note@-5 {{dependence from 0 to 1 at depth 2 = false}} + // expected-note@-6 {{dependence from 0 to 1 at depth 3 = false}} + %a1 = affine_apply (d0, d1) -> (d0, d1) (%i0, %i1) + store %c7, %m[%a1#0, %a1#1] : memref<10x10xf32> + // expected-note@-1 {{dependence from 1 to 0 at depth 1 = [1, 1][-1, -1]}} + // expected-note@-2 {{dependence from 1 to 0 at depth 2 = false}} + // expected-note@-3 {{dependence from 1 to 0 at depth 3 = false}} + // expected-note@-4 {{dependence from 1 to 1 at depth 1 = false}} + // expected-note@-5 {{dependence from 1 to 1 at depth 2 = false}} + // expected-note@-6 {{dependence from 1 to 1 at depth 3 = false}} } } - for %i2 = 5 to 8 { - for %i3 = 0 to 5 { - %a1 = affine_apply (d0, d1) -> (d0, d1) (%i2, %i3) - %v0 = load %m[%a1#0, %a1#1] : memref<10x10xf32> + return +} + +// ----- +// CHECK-LABEL: mlfunc @war_raw_waw_deps() { +mlfunc @war_raw_waw_deps() { + %m = alloc() : memref<100xf32> + %c7 = constant 7.0 : f32 + for %i0 = 0 to 10 { + for %i1 = 0 to 10 { + %a0 = affine_apply (d0) -> (d0 + 1) (%i1) + %v0 = load %m[%a0] : memref<100xf32> + // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} + // expected-note@-2 {{dependence from 0 to 0 at depth 2 = false}} + // expected-note@-3 {{dependence from 0 to 0 at depth 3 = false}} + // expected-note@-4 {{dependence from 0 to 1 at depth 1 = [1, 9][1, 1]}} + // expected-note@-5 {{dependence from 0 to 1 at depth 2 = [0, 0][1, 1]}} + // expected-note@-6 {{dependence from 0 to 1 at depth 3 = false}} + %a1 = affine_apply (d0) -> (d0) (%i1) + store %c7, %m[%a1] : memref<100xf32> + // expected-note@-1 {{dependence from 1 to 0 at depth 1 = [1, 9][-1, -1]}} + // expected-note@-2 {{dependence from 1 to 0 at depth 2 = false}} + // expected-note@-3 {{dependence from 1 to 0 at depth 3 = false}} + // expected-note@-4 {{dependence from 1 to 1 at depth 1 = [1, 9][0, 0]}} + // expected-note@-5 {{dependence from 1 to 1 at depth 2 = false}} + // expected-note@-6 {{dependence from 1 to 1 at depth 3 = false}} } } return -- cgit v1.2.3 From 3b69230b3a7e156150d349d139d4b52172585e50 Mon Sep 17 00:00:00 2001 From: MLIR Team Date: Mon, 17 Dec 2018 09:57:14 -0800 Subject: Loop Fusion pass update: introduce utilities to perform generalized loop fusion based on slicing; encompasses standard loop fusion. *) Adds simple greedy fusion algorithm to drive experimentation. This algorithm greedily fuses loop nests with single-writer/single-reader memref dependences to improve locality. *) Adds support for fusing slices of a loop nest computation: fusing one loop nest into another by adjusting the source loop nest's iteration bounds (after it is fused into the destination loop nest). This is accomplished by solving for the source loop nest's IVs in terms of the destination loop nests IVs and symbols using the dependece polyhedron, then creating AffineMaps of these functions for the loop bounds of the fused source loop. *) Adds utility function 'insertMemRefComputationSlice' which computes and inserts computation slice from loop nest surrounding a source memref access into the loop nest surrounding the destingation memref access. *) Adds FlatAffineConstraints::toAffineMap function which returns and AffineMap which represents an equality contraint where one dimension identifier is represented as a function of all others in the equality constraint. *) Adds multiple fusion unit tests. PiperOrigin-RevId: 225842944 --- mlir/include/mlir/Analysis/AffineAnalysis.h | 5 +- mlir/include/mlir/Analysis/AffineStructures.h | 15 + mlir/include/mlir/Analysis/Utils.h | 17 + mlir/lib/Analysis/AffineAnalysis.cpp | 25 +- mlir/lib/Analysis/AffineStructures.cpp | 49 +- mlir/lib/Analysis/MemRefDependenceCheck.cpp | 2 + mlir/lib/Analysis/Utils.cpp | 129 +++++ mlir/lib/Transforms/LoopFusion.cpp | 497 +++++++++++++------ mlir/test/Transforms/loop-fusion.mlir | 672 +++++++++++++++++++++----- 9 files changed, 1112 insertions(+), 299 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Analysis/AffineAnalysis.h b/mlir/include/mlir/Analysis/AffineAnalysis.h index a5bc3738d92..bc671272a75 100644 --- a/mlir/include/mlir/Analysis/AffineAnalysis.h +++ b/mlir/include/mlir/Analysis/AffineAnalysis.h @@ -145,9 +145,12 @@ struct DependenceComponent { /// the operation statement, indices and memref associated with the access. /// Returns 'false' if it can be determined conclusively that the accesses do /// not access the same memref element. Returns 'true' otherwise. +// TODO(andydavis) Wrap 'dependenceConstraints' and 'dependenceComponents' into +// a single struct. +// TODO(andydavis) Make 'dependenceConstraints' optional arg. bool checkMemrefAccessDependence( const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, - unsigned loopDepth, + unsigned loopDepth, FlatAffineConstraints *dependenceConstraints, llvm::SmallVector *dependenceComponents); } // end namespace mlir diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h index ce53eaf228a..3261d13caf8 100644 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -384,6 +384,21 @@ public: AffineExpr toAffineExpr(unsigned idx, MLIRContext *context); + // Returns an AffineMap that expresses the identifier at pos as a function of + // other dimensional and symbolic identifiers using the 'idx^th' equality + // constraint. + // If 'nonZeroDimIds' and 'nonZeroSymbolIds' are non-null, they are populated + // with the positions of the non-zero equality constraint coefficients which + // were used to build the returned AffineMap. + // Returns AffineMap::Null on error (i.e. if coefficient is zero or does + // not divide other coefficients in the equality constraint). + // TODO(andydavis) Remove 'nonZeroDimIds' and 'nonZeroSymbolIds' from this + // API when we can manage the mapping of MLValues and ids in the constraint + // system. + AffineMap toAffineMapFromEq(unsigned idx, unsigned pos, MLIRContext *context, + SmallVectorImpl *nonZeroDimIds, + SmallVectorImpl *nonZeroSymbolIds); + // Adds an inequality (>= 0) from the coefficients specified in inEq. void addInequality(ArrayRef inEq); // Adds an equality from the coefficients specified in eq. diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index 197edb28a01..796a7aa1453 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -33,7 +33,9 @@ namespace mlir { class FlatAffineConstraints; +class ForStmt; class MLValue; +class MemRefAccess; class OperationStmt; class Statement; @@ -139,6 +141,21 @@ template bool boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, bool emitError = true); +/// Creates a clone of the computation contained in the loop nest surrounding +/// 'srcAccess', and inserts it at the beginning of the statement block of the +/// loop containing 'dstAccess'. Returns the top-level loop of the computation +/// slice on success, returns nullptr otherwise. +// Computes memref dependence between 'srcAccess' and 'dstAccess' and uses the +// dependence constraint system to create AffineMaps with which to adjust the +// loop bounds of the inserted compution slice so that they are functions of the +// loop IVs and symbols of the loops surrounding 'dstAccess'. +// TODO(andydavis) Add 'dstLoopDepth' argument for computation slice insertion. +// Loop depth is a crucial optimization choice that determines where to +// materialize the results of the backward slice - presenting a trade-off b/w +// storage and redundant computation in several cases +// TODO(andydavis) Support computation slices with common surrounding loops. +ForStmt *insertBackwardComputationSlice(MemRefAccess *srcAccess, + MemRefAccess *dstAccess); } // end namespace mlir #endif // MLIR_ANALYSIS_UTILS_H diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index 7f53a148a57..80da93d4262 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -1152,7 +1152,7 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const { // The access functions would be the following: // // src: (%i0 * 2 - %i1 * 4 + %N, %i1 * 3 - %M) -// src: (%i2 * 7 + %i3 * 9 - %M, %i3 * 11 - %K) +// dst: (%i2 * 7 + %i3 * 9 - %M, %i3 * 11 - %K) // // The iteration domains for the src/dst accesses would be the following: // @@ -1166,7 +1166,7 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const { // symbol pos: 0 1 2 // // Equality constraints are built by equating each result of src/destination -// access functions. For this example, the folloing two equality constraints +// access functions. For this example, the following two equality constraints // will be added to the dependence constraint system: // // [src_dim0, src_dim1, dst_dim0, dst_dim1, sym0, sym1, sym2, const] @@ -1190,7 +1190,7 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const { // TODO(andydavis) Support AffineExprs mod/floordiv/ceildiv. bool mlir::checkMemrefAccessDependence( const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, - unsigned loopDepth, + unsigned loopDepth, FlatAffineConstraints *dependenceConstraints, llvm::SmallVector *dependenceComponents) { // Return 'false' if these accesses do not acces the same memref. if (srcAccess.memref != dstAccess.memref) @@ -1247,28 +1247,31 @@ bool mlir::checkMemrefAccessDependence( unsigned numCols = numIds + 1; // Create flat affine constraints reserving space for 'numEq' and 'numIneq'. - FlatAffineConstraints dependenceDomain(numIneq, numEq, numCols, numDims, - numSymbols); + dependenceConstraints->reset(numIneq, numEq, numCols, numDims, numSymbols, + /*numLocals=*/0); // Create memref access constraint by equating src/dst access functions. // Note that this check is conservative, and will failure in the future // when local variables for mod/div exprs are supported. if (!addMemRefAccessConstraints(srcAccessMap, dstAccessMap, valuePosMap, - &dependenceDomain)) + dependenceConstraints)) return true; // Add 'src' happens before 'dst' ordering constraints. addOrderingConstraints(srcIterationDomainContext, dstIterationDomainContext, - valuePosMap, loopDepth, &dependenceDomain); + valuePosMap, loopDepth, dependenceConstraints); // Add src and dst domain constraints. addDomainConstraints(srcIterationDomainContext, dstIterationDomainContext, - valuePosMap, &dependenceDomain); + valuePosMap, dependenceConstraints); // Return false if the solution space is empty: no dependence. - if (dependenceDomain.isEmpty()) { + if (dependenceConstraints->isEmpty()) { return false; } // Compute dependence direction vector and return true. - computeDirectionVector(srcIterationDomainContext, dstIterationDomainContext, - loopDepth, &dependenceDomain, dependenceComponents); + if (dependenceComponents != nullptr) { + computeDirectionVector(srcIterationDomainContext, dstIterationDomainContext, + loopDepth, dependenceConstraints, + dependenceComponents); + } return true; } diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 4a344a17698..9d14405427a 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -745,7 +745,6 @@ bool FlatAffineConstraints::composeMap(AffineValueMap *vMap) { // add two equalities overall: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0. for (unsigned r = 0, e = flatExprs.size(); r < e; r++) { const auto &flatExpr = flatExprs[r]; - // eqToAdd is the equality corresponding to the flattened affine expression. SmallVector eqToAdd(getNumCols(), 0); // Set the coefficient for this result to one. @@ -1100,6 +1099,54 @@ unsigned FlatAffineConstraints::gaussianEliminateIds(unsigned posStart, return posLimit - posStart; } +// Returns an AffineMap which represents 'pos' in equality constraint 'idx', +// as a function of dim and symbols identifers in all other positions. +// TODO(andydavis) Add local variable support to this function. +AffineMap FlatAffineConstraints::toAffineMapFromEq( + unsigned idx, unsigned pos, MLIRContext *context, + SmallVectorImpl *nonZeroDimIds, + SmallVectorImpl *nonZeroSymbolIds) { + assert(getNumLocalIds() == 0); + assert(idx < getNumEqualities()); + int64_t v = atEq(idx, pos); + // Return if coefficient at (idx, pos) is zero or does not divide constant. + if (v == 0 || (atEq(idx, getNumIds()) % v != 0)) + return AffineMap::Null(); + // Check that coefficient at 'pos' divides all other coefficient in row 'idx'. + for (unsigned j = 0, e = getNumIds(); j < e; ++j) { + if (j != pos && (atEq(idx, j) % v != 0)) + return AffineMap::Null(); + } + // Build AffineExpr solving for identifier 'pos' in terms of all others. + auto expr = getAffineConstantExpr(0, context); + unsigned mapNumDims = 0; + unsigned mapNumSymbols = 0; + for (unsigned j = 0, e = getNumIds(); j < e; ++j) { + if (j == pos) + continue; + int64_t c = atEq(idx, j); + if (c == 0) + continue; + // Divide 'c' by 'v' from 'pos' for which we are solving. + c /= v; + if (j < numDims) { + expr = expr + getAffineDimExpr(mapNumDims++, context) * c; + nonZeroDimIds->push_back(j); + } else { + expr = + expr + getAffineSymbolExpr(mapNumDims + mapNumSymbols++, context) * c; + nonZeroSymbolIds->push_back(j); + } + expr = expr * (-1); + } + // Add constant term to AffineExpr. + int64_t c = atEq(idx, getNumIds()); + if (c > 0) { + expr = expr + (c / v) * (-1); + } + return AffineMap::get(mapNumDims, mapNumSymbols, {expr}, {}); +} + void FlatAffineConstraints::addEquality(ArrayRef eq) { assert(eq.size() == getNumCols()); unsigned offset = equalities.size(); diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp index 7a5881854cf..2e3df2d61f4 100644 --- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp @@ -148,8 +148,10 @@ static void checkDependences(ArrayRef loadsAndStores) { unsigned numCommonLoops = getNumCommonSurroundingLoops(srcLoops, dstLoops); for (unsigned d = 1; d <= numCommonLoops + 1; ++d) { + FlatAffineConstraints dependenceConstraints; llvm::SmallVector dependenceComponents; bool ret = checkMemrefAccessDependence(srcAccess, dstAccess, d, + &dependenceConstraints, &dependenceComponents); // TODO(andydavis) Print dependence type (i.e. RAW, etc) and print // distance vectors as: ([2, 3], [0, 10]). Also, shorten distance diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 293d91201ac..3fe22e9dd9a 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -310,3 +310,132 @@ template bool mlir::boundCheckLoadOrStoreOp(OpPointer loadOp, bool emitError); template bool mlir::boundCheckLoadOrStoreOp(OpPointer storeOp, bool emitError); + +// Returns in 'positions' the StmtBlock positions of 'stmt' in each ancestor +// StmtBlock from the StmtBlock containing statement, stopping at 'limitBlock'. +static void findStmtPosition(const Statement *stmt, StmtBlock *limitBlock, + SmallVectorImpl *positions) { + StmtBlock *block = stmt->getBlock(); + while (block != limitBlock) { + int stmtPosInBlock = block->findStmtPosInBlock(*stmt); + assert(stmtPosInBlock >= 0); + positions->push_back(stmtPosInBlock); + stmt = block->getContainingStmt(); + block = stmt->getBlock(); + } + std::reverse(positions->begin(), positions->end()); +} + +// Returns the Statement in a possibly nested set of StmtBlocks, where the +// position of the statement is represented by 'positions', which has a +// StmtBlock position for each level of nesting. +static Statement *getStmtAtPosition(ArrayRef positions, + unsigned level, StmtBlock *block) { + unsigned i = 0; + for (auto &stmt : *block) { + if (i != positions[level]) { + ++i; + continue; + } + if (level == positions.size() - 1) + return &stmt; + if (auto *childForStmt = dyn_cast(&stmt)) + return getStmtAtPosition(positions, level + 1, childForStmt); + + if (auto *ifStmt = dyn_cast(&stmt)) { + auto *ret = getStmtAtPosition(positions, level + 1, ifStmt->getThen()); + if (ret != nullptr) + return ret; + if (auto *elseClause = ifStmt->getElse()) + return getStmtAtPosition(positions, level + 1, elseClause); + } + } + return nullptr; +} + +// TODO(andydavis) Support a 'dstLoopDepth' argument for computation slice +// insertion (currently the computation slice is inserted at the same +// loop depth as 'dstAccess.opStmt'. +ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess, + MemRefAccess *dstAccess) { + FlatAffineConstraints dependenceConstraints; + if (!checkMemrefAccessDependence(*srcAccess, *dstAccess, /*loopDepth=*/0, + &dependenceConstraints, + /*dependenceComponents=*/nullptr)) { + return nullptr; + } + // Get loop nest surrounding src operation. + SmallVector srcLoopNest; + getLoopIVs(*srcAccess->opStmt, &srcLoopNest); + unsigned srcLoopNestSize = srcLoopNest.size(); + + // Get loop nest surrounding dst operation. + SmallVector dstLoopNest; + getLoopIVs(*dstAccess->opStmt, &dstLoopNest); + unsigned dstLoopNestSize = dstLoopNest.size(); + + // Solve for src IVs in terms of dst IVs, symbols and constants. + SmallVector srcIvMaps(srcLoopNestSize, AffineMap::Null()); + std::vector> srcIvOperands(srcLoopNestSize); + for (unsigned i = 0; i < srcLoopNestSize; ++i) { + auto cst = dependenceConstraints.clone(); + for (int j = srcLoopNestSize - 1; j >= 0; --j) { + if (i != j) + cst->projectOut(j); + } + if (cst->getNumEqualities() != 1) { + srcIvMaps[i] = AffineMap::Null(); + continue; + } + SmallVector nonZeroDimIds; + SmallVector nonZeroSymbolIds; + srcIvMaps[i] = cst->toAffineMapFromEq(0, 0, srcAccess->opStmt->getContext(), + &nonZeroDimIds, &nonZeroSymbolIds); + if (srcIvMaps[i] == AffineMap::Null()) + continue; + // Add operands for all non-zero dst dims and symbols. + // TODO(andydavis) Add local variable support. + for (auto dimId : nonZeroDimIds) { + srcIvOperands[i].push_back(dstLoopNest[dimId - 1]); + } + // TODO(andydavis) Add symbols from the access function. Ideally, we + // should be able to query the constaint system for the MLValue associated + // with a symbol identifiers in 'nonZeroSymbolIds'. + } + + // Find the stmt block positions of 'srcAccess->opStmt' within 'srcLoopNest'. + SmallVector positions; + findStmtPosition(srcAccess->opStmt, srcLoopNest[0]->getBlock(), &positions); + + // Clone src loop nest and insert it a the beginning of the statement block + // of the same loop in which containts 'dstAccess->opStmt'. + auto *dstForStmt = dstLoopNest[dstLoopNestSize - 1]; + MLFuncBuilder b(dstForStmt, dstForStmt->begin()); + DenseMap operandMap; + auto *sliceLoopNest = cast(b.clone(*srcLoopNest[0], operandMap)); + + // Lookup stmt in cloned 'sliceLoopNest' at 'positions'. + Statement *sliceStmt = + getStmtAtPosition(positions, /*level=*/0, sliceLoopNest); + // Get loop nest surrounding 'sliceStmt'. + SmallVector sliceSurroundingLoops; + getLoopIVs(*sliceStmt, &sliceSurroundingLoops); + unsigned sliceSurroundingLoopsSize = sliceSurroundingLoops.size(); + + // Update loop bounds for loops in 'sliceLoopNest'. + for (unsigned i = dstLoopNestSize; i < sliceSurroundingLoopsSize; ++i) { + auto *forStmt = sliceSurroundingLoops[i]; + unsigned index = i - dstLoopNestSize; + AffineMap lbMap = srcIvMaps[index]; + if (lbMap == AffineMap::Null()) + continue; + forStmt->setLowerBound(srcIvOperands[index], lbMap); + // Create upper bound map with is lower bound map + 1; + assert(lbMap.getNumResults() == 1); + AffineExpr ubResultExpr = lbMap.getResult(0) + 1; + AffineMap ubMap = AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(), + {ubResultExpr}, {}); + forStmt->setUpperBound(srcIvOperands[index], ubMap); + } + return sliceLoopNest; +} diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 3db290fac03..521fca8979f 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -20,7 +20,9 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/LoopAnalysis.h" +#include "mlir/Analysis/Utils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" @@ -31,16 +33,25 @@ #include "mlir/Transforms/LoopUtils.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/raw_ostream.h" + +using llvm::SetVector; using namespace mlir; namespace { -/// Loop fusion pass. This pass fuses adjacent loops in MLFunctions which -/// access the same memref with no dependences. -// See MatchTestPattern for details on candidate loop selection. +/// Loop fusion pass. This pass currently supports a greedy fusion policy, +/// which fuses loop nests with single-writer/single-reader memref dependences +/// with the goal of improving locality. + +// TODO(andydavis) Support fusion of source loop nests which write to multiple +// memrefs, where each memref can have multiple users (if profitable). // TODO(andydavis) Extend this pass to check for fusion preventing dependences, // and add support for more general loop fusion algorithms. + struct LoopFusion : public FunctionPass { LoopFusion() : FunctionPass(&LoopFusion::passID) {} @@ -48,51 +59,12 @@ struct LoopFusion : public FunctionPass { static char passID; }; -// LoopCollector walks the statements in an MLFunction and builds a map from -// StmtBlocks to a list of loops within the StmtBlock, and a map from ForStmts -// to the list of loads and stores with its StmtBlock. -class LoopCollector : public StmtWalker { -public: - DenseMap> loopMap; - DenseMap> loadsAndStoresMap; - bool hasIfStmt = false; - - void visitForStmt(ForStmt *forStmt) { - loopMap[forStmt->getBlock()].push_back(forStmt); - } - - void visitIfStmt(IfStmt *ifStmt) { hasIfStmt = true; } - - void visitOperationStmt(OperationStmt *opStmt) { - if (auto *parentStmt = opStmt->getParentStmt()) { - if (auto *parentForStmt = dyn_cast(parentStmt)) { - if (opStmt->isa() || opStmt->isa()) { - loadsAndStoresMap[parentForStmt].push_back(opStmt); - } - } - } - } -}; - } // end anonymous namespace char LoopFusion::passID = 0; FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; } -// TODO(andydavis) Remove the following test code when more general loop -// fusion is supported. -struct FusionCandidate { - // Loop nest of ForStmts with 'accessA' in the inner-most loop. - SmallVector forStmtsA; - // Load or store operation within loop nest 'forStmtsA'. - MemRefAccess accessA; - // Loop nest of ForStmts with 'accessB' in the inner-most loop. - SmallVector forStmtsB; - // Load or store operation within loop nest 'forStmtsB'. - MemRefAccess accessB; -}; - static void getSingleMemRefAccess(OperationStmt *loadOrStoreOpStmt, MemRefAccess *access) { if (auto loadOp = loadOrStoreOpStmt->dyn_cast()) { @@ -116,137 +88,348 @@ static void getSingleMemRefAccess(OperationStmt *loadOrStoreOpStmt, } } -// Checks if 'forStmtA' and 'forStmtB' match specific test criterion: -// constant loop bounds, no nested loops, single StoreOp in 'forStmtA' and -// a single LoadOp in 'forStmtB'. -// Returns true if the test pattern matches, false otherwise. -static bool MatchTestPatternLoopPair(LoopCollector *lc, - FusionCandidate *candidate, - ForStmt *forStmtA, ForStmt *forStmtB) { - if (forStmtA == nullptr || forStmtB == nullptr) - return false; - // Return if 'forStmtA' and 'forStmtB' do not have matching constant - // bounds and step. - if (!forStmtA->hasConstantBounds() || !forStmtB->hasConstantBounds() || - forStmtA->getConstantLowerBound() != forStmtB->getConstantLowerBound() || - forStmtA->getConstantUpperBound() != forStmtB->getConstantUpperBound() || - forStmtA->getStep() != forStmtB->getStep()) - return false; - - // Return if 'forStmtA' or 'forStmtB' have nested loops. - if (lc->loopMap.count(forStmtA) > 0 || lc->loopMap.count(forStmtB)) - return false; +// FusionCandidate encapsulates source and destination memref access within +// loop nests which are candidates for loop fusion. +struct FusionCandidate { + // Load or store access within src loop nest to be fused into dst loop nest. + MemRefAccess srcAccess; + // Load or store access within dst loop nest. + MemRefAccess dstAccess; +}; - // Return if 'forStmtA' or 'forStmtB' do not have exactly one load or store. - if (lc->loadsAndStoresMap[forStmtA].size() != 1 || - lc->loadsAndStoresMap[forStmtB].size() != 1) - return false; +static FusionCandidate buildFusionCandidate(OperationStmt *srcStoreOpStmt, + OperationStmt *dstLoadOpStmt) { + FusionCandidate candidate; + // Get store access for src loop nest. + getSingleMemRefAccess(srcStoreOpStmt, &candidate.srcAccess); + // Get load access for dst loop nest. + getSingleMemRefAccess(dstLoadOpStmt, &candidate.dstAccess); + return candidate; +} - // Get load/store access for forStmtA. - getSingleMemRefAccess(lc->loadsAndStoresMap[forStmtA][0], - &candidate->accessA); - // Return if 'accessA' is not a store. - if (!candidate->accessA.opStmt->isa()) - return false; +namespace { - // Get load/store access for forStmtB. - getSingleMemRefAccess(lc->loadsAndStoresMap[forStmtB][0], - &candidate->accessB); +// LoopNestStateCollector walks loop nests and collects load and store +// operations, and whether or not an IfStmt was encountered in the loop nest. +class LoopNestStateCollector : public StmtWalker { +public: + SmallVector forStmts; + SmallVector loadOpStmts; + SmallVector storeOpStmts; + bool hasIfStmt = false; - // Return if accesses do not access the same memref. - if (candidate->accessA.memref != candidate->accessB.memref) - return false; + void visitForStmt(ForStmt *forStmt) { forStmts.push_back(forStmt); } - candidate->forStmtsA.push_back(forStmtA); - candidate->forStmtsB.push_back(forStmtB); - return true; -} + void visitIfStmt(IfStmt *ifStmt) { hasIfStmt = true; } -// Returns the child ForStmt of 'parent' if unique, returns false otherwise. -ForStmt *getSingleForStmtChild(ForStmt *parent) { - if (parent->getStatements().size() == 1 && isa(parent->front())) - return dyn_cast(&parent->front()); - return nullptr; -} + void visitOperationStmt(OperationStmt *opStmt) { + if (opStmt->isa()) + loadOpStmts.push_back(opStmt); + if (opStmt->isa()) + storeOpStmts.push_back(opStmt); + } +}; -// Checks for a specific ForStmt/OpStatment test pattern in 'f', returns true -// on success and resturns fusion candidate in 'candidate'. Returns false -// otherwise. -// Currently supported test patterns: -// *) Adjacent loops with a StoreOp the only op in first loop, and a LoadOp the -// only op in the second loop (both load/store accessing the same memref). -// *) As above, but with one level of perfect loop nesting. +// GreedyFusionPolicy greedily fuses loop nests which have a producer/consumer +// relationship on a memref, with the goal of improving locality. Currently, +// this the producer/consumer relationship is required to be unique in the +// MLFunction (there are TODOs to relax this constraint in the future). +// +// The steps of the algorithm are as follows: +// +// *) Initialize. While visiting each statement in the MLFunction do: +// *) Assign each top-level ForStmt a 'position' which is its initial +// position in the MLFunction's StmtBlock at the start of the pass. +// *) Gather memref load/store state aggregated by top-level statement. For +// example, all loads and stores contained in a loop nest are aggregated +// under the loop nest's top-level ForStmt. +// *) Add each top-level ForStmt to a worklist. // -// TODO(andydavis) Look into using ntv@ pattern matcher here. -static bool MatchTestPattern(MLFunction *f, FusionCandidate *candidate) { - LoopCollector lc; - lc.walk(f); - // Return if an IfStmt was found or if less than two ForStmts were found. - if (lc.hasIfStmt || lc.loopMap.count(f) == 0 || lc.loopMap[f].size() < 2) +// *) Run. The algorithm processes the worklist with the following steps: +// *) The worklist is processed in reverse order (starting from the last +// top-level ForStmt in the MLFunction). +// *) Pop a ForStmt of the worklist. This 'dstForStmt' will be a candidate +// destination ForStmt into which fusion will be attempted. +// *) Add each LoadOp currently in 'dstForStmt' into list 'dstLoadOps'. +// *) For each LoadOp in 'dstLoadOps' do: +// *) Lookup dependent loop nests at earlier positions in the MLFunction +// which have a single store op to the same memref. +// *) Check if dependences would be violated by the fusion. For example, +// the src loop nest may load from memrefs which are different than +// the producer-consumer memref between src and dest loop nests. +// *) Get a computation slice of 'srcLoopNest', which adjust its loop +// bounds to be functions of 'dstLoopNest' IVs and symbols. +// *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest', +// just before the dst load op user. +// *) Add the newly fused load/store operation statements to the state, +// and also add newly fuse load ops to 'dstLoopOps' to be considered +// as fusion dst load ops in another iteration. +// *) Remove old src loop nest and its associated state. +// +// Given a graph where top-level statements are vertices in the set 'V' and +// edges in the set 'E' are dependences between vertices, this algorithm +// takes O(V) time for initialization, and has runtime O(V * E). +// TODO(andydavis) Reduce this time complexity to O(V + E). +// +// This greedy algorithm is not 'maximally' but there is a TODO to fix this. +// +// TODO(andydavis) Experiment with other fusion policies. +struct GreedyFusionPolicy { + // Convenience wrapper with information about 'stmt' ready to access. + struct StmtInfo { + Statement *stmt; + bool isOrContainsIfStmt = false; + }; + // The worklist of top-level loop nest positions. + SmallVector worklist; + // Mapping from top-level position to StmtInfo. + DenseMap posToStmtInfo; + // Mapping from memref MLValue to set of top-level positions of loop nests + // which contain load ops on that memref. + DenseMap> memrefToLoadPosSet; + // Mapping from memref MLValue to set of top-level positions of loop nests + // which contain store ops on that memref. + DenseMap> memrefToStorePosSet; + // Mapping from top-level loop nest to the set of load ops it contains. + DenseMap> forStmtToLoadOps; + // Mapping from top-level loop nest to the set of store ops it contains. + DenseMap> forStmtToStoreOps; + + GreedyFusionPolicy(MLFunction *f) { init(f); } + + void run() { + if (hasIfStmts()) + return; + + while (!worklist.empty()) { + // Pop the position of a loop nest into which fusion will be attempted. + unsigned dstPos = worklist.back(); + worklist.pop_back(); + // Skip if 'dstPos' is not tracked (was fused into another loop nest). + if (posToStmtInfo.count(dstPos) == 0) + continue; + // Get the top-level ForStmt at 'dstPos'. + auto *dstForStmt = getForStmtAtPos(dstPos); + // Skip if this ForStmt contains no load ops. + if (forStmtToLoadOps.count(dstForStmt) == 0) + continue; + + // Greedy Policy: iterate through load ops in 'dstForStmt', greedily + // fusing in src loop nests which have a single store op on the same + // memref, until a fixed point is reached where there is nothing left to + // fuse. + SetVector dstLoadOps = forStmtToLoadOps[dstForStmt]; + while (!dstLoadOps.empty()) { + auto *dstLoadOpStmt = dstLoadOps.pop_back_val(); + + auto dstLoadOp = dstLoadOpStmt->cast(); + auto *memref = cast(dstLoadOp->getMemRef()); + // Skip if not single src store / dst load pair on 'memref'. + if (memrefToLoadPosSet[memref].size() != 1 || + memrefToStorePosSet[memref].size() != 1) + continue; + unsigned srcPos = *memrefToStorePosSet[memref].begin(); + if (srcPos >= dstPos) + continue; + auto *srcForStmt = getForStmtAtPos(srcPos); + // Skip if 'srcForStmt' has more than one store op. + if (forStmtToStoreOps[srcForStmt].size() > 1) + continue; + // Skip if fusion would violated dependences between 'memref' access + // for loop nests between 'srcPos' and 'dstPos': + // For each src load op: check for store ops in range (srcPos, dstPos). + // For each src store op: check for load ops in range (srcPos, dstPos). + if (moveWouldViolateDependences(srcPos, dstPos)) + continue; + auto *srcStoreOpStmt = forStmtToStoreOps[srcForStmt].front(); + // Build fusion candidate out of 'srcStoreOpStmt' and 'dstLoadOpStmt'. + FusionCandidate candidate = + buildFusionCandidate(srcStoreOpStmt, dstLoadOpStmt); + // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. + auto *sliceLoopNest = mlir::insertBackwardComputationSlice( + &candidate.srcAccess, &candidate.dstAccess); + if (sliceLoopNest != nullptr) { + // Remove 'srcPos' mappings from 'state'. + moveAccessesAndRemovePos(srcPos, dstPos); + // Record all load/store accesses in 'sliceLoopNest' at 'dstPos'. + LoopNestStateCollector collector; + collector.walkForStmt(sliceLoopNest); + // Record mappings for loads and stores from 'collector'. + for (auto *opStmt : collector.loadOpStmts) { + addLoadOpStmtAt(dstPos, opStmt, dstForStmt); + // Add newly fused load ops to 'dstLoadOps' to be considered for + // fusion on subsequent iterations. + dstLoadOps.insert(opStmt); + } + for (auto *opStmt : collector.storeOpStmts) { + addStoreOpStmtAt(dstPos, opStmt, dstForStmt); + } + for (auto *forStmt : collector.forStmts) { + promoteIfSingleIteration(forStmt); + } + // Remove old src loop nest. + srcForStmt->erase(); + } + } + } + } + + // Walk MLFunction 'f' assigning each top-level statement a position, and + // gathering state on load and store ops. + void init(MLFunction *f) { + unsigned pos = 0; + for (auto &stmt : *f) { + if (auto *forStmt = dyn_cast(&stmt)) { + // Record all loads and store accesses in 'forStmt' at 'pos'. + LoopNestStateCollector collector; + collector.walkForStmt(forStmt); + // Create StmtInfo for 'forStmt' for top-level loop nests. + addStmtInfoAt(pos, forStmt, collector.hasIfStmt); + // Record mappings for loads and stores from 'collector'. + for (auto *opStmt : collector.loadOpStmts) { + addLoadOpStmtAt(pos, opStmt, forStmt); + } + for (auto *opStmt : collector.storeOpStmts) { + addStoreOpStmtAt(pos, opStmt, forStmt); + } + // Add 'pos' associated with 'forStmt' to worklist. + worklist.push_back(pos); + } + if (auto *opStmt = dyn_cast(&stmt)) { + if (auto loadOp = opStmt->dyn_cast()) { + // Create StmtInfo for top-level load op. + addStmtInfoAt(pos, &stmt, /*hasIfStmt=*/false); + addLoadOpStmtAt(pos, opStmt, /*containingForStmt=*/nullptr); + } + if (auto storeOp = opStmt->dyn_cast()) { + // Create StmtInfo for top-level store op. + addStmtInfoAt(pos, &stmt, /*hasIfStmt=*/false); + addStoreOpStmtAt(pos, opStmt, /*containingForStmt=*/nullptr); + } + } + if (auto *ifStmt = dyn_cast(&stmt)) { + addStmtInfoAt(pos, &stmt, /*hasIfStmt=*/true); + } + ++pos; + } + } + + // Check if fusing loop nest at 'srcPos' into the loop nest at 'dstPos' + // would violated any dependences w.r.t other loop nests in that range. + bool moveWouldViolateDependences(unsigned srcPos, unsigned dstPos) { + // Lookup src ForStmt at 'srcPos'. + auto *srcForStmt = getForStmtAtPos(srcPos); + // For each src load op: check for store ops in range (srcPos, dstPos). + if (forStmtToLoadOps.count(srcForStmt) > 0) { + for (auto *opStmt : forStmtToLoadOps[srcForStmt]) { + auto loadOp = opStmt->cast(); + auto *memref = cast(loadOp->getMemRef()); + for (unsigned pos = srcPos + 1; pos < dstPos; ++pos) { + if (memrefToStorePosSet.count(memref) > 0 && + memrefToStorePosSet[memref].count(pos) > 0) + return true; + } + } + } + // For each src store op: check for load ops in range (srcPos, dstPos). + if (forStmtToStoreOps.count(srcForStmt) > 0) { + for (auto *opStmt : forStmtToStoreOps[srcForStmt]) { + auto storeOp = opStmt->cast(); + auto *memref = cast(storeOp->getMemRef()); + for (unsigned pos = srcPos + 1; pos < dstPos; ++pos) { + if (memrefToLoadPosSet.count(memref) > 0 && + memrefToLoadPosSet[memref].count(pos) > 0) + return true; + } + } + } return false; - auto *forStmtA = lc.loopMap[f][0]; - auto *forStmtB = lc.loopMap[f][1]; - if (!MatchTestPatternLoopPair(&lc, candidate, forStmtA, forStmtB)) { - // Check for one level of loop nesting. - candidate->forStmtsA.push_back(forStmtA); - candidate->forStmtsB.push_back(forStmtB); - return MatchTestPatternLoopPair(&lc, candidate, - getSingleForStmtChild(forStmtA), - getSingleForStmtChild(forStmtB)); } - return true; -} -// FuseLoops implements the code generation mechanics of loop fusion. -// Fuses the operations statments from the inner-most loop in 'c.forStmtsB', -// by cloning them into the inner-most loop in 'c.forStmtsA', then erasing -// old statements and loops. -static void fuseLoops(const FusionCandidate &c) { - MLFuncBuilder builder(c.forStmtsA.back(), - StmtBlock::iterator(c.forStmtsA.back()->end())); - DenseMap operandMap; - assert(c.forStmtsA.size() == c.forStmtsB.size()); - for (unsigned i = 0, e = c.forStmtsA.size(); i < e; i++) { - // Map loop IVs to 'forStmtB[i]' to loop IV for 'forStmtA[i]'. - operandMap[c.forStmtsB[i]] = c.forStmtsA[i]; + // Update mappings of memref loads and stores at 'srcPos' to 'dstPos'. + void moveAccessesAndRemovePos(unsigned srcPos, unsigned dstPos) { + // Lookup ForStmt at 'srcPos'. + auto *srcForStmt = getForStmtAtPos(srcPos); + // Move load op accesses from src to dst. + if (forStmtToLoadOps.count(srcForStmt) > 0) { + for (auto *opStmt : forStmtToLoadOps[srcForStmt]) { + auto loadOp = opStmt->cast(); + auto *memref = cast(loadOp->getMemRef()); + // Remove 'memref' to 'srcPos' mapping. + memrefToLoadPosSet[memref].erase(srcPos); + } + } + // Move store op accesses from src to dst. + if (forStmtToStoreOps.count(srcForStmt) > 0) { + for (auto *opStmt : forStmtToStoreOps[srcForStmt]) { + auto storeOp = opStmt->cast(); + auto *memref = cast(storeOp->getMemRef()); + // Remove 'memref' to 'srcPos' mapping. + memrefToStorePosSet[memref].erase(srcPos); + } + } + // Remove old state. + forStmtToLoadOps.erase(srcForStmt); + forStmtToStoreOps.erase(srcForStmt); + posToStmtInfo.erase(srcPos); } - // Clone the body of inner-most loop in 'forStmtsB', into the body of - // inner-most loop in 'forStmtsA'. - SmallVector stmtsToErase; - auto *innerForStmtB = c.forStmtsB.back(); - for (auto &stmt : *innerForStmtB) { - builder.clone(stmt, operandMap); - stmtsToErase.push_back(&stmt); + + ForStmt *getForStmtAtPos(unsigned pos) { + assert(posToStmtInfo.count(pos) > 0); + assert(isa(posToStmtInfo[pos].stmt)); + return cast(posToStmtInfo[pos].stmt); } - // Erase 'forStmtB' and its statement list. - for (auto it = stmtsToErase.rbegin(); it != stmtsToErase.rend(); ++it) - (*it)->erase(); - // Erase 'forStmtsB' loop nest. - for (int i = static_cast(c.forStmtsB.size()) - 1; i >= 0; --i) - c.forStmtsB[i]->erase(); -} -PassResult LoopFusion::runOnMLFunction(MLFunction *f) { - FusionCandidate candidate; - if (!MatchTestPattern(f, &candidate)) - return failure(); - - // TODO(andydavis) Add checks for fusion-preventing dependences and ordering - // constraints which would prevent fusion. - // TODO(andydavis) This check is overly conservative for now. Support fusing - // statements with compatible dependences (i.e. statements where the - // dependence between the statements does not reverse direction when the - // statements are fused into the same loop). - llvm::SmallVector dependenceComponents; - // TODO(andydavis) Check dependences at differnt loop nest depths. - if (!checkMemrefAccessDependence(candidate.accessA, candidate.accessB, - /*loopNestDepth=*/0, - &dependenceComponents)) { - // Current conservatinve test policy: No dependence exists between accesses - // in different loop nests -> fuse loops. - fuseLoops(candidate); + void addStmtInfoAt(unsigned pos, Statement *stmt, bool hasIfStmt) { + StmtInfo stmtInfo; + stmtInfo.stmt = stmt; + stmtInfo.isOrContainsIfStmt = hasIfStmt; + // Add mapping from 'pos' to StmtInfo for 'forStmt'. + posToStmtInfo[pos] = stmtInfo; } + // Adds the following mappings: + // *) 'containingForStmt' to load 'opStmt' + // *) 'memref' of load 'opStmt' to 'topLevelPos'. + void addLoadOpStmtAt(unsigned topLevelPos, OperationStmt *opStmt, + ForStmt *containingForStmt) { + if (containingForStmt != nullptr) { + // Add mapping from 'containingForStmt' to 'opStmt' for load op. + forStmtToLoadOps[containingForStmt].insert(opStmt); + } + auto loadOp = opStmt->cast(); + auto *memref = cast(loadOp->getMemRef()); + // Add mapping from 'memref' to 'topLevelPos' for load. + memrefToLoadPosSet[memref].insert(topLevelPos); + } + + // Adds the following mappings: + // *) 'containingForStmt' to store 'opStmt' + // *) 'memref' of store 'opStmt' to 'topLevelPos'. + void addStoreOpStmtAt(unsigned topLevelPos, OperationStmt *opStmt, + ForStmt *containingForStmt) { + if (containingForStmt != nullptr) { + // Add mapping from 'forStmt' to 'opStmt' for store op. + forStmtToStoreOps[containingForStmt].insert(opStmt); + } + auto storeOp = opStmt->cast(); + auto *memref = cast(storeOp->getMemRef()); + // Add mapping from 'memref' to 'topLevelPos' for store. + memrefToStorePosSet[memref].insert(topLevelPos); + } + + bool hasIfStmts() { + for (auto &pair : posToStmtInfo) + if (pair.second.isOrContainsIfStmt) + return true; + return false; + } +}; + +} // end anonymous namespace + +PassResult LoopFusion::runOnMLFunction(MLFunction *f) { + GreedyFusionPolicy(f).run(); return success(); } diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index 2b8ce07b240..d0de62e8a06 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -1,141 +1,555 @@ -// RUN: mlir-opt %s -loop-fusion | FileCheck %s - -// CHECK: [[MAP0:#map[0-9]+]] = (d0) -> (d0 * 2 + 2) -// CHECK: [[MAP1:#map[0-9]+]] = (d0) -> (d0 * 3 + 1) -// CHECK: [[MAP2:#map[0-9]+]] = (d0) -> (d0 * 2) -// CHECK: [[MAP3:#map[0-9]+]] = (d0) -> (d0 * 2 + 1) -// CHECK: [[MAP4:#map[0-9]+]] = (d0, d1)[s0, s1] -> (d0 * 2 - d1 - s0 * 7 + 3, d0 * 9 + d1 * 3 + s1 * 13 - 10) -// CHECK: [[MAP6:#map[0-9]+]] = (d0, d1)[s0, s1] -> (d0 * 2 - 1, d1 * 3 + s0 + s1 * 3) - -// The dependence check for this test builds the following set of constraints, -// where the equality contraint equates the two accesses to the memref (from -// different loops), and the inequality constraints represent the upper and -// lower bounds for each loop. After elimination, this linear system can be -// shown to be non-empty (i.e. x0 = x1 = 1 is a solution). As such, the -// dependence check between accesses in the two loops will return true, and -// the loops (according to the current test loop fusion algorithm) should not be -// fused. -// -// x0 x1 x2 -// 2 -3 1 = 0 -// 1 0 0 >= 0 -// -1 0 100 >= 0 -// 0 1 0 >= 0 -// 0 -1 100 >= 0 -// -// CHECK-LABEL: mlfunc @loop_fusion_1d_should_not_fuse_loops() { -mlfunc @loop_fusion_1d_should_not_fuse_loops() { - %m = alloc() : memref<100xf32, (d0) -> (d0)> - // Check that the first loop remains unfused. - // CHECK: for %i0 = 0 to 100 { - // CHECK-NEXT: [[I0:%[0-9]+]] = affine_apply [[MAP0]](%i0) - // CHECK: store {{.*}}, %{{[0-9]+}}{{\[}}[[I0]]{{\]}} - // CHECK-NEXT: } - for %i0 = 0 to 100 { - %a0 = affine_apply (d0) -> (d0 * 2 + 2) (%i0) - %c1 = constant 1.0 : f32 - store %c1, %m[%a0] : memref<100xf32, (d0) -> (d0)> - } - // Check that the second loop remains unfused. - // CHECK: for %i1 = 0 to 100 { - // CHECK-NEXT: [[I1:%[0-9]+]] = affine_apply [[MAP1]](%i1) - // CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I1]]{{\]}} - // CHECK-NEXT: } - for %i1 = 0 to 100 { - %a1 = affine_apply (d0) -> (d0 * 3 + 1) (%i1) - %v0 = load %m[%a1] : memref<100xf32, (d0) -> (d0)> - } - return -} - -// The dependence check for this test builds the following set of constraints: -// -// x0 x1 x2 -// 2 -2 -1 = 0 -// 1 0 0 >= 0 -// -1 0 100 >= 0 -// 0 1 0 >= 0 -// 0 -1 100 >= 0 -// -// After elimination, this linear system can be shown to have no solutions, and -// so no dependence exists and the loops should be fused in this test (according -// to the current trivial test loop fusion policy). -// -// -// CHECK-LABEL: mlfunc @loop_fusion_1d_should_fuse_loops() { -mlfunc @loop_fusion_1d_should_fuse_loops() { - %m = alloc() : memref<100xf32, (d0) -> (d0)> - // Should fuse statements from the second loop into the first loop. - // CHECK: for %i0 = 0 to 100 { - // CHECK-NEXT: [[I0:%[0-9]+]] = affine_apply [[MAP2]](%i0) - // CHECK: store {{.*}}, %{{[0-9]+}}{{\[}}[[I0]]{{\]}} - // CHECK-NEXT: [[I1:%[0-9]+]] = affine_apply [[MAP3]](%i0) - // CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I1]]{{\]}} +// RUN: mlir-opt %s -loop-fusion -split-input-file -verify | FileCheck %s + +// TODO(andydavis) Add more tests: +// *) Add nested fusion test cases when non-constant loop bound support is +// added to iteration domain in dependence check. +// *) Add a test w/ floordiv/ceildiv/mod when supported in dependence check. +// *) Add tests which check fused computation slice indexing and loop bounds. +// TODO(andydavis) Test clean up: move memref allocs to mlfunc args. + +// ----- + +// CHECK: [[MAP0:#map[0-9]+]] = (d0) -> (d0) + +// CHECK-LABEL: mlfunc @should_fuse_raw_dep_for_locality() { +mlfunc @should_fuse_raw_dep_for_locality() { + %m = alloc() : memref<10xf32> + %cf7 = constant 7.0 : f32 + + for %i0 = 0 to 10 { + store %cf7, %m[%i0] : memref<10xf32> + } + for %i1 = 0 to 10 { + %v0 = load %m[%i1] : memref<10xf32> + } + // CHECK: for %i0 = 0 to 10 { + // CHECK-NEXT: %1 = affine_apply [[MAP0]](%i0) + // CHECK-NEXT: store %cst, %0[%1] : memref<10xf32> + // CHECK-NEXT: %2 = load %0[%i0] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: return + return +} + +// ----- + +// CHECK: [[MAP0:#map[0-9]+]] = (d0) -> (d0) + +// TODO(andydavis) Turn this into a proper reduction when constraints on +// the current greedy fusion policy are relaxed. +// CHECK-LABEL: mlfunc @should_fuse_reduction_to_pointwise() { +mlfunc @should_fuse_reduction_to_pointwise() { + %a = alloc() : memref<10x10xf32> + %b = alloc() : memref<10xf32> + %c = alloc() : memref<10xf32> + %d = alloc() : memref<10xf32> + + %cf7 = constant 7.0 : f32 + + for %i0 = 0 to 10 { + for %i1 = 0 to 10 { + %v0 = load %d[%i0] : memref<10xf32> + %v1 = load %a[%i0, %i1] : memref<10x10xf32> + %v3 = addf %v0, %v1 : f32 + store %v3, %b[%i0] : memref<10xf32> + } + } + for %i2 = 0 to 10 { + %v4 = load %b[%i2] : memref<10xf32> + store %v4, %c[%i2] : memref<10xf32> + } + + // Should fuse in entire inner loop on %i1 from source loop nest, as %i1 + // is not used in the access function of the store/load on %b. + // CHECK: for %i0 = 0 to 10 { + // CHECK-NEXT: %4 = affine_apply [[MAP0]](%i0) + // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: %5 = load %3[%4] : memref<10xf32> + // CHECK-NEXT: %6 = load %0[%4, %i1] : memref<10x10xf32> + // CHECK-NEXT: %7 = addf %5, %6 : f32 + // CHECK-NEXT: store %7, %1[%4] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: %8 = load %1[%i0] : memref<10xf32> + // CHECK-NEXT: store %8, %2[%i0] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: return + return +} + +// ----- + +// CHECK: [[MAP_SHIFT_MINUS_ONE:#map[0-9]+]] = (d0) -> (d0 - 1) +// CHECK: [[MAP_SHIFT_BY_ONE:#map[0-9]+]] = (d0, d1) -> (d0 + 1, d1 + 1) + +// CHECK-LABEL: mlfunc @should_fuse_loop_nests_with_shifts() { +mlfunc @should_fuse_loop_nests_with_shifts() { + %a = alloc() : memref<10x10xf32> + %cf7 = constant 7.0 : f32 + + for %i0 = 0 to 10 { + for %i1 = 0 to 10 { + %a0 = affine_apply (d0, d1) -> (d0 + 1, d1 + 1) (%i0, %i1) + store %cf7, %a[%a0#0, %a0#1] : memref<10x10xf32> + } + } + for %i2 = 0 to 10 { + for %i3 = 0 to 10 { + %v0 = load %a[%i2, %i3] : memref<10x10xf32> + } + } + + // CHECK: for %i0 = 0 to 10 { + // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: %1 = affine_apply [[MAP_SHIFT_MINUS_ONE]](%i0) + // CHECK-NEXT: %2 = affine_apply [[MAP_SHIFT_MINUS_ONE]](%i1) + // CHECK-NEXT: %3 = affine_apply [[MAP_SHIFT_BY_ONE]](%1, %2) + // CHECK-NEXT: store %cst, %0[%3#0, %3#1] : memref<10x10xf32> + // CHECK-NEXT: %4 = load %0[%i0, %i1] : memref<10x10xf32> + // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return - for %i0 = 0 to 100 { - %a0 = affine_apply (d0) -> (d0 * 2) (%i0) - %c1 = constant 1.0 : f32 - store %c1, %m[%a0] : memref<100xf32, (d0) -> (d0)> - } - - for %i1 = 0 to 100 { - %a1 = affine_apply (d0) -> (d0 * 2 + 1) (%i1) - - %v0 = load %m[%a1] : memref<100xf32, (d0) -> (d0)> - } - return -} - -// TODO(andydavis) Add LoopFusion tests based on fusion policy and cost model. - -// The dependence check for this test builds the following set of -// equality constraints (one for each memref dimension). Note: inequality -// constraints for loop bounds not shown. -// -// i0 i1 i2 i3 s0 s1 s2 c -// 2 -1 -2 0 -7 0 0 4 = 0 -// 9 3 0 -3 0 12 -3 -10 = 0 -// -// The second equality will fail the GCD test and so the system has no solution, -// so the loops should be fused under the current test policy. -// -// CHECK-LABEL: mlfunc @loop_fusion_2d_should_fuse_loops() { -mlfunc @loop_fusion_2d_should_fuse_loops() { - %m = alloc() : memref<10x10xf32> - - %s0 = constant 7 : index - %s1 = constant 11 : index - %s2 = constant 13 : index - // Should fuse statements from the second loop into the first loop. - // CHECK: for %i0 = 0 to 100 { - // CHECK-NEXT: for %i1 = 0 to 50 { - // CHECK-NEXT: [[I0:%[0-9]+]] = affine_apply [[MAP4]](%i0, %i1)[%c7, %c11] - // CHECK: store {{.*}}, %{{[0-9]+}}{{\[}}[[I0]]#0, [[I0]]#1{{\]}} - // CHECK-NEXT: [[I1:%[0-9]+]] = affine_apply [[MAP6]](%i0, %i1)[%c11, %c13] - // CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I1]]#0, [[I1]]#1{{\]}} + return +} + +// ----- + +// CHECK: [[MAP_IDENTITY:#map[0-9]+]] = (d0) -> (d0) + +// CHECK-LABEL: mlfunc @should_fuse_loop_nest() { +mlfunc @should_fuse_loop_nest() { + %a = alloc() : memref<10x10xf32> + %b = alloc() : memref<10x10xf32> + %cf7 = constant 7.0 : f32 + + for %i0 = 0 to 10 { + for %i1 = 0 to 10 { + store %cf7, %a[%i0, %i1] : memref<10x10xf32> + } + } + for %i2 = 0 to 10 { + for %i3 = 0 to 10 { + %v0 = load %a[%i3, %i2] : memref<10x10xf32> + store %v0, %b[%i2, %i3] : memref<10x10xf32> + } + } + for %i4 = 0 to 10 { + for %i5 = 0 to 10 { + %v1 = load %b[%i4, %i5] : memref<10x10xf32> + } + } + + // CHECK: for %i0 = 0 to 10 { + // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: %2 = affine_apply [[MAP_IDENTITY]](%i1) + // CHECK-NEXT: %3 = affine_apply [[MAP_IDENTITY]](%i0) + // CHECK-NEXT: store %cst, %0[%2, %3] : memref<10x10xf32> + // CHECK-NEXT: %4 = affine_apply [[MAP_IDENTITY]](%i0) + // CHECK-NEXT: %5 = affine_apply [[MAP_IDENTITY]](%i1) + // CHECK-NEXT: %6 = load %0[%5, %4] : memref<10x10xf32> + // CHECK-NEXT: store %6, %1[%4, %5] : memref<10x10xf32> + // CHECK-NEXT: %7 = load %1[%i0, %i1] : memref<10x10xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return - for %i0 = 0 to 100 { - for %i1 = 0 to 50 { - %a0 = affine_apply - (d0, d1)[s0, s1] -> - (d0 * 2 -d1 + -7 * s0 + 3 , d0 * 9 + d1 * 3 + 13 * s1 - 10) - (%i0, %i1)[%s0, %s1] - %c1 = constant 1.0 : f32 - store %c1, %m[%a0#0, %a0#1] : memref<10x10xf32> + return +} + +// ----- + +// CHECK: [[MAP0:#map[0-9]+]] = (d0) -> (d0) + +// CHECK-LABEL: mlfunc @should_fuse_across_intermediate_loop_with_no_deps() { +mlfunc @should_fuse_across_intermediate_loop_with_no_deps() { + %a = alloc() : memref<10xf32> + %b = alloc() : memref<10xf32> + %c = alloc() : memref<10xf32> + + %cf7 = constant 7.0 : f32 + + for %i0 = 0 to 10 { + %v0 = load %a[%i0] : memref<10xf32> + store %v0, %b[%i0] : memref<10xf32> + } + for %i1 = 0 to 10 { + store %cf7, %c[%i1] : memref<10xf32> + } + for %i2 = 0 to 10 { + %v1 = load %b[%i2] : memref<10xf32> + } + + // Should fuse first loop (past second loop with no dependences) into third. + // CHECK: for %i0 = 0 to 10 { + // CHECK-NEXT: store %cst, %2[%i0] : memref<10xf32> + // CHECK-NEXT: } + // CHECK: for %i1 = 0 to 10 { + // CHECK-NEXT: %3 = affine_apply [[MAP0]](%i1) + // CHECK-NEXT: %4 = load %0[%3] : memref<10xf32> + // CHECK-NEXT: store %4, %1[%3] : memref<10xf32> + // CHECK-NEXT: %5 = load %1[%i1] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: return + return +} + +// ----- + +// CHECK: [[MAP0:#map[0-9]+]] = (d0) -> (d0) + +// CHECK-LABEL: mlfunc @should_fuse_all_loops() { +mlfunc @should_fuse_all_loops() { + %a = alloc() : memref<10xf32> + %b = alloc() : memref<10xf32> + %cf7 = constant 7.0 : f32 + + // Set up flow dependences from first and second loops to third. + for %i0 = 0 to 10 { + store %cf7, %a[%i0] : memref<10xf32> + } + for %i1 = 0 to 10 { + store %cf7, %b[%i1] : memref<10xf32> + } + for %i2 = 0 to 10 { + %v0 = load %a[%i2] : memref<10xf32> + %v1 = load %b[%i2] : memref<10xf32> + } + + // Should fuse first and second loops into third. + // CHECK: for %i0 = 0 to 10 { + // CHECK-NEXT: %2 = affine_apply [[MAP0]](%i0) + // CHECK-NEXT: store %cst, %0[%2] : memref<10xf32> + // CHECK-NEXT: %3 = affine_apply [[MAP0]](%i0) + // CHECK-NEXT: store %cst, %1[%3] : memref<10xf32> + // CHECK-NEXT: %4 = load %0[%i0] : memref<10xf32> + // CHECK-NEXT: %5 = load %1[%i0] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: return + return +} + +// ----- + +// CHECK: [[MAP0:#map[0-9]+]] = (d0) -> (d0) + +// CHECK-LABEL: mlfunc @should_fuse_first_and_second_loops() { +mlfunc @should_fuse_first_and_second_loops() { + %a = alloc() : memref<10xf32> + %b = alloc() : memref<10xf32> + %c = alloc() : memref<10xf32> + + %cf7 = constant 7.0 : f32 + + for %i0 = 0 to 10 { + store %cf7, %a[%i0] : memref<10xf32> + } + for %i1 = 0 to 10 { + %v0 = load %a[%i1] : memref<10xf32> + store %cf7, %b[%i1] : memref<10xf32> + } + for %i2 = 0 to 10 { + %v1 = load %c[%i2] : memref<10xf32> + } + + // Should fuse first loop into the second (last loop should not be fused). + // CHECK: for %i0 = 0 to 10 { + // CHECK-NEXT: %3 = affine_apply [[MAP0]](%i0) + // CHECK-NEXT: store %cst, %0[%3] : memref<10xf32> + // CHECK-NEXT: %4 = load %0[%i0] : memref<10xf32> + // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32> + // CHECK-NEXT: } + // CHECK: for %i1 = 0 to 10 { + // CHECK-NEXT: %5 = load %2[%i1] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: return + + return +} + +// ----- + +// CHECK-LABEL: mlfunc @should_not_fuse_would_create_cycle() { +mlfunc @should_not_fuse_would_create_cycle() { + %a = alloc() : memref<10xf32> + %b = alloc() : memref<10xf32> + %c = alloc() : memref<10xf32> + + %cf7 = constant 7.0 : f32 + + // Set up the following dependences: + // 1) loop0 -> loop1 on memref '%a' + // 2) loop0 -> loop2 on memref '%b' + // 3) loop1 -> loop2 on memref '%c' + for %i0 = 0 to 10 { + %v0 = load %a[%i0] : memref<10xf32> + store %cf7, %b[%i0] : memref<10xf32> + } + for %i1 = 0 to 10 { + store %cf7, %a[%i1] : memref<10xf32> + %v1 = load %c[%i1] : memref<10xf32> + } + for %i2 = 0 to 10 { + %v2 = load %b[%i2] : memref<10xf32> + store %cf7, %c[%i2] : memref<10xf32> + } + // Should not fuse: fusing loop first loop into last would create a cycle. + // CHECK: for %i0 = 0 to 10 { + // CHECK-NEXT: %3 = load %0[%i0] : memref<10xf32> + // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32> + // CHECK-NEXT: } + // CHECK: for %i1 = 0 to 10 { + // CHECK-NEXT: store %cst, %0[%i1] : memref<10xf32> + // CHECK-NEXT: %4 = load %2[%i1] : memref<10xf32> + // CHECK-NEXT: } + // CHECK: for %i2 = 0 to 10 { + // CHECK-NEXT: %5 = load %1[%i2] : memref<10xf32> + // CHECK-NEXT: store %cst, %2[%i2] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: return + return +} + +// ----- + +// CHECK-LABEL: mlfunc @should_not_fuse_raw_dep_would_be_violated() { +mlfunc @should_not_fuse_raw_dep_would_be_violated() { + %m = alloc() : memref<10xf32> + %cf7 = constant 7.0 : f32 + + for %i0 = 0 to 10 { + store %cf7, %m[%i0] : memref<10xf32> + } + for %i1 = 0 to 10 { + %v0 = load %m[%i1] : memref<10xf32> + } + for %i2 = 0 to 10 { + %v1 = load %m[%i2] : memref<10xf32> + } + // Fusing loop %i0 to %i2 would violate the RAW dependence between %i0 and %i1 + // CHECK: for %i0 = 0 to 10 { + // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> + // CHECK-NEXT: } + // CHECK: for %i1 = 0 to 10 { + // CHECK-NEXT: %1 = load %0[%i1] : memref<10xf32> + // CHECK-NEXT: } + // CHECK: for %i2 = 0 to 10 { + // CHECK-NEXT: %2 = load %0[%i2] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: return + return +} + +// ----- + +// CHECK-LABEL: mlfunc @should_not_fuse_waw_dep_would_be_violated() { +mlfunc @should_not_fuse_waw_dep_would_be_violated() { + %m = alloc() : memref<10xf32> + %cf7 = constant 7.0 : f32 + + for %i0 = 0 to 10 { + store %cf7, %m[%i0] : memref<10xf32> + } + for %i1 = 0 to 10 { + store %cf7, %m[%i1] : memref<10xf32> + } + for %i2 = 0 to 10 { + %v1 = load %m[%i2] : memref<10xf32> + } + // Fusing loop %i0 to %i2 would violate the WAW dependence between %i0 and %i1 + // CHECK: for %i0 = 0 to 10 { + // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> + // CHECK-NEXT: } + // CHECK: for %i1 = 0 to 10 { + // CHECK-NEXT: store %cst, %0[%i1] : memref<10xf32> + // CHECK-NEXT: } + // CHECK: for %i2 = 0 to 10 { + // CHECK-NEXT: %1 = load %0[%i2] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: return + return +} + +// ----- + +// CHECK-LABEL: mlfunc @should_not_fuse_war_dep_would_be_violated() { +mlfunc @should_not_fuse_war_dep_would_be_violated() { + %a = alloc() : memref<10xf32> + %b = alloc() : memref<10xf32> + %cf7 = constant 7.0 : f32 + + for %i0 = 0 to 10 { + %v0 = load %a[%i0] : memref<10xf32> + store %v0, %b[%i0] : memref<10xf32> + } + for %i1 = 0 to 10 { + store %cf7, %a[%i1] : memref<10xf32> + } + for %i2 = 0 to 10 { + %v1 = load %b[%i2] : memref<10xf32> + } + // Fusing loop %i0 to %i2 would violate the WAR dependence between %i0 and %i1 + // CHECK: for %i0 = 0 to 10 { + // CHECK-NEXT: %2 = load %0[%i0] : memref<10xf32> + // CHECK-NEXT: store %2, %1[%i0] : memref<10xf32> + // CHECK-NEXT: } + // CHECK: for %i1 = 0 to 10 { + // CHECK-NEXT: store %cst, %0[%i1] : memref<10xf32> + // CHECK-NEXT: } + // CHECK: for %i2 = 0 to 10 { + // CHECK-NEXT: %3 = load %1[%i2] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: return + return +} + +// ----- + +// CHECK-LABEL: mlfunc @should_not_fuse_if_top_level_access() { +mlfunc @should_not_fuse_if_top_level_access() { + %m = alloc() : memref<10xf32> + %cf7 = constant 7.0 : f32 + + for %i0 = 0 to 10 { + store %cf7, %m[%i0] : memref<10xf32> + } + for %i1 = 0 to 10 { + %v0 = load %m[%i1] : memref<10xf32> + } + + %c0 = constant 4 : index + %v1 = load %m[%c0] : memref<10xf32> + // Top-level load to '%m' should prevent fusion. + // CHECK: for %i0 = 0 to 10 { + // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> + // CHECK-NEXT: } + // CHECK: for %i1 = 0 to 10 { + // CHECK-NEXT: %1 = load %0[%i1] : memref<10xf32> + // CHECK-NEXT: } + return +} + +// ----- + +// CHECK: [[MAP0:#map[0-9]+]] = (d0) -> (d0) + +// CHECK-LABEL: mlfunc @should_fuse_no_top_level_access() { +mlfunc @should_fuse_no_top_level_access() { + %m = alloc() : memref<10xf32> + %cf7 = constant 7.0 : f32 + + for %i0 = 0 to 10 { + store %cf7, %m[%i0] : memref<10xf32> + } + for %i1 = 0 to 10 { + %v0 = load %m[%i1] : memref<10xf32> + } + // CHECK: for %i0 = 0 to 10 { + // CHECK-NEXT: %1 = affine_apply #map0(%i0) + // CHECK-NEXT: store %cst, %0[%1] : memref<10xf32> + // CHECK-NEXT: %2 = load %0[%i0] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: return + return +} + +// ----- + +#set0 = (d0) : (1 == 0) + +// CHECK-LABEL: mlfunc @should_not_fuse_if_stmt_at_top_level() { +mlfunc @should_not_fuse_if_stmt_at_top_level() { + %m = alloc() : memref<10xf32> + %cf7 = constant 7.0 : f32 + + for %i0 = 0 to 10 { + store %cf7, %m[%i0] : memref<10xf32> + } + for %i1 = 0 to 10 { + %v0 = load %m[%i1] : memref<10xf32> + } + %c0 = constant 4 : index + if #set0(%c0) { + } + // Top-level IfStmt should prevent fusion. + // CHECK: for %i0 = 0 to 10 { + // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> + // CHECK-NEXT: } + // CHECK: for %i1 = 0 to 10 { + // CHECK-NEXT: %1 = load %0[%i1] : memref<10xf32> + // CHECK-NEXT: } + return +} + +// ----- + +#set0 = (d0) : (1 == 0) + +// CHECK-LABEL: mlfunc @should_not_fuse_if_stmt_in_loop_nest() { +mlfunc @should_not_fuse_if_stmt_in_loop_nest() { + %m = alloc() : memref<10xf32> + %cf7 = constant 7.0 : f32 + %c4 = constant 4 : index + + for %i0 = 0 to 10 { + store %cf7, %m[%i0] : memref<10xf32> + } + for %i1 = 0 to 10 { + if #set0(%c4) { } + %v0 = load %m[%i1] : memref<10xf32> } - for %i2 = 0 to 100 { - for %i3 = 0 to 50 { - %a1 = affine_apply - (d0, d1)[s0, s1] -> - (d0 * 2 - 1, d1 * 3 + s0 + s1 * 3) (%i2, %i3)[%s1, %s2] - %v0 = load %m[%a1#0, %a1#1] : memref<10x10xf32> + // IfStmt in ForStmt should prevent fusion. + // CHECK: for %i0 = 0 to 10 { + // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> + // CHECK-NEXT: } + // CHECK: for %i1 = 0 to 10 { + // CHECK-NEXT: if #set0(%c4) { + // CHECK-NEXT: } + // CHECK-NEXT: %1 = load %0[%i1] : memref<10xf32> + // CHECK-NEXT: } + return +} + +// ----- + +// CHECK: [[MAP0:#map[0-9]+]] = (d0) -> (d0) +// CHECK: [[MAP1:#map[0-9]+]] = (d0, d1, d2) -> (d0, d1, d2) +// CHECK: [[MAP2:#map[0-9]+]] = (d0, d1, d2) -> (d1, d2, d0) + +// CHECK-LABEL: mlfunc @remap_ivs() { +mlfunc @remap_ivs() { + %m = alloc() : memref<10x20x30xf32> + + %cf7 = constant 7.0 : f32 + for %i0 = 0 to 10 { + for %i1 = 0 to 20 { + for %i2 = 0 to 30 { + %a0 = affine_apply (d0, d1, d2) -> (d0, d1, d2) (%i0, %i1, %i2) + store %cf7, %m[%a0#0, %a0#1, %a0#2] : memref<10x20x30xf32> + } } } + for %i3 = 0 to 30 { + for %i4 = 0 to 10 { + for %i5 = 0 to 20 { + %a1 = affine_apply (d0, d1, d2) -> (d1, d2, d0) (%i3, %i4, %i5) + %v0 = load %m[%a1#0, %a1#1, %a1#2] : memref<10x20x30xf32> + } + } + } +// CHECK: for %i0 = 0 to 30 { +// CHECK-NEXT: for %i1 = 0 to 10 { +// CHECK-NEXT: for %i2 = 0 to 20 { +// CHECK-NEXT: %1 = affine_apply [[MAP0]](%i1) +// CHECK-NEXT: %2 = affine_apply [[MAP0]](%i2) +// CHECK-NEXT: %3 = affine_apply [[MAP0]](%i0) +// CHECK-NEXT: %4 = affine_apply [[MAP1]](%1, %2, %3) +// CHECK-NEXT: store %cst, %0[%4#0, %4#1, %4#2] : memref<10x20x30xf32> +// CHECK-NEXT: %5 = affine_apply [[MAP2]](%i0, %i1, %i2) +// CHECK-NEXT: %6 = load %0[%5#0, %5#1, %5#2] : memref<10x20x30xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return return -} \ No newline at end of file +} -- cgit v1.2.3 From 6892ffb8965c2849565fedfbbf60f05e475c9858 Mon Sep 17 00:00:00 2001 From: MLIR Team Date: Wed, 19 Dec 2018 20:42:55 -0800 Subject: Improve loop fusion algorithm by using a memref dependence graph. Fixed TODO for reduction fusion unit test. PiperOrigin-RevId: 226277226 --- mlir/lib/Transforms/LoopFusion.cpp | 594 +++++++++++++++++++--------------- mlir/test/Transforms/loop-fusion.mlir | 21 +- 2 files changed, 351 insertions(+), 264 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 521fca8979f..6393fa6069d 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -130,24 +130,270 @@ public: } }; -// GreedyFusionPolicy greedily fuses loop nests which have a producer/consumer +// MemRefDependenceGraph is a graph data structure where graph nodes are +// top-level statements in an MLFunction which contain load/store ops, and edges +// are memref dependences between the nodes. +// TODO(andydavis) Add a depth parameter to dependence graph construction. +struct MemRefDependenceGraph { +public: + // Node represents a node in the graph. A Node is either an entire loop nest + // rooted at the top level which contains loads/stores, or a top level + // load/store. + struct Node { + // The unique identifier of this node in the graph. + unsigned id; + // The top-level statment which is (or contains) loads/stores. + Statement *stmt; + // List of load op stmts. + SmallVector loads; + // List of store op stmts. + SmallVector stores; + Node(unsigned id, Statement *stmt) : id(id), stmt(stmt) {} + + // Returns the load op count for 'memref'. + unsigned getLoadOpCount(MLValue *memref) { + unsigned loadOpCount = 0; + for (auto *loadOpStmt : loads) { + if (memref == cast(loadOpStmt->cast()->getMemRef())) + ++loadOpCount; + } + return loadOpCount; + } + + // Returns the store op count for 'memref'. + unsigned getStoreOpCount(MLValue *memref) { + unsigned storeOpCount = 0; + for (auto *storeOpStmt : stores) { + if (memref == cast(storeOpStmt->cast()->getMemRef())) + ++storeOpCount; + } + return storeOpCount; + } + }; + + // Edge represents a memref data dependece between nodes in the graph. + struct Edge { + // The id of the node at the other end of the edge. + unsigned id; + // The memref on which this edge represents a dependence. + MLValue *memref; + }; + + // Map from node id to Node. + DenseMap nodes; + // Map from node id to list of input edges. + DenseMap> inEdges; + // Map from node id to list of output edges. + DenseMap> outEdges; + + MemRefDependenceGraph() {} + + // Initializes the dependence graph based on operations in 'f'. + // Returns true on success, false otherwise. + bool init(MLFunction *f); + + // Returns the graph node for 'id'. + Node *getNode(unsigned id) { + auto it = nodes.find(id); + assert(it != nodes.end()); + return &it->second; + } + + // Adds an edge from node 'srcId' to node 'dstId' for 'memref'. + void addEdge(unsigned srcId, unsigned dstId, MLValue *memref) { + outEdges[srcId].push_back({dstId, memref}); + inEdges[dstId].push_back({srcId, memref}); + } + + // Removes an edge from node 'srcId' to node 'dstId' for 'memref'. + void removeEdge(unsigned srcId, unsigned dstId, MLValue *memref) { + assert(inEdges.count(dstId) > 0); + assert(outEdges.count(srcId) > 0); + // Remove 'srcId' from 'inEdges[dstId]'. + for (auto it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) { + if ((*it).id == srcId && (*it).memref == memref) { + inEdges[dstId].erase(it); + break; + } + } + // Remove 'dstId' from 'outEdges[srcId]'. + for (auto it = outEdges[srcId].begin(); it != outEdges[srcId].end(); ++it) { + if ((*it).id == dstId && (*it).memref == memref) { + outEdges[srcId].erase(it); + break; + } + } + } + + // Returns the input edge count for node 'id' and 'memref'. + unsigned getInEdgeCount(unsigned id, MLValue *memref) { + unsigned inEdgeCount = 0; + if (inEdges.count(id) > 0) + for (auto &inEdge : inEdges[id]) + if (inEdge.memref == memref) + ++inEdgeCount; + return inEdgeCount; + } + + // Returns the output edge count for node 'id' and 'memref'. + unsigned getOutEdgeCount(unsigned id, MLValue *memref) { + unsigned outEdgeCount = 0; + if (outEdges.count(id) > 0) + for (auto &outEdge : outEdges[id]) + if (outEdge.memref == memref) + ++outEdgeCount; + return outEdgeCount; + } + + // Returns the min node id of all output edges from node 'id'. + unsigned getMinOutEdgeNodeId(unsigned id) { + unsigned minId = std::numeric_limits::max(); + if (outEdges.count(id) > 0) + for (auto &outEdge : outEdges[id]) + minId = std::min(minId, outEdge.id); + return minId; + } + + // Updates edge mappings from node 'srcId' to node 'dstId' and removes + // state associated with node 'srcId'. + void updateEdgesAndRemoveSrcNode(unsigned srcId, unsigned dstId) { + // For each edge in 'inEdges[srcId]': add new edge remaping to 'dstId'. + if (inEdges.count(srcId) > 0) { + SmallVector oldInEdges = inEdges[srcId]; + for (auto &inEdge : oldInEdges) { + // Remove edge from 'inEdge.id' to 'srcId'. + removeEdge(inEdge.id, srcId, inEdge.memref); + // Add edge from 'inEdge.id' to 'dstId'. + addEdge(inEdge.id, dstId, inEdge.memref); + } + } + // For each edge in 'outEdges[srcId]': add new edge remaping to 'dstId'. + if (outEdges.count(srcId) > 0) { + SmallVector oldOutEdges = outEdges[srcId]; + for (auto &outEdge : oldOutEdges) { + // Remove edge from 'srcId' to 'outEdge.id'. + removeEdge(srcId, outEdge.id, outEdge.memref); + // Add edge from 'dstId' to 'outEdge.id' (if 'outEdge.id' != 'dstId'). + if (outEdge.id != dstId) + addEdge(dstId, outEdge.id, outEdge.memref); + } + } + // Remove 'srcId' from graph state. + inEdges.erase(srcId); + outEdges.erase(srcId); + nodes.erase(srcId); + } + + // Adds ops in 'loads' and 'stores' to node at 'id'. + void addToNode(unsigned id, const SmallVectorImpl &loads, + const SmallVectorImpl &stores) { + Node *node = getNode(id); + for (auto *loadOpStmt : loads) + node->loads.push_back(loadOpStmt); + for (auto *storeOpStmt : stores) + node->stores.push_back(storeOpStmt); + } + + void print(raw_ostream &os) const { + os << "\nMemRefDependenceGraph\n"; + os << "\nNodes:\n"; + for (auto &idAndNode : nodes) { + os << "Node: " << idAndNode.first << "\n"; + auto it = inEdges.find(idAndNode.first); + if (it != inEdges.end()) { + for (const auto &e : it->second) + os << " InEdge: " << e.id << " " << e.memref << "\n"; + } + it = outEdges.find(idAndNode.first); + if (it != outEdges.end()) { + for (const auto &e : it->second) + os << " OutEdge: " << e.id << " " << e.memref << "\n"; + } + } + } + void dump() const { print(llvm::errs()); } +}; + +// Intializes the data dependence graph by walking statements in 'f'. +// Assigns each node in the graph a node id based on program order in 'f'. +// TODO(andydavis) Add support for taking a StmtBlock arg to construct the +// dependence graph at a different depth. +bool MemRefDependenceGraph::init(MLFunction *f) { + unsigned id = 0; + DenseMap> memrefAccesses; + for (auto &stmt : *f) { + if (auto *forStmt = dyn_cast(&stmt)) { + // Create graph node 'id' to represent top-level 'forStmt' and record + // all loads and store accesses it contains. + LoopNestStateCollector collector; + collector.walkForStmt(forStmt); + // Return false if IfStmts are found (not currently supported). + if (collector.hasIfStmt) + return false; + Node node(id++, &stmt); + for (auto *opStmt : collector.loadOpStmts) { + node.loads.push_back(opStmt); + auto *memref = cast(opStmt->cast()->getMemRef()); + memrefAccesses[memref].insert(node.id); + } + for (auto *opStmt : collector.storeOpStmts) { + node.stores.push_back(opStmt); + auto *memref = cast(opStmt->cast()->getMemRef()); + memrefAccesses[memref].insert(node.id); + } + nodes.insert({node.id, node}); + } + if (auto *opStmt = dyn_cast(&stmt)) { + if (auto loadOp = opStmt->dyn_cast()) { + // Create graph node for top-level load op. + Node node(id++, &stmt); + node.loads.push_back(opStmt); + auto *memref = cast(opStmt->cast()->getMemRef()); + memrefAccesses[memref].insert(node.id); + nodes.insert({node.id, node}); + } + if (auto storeOp = opStmt->dyn_cast()) { + // Create graph node for top-level store op. + Node node(id++, &stmt); + node.stores.push_back(opStmt); + auto *memref = cast(opStmt->cast()->getMemRef()); + memrefAccesses[memref].insert(node.id); + nodes.insert({node.id, node}); + } + } + // Return false if IfStmts are found (not currently supported). + if (isa(&stmt)) + return false; + } + + // Walk memref access lists and add graph edges between dependent nodes. + for (auto &memrefAndList : memrefAccesses) { + unsigned n = memrefAndList.second.size(); + for (unsigned i = 0; i < n; ++i) { + unsigned srcId = memrefAndList.second[i]; + bool srcHasStore = + getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0; + for (unsigned j = i + 1; j < n; ++j) { + unsigned dstId = memrefAndList.second[j]; + bool dstHasStore = + getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0; + if (srcHasStore || dstHasStore) + addEdge(srcId, dstId, memrefAndList.first); + } + } + } + return true; +} + +// GreedyFusion greedily fuses loop nests which have a producer/consumer // relationship on a memref, with the goal of improving locality. Currently, // this the producer/consumer relationship is required to be unique in the // MLFunction (there are TODOs to relax this constraint in the future). // // The steps of the algorithm are as follows: // -// *) Initialize. While visiting each statement in the MLFunction do: -// *) Assign each top-level ForStmt a 'position' which is its initial -// position in the MLFunction's StmtBlock at the start of the pass. -// *) Gather memref load/store state aggregated by top-level statement. For -// example, all loads and stores contained in a loop nest are aggregated -// under the loop nest's top-level ForStmt. -// *) Add each top-level ForStmt to a worklist. -// -// *) Run. The algorithm processes the worklist with the following steps: -// *) The worklist is processed in reverse order (starting from the last -// top-level ForStmt in the MLFunction). +// *) A worklist is initialized with node ids from the dependence graph. +// *) For each node id in the worklist: // *) Pop a ForStmt of the worklist. This 'dstForStmt' will be a candidate // destination ForStmt into which fusion will be attempted. // *) Add each LoadOp currently in 'dstForStmt' into list 'dstLoadOps'. @@ -157,7 +403,7 @@ public: // *) Check if dependences would be violated by the fusion. For example, // the src loop nest may load from memrefs which are different than // the producer-consumer memref between src and dest loop nests. -// *) Get a computation slice of 'srcLoopNest', which adjust its loop +// *) Get a computation slice of 'srcLoopNest', which adjusts its loop // bounds to be functions of 'dstLoopNest' IVs and symbols. // *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest', // just before the dst load op user. @@ -168,268 +414,112 @@ public: // // Given a graph where top-level statements are vertices in the set 'V' and // edges in the set 'E' are dependences between vertices, this algorithm -// takes O(V) time for initialization, and has runtime O(V * E). -// TODO(andydavis) Reduce this time complexity to O(V + E). +// takes O(V) time for initialization, and has runtime O(V + E). // -// This greedy algorithm is not 'maximally' but there is a TODO to fix this. +// This greedy algorithm is not 'maximal' due to the current restriction of +// fusing along single producer consumer edges, but there is a TODO to fix this. // // TODO(andydavis) Experiment with other fusion policies. -struct GreedyFusionPolicy { - // Convenience wrapper with information about 'stmt' ready to access. - struct StmtInfo { - Statement *stmt; - bool isOrContainsIfStmt = false; - }; - // The worklist of top-level loop nest positions. +// TODO(andydavis) Add support for fusing for input reuse (perhaps by +// constructing a graph with edges which represent loads from the same memref +// in two different loop nestst. +struct GreedyFusion { +public: + MemRefDependenceGraph *mdg; SmallVector worklist; - // Mapping from top-level position to StmtInfo. - DenseMap posToStmtInfo; - // Mapping from memref MLValue to set of top-level positions of loop nests - // which contain load ops on that memref. - DenseMap> memrefToLoadPosSet; - // Mapping from memref MLValue to set of top-level positions of loop nests - // which contain store ops on that memref. - DenseMap> memrefToStorePosSet; - // Mapping from top-level loop nest to the set of load ops it contains. - DenseMap> forStmtToLoadOps; - // Mapping from top-level loop nest to the set of store ops it contains. - DenseMap> forStmtToStoreOps; - - GreedyFusionPolicy(MLFunction *f) { init(f); } - void run() { - if (hasIfStmts()) - return; + GreedyFusion(MemRefDependenceGraph *mdg) : mdg(mdg) { + // Initialize worklist with nodes from 'mdg'. + worklist.resize(mdg->nodes.size()); + std::iota(worklist.begin(), worklist.end(), 0); + } + void run() { while (!worklist.empty()) { - // Pop the position of a loop nest into which fusion will be attempted. - unsigned dstPos = worklist.back(); + unsigned dstId = worklist.back(); worklist.pop_back(); - // Skip if 'dstPos' is not tracked (was fused into another loop nest). - if (posToStmtInfo.count(dstPos) == 0) + // Skip if this node was removed (fused into another node). + if (mdg->nodes.count(dstId) == 0) continue; - // Get the top-level ForStmt at 'dstPos'. - auto *dstForStmt = getForStmtAtPos(dstPos); - // Skip if this ForStmt contains no load ops. - if (forStmtToLoadOps.count(dstForStmt) == 0) + // Get 'dstNode' into which to attempt fusion. + auto *dstNode = mdg->getNode(dstId); + // Skip if 'dstNode' is not a loop nest. + if (!isa(dstNode->stmt)) continue; - // Greedy Policy: iterate through load ops in 'dstForStmt', greedily - // fusing in src loop nests which have a single store op on the same - // memref, until a fixed point is reached where there is nothing left to - // fuse. - SetVector dstLoadOps = forStmtToLoadOps[dstForStmt]; - while (!dstLoadOps.empty()) { - auto *dstLoadOpStmt = dstLoadOps.pop_back_val(); - - auto dstLoadOp = dstLoadOpStmt->cast(); - auto *memref = cast(dstLoadOp->getMemRef()); - // Skip if not single src store / dst load pair on 'memref'. - if (memrefToLoadPosSet[memref].size() != 1 || - memrefToStorePosSet[memref].size() != 1) - continue; - unsigned srcPos = *memrefToStorePosSet[memref].begin(); - if (srcPos >= dstPos) + SmallVector loads = dstNode->loads; + while (!loads.empty()) { + auto *dstLoadOpStmt = loads.pop_back_val(); + auto *memref = + cast(dstLoadOpStmt->cast()->getMemRef()); + // Skip 'dstLoadOpStmt' if multiple loads to 'memref' in 'dstNode'. + if (dstNode->getLoadOpCount(memref) != 1) continue; - auto *srcForStmt = getForStmtAtPos(srcPos); - // Skip if 'srcForStmt' has more than one store op. - if (forStmtToStoreOps[srcForStmt].size() > 1) + // Skip if no input edges along which to fuse. + if (mdg->inEdges.count(dstId) == 0) continue; - // Skip if fusion would violated dependences between 'memref' access - // for loop nests between 'srcPos' and 'dstPos': - // For each src load op: check for store ops in range (srcPos, dstPos). - // For each src store op: check for load ops in range (srcPos, dstPos). - if (moveWouldViolateDependences(srcPos, dstPos)) - continue; - auto *srcStoreOpStmt = forStmtToStoreOps[srcForStmt].front(); - // Build fusion candidate out of 'srcStoreOpStmt' and 'dstLoadOpStmt'. - FusionCandidate candidate = - buildFusionCandidate(srcStoreOpStmt, dstLoadOpStmt); - // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. - auto *sliceLoopNest = mlir::insertBackwardComputationSlice( - &candidate.srcAccess, &candidate.dstAccess); - if (sliceLoopNest != nullptr) { - // Remove 'srcPos' mappings from 'state'. - moveAccessesAndRemovePos(srcPos, dstPos); - // Record all load/store accesses in 'sliceLoopNest' at 'dstPos'. - LoopNestStateCollector collector; - collector.walkForStmt(sliceLoopNest); - // Record mappings for loads and stores from 'collector'. - for (auto *opStmt : collector.loadOpStmts) { - addLoadOpStmtAt(dstPos, opStmt, dstForStmt); - // Add newly fused load ops to 'dstLoadOps' to be considered for - // fusion on subsequent iterations. - dstLoadOps.insert(opStmt); - } - for (auto *opStmt : collector.storeOpStmts) { - addStoreOpStmtAt(dstPos, opStmt, dstForStmt); + // Iterate through in edges for 'dstId'. + for (auto &srcEdge : mdg->inEdges[dstId]) { + // Skip 'srcEdge' if not for 'memref'. + if (srcEdge.memref != memref) + continue; + auto *srcNode = mdg->getNode(srcEdge.id); + // Skip if 'srcNode' is not a loop nest. + if (!isa(srcNode->stmt)) + continue; + // Skip if 'srcNode' has more than one store to 'memref'. + if (srcNode->getStoreOpCount(memref) != 1) + continue; + // Skip 'srcNode' if it has out edges on 'memref' other than 'dstId'. + if (mdg->getOutEdgeCount(srcNode->id, memref) != 1) + continue; + // Skip 'srcNode' if it has in dependence edges. NOTE: This is overly + // TODO(andydavis) Track dependence type with edges, and just check + // for WAW dependence edge here. + if (mdg->getInEdgeCount(srcNode->id, memref) != 0) + continue; + // Skip if 'srcNode' has out edges to other memrefs after 'dstId'. + if (mdg->getMinOutEdgeNodeId(srcNode->id) != dstId) + continue; + // Get unique 'srcNode' store op. + auto *srcStoreOpStmt = srcNode->stores.front(); + // Build fusion candidate out of 'srcStoreOpStmt' and 'dstLoadOpStmt'. + FusionCandidate candidate = + buildFusionCandidate(srcStoreOpStmt, dstLoadOpStmt); + // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. + auto *sliceLoopNest = mlir::insertBackwardComputationSlice( + &candidate.srcAccess, &candidate.dstAccess); + if (sliceLoopNest != nullptr) { + // Remove edges between 'srcNode' and 'dstNode' and remove 'srcNode' + mdg->updateEdgesAndRemoveSrcNode(srcNode->id, dstNode->id); + // Record all load/store accesses in 'sliceLoopNest' at 'dstPos'. + LoopNestStateCollector collector; + collector.walkForStmt(sliceLoopNest); + mdg->addToNode(dstId, collector.loadOpStmts, + collector.storeOpStmts); + // Add new load ops to current Node load op list 'loads' to + // continue fusing based on new operands. + for (auto *loadOpStmt : collector.loadOpStmts) + loads.push_back(loadOpStmt); + // Promote single iteration loops to single IV value. + for (auto *forStmt : collector.forStmts) { + promoteIfSingleIteration(forStmt); + } + // Remove old src loop nest. + cast(srcNode->stmt)->erase(); } - for (auto *forStmt : collector.forStmts) { - promoteIfSingleIteration(forStmt); - } - // Remove old src loop nest. - srcForStmt->erase(); - } - } - } - } - - // Walk MLFunction 'f' assigning each top-level statement a position, and - // gathering state on load and store ops. - void init(MLFunction *f) { - unsigned pos = 0; - for (auto &stmt : *f) { - if (auto *forStmt = dyn_cast(&stmt)) { - // Record all loads and store accesses in 'forStmt' at 'pos'. - LoopNestStateCollector collector; - collector.walkForStmt(forStmt); - // Create StmtInfo for 'forStmt' for top-level loop nests. - addStmtInfoAt(pos, forStmt, collector.hasIfStmt); - // Record mappings for loads and stores from 'collector'. - for (auto *opStmt : collector.loadOpStmts) { - addLoadOpStmtAt(pos, opStmt, forStmt); - } - for (auto *opStmt : collector.storeOpStmts) { - addStoreOpStmtAt(pos, opStmt, forStmt); - } - // Add 'pos' associated with 'forStmt' to worklist. - worklist.push_back(pos); - } - if (auto *opStmt = dyn_cast(&stmt)) { - if (auto loadOp = opStmt->dyn_cast()) { - // Create StmtInfo for top-level load op. - addStmtInfoAt(pos, &stmt, /*hasIfStmt=*/false); - addLoadOpStmtAt(pos, opStmt, /*containingForStmt=*/nullptr); - } - if (auto storeOp = opStmt->dyn_cast()) { - // Create StmtInfo for top-level store op. - addStmtInfoAt(pos, &stmt, /*hasIfStmt=*/false); - addStoreOpStmtAt(pos, opStmt, /*containingForStmt=*/nullptr); - } - } - if (auto *ifStmt = dyn_cast(&stmt)) { - addStmtInfoAt(pos, &stmt, /*hasIfStmt=*/true); - } - ++pos; - } - } - - // Check if fusing loop nest at 'srcPos' into the loop nest at 'dstPos' - // would violated any dependences w.r.t other loop nests in that range. - bool moveWouldViolateDependences(unsigned srcPos, unsigned dstPos) { - // Lookup src ForStmt at 'srcPos'. - auto *srcForStmt = getForStmtAtPos(srcPos); - // For each src load op: check for store ops in range (srcPos, dstPos). - if (forStmtToLoadOps.count(srcForStmt) > 0) { - for (auto *opStmt : forStmtToLoadOps[srcForStmt]) { - auto loadOp = opStmt->cast(); - auto *memref = cast(loadOp->getMemRef()); - for (unsigned pos = srcPos + 1; pos < dstPos; ++pos) { - if (memrefToStorePosSet.count(memref) > 0 && - memrefToStorePosSet[memref].count(pos) > 0) - return true; } } } - // For each src store op: check for load ops in range (srcPos, dstPos). - if (forStmtToStoreOps.count(srcForStmt) > 0) { - for (auto *opStmt : forStmtToStoreOps[srcForStmt]) { - auto storeOp = opStmt->cast(); - auto *memref = cast(storeOp->getMemRef()); - for (unsigned pos = srcPos + 1; pos < dstPos; ++pos) { - if (memrefToLoadPosSet.count(memref) > 0 && - memrefToLoadPosSet[memref].count(pos) > 0) - return true; - } - } - } - return false; - } - - // Update mappings of memref loads and stores at 'srcPos' to 'dstPos'. - void moveAccessesAndRemovePos(unsigned srcPos, unsigned dstPos) { - // Lookup ForStmt at 'srcPos'. - auto *srcForStmt = getForStmtAtPos(srcPos); - // Move load op accesses from src to dst. - if (forStmtToLoadOps.count(srcForStmt) > 0) { - for (auto *opStmt : forStmtToLoadOps[srcForStmt]) { - auto loadOp = opStmt->cast(); - auto *memref = cast(loadOp->getMemRef()); - // Remove 'memref' to 'srcPos' mapping. - memrefToLoadPosSet[memref].erase(srcPos); - } - } - // Move store op accesses from src to dst. - if (forStmtToStoreOps.count(srcForStmt) > 0) { - for (auto *opStmt : forStmtToStoreOps[srcForStmt]) { - auto storeOp = opStmt->cast(); - auto *memref = cast(storeOp->getMemRef()); - // Remove 'memref' to 'srcPos' mapping. - memrefToStorePosSet[memref].erase(srcPos); - } - } - // Remove old state. - forStmtToLoadOps.erase(srcForStmt); - forStmtToStoreOps.erase(srcForStmt); - posToStmtInfo.erase(srcPos); - } - - ForStmt *getForStmtAtPos(unsigned pos) { - assert(posToStmtInfo.count(pos) > 0); - assert(isa(posToStmtInfo[pos].stmt)); - return cast(posToStmtInfo[pos].stmt); - } - - void addStmtInfoAt(unsigned pos, Statement *stmt, bool hasIfStmt) { - StmtInfo stmtInfo; - stmtInfo.stmt = stmt; - stmtInfo.isOrContainsIfStmt = hasIfStmt; - // Add mapping from 'pos' to StmtInfo for 'forStmt'. - posToStmtInfo[pos] = stmtInfo; - } - - // Adds the following mappings: - // *) 'containingForStmt' to load 'opStmt' - // *) 'memref' of load 'opStmt' to 'topLevelPos'. - void addLoadOpStmtAt(unsigned topLevelPos, OperationStmt *opStmt, - ForStmt *containingForStmt) { - if (containingForStmt != nullptr) { - // Add mapping from 'containingForStmt' to 'opStmt' for load op. - forStmtToLoadOps[containingForStmt].insert(opStmt); - } - auto loadOp = opStmt->cast(); - auto *memref = cast(loadOp->getMemRef()); - // Add mapping from 'memref' to 'topLevelPos' for load. - memrefToLoadPosSet[memref].insert(topLevelPos); - } - - // Adds the following mappings: - // *) 'containingForStmt' to store 'opStmt' - // *) 'memref' of store 'opStmt' to 'topLevelPos'. - void addStoreOpStmtAt(unsigned topLevelPos, OperationStmt *opStmt, - ForStmt *containingForStmt) { - if (containingForStmt != nullptr) { - // Add mapping from 'forStmt' to 'opStmt' for store op. - forStmtToStoreOps[containingForStmt].insert(opStmt); - } - auto storeOp = opStmt->cast(); - auto *memref = cast(storeOp->getMemRef()); - // Add mapping from 'memref' to 'topLevelPos' for store. - memrefToStorePosSet[memref].insert(topLevelPos); - } - - bool hasIfStmts() { - for (auto &pair : posToStmtInfo) - if (pair.second.isOrContainsIfStmt) - return true; - return false; } }; } // end anonymous namespace PassResult LoopFusion::runOnMLFunction(MLFunction *f) { - GreedyFusionPolicy(f).run(); + MemRefDependenceGraph g; + if (g.init(f)) + GreedyFusion(&g).run(); return success(); } diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index d0de62e8a06..a668e181cc1 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -35,20 +35,17 @@ mlfunc @should_fuse_raw_dep_for_locality() { // CHECK: [[MAP0:#map[0-9]+]] = (d0) -> (d0) -// TODO(andydavis) Turn this into a proper reduction when constraints on -// the current greedy fusion policy are relaxed. // CHECK-LABEL: mlfunc @should_fuse_reduction_to_pointwise() { mlfunc @should_fuse_reduction_to_pointwise() { %a = alloc() : memref<10x10xf32> %b = alloc() : memref<10xf32> %c = alloc() : memref<10xf32> - %d = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 for %i0 = 0 to 10 { for %i1 = 0 to 10 { - %v0 = load %d[%i0] : memref<10xf32> + %v0 = load %b[%i0] : memref<10xf32> %v1 = load %a[%i0, %i1] : memref<10x10xf32> %v3 = addf %v0, %v1 : f32 store %v3, %b[%i0] : memref<10xf32> @@ -62,15 +59,15 @@ mlfunc @should_fuse_reduction_to_pointwise() { // Should fuse in entire inner loop on %i1 from source loop nest, as %i1 // is not used in the access function of the store/load on %b. // CHECK: for %i0 = 0 to 10 { - // CHECK-NEXT: %4 = affine_apply [[MAP0]](%i0) + // CHECK-NEXT: %3 = affine_apply [[MAP0]](%i0) // CHECK-NEXT: for %i1 = 0 to 10 { - // CHECK-NEXT: %5 = load %3[%4] : memref<10xf32> - // CHECK-NEXT: %6 = load %0[%4, %i1] : memref<10x10xf32> - // CHECK-NEXT: %7 = addf %5, %6 : f32 - // CHECK-NEXT: store %7, %1[%4] : memref<10xf32> - // CHECK-NEXT: } - // CHECK-NEXT: %8 = load %1[%i0] : memref<10xf32> - // CHECK-NEXT: store %8, %2[%i0] : memref<10xf32> + // CHECK-NEXT: %4 = load %1[%3] : memref<10xf32> + // CHECK-NEXT: %5 = load %0[%3, %i1] : memref<10x10xf32> + // CHECK-NEXT: %6 = addf %4, %5 : f32 + // CHECK-NEXT: store %6, %1[%3] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: %7 = load %1[%i0] : memref<10xf32> + // CHECK-NEXT: store %7, %2[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return return -- cgit v1.2.3 From 4eef795a1dbd7eafa9a45303f01c51921729f1f4 Mon Sep 17 00:00:00 2001 From: MLIR Team Date: Fri, 21 Dec 2018 11:06:23 -0800 Subject: Computation slice update: adds parameters to insertBackwardComputationSlice which specify the source loop nest depth at which to perform iteration space slicing, and the destination loop nest depth at which to insert the compution slice. Updates LoopFusion pass to take these parameters as command line flags for experimentation. PiperOrigin-RevId: 226514297 --- mlir/include/mlir/Analysis/Utils.h | 17 ++++++------- mlir/lib/Analysis/Utils.cpp | 43 +++++++++++++++++++++++++------- mlir/lib/Transforms/LoopFusion.cpp | 36 ++++++++++++++++++++++++++- mlir/test/Transforms/loop-fusion.mlir | 46 +++++++++++++++++++++++++++++++++++ 4 files changed, 123 insertions(+), 19 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index 1743d49aafd..365fc74a778 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -143,20 +143,19 @@ bool boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, bool emitError = true); /// Creates a clone of the computation contained in the loop nest surrounding -/// 'srcAccess', and inserts it at the beginning of the statement block of the -/// loop containing 'dstAccess'. Returns the top-level loop of the computation -/// slice on success, returns nullptr otherwise. -// Computes memref dependence between 'srcAccess' and 'dstAccess' and uses the -// dependence constraint system to create AffineMaps with which to adjust the -// loop bounds of the inserted compution slice so that they are functions of the -// loop IVs and symbols of the loops surrounding 'dstAccess'. -// TODO(andydavis) Add 'dstLoopDepth' argument for computation slice insertion. +/// 'srcAccess', slices the iteration space of the first 'srcLoopDepth' src loop +/// IVs, and inserts the computation slice at the beginning of the statement +/// block of the loop at 'dstLoopDepth' in the loop nest surrounding +/// 'dstAccess'. Returns the top-level loop of the computation slice on +/// success, returns nullptr otherwise. // Loop depth is a crucial optimization choice that determines where to // materialize the results of the backward slice - presenting a trade-off b/w // storage and redundant computation in several cases // TODO(andydavis) Support computation slices with common surrounding loops. ForStmt *insertBackwardComputationSlice(MemRefAccess *srcAccess, - MemRefAccess *dstAccess); + MemRefAccess *dstAccess, + unsigned srcLoopDepth, + unsigned dstLoopDepth); } // end namespace mlir #endif // MLIR_ANALYSIS_UTILS_H diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 86f5fbf8ea4..cc30cfffb06 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -28,6 +28,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/StandardOps/StandardOps.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" #define DEBUG_TYPE "analysis-utils" @@ -374,11 +375,14 @@ static Statement *getStmtAtPosition(ArrayRef positions, return nullptr; } -// TODO(andydavis) Support a 'dstLoopDepth' argument for computation slice -// insertion (currently the computation slice is inserted at the same -// loop depth as 'dstAccess.opStmt'. +// Computes memref dependence between 'srcAccess' and 'dstAccess' and uses the +// dependence constraint system to create AffineMaps with which to adjust the +// loop bounds of the inserted compution slice so that they are functions of the +// loop IVs and symbols of the loops surrounding 'dstAccess'. ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess, - MemRefAccess *dstAccess) { + MemRefAccess *dstAccess, + unsigned srcLoopDepth, + unsigned dstLoopDepth) { FlatAffineConstraints dependenceConstraints; if (!checkMemrefAccessDependence(*srcAccess, *dstAccess, /*loopDepth=*/1, &dependenceConstraints, @@ -389,21 +393,32 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess, SmallVector srcLoopNest; getLoopIVs(*srcAccess->opStmt, &srcLoopNest); unsigned srcLoopNestSize = srcLoopNest.size(); + assert(srcLoopDepth <= srcLoopNestSize); // Get loop nest surrounding dst operation. SmallVector dstLoopNest; getLoopIVs(*dstAccess->opStmt, &dstLoopNest); unsigned dstLoopNestSize = dstLoopNest.size(); + (void)dstLoopNestSize; + assert(dstLoopDepth > 0); + assert(dstLoopDepth <= dstLoopNestSize); // Solve for src IVs in terms of dst IVs, symbols and constants. SmallVector srcIvMaps(srcLoopNestSize, AffineMap::Null()); std::vector> srcIvOperands(srcLoopNestSize); for (unsigned i = 0; i < srcLoopNestSize; ++i) { + // Skip IVs which are greater than requested loop depth. + if (i >= srcLoopDepth) { + srcIvMaps[i] = AffineMap::Null(); + continue; + } auto cst = dependenceConstraints.clone(); for (int j = srcLoopNestSize - 1; j >= 0; --j) { if (i != j) cst->projectOut(j); } + // TODO(andydavis) Check for case with two equalities where we have + // set on IV to a constant. Set a constant IV map for these cases. if (cst->getNumEqualities() != 1) { srcIvMaps[i] = AffineMap::Null(); continue; @@ -412,11 +427,18 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess, SmallVector nonZeroSymbolIds; srcIvMaps[i] = cst->toAffineMapFromEq(0, 0, srcAccess->opStmt->getContext(), &nonZeroDimIds, &nonZeroSymbolIds); - if (srcIvMaps[i] == AffineMap::Null()) + if (srcIvMaps[i] == AffineMap::Null()) { continue; + } // Add operands for all non-zero dst dims and symbols. // TODO(andydavis) Add local variable support. for (auto dimId : nonZeroDimIds) { + if (dimId - 1 >= dstLoopDepth) { + // This src IV has a dependence on dst IV dstLoopDepth where it will + // be inserted. So we cannot slice the iteration space at srcLoopDepth, + // and also insert it into the dst loop nest at 'dstLoopDepth'. + return nullptr; + } srcIvOperands[i].push_back(dstLoopNest[dimId - 1]); } // TODO(andydavis) Add symbols from the access function. Ideally, we @@ -429,8 +451,8 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess, findStmtPosition(srcAccess->opStmt, srcLoopNest[0]->getBlock(), &positions); // Clone src loop nest and insert it a the beginning of the statement block - // of the same loop in which containts 'dstAccess->opStmt'. - auto *dstForStmt = dstLoopNest[dstLoopNestSize - 1]; + // of the loop at 'dstLoopDepth' in 'dstLoopNest'. + auto *dstForStmt = dstLoopNest[dstLoopDepth - 1]; MLFuncBuilder b(dstForStmt, dstForStmt->begin()); DenseMap operandMap; auto *sliceLoopNest = cast(b.clone(*srcLoopNest[0], operandMap)); @@ -442,11 +464,14 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess, SmallVector sliceSurroundingLoops; getLoopIVs(*sliceStmt, &sliceSurroundingLoops); unsigned sliceSurroundingLoopsSize = sliceSurroundingLoops.size(); + (void)sliceSurroundingLoopsSize; // Update loop bounds for loops in 'sliceLoopNest'. - for (unsigned i = dstLoopNestSize; i < sliceSurroundingLoopsSize; ++i) { + unsigned sliceLoopLimit = dstLoopDepth + srcLoopNestSize; + assert(sliceLoopLimit <= sliceSurroundingLoopsSize); + for (unsigned i = dstLoopDepth; i < sliceLoopLimit; ++i) { auto *forStmt = sliceSurroundingLoops[i]; - unsigned index = i - dstLoopNestSize; + unsigned index = i - dstLoopDepth; AffineMap lbMap = srcIvMaps[index]; if (lbMap == AffineMap::Null()) continue; diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 6393fa6069d..df68765aeb7 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -35,12 +35,27 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SetVector.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/raw_ostream.h" using llvm::SetVector; using namespace mlir; +// TODO(andydavis) These flags are global for the pass to be used for +// experimentation. Find a way to provide more fine grained control (i.e. +// depth per-loop nest, or depth per load/store op) for this pass utilizing a +// cost model. +static llvm::cl::opt clSrcLoopDepth( + "src-loop-depth", llvm::cl::Hidden, + llvm::cl::desc("Controls the depth of the source loop nest at which " + "to apply loop iteration slicing before fusion.")); + +static llvm::cl::opt clDstLoopDepth( + "dst-loop-depth", llvm::cl::Hidden, + llvm::cl::desc("Controls the depth of the destination loop nest at which " + "to fuse the source loop nest slice.")); + namespace { /// Loop fusion pass. This pass currently supports a greedy fusion policy, @@ -107,6 +122,18 @@ static FusionCandidate buildFusionCandidate(OperationStmt *srcStoreOpStmt, return candidate; } +// Returns the loop depth of the loop nest surrounding 'opStmt'. +static unsigned getLoopDepth(OperationStmt *opStmt) { + unsigned loopDepth = 0; + auto *currStmt = opStmt->getParentStmt(); + ForStmt *currForStmt; + while (currStmt && (currForStmt = dyn_cast(currStmt))) { + ++loopDepth; + currStmt = currStmt->getParentStmt(); + } + return loopDepth; +} + namespace { // LoopNestStateCollector walks loop nests and collects load and store @@ -487,8 +514,15 @@ public: FusionCandidate candidate = buildFusionCandidate(srcStoreOpStmt, dstLoadOpStmt); // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. + unsigned srcLoopDepth = clSrcLoopDepth.getNumOccurrences() > 0 + ? clSrcLoopDepth + : getLoopDepth(srcStoreOpStmt); + unsigned dstLoopDepth = clDstLoopDepth.getNumOccurrences() > 0 + ? clDstLoopDepth + : getLoopDepth(dstLoadOpStmt); auto *sliceLoopNest = mlir::insertBackwardComputationSlice( - &candidate.srcAccess, &candidate.dstAccess); + &candidate.srcAccess, &candidate.dstAccess, srcLoopDepth, + dstLoopDepth); if (sliceLoopNest != nullptr) { // Remove edges between 'srcNode' and 'dstNode' and remove 'srcNode' mdg->updateEdgesAndRemoveSrcNode(srcNode->id, dstNode->id); diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index a668e181cc1..f26041ed169 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -loop-fusion -split-input-file -verify | FileCheck %s +// RUN: mlir-opt %s -loop-fusion -src-loop-depth=1 -dst-loop-depth=1 -split-input-file -verify | FileCheck %s --check-prefix DEPTH1 // TODO(andydavis) Add more tests: // *) Add nested fusion test cases when non-constant loop bound support is @@ -550,3 +551,48 @@ mlfunc @remap_ivs() { return } + +// ----- + +// DEPTH1: #map0 = (d0) -> (d0) +// DEPTH1: #map1 = (d0, d1, d2) -> (d0, d1, d2) + +// DEPTH1-LABEL: mlfunc @fuse_slice_at_depth1() { +mlfunc @fuse_slice_at_depth1() { + %m = alloc() : memref<100x16x100xf32> + + %cf7 = constant 7.0 : f32 + for %i0 = 0 to 100 { + for %i1 = 0 to 16 { + for %i2 = 0 to 100 { + %a0 = affine_apply (d0, d1, d2) -> (d0, d1, d2) (%i0, %i1, %i2) + store %cf7, %m[%a0#0, %a0#1, %a0#2] : memref<100x16x100xf32> + } + } + } + for %i3 = 0 to 100 { + for %i4 = 0 to 16 { + for %i5 = 0 to 100 { + %a1 = affine_apply (d0, d1, d2) -> (d0, d1, d2) (%i3, %i4, %i5) + %v0 = load %m[%a1#0, %a1#1, %a1#2] : memref<100x16x100xf32> + } + } + } +// DEPTH1: for %i0 = 0 to 100 { +// DEPTH1-NEXT: %1 = affine_apply #map0(%i0) +// DEPTH1-NEXT: for %i1 = 0 to 16 { +// DEPTH1-NEXT: for %i2 = 0 to 100 { +// DEPTH1-NEXT: %2 = affine_apply #map1(%1, %i1, %i2) +// DEPTH1-NEXT: store %cst, %0[%2#0, %2#1, %2#2] : memref<100x16x100xf32> +// DEPTH1-NEXT: } +// DEPTH1-NEXT: } +// DEPTH1-NEXT: for %i3 = 0 to 16 { +// DEPTH1-NEXT: for %i4 = 0 to 100 { +// DEPTH1-NEXT: %3 = affine_apply #map1(%i0, %i3, %i4) +// DEPTH1-NEXT: %4 = load %0[%3#0, %3#1, %3#2] : memref<100x16x100xf32> +// DEPTH1-NEXT: } +// DEPTH1-NEXT: } +// DEPTH1-NEXT: } +// DEPTH1-NEXT: return + return +} -- cgit v1.2.3 From d613f5ab65bbb80c3f5a0a38fef22cb4878c4358 Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Wed, 26 Dec 2018 11:21:53 -0800 Subject: Refactor MLFunction to contain a StmtBlock for its body instead of inheriting from it. This is necessary progress to squaring away the parent relationship that a StmtBlock has with its enclosing if/for/fn, and makes room for functions to have more than one block in the future. This also removes IfClause and ForStmtBody. This is step 5/n towards merging instructions and statements, NFC. PiperOrigin-RevId: 226936541 --- mlir/include/mlir/IR/Builders.h | 8 ++- mlir/include/mlir/IR/MLFunction.h | 9 +-- mlir/include/mlir/IR/Statements.h | 70 ++++------------------ mlir/include/mlir/IR/StmtBlock.h | 24 ++++---- mlir/include/mlir/IR/StmtVisitor.h | 6 +- mlir/lib/Analysis/Verifier.cpp | 10 ++-- mlir/lib/IR/AsmPrinter.cpp | 6 +- mlir/lib/IR/BuiltinOps.cpp | 2 +- mlir/lib/IR/Function.cpp | 9 ++- mlir/lib/IR/Operation.cpp | 2 +- mlir/lib/IR/StmtBlock.cpp | 27 ++++----- mlir/lib/Parser/Parser.cpp | 10 ++-- mlir/lib/Transforms/ConvertToCFG.cpp | 2 +- mlir/lib/Transforms/DmaGeneration.cpp | 2 +- mlir/lib/Transforms/LoopFusion.cpp | 2 +- mlir/lib/Transforms/LoopTiling.cpp | 4 +- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 4 +- mlir/lib/Transforms/LowerVectorTransfers.cpp | 2 +- .../Utils/GreedyPatternRewriteDriver.cpp | 3 +- mlir/lib/Transforms/Utils/LoopUtils.cpp | 2 +- 20 files changed, 81 insertions(+), 123 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index f743930bd58..3525c31e099 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -294,6 +294,12 @@ public: setInsertionPoint(stmt); } + MLFuncBuilder(StmtBlock *block) + // TODO: Eliminate findFunction from this. + : MLFuncBuilder(block->findFunction()) { + setInsertionPoint(block, block->end()); + } + MLFuncBuilder(StmtBlock *block, StmtBlock::iterator insertPoint) // TODO: Eliminate findFunction from this. : MLFuncBuilder(block->findFunction()) { @@ -304,7 +310,7 @@ public: /// the function. MLFuncBuilder(MLFunction *func) : Builder(func->getContext()), function(func) { - setInsertionPoint(func, func->begin()); + setInsertionPoint(func->getBody(), func->getBody()->begin()); } /// Return the function this builder is referring to. diff --git a/mlir/include/mlir/IR/MLFunction.h b/mlir/include/mlir/IR/MLFunction.h index cf7f64f869a..58261e04d8f 100644 --- a/mlir/include/mlir/IR/MLFunction.h +++ b/mlir/include/mlir/IR/MLFunction.h @@ -36,7 +36,6 @@ template class ArgumentIterator; // include nested affine for loops, conditionals and operations. class MLFunction final : public Function, - public StmtBlock, private llvm::TrailingObjects { public: /// Creates a new MLFunction with the specific type. @@ -44,6 +43,9 @@ public: FunctionType type, ArrayRef attrs = {}); + StmtBlock *getBody() { return &body; } + const StmtBlock *getBody() const { return &body; } + /// Destroys this statement and its subclass data. void destroy(); @@ -98,9 +100,6 @@ public: static bool classof(const Function *func) { return func->getKind() == Function::Kind::MLFunc; } - static bool classof(const StmtBlock *block) { - return block->getStmtBlockKind() == StmtBlockKind::MLFunc; - } private: MLFunction(Location location, StringRef name, FunctionType type, @@ -119,6 +118,8 @@ private: MutableArrayRef getArgumentsInternal() { return {getTrailingObjects(), getNumArguments()}; } + + StmtBlock body; }; //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Statements.h b/mlir/include/mlir/IR/Statements.h index fcffa9397c9..7ff955cc824 100644 --- a/mlir/include/mlir/IR/Statements.h +++ b/mlir/include/mlir/IR/Statements.h @@ -274,29 +274,6 @@ private: size_t numTrailingObjects(OverloadToken) const { return numSuccs; } }; -/// A ForStmtBody represents statements contained within a ForStmt. -class ForStmtBody : public StmtBlock { -public: - explicit ForStmtBody(ForStmt *stmt) - : StmtBlock(StmtBlockKind::ForBody), forStmt(stmt) { - assert(stmt != nullptr && "ForStmtBody must have non-null parent"); - } - - ~ForStmtBody() {} - - /// Methods for support type inquiry through isa, cast, and dyn_cast - static bool classof(const StmtBlock *block) { - return block->getStmtBlockKind() == StmtBlockKind::ForBody; - } - - /// Returns the 'for' statement that contains this body. - ForStmt *getFor() { return forStmt; } - const ForStmt *getFor() const { return forStmt; } - -private: - ForStmt *forStmt; -}; - /// For statement represents an affine loop nest. class ForStmt : public Statement, public MLValue { public: @@ -324,10 +301,10 @@ public: using const_operand_range = llvm::iterator_range; /// Get the body of the ForStmt. - ForStmtBody *getBody() { return &body; } + StmtBlock *getBody() { return &body; } /// Get the body of the ForStmt. - const ForStmtBody *getBody() const { return &body; } + const StmtBlock *getBody() const { return &body; } //===--------------------------------------------------------------------===// // Bounds and step @@ -455,7 +432,7 @@ public: private: // The StmtBlock for the body. - ForStmtBody body; + StmtBlock body; // Affine map for the lower bound. AffineMap lbMap; @@ -525,31 +502,6 @@ private: friend class ForStmt; }; -/// An if clause represents statements contained within a then or an else clause -/// of an if statement. -class IfClause : public StmtBlock { -public: - explicit IfClause(IfStmt *stmt) - : StmtBlock(StmtBlockKind::IfClause), ifStmt(stmt) { - assert(stmt != nullptr && "If clause must have non-null parent"); - } - - /// Methods for support type inquiry through isa, cast, and dyn_cast - static bool classof(const StmtBlock *block) { - return block->getStmtBlockKind() == StmtBlockKind::IfClause; - } - - ~IfClause() {} - - /// Returns the if statement that contains this clause. - const IfStmt *getIf() const { return ifStmt; } - - IfStmt *getIf() { return ifStmt; } - -private: - IfStmt *ifStmt; -}; - /// If statement restricts execution to a subset of the loop iteration space. class IfStmt : public Statement { public: @@ -561,15 +513,15 @@ public: // Then, else, condition. //===--------------------------------------------------------------------===// - IfClause *getThen() { return &thenClause; } - const IfClause *getThen() const { return &thenClause; } - IfClause *getElse() { return elseClause; } - const IfClause *getElse() const { return elseClause; } + StmtBlock *getThen() { return &thenClause; } + const StmtBlock *getThen() const { return &thenClause; } + StmtBlock *getElse() { return elseClause; } + const StmtBlock *getElse() const { return elseClause; } bool hasElse() const { return elseClause != nullptr; } - IfClause *createElse() { + StmtBlock *createElse() { assert(elseClause == nullptr && "already has an else clause!"); - return (elseClause = new IfClause(this)); + return (elseClause = new StmtBlock(this)); } const AffineCondition getCondition() const; @@ -634,9 +586,9 @@ public: private: // it is always present. - IfClause thenClause; + StmtBlock thenClause; // 'else' clause of the if statement. 'nullptr' if there is no else clause. - IfClause *elseClause; + StmtBlock *elseClause; // The integer set capturing the conditional guard. IntegerSet set; diff --git a/mlir/include/mlir/IR/StmtBlock.h b/mlir/include/mlir/IR/StmtBlock.h index 65e0f19066e..9ee4d651029 100644 --- a/mlir/include/mlir/IR/StmtBlock.h +++ b/mlir/include/mlir/IR/StmtBlock.h @@ -39,12 +39,8 @@ template class StmtSuccessorIterator; /// children of a parent statement in the ML Function. class StmtBlock : public IRObjectWithUseList { public: - enum class StmtBlockKind { - MLFunc, // MLFunction - ForBody, // ForStmtBody - IfClause // IfClause - }; - + explicit StmtBlock(MLFunction *parent); + explicit StmtBlock(Statement *parent); ~StmtBlock(); void clear() { @@ -54,7 +50,9 @@ public: statements.pop_back(); } - StmtBlockKind getStmtBlockKind() const { return kind; } + llvm::PointerUnion getParent() const { + return parent; + } /// Returns the closest surrounding statement that contains this block or /// nullptr if this is a top-level statement block. @@ -66,7 +64,10 @@ public: /// Returns the function that this statement block is part of. /// The function is determined by traversing the chain of parent statements. - MLFunction *findFunction() const; + MLFunction *findFunction(); + const MLFunction *findFunction() const { + return const_cast(this)->findFunction(); + } //===--------------------------------------------------------------------===// // Block argument management @@ -224,11 +225,10 @@ public: void printBlock(raw_ostream &os) const; void dumpBlock() const; -protected: - StmtBlock(StmtBlockKind kind) : kind(kind) {} - private: - StmtBlockKind kind; + /// This is the parent function/IfStmt/ForStmt that owns this block. + llvm::PointerUnion parent; + /// This is the list of statements in the block. StmtListType statements; diff --git a/mlir/include/mlir/IR/StmtVisitor.h b/mlir/include/mlir/IR/StmtVisitor.h index 94bc0b0cdc1..8dcd5863096 100644 --- a/mlir/include/mlir/IR/StmtVisitor.h +++ b/mlir/include/mlir/IR/StmtVisitor.h @@ -132,11 +132,13 @@ public: // Define walkers for MLFunction and all MLFunction statement kinds. void walk(MLFunction *f) { static_cast(this)->visitMLFunction(f); - static_cast(this)->walk(f->begin(), f->end()); + static_cast(this)->walk(f->getBody()->begin(), + f->getBody()->end()); } void walkPostOrder(MLFunction *f) { - static_cast(this)->walkPostOrder(f->begin(), f->end()); + static_cast(this)->walkPostOrder(f->getBody()->begin(), + f->getBody()->end()); static_cast(this)->visitMLFunction(f); } diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index 6e1522a656f..07324ba7d52 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -288,8 +288,8 @@ bool MLFuncVerifier::verifyDominance() { HashTable::ScopeTy blockScope(liveValues); // The induction variable of a for statement is live within its body. - if (auto *forStmtBody = dyn_cast(&block)) - liveValues.insert(forStmtBody->getFor(), true); + if (auto *forStmt = dyn_cast_or_null(block.getContainingStmt())) + liveValues.insert(forStmt, true); for (auto &stmt : block) { // Verify that each of the operands are live. @@ -330,16 +330,16 @@ bool MLFuncVerifier::verifyDominance() { }; // Check the whole function out. - return walkBlock(fn); + return walkBlock(*fn.getBody()); } bool MLFuncVerifier::verifyReturn() { // TODO: fold return verification in the pass that verifies all statements. const char missingReturnMsg[] = "ML function must end with return statement"; - if (fn.getStatements().empty()) + if (fn.getBody()->getStatements().empty()) return failure(missingReturnMsg, fn); - const auto &stmt = fn.getStatements().back(); + const auto &stmt = fn.getBody()->getStatements().back(); if (const auto *op = dyn_cast(&stmt)) { if (!op->isReturn()) return failure(missingReturnMsg, fn); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 2de17563d93..3b193117355 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -230,7 +230,7 @@ void ModuleState::visitStatement(const Statement *stmt) { void ModuleState::visitMLFunction(const MLFunction *fn) { visitType(fn->getType()); - for (auto &stmt : *fn) { + for (auto &stmt : *fn->getBody()) { ModuleState::visitStatement(&stmt); } } @@ -1390,7 +1390,7 @@ void MLFunctionPrinter::print() { printFunctionSignature(); printFunctionAttributes(getFunction()); os << " {\n"; - print(function); + print(function->getBody()); os << "}\n\n"; } @@ -1649,7 +1649,7 @@ void Statement::print(raw_ostream &os) const { void Statement::dump() const { print(llvm::errs()); } void StmtBlock::printBlock(raw_ostream &os) const { - MLFunction *function = findFunction(); + const MLFunction *function = findFunction(); ModuleState state(function->getContext()); ModulePrinter modulePrinter(os, state); MLFunctionPrinter(function, modulePrinter).print(this); diff --git a/mlir/lib/IR/BuiltinOps.cpp b/mlir/lib/IR/BuiltinOps.cpp index 5d7ba237b44..dfd59c4d380 100644 --- a/mlir/lib/IR/BuiltinOps.cpp +++ b/mlir/lib/IR/BuiltinOps.cpp @@ -474,7 +474,7 @@ void ReturnOp::print(OpAsmPrinter *p) const { bool ReturnOp::verify() const { const Function *function; if (auto *stmt = dyn_cast(getOperation())) - function = cast(stmt->getBlock()); + function = stmt->getBlock()->findFunction(); else function = cast(getOperation())->getFunction(); diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index b79e1596a65..533be8e2a29 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -202,14 +202,13 @@ MLFunction *MLFunction::create(Location location, StringRef name, MLFunction::MLFunction(Location location, StringRef name, FunctionType type, ArrayRef attrs) - : Function(Kind::MLFunc, location, name, type, attrs), - StmtBlock(StmtBlockKind::MLFunc) {} + : Function(Kind::MLFunc, location, name, type, attrs), body(this) {} MLFunction::~MLFunction() { // Explicitly erase statements instead of relying of 'StmtBlock' destructor // since child statements need to be destroyed before function arguments // are destroyed. - clear(); + getBody()->clear(); // Explicitly run the destructors for the function arguments. for (auto &arg : getArgumentsInternal()) @@ -222,11 +221,11 @@ void MLFunction::destroy() { } const OperationStmt *MLFunction::getReturnStmt() const { - return cast(&back()); + return cast(&getBody()->back()); } OperationStmt *MLFunction::getReturnStmt() { - return cast(&back()); + return cast(&getBody()->back()); } void MLFunction::walk(std::function callback) { diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index d3a618d7da5..c946a76a98b 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -581,7 +581,7 @@ bool OpTrait::impl::verifyIsTerminator(const Operation *op) { // Verify that the operation is at the end of the respective parent block. if (auto *stmt = dyn_cast(op)) { StmtBlock *block = stmt->getBlock(); - if (!block || !isa(block) || &block->back() != stmt) + if (!block || block->getContainingStmt() || &block->back() != stmt) return op->emitOpError("must be the last statement in the ML function"); } else { const Instruction *inst = cast(op); diff --git a/mlir/lib/IR/StmtBlock.cpp b/mlir/lib/IR/StmtBlock.cpp index 8ecb903d21d..fdee491c150 100644 --- a/mlir/lib/IR/StmtBlock.cpp +++ b/mlir/lib/IR/StmtBlock.cpp @@ -20,33 +20,30 @@ #include "mlir/IR/Statements.h" using namespace mlir; +StmtBlock::StmtBlock(MLFunction *parent) : parent(parent) {} + +StmtBlock::StmtBlock(Statement *parent) : parent(parent) {} + StmtBlock::~StmtBlock() { clear(); llvm::DeleteContainerPointers(arguments); } +/// Returns the closest surrounding statement that contains this block or +/// nullptr if this is a top-level statement block. Statement *StmtBlock::getContainingStmt() { - switch (kind) { - case StmtBlockKind::MLFunc: - return nullptr; - case StmtBlockKind::ForBody: - return cast(this)->getFor(); - case StmtBlockKind::IfClause: - return cast(this)->getIf(); - } + return parent.dyn_cast(); } -MLFunction *StmtBlock::findFunction() const { - // FIXME: const incorrect. - StmtBlock *block = const_cast(this); - - while (block->getContainingStmt()) { - block = block->getContainingStmt()->getBlock(); +MLFunction *StmtBlock::findFunction() { + StmtBlock *block = this; + while (auto *stmt = block->getContainingStmt()) { + block = stmt->getBlock(); if (!block) return nullptr; } - return dyn_cast(block); + return block->getParent().get(); } /// Returns 'stmt' if 'stmt' lies in this block, or otherwise finds the ancestor diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 781ec461b62..1a28648eba9 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -2777,7 +2777,7 @@ class MLFunctionParser : public FunctionParser { public: MLFunctionParser(ParserState &state, MLFunction *function) : FunctionParser(state, Kind::MLFunc), function(function), - builder(function, function->end()) {} + builder(function->getBody()) {} ParseResult parseFunctionBody(); @@ -2796,7 +2796,7 @@ private: ParseResult parseBound(SmallVectorImpl &operands, AffineMap &map, bool isLower); ParseResult parseIfStmt(); - ParseResult parseElseClause(IfClause *elseClause); + ParseResult parseElseClause(StmtBlock *elseClause); ParseResult parseStatements(StmtBlock *block); ParseResult parseStmtBlock(StmtBlock *block); @@ -2812,7 +2812,7 @@ ParseResult MLFunctionParser::parseFunctionBody() { auto braceLoc = getToken().getLoc(); // Parse statements in this function. - if (parseStmtBlock(function)) + if (parseStmtBlock(function->getBody())) return ParseFailure; return finalizeFunction(function, braceLoc); @@ -3121,7 +3121,7 @@ ParseResult MLFunctionParser::parseIfStmt() { IfStmt *ifStmt = builder.createIf(getEncodedSourceLocation(loc), operands, set); - IfClause *thenClause = ifStmt->getThen(); + StmtBlock *thenClause = ifStmt->getThen(); // When parsing of an if statement body fails, the IR contains // the if statement with the portion of the body that has been @@ -3141,7 +3141,7 @@ ParseResult MLFunctionParser::parseIfStmt() { return ParseSuccess; } -ParseResult MLFunctionParser::parseElseClause(IfClause *elseClause) { +ParseResult MLFunctionParser::parseElseClause(StmtBlock *elseClause) { if (getToken().is(Token::kw_if)) { builder.setInsertionPointToEnd(elseClause); return parseIfStmt(); diff --git a/mlir/lib/Transforms/ConvertToCFG.cpp b/mlir/lib/Transforms/ConvertToCFG.cpp index 8620230b2f1..4fafff51322 100644 --- a/mlir/lib/Transforms/ConvertToCFG.cpp +++ b/mlir/lib/Transforms/ConvertToCFG.cpp @@ -490,7 +490,7 @@ CFGFunction *FunctionConverter::convert(MLFunction *mlFunc) { } // Convert statements in order. - for (auto &stmt : *mlFunc) { + for (auto &stmt : *mlFunc->getBody()) { visit(&stmt); } diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 2b79064e53f..a927516345a 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -426,7 +426,7 @@ void DmaGeneration::runOnForStmt(ForStmt *forStmt) { } PassResult DmaGeneration::runOnMLFunction(MLFunction *f) { - for (auto &stmt : *f) { + for (auto &stmt : *f->getBody()) { if (auto *forStmt = dyn_cast(&stmt)) { runOnForStmt(forStmt); } diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index df68765aeb7..e3609496cc5 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -348,7 +348,7 @@ public: bool MemRefDependenceGraph::init(MLFunction *f) { unsigned id = 0; DenseMap> memrefAccesses; - for (auto &stmt : *f) { + for (auto &stmt : *f->getBody()) { if (auto *forStmt = dyn_cast(&stmt)) { // Create graph node 'id' to represent top-level 'forStmt' and record // all loads and store accesses it contains. diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 847db83aebc..b5c12865790 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -230,8 +230,8 @@ static void getTileableBands(MLFunction *f, bands->push_back(band); }; - for (auto &stmt : *f) { - ForStmt *forStmt = dyn_cast(&stmt); + for (auto &stmt : *f->getBody()) { + auto *forStmt = dyn_cast(&stmt); if (!forStmt) continue; getMaximalPerfectLoopNest(forStmt); diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index dd491f8119b..ffff1c5b615 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -92,10 +92,10 @@ PassResult LoopUnrollAndJam::runOnMLFunction(MLFunction *f) { // Currently, just the outermost loop from the first loop nest is // unroll-and-jammed by this pass. However, runOnForStmt can be called on any // for Stmt. - if (!isa(f->begin())) + auto *forStmt = dyn_cast(f->getBody()->begin()); + if (!forStmt) return success(); - auto *forStmt = cast(f->begin()); runOnForStmt(forStmt); return success(); } diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index d4069eaa638..fd07619a165 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -238,7 +238,7 @@ struct LowerVectorTransfersPass makeFuncWiseState(MLFunction *f) const override { auto state = llvm::make_unique(); auto builder = MLFuncBuilder(f); - builder.setInsertionPointToStart(f); + builder.setInsertionPointToStart(f->getBody()); state->zero = builder.create(builder.getUnknownLoc(), 0); return state; } diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 554e3cb47a9..fbde1fd1692 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -177,7 +177,8 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction, cast(op)->moveBefore(&entryBB, entryBB.begin()); } else { auto *mlFunc = cast(currentFunction); - cast(op)->moveBefore(mlFunc, mlFunc->begin()); + cast(op)->moveBefore(mlFunc->getBody(), + mlFunc->getBody()->begin()); } continue; diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 4d75f7c0835..023d3ebc643 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -102,7 +102,7 @@ bool mlir::promoteIfSingleIteration(ForStmt *forStmt) { if (!forStmt->use_empty()) { if (forStmt->hasConstantLowerBound()) { auto *mlFunc = forStmt->findFunction(); - MLFuncBuilder topBuilder(&mlFunc->front()); + MLFuncBuilder topBuilder(&mlFunc->getBody()->front()); auto constOp = topBuilder.create( forStmt->getLoc(), forStmt->getConstantLowerBound()); forStmt->replaceAllUsesWith(constOp); -- cgit v1.2.3 From 3f190312f8f7f09b5910bc77e80268402732ce6b Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Thu, 27 Dec 2018 14:35:10 -0800 Subject: Merge SSAValue, CFGValue, and MLValue together into a single Value class, which is the new base of the SSA value hierarchy. This CL also standardizes all the nomenclature and comments to use 'Value' where appropriate. This also eliminates a large number of cast(x)'s, which is very soothing. This is step 11/n towards merging instructions and statements, NFC. PiperOrigin-RevId: 227064624 --- mlir/include/mlir/Analysis/AffineAnalysis.h | 8 +- mlir/include/mlir/Analysis/AffineStructures.h | 79 ++++---- mlir/include/mlir/Analysis/Dominance.h | 4 +- mlir/include/mlir/Analysis/HyperRectangularSet.h | 6 +- mlir/include/mlir/Analysis/LoopAnalysis.h | 10 +- mlir/include/mlir/Analysis/SliceAnalysis.h | 4 +- mlir/include/mlir/Analysis/Utils.h | 4 +- mlir/include/mlir/IR/Builders.h | 6 +- mlir/include/mlir/IR/BuiltinOps.h | 32 ++-- mlir/include/mlir/IR/Function.h | 1 - mlir/include/mlir/IR/MLValue.h | 133 -------------- mlir/include/mlir/IR/Matchers.h | 4 +- mlir/include/mlir/IR/OpDefinition.h | 80 ++++----- mlir/include/mlir/IR/OpImplementation.h | 17 +- mlir/include/mlir/IR/Operation.h | 22 ++- mlir/include/mlir/IR/OperationSupport.h | 10 +- mlir/include/mlir/IR/PatternMatch.h | 14 +- mlir/include/mlir/IR/SSAValue.h | 154 ---------------- mlir/include/mlir/IR/Statement.h | 21 ++- mlir/include/mlir/IR/Statements.h | 74 ++++---- mlir/include/mlir/IR/StmtBlock.h | 3 +- mlir/include/mlir/IR/UseDefLists.h | 20 +-- mlir/include/mlir/IR/Value.h | 198 +++++++++++++++++++++ mlir/include/mlir/IR/op_base.td | 4 +- mlir/include/mlir/StandardOps/StandardOps.h | 117 ++++++------ mlir/include/mlir/SuperVectorOps/SuperVectorOps.h | 42 ++--- mlir/include/mlir/Transforms/Utils.h | 17 +- mlir/lib/Analysis/AffineAnalysis.cpp | 68 ++++--- mlir/lib/Analysis/AffineStructures.cpp | 66 ++++--- mlir/lib/Analysis/Dominance.cpp | 2 +- mlir/lib/Analysis/LoopAnalysis.cpp | 22 +-- mlir/lib/Analysis/MemRefDependenceCheck.cpp | 6 +- mlir/lib/Analysis/Utils.cpp | 18 +- mlir/lib/Analysis/VectorAnalysis.cpp | 2 +- mlir/lib/Analysis/Verifier.cpp | 2 +- mlir/lib/IR/AsmPrinter.cpp | 43 +++-- mlir/lib/IR/Builders.cpp | 35 ++-- mlir/lib/IR/BuiltinOps.cpp | 42 +++-- mlir/lib/IR/Operation.cpp | 17 +- mlir/lib/IR/PatternMatch.cpp | 15 +- mlir/lib/IR/SSAValue.cpp | 29 ++- mlir/lib/IR/Statement.cpp | 57 +++--- mlir/lib/Parser/Parser.cpp | 71 ++++---- mlir/lib/StandardOps/StandardOps.cpp | 53 +++--- mlir/lib/SuperVectorOps/SuperVectorOps.cpp | 23 ++- mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp | 25 ++- mlir/lib/Transforms/ConstantFold.cpp | 14 +- mlir/lib/Transforms/ConvertToCFG.cpp | 84 +++++---- mlir/lib/Transforms/DmaGeneration.cpp | 28 +-- mlir/lib/Transforms/LoopFusion.cpp | 39 ++-- mlir/lib/Transforms/LoopTiling.cpp | 13 +- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 6 +- mlir/lib/Transforms/LowerVectorTransfers.cpp | 29 ++- mlir/lib/Transforms/MaterializeVectors.cpp | 55 +++--- mlir/lib/Transforms/PipelineDataTransfer.cpp | 24 ++- .../Utils/GreedyPatternRewriteDriver.cpp | 2 +- mlir/lib/Transforms/Utils/LoopUtils.cpp | 16 +- mlir/lib/Transforms/Utils/LoweringUtils.cpp | 28 +-- mlir/lib/Transforms/Utils/Utils.cpp | 39 ++-- mlir/lib/Transforms/Vectorize.cpp | 54 +++--- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 8 +- 61 files changed, 958 insertions(+), 1161 deletions(-) delete mode 100644 mlir/include/mlir/IR/MLValue.h delete mode 100644 mlir/include/mlir/IR/SSAValue.h create mode 100644 mlir/include/mlir/IR/Value.h (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Analysis/AffineAnalysis.h b/mlir/include/mlir/Analysis/AffineAnalysis.h index df6e9a29480..9c288aac7cc 100644 --- a/mlir/include/mlir/Analysis/AffineAnalysis.h +++ b/mlir/include/mlir/Analysis/AffineAnalysis.h @@ -37,9 +37,9 @@ class ForStmt; class MLIRContext; class FlatAffineConstraints; class IntegerSet; -class MLValue; class OperationStmt; class Statement; +class Value; /// Simplify an affine expression through flattening and some amount of /// simple analysis. This has complexity linear in the number of nodes in @@ -78,7 +78,7 @@ AffineExpr composeWithUnboundedMap(AffineExpr e, AffineMap g); /// 'affineApplyOps', which are reachable via a search starting from 'operands', /// and ending at operands which are not defined by AffineApplyOps. void getReachableAffineApplyOps( - llvm::ArrayRef operands, + llvm::ArrayRef operands, llvm::SmallVectorImpl &affineApplyOps); /// Forward substitutes into 'valueMap' all AffineApplyOps reachable from the @@ -122,9 +122,9 @@ bool getIndexSet(llvm::ArrayRef forStmts, FlatAffineConstraints *domain); struct MemRefAccess { - const MLValue *memref; + const Value *memref; const OperationStmt *opStmt; - llvm::SmallVector indices; + llvm::SmallVector indices; // Populates 'accessMap' with composition of AffineApplyOps reachable from // 'indices'. void getAccessMap(AffineValueMap *accessMap) const; diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h index 15c559e2da5..a954dacc3f3 100644 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -25,7 +25,6 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Operation.h" #include "mlir/Support/LLVM.h" -#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" namespace mlir { @@ -37,7 +36,7 @@ class AffineMap; class ForStmt; class IntegerSet; class MLIRContext; -class MLValue; +class Value; class HyperRectangularSet; /// A mutable affine map. Its affine expressions are however unique. @@ -132,7 +131,7 @@ public: AffineValueMap(const AffineApplyOp &op); AffineValueMap(const AffineBound &bound); AffineValueMap(AffineMap map); - AffineValueMap(AffineMap map, ArrayRef operands); + AffineValueMap(AffineMap map, ArrayRef operands); ~AffineValueMap(); @@ -155,13 +154,13 @@ public: // substitutions). // Resets this AffineValueMap with 'map' and 'operands'. - void reset(AffineMap map, ArrayRef operands); + void reset(AffineMap map, ArrayRef operands); /// Return true if the idx^th result can be proved to be a multiple of /// 'factor', false otherwise. inline bool isMultipleOf(unsigned idx, int64_t factor) const; /// Return true if the idx^th result depends on 'value', false otherwise. - bool isFunctionOf(unsigned idx, MLValue *value) const; + bool isFunctionOf(unsigned idx, Value *value) const; /// Return true if the result at 'idx' is a constant, false /// otherwise. @@ -175,8 +174,8 @@ public: inline unsigned getNumSymbols() const { return map.getNumSymbols(); } inline unsigned getNumResults() const { return map.getNumResults(); } - SSAValue *getOperand(unsigned i) const; - ArrayRef getOperands() const; + Value *getOperand(unsigned i) const; + ArrayRef getOperands() const; AffineMap getAffineMap() const; private: @@ -187,9 +186,9 @@ private: // TODO: make these trailing objects? /// The SSA operands binding to the dim's and symbols of 'map'. - SmallVector operands; + SmallVector operands; /// The SSA results binding to the results of 'map'. - SmallVector results; + SmallVector results; }; /// An IntegerValueSet is an integer set plus its operands. @@ -218,7 +217,7 @@ private: // 'AffineCondition'. MutableIntegerSet set; /// The SSA operands binding to the dim's and symbols of 'set'. - SmallVector operands; + SmallVector operands; }; /// A flat list of affine equalities and inequalities in the form. @@ -250,7 +249,7 @@ public: unsigned numReservedEqualities, unsigned numReservedCols, unsigned numDims = 0, unsigned numSymbols = 0, unsigned numLocals = 0, - ArrayRef> idArgs = {}) + ArrayRef> idArgs = {}) : numReservedCols(numReservedCols), numDims(numDims), numSymbols(numSymbols) { assert(numReservedCols >= numDims + numSymbols + 1); @@ -269,7 +268,7 @@ public: /// dimensions and symbols. FlatAffineConstraints(unsigned numDims = 0, unsigned numSymbols = 0, unsigned numLocals = 0, - ArrayRef> idArgs = {}) + ArrayRef> idArgs = {}) : numReservedCols(numDims + numSymbols + numLocals + 1), numDims(numDims), numSymbols(numSymbols) { assert(numReservedCols >= numDims + numSymbols + 1); @@ -309,10 +308,10 @@ public: // Clears any existing data and reserves memory for the specified constraints. void reset(unsigned numReservedInequalities, unsigned numReservedEqualities, unsigned numReservedCols, unsigned numDims, unsigned numSymbols, - unsigned numLocals = 0, ArrayRef idArgs = {}); + unsigned numLocals = 0, ArrayRef idArgs = {}); void reset(unsigned numDims = 0, unsigned numSymbols = 0, - unsigned numLocals = 0, ArrayRef idArgs = {}); + unsigned numLocals = 0, ArrayRef idArgs = {}); /// Appends constraints from 'other' into this. This is equivalent to an /// intersection with no simplification of any sort attempted. @@ -393,7 +392,7 @@ public: // Returns AffineMap::Null on error (i.e. if coefficient is zero or does // not divide other coefficients in the equality constraint). // TODO(andydavis) Remove 'nonZeroDimIds' and 'nonZeroSymbolIds' from this - // API when we can manage the mapping of MLValues and ids in the constraint + // API when we can manage the mapping of Values and ids in the constraint // system. AffineMap toAffineMapFromEq(unsigned idx, unsigned pos, MLIRContext *context, SmallVectorImpl *nonZeroDimIds, @@ -413,10 +412,10 @@ public: void addLowerBound(ArrayRef expr, ArrayRef lb); /// Adds constraints (lower and upper bounds) for the specified 'for' - /// statement's MLValue using IR information stored in its bound maps. The - /// right identifier is first looked up using forStmt's MLValue. Returns + /// statement's Value using IR information stored in its bound maps. The + /// right identifier is first looked up using forStmt's Value. Returns /// false for the yet unimplemented/unsupported cases, and true if the - /// information is succesfully added. Asserts if the MLValue corresponding to + /// information is succesfully added. Asserts if the Value corresponding to /// the 'for' statement isn't found in the constraint system. Any new /// identifiers that are found in the bound operands of the 'for' statement /// are added as trailing identifiers (either dimensional or symbolic @@ -435,28 +434,28 @@ public: /// Sets the identifier at the specified position to a constant. void setIdToConstant(unsigned pos, int64_t val); - /// Sets the identifier corresponding to the specified MLValue id to a + /// Sets the identifier corresponding to the specified Value id to a /// constant. Asserts if the 'id' is not found. - void setIdToConstant(const MLValue &id, int64_t val); + void setIdToConstant(const Value &id, int64_t val); - /// Looks up the identifier with the specified MLValue. Returns false if not + /// Looks up the identifier with the specified Value. Returns false if not /// found, true if found. pos is set to the (column) position of the /// identifier. - bool findId(const MLValue &id, unsigned *pos) const; + bool findId(const Value &id, unsigned *pos) const; // Add identifiers of the specified kind - specified positions are relative to - // the kind of identifier. 'id' is the MLValue corresponding to the + // the kind of identifier. 'id' is the Value corresponding to the // identifier that can optionally be provided. - void addDimId(unsigned pos, MLValue *id = nullptr); - void addSymbolId(unsigned pos, MLValue *id = nullptr); + void addDimId(unsigned pos, Value *id = nullptr); + void addSymbolId(unsigned pos, Value *id = nullptr); void addLocalId(unsigned pos); - void addId(IdKind kind, unsigned pos, MLValue *id = nullptr); + void addId(IdKind kind, unsigned pos, Value *id = nullptr); /// Composes the affine value map with this FlatAffineConstrains, adding the /// results of the map as dimensions at the front [0, vMap->getNumResults()) /// and with the dimensions set to the equalities specified by the value map. /// Returns false if the composition fails (when vMap is a semi-affine map). - /// The vMap's operand MLValue's are used to look up the right positions in + /// The vMap's operand Value's are used to look up the right positions in /// the FlatAffineConstraints with which to associate. The dimensional and /// symbolic operands of vMap should match 1:1 (in the same order) with those /// of this constraint system, but the latter could have additional trailing @@ -471,8 +470,8 @@ public: void projectOut(unsigned pos, unsigned num); inline void projectOut(unsigned pos) { return projectOut(pos, 1); } - /// Projects out the identifier that is associate with MLValue *. - void projectOut(MLValue *id); + /// Projects out the identifier that is associate with Value *. + void projectOut(Value *id); void removeId(IdKind idKind, unsigned pos); void removeId(unsigned pos); @@ -510,24 +509,24 @@ public: return numIds - numDims - numSymbols; } - inline ArrayRef> getIds() const { + inline ArrayRef> getIds() const { return {ids.data(), ids.size()}; } - /// Returns the MLValue's associated with the identifiers. Asserts if - /// no MLValue was associated with an identifier. - inline void getIdValues(SmallVectorImpl *values) const { + /// Returns the Value's associated with the identifiers. Asserts if + /// no Value was associated with an identifier. + inline void getIdValues(SmallVectorImpl *values) const { values->clear(); values->reserve(numIds); for (unsigned i = 0; i < numIds; i++) { - assert(ids[i].hasValue() && "identifier's MLValue not set"); + assert(ids[i].hasValue() && "identifier's Value not set"); values->push_back(ids[i].getValue()); } } - /// Returns the MLValue associated with the pos^th identifier. Asserts if - /// no MLValue identifier was associated. - inline MLValue *getIdValue(unsigned pos) const { + /// Returns the Value associated with the pos^th identifier. Asserts if + /// no Value identifier was associated. + inline Value *getIdValue(unsigned pos) const { assert(ids[pos].hasValue() && "identifier's ML Value not set"); return ids[pos].getValue(); } @@ -630,11 +629,11 @@ private: /// analysis). unsigned numSymbols; - /// MLValues corresponding to the (column) identifiers of this constraint + /// Values corresponding to the (column) identifiers of this constraint /// system appearing in the order the identifiers correspond to columns. - /// Temporary ones or those that aren't associated to any MLValue are to be + /// Temporary ones or those that aren't associated to any Value are to be /// set to None. - SmallVector, 8> ids; + SmallVector, 8> ids; }; } // end namespace mlir. diff --git a/mlir/include/mlir/Analysis/Dominance.h b/mlir/include/mlir/Analysis/Dominance.h index 4ec61869d2f..5374a451bd1 100644 --- a/mlir/include/mlir/Analysis/Dominance.h +++ b/mlir/include/mlir/Analysis/Dominance.h @@ -58,10 +58,10 @@ public: } /// Return true if value A properly dominates instruction B. - bool properlyDominates(const SSAValue *a, const Instruction *b); + bool properlyDominates(const Value *a, const Instruction *b); /// Return true if instruction A dominates instruction B. - bool dominates(const SSAValue *a, const Instruction *b) { + bool dominates(const Value *a, const Instruction *b) { return (Instruction *)a->getDefiningInst() == b || properlyDominates(a, b); } diff --git a/mlir/include/mlir/Analysis/HyperRectangularSet.h b/mlir/include/mlir/Analysis/HyperRectangularSet.h index 266a0d7123d..27bb5da6dab 100644 --- a/mlir/include/mlir/Analysis/HyperRectangularSet.h +++ b/mlir/include/mlir/Analysis/HyperRectangularSet.h @@ -38,10 +38,10 @@ class AffineCondition; class AffineMap; class IntegerSet; class MLIRContext; -class MLValue; class MutableIntegerSet; class FlatAffineConstraints; class HyperRectangleList; +class Value; /// A list of affine bounds. // Not using a MutableAffineMap here since numSymbols is the same as the @@ -152,8 +152,8 @@ private: // expressions. std::vector upperBounds; - Optional> dims = None; - Optional> symbols = None; + Optional> dims = None; + Optional> symbols = None; /// Number of real dimensions. unsigned numDims; diff --git a/mlir/include/mlir/Analysis/LoopAnalysis.h b/mlir/include/mlir/Analysis/LoopAnalysis.h index d40ac521cec..9cac049e7b7 100644 --- a/mlir/include/mlir/Analysis/LoopAnalysis.h +++ b/mlir/include/mlir/Analysis/LoopAnalysis.h @@ -23,7 +23,6 @@ #define MLIR_ANALYSIS_LOOP_ANALYSIS_H #include "mlir/Support/LLVM.h" - #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/Optional.h" @@ -33,8 +32,8 @@ class AffineExpr; class AffineMap; class ForStmt; class MemRefType; -class MLValue; class OperationStmt; +class Value; /// Returns the trip count of the loop as an affine expression if the latter is /// expressible as an affine expression, and nullptr otherwise. The trip count @@ -66,7 +65,7 @@ uint64_t getLargestDivisorOfTripCount(const ForStmt &forStmt); /// /// Returns false in cases with more than one AffineApplyOp, this is /// conservative. -bool isAccessInvariant(const MLValue &iv, const MLValue &index); +bool isAccessInvariant(const Value &iv, const Value &index); /// Given an induction variable `iv` of type ForStmt and `indices` of type /// IndexType, returns the set of `indices` that are independent of `iv`. @@ -77,9 +76,8 @@ bool isAccessInvariant(const MLValue &iv, const MLValue &index); /// /// Returns false in cases with more than one AffineApplyOp, this is /// conservative. -llvm::DenseSet> -getInvariantAccesses(const MLValue &iv, - llvm::ArrayRef indices); +llvm::DenseSet> +getInvariantAccesses(const Value &iv, llvm::ArrayRef indices); /// Checks whether the loop is structurally vectorizable; i.e.: /// 1. the loop has proper dependence semantics (parallel, reduction, etc); diff --git a/mlir/include/mlir/Analysis/SliceAnalysis.h b/mlir/include/mlir/Analysis/SliceAnalysis.h index f6d88c178b8..f3e09655bf2 100644 --- a/mlir/include/mlir/Analysis/SliceAnalysis.h +++ b/mlir/include/mlir/Analysis/SliceAnalysis.h @@ -127,10 +127,10 @@ void getBackwardSlice( /// **includes** the original statement. /// /// This allows building a slice (i.e. multi-root DAG where everything -/// that is reachable from an SSAValue in forward and backward direction is +/// that is reachable from an Value in forward and backward direction is /// contained in the slice). /// This is the abstraction we need to materialize all the instructions for -/// supervectorization without worrying about orderings and SSAValue +/// supervectorization without worrying about orderings and Value /// replacements. /// /// Example starting from any node diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index 365fc74a778..284a6fd6735 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -34,10 +34,10 @@ namespace mlir { class FlatAffineConstraints; class ForStmt; -class MLValue; class MemRefAccess; class OperationStmt; class Statement; +class Value; /// Returns true if statement 'a' dominates statement b. bool dominates(const Statement &a, const Statement &b); @@ -92,7 +92,7 @@ struct MemRefRegion { unsigned getRank() const; /// Memref that this region corresponds to. - MLValue *memref; + Value *memref; private: /// Read or write. diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index a632078903a..1aecbd47e76 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -408,8 +408,8 @@ public: } // Creates a for statement. When step is not specified, it is set to 1. - ForStmt *createFor(Location location, ArrayRef lbOperands, - AffineMap lbMap, ArrayRef ubOperands, + ForStmt *createFor(Location location, ArrayRef lbOperands, + AffineMap lbMap, ArrayRef ubOperands, AffineMap ubMap, int64_t step = 1); // Creates a for statement with known (constant) lower and upper bounds. @@ -417,7 +417,7 @@ public: ForStmt *createFor(Location loc, int64_t lb, int64_t ub, int64_t step = 1); /// Creates if statement. - IfStmt *createIf(Location location, ArrayRef operands, + IfStmt *createIf(Location location, ArrayRef operands, IntegerSet set); private: diff --git a/mlir/include/mlir/IR/BuiltinOps.h b/mlir/include/mlir/IR/BuiltinOps.h index ec88e2d157b..3f4ec6fcccd 100644 --- a/mlir/include/mlir/IR/BuiltinOps.h +++ b/mlir/include/mlir/IR/BuiltinOps.h @@ -30,7 +30,6 @@ namespace mlir { class Builder; -class MLValue; class BuiltinDialect : public Dialect { public: @@ -57,7 +56,7 @@ class AffineApplyOp public: /// Builds an affine apply op with the specified map and operands. static void build(Builder *builder, OperationState *result, AffineMap map, - ArrayRef operands); + ArrayRef operands); /// Returns the affine map to be applied by this operation. AffineMap getAffineMap() const { @@ -101,7 +100,7 @@ public: static StringRef getOperationName() { return "br"; } static void build(Builder *builder, OperationState *result, BasicBlock *dest, - ArrayRef operands = {}); + ArrayRef operands = {}); // Hooks to customize behavior of this op. static bool parse(OpAsmParser *parser, OperationState *result); @@ -144,10 +143,9 @@ class CondBranchOp : public Op::Impl, public: static StringRef getOperationName() { return "cond_br"; } - static void build(Builder *builder, OperationState *result, - SSAValue *condition, BasicBlock *trueDest, - ArrayRef trueOperands, BasicBlock *falseDest, - ArrayRef falseOperands); + static void build(Builder *builder, OperationState *result, Value *condition, + BasicBlock *trueDest, ArrayRef trueOperands, + BasicBlock *falseDest, ArrayRef falseOperands); // Hooks to customize behavior of this op. static bool parse(OpAsmParser *parser, OperationState *result); @@ -155,8 +153,8 @@ public: bool verify() const; // The condition operand is the first operand in the list. - SSAValue *getCondition() { return getOperand(0); } - const SSAValue *getCondition() const { return getOperand(0); } + Value *getCondition() { return getOperand(0); } + const Value *getCondition() const { return getOperand(0); } /// Return the destination if the condition is true. BasicBlock *getTrueDest() const; @@ -165,14 +163,14 @@ public: BasicBlock *getFalseDest() const; // Accessors for operands to the 'true' destination. - SSAValue *getTrueOperand(unsigned idx) { + Value *getTrueOperand(unsigned idx) { assert(idx < getNumTrueOperands()); return getOperand(getTrueDestOperandIndex() + idx); } - const SSAValue *getTrueOperand(unsigned idx) const { + const Value *getTrueOperand(unsigned idx) const { return const_cast(this)->getTrueOperand(idx); } - void setTrueOperand(unsigned idx, SSAValue *value) { + void setTrueOperand(unsigned idx, Value *value) { assert(idx < getNumTrueOperands()); setOperand(getTrueDestOperandIndex() + idx, value); } @@ -199,14 +197,14 @@ public: void eraseTrueOperand(unsigned index); // Accessors for operands to the 'false' destination. - SSAValue *getFalseOperand(unsigned idx) { + Value *getFalseOperand(unsigned idx) { assert(idx < getNumFalseOperands()); return getOperand(getFalseDestOperandIndex() + idx); } - const SSAValue *getFalseOperand(unsigned idx) const { + const Value *getFalseOperand(unsigned idx) const { return const_cast(this)->getFalseOperand(idx); } - void setFalseOperand(unsigned idx, SSAValue *value) { + void setFalseOperand(unsigned idx, Value *value) { assert(idx < getNumFalseOperands()); setOperand(getFalseDestOperandIndex() + idx, value); } @@ -361,7 +359,7 @@ public: static StringRef getOperationName() { return "return"; } static void build(Builder *builder, OperationState *result, - ArrayRef results = {}); + ArrayRef results = {}); // Hooks to customize behavior of this op. static bool parse(OpAsmParser *parser, OperationState *result); @@ -380,7 +378,7 @@ void printDimAndSymbolList(Operation::const_operand_iterator begin, // Parses dimension and symbol list and returns true if parsing failed. bool parseDimAndSymbolList(OpAsmParser *parser, - SmallVector &operands, + SmallVector &operands, unsigned &numDims); } // end namespace mlir diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h index 9f2e202857f..1a4e64fbc43 100644 --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -27,7 +27,6 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Identifier.h" #include "mlir/IR/Location.h" -#include "mlir/IR/MLValue.h" #include "mlir/IR/Operation.h" #include "mlir/IR/StmtBlock.h" #include "mlir/IR/Types.h" diff --git a/mlir/include/mlir/IR/MLValue.h b/mlir/include/mlir/IR/MLValue.h deleted file mode 100644 index a1b5412affa..00000000000 --- a/mlir/include/mlir/IR/MLValue.h +++ /dev/null @@ -1,133 +0,0 @@ -//===- MLValue.h - MLValue base class and SSA type decls ------*- C++ -*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file defines SSA manipulation implementations for ML functions. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_IR_MLVALUE_H -#define MLIR_IR_MLVALUE_H - -#include "mlir/IR/SSAValue.h" - -namespace mlir { -class ForStmt; -class MLValue; -using MLFunction = Function; -class Statement; -class StmtBlock; - -/// This enum contains all of the SSA value kinds that are valid in an ML -/// function. This should be kept as a proper subtype of SSAValueKind, -/// including having all of the values of the enumerators align. -enum class MLValueKind { - BlockArgument = (int)SSAValueKind::BlockArgument, - StmtResult = (int)SSAValueKind::StmtResult, - ForStmt = (int)SSAValueKind::ForStmt, -}; - -/// The operand of ML function statement contains an MLValue. -using StmtOperand = IROperandImpl; - -/// MLValue is the base class for SSA values in ML functions. -class MLValue : public SSAValueImpl { -public: - /// Returns true if the given MLValue can be used as a dimension id. - bool isValidDim() const; - - /// Returns true if the given MLValue can be used as a symbol. - bool isValidSymbol() const; - - static bool classof(const SSAValue *value) { - switch (value->getKind()) { - case SSAValueKind::BlockArgument: - case SSAValueKind::StmtResult: - case SSAValueKind::ForStmt: - return true; - } - } - - /// Return the function that this MLValue is defined in. - MLFunction *getFunction(); - - /// Return the function that this MLValue is defined in. - const MLFunction *getFunction() const { - return const_cast(this)->getFunction(); - } - -protected: - MLValue(MLValueKind kind, Type type) : SSAValueImpl(kind, type) {} -}; - -/// Block arguments are ML Values. -class BlockArgument : public MLValue { -public: - static bool classof(const SSAValue *value) { - return value->getKind() == SSAValueKind::BlockArgument; - } - - /// Return the function that this argument is defined in. - MLFunction *getFunction(); - const MLFunction *getFunction() const { - return const_cast(this)->getFunction(); - } - - StmtBlock *getOwner() { return owner; } - const StmtBlock *getOwner() const { return owner; } - -private: - friend class StmtBlock; // For access to private constructor. - BlockArgument(Type type, StmtBlock *owner) - : MLValue(MLValueKind::BlockArgument, type), owner(owner) {} - - /// The owner of this operand. - /// TODO: can encode this more efficiently to avoid the space hit of this - /// through bitpacking shenanigans. - StmtBlock *const owner; -}; - -/// This is a value defined by a result of an operation instruction. -class StmtResult : public MLValue { -public: - StmtResult(Type type, OperationStmt *owner) - : MLValue(MLValueKind::StmtResult, type), owner(owner) {} - - static bool classof(const SSAValue *value) { - return value->getKind() == SSAValueKind::StmtResult; - } - - OperationStmt *getOwner() { return owner; } - const OperationStmt *getOwner() const { return owner; } - - /// Returns the number of this result. - unsigned getResultNumber() const; - -private: - /// The owner of this operand. - /// TODO: can encode this more efficiently to avoid the space hit of this - /// through bitpacking shenanigans. - OperationStmt *const owner; -}; - -// TODO(clattner) clean all this up. -using CFGValue = MLValue; -using BBArgument = BlockArgument; -using InstResult = StmtResult; - -} // namespace mlir - -#endif diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h index 1b6ccb25b64..9ebe226600f 100644 --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -27,8 +27,8 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Operation.h" -#include "mlir/IR/SSAValue.h" #include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" #include namespace mlir { @@ -107,7 +107,7 @@ template struct op_matcher { /// Entry point for matching a pattern over an SSAValue. template -inline bool matchPattern(SSAValue *value, const Pattern &pattern) { +inline bool matchPattern(Value *value, const Pattern &pattern) { // TODO: handle other cases if (auto *op = value->getDefiningOperation()) return const_cast(pattern).match(op); diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 3c68bb0f30f..5262ba5975b 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -29,7 +29,7 @@ #define MLIR_IR_OPDEFINITION_H #include "mlir/IR/Operation.h" -#include "mlir/IR/SSAValue.h" +#include "mlir/IR/Value.h" #include namespace mlir { @@ -78,11 +78,11 @@ public: } /// If the OpType operation includes the OneResult trait, then OpPointer can - /// be implicitly converted to an SSAValue*. This yields the value of the + /// be implicitly converted to an Value*. This yields the value of the /// only result. template operator typename std::enable_if::value, - SSAValue *>::type() { + Value *>::type() { return value.getResult(); } @@ -114,14 +114,14 @@ public: } /// If the OpType operation includes the OneResult trait, then OpPointer can - /// be implicitly converted to an const SSAValue*. This yields the value of + /// be implicitly converted to an const Value*. This yields the value of /// the only result. template operator typename std::enable_if< std::is_convertible< SFINAE *, OpTrait::OneResult *>::value, - const SSAValue *>::type() const { + const Value *>::type() const { return value.getResult(); } @@ -346,15 +346,13 @@ private: template class OneOperand : public TraitBase { public: - const SSAValue *getOperand() const { + const Value *getOperand() const { return this->getOperation()->getOperand(0); } - SSAValue *getOperand() { return this->getOperation()->getOperand(0); } + Value *getOperand() { return this->getOperation()->getOperand(0); } - void setOperand(SSAValue *value) { - this->getOperation()->setOperand(0, value); - } + void setOperand(Value *value) { this->getOperation()->setOperand(0, value); } static bool verifyTrait(const Operation *op) { return impl::verifyOneOperand(op); @@ -371,15 +369,15 @@ public: template class Impl : public TraitBase::Impl> { public: - const SSAValue *getOperand(unsigned i) const { + const Value *getOperand(unsigned i) const { return this->getOperation()->getOperand(i); } - SSAValue *getOperand(unsigned i) { + Value *getOperand(unsigned i) { return this->getOperation()->getOperand(i); } - void setOperand(unsigned i, SSAValue *value) { + void setOperand(unsigned i, Value *value) { this->getOperation()->setOperand(i, value); } @@ -402,15 +400,15 @@ public: unsigned getNumOperands() const { return this->getOperation()->getNumOperands(); } - const SSAValue *getOperand(unsigned i) const { + const Value *getOperand(unsigned i) const { return this->getOperation()->getOperand(i); } - SSAValue *getOperand(unsigned i) { + Value *getOperand(unsigned i) { return this->getOperation()->getOperand(i); } - void setOperand(unsigned i, SSAValue *value) { + void setOperand(unsigned i, Value *value) { this->getOperation()->setOperand(i, value); } @@ -453,15 +451,13 @@ public: return this->getOperation()->getNumOperands(); } - const SSAValue *getOperand(unsigned i) const { + const Value *getOperand(unsigned i) const { return this->getOperation()->getOperand(i); } - SSAValue *getOperand(unsigned i) { - return this->getOperation()->getOperand(i); - } + Value *getOperand(unsigned i) { return this->getOperation()->getOperand(i); } - void setOperand(unsigned i, SSAValue *value) { + void setOperand(unsigned i, Value *value) { this->getOperation()->setOperand(i, value); } @@ -503,17 +499,15 @@ public: template class OneResult : public TraitBase { public: - SSAValue *getResult() { return this->getOperation()->getResult(0); } - const SSAValue *getResult() const { - return this->getOperation()->getResult(0); - } + Value *getResult() { return this->getOperation()->getResult(0); } + const Value *getResult() const { return this->getOperation()->getResult(0); } Type getType() const { return getResult()->getType(); } /// Replace all uses of 'this' value with the new value, updating anything in /// the IR that uses 'this' to use the other value instead. When this returns /// there are zero uses of 'this'. - void replaceAllUsesWith(SSAValue *newValue) { + void replaceAllUsesWith(Value *newValue) { getResult()->replaceAllUsesWith(newValue); } @@ -548,13 +542,11 @@ public: public: static unsigned getNumResults() { return N; } - const SSAValue *getResult(unsigned i) const { + const Value *getResult(unsigned i) const { return this->getOperation()->getResult(i); } - SSAValue *getResult(unsigned i) { - return this->getOperation()->getResult(i); - } + Value *getResult(unsigned i) { return this->getOperation()->getResult(i); } Type getType(unsigned i) const { return getResult(i)->getType(); } @@ -574,13 +566,11 @@ public: template class Impl : public TraitBase::Impl> { public: - const SSAValue *getResult(unsigned i) const { + const Value *getResult(unsigned i) const { return this->getOperation()->getResult(i); } - SSAValue *getResult(unsigned i) { - return this->getOperation()->getResult(i); - } + Value *getResult(unsigned i) { return this->getOperation()->getResult(i); } Type getType(unsigned i) const { return getResult(i)->getType(); } @@ -599,13 +589,13 @@ public: return this->getOperation()->getNumResults(); } - const SSAValue *getResult(unsigned i) const { + const Value *getResult(unsigned i) const { return this->getOperation()->getResult(i); } - SSAValue *getResult(unsigned i) { return this->getOperation()->getResult(i); } + Value *getResult(unsigned i) { return this->getOperation()->getResult(i); } - void setResult(unsigned i, SSAValue *value) { + void setResult(unsigned i, Value *value) { this->getOperation()->setResult(i, value); } @@ -762,10 +752,10 @@ public: return this->getOperation()->setSuccessor(block, index); } - void addSuccessorOperand(unsigned index, SSAValue *value) { + void addSuccessorOperand(unsigned index, Value *value) { return this->getOperation()->addSuccessorOperand(index, value); } - void addSuccessorOperands(unsigned index, ArrayRef values) { + void addSuccessorOperands(unsigned index, ArrayRef values) { return this->getOperation()->addSuccessorOperand(index, values); } }; @@ -889,8 +879,8 @@ private: // These functions are out-of-line implementations of the methods in BinaryOp, // which avoids them being template instantiated/duplicated. namespace impl { -void buildBinaryOp(Builder *builder, OperationState *result, SSAValue *lhs, - SSAValue *rhs); +void buildBinaryOp(Builder *builder, OperationState *result, Value *lhs, + Value *rhs); bool parseBinaryOp(OpAsmParser *parser, OperationState *result); void printBinaryOp(const Operation *op, OpAsmPrinter *p); } // namespace impl @@ -906,8 +896,8 @@ class BinaryOp : public Op::Impl, OpTrait::OneResult, OpTrait::SameOperandsAndResultType, Traits...> { public: - static void build(Builder *builder, OperationState *result, SSAValue *lhs, - SSAValue *rhs) { + static void build(Builder *builder, OperationState *result, Value *lhs, + Value *rhs) { impl::buildBinaryOp(builder, result, lhs, rhs); } static bool parse(OpAsmParser *parser, OperationState *result) { @@ -926,7 +916,7 @@ protected: // These functions are out-of-line implementations of the methods in CastOp, // which avoids them being template instantiated/duplicated. namespace impl { -void buildCastOp(Builder *builder, OperationState *result, SSAValue *source, +void buildCastOp(Builder *builder, OperationState *result, Value *source, Type destType); bool parseCastOp(OpAsmParser *parser, OperationState *result); void printCastOp(const Operation *op, OpAsmPrinter *p); @@ -942,7 +932,7 @@ template class... Traits> class CastOp : public Op { public: - static void build(Builder *builder, OperationState *result, SSAValue *source, + static void build(Builder *builder, OperationState *result, Value *source, Type destType) { impl::buildCastOp(builder, result, source, destType); } diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index df714e00e1f..36dbb98fa68 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -48,7 +48,7 @@ public: virtual raw_ostream &getStream() const = 0; /// Print implementations for various things an operation contains. - virtual void printOperand(const SSAValue *value) = 0; + virtual void printOperand(const Value *value) = 0; /// Print a comma separated list of operands. template @@ -95,7 +95,7 @@ private: }; // Make the implementations convenient to use. -inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const SSAValue &value) { +inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const Value &value) { p.printOperand(&value); return p; } @@ -119,7 +119,7 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p, AffineMap map) { // even if it isn't exactly one of them. For example, we want to print // FunctionType with the Type& version above, not have it match this. template ::value && + !std::is_convertible::value && !std::is_convertible::value && !std::is_convertible::value && !std::is_convertible::value, @@ -264,9 +264,8 @@ public: virtual bool parseOperand(OperandType &result) = 0; /// Parse a single operation successor and it's operand list. - virtual bool - parseSuccessorAndUseList(BasicBlock *&dest, - SmallVectorImpl &operands) = 0; + virtual bool parseSuccessorAndUseList(BasicBlock *&dest, + SmallVectorImpl &operands) = 0; /// These are the supported delimiters around operand lists, used by /// parseOperandList. @@ -311,13 +310,13 @@ public: /// Resolve an operand to an SSA value, emitting an error and returning true /// on failure. virtual bool resolveOperand(const OperandType &operand, Type type, - SmallVectorImpl &result) = 0; + SmallVectorImpl &result) = 0; /// Resolve a list of operands to SSA values, emitting an error and returning /// true on failure, or appending the results to the list on success. /// This method should be used when all operands have the same type. virtual bool resolveOperands(ArrayRef operands, Type type, - SmallVectorImpl &result) { + SmallVectorImpl &result) { for (auto elt : operands) if (resolveOperand(elt, type, result)) return true; @@ -329,7 +328,7 @@ public: /// to the list on success. virtual bool resolveOperands(ArrayRef operands, ArrayRef types, llvm::SMLoc loc, - SmallVectorImpl &result) { + SmallVectorImpl &result) { if (operands.size() != types.size()) return emitError(loc, Twine(operands.size()) + " operands present, but expected " + diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index 93a2c3061d9..eeee62cb647 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -64,21 +64,20 @@ public: /// Return the number of operands this operation has. unsigned getNumOperands() const; - SSAValue *getOperand(unsigned idx); - const SSAValue *getOperand(unsigned idx) const { + Value *getOperand(unsigned idx); + const Value *getOperand(unsigned idx) const { return const_cast(this)->getOperand(idx); } - void setOperand(unsigned idx, SSAValue *value); + void setOperand(unsigned idx, Value *value); // Support non-const operand iteration. - using operand_iterator = OperandIterator; + using operand_iterator = OperandIterator; operand_iterator operand_begin(); operand_iterator operand_end(); llvm::iterator_range getOperands(); // Support const operand iteration. - using const_operand_iterator = - OperandIterator; + using const_operand_iterator = OperandIterator; const_operand_iterator operand_begin() const; const_operand_iterator operand_end() const; llvm::iterator_range getOperands() const; @@ -87,26 +86,25 @@ public: unsigned getNumResults() const; /// Return the indicated result. - SSAValue *getResult(unsigned idx); - const SSAValue *getResult(unsigned idx) const { + Value *getResult(unsigned idx); + const Value *getResult(unsigned idx) const { return const_cast(this)->getResult(idx); } // Support non-const result iteration. - using result_iterator = ResultIterator; + using result_iterator = ResultIterator; result_iterator result_begin(); result_iterator result_end(); llvm::iterator_range getResults(); // Support const result iteration. - using const_result_iterator = ResultIterator; + using const_result_iterator = ResultIterator; const_result_iterator result_begin() const; const_result_iterator result_end() const; llvm::iterator_range getResults() const; // Support for result type iteration. - using result_type_iterator = - ResultTypeIterator; + using result_type_iterator = ResultTypeIterator; result_type_iterator result_type_begin() const; result_type_iterator result_type_end() const; llvm::iterator_range getResultTypes() const; diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index eaaf927b642..e7d19b7eae0 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -40,8 +40,8 @@ class OpAsmPrinter; class Pattern; class RewritePattern; class StmtBlock; -class SSAValue; class Type; +class Value; using BasicBlock = StmtBlock; /// This is a vector that owns the patterns inside of it. @@ -203,7 +203,7 @@ struct OperationState { MLIRContext *const context; Location location; OperationName name; - SmallVector operands; + SmallVector operands; /// Types of the results of this operation. SmallVector types; SmallVector attributes; @@ -218,7 +218,7 @@ public: : context(context), location(location), name(name) {} OperationState(MLIRContext *context, Location location, StringRef name, - ArrayRef operands, ArrayRef types, + ArrayRef operands, ArrayRef types, ArrayRef attributes, ArrayRef successors = {}) : context(context), location(location), name(name, context), @@ -227,7 +227,7 @@ public: attributes(attributes.begin(), attributes.end()), successors(successors.begin(), successors.end()) {} - void addOperands(ArrayRef newOperands) { + void addOperands(ArrayRef newOperands) { assert(successors.empty() && "Non successor operands should be added first."); operands.append(newOperands.begin(), newOperands.end()); @@ -247,7 +247,7 @@ public: attributes.push_back({name, attr}); } - void addSuccessor(StmtBlock *successor, ArrayRef succOperands) { + void addSuccessor(StmtBlock *successor, ArrayRef succOperands) { successors.push_back(successor); // Insert a sentinal operand to mark a barrier between successor operands. operands.push_back(nullptr); diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index df368c85104..b015f7bb44d 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -222,8 +222,8 @@ public: /// clients can specify a list of other nodes that this replacement may make /// (perhaps transitively) dead. If any of those values are dead, this will /// remove them as well. - void replaceOp(Operation *op, ArrayRef newValues, - ArrayRef valuesToRemoveIfDead = {}); + void replaceOp(Operation *op, ArrayRef newValues, + ArrayRef valuesToRemoveIfDead = {}); /// Replaces the result op with a new op that is created without verification. /// The result values of the two ops must be the same types. @@ -237,8 +237,7 @@ public: /// The result values of the two ops must be the same types. This allows /// specifying a list of ops that may be removed if dead. template - void replaceOpWithNewOp(Operation *op, - ArrayRef valuesToRemoveIfDead, + void replaceOpWithNewOp(Operation *op, ArrayRef valuesToRemoveIfDead, Args... args) { auto newOp = create(op->getLoc(), args...); replaceOpWithResultsOfAnotherOp(op, newOp->getOperation(), @@ -254,7 +253,7 @@ public: /// rewriter should remove if they are dead at this point. /// void updatedRootInPlace(Operation *op, - ArrayRef valuesToRemoveIfDead = {}); + ArrayRef valuesToRemoveIfDead = {}); protected: PatternRewriter(MLIRContext *context) : Builder(context) {} @@ -284,9 +283,8 @@ protected: private: /// op and newOp are known to have the same number of results, replace the /// uses of op with uses of newOp - void - replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp, - ArrayRef valuesToRemoveIfDead); + void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp, + ArrayRef valuesToRemoveIfDead); }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/SSAValue.h b/mlir/include/mlir/IR/SSAValue.h deleted file mode 100644 index 5791cbfd17a..00000000000 --- a/mlir/include/mlir/IR/SSAValue.h +++ /dev/null @@ -1,154 +0,0 @@ -//===- SSAValue.h - Base of the value hierarchy -----------------*- C++ -*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file defines generic SSAValue type and manipulation utilities. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_IR_SSAVALUE_H -#define MLIR_IR_SSAVALUE_H - -#include "mlir/IR/Types.h" -#include "mlir/IR/UseDefLists.h" -#include "mlir/Support/LLVM.h" - -namespace mlir { -class Function; -class OperationStmt; -class Operation; -class Statement; -using Instruction = Statement; -using OperationInst = OperationStmt; - -/// This enumerates all of the SSA value kinds in the MLIR system. -enum class SSAValueKind { - BlockArgument, // Block argument - StmtResult, // statement result - ForStmt, // for statement induction variable -}; - -/// This is the common base class for all values in the MLIR system, -/// representing a computable value that has a type and a set of users. -/// -class SSAValue : public IRObjectWithUseList { -public: - ~SSAValue() {} - - SSAValueKind getKind() const { return typeAndKind.getInt(); } - - Type getType() const { return typeAndKind.getPointer(); } - - /// Replace all uses of 'this' value with the new value, updating anything in - /// the IR that uses 'this' to use the other value instead. When this returns - /// there are zero uses of 'this'. - void replaceAllUsesWith(SSAValue *newValue) { - IRObjectWithUseList::replaceAllUsesWith(newValue); - } - - /// Return the function that this SSAValue is defined in. - Function *getFunction(); - - /// Return the function that this SSAValue is defined in. - const Function *getFunction() const { - return const_cast(this)->getFunction(); - } - - /// If this value is the result of an Instruction, return the instruction - /// that defines it. - OperationInst *getDefiningInst(); - const OperationInst *getDefiningInst() const { - return const_cast(this)->getDefiningInst(); - } - - /// If this value is the result of an OperationStmt, return the statement - /// that defines it. - OperationStmt *getDefiningStmt(); - const OperationStmt *getDefiningStmt() const { - return const_cast(this)->getDefiningStmt(); - } - - /// If this value is the result of an Operation, return the operation that - /// defines it. - Operation *getDefiningOperation(); - const Operation *getDefiningOperation() const { - return const_cast(this)->getDefiningOperation(); - } - - void print(raw_ostream &os) const; - void dump() const; - -protected: - SSAValue(SSAValueKind kind, Type type) : typeAndKind(type, kind) {} - -private: - const llvm::PointerIntPair typeAndKind; -}; - -inline raw_ostream &operator<<(raw_ostream &os, const SSAValue &value) { - value.print(os); - return os; -} - -/// This template unifies the implementation logic for CFGValue and MLValue -/// while providing more type-specific APIs when walking use lists etc. -/// -/// IROperandTy is the concrete instance of IROperand to use (including -/// substituted template arguments). -/// IROwnerTy is the type of the owner of an IROperandTy type. -/// KindTy is the enum 'kind' discriminator that subclasses want to use. -/// -template -class SSAValueImpl : public SSAValue { -public: - // Provide more specific implementations of the base class functionality. - KindTy getKind() const { return (KindTy)SSAValue::getKind(); } - - using use_iterator = SSAValueUseIterator; - using use_range = llvm::iterator_range; - - inline use_iterator use_begin() const; - inline use_iterator use_end() const; - - /// Returns a range of all uses, which is useful for iterating over all uses. - inline use_range getUses() const; - -protected: - SSAValueImpl(KindTy kind, Type type) : SSAValue((SSAValueKind)kind, type) {} -}; - -// Utility functions for iterating through SSAValue uses. -template -inline auto SSAValueImpl::use_begin() const - -> use_iterator { - return use_iterator((IROperandTy *)getFirstUse()); -} - -template -inline auto SSAValueImpl::use_end() const - -> use_iterator { - return use_iterator(nullptr); -} - -template -inline auto SSAValueImpl::getUses() const - -> llvm::iterator_range { - return {use_begin(), use_end()}; -} - -} // namespace mlir - -#endif diff --git a/mlir/include/mlir/IR/Statement.h b/mlir/include/mlir/IR/Statement.h index c7eddaf8d3c..2871b274de1 100644 --- a/mlir/include/mlir/IR/Statement.h +++ b/mlir/include/mlir/IR/Statement.h @@ -22,8 +22,8 @@ #ifndef MLIR_IR_STATEMENT_H #define MLIR_IR_STATEMENT_H -#include "mlir/IR/MLValue.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/ilist.h" #include "llvm/ADT/ilist_node.h" @@ -84,8 +84,8 @@ public: // This is a verbose type used by the clone method below. using OperandMapTy = - DenseMap, - llvm::detail::DenseMapPair>; + DenseMap, + llvm::detail::DenseMapPair>; /// Create a deep copy of this statement, remapping any operands that use /// values outside of the statement using the map that is provided (leaving @@ -136,12 +136,12 @@ public: unsigned getNumOperands() const; - MLValue *getOperand(unsigned idx); - const MLValue *getOperand(unsigned idx) const; - void setOperand(unsigned idx, MLValue *value); + Value *getOperand(unsigned idx); + const Value *getOperand(unsigned idx) const; + void setOperand(unsigned idx, Value *value); // Support non-const operand iteration. - using operand_iterator = OperandIterator; + using operand_iterator = OperandIterator; operand_iterator operand_begin() { return operand_iterator(this, 0); } @@ -149,14 +149,13 @@ public: return operand_iterator(this, getNumOperands()); } - /// Returns an iterator on the underlying MLValue's (MLValue *). + /// Returns an iterator on the underlying Value's (Value *). llvm::iterator_range getOperands() { return {operand_begin(), operand_end()}; } // Support const operand iteration. - using const_operand_iterator = - OperandIterator; + using const_operand_iterator = OperandIterator; const_operand_iterator operand_begin() const { return const_operand_iterator(this, 0); @@ -166,7 +165,7 @@ public: return const_operand_iterator(this, getNumOperands()); } - /// Returns a const iterator on the underlying MLValue's (MLValue *). + /// Returns a const iterator on the underlying Value's (Value *). llvm::iterator_range getOperands() const { return {operand_begin(), operand_end()}; } diff --git a/mlir/include/mlir/IR/Statements.h b/mlir/include/mlir/IR/Statements.h index b1c03e948db..1ca511c00fa 100644 --- a/mlir/include/mlir/IR/Statements.h +++ b/mlir/include/mlir/IR/Statements.h @@ -43,7 +43,7 @@ class OperationStmt final public: /// Create a new OperationStmt with the specific fields. static OperationStmt * - create(Location location, OperationName name, ArrayRef operands, + create(Location location, OperationName name, ArrayRef operands, ArrayRef resultTypes, ArrayRef attributes, ArrayRef successors, MLIRContext *context); @@ -69,16 +69,16 @@ public: unsigned getNumOperands() const { return numOperands; } - MLValue *getOperand(unsigned idx) { return getStmtOperand(idx).get(); } - const MLValue *getOperand(unsigned idx) const { + Value *getOperand(unsigned idx) { return getStmtOperand(idx).get(); } + const Value *getOperand(unsigned idx) const { return getStmtOperand(idx).get(); } - void setOperand(unsigned idx, MLValue *value) { + void setOperand(unsigned idx, Value *value) { return getStmtOperand(idx).set(value); } // Support non-const operand iteration. - using operand_iterator = OperandIterator; + using operand_iterator = OperandIterator; operand_iterator operand_begin() { return operand_iterator(this, 0); } @@ -86,14 +86,14 @@ public: return operand_iterator(this, getNumOperands()); } - /// Returns an iterator on the underlying MLValue's (MLValue *). + /// Returns an iterator on the underlying Value's (Value *). llvm::iterator_range getOperands() { return {operand_begin(), operand_end()}; } // Support const operand iteration. using const_operand_iterator = - OperandIterator; + OperandIterator; const_operand_iterator operand_begin() const { return const_operand_iterator(this, 0); @@ -103,7 +103,7 @@ public: return const_operand_iterator(this, getNumOperands()); } - /// Returns a const iterator on the underlying MLValue's (MLValue *). + /// Returns a const iterator on the underlying Value's (Value *). llvm::iterator_range getOperands() const { return {operand_begin(), operand_end()}; } @@ -126,11 +126,11 @@ public: unsigned getNumResults() const { return numResults; } - MLValue *getResult(unsigned idx) { return &getStmtResult(idx); } - const MLValue *getResult(unsigned idx) const { return &getStmtResult(idx); } + Value *getResult(unsigned idx) { return &getStmtResult(idx); } + const Value *getResult(unsigned idx) const { return &getStmtResult(idx); } // Support non-const result iteration. - using result_iterator = ResultIterator; + using result_iterator = ResultIterator; result_iterator result_begin() { return result_iterator(this, 0); } result_iterator result_end() { return result_iterator(this, getNumResults()); @@ -141,7 +141,7 @@ public: // Support const result iteration. using const_result_iterator = - ResultIterator; + ResultIterator; const_result_iterator result_begin() const { return const_result_iterator(this, 0); } @@ -170,7 +170,7 @@ public: // Support result type iteration. using result_type_iterator = - ResultTypeIterator; + ResultTypeIterator; result_type_iterator result_type_begin() const { return result_type_iterator(this, 0); } @@ -290,15 +290,15 @@ private: }; /// For statement represents an affine loop nest. -class ForStmt : public Statement, public MLValue { +class ForStmt : public Statement, public Value { public: - static ForStmt *create(Location location, ArrayRef lbOperands, - AffineMap lbMap, ArrayRef ubOperands, + static ForStmt *create(Location location, ArrayRef lbOperands, + AffineMap lbMap, ArrayRef ubOperands, AffineMap ubMap, int64_t step); ~ForStmt() { // Explicitly erase statements instead of relying of 'StmtBlock' destructor - // since child statements need to be destroyed before the MLValue that this + // since child statements need to be destroyed before the Value that this // for stmt represents is destroyed. Affine maps are immortal objects and // don't need to be deleted. getBody()->clear(); @@ -308,8 +308,8 @@ public: using Statement::getFunction; /// Operand iterators. - using operand_iterator = OperandIterator; - using const_operand_iterator = OperandIterator; + using operand_iterator = OperandIterator; + using const_operand_iterator = OperandIterator; /// Operand iterator range. using operand_range = llvm::iterator_range; @@ -340,9 +340,9 @@ public: AffineMap getUpperBoundMap() const { return ubMap; } /// Set lower bound. - void setLowerBound(ArrayRef operands, AffineMap map); + void setLowerBound(ArrayRef operands, AffineMap map); /// Set upper bound. - void setUpperBound(ArrayRef operands, AffineMap map); + void setUpperBound(ArrayRef operands, AffineMap map); /// Set the lower bound map without changing operands. void setLowerBoundMap(AffineMap map); @@ -385,11 +385,11 @@ public: unsigned getNumOperands() const { return operands.size(); } - MLValue *getOperand(unsigned idx) { return getStmtOperand(idx).get(); } - const MLValue *getOperand(unsigned idx) const { + Value *getOperand(unsigned idx) { return getStmtOperand(idx).get(); } + const Value *getOperand(unsigned idx) const { return getStmtOperand(idx).get(); } - void setOperand(unsigned idx, MLValue *value) { + void setOperand(unsigned idx, Value *value) { getStmtOperand(idx).set(value); } @@ -439,10 +439,10 @@ public: } // For statement represents implicitly represents induction variable by - // inheriting from MLValue class. Whenever you need to refer to the loop + // inheriting from Value class. Whenever you need to refer to the loop // induction variable, just use the for statement itself. - static bool classof(const SSAValue *value) { - return value->getKind() == SSAValueKind::ForStmt; + static bool classof(const Value *value) { + return value->getKind() == Value::Kind::ForStmt; } private: @@ -475,7 +475,7 @@ public: AffineMap getMap() const { return map; } unsigned getNumOperands() const { return opEnd - opStart; } - const MLValue *getOperand(unsigned idx) const { + const Value *getOperand(unsigned idx) const { return stmt.getOperand(opStart + idx); } const StmtOperand &getStmtOperand(unsigned idx) const { @@ -486,15 +486,15 @@ public: using operand_range = ForStmt::operand_range; operand_iterator operand_begin() const { - // These are iterators over MLValue *. Not casting away const'ness would - // require the caller to use const MLValue *. + // These are iterators over Value *. Not casting away const'ness would + // require the caller to use const Value *. return operand_iterator(const_cast(&stmt), opStart); } operand_iterator operand_end() const { return operand_iterator(const_cast(&stmt), opEnd); } - /// Returns an iterator on the underlying MLValue's (MLValue *). + /// Returns an iterator on the underlying Value's (Value *). operand_range getOperands() const { return {operand_begin(), operand_end()}; } ArrayRef getStmtOperands() const { auto ops = stmt.getStmtOperands(); @@ -520,7 +520,7 @@ private: /// If statement restricts execution to a subset of the loop iteration space. class IfStmt : public Statement { public: - static IfStmt *create(Location location, ArrayRef operands, + static IfStmt *create(Location location, ArrayRef operands, IntegerSet set); ~IfStmt(); @@ -556,8 +556,8 @@ public: //===--------------------------------------------------------------------===// /// Operand iterators. - using operand_iterator = OperandIterator; - using const_operand_iterator = OperandIterator; + using operand_iterator = OperandIterator; + using const_operand_iterator = OperandIterator; /// Operand iterator range. using operand_range = llvm::iterator_range; @@ -565,11 +565,11 @@ public: unsigned getNumOperands() const { return operands.size(); } - MLValue *getOperand(unsigned idx) { return getStmtOperand(idx).get(); } - const MLValue *getOperand(unsigned idx) const { + Value *getOperand(unsigned idx) { return getStmtOperand(idx).get(); } + const Value *getOperand(unsigned idx) const { return getStmtOperand(idx).get(); } - void setOperand(unsigned idx, MLValue *value) { + void setOperand(unsigned idx, Value *value) { getStmtOperand(idx).set(value); } diff --git a/mlir/include/mlir/IR/StmtBlock.h b/mlir/include/mlir/IR/StmtBlock.h index 23b682043ac..57d9f8dde70 100644 --- a/mlir/include/mlir/IR/StmtBlock.h +++ b/mlir/include/mlir/IR/StmtBlock.h @@ -26,7 +26,6 @@ namespace mlir { class IfStmt; -class MLValue; class StmtBlockList; using CFGFunction = Function; using MLFunction = Function; @@ -412,7 +411,7 @@ public: } private: - using BBUseIterator = SSAValueUseIterator; + using BBUseIterator = ValueUseIterator; BBUseIterator bbUseIterator; }; diff --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h index d5b5dda8c35..53164595c37 100644 --- a/mlir/include/mlir/IR/UseDefLists.h +++ b/mlir/include/mlir/IR/UseDefLists.h @@ -30,7 +30,7 @@ namespace mlir { class IROperand; class IROperandOwner; -template class SSAValueUseIterator; +template class ValueUseIterator; class IRObjectWithUseList { public: @@ -44,7 +44,7 @@ public: /// Returns true if this value has exactly one use. inline bool hasOneUse() const; - using use_iterator = SSAValueUseIterator; + using use_iterator = ValueUseIterator; using use_range = llvm::iterator_range; inline use_iterator use_begin() const; @@ -228,33 +228,33 @@ public: /// An iterator over all uses of a ValueBase. template -class SSAValueUseIterator +class ValueUseIterator : public std::iterator { public: - SSAValueUseIterator() = default; - explicit SSAValueUseIterator(OperandType *current) : current(current) {} + ValueUseIterator() = default; + explicit ValueUseIterator(OperandType *current) : current(current) {} OperandType *operator->() const { return current; } OperandType &operator*() const { return *current; } OwnerType *getUser() const { return current->getOwner(); } - SSAValueUseIterator &operator++() { + ValueUseIterator &operator++() { assert(current && "incrementing past end()!"); current = (OperandType *)current->getNextOperandUsingThisValue(); return *this; } - SSAValueUseIterator operator++(int unused) { - SSAValueUseIterator copy = *this; + ValueUseIterator operator++(int unused) { + ValueUseIterator copy = *this; ++*this; return copy; } - friend bool operator==(SSAValueUseIterator lhs, SSAValueUseIterator rhs) { + friend bool operator==(ValueUseIterator lhs, ValueUseIterator rhs) { return lhs.current == rhs.current; } - friend bool operator!=(SSAValueUseIterator lhs, SSAValueUseIterator rhs) { + friend bool operator!=(ValueUseIterator lhs, ValueUseIterator rhs) { return !(lhs == rhs); } diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h new file mode 100644 index 00000000000..c7fe8d3d130 --- /dev/null +++ b/mlir/include/mlir/IR/Value.h @@ -0,0 +1,198 @@ +//===- Value.h - Base of the SSA Value hierarchy ----------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines generic Value type and manipulation utilities. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_VALUE_H +#define MLIR_IR_VALUE_H + +#include "mlir/IR/Types.h" +#include "mlir/IR/UseDefLists.h" +#include "mlir/Support/LLVM.h" + +namespace mlir { +class Function; +class OperationStmt; +class Operation; +class Statement; +class StmtBlock; +class Value; +using Instruction = Statement; +using OperationInst = OperationStmt; + +/// The operand of ML function statement contains a Value. +using StmtOperand = IROperandImpl; + +/// This is the common base class for all values in the MLIR system, +/// representing a computable value that has a type and a set of users. +/// +class Value : public IRObjectWithUseList { +public: + /// This enumerates all of the SSA value kinds in the MLIR system. + enum class Kind { + BlockArgument, // block argument + StmtResult, // statement result + ForStmt, // for statement induction variable + }; + + ~Value() {} + + Kind getKind() const { return typeAndKind.getInt(); } + + Type getType() const { return typeAndKind.getPointer(); } + + /// Replace all uses of 'this' value with the new value, updating anything in + /// the IR that uses 'this' to use the other value instead. When this returns + /// there are zero uses of 'this'. + void replaceAllUsesWith(Value *newValue) { + IRObjectWithUseList::replaceAllUsesWith(newValue); + } + + /// TODO: move isValidDim/isValidSymbol to a utility library specific to the + /// polyhedral operations. + + /// Returns true if the given Value can be used as a dimension id. + bool isValidDim() const; + + /// Returns true if the given Value can be used as a symbol. + bool isValidSymbol() const; + + /// Return the function that this Value is defined in. + Function *getFunction(); + + /// Return the function that this Value is defined in. + const Function *getFunction() const { + return const_cast(this)->getFunction(); + } + + /// If this value is the result of an Instruction, return the instruction + /// that defines it. + OperationInst *getDefiningInst(); + const OperationInst *getDefiningInst() const { + return const_cast(this)->getDefiningInst(); + } + + /// If this value is the result of an OperationStmt, return the statement + /// that defines it. + OperationStmt *getDefiningStmt(); + const OperationStmt *getDefiningStmt() const { + return const_cast(this)->getDefiningStmt(); + } + + /// If this value is the result of an Operation, return the operation that + /// defines it. + Operation *getDefiningOperation(); + const Operation *getDefiningOperation() const { + return const_cast(this)->getDefiningOperation(); + } + + using use_iterator = ValueUseIterator; + using use_range = llvm::iterator_range; + + inline use_iterator use_begin() const; + inline use_iterator use_end() const; + + /// Returns a range of all uses, which is useful for iterating over all uses. + inline use_range getUses() const; + + void print(raw_ostream &os) const; + void dump() const; + +protected: + Value(Kind kind, Type type) : typeAndKind(type, kind) {} + +private: + const llvm::PointerIntPair typeAndKind; +}; + +inline raw_ostream &operator<<(raw_ostream &os, const Value &value) { + value.print(os); + return os; +} + +// Utility functions for iterating through Value uses. +inline auto Value::use_begin() const -> use_iterator { + return use_iterator((StmtOperand *)getFirstUse()); +} + +inline auto Value::use_end() const -> use_iterator { + return use_iterator(nullptr); +} + +inline auto Value::getUses() const -> llvm::iterator_range { + return {use_begin(), use_end()}; +} + +/// Block arguments are values. +class BlockArgument : public Value { +public: + static bool classof(const Value *value) { + return value->getKind() == Kind::BlockArgument; + } + + /// Return the function that this argument is defined in. + Function *getFunction(); + const Function *getFunction() const { + return const_cast(this)->getFunction(); + } + + StmtBlock *getOwner() { return owner; } + const StmtBlock *getOwner() const { return owner; } + +private: + friend class StmtBlock; // For access to private constructor. + BlockArgument(Type type, StmtBlock *owner) + : Value(Value::Kind::BlockArgument, type), owner(owner) {} + + /// The owner of this operand. + /// TODO: can encode this more efficiently to avoid the space hit of this + /// through bitpacking shenanigans. + StmtBlock *const owner; +}; + +/// This is a value defined by a result of an operation instruction. +class StmtResult : public Value { +public: + StmtResult(Type type, OperationStmt *owner) + : Value(Value::Kind::StmtResult, type), owner(owner) {} + + static bool classof(const Value *value) { + return value->getKind() == Kind::StmtResult; + } + + OperationStmt *getOwner() { return owner; } + const OperationStmt *getOwner() const { return owner; } + + /// Returns the number of this result. + unsigned getResultNumber() const; + +private: + /// The owner of this operand. + /// TODO: can encode this more efficiently to avoid the space hit of this + /// through bitpacking shenanigans. + OperationStmt *const owner; +}; + +// TODO(clattner) clean all this up. +using BBArgument = BlockArgument; +using InstResult = StmtResult; + +} // namespace mlir + +#endif diff --git a/mlir/include/mlir/IR/op_base.td b/mlir/include/mlir/IR/op_base.td index 5b579446c7f..2696ce37b78 100644 --- a/mlir/include/mlir/IR/op_base.td +++ b/mlir/include/mlir/IR/op_base.td @@ -157,14 +157,14 @@ class Op props = []> { // // static void build(Builder* builder, OperationState* result, // Type resultType0, Type resultType1, ..., - // SSAValue* arg0, SSAValue* arg1, ..., + // Value arg0, Value arg1, ..., // Attribute , Attribute , ...); // // * where the attributes follow the same declaration order as in the op. // // static void build(Builder* builder, OperationState* result, // ArrayRef resultTypes, - // ArrayRef args, + // ArrayRef args, // ArrayRef attributes); code builder = ?; diff --git a/mlir/include/mlir/StandardOps/StandardOps.h b/mlir/include/mlir/StandardOps/StandardOps.h index 6073eba075f..789dfed43f1 100644 --- a/mlir/include/mlir/StandardOps/StandardOps.h +++ b/mlir/include/mlir/StandardOps/StandardOps.h @@ -30,7 +30,6 @@ namespace mlir { class AffineMap; class Builder; -class MLValue; class StandardOpsDialect : public Dialect { public: @@ -48,8 +47,8 @@ class AddFOp : public BinaryOp { public: - static void build(Builder *builder, OperationState *result, SSAValue *lhs, - SSAValue *rhs); + static void build(Builder *builder, OperationState *result, Value *lhs, + Value *rhs); static StringRef getOperationName() { return "addf"; } @@ -116,7 +115,7 @@ public: // Hooks to customize behavior of this op. static void build(Builder *builder, OperationState *result, - MemRefType memrefType, ArrayRef operands = {}); + MemRefType memrefType, ArrayRef operands = {}); bool verify() const; static bool parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p) const; @@ -140,7 +139,7 @@ public: static StringRef getOperationName() { return "call"; } static void build(Builder *builder, OperationState *result, Function *callee, - ArrayRef operands); + ArrayRef operands); Function *getCallee() const { return getAttrOfType("callee").getValue(); @@ -169,11 +168,11 @@ class CallIndirectOp : public Op operands); + static void build(Builder *builder, OperationState *result, Value *callee, + ArrayRef operands); - const SSAValue *getCallee() const { return getOperand(0); } - SSAValue *getCallee() { return getOperand(0); } + const Value *getCallee() const { return getOperand(0); } + Value *getCallee() { return getOperand(0); } // Hooks to customize behavior of this op. static bool parse(OpAsmParser *parser, OperationState *result); @@ -240,7 +239,7 @@ public: static CmpIPredicate getPredicateByName(StringRef name); static void build(Builder *builder, OperationState *result, CmpIPredicate, - SSAValue *lhs, SSAValue *rhs); + Value *lhs, Value *rhs); static bool parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p) const; bool verify() const; @@ -263,14 +262,14 @@ private: class DeallocOp : public Op { public: - SSAValue *getMemRef() { return getOperand(); } - const SSAValue *getMemRef() const { return getOperand(); } - void setMemRef(SSAValue *value) { setOperand(value); } + Value *getMemRef() { return getOperand(); } + const Value *getMemRef() const { return getOperand(); } + void setMemRef(Value *value) { setOperand(value); } static StringRef getOperationName() { return "dealloc"; } // Hooks to customize behavior of this op. - static void build(Builder *builder, OperationState *result, SSAValue *memref); + static void build(Builder *builder, OperationState *result, Value *memref); bool verify() const; static bool parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p) const; @@ -292,7 +291,7 @@ class DimOp : public Op { public: static void build(Builder *builder, OperationState *result, - SSAValue *memrefOrTensor, unsigned index); + Value *memrefOrTensor, unsigned index); Attribute constantFold(ArrayRef operands, MLIRContext *context) const; @@ -354,15 +353,15 @@ private: class DmaStartOp : public Op { public: - static void build(Builder *builder, OperationState *result, - SSAValue *srcMemRef, ArrayRef srcIndices, - SSAValue *destMemRef, ArrayRef destIndices, - SSAValue *numElements, SSAValue *tagMemRef, - ArrayRef tagIndices, SSAValue *stride = nullptr, - SSAValue *elementsPerStride = nullptr); + static void build(Builder *builder, OperationState *result, Value *srcMemRef, + ArrayRef srcIndices, Value *destMemRef, + ArrayRef destIndices, Value *numElements, + Value *tagMemRef, ArrayRef tagIndices, + Value *stride = nullptr, + Value *elementsPerStride = nullptr); // Returns the source MemRefType for this DMA operation. - const SSAValue *getSrcMemRef() const { return getOperand(0); } + const Value *getSrcMemRef() const { return getOperand(0); } // Returns the rank (number of indices) of the source MemRefType. unsigned getSrcMemRefRank() const { return getSrcMemRef()->getType().cast().getRank(); @@ -375,7 +374,7 @@ public: } // Returns the destination MemRefType for this DMA operations. - const SSAValue *getDstMemRef() const { + const Value *getDstMemRef() const { return getOperand(1 + getSrcMemRefRank()); } // Returns the rank (number of indices) of the destination MemRefType. @@ -398,12 +397,12 @@ public: } // Returns the number of elements being transferred by this DMA operation. - const SSAValue *getNumElements() const { + const Value *getNumElements() const { return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank()); } // Returns the Tag MemRef for this DMA operation. - const SSAValue *getTagMemRef() const { + const Value *getTagMemRef() const { return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1); } // Returns the rank (number of indices) of the tag MemRefType. @@ -453,21 +452,21 @@ public: 1 + 1 + getTagMemRefRank(); } - SSAValue *getStride() { + Value *getStride() { if (!isStrided()) return nullptr; return getOperand(getNumOperands() - 1 - 1); } - const SSAValue *getStride() const { + const Value *getStride() const { return const_cast(this)->getStride(); } - SSAValue *getNumElementsPerStride() { + Value *getNumElementsPerStride() { if (!isStrided()) return nullptr; return getOperand(getNumOperands() - 1); } - const SSAValue *getNumElementsPerStride() const { + const Value *getNumElementsPerStride() const { return const_cast(this)->getNumElementsPerStride(); } @@ -493,15 +492,14 @@ protected: class DmaWaitOp : public Op { public: - static void build(Builder *builder, OperationState *result, - SSAValue *tagMemRef, ArrayRef tagIndices, - SSAValue *numElements); + static void build(Builder *builder, OperationState *result, Value *tagMemRef, + ArrayRef tagIndices, Value *numElements); static StringRef getOperationName() { return "dma_wait"; } // Returns the Tag MemRef associated with the DMA operation being waited on. - const SSAValue *getTagMemRef() const { return getOperand(0); } - SSAValue *getTagMemRef() { return getOperand(0); } + const Value *getTagMemRef() const { return getOperand(0); } + Value *getTagMemRef() { return getOperand(0); } // Returns the tag memref index for this DMA operation. llvm::iterator_range @@ -516,7 +514,7 @@ public: } // Returns the number of elements transferred in the associated DMA operation. - const SSAValue *getNumElements() const { + const Value *getNumElements() const { return getOperand(1 + getTagMemRefRank()); } @@ -545,11 +543,11 @@ class ExtractElementOp : public Op { public: - static void build(Builder *builder, OperationState *result, - SSAValue *aggregate, ArrayRef indices = {}); + static void build(Builder *builder, OperationState *result, Value *aggregate, + ArrayRef indices = {}); - SSAValue *getAggregate() { return getOperand(0); } - const SSAValue *getAggregate() const { return getOperand(0); } + Value *getAggregate() { return getOperand(0); } + const Value *getAggregate() const { return getOperand(0); } llvm::iterator_range getIndices() { return {getOperation()->operand_begin() + 1, getOperation()->operand_end()}; @@ -583,12 +581,12 @@ class LoadOp : public Op { public: // Hooks to customize behavior of this op. - static void build(Builder *builder, OperationState *result, SSAValue *memref, - ArrayRef indices = {}); + static void build(Builder *builder, OperationState *result, Value *memref, + ArrayRef indices = {}); - SSAValue *getMemRef() { return getOperand(0); } - const SSAValue *getMemRef() const { return getOperand(0); } - void setMemRef(SSAValue *value) { setOperand(0, value); } + Value *getMemRef() { return getOperand(0); } + const Value *getMemRef() const { return getOperand(0); } + void setMemRef(Value *value) { setOperand(0, value); } MemRefType getMemRefType() const { return getMemRef()->getType().cast(); } @@ -705,19 +703,18 @@ class SelectOp : public Op::Impl, OpTrait::OneResult, OpTrait::HasNoSideEffect> { public: static StringRef getOperationName() { return "select"; } - static void build(Builder *builder, OperationState *result, - SSAValue *condition, SSAValue *trueValue, - SSAValue *falseValue); + static void build(Builder *builder, OperationState *result, Value *condition, + Value *trueValue, Value *falseValue); static bool parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p) const; bool verify() const; - SSAValue *getCondition() { return getOperand(0); } - const SSAValue *getCondition() const { return getOperand(0); } - SSAValue *getTrueValue() { return getOperand(1); } - const SSAValue *getTrueValue() const { return getOperand(1); } - SSAValue *getFalseValue() { return getOperand(2); } - const SSAValue *getFalseValue() const { return getOperand(2); } + Value *getCondition() { return getOperand(0); } + const Value *getCondition() const { return getOperand(0); } + Value *getTrueValue() { return getOperand(1); } + const Value *getTrueValue() const { return getOperand(1); } + Value *getFalseValue() { return getOperand(2); } + const Value *getFalseValue() const { return getOperand(2); } Attribute constantFold(ArrayRef operands, MLIRContext *context) const; @@ -742,15 +739,15 @@ class StoreOp public: // Hooks to customize behavior of this op. static void build(Builder *builder, OperationState *result, - SSAValue *valueToStore, SSAValue *memref, - ArrayRef indices = {}); + Value *valueToStore, Value *memref, + ArrayRef indices = {}); - SSAValue *getValueToStore() { return getOperand(0); } - const SSAValue *getValueToStore() const { return getOperand(0); } + Value *getValueToStore() { return getOperand(0); } + const Value *getValueToStore() const { return getOperand(0); } - SSAValue *getMemRef() { return getOperand(1); } - const SSAValue *getMemRef() const { return getOperand(1); } - void setMemRef(SSAValue *value) { setOperand(1, value); } + Value *getMemRef() { return getOperand(1); } + const Value *getMemRef() const { return getOperand(1); } + void setMemRef(Value *value) { setOperand(1, value); } MemRefType getMemRefType() const { return getMemRef()->getType().cast(); } diff --git a/mlir/include/mlir/SuperVectorOps/SuperVectorOps.h b/mlir/include/mlir/SuperVectorOps/SuperVectorOps.h index 918bc60d1cc..38bc82569a5 100644 --- a/mlir/include/mlir/SuperVectorOps/SuperVectorOps.h +++ b/mlir/include/mlir/SuperVectorOps/SuperVectorOps.h @@ -98,26 +98,24 @@ public: static StringRef getOperationName() { return "vector_transfer_read"; } static StringRef getPermutationMapAttrName() { return "permutation_map"; } static void build(Builder *builder, OperationState *result, - VectorType vectorType, SSAValue *srcMemRef, - ArrayRef srcIndices, AffineMap permutationMap, - Optional paddingValue = None); + VectorType vectorType, Value *srcMemRef, + ArrayRef srcIndices, AffineMap permutationMap, + Optional paddingValue = None); VectorType getResultType() const { return getResult()->getType().cast(); } - SSAValue *getVector() { return getResult(); } - const SSAValue *getVector() const { return getResult(); } - SSAValue *getMemRef() { return getOperand(Offsets::MemRefOffset); } - const SSAValue *getMemRef() const { - return getOperand(Offsets::MemRefOffset); - } + Value *getVector() { return getResult(); } + const Value *getVector() const { return getResult(); } + Value *getMemRef() { return getOperand(Offsets::MemRefOffset); } + const Value *getMemRef() const { return getOperand(Offsets::MemRefOffset); } VectorType getVectorType() const { return getResultType(); } MemRefType getMemRefType() const { return getMemRef()->getType().cast(); } llvm::iterator_range getIndices(); llvm::iterator_range getIndices() const; - Optional getPaddingValue(); - Optional getPaddingValue() const; + Optional getPaddingValue(); + Optional getPaddingValue() const; AffineMap getPermutationMap() const; static bool parse(OpAsmParser *parser, OperationState *result); @@ -169,20 +167,16 @@ class VectorTransferWriteOp public: static StringRef getOperationName() { return "vector_transfer_write"; } static StringRef getPermutationMapAttrName() { return "permutation_map"; } - static void build(Builder *builder, OperationState *result, - SSAValue *srcVector, SSAValue *dstMemRef, - ArrayRef dstIndices, AffineMap permutationMap); - SSAValue *getVector() { return getOperand(Offsets::VectorOffset); } - const SSAValue *getVector() const { - return getOperand(Offsets::VectorOffset); - } + static void build(Builder *builder, OperationState *result, Value *srcVector, + Value *dstMemRef, ArrayRef dstIndices, + AffineMap permutationMap); + Value *getVector() { return getOperand(Offsets::VectorOffset); } + const Value *getVector() const { return getOperand(Offsets::VectorOffset); } VectorType getVectorType() const { return getVector()->getType().cast(); } - SSAValue *getMemRef() { return getOperand(Offsets::MemRefOffset); } - const SSAValue *getMemRef() const { - return getOperand(Offsets::MemRefOffset); - } + Value *getMemRef() { return getOperand(Offsets::MemRefOffset); } + const Value *getMemRef() const { return getOperand(Offsets::MemRefOffset); } MemRefType getMemRefType() const { return getMemRef()->getType().cast(); } @@ -212,8 +206,8 @@ class VectorTypeCastOp : public Op { public: static StringRef getOperationName() { return "vector_type_cast"; } - static void build(Builder *builder, OperationState *result, - SSAValue *srcVector, Type dstType); + static void build(Builder *builder, OperationState *result, Value *srcVector, + Type dstType); static bool parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p) const; bool verify() const; diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h index 119f2add54a..5670b60e0bd 100644 --- a/mlir/include/mlir/Transforms/Utils.h +++ b/mlir/include/mlir/Transforms/Utils.h @@ -35,10 +35,9 @@ namespace mlir { class ForStmt; class FuncBuilder; class Location; -class MLValue; class Module; class OperationStmt; -class SSAValue; + class Function; using CFGFunction = Function; @@ -52,12 +51,12 @@ using CFGFunction = Function; /// Returns true on success and false if the replacement is not possible /// (whenever a memref is used as an operand in a non-deferencing scenario). See /// comments at function definition for an example. -// TODO(mlir-team): extend this for SSAValue / CFGFunctions. Can also be easily +// TODO(mlir-team): extend this for Value/ CFGFunctions. Can also be easily // extended to add additional indices at any position. -bool replaceAllMemRefUsesWith(const MLValue *oldMemRef, MLValue *newMemRef, - ArrayRef extraIndices = {}, +bool replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, + ArrayRef extraIndices = {}, AffineMap indexRemap = AffineMap::Null(), - ArrayRef extraOperands = {}, + ArrayRef extraOperands = {}, const Statement *domStmtFilter = nullptr); /// Creates and inserts into 'builder' a new AffineApplyOp, with the number of @@ -69,9 +68,9 @@ bool replaceAllMemRefUsesWith(const MLValue *oldMemRef, MLValue *newMemRef, /// parameter 'results'. Returns the affine apply op created. OperationStmt * createComposedAffineApplyOp(FuncBuilder *builder, Location loc, - ArrayRef operands, + ArrayRef operands, ArrayRef affineApplyOps, - SmallVectorImpl *results); + SmallVectorImpl *results); /// Given an operation statement, inserts a new single affine apply operation, /// that is exclusively used by this operation statement, and that provides all @@ -104,7 +103,7 @@ OperationStmt *createAffineComputationSlice(OperationStmt *opStmt); /// Forward substitutes results from 'AffineApplyOp' into any users which /// are also AffineApplyOps. // NOTE: This method may modify users of results of this operation. -// TODO(mlir-team): extend this for SSAValue / CFGFunctions. +// TODO(mlir-team): extend this for Value/ CFGFunctions. void forwardSubstitute(OpPointer affineApplyOp); /// Folds the lower and upper bounds of a 'for' stmt to constants if possible. diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index bdc2c7ec286..04ef715d011 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -489,11 +489,11 @@ bool mlir::getFlattenedAffineExprs( // TODO(andydavis) Add a method to AffineApplyOp which forward substitutes // the AffineApplyOp into any user AffineApplyOps. void mlir::getReachableAffineApplyOps( - ArrayRef operands, + ArrayRef operands, SmallVectorImpl &affineApplyOps) { struct State { // The ssa value for this node in the DFS traversal. - MLValue *value; + Value *value; // The operand index of 'value' to explore next during DFS traversal. unsigned operandIndex; }; @@ -557,8 +557,8 @@ void mlir::forwardSubstituteReachableOps(AffineValueMap *valueMap) { // setExprStride(ArrayRef expr, int64_t stride) bool mlir::getIndexSet(ArrayRef forStmts, FlatAffineConstraints *domain) { - SmallVector indices(forStmts.begin(), forStmts.end()); - // Reset while associated MLValues in 'indices' to the domain. + SmallVector indices(forStmts.begin(), forStmts.end()); + // Reset while associated Values in 'indices' to the domain. domain->reset(forStmts.size(), /*numSymbols=*/0, /*numLocals=*/0, indices); for (auto *forStmt : forStmts) { // Add constraints from forStmt's bounds. @@ -583,10 +583,10 @@ static bool getStmtIndexSet(const Statement *stmt, return getIndexSet(loops, indexSet); } -// ValuePositionMap manages the mapping from MLValues which represent dimension +// ValuePositionMap manages the mapping from Values which represent dimension // and symbol identifiers from 'src' and 'dst' access functions to positions -// in new space where some MLValues are kept separate (using addSrc/DstValue) -// and some MLValues are merged (addSymbolValue). +// in new space where some Values are kept separate (using addSrc/DstValue) +// and some Values are merged (addSymbolValue). // Position lookups return the absolute position in the new space which // has the following format: // @@ -595,7 +595,7 @@ static bool getStmtIndexSet(const Statement *stmt, // Note: access function non-IV dimension identifiers (that have 'dimension' // positions in the access function position space) are assigned as symbols // in the output position space. Convienience access functions which lookup -// an MLValue in multiple maps are provided (i.e. getSrcDimOrSymPos) to handle +// an Value in multiple maps are provided (i.e. getSrcDimOrSymPos) to handle // the common case of resolving positions for all access function operands. // // TODO(andydavis) Generalize this: could take a template parameter for @@ -603,25 +603,25 @@ static bool getStmtIndexSet(const Statement *stmt, // of maps to check. So getSrcDimOrSymPos would be "getPos(value, {0, 2})". class ValuePositionMap { public: - void addSrcValue(const MLValue *value) { + void addSrcValue(const Value *value) { if (addValueAt(value, &srcDimPosMap, numSrcDims)) ++numSrcDims; } - void addDstValue(const MLValue *value) { + void addDstValue(const Value *value) { if (addValueAt(value, &dstDimPosMap, numDstDims)) ++numDstDims; } - void addSymbolValue(const MLValue *value) { + void addSymbolValue(const Value *value) { if (addValueAt(value, &symbolPosMap, numSymbols)) ++numSymbols; } - unsigned getSrcDimOrSymPos(const MLValue *value) const { + unsigned getSrcDimOrSymPos(const Value *value) const { return getDimOrSymPos(value, srcDimPosMap, 0); } - unsigned getDstDimOrSymPos(const MLValue *value) const { + unsigned getDstDimOrSymPos(const Value *value) const { return getDimOrSymPos(value, dstDimPosMap, numSrcDims); } - unsigned getSymPos(const MLValue *value) const { + unsigned getSymPos(const Value *value) const { auto it = symbolPosMap.find(value); assert(it != symbolPosMap.end()); return numSrcDims + numDstDims + it->second; @@ -633,8 +633,7 @@ public: unsigned getNumSymbols() const { return numSymbols; } private: - bool addValueAt(const MLValue *value, - DenseMap *posMap, + bool addValueAt(const Value *value, DenseMap *posMap, unsigned position) { auto it = posMap->find(value); if (it == posMap->end()) { @@ -643,8 +642,8 @@ private: } return false; } - unsigned getDimOrSymPos(const MLValue *value, - const DenseMap &dimPosMap, + unsigned getDimOrSymPos(const Value *value, + const DenseMap &dimPosMap, unsigned dimPosOffset) const { auto it = dimPosMap.find(value); if (it != dimPosMap.end()) { @@ -658,25 +657,25 @@ private: unsigned numSrcDims = 0; unsigned numDstDims = 0; unsigned numSymbols = 0; - DenseMap srcDimPosMap; - DenseMap dstDimPosMap; - DenseMap symbolPosMap; + DenseMap srcDimPosMap; + DenseMap dstDimPosMap; + DenseMap symbolPosMap; }; -// Builds a map from MLValue to identifier position in a new merged identifier +// Builds a map from Value to identifier position in a new merged identifier // list, which is the result of merging dim/symbol lists from src/dst // iteration domains. The format of the new merged list is as follows: // // [src-dim-identifiers, dst-dim-identifiers, symbol-identifiers] // -// This method populates 'valuePosMap' with mappings from operand MLValues in +// This method populates 'valuePosMap' with mappings from operand Values in // 'srcAccessMap'/'dstAccessMap' (as well as those in 'srcDomain'/'dstDomain') // to the position of these values in the merged list. static void buildDimAndSymbolPositionMaps( const FlatAffineConstraints &srcDomain, const FlatAffineConstraints &dstDomain, const AffineValueMap &srcAccessMap, const AffineValueMap &dstAccessMap, ValuePositionMap *valuePosMap) { - auto updateValuePosMap = [&](ArrayRef values, bool isSrc) { + auto updateValuePosMap = [&](ArrayRef values, bool isSrc) { for (unsigned i = 0, e = values.size(); i < e; ++i) { auto *value = values[i]; if (!isa(values[i])) @@ -688,7 +687,7 @@ static void buildDimAndSymbolPositionMaps( } }; - SmallVector srcValues, destValues; + SmallVector srcValues, destValues; srcDomain.getIdValues(&srcValues); dstDomain.getIdValues(&destValues); @@ -702,17 +701,10 @@ static void buildDimAndSymbolPositionMaps( updateValuePosMap(dstAccessMap.getOperands(), /*isSrc=*/false); } -static unsigned getPos(const DenseMap &posMap, - const MLValue *value) { - auto it = posMap.find(value); - assert(it != posMap.end()); - return it->second; -} - // Adds iteration domain constraints from 'srcDomain' and 'dstDomain' into // 'dependenceDomain'. // Uses 'valuePosMap' to determine the position in 'dependenceDomain' to which a -// srcDomain/dstDomain MLValue maps. +// srcDomain/dstDomain Value maps. static void addDomainConstraints(const FlatAffineConstraints &srcDomain, const FlatAffineConstraints &dstDomain, const ValuePositionMap &valuePosMap, @@ -790,10 +782,10 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap, unsigned numResults = srcMap.getNumResults(); unsigned srcNumIds = srcMap.getNumDims() + srcMap.getNumSymbols(); - ArrayRef srcOperands = srcAccessMap.getOperands(); + ArrayRef srcOperands = srcAccessMap.getOperands(); unsigned dstNumIds = dstMap.getNumDims() + dstMap.getNumSymbols(); - ArrayRef dstOperands = dstAccessMap.getOperands(); + ArrayRef dstOperands = dstAccessMap.getOperands(); std::vector> srcFlatExprs; std::vector> destFlatExprs; @@ -848,7 +840,7 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap, } // Add equality constraints for any operands that are defined by constant ops. - auto addEqForConstOperands = [&](ArrayRef operands) { + auto addEqForConstOperands = [&](ArrayRef operands) { for (unsigned i = 0, e = operands.size(); i < e; ++i) { if (isa(operands[i])) continue; @@ -1095,7 +1087,7 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const { // upper/lower loop bounds for each ForStmt in the loop nest associated // with each access. // *) Build dimension and symbol position maps for each access, which map -// MLValues from access functions and iteration domains to their position +// Values from access functions and iteration domains to their position // in the merged constraint system built by this method. // // This method builds a constraint system with the following column format: @@ -1202,7 +1194,7 @@ bool mlir::checkMemrefAccessDependence( return false; } // Build dim and symbol position maps for each access from access operand - // MLValue to position in merged contstraint system. + // Value to position in merged contstraint system. ValuePositionMap valuePosMap; buildDimAndSymbolPositionMaps(srcDomain, dstDomain, srcAccessMap, dstAccessMap, &valuePosMap); diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index c22c5ec95bc..bfdaceff7e7 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -25,7 +25,6 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IntegerSet.h" -#include "mlir/IR/MLValue.h" #include "mlir/IR/Statements.h" #include "mlir/Support/MathExtras.h" #include "llvm/ADT/DenseSet.h" @@ -238,23 +237,23 @@ MutableIntegerSet::MutableIntegerSet(unsigned numDims, unsigned numSymbols, AffineValueMap::AffineValueMap(const AffineApplyOp &op) : map(op.getAffineMap()) { for (auto *operand : op.getOperands()) - operands.push_back(cast(const_cast(operand))); + operands.push_back(const_cast(operand)); for (unsigned i = 0, e = op.getNumResults(); i < e; i++) - results.push_back(cast(const_cast(op.getResult(i)))); + results.push_back(const_cast(op.getResult(i))); } -AffineValueMap::AffineValueMap(AffineMap map, ArrayRef operands) +AffineValueMap::AffineValueMap(AffineMap map, ArrayRef operands) : map(map) { - for (MLValue *operand : operands) { + for (Value *operand : operands) { this->operands.push_back(operand); } } -void AffineValueMap::reset(AffineMap map, ArrayRef operands) { +void AffineValueMap::reset(AffineMap map, ArrayRef operands) { this->operands.clear(); this->results.clear(); this->map.reset(map); - for (MLValue *operand : operands) { + for (Value *operand : operands) { this->operands.push_back(operand); } } @@ -275,7 +274,7 @@ void AffineValueMap::forwardSubstituteSingle(const AffineApplyOp &inputOp, // Returns true and sets 'indexOfMatch' if 'valueToMatch' is found in // 'valuesToSearch' beginning at 'indexStart'. Returns false otherwise. -static bool findIndex(MLValue *valueToMatch, ArrayRef valuesToSearch, +static bool findIndex(Value *valueToMatch, ArrayRef valuesToSearch, unsigned indexStart, unsigned *indexOfMatch) { unsigned size = valuesToSearch.size(); for (unsigned i = indexStart; i < size; ++i) { @@ -324,8 +323,7 @@ void AffineValueMap::forwardSubstitute( for (unsigned j = 0; j < inputNumResults; ++j) { if (!inputResultsToSubstitute[j]) continue; - if (operands[i] == - cast(const_cast(inputOp.getResult(j)))) { + if (operands[i] == const_cast(inputOp.getResult(j))) { currOperandToInputResult[i] = j; inputResultsUsed.insert(j); } @@ -365,7 +363,7 @@ void AffineValueMap::forwardSubstitute( } // Build new output operands list and map update. - SmallVector outputOperands; + SmallVector outputOperands; unsigned outputOperandPosition = 0; AffineMapCompositionUpdate mapUpdate(inputOp.getAffineMap().getResults()); @@ -385,8 +383,7 @@ void AffineValueMap::forwardSubstitute( if (inputPositionsUsed.count(i) == 0) continue; // Check if input operand has a dup in current operand list. - auto *inputOperand = - cast(const_cast(inputOp.getOperand(i))); + auto *inputOperand = const_cast(inputOp.getOperand(i)); unsigned outputIndex; if (findIndex(inputOperand, outputOperands, /*indexStart=*/0, &outputIndex)) { @@ -418,8 +415,7 @@ void AffineValueMap::forwardSubstitute( continue; unsigned inputSymbolPosition = i - inputNumDims; // Check if input operand has a dup in current operand list. - auto *inputOperand = - cast(const_cast(inputOp.getOperand(i))); + auto *inputOperand = const_cast(inputOp.getOperand(i)); // Find output operand index of 'inputOperand' dup. unsigned outputIndex; // Start at index 'outputNumDims' so that only symbol operands are searched. @@ -451,7 +447,7 @@ inline bool AffineValueMap::isMultipleOf(unsigned idx, int64_t factor) const { /// This method uses the invariant that operands are always positionally aligned /// with the AffineDimExpr in the underlying AffineMap. -bool AffineValueMap::isFunctionOf(unsigned idx, MLValue *value) const { +bool AffineValueMap::isFunctionOf(unsigned idx, Value *value) const { unsigned index; findIndex(value, operands, /*indexStart=*/0, &index); auto expr = const_cast(this)->getAffineMap().getResult(idx); @@ -460,12 +456,12 @@ bool AffineValueMap::isFunctionOf(unsigned idx, MLValue *value) const { return expr.isFunctionOfDim(index); } -SSAValue *AffineValueMap::getOperand(unsigned i) const { - return static_cast(operands[i]); +Value *AffineValueMap::getOperand(unsigned i) const { + return static_cast(operands[i]); } -ArrayRef AffineValueMap::getOperands() const { - return ArrayRef(operands); +ArrayRef AffineValueMap::getOperands() const { + return ArrayRef(operands); } AffineMap AffineValueMap::getAffineMap() const { return map.getAffineMap(); } @@ -546,7 +542,7 @@ void FlatAffineConstraints::reset(unsigned numReservedInequalities, unsigned newNumReservedCols, unsigned newNumDims, unsigned newNumSymbols, unsigned newNumLocals, - ArrayRef idArgs) { + ArrayRef idArgs) { assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 && "minimum 1 column"); numReservedCols = newNumReservedCols; @@ -570,7 +566,7 @@ void FlatAffineConstraints::reset(unsigned numReservedInequalities, void FlatAffineConstraints::reset(unsigned newNumDims, unsigned newNumSymbols, unsigned newNumLocals, - ArrayRef idArgs) { + ArrayRef idArgs) { reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims, newNumSymbols, newNumLocals, idArgs); } @@ -597,17 +593,17 @@ void FlatAffineConstraints::addLocalId(unsigned pos) { addId(IdKind::Local, pos); } -void FlatAffineConstraints::addDimId(unsigned pos, MLValue *id) { +void FlatAffineConstraints::addDimId(unsigned pos, Value *id) { addId(IdKind::Dimension, pos, id); } -void FlatAffineConstraints::addSymbolId(unsigned pos, MLValue *id) { +void FlatAffineConstraints::addSymbolId(unsigned pos, Value *id) { addId(IdKind::Symbol, pos, id); } /// Adds a dimensional identifier. The added column is initialized to /// zero. -void FlatAffineConstraints::addId(IdKind kind, unsigned pos, MLValue *id) { +void FlatAffineConstraints::addId(IdKind kind, unsigned pos, Value *id) { if (kind == IdKind::Dimension) { assert(pos <= getNumDimIds()); } else if (kind == IdKind::Symbol) { @@ -755,7 +751,7 @@ bool FlatAffineConstraints::composeMap(AffineValueMap *vMap) { // Dims and symbols. for (unsigned i = 0, e = vMap->getNumOperands(); i < e; i++) { unsigned loc; - bool ret = findId(*cast(vMap->getOperand(i)), &loc); + bool ret = findId(*vMap->getOperand(i), &loc); assert(ret && "value map's id can't be found"); (void)ret; // We need to negate 'eq[r]' since the newly added dimension is going to @@ -1231,7 +1227,7 @@ void FlatAffineConstraints::addUpperBound(ArrayRef expr, } } -bool FlatAffineConstraints::findId(const MLValue &id, unsigned *pos) const { +bool FlatAffineConstraints::findId(const Value &id, unsigned *pos) const { unsigned i = 0; for (const auto &mayBeId : ids) { if (mayBeId.hasValue() && mayBeId.getValue() == &id) { @@ -1253,8 +1249,8 @@ void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) { bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) { unsigned pos; // Pre-condition for this method. - if (!findId(*cast(&forStmt), &pos)) { - assert(0 && "MLValue not found"); + if (!findId(forStmt, &pos)) { + assert(0 && "Value not found"); return false; } @@ -1270,7 +1266,7 @@ bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) { unsigned loc; if (!findId(*operand, &loc)) { if (operand->isValidSymbol()) { - addSymbolId(getNumSymbolIds(), const_cast(operand)); + addSymbolId(getNumSymbolIds(), const_cast(operand)); loc = getNumDimIds() + getNumSymbolIds() - 1; // Check if the symbol is a constant. if (auto *opStmt = operand->getDefiningStmt()) { @@ -1279,7 +1275,7 @@ bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) { } } } else { - addDimId(getNumDimIds(), const_cast(operand)); + addDimId(getNumDimIds(), const_cast(operand)); loc = getNumDimIds() - 1; } } @@ -1352,7 +1348,7 @@ void FlatAffineConstraints::setIdToConstant(unsigned pos, int64_t val) { /// Sets the specified identifer to a constant value; asserts if the id is not /// found. -void FlatAffineConstraints::setIdToConstant(const MLValue &id, int64_t val) { +void FlatAffineConstraints::setIdToConstant(const Value &id, int64_t val) { unsigned pos; if (!findId(id, &pos)) // This is a pre-condition for this method. @@ -1572,7 +1568,7 @@ void FlatAffineConstraints::print(raw_ostream &os) const { if (ids[i] == None) os << "None "; else - os << "MLValue "; + os << "Value "; } os << " const)\n"; for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { @@ -1779,7 +1775,7 @@ void FlatAffineConstraints::FourierMotzkinEliminate( unsigned newNumDims = dimsSymbols.first; unsigned newNumSymbols = dimsSymbols.second; - SmallVector, 8> newIds; + SmallVector, 8> newIds; newIds.reserve(numIds - 1); newIds.insert(newIds.end(), ids.begin(), ids.begin() + pos); newIds.insert(newIds.end(), ids.begin() + pos + 1, ids.end()); @@ -1942,7 +1938,7 @@ void FlatAffineConstraints::projectOut(unsigned pos, unsigned num) { normalizeConstraintsByGCD(); } -void FlatAffineConstraints::projectOut(MLValue *id) { +void FlatAffineConstraints::projectOut(Value *id) { unsigned pos; bool ret = findId(*id, &pos); assert(ret); diff --git a/mlir/lib/Analysis/Dominance.cpp b/mlir/lib/Analysis/Dominance.cpp index b3faaf3eae0..1a28eb138f4 100644 --- a/mlir/lib/Analysis/Dominance.cpp +++ b/mlir/lib/Analysis/Dominance.cpp @@ -70,7 +70,7 @@ bool DominanceInfo::properlyDominates(const Instruction *a, } /// Return true if value A properly dominates instruction B. -bool DominanceInfo::properlyDominates(const SSAValue *a, const Instruction *b) { +bool DominanceInfo::properlyDominates(const Value *a, const Instruction *b) { if (auto *aInst = a->getDefiningInst()) return properlyDominates(aInst, b); diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index f20b8bb19e5..7213ba5986a 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -124,14 +124,14 @@ uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) { return tripCountExpr.getLargestKnownDivisor(); } -bool mlir::isAccessInvariant(const MLValue &iv, const MLValue &index) { +bool mlir::isAccessInvariant(const Value &iv, const Value &index) { assert(isa(iv) && "iv must be a ForStmt"); assert(index.getType().isa() && "index must be of IndexType"); SmallVector affineApplyOps; - getReachableAffineApplyOps({const_cast(&index)}, affineApplyOps); + getReachableAffineApplyOps({const_cast(&index)}, affineApplyOps); if (affineApplyOps.empty()) { - // Pointer equality test because of MLValue pointer semantics. + // Pointer equality test because of Value pointer semantics. return &index != &iv; } @@ -155,13 +155,13 @@ bool mlir::isAccessInvariant(const MLValue &iv, const MLValue &index) { } assert(idx < std::numeric_limits::max()); return !AffineValueMap(*composeOp) - .isFunctionOf(idx, &const_cast(iv)); + .isFunctionOf(idx, &const_cast(iv)); } -llvm::DenseSet -mlir::getInvariantAccesses(const MLValue &iv, - llvm::ArrayRef indices) { - llvm::DenseSet res; +llvm::DenseSet +mlir::getInvariantAccesses(const Value &iv, + llvm::ArrayRef indices) { + llvm::DenseSet res; for (unsigned idx = 0, n = indices.size(); idx < n; ++idx) { auto *val = indices[idx]; if (isAccessInvariant(iv, *val)) { @@ -191,7 +191,7 @@ mlir::getInvariantAccesses(const MLValue &iv, /// // TODO(ntv): check strides. template -static bool isContiguousAccess(const MLValue &iv, const LoadOrStoreOp &memoryOp, +static bool isContiguousAccess(const Value &iv, const LoadOrStoreOp &memoryOp, unsigned fastestVaryingDim) { static_assert(std::is_same::value || std::is_same::value, @@ -220,7 +220,7 @@ static bool isContiguousAccess(const MLValue &iv, const LoadOrStoreOp &memoryOp, if (fastestVaryingDim == (numIndices - 1) - d++) { continue; } - if (!isAccessInvariant(iv, cast(*index))) { + if (!isAccessInvariant(iv, *index)) { return false; } } @@ -316,7 +316,7 @@ bool mlir::isStmtwiseShiftValid(const ForStmt &forStmt, // outside). if (const auto *opStmt = dyn_cast(&stmt)) { for (unsigned i = 0, e = opStmt->getNumResults(); i < e; ++i) { - const MLValue *result = opStmt->getResult(i); + const Value *result = opStmt->getResult(i); for (const StmtOperand &use : result->getUses()) { // If an ancestor statement doesn't lie in the block of forStmt, there // is no shift to check. diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp index 2e3df2d61f4..7c57a66310a 100644 --- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp @@ -70,7 +70,7 @@ static void addMemRefAccessIndices( MemRefType memrefType, MemRefAccess *access) { access->indices.reserve(memrefType.getRank()); for (auto *index : opIndices) { - access->indices.push_back(cast(const_cast(index))); + access->indices.push_back(const_cast(index)); } } @@ -79,13 +79,13 @@ static void getMemRefAccess(const OperationStmt *loadOrStoreOpStmt, MemRefAccess *access) { access->opStmt = loadOrStoreOpStmt; if (auto loadOp = loadOrStoreOpStmt->dyn_cast()) { - access->memref = cast(loadOp->getMemRef()); + access->memref = loadOp->getMemRef(); addMemRefAccessIndices(loadOp->getIndices(), loadOp->getMemRefType(), access); } else { assert(loadOrStoreOpStmt->isa()); auto storeOp = loadOrStoreOpStmt->dyn_cast(); - access->memref = cast(storeOp->getMemRef()); + access->memref = storeOp->getMemRef(); addMemRefAccessIndices(storeOp->getIndices(), storeOp->getMemRefType(), access); } diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 0c6cfea7ccd..7d397647bc9 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -150,21 +150,21 @@ bool mlir::getMemRefRegion(OperationStmt *opStmt, unsigned loopDepth, OpPointer loadOp; OpPointer storeOp; unsigned rank; - SmallVector indices; + SmallVector indices; if ((loadOp = opStmt->dyn_cast())) { rank = loadOp->getMemRefType().getRank(); for (auto *index : loadOp->getIndices()) { - indices.push_back(cast(index)); + indices.push_back(index); } - region->memref = cast(loadOp->getMemRef()); + region->memref = loadOp->getMemRef(); region->setWrite(false); } else if ((storeOp = opStmt->dyn_cast())) { rank = storeOp->getMemRefType().getRank(); for (auto *index : storeOp->getIndices()) { - indices.push_back(cast(index)); + indices.push_back(index); } - region->memref = cast(storeOp->getMemRef()); + region->memref = storeOp->getMemRef(); region->setWrite(true); } else { return false; @@ -201,7 +201,7 @@ bool mlir::getMemRefRegion(OperationStmt *opStmt, unsigned loopDepth, return false; } else { // Has to be a valid symbol. - auto *symbol = cast(accessValueMap.getOperand(i)); + auto *symbol = accessValueMap.getOperand(i); assert(symbol->isValidSymbol()); // Check if the symbol is a constant. if (auto *opStmt = symbol->getDefiningStmt()) { @@ -405,7 +405,7 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess, // Solve for src IVs in terms of dst IVs, symbols and constants. SmallVector srcIvMaps(srcLoopNestSize, AffineMap::Null()); - std::vector> srcIvOperands(srcLoopNestSize); + std::vector> srcIvOperands(srcLoopNestSize); for (unsigned i = 0; i < srcLoopNestSize; ++i) { // Skip IVs which are greater than requested loop depth. if (i >= srcLoopDepth) { @@ -442,7 +442,7 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess, srcIvOperands[i].push_back(dstLoopNest[dimId - 1]); } // TODO(andydavis) Add symbols from the access function. Ideally, we - // should be able to query the constaint system for the MLValue associated + // should be able to query the constaint system for the Value associated // with a symbol identifiers in 'nonZeroSymbolIds'. } @@ -454,7 +454,7 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess, // of the loop at 'dstLoopDepth' in 'dstLoopNest'. auto *dstForStmt = dstLoopNest[dstLoopDepth - 1]; MLFuncBuilder b(dstForStmt->getBody(), dstForStmt->getBody()->begin()); - DenseMap operandMap; + DenseMap operandMap; auto *sliceLoopNest = cast(b.clone(*srcLoopNest[0], operandMap)); // Lookup stmt in cloned 'sliceLoopNest' at 'positions'. diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index bfef98d76da..ec19194f2fa 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -108,7 +108,7 @@ static AffineMap makePermutationMap( const DenseMap &enclosingLoopToVectorDim) { using functional::makePtrDynCaster; using functional::map; - auto unwrappedIndices = map(makePtrDynCaster(), indices); + auto unwrappedIndices = map(makePtrDynCaster(), indices); SmallVector perm(enclosingLoopToVectorDim.size(), getAffineConstantExpr(0, context)); for (auto kvp : enclosingLoopToVectorDim) { diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index a04cee7512d..e7abb899a11 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -277,7 +277,7 @@ struct MLFuncVerifier : public Verifier, public StmtWalker { /// Walk all of the code in this MLFunc and verify that the operands of any /// operations are properly dominated by their definitions. bool MLFuncVerifier::verifyDominance() { - using HashTable = llvm::ScopedHashTable; + using HashTable = llvm::ScopedHashTable; HashTable liveValues; HashTable::ScopeTy topScope(liveValues); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 4778564cb4d..c44ce4d4d6c 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -38,7 +38,6 @@ #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSet.h" - using namespace mlir; void Identifier::print(raw_ostream &os) const { os << str(); } @@ -967,7 +966,7 @@ public: void printFunctionAttributes(const Function *func) { return ModulePrinter::printFunctionAttributes(func); } - void printOperand(const SSAValue *value) { printValueID(value); } + void printOperand(const Value *value) { printValueID(value); } void printOptionalAttrDict(ArrayRef attrs, ArrayRef elidedAttrs = {}) { @@ -977,7 +976,7 @@ public: enum { nameSentinel = ~0U }; protected: - void numberValueID(const SSAValue *value) { + void numberValueID(const Value *value) { assert(!valueIDs.count(value) && "Value numbered multiple times"); SmallString<32> specialNameBuffer; @@ -1004,7 +1003,7 @@ protected: if (specialNameBuffer.empty()) { switch (value->getKind()) { - case SSAValueKind::BlockArgument: + case Value::Kind::BlockArgument: // If this is an argument to the function, give it an 'arg' name. if (auto *block = cast(value)->getOwner()) if (auto *fn = block->getFunction()) @@ -1015,12 +1014,12 @@ protected: // Otherwise number it normally. valueIDs[value] = nextValueID++; return; - case SSAValueKind::StmtResult: + case Value::Kind::StmtResult: // This is an uninteresting result, give it a boring number and be // done with it. valueIDs[value] = nextValueID++; return; - case SSAValueKind::ForStmt: + case Value::Kind::ForStmt: specialName << 'i' << nextLoopID++; break; } @@ -1052,7 +1051,7 @@ protected: } } - void printValueID(const SSAValue *value, bool printResultNo = true) const { + void printValueID(const Value *value, bool printResultNo = true) const { int resultNo = -1; auto lookupValue = value; @@ -1093,8 +1092,8 @@ protected: private: /// This is the value ID for each SSA value in the current function. If this /// returns ~0, then the valueID has an entry in valueNames. - DenseMap valueIDs; - DenseMap valueNames; + DenseMap valueIDs; + DenseMap valueNames; /// This keeps track of all of the non-numeric names that are in flight, /// allowing us to check for duplicates. @@ -1135,7 +1134,7 @@ void FunctionPrinter::printDefaultOp(const Operation *op) { os << "\"("; interleaveComma(op->getOperands(), - [&](const SSAValue *value) { printValueID(value); }); + [&](const Value *value) { printValueID(value); }); os << ')'; auto attrs = op->getAttrs(); @@ -1144,16 +1143,15 @@ void FunctionPrinter::printDefaultOp(const Operation *op) { // Print the type signature of the operation. os << " : ("; interleaveComma(op->getOperands(), - [&](const SSAValue *value) { printType(value->getType()); }); + [&](const Value *value) { printType(value->getType()); }); os << ") -> "; if (op->getNumResults() == 1) { printType(op->getResult(0)->getType()); } else { os << '('; - interleaveComma(op->getResults(), [&](const SSAValue *result) { - printType(result->getType()); - }); + interleaveComma(op->getResults(), + [&](const Value *result) { printType(result->getType()); }); os << ')'; } } @@ -1297,11 +1295,10 @@ void CFGFunctionPrinter::printBranchOperands(const Range &range) { os << '('; interleaveComma(range, - [this](const SSAValue *operand) { printValueID(operand); }); + [this](const Value *operand) { printValueID(operand); }); os << " : "; - interleaveComma(range, [this](const SSAValue *operand) { - printType(operand->getType()); - }); + interleaveComma( + range, [this](const Value *operand) { printType(operand->getType()); }); os << ')'; } @@ -1576,20 +1573,20 @@ void IntegerSet::print(raw_ostream &os) const { ModulePrinter(os, state).printIntegerSet(*this); } -void SSAValue::print(raw_ostream &os) const { +void Value::print(raw_ostream &os) const { switch (getKind()) { - case SSAValueKind::BlockArgument: + case Value::Kind::BlockArgument: // TODO: Improve this. os << "\n"; return; - case SSAValueKind::StmtResult: + case Value::Kind::StmtResult: return getDefiningStmt()->print(os); - case SSAValueKind::ForStmt: + case Value::Kind::ForStmt: return cast(this)->print(os); } } -void SSAValue::dump() const { print(llvm::errs()); } +void Value::dump() const { print(llvm::errs()); } void Instruction::print(raw_ostream &os) const { auto *function = getFunction(); diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 0732448fb87..0b88216f66f 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -281,7 +281,7 @@ BasicBlock *CFGFuncBuilder::createBlock(BasicBlock *insertBefore) { // If we are supposed to insert before a specific block, do so, otherwise add // the block to the end of the function. if (insertBefore) - function->getBlocks().insert(CFGFunction::iterator(insertBefore), b); + function->getBlocks().insert(Function::iterator(insertBefore), b); else function->push_back(b); @@ -291,16 +291,9 @@ BasicBlock *CFGFuncBuilder::createBlock(BasicBlock *insertBefore) { /// Create an operation given the fields represented as an OperationState. OperationStmt *CFGFuncBuilder::createOperation(const OperationState &state) { - SmallVector operands; - operands.reserve(state.operands.size()); - // Allow null operands as they act as sentinal barriers between successor - // operand lists. - for (auto elt : state.operands) - operands.push_back(cast_or_null(elt)); - - auto *op = - OperationInst::create(state.location, state.name, operands, state.types, - state.attributes, state.successors, context); + auto *op = OperationInst::create(state.location, state.name, state.operands, + state.types, state.attributes, + state.successors, context); block->getStatements().insert(insertPoint, op); return op; } @@ -311,23 +304,17 @@ OperationStmt *CFGFuncBuilder::createOperation(const OperationState &state) { /// Create an operation given the fields represented as an OperationState. OperationStmt *MLFuncBuilder::createOperation(const OperationState &state) { - SmallVector operands; - operands.reserve(state.operands.size()); - for (auto elt : state.operands) - operands.push_back(cast(elt)); - - auto *op = - OperationStmt::create(state.location, state.name, operands, state.types, - state.attributes, state.successors, context); + auto *op = OperationStmt::create(state.location, state.name, state.operands, + state.types, state.attributes, + state.successors, context); block->getStatements().insert(insertPoint, op); return op; } ForStmt *MLFuncBuilder::createFor(Location location, - ArrayRef lbOperands, - AffineMap lbMap, - ArrayRef ubOperands, - AffineMap ubMap, int64_t step) { + ArrayRef lbOperands, AffineMap lbMap, + ArrayRef ubOperands, AffineMap ubMap, + int64_t step) { auto *stmt = ForStmt::create(location, lbOperands, lbMap, ubOperands, ubMap, step); block->getStatements().insert(insertPoint, stmt); @@ -341,7 +328,7 @@ ForStmt *MLFuncBuilder::createFor(Location location, int64_t lb, int64_t ub, return createFor(location, {}, lbMap, {}, ubMap, step); } -IfStmt *MLFuncBuilder::createIf(Location location, ArrayRef operands, +IfStmt *MLFuncBuilder::createIf(Location location, ArrayRef operands, IntegerSet set) { auto *stmt = IfStmt::create(location, operands, set); block->getStatements().insert(insertPoint, stmt); diff --git a/mlir/lib/IR/BuiltinOps.cpp b/mlir/lib/IR/BuiltinOps.cpp index cdf98ca4bee..50ab254dd76 100644 --- a/mlir/lib/IR/BuiltinOps.cpp +++ b/mlir/lib/IR/BuiltinOps.cpp @@ -20,8 +20,8 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/OpImplementation.h" -#include "mlir/IR/SSAValue.h" #include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" #include "mlir/Support/MathExtras.h" #include "mlir/Support/STLExtras.h" #include "llvm/Support/raw_ostream.h" @@ -54,7 +54,7 @@ void mlir::printDimAndSymbolList(Operation::const_operand_iterator begin, // dimension operands parsed. // Returns 'false' on success and 'true' on error. bool mlir::parseDimAndSymbolList(OpAsmParser *parser, - SmallVector &operands, + SmallVector &operands, unsigned &numDims) { SmallVector opInfos; if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::Paren)) @@ -76,7 +76,7 @@ bool mlir::parseDimAndSymbolList(OpAsmParser *parser, //===----------------------------------------------------------------------===// void AffineApplyOp::build(Builder *builder, OperationState *result, - AffineMap map, ArrayRef operands) { + AffineMap map, ArrayRef operands) { result->addOperands(operands); result->types.append(map.getNumResults(), builder->getIndexType()); result->addAttribute("map", builder->getAffineMapAttr(map)); @@ -133,24 +133,22 @@ bool AffineApplyOp::verify() const { } // The result of the affine apply operation can be used as a dimension id if it -// is a CFG value or if it is an MLValue, and all the operands are valid +// is a CFG value or if it is an Value, and all the operands are valid // dimension ids. bool AffineApplyOp::isValidDim() const { for (auto *op : getOperands()) { - if (auto *v = dyn_cast(op)) - if (!v->isValidDim()) - return false; + if (!op->isValidDim()) + return false; } return true; } // The result of the affine apply operation can be used as a symbol if it is -// a CFG value or if it is an MLValue, and all the operands are symbols. +// a CFG value or if it is an Value, and all the operands are symbols. bool AffineApplyOp::isValidSymbol() const { for (auto *op : getOperands()) { - if (auto *v = dyn_cast(op)) - if (!v->isValidSymbol()) - return false; + if (!op->isValidSymbol()) + return false; } return true; } @@ -170,13 +168,13 @@ bool AffineApplyOp::constantFold(ArrayRef operandConstants, //===----------------------------------------------------------------------===// void BranchOp::build(Builder *builder, OperationState *result, BasicBlock *dest, - ArrayRef operands) { + ArrayRef operands) { result->addSuccessor(dest, operands); } bool BranchOp::parse(OpAsmParser *parser, OperationState *result) { BasicBlock *dest; - SmallVector destOperands; + SmallVector destOperands; if (parser->parseSuccessorAndUseList(dest, destOperands)) return true; result->addSuccessor(dest, destOperands); @@ -212,17 +210,16 @@ void BranchOp::eraseOperand(unsigned index) { //===----------------------------------------------------------------------===// void CondBranchOp::build(Builder *builder, OperationState *result, - SSAValue *condition, BasicBlock *trueDest, - ArrayRef trueOperands, - BasicBlock *falseDest, - ArrayRef falseOperands) { + Value *condition, BasicBlock *trueDest, + ArrayRef trueOperands, BasicBlock *falseDest, + ArrayRef falseOperands) { result->addOperands(condition); result->addSuccessor(trueDest, trueOperands); result->addSuccessor(falseDest, falseOperands); } bool CondBranchOp::parse(OpAsmParser *parser, OperationState *result) { - SmallVector destOperands; + SmallVector destOperands; BasicBlock *dest; OpAsmParser::OperandType condInfo; @@ -446,7 +443,7 @@ void ConstantIndexOp::build(Builder *builder, OperationState *result, //===----------------------------------------------------------------------===// void ReturnOp::build(Builder *builder, OperationState *result, - ArrayRef results) { + ArrayRef results) { result->addOperands(results); } @@ -465,9 +462,10 @@ void ReturnOp::print(OpAsmPrinter *p) const { *p << ' '; p->printOperands(operand_begin(), operand_end()); *p << " : "; - interleave(operand_begin(), operand_end(), - [&](const SSAValue *e) { p->printType(e->getType()); }, - [&]() { *p << ", "; }); + interleave( + operand_begin(), operand_end(), + [&](const Value *e) { p->printType(e->getType()); }, + [&]() { *p << ", "; }); } } diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 3a537d03e8f..6f22b854fbf 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -23,7 +23,6 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Statements.h" - using namespace mlir; /// Form the OperationName for an op with the specified string. This either is @@ -96,13 +95,13 @@ unsigned Operation::getNumOperands() const { return llvm::cast(this)->getNumOperands(); } -SSAValue *Operation::getOperand(unsigned idx) { +Value *Operation::getOperand(unsigned idx) { return llvm::cast(this)->getOperand(idx); } -void Operation::setOperand(unsigned idx, SSAValue *value) { +void Operation::setOperand(unsigned idx, Value *value) { auto *stmt = llvm::cast(this); - stmt->setOperand(idx, llvm::cast(value)); + stmt->setOperand(idx, value); } /// Return the number of results this operation has. @@ -111,7 +110,7 @@ unsigned Operation::getNumResults() const { } /// Return the indicated result. -SSAValue *Operation::getResult(unsigned idx) { +Value *Operation::getResult(unsigned idx) { return llvm::cast(this)->getResult(idx); } @@ -585,8 +584,8 @@ bool OpTrait::impl::verifyResultsAreIntegerLike(const Operation *op) { // These functions are out-of-line implementations of the methods in BinaryOp, // which avoids them being template instantiated/duplicated. -void impl::buildBinaryOp(Builder *builder, OperationState *result, - SSAValue *lhs, SSAValue *rhs) { +void impl::buildBinaryOp(Builder *builder, OperationState *result, Value *lhs, + Value *rhs) { assert(lhs->getType() == rhs->getType()); result->addOperands({lhs, rhs}); result->types.push_back(lhs->getType()); @@ -613,8 +612,8 @@ void impl::printBinaryOp(const Operation *op, OpAsmPrinter *p) { // CastOp implementation //===----------------------------------------------------------------------===// -void impl::buildCastOp(Builder *builder, OperationState *result, - SSAValue *source, Type destType) { +void impl::buildCastOp(Builder *builder, OperationState *result, Value *source, + Type destType) { result->addOperands(source); result->addTypes(destType); } diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index a8b6aa1e738..9e4d8bb180c 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -16,8 +16,8 @@ // ============================================================================= #include "mlir/IR/PatternMatch.h" -#include "mlir/IR/SSAValue.h" #include "mlir/IR/Statements.h" +#include "mlir/IR/Value.h" using namespace mlir; PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) { @@ -77,8 +77,8 @@ PatternRewriter::~PatternRewriter() { /// clients can specify a list of other nodes that this replacement may make /// (perhaps transitively) dead. If any of those ops are dead, this will /// remove them as well. -void PatternRewriter::replaceOp(Operation *op, ArrayRef newValues, - ArrayRef valuesToRemoveIfDead) { +void PatternRewriter::replaceOp(Operation *op, ArrayRef newValues, + ArrayRef valuesToRemoveIfDead) { // Notify the rewriter subclass that we're about to replace this root. notifyRootReplaced(op); @@ -97,15 +97,14 @@ void PatternRewriter::replaceOp(Operation *op, ArrayRef newValues, /// op and newOp are known to have the same number of results, replace the /// uses of op with uses of newOp void PatternRewriter::replaceOpWithResultsOfAnotherOp( - Operation *op, Operation *newOp, - ArrayRef valuesToRemoveIfDead) { + Operation *op, Operation *newOp, ArrayRef valuesToRemoveIfDead) { assert(op->getNumResults() == newOp->getNumResults() && "replacement op doesn't match results of original op"); if (op->getNumResults() == 1) return replaceOp(op, newOp->getResult(0), valuesToRemoveIfDead); - SmallVector newResults(newOp->getResults().begin(), - newOp->getResults().end()); + SmallVector newResults(newOp->getResults().begin(), + newOp->getResults().end()); return replaceOp(op, newResults, valuesToRemoveIfDead); } @@ -118,7 +117,7 @@ void PatternRewriter::replaceOpWithResultsOfAnotherOp( /// should remove if they are dead at this point. /// void PatternRewriter::updatedRootInPlace( - Operation *op, ArrayRef valuesToRemoveIfDead) { + Operation *op, ArrayRef valuesToRemoveIfDead) { // Notify the rewriter subclass that we're about to replace this root. notifyRootUpdated(op); diff --git a/mlir/lib/IR/SSAValue.cpp b/mlir/lib/IR/SSAValue.cpp index 9a26149ea1d..09825093fde 100644 --- a/mlir/lib/IR/SSAValue.cpp +++ b/mlir/lib/IR/SSAValue.cpp @@ -1,4 +1,4 @@ -//===- SSAValue.cpp - MLIR SSAValue Classes ------------===// +//===- SSAValue.cpp - MLIR ValueClasses ------------===// // // Copyright 2019 The MLIR Authors. // @@ -15,15 +15,15 @@ // limitations under the License. // ============================================================================= -#include "mlir/IR/SSAValue.h" #include "mlir/IR/Function.h" #include "mlir/IR/Statements.h" +#include "mlir/IR/Value.h" using namespace mlir; /// If this value is the result of an Instruction, return the instruction /// that defines it. -OperationInst *SSAValue::getDefiningInst() { +OperationInst *Value::getDefiningInst() { if (auto *result = dyn_cast(this)) return result->getOwner(); return nullptr; @@ -31,13 +31,13 @@ OperationInst *SSAValue::getDefiningInst() { /// If this value is the result of an OperationStmt, return the statement /// that defines it. -OperationStmt *SSAValue::getDefiningStmt() { +OperationStmt *Value::getDefiningStmt() { if (auto *result = dyn_cast(this)) return result->getOwner(); return nullptr; } -Operation *SSAValue::getDefiningOperation() { +Operation *Value::getDefiningOperation() { if (auto *inst = getDefiningInst()) return inst; if (auto *stmt = getDefiningStmt()) @@ -45,14 +45,14 @@ Operation *SSAValue::getDefiningOperation() { return nullptr; } -/// Return the function that this SSAValue is defined in. -Function *SSAValue::getFunction() { +/// Return the function that this Valueis defined in. +Function *Value::getFunction() { switch (getKind()) { - case SSAValueKind::BlockArgument: + case Value::Kind::BlockArgument: return cast(this)->getFunction(); - case SSAValueKind::StmtResult: + case Value::Kind::StmtResult: return getDefiningStmt()->getFunction(); - case SSAValueKind::ForStmt: + case Value::Kind::ForStmt: return cast(this)->getFunction(); } } @@ -89,15 +89,6 @@ MLIRContext *IROperandOwner::getContext() const { } } -//===----------------------------------------------------------------------===// -// MLValue implementation. -//===----------------------------------------------------------------------===// - -/// Return the function that this MLValue is defined in. -MLFunction *MLValue::getFunction() { - return cast(static_cast(this)->getFunction()); -} - //===----------------------------------------------------------------------===// // BlockArgument implementation. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Statement.cpp b/mlir/lib/IR/Statement.cpp index 2a47eb56a28..63c2b26425f 100644 --- a/mlir/lib/IR/Statement.cpp +++ b/mlir/lib/IR/Statement.cpp @@ -85,18 +85,16 @@ MLFunction *Statement::getFunction() const { return block ? block->getFunction() : nullptr; } -MLValue *Statement::getOperand(unsigned idx) { - return getStmtOperand(idx).get(); -} +Value *Statement::getOperand(unsigned idx) { return getStmtOperand(idx).get(); } -const MLValue *Statement::getOperand(unsigned idx) const { +const Value *Statement::getOperand(unsigned idx) const { return getStmtOperand(idx).get(); } -// MLValue can be used as a dimension id if it is valid as a symbol, or +// Value can be used as a dimension id if it is valid as a symbol, or // it is an induction variable, or it is a result of affine apply operation // with dimension id arguments. -bool MLValue::isValidDim() const { +bool Value::isValidDim() const { if (auto *stmt = getDefiningStmt()) { // Top level statement or constant operation is ok. if (stmt->getParentStmt() == nullptr || stmt->isa()) @@ -111,10 +109,10 @@ bool MLValue::isValidDim() const { return true; } -// MLValue can be used as a symbol if it is a constant, or it is defined at +// Value can be used as a symbol if it is a constant, or it is defined at // the top level, or it is a result of affine apply operation with symbol // arguments. -bool MLValue::isValidSymbol() const { +bool Value::isValidSymbol() const { if (auto *stmt = getDefiningStmt()) { // Top level statement or constant operation is ok. if (stmt->getParentStmt() == nullptr || stmt->isa()) @@ -129,7 +127,7 @@ bool MLValue::isValidSymbol() const { return isa(this); } -void Statement::setOperand(unsigned idx, MLValue *value) { +void Statement::setOperand(unsigned idx, Value *value) { getStmtOperand(idx).set(value); } @@ -271,7 +269,7 @@ void Statement::dropAllReferences() { /// Create a new OperationStmt with the specific fields. OperationStmt *OperationStmt::create(Location location, OperationName name, - ArrayRef operands, + ArrayRef operands, ArrayRef resultTypes, ArrayRef attributes, ArrayRef successors, @@ -420,8 +418,8 @@ void OperationInst::eraseOperand(unsigned index) { // ForStmt //===----------------------------------------------------------------------===// -ForStmt *ForStmt::create(Location location, ArrayRef lbOperands, - AffineMap lbMap, ArrayRef ubOperands, +ForStmt *ForStmt::create(Location location, ArrayRef lbOperands, + AffineMap lbMap, ArrayRef ubOperands, AffineMap ubMap, int64_t step) { assert(lbOperands.size() == lbMap.getNumInputs() && "lower bound operand count does not match the affine map"); @@ -444,9 +442,9 @@ ForStmt *ForStmt::create(Location location, ArrayRef lbOperands, ForStmt::ForStmt(Location location, unsigned numOperands, AffineMap lbMap, AffineMap ubMap, int64_t step) - : Statement(Kind::For, location), - MLValue(MLValueKind::ForStmt, - Type::getIndex(lbMap.getResult(0).getContext())), + : Statement(Statement::Kind::For, location), + Value(Value::Kind::ForStmt, + Type::getIndex(lbMap.getResult(0).getContext())), body(this), lbMap(lbMap), ubMap(ubMap), step(step) { // The body of a for stmt always has one block. @@ -462,11 +460,11 @@ const AffineBound ForStmt::getUpperBound() const { return AffineBound(*this, lbMap.getNumInputs(), getNumOperands(), ubMap); } -void ForStmt::setLowerBound(ArrayRef lbOperands, AffineMap map) { +void ForStmt::setLowerBound(ArrayRef lbOperands, AffineMap map) { assert(lbOperands.size() == map.getNumInputs()); assert(map.getNumResults() >= 1 && "bound map has at least one result"); - SmallVector ubOperands(getUpperBoundOperands()); + SmallVector ubOperands(getUpperBoundOperands()); operands.clear(); operands.reserve(lbOperands.size() + ubMap.getNumInputs()); @@ -479,11 +477,11 @@ void ForStmt::setLowerBound(ArrayRef lbOperands, AffineMap map) { this->lbMap = map; } -void ForStmt::setUpperBound(ArrayRef ubOperands, AffineMap map) { +void ForStmt::setUpperBound(ArrayRef ubOperands, AffineMap map) { assert(ubOperands.size() == map.getNumInputs()); assert(map.getNumResults() >= 1 && "bound map has at least one result"); - SmallVector lbOperands(getLowerBoundOperands()); + SmallVector lbOperands(getLowerBoundOperands()); operands.clear(); operands.reserve(lbOperands.size() + ubOperands.size()); @@ -553,7 +551,7 @@ bool ForStmt::matchingBoundOperandList() const { unsigned numOperands = lbMap.getNumInputs(); for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) { - // Compare MLValue *'s. + // Compare Value *'s. if (getOperand(i) != getOperand(numOperands + i)) return false; } @@ -581,7 +579,7 @@ IfStmt::~IfStmt() { // allocated through MLIRContext's bump pointer allocator. } -IfStmt *IfStmt::create(Location location, ArrayRef operands, +IfStmt *IfStmt::create(Location location, ArrayRef operands, IntegerSet set) { unsigned numOperands = operands.size(); assert(numOperands == set.getNumOperands() && @@ -617,16 +615,16 @@ MLIRContext *IfStmt::getContext() const { /// them alone if no entry is present). Replaces references to cloned /// sub-statements to the corresponding statement that is copied, and adds /// those mappings to the map. -Statement *Statement::clone(DenseMap &operandMap, +Statement *Statement::clone(DenseMap &operandMap, MLIRContext *context) const { // If the specified value is in operandMap, return the remapped value. // Otherwise return the value itself. - auto remapOperand = [&](const MLValue *value) -> MLValue * { + auto remapOperand = [&](const Value *value) -> Value * { auto it = operandMap.find(value); - return it != operandMap.end() ? it->second : const_cast(value); + return it != operandMap.end() ? it->second : const_cast(value); }; - SmallVector operands; + SmallVector operands; SmallVector successors; if (auto *opStmt = dyn_cast(this)) { operands.reserve(getNumOperands() + opStmt->getNumSuccessors()); @@ -683,10 +681,9 @@ Statement *Statement::clone(DenseMap &operandMap, auto ubMap = forStmt->getUpperBoundMap(); auto *newFor = ForStmt::create( - getLoc(), - ArrayRef(operands).take_front(lbMap.getNumInputs()), lbMap, - ArrayRef(operands).take_back(ubMap.getNumInputs()), ubMap, - forStmt->getStep()); + getLoc(), ArrayRef(operands).take_front(lbMap.getNumInputs()), + lbMap, ArrayRef(operands).take_back(ubMap.getNumInputs()), + ubMap, forStmt->getStep()); // Remember the induction variable mapping. operandMap[forStmt] = newFor; @@ -716,6 +713,6 @@ Statement *Statement::clone(DenseMap &operandMap, } Statement *Statement::clone(MLIRContext *context) const { - DenseMap operandMap; + DenseMap operandMap; return clone(operandMap, context); } diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index d58d687ee0c..9852b69e91b 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -42,7 +42,6 @@ #include "llvm/Support/SMLoc.h" #include "llvm/Support/SourceMgr.h" #include - using namespace mlir; using llvm::MemoryBuffer; using llvm::SMLoc; @@ -1890,10 +1889,10 @@ public: /// Given a reference to an SSA value and its type, return a reference. This /// returns null on failure. - SSAValue *resolveSSAUse(SSAUseInfo useInfo, Type type); + Value *resolveSSAUse(SSAUseInfo useInfo, Type type); /// Register a definition of a value with the symbol table. - ParseResult addDefinition(SSAUseInfo useInfo, SSAValue *value); + ParseResult addDefinition(SSAUseInfo useInfo, Value *value); // SSA parsing productions. ParseResult parseSSAUse(SSAUseInfo &result); @@ -1903,9 +1902,9 @@ public: ResultType parseSSADefOrUseAndType( const std::function &action); - SSAValue *parseSSAUseAndType() { - return parseSSADefOrUseAndType( - [&](SSAUseInfo useInfo, Type type) -> SSAValue * { + Value *parseSSAUseAndType() { + return parseSSADefOrUseAndType( + [&](SSAUseInfo useInfo, Type type) -> Value * { return resolveSSAUse(useInfo, type); }); } @@ -1920,9 +1919,8 @@ public: Operation *parseCustomOperation(const CreateOperationFunction &createOpFunc); /// Parse a single operation successor and it's operand list. - virtual bool - parseSuccessorAndUseList(BasicBlock *&dest, - SmallVectorImpl &operands) = 0; + virtual bool parseSuccessorAndUseList(BasicBlock *&dest, + SmallVectorImpl &operands) = 0; protected: FunctionParser(ParserState &state, Kind kind) : Parser(state), kind(kind) {} @@ -1934,24 +1932,23 @@ private: Kind kind; /// This keeps track of all of the SSA values we are tracking, indexed by /// their name. This has one entry per result number. - llvm::StringMap, 1>> values; + llvm::StringMap, 1>> values; /// These are all of the placeholders we've made along with the location of /// their first reference, to allow checking for use of undefined values. - DenseMap forwardReferencePlaceholders; + DenseMap forwardReferencePlaceholders; - SSAValue *createForwardReferencePlaceholder(SMLoc loc, Type type); + Value *createForwardReferencePlaceholder(SMLoc loc, Type type); /// Return true if this is a forward reference. - bool isForwardReferencePlaceholder(SSAValue *value) { + bool isForwardReferencePlaceholder(Value *value) { return forwardReferencePlaceholders.count(value); } }; } // end anonymous namespace /// Create and remember a new placeholder for a forward reference. -SSAValue *FunctionParser::createForwardReferencePlaceholder(SMLoc loc, - Type type) { +Value *FunctionParser::createForwardReferencePlaceholder(SMLoc loc, Type type) { // Forward references are always created as instructions, even in ML // functions, because we just need something with a def/use chain. // @@ -1969,7 +1966,7 @@ SSAValue *FunctionParser::createForwardReferencePlaceholder(SMLoc loc, /// Given an unbound reference to an SSA value and its type, return the value /// it specifies. This returns null on failure. -SSAValue *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type type) { +Value *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type type) { auto &entries = values[useInfo.name]; // If we have already seen a value of this name, return it. @@ -2010,7 +2007,7 @@ SSAValue *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type type) { } /// Register a definition of a value with the symbol table. -ParseResult FunctionParser::addDefinition(SSAUseInfo useInfo, SSAValue *value) { +ParseResult FunctionParser::addDefinition(SSAUseInfo useInfo, Value *value) { auto &entries = values[useInfo.name]; // Make sure there is a slot for this value. @@ -2046,7 +2043,7 @@ ParseResult FunctionParser::finalizeFunction(Function *func, SMLoc loc) { // Check for any forward references that are left. If we find any, error // out. if (!forwardReferencePlaceholders.empty()) { - SmallVector, 4> errors; + SmallVector, 4> errors; // Iteration over the map isn't deterministic, so sort by source location. for (auto entry : forwardReferencePlaceholders) errors.push_back({entry.second.getPointer(), entry.first}); @@ -2399,9 +2396,8 @@ public: return false; } - bool - parseSuccessorAndUseList(BasicBlock *&dest, - SmallVectorImpl &operands) override { + bool parseSuccessorAndUseList(BasicBlock *&dest, + SmallVectorImpl &operands) override { // Defer successor parsing to the function parsers. return parser.parseSuccessorAndUseList(dest, operands); } @@ -2493,7 +2489,7 @@ public: llvm::SMLoc getNameLoc() const override { return nameLoc; } bool resolveOperand(const OperandType &operand, Type type, - SmallVectorImpl &result) override { + SmallVectorImpl &result) override { FunctionParser::SSAUseInfo operandInfo = {operand.name, operand.number, operand.location}; if (auto *value = parser.resolveSSAUse(operandInfo, type)) { @@ -2573,7 +2569,7 @@ public: ParseResult parseFunctionBody(); bool parseSuccessorAndUseList(BasicBlock *&dest, - SmallVectorImpl &operands); + SmallVectorImpl &operands); private: CFGFunction *function; @@ -2636,7 +2632,7 @@ private: /// branch-use-list ::= `(` ssa-use-list ':' type-list-no-parens `)` /// bool CFGFunctionParser::parseSuccessorAndUseList( - BasicBlock *&dest, SmallVectorImpl &operands) { + BasicBlock *&dest, SmallVectorImpl &operands) { // Verify branch is identifier and get the matching block. if (!getToken().is(Token::bare_identifier)) return emitError("expected basic block name"); @@ -2790,10 +2786,10 @@ private: ParseResult parseForStmt(); ParseResult parseIntConstant(int64_t &val); - ParseResult parseDimAndSymbolList(SmallVectorImpl &operands, + ParseResult parseDimAndSymbolList(SmallVectorImpl &operands, unsigned numDims, unsigned numOperands, const char *affineStructName); - ParseResult parseBound(SmallVectorImpl &operands, AffineMap &map, + ParseResult parseBound(SmallVectorImpl &operands, AffineMap &map, bool isLower); ParseResult parseIfStmt(); ParseResult parseElseClause(StmtBlock *elseClause); @@ -2801,7 +2797,7 @@ private: ParseResult parseStmtBlock(StmtBlock *block); bool parseSuccessorAndUseList(BasicBlock *&dest, - SmallVectorImpl &operands) { + SmallVectorImpl &operands) { assert(false && "MLFunctions do not have terminators with successors."); return true; } @@ -2838,7 +2834,7 @@ ParseResult MLFunctionParser::parseForStmt() { return ParseFailure; // Parse lower bound. - SmallVector lbOperands; + SmallVector lbOperands; AffineMap lbMap; if (parseBound(lbOperands, lbMap, /*isLower*/ true)) return ParseFailure; @@ -2847,7 +2843,7 @@ ParseResult MLFunctionParser::parseForStmt() { return ParseFailure; // Parse upper bound. - SmallVector ubOperands; + SmallVector ubOperands; AffineMap ubMap; if (parseBound(ubOperands, ubMap, /*isLower*/ false)) return ParseFailure; @@ -2913,7 +2909,7 @@ ParseResult MLFunctionParser::parseIntConstant(int64_t &val) { /// dim-and-symbol-use-list ::= dim-use-list symbol-use-list? /// ParseResult -MLFunctionParser::parseDimAndSymbolList(SmallVectorImpl &operands, +MLFunctionParser::parseDimAndSymbolList(SmallVectorImpl &operands, unsigned numDims, unsigned numOperands, const char *affineStructName) { if (parseToken(Token::l_paren, "expected '('")) @@ -2942,18 +2938,17 @@ MLFunctionParser::parseDimAndSymbolList(SmallVectorImpl &operands, // Resolve SSA uses. Type indexType = builder.getIndexType(); for (unsigned i = 0, e = opInfo.size(); i != e; ++i) { - SSAValue *sval = resolveSSAUse(opInfo[i], indexType); + Value *sval = resolveSSAUse(opInfo[i], indexType); if (!sval) return ParseFailure; - auto *v = cast(sval); - if (i < numDims && !v->isValidDim()) + if (i < numDims && !sval->isValidDim()) return emitError(opInfo[i].loc, "value '" + opInfo[i].name.str() + "' cannot be used as a dimension id"); - if (i >= numDims && !v->isValidSymbol()) + if (i >= numDims && !sval->isValidSymbol()) return emitError(opInfo[i].loc, "value '" + opInfo[i].name.str() + "' cannot be used as a symbol"); - operands.push_back(v); + operands.push_back(sval); } return ParseSuccess; @@ -2965,7 +2960,7 @@ MLFunctionParser::parseDimAndSymbolList(SmallVectorImpl &operands, /// shorthand-bound upper-bound ::= `min`? affine-map dim-and-symbol-use-list /// | shorthand-bound shorthand-bound ::= ssa-id | `-`? integer-literal /// -ParseResult MLFunctionParser::parseBound(SmallVectorImpl &operands, +ParseResult MLFunctionParser::parseBound(SmallVectorImpl &operands, AffineMap &map, bool isLower) { // 'min' / 'max' prefixes are syntactic sugar. Ignore them. if (isLower) @@ -3003,7 +2998,7 @@ ParseResult MLFunctionParser::parseBound(SmallVectorImpl &operands, // TODO: improve error message when SSA value is not an affine integer. // Currently it is 'use of value ... expects different type than prior uses' if (auto *value = resolveSSAUse(opInfo, builder.getIndexType())) - operands.push_back(cast(value)); + operands.push_back(value); else return ParseFailure; @@ -3113,7 +3108,7 @@ ParseResult MLFunctionParser::parseIfStmt() { if (!set) return ParseFailure; - SmallVector operands; + SmallVector operands; if (parseDimAndSymbolList(operands, set.getNumDims(), set.getNumOperands(), "integer set")) return ParseFailure; diff --git a/mlir/lib/StandardOps/StandardOps.cpp b/mlir/lib/StandardOps/StandardOps.cpp index 9613c56daf0..7611c6e741b 100644 --- a/mlir/lib/StandardOps/StandardOps.cpp +++ b/mlir/lib/StandardOps/StandardOps.cpp @@ -23,8 +23,8 @@ #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/IR/SSAValue.h" #include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" #include "mlir/Support/MathExtras.h" #include "mlir/Support/STLExtras.h" #include "llvm/ADT/StringSwitch.h" @@ -78,8 +78,8 @@ struct MemRefCastFolder : public RewritePattern { // AddFOp //===----------------------------------------------------------------------===// -void AddFOp::build(Builder *builder, OperationState *result, SSAValue *lhs, - SSAValue *rhs) { +void AddFOp::build(Builder *builder, OperationState *result, Value *lhs, + Value *rhs) { assert(lhs->getType() == rhs->getType()); result->addOperands({lhs, rhs}); result->types.push_back(lhs->getType()); @@ -146,7 +146,7 @@ void AddIOp::getCanonicalizationPatterns(OwningRewritePatternList &results, //===----------------------------------------------------------------------===// void AllocOp::build(Builder *builder, OperationState *result, - MemRefType memrefType, ArrayRef operands) { + MemRefType memrefType, ArrayRef operands) { result->addOperands(operands); result->types.push_back(memrefType); } @@ -247,8 +247,8 @@ struct SimplifyAllocConst : public RewritePattern { // and keep track of the resultant memref type to build. SmallVector newShapeConstants; newShapeConstants.reserve(memrefType.getRank()); - SmallVector newOperands; - SmallVector droppedOperands; + SmallVector newOperands; + SmallVector droppedOperands; unsigned dynamicDimPos = 0; for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { @@ -301,7 +301,7 @@ void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, //===----------------------------------------------------------------------===// void CallOp::build(Builder *builder, OperationState *result, Function *callee, - ArrayRef operands) { + ArrayRef operands) { result->addOperands(operands); result->addAttribute("callee", builder->getFunctionAttr(callee)); result->addTypes(callee->getType().getResults()); @@ -370,7 +370,7 @@ bool CallOp::verify() const { //===----------------------------------------------------------------------===// void CallIndirectOp::build(Builder *builder, OperationState *result, - SSAValue *callee, ArrayRef operands) { + Value *callee, ArrayRef operands) { auto fnType = callee->getType().cast(); result->operands.push_back(callee); result->addOperands(operands); @@ -507,7 +507,7 @@ CmpIPredicate CmpIOp::getPredicateByName(StringRef name) { } void CmpIOp::build(Builder *build, OperationState *result, - CmpIPredicate predicate, SSAValue *lhs, SSAValue *rhs) { + CmpIPredicate predicate, Value *lhs, Value *rhs) { result->addOperands({lhs, rhs}); result->types.push_back(getI1SameShape(build, lhs->getType())); result->addAttribute(getPredicateAttrName(), @@ -580,8 +580,7 @@ bool CmpIOp::verify() const { // DeallocOp //===----------------------------------------------------------------------===// -void DeallocOp::build(Builder *builder, OperationState *result, - SSAValue *memref) { +void DeallocOp::build(Builder *builder, OperationState *result, Value *memref) { result->addOperands(memref); } @@ -615,7 +614,7 @@ void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, //===----------------------------------------------------------------------===// void DimOp::build(Builder *builder, OperationState *result, - SSAValue *memrefOrTensor, unsigned index) { + Value *memrefOrTensor, unsigned index) { result->addOperands(memrefOrTensor); auto type = builder->getIndexType(); result->addAttribute("index", builder->getIntegerAttr(type, index)); @@ -689,11 +688,11 @@ Attribute DimOp::constantFold(ArrayRef operands, // --------------------------------------------------------------------------- void DmaStartOp::build(Builder *builder, OperationState *result, - SSAValue *srcMemRef, ArrayRef srcIndices, - SSAValue *destMemRef, ArrayRef destIndices, - SSAValue *numElements, SSAValue *tagMemRef, - ArrayRef tagIndices, SSAValue *stride, - SSAValue *elementsPerStride) { + Value *srcMemRef, ArrayRef srcIndices, + Value *destMemRef, ArrayRef destIndices, + Value *numElements, Value *tagMemRef, + ArrayRef tagIndices, Value *stride, + Value *elementsPerStride) { result->addOperands(srcMemRef); result->addOperands(srcIndices); result->addOperands(destMemRef); @@ -836,8 +835,8 @@ void DmaStartOp::getCanonicalizationPatterns(OwningRewritePatternList &results, // --------------------------------------------------------------------------- void DmaWaitOp::build(Builder *builder, OperationState *result, - SSAValue *tagMemRef, ArrayRef tagIndices, - SSAValue *numElements) { + Value *tagMemRef, ArrayRef tagIndices, + Value *numElements) { result->addOperands(tagMemRef); result->addOperands(tagIndices); result->addOperands(numElements); @@ -896,8 +895,7 @@ void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results, //===----------------------------------------------------------------------===// void ExtractElementOp::build(Builder *builder, OperationState *result, - SSAValue *aggregate, - ArrayRef indices) { + Value *aggregate, ArrayRef indices) { auto aggregateType = aggregate->getType().cast(); result->addOperands(aggregate); result->addOperands(indices); @@ -955,8 +953,8 @@ bool ExtractElementOp::verify() const { // LoadOp //===----------------------------------------------------------------------===// -void LoadOp::build(Builder *builder, OperationState *result, SSAValue *memref, - ArrayRef indices) { +void LoadOp::build(Builder *builder, OperationState *result, Value *memref, + ArrayRef indices) { auto memrefType = memref->getType().cast(); result->addOperands(memref); result->addOperands(indices); @@ -1130,9 +1128,8 @@ void MulIOp::getCanonicalizationPatterns(OwningRewritePatternList &results, // SelectOp //===----------------------------------------------------------------------===// -void SelectOp::build(Builder *builder, OperationState *result, - SSAValue *condition, SSAValue *trueValue, - SSAValue *falseValue) { +void SelectOp::build(Builder *builder, OperationState *result, Value *condition, + Value *trueValue, Value *falseValue) { result->addOperands({condition, trueValue, falseValue}); result->addTypes(trueValue->getType()); } @@ -1201,8 +1198,8 @@ Attribute SelectOp::constantFold(ArrayRef operands, //===----------------------------------------------------------------------===// void StoreOp::build(Builder *builder, OperationState *result, - SSAValue *valueToStore, SSAValue *memref, - ArrayRef indices) { + Value *valueToStore, Value *memref, + ArrayRef indices) { result->addOperands(valueToStore); result->addOperands(memref); result->addOperands(indices); diff --git a/mlir/lib/SuperVectorOps/SuperVectorOps.cpp b/mlir/lib/SuperVectorOps/SuperVectorOps.cpp index 3b9f5ed1b3a..02b4c4674ab 100644 --- a/mlir/lib/SuperVectorOps/SuperVectorOps.cpp +++ b/mlir/lib/SuperVectorOps/SuperVectorOps.cpp @@ -72,10 +72,10 @@ static bool verifyPermutationMap(AffineMap permutationMap, } void VectorTransferReadOp::build(Builder *builder, OperationState *result, - VectorType vectorType, SSAValue *srcMemRef, - ArrayRef srcIndices, + VectorType vectorType, Value *srcMemRef, + ArrayRef srcIndices, AffineMap permutationMap, - Optional paddingValue) { + Optional paddingValue) { result->addOperands(srcMemRef); result->addOperands(srcIndices); if (paddingValue) { @@ -100,21 +100,20 @@ VectorTransferReadOp::getIndices() const { return {begin, end}; } -Optional VectorTransferReadOp::getPaddingValue() { +Optional VectorTransferReadOp::getPaddingValue() { auto memRefRank = getMemRefType().getRank(); if (getNumOperands() <= Offsets::FirstIndexOffset + memRefRank) { return None; } - return Optional( - getOperand(Offsets::FirstIndexOffset + memRefRank)); + return Optional(getOperand(Offsets::FirstIndexOffset + memRefRank)); } -Optional VectorTransferReadOp::getPaddingValue() const { +Optional VectorTransferReadOp::getPaddingValue() const { auto memRefRank = getMemRefType().getRank(); if (getNumOperands() <= Offsets::FirstIndexOffset + memRefRank) { return None; } - return Optional( + return Optional( getOperand(Offsets::FirstIndexOffset + memRefRank)); } @@ -136,7 +135,7 @@ void VectorTransferReadOp::print(OpAsmPrinter *p) const { // Construct the FunctionType and print it. llvm::SmallVector inputs{getMemRefType()}; // Must have at least one actual index, see verify. - const SSAValue *firstIndex = *(getIndices().begin()); + const Value *firstIndex = *(getIndices().begin()); Type indexType = firstIndex->getType(); inputs.append(getMemRefType().getRank(), indexType); if (optionalPaddingValue) { @@ -295,8 +294,8 @@ bool VectorTransferReadOp::verify() const { // VectorTransferWriteOp //===----------------------------------------------------------------------===// void VectorTransferWriteOp::build(Builder *builder, OperationState *result, - SSAValue *srcVector, SSAValue *dstMemRef, - ArrayRef dstIndices, + Value *srcVector, Value *dstMemRef, + ArrayRef dstIndices, AffineMap permutationMap) { result->addOperands({srcVector, dstMemRef}); result->addOperands(dstIndices); @@ -457,7 +456,7 @@ bool VectorTransferWriteOp::verify() const { // VectorTypeCastOp //===----------------------------------------------------------------------===// void VectorTypeCastOp::build(Builder *builder, OperationState *result, - SSAValue *srcVector, Type dstType) { + Value *srcVector, Type dstType) { result->addOperands(srcVector); result->addTypes(dstType); } diff --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp index 5c325dbd95d..a4d474dc24a 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp @@ -111,7 +111,7 @@ private: /// descriptor and get the pointer to the element indexed by the linearized /// subscript. Return nullptr on errors. llvm::Value *emitMemRefElementAccess( - const SSAValue *memRef, const Operation &op, + const Value *memRef, const Operation &op, llvm::iterator_range opIndices); /// Emit LLVM IR corresponding to the given Alloc `op`. In particular, create @@ -136,12 +136,12 @@ private: /// Create a single LLVM value of struct type that includes the list of /// given MLIR values. The `values` list must contain at least 2 elements. - llvm::Value *packValues(ArrayRef values); + llvm::Value *packValues(ArrayRef values); /// Extract a list of `num` LLVM values from a `value` of struct type. SmallVector unpackValues(llvm::Value *value, unsigned num); llvm::DenseMap functionMapping; - llvm::DenseMap valueMapping; + llvm::DenseMap valueMapping; llvm::DenseMap blockMapping; llvm::LLVMContext &llvmContext; llvm::IRBuilder builder; @@ -316,7 +316,7 @@ static bool checkSupportedMemRefType(MemRefType type, const Operation &op) { } llvm::Value *ModuleLowerer::emitMemRefElementAccess( - const SSAValue *memRef, const Operation &op, + const Value *memRef, const Operation &op, llvm::iterator_range opIndices) { auto type = memRef->getType().dyn_cast(); assert(type && "expected memRef value to have a MemRef type"); @@ -340,7 +340,7 @@ llvm::Value *ModuleLowerer::emitMemRefElementAccess( // Obtain the list of access subscripts as values and linearize it given the // list of sizes. auto indices = functional::map( - [this](const SSAValue *value) { return valueMapping.lookup(value); }, + [this](const Value *value) { return valueMapping.lookup(value); }, opIndices); auto subscript = linearizeSubscripts(indices, sizes); @@ -460,11 +460,11 @@ llvm::Value *ModuleLowerer::emitConstantSplat(const ConstantOp &op) { } // Create an undef struct value and insert individual values into it. -llvm::Value *ModuleLowerer::packValues(ArrayRef values) { +llvm::Value *ModuleLowerer::packValues(ArrayRef values) { assert(values.size() > 1 && "cannot pack less than 2 values"); auto types = - functional::map([](const SSAValue *v) { return v->getType(); }, values); + functional::map([](const Value *v) { return v->getType(); }, values); llvm::Type *packedType = getPackedResultType(types); llvm::Value *packed = llvm::UndefValue::get(packedType); @@ -641,7 +641,7 @@ bool ModuleLowerer::convertInstruction(const OperationInst &inst) { return false; } if (auto dimOp = inst.dyn_cast()) { - const SSAValue *container = dimOp->getOperand(); + const Value *container = dimOp->getOperand(); MemRefType type = container->getType().dyn_cast(); if (!type) return dimOp->emitError("only memref types are supported"); @@ -672,7 +672,7 @@ bool ModuleLowerer::convertInstruction(const OperationInst &inst) { if (auto callOp = inst.dyn_cast()) { auto operands = functional::map( - [this](const SSAValue *value) { return valueMapping.lookup(value); }, + [this](const Value *value) { return valueMapping.lookup(value); }, callOp->getOperands()); auto numResults = callOp->getNumResults(); llvm::Value *result = @@ -779,10 +779,9 @@ bool ModuleLowerer::convertBasicBlock(const BasicBlock &bb, // Get the SSA value passed to the current block from the terminator instruction // of its predecessor. -static const SSAValue *getPHISourceValue(const BasicBlock *current, - const BasicBlock *pred, - unsigned numArguments, - unsigned index) { +static const Value *getPHISourceValue(const BasicBlock *current, + const BasicBlock *pred, + unsigned numArguments, unsigned index) { auto &terminator = *pred->getTerminator(); if (terminator.isa()) { return terminator.getOperand(index); diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index d4a50a05989..53e633f53cd 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -30,13 +30,12 @@ struct ConstantFold : public FunctionPass, StmtWalker { ConstantFold() : FunctionPass(&ConstantFold::passID) {} // All constants in the function post folding. - SmallVector existingConstants; + SmallVector existingConstants; // Operation statements that were folded and that need to be erased. std::vector opStmtsToErase; - using ConstantFactoryType = std::function; + using ConstantFactoryType = std::function; - bool foldOperation(Operation *op, - SmallVectorImpl &existingConstants, + bool foldOperation(Operation *op, SmallVectorImpl &existingConstants, ConstantFactoryType constantFactory); void visitOperationStmt(OperationStmt *stmt); void visitForStmt(ForStmt *stmt); @@ -54,9 +53,8 @@ char ConstantFold::passID = 0; /// /// This returns false if the operation was successfully folded. bool ConstantFold::foldOperation(Operation *op, - SmallVectorImpl &existingConstants, + SmallVectorImpl &existingConstants, ConstantFactoryType constantFactory) { - // If this operation is already a constant, just remember it for cleanup // later, and don't try to fold it. if (auto constant = op->dyn_cast()) { @@ -114,7 +112,7 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) { if (!inst) continue; - auto constantFactory = [&](Attribute value, Type type) -> SSAValue * { + auto constantFactory = [&](Attribute value, Type type) -> Value * { builder.setInsertionPoint(inst); return builder.create(inst->getLoc(), value, type); }; @@ -142,7 +140,7 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) { // Override the walker's operation statement visit for constant folding. void ConstantFold::visitOperationStmt(OperationStmt *stmt) { - auto constantFactory = [&](Attribute value, Type type) -> SSAValue * { + auto constantFactory = [&](Attribute value, Type type) -> Value * { MLFuncBuilder builder(stmt); return builder.create(stmt->getLoc(), value, type); }; diff --git a/mlir/lib/Transforms/ConvertToCFG.cpp b/mlir/lib/Transforms/ConvertToCFG.cpp index 4423891a4bf..ab8ee28ba7c 100644 --- a/mlir/lib/Transforms/ConvertToCFG.cpp +++ b/mlir/lib/Transforms/ConvertToCFG.cpp @@ -50,28 +50,28 @@ public: void visitOperationStmt(OperationStmt *opStmt); private: - CFGValue *getConstantIndexValue(int64_t value); + Value *getConstantIndexValue(int64_t value); void visitStmtBlock(StmtBlock *stmtBlock); - CFGValue *buildMinMaxReductionSeq( + Value *buildMinMaxReductionSeq( Location loc, CmpIPredicate predicate, llvm::iterator_range values); CFGFunction *cfgFunc; CFGFuncBuilder builder; - // Mapping between original MLValues and lowered CFGValues. - llvm::DenseMap valueRemapping; + // Mapping between original Values and lowered Values. + llvm::DenseMap valueRemapping; }; } // end anonymous namespace -// Return a vector of OperationStmt's arguments as SSAValues. For each -// statement operands, represented as MLValue, lookup its CFGValue conterpart in +// Return a vector of OperationStmt's arguments as Values. For each +// statement operands, represented as Value, lookup its Value conterpart in // the valueRemapping table. -static llvm::SmallVector +static llvm::SmallVector operandsAs(Statement *opStmt, - const llvm::DenseMap &valueRemapping) { - llvm::SmallVector operands; - for (const MLValue *operand : opStmt->getOperands()) { + const llvm::DenseMap &valueRemapping) { + llvm::SmallVector operands; + for (const Value *operand : opStmt->getOperands()) { assert(valueRemapping.count(operand) != 0 && "operand is not defined"); operands.push_back(valueRemapping.lookup(operand)); } @@ -81,8 +81,8 @@ operandsAs(Statement *opStmt, // Convert an operation statement into an operation instruction. // // The operation description (name, number and types of operands or results) -// remains the same but the values must be updated to be CFGValues. Update the -// mapping MLValue->CFGValue as the conversion is performed. The operation +// remains the same but the values must be updated to be Values. Update the +// mapping Value->Value as the conversion is performed. The operation // instruction is appended to current block (end of SESE region). void FunctionConverter::visitOperationStmt(OperationStmt *opStmt) { // Set up basic operation state (context, name, operands). @@ -90,11 +90,10 @@ void FunctionConverter::visitOperationStmt(OperationStmt *opStmt) { opStmt->getName()); state.addOperands(operandsAs(opStmt, valueRemapping)); - // Set up operation return types. The corresponding SSAValues will become + // Set up operation return types. The corresponding Values will become // available after the operation is created. - state.addTypes( - functional::map([](SSAValue *result) { return result->getType(); }, - opStmt->getResults())); + state.addTypes(functional::map( + [](Value *result) { return result->getType(); }, opStmt->getResults())); // Copy attributes. for (auto attr : opStmt->getAttrs()) { @@ -112,10 +111,10 @@ void FunctionConverter::visitOperationStmt(OperationStmt *opStmt) { } } -// Create a CFGValue for the given integer constant of index type. -CFGValue *FunctionConverter::getConstantIndexValue(int64_t value) { +// Create a Value for the given integer constant of index type. +Value *FunctionConverter::getConstantIndexValue(int64_t value) { auto op = builder.create(builder.getUnknownLoc(), value); - return cast(op->getResult()); + return op->getResult(); } // Visit all statements in the given statement block. @@ -135,18 +134,18 @@ void FunctionConverter::visitStmtBlock(StmtBlock *stmtBlock) { // Multiple values are scanned in a linear sequence. This creates a data // dependences that wouldn't exist in a tree reduction, but is easier to // recognize as a reduction by the subsequent passes. -CFGValue *FunctionConverter::buildMinMaxReductionSeq( +Value *FunctionConverter::buildMinMaxReductionSeq( Location loc, CmpIPredicate predicate, llvm::iterator_range values) { assert(!llvm::empty(values) && "empty min/max chain"); auto valueIt = values.begin(); - CFGValue *value = cast(*valueIt++); + Value *value = *valueIt++; for (; valueIt != values.end(); ++valueIt) { auto cmpOp = builder.create(loc, predicate, value, *valueIt); auto selectOp = builder.create(loc, cmpOp->getResult(), value, *valueIt); - value = cast(selectOp->getResult()); + value = selectOp->getResult(); } return value; @@ -231,9 +230,9 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) { // The loop condition block has an argument for loop induction variable. // Create it upfront and make the loop induction variable -> basic block // argument remapping available to the following instructions. ForStatement - // is-a MLValue corresponding to the loop induction variable. + // is-a Value corresponding to the loop induction variable. builder.setInsertionPoint(loopConditionBlock); - CFGValue *iv = loopConditionBlock->addArgument(builder.getIndexType()); + Value *iv = loopConditionBlock->addArgument(builder.getIndexType()); valueRemapping.insert(std::make_pair(forStmt, iv)); // Recursively construct loop body region. @@ -251,7 +250,7 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) { auto affStepMap = builder.getAffineMap(1, 0, {affDim + affStep}, {}); auto stepOp = builder.create(forStmt->getLoc(), affStepMap, iv); - CFGValue *nextIvValue = cast(stepOp->getResult(0)); + Value *nextIvValue = stepOp->getResult(0); builder.create(builder.getUnknownLoc(), loopConditionBlock, nextIvValue); @@ -260,20 +259,19 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) { builder.setInsertionPoint(loopInitBlock); // Compute loop bounds using affine_apply after remapping its operands. - auto remapOperands = [this](const SSAValue *value) -> SSAValue * { - const MLValue *mlValue = dyn_cast(value); - return valueRemapping.lookup(mlValue); + auto remapOperands = [this](const Value *value) -> Value * { + return valueRemapping.lookup(value); }; auto operands = functional::map(remapOperands, forStmt->getLowerBoundOperands()); auto lbAffineApply = builder.create( forStmt->getLoc(), forStmt->getLowerBoundMap(), operands); - CFGValue *lowerBound = buildMinMaxReductionSeq( + Value *lowerBound = buildMinMaxReductionSeq( forStmt->getLoc(), CmpIPredicate::SGT, lbAffineApply->getResults()); operands = functional::map(remapOperands, forStmt->getUpperBoundOperands()); auto ubAffineApply = builder.create( forStmt->getLoc(), forStmt->getUpperBoundMap(), operands); - CFGValue *upperBound = buildMinMaxReductionSeq( + Value *upperBound = buildMinMaxReductionSeq( forStmt->getLoc(), CmpIPredicate::SLT, ubAffineApply->getResults()); builder.create(builder.getUnknownLoc(), loopConditionBlock, lowerBound); @@ -281,10 +279,10 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) { builder.setInsertionPoint(loopConditionBlock); auto comparisonOp = builder.create( forStmt->getLoc(), CmpIPredicate::SLT, iv, upperBound); - auto comparisonResult = cast(comparisonOp->getResult()); + auto comparisonResult = comparisonOp->getResult(); builder.create(builder.getUnknownLoc(), comparisonResult, - loopBodyFirstBlock, ArrayRef(), - postLoopBlock, ArrayRef()); + loopBodyFirstBlock, ArrayRef(), + postLoopBlock, ArrayRef()); // Finally, make sure building can continue by setting the post-loop block // (end of loop SESE region) as the insertion point. @@ -401,7 +399,7 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) { // If the test succeeds, jump to the next block testing testing the next // conjunct of the condition in the similar way. When all conjuncts have been // handled, jump to the 'then' block instead. - SSAValue *zeroConstant = getConstantIndexValue(0); + Value *zeroConstant = getConstantIndexValue(0); ifConditionExtraBlocks.push_back(thenBlock); for (auto tuple : llvm::zip(integerSet.getConstraints(), integerSet.getEqFlags(), @@ -416,16 +414,16 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) { integerSet.getNumSymbols(), constraintExpr, {}); auto affineApplyOp = builder.create( ifStmt->getLoc(), affineMap, operandsAs(ifStmt, valueRemapping)); - SSAValue *affResult = affineApplyOp->getResult(0); + Value *affResult = affineApplyOp->getResult(0); // Compare the result of the apply and branch. auto comparisonOp = builder.create( ifStmt->getLoc(), isEquality ? CmpIPredicate::EQ : CmpIPredicate::SGE, affResult, zeroConstant); builder.create(ifStmt->getLoc(), comparisonOp->getResult(), - nextBlock, /*trueArgs*/ ArrayRef(), + nextBlock, /*trueArgs*/ ArrayRef(), elseBlock, - /*falseArgs*/ ArrayRef()); + /*falseArgs*/ ArrayRef()); builder.setInsertionPoint(nextBlock); } ifConditionExtraBlocks.pop_back(); @@ -468,10 +466,10 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) { // of the current region. The SESE invariant allows us to easily handle nested // structures of arbitrary complexity. // -// During the conversion, we maintain a mapping between the MLValues present in -// the original function and their CFGValue images in the function under -// construction. When an MLValue is used, it gets replaced with the -// corresponding CFGValue that has been defined previously. The value flow +// During the conversion, we maintain a mapping between the Values present in +// the original function and their Value images in the function under +// construction. When an Value is used, it gets replaced with the +// corresponding Value that has been defined previously. The value flow // starts with function arguments converted to basic block arguments. CFGFunction *FunctionConverter::convert(MLFunction *mlFunc) { auto outerBlock = builder.createBlock(); @@ -482,8 +480,8 @@ CFGFunction *FunctionConverter::convert(MLFunction *mlFunc) { outerBlock->addArguments(mlFunc->getType().getInputs()); assert(mlFunc->getNumArguments() == outerBlock->getNumArguments()); for (unsigned i = 0, n = mlFunc->getNumArguments(); i < n; ++i) { - const MLValue *mlArgument = mlFunc->getArgument(i); - CFGValue *cfgArgument = outerBlock->getArgument(i); + const Value *mlArgument = mlFunc->getArgument(i); + Value *cfgArgument = outerBlock->getArgument(i); valueRemapping.insert(std::make_pair(mlArgument, cfgArgument)); } diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 62cf55e37d9..917cd3d0c13 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -76,7 +76,7 @@ struct DmaGeneration : public FunctionPass, StmtWalker { // Map from original memref's to the DMA buffers that their accesses are // replaced with. - DenseMap fastBufferMap; + DenseMap fastBufferMap; // Slow memory space associated with DMAs. const unsigned slowMemorySpace; @@ -195,11 +195,11 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForStmt *forStmt, // Indices to use for the DmaStart op. // Indices for the original memref being DMAed from/to. - SmallVector memIndices; + SmallVector memIndices; // Indices for the faster buffer being DMAed into/from. - SmallVector bufIndices; + SmallVector bufIndices; - SSAValue *zeroIndex = top.create(loc, 0); + Value *zeroIndex = top.create(loc, 0); unsigned rank = memRefType.getRank(); SmallVector fastBufferShape; @@ -226,10 +226,10 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForStmt *forStmt, // DMA generation is being done. const FlatAffineConstraints *cst = region.getConstraints(); auto ids = cst->getIds(); - SmallVector outerIVs; + SmallVector outerIVs; for (unsigned i = rank, e = ids.size(); i < e; i++) { auto id = cst->getIds()[i]; - assert(id.hasValue() && "MLValue id expected"); + assert(id.hasValue() && "Value id expected"); outerIVs.push_back(id.getValue()); } @@ -253,15 +253,15 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForStmt *forStmt, // Set DMA start location for this dimension in the lower memory space // memref. if (auto caf = offset.dyn_cast()) { - memIndices.push_back(cast( - top.create(loc, caf.getValue())->getResult())); + memIndices.push_back( + top.create(loc, caf.getValue())->getResult()); } else { // The coordinate for the start location is just the lower bound along the // corresponding dimension on the memory region (stored in 'offset'). auto map = top.getAffineMap( cst->getNumDimIds() + cst->getNumSymbolIds() - rank, 0, offset, {}); - memIndices.push_back(cast( - b->create(loc, map, outerIVs)->getResult(0))); + memIndices.push_back( + b->create(loc, map, outerIVs)->getResult(0)); } // The fast buffer is DMAed into at location zero; addressing is relative. bufIndices.push_back(zeroIndex); @@ -272,7 +272,7 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForStmt *forStmt, } // The faster memory space buffer. - SSAValue *fastMemRef; + Value *fastMemRef; // Check if a buffer was already created. // TODO(bondhugula): union across all memory op's per buffer. For now assuming @@ -321,8 +321,8 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForStmt *forStmt, return false; } - SSAValue *stride = nullptr; - SSAValue *numEltPerStride = nullptr; + Value *stride = nullptr; + Value *numEltPerStride = nullptr; if (!strideInfos.empty()) { stride = top.create(loc, strideInfos[0].stride); numEltPerStride = @@ -362,7 +362,7 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForStmt *forStmt, } auto indexRemap = b->getAffineMap(outerIVs.size() + rank, 0, remapExprs, {}); // *Only* those uses within the body of 'forStmt' are replaced. - replaceAllMemRefUsesWith(memref, cast(fastMemRef), + replaceAllMemRefUsesWith(memref, fastMemRef, /*extraIndices=*/{}, indexRemap, /*extraOperands=*/outerIVs, /*domStmtFilter=*/&*forStmt->getBody()->begin()); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index e3609496cc5..c86eec3d276 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -83,22 +83,22 @@ FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; } static void getSingleMemRefAccess(OperationStmt *loadOrStoreOpStmt, MemRefAccess *access) { if (auto loadOp = loadOrStoreOpStmt->dyn_cast()) { - access->memref = cast(loadOp->getMemRef()); + access->memref = loadOp->getMemRef(); access->opStmt = loadOrStoreOpStmt; auto loadMemrefType = loadOp->getMemRefType(); access->indices.reserve(loadMemrefType.getRank()); for (auto *index : loadOp->getIndices()) { - access->indices.push_back(cast(index)); + access->indices.push_back(index); } } else { assert(loadOrStoreOpStmt->isa()); auto storeOp = loadOrStoreOpStmt->dyn_cast(); access->opStmt = loadOrStoreOpStmt; - access->memref = cast(storeOp->getMemRef()); + access->memref = storeOp->getMemRef(); auto storeMemrefType = storeOp->getMemRefType(); access->indices.reserve(storeMemrefType.getRank()); for (auto *index : storeOp->getIndices()) { - access->indices.push_back(cast(index)); + access->indices.push_back(index); } } } @@ -178,20 +178,20 @@ public: Node(unsigned id, Statement *stmt) : id(id), stmt(stmt) {} // Returns the load op count for 'memref'. - unsigned getLoadOpCount(MLValue *memref) { + unsigned getLoadOpCount(Value *memref) { unsigned loadOpCount = 0; for (auto *loadOpStmt : loads) { - if (memref == cast(loadOpStmt->cast()->getMemRef())) + if (memref == loadOpStmt->cast()->getMemRef()) ++loadOpCount; } return loadOpCount; } // Returns the store op count for 'memref'. - unsigned getStoreOpCount(MLValue *memref) { + unsigned getStoreOpCount(Value *memref) { unsigned storeOpCount = 0; for (auto *storeOpStmt : stores) { - if (memref == cast(storeOpStmt->cast()->getMemRef())) + if (memref == storeOpStmt->cast()->getMemRef()) ++storeOpCount; } return storeOpCount; @@ -203,7 +203,7 @@ public: // The id of the node at the other end of the edge. unsigned id; // The memref on which this edge represents a dependence. - MLValue *memref; + Value *memref; }; // Map from node id to Node. @@ -227,13 +227,13 @@ public: } // Adds an edge from node 'srcId' to node 'dstId' for 'memref'. - void addEdge(unsigned srcId, unsigned dstId, MLValue *memref) { + void addEdge(unsigned srcId, unsigned dstId, Value *memref) { outEdges[srcId].push_back({dstId, memref}); inEdges[dstId].push_back({srcId, memref}); } // Removes an edge from node 'srcId' to node 'dstId' for 'memref'. - void removeEdge(unsigned srcId, unsigned dstId, MLValue *memref) { + void removeEdge(unsigned srcId, unsigned dstId, Value *memref) { assert(inEdges.count(dstId) > 0); assert(outEdges.count(srcId) > 0); // Remove 'srcId' from 'inEdges[dstId]'. @@ -253,7 +253,7 @@ public: } // Returns the input edge count for node 'id' and 'memref'. - unsigned getInEdgeCount(unsigned id, MLValue *memref) { + unsigned getInEdgeCount(unsigned id, Value *memref) { unsigned inEdgeCount = 0; if (inEdges.count(id) > 0) for (auto &inEdge : inEdges[id]) @@ -263,7 +263,7 @@ public: } // Returns the output edge count for node 'id' and 'memref'. - unsigned getOutEdgeCount(unsigned id, MLValue *memref) { + unsigned getOutEdgeCount(unsigned id, Value *memref) { unsigned outEdgeCount = 0; if (outEdges.count(id) > 0) for (auto &outEdge : outEdges[id]) @@ -347,7 +347,7 @@ public: // dependence graph at a different depth. bool MemRefDependenceGraph::init(MLFunction *f) { unsigned id = 0; - DenseMap> memrefAccesses; + DenseMap> memrefAccesses; for (auto &stmt : *f->getBody()) { if (auto *forStmt = dyn_cast(&stmt)) { // Create graph node 'id' to represent top-level 'forStmt' and record @@ -360,12 +360,12 @@ bool MemRefDependenceGraph::init(MLFunction *f) { Node node(id++, &stmt); for (auto *opStmt : collector.loadOpStmts) { node.loads.push_back(opStmt); - auto *memref = cast(opStmt->cast()->getMemRef()); + auto *memref = opStmt->cast()->getMemRef(); memrefAccesses[memref].insert(node.id); } for (auto *opStmt : collector.storeOpStmts) { node.stores.push_back(opStmt); - auto *memref = cast(opStmt->cast()->getMemRef()); + auto *memref = opStmt->cast()->getMemRef(); memrefAccesses[memref].insert(node.id); } nodes.insert({node.id, node}); @@ -375,7 +375,7 @@ bool MemRefDependenceGraph::init(MLFunction *f) { // Create graph node for top-level load op. Node node(id++, &stmt); node.loads.push_back(opStmt); - auto *memref = cast(opStmt->cast()->getMemRef()); + auto *memref = opStmt->cast()->getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); } @@ -383,7 +383,7 @@ bool MemRefDependenceGraph::init(MLFunction *f) { // Create graph node for top-level store op. Node node(id++, &stmt); node.stores.push_back(opStmt); - auto *memref = cast(opStmt->cast()->getMemRef()); + auto *memref = opStmt->cast()->getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); } @@ -477,8 +477,7 @@ public: SmallVector loads = dstNode->loads; while (!loads.empty()) { auto *dstLoadOpStmt = loads.pop_back_val(); - auto *memref = - cast(dstLoadOpStmt->cast()->getMemRef()); + auto *memref = dstLoadOpStmt->cast()->getMemRef(); // Skip 'dstLoadOpStmt' if multiple loads to 'memref' in 'dstNode'. if (dstNode->getLoadOpCount(memref) != 1) continue; diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index b5c12865790..5f49ed217a2 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -85,10 +85,8 @@ static void constructTiledIndexSetHyperRect(ArrayRef origLoops, for (unsigned i = 0; i < width; i++) { auto lbOperands = origLoops[i]->getLowerBoundOperands(); auto ubOperands = origLoops[i]->getUpperBoundOperands(); - SmallVector newLbOperands(lbOperands.begin(), - lbOperands.end()); - SmallVector newUbOperands(ubOperands.begin(), - ubOperands.end()); + SmallVector newLbOperands(lbOperands.begin(), lbOperands.end()); + SmallVector newUbOperands(ubOperands.begin(), ubOperands.end()); newLoops[i]->setLowerBound(newLbOperands, origLoops[i]->getLowerBoundMap()); newLoops[i]->setUpperBound(newUbOperands, origLoops[i]->getUpperBoundMap()); newLoops[i]->setStep(tileSizes[i]); @@ -112,8 +110,7 @@ static void constructTiledIndexSetHyperRect(ArrayRef origLoops, // Construct the upper bound map; the operands are the original operands // with 'i' (tile-space loop) appended to it. The new upper bound map is // the original one with an additional expression i + tileSize appended. - SmallVector ubOperands( - origLoops[i]->getUpperBoundOperands()); + SmallVector ubOperands(origLoops[i]->getUpperBoundOperands()); ubOperands.push_back(newLoops[i]); auto origUbMap = origLoops[i]->getUpperBoundMap(); @@ -191,8 +188,8 @@ UtilResult mlir::tileCodeGen(ArrayRef band, // Move the loop body of the original nest to the new one. moveLoopBody(origLoops[origLoops.size() - 1], innermostPointLoop); - SmallVector origLoopIVs(band.begin(), band.end()); - SmallVector, 6> ids(band.begin(), band.end()); + SmallVector origLoopIVs(band.begin(), band.end()); + SmallVector, 6> ids(band.begin(), band.end()); FlatAffineConstraints cst; getIndexSet(band, &cst); diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index ffff1c5b615..2a121529ed9 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -191,7 +191,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) { // unrollJamFactor. if (mayBeConstantTripCount.hasValue() && mayBeConstantTripCount.getValue() % unrollJamFactor != 0) { - DenseMap operandMap; + DenseMap operandMap; // Insert the cleanup loop right after 'forStmt'. MLFuncBuilder builder(forStmt->getBlock(), std::next(StmtBlock::iterator(forStmt))); @@ -219,7 +219,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) { // Unroll and jam (appends unrollJamFactor-1 additional copies). for (unsigned i = 1; i < unrollJamFactor; i++) { - DenseMap operandMapping; + DenseMap operandMapping; // If the induction variable is used, create a remapping to the value for // this unrolled instance. @@ -230,7 +230,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) { auto *ivUnroll = builder.create(forStmt->getLoc(), bumpMap, forStmt) ->getResult(0); - operandMapping[forStmt] = cast(ivUnroll); + operandMapping[forStmt] = ivUnroll; } // Clone the sub-block being unroll-jammed. for (auto it = subBlock.first; it != std::next(subBlock.second); ++it) { diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index fd07619a165..013b5080367 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -29,17 +29,14 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Location.h" -#include "mlir/IR/MLValue.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/IR/SSAValue.h" #include "mlir/IR/Types.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/SuperVectorOps/SuperVectorOps.h" #include "mlir/Support/Functional.h" -#include "mlir/Support/LLVM.h" #include "mlir/Transforms/MLPatternLoweringPass.h" #include "mlir/Transforms/Passes.h" @@ -62,26 +59,26 @@ using namespace mlir; #define DEBUG_TYPE "lower-vector-transfers" -/// Creates the SSAValue for the sum of `a` and `b` without building a +/// Creates the Value for the sum of `a` and `b` without building a /// full-fledged AffineMap for all indices. /// /// Prerequisites: /// `a` and `b` must be of IndexType. -static SSAValue *add(MLFuncBuilder *b, Location loc, SSAValue *v, SSAValue *w) { +static mlir::Value *add(MLFuncBuilder *b, Location loc, Value *v, Value *w) { assert(v->getType().isa() && "v must be of IndexType"); assert(w->getType().isa() && "w must be of IndexType"); auto *context = b->getContext(); auto d0 = getAffineDimExpr(0, context); auto d1 = getAffineDimExpr(1, context); auto map = AffineMap::get(2, 0, {d0 + d1}, {}); - return b->create(loc, map, ArrayRef{v, w}) + return b->create(loc, map, ArrayRef{v, w}) ->getResult(0); } namespace { struct LowerVectorTransfersState : public MLFuncGlobalLoweringState { // Top of the function constant zero index. - SSAValue *zero; + Value *zero; }; } // namespace @@ -131,7 +128,8 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer, // case of GPUs. if (std::is_same::value) { b.create(vecView->getLoc(), transfer->getVector(), - vecView->getResult(), ArrayRef{state->zero}); + vecView->getResult(), + ArrayRef{state->zero}); } // 3. Emit the loop-nest. @@ -140,7 +138,7 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer, // TODO(ntv): Handle broadcast / slice properly. auto permutationMap = transfer->getPermutationMap(); SetVector loops; - SmallVector accessIndices(transfer->getIndices()); + SmallVector accessIndices(transfer->getIndices()); for (auto it : llvm::enumerate(transfer->getVectorType().getShape())) { auto composed = composeWithUnboundedMap( getAffineDimExpr(it.index(), b.getContext()), permutationMap); @@ -168,17 +166,16 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer, // b. write scalar to local. auto scalarLoad = b.create(transfer->getLoc(), transfer->getMemRef(), accessIndices); - b.create( - transfer->getLoc(), scalarLoad->getResult(), - tmpScalarAlloc->getResult(), - functional::map([](SSAValue *val) { return val; }, loops)); + b.create(transfer->getLoc(), scalarLoad->getResult(), + tmpScalarAlloc->getResult(), + functional::map([](Value *val) { return val; }, loops)); } else { // VectorTransferWriteOp. // a. read scalar from local; // b. write scalar to remote. auto scalarLoad = b.create( transfer->getLoc(), tmpScalarAlloc->getResult(), - functional::map([](SSAValue *val) { return val; }, loops)); + functional::map([](Value *val) { return val; }, loops)); b.create(transfer->getLoc(), scalarLoad->getResult(), transfer->getMemRef(), accessIndices); } @@ -186,11 +183,11 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer, // 5. Read the vector from local storage in case of a vector_transfer_read. // TODO(ntv): This vector_load operation should be further lowered in the // case of GPUs. - llvm::SmallVector newResults = {}; + llvm::SmallVector newResults = {}; if (std::is_same::value) { b.setInsertionPoint(cast(transfer->getOperation())); auto *vector = b.create(transfer->getLoc(), vecView->getResult(), - ArrayRef{state->zero}) + ArrayRef{state->zero}) ->getResult(); newResults.push_back(vector); } diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index faea9953d86..a12c563fe1a 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -32,9 +32,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Location.h" -#include "mlir/IR/MLValue.h" #include "mlir/IR/OperationSupport.h" -#include "mlir/IR/SSAValue.h" #include "mlir/IR/Types.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" @@ -192,7 +190,7 @@ struct MaterializationState { VectorType superVectorType; VectorType hwVectorType; SmallVector hwVectorInstance; - DenseMap *substitutionsMap; + DenseMap *substitutionsMap; }; struct MaterializeVectorsPass : public FunctionPass { @@ -250,9 +248,9 @@ static SmallVector delinearize(unsigned linearIndex, static OperationStmt * instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType, - DenseMap *substitutionsMap); + DenseMap *substitutionsMap); -/// Not all SSAValue belong to a program slice scoped within the immediately +/// Not all Values belong to a program slice scoped within the immediately /// enclosing loop. /// One simple example is constants defined outside the innermost loop scope. /// For such cases the substitutionsMap has no entry and we allow an additional @@ -261,17 +259,16 @@ instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType, /// indices and will need to be extended in the future. /// /// If substitution fails, returns nullptr. -static MLValue * -substitute(SSAValue *v, VectorType hwVectorType, - DenseMap *substitutionsMap) { - auto it = substitutionsMap->find(cast(v)); +static Value *substitute(Value *v, VectorType hwVectorType, + DenseMap *substitutionsMap) { + auto it = substitutionsMap->find(v); if (it == substitutionsMap->end()) { auto *opStmt = cast(v->getDefiningOperation()); if (opStmt->isa()) { MLFuncBuilder b(opStmt); auto *inst = instantiate(&b, opStmt, hwVectorType, substitutionsMap); - auto res = substitutionsMap->insert( - std::make_pair(cast(v), cast(inst->getResult(0)))); + auto res = + substitutionsMap->insert(std::make_pair(v, inst->getResult(0))); assert(res.second && "Insertion failed"); return res.first->second; } @@ -336,10 +333,10 @@ substitute(SSAValue *v, VectorType hwVectorType, /// TODO(ntv): support a concrete AffineMap and compose with it. /// TODO(ntv): these implementation details should be captured in a /// vectorization trait at the op level directly. -static SmallVector +static SmallVector reindexAffineIndices(MLFuncBuilder *b, VectorType hwVectorType, ArrayRef hwVectorInstance, - ArrayRef memrefIndices) { + ArrayRef memrefIndices) { auto vectorShape = hwVectorType.getShape(); assert(hwVectorInstance.size() >= vectorShape.size()); @@ -380,7 +377,7 @@ reindexAffineIndices(MLFuncBuilder *b, VectorType hwVectorType, // TODO(ntv): support a concrete map and composition. auto app = b->create(b->getInsertionPoint()->getLoc(), affineMap, memrefIndices); - return SmallVector{app->getResults()}; + return SmallVector{app->getResults()}; } /// Returns attributes with the following substitutions applied: @@ -402,21 +399,21 @@ materializeAttributes(OperationStmt *opStmt, VectorType hwVectorType) { /// Creates an instantiated version of `opStmt`. /// Ops other than VectorTransferReadOp/VectorTransferWriteOp require no -/// affine reindexing. Just substitute their SSAValue* operands and be done. For -/// this case the actual instance is irrelevant. Just use the SSA values in +/// affine reindexing. Just substitute their Value operands and be done. For +/// this case the actual instance is irrelevant. Just use the values in /// substitutionsMap. /// /// If the underlying substitution fails, this fails too and returns nullptr. static OperationStmt * instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType, - DenseMap *substitutionsMap) { + DenseMap *substitutionsMap) { assert(!opStmt->isa() && "Should call the function specialized for VectorTransferReadOp"); assert(!opStmt->isa() && "Should call the function specialized for VectorTransferWriteOp"); bool fail = false; auto operands = map( - [hwVectorType, substitutionsMap, &fail](SSAValue *v) -> SSAValue * { + [hwVectorType, substitutionsMap, &fail](Value *v) -> Value * { auto *res = fail ? nullptr : substitute(v, hwVectorType, substitutionsMap); fail |= !res; @@ -481,9 +478,9 @@ static AffineMap projectedPermutationMap(VectorTransferOpTy *transfer, static OperationStmt * instantiate(MLFuncBuilder *b, VectorTransferReadOp *read, VectorType hwVectorType, ArrayRef hwVectorInstance, - DenseMap *substitutionsMap) { - SmallVector indices = - map(makePtrDynCaster(), read->getIndices()); + DenseMap *substitutionsMap) { + SmallVector indices = + map(makePtrDynCaster(), read->getIndices()); auto affineIndices = reindexAffineIndices(b, hwVectorType, hwVectorInstance, indices); auto cloned = b->create( @@ -501,9 +498,9 @@ instantiate(MLFuncBuilder *b, VectorTransferReadOp *read, static OperationStmt * instantiate(MLFuncBuilder *b, VectorTransferWriteOp *write, VectorType hwVectorType, ArrayRef hwVectorInstance, - DenseMap *substitutionsMap) { - SmallVector indices = - map(makePtrDynCaster(), write->getIndices()); + DenseMap *substitutionsMap) { + SmallVector indices = + map(makePtrDynCaster(), write->getIndices()); auto affineIndices = reindexAffineIndices(b, hwVectorType, hwVectorInstance, indices); auto cloned = b->create( @@ -555,8 +552,8 @@ static bool instantiateMaterialization(Statement *stmt, } else if (auto read = opStmt->dyn_cast()) { auto *clone = instantiate(&b, read, state->hwVectorType, state->hwVectorInstance, state->substitutionsMap); - state->substitutionsMap->insert(std::make_pair( - cast(read->getResult()), cast(clone->getResult(0)))); + state->substitutionsMap->insert( + std::make_pair(read->getResult(), clone->getResult(0))); return false; } // The only op with 0 results reaching this point must, by construction, be @@ -571,8 +568,8 @@ static bool instantiateMaterialization(Statement *stmt, if (!clone) { return true; } - state->substitutionsMap->insert(std::make_pair( - cast(opStmt->getResult(0)), cast(clone->getResult(0)))); + state->substitutionsMap->insert( + std::make_pair(opStmt->getResult(0), clone->getResult(0))); return false; } @@ -610,7 +607,7 @@ static bool emitSlice(MaterializationState *state, // Fresh RAII instanceIndices and substitutionsMap. MaterializationState scopedState = *state; scopedState.hwVectorInstance = delinearize(idx, *ratio); - DenseMap substitutionMap; + DenseMap substitutionMap; scopedState.substitutionsMap = &substitutionMap; // slice are topologically sorted, we can just clone them in order. for (auto *stmt : *slice) { diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 13d3ea92307..de1952ca0f5 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -32,7 +32,6 @@ #include "mlir/Transforms/Utils.h" #include "llvm/ADT/DenseMap.h" #include "llvm/Support/Debug.h" - #define DEBUG_TYPE "pipeline-data-transfer" using namespace mlir; @@ -80,7 +79,7 @@ static unsigned getTagMemRefPos(const OperationStmt &dmaStmt) { /// of the old memref by the new one while indexing the newly added dimension by /// the loop IV of the specified 'for' statement modulo 2. Returns false if such /// a replacement cannot be performed. -static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) { +static bool doubleBuffer(Value *oldMemRef, ForStmt *forStmt) { auto *forBody = forStmt->getBody(); MLFuncBuilder bInner(forBody, forBody->begin()); bInner.setInsertionPoint(forBody, forBody->begin()); @@ -103,7 +102,7 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) { // Put together alloc operands for the dynamic dimensions of the memref. MLFuncBuilder bOuter(forStmt); - SmallVector allocOperands; + SmallVector allocOperands; unsigned dynamicDimCount = 0; for (auto dimSize : oldMemRefType.getShape()) { if (dimSize == -1) @@ -114,7 +113,7 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) { // Create and place the alloc right before the 'for' statement. // TODO(mlir-team): we are assuming scoped allocation here, and aren't // inserting a dealloc -- this isn't the right thing. - SSAValue *newMemRef = + Value *newMemRef = bOuter.create(forStmt->getLoc(), newMemRefType, allocOperands); // Create 'iv mod 2' value to index the leading dimension. @@ -126,8 +125,8 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) { // replaceAllMemRefUsesWith will always succeed unless the forStmt body has // non-deferencing uses of the memref. - if (!replaceAllMemRefUsesWith(oldMemRef, cast(newMemRef), - ivModTwoOp->getResult(0), AffineMap::Null(), {}, + if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef, ivModTwoOp->getResult(0), + AffineMap::Null(), {}, &*forStmt->getBody()->begin())) { LLVM_DEBUG(llvm::dbgs() << "memref replacement for double buffering failed\n";); @@ -225,8 +224,7 @@ static void findMatchingStartFinishStmts( continue; // We only double buffer if the buffer is not live out of loop. - const MLValue *memref = - cast(dmaStartOp->getOperand(dmaStartOp->getFasterMemPos())); + auto *memref = dmaStartOp->getOperand(dmaStartOp->getFasterMemPos()); bool escapingUses = false; for (const auto &use : memref->getUses()) { if (!dominates(*forStmt->getBody()->begin(), *use.getOwner())) { @@ -280,8 +278,8 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { // dimension. for (auto &pair : startWaitPairs) { auto *dmaStartStmt = pair.first; - MLValue *oldMemRef = cast(dmaStartStmt->getOperand( - dmaStartStmt->cast()->getFasterMemPos())); + Value *oldMemRef = dmaStartStmt->getOperand( + dmaStartStmt->cast()->getFasterMemPos()); if (!doubleBuffer(oldMemRef, forStmt)) { // Normally, double buffering should not fail because we already checked // that there are no uses outside. @@ -302,8 +300,8 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { // Double the buffers for tag memrefs. for (auto &pair : startWaitPairs) { auto *dmaFinishStmt = pair.second; - MLValue *oldTagMemRef = cast( - dmaFinishStmt->getOperand(getTagMemRefPos(*dmaFinishStmt))); + Value *oldTagMemRef = + dmaFinishStmt->getOperand(getTagMemRefPos(*dmaFinishStmt)); if (!doubleBuffer(oldTagMemRef, forStmt)) { LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";); return success(); @@ -332,7 +330,7 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { // If a slice wasn't created, the reachable affine_apply op's from its // operands are the ones that go with it. SmallVector affineApplyStmts; - SmallVector operands(dmaStartStmt->getOperands()); + SmallVector operands(dmaStartStmt->getOperands()); getReachableAffineApplyOps(operands, affineApplyStmts); for (const auto *stmt : affineApplyStmts) { stmtShiftMap[stmt] = 0; diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 0af7e52b5b1..9d955fb6a81 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -217,7 +217,7 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction, // If we already have a canonicalized version of this constant, just // reuse it. Otherwise create a new one. - SSAValue *cstValue; + Value *cstValue; auto it = uniquedConstants.find({resultConstants[i], res->getType()}); if (it != uniquedConstants.end()) cstValue = it->second->getResult(0); diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 5a5617f3fb1..e8fc5e7ca14 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -31,7 +31,6 @@ #include "mlir/StandardOps/StandardOps.h" #include "llvm/ADT/DenseMap.h" #include "llvm/Support/Debug.h" - #define DEBUG_TYPE "LoopUtils" using namespace mlir; @@ -108,8 +107,7 @@ bool mlir::promoteIfSingleIteration(ForStmt *forStmt) { forStmt->replaceAllUsesWith(constOp); } else { const AffineBound lb = forStmt->getLowerBound(); - SmallVector lbOperands(lb.operand_begin(), - lb.operand_end()); + SmallVector lbOperands(lb.operand_begin(), lb.operand_end()); MLFuncBuilder builder(forStmt->getBlock(), StmtBlock::iterator(forStmt)); auto affineApplyOp = builder.create( forStmt->getLoc(), lb.getMap(), lbOperands); @@ -149,8 +147,8 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, const std::vector>> &stmtGroupQueue, unsigned offset, ForStmt *srcForStmt, MLFuncBuilder *b) { - SmallVector lbOperands(srcForStmt->getLowerBoundOperands()); - SmallVector ubOperands(srcForStmt->getUpperBoundOperands()); + SmallVector lbOperands(srcForStmt->getLowerBoundOperands()); + SmallVector ubOperands(srcForStmt->getUpperBoundOperands()); assert(lbMap.getNumInputs() == lbOperands.size()); assert(ubMap.getNumInputs() == ubOperands.size()); @@ -176,7 +174,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, srcForStmt->getStep() * shift)), loopChunk) ->getResult(0); - operandMap[srcForStmt] = cast(ivRemap); + operandMap[srcForStmt] = ivRemap; } else { operandMap[srcForStmt] = loopChunk; } @@ -380,7 +378,7 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) { // Generate the cleanup loop if trip count isn't a multiple of unrollFactor. if (getLargestDivisorOfTripCount(*forStmt) % unrollFactor != 0) { - DenseMap operandMap; + DenseMap operandMap; MLFuncBuilder builder(forStmt->getBlock(), ++StmtBlock::iterator(forStmt)); auto *cleanupForStmt = cast(builder.clone(*forStmt, operandMap)); auto clLbMap = getCleanupLoopLowerBound(*forStmt, unrollFactor, &builder); @@ -414,7 +412,7 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) { // Unroll the contents of 'forStmt' (append unrollFactor-1 additional copies). for (unsigned i = 1; i < unrollFactor; i++) { - DenseMap operandMap; + DenseMap operandMap; // If the induction variable is used, create a remapping to the value for // this unrolled instance. @@ -425,7 +423,7 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) { auto *ivUnroll = builder.create(forStmt->getLoc(), bumpMap, forStmt) ->getResult(0); - operandMap[forStmt] = cast(ivUnroll); + operandMap[forStmt] = ivUnroll; } // Clone the original body of 'forStmt'. diff --git a/mlir/lib/Transforms/Utils/LoweringUtils.cpp b/mlir/lib/Transforms/Utils/LoweringUtils.cpp index c8ac881dba7..8457ce4ce28 100644 --- a/mlir/lib/Transforms/Utils/LoweringUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoweringUtils.cpp @@ -32,17 +32,17 @@ using namespace mlir; namespace { // Visit affine expressions recursively and build the sequence of instructions -// that correspond to it. Visitation functions return an SSAValue of the +// that correspond to it. Visitation functions return an Value of the // expression subtree they visited or `nullptr` on error. class AffineApplyExpander - : public AffineExprVisitor { + : public AffineExprVisitor { public: // This internal clsas expects arguments to be non-null, checks must be // performed at the call site. AffineApplyExpander(FuncBuilder *builder, AffineApplyOp *op) : builder(*builder), applyOp(*op), loc(op->getLoc()) {} - template SSAValue *buildBinaryExpr(AffineBinaryOpExpr expr) { + template Value *buildBinaryExpr(AffineBinaryOpExpr expr) { auto lhs = visit(expr.getLHS()); auto rhs = visit(expr.getRHS()); if (!lhs || !rhs) @@ -51,33 +51,33 @@ public: return op->getResult(); } - SSAValue *visitAddExpr(AffineBinaryOpExpr expr) { + Value *visitAddExpr(AffineBinaryOpExpr expr) { return buildBinaryExpr(expr); } - SSAValue *visitMulExpr(AffineBinaryOpExpr expr) { + Value *visitMulExpr(AffineBinaryOpExpr expr) { return buildBinaryExpr(expr); } // TODO(zinenko): implement when the standard operators are made available. - SSAValue *visitModExpr(AffineBinaryOpExpr) { + Value *visitModExpr(AffineBinaryOpExpr) { builder.getContext()->emitError(loc, "unsupported binary operator: mod"); return nullptr; } - SSAValue *visitFloorDivExpr(AffineBinaryOpExpr) { + Value *visitFloorDivExpr(AffineBinaryOpExpr) { builder.getContext()->emitError(loc, "unsupported binary operator: floor_div"); return nullptr; } - SSAValue *visitCeilDivExpr(AffineBinaryOpExpr) { + Value *visitCeilDivExpr(AffineBinaryOpExpr) { builder.getContext()->emitError(loc, "unsupported binary operator: ceil_div"); return nullptr; } - SSAValue *visitConstantExpr(AffineConstantExpr expr) { + Value *visitConstantExpr(AffineConstantExpr expr) { auto valueAttr = builder.getIntegerAttr(builder.getIndexType(), expr.getValue()); auto op = @@ -85,7 +85,7 @@ public: return op->getResult(); } - SSAValue *visitDimExpr(AffineDimExpr expr) { + Value *visitDimExpr(AffineDimExpr expr) { assert(expr.getPosition() < applyOp.getNumOperands() && "affine dim position out of range"); // FIXME: this assumes a certain order of AffineApplyOp operands, the @@ -93,7 +93,7 @@ public: return applyOp.getOperand(expr.getPosition()); } - SSAValue *visitSymbolExpr(AffineSymbolExpr expr) { + Value *visitSymbolExpr(AffineSymbolExpr expr) { // FIXME: this assumes a certain order of AffineApplyOp operands, the // cleaner interface would be to separate them at the op level. assert(expr.getPosition() + applyOp.getAffineMap().getNumDims() < @@ -114,8 +114,8 @@ private: // Given an affine expression `expr` extracted from `op`, build the sequence of // primitive instructions that correspond to the affine expression in the // `builder`. -static SSAValue *expandAffineExpr(FuncBuilder *builder, AffineExpr expr, - AffineApplyOp *op) { +static mlir::Value *expandAffineExpr(FuncBuilder *builder, AffineExpr expr, + AffineApplyOp *op) { auto expander = AffineApplyExpander(builder, op); return expander.visit(expr); } @@ -127,7 +127,7 @@ bool mlir::expandAffineApply(AffineApplyOp *op) { FuncBuilder builder(op->getOperation()); auto affineMap = op->getAffineMap(); for (auto numberedExpr : llvm::enumerate(affineMap.getResults())) { - SSAValue *expanded = expandAffineExpr(&builder, numberedExpr.value(), op); + Value *expanded = expandAffineExpr(&builder, numberedExpr.value(), op); if (!expanded) return true; op->getResult(numberedExpr.index())->replaceAllUsesWith(expanded); diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 2818e8c2e4f..624a8a758b5 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -31,7 +31,6 @@ #include "mlir/StandardOps/StandardOps.h" #include "mlir/Support/MathExtras.h" #include "llvm/ADT/DenseMap.h" - using namespace mlir; /// Return true if this operation dereferences one or more memref's. @@ -61,13 +60,12 @@ static bool isMemRefDereferencingOp(const Operation &op) { // extra operands, note that 'indexRemap' would just be applied to the existing // indices (%i, %j). // -// TODO(mlir-team): extend this for SSAValue / CFGFunctions. Can also be easily +// TODO(mlir-team): extend this for Value/ CFGFunctions. Can also be easily // extended to add additional indices at any position. -bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef, - MLValue *newMemRef, - ArrayRef extraIndices, +bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, + ArrayRef extraIndices, AffineMap indexRemap, - ArrayRef extraOperands, + ArrayRef extraOperands, const Statement *domStmtFilter) { unsigned newMemRefRank = newMemRef->getType().cast().getRank(); (void)newMemRefRank; // unused in opt mode @@ -128,16 +126,15 @@ bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef, // operation. assert(extraIndex->getDefiningStmt()->getNumResults() == 1 && "single result op's expected to generate these indices"); - assert((cast(extraIndex)->isValidDim() || - cast(extraIndex)->isValidSymbol()) && + assert((extraIndex->isValidDim() || extraIndex->isValidSymbol()) && "invalid memory op index"); - state.operands.push_back(cast(extraIndex)); + state.operands.push_back(extraIndex); } // Construct new indices as a remap of the old ones if a remapping has been // provided. The indices of a memref come right after it, i.e., // at position memRefOperandPos + 1. - SmallVector remapOperands; + SmallVector remapOperands; remapOperands.reserve(oldMemRefRank + extraOperands.size()); remapOperands.insert(remapOperands.end(), extraOperands.begin(), extraOperands.end()); @@ -149,11 +146,11 @@ bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef, remapOperands); // Remapped indices. for (auto *index : remapOp->getOperation()->getResults()) - state.operands.push_back(cast(index)); + state.operands.push_back(index); } else { // No remapping specified. for (auto *index : remapOperands) - state.operands.push_back(cast(index)); + state.operands.push_back(index); } // Insert the remaining operands unmodified. @@ -191,9 +188,9 @@ bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef, // composed AffineApplyOp are returned in output parameter 'results'. OperationStmt * mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc, - ArrayRef operands, + ArrayRef operands, ArrayRef affineApplyOps, - SmallVectorImpl *results) { + SmallVectorImpl *results) { // Create identity map with same number of dimensions as number of operands. auto map = builder->getMultiDimIdentityMap(operands.size()); // Initialize AffineValueMap with identity map. @@ -208,7 +205,7 @@ mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc, // Compose affine maps from all ancestor AffineApplyOps. // Create new AffineApplyOp from 'valueMap'. unsigned numOperands = valueMap.getNumOperands(); - SmallVector outOperands(numOperands); + SmallVector outOperands(numOperands); for (unsigned i = 0; i < numOperands; ++i) { outOperands[i] = valueMap.getOperand(i); } @@ -252,7 +249,7 @@ mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc, /// otherwise. OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) { // Collect all operands that are results of affine apply ops. - SmallVector subOperands; + SmallVector subOperands; subOperands.reserve(opStmt->getNumOperands()); for (auto *operand : opStmt->getOperands()) { auto *defStmt = operand->getDefiningStmt(); @@ -285,7 +282,7 @@ OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) { return nullptr; FuncBuilder builder(opStmt); - SmallVector results; + SmallVector results; auto *affineApplyStmt = createComposedAffineApplyOp( &builder, opStmt->getLoc(), subOperands, affineApplyOps, &results); assert(results.size() == subOperands.size() && @@ -295,7 +292,7 @@ OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) { // affine apply op above instead of existing ones (subOperands). So, they // differ from opStmt's operands only for those operands in 'subOperands', for // which they will be replaced by the corresponding one from 'results'. - SmallVector newOperands(opStmt->getOperands()); + SmallVector newOperands(opStmt->getOperands()); for (unsigned i = 0, e = newOperands.size(); i < e; i++) { // Replace the subOperands from among the new operands. unsigned j, f; @@ -304,7 +301,7 @@ OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) { break; } if (j < subOperands.size()) { - newOperands[i] = cast(results[j]); + newOperands[i] = results[j]; } } @@ -326,7 +323,7 @@ void mlir::forwardSubstitute(OpPointer affineApplyOp) { // into any uses which are AffineApplyOps. for (unsigned resultIndex = 0, e = opStmt->getNumResults(); resultIndex < e; ++resultIndex) { - const MLValue *result = opStmt->getResult(resultIndex); + const Value *result = opStmt->getResult(resultIndex); for (auto it = result->use_begin(); it != result->use_end();) { StmtOperand &use = *(it++); auto *useStmt = use.getOwner(); @@ -347,7 +344,7 @@ void mlir::forwardSubstitute(OpPointer affineApplyOp) { // Create new AffineApplyOp from 'valueMap'. unsigned numOperands = valueMap.getNumOperands(); - SmallVector operands(numOperands); + SmallVector operands(numOperands); for (unsigned i = 0; i < numOperands; ++i) { operands[i] = valueMap.getOperand(i); } diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index aa80f47b826..9fe002c8fcb 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -27,8 +27,6 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Location.h" -#include "mlir/IR/MLValue.h" -#include "mlir/IR/SSAValue.h" #include "mlir/IR/Types.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" @@ -740,8 +738,8 @@ struct VectorizationState { DenseSet vectorizedSet; // Map of old scalar OperationStmt to new vectorized OperationStmt. DenseMap vectorizationMap; - // Map of old scalar MLValue to new vectorized MLValue. - DenseMap replacementMap; + // Map of old scalar Value to new vectorized Value. + DenseMap replacementMap; // The strategy drives which loop to vectorize by which amount. const VectorizationStrategy *strategy; // Use-def roots. These represent the starting points for the worklist in the @@ -761,7 +759,7 @@ struct VectorizationState { void registerTerminator(OperationStmt *stmt); private: - void registerReplacement(const SSAValue *key, SSAValue *value); + void registerReplacement(const Value *key, Value *value); }; } // end namespace @@ -802,12 +800,9 @@ void VectorizationState::finishVectorizationPattern() { } } -void VectorizationState::registerReplacement(const SSAValue *key, - SSAValue *value) { - assert(replacementMap.count(cast(key)) == 0 && - "replacement already registered"); - replacementMap.insert( - std::make_pair(cast(key), cast(value))); +void VectorizationState::registerReplacement(const Value *key, Value *value) { + assert(replacementMap.count(key) == 0 && "replacement already registered"); + replacementMap.insert(std::make_pair(key, value)); } ////// TODO(ntv): Hoist to a VectorizationMaterialize.cpp when appropriate. //// @@ -825,7 +820,7 @@ void VectorizationState::registerReplacement(const SSAValue *key, /// Such special cases force us to delay the vectorization of the stores /// until the last step. Here we merely register the store operation. template -static bool vectorizeRootOrTerminal(MLValue *iv, LoadOrStoreOpPointer memoryOp, +static bool vectorizeRootOrTerminal(Value *iv, LoadOrStoreOpPointer memoryOp, VectorizationState *state) { auto memRefType = memoryOp->getMemRef()->getType().template cast(); @@ -850,8 +845,7 @@ static bool vectorizeRootOrTerminal(MLValue *iv, LoadOrStoreOpPointer memoryOp, MLFuncBuilder b(opStmt); auto transfer = b.create( opStmt->getLoc(), vectorType, memoryOp->getMemRef(), - map(makePtrDynCaster(), memoryOp->getIndices()), - permutationMap); + map(makePtrDynCaster(), memoryOp->getIndices()), permutationMap); state->registerReplacement(opStmt, cast(transfer->getOperation())); } else { @@ -970,8 +964,8 @@ static bool vectorizeNonRoot(MLFunctionMatches matches, /// element type. /// If `type` is not a valid vector type or if the scalar constant is not a /// valid vector element type, returns nullptr. -static MLValue *vectorizeConstant(Statement *stmt, const ConstantOp &constant, - Type type) { +static Value *vectorizeConstant(Statement *stmt, const ConstantOp &constant, + Type type) { if (!type || !type.isa() || !VectorType::isValidElementType(constant.getType())) { return nullptr; @@ -988,7 +982,7 @@ static MLValue *vectorizeConstant(Statement *stmt, const ConstantOp &constant, {make_pair(Identifier::get("value", b.getContext()), attr)}); auto *splat = cast(b.createOperation(state)); - return cast(splat->getResult(0)); + return splat->getResult(0); } /// Returns a uniqu'ed VectorType. @@ -996,7 +990,7 @@ static MLValue *vectorizeConstant(Statement *stmt, const ConstantOp &constant, /// vectorizedSet, just returns the type of `v`. /// Otherwise, constructs a new VectorType of shape defined by `state.strategy` /// and of elemental type the type of `v`. -static Type getVectorType(SSAValue *v, const VectorizationState &state) { +static Type getVectorType(Value *v, const VectorizationState &state) { if (!VectorType::isValidElementType(v->getType())) { return Type(); } @@ -1028,23 +1022,23 @@ static Type getVectorType(SSAValue *v, const VectorizationState &state) { /// vectorization is possible with the above logic. Returns nullptr otherwise. /// /// TODO(ntv): handle more complex cases. -static MLValue *vectorizeOperand(SSAValue *operand, Statement *stmt, - VectorizationState *state) { +static Value *vectorizeOperand(Value *operand, Statement *stmt, + VectorizationState *state) { LLVM_DEBUG(dbgs() << "\n[early-vect]vectorize operand: "); LLVM_DEBUG(operand->print(dbgs())); auto *definingStatement = cast(operand->getDefiningStmt()); // 1. If this value has already been vectorized this round, we are done. if (state->vectorizedSet.count(definingStatement) > 0) { LLVM_DEBUG(dbgs() << " -> already vector operand"); - return cast(operand); + return operand; } // 1.b. Delayed on-demand replacement of a use. // Note that we cannot just call replaceAllUsesWith because it may result // in ops with mixed types, for ops whose operands have not all yet // been vectorized. This would be invalid IR. - auto it = state->replacementMap.find(cast(operand)); + auto it = state->replacementMap.find(operand); if (it != state->replacementMap.end()) { - auto *res = cast(it->second); + auto *res = it->second; LLVM_DEBUG(dbgs() << "-> delayed replacement by: "); LLVM_DEBUG(res->print(dbgs())); return res; @@ -1089,7 +1083,7 @@ static OperationStmt *vectorizeOneOperationStmt(MLFuncBuilder *b, auto *memRef = store->getMemRef(); auto *value = store->getValueToStore(); auto *vectorValue = vectorizeOperand(value, opStmt, state); - auto indices = map(makePtrDynCaster(), store->getIndices()); + auto indices = map(makePtrDynCaster(), store->getIndices()); MLFuncBuilder b(opStmt); auto permutationMap = makePermutationMap(opStmt, state->strategy->loopToVectorDim); @@ -1104,14 +1098,14 @@ static OperationStmt *vectorizeOneOperationStmt(MLFuncBuilder *b, return res; } - auto types = map([state](SSAValue *v) { return getVectorType(v, *state); }, + auto types = map([state](Value *v) { return getVectorType(v, *state); }, opStmt->getResults()); - auto vectorizeOneOperand = [opStmt, state](SSAValue *op) -> SSAValue * { + auto vectorizeOneOperand = [opStmt, state](Value *op) -> Value * { return vectorizeOperand(op, opStmt, state); }; auto operands = map(vectorizeOneOperand, opStmt->getOperands()); // Check whether a single operand is null. If so, vectorization failed. - bool success = llvm::all_of(operands, [](SSAValue *op) { return op; }); + bool success = llvm::all_of(operands, [](Value *op) { return op; }); if (!success) { LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ an operand failed vectorize"); return nullptr; @@ -1207,7 +1201,7 @@ static bool vectorizeRootMatches(MLFunctionMatches matches, continue; } MLFuncBuilder builder(loop); // builder to insert in place of loop - DenseMap nomap; + DenseMap nomap; ForStmt *clonedLoop = cast(builder.clone(*loop, nomap)); auto fail = doVectorize(m, &state); /// Sets up error handling for this root loop. This is how the root match @@ -1229,8 +1223,8 @@ static bool vectorizeRootMatches(MLFunctionMatches matches, // Form the root operationsthat have been set in the replacementMap. // For now, these roots are the loads for which vector_transfer_read // operations have been inserted. - auto getDefiningOperation = [](const MLValue *val) { - return const_cast(val)->getDefiningOperation(); + auto getDefiningOperation = [](const Value *val) { + return const_cast(val)->getDefiningOperation(); }; using ReferenceTy = decltype(*(state.replacementMap.begin())); auto getKey = [](ReferenceTy it) { return it.first; }; diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 8357f427fb8..0fa23c69566 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -288,10 +288,10 @@ void OpEmitter::emitAttrGetters() { } void OpEmitter::emitNamedOperands() { - const auto operandMethods = R"( SSAValue *{0}() { + const auto operandMethods = R"( Value *{0}() { return this->getOperation()->getOperand({1}); } - const SSAValue *{0}() const { + const Value *{0}() const { return this->getOperation()->getOperand({1}); } )"; @@ -329,7 +329,7 @@ void OpEmitter::emitBuilder() { // Emit parameters for all operands for (const auto &pair : operands) - os << ", SSAValue* " << pair.first; + os << ", Value* " << pair.first; // Emit parameters for all attributes // TODO(antiagainst): Support default initializer for attributes @@ -369,7 +369,7 @@ void OpEmitter::emitBuilder() { // Signature os << " static void build(Builder* builder, OperationState* result, " - << "ArrayRef resultTypes, ArrayRef args, " + << "ArrayRef resultTypes, ArrayRef args, " "ArrayRef attributes) {\n"; // Result types -- cgit v1.2.3 From 5187cfcf03d36fcd9a08adb768d0bc584ef9e50d Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Thu, 27 Dec 2018 21:21:41 -0800 Subject: Merge Operation into OperationInst and standardize nomenclature around OperationInst. This is a big mechanical patch. This is step 16/n towards merging instructions and statements, NFC. PiperOrigin-RevId: 227093712 --- mlir/g3doc/GenericDAGRewriter.md | 8 +- mlir/include/mlir/Analysis/AffineAnalysis.h | 8 +- mlir/include/mlir/Analysis/LoopAnalysis.h | 2 +- mlir/include/mlir/Analysis/MLFunctionMatcher.h | 2 +- mlir/include/mlir/Analysis/SliceAnalysis.h | 6 +- mlir/include/mlir/Analysis/Utils.h | 4 +- mlir/include/mlir/Analysis/VectorAnalysis.h | 6 +- mlir/include/mlir/IR/Builders.h | 6 +- mlir/include/mlir/IR/BuiltinOps.h | 57 +-- mlir/include/mlir/IR/Dialect.h | 4 +- mlir/include/mlir/IR/Function.h | 12 +- mlir/include/mlir/IR/MLIRContext.h | 4 +- mlir/include/mlir/IR/Matchers.h | 8 +- mlir/include/mlir/IR/OpDefinition.h | 142 ++++---- mlir/include/mlir/IR/OpImplementation.h | 4 +- mlir/include/mlir/IR/Operation.h | 399 --------------------- mlir/include/mlir/IR/OperationSupport.h | 21 +- mlir/include/mlir/IR/PatternMatch.h | 27 +- mlir/include/mlir/IR/Statement.h | 4 +- mlir/include/mlir/IR/Statements.h | 340 +++++++++++++++--- mlir/include/mlir/IR/StmtBlock.h | 6 +- mlir/include/mlir/IR/StmtVisitor.h | 24 +- mlir/include/mlir/IR/UseDefLists.h | 6 +- mlir/include/mlir/IR/Value.h | 29 +- mlir/include/mlir/StandardOps/StandardOps.h | 103 +++--- mlir/include/mlir/SuperVectorOps/SuperVectorOps.h | 22 +- .../mlir/Transforms/MLPatternLoweringPass.h | 10 +- mlir/include/mlir/Transforms/Utils.h | 10 +- mlir/lib/Analysis/AffineAnalysis.cpp | 14 +- mlir/lib/Analysis/AffineStructures.cpp | 2 +- mlir/lib/Analysis/LoopAnalysis.cpp | 14 +- mlir/lib/Analysis/MLFunctionMatcher.cpp | 4 +- mlir/lib/Analysis/MemRefBoundCheck.cpp | 4 +- mlir/lib/Analysis/MemRefDependenceCheck.cpp | 10 +- mlir/lib/Analysis/OpStats.cpp | 4 +- mlir/lib/Analysis/SliceAnalysis.cpp | 6 +- mlir/lib/Analysis/Utils.cpp | 6 +- mlir/lib/Analysis/VectorAnalysis.cpp | 6 +- mlir/lib/Analysis/Verifier.cpp | 24 +- mlir/lib/IR/AsmPrinter.cpp | 40 +-- mlir/lib/IR/Builders.cpp | 2 +- mlir/lib/IR/BuiltinOps.cpp | 24 +- mlir/lib/IR/Function.cpp | 24 +- mlir/lib/IR/MLIRContext.cpp | 2 +- mlir/lib/IR/Operation.cpp | 256 ++----------- mlir/lib/IR/PatternMatch.cpp | 15 +- mlir/lib/IR/Statement.cpp | 166 +++++++-- mlir/lib/IR/StmtBlock.cpp | 4 +- mlir/lib/IR/Value.cpp | 24 +- mlir/lib/Parser/Parser.cpp | 18 +- mlir/lib/StandardOps/StandardOps.cpp | 24 +- mlir/lib/SuperVectorOps/SuperVectorOps.cpp | 8 +- mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp | 12 +- mlir/lib/Transforms/CSE.cpp | 18 +- mlir/lib/Transforms/ComposeAffineMaps.cpp | 6 +- mlir/lib/Transforms/ConstantFold.cpp | 19 +- mlir/lib/Transforms/ConvertToCFG.cpp | 12 +- mlir/lib/Transforms/DmaGeneration.cpp | 4 +- mlir/lib/Transforms/LoopFusion.cpp | 28 +- mlir/lib/Transforms/LoopUnroll.cpp | 2 +- mlir/lib/Transforms/LowerVectorTransfers.cpp | 9 +- mlir/lib/Transforms/MaterializeVectors.cpp | 35 +- mlir/lib/Transforms/PipelineDataTransfer.cpp | 20 +- mlir/lib/Transforms/SimplifyAffineExpr.cpp | 4 +- .../Utils/GreedyPatternRewriteDriver.cpp | 39 +- mlir/lib/Transforms/Utils/LoopUtils.cpp | 2 +- mlir/lib/Transforms/Utils/LoweringUtils.cpp | 2 +- mlir/lib/Transforms/Utils/Utils.cpp | 33 +- .../Vectorization/VectorizerTestPass.cpp | 10 +- mlir/lib/Transforms/Vectorize.cpp | 88 ++--- mlir/test/mlir-rewriter-gen/one-op-one-result.td | 4 +- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 4 +- mlir/tools/mlir-tblgen/RewriterGen.cpp | 4 +- 73 files changed, 1033 insertions(+), 1297 deletions(-) delete mode 100644 mlir/include/mlir/IR/Operation.h (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/g3doc/GenericDAGRewriter.md b/mlir/g3doc/GenericDAGRewriter.md index 5f425cac7b4..174b46a6391 100644 --- a/mlir/g3doc/GenericDAGRewriter.md +++ b/mlir/g3doc/GenericDAGRewriter.md @@ -322,14 +322,14 @@ class Pattern { /// returns a None value. On success it a (possibly null) pattern-specific /// state wrapped in a Some. This state is passed back into its rewrite /// function if this match is selected. - virtual Optional match(Operation *op) const = 0; + virtual Optional match(OperationInst *op) const = 0; /// Rewrite the IR rooted at the specified operation with the result of /// this pattern, generating any new operations with the specified /// rewriter. If an unexpected error is encountered (an internal /// compiler error), it is emitted through the normal MLIR diagnostic /// hooks and the IR is left in a valid state. - virtual void rewrite(Operation *op, PatternState *state, + virtual void rewrite(OperationInst *op, PatternState *state, PatternRewriter &rewriter) const; }; ``` @@ -372,8 +372,8 @@ class PatternMatcher { // Given a specific operation, see if there is some rewrite that is // interesting. If so, return success and return the list of new // operations that were created. If not, return failure. - bool matchAndRewrite(Operation *op, - SmallVectorImpl &newlyCreatedOps); + bool matchAndRewrite(OperationInst *op, + SmallVectorImpl &newlyCreatedOps); }; ``` diff --git a/mlir/include/mlir/Analysis/AffineAnalysis.h b/mlir/include/mlir/Analysis/AffineAnalysis.h index 9c288aac7cc..5ffaf845cfc 100644 --- a/mlir/include/mlir/Analysis/AffineAnalysis.h +++ b/mlir/include/mlir/Analysis/AffineAnalysis.h @@ -37,7 +37,7 @@ class ForStmt; class MLIRContext; class FlatAffineConstraints; class IntegerSet; -class OperationStmt; +class OperationInst; class Statement; class Value; @@ -74,12 +74,12 @@ AffineMap composeUnboundedMaps(AffineMap f, AffineMap g); /// smaller than the number of results of `g`. AffineExpr composeWithUnboundedMap(AffineExpr e, AffineMap g); -/// Returns the sequence of AffineApplyOp OperationStmts operation in +/// Returns the sequence of AffineApplyOp OperationInsts operation in /// 'affineApplyOps', which are reachable via a search starting from 'operands', /// and ending at operands which are not defined by AffineApplyOps. void getReachableAffineApplyOps( llvm::ArrayRef operands, - llvm::SmallVectorImpl &affineApplyOps); + llvm::SmallVectorImpl &affineApplyOps); /// Forward substitutes into 'valueMap' all AffineApplyOps reachable from the /// operands of 'valueMap'. @@ -123,7 +123,7 @@ bool getIndexSet(llvm::ArrayRef forStmts, struct MemRefAccess { const Value *memref; - const OperationStmt *opStmt; + const OperationInst *opStmt; llvm::SmallVector indices; // Populates 'accessMap' with composition of AffineApplyOps reachable from // 'indices'. diff --git a/mlir/include/mlir/Analysis/LoopAnalysis.h b/mlir/include/mlir/Analysis/LoopAnalysis.h index 9cac049e7b7..69fb81d0a1f 100644 --- a/mlir/include/mlir/Analysis/LoopAnalysis.h +++ b/mlir/include/mlir/Analysis/LoopAnalysis.h @@ -32,7 +32,7 @@ class AffineExpr; class AffineMap; class ForStmt; class MemRefType; -class OperationStmt; +class OperationInst; class Value; /// Returns the trip count of the loop as an affine expression if the latter is diff --git a/mlir/include/mlir/Analysis/MLFunctionMatcher.h b/mlir/include/mlir/Analysis/MLFunctionMatcher.h index aa05b7dad05..bd99363cafb 100644 --- a/mlir/include/mlir/Analysis/MLFunctionMatcher.h +++ b/mlir/include/mlir/Analysis/MLFunctionMatcher.h @@ -121,7 +121,7 @@ private: void visitForStmt(ForStmt *forStmt) { matchOne(forStmt); } void visitIfStmt(IfStmt *ifStmt) { matchOne(ifStmt); } - void visitOperationStmt(OperationStmt *opStmt) { matchOne(opStmt); } + void visitOperationInst(OperationInst *opStmt) { matchOne(opStmt); } /// Underlying global bump allocator managed by an MLFunctionMatcherContext. static llvm::BumpPtrAllocator *&allocator(); diff --git a/mlir/include/mlir/Analysis/SliceAnalysis.h b/mlir/include/mlir/Analysis/SliceAnalysis.h index f3e09655bf2..c3db378d971 100644 --- a/mlir/include/mlir/Analysis/SliceAnalysis.h +++ b/mlir/include/mlir/Analysis/SliceAnalysis.h @@ -168,10 +168,10 @@ void getBackwardSlice( /// /____\ /// /// We want to iteratively apply `getSlice` to construct the whole -/// list of OperationStmt that are reachable by (use|def)+ from stmt. +/// list of OperationInst that are reachable by (use|def)+ from stmt. /// We want the resulting slice in topological order. /// Ideally we would like the ordering to be maintained in-place to avoid -/// copying OperationStmt at each step. Keeping this ordering by construction +/// copying OperationInst at each step. Keeping this ordering by construction /// seems very unclear, so we list invariants in the hope of seeing whether /// useful properties pop up. /// @@ -207,7 +207,7 @@ llvm::SetVector getSlice( [](Statement *) { return true; }); /// Multi-root DAG topological sort. -/// Performs a topological sort of the OperationStmt in the `toSort` SetVector. +/// Performs a topological sort of the OperationInst in the `toSort` SetVector. /// Returns a topologically sorted SetVector. llvm::SetVector topologicalSort(const llvm::SetVector &toSort); diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index 284a6fd6735..eb8dbe530ea 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -35,7 +35,7 @@ namespace mlir { class FlatAffineConstraints; class ForStmt; class MemRefAccess; -class OperationStmt; +class OperationInst; class Statement; class Value; @@ -128,7 +128,7 @@ private: /// {memref = %A, write = false, {%i <= m0 <= %i + 7} } /// The last field is a 2-d FlatAffineConstraints symbolic in %i. /// -bool getMemRefRegion(OperationStmt *opStmt, unsigned loopDepth, +bool getMemRefRegion(OperationInst *opStmt, unsigned loopDepth, MemRefRegion *region); /// Returns the size of memref data in bytes if it's statically shaped, None diff --git a/mlir/include/mlir/Analysis/VectorAnalysis.h b/mlir/include/mlir/Analysis/VectorAnalysis.h index 44310f7485d..f84aff29946 100644 --- a/mlir/include/mlir/Analysis/VectorAnalysis.h +++ b/mlir/include/mlir/Analysis/VectorAnalysis.h @@ -27,7 +27,7 @@ namespace mlir { class AffineMap; class ForStmt; class MemRefType; -class OperationStmt; +class OperationInst; class VectorType; /// Computes and returns the multi-dimensional ratio of `superShape` to @@ -118,7 +118,7 @@ shapeRatio(VectorType superVectorType, VectorType subVectorType); /// `%arg0[%c0, %c0]` into vector<128xf32> which needs a 1-D vector broadcast. /// AffineMap -makePermutationMap(OperationStmt *opStmt, +makePermutationMap(OperationInst *opStmt, const llvm::DenseMap &loopToVectorDim); namespace matcher { @@ -131,7 +131,7 @@ namespace matcher { /// TODO(ntv): this could all be much simpler if we added a bit that a vector /// type to mark that a vector is a strict super-vector but it still does not /// warrant adding even 1 extra bit in the IR for now. -bool operatesOnStrictSuperVectors(const OperationStmt &stmt, +bool operatesOnStrictSuperVectors(const OperationInst &stmt, VectorType subVectorType); } // end namespace matcher diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 2efa21a3dc8..1ad533b0983 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -243,7 +243,7 @@ public: StmtBlock *getBlock() const { return block; } /// Creates an operation given the fields represented as an OperationState. - OperationStmt *createOperation(const OperationState &state); + OperationInst *createOperation(const OperationState &state); /// Create operation of specific op type at the current insertion point. template @@ -265,7 +265,7 @@ public: OpTy::build(this, &state, args...); auto *stmt = createOperation(state); - // If the OperationStmt we produce is valid, return it. + // If the OperationInst we produce is valid, return it. if (!OpTy::verifyInvariants(stmt)) { auto result = stmt->dyn_cast(); assert(result && "Builder didn't return the right type"); @@ -284,7 +284,7 @@ public: /// sub-statements to the corresponding statement that is copied, and adds /// those mappings to the map. Statement *clone(const Statement &stmt, - OperationStmt::OperandMapTy &operandMapping) { + OperationInst::OperandMapTy &operandMapping) { Statement *cloneStmt = stmt.clone(operandMapping, getContext()); block->getStatements().insert(insertPoint, cloneStmt); return cloneStmt; diff --git a/mlir/include/mlir/IR/BuiltinOps.h b/mlir/include/mlir/IR/BuiltinOps.h index 3f4ec6fcccd..3ccfe4f9f2d 100644 --- a/mlir/include/mlir/IR/BuiltinOps.h +++ b/mlir/include/mlir/IR/BuiltinOps.h @@ -80,8 +80,8 @@ public: MLIRContext *context) const; private: - friend class Operation; - explicit AffineApplyOp(const Operation *state) : Op(state) {} + friend class OperationInst; + explicit AffineApplyOp(const OperationInst *state) : Op(state) {} }; /// The "br" operation represents a branch instruction in a CFG function. @@ -108,15 +108,18 @@ public: bool verify() const; /// Return the block this branch jumps to. - BasicBlock *getDest() const; + BasicBlock *getDest(); + const BasicBlock *getDest() const { + return const_cast(this)->getDest(); + } void setDest(BasicBlock *block); /// Erase the operand at 'index' from the operand list. void eraseOperand(unsigned index); private: - friend class Operation; - explicit BranchOp(const Operation *state) : Op(state) {} + friend class OperationInst; + explicit BranchOp(const OperationInst *state) : Op(state) {} }; /// The "cond_br" operation represents a conditional branch instruction in a @@ -157,10 +160,16 @@ public: const Value *getCondition() const { return getOperand(0); } /// Return the destination if the condition is true. - BasicBlock *getTrueDest() const; + BasicBlock *getTrueDest(); + const BasicBlock *getTrueDest() const { + return const_cast(this)->getTrueDest(); + } /// Return the destination if the condition is false. - BasicBlock *getFalseDest() const; + BasicBlock *getFalseDest(); + const BasicBlock *getFalseDest() const { + return const_cast(this)->getFalseDest(); + } // Accessors for operands to the 'true' destination. Value *getTrueOperand(unsigned idx) { @@ -241,8 +250,8 @@ private: return getTrueDestOperandIndex() + getNumTrueOperands(); } - friend class Operation; - explicit CondBranchOp(const Operation *state) : Op(state) {} + friend class OperationInst; + explicit CondBranchOp(const OperationInst *state) : Op(state) {} }; /// The "constant" operation requires a single attribute named "value". @@ -270,8 +279,8 @@ public: MLIRContext *context) const; protected: - friend class Operation; - explicit ConstantOp(const Operation *state) : Op(state) {} + friend class OperationInst; + explicit ConstantOp(const OperationInst *state) : Op(state) {} }; /// This is a refinement of the "constant" op for the case where it is @@ -289,11 +298,11 @@ public: return getAttrOfType("value").getValue(); } - static bool isClassFor(const Operation *op); + static bool isClassFor(const OperationInst *op); private: - friend class Operation; - explicit ConstantFloatOp(const Operation *state) : ConstantOp(state) {} + friend class OperationInst; + explicit ConstantFloatOp(const OperationInst *state) : ConstantOp(state) {} }; /// This is a refinement of the "constant" op for the case where it is @@ -316,11 +325,11 @@ public: return getAttrOfType("value").getInt(); } - static bool isClassFor(const Operation *op); + static bool isClassFor(const OperationInst *op); private: - friend class Operation; - explicit ConstantIntOp(const Operation *state) : ConstantOp(state) {} + friend class OperationInst; + explicit ConstantIntOp(const OperationInst *state) : ConstantOp(state) {} }; /// This is a refinement of the "constant" op for the case where it is @@ -337,11 +346,11 @@ public: return getAttrOfType("value").getInt(); } - static bool isClassFor(const Operation *op); + static bool isClassFor(const OperationInst *op); private: - friend class Operation; - explicit ConstantIndexOp(const Operation *state) : ConstantOp(state) {} + friend class OperationInst; + explicit ConstantIndexOp(const OperationInst *state) : ConstantOp(state) {} }; /// The "return" operation represents a return statement within a function. @@ -367,13 +376,13 @@ public: bool verify() const; private: - friend class Operation; - explicit ReturnOp(const Operation *state) : Op(state) {} + friend class OperationInst; + explicit ReturnOp(const OperationInst *state) : Op(state) {} }; // Prints dimension and symbol list. -void printDimAndSymbolList(Operation::const_operand_iterator begin, - Operation::const_operand_iterator end, +void printDimAndSymbolList(OperationInst::const_operand_iterator begin, + OperationInst::const_operand_iterator end, unsigned numDims, OpAsmPrinter *p); // Parses dimension and symbol list and returns true if parsing failed. diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h index 72ea426070b..335d64e3064 100644 --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -27,7 +27,7 @@ namespace mlir { using DialectConstantFoldHook = std::function, SmallVectorImpl &)>; + const OperationInst *, ArrayRef, SmallVectorImpl &)>; /// Dialects are groups of MLIR operations and behavior associated with the /// entire group. For example, hooks into other systems for constant folding, @@ -50,7 +50,7 @@ public: /// and fills in the `results` vector. If not, this returns true and /// `results` is unspecified. DialectConstantFoldHook constantFoldHook = - [](const Operation *op, ArrayRef operands, + [](const OperationInst *op, ArrayRef operands, SmallVectorImpl &results) { return true; }; // TODO: Hook to return the list of named types that are known. diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h index ece70f3ef92..0d039ee5d9b 100644 --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -120,8 +120,8 @@ public: } /// Return the 'return' statement of this MLFunction. - const OperationStmt *getReturnStmt() const; - OperationStmt *getReturnStmt(); + const OperationInst *getReturnStmt() const; + OperationInst *getReturnStmt(); // These should only be used on MLFunctions. StmtBlock *getBody() { @@ -133,12 +133,12 @@ public: } /// Walk the statements in the function in preorder, calling the callback for - /// each Operation statement. - void walk(std::function callback); + /// each operation statement. + void walk(std::function callback); /// Walk the statements in the function in postorder, calling the callback for - /// each Operation statement. - void walkPostOrder(std::function callback); + /// each operation statement. + void walkPostOrder(std::function callback); //===--------------------------------------------------------------------===// // Arguments diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h index 3dbf6dc3c37..baf55b4b2c2 100644 --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -61,7 +61,7 @@ public: // Diagnostic handler registration and use. MLIR supports the ability for the // IR to carry arbitrary metadata about operation location information. If an // problem is detected by the compiler, it can invoke the emitError / - // emitWarning / emitNote method on an Operation and have it get reported + // emitWarning / emitNote method on an OperationInst and have it get reported // through this interface. // // Tools using MLIR are encouraged to register error handlers and define a @@ -81,7 +81,7 @@ public: /// Emit a diagnostic using the registered issue handle if present, or with /// the default behavior if not. The MLIR compiler should not generally - /// interact with this, it should use methods on Operation instead. + /// interact with this, it should use methods on OperationInst instead. void emitDiagnostic(Location location, const Twine &message, DiagnosticKind kind) const; diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h index b8e947fbf99..ea109c08d65 100644 --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -68,7 +68,7 @@ struct constant_int_op_binder { /// Creates a matcher instance that binds the value to bv if match succeeds. constant_int_op_binder(IntegerAttr::ValueType *bv) : bind_value(bv) {} - bool match(Operation *op) { + bool match(OperationInst *op) { if (auto constOp = op->dyn_cast()) { auto type = constOp->getResult()->getType(); auto attr = constOp->getAttr("value"); @@ -90,7 +90,7 @@ struct constant_int_op_binder { // The matcher that matches a given target constant scalar / vector splat / // tensor splat integer value. template struct constant_int_value_matcher { - bool match(Operation *op) { + bool match(OperationInst *op) { APInt value; return constant_int_op_binder(&value).match(op) && TargetValue == value; @@ -99,7 +99,7 @@ template struct constant_int_value_matcher { /// The matcher that matches a certain kind of op. template struct op_matcher { - bool match(Operation *op) { return op->isa(); } + bool match(OperationInst *op) { return op->isa(); } }; } // end namespace detail @@ -108,7 +108,7 @@ template struct op_matcher { template inline bool matchPattern(Value *value, const Pattern &pattern) { // TODO: handle other cases - if (auto *op = value->getDefiningOperation()) + if (auto *op = value->getDefiningInst()) return const_cast(pattern).match(op); return false; } diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index aa3be30d7d4..8bedb348cfa 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -54,12 +54,12 @@ template struct IsSingleResult { OpType *, OpTrait::OneResult *>::value; }; -/// This pointer represents a notional "Operation*" but where the actual +/// This pointer represents a notional "OperationInst*" but where the actual /// storage of the pointer is maintained in the templated "OpType" class. template class OpPointer { public: - explicit OpPointer() : value(Operation::getNull().value) {} + explicit OpPointer() : value(OperationInst::getNull().value) {} explicit OpPointer(OpType value) : value(value) {} OpType &operator*() { return value; } @@ -69,7 +69,7 @@ public: operator bool() const { return value.getOperation(); } /// OpPointer can be implicitly converted to OpType*. - /// Return `nullptr` if there is no associated Operation*. + /// Return `nullptr` if there is no associated OperationInst*. operator OpType *() { if (!value.getOperation()) return nullptr; @@ -89,12 +89,12 @@ private: OpType value; }; -/// This pointer represents a notional "const Operation*" but where the actual -/// storage of the pointer is maintained in the templated "OpType" class. +/// This pointer represents a notional "const OperationInst*" but where the +/// actual storage of the pointer is maintained in the templated "OpType" class. template class ConstOpPointer { public: - explicit ConstOpPointer() : value(Operation::getNull().value) {} + explicit ConstOpPointer() : value(OperationInst::getNull().value) {} explicit ConstOpPointer(OpType value) : value(value) {} const OpType &operator*() const { return value; } @@ -105,7 +105,7 @@ public: operator bool() const { return value.getOperation(); } /// ConstOpPointer can always be implicitly converted to const OpType*. - /// Return `nullptr` if there is no associated Operation*. + /// Return `nullptr` if there is no associated OperationInst*. operator const OpType *() const { if (!value.getOperation()) return nullptr; @@ -137,8 +137,8 @@ private: class OpState { public: /// Return the operation that this refers to. - const Operation *getOperation() const { return state; } - Operation *getOperation() { return state; } + const OperationInst *getOperation() const { return state; } + OperationInst *getOperation() { return state; } /// The source location the operation was defined or derived from. Location getLoc() const { return state->getLoc(); } @@ -207,11 +207,11 @@ protected: /// Mutability management is handled by the OpWrapper/OpConstWrapper classes, /// so we can cast it away here. - explicit OpState(const Operation *state) - : state(const_cast(state)) {} + explicit OpState(const OperationInst *state) + : state(const_cast(state)) {} private: - Operation *state; + OperationInst *state; }; /// This template defines the constantFoldHook as used by AbstractOperation. @@ -223,7 +223,7 @@ public: /// This hook implements a constant folder for this operation. It returns /// true if folding failed, or returns false and fills in `results` on /// success. - static bool constantFoldHook(const Operation *op, + static bool constantFoldHook(const OperationInst *op, ArrayRef operands, SmallVectorImpl &results) { return op->cast()->constantFold(operands, results, @@ -256,7 +256,7 @@ public: /// This hook implements a constant folder for this operation. It returns /// true if folding failed, or returns false and fills in `results` on /// success. - static bool constantFoldHook(const Operation *op, + static bool constantFoldHook(const OperationInst *op, ArrayRef operands, SmallVectorImpl &results) { auto result = @@ -270,7 +270,7 @@ public: }; //===----------------------------------------------------------------------===// -// Operation Trait Types +// OperationInst Trait Types //===----------------------------------------------------------------------===// namespace OpTrait { @@ -279,22 +279,22 @@ namespace OpTrait { // corresponding trait classes. This avoids them being template // instantiated/duplicated. namespace impl { -bool verifyZeroOperands(const Operation *op); -bool verifyOneOperand(const Operation *op); -bool verifyNOperands(const Operation *op, unsigned numOperands); -bool verifyAtLeastNOperands(const Operation *op, unsigned numOperands); -bool verifyOperandsAreIntegerLike(const Operation *op); -bool verifySameTypeOperands(const Operation *op); -bool verifyZeroResult(const Operation *op); -bool verifyOneResult(const Operation *op); -bool verifyNResults(const Operation *op, unsigned numOperands); -bool verifyAtLeastNResults(const Operation *op, unsigned numOperands); -bool verifySameOperandsAndResultShape(const Operation *op); -bool verifySameOperandsAndResultType(const Operation *op); -bool verifyResultsAreBoolLike(const Operation *op); -bool verifyResultsAreFloatLike(const Operation *op); -bool verifyResultsAreIntegerLike(const Operation *op); -bool verifyIsTerminator(const Operation *op); +bool verifyZeroOperands(const OperationInst *op); +bool verifyOneOperand(const OperationInst *op); +bool verifyNOperands(const OperationInst *op, unsigned numOperands); +bool verifyAtLeastNOperands(const OperationInst *op, unsigned numOperands); +bool verifyOperandsAreIntegerLike(const OperationInst *op); +bool verifySameTypeOperands(const OperationInst *op); +bool verifyZeroResult(const OperationInst *op); +bool verifyOneResult(const OperationInst *op); +bool verifyNResults(const OperationInst *op, unsigned numOperands); +bool verifyAtLeastNResults(const OperationInst *op, unsigned numOperands); +bool verifySameOperandsAndResultShape(const OperationInst *op); +bool verifySameOperandsAndResultType(const OperationInst *op); +bool verifyResultsAreBoolLike(const OperationInst *op); +bool verifyResultsAreFloatLike(const OperationInst *op); +bool verifyResultsAreIntegerLike(const OperationInst *op); +bool verifyIsTerminator(const OperationInst *op); } // namespace impl /// Helper class for implementing traits. Clients are not expected to interact @@ -302,8 +302,8 @@ bool verifyIsTerminator(const Operation *op); template class TraitType> class TraitBase { protected: - /// Return the ultimate Operation being worked on. - Operation *getOperation() { + /// Return the ultimate OperationInst being worked on. + OperationInst *getOperation() { // We have to cast up to the trait type, then to the concrete type, then to // the BaseState class in explicit hops because the concrete type will // multiply derive from the (content free) TraitBase class, and we need to @@ -313,13 +313,13 @@ protected: auto *base = static_cast(concrete); return base->getOperation(); } - const Operation *getOperation() const { + const OperationInst *getOperation() const { return const_cast(this)->getOperation(); } /// Provide default implementations of trait hooks. This allows traits to /// provide exactly the overrides they care about. - static bool verifyTrait(const Operation *op) { return false; } + static bool verifyTrait(const OperationInst *op) { return false; } static AbstractOperation::OperationProperties getTraitProperties() { return 0; } @@ -330,7 +330,7 @@ protected: template class ZeroOperands : public TraitBase { public: - static bool verifyTrait(const Operation *op) { + static bool verifyTrait(const OperationInst *op) { return impl::verifyZeroOperands(op); } @@ -353,7 +353,7 @@ public: void setOperand(Value *value) { this->getOperation()->setOperand(0, value); } - static bool verifyTrait(const Operation *op) { + static bool verifyTrait(const OperationInst *op) { return impl::verifyOneOperand(op); } }; @@ -380,7 +380,7 @@ public: this->getOperation()->setOperand(i, value); } - static bool verifyTrait(const Operation *op) { + static bool verifyTrait(const OperationInst *op) { return impl::verifyNOperands(op, N); } }; @@ -412,7 +412,7 @@ public: } // Support non-const operand iteration. - using operand_iterator = Operation::operand_iterator; + using operand_iterator = OperationInst::operand_iterator; operand_iterator operand_begin() { return this->getOperation()->operand_begin(); } @@ -424,7 +424,7 @@ public: } // Support const operand iteration. - using const_operand_iterator = Operation::const_operand_iterator; + using const_operand_iterator = OperationInst::const_operand_iterator; const_operand_iterator operand_begin() const { return this->getOperation()->operand_begin(); } @@ -435,7 +435,7 @@ public: return this->getOperation()->getOperands(); } - static bool verifyTrait(const Operation *op) { + static bool verifyTrait(const OperationInst *op) { return impl::verifyAtLeastNOperands(op, N); } }; @@ -461,7 +461,7 @@ public: } // Support non-const operand iteration. - using operand_iterator = Operation::operand_iterator; + using operand_iterator = OperationInst::operand_iterator; operand_iterator operand_begin() { return this->getOperation()->operand_begin(); } @@ -471,7 +471,7 @@ public: } // Support const operand iteration. - using const_operand_iterator = Operation::const_operand_iterator; + using const_operand_iterator = OperationInst::const_operand_iterator; const_operand_iterator operand_begin() const { return this->getOperation()->operand_begin(); } @@ -488,7 +488,7 @@ public: template class ZeroResult : public TraitBase { public: - static bool verifyTrait(const Operation *op) { + static bool verifyTrait(const OperationInst *op) { return impl::verifyZeroResult(op); } }; @@ -510,7 +510,7 @@ public: getResult()->replaceAllUsesWith(newValue); } - static bool verifyTrait(const Operation *op) { + static bool verifyTrait(const OperationInst *op) { return impl::verifyOneResult(op); } @@ -549,7 +549,7 @@ public: Type getType(unsigned i) const { return getResult(i)->getType(); } - static bool verifyTrait(const Operation *op) { + static bool verifyTrait(const OperationInst *op) { return impl::verifyNResults(op, N); } }; @@ -573,7 +573,7 @@ public: Type getType(unsigned i) const { return getResult(i)->getType(); } - static bool verifyTrait(const Operation *op) { + static bool verifyTrait(const OperationInst *op) { return impl::verifyAtLeastNResults(op, N); } }; @@ -599,7 +599,7 @@ public: } // Support non-const result iteration. - using result_iterator = Operation::result_iterator; + using result_iterator = OperationInst::result_iterator; result_iterator result_begin() { return this->getOperation()->result_begin(); } @@ -609,7 +609,7 @@ public: } // Support const result iteration. - using const_result_iterator = Operation::const_result_iterator; + using const_result_iterator = OperationInst::const_result_iterator; const_result_iterator result_begin() const { return this->getOperation()->result_begin(); } @@ -628,7 +628,7 @@ template class SameOperandsAndResultShape : public TraitBase { public: - static bool verifyTrait(const Operation *op) { + static bool verifyTrait(const OperationInst *op) { return impl::verifySameOperandsAndResultShape(op); } }; @@ -643,7 +643,7 @@ template class SameOperandsAndResultType : public TraitBase { public: - static bool verifyTrait(const Operation *op) { + static bool verifyTrait(const OperationInst *op) { return impl::verifySameOperandsAndResultType(op); } }; @@ -653,7 +653,7 @@ public: template class ResultsAreBoolLike : public TraitBase { public: - static bool verifyTrait(const Operation *op) { + static bool verifyTrait(const OperationInst *op) { return impl::verifyResultsAreBoolLike(op); } }; @@ -664,7 +664,7 @@ template class ResultsAreFloatLike : public TraitBase { public: - static bool verifyTrait(const Operation *op) { + static bool verifyTrait(const OperationInst *op) { return impl::verifyResultsAreFloatLike(op); } }; @@ -675,7 +675,7 @@ template class ResultsAreIntegerLike : public TraitBase { public: - static bool verifyTrait(const Operation *op) { + static bool verifyTrait(const OperationInst *op) { return impl::verifyResultsAreIntegerLike(op); } }; @@ -706,7 +706,7 @@ template class OperandsAreIntegerLike : public TraitBase { public: - static bool verifyTrait(const Operation *op) { + static bool verifyTrait(const OperationInst *op) { return impl::verifyOperandsAreIntegerLike(op); } }; @@ -716,7 +716,7 @@ public: template class SameTypeOperands : public TraitBase { public: - static bool verifyTrait(const Operation *op) { + static bool verifyTrait(const OperationInst *op) { return impl::verifySameTypeOperands(op); } }; @@ -729,7 +729,7 @@ public: return static_cast( OperationProperty::Terminator); } - static bool verifyTrait(const Operation *op) { + static bool verifyTrait(const OperationInst *op) { return impl::verifyIsTerminator(op); } @@ -762,7 +762,7 @@ public: } // end namespace OpTrait //===----------------------------------------------------------------------===// -// Operation Definition classes +// OperationInst Definition classes //===----------------------------------------------------------------------===// /// This provides public APIs that all operations should have. The template @@ -777,14 +777,14 @@ class Op : public OpState, Traits...>::value> { public: /// Return the operation that this refers to. - const Operation *getOperation() const { return OpState::getOperation(); } - Operation *getOperation() { return OpState::getOperation(); } + const OperationInst *getOperation() const { return OpState::getOperation(); } + OperationInst *getOperation() { return OpState::getOperation(); } /// Return true if this "op class" can match against the specified operation. /// This hook can be overridden with a more specific implementation in /// the subclass of Base. /// - static bool isClassFor(const Operation *op) { + static bool isClassFor(const OperationInst *op) { return op->getName().getStringRef() == ConcreteType::getOperationName(); } @@ -798,7 +798,7 @@ public: /// This is the hook used by the AsmPrinter to emit this to the .mlir file. /// Op implementations should provide a print method. - static void printAssembly(const Operation *op, OpAsmPrinter *p) { + static void printAssembly(const OperationInst *op, OpAsmPrinter *p) { auto opPointer = op->dyn_cast(); assert(opPointer && "op's name does not match name of concrete type instantiated with"); @@ -812,7 +812,7 @@ public: /// /// On success this returns false; on failure it emits an error to the /// diagnostic subsystem and returns true. - static bool verifyInvariants(const Operation *op) { + static bool verifyInvariants(const OperationInst *op) { return BaseVerifier...>::verifyTrait(op) || op->cast()->verify(); } @@ -830,26 +830,26 @@ public: using ConcreteOpType = ConcreteType; protected: - explicit Op(const Operation *state) : OpState(state) {} + explicit Op(const OperationInst *state) : OpState(state) {} private: template struct BaseVerifier; template struct BaseVerifier { - static bool verifyTrait(const Operation *op) { + static bool verifyTrait(const OperationInst *op) { return First::verifyTrait(op) || BaseVerifier::verifyTrait(op); } }; template struct BaseVerifier { - static bool verifyTrait(const Operation *op) { + static bool verifyTrait(const OperationInst *op) { return First::verifyTrait(op); } }; template <> struct BaseVerifier<> { - static bool verifyTrait(const Operation *op) { return false; } + static bool verifyTrait(const OperationInst *op) { return false; } }; template struct BaseProperties; @@ -881,7 +881,7 @@ namespace impl { void buildBinaryOp(Builder *builder, OperationState *result, Value *lhs, Value *rhs); bool parseBinaryOp(OpAsmParser *parser, OperationState *result); -void printBinaryOp(const Operation *op, OpAsmPrinter *p); +void printBinaryOp(const OperationInst *op, OpAsmPrinter *p); } // namespace impl /// This template is used for operations that are simple binary ops that have @@ -907,7 +907,7 @@ public: } protected: - explicit BinaryOp(const Operation *state) + explicit BinaryOp(const OperationInst *state) : Op::Impl, OpTrait::OneResult, OpTrait::SameOperandsAndResultType, Traits...>(state) {} }; @@ -918,7 +918,7 @@ namespace impl { void buildCastOp(Builder *builder, OperationState *result, Value *source, Type destType); bool parseCastOp(OpAsmParser *parser, OperationState *result); -void printCastOp(const Operation *op, OpAsmPrinter *p); +void printCastOp(const OperationInst *op, OpAsmPrinter *p); } // namespace impl /// This template is used for operations that are cast operations, that have a @@ -943,7 +943,7 @@ public: } protected: - explicit CastOp(const Operation *state) + explicit CastOp(const OperationInst *state) : Op(state) {} }; diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index 36dbb98fa68..9ebc55b2ae8 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -75,7 +75,7 @@ public: /// Print a successor, and use list, of a terminator operation given the /// terminator and the successor index. - virtual void printSuccessorAndUseList(const Operation *term, + virtual void printSuccessorAndUseList(const OperationInst *term, unsigned index) = 0; /// If the specified operation has attributes, print out an attribute @@ -87,7 +87,7 @@ public: ArrayRef elidedAttrs = {}) = 0; /// Print the entire operation with the default verbose formatting. - virtual void printDefaultOp(const Operation *op) = 0; + virtual void printDefaultOp(const OperationInst *op) = 0; private: OpAsmPrinter(const OpAsmPrinter &) = delete; diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h deleted file mode 100644 index 6fe88ac8852..00000000000 --- a/mlir/include/mlir/IR/Operation.h +++ /dev/null @@ -1,399 +0,0 @@ -//===- Operation.h - MLIR Operation Class -----------------------*- C++ -*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#ifndef MLIR_IR_OPERATION_H -#define MLIR_IR_OPERATION_H - -#include "mlir/IR/OperationSupport.h" -#include "mlir/IR/Statement.h" -#include "llvm/ADT/Twine.h" - -namespace mlir { -class AttributeListStorage; -template class ConstOpPointer; -template class OpPointer; -template class OperandIterator; -template class ResultIterator; -template class ResultTypeIterator; -class Function; -class IROperandOwner; -class Statement; -class OperationStmt; -using Instruction = Statement; - -/// Operations represent all of the arithmetic and other basic computation in -/// MLIR. This class is the common implementation details behind Instruction -/// and OperationStmt. -/// -class Operation : public Statement { -public: - - /// Return the function this operation is defined in. This has a verbose - /// name to avoid name lookup ambiguities. - Function *getOperationFunction(); - - const Function *getOperationFunction() const { - return const_cast(this)->getOperationFunction(); - } - - /// The name of an operation is the key identifier for it. - OperationName getName() const { return name; } - - // Support non-const operand iteration. - using operand_iterator = OperandIterator; - operand_iterator operand_begin(); - operand_iterator operand_end(); - llvm::iterator_range getOperands(); - - // Support const operand iteration. - using const_operand_iterator = OperandIterator; - const_operand_iterator operand_begin() const; - const_operand_iterator operand_end() const; - llvm::iterator_range getOperands() const; - - /// Return the number of results this operation has. - unsigned getNumResults() const; - - /// Return the indicated result. - Value *getResult(unsigned idx); - const Value *getResult(unsigned idx) const { - return const_cast(this)->getResult(idx); - } - - // Support non-const result iteration. - using result_iterator = ResultIterator; - result_iterator result_begin(); - result_iterator result_end(); - llvm::iterator_range getResults(); - - // Support const result iteration. - using const_result_iterator = ResultIterator; - const_result_iterator result_begin() const; - const_result_iterator result_end() const; - llvm::iterator_range getResults() const; - - // Support for result type iteration. - using result_type_iterator = ResultTypeIterator; - result_type_iterator result_type_begin() const; - result_type_iterator result_type_end() const; - llvm::iterator_range getResultTypes() const; - - // Support for successor querying. - unsigned getNumSuccessors() const; - unsigned getNumSuccessorOperands(unsigned index) const; - BasicBlock *getSuccessor(unsigned index); - BasicBlock *getSuccessor(unsigned index) const { - return const_cast(this)->getSuccessor(index); - } - void setSuccessor(BasicBlock *block, unsigned index); - void eraseSuccessorOperand(unsigned succIndex, unsigned opIndex); - llvm::iterator_range - getSuccessorOperands(unsigned index) const; - llvm::iterator_range getSuccessorOperands(unsigned index); - - /// Return true if there are no users of any results of this operation. - bool use_empty() const; - - // Attributes. Operations may optionally carry a list of attributes that - // associate constants to names. Attributes may be dynamically added and - // removed over the lifetime of an operation. - // - // We assume there will be relatively few attributes on a given operation - // (maybe a dozen or so, but not hundreds or thousands) so we use linear - // searches for everything. - - /// Return all of the attributes on this operation. - ArrayRef getAttrs() const; - - /// Return the specified attribute if present, null otherwise. - Attribute getAttr(Identifier name) const { - for (auto elt : getAttrs()) - if (elt.first == name) - return elt.second; - return nullptr; - } - - Attribute getAttr(StringRef name) const { - for (auto elt : getAttrs()) - if (elt.first.is(name)) - return elt.second; - return nullptr; - } - - template AttrClass getAttrOfType(Identifier name) const { - return getAttr(name).dyn_cast_or_null(); - } - - template AttrClass getAttrOfType(StringRef name) const { - return getAttr(name).dyn_cast_or_null(); - } - - /// If the an attribute exists with the specified name, change it to the new - /// value. Otherwise, add a new attribute with the specified name/value. - void setAttr(Identifier name, Attribute value); - - enum class RemoveResult { - Removed, NotFound - }; - - /// Remove the attribute with the specified name if it exists. The return - /// value indicates whether the attribute was present or not. - RemoveResult removeAttr(Identifier name); - - /// Emit an error about fatal conditions with this operation, reporting up to - /// any diagnostic handlers that may be listening. This function always - /// returns true. NOTE: This may terminate the containing application, only - /// use when the IR is in an inconsistent state. - bool emitError(const Twine &message) const; - - /// Emit an error with the op name prefixed, like "'dim' op " which is - /// convenient for verifiers. This function always returns true. - bool emitOpError(const Twine &message) const; - - /// Emit a warning about this operation, reporting up to any diagnostic - /// handlers that may be listening. - void emitWarning(const Twine &message) const; - - /// Emit a note about this operation, reporting up to any diagnostic - /// handlers that may be listening. - void emitNote(const Twine &message) const; - - /// If this operation has a registered operation description, return it. - /// Otherwise return null. - const AbstractOperation *getAbstractOperation() const { - return getName().getAbstractOperation(); - } - - // Return a null OpPointer for the specified type. - template static OpPointer getNull() { - return OpPointer(OpClass(nullptr)); - } - - /// The dyn_cast methods perform a dynamic cast from an Operation (like - /// Instruction and OperationStmt) to a typed Op like DimOp. This returns - /// a null OpPointer on failure. - template OpPointer dyn_cast() { - if (isa()) { - return cast(); - } else { - return OpPointer(OpClass(nullptr)); - } - } - - /// The dyn_cast methods perform a dynamic cast from an Operation (like - /// Instruction and OperationStmt) to a typed Op like DimOp. This returns - /// a null ConstOpPointer on failure. - template ConstOpPointer dyn_cast() const { - if (isa()) { - return cast(); - } else { - return ConstOpPointer(OpClass(nullptr)); - } - } - - /// The cast methods perform a cast from an Operation (like - /// Instruction and OperationStmt) to a typed Op like DimOp. This aborts - /// if the parameter to the template isn't an instance of the template type - /// argument. - template OpPointer cast() { - assert(isa() && "cast() argument of incompatible type!"); - return OpPointer(OpClass(this)); - } - - /// The cast methods perform a cast from an Operation (like - /// Instruction and OperationStmt) to a typed Op like DimOp. This aborts - /// if the parameter to the template isn't an instance of the template type - /// argument. - template ConstOpPointer cast() const { - assert(isa() && "cast() argument of incompatible type!"); - return ConstOpPointer(OpClass(this)); - } - - /// The is methods return true if the operation is a typed op (like DimOp) of - /// of the given class. - template bool isa() const { - return OpClass::isClassFor(this); - } - - // Returns whether the operation is commutative. - bool isCommutative() const { - if (auto *absOp = getAbstractOperation()) - return absOp->hasProperty(OperationProperty::Commutative); - return false; - } - - // Returns whether the operation has side-effects. - bool hasNoSideEffect() const { - if (auto *absOp = getAbstractOperation()) - return absOp->hasProperty(OperationProperty::NoSideEffect); - return false; - } - - // Returns whether the operation is a terminator. - bool isTerminator() const { - if (auto *absOp = getAbstractOperation()) - return absOp->hasProperty(OperationProperty::Terminator); - return false; - } - - /// Remove this operation from its parent block and delete it. - void erase(); - - /// Attempt to constant fold this operation with the specified constant - /// operand values - the elements in "operands" will correspond directly to - /// the operands of the operation, but may be null if non-constant. If - /// constant folding is successful, this returns false and fills in the - /// `results` vector. If not, this returns true and `results` is unspecified. - bool constantFold(ArrayRef operands, - SmallVectorImpl &results) const; - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(const Statement *stmt); - static bool classof(const IROperandOwner *ptr); - -protected: - Operation(OperationName name, ArrayRef attrs, - Location location, MLIRContext *context); - ~Operation(); - -private: - Operation(const Operation&) = delete; - void operator=(const Operation&) = delete; - - /// This holds the name of the operation. - OperationName name; - - /// This holds general named attributes for the operation. - AttributeListStorage *attrs; -}; - -/// This template implements the result iterators for the various IR classes -/// in terms of getResult(idx). -template -class ResultIterator final - : public IndexedAccessorIterator, - ObjectType, ElementType> { -public: - /// Initializes the result iterator to the specified index. - ResultIterator(ObjectType *object, unsigned index) - : IndexedAccessorIterator, - ObjectType, ElementType>(object, index) {} - - /// Support converting to the const variant. This will be a no-op for const - /// variant. - operator ResultIterator() const { - return ResultIterator(this->object, - this->index); - } - - ElementType *operator*() const { - return this->object->getResult(this->index); - } -}; - -/// This template implements the result type iterators for the various IR -/// classes in terms of getResult(idx)->getType(). -template -class ResultTypeIterator final - : public IndexedAccessorIterator< - ResultTypeIterator, ObjectType, - ElementType> { -public: - /// Initializes the result type iterator to the specified index. - ResultTypeIterator(ObjectType *object, unsigned index) - : IndexedAccessorIterator, - ObjectType, ElementType>(object, index) {} - - /// Support converting to the const variant. This will be a no-op for const - /// variant. - operator ResultTypeIterator() const { - return ResultTypeIterator(this->object, - this->index); - } - - Type operator*() const { - return this->object->getResult(this->index)->getType(); - } -}; - -// Implement the inline operand iterator methods. -inline auto Operation::operand_begin() -> operand_iterator { - return operand_iterator(this, 0); -} - -inline auto Operation::operand_end() -> operand_iterator { - return operand_iterator(this, getNumOperands()); -} - -inline auto Operation::getOperands() -> llvm::iterator_range { - return {operand_begin(), operand_end()}; -} - -inline auto Operation::operand_begin() const -> const_operand_iterator { - return const_operand_iterator(this, 0); -} - -inline auto Operation::operand_end() const -> const_operand_iterator { - return const_operand_iterator(this, getNumOperands()); -} - -inline auto Operation::getOperands() const - -> llvm::iterator_range { - return {operand_begin(), operand_end()}; -} - -// Implement the inline result iterator methods. -inline auto Operation::result_begin() -> result_iterator { - return result_iterator(this, 0); -} - -inline auto Operation::result_end() -> result_iterator { - return result_iterator(this, getNumResults()); -} - -inline auto Operation::getResults() -> llvm::iterator_range { - return {result_begin(), result_end()}; -} - -inline auto Operation::result_begin() const -> const_result_iterator { - return const_result_iterator(this, 0); -} - -inline auto Operation::result_end() const -> const_result_iterator { - return const_result_iterator(this, getNumResults()); -} - -inline auto Operation::getResults() const - -> llvm::iterator_range { - return {result_begin(), result_end()}; -} - -inline auto Operation::result_type_begin() const -> result_type_iterator { - return result_type_iterator(this, 0); -} - -inline auto Operation::result_type_end() const -> result_type_iterator { - return result_type_iterator(this, getNumResults()); -} - -inline auto Operation::getResultTypes() const - -> llvm::iterator_range { - return {result_type_begin(), result_type_end()}; -} -} // end namespace mlir - -#endif diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index e7d19b7eae0..2bc75a2a40d 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -15,7 +15,7 @@ // limitations under the License. // ============================================================================= // -// This file defines a number of support types that Operation and related +// This file defines a number of support types that OperationInst and related // classes build on top of. // //===----------------------------------------------------------------------===// @@ -32,7 +32,7 @@ namespace mlir { class Dialect; -class Operation; +class OperationInst; class OperationState; class OpAsmParser; class OpAsmParserResult; @@ -78,23 +78,24 @@ public: Dialect &dialect; /// Return true if this "op class" can match against the specified operation. - bool (&isClassFor)(const Operation *op); + bool (&isClassFor)(const OperationInst *op); /// Use the specified object to parse this ops custom assembly format. bool (&parseAssembly)(OpAsmParser *parser, OperationState *result); /// This hook implements the AsmPrinter for this operation. - void (&printAssembly)(const Operation *op, OpAsmPrinter *p); + void (&printAssembly)(const OperationInst *op, OpAsmPrinter *p); /// This hook implements the verifier for this operation. It should emits an /// error message and returns true if a problem is detected, or returns false /// if everything is ok. - bool (&verifyInvariants)(const Operation *op); + bool (&verifyInvariants)(const OperationInst *op); /// This hook implements a constant folder for this operation. It returns /// true if folding failed, or returns false and fills in `results` on /// success. - bool (&constantFoldHook)(const Operation *op, ArrayRef operands, + bool (&constantFoldHook)(const OperationInst *op, + ArrayRef operands, SmallVectorImpl &results); /// This hook returns any canonicalization pattern rewrites that the operation @@ -124,11 +125,11 @@ public: private: AbstractOperation( StringRef name, Dialect &dialect, OperationProperties opProperties, - bool (&isClassFor)(const Operation *op), + bool (&isClassFor)(const OperationInst *op), bool (&parseAssembly)(OpAsmParser *parser, OperationState *result), - void (&printAssembly)(const Operation *op, OpAsmPrinter *p), - bool (&verifyInvariants)(const Operation *op), - bool (&constantFoldHook)(const Operation *op, + void (&printAssembly)(const OperationInst *op, OpAsmPrinter *p), + bool (&verifyInvariants)(const OperationInst *op), + bool (&constantFoldHook)(const OperationInst *op, ArrayRef operands, SmallVectorImpl &results), void (&getCanonicalizationPatterns)(OwningRewritePatternList &results, diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index b015f7bb44d..1a467131cab 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -108,7 +108,7 @@ public: /// returns a None value. On success it a (possibly null) pattern-specific /// state wrapped in a Some. This state is passed back into its rewrite /// function if this match is selected. - virtual PatternMatchResult match(Operation *op) const = 0; + virtual PatternMatchResult match(OperationInst *op) const = 0; virtual ~Pattern() {} @@ -148,7 +148,7 @@ public: /// rewriter. If an unexpected error is encountered (an internal /// compiler error), it is emitted through the normal MLIR diagnostic /// hooks and the IR is left in a valid state. - virtual void rewrite(Operation *op, std::unique_ptr state, + virtual void rewrite(OperationInst *op, std::unique_ptr state, PatternRewriter &rewriter) const; /// Rewrite the IR rooted at the specified operation with the result of @@ -156,7 +156,7 @@ public: /// builder. If an unexpected error is encountered (an internal /// compiler error), it is emitted through the normal MLIR diagnostic /// hooks and the IR is left in a valid state. - virtual void rewrite(Operation *op, PatternRewriter &rewriter) const; + virtual void rewrite(OperationInst *op, PatternRewriter &rewriter) const; protected: /// Patterns must specify the root operation name they match against, and can @@ -222,13 +222,13 @@ public: /// clients can specify a list of other nodes that this replacement may make /// (perhaps transitively) dead. If any of those values are dead, this will /// remove them as well. - void replaceOp(Operation *op, ArrayRef newValues, + void replaceOp(OperationInst *op, ArrayRef newValues, ArrayRef valuesToRemoveIfDead = {}); /// Replaces the result op with a new op that is created without verification. /// The result values of the two ops must be the same types. template - void replaceOpWithNewOp(Operation *op, Args... args) { + void replaceOpWithNewOp(OperationInst *op, Args... args) { auto newOp = create(op->getLoc(), args...); replaceOpWithResultsOfAnotherOp(op, newOp->getOperation(), {}); } @@ -237,7 +237,8 @@ public: /// The result values of the two ops must be the same types. This allows /// specifying a list of ops that may be removed if dead. template - void replaceOpWithNewOp(Operation *op, ArrayRef valuesToRemoveIfDead, + void replaceOpWithNewOp(OperationInst *op, + ArrayRef valuesToRemoveIfDead, Args... args) { auto newOp = create(op->getLoc(), args...); replaceOpWithResultsOfAnotherOp(op, newOp->getOperation(), @@ -252,7 +253,7 @@ public: /// The valuesToRemoveIfDead list is an optional list of values that the /// rewriter should remove if they are dead at this point. /// - void updatedRootInPlace(Operation *op, + void updatedRootInPlace(OperationInst *op, ArrayRef valuesToRemoveIfDead = {}); protected: @@ -264,26 +265,26 @@ protected: /// This is implemented to create the specified operations and serves as a /// notification hook for rewriters that want to know about new operations. - virtual Operation *createOperation(const OperationState &state) = 0; + virtual OperationInst *createOperation(const OperationState &state) = 0; /// Notify the pattern rewriter that the specified operation has been mutated /// in place. This is called after the mutation is done. - virtual void notifyRootUpdated(Operation *op) {} + virtual void notifyRootUpdated(OperationInst *op) {} /// Notify the pattern rewriter that the specified operation is about to be /// replaced with another set of operations. This is called before the uses /// of the operation have been changed. - virtual void notifyRootReplaced(Operation *op) {} + virtual void notifyRootReplaced(OperationInst *op) {} /// This is called on an operation that a pattern match is removing, right /// before the operation is deleted. At this point, the operation has zero /// uses. - virtual void notifyOperationRemoved(Operation *op) {} + virtual void notifyOperationRemoved(OperationInst *op) {} private: /// op and newOp are known to have the same number of results, replace the /// uses of op with uses of newOp - void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp, + void replaceOpWithResultsOfAnotherOp(OperationInst *op, OperationInst *newOp, ArrayRef valuesToRemoveIfDead); }; @@ -316,7 +317,7 @@ public: /// Find the highest benefit pattern available in the pattern set for the DAG /// rooted at the specified node. This returns the pattern (and any state it /// needs) if found, or null if there are no matches. - MatchResult findMatch(Operation *op); + MatchResult findMatch(OperationInst *op); private: PatternMatcher(const PatternMatcher &) = delete; diff --git a/mlir/include/mlir/IR/Statement.h b/mlir/include/mlir/IR/Statement.h index 97d35cf79bc..d03d1daaa88 100644 --- a/mlir/include/mlir/IR/Statement.h +++ b/mlir/include/mlir/IR/Statement.h @@ -35,7 +35,7 @@ class ForStmt; class MLIRContext; /// The operand of a Terminator contains a StmtBlock. -using StmtBlockOperand = IROperandImpl; +using StmtBlockOperand = IROperandImpl; } // namespace mlir @@ -72,7 +72,7 @@ class Statement : public IROperandOwner, public llvm::ilist_node_with_parent { public: enum class Kind { - Operation = (int)IROperandOwner::Kind::OperationStmt, + OperationInst = (int)IROperandOwner::Kind::OperationInst, For = (int)IROperandOwner::Kind::ForStmt, If = (int)IROperandOwner::Kind::IfStmt, }; diff --git a/mlir/include/mlir/IR/Statements.h b/mlir/include/mlir/IR/Statements.h index d7a584843d0..7b94486e42c 100644 --- a/mlir/include/mlir/IR/Statements.h +++ b/mlir/include/mlir/IR/Statements.h @@ -24,25 +24,33 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/IntegerSet.h" -#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/Statement.h" #include "mlir/IR/StmtBlock.h" +#include "llvm/ADT/Twine.h" #include "llvm/Support/TrailingObjects.h" namespace mlir { class AffineBound; class IntegerSet; class AffineCondition; -class OperationStmt; -using OperationInst = OperationStmt; - -/// Operation statements represent operations inside ML functions. -class OperationStmt final - : public Operation, - private llvm::TrailingObjects class ConstOpPointer; +template class OpPointer; +template class ResultIterator; +template class ResultTypeIterator; +class Function; + +/// Operations represent all of the arithmetic and other basic computation in +/// MLIR. +/// +class OperationInst final + : public Statement, + private llvm::TrailingObjects { public: - /// Create a new OperationStmt with the specific fields. - static OperationStmt * + /// Create a new OperationInst with the specific fields. + static OperationInst * create(Location location, OperationName name, ArrayRef operands, ArrayRef resultTypes, ArrayRef attributes, ArrayRef successors, MLIRContext *context); @@ -50,6 +58,15 @@ public: /// Return the context this operation is associated with. MLIRContext *getContext() const; + /// The name of an operation is the key identifier for it. + OperationName getName() const { return name; } + + /// If this operation has a registered operation description, return it. + /// Otherwise return null. + const AbstractOperation *getAbstractOperation() const { + return getName().getAbstractOperation(); + } + /// Check if this statement is a return statement. bool isReturn() const; @@ -68,7 +85,7 @@ public: } // Support non-const operand iteration. - using operand_iterator = OperandIterator; + using operand_iterator = OperandIterator; operand_iterator operand_begin() { return operand_iterator(this, 0); } @@ -83,7 +100,7 @@ public: // Support const operand iteration. using const_operand_iterator = - OperandIterator; + OperandIterator; const_operand_iterator operand_begin() const { return const_operand_iterator(this, 0); @@ -114,35 +131,28 @@ public: // Results //===--------------------------------------------------------------------===// + /// Return true if there are no users of any results of this operation. + bool use_empty() const; + unsigned getNumResults() const { return numResults; } Value *getResult(unsigned idx) { return &getStmtResult(idx); } const Value *getResult(unsigned idx) const { return &getStmtResult(idx); } // Support non-const result iteration. - using result_iterator = ResultIterator; - result_iterator result_begin() { return result_iterator(this, 0); } - result_iterator result_end() { - return result_iterator(this, getNumResults()); - } - llvm::iterator_range getResults() { - return {result_begin(), result_end()}; - } + using result_iterator = ResultIterator; + result_iterator result_begin(); + result_iterator result_end(); + llvm::iterator_range getResults(); // Support const result iteration. using const_result_iterator = - ResultIterator; - const_result_iterator result_begin() const { - return const_result_iterator(this, 0); - } + ResultIterator; + const_result_iterator result_begin() const; - const_result_iterator result_end() const { - return const_result_iterator(this, getNumResults()); - } + const_result_iterator result_end() const; - llvm::iterator_range getResults() const { - return {result_begin(), result_end()}; - } + llvm::iterator_range getResults() const; ArrayRef getStmtResults() const { return {getTrailingObjects(), numResults}; @@ -160,19 +170,61 @@ public: // Support result type iteration. using result_type_iterator = - ResultTypeIterator; - result_type_iterator result_type_begin() const { - return result_type_iterator(this, 0); + ResultTypeIterator; + result_type_iterator result_type_begin() const; + + result_type_iterator result_type_end() const; + + llvm::iterator_range getResultTypes() const; + + //===--------------------------------------------------------------------===// + // Attributes + //===--------------------------------------------------------------------===// + + // Operations may optionally carry a list of attributes that associate + // constants to names. Attributes may be dynamically added and removed over + // the lifetime of an operation. + // + // We assume there will be relatively few attributes on a given operation + // (maybe a dozen or so, but not hundreds or thousands) so we use linear + // searches for everything. + + /// Return all of the attributes on this operation. + ArrayRef getAttrs() const; + + /// Return the specified attribute if present, null otherwise. + Attribute getAttr(Identifier name) const { + for (auto elt : getAttrs()) + if (elt.first == name) + return elt.second; + return nullptr; + } + + Attribute getAttr(StringRef name) const { + for (auto elt : getAttrs()) + if (elt.first.is(name)) + return elt.second; + return nullptr; } - result_type_iterator result_type_end() const { - return result_type_iterator(this, getNumResults()); + template AttrClass getAttrOfType(Identifier name) const { + return getAttr(name).dyn_cast_or_null(); } - llvm::iterator_range getResultTypes() const { - return {result_type_begin(), result_type_end()}; + template AttrClass getAttrOfType(StringRef name) const { + return getAttr(name).dyn_cast_or_null(); } + /// If the an attribute exists with the specified name, change it to the new + /// value. Otherwise, add a new attribute with the specified name/value. + void setAttr(Identifier name, Attribute value); + + enum class RemoveResult { Removed, NotFound }; + + /// Remove the attribute with the specified name if it exists. The return + /// value indicates whether the attribute was present or not. + RemoveResult removeAttr(Identifier name); + //===--------------------------------------------------------------------===// // Terminators //===--------------------------------------------------------------------===// @@ -182,19 +234,12 @@ public: return {getTrailingObjects(), numSuccs}; } ArrayRef getBlockOperands() const { - return const_cast(this)->getBlockOperands(); + return const_cast(this)->getBlockOperands(); } - MutableArrayRef getSuccessorOperands(unsigned index) { - assert(isTerminator() && "Only terminators have successors"); - assert(index < getNumSuccessors()); - unsigned succOpIndex = getSuccessorOperandIndex(index); - auto *operandBegin = getStmtOperands().data() + succOpIndex; - return {operandBegin, getNumSuccessorOperands(index)}; - } - ArrayRef getSuccessorOperands(unsigned index) const { - return const_cast(this)->getSuccessorOperands(index); - } + llvm::iterator_range + getSuccessorOperands(unsigned index) const; + llvm::iterator_range getSuccessorOperands(unsigned index); unsigned getNumSuccessors() const { return numSuccs; } unsigned getNumSuccessorOperands(unsigned index) const { @@ -208,7 +253,7 @@ public: return getBlockOperands()[index].get(); } const StmtBlock *getSuccessor(unsigned index) const { - return const_cast(this)->getSuccessor(index); + return const_cast(this)->getSuccessor(index); } void setSuccessor(BasicBlock *block, unsigned index); @@ -237,33 +282,129 @@ public: return getNumOperands() - postSuccessorOpCount; } + //===--------------------------------------------------------------------===// + // Accessors for various properties of operations + //===--------------------------------------------------------------------===// + + /// Returns whether the operation is commutative. + bool isCommutative() const { + if (auto *absOp = getAbstractOperation()) + return absOp->hasProperty(OperationProperty::Commutative); + return false; + } + + /// Returns whether the operation has side-effects. + bool hasNoSideEffect() const { + if (auto *absOp = getAbstractOperation()) + return absOp->hasProperty(OperationProperty::NoSideEffect); + return false; + } + + /// Returns whether the operation is a terminator. + bool isTerminator() const { + if (auto *absOp = getAbstractOperation()) + return absOp->hasProperty(OperationProperty::Terminator); + return false; + } + + /// Attempt to constant fold this operation with the specified constant + /// operand values - the elements in "operands" will correspond directly to + /// the operands of the operation, but may be null if non-constant. If + /// constant folding is successful, this returns false and fills in the + /// `results` vector. If not, this returns true and `results` is unspecified. + bool constantFold(ArrayRef operands, + SmallVectorImpl &results) const; + + //===--------------------------------------------------------------------===// + // Conversions to declared operations like DimOp + //===--------------------------------------------------------------------===// + + // Return a null OpPointer for the specified type. + template static OpPointer getNull() { + return OpPointer(OpClass(nullptr)); + } + + /// The dyn_cast methods perform a dynamic cast from an OperationInst (like + /// Instruction and OperationInst) to a typed Op like DimOp. This returns + /// a null OpPointer on failure. + template OpPointer dyn_cast() { + if (isa()) { + return cast(); + } else { + return OpPointer(OpClass(nullptr)); + } + } + + /// The dyn_cast methods perform a dynamic cast from an OperationInst (like + /// Instruction and OperationInst) to a typed Op like DimOp. This returns + /// a null ConstOpPointer on failure. + template ConstOpPointer dyn_cast() const { + if (isa()) { + return cast(); + } else { + return ConstOpPointer(OpClass(nullptr)); + } + } + + /// The cast methods perform a cast from an OperationInst (like + /// Instruction and OperationInst) to a typed Op like DimOp. This aborts + /// if the parameter to the template isn't an instance of the template type + /// argument. + template OpPointer cast() { + assert(isa() && "cast() argument of incompatible type!"); + return OpPointer(OpClass(this)); + } + + /// The cast methods perform a cast from an OperationInst (like + /// Instruction and OperationInst) to a typed Op like DimOp. This aborts + /// if the parameter to the template isn't an instance of the template type + /// argument. + template ConstOpPointer cast() const { + assert(isa() && "cast() argument of incompatible type!"); + return ConstOpPointer(OpClass(this)); + } + + /// The is methods return true if the operation is a typed op (like DimOp) of + /// of the given class. + template bool isa() const { + return OpClass::isClassFor(this); + } + //===--------------------------------------------------------------------===// // Other //===--------------------------------------------------------------------===// + /// Emit an error with the op name prefixed, like "'dim' op " which is + /// convenient for verifiers. This function always returns true. + bool emitOpError(const Twine &message) const; + void destroy(); - using Statement::erase; /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(const IROperandOwner *ptr) { - return ptr->getKind() == IROperandOwner::Kind::OperationStmt; + return ptr->getKind() == IROperandOwner::Kind::OperationInst; } - static bool classof(const Operation *op) { return true; } private: unsigned numOperands; const unsigned numResults, numSuccs; - OperationStmt(Location location, OperationName name, unsigned numOperands, + /// This holds the name of the operation. + OperationName name; + + /// This holds general named attributes for the operation. + AttributeListStorage *attrs; + + OperationInst(Location location, OperationName name, unsigned numOperands, unsigned numResults, unsigned numSuccessors, ArrayRef attributes, MLIRContext *context); - ~OperationStmt(); + ~OperationInst(); /// Erase the operand at 'index'. void eraseOperand(unsigned index); // This stuff is used by the TrailingObjects template. - friend llvm::TrailingObjects; size_t numTrailingObjects(OverloadToken) const { return numOperands; @@ -277,6 +418,95 @@ private: size_t numTrailingObjects(OverloadToken) const { return numSuccs; } }; +/// This template implements the result iterators for the OperationInst class +/// in terms of getResult(idx). +template +class ResultIterator final + : public IndexedAccessorIterator, + ObjectType, ElementType> { +public: + /// Initializes the result iterator to the specified index. + ResultIterator(ObjectType *object, unsigned index) + : IndexedAccessorIterator, + ObjectType, ElementType>(object, index) {} + + /// Support converting to the const variant. This will be a no-op for const + /// variant. + operator ResultIterator() const { + return ResultIterator(this->object, + this->index); + } + + ElementType *operator*() const { + return this->object->getResult(this->index); + } +}; + +/// This template implements the result type iterators for the OperationInst +/// class in terms of getResult(idx)->getType(). +template +class ResultTypeIterator final + : public IndexedAccessorIterator< + ResultTypeIterator, ObjectType, + ElementType> { +public: + /// Initializes the result type iterator to the specified index. + ResultTypeIterator(ObjectType *object, unsigned index) + : IndexedAccessorIterator, + ObjectType, ElementType>(object, index) {} + + /// Support converting to the const variant. This will be a no-op for const + /// variant. + operator ResultTypeIterator() const { + return ResultTypeIterator(this->object, + this->index); + } + + Type operator*() const { + return this->object->getResult(this->index)->getType(); + } +}; + +// Implement the inline result iterator methods. +inline auto OperationInst::result_begin() -> result_iterator { + return result_iterator(this, 0); +} + +inline auto OperationInst::result_end() -> result_iterator { + return result_iterator(this, getNumResults()); +} + +inline auto OperationInst::getResults() + -> llvm::iterator_range { + return {result_begin(), result_end()}; +} + +inline auto OperationInst::result_begin() const -> const_result_iterator { + return const_result_iterator(this, 0); +} + +inline auto OperationInst::result_end() const -> const_result_iterator { + return const_result_iterator(this, getNumResults()); +} + +inline auto OperationInst::getResults() const + -> llvm::iterator_range { + return {result_begin(), result_end()}; +} + +inline auto OperationInst::result_type_begin() const -> result_type_iterator { + return result_type_iterator(this, 0); +} + +inline auto OperationInst::result_type_end() const -> result_type_iterator { + return result_type_iterator(this, getNumResults()); +} + +inline auto OperationInst::getResultTypes() const + -> llvm::iterator_range { + return {result_type_begin(), result_type_end()}; +} + /// For statement represents an affine loop nest. class ForStmt : public Statement, public Value { public: diff --git a/mlir/include/mlir/IR/StmtBlock.h b/mlir/include/mlir/IR/StmtBlock.h index a5487319da4..01ef68c7d18 100644 --- a/mlir/include/mlir/IR/StmtBlock.h +++ b/mlir/include/mlir/IR/StmtBlock.h @@ -161,9 +161,9 @@ public: /// Get the terminator instruction of this block, or null if the block is /// malformed. - OperationStmt *getTerminator(); + OperationInst *getTerminator(); - const OperationStmt *getTerminator() const { + const OperationInst *getTerminator() const { return const_cast(this)->getTerminator(); } @@ -408,7 +408,7 @@ public: } private: - using BBUseIterator = ValueUseIterator; + using BBUseIterator = ValueUseIterator; BBUseIterator bbUseIterator; }; diff --git a/mlir/include/mlir/IR/StmtVisitor.h b/mlir/include/mlir/IR/StmtVisitor.h index a0f787fea4d..bcc416c00ae 100644 --- a/mlir/include/mlir/IR/StmtVisitor.h +++ b/mlir/include/mlir/IR/StmtVisitor.h @@ -44,7 +44,7 @@ // lc.walk(function); // numLoops = lc.numLoops; // -// There are 'visit' methods for Operation, ForStmt, IfStmt, and +// There are 'visit' methods for OperationInst, ForStmt, IfStmt, and // MLFunction, which recursively process all contained statements. // // Note that if you don't implement visitXXX for some statement type, @@ -87,9 +87,9 @@ public: return static_cast(this)->visitForStmt(cast(s)); case Statement::Kind::If: return static_cast(this)->visitIfStmt(cast(s)); - case Statement::Kind::Operation: - return static_cast(this)->visitOperationStmt( - cast(s)); + case Statement::Kind::OperationInst: + return static_cast(this)->visitOperationInst( + cast(s)); } } @@ -105,7 +105,7 @@ public: // methods get called to indicate when transitioning into a new unit. void visitForStmt(ForStmt *forStmt) {} void visitIfStmt(IfStmt *ifStmt) {} - void visitOperationStmt(OperationStmt *opStmt) {} + void visitOperationInst(OperationInst *opStmt) {} }; /// Base class for statement walkers. A walker can traverse depth first in @@ -142,8 +142,8 @@ public: static_cast(this)->visitMLFunction(f); } - RetTy walkOpStmt(OperationStmt *opStmt) { - return static_cast(this)->visitOperationStmt(opStmt); + RetTy walkOpStmt(OperationInst *opStmt) { + return static_cast(this)->visitOperationInst(opStmt); } void walkForStmt(ForStmt *forStmt) { @@ -186,8 +186,8 @@ public: return static_cast(this)->walkForStmt(cast(s)); case Statement::Kind::If: return static_cast(this)->walkIfStmt(cast(s)); - case Statement::Kind::Operation: - return static_cast(this)->walkOpStmt(cast(s)); + case Statement::Kind::OperationInst: + return static_cast(this)->walkOpStmt(cast(s)); } } @@ -203,8 +203,8 @@ public: case Statement::Kind::If: return static_cast(this)->walkIfStmtPostOrder( cast(s)); - case Statement::Kind::Operation: - return static_cast(this)->walkOpStmt(cast(s)); + case Statement::Kind::OperationInst: + return static_cast(this)->walkOpStmt(cast(s)); } } @@ -222,7 +222,7 @@ public: void visitMLFunction(MLFunction *f) {} void visitForStmt(ForStmt *forStmt) {} void visitIfStmt(IfStmt *ifStmt) {} - void visitOperationStmt(OperationStmt *opStmt) {} + void visitOperationInst(OperationInst *opStmt) {} }; } // end namespace mlir diff --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h index ce96d6820f3..08b0898342c 100644 --- a/mlir/include/mlir/IR/UseDefLists.h +++ b/mlir/include/mlir/IR/UseDefLists.h @@ -76,7 +76,7 @@ private: class IROperandOwner { public: enum class Kind { - OperationStmt, + OperationInst, ForStmt, IfStmt, @@ -127,7 +127,7 @@ public: insertIntoCurrent(); } - /// Return the owner of this operand, for example, the OperationStmt that + /// Return the owner of this operand, for example, the OperationInst that /// contains a StmtOperand. IROperandOwner *getOwner() { return owner; } const IROperandOwner *getOwner() const { return owner; } @@ -175,7 +175,7 @@ private: /// This points to the previous link in the use-chain. IROperand **back = nullptr; - /// The owner of this operand, for example, the OperationStmt that contains a + /// The owner of this operand, for example, the OperationInst that contains a /// StmtOperand. IROperandOwner *const owner; diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h index d1ec774234b..3d8f1590371 100644 --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -28,13 +28,12 @@ namespace mlir { class Function; -class OperationStmt; -class Operation; +class OperationInst; class Statement; class StmtBlock; class Value; using Instruction = Statement; -using OperationInst = OperationStmt; +using OperationInst = OperationInst; /// Operands contain a Value. using StmtOperand = IROperandImpl; @@ -81,27 +80,13 @@ public: return const_cast(this)->getFunction(); } - /// If this value is the result of an Instruction, return the instruction + /// If this value is the result of an operation, return the instruction /// that defines it. OperationInst *getDefiningInst(); const OperationInst *getDefiningInst() const { return const_cast(this)->getDefiningInst(); } - /// If this value is the result of an OperationStmt, return the statement - /// that defines it. - OperationStmt *getDefiningStmt(); - const OperationStmt *getDefiningStmt() const { - return const_cast(this)->getDefiningStmt(); - } - - /// If this value is the result of an Operation, return the operation that - /// defines it. - Operation *getDefiningOperation(); - const Operation *getDefiningOperation() const { - return const_cast(this)->getDefiningOperation(); - } - using use_iterator = ValueUseIterator; using use_range = llvm::iterator_range; @@ -169,15 +154,15 @@ private: /// This is a value defined by a result of an operation instruction. class StmtResult : public Value { public: - StmtResult(Type type, OperationStmt *owner) + StmtResult(Type type, OperationInst *owner) : Value(Value::Kind::StmtResult, type), owner(owner) {} static bool classof(const Value *value) { return value->getKind() == Kind::StmtResult; } - OperationStmt *getOwner() { return owner; } - const OperationStmt *getOwner() const { return owner; } + OperationInst *getOwner() { return owner; } + const OperationInst *getOwner() const { return owner; } /// Returns the number of this result. unsigned getResultNumber() const; @@ -186,7 +171,7 @@ private: /// The owner of this operand. /// TODO: can encode this more efficiently to avoid the space hit of this /// through bitpacking shenanigans. - OperationStmt *const owner; + OperationInst *const owner; }; // TODO(clattner) clean all this up. diff --git a/mlir/include/mlir/StandardOps/StandardOps.h b/mlir/include/mlir/StandardOps/StandardOps.h index 789dfed43f1..33e08bae2fc 100644 --- a/mlir/include/mlir/StandardOps/StandardOps.h +++ b/mlir/include/mlir/StandardOps/StandardOps.h @@ -56,8 +56,8 @@ public: MLIRContext *context) const; private: - friend class Operation; - explicit AddFOp(const Operation *state) : BinaryOp(state) {} + friend class OperationInst; + explicit AddFOp(const OperationInst *state) : BinaryOp(state) {} }; /// The "addi" operation takes two operands and returns one result, each of @@ -80,8 +80,8 @@ public: MLIRContext *context); private: - friend class Operation; - explicit AddIOp(const Operation *state) : BinaryOp(state) {} + friend class OperationInst; + explicit AddIOp(const OperationInst *state) : BinaryOp(state) {} }; /// The "alloc" operation allocates a region of memory, as specified by its @@ -123,8 +123,8 @@ public: MLIRContext *context); private: - friend class Operation; - explicit AllocOp(const Operation *state) : Op(state) {} + friend class OperationInst; + explicit AllocOp(const OperationInst *state) : Op(state) {} }; /// The "call" operation represents a direct call to a function. The operands @@ -151,8 +151,8 @@ public: bool verify() const; protected: - friend class Operation; - explicit CallOp(const Operation *state) : Op(state) {} + friend class OperationInst; + explicit CallOp(const OperationInst *state) : Op(state) {} }; /// The "call_indirect" operation represents an indirect call to a value of @@ -180,8 +180,8 @@ public: bool verify() const; protected: - friend class Operation; - explicit CallIndirectOp(const Operation *state) : Op(state) {} + friend class OperationInst; + explicit CallIndirectOp(const OperationInst *state) : Op(state) {} }; /// The predicate indicates the type of the comparison to perform: @@ -245,8 +245,8 @@ public: bool verify() const; private: - friend class Operation; - explicit CmpIOp(const Operation *state) : Op(state) {} + friend class OperationInst; + explicit CmpIOp(const OperationInst *state) : Op(state) {} }; /// The "dealloc" operation frees the region of memory referenced by a memref @@ -277,8 +277,8 @@ public: MLIRContext *context); private: - friend class Operation; - explicit DeallocOp(const Operation *state) : Op(state) {} + friend class OperationInst; + explicit DeallocOp(const OperationInst *state) : Op(state) {} }; /// The "dim" operation takes a memref or tensor operand and returns an @@ -309,8 +309,8 @@ public: void print(OpAsmPrinter *p) const; private: - friend class Operation; - explicit DimOp(const Operation *state) : Op(state) {} + friend class OperationInst; + explicit DimOp(const OperationInst *state) : Op(state) {} }; // DmaStartOp starts a non-blocking DMA operation that transfers data from a @@ -367,7 +367,7 @@ public: return getSrcMemRef()->getType().cast().getRank(); } // Returns the source memerf indices for this DMA operation. - llvm::iterator_range + llvm::iterator_range getSrcIndices() const { return {getOperation()->operand_begin() + 1, getOperation()->operand_begin() + 1 + getSrcMemRefRank()}; @@ -389,7 +389,7 @@ public: } // Returns the destination memref indices for this DMA operation. - llvm::iterator_range + llvm::iterator_range getDstIndices() const { return {getOperation()->operand_begin() + 1 + getSrcMemRefRank() + 1, getOperation()->operand_begin() + 1 + getSrcMemRefRank() + 1 + @@ -411,7 +411,7 @@ public: } // Returns the tag memref index for this DMA operation. - llvm::iterator_range + llvm::iterator_range getTagIndices() const { unsigned tagIndexStartPos = 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1; @@ -471,8 +471,8 @@ public: } protected: - friend class ::mlir::Operation; - explicit DmaStartOp(const Operation *state) : Op(state) {} + friend class OperationInst; + explicit DmaStartOp(const OperationInst *state) : Op(state) {} }; // DmaWaitOp blocks until the completion of a DMA operation associated with the @@ -502,7 +502,7 @@ public: Value *getTagMemRef() { return getOperand(0); } // Returns the tag memref index for this DMA operation. - llvm::iterator_range + llvm::iterator_range getTagIndices() const { return {getOperation()->operand_begin() + 1, getOperation()->operand_begin() + 1 + getTagMemRefRank()}; @@ -524,8 +524,8 @@ public: MLIRContext *context); protected: - friend class ::mlir::Operation; - explicit DmaWaitOp(const Operation *state) : Op(state) {} + friend class OperationInst; + explicit DmaWaitOp(const OperationInst *state) : Op(state) {} }; /// The "extract_element" op reads a tensor or vector and returns one element @@ -549,11 +549,12 @@ public: Value *getAggregate() { return getOperand(0); } const Value *getAggregate() const { return getOperand(0); } - llvm::iterator_range getIndices() { + llvm::iterator_range getIndices() { return {getOperation()->operand_begin() + 1, getOperation()->operand_end()}; } - llvm::iterator_range getIndices() const { + llvm::iterator_range + getIndices() const { return {getOperation()->operand_begin() + 1, getOperation()->operand_end()}; } @@ -565,8 +566,8 @@ public: void print(OpAsmPrinter *p) const; private: - friend class Operation; - explicit ExtractElementOp(const Operation *state) : Op(state) {} + friend class OperationInst; + explicit ExtractElementOp(const OperationInst *state) : Op(state) {} }; /// The "load" op reads an element from a memref specified by an index list. The @@ -591,11 +592,12 @@ public: return getMemRef()->getType().cast(); } - llvm::iterator_range getIndices() { + llvm::iterator_range getIndices() { return {getOperation()->operand_begin() + 1, getOperation()->operand_end()}; } - llvm::iterator_range getIndices() const { + llvm::iterator_range + getIndices() const { return {getOperation()->operand_begin() + 1, getOperation()->operand_end()}; } @@ -608,8 +610,8 @@ public: MLIRContext *context); private: - friend class Operation; - explicit LoadOp(const Operation *state) : Op(state) {} + friend class OperationInst; + explicit LoadOp(const OperationInst *state) : Op(state) {} }; /// The "memref_cast" operation converts a memref from one type to an equivalent @@ -639,8 +641,8 @@ public: bool verify() const; private: - friend class Operation; - explicit MemRefCastOp(const Operation *state) : CastOp(state) {} + friend class OperationInst; + explicit MemRefCastOp(const OperationInst *state) : CastOp(state) {} }; /// The "mulf" operation takes two operands and returns one result, each of @@ -660,8 +662,8 @@ public: MLIRContext *context) const; private: - friend class Operation; - explicit MulFOp(const Operation *state) : BinaryOp(state) {} + friend class OperationInst; + explicit MulFOp(const OperationInst *state) : BinaryOp(state) {} }; /// The "muli" operation takes two operands and returns one result, each of @@ -684,8 +686,8 @@ public: MLIRContext *context); private: - friend class Operation; - explicit MulIOp(const Operation *state) : BinaryOp(state) {} + friend class OperationInst; + explicit MulIOp(const OperationInst *state) : BinaryOp(state) {} }; /// The "select" operation chooses one value based on a binary condition @@ -720,8 +722,8 @@ public: MLIRContext *context) const; private: - friend class Operation; - explicit SelectOp(const Operation *state) : Op(state) {} + friend class OperationInst; + explicit SelectOp(const OperationInst *state) : Op(state) {} }; /// The "store" op writes an element to a memref specified by an index list. @@ -752,11 +754,12 @@ public: return getMemRef()->getType().cast(); } - llvm::iterator_range getIndices() { + llvm::iterator_range getIndices() { return {getOperation()->operand_begin() + 2, getOperation()->operand_end()}; } - llvm::iterator_range getIndices() const { + llvm::iterator_range + getIndices() const { return {getOperation()->operand_begin() + 2, getOperation()->operand_end()}; } @@ -770,8 +773,8 @@ public: MLIRContext *context); private: - friend class Operation; - explicit StoreOp(const Operation *state) : Op(state) {} + friend class OperationInst; + explicit StoreOp(const OperationInst *state) : Op(state) {} }; /// The "subf" operation takes two operands and returns one result, each of @@ -790,8 +793,8 @@ public: MLIRContext *context) const; private: - friend class Operation; - explicit SubFOp(const Operation *state) : BinaryOp(state) {} + friend class OperationInst; + explicit SubFOp(const OperationInst *state) : BinaryOp(state) {} }; /// The "subi" operation takes two operands and returns one result, each of @@ -813,8 +816,8 @@ public: MLIRContext *context); private: - friend class Operation; - explicit SubIOp(const Operation *state) : BinaryOp(state) {} + friend class OperationInst; + explicit SubIOp(const OperationInst *state) : BinaryOp(state) {} }; /// The "tensor_cast" operation converts a tensor from one type to an equivalent @@ -839,8 +842,8 @@ public: bool verify() const; private: - friend class Operation; - explicit TensorCastOp(const Operation *state) : CastOp(state) {} + friend class OperationInst; + explicit TensorCastOp(const OperationInst *state) : CastOp(state) {} }; } // end namespace mlir diff --git a/mlir/include/mlir/SuperVectorOps/SuperVectorOps.h b/mlir/include/mlir/SuperVectorOps/SuperVectorOps.h index 38bc82569a5..dcdfafe8720 100644 --- a/mlir/include/mlir/SuperVectorOps/SuperVectorOps.h +++ b/mlir/include/mlir/SuperVectorOps/SuperVectorOps.h @@ -112,8 +112,9 @@ public: MemRefType getMemRefType() const { return getMemRef()->getType().cast(); } - llvm::iterator_range getIndices(); - llvm::iterator_range getIndices() const; + llvm::iterator_range getIndices(); + llvm::iterator_range + getIndices() const; Optional getPaddingValue(); Optional getPaddingValue() const; AffineMap getPermutationMap() const; @@ -123,8 +124,8 @@ public: bool verify() const; private: - friend class Operation; - explicit VectorTransferReadOp(const Operation *state) : Op(state) {} + friend class OperationInst; + explicit VectorTransferReadOp(const OperationInst *state) : Op(state) {} }; /// VectorTransferWriteOp performs a blocking write from a super-vector to @@ -180,8 +181,9 @@ public: MemRefType getMemRefType() const { return getMemRef()->getType().cast(); } - llvm::iterator_range getIndices(); - llvm::iterator_range getIndices() const; + llvm::iterator_range getIndices(); + llvm::iterator_range + getIndices() const; AffineMap getPermutationMap() const; static bool parse(OpAsmParser *parser, OperationState *result); @@ -189,8 +191,8 @@ public: bool verify() const; private: - friend class Operation; - explicit VectorTransferWriteOp(const Operation *state) : Op(state) {} + friend class OperationInst; + explicit VectorTransferWriteOp(const OperationInst *state) : Op(state) {} }; /// VectorTypeCastOp performs a conversion from a memref with scalar element to @@ -213,8 +215,8 @@ public: bool verify() const; private: - friend class Operation; - explicit VectorTypeCastOp(const Operation *state) : Op(state) {} + friend class OperationInst; + explicit VectorTypeCastOp(const OperationInst *state) : Op(state) {} }; } // end namespace mlir diff --git a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h index aaaec12653f..b680d78fce9 100644 --- a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h +++ b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h @@ -37,7 +37,7 @@ public: FuncBuilder *getBuilder() { return builder; } - Operation *createOperation(const OperationState &state) override { + OperationInst *createOperation(const OperationState &state) override { auto *result = builder->createOperation(state); return result; } @@ -66,7 +66,7 @@ public: /// must override). It will be passed the function-wise state, common to all /// matches, and the state returned by the `match` call, if any. The subclass /// must use `rewriter` to modify the function. - virtual void rewriteOpStmt(Operation *op, + virtual void rewriteOpStmt(OperationInst *op, MLFuncGlobalLoweringState *funcWiseState, std::unique_ptr opState, MLFuncLoweringRewriter *rewriter) const = 0; @@ -143,10 +143,10 @@ PassResult MLPatternLoweringPass::runOnMLFunction(MLFunction *f) { FuncBuilder builder(f); MLFuncLoweringRewriter rewriter(&builder); - llvm::SmallVector ops; - f->walk([&ops](OperationStmt *stmt) { ops.push_back(stmt); }); + llvm::SmallVector ops; + f->walk([&ops](OperationInst *stmt) { ops.push_back(stmt); }); - for (OperationStmt *stmt : ops) { + for (OperationInst *stmt : ops) { for (const auto &pattern : patterns) { rewriter.getBuilder()->setInsertionPoint(stmt); auto matchResult = pattern->match(stmt); diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h index 5670b60e0bd..131a1f16815 100644 --- a/mlir/include/mlir/Transforms/Utils.h +++ b/mlir/include/mlir/Transforms/Utils.h @@ -36,7 +36,7 @@ class ForStmt; class FuncBuilder; class Location; class Module; -class OperationStmt; +class OperationInst; class Function; using CFGFunction = Function; @@ -66,10 +66,10 @@ bool replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, /// these will also be collected into a single (multi-result) affine apply op. /// The final results of the composed AffineApplyOp are returned in output /// parameter 'results'. Returns the affine apply op created. -OperationStmt * +OperationInst * createComposedAffineApplyOp(FuncBuilder *builder, Location loc, ArrayRef operands, - ArrayRef affineApplyOps, + ArrayRef affineApplyOps, SmallVectorImpl *results); /// Given an operation statement, inserts a new single affine apply operation, @@ -98,7 +98,7 @@ createComposedAffineApplyOp(FuncBuilder *builder, Location loc, /// Returns nullptr if none of the operands were the result of an affine_apply /// and thus there was no affine computation slice to create. Returns the newly /// affine_apply operation statement otherwise. -OperationStmt *createAffineComputationSlice(OperationStmt *opStmt); +OperationInst *createAffineComputationSlice(OperationInst *opStmt); /// Forward substitutes results from 'AffineApplyOp' into any users which /// are also AffineApplyOps. @@ -113,7 +113,7 @@ bool constantFoldBounds(ForStmt *forStmt); /// Replaces (potentially nested) function attributes in the operation "op" /// with those specified in "remappingTable". void remapFunctionAttrs( - Operation &op, const DenseMap &remappingTable); + OperationInst &op, const DenseMap &remappingTable); /// Replaces (potentially nested) function attributes all operations of the /// Function "fn" with those specified in "remappingTable". diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index 78115b974a1..f3fde8bb95f 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -478,14 +478,14 @@ bool mlir::getFlattenedAffineExprs( localVarCst); } -/// Returns the sequence of AffineApplyOp OperationStmts operation in +/// Returns the sequence of AffineApplyOp OperationInsts operation in /// 'affineApplyOps', which are reachable via a search starting from 'operands', /// and ending at operands which are not defined by AffineApplyOps. // TODO(andydavis) Add a method to AffineApplyOp which forward substitutes // the AffineApplyOp into any user AffineApplyOps. void mlir::getReachableAffineApplyOps( ArrayRef operands, - SmallVectorImpl &affineApplyOps) { + SmallVectorImpl &affineApplyOps) { struct State { // The ssa value for this node in the DFS traversal. Value *value; @@ -499,9 +499,9 @@ void mlir::getReachableAffineApplyOps( while (!worklist.empty()) { State &state = worklist.back(); - auto *opStmt = state.value->getDefiningStmt(); - // Note: getDefiningStmt will return nullptr if the operand is not an - // OperationStmt (i.e. ForStmt), which is a terminator for the search. + auto *opStmt = state.value->getDefiningInst(); + // Note: getDefiningInst will return nullptr if the operand is not an + // OperationInst (i.e. ForStmt), which is a terminator for the search. if (opStmt == nullptr || !opStmt->isa()) { worklist.pop_back(); continue; @@ -531,7 +531,7 @@ void mlir::getReachableAffineApplyOps( // operands of 'valueMap'. void mlir::forwardSubstituteReachableOps(AffineValueMap *valueMap) { // Gather AffineApplyOps reachable from 'indices'. - SmallVector affineApplyOps; + SmallVector affineApplyOps; getReachableAffineApplyOps(valueMap->getOperands(), affineApplyOps); // Compose AffineApplyOps in 'affineApplyOps'. for (auto *opStmt : affineApplyOps) { @@ -842,7 +842,7 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap, auto *symbol = operands[i]; assert(symbol->isValidSymbol()); // Check if the symbol is a constant. - if (auto *opStmt = symbol->getDefiningStmt()) { + if (auto *opStmt = symbol->getDefiningInst()) { if (auto constOp = opStmt->dyn_cast()) { dependenceDomain->setIdToConstant(valuePosMap.getSymPos(symbol), constOp->getValue()); diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index bfdaceff7e7..dd564df3017 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1269,7 +1269,7 @@ bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) { addSymbolId(getNumSymbolIds(), const_cast(operand)); loc = getNumDimIds() + getNumSymbolIds() - 1; // Check if the symbol is a constant. - if (auto *opStmt = operand->getDefiningStmt()) { + if (auto *opStmt = operand->getDefiningInst()) { if (auto constOp = opStmt->dyn_cast()) { setIdToConstant(*operand, constOp->getValue()); } diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 7213ba5986a..85af39222c4 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -127,7 +127,7 @@ uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) { bool mlir::isAccessInvariant(const Value &iv, const Value &index) { assert(isa(iv) && "iv must be a ForStmt"); assert(index.getType().isa() && "index must be of IndexType"); - SmallVector affineApplyOps; + SmallVector affineApplyOps; getReachableAffineApplyOps({const_cast(&index)}, affineApplyOps); if (affineApplyOps.empty()) { @@ -234,13 +234,13 @@ static bool isVectorElement(LoadOrStoreOpPointer memoryOp) { } static bool isVectorTransferReadOrWrite(const Statement &stmt) { - const auto *opStmt = cast(&stmt); + const auto *opStmt = cast(&stmt); return opStmt->isa() || opStmt->isa(); } using VectorizableStmtFun = - std::function; + std::function; static bool isVectorizableLoopWithCond(const ForStmt &loop, VectorizableStmtFun isVectorizableStmt) { @@ -265,7 +265,7 @@ static bool isVectorizableLoopWithCond(const ForStmt &loop, auto loadAndStores = matcher::Op(matcher::isLoadOrStore); auto loadAndStoresMatched = loadAndStores.match(forStmt); for (auto ls : loadAndStoresMatched) { - auto *op = cast(ls.first); + auto *op = cast(ls.first); auto load = op->dyn_cast(); auto store = op->dyn_cast(); // Only scalar types are considered vectorizable, all load/store must be @@ -285,7 +285,7 @@ static bool isVectorizableLoopWithCond(const ForStmt &loop, bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim( const ForStmt &loop, unsigned fastestVaryingDim) { VectorizableStmtFun fun( - [fastestVaryingDim](const ForStmt &loop, const OperationStmt &op) { + [fastestVaryingDim](const ForStmt &loop, const OperationInst &op) { auto load = op.dyn_cast(); auto store = op.dyn_cast(); return load ? isContiguousAccess(loop, *load, fastestVaryingDim) @@ -297,7 +297,7 @@ bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim( bool mlir::isVectorizableLoop(const ForStmt &loop) { VectorizableStmtFun fun( // TODO: implement me - [](const ForStmt &loop, const OperationStmt &op) { return true; }); + [](const ForStmt &loop, const OperationInst &op) { return true; }); return isVectorizableLoopWithCond(loop, fun); } @@ -314,7 +314,7 @@ bool mlir::isStmtwiseShiftValid(const ForStmt &forStmt, for (const auto &stmt : *forBody) { // A for or if stmt does not produce any def/results (that are used // outside). - if (const auto *opStmt = dyn_cast(&stmt)) { + if (const auto *opStmt = dyn_cast(&stmt)) { for (unsigned i = 0, e = opStmt->getNumResults(); i < e; ++i) { const Value *result = opStmt->getResult(i); for (const StmtOperand &use : result->getUses()) { diff --git a/mlir/lib/Analysis/MLFunctionMatcher.cpp b/mlir/lib/Analysis/MLFunctionMatcher.cpp index c227aa3fcdd..c03fed5986b 100644 --- a/mlir/lib/Analysis/MLFunctionMatcher.cpp +++ b/mlir/lib/Analysis/MLFunctionMatcher.cpp @@ -200,7 +200,7 @@ namespace mlir { namespace matcher { MLFunctionMatcher Op(FilterFunctionType filter) { - return MLFunctionMatcher(Statement::Kind::Operation, {}, filter); + return MLFunctionMatcher(Statement::Kind::OperationInst, {}, filter); } MLFunctionMatcher If(MLFunctionMatcher child) { @@ -246,7 +246,7 @@ bool isReductionLoop(const Statement &stmt) { }; bool isLoadOrStore(const Statement &stmt) { - const auto *opStmt = dyn_cast(&stmt); + const auto *opStmt = dyn_cast(&stmt); return opStmt && (opStmt->isa() || opStmt->isa()); }; diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index 995bb466fef..1cb039fe00e 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -45,7 +45,7 @@ struct MemRefBoundCheck : public FunctionPass, StmtWalker { // Not applicable to CFG functions. PassResult runOnCFGFunction(CFGFunction *f) override { return success(); } - void visitOperationStmt(OperationStmt *opStmt); + void visitOperationInst(OperationInst *opStmt); static char passID; }; @@ -58,7 +58,7 @@ FunctionPass *mlir::createMemRefBoundCheckPass() { return new MemRefBoundCheck(); } -void MemRefBoundCheck::visitOperationStmt(OperationStmt *opStmt) { +void MemRefBoundCheck::visitOperationInst(OperationInst *opStmt) { if (auto loadOp = opStmt->dyn_cast()) { boundCheckLoadOrStoreOp(loadOp); } else if (auto storeOp = opStmt->dyn_cast()) { diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp index 7c57a66310a..ec33c619a17 100644 --- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp @@ -40,7 +40,7 @@ namespace { /// Checks dependences between all pairs of memref accesses in an MLFunction. struct MemRefDependenceCheck : public FunctionPass, StmtWalker { - SmallVector loadsAndStores; + SmallVector loadsAndStores; explicit MemRefDependenceCheck() : FunctionPass(&MemRefDependenceCheck::passID) {} @@ -48,7 +48,7 @@ struct MemRefDependenceCheck : public FunctionPass, // Not applicable to CFG functions. PassResult runOnCFGFunction(CFGFunction *f) override { return success(); } - void visitOperationStmt(OperationStmt *opStmt) { + void visitOperationInst(OperationInst *opStmt) { if (opStmt->isa() || opStmt->isa()) { loadsAndStores.push_back(opStmt); } @@ -66,7 +66,7 @@ FunctionPass *mlir::createMemRefDependenceCheckPass() { // Adds memref access indices 'opIndices' from 'memrefType' to 'access'. static void addMemRefAccessIndices( - llvm::iterator_range opIndices, + llvm::iterator_range opIndices, MemRefType memrefType, MemRefAccess *access) { access->indices.reserve(memrefType.getRank()); for (auto *index : opIndices) { @@ -75,7 +75,7 @@ static void addMemRefAccessIndices( } // Populates 'access' with memref, indices and opstmt from 'loadOrStoreOpStmt'. -static void getMemRefAccess(const OperationStmt *loadOrStoreOpStmt, +static void getMemRefAccess(const OperationInst *loadOrStoreOpStmt, MemRefAccess *access) { access->opStmt = loadOrStoreOpStmt; if (auto loadOp = loadOrStoreOpStmt->dyn_cast()) { @@ -131,7 +131,7 @@ getDirectionVectorStr(bool ret, unsigned numCommonLoops, unsigned loopNestDepth, // "source" access and all subsequent "destination" accesses in // 'loadsAndStores'. Emits the result of the dependence check as a note with // the source access. -static void checkDependences(ArrayRef loadsAndStores) { +static void checkDependences(ArrayRef loadsAndStores) { for (unsigned i = 0, e = loadsAndStores.size(); i < e; ++i) { auto *srcOpStmt = loadsAndStores[i]; MemRefAccess srcAccess; diff --git a/mlir/lib/Analysis/OpStats.cpp b/mlir/lib/Analysis/OpStats.cpp index d9a0edd6d83..cea0c087297 100644 --- a/mlir/lib/Analysis/OpStats.cpp +++ b/mlir/lib/Analysis/OpStats.cpp @@ -38,7 +38,7 @@ struct PrintOpStatsPass : public FunctionPass, StmtWalker { // Process ML functions and operation statments in ML functions. PassResult runOnMLFunction(MLFunction *function) override; - void visitOperationStmt(OperationStmt *stmt); + void visitOperationInst(OperationInst *stmt); // Print summary of op stats. void printSummary(); @@ -69,7 +69,7 @@ PassResult PrintOpStatsPass::runOnCFGFunction(CFGFunction *function) { return success(); } -void PrintOpStatsPass::visitOperationStmt(OperationStmt *stmt) { +void PrintOpStatsPass::visitOperationInst(OperationInst *stmt) { ++opCount[stmt->getName().getStringRef()]; } diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index b7873f8327f..c06bf4df61e 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -52,7 +52,7 @@ void mlir::getForwardSlice(Statement *stmt, return; } - if (auto *opStmt = dyn_cast(stmt)) { + if (auto *opStmt = dyn_cast(stmt)) { assert(opStmt->getNumResults() <= 1 && "NYI: multiple results"); if (opStmt->getNumResults() > 0) { for (auto &u : opStmt->getResult(0)->getUses()) { @@ -102,7 +102,7 @@ void mlir::getBackwardSlice(Statement *stmt, } for (auto *operand : stmt->getOperands()) { - auto *stmt = operand->getDefiningStmt(); + auto *stmt = operand->getDefiningInst(); if (backwardSlice->count(stmt) == 0) { getBackwardSlice(stmt, backwardSlice, filter, /*topLevel=*/false); @@ -156,7 +156,7 @@ struct DFSState { } // namespace static void DFSPostorder(Statement *current, DFSState *state) { - auto *opStmt = cast(current); + auto *opStmt = cast(current); assert(opStmt->getNumResults() <= 1 && "NYI: multi-result"); if (opStmt->getNumResults() > 0) { for (auto &u : opStmt->getResult(0)->getUses()) { diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index e6975ac5d09..a63723b333c 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -145,7 +145,7 @@ Optional MemRefRegion::getBoundingConstantSizeAndShape( // // TODO(bondhugula): extend this to any other memref dereferencing ops // (dma_start, dma_wait). -bool mlir::getMemRefRegion(OperationStmt *opStmt, unsigned loopDepth, +bool mlir::getMemRefRegion(OperationInst *opStmt, unsigned loopDepth, MemRefRegion *region) { OpPointer loadOp; OpPointer storeOp; @@ -204,7 +204,7 @@ bool mlir::getMemRefRegion(OperationStmt *opStmt, unsigned loopDepth, auto *symbol = accessValueMap.getOperand(i); assert(symbol->isValidSymbol()); // Check if the symbol is a constant. - if (auto *opStmt = symbol->getDefiningStmt()) { + if (auto *opStmt = symbol->getDefiningInst()) { if (auto constOp = opStmt->dyn_cast()) { regionCst->setIdToConstant(*symbol, constOp->getValue()); } @@ -282,7 +282,7 @@ bool mlir::boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, std::is_same>::value, "function argument should be either a LoadOp or a StoreOp"); - OperationStmt *opStmt = cast(loadOrStoreOp->getOperation()); + OperationInst *opStmt = cast(loadOrStoreOp->getOperation()); MemRefRegion region; if (!getMemRefRegion(opStmt, /*loopDepth=*/0, ®ion)) return false; diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index ec19194f2fa..cd9451cd5e9 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -104,7 +104,7 @@ Optional> mlir::shapeRatio(VectorType superVectorType, /// header file. static AffineMap makePermutationMap( MLIRContext *context, - llvm::iterator_range indices, + llvm::iterator_range indices, const DenseMap &enclosingLoopToVectorDim) { using functional::makePtrDynCaster; using functional::map; @@ -157,7 +157,7 @@ static SetVector getEnclosingForStmts(Statement *stmt) { } AffineMap -mlir::makePermutationMap(OperationStmt *opStmt, +mlir::makePermutationMap(OperationInst *opStmt, const DenseMap &loopToVectorDim) { DenseMap enclosingLoopToVectorDim; auto enclosingLoops = getEnclosingForStmts(opStmt); @@ -178,7 +178,7 @@ mlir::makePermutationMap(OperationStmt *opStmt, enclosingLoopToVectorDim); } -bool mlir::matcher::operatesOnStrictSuperVectors(const OperationStmt &opStmt, +bool mlir::matcher::operatesOnStrictSuperVectors(const OperationInst &opStmt, VectorType subVectorType) { // First, extract the vector type and ditinguish between: // a. ops that *must* lower a super-vector (i.e. vector_transfer_read, diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index e7abb899a11..e1de6191de6 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -51,7 +51,7 @@ namespace { /// class Verifier { public: - bool failure(const Twine &message, const Operation &value) { + bool failure(const Twine &message, const OperationInst &value) { return value.emitError(message); } @@ -62,15 +62,15 @@ public: bool failure(const Twine &message, const BasicBlock &bb) { // Take the location information for the first instruction in the block. if (!bb.empty()) - if (auto *op = dyn_cast(&bb.front())) + if (auto *op = dyn_cast(&bb.front())) return failure(message, *op); // Worst case, fall back to using the function's location. return failure(message, fn); } - bool verifyOperation(const Operation &op); - bool verifyAttribute(Attribute attr, const Operation &op); + bool verifyOperation(const OperationInst &op); + bool verifyAttribute(Attribute attr, const OperationInst &op); protected: explicit Verifier(const Function &fn) : fn(fn) {} @@ -82,7 +82,7 @@ private: } // end anonymous namespace // Check that function attributes are all well formed. -bool Verifier::verifyAttribute(Attribute attr, const Operation &op) { +bool Verifier::verifyAttribute(Attribute attr, const OperationInst &op) { if (!attr.isOrContainsFunction()) return false; @@ -109,9 +109,9 @@ bool Verifier::verifyAttribute(Attribute attr, const Operation &op) { return false; } -/// Check the invariants of the specified operation instruction or statement. -bool Verifier::verifyOperation(const Operation &op) { - if (op.getOperationFunction() != &fn) +/// Check the invariants of the specified operation. +bool Verifier::verifyOperation(const OperationInst &op) { + if (op.getFunction() != &fn) return failure("operation in the wrong function", op); // Check that operands are non-nil and structurally ok. @@ -245,7 +245,7 @@ struct MLFuncVerifier : public Verifier, public StmtWalker { MLFuncVerifier(const MLFunction &fn) : Verifier(fn), fn(fn) {} - void visitOperationStmt(OperationStmt *opStmt) { + void visitOperationInst(OperationInst *opStmt) { hadError |= verifyOperation(*opStmt); } @@ -302,14 +302,14 @@ bool MLFuncVerifier::verifyDominance() { if (!liveValues.count(opValue)) { stmt.emitError("operand #" + Twine(operandNo) + " does not dominate this use"); - if (auto *useStmt = opValue->getDefiningStmt()) + if (auto *useStmt = opValue->getDefiningInst()) useStmt->emitNote("operand defined here"); return true; } ++operandNo; } - if (auto *opStmt = dyn_cast(&stmt)) { + if (auto *opStmt = dyn_cast(&stmt)) { // Operations define values, add them to the hash table. for (auto *result : opStmt->getResults()) liveValues.insert(result, true); @@ -344,7 +344,7 @@ bool MLFuncVerifier::verifyReturn() { return failure(missingReturnMsg, fn); const auto &stmt = fn.getBody()->getStatements().back(); - if (const auto *op = dyn_cast(&stmt)) { + if (const auto *op = dyn_cast(&stmt)) { if (!op->isReturn()) return failure(missingReturnMsg, fn); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index c44ce4d4d6c..9f465ab8507 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -120,10 +120,10 @@ private: void visitStatement(const Statement *stmt); void visitForStmt(const ForStmt *forStmt); void visitIfStmt(const IfStmt *ifStmt); - void visitOperationStmt(const OperationStmt *opStmt); + void visitOperationInst(const OperationInst *opStmt); void visitType(Type type); void visitAttribute(Attribute attr); - void visitOperation(const Operation *op); + void visitOperation(const OperationInst *op); DenseMap affineMapIds; std::vector affineMapsById; @@ -161,7 +161,7 @@ void ModuleState::visitAttribute(Attribute attr) { } } -void ModuleState::visitOperation(const Operation *op) { +void ModuleState::visitOperation(const OperationInst *op) { // Visit all the types used in the operation. for (auto *operand : op->getOperands()) visitType(operand->getType()); @@ -212,7 +212,7 @@ void ModuleState::visitForStmt(const ForStmt *forStmt) { visitStatement(&childStmt); } -void ModuleState::visitOperationStmt(const OperationStmt *opStmt) { +void ModuleState::visitOperationInst(const OperationInst *opStmt) { for (auto attr : opStmt->getAttrs()) visitAttribute(attr.second); } @@ -223,8 +223,8 @@ void ModuleState::visitStatement(const Statement *stmt) { return visitIfStmt(cast(stmt)); case Statement::Kind::For: return visitForStmt(cast(stmt)); - case Statement::Kind::Operation: - return visitOperationStmt(cast(stmt)); + case Statement::Kind::OperationInst: + return visitOperationInst(cast(stmt)); default: return; } @@ -944,8 +944,8 @@ class FunctionPrinter : public ModulePrinter, private OpAsmPrinter { public: FunctionPrinter(const ModulePrinter &other) : ModulePrinter(other) {} - void printOperation(const Operation *op); - void printDefaultOp(const Operation *op); + void printOperation(const OperationInst *op); + void printDefaultOp(const OperationInst *op); // Implement OpAsmPrinter. raw_ostream &getStream() const { return os; } @@ -983,7 +983,7 @@ protected: llvm::raw_svector_ostream specialName(specialNameBuffer); // Give constant integers special names. - if (auto *op = value->getDefiningOperation()) { + if (auto *op = value->getDefiningInst()) { if (auto intOp = op->dyn_cast()) { // i1 constants get special names. if (intOp->getType().isInteger(1)) { @@ -1111,7 +1111,7 @@ private: }; } // end anonymous namespace -void FunctionPrinter::printOperation(const Operation *op) { +void FunctionPrinter::printOperation(const OperationInst *op) { if (op->getNumResults()) { printValueID(op->getResult(0), /*printResultNo=*/false); os << " = "; @@ -1128,7 +1128,7 @@ void FunctionPrinter::printOperation(const Operation *op) { printDefaultOp(op); } -void FunctionPrinter::printDefaultOp(const Operation *op) { +void FunctionPrinter::printDefaultOp(const OperationInst *op) { os << '"'; printEscapedString(op->getName().getStringRef(), os); os << "\"("; @@ -1172,7 +1172,7 @@ public: void print(const Instruction *inst); - void printSuccessorAndUseList(const Operation *term, unsigned index); + void printSuccessorAndUseList(const OperationInst *term, unsigned index); void printBBName(const BasicBlock *block) { os << "bb" << getBBID(block); } @@ -1302,7 +1302,7 @@ void CFGFunctionPrinter::printBranchOperands(const Range &range) { os << ')'; } -void CFGFunctionPrinter::printSuccessorAndUseList(const Operation *term, +void CFGFunctionPrinter::printSuccessorAndUseList(const OperationInst *term, unsigned index) { printBBName(term->getSuccessor(index)); printBranchOperands(term->getSuccessorOperands(index)); @@ -1331,11 +1331,11 @@ public: // Methods to print ML function statements. void print(const Statement *stmt); - void print(const OperationStmt *stmt); + void print(const OperationInst *stmt); void print(const ForStmt *stmt); void print(const IfStmt *stmt); void print(const StmtBlock *block); - void printSuccessorAndUseList(const Operation *term, unsigned index) { + void printSuccessorAndUseList(const OperationInst *term, unsigned index) { assert(false && "MLFunctions do not have terminators with successors."); } @@ -1371,7 +1371,7 @@ void MLFunctionPrinter::numberValues() { // the first result of the operation statements. struct NumberValuesPass : public StmtWalker { NumberValuesPass(MLFunctionPrinter *printer) : printer(printer) {} - void visitOperationStmt(OperationStmt *stmt) { + void visitOperationInst(OperationInst *stmt) { if (stmt->getNumResults() != 0) printer->numberValueID(stmt->getResult(0)); } @@ -1421,8 +1421,8 @@ void MLFunctionPrinter::print(const StmtBlock *block) { void MLFunctionPrinter::print(const Statement *stmt) { switch (stmt->getKind()) { - case Statement::Kind::Operation: - return print(cast(stmt)); + case Statement::Kind::OperationInst: + return print(cast(stmt)); case Statement::Kind::For: return print(cast(stmt)); case Statement::Kind::If: @@ -1430,7 +1430,7 @@ void MLFunctionPrinter::print(const Statement *stmt) { } } -void MLFunctionPrinter::print(const OperationStmt *stmt) { +void MLFunctionPrinter::print(const OperationInst *stmt) { os.indent(numSpaces); printOperation(stmt); } @@ -1580,7 +1580,7 @@ void Value::print(raw_ostream &os) const { os << "\n"; return; case Value::Kind::StmtResult: - return getDefiningStmt()->print(os); + return getDefiningInst()->print(os); case Value::Kind::ForStmt: return cast(this)->print(os); } diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 0d3e54364b3..81a3b7c2950 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -290,7 +290,7 @@ StmtBlock *FuncBuilder::createBlock(StmtBlock *insertBefore) { } /// Create an operation given the fields represented as an OperationState. -OperationStmt *FuncBuilder::createOperation(const OperationState &state) { +OperationInst *FuncBuilder::createOperation(const OperationState &state) { auto *op = OperationInst::create(state.location, state.name, state.operands, state.types, state.attributes, state.successors, context); diff --git a/mlir/lib/IR/BuiltinOps.cpp b/mlir/lib/IR/BuiltinOps.cpp index 50ab254dd76..a87ae6b85f0 100644 --- a/mlir/lib/IR/BuiltinOps.cpp +++ b/mlir/lib/IR/BuiltinOps.cpp @@ -36,8 +36,8 @@ BuiltinDialect::BuiltinDialect(MLIRContext *context) addOperations(); } -void mlir::printDimAndSymbolList(Operation::const_operand_iterator begin, - Operation::const_operand_iterator end, +void mlir::printDimAndSymbolList(OperationInst::const_operand_iterator begin, + OperationInst::const_operand_iterator end, unsigned numDims, OpAsmPrinter *p) { *p << '('; p->printOperands(begin, begin + numDims); @@ -188,14 +188,12 @@ void BranchOp::print(OpAsmPrinter *p) const { bool BranchOp::verify() const { // ML functions do not have branching terminators. - if (getOperation()->getOperationFunction()->isML()) + if (getOperation()->getFunction()->isML()) return (emitOpError("cannot occur in a ML function"), true); return false; } -BasicBlock *BranchOp::getDest() const { - return getOperation()->getSuccessor(0); -} +BasicBlock *BranchOp::getDest() { return getOperation()->getSuccessor(0); } void BranchOp::setDest(BasicBlock *block) { return getOperation()->setSuccessor(block, 0); @@ -258,18 +256,18 @@ void CondBranchOp::print(OpAsmPrinter *p) const { bool CondBranchOp::verify() const { // ML functions do not have branching terminators. - if (getOperation()->getOperationFunction()->isML()) + if (getOperation()->getFunction()->isML()) return (emitOpError("cannot occur in a ML function"), true); if (!getCondition()->getType().isInteger(1)) return emitOpError("expected condition type was boolean (i1)"); return false; } -BasicBlock *CondBranchOp::getTrueDest() const { +BasicBlock *CondBranchOp::getTrueDest() { return getOperation()->getSuccessor(trueIndex); } -BasicBlock *CondBranchOp::getFalseDest() const { +BasicBlock *CondBranchOp::getFalseDest() { return getOperation()->getSuccessor(falseIndex); } @@ -399,13 +397,13 @@ void ConstantFloatOp::build(Builder *builder, OperationState *result, ConstantOp::build(builder, result, builder->getFloatAttr(type, value), type); } -bool ConstantFloatOp::isClassFor(const Operation *op) { +bool ConstantFloatOp::isClassFor(const OperationInst *op) { return ConstantOp::isClassFor(op) && op->getResult(0)->getType().isa(); } /// ConstantIntOp only matches values whose result type is an IntegerType. -bool ConstantIntOp::isClassFor(const Operation *op) { +bool ConstantIntOp::isClassFor(const OperationInst *op) { return ConstantOp::isClassFor(op) && op->getResult(0)->getType().isa(); } @@ -427,7 +425,7 @@ void ConstantIntOp::build(Builder *builder, OperationState *result, } /// ConstantIndexOp only matches values whose result type is Index. -bool ConstantIndexOp::isClassFor(const Operation *op) { +bool ConstantIndexOp::isClassFor(const OperationInst *op) { return ConstantOp::isClassFor(op) && op->getResult(0)->getType().isIndex(); } @@ -470,7 +468,7 @@ void ReturnOp::print(OpAsmPrinter *p) const { } bool ReturnOp::verify() const { - auto *function = cast(getOperation())->getFunction(); + auto *function = cast(getOperation())->getFunction(); // The operand number and types must match the function signature. const auto &results = function->getType().getResults(); diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index 62f1dca067d..19b137071f4 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -161,34 +161,34 @@ bool Function::emitError(const Twine &message) const { // MLFunction implementation. //===----------------------------------------------------------------------===// -const OperationStmt *MLFunction::getReturnStmt() const { - return cast(&getBody()->back()); +const OperationInst *MLFunction::getReturnStmt() const { + return cast(&getBody()->back()); } -OperationStmt *MLFunction::getReturnStmt() { - return cast(&getBody()->back()); +OperationInst *MLFunction::getReturnStmt() { + return cast(&getBody()->back()); } -void MLFunction::walk(std::function callback) { +void MLFunction::walk(std::function callback) { struct Walker : public StmtWalker { - std::function const &callback; - Walker(std::function const &callback) + std::function const &callback; + Walker(std::function const &callback) : callback(callback) {} - void visitOperationStmt(OperationStmt *opStmt) { callback(opStmt); } + void visitOperationInst(OperationInst *opStmt) { callback(opStmt); } }; Walker v(callback); v.walk(this); } -void MLFunction::walkPostOrder(std::function callback) { +void MLFunction::walkPostOrder(std::function callback) { struct Walker : public StmtWalker { - std::function const &callback; - Walker(std::function const &callback) + std::function const &callback; + Walker(std::function const &callback) : callback(callback) {} - void visitOperationStmt(OperationStmt *opStmt) { callback(opStmt); } + void visitOperationInst(OperationInst *opStmt) { callback(opStmt); } }; Walker v(callback); diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index da0bc4b1595..abc3e1cfda4 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -450,7 +450,7 @@ auto MLIRContext::getDiagnosticHandler() const -> DiagnosticHandlerTy { /// This emits a diagnostic using the registered issue handle if present, or /// with the default behavior if not. The MLIR compiler should not generally -/// interact with this, it should use methods on Operation instead. +/// interact with this, it should use methods on OperationInst instead. void MLIRContext::emitDiagnostic(Location location, const llvm::Twine &message, DiagnosticKind kind) const { // Check to see if we are emitting a diagnostic on a fused location. diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 0526c6ea610..6a9b37560db 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -1,4 +1,4 @@ -//===- Operation.cpp - MLIR Operation Class -------------------------------===// +//===- Operation.cpp - Operation support code -----------------------------===// // // Copyright 2019 The MLIR Authors. // @@ -15,8 +15,6 @@ // limitations under the License. // ============================================================================= -#include "mlir/IR/Operation.h" -#include "AttributeListStorage.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Function.h" #include "mlir/IR/MLIRContext.h" @@ -52,200 +50,6 @@ OperationName OperationName::getFromOpaquePointer(void *pointer) { OpAsmParser::~OpAsmParser() {} -//===----------------------------------------------------------------------===// -// Operation class -//===----------------------------------------------------------------------===// - -Operation::Operation(OperationName name, ArrayRef attrs, - Location location, MLIRContext *context) - : Statement(Kind::Operation, location), name(name) { - this->attrs = AttributeListStorage::get(attrs, context); - -#ifndef NDEBUG - for (auto elt : attrs) - assert(elt.second != nullptr && "Attributes cannot have null entries"); -#endif -} - -Operation::~Operation() {} - - -/// Return the function this operation is defined in. -Function *Operation::getOperationFunction() { - return llvm::cast(this)->getFunction(); -} - -/// Return the number of results this operation has. -unsigned Operation::getNumResults() const { - return llvm::cast(this)->getNumResults(); -} - -/// Return the indicated result. -Value *Operation::getResult(unsigned idx) { - return llvm::cast(this)->getResult(idx); -} - -unsigned Operation::getNumSuccessors() const { - assert(isTerminator() && "Only terminators have successors."); - return llvm::cast(this)->getNumSuccessors(); -} - -unsigned Operation::getNumSuccessorOperands(unsigned index) const { - assert(isTerminator() && "Only terminators have successors."); - return llvm::cast(this)->getNumSuccessorOperands(index); -} -BasicBlock *Operation::getSuccessor(unsigned index) { - assert(isTerminator() && "Only terminators have successors"); - return llvm::cast(this)->getSuccessor(index); -} -void Operation::setSuccessor(BasicBlock *block, unsigned index) { - assert(isTerminator() && "Only terminators have successors"); - llvm::cast(this)->setSuccessor(block, index); -} - -void Operation::eraseSuccessorOperand(unsigned succIndex, unsigned opIndex) { - assert(isTerminator() && "Only terminators have successors"); - return llvm::cast(this)->eraseSuccessorOperand(succIndex, - opIndex); -} -auto Operation::getSuccessorOperands(unsigned index) const - -> llvm::iterator_range { - assert(isTerminator() && "Only terminators have successors."); - unsigned succOperandIndex = - llvm::cast(this)->getSuccessorOperandIndex(index); - return {const_operand_iterator(this, succOperandIndex), - const_operand_iterator(this, succOperandIndex + - getNumSuccessorOperands(index))}; -} -auto Operation::getSuccessorOperands(unsigned index) - -> llvm::iterator_range { - assert(isTerminator() && "Only terminators have successors."); - unsigned succOperandIndex = - llvm::cast(this)->getSuccessorOperandIndex(index); - return {operand_iterator(this, succOperandIndex), - operand_iterator(this, - succOperandIndex + getNumSuccessorOperands(index))}; -} - -/// Return true if there are no users of any results of this operation. -bool Operation::use_empty() const { - for (auto *result : getResults()) - if (!result->use_empty()) - return false; - return true; -} - -ArrayRef Operation::getAttrs() const { - if (!attrs) - return {}; - return attrs->getElements(); -} - -/// If an attribute exists with the specified name, change it to the new -/// value. Otherwise, add a new attribute with the specified name/value. -void Operation::setAttr(Identifier name, Attribute value) { - assert(value && "attributes may never be null"); - auto origAttrs = getAttrs(); - - SmallVector newAttrs(origAttrs.begin(), origAttrs.end()); - auto *context = getContext(); - - // If we already have this attribute, replace it. - for (auto &elt : newAttrs) - if (elt.first == name) { - elt.second = value; - attrs = AttributeListStorage::get(newAttrs, context); - return; - } - - // Otherwise, add it. - newAttrs.push_back({name, value}); - attrs = AttributeListStorage::get(newAttrs, context); -} - -/// Remove the attribute with the specified name if it exists. The return -/// value indicates whether the attribute was present or not. -auto Operation::removeAttr(Identifier name) -> RemoveResult { - auto origAttrs = getAttrs(); - for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) { - if (origAttrs[i].first == name) { - SmallVector newAttrs; - newAttrs.reserve(origAttrs.size() - 1); - newAttrs.append(origAttrs.begin(), origAttrs.begin() + i); - newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end()); - attrs = AttributeListStorage::get(newAttrs, getContext()); - return RemoveResult::Removed; - } - } - return RemoveResult::NotFound; -} - -/// Emit a note about this operation, reporting up to any diagnostic -/// handlers that may be listening. -void Operation::emitNote(const Twine &message) const { - getContext()->emitDiagnostic(getLoc(), message, - MLIRContext::DiagnosticKind::Note); -} - -/// Emit a warning about this operation, reporting up to any diagnostic -/// handlers that may be listening. -void Operation::emitWarning(const Twine &message) const { - getContext()->emitDiagnostic(getLoc(), message, - MLIRContext::DiagnosticKind::Warning); -} - -/// Emit an error about fatal conditions with this operation, reporting up to -/// any diagnostic handlers that may be listening. This function always returns -/// true. NOTE: This may terminate the containing application, only use when -/// the IR is in an inconsistent state. -bool Operation::emitError(const Twine &message) const { - return getContext()->emitError(getLoc(), message); -} - -/// Emit an error with the op name prefixed, like "'dim' op " which is -/// convenient for verifiers. -bool Operation::emitOpError(const Twine &message) const { - return emitError(Twine('\'') + getName().getStringRef() + "' op " + message); -} - -/// Remove this operation from its parent block and delete it. -void Operation::erase() { - return llvm::cast(this)->erase(); -} - -/// Attempt to constant fold this operation with the specified constant -/// operand values. If successful, this returns false and fills in the -/// results vector. If not, this returns true and results is unspecified. -bool Operation::constantFold(ArrayRef operands, - SmallVectorImpl &results) const { - if (auto *abstractOp = getAbstractOperation()) { - // If we have a registered operation definition matching this one, use it to - // try to constant fold the operation. - if (!abstractOp->constantFoldHook(this, operands, results)) - return false; - - // Otherwise, fall back on the dialect hook to handle it. - return abstractOp->dialect.constantFoldHook(this, operands, results); - } - - // If this operation hasn't been registered or doesn't have abstract - // operation, fall back to a dialect which matches the prefix. - auto opName = getName().getStringRef(); - if (auto *dialect = getContext()->getRegisteredDialect(opName)) { - return dialect->constantFoldHook(this, operands, results); - } - - return true; -} - -/// Methods for support type inquiry through isa, cast, and dyn_cast. -bool Operation::classof(const Statement *stmt) { - return stmt->getKind() == Statement::Kind::Operation; -} -bool Operation::classof(const IROperandOwner *ptr) { - return ptr->getKind() == IROperandOwner::Kind::OperationStmt; -} - //===----------------------------------------------------------------------===// // OpState trait class. //===----------------------------------------------------------------------===// @@ -290,19 +94,20 @@ void OpState::emitNote(const Twine &message) const { // Op Trait implementations //===----------------------------------------------------------------------===// -bool OpTrait::impl::verifyZeroOperands(const Operation *op) { +bool OpTrait::impl::verifyZeroOperands(const OperationInst *op) { if (op->getNumOperands() != 0) return op->emitOpError("requires zero operands"); return false; } -bool OpTrait::impl::verifyOneOperand(const Operation *op) { +bool OpTrait::impl::verifyOneOperand(const OperationInst *op) { if (op->getNumOperands() != 1) return op->emitOpError("requires a single operand"); return false; } -bool OpTrait::impl::verifyNOperands(const Operation *op, unsigned numOperands) { +bool OpTrait::impl::verifyNOperands(const OperationInst *op, + unsigned numOperands) { if (op->getNumOperands() != numOperands) { return op->emitOpError("expected " + Twine(numOperands) + " operands, but found " + @@ -311,7 +116,7 @@ bool OpTrait::impl::verifyNOperands(const Operation *op, unsigned numOperands) { return false; } -bool OpTrait::impl::verifyAtLeastNOperands(const Operation *op, +bool OpTrait::impl::verifyAtLeastNOperands(const OperationInst *op, unsigned numOperands) { if (op->getNumOperands() < numOperands) return op->emitOpError("expected " + Twine(numOperands) + @@ -331,7 +136,7 @@ static Type getTensorOrVectorElementType(Type type) { return type; } -bool OpTrait::impl::verifyOperandsAreIntegerLike(const Operation *op) { +bool OpTrait::impl::verifyOperandsAreIntegerLike(const OperationInst *op) { for (auto *operand : op->getOperands()) { auto type = getTensorOrVectorElementType(operand->getType()); if (!type.isIntOrIndex()) @@ -340,7 +145,7 @@ bool OpTrait::impl::verifyOperandsAreIntegerLike(const Operation *op) { return false; } -bool OpTrait::impl::verifySameTypeOperands(const Operation *op) { +bool OpTrait::impl::verifySameTypeOperands(const OperationInst *op) { // Zero or one operand always have the "same" type. unsigned nOperands = op->getNumOperands(); if (nOperands < 2) @@ -354,25 +159,26 @@ bool OpTrait::impl::verifySameTypeOperands(const Operation *op) { return false; } -bool OpTrait::impl::verifyZeroResult(const Operation *op) { +bool OpTrait::impl::verifyZeroResult(const OperationInst *op) { if (op->getNumResults() != 0) return op->emitOpError("requires zero results"); return false; } -bool OpTrait::impl::verifyOneResult(const Operation *op) { +bool OpTrait::impl::verifyOneResult(const OperationInst *op) { if (op->getNumResults() != 1) return op->emitOpError("requires one result"); return false; } -bool OpTrait::impl::verifyNResults(const Operation *op, unsigned numOperands) { +bool OpTrait::impl::verifyNResults(const OperationInst *op, + unsigned numOperands) { if (op->getNumResults() != numOperands) return op->emitOpError("expected " + Twine(numOperands) + " results"); return false; } -bool OpTrait::impl::verifyAtLeastNResults(const Operation *op, +bool OpTrait::impl::verifyAtLeastNResults(const OperationInst *op, unsigned numOperands) { if (op->getNumResults() < numOperands) return op->emitOpError("expected " + Twine(numOperands) + @@ -401,7 +207,7 @@ static bool verifyShapeMatch(Type type1, Type type2) { return false; } -bool OpTrait::impl::verifySameOperandsAndResultShape(const Operation *op) { +bool OpTrait::impl::verifySameOperandsAndResultShape(const OperationInst *op) { if (op->getNumOperands() == 0 || op->getNumResults() == 0) return true; @@ -419,7 +225,7 @@ bool OpTrait::impl::verifySameOperandsAndResultShape(const Operation *op) { return false; } -bool OpTrait::impl::verifySameOperandsAndResultType(const Operation *op) { +bool OpTrait::impl::verifySameOperandsAndResultType(const OperationInst *op) { if (op->getNumOperands() == 0 || op->getNumResults() == 0) return true; @@ -438,8 +244,8 @@ bool OpTrait::impl::verifySameOperandsAndResultType(const Operation *op) { } static bool verifyBBArguments( - llvm::iterator_range operands, - const BasicBlock *destBB, const Operation *op) { + llvm::iterator_range operands, + const BasicBlock *destBB, const OperationInst *op) { unsigned operandCount = std::distance(operands.begin(), operands.end()); if (operandCount != destBB->getNumArguments()) return op->emitError("branch has " + Twine(operandCount) + @@ -455,9 +261,9 @@ static bool verifyBBArguments( return false; } -static bool verifyTerminatorSuccessors(const Operation *op) { +static bool verifyTerminatorSuccessors(const OperationInst *op) { // Verify that the operands lines up with the BB arguments in the successor. - const Function *fn = op->getOperationFunction(); + const Function *fn = op->getFunction(); for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) { auto *succ = op->getSuccessor(i); if (succ->getFunction() != fn) @@ -468,17 +274,15 @@ static bool verifyTerminatorSuccessors(const Operation *op) { return false; } -bool OpTrait::impl::verifyIsTerminator(const Operation *op) { +bool OpTrait::impl::verifyIsTerminator(const OperationInst *op) { // Verify that the operation is at the end of the respective parent block. - if (op->getOperationFunction()->isML()) { - auto *stmt = cast(op); - StmtBlock *block = stmt->getBlock(); - if (!block || block->getContainingStmt() || &block->back() != stmt) + if (op->getFunction()->isML()) { + StmtBlock *block = op->getBlock(); + if (!block || block->getContainingStmt() || &block->back() != op) return op->emitOpError("must be the last statement in the ML function"); } else { - auto *inst = cast(op); - const BasicBlock *block = inst->getBlock(); - if (!block || &block->back() != inst) + const BasicBlock *block = op->getBlock(); + if (!block || &block->back() != op) return op->emitOpError( "must be the last instruction in the parent basic block."); } @@ -489,7 +293,7 @@ bool OpTrait::impl::verifyIsTerminator(const Operation *op) { return false; } -bool OpTrait::impl::verifyResultsAreBoolLike(const Operation *op) { +bool OpTrait::impl::verifyResultsAreBoolLike(const OperationInst *op) { for (auto *result : op->getResults()) { auto elementType = getTensorOrVectorElementType(result->getType()); auto intType = elementType.dyn_cast(); @@ -501,7 +305,7 @@ bool OpTrait::impl::verifyResultsAreBoolLike(const Operation *op) { return false; } -bool OpTrait::impl::verifyResultsAreFloatLike(const Operation *op) { +bool OpTrait::impl::verifyResultsAreFloatLike(const OperationInst *op) { for (auto *result : op->getResults()) { if (!getTensorOrVectorElementType(result->getType()).isa()) return op->emitOpError("requires a floating point type"); @@ -510,7 +314,7 @@ bool OpTrait::impl::verifyResultsAreFloatLike(const Operation *op) { return false; } -bool OpTrait::impl::verifyResultsAreIntegerLike(const Operation *op) { +bool OpTrait::impl::verifyResultsAreIntegerLike(const OperationInst *op) { for (auto *result : op->getResults()) { auto type = getTensorOrVectorElementType(result->getType()); if (!type.isIntOrIndex()) @@ -543,7 +347,7 @@ bool impl::parseBinaryOp(OpAsmParser *parser, OperationState *result) { parser->addTypeToList(type, result->types); } -void impl::printBinaryOp(const Operation *op, OpAsmPrinter *p) { +void impl::printBinaryOp(const OperationInst *op, OpAsmPrinter *p) { *p << op->getName() << ' ' << *op->getOperand(0) << ", " << *op->getOperand(1); p->printOptionalAttrDict(op->getAttrs()); @@ -569,7 +373,7 @@ bool impl::parseCastOp(OpAsmParser *parser, OperationState *result) { parser->addTypeToList(dstType, result->types); } -void impl::printCastOp(const Operation *op, OpAsmPrinter *p) { +void impl::printCastOp(const OperationInst *op, OpAsmPrinter *p) { *p << op->getName() << ' ' << *op->getOperand(0) << " : " << op->getOperand(0)->getType() << " to " << op->getResult(0)->getType(); } diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index 9e4d8bb180c..8c41d488a8b 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -58,12 +58,14 @@ void Pattern::anchor() {} // RewritePattern and PatternRewriter implementation //===----------------------------------------------------------------------===// -void RewritePattern::rewrite(Operation *op, std::unique_ptr state, +void RewritePattern::rewrite(OperationInst *op, + std::unique_ptr state, PatternRewriter &rewriter) const { rewrite(op, rewriter); } -void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const { +void RewritePattern::rewrite(OperationInst *op, + PatternRewriter &rewriter) const { llvm_unreachable("need to implement one of the rewrite functions!"); } @@ -77,7 +79,7 @@ PatternRewriter::~PatternRewriter() { /// clients can specify a list of other nodes that this replacement may make /// (perhaps transitively) dead. If any of those ops are dead, this will /// remove them as well. -void PatternRewriter::replaceOp(Operation *op, ArrayRef newValues, +void PatternRewriter::replaceOp(OperationInst *op, ArrayRef newValues, ArrayRef valuesToRemoveIfDead) { // Notify the rewriter subclass that we're about to replace this root. notifyRootReplaced(op); @@ -97,7 +99,8 @@ void PatternRewriter::replaceOp(Operation *op, ArrayRef newValues, /// op and newOp are known to have the same number of results, replace the /// uses of op with uses of newOp void PatternRewriter::replaceOpWithResultsOfAnotherOp( - Operation *op, Operation *newOp, ArrayRef valuesToRemoveIfDead) { + OperationInst *op, OperationInst *newOp, + ArrayRef valuesToRemoveIfDead) { assert(op->getNumResults() == newOp->getNumResults() && "replacement op doesn't match results of original op"); if (op->getNumResults() == 1) @@ -117,7 +120,7 @@ void PatternRewriter::replaceOpWithResultsOfAnotherOp( /// should remove if they are dead at this point. /// void PatternRewriter::updatedRootInPlace( - Operation *op, ArrayRef valuesToRemoveIfDead) { + OperationInst *op, ArrayRef valuesToRemoveIfDead) { // Notify the rewriter subclass that we're about to replace this root. notifyRootUpdated(op); @@ -132,7 +135,7 @@ void PatternRewriter::updatedRootInPlace( /// Find the highest benefit pattern available in the pattern set for the DAG /// rooted at the specified node. This returns the pattern if found, or null /// if there are no matches. -auto PatternMatcher::findMatch(Operation *op) -> MatchResult { +auto PatternMatcher::findMatch(OperationInst *op) -> MatchResult { // TODO: This is a completely trivial implementation, expand this in the // future. diff --git a/mlir/lib/IR/Statement.cpp b/mlir/lib/IR/Statement.cpp index 8bff23d41ed..19457efa8c3 100644 --- a/mlir/lib/IR/Statement.cpp +++ b/mlir/lib/IR/Statement.cpp @@ -15,6 +15,7 @@ // limitations under the License. // ============================================================================= +#include "AttributeListStorage.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinOps.h" @@ -65,8 +66,8 @@ Statement::~Statement() { /// Destroy this statement or one of its subclasses. void Statement::destroy() { switch (this->getKind()) { - case Kind::Operation: - cast(this)->destroy(); + case Kind::OperationInst: + cast(this)->destroy(); break; case Kind::For: delete cast(this); @@ -95,7 +96,7 @@ const Value *Statement::getOperand(unsigned idx) const { // it is an induction variable, or it is a result of affine apply operation // with dimension id arguments. bool Value::isValidDim() const { - if (auto *stmt = getDefiningStmt()) { + if (auto *stmt = getDefiningInst()) { // Top level statement or constant operation is ok. if (stmt->getParentStmt() == nullptr || stmt->isa()) return true; @@ -113,7 +114,7 @@ bool Value::isValidDim() const { // the top level, or it is a result of affine apply operation with symbol // arguments. bool Value::isValidSymbol() const { - if (auto *stmt = getDefiningStmt()) { + if (auto *stmt = getDefiningInst()) { // Top level statement or constant operation is ok. if (stmt->getParentStmt() == nullptr || stmt->isa()) return true; @@ -133,8 +134,8 @@ void Statement::setOperand(unsigned idx, Value *value) { unsigned Statement::getNumOperands() const { switch (getKind()) { - case Kind::Operation: - return cast(this)->getNumOperands(); + case Kind::OperationInst: + return cast(this)->getNumOperands(); case Kind::For: return cast(this)->getNumOperands(); case Kind::If: @@ -144,8 +145,8 @@ unsigned Statement::getNumOperands() const { MutableArrayRef Statement::getStmtOperands() { switch (getKind()) { - case Kind::Operation: - return cast(this)->getStmtOperands(); + case Kind::OperationInst: + return cast(this)->getStmtOperands(); case Kind::For: return cast(this)->getStmtOperands(); case Kind::If: @@ -177,7 +178,7 @@ bool Statement::emitError(const Twine &message) const { // Returns whether the Statement is a terminator. bool Statement::isTerminator() const { - if (auto *op = dyn_cast(this)) + if (auto *op = dyn_cast(this)) return op->isTerminator(); return false; } @@ -264,11 +265,11 @@ void Statement::dropAllReferences() { } //===----------------------------------------------------------------------===// -// OperationStmt +// OperationInst //===----------------------------------------------------------------------===// -/// Create a new OperationStmt with the specific fields. -OperationStmt *OperationStmt::create(Location location, OperationName name, +/// Create a new OperationInst with the specific fields. +OperationInst *OperationInst::create(Location location, OperationName name, ArrayRef operands, ArrayRef resultTypes, ArrayRef attributes, @@ -285,9 +286,9 @@ OperationStmt *OperationStmt::create(Location location, OperationName name, resultTypes.size(), numSuccessors, numSuccessors, numOperands); void *rawMem = malloc(byteSize); - // Initialize the OperationStmt part of the statement. + // Initialize the OperationInst part of the statement. auto stmt = ::new (rawMem) - OperationStmt(location, name, numOperands, resultTypes.size(), + OperationInst(location, name, numOperands, resultTypes.size(), numSuccessors, attributes, context); // Initialize the results and operands. @@ -355,15 +356,22 @@ OperationStmt *OperationStmt::create(Location location, OperationName name, return stmt; } -OperationStmt::OperationStmt(Location location, OperationName name, +OperationInst::OperationInst(Location location, OperationName name, unsigned numOperands, unsigned numResults, unsigned numSuccessors, ArrayRef attributes, MLIRContext *context) - : Operation(name, attributes, location, context), numOperands(numOperands), - numResults(numResults), numSuccs(numSuccessors) {} + : Statement(Kind::OperationInst, location), numOperands(numOperands), + numResults(numResults), numSuccs(numSuccessors), name(name) { +#ifndef NDEBUG + for (auto elt : attributes) + assert(elt.second != nullptr && "Attributes cannot have null entries"); +#endif -OperationStmt::~OperationStmt() { + this->attrs = AttributeListStorage::get(attributes, context); +} + +OperationInst::~OperationInst() { // Explicitly run the destructors for the operands and results. for (auto &operand : getStmtOperands()) operand.~StmtOperand(); @@ -377,13 +385,27 @@ OperationStmt::~OperationStmt() { successor.~StmtBlockOperand(); } -void OperationStmt::destroy() { - this->~OperationStmt(); +/// Return true if there are no users of any results of this operation. +bool OperationInst::use_empty() const { + for (auto *result : getResults()) + if (!result->use_empty()) + return false; + return true; +} + +ArrayRef OperationInst::getAttrs() const { + if (!attrs) + return {}; + return attrs->getElements(); +} + +void OperationInst::destroy() { + this->~OperationInst(); free(this); } /// Return the context this operation is associated with. -MLIRContext *OperationStmt::getContext() const { +MLIRContext *OperationInst::getContext() const { // If we have a result or operand type, that is a constant time way to get // to the context. if (getNumResults()) @@ -396,9 +418,9 @@ MLIRContext *OperationStmt::getContext() const { return getFunction()->getContext(); } -bool OperationStmt::isReturn() const { return isa(); } +bool OperationInst::isReturn() const { return isa(); } -void OperationStmt::setSuccessor(BasicBlock *block, unsigned index) { +void OperationInst::setSuccessor(BasicBlock *block, unsigned index) { assert(index < getNumSuccessors()); getBlockOperands()[index].set(block); } @@ -413,6 +435,96 @@ void OperationInst::eraseOperand(unsigned index) { Operands[getNumOperands()].~StmtOperand(); } +auto OperationInst::getSuccessorOperands(unsigned index) const + -> llvm::iterator_range { + assert(isTerminator() && "Only terminators have successors."); + unsigned succOperandIndex = getSuccessorOperandIndex(index); + return {const_operand_iterator(this, succOperandIndex), + const_operand_iterator(this, succOperandIndex + + getNumSuccessorOperands(index))}; +} +auto OperationInst::getSuccessorOperands(unsigned index) + -> llvm::iterator_range { + assert(isTerminator() && "Only terminators have successors."); + unsigned succOperandIndex = getSuccessorOperandIndex(index); + return {operand_iterator(this, succOperandIndex), + operand_iterator(this, + succOperandIndex + getNumSuccessorOperands(index))}; +} + +/// If an attribute exists with the specified name, change it to the new +/// value. Otherwise, add a new attribute with the specified name/value. +void OperationInst::setAttr(Identifier name, Attribute value) { + assert(value && "attributes may never be null"); + auto origAttrs = getAttrs(); + + SmallVector newAttrs(origAttrs.begin(), origAttrs.end()); + auto *context = getContext(); + + // If we already have this attribute, replace it. + for (auto &elt : newAttrs) + if (elt.first == name) { + elt.second = value; + attrs = AttributeListStorage::get(newAttrs, context); + return; + } + + // Otherwise, add it. + newAttrs.push_back({name, value}); + attrs = AttributeListStorage::get(newAttrs, context); +} + +/// Remove the attribute with the specified name if it exists. The return +/// value indicates whether the attribute was present or not. +auto OperationInst::removeAttr(Identifier name) -> RemoveResult { + auto origAttrs = getAttrs(); + for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) { + if (origAttrs[i].first == name) { + SmallVector newAttrs; + newAttrs.reserve(origAttrs.size() - 1); + newAttrs.append(origAttrs.begin(), origAttrs.begin() + i); + newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end()); + attrs = AttributeListStorage::get(newAttrs, getContext()); + return RemoveResult::Removed; + } + } + return RemoveResult::NotFound; +} + +/// Attempt to constant fold this operation with the specified constant +/// operand values. If successful, this returns false and fills in the +/// results vector. If not, this returns true and results is unspecified. +bool OperationInst::constantFold(ArrayRef operands, + SmallVectorImpl &results) const { + if (auto *abstractOp = getAbstractOperation()) { + // If we have a registered operation definition matching this one, use it to + // try to constant fold the operation. + if (!abstractOp->constantFoldHook(llvm::cast(this), operands, + results)) + return false; + + // Otherwise, fall back on the dialect hook to handle it. + return abstractOp->dialect.constantFoldHook(llvm::cast(this), + operands, results); + } + + // If this operation hasn't been registered or doesn't have abstract + // operation, fall back to a dialect which matches the prefix. + auto opName = getName().getStringRef(); + if (auto *dialect = getContext()->getRegisteredDialect(opName)) { + return dialect->constantFoldHook(llvm::cast(this), operands, + results); + } + + return true; +} + +/// Emit an error with the op name prefixed, like "'dim' op " which is +/// convenient for verifiers. +bool OperationInst::emitOpError(const Twine &message) const { + return emitError(Twine('\'') + getName().getStringRef() + "' op " + message); +} + //===----------------------------------------------------------------------===// // ForStmt //===----------------------------------------------------------------------===// @@ -625,7 +737,7 @@ Statement *Statement::clone(DenseMap &operandMap, SmallVector operands; SmallVector successors; - if (auto *opStmt = dyn_cast(this)) { + if (auto *opStmt = dyn_cast(this)) { operands.reserve(getNumOperands() + opStmt->getNumSuccessors()); if (!opStmt->isTerminator()) { @@ -653,8 +765,8 @@ Statement *Statement::clone(DenseMap &operandMap, operands.push_back(nullptr); // Remap the successors operands. - for (auto &operand : opStmt->getSuccessorOperands(succ)) - operands.push_back(remapOperand(operand.get())); + for (auto *operand : opStmt->getSuccessorOperands(succ)) + operands.push_back(remapOperand(operand)); } } @@ -662,7 +774,7 @@ Statement *Statement::clone(DenseMap &operandMap, resultTypes.reserve(opStmt->getNumResults()); for (auto *result : opStmt->getResults()) resultTypes.push_back(result->getType()); - auto *newOp = OperationStmt::create(getLoc(), opStmt->getName(), operands, + auto *newOp = OperationInst::create(getLoc(), opStmt->getName(), operands, resultTypes, opStmt->getAttrs(), successors, context); // Remember the mapping of any results. diff --git a/mlir/lib/IR/StmtBlock.cpp b/mlir/lib/IR/StmtBlock.cpp index a50861a3060..cfb09e6bf45 100644 --- a/mlir/lib/IR/StmtBlock.cpp +++ b/mlir/lib/IR/StmtBlock.cpp @@ -100,13 +100,13 @@ void StmtBlock::eraseArgument(unsigned index) { // Terminator management //===----------------------------------------------------------------------===// -OperationStmt *StmtBlock::getTerminator() { +OperationInst *StmtBlock::getTerminator() { if (empty()) return nullptr; // Check if the last instruction is a terminator. auto &backInst = statements.back(); - auto *opStmt = dyn_cast(&backInst); + auto *opStmt = dyn_cast(&backInst); if (!opStmt || !opStmt->isTerminator()) return nullptr; return opStmt; diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp index db58e126e61..41a6d80e2a2 100644 --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -28,29 +28,13 @@ OperationInst *Value::getDefiningInst() { return nullptr; } -/// If this value is the result of an OperationStmt, return the statement -/// that defines it. -OperationStmt *Value::getDefiningStmt() { - if (auto *result = dyn_cast(this)) - return result->getOwner(); - return nullptr; -} - -Operation *Value::getDefiningOperation() { - if (auto *inst = getDefiningInst()) - return inst; - if (auto *stmt = getDefiningStmt()) - return stmt; - return nullptr; -} - -/// Return the function that this Valueis defined in. +/// Return the function that this Value is defined in. Function *Value::getFunction() { switch (getKind()) { case Value::Kind::BlockArgument: return cast(this)->getFunction(); case Value::Kind::StmtResult: - return getDefiningStmt()->getFunction(); + return getDefiningInst()->getFunction(); case Value::Kind::ForStmt: return cast(this)->getFunction(); } @@ -73,8 +57,8 @@ void IRObjectWithUseList::replaceAllUsesWith(IRObjectWithUseList *newValue) { /// Return the context this operation is associated with. MLIRContext *IROperandOwner::getContext() const { switch (getKind()) { - case Kind::OperationStmt: - return cast(this)->getContext(); + case Kind::OperationInst: + return cast(this)->getContext(); case Kind::ForStmt: return cast(this)->getContext(); case Kind::IfStmt: diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 06495eb81ab..35891f5784b 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -103,7 +103,7 @@ private: namespace { using CreateOperationFunction = - std::function; + std::function; /// This class implement support for parsing global entities like types and /// shared entities like SSA names. It is intended to be subclassed by @@ -1915,8 +1915,10 @@ public: // Operations ParseResult parseOperation(const CreateOperationFunction &createOpFunc); - Operation *parseVerboseOperation(const CreateOperationFunction &createOpFunc); - Operation *parseCustomOperation(const CreateOperationFunction &createOpFunc); + OperationInst * + parseVerboseOperation(const CreateOperationFunction &createOpFunc); + OperationInst * + parseCustomOperation(const CreateOperationFunction &createOpFunc); /// Parse a single operation successor and it's operand list. virtual bool parseSuccessorAndUseList(BasicBlock *&dest, @@ -2184,7 +2186,7 @@ FunctionParser::parseOperation(const CreateOperationFunction &createOpFunc) { return ParseFailure; } - Operation *op; + OperationInst *op; if (getToken().is(Token::bare_identifier) || getToken().isKeyword()) op = parseCustomOperation(createOpFunc); else if (getToken().is(Token::string)) @@ -2220,7 +2222,7 @@ FunctionParser::parseOperation(const CreateOperationFunction &createOpFunc) { return ParseSuccess; } -Operation *FunctionParser::parseVerboseOperation( +OperationInst *FunctionParser::parseVerboseOperation( const CreateOperationFunction &createOpFunc) { // Get location information for the operation. @@ -2516,7 +2518,7 @@ private: }; } // end anonymous namespace. -Operation *FunctionParser::parseCustomOperation( +OperationInst *FunctionParser::parseCustomOperation( const CreateOperationFunction &createOpFunc) { auto opLoc = getToken().getLoc(); auto opName = getTokenSpelling(); @@ -2746,7 +2748,7 @@ ParseResult CFGFunctionParser::parseBasicBlock() { // into. builder.setInsertionPointToEnd(block); - auto createOpFunc = [&](const OperationState &result) -> Operation * { + auto createOpFunc = [&](const OperationState &result) -> OperationInst * { return builder.createOperation(result); }; @@ -3149,7 +3151,7 @@ ParseResult MLFunctionParser::parseElseClause(StmtBlock *elseClause) { /// Parse a list of statements ending with `return` or `}` /// ParseResult MLFunctionParser::parseStatements(StmtBlock *block) { - auto createOpFunc = [&](const OperationState &state) -> Operation * { + auto createOpFunc = [&](const OperationState &state) -> OperationInst * { return builder.createOperation(state); }; diff --git a/mlir/lib/StandardOps/StandardOps.cpp b/mlir/lib/StandardOps/StandardOps.cpp index 7611c6e741b..19a4c8d1afe 100644 --- a/mlir/lib/StandardOps/StandardOps.cpp +++ b/mlir/lib/StandardOps/StandardOps.cpp @@ -56,7 +56,7 @@ struct MemRefCastFolder : public RewritePattern { MemRefCastFolder(StringRef rootOpName, MLIRContext *context) : RewritePattern(rootOpName, 1, context) {} - PatternMatchResult match(Operation *op) const override { + PatternMatchResult match(OperationInst *op) const override { for (auto *operand : op->getOperands()) if (matchPattern(operand, m_Op())) return matchSuccess(); @@ -64,9 +64,9 @@ struct MemRefCastFolder : public RewritePattern { return matchFailure(); } - void rewrite(Operation *op, PatternRewriter &rewriter) const override { + void rewrite(OperationInst *op, PatternRewriter &rewriter) const override { for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) - if (auto *memref = op->getOperand(i)->getDefiningOperation()) + if (auto *memref = op->getOperand(i)->getDefiningInst()) if (auto cast = memref->dyn_cast()) op->setOperand(i, cast->getOperand()); rewriter.updatedRootInPlace(op); @@ -122,7 +122,7 @@ struct SimplifyAddX0 : public RewritePattern { SimplifyAddX0(MLIRContext *context) : RewritePattern(AddIOp::getOperationName(), 1, context) {} - PatternMatchResult match(Operation *op) const override { + PatternMatchResult match(OperationInst *op) const override { auto addi = op->cast(); if (matchPattern(addi->getOperand(1), m_Zero())) @@ -130,7 +130,7 @@ struct SimplifyAddX0 : public RewritePattern { return matchFailure(); } - void rewrite(Operation *op, PatternRewriter &rewriter) const override { + void rewrite(OperationInst *op, PatternRewriter &rewriter) const override { rewriter.replaceOp(op, op->getOperand(0)); } }; @@ -228,7 +228,7 @@ struct SimplifyAllocConst : public RewritePattern { SimplifyAllocConst(MLIRContext *context) : RewritePattern(AllocOp::getOperationName(), 1, context) {} - PatternMatchResult match(Operation *op) const override { + PatternMatchResult match(OperationInst *op) const override { auto alloc = op->cast(); // Check to see if any dimensions operands are constants. If so, we can @@ -239,7 +239,7 @@ struct SimplifyAllocConst : public RewritePattern { return matchFailure(); } - void rewrite(Operation *op, PatternRewriter &rewriter) const override { + void rewrite(OperationInst *op, PatternRewriter &rewriter) const override { auto allocOp = op->cast(); auto memrefType = allocOp->getType(); @@ -258,7 +258,7 @@ struct SimplifyAllocConst : public RewritePattern { newShapeConstants.push_back(dimSize); continue; } - auto *defOp = allocOp->getOperand(dynamicDimPos)->getDefiningOperation(); + auto *defOp = allocOp->getOperand(dynamicDimPos)->getDefiningInst(); OpPointer constantIndexOp; if (defOp && (constantIndexOp = defOp->dyn_cast())) { // Dynamic shape dimension will be folded. @@ -1105,7 +1105,7 @@ struct SimplifyMulX1 : public RewritePattern { SimplifyMulX1(MLIRContext *context) : RewritePattern(MulIOp::getOperationName(), 1, context) {} - PatternMatchResult match(Operation *op) const override { + PatternMatchResult match(OperationInst *op) const override { auto muli = op->cast(); if (matchPattern(muli->getOperand(1), m_One())) @@ -1113,7 +1113,7 @@ struct SimplifyMulX1 : public RewritePattern { return matchFailure(); } - void rewrite(Operation *op, PatternRewriter &rewriter) const override { + void rewrite(OperationInst *op, PatternRewriter &rewriter) const override { rewriter.replaceOp(op, op->getOperand(0)); } }; @@ -1308,14 +1308,14 @@ struct SimplifyXMinusX : public RewritePattern { SimplifyXMinusX(MLIRContext *context) : RewritePattern(SubIOp::getOperationName(), 1, context) {} - PatternMatchResult match(Operation *op) const override { + PatternMatchResult match(OperationInst *op) const override { auto subi = op->cast(); if (subi->getOperand(0) == subi->getOperand(1)) return matchSuccess(); return matchFailure(); } - void rewrite(Operation *op, PatternRewriter &rewriter) const override { + void rewrite(OperationInst *op, PatternRewriter &rewriter) const override { auto subi = op->cast(); auto result = rewriter.create(op->getLoc(), 0, subi->getType()); diff --git a/mlir/lib/SuperVectorOps/SuperVectorOps.cpp b/mlir/lib/SuperVectorOps/SuperVectorOps.cpp index 02b4c4674ab..e4243a6de25 100644 --- a/mlir/lib/SuperVectorOps/SuperVectorOps.cpp +++ b/mlir/lib/SuperVectorOps/SuperVectorOps.cpp @@ -86,14 +86,14 @@ void VectorTransferReadOp::build(Builder *builder, OperationState *result, result->addTypes(vectorType); } -llvm::iterator_range +llvm::iterator_range VectorTransferReadOp::getIndices() { auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset; auto end = begin + getMemRefType().getRank(); return {begin, end}; } -llvm::iterator_range +llvm::iterator_range VectorTransferReadOp::getIndices() const { auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset; auto end = begin + getMemRefType().getRank(); @@ -303,14 +303,14 @@ void VectorTransferWriteOp::build(Builder *builder, OperationState *result, builder->getAffineMapAttr(permutationMap)); } -llvm::iterator_range +llvm::iterator_range VectorTransferWriteOp::getIndices() { auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset; auto end = begin + getMemRefType().getRank(); return {begin, end}; } -llvm::iterator_range +llvm::iterator_range VectorTransferWriteOp::getIndices() const { auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset; auto end = begin + getMemRefType().getRank(); diff --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp index a4d474dc24a..713aa0b1791 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp @@ -111,8 +111,8 @@ private: /// descriptor and get the pointer to the element indexed by the linearized /// subscript. Return nullptr on errors. llvm::Value *emitMemRefElementAccess( - const Value *memRef, const Operation &op, - llvm::iterator_range opIndices); + const Value *memRef, const OperationInst &op, + llvm::iterator_range opIndices); /// Emit LLVM IR corresponding to the given Alloc `op`. In particular, create /// a Value for the MemRef descriptor, store any dynamic sizes passed to @@ -307,7 +307,7 @@ ModuleLowerer::linearizeSubscripts(ArrayRef indices, // the location of `op` and return true. Return false if the type is supported. // TODO(zinenko): this function should disappear when the conversion fully // supports MemRefs. -static bool checkSupportedMemRefType(MemRefType type, const Operation &op) { +static bool checkSupportedMemRefType(MemRefType type, const OperationInst &op) { if (!type.getAffineMaps().empty()) return op.emitError("NYI: memrefs with affine maps"); if (type.getMemorySpace() != 0) @@ -316,8 +316,8 @@ static bool checkSupportedMemRefType(MemRefType type, const Operation &op) { } llvm::Value *ModuleLowerer::emitMemRefElementAccess( - const Value *memRef, const Operation &op, - llvm::iterator_range opIndices) { + const Value *memRef, const OperationInst &op, + llvm::iterator_range opIndices) { auto type = memRef->getType().dyn_cast(); assert(type && "expected memRef value to have a MemRef type"); if (checkSupportedMemRefType(type, op)) @@ -425,7 +425,7 @@ ModuleLowerer::emitMemRefDealloc(ConstOpPointer deallocOp) { // This forcibly recreates the APFloat with IEEESingle semantics to make sure // LLVM constructs a `float` constant. static llvm::ConstantFP *getFloatConstant(APFloat APvalue, - const Operation &inst, + const OperationInst &inst, llvm::LLVMContext *context) { bool unused; APFloat::opStatus status = APvalue.convert( diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index 575ae2e1c9b..4b198589e2c 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -50,10 +50,10 @@ struct CSE : public FunctionPass { }; // TODO(riverriddle) Handle commutative operations. -struct SimpleOperationInfo : public llvm::DenseMapInfo { - static unsigned getHashValue(const Operation *op) { +struct SimpleOperationInfo : public llvm::DenseMapInfo { + static unsigned getHashValue(const OperationInst *op) { // Hash the operations based upon their: - // - Operation Name + // - OperationInst Name // - Attributes // - Result Types // - Operands @@ -62,7 +62,7 @@ struct SimpleOperationInfo : public llvm::DenseMapInfo { hash_combine_range(op->result_type_begin(), op->result_type_end()), hash_combine_range(op->operand_begin(), op->operand_end())); } - static bool isEqual(const Operation *lhs, const Operation *rhs) { + static bool isEqual(const OperationInst *lhs, const OperationInst *rhs) { if (lhs == rhs) return true; if (lhs == getTombstoneKey() || lhs == getEmptyKey() || @@ -93,8 +93,8 @@ struct SimpleOperationInfo : public llvm::DenseMapInfo { struct CSEImpl { using AllocatorTy = llvm::RecyclingAllocator< llvm::BumpPtrAllocator, - llvm::ScopedHashTableVal>; - using ScopedMapTy = llvm::ScopedHashTable>; + using ScopedMapTy = llvm::ScopedHashTable; /// Erase any operations that were marked as dead during simplification. @@ -104,7 +104,7 @@ struct CSEImpl { } /// Attempt to eliminate a redundant operation. - void simplifyOperation(Operation *op) { + void simplifyOperation(OperationInst *op) { // TODO(riverriddle) We currently only eliminate non side-effecting // operations. if (!op->hasNoSideEffect()) @@ -141,7 +141,7 @@ struct CSEImpl { ScopedMapTy knownValues; /// Operations marked as dead and to be erased. - std::vector opsToErase; + std::vector opsToErase; }; /// Common sub-expression elimination for CFG functions. @@ -224,7 +224,7 @@ struct MLCSE : public CSEImpl, StmtWalker { StmtWalker::walk(Start, End); } - void visitOperationStmt(OperationStmt *stmt) { simplifyOperation(stmt); } + void visitOperationInst(OperationInst *stmt) { simplifyOperation(stmt); } }; } // end anonymous namespace diff --git a/mlir/lib/Transforms/ComposeAffineMaps.cpp b/mlir/lib/Transforms/ComposeAffineMaps.cpp index 84507b91703..365533561f9 100644 --- a/mlir/lib/Transforms/ComposeAffineMaps.cpp +++ b/mlir/lib/Transforms/ComposeAffineMaps.cpp @@ -42,12 +42,12 @@ namespace { // with no remaining uses are collected and erased after the walk. // TODO(andydavis) Remove this when Chris adds instruction combiner pass. struct ComposeAffineMaps : public FunctionPass, StmtWalker { - std::vector affineApplyOpsToErase; + std::vector affineApplyOpsToErase; explicit ComposeAffineMaps() : FunctionPass(&ComposeAffineMaps::passID) {} using StmtListType = llvm::iplist; void walk(StmtListType::iterator Start, StmtListType::iterator End); - void visitOperationStmt(OperationStmt *stmt); + void visitOperationInst(OperationInst *stmt); PassResult runOnMLFunction(MLFunction *f) override; using StmtWalker::walk; @@ -72,7 +72,7 @@ void ComposeAffineMaps::walk(StmtListType::iterator Start, } } -void ComposeAffineMaps::visitOperationStmt(OperationStmt *opStmt) { +void ComposeAffineMaps::visitOperationInst(OperationInst *opStmt) { if (auto affineApplyOp = opStmt->dyn_cast()) { forwardSubstitute(affineApplyOp); bool allUsesEmpty = true; diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index b6b1dec7b17..a83e625c240 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -31,13 +31,14 @@ struct ConstantFold : public FunctionPass, StmtWalker { // All constants in the function post folding. SmallVector existingConstants; - // Operation statements that were folded and that need to be erased. - std::vector opStmtsToErase; + // Operations that were folded and that need to be erased. + std::vector opStmtsToErase; using ConstantFactoryType = std::function; - bool foldOperation(Operation *op, SmallVectorImpl &existingConstants, + bool foldOperation(OperationInst *op, + SmallVectorImpl &existingConstants, ConstantFactoryType constantFactory); - void visitOperationStmt(OperationStmt *stmt); + void visitOperationInst(OperationInst *stmt); void visitForStmt(ForStmt *stmt); PassResult runOnCFGFunction(CFGFunction *f) override; PassResult runOnMLFunction(MLFunction *f) override; @@ -52,7 +53,7 @@ char ConstantFold::passID = 0; /// constants are found, we keep track of them in the existingConstants list. /// /// This returns false if the operation was successfully folded. -bool ConstantFold::foldOperation(Operation *op, +bool ConstantFold::foldOperation(OperationInst *op, SmallVectorImpl &existingConstants, ConstantFactoryType constantFactory) { // If this operation is already a constant, just remember it for cleanup @@ -67,7 +68,7 @@ bool ConstantFold::foldOperation(Operation *op, SmallVector operandConstants; for (auto *operand : op->getOperands()) { Attribute operandCst = nullptr; - if (auto *operandOp = operand->getDefiningOperation()) { + if (auto *operandOp = operand->getDefiningInst()) { if (auto operandConstantOp = operandOp->dyn_cast()) operandCst = operandConstantOp->getValue(); } @@ -138,8 +139,8 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) { return success(); } -// Override the walker's operation statement visit for constant folding. -void ConstantFold::visitOperationStmt(OperationStmt *stmt) { +// Override the walker's operation visiter for constant folding. +void ConstantFold::visitOperationInst(OperationInst *stmt) { auto constantFactory = [&](Attribute value, Type type) -> Value * { FuncBuilder builder(stmt); return builder.create(stmt->getLoc(), value, type); @@ -172,7 +173,7 @@ PassResult ConstantFold::runOnMLFunction(MLFunction *f) { // around dead constants. Check for them now and remove them. for (auto *cst : existingConstants) { if (cst->use_empty()) - cst->getDefiningStmt()->erase(); + cst->getDefiningInst()->erase(); } return success(); diff --git a/mlir/lib/Transforms/ConvertToCFG.cpp b/mlir/lib/Transforms/ConvertToCFG.cpp index fefe9f700c4..ca158a17e92 100644 --- a/mlir/lib/Transforms/ConvertToCFG.cpp +++ b/mlir/lib/Transforms/ConvertToCFG.cpp @@ -47,14 +47,14 @@ public: void visitForStmt(ForStmt *forStmt); void visitIfStmt(IfStmt *ifStmt); - void visitOperationStmt(OperationStmt *opStmt); + void visitOperationInst(OperationInst *opStmt); private: Value *getConstantIndexValue(int64_t value); void visitStmtBlock(StmtBlock *stmtBlock); Value *buildMinMaxReductionSeq( Location loc, CmpIPredicate predicate, - llvm::iterator_range values); + llvm::iterator_range values); CFGFunction *cfgFunc; FuncBuilder builder; @@ -64,7 +64,7 @@ private: }; } // end anonymous namespace -// Return a vector of OperationStmt's arguments as Values. For each +// Return a vector of OperationInst's arguments as Values. For each // statement operands, represented as Value, lookup its Value conterpart in // the valueRemapping table. static llvm::SmallVector @@ -84,7 +84,7 @@ operandsAs(Statement *opStmt, // remains the same but the values must be updated to be Values. Update the // mapping Value->Value as the conversion is performed. The operation // instruction is appended to current block (end of SESE region). -void FunctionConverter::visitOperationStmt(OperationStmt *opStmt) { +void FunctionConverter::visitOperationInst(OperationInst *opStmt) { // Set up basic operation state (context, name, operands). OperationState state(cfgFunc->getContext(), opStmt->getLoc(), opStmt->getName()); @@ -136,7 +136,7 @@ void FunctionConverter::visitStmtBlock(StmtBlock *stmtBlock) { // recognize as a reduction by the subsequent passes. Value *FunctionConverter::buildMinMaxReductionSeq( Location loc, CmpIPredicate predicate, - llvm::iterator_range values) { + llvm::iterator_range values) { assert(!llvm::empty(values) && "empty min/max chain"); auto valueIt = values.begin(); @@ -600,7 +600,7 @@ void ModuleConverter::replaceReferences() { // operation "op" and containing an MLFunction-typed value with the result of // converting "func" to a CFGFunction. static inline void replaceMLFunctionAttr( - Operation &op, Identifier name, const Function *func, + OperationInst &op, Identifier name, const Function *func, const llvm::DenseMap &generatedFuncs) { if (!func->isML()) return; diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index cc2ca32421b..ed184dc9421 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -67,7 +67,7 @@ struct DmaGeneration : public FunctionPass, StmtWalker { PassResult runOnMLFunction(MLFunction *f) override; void runOnForStmt(ForStmt *forStmt); - void visitOperationStmt(OperationStmt *opStmt); + void visitOperationInst(OperationInst *opStmt); bool generateDma(const MemRefRegion ®ion, ForStmt *forStmt, uint64_t *sizeInBytes); @@ -108,7 +108,7 @@ FunctionPass *mlir::createDmaGenerationPass(unsigned slowMemorySpace, // Gather regions to promote to buffers in faster memory space. // TODO(bondhugula): handle store op's; only load's handled for now. -void DmaGeneration::visitOperationStmt(OperationStmt *opStmt) { +void DmaGeneration::visitOperationInst(OperationInst *opStmt) { if (auto loadOp = opStmt->dyn_cast()) { if (loadOp->getMemRefType().getMemorySpace() != slowMemorySpace) return; diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index c86eec3d276..67b36cfda30 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -80,7 +80,7 @@ char LoopFusion::passID = 0; FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; } -static void getSingleMemRefAccess(OperationStmt *loadOrStoreOpStmt, +static void getSingleMemRefAccess(OperationInst *loadOrStoreOpStmt, MemRefAccess *access) { if (auto loadOp = loadOrStoreOpStmt->dyn_cast()) { access->memref = loadOp->getMemRef(); @@ -112,8 +112,8 @@ struct FusionCandidate { MemRefAccess dstAccess; }; -static FusionCandidate buildFusionCandidate(OperationStmt *srcStoreOpStmt, - OperationStmt *dstLoadOpStmt) { +static FusionCandidate buildFusionCandidate(OperationInst *srcStoreOpStmt, + OperationInst *dstLoadOpStmt) { FusionCandidate candidate; // Get store access for src loop nest. getSingleMemRefAccess(srcStoreOpStmt, &candidate.srcAccess); @@ -123,7 +123,7 @@ static FusionCandidate buildFusionCandidate(OperationStmt *srcStoreOpStmt, } // Returns the loop depth of the loop nest surrounding 'opStmt'. -static unsigned getLoopDepth(OperationStmt *opStmt) { +static unsigned getLoopDepth(OperationInst *opStmt) { unsigned loopDepth = 0; auto *currStmt = opStmt->getParentStmt(); ForStmt *currForStmt; @@ -141,15 +141,15 @@ namespace { class LoopNestStateCollector : public StmtWalker { public: SmallVector forStmts; - SmallVector loadOpStmts; - SmallVector storeOpStmts; + SmallVector loadOpStmts; + SmallVector storeOpStmts; bool hasIfStmt = false; void visitForStmt(ForStmt *forStmt) { forStmts.push_back(forStmt); } void visitIfStmt(IfStmt *ifStmt) { hasIfStmt = true; } - void visitOperationStmt(OperationStmt *opStmt) { + void visitOperationInst(OperationInst *opStmt) { if (opStmt->isa()) loadOpStmts.push_back(opStmt); if (opStmt->isa()) @@ -171,10 +171,10 @@ public: unsigned id; // The top-level statment which is (or contains) loads/stores. Statement *stmt; - // List of load op stmts. - SmallVector loads; + // List of load operations. + SmallVector loads; // List of store op stmts. - SmallVector stores; + SmallVector stores; Node(unsigned id, Statement *stmt) : id(id), stmt(stmt) {} // Returns the load op count for 'memref'. @@ -312,8 +312,8 @@ public: } // Adds ops in 'loads' and 'stores' to node at 'id'. - void addToNode(unsigned id, const SmallVectorImpl &loads, - const SmallVectorImpl &stores) { + void addToNode(unsigned id, const SmallVectorImpl &loads, + const SmallVectorImpl &stores) { Node *node = getNode(id); for (auto *loadOpStmt : loads) node->loads.push_back(loadOpStmt); @@ -370,7 +370,7 @@ bool MemRefDependenceGraph::init(MLFunction *f) { } nodes.insert({node.id, node}); } - if (auto *opStmt = dyn_cast(&stmt)) { + if (auto *opStmt = dyn_cast(&stmt)) { if (auto loadOp = opStmt->dyn_cast()) { // Create graph node for top-level load op. Node node(id++, &stmt); @@ -474,7 +474,7 @@ public: if (!isa(dstNode->stmt)) continue; - SmallVector loads = dstNode->loads; + SmallVector loads = dstNode->loads; while (!loads.empty()) { auto *dstLoadOpStmt = loads.pop_back_val(); auto *memref = dstLoadOpStmt->cast()->getMemRef(); diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 183613a2f69..0a3dd65d1f4 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -120,7 +120,7 @@ PassResult LoopUnroll::runOnMLFunction(MLFunction *f) { return hasInnerLoops; } - bool visitOperationStmt(OperationStmt *opStmt) { return false; } + bool visitOperationInst(OperationInst *opStmt) { return false; } // FIXME: can't use base class method for this because that in turn would // need to use the derived class method above. CRTP doesn't allow it, and diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index fd23c341903..e2fd8b66e34 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -185,7 +185,7 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer, // case of GPUs. llvm::SmallVector newResults = {}; if (std::is_same::value) { - b.setInsertionPoint(cast(transfer->getOperation())); + b.setInsertionPoint(cast(transfer->getOperation())); auto *vector = b.create(transfer->getLoc(), vecView->getResult(), ArrayRef{state->zero}) ->getResult(); @@ -193,7 +193,7 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer, } // 6. Free the local buffer. - b.setInsertionPoint(cast(transfer->getOperation())); + b.setInsertionPoint(cast(transfer->getOperation())); b.create(transfer->getLoc(), tmpScalarAlloc); // 7. It is now safe to erase the statement. @@ -207,13 +207,14 @@ public: explicit VectorTransferExpander(MLIRContext *context) : MLLoweringPattern(VectorTransferOpTy::getOperationName(), 1, context) {} - PatternMatchResult match(Operation *op) const override { + PatternMatchResult match(OperationInst *op) const override { if (m_Op().match(op)) return matchSuccess(); return matchFailure(); } - void rewriteOpStmt(Operation *op, MLFuncGlobalLoweringState *funcWiseState, + void rewriteOpStmt(OperationInst *op, + MLFuncGlobalLoweringState *funcWiseState, std::unique_ptr opState, MLFuncLoweringRewriter *rewriter) const override { rewriteAsLoops(&*op->dyn_cast(), rewriter, diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index b4d91b2506c..6f033710798 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -246,8 +246,8 @@ static SmallVector delinearize(unsigned linearIndex, return res; } -static OperationStmt * -instantiate(FuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType, +static OperationInst * +instantiate(FuncBuilder *b, OperationInst *opStmt, VectorType hwVectorType, DenseMap *substitutionsMap); /// Not all Values belong to a program slice scoped within the immediately @@ -263,7 +263,7 @@ static Value *substitute(Value *v, VectorType hwVectorType, DenseMap *substitutionsMap) { auto it = substitutionsMap->find(v); if (it == substitutionsMap->end()) { - auto *opStmt = cast(v->getDefiningOperation()); + auto *opStmt = cast(v->getDefiningInst()); if (opStmt->isa()) { FuncBuilder b(opStmt); auto *inst = instantiate(&b, opStmt, hwVectorType, substitutionsMap); @@ -272,7 +272,7 @@ static Value *substitute(Value *v, VectorType hwVectorType, assert(res.second && "Insertion failed"); return res.first->second; } - v->getDefiningOperation()->emitError("Missing substitution"); + v->getDefiningInst()->emitError("Missing substitution"); return nullptr; } return it->second; @@ -384,7 +384,7 @@ reindexAffineIndices(FuncBuilder *b, VectorType hwVectorType, /// - constant splat is replaced by constant splat of `hwVectorType`. /// TODO(ntv): add more substitutions on a per-need basis. static SmallVector -materializeAttributes(OperationStmt *opStmt, VectorType hwVectorType) { +materializeAttributes(OperationInst *opStmt, VectorType hwVectorType) { SmallVector res; for (auto a : opStmt->getAttrs()) { if (auto splat = a.second.dyn_cast()) { @@ -404,8 +404,8 @@ materializeAttributes(OperationStmt *opStmt, VectorType hwVectorType) { /// substitutionsMap. /// /// If the underlying substitution fails, this fails too and returns nullptr. -static OperationStmt * -instantiate(FuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType, +static OperationInst * +instantiate(FuncBuilder *b, OperationInst *opStmt, VectorType hwVectorType, DenseMap *substitutionsMap) { assert(!opStmt->isa() && "Should call the function specialized for VectorTransferReadOp"); @@ -475,7 +475,7 @@ static AffineMap projectedPermutationMap(VectorTransferOpTy *transfer, /// `hwVectorType` int the covering of the super-vector type. For a more /// detailed description of the problem, see the description of /// reindexAffineIndices. -static OperationStmt * +static OperationInst * instantiate(FuncBuilder *b, VectorTransferReadOp *read, VectorType hwVectorType, ArrayRef hwVectorInstance, DenseMap *substitutionsMap) { @@ -486,7 +486,7 @@ instantiate(FuncBuilder *b, VectorTransferReadOp *read, VectorType hwVectorType, auto cloned = b->create( read->getLoc(), hwVectorType, read->getMemRef(), affineIndices, projectedPermutationMap(read, hwVectorType), read->getPaddingValue()); - return cast(cloned->getOperation()); + return cast(cloned->getOperation()); } /// Creates an instantiated version of `write` for the instance of @@ -495,7 +495,7 @@ instantiate(FuncBuilder *b, VectorTransferReadOp *read, VectorType hwVectorType, /// `hwVectorType` int the covering of th3e super-vector type. For a more /// detailed description of the problem, see the description of /// reindexAffineIndices. -static OperationStmt * +static OperationInst * instantiate(FuncBuilder *b, VectorTransferWriteOp *write, VectorType hwVectorType, ArrayRef hwVectorInstance, DenseMap *substitutionsMap) { @@ -508,7 +508,7 @@ instantiate(FuncBuilder *b, VectorTransferWriteOp *write, substitute(write->getVector(), hwVectorType, substitutionsMap), write->getMemRef(), affineIndices, projectedPermutationMap(write, hwVectorType)); - return cast(cloned->getOperation()); + return cast(cloned->getOperation()); } /// Returns `true` if stmt instance is properly cloned and inserted, false @@ -544,7 +544,7 @@ static bool instantiateMaterialization(Statement *stmt, // Create a builder here for unroll-and-jam effects. FuncBuilder b(stmt); - auto *opStmt = cast(stmt); + auto *opStmt = cast(stmt); if (auto write = opStmt->dyn_cast()) { instantiate(&b, write, state->hwVectorType, state->hwVectorInstance, state->substitutionsMap); @@ -620,8 +620,7 @@ static bool emitSlice(MaterializationState *state, } LLVM_DEBUG(dbgs() << "\nMLFunction is now\n"); - LLVM_DEBUG( - cast((*slice)[0])->getOperationFunction()->print(dbgs())); + LLVM_DEBUG(cast((*slice)[0])->getFunction()->print(dbgs())); // slice are topologically sorted, we can just erase them in reverse // order. Reverse iterator does not just work simply with an operator* @@ -652,7 +651,7 @@ static bool emitSlice(MaterializationState *state, /// because we currently disallow vectorization of defs that come from another /// scope. static bool materialize(MLFunction *f, - const SetVector &terminators, + const SetVector &terminators, MaterializationState *state) { DenseSet seen; for (auto *term : terminators) { @@ -724,7 +723,7 @@ PassResult MaterializeVectorsPass::runOnMLFunction(MLFunction *f) { // Capture terminators; i.e. vector_transfer_write ops involving a strict // super-vector of subVectorType. auto filter = [subVectorType](const Statement &stmt) { - const auto &opStmt = cast(stmt); + const auto &opStmt = cast(stmt); if (!opStmt.isa()) { return false; } @@ -732,9 +731,9 @@ PassResult MaterializeVectorsPass::runOnMLFunction(MLFunction *f) { }; auto pat = Op(filter); auto matches = pat.match(f); - SetVector terminators; + SetVector terminators; for (auto m : matches) { - terminators.insert(cast(m.first)); + terminators.insert(cast(m.first)); } auto fail = materialize(f, terminators, &state); diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index ce2fac72933..0096cd7be2d 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -64,7 +64,7 @@ FunctionPass *mlir::createPipelineDataTransferPass() { // Returns the position of the tag memref operand given a DMA statement. // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are // added. TODO(b/117228571) -static unsigned getTagMemRefPos(const OperationStmt &dmaStmt) { +static unsigned getTagMemRefPos(const OperationInst &dmaStmt) { assert(dmaStmt.isa() || dmaStmt.isa()); if (dmaStmt.isa()) { // Second to last operand. @@ -179,13 +179,13 @@ static bool checkTagMatch(OpPointer startOp, // Identify matching DMA start/finish statements to overlap computation with. static void findMatchingStartFinishStmts( ForStmt *forStmt, - SmallVectorImpl> + SmallVectorImpl> &startWaitPairs) { // Collect outgoing DMA statements - needed to check for dependences below. SmallVector, 4> outgoingDmaOps; for (auto &stmt : *forStmt->getBody()) { - auto *opStmt = dyn_cast(&stmt); + auto *opStmt = dyn_cast(&stmt); if (!opStmt) continue; OpPointer dmaStartOp; @@ -194,9 +194,9 @@ static void findMatchingStartFinishStmts( outgoingDmaOps.push_back(dmaStartOp); } - SmallVector dmaStartStmts, dmaFinishStmts; + SmallVector dmaStartStmts, dmaFinishStmts; for (auto &stmt : *forStmt->getBody()) { - auto *opStmt = dyn_cast(&stmt); + auto *opStmt = dyn_cast(&stmt); if (!opStmt) continue; // Collect DMA finish statements. @@ -260,7 +260,7 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { return success(); } - SmallVector, 4> startWaitPairs; + SmallVector, 4> startWaitPairs; findMatchingStartFinishStmts(forStmt, startWaitPairs); if (startWaitPairs.empty()) { @@ -293,7 +293,7 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { // operation could have been used on it if it was dynamically shaped in // order to create the double buffer above) if (oldMemRef->use_empty()) - if (auto *allocStmt = oldMemRef->getDefiningStmt()) + if (auto *allocStmt = oldMemRef->getDefiningInst()) allocStmt->erase(); } @@ -309,7 +309,7 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { // If the old tag has no more uses, remove its 'dead' alloc if it was // alloc'ed. if (oldTagMemRef->use_empty()) - if (auto *allocStmt = oldTagMemRef->getDefiningStmt()) + if (auto *allocStmt = oldTagMemRef->getDefiningInst()) allocStmt->erase(); } @@ -329,7 +329,7 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { } else { // If a slice wasn't created, the reachable affine_apply op's from its // operands are the ones that go with it. - SmallVector affineApplyStmts; + SmallVector affineApplyStmts; SmallVector operands(dmaStartStmt->getOperands()); getReachableAffineApplyOps(operands, affineApplyStmts); for (const auto *stmt : affineApplyStmts) { @@ -352,7 +352,7 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { shifts[s++] = stmtShiftMap[&stmt]; LLVM_DEBUG( // Tagging statements with shifts for debugging purposes. - if (auto *opStmt = dyn_cast(&stmt)) { + if (auto *opStmt = dyn_cast(&stmt)) { FuncBuilder b(opStmt); opStmt->setAttr(b.getIdentifier("shift"), b.getI64IntegerAttr(shifts[s - 1])); diff --git a/mlir/lib/Transforms/SimplifyAffineExpr.cpp b/mlir/lib/Transforms/SimplifyAffineExpr.cpp index 048e26ae115..b0b31e01175 100644 --- a/mlir/lib/Transforms/SimplifyAffineExpr.cpp +++ b/mlir/lib/Transforms/SimplifyAffineExpr.cpp @@ -47,7 +47,7 @@ struct SimplifyAffineStructures : public FunctionPass, PassResult runOnCFGFunction(CFGFunction *f) override { return success(); } void visitIfStmt(IfStmt *ifStmt); - void visitOperationStmt(OperationStmt *opStmt); + void visitOperationInst(OperationInst *opStmt); static char passID; }; @@ -75,7 +75,7 @@ void SimplifyAffineStructures::visitIfStmt(IfStmt *ifStmt) { ifStmt->setIntegerSet(simplifyIntegerSet(set)); } -void SimplifyAffineStructures::visitOperationStmt(OperationStmt *opStmt) { +void SimplifyAffineStructures::visitOperationInst(OperationInst *opStmt) { for (auto attr : opStmt->getAttrs()) { if (auto mapAttr = attr.second.dyn_cast()) { MutableAffineMap mMap(mapAttr.getValue()); diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index a690844f7a6..f493e4b090b 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -39,7 +39,7 @@ public: void simplifyFunction(Function *currentFunction, WorklistRewriter &rewriter); - void addToWorklist(Operation *op) { + void addToWorklist(OperationInst *op) { // Check to see if the worklist already contains this op. if (worklistMap.count(op)) return; @@ -48,7 +48,7 @@ public: worklist.push_back(op); } - Operation *popFromWorklist() { + OperationInst *popFromWorklist() { auto *op = worklist.back(); worklist.pop_back(); @@ -60,7 +60,7 @@ public: /// If the specified operation is in the worklist, remove it. If not, this is /// a no-op. - void removeFromWorklist(Operation *op) { + void removeFromWorklist(OperationInst *op) { auto it = worklistMap.find(op); if (it != worklistMap.end()) { assert(worklist[it->second] == op && "malformed worklist data structure"); @@ -76,13 +76,13 @@ private: /// need to be revisited, plus their index in the worklist. This allows us to /// efficiently remove operations from the worklist when they are removed even /// if they aren't the root of a pattern. - std::vector worklist; - DenseMap worklistMap; + std::vector worklist; + DenseMap worklistMap; /// As part of canonicalization, we move constants to the top of the entry /// block of the current function and de-duplicate them. This keeps track of /// constants we have done this for. - DenseMap, Operation *> uniquedConstants; + DenseMap, OperationInst *> uniquedConstants; }; }; // end anonymous namespace @@ -94,22 +94,22 @@ public: WorklistRewriter(GreedyPatternRewriteDriver &driver, MLIRContext *context) : PatternRewriter(context), driver(driver) {} - virtual void setInsertionPoint(Operation *op) = 0; + virtual void setInsertionPoint(OperationInst *op) = 0; // If an operation is about to be removed, make sure it is not in our // worklist anymore because we'd get dangling references to it. - void notifyOperationRemoved(Operation *op) override { + void notifyOperationRemoved(OperationInst *op) override { driver.removeFromWorklist(op); } // When the root of a pattern is about to be replaced, it can trigger // simplifications to its users - make sure to add them to the worklist // before the root is changed. - void notifyRootReplaced(Operation *op) override { + void notifyRootReplaced(OperationInst *op) override { for (auto *result : op->getResults()) // TODO: Add a result->getUsers() iterator. for (auto &user : result->getUses()) { - if (auto *op = dyn_cast(user.getOwner())) + if (auto *op = dyn_cast(user.getOwner())) driver.addToWorklist(op); } @@ -168,7 +168,6 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction, // canonical version. To ensure safe dominance, move the operation to the // top of the function. entry = op; - auto &entryBB = currentFunction->front(); op->moveBefore(&entryBB, entryBB.begin()); continue; @@ -186,7 +185,7 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction, operandConstants.clear(); for (auto *operand : op->getOperands()) { Attribute operandCst; - if (auto *operandOp = operand->getDefiningOperation()) { + if (auto *operandOp = operand->getDefiningInst()) { if (auto operandConstantOp = operandOp->dyn_cast()) operandCst = operandConstantOp->getValue(); } @@ -219,7 +218,7 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction, // // TODO: Add a result->getUsers() iterator. for (auto &operand : op->getResult(i)->getUses()) { - if (auto *op = dyn_cast(operand.getOwner())) + if (auto *op = dyn_cast(operand.getOwner())) addToWorklist(op); } @@ -265,15 +264,15 @@ static void processMLFunction(MLFunction *fn, // Implement the hook for creating operations, and make sure that newly // created ops are added to the worklist for processing. - Operation *createOperation(const OperationState &state) override { + OperationInst *createOperation(const OperationState &state) override { auto *result = builder.createOperation(state); driver.addToWorklist(result); return result; } - void setInsertionPoint(Operation *op) override { + void setInsertionPoint(OperationInst *op) override { // Any new operations should be added before this statement. - builder.setInsertionPoint(cast(op)); + builder.setInsertionPoint(cast(op)); } private: @@ -281,7 +280,7 @@ static void processMLFunction(MLFunction *fn, }; GreedyPatternRewriteDriver driver(std::move(patterns)); - fn->walk([&](OperationStmt *stmt) { driver.addToWorklist(stmt); }); + fn->walk([&](OperationInst *stmt) { driver.addToWorklist(stmt); }); FuncBuilder mlBuilder(fn); MLFuncRewriter rewriter(driver, mlBuilder); @@ -297,13 +296,13 @@ static void processCFGFunction(CFGFunction *fn, // Implement the hook for creating operations, and make sure that newly // created ops are added to the worklist for processing. - Operation *createOperation(const OperationState &state) override { + OperationInst *createOperation(const OperationState &state) override { auto *result = builder.createOperation(state); driver.addToWorklist(result); return result; } - void setInsertionPoint(Operation *op) override { + void setInsertionPoint(OperationInst *op) override { // Any new operations should be added before this instruction. builder.setInsertionPoint(cast(op)); } @@ -315,7 +314,7 @@ static void processCFGFunction(CFGFunction *fn, GreedyPatternRewriteDriver driver(std::move(patterns)); for (auto &bb : *fn) for (auto &op : bb) - if (auto *opInst = dyn_cast(&op)) + if (auto *opInst = dyn_cast(&op)) driver.addToWorklist(opInst); FuncBuilder cfgBuilder(fn); diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index b92e15d7857..4a2831c0a83 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -156,7 +156,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, auto *loopChunk = b->createFor(srcForStmt->getLoc(), lbOperands, lbMap, ubOperands, ubMap, srcForStmt->getStep()); - OperationStmt::OperandMapTy operandMap; + OperationInst::OperandMapTy operandMap; for (auto it = stmtGroupQueue.begin() + offset, e = stmtGroupQueue.end(); it != e; ++it) { diff --git a/mlir/lib/Transforms/Utils/LoweringUtils.cpp b/mlir/lib/Transforms/Utils/LoweringUtils.cpp index 90f4d0c028d..6fca54a9972 100644 --- a/mlir/lib/Transforms/Utils/LoweringUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoweringUtils.cpp @@ -124,7 +124,7 @@ bool mlir::expandAffineApply(AffineApplyOp *op) { if (!op) return true; - FuncBuilder builder(cast(op->getOperation())); + FuncBuilder builder(cast(op->getOperation())); auto affineMap = op->getAffineMap(); for (auto numberedExpr : llvm::enumerate(affineMap.getResults())) { Value *expanded = expandAffineExpr(&builder, numberedExpr.value(), op); diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 2bc1be1b785..c8317c27f74 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -36,7 +36,7 @@ using namespace mlir; /// Return true if this operation dereferences one or more memref's. // Temporary utility: will be replaced when this is modeled through // side-effects/op traits. TODO(b/117228571) -static bool isMemRefDereferencingOp(const Operation &op) { +static bool isMemRefDereferencingOp(const OperationInst &op) { if (op.isa() || op.isa() || op.isa() || op.isa()) return true; @@ -82,10 +82,10 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, assert(oldMemRef->getType().cast().getElementType() == newMemRef->getType().cast().getElementType()); - // Walk all uses of old memref. Statement using the memref gets replaced. + // Walk all uses of old memref. Operation using the memref gets replaced. for (auto it = oldMemRef->use_begin(); it != oldMemRef->use_end();) { StmtOperand &use = *(it++); - auto *opStmt = cast(use.getOwner()); + auto *opStmt = cast(use.getOwner()); // Skip this use if it's not dominated by domStmtFilter. if (domStmtFilter && !dominates(*domStmtFilter, *opStmt)) @@ -124,7 +124,7 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, // TODO(mlir-team): An operation/SSA value should provide a method to // return the position of an SSA result in its defining // operation. - assert(extraIndex->getDefiningStmt()->getNumResults() == 1 && + assert(extraIndex->getDefiningInst()->getNumResults() == 1 && "single result op's expected to generate these indices"); assert((extraIndex->isValidDim() || extraIndex->isValidSymbol()) && "invalid memory op index"); @@ -186,10 +186,10 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, // operands were drawing results from multiple affine apply ops, this also leads // to a collapse into a single affine apply op. The final results of the // composed AffineApplyOp are returned in output parameter 'results'. -OperationStmt * +OperationInst * mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc, ArrayRef operands, - ArrayRef affineApplyOps, + ArrayRef affineApplyOps, SmallVectorImpl *results) { // Create identity map with same number of dimensions as number of operands. auto map = builder->getMultiDimIdentityMap(operands.size()); @@ -216,7 +216,7 @@ mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc, for (unsigned i = 0, e = operands.size(); i < e; ++i) { (*results)[i] = affineApplyOp->getResult(i); } - return cast(affineApplyOp->getOperation()); + return cast(affineApplyOp->getOperation()); } /// Given an operation statement, inserts a new single affine apply operation, @@ -247,19 +247,19 @@ mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc, /// all the affine_apply op's supplying operands to this opStmt do not have any /// uses besides this opStmt. Returns the new affine_apply operation statement /// otherwise. -OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) { +OperationInst *mlir::createAffineComputationSlice(OperationInst *opStmt) { // Collect all operands that are results of affine apply ops. SmallVector subOperands; subOperands.reserve(opStmt->getNumOperands()); for (auto *operand : opStmt->getOperands()) { - auto *defStmt = operand->getDefiningStmt(); + auto *defStmt = operand->getDefiningInst(); if (defStmt && defStmt->isa()) { subOperands.push_back(operand); } } // Gather sequence of AffineApplyOps reachable from 'subOperands'. - SmallVector affineApplyOps; + SmallVector affineApplyOps; getReachableAffineApplyOps(subOperands, affineApplyOps); // Skip transforming if there are no affine maps to compose. if (affineApplyOps.empty()) @@ -313,11 +313,11 @@ OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) { } void mlir::forwardSubstitute(OpPointer affineApplyOp) { - if (!affineApplyOp->getOperation()->getOperationFunction()->isML()) { + if (!affineApplyOp->getOperation()->getFunction()->isML()) { // TODO: Support forward substitution for CFG style functions. return; } - auto *opStmt = cast(affineApplyOp->getOperation()); + auto *opStmt = cast(affineApplyOp->getOperation()); // Iterate through all uses of all results of 'opStmt', forward substituting // into any uses which are AffineApplyOps. for (unsigned resultIndex = 0, e = opStmt->getNumResults(); resultIndex < e; @@ -326,7 +326,7 @@ void mlir::forwardSubstitute(OpPointer affineApplyOp) { for (auto it = result->use_begin(); it != result->use_end();) { StmtOperand &use = *(it++); auto *useStmt = use.getOwner(); - auto *useOpStmt = dyn_cast(useStmt); + auto *useOpStmt = dyn_cast(useStmt); // Skip if use is not AffineApplyOp. if (useOpStmt == nullptr || !useOpStmt->isa()) continue; @@ -379,7 +379,7 @@ bool mlir::constantFoldBounds(ForStmt *forStmt) { : forStmt->getUpperBoundOperands(); for (const auto *operand : boundOperands) { Attribute operandCst; - if (auto *operandOp = operand->getDefiningOperation()) { + if (auto *operandOp = operand->getDefiningInst()) { if (auto operandConstantOp = operandOp->dyn_cast()) operandCst = operandConstantOp->getValue(); } @@ -415,7 +415,8 @@ bool mlir::constantFoldBounds(ForStmt *forStmt) { } void mlir::remapFunctionAttrs( - Operation &op, const DenseMap &remappingTable) { + OperationInst &op, + const DenseMap &remappingTable) { for (auto attr : op.getAttrs()) { // Do the remapping, if we got the same thing back, then it must contain // functions that aren't getting remapped. @@ -451,7 +452,7 @@ void mlir::remapFunctionAttrs( struct MLFnWalker : public StmtWalker { MLFnWalker(const DenseMap &remappingTable) : remappingTable(remappingTable) {} - void visitOperationStmt(OperationStmt *opStmt) { + void visitOperationInst(OperationInst *opStmt) { remapFunctionAttrs(*opStmt, remappingTable); } diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index b8145126770..5abd3a3cfcc 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -98,7 +98,7 @@ void VectorizerTestPass::testVectorShapeRatio(MLFunction *f) { // Only filter statements that operate on a strict super-vector and have one // return. This makes testing easier. auto filter = [subVectorType](const Statement &stmt) { - auto *opStmt = dyn_cast(&stmt); + auto *opStmt = dyn_cast(&stmt); if (!opStmt) { return false; } @@ -116,7 +116,7 @@ void VectorizerTestPass::testVectorShapeRatio(MLFunction *f) { auto pat = Op(filter); auto matches = pat.match(f); for (auto m : matches) { - auto *opStmt = cast(m.first); + auto *opStmt = cast(m.first); // This is a unit test that only checks and prints shape ratio. // As a consequence we write only Ops with a single return type for the // purpose of this test. If we need to test more intricate behavior in the @@ -146,7 +146,7 @@ static MLFunctionMatches matchTestSlicingOps(MLFunction *f) { using matcher::Op; // Match all OpStatements with the kTestSlicingOpName name. auto filter = [](const Statement &stmt) { - const auto &opStmt = cast(stmt); + const auto &opStmt = cast(stmt); return opStmt.getName().getStringRef() == kTestSlicingOpName; }; auto pat = Op(filter); @@ -192,7 +192,7 @@ void VectorizerTestPass::testSlicing(MLFunction *f) { } bool customOpWithAffineMapAttribute(const Statement &stmt) { - const auto &opStmt = cast(stmt); + const auto &opStmt = cast(stmt); return opStmt.getName().getStringRef() == VectorizerTestPass::kTestAffineMapOpName; } @@ -205,7 +205,7 @@ void VectorizerTestPass::testComposeMaps(MLFunction *f) { maps.reserve(matches.size()); std::reverse(matches.begin(), matches.end()); for (auto m : matches) { - auto *opStmt = cast(m.first); + auto *opStmt = cast(m.first); auto map = opStmt->getAttr(VectorizerTestPass::kTestAffineMapAttrName) .cast() .getValue(); diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 80d16475e47..0efe727f5b4 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -722,22 +722,22 @@ namespace { struct VectorizationState { /// Adds an entry of pre/post vectorization statements in the state. - void registerReplacement(OperationStmt *key, OperationStmt *value); + void registerReplacement(OperationInst *key, OperationInst *value); /// When the current vectorization pattern is successful, this erases the /// instructions that were marked for erasure in the proper order and resets /// the internal state for the next pattern. void finishVectorizationPattern(); - // In-order tracking of original OperationStmt that have been vectorized. + // In-order tracking of original OperationInst that have been vectorized. // Erase in reverse order. - SmallVector toErase; - // Set of OperationStmt that have been vectorized (the values in the + SmallVector toErase; + // Set of OperationInst that have been vectorized (the values in the // vectorizationMap for hashed access). The vectorizedSet is used in // particular to filter the statements that have already been vectorized by // this pattern, when iterating over nested loops in this pattern. - DenseSet vectorizedSet; - // Map of old scalar OperationStmt to new vectorized OperationStmt. - DenseMap vectorizationMap; + DenseSet vectorizedSet; + // Map of old scalar OperationInst to new vectorized OperationInst. + DenseMap vectorizationMap; // Map of old scalar Value to new vectorized Value. DenseMap replacementMap; // The strategy drives which loop to vectorize by which amount. @@ -746,17 +746,17 @@ struct VectorizationState { // vectorizeOperations function. They consist of the subset of load operations // that have been vectorized. They can be retrieved from `vectorizationMap` // but it is convenient to keep track of them in a separate data structure. - DenseSet roots; + DenseSet roots; // Terminator statements for the worklist in the vectorizeOperations function. // They consist of the subset of store operations that have been vectorized. // They can be retrieved from `vectorizationMap` but it is convenient to keep // track of them in a separate data structure. Since they do not necessarily // belong to use-def chains starting from loads (e.g storing a constant), we // need to handle them in a post-pass. - DenseSet terminators; + DenseSet terminators; // Checks that the type of `stmt` is StoreOp and adds it to the terminators // set. - void registerTerminator(OperationStmt *stmt); + void registerTerminator(OperationInst *stmt); private: void registerReplacement(const Value *key, Value *value); @@ -764,8 +764,8 @@ private: } // end namespace -void VectorizationState::registerReplacement(OperationStmt *key, - OperationStmt *value) { +void VectorizationState::registerReplacement(OperationInst *key, + OperationInst *value) { LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ commit vectorized op: "); LLVM_DEBUG(key->print(dbgs())); LLVM_DEBUG(dbgs() << " into "); @@ -784,7 +784,7 @@ void VectorizationState::registerReplacement(OperationStmt *key, } } -void VectorizationState::registerTerminator(OperationStmt *stmt) { +void VectorizationState::registerTerminator(OperationInst *stmt) { assert(stmt->isa() && "terminator must be a StoreOp"); assert(terminators.count(stmt) == 0 && "terminator was already inserted previously"); @@ -832,7 +832,7 @@ static bool vectorizeRootOrTerminal(Value *iv, LoadOrStoreOpPointer memoryOp, auto vectorType = VectorType::get(state->strategy->vectorSizes, elementType); // Materialize a MemRef with 1 vector. - auto *opStmt = cast(memoryOp->getOperation()); + auto *opStmt = cast(memoryOp->getOperation()); // For now, vector_transfers must be aligned, operate only on indices with an // identity subset of AffineMap and do not change layout. // TODO(ntv): increase the expressiveness power of vector_transfer operations @@ -847,7 +847,7 @@ static bool vectorizeRootOrTerminal(Value *iv, LoadOrStoreOpPointer memoryOp, opStmt->getLoc(), vectorType, memoryOp->getMemRef(), map(makePtrDynCaster(), memoryOp->getIndices()), permutationMap); state->registerReplacement(opStmt, - cast(transfer->getOperation())); + cast(transfer->getOperation())); } else { state->registerTerminator(opStmt); } @@ -866,7 +866,7 @@ static bool vectorizeForStmt(ForStmt *loop, int64_t step, if (!matcher::isLoadOrStore(stmt)) { return false; } - auto *opStmt = cast(&stmt); + auto *opStmt = cast(&stmt); return state->vectorizationMap.count(opStmt) == 0 && state->vectorizedSet.count(opStmt) == 0 && state->roots.count(opStmt) == 0 && @@ -875,7 +875,7 @@ static bool vectorizeForStmt(ForStmt *loop, int64_t step, auto loadAndStores = matcher::Op(notVectorizedThisPattern); auto matches = loadAndStores.match(loop); for (auto ls : matches) { - auto *opStmt = cast(ls.first); + auto *opStmt = cast(ls.first); auto load = opStmt->dyn_cast(); auto store = opStmt->dyn_cast(); LLVM_DEBUG(opStmt->print(dbgs())); @@ -974,14 +974,14 @@ static Value *vectorizeConstant(Statement *stmt, const ConstantOp &constant, Location loc = stmt->getLoc(); auto vectorType = type.cast(); auto attr = SplatElementsAttr::get(vectorType, constant.getValue()); - auto *constantOpStmt = cast(constant.getOperation()); + auto *constantOpStmt = cast(constant.getOperation()); OperationState state( b.getContext(), loc, constantOpStmt->getName().getStringRef(), {}, {vectorType}, {make_pair(Identifier::get("value", b.getContext()), attr)}); - auto *splat = cast(b.createOperation(state)); + auto *splat = cast(b.createOperation(state)); return splat->getResult(0); } @@ -994,7 +994,7 @@ static Type getVectorType(Value *v, const VectorizationState &state) { if (!VectorType::isValidElementType(v->getType())) { return Type(); } - auto *definingOpStmt = cast(v->getDefiningStmt()); + auto *definingOpStmt = cast(v->getDefiningInst()); if (state.vectorizedSet.count(definingOpStmt) > 0) { return v->getType().cast(); } @@ -1026,7 +1026,7 @@ static Value *vectorizeOperand(Value *operand, Statement *stmt, VectorizationState *state) { LLVM_DEBUG(dbgs() << "\n[early-vect]vectorize operand: "); LLVM_DEBUG(operand->print(dbgs())); - auto *definingStatement = cast(operand->getDefiningStmt()); + auto *definingStatement = cast(operand->getDefiningInst()); // 1. If this value has already been vectorized this round, we are done. if (state->vectorizedSet.count(definingStatement) > 0) { LLVM_DEBUG(dbgs() << " -> already vector operand"); @@ -1049,7 +1049,7 @@ static Value *vectorizeOperand(Value *operand, Statement *stmt, return nullptr; } // 3. vectorize constant. - if (auto constant = operand->getDefiningStmt()->dyn_cast()) { + if (auto constant = operand->getDefiningInst()->dyn_cast()) { return vectorizeConstant(stmt, *constant, getVectorType(operand, *state).cast()); } @@ -1059,17 +1059,17 @@ static Value *vectorizeOperand(Value *operand, Statement *stmt, return nullptr; }; -/// Encodes OperationStmt-specific behavior for vectorization. In general we +/// Encodes OperationInst-specific behavior for vectorization. In general we /// assume that all operands of an op must be vectorized but this is not always /// true. In the future, it would be nice to have a trait that describes how a /// particular operation vectorizes. For now we implement the case distinction /// here. -/// Returns a vectorized form of stmt or nullptr if vectorization fails. +/// Returns a vectorized form of an operation or nullptr if vectorization fails. /// TODO(ntv): consider adding a trait to Op to describe how it gets vectorized. /// Maybe some Ops are not vectorizable or require some tricky logic, we cannot /// do one-off logic here; ideally it would be TableGen'd. -static OperationStmt *vectorizeOneOperationStmt(FuncBuilder *b, - OperationStmt *opStmt, +static OperationInst *vectorizeOneOperationInst(FuncBuilder *b, + OperationInst *opStmt, VectorizationState *state) { // Sanity checks. assert(!opStmt->isa() && @@ -1091,7 +1091,7 @@ static OperationStmt *vectorizeOneOperationStmt(FuncBuilder *b, LLVM_DEBUG(permutationMap.print(dbgs())); auto transfer = b.create( opStmt->getLoc(), vectorValue, memRef, indices, permutationMap); - auto *res = cast(transfer->getOperation()); + auto *res = cast(transfer->getOperation()); LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << *res); // "Terminators" (i.e. StoreOps) are erased on the spot. opStmt->erase(); @@ -1114,8 +1114,8 @@ static OperationStmt *vectorizeOneOperationStmt(FuncBuilder *b, // Create a clone of the op with the proper operands and return types. // TODO(ntv): The following assumes there is always an op with a fixed // name that works both in scalar mode and vector mode. - // TODO(ntv): Is it worth considering an OperationStmt.clone operation - // which changes the type so we can promote an OperationStmt with less + // TODO(ntv): Is it worth considering an OperationInst.clone operation + // which changes the type so we can promote an OperationInst with less // boilerplate? OperationState newOp(b->getContext(), opStmt->getLoc(), opStmt->getName().getStringRef(), operands, types, @@ -1123,22 +1123,22 @@ static OperationStmt *vectorizeOneOperationStmt(FuncBuilder *b, return b->createOperation(newOp); } -/// Iterates over the OperationStmt in the loop and rewrites them using their +/// Iterates over the OperationInst in the loop and rewrites them using their /// vectorized counterpart by: -/// 1. iteratively building a worklist of uses of the OperationStmt vectorized +/// 1. iteratively building a worklist of uses of the OperationInst vectorized /// so far by this pattern; -/// 2. for each OperationStmt in the worklist, create the vector form of this +/// 2. for each OperationInst in the worklist, create the vector form of this /// operation and replace all its uses by the vectorized form. For this step, /// the worklist must be traversed in order; /// 3. verify that all operands of the newly vectorized operation have been /// vectorized by this pattern. static bool vectorizeOperations(VectorizationState *state) { // 1. create initial worklist with the uses of the roots. - SetVector worklist; - auto insertUsesOf = [&worklist, state](Operation *vectorized) { - for (auto *r : cast(vectorized)->getResults()) + SetVector worklist; + auto insertUsesOf = [&worklist, state](OperationInst *vectorized) { + for (auto *r : vectorized->getResults()) for (auto &u : r->getUses()) { - auto *stmt = cast(u.getOwner()); + auto *stmt = cast(u.getOwner()); // Don't propagate to terminals, a separate pass is needed for those. // TODO(ntv)[b/119759136]: use isa<> once Op is implemented. if (state->terminators.count(stmt) > 0) { @@ -1160,7 +1160,7 @@ static bool vectorizeOperations(VectorizationState *state) { // 2. Create vectorized form of the statement. // Insert it just before stmt, on success register stmt as replaced. FuncBuilder b(stmt); - auto *vectorizedStmt = vectorizeOneOperationStmt(&b, stmt, state); + auto *vectorizedStmt = vectorizeOneOperationInst(&b, stmt, state); if (!vectorizedStmt) { return true; } @@ -1169,11 +1169,11 @@ static bool vectorizeOperations(VectorizationState *state) { // Note that we cannot just call replaceAllUsesWith because it may // result in ops with mixed types, for ops whose operands have not all // yet been vectorized. This would be invalid IR. - state->registerReplacement(cast(stmt), vectorizedStmt); + state->registerReplacement(stmt, vectorizedStmt); // 4. Augment the worklist with uses of the statement we just vectorized. // This preserves the proper order in the worklist. - apply(insertUsesOf, ArrayRef{stmt}); + apply(insertUsesOf, ArrayRef{stmt}); } return false; } @@ -1223,12 +1223,12 @@ static bool vectorizeRootMatches(MLFunctionMatches matches, // Form the root operationsthat have been set in the replacementMap. // For now, these roots are the loads for which vector_transfer_read // operations have been inserted. - auto getDefiningOperation = [](const Value *val) { - return const_cast(val)->getDefiningOperation(); + auto getDefiningInst = [](const Value *val) { + return const_cast(val)->getDefiningInst(); }; using ReferenceTy = decltype(*(state.replacementMap.begin())); auto getKey = [](ReferenceTy it) { return it.first; }; - auto roots = map(getDefiningOperation, map(getKey, state.replacementMap)); + auto roots = map(getDefiningInst, map(getKey, state.replacementMap)); // Vectorize the root operations and everything reached by use-def chains // except the terminators (store statements) that need to be post-processed @@ -1240,12 +1240,12 @@ static bool vectorizeRootMatches(MLFunctionMatches matches, } // Finally, vectorize the terminators. If anything fails to vectorize, skip. - auto vectorizeOrFail = [&fail, &state](OperationStmt *stmt) { + auto vectorizeOrFail = [&fail, &state](OperationInst *stmt) { if (fail) { return; } FuncBuilder b(stmt); - auto *res = vectorizeOneOperationStmt(&b, stmt, &state); + auto *res = vectorizeOneOperationInst(&b, stmt, &state); if (res == nullptr) { fail = true; } diff --git a/mlir/test/mlir-rewriter-gen/one-op-one-result.td b/mlir/test/mlir-rewriter-gen/one-op-one-result.td index 55b9e3669e8..2aa88469e76 100644 --- a/mlir/test/mlir-rewriter-gen/one-op-one-result.td +++ b/mlir/test/mlir-rewriter-gen/one-op-one-result.td @@ -36,8 +36,8 @@ def : Pat<(X_AddOp $lhs, $rhs), (Y_AddOp $lhs, T1:$rhs, Y_Const_Attr:$x)>; // CHECK: struct GeneratedConvert0 : public RewritePattern // CHECK: RewritePattern("x.add", 1, context) -// CHECK: PatternMatchResult match(Operation *op) -// CHECK: void rewrite(Operation *op, PatternRewriter &rewriter) +// CHECK: PatternMatchResult match(OperationInst *op) +// CHECK: void rewrite(OperationInst *op, PatternRewriter &rewriter) // CHECK: rewriter.replaceOpWithNewOp(op, op->getResult(0)->getType() // CHECK: void populateWithGenerated diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 0fa23c69566..68174cb99d4 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -239,9 +239,9 @@ void OpEmitter::emit(const Record &def, raw_ostream &os) { emitter.emitAttrGetters(); emitter.emitCanonicalizationPatterns(); - os << "private:\n friend class ::mlir::Operation;\n"; + os << "private:\n friend class ::mlir::OperationInst;\n"; os << " explicit " << emitter.cppClassName() - << "(const Operation* state) : Op(state) {}\n"; + << "(const OperationInst* state) : Op(state) {}\n"; os << "};\n"; emitter.mapOverClassNamespaces( [&os](StringRef ns) { os << "} // end namespace " << ns << "\n"; }); diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index a677cf68aa0..ed9d078a635 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -62,7 +62,7 @@ static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { os << "struct " << rewriteName << " : public RewritePattern {\n" << " " << rewriteName << "(MLIRContext *context) : RewritePattern(" << rootName->getAsString() << ", 1, context) {}\n" - << " PatternMatchResult match(Operation *op) const override {\n" + << " PatternMatchResult match(OperationInst *op) const override {\n" << " // TODO: This just handle 1 result\n" << " if (op->getNumResults() != 1) return matchFailure();\n" << " return matchSuccess();\n }\n"; @@ -85,7 +85,7 @@ static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { SplitString(opName, split, "_"); auto className = join(split, "::"); os << formatv(R"( - void rewrite(Operation *op, PatternRewriter &rewriter) const override { + void rewrite(OperationInst *op, PatternRewriter &rewriter) const override { auto* context = op->getContext(); (void)context; rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())", className); -- cgit v1.2.3 From 69d9e990facf4d07c31b96e47e98cde07228f229 Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Fri, 28 Dec 2018 08:48:09 -0800 Subject: Eliminate the using decls for MLFunction and CFGFunction standardizing on Function. This is step 18/n towards merging instructions and statements, NFC. PiperOrigin-RevId: 227139399 --- mlir/g3doc/LangRef.md | 10 +++---- mlir/g3doc/Rationale.md | 9 +++--- mlir/include/mlir/Analysis/AffineStructures.h | 2 +- mlir/include/mlir/Analysis/Dominance.h | 2 +- mlir/include/mlir/Analysis/HyperRectangularSet.h | 2 +- mlir/include/mlir/Analysis/MLFunctionMatcher.h | 6 ++-- mlir/include/mlir/Analysis/Passes.h | 4 +-- mlir/include/mlir/Analysis/Utils.h | 2 +- mlir/include/mlir/IR/Function.h | 27 +++++++++--------- mlir/include/mlir/IR/FunctionGraphTraits.h | 33 +++++++++++----------- mlir/include/mlir/IR/Statement.h | 3 +- mlir/include/mlir/IR/StmtBlock.h | 14 ++++----- mlir/include/mlir/IR/StmtVisitor.h | 14 ++++----- mlir/include/mlir/Pass.h | 10 +++---- mlir/include/mlir/Transforms/LoopUtils.h | 5 ++-- .../mlir/Transforms/MLPatternLoweringPass.h | 12 ++++---- mlir/include/mlir/Transforms/Passes.h | 2 +- mlir/include/mlir/Transforms/Utils.h | 1 - mlir/lib/Analysis/AffineAnalysis.cpp | 2 +- mlir/lib/Analysis/Dominance.cpp | 2 +- mlir/lib/Analysis/MLFunctionMatcher.cpp | 2 +- mlir/lib/Analysis/MemRefBoundCheck.cpp | 8 +++--- mlir/lib/Analysis/MemRefDependenceCheck.cpp | 10 +++---- mlir/lib/Analysis/OpStats.cpp | 8 +++--- mlir/lib/Analysis/SliceAnalysis.cpp | 4 +-- mlir/lib/Analysis/Utils.cpp | 2 +- mlir/lib/Analysis/Verifier.cpp | 16 +++++------ mlir/lib/IR/AsmPrinter.cpp | 30 ++++++++++---------- mlir/lib/IR/Function.cpp | 14 ++++----- mlir/lib/IR/Statement.cpp | 2 +- mlir/lib/IR/StmtBlock.cpp | 10 +++---- mlir/lib/IR/Value.cpp | 2 +- mlir/lib/Parser/Parser.cpp | 14 ++++----- mlir/lib/StandardOps/StandardOps.cpp | 4 +-- mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp | 12 ++++---- mlir/lib/Transforms/CSE.cpp | 12 ++++---- mlir/lib/Transforms/ComposeAffineMaps.cpp | 8 +++--- mlir/lib/Transforms/ConstantFold.cpp | 8 +++--- mlir/lib/Transforms/ConvertToCFG.cpp | 33 +++++++++++----------- mlir/lib/Transforms/DmaGeneration.cpp | 6 ++-- mlir/lib/Transforms/LoopFusion.cpp | 14 ++++----- mlir/lib/Transforms/LoopTiling.cpp | 10 +++---- mlir/lib/Transforms/LoopUnroll.cpp | 4 +-- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 10 +++---- mlir/lib/Transforms/LowerAffineApply.cpp | 8 +++--- mlir/lib/Transforms/LowerVectorTransfers.cpp | 4 +-- mlir/lib/Transforms/MaterializeVectors.cpp | 10 +++---- mlir/lib/Transforms/PipelineDataTransfer.cpp | 4 +-- mlir/lib/Transforms/SimplifyAffineExpr.cpp | 8 +++--- .../Utils/GreedyPatternRewriteDriver.cpp | 4 +-- mlir/lib/Transforms/Utils/LoopUtils.cpp | 8 +++--- mlir/lib/Transforms/Utils/Utils.cpp | 2 +- .../Vectorization/VectorizerTestPass.cpp | 26 ++++++++--------- mlir/lib/Transforms/Vectorize.cpp | 12 ++++---- mlir/lib/Transforms/ViewFunctionGraph.cpp | 19 ++++++------- 55 files changed, 249 insertions(+), 261 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/g3doc/LangRef.md b/mlir/g3doc/LangRef.md index 4737a03071b..5c562e390af 100644 --- a/mlir/g3doc/LangRef.md +++ b/mlir/g3doc/LangRef.md @@ -1519,9 +1519,9 @@ each followed by its indices, size of the data transfer in terms of the number of elements (of the elemental type of the memref), and a tag memref with its indices. The tag location is used by a dma_wait operation to check for completion. The indices of the source memref, destination memref, and the tag -memref have the same restrictions as any load/store instruction in an MLFunction -(whenever DMA operations appear in ML Functions). This allows powerful static -analysis and transformations in the presence of such DMAs including +memref have the same restrictions as any load/store instruction in an ML +Function (whenever DMA operations appear in ML Functions). This allows powerful +static analysis and transformations in the presence of such DMAs including rescheduling, pipelining / overlap with computation, and checking for matching start/end operations. The source and destination memref need not be of the same dimensionality, but need to have the same elemental type. @@ -1599,7 +1599,7 @@ The arity of indices is the rank of the memref (i.e., if the memref loaded from is of rank 3, then 3 indices are required for the load following the memref identifier). -In an MLFunction, the indices of a load are restricted to SSA values bound to +In an ML Function, the indices of a load are restricted to SSA values bound to surrounding loop induction variables, [symbols](#dimensions-and-symbols), results of a [`constant` operation](#'constant'-operation), or the results of an `affine_apply` operation that can in turn take as arguments all of the @@ -1641,7 +1641,7 @@ Store value to memref location given by indices. The value stored should have the same type as the elemental type of the memref. The number of arguments provided within brackets need to match the rank of the memref. -In an MLFunction, the indices of a store are restricted to SSA values bound to +In an ML Function, the indices of a store are restricted to SSA values bound to surrounding loop induction variables, [symbols](#dimensions-and-symbols), results of a [`constant` operation](#'constant'-operation), or the results of an [`affine_apply`](#'affine_apply'-operation) operation that can in turn take as diff --git a/mlir/g3doc/Rationale.md b/mlir/g3doc/Rationale.md index a2e38466dfc..791fe31fce9 100644 --- a/mlir/g3doc/Rationale.md +++ b/mlir/g3doc/Rationale.md @@ -580,7 +580,7 @@ consideration on demand. We will revisit these discussions when we have more implementation experience and learn more about the challenges and limitations of our current design in practice. -### MLFunction representation alternatives: polyhedral schedule lists vs polyhedral schedules trees vs affine loop/If forms {#mlfunction-representation-alternatives-polyhedral-schedule-lists-vs-polyhedral-schedules-trees-vs-affine-loop-if-forms} +### ML Function representation alternatives: polyhedral schedule lists vs polyhedral schedules trees vs affine loop/If forms {#mlfunction-representation-alternatives-polyhedral-schedule-lists-vs-polyhedral-schedules-trees-vs-affine-loop-if-forms} The current MLIR uses a representation of polyhedral schedules using a tree of if/for loops. We extensively debated the tradeoffs involved in the typical @@ -609,8 +609,9 @@ At a high level, we have two alternatives here: 1. Having two different forms of MLFunctions: an affine loop tree form (AffineLoopTreeFunction) and a polyhedral schedule tree form as two different forms of MLFunctions. Or in effect, having four different forms - for functions in MLIR instead of three: CFGFunction, AffineLoopTreeFunction, - Polyhedral Schedule Tree function, and external functions. + for functions in MLIR instead of three: CFG Function, + AffineLoopTreeFunction, Polyhedral Schedule Tree function, and external + functions. #### Schedule Tree Representation for MLFunctions {#schedule-tree-representation-for-mlfunctions} @@ -785,7 +786,7 @@ extfunc @dma_hbm_to_vmem(memref<1024 x f32, #layout_map0, hbm> %a, representation. 2(b) requires no change, but impacts how cost models look at index and layout maps. -### MLFunction Extensions for "Escaping Scalars" {#mlfunction-extensions-for-"escaping-scalars"} +### ML Function Extensions for "Escaping Scalars" {#mlfunction-extensions-for-"escaping-scalars"} We considered providing a representation for SSA values that are live out of if/else conditional bodies or for loops of ML functions. We ultimately abandoned diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h index 70786a14444..de60dc2115c 100644 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -417,7 +417,7 @@ public: /// the 'for' statement isn't found in the constraint system. Any new /// identifiers that are found in the bound operands of the 'for' statement /// are added as trailing identifiers (either dimensional or symbolic - /// depending on whether the operand is a valid MLFunction symbol). + /// depending on whether the operand is a valid ML Function symbol). // TODO(bondhugula): add support for non-unit strides. bool addForStmtDomain(const ForStmt &forStmt); diff --git a/mlir/include/mlir/Analysis/Dominance.h b/mlir/include/mlir/Analysis/Dominance.h index 5374a451bd1..6d3853e396d 100644 --- a/mlir/include/mlir/Analysis/Dominance.h +++ b/mlir/include/mlir/Analysis/Dominance.h @@ -47,7 +47,7 @@ class DominanceInfo : public DominatorTreeBase { using super = DominatorTreeBase; public: - DominanceInfo(CFGFunction *F); + DominanceInfo(Function *F); /// Return true if instruction A properly dominates instruction B. bool properlyDominates(const Instruction *a, const Instruction *b); diff --git a/mlir/include/mlir/Analysis/HyperRectangularSet.h b/mlir/include/mlir/Analysis/HyperRectangularSet.h index 27bb5da6dab..74961308f47 100644 --- a/mlir/include/mlir/Analysis/HyperRectangularSet.h +++ b/mlir/include/mlir/Analysis/HyperRectangularSet.h @@ -62,7 +62,7 @@ using AffineBoundExprList = SmallVector; // 0 <= d0 <= 511 // max(128,M) <= d1 <= min(N-1,256) // -// Symbols here aren't necessarily associated with MLFunction's symbols; they +// Symbols here aren't necessarily associated with Function's symbols; they // could also correspond to outer loop IVs for example or anything abstract. The // binding to SSA values for dimensions/symbols is optional, and these are in an // abstract integer domain. As an example, to describe data accessed in a tile diff --git a/mlir/include/mlir/Analysis/MLFunctionMatcher.h b/mlir/include/mlir/Analysis/MLFunctionMatcher.h index bd99363cafb..753d741f448 100644 --- a/mlir/include/mlir/Analysis/MLFunctionMatcher.h +++ b/mlir/include/mlir/Analysis/MLFunctionMatcher.h @@ -29,7 +29,7 @@ struct MLFunctionMatchesStorage; class Statement; /// An MLFunctionMatcher is a recursive matcher that captures nested patterns in -/// an MLFunction. It is used in conjunction with a scoped +/// an ML Function. It is used in conjunction with a scoped /// MLFunctionMatcherContext that handles the memory allocations efficiently. /// /// In order to use MLFunctionMatchers creates a scoped context and uses @@ -47,7 +47,7 @@ class Statement; /// /// Recursive abstraction for matching results. -/// Provides iteration over the MLFunction Statement* captured by a Matcher. +/// Provides iteration over the Statement* captured by a Matcher. /// /// Implemented as a POD value-type with underlying storage pointer. /// The underlying storage lives in a scoped bumper allocator whose lifetime @@ -99,7 +99,7 @@ struct MLFunctionMatcher : public StmtWalker { FilterFunctionType filter = defaultFilterFunction); /// Returns all the matches in `function`. - MLFunctionMatches match(MLFunction *function); + MLFunctionMatches match(Function *function); /// Returns all the matches nested under `statement`. MLFunctionMatches match(Statement *statement); diff --git a/mlir/include/mlir/Analysis/Passes.h b/mlir/include/mlir/Analysis/Passes.h index 3a663dc4fc4..8fd1f9c4bf9 100644 --- a/mlir/include/mlir/Analysis/Passes.h +++ b/mlir/include/mlir/Analysis/Passes.h @@ -29,10 +29,10 @@ namespace mlir { class FunctionPass; -/// Creates a pass to check memref accesses in an MLFunction. +/// Creates a pass to check memref accesses in an ML Function. FunctionPass *createMemRefBoundCheckPass(); -/// Creates a pass to check memref access dependences in an MLFunction. +/// Creates a pass to check memref access dependences in an ML Function. FunctionPass *createMemRefDependenceCheckPass(); } // end namespace mlir diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index eb8dbe530ea..fd57cda3902 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -115,7 +115,7 @@ private: /// cases. The computed region's 'cst' field has exactly as many dimensional /// identifiers as the rank of the memref, and *potentially* additional symbolic /// identifiers which could include any of the loop IVs surrounding opStmt up -/// until 'loopDepth' and another additional MLFunction symbols involved with +/// until 'loopDepth' and another additional Function symbols involved with /// the access (for eg., those appear in affine_apply's, loop bounds, etc.). /// For example, the memref region for this operation at loopDepth = 1 will be: /// diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h index 0d039ee5d9b..5b52a5de7e7 100644 --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -111,15 +111,15 @@ public: BasicBlock &back() { return blocks.back(); } const BasicBlock &back() const { - return const_cast(this)->back(); + return const_cast(this)->back(); } BasicBlock &front() { return blocks.front(); } const BasicBlock &front() const { - return const_cast(this)->front(); + return const_cast(this)->front(); } - /// Return the 'return' statement of this MLFunction. + /// Return the 'return' statement of this Function. const OperationInst *getReturnStmt() const; OperationInst *getReturnStmt(); @@ -157,14 +157,14 @@ public: } // Supports non-const operand iteration. - using args_iterator = ArgumentIterator; + using args_iterator = ArgumentIterator; args_iterator args_begin(); args_iterator args_end(); llvm::iterator_range getArguments(); // Supports const operand iteration. using const_args_iterator = - ArgumentIterator; + ArgumentIterator; const_args_iterator args_begin() const; const_args_iterator args_end() const; llvm::iterator_range getArguments() const; @@ -252,32 +252,31 @@ public: }; //===--------------------------------------------------------------------===// -// MLFunction iterator methods. +// Function iterator methods. //===--------------------------------------------------------------------===// -inline MLFunction::args_iterator MLFunction::args_begin() { +inline Function::args_iterator Function::args_begin() { return args_iterator(this, 0); } -inline MLFunction::args_iterator MLFunction::args_end() { +inline Function::args_iterator Function::args_end() { return args_iterator(this, getNumArguments()); } -inline llvm::iterator_range -MLFunction::getArguments() { +inline llvm::iterator_range Function::getArguments() { return {args_begin(), args_end()}; } -inline MLFunction::const_args_iterator MLFunction::args_begin() const { +inline Function::const_args_iterator Function::args_begin() const { return const_args_iterator(this, 0); } -inline MLFunction::const_args_iterator MLFunction::args_end() const { +inline Function::const_args_iterator Function::args_end() const { return const_args_iterator(this, getNumArguments()); } -inline llvm::iterator_range -MLFunction::getArguments() const { +inline llvm::iterator_range +Function::getArguments() const { return {args_begin(), args_end()}; } diff --git a/mlir/include/mlir/IR/FunctionGraphTraits.h b/mlir/include/mlir/IR/FunctionGraphTraits.h index 95b20476880..54305c90d25 100644 --- a/mlir/include/mlir/IR/FunctionGraphTraits.h +++ b/mlir/include/mlir/IR/FunctionGraphTraits.h @@ -86,14 +86,13 @@ template <> struct GraphTraits> { }; template <> -struct GraphTraits - : public GraphTraits { - using GraphType = mlir::CFGFunction *; +struct GraphTraits : public GraphTraits { + using GraphType = mlir::Function *; using NodeRef = mlir::BasicBlock *; static NodeRef getEntryNode(GraphType fn) { return &fn->front(); } - using nodes_iterator = pointer_iterator; + using nodes_iterator = pointer_iterator; static nodes_iterator nodes_begin(GraphType fn) { return nodes_iterator(fn->begin()); } @@ -103,14 +102,14 @@ struct GraphTraits }; template <> -struct GraphTraits +struct GraphTraits : public GraphTraits { - using GraphType = const mlir::CFGFunction *; + using GraphType = const mlir::Function *; using NodeRef = const mlir::BasicBlock *; static NodeRef getEntryNode(GraphType fn) { return &fn->front(); } - using nodes_iterator = pointer_iterator; + using nodes_iterator = pointer_iterator; static nodes_iterator nodes_begin(GraphType fn) { return nodes_iterator(fn->begin()); } @@ -120,14 +119,14 @@ struct GraphTraits }; template <> -struct GraphTraits> +struct GraphTraits> : public GraphTraits> { - using GraphType = Inverse; + using GraphType = Inverse; using NodeRef = NodeRef; static NodeRef getEntryNode(GraphType fn) { return &fn.Graph->front(); } - using nodes_iterator = pointer_iterator; + using nodes_iterator = pointer_iterator; static nodes_iterator nodes_begin(GraphType fn) { return nodes_iterator(fn.Graph->begin()); } @@ -137,14 +136,14 @@ struct GraphTraits> }; template <> -struct GraphTraits> +struct GraphTraits> : public GraphTraits> { - using GraphType = Inverse; + using GraphType = Inverse; using NodeRef = NodeRef; static NodeRef getEntryNode(GraphType fn) { return &fn.Graph->front(); } - using nodes_iterator = pointer_iterator; + using nodes_iterator = pointer_iterator; static nodes_iterator nodes_begin(GraphType fn) { return nodes_iterator(fn.Graph->begin()); } @@ -161,7 +160,7 @@ struct GraphTraits static NodeRef getEntryNode(GraphType fn) { return &fn->front(); } - using nodes_iterator = pointer_iterator; + using nodes_iterator = pointer_iterator; static nodes_iterator nodes_begin(GraphType fn) { return nodes_iterator(fn->begin()); } @@ -178,7 +177,7 @@ struct GraphTraits static NodeRef getEntryNode(GraphType fn) { return &fn->front(); } - using nodes_iterator = pointer_iterator; + using nodes_iterator = pointer_iterator; static nodes_iterator nodes_begin(GraphType fn) { return nodes_iterator(fn->begin()); } @@ -195,7 +194,7 @@ struct GraphTraits> static NodeRef getEntryNode(GraphType fn) { return &fn.Graph->front(); } - using nodes_iterator = pointer_iterator; + using nodes_iterator = pointer_iterator; static nodes_iterator nodes_begin(GraphType fn) { return nodes_iterator(fn.Graph->begin()); } @@ -212,7 +211,7 @@ struct GraphTraits> static NodeRef getEntryNode(GraphType fn) { return &fn.Graph->front(); } - using nodes_iterator = pointer_iterator; + using nodes_iterator = pointer_iterator; static nodes_iterator nodes_begin(GraphType fn) { return nodes_iterator(fn.Graph->begin()); } diff --git a/mlir/include/mlir/IR/Statement.h b/mlir/include/mlir/IR/Statement.h index 1e2e1103c98..48135514dcf 100644 --- a/mlir/include/mlir/IR/Statement.h +++ b/mlir/include/mlir/IR/Statement.h @@ -29,7 +29,6 @@ namespace mlir { class Location; -using MLFunction = Function; class StmtBlock; class ForStmt; class MLIRContext; @@ -105,7 +104,7 @@ public: /// Returns the function that this statement is part of. /// The function is determined by traversing the chain of parent statements. /// Returns nullptr if the statement is unlinked. - MLFunction *getFunction() const; + Function *getFunction() const; /// Destroys this statement and its subclass data. void destroy(); diff --git a/mlir/include/mlir/IR/StmtBlock.h b/mlir/include/mlir/IR/StmtBlock.h index 01ef68c7d18..916834dfbdc 100644 --- a/mlir/include/mlir/IR/StmtBlock.h +++ b/mlir/include/mlir/IR/StmtBlock.h @@ -28,8 +28,6 @@ namespace mlir { class IfStmt; class StmtBlockList; -using CFGFunction = Function; -using MLFunction = Function; template class PredecessorIterator; template class SuccessorIterator; @@ -61,8 +59,8 @@ public: /// Returns the function that this statement block is part of. The function /// is determined by traversing the chain of parent statements. - MLFunction *getFunction(); - const MLFunction *getFunction() const { + Function *getFunction(); + const Function *getFunction() const { return const_cast(this)->getFunction(); } @@ -293,7 +291,7 @@ private: namespace mlir { /// This class contains a list of basic blocks and has a notion of the object it -/// is part of - an MLFunction or IfStmt or ForStmt. +/// is part of - a Function or IfStmt or ForStmt. class StmtBlockList { public: explicit StmtBlockList(Function *container); @@ -345,7 +343,7 @@ public: } /// A StmtBlockList is part of a Function or and IfStmt/ForStmt. If it is - /// part of an Function, then return it, otherwise return null. + /// part of a Function, then return it, otherwise return null. Function *getContainingFunction(); const Function *getContainingFunction() const { return const_cast(this)->getContainingFunction(); @@ -353,8 +351,8 @@ public: // TODO(clattner): This is only to help ML -> CFG migration, remove in the // near future. This makes StmtBlockList work more like BasicBlock did. - CFGFunction *getFunction(); - const CFGFunction *getFunction() const { + Function *getFunction(); + const Function *getFunction() const { return const_cast(this)->getFunction(); } diff --git a/mlir/include/mlir/IR/StmtVisitor.h b/mlir/include/mlir/IR/StmtVisitor.h index bcc416c00ae..570036a0d99 100644 --- a/mlir/include/mlir/IR/StmtVisitor.h +++ b/mlir/include/mlir/IR/StmtVisitor.h @@ -15,7 +15,7 @@ // limitations under the License. // ============================================================================= // -// This file defines the base classes for MLFunction's statement visitors and +// This file defines the base classes for Function's statement visitors and // walkers. A visit is a O(1) operation that visits just the node in question. A // walk visits the node it's called on as well as the node's descendants. // @@ -29,7 +29,7 @@ // resolved overloading, not virtual functions. // // For example, here is a walker that counts the number of for loops in an -// MLFunction. +// Function. // // /// Declare the class. Note that we derive from StmtWalker instantiated // /// with _our new subclasses_ type. @@ -45,7 +45,7 @@ // numLoops = lc.numLoops; // // There are 'visit' methods for OperationInst, ForStmt, IfStmt, and -// MLFunction, which recursively process all contained statements. +// Function, which recursively process all contained statements. // // Note that if you don't implement visitXXX for some statement type, // the visitXXX method for Statement superclass will be invoked. @@ -129,14 +129,14 @@ public: } } - // Define walkers for MLFunction and all MLFunction statement kinds. - void walk(MLFunction *f) { + // Define walkers for Function and all Function statement kinds. + void walk(Function *f) { static_cast(this)->visitMLFunction(f); static_cast(this)->walk(f->getBody()->begin(), f->getBody()->end()); } - void walkPostOrder(MLFunction *f) { + void walkPostOrder(Function *f) { static_cast(this)->walkPostOrder(f->getBody()->begin(), f->getBody()->end()); static_cast(this)->visitMLFunction(f); @@ -219,7 +219,7 @@ public: // called. These are typically O(1) complexity and shouldn't be recursively // processing their descendants in some way. When using RetTy, all of these // need to be overridden. - void visitMLFunction(MLFunction *f) {} + void visitMLFunction(Function *f) {} void visitForStmt(ForStmt *forStmt) {} void visitIfStmt(IfStmt *ifStmt) {} void visitOperationInst(OperationInst *opStmt) {} diff --git a/mlir/include/mlir/Pass.h b/mlir/include/mlir/Pass.h index 75f682d3c10..6c2a3322aa6 100644 --- a/mlir/include/mlir/Pass.h +++ b/mlir/include/mlir/Pass.h @@ -25,8 +25,6 @@ namespace mlir { class Function; -using CFGFunction = Function; -using MLFunction = Function; class Module; // Values that can be used by to signal success/failure. This can be implicitly @@ -93,11 +91,11 @@ public: /// runOnCFGFunction or runOnMLFunction. virtual PassResult runOnFunction(Function *fn); - /// Implement this function if you want to see CFGFunction's specifically. - virtual PassResult runOnCFGFunction(CFGFunction *fn) { return success(); } + /// Implement this function if you want to see CFG Function's specifically. + virtual PassResult runOnCFGFunction(Function *fn) { return success(); } - /// Implement this function if you want to see MLFunction's specifically. - virtual PassResult runOnMLFunction(MLFunction *fn) { return success(); } + /// Implement this function if you want to see ML Function's specifically. + virtual PassResult runOnMLFunction(Function *fn) { return success(); } // Iterates over all functions in a module, halting upon failure. virtual PassResult runOnModule(Module *m) override; diff --git a/mlir/include/mlir/Transforms/LoopUtils.h b/mlir/include/mlir/Transforms/LoopUtils.h index 38314101c60..d214a96f335 100644 --- a/mlir/include/mlir/Transforms/LoopUtils.h +++ b/mlir/include/mlir/Transforms/LoopUtils.h @@ -31,7 +31,6 @@ namespace mlir { class AffineMap; class ForStmt; class Function; -using MLFunction = Function; class FuncBuilder; // Values that can be used to signal success/failure. This can be implicitly @@ -66,9 +65,9 @@ bool loopUnrollJamUpToFactor(ForStmt *forStmt, uint64_t unrollJamFactor); /// was known to have a single iteration. Returns false otherwise. bool promoteIfSingleIteration(ForStmt *forStmt); -/// Promotes all single iteration ForStmt's in the MLFunction, i.e., moves +/// Promotes all single iteration ForStmt's in the Function, i.e., moves /// their body into the containing StmtBlock. -void promoteSingleIterationLoops(MLFunction *f); +void promoteSingleIterationLoops(Function *f); /// Returns the lower bound of the cleanup loop when unrolling a loop /// with the specified unroll factor. diff --git a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h index b680d78fce9..c75fddbb4de 100644 --- a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h +++ b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h @@ -46,9 +46,9 @@ private: FuncBuilder *builder; }; -/// Base class for the MLFunction-wise lowering state. A pointer to the same +/// Base class for the Function-wise lowering state. A pointer to the same /// instance of the subclass will be passed to all `rewrite` calls on operations -/// that belong to the same MLFunction. +/// that belong to the same Function. class MLFuncGlobalLoweringState { public: virtual ~MLFuncGlobalLoweringState() {} @@ -58,7 +58,7 @@ protected: MLFuncGlobalLoweringState() {} }; -/// Base class for MLFunction lowering patterns. +/// Base class for Function lowering patterns. class MLLoweringPattern : public Pattern { public: /// Subclasses must override this function to implement rewriting. It will be @@ -104,11 +104,11 @@ public: explicit MLPatternLoweringPass(void *ID) : FunctionPass(ID) {} virtual std::unique_ptr - makeFuncWiseState(MLFunction *f) const { + makeFuncWiseState(Function *f) const { return nullptr; } - PassResult runOnMLFunction(MLFunction *f) override; + PassResult runOnMLFunction(Function *f) override; }; ///////////////////////////////////////////////////////////////////// @@ -135,7 +135,7 @@ template struct ListAdder { } // namespace detail template -PassResult MLPatternLoweringPass::runOnMLFunction(MLFunction *f) { +PassResult MLPatternLoweringPass::runOnMLFunction(Function *f) { detail::OwningMLLoweringPatternList patterns; detail::ListAdder::addPatternsToList(&patterns, f->getContext()); auto funcWiseState = makeFuncWiseState(f); diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index 4179dbccb42..fd376fbbb97 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -95,7 +95,7 @@ FunctionPass *createDmaGenerationPass(unsigned lowMemorySpace, /// Replaces affine_apply operations in CFGFunctions with the arithmetic /// primitives (addition, multplication) they comprise. Errors out on -/// any MLFunction since it may contain affine_applies baked into the For loop +/// any Function since it may contain affine_applies baked into the For loop /// bounds that cannot be replaced. FunctionPass *createLowerAffineApplyPass(); diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h index 131a1f16815..f33f774bb22 100644 --- a/mlir/include/mlir/Transforms/Utils.h +++ b/mlir/include/mlir/Transforms/Utils.h @@ -39,7 +39,6 @@ class Module; class OperationInst; class Function; -using CFGFunction = Function; /// Replace all uses of oldMemRef with newMemRef while optionally remapping the /// old memref's indices using the supplied affine map and adding any additional diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index f3fde8bb95f..e28c2e87651 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -565,7 +565,7 @@ bool mlir::getIndexSet(ArrayRef forStmts, // Computes the iteration domain for 'opStmt' and populates 'indexSet', which // encapsulates the constraints involving loops surrounding 'opStmt' and -// potentially involving any MLFunction symbols. The dimensional identifiers in +// potentially involving any Function symbols. The dimensional identifiers in // 'indexSet' correspond to the loops surounding 'stmt' from outermost to // innermost. // TODO(andydavis) Add support to handle IfStmts surrounding 'stmt'. diff --git a/mlir/lib/Analysis/Dominance.cpp b/mlir/lib/Analysis/Dominance.cpp index c796bb8cd00..0ebbec9c025 100644 --- a/mlir/lib/Analysis/Dominance.cpp +++ b/mlir/lib/Analysis/Dominance.cpp @@ -30,7 +30,7 @@ template class llvm::DominatorTreeBase; template class llvm::DomTreeNodeBase; /// Compute the immediate-dominators map. -DominanceInfo::DominanceInfo(CFGFunction *function) : DominatorTreeBase() { +DominanceInfo::DominanceInfo(Function *function) : DominatorTreeBase() { // Build the dominator tree for the function. recalculate(function->getBlockList()); } diff --git a/mlir/lib/Analysis/MLFunctionMatcher.cpp b/mlir/lib/Analysis/MLFunctionMatcher.cpp index c03fed5986b..12ce8481516 100644 --- a/mlir/lib/Analysis/MLFunctionMatcher.cpp +++ b/mlir/lib/Analysis/MLFunctionMatcher.cpp @@ -92,7 +92,7 @@ static MLFunctionMatches combine(ArrayRef matches) { } /// Calls walk on `function`. -MLFunctionMatches MLFunctionMatcher::match(MLFunction *function) { +MLFunctionMatches MLFunctionMatcher::match(Function *function) { assert(!matches && "MLFunctionMatcher already matched!"); this->walkPostOrder(function); return matches; diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index 1cb039fe00e..ad935faf05d 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -41,9 +41,9 @@ namespace { struct MemRefBoundCheck : public FunctionPass, StmtWalker { explicit MemRefBoundCheck() : FunctionPass(&MemRefBoundCheck::passID) {} - PassResult runOnMLFunction(MLFunction *f) override; + PassResult runOnMLFunction(Function *f) override; // Not applicable to CFG functions. - PassResult runOnCFGFunction(CFGFunction *f) override { return success(); } + PassResult runOnCFGFunction(Function *f) override { return success(); } void visitOperationInst(OperationInst *opStmt); @@ -67,10 +67,10 @@ void MemRefBoundCheck::visitOperationInst(OperationInst *opStmt) { // TODO(bondhugula): do this for DMA ops as well. } -PassResult MemRefBoundCheck::runOnMLFunction(MLFunction *f) { +PassResult MemRefBoundCheck::runOnMLFunction(Function *f) { return walk(f), success(); } static PassRegistration memRefBoundCheck("memref-bound-check", - "Check memref access bounds in an MLFunction"); + "Check memref access bounds in a Function"); diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp index ec33c619a17..bb668f78624 100644 --- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp @@ -37,16 +37,16 @@ using namespace mlir; namespace { // TODO(andydavis) Add common surrounding loop depth-wise dependence checks. -/// Checks dependences between all pairs of memref accesses in an MLFunction. +/// Checks dependences between all pairs of memref accesses in a Function. struct MemRefDependenceCheck : public FunctionPass, StmtWalker { SmallVector loadsAndStores; explicit MemRefDependenceCheck() : FunctionPass(&MemRefDependenceCheck::passID) {} - PassResult runOnMLFunction(MLFunction *f) override; + PassResult runOnMLFunction(Function *f) override; // Not applicable to CFG functions. - PassResult runOnCFGFunction(CFGFunction *f) override { return success(); } + PassResult runOnCFGFunction(Function *f) override { return success(); } void visitOperationInst(OperationInst *opStmt) { if (opStmt->isa() || opStmt->isa()) { @@ -166,9 +166,9 @@ static void checkDependences(ArrayRef loadsAndStores) { } } -// Walks the MLFunction 'f' adding load and store ops to 'loadsAndStores'. +// Walks the Function 'f' adding load and store ops to 'loadsAndStores'. // Runs pair-wise dependence checks. -PassResult MemRefDependenceCheck::runOnMLFunction(MLFunction *f) { +PassResult MemRefDependenceCheck::runOnMLFunction(Function *f) { loadsAndStores.clear(); walk(f); checkDependences(loadsAndStores); diff --git a/mlir/lib/Analysis/OpStats.cpp b/mlir/lib/Analysis/OpStats.cpp index cea0c087297..f4c509a5132 100644 --- a/mlir/lib/Analysis/OpStats.cpp +++ b/mlir/lib/Analysis/OpStats.cpp @@ -34,10 +34,10 @@ struct PrintOpStatsPass : public FunctionPass, StmtWalker { PassResult runOnModule(Module *m) override; // Process CFG function considering the instructions in basic blocks. - PassResult runOnCFGFunction(CFGFunction *function) override; + PassResult runOnCFGFunction(Function *function) override; // Process ML functions and operation statments in ML functions. - PassResult runOnMLFunction(MLFunction *function) override; + PassResult runOnMLFunction(Function *function) override; void visitOperationInst(OperationInst *stmt); // Print summary of op stats. @@ -61,7 +61,7 @@ PassResult PrintOpStatsPass::runOnModule(Module *m) { return result; } -PassResult PrintOpStatsPass::runOnCFGFunction(CFGFunction *function) { +PassResult PrintOpStatsPass::runOnCFGFunction(Function *function) { for (const auto &bb : *function) for (const auto &inst : bb) if (auto *op = dyn_cast(&inst)) @@ -73,7 +73,7 @@ void PrintOpStatsPass::visitOperationInst(OperationInst *stmt) { ++opCount[stmt->getName().getStringRef()]; } -PassResult PrintOpStatsPass::runOnMLFunction(MLFunction *function) { +PassResult PrintOpStatsPass::runOnMLFunction(Function *function) { walk(function); return success(); } diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index c06bf4df61e..393d7c59de0 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -15,7 +15,7 @@ // limitations under the License. // ============================================================================= // -// This file implements Analysis functions specific to slicing in MLFunction. +// This file implements Analysis functions specific to slicing in Function. // //===----------------------------------------------------------------------===// @@ -30,7 +30,7 @@ #include /// -/// Implements Analysis functions specific to slicing in MLFunction. +/// Implements Analysis functions specific to slicing in Function. /// using namespace mlir; diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 80adb369aef..e17c27ac941 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -129,7 +129,7 @@ Optional MemRefRegion::getBoundingConstantSizeAndShape( /// Computes the memory region accessed by this memref with the region /// represented as constraints symbolic/parameteric in 'loopDepth' loops -/// surrounding opStmt and any additional MLFunction symbols. Returns false if +/// surrounding opStmt and any additional Function symbols. Returns false if /// this fails due to yet unimplemented cases. // For example, the memref region for this load operation at loopDepth = 1 will // be as below: diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index e1de6191de6..43c29dbb6ac 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -146,11 +146,11 @@ bool Verifier::verifyOperation(const OperationInst &op) { namespace { struct CFGFuncVerifier : public Verifier { - const CFGFunction &fn; + const Function &fn; DominanceInfo domInfo; - CFGFuncVerifier(const CFGFunction &fn) - : Verifier(fn), fn(fn), domInfo(const_cast(&fn)) {} + CFGFuncVerifier(const Function &fn) + : Verifier(fn), fn(fn), domInfo(const_cast(&fn)) {} bool verify(); bool verifyBlock(const BasicBlock &block); @@ -240,10 +240,10 @@ bool CFGFuncVerifier::verifyBlock(const BasicBlock &block) { namespace { struct MLFuncVerifier : public Verifier, public StmtWalker { - const MLFunction &fn; + const Function &fn; bool hadError = false; - MLFuncVerifier(const MLFunction &fn) : Verifier(fn), fn(fn) {} + MLFuncVerifier(const Function &fn) : Verifier(fn), fn(fn) {} void visitOperationInst(OperationInst *opStmt) { hadError |= verifyOperation(*opStmt); @@ -254,7 +254,7 @@ struct MLFuncVerifier : public Verifier, public StmtWalker { fn.getName().c_str()); // Check basic structural properties. - walk(const_cast(&fn)); + walk(const_cast(&fn)); if (hadError) return true; @@ -366,9 +366,9 @@ bool Function::verify() const { // No body, nothing can be wrong here. return false; case Kind::CFGFunc: - return CFGFuncVerifier(*cast(this)).verify(); + return CFGFuncVerifier(*this).verify(); case Kind::MLFunc: - return MLFuncVerifier(*cast(this)).verify(); + return MLFuncVerifier(*this).verify(); } } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index c6b731aac57..19943573bc3 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -115,8 +115,8 @@ private: // Visit functions. void visitFunction(const Function *fn); void visitExtFunction(const Function *fn); - void visitCFGFunction(const CFGFunction *fn); - void visitMLFunction(const MLFunction *fn); + void visitCFGFunction(const Function *fn); + void visitMLFunction(const Function *fn); void visitStatement(const Statement *stmt); void visitForStmt(const ForStmt *forStmt); void visitIfStmt(const IfStmt *ifStmt); @@ -177,14 +177,14 @@ void ModuleState::visitExtFunction(const Function *fn) { visitType(fn->getType()); } -void ModuleState::visitCFGFunction(const CFGFunction *fn) { +void ModuleState::visitCFGFunction(const Function *fn) { visitType(fn->getType()); for (auto &block : *fn) { for (auto &op : block.getStatements()) { if (auto *opInst = dyn_cast(&op)) visitOperation(opInst); else { - llvm_unreachable("IfStmt/ForStmt in a CFGFunction isn't supported"); + llvm_unreachable("IfStmt/ForStmt in a CFG Function isn't supported"); } } } @@ -230,7 +230,7 @@ void ModuleState::visitStatement(const Statement *stmt) { } } -void ModuleState::visitMLFunction(const MLFunction *fn) { +void ModuleState::visitMLFunction(const Function *fn) { visitType(fn->getType()); for (auto &stmt : *fn->getBody()) { ModuleState::visitStatement(&stmt); @@ -1103,7 +1103,7 @@ private: unsigned nextValueID = 0; /// This is the ID to assign to the next induction variable. unsigned nextLoopID = 0; - /// This is the next ID to assign to an MLFunction argument. + /// This is the next ID to assign to a Function argument. unsigned nextArgumentID = 0; /// This is the next ID to assign when a name conflict is detected. @@ -1163,9 +1163,9 @@ void FunctionPrinter::printDefaultOp(const OperationInst *op) { namespace { class CFGFunctionPrinter : public FunctionPrinter { public: - CFGFunctionPrinter(const CFGFunction *function, const ModulePrinter &other); + CFGFunctionPrinter(const Function *function, const ModulePrinter &other); - const CFGFunction *getFunction() const { return function; } + const Function *getFunction() const { return function; } void print(); void print(const BasicBlock *block); @@ -1183,7 +1183,7 @@ public: } private: - const CFGFunction *function; + const Function *function; DenseMap basicBlockIDs; void numberValuesInBlock(const BasicBlock *block); @@ -1192,7 +1192,7 @@ private: }; } // end anonymous namespace -CFGFunctionPrinter::CFGFunctionPrinter(const CFGFunction *function, +CFGFunctionPrinter::CFGFunctionPrinter(const Function *function, const ModulePrinter &other) : FunctionPrinter(other), function(function) { // Each basic block gets a unique ID per function. @@ -1319,9 +1319,9 @@ void ModulePrinter::printCFG(const Function *fn) { namespace { class MLFunctionPrinter : public FunctionPrinter { public: - MLFunctionPrinter(const MLFunction *function, const ModulePrinter &other); + MLFunctionPrinter(const Function *function, const ModulePrinter &other); - const MLFunction *getFunction() const { return function; } + const Function *getFunction() const { return function; } // Prints ML function. void print(); @@ -1349,12 +1349,12 @@ public: private: void numberValues(); - const MLFunction *function; + const Function *function; int numSpaces; }; } // end anonymous namespace -MLFunctionPrinter::MLFunctionPrinter(const MLFunction *function, +MLFunctionPrinter::MLFunctionPrinter(const Function *function, const ModulePrinter &other) : FunctionPrinter(other), function(function), numSpaces(0) { assert(function && "Cannot print nullptr function"); @@ -1381,7 +1381,7 @@ void MLFunctionPrinter::numberValues() { NumberValuesPass pass(this); // TODO: it'd be cleaner to have constant visitor instead of using const_cast. - pass.walk(const_cast(function)); + pass.walk(const_cast(function)); } void MLFunctionPrinter::print() { diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index 19b137071f4..0e777c65f23 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -32,11 +32,11 @@ Function::Function(Kind kind, Location location, StringRef name, location(location), type(type), blocks(this) { this->attrs = AttributeListStorage::get(attrs, getContext()); - // Creating of an MLFunction automatically populates the entry block and + // Creating of a Function automatically populates the entry block and // arguments. // TODO(clattner): Unify this behavior. if (kind == Kind::MLFunc) { - // The body of an MLFunction always has one block. + // The body of an ML Function always has one block. auto *entry = new StmtBlock(); blocks.push_back(entry); @@ -158,18 +158,18 @@ bool Function::emitError(const Twine &message) const { } //===----------------------------------------------------------------------===// -// MLFunction implementation. +// Function implementation. //===----------------------------------------------------------------------===// -const OperationInst *MLFunction::getReturnStmt() const { +const OperationInst *Function::getReturnStmt() const { return cast(&getBody()->back()); } -OperationInst *MLFunction::getReturnStmt() { +OperationInst *Function::getReturnStmt() { return cast(&getBody()->back()); } -void MLFunction::walk(std::function callback) { +void Function::walk(std::function callback) { struct Walker : public StmtWalker { std::function const &callback; Walker(std::function const &callback) @@ -182,7 +182,7 @@ void MLFunction::walk(std::function callback) { v.walk(this); } -void MLFunction::walkPostOrder(std::function callback) { +void Function::walkPostOrder(std::function callback) { struct Walker : public StmtWalker { std::function const &callback; Walker(std::function const &callback) diff --git a/mlir/lib/IR/Statement.cpp b/mlir/lib/IR/Statement.cpp index 649bb9c4f78..96b44600460 100644 --- a/mlir/lib/IR/Statement.cpp +++ b/mlir/lib/IR/Statement.cpp @@ -82,7 +82,7 @@ Statement *Statement::getParentStmt() const { return block ? block->getContainingStmt() : nullptr; } -MLFunction *Statement::getFunction() const { +Function *Statement::getFunction() const { return block ? block->getFunction() : nullptr; } diff --git a/mlir/lib/IR/StmtBlock.cpp b/mlir/lib/IR/StmtBlock.cpp index cfb09e6bf45..b551b1121a7 100644 --- a/mlir/lib/IR/StmtBlock.cpp +++ b/mlir/lib/IR/StmtBlock.cpp @@ -32,7 +32,7 @@ Statement *StmtBlock::getContainingStmt() { return parent ? parent->getContainingStmt() : nullptr; } -MLFunction *StmtBlock::getFunction() { +Function *StmtBlock::getFunction() { StmtBlock *block = this; while (auto *stmt = block->getContainingStmt()) { block = stmt->getBlock(); @@ -143,7 +143,7 @@ StmtBlock *StmtBlock::getSinglePredecessor() { // Other //===----------------------------------------------------------------------===// -/// Unlink this BasicBlock from its CFGFunction and delete it. +/// Unlink this BasicBlock from its Function and delete it. void BasicBlock::eraseFromFunction() { assert(getFunction() && "BasicBlock has no parent"); getFunction()->getBlocks().erase(this); @@ -163,7 +163,7 @@ BasicBlock *BasicBlock::splitBasicBlock(iterator splitBefore) { // Start by creating a new basic block, and insert it immediate after this // one in the containing function. auto newBB = new BasicBlock(); - getFunction()->getBlocks().insert(++CFGFunction::iterator(this), newBB); + getFunction()->getBlocks().insert(++Function::iterator(this), newBB); auto branchLoc = splitBefore == end() ? getTerminator()->getLoc() : splitBefore->getLoc(); @@ -186,9 +186,7 @@ StmtBlockList::StmtBlockList(Function *container) : container(container) {} StmtBlockList::StmtBlockList(Statement *container) : container(container) {} -CFGFunction *StmtBlockList::getFunction() { - return dyn_cast_or_null(getContainingFunction()); -} +Function *StmtBlockList::getFunction() { return getContainingFunction(); } Statement *StmtBlockList::getContainingStmt() { return container.dyn_cast(); diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp index 830da558d9e..c7a5e42dd99 100644 --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -71,7 +71,7 @@ MLIRContext *IROperandOwner::getContext() const { //===----------------------------------------------------------------------===// /// Return the function that this argument is defined in. -MLFunction *BlockArgument::getFunction() { +Function *BlockArgument::getFunction() { if (auto *owner = getOwner()) return owner->getFunction(); return nullptr; diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index c5cbb0716de..9b67ef8b150 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -2560,11 +2560,11 @@ OperationInst *FunctionParser::parseCustomOperation( namespace { -/// This is a specialized parser for CFGFunction's, maintaining the state +/// This is a specialized parser for Function's, maintaining the state /// transient to their bodies. class CFGFunctionParser : public FunctionParser { public: - CFGFunctionParser(ParserState &state, CFGFunction *function) + CFGFunctionParser(ParserState &state, Function *function) : FunctionParser(state, Kind::CFGFunc), function(function), builder(function) {} @@ -2574,7 +2574,7 @@ public: SmallVectorImpl &operands); private: - CFGFunction *function; + Function *function; llvm::StringMap> blocksByName; DenseMap forwardRef; @@ -2770,17 +2770,17 @@ ParseResult CFGFunctionParser::parseBasicBlock() { //===----------------------------------------------------------------------===// namespace { -/// Refined parser for MLFunction bodies. +/// Refined parser for Function bodies. class MLFunctionParser : public FunctionParser { public: - MLFunctionParser(ParserState &state, MLFunction *function) + MLFunctionParser(ParserState &state, Function *function) : FunctionParser(state, Kind::MLFunc), function(function), builder(function->getBody()) {} ParseResult parseFunctionBody(); private: - MLFunction *function; + Function *function; /// This builder intentionally shadows the builder in the base class, with a /// more specific builder type. @@ -3271,7 +3271,7 @@ ParseResult ModuleParser::parseAffineStructureDef() { return ParseSuccess; } -/// Parse a (possibly empty) list of MLFunction arguments with types. +/// Parse a (possibly empty) list of Function arguments with types. /// /// ml-argument ::= ssa-id `:` type /// ml-argument-list ::= ml-argument (`,` ml-argument)* | /*empty*/ diff --git a/mlir/lib/StandardOps/StandardOps.cpp b/mlir/lib/StandardOps/StandardOps.cpp index 8b57dadf3c6..44ca8277e78 100644 --- a/mlir/lib/StandardOps/StandardOps.cpp +++ b/mlir/lib/StandardOps/StandardOps.cpp @@ -1005,7 +1005,7 @@ bool LoadOp::verify() const { // TODO: Verify we have the right number of indices. - // TODO: in MLFunction verify that the indices are parameters, IV's, or the + // TODO: in Function verify that the indices are parameters, IV's, or the // result of an affine_apply. return false; } @@ -1255,7 +1255,7 @@ bool StoreOp::verify() const { // TODO: Verify we have the right number of indices. - // TODO: in MLFunction verify that the indices are parameters, IV's, or the + // TODO: in Function verify that the indices are parameters, IV's, or the // result of an affine_apply. return false; } diff --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp index 7c22f274e3a..e9942ff824b 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp @@ -53,11 +53,11 @@ public: private: bool convertBasicBlock(const BasicBlock &bb, bool ignoreArguments = false); - bool convertCFGFunction(const CFGFunction &cfgFunc, llvm::Function &llvmFunc); + bool convertCFGFunction(const Function &cfgFunc, llvm::Function &llvmFunc); bool convertFunctions(const Module &mlirModule, llvm::Module &llvmModule); bool convertInstruction(const OperationInst &inst); - void connectPHINodes(const CFGFunction &cfgFunc); + void connectPHINodes(const Function &cfgFunc); /// Type conversion functions. If any conversion fails, report errors to the /// context of the MLIR type and return nullptr. @@ -799,7 +799,7 @@ static const Value *getPHISourceValue(const BasicBlock *current, return nullptr; } -void ModuleLowerer::connectPHINodes(const CFGFunction &cfgFunc) { +void ModuleLowerer::connectPHINodes(const Function &cfgFunc) { // Skip the first block, it cannot be branched to and its arguments correspond // to the arguments of the LLVM function. for (auto it = std::next(cfgFunc.begin()), eit = cfgFunc.end(); it != eit; @@ -821,7 +821,7 @@ void ModuleLowerer::connectPHINodes(const CFGFunction &cfgFunc) { } } -bool ModuleLowerer::convertCFGFunction(const CFGFunction &cfgFunc, +bool ModuleLowerer::convertCFGFunction(const Function &cfgFunc, llvm::Function &llvmFunc) { // Clear the block mapping. Blocks belong to a function, no need to keep // blocks from the previous functions around. Furthermore, we use this @@ -868,10 +868,10 @@ bool ModuleLowerer::convertFunctions(const Module &mlirModule, continue; llvm::Function *llvmFunc = functionMapping[functionPtr]; - // Add function arguments to the value remapping table. In CFGFunction, + // Add function arguments to the value remapping table. In Function, // arguments of the first block are those of the function. assert(!functionPtr->getBlocks().empty() && - "expected at least one basic block in a CFGFunction"); + "expected at least one basic block in a Function"); const BasicBlock &firstBlock = *functionPtr->begin(); for (auto arg : llvm::enumerate(llvmFunc->args())) { valueMapping[firstBlock.getArgument(arg.index())] = &arg.value(); diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index 4b198589e2c..04f7cfdc3e9 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -43,8 +43,8 @@ namespace { struct CSE : public FunctionPass { CSE() : FunctionPass(&CSE::passID) {} - PassResult runOnCFGFunction(CFGFunction *f) override; - PassResult runOnMLFunction(MLFunction *f) override; + PassResult runOnCFGFunction(Function *f) override; + PassResult runOnMLFunction(Function *f) override; static char passID; }; @@ -162,7 +162,7 @@ struct CFGCSE : public CSEImpl { bool processed; }; - void run(CFGFunction *f) { + void run(Function *f) { // Note, deque is being used here because there was significant performance // gains over vector when the container becomes very large due to the // specific access patterns. If/when these performance issues are no @@ -210,7 +210,7 @@ struct CFGCSE : public CSEImpl { struct MLCSE : public CSEImpl, StmtWalker { using StmtWalker::walk; - void run(MLFunction *f) { + void run(Function *f) { // Walk the function statements. walk(f); @@ -231,12 +231,12 @@ struct MLCSE : public CSEImpl, StmtWalker { char CSE::passID = 0; -PassResult CSE::runOnCFGFunction(CFGFunction *f) { +PassResult CSE::runOnCFGFunction(Function *f) { CFGCSE().run(f); return success(); } -PassResult CSE::runOnMLFunction(MLFunction *f) { +PassResult CSE::runOnMLFunction(Function *f) { MLCSE().run(f); return success(); } diff --git a/mlir/lib/Transforms/ComposeAffineMaps.cpp b/mlir/lib/Transforms/ComposeAffineMaps.cpp index a1ecf38dabd..8c69fa61578 100644 --- a/mlir/lib/Transforms/ComposeAffineMaps.cpp +++ b/mlir/lib/Transforms/ComposeAffineMaps.cpp @@ -16,7 +16,7 @@ // ============================================================================= // // This file implements a testing pass which composes affine maps from -// AffineApplyOps in an MLFunction, by forward subtituting results from an +// AffineApplyOps in a Function, by forward subtituting results from an // AffineApplyOp into any of its users which are also AffineApplyOps. // //===----------------------------------------------------------------------===// @@ -36,7 +36,7 @@ using namespace mlir; namespace { -// ComposeAffineMaps walks stmt blocks in an MLFunction, and for each +// ComposeAffineMaps walks stmt blocks in a Function, and for each // AffineApplyOp, forward substitutes its results into any users which are // also AffineApplyOps. After forward subtituting its results, AffineApplyOps // with no remaining uses are collected and erased after the walk. @@ -48,7 +48,7 @@ struct ComposeAffineMaps : public FunctionPass, StmtWalker { using StmtListType = llvm::iplist; void walk(StmtListType::iterator Start, StmtListType::iterator End); void visitOperationInst(OperationInst *stmt); - PassResult runOnMLFunction(MLFunction *f) override; + PassResult runOnMLFunction(Function *f) override; using StmtWalker::walk; static char passID; @@ -88,7 +88,7 @@ void ComposeAffineMaps::visitOperationInst(OperationInst *opStmt) { } } -PassResult ComposeAffineMaps::runOnMLFunction(MLFunction *f) { +PassResult ComposeAffineMaps::runOnMLFunction(Function *f) { affineApplyOpsToErase.clear(); walk(f); for (auto *opStmt : affineApplyOpsToErase) { diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index a83e625c240..08087777e72 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -40,8 +40,8 @@ struct ConstantFold : public FunctionPass, StmtWalker { ConstantFactoryType constantFactory); void visitOperationInst(OperationInst *stmt); void visitForStmt(ForStmt *stmt); - PassResult runOnCFGFunction(CFGFunction *f) override; - PassResult runOnMLFunction(MLFunction *f) override; + PassResult runOnCFGFunction(Function *f) override; + PassResult runOnMLFunction(Function *f) override; static char passID; }; @@ -103,7 +103,7 @@ bool ConstantFold::foldOperation(OperationInst *op, // For now, we do a simple top-down pass over a function folding constants. We // don't handle conditional control flow, constant PHI nodes, folding // conditional branches, or anything else fancy. -PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) { +PassResult ConstantFold::runOnCFGFunction(Function *f) { existingConstants.clear(); FuncBuilder builder(f); @@ -155,7 +155,7 @@ void ConstantFold::visitForStmt(ForStmt *forStmt) { constantFoldBounds(forStmt); } -PassResult ConstantFold::runOnMLFunction(MLFunction *f) { +PassResult ConstantFold::runOnMLFunction(Function *f) { existingConstants.clear(); opStmtsToErase.clear(); diff --git a/mlir/lib/Transforms/ConvertToCFG.cpp b/mlir/lib/Transforms/ConvertToCFG.cpp index ca158a17e92..270a25dd339 100644 --- a/mlir/lib/Transforms/ConvertToCFG.cpp +++ b/mlir/lib/Transforms/ConvertToCFG.cpp @@ -41,9 +41,8 @@ namespace { // Generates CFG function equivalent to the given ML function. class FunctionConverter : public StmtVisitor { public: - FunctionConverter(CFGFunction *cfgFunc) - : cfgFunc(cfgFunc), builder(cfgFunc) {} - CFGFunction *convert(MLFunction *mlFunc); + FunctionConverter(Function *cfgFunc) : cfgFunc(cfgFunc), builder(cfgFunc) {} + Function *convert(Function *mlFunc); void visitForStmt(ForStmt *forStmt); void visitIfStmt(IfStmt *ifStmt); @@ -56,7 +55,7 @@ private: Location loc, CmpIPredicate predicate, llvm::iterator_range values); - CFGFunction *cfgFunc; + Function *cfgFunc; FuncBuilder builder; // Mapping between original Values and lowered Values. @@ -455,7 +454,7 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) { // Entry point of the function convertor. // -// Conversion is performed by recursively visiting statements of an MLFunction. +// Conversion is performed by recursively visiting statements of a Function. // It reasons in terms of single-entry single-exit (SESE) regions that are not // materialized in the code. Instead, the pointer to the last block of the // region is maintained throughout the conversion as the insertion point of the @@ -471,11 +470,11 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) { // construction. When an Value is used, it gets replaced with the // corresponding Value that has been defined previously. The value flow // starts with function arguments converted to basic block arguments. -CFGFunction *FunctionConverter::convert(MLFunction *mlFunc) { +Function *FunctionConverter::convert(Function *mlFunc) { auto outerBlock = builder.createBlock(); // CFGFunctions do not have explicit arguments but use the arguments to the - // first basic block instead. Create those from the MLFunction arguments and + // first basic block instead. Create those from the Function arguments and // set up the value remapping. outerBlock->addArguments(mlFunc->getType().getInputs()); assert(mlFunc->getNumArguments() == outerBlock->getNumArguments()); @@ -511,17 +510,17 @@ private: // Generates CFG functions for all ML functions in the module. void convertMLFunctions(); // Generates CFG function for the given ML function. - CFGFunction *convert(MLFunction *mlFunc); + Function *convert(Function *mlFunc); // Replaces all ML function references in the module // with references to the generated CFG functions. void replaceReferences(); // Replaces function references in the given function. - void replaceReferences(CFGFunction *cfgFunc); + void replaceReferences(Function *cfgFunc); // Replaces MLFunctions with their CFG counterparts in the module. void replaceFunctions(); // Map from ML functions to generated CFG functions. - llvm::DenseMap generatedFuncs; + llvm::DenseMap generatedFuncs; Module *module = nullptr; }; } // end anonymous namespace @@ -554,7 +553,7 @@ void ModuleConverter::convertMLFunctions() { } // Creates CFG function equivalent to the given ML function. -CFGFunction *ModuleConverter::convert(MLFunction *mlFunc) { +Function *ModuleConverter::convert(Function *mlFunc) { // Use the same name as for ML function; do not add the converted function to // the module yet to avoid collision. auto name = mlFunc->getName().str(); @@ -578,7 +577,7 @@ void ModuleConverter::replaceReferences() { for (const Function &fn : *module) { if (!fn.isML()) continue; - CFGFunction *convertedFunc = generatedFuncs.lookup(&fn); + Function *convertedFunc = generatedFuncs.lookup(&fn); assert(convertedFunc && "ML function was not converted"); MLIRContext *context = module->getContext(); @@ -597,11 +596,11 @@ void ModuleConverter::replaceReferences() { } // Replace the value of a function attribute named "name" attached to the -// operation "op" and containing an MLFunction-typed value with the result of -// converting "func" to a CFGFunction. +// operation "op" and containing a Function-typed value with the result of +// converting "func" to a Function. static inline void replaceMLFunctionAttr( OperationInst &op, Identifier name, const Function *func, - const llvm::DenseMap &generatedFuncs) { + const llvm::DenseMap &generatedFuncs) { if (!func->isML()) return; @@ -610,8 +609,8 @@ static inline void replaceMLFunctionAttr( op.setAttr(name, b.getFunctionAttr(cfgFunc)); } -// The CFG and ML functions have the same name. First, erase the MLFunction. -// Then insert the CFGFunction at the same place. +// The CFG and ML functions have the same name. First, erase the Function. +// Then insert the Function at the same place. void ModuleConverter::replaceFunctions() { for (auto pair : generatedFuncs) { auto &functions = module->getFunctions(); diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index ed184dc9421..925c50abfec 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -63,8 +63,8 @@ struct DmaGeneration : public FunctionPass, StmtWalker { } // Not applicable to CFG functions. - PassResult runOnCFGFunction(CFGFunction *f) override { return success(); } - PassResult runOnMLFunction(MLFunction *f) override; + PassResult runOnCFGFunction(Function *f) override { return success(); } + PassResult runOnMLFunction(Function *f) override; void runOnForStmt(ForStmt *forStmt); void visitOperationInst(OperationInst *opStmt); @@ -425,7 +425,7 @@ void DmaGeneration::runOnForStmt(ForStmt *forStmt) { << " KiB of DMA buffers in fast memory space\n";); } -PassResult DmaGeneration::runOnMLFunction(MLFunction *f) { +PassResult DmaGeneration::runOnMLFunction(Function *f) { for (auto &stmt : *f->getBody()) { if (auto *forStmt = dyn_cast(&stmt)) { runOnForStmt(forStmt); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 67b36cfda30..2ddd613d6af 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -70,7 +70,7 @@ namespace { struct LoopFusion : public FunctionPass { LoopFusion() : FunctionPass(&LoopFusion::passID) {} - PassResult runOnMLFunction(MLFunction *f) override; + PassResult runOnMLFunction(Function *f) override; static char passID; }; @@ -158,7 +158,7 @@ public: }; // MemRefDependenceGraph is a graph data structure where graph nodes are -// top-level statements in an MLFunction which contain load/store ops, and edges +// top-level statements in a Function which contain load/store ops, and edges // are memref dependences between the nodes. // TODO(andydavis) Add a depth parameter to dependence graph construction. struct MemRefDependenceGraph { @@ -217,7 +217,7 @@ public: // Initializes the dependence graph based on operations in 'f'. // Returns true on success, false otherwise. - bool init(MLFunction *f); + bool init(Function *f); // Returns the graph node for 'id'. Node *getNode(unsigned id) { @@ -345,7 +345,7 @@ public: // Assigns each node in the graph a node id based on program order in 'f'. // TODO(andydavis) Add support for taking a StmtBlock arg to construct the // dependence graph at a different depth. -bool MemRefDependenceGraph::init(MLFunction *f) { +bool MemRefDependenceGraph::init(Function *f) { unsigned id = 0; DenseMap> memrefAccesses; for (auto &stmt : *f->getBody()) { @@ -415,7 +415,7 @@ bool MemRefDependenceGraph::init(MLFunction *f) { // GreedyFusion greedily fuses loop nests which have a producer/consumer // relationship on a memref, with the goal of improving locality. Currently, // this the producer/consumer relationship is required to be unique in the -// MLFunction (there are TODOs to relax this constraint in the future). +// Function (there are TODOs to relax this constraint in the future). // // The steps of the algorithm are as follows: // @@ -425,7 +425,7 @@ bool MemRefDependenceGraph::init(MLFunction *f) { // destination ForStmt into which fusion will be attempted. // *) Add each LoadOp currently in 'dstForStmt' into list 'dstLoadOps'. // *) For each LoadOp in 'dstLoadOps' do: -// *) Lookup dependent loop nests at earlier positions in the MLFunction +// *) Lookup dependent loop nests at earlier positions in the Function // which have a single store op to the same memref. // *) Check if dependences would be violated by the fusion. For example, // the src loop nest may load from memrefs which are different than @@ -549,7 +549,7 @@ public: } // end anonymous namespace -PassResult LoopFusion::runOnMLFunction(MLFunction *f) { +PassResult LoopFusion::runOnMLFunction(Function *f) { MemRefDependenceGraph g; if (g.init(f)) GreedyFusion(&g).run(); diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 9e365567b17..d6c1eed3a0c 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -38,10 +38,10 @@ static llvm::cl::opt namespace { -/// A pass to perform loop tiling on all suitable loop nests of an MLFunction. +/// A pass to perform loop tiling on all suitable loop nests of a Function. struct LoopTiling : public FunctionPass { LoopTiling() : FunctionPass(&LoopTiling::passID) {} - PassResult runOnMLFunction(MLFunction *f) override; + PassResult runOnMLFunction(Function *f) override; constexpr static unsigned kDefaultTileSize = 4; static char passID; @@ -52,7 +52,7 @@ struct LoopTiling : public FunctionPass { char LoopTiling::passID = 0; /// Creates a pass to perform loop tiling on all suitable loop nests of an -/// MLFunction. +/// Function. FunctionPass *mlir::createLoopTilingPass() { return new LoopTiling(); } // Move the loop body of ForStmt 'src' from 'src' into the specified location in @@ -214,7 +214,7 @@ UtilResult mlir::tileCodeGen(ArrayRef band, // Identify valid and profitable bands of loops to tile. This is currently just // a temporary placeholder to test the mechanics of tiled code generation. // Returns all maximal outermost perfect loop nests to tile. -static void getTileableBands(MLFunction *f, +static void getTileableBands(Function *f, std::vector> *bands) { // Get maximal perfect nest of 'for' stmts starting from root (inclusive). auto getMaximalPerfectLoopNest = [&](ForStmt *root) { @@ -235,7 +235,7 @@ static void getTileableBands(MLFunction *f, } } -PassResult LoopTiling::runOnMLFunction(MLFunction *f) { +PassResult LoopTiling::runOnMLFunction(Function *f) { std::vector> bands; getTileableBands(f, &bands); diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 0a3dd65d1f4..c3651e53593 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -70,7 +70,7 @@ struct LoopUnroll : public FunctionPass { : FunctionPass(&LoopUnroll::passID), unrollFactor(unrollFactor), unrollFull(unrollFull), getUnrollFactor(getUnrollFactor) {} - PassResult runOnMLFunction(MLFunction *f) override; + PassResult runOnMLFunction(Function *f) override; /// Unroll this for stmt. Returns false if nothing was done. bool runOnForStmt(ForStmt *forStmt); @@ -83,7 +83,7 @@ struct LoopUnroll : public FunctionPass { char LoopUnroll::passID = 0; -PassResult LoopUnroll::runOnMLFunction(MLFunction *f) { +PassResult LoopUnroll::runOnMLFunction(Function *f) { // Gathers all innermost loops through a post order pruned walk. class InnermostLoopGatherer : public StmtWalker { public: diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 179216d243e..7ed9be19644 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -65,7 +65,7 @@ static llvm::cl::opt namespace { /// Loop unroll jam pass. Currently, this just unroll jams the first -/// outer loop in an MLFunction. +/// outer loop in a Function. struct LoopUnrollAndJam : public FunctionPass { Optional unrollJamFactor; static const unsigned kDefaultUnrollJamFactor = 4; @@ -74,7 +74,7 @@ struct LoopUnrollAndJam : public FunctionPass { : FunctionPass(&LoopUnrollAndJam::passID), unrollJamFactor(unrollJamFactor) {} - PassResult runOnMLFunction(MLFunction *f) override; + PassResult runOnMLFunction(Function *f) override; bool runOnForStmt(ForStmt *forStmt); static char passID; @@ -88,7 +88,7 @@ FunctionPass *mlir::createLoopUnrollAndJamPass(int unrollJamFactor) { unrollJamFactor == -1 ? None : Optional(unrollJamFactor)); } -PassResult LoopUnrollAndJam::runOnMLFunction(MLFunction *f) { +PassResult LoopUnrollAndJam::runOnMLFunction(Function *f) { // Currently, just the outermost loop from the first loop nest is // unroll-and-jammed by this pass. However, runOnForStmt can be called on any // for Stmt. @@ -165,8 +165,8 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) { auto ubMap = forStmt->getUpperBoundMap(); // Loops with max/min expressions won't be unrolled here (the output can't be - // expressed as an MLFunction in the general case). However, the right way to - // do such unrolling for an MLFunction would be to specialize the loop for the + // expressed as a Function in the general case). However, the right way to + // do such unrolling for a Function would be to specialize the loop for the // 'hotspot' case and unroll that hotspot. if (lbMap.getNumResults() != 1 || ubMap.getNumResults() != 1) return false; diff --git a/mlir/lib/Transforms/LowerAffineApply.cpp b/mlir/lib/Transforms/LowerAffineApply.cpp index e8a2af54b8e..52146fdb5b7 100644 --- a/mlir/lib/Transforms/LowerAffineApply.cpp +++ b/mlir/lib/Transforms/LowerAffineApply.cpp @@ -35,8 +35,8 @@ struct LowerAffineApply : public FunctionPass { explicit LowerAffineApply() : FunctionPass(&LowerAffineApply::passID) {} - PassResult runOnMLFunction(MLFunction *f) override; - PassResult runOnCFGFunction(CFGFunction *f) override; + PassResult runOnMLFunction(Function *f) override; + PassResult runOnCFGFunction(Function *f) override; static char passID; }; @@ -45,13 +45,13 @@ struct LowerAffineApply : public FunctionPass { char LowerAffineApply::passID = 0; -PassResult LowerAffineApply::runOnMLFunction(MLFunction *f) { +PassResult LowerAffineApply::runOnMLFunction(Function *f) { f->emitError("ML Functions contain syntactically hidden affine_apply's that " "cannot be expanded"); return failure(); } -PassResult LowerAffineApply::runOnCFGFunction(CFGFunction *f) { +PassResult LowerAffineApply::runOnCFGFunction(Function *f) { for (BasicBlock &bb : *f) { // Handle iterators with care because we erase in the same loop. // In particular, step to the next element before erasing the current one. diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index 6907c322856..4d24191dcb2 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -108,7 +108,7 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer, auto vectorMemRefType = MemRefType::get({1}, vectorType, {}, 0); // Get the ML function builder. - // We need access to the MLFunction builder stored internally in the + // We need access to the Function builder stored internally in the // MLFunctionLoweringRewriter general rewriting API does not provide // ML-specific functions (ForStmt and StmtBlock manipulation). While we could // forward them or define a whole rewriting chain based on MLFunctionBuilder @@ -233,7 +233,7 @@ struct LowerVectorTransfersPass : MLPatternLoweringPass(&LowerVectorTransfersPass::passID) {} std::unique_ptr - makeFuncWiseState(MLFunction *f) const override { + makeFuncWiseState(Function *f) const override { auto state = llvm::make_unique(); auto builder = FuncBuilder(f); builder.setInsertionPointToStart(f->getBody()); diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 93fec70b8a7..a30e8164760 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -77,7 +77,7 @@ /// words, this pass operates on a scoped program slice. Furthermore, since we /// do not vectorize in the presence of conditionals for now, sliced chains are /// guaranteed not to escape the innermost scope, which has to be either the top -/// MLFunction scope of the innermost loop scope, by construction. As a +/// Function scope of the innermost loop scope, by construction. As a /// consequence, the implementation just starts from vector_transfer_write /// operations and builds the slice scoped the innermost loop enclosing the /// current vector_transfer_write. These assumptions and the implementation @@ -196,7 +196,7 @@ struct MaterializationState { struct MaterializeVectorsPass : public FunctionPass { MaterializeVectorsPass() : FunctionPass(&MaterializeVectorsPass::passID) {} - PassResult runOnMLFunction(MLFunction *f) override; + PassResult runOnMLFunction(Function *f) override; // Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit. MLFunctionMatcherContext mlContext; @@ -650,7 +650,7 @@ static bool emitSlice(MaterializationState *state, /// Additionally, this set is limited to statements in the same lexical scope /// because we currently disallow vectorization of defs that come from another /// scope. -static bool materialize(MLFunction *f, +static bool materialize(Function *f, const SetVector &terminators, MaterializationState *state) { DenseSet seen; @@ -709,9 +709,9 @@ static bool materialize(MLFunction *f, return false; } -PassResult MaterializeVectorsPass::runOnMLFunction(MLFunction *f) { +PassResult MaterializeVectorsPass::runOnMLFunction(Function *f) { using matcher::Op; - LLVM_DEBUG(dbgs() << "\nMaterializeVectors on MLFunction\n"); + LLVM_DEBUG(dbgs() << "\nMaterializeVectors on Function\n"); LLVM_DEBUG(f->print(dbgs())); MaterializationState state; diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 9798225f90b..a0964a67fa6 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -41,7 +41,7 @@ namespace { struct PipelineDataTransfer : public FunctionPass, StmtWalker { PipelineDataTransfer() : FunctionPass(&PipelineDataTransfer::passID) {} - PassResult runOnMLFunction(MLFunction *f) override; + PassResult runOnMLFunction(Function *f) override; PassResult runOnForStmt(ForStmt *forStmt); // Collect all 'for' statements. @@ -137,7 +137,7 @@ static bool doubleBuffer(Value *oldMemRef, ForStmt *forStmt) { } /// Returns success if the IR is in a valid state. -PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) { +PassResult PipelineDataTransfer::runOnMLFunction(Function *f) { // Do a post order walk so that inner loop DMAs are processed first. This is // necessary since 'for' statements nested within would otherwise become // invalid (erased) when the outer loop is pipelined (the pipelined one gets diff --git a/mlir/lib/Transforms/SimplifyAffineExpr.cpp b/mlir/lib/Transforms/SimplifyAffineExpr.cpp index b0b31e01175..853a814e516 100644 --- a/mlir/lib/Transforms/SimplifyAffineExpr.cpp +++ b/mlir/lib/Transforms/SimplifyAffineExpr.cpp @@ -33,7 +33,7 @@ using llvm::report_fatal_error; namespace { /// Simplifies all affine expressions appearing in the operation statements of -/// the MLFunction. This is mainly to test the simplifyAffineExpr method. +/// the Function. This is mainly to test the simplifyAffineExpr method. // TODO(someone): Gradually, extend this to all affine map references found in // ML functions and CFG functions. struct SimplifyAffineStructures : public FunctionPass, @@ -41,10 +41,10 @@ struct SimplifyAffineStructures : public FunctionPass, explicit SimplifyAffineStructures() : FunctionPass(&SimplifyAffineStructures::passID) {} - PassResult runOnMLFunction(MLFunction *f) override; + PassResult runOnMLFunction(Function *f) override; // Does nothing on CFG functions for now. No reusable walkers/visitors exist // for this yet? TODO(someone). - PassResult runOnCFGFunction(CFGFunction *f) override { return success(); } + PassResult runOnCFGFunction(Function *f) override { return success(); } void visitIfStmt(IfStmt *ifStmt); void visitOperationInst(OperationInst *opStmt); @@ -86,7 +86,7 @@ void SimplifyAffineStructures::visitOperationInst(OperationInst *opStmt) { } } -PassResult SimplifyAffineStructures::runOnMLFunction(MLFunction *f) { +PassResult SimplifyAffineStructures::runOnMLFunction(Function *f) { walk(f); return success(); } diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index f493e4b090b..a4116667794 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -255,7 +255,7 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction, uniquedConstants.clear(); } -static void processMLFunction(MLFunction *fn, +static void processMLFunction(Function *fn, OwningRewritePatternList &&patterns) { class MLFuncRewriter : public WorklistRewriter { public: @@ -287,7 +287,7 @@ static void processMLFunction(MLFunction *fn, driver.simplifyFunction(fn, rewriter); } -static void processCFGFunction(CFGFunction *fn, +static void processCFGFunction(Function *fn, OwningRewritePatternList &&patterns) { class CFGFuncRewriter : public WorklistRewriter { public: diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 4a2831c0a83..7def4fe2f09 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -122,9 +122,9 @@ bool mlir::promoteIfSingleIteration(ForStmt *forStmt) { return true; } -/// Promotes all single iteration for stmt's in the MLFunction, i.e., moves +/// Promotes all single iteration for stmt's in the Function, i.e., moves /// their body into the containing StmtBlock. -void mlir::promoteSingleIterationLoops(MLFunction *f) { +void mlir::promoteSingleIterationLoops(Function *f) { // Gathers all innermost loops through a post order pruned walk. class LoopBodyPromoter : public StmtWalker { public: @@ -357,8 +357,8 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) { auto ubMap = forStmt->getUpperBoundMap(); // Loops with max/min expressions won't be unrolled here (the output can't be - // expressed as an MLFunction in the general case). However, the right way to - // do such unrolling for an MLFunction would be to specialize the loop for the + // expressed as a Function in the general case). However, the right way to + // do such unrolling for a Function would be to specialize the loop for the // 'hotspot' case and unroll that hotspot. if (lbMap.getNumResults() != 1 || ubMap.getNumResults() != 1) return false; diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index a0556c43c7f..3661c1bdbbc 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -434,7 +434,7 @@ void mlir::remapFunctionAttrs( void mlir::remapFunctionAttrs( Function &fn, const DenseMap &remappingTable) { - // Look at all instructions in a CFGFunction. + // Look at all instructions in a Function. if (fn.isCFG()) { for (auto &bb : fn.getBlockList()) { for (auto &inst : bb) { diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index 5abd3a3cfcc..78d048b4778 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -73,12 +73,12 @@ struct VectorizerTestPass : public FunctionPass { static constexpr auto kTestAffineMapAttrName = "affine_map"; VectorizerTestPass() : FunctionPass(&VectorizerTestPass::passID) {} - PassResult runOnMLFunction(MLFunction *f) override; - void testVectorShapeRatio(MLFunction *f); - void testForwardSlicing(MLFunction *f); - void testBackwardSlicing(MLFunction *f); - void testSlicing(MLFunction *f); - void testComposeMaps(MLFunction *f); + PassResult runOnMLFunction(Function *f) override; + void testVectorShapeRatio(Function *f); + void testForwardSlicing(Function *f); + void testBackwardSlicing(Function *f); + void testSlicing(Function *f); + void testComposeMaps(Function *f); // Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit. MLFunctionMatcherContext MLContext; @@ -90,7 +90,7 @@ struct VectorizerTestPass : public FunctionPass { char VectorizerTestPass::passID = 0; -void VectorizerTestPass::testVectorShapeRatio(MLFunction *f) { +void VectorizerTestPass::testVectorShapeRatio(Function *f) { using matcher::Op; SmallVector shape(clTestVectorShapeRatio.begin(), clTestVectorShapeRatio.end()); @@ -139,7 +139,7 @@ static std::string toString(Statement *stmt) { return res; } -static MLFunctionMatches matchTestSlicingOps(MLFunction *f) { +static MLFunctionMatches matchTestSlicingOps(Function *f) { // Just use a custom op name for this test, it makes life easier. constexpr auto kTestSlicingOpName = "slicing-test-op"; using functional::map; @@ -153,7 +153,7 @@ static MLFunctionMatches matchTestSlicingOps(MLFunction *f) { return pat.match(f); } -void VectorizerTestPass::testBackwardSlicing(MLFunction *f) { +void VectorizerTestPass::testBackwardSlicing(Function *f) { auto matches = matchTestSlicingOps(f); for (auto m : matches) { SetVector backwardSlice; @@ -166,7 +166,7 @@ void VectorizerTestPass::testBackwardSlicing(MLFunction *f) { } } -void VectorizerTestPass::testForwardSlicing(MLFunction *f) { +void VectorizerTestPass::testForwardSlicing(Function *f) { auto matches = matchTestSlicingOps(f); for (auto m : matches) { SetVector forwardSlice; @@ -179,7 +179,7 @@ void VectorizerTestPass::testForwardSlicing(MLFunction *f) { } } -void VectorizerTestPass::testSlicing(MLFunction *f) { +void VectorizerTestPass::testSlicing(Function *f) { auto matches = matchTestSlicingOps(f); for (auto m : matches) { SetVector staticSlice = getSlice(m.first); @@ -197,7 +197,7 @@ bool customOpWithAffineMapAttribute(const Statement &stmt) { VectorizerTestPass::kTestAffineMapOpName; } -void VectorizerTestPass::testComposeMaps(MLFunction *f) { +void VectorizerTestPass::testComposeMaps(Function *f) { using matcher::Op; auto pattern = Op(customOpWithAffineMapAttribute); auto matches = pattern.match(f); @@ -218,7 +218,7 @@ void VectorizerTestPass::testComposeMaps(MLFunction *f) { res.print(outs() << "\nComposed map: "); } -PassResult VectorizerTestPass::runOnMLFunction(MLFunction *f) { +PassResult VectorizerTestPass::runOnMLFunction(Function *f) { if (!clTestVectorShapeRatio.empty()) { testVectorShapeRatio(f); } diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 762a09ea048..ddbd6256782 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -47,7 +47,7 @@ using namespace mlir; /// -/// Implements a high-level vectorization strategy on an MLFunction. +/// Implements a high-level vectorization strategy on a Function. /// The abstraction used is that of super-vectors, which provide a single, /// compact, representation in the vector types, information that is expected /// to reduce the impact of the phase ordering problem @@ -382,7 +382,7 @@ using namespace mlir; /// /// Examples: /// ========= -/// Consider the following MLFunction: +/// Consider the following Function: /// ```mlir /// mlfunc @vector_add_2d(%M : index, %N : index) -> f32 { /// %A = alloc (%M, %N) : memref @@ -651,7 +651,7 @@ namespace { struct Vectorize : public FunctionPass { Vectorize() : FunctionPass(&Vectorize::passID) {} - PassResult runOnMLFunction(MLFunction *f) override; + PassResult runOnMLFunction(Function *f) override; // Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit. MLFunctionMatcherContext MLContext; @@ -1264,13 +1264,13 @@ static bool vectorizeRootMatches(MLFunctionMatches matches, return false; } -/// Applies vectorization to the current MLFunction by searching over a bunch of +/// Applies vectorization to the current Function by searching over a bunch of /// predetermined patterns. -PassResult Vectorize::runOnMLFunction(MLFunction *f) { +PassResult Vectorize::runOnMLFunction(Function *f) { for (auto pat : makePatterns()) { LLVM_DEBUG(dbgs() << "\n******************************************"); LLVM_DEBUG(dbgs() << "\n******************************************"); - LLVM_DEBUG(dbgs() << "\n[early-vect] new pattern on MLFunction\n"); + LLVM_DEBUG(dbgs() << "\n[early-vect] new pattern on Function\n"); LLVM_DEBUG(f->print(dbgs())); unsigned patternDepth = pat.getDepth(); auto matches = pat.match(f); diff --git a/mlir/lib/Transforms/ViewFunctionGraph.cpp b/mlir/lib/Transforms/ViewFunctionGraph.cpp index 9c1614acb95..2ce8af3613a 100644 --- a/mlir/lib/Transforms/ViewFunctionGraph.cpp +++ b/mlir/lib/Transforms/ViewFunctionGraph.cpp @@ -25,16 +25,15 @@ namespace llvm { // Specialize DOTGraphTraits to produce more readable output. template <> -struct llvm::DOTGraphTraits - : public DefaultDOTGraphTraits { +struct llvm::DOTGraphTraits : public DefaultDOTGraphTraits { using DefaultDOTGraphTraits::DefaultDOTGraphTraits; static std::string getNodeLabel(const BasicBlock *basicBlock, - const CFGFunction *); + const Function *); }; -std::string llvm::DOTGraphTraits::getNodeLabel( - const BasicBlock *basicBlock, const CFGFunction *) { +std::string llvm::DOTGraphTraits::getNodeLabel( + const BasicBlock *basicBlock, const Function *) { // Reuse the print output for the node labels. std::string outStreamStr; raw_string_ostream os(outStreamStr); @@ -57,19 +56,19 @@ std::string llvm::DOTGraphTraits::getNodeLabel( } // end namespace llvm -void mlir::viewGraph(const CFGFunction &function, const llvm::Twine &name, +void mlir::viewGraph(const Function &function, const llvm::Twine &name, bool shortNames, const llvm::Twine &title, llvm::GraphProgram::Name program) { llvm::ViewGraph(&function, name, shortNames, title, program); } llvm::raw_ostream &mlir::writeGraph(llvm::raw_ostream &os, - const CFGFunction *function, - bool shortNames, const llvm::Twine &title) { + const Function *function, bool shortNames, + const llvm::Twine &title) { return llvm::WriteGraph(os, function, shortNames, title); } -void mlir::CFGFunction::viewGraph() const { +void mlir::Function::viewGraph() const { ::mlir::viewGraph(*this, llvm::Twine("cfgfunc ") + getName().str()); } @@ -79,7 +78,7 @@ struct PrintCFGPass : public FunctionPass { const llvm::Twine &title = "") : FunctionPass(&PrintCFGPass::passID), os(os), shortNames(shortNames), title(title) {} - PassResult runOnCFGFunction(CFGFunction *function) override { + PassResult runOnCFGFunction(Function *function) override { mlir::writeGraph(os, function, shortNames, title); return success(); } -- cgit v1.2.3 From 315a466aed9bcc896007098719ed9e0a35a3459d Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Fri, 28 Dec 2018 13:07:39 -0800 Subject: Rename BasicBlock and StmtBlock to Block, and make a pass cleaning it up. I did not make an effort to rename all of the 'bb' names in the codebase, since they are still correct and any specific missed once can be fixed up on demand. The last major renaming is Statement -> Instruction, which is why Statement and Stmt still appears in various places. This is step 19/n towards merging instructions and statements, NFC. PiperOrigin-RevId: 227163082 --- mlir/g3doc/LangRef.md | 51 ++- mlir/g3doc/Rationale.md | 17 +- mlir/include/mlir/Analysis/Dominance.h | 16 +- mlir/include/mlir/IR/Block.h | 486 ++++++++++++++++++++++++++ mlir/include/mlir/IR/Builders.h | 28 +- mlir/include/mlir/IR/BuiltinOps.h | 20 +- mlir/include/mlir/IR/Function.h | 31 +- mlir/include/mlir/IR/FunctionGraphTraits.h | 63 ++-- mlir/include/mlir/IR/OpDefinition.h | 6 +- mlir/include/mlir/IR/OpImplementation.h | 2 +- mlir/include/mlir/IR/OperationSupport.h | 9 +- mlir/include/mlir/IR/Statement.h | 18 +- mlir/include/mlir/IR/Statements.h | 50 +-- mlir/include/mlir/IR/StmtBlock.h | 497 --------------------------- mlir/include/mlir/IR/Value.h | 12 +- mlir/include/mlir/Transforms/LoopUtils.h | 2 +- mlir/lib/Analysis/AffineAnalysis.cpp | 24 +- mlir/lib/Analysis/Dominance.cpp | 10 +- mlir/lib/Analysis/LoopAnalysis.cpp | 6 +- mlir/lib/Analysis/Utils.cpp | 24 +- mlir/lib/Analysis/Verifier.cpp | 16 +- mlir/lib/IR/AsmPrinter.cpp | 43 +-- mlir/lib/IR/Block.cpp | 232 +++++++++++++ mlir/lib/IR/Builders.cpp | 10 +- mlir/lib/IR/BuiltinOps.cpp | 18 +- mlir/lib/IR/Function.cpp | 2 +- mlir/lib/IR/Operation.cpp | 8 +- mlir/lib/IR/Statement.cpp | 43 ++- mlir/lib/IR/StmtBlock.cpp | 236 ------------- mlir/lib/Parser/Parser.cpp | 62 ++-- mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp | 16 +- mlir/lib/Transforms/CSE.cpp | 4 +- mlir/lib/Transforms/ComposeAffineMaps.cpp | 8 +- mlir/lib/Transforms/ConvertToCFG.cpp | 38 +- mlir/lib/Transforms/DmaGeneration.cpp | 4 +- mlir/lib/Transforms/LoopFusion.cpp | 2 +- mlir/lib/Transforms/LoopTiling.cpp | 17 +- mlir/lib/Transforms/LoopUnroll.cpp | 6 +- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 8 +- mlir/lib/Transforms/LowerAffineApply.cpp | 2 +- mlir/lib/Transforms/LowerVectorTransfers.cpp | 2 +- mlir/lib/Transforms/PipelineDataTransfer.cpp | 2 +- mlir/lib/Transforms/Utils/LoopUtils.cpp | 14 +- mlir/lib/Transforms/ViewFunctionGraph.cpp | 10 +- 44 files changed, 1074 insertions(+), 1101 deletions(-) create mode 100644 mlir/include/mlir/IR/Block.h delete mode 100644 mlir/include/mlir/IR/StmtBlock.h create mode 100644 mlir/lib/IR/Block.cpp delete mode 100644 mlir/lib/IR/StmtBlock.cpp (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/g3doc/LangRef.md b/mlir/g3doc/LangRef.md index 5c562e390af..3469133a8ba 100644 --- a/mlir/g3doc/LangRef.md +++ b/mlir/g3doc/LangRef.md @@ -33,8 +33,8 @@ list of [Functions](#functions), and there are two types of function definitions, a "[CFG Function](#cfg-functions)" and an "[ML Function](#ml-functions)". Both kinds of functions are represented as a composition of [operations](#operations), but represent control flow in -different ways: A CFG Function control flow using a CFG of -[BasicBlocks](#basic-blocks), which contain instructions and end with +different ways: A CFG Function control flow using a CFG of [Blocks](#blocks), +which contain instructions and end with [control flow terminator statements](#terminator-instructions) (like branches). ML Functions represents control flow with a nest of affine loops and if conditions, and are said to contain statements. Both types of functions can call @@ -65,7 +65,7 @@ Here's an example of an MLIR module: // result using a TensorFlow op. The dimensions of A and B are partially // known. The shapes are assumed to match. cfgfunc @mul(tensor<100x?xf32>, tensor) -> (tensor<100x50xf32>) { -// Basic block bb0. %A and %B come from function arguments. +// Block bb0. %A and %B come from function arguments. bb0(%A: tensor<100x?xf32>, %B: tensor): // Compute the inner dimension of %A using the dim operation. %n = dim %A, 1 : tensor<100x?xf32> @@ -606,9 +606,8 @@ function-type ::= type-list-parens `->` type-list MLIR supports first-class functions: the [`constant` operation](#'constant'-operation) produces the address of a function as an SSA value. This SSA value may be passed to and returned from functions, -merged across control flow boundaries with -[basic block arguments](#basic-blocks), and called with the -[`call_indirect` instruction](#'call_indirect'-operation). +merged across control flow boundaries with [block arguments](#blocks), and +called with the [`call_indirect` instruction](#'call_indirect'-operation). Function types are also used to indicate the arguments and results of [operations](#operations). @@ -916,7 +915,7 @@ Syntax: ``` {.ebnf} cfg-func ::= `cfgfunc` function-signature - (`attributes` attribute-dict)? `{` basic-block+ `}` + (`attributes` attribute-dict)? `{` block+ `}` ``` A simple CFG function that returns its argument twice looks like this: @@ -935,14 +934,14 @@ TensorFlow dataflow graph, where the instructions are TensorFlow "ops" producing values of Tensor type. It can also represent scalar math, and can be used as a way to lower [ML Functions](#ml-functions) before late code generation. -#### Basic Blocks {#basic-blocks} +#### Blocks {#blocks} Syntax: ``` {.ebnf} -basic-block ::= bb-label operation* terminator-stmt -bb-label ::= bb-id bb-arg-list? `:` -bb-id ::= bare-id +block ::= bb-label operation* terminator-stmt +bb-label ::= bb-id bb-arg-list? `:` +bb-id ::= bare-id ssa-id-and-type ::= ssa-id `:` type // Non-empty list of names and types. @@ -954,14 +953,14 @@ bb-arg-list ::= `(` ssa-id-and-type-list? `)` A [basic block](https://en.wikipedia.org/wiki/Basic_block) is a sequential list of operation instructions without control flow (calls are not considered control flow for this purpose) that are executed from top to bottom. The last -instruction in a basic block is a -[terminator instruction](#terminator-instructions), which ends the block. +instruction in a block is a [terminator instruction](#terminator-instructions), +which ends the block. -Basic blocks in MLIR take a list of arguments, which represent SSA PHI nodes in -a functional notation. The arguments are defined by the block, and values are -provided for these basic block arguments by branches that go to the block. +Blocks in MLIR take a list of arguments, which represent SSA PHI nodes in a +functional notation. The arguments are defined by the block, and values are +provided for these block arguments by branches that go to the block. -Here is a simple example function showing branches, returns, and basic block +Here is a simple example function showing branches, returns, and block arguments: ```mlir {.mlir} @@ -987,13 +986,13 @@ bb4(%d : i64, %e : i64): } ``` -**Context:** The "basic block argument" representation eliminates a number of -special cases from the IR compared to traditional "PHI nodes are instructions" -SSA IRs (like LLVM). For example, the +**Context:** The "block argument" representation eliminates a number of special +cases from the IR compared to traditional "PHI nodes are instructions" SSA IRs +(like LLVM). For example, the [parallel copy semantics](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.524.5461&rep=rep1&type=pdf) of SSA is immediately apparent, and function arguments are no longer a special case: they become arguments to the entry block -[[more rationale](Rationale.md#basic-block-arguments-vs-phi-nodes)]. +[[more rationale](Rationale.md#block-arguments-vs-phi-nodes)]. Control flow within a CFG function is implemented with unconditional branches, conditional branches, and a return statement. @@ -1014,9 +1013,9 @@ terminator-stmt ::= `br` bb-id branch-use-list? branch-use-list ::= `(` ssa-use-list `:` type-list-no-parens `)` ``` -The `br` terminator statement represents an unconditional jump to a target basic +The `br` terminator statement represents an unconditional jump to a target block. The count and types of operands to the branch must align with the -arguments in the target basic block. +arguments in the target block. The MLIR branch instruction is not allowed to target the entry block for a function. @@ -1040,7 +1039,7 @@ for a function. The two destinations of the conditional branch instruction are allowed to be the same. The following example illustrates a CFG function with a conditional branch -instruction that targets the same basic block: +instruction that targets the same block: ```mlir {.mlir} cfgfunc @select(%a : i32, %b :i32, %flag : i1) -> i32 { @@ -1318,8 +1317,8 @@ operation ::= ssa-id `=` `call_indirect` ssa-use The `call_indirect` operation represents an indirect call to a value of function type. Functions are first class types in MLIR, and may be passed as arguments -and merged together with basic block arguments. The operands and result types of -the call must match the specified function type. +and merged together with block arguments. The operands and result types of the +call must match the specified function type. Function values can be created with the [`constant` operation](#'constant'-operation). diff --git a/mlir/g3doc/Rationale.md b/mlir/g3doc/Rationale.md index 791fe31fce9..17cbd1d15c1 100644 --- a/mlir/g3doc/Rationale.md +++ b/mlir/g3doc/Rationale.md @@ -171,15 +171,14 @@ type - memref<8x%Nxf32>. We went for the current approach in MLIR because it simplifies the design --- types remain immutable when the values of symbols change. -### Basic Block Arguments vs PHI nodes {#basic-block-arguments-vs-phi-nodes} +### Block Arguments vs PHI nodes {#block-arguments-vs-phi-nodes} -MLIR CFG Functions represent SSA using -"[basic block arguments](LangRef.md#basic-blocks)" rather than -[PHI instructions](http://llvm.org/docs/LangRef.html#i-phi) used in LLVM. This -choice is representationally identical (the same constructs can be represented -in either form) but basic block arguments have several advantages: +MLIR CFG Functions represent SSA using "[block arguments](LangRef.md#blocks)" +rather than [PHI instructions](http://llvm.org/docs/LangRef.html#i-phi) used in +LLVM. This choice is representationally identical (the same constructs can be +represented in either form) but block arguments have several advantages: -1. LLVM PHI nodes always have to be kept at the top of a basic block, and +1. LLVM PHI nodes always have to be kept at the top of a block, and transformations frequently have to manually skip over them. This is defined away with BB arguments. 1. LLVM has a separate function Argument node. This is defined away with BB @@ -202,7 +201,7 @@ in either form) but basic block arguments have several advantages: but SIL uses it extensively, e.g. in the [switch_enum instruction](https://github.com/apple/swift/blob/master/docs/SIL.rst#switch-enum). -For more context, basic block arguments were previously used in the Swift +For more context, block arguments were previously used in the Swift [SIL Intermediate Representation](https://github.com/apple/swift/blob/master/docs/SIL.rst), and described in [a talk on YouTube](https://www.youtube.com/watch?v=Ntj8ab-5cvE). The section of @@ -474,7 +473,7 @@ for (i=0; i ; -extern template class llvm::DominatorTreeBase; -extern template class llvm::DomTreeNodeBase; +extern template class llvm::DominatorTreeBase; +extern template class llvm::DominatorTreeBase; +extern template class llvm::DomTreeNodeBase; namespace llvm { namespace DomTreeBuilder { -using MLIRDomTree = llvm::DomTreeBase; -using MLIRPostDomTree = llvm::PostDomTreeBase; +using MLIRDomTree = llvm::DomTreeBase; +using MLIRPostDomTree = llvm::PostDomTreeBase; // extern template void Calculate(MLIRDomTree &DT); // extern template void Calculate(MLIRPostDomTree &DT); @@ -38,9 +38,9 @@ using MLIRPostDomTree = llvm::PostDomTreeBase; } // namespace llvm namespace mlir { -using DominatorTreeBase = llvm::DominatorTreeBase; -using PostDominatorTreeBase = llvm::DominatorTreeBase; -using DominanceInfoNode = llvm::DomTreeNodeBase; +using DominatorTreeBase = llvm::DominatorTreeBase; +using PostDominatorTreeBase = llvm::DominatorTreeBase; +using DominanceInfoNode = llvm::DomTreeNodeBase; /// A class for computing basic dominance information. class DominanceInfo : public DominatorTreeBase { diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h new file mode 100644 index 00000000000..985d0fdb075 --- /dev/null +++ b/mlir/include/mlir/IR/Block.h @@ -0,0 +1,486 @@ +//===- Block.h - MLIR Block and BlockList Classes ---------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines Block and BlockList classes. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_BLOCK_H +#define MLIR_IR_BLOCK_H + +#include "mlir/IR/Statement.h" +#include "llvm/ADT/PointerUnion.h" + +namespace mlir { +class IfStmt; +class BlockList; + +template class PredecessorIterator; +template class SuccessorIterator; + +/// `Block` represents an ordered list of `Instruction`s. +class Block : public IRObjectWithUseList, + public llvm::ilist_node_with_parent { +public: + explicit Block() {} + ~Block(); + + void clear() { + // Clear instructions in the reverse order so that uses are destroyed + // before their defs. + while (!empty()) + instructions.pop_back(); + } + + /// Blocks are maintained in a list by BlockList type. + BlockList *getParent() const { return parent; } + + /// Returns the closest surrounding instruction that contains this block or + /// nullptr if this is a top-level block. + Instruction *getContainingInst(); + + const Instruction *getContainingInst() const { + return const_cast(this)->getContainingInst(); + } + + /// Returns the function that this block is part of, even if the block is + /// nested under an IfStmt or ForStmt. + Function *getFunction(); + const Function *getFunction() const { + return const_cast(this)->getFunction(); + } + + //===--------------------------------------------------------------------===// + // Block argument management + //===--------------------------------------------------------------------===// + + // This is the list of arguments to the block. + using BlockArgListType = ArrayRef; + + // FIXME: Not const correct. + BlockArgListType getArguments() const { return arguments; } + + using args_iterator = BlockArgListType::iterator; + using reverse_args_iterator = BlockArgListType::reverse_iterator; + args_iterator args_begin() const { return getArguments().begin(); } + args_iterator args_end() const { return getArguments().end(); } + reverse_args_iterator args_rbegin() const { return getArguments().rbegin(); } + reverse_args_iterator args_rend() const { return getArguments().rend(); } + + bool args_empty() const { return arguments.empty(); } + + /// Add one value to the argument list. + BlockArgument *addArgument(Type type); + + /// Add one argument to the argument list for each type specified in the list. + llvm::iterator_range addArguments(ArrayRef types); + + /// Erase the argument at 'index' and remove it from the argument list. + void eraseArgument(unsigned index); + + unsigned getNumArguments() const { return arguments.size(); } + BlockArgument *getArgument(unsigned i) { return arguments[i]; } + const BlockArgument *getArgument(unsigned i) const { return arguments[i]; } + + //===--------------------------------------------------------------------===// + // Instruction list management + //===--------------------------------------------------------------------===// + + /// This is the list of instructions in the block. + using InstListType = llvm::iplist; + InstListType &getInstructions() { return instructions; } + const InstListType &getInstructions() const { return instructions; } + + // Iteration over the instructions in the block. + using iterator = InstListType::iterator; + using const_iterator = InstListType::const_iterator; + using reverse_iterator = InstListType::reverse_iterator; + using const_reverse_iterator = InstListType::const_reverse_iterator; + + iterator begin() { return instructions.begin(); } + iterator end() { return instructions.end(); } + const_iterator begin() const { return instructions.begin(); } + const_iterator end() const { return instructions.end(); } + reverse_iterator rbegin() { return instructions.rbegin(); } + reverse_iterator rend() { return instructions.rend(); } + const_reverse_iterator rbegin() const { return instructions.rbegin(); } + const_reverse_iterator rend() const { return instructions.rend(); } + + bool empty() const { return instructions.empty(); } + void push_back(Instruction *inst) { instructions.push_back(inst); } + void push_front(Instruction *inst) { instructions.push_front(inst); } + + Instruction &back() { return instructions.back(); } + const Instruction &back() const { return const_cast(this)->back(); } + Instruction &front() { return instructions.front(); } + const Instruction &front() const { + return const_cast(this)->front(); + } + + /// Returns the instructions's position in this block or -1 if the instruction + /// is not present. + /// TODO: This is needlessly inefficient, and should not be API on Block. + int64_t findInstPositionInBlock(const Instruction &stmt) const { + int64_t j = 0; + for (const auto &s : instructions) { + if (&s == &stmt) + return j; + j++; + } + return -1; + } + + /// Returns 'inst' if 'inst' lies in this block, or otherwise finds the + /// ancestor instruction of 'inst' that lies in this block. Returns nullptr if + /// the latter fails. + /// TODO: This is very specific functionality that should live somewhere else. + const Instruction *findAncestorInstInBlock(const Instruction &inst) const; + /// TODO: This const overload is wrong. + Instruction *findAncestorInstInBlock(Instruction *inst) { + return const_cast(findAncestorInstInBlock(*inst)); + } + + //===--------------------------------------------------------------------===// + // Terminator management + //===--------------------------------------------------------------------===// + + /// Get the terminator instruction of this block, or null if the block is + /// malformed. + OperationInst *getTerminator(); + + const OperationInst *getTerminator() const { + return const_cast(this)->getTerminator(); + } + + //===--------------------------------------------------------------------===// + // Predecessors and successors. + //===--------------------------------------------------------------------===// + + // Predecessor iteration. + using const_pred_iterator = PredecessorIterator; + const_pred_iterator pred_begin() const; + const_pred_iterator pred_end() const; + llvm::iterator_range getPredecessors() const; + + using pred_iterator = PredecessorIterator; + pred_iterator pred_begin(); + pred_iterator pred_end(); + llvm::iterator_range getPredecessors(); + + /// Return true if this block has no predecessors. + bool hasNoPredecessors() const; + + /// If this block has exactly one predecessor, return it. Otherwise, return + /// null. + /// + /// Note that if a block has duplicate predecessors from a single block (e.g. + /// if you have a conditional branch with the same block as the true/false + /// destinations) is not considered to be a single predecessor. + Block *getSinglePredecessor(); + + const Block *getSinglePredecessor() const { + return const_cast(this)->getSinglePredecessor(); + } + + // Indexed successor access. + unsigned getNumSuccessors() const; + const Block *getSuccessor(unsigned i) const { + return const_cast(this)->getSuccessor(i); + } + Block *getSuccessor(unsigned i); + + // Successor iteration. + using const_succ_iterator = SuccessorIterator; + const_succ_iterator succ_begin() const; + const_succ_iterator succ_end() const; + llvm::iterator_range getSuccessors() const; + + using succ_iterator = SuccessorIterator; + succ_iterator succ_begin(); + succ_iterator succ_end(); + llvm::iterator_range getSuccessors(); + + //===--------------------------------------------------------------------===// + // Other + //===--------------------------------------------------------------------===// + + /// Unlink this Block from its Function and delete it. + void eraseFromFunction(); + + /// Split the basic block into two basic blocks before the specified + /// instruction or iterator. + /// + /// Note that all instructions BEFORE the specified iterator stay as part of + /// the original basic block, an unconditional branch is added to the original + /// block (going to the new block), and the rest of the instructions in the + /// original block are moved to the new block, including the old terminator. + /// The newly formed Block is returned. + /// + /// This function invalidates the specified iterator. + Block *splitBlock(iterator splitBefore); + Block *splitBlock(Instruction *splitBeforeInst) { + return splitBlock(iterator(splitBeforeInst)); + } + + /// Returns pointer to member of instruction list. + static InstListType Block::*getSublistAccess(Instruction *) { + return &Block::instructions; + } + + void print(raw_ostream &os) const; + void dump() const; + + /// Print out the name of the basic block without printing its body. + /// NOTE: The printType argument is ignored. We keep it for compatibility + /// with LLVM dominator machinery that expects it to exist. + void printAsOperand(raw_ostream &os, bool printType = true); + +private: + /// This is the parent object that owns this block. + BlockList *parent = nullptr; + + /// This is the list of instructions in the block. + InstListType instructions; + + /// This is the list of arguments to the block. + std::vector arguments; + + Block(const Block &) = delete; + void operator=(const Block &) = delete; + + friend struct llvm::ilist_traits; +}; + +} // end namespace mlir + +//===----------------------------------------------------------------------===// +// ilist_traits for Block +//===----------------------------------------------------------------------===// + +namespace llvm { + +template <> +struct ilist_traits<::mlir::Block> : public ilist_alloc_traits<::mlir::Block> { + using Block = ::mlir::Block; + using block_iterator = simple_ilist<::mlir::Block>::iterator; + + void addNodeToList(Block *block); + void removeNodeFromList(Block *block); + void transferNodesFromList(ilist_traits &otherList, + block_iterator first, block_iterator last); + +private: + mlir::BlockList *getContainingBlockList(); +}; +} // end namespace llvm + +namespace mlir { + +/// This class contains a list of basic blocks and has a notion of the object it +/// is part of - a Function or IfStmt or ForStmt. +class BlockList { +public: + explicit BlockList(Function *container); + explicit BlockList(Instruction *container); + + using BlockListType = llvm::iplist; + BlockListType &getBlocks() { return blocks; } + const BlockListType &getBlocks() const { return blocks; } + + // Iteration over the block in the function. + using iterator = BlockListType::iterator; + using const_iterator = BlockListType::const_iterator; + using reverse_iterator = BlockListType::reverse_iterator; + using const_reverse_iterator = BlockListType::const_reverse_iterator; + + iterator begin() { return blocks.begin(); } + iterator end() { return blocks.end(); } + const_iterator begin() const { return blocks.begin(); } + const_iterator end() const { return blocks.end(); } + reverse_iterator rbegin() { return blocks.rbegin(); } + reverse_iterator rend() { return blocks.rend(); } + const_reverse_iterator rbegin() const { return blocks.rbegin(); } + const_reverse_iterator rend() const { return blocks.rend(); } + + bool empty() const { return blocks.empty(); } + void push_back(Block *block) { blocks.push_back(block); } + void push_front(Block *block) { blocks.push_front(block); } + + Block &back() { return blocks.back(); } + const Block &back() const { return const_cast(this)->back(); } + + Block &front() { return blocks.front(); } + const Block &front() const { return const_cast(this)->front(); } + + /// getSublistAccess() - Returns pointer to member of block list. + static BlockListType BlockList::*getSublistAccess(Block *) { + return &BlockList::blocks; + } + + /// A BlockList is part of a Function or and IfStmt/ForStmt. If it is + /// part of an IfStmt/ForStmt, then return it, otherwise return null. + Instruction *getContainingInst(); + const Instruction *getContainingInst() const { + return const_cast(this)->getContainingInst(); + } + + /// A BlockList is part of a Function or and IfStmt/ForStmt. If it is + /// part of a Function, then return it, otherwise return null. + Function *getContainingFunction(); + const Function *getContainingFunction() const { + return const_cast(this)->getContainingFunction(); + } + +private: + BlockListType blocks; + + /// This is the object we are part of. + llvm::PointerUnion container; +}; + +//===----------------------------------------------------------------------===// +// Predecessors +//===----------------------------------------------------------------------===// + +/// Implement a predecessor iterator as a forward iterator. This works by +/// walking the use lists of the blocks. The entries on this list are the +/// BlockOperands that are embedded into terminator instructions. From the +/// operand, we can get the terminator that contains it, and it's parent block +/// is the predecessor. +template +class PredecessorIterator + : public llvm::iterator_facade_base, + std::forward_iterator_tag, + BlockType *> { +public: + PredecessorIterator(BlockOperand *firstOperand) + : bbUseIterator(firstOperand) {} + + PredecessorIterator &operator=(const PredecessorIterator &rhs) { + bbUseIterator = rhs.bbUseIterator; + } + + bool operator==(const PredecessorIterator &rhs) const { + return bbUseIterator == rhs.bbUseIterator; + } + + BlockType *operator*() const { + // The use iterator points to an operand of a terminator. The predecessor + // we return is the block that the terminator is embedded into. + return bbUseIterator.getUser()->getBlock(); + } + + PredecessorIterator &operator++() { + ++bbUseIterator; + return *this; + } + + /// Get the successor number in the predecessor terminator. + unsigned getSuccessorIndex() const { + return bbUseIterator->getOperandNumber(); + } + +private: + using BBUseIterator = ValueUseIterator; + BBUseIterator bbUseIterator; +}; + +inline auto Block::pred_begin() const -> const_pred_iterator { + return const_pred_iterator((BlockOperand *)getFirstUse()); +} + +inline auto Block::pred_end() const -> const_pred_iterator { + return const_pred_iterator(nullptr); +} + +inline auto Block::getPredecessors() const + -> llvm::iterator_range { + return {pred_begin(), pred_end()}; +} + +inline auto Block::pred_begin() -> pred_iterator { + return pred_iterator((BlockOperand *)getFirstUse()); +} + +inline auto Block::pred_end() -> pred_iterator { + return pred_iterator(nullptr); +} + +inline auto Block::getPredecessors() -> llvm::iterator_range { + return {pred_begin(), pred_end()}; +} + +//===----------------------------------------------------------------------===// +// Successors +//===----------------------------------------------------------------------===// + +/// This template implements the successor iterators for Block. +template +class SuccessorIterator final + : public IndexedAccessorIterator, BlockType, + BlockType> { +public: + /// Initializes the result iterator to the specified index. + SuccessorIterator(BlockType *object, unsigned index) + : IndexedAccessorIterator, BlockType, + BlockType>(object, index) {} + + SuccessorIterator(const SuccessorIterator &other) + : SuccessorIterator(other.object, other.index) {} + + /// Support converting to the const variant. This will be a no-op for const + /// variant. + operator SuccessorIterator() const { + return SuccessorIterator(this->object, this->index); + } + + BlockType *operator*() const { + return this->object->getSuccessor(this->index); + } + + /// Get the successor number in the terminator. + unsigned getSuccessorIndex() const { return this->index; } +}; + +inline auto Block::succ_begin() const -> const_succ_iterator { + return const_succ_iterator(this, 0); +} + +inline auto Block::succ_end() const -> const_succ_iterator { + return const_succ_iterator(this, getNumSuccessors()); +} + +inline auto Block::getSuccessors() const + -> llvm::iterator_range { + return {succ_begin(), succ_end()}; +} + +inline auto Block::succ_begin() -> succ_iterator { + return succ_iterator(this, 0); +} + +inline auto Block::succ_end() -> succ_iterator { + return succ_iterator(this, getNumSuccessors()); +} + +inline auto Block::getSuccessors() -> llvm::iterator_range { + return {succ_begin(), succ_end()}; +} + +} // end namespace mlir + +#endif // MLIR_IR_BLOCK_H diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 1ad533b0983..5c1331e880d 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -178,11 +178,11 @@ public: setInsertionPoint(stmt); } - FuncBuilder(StmtBlock *block) : FuncBuilder(block->getFunction()) { + FuncBuilder(Block *block) : FuncBuilder(block->getFunction()) { setInsertionPoint(block, block->end()); } - FuncBuilder(StmtBlock *block, StmtBlock::iterator insertPoint) + FuncBuilder(Block *block, Block::iterator insertPoint) : FuncBuilder(block->getFunction()) { setInsertionPoint(block, insertPoint); } @@ -195,11 +195,11 @@ public: /// current insertion point a builder refers to is being removed. void clearInsertionPoint() { this->block = nullptr; - insertPoint = StmtBlock::iterator(); + insertPoint = Block::iterator(); } /// Set the insertion point to the specified location. - void setInsertionPoint(StmtBlock *block, StmtBlock::iterator insertPoint) { + void setInsertionPoint(Block *block, Block::iterator insertPoint) { // TODO: check that insertPoint is in this rather than some other block. this->block = block; this->insertPoint = insertPoint; @@ -208,31 +208,31 @@ public: /// Sets the insertion point to the specified operation, which will cause /// subsequent insertions to go right before it. void setInsertionPoint(Statement *stmt) { - setInsertionPoint(stmt->getBlock(), StmtBlock::iterator(stmt)); + setInsertionPoint(stmt->getBlock(), Block::iterator(stmt)); } /// Sets the insertion point to the start of the specified block. - void setInsertionPointToStart(StmtBlock *block) { + void setInsertionPointToStart(Block *block) { setInsertionPoint(block, block->begin()); } /// Sets the insertion point to the end of the specified block. - void setInsertionPointToEnd(StmtBlock *block) { + void setInsertionPointToEnd(Block *block) { setInsertionPoint(block, block->end()); } /// Return the block the current insertion point belongs to. Note that the /// the insertion point is not necessarily the end of the block. - BasicBlock *getInsertionBlock() const { return block; } + Block *getInsertionBlock() const { return block; } /// Returns the current insertion point of the builder. - StmtBlock::iterator getInsertionPoint() const { return insertPoint; } + Block::iterator getInsertionPoint() const { return insertPoint; } /// Add new block and set the insertion point to the end of it. If an /// 'insertBefore' block is passed, the block will be placed before the /// specified block. If not, the block will be appended to the end of the /// current function. - StmtBlock *createBlock(StmtBlock *insertBefore = nullptr); + Block *createBlock(Block *insertBefore = nullptr); /// Returns a builder for the body of a for Stmt. static FuncBuilder getForStmtBodyBuilder(ForStmt *forStmt) { @@ -240,7 +240,7 @@ public: } /// Returns the current block of the builder. - StmtBlock *getBlock() const { return block; } + Block *getBlock() const { return block; } /// Creates an operation given the fields represented as an OperationState. OperationInst *createOperation(const OperationState &state); @@ -286,7 +286,7 @@ public: Statement *clone(const Statement &stmt, OperationInst::OperandMapTy &operandMapping) { Statement *cloneStmt = stmt.clone(operandMapping, getContext()); - block->getStatements().insert(insertPoint, cloneStmt); + block->getInstructions().insert(insertPoint, cloneStmt); return cloneStmt; } @@ -305,8 +305,8 @@ public: private: Function *function; - StmtBlock *block = nullptr; - StmtBlock::iterator insertPoint; + Block *block = nullptr; + Block::iterator insertPoint; }; } // namespace mlir diff --git a/mlir/include/mlir/IR/BuiltinOps.h b/mlir/include/mlir/IR/BuiltinOps.h index 3ccfe4f9f2d..e608a704f99 100644 --- a/mlir/include/mlir/IR/BuiltinOps.h +++ b/mlir/include/mlir/IR/BuiltinOps.h @@ -99,7 +99,7 @@ class BranchOp : public Op operands = {}); // Hooks to customize behavior of this op. @@ -108,11 +108,11 @@ public: bool verify() const; /// Return the block this branch jumps to. - BasicBlock *getDest(); - const BasicBlock *getDest() const { + Block *getDest(); + const Block *getDest() const { return const_cast(this)->getDest(); } - void setDest(BasicBlock *block); + void setDest(Block *block); /// Erase the operand at 'index' from the operand list. void eraseOperand(unsigned index); @@ -147,8 +147,8 @@ public: static StringRef getOperationName() { return "cond_br"; } static void build(Builder *builder, OperationState *result, Value *condition, - BasicBlock *trueDest, ArrayRef trueOperands, - BasicBlock *falseDest, ArrayRef falseOperands); + Block *trueDest, ArrayRef trueOperands, + Block *falseDest, ArrayRef falseOperands); // Hooks to customize behavior of this op. static bool parse(OpAsmParser *parser, OperationState *result); @@ -160,14 +160,14 @@ public: const Value *getCondition() const { return getOperand(0); } /// Return the destination if the condition is true. - BasicBlock *getTrueDest(); - const BasicBlock *getTrueDest() const { + Block *getTrueDest(); + const Block *getTrueDest() const { return const_cast(this)->getTrueDest(); } /// Return the destination if the condition is false. - BasicBlock *getFalseDest(); - const BasicBlock *getFalseDest() const { + Block *getFalseDest(); + const Block *getFalseDest() const { return const_cast(this)->getFalseDest(); } diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h index 5b52a5de7e7..b79b64b68b5 100644 --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -25,9 +25,9 @@ #define MLIR_IR_FUNCTION_H #include "mlir/IR/Attributes.h" +#include "mlir/IR/Block.h" #include "mlir/IR/Identifier.h" #include "mlir/IR/Location.h" -#include "mlir/IR/StmtBlock.h" #include "mlir/IR/Types.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/ilist.h" @@ -38,7 +38,6 @@ class FunctionType; class MLIRContext; class Module; template class ArgumentIterator; -using BasicBlock = StmtBlock; /// NamedAttribute is used for function attribute lists, it holds an /// identifier for the name and a value for the attribute. The attribute @@ -82,11 +81,11 @@ public: // Body Handling //===--------------------------------------------------------------------===// - StmtBlockList &getBlockList() { return blocks; } - const StmtBlockList &getBlockList() const { return blocks; } + BlockList &getBlockList() { return blocks; } + const BlockList &getBlockList() const { return blocks; } /// This is the list of blocks in the function. - using BlockListType = llvm::iplist; + using BlockListType = llvm::iplist; BlockListType &getBlocks() { return blocks.getBlocks(); } const BlockListType &getBlocks() const { return blocks.getBlocks(); } @@ -106,29 +105,25 @@ public: const_reverse_iterator rend() const { return blocks.rend(); } bool empty() const { return blocks.empty(); } - void push_back(BasicBlock *block) { blocks.push_back(block); } - void push_front(BasicBlock *block) { blocks.push_front(block); } + void push_back(Block *block) { blocks.push_back(block); } + void push_front(Block *block) { blocks.push_front(block); } - BasicBlock &back() { return blocks.back(); } - const BasicBlock &back() const { - return const_cast(this)->back(); - } + Block &back() { return blocks.back(); } + const Block &back() const { return const_cast(this)->back(); } - BasicBlock &front() { return blocks.front(); } - const BasicBlock &front() const { - return const_cast(this)->front(); - } + Block &front() { return blocks.front(); } + const Block &front() const { return const_cast(this)->front(); } /// Return the 'return' statement of this Function. const OperationInst *getReturnStmt() const; OperationInst *getReturnStmt(); // These should only be used on MLFunctions. - StmtBlock *getBody() { + Block *getBody() { assert(isML()); return &blocks.front(); } - const StmtBlock *getBody() const { + const Block *getBody() const { return const_cast(this)->getBody(); } @@ -218,7 +213,7 @@ private: AttributeListStorage *attrs; /// The contents of the body. - StmtBlockList blocks; + BlockList blocks; void operator=(const Function &) = delete; friend struct llvm::ilist_traits; diff --git a/mlir/include/mlir/IR/FunctionGraphTraits.h b/mlir/include/mlir/IR/FunctionGraphTraits.h index 54305c90d25..6ba50e7ca9e 100644 --- a/mlir/include/mlir/IR/FunctionGraphTraits.h +++ b/mlir/include/mlir/IR/FunctionGraphTraits.h @@ -28,9 +28,9 @@ #include "llvm/ADT/GraphTraits.h" namespace llvm { -template <> struct GraphTraits { - using ChildIteratorType = mlir::BasicBlock::succ_iterator; - using Node = mlir::BasicBlock; +template <> struct GraphTraits { + using ChildIteratorType = mlir::Block::succ_iterator; + using Node = mlir::Block; using NodeRef = Node *; static NodeRef getEntryNode(NodeRef bb) { return bb; } @@ -41,9 +41,9 @@ template <> struct GraphTraits { static ChildIteratorType child_end(NodeRef node) { return node->succ_end(); } }; -template <> struct GraphTraits { - using ChildIteratorType = mlir::BasicBlock::const_succ_iterator; - using Node = const mlir::BasicBlock; +template <> struct GraphTraits { + using ChildIteratorType = mlir::Block::const_succ_iterator; + using Node = const mlir::Block; using NodeRef = Node *; static NodeRef getEntryNode(NodeRef bb) { return bb; } @@ -54,9 +54,9 @@ template <> struct GraphTraits { static ChildIteratorType child_end(NodeRef node) { return node->succ_end(); } }; -template <> struct GraphTraits> { - using ChildIteratorType = mlir::BasicBlock::pred_iterator; - using Node = mlir::BasicBlock; +template <> struct GraphTraits> { + using ChildIteratorType = mlir::Block::pred_iterator; + using Node = mlir::Block; using NodeRef = Node *; static NodeRef getEntryNode(Inverse inverseGraph) { return inverseGraph.Graph; @@ -69,9 +69,9 @@ template <> struct GraphTraits> { } }; -template <> struct GraphTraits> { - using ChildIteratorType = mlir::BasicBlock::const_pred_iterator; - using Node = const mlir::BasicBlock; +template <> struct GraphTraits> { + using ChildIteratorType = mlir::Block::const_pred_iterator; + using Node = const mlir::Block; using NodeRef = Node *; static NodeRef getEntryNode(Inverse inverseGraph) { @@ -86,9 +86,9 @@ template <> struct GraphTraits> { }; template <> -struct GraphTraits : public GraphTraits { +struct GraphTraits : public GraphTraits { using GraphType = mlir::Function *; - using NodeRef = mlir::BasicBlock *; + using NodeRef = mlir::Block *; static NodeRef getEntryNode(GraphType fn) { return &fn->front(); } @@ -103,9 +103,9 @@ struct GraphTraits : public GraphTraits { template <> struct GraphTraits - : public GraphTraits { + : public GraphTraits { using GraphType = const mlir::Function *; - using NodeRef = const mlir::BasicBlock *; + using NodeRef = const mlir::Block *; static NodeRef getEntryNode(GraphType fn) { return &fn->front(); } @@ -120,7 +120,7 @@ struct GraphTraits template <> struct GraphTraits> - : public GraphTraits> { + : public GraphTraits> { using GraphType = Inverse; using NodeRef = NodeRef; @@ -137,7 +137,7 @@ struct GraphTraits> template <> struct GraphTraits> - : public GraphTraits> { + : public GraphTraits> { using GraphType = Inverse; using NodeRef = NodeRef; @@ -153,10 +153,9 @@ struct GraphTraits> }; template <> -struct GraphTraits - : public GraphTraits { - using GraphType = mlir::StmtBlockList *; - using NodeRef = mlir::BasicBlock *; +struct GraphTraits : public GraphTraits { + using GraphType = mlir::BlockList *; + using NodeRef = mlir::Block *; static NodeRef getEntryNode(GraphType fn) { return &fn->front(); } @@ -170,10 +169,10 @@ struct GraphTraits }; template <> -struct GraphTraits - : public GraphTraits { - using GraphType = const mlir::StmtBlockList *; - using NodeRef = const mlir::BasicBlock *; +struct GraphTraits + : public GraphTraits { + using GraphType = const mlir::BlockList *; + using NodeRef = const mlir::Block *; static NodeRef getEntryNode(GraphType fn) { return &fn->front(); } @@ -187,9 +186,9 @@ struct GraphTraits }; template <> -struct GraphTraits> - : public GraphTraits> { - using GraphType = Inverse; +struct GraphTraits> + : public GraphTraits> { + using GraphType = Inverse; using NodeRef = NodeRef; static NodeRef getEntryNode(GraphType fn) { return &fn.Graph->front(); } @@ -204,9 +203,9 @@ struct GraphTraits> }; template <> -struct GraphTraits> - : public GraphTraits> { - using GraphType = Inverse; +struct GraphTraits> + : public GraphTraits> { + using GraphType = Inverse; using NodeRef = NodeRef; static NodeRef getEntryNode(GraphType fn) { return &fn.Graph->front(); } diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 4e840409a27..e1b90b6e39e 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -752,14 +752,14 @@ public: return this->getInstruction()->getNumSuccessorOperands(index); } - const BasicBlock *getSuccessor(unsigned index) const { + const Block *getSuccessor(unsigned index) const { return this->getInstruction()->getSuccessor(index); } - BasicBlock *getSuccessor(unsigned index) { + Block *getSuccessor(unsigned index) { return this->getInstruction()->getSuccessor(index); } - void setSuccessor(BasicBlock *block, unsigned index) { + void setSuccessor(Block *block, unsigned index) { return this->getInstruction()->setSuccessor(block, index); } diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index 9ebc55b2ae8..587eabdee96 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -264,7 +264,7 @@ public: virtual bool parseOperand(OperandType &result) = 0; /// Parse a single operation successor and it's operand list. - virtual bool parseSuccessorAndUseList(BasicBlock *&dest, + virtual bool parseSuccessorAndUseList(Block *&dest, SmallVectorImpl &operands) = 0; /// These are the supported delimiters around operand lists, used by diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index 2bc75a2a40d..15c882b90f7 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -31,6 +31,7 @@ #include namespace mlir { +class Block; class Dialect; class OperationInst; class OperationState; @@ -39,10 +40,8 @@ class OpAsmParserResult; class OpAsmPrinter; class Pattern; class RewritePattern; -class StmtBlock; class Type; class Value; -using BasicBlock = StmtBlock; /// This is a vector that owns the patterns inside of it. using OwningPatternList = std::vector>; @@ -209,7 +208,7 @@ struct OperationState { SmallVector types; SmallVector attributes; /// Successors of this operation and their respective operands. - SmallVector successors; + SmallVector successors; public: OperationState(MLIRContext *context, Location location, StringRef name) @@ -221,7 +220,7 @@ public: OperationState(MLIRContext *context, Location location, StringRef name, ArrayRef operands, ArrayRef types, ArrayRef attributes, - ArrayRef successors = {}) + ArrayRef successors = {}) : context(context), location(location), name(name, context), operands(operands.begin(), operands.end()), types(types.begin(), types.end()), @@ -248,7 +247,7 @@ public: attributes.push_back({name, attr}); } - void addSuccessor(StmtBlock *successor, ArrayRef succOperands) { + void addSuccessor(Block *successor, ArrayRef succOperands) { successors.push_back(successor); // Insert a sentinal operand to mark a barrier between successor operands. operands.push_back(nullptr); diff --git a/mlir/include/mlir/IR/Statement.h b/mlir/include/mlir/IR/Statement.h index 48135514dcf..9ca5530f33c 100644 --- a/mlir/include/mlir/IR/Statement.h +++ b/mlir/include/mlir/IR/Statement.h @@ -28,13 +28,13 @@ #include "llvm/ADT/ilist_node.h" namespace mlir { +class Block; class Location; -class StmtBlock; class ForStmt; class MLIRContext; -/// The operand of a Terminator contains a StmtBlock. -using StmtBlockOperand = IROperandImpl; +/// Terminator operations can have Block operands to represent successors. +using BlockOperand = IROperandImpl; } // namespace mlir @@ -55,7 +55,7 @@ template <> struct ilist_traits<::mlir::Statement> { stmt_iterator first, stmt_iterator last); private: - mlir::StmtBlock *getContainingBlock(); + mlir::Block *getContainingBlock(); }; } // end namespace llvm @@ -66,9 +66,9 @@ template class OperandIterator; /// Statement is a basic unit of execution within an ML function. /// Statements can be nested within for and if statements effectively /// forming a tree. Child statements are organized into statement blocks -/// represented by a 'StmtBlock' class. +/// represented by a 'Block' class. class Statement : public IROperandOwner, - public llvm::ilist_node_with_parent { + public llvm::ilist_node_with_parent { public: enum class Kind { OperationInst = (int)IROperandOwner::Kind::OperationInst, @@ -95,7 +95,7 @@ public: Statement *clone(MLIRContext *context) const; /// Returns the statement block that contains this statement. - StmtBlock *getBlock() const { return block; } + Block *getBlock() const { return block; } /// Returns the closest surrounding statement that contains this statement /// or nullptr if this is a top-level statement. @@ -121,7 +121,7 @@ public: /// Unlink this operation instruction from its current basic block and insert /// it right before `iterator` in the specified basic block. - void moveBefore(StmtBlock *block, llvm::iplist::iterator iterator); + void moveBefore(Block *block, llvm::iplist::iterator iterator); // Returns whether the Statement is a terminator. bool isTerminator() const; @@ -198,7 +198,7 @@ protected: private: /// The statement block that containts this statement. - StmtBlock *block = nullptr; + Block *block = nullptr; // allow ilist_traits access to 'block' field. friend struct llvm::ilist_traits; diff --git a/mlir/include/mlir/IR/Statements.h b/mlir/include/mlir/IR/Statements.h index d04ebd776b9..aa4157714a7 100644 --- a/mlir/include/mlir/IR/Statements.h +++ b/mlir/include/mlir/IR/Statements.h @@ -23,10 +23,10 @@ #define MLIR_IR_STATEMENTS_H #include "mlir/IR/AffineMap.h" +#include "mlir/IR/Block.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/Statement.h" -#include "mlir/IR/StmtBlock.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/TrailingObjects.h" @@ -46,14 +46,14 @@ class Function; /// class OperationInst final : public Statement, - private llvm::TrailingObjects { public: /// Create a new OperationInst with the specific fields. static OperationInst * create(Location location, OperationName name, ArrayRef operands, ArrayRef resultTypes, ArrayRef attributes, - ArrayRef successors, MLIRContext *context); + ArrayRef successors, MLIRContext *context); /// Return the context this operation is associated with. MLIRContext *getContext() const; @@ -229,11 +229,11 @@ public: // Terminators //===--------------------------------------------------------------------===// - MutableArrayRef getBlockOperands() { + MutableArrayRef getBlockOperands() { assert(isTerminator() && "Only terminators have a block operands list"); - return {getTrailingObjects(), numSuccs}; + return {getTrailingObjects(), numSuccs}; } - ArrayRef getBlockOperands() const { + ArrayRef getBlockOperands() const { return const_cast(this)->getBlockOperands(); } @@ -248,14 +248,14 @@ public: return getTrailingObjects()[index]; } - StmtBlock *getSuccessor(unsigned index) { + Block *getSuccessor(unsigned index) { assert(index < getNumSuccessors()); return getBlockOperands()[index].get(); } - const StmtBlock *getSuccessor(unsigned index) const { + const Block *getSuccessor(unsigned index) const { return const_cast(this)->getSuccessor(index); } - void setSuccessor(BasicBlock *block, unsigned index); + void setSuccessor(Block *block, unsigned index); /// Erase a specific operand from the operand list of the successor at /// 'index'. @@ -404,7 +404,7 @@ private: void eraseOperand(unsigned index); // This stuff is used by the TrailingObjects template. - friend llvm::TrailingObjects; size_t numTrailingObjects(OverloadToken) const { return numOperands; @@ -412,7 +412,7 @@ private: size_t numTrailingObjects(OverloadToken) const { return numResults; } - size_t numTrailingObjects(OverloadToken) const { + size_t numTrailingObjects(OverloadToken) const { return numSuccs; } size_t numTrailingObjects(OverloadToken) const { return numSuccs; } @@ -515,7 +515,7 @@ public: AffineMap ubMap, int64_t step); ~ForStmt() { - // Explicitly erase statements instead of relying of 'StmtBlock' destructor + // Explicitly erase statements instead of relying of 'Block' destructor // since child statements need to be destroyed before the Value that this // for stmt represents is destroyed. Affine maps are immortal objects and // don't need to be deleted. @@ -534,10 +534,10 @@ public: using const_operand_range = llvm::iterator_range; /// Get the body of the ForStmt. - StmtBlock *getBody() { return &body.front(); } + Block *getBody() { return &body.front(); } /// Get the body of the ForStmt. - const StmtBlock *getBody() const { return &body.front(); } + const Block *getBody() const { return &body.front(); } //===--------------------------------------------------------------------===// // Bounds and step @@ -664,8 +664,8 @@ public: } private: - // The StmtBlock for the body. - StmtBlockList body; + // The Block for the body. + BlockList body; // Affine map for the lower bound. AffineMap lbMap; @@ -746,18 +746,18 @@ public: // Then, else, condition. //===--------------------------------------------------------------------===// - StmtBlock *getThen() { return &thenClause.front(); } - const StmtBlock *getThen() const { return &thenClause.front(); } - StmtBlock *getElse() { return elseClause ? &elseClause->front() : nullptr; } - const StmtBlock *getElse() const { + Block *getThen() { return &thenClause.front(); } + const Block *getThen() const { return &thenClause.front(); } + Block *getElse() { return elseClause ? &elseClause->front() : nullptr; } + const Block *getElse() const { return elseClause ? &elseClause->front() : nullptr; } bool hasElse() const { return elseClause != nullptr; } - StmtBlock *createElse() { + Block *createElse() { assert(elseClause == nullptr && "already has an else clause!"); - elseClause = new StmtBlockList(this); - elseClause->push_back(new StmtBlock()); + elseClause = new BlockList(this); + elseClause->push_back(new Block()); return &elseClause->front(); } @@ -823,9 +823,9 @@ public: private: // it is always present. - StmtBlockList thenClause; + BlockList thenClause; // 'else' clause of the if statement. 'nullptr' if there is no else clause. - StmtBlockList *elseClause; + BlockList *elseClause; // The integer set capturing the conditional guard. IntegerSet set; diff --git a/mlir/include/mlir/IR/StmtBlock.h b/mlir/include/mlir/IR/StmtBlock.h deleted file mode 100644 index 916834dfbdc..00000000000 --- a/mlir/include/mlir/IR/StmtBlock.h +++ /dev/null @@ -1,497 +0,0 @@ -//===- StmtBlock.h ----------------------------------------------*- C++ -*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file defines StmtBlock and *Stmt classes that extend Statement. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_IR_STMTBLOCK_H -#define MLIR_IR_STMTBLOCK_H - -#include "mlir/IR/Statement.h" -#include "llvm/ADT/PointerUnion.h" - -namespace mlir { -class IfStmt; -class StmtBlockList; - -template class PredecessorIterator; -template class SuccessorIterator; - -/// Blocks represents an ordered list of Instructions. -class StmtBlock - : public IRObjectWithUseList, - public llvm::ilist_node_with_parent { -public: - explicit StmtBlock() {} - ~StmtBlock(); - - void clear() { - // Clear statements in the reverse order so that uses are destroyed - // before their defs. - while (!empty()) - statements.pop_back(); - } - - StmtBlockList *getParent() const { return parent; } - - /// Returns the closest surrounding statement that contains this block or - /// nullptr if this is a top-level statement block. - Statement *getContainingStmt(); - - const Statement *getContainingStmt() const { - return const_cast(this)->getContainingStmt(); - } - - /// Returns the function that this statement block is part of. The function - /// is determined by traversing the chain of parent statements. - Function *getFunction(); - const Function *getFunction() const { - return const_cast(this)->getFunction(); - } - - //===--------------------------------------------------------------------===// - // Block argument management - //===--------------------------------------------------------------------===// - - // This is the list of arguments to the block. - using BlockArgListType = ArrayRef; - - // FIXME: Not const correct. - BlockArgListType getArguments() const { return arguments; } - - using args_iterator = BlockArgListType::iterator; - using reverse_args_iterator = BlockArgListType::reverse_iterator; - args_iterator args_begin() const { return getArguments().begin(); } - args_iterator args_end() const { return getArguments().end(); } - reverse_args_iterator args_rbegin() const { return getArguments().rbegin(); } - reverse_args_iterator args_rend() const { return getArguments().rend(); } - - bool args_empty() const { return arguments.empty(); } - - /// Add one value to the argument list. - BlockArgument *addArgument(Type type); - - /// Add one argument to the argument list for each type specified in the list. - llvm::iterator_range addArguments(ArrayRef types); - - /// Erase the argument at 'index' and remove it from the argument list. - void eraseArgument(unsigned index); - - unsigned getNumArguments() const { return arguments.size(); } - BlockArgument *getArgument(unsigned i) { return arguments[i]; } - const BlockArgument *getArgument(unsigned i) const { return arguments[i]; } - - //===--------------------------------------------------------------------===// - // Statement list management - //===--------------------------------------------------------------------===// - - /// This is the list of statements in the block. - using StmtListType = llvm::iplist; - StmtListType &getStatements() { return statements; } - const StmtListType &getStatements() const { return statements; } - - // Iteration over the statements in the block. - using iterator = StmtListType::iterator; - using const_iterator = StmtListType::const_iterator; - using reverse_iterator = StmtListType::reverse_iterator; - using const_reverse_iterator = StmtListType::const_reverse_iterator; - - iterator begin() { return statements.begin(); } - iterator end() { return statements.end(); } - const_iterator begin() const { return statements.begin(); } - const_iterator end() const { return statements.end(); } - reverse_iterator rbegin() { return statements.rbegin(); } - reverse_iterator rend() { return statements.rend(); } - const_reverse_iterator rbegin() const { return statements.rbegin(); } - const_reverse_iterator rend() const { return statements.rend(); } - - bool empty() const { return statements.empty(); } - void push_back(Statement *stmt) { statements.push_back(stmt); } - void push_front(Statement *stmt) { statements.push_front(stmt); } - - Statement &back() { return statements.back(); } - const Statement &back() const { - return const_cast(this)->back(); - } - Statement &front() { return statements.front(); } - const Statement &front() const { - return const_cast(this)->front(); - } - - /// Returns the statement's position in this block or -1 if the statement is - /// not present. - int64_t findStmtPosInBlock(const Statement &stmt) const { - int64_t j = 0; - for (const auto &s : statements) { - if (&s == &stmt) - return j; - j++; - } - return -1; - } - - /// Returns 'stmt' if 'stmt' lies in this block, or otherwise finds the - /// ancestor statement of 'stmt' that lies in this block. Returns nullptr if - /// the latter fails. - const Statement *findAncestorStmtInBlock(const Statement &stmt) const; - Statement *findAncestorStmtInBlock(Statement *stmt) { - return const_cast(findAncestorStmtInBlock(*stmt)); - } - - //===--------------------------------------------------------------------===// - // Terminator management - //===--------------------------------------------------------------------===// - - /// Get the terminator instruction of this block, or null if the block is - /// malformed. - OperationInst *getTerminator(); - - const OperationInst *getTerminator() const { - return const_cast(this)->getTerminator(); - } - - //===--------------------------------------------------------------------===// - // Predecessors and successors. - //===--------------------------------------------------------------------===// - - // Predecessor iteration. - using const_pred_iterator = PredecessorIterator; - const_pred_iterator pred_begin() const; - const_pred_iterator pred_end() const; - llvm::iterator_range getPredecessors() const; - - using pred_iterator = PredecessorIterator; - pred_iterator pred_begin(); - pred_iterator pred_end(); - llvm::iterator_range getPredecessors(); - - /// Return true if this block has no predecessors. - bool hasNoPredecessors() const; - - /// If this block has exactly one predecessor, return it. Otherwise, return - /// null. - /// - /// Note that if a block has duplicate predecessors from a single block (e.g. - /// if you have a conditional branch with the same block as the true/false - /// destinations) is not considered to be a single predecessor. - StmtBlock *getSinglePredecessor(); - - const StmtBlock *getSinglePredecessor() const { - return const_cast(this)->getSinglePredecessor(); - } - - // Indexed successor access. - unsigned getNumSuccessors() const; - const StmtBlock *getSuccessor(unsigned i) const { - return const_cast(this)->getSuccessor(i); - } - StmtBlock *getSuccessor(unsigned i); - - // Successor iteration. - using const_succ_iterator = SuccessorIterator; - const_succ_iterator succ_begin() const; - const_succ_iterator succ_end() const; - llvm::iterator_range getSuccessors() const; - - using succ_iterator = SuccessorIterator; - succ_iterator succ_begin(); - succ_iterator succ_end(); - llvm::iterator_range getSuccessors(); - - //===--------------------------------------------------------------------===// - // Other - //===--------------------------------------------------------------------===// - - /// Unlink this Block from its Function and delete it. - void eraseFromFunction(); - - /// Split the basic block into two basic blocks before the specified - /// instruction or iterator. - /// - /// Note that all instructions BEFORE the specified iterator stay as part of - /// the original basic block, an unconditional branch is added to the original - /// block (going to the new block), and the rest of the instructions in the - /// original block are moved to the new BB, including the old terminator. The - /// newly formed Block is returned. - /// - /// This function invalidates the specified iterator. - StmtBlock *splitBasicBlock(iterator splitBefore); - StmtBlock *splitBasicBlock(Instruction *splitBeforeInst) { - return splitBasicBlock(iterator(splitBeforeInst)); - } - - /// getSublistAccess() - Returns pointer to member of statement list - static StmtListType StmtBlock::*getSublistAccess(Statement *) { - return &StmtBlock::statements; - } - - void print(raw_ostream &os) const; - void dump() const; - - /// Print out the name of the basic block without printing its body. - /// NOTE: The printType argument is ignored. We keep it for compatibility - /// with LLVM dominator machinery that expects it to exist. - void printAsOperand(raw_ostream &os, bool printType = true); - -private: - /// This is the parent function/IfStmt/ForStmt that owns this block. - StmtBlockList *parent = nullptr; - - /// This is the list of statements in the block. - StmtListType statements; - - /// This is the list of arguments to the block. - std::vector arguments; - - StmtBlock(const StmtBlock &) = delete; - void operator=(const StmtBlock &) = delete; - - friend struct llvm::ilist_traits; -}; - -} // end namespace mlir - -//===----------------------------------------------------------------------===// -// ilist_traits for StmtBlock -//===----------------------------------------------------------------------===// - -namespace llvm { - -template <> -struct ilist_traits<::mlir::StmtBlock> - : public ilist_alloc_traits<::mlir::StmtBlock> { - using StmtBlock = ::mlir::StmtBlock; - using block_iterator = simple_ilist<::mlir::StmtBlock>::iterator; - - void addNodeToList(StmtBlock *block); - void removeNodeFromList(StmtBlock *block); - void transferNodesFromList(ilist_traits &otherList, - block_iterator first, block_iterator last); - -private: - mlir::StmtBlockList *getContainingBlockList(); -}; -} // end namespace llvm - -namespace mlir { - -/// This class contains a list of basic blocks and has a notion of the object it -/// is part of - a Function or IfStmt or ForStmt. -class StmtBlockList { -public: - explicit StmtBlockList(Function *container); - explicit StmtBlockList(Statement *container); - - using BlockListType = llvm::iplist; - BlockListType &getBlocks() { return blocks; } - const BlockListType &getBlocks() const { return blocks; } - - // Iteration over the block in the function. - using iterator = BlockListType::iterator; - using const_iterator = BlockListType::const_iterator; - using reverse_iterator = BlockListType::reverse_iterator; - using const_reverse_iterator = BlockListType::const_reverse_iterator; - - iterator begin() { return blocks.begin(); } - iterator end() { return blocks.end(); } - const_iterator begin() const { return blocks.begin(); } - const_iterator end() const { return blocks.end(); } - reverse_iterator rbegin() { return blocks.rbegin(); } - reverse_iterator rend() { return blocks.rend(); } - const_reverse_iterator rbegin() const { return blocks.rbegin(); } - const_reverse_iterator rend() const { return blocks.rend(); } - - bool empty() const { return blocks.empty(); } - void push_back(StmtBlock *block) { blocks.push_back(block); } - void push_front(StmtBlock *block) { blocks.push_front(block); } - - StmtBlock &back() { return blocks.back(); } - const StmtBlock &back() const { - return const_cast(this)->back(); - } - - StmtBlock &front() { return blocks.front(); } - const StmtBlock &front() const { - return const_cast(this)->front(); - } - - /// getSublistAccess() - Returns pointer to member of block list. - static BlockListType StmtBlockList::*getSublistAccess(StmtBlock *) { - return &StmtBlockList::blocks; - } - - /// A StmtBlockList is part of a Function or and IfStmt/ForStmt. If it is - /// part of an IfStmt/ForStmt, then return it, otherwise return null. - Statement *getContainingStmt(); - const Statement *getContainingStmt() const { - return const_cast(this)->getContainingStmt(); - } - - /// A StmtBlockList is part of a Function or and IfStmt/ForStmt. If it is - /// part of a Function, then return it, otherwise return null. - Function *getContainingFunction(); - const Function *getContainingFunction() const { - return const_cast(this)->getContainingFunction(); - } - - // TODO(clattner): This is only to help ML -> CFG migration, remove in the - // near future. This makes StmtBlockList work more like BasicBlock did. - Function *getFunction(); - const Function *getFunction() const { - return const_cast(this)->getFunction(); - } - -private: - BlockListType blocks; - - /// This is the object we are part of. - llvm::PointerUnion container; -}; - -//===----------------------------------------------------------------------===// -// Predecessors -//===----------------------------------------------------------------------===// - -/// Implement a predecessor iterator as a forward iterator. This works by -/// walking the use lists of the blocks. The entries on this list are the -/// StmtBlockOperands that are embedded into terminator instructions. From the -/// operand, we can get the terminator that contains it, and it's parent block -/// is the predecessor. -template -class PredecessorIterator - : public llvm::iterator_facade_base, - std::forward_iterator_tag, - BlockType *> { -public: - PredecessorIterator(StmtBlockOperand *firstOperand) - : bbUseIterator(firstOperand) {} - - PredecessorIterator &operator=(const PredecessorIterator &rhs) { - bbUseIterator = rhs.bbUseIterator; - } - - bool operator==(const PredecessorIterator &rhs) const { - return bbUseIterator == rhs.bbUseIterator; - } - - BlockType *operator*() const { - // The use iterator points to an operand of a terminator. The predecessor - // we return is the block that the terminator is embedded into. - return bbUseIterator.getUser()->getBlock(); - } - - PredecessorIterator &operator++() { - ++bbUseIterator; - return *this; - } - - /// Get the successor number in the predecessor terminator. - unsigned getSuccessorIndex() const { - return bbUseIterator->getOperandNumber(); - } - -private: - using BBUseIterator = ValueUseIterator; - BBUseIterator bbUseIterator; -}; - -inline auto StmtBlock::pred_begin() const -> const_pred_iterator { - return const_pred_iterator((StmtBlockOperand *)getFirstUse()); -} - -inline auto StmtBlock::pred_end() const -> const_pred_iterator { - return const_pred_iterator(nullptr); -} - -inline auto StmtBlock::getPredecessors() const - -> llvm::iterator_range { - return {pred_begin(), pred_end()}; -} - -inline auto StmtBlock::pred_begin() -> pred_iterator { - return pred_iterator((StmtBlockOperand *)getFirstUse()); -} - -inline auto StmtBlock::pred_end() -> pred_iterator { - return pred_iterator(nullptr); -} - -inline auto StmtBlock::getPredecessors() - -> llvm::iterator_range { - return {pred_begin(), pred_end()}; -} - -//===----------------------------------------------------------------------===// -// Successors -//===----------------------------------------------------------------------===// - -/// This template implments the successor iterators for StmtBlock. -template -class SuccessorIterator final - : public IndexedAccessorIterator, BlockType, - BlockType> { -public: - /// Initializes the result iterator to the specified index. - SuccessorIterator(BlockType *object, unsigned index) - : IndexedAccessorIterator, BlockType, - BlockType>(object, index) {} - - SuccessorIterator(const SuccessorIterator &other) - : SuccessorIterator(other.object, other.index) {} - - /// Support converting to the const variant. This will be a no-op for const - /// variant. - operator SuccessorIterator() const { - return SuccessorIterator(this->object, this->index); - } - - BlockType *operator*() const { - return this->object->getSuccessor(this->index); - } - - /// Get the successor number in the terminator. - unsigned getSuccessorIndex() const { return this->index; } -}; - -inline auto StmtBlock::succ_begin() const -> const_succ_iterator { - return const_succ_iterator(this, 0); -} - -inline auto StmtBlock::succ_end() const -> const_succ_iterator { - return const_succ_iterator(this, getNumSuccessors()); -} - -inline auto StmtBlock::getSuccessors() const - -> llvm::iterator_range { - return {succ_begin(), succ_end()}; -} - -inline auto StmtBlock::succ_begin() -> succ_iterator { - return succ_iterator(this, 0); -} - -inline auto StmtBlock::succ_end() -> succ_iterator { - return succ_iterator(this, getNumSuccessors()); -} - -inline auto StmtBlock::getSuccessors() -> llvm::iterator_range { - return {succ_begin(), succ_end()}; -} - -} // end namespace mlir -#endif // MLIR_IR_STMTBLOCK_H diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h index 75184dc7a3f..2213fe79852 100644 --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -27,10 +27,10 @@ #include "mlir/Support/LLVM.h" namespace mlir { +class Block; class Function; class OperationInst; class Statement; -class StmtBlock; class Value; using Instruction = Statement; @@ -136,18 +136,18 @@ public: return const_cast(this)->getFunction(); } - StmtBlock *getOwner() { return owner; } - const StmtBlock *getOwner() const { return owner; } + Block *getOwner() { return owner; } + const Block *getOwner() const { return owner; } private: - friend class StmtBlock; // For access to private constructor. - BlockArgument(Type type, StmtBlock *owner) + friend class Block; // For access to private constructor. + BlockArgument(Type type, Block *owner) : Value(Value::Kind::BlockArgument, type), owner(owner) {} /// The owner of this operand. /// TODO: can encode this more efficiently to avoid the space hit of this /// through bitpacking shenanigans. - StmtBlock *const owner; + Block *const owner; }; /// This is a value defined by a result of an operation instruction. diff --git a/mlir/include/mlir/Transforms/LoopUtils.h b/mlir/include/mlir/Transforms/LoopUtils.h index d214a96f335..2694433d5a0 100644 --- a/mlir/include/mlir/Transforms/LoopUtils.h +++ b/mlir/include/mlir/Transforms/LoopUtils.h @@ -66,7 +66,7 @@ bool loopUnrollJamUpToFactor(ForStmt *forStmt, uint64_t unrollJamFactor); bool promoteIfSingleIteration(ForStmt *forStmt); /// Promotes all single iteration ForStmt's in the Function, i.e., moves -/// their body into the containing StmtBlock. +/// their body into the containing Block. void promoteSingleIterationLoops(Function *f); /// Returns the lower bound of the cleanup loop when unrolling a loop diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index e28c2e87651..12af803fdad 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -878,15 +878,15 @@ static unsigned getNumCommonLoops(const FlatAffineConstraints &srcDomain, return numCommonLoops; } -// Returns StmtBlock common to 'srcAccess.opStmt' and 'dstAccess.opStmt'. -static StmtBlock *getCommonStmtBlock(const MemRefAccess &srcAccess, - const MemRefAccess &dstAccess, - const FlatAffineConstraints &srcDomain, - unsigned numCommonLoops) { +// Returns Block common to 'srcAccess.opStmt' and 'dstAccess.opStmt'. +static Block *getCommonBlock(const MemRefAccess &srcAccess, + const MemRefAccess &dstAccess, + const FlatAffineConstraints &srcDomain, + unsigned numCommonLoops) { if (numCommonLoops == 0) { auto *block = srcAccess.opStmt->getBlock(); - while (block->getContainingStmt()) { - block = block->getContainingStmt()->getBlock(); + while (block->getContainingInst()) { + block = block->getContainingInst()->getBlock(); } return block; } @@ -906,14 +906,14 @@ static bool srcMayExecuteBeforeDst(const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, const FlatAffineConstraints &srcDomain, unsigned numCommonLoops) { - // Get StmtBlock common to 'srcAccess.opStmt' and 'dstAccess.opStmt'. + // Get Block common to 'srcAccess.opStmt' and 'dstAccess.opStmt'. auto *commonBlock = - getCommonStmtBlock(srcAccess, dstAccess, srcDomain, numCommonLoops); + getCommonBlock(srcAccess, dstAccess, srcDomain, numCommonLoops); // Check the dominance relationship between the respective ancestors of the - // src and dst in the StmtBlock of the innermost among the common loops. - auto *srcStmt = commonBlock->findAncestorStmtInBlock(*srcAccess.opStmt); + // src and dst in the Block of the innermost among the common loops. + auto *srcStmt = commonBlock->findAncestorInstInBlock(*srcAccess.opStmt); assert(srcStmt != nullptr); - auto *dstStmt = commonBlock->findAncestorStmtInBlock(*dstAccess.opStmt); + auto *dstStmt = commonBlock->findAncestorInstInBlock(*dstAccess.opStmt); assert(dstStmt != nullptr); return mlir::properlyDominates(*srcStmt, *dstStmt); } diff --git a/mlir/lib/Analysis/Dominance.cpp b/mlir/lib/Analysis/Dominance.cpp index 0ebbec9c025..0c8db07dbb4 100644 --- a/mlir/lib/Analysis/Dominance.cpp +++ b/mlir/lib/Analysis/Dominance.cpp @@ -25,9 +25,9 @@ #include "llvm/Support/GenericDomTreeConstruction.h" using namespace mlir; -template class llvm::DominatorTreeBase; -template class llvm::DominatorTreeBase; -template class llvm::DomTreeNodeBase; +template class llvm::DominatorTreeBase; +template class llvm::DominatorTreeBase; +template class llvm::DomTreeNodeBase; /// Compute the immediate-dominators map. DominanceInfo::DominanceInfo(Function *function) : DominatorTreeBase() { @@ -57,8 +57,8 @@ bool DominanceInfo::properlyDominates(const Instruction *a, return true; // Otherwise, do a linear scan to determine whether B comes after A. - auto aIter = BasicBlock::const_iterator(a); - auto bIter = BasicBlock::const_iterator(b); + auto aIter = Block::const_iterator(a); + auto bIter = Block::const_iterator(b); auto fIter = aBlock->begin(); while (bIter != fIter) { --bIter; diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index caeeccb677f..dd14f38df55 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -309,7 +309,7 @@ bool mlir::isVectorizableLoop(const ForStmt &loop) { bool mlir::isStmtwiseShiftValid(const ForStmt &forStmt, ArrayRef shifts) { auto *forBody = forStmt.getBody(); - assert(shifts.size() == forBody->getStatements().size()); + assert(shifts.size() == forBody->getInstructions().size()); unsigned s = 0; for (const auto &stmt : *forBody) { // A for or if stmt does not produce any def/results (that are used @@ -323,8 +323,8 @@ bool mlir::isStmtwiseShiftValid(const ForStmt &forStmt, // This is a naive way. If performance becomes an issue, a map can // be used to store 'shifts' - to look up the shift for a statement in // constant time. - if (auto *ancStmt = forBody->findAncestorStmtInBlock(*use.getOwner())) - if (shifts[s] != shifts[forBody->findStmtPosInBlock(*ancStmt)]) + if (auto *ancStmt = forBody->findAncestorInstInBlock(*use.getOwner())) + if (shifts[s] != shifts[forBody->findInstPositionInBlock(*ancStmt)]) return false; } } diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index e17c27ac941..f6191418f54 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -44,8 +44,8 @@ bool mlir::properlyDominates(const Statement &a, const Statement &b) { if (a.getBlock() == b.getBlock()) { // Do a linear scan to determine whether b comes after a. - auto aIter = StmtBlock::const_iterator(a); - auto bIter = StmtBlock::const_iterator(b); + auto aIter = Block::const_iterator(a); + auto bIter = Block::const_iterator(b); auto aBlockStart = a.getBlock()->begin(); while (bIter != aBlockStart) { --bIter; @@ -56,7 +56,7 @@ bool mlir::properlyDominates(const Statement &a, const Statement &b) { } // Traverse up b's hierarchy to check if b's block is contained in a's. - if (const auto *bAncestor = a.getBlock()->findAncestorStmtInBlock(b)) + if (const auto *bAncestor = a.getBlock()->findAncestorInstInBlock(b)) // a and bAncestor are in the same block; check if the former dominates it. return dominates(a, *bAncestor); @@ -333,26 +333,26 @@ template bool mlir::boundCheckLoadOrStoreOp(OpPointer loadOp, template bool mlir::boundCheckLoadOrStoreOp(OpPointer storeOp, bool emitError); -// Returns in 'positions' the StmtBlock positions of 'stmt' in each ancestor -// StmtBlock from the StmtBlock containing statement, stopping at 'limitBlock'. -static void findStmtPosition(const Statement *stmt, StmtBlock *limitBlock, +// Returns in 'positions' the Block positions of 'stmt' in each ancestor +// Block from the Block containing statement, stopping at 'limitBlock'. +static void findStmtPosition(const Statement *stmt, Block *limitBlock, SmallVectorImpl *positions) { - StmtBlock *block = stmt->getBlock(); + Block *block = stmt->getBlock(); while (block != limitBlock) { - int stmtPosInBlock = block->findStmtPosInBlock(*stmt); + int stmtPosInBlock = block->findInstPositionInBlock(*stmt); assert(stmtPosInBlock >= 0); positions->push_back(stmtPosInBlock); - stmt = block->getContainingStmt(); + stmt = block->getContainingInst(); block = stmt->getBlock(); } std::reverse(positions->begin(), positions->end()); } -// Returns the Statement in a possibly nested set of StmtBlocks, where the +// Returns the Statement in a possibly nested set of Blocks, where the // position of the statement is represented by 'positions', which has a -// StmtBlock position for each level of nesting. +// Block position for each level of nesting. static Statement *getStmtAtPosition(ArrayRef positions, - unsigned level, StmtBlock *block) { + unsigned level, Block *block) { unsigned i = 0; for (auto &stmt : *block) { if (i != positions[level]) { diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index 43c29dbb6ac..4cad531ecaa 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -59,7 +59,7 @@ public: return fn.emitError(message); } - bool failure(const Twine &message, const BasicBlock &bb) { + bool failure(const Twine &message, const Block &bb) { // Take the location information for the first instruction in the block. if (!bb.empty()) if (auto *op = dyn_cast(&bb.front())) @@ -153,7 +153,7 @@ struct CFGFuncVerifier : public Verifier { : Verifier(fn), fn(fn), domInfo(const_cast(&fn)) {} bool verify(); - bool verifyBlock(const BasicBlock &block); + bool verifyBlock(const Block &block); bool verifyInstOperands(const Instruction &inst); }; } // end anonymous namespace @@ -214,7 +214,7 @@ bool CFGFuncVerifier::verifyInstOperands(const Instruction &inst) { return false; } -bool CFGFuncVerifier::verifyBlock(const BasicBlock &block) { +bool CFGFuncVerifier::verifyBlock(const Block &block) { if (!block.getTerminator()) return failure("basic block with no terminator", block); @@ -287,12 +287,12 @@ bool MLFuncVerifier::verifyDominance() { // This recursive function walks the statement list pushing scopes onto the // stack as it goes, and popping them to remove them from the table. - std::function walkBlock; - walkBlock = [&](const StmtBlock &block) -> bool { + std::function walkBlock; + walkBlock = [&](const Block &block) -> bool { HashTable::ScopeTy blockScope(liveValues); // The induction variable of a for statement is live within its body. - if (auto *forStmt = dyn_cast_or_null(block.getContainingStmt())) + if (auto *forStmt = dyn_cast_or_null(block.getContainingInst())) liveValues.insert(forStmt, true); for (auto &stmt : block) { @@ -340,10 +340,10 @@ bool MLFuncVerifier::verifyDominance() { bool MLFuncVerifier::verifyReturn() { // TODO: fold return verification in the pass that verifies all statements. const char missingReturnMsg[] = "ML function must end with return statement"; - if (fn.getBody()->getStatements().empty()) + if (fn.getBody()->getInstructions().empty()) return failure(missingReturnMsg, fn); - const auto &stmt = fn.getBody()->getStatements().back(); + const auto &stmt = fn.getBody()->getInstructions().back(); if (const auto *op = dyn_cast(&stmt)) { if (!op->isReturn()) return failure(missingReturnMsg, fn); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 2ff7220f8ee..daaaee7010c 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -180,7 +180,7 @@ void ModuleState::visitExtFunction(const Function *fn) { void ModuleState::visitCFGFunction(const Function *fn) { visitType(fn->getType()); for (auto &block : *fn) { - for (auto &op : block.getStatements()) { + for (auto &op : block.getInstructions()) { if (auto *opInst = dyn_cast(&op)) visitOperation(opInst); else { @@ -914,7 +914,7 @@ public: void print(const OperationInst *inst); void print(const ForStmt *stmt); void print(const IfStmt *stmt); - void print(const StmtBlock *block); + void print(const Block *block); void printOperation(const OperationInst *op); void printDefaultOp(const OperationInst *op); @@ -944,11 +944,11 @@ public: enum { nameSentinel = ~0U }; - void printBBName(const BasicBlock *block) { os << "bb" << getBBID(block); } + void printBlockName(const Block *block) { os << "bb" << getBlockID(block); } - unsigned getBBID(const BasicBlock *block) { - auto it = basicBlockIDs.find(block); - assert(it != basicBlockIDs.end() && "Block not in this function?"); + unsigned getBlockID(const Block *block) { + auto it = blockIDs.find(block); + assert(it != blockIDs.end() && "Block not in this function?"); return it->second; } @@ -964,7 +964,7 @@ public: protected: void numberValueID(const Value *value); - void numberValuesInBlock(const StmtBlock &block); + void numberValuesInBlock(const Block &block); void printValueID(const Value *value, bool printResultNo = true) const; private: @@ -976,7 +976,7 @@ private: DenseMap valueNames; /// This is the block ID for each block in the current function. - DenseMap basicBlockIDs; + DenseMap blockIDs; /// This keeps track of all of the non-numeric names that are in flight, /// allowing us to check for duplicates. @@ -1007,10 +1007,10 @@ FunctionPrinter::FunctionPrinter(const Function *function, } /// Number all of the SSA values in the specified block list. -void FunctionPrinter::numberValuesInBlock(const StmtBlock &block) { +void FunctionPrinter::numberValuesInBlock(const Block &block) { // Each block gets a unique ID, and all of the instructions within it get // numbered as well. - basicBlockIDs[&block] = nextBlockID++; + blockIDs[&block] = nextBlockID++; for (auto *arg : block.getArguments()) numberValueID(arg); @@ -1154,6 +1154,7 @@ void FunctionPrinter::printMLFunctionSignature() { os << " : "; printType(arg->getType()); } + os << ')'; printFunctionResultType(type); } @@ -1174,11 +1175,11 @@ void FunctionPrinter::printOtherFunctionSignature() { printFunctionResultType(type); } -void FunctionPrinter::print(const StmtBlock *block) { +void FunctionPrinter::print(const Block *block) { // Print the block label and argument list, unless we are in an ML function. if (!block->getFunction()->isML()) { os.indent(currentIndent); - printBBName(block); + printBlockName(block); // Print the argument list if non-empty. if (!block->args_empty()) { @@ -1201,13 +1202,13 @@ void FunctionPrinter::print(const StmtBlock *block) { os << "\t// no predecessors"; } else if (auto *pred = block->getSinglePredecessor()) { os << "\t// pred: "; - printBBName(pred); + printBlockName(pred); } else { // We want to print the predecessors in increasing numeric order, not in // whatever order the use-list is in, so gather and sort them. SmallVector predIDs; for (auto *pred : block->getPredecessors()) - predIDs.push_back(getBBID(pred)); + predIDs.push_back(getBlockID(pred)); llvm::array_pod_sort(predIDs.begin(), predIDs.end()); os << "\t// " << predIDs.size() << " preds: "; @@ -1218,7 +1219,8 @@ void FunctionPrinter::print(const StmtBlock *block) { } currentIndent += indentWidth; - for (auto &stmt : block->getStatements()) { + + for (auto &stmt : block->getInstructions()) { print(&stmt); os << '\n'; } @@ -1358,10 +1360,9 @@ void FunctionPrinter::printDefaultOp(const OperationInst *op) { void FunctionPrinter::printSuccessorAndUseList(const OperationInst *term, unsigned index) { - printBBName(term->getSuccessor(index)); + printBlockName(term->getSuccessor(index)); auto succOperands = term->getSuccessorOperands(index); - if (succOperands.begin() == succOperands.end()) return; @@ -1516,7 +1517,7 @@ void Instruction::dump() const { llvm::errs() << "\n"; } -void BasicBlock::print(raw_ostream &os) const { +void Block::print(raw_ostream &os) const { auto *function = getFunction(); if (!function) { os << "<>\n"; @@ -1528,17 +1529,17 @@ void BasicBlock::print(raw_ostream &os) const { FunctionPrinter(function, modulePrinter).print(this); } -void BasicBlock::dump() const { print(llvm::errs()); } +void Block::dump() const { print(llvm::errs()); } /// Print out the name of the basic block without printing its body. -void StmtBlock::printAsOperand(raw_ostream &os, bool printType) { +void Block::printAsOperand(raw_ostream &os, bool printType) { if (!getFunction()) { os << "<>\n"; return; } ModuleState state(getFunction()->getContext()); ModulePrinter modulePrinter(os, state); - FunctionPrinter(getFunction(), modulePrinter).printBBName(this); + FunctionPrinter(getFunction(), modulePrinter).printBlockName(this); } void Function::print(raw_ostream &os) const { diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp new file mode 100644 index 00000000000..c7e84194c35 --- /dev/null +++ b/mlir/lib/IR/Block.cpp @@ -0,0 +1,232 @@ +//===- Block.cpp - MLIR Block and BlockList Classes -----------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +using namespace mlir; + +Block::~Block() { + clear(); + + llvm::DeleteContainerPointers(arguments); +} + +/// Returns the closest surrounding statement that contains this block or +/// nullptr if this is a top-level statement block. +Statement *Block::getContainingInst() { + return parent ? parent->getContainingInst() : nullptr; +} + +Function *Block::getFunction() { + Block *block = this; + while (auto *stmt = block->getContainingInst()) { + block = stmt->getBlock(); + if (!block) + return nullptr; + } + if (auto *list = block->getParent()) + return list->getContainingFunction(); + return nullptr; +} + +/// Returns 'inst' if 'inst' lies in this block, or otherwise finds the +/// ancestor instruction of 'inst' that lies in this block. Returns nullptr if +/// the latter fails. +const Instruction * +Block::findAncestorInstInBlock(const Instruction &inst) const { + // Traverse up the statement hierarchy starting from the owner of operand to + // find the ancestor statement that resides in the block of 'forStmt'. + const auto *currInst = &inst; + while (currInst->getBlock() != this) { + currInst = currInst->getParentStmt(); + if (!currInst) + return nullptr; + } + return currInst; +} + +//===----------------------------------------------------------------------===// +// Argument list management. +//===----------------------------------------------------------------------===// + +BlockArgument *Block::addArgument(Type type) { + auto *arg = new BlockArgument(type, this); + arguments.push_back(arg); + return arg; +} + +/// Add one argument to the argument list for each type specified in the list. +auto Block::addArguments(ArrayRef types) + -> llvm::iterator_range { + arguments.reserve(arguments.size() + types.size()); + auto initialSize = arguments.size(); + for (auto type : types) { + addArgument(type); + } + return {arguments.data() + initialSize, arguments.data() + arguments.size()}; +} + +void Block::eraseArgument(unsigned index) { + assert(index < arguments.size()); + + // Delete the argument. + delete arguments[index]; + arguments.erase(arguments.begin() + index); + + // Erase this argument from each of the predecessor's terminator. + for (auto predIt = pred_begin(), predE = pred_end(); predIt != predE; + ++predIt) { + auto *predTerminator = (*predIt)->getTerminator(); + predTerminator->eraseSuccessorOperand(predIt.getSuccessorIndex(), index); + } +} + +//===----------------------------------------------------------------------===// +// Terminator management +//===----------------------------------------------------------------------===// + +OperationInst *Block::getTerminator() { + if (empty()) + return nullptr; + + // Check if the last instruction is a terminator. + auto &backInst = back(); + auto *opStmt = dyn_cast(&backInst); + if (!opStmt || !opStmt->isTerminator()) + return nullptr; + return opStmt; +} + +/// Return true if this block has no predecessors. +bool Block::hasNoPredecessors() const { return pred_begin() == pred_end(); } + +// Indexed successor access. +unsigned Block::getNumSuccessors() const { + return getTerminator()->getNumSuccessors(); +} + +Block *Block::getSuccessor(unsigned i) { + return getTerminator()->getSuccessor(i); +} + +/// If this block has exactly one predecessor, return it. Otherwise, return +/// null. +/// +/// Note that multiple edges from a single block (e.g. if you have a cond +/// branch with the same block as the true/false destinations) is not +/// considered to be a single predecessor. +Block *Block::getSinglePredecessor() { + auto it = pred_begin(); + if (it == pred_end()) + return nullptr; + auto *firstPred = *it; + ++it; + return it == pred_end() ? firstPred : nullptr; +} + +//===----------------------------------------------------------------------===// +// Other +//===----------------------------------------------------------------------===// + +/// Unlink this Block from its Function and delete it. +void Block::eraseFromFunction() { + assert(getFunction() && "Block has no parent"); + getFunction()->getBlocks().erase(this); +} + +/// Split the basic block into two basic blocks before the specified +/// instruction or iterator. +/// +/// Note that all instructions BEFORE the specified iterator stay as part of +/// the original basic block, an unconditional branch is added to the original +/// block (going to the new block), and the rest of the instructions in the +/// original block are moved to the new BB, including the old terminator. The +/// newly formed Block is returned. +/// +/// This function invalidates the specified iterator. +Block *Block::splitBlock(iterator splitBefore) { + // Start by creating a new basic block, and insert it immediate after this + // one in the containing function. + auto newBB = new Block(); + getFunction()->getBlocks().insert(++Function::iterator(this), newBB); + auto branchLoc = + splitBefore == end() ? getTerminator()->getLoc() : splitBefore->getLoc(); + + // Move all of the operations from the split point to the end of the function + // into the new block. + newBB->getInstructions().splice(newBB->end(), getInstructions(), splitBefore, + end()); + + // Create an unconditional branch to the new block, and move our terminator + // to the new block. + FuncBuilder(this).create(branchLoc, newBB); + return newBB; +} + +//===----------------------------------------------------------------------===// +// BlockList +//===----------------------------------------------------------------------===// + +BlockList::BlockList(Function *container) : container(container) {} + +BlockList::BlockList(Statement *container) : container(container) {} + +Statement *BlockList::getContainingInst() { + return container.dyn_cast(); +} + +Function *BlockList::getContainingFunction() { + return container.dyn_cast(); +} + +BlockList *llvm::ilist_traits<::mlir::Block>::getContainingBlockList() { + size_t Offset( + size_t(&((BlockList *)nullptr->*BlockList::getSublistAccess(nullptr)))); + iplist *Anchor(static_cast *>(this)); + return reinterpret_cast(reinterpret_cast(Anchor) - + Offset); +} + +/// This is a trait method invoked when a basic block is added to a function. +/// We keep the function pointer up to date. +void llvm::ilist_traits<::mlir::Block>::addNodeToList(Block *block) { + assert(!block->parent && "already in a function!"); + block->parent = getContainingBlockList(); +} + +/// This is a trait method invoked when an instruction is removed from a +/// function. We keep the function pointer up to date. +void llvm::ilist_traits<::mlir::Block>::removeNodeFromList(Block *block) { + assert(block->parent && "not already in a function!"); + block->parent = nullptr; +} + +/// This is a trait method invoked when an instruction is moved from one block +/// to another. We keep the block pointer up to date. +void llvm::ilist_traits<::mlir::Block>::transferNodesFromList( + ilist_traits &otherList, block_iterator first, block_iterator last) { + // If we are transferring instructions within the same function, the parent + // pointer doesn't need to be updated. + auto *curParent = getContainingBlockList(); + if (curParent == otherList.getContainingBlockList()) + return; + + // Update the 'parent' member of each Block. + for (; first != last; ++first) + first->parent = curParent; +} diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 81a3b7c2950..a9eb6fe8c8a 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -275,8 +275,8 @@ AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) { /// 'insertBefore' basic block is passed, the block will be placed before the /// specified block. If not, the block will be appended to the end of the /// current function. -StmtBlock *FuncBuilder::createBlock(StmtBlock *insertBefore) { - StmtBlock *b = new StmtBlock(); +Block *FuncBuilder::createBlock(Block *insertBefore) { + Block *b = new Block(); // If we are supposed to insert before a specific block, do so, otherwise add // the block to the end of the function. @@ -294,7 +294,7 @@ OperationInst *FuncBuilder::createOperation(const OperationState &state) { auto *op = OperationInst::create(state.location, state.name, state.operands, state.types, state.attributes, state.successors, context); - block->getStatements().insert(insertPoint, op); + block->getInstructions().insert(insertPoint, op); return op; } @@ -303,7 +303,7 @@ ForStmt *FuncBuilder::createFor(Location location, ArrayRef lbOperands, AffineMap ubMap, int64_t step) { auto *stmt = ForStmt::create(location, lbOperands, lbMap, ubOperands, ubMap, step); - block->getStatements().insert(insertPoint, stmt); + block->getInstructions().insert(insertPoint, stmt); return stmt; } @@ -317,6 +317,6 @@ ForStmt *FuncBuilder::createFor(Location location, int64_t lb, int64_t ub, IfStmt *FuncBuilder::createIf(Location location, ArrayRef operands, IntegerSet set) { auto *stmt = IfStmt::create(location, operands, set); - block->getStatements().insert(insertPoint, stmt); + block->getInstructions().insert(insertPoint, stmt); return stmt; } diff --git a/mlir/lib/IR/BuiltinOps.cpp b/mlir/lib/IR/BuiltinOps.cpp index 51596a9f09e..a0264fc11b0 100644 --- a/mlir/lib/IR/BuiltinOps.cpp +++ b/mlir/lib/IR/BuiltinOps.cpp @@ -167,13 +167,13 @@ bool AffineApplyOp::constantFold(ArrayRef operandConstants, // BranchOp //===----------------------------------------------------------------------===// -void BranchOp::build(Builder *builder, OperationState *result, BasicBlock *dest, +void BranchOp::build(Builder *builder, OperationState *result, Block *dest, ArrayRef operands) { result->addSuccessor(dest, operands); } bool BranchOp::parse(OpAsmParser *parser, OperationState *result) { - BasicBlock *dest; + Block *dest; SmallVector destOperands; if (parser->parseSuccessorAndUseList(dest, destOperands)) return true; @@ -193,9 +193,9 @@ bool BranchOp::verify() const { return false; } -BasicBlock *BranchOp::getDest() { return getInstruction()->getSuccessor(0); } +Block *BranchOp::getDest() { return getInstruction()->getSuccessor(0); } -void BranchOp::setDest(BasicBlock *block) { +void BranchOp::setDest(Block *block) { return getInstruction()->setSuccessor(block, 0); } @@ -208,8 +208,8 @@ void BranchOp::eraseOperand(unsigned index) { //===----------------------------------------------------------------------===// void CondBranchOp::build(Builder *builder, OperationState *result, - Value *condition, BasicBlock *trueDest, - ArrayRef trueOperands, BasicBlock *falseDest, + Value *condition, Block *trueDest, + ArrayRef trueOperands, Block *falseDest, ArrayRef falseOperands) { result->addOperands(condition); result->addSuccessor(trueDest, trueOperands); @@ -218,7 +218,7 @@ void CondBranchOp::build(Builder *builder, OperationState *result, bool CondBranchOp::parse(OpAsmParser *parser, OperationState *result) { SmallVector destOperands; - BasicBlock *dest; + Block *dest; OpAsmParser::OperandType condInfo; // Parse the condition. @@ -263,11 +263,11 @@ bool CondBranchOp::verify() const { return false; } -BasicBlock *CondBranchOp::getTrueDest() { +Block *CondBranchOp::getTrueDest() { return getInstruction()->getSuccessor(trueIndex); } -BasicBlock *CondBranchOp::getFalseDest() { +Block *CondBranchOp::getFalseDest() { return getInstruction()->getSuccessor(falseIndex); } diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index 0e777c65f23..cbe84e10247 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -37,7 +37,7 @@ Function::Function(Kind kind, Location location, StringRef name, // TODO(clattner): Unify this behavior. if (kind == Kind::MLFunc) { // The body of an ML Function always has one block. - auto *entry = new StmtBlock(); + auto *entry = new Block(); blocks.push_back(entry); // Initialize the arguments. diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 23e54b3638e..ccd7d65f7c8 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -245,7 +245,7 @@ bool OpTrait::impl::verifySameOperandsAndResultType(const OperationInst *op) { static bool verifyBBArguments( llvm::iterator_range operands, - const BasicBlock *destBB, const OperationInst *op) { + const Block *destBB, const OperationInst *op) { unsigned operandCount = std::distance(operands.begin(), operands.end()); if (operandCount != destBB->getNumArguments()) return op->emitError("branch has " + Twine(operandCount) + @@ -277,11 +277,11 @@ static bool verifyTerminatorSuccessors(const OperationInst *op) { bool OpTrait::impl::verifyIsTerminator(const OperationInst *op) { // Verify that the operation is at the end of the respective parent block. if (op->getFunction()->isML()) { - StmtBlock *block = op->getBlock(); - if (!block || block->getContainingStmt() || &block->back() != op) + Block *block = op->getBlock(); + if (!block || block->getContainingInst() || &block->back() != op) return op->emitOpError("must be the last statement in the ML function"); } else { - const BasicBlock *block = op->getBlock(); + const Block *block = op->getBlock(); if (!block || &block->back() != op) return op->emitOpError( "must be the last instruction in the parent basic block."); diff --git a/mlir/lib/IR/Statement.cpp b/mlir/lib/IR/Statement.cpp index 96b44600460..6bd9944bb65 100644 --- a/mlir/lib/IR/Statement.cpp +++ b/mlir/lib/IR/Statement.cpp @@ -49,7 +49,7 @@ template <> unsigned InstOperand::getOperandNumber() const { } /// Return which operand this is in the operand list. -template <> unsigned StmtBlockOperand::getOperandNumber() const { +template <> unsigned BlockOperand::getOperandNumber() const { return this - &getOwner()->getBlockOperands()[0]; } @@ -79,7 +79,7 @@ void Statement::destroy() { } Statement *Statement::getParentStmt() const { - return block ? block->getContainingStmt() : nullptr; + return block ? block->getContainingInst() : nullptr; } Function *Statement::getFunction() const { @@ -191,12 +191,10 @@ void llvm::ilist_traits<::mlir::Statement>::deleteNode(Statement *stmt) { stmt->destroy(); } -StmtBlock *llvm::ilist_traits<::mlir::Statement>::getContainingBlock() { - size_t Offset( - size_t(&((StmtBlock *)nullptr->*StmtBlock::getSublistAccess(nullptr)))); +Block *llvm::ilist_traits<::mlir::Statement>::getContainingBlock() { + size_t Offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr)))); iplist *Anchor(static_cast *>(this)); - return reinterpret_cast(reinterpret_cast(Anchor) - - Offset); + return reinterpret_cast(reinterpret_cast(Anchor) - Offset); } /// This is a trait method invoked when a statement is added to a block. We @@ -221,7 +219,7 @@ void llvm::ilist_traits<::mlir::Statement>::transferNodesFromList( stmt_iterator last) { // If we are transferring statements within the same block, the block // pointer doesn't need to be updated. - StmtBlock *curParent = getContainingBlock(); + Block *curParent = getContainingBlock(); if (curParent == otherList.getContainingBlock()) return; @@ -230,11 +228,11 @@ void llvm::ilist_traits<::mlir::Statement>::transferNodesFromList( first->block = curParent; } -/// Remove this statement (and its descendants) from its StmtBlock and delete +/// Remove this statement (and its descendants) from its Block and delete /// all of them. void Statement::erase() { assert(getBlock() && "Statement has no block"); - getBlock()->getStatements().erase(this); + getBlock()->getInstructions().erase(this); } /// Unlink this statement from its current block and insert it right before @@ -246,10 +244,10 @@ void Statement::moveBefore(Statement *existingStmt) { /// Unlink this operation instruction from its current basic block and insert /// it right before `iterator` in the specified basic block. -void Statement::moveBefore(StmtBlock *block, +void Statement::moveBefore(Block *block, llvm::iplist::iterator iterator) { - block->getStatements().splice(iterator, getBlock()->getStatements(), - getIterator()); + block->getInstructions().splice(iterator, getBlock()->getInstructions(), + getIterator()); } /// This drops all operand uses from this instruction, which is an essential @@ -273,7 +271,7 @@ OperationInst *OperationInst::create(Location location, OperationName name, ArrayRef operands, ArrayRef resultTypes, ArrayRef attributes, - ArrayRef successors, + ArrayRef successors, MLIRContext *context) { unsigned numSuccessors = successors.size(); @@ -282,7 +280,7 @@ OperationInst *OperationInst::create(Location location, OperationName name, unsigned numOperands = operands.size() - numSuccessors; auto byteSize = - totalSizeToAlloc( + totalSizeToAlloc( resultTypes.size(), numSuccessors, numSuccessors, numOperands); void *rawMem = malloc(byteSize); @@ -340,7 +338,7 @@ OperationInst *OperationInst::create(Location location, OperationName name, } new (&instBlockOperands[currentSuccNum]) - StmtBlockOperand(stmt, successors[currentSuccNum]); + BlockOperand(stmt, successors[currentSuccNum]); *succOperandCountIt = 0; ++currentSuccNum; continue; @@ -382,7 +380,7 @@ OperationInst::~OperationInst() { // Explicitly run the destructors for the successors. if (isTerminator()) for (auto &successor : getBlockOperands()) - successor.~StmtBlockOperand(); + successor.~BlockOperand(); } /// Return true if there are no users of any results of this operation. @@ -420,7 +418,7 @@ MLIRContext *OperationInst::getContext() const { bool OperationInst::isReturn() const { return isa(); } -void OperationInst::setSuccessor(BasicBlock *block, unsigned index) { +void OperationInst::setSuccessor(Block *block, unsigned index) { assert(index < getNumSuccessors()); getBlockOperands()[index].set(block); } @@ -559,7 +557,7 @@ ForStmt::ForStmt(Location location, unsigned numOperands, AffineMap lbMap, body(this), lbMap(lbMap), ubMap(ubMap), step(step) { // The body of a for stmt always has one block. - body.push_back(new StmtBlock()); + body.push_back(new Block()); operands.reserve(numOperands); } @@ -679,7 +677,7 @@ IfStmt::IfStmt(Location location, unsigned numOperands, IntegerSet set) operands.reserve(numOperands); // The then of an 'if' stmt always has one block. - thenClause.push_back(new StmtBlock()); + thenClause.push_back(new Block()); } IfStmt::~IfStmt() { @@ -736,7 +734,7 @@ Statement *Statement::clone(DenseMap &operandMap, }; SmallVector operands; - SmallVector successors; + SmallVector successors; if (auto *opStmt = dyn_cast(this)) { operands.reserve(getNumOperands() + opStmt->getNumSuccessors()); @@ -758,8 +756,7 @@ Statement *Statement::clone(DenseMap &operandMap, successors.reserve(opStmt->getNumSuccessors()); for (unsigned succ = 0, e = opStmt->getNumSuccessors(); succ != e; ++succ) { - successors.push_back( - const_cast(opStmt->getSuccessor(succ))); + successors.push_back(const_cast(opStmt->getSuccessor(succ))); // Add sentinel to delineate successor operands. operands.push_back(nullptr); diff --git a/mlir/lib/IR/StmtBlock.cpp b/mlir/lib/IR/StmtBlock.cpp deleted file mode 100644 index b551b1121a7..00000000000 --- a/mlir/lib/IR/StmtBlock.cpp +++ /dev/null @@ -1,236 +0,0 @@ -//===- StmtBlock.cpp - MLIR Statement Instruction Classes -----------------===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#include "mlir/IR/StmtBlock.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" -using namespace mlir; - -StmtBlock::~StmtBlock() { - clear(); - - llvm::DeleteContainerPointers(arguments); -} - -/// Returns the closest surrounding statement that contains this block or -/// nullptr if this is a top-level statement block. -Statement *StmtBlock::getContainingStmt() { - return parent ? parent->getContainingStmt() : nullptr; -} - -Function *StmtBlock::getFunction() { - StmtBlock *block = this; - while (auto *stmt = block->getContainingStmt()) { - block = stmt->getBlock(); - if (!block) - return nullptr; - } - if (auto *list = block->getParent()) - return list->getContainingFunction(); - return nullptr; -} - -/// Returns 'stmt' if 'stmt' lies in this block, or otherwise finds the ancestor -/// statement of 'stmt' that lies in this block. Returns nullptr if the latter -/// fails. -const Statement * -StmtBlock::findAncestorStmtInBlock(const Statement &stmt) const { - // Traverse up the statement hierarchy starting from the owner of operand to - // find the ancestor statement that resides in the block of 'forStmt'. - const auto *currStmt = &stmt; - while (currStmt->getBlock() != this) { - currStmt = currStmt->getParentStmt(); - if (!currStmt) - return nullptr; - } - return currStmt; -} - -//===----------------------------------------------------------------------===// -// Argument list management. -//===----------------------------------------------------------------------===// - -BlockArgument *StmtBlock::addArgument(Type type) { - auto *arg = new BlockArgument(type, this); - arguments.push_back(arg); - return arg; -} - -/// Add one argument to the argument list for each type specified in the list. -auto StmtBlock::addArguments(ArrayRef types) - -> llvm::iterator_range { - arguments.reserve(arguments.size() + types.size()); - auto initialSize = arguments.size(); - for (auto type : types) { - addArgument(type); - } - return {arguments.data() + initialSize, arguments.data() + arguments.size()}; -} - -void StmtBlock::eraseArgument(unsigned index) { - assert(index < arguments.size()); - - // Delete the argument. - delete arguments[index]; - arguments.erase(arguments.begin() + index); - - // Erase this argument from each of the predecessor's terminator. - for (auto predIt = pred_begin(), predE = pred_end(); predIt != predE; - ++predIt) { - auto *predTerminator = (*predIt)->getTerminator(); - predTerminator->eraseSuccessorOperand(predIt.getSuccessorIndex(), index); - } -} - -//===----------------------------------------------------------------------===// -// Terminator management -//===----------------------------------------------------------------------===// - -OperationInst *StmtBlock::getTerminator() { - if (empty()) - return nullptr; - - // Check if the last instruction is a terminator. - auto &backInst = statements.back(); - auto *opStmt = dyn_cast(&backInst); - if (!opStmt || !opStmt->isTerminator()) - return nullptr; - return opStmt; -} - -/// Return true if this block has no predecessors. -bool StmtBlock::hasNoPredecessors() const { return pred_begin() == pred_end(); } - -// Indexed successor access. -unsigned StmtBlock::getNumSuccessors() const { - return getTerminator()->getNumSuccessors(); -} - -StmtBlock *StmtBlock::getSuccessor(unsigned i) { - return getTerminator()->getSuccessor(i); -} - -/// If this block has exactly one predecessor, return it. Otherwise, return -/// null. -/// -/// Note that multiple edges from a single block (e.g. if you have a cond -/// branch with the same block as the true/false destinations) is not -/// considered to be a single predecessor. -StmtBlock *StmtBlock::getSinglePredecessor() { - auto it = pred_begin(); - if (it == pred_end()) - return nullptr; - auto *firstPred = *it; - ++it; - return it == pred_end() ? firstPred : nullptr; -} - -//===----------------------------------------------------------------------===// -// Other -//===----------------------------------------------------------------------===// - -/// Unlink this BasicBlock from its Function and delete it. -void BasicBlock::eraseFromFunction() { - assert(getFunction() && "BasicBlock has no parent"); - getFunction()->getBlocks().erase(this); -} - -/// Split the basic block into two basic blocks before the specified -/// instruction or iterator. -/// -/// Note that all instructions BEFORE the specified iterator stay as part of -/// the original basic block, an unconditional branch is added to the original -/// block (going to the new block), and the rest of the instructions in the -/// original block are moved to the new BB, including the old terminator. The -/// newly formed BasicBlock is returned. -/// -/// This function invalidates the specified iterator. -BasicBlock *BasicBlock::splitBasicBlock(iterator splitBefore) { - // Start by creating a new basic block, and insert it immediate after this - // one in the containing function. - auto newBB = new BasicBlock(); - getFunction()->getBlocks().insert(++Function::iterator(this), newBB); - auto branchLoc = - splitBefore == end() ? getTerminator()->getLoc() : splitBefore->getLoc(); - - // Move all of the operations from the split point to the end of the function - // into the new block. - newBB->getStatements().splice(newBB->end(), getStatements(), splitBefore, - end()); - - // Create an unconditional branch to the new block, and move our terminator - // to the new block. - FuncBuilder(this).create(branchLoc, newBB); - return newBB; -} - -//===----------------------------------------------------------------------===// -// StmtBlockList -//===----------------------------------------------------------------------===// - -StmtBlockList::StmtBlockList(Function *container) : container(container) {} - -StmtBlockList::StmtBlockList(Statement *container) : container(container) {} - -Function *StmtBlockList::getFunction() { return getContainingFunction(); } - -Statement *StmtBlockList::getContainingStmt() { - return container.dyn_cast(); -} - -Function *StmtBlockList::getContainingFunction() { - return container.dyn_cast(); -} - -StmtBlockList *llvm::ilist_traits<::mlir::StmtBlock>::getContainingBlockList() { - size_t Offset(size_t( - &((StmtBlockList *)nullptr->*StmtBlockList::getSublistAccess(nullptr)))); - iplist *Anchor(static_cast *>(this)); - return reinterpret_cast(reinterpret_cast(Anchor) - - Offset); -} - -/// This is a trait method invoked when a basic block is added to a function. -/// We keep the function pointer up to date. -void llvm::ilist_traits<::mlir::StmtBlock>::addNodeToList(StmtBlock *block) { - assert(!block->parent && "already in a function!"); - block->parent = getContainingBlockList(); -} - -/// This is a trait method invoked when an instruction is removed from a -/// function. We keep the function pointer up to date. -void llvm::ilist_traits<::mlir::StmtBlock>::removeNodeFromList( - StmtBlock *block) { - assert(block->parent && "not already in a function!"); - block->parent = nullptr; -} - -/// This is a trait method invoked when an instruction is moved from one block -/// to another. We keep the block pointer up to date. -void llvm::ilist_traits<::mlir::StmtBlock>::transferNodesFromList( - ilist_traits &otherList, block_iterator first, - block_iterator last) { - // If we are transferring instructions within the same function, the parent - // pointer doesn't need to be updated. - auto *curParent = getContainingBlockList(); - if (curParent == otherList.getContainingBlockList()) - return; - - // Update the 'parent' member of each StmtBlock. - for (; first != last; ++first) - first->parent = curParent; -} diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 9b67ef8b150..6cc1aba72b3 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -1921,7 +1921,7 @@ public: parseCustomOperation(const CreateOperationFunction &createOpFunc); /// Parse a single operation successor and it's operand list. - virtual bool parseSuccessorAndUseList(BasicBlock *&dest, + virtual bool parseSuccessorAndUseList(Block *&dest, SmallVectorImpl &operands) = 0; protected: @@ -2398,7 +2398,7 @@ public: return false; } - bool parseSuccessorAndUseList(BasicBlock *&dest, + bool parseSuccessorAndUseList(Block *&dest, SmallVectorImpl &operands) override { // Defer successor parsing to the function parsers. return parser.parseSuccessorAndUseList(dest, operands); @@ -2570,13 +2570,13 @@ public: ParseResult parseFunctionBody(); - bool parseSuccessorAndUseList(BasicBlock *&dest, + bool parseSuccessorAndUseList(Block *&dest, SmallVectorImpl &operands); private: Function *function; - llvm::StringMap> blocksByName; - DenseMap forwardRef; + llvm::StringMap> blocksByName; + DenseMap forwardRef; /// This builder intentionally shadows the builder in the base class, with a /// more specific builder type. @@ -2585,10 +2585,10 @@ private: /// Get the basic block with the specified name, creating it if it doesn't /// already exist. The location specified is the point of use, which allows /// us to diagnose references to blocks that are not defined precisely. - BasicBlock *getBlockNamed(StringRef name, SMLoc loc) { + Block *getBlockNamed(StringRef name, SMLoc loc) { auto &blockAndLoc = blocksByName[name]; if (!blockAndLoc.first) { - blockAndLoc.first = new BasicBlock(); + blockAndLoc.first = new Block(); forwardRef[blockAndLoc.first] = loc; function->push_back(blockAndLoc.first); blockAndLoc.second = loc; @@ -2597,9 +2597,9 @@ private: return blockAndLoc.first; } - // Define the basic block with the specified name. Returns the BasicBlock* or + // Define the basic block with the specified name. Returns the Block* or // nullptr in the case of redefinition. - BasicBlock *defineBlockNamed(StringRef name, SMLoc loc) { + Block *defineBlockNamed(StringRef name, SMLoc loc) { auto &blockAndLoc = blocksByName[name]; if (!blockAndLoc.first) { blockAndLoc.first = builder.createBlock(); @@ -2621,10 +2621,10 @@ private: } ParseResult - parseOptionalBasicBlockArgList(SmallVectorImpl &results, - BasicBlock *owner); + parseOptionalBlockArgList(SmallVectorImpl &results, + Block *owner); - ParseResult parseBasicBlock(); + ParseResult parseBlock(); }; } // end anonymous namespace @@ -2634,7 +2634,7 @@ private: /// branch-use-list ::= `(` ssa-use-list ':' type-list-no-parens `)` /// bool CFGFunctionParser::parseSuccessorAndUseList( - BasicBlock *&dest, SmallVectorImpl &operands) { + Block *&dest, SmallVectorImpl &operands) { // Verify branch is identifier and get the matching block. if (!getToken().is(Token::bare_identifier)) return emitError("expected basic block name"); @@ -2656,8 +2656,8 @@ bool CFGFunctionParser::parseSuccessorAndUseList( /// /// ssa-id-and-type-list ::= ssa-id-and-type (`,` ssa-id-and-type)* /// -ParseResult CFGFunctionParser::parseOptionalBasicBlockArgList( - SmallVectorImpl &results, BasicBlock *owner) { +ParseResult CFGFunctionParser::parseOptionalBlockArgList( + SmallVectorImpl &results, Block *owner) { if (getToken().is(Token::r_brace)) return ParseSuccess; @@ -2684,12 +2684,12 @@ ParseResult CFGFunctionParser::parseFunctionBody() { // Parse the list of blocks. while (!consumeIf(Token::r_brace)) - if (parseBasicBlock()) + if (parseBlock()) return ParseFailure; // Verify that all referenced blocks were defined. if (!forwardRef.empty()) { - SmallVector, 4> errors; + SmallVector, 4> errors; // Iteration over the map isn't deterministic, so sort by source location. for (auto entry : forwardRef) errors.push_back({entry.second.getPointer(), entry.first}); @@ -2721,7 +2721,7 @@ ParseResult CFGFunctionParser::parseFunctionBody() { /// bb-id ::= bare-id /// bb-arg-list ::= `(` ssa-id-and-type-list? `)` /// -ParseResult CFGFunctionParser::parseBasicBlock() { +ParseResult CFGFunctionParser::parseBlock() { SMLoc nameLoc = getToken().getLoc(); auto name = getTokenSpelling(); if (parseToken(Token::bare_identifier, "expected basic block name")) @@ -2736,7 +2736,7 @@ ParseResult CFGFunctionParser::parseBasicBlock() { // If an argument list is present, parse it. if (consumeIf(Token::l_paren)) { SmallVector bbArgs; - if (parseOptionalBasicBlockArgList(bbArgs, block) || + if (parseOptionalBlockArgList(bbArgs, block) || parseToken(Token::r_paren, "expected ')' to end argument list")) return ParseFailure; } @@ -2794,11 +2794,11 @@ private: ParseResult parseBound(SmallVectorImpl &operands, AffineMap &map, bool isLower); ParseResult parseIfStmt(); - ParseResult parseElseClause(StmtBlock *elseClause); - ParseResult parseStatements(StmtBlock *block); - ParseResult parseStmtBlock(StmtBlock *block); + ParseResult parseElseClause(Block *elseClause); + ParseResult parseStatements(Block *block); + ParseResult parseBlock(Block *block); - bool parseSuccessorAndUseList(BasicBlock *&dest, + bool parseSuccessorAndUseList(Block *&dest, SmallVectorImpl &operands) { assert(false && "MLFunctions do not have terminators with successors."); return true; @@ -2810,7 +2810,7 @@ ParseResult MLFunctionParser::parseFunctionBody() { auto braceLoc = getToken().getLoc(); // Parse statements in this function. - if (parseStmtBlock(function->getBody())) + if (parseBlock(function->getBody())) return ParseFailure; return finalizeFunction(function, braceLoc); @@ -2874,7 +2874,7 @@ ParseResult MLFunctionParser::parseForStmt() { // If parsing of the for statement body fails, // MLIR contains for statement with those nested statements that have been // successfully parsed. - if (parseStmtBlock(forStmt->getBody())) + if (parseBlock(forStmt->getBody())) return ParseFailure; // Reset insertion point to the current block. @@ -3118,12 +3118,12 @@ ParseResult MLFunctionParser::parseIfStmt() { IfStmt *ifStmt = builder.createIf(getEncodedSourceLocation(loc), operands, set); - StmtBlock *thenClause = ifStmt->getThen(); + Block *thenClause = ifStmt->getThen(); // When parsing of an if statement body fails, the IR contains // the if statement with the portion of the body that has been // successfully parsed. - if (parseStmtBlock(thenClause)) + if (parseBlock(thenClause)) return ParseFailure; if (consumeIf(Token::kw_else)) { @@ -3138,19 +3138,19 @@ ParseResult MLFunctionParser::parseIfStmt() { return ParseSuccess; } -ParseResult MLFunctionParser::parseElseClause(StmtBlock *elseClause) { +ParseResult MLFunctionParser::parseElseClause(Block *elseClause) { if (getToken().is(Token::kw_if)) { builder.setInsertionPointToEnd(elseClause); return parseIfStmt(); } - return parseStmtBlock(elseClause); + return parseBlock(elseClause); } /// /// Parse a list of statements ending with `return` or `}` /// -ParseResult MLFunctionParser::parseStatements(StmtBlock *block) { +ParseResult MLFunctionParser::parseStatements(Block *block) { auto createOpFunc = [&](const OperationState &state) -> OperationInst * { return builder.createOperation(state); }; @@ -3188,7 +3188,7 @@ ParseResult MLFunctionParser::parseStatements(StmtBlock *block) { /// /// Parse `{` ml-stmt* `}` /// -ParseResult MLFunctionParser::parseStmtBlock(StmtBlock *block) { +ParseResult MLFunctionParser::parseBlock(Block *block) { if (parseToken(Token::l_brace, "expected '{' before statement list") || parseStatements(block) || parseToken(Token::r_brace, "expected '}' after statement list")) diff --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp index e9942ff824b..0f130e19e26 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp @@ -52,7 +52,7 @@ public: bool runOnModule(Module &m, llvm::Module &llvmModule); private: - bool convertBasicBlock(const BasicBlock &bb, bool ignoreArguments = false); + bool convertBlock(const Block &bb, bool ignoreArguments = false); bool convertCFGFunction(const Function &cfgFunc, llvm::Function &llvmFunc); bool convertFunctions(const Module &mlirModule, llvm::Module &llvmModule); bool convertInstruction(const OperationInst &inst); @@ -142,7 +142,7 @@ private: llvm::DenseMap functionMapping; llvm::DenseMap valueMapping; - llvm::DenseMap blockMapping; + llvm::DenseMap blockMapping; llvm::LLVMContext &llvmContext; llvm::IRBuilder builder; llvm::IntegerType *indexType; @@ -742,8 +742,7 @@ bool ModuleLowerer::convertInstruction(const OperationInst &inst) { return inst.emitError("unsupported operation"); } -bool ModuleLowerer::convertBasicBlock(const BasicBlock &bb, - bool ignoreArguments) { +bool ModuleLowerer::convertBlock(const Block &bb, bool ignoreArguments) { builder.SetInsertPoint(blockMapping[&bb]); // Before traversing instructions, make block arguments available through @@ -780,8 +779,7 @@ bool ModuleLowerer::convertBasicBlock(const BasicBlock &bb, // Get the SSA value passed to the current block from the terminator instruction // of its predecessor. -static const Value *getPHISourceValue(const BasicBlock *current, - const BasicBlock *pred, +static const Value *getPHISourceValue(const Block *current, const Block *pred, unsigned numArguments, unsigned index) { auto &terminator = *pred->getTerminator(); if (terminator.isa()) { @@ -804,7 +802,7 @@ void ModuleLowerer::connectPHINodes(const Function &cfgFunc) { // to the arguments of the LLVM function. for (auto it = std::next(cfgFunc.begin()), eit = cfgFunc.end(); it != eit; ++it) { - const BasicBlock *bb = &*it; + const Block *bb = &*it; llvm::BasicBlock *llvmBB = blockMapping[bb]; auto phis = llvmBB->phis(); auto numArguments = bb->getNumArguments(); @@ -837,7 +835,7 @@ bool ModuleLowerer::convertCFGFunction(const Function &cfgFunc, // Then, convert blocks one by one. for (auto indexedBB : llvm::enumerate(cfgFunc)) { const auto &bb = indexedBB.value(); - if (convertBasicBlock(bb, /*ignoreArguments=*/indexedBB.index() == 0)) + if (convertBlock(bb, /*ignoreArguments=*/indexedBB.index() == 0)) return true; } @@ -872,7 +870,7 @@ bool ModuleLowerer::convertFunctions(const Module &mlirModule, // arguments of the first block are those of the function. assert(!functionPtr->getBlocks().empty() && "expected at least one basic block in a Function"); - const BasicBlock &firstBlock = *functionPtr->begin(); + const Block &firstBlock = *functionPtr->begin(); for (auto arg : llvm::enumerate(llvmFunc->args())) { valueMapping[firstBlock.getArgument(arg.index())] = &arg.value(); } diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index 04f7cfdc3e9..a5b45ba4098 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -182,7 +182,7 @@ struct CFGCSE : public CSEImpl { // Check to see if we need to process this node. if (!currentNode->processed) { currentNode->processed = true; - simplifyBasicBlock(currentNode->node->getBlock()); + simplifyBlock(currentNode->node->getBlock()); // Otherwise, check to see if we need to process a child node. } else if (currentNode->childIterator != currentNode->node->end()) { auto *childNode = *(currentNode->childIterator++); @@ -199,7 +199,7 @@ struct CFGCSE : public CSEImpl { eraseDeadOperations(); } - void simplifyBasicBlock(BasicBlock *bb) { + void simplifyBlock(Block *bb) { for (auto &i : *bb) if (auto *opInst = dyn_cast(&i)) simplifyOperation(opInst); diff --git a/mlir/lib/Transforms/ComposeAffineMaps.cpp b/mlir/lib/Transforms/ComposeAffineMaps.cpp index 8c69fa61578..c97b83f8485 100644 --- a/mlir/lib/Transforms/ComposeAffineMaps.cpp +++ b/mlir/lib/Transforms/ComposeAffineMaps.cpp @@ -45,8 +45,8 @@ struct ComposeAffineMaps : public FunctionPass, StmtWalker { std::vector affineApplyOpsToErase; explicit ComposeAffineMaps() : FunctionPass(&ComposeAffineMaps::passID) {} - using StmtListType = llvm::iplist; - void walk(StmtListType::iterator Start, StmtListType::iterator End); + using InstListType = llvm::iplist; + void walk(InstListType::iterator Start, InstListType::iterator End); void visitOperationInst(OperationInst *stmt); PassResult runOnMLFunction(Function *f) override; using StmtWalker::walk; @@ -62,8 +62,8 @@ FunctionPass *mlir::createComposeAffineMapsPass() { return new ComposeAffineMaps(); } -void ComposeAffineMaps::walk(StmtListType::iterator Start, - StmtListType::iterator End) { +void ComposeAffineMaps::walk(InstListType::iterator Start, + InstListType::iterator End) { while (Start != End) { walk(&(*Start)); // Increment iterator after walk as visit function can mutate stmt list diff --git a/mlir/lib/Transforms/ConvertToCFG.cpp b/mlir/lib/Transforms/ConvertToCFG.cpp index 270a25dd339..821f35ca539 100644 --- a/mlir/lib/Transforms/ConvertToCFG.cpp +++ b/mlir/lib/Transforms/ConvertToCFG.cpp @@ -50,7 +50,7 @@ public: private: Value *getConstantIndexValue(int64_t value); - void visitStmtBlock(StmtBlock *stmtBlock); + void visitBlock(Block *Block); Value *buildMinMaxReductionSeq( Location loc, CmpIPredicate predicate, llvm::iterator_range values); @@ -117,8 +117,8 @@ Value *FunctionConverter::getConstantIndexValue(int64_t value) { } // Visit all statements in the given statement block. -void FunctionConverter::visitStmtBlock(StmtBlock *stmtBlock) { - for (auto &stmt : *stmtBlock) +void FunctionConverter::visitBlock(Block *Block) { + for (auto &stmt : *Block) this->visit(&stmt); } @@ -214,13 +214,13 @@ Value *FunctionConverter::buildMinMaxReductionSeq( void FunctionConverter::visitForStmt(ForStmt *forStmt) { // First, store the loop insertion location so that we can go back to it after // creating the new blocks (block creation updates the insertion point). - BasicBlock *loopInsertionPoint = builder.getInsertionBlock(); + Block *loopInsertionPoint = builder.getInsertionBlock(); // Create blocks so that they appear in more human-readable order in the // output. - BasicBlock *loopInitBlock = builder.createBlock(); - BasicBlock *loopConditionBlock = builder.createBlock(); - BasicBlock *loopBodyFirstBlock = builder.createBlock(); + Block *loopInitBlock = builder.createBlock(); + Block *loopConditionBlock = builder.createBlock(); + Block *loopBodyFirstBlock = builder.createBlock(); // At the loop insertion location, branch immediately to the loop init block. builder.setInsertionPointToEnd(loopInsertionPoint); @@ -238,7 +238,7 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) { // Walking manually because we need custom logic before and after traversing // the list of children. builder.setInsertionPointToEnd(loopBodyFirstBlock); - visitStmtBlock(forStmt->getBody()); + visitBlock(forStmt->getBody()); // Builder point is currently at the last block of the loop body. Append the // induction variable stepping to this block and branch back to the exit @@ -254,7 +254,7 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) { nextIvValue); // Create post-loop block here so that it appears after all loop body blocks. - BasicBlock *postLoopBlock = builder.createBlock(); + Block *postLoopBlock = builder.createBlock(); builder.setInsertionPointToEnd(loopInitBlock); // Compute loop bounds using affine_apply after remapping its operands. @@ -378,15 +378,15 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) { // the false branch as soon as one condition fails. `cond_br` requires // another block as a target when the condition is true, and that block will // contain the next condition. - BasicBlock *ifInsertionBlock = builder.getInsertionBlock(); - SmallVector ifConditionExtraBlocks; + Block *ifInsertionBlock = builder.getInsertionBlock(); + SmallVector ifConditionExtraBlocks; unsigned numConstraints = integerSet.getNumConstraints(); ifConditionExtraBlocks.reserve(numConstraints - 1); for (unsigned i = 0, e = numConstraints - 1; i < e; ++i) { ifConditionExtraBlocks.push_back(builder.createBlock()); } - BasicBlock *thenBlock = builder.createBlock(); - BasicBlock *elseBlock = builder.createBlock(); + Block *thenBlock = builder.createBlock(); + Block *elseBlock = builder.createBlock(); builder.setInsertionPointToEnd(ifInsertionBlock); // Implement short-circuit logic. For each affine expression in the 'if' @@ -405,7 +405,7 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) { ifConditionExtraBlocks)) { AffineExpr constraintExpr = std::get<0>(tuple); bool isEquality = std::get<1>(tuple); - BasicBlock *nextBlock = std::get<2>(tuple); + Block *nextBlock = std::get<2>(tuple); // Build and apply an affine map. auto affineMap = @@ -429,19 +429,19 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) { // Recursively traverse the 'then' block. builder.setInsertionPointToEnd(thenBlock); - visitStmtBlock(ifStmt->getThen()); - BasicBlock *lastThenBlock = builder.getInsertionBlock(); + visitBlock(ifStmt->getThen()); + Block *lastThenBlock = builder.getInsertionBlock(); // Recursively traverse the 'else' block if present. builder.setInsertionPointToEnd(elseBlock); if (ifStmt->hasElse()) - visitStmtBlock(ifStmt->getElse()); - BasicBlock *lastElseBlock = builder.getInsertionBlock(); + visitBlock(ifStmt->getElse()); + Block *lastElseBlock = builder.getInsertionBlock(); // Create the continuation block here so that it appears lexically after the // 'then' and 'else' blocks, branch from end of 'then' and 'else' SESE regions // to the continuation block. - BasicBlock *continuationBlock = builder.createBlock(); + Block *continuationBlock = builder.createBlock(); builder.setInsertionPointToEnd(lastThenBlock); builder.create(ifStmt->getLoc(), continuationBlock); builder.setInsertionPointToEnd(lastElseBlock); diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 925c50abfec..69344819ed8 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -176,7 +176,7 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForStmt *forStmt, FuncBuilder prologue(forStmt); // DMAs for write regions are going to be inserted just after the for loop. FuncBuilder epilogue(forStmt->getBlock(), - std::next(StmtBlock::iterator(forStmt))); + std::next(Block::iterator(forStmt))); FuncBuilder *b = region.isWrite() ? &epilogue : &prologue; // Builder to create constants at the top level. @@ -382,7 +382,7 @@ static unsigned getNestingDepth(const Statement &stmt) { return depth; } -// TODO(bondhugula): make this run on a StmtBlock instead of a 'for' stmt. +// TODO(bondhugula): make this run on a Block instead of a 'for' stmt. void DmaGeneration::runOnForStmt(ForStmt *forStmt) { // For now (for testing purposes), we'll run this on the outermost among 'for' // stmt's with unit stride, i.e., right at the top of the tile if tiling has diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 2ddd613d6af..d31337437ad 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -343,7 +343,7 @@ public: // Intializes the data dependence graph by walking statements in 'f'. // Assigns each node in the graph a node id based on program order in 'f'. -// TODO(andydavis) Add support for taking a StmtBlock arg to construct the +// TODO(andydavis) Add support for taking a Block arg to construct the // dependence graph at a different depth. bool MemRefDependenceGraph::init(Function *f) { unsigned id = 0; diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index d6c1eed3a0c..109953f2296 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -58,8 +58,9 @@ FunctionPass *mlir::createLoopTilingPass() { return new LoopTiling(); } // Move the loop body of ForStmt 'src' from 'src' into the specified location in // destination's body. static inline void moveLoopBody(ForStmt *src, ForStmt *dest, - StmtBlock::iterator loc) { - dest->getBody()->getStatements().splice(loc, src->getBody()->getStatements()); + Block::iterator loc) { + dest->getBody()->getInstructions().splice(loc, + src->getBody()->getInstructions()); } // Move the loop body of ForStmt 'src' from 'src' to the start of dest's body. @@ -164,8 +165,8 @@ UtilResult mlir::tileCodeGen(ArrayRef band, FuncBuilder b(topLoop); // Loop bounds will be set later. auto *pointLoop = b.createFor(loc, 0, 0); - pointLoop->getBody()->getStatements().splice( - pointLoop->getBody()->begin(), topLoop->getBlock()->getStatements(), + pointLoop->getBody()->getInstructions().splice( + pointLoop->getBody()->begin(), topLoop->getBlock()->getInstructions(), topLoop); newLoops[2 * width - 1 - i] = pointLoop; topLoop = pointLoop; @@ -178,9 +179,9 @@ UtilResult mlir::tileCodeGen(ArrayRef band, FuncBuilder b(topLoop); // Loop bounds will be set later. auto *tileSpaceLoop = b.createFor(loc, 0, 0); - tileSpaceLoop->getBody()->getStatements().splice( - tileSpaceLoop->getBody()->begin(), topLoop->getBlock()->getStatements(), - topLoop); + tileSpaceLoop->getBody()->getInstructions().splice( + tileSpaceLoop->getBody()->begin(), + topLoop->getBlock()->getInstructions(), topLoop); newLoops[2 * width - i - 1] = tileSpaceLoop; topLoop = tileSpaceLoop; } @@ -222,7 +223,7 @@ static void getTileableBands(Function *f, ForStmt *currStmt = root; do { band.push_back(currStmt); - } while (currStmt->getBody()->getStatements().size() == 1 && + } while (currStmt->getBody()->getInstructions().size() == 1 && (currStmt = dyn_cast(&*currStmt->getBody()->begin()))); bands->push_back(band); }; diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index c3651e53593..15ea0f841cc 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -91,9 +91,9 @@ PassResult LoopUnroll::runOnMLFunction(Function *f) { std::vector loops; // This method specialized to encode custom return logic. - using StmtListType = llvm::iplist; - bool walkPostOrder(StmtListType::iterator Start, - StmtListType::iterator End) { + using InstListType = llvm::iplist; + bool walkPostOrder(InstListType::iterator Start, + InstListType::iterator End) { bool hasInnerLoops = false; // We need to walk all elements since all innermost loops need to be // gathered as opposed to determining whether this list has any inner diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 7ed9be19644..60e8d154f98 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -130,13 +130,13 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) { // tree). class JamBlockGatherer : public StmtWalker { public: - using StmtListType = llvm::iplist; + using InstListType = llvm::iplist; // Store iterators to the first and last stmt of each sub-block found. - std::vector> subBlocks; + std::vector> subBlocks; // This is a linear time walk. - void walk(StmtListType::iterator Start, StmtListType::iterator End) { + void walk(InstListType::iterator Start, InstListType::iterator End) { for (auto it = Start; it != End;) { auto subBlockStart = it; while (it != End && !isa(it)) @@ -194,7 +194,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) { DenseMap operandMap; // Insert the cleanup loop right after 'forStmt'. FuncBuilder builder(forStmt->getBlock(), - std::next(StmtBlock::iterator(forStmt))); + std::next(Block::iterator(forStmt))); auto *cleanupForStmt = cast(builder.clone(*forStmt, operandMap)); cleanupForStmt->setLowerBoundMap( getCleanupLoopLowerBound(*forStmt, unrollJamFactor, &builder)); diff --git a/mlir/lib/Transforms/LowerAffineApply.cpp b/mlir/lib/Transforms/LowerAffineApply.cpp index 52146fdb5b7..747733de41e 100644 --- a/mlir/lib/Transforms/LowerAffineApply.cpp +++ b/mlir/lib/Transforms/LowerAffineApply.cpp @@ -52,7 +52,7 @@ PassResult LowerAffineApply::runOnMLFunction(Function *f) { } PassResult LowerAffineApply::runOnCFGFunction(Function *f) { - for (BasicBlock &bb : *f) { + for (Block &bb : *f) { // Handle iterators with care because we erase in the same loop. // In particular, step to the next element before erasing the current one. for (auto it = bb.begin(); it != bb.end();) { diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index 4d24191dcb2..51577009abb 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -110,7 +110,7 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer, // Get the ML function builder. // We need access to the Function builder stored internally in the // MLFunctionLoweringRewriter general rewriting API does not provide - // ML-specific functions (ForStmt and StmtBlock manipulation). While we could + // ML-specific functions (ForStmt and Block manipulation). While we could // forward them or define a whole rewriting chain based on MLFunctionBuilder // instead of Builer, the code for it would be duplicate boilerplate. As we // go towards unifying ML and CFG functions, this separation will disappear. diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index a0964a67fa6..c8a6ced4ed1 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -345,7 +345,7 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { } // Get shifts stored in map. - std::vector shifts(forStmt->getBody()->getStatements().size()); + std::vector shifts(forStmt->getBody()->getInstructions().size()); unsigned s = 0; for (auto &stmt : *forStmt->getBody()) { assert(stmtShiftMap.find(&stmt) != stmtShiftMap.end()); diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 7def4fe2f09..03b4bb29e19 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -108,7 +108,7 @@ bool mlir::promoteIfSingleIteration(ForStmt *forStmt) { } else { const AffineBound lb = forStmt->getLowerBound(); SmallVector lbOperands(lb.operand_begin(), lb.operand_end()); - FuncBuilder builder(forStmt->getBlock(), StmtBlock::iterator(forStmt)); + FuncBuilder builder(forStmt->getBlock(), Block::iterator(forStmt)); auto affineApplyOp = builder.create( forStmt->getLoc(), lb.getMap(), lbOperands); forStmt->replaceAllUsesWith(affineApplyOp->getResult(0)); @@ -116,14 +116,14 @@ bool mlir::promoteIfSingleIteration(ForStmt *forStmt) { } // Move the loop body statements to the loop's containing block. auto *block = forStmt->getBlock(); - block->getStatements().splice(StmtBlock::iterator(forStmt), - forStmt->getBody()->getStatements()); + block->getInstructions().splice(Block::iterator(forStmt), + forStmt->getBody()->getInstructions()); forStmt->erase(); return true; } /// Promotes all single iteration for stmt's in the Function, i.e., moves -/// their body into the containing StmtBlock. +/// their body into the containing Block. void mlir::promoteSingleIterationLoops(Function *f) { // Gathers all innermost loops through a post order pruned walk. class LoopBodyPromoter : public StmtWalker { @@ -223,7 +223,7 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef shifts, int64_t step = forStmt->getStep(); - unsigned numChildStmts = forStmt->getBody()->getStatements().size(); + unsigned numChildStmts = forStmt->getBody()->getInstructions().size(); // Do a linear time (counting) sort for the shifts. uint64_t maxShift = 0; @@ -379,7 +379,7 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) { // Generate the cleanup loop if trip count isn't a multiple of unrollFactor. if (getLargestDivisorOfTripCount(*forStmt) % unrollFactor != 0) { DenseMap operandMap; - FuncBuilder builder(forStmt->getBlock(), ++StmtBlock::iterator(forStmt)); + FuncBuilder builder(forStmt->getBlock(), ++Block::iterator(forStmt)); auto *cleanupForStmt = cast(builder.clone(*forStmt, operandMap)); auto clLbMap = getCleanupLoopLowerBound(*forStmt, unrollFactor, &builder); assert(clLbMap && @@ -408,7 +408,7 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) { // Keep a pointer to the last statement in the original block so that we know // what to clone (since we are doing this in-place). - StmtBlock::iterator srcBlockEnd = std::prev(forStmt->getBody()->end()); + Block::iterator srcBlockEnd = std::prev(forStmt->getBody()->end()); // Unroll the contents of 'forStmt' (append unrollFactor-1 additional copies). for (unsigned i = 1; i < unrollFactor; i++) { diff --git a/mlir/lib/Transforms/ViewFunctionGraph.cpp b/mlir/lib/Transforms/ViewFunctionGraph.cpp index 2ce8af3613a..50a3cf5a595 100644 --- a/mlir/lib/Transforms/ViewFunctionGraph.cpp +++ b/mlir/lib/Transforms/ViewFunctionGraph.cpp @@ -28,16 +28,16 @@ template <> struct llvm::DOTGraphTraits : public DefaultDOTGraphTraits { using DefaultDOTGraphTraits::DefaultDOTGraphTraits; - static std::string getNodeLabel(const BasicBlock *basicBlock, - const Function *); + static std::string getNodeLabel(const Block *Block, const Function *); }; -std::string llvm::DOTGraphTraits::getNodeLabel( - const BasicBlock *basicBlock, const Function *) { +std::string +llvm::DOTGraphTraits::getNodeLabel(const Block *Block, + const Function *) { // Reuse the print output for the node labels. std::string outStreamStr; raw_string_ostream os(outStreamStr); - basicBlock->print(os); + Block->print(os); std::string &outStr = os.str(); if (outStr[0] == '\n') -- cgit v1.2.3 From 456ad6a8e0ca78ce6277da897a0b820533387d84 Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Fri, 28 Dec 2018 16:05:35 -0800 Subject: Standardize naming of statements -> instructions, revisting the code base to be consistent and moving the using declarations over. Hopefully this is the last truly massive patch in this refactoring. This is step 21/n towards merging instructions and statements, NFC. PiperOrigin-RevId: 227178245 --- mlir/g3doc/LangRef.md | 94 +-- mlir/g3doc/Rationale.md | 42 +- mlir/g3doc/RationaleSimplifiedPolyhedralForm.md | 55 +- mlir/include/mlir/Analysis/AffineAnalysis.h | 12 +- mlir/include/mlir/Analysis/AffineStructures.h | 14 +- mlir/include/mlir/Analysis/HyperRectangularSet.h | 2 +- mlir/include/mlir/Analysis/LoopAnalysis.h | 24 +- mlir/include/mlir/Analysis/MLFunctionMatcher.h | 48 +- mlir/include/mlir/Analysis/SliceAnalysis.h | 68 +- mlir/include/mlir/Analysis/Utils.h | 30 +- mlir/include/mlir/Analysis/VectorAnalysis.h | 12 +- mlir/include/mlir/IR/AffineExprVisitor.h | 10 +- mlir/include/mlir/IR/Block.h | 18 +- mlir/include/mlir/IR/Builders.h | 66 +- mlir/include/mlir/IR/BuiltinOps.h | 2 +- mlir/include/mlir/IR/Function.h | 14 +- mlir/include/mlir/IR/InstVisitor.h | 230 ++++++ mlir/include/mlir/IR/Instruction.h | 304 ++++++++ mlir/include/mlir/IR/Instructions.h | 864 +++++++++++++++++++++ mlir/include/mlir/IR/IntegerSet.h | 2 +- mlir/include/mlir/IR/OpDefinition.h | 2 +- mlir/include/mlir/IR/Statement.h | 301 ------- mlir/include/mlir/IR/Statements.h | 863 -------------------- mlir/include/mlir/IR/StmtVisitor.h | 230 ------ mlir/include/mlir/IR/UseDefLists.h | 12 +- mlir/include/mlir/IR/Value.h | 9 +- mlir/include/mlir/Support/Functional.h | 2 +- mlir/include/mlir/Transforms/LoopUtils.h | 38 +- .../mlir/Transforms/MLPatternLoweringPass.h | 14 +- mlir/include/mlir/Transforms/Passes.h | 4 +- mlir/include/mlir/Transforms/Utils.h | 22 +- mlir/lib/Analysis/AffineAnalysis.cpp | 108 +-- mlir/lib/Analysis/AffineStructures.cpp | 26 +- mlir/lib/Analysis/Dominance.cpp | 2 +- mlir/lib/Analysis/LoopAnalysis.cpp | 101 ++- mlir/lib/Analysis/MLFunctionMatcher.cpp | 80 +- mlir/lib/Analysis/MemRefBoundCheck.cpp | 12 +- mlir/lib/Analysis/MemRefDependenceCheck.cpp | 44 +- mlir/lib/Analysis/OpStats.cpp | 12 +- mlir/lib/Analysis/SliceAnalysis.cpp | 114 +-- mlir/lib/Analysis/Utils.cpp | 152 ++-- mlir/lib/Analysis/VectorAnalysis.cpp | 55 +- mlir/lib/Analysis/Verifier.cpp | 53 +- mlir/lib/IR/AsmPrinter.cpp | 136 ++-- mlir/lib/IR/Block.cpp | 28 +- mlir/lib/IR/Builders.cpp | 22 +- mlir/lib/IR/Function.cpp | 14 +- mlir/lib/IR/Instruction.cpp | 829 ++++++++++++++++++++ mlir/lib/IR/Operation.cpp | 4 +- mlir/lib/IR/PatternMatch.cpp | 2 +- mlir/lib/IR/Statement.cpp | 826 -------------------- mlir/lib/IR/Value.cpp | 14 +- mlir/lib/Parser/Parser.cpp | 90 +-- mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp | 2 +- mlir/lib/Transforms/CSE.cpp | 14 +- mlir/lib/Transforms/ComposeAffineMaps.cpp | 24 +- mlir/lib/Transforms/ConstantFold.cpp | 32 +- mlir/lib/Transforms/ConvertToCFG.cpp | 122 +-- mlir/lib/Transforms/DmaGeneration.cpp | 74 +- mlir/lib/Transforms/LoopFusion.cpp | 180 ++--- mlir/lib/Transforms/LoopTiling.cpp | 58 +- mlir/lib/Transforms/LoopUnroll.cpp | 68 +- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 100 +-- mlir/lib/Transforms/LowerVectorTransfers.cpp | 12 +- mlir/lib/Transforms/MaterializeVectors.cpp | 104 +-- mlir/lib/Transforms/PipelineDataTransfer.cpp | 200 ++--- mlir/lib/Transforms/SimplifyAffineExpr.cpp | 22 +- .../Utils/GreedyPatternRewriteDriver.cpp | 4 +- mlir/lib/Transforms/Utils/LoopUtils.cpp | 288 +++---- mlir/lib/Transforms/Utils/Utils.cpp | 152 ++-- .../Vectorization/VectorizerTestPass.cpp | 48 +- mlir/lib/Transforms/Vectorize.cpp | 207 ++--- mlir/test/IR/invalid.mlir | 14 +- mlir/test/IR/parser.mlir | 12 +- mlir/test/Transforms/loop-fusion.mlir | 12 +- mlir/test/Transforms/memref-dependence-check.mlir | 2 +- mlir/test/Transforms/pipeline-data-transfer.mlir | 2 +- mlir/utils/vim/mlir.vim | 4 +- 78 files changed, 3950 insertions(+), 3939 deletions(-) create mode 100644 mlir/include/mlir/IR/InstVisitor.h create mode 100644 mlir/include/mlir/IR/Instruction.h create mode 100644 mlir/include/mlir/IR/Instructions.h delete mode 100644 mlir/include/mlir/IR/Statement.h delete mode 100644 mlir/include/mlir/IR/Statements.h delete mode 100644 mlir/include/mlir/IR/StmtVisitor.h create mode 100644 mlir/lib/IR/Instruction.cpp delete mode 100644 mlir/lib/IR/Statement.cpp (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/g3doc/LangRef.md b/mlir/g3doc/LangRef.md index 3469133a8ba..e9087dc5415 100644 --- a/mlir/g3doc/LangRef.md +++ b/mlir/g3doc/LangRef.md @@ -35,10 +35,10 @@ definitions, a "[CFG Function](#cfg-functions)" and an composition of [operations](#operations), but represent control flow in different ways: A CFG Function control flow using a CFG of [Blocks](#blocks), which contain instructions and end with -[control flow terminator statements](#terminator-instructions) (like branches). -ML Functions represents control flow with a nest of affine loops and if -conditions, and are said to contain statements. Both types of functions can call -back and forth between each other arbitrarily. +[control flow terminator instructions](#terminator-instructions) (like +branches). ML Functions represents control flow with a nest of affine loops and +if conditions. Both types of functions can call back and forth between each +other arbitrarily. MLIR is an [SSA-based](https://en.wikipedia.org/wiki/Static_single_assignment_form) IR, @@ -258,12 +258,12 @@ and symbol identifiers. In an [ML Function](#ml-functions), a symbolic identifier can be bound to an SSA value that is either an argument to the function, a value defined at the top -level of that function (outside of all loops and if statements), the result of a -[`constant` operation](#'constant'-operation), or the result of an +level of that function (outside of all loops and if instructions), the result of +a [`constant` operation](#'constant'-operation), or the result of an [`affine_apply`](#'affine_apply'-operation) operation that recursively takes as arguments any symbolic identifiers. Dimensions may be bound not only to anything that a symbol is bound to, but also to induction variables of enclosing -[for statements](#'for'-statement), and the results of an +[for instructions](#'for'-instruction), and the results of an [`affine_apply` operation](#'affine_apply'-operation) (which recursively may use other dimensions and symbols). @@ -939,7 +939,7 @@ way to lower [ML Functions](#ml-functions) before late code generation. Syntax: ``` {.ebnf} -block ::= bb-label operation* terminator-stmt +block ::= bb-label operation* terminator-inst bb-label ::= bb-id bb-arg-list? `:` bb-id ::= bare-id ssa-id-and-type ::= ssa-id `:` type @@ -951,10 +951,10 @@ bb-arg-list ::= `(` ssa-id-and-type-list? `)` ``` A [basic block](https://en.wikipedia.org/wiki/Basic_block) is a sequential list -of operation instructions without control flow (calls are not considered control -flow for this purpose) that are executed from top to bottom. The last -instruction in a block is a [terminator instruction](#terminator-instructions), -which ends the block. +of instructions without control flow (calls are not considered control flow for +this purpose) that are executed from top to bottom. The last instruction in a +block is a [terminator instruction](#terminator-instructions), which ends the +block. Blocks in MLIR take a list of arguments, which represent SSA PHI nodes in a functional notation. The arguments are defined by the block, and values are @@ -995,7 +995,7 @@ case: they become arguments to the entry block [[more rationale](Rationale.md#block-arguments-vs-phi-nodes)]. Control flow within a CFG function is implemented with unconditional branches, -conditional branches, and a return statement. +conditional branches, and a `return` instruction. TODO: We can add [switches](http://llvm.org/docs/LangRef.html#switch-instruction), @@ -1009,11 +1009,11 @@ if/when there is demand. Syntax: ``` {.ebnf} -terminator-stmt ::= `br` bb-id branch-use-list? +terminator-inst ::= `br` bb-id branch-use-list? branch-use-list ::= `(` ssa-use-list `:` type-list-no-parens `)` ``` -The `br` terminator statement represents an unconditional jump to a target +The `br` terminator instruction represents an unconditional jump to a target block. The count and types of operands to the branch must align with the arguments in the target block. @@ -1025,14 +1025,14 @@ function. Syntax: ``` {.ebnf} -terminator-stmt ::= +terminator-inst ::= `cond_br` ssa-use `,` bb-id branch-use-list? `,` bb-id branch-use-list? ``` -The `cond_br` terminator statement represents a conditional branch on a boolean -(1-bit integer) value. If the bit is set, then the first destination is jumped -to; if it is false, the second destination is chosen. The count and types of -operands must align with the arguments in the corresponding target blocks. +The `cond_br` terminator instruction represents a conditional branch on a +boolean (1-bit integer) value. If the bit is set, then the first destination is +jumped to; if it is false, the second destination is chosen. The count and types +of operands must align with the arguments in the corresponding target blocks. The MLIR conditional branch instruction is not allowed to target the entry block for a function. The two destinations of the conditional branch instruction are @@ -1057,10 +1057,10 @@ bb1 (%x : i32) : Syntax: ``` {.ebnf} -terminator-stmt ::= `return` (ssa-use-list `:` type-list-no-parens)? +terminator-inst ::= `return` (ssa-use-list `:` type-list-no-parens)? ``` -The `return` terminator statement represents the completion of a cfg function, +The `return` terminator instruction represents the completion of a cfg function, and produces the result values. The count and types of the operands must match the result types of the enclosing function. It is legal for multiple blocks in a single function to return. @@ -1071,60 +1071,60 @@ Syntax: ``` {.ebnf} ml-func ::= `mlfunc` ml-func-signature - (`attributes` attribute-dict)? `{` stmt* return-stmt `}` + (`attributes` attribute-dict)? `{` inst* return-inst `}` ml-argument ::= ssa-id `:` type ml-argument-list ::= ml-argument (`,` ml-argument)* | /*empty*/ ml-func-signature ::= function-id `(` ml-argument-list `)` (`->` type-list)? -stmt ::= operation | for-stmt | if-stmt +inst ::= operation | for-inst | if-inst ``` The body of an ML Function is made up of nested affine for loops, conditionals, -and [operation](#operations) statements, and ends with a return statement. Each -of the control flow statements is made up a list of instructions and other -control flow statements. +and [operation](#operations) instructions, and ends with a return instruction. +Each of the control flow instructions is made up a list of instructions and +other control flow instructions. While ML Functions are restricted to affine loops and conditionals, they may freely call (and be called) by CFG Functions which do not have these restrictions. As such, the expressivity of MLIR is not restricted in general; one can choose to apply MLFunctions when it is beneficial. -#### 'return' statement {#'return'-statement} +#### 'return' instruction {#'return'-instruction} Syntax: ``` {.ebnf} -return-stmt ::= `return` (ssa-use-list `:` type-list-no-parens)? +return-inst ::= `return` (ssa-use-list `:` type-list-no-parens)? ``` -The arity and operand types of the return statement must match the result of the -enclosing function. +The arity and operand types of the return instruction must match the result of +the enclosing function. -#### 'for' statement {#'for'-statement} +#### 'for' instruction {#'for'-instruction} Syntax: ``` {.ebnf} -for-stmt ::= `for` ssa-id `=` lower-bound `to` upper-bound - (`step` integer-literal)? `{` stmt* `}` +for-inst ::= `for` ssa-id `=` lower-bound `to` upper-bound + (`step` integer-literal)? `{` inst* `}` lower-bound ::= `max`? affine-map dim-and-symbol-use-list | shorthand-bound upper-bound ::= `min`? affine-map dim-and-symbol-use-list | shorthand-bound shorthand-bound ::= ssa-id | `-`? integer-literal ``` -The `for` statement in an ML Function represents an affine loop nest, defining +The `for` instruction in an ML Function represents an affine loop nest, defining an SSA value for its induction variable. This SSA value always has type [`index`](#index-type), which is the size of the machine word. -The `for` statement executes its body a number of times iterating from a lower +The `for` instruction executes its body a number of times iterating from a lower bound to an upper bound by a stride. The stride, represented by `step`, is a positive constant integer which defaults to "1" if not present. The lower and upper bounds specify a half-open range: the range includes the lower bound but does not include the upper bound. -The lower and upper bounds of a `for` statement are represented as an +The lower and upper bounds of a `for` instruction are represented as an application of an affine mapping to a list of SSA values passed to the map. The [same restrictions](#dimensions-and-symbols) hold for these SSA values as for all bindings of SSA values to dimensions and symbols. @@ -1159,23 +1159,23 @@ mlfunc @simple_example(%A: memref, %B: memref) { } ``` -#### 'if' statement {#'if'-statement} +#### 'if' instruction {#'if'-instruction} Syntax: ``` {.ebnf} -if-stmt-head ::= `if` if-stmt-cond `{` stmt* `}` - | if-stmt-head `else` `if` if-stmt-cond `{` stmt* `}` -if-stmt-cond ::= integer-set dim-and-symbol-use-list +if-inst-head ::= `if` if-inst-cond `{` inst* `}` + | if-inst-head `else` `if` if-inst-cond `{` inst* `}` +if-inst-cond ::= integer-set dim-and-symbol-use-list -if-stmt ::= if-stmt-head - | if-stmt-head `else` `{` stmt* `}` +if-inst ::= if-inst-head + | if-inst-head `else` `{` inst* `}` ``` -The `if` statement in an ML Function restricts execution to a subset of the loop -iteration space defined by an integer set (a conjunction of affine constraints). -A single `if` may have a number of optional `else if` clauses, and may end with -an optional `else` clause. +The `if` instruction in an ML Function restricts execution to a subset of the +loop iteration space defined by an integer set (a conjunction of affine +constraints). A single `if` may have a number of optional `else if` clauses, and +may end with an optional `else` clause. The condition of the `if` is represented by an [integer set](#integer-sets) (a conjunction of affine constraints), and the SSA values bound to the dimensions diff --git a/mlir/g3doc/Rationale.md b/mlir/g3doc/Rationale.md index 17cbd1d15c1..883951637ef 100644 --- a/mlir/g3doc/Rationale.md +++ b/mlir/g3doc/Rationale.md @@ -583,7 +583,7 @@ our current design in practice. The current MLIR uses a representation of polyhedral schedules using a tree of if/for loops. We extensively debated the tradeoffs involved in the typical -unordered polyhedral statement representation (where each statement has +unordered polyhedral instruction representation (where each instruction has multi-dimensional schedule information), discussed the benefits of schedule tree forms, and eventually decided to go with a syntactic tree of affine if/else conditionals and affine for loops. Discussion of the tradeoff was captured in @@ -598,13 +598,13 @@ At a high level, we have two alternatives here: as multidimensional affine functions. A schedule tree form however makes polyhedral domains and schedules a first class concept in the IR allowing compact expression of transformations through the schedule tree without - changing the domains of MLStmts. Such a representation also hides prologues, - epilogues, partial tiles, complex loop bounds and conditionals making loop - nests free of "syntax". Cost models instead look at domains and schedules. - In addition, if necessary such a domain schedule representation can be - normalized to explicitly propagate the schedule into domains and model all - the cleanup code. An example and more detail on the schedule tree form is in - the next section. + changing the domains of instructions. Such a representation also hides + prologues, epilogues, partial tiles, complex loop bounds and conditionals + making loop nests free of "syntax". Cost models instead look at domains and + schedules. In addition, if necessary such a domain schedule representation + can be normalized to explicitly propagate the schedule into domains and + model all the cleanup code. An example and more detail on the schedule tree + form is in the next section. 1. Having two different forms of MLFunctions: an affine loop tree form (AffineLoopTreeFunction) and a polyhedral schedule tree form as two different forms of MLFunctions. Or in effect, having four different forms @@ -620,7 +620,7 @@ has to be executed while schedules represent the order in which domain elements are interleaved. We model domains as non piece-wise convex integer sets, and schedules as affine functions; however, the former can be disjunctive, and the latter can be piece-wise affine relations. In the schedule tree representation, -domain and schedules for statements are represented in a tree-like structure +domain and schedules for instructions are represented in a tree-like structure which is called a schedule tree. Each non-leaf node of the tree is an abstract polyhedral dimension corresponding to an abstract fused loop for each ML instruction that appears in that branch. Each leaf node is an ML Instruction. @@ -790,26 +790,26 @@ extfunc @dma_hbm_to_vmem(memref<1024 x f32, #layout_map0, hbm> %a, We considered providing a representation for SSA values that are live out of if/else conditional bodies or for loops of ML functions. We ultimately abandoned this approach due to its complexity. In the current design of MLIR, scalar -variables cannot escape for loops or if statements. In situations, where +variables cannot escape for loops or if instructions. In situations, where escaping is necessary, we use zero-dimensional tensors and memrefs instead of scalars. The abandoned design of supporting escaping scalars is as follows: -#### For Statement {#for-statement} +#### For Instruction {#for-instruction} Syntax: ``` {.ebnf} [ =] for % = ... step - [with ] { } + [with ] { } ``` out-var-list is a comma separated list of SSA values defined in the loop body and used outside the loop body. in-var-list is a comma separated list of SSA -values used inside the loop body and their initializers. loop-statement-list is -a list of statements that may also include a yield statement. +values used inside the loop body and their initializers. loop-instruction-list +is a list of instructions that may also include a yield instruction. Example: @@ -826,7 +826,7 @@ mlfunc int32 @sum(%A : memref, %N : i32) -> (i32) { } ``` -#### If/else Statement {#if-else-statement} +#### If/else Instruction {#if-else-instruction} Syntax: @@ -834,12 +834,12 @@ Syntax: = if () {...} [else {...}] ``` -Out-var-list is a list of SSA values defined by the if-statement. The values are -arguments to the yield-statement that occurs in both then and else clauses when -else clause is present. When if statement contains only if clause, the escaping -value defined in the then clause should be merged with the value the variable -had before the if statement. The design captured here does not handle this -situation. +Out-var-list is a list of SSA values defined by the if-instruction. The values +are arguments to the yield-instruction that occurs in both then and else clauses +when else clause is present. When if instruction contains only if clause, the +escaping value defined in the then clause should be merged with the value the +variable had before the if instruction. The design captured here does not handle +this situation. Example: diff --git a/mlir/g3doc/RationaleSimplifiedPolyhedralForm.md b/mlir/g3doc/RationaleSimplifiedPolyhedralForm.md index c2770e1e26d..f42bee0303f 100644 --- a/mlir/g3doc/RationaleSimplifiedPolyhedralForm.md +++ b/mlir/g3doc/RationaleSimplifiedPolyhedralForm.md @@ -96,7 +96,7 @@ and probably slightly incorrect below): } ``` -In this design, an mlfunc is an unordered bag of statements whose execution +In this design, an mlfunc is an unordered bag of instructions whose execution order is fully controlled by their schedule. However, we recently agreed that a more explicit schedule tree representation is @@ -128,9 +128,9 @@ representation, and makes lexical ordering within a loop significant (eliminating the constant 0/1/2 of schedules). It isn't obvious in the example above, but the representation allows for some -interesting features, including the ability for statements within a loop nest to -have non-equal domains, like this - the second statement ignores the outer 10 -points inside the loop: +interesting features, including the ability for instructions within a loop nest +to have non-equal domains, like this - the second instruction ignores the outer +10 points inside the loop: ``` mlfunc @reduced_domain_example(... %N) { @@ -147,9 +147,9 @@ points inside the loop: } ``` -It also allows schedule remapping within the statement, like this example that +It also allows schedule remapping within the instruction, like this example that introduces a diagonal skew through a simple change to the schedules of the two -statements: +instructions: ``` mlfunc @skewed_domain_example(... %N) { @@ -175,9 +175,9 @@ structure. This document proposes and explores the idea of going one step further, moving all of the domain and schedule information into the "schedule tree". In this -form, we would have a representation where all statements inside of a given +form, we would have a representation where all instructions inside of a given for-loop are known to have the same domain, which is maintained by the loop. In -the simplified form, we also have an "if" statement that takes an affine +the simplified form, we also have an "if" instruction that takes an affine condition. Our simple example above would be represented as: @@ -199,7 +199,7 @@ Our simple example above would be represented as: } ``` -The example with the reduced domain would be represented with an if statement: +The example with the reduced domain would be represented with an if instruction: ```mlir mlfunc @reduced_domain_example(... %N) { @@ -223,13 +223,13 @@ The example with the reduced domain would be represented with an if statement: These IRs represent exactly the same information, and use a similar information density. The 'traditional' form introduces an extra level of abstraction -(schedules and domains) that make it easy to transform statements at the expense -of making it difficult to reason about how those statements will come out after -code generation. With the simplified form, transformations have to do parts of -code generation inline with their transformation: instead of simply changing a -schedule to **(i+j, j)** to get skewing, you'd have to generate this code -explicitly (potentially implemented by making polyhedral codegen a library that -transformations call into): +(schedules and domains) that make it easy to transform instructions at the +expense of making it difficult to reason about how those instructions will come +out after code generation. With the simplified form, transformations have to do +parts of code generation inline with their transformation: instead of simply +changing a schedule to **(i+j, j)** to get skewing, you'd have to generate this +code explicitly (potentially implemented by making polyhedral codegen a library +that transformations call into): ```mlir mlfunc @skewed_domain_example(... %N) { @@ -268,12 +268,12 @@ representation helps solve this inherently hard problem. ### Commonality: compactness of IR In the cases that are most relevant to us (hyper rectangular spaces) these forms -are directly equivalent: a traditional statement with a limited domain (e.g. the -"reduced_domain_example" above) ends up having one level of ML 'if' inside its -loops. The simplified form pays for this by eliminating schedules and domains -from the IR. Both forms allow code duplication to reduce dynamic branches in the -IR: the traditional approach allows statement splitting, the simplified form -supports statement duplication. +are directly equivalent: a traditional instruction with a limited domain (e.g. +the "reduced_domain_example" above) ends up having one level of ML 'if' inside +its loops. The simplified form pays for this by eliminating schedules and +domains from the IR. Both forms allow code duplication to reduce dynamic +branches in the IR: the traditional approach allows instruction splitting, the +simplified form supports instruction duplication. It is important to point out that the traditional form wins on compactness in the extreme cases: e.g. the loop skewing case. These cases will be rare in @@ -296,7 +296,7 @@ possible to do this, but it is a non-trivial transformation. An advantage for the traditional form is that it is easier to perform certain transformations on it: skewing and tiling are just transformations on the -schedule of the statements in question, it doesn't require changing the loop +schedule of the instructions in question, it doesn't require changing the loop structure. In practice, the simplified form requires moving the complexity of code @@ -317,7 +317,7 @@ The simplified form is much easier for analyses and transformations to build cost models for (e.g. answering the question of "how much code bloat will be caused by unrolling a loop at this level?"), because it is easier to predict what target code will be generated. With the traditional form, these analyses -will have to anticipate what polyhedral codegen will do to a set of statements +will have to anticipate what polyhedral codegen will do to a set of instructions under consideration: something that is non-trivial in the interesting cases in question (see "Cost of code generation"). @@ -343,7 +343,7 @@ stages of a code generator for an accelerator. We agree already that values defined in an mlfunc can include scalar values and they are defined based on traditional dominance. In the simplified form, this is very simple: arguments and induction variables defined in for-loops are live -inside their lexical body, and linear series of statements have the same "top +inside their lexical body, and linear series of instructions have the same "top down" dominance relation that a basic block does. In the traditional form though, this is not the case: it seems that a lot of @@ -374,8 +374,9 @@ mlfunc's (if we support them) will also have to have domains. The traditional form has multiple encodings for the same sorts of behavior: you end up having bits on `for` loops to specify whether codegen should use -"atomic/separate" policies, unroll loops, etc. Statements can be split or can -generate multiple copies of their statement because of overlapping domains, etc. +"atomic/separate" policies, unroll loops, etc. Instructions can be split or can +generate multiple copies of their instruction because of overlapping domains, +etc. This is a problem for analyses and cost models, because they each have to reason about these additional forms in the IR. diff --git a/mlir/include/mlir/Analysis/AffineAnalysis.h b/mlir/include/mlir/Analysis/AffineAnalysis.h index 5ffaf845cfc..b769841b451 100644 --- a/mlir/include/mlir/Analysis/AffineAnalysis.h +++ b/mlir/include/mlir/Analysis/AffineAnalysis.h @@ -33,12 +33,12 @@ namespace mlir { class AffineExpr; class AffineMap; class AffineValueMap; -class ForStmt; +class ForInst; class MLIRContext; class FlatAffineConstraints; class IntegerSet; class OperationInst; -class Statement; +class Instruction; class Value; /// Simplify an affine expression through flattening and some amount of @@ -113,17 +113,17 @@ bool getFlattenedAffineExprs( FlatAffineConstraints *cst = nullptr); /// Builds a system of constraints with dimensional identifiers corresponding to -/// the loop IVs of the forStmts appearing in that order. Bounds of the loop are +/// the loop IVs of the forInsts appearing in that order. Bounds of the loop are /// used to add appropriate inequalities. Any symbols founds in the bound /// operands are added as symbols in the system. Returns false for the yet /// unimplemented cases. // TODO(bondhugula): handle non-unit strides. -bool getIndexSet(llvm::ArrayRef forStmts, +bool getIndexSet(llvm::ArrayRef forInsts, FlatAffineConstraints *domain); struct MemRefAccess { const Value *memref; - const OperationInst *opStmt; + const OperationInst *opInst; llvm::SmallVector indices; // Populates 'accessMap' with composition of AffineApplyOps reachable from // 'indices'. @@ -146,7 +146,7 @@ struct DependenceComponent { /// Checks whether two accesses to the same memref access the same element. /// Each access is specified using the MemRefAccess structure, which contains -/// the operation statement, indices and memref associated with the access. +/// the operation instruction, indices and memref associated with the access. /// Returns 'false' if it can be determined conclusively that the accesses do /// not access the same memref element. Returns 'true' otherwise. // TODO(andydavis) Wrap 'dependenceConstraints' and 'dependenceComponents' into diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h index de60dc2115c..c644a33c938 100644 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -30,7 +30,7 @@ class AffineApplyOp; class AffineBound; class AffineCondition; class AffineMap; -class ForStmt; +class ForInst; class IntegerSet; class MLIRContext; class Value; @@ -113,7 +113,7 @@ private: /// results, and its map can themselves change as a result of /// substitutions, simplifications, and other analysis. // An affine value map can readily be constructed from an AffineApplyOp, or an -// AffineBound of a ForStmt. It can be further transformed, substituted into, +// AffineBound of a ForInst. It can be further transformed, substituted into, // or simplified. Unlike AffineMap's, AffineValueMap's are created and destroyed // during analysis. Only the AffineMap expressions that are pointed by them are // unique'd. @@ -410,16 +410,16 @@ public: void addLowerBound(ArrayRef expr, ArrayRef lb); /// Adds constraints (lower and upper bounds) for the specified 'for' - /// statement's Value using IR information stored in its bound maps. The - /// right identifier is first looked up using forStmt's Value. Returns + /// instruction's Value using IR information stored in its bound maps. The + /// right identifier is first looked up using forInst's Value. Returns /// false for the yet unimplemented/unsupported cases, and true if the /// information is succesfully added. Asserts if the Value corresponding to - /// the 'for' statement isn't found in the constraint system. Any new - /// identifiers that are found in the bound operands of the 'for' statement + /// the 'for' instruction isn't found in the constraint system. Any new + /// identifiers that are found in the bound operands of the 'for' instruction /// are added as trailing identifiers (either dimensional or symbolic /// depending on whether the operand is a valid ML Function symbol). // TODO(bondhugula): add support for non-unit strides. - bool addForStmtDomain(const ForStmt &forStmt); + bool addForInstDomain(const ForInst &forInst); /// Adds an upper bound expression for the specified expression. void addUpperBound(ArrayRef expr, ArrayRef ub); diff --git a/mlir/include/mlir/Analysis/HyperRectangularSet.h b/mlir/include/mlir/Analysis/HyperRectangularSet.h index 74961308f47..52c242c0607 100644 --- a/mlir/include/mlir/Analysis/HyperRectangularSet.h +++ b/mlir/include/mlir/Analysis/HyperRectangularSet.h @@ -262,7 +262,7 @@ public: using HyperRectangleListTy = ::llvm::iplist; HyperRectangleListTy &getRectangles() { return hyperRectangles; } - // Iteration over the statements in the block. + // Iteration over the instructions in the block. using const_iterator = HyperRectangleListTy::const_iterator; const_iterator begin() const { return hyperRectangles.begin(); } diff --git a/mlir/include/mlir/Analysis/LoopAnalysis.h b/mlir/include/mlir/Analysis/LoopAnalysis.h index 69fb81d0a1f..1b3d0ce9675 100644 --- a/mlir/include/mlir/Analysis/LoopAnalysis.h +++ b/mlir/include/mlir/Analysis/LoopAnalysis.h @@ -30,7 +30,7 @@ namespace mlir { class AffineExpr; class AffineMap; -class ForStmt; +class ForInst; class MemRefType; class OperationInst; class Value; @@ -38,19 +38,19 @@ class Value; /// Returns the trip count of the loop as an affine expression if the latter is /// expressible as an affine expression, and nullptr otherwise. The trip count /// expression is simplified before returning. -AffineExpr getTripCountExpr(const ForStmt &forStmt); +AffineExpr getTripCountExpr(const ForInst &forInst); /// Returns the trip count of the loop if it's a constant, None otherwise. This /// uses affine expression analysis and is able to determine constant trip count /// in non-trivial cases. -llvm::Optional getConstantTripCount(const ForStmt &forStmt); +llvm::Optional getConstantTripCount(const ForInst &forInst); /// Returns the greatest known integral divisor of the trip count. Affine /// expression analysis is used (indirectly through getTripCount), and /// this method is thus able to determine non-trivial divisors. -uint64_t getLargestDivisorOfTripCount(const ForStmt &forStmt); +uint64_t getLargestDivisorOfTripCount(const ForInst &forInst); -/// Given an induction variable `iv` of type ForStmt and an `index` of type +/// Given an induction variable `iv` of type ForInst and an `index` of type /// IndexType, returns `true` if `index` is independent of `iv` and false /// otherwise. /// The determination supports composition with at most one AffineApplyOp. @@ -67,7 +67,7 @@ uint64_t getLargestDivisorOfTripCount(const ForStmt &forStmt); /// conservative. bool isAccessInvariant(const Value &iv, const Value &index); -/// Given an induction variable `iv` of type ForStmt and `indices` of type +/// Given an induction variable `iv` of type ForInst and `indices` of type /// IndexType, returns the set of `indices` that are independent of `iv`. /// /// Prerequisites (inherited from `isAccessInvariant` above): @@ -85,21 +85,21 @@ getInvariantAccesses(const Value &iv, llvm::ArrayRef indices); /// 3. all nested load/stores are to scalar MemRefs. /// TODO(ntv): implement dependence semantics /// TODO(ntv): relax the no-conditionals restriction -bool isVectorizableLoop(const ForStmt &loop); +bool isVectorizableLoop(const ForInst &loop); /// Checks whether the loop is structurally vectorizable and that all the LoadOp /// and StoreOp matched have access indexing functions that are are either: /// 1. invariant along the loop induction variable created by 'loop'; /// 2. varying along the 'fastestVaryingDim' memory dimension. -bool isVectorizableLoopAlongFastestVaryingMemRefDim(const ForStmt &loop, +bool isVectorizableLoopAlongFastestVaryingMemRefDim(const ForInst &loop, unsigned fastestVaryingDim); -/// Checks where SSA dominance would be violated if a for stmt's body statements -/// are shifted by the specified shifts. This method checks if a 'def' and all -/// its uses have the same shift factor. +/// Checks where SSA dominance would be violated if a for inst's body +/// instructions are shifted by the specified shifts. This method checks if a +/// 'def' and all its uses have the same shift factor. // TODO(mlir-team): extend this to check for memory-based dependence // violation when we have the support. -bool isStmtwiseShiftValid(const ForStmt &forStmt, +bool isInstwiseShiftValid(const ForInst &forInst, llvm::ArrayRef shifts); } // end namespace mlir diff --git a/mlir/include/mlir/Analysis/MLFunctionMatcher.h b/mlir/include/mlir/Analysis/MLFunctionMatcher.h index 753d741f448..0c6917a0749 100644 --- a/mlir/include/mlir/Analysis/MLFunctionMatcher.h +++ b/mlir/include/mlir/Analysis/MLFunctionMatcher.h @@ -18,7 +18,7 @@ #ifndef MLIR_ANALYSIS_MLFUNCTIONMATCHER_H_ #define MLIR_ANALYSIS_MLFUNCTIONMATCHER_H_ -#include "mlir/IR/StmtVisitor.h" +#include "mlir/IR/InstVisitor.h" #include "llvm/Support/Allocator.h" #include @@ -26,7 +26,7 @@ namespace mlir { struct MLFunctionMatcherStorage; struct MLFunctionMatchesStorage; -class Statement; +class Instruction; /// An MLFunctionMatcher is a recursive matcher that captures nested patterns in /// an ML Function. It is used in conjunction with a scoped @@ -47,14 +47,14 @@ class Statement; /// /// Recursive abstraction for matching results. -/// Provides iteration over the Statement* captured by a Matcher. +/// Provides iteration over the Instruction* captured by a Matcher. /// /// Implemented as a POD value-type with underlying storage pointer. /// The underlying storage lives in a scoped bumper allocator whose lifetime /// is managed by an RAII MLFunctionMatcherContext. /// This should be used by value everywhere. struct MLFunctionMatches { - using EntryType = std::pair; + using EntryType = std::pair; using iterator = EntryType *; MLFunctionMatches() : storage(nullptr) {} @@ -66,8 +66,8 @@ struct MLFunctionMatches { unsigned size() { return end() - begin(); } unsigned empty() { return size() == 0; } - /// Appends the pair to the current matches. - void append(Statement *stmt, MLFunctionMatches children); + /// Appends the pair to the current matches. + void append(Instruction *inst, MLFunctionMatches children); private: friend class MLFunctionMatcher; @@ -79,7 +79,7 @@ private: MLFunctionMatchesStorage *storage; }; -/// A MLFunctionMatcher is a special type of StmtWalker that: +/// A MLFunctionMatcher is a special type of InstWalker that: /// 1. recursively matches a substructure in the tree; /// 2. uses a filter function to refine matches with extra semantic /// constraints (passed via a lambda of type FilterFunctionType); @@ -89,39 +89,39 @@ private: /// The underlying storage lives in a scoped bumper allocator whose lifetime /// is managed by an RAII MLFunctionMatcherContext. /// This should be used by value everywhere. -using FilterFunctionType = std::function; -static bool defaultFilterFunction(const Statement &) { return true; }; -struct MLFunctionMatcher : public StmtWalker { - MLFunctionMatcher(Statement::Kind k, MLFunctionMatcher child, +using FilterFunctionType = std::function; +static bool defaultFilterFunction(const Instruction &) { return true; }; +struct MLFunctionMatcher : public InstWalker { + MLFunctionMatcher(Instruction::Kind k, MLFunctionMatcher child, FilterFunctionType filter = defaultFilterFunction); - MLFunctionMatcher(Statement::Kind k, + MLFunctionMatcher(Instruction::Kind k, MutableArrayRef children, FilterFunctionType filter = defaultFilterFunction); /// Returns all the matches in `function`. MLFunctionMatches match(Function *function); - /// Returns all the matches nested under `statement`. - MLFunctionMatches match(Statement *statement); + /// Returns all the matches nested under `instruction`. + MLFunctionMatches match(Instruction *instruction); unsigned getDepth(); private: friend class MLFunctionMatcherContext; - friend StmtWalker; + friend InstWalker; - Statement::Kind getKind(); + Instruction::Kind getKind(); MutableArrayRef getChildrenMLFunctionMatchers(); FilterFunctionType getFilterFunction(); MLFunctionMatcher forkMLFunctionMatcherAt(MLFunctionMatcher tmpl, - Statement *stmt); + Instruction *inst); - void matchOne(Statement *elem); + void matchOne(Instruction *elem); - void visitForStmt(ForStmt *forStmt) { matchOne(forStmt); } - void visitIfStmt(IfStmt *ifStmt) { matchOne(ifStmt); } - void visitOperationInst(OperationInst *opStmt) { matchOne(opStmt); } + void visitForInst(ForInst *forInst) { matchOne(forInst); } + void visitIfInst(IfInst *ifInst) { matchOne(ifInst); } + void visitOperationInst(OperationInst *opInst) { matchOne(opInst); } /// Underlying global bump allocator managed by an MLFunctionMatcherContext. static llvm::BumpPtrAllocator *&allocator(); @@ -160,9 +160,9 @@ MLFunctionMatcher For(MutableArrayRef children = {}); MLFunctionMatcher For(FilterFunctionType filter, MutableArrayRef children = {}); -bool isParallelLoop(const Statement &stmt); -bool isReductionLoop(const Statement &stmt); -bool isLoadOrStore(const Statement &stmt); +bool isParallelLoop(const Instruction &inst); +bool isReductionLoop(const Instruction &inst); +bool isLoadOrStore(const Instruction &inst); } // end namespace matcher } // end namespace mlir diff --git a/mlir/include/mlir/Analysis/SliceAnalysis.h b/mlir/include/mlir/Analysis/SliceAnalysis.h index c3db378d971..c3cafbc8ae4 100644 --- a/mlir/include/mlir/Analysis/SliceAnalysis.h +++ b/mlir/include/mlir/Analysis/SliceAnalysis.h @@ -27,24 +27,24 @@ namespace mlir { -class Statement; +class Instruction; /// Type of the condition to limit the propagation of transitive use-defs. /// This can be used in particular to limit the propagation to a given Scope or -/// to avoid passing through certain types of statement in a configurable +/// to avoid passing through certain types of instruction in a configurable /// manner. -using TransitiveFilter = std::function; +using TransitiveFilter = std::function; /// Fills `forwardSlice` with the computed forward slice (i.e. all -/// the transitive uses of stmt), **without** including that statement. +/// the transitive uses of inst), **without** including that instruction. /// /// This additionally takes a TransitiveFilter which acts as a frontier: -/// when looking at uses transitively, a statement that does not pass the filter -/// is never propagated through. This allows in particular to carve out the -/// scope within a ForStmt or the scope within an IfStmt. +/// when looking at uses transitively, a instruction that does not pass the +/// filter is never propagated through. This allows in particular to carve out +/// the scope within a ForInst or the scope within an IfInst. /// /// The implementation traverses the use chains in postorder traversal for -/// efficiency reasons: if a statement is already in `forwardSlice`, no +/// efficiency reasons: if a instruction is already in `forwardSlice`, no /// need to traverse its uses again. Since use-def chains form a DAG, this /// terminates. /// @@ -77,21 +77,21 @@ using TransitiveFilter = std::function; /// {4, 3, 6, 2, 1, 5, 8, 7, 9} /// void getForwardSlice( - Statement *stmt, llvm::SetVector *forwardSlice, + Instruction *inst, llvm::SetVector *forwardSlice, TransitiveFilter filter = /* pass-through*/ - [](Statement *) { return true; }, + [](Instruction *) { return true; }, bool topLevel = true); /// Fills `backwardSlice` with the computed backward slice (i.e. -/// all the transitive defs of stmt), **without** including that statement. +/// all the transitive defs of inst), **without** including that instruction. /// /// This additionally takes a TransitiveFilter which acts as a frontier: -/// when looking at defs transitively, a statement that does not pass the filter -/// is never propagated through. This allows in particular to carve out the -/// scope within a ForStmt or the scope within an IfStmt. +/// when looking at defs transitively, a instruction that does not pass the +/// filter is never propagated through. This allows in particular to carve out +/// the scope within a ForInst or the scope within an IfInst. /// /// The implementation traverses the def chains in postorder traversal for -/// efficiency reasons: if a statement is already in `backwardSlice`, no +/// efficiency reasons: if a instruction is already in `backwardSlice`, no /// need to traverse its definitions again. Since useuse-def chains form a DAG, /// this terminates. /// @@ -117,14 +117,14 @@ void getForwardSlice( /// {1, 2, 5, 7, 3, 4, 6, 8} /// void getBackwardSlice( - Statement *stmt, llvm::SetVector *backwardSlice, + Instruction *inst, llvm::SetVector *backwardSlice, TransitiveFilter filter = /* pass-through*/ - [](Statement *) { return true; }, + [](Instruction *) { return true; }, bool topLevel = true); /// Iteratively computes backward slices and forward slices until -/// a fixed point is reached. Returns an `llvm::SetVector` which -/// **includes** the original statement. +/// a fixed point is reached. Returns an `llvm::SetVector` which +/// **includes** the original instruction. /// /// This allows building a slice (i.e. multi-root DAG where everything /// that is reachable from an Value in forward and backward direction is @@ -158,17 +158,17 @@ void getBackwardSlice( /// /// Additional implementation considerations /// ======================================== -/// Consider the defs-stmt-uses hourglass. +/// Consider the defs-inst-uses hourglass. /// ____ /// \ / defs (in some topological order) /// \/ -/// stmt +/// inst /// /\ /// / \ uses (in some topological order) /// /____\ /// /// We want to iteratively apply `getSlice` to construct the whole -/// list of OperationInst that are reachable by (use|def)+ from stmt. +/// list of OperationInst that are reachable by (use|def)+ from inst. /// We want the resulting slice in topological order. /// Ideally we would like the ordering to be maintained in-place to avoid /// copying OperationInst at each step. Keeping this ordering by construction @@ -183,34 +183,34 @@ void getBackwardSlice( /// =========== /// We wish to maintain the following property by a recursive argument: /// """ -/// defs << {stmt} < getSlice( - Statement *stmt, +llvm::SetVector getSlice( + Instruction *inst, TransitiveFilter backwardFilter = /* pass-through*/ - [](Statement *) { return true; }, + [](Instruction *) { return true; }, TransitiveFilter forwardFilter = /* pass-through*/ - [](Statement *) { return true; }); + [](Instruction *) { return true; }); /// Multi-root DAG topological sort. /// Performs a topological sort of the OperationInst in the `toSort` SetVector. /// Returns a topologically sorted SetVector. -llvm::SetVector -topologicalSort(const llvm::SetVector &toSort); +llvm::SetVector +topologicalSort(const llvm::SetVector &toSort); } // end namespace mlir diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index fd57cda3902..fe04d401bcd 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -33,22 +33,22 @@ namespace mlir { class FlatAffineConstraints; -class ForStmt; +class ForInst; class MemRefAccess; class OperationInst; -class Statement; +class Instruction; class Value; -/// Returns true if statement 'a' dominates statement b. -bool dominates(const Statement &a, const Statement &b); +/// Returns true if instruction 'a' dominates instruction b. +bool dominates(const Instruction &a, const Instruction &b); -/// Returns true if statement 'a' properly dominates statement b. -bool properlyDominates(const Statement &a, const Statement &b); +/// Returns true if instruction 'a' properly dominates instruction b. +bool properlyDominates(const Instruction &a, const Instruction &b); -/// Populates 'loops' with IVs of the loops surrounding 'stmt' ordered from -/// the outermost 'for' statement to the innermost one. -// TODO(bondhugula): handle 'if' stmt's. -void getLoopIVs(const Statement &stmt, SmallVectorImpl *loops); +/// Populates 'loops' with IVs of the loops surrounding 'inst' ordered from +/// the outermost 'for' instruction to the innermost one. +// TODO(bondhugula): handle 'if' inst's. +void getLoopIVs(const Instruction &inst, SmallVectorImpl *loops); /// A region of a memref's data space; this is typically constructed by /// analyzing load/store op's on this memref and the index space of loops @@ -111,10 +111,10 @@ private: /// Computes the memory region accessed by this memref with the region /// represented as constraints symbolic/parameteric in 'loopDepth' loops -/// surrounding opStmt. Returns false if this fails due to yet unimplemented +/// surrounding opInst. Returns false if this fails due to yet unimplemented /// cases. The computed region's 'cst' field has exactly as many dimensional /// identifiers as the rank of the memref, and *potentially* additional symbolic -/// identifiers which could include any of the loop IVs surrounding opStmt up +/// identifiers which could include any of the loop IVs surrounding opInst up /// until 'loopDepth' and another additional Function symbols involved with /// the access (for eg., those appear in affine_apply's, loop bounds, etc.). /// For example, the memref region for this operation at loopDepth = 1 will be: @@ -128,7 +128,7 @@ private: /// {memref = %A, write = false, {%i <= m0 <= %i + 7} } /// The last field is a 2-d FlatAffineConstraints symbolic in %i. /// -bool getMemRefRegion(OperationInst *opStmt, unsigned loopDepth, +bool getMemRefRegion(OperationInst *opInst, unsigned loopDepth, MemRefRegion *region); /// Returns the size of memref data in bytes if it's statically shaped, None @@ -144,7 +144,7 @@ bool boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, /// Creates a clone of the computation contained in the loop nest surrounding /// 'srcAccess', slices the iteration space of the first 'srcLoopDepth' src loop -/// IVs, and inserts the computation slice at the beginning of the statement +/// IVs, and inserts the computation slice at the beginning of the instruction /// block of the loop at 'dstLoopDepth' in the loop nest surrounding /// 'dstAccess'. Returns the top-level loop of the computation slice on /// success, returns nullptr otherwise. @@ -152,7 +152,7 @@ bool boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, // materialize the results of the backward slice - presenting a trade-off b/w // storage and redundant computation in several cases // TODO(andydavis) Support computation slices with common surrounding loops. -ForStmt *insertBackwardComputationSlice(MemRefAccess *srcAccess, +ForInst *insertBackwardComputationSlice(MemRefAccess *srcAccess, MemRefAccess *dstAccess, unsigned srcLoopDepth, unsigned dstLoopDepth); diff --git a/mlir/include/mlir/Analysis/VectorAnalysis.h b/mlir/include/mlir/Analysis/VectorAnalysis.h index f84aff29946..9f9eaba056f 100644 --- a/mlir/include/mlir/Analysis/VectorAnalysis.h +++ b/mlir/include/mlir/Analysis/VectorAnalysis.h @@ -25,7 +25,7 @@ namespace mlir { class AffineMap; -class ForStmt; +class ForInst; class MemRefType; class OperationInst; class VectorType; @@ -65,8 +65,8 @@ shapeRatio(VectorType superVectorType, VectorType subVectorType); /// Note that loopToVectorDim is a whole function map from which only enclosing /// loop information is extracted. /// -/// Prerequisites: `opStmt` is a vectorizable load or store operation (i.e. at -/// most one invariant index along each ForStmt of `loopToVectorDim`). +/// Prerequisites: `opInst` is a vectorizable load or store operation (i.e. at +/// most one invariant index along each ForInst of `loopToVectorDim`). /// /// Example 1: /// The following MLIR snippet: @@ -118,8 +118,8 @@ shapeRatio(VectorType superVectorType, VectorType subVectorType); /// `%arg0[%c0, %c0]` into vector<128xf32> which needs a 1-D vector broadcast. /// AffineMap -makePermutationMap(OperationInst *opStmt, - const llvm::DenseMap &loopToVectorDim); +makePermutationMap(OperationInst *opInst, + const llvm::DenseMap &loopToVectorDim); namespace matcher { @@ -131,7 +131,7 @@ namespace matcher { /// TODO(ntv): this could all be much simpler if we added a bit that a vector /// type to mark that a vector is a strict super-vector but it still does not /// warrant adding even 1 extra bit in the IR for now. -bool operatesOnStrictSuperVectors(const OperationInst &stmt, +bool operatesOnStrictSuperVectors(const OperationInst &inst, VectorType subVectorType); } // end namespace matcher diff --git a/mlir/include/mlir/IR/AffineExprVisitor.h b/mlir/include/mlir/IR/AffineExprVisitor.h index 5c1f07e98a6..b3995352e61 100644 --- a/mlir/include/mlir/IR/AffineExprVisitor.h +++ b/mlir/include/mlir/IR/AffineExprVisitor.h @@ -30,7 +30,7 @@ namespace mlir { /// /// AffineExpr visitors are used when you want to perform different actions /// for different kinds of AffineExprs without having to use lots of casts -/// and a big switch statement. +/// and a big switch instruction. /// /// To define your own visitor, inherit from this class, specifying your /// new type for the 'SubClass' template parameter, and "override" visitXXX @@ -66,11 +66,11 @@ namespace mlir { // AffineSymbolExpr. /// /// Note that if you don't implement visitXXX for some affine expression type, -/// the visitXXX method for Statement superclass will be invoked. +/// the visitXXX method for Instruction superclass will be invoked. /// /// Note that this class is specifically designed as a template to avoid /// virtual function call overhead. Defining and using a AffineExprVisitor is -/// just as efficient as having your own switch statement over the statement +/// just as efficient as having your own switch instruction over the instruction /// opcode. template class AffineExprVisitor { @@ -159,8 +159,8 @@ public: //===--------------------------------------------------------------------===// // Visitation functions... these functions provide default fallbacks in case - // the user does not specify what to do for a particular statement type. - // The default behavior is to generalize the statement type to its subtype + // the user does not specify what to do for a particular instruction type. + // The default behavior is to generalize the instruction type to its subtype // and try visiting the subtype. All of this should be inlined perfectly, // because there are no virtual functions to get in the way. // diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h index 985d0fdb075..8e6c80a146d 100644 --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -22,11 +22,11 @@ #ifndef MLIR_IR_BLOCK_H #define MLIR_IR_BLOCK_H -#include "mlir/IR/Statement.h" +#include "mlir/IR/Instruction.h" #include "llvm/ADT/PointerUnion.h" namespace mlir { -class IfStmt; +class IfInst; class BlockList; template class PredecessorIterator; @@ -58,7 +58,7 @@ public: } /// Returns the function that this block is part of, even if the block is - /// nested under an IfStmt or ForStmt. + /// nested under an IfInst or ForInst. Function *getFunction(); const Function *getFunction() const { return const_cast(this)->getFunction(); @@ -134,10 +134,10 @@ public: /// Returns the instructions's position in this block or -1 if the instruction /// is not present. /// TODO: This is needlessly inefficient, and should not be API on Block. - int64_t findInstPositionInBlock(const Instruction &stmt) const { + int64_t findInstPositionInBlock(const Instruction &inst) const { int64_t j = 0; for (const auto &s : instructions) { - if (&s == &stmt) + if (&s == &inst) return j; j++; } @@ -291,7 +291,7 @@ private: namespace mlir { /// This class contains a list of basic blocks and has a notion of the object it -/// is part of - a Function or IfStmt or ForStmt. +/// is part of - a Function or IfInst or ForInst. class BlockList { public: explicit BlockList(Function *container); @@ -331,14 +331,14 @@ public: return &BlockList::blocks; } - /// A BlockList is part of a Function or and IfStmt/ForStmt. If it is - /// part of an IfStmt/ForStmt, then return it, otherwise return null. + /// A BlockList is part of a Function or and IfInst/ForInst. If it is + /// part of an IfInst/ForInst, then return it, otherwise return null. Instruction *getContainingInst(); const Instruction *getContainingInst() const { return const_cast(this)->getContainingInst(); } - /// A BlockList is part of a Function or and IfStmt/ForStmt. If it is + /// A BlockList is part of a Function or and IfInst/ForInst. If it is /// part of a Function, then return it, otherwise return null. Function *getContainingFunction(); const Function *getContainingFunction() const { diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 5c1331e880d..5cba4caef0d 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -19,7 +19,7 @@ #define MLIR_IR_BUILDERS_H #include "mlir/IR/Function.h" -#include "mlir/IR/Statements.h" +#include "mlir/IR/Instructions.h" namespace mlir { @@ -172,10 +172,10 @@ public: clearInsertionPoint(); } - /// Create a function builder and set insertion point to the given statement, - /// which will cause subsequent insertions to go right before it. - FuncBuilder(Statement *stmt) : FuncBuilder(stmt->getFunction()) { - setInsertionPoint(stmt); + /// Create a function builder and set insertion point to the given + /// instruction, which will cause subsequent insertions to go right before it. + FuncBuilder(Instruction *inst) : FuncBuilder(inst->getFunction()) { + setInsertionPoint(inst); } FuncBuilder(Block *block) : FuncBuilder(block->getFunction()) { @@ -207,8 +207,8 @@ public: /// Sets the insertion point to the specified operation, which will cause /// subsequent insertions to go right before it. - void setInsertionPoint(Statement *stmt) { - setInsertionPoint(stmt->getBlock(), Block::iterator(stmt)); + void setInsertionPoint(Instruction *inst) { + setInsertionPoint(inst->getBlock(), Block::iterator(inst)); } /// Sets the insertion point to the start of the specified block. @@ -234,9 +234,9 @@ public: /// current function. Block *createBlock(Block *insertBefore = nullptr); - /// Returns a builder for the body of a for Stmt. - static FuncBuilder getForStmtBodyBuilder(ForStmt *forStmt) { - return FuncBuilder(forStmt->getBody(), forStmt->getBody()->end()); + /// Returns a builder for the body of a 'for' instruction. + static FuncBuilder getForInstBodyBuilder(ForInst *forInst) { + return FuncBuilder(forInst->getBody(), forInst->getBody()->end()); } /// Returns the current block of the builder. @@ -250,8 +250,8 @@ public: OpPointer create(Location location, Args... args) { OperationState state(getContext(), location, OpTy::getOperationName()); OpTy::build(this, &state, args...); - auto *stmt = createOperation(state); - auto result = stmt->dyn_cast(); + auto *inst = createOperation(state); + auto result = inst->dyn_cast(); assert(result && "Builder didn't return the right type"); return result; } @@ -263,44 +263,44 @@ public: OpPointer createChecked(Location location, Args... args) { OperationState state(getContext(), location, OpTy::getOperationName()); OpTy::build(this, &state, args...); - auto *stmt = createOperation(state); + auto *inst = createOperation(state); // If the OperationInst we produce is valid, return it. - if (!OpTy::verifyInvariants(stmt)) { - auto result = stmt->dyn_cast(); + if (!OpTy::verifyInvariants(inst)) { + auto result = inst->dyn_cast(); assert(result && "Builder didn't return the right type"); return result; } - // Otherwise, the error message got emitted. Just remove the statement + // Otherwise, the error message got emitted. Just remove the instruction // we made. - stmt->erase(); + inst->erase(); return OpPointer(); } - /// Creates a deep copy of the specified statement, remapping any operands - /// that use values outside of the statement using the map that is provided ( - /// leaving them alone if no entry is present). Replaces references to cloned - /// sub-statements to the corresponding statement that is copied, and adds - /// those mappings to the map. - Statement *clone(const Statement &stmt, - OperationInst::OperandMapTy &operandMapping) { - Statement *cloneStmt = stmt.clone(operandMapping, getContext()); - block->getInstructions().insert(insertPoint, cloneStmt); - return cloneStmt; + /// Creates a deep copy of the specified instruction, remapping any operands + /// that use values outside of the instruction using the map that is provided + /// ( leaving them alone if no entry is present). Replaces references to + /// cloned sub-instructions to the corresponding instruction that is copied, + /// and adds those mappings to the map. + Instruction *clone(const Instruction &inst, + OperationInst::OperandMapTy &operandMapping) { + Instruction *cloneInst = inst.clone(operandMapping, getContext()); + block->getInstructions().insert(insertPoint, cloneInst); + return cloneInst; } - // Creates a for statement. When step is not specified, it is set to 1. - ForStmt *createFor(Location location, ArrayRef lbOperands, + // Creates a for instruction. When step is not specified, it is set to 1. + ForInst *createFor(Location location, ArrayRef lbOperands, AffineMap lbMap, ArrayRef ubOperands, AffineMap ubMap, int64_t step = 1); - // Creates a for statement with known (constant) lower and upper bounds. + // Creates a for instruction with known (constant) lower and upper bounds. // Default step is 1. - ForStmt *createFor(Location loc, int64_t lb, int64_t ub, int64_t step = 1); + ForInst *createFor(Location loc, int64_t lb, int64_t ub, int64_t step = 1); - /// Creates if statement. - IfStmt *createIf(Location location, ArrayRef operands, + /// Creates if instruction. + IfInst *createIf(Location location, ArrayRef operands, IntegerSet set); private: diff --git a/mlir/include/mlir/IR/BuiltinOps.h b/mlir/include/mlir/IR/BuiltinOps.h index e608a704f99..7f2b098e565 100644 --- a/mlir/include/mlir/IR/BuiltinOps.h +++ b/mlir/include/mlir/IR/BuiltinOps.h @@ -353,7 +353,7 @@ private: explicit ConstantIndexOp(const OperationInst *state) : ConstantOp(state) {} }; -/// The "return" operation represents a return statement within a function. +/// The "return" operation represents a return instruction within a function. /// The operation takes variable number of operands and produces no results. /// The operand number and types must match the signature of the function /// that contains the operation. For example: diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h index b79b64b68b5..c03fba61d2d 100644 --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -114,9 +114,9 @@ public: Block &front() { return blocks.front(); } const Block &front() const { return const_cast(this)->front(); } - /// Return the 'return' statement of this Function. - const OperationInst *getReturnStmt() const; - OperationInst *getReturnStmt(); + /// Return the 'return' instruction of this Function. + const OperationInst *getReturn() const; + OperationInst *getReturn(); // These should only be used on MLFunctions. Block *getBody() { @@ -127,12 +127,12 @@ public: return const_cast(this)->getBody(); } - /// Walk the statements in the function in preorder, calling the callback for - /// each operation statement. + /// Walk the instructions in the function in preorder, calling the callback + /// for each operation instruction. void walk(std::function callback); - /// Walk the statements in the function in postorder, calling the callback for - /// each operation statement. + /// Walk the instructions in the function in postorder, calling the callback + /// for each operation instruction. void walkPostOrder(std::function callback); //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/InstVisitor.h b/mlir/include/mlir/IR/InstVisitor.h new file mode 100644 index 00000000000..3ce7d25cafe --- /dev/null +++ b/mlir/include/mlir/IR/InstVisitor.h @@ -0,0 +1,230 @@ +//===- InstVisitor.h - MLIR Instruction Visitor Class -----------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines the base classes for Function's instruction visitors and +// walkers. A visit is a O(1) operation that visits just the node in question. A +// walk visits the node it's called on as well as the node's descendants. +// +// Instruction visitors/walkers are used when you want to perform different +// actions for different kinds of instructions without having to use lots of +// casts and a big switch instruction. +// +// To define your own visitor/walker, inherit from these classes, specifying +// your new type for the 'SubClass' template parameter, and "override" visitXXX +// functions in your class. This class is defined in terms of statically +// resolved overloading, not virtual functions. +// +// For example, here is a walker that counts the number of for loops in an +// Function. +// +// /// Declare the class. Note that we derive from InstWalker instantiated +// /// with _our new subclasses_ type. +// struct LoopCounter : public InstWalker { +// unsigned numLoops; +// LoopCounter() : numLoops(0) {} +// void visitForInst(ForInst &fs) { ++numLoops; } +// }; +// +// And this class would be used like this: +// LoopCounter lc; +// lc.walk(function); +// numLoops = lc.numLoops; +// +// There are 'visit' methods for OperationInst, ForInst, IfInst, and +// Function, which recursively process all contained instructions. +// +// Note that if you don't implement visitXXX for some instruction type, +// the visitXXX method for Instruction superclass will be invoked. +// +// The optional second template argument specifies the type that instruction +// visitation functions should return. If you specify this, you *MUST* provide +// an implementation of every visit<#Instruction>(InstType *). +// +// Note that these classes are specifically designed as a template to avoid +// virtual function call overhead. Defining and using a InstVisitor is just +// as efficient as having your own switch instruction over the instruction +// opcode. + +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_INSTVISITOR_H +#define MLIR_IR_INSTVISITOR_H + +#include "mlir/IR/Function.h" +#include "mlir/IR/Instructions.h" + +namespace mlir { + +/// Base class for instruction visitors. +template class InstVisitor { + //===--------------------------------------------------------------------===// + // Interface code - This is the public interface of the InstVisitor that you + // use to visit instructions. + +public: + // Function to visit a instruction. + RetTy visit(Instruction *s) { + static_assert(std::is_base_of::value, + "Must pass the derived type to this template!"); + + switch (s->getKind()) { + case Instruction::Kind::For: + return static_cast(this)->visitForInst(cast(s)); + case Instruction::Kind::If: + return static_cast(this)->visitIfInst(cast(s)); + case Instruction::Kind::OperationInst: + return static_cast(this)->visitOperationInst( + cast(s)); + } + } + + //===--------------------------------------------------------------------===// + // Visitation functions... these functions provide default fallbacks in case + // the user does not specify what to do for a particular instruction type. + // The default behavior is to generalize the instruction type to its subtype + // and try visiting the subtype. All of this should be inlined perfectly, + // because there are no virtual functions to get in the way. + // + + // When visiting a for inst, if inst, or an operation inst directly, these + // methods get called to indicate when transitioning into a new unit. + void visitForInst(ForInst *forInst) {} + void visitIfInst(IfInst *ifInst) {} + void visitOperationInst(OperationInst *opInst) {} +}; + +/// Base class for instruction walkers. A walker can traverse depth first in +/// pre-order or post order. The walk methods without a suffix do a pre-order +/// traversal while those that traverse in post order have a PostOrder suffix. +template class InstWalker { + //===--------------------------------------------------------------------===// + // Interface code - This is the public interface of the InstWalker used to + // walk instructions. + +public: + // Generic walk method - allow walk to all instructions in a range. + template void walk(Iterator Start, Iterator End) { + while (Start != End) { + walk(&(*Start++)); + } + } + template void walkPostOrder(Iterator Start, Iterator End) { + while (Start != End) { + walkPostOrder(&(*Start++)); + } + } + + // Define walkers for Function and all Function instruction kinds. + void walk(Function *f) { + static_cast(this)->visitMLFunction(f); + static_cast(this)->walk(f->getBody()->begin(), + f->getBody()->end()); + } + + void walkPostOrder(Function *f) { + static_cast(this)->walkPostOrder(f->getBody()->begin(), + f->getBody()->end()); + static_cast(this)->visitMLFunction(f); + } + + RetTy walkOpInst(OperationInst *opInst) { + return static_cast(this)->visitOperationInst(opInst); + } + + void walkForInst(ForInst *forInst) { + static_cast(this)->visitForInst(forInst); + auto *body = forInst->getBody(); + static_cast(this)->walk(body->begin(), body->end()); + } + + void walkForInstPostOrder(ForInst *forInst) { + auto *body = forInst->getBody(); + static_cast(this)->walkPostOrder(body->begin(), body->end()); + static_cast(this)->visitForInst(forInst); + } + + void walkIfInst(IfInst *ifInst) { + static_cast(this)->visitIfInst(ifInst); + static_cast(this)->walk(ifInst->getThen()->begin(), + ifInst->getThen()->end()); + if (ifInst->hasElse()) + static_cast(this)->walk(ifInst->getElse()->begin(), + ifInst->getElse()->end()); + } + + void walkIfInstPostOrder(IfInst *ifInst) { + static_cast(this)->walkPostOrder(ifInst->getThen()->begin(), + ifInst->getThen()->end()); + if (ifInst->hasElse()) + static_cast(this)->walkPostOrder(ifInst->getElse()->begin(), + ifInst->getElse()->end()); + static_cast(this)->visitIfInst(ifInst); + } + + // Function to walk a instruction. + RetTy walk(Instruction *s) { + static_assert(std::is_base_of::value, + "Must pass the derived type to this template!"); + + switch (s->getKind()) { + case Instruction::Kind::For: + return static_cast(this)->walkForInst(cast(s)); + case Instruction::Kind::If: + return static_cast(this)->walkIfInst(cast(s)); + case Instruction::Kind::OperationInst: + return static_cast(this)->walkOpInst(cast(s)); + } + } + + // Function to walk a instruction in post order DFS. + RetTy walkPostOrder(Instruction *s) { + static_assert(std::is_base_of::value, + "Must pass the derived type to this template!"); + + switch (s->getKind()) { + case Instruction::Kind::For: + return static_cast(this)->walkForInstPostOrder( + cast(s)); + case Instruction::Kind::If: + return static_cast(this)->walkIfInstPostOrder( + cast(s)); + case Instruction::Kind::OperationInst: + return static_cast(this)->walkOpInst(cast(s)); + } + } + + //===--------------------------------------------------------------------===// + // Visitation functions... these functions provide default fallbacks in case + // the user does not specify what to do for a particular instruction type. + // The default behavior is to generalize the instruction type to its subtype + // and try visiting the subtype. All of this should be inlined perfectly, + // because there are no virtual functions to get in the way. + + // When visiting a specific inst directly during a walk, these methods get + // called. These are typically O(1) complexity and shouldn't be recursively + // processing their descendants in some way. When using RetTy, all of these + // need to be overridden. + void visitMLFunction(Function *f) {} + void visitForInst(ForInst *forInst) {} + void visitIfInst(IfInst *ifInst) {} + void visitOperationInst(OperationInst *opInst) {} +}; + +} // end namespace mlir + +#endif // MLIR_IR_INSTVISITOR_H diff --git a/mlir/include/mlir/IR/Instruction.h b/mlir/include/mlir/IR/Instruction.h new file mode 100644 index 00000000000..533266f8e8f --- /dev/null +++ b/mlir/include/mlir/IR/Instruction.h @@ -0,0 +1,304 @@ +//===- Instruction.h - MLIR ML Instruction Class --------------------*- C++ +//-*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines the Instruction class. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_INSTRUCTION_H +#define MLIR_IR_INSTRUCTION_H + +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/ilist.h" +#include "llvm/ADT/ilist_node.h" + +namespace mlir { +class Block; +class Location; +class ForInst; +class MLIRContext; + +/// Terminator operations can have Block operands to represent successors. +using BlockOperand = IROperandImpl; + +} // namespace mlir + +//===----------------------------------------------------------------------===// +// ilist_traits for Instruction +//===----------------------------------------------------------------------===// + +namespace llvm { + +template <> struct ilist_traits<::mlir::Instruction> { + using Instruction = ::mlir::Instruction; + using inst_iterator = simple_ilist::iterator; + + static void deleteNode(Instruction *inst); + void addNodeToList(Instruction *inst); + void removeNodeFromList(Instruction *inst); + void transferNodesFromList(ilist_traits &otherList, + inst_iterator first, inst_iterator last); + +private: + mlir::Block *getContainingBlock(); +}; + +} // end namespace llvm + +namespace mlir { +template class OperandIterator; + +/// Instruction is a basic unit of execution within an ML function. +/// Instructions can be nested within for and if instructions effectively +/// forming a tree. Child instructions are organized into instruction blocks +/// represented by a 'Block' class. +class Instruction : public IROperandOwner, + public llvm::ilist_node_with_parent { +public: + enum class Kind { + OperationInst = (int)IROperandOwner::Kind::OperationInst, + For = (int)IROperandOwner::Kind::ForInst, + If = (int)IROperandOwner::Kind::IfInst, + }; + + Kind getKind() const { return (Kind)IROperandOwner::getKind(); } + + /// Remove this instruction from its parent block and delete it. + void erase(); + + // This is a verbose type used by the clone method below. + using OperandMapTy = + DenseMap, + llvm::detail::DenseMapPair>; + + /// Create a deep copy of this instruction, remapping any operands that use + /// values outside of the instruction using the map that is provided (leaving + /// them alone if no entry is present). Replaces references to cloned + /// sub-instructions to the corresponding instruction that is copied, and adds + /// those mappings to the map. + Instruction *clone(OperandMapTy &operandMap, MLIRContext *context) const; + Instruction *clone(MLIRContext *context) const; + + /// Returns the instruction block that contains this instruction. + Block *getBlock() const { return block; } + + /// Returns the closest surrounding instruction that contains this instruction + /// or nullptr if this is a top-level instruction. + Instruction *getParentInst() const; + + /// Returns the function that this instruction is part of. + /// The function is determined by traversing the chain of parent instructions. + /// Returns nullptr if the instruction is unlinked. + Function *getFunction() const; + + /// Destroys this instruction and its subclass data. + void destroy(); + + /// This drops all operand uses from this instruction, which is an essential + /// step in breaking cyclic dependences between references when they are to + /// be deleted. + void dropAllReferences(); + + /// Unlink this instruction from its current block and insert it right before + /// `existingInst` which may be in the same or another block in the same + /// function. + void moveBefore(Instruction *existingInst); + + /// Unlink this operation instruction from its current basic block and insert + /// it right before `iterator` in the specified basic block. + void moveBefore(Block *block, llvm::iplist::iterator iterator); + + // Returns whether the Instruction is a terminator. + bool isTerminator() const; + + void print(raw_ostream &os) const; + void dump() const; + + //===--------------------------------------------------------------------===// + // Operands + //===--------------------------------------------------------------------===// + + unsigned getNumOperands() const; + + Value *getOperand(unsigned idx); + const Value *getOperand(unsigned idx) const; + void setOperand(unsigned idx, Value *value); + + // Support non-const operand iteration. + using operand_iterator = OperandIterator; + + operand_iterator operand_begin(); + + operand_iterator operand_end(); + + /// Returns an iterator on the underlying Values. + llvm::iterator_range getOperands(); + + // Support const operand iteration. + using const_operand_iterator = + OperandIterator; + + const_operand_iterator operand_begin() const; + + const_operand_iterator operand_end() const; + + /// Returns a const iterator on the underlying Values. + llvm::iterator_range getOperands() const; + + MutableArrayRef getInstOperands(); + ArrayRef getInstOperands() const { + return const_cast(this)->getInstOperands(); + } + + InstOperand &getInstOperand(unsigned idx) { return getInstOperands()[idx]; } + const InstOperand &getInstOperand(unsigned idx) const { + return getInstOperands()[idx]; + } + + /// Emit an error about fatal conditions with this operation, reporting up to + /// any diagnostic handlers that may be listening. This function always + /// returns true. NOTE: This may terminate the containing application, only + /// use when the IR is in an inconsistent state. + bool emitError(const Twine &message) const; + + /// Emit a warning about this operation, reporting up to any diagnostic + /// handlers that may be listening. + void emitWarning(const Twine &message) const; + + /// Emit a note about this operation, reporting up to any diagnostic + /// handlers that may be listening. + void emitNote(const Twine &message) const; + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool classof(const IROperandOwner *ptr) { + return ptr->getKind() <= IROperandOwner::Kind::INST_LAST; + } + +protected: + Instruction(Kind kind, Location location) + : IROperandOwner((IROperandOwner::Kind)kind, location) {} + + // Instructions are deleted through the destroy() member because this class + // does not have a virtual destructor. + ~Instruction(); + +private: + /// The instruction block that containts this instruction. + Block *block = nullptr; + + // allow ilist_traits access to 'block' field. + friend struct llvm::ilist_traits; +}; + +inline raw_ostream &operator<<(raw_ostream &os, const Instruction &inst) { + inst.print(os); + return os; +} + +/// This is a helper template used to implement an iterator that contains a +/// pointer to some object and an index into it. The iterator moves the +/// index but keeps the object constant. +template +class IndexedAccessorIterator + : public llvm::iterator_facade_base< + ConcreteType, std::random_access_iterator_tag, ElementType *, + std::ptrdiff_t, ElementType *, ElementType *> { +public: + ptrdiff_t operator-(const IndexedAccessorIterator &rhs) const { + assert(object == rhs.object && "incompatible iterators"); + return index - rhs.index; + } + bool operator==(const IndexedAccessorIterator &rhs) const { + return object == rhs.object && index == rhs.index; + } + bool operator<(const IndexedAccessorIterator &rhs) const { + assert(object == rhs.object && "incompatible iterators"); + return index < rhs.index; + } + + ConcreteType &operator+=(ptrdiff_t offset) { + this->index += offset; + return static_cast(*this); + } + ConcreteType &operator-=(ptrdiff_t offset) { + this->index -= offset; + return static_cast(*this); + } + +protected: + IndexedAccessorIterator(ObjectType *object, unsigned index) + : object(object), index(index) {} + ObjectType *object; + unsigned index; +}; + +/// This template implements the const/non-const operand iterators for the +/// Instruction class in terms of getOperand(idx). +template +class OperandIterator final + : public IndexedAccessorIterator, + ObjectType, ElementType> { +public: + /// Initializes the operand iterator to the specified operand index. + OperandIterator(ObjectType *object, unsigned index) + : IndexedAccessorIterator, + ObjectType, ElementType>(object, index) {} + + /// Support converting to the const variant. This will be a no-op for const + /// variant. + operator OperandIterator() const { + return OperandIterator(this->object, + this->index); + } + + ElementType *operator*() const { + return this->object->getOperand(this->index); + } +}; + +// Implement the inline operand iterator methods. +inline auto Instruction::operand_begin() -> operand_iterator { + return operand_iterator(this, 0); +} + +inline auto Instruction::operand_end() -> operand_iterator { + return operand_iterator(this, getNumOperands()); +} + +inline auto Instruction::getOperands() + -> llvm::iterator_range { + return {operand_begin(), operand_end()}; +} + +inline auto Instruction::operand_begin() const -> const_operand_iterator { + return const_operand_iterator(this, 0); +} + +inline auto Instruction::operand_end() const -> const_operand_iterator { + return const_operand_iterator(this, getNumOperands()); +} + +inline auto Instruction::getOperands() const + -> llvm::iterator_range { + return {operand_begin(), operand_end()}; +} + +} // end namespace mlir + +#endif // MLIR_IR_INSTRUCTION_H diff --git a/mlir/include/mlir/IR/Instructions.h b/mlir/include/mlir/IR/Instructions.h new file mode 100644 index 00000000000..bd1b371ed06 --- /dev/null +++ b/mlir/include/mlir/IR/Instructions.h @@ -0,0 +1,864 @@ +//===- Instructions.h - MLIR ML Instruction Classes -----------------*- C++ +//-*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines classes for special kinds of ML Function instructions. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_INSTRUCTIONS_H +#define MLIR_IR_INSTRUCTIONS_H + +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Instruction.h" +#include "mlir/IR/IntegerSet.h" +#include "mlir/IR/OperationSupport.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/TrailingObjects.h" + +namespace mlir { +class AffineBound; +class IntegerSet; +class AffineCondition; +class AttributeListStorage; +template class ConstOpPointer; +template class OpPointer; +template class ResultIterator; +template class ResultTypeIterator; +class Function; + +/// Operations represent all of the arithmetic and other basic computation in +/// MLIR. +/// +class OperationInst final + : public Instruction, + private llvm::TrailingObjects { +public: + /// Create a new OperationInst with the specific fields. + static OperationInst * + create(Location location, OperationName name, ArrayRef operands, + ArrayRef resultTypes, ArrayRef attributes, + ArrayRef successors, MLIRContext *context); + + /// Return the context this operation is associated with. + MLIRContext *getContext() const; + + /// The name of an operation is the key identifier for it. + OperationName getName() const { return name; } + + /// If this operation has a registered operation description, return it. + /// Otherwise return null. + const AbstractOperation *getAbstractOperation() const { + return getName().getAbstractOperation(); + } + + /// Check if this instruction is a return instruction. + bool isReturn() const; + + //===--------------------------------------------------------------------===// + // Operands + //===--------------------------------------------------------------------===// + + unsigned getNumOperands() const { return numOperands; } + + Value *getOperand(unsigned idx) { return getInstOperand(idx).get(); } + const Value *getOperand(unsigned idx) const { + return getInstOperand(idx).get(); + } + void setOperand(unsigned idx, Value *value) { + return getInstOperand(idx).set(value); + } + + // Support non-const operand iteration. + using operand_iterator = OperandIterator; + + operand_iterator operand_begin() { return operand_iterator(this, 0); } + + operand_iterator operand_end() { + return operand_iterator(this, getNumOperands()); + } + + /// Returns an iterator on the underlying Value's (Value *). + llvm::iterator_range getOperands() { + return {operand_begin(), operand_end()}; + } + + // Support const operand iteration. + using const_operand_iterator = + OperandIterator; + + const_operand_iterator operand_begin() const { + return const_operand_iterator(this, 0); + } + + const_operand_iterator operand_end() const { + return const_operand_iterator(this, getNumOperands()); + } + + /// Returns a const iterator on the underlying Value's (Value *). + llvm::iterator_range getOperands() const { + return {operand_begin(), operand_end()}; + } + + ArrayRef getInstOperands() const { + return {getTrailingObjects(), numOperands}; + } + MutableArrayRef getInstOperands() { + return {getTrailingObjects(), numOperands}; + } + + InstOperand &getInstOperand(unsigned idx) { return getInstOperands()[idx]; } + const InstOperand &getInstOperand(unsigned idx) const { + return getInstOperands()[idx]; + } + + //===--------------------------------------------------------------------===// + // Results + //===--------------------------------------------------------------------===// + + /// Return true if there are no users of any results of this operation. + bool use_empty() const; + + unsigned getNumResults() const { return numResults; } + + Value *getResult(unsigned idx) { return &getInstResult(idx); } + const Value *getResult(unsigned idx) const { return &getInstResult(idx); } + + // Support non-const result iteration. + using result_iterator = ResultIterator; + result_iterator result_begin(); + result_iterator result_end(); + llvm::iterator_range getResults(); + + // Support const result iteration. + using const_result_iterator = + ResultIterator; + const_result_iterator result_begin() const; + + const_result_iterator result_end() const; + + llvm::iterator_range getResults() const; + + ArrayRef getInstResults() const { + return {getTrailingObjects(), numResults}; + } + + MutableArrayRef getInstResults() { + return {getTrailingObjects(), numResults}; + } + + InstResult &getInstResult(unsigned idx) { return getInstResults()[idx]; } + + const InstResult &getInstResult(unsigned idx) const { + return getInstResults()[idx]; + } + + // Support result type iteration. + using result_type_iterator = + ResultTypeIterator; + result_type_iterator result_type_begin() const; + + result_type_iterator result_type_end() const; + + llvm::iterator_range getResultTypes() const; + + //===--------------------------------------------------------------------===// + // Attributes + //===--------------------------------------------------------------------===// + + // Operations may optionally carry a list of attributes that associate + // constants to names. Attributes may be dynamically added and removed over + // the lifetime of an operation. + // + // We assume there will be relatively few attributes on a given operation + // (maybe a dozen or so, but not hundreds or thousands) so we use linear + // searches for everything. + + /// Return all of the attributes on this operation. + ArrayRef getAttrs() const; + + /// Return the specified attribute if present, null otherwise. + Attribute getAttr(Identifier name) const { + for (auto elt : getAttrs()) + if (elt.first == name) + return elt.second; + return nullptr; + } + + Attribute getAttr(StringRef name) const { + for (auto elt : getAttrs()) + if (elt.first.is(name)) + return elt.second; + return nullptr; + } + + template AttrClass getAttrOfType(Identifier name) const { + return getAttr(name).dyn_cast_or_null(); + } + + template AttrClass getAttrOfType(StringRef name) const { + return getAttr(name).dyn_cast_or_null(); + } + + /// If the an attribute exists with the specified name, change it to the new + /// value. Otherwise, add a new attribute with the specified name/value. + void setAttr(Identifier name, Attribute value); + + enum class RemoveResult { Removed, NotFound }; + + /// Remove the attribute with the specified name if it exists. The return + /// value indicates whether the attribute was present or not. + RemoveResult removeAttr(Identifier name); + + //===--------------------------------------------------------------------===// + // Terminators + //===--------------------------------------------------------------------===// + + MutableArrayRef getBlockOperands() { + assert(isTerminator() && "Only terminators have a block operands list"); + return {getTrailingObjects(), numSuccs}; + } + ArrayRef getBlockOperands() const { + return const_cast(this)->getBlockOperands(); + } + + llvm::iterator_range + getSuccessorOperands(unsigned index) const; + llvm::iterator_range getSuccessorOperands(unsigned index); + + unsigned getNumSuccessors() const { return numSuccs; } + unsigned getNumSuccessorOperands(unsigned index) const { + assert(isTerminator() && "Only terminators have successors"); + assert(index < getNumSuccessors()); + return getTrailingObjects()[index]; + } + + Block *getSuccessor(unsigned index) { + assert(index < getNumSuccessors()); + return getBlockOperands()[index].get(); + } + const Block *getSuccessor(unsigned index) const { + return const_cast(this)->getSuccessor(index); + } + void setSuccessor(Block *block, unsigned index); + + /// Erase a specific operand from the operand list of the successor at + /// 'index'. + void eraseSuccessorOperand(unsigned succIndex, unsigned opIndex) { + assert(succIndex < getNumSuccessors()); + assert(opIndex < getNumSuccessorOperands(succIndex)); + eraseOperand(getSuccessorOperandIndex(succIndex) + opIndex); + --getTrailingObjects()[succIndex]; + } + + /// Get the index of the first operand of the successor at the provided + /// index. + unsigned getSuccessorOperandIndex(unsigned index) const { + assert(isTerminator() && "Only terminators have successors."); + assert(index < getNumSuccessors()); + + // Count the number of operands for each of the successors after, and + // including, the one at 'index'. This is based upon the assumption that all + // non successor operands are placed at the beginning of the operand list. + auto *successorOpCountBegin = getTrailingObjects(); + unsigned postSuccessorOpCount = + std::accumulate(successorOpCountBegin + index, + successorOpCountBegin + getNumSuccessors(), 0); + return getNumOperands() - postSuccessorOpCount; + } + + //===--------------------------------------------------------------------===// + // Accessors for various properties of operations + //===--------------------------------------------------------------------===// + + /// Returns whether the operation is commutative. + bool isCommutative() const { + if (auto *absOp = getAbstractOperation()) + return absOp->hasProperty(OperationProperty::Commutative); + return false; + } + + /// Returns whether the operation has side-effects. + bool hasNoSideEffect() const { + if (auto *absOp = getAbstractOperation()) + return absOp->hasProperty(OperationProperty::NoSideEffect); + return false; + } + + /// Returns whether the operation is a terminator. + bool isTerminator() const { + if (auto *absOp = getAbstractOperation()) + return absOp->hasProperty(OperationProperty::Terminator); + return false; + } + + /// Attempt to constant fold this operation with the specified constant + /// operand values - the elements in "operands" will correspond directly to + /// the operands of the operation, but may be null if non-constant. If + /// constant folding is successful, this returns false and fills in the + /// `results` vector. If not, this returns true and `results` is unspecified. + bool constantFold(ArrayRef operands, + SmallVectorImpl &results) const; + + //===--------------------------------------------------------------------===// + // Conversions to declared operations like DimOp + //===--------------------------------------------------------------------===// + + // Return a null OpPointer for the specified type. + template static OpPointer getNull() { + return OpPointer(OpClass(nullptr)); + } + + /// The dyn_cast methods perform a dynamic cast from an OperationInst (like + /// Instruction and OperationInst) to a typed Op like DimOp. This returns + /// a null OpPointer on failure. + template OpPointer dyn_cast() { + if (isa()) { + return cast(); + } else { + return OpPointer(OpClass(nullptr)); + } + } + + /// The dyn_cast methods perform a dynamic cast from an OperationInst (like + /// Instruction and OperationInst) to a typed Op like DimOp. This returns + /// a null ConstOpPointer on failure. + template ConstOpPointer dyn_cast() const { + if (isa()) { + return cast(); + } else { + return ConstOpPointer(OpClass(nullptr)); + } + } + + /// The cast methods perform a cast from an OperationInst (like + /// Instruction and OperationInst) to a typed Op like DimOp. This aborts + /// if the parameter to the template isn't an instance of the template type + /// argument. + template OpPointer cast() { + assert(isa() && "cast() argument of incompatible type!"); + return OpPointer(OpClass(this)); + } + + /// The cast methods perform a cast from an OperationInst (like + /// Instruction and OperationInst) to a typed Op like DimOp. This aborts + /// if the parameter to the template isn't an instance of the template type + /// argument. + template ConstOpPointer cast() const { + assert(isa() && "cast() argument of incompatible type!"); + return ConstOpPointer(OpClass(this)); + } + + /// The is methods return true if the operation is a typed op (like DimOp) of + /// of the given class. + template bool isa() const { + return OpClass::isClassFor(this); + } + + //===--------------------------------------------------------------------===// + // Other + //===--------------------------------------------------------------------===// + + /// Emit an error with the op name prefixed, like "'dim' op " which is + /// convenient for verifiers. This function always returns true. + bool emitOpError(const Twine &message) const; + + void destroy(); + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool classof(const IROperandOwner *ptr) { + return ptr->getKind() == IROperandOwner::Kind::OperationInst; + } + +private: + unsigned numOperands; + const unsigned numResults, numSuccs; + + /// This holds the name of the operation. + OperationName name; + + /// This holds general named attributes for the operation. + AttributeListStorage *attrs; + + OperationInst(Location location, OperationName name, unsigned numOperands, + unsigned numResults, unsigned numSuccessors, + ArrayRef attributes, MLIRContext *context); + ~OperationInst(); + + /// Erase the operand at 'index'. + void eraseOperand(unsigned index); + + // This stuff is used by the TrailingObjects template. + friend llvm::TrailingObjects; + size_t numTrailingObjects(OverloadToken) const { + return numOperands; + } + size_t numTrailingObjects(OverloadToken) const { + return numResults; + } + size_t numTrailingObjects(OverloadToken) const { + return numSuccs; + } + size_t numTrailingObjects(OverloadToken) const { return numSuccs; } +}; + +/// This template implements the result iterators for the OperationInst class +/// in terms of getResult(idx). +template +class ResultIterator final + : public IndexedAccessorIterator, + ObjectType, ElementType> { +public: + /// Initializes the result iterator to the specified index. + ResultIterator(ObjectType *object, unsigned index) + : IndexedAccessorIterator, + ObjectType, ElementType>(object, index) {} + + /// Support converting to the const variant. This will be a no-op for const + /// variant. + operator ResultIterator() const { + return ResultIterator(this->object, + this->index); + } + + ElementType *operator*() const { + return this->object->getResult(this->index); + } +}; + +/// This template implements the result type iterators for the OperationInst +/// class in terms of getResult(idx)->getType(). +template +class ResultTypeIterator final + : public IndexedAccessorIterator< + ResultTypeIterator, ObjectType, + ElementType> { +public: + /// Initializes the result type iterator to the specified index. + ResultTypeIterator(ObjectType *object, unsigned index) + : IndexedAccessorIterator, + ObjectType, ElementType>(object, index) {} + + /// Support converting to the const variant. This will be a no-op for const + /// variant. + operator ResultTypeIterator() const { + return ResultTypeIterator(this->object, + this->index); + } + + Type operator*() const { + return this->object->getResult(this->index)->getType(); + } +}; + +// Implement the inline result iterator methods. +inline auto OperationInst::result_begin() -> result_iterator { + return result_iterator(this, 0); +} + +inline auto OperationInst::result_end() -> result_iterator { + return result_iterator(this, getNumResults()); +} + +inline auto OperationInst::getResults() + -> llvm::iterator_range { + return {result_begin(), result_end()}; +} + +inline auto OperationInst::result_begin() const -> const_result_iterator { + return const_result_iterator(this, 0); +} + +inline auto OperationInst::result_end() const -> const_result_iterator { + return const_result_iterator(this, getNumResults()); +} + +inline auto OperationInst::getResults() const + -> llvm::iterator_range { + return {result_begin(), result_end()}; +} + +inline auto OperationInst::result_type_begin() const -> result_type_iterator { + return result_type_iterator(this, 0); +} + +inline auto OperationInst::result_type_end() const -> result_type_iterator { + return result_type_iterator(this, getNumResults()); +} + +inline auto OperationInst::getResultTypes() const + -> llvm::iterator_range { + return {result_type_begin(), result_type_end()}; +} + +/// For instruction represents an affine loop nest. +class ForInst : public Instruction, public Value { +public: + static ForInst *create(Location location, ArrayRef lbOperands, + AffineMap lbMap, ArrayRef ubOperands, + AffineMap ubMap, int64_t step); + + ~ForInst() { + // Explicitly erase instructions instead of relying of 'Block' destructor + // since child instructions need to be destroyed before the Value that this + // for inst represents is destroyed. Affine maps are immortal objects and + // don't need to be deleted. + getBody()->clear(); + } + + /// Resolve base class ambiguity. + using Instruction::getFunction; + + /// Operand iterators. + using operand_iterator = OperandIterator; + using const_operand_iterator = OperandIterator; + + /// Operand iterator range. + using operand_range = llvm::iterator_range; + using const_operand_range = llvm::iterator_range; + + /// Get the body of the ForInst. + Block *getBody() { return &body.front(); } + + /// Get the body of the ForInst. + const Block *getBody() const { return &body.front(); } + + //===--------------------------------------------------------------------===// + // Bounds and step + //===--------------------------------------------------------------------===// + + /// Returns information about the lower bound as a single object. + const AffineBound getLowerBound() const; + + /// Returns information about the upper bound as a single object. + const AffineBound getUpperBound() const; + + /// Returns loop step. + int64_t getStep() const { return step; } + + /// Returns affine map for the lower bound. + AffineMap getLowerBoundMap() const { return lbMap; } + /// Returns affine map for the upper bound. The upper bound is exclusive. + AffineMap getUpperBoundMap() const { return ubMap; } + + /// Set lower bound. + void setLowerBound(ArrayRef operands, AffineMap map); + /// Set upper bound. + void setUpperBound(ArrayRef operands, AffineMap map); + + /// Set the lower bound map without changing operands. + void setLowerBoundMap(AffineMap map); + + /// Set the upper bound map without changing operands. + void setUpperBoundMap(AffineMap map); + + /// Set loop step. + void setStep(int64_t step) { + assert(step > 0 && "step has to be a positive integer constant"); + this->step = step; + } + + /// Returns true if the lower bound is constant. + bool hasConstantLowerBound() const; + /// Returns true if the upper bound is constant. + bool hasConstantUpperBound() const; + /// Returns true if both bounds are constant. + bool hasConstantBounds() const { + return hasConstantLowerBound() && hasConstantUpperBound(); + } + /// Returns the value of the constant lower bound. + /// Fails assertion if the bound is non-constant. + int64_t getConstantLowerBound() const; + /// Returns the value of the constant upper bound. The upper bound is + /// exclusive. Fails assertion if the bound is non-constant. + int64_t getConstantUpperBound() const; + /// Sets the lower bound to the given constant value. + void setConstantLowerBound(int64_t value); + /// Sets the upper bound to the given constant value. + void setConstantUpperBound(int64_t value); + + /// Returns true if both the lower and upper bound have the same operand lists + /// (same operands in the same order). + bool matchingBoundOperandList() const; + + //===--------------------------------------------------------------------===// + // Operands + //===--------------------------------------------------------------------===// + + unsigned getNumOperands() const { return operands.size(); } + + Value *getOperand(unsigned idx) { return getInstOperand(idx).get(); } + const Value *getOperand(unsigned idx) const { + return getInstOperand(idx).get(); + } + void setOperand(unsigned idx, Value *value) { + getInstOperand(idx).set(value); + } + + operand_iterator operand_begin() { return operand_iterator(this, 0); } + operand_iterator operand_end() { + return operand_iterator(this, getNumOperands()); + } + + const_operand_iterator operand_begin() const { + return const_operand_iterator(this, 0); + } + const_operand_iterator operand_end() const { + return const_operand_iterator(this, getNumOperands()); + } + + ArrayRef getInstOperands() const { return operands; } + MutableArrayRef getInstOperands() { return operands; } + InstOperand &getInstOperand(unsigned idx) { return getInstOperands()[idx]; } + const InstOperand &getInstOperand(unsigned idx) const { + return getInstOperands()[idx]; + } + + // TODO: provide iterators for the lower and upper bound operands + // if the current access via getLowerBound(), getUpperBound() is too slow. + + /// Returns operands for the lower bound map. + operand_range getLowerBoundOperands(); + const_operand_range getLowerBoundOperands() const; + + /// Returns operands for the upper bound map. + operand_range getUpperBoundOperands(); + const_operand_range getUpperBoundOperands() const; + + //===--------------------------------------------------------------------===// + // Other + //===--------------------------------------------------------------------===// + + /// Return the context this operation is associated with. + MLIRContext *getContext() const { return getType().getContext(); } + + using Instruction::dump; + using Instruction::print; + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool classof(const IROperandOwner *ptr) { + return ptr->getKind() == IROperandOwner::Kind::ForInst; + } + + // For instruction represents implicitly represents induction variable by + // inheriting from Value class. Whenever you need to refer to the loop + // induction variable, just use the for instruction itself. + static bool classof(const Value *value) { + return value->getKind() == Value::Kind::ForInst; + } + +private: + // The Block for the body. + BlockList body; + + // Affine map for the lower bound. + AffineMap lbMap; + // Affine map for the upper bound. The upper bound is exclusive. + AffineMap ubMap; + // Positive constant step. Since index is stored as an int64_t, we restrict + // step to the set of positive integers that int64_t can represent. + int64_t step; + // Operands for the lower and upper bounds, with the former followed by the + // latter. Dimensional operands are followed by symbolic operands for each + // bound. + std::vector operands; + + explicit ForInst(Location location, unsigned numOperands, AffineMap lbMap, + AffineMap ubMap, int64_t step); +}; + +/// AffineBound represents a lower or upper bound in the for instruction. +/// This class does not own the underlying operands. Instead, it refers +/// to the operands stored in the ForInst. Its life span should not exceed +/// that of the for instruction it refers to. +class AffineBound { +public: + const ForInst *getForInst() const { return &inst; } + AffineMap getMap() const { return map; } + + unsigned getNumOperands() const { return opEnd - opStart; } + const Value *getOperand(unsigned idx) const { + return inst.getOperand(opStart + idx); + } + const InstOperand &getInstOperand(unsigned idx) const { + return inst.getInstOperand(opStart + idx); + } + + using operand_iterator = ForInst::operand_iterator; + using operand_range = ForInst::operand_range; + + operand_iterator operand_begin() const { + // These are iterators over Value *. Not casting away const'ness would + // require the caller to use const Value *. + return operand_iterator(const_cast(&inst), opStart); + } + operand_iterator operand_end() const { + return operand_iterator(const_cast(&inst), opEnd); + } + + /// Returns an iterator on the underlying Value's (Value *). + operand_range getOperands() const { return {operand_begin(), operand_end()}; } + ArrayRef getInstOperands() const { + auto ops = inst.getInstOperands(); + return ArrayRef(ops.begin() + opStart, ops.begin() + opEnd); + } + +private: + // 'for' instruction that contains this bound. + const ForInst &inst; + // Start and end positions of this affine bound operands in the list of + // the containing 'for' instruction operands. + unsigned opStart, opEnd; + // Affine map for this bound. + AffineMap map; + + AffineBound(const ForInst &inst, unsigned opStart, unsigned opEnd, + AffineMap map) + : inst(inst), opStart(opStart), opEnd(opEnd), map(map) {} + + friend class ForInst; +}; + +/// If instruction restricts execution to a subset of the loop iteration space. +class IfInst : public Instruction { +public: + static IfInst *create(Location location, ArrayRef operands, + IntegerSet set); + ~IfInst(); + + //===--------------------------------------------------------------------===// + // Then, else, condition. + //===--------------------------------------------------------------------===// + + Block *getThen() { return &thenClause.front(); } + const Block *getThen() const { return &thenClause.front(); } + Block *getElse() { return elseClause ? &elseClause->front() : nullptr; } + const Block *getElse() const { + return elseClause ? &elseClause->front() : nullptr; + } + bool hasElse() const { return elseClause != nullptr; } + + Block *createElse() { + assert(elseClause == nullptr && "already has an else clause!"); + elseClause = new BlockList(this); + elseClause->push_back(new Block()); + return &elseClause->front(); + } + + const AffineCondition getCondition() const; + + IntegerSet getIntegerSet() const { return set; } + void setIntegerSet(IntegerSet newSet) { + assert(newSet.getNumOperands() == operands.size()); + set = newSet; + } + + //===--------------------------------------------------------------------===// + // Operands + //===--------------------------------------------------------------------===// + + /// Operand iterators. + using operand_iterator = OperandIterator; + using const_operand_iterator = OperandIterator; + + /// Operand iterator range. + using operand_range = llvm::iterator_range; + using const_operand_range = llvm::iterator_range; + + unsigned getNumOperands() const { return operands.size(); } + + Value *getOperand(unsigned idx) { return getInstOperand(idx).get(); } + const Value *getOperand(unsigned idx) const { + return getInstOperand(idx).get(); + } + void setOperand(unsigned idx, Value *value) { + getInstOperand(idx).set(value); + } + + operand_iterator operand_begin() { return operand_iterator(this, 0); } + operand_iterator operand_end() { + return operand_iterator(this, getNumOperands()); + } + + const_operand_iterator operand_begin() const { + return const_operand_iterator(this, 0); + } + const_operand_iterator operand_end() const { + return const_operand_iterator(this, getNumOperands()); + } + + ArrayRef getInstOperands() const { return operands; } + MutableArrayRef getInstOperands() { return operands; } + InstOperand &getInstOperand(unsigned idx) { return getInstOperands()[idx]; } + const InstOperand &getInstOperand(unsigned idx) const { + return getInstOperands()[idx]; + } + + //===--------------------------------------------------------------------===// + // Other + //===--------------------------------------------------------------------===// + + MLIRContext *getContext() const; + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool classof(const IROperandOwner *ptr) { + return ptr->getKind() == IROperandOwner::Kind::IfInst; + } + +private: + // it is always present. + BlockList thenClause; + // 'else' clause of the if instruction. 'nullptr' if there is no else clause. + BlockList *elseClause; + + // The integer set capturing the conditional guard. + IntegerSet set; + + // Condition operands. + std::vector operands; + + explicit IfInst(Location location, unsigned numOperands, IntegerSet set); +}; + +/// AffineCondition represents a condition of the 'if' instruction. +/// Its life span should not exceed that of the objects it refers to. +/// AffineCondition does not provide its own methods for iterating over +/// the operands since the iterators of the if instruction accomplish +/// the same purpose. +/// +/// AffineCondition is trivially copyable, so it should be passed by value. +class AffineCondition { +public: + const IfInst *getIfInst() const { return &inst; } + IntegerSet getIntegerSet() const { return set; } + +private: + // 'if' instruction that contains this affine condition. + const IfInst &inst; + // Integer set for this affine condition. + IntegerSet set; + + AffineCondition(const IfInst &inst, IntegerSet set) : inst(inst), set(set) {} + + friend class IfInst; +}; +} // end namespace mlir + +#endif // MLIR_IR_INSTRUCTIONS_H diff --git a/mlir/include/mlir/IR/IntegerSet.h b/mlir/include/mlir/IR/IntegerSet.h index 71c6ea6f79a..dc7eff7f572 100644 --- a/mlir/include/mlir/IR/IntegerSet.h +++ b/mlir/include/mlir/IR/IntegerSet.h @@ -17,7 +17,7 @@ // // Integer sets are sets of points from the integer lattice constrained by // affine equality/inequality constraints. This class is meant to represent -// affine equality/inequality conditions for MLFunctions' if statements. As +// affine equality/inequality conditions for MLFunctions' if instructions. As // such, it is only expected to contain a handful of affine constraints, and it // is immutable like an Affine Map. Integer sets are however not unique'd - // although affine expressions that make up the equalities and inequalites of an diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index e1b90b6e39e..8441c750764 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -28,7 +28,7 @@ #ifndef MLIR_IR_OPDEFINITION_H #define MLIR_IR_OPDEFINITION_H -#include "mlir/IR/Statements.h" +#include "mlir/IR/Instructions.h" #include namespace mlir { diff --git a/mlir/include/mlir/IR/Statement.h b/mlir/include/mlir/IR/Statement.h deleted file mode 100644 index 9ca5530f33c..00000000000 --- a/mlir/include/mlir/IR/Statement.h +++ /dev/null @@ -1,301 +0,0 @@ -//===- Statement.h - MLIR ML Statement Class --------------------*- C++ -*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file defines the Statement class. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_IR_STATEMENT_H -#define MLIR_IR_STATEMENT_H - -#include "mlir/IR/Value.h" -#include "mlir/Support/LLVM.h" -#include "llvm/ADT/ilist.h" -#include "llvm/ADT/ilist_node.h" - -namespace mlir { -class Block; -class Location; -class ForStmt; -class MLIRContext; - -/// Terminator operations can have Block operands to represent successors. -using BlockOperand = IROperandImpl; - -} // namespace mlir - -//===----------------------------------------------------------------------===// -// ilist_traits for Statement -//===----------------------------------------------------------------------===// - -namespace llvm { - -template <> struct ilist_traits<::mlir::Statement> { - using Statement = ::mlir::Statement; - using stmt_iterator = simple_ilist::iterator; - - static void deleteNode(Statement *stmt); - void addNodeToList(Statement *stmt); - void removeNodeFromList(Statement *stmt); - void transferNodesFromList(ilist_traits &otherList, - stmt_iterator first, stmt_iterator last); - -private: - mlir::Block *getContainingBlock(); -}; - -} // end namespace llvm - -namespace mlir { -template class OperandIterator; - -/// Statement is a basic unit of execution within an ML function. -/// Statements can be nested within for and if statements effectively -/// forming a tree. Child statements are organized into statement blocks -/// represented by a 'Block' class. -class Statement : public IROperandOwner, - public llvm::ilist_node_with_parent { -public: - enum class Kind { - OperationInst = (int)IROperandOwner::Kind::OperationInst, - For = (int)IROperandOwner::Kind::ForStmt, - If = (int)IROperandOwner::Kind::IfStmt, - }; - - Kind getKind() const { return (Kind)IROperandOwner::getKind(); } - - /// Remove this statement from its parent block and delete it. - void erase(); - - // This is a verbose type used by the clone method below. - using OperandMapTy = - DenseMap, - llvm::detail::DenseMapPair>; - - /// Create a deep copy of this statement, remapping any operands that use - /// values outside of the statement using the map that is provided (leaving - /// them alone if no entry is present). Replaces references to cloned - /// sub-statements to the corresponding statement that is copied, and adds - /// those mappings to the map. - Statement *clone(OperandMapTy &operandMap, MLIRContext *context) const; - Statement *clone(MLIRContext *context) const; - - /// Returns the statement block that contains this statement. - Block *getBlock() const { return block; } - - /// Returns the closest surrounding statement that contains this statement - /// or nullptr if this is a top-level statement. - Statement *getParentStmt() const; - - /// Returns the function that this statement is part of. - /// The function is determined by traversing the chain of parent statements. - /// Returns nullptr if the statement is unlinked. - Function *getFunction() const; - - /// Destroys this statement and its subclass data. - void destroy(); - - /// This drops all operand uses from this instruction, which is an essential - /// step in breaking cyclic dependences between references when they are to - /// be deleted. - void dropAllReferences(); - - /// Unlink this statement from its current block and insert it right before - /// `existingStmt` which may be in the same or another block in the same - /// function. - void moveBefore(Statement *existingStmt); - - /// Unlink this operation instruction from its current basic block and insert - /// it right before `iterator` in the specified basic block. - void moveBefore(Block *block, llvm::iplist::iterator iterator); - - // Returns whether the Statement is a terminator. - bool isTerminator() const; - - void print(raw_ostream &os) const; - void dump() const; - - //===--------------------------------------------------------------------===// - // Operands - //===--------------------------------------------------------------------===// - - unsigned getNumOperands() const; - - Value *getOperand(unsigned idx); - const Value *getOperand(unsigned idx) const; - void setOperand(unsigned idx, Value *value); - - // Support non-const operand iteration. - using operand_iterator = OperandIterator; - - operand_iterator operand_begin(); - - operand_iterator operand_end(); - - /// Returns an iterator on the underlying Values. - llvm::iterator_range getOperands(); - - // Support const operand iteration. - using const_operand_iterator = OperandIterator; - - const_operand_iterator operand_begin() const; - - const_operand_iterator operand_end() const; - - /// Returns a const iterator on the underlying Values. - llvm::iterator_range getOperands() const; - - MutableArrayRef getInstOperands(); - ArrayRef getInstOperands() const { - return const_cast(this)->getInstOperands(); - } - - InstOperand &getInstOperand(unsigned idx) { return getInstOperands()[idx]; } - const InstOperand &getInstOperand(unsigned idx) const { - return getInstOperands()[idx]; - } - - /// Emit an error about fatal conditions with this operation, reporting up to - /// any diagnostic handlers that may be listening. This function always - /// returns true. NOTE: This may terminate the containing application, only - /// use when the IR is in an inconsistent state. - bool emitError(const Twine &message) const; - - /// Emit a warning about this operation, reporting up to any diagnostic - /// handlers that may be listening. - void emitWarning(const Twine &message) const; - - /// Emit a note about this operation, reporting up to any diagnostic - /// handlers that may be listening. - void emitNote(const Twine &message) const; - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(const IROperandOwner *ptr) { - return ptr->getKind() <= IROperandOwner::Kind::STMT_LAST; - } - -protected: - Statement(Kind kind, Location location) - : IROperandOwner((IROperandOwner::Kind)kind, location) {} - - // Statements are deleted through the destroy() member because this class - // does not have a virtual destructor. - ~Statement(); - -private: - /// The statement block that containts this statement. - Block *block = nullptr; - - // allow ilist_traits access to 'block' field. - friend struct llvm::ilist_traits; -}; - -inline raw_ostream &operator<<(raw_ostream &os, const Statement &stmt) { - stmt.print(os); - return os; -} - -/// This is a helper template used to implement an iterator that contains a -/// pointer to some object and an index into it. The iterator moves the -/// index but keeps the object constant. -template -class IndexedAccessorIterator - : public llvm::iterator_facade_base< - ConcreteType, std::random_access_iterator_tag, ElementType *, - std::ptrdiff_t, ElementType *, ElementType *> { -public: - ptrdiff_t operator-(const IndexedAccessorIterator &rhs) const { - assert(object == rhs.object && "incompatible iterators"); - return index - rhs.index; - } - bool operator==(const IndexedAccessorIterator &rhs) const { - return object == rhs.object && index == rhs.index; - } - bool operator<(const IndexedAccessorIterator &rhs) const { - assert(object == rhs.object && "incompatible iterators"); - return index < rhs.index; - } - - ConcreteType &operator+=(ptrdiff_t offset) { - this->index += offset; - return static_cast(*this); - } - ConcreteType &operator-=(ptrdiff_t offset) { - this->index -= offset; - return static_cast(*this); - } - -protected: - IndexedAccessorIterator(ObjectType *object, unsigned index) - : object(object), index(index) {} - ObjectType *object; - unsigned index; -}; - -/// This template implements the const/non-const operand iterators for the -/// Instruction class in terms of getOperand(idx). -template -class OperandIterator final - : public IndexedAccessorIterator, - ObjectType, ElementType> { -public: - /// Initializes the operand iterator to the specified operand index. - OperandIterator(ObjectType *object, unsigned index) - : IndexedAccessorIterator, - ObjectType, ElementType>(object, index) {} - - /// Support converting to the const variant. This will be a no-op for const - /// variant. - operator OperandIterator() const { - return OperandIterator(this->object, - this->index); - } - - ElementType *operator*() const { - return this->object->getOperand(this->index); - } -}; - -// Implement the inline operand iterator methods. -inline auto Statement::operand_begin() -> operand_iterator { - return operand_iterator(this, 0); -} - -inline auto Statement::operand_end() -> operand_iterator { - return operand_iterator(this, getNumOperands()); -} - -inline auto Statement::getOperands() -> llvm::iterator_range { - return {operand_begin(), operand_end()}; -} - -inline auto Statement::operand_begin() const -> const_operand_iterator { - return const_operand_iterator(this, 0); -} - -inline auto Statement::operand_end() const -> const_operand_iterator { - return const_operand_iterator(this, getNumOperands()); -} - -inline auto Statement::getOperands() const - -> llvm::iterator_range { - return {operand_begin(), operand_end()}; -} - -} // end namespace mlir - -#endif // MLIR_IR_STATEMENT_H diff --git a/mlir/include/mlir/IR/Statements.h b/mlir/include/mlir/IR/Statements.h deleted file mode 100644 index aa4157714a7..00000000000 --- a/mlir/include/mlir/IR/Statements.h +++ /dev/null @@ -1,863 +0,0 @@ -//===- Statements.h - MLIR ML Statement Classes -----------------*- C++ -*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file defines classes for special kinds of ML Function statements. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_IR_STATEMENTS_H -#define MLIR_IR_STATEMENTS_H - -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/Block.h" -#include "mlir/IR/IntegerSet.h" -#include "mlir/IR/OperationSupport.h" -#include "mlir/IR/Statement.h" -#include "llvm/ADT/Twine.h" -#include "llvm/Support/TrailingObjects.h" - -namespace mlir { -class AffineBound; -class IntegerSet; -class AffineCondition; -class AttributeListStorage; -template class ConstOpPointer; -template class OpPointer; -template class ResultIterator; -template class ResultTypeIterator; -class Function; - -/// Operations represent all of the arithmetic and other basic computation in -/// MLIR. -/// -class OperationInst final - : public Statement, - private llvm::TrailingObjects { -public: - /// Create a new OperationInst with the specific fields. - static OperationInst * - create(Location location, OperationName name, ArrayRef operands, - ArrayRef resultTypes, ArrayRef attributes, - ArrayRef successors, MLIRContext *context); - - /// Return the context this operation is associated with. - MLIRContext *getContext() const; - - /// The name of an operation is the key identifier for it. - OperationName getName() const { return name; } - - /// If this operation has a registered operation description, return it. - /// Otherwise return null. - const AbstractOperation *getAbstractOperation() const { - return getName().getAbstractOperation(); - } - - /// Check if this statement is a return statement. - bool isReturn() const; - - //===--------------------------------------------------------------------===// - // Operands - //===--------------------------------------------------------------------===// - - unsigned getNumOperands() const { return numOperands; } - - Value *getOperand(unsigned idx) { return getInstOperand(idx).get(); } - const Value *getOperand(unsigned idx) const { - return getInstOperand(idx).get(); - } - void setOperand(unsigned idx, Value *value) { - return getInstOperand(idx).set(value); - } - - // Support non-const operand iteration. - using operand_iterator = OperandIterator; - - operand_iterator operand_begin() { return operand_iterator(this, 0); } - - operand_iterator operand_end() { - return operand_iterator(this, getNumOperands()); - } - - /// Returns an iterator on the underlying Value's (Value *). - llvm::iterator_range getOperands() { - return {operand_begin(), operand_end()}; - } - - // Support const operand iteration. - using const_operand_iterator = - OperandIterator; - - const_operand_iterator operand_begin() const { - return const_operand_iterator(this, 0); - } - - const_operand_iterator operand_end() const { - return const_operand_iterator(this, getNumOperands()); - } - - /// Returns a const iterator on the underlying Value's (Value *). - llvm::iterator_range getOperands() const { - return {operand_begin(), operand_end()}; - } - - ArrayRef getInstOperands() const { - return {getTrailingObjects(), numOperands}; - } - MutableArrayRef getInstOperands() { - return {getTrailingObjects(), numOperands}; - } - - InstOperand &getInstOperand(unsigned idx) { return getInstOperands()[idx]; } - const InstOperand &getInstOperand(unsigned idx) const { - return getInstOperands()[idx]; - } - - //===--------------------------------------------------------------------===// - // Results - //===--------------------------------------------------------------------===// - - /// Return true if there are no users of any results of this operation. - bool use_empty() const; - - unsigned getNumResults() const { return numResults; } - - Value *getResult(unsigned idx) { return &getInstResult(idx); } - const Value *getResult(unsigned idx) const { return &getInstResult(idx); } - - // Support non-const result iteration. - using result_iterator = ResultIterator; - result_iterator result_begin(); - result_iterator result_end(); - llvm::iterator_range getResults(); - - // Support const result iteration. - using const_result_iterator = - ResultIterator; - const_result_iterator result_begin() const; - - const_result_iterator result_end() const; - - llvm::iterator_range getResults() const; - - ArrayRef getInstResults() const { - return {getTrailingObjects(), numResults}; - } - - MutableArrayRef getInstResults() { - return {getTrailingObjects(), numResults}; - } - - InstResult &getInstResult(unsigned idx) { return getInstResults()[idx]; } - - const InstResult &getInstResult(unsigned idx) const { - return getInstResults()[idx]; - } - - // Support result type iteration. - using result_type_iterator = - ResultTypeIterator; - result_type_iterator result_type_begin() const; - - result_type_iterator result_type_end() const; - - llvm::iterator_range getResultTypes() const; - - //===--------------------------------------------------------------------===// - // Attributes - //===--------------------------------------------------------------------===// - - // Operations may optionally carry a list of attributes that associate - // constants to names. Attributes may be dynamically added and removed over - // the lifetime of an operation. - // - // We assume there will be relatively few attributes on a given operation - // (maybe a dozen or so, but not hundreds or thousands) so we use linear - // searches for everything. - - /// Return all of the attributes on this operation. - ArrayRef getAttrs() const; - - /// Return the specified attribute if present, null otherwise. - Attribute getAttr(Identifier name) const { - for (auto elt : getAttrs()) - if (elt.first == name) - return elt.second; - return nullptr; - } - - Attribute getAttr(StringRef name) const { - for (auto elt : getAttrs()) - if (elt.first.is(name)) - return elt.second; - return nullptr; - } - - template AttrClass getAttrOfType(Identifier name) const { - return getAttr(name).dyn_cast_or_null(); - } - - template AttrClass getAttrOfType(StringRef name) const { - return getAttr(name).dyn_cast_or_null(); - } - - /// If the an attribute exists with the specified name, change it to the new - /// value. Otherwise, add a new attribute with the specified name/value. - void setAttr(Identifier name, Attribute value); - - enum class RemoveResult { Removed, NotFound }; - - /// Remove the attribute with the specified name if it exists. The return - /// value indicates whether the attribute was present or not. - RemoveResult removeAttr(Identifier name); - - //===--------------------------------------------------------------------===// - // Terminators - //===--------------------------------------------------------------------===// - - MutableArrayRef getBlockOperands() { - assert(isTerminator() && "Only terminators have a block operands list"); - return {getTrailingObjects(), numSuccs}; - } - ArrayRef getBlockOperands() const { - return const_cast(this)->getBlockOperands(); - } - - llvm::iterator_range - getSuccessorOperands(unsigned index) const; - llvm::iterator_range getSuccessorOperands(unsigned index); - - unsigned getNumSuccessors() const { return numSuccs; } - unsigned getNumSuccessorOperands(unsigned index) const { - assert(isTerminator() && "Only terminators have successors"); - assert(index < getNumSuccessors()); - return getTrailingObjects()[index]; - } - - Block *getSuccessor(unsigned index) { - assert(index < getNumSuccessors()); - return getBlockOperands()[index].get(); - } - const Block *getSuccessor(unsigned index) const { - return const_cast(this)->getSuccessor(index); - } - void setSuccessor(Block *block, unsigned index); - - /// Erase a specific operand from the operand list of the successor at - /// 'index'. - void eraseSuccessorOperand(unsigned succIndex, unsigned opIndex) { - assert(succIndex < getNumSuccessors()); - assert(opIndex < getNumSuccessorOperands(succIndex)); - eraseOperand(getSuccessorOperandIndex(succIndex) + opIndex); - --getTrailingObjects()[succIndex]; - } - - /// Get the index of the first operand of the successor at the provided - /// index. - unsigned getSuccessorOperandIndex(unsigned index) const { - assert(isTerminator() && "Only terminators have successors."); - assert(index < getNumSuccessors()); - - // Count the number of operands for each of the successors after, and - // including, the one at 'index'. This is based upon the assumption that all - // non successor operands are placed at the beginning of the operand list. - auto *successorOpCountBegin = getTrailingObjects(); - unsigned postSuccessorOpCount = - std::accumulate(successorOpCountBegin + index, - successorOpCountBegin + getNumSuccessors(), 0); - return getNumOperands() - postSuccessorOpCount; - } - - //===--------------------------------------------------------------------===// - // Accessors for various properties of operations - //===--------------------------------------------------------------------===// - - /// Returns whether the operation is commutative. - bool isCommutative() const { - if (auto *absOp = getAbstractOperation()) - return absOp->hasProperty(OperationProperty::Commutative); - return false; - } - - /// Returns whether the operation has side-effects. - bool hasNoSideEffect() const { - if (auto *absOp = getAbstractOperation()) - return absOp->hasProperty(OperationProperty::NoSideEffect); - return false; - } - - /// Returns whether the operation is a terminator. - bool isTerminator() const { - if (auto *absOp = getAbstractOperation()) - return absOp->hasProperty(OperationProperty::Terminator); - return false; - } - - /// Attempt to constant fold this operation with the specified constant - /// operand values - the elements in "operands" will correspond directly to - /// the operands of the operation, but may be null if non-constant. If - /// constant folding is successful, this returns false and fills in the - /// `results` vector. If not, this returns true and `results` is unspecified. - bool constantFold(ArrayRef operands, - SmallVectorImpl &results) const; - - //===--------------------------------------------------------------------===// - // Conversions to declared operations like DimOp - //===--------------------------------------------------------------------===// - - // Return a null OpPointer for the specified type. - template static OpPointer getNull() { - return OpPointer(OpClass(nullptr)); - } - - /// The dyn_cast methods perform a dynamic cast from an OperationInst (like - /// Instruction and OperationInst) to a typed Op like DimOp. This returns - /// a null OpPointer on failure. - template OpPointer dyn_cast() { - if (isa()) { - return cast(); - } else { - return OpPointer(OpClass(nullptr)); - } - } - - /// The dyn_cast methods perform a dynamic cast from an OperationInst (like - /// Instruction and OperationInst) to a typed Op like DimOp. This returns - /// a null ConstOpPointer on failure. - template ConstOpPointer dyn_cast() const { - if (isa()) { - return cast(); - } else { - return ConstOpPointer(OpClass(nullptr)); - } - } - - /// The cast methods perform a cast from an OperationInst (like - /// Instruction and OperationInst) to a typed Op like DimOp. This aborts - /// if the parameter to the template isn't an instance of the template type - /// argument. - template OpPointer cast() { - assert(isa() && "cast() argument of incompatible type!"); - return OpPointer(OpClass(this)); - } - - /// The cast methods perform a cast from an OperationInst (like - /// Instruction and OperationInst) to a typed Op like DimOp. This aborts - /// if the parameter to the template isn't an instance of the template type - /// argument. - template ConstOpPointer cast() const { - assert(isa() && "cast() argument of incompatible type!"); - return ConstOpPointer(OpClass(this)); - } - - /// The is methods return true if the operation is a typed op (like DimOp) of - /// of the given class. - template bool isa() const { - return OpClass::isClassFor(this); - } - - //===--------------------------------------------------------------------===// - // Other - //===--------------------------------------------------------------------===// - - /// Emit an error with the op name prefixed, like "'dim' op " which is - /// convenient for verifiers. This function always returns true. - bool emitOpError(const Twine &message) const; - - void destroy(); - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(const IROperandOwner *ptr) { - return ptr->getKind() == IROperandOwner::Kind::OperationInst; - } - -private: - unsigned numOperands; - const unsigned numResults, numSuccs; - - /// This holds the name of the operation. - OperationName name; - - /// This holds general named attributes for the operation. - AttributeListStorage *attrs; - - OperationInst(Location location, OperationName name, unsigned numOperands, - unsigned numResults, unsigned numSuccessors, - ArrayRef attributes, MLIRContext *context); - ~OperationInst(); - - /// Erase the operand at 'index'. - void eraseOperand(unsigned index); - - // This stuff is used by the TrailingObjects template. - friend llvm::TrailingObjects; - size_t numTrailingObjects(OverloadToken) const { - return numOperands; - } - size_t numTrailingObjects(OverloadToken) const { - return numResults; - } - size_t numTrailingObjects(OverloadToken) const { - return numSuccs; - } - size_t numTrailingObjects(OverloadToken) const { return numSuccs; } -}; - -/// This template implements the result iterators for the OperationInst class -/// in terms of getResult(idx). -template -class ResultIterator final - : public IndexedAccessorIterator, - ObjectType, ElementType> { -public: - /// Initializes the result iterator to the specified index. - ResultIterator(ObjectType *object, unsigned index) - : IndexedAccessorIterator, - ObjectType, ElementType>(object, index) {} - - /// Support converting to the const variant. This will be a no-op for const - /// variant. - operator ResultIterator() const { - return ResultIterator(this->object, - this->index); - } - - ElementType *operator*() const { - return this->object->getResult(this->index); - } -}; - -/// This template implements the result type iterators for the OperationInst -/// class in terms of getResult(idx)->getType(). -template -class ResultTypeIterator final - : public IndexedAccessorIterator< - ResultTypeIterator, ObjectType, - ElementType> { -public: - /// Initializes the result type iterator to the specified index. - ResultTypeIterator(ObjectType *object, unsigned index) - : IndexedAccessorIterator, - ObjectType, ElementType>(object, index) {} - - /// Support converting to the const variant. This will be a no-op for const - /// variant. - operator ResultTypeIterator() const { - return ResultTypeIterator(this->object, - this->index); - } - - Type operator*() const { - return this->object->getResult(this->index)->getType(); - } -}; - -// Implement the inline result iterator methods. -inline auto OperationInst::result_begin() -> result_iterator { - return result_iterator(this, 0); -} - -inline auto OperationInst::result_end() -> result_iterator { - return result_iterator(this, getNumResults()); -} - -inline auto OperationInst::getResults() - -> llvm::iterator_range { - return {result_begin(), result_end()}; -} - -inline auto OperationInst::result_begin() const -> const_result_iterator { - return const_result_iterator(this, 0); -} - -inline auto OperationInst::result_end() const -> const_result_iterator { - return const_result_iterator(this, getNumResults()); -} - -inline auto OperationInst::getResults() const - -> llvm::iterator_range { - return {result_begin(), result_end()}; -} - -inline auto OperationInst::result_type_begin() const -> result_type_iterator { - return result_type_iterator(this, 0); -} - -inline auto OperationInst::result_type_end() const -> result_type_iterator { - return result_type_iterator(this, getNumResults()); -} - -inline auto OperationInst::getResultTypes() const - -> llvm::iterator_range { - return {result_type_begin(), result_type_end()}; -} - -/// For statement represents an affine loop nest. -class ForStmt : public Statement, public Value { -public: - static ForStmt *create(Location location, ArrayRef lbOperands, - AffineMap lbMap, ArrayRef ubOperands, - AffineMap ubMap, int64_t step); - - ~ForStmt() { - // Explicitly erase statements instead of relying of 'Block' destructor - // since child statements need to be destroyed before the Value that this - // for stmt represents is destroyed. Affine maps are immortal objects and - // don't need to be deleted. - getBody()->clear(); - } - - /// Resolve base class ambiguity. - using Statement::getFunction; - - /// Operand iterators. - using operand_iterator = OperandIterator; - using const_operand_iterator = OperandIterator; - - /// Operand iterator range. - using operand_range = llvm::iterator_range; - using const_operand_range = llvm::iterator_range; - - /// Get the body of the ForStmt. - Block *getBody() { return &body.front(); } - - /// Get the body of the ForStmt. - const Block *getBody() const { return &body.front(); } - - //===--------------------------------------------------------------------===// - // Bounds and step - //===--------------------------------------------------------------------===// - - /// Returns information about the lower bound as a single object. - const AffineBound getLowerBound() const; - - /// Returns information about the upper bound as a single object. - const AffineBound getUpperBound() const; - - /// Returns loop step. - int64_t getStep() const { return step; } - - /// Returns affine map for the lower bound. - AffineMap getLowerBoundMap() const { return lbMap; } - /// Returns affine map for the upper bound. The upper bound is exclusive. - AffineMap getUpperBoundMap() const { return ubMap; } - - /// Set lower bound. - void setLowerBound(ArrayRef operands, AffineMap map); - /// Set upper bound. - void setUpperBound(ArrayRef operands, AffineMap map); - - /// Set the lower bound map without changing operands. - void setLowerBoundMap(AffineMap map); - - /// Set the upper bound map without changing operands. - void setUpperBoundMap(AffineMap map); - - /// Set loop step. - void setStep(int64_t step) { - assert(step > 0 && "step has to be a positive integer constant"); - this->step = step; - } - - /// Returns true if the lower bound is constant. - bool hasConstantLowerBound() const; - /// Returns true if the upper bound is constant. - bool hasConstantUpperBound() const; - /// Returns true if both bounds are constant. - bool hasConstantBounds() const { - return hasConstantLowerBound() && hasConstantUpperBound(); - } - /// Returns the value of the constant lower bound. - /// Fails assertion if the bound is non-constant. - int64_t getConstantLowerBound() const; - /// Returns the value of the constant upper bound. The upper bound is - /// exclusive. Fails assertion if the bound is non-constant. - int64_t getConstantUpperBound() const; - /// Sets the lower bound to the given constant value. - void setConstantLowerBound(int64_t value); - /// Sets the upper bound to the given constant value. - void setConstantUpperBound(int64_t value); - - /// Returns true if both the lower and upper bound have the same operand lists - /// (same operands in the same order). - bool matchingBoundOperandList() const; - - //===--------------------------------------------------------------------===// - // Operands - //===--------------------------------------------------------------------===// - - unsigned getNumOperands() const { return operands.size(); } - - Value *getOperand(unsigned idx) { return getInstOperand(idx).get(); } - const Value *getOperand(unsigned idx) const { - return getInstOperand(idx).get(); - } - void setOperand(unsigned idx, Value *value) { - getInstOperand(idx).set(value); - } - - operand_iterator operand_begin() { return operand_iterator(this, 0); } - operand_iterator operand_end() { - return operand_iterator(this, getNumOperands()); - } - - const_operand_iterator operand_begin() const { - return const_operand_iterator(this, 0); - } - const_operand_iterator operand_end() const { - return const_operand_iterator(this, getNumOperands()); - } - - ArrayRef getInstOperands() const { return operands; } - MutableArrayRef getInstOperands() { return operands; } - InstOperand &getInstOperand(unsigned idx) { return getInstOperands()[idx]; } - const InstOperand &getInstOperand(unsigned idx) const { - return getInstOperands()[idx]; - } - - // TODO: provide iterators for the lower and upper bound operands - // if the current access via getLowerBound(), getUpperBound() is too slow. - - /// Returns operands for the lower bound map. - operand_range getLowerBoundOperands(); - const_operand_range getLowerBoundOperands() const; - - /// Returns operands for the upper bound map. - operand_range getUpperBoundOperands(); - const_operand_range getUpperBoundOperands() const; - - //===--------------------------------------------------------------------===// - // Other - //===--------------------------------------------------------------------===// - - /// Return the context this operation is associated with. - MLIRContext *getContext() const { return getType().getContext(); } - - using Statement::dump; - using Statement::print; - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(const IROperandOwner *ptr) { - return ptr->getKind() == IROperandOwner::Kind::ForStmt; - } - - // For statement represents implicitly represents induction variable by - // inheriting from Value class. Whenever you need to refer to the loop - // induction variable, just use the for statement itself. - static bool classof(const Value *value) { - return value->getKind() == Value::Kind::ForStmt; - } - -private: - // The Block for the body. - BlockList body; - - // Affine map for the lower bound. - AffineMap lbMap; - // Affine map for the upper bound. The upper bound is exclusive. - AffineMap ubMap; - // Positive constant step. Since index is stored as an int64_t, we restrict - // step to the set of positive integers that int64_t can represent. - int64_t step; - // Operands for the lower and upper bounds, with the former followed by the - // latter. Dimensional operands are followed by symbolic operands for each - // bound. - std::vector operands; - - explicit ForStmt(Location location, unsigned numOperands, AffineMap lbMap, - AffineMap ubMap, int64_t step); -}; - -/// AffineBound represents a lower or upper bound in the for statement. -/// This class does not own the underlying operands. Instead, it refers -/// to the operands stored in the ForStmt. Its life span should not exceed -/// that of the for statement it refers to. -class AffineBound { -public: - const ForStmt *getForStmt() const { return &stmt; } - AffineMap getMap() const { return map; } - - unsigned getNumOperands() const { return opEnd - opStart; } - const Value *getOperand(unsigned idx) const { - return stmt.getOperand(opStart + idx); - } - const InstOperand &getInstOperand(unsigned idx) const { - return stmt.getInstOperand(opStart + idx); - } - - using operand_iterator = ForStmt::operand_iterator; - using operand_range = ForStmt::operand_range; - - operand_iterator operand_begin() const { - // These are iterators over Value *. Not casting away const'ness would - // require the caller to use const Value *. - return operand_iterator(const_cast(&stmt), opStart); - } - operand_iterator operand_end() const { - return operand_iterator(const_cast(&stmt), opEnd); - } - - /// Returns an iterator on the underlying Value's (Value *). - operand_range getOperands() const { return {operand_begin(), operand_end()}; } - ArrayRef getInstOperands() const { - auto ops = stmt.getInstOperands(); - return ArrayRef(ops.begin() + opStart, ops.begin() + opEnd); - } - -private: - // 'for' statement that contains this bound. - const ForStmt &stmt; - // Start and end positions of this affine bound operands in the list of - // the containing 'for' statement operands. - unsigned opStart, opEnd; - // Affine map for this bound. - AffineMap map; - - AffineBound(const ForStmt &stmt, unsigned opStart, unsigned opEnd, - AffineMap map) - : stmt(stmt), opStart(opStart), opEnd(opEnd), map(map) {} - - friend class ForStmt; -}; - -/// If statement restricts execution to a subset of the loop iteration space. -class IfStmt : public Statement { -public: - static IfStmt *create(Location location, ArrayRef operands, - IntegerSet set); - ~IfStmt(); - - //===--------------------------------------------------------------------===// - // Then, else, condition. - //===--------------------------------------------------------------------===// - - Block *getThen() { return &thenClause.front(); } - const Block *getThen() const { return &thenClause.front(); } - Block *getElse() { return elseClause ? &elseClause->front() : nullptr; } - const Block *getElse() const { - return elseClause ? &elseClause->front() : nullptr; - } - bool hasElse() const { return elseClause != nullptr; } - - Block *createElse() { - assert(elseClause == nullptr && "already has an else clause!"); - elseClause = new BlockList(this); - elseClause->push_back(new Block()); - return &elseClause->front(); - } - - const AffineCondition getCondition() const; - - IntegerSet getIntegerSet() const { return set; } - void setIntegerSet(IntegerSet newSet) { - assert(newSet.getNumOperands() == operands.size()); - set = newSet; - } - - //===--------------------------------------------------------------------===// - // Operands - //===--------------------------------------------------------------------===// - - /// Operand iterators. - using operand_iterator = OperandIterator; - using const_operand_iterator = OperandIterator; - - /// Operand iterator range. - using operand_range = llvm::iterator_range; - using const_operand_range = llvm::iterator_range; - - unsigned getNumOperands() const { return operands.size(); } - - Value *getOperand(unsigned idx) { return getInstOperand(idx).get(); } - const Value *getOperand(unsigned idx) const { - return getInstOperand(idx).get(); - } - void setOperand(unsigned idx, Value *value) { - getInstOperand(idx).set(value); - } - - operand_iterator operand_begin() { return operand_iterator(this, 0); } - operand_iterator operand_end() { - return operand_iterator(this, getNumOperands()); - } - - const_operand_iterator operand_begin() const { - return const_operand_iterator(this, 0); - } - const_operand_iterator operand_end() const { - return const_operand_iterator(this, getNumOperands()); - } - - ArrayRef getInstOperands() const { return operands; } - MutableArrayRef getInstOperands() { return operands; } - InstOperand &getInstOperand(unsigned idx) { return getInstOperands()[idx]; } - const InstOperand &getInstOperand(unsigned idx) const { - return getInstOperands()[idx]; - } - - //===--------------------------------------------------------------------===// - // Other - //===--------------------------------------------------------------------===// - - MLIRContext *getContext() const; - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(const IROperandOwner *ptr) { - return ptr->getKind() == IROperandOwner::Kind::IfStmt; - } - -private: - // it is always present. - BlockList thenClause; - // 'else' clause of the if statement. 'nullptr' if there is no else clause. - BlockList *elseClause; - - // The integer set capturing the conditional guard. - IntegerSet set; - - // Condition operands. - std::vector operands; - - explicit IfStmt(Location location, unsigned numOperands, IntegerSet set); -}; - -/// AffineCondition represents a condition of the 'if' statement. -/// Its life span should not exceed that of the objects it refers to. -/// AffineCondition does not provide its own methods for iterating over -/// the operands since the iterators of the if statement accomplish -/// the same purpose. -/// -/// AffineCondition is trivially copyable, so it should be passed by value. -class AffineCondition { -public: - const IfStmt *getIfStmt() const { return &stmt; } - IntegerSet getIntegerSet() const { return set; } - -private: - // 'if' statement that contains this affine condition. - const IfStmt &stmt; - // Integer set for this affine condition. - IntegerSet set; - - AffineCondition(const IfStmt &stmt, IntegerSet set) : stmt(stmt), set(set) {} - - friend class IfStmt; -}; -} // end namespace mlir - -#endif // MLIR_IR_STATEMENTS_H diff --git a/mlir/include/mlir/IR/StmtVisitor.h b/mlir/include/mlir/IR/StmtVisitor.h deleted file mode 100644 index 570036a0d99..00000000000 --- a/mlir/include/mlir/IR/StmtVisitor.h +++ /dev/null @@ -1,230 +0,0 @@ -//===- StmtVisitor.h - MLIR Instruction Visitor Class -----------*- C++ -*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file defines the base classes for Function's statement visitors and -// walkers. A visit is a O(1) operation that visits just the node in question. A -// walk visits the node it's called on as well as the node's descendants. -// -// Statement visitors/walkers are used when you want to perform different -// actions for different kinds of statements without having to use lots of casts -// and a big switch statement. -// -// To define your own visitor/walker, inherit from these classes, specifying -// your new type for the 'SubClass' template parameter, and "override" visitXXX -// functions in your class. This class is defined in terms of statically -// resolved overloading, not virtual functions. -// -// For example, here is a walker that counts the number of for loops in an -// Function. -// -// /// Declare the class. Note that we derive from StmtWalker instantiated -// /// with _our new subclasses_ type. -// struct LoopCounter : public StmtWalker { -// unsigned numLoops; -// LoopCounter() : numLoops(0) {} -// void visitForStmt(ForStmt &fs) { ++numLoops; } -// }; -// -// And this class would be used like this: -// LoopCounter lc; -// lc.walk(function); -// numLoops = lc.numLoops; -// -// There are 'visit' methods for OperationInst, ForStmt, IfStmt, and -// Function, which recursively process all contained statements. -// -// Note that if you don't implement visitXXX for some statement type, -// the visitXXX method for Statement superclass will be invoked. -// -// The optional second template argument specifies the type that statement -// visitation functions should return. If you specify this, you *MUST* provide -// an implementation of every visit<#Statement>(StmtType *). -// -// Note that these classes are specifically designed as a template to avoid -// virtual function call overhead. Defining and using a StmtVisitor is just -// as efficient as having your own switch statement over the statement -// opcode. - -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_IR_STMTVISITOR_H -#define MLIR_IR_STMTVISITOR_H - -#include "mlir/IR/Function.h" -#include "mlir/IR/Statements.h" - -namespace mlir { - -/// Base class for statement visitors. -template class StmtVisitor { - //===--------------------------------------------------------------------===// - // Interface code - This is the public interface of the StmtVisitor that you - // use to visit statements. - -public: - // Function to visit a statement. - RetTy visit(Statement *s) { - static_assert(std::is_base_of::value, - "Must pass the derived type to this template!"); - - switch (s->getKind()) { - case Statement::Kind::For: - return static_cast(this)->visitForStmt(cast(s)); - case Statement::Kind::If: - return static_cast(this)->visitIfStmt(cast(s)); - case Statement::Kind::OperationInst: - return static_cast(this)->visitOperationInst( - cast(s)); - } - } - - //===--------------------------------------------------------------------===// - // Visitation functions... these functions provide default fallbacks in case - // the user does not specify what to do for a particular statement type. - // The default behavior is to generalize the statement type to its subtype - // and try visiting the subtype. All of this should be inlined perfectly, - // because there are no virtual functions to get in the way. - // - - // When visiting a for stmt, if stmt, or an operation stmt directly, these - // methods get called to indicate when transitioning into a new unit. - void visitForStmt(ForStmt *forStmt) {} - void visitIfStmt(IfStmt *ifStmt) {} - void visitOperationInst(OperationInst *opStmt) {} -}; - -/// Base class for statement walkers. A walker can traverse depth first in -/// pre-order or post order. The walk methods without a suffix do a pre-order -/// traversal while those that traverse in post order have a PostOrder suffix. -template class StmtWalker { - //===--------------------------------------------------------------------===// - // Interface code - This is the public interface of the StmtWalker used to - // walk statements. - -public: - // Generic walk method - allow walk to all statements in a range. - template void walk(Iterator Start, Iterator End) { - while (Start != End) { - walk(&(*Start++)); - } - } - template void walkPostOrder(Iterator Start, Iterator End) { - while (Start != End) { - walkPostOrder(&(*Start++)); - } - } - - // Define walkers for Function and all Function statement kinds. - void walk(Function *f) { - static_cast(this)->visitMLFunction(f); - static_cast(this)->walk(f->getBody()->begin(), - f->getBody()->end()); - } - - void walkPostOrder(Function *f) { - static_cast(this)->walkPostOrder(f->getBody()->begin(), - f->getBody()->end()); - static_cast(this)->visitMLFunction(f); - } - - RetTy walkOpStmt(OperationInst *opStmt) { - return static_cast(this)->visitOperationInst(opStmt); - } - - void walkForStmt(ForStmt *forStmt) { - static_cast(this)->visitForStmt(forStmt); - auto *body = forStmt->getBody(); - static_cast(this)->walk(body->begin(), body->end()); - } - - void walkForStmtPostOrder(ForStmt *forStmt) { - auto *body = forStmt->getBody(); - static_cast(this)->walkPostOrder(body->begin(), body->end()); - static_cast(this)->visitForStmt(forStmt); - } - - void walkIfStmt(IfStmt *ifStmt) { - static_cast(this)->visitIfStmt(ifStmt); - static_cast(this)->walk(ifStmt->getThen()->begin(), - ifStmt->getThen()->end()); - if (ifStmt->hasElse()) - static_cast(this)->walk(ifStmt->getElse()->begin(), - ifStmt->getElse()->end()); - } - - void walkIfStmtPostOrder(IfStmt *ifStmt) { - static_cast(this)->walkPostOrder(ifStmt->getThen()->begin(), - ifStmt->getThen()->end()); - if (ifStmt->hasElse()) - static_cast(this)->walkPostOrder(ifStmt->getElse()->begin(), - ifStmt->getElse()->end()); - static_cast(this)->visitIfStmt(ifStmt); - } - - // Function to walk a statement. - RetTy walk(Statement *s) { - static_assert(std::is_base_of::value, - "Must pass the derived type to this template!"); - - switch (s->getKind()) { - case Statement::Kind::For: - return static_cast(this)->walkForStmt(cast(s)); - case Statement::Kind::If: - return static_cast(this)->walkIfStmt(cast(s)); - case Statement::Kind::OperationInst: - return static_cast(this)->walkOpStmt(cast(s)); - } - } - - // Function to walk a statement in post order DFS. - RetTy walkPostOrder(Statement *s) { - static_assert(std::is_base_of::value, - "Must pass the derived type to this template!"); - - switch (s->getKind()) { - case Statement::Kind::For: - return static_cast(this)->walkForStmtPostOrder( - cast(s)); - case Statement::Kind::If: - return static_cast(this)->walkIfStmtPostOrder( - cast(s)); - case Statement::Kind::OperationInst: - return static_cast(this)->walkOpStmt(cast(s)); - } - } - - //===--------------------------------------------------------------------===// - // Visitation functions... these functions provide default fallbacks in case - // the user does not specify what to do for a particular statement type. - // The default behavior is to generalize the statement type to its subtype - // and try visiting the subtype. All of this should be inlined perfectly, - // because there are no virtual functions to get in the way. - - // When visiting a specific stmt directly during a walk, these methods get - // called. These are typically O(1) complexity and shouldn't be recursively - // processing their descendants in some way. When using RetTy, all of these - // need to be overridden. - void visitMLFunction(Function *f) {} - void visitForStmt(ForStmt *forStmt) {} - void visitIfStmt(IfStmt *ifStmt) {} - void visitOperationInst(OperationInst *opStmt) {} -}; - -} // end namespace mlir - -#endif // MLIR_IR_STMTVISITOR_H diff --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h index 4b6ad287a87..17f7616b73d 100644 --- a/mlir/include/mlir/IR/UseDefLists.h +++ b/mlir/include/mlir/IR/UseDefLists.h @@ -72,16 +72,16 @@ private: }; /// Subclasses of IROperandOwner can be the owner of an IROperand. In practice -/// this is the common base between Instruction and Statement. +/// this is the common base between Instruction and Instruction. class IROperandOwner { public: enum class Kind { OperationInst, - ForStmt, - IfStmt, + ForInst, + IfInst, /// These enums define ranges used for classof implementations. - STMT_LAST = IfStmt, + INST_LAST = IfInst, }; Kind getKind() const { return locationAndKind.getInt(); } @@ -106,7 +106,7 @@ private: }; /// A reference to a value, suitable for use as an operand of an instruction, -/// statement, etc. +/// instruction, etc. class IROperand { public: IROperand(IROperandOwner *owner) : owner(owner) {} @@ -201,7 +201,7 @@ private: }; /// A reference to a value, suitable for use as an operand of an instruction, -/// statement, etc. IRValueTy is the root type to use for values this tracks, +/// instruction, etc. IRValueTy is the root type to use for values this tracks, /// and SSAUserTy is the type that will contain operands. template class IROperandImpl : public IROperand { diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h index 2213fe79852..47dd43f193f 100644 --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -30,12 +30,11 @@ namespace mlir { class Block; class Function; class OperationInst; -class Statement; +class Instruction; class Value; -using Instruction = Statement; /// Operands contain a Value. -using InstOperand = IROperandImpl; +using InstOperand = IROperandImpl; /// This is the common base class for all SSA values in the MLIR system, /// representing a computable value that has a type and a set of users. @@ -46,7 +45,7 @@ public: enum class Kind { BlockArgument, // block argument InstResult, // operation instruction result - ForStmt, // 'for' statement induction variable + ForInst, // 'for' instruction induction variable }; ~Value() {} @@ -86,7 +85,7 @@ public: return const_cast(this)->getDefiningInst(); } - using use_iterator = ValueUseIterator; + using use_iterator = ValueUseIterator; using use_range = llvm::iterator_range; inline use_iterator use_begin() const; diff --git a/mlir/include/mlir/Support/Functional.h b/mlir/include/mlir/Support/Functional.h index e1b1ee5ce58..19baeccbf43 100644 --- a/mlir/include/mlir/Support/Functional.h +++ b/mlir/include/mlir/Support/Functional.h @@ -81,7 +81,7 @@ void zipApply(Fun fun, ContainerType1 input1, ContainerType2 input2) { /// Unwraps a pointer type to another type (possibly the same). /// Used in particular to allow easier compositions of -/// llvm::iterator_range types. +/// llvm::iterator_range types. template inline std::function makePtrDynCaster() { return [](T *val) { return llvm::dyn_cast(val); }; diff --git a/mlir/include/mlir/Transforms/LoopUtils.h b/mlir/include/mlir/Transforms/LoopUtils.h index 2694433d5a0..e0cf3039f07 100644 --- a/mlir/include/mlir/Transforms/LoopUtils.h +++ b/mlir/include/mlir/Transforms/LoopUtils.h @@ -29,7 +29,7 @@ namespace mlir { class AffineMap; -class ForStmt; +class ForInst; class Function; class FuncBuilder; @@ -42,53 +42,53 @@ struct LLVM_NODISCARD UtilResult { operator bool() const { return value == Failure; } }; -/// Unrolls this for statement completely if the trip count is known to be +/// Unrolls this for instruction completely if the trip count is known to be /// constant. Returns false otherwise. -bool loopUnrollFull(ForStmt *forStmt); -/// Unrolls this for statement by the specified unroll factor. Returns false if -/// the loop cannot be unrolled either due to restrictions or due to invalid +bool loopUnrollFull(ForInst *forInst); +/// Unrolls this for instruction by the specified unroll factor. Returns false +/// if the loop cannot be unrolled either due to restrictions or due to invalid /// unroll factors. -bool loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor); +bool loopUnrollByFactor(ForInst *forInst, uint64_t unrollFactor); /// Unrolls this loop by the specified unroll factor or its trip count, /// whichever is lower. -bool loopUnrollUpToFactor(ForStmt *forStmt, uint64_t unrollFactor); +bool loopUnrollUpToFactor(ForInst *forInst, uint64_t unrollFactor); /// Unrolls and jams this loop by the specified factor. Returns true if the loop /// is successfully unroll-jammed. -bool loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor); +bool loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor); /// Unrolls and jams this loop by the specified factor or by the trip count (if /// constant), whichever is lower. -bool loopUnrollJamUpToFactor(ForStmt *forStmt, uint64_t unrollJamFactor); +bool loopUnrollJamUpToFactor(ForInst *forInst, uint64_t unrollJamFactor); -/// Promotes the loop body of a ForStmt to its containing block if the ForStmt +/// Promotes the loop body of a ForInst to its containing block if the ForInst /// was known to have a single iteration. Returns false otherwise. -bool promoteIfSingleIteration(ForStmt *forStmt); +bool promoteIfSingleIteration(ForInst *forInst); -/// Promotes all single iteration ForStmt's in the Function, i.e., moves +/// Promotes all single iteration ForInst's in the Function, i.e., moves /// their body into the containing Block. void promoteSingleIterationLoops(Function *f); /// Returns the lower bound of the cleanup loop when unrolling a loop /// with the specified unroll factor. -AffineMap getCleanupLoopLowerBound(const ForStmt &forStmt, +AffineMap getCleanupLoopLowerBound(const ForInst &forInst, unsigned unrollFactor, FuncBuilder *builder); /// Returns the upper bound of an unrolled loop when unrolling with /// the specified trip count, stride, and unroll factor. -AffineMap getUnrolledLoopUpperBound(const ForStmt &forStmt, +AffineMap getUnrolledLoopUpperBound(const ForInst &forInst, unsigned unrollFactor, FuncBuilder *builder); -/// Skew the statements in the body of a 'for' statement with the specified -/// statement-wise shifts. The shifts are with respect to the original execution -/// order, and are multiplied by the loop 'step' before being applied. -UtilResult stmtBodySkew(ForStmt *forStmt, ArrayRef shifts, +/// Skew the instructions in the body of a 'for' instruction with the specified +/// instruction-wise shifts. The shifts are with respect to the original +/// execution order, and are multiplied by the loop 'step' before being applied. +UtilResult instBodySkew(ForInst *forInst, ArrayRef shifts, bool unrollPrologueEpilogue = false); /// Tiles the specified band of perfectly nested loops creating tile-space loops /// and intra-tile loops. A band is a contiguous set of loops. -UtilResult tileCodeGen(ArrayRef band, ArrayRef tileSizes); +UtilResult tileCodeGen(ArrayRef band, ArrayRef tileSizes); } // end namespace mlir diff --git a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h index c75fddbb4de..4e34889e077 100644 --- a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h +++ b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h @@ -66,7 +66,7 @@ public: /// must override). It will be passed the function-wise state, common to all /// matches, and the state returned by the `match` call, if any. The subclass /// must use `rewriter` to modify the function. - virtual void rewriteOpStmt(OperationInst *op, + virtual void rewriteOpInst(OperationInst *op, MLFuncGlobalLoweringState *funcWiseState, std::unique_ptr opState, MLFuncLoweringRewriter *rewriter) const = 0; @@ -93,7 +93,7 @@ using OwningMLLoweringPatternList = /// next _original_ operation is considered. /// In other words, for each operation, the pass applies the first matching /// rewriter in the list and advances to the (lexically) next operation. -/// Non-operation statements (ForStmt and IfStmt) are ignored. +/// Non-operation instructions (ForInst and IfInst) are ignored. /// This is similar to greedy worklist-based pattern rewriter, except that this /// operates on ML functions using an ML builder and does not maintain the work /// list. Note that, as of the time of writing, worklist-based rewriter did not @@ -144,14 +144,14 @@ PassResult MLPatternLoweringPass::runOnMLFunction(Function *f) { MLFuncLoweringRewriter rewriter(&builder); llvm::SmallVector ops; - f->walk([&ops](OperationInst *stmt) { ops.push_back(stmt); }); + f->walk([&ops](OperationInst *inst) { ops.push_back(inst); }); - for (OperationInst *stmt : ops) { + for (OperationInst *inst : ops) { for (const auto &pattern : patterns) { - rewriter.getBuilder()->setInsertionPoint(stmt); - auto matchResult = pattern->match(stmt); + rewriter.getBuilder()->setInsertionPoint(inst); + auto matchResult = pattern->match(inst); if (matchResult) { - pattern->rewriteOpStmt(stmt, funcWiseState.get(), + pattern->rewriteOpInst(inst, funcWiseState.get(), std::move(*matchResult), &rewriter); break; } diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index fd376fbbb97..acf07c5143f 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -27,7 +27,7 @@ namespace mlir { -class ForStmt; +class ForInst; class FunctionPass; class ModulePass; @@ -59,7 +59,7 @@ FunctionPass *createMaterializeVectorsPass(); /// all) or the default unroll factor is used (LoopUnroll:kDefaultUnrollFactor). FunctionPass *createLoopUnrollPass( int unrollFactor = -1, int unrollFull = -1, - const std::function &getUnrollFactor = nullptr); + const std::function &getUnrollFactor = nullptr); /// Creates a loop unroll jam pass to unroll jam by the specified factor. A /// factor of -1 lets the pass use the default factor or the one on the command diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h index f33f774bb22..c63eb0349a7 100644 --- a/mlir/include/mlir/Transforms/Utils.h +++ b/mlir/include/mlir/Transforms/Utils.h @@ -32,7 +32,7 @@ namespace mlir { -class ForStmt; +class ForInst; class FuncBuilder; class Location; class Module; @@ -45,7 +45,7 @@ class Function; /// indices. Additional indices are added at the start. The new memref could be /// of a different shape or rank. 'extraOperands' is an optional argument that /// corresponds to additional operands (inputs) for indexRemap at the beginning -/// of its input list. An additional optional argument 'domStmtFilter' restricts +/// of its input list. An additional optional argument 'domInstFilter' restricts /// the replacement to only those operations that are dominated by the former. /// Returns true on success and false if the replacement is not possible /// (whenever a memref is used as an operand in a non-deferencing scenario). See @@ -56,7 +56,7 @@ bool replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, ArrayRef extraIndices = {}, AffineMap indexRemap = AffineMap::Null(), ArrayRef extraOperands = {}, - const Statement *domStmtFilter = nullptr); + const Instruction *domInstFilter = nullptr); /// Creates and inserts into 'builder' a new AffineApplyOp, with the number of /// its results equal to the number of operands, as a composition @@ -71,10 +71,10 @@ createComposedAffineApplyOp(FuncBuilder *builder, Location loc, ArrayRef affineApplyOps, SmallVectorImpl *results); -/// Given an operation statement, inserts a new single affine apply operation, -/// that is exclusively used by this operation statement, and that provides all -/// operands that are results of an affine_apply as a function of loop iterators -/// and program parameters and whose results are. +/// Given an operation instruction, inserts a new single affine apply operation, +/// that is exclusively used by this operation instruction, and that provides +/// all operands that are results of an affine_apply as a function of loop +/// iterators and program parameters and whose results are. /// /// Before /// @@ -96,8 +96,8 @@ createComposedAffineApplyOp(FuncBuilder *builder, Location loc, /// /// Returns nullptr if none of the operands were the result of an affine_apply /// and thus there was no affine computation slice to create. Returns the newly -/// affine_apply operation statement otherwise. -OperationInst *createAffineComputationSlice(OperationInst *opStmt); +/// affine_apply operation instruction otherwise. +OperationInst *createAffineComputationSlice(OperationInst *opInst); /// Forward substitutes results from 'AffineApplyOp' into any users which /// are also AffineApplyOps. @@ -105,9 +105,9 @@ OperationInst *createAffineComputationSlice(OperationInst *opStmt); // TODO(mlir-team): extend this for Value/ CFGFunctions. void forwardSubstitute(OpPointer affineApplyOp); -/// Folds the lower and upper bounds of a 'for' stmt to constants if possible. +/// Folds the lower and upper bounds of a 'for' inst to constants if possible. /// Returns false if the folding happens for at least one bound, true otherwise. -bool constantFoldBounds(ForStmt *forStmt); +bool constantFoldBounds(ForInst *forInst); /// Replaces (potentially nested) function attributes in the operation "op" /// with those specified in "remappingTable". diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index f01735f26e1..8058af06b55 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -25,7 +25,7 @@ #include "mlir/Analysis/Utils.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Statements.h" +#include "mlir/IR/Instructions.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Support/Functional.h" #include "mlir/Support/MathExtras.h" @@ -498,22 +498,22 @@ void mlir::getReachableAffineApplyOps( while (!worklist.empty()) { State &state = worklist.back(); - auto *opStmt = state.value->getDefiningInst(); + auto *opInst = state.value->getDefiningInst(); // Note: getDefiningInst will return nullptr if the operand is not an - // OperationInst (i.e. ForStmt), which is a terminator for the search. - if (opStmt == nullptr || !opStmt->isa()) { + // OperationInst (i.e. ForInst), which is a terminator for the search. + if (opInst == nullptr || !opInst->isa()) { worklist.pop_back(); continue; } - if (auto affineApplyOp = opStmt->dyn_cast()) { + if (auto affineApplyOp = opInst->dyn_cast()) { if (state.operandIndex == 0) { - // Pre-Visit: Add 'opStmt' to reachable sequence. - affineApplyOps.push_back(opStmt); + // Pre-Visit: Add 'opInst' to reachable sequence. + affineApplyOps.push_back(opInst); } - if (state.operandIndex < opStmt->getNumOperands()) { + if (state.operandIndex < opInst->getNumOperands()) { // Visit: Add next 'affineApplyOp' operand to worklist. // Get next operand to visit at 'operandIndex'. - auto *nextOperand = opStmt->getOperand(state.operandIndex); + auto *nextOperand = opInst->getOperand(state.operandIndex); // Increment 'operandIndex' in 'state'. ++state.operandIndex; // Add 'nextOperand' to worklist. @@ -533,47 +533,47 @@ void mlir::forwardSubstituteReachableOps(AffineValueMap *valueMap) { SmallVector affineApplyOps; getReachableAffineApplyOps(valueMap->getOperands(), affineApplyOps); // Compose AffineApplyOps in 'affineApplyOps'. - for (auto *opStmt : affineApplyOps) { - assert(opStmt->isa()); - auto affineApplyOp = opStmt->dyn_cast(); + for (auto *opInst : affineApplyOps) { + assert(opInst->isa()); + auto affineApplyOp = opInst->dyn_cast(); // Forward substitute 'affineApplyOp' into 'valueMap'. valueMap->forwardSubstitute(*affineApplyOp); } } // Builds a system of constraints with dimensional identifiers corresponding to -// the loop IVs of the forStmts appearing in that order. Any symbols founds in +// the loop IVs of the forInsts appearing in that order. Any symbols founds in // the bound operands are added as symbols in the system. Returns false for the // yet unimplemented cases. // TODO(andydavis,bondhugula) Handle non-unit steps through local variables or // stride information in FlatAffineConstraints. (For eg., by using iv - lb % // step = 0 and/or by introducing a method in FlatAffineConstraints // setExprStride(ArrayRef expr, int64_t stride) -bool mlir::getIndexSet(ArrayRef forStmts, +bool mlir::getIndexSet(ArrayRef forInsts, FlatAffineConstraints *domain) { - SmallVector indices(forStmts.begin(), forStmts.end()); + SmallVector indices(forInsts.begin(), forInsts.end()); // Reset while associated Values in 'indices' to the domain. - domain->reset(forStmts.size(), /*numSymbols=*/0, /*numLocals=*/0, indices); - for (auto *forStmt : forStmts) { - // Add constraints from forStmt's bounds. - if (!domain->addForStmtDomain(*forStmt)) + domain->reset(forInsts.size(), /*numSymbols=*/0, /*numLocals=*/0, indices); + for (auto *forInst : forInsts) { + // Add constraints from forInst's bounds. + if (!domain->addForInstDomain(*forInst)) return false; } return true; } -// Computes the iteration domain for 'opStmt' and populates 'indexSet', which -// encapsulates the constraints involving loops surrounding 'opStmt' and +// Computes the iteration domain for 'opInst' and populates 'indexSet', which +// encapsulates the constraints involving loops surrounding 'opInst' and // potentially involving any Function symbols. The dimensional identifiers in -// 'indexSet' correspond to the loops surounding 'stmt' from outermost to +// 'indexSet' correspond to the loops surounding 'inst' from outermost to // innermost. -// TODO(andydavis) Add support to handle IfStmts surrounding 'stmt'. -static bool getStmtIndexSet(const Statement *stmt, +// TODO(andydavis) Add support to handle IfInsts surrounding 'inst'. +static bool getInstIndexSet(const Instruction *inst, FlatAffineConstraints *indexSet) { - // TODO(andydavis) Extend this to gather enclosing IfStmts and consider + // TODO(andydavis) Extend this to gather enclosing IfInsts and consider // factoring it out into a utility function. - SmallVector loops; - getLoopIVs(*stmt, &loops); + SmallVector loops; + getLoopIVs(*inst, &loops); return getIndexSet(loops, indexSet); } @@ -672,7 +672,7 @@ static void buildDimAndSymbolPositionMaps( auto updateValuePosMap = [&](ArrayRef values, bool isSrc) { for (unsigned i = 0, e = values.size(); i < e; ++i) { auto *value = values[i]; - if (!isa(values[i])) + if (!isa(values[i])) valuePosMap->addSymbolValue(value); else if (isSrc) valuePosMap->addSrcValue(value); @@ -840,13 +840,13 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap, // Add equality constraints for any operands that are defined by constant ops. auto addEqForConstOperands = [&](ArrayRef operands) { for (unsigned i = 0, e = operands.size(); i < e; ++i) { - if (isa(operands[i])) + if (isa(operands[i])) continue; auto *symbol = operands[i]; assert(symbol->isValidSymbol()); // Check if the symbol is a constant. - if (auto *opStmt = symbol->getDefiningInst()) { - if (auto constOp = opStmt->dyn_cast()) { + if (auto *opInst = symbol->getDefiningInst()) { + if (auto constOp = opInst->dyn_cast()) { dependenceDomain->setIdToConstant(valuePosMap.getSymPos(symbol), constOp->getValue()); } @@ -909,8 +909,8 @@ static unsigned getNumCommonLoops(const FlatAffineConstraints &srcDomain, std::min(srcDomain.getNumDimIds(), dstDomain.getNumDimIds()); unsigned numCommonLoops = 0; for (unsigned i = 0; i < minNumLoops; ++i) { - if (!isa(srcDomain.getIdValue(i)) || - !isa(dstDomain.getIdValue(i)) || + if (!isa(srcDomain.getIdValue(i)) || + !isa(dstDomain.getIdValue(i)) || srcDomain.getIdValue(i) != dstDomain.getIdValue(i)) break; ++numCommonLoops; @@ -918,26 +918,26 @@ static unsigned getNumCommonLoops(const FlatAffineConstraints &srcDomain, return numCommonLoops; } -// Returns Block common to 'srcAccess.opStmt' and 'dstAccess.opStmt'. +// Returns Block common to 'srcAccess.opInst' and 'dstAccess.opInst'. static Block *getCommonBlock(const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, const FlatAffineConstraints &srcDomain, unsigned numCommonLoops) { if (numCommonLoops == 0) { - auto *block = srcAccess.opStmt->getBlock(); + auto *block = srcAccess.opInst->getBlock(); while (block->getContainingInst()) { block = block->getContainingInst()->getBlock(); } return block; } auto *commonForValue = srcDomain.getIdValue(numCommonLoops - 1); - assert(isa(commonForValue)); - return cast(commonForValue)->getBody(); + assert(isa(commonForValue)); + return cast(commonForValue)->getBody(); } -// Returns true if the ancestor operation statement of 'srcAccess' properly -// dominates the ancestor operation statement of 'dstAccess' in the same -// statement block. Returns false otherwise. +// Returns true if the ancestor operation instruction of 'srcAccess' properly +// dominates the ancestor operation instruction of 'dstAccess' in the same +// instruction block. Returns false otherwise. // Note that because 'srcAccess' or 'dstAccess' may be nested in conditionals, // the function is named 'srcMayExecuteBeforeDst'. // Note that 'numCommonLoops' is the number of contiguous surrounding outer @@ -946,16 +946,16 @@ static bool srcMayExecuteBeforeDst(const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, const FlatAffineConstraints &srcDomain, unsigned numCommonLoops) { - // Get Block common to 'srcAccess.opStmt' and 'dstAccess.opStmt'. + // Get Block common to 'srcAccess.opInst' and 'dstAccess.opInst'. auto *commonBlock = getCommonBlock(srcAccess, dstAccess, srcDomain, numCommonLoops); // Check the dominance relationship between the respective ancestors of the // src and dst in the Block of the innermost among the common loops. - auto *srcStmt = commonBlock->findAncestorInstInBlock(*srcAccess.opStmt); - assert(srcStmt != nullptr); - auto *dstStmt = commonBlock->findAncestorInstInBlock(*dstAccess.opStmt); - assert(dstStmt != nullptr); - return mlir::properlyDominates(*srcStmt, *dstStmt); + auto *srcInst = commonBlock->findAncestorInstInBlock(*srcAccess.opInst); + assert(srcInst != nullptr); + auto *dstInst = commonBlock->findAncestorInstInBlock(*dstAccess.opInst); + assert(dstInst != nullptr); + return mlir::properlyDominates(*srcInst, *dstInst); } // Adds ordering constraints to 'dependenceDomain' based on number of loops @@ -1119,7 +1119,7 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const { // until operands of the AffineValueMap are loop IVs or symbols. // *) Build iteration domain constraints for each access. Iteration domain // constraints are pairs of inequality contraints representing the -// upper/lower loop bounds for each ForStmt in the loop nest associated +// upper/lower loop bounds for each ForInst in the loop nest associated // with each access. // *) Build dimension and symbol position maps for each access, which map // Values from access functions and iteration domains to their position @@ -1197,7 +1197,7 @@ bool mlir::checkMemrefAccessDependence( if (srcAccess.memref != dstAccess.memref) return false; // Return 'false' if one of these accesses is not a StoreOp. - if (!srcAccess.opStmt->isa() && !dstAccess.opStmt->isa()) + if (!srcAccess.opInst->isa() && !dstAccess.opInst->isa()) return false; // Get composed access function for 'srcAccess'. @@ -1208,19 +1208,19 @@ bool mlir::checkMemrefAccessDependence( AffineValueMap dstAccessMap; dstAccess.getAccessMap(&dstAccessMap); - // Get iteration domain for the 'srcAccess' statement. + // Get iteration domain for the 'srcAccess' instruction. FlatAffineConstraints srcDomain; - if (!getStmtIndexSet(srcAccess.opStmt, &srcDomain)) + if (!getInstIndexSet(srcAccess.opInst, &srcDomain)) return false; - // Get iteration domain for 'dstAccess' statement. + // Get iteration domain for 'dstAccess' instruction. FlatAffineConstraints dstDomain; - if (!getStmtIndexSet(dstAccess.opStmt, &dstDomain)) + if (!getInstIndexSet(dstAccess.opInst, &dstDomain)) return false; // Return 'false' if loopDepth > numCommonLoops and if the ancestor operation - // statement of 'srcAccess' does not properly dominate the ancestor operation - // statement of 'dstAccess' in the same common statement block. + // instruction of 'srcAccess' does not properly dominate the ancestor + // operation instruction of 'dstAccess' in the same common instruction block. unsigned numCommonLoops = getNumCommonLoops(srcDomain, dstDomain); assert(loopDepth <= numCommonLoops + 1); if (loopDepth > numCommonLoops && diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index a45c5ffdf5e..d4b8a05dbf8 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -24,8 +24,8 @@ #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Instructions.h" #include "mlir/IR/IntegerSet.h" -#include "mlir/IR/Statements.h" #include "mlir/Support/MathExtras.h" #include "llvm/ADT/DenseSet.h" #include "llvm/Support/Debug.h" @@ -1248,22 +1248,22 @@ void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) { numSymbols = newSymbolCount; } -bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) { +bool FlatAffineConstraints::addForInstDomain(const ForInst &forInst) { unsigned pos; // Pre-condition for this method. - if (!findId(forStmt, &pos)) { + if (!findId(forInst, &pos)) { assert(0 && "Value not found"); return false; } - if (forStmt.getStep() != 1) + if (forInst.getStep() != 1) LLVM_DEBUG(llvm::dbgs() << "Domain conservative: non-unit stride not handled\n"); // Adds a lower or upper bound when the bounds aren't constant. auto addLowerOrUpperBound = [&](bool lower) -> bool { - auto operands = lower ? forStmt.getLowerBoundOperands() - : forStmt.getUpperBoundOperands(); + auto operands = lower ? forInst.getLowerBoundOperands() + : forInst.getUpperBoundOperands(); for (const auto &operand : operands) { unsigned loc; if (!findId(*operand, &loc)) { @@ -1271,8 +1271,8 @@ bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) { addSymbolId(getNumSymbolIds(), const_cast(operand)); loc = getNumDimIds() + getNumSymbolIds() - 1; // Check if the symbol is a constant. - if (auto *opStmt = operand->getDefiningInst()) { - if (auto constOp = opStmt->dyn_cast()) { + if (auto *opInst = operand->getDefiningInst()) { + if (auto constOp = opInst->dyn_cast()) { setIdToConstant(*operand, constOp->getValue()); } } @@ -1292,7 +1292,7 @@ bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) { } auto boundMap = - lower ? forStmt.getLowerBoundMap() : forStmt.getUpperBoundMap(); + lower ? forInst.getLowerBoundMap() : forInst.getUpperBoundMap(); FlatAffineConstraints localVarCst; std::vector> flatExprs; @@ -1322,16 +1322,16 @@ bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) { return true; }; - if (forStmt.hasConstantLowerBound()) { - addConstantLowerBound(pos, forStmt.getConstantLowerBound()); + if (forInst.hasConstantLowerBound()) { + addConstantLowerBound(pos, forInst.getConstantLowerBound()); } else { // Non-constant lower bound case. if (!addLowerOrUpperBound(/*lower=*/true)) return false; } - if (forStmt.hasConstantUpperBound()) { - addConstantUpperBound(pos, forStmt.getConstantUpperBound() - 1); + if (forInst.hasConstantUpperBound()) { + addConstantUpperBound(pos, forInst.getConstantUpperBound() - 1); return true; } // Non-constant upper bound case. diff --git a/mlir/lib/Analysis/Dominance.cpp b/mlir/lib/Analysis/Dominance.cpp index 0c8db07dbb4..4ee1b393068 100644 --- a/mlir/lib/Analysis/Dominance.cpp +++ b/mlir/lib/Analysis/Dominance.cpp @@ -21,7 +21,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/Dominance.h" -#include "mlir/IR/Statements.h" +#include "mlir/IR/Instructions.h" #include "llvm/Support/GenericDomTreeConstruction.h" using namespace mlir; diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index dd14f38df55..b66b665c563 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -27,7 +27,7 @@ #include "mlir/Analysis/VectorAnalysis.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Statements.h" +#include "mlir/IR/Instructions.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/SuperVectorOps/SuperVectorOps.h" #include "mlir/Support/Functional.h" @@ -42,27 +42,27 @@ using namespace mlir; /// Returns the trip count of the loop as an affine expression if the latter is /// expressible as an affine expression, and nullptr otherwise. The trip count /// expression is simplified before returning. -AffineExpr mlir::getTripCountExpr(const ForStmt &forStmt) { +AffineExpr mlir::getTripCountExpr(const ForInst &forInst) { // upper_bound - lower_bound int64_t loopSpan; - int64_t step = forStmt.getStep(); - auto *context = forStmt.getContext(); + int64_t step = forInst.getStep(); + auto *context = forInst.getContext(); - if (forStmt.hasConstantBounds()) { - int64_t lb = forStmt.getConstantLowerBound(); - int64_t ub = forStmt.getConstantUpperBound(); + if (forInst.hasConstantBounds()) { + int64_t lb = forInst.getConstantLowerBound(); + int64_t ub = forInst.getConstantUpperBound(); loopSpan = ub - lb; } else { - auto lbMap = forStmt.getLowerBoundMap(); - auto ubMap = forStmt.getUpperBoundMap(); + auto lbMap = forInst.getLowerBoundMap(); + auto ubMap = forInst.getUpperBoundMap(); // TODO(bondhugula): handle max/min of multiple expressions. if (lbMap.getNumResults() != 1 || ubMap.getNumResults() != 1) return nullptr; // TODO(bondhugula): handle bounds with different operands. // Bounds have different operands, unhandled for now. - if (!forStmt.matchingBoundOperandList()) + if (!forInst.matchingBoundOperandList()) return nullptr; // ub_expr - lb_expr @@ -88,8 +88,8 @@ AffineExpr mlir::getTripCountExpr(const ForStmt &forStmt) { /// Returns the trip count of the loop if it's a constant, None otherwise. This /// method uses affine expression analysis (in turn using getTripCount) and is /// able to determine constant trip count in non-trivial cases. -llvm::Optional mlir::getConstantTripCount(const ForStmt &forStmt) { - auto tripCountExpr = getTripCountExpr(forStmt); +llvm::Optional mlir::getConstantTripCount(const ForInst &forInst) { + auto tripCountExpr = getTripCountExpr(forInst); if (!tripCountExpr) return None; @@ -103,8 +103,8 @@ llvm::Optional mlir::getConstantTripCount(const ForStmt &forStmt) { /// Returns the greatest known integral divisor of the trip count. Affine /// expression analysis is used (indirectly through getTripCount), and /// this method is thus able to determine non-trivial divisors. -uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) { - auto tripCountExpr = getTripCountExpr(forStmt); +uint64_t mlir::getLargestDivisorOfTripCount(const ForInst &forInst) { + auto tripCountExpr = getTripCountExpr(forInst); if (!tripCountExpr) return 1; @@ -125,7 +125,7 @@ uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) { } bool mlir::isAccessInvariant(const Value &iv, const Value &index) { - assert(isa(iv) && "iv must be a ForStmt"); + assert(isa(iv) && "iv must be a ForInst"); assert(index.getType().isa() && "index must be of IndexType"); SmallVector affineApplyOps; getReachableAffineApplyOps({const_cast(&index)}, affineApplyOps); @@ -172,7 +172,7 @@ mlir::getInvariantAccesses(const Value &iv, } /// Given: -/// 1. an induction variable `iv` of type ForStmt; +/// 1. an induction variable `iv` of type ForInst; /// 2. a `memoryOp` of type const LoadOp& or const StoreOp&; /// 3. the index of the `fastestVaryingDim` along which to check; /// determines whether `memoryOp`[`fastestVaryingDim`] is a contiguous access @@ -233,37 +233,37 @@ static bool isVectorElement(LoadOrStoreOpPointer memoryOp) { return memRefType.getElementType().template isa(); } -static bool isVectorTransferReadOrWrite(const Statement &stmt) { - const auto *opStmt = cast(&stmt); - return opStmt->isa() || - opStmt->isa(); +static bool isVectorTransferReadOrWrite(const Instruction &inst) { + const auto *opInst = cast(&inst); + return opInst->isa() || + opInst->isa(); } -using VectorizableStmtFun = - std::function; +using VectorizableInstFun = + std::function; -static bool isVectorizableLoopWithCond(const ForStmt &loop, - VectorizableStmtFun isVectorizableStmt) { +static bool isVectorizableLoopWithCond(const ForInst &loop, + VectorizableInstFun isVectorizableInst) { if (!matcher::isParallelLoop(loop) && !matcher::isReductionLoop(loop)) { return false; } // No vectorization across conditionals for now. auto conditionals = matcher::If(); - auto *forStmt = const_cast(&loop); - auto conditionalsMatched = conditionals.match(forStmt); + auto *forInst = const_cast(&loop); + auto conditionalsMatched = conditionals.match(forInst); if (!conditionalsMatched.empty()) { return false; } auto vectorTransfers = matcher::Op(isVectorTransferReadOrWrite); - auto vectorTransfersMatched = vectorTransfers.match(forStmt); + auto vectorTransfersMatched = vectorTransfers.match(forInst); if (!vectorTransfersMatched.empty()) { return false; } auto loadAndStores = matcher::Op(matcher::isLoadOrStore); - auto loadAndStoresMatched = loadAndStores.match(forStmt); + auto loadAndStoresMatched = loadAndStores.match(forInst); for (auto ls : loadAndStoresMatched) { auto *op = cast(ls.first); auto load = op->dyn_cast(); @@ -275,7 +275,7 @@ static bool isVectorizableLoopWithCond(const ForStmt &loop, if (vector) { return false; } - if (!isVectorizableStmt(loop, *op)) { + if (!isVectorizableInst(loop, *op)) { return false; } } @@ -283,9 +283,9 @@ static bool isVectorizableLoopWithCond(const ForStmt &loop, } bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim( - const ForStmt &loop, unsigned fastestVaryingDim) { - VectorizableStmtFun fun( - [fastestVaryingDim](const ForStmt &loop, const OperationInst &op) { + const ForInst &loop, unsigned fastestVaryingDim) { + VectorizableInstFun fun( + [fastestVaryingDim](const ForInst &loop, const OperationInst &op) { auto load = op.dyn_cast(); auto store = op.dyn_cast(); return load ? isContiguousAccess(loop, *load, fastestVaryingDim) @@ -294,37 +294,36 @@ bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim( return isVectorizableLoopWithCond(loop, fun); } -bool mlir::isVectorizableLoop(const ForStmt &loop) { - VectorizableStmtFun fun( +bool mlir::isVectorizableLoop(const ForInst &loop) { + VectorizableInstFun fun( // TODO: implement me - [](const ForStmt &loop, const OperationInst &op) { return true; }); + [](const ForInst &loop, const OperationInst &op) { return true; }); return isVectorizableLoopWithCond(loop, fun); } -/// Checks whether SSA dominance would be violated if a for stmt's body -/// statements are shifted by the specified shifts. This method checks if a +/// Checks whether SSA dominance would be violated if a for inst's body +/// instructions are shifted by the specified shifts. This method checks if a /// 'def' and all its uses have the same shift factor. // TODO(mlir-team): extend this to check for memory-based dependence // violation when we have the support. -bool mlir::isStmtwiseShiftValid(const ForStmt &forStmt, +bool mlir::isInstwiseShiftValid(const ForInst &forInst, ArrayRef shifts) { - auto *forBody = forStmt.getBody(); + auto *forBody = forInst.getBody(); assert(shifts.size() == forBody->getInstructions().size()); unsigned s = 0; - for (const auto &stmt : *forBody) { - // A for or if stmt does not produce any def/results (that are used + for (const auto &inst : *forBody) { + // A for or if inst does not produce any def/results (that are used // outside). - if (const auto *opStmt = dyn_cast(&stmt)) { - for (unsigned i = 0, e = opStmt->getNumResults(); i < e; ++i) { - const Value *result = opStmt->getResult(i); + if (const auto *opInst = dyn_cast(&inst)) { + for (unsigned i = 0, e = opInst->getNumResults(); i < e; ++i) { + const Value *result = opInst->getResult(i); for (const InstOperand &use : result->getUses()) { - // If an ancestor statement doesn't lie in the block of forStmt, there - // is no shift to check. - // This is a naive way. If performance becomes an issue, a map can - // be used to store 'shifts' - to look up the shift for a statement in - // constant time. - if (auto *ancStmt = forBody->findAncestorInstInBlock(*use.getOwner())) - if (shifts[s] != shifts[forBody->findInstPositionInBlock(*ancStmt)]) + // If an ancestor instruction doesn't lie in the block of forInst, + // there is no shift to check. This is a naive way. If performance + // becomes an issue, a map can be used to store 'shifts' - to look up + // the shift for a instruction in constant time. + if (auto *ancInst = forBody->findAncestorInstInBlock(*use.getOwner())) + if (shifts[s] != shifts[forBody->findInstPositionInBlock(*ancInst)]) return false; } } diff --git a/mlir/lib/Analysis/MLFunctionMatcher.cpp b/mlir/lib/Analysis/MLFunctionMatcher.cpp index 12ce8481516..5bb4548e670 100644 --- a/mlir/lib/Analysis/MLFunctionMatcher.cpp +++ b/mlir/lib/Analysis/MLFunctionMatcher.cpp @@ -31,29 +31,29 @@ struct MLFunctionMatchesStorage { /// Underlying storage for MLFunctionMatcher. struct MLFunctionMatcherStorage { - MLFunctionMatcherStorage(Statement::Kind k, + MLFunctionMatcherStorage(Instruction::Kind k, MutableArrayRef c, - FilterFunctionType filter, Statement *skip) + FilterFunctionType filter, Instruction *skip) : kind(k), childrenMLFunctionMatchers(c.begin(), c.end()), filter(filter), skip(skip) {} - Statement::Kind kind; + Instruction::Kind kind; SmallVector childrenMLFunctionMatchers; FilterFunctionType filter; /// skip is needed so that we can implement match without switching on the - /// type of the Statement. + /// type of the Instruction. /// The idea is that a MLFunctionMatcher first checks if it matches locally /// and then recursively applies its children matchers to its elem->children. - /// Since we want to rely on the StmtWalker impl rather than duplicate its + /// Since we want to rely on the InstWalker impl rather than duplicate its /// the logic, we allow an off-by-one traversal to account for the fact that /// we write: /// - /// void match(Statement *elem) { + /// void match(Instruction *elem) { /// for (auto &c : getChildrenMLFunctionMatchers()) { /// MLFunctionMatcher childMLFunctionMatcher(...); /// ^~~~ Needs off-by-one skip. /// - Statement *skip; + Instruction *skip; }; } // end namespace mlir @@ -65,12 +65,12 @@ llvm::BumpPtrAllocator *&MLFunctionMatches::allocator() { return allocator; } -void MLFunctionMatches::append(Statement *stmt, MLFunctionMatches children) { +void MLFunctionMatches::append(Instruction *inst, MLFunctionMatches children) { if (!storage) { storage = allocator()->Allocate(); - new (storage) MLFunctionMatchesStorage(std::make_pair(stmt, children)); + new (storage) MLFunctionMatchesStorage(std::make_pair(inst, children)); } else { - storage->matches.push_back(std::make_pair(stmt, children)); + storage->matches.push_back(std::make_pair(inst, children)); } } MLFunctionMatches::iterator MLFunctionMatches::begin() { @@ -98,10 +98,10 @@ MLFunctionMatches MLFunctionMatcher::match(Function *function) { return matches; } -/// Calls walk on `statement`. -MLFunctionMatches MLFunctionMatcher::match(Statement *statement) { +/// Calls walk on `instruction`. +MLFunctionMatches MLFunctionMatcher::match(Instruction *instruction) { assert(!matches && "MLFunctionMatcher already matched!"); - this->walkPostOrder(statement); + this->walkPostOrder(instruction); return matches; } @@ -117,17 +117,17 @@ unsigned MLFunctionMatcher::getDepth() { return depth + 1; } -/// Matches a single statement in the following way: -/// 1. checks the kind of statement against the matcher, if different then +/// Matches a single instruction in the following way: +/// 1. checks the kind of instruction against the matcher, if different then /// there is no match; -/// 2. calls the customizable filter function to refine the single statement +/// 2. calls the customizable filter function to refine the single instruction /// match with extra semantic constraints; /// 3. if all is good, recursivey matches the children patterns; -/// 4. if all children match then the single statement matches too and is +/// 4. if all children match then the single instruction matches too and is /// appended to the list of matches; /// 5. TODO(ntv) Optionally applies actions (lambda), in which case we will /// want to traverse in post-order DFS to avoid invalidating iterators. -void MLFunctionMatcher::matchOne(Statement *elem) { +void MLFunctionMatcher::matchOne(Instruction *elem) { if (storage->skip == elem) { return; } @@ -159,7 +159,8 @@ llvm::BumpPtrAllocator *&MLFunctionMatcher::allocator() { return allocator; } -MLFunctionMatcher::MLFunctionMatcher(Statement::Kind k, MLFunctionMatcher child, +MLFunctionMatcher::MLFunctionMatcher(Instruction::Kind k, + MLFunctionMatcher child, FilterFunctionType filter) : storage(allocator()->Allocate()) { // Initialize with placement new. @@ -168,7 +169,7 @@ MLFunctionMatcher::MLFunctionMatcher(Statement::Kind k, MLFunctionMatcher child, } MLFunctionMatcher::MLFunctionMatcher( - Statement::Kind k, MutableArrayRef children, + Instruction::Kind k, MutableArrayRef children, FilterFunctionType filter) : storage(allocator()->Allocate()) { // Initialize with placement new. @@ -178,14 +179,14 @@ MLFunctionMatcher::MLFunctionMatcher( MLFunctionMatcher MLFunctionMatcher::forkMLFunctionMatcherAt(MLFunctionMatcher tmpl, - Statement *stmt) { + Instruction *inst) { MLFunctionMatcher res(tmpl.getKind(), tmpl.getChildrenMLFunctionMatchers(), tmpl.getFilterFunction()); - res.storage->skip = stmt; + res.storage->skip = inst; return res; } -Statement::Kind MLFunctionMatcher::getKind() { return storage->kind; } +Instruction::Kind MLFunctionMatcher::getKind() { return storage->kind; } MutableArrayRef MLFunctionMatcher::getChildrenMLFunctionMatchers() { @@ -200,54 +201,55 @@ namespace mlir { namespace matcher { MLFunctionMatcher Op(FilterFunctionType filter) { - return MLFunctionMatcher(Statement::Kind::OperationInst, {}, filter); + return MLFunctionMatcher(Instruction::Kind::OperationInst, {}, filter); } MLFunctionMatcher If(MLFunctionMatcher child) { - return MLFunctionMatcher(Statement::Kind::If, child, defaultFilterFunction); + return MLFunctionMatcher(Instruction::Kind::If, child, defaultFilterFunction); } MLFunctionMatcher If(FilterFunctionType filter, MLFunctionMatcher child) { - return MLFunctionMatcher(Statement::Kind::If, child, filter); + return MLFunctionMatcher(Instruction::Kind::If, child, filter); } MLFunctionMatcher If(MutableArrayRef children) { - return MLFunctionMatcher(Statement::Kind::If, children, + return MLFunctionMatcher(Instruction::Kind::If, children, defaultFilterFunction); } MLFunctionMatcher If(FilterFunctionType filter, MutableArrayRef children) { - return MLFunctionMatcher(Statement::Kind::If, children, filter); + return MLFunctionMatcher(Instruction::Kind::If, children, filter); } MLFunctionMatcher For(MLFunctionMatcher child) { - return MLFunctionMatcher(Statement::Kind::For, child, defaultFilterFunction); + return MLFunctionMatcher(Instruction::Kind::For, child, + defaultFilterFunction); } MLFunctionMatcher For(FilterFunctionType filter, MLFunctionMatcher child) { - return MLFunctionMatcher(Statement::Kind::For, child, filter); + return MLFunctionMatcher(Instruction::Kind::For, child, filter); } MLFunctionMatcher For(MutableArrayRef children) { - return MLFunctionMatcher(Statement::Kind::For, children, + return MLFunctionMatcher(Instruction::Kind::For, children, defaultFilterFunction); } MLFunctionMatcher For(FilterFunctionType filter, MutableArrayRef children) { - return MLFunctionMatcher(Statement::Kind::For, children, filter); + return MLFunctionMatcher(Instruction::Kind::For, children, filter); } // TODO(ntv): parallel annotation on loops. -bool isParallelLoop(const Statement &stmt) { - const auto *loop = cast(&stmt); +bool isParallelLoop(const Instruction &inst) { + const auto *loop = cast(&inst); return (void *)loop || true; // loop->isParallel(); }; // TODO(ntv): reduction annotation on loops. -bool isReductionLoop(const Statement &stmt) { - const auto *loop = cast(&stmt); +bool isReductionLoop(const Instruction &inst) { + const auto *loop = cast(&inst); return (void *)loop || true; // loop->isReduction(); }; -bool isLoadOrStore(const Statement &stmt) { - const auto *opStmt = dyn_cast(&stmt); - return opStmt && (opStmt->isa() || opStmt->isa()); +bool isLoadOrStore(const Instruction &inst) { + const auto *opInst = dyn_cast(&inst); + return opInst && (opInst->isa() || opInst->isa()); }; } // end namespace matcher diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index ad935faf05d..e8b668892b8 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -26,7 +26,7 @@ #include "mlir/Analysis/Utils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/StmtVisitor.h" +#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "llvm/Support/Debug.h" @@ -38,14 +38,14 @@ using namespace mlir; namespace { /// Checks for out of bound memef access subscripts.. -struct MemRefBoundCheck : public FunctionPass, StmtWalker { +struct MemRefBoundCheck : public FunctionPass, InstWalker { explicit MemRefBoundCheck() : FunctionPass(&MemRefBoundCheck::passID) {} PassResult runOnMLFunction(Function *f) override; // Not applicable to CFG functions. PassResult runOnCFGFunction(Function *f) override { return success(); } - void visitOperationInst(OperationInst *opStmt); + void visitOperationInst(OperationInst *opInst); static char passID; }; @@ -58,10 +58,10 @@ FunctionPass *mlir::createMemRefBoundCheckPass() { return new MemRefBoundCheck(); } -void MemRefBoundCheck::visitOperationInst(OperationInst *opStmt) { - if (auto loadOp = opStmt->dyn_cast()) { +void MemRefBoundCheck::visitOperationInst(OperationInst *opInst) { + if (auto loadOp = opInst->dyn_cast()) { boundCheckLoadOrStoreOp(loadOp); - } else if (auto storeOp = opStmt->dyn_cast()) { + } else if (auto storeOp = opInst->dyn_cast()) { boundCheckLoadOrStoreOp(storeOp); } // TODO(bondhugula): do this for DMA ops as well. diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp index bb668f78624..8391f15b6d3 100644 --- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp @@ -25,7 +25,7 @@ #include "mlir/Analysis/Utils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/StmtVisitor.h" +#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "llvm/Support/Debug.h" @@ -39,7 +39,7 @@ namespace { // TODO(andydavis) Add common surrounding loop depth-wise dependence checks. /// Checks dependences between all pairs of memref accesses in a Function. struct MemRefDependenceCheck : public FunctionPass, - StmtWalker { + InstWalker { SmallVector loadsAndStores; explicit MemRefDependenceCheck() : FunctionPass(&MemRefDependenceCheck::passID) {} @@ -48,9 +48,9 @@ struct MemRefDependenceCheck : public FunctionPass, // Not applicable to CFG functions. PassResult runOnCFGFunction(Function *f) override { return success(); } - void visitOperationInst(OperationInst *opStmt) { - if (opStmt->isa() || opStmt->isa()) { - loadsAndStores.push_back(opStmt); + void visitOperationInst(OperationInst *opInst) { + if (opInst->isa() || opInst->isa()) { + loadsAndStores.push_back(opInst); } } static char passID; @@ -74,17 +74,17 @@ static void addMemRefAccessIndices( } } -// Populates 'access' with memref, indices and opstmt from 'loadOrStoreOpStmt'. -static void getMemRefAccess(const OperationInst *loadOrStoreOpStmt, +// Populates 'access' with memref, indices and opinst from 'loadOrStoreOpInst'. +static void getMemRefAccess(const OperationInst *loadOrStoreOpInst, MemRefAccess *access) { - access->opStmt = loadOrStoreOpStmt; - if (auto loadOp = loadOrStoreOpStmt->dyn_cast()) { + access->opInst = loadOrStoreOpInst; + if (auto loadOp = loadOrStoreOpInst->dyn_cast()) { access->memref = loadOp->getMemRef(); addMemRefAccessIndices(loadOp->getIndices(), loadOp->getMemRefType(), access); } else { - assert(loadOrStoreOpStmt->isa()); - auto storeOp = loadOrStoreOpStmt->dyn_cast(); + assert(loadOrStoreOpInst->isa()); + auto storeOp = loadOrStoreOpInst->dyn_cast(); access->memref = storeOp->getMemRef(); addMemRefAccessIndices(storeOp->getIndices(), storeOp->getMemRefType(), access); @@ -93,8 +93,8 @@ static void getMemRefAccess(const OperationInst *loadOrStoreOpStmt, // Returns the number of surrounding loops common to 'loopsA' and 'loopsB', // where each lists loops from outer-most to inner-most in loop nest. -static unsigned getNumCommonSurroundingLoops(ArrayRef loopsA, - ArrayRef loopsB) { +static unsigned getNumCommonSurroundingLoops(ArrayRef loopsA, + ArrayRef loopsB) { unsigned minNumLoops = std::min(loopsA.size(), loopsB.size()); unsigned numCommonLoops = 0; for (unsigned i = 0; i < minNumLoops; ++i) { @@ -133,18 +133,18 @@ getDirectionVectorStr(bool ret, unsigned numCommonLoops, unsigned loopNestDepth, // the source access. static void checkDependences(ArrayRef loadsAndStores) { for (unsigned i = 0, e = loadsAndStores.size(); i < e; ++i) { - auto *srcOpStmt = loadsAndStores[i]; + auto *srcOpInst = loadsAndStores[i]; MemRefAccess srcAccess; - getMemRefAccess(srcOpStmt, &srcAccess); - SmallVector srcLoops; - getLoopIVs(*srcOpStmt, &srcLoops); + getMemRefAccess(srcOpInst, &srcAccess); + SmallVector srcLoops; + getLoopIVs(*srcOpInst, &srcLoops); for (unsigned j = 0; j < e; ++j) { - auto *dstOpStmt = loadsAndStores[j]; + auto *dstOpInst = loadsAndStores[j]; MemRefAccess dstAccess; - getMemRefAccess(dstOpStmt, &dstAccess); + getMemRefAccess(dstOpInst, &dstAccess); - SmallVector dstLoops; - getLoopIVs(*dstOpStmt, &dstLoops); + SmallVector dstLoops; + getLoopIVs(*dstOpInst, &dstLoops); unsigned numCommonLoops = getNumCommonSurroundingLoops(srcLoops, dstLoops); for (unsigned d = 1; d <= numCommonLoops + 1; ++d) { @@ -156,7 +156,7 @@ static void checkDependences(ArrayRef loadsAndStores) { // TODO(andydavis) Print dependence type (i.e. RAW, etc) and print // distance vectors as: ([2, 3], [0, 10]). Also, shorten distance // vectors from ([1, 1], [3, 3]) to (1, 3). - srcOpStmt->emitNote( + srcOpInst->emitNote( "dependence from " + Twine(i) + " to " + Twine(j) + " at depth " + Twine(d) + " = " + getDirectionVectorStr(ret, numCommonLoops, d, dependenceComponents) diff --git a/mlir/lib/Analysis/OpStats.cpp b/mlir/lib/Analysis/OpStats.cpp index f4c509a5132..07edb13d1a3 100644 --- a/mlir/lib/Analysis/OpStats.cpp +++ b/mlir/lib/Analysis/OpStats.cpp @@ -16,9 +16,9 @@ // ============================================================================= #include "mlir/IR/Function.h" +#include "mlir/IR/InstVisitor.h" +#include "mlir/IR/Instructions.h" #include "mlir/IR/OperationSupport.h" -#include "mlir/IR/Statements.h" -#include "mlir/IR/StmtVisitor.h" #include "mlir/Pass.h" #include "llvm/ADT/DenseMap.h" #include "llvm/Support/raw_ostream.h" @@ -26,7 +26,7 @@ using namespace mlir; namespace { -struct PrintOpStatsPass : public FunctionPass, StmtWalker { +struct PrintOpStatsPass : public FunctionPass, InstWalker { explicit PrintOpStatsPass(llvm::raw_ostream &os = llvm::errs()) : FunctionPass(&PrintOpStatsPass::passID), os(os) {} @@ -38,7 +38,7 @@ struct PrintOpStatsPass : public FunctionPass, StmtWalker { // Process ML functions and operation statments in ML functions. PassResult runOnMLFunction(Function *function) override; - void visitOperationInst(OperationInst *stmt); + void visitOperationInst(OperationInst *inst); // Print summary of op stats. void printSummary(); @@ -69,8 +69,8 @@ PassResult PrintOpStatsPass::runOnCFGFunction(Function *function) { return success(); } -void PrintOpStatsPass::visitOperationInst(OperationInst *stmt) { - ++opCount[stmt->getName().getStringRef()]; +void PrintOpStatsPass::visitOperationInst(OperationInst *inst) { + ++opCount[inst->getName().getStringRef()]; } PassResult PrintOpStatsPass::runOnMLFunction(Function *function) { diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index 393d7c59de0..a8cec771f0d 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -22,7 +22,7 @@ #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/VectorAnalysis.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Statements.h" +#include "mlir/IR/Instructions.h" #include "mlir/Support/Functional.h" #include "mlir/Support/STLExtras.h" @@ -38,36 +38,36 @@ using namespace mlir; using llvm::DenseSet; using llvm::SetVector; -void mlir::getForwardSlice(Statement *stmt, - SetVector *forwardSlice, +void mlir::getForwardSlice(Instruction *inst, + SetVector *forwardSlice, TransitiveFilter filter, bool topLevel) { - if (!stmt) { + if (!inst) { return; } // Evaluate whether we should keep this use. // This is useful in particular to implement scoping; i.e. return the // transitive forwardSlice in the current scope. - if (!filter(stmt)) { + if (!filter(inst)) { return; } - if (auto *opStmt = dyn_cast(stmt)) { - assert(opStmt->getNumResults() <= 1 && "NYI: multiple results"); - if (opStmt->getNumResults() > 0) { - for (auto &u : opStmt->getResult(0)->getUses()) { - auto *ownerStmt = u.getOwner(); - if (forwardSlice->count(ownerStmt) == 0) { - getForwardSlice(ownerStmt, forwardSlice, filter, + if (auto *opInst = dyn_cast(inst)) { + assert(opInst->getNumResults() <= 1 && "NYI: multiple results"); + if (opInst->getNumResults() > 0) { + for (auto &u : opInst->getResult(0)->getUses()) { + auto *ownerInst = u.getOwner(); + if (forwardSlice->count(ownerInst) == 0) { + getForwardSlice(ownerInst, forwardSlice, filter, /*topLevel=*/false); } } } - } else if (auto *forStmt = dyn_cast(stmt)) { - for (auto &u : forStmt->getUses()) { - auto *ownerStmt = u.getOwner(); - if (forwardSlice->count(ownerStmt) == 0) { - getForwardSlice(ownerStmt, forwardSlice, filter, + } else if (auto *forInst = dyn_cast(inst)) { + for (auto &u : forInst->getUses()) { + auto *ownerInst = u.getOwner(); + if (forwardSlice->count(ownerInst) == 0) { + getForwardSlice(ownerInst, forwardSlice, filter, /*topLevel=*/false); } } @@ -80,61 +80,61 @@ void mlir::getForwardSlice(Statement *stmt, // std::reverse does not work out of the box on SetVector and I want an // in-place swap based thing (the real std::reverse, not the LLVM adapter). // TODO(clattner): Consider adding an extra method? - std::vector v(forwardSlice->takeVector()); + std::vector v(forwardSlice->takeVector()); forwardSlice->insert(v.rbegin(), v.rend()); } else { - forwardSlice->insert(stmt); + forwardSlice->insert(inst); } } -void mlir::getBackwardSlice(Statement *stmt, - SetVector *backwardSlice, +void mlir::getBackwardSlice(Instruction *inst, + SetVector *backwardSlice, TransitiveFilter filter, bool topLevel) { - if (!stmt) { + if (!inst) { return; } // Evaluate whether we should keep this def. // This is useful in particular to implement scoping; i.e. return the // transitive forwardSlice in the current scope. - if (!filter(stmt)) { + if (!filter(inst)) { return; } - for (auto *operand : stmt->getOperands()) { - auto *stmt = operand->getDefiningInst(); - if (backwardSlice->count(stmt) == 0) { - getBackwardSlice(stmt, backwardSlice, filter, + for (auto *operand : inst->getOperands()) { + auto *inst = operand->getDefiningInst(); + if (backwardSlice->count(inst) == 0) { + getBackwardSlice(inst, backwardSlice, filter, /*topLevel=*/false); } } - // Don't insert the top level statement, we just queried on it and don't + // Don't insert the top level instruction, we just queried on it and don't // want it in the results. if (!topLevel) { - backwardSlice->insert(stmt); + backwardSlice->insert(inst); } } -SetVector mlir::getSlice(Statement *stmt, - TransitiveFilter backwardFilter, - TransitiveFilter forwardFilter) { - SetVector slice; - slice.insert(stmt); +SetVector mlir::getSlice(Instruction *inst, + TransitiveFilter backwardFilter, + TransitiveFilter forwardFilter) { + SetVector slice; + slice.insert(inst); unsigned currentIndex = 0; - SetVector backwardSlice; - SetVector forwardSlice; + SetVector backwardSlice; + SetVector forwardSlice; while (currentIndex != slice.size()) { - auto *currentStmt = (slice)[currentIndex]; - // Compute and insert the backwardSlice starting from currentStmt. + auto *currentInst = (slice)[currentIndex]; + // Compute and insert the backwardSlice starting from currentInst. backwardSlice.clear(); - getBackwardSlice(currentStmt, &backwardSlice, backwardFilter); + getBackwardSlice(currentInst, &backwardSlice, backwardFilter); slice.insert(backwardSlice.begin(), backwardSlice.end()); - // Compute and insert the forwardSlice starting from currentStmt. + // Compute and insert the forwardSlice starting from currentInst. forwardSlice.clear(); - getForwardSlice(currentStmt, &forwardSlice, forwardFilter); + getForwardSlice(currentInst, &forwardSlice, forwardFilter); slice.insert(forwardSlice.begin(), forwardSlice.end()); ++currentIndex; } @@ -144,24 +144,24 @@ SetVector mlir::getSlice(Statement *stmt, namespace { /// DFS post-order implementation that maintains a global count to work across /// multiple invocations, to help implement topological sort on multi-root DAGs. -/// We traverse all statements but only record the ones that appear in `toSort` -/// for the final result. +/// We traverse all instructions but only record the ones that appear in +/// `toSort` for the final result. struct DFSState { - DFSState(const SetVector &set) + DFSState(const SetVector &set) : toSort(set), topologicalCounts(), seen() {} - const SetVector &toSort; - SmallVector topologicalCounts; - DenseSet seen; + const SetVector &toSort; + SmallVector topologicalCounts; + DenseSet seen; }; } // namespace -static void DFSPostorder(Statement *current, DFSState *state) { - auto *opStmt = cast(current); - assert(opStmt->getNumResults() <= 1 && "NYI: multi-result"); - if (opStmt->getNumResults() > 0) { - for (auto &u : opStmt->getResult(0)->getUses()) { - auto *stmt = u.getOwner(); - DFSPostorder(stmt, state); +static void DFSPostorder(Instruction *current, DFSState *state) { + auto *opInst = cast(current); + assert(opInst->getNumResults() <= 1 && "NYI: multi-result"); + if (opInst->getNumResults() > 0) { + for (auto &u : opInst->getResult(0)->getUses()) { + auto *inst = u.getOwner(); + DFSPostorder(inst, state); } } bool inserted; @@ -175,8 +175,8 @@ static void DFSPostorder(Statement *current, DFSState *state) { } } -SetVector -mlir::topologicalSort(const SetVector &toSort) { +SetVector +mlir::topologicalSort(const SetVector &toSort) { if (toSort.empty()) { return toSort; } @@ -189,7 +189,7 @@ mlir::topologicalSort(const SetVector &toSort) { } // Reorder and return. - SetVector res; + SetVector res; for (auto it = state.topologicalCounts.rbegin(), eit = state.topologicalCounts.rend(); it != eit; ++it) { diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index f6191418f54..a7fc5ac619e 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -34,8 +34,8 @@ using namespace mlir; -/// Returns true if statement 'a' properly dominates statement b. -bool mlir::properlyDominates(const Statement &a, const Statement &b) { +/// Returns true if instruction 'a' properly dominates instruction b. +bool mlir::properlyDominates(const Instruction &a, const Instruction &b) { if (&a == &b) return false; @@ -64,24 +64,24 @@ bool mlir::properlyDominates(const Statement &a, const Statement &b) { return false; } -/// Returns true if statement A dominates statement B. -bool mlir::dominates(const Statement &a, const Statement &b) { +/// Returns true if instruction A dominates instruction B. +bool mlir::dominates(const Instruction &a, const Instruction &b) { return &a == &b || properlyDominates(a, b); } -/// Populates 'loops' with IVs of the loops surrounding 'stmt' ordered from -/// the outermost 'for' statement to the innermost one. -void mlir::getLoopIVs(const Statement &stmt, - SmallVectorImpl *loops) { - auto *currStmt = stmt.getParentStmt(); - ForStmt *currForStmt; - // Traverse up the hierarchy collecing all 'for' statement while skipping over - // 'if' statements. - while (currStmt && ((currForStmt = dyn_cast(currStmt)) || - isa(currStmt))) { - if (currForStmt) - loops->push_back(currForStmt); - currStmt = currStmt->getParentStmt(); +/// Populates 'loops' with IVs of the loops surrounding 'inst' ordered from +/// the outermost 'for' instruction to the innermost one. +void mlir::getLoopIVs(const Instruction &inst, + SmallVectorImpl *loops) { + auto *currInst = inst.getParentInst(); + ForInst *currForInst; + // Traverse up the hierarchy collecing all 'for' instruction while skipping + // over 'if' instructions. + while (currInst && ((currForInst = dyn_cast(currInst)) || + isa(currInst))) { + if (currForInst) + loops->push_back(currForInst); + currInst = currInst->getParentInst(); } std::reverse(loops->begin(), loops->end()); } @@ -129,7 +129,7 @@ Optional MemRefRegion::getBoundingConstantSizeAndShape( /// Computes the memory region accessed by this memref with the region /// represented as constraints symbolic/parameteric in 'loopDepth' loops -/// surrounding opStmt and any additional Function symbols. Returns false if +/// surrounding opInst and any additional Function symbols. Returns false if /// this fails due to yet unimplemented cases. // For example, the memref region for this load operation at loopDepth = 1 will // be as below: @@ -145,21 +145,21 @@ Optional MemRefRegion::getBoundingConstantSizeAndShape( // // TODO(bondhugula): extend this to any other memref dereferencing ops // (dma_start, dma_wait). -bool mlir::getMemRefRegion(OperationInst *opStmt, unsigned loopDepth, +bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth, MemRefRegion *region) { OpPointer loadOp; OpPointer storeOp; unsigned rank; SmallVector indices; - if ((loadOp = opStmt->dyn_cast())) { + if ((loadOp = opInst->dyn_cast())) { rank = loadOp->getMemRefType().getRank(); for (auto *index : loadOp->getIndices()) { indices.push_back(index); } region->memref = loadOp->getMemRef(); region->setWrite(false); - } else if ((storeOp = opStmt->dyn_cast())) { + } else if ((storeOp = opInst->dyn_cast())) { rank = storeOp->getMemRefType().getRank(); for (auto *index : storeOp->getIndices()) { indices.push_back(index); @@ -173,7 +173,7 @@ bool mlir::getMemRefRegion(OperationInst *opStmt, unsigned loopDepth, // Build the constraints for this region. FlatAffineConstraints *regionCst = region->getConstraints(); - FuncBuilder b(opStmt); + FuncBuilder b(opInst); auto idMap = b.getMultiDimIdentityMap(rank); // Initialize 'accessValueMap' and compose with reachable AffineApplyOps. @@ -192,20 +192,20 @@ bool mlir::getMemRefRegion(OperationInst *opStmt, unsigned loopDepth, unsigned numSymbols = accessMap.getNumSymbols(); // Add inequalties for loop lower/upper bounds. for (unsigned i = 0; i < numDims + numSymbols; ++i) { - if (auto *loop = dyn_cast(accessValueMap.getOperand(i))) { + if (auto *loop = dyn_cast(accessValueMap.getOperand(i))) { // Note that regionCst can now have more dimensions than accessMap if the // bounds expressions involve outer loops or other symbols. - // TODO(bondhugula): rewrite this to use getStmtIndexSet; this way + // TODO(bondhugula): rewrite this to use getInstIndexSet; this way // conditionals will be handled when the latter supports it. - if (!regionCst->addForStmtDomain(*loop)) + if (!regionCst->addForInstDomain(*loop)) return false; } else { // Has to be a valid symbol. auto *symbol = accessValueMap.getOperand(i); assert(symbol->isValidSymbol()); // Check if the symbol is a constant. - if (auto *opStmt = symbol->getDefiningInst()) { - if (auto constOp = opStmt->dyn_cast()) { + if (auto *opInst = symbol->getDefiningInst()) { + if (auto constOp = opInst->dyn_cast()) { regionCst->setIdToConstant(*symbol, constOp->getValue()); } } @@ -220,12 +220,12 @@ bool mlir::getMemRefRegion(OperationInst *opStmt, unsigned loopDepth, // Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which // this memref region is symbolic. - SmallVector outerIVs; - getLoopIVs(*opStmt, &outerIVs); + SmallVector outerIVs; + getLoopIVs(*opInst, &outerIVs); outerIVs.resize(loopDepth); for (auto *operand : accessValueMap.getOperands()) { - ForStmt *iv; - if ((iv = dyn_cast(operand)) && + ForInst *iv; + if ((iv = dyn_cast(operand)) && std::find(outerIVs.begin(), outerIVs.end(), iv) == outerIVs.end()) { regionCst->projectOut(operand); } @@ -282,9 +282,9 @@ bool mlir::boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, std::is_same>::value, "function argument should be either a LoadOp or a StoreOp"); - OperationInst *opStmt = loadOrStoreOp->getInstruction(); + OperationInst *opInst = loadOrStoreOp->getInstruction(); MemRefRegion region; - if (!getMemRefRegion(opStmt, /*loopDepth=*/0, ®ion)) + if (!getMemRefRegion(opInst, /*loopDepth=*/0, ®ion)) return false; LLVM_DEBUG(llvm::dbgs() << "Memory region"); LLVM_DEBUG(region.getConstraints()->dump()); @@ -333,43 +333,43 @@ template bool mlir::boundCheckLoadOrStoreOp(OpPointer loadOp, template bool mlir::boundCheckLoadOrStoreOp(OpPointer storeOp, bool emitError); -// Returns in 'positions' the Block positions of 'stmt' in each ancestor -// Block from the Block containing statement, stopping at 'limitBlock'. -static void findStmtPosition(const Statement *stmt, Block *limitBlock, +// Returns in 'positions' the Block positions of 'inst' in each ancestor +// Block from the Block containing instruction, stopping at 'limitBlock'. +static void findInstPosition(const Instruction *inst, Block *limitBlock, SmallVectorImpl *positions) { - Block *block = stmt->getBlock(); + Block *block = inst->getBlock(); while (block != limitBlock) { - int stmtPosInBlock = block->findInstPositionInBlock(*stmt); - assert(stmtPosInBlock >= 0); - positions->push_back(stmtPosInBlock); - stmt = block->getContainingInst(); - block = stmt->getBlock(); + int instPosInBlock = block->findInstPositionInBlock(*inst); + assert(instPosInBlock >= 0); + positions->push_back(instPosInBlock); + inst = block->getContainingInst(); + block = inst->getBlock(); } std::reverse(positions->begin(), positions->end()); } -// Returns the Statement in a possibly nested set of Blocks, where the -// position of the statement is represented by 'positions', which has a +// Returns the Instruction in a possibly nested set of Blocks, where the +// position of the instruction is represented by 'positions', which has a // Block position for each level of nesting. -static Statement *getStmtAtPosition(ArrayRef positions, - unsigned level, Block *block) { +static Instruction *getInstAtPosition(ArrayRef positions, + unsigned level, Block *block) { unsigned i = 0; - for (auto &stmt : *block) { + for (auto &inst : *block) { if (i != positions[level]) { ++i; continue; } if (level == positions.size() - 1) - return &stmt; - if (auto *childForStmt = dyn_cast(&stmt)) - return getStmtAtPosition(positions, level + 1, childForStmt->getBody()); + return &inst; + if (auto *childForInst = dyn_cast(&inst)) + return getInstAtPosition(positions, level + 1, childForInst->getBody()); - if (auto *ifStmt = dyn_cast(&stmt)) { - auto *ret = getStmtAtPosition(positions, level + 1, ifStmt->getThen()); + if (auto *ifInst = dyn_cast(&inst)) { + auto *ret = getInstAtPosition(positions, level + 1, ifInst->getThen()); if (ret != nullptr) return ret; - if (auto *elseClause = ifStmt->getElse()) - return getStmtAtPosition(positions, level + 1, elseClause); + if (auto *elseClause = ifInst->getElse()) + return getInstAtPosition(positions, level + 1, elseClause); } } return nullptr; @@ -379,7 +379,7 @@ static Statement *getStmtAtPosition(ArrayRef positions, // dependence constraint system to create AffineMaps with which to adjust the // loop bounds of the inserted compution slice so that they are functions of the // loop IVs and symbols of the loops surrounding 'dstAccess'. -ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess, +ForInst *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess, MemRefAccess *dstAccess, unsigned srcLoopDepth, unsigned dstLoopDepth) { @@ -390,14 +390,14 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess, return nullptr; } // Get loop nest surrounding src operation. - SmallVector srcLoopNest; - getLoopIVs(*srcAccess->opStmt, &srcLoopNest); + SmallVector srcLoopNest; + getLoopIVs(*srcAccess->opInst, &srcLoopNest); unsigned srcLoopNestSize = srcLoopNest.size(); assert(srcLoopDepth <= srcLoopNestSize); // Get loop nest surrounding dst operation. - SmallVector dstLoopNest; - getLoopIVs(*dstAccess->opStmt, &dstLoopNest); + SmallVector dstLoopNest; + getLoopIVs(*dstAccess->opInst, &dstLoopNest); unsigned dstLoopNestSize = dstLoopNest.size(); (void)dstLoopNestSize; assert(dstLoopDepth > 0); @@ -425,7 +425,7 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess, } SmallVector nonZeroDimIds; SmallVector nonZeroSymbolIds; - srcIvMaps[i] = cst->toAffineMapFromEq(0, 0, srcAccess->opStmt->getContext(), + srcIvMaps[i] = cst->toAffineMapFromEq(0, 0, srcAccess->opInst->getContext(), &nonZeroDimIds, &nonZeroSymbolIds); if (srcIvMaps[i] == AffineMap::Null()) { continue; @@ -446,23 +446,23 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess, // with a symbol identifiers in 'nonZeroSymbolIds'. } - // Find the stmt block positions of 'srcAccess->opStmt' within 'srcLoopNest'. + // Find the inst block positions of 'srcAccess->opInst' within 'srcLoopNest'. SmallVector positions; - findStmtPosition(srcAccess->opStmt, srcLoopNest[0]->getBlock(), &positions); + findInstPosition(srcAccess->opInst, srcLoopNest[0]->getBlock(), &positions); - // Clone src loop nest and insert it a the beginning of the statement block + // Clone src loop nest and insert it a the beginning of the instruction block // of the loop at 'dstLoopDepth' in 'dstLoopNest'. - auto *dstForStmt = dstLoopNest[dstLoopDepth - 1]; - FuncBuilder b(dstForStmt->getBody(), dstForStmt->getBody()->begin()); + auto *dstForInst = dstLoopNest[dstLoopDepth - 1]; + FuncBuilder b(dstForInst->getBody(), dstForInst->getBody()->begin()); DenseMap operandMap; - auto *sliceLoopNest = cast(b.clone(*srcLoopNest[0], operandMap)); - - // Lookup stmt in cloned 'sliceLoopNest' at 'positions'. - Statement *sliceStmt = - getStmtAtPosition(positions, /*level=*/0, sliceLoopNest->getBody()); - // Get loop nest surrounding 'sliceStmt'. - SmallVector sliceSurroundingLoops; - getLoopIVs(*sliceStmt, &sliceSurroundingLoops); + auto *sliceLoopNest = cast(b.clone(*srcLoopNest[0], operandMap)); + + // Lookup inst in cloned 'sliceLoopNest' at 'positions'. + Instruction *sliceInst = + getInstAtPosition(positions, /*level=*/0, sliceLoopNest->getBody()); + // Get loop nest surrounding 'sliceInst'. + SmallVector sliceSurroundingLoops; + getLoopIVs(*sliceInst, &sliceSurroundingLoops); unsigned sliceSurroundingLoopsSize = sliceSurroundingLoops.size(); (void)sliceSurroundingLoopsSize; @@ -470,18 +470,18 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess, unsigned sliceLoopLimit = dstLoopDepth + srcLoopNestSize; assert(sliceLoopLimit <= sliceSurroundingLoopsSize); for (unsigned i = dstLoopDepth; i < sliceLoopLimit; ++i) { - auto *forStmt = sliceSurroundingLoops[i]; + auto *forInst = sliceSurroundingLoops[i]; unsigned index = i - dstLoopDepth; AffineMap lbMap = srcIvMaps[index]; if (lbMap == AffineMap::Null()) continue; - forStmt->setLowerBound(srcIvOperands[index], lbMap); + forInst->setLowerBound(srcIvOperands[index], lbMap); // Create upper bound map with is lower bound map + 1; assert(lbMap.getNumResults() == 1); AffineExpr ubResultExpr = lbMap.getResult(0) + 1; AffineMap ubMap = AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(), {ubResultExpr}, {}); - forStmt->setUpperBound(srcIvOperands[index], ubMap); + forInst->setUpperBound(srcIvOperands[index], ubMap); } return sliceLoopNest; } diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index cd9451cd5e9..e092b29a13b 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -19,7 +19,7 @@ #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Statements.h" +#include "mlir/IR/Instructions.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/SuperVectorOps/SuperVectorOps.h" #include "mlir/Support/Functional.h" @@ -105,7 +105,7 @@ Optional> mlir::shapeRatio(VectorType superVectorType, static AffineMap makePermutationMap( MLIRContext *context, llvm::iterator_range indices, - const DenseMap &enclosingLoopToVectorDim) { + const DenseMap &enclosingLoopToVectorDim) { using functional::makePtrDynCaster; using functional::map; auto unwrappedIndices = map(makePtrDynCaster(), indices); @@ -137,10 +137,11 @@ static AffineMap makePermutationMap( /// the specified type. /// TODO(ntv): could also be implemented as a collect parents followed by a /// filter and made available outside this file. -template static SetVector getParentsOfType(Statement *stmt) { +template +static SetVector getParentsOfType(Instruction *inst) { SetVector res; - auto *current = stmt; - while (auto *parent = current->getParentStmt()) { + auto *current = inst; + while (auto *parent = current->getParentInst()) { auto *typedParent = dyn_cast(parent); if (typedParent) { assert(res.count(typedParent) == 0 && "Already inserted"); @@ -151,34 +152,34 @@ template static SetVector getParentsOfType(Statement *stmt) { return res; } -/// Returns the enclosing ForStmt, from closest to farthest. -static SetVector getEnclosingForStmts(Statement *stmt) { - return getParentsOfType(stmt); +/// Returns the enclosing ForInst, from closest to farthest. +static SetVector getEnclosingforInsts(Instruction *inst) { + return getParentsOfType(inst); } AffineMap -mlir::makePermutationMap(OperationInst *opStmt, - const DenseMap &loopToVectorDim) { - DenseMap enclosingLoopToVectorDim; - auto enclosingLoops = getEnclosingForStmts(opStmt); - for (auto *forStmt : enclosingLoops) { - auto it = loopToVectorDim.find(forStmt); +mlir::makePermutationMap(OperationInst *opInst, + const DenseMap &loopToVectorDim) { + DenseMap enclosingLoopToVectorDim; + auto enclosingLoops = getEnclosingforInsts(opInst); + for (auto *forInst : enclosingLoops) { + auto it = loopToVectorDim.find(forInst); if (it != loopToVectorDim.end()) { enclosingLoopToVectorDim.insert(*it); } } - if (auto load = opStmt->dyn_cast()) { - return ::makePermutationMap(opStmt->getContext(), load->getIndices(), + if (auto load = opInst->dyn_cast()) { + return ::makePermutationMap(opInst->getContext(), load->getIndices(), enclosingLoopToVectorDim); } - auto store = opStmt->cast(); - return ::makePermutationMap(opStmt->getContext(), store->getIndices(), + auto store = opInst->cast(); + return ::makePermutationMap(opInst->getContext(), store->getIndices(), enclosingLoopToVectorDim); } -bool mlir::matcher::operatesOnStrictSuperVectors(const OperationInst &opStmt, +bool mlir::matcher::operatesOnStrictSuperVectors(const OperationInst &opInst, VectorType subVectorType) { // First, extract the vector type and ditinguish between: // a. ops that *must* lower a super-vector (i.e. vector_transfer_read, @@ -191,20 +192,20 @@ bool mlir::matcher::operatesOnStrictSuperVectors(const OperationInst &opStmt, /// do not have to special case. Maybe a trait, or just a method, unclear atm. bool mustDivide = false; VectorType superVectorType; - if (auto read = opStmt.dyn_cast()) { + if (auto read = opInst.dyn_cast()) { superVectorType = read->getResultType(); mustDivide = true; - } else if (auto write = opStmt.dyn_cast()) { + } else if (auto write = opInst.dyn_cast()) { superVectorType = write->getVectorType(); mustDivide = true; - } else if (opStmt.getNumResults() == 0) { - if (!opStmt.isa()) { - opStmt.emitError("NYI: assuming only return statements can have 0 " + } else if (opInst.getNumResults() == 0) { + if (!opInst.isa()) { + opInst.emitError("NYI: assuming only return instructions can have 0 " " results at this point"); } return false; - } else if (opStmt.getNumResults() == 1) { - if (auto v = opStmt.getResult(0)->getType().dyn_cast()) { + } else if (opInst.getNumResults() == 1) { + if (auto v = opInst.getResult(0)->getType().dyn_cast()) { superVectorType = v; } else { // Not a vector type. @@ -213,7 +214,7 @@ bool mlir::matcher::operatesOnStrictSuperVectors(const OperationInst &opStmt, } else { // Not a vector_transfer and has more than 1 result, fail hard for now to // wake us up when something changes. - opStmt.emitError("NYI: statement has more than 1 result"); + opInst.emitError("NYI: instruction has more than 1 result"); return false; } diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index 4cad531ecaa..7217c5492a6 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -36,9 +36,9 @@ #include "mlir/Analysis/Dominance.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Function.h" +#include "mlir/IR/InstVisitor.h" +#include "mlir/IR/Instructions.h" #include "mlir/IR/Module.h" -#include "mlir/IR/Statements.h" -#include "mlir/IR/StmtVisitor.h" #include "llvm/ADT/ScopedHashTable.h" #include "llvm/Support/PrettyStackTrace.h" #include "llvm/Support/raw_ostream.h" @@ -239,14 +239,14 @@ bool CFGFuncVerifier::verifyBlock(const Block &block) { //===----------------------------------------------------------------------===// namespace { -struct MLFuncVerifier : public Verifier, public StmtWalker { +struct MLFuncVerifier : public Verifier, public InstWalker { const Function &fn; bool hadError = false; MLFuncVerifier(const Function &fn) : Verifier(fn), fn(fn) {} - void visitOperationInst(OperationInst *opStmt) { - hadError |= verifyOperation(*opStmt); + void visitOperationInst(OperationInst *opInst) { + hadError |= verifyOperation(*opInst); } bool verify() { @@ -269,7 +269,7 @@ struct MLFuncVerifier : public Verifier, public StmtWalker { /// operations are properly dominated by their definitions. bool verifyDominance(); - /// Verify that function has a return statement that matches its signature. + /// Verify that function has a return instruction that matches its signature. bool verifyReturn(); }; } // end anonymous namespace @@ -285,48 +285,48 @@ bool MLFuncVerifier::verifyDominance() { for (auto *arg : fn.getArguments()) liveValues.insert(arg, true); - // This recursive function walks the statement list pushing scopes onto the + // This recursive function walks the instruction list pushing scopes onto the // stack as it goes, and popping them to remove them from the table. std::function walkBlock; walkBlock = [&](const Block &block) -> bool { HashTable::ScopeTy blockScope(liveValues); - // The induction variable of a for statement is live within its body. - if (auto *forStmt = dyn_cast_or_null(block.getContainingInst())) - liveValues.insert(forStmt, true); + // The induction variable of a for instruction is live within its body. + if (auto *forInst = dyn_cast_or_null(block.getContainingInst())) + liveValues.insert(forInst, true); - for (auto &stmt : block) { + for (auto &inst : block) { // Verify that each of the operands are live. unsigned operandNo = 0; - for (auto *opValue : stmt.getOperands()) { + for (auto *opValue : inst.getOperands()) { if (!liveValues.count(opValue)) { - stmt.emitError("operand #" + Twine(operandNo) + + inst.emitError("operand #" + Twine(operandNo) + " does not dominate this use"); - if (auto *useStmt = opValue->getDefiningInst()) - useStmt->emitNote("operand defined here"); + if (auto *useInst = opValue->getDefiningInst()) + useInst->emitNote("operand defined here"); return true; } ++operandNo; } - if (auto *opStmt = dyn_cast(&stmt)) { + if (auto *opInst = dyn_cast(&inst)) { // Operations define values, add them to the hash table. - for (auto *result : opStmt->getResults()) + for (auto *result : opInst->getResults()) liveValues.insert(result, true); continue; } // If this is an if or for, recursively walk the block they contain. - if (auto *ifStmt = dyn_cast(&stmt)) { - if (walkBlock(*ifStmt->getThen())) + if (auto *ifInst = dyn_cast(&inst)) { + if (walkBlock(*ifInst->getThen())) return true; - if (auto *elseClause = ifStmt->getElse()) + if (auto *elseClause = ifInst->getElse()) if (walkBlock(*elseClause)) return true; } - if (auto *forStmt = dyn_cast(&stmt)) - if (walkBlock(*forStmt->getBody())) + if (auto *forInst = dyn_cast(&inst)) + if (walkBlock(*forInst->getBody())) return true; } @@ -338,13 +338,14 @@ bool MLFuncVerifier::verifyDominance() { } bool MLFuncVerifier::verifyReturn() { - // TODO: fold return verification in the pass that verifies all statements. - const char missingReturnMsg[] = "ML function must end with return statement"; + // TODO: fold return verification in the pass that verifies all instructions. + const char missingReturnMsg[] = + "ML function must end with return instruction"; if (fn.getBody()->getInstructions().empty()) return failure(missingReturnMsg, fn); - const auto &stmt = fn.getBody()->getInstructions().back(); - if (const auto *op = dyn_cast(&stmt)) { + const auto &inst = fn.getBody()->getInstructions().back(); + if (const auto *op = dyn_cast(&inst)) { if (!op->isReturn()) return failure(missingReturnMsg, fn); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index daaaee7010c..cf822e025b8 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -25,11 +25,11 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Function.h" +#include "mlir/IR/InstVisitor.h" +#include "mlir/IR/Instructions.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Module.h" #include "mlir/IR/OpImplementation.h" -#include "mlir/IR/Statements.h" -#include "mlir/IR/StmtVisitor.h" #include "mlir/IR/Types.h" #include "mlir/Support/STLExtras.h" #include "llvm/ADT/APFloat.h" @@ -117,10 +117,10 @@ private: void visitExtFunction(const Function *fn); void visitCFGFunction(const Function *fn); void visitMLFunction(const Function *fn); - void visitStatement(const Statement *stmt); - void visitForStmt(const ForStmt *forStmt); - void visitIfStmt(const IfStmt *ifStmt); - void visitOperationInst(const OperationInst *opStmt); + void visitInstruction(const Instruction *inst); + void visitForInst(const ForInst *forInst); + void visitIfInst(const IfInst *ifInst); + void visitOperationInst(const OperationInst *opInst); void visitType(Type type); void visitAttribute(Attribute attr); void visitOperation(const OperationInst *op); @@ -184,47 +184,47 @@ void ModuleState::visitCFGFunction(const Function *fn) { if (auto *opInst = dyn_cast(&op)) visitOperation(opInst); else { - llvm_unreachable("IfStmt/ForStmt in a CFG Function isn't supported"); + llvm_unreachable("IfInst/ForInst in a CFG Function isn't supported"); } } } } -void ModuleState::visitIfStmt(const IfStmt *ifStmt) { - recordIntegerSetReference(ifStmt->getIntegerSet()); - for (auto &childStmt : *ifStmt->getThen()) - visitStatement(&childStmt); - if (ifStmt->hasElse()) - for (auto &childStmt : *ifStmt->getElse()) - visitStatement(&childStmt); +void ModuleState::visitIfInst(const IfInst *ifInst) { + recordIntegerSetReference(ifInst->getIntegerSet()); + for (auto &childInst : *ifInst->getThen()) + visitInstruction(&childInst); + if (ifInst->hasElse()) + for (auto &childInst : *ifInst->getElse()) + visitInstruction(&childInst); } -void ModuleState::visitForStmt(const ForStmt *forStmt) { - AffineMap lbMap = forStmt->getLowerBoundMap(); +void ModuleState::visitForInst(const ForInst *forInst) { + AffineMap lbMap = forInst->getLowerBoundMap(); if (!hasShorthandForm(lbMap)) recordAffineMapReference(lbMap); - AffineMap ubMap = forStmt->getUpperBoundMap(); + AffineMap ubMap = forInst->getUpperBoundMap(); if (!hasShorthandForm(ubMap)) recordAffineMapReference(ubMap); - for (auto &childStmt : *forStmt->getBody()) - visitStatement(&childStmt); + for (auto &childInst : *forInst->getBody()) + visitInstruction(&childInst); } -void ModuleState::visitOperationInst(const OperationInst *opStmt) { - for (auto attr : opStmt->getAttrs()) +void ModuleState::visitOperationInst(const OperationInst *opInst) { + for (auto attr : opInst->getAttrs()) visitAttribute(attr.second); } -void ModuleState::visitStatement(const Statement *stmt) { - switch (stmt->getKind()) { - case Statement::Kind::If: - return visitIfStmt(cast(stmt)); - case Statement::Kind::For: - return visitForStmt(cast(stmt)); - case Statement::Kind::OperationInst: - return visitOperationInst(cast(stmt)); +void ModuleState::visitInstruction(const Instruction *inst) { + switch (inst->getKind()) { + case Instruction::Kind::If: + return visitIfInst(cast(inst)); + case Instruction::Kind::For: + return visitForInst(cast(inst)); + case Instruction::Kind::OperationInst: + return visitOperationInst(cast(inst)); default: return; } @@ -232,8 +232,8 @@ void ModuleState::visitStatement(const Statement *stmt) { void ModuleState::visitMLFunction(const Function *fn) { visitType(fn->getType()); - for (auto &stmt : *fn->getBody()) { - ModuleState::visitStatement(&stmt); + for (auto &inst : *fn->getBody()) { + ModuleState::visitInstruction(&inst); } } @@ -909,11 +909,11 @@ public: void printMLFunctionSignature(); void printOtherFunctionSignature(); - // Methods to print statements. - void print(const Statement *stmt); + // Methods to print instructions. + void print(const Instruction *inst); void print(const OperationInst *inst); - void print(const ForStmt *stmt); - void print(const IfStmt *stmt); + void print(const ForInst *inst); + void print(const IfInst *inst); void print(const Block *block); void printOperation(const OperationInst *op); @@ -959,7 +959,7 @@ public: void printDimAndSymbolList(ArrayRef ops, unsigned numDims); void printBound(AffineBound bound, const char *prefix); - // Number of spaces used for indenting nested statements. + // Number of spaces used for indenting nested instructions. const static unsigned indentWidth = 2; protected: @@ -1019,22 +1019,22 @@ void FunctionPrinter::numberValuesInBlock(const Block &block) { // We number instruction that have results, and we only number the first // result. switch (inst.getKind()) { - case Statement::Kind::OperationInst: { + case Instruction::Kind::OperationInst: { auto *opInst = cast(&inst); if (opInst->getNumResults() != 0) numberValueID(opInst->getResult(0)); break; } - case Statement::Kind::For: { - auto *forInst = cast(&inst); + case Instruction::Kind::For: { + auto *forInst = cast(&inst); // Number the induction variable. numberValueID(forInst); // Recursively number the stuff in the body. numberValuesInBlock(*forInst->getBody()); break; } - case Statement::Kind::If: { - auto *ifInst = cast(&inst); + case Instruction::Kind::If: { + auto *ifInst = cast(&inst); numberValuesInBlock(*ifInst->getThen()); if (auto *elseBlock = ifInst->getElse()) numberValuesInBlock(*elseBlock); @@ -1086,7 +1086,7 @@ void FunctionPrinter::numberValueID(const Value *value) { // done with it. valueIDs[value] = nextValueID++; return; - case Value::Kind::ForStmt: + case Value::Kind::ForInst: specialName << 'i' << nextLoopID++; break; } @@ -1220,21 +1220,21 @@ void FunctionPrinter::print(const Block *block) { currentIndent += indentWidth; - for (auto &stmt : block->getInstructions()) { - print(&stmt); + for (auto &inst : block->getInstructions()) { + print(&inst); os << '\n'; } currentIndent -= indentWidth; } -void FunctionPrinter::print(const Statement *stmt) { - switch (stmt->getKind()) { - case Statement::Kind::OperationInst: - return print(cast(stmt)); - case Statement::Kind::For: - return print(cast(stmt)); - case Statement::Kind::If: - return print(cast(stmt)); +void FunctionPrinter::print(const Instruction *inst) { + switch (inst->getKind()) { + case Instruction::Kind::OperationInst: + return print(cast(inst)); + case Instruction::Kind::For: + return print(cast(inst)); + case Instruction::Kind::If: + return print(cast(inst)); } } @@ -1243,33 +1243,33 @@ void FunctionPrinter::print(const OperationInst *inst) { printOperation(inst); } -void FunctionPrinter::print(const ForStmt *stmt) { +void FunctionPrinter::print(const ForInst *inst) { os.indent(currentIndent) << "for "; - printOperand(stmt); + printOperand(inst); os << " = "; - printBound(stmt->getLowerBound(), "max"); + printBound(inst->getLowerBound(), "max"); os << " to "; - printBound(stmt->getUpperBound(), "min"); + printBound(inst->getUpperBound(), "min"); - if (stmt->getStep() != 1) - os << " step " << stmt->getStep(); + if (inst->getStep() != 1) + os << " step " << inst->getStep(); os << " {\n"; - print(stmt->getBody()); + print(inst->getBody()); os.indent(currentIndent) << "}"; } -void FunctionPrinter::print(const IfStmt *stmt) { +void FunctionPrinter::print(const IfInst *inst) { os.indent(currentIndent) << "if "; - IntegerSet set = stmt->getIntegerSet(); + IntegerSet set = inst->getIntegerSet(); printIntegerSetReference(set); - printDimAndSymbolList(stmt->getInstOperands(), set.getNumDims()); + printDimAndSymbolList(inst->getInstOperands(), set.getNumDims()); os << " {\n"; - print(stmt->getThen()); + print(inst->getThen()); os.indent(currentIndent) << "}"; - if (stmt->hasElse()) { + if (inst->hasElse()) { os << " else {\n"; - print(stmt->getElse()); + print(inst->getElse()); os.indent(currentIndent) << "}"; } } @@ -1280,7 +1280,7 @@ void FunctionPrinter::printValueID(const Value *value, auto lookupValue = value; // If this is a reference to the result of a multi-result instruction or - // statement, print out the # identifier and make sure to map our lookup + // instruction, print out the # identifier and make sure to map our lookup // to the first result of the instruction. if (auto *result = dyn_cast(value)) { if (result->getOwner()->getNumResults() != 1) { @@ -1493,8 +1493,8 @@ void Value::print(raw_ostream &os) const { return; case Value::Kind::InstResult: return getDefiningInst()->print(os); - case Value::Kind::ForStmt: - return cast(this)->print(os); + case Value::Kind::ForInst: + return cast(this)->print(os); } } diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp index c7e84194c35..2efba2bbf69 100644 --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -26,16 +26,16 @@ Block::~Block() { llvm::DeleteContainerPointers(arguments); } -/// Returns the closest surrounding statement that contains this block or -/// nullptr if this is a top-level statement block. -Statement *Block::getContainingInst() { +/// Returns the closest surrounding instruction that contains this block or +/// nullptr if this is a top-level instruction block. +Instruction *Block::getContainingInst() { return parent ? parent->getContainingInst() : nullptr; } Function *Block::getFunction() { Block *block = this; - while (auto *stmt = block->getContainingInst()) { - block = stmt->getBlock(); + while (auto *inst = block->getContainingInst()) { + block = inst->getBlock(); if (!block) return nullptr; } @@ -49,11 +49,11 @@ Function *Block::getFunction() { /// the latter fails. const Instruction * Block::findAncestorInstInBlock(const Instruction &inst) const { - // Traverse up the statement hierarchy starting from the owner of operand to - // find the ancestor statement that resides in the block of 'forStmt'. + // Traverse up the instruction hierarchy starting from the owner of operand to + // find the ancestor instruction that resides in the block of 'forInst'. const auto *currInst = &inst; while (currInst->getBlock() != this) { - currInst = currInst->getParentStmt(); + currInst = currInst->getParentInst(); if (!currInst) return nullptr; } @@ -106,10 +106,10 @@ OperationInst *Block::getTerminator() { // Check if the last instruction is a terminator. auto &backInst = back(); - auto *opStmt = dyn_cast(&backInst); - if (!opStmt || !opStmt->isTerminator()) + auto *opInst = dyn_cast(&backInst); + if (!opInst || !opInst->isTerminator()) return nullptr; - return opStmt; + return opInst; } /// Return true if this block has no predecessors. @@ -184,10 +184,10 @@ Block *Block::splitBlock(iterator splitBefore) { BlockList::BlockList(Function *container) : container(container) {} -BlockList::BlockList(Statement *container) : container(container) {} +BlockList::BlockList(Instruction *container) : container(container) {} -Statement *BlockList::getContainingInst() { - return container.dyn_cast(); +Instruction *BlockList::getContainingInst() { + return container.dyn_cast(); } Function *BlockList::getContainingFunction() { diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index a9eb6fe8c8a..4c7c8ddae81 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -268,7 +268,7 @@ AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) { } //===----------------------------------------------------------------------===// -// Statements. +// Instructions. //===----------------------------------------------------------------------===// /// Add new basic block and set the insertion point to the end of it. If an @@ -298,25 +298,25 @@ OperationInst *FuncBuilder::createOperation(const OperationState &state) { return op; } -ForStmt *FuncBuilder::createFor(Location location, ArrayRef lbOperands, +ForInst *FuncBuilder::createFor(Location location, ArrayRef lbOperands, AffineMap lbMap, ArrayRef ubOperands, AffineMap ubMap, int64_t step) { - auto *stmt = - ForStmt::create(location, lbOperands, lbMap, ubOperands, ubMap, step); - block->getInstructions().insert(insertPoint, stmt); - return stmt; + auto *inst = + ForInst::create(location, lbOperands, lbMap, ubOperands, ubMap, step); + block->getInstructions().insert(insertPoint, inst); + return inst; } -ForStmt *FuncBuilder::createFor(Location location, int64_t lb, int64_t ub, +ForInst *FuncBuilder::createFor(Location location, int64_t lb, int64_t ub, int64_t step) { auto lbMap = AffineMap::getConstantMap(lb, context); auto ubMap = AffineMap::getConstantMap(ub, context); return createFor(location, {}, lbMap, {}, ubMap, step); } -IfStmt *FuncBuilder::createIf(Location location, ArrayRef operands, +IfInst *FuncBuilder::createIf(Location location, ArrayRef operands, IntegerSet set) { - auto *stmt = IfStmt::create(location, operands, set); - block->getInstructions().insert(insertPoint, stmt); - return stmt; + auto *inst = IfInst::create(location, operands, set); + block->getInstructions().insert(insertPoint, inst); + return inst; } diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index cbe84e10247..bacb504683b 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -18,9 +18,9 @@ #include "mlir/IR/Function.h" #include "AttributeListStorage.h" #include "mlir/IR/Attributes.h" +#include "mlir/IR/InstVisitor.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" -#include "mlir/IR/StmtVisitor.h" #include "mlir/IR/Types.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringRef.h" @@ -161,21 +161,21 @@ bool Function::emitError(const Twine &message) const { // Function implementation. //===----------------------------------------------------------------------===// -const OperationInst *Function::getReturnStmt() const { +const OperationInst *Function::getReturn() const { return cast(&getBody()->back()); } -OperationInst *Function::getReturnStmt() { +OperationInst *Function::getReturn() { return cast(&getBody()->back()); } void Function::walk(std::function callback) { - struct Walker : public StmtWalker { + struct Walker : public InstWalker { std::function const &callback; Walker(std::function const &callback) : callback(callback) {} - void visitOperationInst(OperationInst *opStmt) { callback(opStmt); } + void visitOperationInst(OperationInst *opInst) { callback(opInst); } }; Walker v(callback); @@ -183,12 +183,12 @@ void Function::walk(std::function callback) { } void Function::walkPostOrder(std::function callback) { - struct Walker : public StmtWalker { + struct Walker : public InstWalker { std::function const &callback; Walker(std::function const &callback) : callback(callback) {} - void visitOperationInst(OperationInst *opStmt) { callback(opStmt); } + void visitOperationInst(OperationInst *opInst) { callback(opInst); } }; Walker v(callback); diff --git a/mlir/lib/IR/Instruction.cpp b/mlir/lib/IR/Instruction.cpp new file mode 100644 index 00000000000..92f3c4ecba3 --- /dev/null +++ b/mlir/lib/IR/Instruction.cpp @@ -0,0 +1,829 @@ +//===- Instruction.cpp - MLIR Instruction Classes +//----------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "AttributeListStorage.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/InstVisitor.h" +#include "mlir/IR/Instructions.h" +#include "mlir/IR/IntegerSet.h" +#include "mlir/IR/MLIRContext.h" +#include "llvm/ADT/DenseMap.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// InstResult +//===----------------------------------------------------------------------===// + +/// Return the result number of this result. +unsigned InstResult::getResultNumber() const { + // Results are always stored consecutively, so use pointer subtraction to + // figure out what number this is. + return this - &getOwner()->getInstResults()[0]; +} + +//===----------------------------------------------------------------------===// +// InstOperand +//===----------------------------------------------------------------------===// + +/// Return which operand this is in the operand list. +template <> unsigned InstOperand::getOperandNumber() const { + return this - &getOwner()->getInstOperands()[0]; +} + +/// Return which operand this is in the operand list. +template <> unsigned BlockOperand::getOperandNumber() const { + return this - &getOwner()->getBlockOperands()[0]; +} + +//===----------------------------------------------------------------------===// +// Instruction +//===----------------------------------------------------------------------===// + +// Instructions are deleted through the destroy() member because we don't have +// a virtual destructor. +Instruction::~Instruction() { + assert(block == nullptr && "instruction destroyed but still in a block"); +} + +/// Destroy this instruction or one of its subclasses. +void Instruction::destroy() { + switch (this->getKind()) { + case Kind::OperationInst: + cast(this)->destroy(); + break; + case Kind::For: + delete cast(this); + break; + case Kind::If: + delete cast(this); + break; + } +} + +Instruction *Instruction::getParentInst() const { + return block ? block->getContainingInst() : nullptr; +} + +Function *Instruction::getFunction() const { + return block ? block->getFunction() : nullptr; +} + +Value *Instruction::getOperand(unsigned idx) { + return getInstOperand(idx).get(); +} + +const Value *Instruction::getOperand(unsigned idx) const { + return getInstOperand(idx).get(); +} + +// Value can be used as a dimension id if it is valid as a symbol, or +// it is an induction variable, or it is a result of affine apply operation +// with dimension id arguments. +bool Value::isValidDim() const { + if (auto *inst = getDefiningInst()) { + // Top level instruction or constant operation is ok. + if (inst->getParentInst() == nullptr || inst->isa()) + return true; + // Affine apply operation is ok if all of its operands are ok. + if (auto op = inst->dyn_cast()) + return op->isValidDim(); + return false; + } + // This value is either a function argument or an induction variable. Both + // are ok. + return true; +} + +// Value can be used as a symbol if it is a constant, or it is defined at +// the top level, or it is a result of affine apply operation with symbol +// arguments. +bool Value::isValidSymbol() const { + if (auto *inst = getDefiningInst()) { + // Top level instruction or constant operation is ok. + if (inst->getParentInst() == nullptr || inst->isa()) + return true; + // Affine apply operation is ok if all of its operands are ok. + if (auto op = inst->dyn_cast()) + return op->isValidSymbol(); + return false; + } + // This value is either a function argument or an induction variable. + // Function argument is ok, induction variable is not. + return isa(this); +} + +void Instruction::setOperand(unsigned idx, Value *value) { + getInstOperand(idx).set(value); +} + +unsigned Instruction::getNumOperands() const { + switch (getKind()) { + case Kind::OperationInst: + return cast(this)->getNumOperands(); + case Kind::For: + return cast(this)->getNumOperands(); + case Kind::If: + return cast(this)->getNumOperands(); + } +} + +MutableArrayRef Instruction::getInstOperands() { + switch (getKind()) { + case Kind::OperationInst: + return cast(this)->getInstOperands(); + case Kind::For: + return cast(this)->getInstOperands(); + case Kind::If: + return cast(this)->getInstOperands(); + } +} + +/// Emit a note about this instruction, reporting up to any diagnostic +/// handlers that may be listening. +void Instruction::emitNote(const Twine &message) const { + getContext()->emitDiagnostic(getLoc(), message, + MLIRContext::DiagnosticKind::Note); +} + +/// Emit a warning about this instruction, reporting up to any diagnostic +/// handlers that may be listening. +void Instruction::emitWarning(const Twine &message) const { + getContext()->emitDiagnostic(getLoc(), message, + MLIRContext::DiagnosticKind::Warning); +} + +/// Emit an error about fatal conditions with this operation, reporting up to +/// any diagnostic handlers that may be listening. This function always +/// returns true. NOTE: This may terminate the containing application, only +/// use when the IR is in an inconsistent state. +bool Instruction::emitError(const Twine &message) const { + return getContext()->emitError(getLoc(), message); +} + +// Returns whether the Instruction is a terminator. +bool Instruction::isTerminator() const { + if (auto *op = dyn_cast(this)) + return op->isTerminator(); + return false; +} + +//===----------------------------------------------------------------------===// +// ilist_traits for Instruction +//===----------------------------------------------------------------------===// + +void llvm::ilist_traits<::mlir::Instruction>::deleteNode(Instruction *inst) { + inst->destroy(); +} + +Block *llvm::ilist_traits<::mlir::Instruction>::getContainingBlock() { + size_t Offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr)))); + iplist *Anchor(static_cast *>(this)); + return reinterpret_cast(reinterpret_cast(Anchor) - Offset); +} + +/// This is a trait method invoked when a instruction is added to a block. We +/// keep the block pointer up to date. +void llvm::ilist_traits<::mlir::Instruction>::addNodeToList(Instruction *inst) { + assert(!inst->getBlock() && "already in a instruction block!"); + inst->block = getContainingBlock(); +} + +/// This is a trait method invoked when a instruction is removed from a block. +/// We keep the block pointer up to date. +void llvm::ilist_traits<::mlir::Instruction>::removeNodeFromList( + Instruction *inst) { + assert(inst->block && "not already in a instruction block!"); + inst->block = nullptr; +} + +/// This is a trait method invoked when a instruction is moved from one block +/// to another. We keep the block pointer up to date. +void llvm::ilist_traits<::mlir::Instruction>::transferNodesFromList( + ilist_traits &otherList, inst_iterator first, + inst_iterator last) { + // If we are transferring instructions within the same block, the block + // pointer doesn't need to be updated. + Block *curParent = getContainingBlock(); + if (curParent == otherList.getContainingBlock()) + return; + + // Update the 'block' member of each instruction. + for (; first != last; ++first) + first->block = curParent; +} + +/// Remove this instruction (and its descendants) from its Block and delete +/// all of them. +void Instruction::erase() { + assert(getBlock() && "Instruction has no block"); + getBlock()->getInstructions().erase(this); +} + +/// Unlink this instruction from its current block and insert it right before +/// `existingInst` which may be in the same or another block in the same +/// function. +void Instruction::moveBefore(Instruction *existingInst) { + moveBefore(existingInst->getBlock(), existingInst->getIterator()); +} + +/// Unlink this operation instruction from its current basic block and insert +/// it right before `iterator` in the specified basic block. +void Instruction::moveBefore(Block *block, + llvm::iplist::iterator iterator) { + block->getInstructions().splice(iterator, getBlock()->getInstructions(), + getIterator()); +} + +/// This drops all operand uses from this instruction, which is an essential +/// step in breaking cyclic dependences between references when they are to +/// be deleted. +void Instruction::dropAllReferences() { + for (auto &op : getInstOperands()) + op.drop(); + + if (isTerminator()) + for (auto &dest : cast(this)->getBlockOperands()) + dest.drop(); +} + +//===----------------------------------------------------------------------===// +// OperationInst +//===----------------------------------------------------------------------===// + +/// Create a new OperationInst with the specific fields. +OperationInst *OperationInst::create(Location location, OperationName name, + ArrayRef operands, + ArrayRef resultTypes, + ArrayRef attributes, + ArrayRef successors, + MLIRContext *context) { + unsigned numSuccessors = successors.size(); + + // Input operands are nullptr-separated for each successors in the case of + // terminators, the nullptr aren't actually stored. + unsigned numOperands = operands.size() - numSuccessors; + + auto byteSize = + totalSizeToAlloc( + resultTypes.size(), numSuccessors, numSuccessors, numOperands); + void *rawMem = malloc(byteSize); + + // Initialize the OperationInst part of the instruction. + auto inst = ::new (rawMem) + OperationInst(location, name, numOperands, resultTypes.size(), + numSuccessors, attributes, context); + + // Initialize the results and operands. + auto instResults = inst->getInstResults(); + for (unsigned i = 0, e = resultTypes.size(); i != e; ++i) + new (&instResults[i]) InstResult(resultTypes[i], inst); + + auto InstOperands = inst->getInstOperands(); + + // Initialize normal operands. + unsigned operandIt = 0, operandE = operands.size(); + unsigned nextOperand = 0; + for (; operandIt != operandE; ++operandIt) { + // Null operands are used as sentinals between successor operand lists. If + // we encounter one here, break and handle the successor operands lists + // separately below. + if (!operands[operandIt]) + break; + new (&InstOperands[nextOperand++]) InstOperand(inst, operands[operandIt]); + } + + unsigned currentSuccNum = 0; + if (operandIt == operandE) { + // Verify that the amount of sentinal operands is equivalent to the number + // of successors. + assert(currentSuccNum == numSuccessors); + return inst; + } + + assert(inst->isTerminator() && + "Sentinal operand found in non terminator operand list."); + auto instBlockOperands = inst->getBlockOperands(); + unsigned *succOperandCountIt = inst->getTrailingObjects(); + unsigned *succOperandCountE = succOperandCountIt + numSuccessors; + (void)succOperandCountE; + + for (; operandIt != operandE; ++operandIt) { + // If we encounter a sentinal branch to the next operand update the count + // variable. + if (!operands[operandIt]) { + assert(currentSuccNum < numSuccessors); + + // After the first iteration update the successor operand count + // variable. + if (currentSuccNum != 0) { + ++succOperandCountIt; + assert(succOperandCountIt != succOperandCountE && + "More sentinal operands than successors."); + } + + new (&instBlockOperands[currentSuccNum]) + BlockOperand(inst, successors[currentSuccNum]); + *succOperandCountIt = 0; + ++currentSuccNum; + continue; + } + new (&InstOperands[nextOperand++]) InstOperand(inst, operands[operandIt]); + ++(*succOperandCountIt); + } + + // Verify that the amount of sentinal operands is equivalent to the number of + // successors. + assert(currentSuccNum == numSuccessors); + + return inst; +} + +OperationInst::OperationInst(Location location, OperationName name, + unsigned numOperands, unsigned numResults, + unsigned numSuccessors, + ArrayRef attributes, + MLIRContext *context) + : Instruction(Kind::OperationInst, location), numOperands(numOperands), + numResults(numResults), numSuccs(numSuccessors), name(name) { +#ifndef NDEBUG + for (auto elt : attributes) + assert(elt.second != nullptr && "Attributes cannot have null entries"); +#endif + + this->attrs = AttributeListStorage::get(attributes, context); +} + +OperationInst::~OperationInst() { + // Explicitly run the destructors for the operands and results. + for (auto &operand : getInstOperands()) + operand.~InstOperand(); + + for (auto &result : getInstResults()) + result.~InstResult(); + + // Explicitly run the destructors for the successors. + if (isTerminator()) + for (auto &successor : getBlockOperands()) + successor.~BlockOperand(); +} + +/// Return true if there are no users of any results of this operation. +bool OperationInst::use_empty() const { + for (auto *result : getResults()) + if (!result->use_empty()) + return false; + return true; +} + +ArrayRef OperationInst::getAttrs() const { + if (!attrs) + return {}; + return attrs->getElements(); +} + +void OperationInst::destroy() { + this->~OperationInst(); + free(this); +} + +/// Return the context this operation is associated with. +MLIRContext *OperationInst::getContext() const { + // If we have a result or operand type, that is a constant time way to get + // to the context. + if (getNumResults()) + return getResult(0)->getType().getContext(); + if (getNumOperands()) + return getOperand(0)->getType().getContext(); + + // In the very odd case where we have no operands or results, fall back to + // doing a find. + return getFunction()->getContext(); +} + +bool OperationInst::isReturn() const { return isa(); } + +void OperationInst::setSuccessor(Block *block, unsigned index) { + assert(index < getNumSuccessors()); + getBlockOperands()[index].set(block); +} + +void OperationInst::eraseOperand(unsigned index) { + assert(index < getNumOperands()); + auto Operands = getInstOperands(); + // Shift all operands down by 1. + std::rotate(&Operands[index], &Operands[index + 1], + &Operands[numOperands - 1]); + --numOperands; + Operands[getNumOperands()].~InstOperand(); +} + +auto OperationInst::getSuccessorOperands(unsigned index) const + -> llvm::iterator_range { + assert(isTerminator() && "Only terminators have successors."); + unsigned succOperandIndex = getSuccessorOperandIndex(index); + return {const_operand_iterator(this, succOperandIndex), + const_operand_iterator(this, succOperandIndex + + getNumSuccessorOperands(index))}; +} +auto OperationInst::getSuccessorOperands(unsigned index) + -> llvm::iterator_range { + assert(isTerminator() && "Only terminators have successors."); + unsigned succOperandIndex = getSuccessorOperandIndex(index); + return {operand_iterator(this, succOperandIndex), + operand_iterator(this, + succOperandIndex + getNumSuccessorOperands(index))}; +} + +/// If an attribute exists with the specified name, change it to the new +/// value. Otherwise, add a new attribute with the specified name/value. +void OperationInst::setAttr(Identifier name, Attribute value) { + assert(value && "attributes may never be null"); + auto origAttrs = getAttrs(); + + SmallVector newAttrs(origAttrs.begin(), origAttrs.end()); + auto *context = getContext(); + + // If we already have this attribute, replace it. + for (auto &elt : newAttrs) + if (elt.first == name) { + elt.second = value; + attrs = AttributeListStorage::get(newAttrs, context); + return; + } + + // Otherwise, add it. + newAttrs.push_back({name, value}); + attrs = AttributeListStorage::get(newAttrs, context); +} + +/// Remove the attribute with the specified name if it exists. The return +/// value indicates whether the attribute was present or not. +auto OperationInst::removeAttr(Identifier name) -> RemoveResult { + auto origAttrs = getAttrs(); + for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) { + if (origAttrs[i].first == name) { + SmallVector newAttrs; + newAttrs.reserve(origAttrs.size() - 1); + newAttrs.append(origAttrs.begin(), origAttrs.begin() + i); + newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end()); + attrs = AttributeListStorage::get(newAttrs, getContext()); + return RemoveResult::Removed; + } + } + return RemoveResult::NotFound; +} + +/// Attempt to constant fold this operation with the specified constant +/// operand values. If successful, this returns false and fills in the +/// results vector. If not, this returns true and results is unspecified. +bool OperationInst::constantFold(ArrayRef operands, + SmallVectorImpl &results) const { + if (auto *abstractOp = getAbstractOperation()) { + // If we have a registered operation definition matching this one, use it to + // try to constant fold the operation. + if (!abstractOp->constantFoldHook(llvm::cast(this), operands, + results)) + return false; + + // Otherwise, fall back on the dialect hook to handle it. + return abstractOp->dialect.constantFoldHook(llvm::cast(this), + operands, results); + } + + // If this operation hasn't been registered or doesn't have abstract + // operation, fall back to a dialect which matches the prefix. + auto opName = getName().getStringRef(); + if (auto *dialect = getContext()->getRegisteredDialect(opName)) { + return dialect->constantFoldHook(llvm::cast(this), operands, + results); + } + + return true; +} + +/// Emit an error with the op name prefixed, like "'dim' op " which is +/// convenient for verifiers. +bool OperationInst::emitOpError(const Twine &message) const { + return emitError(Twine('\'') + getName().getStringRef() + "' op " + message); +} + +//===----------------------------------------------------------------------===// +// ForInst +//===----------------------------------------------------------------------===// + +ForInst *ForInst::create(Location location, ArrayRef lbOperands, + AffineMap lbMap, ArrayRef ubOperands, + AffineMap ubMap, int64_t step) { + assert(lbOperands.size() == lbMap.getNumInputs() && + "lower bound operand count does not match the affine map"); + assert(ubOperands.size() == ubMap.getNumInputs() && + "upper bound operand count does not match the affine map"); + assert(step > 0 && "step has to be a positive integer constant"); + + unsigned numOperands = lbOperands.size() + ubOperands.size(); + ForInst *inst = new ForInst(location, numOperands, lbMap, ubMap, step); + + unsigned i = 0; + for (unsigned e = lbOperands.size(); i != e; ++i) + inst->operands.emplace_back(InstOperand(inst, lbOperands[i])); + + for (unsigned j = 0, e = ubOperands.size(); j != e; ++i, ++j) + inst->operands.emplace_back(InstOperand(inst, ubOperands[j])); + + return inst; +} + +ForInst::ForInst(Location location, unsigned numOperands, AffineMap lbMap, + AffineMap ubMap, int64_t step) + : Instruction(Instruction::Kind::For, location), + Value(Value::Kind::ForInst, + Type::getIndex(lbMap.getResult(0).getContext())), + body(this), lbMap(lbMap), ubMap(ubMap), step(step) { + + // The body of a for inst always has one block. + body.push_back(new Block()); + operands.reserve(numOperands); +} + +const AffineBound ForInst::getLowerBound() const { + return AffineBound(*this, 0, lbMap.getNumInputs(), lbMap); +} + +const AffineBound ForInst::getUpperBound() const { + return AffineBound(*this, lbMap.getNumInputs(), getNumOperands(), ubMap); +} + +void ForInst::setLowerBound(ArrayRef lbOperands, AffineMap map) { + assert(lbOperands.size() == map.getNumInputs()); + assert(map.getNumResults() >= 1 && "bound map has at least one result"); + + SmallVector ubOperands(getUpperBoundOperands()); + + operands.clear(); + operands.reserve(lbOperands.size() + ubMap.getNumInputs()); + for (auto *operand : lbOperands) { + operands.emplace_back(InstOperand(this, operand)); + } + for (auto *operand : ubOperands) { + operands.emplace_back(InstOperand(this, operand)); + } + this->lbMap = map; +} + +void ForInst::setUpperBound(ArrayRef ubOperands, AffineMap map) { + assert(ubOperands.size() == map.getNumInputs()); + assert(map.getNumResults() >= 1 && "bound map has at least one result"); + + SmallVector lbOperands(getLowerBoundOperands()); + + operands.clear(); + operands.reserve(lbOperands.size() + ubOperands.size()); + for (auto *operand : lbOperands) { + operands.emplace_back(InstOperand(this, operand)); + } + for (auto *operand : ubOperands) { + operands.emplace_back(InstOperand(this, operand)); + } + this->ubMap = map; +} + +void ForInst::setLowerBoundMap(AffineMap map) { + assert(lbMap.getNumDims() == map.getNumDims() && + lbMap.getNumSymbols() == map.getNumSymbols()); + assert(map.getNumResults() >= 1 && "bound map has at least one result"); + this->lbMap = map; +} + +void ForInst::setUpperBoundMap(AffineMap map) { + assert(ubMap.getNumDims() == map.getNumDims() && + ubMap.getNumSymbols() == map.getNumSymbols()); + assert(map.getNumResults() >= 1 && "bound map has at least one result"); + this->ubMap = map; +} + +bool ForInst::hasConstantLowerBound() const { return lbMap.isSingleConstant(); } + +bool ForInst::hasConstantUpperBound() const { return ubMap.isSingleConstant(); } + +int64_t ForInst::getConstantLowerBound() const { + return lbMap.getSingleConstantResult(); +} + +int64_t ForInst::getConstantUpperBound() const { + return ubMap.getSingleConstantResult(); +} + +void ForInst::setConstantLowerBound(int64_t value) { + setLowerBound({}, AffineMap::getConstantMap(value, getContext())); +} + +void ForInst::setConstantUpperBound(int64_t value) { + setUpperBound({}, AffineMap::getConstantMap(value, getContext())); +} + +ForInst::operand_range ForInst::getLowerBoundOperands() { + return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()}; +} + +ForInst::const_operand_range ForInst::getLowerBoundOperands() const { + return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()}; +} + +ForInst::operand_range ForInst::getUpperBoundOperands() { + return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()}; +} + +ForInst::const_operand_range ForInst::getUpperBoundOperands() const { + return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()}; +} + +bool ForInst::matchingBoundOperandList() const { + if (lbMap.getNumDims() != ubMap.getNumDims() || + lbMap.getNumSymbols() != ubMap.getNumSymbols()) + return false; + + unsigned numOperands = lbMap.getNumInputs(); + for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) { + // Compare Value *'s. + if (getOperand(i) != getOperand(numOperands + i)) + return false; + } + return true; +} + +//===----------------------------------------------------------------------===// +// IfInst +//===----------------------------------------------------------------------===// + +IfInst::IfInst(Location location, unsigned numOperands, IntegerSet set) + : Instruction(Kind::If, location), thenClause(this), elseClause(nullptr), + set(set) { + operands.reserve(numOperands); + + // The then of an 'if' inst always has one block. + thenClause.push_back(new Block()); +} + +IfInst::~IfInst() { + if (elseClause) + delete elseClause; + + // An IfInst's IntegerSet 'set' should not be deleted since it is + // allocated through MLIRContext's bump pointer allocator. +} + +IfInst *IfInst::create(Location location, ArrayRef operands, + IntegerSet set) { + unsigned numOperands = operands.size(); + assert(numOperands == set.getNumOperands() && + "operand cound does not match the integer set operand count"); + + IfInst *inst = new IfInst(location, numOperands, set); + + for (auto *op : operands) + inst->operands.emplace_back(InstOperand(inst, op)); + + return inst; +} + +const AffineCondition IfInst::getCondition() const { + return AffineCondition(*this, set); +} + +MLIRContext *IfInst::getContext() const { + // Check for degenerate case of if instruction with no operands. + // This is unlikely, but legal. + if (operands.empty()) + return getFunction()->getContext(); + + return getOperand(0)->getType().getContext(); +} + +//===----------------------------------------------------------------------===// +// Instruction Cloning +//===----------------------------------------------------------------------===// + +/// Create a deep copy of this instruction, remapping any operands that use +/// values outside of the instruction using the map that is provided (leaving +/// them alone if no entry is present). Replaces references to cloned +/// sub-instructions to the corresponding instruction that is copied, and adds +/// those mappings to the map. +Instruction *Instruction::clone(DenseMap &operandMap, + MLIRContext *context) const { + // If the specified value is in operandMap, return the remapped value. + // Otherwise return the value itself. + auto remapOperand = [&](const Value *value) -> Value * { + auto it = operandMap.find(value); + return it != operandMap.end() ? it->second : const_cast(value); + }; + + SmallVector operands; + SmallVector successors; + if (auto *opInst = dyn_cast(this)) { + operands.reserve(getNumOperands() + opInst->getNumSuccessors()); + + if (!opInst->isTerminator()) { + // Non-terminators just add all the operands. + for (auto *opValue : getOperands()) + operands.push_back(remapOperand(opValue)); + } else { + // We add the operands separated by nullptr's for each successor. + unsigned firstSuccOperand = opInst->getNumSuccessors() + ? opInst->getSuccessorOperandIndex(0) + : opInst->getNumOperands(); + auto InstOperands = opInst->getInstOperands(); + + unsigned i = 0; + for (; i != firstSuccOperand; ++i) + operands.push_back(remapOperand(InstOperands[i].get())); + + successors.reserve(opInst->getNumSuccessors()); + for (unsigned succ = 0, e = opInst->getNumSuccessors(); succ != e; + ++succ) { + successors.push_back(const_cast(opInst->getSuccessor(succ))); + + // Add sentinel to delineate successor operands. + operands.push_back(nullptr); + + // Remap the successors operands. + for (auto *operand : opInst->getSuccessorOperands(succ)) + operands.push_back(remapOperand(operand)); + } + } + + SmallVector resultTypes; + resultTypes.reserve(opInst->getNumResults()); + for (auto *result : opInst->getResults()) + resultTypes.push_back(result->getType()); + auto *newOp = OperationInst::create(getLoc(), opInst->getName(), operands, + resultTypes, opInst->getAttrs(), + successors, context); + // Remember the mapping of any results. + for (unsigned i = 0, e = opInst->getNumResults(); i != e; ++i) + operandMap[opInst->getResult(i)] = newOp->getResult(i); + return newOp; + } + + operands.reserve(getNumOperands()); + for (auto *opValue : getOperands()) + operands.push_back(remapOperand(opValue)); + + if (auto *forInst = dyn_cast(this)) { + auto lbMap = forInst->getLowerBoundMap(); + auto ubMap = forInst->getUpperBoundMap(); + + auto *newFor = ForInst::create( + getLoc(), ArrayRef(operands).take_front(lbMap.getNumInputs()), + lbMap, ArrayRef(operands).take_back(ubMap.getNumInputs()), + ubMap, forInst->getStep()); + + // Remember the induction variable mapping. + operandMap[forInst] = newFor; + + // Recursively clone the body of the for loop. + for (auto &subInst : *forInst->getBody()) + newFor->getBody()->push_back(subInst.clone(operandMap, context)); + + return newFor; + } + + // Otherwise, we must have an If instruction. + auto *ifInst = cast(this); + auto *newIf = IfInst::create(getLoc(), operands, ifInst->getIntegerSet()); + + auto *resultThen = newIf->getThen(); + for (auto &childInst : *ifInst->getThen()) + resultThen->push_back(childInst.clone(operandMap, context)); + + if (ifInst->hasElse()) { + auto *resultElse = newIf->createElse(); + for (auto &childInst : *ifInst->getElse()) + resultElse->push_back(childInst.clone(operandMap, context)); + } + + return newIf; +} + +Instruction *Instruction::clone(MLIRContext *context) const { + DenseMap operandMap; + return clone(operandMap, context); +} diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index ccd7d65f7c8..9cd4355e4aa 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -17,10 +17,10 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/Function.h" +#include "mlir/IR/Instructions.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" -#include "mlir/IR/Statements.h" using namespace mlir; /// Form the OperationName for an op with the specified string. This either is @@ -279,7 +279,7 @@ bool OpTrait::impl::verifyIsTerminator(const OperationInst *op) { if (op->getFunction()->isML()) { Block *block = op->getBlock(); if (!block || block->getContainingInst() || &block->back() != op) - return op->emitOpError("must be the last statement in the ML function"); + return op->emitOpError("must be the last instruction in the ML function"); } else { const Block *block = op->getBlock(); if (!block || &block->back() != op) diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index 8c41d488a8b..90d768c844e 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -16,7 +16,7 @@ // ============================================================================= #include "mlir/IR/PatternMatch.h" -#include "mlir/IR/Statements.h" +#include "mlir/IR/Instructions.h" #include "mlir/IR/Value.h" using namespace mlir; diff --git a/mlir/lib/IR/Statement.cpp b/mlir/lib/IR/Statement.cpp deleted file mode 100644 index 6bd9944bb65..00000000000 --- a/mlir/lib/IR/Statement.cpp +++ /dev/null @@ -1,826 +0,0 @@ -//===- Statement.cpp - MLIR Statement Classes ----------------------------===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#include "AttributeListStorage.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Function.h" -#include "mlir/IR/IntegerSet.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Statements.h" -#include "mlir/IR/StmtVisitor.h" -#include "llvm/ADT/DenseMap.h" - -using namespace mlir; - -//===----------------------------------------------------------------------===// -// InstResult -//===----------------------------------------------------------------------===// - -/// Return the result number of this result. -unsigned InstResult::getResultNumber() const { - // Results are always stored consecutively, so use pointer subtraction to - // figure out what number this is. - return this - &getOwner()->getInstResults()[0]; -} - -//===----------------------------------------------------------------------===// -// InstOperand -//===----------------------------------------------------------------------===// - -/// Return which operand this is in the operand list. -template <> unsigned InstOperand::getOperandNumber() const { - return this - &getOwner()->getInstOperands()[0]; -} - -/// Return which operand this is in the operand list. -template <> unsigned BlockOperand::getOperandNumber() const { - return this - &getOwner()->getBlockOperands()[0]; -} - -//===----------------------------------------------------------------------===// -// Statement -//===----------------------------------------------------------------------===// - -// Statements are deleted through the destroy() member because we don't have -// a virtual destructor. -Statement::~Statement() { - assert(block == nullptr && "statement destroyed but still in a block"); -} - -/// Destroy this statement or one of its subclasses. -void Statement::destroy() { - switch (this->getKind()) { - case Kind::OperationInst: - cast(this)->destroy(); - break; - case Kind::For: - delete cast(this); - break; - case Kind::If: - delete cast(this); - break; - } -} - -Statement *Statement::getParentStmt() const { - return block ? block->getContainingInst() : nullptr; -} - -Function *Statement::getFunction() const { - return block ? block->getFunction() : nullptr; -} - -Value *Statement::getOperand(unsigned idx) { return getInstOperand(idx).get(); } - -const Value *Statement::getOperand(unsigned idx) const { - return getInstOperand(idx).get(); -} - -// Value can be used as a dimension id if it is valid as a symbol, or -// it is an induction variable, or it is a result of affine apply operation -// with dimension id arguments. -bool Value::isValidDim() const { - if (auto *stmt = getDefiningInst()) { - // Top level statement or constant operation is ok. - if (stmt->getParentStmt() == nullptr || stmt->isa()) - return true; - // Affine apply operation is ok if all of its operands are ok. - if (auto op = stmt->dyn_cast()) - return op->isValidDim(); - return false; - } - // This value is either a function argument or an induction variable. Both - // are ok. - return true; -} - -// Value can be used as a symbol if it is a constant, or it is defined at -// the top level, or it is a result of affine apply operation with symbol -// arguments. -bool Value::isValidSymbol() const { - if (auto *stmt = getDefiningInst()) { - // Top level statement or constant operation is ok. - if (stmt->getParentStmt() == nullptr || stmt->isa()) - return true; - // Affine apply operation is ok if all of its operands are ok. - if (auto op = stmt->dyn_cast()) - return op->isValidSymbol(); - return false; - } - // This value is either a function argument or an induction variable. - // Function argument is ok, induction variable is not. - return isa(this); -} - -void Statement::setOperand(unsigned idx, Value *value) { - getInstOperand(idx).set(value); -} - -unsigned Statement::getNumOperands() const { - switch (getKind()) { - case Kind::OperationInst: - return cast(this)->getNumOperands(); - case Kind::For: - return cast(this)->getNumOperands(); - case Kind::If: - return cast(this)->getNumOperands(); - } -} - -MutableArrayRef Statement::getInstOperands() { - switch (getKind()) { - case Kind::OperationInst: - return cast(this)->getInstOperands(); - case Kind::For: - return cast(this)->getInstOperands(); - case Kind::If: - return cast(this)->getInstOperands(); - } -} - -/// Emit a note about this statement, reporting up to any diagnostic -/// handlers that may be listening. -void Statement::emitNote(const Twine &message) const { - getContext()->emitDiagnostic(getLoc(), message, - MLIRContext::DiagnosticKind::Note); -} - -/// Emit a warning about this statement, reporting up to any diagnostic -/// handlers that may be listening. -void Statement::emitWarning(const Twine &message) const { - getContext()->emitDiagnostic(getLoc(), message, - MLIRContext::DiagnosticKind::Warning); -} - -/// Emit an error about fatal conditions with this operation, reporting up to -/// any diagnostic handlers that may be listening. This function always -/// returns true. NOTE: This may terminate the containing application, only -/// use when the IR is in an inconsistent state. -bool Statement::emitError(const Twine &message) const { - return getContext()->emitError(getLoc(), message); -} - -// Returns whether the Statement is a terminator. -bool Statement::isTerminator() const { - if (auto *op = dyn_cast(this)) - return op->isTerminator(); - return false; -} - -//===----------------------------------------------------------------------===// -// ilist_traits for Statement -//===----------------------------------------------------------------------===// - -void llvm::ilist_traits<::mlir::Statement>::deleteNode(Statement *stmt) { - stmt->destroy(); -} - -Block *llvm::ilist_traits<::mlir::Statement>::getContainingBlock() { - size_t Offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr)))); - iplist *Anchor(static_cast *>(this)); - return reinterpret_cast(reinterpret_cast(Anchor) - Offset); -} - -/// This is a trait method invoked when a statement is added to a block. We -/// keep the block pointer up to date. -void llvm::ilist_traits<::mlir::Statement>::addNodeToList(Statement *stmt) { - assert(!stmt->getBlock() && "already in a statement block!"); - stmt->block = getContainingBlock(); -} - -/// This is a trait method invoked when a statement is removed from a block. -/// We keep the block pointer up to date. -void llvm::ilist_traits<::mlir::Statement>::removeNodeFromList( - Statement *stmt) { - assert(stmt->block && "not already in a statement block!"); - stmt->block = nullptr; -} - -/// This is a trait method invoked when a statement is moved from one block -/// to another. We keep the block pointer up to date. -void llvm::ilist_traits<::mlir::Statement>::transferNodesFromList( - ilist_traits &otherList, stmt_iterator first, - stmt_iterator last) { - // If we are transferring statements within the same block, the block - // pointer doesn't need to be updated. - Block *curParent = getContainingBlock(); - if (curParent == otherList.getContainingBlock()) - return; - - // Update the 'block' member of each statement. - for (; first != last; ++first) - first->block = curParent; -} - -/// Remove this statement (and its descendants) from its Block and delete -/// all of them. -void Statement::erase() { - assert(getBlock() && "Statement has no block"); - getBlock()->getInstructions().erase(this); -} - -/// Unlink this statement from its current block and insert it right before -/// `existingStmt` which may be in the same or another block in the same -/// function. -void Statement::moveBefore(Statement *existingStmt) { - moveBefore(existingStmt->getBlock(), existingStmt->getIterator()); -} - -/// Unlink this operation instruction from its current basic block and insert -/// it right before `iterator` in the specified basic block. -void Statement::moveBefore(Block *block, - llvm::iplist::iterator iterator) { - block->getInstructions().splice(iterator, getBlock()->getInstructions(), - getIterator()); -} - -/// This drops all operand uses from this instruction, which is an essential -/// step in breaking cyclic dependences between references when they are to -/// be deleted. -void Statement::dropAllReferences() { - for (auto &op : getInstOperands()) - op.drop(); - - if (isTerminator()) - for (auto &dest : cast(this)->getBlockOperands()) - dest.drop(); -} - -//===----------------------------------------------------------------------===// -// OperationInst -//===----------------------------------------------------------------------===// - -/// Create a new OperationInst with the specific fields. -OperationInst *OperationInst::create(Location location, OperationName name, - ArrayRef operands, - ArrayRef resultTypes, - ArrayRef attributes, - ArrayRef successors, - MLIRContext *context) { - unsigned numSuccessors = successors.size(); - - // Input operands are nullptr-separated for each successors in the case of - // terminators, the nullptr aren't actually stored. - unsigned numOperands = operands.size() - numSuccessors; - - auto byteSize = - totalSizeToAlloc( - resultTypes.size(), numSuccessors, numSuccessors, numOperands); - void *rawMem = malloc(byteSize); - - // Initialize the OperationInst part of the statement. - auto stmt = ::new (rawMem) - OperationInst(location, name, numOperands, resultTypes.size(), - numSuccessors, attributes, context); - - // Initialize the results and operands. - auto instResults = stmt->getInstResults(); - for (unsigned i = 0, e = resultTypes.size(); i != e; ++i) - new (&instResults[i]) InstResult(resultTypes[i], stmt); - - auto InstOperands = stmt->getInstOperands(); - - // Initialize normal operands. - unsigned operandIt = 0, operandE = operands.size(); - unsigned nextOperand = 0; - for (; operandIt != operandE; ++operandIt) { - // Null operands are used as sentinals between successor operand lists. If - // we encounter one here, break and handle the successor operands lists - // separately below. - if (!operands[operandIt]) - break; - new (&InstOperands[nextOperand++]) InstOperand(stmt, operands[operandIt]); - } - - unsigned currentSuccNum = 0; - if (operandIt == operandE) { - // Verify that the amount of sentinal operands is equivalent to the number - // of successors. - assert(currentSuccNum == numSuccessors); - return stmt; - } - - assert(stmt->isTerminator() && - "Sentinal operand found in non terminator operand list."); - auto instBlockOperands = stmt->getBlockOperands(); - unsigned *succOperandCountIt = stmt->getTrailingObjects(); - unsigned *succOperandCountE = succOperandCountIt + numSuccessors; - (void)succOperandCountE; - - for (; operandIt != operandE; ++operandIt) { - // If we encounter a sentinal branch to the next operand update the count - // variable. - if (!operands[operandIt]) { - assert(currentSuccNum < numSuccessors); - - // After the first iteration update the successor operand count - // variable. - if (currentSuccNum != 0) { - ++succOperandCountIt; - assert(succOperandCountIt != succOperandCountE && - "More sentinal operands than successors."); - } - - new (&instBlockOperands[currentSuccNum]) - BlockOperand(stmt, successors[currentSuccNum]); - *succOperandCountIt = 0; - ++currentSuccNum; - continue; - } - new (&InstOperands[nextOperand++]) InstOperand(stmt, operands[operandIt]); - ++(*succOperandCountIt); - } - - // Verify that the amount of sentinal operands is equivalent to the number of - // successors. - assert(currentSuccNum == numSuccessors); - - return stmt; -} - -OperationInst::OperationInst(Location location, OperationName name, - unsigned numOperands, unsigned numResults, - unsigned numSuccessors, - ArrayRef attributes, - MLIRContext *context) - : Statement(Kind::OperationInst, location), numOperands(numOperands), - numResults(numResults), numSuccs(numSuccessors), name(name) { -#ifndef NDEBUG - for (auto elt : attributes) - assert(elt.second != nullptr && "Attributes cannot have null entries"); -#endif - - this->attrs = AttributeListStorage::get(attributes, context); -} - -OperationInst::~OperationInst() { - // Explicitly run the destructors for the operands and results. - for (auto &operand : getInstOperands()) - operand.~InstOperand(); - - for (auto &result : getInstResults()) - result.~InstResult(); - - // Explicitly run the destructors for the successors. - if (isTerminator()) - for (auto &successor : getBlockOperands()) - successor.~BlockOperand(); -} - -/// Return true if there are no users of any results of this operation. -bool OperationInst::use_empty() const { - for (auto *result : getResults()) - if (!result->use_empty()) - return false; - return true; -} - -ArrayRef OperationInst::getAttrs() const { - if (!attrs) - return {}; - return attrs->getElements(); -} - -void OperationInst::destroy() { - this->~OperationInst(); - free(this); -} - -/// Return the context this operation is associated with. -MLIRContext *OperationInst::getContext() const { - // If we have a result or operand type, that is a constant time way to get - // to the context. - if (getNumResults()) - return getResult(0)->getType().getContext(); - if (getNumOperands()) - return getOperand(0)->getType().getContext(); - - // In the very odd case where we have no operands or results, fall back to - // doing a find. - return getFunction()->getContext(); -} - -bool OperationInst::isReturn() const { return isa(); } - -void OperationInst::setSuccessor(Block *block, unsigned index) { - assert(index < getNumSuccessors()); - getBlockOperands()[index].set(block); -} - -void OperationInst::eraseOperand(unsigned index) { - assert(index < getNumOperands()); - auto Operands = getInstOperands(); - // Shift all operands down by 1. - std::rotate(&Operands[index], &Operands[index + 1], - &Operands[numOperands - 1]); - --numOperands; - Operands[getNumOperands()].~InstOperand(); -} - -auto OperationInst::getSuccessorOperands(unsigned index) const - -> llvm::iterator_range { - assert(isTerminator() && "Only terminators have successors."); - unsigned succOperandIndex = getSuccessorOperandIndex(index); - return {const_operand_iterator(this, succOperandIndex), - const_operand_iterator(this, succOperandIndex + - getNumSuccessorOperands(index))}; -} -auto OperationInst::getSuccessorOperands(unsigned index) - -> llvm::iterator_range { - assert(isTerminator() && "Only terminators have successors."); - unsigned succOperandIndex = getSuccessorOperandIndex(index); - return {operand_iterator(this, succOperandIndex), - operand_iterator(this, - succOperandIndex + getNumSuccessorOperands(index))}; -} - -/// If an attribute exists with the specified name, change it to the new -/// value. Otherwise, add a new attribute with the specified name/value. -void OperationInst::setAttr(Identifier name, Attribute value) { - assert(value && "attributes may never be null"); - auto origAttrs = getAttrs(); - - SmallVector newAttrs(origAttrs.begin(), origAttrs.end()); - auto *context = getContext(); - - // If we already have this attribute, replace it. - for (auto &elt : newAttrs) - if (elt.first == name) { - elt.second = value; - attrs = AttributeListStorage::get(newAttrs, context); - return; - } - - // Otherwise, add it. - newAttrs.push_back({name, value}); - attrs = AttributeListStorage::get(newAttrs, context); -} - -/// Remove the attribute with the specified name if it exists. The return -/// value indicates whether the attribute was present or not. -auto OperationInst::removeAttr(Identifier name) -> RemoveResult { - auto origAttrs = getAttrs(); - for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) { - if (origAttrs[i].first == name) { - SmallVector newAttrs; - newAttrs.reserve(origAttrs.size() - 1); - newAttrs.append(origAttrs.begin(), origAttrs.begin() + i); - newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end()); - attrs = AttributeListStorage::get(newAttrs, getContext()); - return RemoveResult::Removed; - } - } - return RemoveResult::NotFound; -} - -/// Attempt to constant fold this operation with the specified constant -/// operand values. If successful, this returns false and fills in the -/// results vector. If not, this returns true and results is unspecified. -bool OperationInst::constantFold(ArrayRef operands, - SmallVectorImpl &results) const { - if (auto *abstractOp = getAbstractOperation()) { - // If we have a registered operation definition matching this one, use it to - // try to constant fold the operation. - if (!abstractOp->constantFoldHook(llvm::cast(this), operands, - results)) - return false; - - // Otherwise, fall back on the dialect hook to handle it. - return abstractOp->dialect.constantFoldHook(llvm::cast(this), - operands, results); - } - - // If this operation hasn't been registered or doesn't have abstract - // operation, fall back to a dialect which matches the prefix. - auto opName = getName().getStringRef(); - if (auto *dialect = getContext()->getRegisteredDialect(opName)) { - return dialect->constantFoldHook(llvm::cast(this), operands, - results); - } - - return true; -} - -/// Emit an error with the op name prefixed, like "'dim' op " which is -/// convenient for verifiers. -bool OperationInst::emitOpError(const Twine &message) const { - return emitError(Twine('\'') + getName().getStringRef() + "' op " + message); -} - -//===----------------------------------------------------------------------===// -// ForStmt -//===----------------------------------------------------------------------===// - -ForStmt *ForStmt::create(Location location, ArrayRef lbOperands, - AffineMap lbMap, ArrayRef ubOperands, - AffineMap ubMap, int64_t step) { - assert(lbOperands.size() == lbMap.getNumInputs() && - "lower bound operand count does not match the affine map"); - assert(ubOperands.size() == ubMap.getNumInputs() && - "upper bound operand count does not match the affine map"); - assert(step > 0 && "step has to be a positive integer constant"); - - unsigned numOperands = lbOperands.size() + ubOperands.size(); - ForStmt *stmt = new ForStmt(location, numOperands, lbMap, ubMap, step); - - unsigned i = 0; - for (unsigned e = lbOperands.size(); i != e; ++i) - stmt->operands.emplace_back(InstOperand(stmt, lbOperands[i])); - - for (unsigned j = 0, e = ubOperands.size(); j != e; ++i, ++j) - stmt->operands.emplace_back(InstOperand(stmt, ubOperands[j])); - - return stmt; -} - -ForStmt::ForStmt(Location location, unsigned numOperands, AffineMap lbMap, - AffineMap ubMap, int64_t step) - : Statement(Statement::Kind::For, location), - Value(Value::Kind::ForStmt, - Type::getIndex(lbMap.getResult(0).getContext())), - body(this), lbMap(lbMap), ubMap(ubMap), step(step) { - - // The body of a for stmt always has one block. - body.push_back(new Block()); - operands.reserve(numOperands); -} - -const AffineBound ForStmt::getLowerBound() const { - return AffineBound(*this, 0, lbMap.getNumInputs(), lbMap); -} - -const AffineBound ForStmt::getUpperBound() const { - return AffineBound(*this, lbMap.getNumInputs(), getNumOperands(), ubMap); -} - -void ForStmt::setLowerBound(ArrayRef lbOperands, AffineMap map) { - assert(lbOperands.size() == map.getNumInputs()); - assert(map.getNumResults() >= 1 && "bound map has at least one result"); - - SmallVector ubOperands(getUpperBoundOperands()); - - operands.clear(); - operands.reserve(lbOperands.size() + ubMap.getNumInputs()); - for (auto *operand : lbOperands) { - operands.emplace_back(InstOperand(this, operand)); - } - for (auto *operand : ubOperands) { - operands.emplace_back(InstOperand(this, operand)); - } - this->lbMap = map; -} - -void ForStmt::setUpperBound(ArrayRef ubOperands, AffineMap map) { - assert(ubOperands.size() == map.getNumInputs()); - assert(map.getNumResults() >= 1 && "bound map has at least one result"); - - SmallVector lbOperands(getLowerBoundOperands()); - - operands.clear(); - operands.reserve(lbOperands.size() + ubOperands.size()); - for (auto *operand : lbOperands) { - operands.emplace_back(InstOperand(this, operand)); - } - for (auto *operand : ubOperands) { - operands.emplace_back(InstOperand(this, operand)); - } - this->ubMap = map; -} - -void ForStmt::setLowerBoundMap(AffineMap map) { - assert(lbMap.getNumDims() == map.getNumDims() && - lbMap.getNumSymbols() == map.getNumSymbols()); - assert(map.getNumResults() >= 1 && "bound map has at least one result"); - this->lbMap = map; -} - -void ForStmt::setUpperBoundMap(AffineMap map) { - assert(ubMap.getNumDims() == map.getNumDims() && - ubMap.getNumSymbols() == map.getNumSymbols()); - assert(map.getNumResults() >= 1 && "bound map has at least one result"); - this->ubMap = map; -} - -bool ForStmt::hasConstantLowerBound() const { return lbMap.isSingleConstant(); } - -bool ForStmt::hasConstantUpperBound() const { return ubMap.isSingleConstant(); } - -int64_t ForStmt::getConstantLowerBound() const { - return lbMap.getSingleConstantResult(); -} - -int64_t ForStmt::getConstantUpperBound() const { - return ubMap.getSingleConstantResult(); -} - -void ForStmt::setConstantLowerBound(int64_t value) { - setLowerBound({}, AffineMap::getConstantMap(value, getContext())); -} - -void ForStmt::setConstantUpperBound(int64_t value) { - setUpperBound({}, AffineMap::getConstantMap(value, getContext())); -} - -ForStmt::operand_range ForStmt::getLowerBoundOperands() { - return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()}; -} - -ForStmt::const_operand_range ForStmt::getLowerBoundOperands() const { - return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()}; -} - -ForStmt::operand_range ForStmt::getUpperBoundOperands() { - return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()}; -} - -ForStmt::const_operand_range ForStmt::getUpperBoundOperands() const { - return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()}; -} - -bool ForStmt::matchingBoundOperandList() const { - if (lbMap.getNumDims() != ubMap.getNumDims() || - lbMap.getNumSymbols() != ubMap.getNumSymbols()) - return false; - - unsigned numOperands = lbMap.getNumInputs(); - for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) { - // Compare Value *'s. - if (getOperand(i) != getOperand(numOperands + i)) - return false; - } - return true; -} - -//===----------------------------------------------------------------------===// -// IfStmt -//===----------------------------------------------------------------------===// - -IfStmt::IfStmt(Location location, unsigned numOperands, IntegerSet set) - : Statement(Kind::If, location), thenClause(this), elseClause(nullptr), - set(set) { - operands.reserve(numOperands); - - // The then of an 'if' stmt always has one block. - thenClause.push_back(new Block()); -} - -IfStmt::~IfStmt() { - if (elseClause) - delete elseClause; - - // An IfStmt's IntegerSet 'set' should not be deleted since it is - // allocated through MLIRContext's bump pointer allocator. -} - -IfStmt *IfStmt::create(Location location, ArrayRef operands, - IntegerSet set) { - unsigned numOperands = operands.size(); - assert(numOperands == set.getNumOperands() && - "operand cound does not match the integer set operand count"); - - IfStmt *stmt = new IfStmt(location, numOperands, set); - - for (auto *op : operands) - stmt->operands.emplace_back(InstOperand(stmt, op)); - - return stmt; -} - -const AffineCondition IfStmt::getCondition() const { - return AffineCondition(*this, set); -} - -MLIRContext *IfStmt::getContext() const { - // Check for degenerate case of if statement with no operands. - // This is unlikely, but legal. - if (operands.empty()) - return getFunction()->getContext(); - - return getOperand(0)->getType().getContext(); -} - -//===----------------------------------------------------------------------===// -// Statement Cloning -//===----------------------------------------------------------------------===// - -/// Create a deep copy of this statement, remapping any operands that use -/// values outside of the statement using the map that is provided (leaving -/// them alone if no entry is present). Replaces references to cloned -/// sub-statements to the corresponding statement that is copied, and adds -/// those mappings to the map. -Statement *Statement::clone(DenseMap &operandMap, - MLIRContext *context) const { - // If the specified value is in operandMap, return the remapped value. - // Otherwise return the value itself. - auto remapOperand = [&](const Value *value) -> Value * { - auto it = operandMap.find(value); - return it != operandMap.end() ? it->second : const_cast(value); - }; - - SmallVector operands; - SmallVector successors; - if (auto *opStmt = dyn_cast(this)) { - operands.reserve(getNumOperands() + opStmt->getNumSuccessors()); - - if (!opStmt->isTerminator()) { - // Non-terminators just add all the operands. - for (auto *opValue : getOperands()) - operands.push_back(remapOperand(opValue)); - } else { - // We add the operands separated by nullptr's for each successor. - unsigned firstSuccOperand = opStmt->getNumSuccessors() - ? opStmt->getSuccessorOperandIndex(0) - : opStmt->getNumOperands(); - auto InstOperands = opStmt->getInstOperands(); - - unsigned i = 0; - for (; i != firstSuccOperand; ++i) - operands.push_back(remapOperand(InstOperands[i].get())); - - successors.reserve(opStmt->getNumSuccessors()); - for (unsigned succ = 0, e = opStmt->getNumSuccessors(); succ != e; - ++succ) { - successors.push_back(const_cast(opStmt->getSuccessor(succ))); - - // Add sentinel to delineate successor operands. - operands.push_back(nullptr); - - // Remap the successors operands. - for (auto *operand : opStmt->getSuccessorOperands(succ)) - operands.push_back(remapOperand(operand)); - } - } - - SmallVector resultTypes; - resultTypes.reserve(opStmt->getNumResults()); - for (auto *result : opStmt->getResults()) - resultTypes.push_back(result->getType()); - auto *newOp = OperationInst::create(getLoc(), opStmt->getName(), operands, - resultTypes, opStmt->getAttrs(), - successors, context); - // Remember the mapping of any results. - for (unsigned i = 0, e = opStmt->getNumResults(); i != e; ++i) - operandMap[opStmt->getResult(i)] = newOp->getResult(i); - return newOp; - } - - operands.reserve(getNumOperands()); - for (auto *opValue : getOperands()) - operands.push_back(remapOperand(opValue)); - - if (auto *forStmt = dyn_cast(this)) { - auto lbMap = forStmt->getLowerBoundMap(); - auto ubMap = forStmt->getUpperBoundMap(); - - auto *newFor = ForStmt::create( - getLoc(), ArrayRef(operands).take_front(lbMap.getNumInputs()), - lbMap, ArrayRef(operands).take_back(ubMap.getNumInputs()), - ubMap, forStmt->getStep()); - - // Remember the induction variable mapping. - operandMap[forStmt] = newFor; - - // Recursively clone the body of the for loop. - for (auto &subStmt : *forStmt->getBody()) - newFor->getBody()->push_back(subStmt.clone(operandMap, context)); - - return newFor; - } - - // Otherwise, we must have an If statement. - auto *ifStmt = cast(this); - auto *newIf = IfStmt::create(getLoc(), operands, ifStmt->getIntegerSet()); - - auto *resultThen = newIf->getThen(); - for (auto &childStmt : *ifStmt->getThen()) - resultThen->push_back(childStmt.clone(operandMap, context)); - - if (ifStmt->hasElse()) { - auto *resultElse = newIf->createElse(); - for (auto &childStmt : *ifStmt->getElse()) - resultElse->push_back(childStmt.clone(operandMap, context)); - } - - return newIf; -} - -Statement *Statement::clone(MLIRContext *context) const { - DenseMap operandMap; - return clone(operandMap, context); -} diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp index c7a5e42dd99..a213f05a932 100644 --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -17,7 +17,7 @@ #include "mlir/IR/Value.h" #include "mlir/IR/Function.h" -#include "mlir/IR/Statements.h" +#include "mlir/IR/Instructions.h" using namespace mlir; /// If this value is the result of an Instruction, return the instruction @@ -35,8 +35,8 @@ Function *Value::getFunction() { return cast(this)->getFunction(); case Value::Kind::InstResult: return getDefiningInst()->getFunction(); - case Value::Kind::ForStmt: - return cast(this)->getFunction(); + case Value::Kind::ForInst: + return cast(this)->getFunction(); } } @@ -59,10 +59,10 @@ MLIRContext *IROperandOwner::getContext() const { switch (getKind()) { case Kind::OperationInst: return cast(this)->getContext(); - case Kind::ForStmt: - return cast(this)->getContext(); - case Kind::IfStmt: - return cast(this)->getContext(); + case Kind::ForInst: + return cast(this)->getContext(); + case Kind::IfInst: + return cast(this)->getContext(); } } diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 6cc1aba72b3..3f05a4a145a 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -26,12 +26,12 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/InstVisitor.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/OpImplementation.h" -#include "mlir/IR/StmtVisitor.h" #include "mlir/IR/Types.h" #include "mlir/Support/STLExtras.h" #include "mlir/Transforms/Utils.h" @@ -2071,7 +2071,7 @@ FunctionParser::~FunctionParser() { } } -/// Parse a SSA operand for an instruction or statement. +/// Parse a SSA operand for an instruction or instruction. /// /// ssa-use ::= ssa-id /// @@ -2716,7 +2716,7 @@ ParseResult CFGFunctionParser::parseFunctionBody() { /// Basic block declaration. /// -/// basic-block ::= bb-label instruction* terminator-stmt +/// basic-block ::= bb-label instruction* terminator-inst /// bb-label ::= bb-id bb-arg-list? `:` /// bb-id ::= bare-id /// bb-arg-list ::= `(` ssa-id-and-type-list? `)` @@ -2786,16 +2786,16 @@ private: /// more specific builder type. FuncBuilder builder; - ParseResult parseForStmt(); + ParseResult parseForInst(); ParseResult parseIntConstant(int64_t &val); ParseResult parseDimAndSymbolList(SmallVectorImpl &operands, unsigned numDims, unsigned numOperands, const char *affineStructName); ParseResult parseBound(SmallVectorImpl &operands, AffineMap &map, bool isLower); - ParseResult parseIfStmt(); + ParseResult parseIfInst(); ParseResult parseElseClause(Block *elseClause); - ParseResult parseStatements(Block *block); + ParseResult parseInstructions(Block *block); ParseResult parseBlock(Block *block); bool parseSuccessorAndUseList(Block *&dest, @@ -2809,19 +2809,19 @@ private: ParseResult MLFunctionParser::parseFunctionBody() { auto braceLoc = getToken().getLoc(); - // Parse statements in this function. + // Parse instructions in this function. if (parseBlock(function->getBody())) return ParseFailure; return finalizeFunction(function, braceLoc); } -/// For statement. +/// For instruction. /// -/// ml-for-stmt ::= `for` ssa-id `=` lower-bound `to` upper-bound -/// (`step` integer-literal)? `{` ml-stmt* `}` +/// ml-for-inst ::= `for` ssa-id `=` lower-bound `to` upper-bound +/// (`step` integer-literal)? `{` ml-inst* `}` /// -ParseResult MLFunctionParser::parseForStmt() { +ParseResult MLFunctionParser::parseForInst() { consumeToken(Token::kw_for); // Parse induction variable. @@ -2862,23 +2862,23 @@ ParseResult MLFunctionParser::parseForStmt() { return emitError("step has to be a positive integer"); } - // Create for statement. - ForStmt *forStmt = + // Create for instruction. + ForInst *forInst = builder.createFor(getEncodedSourceLocation(loc), lbOperands, lbMap, ubOperands, ubMap, step); // Create SSA value definition for the induction variable. - if (addDefinition({inductionVariableName, 0, loc}, forStmt)) + if (addDefinition({inductionVariableName, 0, loc}, forInst)) return ParseFailure; - // If parsing of the for statement body fails, - // MLIR contains for statement with those nested statements that have been + // If parsing of the for instruction body fails, + // MLIR contains for instruction with those nested instructions that have been // successfully parsed. - if (parseBlock(forStmt->getBody())) + if (parseBlock(forInst->getBody())) return ParseFailure; // Reset insertion point to the current block. - builder.setInsertionPointToEnd(forStmt->getBlock()); + builder.setInsertionPointToEnd(forInst->getBlock()); return ParseSuccess; } @@ -3007,7 +3007,7 @@ ParseResult MLFunctionParser::parseBound(SmallVectorImpl &operands, // Create an identity map using dim id for an induction variable and // symbol otherwise. This representation is optimized for storage. // Analysis passes may expand it into a multi-dimensional map if desired. - if (isa(operands[0])) + if (isa(operands[0])) map = builder.getDimIdentityMap(); else map = builder.getSymbolIdentityMap(); @@ -3095,14 +3095,14 @@ IntegerSet Parser::parseIntegerSetInline() { return set; } -/// If statement. +/// If instruction. /// -/// ml-if-head ::= `if` ml-if-cond `{` ml-stmt* `}` -/// | ml-if-head `else` `if` ml-if-cond `{` ml-stmt* `}` -/// ml-if-stmt ::= ml-if-head -/// | ml-if-head `else` `{` ml-stmt* `}` +/// ml-if-head ::= `if` ml-if-cond `{` ml-inst* `}` +/// | ml-if-head `else` `if` ml-if-cond `{` ml-inst* `}` +/// ml-if-inst ::= ml-if-head +/// | ml-if-head `else` `{` ml-inst* `}` /// -ParseResult MLFunctionParser::parseIfStmt() { +ParseResult MLFunctionParser::parseIfInst() { auto loc = getToken().getLoc(); consumeToken(Token::kw_if); @@ -3115,25 +3115,25 @@ ParseResult MLFunctionParser::parseIfStmt() { "integer set")) return ParseFailure; - IfStmt *ifStmt = + IfInst *ifInst = builder.createIf(getEncodedSourceLocation(loc), operands, set); - Block *thenClause = ifStmt->getThen(); + Block *thenClause = ifInst->getThen(); - // When parsing of an if statement body fails, the IR contains - // the if statement with the portion of the body that has been + // When parsing of an if instruction body fails, the IR contains + // the if instruction with the portion of the body that has been // successfully parsed. if (parseBlock(thenClause)) return ParseFailure; if (consumeIf(Token::kw_else)) { - auto *elseClause = ifStmt->createElse(); + auto *elseClause = ifInst->createElse(); if (parseElseClause(elseClause)) return ParseFailure; } // Reset insertion point to the current block. - builder.setInsertionPointToEnd(ifStmt->getBlock()); + builder.setInsertionPointToEnd(ifInst->getBlock()); return ParseSuccess; } @@ -3141,25 +3141,25 @@ ParseResult MLFunctionParser::parseIfStmt() { ParseResult MLFunctionParser::parseElseClause(Block *elseClause) { if (getToken().is(Token::kw_if)) { builder.setInsertionPointToEnd(elseClause); - return parseIfStmt(); + return parseIfInst(); } return parseBlock(elseClause); } /// -/// Parse a list of statements ending with `return` or `}` +/// Parse a list of instructions ending with `return` or `}` /// -ParseResult MLFunctionParser::parseStatements(Block *block) { +ParseResult MLFunctionParser::parseInstructions(Block *block) { auto createOpFunc = [&](const OperationState &state) -> OperationInst * { return builder.createOperation(state); }; builder.setInsertionPointToEnd(block); - // Parse statements till we see '}' or 'return'. - // Return statement is parsed separately to emit a more intuitive error - // when '}' is missing after the return statement. + // Parse instructions till we see '}' or 'return'. + // Return instruction is parsed separately to emit a more intuitive error + // when '}' is missing after the return instruction. while (getToken().isNot(Token::r_brace, Token::kw_return)) { switch (getToken().getKind()) { default: @@ -3167,17 +3167,17 @@ ParseResult MLFunctionParser::parseStatements(Block *block) { return ParseFailure; break; case Token::kw_for: - if (parseForStmt()) + if (parseForInst()) return ParseFailure; break; case Token::kw_if: - if (parseIfStmt()) + if (parseIfInst()) return ParseFailure; break; } // end switch } - // Parse the return statement. + // Parse the return instruction. if (getToken().is(Token::kw_return)) if (parseOperation(createOpFunc)) return ParseFailure; @@ -3186,12 +3186,12 @@ ParseResult MLFunctionParser::parseStatements(Block *block) { } /// -/// Parse `{` ml-stmt* `}` +/// Parse `{` ml-inst* `}` /// ParseResult MLFunctionParser::parseBlock(Block *block) { - if (parseToken(Token::l_brace, "expected '{' before statement list") || - parseStatements(block) || - parseToken(Token::r_brace, "expected '}' after statement list")) + if (parseToken(Token::l_brace, "expected '{' before instruction list") || + parseInstructions(block) || + parseToken(Token::r_brace, "expected '}' after instruction list")) return ParseFailure; return ParseSuccess; @@ -3429,7 +3429,7 @@ ParseResult ModuleParser::parseCFGFunc() { /// ML function declarations. /// /// ml-func ::= `mlfunc` ml-func-signature -/// (`attributes` attribute-dict)? `{` ml-stmt* ml-return-stmt +/// (`attributes` attribute-dict)? `{` ml-inst* ml-return-inst /// `}` /// ParseResult ModuleParser::parseMLFunc() { diff --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp index 0f130e19e26..20e8e0af214 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp @@ -21,9 +21,9 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Instructions.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" -#include "mlir/IR/Statements.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/SuperVectorOps/SuperVectorOps.h" #include "mlir/Support/FileUtilities.h" diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index a5b45ba4098..80e3dd955c3 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -24,7 +24,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" -#include "mlir/IR/StmtVisitor.h" +#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/Support/Functional.h" #include "mlir/Transforms/Passes.h" @@ -207,24 +207,24 @@ struct CFGCSE : public CSEImpl { }; /// Common sub-expression elimination for ML functions. -struct MLCSE : public CSEImpl, StmtWalker { - using StmtWalker::walk; +struct MLCSE : public CSEImpl, InstWalker { + using InstWalker::walk; void run(Function *f) { - // Walk the function statements. + // Walk the function instructions. walk(f); // Finally, erase any redundant operations. eraseDeadOperations(); } - // Insert a scope for each statement range. + // Insert a scope for each instruction range. template void walk(Iterator Start, Iterator End) { ScopedMapTy::ScopeTy scope(knownValues); - StmtWalker::walk(Start, End); + InstWalker::walk(Start, End); } - void visitOperationInst(OperationInst *stmt) { simplifyOperation(stmt); } + void visitOperationInst(OperationInst *inst) { simplifyOperation(inst); } }; } // end anonymous namespace diff --git a/mlir/lib/Transforms/ComposeAffineMaps.cpp b/mlir/lib/Transforms/ComposeAffineMaps.cpp index c97b83f8485..f5edf2d8b81 100644 --- a/mlir/lib/Transforms/ComposeAffineMaps.cpp +++ b/mlir/lib/Transforms/ComposeAffineMaps.cpp @@ -25,7 +25,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/StmtVisitor.h" +#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Transforms/Passes.h" @@ -36,20 +36,20 @@ using namespace mlir; namespace { -// ComposeAffineMaps walks stmt blocks in a Function, and for each +// ComposeAffineMaps walks inst blocks in a Function, and for each // AffineApplyOp, forward substitutes its results into any users which are // also AffineApplyOps. After forward subtituting its results, AffineApplyOps // with no remaining uses are collected and erased after the walk. // TODO(andydavis) Remove this when Chris adds instruction combiner pass. -struct ComposeAffineMaps : public FunctionPass, StmtWalker { +struct ComposeAffineMaps : public FunctionPass, InstWalker { std::vector affineApplyOpsToErase; explicit ComposeAffineMaps() : FunctionPass(&ComposeAffineMaps::passID) {} - using InstListType = llvm::iplist; + using InstListType = llvm::iplist; void walk(InstListType::iterator Start, InstListType::iterator End); - void visitOperationInst(OperationInst *stmt); + void visitOperationInst(OperationInst *inst); PassResult runOnMLFunction(Function *f) override; - using StmtWalker::walk; + using InstWalker::walk; static char passID; }; @@ -66,14 +66,14 @@ void ComposeAffineMaps::walk(InstListType::iterator Start, InstListType::iterator End) { while (Start != End) { walk(&(*Start)); - // Increment iterator after walk as visit function can mutate stmt list + // Increment iterator after walk as visit function can mutate inst list // ahead of 'Start'. ++Start; } } -void ComposeAffineMaps::visitOperationInst(OperationInst *opStmt) { - if (auto affineApplyOp = opStmt->dyn_cast()) { +void ComposeAffineMaps::visitOperationInst(OperationInst *opInst) { + if (auto affineApplyOp = opInst->dyn_cast()) { forwardSubstitute(affineApplyOp); bool allUsesEmpty = true; for (auto *result : affineApplyOp->getInstruction()->getResults()) { @@ -83,7 +83,7 @@ void ComposeAffineMaps::visitOperationInst(OperationInst *opStmt) { } } if (allUsesEmpty) { - affineApplyOpsToErase.push_back(opStmt); + affineApplyOpsToErase.push_back(opInst); } } } @@ -91,8 +91,8 @@ void ComposeAffineMaps::visitOperationInst(OperationInst *opStmt) { PassResult ComposeAffineMaps::runOnMLFunction(Function *f) { affineApplyOpsToErase.clear(); walk(f); - for (auto *opStmt : affineApplyOpsToErase) { - opStmt->erase(); + for (auto *opInst : affineApplyOpsToErase) { + opInst->erase(); } return success(); } diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index 08087777e72..f482e90d7ac 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -17,7 +17,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" -#include "mlir/IR/StmtVisitor.h" +#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Utils.h" @@ -26,20 +26,20 @@ using namespace mlir; namespace { /// Simple constant folding pass. -struct ConstantFold : public FunctionPass, StmtWalker { +struct ConstantFold : public FunctionPass, InstWalker { ConstantFold() : FunctionPass(&ConstantFold::passID) {} // All constants in the function post folding. SmallVector existingConstants; // Operations that were folded and that need to be erased. - std::vector opStmtsToErase; + std::vector opInstsToErase; using ConstantFactoryType = std::function; bool foldOperation(OperationInst *op, SmallVectorImpl &existingConstants, ConstantFactoryType constantFactory); - void visitOperationInst(OperationInst *stmt); - void visitForStmt(ForStmt *stmt); + void visitOperationInst(OperationInst *inst); + void visitForInst(ForInst *inst); PassResult runOnCFGFunction(Function *f) override; PassResult runOnMLFunction(Function *f) override; @@ -140,24 +140,24 @@ PassResult ConstantFold::runOnCFGFunction(Function *f) { } // Override the walker's operation visiter for constant folding. -void ConstantFold::visitOperationInst(OperationInst *stmt) { +void ConstantFold::visitOperationInst(OperationInst *inst) { auto constantFactory = [&](Attribute value, Type type) -> Value * { - FuncBuilder builder(stmt); - return builder.create(stmt->getLoc(), value, type); + FuncBuilder builder(inst); + return builder.create(inst->getLoc(), value, type); }; - if (!ConstantFold::foldOperation(stmt, existingConstants, constantFactory)) { - opStmtsToErase.push_back(stmt); + if (!ConstantFold::foldOperation(inst, existingConstants, constantFactory)) { + opInstsToErase.push_back(inst); } } -// Override the walker's 'for' statement visit for constant folding. -void ConstantFold::visitForStmt(ForStmt *forStmt) { - constantFoldBounds(forStmt); +// Override the walker's 'for' instruction visit for constant folding. +void ConstantFold::visitForInst(ForInst *forInst) { + constantFoldBounds(forInst); } PassResult ConstantFold::runOnMLFunction(Function *f) { existingConstants.clear(); - opStmtsToErase.clear(); + opInstsToErase.clear(); walk(f); // At this point, these operations are dead, remove them. @@ -165,8 +165,8 @@ PassResult ConstantFold::runOnMLFunction(Function *f) { // side effects. When we have side effect modeling, we should verify that // the operation is effect-free before we remove it. Until then this is // close enough. - for (auto *stmt : opStmtsToErase) { - stmt->erase(); + for (auto *inst : opInstsToErase) { + inst->erase(); } // By the time we are done, we may have simplified a bunch of code, leaving diff --git a/mlir/lib/Transforms/ConvertToCFG.cpp b/mlir/lib/Transforms/ConvertToCFG.cpp index 821f35ca539..abce624b06f 100644 --- a/mlir/lib/Transforms/ConvertToCFG.cpp +++ b/mlir/lib/Transforms/ConvertToCFG.cpp @@ -21,9 +21,9 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/InstVisitor.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" -#include "mlir/IR/StmtVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Support/Functional.h" @@ -39,14 +39,14 @@ using namespace mlir; namespace { // Generates CFG function equivalent to the given ML function. -class FunctionConverter : public StmtVisitor { +class FunctionConverter : public InstVisitor { public: FunctionConverter(Function *cfgFunc) : cfgFunc(cfgFunc), builder(cfgFunc) {} Function *convert(Function *mlFunc); - void visitForStmt(ForStmt *forStmt); - void visitIfStmt(IfStmt *ifStmt); - void visitOperationInst(OperationInst *opStmt); + void visitForInst(ForInst *forInst); + void visitIfInst(IfInst *ifInst); + void visitOperationInst(OperationInst *opInst); private: Value *getConstantIndexValue(int64_t value); @@ -64,49 +64,49 @@ private: } // end anonymous namespace // Return a vector of OperationInst's arguments as Values. For each -// statement operands, represented as Value, lookup its Value conterpart in +// instruction operands, represented as Value, lookup its Value conterpart in // the valueRemapping table. static llvm::SmallVector -operandsAs(Statement *opStmt, +operandsAs(Instruction *opInst, const llvm::DenseMap &valueRemapping) { llvm::SmallVector operands; - for (const Value *operand : opStmt->getOperands()) { + for (const Value *operand : opInst->getOperands()) { assert(valueRemapping.count(operand) != 0 && "operand is not defined"); operands.push_back(valueRemapping.lookup(operand)); } return operands; } -// Convert an operation statement into an operation instruction. +// Convert an operation instruction into an operation instruction. // // The operation description (name, number and types of operands or results) // remains the same but the values must be updated to be Values. Update the // mapping Value->Value as the conversion is performed. The operation // instruction is appended to current block (end of SESE region). -void FunctionConverter::visitOperationInst(OperationInst *opStmt) { +void FunctionConverter::visitOperationInst(OperationInst *opInst) { // Set up basic operation state (context, name, operands). - OperationState state(cfgFunc->getContext(), opStmt->getLoc(), - opStmt->getName()); - state.addOperands(operandsAs(opStmt, valueRemapping)); + OperationState state(cfgFunc->getContext(), opInst->getLoc(), + opInst->getName()); + state.addOperands(operandsAs(opInst, valueRemapping)); // Set up operation return types. The corresponding Values will become // available after the operation is created. state.addTypes(functional::map( - [](Value *result) { return result->getType(); }, opStmt->getResults())); + [](Value *result) { return result->getType(); }, opInst->getResults())); // Copy attributes. - for (auto attr : opStmt->getAttrs()) { + for (auto attr : opInst->getAttrs()) { state.addAttribute(attr.first.strref(), attr.second); } - auto opInst = builder.createOperation(state); + auto op = builder.createOperation(state); // Make results of the operation accessible to the following operations // through remapping. - assert(opInst->getNumResults() == opStmt->getNumResults()); + assert(opInst->getNumResults() == op->getNumResults()); for (unsigned i = 0, n = opInst->getNumResults(); i < n; ++i) { valueRemapping.insert( - std::make_pair(opStmt->getResult(i), opInst->getResult(i))); + std::make_pair(opInst->getResult(i), op->getResult(i))); } } @@ -116,10 +116,10 @@ Value *FunctionConverter::getConstantIndexValue(int64_t value) { return op->getResult(); } -// Visit all statements in the given statement block. +// Visit all instructions in the given instruction block. void FunctionConverter::visitBlock(Block *Block) { - for (auto &stmt : *Block) - this->visit(&stmt); + for (auto &inst : *Block) + this->visit(&inst); } // Given a range of values, emit the code that reduces them with "min" or "max" @@ -211,7 +211,7 @@ Value *FunctionConverter::buildMinMaxReductionSeq( // | | // +--------------------------------+ // -void FunctionConverter::visitForStmt(ForStmt *forStmt) { +void FunctionConverter::visitForInst(ForInst *forInst) { // First, store the loop insertion location so that we can go back to it after // creating the new blocks (block creation updates the insertion point). Block *loopInsertionPoint = builder.getInsertionBlock(); @@ -228,27 +228,27 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) { // The loop condition block has an argument for loop induction variable. // Create it upfront and make the loop induction variable -> basic block - // argument remapping available to the following instructions. ForStatement + // argument remapping available to the following instructions. ForInstruction // is-a Value corresponding to the loop induction variable. builder.setInsertionPointToEnd(loopConditionBlock); Value *iv = loopConditionBlock->addArgument(builder.getIndexType()); - valueRemapping.insert(std::make_pair(forStmt, iv)); + valueRemapping.insert(std::make_pair(forInst, iv)); // Recursively construct loop body region. // Walking manually because we need custom logic before and after traversing // the list of children. builder.setInsertionPointToEnd(loopBodyFirstBlock); - visitBlock(forStmt->getBody()); + visitBlock(forInst->getBody()); // Builder point is currently at the last block of the loop body. Append the // induction variable stepping to this block and branch back to the exit // condition block. Construct an affine map f : (x -> x+step) and apply this // map to the induction variable. - auto affStep = builder.getAffineConstantExpr(forStmt->getStep()); + auto affStep = builder.getAffineConstantExpr(forInst->getStep()); auto affDim = builder.getAffineDimExpr(0); auto affStepMap = builder.getAffineMap(1, 0, {affDim + affStep}, {}); auto stepOp = - builder.create(forStmt->getLoc(), affStepMap, iv); + builder.create(forInst->getLoc(), affStepMap, iv); Value *nextIvValue = stepOp->getResult(0); builder.create(builder.getUnknownLoc(), loopConditionBlock, nextIvValue); @@ -262,22 +262,22 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) { return valueRemapping.lookup(value); }; auto operands = - functional::map(remapOperands, forStmt->getLowerBoundOperands()); + functional::map(remapOperands, forInst->getLowerBoundOperands()); auto lbAffineApply = builder.create( - forStmt->getLoc(), forStmt->getLowerBoundMap(), operands); + forInst->getLoc(), forInst->getLowerBoundMap(), operands); Value *lowerBound = buildMinMaxReductionSeq( - forStmt->getLoc(), CmpIPredicate::SGT, lbAffineApply->getResults()); - operands = functional::map(remapOperands, forStmt->getUpperBoundOperands()); + forInst->getLoc(), CmpIPredicate::SGT, lbAffineApply->getResults()); + operands = functional::map(remapOperands, forInst->getUpperBoundOperands()); auto ubAffineApply = builder.create( - forStmt->getLoc(), forStmt->getUpperBoundMap(), operands); + forInst->getLoc(), forInst->getUpperBoundMap(), operands); Value *upperBound = buildMinMaxReductionSeq( - forStmt->getLoc(), CmpIPredicate::SLT, ubAffineApply->getResults()); + forInst->getLoc(), CmpIPredicate::SLT, ubAffineApply->getResults()); builder.create(builder.getUnknownLoc(), loopConditionBlock, lowerBound); builder.setInsertionPointToEnd(loopConditionBlock); auto comparisonOp = builder.create( - forStmt->getLoc(), CmpIPredicate::SLT, iv, upperBound); + forInst->getLoc(), CmpIPredicate::SLT, iv, upperBound); auto comparisonResult = comparisonOp->getResult(); builder.create(builder.getUnknownLoc(), comparisonResult, loopBodyFirstBlock, ArrayRef(), @@ -288,16 +288,16 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) { builder.setInsertionPointToEnd(postLoopBlock); } -// Convert an "if" statement into a flow of basic blocks. +// Convert an "if" instruction into a flow of basic blocks. // -// Create an SESE region for the if statement (including its "then" and optional -// "else" statement blocks) and append it to the end of the current region. The -// conditional region consists of a sequence of condition-checking blocks that -// implement the short-circuit scheme, followed by a "then" SESE region and an -// "else" SESE region, and the continuation block that post-dominates all blocks -// of the "if" statement. The flow of blocks that correspond to the "then" and -// "else" clauses are constructed recursively, enabling easy nesting of "if" -// statements and if-then-else-if chains. +// Create an SESE region for the if instruction (including its "then" and +// optional "else" instruction blocks) and append it to the end of the current +// region. The conditional region consists of a sequence of condition-checking +// blocks that implement the short-circuit scheme, followed by a "then" SESE +// region and an "else" SESE region, and the continuation block that +// post-dominates all blocks of the "if" instruction. The flow of blocks that +// correspond to the "then" and "else" clauses are constructed recursively, +// enabling easy nesting of "if" instructions and if-then-else-if chains. // // +--------------------------------+ // | | @@ -365,17 +365,17 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) { // | | // +--------------------------------+ // -void FunctionConverter::visitIfStmt(IfStmt *ifStmt) { - assert(ifStmt != nullptr); +void FunctionConverter::visitIfInst(IfInst *ifInst) { + assert(ifInst != nullptr); - auto integerSet = ifStmt->getCondition().getIntegerSet(); + auto integerSet = ifInst->getCondition().getIntegerSet(); // Create basic blocks for the 'then' block and for the 'else' block. // Although 'else' block may be empty in absence of an 'else' clause, create // it anyway for the sake of consistency and output IR readability. Also // create extra blocks for condition checking to prepare for short-circuit - // logic: conditions in the 'if' statement are conjunctive, so we can jump to - // the false branch as soon as one condition fails. `cond_br` requires + // logic: conditions in the 'if' instruction are conjunctive, so we can jump + // to the false branch as soon as one condition fails. `cond_br` requires // another block as a target when the condition is true, and that block will // contain the next condition. Block *ifInsertionBlock = builder.getInsertionBlock(); @@ -412,14 +412,14 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) { builder.getAffineMap(integerSet.getNumDims(), integerSet.getNumSymbols(), constraintExpr, {}); auto affineApplyOp = builder.create( - ifStmt->getLoc(), affineMap, operandsAs(ifStmt, valueRemapping)); + ifInst->getLoc(), affineMap, operandsAs(ifInst, valueRemapping)); Value *affResult = affineApplyOp->getResult(0); // Compare the result of the apply and branch. auto comparisonOp = builder.create( - ifStmt->getLoc(), isEquality ? CmpIPredicate::EQ : CmpIPredicate::SGE, + ifInst->getLoc(), isEquality ? CmpIPredicate::EQ : CmpIPredicate::SGE, affResult, zeroConstant); - builder.create(ifStmt->getLoc(), comparisonOp->getResult(), + builder.create(ifInst->getLoc(), comparisonOp->getResult(), nextBlock, /*trueArgs*/ ArrayRef(), elseBlock, /*falseArgs*/ ArrayRef()); @@ -429,13 +429,13 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) { // Recursively traverse the 'then' block. builder.setInsertionPointToEnd(thenBlock); - visitBlock(ifStmt->getThen()); + visitBlock(ifInst->getThen()); Block *lastThenBlock = builder.getInsertionBlock(); // Recursively traverse the 'else' block if present. builder.setInsertionPointToEnd(elseBlock); - if (ifStmt->hasElse()) - visitBlock(ifStmt->getElse()); + if (ifInst->hasElse()) + visitBlock(ifInst->getElse()); Block *lastElseBlock = builder.getInsertionBlock(); // Create the continuation block here so that it appears lexically after the @@ -443,9 +443,9 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) { // to the continuation block. Block *continuationBlock = builder.createBlock(); builder.setInsertionPointToEnd(lastThenBlock); - builder.create(ifStmt->getLoc(), continuationBlock); + builder.create(ifInst->getLoc(), continuationBlock); builder.setInsertionPointToEnd(lastElseBlock); - builder.create(ifStmt->getLoc(), continuationBlock); + builder.create(ifInst->getLoc(), continuationBlock); // Make sure building can continue by setting up the continuation block as the // insertion point. @@ -454,12 +454,12 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) { // Entry point of the function convertor. // -// Conversion is performed by recursively visiting statements of a Function. +// Conversion is performed by recursively visiting instructions of a Function. // It reasons in terms of single-entry single-exit (SESE) regions that are not // materialized in the code. Instead, the pointer to the last block of the // region is maintained throughout the conversion as the insertion point of the // IR builder since we never change the first block after its creation. "Block" -// statements such as loops and branches create new SESE regions for their +// instructions such as loops and branches create new SESE regions for their // bodies, and surround them with additional basic blocks for the control flow. // Individual operations are simply appended to the end of the last basic block // of the current region. The SESE invariant allows us to easily handle nested @@ -484,9 +484,9 @@ Function *FunctionConverter::convert(Function *mlFunc) { valueRemapping.insert(std::make_pair(mlArgument, cfgArgument)); } - // Convert statements in order. - for (auto &stmt : *mlFunc->getBody()) { - visit(&stmt); + // Convert instructions in order. + for (auto &inst : *mlFunc->getBody()) { + visit(&inst); } return cfgFunc; diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 69344819ed8..bc7f31f0434 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -25,7 +25,7 @@ #include "mlir/Analysis/Utils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/StmtVisitor.h" +#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Transforms/Passes.h" @@ -49,7 +49,7 @@ namespace { /// buffers in 'fastMemorySpace', and replaces memory operations to the former /// by the latter. Only load op's handled for now. /// TODO(bondhugula): extend this to store op's. -struct DmaGeneration : public FunctionPass, StmtWalker { +struct DmaGeneration : public FunctionPass, InstWalker { explicit DmaGeneration(unsigned slowMemorySpace = 0, unsigned fastMemorySpaceArg = 1, int minDmaTransferSize = 1024) @@ -65,10 +65,10 @@ struct DmaGeneration : public FunctionPass, StmtWalker { // Not applicable to CFG functions. PassResult runOnCFGFunction(Function *f) override { return success(); } PassResult runOnMLFunction(Function *f) override; - void runOnForStmt(ForStmt *forStmt); + void runOnForInst(ForInst *forInst); - void visitOperationInst(OperationInst *opStmt); - bool generateDma(const MemRefRegion ®ion, ForStmt *forStmt, + void visitOperationInst(OperationInst *opInst); + bool generateDma(const MemRefRegion ®ion, ForInst *forInst, uint64_t *sizeInBytes); // List of memory regions to DMA for. @@ -108,11 +108,11 @@ FunctionPass *mlir::createDmaGenerationPass(unsigned slowMemorySpace, // Gather regions to promote to buffers in faster memory space. // TODO(bondhugula): handle store op's; only load's handled for now. -void DmaGeneration::visitOperationInst(OperationInst *opStmt) { - if (auto loadOp = opStmt->dyn_cast()) { +void DmaGeneration::visitOperationInst(OperationInst *opInst) { + if (auto loadOp = opInst->dyn_cast()) { if (loadOp->getMemRefType().getMemorySpace() != slowMemorySpace) return; - } else if (auto storeOp = opStmt->dyn_cast()) { + } else if (auto storeOp = opInst->dyn_cast()) { if (storeOp->getMemRefType().getMemorySpace() != slowMemorySpace) return; } else { @@ -125,7 +125,7 @@ void DmaGeneration::visitOperationInst(OperationInst *opStmt) { // This way we would be allocating O(num of memref's) sets instead of // O(num of load/store op's). auto region = std::make_unique(); - if (!getMemRefRegion(opStmt, dmaDepth, region.get())) { + if (!getMemRefRegion(opInst, dmaDepth, region.get())) { LLVM_DEBUG(llvm::dbgs() << "Error obtaining memory region\n"); return; } @@ -170,19 +170,19 @@ static void getMultiLevelStrides(const MemRefRegion ®ion, // Creates a buffer in the faster memory space for the specified region; // generates a DMA from the lower memory space to this one, and replaces all // loads to load from that buffer. Returns true if DMAs are generated. -bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForStmt *forStmt, +bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForInst *forInst, uint64_t *sizeInBytes) { // DMAs for read regions are going to be inserted just before the for loop. - FuncBuilder prologue(forStmt); + FuncBuilder prologue(forInst); // DMAs for write regions are going to be inserted just after the for loop. - FuncBuilder epilogue(forStmt->getBlock(), - std::next(Block::iterator(forStmt))); + FuncBuilder epilogue(forInst->getBlock(), + std::next(Block::iterator(forInst))); FuncBuilder *b = region.isWrite() ? &epilogue : &prologue; // Builder to create constants at the top level. - FuncBuilder top(forStmt->getFunction()); + FuncBuilder top(forInst->getFunction()); - auto loc = forStmt->getLoc(); + auto loc = forInst->getLoc(); auto *memref = region.memref; auto memRefType = memref->getType().cast(); @@ -285,7 +285,7 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForStmt *forStmt, LLVM_DEBUG(llvm::dbgs() << "Creating a new buffer of type: "); LLVM_DEBUG(fastMemRefType.dump(); llvm::dbgs() << "\n"); - // Create the fast memory space buffer just before the 'for' statement. + // Create the fast memory space buffer just before the 'for' instruction. fastMemRef = prologue.create(loc, fastMemRefType)->getResult(); // Record it. fastBufferMap[memref] = fastMemRef; @@ -361,58 +361,58 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForStmt *forStmt, remapExprs.push_back(dimExpr - offsets[i]); } auto indexRemap = b->getAffineMap(outerIVs.size() + rank, 0, remapExprs, {}); - // *Only* those uses within the body of 'forStmt' are replaced. + // *Only* those uses within the body of 'forInst' are replaced. replaceAllMemRefUsesWith(memref, fastMemRef, /*extraIndices=*/{}, indexRemap, /*extraOperands=*/outerIVs, - /*domStmtFilter=*/&*forStmt->getBody()->begin()); + /*domInstFilter=*/&*forInst->getBody()->begin()); return true; } -/// Returns the nesting depth of this statement, i.e., the number of loops -/// surrounding this statement. +/// Returns the nesting depth of this instruction, i.e., the number of loops +/// surrounding this instruction. // TODO(bondhugula): move this to utilities later. -static unsigned getNestingDepth(const Statement &stmt) { - const Statement *currStmt = &stmt; +static unsigned getNestingDepth(const Instruction &inst) { + const Instruction *currInst = &inst; unsigned depth = 0; - while ((currStmt = currStmt->getParentStmt())) { - if (isa(currStmt)) + while ((currInst = currInst->getParentInst())) { + if (isa(currInst)) depth++; } return depth; } -// TODO(bondhugula): make this run on a Block instead of a 'for' stmt. -void DmaGeneration::runOnForStmt(ForStmt *forStmt) { +// TODO(bondhugula): make this run on a Block instead of a 'for' inst. +void DmaGeneration::runOnForInst(ForInst *forInst) { // For now (for testing purposes), we'll run this on the outermost among 'for' - // stmt's with unit stride, i.e., right at the top of the tile if tiling has + // inst's with unit stride, i.e., right at the top of the tile if tiling has // been done. In the future, the DMA generation has to be done at a level // where the generated data fits in a higher level of the memory hierarchy; so // the pass has to be instantiated with additional information that we aren't // provided with at the moment. - if (forStmt->getStep() != 1) { - if (auto *innerFor = dyn_cast(&*forStmt->getBody()->begin())) { - runOnForStmt(innerFor); + if (forInst->getStep() != 1) { + if (auto *innerFor = dyn_cast(&*forInst->getBody()->begin())) { + runOnForInst(innerFor); } return; } // DMAs will be generated for this depth, i.e., for all data accessed by this // loop. - dmaDepth = getNestingDepth(*forStmt); + dmaDepth = getNestingDepth(*forInst); regions.clear(); fastBufferMap.clear(); - // Walk this 'for' statement to gather all memory regions. - walk(forStmt); + // Walk this 'for' instruction to gather all memory regions. + walk(forInst); uint64_t totalSizeInBytes = 0; bool ret = false; for (const auto ®ion : regions) { uint64_t sizeInBytes; - bool iRet = generateDma(*region, forStmt, &sizeInBytes); + bool iRet = generateDma(*region, forInst, &sizeInBytes); if (iRet) totalSizeInBytes += sizeInBytes; ret = ret | iRet; @@ -426,9 +426,9 @@ void DmaGeneration::runOnForStmt(ForStmt *forStmt) { } PassResult DmaGeneration::runOnMLFunction(Function *f) { - for (auto &stmt : *f->getBody()) { - if (auto *forStmt = dyn_cast(&stmt)) { - runOnForStmt(forStmt); + for (auto &inst : *f->getBody()) { + if (auto *forInst = dyn_cast(&inst)) { + runOnForInst(forInst); } } // This function never leaves the IR in an invalid state. diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index d31337437ad..97dea753f88 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -27,7 +27,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/StmtVisitor.h" +#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Transforms/LoopUtils.h" @@ -80,20 +80,20 @@ char LoopFusion::passID = 0; FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; } -static void getSingleMemRefAccess(OperationInst *loadOrStoreOpStmt, +static void getSingleMemRefAccess(OperationInst *loadOrStoreOpInst, MemRefAccess *access) { - if (auto loadOp = loadOrStoreOpStmt->dyn_cast()) { + if (auto loadOp = loadOrStoreOpInst->dyn_cast()) { access->memref = loadOp->getMemRef(); - access->opStmt = loadOrStoreOpStmt; + access->opInst = loadOrStoreOpInst; auto loadMemrefType = loadOp->getMemRefType(); access->indices.reserve(loadMemrefType.getRank()); for (auto *index : loadOp->getIndices()) { access->indices.push_back(index); } } else { - assert(loadOrStoreOpStmt->isa()); - auto storeOp = loadOrStoreOpStmt->dyn_cast(); - access->opStmt = loadOrStoreOpStmt; + assert(loadOrStoreOpInst->isa()); + auto storeOp = loadOrStoreOpInst->dyn_cast(); + access->opInst = loadOrStoreOpInst; access->memref = storeOp->getMemRef(); auto storeMemrefType = storeOp->getMemRefType(); access->indices.reserve(storeMemrefType.getRank()); @@ -112,24 +112,24 @@ struct FusionCandidate { MemRefAccess dstAccess; }; -static FusionCandidate buildFusionCandidate(OperationInst *srcStoreOpStmt, - OperationInst *dstLoadOpStmt) { +static FusionCandidate buildFusionCandidate(OperationInst *srcStoreOpInst, + OperationInst *dstLoadOpInst) { FusionCandidate candidate; // Get store access for src loop nest. - getSingleMemRefAccess(srcStoreOpStmt, &candidate.srcAccess); + getSingleMemRefAccess(srcStoreOpInst, &candidate.srcAccess); // Get load access for dst loop nest. - getSingleMemRefAccess(dstLoadOpStmt, &candidate.dstAccess); + getSingleMemRefAccess(dstLoadOpInst, &candidate.dstAccess); return candidate; } -// Returns the loop depth of the loop nest surrounding 'opStmt'. -static unsigned getLoopDepth(OperationInst *opStmt) { +// Returns the loop depth of the loop nest surrounding 'opInst'. +static unsigned getLoopDepth(OperationInst *opInst) { unsigned loopDepth = 0; - auto *currStmt = opStmt->getParentStmt(); - ForStmt *currForStmt; - while (currStmt && (currForStmt = dyn_cast(currStmt))) { + auto *currInst = opInst->getParentInst(); + ForInst *currForInst; + while (currInst && (currForInst = dyn_cast(currInst))) { ++loopDepth; - currStmt = currStmt->getParentStmt(); + currInst = currInst->getParentInst(); } return loopDepth; } @@ -137,28 +137,28 @@ static unsigned getLoopDepth(OperationInst *opStmt) { namespace { // LoopNestStateCollector walks loop nests and collects load and store -// operations, and whether or not an IfStmt was encountered in the loop nest. -class LoopNestStateCollector : public StmtWalker { +// operations, and whether or not an IfInst was encountered in the loop nest. +class LoopNestStateCollector : public InstWalker { public: - SmallVector forStmts; - SmallVector loadOpStmts; - SmallVector storeOpStmts; - bool hasIfStmt = false; + SmallVector forInsts; + SmallVector loadOpInsts; + SmallVector storeOpInsts; + bool hasIfInst = false; - void visitForStmt(ForStmt *forStmt) { forStmts.push_back(forStmt); } + void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); } - void visitIfStmt(IfStmt *ifStmt) { hasIfStmt = true; } + void visitIfInst(IfInst *ifInst) { hasIfInst = true; } - void visitOperationInst(OperationInst *opStmt) { - if (opStmt->isa()) - loadOpStmts.push_back(opStmt); - if (opStmt->isa()) - storeOpStmts.push_back(opStmt); + void visitOperationInst(OperationInst *opInst) { + if (opInst->isa()) + loadOpInsts.push_back(opInst); + if (opInst->isa()) + storeOpInsts.push_back(opInst); } }; // MemRefDependenceGraph is a graph data structure where graph nodes are -// top-level statements in a Function which contain load/store ops, and edges +// top-level instructions in a Function which contain load/store ops, and edges // are memref dependences between the nodes. // TODO(andydavis) Add a depth parameter to dependence graph construction. struct MemRefDependenceGraph { @@ -170,18 +170,18 @@ public: // The unique identifier of this node in the graph. unsigned id; // The top-level statment which is (or contains) loads/stores. - Statement *stmt; + Instruction *inst; // List of load operations. SmallVector loads; - // List of store op stmts. + // List of store op insts. SmallVector stores; - Node(unsigned id, Statement *stmt) : id(id), stmt(stmt) {} + Node(unsigned id, Instruction *inst) : id(id), inst(inst) {} // Returns the load op count for 'memref'. unsigned getLoadOpCount(Value *memref) { unsigned loadOpCount = 0; - for (auto *loadOpStmt : loads) { - if (memref == loadOpStmt->cast()->getMemRef()) + for (auto *loadOpInst : loads) { + if (memref == loadOpInst->cast()->getMemRef()) ++loadOpCount; } return loadOpCount; @@ -190,8 +190,8 @@ public: // Returns the store op count for 'memref'. unsigned getStoreOpCount(Value *memref) { unsigned storeOpCount = 0; - for (auto *storeOpStmt : stores) { - if (memref == storeOpStmt->cast()->getMemRef()) + for (auto *storeOpInst : stores) { + if (memref == storeOpInst->cast()->getMemRef()) ++storeOpCount; } return storeOpCount; @@ -315,10 +315,10 @@ public: void addToNode(unsigned id, const SmallVectorImpl &loads, const SmallVectorImpl &stores) { Node *node = getNode(id); - for (auto *loadOpStmt : loads) - node->loads.push_back(loadOpStmt); - for (auto *storeOpStmt : stores) - node->stores.push_back(storeOpStmt); + for (auto *loadOpInst : loads) + node->loads.push_back(loadOpInst); + for (auto *storeOpInst : stores) + node->stores.push_back(storeOpInst); } void print(raw_ostream &os) const { @@ -341,55 +341,55 @@ public: void dump() const { print(llvm::errs()); } }; -// Intializes the data dependence graph by walking statements in 'f'. +// Intializes the data dependence graph by walking instructions in 'f'. // Assigns each node in the graph a node id based on program order in 'f'. // TODO(andydavis) Add support for taking a Block arg to construct the // dependence graph at a different depth. bool MemRefDependenceGraph::init(Function *f) { unsigned id = 0; DenseMap> memrefAccesses; - for (auto &stmt : *f->getBody()) { - if (auto *forStmt = dyn_cast(&stmt)) { - // Create graph node 'id' to represent top-level 'forStmt' and record + for (auto &inst : *f->getBody()) { + if (auto *forInst = dyn_cast(&inst)) { + // Create graph node 'id' to represent top-level 'forInst' and record // all loads and store accesses it contains. LoopNestStateCollector collector; - collector.walkForStmt(forStmt); - // Return false if IfStmts are found (not currently supported). - if (collector.hasIfStmt) + collector.walkForInst(forInst); + // Return false if IfInsts are found (not currently supported). + if (collector.hasIfInst) return false; - Node node(id++, &stmt); - for (auto *opStmt : collector.loadOpStmts) { - node.loads.push_back(opStmt); - auto *memref = opStmt->cast()->getMemRef(); + Node node(id++, &inst); + for (auto *opInst : collector.loadOpInsts) { + node.loads.push_back(opInst); + auto *memref = opInst->cast()->getMemRef(); memrefAccesses[memref].insert(node.id); } - for (auto *opStmt : collector.storeOpStmts) { - node.stores.push_back(opStmt); - auto *memref = opStmt->cast()->getMemRef(); + for (auto *opInst : collector.storeOpInsts) { + node.stores.push_back(opInst); + auto *memref = opInst->cast()->getMemRef(); memrefAccesses[memref].insert(node.id); } nodes.insert({node.id, node}); } - if (auto *opStmt = dyn_cast(&stmt)) { - if (auto loadOp = opStmt->dyn_cast()) { + if (auto *opInst = dyn_cast(&inst)) { + if (auto loadOp = opInst->dyn_cast()) { // Create graph node for top-level load op. - Node node(id++, &stmt); - node.loads.push_back(opStmt); - auto *memref = opStmt->cast()->getMemRef(); + Node node(id++, &inst); + node.loads.push_back(opInst); + auto *memref = opInst->cast()->getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); } - if (auto storeOp = opStmt->dyn_cast()) { + if (auto storeOp = opInst->dyn_cast()) { // Create graph node for top-level store op. - Node node(id++, &stmt); - node.stores.push_back(opStmt); - auto *memref = opStmt->cast()->getMemRef(); + Node node(id++, &inst); + node.stores.push_back(opInst); + auto *memref = opInst->cast()->getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); } } - // Return false if IfStmts are found (not currently supported). - if (isa(&stmt)) + // Return false if IfInsts are found (not currently supported). + if (isa(&inst)) return false; } @@ -421,9 +421,9 @@ bool MemRefDependenceGraph::init(Function *f) { // // *) A worklist is initialized with node ids from the dependence graph. // *) For each node id in the worklist: -// *) Pop a ForStmt of the worklist. This 'dstForStmt' will be a candidate -// destination ForStmt into which fusion will be attempted. -// *) Add each LoadOp currently in 'dstForStmt' into list 'dstLoadOps'. +// *) Pop a ForInst of the worklist. This 'dstForInst' will be a candidate +// destination ForInst into which fusion will be attempted. +// *) Add each LoadOp currently in 'dstForInst' into list 'dstLoadOps'. // *) For each LoadOp in 'dstLoadOps' do: // *) Lookup dependent loop nests at earlier positions in the Function // which have a single store op to the same memref. @@ -434,12 +434,12 @@ bool MemRefDependenceGraph::init(Function *f) { // bounds to be functions of 'dstLoopNest' IVs and symbols. // *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest', // just before the dst load op user. -// *) Add the newly fused load/store operation statements to the state, +// *) Add the newly fused load/store operation instructions to the state, // and also add newly fuse load ops to 'dstLoopOps' to be considered // as fusion dst load ops in another iteration. // *) Remove old src loop nest and its associated state. // -// Given a graph where top-level statements are vertices in the set 'V' and +// Given a graph where top-level instructions are vertices in the set 'V' and // edges in the set 'E' are dependences between vertices, this algorithm // takes O(V) time for initialization, and has runtime O(V + E). // @@ -471,14 +471,14 @@ public: // Get 'dstNode' into which to attempt fusion. auto *dstNode = mdg->getNode(dstId); // Skip if 'dstNode' is not a loop nest. - if (!isa(dstNode->stmt)) + if (!isa(dstNode->inst)) continue; SmallVector loads = dstNode->loads; while (!loads.empty()) { - auto *dstLoadOpStmt = loads.pop_back_val(); - auto *memref = dstLoadOpStmt->cast()->getMemRef(); - // Skip 'dstLoadOpStmt' if multiple loads to 'memref' in 'dstNode'. + auto *dstLoadOpInst = loads.pop_back_val(); + auto *memref = dstLoadOpInst->cast()->getMemRef(); + // Skip 'dstLoadOpInst' if multiple loads to 'memref' in 'dstNode'. if (dstNode->getLoadOpCount(memref) != 1) continue; // Skip if no input edges along which to fuse. @@ -491,7 +491,7 @@ public: continue; auto *srcNode = mdg->getNode(srcEdge.id); // Skip if 'srcNode' is not a loop nest. - if (!isa(srcNode->stmt)) + if (!isa(srcNode->inst)) continue; // Skip if 'srcNode' has more than one store to 'memref'. if (srcNode->getStoreOpCount(memref) != 1) @@ -508,17 +508,17 @@ public: if (mdg->getMinOutEdgeNodeId(srcNode->id) != dstId) continue; // Get unique 'srcNode' store op. - auto *srcStoreOpStmt = srcNode->stores.front(); - // Build fusion candidate out of 'srcStoreOpStmt' and 'dstLoadOpStmt'. + auto *srcStoreOpInst = srcNode->stores.front(); + // Build fusion candidate out of 'srcStoreOpInst' and 'dstLoadOpInst'. FusionCandidate candidate = - buildFusionCandidate(srcStoreOpStmt, dstLoadOpStmt); + buildFusionCandidate(srcStoreOpInst, dstLoadOpInst); // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. unsigned srcLoopDepth = clSrcLoopDepth.getNumOccurrences() > 0 ? clSrcLoopDepth - : getLoopDepth(srcStoreOpStmt); + : getLoopDepth(srcStoreOpInst); unsigned dstLoopDepth = clDstLoopDepth.getNumOccurrences() > 0 ? clDstLoopDepth - : getLoopDepth(dstLoadOpStmt); + : getLoopDepth(dstLoadOpInst); auto *sliceLoopNest = mlir::insertBackwardComputationSlice( &candidate.srcAccess, &candidate.dstAccess, srcLoopDepth, dstLoopDepth); @@ -527,19 +527,19 @@ public: mdg->updateEdgesAndRemoveSrcNode(srcNode->id, dstNode->id); // Record all load/store accesses in 'sliceLoopNest' at 'dstPos'. LoopNestStateCollector collector; - collector.walkForStmt(sliceLoopNest); - mdg->addToNode(dstId, collector.loadOpStmts, - collector.storeOpStmts); + collector.walkForInst(sliceLoopNest); + mdg->addToNode(dstId, collector.loadOpInsts, + collector.storeOpInsts); // Add new load ops to current Node load op list 'loads' to // continue fusing based on new operands. - for (auto *loadOpStmt : collector.loadOpStmts) - loads.push_back(loadOpStmt); + for (auto *loadOpInst : collector.loadOpInsts) + loads.push_back(loadOpInst); // Promote single iteration loops to single IV value. - for (auto *forStmt : collector.forStmts) { - promoteIfSingleIteration(forStmt); + for (auto *forInst : collector.forInsts) { + promoteIfSingleIteration(forInst); } // Remove old src loop nest. - cast(srcNode->stmt)->erase(); + cast(srcNode->inst)->erase(); } } } diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 109953f2296..8f3be8a3d45 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -55,16 +55,16 @@ char LoopTiling::passID = 0; /// Function. FunctionPass *mlir::createLoopTilingPass() { return new LoopTiling(); } -// Move the loop body of ForStmt 'src' from 'src' into the specified location in +// Move the loop body of ForInst 'src' from 'src' into the specified location in // destination's body. -static inline void moveLoopBody(ForStmt *src, ForStmt *dest, +static inline void moveLoopBody(ForInst *src, ForInst *dest, Block::iterator loc) { dest->getBody()->getInstructions().splice(loc, src->getBody()->getInstructions()); } -// Move the loop body of ForStmt 'src' from 'src' to the start of dest's body. -static inline void moveLoopBody(ForStmt *src, ForStmt *dest) { +// Move the loop body of ForInst 'src' from 'src' to the start of dest's body. +static inline void moveLoopBody(ForInst *src, ForInst *dest) { moveLoopBody(src, dest, dest->getBody()->begin()); } @@ -73,8 +73,8 @@ static inline void moveLoopBody(ForStmt *src, ForStmt *dest) { /// depend on other dimensions. Bounds of each dimension can thus be treated /// independently, and deriving the new bounds is much simpler and faster /// than for the case of tiling arbitrary polyhedral shapes. -static void constructTiledIndexSetHyperRect(ArrayRef origLoops, - ArrayRef newLoops, +static void constructTiledIndexSetHyperRect(ArrayRef origLoops, + ArrayRef newLoops, ArrayRef tileSizes) { assert(!origLoops.empty()); assert(origLoops.size() == tileSizes.size()); @@ -138,27 +138,27 @@ static void constructTiledIndexSetHyperRect(ArrayRef origLoops, /// Tiles the specified band of perfectly nested loops creating tile-space loops /// and intra-tile loops. A band is a contiguous set of loops. // TODO(bondhugula): handle non hyper-rectangular spaces. -UtilResult mlir::tileCodeGen(ArrayRef band, +UtilResult mlir::tileCodeGen(ArrayRef band, ArrayRef tileSizes) { assert(!band.empty()); assert(band.size() == tileSizes.size()); - // Check if the supplied for stmt's are all successively nested. + // Check if the supplied for inst's are all successively nested. for (unsigned i = 1, e = band.size(); i < e; i++) { - assert(band[i]->getParentStmt() == band[i - 1]); + assert(band[i]->getParentInst() == band[i - 1]); } auto origLoops = band; - ForStmt *rootForStmt = origLoops[0]; - auto loc = rootForStmt->getLoc(); + ForInst *rootForInst = origLoops[0]; + auto loc = rootForInst->getLoc(); // Note that width is at least one since band isn't empty. unsigned width = band.size(); - SmallVector newLoops(2 * width); - ForStmt *innermostPointLoop; + SmallVector newLoops(2 * width); + ForInst *innermostPointLoop; // The outermost among the loops as we add more.. - auto *topLoop = rootForStmt; + auto *topLoop = rootForInst; // Add intra-tile (or point) loops. for (unsigned i = 0; i < width; i++) { @@ -195,7 +195,7 @@ UtilResult mlir::tileCodeGen(ArrayRef band, getIndexSet(band, &cst); if (!cst.isHyperRectangular(0, width)) { - rootForStmt->emitError("tiled code generation unimplemented for the" + rootForInst->emitError("tiled code generation unimplemented for the" "non-hyperrectangular case"); return UtilResult::Failure; } @@ -207,7 +207,7 @@ UtilResult mlir::tileCodeGen(ArrayRef band, } // Erase the old loop nest. - rootForStmt->erase(); + rootForInst->erase(); return UtilResult::Success; } @@ -216,28 +216,28 @@ UtilResult mlir::tileCodeGen(ArrayRef band, // a temporary placeholder to test the mechanics of tiled code generation. // Returns all maximal outermost perfect loop nests to tile. static void getTileableBands(Function *f, - std::vector> *bands) { - // Get maximal perfect nest of 'for' stmts starting from root (inclusive). - auto getMaximalPerfectLoopNest = [&](ForStmt *root) { - SmallVector band; - ForStmt *currStmt = root; + std::vector> *bands) { + // Get maximal perfect nest of 'for' insts starting from root (inclusive). + auto getMaximalPerfectLoopNest = [&](ForInst *root) { + SmallVector band; + ForInst *currInst = root; do { - band.push_back(currStmt); - } while (currStmt->getBody()->getInstructions().size() == 1 && - (currStmt = dyn_cast(&*currStmt->getBody()->begin()))); + band.push_back(currInst); + } while (currInst->getBody()->getInstructions().size() == 1 && + (currInst = dyn_cast(&*currInst->getBody()->begin()))); bands->push_back(band); }; - for (auto &stmt : *f->getBody()) { - auto *forStmt = dyn_cast(&stmt); - if (!forStmt) + for (auto &inst : *f->getBody()) { + auto *forInst = dyn_cast(&inst); + if (!forInst) continue; - getMaximalPerfectLoopNest(forStmt); + getMaximalPerfectLoopNest(forInst); } } PassResult LoopTiling::runOnMLFunction(Function *f) { - std::vector> bands; + std::vector> bands; getTileableBands(f, &bands); // Temporary tile sizes. diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 15ea0f841cc..69431bf6349 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -26,7 +26,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/StmtVisitor.h" +#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/Transforms/LoopUtils.h" #include "llvm/ADT/DenseMap.h" @@ -62,18 +62,18 @@ struct LoopUnroll : public FunctionPass { const Optional unrollFull; // Callback to obtain unroll factors; if this has a callable target, takes // precedence over command-line argument or passed argument. - const std::function getUnrollFactor; + const std::function getUnrollFactor; explicit LoopUnroll( Optional unrollFactor = None, Optional unrollFull = None, - const std::function &getUnrollFactor = nullptr) + const std::function &getUnrollFactor = nullptr) : FunctionPass(&LoopUnroll::passID), unrollFactor(unrollFactor), unrollFull(unrollFull), getUnrollFactor(getUnrollFactor) {} PassResult runOnMLFunction(Function *f) override; - /// Unroll this for stmt. Returns false if nothing was done. - bool runOnForStmt(ForStmt *forStmt); + /// Unroll this for inst. Returns false if nothing was done. + bool runOnForInst(ForInst *forInst); static const unsigned kDefaultUnrollFactor = 4; @@ -85,13 +85,13 @@ char LoopUnroll::passID = 0; PassResult LoopUnroll::runOnMLFunction(Function *f) { // Gathers all innermost loops through a post order pruned walk. - class InnermostLoopGatherer : public StmtWalker { + class InnermostLoopGatherer : public InstWalker { public: // Store innermost loops as we walk. - std::vector loops; + std::vector loops; // This method specialized to encode custom return logic. - using InstListType = llvm::iplist; + using InstListType = llvm::iplist; bool walkPostOrder(InstListType::iterator Start, InstListType::iterator End) { bool hasInnerLoops = false; @@ -103,43 +103,43 @@ PassResult LoopUnroll::runOnMLFunction(Function *f) { return hasInnerLoops; } - bool walkForStmtPostOrder(ForStmt *forStmt) { + bool walkForInstPostOrder(ForInst *forInst) { bool hasInnerLoops = - walkPostOrder(forStmt->getBody()->begin(), forStmt->getBody()->end()); + walkPostOrder(forInst->getBody()->begin(), forInst->getBody()->end()); if (!hasInnerLoops) - loops.push_back(forStmt); + loops.push_back(forInst); return true; } - bool walkIfStmtPostOrder(IfStmt *ifStmt) { + bool walkIfInstPostOrder(IfInst *ifInst) { bool hasInnerLoops = - walkPostOrder(ifStmt->getThen()->begin(), ifStmt->getThen()->end()); - if (ifStmt->hasElse()) + walkPostOrder(ifInst->getThen()->begin(), ifInst->getThen()->end()); + if (ifInst->hasElse()) hasInnerLoops |= - walkPostOrder(ifStmt->getElse()->begin(), ifStmt->getElse()->end()); + walkPostOrder(ifInst->getElse()->begin(), ifInst->getElse()->end()); return hasInnerLoops; } - bool visitOperationInst(OperationInst *opStmt) { return false; } + bool visitOperationInst(OperationInst *opInst) { return false; } // FIXME: can't use base class method for this because that in turn would // need to use the derived class method above. CRTP doesn't allow it, and // the compiler error resulting from it is also misleading. - using StmtWalker::walkPostOrder; + using InstWalker::walkPostOrder; }; // Gathers all loops with trip count <= minTripCount. - class ShortLoopGatherer : public StmtWalker { + class ShortLoopGatherer : public InstWalker { public: // Store short loops as we walk. - std::vector loops; + std::vector loops; const unsigned minTripCount; ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {} - void visitForStmt(ForStmt *forStmt) { - Optional tripCount = getConstantTripCount(*forStmt); + void visitForInst(ForInst *forInst) { + Optional tripCount = getConstantTripCount(*forInst); if (tripCount.hasValue() && tripCount.getValue() <= minTripCount) - loops.push_back(forStmt); + loops.push_back(forInst); } }; @@ -151,8 +151,8 @@ PassResult LoopUnroll::runOnMLFunction(Function *f) { // ones). slg.walkPostOrder(f); auto &loops = slg.loops; - for (auto *forStmt : loops) - loopUnrollFull(forStmt); + for (auto *forInst : loops) + loopUnrollFull(forInst); return success(); } @@ -167,8 +167,8 @@ PassResult LoopUnroll::runOnMLFunction(Function *f) { if (loops.empty()) break; bool unrolled = false; - for (auto *forStmt : loops) - unrolled |= runOnForStmt(forStmt); + for (auto *forInst : loops) + unrolled |= runOnForInst(forInst); if (!unrolled) // Break out if nothing was unrolled. break; @@ -176,31 +176,31 @@ PassResult LoopUnroll::runOnMLFunction(Function *f) { return success(); } -/// Unrolls a 'for' stmt. Returns true if the loop was unrolled, false +/// Unrolls a 'for' inst. Returns true if the loop was unrolled, false /// otherwise. The default unroll factor is 4. -bool LoopUnroll::runOnForStmt(ForStmt *forStmt) { +bool LoopUnroll::runOnForInst(ForInst *forInst) { // Use the function callback if one was provided. if (getUnrollFactor) { - return loopUnrollByFactor(forStmt, getUnrollFactor(*forStmt)); + return loopUnrollByFactor(forInst, getUnrollFactor(*forInst)); } // Unroll by the factor passed, if any. if (unrollFactor.hasValue()) - return loopUnrollByFactor(forStmt, unrollFactor.getValue()); + return loopUnrollByFactor(forInst, unrollFactor.getValue()); // Unroll by the command line factor if one was specified. if (clUnrollFactor.getNumOccurrences() > 0) - return loopUnrollByFactor(forStmt, clUnrollFactor); + return loopUnrollByFactor(forInst, clUnrollFactor); // Unroll completely if full loop unroll was specified. if (clUnrollFull.getNumOccurrences() > 0 || (unrollFull.hasValue() && unrollFull.getValue())) - return loopUnrollFull(forStmt); + return loopUnrollFull(forInst); // Unroll by four otherwise. - return loopUnrollByFactor(forStmt, kDefaultUnrollFactor); + return loopUnrollByFactor(forInst, kDefaultUnrollFactor); } FunctionPass *mlir::createLoopUnrollPass( int unrollFactor, int unrollFull, - const std::function &getUnrollFactor) { + const std::function &getUnrollFactor) { return new LoopUnroll( unrollFactor == -1 ? None : Optional(unrollFactor), unrollFull == -1 ? None : Optional(unrollFull), getUnrollFactor); diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 60e8d154f98..f59659cf234 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -40,7 +40,7 @@ // S6(i+1); // // Note: 'if/else' blocks are not jammed. So, if there are loops inside if -// stmt's, bodies of those loops will not be jammed. +// inst's, bodies of those loops will not be jammed. //===----------------------------------------------------------------------===// #include "mlir/Transforms/Passes.h" @@ -49,7 +49,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/StmtVisitor.h" +#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/Transforms/LoopUtils.h" #include "llvm/ADT/DenseMap.h" @@ -75,7 +75,7 @@ struct LoopUnrollAndJam : public FunctionPass { unrollJamFactor(unrollJamFactor) {} PassResult runOnMLFunction(Function *f) override; - bool runOnForStmt(ForStmt *forStmt); + bool runOnForInst(ForInst *forInst); static char passID; }; @@ -90,79 +90,79 @@ FunctionPass *mlir::createLoopUnrollAndJamPass(int unrollJamFactor) { PassResult LoopUnrollAndJam::runOnMLFunction(Function *f) { // Currently, just the outermost loop from the first loop nest is - // unroll-and-jammed by this pass. However, runOnForStmt can be called on any - // for Stmt. - auto *forStmt = dyn_cast(f->getBody()->begin()); - if (!forStmt) + // unroll-and-jammed by this pass. However, runOnForInst can be called on any + // for Inst. + auto *forInst = dyn_cast(f->getBody()->begin()); + if (!forInst) return success(); - runOnForStmt(forStmt); + runOnForInst(forInst); return success(); } -/// Unroll and jam a 'for' stmt. Default unroll jam factor is +/// Unroll and jam a 'for' inst. Default unroll jam factor is /// kDefaultUnrollJamFactor. Return false if nothing was done. -bool LoopUnrollAndJam::runOnForStmt(ForStmt *forStmt) { +bool LoopUnrollAndJam::runOnForInst(ForInst *forInst) { // Unroll and jam by the factor that was passed if any. if (unrollJamFactor.hasValue()) - return loopUnrollJamByFactor(forStmt, unrollJamFactor.getValue()); + return loopUnrollJamByFactor(forInst, unrollJamFactor.getValue()); // Otherwise, unroll jam by the command-line factor if one was specified. if (clUnrollJamFactor.getNumOccurrences() > 0) - return loopUnrollJamByFactor(forStmt, clUnrollJamFactor); + return loopUnrollJamByFactor(forInst, clUnrollJamFactor); // Unroll and jam by four otherwise. - return loopUnrollJamByFactor(forStmt, kDefaultUnrollJamFactor); + return loopUnrollJamByFactor(forInst, kDefaultUnrollJamFactor); } -bool mlir::loopUnrollJamUpToFactor(ForStmt *forStmt, uint64_t unrollJamFactor) { - Optional mayBeConstantTripCount = getConstantTripCount(*forStmt); +bool mlir::loopUnrollJamUpToFactor(ForInst *forInst, uint64_t unrollJamFactor) { + Optional mayBeConstantTripCount = getConstantTripCount(*forInst); if (mayBeConstantTripCount.hasValue() && mayBeConstantTripCount.getValue() < unrollJamFactor) - return loopUnrollJamByFactor(forStmt, mayBeConstantTripCount.getValue()); - return loopUnrollJamByFactor(forStmt, unrollJamFactor); + return loopUnrollJamByFactor(forInst, mayBeConstantTripCount.getValue()); + return loopUnrollJamByFactor(forInst, unrollJamFactor); } /// Unrolls and jams this loop by the specified factor. -bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) { - // Gathers all maximal sub-blocks of statements that do not themselves include - // a for stmt (a statement could have a descendant for stmt though in its - // tree). - class JamBlockGatherer : public StmtWalker { +bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) { + // Gathers all maximal sub-blocks of instructions that do not themselves + // include a for inst (a instruction could have a descendant for inst though + // in its tree). + class JamBlockGatherer : public InstWalker { public: - using InstListType = llvm::iplist; + using InstListType = llvm::iplist; - // Store iterators to the first and last stmt of each sub-block found. + // Store iterators to the first and last inst of each sub-block found. std::vector> subBlocks; // This is a linear time walk. void walk(InstListType::iterator Start, InstListType::iterator End) { for (auto it = Start; it != End;) { auto subBlockStart = it; - while (it != End && !isa(it)) + while (it != End && !isa(it)) ++it; if (it != subBlockStart) subBlocks.push_back({subBlockStart, std::prev(it)}); - // Process all for stmts that appear next. - while (it != End && isa(it)) - walkForStmt(cast(it++)); + // Process all for insts that appear next. + while (it != End && isa(it)) + walkForInst(cast(it++)); } } }; assert(unrollJamFactor >= 1 && "unroll jam factor should be >= 1"); - if (unrollJamFactor == 1 || forStmt->getBody()->empty()) + if (unrollJamFactor == 1 || forInst->getBody()->empty()) return false; - Optional mayBeConstantTripCount = getConstantTripCount(*forStmt); + Optional mayBeConstantTripCount = getConstantTripCount(*forInst); if (!mayBeConstantTripCount.hasValue() && - getLargestDivisorOfTripCount(*forStmt) % unrollJamFactor != 0) + getLargestDivisorOfTripCount(*forInst) % unrollJamFactor != 0) return false; - auto lbMap = forStmt->getLowerBoundMap(); - auto ubMap = forStmt->getUpperBoundMap(); + auto lbMap = forInst->getLowerBoundMap(); + auto ubMap = forInst->getUpperBoundMap(); // Loops with max/min expressions won't be unrolled here (the output can't be // expressed as a Function in the general case). However, the right way to @@ -173,7 +173,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) { // Same operand list for lower and upper bound for now. // TODO(bondhugula): handle bounds with different sets of operands. - if (!forStmt->matchingBoundOperandList()) + if (!forInst->matchingBoundOperandList()) return false; // If the trip count is lower than the unroll jam factor, no unroll jam. @@ -184,7 +184,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) { // Gather all sub-blocks to jam upon the loop being unrolled. JamBlockGatherer jbg; - jbg.walkForStmt(forStmt); + jbg.walkForInst(forInst); auto &subBlocks = jbg.subBlocks; // Generate the cleanup loop if trip count isn't a multiple of @@ -192,24 +192,24 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) { if (mayBeConstantTripCount.hasValue() && mayBeConstantTripCount.getValue() % unrollJamFactor != 0) { DenseMap operandMap; - // Insert the cleanup loop right after 'forStmt'. - FuncBuilder builder(forStmt->getBlock(), - std::next(Block::iterator(forStmt))); - auto *cleanupForStmt = cast(builder.clone(*forStmt, operandMap)); - cleanupForStmt->setLowerBoundMap( - getCleanupLoopLowerBound(*forStmt, unrollJamFactor, &builder)); + // Insert the cleanup loop right after 'forInst'. + FuncBuilder builder(forInst->getBlock(), + std::next(Block::iterator(forInst))); + auto *cleanupForInst = cast(builder.clone(*forInst, operandMap)); + cleanupForInst->setLowerBoundMap( + getCleanupLoopLowerBound(*forInst, unrollJamFactor, &builder)); // The upper bound needs to be adjusted. - forStmt->setUpperBoundMap( - getUnrolledLoopUpperBound(*forStmt, unrollJamFactor, &builder)); + forInst->setUpperBoundMap( + getUnrolledLoopUpperBound(*forInst, unrollJamFactor, &builder)); // Promote the loop body up if this has turned into a single iteration loop. - promoteIfSingleIteration(cleanupForStmt); + promoteIfSingleIteration(cleanupForInst); } // Scale the step of loop being unroll-jammed by the unroll-jam factor. - int64_t step = forStmt->getStep(); - forStmt->setStep(step * unrollJamFactor); + int64_t step = forInst->getStep(); + forInst->setStep(step * unrollJamFactor); for (auto &subBlock : subBlocks) { // Builder to insert unroll-jammed bodies. Insert right at the end of @@ -222,14 +222,14 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) { // If the induction variable is used, create a remapping to the value for // this unrolled instance. - if (!forStmt->use_empty()) { + if (!forInst->use_empty()) { // iv' = iv + i, i = 1 to unrollJamFactor-1. auto d0 = builder.getAffineDimExpr(0); auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {}); auto *ivUnroll = - builder.create(forStmt->getLoc(), bumpMap, forStmt) + builder.create(forInst->getLoc(), bumpMap, forInst) ->getResult(0); - operandMapping[forStmt] = ivUnroll; + operandMapping[forInst] = ivUnroll; } // Clone the sub-block being unroll-jammed. for (auto it = subBlock.first; it != std::next(subBlock.second); ++it) { @@ -239,7 +239,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) { } // Promote the loop body up if this has turned into a single iteration loop. - promoteIfSingleIteration(forStmt); + promoteIfSingleIteration(forInst); return true; } diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index 51577009abb..bcb2abf11dd 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -110,7 +110,7 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer, // Get the ML function builder. // We need access to the Function builder stored internally in the // MLFunctionLoweringRewriter general rewriting API does not provide - // ML-specific functions (ForStmt and Block manipulation). While we could + // ML-specific functions (ForInst and Block manipulation). While we could // forward them or define a whole rewriting chain based on MLFunctionBuilder // instead of Builer, the code for it would be duplicate boilerplate. As we // go towards unifying ML and CFG functions, this separation will disappear. @@ -137,13 +137,13 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer, // memory. // TODO(ntv): Handle broadcast / slice properly. auto permutationMap = transfer->getPermutationMap(); - SetVector loops; + SetVector loops; SmallVector accessIndices(transfer->getIndices()); for (auto it : llvm::enumerate(transfer->getVectorType().getShape())) { auto composed = composeWithUnboundedMap( getAffineDimExpr(it.index(), b.getContext()), permutationMap); - auto *forStmt = b.createFor(transfer->getLoc(), 0, it.value()); - loops.insert(forStmt); + auto *forInst = b.createFor(transfer->getLoc(), 0, it.value()); + loops.insert(forInst); // Setting the insertion point to the innermost loop achieves nesting. b.setInsertionPointToStart(loops.back()->getBody()); if (composed == getAffineConstantExpr(0, b.getContext())) { @@ -196,7 +196,7 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer, b.setInsertionPoint(transfer->getInstruction()); b.create(transfer->getLoc(), tmpScalarAlloc); - // 7. It is now safe to erase the statement. + // 7. It is now safe to erase the instruction. rewriter->replaceOp(transfer->getInstruction(), newResults); } @@ -213,7 +213,7 @@ public: return matchFailure(); } - void rewriteOpStmt(OperationInst *op, + void rewriteOpInst(OperationInst *op, MLFuncGlobalLoweringState *funcWiseState, std::unique_ptr opState, MLFuncLoweringRewriter *rewriter) const override { diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index a30e8164760..37f0f571a0f 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -73,7 +73,7 @@ /// Implementation details /// ====================== /// The current decisions made by the super-vectorization pass guarantee that -/// use-def chains do not escape an enclosing vectorized ForStmt. In other +/// use-def chains do not escape an enclosing vectorized ForInst. In other /// words, this pass operates on a scoped program slice. Furthermore, since we /// do not vectorize in the presence of conditionals for now, sliced chains are /// guaranteed not to escape the innermost scope, which has to be either the top @@ -247,7 +247,7 @@ static SmallVector delinearize(unsigned linearIndex, } static OperationInst * -instantiate(FuncBuilder *b, OperationInst *opStmt, VectorType hwVectorType, +instantiate(FuncBuilder *b, OperationInst *opInst, VectorType hwVectorType, DenseMap *substitutionsMap); /// Not all Values belong to a program slice scoped within the immediately @@ -263,10 +263,10 @@ static Value *substitute(Value *v, VectorType hwVectorType, DenseMap *substitutionsMap) { auto it = substitutionsMap->find(v); if (it == substitutionsMap->end()) { - auto *opStmt = v->getDefiningInst(); - if (opStmt->isa()) { - FuncBuilder b(opStmt); - auto *inst = instantiate(&b, opStmt, hwVectorType, substitutionsMap); + auto *opInst = v->getDefiningInst(); + if (opInst->isa()) { + FuncBuilder b(opInst); + auto *inst = instantiate(&b, opInst, hwVectorType, substitutionsMap); auto res = substitutionsMap->insert(std::make_pair(v, inst->getResult(0))); assert(res.second && "Insertion failed"); @@ -285,7 +285,7 @@ static Value *substitute(Value *v, VectorType hwVectorType, /// /// The general problem this pass solves is as follows: /// Assume a vector_transfer operation at the super-vector granularity that has -/// `l` enclosing loops (ForStmt). Assume the vector transfer operation operates +/// `l` enclosing loops (ForInst). Assume the vector transfer operation operates /// on a MemRef of rank `r`, a super-vector of rank `s` and a hardware vector of /// rank `h`. /// For the purpose of illustration assume l==4, r==3, s==2, h==1 and that the @@ -347,7 +347,7 @@ reindexAffineIndices(FuncBuilder *b, VectorType hwVectorType, SmallVector affineExprs; // TODO(ntv): support a concrete map and composition. unsigned i = 0; - // The first numMemRefIndices correspond to ForStmt that have not been + // The first numMemRefIndices correspond to ForInst that have not been // vectorized, the transformation is the identity on those. for (i = 0; i < numMemRefIndices; ++i) { auto d_i = b->getAffineDimExpr(i); @@ -384,9 +384,9 @@ reindexAffineIndices(FuncBuilder *b, VectorType hwVectorType, /// - constant splat is replaced by constant splat of `hwVectorType`. /// TODO(ntv): add more substitutions on a per-need basis. static SmallVector -materializeAttributes(OperationInst *opStmt, VectorType hwVectorType) { +materializeAttributes(OperationInst *opInst, VectorType hwVectorType) { SmallVector res; - for (auto a : opStmt->getAttrs()) { + for (auto a : opInst->getAttrs()) { if (auto splat = a.second.dyn_cast()) { auto attr = SplatElementsAttr::get(hwVectorType, splat.getValue()); res.push_back(NamedAttribute(a.first, attr)); @@ -397,7 +397,7 @@ materializeAttributes(OperationInst *opStmt, VectorType hwVectorType) { return res; } -/// Creates an instantiated version of `opStmt`. +/// Creates an instantiated version of `opInst`. /// Ops other than VectorTransferReadOp/VectorTransferWriteOp require no /// affine reindexing. Just substitute their Value operands and be done. For /// this case the actual instance is irrelevant. Just use the values in @@ -405,11 +405,11 @@ materializeAttributes(OperationInst *opStmt, VectorType hwVectorType) { /// /// If the underlying substitution fails, this fails too and returns nullptr. static OperationInst * -instantiate(FuncBuilder *b, OperationInst *opStmt, VectorType hwVectorType, +instantiate(FuncBuilder *b, OperationInst *opInst, VectorType hwVectorType, DenseMap *substitutionsMap) { - assert(!opStmt->isa() && + assert(!opInst->isa() && "Should call the function specialized for VectorTransferReadOp"); - assert(!opStmt->isa() && + assert(!opInst->isa() && "Should call the function specialized for VectorTransferWriteOp"); bool fail = false; auto operands = map( @@ -419,14 +419,14 @@ instantiate(FuncBuilder *b, OperationInst *opStmt, VectorType hwVectorType, fail |= !res; return res; }, - opStmt->getOperands()); + opInst->getOperands()); if (fail) return nullptr; - auto attrs = materializeAttributes(opStmt, hwVectorType); + auto attrs = materializeAttributes(opInst, hwVectorType); - OperationState state(b->getContext(), opStmt->getLoc(), - opStmt->getName().getStringRef(), operands, + OperationState state(b->getContext(), opInst->getLoc(), + opInst->getName().getStringRef(), operands, {hwVectorType}, attrs); return b->createOperation(state); } @@ -511,11 +511,11 @@ instantiate(FuncBuilder *b, VectorTransferWriteOp *write, return cloned->getInstruction(); } -/// Returns `true` if stmt instance is properly cloned and inserted, false +/// Returns `true` if inst instance is properly cloned and inserted, false /// otherwise. /// The multi-dimensional `hwVectorInstance` belongs to the shapeRatio of /// super-vector type to hw vector type. -/// A cloned instance of `stmt` is formed as follows: +/// A cloned instance of `inst` is formed as follows: /// 1. vector_transfer_read: the return `superVectorType` is replaced by /// `hwVectorType`. Additionally, affine indices are reindexed with /// `reindexAffineIndices` using `hwVectorInstance` and vector type @@ -532,24 +532,24 @@ instantiate(FuncBuilder *b, VectorTransferWriteOp *write, /// possible. /// /// Returns true on failure. -static bool instantiateMaterialization(Statement *stmt, +static bool instantiateMaterialization(Instruction *inst, MaterializationState *state) { - LLVM_DEBUG(dbgs() << "\ninstantiate: " << *stmt); + LLVM_DEBUG(dbgs() << "\ninstantiate: " << *inst); - if (isa(stmt)) - return stmt->emitError("NYI path ForStmt"); + if (isa(inst)) + return inst->emitError("NYI path ForInst"); - if (isa(stmt)) - return stmt->emitError("NYI path IfStmt"); + if (isa(inst)) + return inst->emitError("NYI path IfInst"); // Create a builder here for unroll-and-jam effects. - FuncBuilder b(stmt); - auto *opStmt = cast(stmt); - if (auto write = opStmt->dyn_cast()) { + FuncBuilder b(inst); + auto *opInst = cast(inst); + if (auto write = opInst->dyn_cast()) { instantiate(&b, write, state->hwVectorType, state->hwVectorInstance, state->substitutionsMap); return false; - } else if (auto read = opStmt->dyn_cast()) { + } else if (auto read = opInst->dyn_cast()) { auto *clone = instantiate(&b, read, state->hwVectorType, state->hwVectorInstance, state->substitutionsMap); state->substitutionsMap->insert( @@ -559,17 +559,17 @@ static bool instantiateMaterialization(Statement *stmt, // The only op with 0 results reaching this point must, by construction, be // VectorTransferWriteOps and have been caught above. Ops with >= 2 results // are not yet supported. So just support 1 result. - if (opStmt->getNumResults() != 1) - return stmt->emitError("NYI: ops with != 1 results"); - if (opStmt->getResult(0)->getType() != state->superVectorType) - return stmt->emitError("Op does not return a supervector."); + if (opInst->getNumResults() != 1) + return inst->emitError("NYI: ops with != 1 results"); + if (opInst->getResult(0)->getType() != state->superVectorType) + return inst->emitError("Op does not return a supervector."); auto *clone = - instantiate(&b, opStmt, state->hwVectorType, state->substitutionsMap); + instantiate(&b, opInst, state->hwVectorType, state->substitutionsMap); if (!clone) { return true; } state->substitutionsMap->insert( - std::make_pair(opStmt->getResult(0), clone->getResult(0))); + std::make_pair(opInst->getResult(0), clone->getResult(0))); return false; } @@ -595,7 +595,7 @@ static bool instantiateMaterialization(Statement *stmt, /// TODO(ntv): full loops + materialized allocs. /// TODO(ntv): partial unrolling + materialized allocs. static bool emitSlice(MaterializationState *state, - SetVector *slice) { + SetVector *slice) { auto ratio = shapeRatio(state->superVectorType, state->hwVectorType); assert(ratio.hasValue() && "ratio of super-vector to HW-vector shape is not integral"); @@ -610,10 +610,10 @@ static bool emitSlice(MaterializationState *state, DenseMap substitutionMap; scopedState.substitutionsMap = &substitutionMap; // slice are topologically sorted, we can just clone them in order. - for (auto *stmt : *slice) { - auto fail = instantiateMaterialization(stmt, &scopedState); + for (auto *inst : *slice) { + auto fail = instantiateMaterialization(inst, &scopedState); if (fail) { - stmt->emitError("Unhandled super-vector materialization failure"); + inst->emitError("Unhandled super-vector materialization failure"); return true; } } @@ -636,7 +636,7 @@ static bool emitSlice(MaterializationState *state, /// Materializes super-vector types into concrete hw vector types as follows: /// 1. start from super-vector terminators (current vector_transfer_write /// ops); -/// 2. collect all the statements that can be reached by transitive use-defs +/// 2. collect all the instructions that can be reached by transitive use-defs /// chains; /// 3. get the superVectorType for this particular terminator and the /// corresponding hardware vector type (for now limited to F32) @@ -647,13 +647,13 @@ static bool emitSlice(MaterializationState *state, /// Notes /// ===== /// The `slice` is sorted in topological order by construction. -/// Additionally, this set is limited to statements in the same lexical scope +/// Additionally, this set is limited to instructions in the same lexical scope /// because we currently disallow vectorization of defs that come from another /// scope. static bool materialize(Function *f, const SetVector &terminators, MaterializationState *state) { - DenseSet seen; + DenseSet seen; for (auto *term : terminators) { // Short-circuit test, a given terminator may have been reached by some // other previous transitive use-def chains. @@ -668,16 +668,16 @@ static bool materialize(Function *f, // current enclosing scope of the terminator. See the top of the function // Note for the justification of this restriction. // TODO(ntv): relax scoping constraints. - auto *enclosingScope = term->getParentStmt(); - auto keepIfInSameScope = [enclosingScope](Statement *stmt) { - assert(stmt && "NULL stmt"); + auto *enclosingScope = term->getParentInst(); + auto keepIfInSameScope = [enclosingScope](Instruction *inst) { + assert(inst && "NULL inst"); if (!enclosingScope) { // by construction, everyone is always under the top scope (null scope). return true; } - return properlyDominates(*enclosingScope, *stmt); + return properlyDominates(*enclosingScope, *inst); }; - SetVector slice = + SetVector slice = getSlice(term, keepIfInSameScope, keepIfInSameScope); assert(!slice.empty()); @@ -722,12 +722,12 @@ PassResult MaterializeVectorsPass::runOnMLFunction(Function *f) { // Capture terminators; i.e. vector_transfer_write ops involving a strict // super-vector of subVectorType. - auto filter = [subVectorType](const Statement &stmt) { - const auto &opStmt = cast(stmt); - if (!opStmt.isa()) { + auto filter = [subVectorType](const Instruction &inst) { + const auto &opInst = cast(inst); + if (!opInst.isa()) { return false; } - return matcher::operatesOnStrictSuperVectors(opStmt, subVectorType); + return matcher::operatesOnStrictSuperVectors(opInst, subVectorType); }; auto pat = Op(filter); auto matches = pat.match(f); diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index c8a6ced4ed1..debaac3a33c 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -25,7 +25,7 @@ #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Analysis/Utils.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/StmtVisitor.h" +#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Transforms/LoopUtils.h" @@ -39,14 +39,14 @@ using namespace mlir; namespace { struct PipelineDataTransfer : public FunctionPass, - StmtWalker { + InstWalker { PipelineDataTransfer() : FunctionPass(&PipelineDataTransfer::passID) {} PassResult runOnMLFunction(Function *f) override; - PassResult runOnForStmt(ForStmt *forStmt); + PassResult runOnForInst(ForInst *forInst); - // Collect all 'for' statements. - void visitForStmt(ForStmt *forStmt) { forStmts.push_back(forStmt); } - std::vector forStmts; + // Collect all 'for' instructions. + void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); } + std::vector forInsts; static char passID; }; @@ -61,26 +61,26 @@ FunctionPass *mlir::createPipelineDataTransferPass() { return new PipelineDataTransfer(); } -// Returns the position of the tag memref operand given a DMA statement. +// Returns the position of the tag memref operand given a DMA instruction. // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are // added. TODO(b/117228571) -static unsigned getTagMemRefPos(const OperationInst &dmaStmt) { - assert(dmaStmt.isa() || dmaStmt.isa()); - if (dmaStmt.isa()) { +static unsigned getTagMemRefPos(const OperationInst &dmaInst) { + assert(dmaInst.isa() || dmaInst.isa()); + if (dmaInst.isa()) { // Second to last operand. - return dmaStmt.getNumOperands() - 2; + return dmaInst.getNumOperands() - 2; } - // First operand for a dma finish statement. + // First operand for a dma finish instruction. return 0; } -/// Doubles the buffer of the supplied memref on the specified 'for' statement +/// Doubles the buffer of the supplied memref on the specified 'for' instruction /// by adding a leading dimension of size two to the memref. Replaces all uses /// of the old memref by the new one while indexing the newly added dimension by -/// the loop IV of the specified 'for' statement modulo 2. Returns false if such -/// a replacement cannot be performed. -static bool doubleBuffer(Value *oldMemRef, ForStmt *forStmt) { - auto *forBody = forStmt->getBody(); +/// the loop IV of the specified 'for' instruction modulo 2. Returns false if +/// such a replacement cannot be performed. +static bool doubleBuffer(Value *oldMemRef, ForInst *forInst) { + auto *forBody = forInst->getBody(); FuncBuilder bInner(forBody, forBody->begin()); bInner.setInsertionPoint(forBody, forBody->begin()); @@ -101,33 +101,33 @@ static bool doubleBuffer(Value *oldMemRef, ForStmt *forStmt) { auto newMemRefType = doubleShape(oldMemRefType); // Put together alloc operands for the dynamic dimensions of the memref. - FuncBuilder bOuter(forStmt); + FuncBuilder bOuter(forInst); SmallVector allocOperands; unsigned dynamicDimCount = 0; for (auto dimSize : oldMemRefType.getShape()) { if (dimSize == -1) - allocOperands.push_back(bOuter.create(forStmt->getLoc(), oldMemRef, + allocOperands.push_back(bOuter.create(forInst->getLoc(), oldMemRef, dynamicDimCount++)); } - // Create and place the alloc right before the 'for' statement. + // Create and place the alloc right before the 'for' instruction. // TODO(mlir-team): we are assuming scoped allocation here, and aren't // inserting a dealloc -- this isn't the right thing. Value *newMemRef = - bOuter.create(forStmt->getLoc(), newMemRefType, allocOperands); + bOuter.create(forInst->getLoc(), newMemRefType, allocOperands); // Create 'iv mod 2' value to index the leading dimension. auto d0 = bInner.getAffineDimExpr(0); auto modTwoMap = bInner.getAffineMap(/*dimCount=*/1, /*symbolCount=*/0, {d0 % 2}, {}); auto ivModTwoOp = - bInner.create(forStmt->getLoc(), modTwoMap, forStmt); + bInner.create(forInst->getLoc(), modTwoMap, forInst); - // replaceAllMemRefUsesWith will always succeed unless the forStmt body has + // replaceAllMemRefUsesWith will always succeed unless the forInst body has // non-deferencing uses of the memref. if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef, ivModTwoOp->getResult(0), AffineMap::Null(), {}, - &*forStmt->getBody()->begin())) { + &*forInst->getBody()->begin())) { LLVM_DEBUG(llvm::dbgs() << "memref replacement for double buffering failed\n";); ivModTwoOp->getInstruction()->erase(); @@ -139,15 +139,15 @@ static bool doubleBuffer(Value *oldMemRef, ForStmt *forStmt) { /// Returns success if the IR is in a valid state. PassResult PipelineDataTransfer::runOnMLFunction(Function *f) { // Do a post order walk so that inner loop DMAs are processed first. This is - // necessary since 'for' statements nested within would otherwise become + // necessary since 'for' instructions nested within would otherwise become // invalid (erased) when the outer loop is pipelined (the pipelined one gets // deleted and replaced by a prologue, a new steady-state loop and an // epilogue). - forStmts.clear(); + forInsts.clear(); walkPostOrder(f); bool ret = false; - for (auto *forStmt : forStmts) { - ret = ret | runOnForStmt(forStmt); + for (auto *forInst : forInsts) { + ret = ret | runOnForInst(forInst); } return ret ? failure() : success(); } @@ -176,36 +176,36 @@ static bool checkTagMatch(OpPointer startOp, return true; } -// Identify matching DMA start/finish statements to overlap computation with. -static void findMatchingStartFinishStmts( - ForStmt *forStmt, +// Identify matching DMA start/finish instructions to overlap computation with. +static void findMatchingStartFinishInsts( + ForInst *forInst, SmallVectorImpl> &startWaitPairs) { - // Collect outgoing DMA statements - needed to check for dependences below. + // Collect outgoing DMA instructions - needed to check for dependences below. SmallVector, 4> outgoingDmaOps; - for (auto &stmt : *forStmt->getBody()) { - auto *opStmt = dyn_cast(&stmt); - if (!opStmt) + for (auto &inst : *forInst->getBody()) { + auto *opInst = dyn_cast(&inst); + if (!opInst) continue; OpPointer dmaStartOp; - if ((dmaStartOp = opStmt->dyn_cast()) && + if ((dmaStartOp = opInst->dyn_cast()) && dmaStartOp->isSrcMemorySpaceFaster()) outgoingDmaOps.push_back(dmaStartOp); } - SmallVector dmaStartStmts, dmaFinishStmts; - for (auto &stmt : *forStmt->getBody()) { - auto *opStmt = dyn_cast(&stmt); - if (!opStmt) + SmallVector dmaStartInsts, dmaFinishInsts; + for (auto &inst : *forInst->getBody()) { + auto *opInst = dyn_cast(&inst); + if (!opInst) continue; - // Collect DMA finish statements. - if (opStmt->isa()) { - dmaFinishStmts.push_back(opStmt); + // Collect DMA finish instructions. + if (opInst->isa()) { + dmaFinishInsts.push_back(opInst); continue; } OpPointer dmaStartOp; - if (!(dmaStartOp = opStmt->dyn_cast())) + if (!(dmaStartOp = opInst->dyn_cast())) continue; // Only DMAs incoming into higher memory spaces are pipelined for now. // TODO(bondhugula): handle outgoing DMA pipelining. @@ -227,7 +227,7 @@ static void findMatchingStartFinishStmts( auto *memref = dmaStartOp->getOperand(dmaStartOp->getFasterMemPos()); bool escapingUses = false; for (const auto &use : memref->getUses()) { - if (!dominates(*forStmt->getBody()->begin(), *use.getOwner())) { + if (!dominates(*forInst->getBody()->begin(), *use.getOwner())) { LLVM_DEBUG(llvm::dbgs() << "can't pipeline: buffer is live out of loop\n";); escapingUses = true; @@ -235,15 +235,15 @@ static void findMatchingStartFinishStmts( } } if (!escapingUses) - dmaStartStmts.push_back(opStmt); + dmaStartInsts.push_back(opInst); } - // For each start statement, we look for a matching finish statement. - for (auto *dmaStartStmt : dmaStartStmts) { - for (auto *dmaFinishStmt : dmaFinishStmts) { - if (checkTagMatch(dmaStartStmt->cast(), - dmaFinishStmt->cast())) { - startWaitPairs.push_back({dmaStartStmt, dmaFinishStmt}); + // For each start instruction, we look for a matching finish instruction. + for (auto *dmaStartInst : dmaStartInsts) { + for (auto *dmaFinishInst : dmaFinishInsts) { + if (checkTagMatch(dmaStartInst->cast(), + dmaFinishInst->cast())) { + startWaitPairs.push_back({dmaStartInst, dmaFinishInst}); break; } } @@ -251,17 +251,17 @@ static void findMatchingStartFinishStmts( } /// Overlap DMA transfers with computation in this loop. If successful, -/// 'forStmt' is deleted, and a prologue, a new pipelined loop, and epilogue are +/// 'forInst' is deleted, and a prologue, a new pipelined loop, and epilogue are /// inserted right before where it was. -PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { - auto mayBeConstTripCount = getConstantTripCount(*forStmt); +PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) { + auto mayBeConstTripCount = getConstantTripCount(*forInst); if (!mayBeConstTripCount.hasValue()) { LLVM_DEBUG(llvm::dbgs() << "unknown trip count loop\n"); return success(); } SmallVector, 4> startWaitPairs; - findMatchingStartFinishStmts(forStmt, startWaitPairs); + findMatchingStartFinishInsts(forInst, startWaitPairs); if (startWaitPairs.empty()) { LLVM_DEBUG(llvm::dbgs() << "No dma start/finish pairs\n";); @@ -269,22 +269,22 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { } // Double the buffers for the higher memory space memref's. - // Identify memref's to replace by scanning through all DMA start statements. - // A DMA start statement has two memref's - the one from the higher level of - // memory hierarchy is the one to double buffer. + // Identify memref's to replace by scanning through all DMA start + // instructions. A DMA start instruction has two memref's - the one from the + // higher level of memory hierarchy is the one to double buffer. // TODO(bondhugula): check whether double-buffering is even necessary. // TODO(bondhugula): make this work with different layouts: assuming here that // the dimension we are adding here for the double buffering is the outermost // dimension. for (auto &pair : startWaitPairs) { - auto *dmaStartStmt = pair.first; - Value *oldMemRef = dmaStartStmt->getOperand( - dmaStartStmt->cast()->getFasterMemPos()); - if (!doubleBuffer(oldMemRef, forStmt)) { + auto *dmaStartInst = pair.first; + Value *oldMemRef = dmaStartInst->getOperand( + dmaStartInst->cast()->getFasterMemPos()); + if (!doubleBuffer(oldMemRef, forInst)) { // Normally, double buffering should not fail because we already checked // that there are no uses outside. LLVM_DEBUG(llvm::dbgs() << "double buffering failed for: \n";); - LLVM_DEBUG(dmaStartStmt->dump()); + LLVM_DEBUG(dmaStartInst->dump()); // IR still in a valid state. return success(); } @@ -293,80 +293,80 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { // operation could have been used on it if it was dynamically shaped in // order to create the double buffer above) if (oldMemRef->use_empty()) - if (auto *allocStmt = oldMemRef->getDefiningInst()) - allocStmt->erase(); + if (auto *allocInst = oldMemRef->getDefiningInst()) + allocInst->erase(); } // Double the buffers for tag memrefs. for (auto &pair : startWaitPairs) { - auto *dmaFinishStmt = pair.second; + auto *dmaFinishInst = pair.second; Value *oldTagMemRef = - dmaFinishStmt->getOperand(getTagMemRefPos(*dmaFinishStmt)); - if (!doubleBuffer(oldTagMemRef, forStmt)) { + dmaFinishInst->getOperand(getTagMemRefPos(*dmaFinishInst)); + if (!doubleBuffer(oldTagMemRef, forInst)) { LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";); return success(); } // If the old tag has no more uses, remove its 'dead' alloc if it was // alloc'ed. if (oldTagMemRef->use_empty()) - if (auto *allocStmt = oldTagMemRef->getDefiningInst()) - allocStmt->erase(); + if (auto *allocInst = oldTagMemRef->getDefiningInst()) + allocInst->erase(); } - // Double buffering would have invalidated all the old DMA start/wait stmts. + // Double buffering would have invalidated all the old DMA start/wait insts. startWaitPairs.clear(); - findMatchingStartFinishStmts(forStmt, startWaitPairs); + findMatchingStartFinishInsts(forInst, startWaitPairs); - // Store shift for statement for later lookup for AffineApplyOp's. - DenseMap stmtShiftMap; + // Store shift for instruction for later lookup for AffineApplyOp's. + DenseMap instShiftMap; for (auto &pair : startWaitPairs) { - auto *dmaStartStmt = pair.first; - assert(dmaStartStmt->isa()); - stmtShiftMap[dmaStartStmt] = 0; - // Set shifts for DMA start stmt's affine operand computation slices to 0. - if (auto *slice = mlir::createAffineComputationSlice(dmaStartStmt)) { - stmtShiftMap[slice] = 0; + auto *dmaStartInst = pair.first; + assert(dmaStartInst->isa()); + instShiftMap[dmaStartInst] = 0; + // Set shifts for DMA start inst's affine operand computation slices to 0. + if (auto *slice = mlir::createAffineComputationSlice(dmaStartInst)) { + instShiftMap[slice] = 0; } else { // If a slice wasn't created, the reachable affine_apply op's from its // operands are the ones that go with it. - SmallVector affineApplyStmts; - SmallVector operands(dmaStartStmt->getOperands()); - getReachableAffineApplyOps(operands, affineApplyStmts); - for (const auto *stmt : affineApplyStmts) { - stmtShiftMap[stmt] = 0; + SmallVector affineApplyInsts; + SmallVector operands(dmaStartInst->getOperands()); + getReachableAffineApplyOps(operands, affineApplyInsts); + for (const auto *inst : affineApplyInsts) { + instShiftMap[inst] = 0; } } } // Everything else (including compute ops and dma finish) are shifted by one. - for (const auto &stmt : *forStmt->getBody()) { - if (stmtShiftMap.find(&stmt) == stmtShiftMap.end()) { - stmtShiftMap[&stmt] = 1; + for (const auto &inst : *forInst->getBody()) { + if (instShiftMap.find(&inst) == instShiftMap.end()) { + instShiftMap[&inst] = 1; } } // Get shifts stored in map. - std::vector shifts(forStmt->getBody()->getInstructions().size()); + std::vector shifts(forInst->getBody()->getInstructions().size()); unsigned s = 0; - for (auto &stmt : *forStmt->getBody()) { - assert(stmtShiftMap.find(&stmt) != stmtShiftMap.end()); - shifts[s++] = stmtShiftMap[&stmt]; + for (auto &inst : *forInst->getBody()) { + assert(instShiftMap.find(&inst) != instShiftMap.end()); + shifts[s++] = instShiftMap[&inst]; LLVM_DEBUG( - // Tagging statements with shifts for debugging purposes. - if (auto *opStmt = dyn_cast(&stmt)) { - FuncBuilder b(opStmt); - opStmt->setAttr(b.getIdentifier("shift"), + // Tagging instructions with shifts for debugging purposes. + if (auto *opInst = dyn_cast(&inst)) { + FuncBuilder b(opInst); + opInst->setAttr(b.getIdentifier("shift"), b.getI64IntegerAttr(shifts[s - 1])); }); } - if (!isStmtwiseShiftValid(*forStmt, shifts)) { + if (!isInstwiseShiftValid(*forInst, shifts)) { // Violates dependences. LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";); return success(); } - if (stmtBodySkew(forStmt, shifts)) { - LLVM_DEBUG(llvm::dbgs() << "stmt body skewing failed - unexpected\n";); + if (instBodySkew(forInst, shifts)) { + LLVM_DEBUG(llvm::dbgs() << "inst body skewing failed - unexpected\n";); return success(); } diff --git a/mlir/lib/Transforms/SimplifyAffineExpr.cpp b/mlir/lib/Transforms/SimplifyAffineExpr.cpp index 853a814e516..2a643eb690a 100644 --- a/mlir/lib/Transforms/SimplifyAffineExpr.cpp +++ b/mlir/lib/Transforms/SimplifyAffineExpr.cpp @@ -21,7 +21,7 @@ #include "mlir/Analysis/AffineStructures.h" #include "mlir/IR/Function.h" -#include "mlir/IR/StmtVisitor.h" +#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/Transforms/Passes.h" @@ -32,12 +32,12 @@ using llvm::report_fatal_error; namespace { -/// Simplifies all affine expressions appearing in the operation statements of +/// Simplifies all affine expressions appearing in the operation instructions of /// the Function. This is mainly to test the simplifyAffineExpr method. // TODO(someone): Gradually, extend this to all affine map references found in // ML functions and CFG functions. struct SimplifyAffineStructures : public FunctionPass, - StmtWalker { + InstWalker { explicit SimplifyAffineStructures() : FunctionPass(&SimplifyAffineStructures::passID) {} @@ -46,8 +46,8 @@ struct SimplifyAffineStructures : public FunctionPass, // for this yet? TODO(someone). PassResult runOnCFGFunction(Function *f) override { return success(); } - void visitIfStmt(IfStmt *ifStmt); - void visitOperationInst(OperationInst *opStmt); + void visitIfInst(IfInst *ifInst); + void visitOperationInst(OperationInst *opInst); static char passID; }; @@ -70,18 +70,18 @@ static IntegerSet simplifyIntegerSet(IntegerSet set) { return set; } -void SimplifyAffineStructures::visitIfStmt(IfStmt *ifStmt) { - auto set = ifStmt->getCondition().getIntegerSet(); - ifStmt->setIntegerSet(simplifyIntegerSet(set)); +void SimplifyAffineStructures::visitIfInst(IfInst *ifInst) { + auto set = ifInst->getCondition().getIntegerSet(); + ifInst->setIntegerSet(simplifyIntegerSet(set)); } -void SimplifyAffineStructures::visitOperationInst(OperationInst *opStmt) { - for (auto attr : opStmt->getAttrs()) { +void SimplifyAffineStructures::visitOperationInst(OperationInst *opInst) { + for (auto attr : opInst->getAttrs()) { if (auto mapAttr = attr.second.dyn_cast()) { MutableAffineMap mMap(mapAttr.getValue()); mMap.simplify(); auto map = mMap.getAffineMap(); - opStmt->setAttr(attr.first, AffineMapAttr::get(map)); + opInst->setAttr(attr.first, AffineMapAttr::get(map)); } } } diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index a4116667794..6064d1feff3 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -271,7 +271,7 @@ static void processMLFunction(Function *fn, } void setInsertionPoint(OperationInst *op) override { - // Any new operations should be added before this statement. + // Any new operations should be added before this instruction. builder.setInsertionPoint(cast(op)); } @@ -280,7 +280,7 @@ static void processMLFunction(Function *fn, }; GreedyPatternRewriteDriver driver(std::move(patterns)); - fn->walk([&](OperationInst *stmt) { driver.addToWorklist(stmt); }); + fn->walk([&](OperationInst *inst) { driver.addToWorklist(inst); }); FuncBuilder mlBuilder(fn); MLFuncRewriter rewriter(driver, mlBuilder); diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 03b4bb29e19..93039372121 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -26,8 +26,8 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Statements.h" -#include "mlir/IR/StmtVisitor.h" +#include "mlir/IR/InstVisitor.h" +#include "mlir/IR/Instructions.h" #include "mlir/StandardOps/StandardOps.h" #include "llvm/ADT/DenseMap.h" #include "llvm/Support/Debug.h" @@ -38,22 +38,22 @@ using namespace mlir; /// Returns the upper bound of an unrolled loop with lower bound 'lb' and with /// the specified trip count, stride, and unroll factor. Returns nullptr when /// the trip count can't be expressed as an affine expression. -AffineMap mlir::getUnrolledLoopUpperBound(const ForStmt &forStmt, +AffineMap mlir::getUnrolledLoopUpperBound(const ForInst &forInst, unsigned unrollFactor, FuncBuilder *builder) { - auto lbMap = forStmt.getLowerBoundMap(); + auto lbMap = forInst.getLowerBoundMap(); // Single result lower bound map only. if (lbMap.getNumResults() != 1) return AffineMap::Null(); // Sometimes, the trip count cannot be expressed as an affine expression. - auto tripCount = getTripCountExpr(forStmt); + auto tripCount = getTripCountExpr(forInst); if (!tripCount) return AffineMap::Null(); AffineExpr lb(lbMap.getResult(0)); - unsigned step = forStmt.getStep(); + unsigned step = forInst.getStep(); auto newUb = lb + (tripCount - tripCount % unrollFactor - 1) * step; return builder->getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(), @@ -64,122 +64,122 @@ AffineMap mlir::getUnrolledLoopUpperBound(const ForStmt &forStmt, /// bound 'lb' and with the specified trip count, stride, and unroll factor. /// Returns an AffinMap with nullptr storage (that evaluates to false) /// when the trip count can't be expressed as an affine expression. -AffineMap mlir::getCleanupLoopLowerBound(const ForStmt &forStmt, +AffineMap mlir::getCleanupLoopLowerBound(const ForInst &forInst, unsigned unrollFactor, FuncBuilder *builder) { - auto lbMap = forStmt.getLowerBoundMap(); + auto lbMap = forInst.getLowerBoundMap(); // Single result lower bound map only. if (lbMap.getNumResults() != 1) return AffineMap::Null(); // Sometimes the trip count cannot be expressed as an affine expression. - AffineExpr tripCount(getTripCountExpr(forStmt)); + AffineExpr tripCount(getTripCountExpr(forInst)); if (!tripCount) return AffineMap::Null(); AffineExpr lb(lbMap.getResult(0)); - unsigned step = forStmt.getStep(); + unsigned step = forInst.getStep(); auto newLb = lb + (tripCount - tripCount % unrollFactor) * step; return builder->getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(), {newLb}, {}); } -/// Promotes the loop body of a forStmt to its containing block if the forStmt +/// Promotes the loop body of a forInst to its containing block if the forInst /// was known to have a single iteration. Returns false otherwise. // TODO(bondhugula): extend this for arbitrary affine bounds. -bool mlir::promoteIfSingleIteration(ForStmt *forStmt) { - Optional tripCount = getConstantTripCount(*forStmt); +bool mlir::promoteIfSingleIteration(ForInst *forInst) { + Optional tripCount = getConstantTripCount(*forInst); if (!tripCount.hasValue() || tripCount.getValue() != 1) return false; // TODO(mlir-team): there is no builder for a max. - if (forStmt->getLowerBoundMap().getNumResults() != 1) + if (forInst->getLowerBoundMap().getNumResults() != 1) return false; // Replaces all IV uses to its single iteration value. - if (!forStmt->use_empty()) { - if (forStmt->hasConstantLowerBound()) { - auto *mlFunc = forStmt->getFunction(); + if (!forInst->use_empty()) { + if (forInst->hasConstantLowerBound()) { + auto *mlFunc = forInst->getFunction(); FuncBuilder topBuilder(&mlFunc->getBody()->front()); auto constOp = topBuilder.create( - forStmt->getLoc(), forStmt->getConstantLowerBound()); - forStmt->replaceAllUsesWith(constOp); + forInst->getLoc(), forInst->getConstantLowerBound()); + forInst->replaceAllUsesWith(constOp); } else { - const AffineBound lb = forStmt->getLowerBound(); + const AffineBound lb = forInst->getLowerBound(); SmallVector lbOperands(lb.operand_begin(), lb.operand_end()); - FuncBuilder builder(forStmt->getBlock(), Block::iterator(forStmt)); + FuncBuilder builder(forInst->getBlock(), Block::iterator(forInst)); auto affineApplyOp = builder.create( - forStmt->getLoc(), lb.getMap(), lbOperands); - forStmt->replaceAllUsesWith(affineApplyOp->getResult(0)); + forInst->getLoc(), lb.getMap(), lbOperands); + forInst->replaceAllUsesWith(affineApplyOp->getResult(0)); } } - // Move the loop body statements to the loop's containing block. - auto *block = forStmt->getBlock(); - block->getInstructions().splice(Block::iterator(forStmt), - forStmt->getBody()->getInstructions()); - forStmt->erase(); + // Move the loop body instructions to the loop's containing block. + auto *block = forInst->getBlock(); + block->getInstructions().splice(Block::iterator(forInst), + forInst->getBody()->getInstructions()); + forInst->erase(); return true; } -/// Promotes all single iteration for stmt's in the Function, i.e., moves +/// Promotes all single iteration for inst's in the Function, i.e., moves /// their body into the containing Block. void mlir::promoteSingleIterationLoops(Function *f) { // Gathers all innermost loops through a post order pruned walk. - class LoopBodyPromoter : public StmtWalker { + class LoopBodyPromoter : public InstWalker { public: - void visitForStmt(ForStmt *forStmt) { promoteIfSingleIteration(forStmt); } + void visitForInst(ForInst *forInst) { promoteIfSingleIteration(forInst); } }; LoopBodyPromoter fsw; fsw.walkPostOrder(f); } -/// Generates a 'for' stmt with the specified lower and upper bounds while -/// generating the right IV remappings for the shifted statements. The -/// statement blocks that go into the loop are specified in stmtGroupQueue +/// Generates a 'for' inst with the specified lower and upper bounds while +/// generating the right IV remappings for the shifted instructions. The +/// instruction blocks that go into the loop are specified in instGroupQueue /// starting from the specified offset, and in that order; the first element of -/// the pair specifies the shift applied to that group of statements; note that -/// the shift is multiplied by the loop step before being applied. Returns +/// the pair specifies the shift applied to that group of instructions; note +/// that the shift is multiplied by the loop step before being applied. Returns /// nullptr if the generated loop simplifies to a single iteration one. -static ForStmt * +static ForInst * generateLoop(AffineMap lbMap, AffineMap ubMap, - const std::vector>> - &stmtGroupQueue, - unsigned offset, ForStmt *srcForStmt, FuncBuilder *b) { - SmallVector lbOperands(srcForStmt->getLowerBoundOperands()); - SmallVector ubOperands(srcForStmt->getUpperBoundOperands()); + const std::vector>> + &instGroupQueue, + unsigned offset, ForInst *srcForInst, FuncBuilder *b) { + SmallVector lbOperands(srcForInst->getLowerBoundOperands()); + SmallVector ubOperands(srcForInst->getUpperBoundOperands()); assert(lbMap.getNumInputs() == lbOperands.size()); assert(ubMap.getNumInputs() == ubOperands.size()); - auto *loopChunk = b->createFor(srcForStmt->getLoc(), lbOperands, lbMap, - ubOperands, ubMap, srcForStmt->getStep()); + auto *loopChunk = b->createFor(srcForInst->getLoc(), lbOperands, lbMap, + ubOperands, ubMap, srcForInst->getStep()); OperationInst::OperandMapTy operandMap; - for (auto it = stmtGroupQueue.begin() + offset, e = stmtGroupQueue.end(); + for (auto it = instGroupQueue.begin() + offset, e = instGroupQueue.end(); it != e; ++it) { uint64_t shift = it->first; - auto stmts = it->second; - // All 'same shift' statements get added with their operands being remapped - // to results of cloned statements, and their IV used remapped. + auto insts = it->second; + // All 'same shift' instructions get added with their operands being + // remapped to results of cloned instructions, and their IV used remapped. // Generate the remapping if the shift is not zero: remappedIV = newIV - // shift. - if (!srcForStmt->use_empty() && shift != 0) { - auto b = FuncBuilder::getForStmtBodyBuilder(loopChunk); + if (!srcForInst->use_empty() && shift != 0) { + auto b = FuncBuilder::getForInstBodyBuilder(loopChunk); auto *ivRemap = b.create( - srcForStmt->getLoc(), + srcForInst->getLoc(), b.getSingleDimShiftAffineMap(-static_cast( - srcForStmt->getStep() * shift)), + srcForInst->getStep() * shift)), loopChunk) ->getResult(0); - operandMap[srcForStmt] = ivRemap; + operandMap[srcForInst] = ivRemap; } else { - operandMap[srcForStmt] = loopChunk; + operandMap[srcForInst] = loopChunk; } - for (auto *stmt : stmts) { - loopChunk->getBody()->push_back(stmt->clone(operandMap, b->getContext())); + for (auto *inst : insts) { + loopChunk->getBody()->push_back(inst->clone(operandMap, b->getContext())); } } if (promoteIfSingleIteration(loopChunk)) @@ -187,63 +187,63 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, return loopChunk; } -/// Skew the statements in the body of a 'for' statement with the specified -/// statement-wise shifts. The shifts are with respect to the original execution -/// order, and are multiplied by the loop 'step' before being applied. A shift -/// of zero for each statement will lead to no change. -// The skewing of statements with respect to one another can be used for example -// to allow overlap of asynchronous operations (such as DMA communication) with -// computation, or just relative shifting of statements for better register -// reuse, locality or parallelism. As such, the shifts are typically expected to -// be at most of the order of the number of statements. This method should not -// be used as a substitute for loop distribution/fission. -// This method uses an algorithm// in time linear in the number of statements in -// the body of the for loop - (using the 'sweep line' paradigm). This method +/// Skew the instructions in the body of a 'for' instruction with the specified +/// instruction-wise shifts. The shifts are with respect to the original +/// execution order, and are multiplied by the loop 'step' before being applied. +/// A shift of zero for each instruction will lead to no change. +// The skewing of instructions with respect to one another can be used for +// example to allow overlap of asynchronous operations (such as DMA +// communication) with computation, or just relative shifting of instructions +// for better register reuse, locality or parallelism. As such, the shifts are +// typically expected to be at most of the order of the number of instructions. +// This method should not be used as a substitute for loop distribution/fission. +// This method uses an algorithm// in time linear in the number of instructions +// in the body of the for loop - (using the 'sweep line' paradigm). This method // asserts preservation of SSA dominance. A check for that as well as that for // memory-based depedence preservation check rests with the users of this // method. -UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef shifts, +UtilResult mlir::instBodySkew(ForInst *forInst, ArrayRef shifts, bool unrollPrologueEpilogue) { - if (forStmt->getBody()->empty()) + if (forInst->getBody()->empty()) return UtilResult::Success; // If the trip counts aren't constant, we would need versioning and // conditional guards (or context information to prevent such versioning). The // better way to pipeline for such loops is to first tile them and extract // constant trip count "full tiles" before applying this. - auto mayBeConstTripCount = getConstantTripCount(*forStmt); + auto mayBeConstTripCount = getConstantTripCount(*forInst); if (!mayBeConstTripCount.hasValue()) { LLVM_DEBUG(llvm::dbgs() << "non-constant trip count loop\n";); return UtilResult::Success; } uint64_t tripCount = mayBeConstTripCount.getValue(); - assert(isStmtwiseShiftValid(*forStmt, shifts) && + assert(isInstwiseShiftValid(*forInst, shifts) && "shifts will lead to an invalid transformation\n"); - int64_t step = forStmt->getStep(); + int64_t step = forInst->getStep(); - unsigned numChildStmts = forStmt->getBody()->getInstructions().size(); + unsigned numChildInsts = forInst->getBody()->getInstructions().size(); // Do a linear time (counting) sort for the shifts. uint64_t maxShift = 0; - for (unsigned i = 0; i < numChildStmts; i++) { + for (unsigned i = 0; i < numChildInsts; i++) { maxShift = std::max(maxShift, shifts[i]); } // Such large shifts are not the typical use case. - if (maxShift >= numChildStmts) { - LLVM_DEBUG(llvm::dbgs() << "stmt shifts too large - unexpected\n";); + if (maxShift >= numChildInsts) { + LLVM_DEBUG(llvm::dbgs() << "inst shifts too large - unexpected\n";); return UtilResult::Success; } - // An array of statement groups sorted by shift amount; each group has all - // statements with the same shift in the order in which they appear in the - // body of the 'for' stmt. - std::vector> sortedStmtGroups(maxShift + 1); + // An array of instruction groups sorted by shift amount; each group has all + // instructions with the same shift in the order in which they appear in the + // body of the 'for' inst. + std::vector> sortedInstGroups(maxShift + 1); unsigned pos = 0; - for (auto &stmt : *forStmt->getBody()) { + for (auto &inst : *forInst->getBody()) { auto shift = shifts[pos++]; - sortedStmtGroups[shift].push_back(&stmt); + sortedInstGroups[shift].push_back(&inst); } // Unless the shifts have a specific pattern (which actually would be the @@ -251,40 +251,40 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef shifts, // Nevertheless, if 'unrollPrologueEpilogue' is set, we will treat the first // loop generated as the prologue and the last as epilogue and unroll these // fully. - ForStmt *prologue = nullptr; - ForStmt *epilogue = nullptr; + ForInst *prologue = nullptr; + ForInst *epilogue = nullptr; // Do a sweep over the sorted shifts while storing open groups in a // vector, and generating loop portions as necessary during the sweep. A block - // of statements is paired with its shift. - std::vector>> stmtGroupQueue; + // of instructions is paired with its shift. + std::vector>> instGroupQueue; - auto origLbMap = forStmt->getLowerBoundMap(); + auto origLbMap = forInst->getLowerBoundMap(); uint64_t lbShift = 0; - FuncBuilder b(forStmt); - for (uint64_t d = 0, e = sortedStmtGroups.size(); d < e; ++d) { + FuncBuilder b(forInst); + for (uint64_t d = 0, e = sortedInstGroups.size(); d < e; ++d) { // If nothing is shifted by d, continue. - if (sortedStmtGroups[d].empty()) + if (sortedInstGroups[d].empty()) continue; - if (!stmtGroupQueue.empty()) { + if (!instGroupQueue.empty()) { assert(d >= 1 && "Queue expected to be empty when the first block is found"); // The interval for which the loop needs to be generated here is: // [lbShift, min(lbShift + tripCount, d)) and the body of the - // loop needs to have all statements in stmtQueue in that order. - ForStmt *res; + // loop needs to have all instructions in instQueue in that order. + ForInst *res; if (lbShift + tripCount * step < d * step) { res = generateLoop( b.getShiftedAffineMap(origLbMap, lbShift), b.getShiftedAffineMap(origLbMap, lbShift + tripCount * step), - stmtGroupQueue, 0, forStmt, &b); - // Entire loop for the queued stmt groups generated, empty it. - stmtGroupQueue.clear(); + instGroupQueue, 0, forInst, &b); + // Entire loop for the queued inst groups generated, empty it. + instGroupQueue.clear(); lbShift += tripCount * step; } else { res = generateLoop(b.getShiftedAffineMap(origLbMap, lbShift), - b.getShiftedAffineMap(origLbMap, d), stmtGroupQueue, - 0, forStmt, &b); + b.getShiftedAffineMap(origLbMap, d), instGroupQueue, + 0, forInst, &b); lbShift = d * step; } if (!prologue && res) @@ -294,24 +294,24 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef shifts, // Start of first interval. lbShift = d * step; } - // Augment the list of statements that get into the current open interval. - stmtGroupQueue.push_back({d, sortedStmtGroups[d]}); + // Augment the list of instructions that get into the current open interval. + instGroupQueue.push_back({d, sortedInstGroups[d]}); } - // Those statements groups left in the queue now need to be processed (FIFO) + // Those instructions groups left in the queue now need to be processed (FIFO) // and their loops completed. - for (unsigned i = 0, e = stmtGroupQueue.size(); i < e; ++i) { - uint64_t ubShift = (stmtGroupQueue[i].first + tripCount) * step; + for (unsigned i = 0, e = instGroupQueue.size(); i < e; ++i) { + uint64_t ubShift = (instGroupQueue[i].first + tripCount) * step; epilogue = generateLoop(b.getShiftedAffineMap(origLbMap, lbShift), b.getShiftedAffineMap(origLbMap, ubShift), - stmtGroupQueue, i, forStmt, &b); + instGroupQueue, i, forInst, &b); lbShift = ubShift; if (!prologue) prologue = epilogue; } - // Erase the original for stmt. - forStmt->erase(); + // Erase the original for inst. + forInst->erase(); if (unrollPrologueEpilogue && prologue) loopUnrollFull(prologue); @@ -322,39 +322,39 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef shifts, } /// Unrolls this loop completely. -bool mlir::loopUnrollFull(ForStmt *forStmt) { - Optional mayBeConstantTripCount = getConstantTripCount(*forStmt); +bool mlir::loopUnrollFull(ForInst *forInst) { + Optional mayBeConstantTripCount = getConstantTripCount(*forInst); if (mayBeConstantTripCount.hasValue()) { uint64_t tripCount = mayBeConstantTripCount.getValue(); if (tripCount == 1) { - return promoteIfSingleIteration(forStmt); + return promoteIfSingleIteration(forInst); } - return loopUnrollByFactor(forStmt, tripCount); + return loopUnrollByFactor(forInst, tripCount); } return false; } /// Unrolls and jams this loop by the specified factor or by the trip count (if /// constant) whichever is lower. -bool mlir::loopUnrollUpToFactor(ForStmt *forStmt, uint64_t unrollFactor) { - Optional mayBeConstantTripCount = getConstantTripCount(*forStmt); +bool mlir::loopUnrollUpToFactor(ForInst *forInst, uint64_t unrollFactor) { + Optional mayBeConstantTripCount = getConstantTripCount(*forInst); if (mayBeConstantTripCount.hasValue() && mayBeConstantTripCount.getValue() < unrollFactor) - return loopUnrollByFactor(forStmt, mayBeConstantTripCount.getValue()); - return loopUnrollByFactor(forStmt, unrollFactor); + return loopUnrollByFactor(forInst, mayBeConstantTripCount.getValue()); + return loopUnrollByFactor(forInst, unrollFactor); } /// Unrolls this loop by the specified factor. Returns true if the loop /// is successfully unrolled. -bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) { +bool mlir::loopUnrollByFactor(ForInst *forInst, uint64_t unrollFactor) { assert(unrollFactor >= 1 && "unroll factor should be >= 1"); - if (unrollFactor == 1 || forStmt->getBody()->empty()) + if (unrollFactor == 1 || forInst->getBody()->empty()) return false; - auto lbMap = forStmt->getLowerBoundMap(); - auto ubMap = forStmt->getUpperBoundMap(); + auto lbMap = forInst->getLowerBoundMap(); + auto ubMap = forInst->getUpperBoundMap(); // Loops with max/min expressions won't be unrolled here (the output can't be // expressed as a Function in the general case). However, the right way to @@ -365,10 +365,10 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) { // Same operand list for lower and upper bound for now. // TODO(bondhugula): handle bounds with different operand lists. - if (!forStmt->matchingBoundOperandList()) + if (!forInst->matchingBoundOperandList()) return false; - Optional mayBeConstantTripCount = getConstantTripCount(*forStmt); + Optional mayBeConstantTripCount = getConstantTripCount(*forInst); // If the trip count is lower than the unroll factor, no unrolled body. // TODO(bondhugula): option to specify cleanup loop unrolling. @@ -377,64 +377,64 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) { return false; // Generate the cleanup loop if trip count isn't a multiple of unrollFactor. - if (getLargestDivisorOfTripCount(*forStmt) % unrollFactor != 0) { + if (getLargestDivisorOfTripCount(*forInst) % unrollFactor != 0) { DenseMap operandMap; - FuncBuilder builder(forStmt->getBlock(), ++Block::iterator(forStmt)); - auto *cleanupForStmt = cast(builder.clone(*forStmt, operandMap)); - auto clLbMap = getCleanupLoopLowerBound(*forStmt, unrollFactor, &builder); + FuncBuilder builder(forInst->getBlock(), ++Block::iterator(forInst)); + auto *cleanupForInst = cast(builder.clone(*forInst, operandMap)); + auto clLbMap = getCleanupLoopLowerBound(*forInst, unrollFactor, &builder); assert(clLbMap && "cleanup loop lower bound map for single result bound maps can " "always be determined"); - cleanupForStmt->setLowerBoundMap(clLbMap); + cleanupForInst->setLowerBoundMap(clLbMap); // Promote the loop body up if this has turned into a single iteration loop. - promoteIfSingleIteration(cleanupForStmt); + promoteIfSingleIteration(cleanupForInst); // Adjust upper bound. auto unrolledUbMap = - getUnrolledLoopUpperBound(*forStmt, unrollFactor, &builder); + getUnrolledLoopUpperBound(*forInst, unrollFactor, &builder); assert(unrolledUbMap && "upper bound map can alwayys be determined for an unrolled loop " "with single result bounds"); - forStmt->setUpperBoundMap(unrolledUbMap); + forInst->setUpperBoundMap(unrolledUbMap); } // Scale the step of loop being unrolled by unroll factor. - int64_t step = forStmt->getStep(); - forStmt->setStep(step * unrollFactor); + int64_t step = forInst->getStep(); + forInst->setStep(step * unrollFactor); - // Builder to insert unrolled bodies right after the last statement in the - // body of 'forStmt'. - FuncBuilder builder(forStmt->getBody(), forStmt->getBody()->end()); + // Builder to insert unrolled bodies right after the last instruction in the + // body of 'forInst'. + FuncBuilder builder(forInst->getBody(), forInst->getBody()->end()); - // Keep a pointer to the last statement in the original block so that we know - // what to clone (since we are doing this in-place). - Block::iterator srcBlockEnd = std::prev(forStmt->getBody()->end()); + // Keep a pointer to the last instruction in the original block so that we + // know what to clone (since we are doing this in-place). + Block::iterator srcBlockEnd = std::prev(forInst->getBody()->end()); - // Unroll the contents of 'forStmt' (append unrollFactor-1 additional copies). + // Unroll the contents of 'forInst' (append unrollFactor-1 additional copies). for (unsigned i = 1; i < unrollFactor; i++) { DenseMap operandMap; // If the induction variable is used, create a remapping to the value for // this unrolled instance. - if (!forStmt->use_empty()) { + if (!forInst->use_empty()) { // iv' = iv + 1/2/3...unrollFactor-1; auto d0 = builder.getAffineDimExpr(0); auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {}); auto *ivUnroll = - builder.create(forStmt->getLoc(), bumpMap, forStmt) + builder.create(forInst->getLoc(), bumpMap, forInst) ->getResult(0); - operandMap[forStmt] = ivUnroll; + operandMap[forInst] = ivUnroll; } - // Clone the original body of 'forStmt'. - for (auto it = forStmt->getBody()->begin(); it != std::next(srcBlockEnd); + // Clone the original body of 'forInst'. + for (auto it = forInst->getBody()->begin(); it != std::next(srcBlockEnd); it++) { builder.clone(*it, operandMap); } } // Promote the loop body up if this has turned into a single iteration loop. - promoteIfSingleIteration(forStmt); + promoteIfSingleIteration(forInst); return true; } diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 3661c1bdbbc..8cfe2619e2a 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -26,8 +26,8 @@ #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Utils.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/InstVisitor.h" #include "mlir/IR/Module.h" -#include "mlir/IR/StmtVisitor.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Support/MathExtras.h" #include "llvm/ADT/DenseMap.h" @@ -66,7 +66,7 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, ArrayRef extraIndices, AffineMap indexRemap, ArrayRef extraOperands, - const Statement *domStmtFilter) { + const Instruction *domInstFilter) { unsigned newMemRefRank = newMemRef->getType().cast().getRank(); (void)newMemRefRank; // unused in opt mode unsigned oldMemRefRank = oldMemRef->getType().cast().getRank(); @@ -85,41 +85,41 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, // Walk all uses of old memref. Operation using the memref gets replaced. for (auto it = oldMemRef->use_begin(); it != oldMemRef->use_end();) { InstOperand &use = *(it++); - auto *opStmt = cast(use.getOwner()); + auto *opInst = cast(use.getOwner()); - // Skip this use if it's not dominated by domStmtFilter. - if (domStmtFilter && !dominates(*domStmtFilter, *opStmt)) + // Skip this use if it's not dominated by domInstFilter. + if (domInstFilter && !dominates(*domInstFilter, *opInst)) continue; // Check if the memref was used in a non-deferencing context. It is fine for // the memref to be used in a non-deferencing way outside of the region // where this replacement is happening. - if (!isMemRefDereferencingOp(*opStmt)) + if (!isMemRefDereferencingOp(*opInst)) // Failure: memref used in a non-deferencing op (potentially escapes); no // replacement in these cases. return false; auto getMemRefOperandPos = [&]() -> unsigned { unsigned i, e; - for (i = 0, e = opStmt->getNumOperands(); i < e; i++) { - if (opStmt->getOperand(i) == oldMemRef) + for (i = 0, e = opInst->getNumOperands(); i < e; i++) { + if (opInst->getOperand(i) == oldMemRef) break; } - assert(i < opStmt->getNumOperands() && "operand guaranteed to be found"); + assert(i < opInst->getNumOperands() && "operand guaranteed to be found"); return i; }; unsigned memRefOperandPos = getMemRefOperandPos(); - // Construct the new operation statement using this memref. - OperationState state(opStmt->getContext(), opStmt->getLoc(), - opStmt->getName()); - state.operands.reserve(opStmt->getNumOperands() + extraIndices.size()); + // Construct the new operation instruction using this memref. + OperationState state(opInst->getContext(), opInst->getLoc(), + opInst->getName()); + state.operands.reserve(opInst->getNumOperands() + extraIndices.size()); // Insert the non-memref operands. - state.operands.insert(state.operands.end(), opStmt->operand_begin(), - opStmt->operand_begin() + memRefOperandPos); + state.operands.insert(state.operands.end(), opInst->operand_begin(), + opInst->operand_begin() + memRefOperandPos); state.operands.push_back(newMemRef); - FuncBuilder builder(opStmt); + FuncBuilder builder(opInst); for (auto *extraIndex : extraIndices) { // TODO(mlir-team): An operation/SSA value should provide a method to // return the position of an SSA result in its defining @@ -139,10 +139,10 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, remapOperands.insert(remapOperands.end(), extraOperands.begin(), extraOperands.end()); remapOperands.insert( - remapOperands.end(), opStmt->operand_begin() + memRefOperandPos + 1, - opStmt->operand_begin() + memRefOperandPos + 1 + oldMemRefRank); + remapOperands.end(), opInst->operand_begin() + memRefOperandPos + 1, + opInst->operand_begin() + memRefOperandPos + 1 + oldMemRefRank); if (indexRemap) { - auto remapOp = builder.create(opStmt->getLoc(), indexRemap, + auto remapOp = builder.create(opInst->getLoc(), indexRemap, remapOperands); // Remapped indices. for (auto *index : remapOp->getInstruction()->getResults()) @@ -155,27 +155,27 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, // Insert the remaining operands unmodified. state.operands.insert(state.operands.end(), - opStmt->operand_begin() + memRefOperandPos + 1 + + opInst->operand_begin() + memRefOperandPos + 1 + oldMemRefRank, - opStmt->operand_end()); + opInst->operand_end()); // Result types don't change. Both memref's are of the same elemental type. - state.types.reserve(opStmt->getNumResults()); - for (const auto *result : opStmt->getResults()) + state.types.reserve(opInst->getNumResults()); + for (const auto *result : opInst->getResults()) state.types.push_back(result->getType()); // Attributes also do not change. - state.attributes.insert(state.attributes.end(), opStmt->getAttrs().begin(), - opStmt->getAttrs().end()); + state.attributes.insert(state.attributes.end(), opInst->getAttrs().begin(), + opInst->getAttrs().end()); // Create the new operation. auto *repOp = builder.createOperation(state); // Replace old memref's deferencing op's uses. unsigned r = 0; - for (auto *res : opStmt->getResults()) { + for (auto *res : opInst->getResults()) { res->replaceAllUsesWith(repOp->getResult(r++)); } - opStmt->erase(); + opInst->erase(); } return true; } @@ -196,9 +196,9 @@ mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc, // Initialize AffineValueMap with identity map. AffineValueMap valueMap(map, operands); - for (auto *opStmt : affineApplyOps) { - assert(opStmt->isa()); - auto affineApplyOp = opStmt->cast(); + for (auto *opInst : affineApplyOps) { + assert(opInst->isa()); + auto affineApplyOp = opInst->cast(); // Forward substitute 'affineApplyOp' into 'valueMap'. valueMap.forwardSubstitute(*affineApplyOp); } @@ -219,10 +219,10 @@ mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc, return affineApplyOp->getInstruction(); } -/// Given an operation statement, inserts a new single affine apply operation, -/// that is exclusively used by this operation statement, and that provides all -/// operands that are results of an affine_apply as a function of loop iterators -/// and program parameters and whose results are. +/// Given an operation instruction, inserts a new single affine apply operation, +/// that is exclusively used by this operation instruction, and that provides +/// all operands that are results of an affine_apply as a function of loop +/// iterators and program parameters and whose results are. /// /// Before /// @@ -242,18 +242,18 @@ mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc, /// This allows applying different transformations on send and compute (for eg. /// different shifts/delays). /// -/// Returns nullptr either if none of opStmt's operands were the result of an +/// Returns nullptr either if none of opInst's operands were the result of an /// affine_apply and thus there was no affine computation slice to create, or if -/// all the affine_apply op's supplying operands to this opStmt do not have any -/// uses besides this opStmt. Returns the new affine_apply operation statement +/// all the affine_apply op's supplying operands to this opInst do not have any +/// uses besides this opInst. Returns the new affine_apply operation instruction /// otherwise. -OperationInst *mlir::createAffineComputationSlice(OperationInst *opStmt) { +OperationInst *mlir::createAffineComputationSlice(OperationInst *opInst) { // Collect all operands that are results of affine apply ops. SmallVector subOperands; - subOperands.reserve(opStmt->getNumOperands()); - for (auto *operand : opStmt->getOperands()) { - auto *defStmt = operand->getDefiningInst(); - if (defStmt && defStmt->isa()) { + subOperands.reserve(opInst->getNumOperands()); + for (auto *operand : opInst->getOperands()) { + auto *defInst = operand->getDefiningInst(); + if (defInst && defInst->isa()) { subOperands.push_back(operand); } } @@ -265,13 +265,13 @@ OperationInst *mlir::createAffineComputationSlice(OperationInst *opStmt) { if (affineApplyOps.empty()) return nullptr; - // Check if all uses of the affine apply op's lie only in this op stmt, in + // Check if all uses of the affine apply op's lie only in this op inst, in // which case there would be nothing to do. bool localized = true; for (auto *op : affineApplyOps) { for (auto *result : op->getResults()) { for (auto &use : result->getUses()) { - if (use.getOwner() != opStmt) { + if (use.getOwner() != opInst) { localized = false; break; } @@ -281,18 +281,18 @@ OperationInst *mlir::createAffineComputationSlice(OperationInst *opStmt) { if (localized) return nullptr; - FuncBuilder builder(opStmt); + FuncBuilder builder(opInst); SmallVector results; - auto *affineApplyStmt = createComposedAffineApplyOp( - &builder, opStmt->getLoc(), subOperands, affineApplyOps, &results); + auto *affineApplyInst = createComposedAffineApplyOp( + &builder, opInst->getLoc(), subOperands, affineApplyOps, &results); assert(results.size() == subOperands.size() && "number of results should be the same as the number of subOperands"); // Construct the new operands that include the results from the composed // affine apply op above instead of existing ones (subOperands). So, they - // differ from opStmt's operands only for those operands in 'subOperands', for + // differ from opInst's operands only for those operands in 'subOperands', for // which they will be replaced by the corresponding one from 'results'. - SmallVector newOperands(opStmt->getOperands()); + SmallVector newOperands(opInst->getOperands()); for (unsigned i = 0, e = newOperands.size(); i < e; i++) { // Replace the subOperands from among the new operands. unsigned j, f; @@ -306,10 +306,10 @@ OperationInst *mlir::createAffineComputationSlice(OperationInst *opStmt) { } for (unsigned idx = 0, e = newOperands.size(); idx < e; idx++) { - opStmt->setOperand(idx, newOperands[idx]); + opInst->setOperand(idx, newOperands[idx]); } - return affineApplyStmt; + return affineApplyInst; } void mlir::forwardSubstitute(OpPointer affineApplyOp) { @@ -317,26 +317,26 @@ void mlir::forwardSubstitute(OpPointer affineApplyOp) { // TODO: Support forward substitution for CFG style functions. return; } - auto *opStmt = affineApplyOp->getInstruction(); - // Iterate through all uses of all results of 'opStmt', forward substituting + auto *opInst = affineApplyOp->getInstruction(); + // Iterate through all uses of all results of 'opInst', forward substituting // into any uses which are AffineApplyOps. - for (unsigned resultIndex = 0, e = opStmt->getNumResults(); resultIndex < e; + for (unsigned resultIndex = 0, e = opInst->getNumResults(); resultIndex < e; ++resultIndex) { - const Value *result = opStmt->getResult(resultIndex); + const Value *result = opInst->getResult(resultIndex); for (auto it = result->use_begin(); it != result->use_end();) { InstOperand &use = *(it++); - auto *useStmt = use.getOwner(); - auto *useOpStmt = dyn_cast(useStmt); + auto *useInst = use.getOwner(); + auto *useOpInst = dyn_cast(useInst); // Skip if use is not AffineApplyOp. - if (useOpStmt == nullptr || !useOpStmt->isa()) + if (useOpInst == nullptr || !useOpInst->isa()) continue; - // Advance iterator past 'opStmt' operands which also use 'result'. - while (it != result->use_end() && it->getOwner() == useStmt) + // Advance iterator past 'opInst' operands which also use 'result'. + while (it != result->use_end() && it->getOwner() == useInst) ++it; - FuncBuilder builder(useOpStmt); + FuncBuilder builder(useOpInst); // Initialize AffineValueMap with 'affineApplyOp' which uses 'result'. - auto oldAffineApplyOp = useOpStmt->cast(); + auto oldAffineApplyOp = useOpInst->cast(); AffineValueMap valueMap(*oldAffineApplyOp); // Forward substitute 'result' at index 'i' into 'valueMap'. valueMap.forwardSubstituteSingle(*affineApplyOp, resultIndex); @@ -348,10 +348,10 @@ void mlir::forwardSubstitute(OpPointer affineApplyOp) { operands[i] = valueMap.getOperand(i); } auto newAffineApplyOp = builder.create( - useOpStmt->getLoc(), valueMap.getAffineMap(), operands); + useOpInst->getLoc(), valueMap.getAffineMap(), operands); // Update all uses to use results from 'newAffineApplyOp'. - for (unsigned i = 0, e = useOpStmt->getNumResults(); i < e; ++i) { + for (unsigned i = 0, e = useOpInst->getNumResults(); i < e; ++i) { oldAffineApplyOp->getResult(i)->replaceAllUsesWith( newAffineApplyOp->getResult(i)); } @@ -364,19 +364,19 @@ void mlir::forwardSubstitute(OpPointer affineApplyOp) { /// Folds the specified (lower or upper) bound to a constant if possible /// considering its operands. Returns false if the folding happens for any of /// the bounds, true otherwise. -bool mlir::constantFoldBounds(ForStmt *forStmt) { - auto foldLowerOrUpperBound = [forStmt](bool lower) { +bool mlir::constantFoldBounds(ForInst *forInst) { + auto foldLowerOrUpperBound = [forInst](bool lower) { // Check if the bound is already a constant. - if (lower && forStmt->hasConstantLowerBound()) + if (lower && forInst->hasConstantLowerBound()) return true; - if (!lower && forStmt->hasConstantUpperBound()) + if (!lower && forInst->hasConstantUpperBound()) return true; // Check to see if each of the operands is the result of a constant. If so, // get the value. If not, ignore it. SmallVector operandConstants; - auto boundOperands = lower ? forStmt->getLowerBoundOperands() - : forStmt->getUpperBoundOperands(); + auto boundOperands = lower ? forInst->getLowerBoundOperands() + : forInst->getUpperBoundOperands(); for (const auto *operand : boundOperands) { Attribute operandCst; if (auto *operandOp = operand->getDefiningInst()) { @@ -387,7 +387,7 @@ bool mlir::constantFoldBounds(ForStmt *forStmt) { } AffineMap boundMap = - lower ? forStmt->getLowerBoundMap() : forStmt->getUpperBoundMap(); + lower ? forInst->getLowerBoundMap() : forInst->getUpperBoundMap(); assert(boundMap.getNumResults() >= 1 && "bound maps should have at least one result"); SmallVector foldedResults; @@ -402,8 +402,8 @@ bool mlir::constantFoldBounds(ForStmt *forStmt) { maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult) : llvm::APIntOps::smin(maxOrMin, foldedResult); } - lower ? forStmt->setConstantLowerBound(maxOrMin.getSExtValue()) - : forStmt->setConstantUpperBound(maxOrMin.getSExtValue()); + lower ? forInst->setConstantLowerBound(maxOrMin.getSExtValue()) + : forInst->setConstantUpperBound(maxOrMin.getSExtValue()); // Return false on success. return false; @@ -449,11 +449,11 @@ void mlir::remapFunctionAttrs( if (!fn.isML()) return; - struct MLFnWalker : public StmtWalker { + struct MLFnWalker : public InstWalker { MLFnWalker(const DenseMap &remappingTable) : remappingTable(remappingTable) {} - void visitOperationInst(OperationInst *opStmt) { - remapFunctionAttrs(*opStmt, remappingTable); + void visitOperationInst(OperationInst *opInst) { + remapFunctionAttrs(*opInst, remappingTable); } const DenseMap &remappingTable; diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index 78d048b4778..9aa11682ebb 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -95,20 +95,20 @@ void VectorizerTestPass::testVectorShapeRatio(Function *f) { SmallVector shape(clTestVectorShapeRatio.begin(), clTestVectorShapeRatio.end()); auto subVectorType = VectorType::get(shape, Type::getF32(f->getContext())); - // Only filter statements that operate on a strict super-vector and have one + // Only filter instructions that operate on a strict super-vector and have one // return. This makes testing easier. - auto filter = [subVectorType](const Statement &stmt) { - auto *opStmt = dyn_cast(&stmt); - if (!opStmt) { + auto filter = [subVectorType](const Instruction &inst) { + auto *opInst = dyn_cast(&inst); + if (!opInst) { return false; } assert(subVectorType.getElementType() == Type::getF32(subVectorType.getContext()) && "Only f32 supported for now"); - if (!matcher::operatesOnStrictSuperVectors(*opStmt, subVectorType)) { + if (!matcher::operatesOnStrictSuperVectors(*opInst, subVectorType)) { return false; } - if (opStmt->getNumResults() != 1) { + if (opInst->getNumResults() != 1) { return false; } return true; @@ -116,26 +116,26 @@ void VectorizerTestPass::testVectorShapeRatio(Function *f) { auto pat = Op(filter); auto matches = pat.match(f); for (auto m : matches) { - auto *opStmt = cast(m.first); + auto *opInst = cast(m.first); // This is a unit test that only checks and prints shape ratio. // As a consequence we write only Ops with a single return type for the // purpose of this test. If we need to test more intricate behavior in the // future we can always extend. - auto superVectorType = opStmt->getResult(0)->getType().cast(); + auto superVectorType = opInst->getResult(0)->getType().cast(); auto ratio = shapeRatio(superVectorType, subVectorType); if (!ratio.hasValue()) { - opStmt->emitNote("NOT MATCHED"); + opInst->emitNote("NOT MATCHED"); } else { - outs() << "\nmatched: " << *opStmt << " with shape ratio: "; + outs() << "\nmatched: " << *opInst << " with shape ratio: "; interleaveComma(MutableArrayRef(*ratio), outs()); } } } -static std::string toString(Statement *stmt) { +static std::string toString(Instruction *inst) { std::string res; auto os = llvm::raw_string_ostream(res); - stmt->print(os); + inst->print(os); return res; } @@ -144,10 +144,10 @@ static MLFunctionMatches matchTestSlicingOps(Function *f) { constexpr auto kTestSlicingOpName = "slicing-test-op"; using functional::map; using matcher::Op; - // Match all OpStatements with the kTestSlicingOpName name. - auto filter = [](const Statement &stmt) { - const auto &opStmt = cast(stmt); - return opStmt.getName().getStringRef() == kTestSlicingOpName; + // Match all OpInstructions with the kTestSlicingOpName name. + auto filter = [](const Instruction &inst) { + const auto &opInst = cast(inst); + return opInst.getName().getStringRef() == kTestSlicingOpName; }; auto pat = Op(filter); return pat.match(f); @@ -156,7 +156,7 @@ static MLFunctionMatches matchTestSlicingOps(Function *f) { void VectorizerTestPass::testBackwardSlicing(Function *f) { auto matches = matchTestSlicingOps(f); for (auto m : matches) { - SetVector backwardSlice; + SetVector backwardSlice; getBackwardSlice(m.first, &backwardSlice); auto strs = map(toString, backwardSlice); outs() << "\nmatched: " << *m.first << " backward static slice: "; @@ -169,7 +169,7 @@ void VectorizerTestPass::testBackwardSlicing(Function *f) { void VectorizerTestPass::testForwardSlicing(Function *f) { auto matches = matchTestSlicingOps(f); for (auto m : matches) { - SetVector forwardSlice; + SetVector forwardSlice; getForwardSlice(m.first, &forwardSlice); auto strs = map(toString, forwardSlice); outs() << "\nmatched: " << *m.first << " forward static slice: "; @@ -182,7 +182,7 @@ void VectorizerTestPass::testForwardSlicing(Function *f) { void VectorizerTestPass::testSlicing(Function *f) { auto matches = matchTestSlicingOps(f); for (auto m : matches) { - SetVector staticSlice = getSlice(m.first); + SetVector staticSlice = getSlice(m.first); auto strs = map(toString, staticSlice); outs() << "\nmatched: " << *m.first << " static slice: "; for (const auto &s : strs) { @@ -191,9 +191,9 @@ void VectorizerTestPass::testSlicing(Function *f) { } } -bool customOpWithAffineMapAttribute(const Statement &stmt) { - const auto &opStmt = cast(stmt); - return opStmt.getName().getStringRef() == +bool customOpWithAffineMapAttribute(const Instruction &inst) { + const auto &opInst = cast(inst); + return opInst.getName().getStringRef() == VectorizerTestPass::kTestAffineMapOpName; } @@ -205,8 +205,8 @@ void VectorizerTestPass::testComposeMaps(Function *f) { maps.reserve(matches.size()); std::reverse(matches.begin(), matches.end()); for (auto m : matches) { - auto *opStmt = cast(m.first); - auto map = opStmt->getAttr(VectorizerTestPass::kTestAffineMapAttrName) + auto *opInst = cast(m.first); + auto map = opInst->getAttr(VectorizerTestPass::kTestAffineMapAttrName) .cast() .getValue(); maps.push_back(map); diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index ddbd6256782..bbb703cd627 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -252,7 +252,7 @@ using namespace mlir; /// ========== /// The algorithm proceeds in a few steps: /// 1. defining super-vectorization patterns and matching them on the tree of -/// ForStmt. A super-vectorization pattern is defined as a recursive data +/// ForInst. A super-vectorization pattern is defined as a recursive data /// structures that matches and captures nested, imperfectly-nested loops /// that have a. comformable loop annotations attached (e.g. parallel, /// reduction, vectoriable, ...) as well as b. all contiguous load/store @@ -279,7 +279,7 @@ using namespace mlir; /// it by its vector form. Otherwise, if the scalar value is a constant, /// it is vectorized into a splat. In all other cases, vectorization for /// the pattern currently fails. -/// e. if everything under the root ForStmt in the current pattern vectorizes +/// e. if everything under the root ForInst in the current pattern vectorizes /// properly, we commit that loop to the IR. Otherwise we discard it and /// restore a previously cloned version of the loop. Thanks to the /// recursive scoping nature of matchers and captured patterns, this is @@ -668,12 +668,12 @@ namespace { struct VectorizationStrategy { ArrayRef vectorSizes; - DenseMap loopToVectorDim; + DenseMap loopToVectorDim; }; } // end anonymous namespace -static void vectorizeLoopIfProfitable(ForStmt *loop, unsigned depthInPattern, +static void vectorizeLoopIfProfitable(ForInst *loop, unsigned depthInPattern, unsigned patternDepth, VectorizationStrategy *strategy) { assert(patternDepth > depthInPattern && @@ -705,7 +705,7 @@ static bool analyzeProfitability(MLFunctionMatches matches, unsigned depthInPattern, unsigned patternDepth, VectorizationStrategy *strategy) { for (auto m : matches) { - auto *loop = cast(m.first); + auto *loop = cast(m.first); bool fail = analyzeProfitability(m.second, depthInPattern + 1, patternDepth, strategy); if (fail) { @@ -721,7 +721,7 @@ static bool analyzeProfitability(MLFunctionMatches matches, namespace { struct VectorizationState { - /// Adds an entry of pre/post vectorization statements in the state. + /// Adds an entry of pre/post vectorization instructions in the state. void registerReplacement(OperationInst *key, OperationInst *value); /// When the current vectorization pattern is successful, this erases the /// instructions that were marked for erasure in the proper order and resets @@ -733,7 +733,7 @@ struct VectorizationState { SmallVector toErase; // Set of OperationInst that have been vectorized (the values in the // vectorizationMap for hashed access). The vectorizedSet is used in - // particular to filter the statements that have already been vectorized by + // particular to filter the instructions that have already been vectorized by // this pattern, when iterating over nested loops in this pattern. DenseSet vectorizedSet; // Map of old scalar OperationInst to new vectorized OperationInst. @@ -747,16 +747,16 @@ struct VectorizationState { // that have been vectorized. They can be retrieved from `vectorizationMap` // but it is convenient to keep track of them in a separate data structure. DenseSet roots; - // Terminator statements for the worklist in the vectorizeOperations function. - // They consist of the subset of store operations that have been vectorized. - // They can be retrieved from `vectorizationMap` but it is convenient to keep - // track of them in a separate data structure. Since they do not necessarily - // belong to use-def chains starting from loads (e.g storing a constant), we - // need to handle them in a post-pass. + // Terminator instructions for the worklist in the vectorizeOperations + // function. They consist of the subset of store operations that have been + // vectorized. They can be retrieved from `vectorizationMap` but it is + // convenient to keep track of them in a separate data structure. Since they + // do not necessarily belong to use-def chains starting from loads (e.g + // storing a constant), we need to handle them in a post-pass. DenseSet terminators; - // Checks that the type of `stmt` is StoreOp and adds it to the terminators + // Checks that the type of `inst` is StoreOp and adds it to the terminators // set. - void registerTerminator(OperationInst *stmt); + void registerTerminator(OperationInst *inst); private: void registerReplacement(const Value *key, Value *value); @@ -784,19 +784,19 @@ void VectorizationState::registerReplacement(OperationInst *key, } } -void VectorizationState::registerTerminator(OperationInst *stmt) { - assert(stmt->isa() && "terminator must be a StoreOp"); - assert(terminators.count(stmt) == 0 && +void VectorizationState::registerTerminator(OperationInst *inst) { + assert(inst->isa() && "terminator must be a StoreOp"); + assert(terminators.count(inst) == 0 && "terminator was already inserted previously"); - terminators.insert(stmt); + terminators.insert(inst); } void VectorizationState::finishVectorizationPattern() { while (!toErase.empty()) { - auto *stmt = toErase.pop_back_val(); + auto *inst = toErase.pop_back_val(); LLVM_DEBUG(dbgs() << "\n[early-vect] finishVectorizationPattern erase: "); - LLVM_DEBUG(stmt->print(dbgs())); - stmt->erase(); + LLVM_DEBUG(inst->print(dbgs())); + inst->erase(); } } @@ -832,23 +832,23 @@ static bool vectorizeRootOrTerminal(Value *iv, LoadOrStoreOpPointer memoryOp, auto vectorType = VectorType::get(state->strategy->vectorSizes, elementType); // Materialize a MemRef with 1 vector. - auto *opStmt = memoryOp->getInstruction(); + auto *opInst = memoryOp->getInstruction(); // For now, vector_transfers must be aligned, operate only on indices with an // identity subset of AffineMap and do not change layout. // TODO(ntv): increase the expressiveness power of vector_transfer operations // as needed by various targets. - if (opStmt->template isa()) { + if (opInst->template isa()) { auto permutationMap = - makePermutationMap(opStmt, state->strategy->loopToVectorDim); + makePermutationMap(opInst, state->strategy->loopToVectorDim); LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: "); LLVM_DEBUG(permutationMap.print(dbgs())); - FuncBuilder b(opStmt); + FuncBuilder b(opInst); auto transfer = b.create( - opStmt->getLoc(), vectorType, memoryOp->getMemRef(), + opInst->getLoc(), vectorType, memoryOp->getMemRef(), map(makePtrDynCaster(), memoryOp->getIndices()), permutationMap); - state->registerReplacement(opStmt, transfer->getInstruction()); + state->registerReplacement(opInst, transfer->getInstruction()); } else { - state->registerTerminator(opStmt); + state->registerTerminator(opInst); } return false; } @@ -856,28 +856,29 @@ static bool vectorizeRootOrTerminal(Value *iv, LoadOrStoreOpPointer memoryOp, /// Coarsens the loops bounds and transforms all remaining load and store /// operations into the appropriate vector_transfer. -static bool vectorizeForStmt(ForStmt *loop, int64_t step, +static bool vectorizeForInst(ForInst *loop, int64_t step, VectorizationState *state) { using namespace functional; loop->setStep(step); - FilterFunctionType notVectorizedThisPattern = [state](const Statement &stmt) { - if (!matcher::isLoadOrStore(stmt)) { - return false; - } - auto *opStmt = cast(&stmt); - return state->vectorizationMap.count(opStmt) == 0 && - state->vectorizedSet.count(opStmt) == 0 && - state->roots.count(opStmt) == 0 && - state->terminators.count(opStmt) == 0; - }; + FilterFunctionType notVectorizedThisPattern = + [state](const Instruction &inst) { + if (!matcher::isLoadOrStore(inst)) { + return false; + } + auto *opInst = cast(&inst); + return state->vectorizationMap.count(opInst) == 0 && + state->vectorizedSet.count(opInst) == 0 && + state->roots.count(opInst) == 0 && + state->terminators.count(opInst) == 0; + }; auto loadAndStores = matcher::Op(notVectorizedThisPattern); auto matches = loadAndStores.match(loop); for (auto ls : matches) { - auto *opStmt = cast(ls.first); - auto load = opStmt->dyn_cast(); - auto store = opStmt->dyn_cast(); - LLVM_DEBUG(opStmt->print(dbgs())); + auto *opInst = cast(ls.first); + auto load = opInst->dyn_cast(); + auto store = opInst->dyn_cast(); + LLVM_DEBUG(opInst->print(dbgs())); auto fail = load ? vectorizeRootOrTerminal(loop, load, state) : vectorizeRootOrTerminal(loop, store, state); if (fail) { @@ -895,8 +896,8 @@ static bool vectorizeForStmt(ForStmt *loop, int64_t step, /// we can build a cost model and a search procedure. static FilterFunctionType isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension) { - return [fastestVaryingMemRefDimension](const Statement &forStmt) { - const auto &loop = cast(forStmt); + return [fastestVaryingMemRefDimension](const Instruction &forInst) { + const auto &loop = cast(forInst); return isVectorizableLoopAlongFastestVaryingMemRefDim( loop, fastestVaryingMemRefDimension); }; @@ -911,7 +912,7 @@ static bool vectorizeNonRoot(MLFunctionMatches matches, /// recursively in DFS post-order. static bool doVectorize(MLFunctionMatches::EntryType oneMatch, VectorizationState *state) { - ForStmt *loop = cast(oneMatch.first); + ForInst *loop = cast(oneMatch.first); MLFunctionMatches childrenMatches = oneMatch.second; // 1. DFS postorder recursion, if any of my children fails, I fail too. @@ -938,10 +939,10 @@ static bool doVectorize(MLFunctionMatches::EntryType oneMatch, // exploratory tradeoffs (see top of the file). Apply coarsening, i.e.: // | ub -> ub // | step -> step * vectorSize - LLVM_DEBUG(dbgs() << "\n[early-vect] vectorizeForStmt by " << vectorSize + LLVM_DEBUG(dbgs() << "\n[early-vect] vectorizeForInst by " << vectorSize << " : "); LLVM_DEBUG(loop->print(dbgs())); - return vectorizeForStmt(loop, loop->getStep() * vectorSize, state); + return vectorizeForInst(loop, loop->getStep() * vectorSize, state); } /// Non-root pattern iterates over the matches at this level, calls doVectorize @@ -963,20 +964,20 @@ static bool vectorizeNonRoot(MLFunctionMatches matches, /// element type. /// If `type` is not a valid vector type or if the scalar constant is not a /// valid vector element type, returns nullptr. -static Value *vectorizeConstant(Statement *stmt, const ConstantOp &constant, +static Value *vectorizeConstant(Instruction *inst, const ConstantOp &constant, Type type) { if (!type || !type.isa() || !VectorType::isValidElementType(constant.getType())) { return nullptr; } - FuncBuilder b(stmt); - Location loc = stmt->getLoc(); + FuncBuilder b(inst); + Location loc = inst->getLoc(); auto vectorType = type.cast(); auto attr = SplatElementsAttr::get(vectorType, constant.getValue()); - auto *constantOpStmt = cast(constant.getInstruction()); + auto *constantOpInst = cast(constant.getInstruction()); OperationState state( - b.getContext(), loc, constantOpStmt->getName().getStringRef(), {}, + b.getContext(), loc, constantOpInst->getName().getStringRef(), {}, {vectorType}, {make_pair(Identifier::get("value", b.getContext()), attr)}); @@ -985,7 +986,7 @@ static Value *vectorizeConstant(Statement *stmt, const ConstantOp &constant, } /// Returns a uniqu'ed VectorType. -/// In the case `v`'s defining statement is already part of the `state`'s +/// In the case `v`'s defining instruction is already part of the `state`'s /// vectorizedSet, just returns the type of `v`. /// Otherwise, constructs a new VectorType of shape defined by `state.strategy` /// and of elemental type the type of `v`. @@ -993,17 +994,17 @@ static Type getVectorType(Value *v, const VectorizationState &state) { if (!VectorType::isValidElementType(v->getType())) { return Type(); } - auto *definingOpStmt = cast(v->getDefiningInst()); - if (state.vectorizedSet.count(definingOpStmt) > 0) { + auto *definingOpInst = cast(v->getDefiningInst()); + if (state.vectorizedSet.count(definingOpInst) > 0) { return v->getType().cast(); } return VectorType::get(state.strategy->vectorSizes, v->getType()); }; -/// Tries to vectorize a given operand `op` of Statement `stmt` during def-chain -/// propagation or during terminator vectorization, by applying the following -/// logic: -/// 1. if the defining statement is part of the vectorizedSet (i.e. vectorized +/// Tries to vectorize a given operand `op` of Instruction `inst` during +/// def-chain propagation or during terminator vectorization, by applying the +/// following logic: +/// 1. if the defining instruction is part of the vectorizedSet (i.e. vectorized /// useby -def propagation), `op` is already in the proper vector form; /// 2. otherwise, the `op` may be in some other vector form that fails to /// vectorize atm (i.e. broadcasting required), returns nullptr to indicate @@ -1021,13 +1022,13 @@ static Type getVectorType(Value *v, const VectorizationState &state) { /// vectorization is possible with the above logic. Returns nullptr otherwise. /// /// TODO(ntv): handle more complex cases. -static Value *vectorizeOperand(Value *operand, Statement *stmt, +static Value *vectorizeOperand(Value *operand, Instruction *inst, VectorizationState *state) { LLVM_DEBUG(dbgs() << "\n[early-vect]vectorize operand: "); LLVM_DEBUG(operand->print(dbgs())); - auto *definingStatement = cast(operand->getDefiningInst()); + auto *definingInstruction = cast(operand->getDefiningInst()); // 1. If this value has already been vectorized this round, we are done. - if (state->vectorizedSet.count(definingStatement) > 0) { + if (state->vectorizedSet.count(definingInstruction) > 0) { LLVM_DEBUG(dbgs() << " -> already vector operand"); return operand; } @@ -1049,7 +1050,7 @@ static Value *vectorizeOperand(Value *operand, Statement *stmt, } // 3. vectorize constant. if (auto constant = operand->getDefiningInst()->dyn_cast()) { - return vectorizeConstant(stmt, *constant, + return vectorizeConstant(inst, *constant, getVectorType(operand, *state).cast()); } // 4. currently non-vectorizable. @@ -1068,41 +1069,41 @@ static Value *vectorizeOperand(Value *operand, Statement *stmt, /// Maybe some Ops are not vectorizable or require some tricky logic, we cannot /// do one-off logic here; ideally it would be TableGen'd. static OperationInst *vectorizeOneOperationInst(FuncBuilder *b, - OperationInst *opStmt, + OperationInst *opInst, VectorizationState *state) { // Sanity checks. - assert(!opStmt->isa() && + assert(!opInst->isa() && "all loads must have already been fully vectorized independently"); - assert(!opStmt->isa() && + assert(!opInst->isa() && "vector_transfer_read cannot be further vectorized"); - assert(!opStmt->isa() && + assert(!opInst->isa() && "vector_transfer_write cannot be further vectorized"); - if (auto store = opStmt->dyn_cast()) { + if (auto store = opInst->dyn_cast()) { auto *memRef = store->getMemRef(); auto *value = store->getValueToStore(); - auto *vectorValue = vectorizeOperand(value, opStmt, state); + auto *vectorValue = vectorizeOperand(value, opInst, state); auto indices = map(makePtrDynCaster(), store->getIndices()); - FuncBuilder b(opStmt); + FuncBuilder b(opInst); auto permutationMap = - makePermutationMap(opStmt, state->strategy->loopToVectorDim); + makePermutationMap(opInst, state->strategy->loopToVectorDim); LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: "); LLVM_DEBUG(permutationMap.print(dbgs())); auto transfer = b.create( - opStmt->getLoc(), vectorValue, memRef, indices, permutationMap); + opInst->getLoc(), vectorValue, memRef, indices, permutationMap); auto *res = cast(transfer->getInstruction()); LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << *res); // "Terminators" (i.e. StoreOps) are erased on the spot. - opStmt->erase(); + opInst->erase(); return res; } auto types = map([state](Value *v) { return getVectorType(v, *state); }, - opStmt->getResults()); - auto vectorizeOneOperand = [opStmt, state](Value *op) -> Value * { - return vectorizeOperand(op, opStmt, state); + opInst->getResults()); + auto vectorizeOneOperand = [opInst, state](Value *op) -> Value * { + return vectorizeOperand(op, opInst, state); }; - auto operands = map(vectorizeOneOperand, opStmt->getOperands()); + auto operands = map(vectorizeOneOperand, opInst->getOperands()); // Check whether a single operand is null. If so, vectorization failed. bool success = llvm::all_of(operands, [](Value *op) { return op; }); if (!success) { @@ -1116,9 +1117,9 @@ static OperationInst *vectorizeOneOperationInst(FuncBuilder *b, // TODO(ntv): Is it worth considering an OperationInst.clone operation // which changes the type so we can promote an OperationInst with less // boilerplate? - OperationState newOp(b->getContext(), opStmt->getLoc(), - opStmt->getName().getStringRef(), operands, types, - opStmt->getAttrs()); + OperationState newOp(b->getContext(), opInst->getLoc(), + opInst->getName().getStringRef(), operands, types, + opInst->getAttrs()); return b->createOperation(newOp); } @@ -1137,13 +1138,13 @@ static bool vectorizeOperations(VectorizationState *state) { auto insertUsesOf = [&worklist, state](OperationInst *vectorized) { for (auto *r : vectorized->getResults()) for (auto &u : r->getUses()) { - auto *stmt = cast(u.getOwner()); + auto *inst = cast(u.getOwner()); // Don't propagate to terminals, a separate pass is needed for those. // TODO(ntv)[b/119759136]: use isa<> once Op is implemented. - if (state->terminators.count(stmt) > 0) { + if (state->terminators.count(inst) > 0) { continue; } - worklist.insert(stmt); + worklist.insert(inst); } }; apply(insertUsesOf, state->roots); @@ -1152,15 +1153,15 @@ static bool vectorizeOperations(VectorizationState *state) { // size again. By construction, the order of elements in the worklist is // consistent across iterations. for (unsigned i = 0; i < worklist.size(); ++i) { - auto *stmt = worklist[i]; + auto *inst = worklist[i]; LLVM_DEBUG(dbgs() << "\n[early-vect] vectorize use: "); - LLVM_DEBUG(stmt->print(dbgs())); + LLVM_DEBUG(inst->print(dbgs())); - // 2. Create vectorized form of the statement. - // Insert it just before stmt, on success register stmt as replaced. - FuncBuilder b(stmt); - auto *vectorizedStmt = vectorizeOneOperationInst(&b, stmt, state); - if (!vectorizedStmt) { + // 2. Create vectorized form of the instruction. + // Insert it just before inst, on success register inst as replaced. + FuncBuilder b(inst); + auto *vectorizedInst = vectorizeOneOperationInst(&b, inst, state); + if (!vectorizedInst) { return true; } @@ -1168,11 +1169,11 @@ static bool vectorizeOperations(VectorizationState *state) { // Note that we cannot just call replaceAllUsesWith because it may // result in ops with mixed types, for ops whose operands have not all // yet been vectorized. This would be invalid IR. - state->registerReplacement(stmt, vectorizedStmt); + state->registerReplacement(inst, vectorizedInst); - // 4. Augment the worklist with uses of the statement we just vectorized. + // 4. Augment the worklist with uses of the instruction we just vectorized. // This preserves the proper order in the worklist. - apply(insertUsesOf, ArrayRef{stmt}); + apply(insertUsesOf, ArrayRef{inst}); } return false; } @@ -1184,7 +1185,7 @@ static bool vectorizeOperations(VectorizationState *state) { static bool vectorizeRootMatches(MLFunctionMatches matches, VectorizationStrategy *strategy) { for (auto m : matches) { - auto *loop = cast(m.first); + auto *loop = cast(m.first); VectorizationState state; state.strategy = strategy; @@ -1201,7 +1202,7 @@ static bool vectorizeRootMatches(MLFunctionMatches matches, } FuncBuilder builder(loop); // builder to insert in place of loop DenseMap nomap; - ForStmt *clonedLoop = cast(builder.clone(*loop, nomap)); + ForInst *clonedLoop = cast(builder.clone(*loop, nomap)); auto fail = doVectorize(m, &state); /// Sets up error handling for this root loop. This is how the root match /// maintains a clone for handling failure and restores the proper state via @@ -1230,8 +1231,8 @@ static bool vectorizeRootMatches(MLFunctionMatches matches, auto roots = map(getDefiningInst, map(getKey, state.replacementMap)); // Vectorize the root operations and everything reached by use-def chains - // except the terminators (store statements) that need to be post-processed - // separately. + // except the terminators (store instructions) that need to be + // post-processed separately. fail = vectorizeOperations(&state); if (fail) { LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed vectorizeOperations"); @@ -1239,12 +1240,12 @@ static bool vectorizeRootMatches(MLFunctionMatches matches, } // Finally, vectorize the terminators. If anything fails to vectorize, skip. - auto vectorizeOrFail = [&fail, &state](OperationInst *stmt) { + auto vectorizeOrFail = [&fail, &state](OperationInst *inst) { if (fail) { return; } - FuncBuilder b(stmt); - auto *res = vectorizeOneOperationInst(&b, stmt, &state); + FuncBuilder b(inst); + auto *res = vectorizeOneOperationInst(&b, inst, &state); if (res == nullptr) { fail = true; } @@ -1284,7 +1285,7 @@ PassResult Vectorize::runOnMLFunction(Function *f) { if (fail) { continue; } - auto *loop = cast(m.first); + auto *loop = cast(m.first); vectorizeLoopIfProfitable(loop, 0, patternDepth, &strategy); // TODO(ntv): if pattern does not apply, report it; alter the // cost/benefit. diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir index 1ef759ebedc..2e9c3915504 100644 --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -160,11 +160,11 @@ bb42: // ----- mlfunc @foo() -mlfunc @bar() // expected-error {{expected '{' before statement list}} +mlfunc @bar() // expected-error {{expected '{' before instruction list}} // ----- -mlfunc @empty() { // expected-error {{ML function must end with return statement}} +mlfunc @empty() { // expected-error {{ML function must end with return instruction}} } // ----- @@ -177,7 +177,7 @@ bb42: // ----- -mlfunc @no_return() { // expected-error {{ML function must end with return statement}} +mlfunc @no_return() { // expected-error {{ML function must end with return instruction}} "foo"() : () -> () } @@ -231,7 +231,7 @@ mlfunc @malformed_for_to() { mlfunc @incomplete_for() { for %i = 1 to 10 step 2 -} // expected-error {{expected '{' before statement list}} +} // expected-error {{expected '{' before instruction list}} // ----- @@ -246,7 +246,7 @@ mlfunc @for_negative_stride() { // ----- -mlfunc @non_statement() { +mlfunc @non_instruction() { asd // expected-error {{custom op 'asd' is unknown}} } @@ -339,7 +339,7 @@ bb42: mlfunc @missing_rbrace() { return -mlfunc @d() {return} // expected-error {{expected '}' after statement list}} +mlfunc @d() {return} // expected-error {{expected '}' after instruction list}} // ----- @@ -478,7 +478,7 @@ mlfunc @return_inside_loop() -> i8 { for %i = 1 to 100 { %a = "foo"() : ()->i8 return %a : i8 - // expected-error@-1 {{'return' op must be the last statement in the ML function}} + // expected-error@-1 {{'return' op must be the last instruction in the ML function}} } } diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index 44274391951..8cebf1df717 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -283,8 +283,8 @@ mlfunc @loop_bounds(%N : index) { return // CHECK: return } // CHECK: } -// CHECK-LABEL: mlfunc @ifstmt(%arg0 : index) { -mlfunc @ifstmt(%N: index) { +// CHECK-LABEL: mlfunc @ifinst(%arg0 : index) { +mlfunc @ifinst(%N: index) { %c = constant 200 : index // CHECK %c200 = constant 200 for %i = 1 to 10 { // CHECK for %i0 = 1 to 10 { if #set0(%i)[%N, %c] { // CHECK if #set0(%i0)[%arg0, %c200] { @@ -304,8 +304,8 @@ mlfunc @ifstmt(%N: index) { return // CHECK return } // CHECK } -// CHECK-LABEL: mlfunc @simple_ifstmt(%arg0 : index) { -mlfunc @simple_ifstmt(%N: index) { +// CHECK-LABEL: mlfunc @simple_ifinst(%arg0 : index) { +mlfunc @simple_ifinst(%N: index) { %c = constant 200 : index // CHECK %c200 = constant 200 for %i = 1 to 10 { // CHECK for %i0 = 1 to 10 { if #set0(%i)[%N, %c] { // CHECK if #set0(%i0)[%arg0, %c200] { @@ -349,8 +349,8 @@ bb42: // CHECK: bb0: // CHECK: "foo"() {cfgfunc: [], d: 1.000000e-09, i123: 7, if: "foo"} : () -> () "foo"() {if: "foo", cfgfunc: [], i123: 7, d: 1.e-9} : () -> () - // CHECK: "foo"() {fn: @attributes : () -> (), if: @ifstmt : (index) -> ()} : () -> () - "foo"() {fn: @attributes : () -> (), if: @ifstmt : (index) -> ()} : () -> () + // CHECK: "foo"() {fn: @attributes : () -> (), if: @ifinst : (index) -> ()} : () -> () + "foo"() {fn: @attributes : () -> (), if: @ifinst : (index) -> ()} : () -> () return } diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index f26041ed169..34057e2e98a 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -452,8 +452,8 @@ mlfunc @should_fuse_no_top_level_access() { #set0 = (d0) : (1 == 0) -// CHECK-LABEL: mlfunc @should_not_fuse_if_stmt_at_top_level() { -mlfunc @should_not_fuse_if_stmt_at_top_level() { +// CHECK-LABEL: mlfunc @should_not_fuse_if_inst_at_top_level() { +mlfunc @should_not_fuse_if_inst_at_top_level() { %m = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 @@ -466,7 +466,7 @@ mlfunc @should_not_fuse_if_stmt_at_top_level() { %c0 = constant 4 : index if #set0(%c0) { } - // Top-level IfStmt should prevent fusion. + // Top-level IfInst should prevent fusion. // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } @@ -480,8 +480,8 @@ mlfunc @should_not_fuse_if_stmt_at_top_level() { #set0 = (d0) : (1 == 0) -// CHECK-LABEL: mlfunc @should_not_fuse_if_stmt_in_loop_nest() { -mlfunc @should_not_fuse_if_stmt_in_loop_nest() { +// CHECK-LABEL: mlfunc @should_not_fuse_if_inst_in_loop_nest() { +mlfunc @should_not_fuse_if_inst_in_loop_nest() { %m = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 %c4 = constant 4 : index @@ -495,7 +495,7 @@ mlfunc @should_not_fuse_if_stmt_in_loop_nest() { %v0 = load %m[%i1] : memref<10xf32> } - // IfStmt in ForStmt should prevent fusion. + // IfInst in ForInst should prevent fusion. // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } diff --git a/mlir/test/Transforms/memref-dependence-check.mlir b/mlir/test/Transforms/memref-dependence-check.mlir index a59fc18ea69..73e5cfa1664 100644 --- a/mlir/test/Transforms/memref-dependence-check.mlir +++ b/mlir/test/Transforms/memref-dependence-check.mlir @@ -10,7 +10,7 @@ mlfunc @store_may_execute_before_load() { %cf7 = constant 7.0 : f32 %c0 = constant 4 : index // There is a dependence from store 0 to load 1 at depth 1 because the - // ancestor IfStmt of the store, dominates the ancestor ForSmt of the load, + // ancestor IfInst of the store, dominates the ancestor ForSmt of the load, // and thus the store "may" conditionally execute before the load. if #set0(%c0) { for %i0 = 0 to 10 { diff --git a/mlir/test/Transforms/pipeline-data-transfer.mlir b/mlir/test/Transforms/pipeline-data-transfer.mlir index 349b65b4b7b..61dd30ea882 100644 --- a/mlir/test/Transforms/pipeline-data-transfer.mlir +++ b/mlir/test/Transforms/pipeline-data-transfer.mlir @@ -226,7 +226,7 @@ mlfunc @live_out_use(%arg0: memref<512 x 32 x f32>) -> f32 { memref<32 x 32 x f32, 2>, memref<1 x i32> dma_wait %tag[%zero], %num_elt : memref<1 x i32> } - // Use live out of 'for' stmt; no DMA pipelining will be done. + // Use live out of 'for' inst; no DMA pipelining will be done. %v = load %Av[%zero, %zero] : memref<32 x 32 x f32, 2> return %v : f32 // CHECK: %{{[0-9]+}} = load %{{[0-9]+}}[%c0, %c0] : memref<32x32xf32, 2> diff --git a/mlir/utils/vim/mlir.vim b/mlir/utils/vim/mlir.vim index 7bc7f115474..e220ca7a1da 100644 --- a/mlir/utils/vim/mlir.vim +++ b/mlir/utils/vim/mlir.vim @@ -23,8 +23,8 @@ syn region mlirComment start="//" skip="\\$" end="$" syn region mlirString matchgroup=mlirString start=+"+ end=+"+ hi def link mlirComment Comment -hi def link mlirKeywords Statement -hi def link mlirCoreOps Statement +hi def link mlirKeywords Instruction +hi def link mlirCoreOps Instruction hi def link mlirInt Constant hi def link mlirType Type hi def link mlirMapOutline PreProc -- cgit v1.2.3 From dffc589ad2783dd14dd5cdfdfccb624ccee94a6c Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Sat, 29 Dec 2018 15:33:43 -0800 Subject: Extend InstVisitor and Walker to handle arbitrary CFG functions, expand the Function::walk functionality into f->walkInsts/Ops which allows visiting all instructions, not just ops. Eliminate Function::getBody() and Function::getReturn() helpers which crash in CFG functions, and were only kept around as a bridge. This is step 25/n towards merging instructions and statements. PiperOrigin-RevId: 227243966 --- mlir/include/mlir/IR/Function.h | 23 ++--- mlir/include/mlir/IR/InstVisitor.h | 26 +++--- .../mlir/Transforms/MLPatternLoweringPass.h | 8 +- mlir/lib/Analysis/MemRefBoundCheck.cpp | 6 +- mlir/lib/Analysis/MemRefDependenceCheck.cpp | 6 +- mlir/lib/Analysis/OpStats.cpp | 24 ++--- mlir/lib/IR/AsmPrinter.cpp | 100 +++++++-------------- mlir/lib/IR/Function.cpp | 32 +++++-- mlir/lib/Transforms/ConvertToCFG.cpp | 6 +- mlir/lib/Transforms/DmaGeneration.cpp | 14 +-- mlir/lib/Transforms/LoopFusion.cpp | 7 +- mlir/lib/Transforms/LoopTiling.cpp | 14 +-- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 12 +-- mlir/lib/Transforms/LowerVectorTransfers.cpp | 1 - .../Utils/GreedyPatternRewriteDriver.cpp | 2 +- mlir/lib/Transforms/Utils/LoopUtils.cpp | 2 +- mlir/lib/Transforms/Vectorize.cpp | 4 +- 17 files changed, 124 insertions(+), 163 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h index c03fba61d2d..4f3a2abec39 100644 --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -114,26 +114,15 @@ public: Block &front() { return blocks.front(); } const Block &front() const { return const_cast(this)->front(); } - /// Return the 'return' instruction of this Function. - const OperationInst *getReturn() const; - OperationInst *getReturn(); - - // These should only be used on MLFunctions. - Block *getBody() { - assert(isML()); - return &blocks.front(); - } - const Block *getBody() const { - return const_cast(this)->getBody(); - } - /// Walk the instructions in the function in preorder, calling the callback - /// for each operation instruction. - void walk(std::function callback); + /// for each instruction or operation. + void walkInsts(std::function callback); + void walkOps(std::function callback); /// Walk the instructions in the function in postorder, calling the callback - /// for each operation instruction. - void walkPostOrder(std::function callback); + /// for each instruction or operation. + void walkInstsPostOrder(std::function callback); + void walkOpsPostOrder(std::function callback); //===--------------------------------------------------------------------===// // Arguments diff --git a/mlir/include/mlir/IR/InstVisitor.h b/mlir/include/mlir/IR/InstVisitor.h index 3ce7d25cafe..589e5bbdfdc 100644 --- a/mlir/include/mlir/IR/InstVisitor.h +++ b/mlir/include/mlir/IR/InstVisitor.h @@ -131,15 +131,13 @@ public: // Define walkers for Function and all Function instruction kinds. void walk(Function *f) { - static_cast(this)->visitMLFunction(f); - static_cast(this)->walk(f->getBody()->begin(), - f->getBody()->end()); + for (auto &block : *f) + static_cast(this)->walk(block.begin(), block.end()); } void walkPostOrder(Function *f) { - static_cast(this)->walkPostOrder(f->getBody()->begin(), - f->getBody()->end()); - static_cast(this)->visitMLFunction(f); + for (auto it = f->rbegin(), e = f->rend(); it != e; ++it) + static_cast(this)->walkPostOrder(it->begin(), it->end()); } RetTy walkOpInst(OperationInst *opInst) { @@ -162,17 +160,16 @@ public: static_cast(this)->visitIfInst(ifInst); static_cast(this)->walk(ifInst->getThen()->begin(), ifInst->getThen()->end()); - if (ifInst->hasElse()) - static_cast(this)->walk(ifInst->getElse()->begin(), - ifInst->getElse()->end()); + if (auto *elseBlock = ifInst->getElse()) + static_cast(this)->walk(elseBlock->begin(), elseBlock->end()); } void walkIfInstPostOrder(IfInst *ifInst) { static_cast(this)->walkPostOrder(ifInst->getThen()->begin(), ifInst->getThen()->end()); - if (ifInst->hasElse()) - static_cast(this)->walkPostOrder(ifInst->getElse()->begin(), - ifInst->getElse()->end()); + if (auto *elseBlock = ifInst->getElse()) + static_cast(this)->walkPostOrder(elseBlock->begin(), + elseBlock->end()); static_cast(this)->visitIfInst(ifInst); } @@ -181,6 +178,8 @@ public: static_assert(std::is_base_of::value, "Must pass the derived type to this template!"); + static_cast(this)->visitInstruction(s); + switch (s->getKind()) { case Instruction::Kind::For: return static_cast(this)->walkForInst(cast(s)); @@ -195,6 +194,7 @@ public: RetTy walkPostOrder(Instruction *s) { static_assert(std::is_base_of::value, "Must pass the derived type to this template!"); + static_cast(this)->visitInstruction(s); switch (s->getKind()) { case Instruction::Kind::For: @@ -219,10 +219,10 @@ public: // called. These are typically O(1) complexity and shouldn't be recursively // processing their descendants in some way. When using RetTy, all of these // need to be overridden. - void visitMLFunction(Function *f) {} void visitForInst(ForInst *forInst) {} void visitIfInst(IfInst *ifInst) {} void visitOperationInst(OperationInst *opInst) {} + void visitInstruction(Instruction *inst) {} }; } // end namespace mlir diff --git a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h index 4e34889e077..978fa45ab23 100644 --- a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h +++ b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h @@ -108,7 +108,7 @@ public: return nullptr; } - PassResult runOnMLFunction(Function *f) override; + PassResult runOnFunction(Function *f) override; }; ///////////////////////////////////////////////////////////////////// @@ -135,7 +135,7 @@ template struct ListAdder { } // namespace detail template -PassResult MLPatternLoweringPass::runOnMLFunction(Function *f) { +PassResult MLPatternLoweringPass::runOnFunction(Function *f) { detail::OwningMLLoweringPatternList patterns; detail::ListAdder::addPatternsToList(&patterns, f->getContext()); auto funcWiseState = makeFuncWiseState(f); @@ -143,8 +143,8 @@ PassResult MLPatternLoweringPass::runOnMLFunction(Function *f) { FuncBuilder builder(f); MLFuncLoweringRewriter rewriter(&builder); - llvm::SmallVector ops; - f->walk([&ops](OperationInst *inst) { ops.push_back(inst); }); + llvm::SmallVector ops; + f->walkOps([&ops](OperationInst *inst) { ops.push_back(inst); }); for (OperationInst *inst : ops) { for (const auto &pattern : patterns) { diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index e8b668892b8..d21f2f8035b 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -41,9 +41,7 @@ namespace { struct MemRefBoundCheck : public FunctionPass, InstWalker { explicit MemRefBoundCheck() : FunctionPass(&MemRefBoundCheck::passID) {} - PassResult runOnMLFunction(Function *f) override; - // Not applicable to CFG functions. - PassResult runOnCFGFunction(Function *f) override { return success(); } + PassResult runOnFunction(Function *f) override; void visitOperationInst(OperationInst *opInst); @@ -67,7 +65,7 @@ void MemRefBoundCheck::visitOperationInst(OperationInst *opInst) { // TODO(bondhugula): do this for DMA ops as well. } -PassResult MemRefBoundCheck::runOnMLFunction(Function *f) { +PassResult MemRefBoundCheck::runOnFunction(Function *f) { return walk(f), success(); } diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp index 8391f15b6d3..1df935f544e 100644 --- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp @@ -44,9 +44,7 @@ struct MemRefDependenceCheck : public FunctionPass, explicit MemRefDependenceCheck() : FunctionPass(&MemRefDependenceCheck::passID) {} - PassResult runOnMLFunction(Function *f) override; - // Not applicable to CFG functions. - PassResult runOnCFGFunction(Function *f) override { return success(); } + PassResult runOnFunction(Function *f) override; void visitOperationInst(OperationInst *opInst) { if (opInst->isa() || opInst->isa()) { @@ -168,7 +166,7 @@ static void checkDependences(ArrayRef loadsAndStores) { // Walks the Function 'f' adding load and store ops to 'loadsAndStores'. // Runs pair-wise dependence checks. -PassResult MemRefDependenceCheck::runOnMLFunction(Function *f) { +PassResult MemRefDependenceCheck::runOnFunction(Function *f) { loadsAndStores.clear(); walk(f); checkDependences(loadsAndStores); diff --git a/mlir/lib/Analysis/OpStats.cpp b/mlir/lib/Analysis/OpStats.cpp index 07edb13d1a3..a8cad416fd8 100644 --- a/mlir/lib/Analysis/OpStats.cpp +++ b/mlir/lib/Analysis/OpStats.cpp @@ -15,9 +15,9 @@ // limitations under the License. // ============================================================================= -#include "mlir/IR/Function.h" #include "mlir/IR/InstVisitor.h" #include "mlir/IR/Instructions.h" +#include "mlir/IR/Module.h" #include "mlir/IR/OperationSupport.h" #include "mlir/Pass.h" #include "llvm/ADT/DenseMap.h" @@ -33,11 +33,7 @@ struct PrintOpStatsPass : public FunctionPass, InstWalker { // Prints the resultant operation stats post iterating over the module. PassResult runOnModule(Module *m) override; - // Process CFG function considering the instructions in basic blocks. - PassResult runOnCFGFunction(Function *function) override; - - // Process ML functions and operation statments in ML functions. - PassResult runOnMLFunction(Function *function) override; + PassResult runOnFunction(Function *function) override; void visitOperationInst(OperationInst *inst); // Print summary of op stats. @@ -55,17 +51,9 @@ private: char PrintOpStatsPass::passID = 0; PassResult PrintOpStatsPass::runOnModule(Module *m) { - auto result = FunctionPass::runOnModule(m); - if (!result) - printSummary(); - return result; -} - -PassResult PrintOpStatsPass::runOnCFGFunction(Function *function) { - for (const auto &bb : *function) - for (const auto &inst : bb) - if (auto *op = dyn_cast(&inst)) - ++opCount[op->getName().getStringRef()]; + for (auto &fn : *m) + (void)runOnFunction(&fn); + printSummary(); return success(); } @@ -73,7 +61,7 @@ void PrintOpStatsPass::visitOperationInst(OperationInst *inst) { ++opCount[inst->getName().getStringRef()]; } -PassResult PrintOpStatsPass::runOnMLFunction(Function *function) { +PassResult PrintOpStatsPass::runOnFunction(Function *function) { walk(function); return success(); } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 4bc2c94a128..098439ba115 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -113,17 +113,12 @@ private: } // Visit functions. - void visitFunction(const Function *fn); - void visitExtFunction(const Function *fn); - void visitCFGFunction(const Function *fn); - void visitMLFunction(const Function *fn); void visitInstruction(const Instruction *inst); void visitForInst(const ForInst *forInst); void visitIfInst(const IfInst *ifInst); void visitOperationInst(const OperationInst *opInst); void visitType(Type type); void visitAttribute(Attribute attr); - void visitOperation(const OperationInst *op); DenseMap affineMapIds; std::vector affineMapsById; @@ -161,42 +156,8 @@ void ModuleState::visitAttribute(Attribute attr) { } } -void ModuleState::visitOperation(const OperationInst *op) { - // Visit all the types used in the operation. - for (auto *operand : op->getOperands()) - visitType(operand->getType()); - for (auto *result : op->getResults()) - visitType(result->getType()); - - // Visit each of the attributes. - for (auto elt : op->getAttrs()) - visitAttribute(elt.second); -} - -void ModuleState::visitExtFunction(const Function *fn) { - visitType(fn->getType()); -} - -void ModuleState::visitCFGFunction(const Function *fn) { - visitType(fn->getType()); - for (auto &block : *fn) { - for (auto &op : block.getInstructions()) { - if (auto *opInst = dyn_cast(&op)) - visitOperation(opInst); - else { - llvm_unreachable("IfInst/ForInst in a CFG Function isn't supported"); - } - } - } -} - void ModuleState::visitIfInst(const IfInst *ifInst) { recordIntegerSetReference(ifInst->getIntegerSet()); - for (auto &childInst : *ifInst->getThen()) - visitInstruction(&childInst); - if (ifInst->hasElse()) - for (auto &childInst : *ifInst->getElse()) - visitInstruction(&childInst); } void ModuleState::visitForInst(const ForInst *forInst) { @@ -207,14 +168,18 @@ void ModuleState::visitForInst(const ForInst *forInst) { AffineMap ubMap = forInst->getUpperBoundMap(); if (!hasShorthandForm(ubMap)) recordAffineMapReference(ubMap); - - for (auto &childInst : *forInst->getBody()) - visitInstruction(&childInst); } -void ModuleState::visitOperationInst(const OperationInst *opInst) { - for (auto attr : opInst->getAttrs()) - visitAttribute(attr.second); +void ModuleState::visitOperationInst(const OperationInst *op) { + // Visit all the types used in the operation. + for (auto *operand : op->getOperands()) + visitType(operand->getType()); + for (auto *result : op->getResults()) + visitType(result->getType()); + + // Visit each of the attributes. + for (auto elt : op->getAttrs()) + visitAttribute(elt.second); } void ModuleState::visitInstruction(const Instruction *inst) { @@ -225,33 +190,16 @@ void ModuleState::visitInstruction(const Instruction *inst) { return visitForInst(cast(inst)); case Instruction::Kind::OperationInst: return visitOperationInst(cast(inst)); - default: - return; - } -} - -void ModuleState::visitMLFunction(const Function *fn) { - visitType(fn->getType()); - for (auto &inst : *fn->getBody()) { - ModuleState::visitInstruction(&inst); - } -} - -void ModuleState::visitFunction(const Function *fn) { - switch (fn->getKind()) { - case Function::Kind::ExtFunc: - return visitExtFunction(fn); - case Function::Kind::CFGFunc: - return visitCFGFunction(fn); - case Function::Kind::MLFunc: - return visitMLFunction(fn); } } // Initializes module state, populating affine map and integer set state. void ModuleState::initialize(const Module *module) { for (auto &fn : *module) { - visitFunction(&fn); + visitType(fn.getType()); + + const_cast(fn).walkInsts( + [&](Instruction *op) { ModuleState::visitInstruction(op); }); } } @@ -1167,12 +1115,26 @@ void FunctionPrinter::printFunctionSignature() { } } +/// Return true if the introducer for the specified block should be printed. +static bool shouldPrintBlockArguments(const Block *block) { + // Never print the entry block of the function - it is included in the + // argument list. + if (block == &block->getFunction()->front()) + return false; + + // If this is the first block in a nested region, and if there are no + // arguments, then we can omit it. + if (block == &block->getParent()->front() && block->getNumArguments() == 0) + return false; + + // Otherwise print it. + return true; +} + void FunctionPrinter::print(const Block *block) { // Print the block label and argument list, unless this is the first block of // the function, or the first block of an IfInst/ForInst with no arguments. - if (block != &block->getFunction()->front() && - (block != &block->getParent()->front() || - block->getNumArguments() != 0)) { + if (shouldPrintBlockArguments(block)) { os.indent(currentIndent); printBlockName(block); diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index bacb504683b..b7346e9389d 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -161,15 +161,20 @@ bool Function::emitError(const Twine &message) const { // Function implementation. //===----------------------------------------------------------------------===// -const OperationInst *Function::getReturn() const { - return cast(&getBody()->back()); -} +void Function::walkInsts(std::function callback) { + struct Walker : public InstWalker { + std::function const &callback; + Walker(std::function const &callback) + : callback(callback) {} + + void visitInstruction(Instruction *inst) { callback(inst); } + }; -OperationInst *Function::getReturn() { - return cast(&getBody()->back()); + Walker v(callback); + v.walk(this); } -void Function::walk(std::function callback) { +void Function::walkOps(std::function callback) { struct Walker : public InstWalker { std::function const &callback; Walker(std::function const &callback) @@ -182,7 +187,20 @@ void Function::walk(std::function callback) { v.walk(this); } -void Function::walkPostOrder(std::function callback) { +void Function::walkInstsPostOrder(std::function callback) { + struct Walker : public InstWalker { + std::function const &callback; + Walker(std::function const &callback) + : callback(callback) {} + + void visitOperationInst(Instruction *inst) { callback(inst); } + }; + + Walker v(callback); + v.walkPostOrder(this); +} + +void Function::walkOpsPostOrder(std::function callback) { struct Walker : public InstWalker { std::function const &callback; Walker(std::function const &callback) diff --git a/mlir/lib/Transforms/ConvertToCFG.cpp b/mlir/lib/Transforms/ConvertToCFG.cpp index a9124b0bcb8..0ecd248cc89 100644 --- a/mlir/lib/Transforms/ConvertToCFG.cpp +++ b/mlir/lib/Transforms/ConvertToCFG.cpp @@ -485,8 +485,10 @@ Function *FunctionConverter::convert(Function *mlFunc) { } // Convert instructions in order. - for (auto &inst : *mlFunc->getBody()) { - visit(&inst); + for (auto &block : *mlFunc) { + for (auto &inst : block) { + visit(&inst); + } } return cfgFunc; diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index bc7f31f0434..b5e7653d2b9 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -62,9 +62,7 @@ struct DmaGeneration : public FunctionPass, InstWalker { } } - // Not applicable to CFG functions. - PassResult runOnCFGFunction(Function *f) override { return success(); } - PassResult runOnMLFunction(Function *f) override; + PassResult runOnFunction(Function *f) override; void runOnForInst(ForInst *forInst); void visitOperationInst(OperationInst *opInst); @@ -425,10 +423,12 @@ void DmaGeneration::runOnForInst(ForInst *forInst) { << " KiB of DMA buffers in fast memory space\n";); } -PassResult DmaGeneration::runOnMLFunction(Function *f) { - for (auto &inst : *f->getBody()) { - if (auto *forInst = dyn_cast(&inst)) { - runOnForInst(forInst); +PassResult DmaGeneration::runOnFunction(Function *f) { + for (auto &block : *f) { + for (auto &inst : block) { + if (auto *forInst = dyn_cast(&inst)) { + runOnForInst(forInst); + } } } // This function never leaves the IR in an invalid state. diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 97dea753f88..1854cd99ab5 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -348,7 +348,12 @@ public: bool MemRefDependenceGraph::init(Function *f) { unsigned id = 0; DenseMap> memrefAccesses; - for (auto &inst : *f->getBody()) { + + // TODO: support multi-block functions. + if (f->getBlocks().size() != 1) + return false; + + for (auto &inst : f->front()) { if (auto *forInst = dyn_cast(&inst)) { // Create graph node 'id' to represent top-level 'forInst' and record // all loads and store accesses it contains. diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 8f3be8a3d45..fa39b7dfa5b 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -224,15 +224,17 @@ static void getTileableBands(Function *f, do { band.push_back(currInst); } while (currInst->getBody()->getInstructions().size() == 1 && - (currInst = dyn_cast(&*currInst->getBody()->begin()))); + (currInst = dyn_cast(&currInst->getBody()->front()))); bands->push_back(band); }; - for (auto &inst : *f->getBody()) { - auto *forInst = dyn_cast(&inst); - if (!forInst) - continue; - getMaximalPerfectLoopNest(forInst); + for (auto &block : *f) { + for (auto &inst : block) { + auto *forInst = dyn_cast(&inst); + if (!forInst) + continue; + getMaximalPerfectLoopNest(forInst); + } } } diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index f59659cf234..12975608370 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -74,7 +74,7 @@ struct LoopUnrollAndJam : public FunctionPass { : FunctionPass(&LoopUnrollAndJam::passID), unrollJamFactor(unrollJamFactor) {} - PassResult runOnMLFunction(Function *f) override; + PassResult runOnFunction(Function *f) override; bool runOnForInst(ForInst *forInst); static char passID; @@ -88,15 +88,15 @@ FunctionPass *mlir::createLoopUnrollAndJamPass(int unrollJamFactor) { unrollJamFactor == -1 ? None : Optional(unrollJamFactor)); } -PassResult LoopUnrollAndJam::runOnMLFunction(Function *f) { +PassResult LoopUnrollAndJam::runOnFunction(Function *f) { // Currently, just the outermost loop from the first loop nest is // unroll-and-jammed by this pass. However, runOnForInst can be called on any // for Inst. - auto *forInst = dyn_cast(f->getBody()->begin()); - if (!forInst) - return success(); + auto &entryBlock = f->front(); + if (!entryBlock.empty()) + if (auto *forInst = dyn_cast(&entryBlock.front())) + runOnForInst(forInst); - runOnForInst(forInst); return success(); } diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index bcb2abf11dd..5d55800acf3 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -236,7 +236,6 @@ struct LowerVectorTransfersPass makeFuncWiseState(Function *f) const override { auto state = llvm::make_unique(); auto builder = FuncBuilder(f); - builder.setInsertionPointToStart(f->getBody()); state->zero = builder.create(builder.getUnknownLoc(), 0); return state; } diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 6064d1feff3..c37b997734c 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -280,7 +280,7 @@ static void processMLFunction(Function *fn, }; GreedyPatternRewriteDriver driver(std::move(patterns)); - fn->walk([&](OperationInst *inst) { driver.addToWorklist(inst); }); + fn->walkOps([&](OperationInst *inst) { driver.addToWorklist(inst); }); FuncBuilder mlBuilder(fn); MLFuncRewriter rewriter(driver, mlBuilder); diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 93039372121..4168dda064a 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -101,7 +101,7 @@ bool mlir::promoteIfSingleIteration(ForInst *forInst) { if (!forInst->use_empty()) { if (forInst->hasConstantLowerBound()) { auto *mlFunc = forInst->getFunction(); - FuncBuilder topBuilder(&mlFunc->getBody()->front()); + FuncBuilder topBuilder(mlFunc); auto constOp = topBuilder.create( forInst->getLoc(), forInst->getConstantLowerBound()); forInst->replaceAllUsesWith(constOp); diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index bbb703cd627..58bb3901947 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -651,7 +651,7 @@ namespace { struct Vectorize : public FunctionPass { Vectorize() : FunctionPass(&Vectorize::passID) {} - PassResult runOnMLFunction(Function *f) override; + PassResult runOnFunction(Function *f) override; // Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit. MLFunctionMatcherContext MLContext; @@ -1267,7 +1267,7 @@ static bool vectorizeRootMatches(MLFunctionMatches matches, /// Applies vectorization to the current Function by searching over a bunch of /// predetermined patterns. -PassResult Vectorize::runOnMLFunction(Function *f) { +PassResult Vectorize::runOnFunction(Function *f) { for (auto pat : makePatterns()) { LLVM_DEBUG(dbgs() << "\n******************************************"); LLVM_DEBUG(dbgs() << "\n******************************************"); -- cgit v1.2.3 From b9fe6be6d4cefdd6aefeaee2c7ab5c475efd8e4f Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Sat, 29 Dec 2018 19:16:55 -0800 Subject: Introduce memref store to load forwarding - a simple memref dataflow analysis - the load/store forwarding relies on memref dependence routines as well as SSA/dominance to identify the memref store instance uniquely supplying a value to a memref load, and replaces the result of that load with the value being stored. The memref is also deleted when possible if only stores remain. - add methods for post dominance for MLFunction blocks. - remove duplicated getLoopDepth/getNestingDepth - move getNestingDepth, getMemRefAccess, getNumCommonSurroundingLoops into Analysis/Utils (were earlier static) - add a helper method in FlatAffineConstraints - isRangeOneToOne. PiperOrigin-RevId: 227252907 --- mlir/include/mlir/Analysis/AffineStructures.h | 6 + mlir/include/mlir/Analysis/Utils.h | 20 ++- mlir/include/mlir/Transforms/Passes.h | 4 + mlir/lib/Analysis/AffineStructures.cpp | 45 +++++ mlir/lib/Analysis/MemRefDependenceCheck.cpp | 20 +-- mlir/lib/Analysis/Utils.cpp | 95 ++++++++++ mlir/lib/Transforms/DmaGeneration.cpp | 13 -- mlir/lib/Transforms/LoopFusion.cpp | 43 +---- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 243 ++++++++++++++++++++++++++ mlir/test/Transforms/memref-dataflow-opt.mlir | 239 +++++++++++++++++++++++++ 10 files changed, 656 insertions(+), 72 deletions(-) create mode 100644 mlir/lib/Transforms/MemRefDataFlowOpt.cpp create mode 100644 mlir/test/Transforms/memref-dataflow-opt.mlir (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h index c644a33c938..006d73e9be8 100644 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -497,6 +497,12 @@ public: /// 'num' identifiers starting at position 'pos'. void constantFoldIdRange(unsigned pos, unsigned num); + /// Returns true if all the identifiers in the specified range [start, limit) + /// can only take a single value each if the remaining identifiers are treated + /// as symbols/parameters, i.e., for given values of the latter, there only + /// exists a unique value for each of the dimensions in the specified range. + bool isRangeOneToOne(unsigned start, unsigned limit) const; + unsigned getNumConstraints() const { return getNumInequalities() + getNumEqualities(); } diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index fe04d401bcd..e4ab4ffda1c 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -50,6 +50,16 @@ bool properlyDominates(const Instruction &a, const Instruction &b); // TODO(bondhugula): handle 'if' inst's. void getLoopIVs(const Instruction &inst, SmallVectorImpl *loops); +/// Returns true if instruction 'a' postdominates instruction b. +bool postDominates(const Instruction &a, const Instruction &b); + +/// Returns true if instruction 'a' properly postdominates instruction b. +bool properlyPostDominates(const Instruction &a, const Instruction &b); + +/// Returns the nesting depth of this instruction, i.e., the number of loops +/// surrounding this instruction. +unsigned getNestingDepth(const Instruction &stmt); + /// A region of a memref's data space; this is typically constructed by /// analyzing load/store op's on this memref and the index space of loops /// surrounding such op's. @@ -83,7 +93,8 @@ struct MemRefRegion { /// minor) which matches 1:1 with the dimensional identifier positions in //'cst'. Optional - getConstantBoundOnDimSize(unsigned pos, SmallVectorImpl *lb) const { + getConstantBoundOnDimSize(unsigned pos, + SmallVectorImpl *lb = nullptr) const { assert(pos < getRank() && "invalid position"); return cst.getConstantBoundOnDimSize(pos, lb); } @@ -142,6 +153,13 @@ template bool boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, bool emitError = true); +/// Constructs a MemRefAccess from a load or store operation instruction. +void getMemRefAccess(OperationInst *loadOrStoreOpInst, MemRefAccess *access); + +/// Returns the number of surrounding loops common to both A and B. +unsigned getNumCommonSurroundingLoops(const Instruction &A, + const Instruction &B); + /// Creates a clone of the computation contained in the loop nest surrounding /// 'srcAccess', slices the iteration space of the first 'srcLoopDepth' src loop /// IVs, and inserts the computation slice at the beginning of the instruction diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index acf07c5143f..dc79eba0d32 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -102,6 +102,10 @@ FunctionPass *createLowerAffineApplyPass(); /// Creates a pass to lower VectorTransferReadOp and VectorTransferWriteOp. FunctionPass *createLowerVectorTransfersPass(); +/// Creates a pass to perform optimizations relying on memref dataflow such as +/// store to load forwarding, elimination of dead stores, and dead allocs. +FunctionPass *createMemRefDataFlowOptPass(); + } // end namespace mlir #endif // MLIR_TRANSFORMS_PASSES_H diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 11d0f170550..b67e18901d8 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1950,3 +1950,48 @@ void FlatAffineConstraints::projectOut(Value *id) { (void)ret; FourierMotzkinEliminate(pos); } + +bool FlatAffineConstraints::isRangeOneToOne(unsigned start, + unsigned limit) const { + assert(start <= getNumIds() - 1 && "invalid start position"); + assert(limit > start && limit <= getNumIds() && "invalid limit"); + + FlatAffineConstraints tmpCst(*this); + + if (start != 0) { + // Move [start, limit) to the left. + for (unsigned r = 0, e = getNumInequalities(); r < e; ++r) { + for (unsigned c = 0, f = getNumCols(); c < f; ++c) { + if (c >= start && c < limit) + tmpCst.atIneq(r, c - start) = atIneq(r, c); + else if (c < start) + tmpCst.atIneq(r, c + limit - start) = atIneq(r, c); + else + tmpCst.atIneq(r, c) = atIneq(r, c); + } + } + for (unsigned r = 0, e = getNumEqualities(); r < e; ++r) { + for (unsigned c = 0, f = getNumCols(); c < f; ++c) { + if (c >= start && c < limit) + tmpCst.atEq(r, c - start) = atEq(r, c); + else if (c < start) + tmpCst.atEq(r, c + limit - start) = atEq(r, c); + else + tmpCst.atEq(r, c) = atEq(r, c); + } + } + } + + // Mark everything to the right as symbols so that we can check the extents in + // a symbolic way below. + tmpCst.setDimSymbolSeparation(getNumIds() - (limit - start)); + + // Check if the extents of all the specified dimensions are just one (when + // treating the rest as symbols). + for (unsigned pos = 0, e = tmpCst.getNumDimIds(); pos < e; ++pos) { + auto extent = tmpCst.getConstantBoundOnDimSize(pos); + if (!extent.hasValue() || extent.getValue() != 1) + return false; + } + return true; +} diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp index 1df935f544e..c7bf2abd8d6 100644 --- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp @@ -89,20 +89,6 @@ static void getMemRefAccess(const OperationInst *loadOrStoreOpInst, } } -// Returns the number of surrounding loops common to 'loopsA' and 'loopsB', -// where each lists loops from outer-most to inner-most in loop nest. -static unsigned getNumCommonSurroundingLoops(ArrayRef loopsA, - ArrayRef loopsB) { - unsigned minNumLoops = std::min(loopsA.size(), loopsB.size()); - unsigned numCommonLoops = 0; - for (unsigned i = 0; i < minNumLoops; ++i) { - if (loopsA[i] != loopsB[i]) - break; - ++numCommonLoops; - } - return numCommonLoops; -} - // Returns a result string which represents the direction vector (if there was // a dependence), returns the string "false" otherwise. static string @@ -134,17 +120,13 @@ static void checkDependences(ArrayRef loadsAndStores) { auto *srcOpInst = loadsAndStores[i]; MemRefAccess srcAccess; getMemRefAccess(srcOpInst, &srcAccess); - SmallVector srcLoops; - getLoopIVs(*srcOpInst, &srcLoops); for (unsigned j = 0; j < e; ++j) { auto *dstOpInst = loadsAndStores[j]; MemRefAccess dstAccess; getMemRefAccess(dstOpInst, &dstAccess); - SmallVector dstLoops; - getLoopIVs(*dstOpInst, &dstLoops); unsigned numCommonLoops = - getNumCommonSurroundingLoops(srcLoops, dstLoops); + getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst); for (unsigned d = 1; d <= numCommonLoops + 1; ++d) { FlatAffineConstraints dependenceConstraints; llvm::SmallVector dependenceComponents; diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 6c70ee22df2..75aec132060 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -64,11 +64,53 @@ bool mlir::properlyDominates(const Instruction &a, const Instruction &b) { return false; } +/// Returns true if statement 'a' properly postdominates statement b. +bool mlir::properlyPostDominates(const Instruction &a, const Instruction &b) { + // Only applicable to ML functions. + assert(a.getFunction()->isML() && b.getFunction()->isML()); + + if (&a == &b) + return false; + + if (a.getFunction() != b.getFunction()) + return false; + + if (a.getBlock() == b.getBlock()) { + // Do a linear scan to determine whether a comes after b. + auto aIter = Block::const_iterator(a); + auto bIter = Block::const_iterator(b); + auto bBlockStart = b.getBlock()->begin(); + while (aIter != bBlockStart) { + --aIter; + if (aIter == bIter) + return true; + } + return false; + } + + // Traverse up b's hierarchy to check if b's block is contained in a's. + if (const auto *bAncestor = a.getBlock()->findAncestorInstInBlock(b)) + // a and bAncestor are in the same block; check if 'a' postdominates + // bAncestor. + return postDominates(a, *bAncestor); + + // b's block is not contained in A's. + return false; +} + /// Returns true if instruction A dominates instruction B. bool mlir::dominates(const Instruction &a, const Instruction &b) { return &a == &b || properlyDominates(a, b); } +/// Returns true if statement A postdominates statement B. +bool mlir::postDominates(const Instruction &a, const Instruction &b) { + // Only applicable to ML functions. + assert(a.getFunction()->isML() && b.getFunction()->isML()); + + return &a == &b || properlyPostDominates(a, b); +} + /// Populates 'loops' with IVs of the loops surrounding 'inst' ordered from /// the outermost 'for' instruction to the innermost one. void mlir::getLoopIVs(const Instruction &inst, @@ -485,3 +527,56 @@ ForInst *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess, } return sliceLoopNest; } + +void mlir::getMemRefAccess(OperationInst *loadOrStoreOpInst, + MemRefAccess *access) { + if (auto loadOp = loadOrStoreOpInst->dyn_cast()) { + access->memref = loadOp->getMemRef(); + access->opInst = loadOrStoreOpInst; + auto loadMemrefType = loadOp->getMemRefType(); + access->indices.reserve(loadMemrefType.getRank()); + for (auto *index : loadOp->getIndices()) { + access->indices.push_back(index); + } + } else { + assert(loadOrStoreOpInst->isa() && "load/store op expected"); + auto storeOp = loadOrStoreOpInst->dyn_cast(); + access->opInst = loadOrStoreOpInst; + access->memref = storeOp->getMemRef(); + auto storeMemrefType = storeOp->getMemRefType(); + access->indices.reserve(storeMemrefType.getRank()); + for (auto *index : storeOp->getIndices()) { + access->indices.push_back(index); + } + } +} + +/// Returns the nesting depth of this statement, i.e., the number of loops +/// surrounding this statement. +unsigned mlir::getNestingDepth(const Instruction &stmt) { + const Instruction *currInst = &stmt; + unsigned depth = 0; + while ((currInst = currInst->getParentInst())) { + if (isa(currInst)) + depth++; + } + return depth; +} + +/// Returns the number of surrounding loops common to 'loopsA' and 'loopsB', +/// where each lists loops from outer-most to inner-most in loop nest. +unsigned mlir::getNumCommonSurroundingLoops(const Instruction &A, + const Instruction &B) { + SmallVector loopsA, loopsB; + getLoopIVs(A, &loopsA); + getLoopIVs(B, &loopsB); + + unsigned minNumLoops = std::min(loopsA.size(), loopsB.size()); + unsigned numCommonLoops = 0; + for (unsigned i = 0; i < minNumLoops; ++i) { + if (loopsA[i] != loopsB[i]) + break; + ++numCommonLoops; + } + return numCommonLoops; +} diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index b5e7653d2b9..e60f3531b62 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -367,19 +367,6 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForInst *forInst, return true; } -/// Returns the nesting depth of this instruction, i.e., the number of loops -/// surrounding this instruction. -// TODO(bondhugula): move this to utilities later. -static unsigned getNestingDepth(const Instruction &inst) { - const Instruction *currInst = &inst; - unsigned depth = 0; - while ((currInst = currInst->getParentInst())) { - if (isa(currInst)) - depth++; - } - return depth; -} - // TODO(bondhugula): make this run on a Block instead of a 'for' inst. void DmaGeneration::runOnForInst(ForInst *forInst) { // For now (for testing purposes), we'll run this on the outermost among 'for' diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 1854cd99ab5..1610932918f 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -80,29 +80,6 @@ char LoopFusion::passID = 0; FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; } -static void getSingleMemRefAccess(OperationInst *loadOrStoreOpInst, - MemRefAccess *access) { - if (auto loadOp = loadOrStoreOpInst->dyn_cast()) { - access->memref = loadOp->getMemRef(); - access->opInst = loadOrStoreOpInst; - auto loadMemrefType = loadOp->getMemRefType(); - access->indices.reserve(loadMemrefType.getRank()); - for (auto *index : loadOp->getIndices()) { - access->indices.push_back(index); - } - } else { - assert(loadOrStoreOpInst->isa()); - auto storeOp = loadOrStoreOpInst->dyn_cast(); - access->opInst = loadOrStoreOpInst; - access->memref = storeOp->getMemRef(); - auto storeMemrefType = storeOp->getMemRefType(); - access->indices.reserve(storeMemrefType.getRank()); - for (auto *index : storeOp->getIndices()) { - access->indices.push_back(index); - } - } -} - // FusionCandidate encapsulates source and destination memref access within // loop nests which are candidates for loop fusion. struct FusionCandidate { @@ -116,24 +93,12 @@ static FusionCandidate buildFusionCandidate(OperationInst *srcStoreOpInst, OperationInst *dstLoadOpInst) { FusionCandidate candidate; // Get store access for src loop nest. - getSingleMemRefAccess(srcStoreOpInst, &candidate.srcAccess); + getMemRefAccess(srcStoreOpInst, &candidate.srcAccess); // Get load access for dst loop nest. - getSingleMemRefAccess(dstLoadOpInst, &candidate.dstAccess); + getMemRefAccess(dstLoadOpInst, &candidate.dstAccess); return candidate; } -// Returns the loop depth of the loop nest surrounding 'opInst'. -static unsigned getLoopDepth(OperationInst *opInst) { - unsigned loopDepth = 0; - auto *currInst = opInst->getParentInst(); - ForInst *currForInst; - while (currInst && (currForInst = dyn_cast(currInst))) { - ++loopDepth; - currInst = currInst->getParentInst(); - } - return loopDepth; -} - namespace { // LoopNestStateCollector walks loop nests and collects load and store @@ -520,10 +485,10 @@ public: // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. unsigned srcLoopDepth = clSrcLoopDepth.getNumOccurrences() > 0 ? clSrcLoopDepth - : getLoopDepth(srcStoreOpInst); + : getNestingDepth(*srcStoreOpInst); unsigned dstLoopDepth = clDstLoopDepth.getNumOccurrences() > 0 ? clDstLoopDepth - : getLoopDepth(dstLoadOpInst); + : getNestingDepth(*dstLoadOpInst); auto *sliceLoopNest = mlir::insertBackwardComputationSlice( &candidate.srcAccess, &candidate.dstAccess, srcLoopDepth, dstLoopDepth); diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp new file mode 100644 index 00000000000..d1af131d383 --- /dev/null +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -0,0 +1,243 @@ +//===- MemRefDataFlowOpt.cpp - MemRef DataFlow Optimization pass ------ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements a pass to forward memref stores to loads, thereby +// potentially getting rid of intermediate memref's entirely. +// TODO(mlir-team): In the future, similar techniques could be used to eliminate +// dead memref store's and perform more complex forwarding when support for +// SSA scalars live out of 'for'/'if' statements is available. +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/Analysis/Utils.h" +#include "mlir/IR/InstVisitor.h" +#include "mlir/Pass.h" +#include "mlir/StandardOps/StandardOps.h" +#include "mlir/Transforms/Passes.h" +#include "llvm/Support/raw_ostream.h" +#include + +#define DEBUG_TYPE "memref-dataflow-opt" + +using namespace mlir; + +namespace { + +// The store to load forwarding relies on three conditions: +// +// 1) there has to be a dependence from the store to the load satisfied at the +// block immediately within the innermost common surrounding loop of the load op +// and the store op, and such a dependence should associate with a single load +// location for a given source store iteration. +// +// 2) the store op should dominate the load op, +// +// 3) among all candidate store op's that satisfy (1) and (2), if there exists a +// store op that postdominates all those that satisfy (1), such a store op is +// provably the last writer to the particular memref location being loaded from +// by the load op, and its store value can be forwarded to the load. +// +// The above conditions are simple to check, sufficient, and powerful for most +// cases in practice - condition (1) and (3) are precise and necessary, while +// condition (2) is a sufficient one but not necessary (since it doesn't reason +// about loops that are guaranteed to execute at least one). +// +// TODO(mlir-team): more forwarding can be done when support for +// loop/conditional live-out SSA values is available. +// TODO(mlir-team): do general dead store elimination for memref's. This pass +// currently only eliminates the stores only if no other loads/uses (other +// than dealloc) remain. +// +struct MemRefDataFlowOpt : public FunctionPass, InstWalker { + explicit MemRefDataFlowOpt() : FunctionPass(&MemRefDataFlowOpt::passID) {} + + // Not applicable to CFG functions. + PassResult runOnCFGFunction(Function *f) override { return success(); } + PassResult runOnMLFunction(Function *f) override; + + void visitOperationInst(OperationInst *opInst); + + // A list of memref's that are potentially dead / could be eliminated. + std::vector memrefsToErase; + + static char passID; +}; + +} // end anonymous namespace + +char MemRefDataFlowOpt::passID = 0; + +/// Creates a pass to perform optimizations relying on memref dataflow such as +/// store to load forwarding, elimination of dead stores, and dead allocs. +FunctionPass *mlir::createMemRefDataFlowOptPass() { + return new MemRefDataFlowOpt(); +} + +// This is a straightforward implementation not optimized for speed. Optimize +// this in the future if needed. +void MemRefDataFlowOpt::visitOperationInst(OperationInst *opInst) { + OperationInst *lastWriteStoreOp = nullptr; + + auto loadOp = opInst->dyn_cast(); + if (!loadOp) + return; + + OperationInst *loadOpInst = opInst; + + // First pass over the use list to get minimum number of surrounding + // loops common between the load op and the store op, with min taken across + // all store ops. + SmallVector storeOps; + unsigned minSurroundingLoops = getNestingDepth(*loadOpInst); + for (InstOperand &use : loadOp->getMemRef()->getUses()) { + auto storeOp = cast(use.getOwner())->dyn_cast(); + if (!storeOp) + continue; + auto *storeOpInst = storeOp->getInstruction(); + unsigned nsLoops = getNumCommonSurroundingLoops(*loadOpInst, *storeOpInst); + minSurroundingLoops = std::min(nsLoops, minSurroundingLoops); + storeOps.push_back(storeOpInst); + } + + // 1. Check if there is a dependence satisfied at depth equal to the depth + // of the loop body of the innermost common surrounding loop of the storeOp + // and loadOp. + // The list of store op candidates for forwarding - need to satisfy the + // conditions listed at the top. + SmallVector fwdingCandidates; + // Store ops that have a dependence into the load (even if they aren't + // forwarding candidates). Each fwding candidate will be checked for a + // post-dominance on these. 'fwdingCandidates' are a subset of depSrcStores. + SmallVector depSrcStores; + for (auto *storeOpInst : storeOps) { + MemRefAccess srcAccess, destAccess; + getMemRefAccess(storeOpInst, &srcAccess); + getMemRefAccess(loadOpInst, &destAccess); + FlatAffineConstraints dependenceConstraints; + unsigned nsLoops = getNumCommonSurroundingLoops(*loadOpInst, *storeOpInst); + // Dependences at loop depth <= minSurroundingLoops do NOT matter. + for (unsigned d = nsLoops + 1; d > minSurroundingLoops; d--) { + if (!checkMemrefAccessDependence(srcAccess, destAccess, d, + &dependenceConstraints, + /*dependenceComponents=*/nullptr)) + continue; + depSrcStores.push_back(storeOpInst); + // Check if this store is a candidate for forwarding; we only forward if + // the dependence from the store is carried by the *body* of innermost + // common surrounding loop. As an example this filters out cases like: + // for %i0 + // for %i1 + // %idx = affine_apply (d0) -> (d0 + 1) (%i0) + // store %A[%idx] + // load %A[%i0] + // + if (d != nsLoops + 1) + break; + + // 2. The store has to dominate the load op to be candidate. This is not + // strictly a necessary condition since dominance isn't a prerequisite for + // a memref element store to reach a load, but this is sufficient and + // reasonably powerful in practice. + if (!dominates(*storeOpInst, *loadOpInst)) + break; + + // Finally, forwarding is only possible if the load touches a single + // location in the memref across the enclosing loops *not* common with the + // store. This is filtering out cases like: + // for (i ...) + // a [i] = ... + // for (j ...) + // ... = a[j] + MemRefRegion region; + getMemRefRegion(loadOpInst, nsLoops, ®ion); + if (!region.getConstraints()->isRangeOneToOne( + /*start=*/0, /*limit=*/loadOp->getMemRefType().getRank())) + break; + + // After all these conditions, we have a candidate for forwarding! + fwdingCandidates.push_back(storeOpInst); + break; + } + } + + // Note: this can implemented in a cleaner way with postdominator tree + // traversals. Consider this for the future if needed. + for (auto *storeOpInst : fwdingCandidates) { + // 3. Of all the store op's that meet the above criteria, the store + // that postdominates all 'depSrcStores' (if such a store exists) is the + // unique store providing the value to the load, i.e., provably the last + // writer to that memref loc. + if (llvm::all_of(depSrcStores, [&](OperationInst *depStore) { + return postDominates(*storeOpInst, *depStore); + })) { + lastWriteStoreOp = storeOpInst; + break; + } + } + // TODO: optimization for future: those store op's that are determined to be + // postdominated above can actually be recorded and skipped on the 'i' loop + // iteration above --- since they can never post dominate everything. + + if (!lastWriteStoreOp) + return; + + // Perform the actual store to load forwarding. + Value *storeVal = lastWriteStoreOp->cast()->getValueToStore(); + loadOp->getResult()->replaceAllUsesWith(storeVal); + // Record the memref for a later sweep to optimize away. + memrefsToErase.push_back(loadOp->getMemRef()); + loadOp->erase(); +} + +PassResult MemRefDataFlowOpt::runOnMLFunction(Function *f) { + memrefsToErase.clear(); + + // Walk all load's and perform load/store forwarding. + walk(f); + + // Check if the store fwd'ed memrefs are now left with only stores and can + // thus be completely deleted. Note: the canononicalize pass should be able + // to do this as well, but we'll do it here since we collected these anyway. + for (auto *memref : memrefsToErase) { + // If the memref hasn't been alloc'ed in this function, skip. + OperationInst *defInst = memref->getDefiningInst(); + if (!defInst || !cast(defInst)->isa()) + // TODO(mlir-team): if the memref was returned by a 'call' instruction, we + // could still erase it if the call has no side-effects. + continue; + if (std::any_of(memref->use_begin(), memref->use_end(), + [&](InstOperand &use) { + auto *ownerInst = cast(use.getOwner()); + return (!ownerInst->isa() && + !ownerInst->isa()); + })) + continue; + + // Erase all stores, the dealloc, and the alloc on the memref. + for (auto it = memref->use_begin(), e = memref->use_end(); it != e;) { + auto &use = *(it++); + cast(use.getOwner())->erase(); + } + defInst->erase(); + } + + // This function never leaves the IR in an invalid state. + return success(); +} + +static PassRegistration + pass("memref-dataflow-opt", "Perform store/load forwarding for memrefs"); diff --git a/mlir/test/Transforms/memref-dataflow-opt.mlir b/mlir/test/Transforms/memref-dataflow-opt.mlir new file mode 100644 index 00000000000..c864873f5b2 --- /dev/null +++ b/mlir/test/Transforms/memref-dataflow-opt.mlir @@ -0,0 +1,239 @@ +// RUN: mlir-opt %s -memref-dataflow-opt -verify | FileCheck %s + +// CHECK-LABEL: mlfunc @simple_store_load() { +mlfunc @simple_store_load() { + %cf7 = constant 7.0 : f32 + %m = alloc() : memref<10xf32> + for %i0 = 0 to 10 { + store %cf7, %m[%i0] : memref<10xf32> + %v0 = load %m[%i0] : memref<10xf32> + %v1 = addf %v0, %v0 : f32 + } + return +// CHECK: %cst = constant 7.000000e+00 : f32 +// CHECK-NEXT: for %i0 = 0 to 10 { +// CHECK-NEXT: %0 = addf %cst, %cst : f32 +// CHECK-NEXT: } +// CHECK-NEXT: return +} + +// CHECK-LABEL: mlfunc @multi_store_load() { +mlfunc @multi_store_load() { + %c0 = constant 0 : index + %cf7 = constant 7.0 : f32 + %cf8 = constant 8.0 : f32 + %cf9 = constant 9.0 : f32 + %m = alloc() : memref<10xf32> + for %i0 = 0 to 10 { + store %cf7, %m[%i0] : memref<10xf32> + %v0 = load %m[%i0] : memref<10xf32> + %v1 = addf %v0, %v0 : f32 + store %cf8, %m[%i0] : memref<10xf32> + store %cf9, %m[%i0] : memref<10xf32> + %v2 = load %m[%i0] : memref<10xf32> + %v3 = load %m[%i0] : memref<10xf32> + %v4 = mulf %v2, %v3 : f32 + } + return +// CHECK: %c0 = constant 0 : index +// CHECK-NEXT: %cst = constant 7.000000e+00 : f32 +// CHECK-NEXT: %cst_0 = constant 8.000000e+00 : f32 +// CHECK-NEXT: %cst_1 = constant 9.000000e+00 : f32 +// CHECK-NEXT: for %i0 = 0 to 10 { +// CHECK-NEXT: %0 = addf %cst, %cst : f32 +// CHECK-NEXT: %1 = mulf %cst_1, %cst_1 : f32 +// CHECK-NEXT: } +// CHECK-NEXT: return + +} + +// The store-load forwarding can see through affine apply's since it relies on +// dependence information. +// CHECK-LABEL: mlfunc @store_load_affine_apply +mlfunc @store_load_affine_apply() -> memref<10x10xf32> { + %cf7 = constant 7.0 : f32 + %m = alloc() : memref<10x10xf32> + for %i0 = 0 to 10 { + for %i1 = 0 to 10 { + %t = affine_apply (d0, d1) -> (d1 + 1, d0)(%i0, %i1) + %idx = affine_apply (d0, d1) -> (d1, d0 - 1) (%t#0, %t#1) + store %cf7, %m[%idx#0, %idx#1] : memref<10x10xf32> + // CHECK-NOT: load %{{[0-9]+}} + %v0 = load %m[%i0, %i1] : memref<10x10xf32> + %v1 = addf %v0, %v0 : f32 + } + } + // The memref and its stores won't be erased due to this memref return. + return %m : memref<10x10xf32> +// CHECK: %cst = constant 7.000000e+00 : f32 +// CHECK-NEXT: %0 = alloc() : memref<10x10xf32> +// CHECK-NEXT: for %i0 = 0 to 10 { +// CHECK-NEXT: for %i1 = 0 to 10 { +// CHECK-NEXT: %1 = affine_apply #map0(%i0, %i1) +// CHECK-NEXT: %2 = affine_apply #map1(%1#0, %1#1) +// CHECK-NEXT: store %cst, %0[%2#0, %2#1] : memref<10x10xf32> +// CHECK-NEXT: %3 = addf %cst, %cst : f32 +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return %0 : memref<10x10xf32> +} + +// CHECK-LABEL: mlfunc @store_load_nested +mlfunc @store_load_nested(%N : index) { + %cf7 = constant 7.0 : f32 + %m = alloc() : memref<10xf32> + for %i0 = 0 to 10 { + store %cf7, %m[%i0] : memref<10xf32> + for %i1 = 0 to %N { + %v0 = load %m[%i0] : memref<10xf32> + %v1 = addf %v0, %v0 : f32 + } + } + return +// CHECK: %cst = constant 7.000000e+00 : f32 +// CHECK-NEXT: for %i0 = 0 to 10 { +// CHECK-NEXT: for %i1 = 0 to %arg0 { +// CHECK-NEXT: %0 = addf %cst, %cst : f32 +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return +} + +// No forwarding happens here since either of the two stores could be the last +// writer; store/load forwarding will however be possible here once loop live +// out SSA scalars are available. +// CHECK-LABEL: mlfunc @multi_store_load_nested_no_fwd +mlfunc @multi_store_load_nested_no_fwd(%N : index) { + %cf7 = constant 7.0 : f32 + %cf8 = constant 8.0 : f32 + %m = alloc() : memref<10xf32> + for %i0 = 0 to 10 { + store %cf7, %m[%i0] : memref<10xf32> + for %i1 = 0 to %N { + store %cf8, %m[%i1] : memref<10xf32> + } + for %i2 = 0 to %N { + // CHECK: %{{[0-9]+}} = load %0[%i0] : memref<10xf32> + %v0 = load %m[%i0] : memref<10xf32> + %v1 = addf %v0, %v0 : f32 + } + } + return +} + +// No forwarding happens here since both stores have a value going into +// the load. +// CHECK-LABEL: mlfunc @store_load_store_nested_no_fwd +mlfunc @store_load_store_nested_no_fwd(%N : index) { + %cf7 = constant 7.0 : f32 + %cf9 = constant 9.0 : f32 + %m = alloc() : memref<10xf32> + for %i0 = 0 to 10 { + store %cf7, %m[%i0] : memref<10xf32> + for %i1 = 0 to %N { + // CHECK: %{{[0-9]+}} = load %0[%i0] : memref<10xf32> + %v0 = load %m[%i0] : memref<10xf32> + %v1 = addf %v0, %v0 : f32 + store %cf9, %m[%i0] : memref<10xf32> + } + } + return +} + +// Forwarding happens here since the last store postdominates all other stores +// and other forwarding criteria are satisfied. +// CHECK-LABEL: mlfunc @multi_store_load_nested_fwd +mlfunc @multi_store_load_nested_fwd(%N : index) { + %cf7 = constant 7.0 : f32 + %cf8 = constant 8.0 : f32 + %cf9 = constant 9.0 : f32 + %cf10 = constant 10.0 : f32 + %m = alloc() : memref<10xf32> + for %i0 = 0 to 10 { + store %cf7, %m[%i0] : memref<10xf32> + for %i1 = 0 to %N { + store %cf8, %m[%i1] : memref<10xf32> + } + for %i2 = 0 to %N { + store %cf9, %m[%i2] : memref<10xf32> + } + store %cf10, %m[%i0] : memref<10xf32> + for %i3 = 0 to %N { + // CHECK-NOT: %{{[0-9]+}} = load + %v0 = load %m[%i0] : memref<10xf32> + %v1 = addf %v0, %v0 : f32 + } + } + return +} + +// No one-to-one dependence here between the store and load. +// CHECK-LABEL: mlfunc @store_load_no_fwd +mlfunc @store_load_no_fwd() { + %cf7 = constant 7.0 : f32 + %m = alloc() : memref<10xf32> + for %i0 = 0 to 10 { + store %cf7, %m[%i0] : memref<10xf32> + for %i1 = 0 to 10 { + for %i2 = 0 to 10 { + // CHECK: load %{{[0-9]+}} + %v0 = load %m[%i2] : memref<10xf32> + %v1 = addf %v0, %v0 : f32 + } + } + } + return +} + +// Forwarding happens here as there is a one-to-one store-load correspondence. +// CHECK-LABEL: mlfunc @store_load_fwd +mlfunc @store_load_fwd() { + %cf7 = constant 7.0 : f32 + %c0 = constant 0 : index + %m = alloc() : memref<10xf32> + store %cf7, %m[%c0] : memref<10xf32> + for %i0 = 0 to 10 { + for %i1 = 0 to 10 { + for %i2 = 0 to 10 { + // CHECK-NOT: load %{{[0-9]}}+ + %v0 = load %m[%c0] : memref<10xf32> + %v1 = addf %v0, %v0 : f32 + } + } + } + return +} + +// Although there is a dependence from the second store to the load, it is +// satisfied by the outer surrounding loop, and does not prevent the first +// store to be forwarded to the load. +mlfunc @store_load_store_nested_fwd(%N : index) -> f32 { + %cf7 = constant 7.0 : f32 + %cf9 = constant 9.0 : f32 + %c0 = constant 0 : index + %c1 = constant 1 : index + %m = alloc() : memref<10xf32> + for %i0 = 0 to 10 { + store %cf7, %m[%i0] : memref<10xf32> + for %i1 = 0 to %N { + %v0 = load %m[%i0] : memref<10xf32> + %v1 = addf %v0, %v0 : f32 + %idx = affine_apply (d0) -> (d0 + 1) (%i0) + store %cf9, %m[%idx] : memref<10xf32> + } + } + // Due to this load, the memref isn't optimized away. + %v3 = load %m[%c1] : memref<10xf32> + return %v3 : f32 +// CHECK: %0 = alloc() : memref<10xf32> +// CHECK-NEXT: for %i0 = 0 to 10 { +// CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> +// CHECK-NEXT: for %i1 = 0 to %arg0 { +// CHECK-NEXT: %1 = addf %cst, %cst : f32 +// CHECK-NEXT: %2 = affine_apply #map2(%i0) +// CHECK-NEXT: store %cst_0, %0[%2] : memref<10xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: %3 = load %0[%c1] : memref<10xf32> +// CHECK-NEXT: return %3 : f32 +} -- cgit v1.2.3 From 7974889f549a445890435950208ab3863722a3c5 Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Sun, 30 Dec 2018 23:10:35 -0800 Subject: Update and generalize various passes to work on both CFG and ML functions, simplifying them in minor ways. The only significant cleanup here is the constant folding pass. All the other changes are simple and easy, but this is still enough to shrink the compiler by 45LOC. The one pass left to merge is the CSE pass, which will be move involved, so I'm splitting it out to its own patch (which I'll tackle right after this). This is step 28/n towards merging instructions and statements. PiperOrigin-RevId: 227328115 --- mlir/include/mlir/Transforms/Utils.h | 3 - mlir/lib/Transforms/ComposeAffineMaps.cpp | 4 +- mlir/lib/Transforms/ConstantFold.cpp | 80 +++++----------------- mlir/lib/Transforms/LoopFusion.cpp | 4 +- mlir/lib/Transforms/LoopTiling.cpp | 4 +- mlir/lib/Transforms/LoopUnroll.cpp | 4 +- mlir/lib/Transforms/LowerAffineApply.cpp | 41 +++++------ mlir/lib/Transforms/MaterializeVectors.cpp | 8 ++- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 12 ++-- mlir/lib/Transforms/PipelineDataTransfer.cpp | 4 +- mlir/lib/Transforms/SimplifyAffineExpr.cpp | 24 ++++--- mlir/lib/Transforms/Utils/Utils.cpp | 2 - .../Vectorization/VectorizerTestPass.cpp | 8 ++- mlir/lib/Transforms/ViewFunctionGraph.cpp | 2 +- 14 files changed, 75 insertions(+), 125 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h index c63eb0349a7..af72d01ee3d 100644 --- a/mlir/include/mlir/Transforms/Utils.h +++ b/mlir/include/mlir/Transforms/Utils.h @@ -50,8 +50,6 @@ class Function; /// Returns true on success and false if the replacement is not possible /// (whenever a memref is used as an operand in a non-deferencing scenario). See /// comments at function definition for an example. -// TODO(mlir-team): extend this for Value/ CFGFunctions. Can also be easily -// extended to add additional indices at any position. bool replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, ArrayRef extraIndices = {}, AffineMap indexRemap = AffineMap::Null(), @@ -102,7 +100,6 @@ OperationInst *createAffineComputationSlice(OperationInst *opInst); /// Forward substitutes results from 'AffineApplyOp' into any users which /// are also AffineApplyOps. // NOTE: This method may modify users of results of this operation. -// TODO(mlir-team): extend this for Value/ CFGFunctions. void forwardSubstitute(OpPointer affineApplyOp); /// Folds the lower and upper bounds of a 'for' inst to constants if possible. diff --git a/mlir/lib/Transforms/ComposeAffineMaps.cpp b/mlir/lib/Transforms/ComposeAffineMaps.cpp index f5edf2d8b81..7902c34b066 100644 --- a/mlir/lib/Transforms/ComposeAffineMaps.cpp +++ b/mlir/lib/Transforms/ComposeAffineMaps.cpp @@ -48,7 +48,7 @@ struct ComposeAffineMaps : public FunctionPass, InstWalker { using InstListType = llvm::iplist; void walk(InstListType::iterator Start, InstListType::iterator End); void visitOperationInst(OperationInst *inst); - PassResult runOnMLFunction(Function *f) override; + PassResult runOnFunction(Function *f) override; using InstWalker::walk; static char passID; @@ -88,7 +88,7 @@ void ComposeAffineMaps::visitOperationInst(OperationInst *opInst) { } } -PassResult ComposeAffineMaps::runOnMLFunction(Function *f) { +PassResult ComposeAffineMaps::runOnFunction(Function *f) { affineApplyOpsToErase.clear(); walk(f); for (auto *opInst : affineApplyOpsToErase) { diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index f482e90d7ac..8369e3ad43b 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -33,15 +33,12 @@ struct ConstantFold : public FunctionPass, InstWalker { SmallVector existingConstants; // Operations that were folded and that need to be erased. std::vector opInstsToErase; - using ConstantFactoryType = std::function; bool foldOperation(OperationInst *op, - SmallVectorImpl &existingConstants, - ConstantFactoryType constantFactory); + SmallVectorImpl &existingConstants); void visitOperationInst(OperationInst *inst); void visitForInst(ForInst *inst); - PassResult runOnCFGFunction(Function *f) override; - PassResult runOnMLFunction(Function *f) override; + PassResult runOnFunction(Function *f) override; static char passID; }; @@ -52,15 +49,12 @@ char ConstantFold::passID = 0; /// Attempt to fold the specified operation, updating the IR to match. If /// constants are found, we keep track of them in the existingConstants list. /// -/// This returns false if the operation was successfully folded. -bool ConstantFold::foldOperation(OperationInst *op, - SmallVectorImpl &existingConstants, - ConstantFactoryType constantFactory) { +void ConstantFold::visitOperationInst(OperationInst *op) { // If this operation is already a constant, just remember it for cleanup // later, and don't try to fold it. if (auto constant = op->dyn_cast()) { existingConstants.push_back(constant); - return true; + return; } // Check to see if each of the operands is a trivial constant. If so, get @@ -78,7 +72,7 @@ bool ConstantFold::foldOperation(OperationInst *op, // Attempt to constant fold the operation. SmallVector resultConstants; if (op->constantFold(operandConstants, resultConstants)) - return true; + return; // Ok, if everything succeeded, then we can create constants corresponding // to the result of the call. @@ -87,67 +81,21 @@ bool ConstantFold::foldOperation(OperationInst *op, assert(resultConstants.size() == op->getNumResults() && "constant folding produced the wrong number of results"); + FuncBuilder builder(op); for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) { auto *res = op->getResult(i); if (res->use_empty()) // ignore dead uses. continue; - auto *cst = constantFactory(resultConstants[i], res->getType()); + auto cst = builder.create(op->getLoc(), resultConstants[i], + res->getType()); existingConstants.push_back(cst); res->replaceAllUsesWith(cst); } - return false; -} - -// For now, we do a simple top-down pass over a function folding constants. We -// don't handle conditional control flow, constant PHI nodes, folding -// conditional branches, or anything else fancy. -PassResult ConstantFold::runOnCFGFunction(Function *f) { - existingConstants.clear(); - FuncBuilder builder(f); - - for (auto &bb : *f) { - for (auto instIt = bb.begin(), e = bb.end(); instIt != e;) { - auto *inst = dyn_cast(&*instIt++); - if (!inst) - continue; - - auto constantFactory = [&](Attribute value, Type type) -> Value * { - builder.setInsertionPoint(inst); - return builder.create(inst->getLoc(), value, type); - }; - - if (!foldOperation(inst, existingConstants, constantFactory)) { - // At this point the operation is dead, remove it. - // TODO: This is assuming that all constant foldable operations have no - // side effects. When we have side effect modeling, we should verify - // that the operation is effect-free before we remove it. Until then - // this is close enough. - inst->erase(); - } - } - } - - // By the time we are done, we may have simplified a bunch of code, leaving - // around dead constants. Check for them now and remove them. - for (auto *cst : existingConstants) { - if (cst->use_empty()) - cst->getDefiningInst()->erase(); - } - - return success(); -} - -// Override the walker's operation visiter for constant folding. -void ConstantFold::visitOperationInst(OperationInst *inst) { - auto constantFactory = [&](Attribute value, Type type) -> Value * { - FuncBuilder builder(inst); - return builder.create(inst->getLoc(), value, type); - }; - if (!ConstantFold::foldOperation(inst, existingConstants, constantFactory)) { - opInstsToErase.push_back(inst); - } + // At this point the operation is dead, so we can remove it. We add it to + // a vector to avoid invalidating our walker. + opInstsToErase.push_back(op); } // Override the walker's 'for' instruction visit for constant folding. @@ -155,11 +103,15 @@ void ConstantFold::visitForInst(ForInst *forInst) { constantFoldBounds(forInst); } -PassResult ConstantFold::runOnMLFunction(Function *f) { +// For now, we do a simple top-down pass over a function folding constants. We +// don't handle conditional control flow, block arguments, folding +// conditional branches, or anything else fancy. +PassResult ConstantFold::runOnFunction(Function *f) { existingConstants.clear(); opInstsToErase.clear(); walk(f); + // At this point, these operations are dead, remove them. // TODO: This is assuming that all constant foldable operations have no // side effects. When we have side effect modeling, we should verify that diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 1610932918f..31b59d85e14 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -70,7 +70,7 @@ namespace { struct LoopFusion : public FunctionPass { LoopFusion() : FunctionPass(&LoopFusion::passID) {} - PassResult runOnMLFunction(Function *f) override; + PassResult runOnFunction(Function *f) override; static char passID; }; @@ -519,7 +519,7 @@ public: } // end anonymous namespace -PassResult LoopFusion::runOnMLFunction(Function *f) { +PassResult LoopFusion::runOnFunction(Function *f) { MemRefDependenceGraph g; if (g.init(f)) GreedyFusion(&g).run(); diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index fa39b7dfa5b..085a9c0b0fe 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -41,7 +41,7 @@ namespace { /// A pass to perform loop tiling on all suitable loop nests of a Function. struct LoopTiling : public FunctionPass { LoopTiling() : FunctionPass(&LoopTiling::passID) {} - PassResult runOnMLFunction(Function *f) override; + PassResult runOnFunction(Function *f) override; constexpr static unsigned kDefaultTileSize = 4; static char passID; @@ -238,7 +238,7 @@ static void getTileableBands(Function *f, } } -PassResult LoopTiling::runOnMLFunction(Function *f) { +PassResult LoopTiling::runOnFunction(Function *f) { std::vector> bands; getTileableBands(f, &bands); diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 69431bf6349..a0472754ceb 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -70,7 +70,7 @@ struct LoopUnroll : public FunctionPass { : FunctionPass(&LoopUnroll::passID), unrollFactor(unrollFactor), unrollFull(unrollFull), getUnrollFactor(getUnrollFactor) {} - PassResult runOnMLFunction(Function *f) override; + PassResult runOnFunction(Function *f) override; /// Unroll this for inst. Returns false if nothing was done. bool runOnForInst(ForInst *forInst); @@ -83,7 +83,7 @@ struct LoopUnroll : public FunctionPass { char LoopUnroll::passID = 0; -PassResult LoopUnroll::runOnMLFunction(Function *f) { +PassResult LoopUnroll::runOnFunction(Function *f) { // Gathers all innermost loops through a post order pruned walk. class InnermostLoopGatherer : public InstWalker { public: diff --git a/mlir/lib/Transforms/LowerAffineApply.cpp b/mlir/lib/Transforms/LowerAffineApply.cpp index 747733de41e..75e77436bd2 100644 --- a/mlir/lib/Transforms/LowerAffineApply.cpp +++ b/mlir/lib/Transforms/LowerAffineApply.cpp @@ -31,13 +31,11 @@ using namespace mlir; namespace { +// TODO: This shouldn't be its own pass, it should be a legalization (once we +// have the proper infra). struct LowerAffineApply : public FunctionPass { - explicit LowerAffineApply() : FunctionPass(&LowerAffineApply::passID) {} - - PassResult runOnMLFunction(Function *f) override; - PassResult runOnCFGFunction(Function *f) override; - + PassResult runOnFunction(Function *f) override; static char passID; }; @@ -45,28 +43,21 @@ struct LowerAffineApply : public FunctionPass { char LowerAffineApply::passID = 0; -PassResult LowerAffineApply::runOnMLFunction(Function *f) { - f->emitError("ML Functions contain syntactically hidden affine_apply's that " - "cannot be expanded"); - return failure(); -} +PassResult LowerAffineApply::runOnFunction(Function *f) { + SmallVector, 8> affineApplyInsts; -PassResult LowerAffineApply::runOnCFGFunction(Function *f) { - for (Block &bb : *f) { - // Handle iterators with care because we erase in the same loop. - // In particular, step to the next element before erasing the current one. - for (auto it = bb.begin(); it != bb.end();) { - auto *inst = dyn_cast(&*it++); - if (!inst) - continue; + // Find all the affine_apply operations. + f->walkOps([&](OperationInst *inst) { + auto applyOp = inst->dyn_cast(); + if (applyOp) + affineApplyInsts.push_back(applyOp); + }); - auto affineApplyOp = inst->dyn_cast(); - if (!affineApplyOp) - continue; - if (expandAffineApply(&*affineApplyOp)) - return failure(); - } - } + // Rewrite them in a second pass, avoiding invalidation of the walker + // iterator. + for (auto applyOp : affineApplyInsts) + if (expandAffineApply(applyOp)) + return failure(); return success(); } diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 1ab1f6361d3..e95bb7307e3 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -197,7 +197,7 @@ struct MaterializationState { struct MaterializeVectorsPass : public FunctionPass { MaterializeVectorsPass() : FunctionPass(&MaterializeVectorsPass::passID) {} - PassResult runOnMLFunction(Function *f) override; + PassResult runOnFunction(Function *f) override; // Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit. MLFunctionMatcherContext mlContext; @@ -712,7 +712,11 @@ static bool materialize(Function *f, return false; } -PassResult MaterializeVectorsPass::runOnMLFunction(Function *f) { +PassResult MaterializeVectorsPass::runOnFunction(Function *f) { + // TODO(ntv): Check to see if this supports arbitrary top-level code. + if (f->getBlocks().size() != 1) + return success(); + using matcher::Op; LLVM_DEBUG(dbgs() << "\nMaterializeVectors on Function\n"); LLVM_DEBUG(f->print(dbgs())); diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index 1a30e2b289d..49b33b0596b 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -66,9 +66,7 @@ namespace { struct MemRefDataFlowOpt : public FunctionPass, InstWalker { explicit MemRefDataFlowOpt() : FunctionPass(&MemRefDataFlowOpt::passID) {} - // Not applicable to CFG functions. - PassResult runOnCFGFunction(Function *f) override { return success(); } - PassResult runOnMLFunction(Function *f) override; + PassResult runOnFunction(Function *f) override; void visitOperationInst(OperationInst *opInst); @@ -210,7 +208,11 @@ void MemRefDataFlowOpt::visitOperationInst(OperationInst *opInst) { loadOpsToErase.push_back(loadOpInst); } -PassResult MemRefDataFlowOpt::runOnMLFunction(Function *f) { +PassResult MemRefDataFlowOpt::runOnFunction(Function *f) { + // Only supports single block functions at the moment. + if (f->getBlocks().size() != 1) + return success(); + DominanceInfo theDomInfo(f); domInfo = &theDomInfo; PostDominanceInfo thePostDomInfo(f); @@ -233,7 +235,7 @@ PassResult MemRefDataFlowOpt::runOnMLFunction(Function *f) { for (auto *memref : memrefsToErase) { // If the memref hasn't been alloc'ed in this function, skip. OperationInst *defInst = memref->getDefiningInst(); - if (!defInst || !cast(defInst)->isa()) + if (!defInst || !defInst->isa()) // TODO(mlir-team): if the memref was returned by a 'call' instruction, we // could still erase it if the call has no side-effects. continue; diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 321bf20cf0b..33523df1b4d 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -41,7 +41,7 @@ namespace { struct PipelineDataTransfer : public FunctionPass, InstWalker { PipelineDataTransfer() : FunctionPass(&PipelineDataTransfer::passID) {} - PassResult runOnMLFunction(Function *f) override; + PassResult runOnFunction(Function *f) override; PassResult runOnForInst(ForInst *forInst); // Collect all 'for' instructions. @@ -137,7 +137,7 @@ static bool doubleBuffer(Value *oldMemRef, ForInst *forInst) { } /// Returns success if the IR is in a valid state. -PassResult PipelineDataTransfer::runOnMLFunction(Function *f) { +PassResult PipelineDataTransfer::runOnFunction(Function *f) { // Do a post order walk so that inner loop DMAs are processed first. This is // necessary since 'for' instructions nested within would otherwise become // invalid (erased) when the outer loop is pipelined (the pipelined one gets diff --git a/mlir/lib/Transforms/SimplifyAffineExpr.cpp b/mlir/lib/Transforms/SimplifyAffineExpr.cpp index 2a643eb690a..086e8891aac 100644 --- a/mlir/lib/Transforms/SimplifyAffineExpr.cpp +++ b/mlir/lib/Transforms/SimplifyAffineExpr.cpp @@ -21,7 +21,7 @@ #include "mlir/Analysis/AffineStructures.h" #include "mlir/IR/Function.h" -#include "mlir/IR/InstVisitor.h" +#include "mlir/IR/Instructions.h" #include "mlir/Pass.h" #include "mlir/Transforms/Passes.h" @@ -34,17 +34,13 @@ namespace { /// Simplifies all affine expressions appearing in the operation instructions of /// the Function. This is mainly to test the simplifyAffineExpr method. -// TODO(someone): Gradually, extend this to all affine map references found in -// ML functions and CFG functions. -struct SimplifyAffineStructures : public FunctionPass, - InstWalker { +/// TODO(someone): This should just be defined as a canonicalization pattern +/// on AffineMap and driven from the existing canonicalization pass. +struct SimplifyAffineStructures : public FunctionPass { explicit SimplifyAffineStructures() : FunctionPass(&SimplifyAffineStructures::passID) {} - PassResult runOnMLFunction(Function *f) override; - // Does nothing on CFG functions for now. No reusable walkers/visitors exist - // for this yet? TODO(someone). - PassResult runOnCFGFunction(Function *f) override { return success(); } + PassResult runOnFunction(Function *f) override; void visitIfInst(IfInst *ifInst); void visitOperationInst(OperationInst *opInst); @@ -86,8 +82,14 @@ void SimplifyAffineStructures::visitOperationInst(OperationInst *opInst) { } } -PassResult SimplifyAffineStructures::runOnMLFunction(Function *f) { - walk(f); +PassResult SimplifyAffineStructures::runOnFunction(Function *f) { + f->walkInsts([&](Instruction *inst) { + if (auto *opInst = dyn_cast(inst)) + visitOperationInst(opInst); + if (auto *ifInst = dyn_cast(inst)) + visitIfInst(ifInst); + }); + return success(); } diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index b196695c45a..4af9436b44d 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -60,8 +60,6 @@ static bool isMemRefDereferencingOp(const OperationInst &op) { // extra operands, note that 'indexRemap' would just be applied to the existing // indices (%i, %j). // -// TODO(mlir-team): extend this for CFG Functions. Can also be easily -// extended to add additional indices at any position. bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, ArrayRef extraIndices, AffineMap indexRemap, diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index 9aa11682ebb..f4020f3e1c7 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -73,7 +73,7 @@ struct VectorizerTestPass : public FunctionPass { static constexpr auto kTestAffineMapAttrName = "affine_map"; VectorizerTestPass() : FunctionPass(&VectorizerTestPass::passID) {} - PassResult runOnMLFunction(Function *f) override; + PassResult runOnFunction(Function *f) override; void testVectorShapeRatio(Function *f); void testForwardSlicing(Function *f); void testBackwardSlicing(Function *f); @@ -218,7 +218,11 @@ void VectorizerTestPass::testComposeMaps(Function *f) { res.print(outs() << "\nComposed map: "); } -PassResult VectorizerTestPass::runOnMLFunction(Function *f) { +PassResult VectorizerTestPass::runOnFunction(Function *f) { + // Only support single block functions at this point. + if (f->getBlocks().size() != 1) + return success(); + if (!clTestVectorShapeRatio.empty()) { testVectorShapeRatio(f); } diff --git a/mlir/lib/Transforms/ViewFunctionGraph.cpp b/mlir/lib/Transforms/ViewFunctionGraph.cpp index 50a3cf5a595..e46dc503ea9 100644 --- a/mlir/lib/Transforms/ViewFunctionGraph.cpp +++ b/mlir/lib/Transforms/ViewFunctionGraph.cpp @@ -78,7 +78,7 @@ struct PrintCFGPass : public FunctionPass { const llvm::Twine &title = "") : FunctionPass(&PrintCFGPass::passID), os(os), shortNames(shortNames), title(title) {} - PassResult runOnCFGFunction(Function *function) override { + PassResult runOnFunction(Function *function) override { mlir::writeGraph(os, function, shortNames, title); return success(); } -- cgit v1.2.3 From 56b3640b945c38c1a761a8811f30c04deabb5e67 Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Mon, 7 Jan 2019 15:06:32 -0800 Subject: Misc readability and doc / code comment related improvements - NFC - when SSAValue/MLValue existed, code at several places was forced to create additional aggregate temporaries of SmallVector to handle the conversion; get rid of such redundant code - use filling ctors instead of explicit loops - for smallvectors, change insert(list.end(), ...) -> append(... - improve comments at various places - turn getMemRefAccess into MemRefAccess ctor and drop duplicated getMemRefAccess. In the next CL, provide getAccess() accessors for load, store, DMA op's to return a MemRefAccess. PiperOrigin-RevId: 228243638 --- mlir/include/mlir/Analysis/AffineAnalysis.h | 9 ++- mlir/include/mlir/Analysis/AffineStructures.h | 18 +++-- mlir/lib/Analysis/AffineAnalysis.cpp | 99 ++++++++++++++++++--------- mlir/lib/Analysis/AffineStructures.cpp | 8 +-- mlir/lib/Analysis/MemRefDependenceCheck.cpp | 33 +-------- mlir/lib/Analysis/Utils.cpp | 29 ++++---- mlir/lib/EDSC/Types.cpp | 4 +- mlir/lib/Parser/Parser.cpp | 2 +- mlir/lib/Transforms/LoopFusion.cpp | 9 +-- mlir/lib/Transforms/LoopTiling.cpp | 6 +- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 5 +- mlir/lib/Transforms/Utils/Utils.cpp | 32 +++------ 12 files changed, 125 insertions(+), 129 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Analysis/AffineAnalysis.h b/mlir/include/mlir/Analysis/AffineAnalysis.h index b769841b451..588be4ea351 100644 --- a/mlir/include/mlir/Analysis/AffineAnalysis.h +++ b/mlir/include/mlir/Analysis/AffineAnalysis.h @@ -121,11 +121,18 @@ bool getFlattenedAffineExprs( bool getIndexSet(llvm::ArrayRef forInsts, FlatAffineConstraints *domain); +/// Encapsulates a memref load or store access information. struct MemRefAccess { const Value *memref; const OperationInst *opInst; llvm::SmallVector indices; - // Populates 'accessMap' with composition of AffineApplyOps reachable from + + /// Constructs a MemRefAccess from a load or store operation instruction. + // TODO(b/119949820): add accessors to standard op's load, store, DMA op's to + // return MemRefAccess, i.e., loadOp->getAccess(), dmaOp->getRead/WriteAccess. + explicit MemRefAccess(OperationInst *opInst); + + /// Populates 'accessMap' with composition of AffineApplyOps reachable from // 'indices'. void getAccessMap(AffineValueMap *accessMap) const; }; diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h index b1133520d74..ae8cda997a1 100644 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -233,10 +233,16 @@ private: /// /// The identifiers x_0, x_1, ... appear in the order: dimensional identifiers, /// symbolic identifiers, and local identifiers. The local identifiers -/// correspond to local/internal variables created temporarily when converting -/// from tree AffineExpr's that have mod's and div's and are thus needed -/// to increase representational power. -// +/// correspond to local/internal variables created when converting from +/// AffineExpr's containing mod's and div's; they are thus needed to increase +/// representational power. Each local identifier is always (by construction) a +/// floordiv of a pure add/mul affine function of dimensional, symbolic, and +/// other local identifiers, in a non-mutually recursive way. Hence, every local +/// identifier can ultimately always be recovered as an affine function of +/// dimensional and symbolic identifiers (involving floordiv's); note however +/// that some floordiv combinations are converted to mod's by AffineExpr +/// construction. +/// class FlatAffineConstraints { public: enum IdKind { Dimension, Symbol, Local }; @@ -259,7 +265,7 @@ public: if (idArgs.empty()) ids.resize(numIds, None); else - ids.insert(ids.end(), idArgs.begin(), idArgs.end()); + ids.append(idArgs.begin(), idArgs.end()); } /// Constructs a constraint system with the specified number of @@ -276,7 +282,7 @@ public: if (idArgs.empty()) ids.resize(numIds, None); else - ids.insert(ids.end(), idArgs.begin(), idArgs.end()); + ids.append(idArgs.begin(), idArgs.end()); } explicit FlatAffineConstraints(const HyperRectangularSet &set); diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index 89148139fb4..4485326c897 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -81,47 +81,59 @@ namespace { // This class is used to flatten a pure affine expression (AffineExpr, // which is in a tree form) into a sum of products (w.r.t constants) when -// possible, and in that process simplifying the expression. The simplification -// performed includes the accumulation of contributions for each dimensional and -// symbolic identifier together, the simplification of floordiv/ceildiv/mod -// expressions and other simplifications that in turn happen as a result. A -// simplification that this flattening naturally performs is of simplifying the -// numerator and denominator of floordiv/ceildiv, and folding a modulo -// expression to a zero, if possible. Three examples are below: +// possible, and in that process simplifying the expression. For a modulo, +// floordiv, or a ceildiv expression, an additional identifier, called a local +// identifier, is introduced to rewrite the expression as a sum of product +// affine expression. Each local identifier is always and by construction a +// floordiv of a pure add/mul affine function of dimensional, symbolic, and +// other local identifiers, in a non-mutually recursive way. Hence, every local +// identifier can ultimately always be recovered as an affine function of +// dimensional and symbolic identifiers (involving floordiv's); note however +// that by AffineExpr construction, some floordiv combinations are converted to +// mod's. The result of the flattening is a flattened expression and a set of +// constraints involving just the local variables. // -// (d0 + 3 * d1) + d0) - 2 * d1) - d0 simplified to d0 + d1 -// (d0 - d0 mod 4 + 4) mod 4 simplified to 0. -// (3*d0 + 2*d1 + d0) floordiv 2 + d1 simplified to 2*d0 + 2*d1 +// d2 + (d0 + d1) floordiv 4 is flattened to d2 + q where 'q' is the local +// variable introduced, with localVarCst containing 4*q <= d0 + d1 <= 4*q + 3. // -// For a modulo, floordiv, or a ceildiv expression, an additional identifier -// (called a local identifier) is introduced to rewrite it as a sum of products -// (w.r.t constants). For example, for the second example above, d0 % 4 is +// The simplification performed includes the accumulation of contributions for +// each dimensional and symbolic identifier together, the simplification of +// floordiv/ceildiv/mod expressions and other simplifications that in turn +// happen as a result. A simplification that this flattening naturally performs +// is of simplifying the numerator and denominator of floordiv/ceildiv, and +// folding a modulo expression to a zero, if possible. Three examples are below: +// +// (d0 + 3 * d1) + d0) - 2 * d1) - d0 simplified to d0 + d1 +// (d0 - d0 mod 4 + 4) mod 4 simplified to 0 +// (3*d0 + 2*d1 + d0) floordiv 2 + d1 simplified to 2*d0 + 2*d1 +// +// The way the flattening works for the second example is as follows: d0 % 4 is // replaced by d0 - 4*q with q being introduced: the expression then simplifies // to: (d0 - (d0 - 4q) + 4) = 4q + 4, modulo of which w.r.t 4 simplifies to -// zero. Note that an affine expression may not always be expressible in a sum -// of products form involving just the original dimensional and symbolic -// identifiers, due to the presence of modulo/floordiv/ceildiv expressions -// that may not be eliminated after simplification; in such cases, the final +// zero. Note that an affine expression may not always be expressible purely as +// a sum of products involving just the original dimensional and symbolic +// identifiers due to the presence of modulo/floordiv/ceildiv expressions that +// may not be eliminated after simplification; in such cases, the final // expression can be reconstructed by replacing the local identifiers with their -// corresponding explicit form stored in 'localExprs' (note that the explicit -// form itself would have been simplified). +// corresponding explicit form stored in 'localExprs' (note that each of the +// explicit forms itself would have been simplified). // -// This is a linear time post order walk for an affine expression that attempts -// the above simplifications through visit methods, with partial results being -// stored in 'operandExprStack'. When a parent expr is visited, the flattened -// expressions corresponding to its two operands would already be on the stack - -// the parent expression looks at the two flattened expressions and combines the -// two. It pops off the operand expressions and pushes the combined result -// (although this is done in-place on its LHS operand expr). When the walk is -// completed, the flattened form of the top-level expression would be left on -// the stack. +// The expression walk method here performs a linear time post order walk that +// performs the above simplifications through visit methods, with partial +// results being stored in 'operandExprStack'. When a parent expr is visited, +// the flattened expressions corresponding to its two operands would already be +// on the stack - the parent expression looks at the two flattened expressions +// and combines the two. It pops off the operand expressions and pushes the +// combined result (although this is done in-place on its LHS operand expr). +// When the walk is completed, the flattened form of the top-level expression +// would be left on the stack. // // A flattener can be repeatedly used for multiple affine expressions that bind // to the same operands, for example, for all result expressions of an // AffineMap or AffineValueMap. In such cases, using it for multiple expressions // is more efficient than creating a new flattener for each expression since // common idenical div and mod expressions appearing across different -// expressions are mapped to the local identifier (same column position in +// expressions are mapped to the same local identifier (same column position in // 'localVarCst'). struct AffineExprFlattener : public AffineExprVisitor { public: @@ -143,11 +155,11 @@ public: unsigned numLocals; // AffineExpr's corresponding to the floordiv/ceildiv/mod expressions for // which new identifiers were introduced; if the latter do not get canceled - // out, these expressions are needed to reconstruct the AffineExpr / tree - // form. Note that these expressions themselves would have been simplified - // (recursively) by this pass. Eg. d0 + (d0 + 2*d1 + d0) ceildiv 4 will be - // simplified to d0 + q, where q = (d0 + d1) ceildiv 2. (d0 + d1) ceildiv 2 - // would be the local expression stored for q. + // out, these expressions can be readily used to reconstruct the AffineExpr + // (tree) form. Note that these expressions themselves would have been + // simplified (recursively) by this pass. Eg. d0 + (d0 + 2*d1 + d0) ceildiv 4 + // will be simplified to d0 + q, where q = (d0 + d1) ceildiv 2. (d0 + d1) + // ceildiv 2 would be the local expression stored for q. SmallVector localExprs; MLIRContext *context; @@ -186,6 +198,12 @@ public: operandExprStack.pop_back(); } + // + // t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1 + // + // A mod expression "expr mod c" is thus flattened by introducing a new local + // variable q (= expr floordiv c), such that expr mod c is replaced with + // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst. void visitModExpr(AffineBinaryOpExpr expr) { assert(operandExprStack.size() >= 2); // This is a pure affine expr; the RHS will be a constant. @@ -231,18 +249,21 @@ public: void visitFloorDivExpr(AffineBinaryOpExpr expr) { visitDivExpr(expr, /*isCeil=*/false); } + void visitDimExpr(AffineDimExpr expr) { operandExprStack.emplace_back(SmallVector(getNumCols(), 0)); auto &eq = operandExprStack.back(); assert(expr.getPosition() < numDims && "Inconsistent number of dims"); eq[getDimStartIndex() + expr.getPosition()] = 1; } + void visitSymbolExpr(AffineSymbolExpr expr) { operandExprStack.emplace_back(SmallVector(getNumCols(), 0)); auto &eq = operandExprStack.back(); assert(expr.getPosition() < numSymbols && "inconsistent number of symbols"); eq[getSymbolStartIndex() + expr.getPosition()] = 1; } + void visitConstantExpr(AffineConstantExpr expr) { operandExprStack.emplace_back(SmallVector(getNumCols(), 0)); auto &eq = operandExprStack.back(); @@ -250,9 +271,19 @@ public: } private: + // t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1 + // A floordiv is thus flattened by introducing a new local variable q, and + // replacing that expression with 'q' while adding the constraints + // c * q <= expr <= c * q + c - 1 to localVarCst. + // + // A ceildiv is similarly flattened: + // t = expr ceildiv c <=> t = q, c * q - (c - 1) <= expr <= c * q + // Note that although t = expr ceildiv c, it is equivalent to + // (expr + c - 1) floordiv c. void visitDivExpr(AffineBinaryOpExpr expr, bool isCeil) { assert(operandExprStack.size() >= 2); assert(expr.getRHS().isa()); + // This is a pure affine expr; the RHS is a positive constant. auto rhsConst = operandExprStack.back()[getConstantIndex()]; // TODO(bondhugula): handle division by zero at the same time the issue is diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 3dbbfa7a49d..f4f525bc470 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -484,7 +484,7 @@ FlatAffineConstraints::FlatAffineConstraints( auto otherIds = other.getIds(); ids.reserve(numReservedCols); - ids.insert(ids.end(), otherIds.begin(), otherIds.end()); + ids.append(otherIds.begin(), otherIds.end()); unsigned numReservedEqualities = other.getNumReservedEqualities(); unsigned numReservedInequalities = other.getNumReservedInequalities(); @@ -562,7 +562,7 @@ void FlatAffineConstraints::reset(unsigned numReservedInequalities, ids.resize(numIds, None); } else { ids.reserve(idArgs.size()); - ids.insert(ids.end(), idArgs.begin(), idArgs.end()); + ids.append(idArgs.begin(), idArgs.end()); } } @@ -1817,8 +1817,8 @@ void FlatAffineConstraints::FourierMotzkinEliminate( SmallVector, 8> newIds; newIds.reserve(numIds - 1); - newIds.insert(newIds.end(), ids.begin(), ids.begin() + pos); - newIds.insert(newIds.end(), ids.begin() + pos + 1, ids.end()); + newIds.append(ids.begin(), ids.begin() + pos); + newIds.append(ids.begin() + pos + 1, ids.end()); /// Create the new system which has one identifier less. FlatAffineConstraints newFac( diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp index c7bf2abd8d6..043d62d0cc9 100644 --- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp @@ -62,33 +62,6 @@ FunctionPass *mlir::createMemRefDependenceCheckPass() { return new MemRefDependenceCheck(); } -// Adds memref access indices 'opIndices' from 'memrefType' to 'access'. -static void addMemRefAccessIndices( - llvm::iterator_range opIndices, - MemRefType memrefType, MemRefAccess *access) { - access->indices.reserve(memrefType.getRank()); - for (auto *index : opIndices) { - access->indices.push_back(const_cast(index)); - } -} - -// Populates 'access' with memref, indices and opinst from 'loadOrStoreOpInst'. -static void getMemRefAccess(const OperationInst *loadOrStoreOpInst, - MemRefAccess *access) { - access->opInst = loadOrStoreOpInst; - if (auto loadOp = loadOrStoreOpInst->dyn_cast()) { - access->memref = loadOp->getMemRef(); - addMemRefAccessIndices(loadOp->getIndices(), loadOp->getMemRefType(), - access); - } else { - assert(loadOrStoreOpInst->isa()); - auto storeOp = loadOrStoreOpInst->dyn_cast(); - access->memref = storeOp->getMemRef(); - addMemRefAccessIndices(storeOp->getIndices(), storeOp->getMemRefType(), - access); - } -} - // Returns a result string which represents the direction vector (if there was // a dependence), returns the string "false" otherwise. static string @@ -118,12 +91,10 @@ getDirectionVectorStr(bool ret, unsigned numCommonLoops, unsigned loopNestDepth, static void checkDependences(ArrayRef loadsAndStores) { for (unsigned i = 0, e = loadsAndStores.size(); i < e; ++i) { auto *srcOpInst = loadsAndStores[i]; - MemRefAccess srcAccess; - getMemRefAccess(srcOpInst, &srcAccess); + MemRefAccess srcAccess(srcOpInst); for (unsigned j = 0; j < e; ++j) { auto *dstOpInst = loadsAndStores[j]; - MemRefAccess dstAccess; - getMemRefAccess(dstOpInst, &dstAccess); + MemRefAccess dstAccess(dstOpInst); unsigned numCommonLoops = getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst); diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index d94e0967dcd..9d89f04d41d 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -119,16 +119,12 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth, if ((loadOp = opInst->dyn_cast())) { rank = loadOp->getMemRefType().getRank(); - for (auto *index : loadOp->getIndices()) { - indices.push_back(index); - } + indices.append(loadOp->getIndices().begin(), loadOp->getIndices().end()); region->memref = loadOp->getMemRef(); region->setWrite(false); } else if ((storeOp = opInst->dyn_cast())) { rank = storeOp->getMemRefType().getRank(); - for (auto *index : storeOp->getIndices()) { - indices.push_back(index); - } + indices.append(storeOp->getIndices().begin(), storeOp->getIndices().end()); region->memref = storeOp->getMemRef(); region->setWrite(true); } else { @@ -442,25 +438,26 @@ ForInst *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess, return sliceLoopNest; } -void mlir::getMemRefAccess(OperationInst *loadOrStoreOpInst, - MemRefAccess *access) { +// Constructs MemRefAccess populating it with the memref, its indices and +// opinst from 'loadOrStoreOpInst'. +MemRefAccess::MemRefAccess(OperationInst *loadOrStoreOpInst) { if (auto loadOp = loadOrStoreOpInst->dyn_cast()) { - access->memref = loadOp->getMemRef(); - access->opInst = loadOrStoreOpInst; + memref = loadOp->getMemRef(); + opInst = loadOrStoreOpInst; auto loadMemrefType = loadOp->getMemRefType(); - access->indices.reserve(loadMemrefType.getRank()); + indices.reserve(loadMemrefType.getRank()); for (auto *index : loadOp->getIndices()) { - access->indices.push_back(index); + indices.push_back(index); } } else { assert(loadOrStoreOpInst->isa() && "load/store op expected"); auto storeOp = loadOrStoreOpInst->dyn_cast(); - access->opInst = loadOrStoreOpInst; - access->memref = storeOp->getMemRef(); + opInst = loadOrStoreOpInst; + memref = storeOp->getMemRef(); auto storeMemrefType = storeOp->getMemRefType(); - access->indices.reserve(storeMemrefType.getRank()); + indices.reserve(storeMemrefType.getRank()); for (auto *index : storeOp->getIndices()) { - access->indices.push_back(index); + indices.push_back(index); } } } diff --git a/mlir/lib/EDSC/Types.cpp b/mlir/lib/EDSC/Types.cpp index d762e3f732e..30ae5ab00ff 100644 --- a/mlir/lib/EDSC/Types.cpp +++ b/mlir/lib/EDSC/Types.cpp @@ -178,7 +178,7 @@ Stmt ForNest(MutableArrayRef indices, ArrayRef lbs, Expr load(Expr m, llvm::ArrayRef indices) { SmallVector exprs; exprs.push_back(m); - exprs.insert(exprs.end(), indices.begin(), indices.end()); + exprs.append(indices.begin(), indices.end()); return VariadicExpr(ExprKind::Load, exprs); } @@ -186,7 +186,7 @@ Expr store(Expr val, Expr m, llvm::ArrayRef indices) { SmallVector exprs; exprs.push_back(val); exprs.push_back(m); - exprs.insert(exprs.end(), indices.begin(), indices.end()); + exprs.append(indices.begin(), indices.end()); return VariadicExpr(ExprKind::Store, exprs); } diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index c90a0d40056..5790b1ad938 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -798,7 +798,7 @@ ParseResult TensorLiteralParser::parseList(llvm::SmallVectorImpl &dims) { // Return the sublists' dimensions with 'size' prepended. dims.clear(); dims.push_back(size); - dims.insert(dims.end(), newDims.begin(), newDims.end()); + dims.append(newDims.begin(), newDims.end()); return ParseSuccess; } diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 31b59d85e14..2a004492d84 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -87,16 +87,13 @@ struct FusionCandidate { MemRefAccess srcAccess; // Load or store access within dst loop nest. MemRefAccess dstAccess; + explicit FusionCandidate(OperationInst *src, OperationInst *dst) + : srcAccess(MemRefAccess(src)), dstAccess(MemRefAccess(dst)) {} }; static FusionCandidate buildFusionCandidate(OperationInst *srcStoreOpInst, OperationInst *dstLoadOpInst) { - FusionCandidate candidate; - // Get store access for src loop nest. - getMemRefAccess(srcStoreOpInst, &candidate.srcAccess); - // Get load access for dst loop nest. - getMemRefAccess(dstLoadOpInst, &candidate.dstAccess); - return candidate; + return FusionCandidate(srcStoreOpInst, dstLoadOpInst); } namespace { diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 085a9c0b0fe..ee66c9b17b1 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -86,8 +86,8 @@ static void constructTiledIndexSetHyperRect(ArrayRef origLoops, for (unsigned i = 0; i < width; i++) { auto lbOperands = origLoops[i]->getLowerBoundOperands(); auto ubOperands = origLoops[i]->getUpperBoundOperands(); - SmallVector newLbOperands(lbOperands.begin(), lbOperands.end()); - SmallVector newUbOperands(ubOperands.begin(), ubOperands.end()); + SmallVector newLbOperands(lbOperands); + SmallVector newUbOperands(ubOperands); newLoops[i]->setLowerBound(newLbOperands, origLoops[i]->getLowerBoundMap()); newLoops[i]->setUpperBound(newUbOperands, origLoops[i]->getUpperBoundMap()); newLoops[i]->setStep(tileSizes[i]); @@ -121,7 +121,7 @@ static void constructTiledIndexSetHyperRect(ArrayRef origLoops, // The new upper bound map is the original one with an additional // expression i + tileSize appended. boundExprs.push_back(dim + tileSizes[i]); - boundExprs.insert(boundExprs.end(), origUbMap.getResults().begin(), + boundExprs.append(origUbMap.getResults().begin(), origUbMap.getResults().end()); auto ubMap = b.getAffineMap(origUbMap.getNumInputs() + 1, 0, boundExprs, {}); diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index 49b33b0596b..adf91b76276 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -128,9 +128,8 @@ void MemRefDataFlowOpt::visitOperationInst(OperationInst *opInst) { // post-dominance on these. 'fwdingCandidates' are a subset of depSrcStores. SmallVector depSrcStores; for (auto *storeOpInst : storeOps) { - MemRefAccess srcAccess, destAccess; - getMemRefAccess(storeOpInst, &srcAccess); - getMemRefAccess(loadOpInst, &destAccess); + MemRefAccess srcAccess(storeOpInst); + MemRefAccess destAccess(loadOpInst); FlatAffineConstraints dependenceConstraints; unsigned nsLoops = getNumCommonSurroundingLoops(*loadOpInst, *storeOpInst); // Dependences at loop depth <= minSurroundingLoops do NOT matter. diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 4af9436b44d..cf9da344b82 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -117,7 +117,7 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, opInst->getName()); state.operands.reserve(opInst->getNumOperands() + extraIndices.size()); // Insert the non-memref operands. - state.operands.insert(state.operands.end(), opInst->operand_begin(), + state.operands.append(opInst->operand_begin(), opInst->operand_begin() + memRefOperandPos); state.operands.push_back(newMemRef); @@ -138,11 +138,10 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, // at position memRefOperandPos + 1. SmallVector remapOperands; remapOperands.reserve(oldMemRefRank + extraOperands.size()); - remapOperands.insert(remapOperands.end(), extraOperands.begin(), - extraOperands.end()); - remapOperands.insert( - remapOperands.end(), opInst->operand_begin() + memRefOperandPos + 1, - opInst->operand_begin() + memRefOperandPos + 1 + oldMemRefRank); + remapOperands.append(extraOperands.begin(), extraOperands.end()); + remapOperands.append(opInst->operand_begin() + memRefOperandPos + 1, + opInst->operand_begin() + memRefOperandPos + 1 + + oldMemRefRank); if (indexRemap) { auto remapOp = builder.create(opInst->getLoc(), indexRemap, remapOperands); @@ -156,8 +155,7 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, } // Insert the remaining operands unmodified. - state.operands.insert(state.operands.end(), - opInst->operand_begin() + memRefOperandPos + 1 + + state.operands.append(opInst->operand_begin() + memRefOperandPos + 1 + oldMemRefRank, opInst->operand_end()); @@ -167,7 +165,7 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, state.types.push_back(result->getType()); // Attributes also do not change. - state.attributes.insert(state.attributes.end(), opInst->getAttrs().begin(), + state.attributes.append(opInst->getAttrs().begin(), opInst->getAttrs().end()); // Create the new operation. @@ -206,14 +204,9 @@ mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc, } // Compose affine maps from all ancestor AffineApplyOps. // Create new AffineApplyOp from 'valueMap'. - unsigned numOperands = valueMap.getNumOperands(); - SmallVector outOperands(numOperands); - for (unsigned i = 0; i < numOperands; ++i) { - outOperands[i] = valueMap.getOperand(i); - } // Create new AffineApplyOp based on 'valueMap'. - auto affineApplyOp = - builder->create(loc, valueMap.getAffineMap(), outOperands); + auto affineApplyOp = builder->create( + loc, valueMap.getAffineMap(), valueMap.getOperands()); results->resize(operands.size()); for (unsigned i = 0, e = operands.size(); i < e; ++i) { (*results)[i] = affineApplyOp->getResult(i); @@ -340,13 +333,8 @@ void mlir::forwardSubstitute(OpPointer affineApplyOp) { valueMap.forwardSubstituteSingle(*affineApplyOp, resultIndex); // Create new AffineApplyOp from 'valueMap'. - unsigned numOperands = valueMap.getNumOperands(); - SmallVector operands(numOperands); - for (unsigned i = 0; i < numOperands; ++i) { - operands[i] = valueMap.getOperand(i); - } auto newAffineApplyOp = builder.create( - useOpInst->getLoc(), valueMap.getAffineMap(), operands); + useOpInst->getLoc(), valueMap.getAffineMap(), valueMap.getOperands()); // Update all uses to use results from 'newAffineApplyOp'. for (unsigned i = 0, e = useOpInst->getNumResults(); i < e; ++i) { -- cgit v1.2.3 From 21baf86a2f454fb1387f3f670ade7c507a53e2e6 Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Mon, 7 Jan 2019 17:34:26 -0800 Subject: Extend loop-fusion's slicing utility + other fixes / updates - refactor toAffineFromEq and the code surrounding it; refactor code into FlatAffineConstraints::getSliceBounds - add FlatAffineConstraints methods to detect identifiers as mod's and div's of other identifiers - add FlatAffineConstraints::getConstantLower/UpperBound - Address b/122118218 (don't assert on invalid fusion depths cmdline flags - instead, don't do anything; change cmdline flags src-loop-depth -> fusion-src-loop-depth - AffineExpr/Map print method update: don't fail on null instances (since we have a wrapper around a pointer, it's avoidable); rationale: dump/print methods should never fail if possible. - Update memref-dataflow-opt to add an optimization to avoid a unnecessary call to IsRangeOneToOne when it's trivially going to be true. - Add additional test cases to exercise the new support - update a few existing test cases since the maps are now generated uniformly with all destination loop operands appearing for the backward slice - Fix projectOut - fix wrong range for getBestElimCandidate. - Fix for getConstantBoundOnDimSize() - didn't show up in any test cases since we didn't have any non-hyperrectangular ones. PiperOrigin-RevId: 228265152 --- mlir/include/mlir/Analysis/AffineStructures.h | 78 +++-- mlir/include/mlir/IR/AffineExpr.h | 2 + mlir/lib/Analysis/AffineAnalysis.cpp | 101 +++++-- mlir/lib/Analysis/AffineStructures.cpp | 397 ++++++++++++++++++++++---- mlir/lib/Analysis/Utils.cpp | 121 ++++---- mlir/lib/IR/AsmPrinter.cpp | 8 + mlir/lib/Transforms/LoopFusion.cpp | 4 +- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 17 +- mlir/test/Transforms/loop-fusion.mlir | 176 ++++++++++-- mlir/test/Transforms/memref-dataflow-opt.mlir | 2 +- 10 files changed, 715 insertions(+), 191 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h index ae8cda997a1..4d78a24c8f1 100644 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -387,18 +387,15 @@ public: AffineExpr toAffineExpr(unsigned idx, MLIRContext *context); - // Returns an AffineMap that expresses the identifier at pos as a function of - // other dimensional and symbolic identifiers. - // If 'nonZeroDimIds' and 'nonZeroSymbolIds' are non-null, they are populated - // with the positions of the non-zero equality constraint coefficients which - // were used to build the returned AffineMap. - // Returns AffineMap::Null if such an expression can't be constructed. - // TODO(andydavis) Remove 'nonZeroDimIds' and 'nonZeroSymbolIds' from this - // API when we can manage the mapping of Values and ids in the constraint - // system. - AffineMap toAffineMapFromEq(unsigned pos, MLIRContext *context, - SmallVectorImpl *nonZeroDimIds, - SmallVectorImpl *nonZeroSymbolIds); + /// Computes the lower and upper bounds of the first 'num' dimensional + /// identifiers as an affine map of the remaining identifiers (dimensional and + /// symbolic). This method is able to detect identifiers as floordiv's + /// and mod's of affine expressions of other identifiers with respect to + /// (positive) constants. Sets bound map to AffineMap::Null if such a bound + /// can't be found (or yet unimplemented). + void getSliceBounds(unsigned num, MLIRContext *context, + SmallVectorImpl *lbMaps, + SmallVectorImpl *ubMaps); // Adds an inequality (>= 0) from the coefficients specified in inEq. void addInequality(ArrayRef inEq); @@ -513,6 +510,7 @@ public: inline unsigned getNumIds() const { return numIds; } inline unsigned getNumDimIds() const { return numDims; } inline unsigned getNumSymbolIds() const { return numSymbols; } + inline unsigned getNumDimAndSymbolIds() const { return numDims + numSymbols; } inline unsigned getNumLocalIds() const { return numIds - numDims - numSymbols; } @@ -521,24 +519,43 @@ public: return {ids.data(), ids.size()}; } - /// Returns the Value's associated with the identifiers. Asserts if - /// no Value was associated with an identifier. - inline void getIdValues(SmallVectorImpl *values) const { - values->clear(); - values->reserve(numIds); - for (unsigned i = 0; i < numIds; i++) { - assert(ids[i].hasValue() && "identifier's Value not set"); - values->push_back(ids[i].getValue()); - } - } - /// Returns the Value associated with the pos^th identifier. Asserts if /// no Value identifier was associated. inline Value *getIdValue(unsigned pos) const { - assert(ids[pos].hasValue() && "identifier's ML Value not set"); + assert(ids[pos].hasValue() && "identifier's Value not set"); return ids[pos].getValue(); } + /// Returns the Values associated with identifiers in range [start, end). + /// Asserts if no Value was associated with one of these identifiers. + void getIdValues(unsigned start, unsigned end, + SmallVectorImpl *values) const { + assert((start < numIds || start == end) && "invalid start position"); + assert(end <= numIds && "invalid end position"); + values->clear(); + values->reserve(end - start); + for (unsigned i = start; i < end; i++) { + values->push_back(getIdValue(i)); + } + } + inline void getAllIdValues(SmallVectorImpl *values) const { + getIdValues(0, numIds, values); + } + + /// Sets Value associated with the pos^th identifier. + inline void setIdValue(unsigned pos, Value *val) { + assert(pos < numIds && "invalid id position"); + ids[pos] = val; + } + /// Sets Values associated with identifiers in the range [start, end). + void setIdValues(unsigned start, unsigned end, ArrayRef values) { + assert((start < numIds || end == start) && "invalid start position"); + assert(end <= numIds && "invalid end position"); + assert(values.size() == end - start); + for (unsigned i = start; i < end; ++i) + ids[i] = values[i - start]; + } + /// Clears this list of constraints and copies other into it. void clearAndCopyFrom(const FlatAffineConstraints &other); @@ -555,6 +572,14 @@ public: getConstantBoundOnDimSize(unsigned pos, SmallVectorImpl *lb = nullptr) const; + /// Returns the constant lower bound for the pos^th identifier if there is + /// one; None otherwise. + Optional getConstantLowerBound(unsigned pos) const; + + /// Returns the constant upper bound for the pos^th identifier if there is + /// one; None otherwise. + Optional getConstantUpperBound(unsigned pos) const; + /// Returns true if the set can be trivially detected as being /// hyper-rectangular on the specified contiguous set of identifiers. bool isHyperRectangular(unsigned pos, unsigned num) const; @@ -579,6 +604,11 @@ private: /// 'false'otherwise. bool hasInvalidConstraint() const; + /// Returns the constant lower bound bound if isLower is true, and the upper + /// bound if isLower is false. + template + Optional getConstantLowerOrUpperBound(unsigned pos) const; + // Eliminates a single identifier at 'position' from equality and inequality // constraints. Returns 'true' if the identifier was eliminated, and false // otherwise. diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h index 5de382d09a3..51703acb55b 100644 --- a/mlir/include/mlir/IR/AffineExpr.h +++ b/mlir/include/mlir/IR/AffineExpr.h @@ -83,6 +83,8 @@ public: return *this; } + static AffineExpr Null() { return AffineExpr(nullptr); } + bool operator==(AffineExpr other) const { return expr == other.expr; } bool operator!=(AffineExpr other) const { return !(*this == other); } explicit operator bool() const { return expr; } diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index 4485326c897..99c65962ec7 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -690,9 +690,9 @@ private: // Builds a map from Value to identifier position in a new merged identifier // list, which is the result of merging dim/symbol lists from src/dst -// iteration domains. The format of the new merged list is as follows: +// iteration domains, the format of which is as follows: // -// [src-dim-identifiers, dst-dim-identifiers, symbol-identifiers] +// [src-dim-identifiers, dst-dim-identifiers, symbol-identifiers, const_term] // // This method populates 'valuePosMap' with mappings from operand Values in // 'srcAccessMap'/'dstAccessMap' (as well as those in 'srcDomain'/'dstDomain') @@ -700,22 +700,26 @@ private: static void buildDimAndSymbolPositionMaps( const FlatAffineConstraints &srcDomain, const FlatAffineConstraints &dstDomain, const AffineValueMap &srcAccessMap, - const AffineValueMap &dstAccessMap, ValuePositionMap *valuePosMap) { + const AffineValueMap &dstAccessMap, ValuePositionMap *valuePosMap, + FlatAffineConstraints *dependenceConstraints) { auto updateValuePosMap = [&](ArrayRef values, bool isSrc) { for (unsigned i = 0, e = values.size(); i < e; ++i) { auto *value = values[i]; - if (!isa(values[i])) + if (!isa(values[i])) { + assert(values[i]->isValidSymbol() && + "access operand has to be either a loop IV or a symbol"); valuePosMap->addSymbolValue(value); - else if (isSrc) + } else if (isSrc) { valuePosMap->addSrcValue(value); - else + } else { valuePosMap->addDstValue(value); + } } }; SmallVector srcValues, destValues; - srcDomain.getIdValues(&srcValues); - dstDomain.getIdValues(&destValues); + srcDomain.getAllIdValues(&srcValues); + dstDomain.getAllIdValues(&destValues); // Update value position map with identifiers from src iteration domain. updateValuePosMap(srcValues, /*isSrc=*/true); @@ -727,6 +731,65 @@ static void buildDimAndSymbolPositionMaps( updateValuePosMap(dstAccessMap.getOperands(), /*isSrc=*/false); } +// Sets up dependence constraints columns appropriately, in the format: +// [src-dim-identifiers, dst-dim-identifiers, symbol-identifiers, const_term] +void initDependenceConstraints(const FlatAffineConstraints &srcDomain, + const FlatAffineConstraints &dstDomain, + const AffineValueMap &srcAccessMap, + const AffineValueMap &dstAccessMap, + const ValuePositionMap &valuePosMap, + FlatAffineConstraints *dependenceConstraints) { + // Calculate number of equalities/inequalities and columns required to + // initialize FlatAffineConstraints for 'dependenceDomain'. + unsigned numIneq = + srcDomain.getNumInequalities() + dstDomain.getNumInequalities(); + AffineMap srcMap = srcAccessMap.getAffineMap(); + assert(srcMap.getNumResults() == dstAccessMap.getAffineMap().getNumResults()); + unsigned numEq = srcMap.getNumResults(); + unsigned numDims = srcDomain.getNumDimIds() + dstDomain.getNumDimIds(); + unsigned numSymbols = valuePosMap.getNumSymbols(); + unsigned numIds = numDims + numSymbols; + unsigned numCols = numIds + 1; + + // Set flat affine constraints sizes and reserving space for constraints. + dependenceConstraints->reset(numIneq, numEq, numCols, numDims, numSymbols, + /*numLocals=*/0); + + // Set values corresponding to dependence constraint identifiers. + SmallVector srcLoopIVs, dstLoopIVs; + srcDomain.getIdValues(0, srcDomain.getNumDimIds(), &srcLoopIVs); + dstDomain.getIdValues(0, dstDomain.getNumDimIds(), &dstLoopIVs); + + dependenceConstraints->setIdValues(0, srcLoopIVs.size(), srcLoopIVs); + dependenceConstraints->setIdValues( + srcLoopIVs.size(), srcLoopIVs.size() + dstLoopIVs.size(), dstLoopIVs); + + // Set values for the symbolic identifier dimensions. + auto setSymbolIds = [&](ArrayRef values) { + for (auto *value : values) { + if (!isa(value)) { + assert(value->isValidSymbol() && "expected symbol"); + dependenceConstraints->setIdValue(valuePosMap.getSymPos(value), value); + } + } + }; + + setSymbolIds(srcAccessMap.getOperands()); + setSymbolIds(dstAccessMap.getOperands()); + + SmallVector srcSymbolValues, dstSymbolValues; + srcDomain.getIdValues(srcDomain.getNumDimIds(), + srcDomain.getNumDimAndSymbolIds(), &srcSymbolValues); + dstDomain.getIdValues(dstDomain.getNumDimIds(), + dstDomain.getNumDimAndSymbolIds(), &dstSymbolValues); + setSymbolIds(srcSymbolValues); + setSymbolIds(dstSymbolValues); + + for (unsigned i = 0, e = dependenceConstraints->getNumDimAndSymbolIds(); + i < e; i++) + assert(dependenceConstraints->getIds()[i].hasValue()); +} + // Adds iteration domain constraints from 'srcDomain' and 'dstDomain' into // 'dependenceDomain'. // Uses 'valuePosMap' to determine the position in 'dependenceDomain' to which a @@ -1278,25 +1341,15 @@ bool mlir::checkMemrefAccessDependence( // Value to position in merged contstraint system. ValuePositionMap valuePosMap; buildDimAndSymbolPositionMaps(srcDomain, dstDomain, srcAccessMap, - dstAccessMap, &valuePosMap); + dstAccessMap, &valuePosMap, + dependenceConstraints); + + initDependenceConstraints(srcDomain, dstDomain, srcAccessMap, dstAccessMap, + valuePosMap, dependenceConstraints); + assert(valuePosMap.getNumDims() == srcDomain.getNumDimIds() + dstDomain.getNumDimIds()); - // Calculate number of equalities/inequalities and columns required to - // initialize FlatAffineConstraints for 'dependenceDomain'. - unsigned numIneq = - srcDomain.getNumInequalities() + dstDomain.getNumInequalities(); - AffineMap srcMap = srcAccessMap.getAffineMap(); - assert(srcMap.getNumResults() == dstAccessMap.getAffineMap().getNumResults()); - unsigned numEq = srcMap.getNumResults(); - unsigned numDims = srcDomain.getNumDimIds() + dstDomain.getNumDimIds(); - unsigned numSymbols = valuePosMap.getNumSymbols(); - unsigned numIds = numDims + numSymbols; - unsigned numCols = numIds + 1; - - // Create flat affine constraints reserving space for 'numEq' and 'numIneq'. - dependenceConstraints->reset(numIneq, numEq, numCols, numDims, numSymbols, - /*numLocals=*/0); // Create memref access constraint by equating src/dst access functions. // Note that this check is conservative, and will failure in the future // when local variables for mod/div exprs are supported. diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index f4f525bc470..38cc612505d 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1137,50 +1137,268 @@ unsigned FlatAffineConstraints::gaussianEliminateIds(unsigned posStart, return posLimit - posStart; } -AffineMap FlatAffineConstraints::toAffineMapFromEq( - unsigned pos, MLIRContext *context, - SmallVectorImpl *nonZeroDimIds, - SmallVectorImpl *nonZeroSymbolIds) { - - // For now just project out local IDs, and return null if we can't - // find an equality. TODO(bondhugula): infer as a function of other - // dims/symbols involving mod/div. - projectOut(getNumIds() - getNumLocalIds(), getNumLocalIds()); - - unsigned idx; - if (!findConstraintWithNonZeroAt(*this, pos, /*isEq=*/true, &idx)) - return AffineMap::Null(); - - // Build AffineExpr solving for identifier 'pos' in terms of all others. - auto expr = getAffineConstantExpr(0, context); - unsigned mapNumDims = 0; - unsigned mapNumSymbols = 0; - for (unsigned j = 0, e = getNumIds(); j < e; ++j) { - if (j == pos) - continue; - int64_t c = atEq(idx, j); - if (c == 0) +// Detect the identifier at 'pos' (say id_r) as modulo of another identifier +// (say id_n) w.r.t a constant. When this happens, another identifier (say id_q) +// could be detected as the floordiv of n. For eg: +// id_n - 4*id_q - id_r = 0, 0 <= id_r <= 3 <=> +// id_r = id_n mod 4, id_q = id_n floordiv 4. +// lbConst and ubConst are the constant lower and upper bounds for 'pos' - +// pre-detected at the caller. +static bool detectAsMod(const FlatAffineConstraints &cst, unsigned pos, + int64_t lbConst, int64_t ubConst, + SmallVectorImpl *memo) { + assert(pos < cst.getNumIds() && "invalid position"); + + // Check if 0 <= id_r <= divisor - 1 and if id_r is equal to + // id_n - divisor * id_q. If these are true, then id_n becomes the dividend + // and id_q the quotient when dividing id_n by the divisor. + + if (lbConst != 0 || ubConst < 1) + return false; + + int64_t divisor = ubConst + 1; + + // Now check for: id_r = id_n - divisor * id_q. As an example, we + // are looking r = d - 4q, i.e., either r - d + 4q = 0 or -r + d - 4q = 0. + unsigned seenQuotient = 0, seenDividend = 0; + int quotientPos = -1, dividendPos = -1; + for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) { + // id_n should have coeff 1 or -1. + if (std::abs(cst.atEq(r, pos)) != 1) continue; - if (j < numDims) { - expr = expr - getAffineDimExpr(mapNumDims++, context) * c; - nonZeroDimIds->push_back(j); - } else { - expr = - expr - getAffineSymbolExpr(mapNumDims + mapNumSymbols++, context) * c; - nonZeroSymbolIds->push_back(j); + for (unsigned c = 0, f = cst.getNumDimAndSymbolIds(); c < f; c++) { + // The coeff of the quotient should be -divisor if the coefficient of + // the pos^th identifier is -1, and divisor if the latter is -1. + if (cst.atEq(r, c) * cst.atEq(r, pos) == divisor) { + seenQuotient++; + quotientPos = c; + } else if (cst.atEq(r, c) * cst.atEq(r, pos) == -1) { + seenDividend++; + dividendPos = c; + } + } + // We are looking for exactly one identifier as part of the dividend. + // TODO(bondhugula): could be extended to cover multiple ones in the + // dividend to detect mod of an affine function of identifiers. + if (seenDividend == 1 && seenQuotient >= 1) { + if (!(*memo)[dividendPos]) + return false; + // Successfully detected a mod. + (*memo)[pos] = (*memo)[dividendPos] % divisor; + if (seenQuotient == 1 && !(*memo)[quotientPos]) + // Successfully detected a floordiv as well. + (*memo)[quotientPos] = (*memo)[dividendPos].floorDiv(divisor); + return true; } } - // Add constant term to AffineExpr. - expr = expr - atEq(idx, getNumIds()); - int64_t v = atEq(idx, pos); - assert(v != 0 && "expected non-zero here"); - if (v > 0) - expr = expr.floorDiv(v); - else - // v < 0. - expr = (-expr).floorDiv(-v); + return false; +} + +// Check if the pos^th identifier can be expressed as a floordiv of an affine +// function of other identifiers (where the divisor is a positive constant). +// For eg: 4q <= i + j <= 4q + 3 <=> q = (i + j) floordiv 4. +bool detectAsFloorDiv(const FlatAffineConstraints &cst, unsigned pos, + SmallVectorImpl *memo, MLIRContext *context) { + assert(pos < cst.getNumIds() && "invalid position"); + SmallVector lbIndices, ubIndices; - return AffineMap::get(mapNumDims, mapNumSymbols, {expr}, {}); + // Gather all lower bounds and upper bound constraints of this identifier. + // Since the canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint + // is a lower bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. + for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) { + if (cst.atIneq(r, pos) >= 1) + // Lower bound. + lbIndices.push_back(r); + else if (cst.atIneq(r, pos) <= -1) + // Upper bound. + ubIndices.push_back(r); + } + + // Check if any lower bound, upper bound pair is of the form: + // divisor * id >= expr - (divisor - 1) <-- Lower bound for 'id' + // divisor * id <= expr <-- Upper bound for 'id' + // Then, 'id' is equivalent to 'expr floordiv divisor'. (where divisor > 1). + // + // For example, if -32*k + 16*i + j >= 0 + // 32*k - 16*i - j + 31 >= 0 <=> + // k = ( 16*i + j ) floordiv 32 + unsigned seenDividends = 0; + for (auto ubPos : ubIndices) { + for (auto lbPos : lbIndices) { + // Check if lower bound's constant term is 'divisor - 1'. The 'divisor' + // here is cst.atIneq(lbPos, pos) and we already know that it's positive + // (since cst.Ineq(lbPos, ...) is a lower bound expression for 'pos'. + if (cst.atIneq(lbPos, cst.getNumCols() - 1) != cst.atIneq(lbPos, pos) - 1) + continue; + // Check if upper bound's constant term is 0. + if (cst.atIneq(ubPos, cst.getNumCols() - 1) != 0) + continue; + // For the remaining part, check if the lower bound expr's coeff's are + // negations of corresponding upper bound ones'. + unsigned c, f; + for (c = 0, f = cst.getNumCols() - 1; c < f; c++) { + if (cst.atIneq(lbPos, c) != -cst.atIneq(ubPos, c)) + break; + if (c != pos && cst.atIneq(lbPos, c) != 0) + seenDividends++; + } + // Lb coeff's aren't negative of ub coeff's (for the non constant term + // part). + if (c < f) + continue; + if (seenDividends >= 1) { + // The divisor is the constant term of the lower bound expression. + // We already know that cst.atIneq(lbPos, pos) > 0. + int64_t divisor = cst.atIneq(lbPos, pos); + // Construct the dividend expression. + auto dividendExpr = getAffineConstantExpr(0, context); + unsigned c, f; + for (c = 0, f = cst.getNumCols() - 1; c < f; c++) { + if (c == pos) + continue; + int64_t ubVal = cst.atIneq(ubPos, c); + if (ubVal == 0) + continue; + if (!(*memo)[c]) + break; + dividendExpr = dividendExpr + ubVal * (*memo)[c]; + } + // Expression can't be constructed as it depends on a yet unknown + // identifier. + // TODO(mlir-team): Visit/compute the identifiers in an order so that + // this doesn't happen. More complex but much more efficient. + if (c < f) + continue; + // Successfully detected the floordiv. + (*memo)[pos] = dividendExpr.floorDiv(divisor); + return true; + } + } + } + return false; +} + +/// Computes the lower and upper bounds of the first 'num' dimensional +/// identifiers as affine maps of the remaining identifiers (dimensional and +/// symbolic identifiers). Local identifiers are themselves explicitly computed +/// as affine functions of other identifiers in this process if needed. +void FlatAffineConstraints::getSliceBounds(unsigned num, MLIRContext *context, + SmallVectorImpl *lbMaps, + SmallVectorImpl *ubMaps) { + assert(num < getNumDimIds() && "invalid range"); + + // Basic simplification. + normalizeConstraintsByGCD(); + + LLVM_DEBUG(llvm::dbgs() << "getSliceBounds on:\n"); + LLVM_DEBUG(dump()); + + // Record computed/detected identifiers. + SmallVector memo(getNumIds(), AffineExpr::Null()); + // Initialize dimensional and symbolic identifiers. + for (unsigned i = num, e = getNumDimIds(); i < e; i++) + memo[i] = getAffineDimExpr(i - num, context); + for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++) + memo[i] = getAffineSymbolExpr(i - getNumDimIds(), context); + + bool changed; + do { + changed = false; + // Identify yet unknown identifiers as constants or mod's / floordiv's of + // other identifiers if possible. + for (unsigned pos = 0; pos < getNumIds(); pos++) { + if (memo[pos]) + continue; + + auto lbConst = getConstantLowerBound(pos); + auto ubConst = getConstantUpperBound(pos); + if (lbConst.hasValue() && ubConst.hasValue()) { + // Detect equality to a constant. + if (lbConst.getValue() == ubConst.getValue()) { + memo[pos] = getAffineConstantExpr(lbConst.getValue(), context); + changed = true; + continue; + } + + // Detect an identifier as modulo of another identifier w.r.t a + // constant. + if (detectAsMod(*this, pos, lbConst.getValue(), ubConst.getValue(), + &memo)) { + changed = true; + continue; + } + } + + // Detect an identifier as floordiv of another identifier w.r.t a + // constant. + if (detectAsFloorDiv(*this, pos, &memo, context)) { + changed = true; + continue; + } + + // Detect an identifier as an expression of other identifiers. + unsigned idx; + if (!findConstraintWithNonZeroAt(*this, pos, /*isEq=*/true, &idx)) { + continue; + } + + // Build AffineExpr solving for identifier 'pos' in terms of all others. + auto expr = getAffineConstantExpr(0, context); + unsigned j, e; + for (j = 0, e = getNumIds(); j < e; ++j) { + if (j == pos) + continue; + int64_t c = atEq(idx, j); + if (c == 0) + continue; + // If any of the involved IDs hasn't been found yet, we can't proceed. + if (!memo[j]) + break; + expr = expr + memo[j] * c; + } + if (j < e) + // Can't construct expression as it depends on a yet uncomputed + // identifier. + continue; + + // Add constant term to AffineExpr. + expr = expr + atEq(idx, getNumIds()); + int64_t vPos = atEq(idx, pos); + assert(vPos != 0 && "expected non-zero here"); + if (vPos > 0) + expr = (-expr).floorDiv(vPos); + else + // vPos < 0. + expr = expr.floorDiv(-vPos); + // Successfully constructed expression. + memo[pos] = expr; + changed = true; + } + // This loop is guaranteed to reach a fixed point - since once an + // identifier's explicit form is computed (in memo[pos]), it's not updated + // again. + } while (changed); + + // Set the lower and upper bound maps for all the identifiers that were + // computed as affine expressions of the rest as the "detected expr" and + // "detected expr + 1" respectively; set the undetected ones to Null(). + for (unsigned pos = 0; pos < num; pos++) { + unsigned numMapDims = getNumDimIds() - num; + unsigned numMapSymbols = getNumSymbolIds(); + AffineExpr expr = memo[pos]; + if (expr) + expr = simplifyAffineExpr(expr, numMapDims, numMapSymbols); + + if (expr) { + (*lbMaps)[pos] = AffineMap::get(numMapDims, numMapSymbols, expr, {}); + (*ubMaps)[pos] = AffineMap::get(numMapDims, numMapSymbols, expr + 1, {}); + } else { + (*lbMaps)[pos] = AffineMap::Null(); + (*ubMaps)[pos] = AffineMap::Null(); + } + LLVM_DEBUG(llvm::dbgs() << "lb map for pos = " << Twine(pos) << ", expr: "); + LLVM_DEBUG(expr.dump();); + } } void FlatAffineConstraints::addEquality(ArrayRef eq) { @@ -1456,7 +1674,7 @@ bool FlatAffineConstraints::constantFoldId(unsigned pos) { // atEq(rowIdx, pos) is either -1 or 1. assert(atEq(rowIdx, pos) * atEq(rowIdx, pos) == 1); - int64_t constVal = atEq(rowIdx, getNumCols() - 1) / -atEq(rowIdx, pos); + int64_t constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos); setAndEliminate(pos, constVal); return true; } @@ -1513,19 +1731,24 @@ Optional FlatAffineConstraints::getConstantBoundOnDimSize( if (atIneq(r, pos) != 0) break; } - if (r == e) { - // If it doesn't appear, just remove the column and return. - // TODO(andydavis,bondhugula): refactor removeColumns to use it from here. + if (r == e) + // If it doesn't, there isn't a bound on it. return None; - } // Positions of constraints that are lower/upper bounds on the variable. SmallVector lbIndices, ubIndices; - // Gather all lower bounds and upper bounds of the variable. Since the - // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower - // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. + // Gather all symbolic lower bounds and upper bounds of the variable. Since + // the canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a + // lower bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { + unsigned c, f; + for (c = 0, f = getNumDimIds(); c < f; c++) { + if (c != pos && atIneq(r, c) != 0) + break; + } + if (c < getNumDimIds()) + continue; if (atIneq(r, pos) >= 1) // Lower bound. lbIndices.push_back(r); @@ -1554,10 +1777,10 @@ Optional FlatAffineConstraints::getConstantBoundOnDimSize( } if (j < getNumCols() - 1) continue; - int64_t mayDiff = + int64_t diff = atIneq(ubPos, getNumCols() - 1) + atIneq(lbPos, getNumCols() - 1) + 1; - if (minDiff == None || mayDiff < minDiff) { - minDiff = mayDiff; + if (minDiff == None || diff < minDiff) { + minDiff = diff; minLbPosition = lbPos; } } @@ -1572,6 +1795,71 @@ Optional FlatAffineConstraints::getConstantBoundOnDimSize( return minDiff; } +template +Optional +FlatAffineConstraints::getConstantLowerOrUpperBound(unsigned pos) const { + // Check if there's an equality equating the 'pos'^th identifier to a + // constant. + int eqRowIdx = findEqualityToConstant(*this, pos, /*symbolic=*/false); + if (eqRowIdx != -1) + // atEq(rowIdx, pos) is either -1 or 1. + return -atEq(eqRowIdx, getNumCols() - 1) / atEq(eqRowIdx, pos); + + // Check if the identifier appears at all in any of the inequalities. + unsigned r, e; + for (r = 0, e = getNumInequalities(); r < e; r++) { + if (atIneq(r, pos) != 0) + break; + } + if (r == e) + // If it doesn't, there isn't a bound on it. + return None; + + Optional minOrMaxConst = None; + + // Take the max across all const lower bounds (or min across all constant + // upper bounds). + for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { + if (isLower) { + if (atIneq(r, pos) <= 0) + // Not a lower bound. + continue; + } else if (atIneq(r, pos) >= 0) { + // Not an upper bound. + continue; + } + unsigned c, f; + for (c = 0, f = getNumCols() - 1; c < f; c++) + if (c != pos && atIneq(r, c) != 0) + break; + if (c < getNumCols() - 1) + // Not a constant bound. + continue; + + int64_t boundConst = + isLower ? mlir::ceilDiv(-atIneq(r, getNumCols() - 1), atIneq(r, pos)) + : mlir::floorDiv(atIneq(r, getNumCols() - 1), -atIneq(r, pos)); + if (isLower) { + if (minOrMaxConst == None || boundConst > minOrMaxConst) + minOrMaxConst = boundConst; + } else { + if (minOrMaxConst == None || boundConst < minOrMaxConst) + minOrMaxConst = boundConst; + } + } + return minOrMaxConst; +} + +Optional +FlatAffineConstraints::getConstantLowerBound(unsigned pos) const { + return getConstantLowerOrUpperBound(pos); +} + +Optional +FlatAffineConstraints::getConstantUpperBound(unsigned pos) const { + return getConstantLowerOrUpperBound(pos); +} + // A simple (naive and conservative) check for hyper-rectangularlity. bool FlatAffineConstraints::isHyperRectangular(unsigned pos, unsigned num) const { @@ -1912,7 +2200,7 @@ void FlatAffineConstraints::projectOut(unsigned pos, unsigned num) { return; // 'pos' can be at most getNumCols() - 2 if num > 0. - assert(pos <= getNumCols() - 2 && "invalid position"); + assert(getNumCols() < 2 || pos <= getNumCols() - 2 && "invalid position"); assert(pos + num < getNumCols() && "invalid range"); // Eliminate as many identifiers as possible using Gaussian elimination. @@ -1930,8 +2218,9 @@ void FlatAffineConstraints::projectOut(unsigned pos, unsigned num) { // Eliminate the remaining using Fourier-Motzkin. for (unsigned i = 0; i < num - numGaussianEliminated; i++) { - unsigned elimId = getBestIdToEliminate(*this, pos, getNumIds()); - FourierMotzkinEliminate(elimId); + unsigned numToEliminate = num - numGaussianEliminated - i; + FourierMotzkinEliminate( + getBestIdToEliminate(*this, pos, pos + numToEliminate)); } // Fast/trivial simplifications. diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 9d89f04d41d..d3beb78eb92 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -340,6 +340,14 @@ static Instruction *getInstAtPosition(ArrayRef positions, // dependence constraint system to create AffineMaps with which to adjust the // loop bounds of the inserted compution slice so that they are functions of the // loop IVs and symbols of the loops surrounding 'dstAccess'. +// TODO(andydavis,bondhugula): extend the slicing utility to compute slices that +// aren't necessarily a one-to-one relation b/w the source and destination. The +// relation between the source and destination could be many-to-many in general. +// TODO(andydavis,bondhugula): the slice computation is incorrect in the cases +// where the dependence from the source to the destination does not cover the +// entire destination index set. Subtract out the dependent destination +// iterations from destination index set and check for emptiness --- this is one +// solution. ForInst *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess, MemRefAccess *dstAccess, unsigned srcLoopDepth, @@ -351,89 +359,74 @@ ForInst *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess, return nullptr; } // Get loop nest surrounding src operation. - SmallVector srcLoopNest; - getLoopIVs(*srcAccess->opInst, &srcLoopNest); - unsigned srcLoopNestSize = srcLoopNest.size(); - assert(srcLoopDepth <= srcLoopNestSize); + SmallVector srcLoopIVs; + getLoopIVs(*srcAccess->opInst, &srcLoopIVs); + unsigned numSrcLoopIVs = srcLoopIVs.size(); + if (srcLoopDepth > numSrcLoopIVs) { + srcAccess->opInst->emitError("invalid source loop depth"); + return nullptr; + } // Get loop nest surrounding dst operation. - SmallVector dstLoopNest; - getLoopIVs(*dstAccess->opInst, &dstLoopNest); - unsigned dstLoopNestSize = dstLoopNest.size(); - (void)dstLoopNestSize; - assert(dstLoopDepth > 0); - assert(dstLoopDepth <= dstLoopNestSize); - - // Solve for src IVs in terms of dst IVs, symbols and constants. - SmallVector srcIvMaps(srcLoopNestSize, AffineMap::Null()); - std::vector> srcIvOperands(srcLoopNestSize); - for (unsigned i = 0; i < srcLoopNestSize; ++i) { - // Skip IVs which are greater than requested loop depth. - if (i >= srcLoopDepth) { - srcIvMaps[i] = AffineMap::Null(); - continue; - } - auto cst = dependenceConstraints.clone(); - for (int j = srcLoopNestSize - 1; j >= 0; --j) { - if (i != j) - cst->projectOut(j); - } - SmallVector nonZeroDimIds; - SmallVector nonZeroSymbolIds; - srcIvMaps[i] = cst->toAffineMapFromEq(0, srcAccess->opInst->getContext(), - &nonZeroDimIds, &nonZeroSymbolIds); - // Add operands for all non-zero dst dims and symbols. - // TODO(andydavis) Add local variable support. - for (auto dimId : nonZeroDimIds) { - if (dimId - 1 >= dstLoopDepth) { - // This src IV has a dependence on dst IV dstLoopDepth where it will - // be inserted. So we cannot slice the iteration space at srcLoopDepth, - // and also insert it into the dst loop nest at 'dstLoopDepth'. - return nullptr; - } - srcIvOperands[i].push_back(dstLoopNest[dimId - 1]); - } - // TODO(andydavis) Add symbols from the access function. Ideally, we - // should be able to query the constaint system for the Value associated - // with a symbol identifiers in 'nonZeroSymbolIds'. + SmallVector dstLoopIVs; + getLoopIVs(*dstAccess->opInst, &dstLoopIVs); + unsigned dstLoopIVsSize = dstLoopIVs.size(); + if (dstLoopDepth > dstLoopIVsSize) { + dstAccess->opInst->emitError("invalid destination loop depth"); + return nullptr; } - // Find the inst block positions of 'srcAccess->opInst' within 'srcLoopNest'. + // Project out dimensions other than those up to src/dstLoopDepth's. + dependenceConstraints.projectOut(srcLoopDepth, numSrcLoopIVs - srcLoopDepth); + dependenceConstraints.projectOut(srcLoopDepth + dstLoopDepth, + dstLoopIVsSize - dstLoopDepth); + + // Set up lower/upper bound affine maps for the slice. + SmallVector sliceLbs(srcLoopDepth, AffineMap::Null()); + SmallVector sliceUbs(srcLoopDepth, AffineMap::Null()); + + // Get bounds for src IVs in terms of dst IVs, symbols, and constants. + dependenceConstraints.getSliceBounds(std::min(srcLoopDepth, numSrcLoopIVs), + srcAccess->opInst->getContext(), + &sliceLbs, &sliceUbs); + + // Set up bound operands for the slice's lower and upper bounds. + SmallVector sliceBoundOperands; + dependenceConstraints.getIdValues( + srcLoopDepth, dependenceConstraints.getNumDimAndSymbolIds(), + &sliceBoundOperands); + + // Find the inst block positions of 'srcAccess->opInst' within 'srcLoopIVs'. SmallVector positions; - findInstPosition(srcAccess->opInst, srcLoopNest[0]->getBlock(), &positions); + // TODO(andydavis): This code is incorrect since srcLoopIVs can be 0-d. + findInstPosition(srcAccess->opInst, srcLoopIVs[0]->getBlock(), &positions); // Clone src loop nest and insert it a the beginning of the instruction block - // of the loop at 'dstLoopDepth' in 'dstLoopNest'. - auto *dstForInst = dstLoopNest[dstLoopDepth - 1]; + // of the loop at 'dstLoopDepth' in 'dstLoopIVs'. + auto *dstForInst = dstLoopIVs[dstLoopDepth - 1]; FuncBuilder b(dstForInst->getBody(), dstForInst->getBody()->begin()); DenseMap operandMap; - auto *sliceLoopNest = cast(b.clone(*srcLoopNest[0], operandMap)); + auto *sliceLoopNest = cast(b.clone(*srcLoopIVs[0], operandMap)); - // Lookup inst in cloned 'sliceLoopNest' at 'positions'. Instruction *sliceInst = getInstAtPosition(positions, /*level=*/0, sliceLoopNest->getBody()); // Get loop nest surrounding 'sliceInst'. SmallVector sliceSurroundingLoops; getLoopIVs(*sliceInst, &sliceSurroundingLoops); + + // Sanity check. unsigned sliceSurroundingLoopsSize = sliceSurroundingLoops.size(); (void)sliceSurroundingLoopsSize; + unsigned sliceLoopLimit = dstLoopDepth + numSrcLoopIVs; + assert(sliceLoopLimit >= sliceSurroundingLoopsSize); // Update loop bounds for loops in 'sliceLoopNest'. - unsigned sliceLoopLimit = dstLoopDepth + srcLoopNestSize; - assert(sliceLoopLimit <= sliceSurroundingLoopsSize); - for (unsigned i = dstLoopDepth; i < sliceLoopLimit; ++i) { - auto *forInst = sliceSurroundingLoops[i]; - unsigned index = i - dstLoopDepth; - AffineMap lbMap = srcIvMaps[index]; - if (lbMap == AffineMap::Null()) - continue; - forInst->setLowerBound(srcIvOperands[index], lbMap); - // Create upper bound map with is lower bound map + 1; - assert(lbMap.getNumResults() == 1); - AffineExpr ubResultExpr = lbMap.getResult(0) + 1; - AffineMap ubMap = AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(), - {ubResultExpr}, {}); - forInst->setUpperBound(srcIvOperands[index], ubMap); + for (unsigned i = 0; i < srcLoopDepth; ++i) { + auto *forInst = sliceSurroundingLoops[dstLoopDepth + i]; + if (AffineMap lbMap = sliceLbs[i]) + forInst->setLowerBound(sliceBoundOperands, lbMap); + if (AffineMap ubMap = sliceUbs[i]) + forInst->setUpperBound(sliceBoundOperands, ubMap); } return sliceLoopNest; } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index face781a395..2dcb9b15c8d 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1409,6 +1409,10 @@ void IntegerSet::dump() const { } void AffineExpr::print(raw_ostream &os) const { + if (expr == nullptr) { + os << "null affine expr"; + return; + } ModuleState state(getContext()); ModulePrinter(os, state).printAffineExpr(*this); } @@ -1419,6 +1423,10 @@ void AffineExpr::dump() const { } void AffineMap::print(raw_ostream &os) const { + if (map == nullptr) { + os << "null affine map"; + return; + } ModuleState state(getContext()); ModulePrinter(os, state).printAffineMap(*this); } diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 2a004492d84..d3832378264 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -47,12 +47,12 @@ using namespace mlir; // depth per-loop nest, or depth per load/store op) for this pass utilizing a // cost model. static llvm::cl::opt clSrcLoopDepth( - "src-loop-depth", llvm::cl::Hidden, + "fusion-src-loop-depth", llvm::cl::Hidden, llvm::cl::desc("Controls the depth of the source loop nest at which " "to apply loop iteration slicing before fusion.")); static llvm::cl::opt clDstLoopDepth( - "dst-loop-depth", llvm::cl::Hidden, + "fusion-dst-loop-depth", llvm::cl::Hidden, llvm::cl::desc("Controls the depth of the destination loop nest at which " "to fuse the source loop nest slice.")); diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index adf91b76276..1a1502de738 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -117,6 +117,8 @@ void MemRefDataFlowOpt::visitOperationInst(OperationInst *opInst) { storeOps.push_back(storeOpInst); } + unsigned loadOpDepth = getNestingDepth(*loadOpInst); + // 1. Check if there is a dependence satisfied at depth equal to the depth // of the loop body of the innermost common surrounding loop of the storeOp // and loadOp. @@ -165,11 +167,16 @@ void MemRefDataFlowOpt::visitOperationInst(OperationInst *opInst) { // a [i] = ... // for (j ...) // ... = a[j] - MemRefRegion region; - getMemRefRegion(loadOpInst, nsLoops, ®ion); - if (!region.getConstraints()->isRangeOneToOne( - /*start=*/0, /*limit=*/loadOp->getMemRefType().getRank())) - break; + // If storeOpInst and loadOpDepth at the same nesting depth, the load Op + // is trivially loading from a single location at that depth; so there + // isn't a need to call isRangeOneToOne. + if (getNestingDepth(*storeOpInst) < loadOpDepth) { + MemRefRegion region; + getMemRefRegion(loadOpInst, nsLoops, ®ion); + if (!region.getConstraints()->isRangeOneToOne( + /*start=*/0, /*limit=*/loadOp->getMemRefType().getRank())) + break; + } // After all these conditions, we have a candidate for forwarding! fwdingCandidates.push_back(storeOpInst); diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index 886072fc30c..57172b1cf9a 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s -loop-fusion -split-input-file -verify | FileCheck %s -// RUN: mlir-opt %s -loop-fusion -src-loop-depth=1 -dst-loop-depth=1 -split-input-file -verify | FileCheck %s --check-prefix DEPTH1 +// RUN: mlir-opt %s -loop-fusion -fusion-src-loop-depth=1 -fusion-dst-loop-depth=1 -split-input-file -verify | FileCheck %s --check-prefix DEPTH1 // TODO(andydavis) Add more tests: // *) Add nested fusion test cases when non-constant loop bound support is @@ -76,7 +76,8 @@ func @should_fuse_reduction_to_pointwise() { // ----- -// CHECK: [[MAP_SHIFT_MINUS_ONE:#map[0-9]+]] = (d0) -> (d0 - 1) +// CHECK: [[MAP_SHIFT_MINUS_ONE_D0:#map[0-9]+]] = (d0, d1) -> (d0 - 1) +// CHECK: [[MAP_SHIFT_MINUS_ONE_D1:#map[0-9]+]] = (d0, d1) -> (d1 - 1) // CHECK: [[MAP_SHIFT_BY_ONE:#map[0-9]+]] = (d0, d1) -> (d0 + 1, d1 + 1) // CHECK-LABEL: func @should_fuse_loop_nests_with_shifts() { @@ -98,8 +99,8 @@ func @should_fuse_loop_nests_with_shifts() { // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: for %i1 = 0 to 10 { - // CHECK-NEXT: %1 = affine_apply [[MAP_SHIFT_MINUS_ONE]](%i0) - // CHECK-NEXT: %2 = affine_apply [[MAP_SHIFT_MINUS_ONE]](%i1) + // CHECK-NEXT: %1 = affine_apply [[MAP_SHIFT_MINUS_ONE_D0]](%i0, %i1) + // CHECK-NEXT: %2 = affine_apply [[MAP_SHIFT_MINUS_ONE_D1]](%i0, %i1) // CHECK-NEXT: %3 = affine_apply [[MAP_SHIFT_BY_ONE]](%1, %2) // CHECK-NEXT: store %cst, %0[%3#0, %3#1] : memref<10x10xf32> // CHECK-NEXT: %4 = load %0[%i0, %i1] : memref<10x10xf32> @@ -111,7 +112,8 @@ func @should_fuse_loop_nests_with_shifts() { // ----- -// CHECK: [[MAP_IDENTITY:#map[0-9]+]] = (d0) -> (d0) +// CHECK-DAG: [[MAP_DIM0:#map[0-9]+]] = (d0, d1) -> (d0) +// CHECK-DAG: [[MAP_DIM1:#map[0-9]+]] = (d0, d1) -> (d1) // CHECK-LABEL: func @should_fuse_loop_nest() { func @should_fuse_loop_nest() { @@ -138,11 +140,11 @@ func @should_fuse_loop_nest() { // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: for %i1 = 0 to 10 { - // CHECK-NEXT: %2 = affine_apply [[MAP_IDENTITY]](%i1) - // CHECK-NEXT: %3 = affine_apply [[MAP_IDENTITY]](%i0) + // CHECK-NEXT: %2 = affine_apply [[MAP_DIM1]](%i0, %i1) + // CHECK-NEXT: %3 = affine_apply [[MAP_DIM0]](%i0, %i1) // CHECK-NEXT: store %cst, %0[%2, %3] : memref<10x10xf32> - // CHECK-NEXT: %4 = affine_apply [[MAP_IDENTITY]](%i0) - // CHECK-NEXT: %5 = affine_apply [[MAP_IDENTITY]](%i1) + // CHECK-NEXT: %4 = affine_apply [[MAP_DIM0]](%i0, %i1) + // CHECK-NEXT: %5 = affine_apply [[MAP_DIM1]](%i0, %i1) // CHECK-NEXT: %6 = load %0[%5, %4] : memref<10x10xf32> // CHECK-NEXT: store %6, %1[%4, %5] : memref<10x10xf32> // CHECK-NEXT: %7 = load %1[%i0, %i1] : memref<10x10xf32> @@ -509,9 +511,11 @@ func @should_not_fuse_if_inst_in_loop_nest() { // ----- -// CHECK: [[MAP0:#map[0-9]+]] = (d0) -> (d0) -// CHECK: [[MAP1:#map[0-9]+]] = (d0, d1, d2) -> (d0, d1, d2) -// CHECK: [[MAP2:#map[0-9]+]] = (d0, d1, d2) -> (d1, d2, d0) +// CHECK-DAG: [[MAP_D0:#map[0-9]+]] = (d0, d1, d2) -> (d0) +// CHECK-DAG: [[MAP_D1:#map[0-9]+]] = (d0, d1, d2) -> (d1) +// CHECK-DAG: [[MAP_D2:#map[0-9]+]] = (d0, d1, d2) -> (d2) +// CHECK: [[MAP_IDENTITY:#map[0-9]+]] = (d0, d1, d2) -> (d0, d1, d2) +// CHECK: [[MAP_PERMUTE:#map[0-9]+]] = (d0, d1, d2) -> (d1, d2, d0) // CHECK-LABEL: func @remap_ivs() { func @remap_ivs() { @@ -537,12 +541,12 @@ func @remap_ivs() { // CHECK: for %i0 = 0 to 30 { // CHECK-NEXT: for %i1 = 0 to 10 { // CHECK-NEXT: for %i2 = 0 to 20 { -// CHECK-NEXT: %1 = affine_apply [[MAP0]](%i1) -// CHECK-NEXT: %2 = affine_apply [[MAP0]](%i2) -// CHECK-NEXT: %3 = affine_apply [[MAP0]](%i0) -// CHECK-NEXT: %4 = affine_apply [[MAP1]](%1, %2, %3) +// CHECK-NEXT: %1 = affine_apply [[MAP_D1]](%i0, %i1, %i2) +// CHECK-NEXT: %2 = affine_apply [[MAP_D2]](%i0, %i1, %i2) +// CHECK-NEXT: %3 = affine_apply [[MAP_D0]](%i0, %i1, %i2) +// CHECK-NEXT: %4 = affine_apply [[MAP_IDENTITY]](%1, %2, %3) // CHECK-NEXT: store %cst, %0[%4#0, %4#1, %4#2] : memref<10x20x30xf32> -// CHECK-NEXT: %5 = affine_apply [[MAP2]](%i0, %i1, %i2) +// CHECK-NEXT: %5 = affine_apply [[MAP_PERMUTE]](%i0, %i1, %i2) // CHECK-NEXT: %6 = load %0[%5#0, %5#1, %5#2] : memref<10x20x30xf32> // CHECK-NEXT: } // CHECK-NEXT: } @@ -627,3 +631,141 @@ func @fuse_reshape_64_16_4(%in : memref<64xf32>) { // CHECK-NEXT: } // CHECK-NEXT: return } + +// ----- +// CHECK: #map0 = (d0) -> (d0 floordiv 4) +// CHECK: #map1 = (d0) -> (d0 mod 4) + +// Reshape a 16x4xf32 to 64xf32. +// CHECK-LABEL: func @fuse_reshape_16_4_64 +func @fuse_reshape_16_4_64() { + %in = alloc() : memref<16x4xf32> + %out = alloc() : memref<64xf32> + + for %i0 = 0 to 16 { + for %i1 = 0 to 4 { + %v = load %in[%i0, %i1] : memref<16x4xf32> + %idx = affine_apply (d0, d1) -> (4*d0 + d1) (%i0, %i1) + store %v, %out[%idx] : memref<64xf32> + } + } + + for %i2 = 0 to 64 { + %w = load %out[%i2] : memref<64xf32> + "foo"(%w) : (f32) -> () + } +// CHECK: for %i0 = 0 to 64 { +// CHECK-NEXT: %2 = affine_apply #map0(%i0) +// CHECK-NEXT: %3 = affine_apply #map1(%i0) +// CHECK-NEXT: %4 = load %0[%2, %3] : memref<16x4xf32> +// CHECK-NEXT: %5 = affine_apply #map2(%2, %3) +// CHECK-NEXT: store %4, %1[%5] : memref<64xf32> +// CHECK-NEXT: %6 = load %1[%i0] : memref<64xf32> +// CHECK-NEXT: "foo"(%6) : (f32) -> () +// CHECK-NEXT: } +// CHECK-NEXT: return + return +} + + +// ----- + +// All three loop nests below (6-d one, 2-d one, 2-d one is fused into a single +// 2-d loop nest). +// CHECK-LABEL: func @R6_to_R2_reshape +func @R6_to_R2_reshape_square() -> memref<64x9xi32> { + %in = alloc() : memref<2x2x3x3x16x1xi32> + %out = alloc() : memref<64x9xi32> + + // Initialize input with a different value for each 8x128 chunk. + for %i0 = 0 to 2 { + for %i1 = 0 to 2 { + for %i2 = 0 to 3 { + for %i3 = 0 to 3 { + for %i4 = 0 to 16 { + for %i5 = 0 to 1 { + %val = "foo"(%i0, %i1, %i2, %i3, %i4, %i5) : (index, index, index, index, index, index) -> i32 + store %val, %in[%i0, %i1, %i2, %i3, %i4, %i5] : memref<2x2x3x3x16x1xi32> + } + } + } + } + } + } + + for %ii = 0 to 64 { + for %jj = 0 to 9 { + // Convert output coordinates to linear index. + %a0 = affine_apply (d0, d1) -> (d0 * 9 + d1) (%ii, %jj) + %a1 = affine_apply (d0) -> ( + d0 floordiv (2 * 3 * 3 * 16 * 1), + (d0 mod 288) floordiv (3 * 3 * 16 * 1), + ((d0 mod 288) mod 144) floordiv 48, + (((d0 mod 288) mod 144) mod 48) floordiv 16, + ((((d0 mod 288) mod 144) mod 48) mod 16), + (((d0 mod 144) mod 144) mod 48) mod 16 + ) (%a0) + %v = load %in[%a1#0, %a1#1, %a1#3, %a1#4, %a1#2, %a1#5] + : memref<2x2x3x3x16x1xi32> + store %v, %out[%ii, %jj] : memref<64x9xi32> + } + } + + for %i = 0 to 64 { + for %j = 0 to 9 { + %a = load %out[%i, %j] : memref<64x9xi32> + %b = muli %a, %a : i32 + store %b, %out[%i, %j] : memref<64x9xi32> + } + } + return %out : memref<64x9xi32> +} +// Everything above is fused to a single 2-d loop nest, and the 6-d tensor %in +// is eliminated if -memref-dataflow-opt is also supplied. +// +// CHECK: for %i0 = 0 to 64 { +// CHECK-NEXT: for %i1 = 0 to 9 { +// CHECK-NEXT: %2 = affine_apply #map0(%i0, %i1) +// CHECK-NEXT: %3 = affine_apply #map1(%i0, %i1) +// CHECK-NEXT: %4 = affine_apply #map2(%i0, %i1) +// CHECK-NEXT: %5 = affine_apply #map3(%i0, %i1) +// CHECK-NEXT: %6 = affine_apply #map4(%i0, %i1) +// CHECK-NEXT: %7 = "foo"(%2, %3, %4, %5, %6, %c0) : (index, index, index, index, index, index) -> i32 +// CHECK-NEXT: store %7, %0[%2, %3, %4, %5, %6, %c0] : memref<2x2x3x3x16x1xi32> +// CHECK-NEXT: %8 = affine_apply #map5(%i0, %i1) +// CHECK-NEXT: %9 = affine_apply #map6(%i0, %i1) +// CHECK-NEXT: %10 = affine_apply #map7(%8, %9) +// CHECK-NEXT: %11 = affine_apply #map8(%10) +// CHECK-NEXT: %12 = load %0[%11#0, %11#1, %11#3, %11#4, %11#2, %11#5] : memref<2x2x3x3x16x1xi32> +// CHECK-NEXT: store %12, %1[%8, %9] : memref<64x9xi32> +// CHECK-NEXT: %13 = load %1[%i0, %i1] : memref<64x9xi32> +// CHECK-NEXT: %14 = muli %13, %13 : i32 +// CHECK-NEXT: store %14, %1[%i0, %i1] : memref<64x9xi32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return %1 : memref<64x9xi32> + +// ----- + +// CHECK-LABEL: func @fuse_symbolic_bounds +func @fuse_symbolic_bounds(%M : index, %N : index) { + %m = alloc() : memref<800x800xf32> + + %c0 = constant 0.0 : f32 + %s = constant 5 : index + + for %i0 = 0 to %M { + for %i1 = 0 to (d0) -> (d0 + 5) (%N) { + store %c0, %m[%i0, %i1] : memref<800 x 800 x f32> + } + } + + for %i2 = 0 to %M { + for %i3 = 0 to %N { + %idx = affine_apply (d0, d1)[s0] -> (d0, d1 + s0) (%i2, %i3)[%s] + %v = load %m[%idx#0, %idx#1] : memref<800 x 800 x f32> + } + } + + return +} diff --git a/mlir/test/Transforms/memref-dataflow-opt.mlir b/mlir/test/Transforms/memref-dataflow-opt.mlir index 888e2e7d24d..2c9b2380192 100644 --- a/mlir/test/Transforms/memref-dataflow-opt.mlir +++ b/mlir/test/Transforms/memref-dataflow-opt.mlir @@ -167,7 +167,7 @@ func @multi_store_load_nested_fwd(%N : index) { return } -// No one-to-one dependence here between the store and load. +// There is no unique load location for the store to forward to. // CHECK-LABEL: func @store_load_no_fwd func @store_load_no_fwd() { %cf7 = constant 7.0 : f32 -- cgit v1.2.3 From 38c2fe3158f7e5e8115e6c030cf618b5f6e5ef6a Mon Sep 17 00:00:00 2001 From: MLIR Team Date: Mon, 14 Jan 2019 11:26:25 -0800 Subject: LoopFusion: automate selection of source loop nest slice depth and destination loop nest insertion depth based on a simple cost model (cost model can be extended/replaced at a later time). *) LoopFusion: Adds fusion cost function which compares the cost of the fused loop nest, with the cost of the two unfused loop nests to determine if it is profitable to fuse the candidate loop nests. The fusion cost function is run for various combinations for src/dst loop depths attempting find the minimum cost setting for src/dst loop depths which does not increase the computational cost when the loop nests are fused. Combinations of src/dst loop depth are evaluated attempting to maximize loop depth (i.e. take a bigger computation slice from the source loop nest, and insert it deeper in the destination loop nest for better locality). *) LoopFusion: Adds utility to compute op instance count for loop nests, sliced loop nests, and to compute the cost of a loop nest fused with another sliced loop nest. *) LoopFusion: canonicalizes slice bound AffineMaps (and updates related tests). *) Analysis::Utils: Splits getBackwardComputationSlice into two functions: one which calculates and returns the slice loop bounds for analysis by LoopFusion, and the other for insertion of the computation slice (ones fusion has calculated the min-cost src/dst loop depths). *) Test: Adds multiple unit tests to test the new functionality. PiperOrigin-RevId: 229219757 --- mlir/include/mlir/Analysis/Utils.h | 27 ++- mlir/lib/Analysis/Utils.cpp | 78 ++++--- mlir/lib/Transforms/LoopFusion.cpp | 375 ++++++++++++++++++++++++++++++++-- mlir/test/Transforms/loop-fusion.mlir | 344 +++++++++++++++++++++++++------ 4 files changed, 703 insertions(+), 121 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index 900a7ac229b..7cd30ba86ab 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -26,6 +26,7 @@ #define MLIR_ANALYSIS_UTILS_H #include "mlir/Analysis/AffineStructures.h" +#include "mlir/IR/AffineMap.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/SmallVector.h" #include @@ -151,6 +152,29 @@ bool boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, unsigned getNumCommonSurroundingLoops(const Instruction &A, const Instruction &B); +/// ComputationSliceState aggregates loop bound AffineMaps and their associated +/// operands for a set of loops within a loop nest (typically the set of loops +/// surrounding a store operation). Loop bound AffineMaps which are non-null +/// represent slices of that loop's iteration space. +struct ComputationSliceState { + // List of lower bound AffineMaps. + SmallVector lbs; + // List of upper bound AffineMaps. + SmallVector ubs; + // List of lower bound operands (lbOperands[i] are used by 'lbs[i]'). + std::vector> lbOperands; + // List of upper bound operands (ubOperands[i] are used by 'ubs[i]'). + std::vector> ubOperands; +}; + +/// Computes computation slice loop bounds for the loop nest surrounding +/// 'srcAccess', where the returned loop bound AffineMaps are functions of +/// loop IVs from the loop nest surrounding 'dstAccess'. +/// Returns true on success, false otherwise. +bool getBackwardComputationSliceState(const MemRefAccess &srcAccess, + const MemRefAccess &dstAccess, + ComputationSliceState *sliceState); + /// Creates a clone of the computation contained in the loop nest surrounding /// 'srcAccess', slices the iteration space of the first 'srcLoopDepth' src loop /// IVs, and inserts the computation slice at the beginning of the instruction @@ -159,10 +183,11 @@ unsigned getNumCommonSurroundingLoops(const Instruction &A, /// success, returns nullptr otherwise. // Loop depth is a crucial optimization choice that determines where to // materialize the results of the backward slice - presenting a trade-off b/w -// storage and redundant computation in several cases +// storage and redundant computation in several cases. // TODO(andydavis) Support computation slices with common surrounding loops. ForInst *insertBackwardComputationSlice(MemRefAccess *srcAccess, MemRefAccess *dstAccess, + ComputationSliceState *sliceState, unsigned srcLoopDepth, unsigned dstLoopDepth); } // end namespace mlir diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index b4e3cabd256..12ac0cc44ec 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -348,6 +348,47 @@ static Instruction *getInstAtPosition(ArrayRef positions, // dependence constraint system to create AffineMaps with which to adjust the // loop bounds of the inserted compution slice so that they are functions of the // loop IVs and symbols of the loops surrounding 'dstAccess'. +bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess, + const MemRefAccess &dstAccess, + ComputationSliceState *sliceState) { + FlatAffineConstraints dependenceConstraints; + if (!checkMemrefAccessDependence(srcAccess, dstAccess, /*loopDepth=*/1, + &dependenceConstraints, + /*dependenceComponents=*/nullptr)) { + return false; + } + // Get loop nest surrounding src operation. + SmallVector srcLoopIVs; + getLoopIVs(*srcAccess.opInst, &srcLoopIVs); + unsigned numSrcLoopIVs = srcLoopIVs.size(); + + // Set up lower/upper bound affine maps for the slice. + sliceState->lbs.resize(numSrcLoopIVs, AffineMap::Null()); + sliceState->ubs.resize(numSrcLoopIVs, AffineMap::Null()); + + // Get bounds for src IVs in terms of dst IVs, symbols, and constants. + dependenceConstraints.getSliceBounds(numSrcLoopIVs, + srcAccess.opInst->getContext(), + &sliceState->lbs, &sliceState->ubs); + + // Set up bound operands for the slice's lower and upper bounds. + SmallVector sliceBoundOperands; + dependenceConstraints.getIdValues( + numSrcLoopIVs, dependenceConstraints.getNumDimAndSymbolIds(), + &sliceBoundOperands); + // Give each bound its own copy of 'sliceBoundOperands' for subsequent + // canonicalization. + sliceState->lbOperands.resize(numSrcLoopIVs, sliceBoundOperands); + sliceState->ubOperands.resize(numSrcLoopIVs, sliceBoundOperands); + return true; +} + +/// Creates a computation slice of the loop nest surrounding 'srcAccess' +/// utilizing slice loop bounds in 'sliceState' (for src loops up to +/// 'srcLoopDepth'), and inserts this slice into loop nest surrounding +/// 'dstAccess' at loop depth 'dstLoopDepth'. For all loops at loop depth +/// greater than 'srcLoopDepth' their full loop bounds will be used in the +/// slice. // TODO(andydavis,bondhugula): extend the slicing utility to compute slices that // aren't necessarily a one-to-one relation b/w the source and destination. The // relation between the source and destination could be many-to-many in general. @@ -356,16 +397,13 @@ static Instruction *getInstAtPosition(ArrayRef positions, // entire destination index set. Subtract out the dependent destination // iterations from destination index set and check for emptiness --- this is one // solution. +// TODO(andydavis) Remove dependence on 'srcLoopDepth' here. Instead project +// out loop IVs we don't care about and produce smaller slice. ForInst *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess, MemRefAccess *dstAccess, + ComputationSliceState *sliceState, unsigned srcLoopDepth, unsigned dstLoopDepth) { - FlatAffineConstraints dependenceConstraints; - if (!checkMemrefAccessDependence(*srcAccess, *dstAccess, /*loopDepth=*/1, - &dependenceConstraints, - /*dependenceComponents=*/nullptr)) { - return nullptr; - } // Get loop nest surrounding src operation. SmallVector srcLoopIVs; getLoopIVs(*srcAccess->opInst, &srcLoopIVs); @@ -384,26 +422,6 @@ ForInst *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess, return nullptr; } - // Project out dimensions other than those up to src/dstLoopDepth's. - dependenceConstraints.projectOut(srcLoopDepth, numSrcLoopIVs - srcLoopDepth); - dependenceConstraints.projectOut(srcLoopDepth + dstLoopDepth, - dstLoopIVsSize - dstLoopDepth); - - // Set up lower/upper bound affine maps for the slice. - SmallVector sliceLbs(srcLoopDepth, AffineMap::Null()); - SmallVector sliceUbs(srcLoopDepth, AffineMap::Null()); - - // Get bounds for src IVs in terms of dst IVs, symbols, and constants. - dependenceConstraints.getSliceBounds(std::min(srcLoopDepth, numSrcLoopIVs), - srcAccess->opInst->getContext(), - &sliceLbs, &sliceUbs); - - // Set up bound operands for the slice's lower and upper bounds. - SmallVector sliceBoundOperands; - dependenceConstraints.getIdValues( - srcLoopDepth, dependenceConstraints.getNumDimAndSymbolIds(), - &sliceBoundOperands); - // Find the inst block positions of 'srcAccess->opInst' within 'srcLoopIVs'. SmallVector positions; // TODO(andydavis): This code is incorrect since srcLoopIVs can be 0-d. @@ -433,10 +451,10 @@ ForInst *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess, // Update loop bounds for loops in 'sliceLoopNest'. for (unsigned i = 0; i < srcLoopDepth; ++i) { auto *forInst = sliceSurroundingLoops[dstLoopDepth + i]; - if (AffineMap lbMap = sliceLbs[i]) - forInst->setLowerBound(sliceBoundOperands, lbMap); - if (AffineMap ubMap = sliceUbs[i]) - forInst->setUpperBound(sliceBoundOperands, ubMap); + if (AffineMap lbMap = sliceState->lbs[i]) + forInst->setLowerBound(sliceState->lbOperands[i], lbMap); + if (AffineMap ubMap = sliceState->ubs[i]) + forInst->setUpperBound(sliceState->ubOperands[i], ubMap); } return sliceLoopNest; } diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index d3832378264..dffa292af3c 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -36,26 +36,15 @@ #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#define DEBUG_TYPE "loop-fusion" + using llvm::SetVector; using namespace mlir; -// TODO(andydavis) These flags are global for the pass to be used for -// experimentation. Find a way to provide more fine grained control (i.e. -// depth per-loop nest, or depth per load/store op) for this pass utilizing a -// cost model. -static llvm::cl::opt clSrcLoopDepth( - "fusion-src-loop-depth", llvm::cl::Hidden, - llvm::cl::desc("Controls the depth of the source loop nest at which " - "to apply loop iteration slicing before fusion.")); - -static llvm::cl::opt clDstLoopDepth( - "fusion-dst-loop-depth", llvm::cl::Hidden, - llvm::cl::desc("Controls the depth of the destination loop nest at which " - "to fuse the source loop nest slice.")); - namespace { /// Loop fusion pass. This pass currently supports a greedy fusion policy, @@ -379,6 +368,347 @@ bool MemRefDependenceGraph::init(Function *f) { return true; } +namespace { + +// LoopNestStats aggregates various per-loop statistics (eg. loop trip count +// and operation count) for a loop nest up until the innermost loop body. +struct LoopNestStats { + // Map from ForInst to immediate child ForInsts in its loop body. + DenseMap> loopMap; + // Map from ForInst to count of operations in its loop body. + DenseMap opCountMap; + // Map from ForInst to its constant trip count. + DenseMap tripCountMap; +}; + +// LoopNestStatsCollector walks a single loop nest and gathers per-loop +// trip count and operation count statistics and records them in 'stats'. +class LoopNestStatsCollector : public InstWalker { +public: + LoopNestStats *stats; + bool hasLoopWithNonConstTripCount = false; + + LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {} + + void visitForInst(ForInst *forInst) { + auto *parentInst = forInst->getParentInst(); + if (parentInst != nullptr) { + assert(isa(parentInst) && "Expected parent ForInst"); + // Add mapping to 'forInst' from its parent ForInst. + stats->loopMap[cast(parentInst)].push_back(forInst); + } + // Record the number of op instructions in the body of 'forInst'. + unsigned count = 0; + stats->opCountMap[forInst] = 0; + for (auto &inst : *forInst->getBody()) { + if (isa(&inst)) + ++count; + } + stats->opCountMap[forInst] = count; + // Record trip count for 'forInst'. Set flag if trip count is not constant. + Optional maybeConstTripCount = getConstantTripCount(*forInst); + if (!maybeConstTripCount.hasValue()) { + hasLoopWithNonConstTripCount = true; + return; + } + stats->tripCountMap[forInst] = maybeConstTripCount.getValue(); + } +}; + +// Computes the total cost of the loop nest rooted at 'forInst'. +// Currently, the total cost is computed by counting the total operation +// instance count (i.e. total number of operations in the loop bodyloop +// operation count * loop trip count) for the entire loop nest. +// If 'tripCountOverrideMap' is non-null, overrides the trip count for loops +// specified in the map when computing the total op instance count. +// NOTE: this is used to compute the cost of computation slices, which are +// sliced along the iteration dimension, and thus reduce the trip count. +// If 'computeCostMap' is non-null, the total op count for forInsts specified +// in the map is increased (not overridden) by adding the op count from the +// map to the existing op count for the for loop. This is done before +// multiplying by the loop's trip count, and is used to model the cost of +// inserting a sliced loop nest of known cost into the loop's body. +// NOTE: this is used to compute the cost of fusing a slice of some loop nest +// within another loop. +static uint64_t +getComputeCost(ForInst *forInst, LoopNestStats *stats, + DenseMap *tripCountOverrideMap, + DenseMap *computeCostMap) { + // 'opCount' is the total number operations in one iteration of 'forInst' body + uint64_t opCount = stats->opCountMap[forInst]; + if (stats->loopMap.count(forInst) > 0) { + for (auto *childForInst : stats->loopMap[forInst]) { + opCount += getComputeCost(childForInst, stats, tripCountOverrideMap, + computeCostMap); + } + } + // Add in additional op instances from slice (if specified in map). + if (computeCostMap != nullptr) { + auto it = computeCostMap->find(forInst); + if (it != computeCostMap->end()) { + opCount += it->second; + } + } + // Override trip count (if specified in map). + uint64_t tripCount = stats->tripCountMap[forInst]; + if (tripCountOverrideMap != nullptr) { + auto it = tripCountOverrideMap->find(forInst); + if (it != tripCountOverrideMap->end()) { + tripCount = it->second; + } + } + // Returns the total number of dynamic instances of operations in loop body. + return tripCount * opCount; +} + +} // end anonymous namespace + +// Builds a map 'tripCountMap' from ForInst to constant trip count for loop +// nest surrounding 'srcAccess' utilizing slice loop bounds in 'sliceState'. +// Returns true on success, false otherwise (if a non-constant trip count +// was encountered). +// TODO(andydavis) Make this work with non-unit step loops. +static bool +buildSliceTripCountMap(MemRefAccess *srcAccess, + ComputationSliceState *sliceState, + DenseMap *tripCountMap) { + SmallVector srcLoopIVs; + getLoopIVs(*srcAccess->opInst, &srcLoopIVs); + unsigned numSrcLoopIVs = srcLoopIVs.size(); + // Populate map from ForInst -> trip count + for (unsigned i = 0; i < numSrcLoopIVs; ++i) { + AffineMap lbMap = sliceState->lbs[i]; + AffineMap ubMap = sliceState->ubs[i]; + if (lbMap == AffineMap::Null() || ubMap == AffineMap::Null()) { + // The iteration of src loop IV 'i' was not sliced. Use full loop bounds. + if (srcLoopIVs[i]->hasConstantLowerBound() && + srcLoopIVs[i]->hasConstantUpperBound()) { + (*tripCountMap)[srcLoopIVs[i]] = + srcLoopIVs[i]->getConstantUpperBound() - + srcLoopIVs[i]->getConstantLowerBound(); + continue; + } + return false; + } + // TODO(andydavis) Merge this code with 'mlir::getTripCountExpr'. + // ub_expr - lb_expr + AffineExpr lbExpr(lbMap.getResult(0)); + AffineExpr ubExpr(ubMap.getResult(0)); + auto loopSpanExpr = simplifyAffineExpr( + ubExpr - lbExpr, std::max(lbMap.getNumDims(), ubMap.getNumDims()), + std::max(lbMap.getNumSymbols(), ubMap.getNumSymbols())); + auto cExpr = loopSpanExpr.dyn_cast(); + if (!cExpr) + return false; + (*tripCountMap)[srcLoopIVs[i]] = cExpr.getValue(); + } + return true; +} + +// Returns the maximum loop depth within the source loop nest at which a +// sliced loop bound is detected in 'sliceState'. +static unsigned getMaxSrcLoopDepth(unsigned srcLoopDepthLimit, + ComputationSliceState *sliceState) { + unsigned maxSrcPos = 0; + for (unsigned i = 0; i < srcLoopDepthLimit; ++i) { + if (sliceState->lbs[i] != AffineMap::Null() && + sliceState->ubs[i] != AffineMap::Null()) { + maxSrcPos = std::max(maxSrcPos, i); + } + } + return maxSrcPos + 1; +} + +// Returns the minimum loop depth within the destination loop nest at which the +// computation slice can be inserted (based on the destination loop IVs that +// the source slice actually depends on / is a function of). +static unsigned getMinDstLoopDepth(unsigned srcLoopDepth, + ComputationSliceState *sliceState) { + // Record in 'maxDstLoopDepth' the largest position (+1) of a dst loop nest + // IV, which is used in a sliced loop bound in the src loop nest. + unsigned maxDstLoopDepth = 0; + for (unsigned i = 0; i < srcLoopDepth; ++i) { + if (AffineMap lbMap = sliceState->lbs[i]) { + lbMap.walkExprs([&](AffineExpr expr) { + if (auto dimExpr = expr.dyn_cast()) { + maxDstLoopDepth = + std::max(maxDstLoopDepth, dimExpr.getPosition() + 1); + } + }); + } + if (AffineMap ubMap = sliceState->ubs[i]) { + ubMap.walkExprs([&](AffineExpr expr) { + if (auto dimExpr = expr.dyn_cast()) { + maxDstLoopDepth = + std::max(maxDstLoopDepth, dimExpr.getPosition() + 1); + } + }); + } + } + return maxDstLoopDepth; +} + +// Checks the profitability of fusion candidate 'candidate'. Returns true if it +// profitable to fuse the candidate loop nests. Returns false otherwise. +// The profitability model executes the following steps: +// *) Computes the backward computation slice at 'candidate.srcAccess'. This +// computation slice of the loop nest surrounding 'candidate.srcAccess' is +// represented by modified src loop bounds in 'sliceState', which are +// functions of loop IVs in the loop nest surrounding 'candidate.dstAccess'. +// *) Computes the cost of unfused src/dst loop nests (currently the cost of a +// loop nest is the total number of dynamic operation instances in the loop +// nest). +// *) Computes the cost of fusing a slice of the src loop nest into the dst +// loop nest at various values of src/dst loop depth, attempting to fuse +// the biggest compution slice (max src loop depth) at the maximal dst loop +// depth (closest to the load) to minimize reuse distance and opportunity for +// subsequent load/store forwarding. +// NOTE: 'srcLoopDepth' refers to the loop depth within the source loop nest +// at which we slice the loops bounds (all src loops below this depth will +// utilize full loop bounds). +// NOTE: 'dstLoopDepth' refers the loop depth within the destination loop +// nest, at which the src computation slice is inserted/fused. +// NOTE: We attempt to maximize the source loop depth, but there are cases +// where a particular setting for 'dstLoopNest' might fused an unsliced +// loop (within the src computation slice) at a depth which results in +// execessive recomputation (see unit tests for examples). +// *) Compares the total cost of the unfused loop nests to the min cost fused +// loop nest computed in the previous step, and returns true if the latter +// is lower. +static bool isFusionProfitable(FusionCandidate *candidate, + ComputationSliceState *sliceState, + unsigned *srcLoopDepth, unsigned *dstLoopDepth) { + // Compute backward computation slice state: src IV bounds w.r.t dst IVs, etc. + if (!mlir::getBackwardComputationSliceState( + candidate->srcAccess, candidate->dstAccess, sliceState)) { + return false; + } + + // Build trip count map for src loops with sliced loop bounds in 'sliceState'. + DenseMap sliceTripCountMap; + if (!buildSliceTripCountMap(&candidate->srcAccess, sliceState, + &sliceTripCountMap)) + return false; + + // Compute cost of sliced and unsliced src loop nest. + SmallVector srcLoopIVs; + getLoopIVs(*candidate->srcAccess.opInst, &srcLoopIVs); + unsigned numSrcLoopIVs = srcLoopIVs.size(); + + // Walk src loop nest and collect stats. + LoopNestStats srcLoopNestStats; + LoopNestStatsCollector srcStatsCollector(&srcLoopNestStats); + srcStatsCollector.walk(srcLoopIVs[0]); + // Currently only constant trip count loop nests are supported. + if (srcStatsCollector.hasLoopWithNonConstTripCount) + return false; + + // Compute cost of dst loop nest. + SmallVector dstLoopIVs; + getLoopIVs(*candidate->dstAccess.opInst, &dstLoopIVs); + unsigned numDstLoopIVs = dstLoopIVs.size(); + + LoopNestStats dstLoopNestStats; + LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats); + dstStatsCollector.walk(dstLoopIVs[0]); + // Currently only constant trip count loop nests are supported. + if (dstStatsCollector.hasLoopWithNonConstTripCount) + return false; + + // Search for min cost values for 'srcLoopDepth' and 'dstLoopDepth'. + // This search is O(n^2) where 'n' is very small (eg. six). + // TODO(andydavis) Consider a solution where we just iteration through + // dstLoopDepth possibilities and project out IVs we do not need (remove + // dependence on 'srcLoopDepth'. + DenseMap tripCountMap; + DenseMap computeCostMap; + unsigned maxSrcLoopDepth = getMaxSrcLoopDepth(numSrcLoopIVs, sliceState); + unsigned minFusedLoopNestComputeCost = std::numeric_limits::max(); + unsigned bestSrcLoopDepth; + unsigned bestDstLoopDepth; + for (unsigned i = maxSrcLoopDepth; i >= 1; --i) { + // Compute minDstLoopDepth based on dst loop IVs used in slice loop bounds. + unsigned minDstLoopDepth = getMinDstLoopDepth(i, sliceState); + assert(minDstLoopDepth <= numDstLoopIVs); + if (minDstLoopDepth == 0) { + // TODO(andydavis) Support inserting computation slices at top-level. + continue; + } + // Copy elements from slice trip count map up to src loop depth 'i'. + tripCountMap.clear(); + for (unsigned k = 0; k < i; ++k) { + auto *forInst = srcLoopIVs[k]; + auto it = sliceTripCountMap.find(forInst); + if (it != sliceTripCountMap.end()) { + tripCountMap[forInst] = it->second; + } + } + // Compute op instance count for the src loop nest with iteration slicing. + uint64_t sliceComputeCost = + getComputeCost(srcLoopIVs[0], &srcLoopNestStats, &tripCountMap, + /*computeCostMap=*/nullptr); + + for (unsigned j = numDstLoopIVs; j >= minDstLoopDepth; --j) { + // Compute cost of fusion for these values of 'i' and 'j'. + computeCostMap.clear(); + computeCostMap[dstLoopIVs[j - 1]] = sliceComputeCost; + uint64_t fusedLoopNestComputeCost = + getComputeCost(dstLoopIVs[0], &dstLoopNestStats, + /*tripCountOverrideMap=*/nullptr, &computeCostMap); + if (fusedLoopNestComputeCost < minFusedLoopNestComputeCost) { + minFusedLoopNestComputeCost = fusedLoopNestComputeCost; + bestSrcLoopDepth = i; + bestDstLoopDepth = j; + } + } + } + + // Compute op instance count for the src loop nest without iteration slicing. + uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], &srcLoopNestStats, + /*tripCountOverrideMap=*/nullptr, + /*computeCostMap=*/nullptr); + // Compute op instance count for the src loop nest. + uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], &dstLoopNestStats, + /*tripCountOverrideMap=*/nullptr, + /*computeCostMap=*/nullptr); + + LLVM_DEBUG(llvm::dbgs() << "LoopFusion statistics " + << " bestSrcLoopDepth: " << bestSrcLoopDepth + << " bestDstLoopDepth: " << bestDstLoopDepth + << " srcLoopNestCost: " << srcLoopNestCost + << " dstLoopNestCost: " << dstLoopNestCost + << " minFusedLoopNestComputeCost: " + << minFusedLoopNestComputeCost << "\n"); + + // Do not fuse if fused loop would increase the total cost of the computation. + // TODO(andydavis) Use locality/reduction in slice memref size/opportunity + // for load/store forwarding in cost model. + if (minFusedLoopNestComputeCost > srcLoopNestCost + dstLoopNestCost) + return false; + // Set src/dstLoopDepth based on best values from search. + *srcLoopDepth = bestSrcLoopDepth; + *dstLoopDepth = bestDstLoopDepth; + // Update 'sliceState' bounds based on computed 'srcLoopDepth': + // *) Canonicalize affine map now that 'srcLoopDepth' has been chosen. + // *) Replace slice bound maps at depth > 'srcLoopDepth' withAffineMap::Null() + for (unsigned i = 0; i < numSrcLoopIVs; ++i) { + if (i < bestSrcLoopDepth) { + if (sliceState->lbs[i] != AffineMap::Null()) { + canonicalizeMapAndOperands(sliceState->lbs[i], + sliceState->lbOperands[i]); + } + if (sliceState->ubs[i] != AffineMap::Null()) { + canonicalizeMapAndOperands(sliceState->ubs[i], + sliceState->ubOperands[i]); + } + } else { + sliceState->lbs[i] = AffineMap::Null(); + sliceState->ubs[i] = AffineMap::Null(); + } + } + return true; +} + // GreedyFusion greedily fuses loop nests which have a producer/consumer // relationship on a memref, with the goal of improving locality. Currently, // this the producer/consumer relationship is required to be unique in the @@ -479,16 +809,17 @@ public: // Build fusion candidate out of 'srcStoreOpInst' and 'dstLoadOpInst'. FusionCandidate candidate = buildFusionCandidate(srcStoreOpInst, dstLoadOpInst); + // Check if fusion would be profitable. + unsigned srcLoopDepth; + unsigned dstLoopDepth; + mlir::ComputationSliceState sliceState; + if (!isFusionProfitable(&candidate, &sliceState, &srcLoopDepth, + &dstLoopDepth)) + continue; // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. - unsigned srcLoopDepth = clSrcLoopDepth.getNumOccurrences() > 0 - ? clSrcLoopDepth - : getNestingDepth(*srcStoreOpInst); - unsigned dstLoopDepth = clDstLoopDepth.getNumOccurrences() > 0 - ? clDstLoopDepth - : getNestingDepth(*dstLoadOpInst); auto *sliceLoopNest = mlir::insertBackwardComputationSlice( - &candidate.srcAccess, &candidate.dstAccess, srcLoopDepth, - dstLoopDepth); + &candidate.srcAccess, &candidate.dstAccess, &sliceState, + srcLoopDepth, dstLoopDepth); if (sliceLoopNest != nullptr) { // Remove edges between 'srcNode' and 'dstNode' and remove 'srcNode' mdg->updateEdgesAndRemoveSrcNode(srcNode->id, dstNode->id); diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index 57172b1cf9a..525c9d63ad0 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -1,5 +1,4 @@ // RUN: mlir-opt %s -loop-fusion -split-input-file -verify | FileCheck %s -// RUN: mlir-opt %s -loop-fusion -fusion-src-loop-depth=1 -fusion-dst-loop-depth=1 -split-input-file -verify | FileCheck %s --check-prefix DEPTH1 // TODO(andydavis) Add more tests: // *) Add nested fusion test cases when non-constant loop bound support is @@ -76,8 +75,7 @@ func @should_fuse_reduction_to_pointwise() { // ----- -// CHECK: [[MAP_SHIFT_MINUS_ONE_D0:#map[0-9]+]] = (d0, d1) -> (d0 - 1) -// CHECK: [[MAP_SHIFT_MINUS_ONE_D1:#map[0-9]+]] = (d0, d1) -> (d1 - 1) +// CHECK: [[MAP_SHIFT_MINUS_ONE:#map[0-9]+]] = (d0) -> (d0 - 1) // CHECK: [[MAP_SHIFT_BY_ONE:#map[0-9]+]] = (d0, d1) -> (d0 + 1, d1 + 1) // CHECK-LABEL: func @should_fuse_loop_nests_with_shifts() { @@ -99,8 +97,8 @@ func @should_fuse_loop_nests_with_shifts() { // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: for %i1 = 0 to 10 { - // CHECK-NEXT: %1 = affine_apply [[MAP_SHIFT_MINUS_ONE_D0]](%i0, %i1) - // CHECK-NEXT: %2 = affine_apply [[MAP_SHIFT_MINUS_ONE_D1]](%i0, %i1) + // CHECK-NEXT: %1 = affine_apply [[MAP_SHIFT_MINUS_ONE]](%i0) + // CHECK-NEXT: %2 = affine_apply [[MAP_SHIFT_MINUS_ONE]](%i1) // CHECK-NEXT: %3 = affine_apply [[MAP_SHIFT_BY_ONE]](%1, %2) // CHECK-NEXT: store %cst, %0[%3#0, %3#1] : memref<10x10xf32> // CHECK-NEXT: %4 = load %0[%i0, %i1] : memref<10x10xf32> @@ -112,8 +110,7 @@ func @should_fuse_loop_nests_with_shifts() { // ----- -// CHECK-DAG: [[MAP_DIM0:#map[0-9]+]] = (d0, d1) -> (d0) -// CHECK-DAG: [[MAP_DIM1:#map[0-9]+]] = (d0, d1) -> (d1) +// CHECK-DAG: [[MAP_ID:#map[0-9]+]] = (d0) -> (d0) // CHECK-LABEL: func @should_fuse_loop_nest() { func @should_fuse_loop_nest() { @@ -140,11 +137,11 @@ func @should_fuse_loop_nest() { // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: for %i1 = 0 to 10 { - // CHECK-NEXT: %2 = affine_apply [[MAP_DIM1]](%i0, %i1) - // CHECK-NEXT: %3 = affine_apply [[MAP_DIM0]](%i0, %i1) + // CHECK-NEXT: %2 = affine_apply [[MAP_ID]](%i1) + // CHECK-NEXT: %3 = affine_apply [[MAP_ID]](%i0) // CHECK-NEXT: store %cst, %0[%2, %3] : memref<10x10xf32> - // CHECK-NEXT: %4 = affine_apply [[MAP_DIM0]](%i0, %i1) - // CHECK-NEXT: %5 = affine_apply [[MAP_DIM1]](%i0, %i1) + // CHECK-NEXT: %4 = affine_apply [[MAP_ID]](%i0) + // CHECK-NEXT: %5 = affine_apply [[MAP_ID]](%i1) // CHECK-NEXT: %6 = load %0[%5, %4] : memref<10x10xf32> // CHECK-NEXT: store %6, %1[%4, %5] : memref<10x10xf32> // CHECK-NEXT: %7 = load %1[%i0, %i1] : memref<10x10xf32> @@ -511,10 +508,8 @@ func @should_not_fuse_if_inst_in_loop_nest() { // ----- -// CHECK-DAG: [[MAP_D0:#map[0-9]+]] = (d0, d1, d2) -> (d0) -// CHECK-DAG: [[MAP_D1:#map[0-9]+]] = (d0, d1, d2) -> (d1) -// CHECK-DAG: [[MAP_D2:#map[0-9]+]] = (d0, d1, d2) -> (d2) -// CHECK: [[MAP_IDENTITY:#map[0-9]+]] = (d0, d1, d2) -> (d0, d1, d2) +// CHECK: [[MAP0:#map[0-9]+]] = (d0) -> (d0) +// CHECK: [[MAP1:#map[0-9]+]] = (d0, d1, d2) -> (d0, d1, d2) // CHECK: [[MAP_PERMUTE:#map[0-9]+]] = (d0, d1, d2) -> (d1, d2, d0) // CHECK-LABEL: func @remap_ivs() { @@ -541,10 +536,10 @@ func @remap_ivs() { // CHECK: for %i0 = 0 to 30 { // CHECK-NEXT: for %i1 = 0 to 10 { // CHECK-NEXT: for %i2 = 0 to 20 { -// CHECK-NEXT: %1 = affine_apply [[MAP_D1]](%i0, %i1, %i2) -// CHECK-NEXT: %2 = affine_apply [[MAP_D2]](%i0, %i1, %i2) -// CHECK-NEXT: %3 = affine_apply [[MAP_D0]](%i0, %i1, %i2) -// CHECK-NEXT: %4 = affine_apply [[MAP_IDENTITY]](%1, %2, %3) +// CHECK-NEXT: %1 = affine_apply [[MAP0]](%i1) +// CHECK-NEXT: %2 = affine_apply [[MAP0]](%i2) +// CHECK-NEXT: %3 = affine_apply [[MAP0]](%i0) +// CHECK-NEXT: %4 = affine_apply [[MAP1]](%1, %2, %3) // CHECK-NEXT: store %cst, %0[%4#0, %4#1, %4#2] : memref<10x20x30xf32> // CHECK-NEXT: %5 = affine_apply [[MAP_PERMUTE]](%i0, %i1, %i2) // CHECK-NEXT: %6 = load %0[%5#0, %5#1, %5#2] : memref<10x20x30xf32> @@ -558,51 +553,6 @@ func @remap_ivs() { // ----- -// DEPTH1: #map0 = (d0) -> (d0) -// DEPTH1: #map1 = (d0, d1, d2) -> (d0, d1, d2) - -// DEPTH1-LABEL: func @fuse_slice_at_depth1() { -func @fuse_slice_at_depth1() { - %m = alloc() : memref<100x16x100xf32> - - %cf7 = constant 7.0 : f32 - for %i0 = 0 to 100 { - for %i1 = 0 to 16 { - for %i2 = 0 to 100 { - %a0 = affine_apply (d0, d1, d2) -> (d0, d1, d2) (%i0, %i1, %i2) - store %cf7, %m[%a0#0, %a0#1, %a0#2] : memref<100x16x100xf32> - } - } - } - for %i3 = 0 to 100 { - for %i4 = 0 to 16 { - for %i5 = 0 to 100 { - %a1 = affine_apply (d0, d1, d2) -> (d0, d1, d2) (%i3, %i4, %i5) - %v0 = load %m[%a1#0, %a1#1, %a1#2] : memref<100x16x100xf32> - } - } - } -// DEPTH1: for %i0 = 0 to 100 { -// DEPTH1-NEXT: %1 = affine_apply #map0(%i0) -// DEPTH1-NEXT: for %i1 = 0 to 16 { -// DEPTH1-NEXT: for %i2 = 0 to 100 { -// DEPTH1-NEXT: %2 = affine_apply #map1(%1, %i1, %i2) -// DEPTH1-NEXT: store %cst, %0[%2#0, %2#1, %2#2] : memref<100x16x100xf32> -// DEPTH1-NEXT: } -// DEPTH1-NEXT: } -// DEPTH1-NEXT: for %i3 = 0 to 16 { -// DEPTH1-NEXT: for %i4 = 0 to 100 { -// DEPTH1-NEXT: %3 = affine_apply #map1(%i0, %i3, %i4) -// DEPTH1-NEXT: %4 = load %0[%3#0, %3#1, %3#2] : memref<100x16x100xf32> -// DEPTH1-NEXT: } -// DEPTH1-NEXT: } -// DEPTH1-NEXT: } -// DEPTH1-NEXT: return - return -} - -// ----- - // CHECK-DAG: #map0 = (d0, d1) -> (d0 * 4 + d1) // CHECK-DAG: #map1 = (d0) -> (d0 floordiv 4, d0 mod 4) @@ -732,10 +682,10 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> { // CHECK-NEXT: %6 = affine_apply #map4(%i0, %i1) // CHECK-NEXT: %7 = "foo"(%2, %3, %4, %5, %6, %c0) : (index, index, index, index, index, index) -> i32 // CHECK-NEXT: store %7, %0[%2, %3, %4, %5, %6, %c0] : memref<2x2x3x3x16x1xi32> -// CHECK-NEXT: %8 = affine_apply #map5(%i0, %i1) -// CHECK-NEXT: %9 = affine_apply #map6(%i0, %i1) -// CHECK-NEXT: %10 = affine_apply #map7(%8, %9) -// CHECK-NEXT: %11 = affine_apply #map8(%10) +// CHECK-NEXT: %8 = affine_apply #map5(%i0) +// CHECK-NEXT: %9 = affine_apply #map5(%i1) +// CHECK-NEXT: %10 = affine_apply #map6(%8, %9) +// CHECK-NEXT: %11 = affine_apply #map7(%10) // CHECK-NEXT: %12 = load %0[%11#0, %11#1, %11#3, %11#4, %11#2, %11#5] : memref<2x2x3x3x16x1xi32> // CHECK-NEXT: store %12, %1[%8, %9] : memref<64x9xi32> // CHECK-NEXT: %13 = load %1[%i0, %i1] : memref<64x9xi32> @@ -769,3 +719,261 @@ func @fuse_symbolic_bounds(%M : index, %N : index) { return } + +// ----- +// CHECK: #map0 = (d0) -> (d0) + +// CHECK-LABEL: func @should_fuse_reduction_at_depth1 +func @should_fuse_reduction_at_depth1() { + %a = alloc() : memref<10x100xf32> + %b = alloc() : memref<10xf32> + + for %i0 = 0 to 10 { + for %i1 = 0 to 100 { + %v0 = load %b[%i0] : memref<10xf32> + %v1 = load %a[%i0, %i1] : memref<10x100xf32> + %v2 = "maxf"(%v0, %v1) : (f32, f32) -> f32 + store %v2, %b[%i0] : memref<10xf32> + } + } + for %i2 = 0 to 10 { + for %i3 = 0 to 100 { + %v3 = load %b[%i2] : memref<10xf32> + %v4 = load %a[%i2, %i3] : memref<10x100xf32> + %v5 = subf %v4, %v3 : f32 + store %v5, %b[%i2] : memref<10xf32> + } + } + // This test should fuse the src reduction loop at depth 1 in the destination + // loop nest, which improves locality and enables subsequence passes to + // decrease the reduction memref size and possibly place it in a faster + // memory space. + // CHECK: for %i0 = 0 to 10 { + // CHECK-NEXT: %2 = affine_apply #map0(%i0) + // CHECK-NEXT: for %i1 = 0 to 100 { + // CHECK-NEXT: %3 = load %1[%2] : memref<10xf32> + // CHECK-NEXT: %4 = load %0[%2, %i1] : memref<10x100xf32> + // CHECK-NEXT: %5 = "maxf"(%3, %4) : (f32, f32) -> f32 + // CHECK-NEXT: store %5, %1[%2] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: for %i2 = 0 to 100 { + // CHECK-NEXT: %6 = load %1[%i0] : memref<10xf32> + // CHECK-NEXT: %7 = load %0[%i0, %i2] : memref<10x100xf32> + // CHECK-NEXT: %8 = subf %7, %6 : f32 + // CHECK-NEXT: store %8, %1[%i0] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: return + return +} + +// ----- +// CHECK: #map0 = (d0) -> (d0) + +// CHECK-LABEL: func @should_fuse_at_src_depth1_and_dst_depth1 +func @should_fuse_at_src_depth1_and_dst_depth1() { + %a = alloc() : memref<100x16xf32> + %b = alloc() : memref<100x16xf32> + + for %i0 = 0 to 100 { + for %i1 = 0 to 16 { + %v0 = load %a[%i0, %i1] : memref<100x16xf32> + "op0"(%v0) : (f32) -> () + } + for %i2 = 0 to 16 { + %v1 = "op1"() : () -> (f32) + store %v1, %b[%i0, %i2] : memref<100x16xf32> + } + } + + for %i3 = 0 to 100 { + for %i4 = 0 to 16 { + %v2 = load %b[%i3, %i4] : memref<100x16xf32> + "op2"(%v2) : (f32) -> () + } + } + // We can slice iterations of the '%i0' and '%i1' loops in the the source + // loop nest, but slicing at depth 2 and inserting the slice in the + // destination loop nest at depth2 causes extra computation. Instead, + // the fusion algorithm should detect that the source loop should be sliced + // at depth 1 and the slice should be inserted at depth 1. + // CHECK: for %i0 = 0 to 100 { + // CHECK-NEXT: %2 = affine_apply #map0(%i0) + // CHECK-NEXT: for %i1 = 0 to 16 { + // CHECK-NEXT: %3 = load %0[%2, %i1] : memref<100x16xf32> + // CHECK-NEXT: "op0"(%3) : (f32) -> () + // CHECK-NEXT: } + // CHECK-NEXT: for %i2 = 0 to 16 { + // CHECK-NEXT: %4 = "op1"() : () -> f32 + // CHECK-NEXT: store %4, %1[%2, %i2] : memref<100x16xf32> + // CHECK-NEXT: } + // CHECK-NEXT: for %i3 = 0 to 16 { + // CHECK-NEXT: %5 = load %1[%i0, %i3] : memref<100x16xf32> + // CHECK-NEXT: "op2"(%5) : (f32) -> () + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: return + return +} + +// ----- +// CHECK: #map0 = (d0, d1) -> (d0 * 10 + d1) + +// CHECK-LABEL: func @should_fuse_src_depth1_at_dst_depth2 +func @should_fuse_src_depth1_at_dst_depth2() { + %a = alloc() : memref<100xf32> + %c0 = constant 0.0 : f32 + + for %i0 = 0 to 100 { + store %c0, %a[%i0] : memref<100xf32> + } + + for %i1 = 0 to 10 { + for %i2 = 0 to 10 { + %a0 = affine_apply (d0, d1) -> (d0 * 10 + d1) (%i1, %i2) + %v0 = load %a[%a0] : memref<100xf32> + } + } + // The source loop nest slice loop bound is a function of both destination + // loop IVs, so we should slice at depth 1 and insert the slice at depth 2. + // CHECK: for %i0 = 0 to 10 { + // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: %1 = affine_apply #map0(%i0, %i1) + // CHECK-NEXT: store %cst, %0[%1] : memref<100xf32> + // CHECK-NEXT: %2 = affine_apply #map0(%i0, %i1) + // CHECK-NEXT: %3 = load %0[%2] : memref<100xf32> + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: return + return +} + +// ----- + +// CHECK-LABEL: func @fusion_at_depth0_not_currently_supported +func @fusion_at_depth0_not_currently_supported() { + %0 = alloc() : memref<10xf32> + %c0 = constant 0 : index + %cst = constant 0.000000e+00 : f32 + for %i0 = 0 to 10 { + store %cst, %0[%i0] : memref<10xf32> + } + for %i1 = 0 to 10 { + %1 = load %0[%c0] : memref<10xf32> + } + // CHECK:for %i0 = 0 to 10 { + // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: %1 = load %0[%c0] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: return + return +} + +// ----- +// CHECK: #map0 = (d0) -> (d0) + +// CHECK-LABEL: func @should_fuse_deep_loop_nests +func @should_fuse_deep_loop_nests() { + %0 = alloc() : memref<2x2x3x3x16x10xf32, 2> + %1 = alloc() : memref<2x2x3x3x16x10xf32, 2> + %2 = alloc() : memref<3x3x3x3x16x10xf32, 2> + %c0 = constant 0 : index + %c1 = constant 1 : index + %c1_0 = constant 1 : index + %cst = constant 0.000000e+00 : f32 + for %i0 = 0 to 2 { + for %i1 = 0 to 2 { + for %i2 = 0 to 3 { + for %i3 = 0 to 3 { + for %i4 = 0 to 16 { + for %i5 = 0 to 10 { + %3 = load %0[%i0, %i1, %i2, %i3, %i4, %i5] + : memref<2x2x3x3x16x10xf32, 2> + } + } + for %i6 = 0 to 16 { + for %i7 = 0 to 10 { + store %cst, %1[%i0, %i1, %i2, %i3, %i6, %i7] + : memref<2x2x3x3x16x10xf32, 2> + } + } + } + } + } + } + for %i8 = 0 to 3 { + for %i9 = 0 to 3 { + for %i10 = 0 to 2 { + for %i11 = 0 to 2 { + for %i12 = 0 to 3 { + for %i13 = 0 to 3 { + for %i14 = 0 to 2 { + for %i15 = 0 to 2 { + for %i16 = 0 to 16 { + for %i17 = 0 to 10 { + %5 = load %0[%i14, %i15, %i12, %i13, %i16, %i17] + : memref<2x2x3x3x16x10xf32, 2> + } + } + for %i18 = 0 to 16 { + for %i19 = 0 to 10 { + %6 = load %1[%i10, %i11, %i8, %i9, %i18, %i19] + : memref<2x2x3x3x16x10xf32, 2> + } + } + } + } + } + } + } + } + } + } +// The first four loops of the source loop nest can be sliced with iteration +// bounds which are a function of the first four loops of destination loop nest, +// where the destination loops nests have been interchanged. +// CHECK: for %i0 = 0 to 3 { +// CHECK-NEXT: for %i1 = 0 to 3 { +// CHECK-NEXT: for %i2 = 0 to 2 { +// CHECK-NEXT: for %i3 = 0 to 2 { +// CHECK-NEXT: %3 = affine_apply #map0(%i2) +// CHECK-NEXT: %4 = affine_apply #map0(%i3) +// CHECK-NEXT: %5 = affine_apply #map0(%i0) +// CHECK-NEXT: %6 = affine_apply #map0(%i1) +// CHECK-NEXT: for %i4 = 0 to 16 { +// CHECK-NEXT: for %i5 = 0 to 10 { +// CHECK-NEXT: %7 = load %0[%3, %4, %5, %6, %i4, %i5] : memref<2x2x3x3x16x10xf32, 2> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: for %i6 = 0 to 16 { +// CHECK-NEXT: for %i7 = 0 to 10 { +// CHECK-NEXT: store %cst, %1[%3, %4, %5, %6, %i6, %i7] : memref<2x2x3x3x16x10xf32, 2> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: for %i8 = 0 to 3 { +// CHECK-NEXT: for %i9 = 0 to 3 { +// CHECK-NEXT: for %i10 = 0 to 2 { +// CHECK-NEXT: for %i11 = 0 to 2 { +// CHECK-NEXT: for %i12 = 0 to 16 { +// CHECK-NEXT: for %i13 = 0 to 10 { +// CHECK-NEXT: %8 = load %0[%i10, %i11, %i8, %i9, %i12, %i13] : memref<2x2x3x3x16x10xf32, 2> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: for %i14 = 0 to 16 { +// CHECK-NEXT: for %i15 = 0 to 10 { +// CHECK-NEXT: %9 = load %1[%i2, %i3, %i0, %i1, %i14, %i15] : memref<2x2x3x3x16x10xf32, 2> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return + return +} -- cgit v1.2.3 From 03e15e1b9f84a7bed35ff4065ad7135b571d2d65 Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Tue, 15 Jan 2019 14:41:56 -0800 Subject: Minor code cleanup - NFC. - readability changes PiperOrigin-RevId: 229443430 --- mlir/include/mlir/IR/BuiltinOps.h | 4 +-- mlir/lib/Analysis/AffineAnalysis.cpp | 2 +- mlir/lib/Analysis/Utils.cpp | 14 +++++---- mlir/lib/IR/BuiltinOps.cpp | 45 ++++++++++++++-------------- mlir/lib/Transforms/DmaGeneration.cpp | 7 +---- mlir/lib/Transforms/LoopFusion.cpp | 8 ++--- mlir/lib/Transforms/PipelineDataTransfer.cpp | 10 +++---- mlir/lib/Transforms/Utils/Utils.cpp | 12 +++----- 8 files changed, 48 insertions(+), 54 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/IR/BuiltinOps.h b/mlir/include/mlir/IR/BuiltinOps.h index b182d1f709a..d9ed7cfd9c4 100644 --- a/mlir/include/mlir/IR/BuiltinOps.h +++ b/mlir/include/mlir/IR/BuiltinOps.h @@ -399,8 +399,8 @@ bool parseDimAndSymbolList(OpAsmParser *parser, SmallVector &operands, unsigned &numDims); -void canonicalizeMapAndOperands(AffineMap &map, - llvm::SmallVectorImpl &operands); +void canonicalizeMapAndOperands(AffineMap *map, + llvm::SmallVectorImpl *operands); } // end namespace mlir diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index cc0071d6b5d..19283b319b6 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -1535,7 +1535,7 @@ static void composeAffineMapAndOperands(AffineMap *map, AffineNormalizer normalizer(*map, *operands); auto normalizedMap = normalizer.getAffineMap(); auto normalizedOperands = normalizer.getOperands(); - canonicalizeMapAndOperands(normalizedMap, normalizedOperands); + canonicalizeMapAndOperands(&normalizedMap, &normalizedOperands); *map = normalizedMap; *operands = normalizedOperands; assert(*map); diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 12ac0cc44ec..49e1e31f55d 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -112,22 +112,22 @@ Optional MemRefRegion::getBoundingConstantSizeAndShape( // (dma_start, dma_wait). bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth, MemRefRegion *region) { - OpPointer loadOp; - OpPointer storeOp; unsigned rank; SmallVector indices; - - if ((loadOp = opInst->dyn_cast())) { + if (auto loadOp = opInst->dyn_cast()) { rank = loadOp->getMemRefType().getRank(); + indices.reserve(rank); indices.append(loadOp->getIndices().begin(), loadOp->getIndices().end()); region->memref = loadOp->getMemRef(); region->setWrite(false); - } else if ((storeOp = opInst->dyn_cast())) { + } else if (auto storeOp = opInst->dyn_cast()) { rank = storeOp->getMemRefType().getRank(); + indices.reserve(rank); indices.append(storeOp->getIndices().begin(), storeOp->getIndices().end()); region->memref = storeOp->getMemRef(); region->setWrite(true); } else { + assert(false && "expected load or store op"); return false; } @@ -191,6 +191,7 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth, // this memref region is symbolic. SmallVector outerIVs; getLoopIVs(*opInst, &outerIVs); + assert(loopDepth <= outerIVs.size() && "invalid loop depth"); outerIVs.resize(loopDepth); for (auto *operand : accessValueMap.getOperands()) { ForInst *iv; @@ -249,12 +250,13 @@ bool mlir::boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, static_assert( std::is_same>::value || std::is_same>::value, - "function argument should be either a LoadOp or a StoreOp"); + "argument should be either a LoadOp or a StoreOp"); OperationInst *opInst = loadOrStoreOp->getInstruction(); MemRefRegion region; if (!getMemRefRegion(opInst, /*loopDepth=*/0, ®ion)) return false; + LLVM_DEBUG(llvm::dbgs() << "Memory region"); LLVM_DEBUG(region.getConstraints()->dump()); diff --git a/mlir/lib/IR/BuiltinOps.cpp b/mlir/lib/IR/BuiltinOps.cpp index 94fa58139af..da570f4b805 100644 --- a/mlir/lib/IR/BuiltinOps.cpp +++ b/mlir/lib/IR/BuiltinOps.cpp @@ -198,61 +198,62 @@ struct SimplifyAffineApplyState : public PatternState { } // end anonymous namespace. void mlir::canonicalizeMapAndOperands( - AffineMap &map, llvm::SmallVectorImpl &operands) { - if (!map || operands.empty()) + AffineMap *map, llvm::SmallVectorImpl *operands) { + if (!map || operands->empty()) return; - assert(map.getNumInputs() == operands.size() && + assert(map->getNumInputs() == operands->size() && "map inputs must match number of operands"); // Check to see what dims are used. - llvm::SmallBitVector usedDims(map.getNumDims()); - llvm::SmallBitVector usedSyms(map.getNumSymbols()); - map.walkExprs([&](AffineExpr expr) { + llvm::SmallBitVector usedDims(map->getNumDims()); + llvm::SmallBitVector usedSyms(map->getNumSymbols()); + map->walkExprs([&](AffineExpr expr) { if (auto dimExpr = expr.dyn_cast()) usedDims[dimExpr.getPosition()] = true; else if (auto symExpr = expr.dyn_cast()) usedSyms[symExpr.getPosition()] = true; }); - auto *context = map.getContext(); + auto *context = map->getContext(); SmallVector resultOperands; - resultOperands.reserve(operands.size()); + resultOperands.reserve(operands->size()); llvm::SmallDenseMap seenDims; - SmallVector dimRemapping(map.getNumDims()); + SmallVector dimRemapping(map->getNumDims()); unsigned nextDim = 0; - for (unsigned i = 0, e = map.getNumDims(); i != e; ++i) { + for (unsigned i = 0, e = map->getNumDims(); i != e; ++i) { if (usedDims[i]) { - auto it = seenDims.find(operands[i]); + auto it = seenDims.find((*operands)[i]); if (it == seenDims.end()) { dimRemapping[i] = getAffineDimExpr(nextDim++, context); - resultOperands.push_back(operands[i]); - seenDims.insert(std::make_pair(operands[i], dimRemapping[i])); + resultOperands.push_back((*operands)[i]); + seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i])); } else { dimRemapping[i] = it->second; } } } llvm::SmallDenseMap seenSymbols; - SmallVector symRemapping(map.getNumSymbols()); + SmallVector symRemapping(map->getNumSymbols()); unsigned nextSym = 0; - for (unsigned i = 0, e = map.getNumSymbols(); i != e; ++i) { + for (unsigned i = 0, e = map->getNumSymbols(); i != e; ++i) { if (usedSyms[i]) { - auto it = seenSymbols.find(operands[i + map.getNumDims()]); + auto it = seenSymbols.find((*operands)[i + map->getNumDims()]); if (it == seenSymbols.end()) { symRemapping[i] = getAffineSymbolExpr(nextSym++, context); - resultOperands.push_back(operands[i + map.getNumDims()]); - seenSymbols.insert( - std::make_pair(operands[i + map.getNumDims()], symRemapping[i])); + resultOperands.push_back((*operands)[i + map->getNumDims()]); + seenSymbols.insert(std::make_pair((*operands)[i + map->getNumDims()], + symRemapping[i])); } else { symRemapping[i] = it->second; } } } - map = map.replaceDimsAndSymbols(dimRemapping, symRemapping, nextDim, nextSym); - operands = resultOperands; + *map = + map->replaceDimsAndSymbols(dimRemapping, symRemapping, nextDim, nextSym); + *operands = resultOperands; } PatternMatchResult SimplifyAffineApply::match(OperationInst *op) const { @@ -262,7 +263,7 @@ PatternMatchResult SimplifyAffineApply::match(OperationInst *op) const { AffineMap oldMap = map; SmallVector resultOperands(apply->getOperands().begin(), apply->getOperands().end()); - canonicalizeMapAndOperands(map, resultOperands); + canonicalizeMapAndOperands(&map, &resultOperands); if (map != oldMap) return matchSuccess( std::make_unique(map, resultOperands)); diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index e60f3531b62..df4aa84b039 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -223,13 +223,8 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForInst *forInst, // on; this would correspond to loop IVs surrounding the level at which the // DMA generation is being done. const FlatAffineConstraints *cst = region.getConstraints(); - auto ids = cst->getIds(); SmallVector outerIVs; - for (unsigned i = rank, e = ids.size(); i < e; i++) { - auto id = cst->getIds()[i]; - assert(id.hasValue() && "Value id expected"); - outerIVs.push_back(id.getValue()); - } + cst->getIdValues(rank, cst->getNumIds(), &outerIVs); // Construct the index expressions for the fast memory buffer. The index // expression for a particular dimension of the fast buffer is obtained by diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index dffa292af3c..c097473de3f 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -694,12 +694,12 @@ static bool isFusionProfitable(FusionCandidate *candidate, for (unsigned i = 0; i < numSrcLoopIVs; ++i) { if (i < bestSrcLoopDepth) { if (sliceState->lbs[i] != AffineMap::Null()) { - canonicalizeMapAndOperands(sliceState->lbs[i], - sliceState->lbOperands[i]); + canonicalizeMapAndOperands(&sliceState->lbs[i], + &sliceState->lbOperands[i]); } if (sliceState->ubs[i] != AffineMap::Null()) { - canonicalizeMapAndOperands(sliceState->ubs[i], - sliceState->ubOperands[i]); + canonicalizeMapAndOperands(&sliceState->ubs[i], + &sliceState->ubOperands[i]); } } else { sliceState->lbs[i] = AffineMap::Null(); diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 495c9c181fd..989af0071d7 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -87,12 +87,12 @@ static bool doubleBuffer(Value *oldMemRef, ForInst *forInst) { // Doubles the shape with a leading dimension extent of 2. auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType { // Add the leading dimension in the shape for the double buffer. - ArrayRef shape = oldMemRefType.getShape(); - SmallVector shapeSizes(shape.begin(), shape.end()); - shapeSizes.insert(shapeSizes.begin(), 2); - + ArrayRef oldShape = oldMemRefType.getShape(); + SmallVector newShape(1 + oldMemRefType.getRank()); + newShape[0] = 2; + std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1); auto newMemRefType = - bInner.getMemRefType(shapeSizes, oldMemRefType.getElementType(), {}, + bInner.getMemRefType(newShape, oldMemRefType.getElementType(), {}, oldMemRefType.getMemorySpace()); return newMemRefType; }; diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index f85847ff066..4f4aeabb26d 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -123,9 +123,6 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, FuncBuilder builder(opInst); for (auto *extraIndex : extraIndices) { - // TODO(mlir-team): An operation/SSA value should provide a method to - // return the position of an SSA result in its defining - // operation. assert(extraIndex->getDefiningInst()->getNumResults() == 1 && "single result op's expected to generate these indices"); assert((extraIndex->isValidDim() || extraIndex->isValidSymbol()) && @@ -137,7 +134,7 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, // provided. The indices of a memref come right after it, i.e., // at position memRefOperandPos + 1. SmallVector remapOperands; - remapOperands.reserve(oldMemRefRank + extraOperands.size()); + remapOperands.reserve(extraOperands.size() + oldMemRefRank); remapOperands.append(extraOperands.begin(), extraOperands.end()); remapOperands.append(opInst->operand_begin() + memRefOperandPos + 1, opInst->operand_begin() + memRefOperandPos + 1 + @@ -146,12 +143,11 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, auto remapOp = builder.create(opInst->getLoc(), indexRemap, remapOperands); // Remapped indices. - for (auto *index : remapOp->getInstruction()->getResults()) - state.operands.push_back(index); + state.operands.append(remapOp->getInstruction()->result_begin(), + remapOp->getInstruction()->result_end()); } else { // No remapping specified. - for (auto *index : remapOperands) - state.operands.push_back(index); + state.operands.append(remapOperands.begin(), remapOperands.end()); } // Insert the remaining operands unmodified. -- cgit v1.2.3 From f99a44a7cd96516fe4552b0d9bc464094a4d40ae Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Wed, 16 Jan 2019 08:29:16 -0800 Subject: Address documentation/readability related comments from cl/227252907 on memref store forwarding - NFC. PiperOrigin-RevId: 229561933 --- mlir/lib/Transforms/LoopFusion.cpp | 8 +------- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 19 +++++++++++++------ 2 files changed, 14 insertions(+), 13 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index c097473de3f..91e8d2946a6 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -80,11 +80,6 @@ struct FusionCandidate { : srcAccess(MemRefAccess(src)), dstAccess(MemRefAccess(dst)) {} }; -static FusionCandidate buildFusionCandidate(OperationInst *srcStoreOpInst, - OperationInst *dstLoadOpInst) { - return FusionCandidate(srcStoreOpInst, dstLoadOpInst); -} - namespace { // LoopNestStateCollector walks loop nests and collects load and store @@ -807,8 +802,7 @@ public: // Get unique 'srcNode' store op. auto *srcStoreOpInst = srcNode->stores.front(); // Build fusion candidate out of 'srcStoreOpInst' and 'dstLoadOpInst'. - FusionCandidate candidate = - buildFusionCandidate(srcStoreOpInst, dstLoadOpInst); + FusionCandidate candidate(srcStoreOpInst, dstLoadOpInst); // Check if fusion would be profitable. unsigned srcLoopDepth; unsigned dstLoopDepth; diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index 1a1502de738..4191a9cc279 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -41,9 +41,8 @@ namespace { // The store to load forwarding relies on three conditions: // // 1) there has to be a dependence from the store to the load satisfied at the -// block immediately within the innermost common surrounding loop of the load op -// and the store op, and such a dependence should associate with a single load -// location for a given source store iteration. +// block* immediately within the innermost loop enclosing both the load op and +// the store op, // // 2) the store op should dominate the load op, // @@ -52,10 +51,18 @@ namespace { // provably the last writer to the particular memref location being loaded from // by the load op, and its store value can be forwarded to the load. // +// 4) the load should touch a single location in the memref for a given +// iteration of the innermost loop enclosing both the store op and the load op. +// +// (* A dependence being satisfied at a block: a dependence that is satisfied by +// virtue of the destination instruction appearing textually / lexically after +// the source instruction within the body of a 'for' instruction; thus, a +// dependence is always either satisfied by a loop or by a block). +// // The above conditions are simple to check, sufficient, and powerful for most // cases in practice - condition (1) and (3) are precise and necessary, while // condition (2) is a sufficient one but not necessary (since it doesn't reason -// about loops that are guaranteed to execute at least one). +// about loops that are guaranteed to execute at least once). // // TODO(mlir-team): more forwarding can be done when support for // loop/conditional live-out SSA values is available. @@ -126,7 +133,7 @@ void MemRefDataFlowOpt::visitOperationInst(OperationInst *opInst) { // conditions listed at the top. SmallVector fwdingCandidates; // Store ops that have a dependence into the load (even if they aren't - // forwarding candidates). Each fwding candidate will be checked for a + // forwarding candidates). Each forwarding candidate will be checked for a // post-dominance on these. 'fwdingCandidates' are a subset of depSrcStores. SmallVector depSrcStores; for (auto *storeOpInst : storeOps) { @@ -243,7 +250,7 @@ PassResult MemRefDataFlowOpt::runOnFunction(Function *f) { OperationInst *defInst = memref->getDefiningInst(); if (!defInst || !defInst->isa()) // TODO(mlir-team): if the memref was returned by a 'call' instruction, we - // could still erase it if the call has no side-effects. + // could still erase it if the call had no side-effects. continue; if (std::any_of(memref->use_begin(), memref->use_end(), [&](InstOperand &use) { -- cgit v1.2.3 From 27d067e16451da80f6b53dc90740a2238e3f4ee7 Mon Sep 17 00:00:00 2001 From: MLIR Team Date: Wed, 16 Jan 2019 09:55:02 -0800 Subject: LoopFusion improvements: *) Adds support for fusing into consumer loop nests with multiple loads from the same memref. *) Adds support for reducing slice loop trip count by projecting out destination loop IVs greater than destination loop depth. *) Removes dependence on src loop depth and simplifies cost model computation. PiperOrigin-RevId: 229575126 --- mlir/include/mlir/Analysis/Utils.h | 18 +- mlir/lib/Analysis/AffineStructures.cpp | 17 +- mlir/lib/Analysis/Utils.cpp | 56 +++-- mlir/lib/Transforms/LoopFusion.cpp | 408 ++++++++++++++++++++------------- mlir/test/Transforms/loop-fusion.mlir | 150 +++++++++++- 5 files changed, 439 insertions(+), 210 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index 7cd30ba86ab..4e304067411 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -173,23 +173,23 @@ struct ComputationSliceState { /// Returns true on success, false otherwise. bool getBackwardComputationSliceState(const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, + unsigned dstLoopDepth, ComputationSliceState *sliceState); /// Creates a clone of the computation contained in the loop nest surrounding -/// 'srcAccess', slices the iteration space of the first 'srcLoopDepth' src loop -/// IVs, and inserts the computation slice at the beginning of the instruction -/// block of the loop at 'dstLoopDepth' in the loop nest surrounding -/// 'dstAccess'. Returns the top-level loop of the computation slice on +/// 'srcOpInst', slices the iteration space of src loop based on slice bounds +/// in 'sliceState', and inserts the computation slice at the beginning of the +/// instruction block of the loop at 'dstLoopDepth' in the loop nest surrounding +/// 'dstOpInst'. Returns the top-level loop of the computation slice on /// success, returns nullptr otherwise. // Loop depth is a crucial optimization choice that determines where to // materialize the results of the backward slice - presenting a trade-off b/w // storage and redundant computation in several cases. // TODO(andydavis) Support computation slices with common surrounding loops. -ForInst *insertBackwardComputationSlice(MemRefAccess *srcAccess, - MemRefAccess *dstAccess, - ComputationSliceState *sliceState, - unsigned srcLoopDepth, - unsigned dstLoopDepth); +ForInst *insertBackwardComputationSlice(OperationInst *srcOpInst, + OperationInst *dstOpInst, + unsigned dstLoopDepth, + ComputationSliceState *sliceState); } // end namespace mlir #endif // MLIR_ANALYSIS_UTILS_H diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index bf915dbbf5b..af9252c279c 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1101,8 +1101,21 @@ void FlatAffineConstraints::getSliceBounds(unsigned num, MLIRContext *context, (*lbMaps)[pos] = AffineMap::get(numMapDims, numMapSymbols, expr, {}); (*ubMaps)[pos] = AffineMap::get(numMapDims, numMapSymbols, expr + 1, {}); } else { - (*lbMaps)[pos] = AffineMap::Null(); - (*ubMaps)[pos] = AffineMap::Null(); + // TODO(andydavis, bondhugula) Add support for computing slice bounds + // symbolic in the identifies [num, numIds). + auto lbConst = getConstantLowerBound(pos); + auto ubConst = getConstantUpperBound(pos); + if (lbConst.hasValue() && ubConst.hasValue()) { + (*lbMaps)[pos] = AffineMap::get( + numMapDims, numMapSymbols, + getAffineConstantExpr(lbConst.getValue(), context), {}); + (*ubMaps)[pos] = AffineMap::get( + numMapDims, numMapSymbols, + getAffineConstantExpr(ubConst.getValue() + 1, context), {}); + } else { + (*lbMaps)[pos] = AffineMap::Null(); + (*ubMaps)[pos] = AffineMap::Null(); + } } LLVM_DEBUG(llvm::dbgs() << "lb map for pos = " << Twine(pos) << ", expr: "); LLVM_DEBUG(expr.dump();); diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 49e1e31f55d..c003a641311 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -346,12 +346,13 @@ static Instruction *getInstAtPosition(ArrayRef positions, return nullptr; } -// Computes memref dependence between 'srcAccess' and 'dstAccess' and uses the -// dependence constraint system to create AffineMaps with which to adjust the -// loop bounds of the inserted compution slice so that they are functions of the -// loop IVs and symbols of the loops surrounding 'dstAccess'. +// Computes memref dependence between 'srcAccess' and 'dstAccess', projects +// out any dst loop IVs at depth greater than 'dstLoopDepth', and computes slice +// bounds in 'sliceState' which represent the src IVs in terms of the dst IVs, +// symbols and constants. bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, + unsigned dstLoopDepth, ComputationSliceState *sliceState) { FlatAffineConstraints dependenceConstraints; if (!checkMemrefAccessDependence(srcAccess, dstAccess, /*loopDepth=*/1, @@ -364,6 +365,19 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess, getLoopIVs(*srcAccess.opInst, &srcLoopIVs); unsigned numSrcLoopIVs = srcLoopIVs.size(); + // Get loop nest surrounding dst operation. + SmallVector dstLoopIVs; + getLoopIVs(*dstAccess.opInst, &dstLoopIVs); + unsigned numDstLoopIVs = dstLoopIVs.size(); + if (dstLoopDepth > numDstLoopIVs) { + dstAccess.opInst->emitError("invalid destination loop depth"); + return false; + } + + // Project out dimensions other than those up to 'dstLoopDepth'. + dependenceConstraints.projectOut(numSrcLoopIVs + dstLoopDepth, + numDstLoopIVs - dstLoopDepth); + // Set up lower/upper bound affine maps for the slice. sliceState->lbs.resize(numSrcLoopIVs, AffineMap::Null()); sliceState->ubs.resize(numSrcLoopIVs, AffineMap::Null()); @@ -385,12 +399,10 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess, return true; } -/// Creates a computation slice of the loop nest surrounding 'srcAccess' -/// utilizing slice loop bounds in 'sliceState' (for src loops up to -/// 'srcLoopDepth'), and inserts this slice into loop nest surrounding -/// 'dstAccess' at loop depth 'dstLoopDepth'. For all loops at loop depth -/// greater than 'srcLoopDepth' their full loop bounds will be used in the -/// slice. +/// Creates a computation slice of the loop nest surrounding 'srcOpInst', +/// updates the slice loop bounds with any non-null bound maps specified in +/// 'sliceState', and inserts this slice into the loop nest surrounding +/// 'dstOpInst' at loop depth 'dstLoopDepth'. // TODO(andydavis,bondhugula): extend the slicing utility to compute slices that // aren't necessarily a one-to-one relation b/w the source and destination. The // relation between the source and destination could be many-to-many in general. @@ -401,33 +413,27 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess, // solution. // TODO(andydavis) Remove dependence on 'srcLoopDepth' here. Instead project // out loop IVs we don't care about and produce smaller slice. -ForInst *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess, - MemRefAccess *dstAccess, - ComputationSliceState *sliceState, - unsigned srcLoopDepth, - unsigned dstLoopDepth) { +ForInst *mlir::insertBackwardComputationSlice( + OperationInst *srcOpInst, OperationInst *dstOpInst, unsigned dstLoopDepth, + ComputationSliceState *sliceState) { // Get loop nest surrounding src operation. SmallVector srcLoopIVs; - getLoopIVs(*srcAccess->opInst, &srcLoopIVs); + getLoopIVs(*srcOpInst, &srcLoopIVs); unsigned numSrcLoopIVs = srcLoopIVs.size(); - if (srcLoopDepth > numSrcLoopIVs) { - srcAccess->opInst->emitError("invalid source loop depth"); - return nullptr; - } // Get loop nest surrounding dst operation. SmallVector dstLoopIVs; - getLoopIVs(*dstAccess->opInst, &dstLoopIVs); + getLoopIVs(*dstOpInst, &dstLoopIVs); unsigned dstLoopIVsSize = dstLoopIVs.size(); if (dstLoopDepth > dstLoopIVsSize) { - dstAccess->opInst->emitError("invalid destination loop depth"); + dstOpInst->emitError("invalid destination loop depth"); return nullptr; } - // Find the inst block positions of 'srcAccess->opInst' within 'srcLoopIVs'. + // Find the inst block positions of 'srcOpInst' within 'srcLoopIVs'. SmallVector positions; // TODO(andydavis): This code is incorrect since srcLoopIVs can be 0-d. - findInstPosition(srcAccess->opInst, srcLoopIVs[0]->getBlock(), &positions); + findInstPosition(srcOpInst, srcLoopIVs[0]->getBlock(), &positions); // Clone src loop nest and insert it a the beginning of the instruction block // of the loop at 'dstLoopDepth' in 'dstLoopIVs'. @@ -451,7 +457,7 @@ ForInst *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess, assert(sliceLoopLimit >= sliceSurroundingLoopsSize); // Update loop bounds for loops in 'sliceLoopNest'. - for (unsigned i = 0; i < srcLoopDepth; ++i) { + for (unsigned i = 0; i < numSrcLoopIVs; ++i) { auto *forInst = sliceSurroundingLoops[dstLoopDepth + i]; if (AffineMap lbMap = sliceState->lbs[i]) forInst->setLowerBound(sliceState->lbOperands[i], lbMap); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 91e8d2946a6..cdd1c77f302 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -69,17 +69,6 @@ char LoopFusion::passID = 0; FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; } -// FusionCandidate encapsulates source and destination memref access within -// loop nests which are candidates for loop fusion. -struct FusionCandidate { - // Load or store access within src loop nest to be fused into dst loop nest. - MemRefAccess srcAccess; - // Load or store access within dst loop nest. - MemRefAccess dstAccess; - explicit FusionCandidate(OperationInst *src, OperationInst *dst) - : srcAccess(MemRefAccess(src)), dstAccess(MemRefAccess(dst)) {} -}; - namespace { // LoopNestStateCollector walks loop nests and collects load and store @@ -172,10 +161,27 @@ public: return &it->second; } + // Returns true iff there is an edge from node 'srcId' to node 'dstId' for + // 'memref'. Returns false otherwise. + bool hasEdge(unsigned srcId, unsigned dstId, Value *memref) { + if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) { + return false; + } + bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) { + return edge.id == dstId && edge.memref == memref; + }); + bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) { + return edge.id == srcId && edge.memref == memref; + }); + return hasOutEdge && hasInEdge; + } + // Adds an edge from node 'srcId' to node 'dstId' for 'memref'. void addEdge(unsigned srcId, unsigned dstId, Value *memref) { - outEdges[srcId].push_back({dstId, memref}); - inEdges[dstId].push_back({srcId, memref}); + if (!hasEdge(srcId, dstId, memref)) { + outEdges[srcId].push_back({dstId, memref}); + inEdges[dstId].push_back({srcId, memref}); + } } // Removes an edge from node 'srcId' to node 'dstId' for 'memref'. @@ -425,10 +431,10 @@ public: // inserting a sliced loop nest of known cost into the loop's body. // NOTE: this is used to compute the cost of fusing a slice of some loop nest // within another loop. -static uint64_t -getComputeCost(ForInst *forInst, LoopNestStats *stats, - DenseMap *tripCountOverrideMap, - DenseMap *computeCostMap) { +static uint64_t getComputeCost( + ForInst *forInst, LoopNestStats *stats, + llvm::SmallDenseMap *tripCountOverrideMap, + DenseMap *computeCostMap) { // 'opCount' is the total number operations in one iteration of 'forInst' body uint64_t opCount = stats->opCountMap[forInst]; if (stats->loopMap.count(forInst) > 0) { @@ -458,17 +464,33 @@ getComputeCost(ForInst *forInst, LoopNestStats *stats, } // end anonymous namespace +static Optional getConstDifference(AffineMap lbMap, AffineMap ubMap) { + assert(lbMap.getNumResults() == 1); + assert(ubMap.getNumResults() == 1); + assert(lbMap.getNumDims() == ubMap.getNumDims()); + assert(lbMap.getNumSymbols() == ubMap.getNumSymbols()); + // TODO(andydavis) Merge this code with 'mlir::getTripCountExpr'. + // ub_expr - lb_expr + AffineExpr lbExpr(lbMap.getResult(0)); + AffineExpr ubExpr(ubMap.getResult(0)); + auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(), + lbMap.getNumSymbols()); + auto cExpr = loopSpanExpr.dyn_cast(); + if (!cExpr) + return None; + return cExpr.getValue(); +} + // Builds a map 'tripCountMap' from ForInst to constant trip count for loop // nest surrounding 'srcAccess' utilizing slice loop bounds in 'sliceState'. // Returns true on success, false otherwise (if a non-constant trip count // was encountered). // TODO(andydavis) Make this work with non-unit step loops. -static bool -buildSliceTripCountMap(MemRefAccess *srcAccess, - ComputationSliceState *sliceState, - DenseMap *tripCountMap) { +static bool buildSliceTripCountMap( + OperationInst *srcOpInst, ComputationSliceState *sliceState, + llvm::SmallDenseMap *tripCountMap) { SmallVector srcLoopIVs; - getLoopIVs(*srcAccess->opInst, &srcLoopIVs); + getLoopIVs(*srcOpInst, &srcLoopIVs); unsigned numSrcLoopIVs = srcLoopIVs.size(); // Populate map from ForInst -> trip count for (unsigned i = 0; i < numSrcLoopIVs; ++i) { @@ -485,109 +507,166 @@ buildSliceTripCountMap(MemRefAccess *srcAccess, } return false; } - // TODO(andydavis) Merge this code with 'mlir::getTripCountExpr'. - // ub_expr - lb_expr - AffineExpr lbExpr(lbMap.getResult(0)); - AffineExpr ubExpr(ubMap.getResult(0)); - auto loopSpanExpr = simplifyAffineExpr( - ubExpr - lbExpr, std::max(lbMap.getNumDims(), ubMap.getNumDims()), - std::max(lbMap.getNumSymbols(), ubMap.getNumSymbols())); - auto cExpr = loopSpanExpr.dyn_cast(); - if (!cExpr) + Optional tripCount = getConstDifference(lbMap, ubMap); + if (!tripCount.hasValue()) return false; - (*tripCountMap)[srcLoopIVs[i]] = cExpr.getValue(); + (*tripCountMap)[srcLoopIVs[i]] = tripCount.getValue(); } return true; } -// Returns the maximum loop depth within the source loop nest at which a -// sliced loop bound is detected in 'sliceState'. -static unsigned getMaxSrcLoopDepth(unsigned srcLoopDepthLimit, - ComputationSliceState *sliceState) { - unsigned maxSrcPos = 0; - for (unsigned i = 0; i < srcLoopDepthLimit; ++i) { - if (sliceState->lbs[i] != AffineMap::Null() && - sliceState->ubs[i] != AffineMap::Null()) { - maxSrcPos = std::max(maxSrcPos, i); +// Removes load operations from 'srcLoads' which operate on 'memref', and +// adds them to 'dstLoads'. +static void +moveLoadsAccessingMemrefTo(Value *memref, + SmallVectorImpl *srcLoads, + SmallVectorImpl *dstLoads) { + dstLoads->clear(); + SmallVector srcLoadsToKeep; + for (auto *load : *srcLoads) { + if (load->cast()->getMemRef() == memref) + dstLoads->push_back(load); + else + srcLoadsToKeep.push_back(load); + } + srcLoads->swap(srcLoadsToKeep); +} + +// Returns the innermost common loop depth for the set of operations in 'ops'. +static unsigned getInnermostCommonLoopDepth(ArrayRef ops) { + unsigned numOps = ops.size(); + assert(numOps > 0); + + std::vector> loops(numOps); + unsigned loopDepthLimit = std::numeric_limits::max(); + for (unsigned i = 0; i < numOps; ++i) { + getLoopIVs(*ops[i], &loops[i]); + loopDepthLimit = + std::min(loopDepthLimit, static_cast(loops[i].size())); + } + + unsigned loopDepth = 0; + for (unsigned d = 0; d < loopDepthLimit; ++d) { + unsigned i; + for (i = 1; i < numOps; ++i) { + if (loops[i - 1][d] != loops[i][d]) { + break; + } } + if (i != numOps) + break; + ++loopDepth; } - return maxSrcPos + 1; + return loopDepth; } -// Returns the minimum loop depth within the destination loop nest at which the -// computation slice can be inserted (based on the destination loop IVs that -// the source slice actually depends on / is a function of). -static unsigned getMinDstLoopDepth(unsigned srcLoopDepth, - ComputationSliceState *sliceState) { - // Record in 'maxDstLoopDepth' the largest position (+1) of a dst loop nest - // IV, which is used in a sliced loop bound in the src loop nest. - unsigned maxDstLoopDepth = 0; - for (unsigned i = 0; i < srcLoopDepth; ++i) { - if (AffineMap lbMap = sliceState->lbs[i]) { - lbMap.walkExprs([&](AffineExpr expr) { - if (auto dimExpr = expr.dyn_cast()) { - maxDstLoopDepth = - std::max(maxDstLoopDepth, dimExpr.getPosition() + 1); - } - }); +// Returns true if 'map' is a single result constant or single result +// dim expr where its corresponding loop IV in 'operands' has zero constant +// lower bound. +static bool hasZeroMinValue(AffineMap map, ArrayRef operands) { + if (map.isSingleConstant() && map.getSingleConstantResult() == 0) + return true; + if (map.getNumResults() != 1 || !map.getResult(0).isa()) + return false; + // Get operand position of single dim expr result. + unsigned pos = map.getResult(0).cast().getPosition(); + // Check if loop IV at 'pos' has zero constant lower bound. + auto *operand = operands[pos]; + assert(isa(operand)); + auto *forInst = cast(operand); + return forInst->hasConstantLowerBound() && + forInst->getConstantLowerBound() == 0; +} +// Returns the slice bound union of 'sliceStateA' and 'sliceStateB' in +// 'sliceStateB'. +// TODO(andydavis) This function assumes that lower bounds for 'sliceStateA' +// and 'sliceStateB' are aligned. +// Specifically, when taking the union of overlapping intervals, it assumes +// that both intervals start at zero. Support needs to be added to take into +// account interval start offset when computing the union. +// TODO(andydavis) Move this function to an analysis library. +static bool getSliceBoundUnion(const ComputationSliceState &sliceStateA, + ComputationSliceState *sliceStateB) { + assert(sliceStateA.lbs.size() == sliceStateB->lbs.size()); + assert(sliceStateA.ubs.size() == sliceStateB->ubs.size()); + + for (unsigned i = 0, e = sliceStateA.lbs.size(); i < e; ++i) { + AffineMap lbMapA = sliceStateA.lbs[i]; + AffineMap ubMapA = sliceStateA.ubs[i]; + if (lbMapA == AffineMap::Null()) { + assert(ubMapA == AffineMap::Null()); + continue; } - if (AffineMap ubMap = sliceState->ubs[i]) { - ubMap.walkExprs([&](AffineExpr expr) { - if (auto dimExpr = expr.dyn_cast()) { - maxDstLoopDepth = - std::max(maxDstLoopDepth, dimExpr.getPosition() + 1); - } - }); + assert(ubMapA != AffineMap::Null()); + // Validate that constant lower bounds are aligned at zero. + if (!hasZeroMinValue(lbMapA, sliceStateA.lbOperands[i])) + return false; + + AffineMap lbMapB = sliceStateB->lbs[i]; + AffineMap ubMapB = sliceStateB->ubs[i]; + if (lbMapB == AffineMap::Null()) { + assert(ubMapB == AffineMap::Null()); + // Union 'sliceStateB' does not have a bound for 'i' so copy from A. + sliceStateB->lbs[i] = lbMapA; + sliceStateB->ubs[i] = ubMapA; + continue; + } + // Validate that constant lower bounds are aligned at zero. + if (!hasZeroMinValue(lbMapB, sliceStateB->lbOperands[i])) + return false; + + // Add bound with the largest trip count to union. + Optional tripCountA = getConstDifference(lbMapA, ubMapA); + Optional tripCountB = getConstDifference(lbMapB, ubMapB); + if (!tripCountA.hasValue() || !tripCountB.hasValue()) + return false; + // TODO(andydavis) Change this code to take the min across all lower bounds + // and max across all upper bounds for each dimension. This code can for + // cases where a unique min or max could not be statically determined. + if (tripCountA.getValue() > tripCountB.getValue()) { + sliceStateB->lbs[i] = lbMapA; + sliceStateB->ubs[i] = ubMapA; } } - return maxDstLoopDepth; + return true; } -// Checks the profitability of fusion candidate 'candidate'. Returns true if it -// profitable to fuse the candidate loop nests. Returns false otherwise. +// Checks the profitability of fusing a backwards slice of the loop nest +// surrounding 'srcOpInst' into the loop nest surrounding 'dstOpInsts'. +// Returns true if it profitable to fuse the candidate loop nests. Returns +// false otherwise. // The profitability model executes the following steps: -// *) Computes the backward computation slice at 'candidate.srcAccess'. This -// computation slice of the loop nest surrounding 'candidate.srcAccess' is +// *) Computes the backward computation slice at 'srcOpInst'. This +// computation slice of the loop nest surrounding 'srcOpInst' is // represented by modified src loop bounds in 'sliceState', which are -// functions of loop IVs in the loop nest surrounding 'candidate.dstAccess'. +// functions of loop IVs in the loop nest surrounding 'srcOpInst'. // *) Computes the cost of unfused src/dst loop nests (currently the cost of a // loop nest is the total number of dynamic operation instances in the loop // nest). // *) Computes the cost of fusing a slice of the src loop nest into the dst -// loop nest at various values of src/dst loop depth, attempting to fuse -// the biggest compution slice (max src loop depth) at the maximal dst loop -// depth (closest to the load) to minimize reuse distance and opportunity for -// subsequent load/store forwarding. -// NOTE: 'srcLoopDepth' refers to the loop depth within the source loop nest -// at which we slice the loops bounds (all src loops below this depth will -// utilize full loop bounds). +// loop nest at various values of dst loop depth, attempting to fuse +// the largest compution slice at the maximal dst loop depth (closest to the +// load) to minimize reuse distance and potentially enable subsequent +// load/store forwarding. +// NOTE: If the dst loop nest includes multiple loads in 'dstOpInsts' for +// the same memref as is written by 'srcOpInst', then the union of slice +// loop bounds is used to compute the slice and associated slice cost. // NOTE: 'dstLoopDepth' refers the loop depth within the destination loop // nest, at which the src computation slice is inserted/fused. -// NOTE: We attempt to maximize the source loop depth, but there are cases -// where a particular setting for 'dstLoopNest' might fused an unsliced +// NOTE: We attempt to maximize the dst loop depth, but there are cases +// where a particular setting for 'dstLoopNest' might fuse an unsliced // loop (within the src computation slice) at a depth which results in // execessive recomputation (see unit tests for examples). // *) Compares the total cost of the unfused loop nests to the min cost fused // loop nest computed in the previous step, and returns true if the latter // is lower. -static bool isFusionProfitable(FusionCandidate *candidate, +static bool isFusionProfitable(OperationInst *srcOpInst, + ArrayRef dstOpInsts, ComputationSliceState *sliceState, - unsigned *srcLoopDepth, unsigned *dstLoopDepth) { - // Compute backward computation slice state: src IV bounds w.r.t dst IVs, etc. - if (!mlir::getBackwardComputationSliceState( - candidate->srcAccess, candidate->dstAccess, sliceState)) { - return false; - } - - // Build trip count map for src loops with sliced loop bounds in 'sliceState'. - DenseMap sliceTripCountMap; - if (!buildSliceTripCountMap(&candidate->srcAccess, sliceState, - &sliceTripCountMap)) - return false; - + unsigned *dstLoopDepth) { // Compute cost of sliced and unsliced src loop nest. SmallVector srcLoopIVs; - getLoopIVs(*candidate->srcAccess.opInst, &srcLoopIVs); + getLoopIVs(*srcOpInst, &srcLoopIVs); unsigned numSrcLoopIVs = srcLoopIVs.size(); // Walk src loop nest and collect stats. @@ -600,8 +679,7 @@ static bool isFusionProfitable(FusionCandidate *candidate, // Compute cost of dst loop nest. SmallVector dstLoopIVs; - getLoopIVs(*candidate->dstAccess.opInst, &dstLoopIVs); - unsigned numDstLoopIVs = dstLoopIVs.size(); + getLoopIVs(*dstOpInsts[0], &dstLoopIVs); LoopNestStats dstLoopNestStats; LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats); @@ -610,51 +688,60 @@ static bool isFusionProfitable(FusionCandidate *candidate, if (dstStatsCollector.hasLoopWithNonConstTripCount) return false; - // Search for min cost values for 'srcLoopDepth' and 'dstLoopDepth'. - // This search is O(n^2) where 'n' is very small (eg. six). - // TODO(andydavis) Consider a solution where we just iteration through - // dstLoopDepth possibilities and project out IVs we do not need (remove - // dependence on 'srcLoopDepth'. - DenseMap tripCountMap; - DenseMap computeCostMap; - unsigned maxSrcLoopDepth = getMaxSrcLoopDepth(numSrcLoopIVs, sliceState); + // Compute the innermost common loop for ops in 'dstOpInst'. + unsigned maxDstLoopDepth = getInnermostCommonLoopDepth(dstOpInsts); + if (maxDstLoopDepth == 0) + return false; + + // Search for min cost value for 'dstLoopDepth'. At each value of + // 'dstLoopDepth' from 'maxDstLoopDepth' to '1', compute computation slice + // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union + // of these bounds). Next the union slice bounds are used to calculate + // the cost of the slice and the cost of the slice inserted into the dst + // loop nest at 'dstLoopDepth'. unsigned minFusedLoopNestComputeCost = std::numeric_limits::max(); - unsigned bestSrcLoopDepth; unsigned bestDstLoopDepth; - for (unsigned i = maxSrcLoopDepth; i >= 1; --i) { - // Compute minDstLoopDepth based on dst loop IVs used in slice loop bounds. - unsigned minDstLoopDepth = getMinDstLoopDepth(i, sliceState); - assert(minDstLoopDepth <= numDstLoopIVs); - if (minDstLoopDepth == 0) { - // TODO(andydavis) Support inserting computation slices at top-level. - continue; - } - // Copy elements from slice trip count map up to src loop depth 'i'. - tripCountMap.clear(); - for (unsigned k = 0; k < i; ++k) { - auto *forInst = srcLoopIVs[k]; - auto it = sliceTripCountMap.find(forInst); - if (it != sliceTripCountMap.end()) { - tripCountMap[forInst] = it->second; - } + SmallVector sliceStates; + sliceStates.resize(maxDstLoopDepth); + + llvm::SmallDenseMap sliceTripCountMap; + DenseMap computeCostMap; + for (unsigned i = maxDstLoopDepth; i >= 1; --i) { + MemRefAccess srcAccess(srcOpInst); + // Handle the common case of one dst load without a copy. + if (!mlir::getBackwardComputationSliceState( + srcAccess, MemRefAccess(dstOpInsts[0]), i, &sliceStates[i - 1])) + return false; + // Compute the union of slice bound of all ops in 'dstOpInsts'. + for (int j = 1, e = dstOpInsts.size(); j < e; ++j) { + MemRefAccess dstAccess(dstOpInsts[j]); + ComputationSliceState tmpSliceState; + if (!mlir::getBackwardComputationSliceState(srcAccess, dstAccess, i, + &tmpSliceState)) + return false; + // Compute slice boun dunion of 'tmpSliceState' and 'sliceStates[i - 1]'. + getSliceBoundUnion(tmpSliceState, &sliceStates[i - 1]); } + // Build trip count map for computation slice. + sliceTripCountMap.clear(); + if (!buildSliceTripCountMap(srcOpInst, &sliceStates[i - 1], + &sliceTripCountMap)) + return false; + // Compute op instance count for the src loop nest with iteration slicing. uint64_t sliceComputeCost = - getComputeCost(srcLoopIVs[0], &srcLoopNestStats, &tripCountMap, + getComputeCost(srcLoopIVs[0], &srcLoopNestStats, &sliceTripCountMap, /*computeCostMap=*/nullptr); - for (unsigned j = numDstLoopIVs; j >= minDstLoopDepth; --j) { - // Compute cost of fusion for these values of 'i' and 'j'. - computeCostMap.clear(); - computeCostMap[dstLoopIVs[j - 1]] = sliceComputeCost; - uint64_t fusedLoopNestComputeCost = - getComputeCost(dstLoopIVs[0], &dstLoopNestStats, - /*tripCountOverrideMap=*/nullptr, &computeCostMap); - if (fusedLoopNestComputeCost < minFusedLoopNestComputeCost) { - minFusedLoopNestComputeCost = fusedLoopNestComputeCost; - bestSrcLoopDepth = i; - bestDstLoopDepth = j; - } + // Compute cost of fusion for these values of 'i' and 'j'. + computeCostMap.clear(); + computeCostMap[dstLoopIVs[i - 1]] = sliceComputeCost; + uint64_t fusedLoopNestComputeCost = + getComputeCost(dstLoopIVs[0], &dstLoopNestStats, + /*tripCountOverrideMap=*/nullptr, &computeCostMap); + if (fusedLoopNestComputeCost < minFusedLoopNestComputeCost) { + minFusedLoopNestComputeCost = fusedLoopNestComputeCost; + bestDstLoopDepth = i; } } @@ -668,7 +755,6 @@ static bool isFusionProfitable(FusionCandidate *candidate, /*computeCostMap=*/nullptr); LLVM_DEBUG(llvm::dbgs() << "LoopFusion statistics " - << " bestSrcLoopDepth: " << bestSrcLoopDepth << " bestDstLoopDepth: " << bestDstLoopDepth << " srcLoopNestCost: " << srcLoopNestCost << " dstLoopNestCost: " << dstLoopNestCost @@ -680,25 +766,23 @@ static bool isFusionProfitable(FusionCandidate *candidate, // for load/store forwarding in cost model. if (minFusedLoopNestComputeCost > srcLoopNestCost + dstLoopNestCost) return false; - // Set src/dstLoopDepth based on best values from search. - *srcLoopDepth = bestSrcLoopDepth; + // Update return parameter 'sliceState' with 'bestSliceState'. + ComputationSliceState *bestSliceState = &sliceStates[bestDstLoopDepth - 1]; + sliceState->lbs = bestSliceState->lbs; + sliceState->ubs = bestSliceState->ubs; + sliceState->lbOperands = bestSliceState->lbOperands; + sliceState->ubOperands = bestSliceState->ubOperands; + // Set dstLoopDepth based on best values from search. *dstLoopDepth = bestDstLoopDepth; - // Update 'sliceState' bounds based on computed 'srcLoopDepth': - // *) Canonicalize affine map now that 'srcLoopDepth' has been chosen. - // *) Replace slice bound maps at depth > 'srcLoopDepth' withAffineMap::Null() + // Canonicalize slice bound affine maps. for (unsigned i = 0; i < numSrcLoopIVs; ++i) { - if (i < bestSrcLoopDepth) { - if (sliceState->lbs[i] != AffineMap::Null()) { - canonicalizeMapAndOperands(&sliceState->lbs[i], - &sliceState->lbOperands[i]); - } - if (sliceState->ubs[i] != AffineMap::Null()) { - canonicalizeMapAndOperands(&sliceState->ubs[i], - &sliceState->ubOperands[i]); - } - } else { - sliceState->lbs[i] = AffineMap::Null(); - sliceState->ubs[i] = AffineMap::Null(); + if (sliceState->lbs[i] != AffineMap::Null()) { + canonicalizeMapAndOperands(&sliceState->lbs[i], + &sliceState->lbOperands[i]); + } + if (sliceState->ubs[i] != AffineMap::Null()) { + canonicalizeMapAndOperands(&sliceState->ubs[i], + &sliceState->ubOperands[i]); } } return true; @@ -767,12 +851,12 @@ public: continue; SmallVector loads = dstNode->loads; + SmallVector dstLoadOpInsts; while (!loads.empty()) { - auto *dstLoadOpInst = loads.pop_back_val(); - auto *memref = dstLoadOpInst->cast()->getMemRef(); - // Skip 'dstLoadOpInst' if multiple loads to 'memref' in 'dstNode'. - if (dstNode->getLoadOpCount(memref) != 1) - continue; + // Get memref of load on top of the stack. + auto *memref = loads.back()->cast()->getMemRef(); + // Move all loads in 'loads' accessing 'memref' to 'dstLoadOpInsts'. + moveLoadsAccessingMemrefTo(memref, &loads, &dstLoadOpInsts); // Skip if no input edges along which to fuse. if (mdg->inEdges.count(dstId) == 0) continue; @@ -801,19 +885,15 @@ public: continue; // Get unique 'srcNode' store op. auto *srcStoreOpInst = srcNode->stores.front(); - // Build fusion candidate out of 'srcStoreOpInst' and 'dstLoadOpInst'. - FusionCandidate candidate(srcStoreOpInst, dstLoadOpInst); // Check if fusion would be profitable. - unsigned srcLoopDepth; unsigned dstLoopDepth; mlir::ComputationSliceState sliceState; - if (!isFusionProfitable(&candidate, &sliceState, &srcLoopDepth, + if (!isFusionProfitable(srcStoreOpInst, dstLoadOpInsts, &sliceState, &dstLoopDepth)) continue; // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. auto *sliceLoopNest = mlir::insertBackwardComputationSlice( - &candidate.srcAccess, &candidate.dstAccess, &sliceState, - srcLoopDepth, dstLoopDepth); + srcStoreOpInst, dstLoadOpInsts[0], dstLoopDepth, &sliceState); if (sliceLoopNest != nullptr) { // Remove edges between 'srcNode' and 'dstNode' and remove 'srcNode' mdg->updateEdgesAndRemoveSrcNode(srcNode->id, dstNode->id); diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index 525c9d63ad0..61335be227f 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -95,13 +95,18 @@ func @should_fuse_loop_nests_with_shifts() { } } + // The cost of fusing the src loop nest at dst loop depth 1 is less expensive + // than fusing at dst loop depth 2, because at dst loop depth 1, we are + // able to reduce the trip count around the %i1 loop by one (because the + // dst loop never reads the last element written by the src loop). // CHECK: for %i0 = 0 to 10 { - // CHECK-NEXT: for %i1 = 0 to 10 { - // CHECK-NEXT: %1 = affine_apply [[MAP_SHIFT_MINUS_ONE]](%i0) - // CHECK-NEXT: %2 = affine_apply [[MAP_SHIFT_MINUS_ONE]](%i1) - // CHECK-NEXT: %3 = affine_apply [[MAP_SHIFT_BY_ONE]](%1, %2) - // CHECK-NEXT: store %cst, %0[%3#0, %3#1] : memref<10x10xf32> - // CHECK-NEXT: %4 = load %0[%i0, %i1] : memref<10x10xf32> + // CHECK-NEXT: %1 = affine_apply [[MAP_SHIFT_MINUS_ONE]](%i0) + // CHECK-NEXT: for %i1 = 0 to 9 { + // CHECK-NEXT: %2 = affine_apply [[MAP_SHIFT_BY_ONE]](%1, %i1) + // CHECK-NEXT: store %cst, %0[%2#0, %2#1] : memref<10x10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: for %i2 = 0 to 10 { + // CHECK-NEXT: %3 = load %0[%i0, %i2] : memref<10x10xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -849,6 +854,7 @@ func @should_fuse_src_depth1_at_dst_depth2() { } // ----- +// CHECK: #map0 = ()[s0] -> (s0) // CHECK-LABEL: func @fusion_at_depth0_not_currently_supported func @fusion_at_depth0_not_currently_supported() { @@ -862,10 +868,9 @@ func @fusion_at_depth0_not_currently_supported() { %1 = load %0[%c0] : memref<10xf32> } // CHECK:for %i0 = 0 to 10 { - // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> - // CHECK-NEXT: } - // CHECK-NEXT: for %i1 = 0 to 10 { - // CHECK-NEXT: %1 = load %0[%c0] : memref<10xf32> + // CHECK-NEXT: %1 = affine_apply #map0()[%c0] + // CHECK-NEXT: store %cst, %0[%1] : memref<10xf32> + // CHECK-NEXT: %2 = load %0[%c0] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -977,3 +982,128 @@ func @should_fuse_deep_loop_nests() { // CHECK-NEXT: return return } + +// ----- +// CHECK: #map0 = (d0) -> (d0) + +// CHECK-LABEL: func @should_fuse_at_depth1_and_reduce_slice_trip_count +func @should_fuse_at_depth1_and_reduce_slice_trip_count() { + %a = alloc() : memref<4x256xf32> + %b = alloc() : memref<4x256xf32> + + %c0 = constant 0 : index + %cf0 = constant 0.0 : f32 + + for %i0 = 0 to 4 { + for %i1 = 0 to 256 { + %v0 = load %b[%i0, %i1] : memref<4x256xf32> + } + for %i2 = 0 to 256 { + store %cf0, %a[%i0, %i2] : memref<4x256xf32> + } + } + + for %d0 = 0 to 4 { + for %d1 = 0 to 16 { + %v1 = load %a[%d0, %d1] : memref<4x256xf32> + } + } + // The cost of fusing at depth 2 is greater than the cost of fusing at depth 1 + // for two reasons: + // 1) Inserting the unsliceable src loop %i1 to a higher depth removes + // redundant computation and reduces costs. + // 2) Inserting the sliceable src loop %i2 at depth 1, we can still reduce + // its trip count to 16 (from 256) reducing costs. + // CHECK: for %i0 = 0 to 4 { + // CHECK-NEXT: %2 = affine_apply #map0(%i0) + // CHECK-NEXT: for %i1 = 0 to 256 { + // CHECK-NEXT: %3 = load %1[%2, %i1] : memref<4x256xf32> + // CHECK-NEXT: } + // CHECK-NEXT: for %i2 = 0 to 16 { + // CHECK-NEXT: store %cst, %0[%2, %i2] : memref<4x256xf32> + // CHECK-NEXT: } + // CHECK-NEXT: for %i3 = 0 to 16 { + // CHECK-NEXT: %4 = load %0[%i0, %i3] : memref<4x256xf32> + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: return + return +} + +// ----- + +// CHECK-LABEL: func @should_fuse_at_depth1_with_trip_count_20 +func @should_fuse_at_depth1_with_trip_count_20() { + %a = alloc() : memref<100xf32> + %c0 = constant 0 : index + %cf0 = constant 0.0 : f32 + + for %i0 = 0 to 100 { + store %cf0, %a[%i0]: memref<100xf32> + } + + for %i1 = 0 to 5 { + for %i2 = 0 to 10 { + %v0 = load %a[%i2]: memref<100xf32> + } + for %i3 = 0 to 10 { + for %i4 = 0 to 20 { + %v1 = load %a[%i4]: memref<100xf32> + } + } + } + // CHECK: for %i0 = 0 to 5 { + // CHECK-NEXT: for %i1 = 0 to 20 { + // CHECK-NEXT: store %cst, %0[%i1] : memref<100xf32> + // CHECK-NEXT: } + // CHECK-NEXT: for %i2 = 0 to 10 { + // CHECK-NEXT: %1 = load %0[%i2] : memref<100xf32> + // CHECK-NEXT: } + // CHECK-NEXT: for %i3 = 0 to 10 { + // CHECK-NEXT: for %i4 = 0 to 20 { + // CHECK-NEXT: %2 = load %0[%i4] : memref<100xf32> + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: return + return +} + +// ----- + +// CHECK-LABEL: func @should_fuse_at_depth1_with_trip_count_19 +func @should_fuse_at_depth1_with_trip_count_19() { + %a = alloc() : memref<100xf32> + %c0 = constant 0 : index + %cf0 = constant 0.0 : f32 + + for %i0 = 0 to 100 { + store %cf0, %a[%i0]: memref<100xf32> + } + + for %i1 = 0 to 5 { + for %i2 = 0 to 19 { + %v0 = load %a[%i2]: memref<100xf32> + } + for %i3 = 0 to 10 { + for %i4 = 0 to 10 { + %v1 = load %a[%i4]: memref<100xf32> + } + } + } + // CHECK: for %i0 = 0 to 5 { + // CHECK-NEXT: for %i1 = 0 to 19 { + // CHECK-NEXT: store %cst, %0[%i1] : memref<100xf32> + // CHECK-NEXT: } + // CHECK-NEXT: for %i2 = 0 to 19 { + // CHECK-NEXT: %1 = load %0[%i2] : memref<100xf32> + // CHECK-NEXT: } + // CHECK-NEXT: for %i3 = 0 to 10 { + // CHECK-NEXT: for %i4 = 0 to 10 { + // CHECK-NEXT: %2 = load %0[%i4] : memref<100xf32> + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: return + return +} -- cgit v1.2.3 From c1ca23ef6efab414879352e84302a6b52de721c2 Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Wed, 16 Jan 2019 13:13:00 -0800 Subject: Some loop fusion code cleanup/simplification post cl/229575126 - enforce the assumptions better / in a simpler way PiperOrigin-RevId: 229612424 --- mlir/lib/Transforms/LoopFusion.cpp | 49 +++++++++++++------------------------- 1 file changed, 16 insertions(+), 33 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index cdd1c77f302..804acba0d5a 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -465,8 +465,8 @@ static uint64_t getComputeCost( } // end anonymous namespace static Optional getConstDifference(AffineMap lbMap, AffineMap ubMap) { - assert(lbMap.getNumResults() == 1); - assert(ubMap.getNumResults() == 1); + assert(lbMap.getNumResults() == 1 && "expected single result bound map"); + assert(ubMap.getNumResults() == 1 && "expected single result bound map"); assert(lbMap.getNumDims() == ubMap.getNumDims()); assert(lbMap.getNumSymbols() == ubMap.getNumSymbols()); // TODO(andydavis) Merge this code with 'mlir::getTripCountExpr'. @@ -560,33 +560,16 @@ static unsigned getInnermostCommonLoopDepth(ArrayRef ops) { return loopDepth; } -// Returns true if 'map' is a single result constant or single result -// dim expr where its corresponding loop IV in 'operands' has zero constant -// lower bound. -static bool hasZeroMinValue(AffineMap map, ArrayRef operands) { - if (map.isSingleConstant() && map.getSingleConstantResult() == 0) - return true; - if (map.getNumResults() != 1 || !map.getResult(0).isa()) - return false; - // Get operand position of single dim expr result. - unsigned pos = map.getResult(0).cast().getPosition(); - // Check if loop IV at 'pos' has zero constant lower bound. - auto *operand = operands[pos]; - assert(isa(operand)); - auto *forInst = cast(operand); - return forInst->hasConstantLowerBound() && - forInst->getConstantLowerBound() == 0; -} -// Returns the slice bound union of 'sliceStateA' and 'sliceStateB' in -// 'sliceStateB'. +// Returns the slice union of 'sliceStateA' and 'sliceStateB' in 'sliceStateB' +// using a rectangular bounding box. // TODO(andydavis) This function assumes that lower bounds for 'sliceStateA' // and 'sliceStateB' are aligned. // Specifically, when taking the union of overlapping intervals, it assumes // that both intervals start at zero. Support needs to be added to take into // account interval start offset when computing the union. // TODO(andydavis) Move this function to an analysis library. -static bool getSliceBoundUnion(const ComputationSliceState &sliceStateA, - ComputationSliceState *sliceStateB) { +static bool getSliceUnion(const ComputationSliceState &sliceStateA, + ComputationSliceState *sliceStateB) { assert(sliceStateA.lbs.size() == sliceStateB->lbs.size()); assert(sliceStateA.ubs.size() == sliceStateB->ubs.size()); @@ -597,10 +580,7 @@ static bool getSliceBoundUnion(const ComputationSliceState &sliceStateA, assert(ubMapA == AffineMap::Null()); continue; } - assert(ubMapA != AffineMap::Null()); - // Validate that constant lower bounds are aligned at zero. - if (!hasZeroMinValue(lbMapA, sliceStateA.lbOperands[i])) - return false; + assert(ubMapA && "expected non-null ub map"); AffineMap lbMapB = sliceStateB->lbs[i]; AffineMap ubMapB = sliceStateB->ubs[i]; @@ -611,8 +591,13 @@ static bool getSliceBoundUnion(const ComputationSliceState &sliceStateA, sliceStateB->ubs[i] = ubMapA; continue; } - // Validate that constant lower bounds are aligned at zero. - if (!hasZeroMinValue(lbMapB, sliceStateB->lbOperands[i])) + + // TODO(andydavis) Change this code to take the min across all lower bounds + // and max across all upper bounds for each dimension. This code can for + // cases where a unique min or max could not be statically determined. + + // Assumption: both lower bounds are the same. + if (lbMapA != lbMapB) return false; // Add bound with the largest trip count to union. @@ -620,9 +605,7 @@ static bool getSliceBoundUnion(const ComputationSliceState &sliceStateA, Optional tripCountB = getConstDifference(lbMapB, ubMapB); if (!tripCountA.hasValue() || !tripCountB.hasValue()) return false; - // TODO(andydavis) Change this code to take the min across all lower bounds - // and max across all upper bounds for each dimension. This code can for - // cases where a unique min or max could not be statically determined. + if (tripCountA.getValue() > tripCountB.getValue()) { sliceStateB->lbs[i] = lbMapA; sliceStateB->ubs[i] = ubMapA; @@ -720,7 +703,7 @@ static bool isFusionProfitable(OperationInst *srcOpInst, &tmpSliceState)) return false; // Compute slice boun dunion of 'tmpSliceState' and 'sliceStates[i - 1]'. - getSliceBoundUnion(tmpSliceState, &sliceStates[i - 1]); + getSliceUnion(tmpSliceState, &sliceStates[i - 1]); } // Build trip count map for computation slice. sliceTripCountMap.clear(); -- cgit v1.2.3 From c4237ae99048f28df3e1776d63500bc9833f7a65 Mon Sep 17 00:00:00 2001 From: MLIR Team Date: Fri, 18 Jan 2019 08:56:27 -0800 Subject: LoopFusion: Creates private MemRefs which are used only by operations in the fused loop. *) Enables reduction of private memref size based on MemRef region accessed by fused slice. *) Enables maximal fusion by creating a private memref to break a fusion-preventing dependence. *) Adds maximal fusion flag to enable fusing as much as possible (though it still fuses the minimum cost computation slice). PiperOrigin-RevId: 229936698 --- mlir/lib/Transforms/LoopFusion.cpp | 231 ++++++++++++++++++++++----- mlir/test/Transforms/loop-fusion.mlir | 283 ++++++++++++++++++++-------------- 2 files changed, 357 insertions(+), 157 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 804acba0d5a..55bf025f500 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -32,6 +32,7 @@ #include "mlir/StandardOps/StandardOps.h" #include "mlir/Transforms/LoopUtils.h" #include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/Utils.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SetVector.h" @@ -45,6 +46,10 @@ using llvm::SetVector; using namespace mlir; +static llvm::cl::opt + clMaximalLoopFusion("fusion-maximal", llvm::cl::Hidden, + llvm::cl::desc("Enables maximal loop fusion.")); + namespace { /// Loop fusion pass. This pass currently supports a greedy fusion policy, @@ -95,6 +100,7 @@ public: // MemRefDependenceGraph is a graph data structure where graph nodes are // top-level instructions in a Function which contain load/store ops, and edges // are memref dependences between the nodes. +// TODO(andydavis) Add a more flexible dependece graph representation. // TODO(andydavis) Add a depth parameter to dependence graph construction. struct MemRefDependenceGraph { public: @@ -147,6 +153,9 @@ public: DenseMap> inEdges; // Map from node id to list of output edges. DenseMap> outEdges; + // Map from memref to a count on the dependence edges associated with that + // memref. + DenseMap memrefEdgeCount; MemRefDependenceGraph() {} @@ -161,6 +170,32 @@ public: return &it->second; } + // Remove node 'id' (and its associated edges) from graph. + void removeNode(unsigned id) { + // Remove each edge in 'inEdges[id]'. + if (inEdges.count(id) > 0) { + SmallVector oldInEdges = inEdges[id]; + for (auto &inEdge : oldInEdges) { + removeEdge(inEdge.id, id, inEdge.memref); + } + } + // Remove each edge in 'outEdges[id]'. + if (outEdges.count(id) > 0) { + SmallVector oldOutEdges = outEdges[id]; + for (auto &outEdge : oldOutEdges) { + removeEdge(id, outEdge.id, outEdge.memref); + } + } + // Erase remaining node state. + inEdges.erase(id); + outEdges.erase(id); + nodes.erase(id); + } + + bool hasOutEdges(unsigned id) { + return outEdges.count(id) > 0 && !outEdges[id].empty(); + } + // Returns true iff there is an edge from node 'srcId' to node 'dstId' for // 'memref'. Returns false otherwise. bool hasEdge(unsigned srcId, unsigned dstId, Value *memref) { @@ -181,6 +216,7 @@ public: if (!hasEdge(srcId, dstId, memref)) { outEdges[srcId].push_back({dstId, memref}); inEdges[dstId].push_back({srcId, memref}); + memrefEdgeCount[memref]++; } } @@ -188,6 +224,8 @@ public: void removeEdge(unsigned srcId, unsigned dstId, Value *memref) { assert(inEdges.count(dstId) > 0); assert(outEdges.count(srcId) > 0); + assert(memrefEdgeCount.count(memref) > 0); + memrefEdgeCount[memref]--; // Remove 'srcId' from 'inEdges[dstId]'. for (auto it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) { if ((*it).id == srcId && (*it).memref == memref) { @@ -224,43 +262,36 @@ public: return outEdgeCount; } - // Returns the min node id of all output edges from node 'id'. - unsigned getMinOutEdgeNodeId(unsigned id) { + // Returns the min node id across all outgoing edges from node 'id', skipping + // edges with 'memrefToSkip'. + unsigned getMinOutEdgeNodeId(unsigned id, Value *memrefToSkip) { unsigned minId = std::numeric_limits::max(); if (outEdges.count(id) > 0) for (auto &outEdge : outEdges[id]) - minId = std::min(minId, outEdge.id); + if (outEdge.memref != memrefToSkip) + minId = std::min(minId, outEdge.id); return minId; } - // Updates edge mappings from node 'srcId' to node 'dstId' and removes - // state associated with node 'srcId'. - void updateEdgesAndRemoveSrcNode(unsigned srcId, unsigned dstId) { + // Updates edge mappings from node 'srcId' to node 'dstId'. + void updateEdges(unsigned srcId, unsigned dstId) { // For each edge in 'inEdges[srcId]': add new edge remaping to 'dstId'. if (inEdges.count(srcId) > 0) { SmallVector oldInEdges = inEdges[srcId]; for (auto &inEdge : oldInEdges) { - // Remove edge from 'inEdge.id' to 'srcId'. - removeEdge(inEdge.id, srcId, inEdge.memref); // Add edge from 'inEdge.id' to 'dstId'. addEdge(inEdge.id, dstId, inEdge.memref); } } - // For each edge in 'outEdges[srcId]': add new edge remaping to 'dstId'. + // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'. if (outEdges.count(srcId) > 0) { SmallVector oldOutEdges = outEdges[srcId]; for (auto &outEdge : oldOutEdges) { - // Remove edge from 'srcId' to 'outEdge.id'. - removeEdge(srcId, outEdge.id, outEdge.memref); - // Add edge from 'dstId' to 'outEdge.id' (if 'outEdge.id' != 'dstId'). - if (outEdge.id != dstId) - addEdge(dstId, outEdge.id, outEdge.memref); + // Remove any out edges from 'srcId' to 'dstId' across memrefs. + if (outEdge.id == dstId) + removeEdge(srcId, outEdge.id, outEdge.memref); } } - // Remove 'srcId' from graph state. - inEdges.erase(srcId); - outEdges.erase(srcId); - nodes.erase(srcId); } // Adds ops in 'loads' and 'stores' to node at 'id'. @@ -273,6 +304,12 @@ public: node->stores.push_back(storeOpInst); } + void clearNodeLoadAndStores(unsigned id) { + Node *node = getNode(id); + node->loads.clear(); + node->stores.clear(); + } + void print(raw_ostream &os) const { os << "\nMemRefDependenceGraph\n"; os << "\nNodes:\n"; @@ -614,6 +651,82 @@ static bool getSliceUnion(const ComputationSliceState &sliceStateA, return true; } +// Creates and returns a private (single-user) memref for fused loop rooted +// at 'forInst', with (potentially reduced) memref size based on the +// MemRefRegion written to by 'srcStoreOpInst'. +static Value *createPrivateMemRef(ForInst *forInst, + OperationInst *srcStoreOpInst) { + // Create builder to insert alloc op just before 'forInst'. + FuncBuilder b(forInst); + // Builder to create constants at the top level. + FuncBuilder top(forInst->getFunction()); + // Create new memref type based on slice bounds. + auto *oldMemRef = srcStoreOpInst->cast()->getMemRef(); + auto oldMemRefType = oldMemRef->getType().cast(); + unsigned rank = oldMemRefType.getRank(); + + // Compute MemRefRegion for 'srcStoreOpInst'. + MemRefRegion region; + getMemRefRegion(srcStoreOpInst, 0, ®ion); + SmallVector newShape; + std::vector> lbs; + lbs.reserve(rank); + // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed + // by 'srcStoreOpInst'. + Optional numElements = + region.getBoundingConstantSizeAndShape(&newShape, &lbs); + assert(numElements.hasValue()); + + // Build 'rank' AffineExprs from MemRefRegion 'lbs' + const FlatAffineConstraints *cst = region.getConstraints(); + SmallVector offsets; + offsets.reserve(rank); + for (unsigned d = 0; d < rank; ++d) { + AffineExpr offset = top.getAffineConstantExpr(0); + for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) { + offset = offset + lbs[d][j] * top.getAffineDimExpr(j); + } + offset = offset + lbs[d][cst->getNumCols() - 1 - rank]; + offsets.push_back(offset); + } + + // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed + // by 'srcStoreOpInst'. + auto newMemRefType = b.getMemRefType(newShape, oldMemRefType.getElementType(), + {}, oldMemRefType.getMemorySpace()); + // Gather alloc operands for the dynamic dimensions of the memref. + SmallVector allocOperands; + unsigned dynamicDimCount = 0; + for (auto dimSize : oldMemRefType.getShape()) { + if (dimSize == -1) + allocOperands.push_back( + b.create(forInst->getLoc(), oldMemRef, dynamicDimCount++)); + } + + // Create new private memref for fused loop 'forInst'. + Value *newMemRef = + b.create(forInst->getLoc(), newMemRefType, allocOperands); + + // Build an AffineMap to remap access functions based on lower bound offsets. + SmallVector remapExprs; + remapExprs.reserve(rank); + unsigned zeroOffsetCount = 0; + for (unsigned i = 0; i < rank; i++) { + if (auto constExpr = offsets[i].dyn_cast()) + if (constExpr.getValue() == 0) + ++zeroOffsetCount; + auto dimExpr = b.getAffineDimExpr(i); + remapExprs.push_back(dimExpr - offsets[i]); + } + auto indexRemap = zeroOffsetCount == rank + ? AffineMap::Null() + : b.getAffineMap(rank, 0, remapExprs, {}); + // Replace all users of 'oldMemRef' with 'newMemRef'. + assert(replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap, {}, + &*forInst->getBody()->begin())); + return newMemRef; +} + // Checks the profitability of fusing a backwards slice of the loop nest // surrounding 'srcOpInst' into the loop nest surrounding 'dstOpInsts'. // Returns true if it profitable to fuse the candidate loop nests. Returns @@ -744,10 +857,12 @@ static bool isFusionProfitable(OperationInst *srcOpInst, << " minFusedLoopNestComputeCost: " << minFusedLoopNestComputeCost << "\n"); - // Do not fuse if fused loop would increase the total cost of the computation. + // Do not fuse if fused loop would increase the total cost of the computation, + // unless 'clMaximalLoopFusion' flag is set. // TODO(andydavis) Use locality/reduction in slice memref size/opportunity // for load/store forwarding in cost model. - if (minFusedLoopNestComputeCost > srcLoopNestCost + dstLoopNestCost) + if (!clMaximalLoopFusion && + minFusedLoopNestComputeCost > srcLoopNestCost + dstLoopNestCost) return false; // Update return parameter 'sliceState' with 'bestSliceState'. ComputationSliceState *bestSliceState = &sliceStates[bestDstLoopDepth - 1]; @@ -835,9 +950,13 @@ public: SmallVector loads = dstNode->loads; SmallVector dstLoadOpInsts; + DenseSet visitedMemrefs; while (!loads.empty()) { // Get memref of load on top of the stack. auto *memref = loads.back()->cast()->getMemRef(); + if (visitedMemrefs.count(memref) > 0) + continue; + visitedMemrefs.insert(memref); // Move all loads in 'loads' accessing 'memref' to 'dstLoadOpInsts'. moveLoadsAccessingMemrefTo(memref, &loads, &dstLoadOpInsts); // Skip if no input edges along which to fuse. @@ -855,16 +974,13 @@ public: // Skip if 'srcNode' has more than one store to 'memref'. if (srcNode->getStoreOpCount(memref) != 1) continue; - // Skip 'srcNode' if it has out edges on 'memref' other than 'dstId'. - if (mdg->getOutEdgeCount(srcNode->id, memref) != 1) - continue; // Skip 'srcNode' if it has in dependence edges. NOTE: This is overly // TODO(andydavis) Track dependence type with edges, and just check // for WAW dependence edge here. if (mdg->getInEdgeCount(srcNode->id, memref) != 0) continue; // Skip if 'srcNode' has out edges to other memrefs after 'dstId'. - if (mdg->getMinOutEdgeNodeId(srcNode->id) != dstId) + if (mdg->getMinOutEdgeNodeId(srcNode->id, memref) < dstId) continue; // Get unique 'srcNode' store op. auto *srcStoreOpInst = srcNode->stores.front(); @@ -878,27 +994,66 @@ public: auto *sliceLoopNest = mlir::insertBackwardComputationSlice( srcStoreOpInst, dstLoadOpInsts[0], dstLoopDepth, &sliceState); if (sliceLoopNest != nullptr) { - // Remove edges between 'srcNode' and 'dstNode' and remove 'srcNode' - mdg->updateEdgesAndRemoveSrcNode(srcNode->id, dstNode->id); - // Record all load/store accesses in 'sliceLoopNest' at 'dstPos'. - LoopNestStateCollector collector; - collector.walkForInst(sliceLoopNest); - mdg->addToNode(dstId, collector.loadOpInsts, - collector.storeOpInsts); + // Update edges between 'srcNode' and 'dstNode'. + mdg->updateEdges(srcNode->id, dstNode->id); + + // Collect slice loop stats. + LoopNestStateCollector sliceCollector; + sliceCollector.walkForInst(sliceLoopNest); + // Promote single iteration slice loops to single IV value. + for (auto *forInst : sliceCollector.forInsts) { + promoteIfSingleIteration(forInst); + } + + // Create private memref for 'memref' in 'dstForInst'. + auto *dstForInst = cast(dstNode->inst); + SmallVector storesForMemref; + for (auto *storeOpInst : sliceCollector.storeOpInsts) { + if (storeOpInst->cast()->getMemRef() == memref) + storesForMemref.push_back(storeOpInst); + } + assert(storesForMemref.size() == 1); + auto *newMemRef = + createPrivateMemRef(dstForInst, storesForMemref[0]); + visitedMemrefs.insert(newMemRef); + + // Collect dst loop stats after memref privatizaton transformation. + LoopNestStateCollector dstLoopCollector; + dstLoopCollector.walkForInst(dstForInst); + // Add new load ops to current Node load op list 'loads' to // continue fusing based on new operands. - for (auto *loadOpInst : collector.loadOpInsts) - loads.push_back(loadOpInst); - // Promote single iteration loops to single IV value. - for (auto *forInst : collector.forInsts) { - promoteIfSingleIteration(forInst); + for (auto *loadOpInst : dstLoopCollector.loadOpInsts) { + auto *loadMemRef = loadOpInst->cast()->getMemRef(); + if (visitedMemrefs.count(loadMemRef) == 0) + loads.push_back(loadOpInst); + } + + // Clear and add back loads and stores + mdg->clearNodeLoadAndStores(dstNode->id); + mdg->addToNode(dstId, dstLoopCollector.loadOpInsts, + dstLoopCollector.storeOpInsts); + // Remove old src loop nest if it no longer has users. + if (!mdg->hasOutEdges(srcNode->id)) { + mdg->removeNode(srcNode->id); + cast(srcNode->inst)->erase(); } - // Remove old src loop nest. - cast(srcNode->inst)->erase(); } } } } + // Clean up any allocs with no users. + for (auto &pair : mdg->memrefEdgeCount) { + if (pair.second > 0) + continue; + auto *memref = pair.first; + // Use list expected to match the dep graph info. + assert(memref->use_empty()); + auto *inst = memref->getDefiningInst(); + auto *opInst = dyn_cast_or_null(inst); + if (opInst && opInst->isa()) + opInst->erase(); + } } }; diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index b7859f7efbe..d240614b2ce 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -61,13 +61,13 @@ func @should_fuse_reduction_to_pointwise() { // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: %3 = affine_apply [[MAP0]](%i0) // CHECK-NEXT: for %i1 = 0 to 10 { - // CHECK-NEXT: %4 = load %1[%3] : memref<10xf32> + // CHECK-NEXT: %4 = load %2[%3] : memref<10xf32> // CHECK-NEXT: %5 = load %0[%3, %i1] : memref<10x10xf32> // CHECK-NEXT: %6 = addf %4, %5 : f32 - // CHECK-NEXT: store %6, %1[%3] : memref<10xf32> + // CHECK-NEXT: store %6, %2[%3] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: %7 = load %1[%i0] : memref<10xf32> - // CHECK-NEXT: store %7, %2[%i0] : memref<10xf32> + // CHECK-NEXT: %7 = load %2[%i0] : memref<10xf32> + // CHECK-NEXT: store %7, %1[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -75,38 +75,45 @@ func @should_fuse_reduction_to_pointwise() { // ----- -// CHECK: [[MAP_SHIFT_MINUS_ONE:#map[0-9]+]] = (d0) -> (d0 - 1) +// CHECK: [[MAP_SHIFT_MINUS_ONE_R1:#map[0-9]+]] = (d0) -> (d0 - 1) // CHECK: [[MAP_SHIFT_BY_ONE:#map[0-9]+]] = (d0, d1) -> (d0 + 1, d1 + 1) +// CHECK: [[MAP_SHIFT_MINUS_ONE_R2:#map[0-9]+]] = (d0, d1) -> (d0 - 1, d1 - 1) // CHECK-LABEL: func @should_fuse_loop_nests_with_shifts() { func @should_fuse_loop_nests_with_shifts() { %a = alloc() : memref<10x10xf32> %cf7 = constant 7.0 : f32 - for %i0 = 0 to 10 { - for %i1 = 0 to 10 { + for %i0 = 0 to 9 { + for %i1 = 0 to 9 { %a0 = affine_apply (d0, d1) -> (d0 + 1, d1 + 1) (%i0, %i1) store %cf7, %a[%a0#0, %a0#1] : memref<10x10xf32> } } - for %i2 = 0 to 10 { - for %i3 = 0 to 10 { + for %i2 = 1 to 10 { + for %i3 = 1 to 10 { %v0 = load %a[%i2, %i3] : memref<10x10xf32> } } - // The cost of fusing the src loop nest at dst loop depth 1 is less expensive - // than fusing at dst loop depth 2, because at dst loop depth 1, we are - // able to reduce the trip count around the %i1 loop by one (because the - // dst loop never reads the last element written by the src loop). - // CHECK: for %i0 = 0 to 10 { - // CHECK-NEXT: %1 = affine_apply [[MAP_SHIFT_MINUS_ONE]](%i0) - // CHECK-NEXT: for %i1 = 0 to 9 { - // CHECK-NEXT: %2 = affine_apply [[MAP_SHIFT_BY_ONE]](%1, %i1) - // CHECK-NEXT: store %cst, %0[%2#0, %2#1] : memref<10x10xf32> - // CHECK-NEXT: } - // CHECK-NEXT: for %i2 = 0 to 10 { - // CHECK-NEXT: %3 = load %0[%i0, %i2] : memref<10x10xf32> + // Source slice affine apply sequence: + // *) First two affine apply's map from the dst to src iteration space. + // *) Third affine apply is access function around src store. + // *) Fourth affine apply shifts the stores access function by '-1', because + // of the offset induced by reducing the memref shape from 10x10 to 9x9. + // *) Fifth affine apply shifts the loads access function by '-1', because + // of the offset induced by reducing the memref shape from 10x10 to 9x9. + // NOTE: Should create a private memref with reduced shape 9x9xf32. + // CHECK: %0 = alloc() : memref<9x9xf32> + // CHECK-NEXT: for %i0 = 1 to 10 { + // CHECK-NEXT: for %i1 = 1 to 10 { + // CHECK-NEXT: %1 = affine_apply [[MAP_SHIFT_MINUS_ONE_R1]](%i0) + // CHECK-NEXT: %2 = affine_apply [[MAP_SHIFT_MINUS_ONE_R1]](%i1) + // CHECK-NEXT: %3 = affine_apply [[MAP_SHIFT_BY_ONE]](%1, %2) + // CHECK-NEXT: %4 = affine_apply [[MAP_SHIFT_MINUS_ONE_R2]](%3#0, %3#1) + // CHECK-NEXT: store %cst, %0[%4#0, %4#1] : memref<9x9xf32> + // CHECK-NEXT: %5 = affine_apply [[MAP_SHIFT_MINUS_ONE_R2]](%i0, %i1) + // CHECK-NEXT: %6 = load %0[%5#0, %5#1] : memref<9x9xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -139,17 +146,19 @@ func @should_fuse_loop_nest() { %v1 = load %b[%i4, %i5] : memref<10x10xf32> } } - - // CHECK: for %i0 = 0 to 10 { + // Expecting private memref for '%b' first, then private memref for '%a'. + // CHECK: [[NEWB:%[0-9]+]] = alloc() : memref<10x10xf32> + // CHECK-NEXT: [[NEWA:%[0-9]+]] = alloc() : memref<10x10xf32> + // CHECK-NEXT: for %i0 = 0 to 10 { // CHECK-NEXT: for %i1 = 0 to 10 { // CHECK-NEXT: %2 = affine_apply [[MAP_ID]](%i1) // CHECK-NEXT: %3 = affine_apply [[MAP_ID]](%i0) - // CHECK-NEXT: store %cst, %0[%2, %3] : memref<10x10xf32> + // CHECK-NEXT: store %cst, [[NEWA]][%2, %3] : memref<10x10xf32> // CHECK-NEXT: %4 = affine_apply [[MAP_ID]](%i0) // CHECK-NEXT: %5 = affine_apply [[MAP_ID]](%i1) - // CHECK-NEXT: %6 = load %0[%5, %4] : memref<10x10xf32> - // CHECK-NEXT: store %6, %1[%4, %5] : memref<10x10xf32> - // CHECK-NEXT: %7 = load %1[%i0, %i1] : memref<10x10xf32> + // CHECK-NEXT: %6 = load [[NEWA]][%5, %4] : memref<10x10xf32> + // CHECK-NEXT: store %6, [[NEWB]][%4, %5] : memref<10x10xf32> + // CHECK-NEXT: %7 = load [[NEWB]][%i0, %i1] : memref<10x10xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -180,14 +189,16 @@ func @should_fuse_across_intermediate_loop_with_no_deps() { } // Should fuse first loop (past second loop with no dependences) into third. + // Note that fusion creates a private memref '%2' for the fused loop nest. // CHECK: for %i0 = 0 to 10 { - // CHECK-NEXT: store %cst, %2[%i0] : memref<10xf32> + // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32> // CHECK-NEXT: } + // CHECK: %2 = alloc() : memref<10xf32> // CHECK: for %i1 = 0 to 10 { // CHECK-NEXT: %3 = affine_apply [[MAP0]](%i1) // CHECK-NEXT: %4 = load %0[%3] : memref<10xf32> - // CHECK-NEXT: store %4, %1[%3] : memref<10xf32> - // CHECK-NEXT: %5 = load %1[%i1] : memref<10xf32> + // CHECK-NEXT: store %4, %2[%3] : memref<10xf32> + // CHECK-NEXT: %5 = load %2[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -216,13 +227,16 @@ func @should_fuse_all_loops() { } // Should fuse first and second loops into third. - // CHECK: for %i0 = 0 to 10 { + // Expecting private memref for '%b' first, then private memref for '%a'. + // CHECK: [[NEWB:%[0-9]+]] = alloc() : memref<10xf32> + // CHECK-NEXT: [[NEWA:%[0-9]+]] = alloc() : memref<10xf32> + // CHECK-NEXT: for %i0 = 0 to 10 { // CHECK-NEXT: %2 = affine_apply [[MAP0]](%i0) - // CHECK-NEXT: store %cst, %0[%2] : memref<10xf32> + // CHECK-NEXT: store %cst, [[NEWA]][%2] : memref<10xf32> // CHECK-NEXT: %3 = affine_apply [[MAP0]](%i0) - // CHECK-NEXT: store %cst, %1[%3] : memref<10xf32> - // CHECK-NEXT: %4 = load %0[%i0] : memref<10xf32> - // CHECK-NEXT: %5 = load %1[%i0] : memref<10xf32> + // CHECK-NEXT: store %cst, [[NEWB]][%3] : memref<10xf32> + // CHECK-NEXT: %4 = load [[NEWA]][%i0] : memref<10xf32> + // CHECK-NEXT: %5 = load [[NEWB]][%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -252,14 +266,16 @@ func @should_fuse_first_and_second_loops() { } // Should fuse first loop into the second (last loop should not be fused). - // CHECK: for %i0 = 0 to 10 { + // Should create private memref '%2' for fused loop. + // CHECK: %2 = alloc() : memref<10xf32> + // CHECK-NEXT: for %i0 = 0 to 10 { // CHECK-NEXT: %3 = affine_apply [[MAP0]](%i0) - // CHECK-NEXT: store %cst, %0[%3] : memref<10xf32> - // CHECK-NEXT: %4 = load %0[%i0] : memref<10xf32> - // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32> + // CHECK-NEXT: store %cst, %2[%3] : memref<10xf32> + // CHECK-NEXT: %4 = load %2[%i0] : memref<10xf32> + // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK: for %i1 = 0 to 10 { - // CHECK-NEXT: %5 = load %2[%i1] : memref<10xf32> + // CHECK-NEXT: %5 = load %1[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return @@ -310,39 +326,10 @@ func @should_not_fuse_would_create_cycle() { } // ----- +// CHECK: #map0 = (d0) -> (d0) -// CHECK-LABEL: func @should_not_fuse_raw_dep_would_be_violated() { -func @should_not_fuse_raw_dep_would_be_violated() { - %m = alloc() : memref<10xf32> - %cf7 = constant 7.0 : f32 - - for %i0 = 0 to 10 { - store %cf7, %m[%i0] : memref<10xf32> - } - for %i1 = 0 to 10 { - %v0 = load %m[%i1] : memref<10xf32> - } - for %i2 = 0 to 10 { - %v1 = load %m[%i2] : memref<10xf32> - } - // Fusing loop %i0 to %i2 would violate the RAW dependence between %i0 and %i1 - // CHECK: for %i0 = 0 to 10 { - // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> - // CHECK-NEXT: } - // CHECK: for %i1 = 0 to 10 { - // CHECK-NEXT: %1 = load %0[%i1] : memref<10xf32> - // CHECK-NEXT: } - // CHECK: for %i2 = 0 to 10 { - // CHECK-NEXT: %2 = load %0[%i2] : memref<10xf32> - // CHECK-NEXT: } - // CHECK-NEXT: return - return -} - -// ----- - -// CHECK-LABEL: func @should_not_fuse_waw_dep_would_be_violated() { -func @should_not_fuse_waw_dep_would_be_violated() { +// CHECK-LABEL: func @should_fuse_across_waw_dep_with_private_memref() { +func @should_fuse_across_waw_dep_with_private_memref() { %m = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 @@ -362,8 +349,11 @@ func @should_not_fuse_waw_dep_would_be_violated() { // CHECK: for %i1 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i1] : memref<10xf32> // CHECK-NEXT: } - // CHECK: for %i2 = 0 to 10 { - // CHECK-NEXT: %1 = load %0[%i2] : memref<10xf32> + // CHECK: %1 = alloc() : memref<10xf32> + // CHECK-NEXT: for %i2 = 0 to 10 { + // CHECK-NEXT: %2 = affine_apply #map0(%i2) + // CHECK-NEXT: store %cst, %1[%2] : memref<10xf32> + // CHECK-NEXT: %3 = load %1[%i2] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -404,8 +394,8 @@ func @should_not_fuse_war_dep_would_be_violated() { // ----- -// CHECK-LABEL: func @should_not_fuse_if_top_level_access() { -func @should_not_fuse_if_top_level_access() { +// CHECK-LABEL: func @should_fuse_with_private_memref_if_top_level_access() { +func @should_fuse_with_private_memref_if_top_level_access() { %m = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 @@ -422,8 +412,11 @@ func @should_not_fuse_if_top_level_access() { // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK: for %i1 = 0 to 10 { - // CHECK-NEXT: %1 = load %0[%i1] : memref<10xf32> + // CHECK: %1 = alloc() : memref<10xf32> + // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: %2 = affine_apply #map0(%i1) + // CHECK-NEXT: store %cst, %1[%2] : memref<10xf32> + // CHECK-NEXT: %3 = load %1[%i1] : memref<10xf32> // CHECK-NEXT: } return } @@ -625,12 +618,14 @@ func @fuse_reshape_16_4_64() { // ----- +// TODO(b/123072438) Re-enable test MemRefRegion bug is fixed. // All three loop nests below (6-d one, 2-d one, 2-d one is fused into a single // 2-d loop nest). -// CHECK-LABEL: func @R6_to_R2_reshape +// xCHECK-LABEL: func @R6_to_R2_reshape func @R6_to_R2_reshape_square() -> memref<64x9xi32> { %in = alloc() : memref<2x2x3x3x16x1xi32> %out = alloc() : memref<64x9xi32> + %live_out = alloc() : memref<64x9xi32> // Initialize input. for %i0 = 0 to 2 { @@ -670,35 +665,38 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> { for %j = 0 to 9 { %a = load %out[%i, %j] : memref<64x9xi32> %b = muli %a, %a : i32 - store %b, %out[%i, %j] : memref<64x9xi32> + store %b, %live_out[%i, %j] : memref<64x9xi32> } } - return %out : memref<64x9xi32> + return %live_out : memref<64x9xi32> } // Everything above is fused to a single 2-d loop nest, and the 6-d tensor %in // is eliminated if -memref-dataflow-opt is also supplied. // -// CHECK: for %i0 = 0 to 64 { -// CHECK-NEXT: for %i1 = 0 to 9 { -// CHECK-NEXT: %2 = affine_apply #map0(%i0, %i1) -// CHECK-NEXT: %3 = affine_apply #map1(%i0, %i1) -// CHECK-NEXT: %4 = affine_apply #map2(%i0, %i1) -// CHECK-NEXT: %5 = affine_apply #map3(%i0, %i1) -// CHECK-NEXT: %6 = affine_apply #map4(%i0, %i1) -// CHECK-NEXT: %7 = "foo"(%2, %3, %4, %5, %6, %c0) : (index, index, index, index, index, index) -> i32 -// CHECK-NEXT: store %7, %0[%2, %3, %4, %5, %6, %c0] : memref<2x2x3x3x16x1xi32> -// CHECK-NEXT: %8 = affine_apply #map5(%i0) -// CHECK-NEXT: %9 = affine_apply #map5(%i1) -// CHECK-NEXT: %10 = affine_apply #map6(%8, %9) -// CHECK-NEXT: %11 = affine_apply #map7(%10) -// CHECK-NEXT: %12 = load %0[%11#0, %11#1, %11#3, %11#4, %11#2, %11#5] : memref<2x2x3x3x16x1xi32> -// CHECK-NEXT: store %12, %1[%8, %9] : memref<64x9xi32> -// CHECK-NEXT: %13 = load %1[%i0, %i1] : memref<64x9xi32> -// CHECK-NEXT: %14 = muli %13, %13 : i32 -// CHECK-NEXT: store %14, %1[%i0, %i1] : memref<64x9xi32> -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: return %1 : memref<64x9xi32> +// xCHECK: %0 = alloc() : memref<64x9xi32> +// xCHECK-NEXT: %1 = alloc() : memref<64x9xi32> +// xCHECK-NEXT: %2 = alloc() : memref<2x2x3x3x16x1xi32> +// xCHECK-NEXT: for %i0 = 0 to 64 { +// xCHECK-NEXT: for %i1 = 0 to 9 { +// xCHECK-NEXT: %3 = affine_apply #map0(%i0, %i1) +// xCHECK-NEXT: %4 = affine_apply #map1(%i0, %i1) +// xCHECK-NEXT: %5 = affine_apply #map2(%i0, %i1) +// xCHECK-NEXT: %6 = affine_apply #map3(%i0, %i1) +// xCHECK-NEXT: %7 = affine_apply #map4(%i0, %i1) +// xCHECK-NEXT: %8 = "foo"(%3, %4, %5, %6, %7, %c0) : (index, index, index, index, index, index) -> i32 +// xCHECK-NEXT: store %8, %2[%3, %4, %5, %6, %7, %c0] : memref<2x2x3x3x16x1xi32> +// xCHECK-NEXT: %9 = affine_apply #map5(%i0) +// xCHECK-NEXT: %10 = affine_apply #map5(%i1) +// xCHECK-NEXT: %11 = affine_apply #map6(%8, %9) +// xCHECK-NEXT: %12 = affine_apply #map7(%10) +// xCHECK-NEXT: %13 = load %2[%12#0, %12#1, %12#3, %12#4, %12#2, %12#5] : memref<2x2x3x3x16x1xi32> +// xCHECK-NEXT: store %12, %1[%9, %10] : memref<64x9xi32> +// xCHECK-NEXT: %14 = load %1[%i0, %i1] : memref<64x9xi32> +// xCHECK-NEXT: %15 = muli %14, %14 : i32 +// xCHECK-NEXT: store %15, %0[%i0, %i1] : memref<64x9xi32> +// xCHECK-NEXT: } +// xCHECK-NEXT: } +// xCHECK-NEXT: return %0 : memref<64x9xi32> // ----- @@ -867,10 +865,13 @@ func @fusion_at_depth0_not_currently_supported() { for %i1 = 0 to 10 { %1 = load %0[%c0] : memref<10xf32> } - // CHECK:for %i0 = 0 to 10 { + // NOTE: Should shrink memref size to 1 element access by load in dst loop + // nest, and make the store in the slice store to the same element. + // CHECK: %0 = alloc() : memref<1xf32> + // CHECK-NEXT: for %i0 = 0 to 10 { // CHECK-NEXT: %1 = affine_apply #map0()[%c0] - // CHECK-NEXT: store %cst, %0[%1] : memref<10xf32> - // CHECK-NEXT: %2 = load %0[%c0] : memref<10xf32> + // CHECK-NEXT: store %cst, %0[%1] : memref<1xf32> + // CHECK-NEXT: %2 = load %0[%c0] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -954,7 +955,7 @@ func @should_fuse_deep_loop_nests() { // CHECK-NEXT: } // CHECK-NEXT: for %i6 = 0 to 16 { // CHECK-NEXT: for %i7 = 0 to 10 { -// CHECK-NEXT: store %cst, %1[%3, %4, %5, %6, %i6, %i7] : memref<2x2x3x3x16x10xf32, 2> +// CHECK-NEXT: store %cst, %2[%3, %4, %5, %6, %i6, %i7] : memref<2x2x3x3x16x10xf32, 2> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: for %i8 = 0 to 3 { @@ -968,7 +969,7 @@ func @should_fuse_deep_loop_nests() { // CHECK-NEXT: } // CHECK-NEXT: for %i14 = 0 to 16 { // CHECK-NEXT: for %i15 = 0 to 10 { -// CHECK-NEXT: %9 = load %1[%i2, %i3, %i0, %i1, %i14, %i15] : memref<2x2x3x3x16x10xf32, 2> +// CHECK-NEXT: %9 = load %2[%i2, %i3, %i0, %i1, %i14, %i15] : memref<2x2x3x3x16x10xf32, 2> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } @@ -1014,16 +1015,20 @@ func @should_fuse_at_depth1_and_reduce_slice_trip_count() { // redundant computation and reduces costs. // 2) Inserting the sliceable src loop %i2 at depth 1, we can still reduce // its trip count to 16 (from 256) reducing costs. - // CHECK: for %i0 = 0 to 4 { + // NOTE: the size of the private memref created for the fused loop nest + // is reduced from the original shape from 4x256 to 4x16 because of the + // data accessed by the load. + // CHECK: %1 = alloc() : memref<4x16xf32> + // CHECK-NEXT: for %i0 = 0 to 4 { // CHECK-NEXT: %2 = affine_apply #map0(%i0) // CHECK-NEXT: for %i1 = 0 to 256 { - // CHECK-NEXT: %3 = load %1[%2, %i1] : memref<4x256xf32> + // CHECK-NEXT: %3 = load %0[%2, %i1] : memref<4x256xf32> // CHECK-NEXT: } // CHECK-NEXT: for %i2 = 0 to 16 { - // CHECK-NEXT: store %cst, %0[%2, %i2] : memref<4x256xf32> + // CHECK-NEXT: store %cst, %1[%2, %i2] : memref<4x16xf32> // CHECK-NEXT: } // CHECK-NEXT: for %i3 = 0 to 16 { - // CHECK-NEXT: %4 = load %0[%i0, %i3] : memref<4x256xf32> + // CHECK-NEXT: %4 = load %1[%i0, %i3] : memref<4x16xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -1052,16 +1057,18 @@ func @should_fuse_at_depth1_with_trip_count_20() { } } } - // CHECK: for %i0 = 0 to 5 { + // NOTE: The size of the private memref created for fusion is shrunk to 20xf32 + // CHECK: %0 = alloc() : memref<20xf32> + // CHECK-NEXT: for %i0 = 0 to 5 { // CHECK-NEXT: for %i1 = 0 to 20 { - // CHECK-NEXT: store %cst, %0[%i1] : memref<100xf32> + // CHECK-NEXT: store %cst, %0[%i1] : memref<20xf32> // CHECK-NEXT: } // CHECK-NEXT: for %i2 = 0 to 10 { - // CHECK-NEXT: %1 = load %0[%i2] : memref<100xf32> + // CHECK-NEXT: %1 = load %0[%i2] : memref<20xf32> // CHECK-NEXT: } // CHECK-NEXT: for %i3 = 0 to 10 { // CHECK-NEXT: for %i4 = 0 to 20 { - // CHECK-NEXT: %2 = load %0[%i4] : memref<100xf32> + // CHECK-NEXT: %2 = load %0[%i4] : memref<20xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } @@ -1091,19 +1098,57 @@ func @should_fuse_at_depth1_with_trip_count_19() { } } } - // CHECK: for %i0 = 0 to 5 { + // NOTE: The size of the private memref created for fusion is shrunk to 19xf32 + // CHECK: %0 = alloc() : memref<19xf32> + // CHECK-NEXT: for %i0 = 0 to 5 { // CHECK-NEXT: for %i1 = 0 to 19 { - // CHECK-NEXT: store %cst, %0[%i1] : memref<100xf32> + // CHECK-NEXT: store %cst, %0[%i1] : memref<19xf32> // CHECK-NEXT: } // CHECK-NEXT: for %i2 = 0 to 19 { - // CHECK-NEXT: %1 = load %0[%i2] : memref<100xf32> + // CHECK-NEXT: %1 = load %0[%i2] : memref<19xf32> // CHECK-NEXT: } // CHECK-NEXT: for %i3 = 0 to 10 { // CHECK-NEXT: for %i4 = 0 to 10 { - // CHECK-NEXT: %2 = load %0[%i4] : memref<100xf32> + // CHECK-NEXT: %2 = load %0[%i4] : memref<19xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return return } + + +// ----- +// CHECK: #map0 = (d0) -> (d0) + +// CHECK-LABEL: func @should_fuse_with_private_memrefs_with_diff_shapes() { +func @should_fuse_with_private_memrefs_with_diff_shapes() { + %m = alloc() : memref<100xf32> + %cf7 = constant 7.0 : f32 + + for %i0 = 0 to 100 { + store %cf7, %m[%i0] : memref<100xf32> + } + for %i1 = 0 to 17 { + %v0 = load %m[%i1] : memref<100xf32> + } + for %i2 = 0 to 82 { + %v1 = load %m[%i2] : memref<100xf32> + } + // Should create two new private memrefs customized to the shapes accessed + // by loops %i1 and %i2. + // CHECK: %0 = alloc() : memref<17xf32> + // CHECK-NEXT: for %i0 = 0 to 17 { + // CHECK-NEXT: %1 = affine_apply #map0(%i0) + // CHECK-NEXT: store %cst, %0[%1] : memref<17xf32> + // CHECK-NEXT: %2 = load %0[%i0] : memref<17xf32> + // CHECK-NEXT: } + // CHECK-NEXT: %3 = alloc() : memref<82xf32> + // CHECK-NEXT: for %i1 = 0 to 82 { + // CHECK-NEXT: %4 = affine_apply #map0(%i1) + // CHECK-NEXT: store %cst, %3[%4] : memref<82xf32> + // CHECK-NEXT: %5 = load %3[%i1] : memref<82xf32> + // CHECK-NEXT: } + // CHECK-NEXT: return + return +} \ No newline at end of file -- cgit v1.2.3 From 1e484b5ef4b9c896fa5aa04d389f7aac3b40deef Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Fri, 18 Jan 2019 10:56:12 -0800 Subject: Mark (void)indexRemap to please compiler for unused variable check PiperOrigin-RevId: 229957023 --- mlir/lib/Transforms/LoopFusion.cpp | 1 + 1 file changed, 1 insertion(+) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 55bf025f500..e82214b7b48 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -724,6 +724,7 @@ static Value *createPrivateMemRef(ForInst *forInst, // Replace all users of 'oldMemRef' with 'newMemRef'. assert(replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap, {}, &*forInst->getBody()->begin())); + (void)indexRemap; return newMemRef; } -- cgit v1.2.3 From 71495d58a7b17f3e3f6a8c54ed704114b56c2374 Mon Sep 17 00:00:00 2001 From: MLIR Team Date: Tue, 22 Jan 2019 13:23:37 -0800 Subject: Handle escaping memrefs in loop fusion pass: *) Do not remove loop nests which write to memrefs which escape the function. *) Do not remove memrefs which escape the function (e.g. are used in the return instruction). PiperOrigin-RevId: 230398630 --- mlir/lib/Transforms/LoopFusion.cpp | 45 ++++++++++++++++++++++++--- mlir/test/Transforms/loop-fusion.mlir | 58 +++++++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 5 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index e82214b7b48..5c80367cfa0 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -97,6 +97,13 @@ public: } }; +// TODO(b/117228571) Replace when this is modeled through side-effects/op traits +static bool isMemRefDereferencingOp(const OperationInst &op) { + if (op.isa() || op.isa() || op.isa() || + op.isa()) + return true; + return false; +} // MemRefDependenceGraph is a graph data structure where graph nodes are // top-level instructions in a Function which contain load/store ops, and edges // are memref dependences between the nodes. @@ -196,6 +203,27 @@ public: return outEdges.count(id) > 0 && !outEdges[id].empty(); } + // Returns true if node 'id' writes to any memref which escapes (or is an + // argument to) the function/block. Returns false otherwise. + bool writesToLiveInOrEscapingMemrefs(unsigned id) { + Node *node = getNode(id); + for (auto *storeOpInst : node->stores) { + auto *memref = storeOpInst->cast()->getMemRef(); + auto *inst = memref->getDefiningInst(); + auto *opInst = dyn_cast_or_null(inst); + // Return false if 'memref' is a function argument. + if (opInst == nullptr) + return true; + // Return false if any use of 'memref' escapes the function. + for (auto &use : memref->getUses()) { + auto *user = dyn_cast(use.getOwner()); + if (!user || !isMemRefDereferencingOp(*user)) + return true; + } + } + return false; + } + // Returns true iff there is an edge from node 'srcId' to node 'dstId' for // 'memref'. Returns false otherwise. bool hasEdge(unsigned srcId, unsigned dstId, Value *memref) { @@ -722,8 +750,10 @@ static Value *createPrivateMemRef(ForInst *forInst, ? AffineMap::Null() : b.getAffineMap(rank, 0, remapExprs, {}); // Replace all users of 'oldMemRef' with 'newMemRef'. - assert(replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap, {}, - &*forInst->getBody()->begin())); + bool ret = replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap, {}, + &*forInst->getBody()->begin()); + assert(ret); + (void)ret; (void)indexRemap; return newMemRef; } @@ -1034,8 +1064,11 @@ public: mdg->clearNodeLoadAndStores(dstNode->id); mdg->addToNode(dstId, dstLoopCollector.loadOpInsts, dstLoopCollector.storeOpInsts); - // Remove old src loop nest if it no longer has users. - if (!mdg->hasOutEdges(srcNode->id)) { + // Remove old src loop nest if it no longer has outgoing dependence + // edges, and it does not write to a memref which escapes the + // function. + if (!mdg->hasOutEdges(srcNode->id) && + !mdg->writesToLiveInOrEscapingMemrefs(srcNode->id)) { mdg->removeNode(srcNode->id); cast(srcNode->inst)->erase(); } @@ -1048,8 +1081,10 @@ public: if (pair.second > 0) continue; auto *memref = pair.first; + // Skip if there exist other uses (return instruction or function calls). + if (!memref->use_empty()) + continue; // Use list expected to match the dep graph info. - assert(memref->use_empty()); auto *inst = memref->getDefiningInst(); auto *opInst = dyn_cast_or_null(inst); if (opInst && opInst->isa()) diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index 864ea3c10ac..12e0a8a456a 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -1153,3 +1153,61 @@ func @should_fuse_with_private_memrefs_with_diff_shapes() { // CHECK-NEXT: return return } + +// ----- + +// CHECK: #map0 = (d0) -> (d0) + +// CHECK-LABEL: func @fusion_should_not_remove_memref_arg(%arg0: memref<10xf32>) { +func @fusion_should_not_remove_memref_arg(%arg0: memref<10xf32>) { + %cf7 = constant 7.0 : f32 + + for %i0 = 0 to 10 { + store %cf7, %arg0[%i0] : memref<10xf32> + } + for %i1 = 0 to 10 { + %v0 = load %arg0[%i1] : memref<10xf32> + } + // This tests that the loop nest '%i0' should not be removed after fusion + // because it writes to memref argument '%arg0'. + // CHECK: for %i0 = 0 to 10 { + // CHECK-NEXT: store %cst, %arg0[%i0] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: %0 = alloc() : memref<10xf32> + // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: %1 = affine_apply #map0(%i1) + // CHECK-NEXT: store %cst, %0[%1] : memref<10xf32> + // CHECK-NEXT: %2 = load %0[%i1] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: return + return +} + +// ----- + +// CHECK: #map0 = (d0) -> (d0) + +// CHECK-LABEL: func @fusion_should_not_remove_escaping_memref() +func @fusion_should_not_remove_escaping_memref() -> memref<10xf32> { + %cf7 = constant 7.0 : f32 + %m = alloc() : memref<10xf32> + for %i0 = 0 to 10 { + store %cf7, %m[%i0] : memref<10xf32> + } + for %i1 = 0 to 10 { + %v0 = load %m[%i1] : memref<10xf32> + } + // This tests that the loop nest '%i0' should not be removed after fusion + // because it writes to memref '%m' which is returned by the function. + // CHECK: for %i0 = 0 to 10 { + // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: %1 = alloc() : memref<10xf32> + // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: %2 = affine_apply #map0(%i1) + // CHECK-NEXT: store %cst, %1[%2] : memref<10xf32> + // CHECK-NEXT: %3 = load %1[%i1] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: return %0 : memref<10xf32> + return %m : memref<10xf32> +} -- cgit v1.2.3 From 94a03f864f53ec51137c522c76cecf73abb2f6e7 Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Tue, 22 Jan 2019 13:58:52 -0800 Subject: Allocate private/local buffers for slices accurately during fusion - the size of the private memref created for the slice should be based on the memref region accessed at the depth at which the slice is being materialized, i.e., symbolic in the outer IVs up until that depth, as opposed to the region accessed based on the entire domain. - leads to a significant contraction of the temporary / intermediate memref whenever the memref isn't reduced to a single scalar (through store fwd'ing). Other changes - update to promoteIfSingleIteration - avoid introducing unnecessary identity map affine_apply from IV; makes it much easier to write and read test cases and pass output for all passes that use promoteIfSingleIteration; loop-fusion test cases become much simpler - fix replaceAllMemrefUsesWith bug that was exposed by the above update - 'domInstFilter' could be one of the ops erased due to a memref replacement in it. - fix getConstantBoundOnDimSize bug: a division by the coefficient of the identifier was missing (the latter need not always be 1); add lbFloorDivisors output argument - rename getBoundingConstantSizeAndShape -> getConstantBoundingSizeAndShape PiperOrigin-RevId: 230405218 --- mlir/include/mlir/Analysis/AffineStructures.h | 18 +- mlir/include/mlir/Analysis/Utils.h | 8 +- mlir/include/mlir/Transforms/Utils.h | 33 ++- mlir/lib/Analysis/AffineAnalysis.cpp | 1 + mlir/lib/Analysis/AffineStructures.cpp | 24 +- mlir/lib/Analysis/Utils.cpp | 13 +- mlir/lib/Transforms/DmaGeneration.cpp | 12 +- mlir/lib/Transforms/LoopFusion.cpp | 57 ++-- mlir/lib/Transforms/Utils/LoopUtils.cpp | 11 +- mlir/lib/Transforms/Utils/Utils.cpp | 28 +- mlir/test/Transforms/loop-fusion.mlir | 368 ++++++++++++++++---------- 11 files changed, 355 insertions(+), 218 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h index c6963c5662f..f263ca13d79 100644 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -547,18 +547,22 @@ public: /// Clears this list of constraints and copies other into it. void clearAndCopyFrom(const FlatAffineConstraints &other); - /// Returns the smallest known constant bound for the extent of the - /// specified identifier (pos^th), i.e., the smallest known constant that is - /// greater than or equal to 'exclusive upper bound' - 'lower bound' of the - /// identifier; returns None if it's not a constant. This method employs + /// Returns the smallest known constant bound for the extent of the specified + /// identifier (pos^th), i.e., the smallest known constant that is greater + /// than or equal to 'exclusive upper bound' - 'lower bound' of the + /// identifier. Returns None if it's not a constant. This method employs /// trivial (low complexity / cost) checks and detection. Symbolic identifiers /// are treated specially, i.e., it looks for constant differences between /// affine expressions involving only the symbolic identifiers. See comments - /// at function definition for examples. 'lb', if provided, is set to the - /// lower bound associated with the constant difference. + /// at function definition for examples. 'lb' and 'lbDivisor', if provided, + /// are used to express the lower bound associated with the constant + /// difference: 'lb' has the coefficients and lbDivisor, the divisor. For eg., + /// if the lower bound is [(s0 + s2 - 1) floordiv 32] for a system with three + /// symbolic identifiers, *lb = [1, 0, 1], lbDivisor = 32. Optional getConstantBoundOnDimSize(unsigned pos, - SmallVectorImpl *lb = nullptr) const; + SmallVectorImpl *lb = nullptr, + int64_t *lbDivisor = nullptr) const; /// Returns the constant lower bound for the pos^th identifier if there is /// one; None otherwise. diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index 4e304067411..bfdf4d40b34 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -79,9 +79,10 @@ struct MemRefRegion { /// bounded by a known constant, None otherwise. The 'shape' vector is set to /// the corresponding dimension-wise bounds major to minor. We use int64_t /// instead of uint64_t since index types can be at most int64_t. - Optional getBoundingConstantSizeAndShape( + Optional getConstantBoundingSizeAndShape( SmallVectorImpl *shape = nullptr, - std::vector> *lbs = nullptr) const; + std::vector> *lbs = nullptr, + SmallVectorImpl *lbDivisors = nullptr) const; /// A wrapper around FlatAffineConstraints::getConstantBoundOnDimSize(). 'pos' /// corresponds to the position of the memref shape's dimension (major to @@ -89,7 +90,8 @@ struct MemRefRegion { //'cst'. Optional getConstantBoundOnDimSize(unsigned pos, - SmallVectorImpl *lb = nullptr) const { + SmallVectorImpl *lb = nullptr, + int64_t *lbDivisor = nullptr) const { assert(pos < getRank() && "invalid position"); return cst.getConstantBoundOnDimSize(pos, lb); } diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h index ccf549f89ae..2a1505aecb7 100644 --- a/mlir/include/mlir/Transforms/Utils.h +++ b/mlir/include/mlir/Transforms/Utils.h @@ -40,16 +40,29 @@ class OperationInst; class Function; -/// Replace all uses of oldMemRef with newMemRef while optionally remapping the -/// old memref's indices using the supplied affine map and adding any additional -/// indices. Additional indices are added at the start. The new memref could be -/// of a different shape or rank. 'extraOperands' is an optional argument that -/// corresponds to additional operands (inputs) for indexRemap at the beginning -/// of its input list. An additional optional argument 'domInstFilter' restricts -/// the replacement to only those operations that are dominated by the former. -/// Returns true on success and false if the replacement is not possible -/// (whenever a memref is used as an operand in a non-deferencing scenario). See -/// comments at function definition for an example. +/// Replaces all uses of oldMemRef with newMemRef while optionally remapping the +/// old memref's indices using the supplied affine map, 'indexRemap'. The new +/// memref could be of a different shape or rank. 'extraIndices' provides +/// additional access indices to be added to the start. 'indexRemap' remaps +/// indices of the old memref access to a new set of indices that are used to +/// index the memref. Additional input operands to indexRemap can be optionally +/// provided, and they are added at the start of its input list. 'indexRemap' is +/// expected to have only dimensional inputs, and the number of its inputs equal +/// to extraOperands.size() plus rank of the memref. 'extraOperands' is an +/// optional argument that corresponds to additional operands (inputs) for +/// indexRemap at the beginning of its input list. An additional optional +/// argument 'domInstFilter' restricts the replacement to only those operations +/// that are dominated by the former. Returns true on success and false if the +/// replacement is not possible (whenever a memref is used as an operand in a +/// non-deferencing scenario). See comments at function definition for an +/// example. +// Ex: to replace load %A[%i, %j] with load %Abuf[%t mod 2, %ii - %i, %j]: +// The SSA value corresponding to '%t mod 2' should be in 'extraIndices', and +// index remap will perform (%i, %j) -> (%ii - %i, %j), i.e., indexRemap = (d0, +// d1, d2) -> (d0 - d1, d2), and %ii will be the extra operand. Without any +// extra operands, note that 'indexRemap' would just be applied to existing +// indices (%i, %j). +// TODO(bondhugula): allow extraIndices to be added at any position. bool replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, ArrayRef extraIndices = {}, AffineMap indexRemap = AffineMap::Null(), diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index 6f0b72e3f09..0e49303c778 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -1531,6 +1531,7 @@ AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map, affineMap = simplifyAffineMap(map.compose(exprsMap)); LLVM_DEBUG(affineMap.print(dbgs() << "\nSimplified result: ")); + LLVM_DEBUG(dbgs() << "\n"); } /// Implements `map` and `operands` composition and simplification to support diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index ebb92e8577f..44daaf1459b 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1438,9 +1438,11 @@ void FlatAffineConstraints::constantFoldIdRange(unsigned pos, unsigned num) { /// the coefficients of the symbolic identifiers and the constant coefficient. // Egs: 0 <= i <= 15, return 16. // s0 + 2 <= i <= s0 + 17, returns 16. (s0 has to be a symbol) -// i + s0 + 16 <= d0 <= i + s0 + 31, returns 16. +// s0 + s1 + 16 <= d0 <= s0 + s1 + 31, returns 16. +// s0 - 7 <= 8*j <= s0 returns 1 with lb = s0, lbDivisor = 8 (since lb = +// ceil(s0 - 7 / 8) = floor(s0 / 8)). Optional FlatAffineConstraints::getConstantBoundOnDimSize( - unsigned pos, SmallVectorImpl *lb) const { + unsigned pos, SmallVectorImpl *lb, int64_t *lbFloorDivisor) const { assert(pos < getNumDimIds() && "Invalid identifier position"); assert(getNumLocalIds() == 0); @@ -1463,6 +1465,9 @@ Optional FlatAffineConstraints::getConstantBoundOnDimSize( (*lb)[c] = v < 0 ? atEq(eqRow, getNumDimIds() + c) / -v : -atEq(eqRow, getNumDimIds() + c) / v; } + assert(lbFloorDivisor && + "both lb and divisor or none should be provided"); + *lbFloorDivisor = 1; } return 1; } @@ -1519,8 +1524,9 @@ Optional FlatAffineConstraints::getConstantBoundOnDimSize( } if (j < getNumCols() - 1) continue; - int64_t diff = - atIneq(ubPos, getNumCols() - 1) + atIneq(lbPos, getNumCols() - 1) + 1; + int64_t diff = floorDiv(atIneq(ubPos, getNumCols() - 1) + + atIneq(lbPos, getNumCols() - 1) + 1, + atIneq(lbPos, pos)); if (minDiff == None || diff < minDiff) { minDiff = diff; minLbPosition = lbPos; @@ -1530,8 +1536,16 @@ Optional FlatAffineConstraints::getConstantBoundOnDimSize( if (lb && minDiff.hasValue()) { // Set lb to the symbolic lower bound. lb->resize(getNumSymbolIds() + 1); + // The lower bound is the ceildiv of the lb constraint over the coefficient + // of the variable at 'pos'. We express the ceildiv equivalently as a floor + // for uniformity. For eg., if the lower bound constraint was: 32*d0 - N + + // 31 >= 0, the lower bound for d0 is ceil(N - 31, 32), i.e., floor(N, 32). + *lbFloorDivisor = atIneq(minLbPosition, pos); for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++) { - (*lb)[c] = -atIneq(minLbPosition, getNumDimIds() + c); + // ceildiv (val / d) = floordiv (val + d - 1 / d); hence, the addition of + // 'atIneq(minLbPosition, pos) - 1'. + (*lb)[c] = -atIneq(minLbPosition, getNumDimIds() + c) + + atIneq(minLbPosition, pos) - 1; } } return minDiff; diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index c003a641311..592fad4ab29 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -55,9 +55,9 @@ unsigned MemRefRegion::getRank() const { return memref->getType().cast().getRank(); } -Optional MemRefRegion::getBoundingConstantSizeAndShape( - SmallVectorImpl *shape, - std::vector> *lbs) const { +Optional MemRefRegion::getConstantBoundingSizeAndShape( + SmallVectorImpl *shape, std::vector> *lbs, + SmallVectorImpl *lbDivisors) const { auto memRefType = memref->getType().cast(); unsigned rank = memRefType.getRank(); shape->reserve(rank); @@ -66,11 +66,13 @@ Optional MemRefRegion::getBoundingConstantSizeAndShape( // dimension. int64_t numElements = 1; int64_t diffConstant; + int64_t lbDivisor; for (unsigned d = 0; d < rank; d++) { SmallVector lb; - Optional diff = cst.getConstantBoundOnDimSize(d, &lb); + Optional diff = cst.getConstantBoundOnDimSize(d, &lb, &lbDivisor); if (diff.hasValue()) { diffConstant = diff.getValue(); + assert(lbDivisor > 0); } else { // If no constant bound is found, then it can always be bound by the // memref's dim size if the latter has a constant size along this dim. @@ -80,10 +82,13 @@ Optional MemRefRegion::getBoundingConstantSizeAndShape( diffConstant = dimSize; // Lower bound becomes 0. lb.resize(cst.getNumSymbolIds() + 1, 0); + lbDivisor = 1; } numElements *= diffConstant; if (lbs) { lbs->push_back(lb); + assert(lbDivisors && "both lbs and lbDivisor or none"); + lbDivisors->push_back(lbDivisor); } if (shape) { shape->push_back(diffConstant); diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index df4aa84b039..91614a386e6 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -204,9 +204,10 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForInst *forInst, // Compute the extents of the buffer. std::vector> lbs; + SmallVector lbDivisors; lbs.reserve(rank); - Optional numElements = - region.getBoundingConstantSizeAndShape(&fastBufferShape, &lbs); + Optional numElements = region.getConstantBoundingSizeAndShape( + &fastBufferShape, &lbs, &lbDivisors); if (!numElements.hasValue()) { LLVM_DEBUG(llvm::dbgs() << "Non-constant region size not supported\n"); *sizeInBytes = 0; @@ -219,10 +220,11 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForInst *forInst, return false; } + const FlatAffineConstraints *cst = region.getConstraints(); + // 'outerIVs' holds the values that this memory region is symbolic/paramteric // on; this would correspond to loop IVs surrounding the level at which the // DMA generation is being done. - const FlatAffineConstraints *cst = region.getConstraints(); SmallVector outerIVs; cst->getIdValues(rank, cst->getNumIds(), &outerIVs); @@ -241,7 +243,9 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForInst *forInst, for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) { offset = offset + lbs[d][j] * top.getAffineDimExpr(j); } - offset = offset + lbs[d][cst->getNumCols() - 1 - rank]; + assert(lbDivisors[d] > 0); + offset = + (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]); // Set DMA start location for this dimension in the lower memory space // memref. diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 5c80367cfa0..520b89ded48 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -681,9 +681,12 @@ static bool getSliceUnion(const ComputationSliceState &sliceStateA, // Creates and returns a private (single-user) memref for fused loop rooted // at 'forInst', with (potentially reduced) memref size based on the -// MemRefRegion written to by 'srcStoreOpInst'. +// MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'. +// TODO(bondhugula): consider refactoring the common code from generateDma and +// this one. static Value *createPrivateMemRef(ForInst *forInst, - OperationInst *srcStoreOpInst) { + OperationInst *srcStoreOpInst, + unsigned dstLoopDepth) { // Create builder to insert alloc op just before 'forInst'. FuncBuilder b(forInst); // Builder to create constants at the top level. @@ -693,28 +696,39 @@ static Value *createPrivateMemRef(ForInst *forInst, auto oldMemRefType = oldMemRef->getType().cast(); unsigned rank = oldMemRefType.getRank(); - // Compute MemRefRegion for 'srcStoreOpInst'. + // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'. MemRefRegion region; - getMemRefRegion(srcStoreOpInst, 0, ®ion); + getMemRefRegion(srcStoreOpInst, dstLoopDepth, ®ion); SmallVector newShape; std::vector> lbs; + SmallVector lbDivisors; lbs.reserve(rank); // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed - // by 'srcStoreOpInst'. + // by 'srcStoreOpInst' at depth 'dstLoopDepth'. Optional numElements = - region.getBoundingConstantSizeAndShape(&newShape, &lbs); + region.getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors); assert(numElements.hasValue()); - // Build 'rank' AffineExprs from MemRefRegion 'lbs' const FlatAffineConstraints *cst = region.getConstraints(); + // 'outerIVs' holds the values that this memory region is symbolic/paramteric + // on; this would correspond to loop IVs surrounding the level at which the + // slice is being materialized. + SmallVector outerIVs; + cst->getIdValues(rank, cst->getNumIds(), &outerIVs); + + // Build 'rank' AffineExprs from MemRefRegion 'lbs' SmallVector offsets; offsets.reserve(rank); for (unsigned d = 0; d < rank; ++d) { + assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size"); + AffineExpr offset = top.getAffineConstantExpr(0); for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) { offset = offset + lbs[d][j] * top.getAffineDimExpr(j); } - offset = offset + lbs[d][cst->getNumCols() - 1 - rank]; + assert(lbDivisors[d] > 0); + offset = + (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]); offsets.push_back(offset); } @@ -743,18 +757,23 @@ static Value *createPrivateMemRef(ForInst *forInst, if (auto constExpr = offsets[i].dyn_cast()) if (constExpr.getValue() == 0) ++zeroOffsetCount; - auto dimExpr = b.getAffineDimExpr(i); - remapExprs.push_back(dimExpr - offsets[i]); + auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i); + + auto remapExpr = + simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0); + remapExprs.push_back(remapExpr); } - auto indexRemap = zeroOffsetCount == rank - ? AffineMap::Null() - : b.getAffineMap(rank, 0, remapExprs, {}); + auto indexRemap = + zeroOffsetCount == rank + ? AffineMap::Null() + : b.getAffineMap(outerIVs.size() + rank, 0, remapExprs, {}); // Replace all users of 'oldMemRef' with 'newMemRef'. - bool ret = replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap, {}, - &*forInst->getBody()->begin()); - assert(ret); + bool ret = + replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap, + /*extraOperands=*/outerIVs, + /*domInstFilter=*/&*forInst->getBody()->begin()); + assert(ret && "replaceAllMemrefUsesWith should always succeed here"); (void)ret; - (void)indexRemap; return newMemRef; } @@ -1044,8 +1063,8 @@ public: storesForMemref.push_back(storeOpInst); } assert(storesForMemref.size() == 1); - auto *newMemRef = - createPrivateMemRef(dstForInst, storesForMemref[0]); + auto *newMemRef = createPrivateMemRef( + dstForInst, storesForMemref[0], dstLoopDepth); visitedMemrefs.insert(newMemRef); // Collect dst loop stats after memref privatizaton transformation. diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 4168dda064a..9d928fc0709 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -109,9 +109,14 @@ bool mlir::promoteIfSingleIteration(ForInst *forInst) { const AffineBound lb = forInst->getLowerBound(); SmallVector lbOperands(lb.operand_begin(), lb.operand_end()); FuncBuilder builder(forInst->getBlock(), Block::iterator(forInst)); - auto affineApplyOp = builder.create( - forInst->getLoc(), lb.getMap(), lbOperands); - forInst->replaceAllUsesWith(affineApplyOp->getResult(0)); + if (lb.getMap() == builder.getDimIdentityMap()) { + // No need of generating an affine_apply. + forInst->replaceAllUsesWith(lbOperands[0]); + } else { + auto affineApplyOp = builder.create( + forInst->getLoc(), lb.getMap(), lbOperands); + forInst->replaceAllUsesWith(affineApplyOp->getResult(0)); + } } } // Move the loop body instructions to the loop's containing block. diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 4f4aeabb26d..136cf20c12a 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -43,23 +43,6 @@ static bool isMemRefDereferencingOp(const OperationInst &op) { return false; } -/// Replaces all uses of oldMemRef with newMemRef while optionally remapping -/// old memref's indices to the new memref using the supplied affine map -/// and adding any additional indices. The new memref could be of a different -/// shape or rank, but of the same elemental type. Additional indices are added -/// at the start. 'extraOperands' is another optional argument that corresponds -/// to additional operands (inputs) for indexRemap at the beginning of its input -/// list. An optional argument 'domOpFilter' restricts the replacement to only -/// those operations that are dominated by the former. The replacement succeeds -/// and returns true if all uses of the memref in the region where the -/// replacement is asked for are "dereferencing" memref uses. -// Ex: to replace load %A[%i, %j] with load %Abuf[%t mod 2, %ii - %i, %j]: -// The SSA value corresponding to '%t mod 2' should be in 'extraIndices', and -// index remap will (%i, %j) -> (%ii - %i, %j), i.e., (d0, d1, d2) -> (d0 - d1, -// d2) will be the 'indexRemap', and %ii is the extra operand. Without any -// extra operands, note that 'indexRemap' would just be applied to the existing -// indices (%i, %j). -// bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, ArrayRef extraIndices, AffineMap indexRemap, @@ -84,6 +67,9 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, if (domInstFilter) domInfo = std::make_unique(domInstFilter->getFunction()); + // The ops where memref replacement succeeds are replaced with new ones. + SmallVector opsToErase; + // Walk all uses of old memref. Operation using the memref gets replaced. for (auto it = oldMemRef->use_begin(); it != oldMemRef->use_end();) { InstOperand &use = *(it++); @@ -171,8 +157,14 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, for (auto *res : opInst->getResults()) { res->replaceAllUsesWith(repOp->getResult(r++)); } - opInst->erase(); + // Collect and erase at the end since one of these op's could be + // domInstFilter! + opsToErase.push_back(opInst); } + + for (auto *opInst : opsToErase) + opInst->erase(); + return true; } diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index 12e0a8a456a..8e5b706835e 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -9,7 +9,7 @@ // ----- -// CHECK: [[MAP0:#map[0-9]+]] = (d0) -> (d0) +// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) // CHECK-LABEL: func @should_fuse_raw_dep_for_locality() { func @should_fuse_raw_dep_for_locality() { @@ -23,9 +23,10 @@ func @should_fuse_raw_dep_for_locality() { %v0 = load %m[%i1] : memref<10xf32> } // CHECK: for %i0 = 0 to 10 { - // CHECK-NEXT: %1 = affine_apply [[MAP0]](%i0) - // CHECK-NEXT: store %cst, %0[%1] : memref<10xf32> - // CHECK-NEXT: %2 = load %0[%i0] : memref<10xf32> + // CHECK-NEXT: %1 = affine_apply [[MAP0]](%i0, %i0) + // CHECK-NEXT: store %cst, %0[%1] : memref<1xf32> + // CHECK-NEXT: %2 = affine_apply [[MAP0]](%i0, %i0) + // CHECK-NEXT: %3 = load %0[%2] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -33,7 +34,7 @@ func @should_fuse_raw_dep_for_locality() { // ----- -// CHECK: [[MAP0:#map[0-9]+]] = (d0) -> (d0) +// CHECK-DAG: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) // CHECK-LABEL: func @should_fuse_reduction_to_pointwise() { func @should_fuse_reduction_to_pointwise() { @@ -59,15 +60,17 @@ func @should_fuse_reduction_to_pointwise() { // Should fuse in entire inner loop on %i1 from source loop nest, as %i1 // is not used in the access function of the store/load on %b. // CHECK: for %i0 = 0 to 10 { - // CHECK-NEXT: %3 = affine_apply [[MAP0]](%i0) // CHECK-NEXT: for %i1 = 0 to 10 { - // CHECK-NEXT: %4 = load %2[%3] : memref<10xf32> - // CHECK-NEXT: %5 = load %0[%3, %i1] : memref<10x10xf32> + // CHECK-NEXT: %3 = affine_apply [[MAP0]](%i0, %i0) + // CHECK-NEXT: %4 = load %2[%3] : memref<1xf32> + // CHECK-NEXT: %5 = load %0[%i0, %i1] : memref<10x10xf32> // CHECK-NEXT: %6 = addf %4, %5 : f32 - // CHECK-NEXT: store %6, %2[%3] : memref<10xf32> + // CHECK-NEXT: %7 = affine_apply [[MAP0]](%i0, %i0) + // CHECK-NEXT: store %6, %2[%7] : memref<1xf32> // CHECK-NEXT: } - // CHECK-NEXT: %7 = load %2[%i0] : memref<10xf32> - // CHECK-NEXT: store %7, %1[%i0] : memref<10xf32> + // CHECK-NEXT: %8 = affine_apply [[MAP0]](%i0, %i0) + // CHECK-NEXT: %9 = load %2[%8] : memref<1xf32> + // CHECK-NEXT: store %9, %1[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -75,9 +78,9 @@ func @should_fuse_reduction_to_pointwise() { // ----- -// CHECK: [[MAP_SHIFT_MINUS_ONE_R1:#map[0-9]+]] = (d0) -> (d0 - 1) -// CHECK: [[MAP_SHIFT_BY_ONE:#map[0-9]+]] = (d0, d1) -> (d0 + 1, d1 + 1) -// CHECK: [[MAP_SHIFT_MINUS_ONE_R2:#map[0-9]+]] = (d0, d1) -> (d0 - 1, d1 - 1) +// CHECK-DAG: [[MAP_SHIFT_MINUS_ONE_R1:#map[0-9]+]] = (d0) -> (d0 - 1) +// CHECK-DAG: [[MAP_SHIFT_BY_ONE:#map[0-9]+]] = (d0, d1) -> (d0 + 1, d1 + 1) +// CHECK-DAG: [[MAP_SHIFT_MINUS_IV_R2:#map[0-9]+]] = (d0, d1, d2, d3) -> (-d0 + d2, -d1 + d3) // CHECK-LABEL: func @should_fuse_loop_nests_with_shifts() { func @should_fuse_loop_nests_with_shifts() { @@ -104,16 +107,16 @@ func @should_fuse_loop_nests_with_shifts() { // *) Fifth affine apply shifts the loads access function by '-1', because // of the offset induced by reducing the memref shape from 10x10 to 9x9. // NOTE: Should create a private memref with reduced shape 9x9xf32. - // CHECK: %0 = alloc() : memref<9x9xf32> + // CHECK: %0 = alloc() : memref<1x1xf32> // CHECK-NEXT: for %i0 = 1 to 10 { // CHECK-NEXT: for %i1 = 1 to 10 { // CHECK-NEXT: %1 = affine_apply [[MAP_SHIFT_MINUS_ONE_R1]](%i0) // CHECK-NEXT: %2 = affine_apply [[MAP_SHIFT_MINUS_ONE_R1]](%i1) // CHECK-NEXT: %3 = affine_apply [[MAP_SHIFT_BY_ONE]](%1, %2) - // CHECK-NEXT: %4 = affine_apply [[MAP_SHIFT_MINUS_ONE_R2]](%3#0, %3#1) - // CHECK-NEXT: store %cst, %0[%4#0, %4#1] : memref<9x9xf32> - // CHECK-NEXT: %5 = affine_apply [[MAP_SHIFT_MINUS_ONE_R2]](%i0, %i1) - // CHECK-NEXT: %6 = load %0[%5#0, %5#1] : memref<9x9xf32> + // CHECK-NEXT: %4 = affine_apply [[MAP_SHIFT_MINUS_IV_R2]](%i0, %i1, %3#0, %3#1) + // CHECK-NEXT: store %cst, %0[%4#0, %4#1] : memref<1x1xf32> + // CHECK-NEXT: %5 = affine_apply [[MAP_SHIFT_MINUS_IV_R2]](%i0, %i1, %i0, %i1) + // CHECK-NEXT: %6 = load %0[%5#0, %5#1] : memref<1x1xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -122,7 +125,7 @@ func @should_fuse_loop_nests_with_shifts() { // ----- -// CHECK-DAG: [[MAP_ID:#map[0-9]+]] = (d0) -> (d0) +// CHECK-DAG: [[MAP0:#map[0-9]+]] = (d0, d1, d2, d3) -> (-d0 + d2, -d1 + d3) // CHECK-LABEL: func @should_fuse_loop_nest() { func @should_fuse_loop_nest() { @@ -147,18 +150,18 @@ func @should_fuse_loop_nest() { } } // Expecting private memref for '%b' first, then private memref for '%a'. - // CHECK: [[NEWB:%[0-9]+]] = alloc() : memref<10x10xf32> - // CHECK-NEXT: [[NEWA:%[0-9]+]] = alloc() : memref<10x10xf32> + // CHECK: [[NEWB:%[0-9]+]] = alloc() : memref<1x1xf32> + // CHECK-NEXT: [[NEWA:%[0-9]+]] = alloc() : memref<1x1xf32> // CHECK-NEXT: for %i0 = 0 to 10 { // CHECK-NEXT: for %i1 = 0 to 10 { - // CHECK-NEXT: %2 = affine_apply [[MAP_ID]](%i1) - // CHECK-NEXT: %3 = affine_apply [[MAP_ID]](%i0) - // CHECK-NEXT: store %cst, [[NEWA]][%2, %3] : memref<10x10xf32> - // CHECK-NEXT: %4 = affine_apply [[MAP_ID]](%i0) - // CHECK-NEXT: %5 = affine_apply [[MAP_ID]](%i1) - // CHECK-NEXT: %6 = load [[NEWA]][%5, %4] : memref<10x10xf32> - // CHECK-NEXT: store %6, [[NEWB]][%4, %5] : memref<10x10xf32> - // CHECK-NEXT: %7 = load [[NEWB]][%i0, %i1] : memref<10x10xf32> + // CHECK-NEXT: %2 = affine_apply [[MAP0]](%i1, %i0, %i1, %i0) + // CHECK-NEXT: store %cst, [[NEWA]][%2#0, %2#1] : memref<1x1xf32> + // CHECK-NEXT: %3 = affine_apply [[MAP0]](%i1, %i0, %i1, %i0) + // CHECK-NEXT: %4 = load [[NEWA]][%3#0, %3#1] : memref<1x1xf32> + // CHECK-NEXT: %5 = affine_apply [[MAP0]](%i0, %i1, %i0, %i1) + // CHECK-NEXT: store %4, [[NEWB]][%5#0, %5#1] : memref<1x1xf32> + // CHECK-NEXT: %6 = affine_apply [[MAP0]](%i0, %i1, %i0, %i1) + // CHECK-NEXT: %7 = load [[NEWB]][%6#0, %6#1] : memref<1x1xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -167,7 +170,7 @@ func @should_fuse_loop_nest() { // ----- -// CHECK: [[MAP0:#map[0-9]+]] = (d0) -> (d0) +// CHECK-DAG: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) // CHECK-LABEL: func @should_fuse_across_intermediate_loop_with_no_deps() { func @should_fuse_across_intermediate_loop_with_no_deps() { @@ -193,12 +196,13 @@ func @should_fuse_across_intermediate_loop_with_no_deps() { // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK: %2 = alloc() : memref<10xf32> + // CHECK: %2 = alloc() : memref<1xf32> // CHECK: for %i1 = 0 to 10 { - // CHECK-NEXT: %3 = affine_apply [[MAP0]](%i1) - // CHECK-NEXT: %4 = load %0[%3] : memref<10xf32> - // CHECK-NEXT: store %4, %2[%3] : memref<10xf32> - // CHECK-NEXT: %5 = load %2[%i1] : memref<10xf32> + // CHECK-NEXT: %3 = load %0[%i1] : memref<10xf32> + // CHECK-NEXT: %4 = affine_apply [[MAP0]](%i1, %i1) + // CHECK-NEXT: store %3, %2[%4] : memref<1xf32> + // CHECK-NEXT: %5 = affine_apply [[MAP0]](%i1, %i1) + // CHECK-NEXT: %6 = load %2[%5] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -206,7 +210,7 @@ func @should_fuse_across_intermediate_loop_with_no_deps() { // ----- -// CHECK: [[MAP0:#map[0-9]+]] = (d0) -> (d0) +// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) // CHECK-LABEL: func @should_fuse_all_loops() { func @should_fuse_all_loops() { @@ -228,15 +232,17 @@ func @should_fuse_all_loops() { // Should fuse first and second loops into third. // Expecting private memref for '%b' first, then private memref for '%a'. - // CHECK: [[NEWB:%[0-9]+]] = alloc() : memref<10xf32> - // CHECK-NEXT: [[NEWA:%[0-9]+]] = alloc() : memref<10xf32> + // CHECK: [[NEWB:%[0-9]+]] = alloc() : memref<1xf32> + // CHECK-NEXT: [[NEWA:%[0-9]+]] = alloc() : memref<1xf32> // CHECK-NEXT: for %i0 = 0 to 10 { - // CHECK-NEXT: %2 = affine_apply [[MAP0]](%i0) - // CHECK-NEXT: store %cst, [[NEWA]][%2] : memref<10xf32> - // CHECK-NEXT: %3 = affine_apply [[MAP0]](%i0) - // CHECK-NEXT: store %cst, [[NEWB]][%3] : memref<10xf32> - // CHECK-NEXT: %4 = load [[NEWA]][%i0] : memref<10xf32> - // CHECK-NEXT: %5 = load [[NEWB]][%i0] : memref<10xf32> + // CHECK-NEXT: %2 = affine_apply [[MAP0]](%i0, %i0) + // CHECK-NEXT: store %cst, [[NEWA]][%2] : memref<1xf32> + // CHECK-NEXT: %3 = affine_apply [[MAP0]](%i0, %i0) + // CHECK-NEXT: store %cst, [[NEWB]][%3] : memref<1xf32> + // CHECK-NEXT: %4 = affine_apply [[MAP0]](%i0, %i0) + // CHECK-NEXT: %5 = load [[NEWA]][%4] : memref<1xf32> + // CHECK-NEXT: %6 = affine_apply [[MAP0]](%i0, %i0) + // CHECK-NEXT: %7 = load [[NEWB]][%6] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -244,7 +250,7 @@ func @should_fuse_all_loops() { // ----- -// CHECK: [[MAP0:#map[0-9]+]] = (d0) -> (d0) +// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) // CHECK-LABEL: func @should_fuse_first_and_second_loops() { func @should_fuse_first_and_second_loops() { @@ -267,15 +273,16 @@ func @should_fuse_first_and_second_loops() { // Should fuse first loop into the second (last loop should not be fused). // Should create private memref '%2' for fused loop. - // CHECK: %2 = alloc() : memref<10xf32> + // CHECK: %2 = alloc() : memref<1xf32> // CHECK-NEXT: for %i0 = 0 to 10 { - // CHECK-NEXT: %3 = affine_apply [[MAP0]](%i0) - // CHECK-NEXT: store %cst, %2[%3] : memref<10xf32> - // CHECK-NEXT: %4 = load %2[%i0] : memref<10xf32> + // CHECK-NEXT: %3 = affine_apply [[MAP0]](%i0, %i0) + // CHECK-NEXT: store %cst, %2[%3] : memref<1xf32> + // CHECK-NEXT: %4 = affine_apply [[MAP0]](%i0, %i0) + // CHECK-NEXT: %5 = load %2[%4] : memref<1xf32> // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK: for %i1 = 0 to 10 { - // CHECK-NEXT: %5 = load %1[%i1] : memref<10xf32> + // CHECK-NEXT: %6 = load %1[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return @@ -326,7 +333,7 @@ func @should_not_fuse_would_create_cycle() { } // ----- -// CHECK: #map0 = (d0) -> (d0) +// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) // CHECK-LABEL: func @should_fuse_across_waw_dep_with_private_memref() { func @should_fuse_across_waw_dep_with_private_memref() { @@ -349,11 +356,12 @@ func @should_fuse_across_waw_dep_with_private_memref() { // CHECK: for %i1 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i1] : memref<10xf32> // CHECK-NEXT: } - // CHECK: %1 = alloc() : memref<10xf32> + // CHECK: %1 = alloc() : memref<1xf32> // CHECK-NEXT: for %i2 = 0 to 10 { - // CHECK-NEXT: %2 = affine_apply #map0(%i2) - // CHECK-NEXT: store %cst, %1[%2] : memref<10xf32> - // CHECK-NEXT: %3 = load %1[%i2] : memref<10xf32> + // CHECK-NEXT: %2 = affine_apply #map0(%i2, %i2) + // CHECK-NEXT: store %cst, %1[%2] : memref<1xf32> + // CHECK-NEXT: %3 = affine_apply #map0(%i2, %i2) + // CHECK-NEXT: %4 = load %1[%3] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -412,18 +420,19 @@ func @should_fuse_with_private_memref_if_top_level_access() { // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK: %1 = alloc() : memref<10xf32> + // CHECK: %1 = alloc() : memref<1xf32> // CHECK-NEXT: for %i1 = 0 to 10 { - // CHECK-NEXT: %2 = affine_apply #map0(%i1) - // CHECK-NEXT: store %cst, %1[%2] : memref<10xf32> - // CHECK-NEXT: %3 = load %1[%i1] : memref<10xf32> + // CHECK-NEXT: %2 = affine_apply #map0(%i1, %i1) + // CHECK-NEXT: store %cst, %1[%2] : memref<1xf32> + // CHECK-NEXT: %3 = affine_apply #map0(%i1, %i1) + // CHECK-NEXT: %4 = load %1[%3] : memref<1xf32> // CHECK-NEXT: } return } // ----- -// CHECK: [[MAP0:#map[0-9]+]] = (d0) -> (d0) +// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) // CHECK-LABEL: func @should_fuse_no_top_level_access() { func @should_fuse_no_top_level_access() { @@ -437,9 +446,10 @@ func @should_fuse_no_top_level_access() { %v0 = load %m[%i1] : memref<10xf32> } // CHECK: for %i0 = 0 to 10 { - // CHECK-NEXT: %1 = affine_apply #map0(%i0) - // CHECK-NEXT: store %cst, %0[%1] : memref<10xf32> - // CHECK-NEXT: %2 = load %0[%i0] : memref<10xf32> + // CHECK-NEXT: %1 = affine_apply #map0(%i0, %i0) + // CHECK-NEXT: store %cst, %0[%1] : memref<1xf32> + // CHECK-NEXT: %2 = affine_apply #map0(%i0, %i0) + // CHECK-NEXT: %3 = load %0[%2] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -506,10 +516,12 @@ func @should_not_fuse_if_inst_in_loop_nest() { // ----- -// CHECK: [[MAP0:#map[0-9]+]] = (d0) -> (d0) -// CHECK: [[MAP1:#map[0-9]+]] = (d0, d1, d2) -> (d0, d1, d2) +// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1, d2) -> (d0, d1, d2) +// CHECK: [[MAP1:#map[0-9]+]] = (d0, d1, d2, d3, d4, d5) -> (-d0 + d3, -d1 + d4, -d2 + d5) // CHECK: [[MAP_PERMUTE:#map[0-9]+]] = (d0, d1, d2) -> (d1, d2, d0) +#map0 = (d0, d1, d2) -> (d0, d1, d2) + // CHECK-LABEL: func @remap_ivs() { func @remap_ivs() { %m = alloc() : memref<10x20x30xf32> @@ -534,13 +546,12 @@ func @remap_ivs() { // CHECK: for %i0 = 0 to 30 { // CHECK-NEXT: for %i1 = 0 to 10 { // CHECK-NEXT: for %i2 = 0 to 20 { -// CHECK-NEXT: %1 = affine_apply [[MAP0]](%i1) -// CHECK-NEXT: %2 = affine_apply [[MAP0]](%i2) -// CHECK-NEXT: %3 = affine_apply [[MAP0]](%i0) -// CHECK-NEXT: %4 = affine_apply [[MAP1]](%1, %2, %3) -// CHECK-NEXT: store %cst, %0[%4#0, %4#1, %4#2] : memref<10x20x30xf32> -// CHECK-NEXT: %5 = affine_apply [[MAP_PERMUTE]](%i0, %i1, %i2) -// CHECK-NEXT: %6 = load %0[%5#0, %5#1, %5#2] : memref<10x20x30xf32> +// CHECK-NEXT: %1 = affine_apply [[MAP0]](%i1, %i2, %i0) +// CHECK-NEXT: %2 = affine_apply [[MAP1]](%i1, %i2, %i0, %1#0, %1#1, %1#2) +// CHECK-NEXT: store %cst, %0[%2#0, %2#1, %2#2] : memref<1x1x1xf32> +// CHECK-NEXT: %3 = affine_apply [[MAP_PERMUTE]](%i0, %i1, %i2) +// CHECK-NEXT: %4 = affine_apply [[MAP1]](%i1, %i2, %i0, %3#0, %3#1, %3#2) +// CHECK-NEXT: %5 = load %0[%4#0, %4#1, %4#2] : memref<1x1x1xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } @@ -581,8 +592,10 @@ func @fuse_reshape_64_16_4(%in : memref<64xf32>) { } // ----- -// CHECK: #map0 = (d0) -> (d0 floordiv 4) -// CHECK: #map1 = (d0) -> (d0 mod 4) +// CHECK-DAG: #map0 = (d0) -> (d0 floordiv 4) +// CHECK-DAG: #map1 = (d0) -> (d0 mod 4) +// CHECK-DAG: [[MAP2:#map[0-9]+]] = (d0, d1) -> (d0 * 4 + d1) +// CHECK-DAG: [[MAP3:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) // Reshape a 16x4xf32 to 64xf32. // CHECK-LABEL: func @fuse_reshape_16_4_64 @@ -606,10 +619,12 @@ func @fuse_reshape_16_4_64() { // CHECK-NEXT: %2 = affine_apply #map0(%i0) // CHECK-NEXT: %3 = affine_apply #map1(%i0) // CHECK-NEXT: %4 = load %0[%2, %3] : memref<16x4xf32> -// CHECK-NEXT: %5 = affine_apply #map2(%2, %3) -// CHECK-NEXT: store %4, %1[%5] : memref<64xf32> -// CHECK-NEXT: %6 = load %1[%i0] : memref<64xf32> -// CHECK-NEXT: "foo"(%6) : (f32) -> () +// CHECK-NEXT: %5 = affine_apply [[MAP2]](%2, %3) +// CHECK-NEXT: %6 = affine_apply [[MAP3]](%i0, %5) +// CHECK-NEXT: store %4, %1[%6] : memref<1xf32> +// CHECK-NEXT: %7 = affine_apply [[MAP3]](%i0, %i0) +// CHECK-NEXT: %8 = load %1[%7] : memref<1xf32> +// CHECK-NEXT: "foo"(%8) : (f32) -> () // CHECK-NEXT: } // CHECK-NEXT: return return @@ -674,8 +689,8 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> { // is eliminated if -memref-dataflow-opt is also supplied. // // CHECK: %0 = alloc() : memref<64x9xi32> -// CHECK-NEXT: %1 = alloc() : memref<64x9xi32> -// CHECK-NEXT: %2 = alloc() : memref<2x2x3x3x16x1xi32> +// CHECK-NEXT: %1 = alloc() : memref<1x1xi32> +// CHECK-NEXT: %2 = alloc() : memref<1x2x3x3x16x1xi32> // CHECK-NEXT: for %i0 = 0 to 64 { // CHECK-NEXT: for %i1 = 0 to 9 { // CHECK-NEXT: %3 = affine_apply #map0(%i0, %i1) @@ -684,16 +699,18 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> { // CHECK-NEXT: %6 = affine_apply #map3(%i0, %i1) // CHECK-NEXT: %7 = affine_apply #map4(%i0, %i1) // CHECK-NEXT: %8 = "foo"(%3, %4, %5, %6, %7, %c0) : (index, index, index, index, index, index) -> i32 -// CHECK-NEXT: store %8, %2[%3, %4, %5, %6, %7, %c0] : memref<2x2x3x3x16x1xi32> -// CHECK-NEXT: %9 = affine_apply #map5(%i0) -// CHECK-NEXT: %10 = affine_apply #map5(%i1) -// CHECK-NEXT: %11 = affine_apply #map6(%9, %10) -// CHECK-NEXT: %12 = affine_apply #map7(%11) -// CHECK-NEXT: %13 = load %2[%12#0, %12#1, %12#2, %12#3, %12#4, %12#5] : memref<2x2x3x3x16x1xi32> -// CHECK-NEXT: store %13, %1[%9, %10] : memref<64x9xi32> -// CHECK-NEXT: %14 = load %1[%i0, %i1] : memref<64x9xi32> -// CHECK-NEXT: %15 = muli %14, %14 : i32 -// CHECK-NEXT: store %15, %0[%i0, %i1] : memref<64x9xi32> +// CHECK-NEXT: %9 = affine_apply #map5(%i0, %i1, %3, %4, %5, %6, %7, %c0) +// CHECK-NEXT: store %8, %2[%9#0, %9#1, %9#2, %9#3, %9#4, %9#5] : memref<1x2x3x3x16x1xi32> +// CHECK-NEXT: %10 = affine_apply #map6(%i0, %i1) +// CHECK-NEXT: %11 = affine_apply #map7(%10) +// CHECK-NEXT: %12 = affine_apply #map5(%i0, %i1, %11#0, %11#1, %11#2, %11#3, %11#4, %11#5) +// CHECK-NEXT: %13 = load %2[%12#0, %12#1, %12#2, %12#3, %12#4, %12#5] : memref<1x2x3x3x16x1xi32> +// CHECK-NEXT: %14 = affine_apply #map8(%i0, %i1, %i0, %i1) +// CHECK-NEXT: store %13, %1[%14#0, %14#1] : memref<1x1xi32> +// CHECK-NEXT: %15 = affine_apply #map8(%i0, %i1, %i0, %i1) +// CHECK-NEXT: %16 = load %1[%15#0, %15#1] : memref<1x1xi32> +// CHECK-NEXT: %17 = muli %16, %16 : i32 +// CHECK-NEXT: store %17, %0[%i0, %i1] : memref<64x9xi32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return %0 : memref<64x9xi32> @@ -725,7 +742,7 @@ func @fuse_symbolic_bounds(%M : index, %N : index) { } // ----- -// CHECK: #map0 = (d0) -> (d0) +// CHECK-DAG: #map0 = (d0, d1) -> (-d0 + d1) // CHECK-LABEL: func @should_fuse_reduction_at_depth1 func @should_fuse_reduction_at_depth1() { @@ -753,18 +770,21 @@ func @should_fuse_reduction_at_depth1() { // decrease the reduction memref size and possibly place it in a faster // memory space. // CHECK: for %i0 = 0 to 10 { - // CHECK-NEXT: %2 = affine_apply #map0(%i0) // CHECK-NEXT: for %i1 = 0 to 100 { - // CHECK-NEXT: %3 = load %1[%2] : memref<10xf32> - // CHECK-NEXT: %4 = load %0[%2, %i1] : memref<10x100xf32> + // CHECK-NEXT: %2 = affine_apply #map0(%i0, %i0) + // CHECK-NEXT: %3 = load %1[%2] : memref<1xf32> + // CHECK-NEXT: %4 = load %0[%i0, %i1] : memref<10x100xf32> // CHECK-NEXT: %5 = "maxf"(%3, %4) : (f32, f32) -> f32 - // CHECK-NEXT: store %5, %1[%2] : memref<10xf32> + // CHECK-NEXT: %6 = affine_apply #map0(%i0, %i0) + // CHECK-NEXT: store %5, %1[%6] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: for %i2 = 0 to 100 { - // CHECK-NEXT: %6 = load %1[%i0] : memref<10xf32> - // CHECK-NEXT: %7 = load %0[%i0, %i2] : memref<10x100xf32> - // CHECK-NEXT: %8 = subf %7, %6 : f32 - // CHECK-NEXT: store %8, %1[%i0] : memref<10xf32> + // CHECK-NEXT: %7 = affine_apply #map0(%i0, %i0) + // CHECK-NEXT: %8 = load %1[%7] : memref<1xf32> + // CHECK-NEXT: %9 = load %0[%i0, %i2] : memref<10x100xf32> + // CHECK-NEXT: %10 = subf %9, %8 : f32 + // CHECK-NEXT: %11 = affine_apply #map0(%i0, %i0) + // CHECK-NEXT: store %10, %1[%11] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -772,7 +792,7 @@ func @should_fuse_reduction_at_depth1() { } // ----- -// CHECK: #map0 = (d0) -> (d0) +// CHECK: #map0 = (d0, d1, d2) -> (-d0 + d1, d2) // CHECK-LABEL: func @should_fuse_at_src_depth1_and_dst_depth1 func @should_fuse_at_src_depth1_and_dst_depth1() { @@ -802,18 +822,19 @@ func @should_fuse_at_src_depth1_and_dst_depth1() { // the fusion algorithm should detect that the source loop should be sliced // at depth 1 and the slice should be inserted at depth 1. // CHECK: for %i0 = 0 to 100 { - // CHECK-NEXT: %2 = affine_apply #map0(%i0) // CHECK-NEXT: for %i1 = 0 to 16 { - // CHECK-NEXT: %3 = load %0[%2, %i1] : memref<100x16xf32> - // CHECK-NEXT: "op0"(%3) : (f32) -> () + // CHECK-NEXT: %2 = load %0[%i0, %i1] : memref<100x16xf32> + // CHECK-NEXT: "op0"(%2) : (f32) -> () // CHECK-NEXT: } // CHECK-NEXT: for %i2 = 0 to 16 { - // CHECK-NEXT: %4 = "op1"() : () -> f32 - // CHECK-NEXT: store %4, %1[%2, %i2] : memref<100x16xf32> + // CHECK-NEXT: %3 = "op1"() : () -> f32 + // CHECK-NEXT: %4 = affine_apply #map0(%i0, %i0, %i2) + // CHECK-NEXT: store %3, %1[%4#0, %4#1] : memref<1x16xf32> // CHECK-NEXT: } // CHECK-NEXT: for %i3 = 0 to 16 { - // CHECK-NEXT: %5 = load %1[%i0, %i3] : memref<100x16xf32> - // CHECK-NEXT: "op2"(%5) : (f32) -> () + // CHECK-NEXT: %5 = affine_apply #map0(%i0, %i0, %i3) + // CHECK-NEXT: %6 = load %1[%5#0, %5#1] : memref<1x16xf32> + // CHECK-NEXT: "op2"(%6) : (f32) -> () // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -822,6 +843,7 @@ func @should_fuse_at_src_depth1_and_dst_depth1() { // ----- // CHECK: #map0 = (d0, d1) -> (d0 * 10 + d1) +// CHECK: #map1 = (d0, d1, d2) -> (d0 * -10 - d1 + d2) // CHECK-LABEL: func @should_fuse_src_depth1_at_dst_depth2 func @should_fuse_src_depth1_at_dst_depth2() { @@ -843,9 +865,11 @@ func @should_fuse_src_depth1_at_dst_depth2() { // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: for %i1 = 0 to 10 { // CHECK-NEXT: %1 = affine_apply #map0(%i0, %i1) - // CHECK-NEXT: store %cst, %0[%1] : memref<100xf32> - // CHECK-NEXT: %2 = affine_apply #map0(%i0, %i1) - // CHECK-NEXT: %3 = load %0[%2] : memref<100xf32> + // CHECK-NEXT: %2 = affine_apply #map1(%i0, %i1, %1) + // CHECK-NEXT: store %cst, %0[%2] : memref<1xf32> + // CHECK-NEXT: %3 = affine_apply #map0(%i0, %i1) + // CHECK-NEXT: %4 = affine_apply #map1(%i0, %i1, %3) + // CHECK-NEXT: %5 = load %0[%4] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -879,7 +903,8 @@ func @fusion_at_depth0_not_currently_supported() { } // ----- -// CHECK: #map0 = (d0) -> (d0) + +// CHECK-DAG: #map0 = (d0, d1, d2, d3, d4, d5, d6, d7, d8, d9) -> (-d0 + d4, -d1 + d5, -d2 + d6, -d3 + d7, d8, d9) // CHECK-LABEL: func @should_fuse_deep_loop_nests func @should_fuse_deep_loop_nests() { @@ -945,18 +970,15 @@ func @should_fuse_deep_loop_nests() { // CHECK-NEXT: for %i1 = 0 to 3 { // CHECK-NEXT: for %i2 = 0 to 2 { // CHECK-NEXT: for %i3 = 0 to 2 { -// CHECK-NEXT: %3 = affine_apply #map0(%i2) -// CHECK-NEXT: %4 = affine_apply #map0(%i3) -// CHECK-NEXT: %5 = affine_apply #map0(%i0) -// CHECK-NEXT: %6 = affine_apply #map0(%i1) // CHECK-NEXT: for %i4 = 0 to 16 { // CHECK-NEXT: for %i5 = 0 to 10 { -// CHECK-NEXT: %7 = load %0[%3, %4, %5, %6, %i4, %i5] : memref<2x2x3x3x16x10xf32, 2> +// CHECK-NEXT: %3 = load %0[%i2, %i3, %i0, %i1, %i4, %i5] : memref<2x2x3x3x16x10xf32, 2> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: for %i6 = 0 to 16 { // CHECK-NEXT: for %i7 = 0 to 10 { -// CHECK-NEXT: store %cst, %2[%3, %4, %5, %6, %i6, %i7] : memref<2x2x3x3x16x10xf32, 2> +// CHECK-NEXT: %4 = affine_apply #map0(%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i6, %i7) +// CHECK-NEXT: store %cst, %2[%4#0, %4#1, %4#2, %4#3, %4#4, %4#5] : memref<1x1x1x1x16x10xf32, 2> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: for %i8 = 0 to 3 { @@ -965,12 +987,13 @@ func @should_fuse_deep_loop_nests() { // CHECK-NEXT: for %i11 = 0 to 2 { // CHECK-NEXT: for %i12 = 0 to 16 { // CHECK-NEXT: for %i13 = 0 to 10 { -// CHECK-NEXT: %8 = load %0[%i10, %i11, %i8, %i9, %i12, %i13] : memref<2x2x3x3x16x10xf32, 2> +// CHECK-NEXT: %5 = load %0[%i10, %i11, %i8, %i9, %i12, %i13] : memref<2x2x3x3x16x10xf32, 2> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: for %i14 = 0 to 16 { // CHECK-NEXT: for %i15 = 0 to 10 { -// CHECK-NEXT: %9 = load %2[%i2, %i3, %i0, %i1, %i14, %i15] : memref<2x2x3x3x16x10xf32, 2> +// CHECK-NEXT: %6 = affine_apply #map0(%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i14, %i15) +// CHECK-NEXT: %7 = load %2[%6#0, %6#1, %6#2, %6#3, %6#4, %6#5] : memref<1x1x1x1x16x10xf32, 2> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } @@ -986,7 +1009,7 @@ func @should_fuse_deep_loop_nests() { } // ----- -// CHECK: #map0 = (d0) -> (d0) +// CHECK: #map0 = (d0, d1, d2) -> (-d0 + d1, d2) // CHECK-LABEL: func @should_fuse_at_depth1_and_reduce_slice_trip_count func @should_fuse_at_depth1_and_reduce_slice_trip_count() { @@ -1019,17 +1042,18 @@ func @should_fuse_at_depth1_and_reduce_slice_trip_count() { // NOTE: the size of the private memref created for the fused loop nest // is reduced from the original shape from 4x256 to 4x16 because of the // data accessed by the load. - // CHECK: %1 = alloc() : memref<4x16xf32> + // CHECK: %1 = alloc() : memref<1x16xf32> // CHECK-NEXT: for %i0 = 0 to 4 { - // CHECK-NEXT: %2 = affine_apply #map0(%i0) // CHECK-NEXT: for %i1 = 0 to 256 { - // CHECK-NEXT: %3 = load %0[%2, %i1] : memref<4x256xf32> + // CHECK-NEXT: %2 = load %0[%i0, %i1] : memref<4x256xf32> // CHECK-NEXT: } // CHECK-NEXT: for %i2 = 0 to 16 { - // CHECK-NEXT: store %cst, %1[%2, %i2] : memref<4x16xf32> + // CHECK-NEXT: %3 = affine_apply #map0(%i0, %i0, %i2) + // CHECK-NEXT: store %cst, %1[%3#0, %3#1] : memref<1x16xf32> // CHECK-NEXT: } // CHECK-NEXT: for %i3 = 0 to 16 { - // CHECK-NEXT: %4 = load %1[%i0, %i3] : memref<4x16xf32> + // CHECK-NEXT: %4 = affine_apply #map0(%i0, %i0, %i3) + // CHECK-NEXT: %5 = load %1[%4#0, %4#1] : memref<1x16xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -1120,7 +1144,7 @@ func @should_fuse_at_depth1_with_trip_count_19() { // ----- -// CHECK: #map0 = (d0) -> (d0) +// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) // CHECK-LABEL: func @should_fuse_with_private_memrefs_with_diff_shapes() { func @should_fuse_with_private_memrefs_with_diff_shapes() { @@ -1138,17 +1162,19 @@ func @should_fuse_with_private_memrefs_with_diff_shapes() { } // Should create two new private memrefs customized to the shapes accessed // by loops %i1 and %i2. - // CHECK: %0 = alloc() : memref<17xf32> + // CHECK: %0 = alloc() : memref<1xf32> // CHECK-NEXT: for %i0 = 0 to 17 { - // CHECK-NEXT: %1 = affine_apply #map0(%i0) - // CHECK-NEXT: store %cst, %0[%1] : memref<17xf32> - // CHECK-NEXT: %2 = load %0[%i0] : memref<17xf32> + // CHECK-NEXT: %1 = affine_apply #map0(%i0, %i0) + // CHECK-NEXT: store %cst, %0[%1] : memref<1xf32> + // CHECK-NEXT: %2 = affine_apply #map0(%i0, %i0) + // CHECK-NEXT: %3 = load %0[%2] : memref<1xf32> // CHECK-NEXT: } - // CHECK-NEXT: %3 = alloc() : memref<82xf32> + // CHECK-NEXT: %4 = alloc() : memref<1xf32> // CHECK-NEXT: for %i1 = 0 to 82 { - // CHECK-NEXT: %4 = affine_apply #map0(%i1) - // CHECK-NEXT: store %cst, %3[%4] : memref<82xf32> - // CHECK-NEXT: %5 = load %3[%i1] : memref<82xf32> + // CHECK-NEXT: %5 = affine_apply #map0(%i1, %i1) + // CHECK-NEXT: store %cst, %4[%5] : memref<1xf32> + // CHECK-NEXT: %6 = affine_apply #map0(%i1, %i1) + // CHECK-NEXT: %7 = load %4[%6] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -1156,7 +1182,7 @@ func @should_fuse_with_private_memrefs_with_diff_shapes() { // ----- -// CHECK: #map0 = (d0) -> (d0) +// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) // CHECK-LABEL: func @fusion_should_not_remove_memref_arg(%arg0: memref<10xf32>) { func @fusion_should_not_remove_memref_arg(%arg0: memref<10xf32>) { @@ -1173,11 +1199,12 @@ func @fusion_should_not_remove_memref_arg(%arg0: memref<10xf32>) { // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %arg0[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: %0 = alloc() : memref<10xf32> + // CHECK-NEXT: %0 = alloc() : memref<1xf32> // CHECK-NEXT: for %i1 = 0 to 10 { - // CHECK-NEXT: %1 = affine_apply #map0(%i1) - // CHECK-NEXT: store %cst, %0[%1] : memref<10xf32> - // CHECK-NEXT: %2 = load %0[%i1] : memref<10xf32> + // CHECK-NEXT: %1 = affine_apply [[MAP0]](%i1, %i1) + // CHECK-NEXT: store %cst, %0[%1] : memref<1xf32> + // CHECK-NEXT: %2 = affine_apply [[MAP0]](%i1, %i1) + // CHECK-NEXT: %3 = load %0[%2] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -1185,7 +1212,7 @@ func @fusion_should_not_remove_memref_arg(%arg0: memref<10xf32>) { // ----- -// CHECK: #map0 = (d0) -> (d0) +// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) // CHECK-LABEL: func @fusion_should_not_remove_escaping_memref() func @fusion_should_not_remove_escaping_memref() -> memref<10xf32> { @@ -1202,12 +1229,63 @@ func @fusion_should_not_remove_escaping_memref() -> memref<10xf32> { // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: %1 = alloc() : memref<10xf32> + // CHECK-NEXT: %1 = alloc() : memref<1xf32> // CHECK-NEXT: for %i1 = 0 to 10 { - // CHECK-NEXT: %2 = affine_apply #map0(%i1) - // CHECK-NEXT: store %cst, %1[%2] : memref<10xf32> - // CHECK-NEXT: %3 = load %1[%i1] : memref<10xf32> + // CHECK-NEXT: %2 = affine_apply [[MAP0]](%i1, %i1) + // CHECK-NEXT: store %cst, %1[%2] : memref<1xf32> + // CHECK-NEXT: %3 = affine_apply [[MAP0]](%i1, %i1) + // CHECK-NEXT: %4 = load %1[%3] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: return %0 : memref<10xf32> return %m : memref<10xf32> } + +// ----- + +// This should fuse with the %in becoming a 1x1x1. +func @R3_to_R2_reshape() { + %in = alloc() : memref<2x3x16xi32> + + %c0 = constant 0 : index + + for %i0 = 0 to 2 { + for %i1 = 0 to 3 { + for %i2 = 0 to 16 { + %val = "foo"(%i0, %i1, %i2) : (index, index, index) -> i32 + store %val, %in[%i0, %i1, %i2] : memref<2x3x16xi32> + } + } + } + + for %ii = 0 to 32 { + for %jj = 0 to 3 { + %a0 = affine_apply (d0, d1) -> (d0 * 3 + d1) (%ii, %jj) + %a1 = affine_apply (d0) -> (d0 floordiv (3 * 16)) (%a0) + %v = load %in[%a1#0, %jj, %c0] + : memref<2x3x16xi32> + } + } + return +} +// CHECK: #map0 = (d0, d1) -> ((d0 * 3 + d1) floordiv 48) +// CHECK-NEXT: #map1 = ()[s0] -> (s0) +// CHECK-NEXT: #map2 = (d0, d1, d2, d3, d4) -> (d2 - (d0 * 25 + d1 * 24) floordiv 24, -d1 + d3, d4) +// CHECK-NEXT: #map3 = (d0, d1) -> (d0 * 3 + d1) +// CHECK-NEXT: #map4 = (d0) -> (d0 floordiv 48) +// CHECK-LABEL: func @R3_to_R2_reshape() +// CHECK: %0 = alloc() : memref<1x1x1xi32> +// CHECK-NEXT: for %i0 = 0 to 32 { +// CHECK-NEXT: for %i1 = 0 to 3 { +// CHECK-NEXT: %1 = affine_apply #map0(%i0, %i1) +// CHECK-NEXT: %2 = affine_apply #map1()[%c0] +// CHECK-NEXT: %3 = "foo"(%1, %i1, %2) : (index, index, index) -> i32 +// CHECK-NEXT: %4 = affine_apply #map2(%i0, %i1, %1, %i1, %2) +// CHECK-NEXT: store %3, %0[%4#0, %4#1, %4#2] : memref<1x1x1xi32> +// CHECK-NEXT: %5 = affine_apply #map3(%i0, %i1) +// CHECK-NEXT: %6 = affine_apply #map4(%5) +// CHECK-NEXT: %7 = affine_apply #map2(%i0, %i1, %6, %i1, %c0) +// CHECK-NEXT: %8 = load %0[%7#0, %7#1, %7#2] : memref<1x1x1xi32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } -- cgit v1.2.3 From 864d9e02a17f02fb396e5ea047623ed4207153f1 Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Wed, 23 Jan 2019 09:16:24 -0800 Subject: Update fusion cost model + some additional infrastructure and debug information for -loop-fusion - update fusion cost model to fuse while tolerating a certain amount of redundant computation; add cl option -fusion-compute-tolerance evaluate memory footprint and intermediate memory reduction - emit debug info from -loop-fusion showing what was fused and why - introduce function to compute memory footprint for a loop nest - getMemRefRegion readability update - NFC PiperOrigin-RevId: 230541857 --- mlir/include/mlir/Analysis/Utils.h | 11 +- mlir/lib/Analysis/AffineStructures.cpp | 11 +- mlir/lib/Analysis/Utils.cpp | 107 ++++++++++++-- mlir/lib/Transforms/DmaGeneration.cpp | 3 - mlir/lib/Transforms/LoopFusion.cpp | 245 +++++++++++++++++++++++++++------ mlir/test/Transforms/loop-fusion.mlir | 29 ++-- 6 files changed, 328 insertions(+), 78 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index bfdf4d40b34..8c8f73da409 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -76,9 +76,10 @@ struct MemRefRegion { void setWrite(bool flag) { write = flag; } /// Returns a constant upper bound on the number of elements in this region if - /// bounded by a known constant, None otherwise. The 'shape' vector is set to - /// the corresponding dimension-wise bounds major to minor. We use int64_t - /// instead of uint64_t since index types can be at most int64_t. + /// bounded by a known constant (always possible for static shapes), None + /// otherwise. The 'shape' vector is set to the corresponding dimension-wise + /// bounds major to minor. We use int64_t instead of uint64_t since index + /// types can be at most int64_t. Optional getConstantBoundingSizeAndShape( SmallVectorImpl *shape = nullptr, std::vector> *lbs = nullptr, @@ -192,6 +193,10 @@ ForInst *insertBackwardComputationSlice(OperationInst *srcOpInst, OperationInst *dstOpInst, unsigned dstLoopDepth, ComputationSliceState *sliceState); + +Optional getMemoryFootprintBytes(const ForInst &forInst, + int memorySpace = -1); + } // end namespace mlir #endif // MLIR_ANALYSIS_UTILS_H diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 44daaf1459b..baab283ab25 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -395,7 +395,7 @@ bool FlatAffineConstraints::composeMap(AffineValueMap *vMap) { FlatAffineConstraints cst; if (!getFlattenedAffineExprs(vMap->getAffineMap(), &flatExprs, &cst)) { LLVM_DEBUG(llvm::dbgs() - << "composition unimplemented for semi-affine maps"); + << "composition unimplemented for semi-affine maps\n"); return false; } assert(flatExprs.size() == vMap->getNumResults()); @@ -823,6 +823,9 @@ unsigned FlatAffineConstraints::gaussianEliminateIds(unsigned posStart, if (posStart >= posLimit) return 0; + LLVM_DEBUG(llvm::dbgs() << "Eliminating by Gaussian [" << posStart << ", " + << posLimit << ")\n"); + GCDTightenInequalities(); unsigned pivotCol = 0; @@ -1749,6 +1752,9 @@ getNewNumDimsSymbols(unsigned pos, const FlatAffineConstraints &cst) { return {newNumDims, newNumSymbols}; } +#undef DEBUG_TYPE +#define DEBUG_TYPE "fm" + /// Eliminates identifier at the specified position using Fourier-Motzkin /// variable elimination. This technique is exact for rational spaces but /// conservative (in "rare" cases) for integer spaces. The operation corresponds @@ -1951,6 +1957,9 @@ void FlatAffineConstraints::FourierMotzkinEliminate( LLVM_DEBUG(dump()); } +#undef DEBUG_TYPE +#define DEBUG_TYPE "affine-structures" + void FlatAffineConstraints::projectOut(unsigned pos, unsigned num) { if (num == 0) return; diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 592fad4ab29..79d1b696612 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -60,7 +60,8 @@ Optional MemRefRegion::getConstantBoundingSizeAndShape( SmallVectorImpl *lbDivisors) const { auto memRefType = memref->getType().cast(); unsigned rank = memRefType.getRank(); - shape->reserve(rank); + if (shape) + shape->reserve(rank); // Find a constant upper bound on the extent of this memref region along each // dimension. @@ -189,6 +190,7 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth, // Add access function equalities to connect loop IVs to data dimensions. if (!regionCst->composeMap(&accessValueMap)) { LLVM_DEBUG(llvm::dbgs() << "getMemRefRegion: compose affine map failed\n"); + LLVM_DEBUG(accessValueMap.getAffineMap().dump()); return false; } @@ -207,14 +209,13 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth, } // Project out any local variables (these would have been added for any // mod/divs). - regionCst->projectOut(regionCst->getNumDimIds() + - regionCst->getNumSymbolIds(), + regionCst->projectOut(regionCst->getNumDimAndSymbolIds(), regionCst->getNumLocalIds()); // Set all identifiers appearing after the first 'rank' identifiers as // symbolic identifiers - so that the ones correspoding to the memref // dimensions are the dimensional identifiers for the memref region. - regionCst->setDimSymbolSeparation(regionCst->getNumIds() - rank); + regionCst->setDimSymbolSeparation(regionCst->getNumDimAndSymbolIds() - rank); // Constant fold any symbolic identifiers. regionCst->constantFoldIdRange(/*pos=*/regionCst->getNumDimIds(), @@ -222,12 +223,31 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth, assert(regionCst->getNumDimIds() == rank && "unexpected MemRefRegion format"); + LLVM_DEBUG(llvm::dbgs() << "Memory region:\n"); + LLVM_DEBUG(region->getConstraints()->dump()); + return true; } +// TODO(mlir-team): improve/complete this when we have target data. +static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { + auto elementType = memRefType.getElementType(); + + unsigned sizeInBits; + if (elementType.isIntOrFloat()) { + sizeInBits = elementType.getIntOrFloatBitWidth(); + } else { + auto vectorType = elementType.cast(); + sizeInBits = + vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); + } + return llvm::divideCeil(sizeInBits, 8); +} + /// Returns the size of memref data in bytes if it's statically shaped, None /// otherwise. If the element of the memref has vector type, takes into account /// size of the vector as well. +// TODO(mlir-team): improve/complete this when we have target data. Optional mlir::getMemRefSizeInBytes(MemRefType memRefType) { if (memRefType.getNumDynamicDims() > 0) return None; @@ -235,18 +255,11 @@ Optional mlir::getMemRefSizeInBytes(MemRefType memRefType) { if (!elementType.isIntOrFloat() && !elementType.isa()) return None; - uint64_t sizeInBits; - if (elementType.isIntOrFloat()) { - sizeInBits = elementType.getIntOrFloatBitWidth(); - } else { - auto vectorType = elementType.cast(); - sizeInBits = - vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); - } + unsigned sizeInBytes = getMemRefEltSizeInBytes(memRefType); for (unsigned i = 0, e = memRefType.getRank(); i < e; i++) { - sizeInBits = sizeInBits * memRefType.getDimSize(i); + sizeInBytes = sizeInBytes * memRefType.getDimSize(i); } - return llvm::divideCeil(sizeInBits, 8); + return sizeInBytes; } template @@ -525,3 +538,69 @@ unsigned mlir::getNumCommonSurroundingLoops(const Instruction &A, } return numCommonLoops; } + +// Returns the size of the region. +static Optional getRegionSize(const MemRefRegion ®ion) { + auto *memref = region.memref; + auto memRefType = memref->getType().cast(); + + auto layoutMaps = memRefType.getAffineMaps(); + if (layoutMaps.size() > 1 || + (layoutMaps.size() == 1 && !layoutMaps[0].isIdentity())) { + LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n"); + return false; + } + + // Indices to use for the DmaStart op. + // Indices for the original memref being DMAed from/to. + SmallVector memIndices; + // Indices for the faster buffer being DMAed into/from. + SmallVector bufIndices; + + // Compute the extents of the buffer. + Optional numElements = region.getConstantBoundingSizeAndShape(); + if (!numElements.hasValue()) { + LLVM_DEBUG(llvm::dbgs() << "Dynamic shapes not yet supported\n"); + return None; + } + return getMemRefEltSizeInBytes(memRefType) * numElements.getValue(); +} + +Optional mlir::getMemoryFootprintBytes(const ForInst &forInst, + int memorySpace) { + std::vector> regions; + + // Walk this 'for' instruction to gather all memory regions. + bool error = false; + const_cast(&forInst)->walkOps([&](OperationInst *opInst) { + if (!opInst->isa() && !opInst->isa()) { + // Neither load nor a store op. + return; + } + + // TODO(bondhugula): eventually, we need to be performing a union across + // all regions for a given memref instead of creating one region per + // memory op. This way we would be allocating O(num of memref's) sets + // instead of O(num of load/store op's). + auto region = std::make_unique(); + if (!getMemRefRegion(opInst, 0, region.get())) { + LLVM_DEBUG(llvm::dbgs() << "Error obtaining memory region\n"); + // TODO: stop the walk if an error occurred. + error = true; + return; + } + regions.push_back(std::move(region)); + }); + + if (error) + return None; + + int64_t totalSizeInBytes = 0; + for (const auto ®ion : regions) { + auto size = getRegionSize(*region); + if (!size.hasValue()) + return None; + totalSizeInBytes += size.getValue(); + } + return totalSizeInBytes; +} diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 3b829fc55e5..8b86056c8a9 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -179,18 +179,15 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForInst *forInst, &fastBufferShape, &lbs, &lbDivisors); if (!numElements.hasValue()) { LLVM_DEBUG(llvm::dbgs() << "Non-constant region size not supported\n"); - *sizeInBytes = 0; return false; } if (numElements.getValue() == 0) { LLVM_DEBUG(llvm::dbgs() << "Nothing to DMA\n"); - *sizeInBytes = 0; return false; } const FlatAffineConstraints *cst = region.getConstraints(); - // 'outerIVs' holds the values that this memory region is symbolic/paramteric // on; this would correspond to loop IVs surrounding the level at which the // DMA generation is being done. diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 520b89ded48..239915b1d4b 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -39,6 +39,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include #define DEBUG_TYPE "loop-fusion" @@ -46,9 +47,16 @@ using llvm::SetVector; using namespace mlir; +/// Disables fusion profitability check and fuses if valid. static llvm::cl::opt clMaximalLoopFusion("fusion-maximal", llvm::cl::Hidden, - llvm::cl::desc("Enables maximal loop fusion.")); + llvm::cl::desc("Enables maximal loop fusion")); + +/// A threshold in percent of additional computation allowed when fusing. +static llvm::cl::opt clFusionAddlComputeTolerance( + "fusion-compute-tolerance", llvm::cl::Hidden, + llvm::cl::desc("Fractional increase in additional" + "computation tolerated while fusing")); namespace { @@ -66,6 +74,10 @@ struct LoopFusion : public FunctionPass { PassResult runOnFunction(Function *f) override; static char passID; + + // The amount of additional computation that is tolerated while fusing + // pair-wise as a fraction of the total computation. + constexpr static double kComputeToleranceThreshold = 0.30f; }; } // end anonymous namespace @@ -496,12 +508,12 @@ public: // inserting a sliced loop nest of known cost into the loop's body. // NOTE: this is used to compute the cost of fusing a slice of some loop nest // within another loop. -static uint64_t getComputeCost( +static int64_t getComputeCost( ForInst *forInst, LoopNestStats *stats, llvm::SmallDenseMap *tripCountOverrideMap, - DenseMap *computeCostMap) { + DenseMap *computeCostMap) { // 'opCount' is the total number operations in one iteration of 'forInst' body - uint64_t opCount = stats->opCountMap[forInst]; + int64_t opCount = stats->opCountMap[forInst]; if (stats->loopMap.count(forInst) > 0) { for (auto *childForInst : stats->loopMap[forInst]) { opCount += getComputeCost(childForInst, stats, tripCountOverrideMap, @@ -516,7 +528,7 @@ static uint64_t getComputeCost( } } // Override trip count (if specified in map). - uint64_t tripCount = stats->tripCountMap[forInst]; + int64_t tripCount = stats->tripCountMap[forInst]; if (tripCountOverrideMap != nullptr) { auto it = tripCountOverrideMap->find(forInst); if (it != tripCountOverrideMap->end()) { @@ -777,6 +789,16 @@ static Value *createPrivateMemRef(ForInst *forInst, return newMemRef; } +// Does the slice have a single iteration? +static uint64_t getSliceIterationCount( + const llvm::SmallDenseMap &sliceTripCountMap) { + uint64_t iterCount = 1; + for (const auto &count : sliceTripCountMap) { + iterCount *= count.second; + } + return iterCount; +} + // Checks the profitability of fusing a backwards slice of the loop nest // surrounding 'srcOpInst' into the loop nest surrounding 'dstOpInsts'. // Returns true if it profitable to fuse the candidate loop nests. Returns @@ -810,6 +832,14 @@ static bool isFusionProfitable(OperationInst *srcOpInst, ArrayRef dstOpInsts, ComputationSliceState *sliceState, unsigned *dstLoopDepth) { + LLVM_DEBUG(llvm::dbgs() << "Checking whether fusion is profitable between:\n"; + llvm::dbgs() << " "; srcOpInst->dump(); llvm::dbgs() << " and \n"; + for (auto dstOpInst + : dstOpInsts) { + llvm::dbgs() << " "; + dstOpInst->dump(); + }); + // Compute cost of sliced and unsliced src loop nest. SmallVector srcLoopIVs; getLoopIVs(*srcOpInst, &srcLoopIVs); @@ -845,13 +875,27 @@ static bool isFusionProfitable(OperationInst *srcOpInst, // of these bounds). Next the union slice bounds are used to calculate // the cost of the slice and the cost of the slice inserted into the dst // loop nest at 'dstLoopDepth'. - unsigned minFusedLoopNestComputeCost = std::numeric_limits::max(); - unsigned bestDstLoopDepth; + uint64_t minFusedLoopNestComputeCost = std::numeric_limits::max(); + uint64_t maxStorageReduction = 0; + Optional sliceMemEstimate = None; + SmallVector sliceStates; sliceStates.resize(maxDstLoopDepth); + // The best loop depth at which to materialize the slice. + Optional bestDstLoopDepth = None; + + // Compute op instance count for the src loop nest without iteration slicing. + uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], &srcLoopNestStats, + /*tripCountOverrideMap=*/nullptr, + /*computeCostMap=*/nullptr); + + // Compute op instance count for the src loop nest. + uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], &dstLoopNestStats, + /*tripCountOverrideMap=*/nullptr, + /*computeCostMap=*/nullptr); llvm::SmallDenseMap sliceTripCountMap; - DenseMap computeCostMap; + DenseMap computeCostMap; for (unsigned i = maxDstLoopDepth; i >= 1; --i) { MemRefAccess srcAccess(srcOpInst); // Handle the common case of one dst load without a copy. @@ -872,56 +916,167 @@ static bool isFusionProfitable(OperationInst *srcOpInst, sliceTripCountMap.clear(); if (!buildSliceTripCountMap(srcOpInst, &sliceStates[i - 1], &sliceTripCountMap)) - return false; + // We'll skip cases where we the trip count was non-constant. + continue; - // Compute op instance count for the src loop nest with iteration slicing. - uint64_t sliceComputeCost = - getComputeCost(srcLoopIVs[0], &srcLoopNestStats, &sliceTripCountMap, - /*computeCostMap=*/nullptr); + // Checks whether a store to load forwarding will happen. + int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap); + bool storeLoadFwdGuaranteed = (sliceIterationCount == 1); + + assert(sliceIterationCount > 0); + + // Compute cost of fusion for this dest loop depth. - // Compute cost of fusion for these values of 'i' and 'j'. computeCostMap.clear(); + + // The store and loads to this memref will disappear. + if (storeLoadFwdGuaranteed) { + // A single store disappears: -1 for that. + computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]] = -1; + for (auto *loadOp : dstOpInsts) { + if (auto *loadLoop = dyn_cast_or_null(loadOp->getParentInst())) + computeCostMap[loadLoop] = -1; + } + } + + // Compute op instance count for the src loop nest with iteration slicing. + int64_t sliceComputeCost = + getComputeCost(srcLoopIVs[0], &srcLoopNestStats, + /*tripCountOverrideMap=*/&sliceTripCountMap, + /*computeCostMap=*/&computeCostMap); + + // Compute cost of fusion for this depth. computeCostMap[dstLoopIVs[i - 1]] = sliceComputeCost; - uint64_t fusedLoopNestComputeCost = + + int64_t fusedLoopNestComputeCost = getComputeCost(dstLoopIVs[0], &dstLoopNestStats, /*tripCountOverrideMap=*/nullptr, &computeCostMap); - if (fusedLoopNestComputeCost < minFusedLoopNestComputeCost) { - minFusedLoopNestComputeCost = fusedLoopNestComputeCost; + + double additionalComputeFraction = + fusedLoopNestComputeCost / + (static_cast(srcLoopNestCost) + dstLoopNestCost) - + 1; + + // TODO(bondhugula): This is an ugly approximation. Fix this by finding a + // good way to calculate the footprint of the memref in the slice and + // divide it by the total memory footprint of the fused computation. + double storageReduction = + static_cast(srcLoopNestCost) / sliceIterationCount; + + LLVM_DEBUG( + std::stringstream msg; + msg << " evaluating fusion profitability at depth : " << i << "\n" + << std::setprecision(2) << " additional compute fraction: " + << 100.0 * additionalComputeFraction << "%\n" + << " storage reduction factor: " << storageReduction << "x\n" + << " fused nest cost: " << fusedLoopNestComputeCost << "\n" + << " slice iteration count: " << sliceIterationCount << "\n"; + llvm::dbgs() << msg.str()); + + double computeToleranceThreshold = + clFusionAddlComputeTolerance.getNumOccurrences() > 0 + ? clFusionAddlComputeTolerance + : LoopFusion::kComputeToleranceThreshold; + + // TODO(b/123247369): This is a placeholder cost model. + // Among all choices that add an acceptable amount of redundant computation + // (as per computeToleranceThreshold), we will simply pick the one that + // reduces the intermediary size the most. + if ((storageReduction > maxStorageReduction) && + (clMaximalLoopFusion || + (additionalComputeFraction < computeToleranceThreshold))) { + maxStorageReduction = storageReduction; bestDstLoopDepth = i; + minFusedLoopNestComputeCost = fusedLoopNestComputeCost; + // TODO(bondhugula,andydavis): find a good way to compute the memory + // footprint of the materialized slice. + // Approximating this to the compute cost of the slice. This could be an + // under-approximation or an overapproximation, but in many cases + // accurate. + sliceMemEstimate = sliceIterationCount; } } - // Compute op instance count for the src loop nest without iteration slicing. - uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], &srcLoopNestStats, - /*tripCountOverrideMap=*/nullptr, - /*computeCostMap=*/nullptr); - // Compute op instance count for the src loop nest. - uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], &dstLoopNestStats, - /*tripCountOverrideMap=*/nullptr, - /*computeCostMap=*/nullptr); + // A simple cost model: fuse if it reduces the memory footprint. If + // -maximal-fusion is set, fuse nevertheless. - LLVM_DEBUG(llvm::dbgs() << "LoopFusion statistics " - << " bestDstLoopDepth: " << bestDstLoopDepth - << " srcLoopNestCost: " << srcLoopNestCost - << " dstLoopNestCost: " << dstLoopNestCost - << " minFusedLoopNestComputeCost: " - << minFusedLoopNestComputeCost << "\n"); - - // Do not fuse if fused loop would increase the total cost of the computation, - // unless 'clMaximalLoopFusion' flag is set. - // TODO(andydavis) Use locality/reduction in slice memref size/opportunity - // for load/store forwarding in cost model. - if (!clMaximalLoopFusion && - minFusedLoopNestComputeCost > srcLoopNestCost + dstLoopNestCost) + if (!clMaximalLoopFusion && !bestDstLoopDepth.hasValue()) { + LLVM_DEBUG(llvm::dbgs() + << "All fusion choices involve more than the threshold amount of" + "redundant computation; NOT fusing.\n"); return false; + } + + assert(bestDstLoopDepth.hasValue() && + "expected to have a value per logic above"); + + // Set dstLoopDepth based on best values from search. + *dstLoopDepth = bestDstLoopDepth.getValue(); + + LLVM_DEBUG( + llvm::dbgs() << " LoopFusion fusion stats:\n" + << "\n Best loop depth: " << bestDstLoopDepth + << "\n src loop nest compute cost: " << srcLoopNestCost + << "\n dst loop nest compute cost: " << dstLoopNestCost + << "\n fused loop nest compute cost: " + << minFusedLoopNestComputeCost << "\n"); + + auto dstMemSize = getMemoryFootprintBytes(*dstLoopIVs[0]); + auto srcMemSize = getMemoryFootprintBytes(*srcLoopIVs[0]); + + Optional storageReduction = None; + + if (!clMaximalLoopFusion) { + if (!dstMemSize.hasValue() || !srcMemSize.hasValue()) { + LLVM_DEBUG( + llvm::dbgs() + << " fusion memory benefit cannot be evaluated; NOT fusing.\n"); + return false; + } + + auto srcMemSizeVal = srcMemSize.getValue(); + auto dstMemSizeVal = dstMemSize.getValue(); + + assert(sliceMemEstimate.hasValue() && "expected value"); + // This is an inaccurate estimate since sliceMemEstimate is isaccurate. + auto fusedMem = dstMemSizeVal + sliceMemEstimate.getValue(); + + LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n" + << " dst mem: " << dstMemSizeVal << "\n" + << " fused mem: " << fusedMem << "\n" + << " slice mem: " << sliceMemEstimate << "\n"); + + if (fusedMem > srcMemSizeVal + dstMemSizeVal) { + LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n"); + return false; + } + storageReduction = + 100.0 * + (1.0 - fusedMem / (static_cast(srcMemSizeVal) + dstMemSizeVal)); + } + + double additionalComputeFraction = + 100.0 * (minFusedLoopNestComputeCost / + (static_cast(srcLoopNestCost) + dstLoopNestCost) - + 1); + + std::stringstream msg; + msg << " fusion is most profitable at depth " << *dstLoopDepth << " with " + << setprecision(2) << additionalComputeFraction + << "% redundant computation and a "; + msg << (storageReduction.hasValue() + ? std::to_string(storageReduction.getValue()) + : ""); + msg << "% storage reduction.\n"; + LLVM_DEBUG(llvm::dbgs() << msg.str()); + // Update return parameter 'sliceState' with 'bestSliceState'. - ComputationSliceState *bestSliceState = &sliceStates[bestDstLoopDepth - 1]; + ComputationSliceState *bestSliceState = &sliceStates[*dstLoopDepth - 1]; sliceState->lbs = bestSliceState->lbs; sliceState->ubs = bestSliceState->ubs; sliceState->lbOperands = bestSliceState->lbOperands; sliceState->ubOperands = bestSliceState->ubOperands; - // Set dstLoopDepth based on best values from search. - *dstLoopDepth = bestDstLoopDepth; + // Canonicalize slice bound affine maps. for (unsigned i = 0; i < numSrcLoopIVs; ++i) { if (sliceState->lbs[i] != AffineMap::Null()) { @@ -1017,29 +1172,35 @@ public: // Skip 'srcEdge' if not for 'memref'. if (srcEdge.memref != memref) continue; + auto *srcNode = mdg->getNode(srcEdge.id); // Skip if 'srcNode' is not a loop nest. if (!isa(srcNode->inst)) continue; + // Skip if 'srcNode' has more than one store to 'memref'. if (srcNode->getStoreOpCount(memref) != 1) continue; + // Skip 'srcNode' if it has in dependence edges. NOTE: This is overly // TODO(andydavis) Track dependence type with edges, and just check // for WAW dependence edge here. if (mdg->getInEdgeCount(srcNode->id, memref) != 0) continue; + // Skip if 'srcNode' has out edges to other memrefs after 'dstId'. if (mdg->getMinOutEdgeNodeId(srcNode->id, memref) < dstId) continue; + + // Check if fusion would be profitable. // Get unique 'srcNode' store op. auto *srcStoreOpInst = srcNode->stores.front(); - // Check if fusion would be profitable. unsigned dstLoopDepth; mlir::ComputationSliceState sliceState; if (!isFusionProfitable(srcStoreOpInst, dstLoadOpInsts, &sliceState, &dstLoopDepth)) continue; + // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. auto *sliceLoopNest = mlir::insertBackwardComputationSlice( srcStoreOpInst, dstLoadOpInsts[0], dstLoopDepth, &sliceState); diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index 8e5b706835e..57b5d8dd0ef 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -633,7 +633,6 @@ func @fuse_reshape_16_4_64() { // ----- -// TODO(b/123072438) Re-enable test MemRefRegion bug is fixed. // All three loop nests below (6-d one, 2-d one, 2-d one is fused into a single // 2-d loop nest). // CHECK-LABEL: func @R6_to_R2_reshape @@ -970,24 +969,24 @@ func @should_fuse_deep_loop_nests() { // CHECK-NEXT: for %i1 = 0 to 3 { // CHECK-NEXT: for %i2 = 0 to 2 { // CHECK-NEXT: for %i3 = 0 to 2 { -// CHECK-NEXT: for %i4 = 0 to 16 { -// CHECK-NEXT: for %i5 = 0 to 10 { -// CHECK-NEXT: %3 = load %0[%i2, %i3, %i0, %i1, %i4, %i5] : memref<2x2x3x3x16x10xf32, 2> -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: for %i6 = 0 to 16 { -// CHECK-NEXT: for %i7 = 0 to 10 { -// CHECK-NEXT: %4 = affine_apply #map0(%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i6, %i7) -// CHECK-NEXT: store %cst, %2[%4#0, %4#1, %4#2, %4#3, %4#4, %4#5] : memref<1x1x1x1x16x10xf32, 2> -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: for %i8 = 0 to 3 { -// CHECK-NEXT: for %i9 = 0 to 3 { +// CHECK-NEXT: for %i4 = 0 to 3 { +// CHECK-NEXT: for %i5 = 0 to 3 { +// CHECK-NEXT: for %i6 = 0 to 16 { +// CHECK-NEXT: for %i7 = 0 to 10 { +// CHECK-NEXT: %3 = load %0[%i2, %i3, %i0, %i1, %i6, %i7] : memref<2x2x3x3x16x10xf32, 2> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: for %i8 = 0 to 16 { +// CHECK-NEXT: for %i9 = 0 to 10 { +// CHECK-NEXT: %4 = affine_apply #map0(%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i8, %i9) +// CHECK-NEXT: store %cst, %2[%4#0, %4#1, %4#2, %4#3, %4#4, %4#5] : memref<1x1x1x1x16x10xf32, 2> +// CHECK-NEXT: } +// CHECK-NEXT: } // CHECK-NEXT: for %i10 = 0 to 2 { // CHECK-NEXT: for %i11 = 0 to 2 { // CHECK-NEXT: for %i12 = 0 to 16 { // CHECK-NEXT: for %i13 = 0 to 10 { -// CHECK-NEXT: %5 = load %0[%i10, %i11, %i8, %i9, %i12, %i13] : memref<2x2x3x3x16x10xf32, 2> +// CHECK-NEXT: %5 = load %0[%i10, %i11, %i4, %i5, %i12, %i13] : memref<2x2x3x3x16x10xf32, 2> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: for %i14 = 0 to 16 { -- cgit v1.2.3 From b28009b681a45a0e5b925ed90f33f8285c87e6af Mon Sep 17 00:00:00 2001 From: MLIR Team Date: Wed, 23 Jan 2019 11:11:43 -0800 Subject: Fix single producer check in loop fusion pass. PiperOrigin-RevId: 230565482 --- mlir/lib/Transforms/LoopFusion.cpp | 6 +++--- mlir/test/Transforms/loop-fusion.mlir | 28 ++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 3 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 239915b1d4b..94d763fcbd1 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -1177,9 +1177,9 @@ public: // Skip if 'srcNode' is not a loop nest. if (!isa(srcNode->inst)) continue; - - // Skip if 'srcNode' has more than one store to 'memref'. - if (srcNode->getStoreOpCount(memref) != 1) + // Skip if 'srcNode' has more than one store to any memref. + // TODO(andydavis) Support fusing multi-output src loop nests. + if (srcNode->stores.size() != 1) continue; // Skip 'srcNode' if it has in dependence edges. NOTE: This is overly diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index 57b5d8dd0ef..86a24cf7796 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -1288,3 +1288,31 @@ func @R3_to_R2_reshape() { // CHECK-NEXT: } // CHECK-NEXT: return // CHECK-NEXT: } + +// ----- + +// CHECK-LABEL: func @should_not_fuse_multi_output_producer() { +func @should_not_fuse_multi_output_producer() { + %a = alloc() : memref<10xf32> + %b = alloc() : memref<10xf32> + + %cf7 = constant 7.0 : f32 + + for %i0 = 0 to 10 { + store %cf7, %a[%i0] : memref<10xf32> + store %cf7, %b[%i0] : memref<10xf32> + } + for %i1 = 0 to 10 { + %v0 = load %a[%i1] : memref<10xf32> + } + + // CHECK: for %i0 = 0 to 10 { + // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> + // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: %2 = load %0[%i1] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: return + return +} -- cgit v1.2.3 From 6859f33292a7c4d908f211b9d9d8be200276d383 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Wed, 23 Jan 2019 14:39:45 -0800 Subject: Migrate VectorOrTensorType/MemRefType shape api to use int64_t instead of int. PiperOrigin-RevId: 230605756 --- mlir/include/mlir/Analysis/Utils.h | 2 +- mlir/include/mlir/Analysis/VectorAnalysis.h | 2 +- mlir/include/mlir/IR/Builders.h | 6 ++-- mlir/include/mlir/IR/StandardTypes.h | 40 +++++++++++---------- mlir/lib/Analysis/Utils.cpp | 4 +-- mlir/lib/Analysis/VectorAnalysis.cpp | 8 ++--- mlir/lib/Dialect/Traits.cpp | 4 +-- mlir/lib/IR/Attributes.cpp | 4 +-- mlir/lib/IR/Builders.cpp | 7 ++-- mlir/lib/IR/StandardTypes.cpp | 42 ++++++++++------------ mlir/lib/IR/TypeDetail.h | 37 +++++++++---------- mlir/lib/Parser/Parser.cpp | 38 ++++++++++---------- mlir/lib/StandardOps/StandardOps.cpp | 10 +++--- mlir/lib/SuperVectorOps/SuperVectorOps.cpp | 2 +- mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp | 2 +- mlir/lib/Transforms/DmaGeneration.cpp | 6 ++-- mlir/lib/Transforms/LoopFusion.cpp | 2 +- mlir/lib/Transforms/LowerVectorTransfers.cpp | 8 ++--- mlir/lib/Transforms/MaterializeVectors.cpp | 4 +-- mlir/lib/Transforms/PipelineDataTransfer.cpp | 4 +-- .../Vectorization/VectorizerTestPass.cpp | 4 +-- mlir/lib/Transforms/Vectorize.cpp | 5 +-- 22 files changed, 123 insertions(+), 118 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index 8c8f73da409..c32583ebc6e 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -81,7 +81,7 @@ struct MemRefRegion { /// bounds major to minor. We use int64_t instead of uint64_t since index /// types can be at most int64_t. Optional getConstantBoundingSizeAndShape( - SmallVectorImpl *shape = nullptr, + SmallVectorImpl *shape = nullptr, std::vector> *lbs = nullptr, SmallVectorImpl *lbDivisors = nullptr) const; diff --git a/mlir/include/mlir/Analysis/VectorAnalysis.h b/mlir/include/mlir/Analysis/VectorAnalysis.h index e34e3433f2f..dfb1164750a 100644 --- a/mlir/include/mlir/Analysis/VectorAnalysis.h +++ b/mlir/include/mlir/Analysis/VectorAnalysis.h @@ -48,7 +48,7 @@ class VectorType; /// - shapeRatio({3, 4, 4, 8}, {2, 5, 2}) returns None /// - shapeRatio({1, 2, 10, 32}, {2, 5, 2}) returns {1, 1, 2, 16} llvm::Optional> -shapeRatio(ArrayRef superShape, ArrayRef subShape); +shapeRatio(ArrayRef superShape, ArrayRef subShape); /// Computes and returns the multi-dimensional ratio of the shapes of /// `superVector` to `subVector`. If integral division is not possible, returns diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 3fa2e3ff4a9..3814e5fc6d9 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -80,11 +80,11 @@ public: IntegerType getI1Type(); IntegerType getIntegerType(unsigned width); FunctionType getFunctionType(ArrayRef inputs, ArrayRef results); - MemRefType getMemRefType(ArrayRef shape, Type elementType, + MemRefType getMemRefType(ArrayRef shape, Type elementType, ArrayRef affineMapComposition = {}, unsigned memorySpace = 0); - VectorType getVectorType(ArrayRef shape, Type elementType); - RankedTensorType getTensorType(ArrayRef shape, Type elementType); + VectorType getVectorType(ArrayRef shape, Type elementType); + RankedTensorType getTensorType(ArrayRef shape, Type elementType); UnrankedTensorType getTensorType(Type elementType); /// Get or construct an instance of the type 'ty' with provided arguments. diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h index 823ea75d403..3a12f8016f1 100644 --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -186,11 +186,11 @@ public: /// If this is ranked tensor or vector type, return the rank. If it is an /// unranked tensor, return -1. - int getRank() const; + int64_t getRank() const; /// If this is ranked tensor or vector type, return the shape. If it is an /// unranked tensor, abort. - ArrayRef getShape() const; + ArrayRef getShape() const; /// If this is unranked tensor or any dimension has unknown size (<0), /// it doesn't have static shape. If all dimensions have known size (>= 0), @@ -200,7 +200,7 @@ public: /// If this is ranked tensor or vector type, return the size of the specified /// dimension. It aborts if the tensor is unranked (this can be checked by /// the getRank call method). - int getDimSize(unsigned i) const; + int64_t getDimSize(unsigned i) const; /// Get the total amount of bits occupied by a value of this type. This does /// not take into account any memory layout or widening constraints, e.g. a @@ -208,7 +208,7 @@ public: /// it will likely be stored as in a 4xi64 vector register. Fail an assertion /// if the size cannot be computed statically, i.e. if the tensor has a /// dynamic shape or if its elemental type does not have a known bit width. - long getSizeInBits() const; + int64_t getSizeInBits() const; /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool kindof(unsigned kind) { @@ -227,26 +227,26 @@ public: /// Get or create a new VectorType of the provided shape and element type. /// Assumes the arguments define a well-formed VectorType. - static VectorType get(ArrayRef shape, Type elementType); + static VectorType get(ArrayRef shape, Type elementType); /// Get or create a new VectorType of the provided shape and element type /// declared at the given, potentially unknown, location. If the VectorType /// defined by the arguments would be ill-formed, emit errors and return /// nullptr-wrapping type. - static VectorType getChecked(ArrayRef shape, Type elementType, + static VectorType getChecked(ArrayRef shape, Type elementType, Location location); /// Verify the construction of a vector type. static bool verifyConstructionInvariants(llvm::Optional loc, MLIRContext *context, - ArrayRef shape, + ArrayRef shape, Type elementType); /// Returns true of the given type can be used as an element of a vector type. /// In particular, vectors can consist of integer or float primitives. static bool isValidElementType(Type t) { return t.isIntOrFloat(); } - ArrayRef getShape() const; + ArrayRef getShape() const; /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool kindof(unsigned kind) { return kind == StandardTypes::Vector; } @@ -290,22 +290,22 @@ public: /// Get or create a new RankedTensorType of the provided shape and element /// type. Assumes the arguments define a well-formed type. - static RankedTensorType get(ArrayRef shape, Type elementType); + static RankedTensorType get(ArrayRef shape, Type elementType); /// Get or create a new RankedTensorType of the provided shape and element /// type declared at the given, potentially unknown, location. If the /// RankedTensorType defined by the arguments would be ill-formed, emit errors /// and return a nullptr-wrapping type. - static RankedTensorType getChecked(ArrayRef shape, Type elementType, + static RankedTensorType getChecked(ArrayRef shape, Type elementType, Location location); /// Verify the construction of a ranked tensor type. static bool verifyConstructionInvariants(llvm::Optional loc, MLIRContext *context, - ArrayRef shape, + ArrayRef shape, Type elementType); - ArrayRef getShape() const; + ArrayRef getShape() const; static bool kindof(unsigned kind) { return kind == StandardTypes::RankedTensor; @@ -338,7 +338,7 @@ public: MLIRContext *context, Type elementType); - ArrayRef getShape() const { return ArrayRef(); } + ArrayRef getShape() const { return llvm::None; } static bool kindof(unsigned kind) { return kind == StandardTypes::UnrankedTensor; @@ -361,7 +361,7 @@ public: /// map composition, and memory space. Assumes the arguments define a /// well-formed MemRef type. Use getChecked to gracefully handle MemRefType /// construction failures. - static MemRefType get(ArrayRef shape, Type elementType, + static MemRefType get(ArrayRef shape, Type elementType, ArrayRef affineMapComposition, unsigned memorySpace) { auto result = getImpl(shape, elementType, affineMapComposition, memorySpace, @@ -376,7 +376,7 @@ public: /// UnknownLoc. If the MemRefType defined by the arguments would be /// ill-formed, emits errors (to the handler registered with the context or to /// the error stream) and returns nullptr. - static MemRefType getChecked(ArrayRef shape, Type elementType, + static MemRefType getChecked(ArrayRef shape, Type elementType, ArrayRef affineMapComposition, unsigned memorySpace, Location location) { return getImpl(shape, elementType, affineMapComposition, memorySpace, @@ -386,10 +386,10 @@ public: unsigned getRank() const { return getShape().size(); } /// Returns an array of memref shape dimension sizes. - ArrayRef getShape() const; + ArrayRef getShape() const; /// Return the size of the specified dimension, or -1 if unspecified. - int getDimSize(unsigned i) const { return getShape()[i]; } + int64_t getDimSize(unsigned i) const { return getShape()[i]; } /// Returns the elemental type for this memref shape. Type getElementType() const; @@ -404,6 +404,10 @@ public: /// Returns the number of dimensions with dynamic size. unsigned getNumDynamicDims() const; + /// If any dimension of the shape has unknown size (<0), it doesn't have + /// static shape. + bool hasStaticShape() const { return getNumDynamicDims() == 0; } + static bool kindof(unsigned kind) { return kind == StandardTypes::MemRef; } /// Unique identifier for this type class. @@ -413,7 +417,7 @@ private: /// Get or create a new MemRefType defined by the arguments. If the resulting /// type would be ill-formed, return nullptr. If the location is provided, /// emit detailed error messages. - static MemRefType getImpl(ArrayRef shape, Type elementType, + static MemRefType getImpl(ArrayRef shape, Type elementType, ArrayRef affineMapComposition, unsigned memorySpace, Optional location); }; diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 79d1b696612..6c33bdee6aa 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -56,7 +56,7 @@ unsigned MemRefRegion::getRank() const { } Optional MemRefRegion::getConstantBoundingSizeAndShape( - SmallVectorImpl *shape, std::vector> *lbs, + SmallVectorImpl *shape, std::vector> *lbs, SmallVectorImpl *lbDivisors) const { auto memRefType = memref->getType().cast(); unsigned rank = memRefType.getRank(); @@ -289,7 +289,7 @@ bool mlir::boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, // of upper and out of lower), and check if the constraint system is // feasible. If it is, there is at least one point out of bounds. SmallVector ineq(rank + 1, 0); - int dimSize = loadOrStoreOp->getMemRefType().getDimSize(r); + int64_t dimSize = loadOrStoreOp->getMemRefType().getDimSize(r); // TODO(bondhugula): handle dynamic dim sizes. if (dimSize == -1) continue; diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index bc43e2ca5eb..37eed71508f 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -37,8 +37,8 @@ using namespace mlir; using llvm::SetVector; -Optional> mlir::shapeRatio(ArrayRef superShape, - ArrayRef subShape) { +Optional> +mlir::shapeRatio(ArrayRef superShape, ArrayRef subShape) { if (superShape.size() < subShape.size()) { return Optional>(); } @@ -55,8 +55,8 @@ Optional> mlir::shapeRatio(ArrayRef superShape, result.push_back(superSize / subSize); }; functional::zipApply( - divide, SmallVector{superShape.rbegin(), superShape.rend()}, - SmallVector{subShape.rbegin(), subShape.rend()}); + divide, SmallVector{superShape.rbegin(), superShape.rend()}, + SmallVector{subShape.rbegin(), subShape.rend()}); // If integral division does not occur, return and let the caller decide. if (!divides) { diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp index 6beb8cba41d..0fa57c23364 100644 --- a/mlir/lib/Dialect/Traits.cpp +++ b/mlir/lib/Dialect/Traits.cpp @@ -86,7 +86,7 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2) { } // Returns the shape of the given type. - auto getShape = [](Type type) -> ArrayRef { + auto getShape = [](Type type) -> ArrayRef { if (auto vtType = type.dyn_cast()) return vtType.getShape(); return {}; @@ -104,7 +104,7 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2) { // The result shape has the maximum among the two inputs at every // dimension index. - SmallVector resultShape; + SmallVector resultShape; if (shape1.size() > shape2.size()) { std::copy(shape1.begin(), shape1.end(), std::back_inserter(resultShape)); } else { diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index b3ff25035de..e00297a6bb0 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -168,7 +168,7 @@ Attribute DenseElementsAttr::getValue(ArrayRef index) const { // Reduce the provided multidimensional index into a 1D index. uint64_t valueIndex = 0; uint64_t dimMultiplier = 1; - for (int i = rank - 1; i >= 0; --i) { + for (auto i = rank - 1; i >= 0; --i) { valueIndex += index[i] * dimMultiplier; dimMultiplier *= shape[i]; } @@ -346,7 +346,7 @@ Attribute SparseElementsAttr::getValue(ArrayRef index) const { // Build a mapping between known indices and the offset of the stored element. llvm::SmallDenseMap, size_t> mappedIndices; - size_t numSparseIndices = sparseIndices.getType().getDimSize(0); + auto numSparseIndices = sparseIndices.getType().getDimSize(0); for (size_t i = 0, e = numSparseIndices; i != e; ++i) mappedIndices.try_emplace( {sparseIndexValues + (i * rank), static_cast(rank)}, i); diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 2e453f0c6e8..6a513352afc 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -77,17 +77,18 @@ FunctionType Builder::getFunctionType(ArrayRef inputs, return FunctionType::get(inputs, results, context); } -MemRefType Builder::getMemRefType(ArrayRef shape, Type elementType, +MemRefType Builder::getMemRefType(ArrayRef shape, Type elementType, ArrayRef affineMapComposition, unsigned memorySpace) { return MemRefType::get(shape, elementType, affineMapComposition, memorySpace); } -VectorType Builder::getVectorType(ArrayRef shape, Type elementType) { +VectorType Builder::getVectorType(ArrayRef shape, Type elementType) { return VectorType::get(shape, elementType); } -RankedTensorType Builder::getTensorType(ArrayRef shape, Type elementType) { +RankedTensorType Builder::getTensorType(ArrayRef shape, + Type elementType) { return RankedTensorType::get(shape, elementType); } diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp index f031a11859f..7b7041fe428 100644 --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -112,6 +112,7 @@ unsigned VectorOrTensorType::getNumElements() const { switch (getKind()) { case StandardTypes::Vector: case StandardTypes::RankedTensor: { + assert(hasStaticShape() && "expected type to have static shape"); auto shape = getShape(); unsigned num = 1; for (auto dim : shape) @@ -125,7 +126,7 @@ unsigned VectorOrTensorType::getNumElements() const { /// If this is ranked tensor or vector type, return the rank. If it is an /// unranked tensor, return -1. -int VectorOrTensorType::getRank() const { +int64_t VectorOrTensorType::getRank() const { switch (getKind()) { case StandardTypes::Vector: case StandardTypes::RankedTensor: @@ -137,7 +138,7 @@ int VectorOrTensorType::getRank() const { } } -int VectorOrTensorType::getDimSize(unsigned i) const { +int64_t VectorOrTensorType::getDimSize(unsigned i) const { switch (getKind()) { case StandardTypes::Vector: case StandardTypes::RankedTensor: @@ -150,7 +151,7 @@ int VectorOrTensorType::getDimSize(unsigned i) const { // Get the number of number of bits require to store a value of the given vector // or tensor types. Compute the value recursively since tensors are allowed to // have vectors as elements. -long VectorOrTensorType::getSizeInBits() const { +int64_t VectorOrTensorType::getSizeInBits() const { assert(hasStaticShape() && "cannot get the bit size of an aggregate with a dynamic shape"); @@ -165,7 +166,7 @@ long VectorOrTensorType::getSizeInBits() const { return getNumElements() * elementVectorOrTensorType.getSizeInBits(); } -ArrayRef VectorOrTensorType::getShape() const { +ArrayRef VectorOrTensorType::getShape() const { switch (getKind()) { case StandardTypes::Vector: return cast().getShape(); @@ -179,18 +180,17 @@ ArrayRef VectorOrTensorType::getShape() const { bool VectorOrTensorType::hasStaticShape() const { if (isa()) return false; - auto dims = getShape(); - return !std::any_of(dims.begin(), dims.end(), [](int i) { return i < 0; }); + return llvm::none_of(getShape(), [](int64_t i) { return i < 0; }); } /// VectorType -VectorType VectorType::get(ArrayRef shape, Type elementType) { +VectorType VectorType::get(ArrayRef shape, Type elementType) { return Base::get(elementType.getContext(), StandardTypes::Vector, shape, elementType); } -VectorType VectorType::getChecked(ArrayRef shape, Type elementType, +VectorType VectorType::getChecked(ArrayRef shape, Type elementType, Location location) { return Base::getChecked(location, elementType.getContext(), StandardTypes::Vector, shape, elementType); @@ -198,7 +198,7 @@ VectorType VectorType::getChecked(ArrayRef shape, Type elementType, bool VectorType::verifyConstructionInvariants(llvm::Optional loc, MLIRContext *context, - ArrayRef shape, + ArrayRef shape, Type elementType) { if (shape.empty()) { if (loc) @@ -212,7 +212,7 @@ bool VectorType::verifyConstructionInvariants(llvm::Optional loc, return true; } - if (any_of(shape, [](int i) { return i < 0; })) { + if (any_of(shape, [](int64_t i) { return i < 0; })) { if (loc) context->emitError(*loc, "vector types must have static shape"); return true; @@ -220,7 +220,7 @@ bool VectorType::verifyConstructionInvariants(llvm::Optional loc, return false; } -ArrayRef VectorType::getShape() const { +ArrayRef VectorType::getShape() const { return static_cast(type)->getShape(); } @@ -241,12 +241,13 @@ static inline bool checkTensorElementType(Optional location, /// RankedTensorType -RankedTensorType RankedTensorType::get(ArrayRef shape, Type elementType) { +RankedTensorType RankedTensorType::get(ArrayRef shape, + Type elementType) { return Base::get(elementType.getContext(), StandardTypes::RankedTensor, shape, elementType); } -RankedTensorType RankedTensorType::getChecked(ArrayRef shape, +RankedTensorType RankedTensorType::getChecked(ArrayRef shape, Type elementType, Location location) { return Base::getChecked(location, elementType.getContext(), @@ -254,16 +255,16 @@ RankedTensorType RankedTensorType::getChecked(ArrayRef shape, } bool RankedTensorType::verifyConstructionInvariants( - llvm::Optional loc, MLIRContext *context, ArrayRef shape, + llvm::Optional loc, MLIRContext *context, ArrayRef shape, Type elementType) { return checkTensorElementType(loc, context, elementType); } -ArrayRef RankedTensorType::getShape() const { +ArrayRef RankedTensorType::getShape() const { return static_cast(type)->getShape(); } -ArrayRef MemRefType::getShape() const { +ArrayRef MemRefType::getShape() const { return static_cast(type)->getShape(); } @@ -291,7 +292,7 @@ bool UnrankedTensorType::verifyConstructionInvariants( /// type would be ill-formed, return nullptr. If the location is provided, /// emit detailed error messages. To emit errors when the location is unknown, /// pass in an instance of UnknownLoc. -MemRefType MemRefType::getImpl(ArrayRef shape, Type elementType, +MemRefType MemRefType::getImpl(ArrayRef shape, Type elementType, ArrayRef affineMapComposition, unsigned memorySpace, Optional location) { @@ -346,12 +347,7 @@ unsigned MemRefType::getMemorySpace() const { } unsigned MemRefType::getNumDynamicDims() const { - unsigned numDynamicDims = 0; - for (int dimSize : getShape()) { - if (dimSize == -1) - ++numDynamicDims; - } - return numDynamicDims; + return llvm::count_if(getShape(), [](int64_t i) { return i < 0; }); } // Define type identifiers. diff --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h index 32a30a6275f..91762df53d6 100644 --- a/mlir/lib/IR/TypeDetail.h +++ b/mlir/lib/IR/TypeDetail.h @@ -131,12 +131,12 @@ struct VectorOrTensorTypeStorage : public TypeStorage { /// Vector Type Storage and Uniquing. struct VectorTypeStorage : public VectorOrTensorTypeStorage { VectorTypeStorage(unsigned shapeSize, Type elementTy, - const int *shapeElements) + const int64_t *shapeElements) : VectorOrTensorTypeStorage(elementTy, shapeSize), shapeElements(shapeElements) {} /// The hash key used for uniquing. - using KeyTy = std::pair, Type>; + using KeyTy = std::pair, Type>; bool operator==(const KeyTy &key) const { return key == KeyTy(getShape(), elementType); } @@ -145,28 +145,28 @@ struct VectorTypeStorage : public VectorOrTensorTypeStorage { static VectorTypeStorage *construct(TypeStorageAllocator &allocator, const KeyTy &key) { // Copy the shape into the bump pointer. - ArrayRef shape = allocator.copyInto(key.first); + ArrayRef shape = allocator.copyInto(key.first); // Initialize the memory using placement new. return new (allocator.allocate()) VectorTypeStorage(shape.size(), key.second, shape.data()); } - ArrayRef getShape() const { - return ArrayRef(shapeElements, getSubclassData()); + ArrayRef getShape() const { + return ArrayRef(shapeElements, getSubclassData()); } - const int *shapeElements; + const int64_t *shapeElements; }; struct RankedTensorTypeStorage : public VectorOrTensorTypeStorage { RankedTensorTypeStorage(unsigned shapeSize, Type elementTy, - const int *shapeElements) + const int64_t *shapeElements) : VectorOrTensorTypeStorage(elementTy, shapeSize), shapeElements(shapeElements) {} /// The hash key used for uniquing. - using KeyTy = std::pair, Type>; + using KeyTy = std::pair, Type>; bool operator==(const KeyTy &key) const { return key == KeyTy(getShape(), elementType); } @@ -175,18 +175,18 @@ struct RankedTensorTypeStorage : public VectorOrTensorTypeStorage { static RankedTensorTypeStorage *construct(TypeStorageAllocator &allocator, const KeyTy &key) { // Copy the shape into the bump pointer. - ArrayRef shape = allocator.copyInto(key.first); + ArrayRef shape = allocator.copyInto(key.first); // Initialize the memory using placement new. return new (allocator.allocate()) RankedTensorTypeStorage(shape.size(), key.second, shape.data()); } - ArrayRef getShape() const { - return ArrayRef(shapeElements, getSubclassData()); + ArrayRef getShape() const { + return ArrayRef(shapeElements, getSubclassData()); } - const int *shapeElements; + const int64_t *shapeElements; }; struct UnrankedTensorTypeStorage : public VectorOrTensorTypeStorage { @@ -203,7 +203,7 @@ struct UnrankedTensorTypeStorage : public VectorOrTensorTypeStorage { struct MemRefTypeStorage : public TypeStorage { MemRefTypeStorage(unsigned shapeSize, Type elementType, - const int *shapeElements, const unsigned numAffineMaps, + const int64_t *shapeElements, const unsigned numAffineMaps, AffineMap const *affineMapList, const unsigned memorySpace) : TypeStorage(shapeSize), elementType(elementType), shapeElements(shapeElements), numAffineMaps(numAffineMaps), @@ -212,7 +212,8 @@ struct MemRefTypeStorage : public TypeStorage { /// The hash key used for uniquing. // MemRefs are uniqued based on their shape, element type, affine map // composition, and memory space. - using KeyTy = std::tuple, Type, ArrayRef, unsigned>; + using KeyTy = + std::tuple, Type, ArrayRef, unsigned>; bool operator==(const KeyTy &key) const { return key == KeyTy(getShape(), elementType, getAffineMaps(), memorySpace); } @@ -221,7 +222,7 @@ struct MemRefTypeStorage : public TypeStorage { static MemRefTypeStorage *construct(TypeStorageAllocator &allocator, const KeyTy &key) { // Copy the shape into the bump pointer. - ArrayRef shape = allocator.copyInto(std::get<0>(key)); + ArrayRef shape = allocator.copyInto(std::get<0>(key)); // Copy the affine map composition into the bump pointer. ArrayRef affineMapComposition = @@ -234,8 +235,8 @@ struct MemRefTypeStorage : public TypeStorage { affineMapComposition.data(), std::get<3>(key)); } - ArrayRef getShape() const { - return ArrayRef(shapeElements, getSubclassData()); + ArrayRef getShape() const { + return ArrayRef(shapeElements, getSubclassData()); } ArrayRef getAffineMaps() const { @@ -245,7 +246,7 @@ struct MemRefTypeStorage : public TypeStorage { /// The type of each scalar element of the memref. Type elementType; /// An array of integers which stores the shape dimension sizes. - const int *shapeElements; + const int64_t *shapeElements; /// The number of affine maps in the 'affineMapList' array. const unsigned numAffineMaps; /// List of affine maps in the memref's layout/index map composition. diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 7766c5900c2..6a96021dcb5 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -183,7 +183,7 @@ public: // Type parsing. VectorType parseVectorType(); ParseResult parseXInDimensionList(); - ParseResult parseDimensionListRanked(SmallVectorImpl &dimensions); + ParseResult parseDimensionListRanked(SmallVectorImpl &dimensions); Type parseExtendedType(); Type parseTensorType(); Type parseMemRefType(); @@ -386,13 +386,13 @@ VectorType Parser::parseVectorType() { if (getToken().isNot(Token::integer)) return (emitError("expected dimension size in vector type"), nullptr); - SmallVector dimensions; + SmallVector dimensions; while (getToken().is(Token::integer)) { // Make sure this integer value is in bound and valid. auto dimension = getToken().getUnsignedIntegerValue(); if (!dimension.hasValue()) return (emitError("invalid dimension in vector type"), nullptr); - dimensions.push_back((int)dimension.getValue()); + dimensions.push_back((int64_t)dimension.getValue()); consumeToken(Token::integer); @@ -442,16 +442,17 @@ ParseResult Parser::parseXInDimensionList() { /// dimension-list-ranked ::= (dimension `x`)* /// dimension ::= `?` | integer-literal /// -ParseResult Parser::parseDimensionListRanked(SmallVectorImpl &dimensions) { +ParseResult +Parser::parseDimensionListRanked(SmallVectorImpl &dimensions) { while (getToken().isAny(Token::integer, Token::question)) { if (consumeIf(Token::question)) { dimensions.push_back(-1); } else { // Make sure this integer value is in bound and valid. auto dimension = getToken().getUnsignedIntegerValue(); - if (!dimension.hasValue() || (int)dimension.getValue() < 0) + if (!dimension.hasValue() || (int64_t)dimension.getValue() < 0) return emitError("invalid dimension"); - dimensions.push_back((int)dimension.getValue()); + dimensions.push_back((int64_t)dimension.getValue()); consumeToken(Token::integer); } @@ -540,7 +541,7 @@ Type Parser::parseTensorType() { return nullptr; bool isUnranked; - SmallVector dimensions; + SmallVector dimensions; if (consumeIf(Token::star)) { // This is an unranked tensor type. @@ -580,7 +581,7 @@ Type Parser::parseMemRefType() { if (parseToken(Token::less, "expected '<' in memref type")) return nullptr; - SmallVector dimensions; + SmallVector dimensions; if (parseDimensionListRanked(dimensions)) return nullptr; @@ -706,12 +707,12 @@ public: ArrayRef getValues() const { return storage; } - ArrayRef getShape() const { return shape; } + ArrayRef getShape() const { return shape; } private: /// Parse either a single element or a list of elements. Return the dimensions /// of the parsed sub-tensor in dims. - ParseResult parseElementOrList(llvm::SmallVectorImpl &dims); + ParseResult parseElementOrList(llvm::SmallVectorImpl &dims); /// Parse a list of either lists or elements, returning the dimensions of the /// parsed sub-tensors in dims. For example: @@ -719,11 +720,11 @@ private: /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] /// parseList([[1, 2], 3]) -> Failure /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure - ParseResult parseList(llvm::SmallVectorImpl &dims); + ParseResult parseList(llvm::SmallVectorImpl &dims); Parser &p; Type eltTy; - SmallVector shape; + SmallVector shape; std::vector storage; }; } // namespace @@ -731,7 +732,7 @@ private: /// Parse either a single element or a list of elements. Return the dimensions /// of the parsed sub-tensor in dims. ParseResult -TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl &dims) { +TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl &dims) { switch (p.getToken().getKind()) { case Token::l_square: return parseList(dims); @@ -789,11 +790,12 @@ TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl &dims) { /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] /// parseList([[1, 2], 3]) -> Failure /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure -ParseResult TensorLiteralParser::parseList(llvm::SmallVectorImpl &dims) { +ParseResult +TensorLiteralParser::parseList(llvm::SmallVectorImpl &dims) { p.consumeToken(Token::l_square); - auto checkDims = [&](const llvm::SmallVectorImpl &prevDims, - const llvm::SmallVectorImpl &newDims) { + auto checkDims = [&](const llvm::SmallVectorImpl &prevDims, + const llvm::SmallVectorImpl &newDims) { if (prevDims == newDims) return ParseSuccess; return p.emitError("tensor literal is invalid; ranks are not consistent " @@ -801,10 +803,10 @@ ParseResult TensorLiteralParser::parseList(llvm::SmallVectorImpl &dims) { }; bool first = true; - llvm::SmallVector newDims; + llvm::SmallVector newDims; unsigned size = 0; auto parseCommaSeparatedList = [&]() { - llvm::SmallVector thisDims; + llvm::SmallVector thisDims; if (parseElementOrList(thisDims)) return ParseFailure; ++size; diff --git a/mlir/lib/StandardOps/StandardOps.cpp b/mlir/lib/StandardOps/StandardOps.cpp index d9476a7ea9c..328e0c1623b 100644 --- a/mlir/lib/StandardOps/StandardOps.cpp +++ b/mlir/lib/StandardOps/StandardOps.cpp @@ -258,14 +258,14 @@ struct SimplifyAllocConst : public RewritePattern { // Ok, we have one or more constant operands. Collect the non-constant ones // and keep track of the resultant memref type to build. - SmallVector newShapeConstants; + SmallVector newShapeConstants; newShapeConstants.reserve(memrefType.getRank()); SmallVector newOperands; SmallVector droppedOperands; unsigned dynamicDimPos = 0; for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { - int dimSize = memrefType.getDimSize(dim); + int64_t dimSize = memrefType.getDimSize(dim); // If this is already static dimension, keep it. if (dimSize != -1) { newShapeConstants.push_back(dimSize); @@ -794,7 +794,7 @@ Attribute DimOp::constantFold(ArrayRef operands, MLIRContext *context) const { // Constant fold dim when the size along the index referred to is a constant. auto opType = getOperand()->getType(); - int indexSize = -1; + int64_t indexSize = -1; if (auto tensorType = opType.dyn_cast()) { indexSize = tensorType.getShape()[getIndex()]; } else if (auto memrefType = opType.dyn_cast()) { @@ -1268,7 +1268,7 @@ bool MemRefCastOp::verify() const { return emitOpError("requires input and result ranks to match"); for (unsigned i = 0, e = opType.getRank(); i != e; ++i) { - int opDim = opType.getDimSize(i), resultDim = resType.getDimSize(i); + int64_t opDim = opType.getDimSize(i), resultDim = resType.getDimSize(i); if (opDim != -1 && resultDim != -1 && opDim != resultDim) return emitOpError("requires static dimensions to match"); } @@ -1628,7 +1628,7 @@ bool TensorCastOp::verify() const { return emitOpError("requires input and result ranks to match"); for (unsigned i = 0, e = opRType.getRank(); i != e; ++i) { - int opDim = opRType.getDimSize(i), resultDim = resRType.getDimSize(i); + int64_t opDim = opRType.getDimSize(i), resultDim = resRType.getDimSize(i); if (opDim != -1 && resultDim != -1 && opDim != resultDim) return emitOpError("requires static dimensions to match"); } diff --git a/mlir/lib/SuperVectorOps/SuperVectorOps.cpp b/mlir/lib/SuperVectorOps/SuperVectorOps.cpp index 4a106b066d6..3ffff21155b 100644 --- a/mlir/lib/SuperVectorOps/SuperVectorOps.cpp +++ b/mlir/lib/SuperVectorOps/SuperVectorOps.cpp @@ -484,7 +484,7 @@ bool VectorTypeCastOp::verify() const { if (!dstVectorType) return emitOpError( "expects vector as an element of the target memref type"); - if (llvm::any_of(dstMemrefType.getShape(), [](int s) { return s == -1; })) + if (!dstMemrefType.hasStaticShape()) return emitOpError("does not support dynamic shapes"); if (!getOperand()->getType().isa()) diff --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp index 9dfa86c9d94..b91139afe8e 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp @@ -360,7 +360,7 @@ llvm::Value *ModuleLowerer::emitMemRefAlloc(ConstOpPointer allocOp) { SmallVector sizes; sizes.reserve(allocOp->getNumOperands()); unsigned i = 0; - for (int s : type.getShape()) { + for (int64_t s : type.getShape()) { llvm::Value *value = (s == -1) ? valueMapping.lookup(allocOp->getOperand(i++)) : getIndexConstant(s); diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 8b86056c8a9..e9d66ef74c3 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -114,7 +114,7 @@ struct StrideInfo { /// successively nested. // TODO(bondhugula): make this work with non-identity layout maps. static void getMultiLevelStrides(const MemRefRegion ®ion, - ArrayRef bufferShape, + ArrayRef bufferShape, SmallVectorImpl *strideInfos) { if (bufferShape.size() <= 1) return; @@ -122,7 +122,7 @@ static void getMultiLevelStrides(const MemRefRegion ®ion, int64_t numEltPerStride = 1; int64_t stride = 1; for (int d = bufferShape.size() - 1; d >= 1; d--) { - int dimSize = region.memref->getType().cast().getDimSize(d); + int64_t dimSize = region.memref->getType().cast().getDimSize(d); stride *= dimSize; numEltPerStride *= bufferShape[d]; // A stride is needed only if the region has a shorter extent than the @@ -169,7 +169,7 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForInst *forInst, Value *zeroIndex = top.create(loc, 0); unsigned rank = memRefType.getRank(); - SmallVector fastBufferShape; + SmallVector fastBufferShape; // Compute the extents of the buffer. std::vector> lbs; diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 94d763fcbd1..24914878656 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -711,7 +711,7 @@ static Value *createPrivateMemRef(ForInst *forInst, // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'. MemRefRegion region; getMemRefRegion(srcStoreOpInst, dstLoopDepth, ®ion); - SmallVector newShape; + SmallVector newShape; std::vector> lbs; SmallVector lbDivisors; lbs.reserve(rank); diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index ccda1385df4..19208d4c268 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -96,9 +96,9 @@ private: MLFuncGlobalLoweringState *state; MemRefType memrefType; - ArrayRef memrefShape; + ArrayRef memrefShape; VectorType vectorType; - ArrayRef vectorShape; + ArrayRef vectorShape; AffineMap permutationMap; /// Used for staging the transfer in a local scalar buffer. @@ -232,9 +232,9 @@ VectorTransferRewriter::makeVectorTransferAccessInfo() { } emitter .template bindZipRangeConstants( - llvm::zip(lbs, SmallVector(ivs.size(), 0))) + llvm::zip(lbs, SmallVector(ivs.size(), 0))) .template bindZipRangeConstants( - llvm::zip(steps, SmallVector(ivs.size(), 1))); + llvm::zip(steps, SmallVector(ivs.size(), 1))); return VectorTransferAccessInfo{ivs, makeExprs(lbs), diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 6085edd8e8e..e82390f8db9 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -187,7 +187,7 @@ struct MaterializationState { MaterializationState() : hwVectorSize(clVectorSize.size(), 0) { std::copy(clVectorSize.begin(), clVectorSize.end(), hwVectorSize.begin()); } - SmallVector hwVectorSize; + SmallVector hwVectorSize; VectorType superVectorType; VectorType hwVectorType; SmallVector hwVectorInstance; @@ -458,7 +458,7 @@ static AffineMap projectedPermutationMap(VectorTransferOpTy *transfer, SmallVector keep; MLIRContext *context = transfer->getInstruction()->getContext(); functional::zipApply( - [&dim, &keep, context](int shape, int ratio) { + [&dim, &keep, context](int64_t shape, int64_t ratio) { assert(shape >= ratio && "shape dim must be greater than ratio dim"); if (shape != ratio) { // HW vector is not full instantiated along this dim, keep it. diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 989af0071d7..9e7c928070f 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -87,8 +87,8 @@ static bool doubleBuffer(Value *oldMemRef, ForInst *forInst) { // Doubles the shape with a leading dimension extent of 2. auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType { // Add the leading dimension in the shape for the double buffer. - ArrayRef oldShape = oldMemRefType.getShape(); - SmallVector newShape(1 + oldMemRefType.getRank()); + ArrayRef oldShape = oldMemRefType.getShape(); + SmallVector newShape(1 + oldMemRefType.getRank()); newShape[0] = 2; std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1); auto newMemRefType = diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index 7e5cac0d87c..ad966e8d280 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -101,8 +101,8 @@ char VectorizerTestPass::passID = 0; void VectorizerTestPass::testVectorShapeRatio(Function *f) { using matcher::Op; - SmallVector shape(clTestVectorShapeRatio.begin(), - clTestVectorShapeRatio.end()); + SmallVector shape(clTestVectorShapeRatio.begin(), + clTestVectorShapeRatio.end()); auto subVectorType = VectorType::get(shape, Type::getF32(f->getContext())); // Only filter instructions that operate on a strict super-vector and have one // return. This makes testing easier. diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 58bb3901947..8a6d965ce0d 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -667,7 +667,7 @@ char Vectorize::passID = 0; namespace { struct VectorizationStrategy { - ArrayRef vectorSizes; + SmallVector vectorSizes; DenseMap loopToVectorDim; }; @@ -1280,7 +1280,8 @@ PassResult Vectorize::runOnFunction(Function *f) { for (auto m : matches) { VectorizationStrategy strategy; // TODO(ntv): depending on profitability, elect to reduce the vector size. - strategy.vectorSizes = clVirtualVectorSize; + strategy.vectorSizes.assign(clVirtualVectorSize.begin(), + clVirtualVectorSize.end()); auto fail = analyzeProfitability(m.second, 1, patternDepth, &strategy); if (fail) { continue; -- cgit v1.2.3 From 06d21d9f64559e8f7b9fcd330b55cd22df5a204c Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Thu, 24 Jan 2019 17:01:49 -0800 Subject: loop-fusion: debug info cleanup PiperOrigin-RevId: 230817383 --- mlir/lib/Transforms/LoopFusion.cpp | 60 +++++++++++++++++++++----------------- 1 file changed, 33 insertions(+), 27 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 24914878656..6a8c4fd0230 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -832,13 +832,16 @@ static bool isFusionProfitable(OperationInst *srcOpInst, ArrayRef dstOpInsts, ComputationSliceState *sliceState, unsigned *dstLoopDepth) { - LLVM_DEBUG(llvm::dbgs() << "Checking whether fusion is profitable between:\n"; - llvm::dbgs() << " "; srcOpInst->dump(); llvm::dbgs() << " and \n"; - for (auto dstOpInst - : dstOpInsts) { - llvm::dbgs() << " "; - dstOpInst->dump(); - }); + LLVM_DEBUG({ + llvm::dbgs() << "Checking whether fusion is profitable between:\n"; + llvm::dbgs() << " "; + srcOpInst->dump(); + llvm::dbgs() << " and \n"; + for (auto dstOpInst : dstOpInsts) { + llvm::dbgs() << " "; + dstOpInst->dump(); + }; + }); // Compute cost of sliced and unsliced src loop nest. SmallVector srcLoopIVs; @@ -963,15 +966,16 @@ static bool isFusionProfitable(OperationInst *srcOpInst, double storageReduction = static_cast(srcLoopNestCost) / sliceIterationCount; - LLVM_DEBUG( - std::stringstream msg; - msg << " evaluating fusion profitability at depth : " << i << "\n" - << std::setprecision(2) << " additional compute fraction: " - << 100.0 * additionalComputeFraction << "%\n" - << " storage reduction factor: " << storageReduction << "x\n" - << " fused nest cost: " << fusedLoopNestComputeCost << "\n" - << " slice iteration count: " << sliceIterationCount << "\n"; - llvm::dbgs() << msg.str()); + LLVM_DEBUG({ + std::stringstream msg; + msg << " evaluating fusion profitability at depth : " << i << "\n" + << std::setprecision(2) << " additional compute fraction: " + << 100.0 * additionalComputeFraction << "%\n" + << " storage reduction factor: " << storageReduction << "x\n" + << " fused nest cost: " << fusedLoopNestComputeCost << "\n" + << " slice iteration count: " << sliceIterationCount << "\n"; + llvm::dbgs() << msg.str(); + }); double computeToleranceThreshold = clFusionAddlComputeTolerance.getNumOccurrences() > 0 @@ -1014,8 +1018,8 @@ static bool isFusionProfitable(OperationInst *srcOpInst, *dstLoopDepth = bestDstLoopDepth.getValue(); LLVM_DEBUG( - llvm::dbgs() << " LoopFusion fusion stats:\n" - << "\n Best loop depth: " << bestDstLoopDepth + llvm::dbgs() << " LoopFusion fusion stats:" + << "\n best loop depth: " << bestDstLoopDepth << "\n src loop nest compute cost: " << srcLoopNestCost << "\n dst loop nest compute cost: " << dstLoopNestCost << "\n fused loop nest compute cost: " @@ -1060,15 +1064,17 @@ static bool isFusionProfitable(OperationInst *srcOpInst, (static_cast(srcLoopNestCost) + dstLoopNestCost) - 1); - std::stringstream msg; - msg << " fusion is most profitable at depth " << *dstLoopDepth << " with " - << setprecision(2) << additionalComputeFraction - << "% redundant computation and a "; - msg << (storageReduction.hasValue() - ? std::to_string(storageReduction.getValue()) - : ""); - msg << "% storage reduction.\n"; - LLVM_DEBUG(llvm::dbgs() << msg.str()); + LLVM_DEBUG({ + std::stringstream msg; + msg << " fusion is most profitable at depth " << *dstLoopDepth << " with " + << setprecision(2) << additionalComputeFraction + << "% redundant computation and a "; + msg << (storageReduction.hasValue() + ? std::to_string(storageReduction.getValue()) + : ""); + msg << "% storage reduction.\n"; + llvm::dbgs() << msg.str(); + }); // Update return parameter 'sliceState' with 'bestSliceState'. ComputationSliceState *bestSliceState = &sliceStates[*dstLoopDepth - 1]; -- cgit v1.2.3 From 5c5739d42b2f6a00572f757602341b0dc4ad2569 Mon Sep 17 00:00:00 2001 From: MLIR Team Date: Thu, 24 Jan 2019 22:27:40 -0800 Subject: Change the dependence check in the loop fusion pass to use the MLIR instruction list ordering (instead of the dependence graph node id ordering). This breaks the overloading of dependence graph node ids as both edge endpoints and instruction list position. PiperOrigin-RevId: 230849232 --- mlir/lib/Transforms/LoopFusion.cpp | 45 +++++++++++++++++++++++++++----------- 1 file changed, 32 insertions(+), 13 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 6a8c4fd0230..4df3a8c71d1 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -302,15 +302,33 @@ public: return outEdgeCount; } - // Returns the min node id across all outgoing edges from node 'id', skipping - // edges with 'memrefToSkip'. - unsigned getMinOutEdgeNodeId(unsigned id, Value *memrefToSkip) { - unsigned minId = std::numeric_limits::max(); - if (outEdges.count(id) > 0) - for (auto &outEdge : outEdges[id]) - if (outEdge.memref != memrefToSkip) - minId = std::min(minId, outEdge.id); - return minId; + // Check for a dependence in Block instruction list range (srcId, dstId) on + // memrefs other than 'memrefToSkip' (which will be privatized for the fused + // loop). + bool hasDependenceTargetInRange(unsigned srcId, unsigned dstId, + Value *memrefToSkip) { + if (outEdges.count(srcId) == 0) + return false; + // Check if any of the outgoing edge targets from srcId lie in + // (srcId, dstId). + SmallPtrSet depInsts; + for (auto &outEdge : outEdges[srcId]) { + if (outEdge.id != dstId && outEdge.memref != memrefToSkip) { + Node *node = getNode(outEdge.id); + depInsts.insert(node->inst); + } + } + // Do a linear walk from 'srcNode.inst' to 'dstNode.inst' and for each + // instruction 'inst' in range ('srcNode.inst', 'dstNode.inst') test + // if 'depInsts' contains 'inst', and return true if it does. + // TODO(andydavis) If this linear search becomes a compile time issue, + // create a data structure which allows a faster search through ForInsts + // in a Block. + Block::iterator it = std::next(Block::iterator(getNode(srcId)->inst)); + Block::iterator itEnd = Block::iterator(getNode(dstId)->inst); + return std::any_of(it, itEnd, [&](Instruction &inst) { + return depInsts.count(&inst) > 0; + }); } // Updates edge mappings from node 'srcId' to node 'dstId'. @@ -1063,7 +1081,7 @@ static bool isFusionProfitable(OperationInst *srcOpInst, 100.0 * (minFusedLoopNestComputeCost / (static_cast(srcLoopNestCost) + dstLoopNestCost) - 1); - + (void)additionalComputeFraction; LLVM_DEBUG({ std::stringstream msg; msg << " fusion is most profitable at depth " << *dstLoopDepth << " with " @@ -1134,7 +1152,7 @@ static bool isFusionProfitable(OperationInst *srcOpInst, // TODO(andydavis) Experiment with other fusion policies. // TODO(andydavis) Add support for fusing for input reuse (perhaps by // constructing a graph with edges which represent loads from the same memref -// in two different loop nestst. +// in two different loop nests. struct GreedyFusion { public: MemRefDependenceGraph *mdg; @@ -1194,8 +1212,9 @@ public: if (mdg->getInEdgeCount(srcNode->id, memref) != 0) continue; - // Skip if 'srcNode' has out edges to other memrefs after 'dstId'. - if (mdg->getMinOutEdgeNodeId(srcNode->id, memref) < dstId) + // Skip if 'srcNode' has out edges on memrefs other than 'memref' + // for nodes in instruction list range (srcNode.inst, dstNode.inst). + if (mdg->hasDependenceTargetInRange(srcNode->id, dstNode->id, memref)) continue; // Check if fusion would be profitable. -- cgit v1.2.3 From b4a1443508d8a22b33ce892847260b241e3909e7 Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Fri, 25 Jan 2019 16:00:50 -0800 Subject: Update replaceAllMemRefUsesWith to generate single result affine_apply's for index remapping - generate a sequence of single result affine_apply's for the index remapping (instead of one multi result affine_apply) - update dma-generate and loop-fusion test cases; while on this, change test cases to use single result affine apply ops - some fusion comment fix/cleanup PiperOrigin-RevId: 230985830 --- mlir/lib/Transforms/LoopFusion.cpp | 24 ++-- mlir/lib/Transforms/Utils/Utils.cpp | 13 +- mlir/test/Transforms/dma-generate.mlir | 80 ++++++----- mlir/test/Transforms/loop-fusion.mlir | 245 ++++++++++++++++++++------------- 4 files changed, 218 insertions(+), 144 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 4df3a8c71d1..0add7972420 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -819,8 +819,9 @@ static uint64_t getSliceIterationCount( // Checks the profitability of fusing a backwards slice of the loop nest // surrounding 'srcOpInst' into the loop nest surrounding 'dstOpInsts'. -// Returns true if it profitable to fuse the candidate loop nests. Returns -// false otherwise. +// Returns true if it is profitable to fuse the candidate loop nests. Returns +// false otherwise. `dstLoopDepth` is set to the most profitable depth at which +// to materialize the source loop nest slice. // The profitability model executes the following steps: // *) Computes the backward computation slice at 'srcOpInst'. This // computation slice of the loop nest surrounding 'srcOpInst' is @@ -837,7 +838,7 @@ static uint64_t getSliceIterationCount( // NOTE: If the dst loop nest includes multiple loads in 'dstOpInsts' for // the same memref as is written by 'srcOpInst', then the union of slice // loop bounds is used to compute the slice and associated slice cost. -// NOTE: 'dstLoopDepth' refers the loop depth within the destination loop +// NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop // nest, at which the src computation slice is inserted/fused. // NOTE: We attempt to maximize the dst loop depth, but there are cases // where a particular setting for 'dstLoopNest' might fuse an unsliced @@ -933,18 +934,17 @@ static bool isFusionProfitable(OperationInst *srcOpInst, // Compute slice boun dunion of 'tmpSliceState' and 'sliceStates[i - 1]'. getSliceUnion(tmpSliceState, &sliceStates[i - 1]); } - // Build trip count map for computation slice. + // Build trip count map for computation slice. We'll skip cases where the + // trip count was non-constant. sliceTripCountMap.clear(); if (!buildSliceTripCountMap(srcOpInst, &sliceStates[i - 1], &sliceTripCountMap)) - // We'll skip cases where we the trip count was non-constant. continue; // Checks whether a store to load forwarding will happen. int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap); - bool storeLoadFwdGuaranteed = (sliceIterationCount == 1); - assert(sliceIterationCount > 0); + bool storeLoadFwdGuaranteed = (sliceIterationCount == 1); // Compute cost of fusion for this dest loop depth. @@ -1217,18 +1217,18 @@ public: if (mdg->hasDependenceTargetInRange(srcNode->id, dstNode->id, memref)) continue; - // Check if fusion would be profitable. + // Check if fusion would be profitable and at what depth. // Get unique 'srcNode' store op. auto *srcStoreOpInst = srcNode->stores.front(); - unsigned dstLoopDepth; + unsigned bestDstLoopDepth; mlir::ComputationSliceState sliceState; if (!isFusionProfitable(srcStoreOpInst, dstLoadOpInsts, &sliceState, - &dstLoopDepth)) + &bestDstLoopDepth)) continue; // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. auto *sliceLoopNest = mlir::insertBackwardComputationSlice( - srcStoreOpInst, dstLoadOpInsts[0], dstLoopDepth, &sliceState); + srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState); if (sliceLoopNest != nullptr) { // Update edges between 'srcNode' and 'dstNode'. mdg->updateEdges(srcNode->id, dstNode->id); @@ -1250,7 +1250,7 @@ public: } assert(storesForMemref.size() == 1); auto *newMemRef = createPrivateMemRef( - dstForInst, storesForMemref[0], dstLoopDepth); + dstForInst, storesForMemref[0], bestDstLoopDepth); visitedMemrefs.insert(newMemRef); // Collect dst loop stats after memref privatizaton transformation. diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 4101a07a33d..03c2a9df1e4 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -128,11 +128,16 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, oldMemRefRank); if (indexRemap && indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) { - auto remapOp = builder.create(opInst->getLoc(), indexRemap, - remapOperands); + // Remapped indices. - state.operands.append(remapOp->getInstruction()->result_begin(), - remapOp->getInstruction()->result_end()); + for (auto resultExpr : indexRemap.getResults()) { + auto singleResMap = + builder.getAffineMap(indexRemap.getNumDims(), + indexRemap.getNumSymbols(), resultExpr, {}); + auto afOp = builder.create(opInst->getLoc(), + singleResMap, remapOperands); + state.operands.push_back(afOp->getResult(0)); + } } else { // No remapping specified. state.operands.append(remapOperands.begin(), remapOperands.end()); diff --git a/mlir/test/Transforms/dma-generate.mlir b/mlir/test/Transforms/dma-generate.mlir index 0777e4abcd9..2a21ed12414 100644 --- a/mlir/test/Transforms/dma-generate.mlir +++ b/mlir/test/Transforms/dma-generate.mlir @@ -1,12 +1,16 @@ // RUN: mlir-opt %s -split-input-file -dma-generate -verify | FileCheck %s // Index of the buffer for the second DMA is remapped. -// CHECK-DAG: [[MAP:#map[0-9]+]] = (d0) -> (d0 - 256) +// CHECK-DAG: [[MAP_MINUS_256:#map[0-9]+]] = (d0) -> (d0 - 256) +// CHECK-DAG: [[MAP_PLUS_256:#map[0-9]+]] = (d0) -> (d0 + 256) // CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> (d0 * 16 + d1) -// CHECK-DAG: [[MAP_INDEX_DIFF:#map[0-9]+]] = (d0, d1, d2, d3) -> (d2 - d0, d3 - d1) -// CHECK-DAG: [[MAP_MINUS_ONE:#map[0-9]+]] = (d0, d1) -> (d0 - 1, d1) -// CHECK-DAG: [[MAP_ORIG_ACCESS:#map[0-9]+]] = (d0, d1)[s0, s1] -> (d0, d1 + s0 + s1) -// CHECK-DAG: [[MAP_SUB_OFFSET:#map[0-9]+]] = (d0, d1, d2) -> (d1, d2 - (d0 + 9)) +// CHECK-DAG: [[MAP_INDEX_DIFF_EVEN:#map[0-9]+]] = (d0, d1, d2, d3) -> (d2 - d0) +// CHECK-DAG: [[MAP_INDEX_DIFF_ODD:#map[0-9]+]] = (d0, d1, d2, d3) -> (d3 - d1) +// CHECK-DAG: [[MAP_D0_MINUS_ONE:#map[0-9]+]] = (d0, d1) -> (d0 - 1) +// CHECK-DAG: [[MAP_D1:#map[0-9]+]] = (d0, d1) -> (d1) +// CHECK-DAG: [[MAP_SYM_SHIFT:#map[0-9]+]] = (d0, d1)[s0, s1] -> (d1 + s0 + s1) +// CHECK-DAG: [[MAP_3D_D1:#map[0-9]+]] = (d0, d1, d2) -> (d1) +// CHECK-DAG: [[MAP_SUB_OFFSET:#map[0-9]+]] = (d0, d1, d2) -> (d2 - (d0 + 9)) // CHECK-LABEL: func @loop_nest_1d() { func @loop_nest_1d() { @@ -30,8 +34,8 @@ func @loop_nest_1d() { // CHECK-NEXT: dma_wait %6[%c0], %c256_0 : memref<1xi32> // CHECK: for %i0 = 0 to 256 { // CHECK-NEXT: %7 = load %3[%i0] : memref<256xf32, 1> - // CHECK: %8 = affine_apply #map{{[0-9]+}}(%i0) - // CHECK: %9 = affine_apply [[MAP]](%8) + // CHECK: %8 = affine_apply [[MAP_PLUS_256]](%i0) + // CHECK: %9 = affine_apply [[MAP_MINUS_256]](%8) // CHECK-NEXT: %10 = load %5[%9] : memref<256xf32, 1> // Already in faster memory space. // CHECK: %11 = load %2[%i0] : memref<256xf32, 1> @@ -171,8 +175,9 @@ func @loop_nest_tiled() -> memref<256x1024xf32> { // CHECK-NEXT: for %i3 = #map for %i2 = (d0) -> (d0)(%i0) to (d0) -> (d0 + 32)(%i0) { for %i3 = (d0) -> (d0)(%i1) to (d0) -> (d0 + 32)(%i1) { - // CHECK: %5 = affine_apply [[MAP_INDEX_DIFF]](%i0, %i1, %i2, %i3) - // CHECK-NEXT: %6 = load %3[%5#0, %5#1] : memref<32x32xf32, 1> + // CHECK-NEXT: %5 = affine_apply [[MAP_INDEX_DIFF_EVEN]](%i0, %i1, %i2, %i3) + // CHECK-NEXT: %6 = affine_apply [[MAP_INDEX_DIFF_ODD]](%i0, %i1, %i2, %i3) + // CHECK-NEXT: %7 = load %3[%5, %6] : memref<32x32xf32, 1> %1 = load %0[%i2, %i3] : memref<256x1024xf32> } // CHECK-NEXT: } } @@ -193,8 +198,9 @@ func @dma_constant_dim_access(%A : memref<100x100xf32>) { // CHECK-NEXT: dma_wait %1[%c0], %c100 : memref<1xi32> for %i = 0 to 100 { for %j = 0 to ()[s0] -> (s0) ()[%N] { - // CHECK: %2 = affine_apply [[MAP_MINUS_ONE]](%c1_0, %i1) - // CHECK-NEXT: %3 = load %0[%2#0, %2#1] : memref<1x100xf32, 1> + // CHECK: %2 = affine_apply [[MAP_D0_MINUS_ONE]](%c1_0, %i1) + // CHECK: %3 = affine_apply [[MAP_D1]](%c1_0, %i1) + // CHECK-NEXT: %4 = load %0[%2, %3] : memref<1x100xf32, 1> load %A[%one, %j] : memref<100 x 100 x f32> } } @@ -206,8 +212,8 @@ func @dma_with_symbolic_accesses(%A : memref<100x100xf32>, %M : index) { %N = constant 9 : index for %i = 0 to 100 { for %j = 0 to 100 { - %idx = affine_apply (d0, d1) [s0, s1] -> (d0, d1 + s0 + s1)(%i, %j)[%M, %N] - load %A[%idx#0, %idx#1] : memref<100 x 100 x f32> + %idy = affine_apply (d0, d1) [s0, s1] -> (d1 + s0 + s1)(%i, %j)[%M, %N] + load %A[%i, %idy] : memref<100 x 100 x f32> } } return @@ -217,9 +223,10 @@ func @dma_with_symbolic_accesses(%A : memref<100x100xf32>, %M : index) { // CHECK-NEXT: dma_wait %2[%c0], %c10000 // CHECK-NEXT: for %i0 = 0 to 100 { // CHECK-NEXT: for %i1 = 0 to 100 { -// CHECK-NEXT: %3 = affine_apply [[MAP_ORIG_ACCESS]](%i0, %i1)[%arg1, %c9] -// CHECK-NEXT: %4 = affine_apply [[MAP_SUB_OFFSET]](%arg1, %3#0, %3#1) -// CHECK-NEXT: %5 = load %1[%4#0, %4#1] : memref<100x100xf32, 1> +// CHECK-NEXT: %3 = affine_apply [[MAP_SYM_SHIFT]](%i0, %i1)[%arg1, %c9] +// CHECK-NEXT: %4 = affine_apply [[MAP_3D_D1]](%arg1, %i0, %3) +// CHECK-NEXT: %5 = affine_apply [[MAP_SUB_OFFSET]](%arg1, %i0, %3) +// CHECK-NEXT: %6 = load %1[%4, %5] : memref<100x100xf32, 1> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -236,8 +243,8 @@ func @dma_with_symbolic_loop_bounds(%A : memref<100x100xf32>, %M : index, %N: in // CHECK-NEXT: dma_wait %1[%c0], %c10000 : memref<1xi32> for %i = 0 to 100 { for %j = %M to %N { - %idx = affine_apply (d0, d1) [s0] -> (d0, d1 + s0)(%i, %j)[%K] - load %A[%idx#0, %idx#1] : memref<100 x 100 x f32> + %idy = affine_apply (d1) [s0] -> (d1 + s0)(%j)[%K] + load %A[%i, %idy] : memref<100 x 100 x f32> } } return @@ -268,12 +275,14 @@ func @dma_memref_3d(%arg0: memref<1024x1024x1024xf32>) { for %i = 0 to 1024 { for %j = 0 to 1024 { for %k = 0 to 1024 { - %idx = affine_apply (d0, d1, d2) -> (d0 mod 128, d1 mod 128, d2 mod 128)(%i, %j, %k) + %idx = affine_apply (d0) -> (d0 mod 128)(%i) + %idy = affine_apply (d0) -> (d0 mod 128)(%j) + %idz = affine_apply (d0) -> (d0 mod 128)(%k) // DMA with nested striding (or emulating with loop around strided DMA) // not yet implemented. - // CHECK: %3 = load %arg0[%2#0, %2#1, %2#2] : memref<1024x1024x1024xf32> - %v = load %arg0[%idx#0, %idx#1, %idx#2] : memref<1024 x 1024 x 1024 x f32> - // expected-error@-8 {{DMA generation failed for one or more memref's}} + // CHECK: %5 = load %arg0[%2, %3, %4] : memref<1024x1024x1024xf32> + %v = load %arg0[%idx, %idy, %idz] : memref<1024 x 1024 x 1024 x f32> + // expected-error@-10 {{DMA generation failed for one or more memref's}} } } } @@ -285,8 +294,9 @@ func @dma_memref_3d(%arg0: memref<1024x1024x1024xf32>) { // CHECK: #map0 = (d0) -> (d0 + 64) // CHECK-NEXT: #map1 = (d0) -> (d0 + 128) // CHECK-NEXT: #map2 = (d0) -> (d0 + 2) -// CHECK-NEXT: #map3 = (d0, d1) -> (d0 - 2, d1 - 2) -// CHECK-NEXT: #map4 = (d0) -> (d0 + 192) +// CHECK-NEXT: #map3 = (d0, d1) -> (d0 - 2) +// CHECK-NEXT: #map4 = (d0, d1) -> (d1 - 2) +// CHECK-NEXT: #map5 = (d0) -> (d0 + 192) // The first load accesses ([2,258), [128,384)) // The second load accesses ([64,320), [2,258)) @@ -330,15 +340,19 @@ func @multi_load_store_union() { // CHECK-NEXT: %6 = affine_apply #map2(%i0) // CHECK-NEXT: %7 = affine_apply #map2(%i1) // CHECK-NEXT: %8 = affine_apply #map3(%6, %5) -// CHECK-NEXT: %9 = load %1[%8#0, %8#1] : memref<382x446xf32, 1> -// CHECK-NEXT: %10 = affine_apply #map3(%4, %7) -// CHECK-NEXT: %11 = load %1[%10#0, %10#1] : memref<382x446xf32, 1> -// CHECK-NEXT: %12 = affine_apply #map1(%i0) -// CHECK-NEXT: %13 = affine_apply #map4(%i1) -// CHECK-NEXT: %14 = affine_apply #map3(%6, %13) -// CHECK-NEXT: store %9, %1[%14#0, %14#1] : memref<382x446xf32, 1> -// CHECK-NEXT: %15 = affine_apply #map3(%12, %7) -// CHECK-NEXT: store %11, %1[%15#0, %15#1] : memref<382x446xf32, 1> +// CHECK-NEXT: %9 = affine_apply #map4(%6, %5) +// CHECK-NEXT: %10 = load %1[%8, %9] : memref<382x446xf32, 1> +// CHECK-NEXT: %11 = affine_apply #map3(%4, %7) +// CHECK-NEXT: %12 = affine_apply #map4(%4, %7) +// CHECK-NEXT: %13 = load %1[%11, %12] : memref<382x446xf32, 1> +// CHECK-NEXT: %14 = affine_apply #map1(%i0) +// CHECK-NEXT: %15 = affine_apply #map5(%i1) +// CHECK-NEXT: %16 = affine_apply #map3(%6, %15) +// CHECK-NEXT: %17 = affine_apply #map4(%6, %15) +// CHECK-NEXT: store %10, %1[%16, %17] : memref<382x446xf32, 1> +// CHECK-NEXT: %18 = affine_apply #map3(%14, %7) +// CHECK-NEXT: %19 = affine_apply #map4(%14, %7) +// CHECK-NEXT: store %13, %1[%18, %19] : memref<382x446xf32, 1> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: dma_start %1[%c0, %c0], %0[%c2, %c2_0], %c170372, %3[%c0], %c512, %c446 : memref<382x446xf32, 1>, memref<512x512xf32>, memref<1xi32> diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index 86a24cf7796..d170ce590f7 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -79,8 +79,9 @@ func @should_fuse_reduction_to_pointwise() { // ----- // CHECK-DAG: [[MAP_SHIFT_MINUS_ONE_R1:#map[0-9]+]] = (d0) -> (d0 - 1) -// CHECK-DAG: [[MAP_SHIFT_BY_ONE:#map[0-9]+]] = (d0, d1) -> (d0 + 1, d1 + 1) -// CHECK-DAG: [[MAP_SHIFT_MINUS_IV_R2:#map[0-9]+]] = (d0, d1, d2, d3) -> (-d0 + d2, -d1 + d3) +// CHECK-DAG: [[MAP_SHIFT_BY_ONE:#map[0-9]+]] = (d0) -> (d0 + 1) +// CHECK-DAG: [[MAP_SHIFT_MINUS_IV_R2_EVEN:#map[0-9]+]] = (d0, d1, d2, d3) -> (-d0 + d2) +// CHECK-DAG: [[MAP_SHIFT_MINUS_IV_R2_ODD:#map[0-9]+]] = (d0, d1, d2, d3) -> (-d1 + d3) // CHECK-LABEL: func @should_fuse_loop_nests_with_shifts() { func @should_fuse_loop_nests_with_shifts() { @@ -89,8 +90,9 @@ func @should_fuse_loop_nests_with_shifts() { for %i0 = 0 to 9 { for %i1 = 0 to 9 { - %a0 = affine_apply (d0, d1) -> (d0 + 1, d1 + 1) (%i0, %i1) - store %cf7, %a[%a0#0, %a0#1] : memref<10x10xf32> + %idx = affine_apply (d0) -> (d0 + 1) (%i0) + %idy = affine_apply (d0) -> (d0 + 1) (%i1) + store %cf7, %a[%idx, %idy] : memref<10x10xf32> } } for %i2 = 1 to 10 { @@ -112,11 +114,14 @@ func @should_fuse_loop_nests_with_shifts() { // CHECK-NEXT: for %i1 = 1 to 10 { // CHECK-NEXT: %1 = affine_apply [[MAP_SHIFT_MINUS_ONE_R1]](%i0) // CHECK-NEXT: %2 = affine_apply [[MAP_SHIFT_MINUS_ONE_R1]](%i1) - // CHECK-NEXT: %3 = affine_apply [[MAP_SHIFT_BY_ONE]](%1, %2) - // CHECK-NEXT: %4 = affine_apply [[MAP_SHIFT_MINUS_IV_R2]](%i0, %i1, %3#0, %3#1) - // CHECK-NEXT: store %cst, %0[%4#0, %4#1] : memref<1x1xf32> - // CHECK-NEXT: %5 = affine_apply [[MAP_SHIFT_MINUS_IV_R2]](%i0, %i1, %i0, %i1) - // CHECK-NEXT: %6 = load %0[%5#0, %5#1] : memref<1x1xf32> + // CHECK-NEXT: %3 = affine_apply [[MAP_SHIFT_BY_ONE]](%1) + // CHECK-NEXT: %4 = affine_apply [[MAP_SHIFT_BY_ONE]](%2) + // CHECK-NEXT: %5 = affine_apply [[MAP_SHIFT_MINUS_IV_R2_EVEN]](%i0, %i1, %3, %4) + // CHECK-NEXT: %6 = affine_apply [[MAP_SHIFT_MINUS_IV_R2_ODD]](%i0, %i1, %3, %4) + // CHECK-NEXT: store %cst, %0[%5, %6] : memref<1x1xf32> + // CHECK-NEXT: %7 = affine_apply [[MAP_SHIFT_MINUS_IV_R2_EVEN]](%i0, %i1, %i0, %i1) + // CHECK-NEXT: %8 = affine_apply [[MAP_SHIFT_MINUS_IV_R2_ODD]](%i0, %i1, %i0, %i1) + // CHECK-NEXT: %9 = load %0[%7, %8] : memref<1x1xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -125,7 +130,8 @@ func @should_fuse_loop_nests_with_shifts() { // ----- -// CHECK-DAG: [[MAP0:#map[0-9]+]] = (d0, d1, d2, d3) -> (-d0 + d2, -d1 + d3) +// CHECK-DAG: [[MAP_D2_D0_DIFF:#map[0-9]+]] = (d0, d1, d2, d3) -> (-d0 + d2) +// CHECK-DAG: [[MAP_D3_D1_DIFF:#map[0-9]+]] = (d0, d1, d2, d3) -> (-d1 + d3) // CHECK-LABEL: func @should_fuse_loop_nest() { func @should_fuse_loop_nest() { @@ -154,14 +160,18 @@ func @should_fuse_loop_nest() { // CHECK-NEXT: [[NEWA:%[0-9]+]] = alloc() : memref<1x1xf32> // CHECK-NEXT: for %i0 = 0 to 10 { // CHECK-NEXT: for %i1 = 0 to 10 { - // CHECK-NEXT: %2 = affine_apply [[MAP0]](%i1, %i0, %i1, %i0) - // CHECK-NEXT: store %cst, [[NEWA]][%2#0, %2#1] : memref<1x1xf32> - // CHECK-NEXT: %3 = affine_apply [[MAP0]](%i1, %i0, %i1, %i0) - // CHECK-NEXT: %4 = load [[NEWA]][%3#0, %3#1] : memref<1x1xf32> - // CHECK-NEXT: %5 = affine_apply [[MAP0]](%i0, %i1, %i0, %i1) - // CHECK-NEXT: store %4, [[NEWB]][%5#0, %5#1] : memref<1x1xf32> - // CHECK-NEXT: %6 = affine_apply [[MAP0]](%i0, %i1, %i0, %i1) - // CHECK-NEXT: %7 = load [[NEWB]][%6#0, %6#1] : memref<1x1xf32> + // CHECK-NEXT: %2 = affine_apply [[MAP_D2_D0_DIFF]](%i1, %i0, %i1, %i0) + // CHECK-NEXT: %3 = affine_apply [[MAP_D3_D1_DIFF]](%i1, %i0, %i1, %i0) + // CHECK-NEXT: store %cst, [[NEWA]][%2, %3] : memref<1x1xf32> + // CHECK-NEXT: %4 = affine_apply [[MAP_D2_D0_DIFF]](%i1, %i0, %i1, %i0) + // CHECK-NEXT: %5 = affine_apply [[MAP_D3_D1_DIFF]](%i1, %i0, %i1, %i0) + // CHECK-NEXT: %6 = load [[NEWA]][%4, %5] : memref<1x1xf32> + // CHECK-NEXT: %7 = affine_apply [[MAP_D2_D0_DIFF]](%i0, %i1, %i0, %i1) + // CHECK-NEXT: %8 = affine_apply [[MAP_D3_D1_DIFF]](%i0, %i1, %i0, %i1) + // CHECK-NEXT: store %6, [[NEWB]][%7, %8] : memref<1x1xf32> + // CHECK-NEXT: %9 = affine_apply [[MAP_D2_D0_DIFF]](%i0, %i1, %i0, %i1) + // CHECK-NEXT: %10 = affine_apply [[MAP_D3_D1_DIFF]](%i0, %i1, %i0, %i1) + // CHECK-NEXT: %11 = load [[NEWB]][%9, %10] : memref<1x1xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -516,42 +526,42 @@ func @should_not_fuse_if_inst_in_loop_nest() { // ----- -// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1, d2) -> (d0, d1, d2) -// CHECK: [[MAP1:#map[0-9]+]] = (d0, d1, d2, d3, d4, d5) -> (-d0 + d3, -d1 + d4, -d2 + d5) -// CHECK: [[MAP_PERMUTE:#map[0-9]+]] = (d0, d1, d2) -> (d1, d2, d0) +// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1, d2, d3, d4, d5) -> (-d0 + d3) +// CHECK: [[MAP1:#map[0-9]+]] = (d0, d1, d2, d3, d4, d5) -> (-d1 + d4) +// CHECK: [[MAP2:#map[0-9]+]] = (d0, d1, d2, d3, d4, d5) -> (-d2 + d5) -#map0 = (d0, d1, d2) -> (d0, d1, d2) - -// CHECK-LABEL: func @remap_ivs() { -func @remap_ivs() { +// CHECK-LABEL: func @permute_and_fuse() { +func @permute_and_fuse() { %m = alloc() : memref<10x20x30xf32> %cf7 = constant 7.0 : f32 for %i0 = 0 to 10 { for %i1 = 0 to 20 { for %i2 = 0 to 30 { - %a0 = affine_apply (d0, d1, d2) -> (d0, d1, d2) (%i0, %i1, %i2) - store %cf7, %m[%a0#0, %a0#1, %a0#2] : memref<10x20x30xf32> + store %cf7, %m[%i0, %i1, %i2] : memref<10x20x30xf32> } } } for %i3 = 0 to 30 { for %i4 = 0 to 10 { for %i5 = 0 to 20 { - %a1 = affine_apply (d0, d1, d2) -> (d1, d2, d0) (%i3, %i4, %i5) - %v0 = load %m[%a1#0, %a1#1, %a1#2] : memref<10x20x30xf32> + %v0 = load %m[%i4, %i5, %i3] : memref<10x20x30xf32> + "foo"(%v0) : (f32) -> () } } } // CHECK: for %i0 = 0 to 30 { // CHECK-NEXT: for %i1 = 0 to 10 { // CHECK-NEXT: for %i2 = 0 to 20 { -// CHECK-NEXT: %1 = affine_apply [[MAP0]](%i1, %i2, %i0) -// CHECK-NEXT: %2 = affine_apply [[MAP1]](%i1, %i2, %i0, %1#0, %1#1, %1#2) -// CHECK-NEXT: store %cst, %0[%2#0, %2#1, %2#2] : memref<1x1x1xf32> -// CHECK-NEXT: %3 = affine_apply [[MAP_PERMUTE]](%i0, %i1, %i2) -// CHECK-NEXT: %4 = affine_apply [[MAP1]](%i1, %i2, %i0, %3#0, %3#1, %3#2) -// CHECK-NEXT: %5 = load %0[%4#0, %4#1, %4#2] : memref<1x1x1xf32> +// CHECK-NEXT: %1 = affine_apply [[MAP0]](%i1, %i2, %i0, %i1, %i2, %i0) +// CHECK-NEXT: %2 = affine_apply [[MAP1]](%i1, %i2, %i0, %i1, %i2, %i0) +// CHECK-NEXT: %3 = affine_apply [[MAP2]](%i1, %i2, %i0, %i1, %i2, %i0) +// CHECK-NEXT: store %cst, %0[%1, %2, %3] : memref<1x1x1xf32> +// CHECK-NEXT: %4 = affine_apply [[MAP0]](%i1, %i2, %i0, %i1, %i2, %i0) +// CHECK-NEXT: %5 = affine_apply [[MAP1]](%i1, %i2, %i0, %i1, %i2, %i0) +// CHECK-NEXT: %6 = affine_apply [[MAP2]](%i1, %i2, %i0, %i1, %i2, %i0) +// CHECK-NEXT: %7 = load %0[%4, %5, %6] : memref<1x1x1xf32> +// CHECK-NEXT: "foo"(%7) : (f32) -> () // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } @@ -563,7 +573,8 @@ func @remap_ivs() { // ----- // CHECK-DAG: #map0 = (d0, d1) -> (d0 * 4 + d1) -// CHECK-DAG: #map1 = (d0) -> (d0 floordiv 4, d0 mod 4) +// CHECK-DAG: #map1 = (d0) -> (d0 floordiv 4) +// CHECK-DAG: #map2 = (d0) -> (d0 mod 4) // Reshape from a 64 x f32 to 16 x 4 x f32. // CHECK-LABEL: func @fuse_reshape_64_16_4 @@ -572,8 +583,9 @@ func @fuse_reshape_64_16_4(%in : memref<64xf32>) { for %i0 = 0 to 64 { %v = load %in[%i0] : memref<64xf32> - %idx = affine_apply (d0) -> (d0 floordiv 4, d0 mod 4) (%i0) - store %v, %out[%idx#0, %idx#1] : memref<16x4xf32> + %idx = affine_apply (d0) -> (d0 floordiv 4) (%i0) + %idy = affine_apply (d0) -> (d0 mod 4) (%i0) + store %v, %out[%idx, %idy] : memref<16x4xf32> } for %i1 = 0 to 16 { @@ -661,16 +673,13 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> { for %jj = 0 to 9 { // Convert output coordinates to linear index. %a0 = affine_apply (d0, d1) -> (d0 * 9 + d1) (%ii, %jj) - %a1 = affine_apply (d0) -> ( - d0 floordiv (2 * 3 * 3 * 16 * 1), - (d0 mod 288) floordiv (3 * 3 * 16 * 1), - ((d0 mod 288) mod 144) floordiv (3 * 16 * 1), - (((d0 mod 288) mod 144) mod 48) floordiv (16 * 1), - ((((d0 mod 288) mod 144) mod 48) mod 16), - ((((d0 mod 144) mod 144) mod 48) mod 16) mod 1 - ) (%a0) - %v = load %in[%a1#0, %a1#1, %a1#2, %a1#3, %a1#4, %a1#5] - : memref<2x2x3x3x16x1xi32> + %0 = affine_apply (d0) -> (d0 floordiv (2 * 3 * 3 * 16 * 1))(%a0) + %1 = affine_apply (d0) -> ((d0 mod 288) floordiv (3 * 3 * 16 * 1))(%a0) + %2 = affine_apply (d0) -> (((d0 mod 288) mod 144) floordiv (3 * 16 * 1))(%a0) + %3 = affine_apply (d0) -> ((((d0 mod 288) mod 144) mod 48) floordiv (16 * 1))(%a0) + %4 = affine_apply (d0) -> ((((d0 mod 288) mod 144) mod 48) mod 16)(%a0) + %5 = affine_apply (d0) -> (((((d0 mod 144) mod 144) mod 48) mod 16) mod 1)(%a0) + %v = load %in[%0, %1, %2, %3, %4, %5] : memref<2x2x3x3x16x1xi32> store %v, %out[%ii, %jj] : memref<64x9xi32> } } @@ -689,7 +698,7 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> { // // CHECK: %0 = alloc() : memref<64x9xi32> // CHECK-NEXT: %1 = alloc() : memref<1x1xi32> -// CHECK-NEXT: %2 = alloc() : memref<1x2x3x3x16x1xi32> +// CHECK-NEXT: %2 = alloc() : memref<1x2x3x3x16x1xi32> // CHECK-NEXT: for %i0 = 0 to 64 { // CHECK-NEXT: for %i1 = 0 to 9 { // CHECK-NEXT: %3 = affine_apply #map0(%i0, %i1) @@ -699,17 +708,34 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> { // CHECK-NEXT: %7 = affine_apply #map4(%i0, %i1) // CHECK-NEXT: %8 = "foo"(%3, %4, %5, %6, %7, %c0) : (index, index, index, index, index, index) -> i32 // CHECK-NEXT: %9 = affine_apply #map5(%i0, %i1, %3, %4, %5, %6, %7, %c0) -// CHECK-NEXT: store %8, %2[%9#0, %9#1, %9#2, %9#3, %9#4, %9#5] : memref<1x2x3x3x16x1xi32> -// CHECK-NEXT: %10 = affine_apply #map6(%i0, %i1) -// CHECK-NEXT: %11 = affine_apply #map7(%10) -// CHECK-NEXT: %12 = affine_apply #map5(%i0, %i1, %11#0, %11#1, %11#2, %11#3, %11#4, %11#5) -// CHECK-NEXT: %13 = load %2[%12#0, %12#1, %12#2, %12#3, %12#4, %12#5] : memref<1x2x3x3x16x1xi32> -// CHECK-NEXT: %14 = affine_apply #map8(%i0, %i1, %i0, %i1) -// CHECK-NEXT: store %13, %1[%14#0, %14#1] : memref<1x1xi32> -// CHECK-NEXT: %15 = affine_apply #map8(%i0, %i1, %i0, %i1) -// CHECK-NEXT: %16 = load %1[%15#0, %15#1] : memref<1x1xi32> -// CHECK-NEXT: %17 = muli %16, %16 : i32 -// CHECK-NEXT: store %17, %0[%i0, %i1] : memref<64x9xi32> +// CHECK-NEXT: %10 = affine_apply #map6(%i0, %i1, %3, %4, %5, %6, %7, %c0) +// CHECK-NEXT: %11 = affine_apply #map7(%i0, %i1, %3, %4, %5, %6, %7, %c0) +// CHECK-NEXT: %12 = affine_apply #map8(%i0, %i1, %3, %4, %5, %6, %7, %c0) +// CHECK-NEXT: %13 = affine_apply #map9(%i0, %i1, %3, %4, %5, %6, %7, %c0) +// CHECK-NEXT: %14 = affine_apply #map10(%i0, %i1, %3, %4, %5, %6, %7, %c0) +// CHECK-NEXT: store %8, %2[%9, %10, %11, %12, %13, %14] : memref<1x2x3x3x16x1xi32> +// CHECK-NEXT: %15 = affine_apply #map11(%i0, %i1) +// CHECK-NEXT: %16 = affine_apply #map12(%15) +// CHECK-NEXT: %17 = affine_apply #map13(%15) +// CHECK-NEXT: %18 = affine_apply #map14(%15) +// CHECK-NEXT: %19 = affine_apply #map15(%15) +// CHECK-NEXT: %20 = affine_apply #map16(%15) +// CHECK-NEXT: %21 = affine_apply #map17(%15) +// CHECK-NEXT: %22 = affine_apply #map5(%i0, %i1, %16, %17, %18, %19, %20, %21) +// CHECK-NEXT: %23 = affine_apply #map6(%i0, %i1, %16, %17, %18, %19, %20, %21) +// CHECK-NEXT: %24 = affine_apply #map7(%i0, %i1, %16, %17, %18, %19, %20, %21) +// CHECK-NEXT: %25 = affine_apply #map8(%i0, %i1, %16, %17, %18, %19, %20, %21) +// CHECK-NEXT: %26 = affine_apply #map9(%i0, %i1, %16, %17, %18, %19, %20, %21) +// CHECK-NEXT: %27 = affine_apply #map10(%i0, %i1, %16, %17, %18, %19, %20, %21) +// CHECK-NEXT: %28 = load %2[%22, %23, %24, %25, %26, %27] : memref<1x2x3x3x16x1xi32> +// CHECK-NEXT: %29 = affine_apply #map18(%i0, %i1, %i0, %i1) +// CHECK-NEXT: %30 = affine_apply #map19(%i0, %i1, %i0, %i1) +// CHECK-NEXT: store %28, %1[%29, %30] : memref<1x1xi32> +// CHECK-NEXT: %31 = affine_apply #map18(%i0, %i1, %i0, %i1) +// CHECK-NEXT: %32 = affine_apply #map19(%i0, %i1, %i0, %i1) +// CHECK-NEXT: %33 = load %1[%31, %32] : memref<1x1xi32> +// CHECK-NEXT: %34 = muli %33, %33 : i32 +// CHECK-NEXT: store %34, %0[%i0, %i1] : memref<64x9xi32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return %0 : memref<64x9xi32> @@ -732,8 +758,8 @@ func @fuse_symbolic_bounds(%M : index, %N : index) { for %i2 = 0 to %M { for %i3 = 0 to %N { - %idx = affine_apply (d0, d1)[s0] -> (d0, d1 + s0) (%i2, %i3)[%s] - %v = load %m[%idx#0, %idx#1] : memref + %idy = affine_apply (d0)[s0] -> (d0 + s0) (%i3)[%s] + %v = load %m[%i2, %idy] : memref } } @@ -791,7 +817,8 @@ func @should_fuse_reduction_at_depth1() { } // ----- -// CHECK: #map0 = (d0, d1, d2) -> (-d0 + d1, d2) +// CHECK: #map0 = (d0, d1, d2) -> (-d0 + d1) +// CHECK: #map1 = (d0, d1, d2) -> (d2) // CHECK-LABEL: func @should_fuse_at_src_depth1_and_dst_depth1 func @should_fuse_at_src_depth1_and_dst_depth1() { @@ -828,12 +855,14 @@ func @should_fuse_at_src_depth1_and_dst_depth1() { // CHECK-NEXT: for %i2 = 0 to 16 { // CHECK-NEXT: %3 = "op1"() : () -> f32 // CHECK-NEXT: %4 = affine_apply #map0(%i0, %i0, %i2) - // CHECK-NEXT: store %3, %1[%4#0, %4#1] : memref<1x16xf32> + // CHECK-NEXT: %5 = affine_apply #map1(%i0, %i0, %i2) + // CHECK-NEXT: store %3, %1[%4, %5] : memref<1x16xf32> // CHECK-NEXT: } // CHECK-NEXT: for %i3 = 0 to 16 { - // CHECK-NEXT: %5 = affine_apply #map0(%i0, %i0, %i3) - // CHECK-NEXT: %6 = load %1[%5#0, %5#1] : memref<1x16xf32> - // CHECK-NEXT: "op2"(%6) : (f32) -> () + // CHECK-NEXT: %6 = affine_apply #map0(%i0, %i0, %i3) + // CHECK-NEXT: %7 = affine_apply #map1(%i0, %i0, %i3) + // CHECK-NEXT: %8 = load %1[%6, %7] : memref<1x16xf32> + // CHECK-NEXT: "op2"(%8) : (f32) -> () // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -903,7 +932,12 @@ func @fusion_at_depth0_not_currently_supported() { // ----- -// CHECK-DAG: #map0 = (d0, d1, d2, d3, d4, d5, d6, d7, d8, d9) -> (-d0 + d4, -d1 + d5, -d2 + d6, -d3 + d7, d8, d9) +// CHECK: #map0 = (d0, d1, d2, d3, d4, d5, d6, d7, d8, d9) -> (-d0 + d4) +// CHECK: #map1 = (d0, d1, d2, d3, d4, d5, d6, d7, d8, d9) -> (-d1 + d5) +// CHECK: #map2 = (d0, d1, d2, d3, d4, d5, d6, d7, d8, d9) -> (-d2 + d6) +// CHECK: #map3 = (d0, d1, d2, d3, d4, d5, d6, d7, d8, d9) -> (-d3 + d7) +// CHECK: #map4 = (d0, d1, d2, d3, d4, d5, d6, d7, d8, d9) -> (d8) +// CHECK: #map5 = (d0, d1, d2, d3, d4, d5, d6, d7, d8, d9) -> (d9) // CHECK-LABEL: func @should_fuse_deep_loop_nests func @should_fuse_deep_loop_nests() { @@ -965,7 +999,9 @@ func @should_fuse_deep_loop_nests() { // The first four loops of the source loop nest can be sliced with iteration // bounds which are a function of the first four loops of destination loop nest, // where the destination loops nests have been interchanged. -// CHECK: for %i0 = 0 to 3 { + +// CHECK: %2 = alloc() : memref<1x1x1x1x16x10xf32, 2> +// CHECK-NEXT: for %i0 = 0 to 3 { // CHECK-NEXT: for %i1 = 0 to 3 { // CHECK-NEXT: for %i2 = 0 to 2 { // CHECK-NEXT: for %i3 = 0 to 2 { @@ -979,20 +1015,30 @@ func @should_fuse_deep_loop_nests() { // CHECK-NEXT: for %i8 = 0 to 16 { // CHECK-NEXT: for %i9 = 0 to 10 { // CHECK-NEXT: %4 = affine_apply #map0(%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i8, %i9) -// CHECK-NEXT: store %cst, %2[%4#0, %4#1, %4#2, %4#3, %4#4, %4#5] : memref<1x1x1x1x16x10xf32, 2> +// CHECK-NEXT: %5 = affine_apply #map1(%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i8, %i9) +// CHECK-NEXT: %6 = affine_apply #map2(%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i8, %i9) +// CHECK-NEXT: %7 = affine_apply #map3(%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i8, %i9) +// CHECK-NEXT: %8 = affine_apply #map4(%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i8, %i9) +// CHECK-NEXT: %9 = affine_apply #map5(%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i8, %i9) +// CHECK-NEXT: store %cst, %2[%4, %5, %6, %7, %8, %9] : memref<1x1x1x1x16x10xf32, 2> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: for %i10 = 0 to 2 { // CHECK-NEXT: for %i11 = 0 to 2 { // CHECK-NEXT: for %i12 = 0 to 16 { // CHECK-NEXT: for %i13 = 0 to 10 { -// CHECK-NEXT: %5 = load %0[%i10, %i11, %i4, %i5, %i12, %i13] : memref<2x2x3x3x16x10xf32, 2> +// CHECK-NEXT: %10 = load %0[%i10, %i11, %i4, %i5, %i12, %i13] : memref<2x2x3x3x16x10xf32, 2> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: for %i14 = 0 to 16 { // CHECK-NEXT: for %i15 = 0 to 10 { -// CHECK-NEXT: %6 = affine_apply #map0(%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i14, %i15) -// CHECK-NEXT: %7 = load %2[%6#0, %6#1, %6#2, %6#3, %6#4, %6#5] : memref<1x1x1x1x16x10xf32, 2> +// CHECK-NEXT: %11 = affine_apply #map0(%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i14, %i15) +// CHECK-NEXT: %12 = affine_apply #map1(%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i14, %i15) +// CHECK-NEXT: %13 = affine_apply #map2(%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i14, %i15) +// CHECK-NEXT: %14 = affine_apply #map3(%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i14, %i15) +// CHECK-NEXT: %15 = affine_apply #map4(%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i14, %i15) +// CHECK-NEXT: %16 = affine_apply #map5(%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i14, %i15) +// CHECK-NEXT: %17 = load %2[%11, %12, %13, %14, %15, %16] : memref<1x1x1x1x16x10xf32, 2> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } @@ -1008,7 +1054,8 @@ func @should_fuse_deep_loop_nests() { } // ----- -// CHECK: #map0 = (d0, d1, d2) -> (-d0 + d1, d2) +// CHECK: #map0 = (d0, d1, d2) -> (-d0 + d1) +// CHECK: #map1 = (d0, d1, d2) -> (d2) // CHECK-LABEL: func @should_fuse_at_depth1_and_reduce_slice_trip_count func @should_fuse_at_depth1_and_reduce_slice_trip_count() { @@ -1048,11 +1095,13 @@ func @should_fuse_at_depth1_and_reduce_slice_trip_count() { // CHECK-NEXT: } // CHECK-NEXT: for %i2 = 0 to 16 { // CHECK-NEXT: %3 = affine_apply #map0(%i0, %i0, %i2) - // CHECK-NEXT: store %cst, %1[%3#0, %3#1] : memref<1x16xf32> + // CHECK-NEXT: %4 = affine_apply #map1(%i0, %i0, %i2) + // CHECK-NEXT: store %cst, %1[%3, %4] : memref<1x16xf32> // CHECK-NEXT: } // CHECK-NEXT: for %i3 = 0 to 16 { - // CHECK-NEXT: %4 = affine_apply #map0(%i0, %i0, %i3) - // CHECK-NEXT: %5 = load %1[%4#0, %4#1] : memref<1x16xf32> + // CHECK-NEXT: %5 = affine_apply #map0(%i0, %i0, %i3) + // CHECK-NEXT: %6 = affine_apply #map1(%i0, %i0, %i3) + // CHECK-NEXT: %7 = load %1[%5, %6] : memref<1x16xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -1259,8 +1308,8 @@ func @R3_to_R2_reshape() { for %ii = 0 to 32 { for %jj = 0 to 3 { %a0 = affine_apply (d0, d1) -> (d0 * 3 + d1) (%ii, %jj) - %a1 = affine_apply (d0) -> (d0 floordiv (3 * 16)) (%a0) - %v = load %in[%a1#0, %jj, %c0] + %idx = affine_apply (d0) -> (d0 floordiv (3 * 16)) (%a0) + %v = load %in[%idx, %jj, %c0] : memref<2x3x16xi32> } } @@ -1268,26 +1317,32 @@ func @R3_to_R2_reshape() { } // CHECK: #map0 = (d0, d1) -> ((d0 * 3 + d1) floordiv 48) // CHECK-NEXT: #map1 = ()[s0] -> (s0) -// CHECK-NEXT: #map2 = (d0, d1, d2, d3, d4) -> (d2 - (d0 * 25 + d1 * 24) floordiv 24, -d1 + d3, d4) -// CHECK-NEXT: #map3 = (d0, d1) -> (d0 * 3 + d1) -// CHECK-NEXT: #map4 = (d0) -> (d0 floordiv 48) +// CHECK-NEXT: #map2 = (d0, d1, d2, d3, d4) -> (d2 - (d0 * 25 + d1 * 24) floordiv 24) +// CHECK-NEXT: #map3 = (d0, d1, d2, d3, d4) -> (-d1 + d3) +// CHECK-NEXT: #map4 = (d0, d1, d2, d3, d4) -> (d4) +// CHECK-NEXT: #map5 = (d0, d1) -> (d0 * 3 + d1) +// CHECK-NEXT: #map6 = (d0) -> (d0 floordiv 48) + // CHECK-LABEL: func @R3_to_R2_reshape() // CHECK: %0 = alloc() : memref<1x1x1xi32> // CHECK-NEXT: for %i0 = 0 to 32 { // CHECK-NEXT: for %i1 = 0 to 3 { -// CHECK-NEXT: %1 = affine_apply #map0(%i0, %i1) -// CHECK-NEXT: %2 = affine_apply #map1()[%c0] -// CHECK-NEXT: %3 = "foo"(%1, %i1, %2) : (index, index, index) -> i32 -// CHECK-NEXT: %4 = affine_apply #map2(%i0, %i1, %1, %i1, %2) -// CHECK-NEXT: store %3, %0[%4#0, %4#1, %4#2] : memref<1x1x1xi32> -// CHECK-NEXT: %5 = affine_apply #map3(%i0, %i1) -// CHECK-NEXT: %6 = affine_apply #map4(%5) -// CHECK-NEXT: %7 = affine_apply #map2(%i0, %i1, %6, %i1, %c0) -// CHECK-NEXT: %8 = load %0[%7#0, %7#1, %7#2] : memref<1x1x1xi32> -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: return -// CHECK-NEXT: } +// CHECK-NEXT: %1 = affine_apply #map0(%i0, %i1) +// CHECK-NEXT: %2 = affine_apply #map1()[%c0] +// CHECK-NEXT: %3 = "foo"(%1, %i1, %2) : (index, index, index) -> i32 +// CHECK-NEXT: %4 = affine_apply #map2(%i0, %i1, %1, %i1, %2) +// CHECK-NEXT: %5 = affine_apply #map3(%i0, %i1, %1, %i1, %2) +// CHECK-NEXT: %6 = affine_apply #map4(%i0, %i1, %1, %i1, %2) +// CHECK-NEXT: store %3, %0[%4, %5, %6] : memref<1x1x1xi32> +// CHECK-NEXT: %7 = affine_apply #map5(%i0, %i1) +// CHECK-NEXT: %8 = affine_apply #map6(%7) +// CHECK-NEXT: %9 = affine_apply #map2(%i0, %i1, %8, %i1, %c0) +// CHECK-NEXT: %10 = affine_apply #map3(%i0, %i1, %8, %i1, %c0) +// CHECK-NEXT: %11 = affine_apply #map4(%i0, %i1, %8, %i1, %c0) +// CHECK-NEXT: %12 = load %0[%9, %10, %11] : memref<1x1x1xi32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return // ----- -- cgit v1.2.3 From 75c21e1de01079cb65654d1c88007dba8772e6a1 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Fri, 25 Jan 2019 22:14:04 -0800 Subject: Wrap cl::opt flags within passes in a category with the pass name. This improves the help output of tools like mlir-opt. Example: dma-generate options: -dma-fast-mem-capacity - Set fast memory space ... -dma-fast-mem-space= - Set fast memory space ... loop-fusion options: -fusion-compute-tolerance= - Fractional increase in ... -fusion-maximal - Enables maximal loop fusion loop-tile options: -tile-size= - Use this tile size for ... loop-unroll options: -unroll-factor= - Use this unroll factor ... -unroll-full - Fully unroll loops -unroll-full-threshold= - Unroll all loops with ... -unroll-num-reps= - Unroll innermost loops ... loop-unroll-jam options: -unroll-jam-factor= - Use this unroll jam factor ... PiperOrigin-RevId: 231019363 --- mlir/lib/Transforms/DmaGeneration.cpp | 8 ++++-- mlir/lib/Transforms/LoopFusion.cpp | 8 ++++-- mlir/lib/Transforms/LoopTiling.cpp | 7 ++++- mlir/lib/Transforms/LoopUnroll.cpp | 16 ++++++++--- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 7 ++++- .../Vectorization/VectorizerTestPass.cpp | 31 +++++++++++++--------- mlir/lib/Transforms/Vectorize.cpp | 6 +++-- 7 files changed, 58 insertions(+), 25 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 07965840dd7..0437fb143e0 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -39,13 +39,17 @@ using namespace mlir; using llvm::SmallMapVector; +static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); + static llvm::cl::opt clFastMemorySpace( "dma-fast-mem-space", llvm::cl::Hidden, - llvm::cl::desc("Set fast memory space id for DMA generation")); + llvm::cl::desc("Set fast memory space id for DMA generation"), + llvm::cl::cat(clOptionsCategory)); static llvm::cl::opt clFastMemoryCapacity( "dma-fast-mem-capacity", llvm::cl::Hidden, - llvm::cl::desc("Set fast memory space capacity in KiB")); + llvm::cl::desc("Set fast memory space capacity in KiB"), + llvm::cl::cat(clOptionsCategory)); namespace { diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 0add7972420..77a455993fb 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -47,16 +47,20 @@ using llvm::SetVector; using namespace mlir; +static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); + /// Disables fusion profitability check and fuses if valid. static llvm::cl::opt clMaximalLoopFusion("fusion-maximal", llvm::cl::Hidden, - llvm::cl::desc("Enables maximal loop fusion")); + llvm::cl::desc("Enables maximal loop fusion"), + llvm::cl::cat(clOptionsCategory)); /// A threshold in percent of additional computation allowed when fusing. static llvm::cl::opt clFusionAddlComputeTolerance( "fusion-compute-tolerance", llvm::cl::Hidden, llvm::cl::desc("Fractional increase in additional" - "computation tolerated while fusing")); + " computation tolerated while fusing"), + llvm::cl::cat(clOptionsCategory)); namespace { diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index ee66c9b17b1..2a4b7bcd262 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -31,10 +31,15 @@ using namespace mlir; +#define DEBUG_TYPE "loop-tile" + +static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); + // Tile size for all loops. static llvm::cl::opt clTileSize("tile-size", llvm::cl::Hidden, - llvm::cl::desc("Use this tile size for all loops")); + llvm::cl::desc("Use this tile size for all loops"), + llvm::cl::cat(clOptionsCategory)); namespace { diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 554fbc26577..39ef758833b 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -35,22 +35,30 @@ using namespace mlir; +#define DEBUG_TYPE "loop-unroll" + +static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); + // Loop unrolling factor. static llvm::cl::opt clUnrollFactor( "unroll-factor", llvm::cl::Hidden, - llvm::cl::desc("Use this unroll factor for all loops being unrolled")); + llvm::cl::desc("Use this unroll factor for all loops being unrolled"), + llvm::cl::cat(clOptionsCategory)); static llvm::cl::opt clUnrollFull("unroll-full", llvm::cl::Hidden, - llvm::cl::desc("Fully unroll loops")); + llvm::cl::desc("Fully unroll loops"), + llvm::cl::cat(clOptionsCategory)); static llvm::cl::opt clUnrollNumRepetitions( "unroll-num-reps", llvm::cl::Hidden, - llvm::cl::desc("Unroll innermost loops repeatedly this many times")); + llvm::cl::desc("Unroll innermost loops repeatedly this many times"), + llvm::cl::cat(clOptionsCategory)); static llvm::cl::opt clUnrollFullThreshold( "unroll-full-threshold", llvm::cl::Hidden, llvm::cl::desc( - "Unroll all loops with trip count less than or equal to this")); + "Unroll all loops with trip count less than or equal to this"), + llvm::cl::cat(clOptionsCategory)); namespace { /// Loop unrolling pass. Unrolls all innermost loops unless full unrolling and a diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index d0cf27804d4..71d77817254 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -57,11 +57,16 @@ using namespace mlir; +#define DEBUG_TYPE "loop-unroll-jam" + +static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); + // Loop unroll and jam factor. static llvm::cl::opt clUnrollJamFactor("unroll-jam-factor", llvm::cl::Hidden, llvm::cl::desc("Use this unroll jam factor for all loops" - " (default 4)")); + " (default 4)"), + llvm::cl::cat(clOptionsCategory)); namespace { /// Loop unroll jam pass. Currently, this just unroll jams the first diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index ad966e8d280..c08ffd4cd7d 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -43,36 +43,41 @@ using llvm::SetVector; using functional::map; +static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); + static llvm::cl::list clTestVectorShapeRatio( "vector-shape-ratio", llvm::cl::desc("Specify the HW vector size for vectorization"), - llvm::cl::ZeroOrMore); + llvm::cl::ZeroOrMore, llvm::cl::cat(clOptionsCategory)); static llvm::cl::opt clTestForwardSlicingAnalysis( "forward-slicing", - llvm::cl::desc( - "Specify to enable testing forward static slicing and topological sort " - "functionalities")); + llvm::cl::desc("Enable testing forward static slicing and topological sort " + "functionalities"), + llvm::cl::cat(clOptionsCategory)); static llvm::cl::opt clTestBackwardSlicingAnalysis( "backward-slicing", - llvm::cl::desc("Specify to enable testing backward static slicing and " - "topological sort functionalities")); + llvm::cl::desc("Enable testing backward static slicing and " + "topological sort functionalities"), + llvm::cl::cat(clOptionsCategory)); static llvm::cl::opt clTestSlicingAnalysis( "slicing", - llvm::cl::desc( - "Specify to enable testing static slicing and topological sort " - "functionalities")); + llvm::cl::desc("Enable testing static slicing and topological sort " + "functionalities"), + llvm::cl::cat(clOptionsCategory)); static llvm::cl::opt clTestComposeMaps( "compose-maps", llvm::cl::desc( - "Specify to enable testing the composition of AffineMap where each " + "Enable testing the composition of AffineMap where each " "AffineMap in the composition is specified as the affine_map attribute " - "in a constant op.")); + "in a constant op."), + llvm::cl::cat(clOptionsCategory)); static llvm::cl::opt clTestNormalizeMaps( "normalize-maps", llvm::cl::desc( - "Specify to enable testing the normalization of AffineAffineApplyOp " + "Enable testing the normalization of AffineAffineApplyOp " "where each AffineAffineApplyOp in the composition is a single output " - "instruction.")); + "instruction."), + llvm::cl::cat(clOptionsCategory)); namespace { diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 3ad82faa7be..29a97991d5e 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -547,10 +547,12 @@ using llvm::dbgs; using llvm::DenseSet; using llvm::SetVector; +static llvm::cl::OptionCategory clOptionsCategory("vectorize options"); + static llvm::cl::list clVirtualVectorSize( "virtual-vector-size", llvm::cl::desc("Specify n-D virtual vector size for early vectorization"), - llvm::cl::ZeroOrMore); + llvm::cl::ZeroOrMore, llvm::cl::cat(clOptionsCategory)); static llvm::cl::list clFastestVaryingPattern( "test-fastest-varying", @@ -558,7 +560,7 @@ static llvm::cl::list clFastestVaryingPattern( "Specify a 1-D, 2-D or 3-D pattern of fastest varying memory" " dimensions to match. See defaultPatterns in Vectorize.cpp for a" " description and examples. This is used for testing purposes"), - llvm::cl::ZeroOrMore); + llvm::cl::ZeroOrMore, llvm::cl::cat(clOptionsCategory)); /// Forward declaration. static FilterFunctionType -- cgit v1.2.3 From 0e7a8a9027c5f5862c3c78f41c777ba2930f9b23 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Sat, 26 Jan 2019 10:41:17 -0800 Subject: Drop AffineMap::Null and IntegerSet::Null Addresses b/122486036 This CL addresses some leftover crumbs in AffineMap and IntegerSet by removing the Null method and cleaning up the constructors. As the ::Null uses were tracked down, opportunities appeared to untangle some of the Parsing logic and make it explicit where AffineMap/IntegerSet have ambiguous syntax. Previously, ambiguous cases were hidden behind the implicit pointer values of AffineMap* and IntegerSet* that were passed as function parameters. Depending the values of those pointers one of 3 behaviors could occur. This parsing logic convolution is one of the rare cases where I would advocate for code duplication. The more proper fix would be to make the syntax unambiguous or to allow some lookahead. PiperOrigin-RevId: 231058512 --- mlir/include/mlir/Analysis/AffineStructures.h | 2 +- mlir/include/mlir/IR/AffineMap.h | 12 +- mlir/include/mlir/IR/IntegerSet.h | 12 +- mlir/include/mlir/Transforms/Utils.h | 2 +- mlir/lib/Analysis/AffineStructures.cpp | 4 +- mlir/lib/Analysis/Utils.cpp | 4 +- mlir/lib/IR/Instruction.cpp | 10 +- mlir/lib/IR/MLIRContext.cpp | 2 +- mlir/lib/Parser/Parser.cpp | 318 +++++++++++++------------- mlir/lib/Transforms/LoopFusion.cpp | 16 +- mlir/lib/Transforms/PipelineDataTransfer.cpp | 2 +- mlir/lib/Transforms/Utils/LoopUtils.cpp | 8 +- 12 files changed, 195 insertions(+), 197 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h index 5735f9d4a9a..49202eb7cc5 100644 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -378,7 +378,7 @@ public: /// identifiers as an affine map of the remaining identifiers (dimensional and /// symbolic). This method is able to detect identifiers as floordiv's /// and mod's of affine expressions of other identifiers with respect to - /// (positive) constants. Sets bound map to AffineMap::Null if such a bound + /// (positive) constants. Sets bound map to a null AffineMap if such a bound /// can't be found (or yet unimplemented). void getSliceBounds(unsigned num, MLIRContext *context, SmallVectorImpl *lbMaps, diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h index e4d12f87301..cb2c315acea 100644 --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -46,8 +46,10 @@ class AffineMap { public: using ImplType = detail::AffineMapStorage; - explicit AffineMap(ImplType *map = nullptr) : map(map) {} - static AffineMap Null() { return AffineMap(nullptr); } + AffineMap() : map(nullptr) {} + explicit AffineMap(ImplType *map) : map(map) {} + AffineMap(const AffineMap &other) : map(other.map) {} + AffineMap &operator=(const AffineMap &other) = default; static AffineMap get(unsigned dimCount, unsigned symbolCount, ArrayRef results, @@ -62,9 +64,9 @@ public: MLIRContext *getContext() const; - explicit operator bool() { return map; } - bool operator==(const AffineMap &other) const { return other.map == map; } - bool operator!=(const AffineMap &other) const { return !(other.map == map); } + explicit operator bool() { return map != nullptr; } + bool operator==(AffineMap other) const { return other.map == map; } + bool operator!=(AffineMap other) const { return !(other.map == map); } /// Returns true if the co-domain (or more loosely speaking, range) of this /// map is bounded. Bounded affine maps have a size (extent) for each of diff --git a/mlir/include/mlir/IR/IntegerSet.h b/mlir/include/mlir/IR/IntegerSet.h index d0d82220922..6a97827934a 100644 --- a/mlir/include/mlir/IR/IntegerSet.h +++ b/mlir/include/mlir/IR/IntegerSet.h @@ -52,12 +52,10 @@ class IntegerSet { public: using ImplType = detail::IntegerSetStorage; - explicit IntegerSet(ImplType *set = nullptr) : set(set) {} - - IntegerSet &operator=(const IntegerSet other) { - set = other.set; - return *this; - } + IntegerSet() : set(nullptr) {} + explicit IntegerSet(ImplType *set) : set(set) {} + IntegerSet(const IntegerSet &other) : set(other.set) {} + IntegerSet &operator=(const IntegerSet &other) = default; static IntegerSet get(unsigned dimCount, unsigned symbolCount, ArrayRef constraints, @@ -74,8 +72,6 @@ public: /// Returns true if this is the canonical integer set. bool isEmptyIntegerSet() const; - static IntegerSet Null() { return IntegerSet(nullptr); } - explicit operator bool() { return set; } bool operator==(IntegerSet other) const { return set == other.set; } diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h index b3fe4471699..5c7260c9a58 100644 --- a/mlir/include/mlir/Transforms/Utils.h +++ b/mlir/include/mlir/Transforms/Utils.h @@ -65,7 +65,7 @@ class Function; // TODO(bondhugula): allow extraIndices to be added at any position. bool replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, ArrayRef extraIndices = {}, - AffineMap indexRemap = AffineMap::Null(), + AffineMap indexRemap = AffineMap(), ArrayRef extraOperands = {}, const Instruction *domInstFilter = nullptr); diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index d2adfcaa3d2..268fbe0c9c6 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1133,8 +1133,8 @@ void FlatAffineConstraints::getSliceBounds(unsigned num, MLIRContext *context, numMapDims, numMapSymbols, getAffineConstantExpr(ubConst.getValue() + 1, context), {}); } else { - (*lbMaps)[pos] = AffineMap::Null(); - (*ubMaps)[pos] = AffineMap::Null(); + (*lbMaps)[pos] = AffineMap(); + (*ubMaps)[pos] = AffineMap(); } } LLVM_DEBUG(llvm::dbgs() << "lb map for pos = " << Twine(pos) << ", expr: "); diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 91939c235da..39e58e8983c 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -410,8 +410,8 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess, numDstLoopIVs - dstLoopDepth); // Set up lower/upper bound affine maps for the slice. - sliceState->lbs.resize(numSrcLoopIVs, AffineMap::Null()); - sliceState->ubs.resize(numSrcLoopIVs, AffineMap::Null()); + sliceState->lbs.resize(numSrcLoopIVs, AffineMap()); + sliceState->ubs.resize(numSrcLoopIVs, AffineMap()); // Get bounds for src IVs in terms of dst IVs, symbols, and constants. dependenceConstraints.getSliceBounds(numSrcLoopIVs, diff --git a/mlir/lib/IR/Instruction.cpp b/mlir/lib/IR/Instruction.cpp index 6d84765e22c..b8a3e581329 100644 --- a/mlir/lib/IR/Instruction.cpp +++ b/mlir/lib/IR/Instruction.cpp @@ -612,10 +612,12 @@ bool OperationInst::emitOpError(const Twine &message) const { ForInst *ForInst::create(Location location, ArrayRef lbOperands, AffineMap lbMap, ArrayRef ubOperands, AffineMap ubMap, int64_t step) { - assert(lbOperands.size() == lbMap.getNumInputs() && - "lower bound operand count does not match the affine map"); - assert(ubOperands.size() == ubMap.getNumInputs() && - "upper bound operand count does not match the affine map"); + assert((!lbMap && lbOperands.empty()) || + lbOperands.size() == lbMap.getNumInputs() && + "lower bound operand count does not match the affine map"); + assert((!ubMap && ubOperands.empty()) || + ubOperands.size() == ubMap.getNumInputs() && + "upper bound operand count does not match the affine map"); assert(step > 0 && "step has to be a positive integer constant"); unsigned numOperands = lbOperands.size() + ubOperands.size(); diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 69642980677..b0a7979c626 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -1202,7 +1202,7 @@ AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, // Check if we already have this affine map. auto key = std::make_tuple(dimCount, symbolCount, results, rangeSizes); - auto existing = impl.affineMaps.insert_as(AffineMap(nullptr), key); + auto existing = impl.affineMaps.insert_as(AffineMap(), key); // If we already have it, return that value. if (!existing.second) diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 88dedcf7e49..ecb7fbc779e 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -200,12 +200,10 @@ public: ParseResult parseAttributeDict(SmallVectorImpl &attributes); // Polyhedral structures. - void parseAffineStructureInline(AffineMap *map, IntegerSet *set); - void parseAffineStructureReference(AffineMap *map, IntegerSet *set); - AffineMap parseAffineMapInline(); AffineMap parseAffineMapReference(); - IntegerSet parseIntegerSetInline(); IntegerSet parseIntegerSetReference(); + ParseResult parseAffineMapOrIntegerSetReference(AffineMap &map, + IntegerSet &set); DenseElementsAttr parseDenseElementsAttr(VectorOrTensorType type); DenseElementsAttr parseDenseElementsAttr(Type eltType, bool isVector); VectorOrTensorType parseVectorOrTensorType(); @@ -997,13 +995,13 @@ Attribute Parser::parseAttribute(Type type) { // Try to parse an affine map or an integer set reference. AffineMap map; IntegerSet set; - parseAffineStructureReference(&map, &set); + if (parseAffineMapOrIntegerSetReference(map, set)) + return (emitError("expected affine map or integer set attribute value"), + nullptr); if (map) return builder.getAffineMapAttr(map); - if (set) - return builder.getIntegerSetAttr(set); - return (emitError("expected affine map or integer set attribute value"), - nullptr); + assert(set); + return builder.getIntegerSetAttr(set); } case Token::at_identifier: { @@ -1474,8 +1472,10 @@ class AffineParser : public Parser { public: explicit AffineParser(ParserState &state) : Parser(state) {} - void parseAffineStructureInline(AffineMap *map, IntegerSet *set); + AffineMap parseAffineMapInline(); AffineMap parseAffineMapRange(unsigned numDims, unsigned numSymbols); + IntegerSet parseIntegerSetInline(); + ParseResult parseAffineMapOrIntegerSetInline(AffineMap &map, IntegerSet &set); IntegerSet parseIntegerSetConstraints(unsigned numDims, unsigned numSymbols); private: @@ -1486,6 +1486,8 @@ private: // Identifier lists for polyhedral structures. ParseResult parseDimIdList(unsigned &numDims); ParseResult parseSymbolIdList(unsigned &numSymbols); + ParseResult parseDimAndOptionalSymbolIdList(unsigned &numDims, + unsigned &numSymbols); ParseResult parseIdentifierDefinition(AffineExpr idExpr); AffineExpr parseAffineExpr(); @@ -1841,21 +1843,12 @@ ParseResult AffineParser::parseIdentifierDefinition(AffineExpr idExpr) { return ParseSuccess; } -/// Parse the list of symbolic identifiers to an affine map. -ParseResult AffineParser::parseSymbolIdList(unsigned &numSymbols) { - consumeToken(Token::l_square); - auto parseElt = [&]() -> ParseResult { - auto symbol = getAffineSymbolExpr(numSymbols++, getContext()); - return parseIdentifierDefinition(symbol); - }; - return parseCommaSeparatedListUntil(Token::r_square, parseElt); -} - /// Parse the list of dimensional identifiers to an affine map. ParseResult AffineParser::parseDimIdList(unsigned &numDims) { if (parseToken(Token::l_paren, - "expected '(' at start of dimensional identifiers list")) + "expected '(' at start of dimensional identifiers list")) { return ParseFailure; + } auto parseElt = [&]() -> ParseResult { auto dimension = getAffineDimExpr(numDims++, getContext()); @@ -1864,10 +1857,31 @@ ParseResult AffineParser::parseDimIdList(unsigned &numDims) { return parseCommaSeparatedListUntil(Token::r_paren, parseElt); } -/// Parses either an affine map or an integer set definition inline. If both -/// 'map' and 'set' are non-null, parses either an affine map or an integer set. -/// If 'map' is set to nullptr, parses an integer set. If 'set' is set to -/// nullptr, parses an affine map. 'map'/'set' are set to the parsed structure. +/// Parse the list of symbolic identifiers to an affine map. +ParseResult AffineParser::parseSymbolIdList(unsigned &numSymbols) { + consumeToken(Token::l_square); + auto parseElt = [&]() -> ParseResult { + auto symbol = getAffineSymbolExpr(numSymbols++, getContext()); + return parseIdentifierDefinition(symbol); + }; + return parseCommaSeparatedListUntil(Token::r_square, parseElt); +} + +/// Parse the list of symbolic identifiers to an affine map. +ParseResult +AffineParser::parseDimAndOptionalSymbolIdList(unsigned &numDims, + unsigned &numSymbols) { + if (parseDimIdList(numDims)) { + return ParseResult::ParseFailure; + } + if (!getToken().is(Token::l_square)) { + numSymbols = 0; + return ParseResult::ParseSuccess; + } + return parseSymbolIdList(numSymbols); +} + +/// Parses an affine map definition inline. /// /// affine-map-inline ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr /// (`size` `(` dim-size (`,` dim-size)* `)`)? @@ -1875,6 +1889,23 @@ ParseResult AffineParser::parseDimIdList(unsigned &numDims) { /// /// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `) /// +AffineMap AffineParser::parseAffineMapInline() { + unsigned numDims = 0, numSymbols = 0; + + // List of dimensional and optional symbol identifiers. + if (parseDimAndOptionalSymbolIdList(numDims, numSymbols)) { + return AffineMap(); + } + + if (parseToken(Token::arrow, "expected '->' or '['")) { + return AffineMap(); + } + + // Parse the affine map. + return parseAffineMapRange(numDims, numSymbols); +} + +/// Parses an integer set definition inline. /// /// integer-set-inline /// ::= dim-and-symbol-id-lists `:` @@ -1883,68 +1914,49 @@ ParseResult AffineParser::parseDimIdList(unsigned &numDims) { /// | affine-constraint (`,` /// affine-constraint)* /// -void AffineParser::parseAffineStructureInline(AffineMap *map, IntegerSet *set) { - assert((map || set) && "one of map or set expected to be non-null"); - +IntegerSet AffineParser::parseIntegerSetInline() { unsigned numDims = 0, numSymbols = 0; - // List of dimensional identifiers. - if (parseDimIdList(numDims)) { - if (map) - *map = AffineMap::Null(); - if (set) - *set = IntegerSet::Null(); - return; + // List of dimensional and optional symbol identifiers. + if (parseDimAndOptionalSymbolIdList(numDims, numSymbols)) { + return IntegerSet(); } - // Symbols are optional. - if (getToken().is(Token::l_square)) { - if (parseSymbolIdList(numSymbols)) { - if (map) - *map = AffineMap::Null(); - if (set) - *set = IntegerSet::Null(); - return; - } + if (parseToken(Token::colon, "expected ':' or '['")) { + return IntegerSet(); + } + + return parseIntegerSetConstraints(numDims, numSymbols); +} + +/// Parses an ambiguous affine map or integer set definition inline. +ParseResult AffineParser::parseAffineMapOrIntegerSetInline(AffineMap &map, + IntegerSet &set) { + unsigned numDims = 0, numSymbols = 0; + + // List of dimensional and optional symbol identifiers. + if (parseDimAndOptionalSymbolIdList(numDims, numSymbols)) { + return ParseResult::ParseFailure; } // This is needed for parsing attributes as we wouldn't know whether we would // be parsing an integer set attribute or an affine map attribute. - if (map && set && getToken().isNot(Token::arrow) && - getToken().isNot(Token::colon)) { - emitError("expected '->' or ':' or '['"); - *map = AffineMap::Null(); - *set = IntegerSet::Null(); - return; - } - - if (map && (!set || getToken().is(Token::arrow))) { - // Parse an affine map. - if (parseToken(Token::arrow, "expected '->' or '['")) { - *map = AffineMap::Null(); - if (set) - *set = IntegerSet::Null(); - return; - } - *map = parseAffineMapRange(numDims, numSymbols); - if (set) - *set = IntegerSet::Null(); - return; - } - - if (set && (!map || getToken().is(Token::colon))) { - // Parse an integer set. - if (parseToken(Token::colon, "expected ':' or '['")) { - *set = IntegerSet::Null(); - if (map) - *map = AffineMap::Null(); - return; - } - *set = parseIntegerSetConstraints(numDims, numSymbols); - if (map) - *map = AffineMap::Null(); - return; + bool isArrow = getToken().is(Token::arrow); + bool isColon = getToken().is(Token::colon); + if (!isArrow && !isColon) { + return ParseFailure; + } else if (isArrow) { + parseToken(Token::arrow, "expected '->' or '['"); + map = parseAffineMapRange(numDims, numSymbols); + return map ? ParseSuccess : ParseFailure; + } else if (parseToken(Token::colon, "expected ':' or '['")) { + return ParseFailure; } + + if ((set = parseIntegerSetConstraints(numDims, numSymbols))) + return ParseSuccess; + + return ParseFailure; } /// Parse the range and sizes affine map definition inline. @@ -1970,7 +1982,7 @@ AffineMap AffineParser::parseAffineMapRange(unsigned numDims, // 1-d affine expressions); the list cannot be empty. Grammar: // multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `) if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, false)) - return AffineMap::Null(); + return AffineMap(); // Parse optional range sizes. // range-sizes ::= (`size` `(` dim-size (`,` dim-size)* `)`)? @@ -1982,7 +1994,7 @@ AffineMap AffineParser::parseAffineMapRange(unsigned numDims, // Location of the l_paren token (if it exists) for error reporting later. auto loc = getToken().getLoc(); if (parseToken(Token::l_paren, "expected '(' at start of affine map range")) - return AffineMap::Null(); + return AffineMap(); auto parseRangeSize = [&]() -> ParseResult { auto loc = getToken().getLoc(); @@ -1999,99 +2011,90 @@ AffineMap AffineParser::parseAffineMapRange(unsigned numDims, }; if (parseCommaSeparatedListUntil(Token::r_paren, parseRangeSize, false)) - return AffineMap::Null(); + return AffineMap(); if (exprs.size() > rangeSizes.size()) return (emitError(loc, "fewer range sizes than range expressions"), - AffineMap::Null()); + AffineMap()); if (exprs.size() < rangeSizes.size()) return (emitError(loc, "more range sizes than range expressions"), - AffineMap::Null()); + AffineMap()); } // Parsed a valid affine map. return builder.getAffineMap(numDims, numSymbols, exprs, rangeSizes); } -void Parser::parseAffineStructureInline(AffineMap *map, IntegerSet *set) { - AffineParser(state).parseAffineStructureInline(map, set); -} +/// Parse a reference to an integer set. +/// integer-set ::= integer-set-id | integer-set-inline +/// integer-set-id ::= `#` suffix-id +/// +IntegerSet Parser::parseIntegerSetReference() { + if (getToken().isNot(Token::hash_identifier)) { + // Try to parse inline integer set. + return AffineParser(state).parseIntegerSetInline(); + } -AffineMap Parser::parseAffineMapInline() { - AffineMap map; - AffineParser(state).parseAffineStructureInline(&map, nullptr); - return map; + // Parse integer set identifier and verify that it exists. + StringRef id = getTokenSpelling().drop_front(); + if (getState().integerSetDefinitions.count(id) > 0) { + consumeToken(Token::hash_identifier); + return getState().integerSetDefinitions[id]; + } + + // The id isn't among any of the recorded definitions. + emitError("undefined integer set id '" + id + "'"); + return IntegerSet(); } -/// Parse either an affine map reference or integer set reference. -/// -/// affine-structure ::= affine-structure-id | affine-structure-inline -/// affine-structure-id ::= `#` suffix-id -/// -/// affine-structure ::= affine-map | integer-set +/// Parse a reference to an affine map. +/// affine-map ::= affine-map-id | affine-map-inline +/// affine-map-id ::= `#` suffix-id /// -void Parser::parseAffineStructureReference(AffineMap *map, IntegerSet *set) { - assert((map || set) && "both map and set are non-null"); +AffineMap Parser::parseAffineMapReference() { + if (getToken().isNot(Token::hash_identifier)) { + // Try to parse inline affine map. + return AffineParser(state).parseAffineMapInline(); + } + + // Parse affine map identifier and verify that it exists. + StringRef id = getTokenSpelling().drop_front(); + if (getState().affineMapDefinitions.count(id) > 0) { + consumeToken(Token::hash_identifier); + return getState().affineMapDefinitions[id]; + } + + // The id isn't among any of the recorded definitions. + emitError("undefined affine map id '" + id + "'"); + return AffineMap(); +} + +/// Parse an ambiguous reference to either and affine map or an integer set. +ParseResult Parser::parseAffineMapOrIntegerSetReference(AffineMap &map, + IntegerSet &set) { if (getToken().isNot(Token::hash_identifier)) { - // Try to parse inline affine map or integer set. - return parseAffineStructureInline(map, set); + // Try to parse inline affine map. + return AffineParser(state).parseAffineMapOrIntegerSetInline(map, set); } // Parse affine map / integer set identifier and verify that it exists. // Note that an id can't be in both affineMapDefinitions and // integerSetDefinitions since they use the same sigil '#'. - StringRef affineStructId = getTokenSpelling().drop_front(); - if (getState().affineMapDefinitions.count(affineStructId) > 0) { + StringRef id = getTokenSpelling().drop_front(); + if (getState().affineMapDefinitions.count(id) > 0) { consumeToken(Token::hash_identifier); - if (map) - *map = getState().affineMapDefinitions[affineStructId]; - if (set) - *set = IntegerSet::Null(); - return; + map = getState().affineMapDefinitions[id]; + return ParseSuccess; } - - if (getState().integerSetDefinitions.count(affineStructId) > 0) { + if (getState().integerSetDefinitions.count(id) > 0) { consumeToken(Token::hash_identifier); - if (set) - *set = getState().integerSetDefinitions[affineStructId]; - if (map) - *map = AffineMap::Null(); - return; + set = getState().integerSetDefinitions[id]; + return ParseSuccess; } // The id isn't among any of the recorded definitions. - // Emit the right message depending on what the caller expected. - if (map && !set) - emitError("undefined affine map id '" + affineStructId + "'"); - else if (set && !map) - emitError("undefined integer set id '" + affineStructId + "'"); - else if (set && map) - emitError("undefined affine map or integer set id '" + affineStructId + - "'"); - - if (map) - *map = AffineMap::Null(); - if (set) - *set = IntegerSet::Null(); -} + emitError("undefined affine map or integer set id '" + id + "'"); -/// Parse a reference to an integer set. -/// affine-map ::= affine-map-id | affine-map-inline -/// affine-map-id ::= `#` suffix-id -/// -AffineMap Parser::parseAffineMapReference() { - AffineMap map; - parseAffineStructureReference(&map, nullptr); - return map; -} - -/// Parse a reference to an integer set. -/// integer-set ::= integer-set-id | integer-set-inline -/// integer-set-id ::= `#` suffix-id -/// -IntegerSet Parser::parseIntegerSetReference() { - IntegerSet set; - parseAffineStructureReference(nullptr, &set); - return set; + return ParseFailure; } //===----------------------------------------------------------------------===// @@ -3402,7 +3405,7 @@ IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims, unsigned numSymbols) { if (parseToken(Token::l_paren, "expected '(' at start of integer set constraint list")) - return IntegerSet::Null(); + return IntegerSet(); SmallVector constraints; SmallVector isEqs; @@ -3422,24 +3425,18 @@ IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims, // affine-constraint)* `) auto constraintListLoc = getToken().getLoc(); if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true)) - return IntegerSet::Null(); + return IntegerSet(); // Check that at least one constraint was parsed. if (constraints.empty()) { emitError(constraintListLoc, "expected a valid affine constraint"); - return IntegerSet::Null(); + return IntegerSet(); } // Parsed a valid integer set. return builder.getIntegerSet(numDims, numSymbols, constraints, isEqs); } -IntegerSet Parser::parseIntegerSetInline() { - IntegerSet set; - AffineParser(state).parseAffineStructureInline(nullptr, &set); - return set; -} - /// If instruction. /// /// ml-if-head ::= `if` ml-if-cond trailing-location? `{` inst* `}` @@ -3564,15 +3561,16 @@ ParseResult ModuleParser::parseAffineStructureDef() { AffineMap map; IntegerSet set; - parseAffineStructureInline(&map, &set); - if (!map && !set) + if (AffineParser(getState()).parseAffineMapOrIntegerSetInline(map, set)) return ParseFailure; - if (map) + if (map) { getState().affineMapDefinitions[affineStructureId] = map; - else - getState().integerSetDefinitions[affineStructureId] = set; + return ParseSuccess; + } + assert(set); + getState().integerSetDefinitions[affineStructureId] = set; return ParseSuccess; } diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 77a455993fb..cee0a08a63c 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -595,7 +595,7 @@ static bool buildSliceTripCountMap( for (unsigned i = 0; i < numSrcLoopIVs; ++i) { AffineMap lbMap = sliceState->lbs[i]; AffineMap ubMap = sliceState->ubs[i]; - if (lbMap == AffineMap::Null() || ubMap == AffineMap::Null()) { + if (lbMap == AffineMap() || ubMap == AffineMap()) { // The iteration of src loop IV 'i' was not sliced. Use full loop bounds. if (srcLoopIVs[i]->hasConstantLowerBound() && srcLoopIVs[i]->hasConstantUpperBound()) { @@ -675,16 +675,16 @@ static bool getSliceUnion(const ComputationSliceState &sliceStateA, for (unsigned i = 0, e = sliceStateA.lbs.size(); i < e; ++i) { AffineMap lbMapA = sliceStateA.lbs[i]; AffineMap ubMapA = sliceStateA.ubs[i]; - if (lbMapA == AffineMap::Null()) { - assert(ubMapA == AffineMap::Null()); + if (lbMapA == AffineMap()) { + assert(ubMapA == AffineMap()); continue; } assert(ubMapA && "expected non-null ub map"); AffineMap lbMapB = sliceStateB->lbs[i]; AffineMap ubMapB = sliceStateB->ubs[i]; - if (lbMapB == AffineMap::Null()) { - assert(ubMapB == AffineMap::Null()); + if (lbMapB == AffineMap()) { + assert(ubMapB == AffineMap()); // Union 'sliceStateB' does not have a bound for 'i' so copy from A. sliceStateB->lbs[i] = lbMapA; sliceStateB->ubs[i] = ubMapA; @@ -799,7 +799,7 @@ static Value *createPrivateMemRef(ForInst *forInst, } auto indexRemap = zeroOffsetCount == rank - ? AffineMap::Null() + ? AffineMap() : b.getAffineMap(outerIVs.size() + rank, 0, remapExprs, {}); // Replace all users of 'oldMemRef' with 'newMemRef'. bool ret = @@ -1107,11 +1107,11 @@ static bool isFusionProfitable(OperationInst *srcOpInst, // Canonicalize slice bound affine maps. for (unsigned i = 0; i < numSrcLoopIVs; ++i) { - if (sliceState->lbs[i] != AffineMap::Null()) { + if (sliceState->lbs[i] != AffineMap()) { canonicalizeMapAndOperands(&sliceState->lbs[i], &sliceState->lbOperands[i]); } - if (sliceState->ubs[i] != AffineMap::Null()) { + if (sliceState->ubs[i] != AffineMap()) { canonicalizeMapAndOperands(&sliceState->ubs[i], &sliceState->ubOperands[i]); } diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 101a00eaf61..e72b9ef80df 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -127,7 +127,7 @@ static bool doubleBuffer(Value *oldMemRef, ForInst *forInst) { // replaceAllMemRefUsesWith will always succeed unless the forInst body has // non-deferencing uses of the memref. if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef, ivModTwoOp->getResult(0), - AffineMap::Null(), {}, + AffineMap(), {}, &*forInst->getBody()->begin())) { LLVM_DEBUG(llvm::dbgs() << "memref replacement for double buffering failed\n";); diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 66e7c5975da..d41614545d2 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -46,12 +46,12 @@ AffineMap mlir::getUnrolledLoopUpperBound(const ForInst &forInst, // Single result lower bound map only. if (lbMap.getNumResults() != 1) - return AffineMap::Null(); + return AffineMap(); // Sometimes, the trip count cannot be expressed as an affine expression. auto tripCount = getTripCountExpr(forInst); if (!tripCount) - return AffineMap::Null(); + return AffineMap(); AffineExpr lb(lbMap.getResult(0)); unsigned step = forInst.getStep(); @@ -72,12 +72,12 @@ AffineMap mlir::getCleanupLoopLowerBound(const ForInst &forInst, // Single result lower bound map only. if (lbMap.getNumResults() != 1) - return AffineMap::Null(); + return AffineMap(); // Sometimes the trip count cannot be expressed as an affine expression. AffineExpr tripCount(getTripCountExpr(forInst)); if (!tripCount) - return AffineMap::Null(); + return AffineMap(); AffineExpr lb(lbMap.getResult(0)); unsigned step = forInst.getStep(); -- cgit v1.2.3 From 5ecef2b3f63c8391e8dd1e06209b1b8f3000c9c7 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 28 Jan 2019 17:20:44 -0800 Subject: Define a AffineOps dialect as well as an AffineIfOp operation. Replace all instances of IfInst with AffineIfOp and delete IfInst. PiperOrigin-RevId: 231318632 --- mlir/include/mlir/AffineOps/AffineOps.h | 91 +++++++++++++ mlir/include/mlir/Analysis/NestedMatcher.h | 1 - mlir/include/mlir/IR/Block.h | 11 +- mlir/include/mlir/IR/Builders.h | 4 - mlir/include/mlir/IR/InstVisitor.h | 28 +--- mlir/include/mlir/IR/Instruction.h | 1 - mlir/include/mlir/IR/Instructions.h | 124 ----------------- mlir/include/mlir/IR/OpImplementation.h | 21 ++- mlir/include/mlir/IR/UseDefLists.h | 3 +- .../mlir/Transforms/MLPatternLoweringPass.h | 2 +- mlir/lib/AffineOps/AffineOps.cpp | 151 +++++++++++++++++++++ mlir/lib/AffineOps/DialectRegistration.cpp | 22 +++ mlir/lib/Analysis/LoopAnalysis.cpp | 11 ++ mlir/lib/Analysis/NestedMatcher.cpp | 20 ++- mlir/lib/Analysis/Utils.cpp | 22 +-- mlir/lib/Analysis/Verifier.cpp | 25 ---- mlir/lib/IR/AsmPrinter.cpp | 38 +----- mlir/lib/IR/Builders.cpp | 7 - mlir/lib/IR/Instruction.cpp | 109 ++------------- mlir/lib/IR/Operation.cpp | 2 +- mlir/lib/IR/Value.cpp | 2 - mlir/lib/Parser/Parser.cpp | 129 +++++++----------- mlir/lib/Parser/TokenKinds.def | 2 - mlir/lib/Transforms/CSE.cpp | 10 -- mlir/lib/Transforms/LoopFusion.cpp | 24 ++-- mlir/lib/Transforms/LoopUnroll.cpp | 9 -- mlir/lib/Transforms/LowerAffine.cpp | 59 +++++--- mlir/lib/Transforms/MaterializeVectors.cpp | 7 +- mlir/lib/Transforms/SimplifyAffineStructures.cpp | 37 ++--- mlir/test/IR/invalid.mlir | 14 +- mlir/test/IR/locations.mlir | 6 +- mlir/test/IR/parser.mlir | 22 +-- mlir/test/IR/pretty-locations.mlir | 8 +- mlir/test/Transforms/loop-fusion.mlir | 4 +- mlir/test/Transforms/memref-dependence-check.mlir | 2 +- mlir/test/Transforms/strip-debug-info.mlir | 6 +- 36 files changed, 493 insertions(+), 541 deletions(-) create mode 100644 mlir/include/mlir/AffineOps/AffineOps.h create mode 100644 mlir/lib/AffineOps/AffineOps.cpp create mode 100644 mlir/lib/AffineOps/DialectRegistration.cpp (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/AffineOps/AffineOps.h b/mlir/include/mlir/AffineOps/AffineOps.h new file mode 100644 index 00000000000..d511f628c3c --- /dev/null +++ b/mlir/include/mlir/AffineOps/AffineOps.h @@ -0,0 +1,91 @@ +//===- AffineOps.h - MLIR Affine Operations -------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines convenience types for working with Affine operations +// in the MLIR instruction set. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_AFFINEOPS_AFFINEOPS_H +#define MLIR_AFFINEOPS_AFFINEOPS_H + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" + +namespace mlir { + +class AffineOpsDialect : public Dialect { +public: + AffineOpsDialect(MLIRContext *context); +}; + +/// The "if" operation represents an if–then–else construct for conditionally +/// executing two regions of code. The operands to an if operation are an +/// IntegerSet condition and a set of symbol/dimension operands to the +/// condition set. The operation produces no results. For example: +/// +/// if #set(%i) { +/// ... +/// } else { +/// ... +/// } +/// +/// The 'else' blocks to the if operation are optional, and may be omitted. For +/// example: +/// +/// if #set(%i) { +/// ... +/// } +/// +class AffineIfOp + : public Op { +public: + // Hooks to customize behavior of this op. + static void build(Builder *builder, OperationState *result, + IntegerSet condition, ArrayRef conditionOperands); + + static StringRef getOperationName() { return "if"; } + static StringRef getConditionAttrName() { return "condition"; } + + IntegerSet getIntegerSet() const; + void setIntegerSet(IntegerSet newSet); + + /// Returns the list of 'then' blocks. + BlockList &getThenBlocks(); + const BlockList &getThenBlocks() const { + return const_cast(this)->getThenBlocks(); + } + + /// Returns the list of 'else' blocks. + BlockList &getElseBlocks(); + const BlockList &getElseBlocks() const { + return const_cast(this)->getElseBlocks(); + } + + bool verify() const; + static bool parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p) const; + +private: + friend class OperationInst; + explicit AffineIfOp(const OperationInst *state) : Op(state) {} +}; + +} // end namespace mlir + +#endif diff --git a/mlir/include/mlir/Analysis/NestedMatcher.h b/mlir/include/mlir/Analysis/NestedMatcher.h index c205d55488e..161bb217a10 100644 --- a/mlir/include/mlir/Analysis/NestedMatcher.h +++ b/mlir/include/mlir/Analysis/NestedMatcher.h @@ -128,7 +128,6 @@ private: void matchOne(Instruction *elem); void visitForInst(ForInst *forInst) { matchOne(forInst); } - void visitIfInst(IfInst *ifInst) { matchOne(ifInst); } void visitOperationInst(OperationInst *opInst) { matchOne(opInst); } /// POD paylod. diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h index 1b14d925d32..e85ea772d0b 100644 --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -26,7 +26,6 @@ #include "llvm/ADT/PointerUnion.h" namespace mlir { -class IfInst; class BlockList; class BlockAndValueMapping; @@ -62,7 +61,7 @@ public: } /// Returns the function that this block is part of, even if the block is - /// nested under an IfInst or ForInst. + /// nested under an OperationInst or ForInst. Function *getFunction(); const Function *getFunction() const { return const_cast(this)->getFunction(); @@ -325,7 +324,7 @@ private: namespace mlir { /// This class contains a list of basic blocks and has a notion of the object it -/// is part of - a Function or IfInst or ForInst. +/// is part of - a Function or OperationInst or ForInst. class BlockList { public: explicit BlockList(Function *container); @@ -365,14 +364,14 @@ public: return &BlockList::blocks; } - /// A BlockList is part of a Function or and IfInst/ForInst. If it is - /// part of an IfInst/ForInst, then return it, otherwise return null. + /// A BlockList is part of a Function or and OperationInst/ForInst. If it is + /// part of an OperationInst/ForInst, then return it, otherwise return null. Instruction *getContainingInst(); const Instruction *getContainingInst() const { return const_cast(this)->getContainingInst(); } - /// A BlockList is part of a Function or and IfInst/ForInst. If it is + /// A BlockList is part of a Function or and OperationInst/ForInst. If it is /// part of a Function, then return it, otherwise return null. Function *getContainingFunction(); const Function *getContainingFunction() const { diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 156bd02bb52..3271c12afde 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -286,10 +286,6 @@ public: // Default step is 1. ForInst *createFor(Location loc, int64_t lb, int64_t ub, int64_t step = 1); - /// Creates if instruction. - IfInst *createIf(Location location, ArrayRef operands, - IntegerSet set); - private: Function *function; Block *block = nullptr; diff --git a/mlir/include/mlir/IR/InstVisitor.h b/mlir/include/mlir/IR/InstVisitor.h index b6a759e76f5..78810da909d 100644 --- a/mlir/include/mlir/IR/InstVisitor.h +++ b/mlir/include/mlir/IR/InstVisitor.h @@ -44,7 +44,7 @@ // lc.walk(function); // numLoops = lc.numLoops; // -// There are 'visit' methods for OperationInst, ForInst, IfInst, and +// There are 'visit' methods for OperationInst, ForInst, and // Function, which recursively process all contained instructions. // // Note that if you don't implement visitXXX for some instruction type, @@ -85,8 +85,6 @@ public: switch (s->getKind()) { case Instruction::Kind::For: return static_cast(this)->visitForInst(cast(s)); - case Instruction::Kind::If: - return static_cast(this)->visitIfInst(cast(s)); case Instruction::Kind::OperationInst: return static_cast(this)->visitOperationInst( cast(s)); @@ -104,7 +102,6 @@ public: // When visiting a for inst, if inst, or an operation inst directly, these // methods get called to indicate when transitioning into a new unit. void visitForInst(ForInst *forInst) {} - void visitIfInst(IfInst *ifInst) {} void visitOperationInst(OperationInst *opInst) {} }; @@ -166,23 +163,6 @@ public: static_cast(this)->visitForInst(forInst); } - void walkIfInst(IfInst *ifInst) { - static_cast(this)->visitIfInst(ifInst); - static_cast(this)->walk(ifInst->getThen()->begin(), - ifInst->getThen()->end()); - if (auto *elseBlock = ifInst->getElse()) - static_cast(this)->walk(elseBlock->begin(), elseBlock->end()); - } - - void walkIfInstPostOrder(IfInst *ifInst) { - static_cast(this)->walkPostOrder(ifInst->getThen()->begin(), - ifInst->getThen()->end()); - if (auto *elseBlock = ifInst->getElse()) - static_cast(this)->walkPostOrder(elseBlock->begin(), - elseBlock->end()); - static_cast(this)->visitIfInst(ifInst); - } - // Function to walk a instruction. RetTy walk(Instruction *s) { static_assert(std::is_base_of::value, @@ -193,8 +173,6 @@ public: switch (s->getKind()) { case Instruction::Kind::For: return static_cast(this)->walkForInst(cast(s)); - case Instruction::Kind::If: - return static_cast(this)->walkIfInst(cast(s)); case Instruction::Kind::OperationInst: return static_cast(this)->walkOpInst(cast(s)); } @@ -210,9 +188,6 @@ public: case Instruction::Kind::For: return static_cast(this)->walkForInstPostOrder( cast(s)); - case Instruction::Kind::If: - return static_cast(this)->walkIfInstPostOrder( - cast(s)); case Instruction::Kind::OperationInst: return static_cast(this)->walkOpInstPostOrder( cast(s)); @@ -231,7 +206,6 @@ public: // processing their descendants in some way. When using RetTy, all of these // need to be overridden. void visitForInst(ForInst *forInst) {} - void visitIfInst(IfInst *ifInst) {} void visitOperationInst(OperationInst *opInst) {} void visitInstruction(Instruction *inst) {} }; diff --git a/mlir/include/mlir/IR/Instruction.h b/mlir/include/mlir/IR/Instruction.h index 6a296b7348e..3dc1e76dd20 100644 --- a/mlir/include/mlir/IR/Instruction.h +++ b/mlir/include/mlir/IR/Instruction.h @@ -75,7 +75,6 @@ public: enum class Kind { OperationInst = (int)IROperandOwner::Kind::OperationInst, For = (int)IROperandOwner::Kind::ForInst, - If = (int)IROperandOwner::Kind::IfInst, }; Kind getKind() const { return (Kind)IROperandOwner::getKind(); } diff --git a/mlir/include/mlir/IR/Instructions.h b/mlir/include/mlir/IR/Instructions.h index 71d832b8b90..fb6b1b97ca0 100644 --- a/mlir/include/mlir/IR/Instructions.h +++ b/mlir/include/mlir/IR/Instructions.h @@ -794,130 +794,6 @@ private: friend class ForInst; }; - -/// If instruction restricts execution to a subset of the loop iteration space. -class IfInst : public Instruction { -public: - static IfInst *create(Location location, ArrayRef operands, - IntegerSet set); - ~IfInst(); - - //===--------------------------------------------------------------------===// - // Then, else, condition. - //===--------------------------------------------------------------------===// - - Block *getThen() { return &thenClause.front(); } - const Block *getThen() const { return &thenClause.front(); } - Block *getElse() { return elseClause ? &elseClause->front() : nullptr; } - const Block *getElse() const { - return elseClause ? &elseClause->front() : nullptr; - } - bool hasElse() const { return elseClause != nullptr; } - - Block *createElse() { - assert(elseClause == nullptr && "already has an else clause!"); - elseClause = new BlockList(this); - elseClause->push_back(new Block()); - return &elseClause->front(); - } - - const AffineCondition getCondition() const; - - IntegerSet getIntegerSet() const { return set; } - void setIntegerSet(IntegerSet newSet) { - assert(newSet.getNumOperands() == operands.size()); - set = newSet; - } - - //===--------------------------------------------------------------------===// - // Operands - //===--------------------------------------------------------------------===// - - /// Operand iterators. - using operand_iterator = OperandIterator; - using const_operand_iterator = OperandIterator; - - /// Operand iterator range. - using operand_range = llvm::iterator_range; - using const_operand_range = llvm::iterator_range; - - unsigned getNumOperands() const { return operands.size(); } - - Value *getOperand(unsigned idx) { return getInstOperand(idx).get(); } - const Value *getOperand(unsigned idx) const { - return getInstOperand(idx).get(); - } - void setOperand(unsigned idx, Value *value) { - getInstOperand(idx).set(value); - } - - operand_iterator operand_begin() { return operand_iterator(this, 0); } - operand_iterator operand_end() { - return operand_iterator(this, getNumOperands()); - } - - const_operand_iterator operand_begin() const { - return const_operand_iterator(this, 0); - } - const_operand_iterator operand_end() const { - return const_operand_iterator(this, getNumOperands()); - } - - ArrayRef getInstOperands() const { return operands; } - MutableArrayRef getInstOperands() { return operands; } - InstOperand &getInstOperand(unsigned idx) { return getInstOperands()[idx]; } - const InstOperand &getInstOperand(unsigned idx) const { - return getInstOperands()[idx]; - } - - //===--------------------------------------------------------------------===// - // Other - //===--------------------------------------------------------------------===// - - MLIRContext *getContext() const; - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(const IROperandOwner *ptr) { - return ptr->getKind() == IROperandOwner::Kind::IfInst; - } - -private: - // it is always present. - BlockList thenClause; - // 'else' clause of the if instruction. 'nullptr' if there is no else clause. - BlockList *elseClause; - - // The integer set capturing the conditional guard. - IntegerSet set; - - // Condition operands. - std::vector operands; - - explicit IfInst(Location location, unsigned numOperands, IntegerSet set); -}; - -/// AffineCondition represents a condition of the 'if' instruction. -/// Its life span should not exceed that of the objects it refers to. -/// AffineCondition does not provide its own methods for iterating over -/// the operands since the iterators of the if instruction accomplish -/// the same purpose. -/// -/// AffineCondition is trivially copyable, so it should be passed by value. -class AffineCondition { -public: - const IfInst *getIfInst() const { return &inst; } - IntegerSet getIntegerSet() const { return set; } - -private: - // 'if' instruction that contains this affine condition. - const IfInst &inst; - // Integer set for this affine condition. - IntegerSet set; - - AffineCondition(const IfInst &inst, IntegerSet set) : inst(inst), set(set) {} - - friend class IfInst; -}; } // end namespace mlir #endif // MLIR_IR_INSTRUCTIONS_H diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index 1e319db3571..d3a5d35427f 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -89,6 +89,9 @@ public: /// Print the entire operation with the default generic assembly form. virtual void printGenericOp(const OperationInst *op) = 0; + /// Prints a block list. + virtual void printBlockList(const BlockList &blocks) = 0; + private: OpAsmPrinter(const OpAsmPrinter &) = delete; void operator=(const OpAsmPrinter &) = delete; @@ -195,7 +198,19 @@ public: virtual bool parseColonTypeList(SmallVectorImpl &result) = 0; /// Parse a keyword followed by a type. - virtual bool parseKeywordType(const char *keyword, Type &result) = 0; + bool parseKeywordType(const char *keyword, Type &result) { + return parseKeyword(keyword) || parseType(result); + } + + /// Parse a keyword. + bool parseKeyword(const char *keyword) { + if (parseOptionalKeyword(keyword)) + return emitError(getNameLoc(), "expected '" + Twine(keyword) + "'"); + return false; + } + + /// If a keyword is present, then parse it. + virtual bool parseOptionalKeyword(const char *keyword) = 0; /// Add the specified type to the end of the specified type list and return /// false. This is a helper designed to allow parse methods to be simple and @@ -296,6 +311,10 @@ public: int requiredOperandCount = -1, Delimiter delimiter = Delimiter::None) = 0; + /// Parses a block list. Any parsed blocks are filled in to the + /// operation's block lists after the operation is created. + virtual bool parseBlockList() = 0; + //===--------------------------------------------------------------------===// // Methods for interacting with the parser //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h index 053d3520103..80cd21362ce 100644 --- a/mlir/include/mlir/IR/UseDefLists.h +++ b/mlir/include/mlir/IR/UseDefLists.h @@ -81,10 +81,9 @@ public: enum class Kind { OperationInst, ForInst, - IfInst, /// These enums define ranges used for classof implementations. - INST_LAST = IfInst, + INST_LAST = ForInst, }; Kind getKind() const { return locationAndKind.getInt(); } diff --git a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h index 978fa45ab23..00c6577240c 100644 --- a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h +++ b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h @@ -93,7 +93,7 @@ using OwningMLLoweringPatternList = /// next _original_ operation is considered. /// In other words, for each operation, the pass applies the first matching /// rewriter in the list and advances to the (lexically) next operation. -/// Non-operation instructions (ForInst and IfInst) are ignored. +/// Non-operation instructions (ForInst) are ignored. /// This is similar to greedy worklist-based pattern rewriter, except that this /// operates on ML functions using an ML builder and does not maintain the work /// list. Note that, as of the time of writing, worklist-based rewriter did not diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp new file mode 100644 index 00000000000..5b29467fc44 --- /dev/null +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -0,0 +1,151 @@ +//===- AffineOps.cpp - MLIR Affine Operations -----------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/AffineOps/AffineOps.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/OpImplementation.h" +using namespace mlir; + +//===----------------------------------------------------------------------===// +// AffineOpsDialect +//===----------------------------------------------------------------------===// + +AffineOpsDialect::AffineOpsDialect(MLIRContext *context) + : Dialect(/*namePrefix=*/"", context) { + addOperations(); +} + +//===----------------------------------------------------------------------===// +// AffineIfOp +//===----------------------------------------------------------------------===// + +void AffineIfOp::build(Builder *builder, OperationState *result, + IntegerSet condition, + ArrayRef conditionOperands) { + result->addAttribute(getConditionAttrName(), IntegerSetAttr::get(condition)); + result->addOperands(conditionOperands); + + // Reserve 2 block lists, one for the 'then' and one for the 'else' regions. + result->reserveBlockLists(2); +} + +bool AffineIfOp::verify() const { + // Verify that we have a condition attribute. + auto conditionAttr = getAttrOfType(getConditionAttrName()); + if (!conditionAttr) + return emitOpError("requires an integer set attribute named 'condition'"); + + // Verify that the operands are valid dimension/symbols. + IntegerSet condition = conditionAttr.getValue(); + for (unsigned i = 0, e = getNumOperands(); i != e; ++i) { + const Value *operand = getOperand(i); + if (i < condition.getNumDims() && !operand->isValidDim()) + return emitOpError("operand cannot be used as a dimension id"); + if (i >= condition.getNumDims() && !operand->isValidSymbol()) + return emitOpError("operand cannot be used as a symbol"); + } + + // Verify that the entry of each child blocklist does not have arguments. + for (const auto &blockList : getInstruction()->getBlockLists()) { + if (blockList.empty()) + continue; + + // TODO(riverriddle) We currently do not allow multiple blocks in child + // block lists. + if (std::next(blockList.begin()) != blockList.end()) + return emitOpError( + "expects only one block per 'if' or 'else' block list"); + if (blockList.front().getTerminator()) + return emitOpError("expects region block to not have a terminator"); + + for (const auto &b : blockList) + if (b.getNumArguments() != 0) + return emitOpError( + "requires that child entry blocks have no arguments"); + } + return false; +} + +bool AffineIfOp::parse(OpAsmParser *parser, OperationState *result) { + // Parse the condition attribute set. + IntegerSetAttr conditionAttr; + unsigned numDims; + if (parser->parseAttribute(conditionAttr, getConditionAttrName().data(), + result->attributes) || + parseDimAndSymbolList(parser, result->operands, numDims)) + return true; + + // Verify the condition operands. + auto set = conditionAttr.getValue(); + if (set.getNumDims() != numDims) + return parser->emitError( + parser->getNameLoc(), + "dim operand count and integer set dim count must match"); + if (numDims + set.getNumSymbols() != result->operands.size()) + return parser->emitError( + parser->getNameLoc(), + "symbol operand count and integer set symbol count must match"); + + // Parse the 'then' block list. + if (parser->parseBlockList()) + return true; + + // If we find an 'else' keyword then parse the else block list. + if (!parser->parseOptionalKeyword("else")) { + if (parser->parseBlockList()) + return true; + } + + // Reserve 2 block lists, one for the 'then' and one for the 'else' regions. + result->reserveBlockLists(2); + return false; +} + +void AffineIfOp::print(OpAsmPrinter *p) const { + auto conditionAttr = getAttrOfType(getConditionAttrName()); + *p << "if " << conditionAttr; + printDimAndSymbolList(operand_begin(), operand_end(), + conditionAttr.getValue().getNumDims(), p); + p->printBlockList(getInstruction()->getBlockList(0)); + + // Print the 'else' block list if it has any blocks. + const auto &elseBlockList = getInstruction()->getBlockList(1); + if (!elseBlockList.empty()) { + *p << " else"; + p->printBlockList(elseBlockList); + } +} + +IntegerSet AffineIfOp::getIntegerSet() const { + return getAttrOfType(getConditionAttrName()).getValue(); +} +void AffineIfOp::setIntegerSet(IntegerSet newSet) { + setAttr( + Identifier::get(getConditionAttrName(), getInstruction()->getContext()), + IntegerSetAttr::get(newSet)); +} + +/// Returns the list of 'then' blocks. +BlockList &AffineIfOp::getThenBlocks() { + return getInstruction()->getBlockList(0); +} + +/// Returns the list of 'else' blocks. +BlockList &AffineIfOp::getElseBlocks() { + return getInstruction()->getBlockList(1); +} diff --git a/mlir/lib/AffineOps/DialectRegistration.cpp b/mlir/lib/AffineOps/DialectRegistration.cpp new file mode 100644 index 00000000000..0afb32c1bd6 --- /dev/null +++ b/mlir/lib/AffineOps/DialectRegistration.cpp @@ -0,0 +1,22 @@ +//===- DialectRegistration.cpp - Register Affine Op dialect ---------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/AffineOps/AffineOps.h" +using namespace mlir; + +// Static initialization for Affine op dialect registration. +static DialectRegistration StandardOps; diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 219f356807a..07c903a6613 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -21,6 +21,7 @@ #include "mlir/Analysis/LoopAnalysis.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/NestedMatcher.h" @@ -246,6 +247,16 @@ static bool isVectorizableLoopWithCond(const ForInst &loop, return false; } + // No vectorization across unknown regions. + auto regions = matcher::Op([](const Instruction &inst) -> bool { + auto &opInst = cast(inst); + return opInst.getNumBlockLists() != 0 && !opInst.isa(); + }); + auto regionsMatched = regions.match(forInst); + if (!regionsMatched.empty()) { + return false; + } + auto vectorTransfers = matcher::Op(isVectorTransferReadOrWrite); auto vectorTransfersMatched = vectorTransfers.match(forInst); if (!vectorTransfersMatched.empty()) { diff --git a/mlir/lib/Analysis/NestedMatcher.cpp b/mlir/lib/Analysis/NestedMatcher.cpp index 4f32e9b22f4..491a9bef1b9 100644 --- a/mlir/lib/Analysis/NestedMatcher.cpp +++ b/mlir/lib/Analysis/NestedMatcher.cpp @@ -16,6 +16,7 @@ // ============================================================================= #include "mlir/Analysis/NestedMatcher.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/StandardOps/StandardOps.h" #include "llvm/ADT/ArrayRef.h" @@ -186,6 +187,11 @@ FilterFunctionType NestedPattern::getFilterFunction() { return storage->filter; } +static bool isAffineIfOp(const Instruction &inst) { + return isa(inst) && + cast(inst).isa(); +} + namespace mlir { namespace matcher { @@ -194,16 +200,22 @@ NestedPattern Op(FilterFunctionType filter) { } NestedPattern If(NestedPattern child) { - return NestedPattern(Instruction::Kind::If, child, defaultFilterFunction); + return NestedPattern(Instruction::Kind::OperationInst, child, isAffineIfOp); } NestedPattern If(FilterFunctionType filter, NestedPattern child) { - return NestedPattern(Instruction::Kind::If, child, filter); + return NestedPattern(Instruction::Kind::OperationInst, child, + [filter](const Instruction &inst) { + return isAffineIfOp(inst) && filter(inst); + }); } NestedPattern If(ArrayRef nested) { - return NestedPattern(Instruction::Kind::If, nested, defaultFilterFunction); + return NestedPattern(Instruction::Kind::OperationInst, nested, isAffineIfOp); } NestedPattern If(FilterFunctionType filter, ArrayRef nested) { - return NestedPattern(Instruction::Kind::If, nested, filter); + return NestedPattern(Instruction::Kind::OperationInst, nested, + [filter](const Instruction &inst) { + return isAffineIfOp(inst) && filter(inst); + }); } NestedPattern For(NestedPattern child) { diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 939a2ede618..0e77d4d9084 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -22,6 +22,7 @@ #include "mlir/Analysis/Utils.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/IR/Builders.h" @@ -43,7 +44,7 @@ void mlir::getLoopIVs(const Instruction &inst, // Traverse up the hierarchy collecing all 'for' instruction while skipping // over 'if' instructions. while (currInst && ((currForInst = dyn_cast(currInst)) || - isa(currInst))) { + cast(currInst)->isa())) { if (currForInst) loops->push_back(currForInst); currInst = currInst->getParentInst(); @@ -359,21 +360,12 @@ static Instruction *getInstAtPosition(ArrayRef positions, if (auto *childForInst = dyn_cast(&inst)) return getInstAtPosition(positions, level + 1, childForInst->getBody()); - if (auto *ifInst = dyn_cast(&inst)) { - auto *ret = getInstAtPosition(positions, level + 1, ifInst->getThen()); - if (ret != nullptr) - return ret; - if (auto *elseClause = ifInst->getElse()) - return getInstAtPosition(positions, level + 1, elseClause); - } - if (auto *opInst = dyn_cast(&inst)) { - for (auto &blockList : opInst->getBlockLists()) { - for (auto &b : blockList) - if (auto *ret = getInstAtPosition(positions, level + 1, &b)) - return ret; - } - return nullptr; + for (auto &blockList : cast(&inst)->getBlockLists()) { + for (auto &b : blockList) + if (auto *ret = getInstAtPosition(positions, level + 1, &b)) + return ret; } + return nullptr; } return nullptr; } diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index 383a4878c35..474eeb2a28e 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -73,7 +73,6 @@ public: bool verifyBlock(const Block &block, bool isTopLevel); bool verifyOperation(const OperationInst &op); bool verifyForInst(const ForInst &forInst); - bool verifyIfInst(const IfInst &ifInst); bool verifyDominance(const Block &block); bool verifyInstDominance(const Instruction &inst); @@ -180,10 +179,6 @@ bool FuncVerifier::verifyBlock(const Block &block, bool isTopLevel) { if (verifyForInst(cast(inst))) return true; break; - case Instruction::Kind::If: - if (verifyIfInst(cast(inst))) - return true; - break; } } @@ -250,18 +245,6 @@ bool FuncVerifier::verifyForInst(const ForInst &forInst) { return verifyBlock(*forInst.getBody(), /*isTopLevel=*/false); } -bool FuncVerifier::verifyIfInst(const IfInst &ifInst) { - // TODO: check that if conditions are properly formed. - if (verifyBlock(*ifInst.getThen(), /*isTopLevel*/ false)) - return true; - - if (auto *elseClause = ifInst.getElse()) - if (verifyBlock(*elseClause, /*isTopLevel*/ false)) - return true; - - return false; -} - bool FuncVerifier::verifyDominance(const Block &block) { for (auto &inst : block) { // Check that all operands on the instruction are ok. @@ -283,14 +266,6 @@ bool FuncVerifier::verifyDominance(const Block &block) { if (verifyDominance(*cast(inst).getBody())) return true; break; - case Instruction::Kind::If: - auto &ifInst = cast(inst); - if (verifyDominance(*ifInst.getThen())) - return true; - if (auto *elseClause = ifInst.getElse()) - if (verifyDominance(*elseClause)) - return true; - break; } } return false; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 21bc3b824b1..cb4c1f0edce 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -145,7 +145,6 @@ private: // Visit functions. void visitInstruction(const Instruction *inst); void visitForInst(const ForInst *forInst); - void visitIfInst(const IfInst *ifInst); void visitOperationInst(const OperationInst *opInst); void visitType(Type type); void visitAttribute(Attribute attr); @@ -197,10 +196,6 @@ void ModuleState::visitAttribute(Attribute attr) { } } -void ModuleState::visitIfInst(const IfInst *ifInst) { - recordIntegerSetReference(ifInst->getIntegerSet()); -} - void ModuleState::visitForInst(const ForInst *forInst) { AffineMap lbMap = forInst->getLowerBoundMap(); if (!hasCustomForm(lbMap)) @@ -225,8 +220,6 @@ void ModuleState::visitOperationInst(const OperationInst *op) { void ModuleState::visitInstruction(const Instruction *inst) { switch (inst->getKind()) { - case Instruction::Kind::If: - return visitIfInst(cast(inst)); case Instruction::Kind::For: return visitForInst(cast(inst)); case Instruction::Kind::OperationInst: @@ -1077,7 +1070,6 @@ public: void print(const Instruction *inst); void print(const OperationInst *inst); void print(const ForInst *inst); - void print(const IfInst *inst); void print(const Block *block, bool printBlockArgs = true); void printOperation(const OperationInst *op); @@ -1125,6 +1117,9 @@ public: unsigned index) override; /// Print a block list. + void printBlockList(const BlockList &blocks) override { + printBlockList(blocks, /*printEntryBlockArgs=*/true); + } void printBlockList(const BlockList &blocks, bool printEntryBlockArgs) { os << " {\n"; if (!blocks.empty()) { @@ -1214,12 +1209,6 @@ void FunctionPrinter::numberValuesInBlock(const Block &block) { // Recursively number the stuff in the body. numberValuesInBlock(*cast(&inst)->getBody()); break; - case Instruction::Kind::If: { - auto *ifInst = cast(&inst); - numberValuesInBlock(*ifInst->getThen()); - if (auto *elseBlock = ifInst->getElse()) - numberValuesInBlock(*elseBlock); - } } } } @@ -1360,8 +1349,7 @@ void FunctionPrinter::printFunctionSignature() { } void FunctionPrinter::print(const Block *block, bool printBlockArgs) { - // Print the block label and argument list, unless this is the first block of - // the function, or the first block of an IfInst/ForInst with no arguments. + // Print the block label and argument list if requested. if (printBlockArgs) { os.indent(currentIndent); printBlockName(block); @@ -1418,8 +1406,6 @@ void FunctionPrinter::print(const Instruction *inst) { return print(cast(inst)); case Instruction::Kind::For: return print(cast(inst)); - case Instruction::Kind::If: - return print(cast(inst)); } } @@ -1447,22 +1433,6 @@ void FunctionPrinter::print(const ForInst *inst) { os.indent(currentIndent) << "}"; } -void FunctionPrinter::print(const IfInst *inst) { - os.indent(currentIndent) << "if "; - IntegerSet set = inst->getIntegerSet(); - printIntegerSetReference(set); - printDimAndSymbolList(inst->getInstOperands(), set.getNumDims()); - printTrailingLocation(inst->getLoc()); - os << " {\n"; - print(inst->getThen(), /*printBlockArgs=*/false); - os.indent(currentIndent) << "}"; - if (inst->hasElse()) { - os << " else {\n"; - print(inst->getElse(), /*printBlockArgs=*/false); - os.indent(currentIndent) << "}"; - } -} - void FunctionPrinter::printValueID(const Value *value, bool printResultNo) const { int resultNo = -1; diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 4471ff25e94..e174fdc1d00 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -327,10 +327,3 @@ ForInst *FuncBuilder::createFor(Location location, int64_t lb, int64_t ub, auto ubMap = AffineMap::getConstantMap(ub, context); return createFor(location, {}, lbMap, {}, ubMap, step); } - -IfInst *FuncBuilder::createIf(Location location, ArrayRef operands, - IntegerSet set) { - auto *inst = IfInst::create(location, operands, set); - block->getInstructions().insert(insertPoint, inst); - return inst; -} diff --git a/mlir/lib/IR/Instruction.cpp b/mlir/lib/IR/Instruction.cpp index 6d74ed14257..0ccab2305ec 100644 --- a/mlir/lib/IR/Instruction.cpp +++ b/mlir/lib/IR/Instruction.cpp @@ -73,9 +73,6 @@ void Instruction::destroy() { case Kind::For: delete cast(this); break; - case Kind::If: - delete cast(this); - break; } } @@ -141,8 +138,6 @@ unsigned Instruction::getNumOperands() const { return cast(this)->getNumOperands(); case Kind::For: return cast(this)->getNumOperands(); - case Kind::If: - return cast(this)->getNumOperands(); } } @@ -152,8 +147,6 @@ MutableArrayRef Instruction::getInstOperands() { return cast(this)->getInstOperands(); case Kind::For: return cast(this)->getInstOperands(); - case Kind::If: - return cast(this)->getInstOperands(); } } @@ -287,15 +280,6 @@ void Instruction::dropAllReferences() { // Make sure to drop references held by instructions within the body. cast(this)->getBody()->dropAllReferences(); break; - case Kind::If: { - // Make sure to drop references held by instructions within the 'then' and - // 'else' blocks. - auto *ifInst = cast(this); - ifInst->getThen()->dropAllReferences(); - if (auto *elseBlock = ifInst->getElse()) - elseBlock->dropAllReferences(); - break; - } case Kind::OperationInst: { auto *opInst = cast(this); if (isTerminator()) @@ -809,54 +793,6 @@ mlir::extractForInductionVars(ArrayRef forInsts) { results.push_back(forInst->getInductionVar()); return results; } -//===----------------------------------------------------------------------===// -// IfInst -//===----------------------------------------------------------------------===// - -IfInst::IfInst(Location location, unsigned numOperands, IntegerSet set) - : Instruction(Kind::If, location), thenClause(this), elseClause(nullptr), - set(set) { - operands.reserve(numOperands); - - // The then of an 'if' inst always has one block. - thenClause.push_back(new Block()); -} - -IfInst::~IfInst() { - if (elseClause) - delete elseClause; - - // An IfInst's IntegerSet 'set' should not be deleted since it is - // allocated through MLIRContext's bump pointer allocator. -} - -IfInst *IfInst::create(Location location, ArrayRef operands, - IntegerSet set) { - unsigned numOperands = operands.size(); - assert(numOperands == set.getNumOperands() && - "operand cound does not match the integer set operand count"); - - IfInst *inst = new IfInst(location, numOperands, set); - - for (auto *op : operands) - inst->operands.emplace_back(InstOperand(inst, op)); - - return inst; -} - -const AffineCondition IfInst::getCondition() const { - return AffineCondition(*this, set); -} - -MLIRContext *IfInst::getContext() const { - // Check for degenerate case of if instruction with no operands. - // This is unlikely, but legal. - if (operands.empty()) - return getFunction()->getContext(); - - return getOperand(0)->getType().getContext(); -} - //===----------------------------------------------------------------------===// // Instruction Cloning //===----------------------------------------------------------------------===// @@ -931,40 +867,23 @@ Instruction *Instruction::clone(BlockAndValueMapping &mapper, for (auto *opValue : getOperands()) operands.push_back(mapper.lookupOrDefault(const_cast(opValue))); - if (auto *forInst = dyn_cast(this)) { - auto lbMap = forInst->getLowerBoundMap(); - auto ubMap = forInst->getUpperBoundMap(); + // Otherwise, this must be a ForInst. + auto *forInst = cast(this); + auto lbMap = forInst->getLowerBoundMap(); + auto ubMap = forInst->getUpperBoundMap(); - auto *newFor = ForInst::create( - getLoc(), ArrayRef(operands).take_front(lbMap.getNumInputs()), - lbMap, ArrayRef(operands).take_back(ubMap.getNumInputs()), - ubMap, forInst->getStep()); + auto *newFor = ForInst::create( + getLoc(), ArrayRef(operands).take_front(lbMap.getNumInputs()), + lbMap, ArrayRef(operands).take_back(ubMap.getNumInputs()), ubMap, + forInst->getStep()); - // Remember the induction variable mapping. - mapper.map(forInst->getInductionVar(), newFor->getInductionVar()); - - // Recursively clone the body of the for loop. - for (auto &subInst : *forInst->getBody()) - newFor->getBody()->push_back(subInst.clone(mapper, context)); - - return newFor; - } - - // Otherwise, we must have an If instruction. - auto *ifInst = cast(this); - auto *newIf = IfInst::create(getLoc(), operands, ifInst->getIntegerSet()); - - auto *resultThen = newIf->getThen(); - for (auto &childInst : *ifInst->getThen()) - resultThen->push_back(childInst.clone(mapper, context)); - - if (ifInst->hasElse()) { - auto *resultElse = newIf->createElse(); - for (auto &childInst : *ifInst->getElse()) - resultElse->push_back(childInst.clone(mapper, context)); - } + // Remember the induction variable mapping. + mapper.map(forInst->getInductionVar(), newFor->getInductionVar()); - return newIf; + // Recursively clone the body of the for loop. + for (auto &subInst : *forInst->getBody()) + newFor->getBody()->push_back(subInst.clone(mapper, context)); + return newFor; } Instruction *Instruction::clone(MLIRContext *context) const { diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 099b218892f..2ab151f8913 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -281,7 +281,7 @@ bool OpTrait::impl::verifyIsTerminator(const OperationInst *op) { if (!block || &block->back() != op) return op->emitOpError("must be the last instruction in the parent block"); - // Terminators may not exist in ForInst and IfInst. + // TODO(riverriddle) Terminators may not exist with an operation region. if (block->getContainingInst()) return op->emitOpError("may only be at the top level of a function"); diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp index 6418b062dc1..7103eeb7389 100644 --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -66,8 +66,6 @@ MLIRContext *IROperandOwner::getContext() const { return cast(this)->getContext(); case Kind::ForInst: return cast(this)->getContext(); - case Kind::IfInst: - return cast(this)->getContext(); } } diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index c477ad1bbc5..e5d6aa46565 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -996,8 +996,7 @@ Attribute Parser::parseAttribute(Type type) { AffineMap map; IntegerSet set; if (parseAffineMapOrIntegerSetReference(map, set)) - return (emitError("expected affine map or integer set attribute value"), - nullptr); + return nullptr; if (map) return builder.getAffineMapAttr(map); assert(set); @@ -2209,8 +2208,6 @@ public: const char *affineStructName); ParseResult parseBound(SmallVectorImpl &operands, AffineMap &map, bool isLower); - ParseResult parseIfInst(); - ParseResult parseElseClause(Block *elseClause); ParseResult parseInstructions(Block *block); private: @@ -2392,10 +2389,6 @@ ParseResult FunctionParser::parseBlockBody(Block *block) { if (parseForInst()) return ParseFailure; break; - case Token::kw_if: - if (parseIfInst()) - return ParseFailure; - break; } } @@ -2935,12 +2928,18 @@ public: return false; } - /// Parse a keyword followed by a type. - bool parseKeywordType(const char *keyword, Type &result) override { - if (parser.getTokenSpelling() != keyword) - return parser.emitError("expected '" + Twine(keyword) + "'"); - parser.consumeToken(); - return !(result = parser.parseType()); + /// Parse an optional keyword. + bool parseOptionalKeyword(const char *keyword) override { + // Check that the current token is a bare identifier or keyword. + if (parser.getToken().isNot(Token::bare_identifier) && + !parser.getToken().isKeyword()) + return true; + + if (parser.getTokenSpelling() == keyword) { + parser.consumeToken(); + return false; + } + return true; } /// Parse an arbitrary attribute of a given type and return it in result. This @@ -3078,6 +3077,15 @@ public: return result == nullptr; } + /// Parses a list of blocks. + bool parseBlockList() override { + SmallVector results; + if (parser.parseOperationBlockList(results)) + return true; + parsedBlockLists.emplace_back(results); + return false; + } + //===--------------------------------------------------------------------===// // Methods for interacting with the parser //===--------------------------------------------------------------------===// @@ -3099,6 +3107,11 @@ public: /// Emit a diagnostic at the specified location and return true. bool emitError(llvm::SMLoc loc, const Twine &message) override { + // If we emit an error, then cleanup any parsed block lists. + for (auto &blockList : parsedBlockLists) + parser.cleanupInvalidBlocks(blockList); + parsedBlockLists.clear(); + parser.emitError(loc, "custom op '" + Twine(opName) + "' " + message); emittedError = true; return true; @@ -3106,7 +3119,13 @@ public: bool didEmitError() const { return emittedError; } + /// Returns the block lists that were parsed. + MutableArrayRef> getParsedBlockLists() { + return parsedBlockLists; + } + private: + std::vector> parsedBlockLists; SMLoc nameLoc; StringRef opName; FunctionParser &parser; @@ -3145,8 +3164,25 @@ OperationInst *FunctionParser::parseCustomOperation() { if (opAsmParser.didEmitError()) return nullptr; + // Check that enough block lists were reserved for those that were parsed. + auto parsedBlockLists = opAsmParser.getParsedBlockLists(); + if (parsedBlockLists.size() > opState.numBlockLists) { + opAsmParser.emitError( + opLoc, + "parsed more block lists than those reserved in the operation state"); + return nullptr; + } + // Otherwise, we succeeded. Use the state it parsed as our op information. - return builder.createOperation(opState); + auto *opInst = builder.createOperation(opState); + + // Resolve any parsed block lists. + for (unsigned i = 0, e = parsedBlockLists.size(); i != e; ++i) { + auto &opBlockList = opInst->getBlockList(i).getBlocks(); + opBlockList.insert(opBlockList.end(), parsedBlockLists[i].begin(), + parsedBlockLists[i].end()); + } + return opInst; } /// For instruction. @@ -3438,69 +3474,6 @@ IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims, return builder.getIntegerSet(numDims, numSymbols, constraints, isEqs); } -/// If instruction. -/// -/// ml-if-head ::= `if` ml-if-cond trailing-location? `{` inst* `}` -/// | ml-if-head `else` `if` ml-if-cond trailing-location? -/// `{` inst* `}` -/// ml-if-inst ::= ml-if-head -/// | ml-if-head `else` `{` inst* `}` -/// -ParseResult FunctionParser::parseIfInst() { - auto loc = getToken().getLoc(); - consumeToken(Token::kw_if); - - IntegerSet set = parseIntegerSetReference(); - if (!set) - return ParseFailure; - - SmallVector operands; - if (parseDimAndSymbolList(operands, set.getNumDims(), set.getNumOperands(), - "integer set")) - return ParseFailure; - - IfInst *ifInst = - builder.createIf(getEncodedSourceLocation(loc), operands, set); - - // Try to parse the optional trailing location. - if (parseOptionalTrailingLocation(ifInst)) - return ParseFailure; - - Block *thenClause = ifInst->getThen(); - - // When parsing of an if instruction body fails, the IR contains - // the if instruction with the portion of the body that has been - // successfully parsed. - if (parseToken(Token::l_brace, "expected '{' before instruction list") || - parseBlock(thenClause) || - parseToken(Token::r_brace, "expected '}' after instruction list")) - return ParseFailure; - - if (consumeIf(Token::kw_else)) { - auto *elseClause = ifInst->createElse(); - if (parseElseClause(elseClause)) - return ParseFailure; - } - - // Reset insertion point to the current block. - builder.setInsertionPointToEnd(ifInst->getBlock()); - - return ParseSuccess; -} - -ParseResult FunctionParser::parseElseClause(Block *elseClause) { - if (getToken().is(Token::kw_if)) { - builder.setInsertionPointToEnd(elseClause); - return parseIfInst(); - } - - if (parseToken(Token::l_brace, "expected '{' before instruction list") || - parseBlock(elseClause) || - parseToken(Token::r_brace, "expected '}' after instruction list")) - return ParseFailure; - return ParseSuccess; -} - //===----------------------------------------------------------------------===// // Top-level entity parsing. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def index 40e98b25cb3..ec00f98b3f5 100644 --- a/mlir/lib/Parser/TokenKinds.def +++ b/mlir/lib/Parser/TokenKinds.def @@ -91,7 +91,6 @@ TOK_KEYWORD(attributes) TOK_KEYWORD(bf16) TOK_KEYWORD(ceildiv) TOK_KEYWORD(dense) -TOK_KEYWORD(else) TOK_KEYWORD(splat) TOK_KEYWORD(f16) TOK_KEYWORD(f32) @@ -100,7 +99,6 @@ TOK_KEYWORD(false) TOK_KEYWORD(floordiv) TOK_KEYWORD(for) TOK_KEYWORD(func) -TOK_KEYWORD(if) TOK_KEYWORD(index) TOK_KEYWORD(loc) TOK_KEYWORD(max) diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index c2e1636626d..afd18a49b79 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -188,16 +188,6 @@ void CSE::simplifyBlock(Block *bb) { simplifyBlock(cast(i).getBody()); break; } - case Instruction::Kind::If: { - auto &ifInst = cast(i); - if (auto *elseBlock = ifInst.getElse()) { - ScopedMapTy::ScopeTy scope(knownValues); - simplifyBlock(elseBlock); - } - ScopedMapTy::ScopeTy scope(knownValues); - simplifyBlock(ifInst.getThen()); - break; - } } } } diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index cee0a08a63c..eebbbe9daa7 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -19,6 +19,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/LoopAnalysis.h" @@ -99,16 +100,16 @@ public: SmallVector forInsts; SmallVector loadOpInsts; SmallVector storeOpInsts; - bool hasIfInst = false; + bool hasNonForRegion = false; void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); } - void visitIfInst(IfInst *ifInst) { hasIfInst = true; } - void visitOperationInst(OperationInst *opInst) { - if (opInst->isa()) + if (opInst->getNumBlockLists() != 0) + hasNonForRegion = true; + else if (opInst->isa()) loadOpInsts.push_back(opInst); - if (opInst->isa()) + else if (opInst->isa()) storeOpInsts.push_back(opInst); } }; @@ -410,8 +411,8 @@ bool MemRefDependenceGraph::init(Function *f) { // all loads and store accesses it contains. LoopNestStateCollector collector; collector.walkForInst(forInst); - // Return false if IfInsts are found (not currently supported). - if (collector.hasIfInst) + // Return false if a non 'for' region was found (not currently supported). + if (collector.hasNonForRegion) return false; Node node(id++, &inst); for (auto *opInst : collector.loadOpInsts) { @@ -434,19 +435,18 @@ bool MemRefDependenceGraph::init(Function *f) { auto *memref = opInst->cast()->getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); - } - if (auto storeOp = opInst->dyn_cast()) { + } else if (auto storeOp = opInst->dyn_cast()) { // Create graph node for top-level store op. Node node(id++, &inst); node.stores.push_back(opInst); auto *memref = opInst->cast()->getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); + } else if (opInst->getNumBlockLists() != 0) { + // Return false if another region is found (not currently supported). + return false; } } - // Return false if IfInsts are found (not currently supported). - if (isa(&inst)) - return false; } // Walk memref access lists and add graph edges between dependent nodes. diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 39ef758833b..6d63e4afd2d 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -119,15 +119,6 @@ PassResult LoopUnroll::runOnFunction(Function *f) { return true; } - bool walkIfInstPostOrder(IfInst *ifInst) { - bool hasInnerLoops = - walkPostOrder(ifInst->getThen()->begin(), ifInst->getThen()->end()); - if (ifInst->hasElse()) - hasInnerLoops |= - walkPostOrder(ifInst->getElse()->begin(), ifInst->getElse()->end()); - return hasInnerLoops; - } - bool walkOpInstPostOrder(OperationInst *opInst) { for (auto &blockList : opInst->getBlockLists()) for (auto &block : blockList) diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index ab37ff63bad..f770684f519 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -20,6 +20,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/AffineOps/AffineOps.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" @@ -246,7 +247,7 @@ public: PassResult runOnFunction(Function *function) override; bool lowerForInst(ForInst *forInst); - bool lowerIfInst(IfInst *ifInst); + bool lowerAffineIf(AffineIfOp *ifOp); bool lowerAffineApply(AffineApplyOp *op); static char passID; @@ -409,7 +410,7 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) { // enabling easy nesting of "if" instructions and if-then-else-if chains. // // +--------------------------------+ -// | | +// | | // | %zero = constant 0 : index | // | %v = affine_apply #expr1(%ops) | // | %c = cmpi "sge" %v, %zero | @@ -453,10 +454,11 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) { // v v // +--------------------------------+ // | continue: | -// | | +// | | // +--------------------------------+ // -bool LowerAffinePass::lowerIfInst(IfInst *ifInst) { +bool LowerAffinePass::lowerAffineIf(AffineIfOp *ifOp) { + auto *ifInst = ifOp->getInstruction(); auto loc = ifInst->getLoc(); // Start by splitting the block containing the 'if' into two parts. The part @@ -466,22 +468,38 @@ bool LowerAffinePass::lowerIfInst(IfInst *ifInst) { auto *continueBlock = condBlock->splitBlock(ifInst); // Create a block for the 'then' code, inserting it between the cond and - // continue blocks. Move the instructions over from the IfInst and add a + // continue blocks. Move the instructions over from the AffineIfOp and add a // branch to the continuation point. Block *thenBlock = new Block(); thenBlock->insertBefore(continueBlock); - auto *oldThen = ifInst->getThen(); - thenBlock->getInstructions().splice(thenBlock->begin(), - oldThen->getInstructions(), - oldThen->begin(), oldThen->end()); + // If the 'then' block is not empty, then splice the instructions. + auto &oldThenBlocks = ifOp->getThenBlocks(); + if (!oldThenBlocks.empty()) { + // We currently only handle one 'then' block. + if (std::next(oldThenBlocks.begin()) != oldThenBlocks.end()) + return true; + + Block *oldThen = &oldThenBlocks.front(); + + thenBlock->getInstructions().splice(thenBlock->begin(), + oldThen->getInstructions(), + oldThen->begin(), oldThen->end()); + } + FuncBuilder builder(thenBlock); builder.create(loc, continueBlock); // Handle the 'else' block the same way, but we skip it if we have no else // code. Block *elseBlock = continueBlock; - if (auto *oldElse = ifInst->getElse()) { + auto &oldElseBlocks = ifOp->getElseBlocks(); + if (!oldElseBlocks.empty()) { + // We currently only handle one 'else' block. + if (std::next(oldElseBlocks.begin()) != oldElseBlocks.end()) + return true; + + auto *oldElse = &oldElseBlocks.front(); elseBlock = new Block(); elseBlock->insertBefore(continueBlock); @@ -493,7 +511,7 @@ bool LowerAffinePass::lowerIfInst(IfInst *ifInst) { } // Ok, now we just have to handle the condition logic. - auto integerSet = ifInst->getCondition().getIntegerSet(); + auto integerSet = ifOp->getIntegerSet(); // Implement short-circuit logic. For each affine expression in the 'if' // condition, convert it into an affine map and call `affine_apply` to obtain @@ -593,29 +611,30 @@ bool LowerAffinePass::lowerAffineApply(AffineApplyOp *op) { PassResult LowerAffinePass::runOnFunction(Function *function) { SmallVector instsToRewrite; - // Collect all the If and For instructions as well as AffineApplyOps. We do - // this as a prepass to avoid invalidating the walker with our rewrite. + // Collect all the For instructions as well as AffineIfOps and AffineApplyOps. + // We do this as a prepass to avoid invalidating the walker with our rewrite. function->walkInsts([&](Instruction *inst) { - if (isa(inst) || isa(inst)) + if (isa(inst)) instsToRewrite.push_back(inst); auto op = dyn_cast(inst); - if (op && op->isa()) + if (op && (op->isa() || op->isa())) instsToRewrite.push_back(inst); }); // Rewrite all of the ifs and fors. We walked the instructions in preorder, // so we know that we will rewrite them in the same order. for (auto *inst : instsToRewrite) - if (auto *ifInst = dyn_cast(inst)) { - if (lowerIfInst(ifInst)) - return failure(); - } else if (auto *forInst = dyn_cast(inst)) { + if (auto *forInst = dyn_cast(inst)) { if (lowerForInst(forInst)) return failure(); } else { auto op = cast(inst); - if (lowerAffineApply(op->cast())) + if (auto ifOp = op->dyn_cast()) { + if (lowerAffineIf(ifOp)) + return failure(); + } else if (lowerAffineApply(op->cast())) { return failure(); + } } return success(); diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 09d961f85cd..2744b1d624c 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -20,6 +20,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/Dominance.h" #include "mlir/Analysis/LoopAnalysis.h" @@ -559,9 +560,6 @@ static bool instantiateMaterialization(Instruction *inst, if (isa(inst)) return inst->emitError("NYI path ForInst"); - if (isa(inst)) - return inst->emitError("NYI path IfInst"); - // Create a builder here for unroll-and-jam effects. FuncBuilder b(inst); auto *opInst = cast(inst); @@ -570,6 +568,9 @@ static bool instantiateMaterialization(Instruction *inst, if (opInst->isa()) { return false; } + if (opInst->getNumBlockLists() != 0) + return inst->emitError("NYI path Op with region"); + if (auto write = opInst->dyn_cast()) { auto *clone = instantiate(&b, write, state->hwVectorType, state->hwVectorInstance, state->substitutionsMap); diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp index bd39e47786a..ba59123c700 100644 --- a/mlir/lib/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp @@ -28,7 +28,6 @@ #define DEBUG_TYPE "simplify-affine-structure" using namespace mlir; -using llvm::report_fatal_error; namespace { @@ -42,9 +41,6 @@ struct SimplifyAffineStructures : public FunctionPass { PassResult runOnFunction(Function *f) override; - void visitIfInst(IfInst *ifInst); - void visitOperationInst(OperationInst *opInst); - static char passID; }; @@ -66,28 +62,19 @@ static IntegerSet simplifyIntegerSet(IntegerSet set) { return set; } -void SimplifyAffineStructures::visitIfInst(IfInst *ifInst) { - auto set = ifInst->getCondition().getIntegerSet(); - ifInst->setIntegerSet(simplifyIntegerSet(set)); -} - -void SimplifyAffineStructures::visitOperationInst(OperationInst *opInst) { - for (auto attr : opInst->getAttrs()) { - if (auto mapAttr = attr.second.dyn_cast()) { - MutableAffineMap mMap(mapAttr.getValue()); - mMap.simplify(); - auto map = mMap.getAffineMap(); - opInst->setAttr(attr.first, AffineMapAttr::get(map)); - } - } -} - PassResult SimplifyAffineStructures::runOnFunction(Function *f) { - f->walkInsts([&](Instruction *inst) { - if (auto *opInst = dyn_cast(inst)) - visitOperationInst(opInst); - if (auto *ifInst = dyn_cast(inst)) - visitIfInst(ifInst); + f->walkOps([&](OperationInst *opInst) { + for (auto attr : opInst->getAttrs()) { + if (auto mapAttr = attr.second.dyn_cast()) { + MutableAffineMap mMap(mapAttr.getValue()); + mMap.simplify(); + auto map = mMap.getAffineMap(); + opInst->setAttr(attr.first, AffineMapAttr::get(map)); + } else if (auto setAttr = attr.second.dyn_cast()) { + auto simplified = simplifyIntegerSet(setAttr.getValue()); + opInst->setAttr(attr.first, IntegerSetAttr::get(simplified)); + } + } }); return success(); diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir index bae112dd3b9..595991c0109 100644 --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -243,14 +243,6 @@ func @non_instruction() { // ----- -func @invalid_if_conditional1() { - for %i = 1 to 10 { - if () { // expected-error {{expected ':' or '['}} - } -} - -// ----- - func @invalid_if_conditional2() { for %i = 1 to 10 { if (i)[N] : (i >= ) // expected-error {{expected '== 0' or '>= 0' at end of affine constraint}} @@ -664,7 +656,11 @@ func @invalid_if_operands2(%N : index) { func @invalid_if_operands3(%N : index) { for %i = 1 to 10 { if #set0(%i)[%i] { - // expected-error@-1 {{value '%i' cannot be used as a symbol}} + // expected-error@-1 {{operand cannot be used as a symbol}} + } + } + return +} // ----- // expected-error@+1 {{expected '"' in string literal}} diff --git a/mlir/test/IR/locations.mlir b/mlir/test/IR/locations.mlir index e3e1bbbbfad..8a90d12bd03 100644 --- a/mlir/test/IR/locations.mlir +++ b/mlir/test/IR/locations.mlir @@ -16,9 +16,9 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) { for %i0 = 0 to 8 loc(fused["foo", "mysource.cc":10:8]) { } - // CHECK: ) loc(fused<"myPass">["foo", "foo2"]) - if #set0(%2) loc(fused<"myPass">["foo", "foo2"]) { - } + // CHECK: } loc(fused<"myPass">["foo", "foo2"]) + if #set0(%2) { + } loc(fused<"myPass">["foo", "foo2"]) // CHECK: return %0 : i32 loc(unknown) return %1 : i32 loc(unknown) diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index 33109606538..626f24569c6 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -287,13 +287,15 @@ func @ifinst(%N: index) { // CHECK: %c1_i32 = constant 1 : i32 %y = "add"(%x, %i) : (i32, index) -> i32 // CHECK: %0 = "add"(%c1_i32, %i0) : (i32, index) -> i32 %z = "mul"(%y, %y) : (i32, i32) -> i32 // CHECK: %1 = "mul"(%0, %0) : (i32, i32) -> i32 - } else if (i)[N] : (i - 2 >= 0, 4 - i >= 0)(%i)[%N] { // CHECK } else if (#set1(%i0)[%arg0]) { - // CHECK: %c1 = constant 1 : index - %u = constant 1 : index - // CHECK: %2 = affine_apply #map{{.*}}(%i0, %i0)[%c1] - %w = affine_apply (d0,d1)[s0] -> (d0+d1+s0) (%i, %i) [%u] - } else { // CHECK } else { - %v = constant 3 : i32 // %c3_i32 = constant 3 : i32 + } else { // CHECK } else { + if (i)[N] : (i - 2 >= 0, 4 - i >= 0)(%i)[%N] { // CHECK if (#set1(%i0)[%arg0]) { + // CHECK: %c1 = constant 1 : index + %u = constant 1 : index + // CHECK: %2 = affine_apply #map{{.*}}(%i0, %i0)[%c1] + %w = affine_apply (d0,d1)[s0] -> (d0+d1+s0) (%i, %i) [%u] + } else { // CHECK } else { + %v = constant 3 : i32 // %c3_i32 = constant 3 : i32 + } } // CHECK } } // CHECK } return // CHECK return @@ -751,11 +753,11 @@ func @type_alias() -> !i32_type_alias { func @verbose_if(%N: index) { %c = constant 200 : index - // CHECK: "if"(%c200, %arg0, %c200) {cond: #set0} : (index, index, index) -> () { - "if"(%c, %N, %c) { cond: #set0 } : (index, index, index) -> () { + // CHECK: if #set0(%c200)[%arg0, %c200] { + "if"(%c, %N, %c) { condition: #set0 } : (index, index, index) -> () { // CHECK-NEXT: "add" %y = "add"(%c, %N) : (index, index) -> index - // CHECK-NEXT: } { + // CHECK-NEXT: } else { } { // The else block list. // CHECK-NEXT: "add" %z = "add"(%c, %c) : (index, index) -> index diff --git a/mlir/test/IR/pretty-locations.mlir b/mlir/test/IR/pretty-locations.mlir index cb2e14a56d5..69dace45165 100644 --- a/mlir/test/IR/pretty-locations.mlir +++ b/mlir/test/IR/pretty-locations.mlir @@ -21,10 +21,10 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) { for %i0 = 0 to 8 loc(fused["foo", "mysource.cc":10:8]) { } - // CHECK: ) <"myPass">["foo", "foo2"] - if #set0(%2) loc(fused<"myPass">["foo", "foo2"]) { - } + // CHECK: } <"myPass">["foo", "foo2"] + if #set0(%2) { + } loc(fused<"myPass">["foo", "foo2"]) // CHECK: return %0 : i32 [unknown] return %1 : i32 loc(unknown) -} \ No newline at end of file +} diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index d170ce590f7..162f193f662 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -483,7 +483,7 @@ func @should_not_fuse_if_inst_at_top_level() { %c0 = constant 4 : index if #set0(%c0) { } - // Top-level IfInst should prevent fusion. + // Top-level IfOp should prevent fusion. // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } @@ -512,7 +512,7 @@ func @should_not_fuse_if_inst_in_loop_nest() { %v0 = load %m[%i1] : memref<10xf32> } - // IfInst in ForInst should prevent fusion. + // IfOp in ForInst should prevent fusion. // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } diff --git a/mlir/test/Transforms/memref-dependence-check.mlir b/mlir/test/Transforms/memref-dependence-check.mlir index 6f6ad3fafc7..628044ed77a 100644 --- a/mlir/test/Transforms/memref-dependence-check.mlir +++ b/mlir/test/Transforms/memref-dependence-check.mlir @@ -10,7 +10,7 @@ func @store_may_execute_before_load() { %cf7 = constant 7.0 : f32 %c0 = constant 4 : index // There is a dependence from store 0 to load 1 at depth 1 because the - // ancestor IfInst of the store, dominates the ancestor ForSmt of the load, + // ancestor IfOp of the store, dominates the ancestor ForSmt of the load, // and thus the store "may" conditionally execute before the load. if #set0(%c0) { for %i0 = 0 to 10 { diff --git a/mlir/test/Transforms/strip-debug-info.mlir b/mlir/test/Transforms/strip-debug-info.mlir index 5509c7aba55..13f009deb70 100644 --- a/mlir/test/Transforms/strip-debug-info.mlir +++ b/mlir/test/Transforms/strip-debug-info.mlir @@ -13,10 +13,10 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) { for %i0 = 0 to 8 loc(fused["foo", "mysource.cc":10:8]) { } - // CHECK: if #set0(%c4) loc(unknown) + // CHECK: } loc(unknown) %2 = constant 4 : index - if #set0(%2) loc(fused<"myPass">["foo", "foo2"]) { - } + if #set0(%2) { + } loc(fused<"myPass">["foo", "foo2"]) // CHECK: return %0 : i32 loc(unknown) return %1 : i32 loc("bar") -- cgit v1.2.3 From ae772b79659afff6695170b7d404113c32e35a0d Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Mon, 28 Jan 2019 18:28:43 -0800 Subject: Automated rollback of changelist 231318632. PiperOrigin-RevId: 231327161 --- mlir/include/mlir/AffineOps/AffineOps.h | 91 ------------- mlir/include/mlir/Analysis/NestedMatcher.h | 1 + mlir/include/mlir/IR/Block.h | 11 +- mlir/include/mlir/IR/Builders.h | 4 + mlir/include/mlir/IR/InstVisitor.h | 28 +++- mlir/include/mlir/IR/Instruction.h | 1 + mlir/include/mlir/IR/Instructions.h | 124 +++++++++++++++++ mlir/include/mlir/IR/OpImplementation.h | 21 +-- mlir/include/mlir/IR/UseDefLists.h | 3 +- .../mlir/Transforms/MLPatternLoweringPass.h | 2 +- mlir/lib/AffineOps/AffineOps.cpp | 151 --------------------- mlir/lib/AffineOps/DialectRegistration.cpp | 22 --- mlir/lib/Analysis/LoopAnalysis.cpp | 11 -- mlir/lib/Analysis/NestedMatcher.cpp | 20 +-- mlir/lib/Analysis/Utils.cpp | 22 ++- mlir/lib/Analysis/Verifier.cpp | 25 ++++ mlir/lib/IR/AsmPrinter.cpp | 38 +++++- mlir/lib/IR/Builders.cpp | 7 + mlir/lib/IR/Instruction.cpp | 109 +++++++++++++-- mlir/lib/IR/Operation.cpp | 2 +- mlir/lib/IR/Value.cpp | 2 + mlir/lib/Parser/Parser.cpp | 129 +++++++++++------- mlir/lib/Parser/TokenKinds.def | 2 + mlir/lib/Transforms/CSE.cpp | 10 ++ mlir/lib/Transforms/LoopFusion.cpp | 24 ++-- mlir/lib/Transforms/LoopUnroll.cpp | 9 ++ mlir/lib/Transforms/LowerAffine.cpp | 59 +++----- mlir/lib/Transforms/MaterializeVectors.cpp | 7 +- mlir/lib/Transforms/SimplifyAffineStructures.cpp | 37 +++-- mlir/test/IR/invalid.mlir | 14 +- mlir/test/IR/locations.mlir | 6 +- mlir/test/IR/parser.mlir | 22 ++- mlir/test/IR/pretty-locations.mlir | 8 +- mlir/test/Transforms/loop-fusion.mlir | 4 +- mlir/test/Transforms/memref-dependence-check.mlir | 2 +- mlir/test/Transforms/strip-debug-info.mlir | 6 +- 36 files changed, 541 insertions(+), 493 deletions(-) delete mode 100644 mlir/include/mlir/AffineOps/AffineOps.h delete mode 100644 mlir/lib/AffineOps/AffineOps.cpp delete mode 100644 mlir/lib/AffineOps/DialectRegistration.cpp (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/AffineOps/AffineOps.h b/mlir/include/mlir/AffineOps/AffineOps.h deleted file mode 100644 index d511f628c3c..00000000000 --- a/mlir/include/mlir/AffineOps/AffineOps.h +++ /dev/null @@ -1,91 +0,0 @@ -//===- AffineOps.h - MLIR Affine Operations -------------------------------===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file defines convenience types for working with Affine operations -// in the MLIR instruction set. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_AFFINEOPS_AFFINEOPS_H -#define MLIR_AFFINEOPS_AFFINEOPS_H - -#include "mlir/IR/Dialect.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/StandardTypes.h" - -namespace mlir { - -class AffineOpsDialect : public Dialect { -public: - AffineOpsDialect(MLIRContext *context); -}; - -/// The "if" operation represents an if–then–else construct for conditionally -/// executing two regions of code. The operands to an if operation are an -/// IntegerSet condition and a set of symbol/dimension operands to the -/// condition set. The operation produces no results. For example: -/// -/// if #set(%i) { -/// ... -/// } else { -/// ... -/// } -/// -/// The 'else' blocks to the if operation are optional, and may be omitted. For -/// example: -/// -/// if #set(%i) { -/// ... -/// } -/// -class AffineIfOp - : public Op { -public: - // Hooks to customize behavior of this op. - static void build(Builder *builder, OperationState *result, - IntegerSet condition, ArrayRef conditionOperands); - - static StringRef getOperationName() { return "if"; } - static StringRef getConditionAttrName() { return "condition"; } - - IntegerSet getIntegerSet() const; - void setIntegerSet(IntegerSet newSet); - - /// Returns the list of 'then' blocks. - BlockList &getThenBlocks(); - const BlockList &getThenBlocks() const { - return const_cast(this)->getThenBlocks(); - } - - /// Returns the list of 'else' blocks. - BlockList &getElseBlocks(); - const BlockList &getElseBlocks() const { - return const_cast(this)->getElseBlocks(); - } - - bool verify() const; - static bool parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p) const; - -private: - friend class OperationInst; - explicit AffineIfOp(const OperationInst *state) : Op(state) {} -}; - -} // end namespace mlir - -#endif diff --git a/mlir/include/mlir/Analysis/NestedMatcher.h b/mlir/include/mlir/Analysis/NestedMatcher.h index 161bb217a10..c205d55488e 100644 --- a/mlir/include/mlir/Analysis/NestedMatcher.h +++ b/mlir/include/mlir/Analysis/NestedMatcher.h @@ -128,6 +128,7 @@ private: void matchOne(Instruction *elem); void visitForInst(ForInst *forInst) { matchOne(forInst); } + void visitIfInst(IfInst *ifInst) { matchOne(ifInst); } void visitOperationInst(OperationInst *opInst) { matchOne(opInst); } /// POD paylod. diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h index e85ea772d0b..1b14d925d32 100644 --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -26,6 +26,7 @@ #include "llvm/ADT/PointerUnion.h" namespace mlir { +class IfInst; class BlockList; class BlockAndValueMapping; @@ -61,7 +62,7 @@ public: } /// Returns the function that this block is part of, even if the block is - /// nested under an OperationInst or ForInst. + /// nested under an IfInst or ForInst. Function *getFunction(); const Function *getFunction() const { return const_cast(this)->getFunction(); @@ -324,7 +325,7 @@ private: namespace mlir { /// This class contains a list of basic blocks and has a notion of the object it -/// is part of - a Function or OperationInst or ForInst. +/// is part of - a Function or IfInst or ForInst. class BlockList { public: explicit BlockList(Function *container); @@ -364,14 +365,14 @@ public: return &BlockList::blocks; } - /// A BlockList is part of a Function or and OperationInst/ForInst. If it is - /// part of an OperationInst/ForInst, then return it, otherwise return null. + /// A BlockList is part of a Function or and IfInst/ForInst. If it is + /// part of an IfInst/ForInst, then return it, otherwise return null. Instruction *getContainingInst(); const Instruction *getContainingInst() const { return const_cast(this)->getContainingInst(); } - /// A BlockList is part of a Function or and OperationInst/ForInst. If it is + /// A BlockList is part of a Function or and IfInst/ForInst. If it is /// part of a Function, then return it, otherwise return null. Function *getContainingFunction(); const Function *getContainingFunction() const { diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 3271c12afde..156bd02bb52 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -286,6 +286,10 @@ public: // Default step is 1. ForInst *createFor(Location loc, int64_t lb, int64_t ub, int64_t step = 1); + /// Creates if instruction. + IfInst *createIf(Location location, ArrayRef operands, + IntegerSet set); + private: Function *function; Block *block = nullptr; diff --git a/mlir/include/mlir/IR/InstVisitor.h b/mlir/include/mlir/IR/InstVisitor.h index 78810da909d..b6a759e76f5 100644 --- a/mlir/include/mlir/IR/InstVisitor.h +++ b/mlir/include/mlir/IR/InstVisitor.h @@ -44,7 +44,7 @@ // lc.walk(function); // numLoops = lc.numLoops; // -// There are 'visit' methods for OperationInst, ForInst, and +// There are 'visit' methods for OperationInst, ForInst, IfInst, and // Function, which recursively process all contained instructions. // // Note that if you don't implement visitXXX for some instruction type, @@ -85,6 +85,8 @@ public: switch (s->getKind()) { case Instruction::Kind::For: return static_cast(this)->visitForInst(cast(s)); + case Instruction::Kind::If: + return static_cast(this)->visitIfInst(cast(s)); case Instruction::Kind::OperationInst: return static_cast(this)->visitOperationInst( cast(s)); @@ -102,6 +104,7 @@ public: // When visiting a for inst, if inst, or an operation inst directly, these // methods get called to indicate when transitioning into a new unit. void visitForInst(ForInst *forInst) {} + void visitIfInst(IfInst *ifInst) {} void visitOperationInst(OperationInst *opInst) {} }; @@ -163,6 +166,23 @@ public: static_cast(this)->visitForInst(forInst); } + void walkIfInst(IfInst *ifInst) { + static_cast(this)->visitIfInst(ifInst); + static_cast(this)->walk(ifInst->getThen()->begin(), + ifInst->getThen()->end()); + if (auto *elseBlock = ifInst->getElse()) + static_cast(this)->walk(elseBlock->begin(), elseBlock->end()); + } + + void walkIfInstPostOrder(IfInst *ifInst) { + static_cast(this)->walkPostOrder(ifInst->getThen()->begin(), + ifInst->getThen()->end()); + if (auto *elseBlock = ifInst->getElse()) + static_cast(this)->walkPostOrder(elseBlock->begin(), + elseBlock->end()); + static_cast(this)->visitIfInst(ifInst); + } + // Function to walk a instruction. RetTy walk(Instruction *s) { static_assert(std::is_base_of::value, @@ -173,6 +193,8 @@ public: switch (s->getKind()) { case Instruction::Kind::For: return static_cast(this)->walkForInst(cast(s)); + case Instruction::Kind::If: + return static_cast(this)->walkIfInst(cast(s)); case Instruction::Kind::OperationInst: return static_cast(this)->walkOpInst(cast(s)); } @@ -188,6 +210,9 @@ public: case Instruction::Kind::For: return static_cast(this)->walkForInstPostOrder( cast(s)); + case Instruction::Kind::If: + return static_cast(this)->walkIfInstPostOrder( + cast(s)); case Instruction::Kind::OperationInst: return static_cast(this)->walkOpInstPostOrder( cast(s)); @@ -206,6 +231,7 @@ public: // processing their descendants in some way. When using RetTy, all of these // need to be overridden. void visitForInst(ForInst *forInst) {} + void visitIfInst(IfInst *ifInst) {} void visitOperationInst(OperationInst *opInst) {} void visitInstruction(Instruction *inst) {} }; diff --git a/mlir/include/mlir/IR/Instruction.h b/mlir/include/mlir/IR/Instruction.h index 3dc1e76dd20..6a296b7348e 100644 --- a/mlir/include/mlir/IR/Instruction.h +++ b/mlir/include/mlir/IR/Instruction.h @@ -75,6 +75,7 @@ public: enum class Kind { OperationInst = (int)IROperandOwner::Kind::OperationInst, For = (int)IROperandOwner::Kind::ForInst, + If = (int)IROperandOwner::Kind::IfInst, }; Kind getKind() const { return (Kind)IROperandOwner::getKind(); } diff --git a/mlir/include/mlir/IR/Instructions.h b/mlir/include/mlir/IR/Instructions.h index fb6b1b97ca0..71d832b8b90 100644 --- a/mlir/include/mlir/IR/Instructions.h +++ b/mlir/include/mlir/IR/Instructions.h @@ -794,6 +794,130 @@ private: friend class ForInst; }; + +/// If instruction restricts execution to a subset of the loop iteration space. +class IfInst : public Instruction { +public: + static IfInst *create(Location location, ArrayRef operands, + IntegerSet set); + ~IfInst(); + + //===--------------------------------------------------------------------===// + // Then, else, condition. + //===--------------------------------------------------------------------===// + + Block *getThen() { return &thenClause.front(); } + const Block *getThen() const { return &thenClause.front(); } + Block *getElse() { return elseClause ? &elseClause->front() : nullptr; } + const Block *getElse() const { + return elseClause ? &elseClause->front() : nullptr; + } + bool hasElse() const { return elseClause != nullptr; } + + Block *createElse() { + assert(elseClause == nullptr && "already has an else clause!"); + elseClause = new BlockList(this); + elseClause->push_back(new Block()); + return &elseClause->front(); + } + + const AffineCondition getCondition() const; + + IntegerSet getIntegerSet() const { return set; } + void setIntegerSet(IntegerSet newSet) { + assert(newSet.getNumOperands() == operands.size()); + set = newSet; + } + + //===--------------------------------------------------------------------===// + // Operands + //===--------------------------------------------------------------------===// + + /// Operand iterators. + using operand_iterator = OperandIterator; + using const_operand_iterator = OperandIterator; + + /// Operand iterator range. + using operand_range = llvm::iterator_range; + using const_operand_range = llvm::iterator_range; + + unsigned getNumOperands() const { return operands.size(); } + + Value *getOperand(unsigned idx) { return getInstOperand(idx).get(); } + const Value *getOperand(unsigned idx) const { + return getInstOperand(idx).get(); + } + void setOperand(unsigned idx, Value *value) { + getInstOperand(idx).set(value); + } + + operand_iterator operand_begin() { return operand_iterator(this, 0); } + operand_iterator operand_end() { + return operand_iterator(this, getNumOperands()); + } + + const_operand_iterator operand_begin() const { + return const_operand_iterator(this, 0); + } + const_operand_iterator operand_end() const { + return const_operand_iterator(this, getNumOperands()); + } + + ArrayRef getInstOperands() const { return operands; } + MutableArrayRef getInstOperands() { return operands; } + InstOperand &getInstOperand(unsigned idx) { return getInstOperands()[idx]; } + const InstOperand &getInstOperand(unsigned idx) const { + return getInstOperands()[idx]; + } + + //===--------------------------------------------------------------------===// + // Other + //===--------------------------------------------------------------------===// + + MLIRContext *getContext() const; + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool classof(const IROperandOwner *ptr) { + return ptr->getKind() == IROperandOwner::Kind::IfInst; + } + +private: + // it is always present. + BlockList thenClause; + // 'else' clause of the if instruction. 'nullptr' if there is no else clause. + BlockList *elseClause; + + // The integer set capturing the conditional guard. + IntegerSet set; + + // Condition operands. + std::vector operands; + + explicit IfInst(Location location, unsigned numOperands, IntegerSet set); +}; + +/// AffineCondition represents a condition of the 'if' instruction. +/// Its life span should not exceed that of the objects it refers to. +/// AffineCondition does not provide its own methods for iterating over +/// the operands since the iterators of the if instruction accomplish +/// the same purpose. +/// +/// AffineCondition is trivially copyable, so it should be passed by value. +class AffineCondition { +public: + const IfInst *getIfInst() const { return &inst; } + IntegerSet getIntegerSet() const { return set; } + +private: + // 'if' instruction that contains this affine condition. + const IfInst &inst; + // Integer set for this affine condition. + IntegerSet set; + + AffineCondition(const IfInst &inst, IntegerSet set) : inst(inst), set(set) {} + + friend class IfInst; +}; } // end namespace mlir #endif // MLIR_IR_INSTRUCTIONS_H diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index d3a5d35427f..1e319db3571 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -89,9 +89,6 @@ public: /// Print the entire operation with the default generic assembly form. virtual void printGenericOp(const OperationInst *op) = 0; - /// Prints a block list. - virtual void printBlockList(const BlockList &blocks) = 0; - private: OpAsmPrinter(const OpAsmPrinter &) = delete; void operator=(const OpAsmPrinter &) = delete; @@ -198,19 +195,7 @@ public: virtual bool parseColonTypeList(SmallVectorImpl &result) = 0; /// Parse a keyword followed by a type. - bool parseKeywordType(const char *keyword, Type &result) { - return parseKeyword(keyword) || parseType(result); - } - - /// Parse a keyword. - bool parseKeyword(const char *keyword) { - if (parseOptionalKeyword(keyword)) - return emitError(getNameLoc(), "expected '" + Twine(keyword) + "'"); - return false; - } - - /// If a keyword is present, then parse it. - virtual bool parseOptionalKeyword(const char *keyword) = 0; + virtual bool parseKeywordType(const char *keyword, Type &result) = 0; /// Add the specified type to the end of the specified type list and return /// false. This is a helper designed to allow parse methods to be simple and @@ -311,10 +296,6 @@ public: int requiredOperandCount = -1, Delimiter delimiter = Delimiter::None) = 0; - /// Parses a block list. Any parsed blocks are filled in to the - /// operation's block lists after the operation is created. - virtual bool parseBlockList() = 0; - //===--------------------------------------------------------------------===// // Methods for interacting with the parser //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h index 80cd21362ce..053d3520103 100644 --- a/mlir/include/mlir/IR/UseDefLists.h +++ b/mlir/include/mlir/IR/UseDefLists.h @@ -81,9 +81,10 @@ public: enum class Kind { OperationInst, ForInst, + IfInst, /// These enums define ranges used for classof implementations. - INST_LAST = ForInst, + INST_LAST = IfInst, }; Kind getKind() const { return locationAndKind.getInt(); } diff --git a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h index 00c6577240c..978fa45ab23 100644 --- a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h +++ b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h @@ -93,7 +93,7 @@ using OwningMLLoweringPatternList = /// next _original_ operation is considered. /// In other words, for each operation, the pass applies the first matching /// rewriter in the list and advances to the (lexically) next operation. -/// Non-operation instructions (ForInst) are ignored. +/// Non-operation instructions (ForInst and IfInst) are ignored. /// This is similar to greedy worklist-based pattern rewriter, except that this /// operates on ML functions using an ML builder and does not maintain the work /// list. Note that, as of the time of writing, worklist-based rewriter did not diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp deleted file mode 100644 index 5b29467fc44..00000000000 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ /dev/null @@ -1,151 +0,0 @@ -//===- AffineOps.cpp - MLIR Affine Operations -----------------------------===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#include "mlir/AffineOps/AffineOps.h" -#include "mlir/IR/Block.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/OpImplementation.h" -using namespace mlir; - -//===----------------------------------------------------------------------===// -// AffineOpsDialect -//===----------------------------------------------------------------------===// - -AffineOpsDialect::AffineOpsDialect(MLIRContext *context) - : Dialect(/*namePrefix=*/"", context) { - addOperations(); -} - -//===----------------------------------------------------------------------===// -// AffineIfOp -//===----------------------------------------------------------------------===// - -void AffineIfOp::build(Builder *builder, OperationState *result, - IntegerSet condition, - ArrayRef conditionOperands) { - result->addAttribute(getConditionAttrName(), IntegerSetAttr::get(condition)); - result->addOperands(conditionOperands); - - // Reserve 2 block lists, one for the 'then' and one for the 'else' regions. - result->reserveBlockLists(2); -} - -bool AffineIfOp::verify() const { - // Verify that we have a condition attribute. - auto conditionAttr = getAttrOfType(getConditionAttrName()); - if (!conditionAttr) - return emitOpError("requires an integer set attribute named 'condition'"); - - // Verify that the operands are valid dimension/symbols. - IntegerSet condition = conditionAttr.getValue(); - for (unsigned i = 0, e = getNumOperands(); i != e; ++i) { - const Value *operand = getOperand(i); - if (i < condition.getNumDims() && !operand->isValidDim()) - return emitOpError("operand cannot be used as a dimension id"); - if (i >= condition.getNumDims() && !operand->isValidSymbol()) - return emitOpError("operand cannot be used as a symbol"); - } - - // Verify that the entry of each child blocklist does not have arguments. - for (const auto &blockList : getInstruction()->getBlockLists()) { - if (blockList.empty()) - continue; - - // TODO(riverriddle) We currently do not allow multiple blocks in child - // block lists. - if (std::next(blockList.begin()) != blockList.end()) - return emitOpError( - "expects only one block per 'if' or 'else' block list"); - if (blockList.front().getTerminator()) - return emitOpError("expects region block to not have a terminator"); - - for (const auto &b : blockList) - if (b.getNumArguments() != 0) - return emitOpError( - "requires that child entry blocks have no arguments"); - } - return false; -} - -bool AffineIfOp::parse(OpAsmParser *parser, OperationState *result) { - // Parse the condition attribute set. - IntegerSetAttr conditionAttr; - unsigned numDims; - if (parser->parseAttribute(conditionAttr, getConditionAttrName().data(), - result->attributes) || - parseDimAndSymbolList(parser, result->operands, numDims)) - return true; - - // Verify the condition operands. - auto set = conditionAttr.getValue(); - if (set.getNumDims() != numDims) - return parser->emitError( - parser->getNameLoc(), - "dim operand count and integer set dim count must match"); - if (numDims + set.getNumSymbols() != result->operands.size()) - return parser->emitError( - parser->getNameLoc(), - "symbol operand count and integer set symbol count must match"); - - // Parse the 'then' block list. - if (parser->parseBlockList()) - return true; - - // If we find an 'else' keyword then parse the else block list. - if (!parser->parseOptionalKeyword("else")) { - if (parser->parseBlockList()) - return true; - } - - // Reserve 2 block lists, one for the 'then' and one for the 'else' regions. - result->reserveBlockLists(2); - return false; -} - -void AffineIfOp::print(OpAsmPrinter *p) const { - auto conditionAttr = getAttrOfType(getConditionAttrName()); - *p << "if " << conditionAttr; - printDimAndSymbolList(operand_begin(), operand_end(), - conditionAttr.getValue().getNumDims(), p); - p->printBlockList(getInstruction()->getBlockList(0)); - - // Print the 'else' block list if it has any blocks. - const auto &elseBlockList = getInstruction()->getBlockList(1); - if (!elseBlockList.empty()) { - *p << " else"; - p->printBlockList(elseBlockList); - } -} - -IntegerSet AffineIfOp::getIntegerSet() const { - return getAttrOfType(getConditionAttrName()).getValue(); -} -void AffineIfOp::setIntegerSet(IntegerSet newSet) { - setAttr( - Identifier::get(getConditionAttrName(), getInstruction()->getContext()), - IntegerSetAttr::get(newSet)); -} - -/// Returns the list of 'then' blocks. -BlockList &AffineIfOp::getThenBlocks() { - return getInstruction()->getBlockList(0); -} - -/// Returns the list of 'else' blocks. -BlockList &AffineIfOp::getElseBlocks() { - return getInstruction()->getBlockList(1); -} diff --git a/mlir/lib/AffineOps/DialectRegistration.cpp b/mlir/lib/AffineOps/DialectRegistration.cpp deleted file mode 100644 index 0afb32c1bd6..00000000000 --- a/mlir/lib/AffineOps/DialectRegistration.cpp +++ /dev/null @@ -1,22 +0,0 @@ -//===- DialectRegistration.cpp - Register Affine Op dialect ---------------===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#include "mlir/AffineOps/AffineOps.h" -using namespace mlir; - -// Static initialization for Affine op dialect registration. -static DialectRegistration StandardOps; diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 07c903a6613..219f356807a 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -21,7 +21,6 @@ #include "mlir/Analysis/LoopAnalysis.h" -#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/NestedMatcher.h" @@ -247,16 +246,6 @@ static bool isVectorizableLoopWithCond(const ForInst &loop, return false; } - // No vectorization across unknown regions. - auto regions = matcher::Op([](const Instruction &inst) -> bool { - auto &opInst = cast(inst); - return opInst.getNumBlockLists() != 0 && !opInst.isa(); - }); - auto regionsMatched = regions.match(forInst); - if (!regionsMatched.empty()) { - return false; - } - auto vectorTransfers = matcher::Op(isVectorTransferReadOrWrite); auto vectorTransfersMatched = vectorTransfers.match(forInst); if (!vectorTransfersMatched.empty()) { diff --git a/mlir/lib/Analysis/NestedMatcher.cpp b/mlir/lib/Analysis/NestedMatcher.cpp index 491a9bef1b9..4f32e9b22f4 100644 --- a/mlir/lib/Analysis/NestedMatcher.cpp +++ b/mlir/lib/Analysis/NestedMatcher.cpp @@ -16,7 +16,6 @@ // ============================================================================= #include "mlir/Analysis/NestedMatcher.h" -#include "mlir/AffineOps/AffineOps.h" #include "mlir/StandardOps/StandardOps.h" #include "llvm/ADT/ArrayRef.h" @@ -187,11 +186,6 @@ FilterFunctionType NestedPattern::getFilterFunction() { return storage->filter; } -static bool isAffineIfOp(const Instruction &inst) { - return isa(inst) && - cast(inst).isa(); -} - namespace mlir { namespace matcher { @@ -200,22 +194,16 @@ NestedPattern Op(FilterFunctionType filter) { } NestedPattern If(NestedPattern child) { - return NestedPattern(Instruction::Kind::OperationInst, child, isAffineIfOp); + return NestedPattern(Instruction::Kind::If, child, defaultFilterFunction); } NestedPattern If(FilterFunctionType filter, NestedPattern child) { - return NestedPattern(Instruction::Kind::OperationInst, child, - [filter](const Instruction &inst) { - return isAffineIfOp(inst) && filter(inst); - }); + return NestedPattern(Instruction::Kind::If, child, filter); } NestedPattern If(ArrayRef nested) { - return NestedPattern(Instruction::Kind::OperationInst, nested, isAffineIfOp); + return NestedPattern(Instruction::Kind::If, nested, defaultFilterFunction); } NestedPattern If(FilterFunctionType filter, ArrayRef nested) { - return NestedPattern(Instruction::Kind::OperationInst, nested, - [filter](const Instruction &inst) { - return isAffineIfOp(inst) && filter(inst); - }); + return NestedPattern(Instruction::Kind::If, nested, filter); } NestedPattern For(NestedPattern child) { diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 0e77d4d9084..939a2ede618 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -22,7 +22,6 @@ #include "mlir/Analysis/Utils.h" -#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/IR/Builders.h" @@ -44,7 +43,7 @@ void mlir::getLoopIVs(const Instruction &inst, // Traverse up the hierarchy collecing all 'for' instruction while skipping // over 'if' instructions. while (currInst && ((currForInst = dyn_cast(currInst)) || - cast(currInst)->isa())) { + isa(currInst))) { if (currForInst) loops->push_back(currForInst); currInst = currInst->getParentInst(); @@ -360,12 +359,21 @@ static Instruction *getInstAtPosition(ArrayRef positions, if (auto *childForInst = dyn_cast(&inst)) return getInstAtPosition(positions, level + 1, childForInst->getBody()); - for (auto &blockList : cast(&inst)->getBlockLists()) { - for (auto &b : blockList) - if (auto *ret = getInstAtPosition(positions, level + 1, &b)) - return ret; + if (auto *ifInst = dyn_cast(&inst)) { + auto *ret = getInstAtPosition(positions, level + 1, ifInst->getThen()); + if (ret != nullptr) + return ret; + if (auto *elseClause = ifInst->getElse()) + return getInstAtPosition(positions, level + 1, elseClause); + } + if (auto *opInst = dyn_cast(&inst)) { + for (auto &blockList : opInst->getBlockLists()) { + for (auto &b : blockList) + if (auto *ret = getInstAtPosition(positions, level + 1, &b)) + return ret; + } + return nullptr; } - return nullptr; } return nullptr; } diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index 474eeb2a28e..383a4878c35 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -73,6 +73,7 @@ public: bool verifyBlock(const Block &block, bool isTopLevel); bool verifyOperation(const OperationInst &op); bool verifyForInst(const ForInst &forInst); + bool verifyIfInst(const IfInst &ifInst); bool verifyDominance(const Block &block); bool verifyInstDominance(const Instruction &inst); @@ -179,6 +180,10 @@ bool FuncVerifier::verifyBlock(const Block &block, bool isTopLevel) { if (verifyForInst(cast(inst))) return true; break; + case Instruction::Kind::If: + if (verifyIfInst(cast(inst))) + return true; + break; } } @@ -245,6 +250,18 @@ bool FuncVerifier::verifyForInst(const ForInst &forInst) { return verifyBlock(*forInst.getBody(), /*isTopLevel=*/false); } +bool FuncVerifier::verifyIfInst(const IfInst &ifInst) { + // TODO: check that if conditions are properly formed. + if (verifyBlock(*ifInst.getThen(), /*isTopLevel*/ false)) + return true; + + if (auto *elseClause = ifInst.getElse()) + if (verifyBlock(*elseClause, /*isTopLevel*/ false)) + return true; + + return false; +} + bool FuncVerifier::verifyDominance(const Block &block) { for (auto &inst : block) { // Check that all operands on the instruction are ok. @@ -266,6 +283,14 @@ bool FuncVerifier::verifyDominance(const Block &block) { if (verifyDominance(*cast(inst).getBody())) return true; break; + case Instruction::Kind::If: + auto &ifInst = cast(inst); + if (verifyDominance(*ifInst.getThen())) + return true; + if (auto *elseClause = ifInst.getElse()) + if (verifyDominance(*elseClause)) + return true; + break; } } return false; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index cb4c1f0edce..21bc3b824b1 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -145,6 +145,7 @@ private: // Visit functions. void visitInstruction(const Instruction *inst); void visitForInst(const ForInst *forInst); + void visitIfInst(const IfInst *ifInst); void visitOperationInst(const OperationInst *opInst); void visitType(Type type); void visitAttribute(Attribute attr); @@ -196,6 +197,10 @@ void ModuleState::visitAttribute(Attribute attr) { } } +void ModuleState::visitIfInst(const IfInst *ifInst) { + recordIntegerSetReference(ifInst->getIntegerSet()); +} + void ModuleState::visitForInst(const ForInst *forInst) { AffineMap lbMap = forInst->getLowerBoundMap(); if (!hasCustomForm(lbMap)) @@ -220,6 +225,8 @@ void ModuleState::visitOperationInst(const OperationInst *op) { void ModuleState::visitInstruction(const Instruction *inst) { switch (inst->getKind()) { + case Instruction::Kind::If: + return visitIfInst(cast(inst)); case Instruction::Kind::For: return visitForInst(cast(inst)); case Instruction::Kind::OperationInst: @@ -1070,6 +1077,7 @@ public: void print(const Instruction *inst); void print(const OperationInst *inst); void print(const ForInst *inst); + void print(const IfInst *inst); void print(const Block *block, bool printBlockArgs = true); void printOperation(const OperationInst *op); @@ -1117,9 +1125,6 @@ public: unsigned index) override; /// Print a block list. - void printBlockList(const BlockList &blocks) override { - printBlockList(blocks, /*printEntryBlockArgs=*/true); - } void printBlockList(const BlockList &blocks, bool printEntryBlockArgs) { os << " {\n"; if (!blocks.empty()) { @@ -1209,6 +1214,12 @@ void FunctionPrinter::numberValuesInBlock(const Block &block) { // Recursively number the stuff in the body. numberValuesInBlock(*cast(&inst)->getBody()); break; + case Instruction::Kind::If: { + auto *ifInst = cast(&inst); + numberValuesInBlock(*ifInst->getThen()); + if (auto *elseBlock = ifInst->getElse()) + numberValuesInBlock(*elseBlock); + } } } } @@ -1349,7 +1360,8 @@ void FunctionPrinter::printFunctionSignature() { } void FunctionPrinter::print(const Block *block, bool printBlockArgs) { - // Print the block label and argument list if requested. + // Print the block label and argument list, unless this is the first block of + // the function, or the first block of an IfInst/ForInst with no arguments. if (printBlockArgs) { os.indent(currentIndent); printBlockName(block); @@ -1406,6 +1418,8 @@ void FunctionPrinter::print(const Instruction *inst) { return print(cast(inst)); case Instruction::Kind::For: return print(cast(inst)); + case Instruction::Kind::If: + return print(cast(inst)); } } @@ -1433,6 +1447,22 @@ void FunctionPrinter::print(const ForInst *inst) { os.indent(currentIndent) << "}"; } +void FunctionPrinter::print(const IfInst *inst) { + os.indent(currentIndent) << "if "; + IntegerSet set = inst->getIntegerSet(); + printIntegerSetReference(set); + printDimAndSymbolList(inst->getInstOperands(), set.getNumDims()); + printTrailingLocation(inst->getLoc()); + os << " {\n"; + print(inst->getThen(), /*printBlockArgs=*/false); + os.indent(currentIndent) << "}"; + if (inst->hasElse()) { + os << " else {\n"; + print(inst->getElse(), /*printBlockArgs=*/false); + os.indent(currentIndent) << "}"; + } +} + void FunctionPrinter::printValueID(const Value *value, bool printResultNo) const { int resultNo = -1; diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index e174fdc1d00..4471ff25e94 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -327,3 +327,10 @@ ForInst *FuncBuilder::createFor(Location location, int64_t lb, int64_t ub, auto ubMap = AffineMap::getConstantMap(ub, context); return createFor(location, {}, lbMap, {}, ubMap, step); } + +IfInst *FuncBuilder::createIf(Location location, ArrayRef operands, + IntegerSet set) { + auto *inst = IfInst::create(location, operands, set); + block->getInstructions().insert(insertPoint, inst); + return inst; +} diff --git a/mlir/lib/IR/Instruction.cpp b/mlir/lib/IR/Instruction.cpp index 0ccab2305ec..6d74ed14257 100644 --- a/mlir/lib/IR/Instruction.cpp +++ b/mlir/lib/IR/Instruction.cpp @@ -73,6 +73,9 @@ void Instruction::destroy() { case Kind::For: delete cast(this); break; + case Kind::If: + delete cast(this); + break; } } @@ -138,6 +141,8 @@ unsigned Instruction::getNumOperands() const { return cast(this)->getNumOperands(); case Kind::For: return cast(this)->getNumOperands(); + case Kind::If: + return cast(this)->getNumOperands(); } } @@ -147,6 +152,8 @@ MutableArrayRef Instruction::getInstOperands() { return cast(this)->getInstOperands(); case Kind::For: return cast(this)->getInstOperands(); + case Kind::If: + return cast(this)->getInstOperands(); } } @@ -280,6 +287,15 @@ void Instruction::dropAllReferences() { // Make sure to drop references held by instructions within the body. cast(this)->getBody()->dropAllReferences(); break; + case Kind::If: { + // Make sure to drop references held by instructions within the 'then' and + // 'else' blocks. + auto *ifInst = cast(this); + ifInst->getThen()->dropAllReferences(); + if (auto *elseBlock = ifInst->getElse()) + elseBlock->dropAllReferences(); + break; + } case Kind::OperationInst: { auto *opInst = cast(this); if (isTerminator()) @@ -793,6 +809,54 @@ mlir::extractForInductionVars(ArrayRef forInsts) { results.push_back(forInst->getInductionVar()); return results; } +//===----------------------------------------------------------------------===// +// IfInst +//===----------------------------------------------------------------------===// + +IfInst::IfInst(Location location, unsigned numOperands, IntegerSet set) + : Instruction(Kind::If, location), thenClause(this), elseClause(nullptr), + set(set) { + operands.reserve(numOperands); + + // The then of an 'if' inst always has one block. + thenClause.push_back(new Block()); +} + +IfInst::~IfInst() { + if (elseClause) + delete elseClause; + + // An IfInst's IntegerSet 'set' should not be deleted since it is + // allocated through MLIRContext's bump pointer allocator. +} + +IfInst *IfInst::create(Location location, ArrayRef operands, + IntegerSet set) { + unsigned numOperands = operands.size(); + assert(numOperands == set.getNumOperands() && + "operand cound does not match the integer set operand count"); + + IfInst *inst = new IfInst(location, numOperands, set); + + for (auto *op : operands) + inst->operands.emplace_back(InstOperand(inst, op)); + + return inst; +} + +const AffineCondition IfInst::getCondition() const { + return AffineCondition(*this, set); +} + +MLIRContext *IfInst::getContext() const { + // Check for degenerate case of if instruction with no operands. + // This is unlikely, but legal. + if (operands.empty()) + return getFunction()->getContext(); + + return getOperand(0)->getType().getContext(); +} + //===----------------------------------------------------------------------===// // Instruction Cloning //===----------------------------------------------------------------------===// @@ -867,23 +931,40 @@ Instruction *Instruction::clone(BlockAndValueMapping &mapper, for (auto *opValue : getOperands()) operands.push_back(mapper.lookupOrDefault(const_cast(opValue))); - // Otherwise, this must be a ForInst. - auto *forInst = cast(this); - auto lbMap = forInst->getLowerBoundMap(); - auto ubMap = forInst->getUpperBoundMap(); + if (auto *forInst = dyn_cast(this)) { + auto lbMap = forInst->getLowerBoundMap(); + auto ubMap = forInst->getUpperBoundMap(); - auto *newFor = ForInst::create( - getLoc(), ArrayRef(operands).take_front(lbMap.getNumInputs()), - lbMap, ArrayRef(operands).take_back(ubMap.getNumInputs()), ubMap, - forInst->getStep()); + auto *newFor = ForInst::create( + getLoc(), ArrayRef(operands).take_front(lbMap.getNumInputs()), + lbMap, ArrayRef(operands).take_back(ubMap.getNumInputs()), + ubMap, forInst->getStep()); - // Remember the induction variable mapping. - mapper.map(forInst->getInductionVar(), newFor->getInductionVar()); + // Remember the induction variable mapping. + mapper.map(forInst->getInductionVar(), newFor->getInductionVar()); + + // Recursively clone the body of the for loop. + for (auto &subInst : *forInst->getBody()) + newFor->getBody()->push_back(subInst.clone(mapper, context)); + + return newFor; + } + + // Otherwise, we must have an If instruction. + auto *ifInst = cast(this); + auto *newIf = IfInst::create(getLoc(), operands, ifInst->getIntegerSet()); + + auto *resultThen = newIf->getThen(); + for (auto &childInst : *ifInst->getThen()) + resultThen->push_back(childInst.clone(mapper, context)); + + if (ifInst->hasElse()) { + auto *resultElse = newIf->createElse(); + for (auto &childInst : *ifInst->getElse()) + resultElse->push_back(childInst.clone(mapper, context)); + } - // Recursively clone the body of the for loop. - for (auto &subInst : *forInst->getBody()) - newFor->getBody()->push_back(subInst.clone(mapper, context)); - return newFor; + return newIf; } Instruction *Instruction::clone(MLIRContext *context) const { diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 2ab151f8913..099b218892f 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -281,7 +281,7 @@ bool OpTrait::impl::verifyIsTerminator(const OperationInst *op) { if (!block || &block->back() != op) return op->emitOpError("must be the last instruction in the parent block"); - // TODO(riverriddle) Terminators may not exist with an operation region. + // Terminators may not exist in ForInst and IfInst. if (block->getContainingInst()) return op->emitOpError("may only be at the top level of a function"); diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp index 7103eeb7389..6418b062dc1 100644 --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -66,6 +66,8 @@ MLIRContext *IROperandOwner::getContext() const { return cast(this)->getContext(); case Kind::ForInst: return cast(this)->getContext(); + case Kind::IfInst: + return cast(this)->getContext(); } } diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index e5d6aa46565..c477ad1bbc5 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -996,7 +996,8 @@ Attribute Parser::parseAttribute(Type type) { AffineMap map; IntegerSet set; if (parseAffineMapOrIntegerSetReference(map, set)) - return nullptr; + return (emitError("expected affine map or integer set attribute value"), + nullptr); if (map) return builder.getAffineMapAttr(map); assert(set); @@ -2208,6 +2209,8 @@ public: const char *affineStructName); ParseResult parseBound(SmallVectorImpl &operands, AffineMap &map, bool isLower); + ParseResult parseIfInst(); + ParseResult parseElseClause(Block *elseClause); ParseResult parseInstructions(Block *block); private: @@ -2389,6 +2392,10 @@ ParseResult FunctionParser::parseBlockBody(Block *block) { if (parseForInst()) return ParseFailure; break; + case Token::kw_if: + if (parseIfInst()) + return ParseFailure; + break; } } @@ -2928,18 +2935,12 @@ public: return false; } - /// Parse an optional keyword. - bool parseOptionalKeyword(const char *keyword) override { - // Check that the current token is a bare identifier or keyword. - if (parser.getToken().isNot(Token::bare_identifier) && - !parser.getToken().isKeyword()) - return true; - - if (parser.getTokenSpelling() == keyword) { - parser.consumeToken(); - return false; - } - return true; + /// Parse a keyword followed by a type. + bool parseKeywordType(const char *keyword, Type &result) override { + if (parser.getTokenSpelling() != keyword) + return parser.emitError("expected '" + Twine(keyword) + "'"); + parser.consumeToken(); + return !(result = parser.parseType()); } /// Parse an arbitrary attribute of a given type and return it in result. This @@ -3077,15 +3078,6 @@ public: return result == nullptr; } - /// Parses a list of blocks. - bool parseBlockList() override { - SmallVector results; - if (parser.parseOperationBlockList(results)) - return true; - parsedBlockLists.emplace_back(results); - return false; - } - //===--------------------------------------------------------------------===// // Methods for interacting with the parser //===--------------------------------------------------------------------===// @@ -3107,11 +3099,6 @@ public: /// Emit a diagnostic at the specified location and return true. bool emitError(llvm::SMLoc loc, const Twine &message) override { - // If we emit an error, then cleanup any parsed block lists. - for (auto &blockList : parsedBlockLists) - parser.cleanupInvalidBlocks(blockList); - parsedBlockLists.clear(); - parser.emitError(loc, "custom op '" + Twine(opName) + "' " + message); emittedError = true; return true; @@ -3119,13 +3106,7 @@ public: bool didEmitError() const { return emittedError; } - /// Returns the block lists that were parsed. - MutableArrayRef> getParsedBlockLists() { - return parsedBlockLists; - } - private: - std::vector> parsedBlockLists; SMLoc nameLoc; StringRef opName; FunctionParser &parser; @@ -3164,25 +3145,8 @@ OperationInst *FunctionParser::parseCustomOperation() { if (opAsmParser.didEmitError()) return nullptr; - // Check that enough block lists were reserved for those that were parsed. - auto parsedBlockLists = opAsmParser.getParsedBlockLists(); - if (parsedBlockLists.size() > opState.numBlockLists) { - opAsmParser.emitError( - opLoc, - "parsed more block lists than those reserved in the operation state"); - return nullptr; - } - // Otherwise, we succeeded. Use the state it parsed as our op information. - auto *opInst = builder.createOperation(opState); - - // Resolve any parsed block lists. - for (unsigned i = 0, e = parsedBlockLists.size(); i != e; ++i) { - auto &opBlockList = opInst->getBlockList(i).getBlocks(); - opBlockList.insert(opBlockList.end(), parsedBlockLists[i].begin(), - parsedBlockLists[i].end()); - } - return opInst; + return builder.createOperation(opState); } /// For instruction. @@ -3474,6 +3438,69 @@ IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims, return builder.getIntegerSet(numDims, numSymbols, constraints, isEqs); } +/// If instruction. +/// +/// ml-if-head ::= `if` ml-if-cond trailing-location? `{` inst* `}` +/// | ml-if-head `else` `if` ml-if-cond trailing-location? +/// `{` inst* `}` +/// ml-if-inst ::= ml-if-head +/// | ml-if-head `else` `{` inst* `}` +/// +ParseResult FunctionParser::parseIfInst() { + auto loc = getToken().getLoc(); + consumeToken(Token::kw_if); + + IntegerSet set = parseIntegerSetReference(); + if (!set) + return ParseFailure; + + SmallVector operands; + if (parseDimAndSymbolList(operands, set.getNumDims(), set.getNumOperands(), + "integer set")) + return ParseFailure; + + IfInst *ifInst = + builder.createIf(getEncodedSourceLocation(loc), operands, set); + + // Try to parse the optional trailing location. + if (parseOptionalTrailingLocation(ifInst)) + return ParseFailure; + + Block *thenClause = ifInst->getThen(); + + // When parsing of an if instruction body fails, the IR contains + // the if instruction with the portion of the body that has been + // successfully parsed. + if (parseToken(Token::l_brace, "expected '{' before instruction list") || + parseBlock(thenClause) || + parseToken(Token::r_brace, "expected '}' after instruction list")) + return ParseFailure; + + if (consumeIf(Token::kw_else)) { + auto *elseClause = ifInst->createElse(); + if (parseElseClause(elseClause)) + return ParseFailure; + } + + // Reset insertion point to the current block. + builder.setInsertionPointToEnd(ifInst->getBlock()); + + return ParseSuccess; +} + +ParseResult FunctionParser::parseElseClause(Block *elseClause) { + if (getToken().is(Token::kw_if)) { + builder.setInsertionPointToEnd(elseClause); + return parseIfInst(); + } + + if (parseToken(Token::l_brace, "expected '{' before instruction list") || + parseBlock(elseClause) || + parseToken(Token::r_brace, "expected '}' after instruction list")) + return ParseFailure; + return ParseSuccess; +} + //===----------------------------------------------------------------------===// // Top-level entity parsing. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def index ec00f98b3f5..40e98b25cb3 100644 --- a/mlir/lib/Parser/TokenKinds.def +++ b/mlir/lib/Parser/TokenKinds.def @@ -91,6 +91,7 @@ TOK_KEYWORD(attributes) TOK_KEYWORD(bf16) TOK_KEYWORD(ceildiv) TOK_KEYWORD(dense) +TOK_KEYWORD(else) TOK_KEYWORD(splat) TOK_KEYWORD(f16) TOK_KEYWORD(f32) @@ -99,6 +100,7 @@ TOK_KEYWORD(false) TOK_KEYWORD(floordiv) TOK_KEYWORD(for) TOK_KEYWORD(func) +TOK_KEYWORD(if) TOK_KEYWORD(index) TOK_KEYWORD(loc) TOK_KEYWORD(max) diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index afd18a49b79..c2e1636626d 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -188,6 +188,16 @@ void CSE::simplifyBlock(Block *bb) { simplifyBlock(cast(i).getBody()); break; } + case Instruction::Kind::If: { + auto &ifInst = cast(i); + if (auto *elseBlock = ifInst.getElse()) { + ScopedMapTy::ScopeTy scope(knownValues); + simplifyBlock(elseBlock); + } + ScopedMapTy::ScopeTy scope(knownValues); + simplifyBlock(ifInst.getThen()); + break; + } } } } diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index eebbbe9daa7..cee0a08a63c 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -19,7 +19,6 @@ // //===----------------------------------------------------------------------===// -#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/LoopAnalysis.h" @@ -100,16 +99,16 @@ public: SmallVector forInsts; SmallVector loadOpInsts; SmallVector storeOpInsts; - bool hasNonForRegion = false; + bool hasIfInst = false; void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); } + void visitIfInst(IfInst *ifInst) { hasIfInst = true; } + void visitOperationInst(OperationInst *opInst) { - if (opInst->getNumBlockLists() != 0) - hasNonForRegion = true; - else if (opInst->isa()) + if (opInst->isa()) loadOpInsts.push_back(opInst); - else if (opInst->isa()) + if (opInst->isa()) storeOpInsts.push_back(opInst); } }; @@ -411,8 +410,8 @@ bool MemRefDependenceGraph::init(Function *f) { // all loads and store accesses it contains. LoopNestStateCollector collector; collector.walkForInst(forInst); - // Return false if a non 'for' region was found (not currently supported). - if (collector.hasNonForRegion) + // Return false if IfInsts are found (not currently supported). + if (collector.hasIfInst) return false; Node node(id++, &inst); for (auto *opInst : collector.loadOpInsts) { @@ -435,18 +434,19 @@ bool MemRefDependenceGraph::init(Function *f) { auto *memref = opInst->cast()->getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); - } else if (auto storeOp = opInst->dyn_cast()) { + } + if (auto storeOp = opInst->dyn_cast()) { // Create graph node for top-level store op. Node node(id++, &inst); node.stores.push_back(opInst); auto *memref = opInst->cast()->getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); - } else if (opInst->getNumBlockLists() != 0) { - // Return false if another region is found (not currently supported). - return false; } } + // Return false if IfInsts are found (not currently supported). + if (isa(&inst)) + return false; } // Walk memref access lists and add graph edges between dependent nodes. diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 6d63e4afd2d..39ef758833b 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -119,6 +119,15 @@ PassResult LoopUnroll::runOnFunction(Function *f) { return true; } + bool walkIfInstPostOrder(IfInst *ifInst) { + bool hasInnerLoops = + walkPostOrder(ifInst->getThen()->begin(), ifInst->getThen()->end()); + if (ifInst->hasElse()) + hasInnerLoops |= + walkPostOrder(ifInst->getElse()->begin(), ifInst->getElse()->end()); + return hasInnerLoops; + } + bool walkOpInstPostOrder(OperationInst *opInst) { for (auto &blockList : opInst->getBlockLists()) for (auto &block : blockList) diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index f770684f519..ab37ff63bad 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -20,7 +20,6 @@ // //===----------------------------------------------------------------------===// -#include "mlir/AffineOps/AffineOps.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" @@ -247,7 +246,7 @@ public: PassResult runOnFunction(Function *function) override; bool lowerForInst(ForInst *forInst); - bool lowerAffineIf(AffineIfOp *ifOp); + bool lowerIfInst(IfInst *ifInst); bool lowerAffineApply(AffineApplyOp *op); static char passID; @@ -410,7 +409,7 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) { // enabling easy nesting of "if" instructions and if-then-else-if chains. // // +--------------------------------+ -// | | +// | | // | %zero = constant 0 : index | // | %v = affine_apply #expr1(%ops) | // | %c = cmpi "sge" %v, %zero | @@ -454,11 +453,10 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) { // v v // +--------------------------------+ // | continue: | -// | | +// | | // +--------------------------------+ // -bool LowerAffinePass::lowerAffineIf(AffineIfOp *ifOp) { - auto *ifInst = ifOp->getInstruction(); +bool LowerAffinePass::lowerIfInst(IfInst *ifInst) { auto loc = ifInst->getLoc(); // Start by splitting the block containing the 'if' into two parts. The part @@ -468,38 +466,22 @@ bool LowerAffinePass::lowerAffineIf(AffineIfOp *ifOp) { auto *continueBlock = condBlock->splitBlock(ifInst); // Create a block for the 'then' code, inserting it between the cond and - // continue blocks. Move the instructions over from the AffineIfOp and add a + // continue blocks. Move the instructions over from the IfInst and add a // branch to the continuation point. Block *thenBlock = new Block(); thenBlock->insertBefore(continueBlock); - // If the 'then' block is not empty, then splice the instructions. - auto &oldThenBlocks = ifOp->getThenBlocks(); - if (!oldThenBlocks.empty()) { - // We currently only handle one 'then' block. - if (std::next(oldThenBlocks.begin()) != oldThenBlocks.end()) - return true; - - Block *oldThen = &oldThenBlocks.front(); - - thenBlock->getInstructions().splice(thenBlock->begin(), - oldThen->getInstructions(), - oldThen->begin(), oldThen->end()); - } - + auto *oldThen = ifInst->getThen(); + thenBlock->getInstructions().splice(thenBlock->begin(), + oldThen->getInstructions(), + oldThen->begin(), oldThen->end()); FuncBuilder builder(thenBlock); builder.create(loc, continueBlock); // Handle the 'else' block the same way, but we skip it if we have no else // code. Block *elseBlock = continueBlock; - auto &oldElseBlocks = ifOp->getElseBlocks(); - if (!oldElseBlocks.empty()) { - // We currently only handle one 'else' block. - if (std::next(oldElseBlocks.begin()) != oldElseBlocks.end()) - return true; - - auto *oldElse = &oldElseBlocks.front(); + if (auto *oldElse = ifInst->getElse()) { elseBlock = new Block(); elseBlock->insertBefore(continueBlock); @@ -511,7 +493,7 @@ bool LowerAffinePass::lowerAffineIf(AffineIfOp *ifOp) { } // Ok, now we just have to handle the condition logic. - auto integerSet = ifOp->getIntegerSet(); + auto integerSet = ifInst->getCondition().getIntegerSet(); // Implement short-circuit logic. For each affine expression in the 'if' // condition, convert it into an affine map and call `affine_apply` to obtain @@ -611,30 +593,29 @@ bool LowerAffinePass::lowerAffineApply(AffineApplyOp *op) { PassResult LowerAffinePass::runOnFunction(Function *function) { SmallVector instsToRewrite; - // Collect all the For instructions as well as AffineIfOps and AffineApplyOps. - // We do this as a prepass to avoid invalidating the walker with our rewrite. + // Collect all the If and For instructions as well as AffineApplyOps. We do + // this as a prepass to avoid invalidating the walker with our rewrite. function->walkInsts([&](Instruction *inst) { - if (isa(inst)) + if (isa(inst) || isa(inst)) instsToRewrite.push_back(inst); auto op = dyn_cast(inst); - if (op && (op->isa() || op->isa())) + if (op && op->isa()) instsToRewrite.push_back(inst); }); // Rewrite all of the ifs and fors. We walked the instructions in preorder, // so we know that we will rewrite them in the same order. for (auto *inst : instsToRewrite) - if (auto *forInst = dyn_cast(inst)) { + if (auto *ifInst = dyn_cast(inst)) { + if (lowerIfInst(ifInst)) + return failure(); + } else if (auto *forInst = dyn_cast(inst)) { if (lowerForInst(forInst)) return failure(); } else { auto op = cast(inst); - if (auto ifOp = op->dyn_cast()) { - if (lowerAffineIf(ifOp)) - return failure(); - } else if (lowerAffineApply(op->cast())) { + if (lowerAffineApply(op->cast())) return failure(); - } } return success(); diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 2744b1d624c..09d961f85cd 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -20,7 +20,6 @@ // //===----------------------------------------------------------------------===// -#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/Dominance.h" #include "mlir/Analysis/LoopAnalysis.h" @@ -560,6 +559,9 @@ static bool instantiateMaterialization(Instruction *inst, if (isa(inst)) return inst->emitError("NYI path ForInst"); + if (isa(inst)) + return inst->emitError("NYI path IfInst"); + // Create a builder here for unroll-and-jam effects. FuncBuilder b(inst); auto *opInst = cast(inst); @@ -568,9 +570,6 @@ static bool instantiateMaterialization(Instruction *inst, if (opInst->isa()) { return false; } - if (opInst->getNumBlockLists() != 0) - return inst->emitError("NYI path Op with region"); - if (auto write = opInst->dyn_cast()) { auto *clone = instantiate(&b, write, state->hwVectorType, state->hwVectorInstance, state->substitutionsMap); diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp index ba59123c700..bd39e47786a 100644 --- a/mlir/lib/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp @@ -28,6 +28,7 @@ #define DEBUG_TYPE "simplify-affine-structure" using namespace mlir; +using llvm::report_fatal_error; namespace { @@ -41,6 +42,9 @@ struct SimplifyAffineStructures : public FunctionPass { PassResult runOnFunction(Function *f) override; + void visitIfInst(IfInst *ifInst); + void visitOperationInst(OperationInst *opInst); + static char passID; }; @@ -62,19 +66,28 @@ static IntegerSet simplifyIntegerSet(IntegerSet set) { return set; } -PassResult SimplifyAffineStructures::runOnFunction(Function *f) { - f->walkOps([&](OperationInst *opInst) { - for (auto attr : opInst->getAttrs()) { - if (auto mapAttr = attr.second.dyn_cast()) { - MutableAffineMap mMap(mapAttr.getValue()); - mMap.simplify(); - auto map = mMap.getAffineMap(); - opInst->setAttr(attr.first, AffineMapAttr::get(map)); - } else if (auto setAttr = attr.second.dyn_cast()) { - auto simplified = simplifyIntegerSet(setAttr.getValue()); - opInst->setAttr(attr.first, IntegerSetAttr::get(simplified)); - } +void SimplifyAffineStructures::visitIfInst(IfInst *ifInst) { + auto set = ifInst->getCondition().getIntegerSet(); + ifInst->setIntegerSet(simplifyIntegerSet(set)); +} + +void SimplifyAffineStructures::visitOperationInst(OperationInst *opInst) { + for (auto attr : opInst->getAttrs()) { + if (auto mapAttr = attr.second.dyn_cast()) { + MutableAffineMap mMap(mapAttr.getValue()); + mMap.simplify(); + auto map = mMap.getAffineMap(); + opInst->setAttr(attr.first, AffineMapAttr::get(map)); } + } +} + +PassResult SimplifyAffineStructures::runOnFunction(Function *f) { + f->walkInsts([&](Instruction *inst) { + if (auto *opInst = dyn_cast(inst)) + visitOperationInst(opInst); + if (auto *ifInst = dyn_cast(inst)) + visitIfInst(ifInst); }); return success(); diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir index 595991c0109..bae112dd3b9 100644 --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -243,6 +243,14 @@ func @non_instruction() { // ----- +func @invalid_if_conditional1() { + for %i = 1 to 10 { + if () { // expected-error {{expected ':' or '['}} + } +} + +// ----- + func @invalid_if_conditional2() { for %i = 1 to 10 { if (i)[N] : (i >= ) // expected-error {{expected '== 0' or '>= 0' at end of affine constraint}} @@ -656,11 +664,7 @@ func @invalid_if_operands2(%N : index) { func @invalid_if_operands3(%N : index) { for %i = 1 to 10 { if #set0(%i)[%i] { - // expected-error@-1 {{operand cannot be used as a symbol}} - } - } - return -} + // expected-error@-1 {{value '%i' cannot be used as a symbol}} // ----- // expected-error@+1 {{expected '"' in string literal}} diff --git a/mlir/test/IR/locations.mlir b/mlir/test/IR/locations.mlir index 8a90d12bd03..e3e1bbbbfad 100644 --- a/mlir/test/IR/locations.mlir +++ b/mlir/test/IR/locations.mlir @@ -16,9 +16,9 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) { for %i0 = 0 to 8 loc(fused["foo", "mysource.cc":10:8]) { } - // CHECK: } loc(fused<"myPass">["foo", "foo2"]) - if #set0(%2) { - } loc(fused<"myPass">["foo", "foo2"]) + // CHECK: ) loc(fused<"myPass">["foo", "foo2"]) + if #set0(%2) loc(fused<"myPass">["foo", "foo2"]) { + } // CHECK: return %0 : i32 loc(unknown) return %1 : i32 loc(unknown) diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index 626f24569c6..33109606538 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -287,15 +287,13 @@ func @ifinst(%N: index) { // CHECK: %c1_i32 = constant 1 : i32 %y = "add"(%x, %i) : (i32, index) -> i32 // CHECK: %0 = "add"(%c1_i32, %i0) : (i32, index) -> i32 %z = "mul"(%y, %y) : (i32, i32) -> i32 // CHECK: %1 = "mul"(%0, %0) : (i32, i32) -> i32 - } else { // CHECK } else { - if (i)[N] : (i - 2 >= 0, 4 - i >= 0)(%i)[%N] { // CHECK if (#set1(%i0)[%arg0]) { - // CHECK: %c1 = constant 1 : index - %u = constant 1 : index - // CHECK: %2 = affine_apply #map{{.*}}(%i0, %i0)[%c1] - %w = affine_apply (d0,d1)[s0] -> (d0+d1+s0) (%i, %i) [%u] - } else { // CHECK } else { - %v = constant 3 : i32 // %c3_i32 = constant 3 : i32 - } + } else if (i)[N] : (i - 2 >= 0, 4 - i >= 0)(%i)[%N] { // CHECK } else if (#set1(%i0)[%arg0]) { + // CHECK: %c1 = constant 1 : index + %u = constant 1 : index + // CHECK: %2 = affine_apply #map{{.*}}(%i0, %i0)[%c1] + %w = affine_apply (d0,d1)[s0] -> (d0+d1+s0) (%i, %i) [%u] + } else { // CHECK } else { + %v = constant 3 : i32 // %c3_i32 = constant 3 : i32 } // CHECK } } // CHECK } return // CHECK return @@ -753,11 +751,11 @@ func @type_alias() -> !i32_type_alias { func @verbose_if(%N: index) { %c = constant 200 : index - // CHECK: if #set0(%c200)[%arg0, %c200] { - "if"(%c, %N, %c) { condition: #set0 } : (index, index, index) -> () { + // CHECK: "if"(%c200, %arg0, %c200) {cond: #set0} : (index, index, index) -> () { + "if"(%c, %N, %c) { cond: #set0 } : (index, index, index) -> () { // CHECK-NEXT: "add" %y = "add"(%c, %N) : (index, index) -> index - // CHECK-NEXT: } else { + // CHECK-NEXT: } { } { // The else block list. // CHECK-NEXT: "add" %z = "add"(%c, %c) : (index, index) -> index diff --git a/mlir/test/IR/pretty-locations.mlir b/mlir/test/IR/pretty-locations.mlir index 69dace45165..cb2e14a56d5 100644 --- a/mlir/test/IR/pretty-locations.mlir +++ b/mlir/test/IR/pretty-locations.mlir @@ -21,10 +21,10 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) { for %i0 = 0 to 8 loc(fused["foo", "mysource.cc":10:8]) { } - // CHECK: } <"myPass">["foo", "foo2"] - if #set0(%2) { - } loc(fused<"myPass">["foo", "foo2"]) + // CHECK: ) <"myPass">["foo", "foo2"] + if #set0(%2) loc(fused<"myPass">["foo", "foo2"]) { + } // CHECK: return %0 : i32 [unknown] return %1 : i32 loc(unknown) -} +} \ No newline at end of file diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index 162f193f662..d170ce590f7 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -483,7 +483,7 @@ func @should_not_fuse_if_inst_at_top_level() { %c0 = constant 4 : index if #set0(%c0) { } - // Top-level IfOp should prevent fusion. + // Top-level IfInst should prevent fusion. // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } @@ -512,7 +512,7 @@ func @should_not_fuse_if_inst_in_loop_nest() { %v0 = load %m[%i1] : memref<10xf32> } - // IfOp in ForInst should prevent fusion. + // IfInst in ForInst should prevent fusion. // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } diff --git a/mlir/test/Transforms/memref-dependence-check.mlir b/mlir/test/Transforms/memref-dependence-check.mlir index 628044ed77a..6f6ad3fafc7 100644 --- a/mlir/test/Transforms/memref-dependence-check.mlir +++ b/mlir/test/Transforms/memref-dependence-check.mlir @@ -10,7 +10,7 @@ func @store_may_execute_before_load() { %cf7 = constant 7.0 : f32 %c0 = constant 4 : index // There is a dependence from store 0 to load 1 at depth 1 because the - // ancestor IfOp of the store, dominates the ancestor ForSmt of the load, + // ancestor IfInst of the store, dominates the ancestor ForSmt of the load, // and thus the store "may" conditionally execute before the load. if #set0(%c0) { for %i0 = 0 to 10 { diff --git a/mlir/test/Transforms/strip-debug-info.mlir b/mlir/test/Transforms/strip-debug-info.mlir index 13f009deb70..5509c7aba55 100644 --- a/mlir/test/Transforms/strip-debug-info.mlir +++ b/mlir/test/Transforms/strip-debug-info.mlir @@ -13,10 +13,10 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) { for %i0 = 0 to 8 loc(fused["foo", "mysource.cc":10:8]) { } - // CHECK: } loc(unknown) + // CHECK: if #set0(%c4) loc(unknown) %2 = constant 4 : index - if #set0(%2) { - } loc(fused<"myPass">["foo", "foo2"]) + if #set0(%2) loc(fused<"myPass">["foo", "foo2"]) { + } // CHECK: return %0 : i32 loc(unknown) return %1 : i32 loc("bar") -- cgit v1.2.3 From 755538328b0651661323bee33cf18b4ea76ee92a Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 28 Jan 2019 21:23:53 -0800 Subject: Recommit: Define a AffineOps dialect as well as an AffineIfOp operation. Replace all instances of IfInst with AffineIfOp and delete IfInst. PiperOrigin-RevId: 231342063 --- mlir/include/mlir/AffineOps/AffineOps.h | 91 +++++++++++++ mlir/include/mlir/Analysis/NestedMatcher.h | 1 - mlir/include/mlir/IR/Block.h | 14 +- mlir/include/mlir/IR/Builders.h | 4 - mlir/include/mlir/IR/InstVisitor.h | 28 +--- mlir/include/mlir/IR/Instruction.h | 1 - mlir/include/mlir/IR/Instructions.h | 124 ----------------- mlir/include/mlir/IR/OpImplementation.h | 21 ++- mlir/include/mlir/IR/UseDefLists.h | 3 +- .../mlir/Transforms/MLPatternLoweringPass.h | 2 +- mlir/lib/AffineOps/AffineOps.cpp | 151 +++++++++++++++++++++ mlir/lib/AffineOps/DialectRegistration.cpp | 22 +++ mlir/lib/Analysis/LoopAnalysis.cpp | 11 ++ mlir/lib/Analysis/NestedMatcher.cpp | 20 ++- mlir/lib/Analysis/Utils.cpp | 22 +-- mlir/lib/Analysis/Verifier.cpp | 25 ---- mlir/lib/IR/AsmPrinter.cpp | 38 +----- mlir/lib/IR/Builders.cpp | 7 - mlir/lib/IR/Instruction.cpp | 109 ++------------- mlir/lib/IR/Operation.cpp | 2 +- mlir/lib/IR/Value.cpp | 2 - mlir/lib/Parser/Parser.cpp | 129 +++++++----------- mlir/lib/Parser/TokenKinds.def | 2 - mlir/lib/Transforms/CSE.cpp | 10 -- mlir/lib/Transforms/LoopFusion.cpp | 24 ++-- mlir/lib/Transforms/LoopUnroll.cpp | 9 -- mlir/lib/Transforms/LowerAffine.cpp | 59 +++++--- mlir/lib/Transforms/MaterializeVectors.cpp | 7 +- mlir/lib/Transforms/SimplifyAffineStructures.cpp | 37 ++--- mlir/test/IR/invalid.mlir | 14 +- mlir/test/IR/locations.mlir | 6 +- mlir/test/IR/parser.mlir | 22 +-- mlir/test/IR/pretty-locations.mlir | 8 +- mlir/test/Transforms/loop-fusion.mlir | 4 +- mlir/test/Transforms/memref-dependence-check.mlir | 2 +- mlir/test/Transforms/strip-debug-info.mlir | 6 +- 36 files changed, 495 insertions(+), 542 deletions(-) create mode 100644 mlir/include/mlir/AffineOps/AffineOps.h create mode 100644 mlir/lib/AffineOps/AffineOps.cpp create mode 100644 mlir/lib/AffineOps/DialectRegistration.cpp (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/AffineOps/AffineOps.h b/mlir/include/mlir/AffineOps/AffineOps.h new file mode 100644 index 00000000000..d511f628c3c --- /dev/null +++ b/mlir/include/mlir/AffineOps/AffineOps.h @@ -0,0 +1,91 @@ +//===- AffineOps.h - MLIR Affine Operations -------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines convenience types for working with Affine operations +// in the MLIR instruction set. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_AFFINEOPS_AFFINEOPS_H +#define MLIR_AFFINEOPS_AFFINEOPS_H + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" + +namespace mlir { + +class AffineOpsDialect : public Dialect { +public: + AffineOpsDialect(MLIRContext *context); +}; + +/// The "if" operation represents an if–then–else construct for conditionally +/// executing two regions of code. The operands to an if operation are an +/// IntegerSet condition and a set of symbol/dimension operands to the +/// condition set. The operation produces no results. For example: +/// +/// if #set(%i) { +/// ... +/// } else { +/// ... +/// } +/// +/// The 'else' blocks to the if operation are optional, and may be omitted. For +/// example: +/// +/// if #set(%i) { +/// ... +/// } +/// +class AffineIfOp + : public Op { +public: + // Hooks to customize behavior of this op. + static void build(Builder *builder, OperationState *result, + IntegerSet condition, ArrayRef conditionOperands); + + static StringRef getOperationName() { return "if"; } + static StringRef getConditionAttrName() { return "condition"; } + + IntegerSet getIntegerSet() const; + void setIntegerSet(IntegerSet newSet); + + /// Returns the list of 'then' blocks. + BlockList &getThenBlocks(); + const BlockList &getThenBlocks() const { + return const_cast(this)->getThenBlocks(); + } + + /// Returns the list of 'else' blocks. + BlockList &getElseBlocks(); + const BlockList &getElseBlocks() const { + return const_cast(this)->getElseBlocks(); + } + + bool verify() const; + static bool parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p) const; + +private: + friend class OperationInst; + explicit AffineIfOp(const OperationInst *state) : Op(state) {} +}; + +} // end namespace mlir + +#endif diff --git a/mlir/include/mlir/Analysis/NestedMatcher.h b/mlir/include/mlir/Analysis/NestedMatcher.h index c205d55488e..161bb217a10 100644 --- a/mlir/include/mlir/Analysis/NestedMatcher.h +++ b/mlir/include/mlir/Analysis/NestedMatcher.h @@ -128,7 +128,6 @@ private: void matchOne(Instruction *elem); void visitForInst(ForInst *forInst) { matchOne(forInst); } - void visitIfInst(IfInst *ifInst) { matchOne(ifInst); } void visitOperationInst(OperationInst *opInst) { matchOne(opInst); } /// POD paylod. diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h index 1b14d925d32..bc9563f847a 100644 --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -26,7 +26,6 @@ #include "llvm/ADT/PointerUnion.h" namespace mlir { -class IfInst; class BlockList; class BlockAndValueMapping; @@ -62,7 +61,7 @@ public: } /// Returns the function that this block is part of, even if the block is - /// nested under an IfInst or ForInst. + /// nested under an OperationInst or ForInst. Function *getFunction(); const Function *getFunction() const { return const_cast(this)->getFunction(); @@ -325,7 +324,7 @@ private: namespace mlir { /// This class contains a list of basic blocks and has a notion of the object it -/// is part of - a Function or IfInst or ForInst. +/// is part of - a Function or OperationInst or ForInst. class BlockList { public: explicit BlockList(Function *container); @@ -365,15 +364,16 @@ public: return &BlockList::blocks; } - /// A BlockList is part of a Function or and IfInst/ForInst. If it is - /// part of an IfInst/ForInst, then return it, otherwise return null. + /// A BlockList is part of a function or an operation region. If it is + /// part of an operation region, then return the operation, otherwise return + /// null. Instruction *getContainingInst(); const Instruction *getContainingInst() const { return const_cast(this)->getContainingInst(); } - /// A BlockList is part of a Function or and IfInst/ForInst. If it is - /// part of a Function, then return it, otherwise return null. + /// A BlockList is part of a function or an operation region. If it is part + /// of a Function, then return it, otherwise return null. Function *getContainingFunction(); const Function *getContainingFunction() const { return const_cast(this)->getContainingFunction(); diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 156bd02bb52..3271c12afde 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -286,10 +286,6 @@ public: // Default step is 1. ForInst *createFor(Location loc, int64_t lb, int64_t ub, int64_t step = 1); - /// Creates if instruction. - IfInst *createIf(Location location, ArrayRef operands, - IntegerSet set); - private: Function *function; Block *block = nullptr; diff --git a/mlir/include/mlir/IR/InstVisitor.h b/mlir/include/mlir/IR/InstVisitor.h index b6a759e76f5..78810da909d 100644 --- a/mlir/include/mlir/IR/InstVisitor.h +++ b/mlir/include/mlir/IR/InstVisitor.h @@ -44,7 +44,7 @@ // lc.walk(function); // numLoops = lc.numLoops; // -// There are 'visit' methods for OperationInst, ForInst, IfInst, and +// There are 'visit' methods for OperationInst, ForInst, and // Function, which recursively process all contained instructions. // // Note that if you don't implement visitXXX for some instruction type, @@ -85,8 +85,6 @@ public: switch (s->getKind()) { case Instruction::Kind::For: return static_cast(this)->visitForInst(cast(s)); - case Instruction::Kind::If: - return static_cast(this)->visitIfInst(cast(s)); case Instruction::Kind::OperationInst: return static_cast(this)->visitOperationInst( cast(s)); @@ -104,7 +102,6 @@ public: // When visiting a for inst, if inst, or an operation inst directly, these // methods get called to indicate when transitioning into a new unit. void visitForInst(ForInst *forInst) {} - void visitIfInst(IfInst *ifInst) {} void visitOperationInst(OperationInst *opInst) {} }; @@ -166,23 +163,6 @@ public: static_cast(this)->visitForInst(forInst); } - void walkIfInst(IfInst *ifInst) { - static_cast(this)->visitIfInst(ifInst); - static_cast(this)->walk(ifInst->getThen()->begin(), - ifInst->getThen()->end()); - if (auto *elseBlock = ifInst->getElse()) - static_cast(this)->walk(elseBlock->begin(), elseBlock->end()); - } - - void walkIfInstPostOrder(IfInst *ifInst) { - static_cast(this)->walkPostOrder(ifInst->getThen()->begin(), - ifInst->getThen()->end()); - if (auto *elseBlock = ifInst->getElse()) - static_cast(this)->walkPostOrder(elseBlock->begin(), - elseBlock->end()); - static_cast(this)->visitIfInst(ifInst); - } - // Function to walk a instruction. RetTy walk(Instruction *s) { static_assert(std::is_base_of::value, @@ -193,8 +173,6 @@ public: switch (s->getKind()) { case Instruction::Kind::For: return static_cast(this)->walkForInst(cast(s)); - case Instruction::Kind::If: - return static_cast(this)->walkIfInst(cast(s)); case Instruction::Kind::OperationInst: return static_cast(this)->walkOpInst(cast(s)); } @@ -210,9 +188,6 @@ public: case Instruction::Kind::For: return static_cast(this)->walkForInstPostOrder( cast(s)); - case Instruction::Kind::If: - return static_cast(this)->walkIfInstPostOrder( - cast(s)); case Instruction::Kind::OperationInst: return static_cast(this)->walkOpInstPostOrder( cast(s)); @@ -231,7 +206,6 @@ public: // processing their descendants in some way. When using RetTy, all of these // need to be overridden. void visitForInst(ForInst *forInst) {} - void visitIfInst(IfInst *ifInst) {} void visitOperationInst(OperationInst *opInst) {} void visitInstruction(Instruction *inst) {} }; diff --git a/mlir/include/mlir/IR/Instruction.h b/mlir/include/mlir/IR/Instruction.h index 6a296b7348e..3dc1e76dd20 100644 --- a/mlir/include/mlir/IR/Instruction.h +++ b/mlir/include/mlir/IR/Instruction.h @@ -75,7 +75,6 @@ public: enum class Kind { OperationInst = (int)IROperandOwner::Kind::OperationInst, For = (int)IROperandOwner::Kind::ForInst, - If = (int)IROperandOwner::Kind::IfInst, }; Kind getKind() const { return (Kind)IROperandOwner::getKind(); } diff --git a/mlir/include/mlir/IR/Instructions.h b/mlir/include/mlir/IR/Instructions.h index 71d832b8b90..fb6b1b97ca0 100644 --- a/mlir/include/mlir/IR/Instructions.h +++ b/mlir/include/mlir/IR/Instructions.h @@ -794,130 +794,6 @@ private: friend class ForInst; }; - -/// If instruction restricts execution to a subset of the loop iteration space. -class IfInst : public Instruction { -public: - static IfInst *create(Location location, ArrayRef operands, - IntegerSet set); - ~IfInst(); - - //===--------------------------------------------------------------------===// - // Then, else, condition. - //===--------------------------------------------------------------------===// - - Block *getThen() { return &thenClause.front(); } - const Block *getThen() const { return &thenClause.front(); } - Block *getElse() { return elseClause ? &elseClause->front() : nullptr; } - const Block *getElse() const { - return elseClause ? &elseClause->front() : nullptr; - } - bool hasElse() const { return elseClause != nullptr; } - - Block *createElse() { - assert(elseClause == nullptr && "already has an else clause!"); - elseClause = new BlockList(this); - elseClause->push_back(new Block()); - return &elseClause->front(); - } - - const AffineCondition getCondition() const; - - IntegerSet getIntegerSet() const { return set; } - void setIntegerSet(IntegerSet newSet) { - assert(newSet.getNumOperands() == operands.size()); - set = newSet; - } - - //===--------------------------------------------------------------------===// - // Operands - //===--------------------------------------------------------------------===// - - /// Operand iterators. - using operand_iterator = OperandIterator; - using const_operand_iterator = OperandIterator; - - /// Operand iterator range. - using operand_range = llvm::iterator_range; - using const_operand_range = llvm::iterator_range; - - unsigned getNumOperands() const { return operands.size(); } - - Value *getOperand(unsigned idx) { return getInstOperand(idx).get(); } - const Value *getOperand(unsigned idx) const { - return getInstOperand(idx).get(); - } - void setOperand(unsigned idx, Value *value) { - getInstOperand(idx).set(value); - } - - operand_iterator operand_begin() { return operand_iterator(this, 0); } - operand_iterator operand_end() { - return operand_iterator(this, getNumOperands()); - } - - const_operand_iterator operand_begin() const { - return const_operand_iterator(this, 0); - } - const_operand_iterator operand_end() const { - return const_operand_iterator(this, getNumOperands()); - } - - ArrayRef getInstOperands() const { return operands; } - MutableArrayRef getInstOperands() { return operands; } - InstOperand &getInstOperand(unsigned idx) { return getInstOperands()[idx]; } - const InstOperand &getInstOperand(unsigned idx) const { - return getInstOperands()[idx]; - } - - //===--------------------------------------------------------------------===// - // Other - //===--------------------------------------------------------------------===// - - MLIRContext *getContext() const; - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(const IROperandOwner *ptr) { - return ptr->getKind() == IROperandOwner::Kind::IfInst; - } - -private: - // it is always present. - BlockList thenClause; - // 'else' clause of the if instruction. 'nullptr' if there is no else clause. - BlockList *elseClause; - - // The integer set capturing the conditional guard. - IntegerSet set; - - // Condition operands. - std::vector operands; - - explicit IfInst(Location location, unsigned numOperands, IntegerSet set); -}; - -/// AffineCondition represents a condition of the 'if' instruction. -/// Its life span should not exceed that of the objects it refers to. -/// AffineCondition does not provide its own methods for iterating over -/// the operands since the iterators of the if instruction accomplish -/// the same purpose. -/// -/// AffineCondition is trivially copyable, so it should be passed by value. -class AffineCondition { -public: - const IfInst *getIfInst() const { return &inst; } - IntegerSet getIntegerSet() const { return set; } - -private: - // 'if' instruction that contains this affine condition. - const IfInst &inst; - // Integer set for this affine condition. - IntegerSet set; - - AffineCondition(const IfInst &inst, IntegerSet set) : inst(inst), set(set) {} - - friend class IfInst; -}; } // end namespace mlir #endif // MLIR_IR_INSTRUCTIONS_H diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index 1e319db3571..d3a5d35427f 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -89,6 +89,9 @@ public: /// Print the entire operation with the default generic assembly form. virtual void printGenericOp(const OperationInst *op) = 0; + /// Prints a block list. + virtual void printBlockList(const BlockList &blocks) = 0; + private: OpAsmPrinter(const OpAsmPrinter &) = delete; void operator=(const OpAsmPrinter &) = delete; @@ -195,7 +198,19 @@ public: virtual bool parseColonTypeList(SmallVectorImpl &result) = 0; /// Parse a keyword followed by a type. - virtual bool parseKeywordType(const char *keyword, Type &result) = 0; + bool parseKeywordType(const char *keyword, Type &result) { + return parseKeyword(keyword) || parseType(result); + } + + /// Parse a keyword. + bool parseKeyword(const char *keyword) { + if (parseOptionalKeyword(keyword)) + return emitError(getNameLoc(), "expected '" + Twine(keyword) + "'"); + return false; + } + + /// If a keyword is present, then parse it. + virtual bool parseOptionalKeyword(const char *keyword) = 0; /// Add the specified type to the end of the specified type list and return /// false. This is a helper designed to allow parse methods to be simple and @@ -296,6 +311,10 @@ public: int requiredOperandCount = -1, Delimiter delimiter = Delimiter::None) = 0; + /// Parses a block list. Any parsed blocks are filled in to the + /// operation's block lists after the operation is created. + virtual bool parseBlockList() = 0; + //===--------------------------------------------------------------------===// // Methods for interacting with the parser //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h index 053d3520103..80cd21362ce 100644 --- a/mlir/include/mlir/IR/UseDefLists.h +++ b/mlir/include/mlir/IR/UseDefLists.h @@ -81,10 +81,9 @@ public: enum class Kind { OperationInst, ForInst, - IfInst, /// These enums define ranges used for classof implementations. - INST_LAST = IfInst, + INST_LAST = ForInst, }; Kind getKind() const { return locationAndKind.getInt(); } diff --git a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h index 978fa45ab23..00c6577240c 100644 --- a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h +++ b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h @@ -93,7 +93,7 @@ using OwningMLLoweringPatternList = /// next _original_ operation is considered. /// In other words, for each operation, the pass applies the first matching /// rewriter in the list and advances to the (lexically) next operation. -/// Non-operation instructions (ForInst and IfInst) are ignored. +/// Non-operation instructions (ForInst) are ignored. /// This is similar to greedy worklist-based pattern rewriter, except that this /// operates on ML functions using an ML builder and does not maintain the work /// list. Note that, as of the time of writing, worklist-based rewriter did not diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp new file mode 100644 index 00000000000..5b29467fc44 --- /dev/null +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -0,0 +1,151 @@ +//===- AffineOps.cpp - MLIR Affine Operations -----------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/AffineOps/AffineOps.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/OpImplementation.h" +using namespace mlir; + +//===----------------------------------------------------------------------===// +// AffineOpsDialect +//===----------------------------------------------------------------------===// + +AffineOpsDialect::AffineOpsDialect(MLIRContext *context) + : Dialect(/*namePrefix=*/"", context) { + addOperations(); +} + +//===----------------------------------------------------------------------===// +// AffineIfOp +//===----------------------------------------------------------------------===// + +void AffineIfOp::build(Builder *builder, OperationState *result, + IntegerSet condition, + ArrayRef conditionOperands) { + result->addAttribute(getConditionAttrName(), IntegerSetAttr::get(condition)); + result->addOperands(conditionOperands); + + // Reserve 2 block lists, one for the 'then' and one for the 'else' regions. + result->reserveBlockLists(2); +} + +bool AffineIfOp::verify() const { + // Verify that we have a condition attribute. + auto conditionAttr = getAttrOfType(getConditionAttrName()); + if (!conditionAttr) + return emitOpError("requires an integer set attribute named 'condition'"); + + // Verify that the operands are valid dimension/symbols. + IntegerSet condition = conditionAttr.getValue(); + for (unsigned i = 0, e = getNumOperands(); i != e; ++i) { + const Value *operand = getOperand(i); + if (i < condition.getNumDims() && !operand->isValidDim()) + return emitOpError("operand cannot be used as a dimension id"); + if (i >= condition.getNumDims() && !operand->isValidSymbol()) + return emitOpError("operand cannot be used as a symbol"); + } + + // Verify that the entry of each child blocklist does not have arguments. + for (const auto &blockList : getInstruction()->getBlockLists()) { + if (blockList.empty()) + continue; + + // TODO(riverriddle) We currently do not allow multiple blocks in child + // block lists. + if (std::next(blockList.begin()) != blockList.end()) + return emitOpError( + "expects only one block per 'if' or 'else' block list"); + if (blockList.front().getTerminator()) + return emitOpError("expects region block to not have a terminator"); + + for (const auto &b : blockList) + if (b.getNumArguments() != 0) + return emitOpError( + "requires that child entry blocks have no arguments"); + } + return false; +} + +bool AffineIfOp::parse(OpAsmParser *parser, OperationState *result) { + // Parse the condition attribute set. + IntegerSetAttr conditionAttr; + unsigned numDims; + if (parser->parseAttribute(conditionAttr, getConditionAttrName().data(), + result->attributes) || + parseDimAndSymbolList(parser, result->operands, numDims)) + return true; + + // Verify the condition operands. + auto set = conditionAttr.getValue(); + if (set.getNumDims() != numDims) + return parser->emitError( + parser->getNameLoc(), + "dim operand count and integer set dim count must match"); + if (numDims + set.getNumSymbols() != result->operands.size()) + return parser->emitError( + parser->getNameLoc(), + "symbol operand count and integer set symbol count must match"); + + // Parse the 'then' block list. + if (parser->parseBlockList()) + return true; + + // If we find an 'else' keyword then parse the else block list. + if (!parser->parseOptionalKeyword("else")) { + if (parser->parseBlockList()) + return true; + } + + // Reserve 2 block lists, one for the 'then' and one for the 'else' regions. + result->reserveBlockLists(2); + return false; +} + +void AffineIfOp::print(OpAsmPrinter *p) const { + auto conditionAttr = getAttrOfType(getConditionAttrName()); + *p << "if " << conditionAttr; + printDimAndSymbolList(operand_begin(), operand_end(), + conditionAttr.getValue().getNumDims(), p); + p->printBlockList(getInstruction()->getBlockList(0)); + + // Print the 'else' block list if it has any blocks. + const auto &elseBlockList = getInstruction()->getBlockList(1); + if (!elseBlockList.empty()) { + *p << " else"; + p->printBlockList(elseBlockList); + } +} + +IntegerSet AffineIfOp::getIntegerSet() const { + return getAttrOfType(getConditionAttrName()).getValue(); +} +void AffineIfOp::setIntegerSet(IntegerSet newSet) { + setAttr( + Identifier::get(getConditionAttrName(), getInstruction()->getContext()), + IntegerSetAttr::get(newSet)); +} + +/// Returns the list of 'then' blocks. +BlockList &AffineIfOp::getThenBlocks() { + return getInstruction()->getBlockList(0); +} + +/// Returns the list of 'else' blocks. +BlockList &AffineIfOp::getElseBlocks() { + return getInstruction()->getBlockList(1); +} diff --git a/mlir/lib/AffineOps/DialectRegistration.cpp b/mlir/lib/AffineOps/DialectRegistration.cpp new file mode 100644 index 00000000000..0afb32c1bd6 --- /dev/null +++ b/mlir/lib/AffineOps/DialectRegistration.cpp @@ -0,0 +1,22 @@ +//===- DialectRegistration.cpp - Register Affine Op dialect ---------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/AffineOps/AffineOps.h" +using namespace mlir; + +// Static initialization for Affine op dialect registration. +static DialectRegistration StandardOps; diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 219f356807a..07c903a6613 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -21,6 +21,7 @@ #include "mlir/Analysis/LoopAnalysis.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/NestedMatcher.h" @@ -246,6 +247,16 @@ static bool isVectorizableLoopWithCond(const ForInst &loop, return false; } + // No vectorization across unknown regions. + auto regions = matcher::Op([](const Instruction &inst) -> bool { + auto &opInst = cast(inst); + return opInst.getNumBlockLists() != 0 && !opInst.isa(); + }); + auto regionsMatched = regions.match(forInst); + if (!regionsMatched.empty()) { + return false; + } + auto vectorTransfers = matcher::Op(isVectorTransferReadOrWrite); auto vectorTransfersMatched = vectorTransfers.match(forInst); if (!vectorTransfersMatched.empty()) { diff --git a/mlir/lib/Analysis/NestedMatcher.cpp b/mlir/lib/Analysis/NestedMatcher.cpp index 4f32e9b22f4..491a9bef1b9 100644 --- a/mlir/lib/Analysis/NestedMatcher.cpp +++ b/mlir/lib/Analysis/NestedMatcher.cpp @@ -16,6 +16,7 @@ // ============================================================================= #include "mlir/Analysis/NestedMatcher.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/StandardOps/StandardOps.h" #include "llvm/ADT/ArrayRef.h" @@ -186,6 +187,11 @@ FilterFunctionType NestedPattern::getFilterFunction() { return storage->filter; } +static bool isAffineIfOp(const Instruction &inst) { + return isa(inst) && + cast(inst).isa(); +} + namespace mlir { namespace matcher { @@ -194,16 +200,22 @@ NestedPattern Op(FilterFunctionType filter) { } NestedPattern If(NestedPattern child) { - return NestedPattern(Instruction::Kind::If, child, defaultFilterFunction); + return NestedPattern(Instruction::Kind::OperationInst, child, isAffineIfOp); } NestedPattern If(FilterFunctionType filter, NestedPattern child) { - return NestedPattern(Instruction::Kind::If, child, filter); + return NestedPattern(Instruction::Kind::OperationInst, child, + [filter](const Instruction &inst) { + return isAffineIfOp(inst) && filter(inst); + }); } NestedPattern If(ArrayRef nested) { - return NestedPattern(Instruction::Kind::If, nested, defaultFilterFunction); + return NestedPattern(Instruction::Kind::OperationInst, nested, isAffineIfOp); } NestedPattern If(FilterFunctionType filter, ArrayRef nested) { - return NestedPattern(Instruction::Kind::If, nested, filter); + return NestedPattern(Instruction::Kind::OperationInst, nested, + [filter](const Instruction &inst) { + return isAffineIfOp(inst) && filter(inst); + }); } NestedPattern For(NestedPattern child) { diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 939a2ede618..0e77d4d9084 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -22,6 +22,7 @@ #include "mlir/Analysis/Utils.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/IR/Builders.h" @@ -43,7 +44,7 @@ void mlir::getLoopIVs(const Instruction &inst, // Traverse up the hierarchy collecing all 'for' instruction while skipping // over 'if' instructions. while (currInst && ((currForInst = dyn_cast(currInst)) || - isa(currInst))) { + cast(currInst)->isa())) { if (currForInst) loops->push_back(currForInst); currInst = currInst->getParentInst(); @@ -359,21 +360,12 @@ static Instruction *getInstAtPosition(ArrayRef positions, if (auto *childForInst = dyn_cast(&inst)) return getInstAtPosition(positions, level + 1, childForInst->getBody()); - if (auto *ifInst = dyn_cast(&inst)) { - auto *ret = getInstAtPosition(positions, level + 1, ifInst->getThen()); - if (ret != nullptr) - return ret; - if (auto *elseClause = ifInst->getElse()) - return getInstAtPosition(positions, level + 1, elseClause); - } - if (auto *opInst = dyn_cast(&inst)) { - for (auto &blockList : opInst->getBlockLists()) { - for (auto &b : blockList) - if (auto *ret = getInstAtPosition(positions, level + 1, &b)) - return ret; - } - return nullptr; + for (auto &blockList : cast(&inst)->getBlockLists()) { + for (auto &b : blockList) + if (auto *ret = getInstAtPosition(positions, level + 1, &b)) + return ret; } + return nullptr; } return nullptr; } diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index 383a4878c35..474eeb2a28e 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -73,7 +73,6 @@ public: bool verifyBlock(const Block &block, bool isTopLevel); bool verifyOperation(const OperationInst &op); bool verifyForInst(const ForInst &forInst); - bool verifyIfInst(const IfInst &ifInst); bool verifyDominance(const Block &block); bool verifyInstDominance(const Instruction &inst); @@ -180,10 +179,6 @@ bool FuncVerifier::verifyBlock(const Block &block, bool isTopLevel) { if (verifyForInst(cast(inst))) return true; break; - case Instruction::Kind::If: - if (verifyIfInst(cast(inst))) - return true; - break; } } @@ -250,18 +245,6 @@ bool FuncVerifier::verifyForInst(const ForInst &forInst) { return verifyBlock(*forInst.getBody(), /*isTopLevel=*/false); } -bool FuncVerifier::verifyIfInst(const IfInst &ifInst) { - // TODO: check that if conditions are properly formed. - if (verifyBlock(*ifInst.getThen(), /*isTopLevel*/ false)) - return true; - - if (auto *elseClause = ifInst.getElse()) - if (verifyBlock(*elseClause, /*isTopLevel*/ false)) - return true; - - return false; -} - bool FuncVerifier::verifyDominance(const Block &block) { for (auto &inst : block) { // Check that all operands on the instruction are ok. @@ -283,14 +266,6 @@ bool FuncVerifier::verifyDominance(const Block &block) { if (verifyDominance(*cast(inst).getBody())) return true; break; - case Instruction::Kind::If: - auto &ifInst = cast(inst); - if (verifyDominance(*ifInst.getThen())) - return true; - if (auto *elseClause = ifInst.getElse()) - if (verifyDominance(*elseClause)) - return true; - break; } } return false; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 21bc3b824b1..cb4c1f0edce 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -145,7 +145,6 @@ private: // Visit functions. void visitInstruction(const Instruction *inst); void visitForInst(const ForInst *forInst); - void visitIfInst(const IfInst *ifInst); void visitOperationInst(const OperationInst *opInst); void visitType(Type type); void visitAttribute(Attribute attr); @@ -197,10 +196,6 @@ void ModuleState::visitAttribute(Attribute attr) { } } -void ModuleState::visitIfInst(const IfInst *ifInst) { - recordIntegerSetReference(ifInst->getIntegerSet()); -} - void ModuleState::visitForInst(const ForInst *forInst) { AffineMap lbMap = forInst->getLowerBoundMap(); if (!hasCustomForm(lbMap)) @@ -225,8 +220,6 @@ void ModuleState::visitOperationInst(const OperationInst *op) { void ModuleState::visitInstruction(const Instruction *inst) { switch (inst->getKind()) { - case Instruction::Kind::If: - return visitIfInst(cast(inst)); case Instruction::Kind::For: return visitForInst(cast(inst)); case Instruction::Kind::OperationInst: @@ -1077,7 +1070,6 @@ public: void print(const Instruction *inst); void print(const OperationInst *inst); void print(const ForInst *inst); - void print(const IfInst *inst); void print(const Block *block, bool printBlockArgs = true); void printOperation(const OperationInst *op); @@ -1125,6 +1117,9 @@ public: unsigned index) override; /// Print a block list. + void printBlockList(const BlockList &blocks) override { + printBlockList(blocks, /*printEntryBlockArgs=*/true); + } void printBlockList(const BlockList &blocks, bool printEntryBlockArgs) { os << " {\n"; if (!blocks.empty()) { @@ -1214,12 +1209,6 @@ void FunctionPrinter::numberValuesInBlock(const Block &block) { // Recursively number the stuff in the body. numberValuesInBlock(*cast(&inst)->getBody()); break; - case Instruction::Kind::If: { - auto *ifInst = cast(&inst); - numberValuesInBlock(*ifInst->getThen()); - if (auto *elseBlock = ifInst->getElse()) - numberValuesInBlock(*elseBlock); - } } } } @@ -1360,8 +1349,7 @@ void FunctionPrinter::printFunctionSignature() { } void FunctionPrinter::print(const Block *block, bool printBlockArgs) { - // Print the block label and argument list, unless this is the first block of - // the function, or the first block of an IfInst/ForInst with no arguments. + // Print the block label and argument list if requested. if (printBlockArgs) { os.indent(currentIndent); printBlockName(block); @@ -1418,8 +1406,6 @@ void FunctionPrinter::print(const Instruction *inst) { return print(cast(inst)); case Instruction::Kind::For: return print(cast(inst)); - case Instruction::Kind::If: - return print(cast(inst)); } } @@ -1447,22 +1433,6 @@ void FunctionPrinter::print(const ForInst *inst) { os.indent(currentIndent) << "}"; } -void FunctionPrinter::print(const IfInst *inst) { - os.indent(currentIndent) << "if "; - IntegerSet set = inst->getIntegerSet(); - printIntegerSetReference(set); - printDimAndSymbolList(inst->getInstOperands(), set.getNumDims()); - printTrailingLocation(inst->getLoc()); - os << " {\n"; - print(inst->getThen(), /*printBlockArgs=*/false); - os.indent(currentIndent) << "}"; - if (inst->hasElse()) { - os << " else {\n"; - print(inst->getElse(), /*printBlockArgs=*/false); - os.indent(currentIndent) << "}"; - } -} - void FunctionPrinter::printValueID(const Value *value, bool printResultNo) const { int resultNo = -1; diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 4471ff25e94..e174fdc1d00 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -327,10 +327,3 @@ ForInst *FuncBuilder::createFor(Location location, int64_t lb, int64_t ub, auto ubMap = AffineMap::getConstantMap(ub, context); return createFor(location, {}, lbMap, {}, ubMap, step); } - -IfInst *FuncBuilder::createIf(Location location, ArrayRef operands, - IntegerSet set) { - auto *inst = IfInst::create(location, operands, set); - block->getInstructions().insert(insertPoint, inst); - return inst; -} diff --git a/mlir/lib/IR/Instruction.cpp b/mlir/lib/IR/Instruction.cpp index 6d74ed14257..0ccab2305ec 100644 --- a/mlir/lib/IR/Instruction.cpp +++ b/mlir/lib/IR/Instruction.cpp @@ -73,9 +73,6 @@ void Instruction::destroy() { case Kind::For: delete cast(this); break; - case Kind::If: - delete cast(this); - break; } } @@ -141,8 +138,6 @@ unsigned Instruction::getNumOperands() const { return cast(this)->getNumOperands(); case Kind::For: return cast(this)->getNumOperands(); - case Kind::If: - return cast(this)->getNumOperands(); } } @@ -152,8 +147,6 @@ MutableArrayRef Instruction::getInstOperands() { return cast(this)->getInstOperands(); case Kind::For: return cast(this)->getInstOperands(); - case Kind::If: - return cast(this)->getInstOperands(); } } @@ -287,15 +280,6 @@ void Instruction::dropAllReferences() { // Make sure to drop references held by instructions within the body. cast(this)->getBody()->dropAllReferences(); break; - case Kind::If: { - // Make sure to drop references held by instructions within the 'then' and - // 'else' blocks. - auto *ifInst = cast(this); - ifInst->getThen()->dropAllReferences(); - if (auto *elseBlock = ifInst->getElse()) - elseBlock->dropAllReferences(); - break; - } case Kind::OperationInst: { auto *opInst = cast(this); if (isTerminator()) @@ -809,54 +793,6 @@ mlir::extractForInductionVars(ArrayRef forInsts) { results.push_back(forInst->getInductionVar()); return results; } -//===----------------------------------------------------------------------===// -// IfInst -//===----------------------------------------------------------------------===// - -IfInst::IfInst(Location location, unsigned numOperands, IntegerSet set) - : Instruction(Kind::If, location), thenClause(this), elseClause(nullptr), - set(set) { - operands.reserve(numOperands); - - // The then of an 'if' inst always has one block. - thenClause.push_back(new Block()); -} - -IfInst::~IfInst() { - if (elseClause) - delete elseClause; - - // An IfInst's IntegerSet 'set' should not be deleted since it is - // allocated through MLIRContext's bump pointer allocator. -} - -IfInst *IfInst::create(Location location, ArrayRef operands, - IntegerSet set) { - unsigned numOperands = operands.size(); - assert(numOperands == set.getNumOperands() && - "operand cound does not match the integer set operand count"); - - IfInst *inst = new IfInst(location, numOperands, set); - - for (auto *op : operands) - inst->operands.emplace_back(InstOperand(inst, op)); - - return inst; -} - -const AffineCondition IfInst::getCondition() const { - return AffineCondition(*this, set); -} - -MLIRContext *IfInst::getContext() const { - // Check for degenerate case of if instruction with no operands. - // This is unlikely, but legal. - if (operands.empty()) - return getFunction()->getContext(); - - return getOperand(0)->getType().getContext(); -} - //===----------------------------------------------------------------------===// // Instruction Cloning //===----------------------------------------------------------------------===// @@ -931,40 +867,23 @@ Instruction *Instruction::clone(BlockAndValueMapping &mapper, for (auto *opValue : getOperands()) operands.push_back(mapper.lookupOrDefault(const_cast(opValue))); - if (auto *forInst = dyn_cast(this)) { - auto lbMap = forInst->getLowerBoundMap(); - auto ubMap = forInst->getUpperBoundMap(); + // Otherwise, this must be a ForInst. + auto *forInst = cast(this); + auto lbMap = forInst->getLowerBoundMap(); + auto ubMap = forInst->getUpperBoundMap(); - auto *newFor = ForInst::create( - getLoc(), ArrayRef(operands).take_front(lbMap.getNumInputs()), - lbMap, ArrayRef(operands).take_back(ubMap.getNumInputs()), - ubMap, forInst->getStep()); + auto *newFor = ForInst::create( + getLoc(), ArrayRef(operands).take_front(lbMap.getNumInputs()), + lbMap, ArrayRef(operands).take_back(ubMap.getNumInputs()), ubMap, + forInst->getStep()); - // Remember the induction variable mapping. - mapper.map(forInst->getInductionVar(), newFor->getInductionVar()); - - // Recursively clone the body of the for loop. - for (auto &subInst : *forInst->getBody()) - newFor->getBody()->push_back(subInst.clone(mapper, context)); - - return newFor; - } - - // Otherwise, we must have an If instruction. - auto *ifInst = cast(this); - auto *newIf = IfInst::create(getLoc(), operands, ifInst->getIntegerSet()); - - auto *resultThen = newIf->getThen(); - for (auto &childInst : *ifInst->getThen()) - resultThen->push_back(childInst.clone(mapper, context)); - - if (ifInst->hasElse()) { - auto *resultElse = newIf->createElse(); - for (auto &childInst : *ifInst->getElse()) - resultElse->push_back(childInst.clone(mapper, context)); - } + // Remember the induction variable mapping. + mapper.map(forInst->getInductionVar(), newFor->getInductionVar()); - return newIf; + // Recursively clone the body of the for loop. + for (auto &subInst : *forInst->getBody()) + newFor->getBody()->push_back(subInst.clone(mapper, context)); + return newFor; } Instruction *Instruction::clone(MLIRContext *context) const { diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 099b218892f..2ab151f8913 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -281,7 +281,7 @@ bool OpTrait::impl::verifyIsTerminator(const OperationInst *op) { if (!block || &block->back() != op) return op->emitOpError("must be the last instruction in the parent block"); - // Terminators may not exist in ForInst and IfInst. + // TODO(riverriddle) Terminators may not exist with an operation region. if (block->getContainingInst()) return op->emitOpError("may only be at the top level of a function"); diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp index 6418b062dc1..7103eeb7389 100644 --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -66,8 +66,6 @@ MLIRContext *IROperandOwner::getContext() const { return cast(this)->getContext(); case Kind::ForInst: return cast(this)->getContext(); - case Kind::IfInst: - return cast(this)->getContext(); } } diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index c477ad1bbc5..e5d6aa46565 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -996,8 +996,7 @@ Attribute Parser::parseAttribute(Type type) { AffineMap map; IntegerSet set; if (parseAffineMapOrIntegerSetReference(map, set)) - return (emitError("expected affine map or integer set attribute value"), - nullptr); + return nullptr; if (map) return builder.getAffineMapAttr(map); assert(set); @@ -2209,8 +2208,6 @@ public: const char *affineStructName); ParseResult parseBound(SmallVectorImpl &operands, AffineMap &map, bool isLower); - ParseResult parseIfInst(); - ParseResult parseElseClause(Block *elseClause); ParseResult parseInstructions(Block *block); private: @@ -2392,10 +2389,6 @@ ParseResult FunctionParser::parseBlockBody(Block *block) { if (parseForInst()) return ParseFailure; break; - case Token::kw_if: - if (parseIfInst()) - return ParseFailure; - break; } } @@ -2935,12 +2928,18 @@ public: return false; } - /// Parse a keyword followed by a type. - bool parseKeywordType(const char *keyword, Type &result) override { - if (parser.getTokenSpelling() != keyword) - return parser.emitError("expected '" + Twine(keyword) + "'"); - parser.consumeToken(); - return !(result = parser.parseType()); + /// Parse an optional keyword. + bool parseOptionalKeyword(const char *keyword) override { + // Check that the current token is a bare identifier or keyword. + if (parser.getToken().isNot(Token::bare_identifier) && + !parser.getToken().isKeyword()) + return true; + + if (parser.getTokenSpelling() == keyword) { + parser.consumeToken(); + return false; + } + return true; } /// Parse an arbitrary attribute of a given type and return it in result. This @@ -3078,6 +3077,15 @@ public: return result == nullptr; } + /// Parses a list of blocks. + bool parseBlockList() override { + SmallVector results; + if (parser.parseOperationBlockList(results)) + return true; + parsedBlockLists.emplace_back(results); + return false; + } + //===--------------------------------------------------------------------===// // Methods for interacting with the parser //===--------------------------------------------------------------------===// @@ -3099,6 +3107,11 @@ public: /// Emit a diagnostic at the specified location and return true. bool emitError(llvm::SMLoc loc, const Twine &message) override { + // If we emit an error, then cleanup any parsed block lists. + for (auto &blockList : parsedBlockLists) + parser.cleanupInvalidBlocks(blockList); + parsedBlockLists.clear(); + parser.emitError(loc, "custom op '" + Twine(opName) + "' " + message); emittedError = true; return true; @@ -3106,7 +3119,13 @@ public: bool didEmitError() const { return emittedError; } + /// Returns the block lists that were parsed. + MutableArrayRef> getParsedBlockLists() { + return parsedBlockLists; + } + private: + std::vector> parsedBlockLists; SMLoc nameLoc; StringRef opName; FunctionParser &parser; @@ -3145,8 +3164,25 @@ OperationInst *FunctionParser::parseCustomOperation() { if (opAsmParser.didEmitError()) return nullptr; + // Check that enough block lists were reserved for those that were parsed. + auto parsedBlockLists = opAsmParser.getParsedBlockLists(); + if (parsedBlockLists.size() > opState.numBlockLists) { + opAsmParser.emitError( + opLoc, + "parsed more block lists than those reserved in the operation state"); + return nullptr; + } + // Otherwise, we succeeded. Use the state it parsed as our op information. - return builder.createOperation(opState); + auto *opInst = builder.createOperation(opState); + + // Resolve any parsed block lists. + for (unsigned i = 0, e = parsedBlockLists.size(); i != e; ++i) { + auto &opBlockList = opInst->getBlockList(i).getBlocks(); + opBlockList.insert(opBlockList.end(), parsedBlockLists[i].begin(), + parsedBlockLists[i].end()); + } + return opInst; } /// For instruction. @@ -3438,69 +3474,6 @@ IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims, return builder.getIntegerSet(numDims, numSymbols, constraints, isEqs); } -/// If instruction. -/// -/// ml-if-head ::= `if` ml-if-cond trailing-location? `{` inst* `}` -/// | ml-if-head `else` `if` ml-if-cond trailing-location? -/// `{` inst* `}` -/// ml-if-inst ::= ml-if-head -/// | ml-if-head `else` `{` inst* `}` -/// -ParseResult FunctionParser::parseIfInst() { - auto loc = getToken().getLoc(); - consumeToken(Token::kw_if); - - IntegerSet set = parseIntegerSetReference(); - if (!set) - return ParseFailure; - - SmallVector operands; - if (parseDimAndSymbolList(operands, set.getNumDims(), set.getNumOperands(), - "integer set")) - return ParseFailure; - - IfInst *ifInst = - builder.createIf(getEncodedSourceLocation(loc), operands, set); - - // Try to parse the optional trailing location. - if (parseOptionalTrailingLocation(ifInst)) - return ParseFailure; - - Block *thenClause = ifInst->getThen(); - - // When parsing of an if instruction body fails, the IR contains - // the if instruction with the portion of the body that has been - // successfully parsed. - if (parseToken(Token::l_brace, "expected '{' before instruction list") || - parseBlock(thenClause) || - parseToken(Token::r_brace, "expected '}' after instruction list")) - return ParseFailure; - - if (consumeIf(Token::kw_else)) { - auto *elseClause = ifInst->createElse(); - if (parseElseClause(elseClause)) - return ParseFailure; - } - - // Reset insertion point to the current block. - builder.setInsertionPointToEnd(ifInst->getBlock()); - - return ParseSuccess; -} - -ParseResult FunctionParser::parseElseClause(Block *elseClause) { - if (getToken().is(Token::kw_if)) { - builder.setInsertionPointToEnd(elseClause); - return parseIfInst(); - } - - if (parseToken(Token::l_brace, "expected '{' before instruction list") || - parseBlock(elseClause) || - parseToken(Token::r_brace, "expected '}' after instruction list")) - return ParseFailure; - return ParseSuccess; -} - //===----------------------------------------------------------------------===// // Top-level entity parsing. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def index 40e98b25cb3..ec00f98b3f5 100644 --- a/mlir/lib/Parser/TokenKinds.def +++ b/mlir/lib/Parser/TokenKinds.def @@ -91,7 +91,6 @@ TOK_KEYWORD(attributes) TOK_KEYWORD(bf16) TOK_KEYWORD(ceildiv) TOK_KEYWORD(dense) -TOK_KEYWORD(else) TOK_KEYWORD(splat) TOK_KEYWORD(f16) TOK_KEYWORD(f32) @@ -100,7 +99,6 @@ TOK_KEYWORD(false) TOK_KEYWORD(floordiv) TOK_KEYWORD(for) TOK_KEYWORD(func) -TOK_KEYWORD(if) TOK_KEYWORD(index) TOK_KEYWORD(loc) TOK_KEYWORD(max) diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index c2e1636626d..afd18a49b79 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -188,16 +188,6 @@ void CSE::simplifyBlock(Block *bb) { simplifyBlock(cast(i).getBody()); break; } - case Instruction::Kind::If: { - auto &ifInst = cast(i); - if (auto *elseBlock = ifInst.getElse()) { - ScopedMapTy::ScopeTy scope(knownValues); - simplifyBlock(elseBlock); - } - ScopedMapTy::ScopeTy scope(knownValues); - simplifyBlock(ifInst.getThen()); - break; - } } } } diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index cee0a08a63c..eebbbe9daa7 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -19,6 +19,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/LoopAnalysis.h" @@ -99,16 +100,16 @@ public: SmallVector forInsts; SmallVector loadOpInsts; SmallVector storeOpInsts; - bool hasIfInst = false; + bool hasNonForRegion = false; void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); } - void visitIfInst(IfInst *ifInst) { hasIfInst = true; } - void visitOperationInst(OperationInst *opInst) { - if (opInst->isa()) + if (opInst->getNumBlockLists() != 0) + hasNonForRegion = true; + else if (opInst->isa()) loadOpInsts.push_back(opInst); - if (opInst->isa()) + else if (opInst->isa()) storeOpInsts.push_back(opInst); } }; @@ -410,8 +411,8 @@ bool MemRefDependenceGraph::init(Function *f) { // all loads and store accesses it contains. LoopNestStateCollector collector; collector.walkForInst(forInst); - // Return false if IfInsts are found (not currently supported). - if (collector.hasIfInst) + // Return false if a non 'for' region was found (not currently supported). + if (collector.hasNonForRegion) return false; Node node(id++, &inst); for (auto *opInst : collector.loadOpInsts) { @@ -434,19 +435,18 @@ bool MemRefDependenceGraph::init(Function *f) { auto *memref = opInst->cast()->getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); - } - if (auto storeOp = opInst->dyn_cast()) { + } else if (auto storeOp = opInst->dyn_cast()) { // Create graph node for top-level store op. Node node(id++, &inst); node.stores.push_back(opInst); auto *memref = opInst->cast()->getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); + } else if (opInst->getNumBlockLists() != 0) { + // Return false if another region is found (not currently supported). + return false; } } - // Return false if IfInsts are found (not currently supported). - if (isa(&inst)) - return false; } // Walk memref access lists and add graph edges between dependent nodes. diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 39ef758833b..6d63e4afd2d 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -119,15 +119,6 @@ PassResult LoopUnroll::runOnFunction(Function *f) { return true; } - bool walkIfInstPostOrder(IfInst *ifInst) { - bool hasInnerLoops = - walkPostOrder(ifInst->getThen()->begin(), ifInst->getThen()->end()); - if (ifInst->hasElse()) - hasInnerLoops |= - walkPostOrder(ifInst->getElse()->begin(), ifInst->getElse()->end()); - return hasInnerLoops; - } - bool walkOpInstPostOrder(OperationInst *opInst) { for (auto &blockList : opInst->getBlockLists()) for (auto &block : blockList) diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index ab37ff63bad..f770684f519 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -20,6 +20,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/AffineOps/AffineOps.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" @@ -246,7 +247,7 @@ public: PassResult runOnFunction(Function *function) override; bool lowerForInst(ForInst *forInst); - bool lowerIfInst(IfInst *ifInst); + bool lowerAffineIf(AffineIfOp *ifOp); bool lowerAffineApply(AffineApplyOp *op); static char passID; @@ -409,7 +410,7 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) { // enabling easy nesting of "if" instructions and if-then-else-if chains. // // +--------------------------------+ -// | | +// | | // | %zero = constant 0 : index | // | %v = affine_apply #expr1(%ops) | // | %c = cmpi "sge" %v, %zero | @@ -453,10 +454,11 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) { // v v // +--------------------------------+ // | continue: | -// | | +// | | // +--------------------------------+ // -bool LowerAffinePass::lowerIfInst(IfInst *ifInst) { +bool LowerAffinePass::lowerAffineIf(AffineIfOp *ifOp) { + auto *ifInst = ifOp->getInstruction(); auto loc = ifInst->getLoc(); // Start by splitting the block containing the 'if' into two parts. The part @@ -466,22 +468,38 @@ bool LowerAffinePass::lowerIfInst(IfInst *ifInst) { auto *continueBlock = condBlock->splitBlock(ifInst); // Create a block for the 'then' code, inserting it between the cond and - // continue blocks. Move the instructions over from the IfInst and add a + // continue blocks. Move the instructions over from the AffineIfOp and add a // branch to the continuation point. Block *thenBlock = new Block(); thenBlock->insertBefore(continueBlock); - auto *oldThen = ifInst->getThen(); - thenBlock->getInstructions().splice(thenBlock->begin(), - oldThen->getInstructions(), - oldThen->begin(), oldThen->end()); + // If the 'then' block is not empty, then splice the instructions. + auto &oldThenBlocks = ifOp->getThenBlocks(); + if (!oldThenBlocks.empty()) { + // We currently only handle one 'then' block. + if (std::next(oldThenBlocks.begin()) != oldThenBlocks.end()) + return true; + + Block *oldThen = &oldThenBlocks.front(); + + thenBlock->getInstructions().splice(thenBlock->begin(), + oldThen->getInstructions(), + oldThen->begin(), oldThen->end()); + } + FuncBuilder builder(thenBlock); builder.create(loc, continueBlock); // Handle the 'else' block the same way, but we skip it if we have no else // code. Block *elseBlock = continueBlock; - if (auto *oldElse = ifInst->getElse()) { + auto &oldElseBlocks = ifOp->getElseBlocks(); + if (!oldElseBlocks.empty()) { + // We currently only handle one 'else' block. + if (std::next(oldElseBlocks.begin()) != oldElseBlocks.end()) + return true; + + auto *oldElse = &oldElseBlocks.front(); elseBlock = new Block(); elseBlock->insertBefore(continueBlock); @@ -493,7 +511,7 @@ bool LowerAffinePass::lowerIfInst(IfInst *ifInst) { } // Ok, now we just have to handle the condition logic. - auto integerSet = ifInst->getCondition().getIntegerSet(); + auto integerSet = ifOp->getIntegerSet(); // Implement short-circuit logic. For each affine expression in the 'if' // condition, convert it into an affine map and call `affine_apply` to obtain @@ -593,29 +611,30 @@ bool LowerAffinePass::lowerAffineApply(AffineApplyOp *op) { PassResult LowerAffinePass::runOnFunction(Function *function) { SmallVector instsToRewrite; - // Collect all the If and For instructions as well as AffineApplyOps. We do - // this as a prepass to avoid invalidating the walker with our rewrite. + // Collect all the For instructions as well as AffineIfOps and AffineApplyOps. + // We do this as a prepass to avoid invalidating the walker with our rewrite. function->walkInsts([&](Instruction *inst) { - if (isa(inst) || isa(inst)) + if (isa(inst)) instsToRewrite.push_back(inst); auto op = dyn_cast(inst); - if (op && op->isa()) + if (op && (op->isa() || op->isa())) instsToRewrite.push_back(inst); }); // Rewrite all of the ifs and fors. We walked the instructions in preorder, // so we know that we will rewrite them in the same order. for (auto *inst : instsToRewrite) - if (auto *ifInst = dyn_cast(inst)) { - if (lowerIfInst(ifInst)) - return failure(); - } else if (auto *forInst = dyn_cast(inst)) { + if (auto *forInst = dyn_cast(inst)) { if (lowerForInst(forInst)) return failure(); } else { auto op = cast(inst); - if (lowerAffineApply(op->cast())) + if (auto ifOp = op->dyn_cast()) { + if (lowerAffineIf(ifOp)) + return failure(); + } else if (lowerAffineApply(op->cast())) { return failure(); + } } return success(); diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 09d961f85cd..2744b1d624c 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -20,6 +20,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/Dominance.h" #include "mlir/Analysis/LoopAnalysis.h" @@ -559,9 +560,6 @@ static bool instantiateMaterialization(Instruction *inst, if (isa(inst)) return inst->emitError("NYI path ForInst"); - if (isa(inst)) - return inst->emitError("NYI path IfInst"); - // Create a builder here for unroll-and-jam effects. FuncBuilder b(inst); auto *opInst = cast(inst); @@ -570,6 +568,9 @@ static bool instantiateMaterialization(Instruction *inst, if (opInst->isa()) { return false; } + if (opInst->getNumBlockLists() != 0) + return inst->emitError("NYI path Op with region"); + if (auto write = opInst->dyn_cast()) { auto *clone = instantiate(&b, write, state->hwVectorType, state->hwVectorInstance, state->substitutionsMap); diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp index bd39e47786a..ba59123c700 100644 --- a/mlir/lib/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp @@ -28,7 +28,6 @@ #define DEBUG_TYPE "simplify-affine-structure" using namespace mlir; -using llvm::report_fatal_error; namespace { @@ -42,9 +41,6 @@ struct SimplifyAffineStructures : public FunctionPass { PassResult runOnFunction(Function *f) override; - void visitIfInst(IfInst *ifInst); - void visitOperationInst(OperationInst *opInst); - static char passID; }; @@ -66,28 +62,19 @@ static IntegerSet simplifyIntegerSet(IntegerSet set) { return set; } -void SimplifyAffineStructures::visitIfInst(IfInst *ifInst) { - auto set = ifInst->getCondition().getIntegerSet(); - ifInst->setIntegerSet(simplifyIntegerSet(set)); -} - -void SimplifyAffineStructures::visitOperationInst(OperationInst *opInst) { - for (auto attr : opInst->getAttrs()) { - if (auto mapAttr = attr.second.dyn_cast()) { - MutableAffineMap mMap(mapAttr.getValue()); - mMap.simplify(); - auto map = mMap.getAffineMap(); - opInst->setAttr(attr.first, AffineMapAttr::get(map)); - } - } -} - PassResult SimplifyAffineStructures::runOnFunction(Function *f) { - f->walkInsts([&](Instruction *inst) { - if (auto *opInst = dyn_cast(inst)) - visitOperationInst(opInst); - if (auto *ifInst = dyn_cast(inst)) - visitIfInst(ifInst); + f->walkOps([&](OperationInst *opInst) { + for (auto attr : opInst->getAttrs()) { + if (auto mapAttr = attr.second.dyn_cast()) { + MutableAffineMap mMap(mapAttr.getValue()); + mMap.simplify(); + auto map = mMap.getAffineMap(); + opInst->setAttr(attr.first, AffineMapAttr::get(map)); + } else if (auto setAttr = attr.second.dyn_cast()) { + auto simplified = simplifyIntegerSet(setAttr.getValue()); + opInst->setAttr(attr.first, IntegerSetAttr::get(simplified)); + } + } }); return success(); diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir index bae112dd3b9..595991c0109 100644 --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -243,14 +243,6 @@ func @non_instruction() { // ----- -func @invalid_if_conditional1() { - for %i = 1 to 10 { - if () { // expected-error {{expected ':' or '['}} - } -} - -// ----- - func @invalid_if_conditional2() { for %i = 1 to 10 { if (i)[N] : (i >= ) // expected-error {{expected '== 0' or '>= 0' at end of affine constraint}} @@ -664,7 +656,11 @@ func @invalid_if_operands2(%N : index) { func @invalid_if_operands3(%N : index) { for %i = 1 to 10 { if #set0(%i)[%i] { - // expected-error@-1 {{value '%i' cannot be used as a symbol}} + // expected-error@-1 {{operand cannot be used as a symbol}} + } + } + return +} // ----- // expected-error@+1 {{expected '"' in string literal}} diff --git a/mlir/test/IR/locations.mlir b/mlir/test/IR/locations.mlir index e3e1bbbbfad..8a90d12bd03 100644 --- a/mlir/test/IR/locations.mlir +++ b/mlir/test/IR/locations.mlir @@ -16,9 +16,9 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) { for %i0 = 0 to 8 loc(fused["foo", "mysource.cc":10:8]) { } - // CHECK: ) loc(fused<"myPass">["foo", "foo2"]) - if #set0(%2) loc(fused<"myPass">["foo", "foo2"]) { - } + // CHECK: } loc(fused<"myPass">["foo", "foo2"]) + if #set0(%2) { + } loc(fused<"myPass">["foo", "foo2"]) // CHECK: return %0 : i32 loc(unknown) return %1 : i32 loc(unknown) diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index 33109606538..626f24569c6 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -287,13 +287,15 @@ func @ifinst(%N: index) { // CHECK: %c1_i32 = constant 1 : i32 %y = "add"(%x, %i) : (i32, index) -> i32 // CHECK: %0 = "add"(%c1_i32, %i0) : (i32, index) -> i32 %z = "mul"(%y, %y) : (i32, i32) -> i32 // CHECK: %1 = "mul"(%0, %0) : (i32, i32) -> i32 - } else if (i)[N] : (i - 2 >= 0, 4 - i >= 0)(%i)[%N] { // CHECK } else if (#set1(%i0)[%arg0]) { - // CHECK: %c1 = constant 1 : index - %u = constant 1 : index - // CHECK: %2 = affine_apply #map{{.*}}(%i0, %i0)[%c1] - %w = affine_apply (d0,d1)[s0] -> (d0+d1+s0) (%i, %i) [%u] - } else { // CHECK } else { - %v = constant 3 : i32 // %c3_i32 = constant 3 : i32 + } else { // CHECK } else { + if (i)[N] : (i - 2 >= 0, 4 - i >= 0)(%i)[%N] { // CHECK if (#set1(%i0)[%arg0]) { + // CHECK: %c1 = constant 1 : index + %u = constant 1 : index + // CHECK: %2 = affine_apply #map{{.*}}(%i0, %i0)[%c1] + %w = affine_apply (d0,d1)[s0] -> (d0+d1+s0) (%i, %i) [%u] + } else { // CHECK } else { + %v = constant 3 : i32 // %c3_i32 = constant 3 : i32 + } } // CHECK } } // CHECK } return // CHECK return @@ -751,11 +753,11 @@ func @type_alias() -> !i32_type_alias { func @verbose_if(%N: index) { %c = constant 200 : index - // CHECK: "if"(%c200, %arg0, %c200) {cond: #set0} : (index, index, index) -> () { - "if"(%c, %N, %c) { cond: #set0 } : (index, index, index) -> () { + // CHECK: if #set0(%c200)[%arg0, %c200] { + "if"(%c, %N, %c) { condition: #set0 } : (index, index, index) -> () { // CHECK-NEXT: "add" %y = "add"(%c, %N) : (index, index) -> index - // CHECK-NEXT: } { + // CHECK-NEXT: } else { } { // The else block list. // CHECK-NEXT: "add" %z = "add"(%c, %c) : (index, index) -> index diff --git a/mlir/test/IR/pretty-locations.mlir b/mlir/test/IR/pretty-locations.mlir index cb2e14a56d5..69dace45165 100644 --- a/mlir/test/IR/pretty-locations.mlir +++ b/mlir/test/IR/pretty-locations.mlir @@ -21,10 +21,10 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) { for %i0 = 0 to 8 loc(fused["foo", "mysource.cc":10:8]) { } - // CHECK: ) <"myPass">["foo", "foo2"] - if #set0(%2) loc(fused<"myPass">["foo", "foo2"]) { - } + // CHECK: } <"myPass">["foo", "foo2"] + if #set0(%2) { + } loc(fused<"myPass">["foo", "foo2"]) // CHECK: return %0 : i32 [unknown] return %1 : i32 loc(unknown) -} \ No newline at end of file +} diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index d170ce590f7..162f193f662 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -483,7 +483,7 @@ func @should_not_fuse_if_inst_at_top_level() { %c0 = constant 4 : index if #set0(%c0) { } - // Top-level IfInst should prevent fusion. + // Top-level IfOp should prevent fusion. // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } @@ -512,7 +512,7 @@ func @should_not_fuse_if_inst_in_loop_nest() { %v0 = load %m[%i1] : memref<10xf32> } - // IfInst in ForInst should prevent fusion. + // IfOp in ForInst should prevent fusion. // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } diff --git a/mlir/test/Transforms/memref-dependence-check.mlir b/mlir/test/Transforms/memref-dependence-check.mlir index 6f6ad3fafc7..628044ed77a 100644 --- a/mlir/test/Transforms/memref-dependence-check.mlir +++ b/mlir/test/Transforms/memref-dependence-check.mlir @@ -10,7 +10,7 @@ func @store_may_execute_before_load() { %cf7 = constant 7.0 : f32 %c0 = constant 4 : index // There is a dependence from store 0 to load 1 at depth 1 because the - // ancestor IfInst of the store, dominates the ancestor ForSmt of the load, + // ancestor IfOp of the store, dominates the ancestor ForSmt of the load, // and thus the store "may" conditionally execute before the load. if #set0(%c0) { for %i0 = 0 to 10 { diff --git a/mlir/test/Transforms/strip-debug-info.mlir b/mlir/test/Transforms/strip-debug-info.mlir index 5509c7aba55..13f009deb70 100644 --- a/mlir/test/Transforms/strip-debug-info.mlir +++ b/mlir/test/Transforms/strip-debug-info.mlir @@ -13,10 +13,10 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) { for %i0 = 0 to 8 loc(fused["foo", "mysource.cc":10:8]) { } - // CHECK: if #set0(%c4) loc(unknown) + // CHECK: } loc(unknown) %2 = constant 4 : index - if #set0(%2) loc(fused<"myPass">["foo", "foo2"]) { - } + if #set0(%2) { + } loc(fused<"myPass">["foo", "foo2"]) // CHECK: return %0 : i32 loc(unknown) return %1 : i32 loc("bar") -- cgit v1.2.3 From a0f3db4024fbe8a07cd84617dc337226ddde8ac1 Mon Sep 17 00:00:00 2001 From: MLIR Team Date: Tue, 29 Jan 2019 09:36:41 -0800 Subject: Support fusing loop nests which require insertion into a new instruction Block position while preserving dependences, opening up additional fusion opportunities. - Adds SSA Value edges to the data dependence graph used in the loop fusion pass. PiperOrigin-RevId: 231417649 --- mlir/lib/Transforms/LoopFusion.cpp | 292 +++++++++++++++++-------- mlir/test/Transforms/loop-fusion.mlir | 386 +++++++++++++++++++++++++--------- 2 files changed, 485 insertions(+), 193 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index eebbbe9daa7..f33afba3806 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -163,12 +163,18 @@ public: } }; - // Edge represents a memref data dependece between nodes in the graph. + // Edge represents a data dependece between nodes in the graph. struct Edge { // The id of the node at the other end of the edge. unsigned id; - // The memref on which this edge represents a dependence. - Value *memref; + // The SSA value on which this edge represents a dependence. + // If the value is a memref, then the dependence is between graph nodes + // which contain accesses to the same memref 'value'. If the value is a + // non-memref value, then the dependence is between a graph node which + // defines an SSA value and another graph node which uses the SSA value + // (e.g. a constant instruction defining a value which is used inside a loop + // nest). + Value *value; }; // Map from node id to Node. @@ -180,6 +186,8 @@ public: // Map from memref to a count on the dependence edges associated with that // memref. DenseMap memrefEdgeCount; + // The next unique identifier to use for newly created graph nodes. + unsigned nextNodeId = 0; MemRefDependenceGraph() {} @@ -194,20 +202,27 @@ public: return &it->second; } + // Adds a node with 'inst' to the graph and returns its unique identifier. + unsigned addNode(Instruction *inst) { + Node node(nextNodeId++, inst); + nodes.insert({node.id, node}); + return node.id; + } + // Remove node 'id' (and its associated edges) from graph. void removeNode(unsigned id) { // Remove each edge in 'inEdges[id]'. if (inEdges.count(id) > 0) { SmallVector oldInEdges = inEdges[id]; for (auto &inEdge : oldInEdges) { - removeEdge(inEdge.id, id, inEdge.memref); + removeEdge(inEdge.id, id, inEdge.value); } } // Remove each edge in 'outEdges[id]'. if (outEdges.count(id) > 0) { SmallVector oldOutEdges = outEdges[id]; for (auto &outEdge : oldOutEdges) { - removeEdge(id, outEdge.id, outEdge.memref); + removeEdge(id, outEdge.id, outEdge.value); } } // Erase remaining node state. @@ -216,13 +231,13 @@ public: nodes.erase(id); } - bool hasOutEdges(unsigned id) { - return outEdges.count(id) > 0 && !outEdges[id].empty(); - } - - // Returns true if node 'id' writes to any memref which escapes (or is an - // argument to) the function/block. Returns false otherwise. - bool writesToLiveInOrEscapingMemrefs(unsigned id) { + // Returns true if node 'id' can be removed from the graph. Returns false + // otherwise. A node can be removed from the graph iff the following + // conditions are met: + // *) The node does not write to any memref which escapes (or is an argument + // to) the function/block. + // *) The node has no successors in the dependence graph. + bool canRemoveNode(unsigned id) { Node *node = getNode(id); for (auto *storeOpInst : node->stores) { auto *memref = storeOpInst->cast()->getMemRef(); @@ -230,70 +245,82 @@ public: auto *opInst = dyn_cast_or_null(inst); // Return false if 'memref' is a function argument. if (opInst == nullptr) - return true; + return false; // Return false if any use of 'memref' escapes the function. for (auto &use : memref->getUses()) { auto *user = dyn_cast(use.getOwner()); if (!user || !isMemRefDereferencingOp(*user)) - return true; + return false; } + // Return false if there exist out edges from 'id' on 'memref'. + if (getOutEdgeCount(id, memref) > 0) + return false; } - return false; + return true; } // Returns true iff there is an edge from node 'srcId' to node 'dstId' for - // 'memref'. Returns false otherwise. - bool hasEdge(unsigned srcId, unsigned dstId, Value *memref) { + // 'value'. Returns false otherwise. + bool hasEdge(unsigned srcId, unsigned dstId, Value *value) { if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) { return false; } bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) { - return edge.id == dstId && edge.memref == memref; + return edge.id == dstId && edge.value == value; }); bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) { - return edge.id == srcId && edge.memref == memref; + return edge.id == srcId && edge.value == value; }); return hasOutEdge && hasInEdge; } - // Adds an edge from node 'srcId' to node 'dstId' for 'memref'. - void addEdge(unsigned srcId, unsigned dstId, Value *memref) { - if (!hasEdge(srcId, dstId, memref)) { - outEdges[srcId].push_back({dstId, memref}); - inEdges[dstId].push_back({srcId, memref}); - memrefEdgeCount[memref]++; + // Adds an edge from node 'srcId' to node 'dstId' for 'value'. + void addEdge(unsigned srcId, unsigned dstId, Value *value) { + if (!hasEdge(srcId, dstId, value)) { + outEdges[srcId].push_back({dstId, value}); + inEdges[dstId].push_back({srcId, value}); + if (value->getType().isa()) + memrefEdgeCount[value]++; } } - // Removes an edge from node 'srcId' to node 'dstId' for 'memref'. - void removeEdge(unsigned srcId, unsigned dstId, Value *memref) { + // Removes an edge from node 'srcId' to node 'dstId' for 'value'. + void removeEdge(unsigned srcId, unsigned dstId, Value *value) { assert(inEdges.count(dstId) > 0); assert(outEdges.count(srcId) > 0); - assert(memrefEdgeCount.count(memref) > 0); - memrefEdgeCount[memref]--; + if (value->getType().isa()) { + assert(memrefEdgeCount.count(value) > 0); + memrefEdgeCount[value]--; + } // Remove 'srcId' from 'inEdges[dstId]'. for (auto it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) { - if ((*it).id == srcId && (*it).memref == memref) { + if ((*it).id == srcId && (*it).value == value) { inEdges[dstId].erase(it); break; } } // Remove 'dstId' from 'outEdges[srcId]'. for (auto it = outEdges[srcId].begin(); it != outEdges[srcId].end(); ++it) { - if ((*it).id == dstId && (*it).memref == memref) { + if ((*it).id == dstId && (*it).value == value) { outEdges[srcId].erase(it); break; } } } - // Returns the input edge count for node 'id' and 'memref'. - unsigned getInEdgeCount(unsigned id, Value *memref) { + // Returns the input edge count for node 'id' and 'memref' from src nodes + // which access 'memref'. + unsigned getIncomingMemRefAccesses(unsigned id, Value *memref) { unsigned inEdgeCount = 0; if (inEdges.count(id) > 0) for (auto &inEdge : inEdges[id]) - if (inEdge.memref == memref) - ++inEdgeCount; + if (inEdge.value == memref) { + Node *srcNode = getNode(inEdge.id); + // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref' + if (srcNode->getLoadOpCount(memref) > 0 || + srcNode->getStoreOpCount(memref) > 0) + ++inEdgeCount; + } return inEdgeCount; } @@ -302,48 +329,84 @@ public: unsigned outEdgeCount = 0; if (outEdges.count(id) > 0) for (auto &outEdge : outEdges[id]) - if (outEdge.memref == memref) + if (outEdge.value == memref) ++outEdgeCount; return outEdgeCount; } - // Check for a dependence in Block instruction list range (srcId, dstId) on - // memrefs other than 'memrefToSkip' (which will be privatized for the fused - // loop). - bool hasDependenceTargetInRange(unsigned srcId, unsigned dstId, - Value *memrefToSkip) { + // Computes and returns an insertion point instruction, before which the + // the fused loop nest can be inserted while preserving + // dependences. Returns nullptr if no such insertion point is found. + Instruction *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId, + Value *memrefToSkip) { if (outEdges.count(srcId) == 0) - return false; - // Check if any of the outgoing edge targets from srcId lie in - // (srcId, dstId). - SmallPtrSet depInsts; - for (auto &outEdge : outEdges[srcId]) { - if (outEdge.id != dstId && outEdge.memref != memrefToSkip) { - Node *node = getNode(outEdge.id); - depInsts.insert(node->inst); + return getNode(dstId)->inst; + + // Build set of insts in range (srcId, dstId) which depend on 'srcId'. + SmallPtrSet srcDepInsts; + for (auto &outEdge : outEdges[srcId]) + if (outEdge.id != dstId && outEdge.value != memrefToSkip) + srcDepInsts.insert(getNode(outEdge.id)->inst); + + // Build set of insts in range (srcId, dstId) on which 'dstId' depends. + SmallPtrSet dstDepInsts; + for (auto &inEdge : inEdges[dstId]) + if (inEdge.id != srcId && inEdge.value != memrefToSkip) + dstDepInsts.insert(getNode(inEdge.id)->inst); + + Instruction *srcNodeInst = getNode(srcId)->inst; + Instruction *dstNodeInst = getNode(dstId)->inst; + + // Computing insertion point: + // *) Walk all instruction positions in Block instruction list in the + // range (src, dst). For each instruction 'inst' visited in this search: + // *) Store in 'firstSrcDepPos' the first position where 'inst' has a + // dependence edge from 'srcNode'. + // *) Store in 'lastDstDepPost' the last position where 'inst' has a + // dependence edge to 'dstNode'. + // *) Compare 'firstSrcDepPos' and 'lastDstDepPost' to determine the + // instruction insertion point (or return null pointer if no such + // insertion point exists: 'firstSrcDepPos' <= 'lastDstDepPos'). + SmallVector depInsts; + Optional firstSrcDepPos; + Optional lastDstDepPos; + unsigned pos = 0; + for (Block::iterator it = std::next(Block::iterator(srcNodeInst)); + it != Block::iterator(dstNodeInst); ++it) { + Instruction *inst = &(*it); + if (srcDepInsts.count(inst) > 0 && firstSrcDepPos == None) + firstSrcDepPos = pos; + if (dstDepInsts.count(inst) > 0) + lastDstDepPos = pos; + depInsts.push_back(inst); + ++pos; + } + + if (firstSrcDepPos.hasValue()) { + if (lastDstDepPos.hasValue()) { + if (firstSrcDepPos.getValue() <= lastDstDepPos.getValue()) { + // No valid insertion point exists which preserves dependences. + return nullptr; + } } + // Return the insertion point at 'firstSrcDepPos'. + return depInsts[firstSrcDepPos.getValue()]; } - // Do a linear walk from 'srcNode.inst' to 'dstNode.inst' and for each - // instruction 'inst' in range ('srcNode.inst', 'dstNode.inst') test - // if 'depInsts' contains 'inst', and return true if it does. - // TODO(andydavis) If this linear search becomes a compile time issue, - // create a data structure which allows a faster search through ForInsts - // in a Block. - Block::iterator it = std::next(Block::iterator(getNode(srcId)->inst)); - Block::iterator itEnd = Block::iterator(getNode(dstId)->inst); - return std::any_of(it, itEnd, [&](Instruction &inst) { - return depInsts.count(&inst) > 0; - }); + // No dependence targets in range (or only dst deps in range), return + // 'dstNodInst' insertion point. + return dstNodeInst; } - // Updates edge mappings from node 'srcId' to node 'dstId'. - void updateEdges(unsigned srcId, unsigned dstId) { + // Updates edge mappings from node 'srcId' to node 'dstId' after 'oldMemRef' + // has been replaced in node at 'dstId' by a private memref. + void updateEdges(unsigned srcId, unsigned dstId, Value *oldMemRef) { // For each edge in 'inEdges[srcId]': add new edge remaping to 'dstId'. if (inEdges.count(srcId) > 0) { SmallVector oldInEdges = inEdges[srcId]; for (auto &inEdge : oldInEdges) { - // Add edge from 'inEdge.id' to 'dstId'. - addEdge(inEdge.id, dstId, inEdge.memref); + // Add edge from 'inEdge.id' to 'dstId' if not for 'oldMemRef'. + if (inEdge.value != oldMemRef) + addEdge(inEdge.id, dstId, inEdge.value); } } // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'. @@ -352,9 +415,18 @@ public: for (auto &outEdge : oldOutEdges) { // Remove any out edges from 'srcId' to 'dstId' across memrefs. if (outEdge.id == dstId) - removeEdge(srcId, outEdge.id, outEdge.memref); + removeEdge(srcId, outEdge.id, outEdge.value); } } + // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being + // replaced by a private memref). These edges could come from nodes + // other than 'srcId' which were removed in the previous step. + if (inEdges.count(dstId) > 0) { + SmallVector oldInEdges = inEdges[dstId]; + for (auto &inEdge : oldInEdges) + if (inEdge.value == oldMemRef) + removeEdge(inEdge.id, dstId, inEdge.value); + } } // Adds ops in 'loads' and 'stores' to node at 'id'. @@ -381,12 +453,12 @@ public: auto it = inEdges.find(idAndNode.first); if (it != inEdges.end()) { for (const auto &e : it->second) - os << " InEdge: " << e.id << " " << e.memref << "\n"; + os << " InEdge: " << e.id << " " << e.value << "\n"; } it = outEdges.find(idAndNode.first); if (it != outEdges.end()) { for (const auto &e : it->second) - os << " OutEdge: " << e.id << " " << e.memref << "\n"; + os << " OutEdge: " << e.id << " " << e.value << "\n"; } } } @@ -398,23 +470,23 @@ public: // TODO(andydavis) Add support for taking a Block arg to construct the // dependence graph at a different depth. bool MemRefDependenceGraph::init(Function *f) { - unsigned id = 0; DenseMap> memrefAccesses; // TODO: support multi-block functions. if (f->getBlocks().size() != 1) return false; + DenseMap forToNodeMap; for (auto &inst : f->front()) { if (auto *forInst = dyn_cast(&inst)) { // Create graph node 'id' to represent top-level 'forInst' and record // all loads and store accesses it contains. LoopNestStateCollector collector; collector.walkForInst(forInst); - // Return false if a non 'for' region was found (not currently supported). + // Return false if IfInsts are found (not currently supported). if (collector.hasNonForRegion) return false; - Node node(id++, &inst); + Node node(nextNodeId++, &inst); for (auto *opInst : collector.loadOpInsts) { node.loads.push_back(opInst); auto *memref = opInst->cast()->getMemRef(); @@ -425,19 +497,20 @@ bool MemRefDependenceGraph::init(Function *f) { auto *memref = opInst->cast()->getMemRef(); memrefAccesses[memref].insert(node.id); } + forToNodeMap[forInst] = node.id; nodes.insert({node.id, node}); } if (auto *opInst = dyn_cast(&inst)) { if (auto loadOp = opInst->dyn_cast()) { // Create graph node for top-level load op. - Node node(id++, &inst); + Node node(nextNodeId++, &inst); node.loads.push_back(opInst); auto *memref = opInst->cast()->getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); } else if (auto storeOp = opInst->dyn_cast()) { // Create graph node for top-level store op. - Node node(id++, &inst); + Node node(nextNodeId++, &inst); node.stores.push_back(opInst); auto *memref = opInst->cast()->getMemRef(); memrefAccesses[memref].insert(node.id); @@ -445,6 +518,32 @@ bool MemRefDependenceGraph::init(Function *f) { } else if (opInst->getNumBlockLists() != 0) { // Return false if another region is found (not currently supported). return false; + } else if (opInst->getNumResults() > 0 && !opInst->use_empty()) { + // Create graph node for top-level producer of SSA values, which + // could be used by loop nest nodes. + Node node(nextNodeId++, &inst); + nodes.insert({node.id, node}); + } + } + } + + // Add dependence edges between nodes which produce SSA values and their + // users. + for (auto &idAndNode : nodes) { + const Node &node = idAndNode.second; + if (!node.loads.empty() || !node.stores.empty()) + continue; + auto *opInst = cast(node.inst); + for (auto *value : opInst->getResults()) { + for (auto &use : value->getUses()) { + auto *userOpInst = cast(use.getOwner()); + SmallVector loops; + getLoopIVs(*userOpInst, &loops); + if (loops.empty()) + continue; + assert(forToNodeMap.count(loops[0]) > 0); + unsigned userLoopNestId = forToNodeMap[loops[0]]; + addEdge(node.id, userLoopNestId, value); } } } @@ -768,20 +867,25 @@ static Value *createPrivateMemRef(ForInst *forInst, // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed // by 'srcStoreOpInst'. - auto newMemRefType = b.getMemRefType(newShape, oldMemRefType.getElementType(), - {}, oldMemRefType.getMemorySpace()); + auto newMemRefType = + top.getMemRefType(newShape, oldMemRefType.getElementType(), {}, + oldMemRefType.getMemorySpace()); // Gather alloc operands for the dynamic dimensions of the memref. SmallVector allocOperands; unsigned dynamicDimCount = 0; for (auto dimSize : oldMemRefType.getShape()) { if (dimSize == -1) allocOperands.push_back( - b.create(forInst->getLoc(), oldMemRef, dynamicDimCount++)); + top.create(forInst->getLoc(), oldMemRef, dynamicDimCount++)); } // Create new private memref for fused loop 'forInst'. + // TODO(andydavis) Create/move alloc ops for private memrefs closer to their + // consumer loop nests to reduce their live range. Currently they are added + // at the beginning of the function, because loop nests can be reordered + // during the fusion pass. Value *newMemRef = - b.create(forInst->getLoc(), newMemRefType, allocOperands); + top.create(forInst->getLoc(), newMemRefType, allocOperands); // Build an AffineMap to remap access functions based on lower bound offsets. SmallVector remapExprs; @@ -1198,7 +1302,7 @@ public: // Iterate through in edges for 'dstId'. for (auto &srcEdge : mdg->inEdges[dstId]) { // Skip 'srcEdge' if not for 'memref'. - if (srcEdge.memref != memref) + if (srcEdge.value != memref) continue; auto *srcNode = mdg->getNode(srcEdge.id); @@ -1210,22 +1314,25 @@ public: if (srcNode->stores.size() != 1) continue; - // Skip 'srcNode' if it has in dependence edges. NOTE: This is overly + // Skip 'srcNode' if it has in edges on 'memref'. // TODO(andydavis) Track dependence type with edges, and just check - // for WAW dependence edge here. - if (mdg->getInEdgeCount(srcNode->id, memref) != 0) + // for WAW dependence edge here. Note that this check is overly + // conservative and will be removed in the future. + if (mdg->getIncomingMemRefAccesses(srcNode->id, memref) != 0) continue; - // Skip if 'srcNode' has out edges on memrefs other than 'memref' - // for nodes in instruction list range (srcNode.inst, dstNode.inst). - if (mdg->hasDependenceTargetInRange(srcNode->id, dstNode->id, memref)) + // Compute an instruction list insertion point for the fused loop + // nest which preserves dependences. + Instruction *insertPointInst = mdg->getFusedLoopNestInsertionPoint( + srcNode->id, dstNode->id, memref); + if (insertPointInst == nullptr) continue; - // Check if fusion would be profitable and at what depth. // Get unique 'srcNode' store op. auto *srcStoreOpInst = srcNode->stores.front(); unsigned bestDstLoopDepth; mlir::ComputationSliceState sliceState; + // Check if fusion would be profitable. if (!isFusionProfitable(srcStoreOpInst, dstLoadOpInsts, &sliceState, &bestDstLoopDepth)) continue; @@ -1234,8 +1341,13 @@ public: auto *sliceLoopNest = mlir::insertBackwardComputationSlice( srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState); if (sliceLoopNest != nullptr) { + // Move 'dstForInst' before 'insertPointInst' if needed. + auto *dstForInst = cast(dstNode->inst); + if (insertPointInst != dstForInst) { + dstForInst->moveBefore(insertPointInst); + } // Update edges between 'srcNode' and 'dstNode'. - mdg->updateEdges(srcNode->id, dstNode->id); + mdg->updateEdges(srcNode->id, dstNode->id, memref); // Collect slice loop stats. LoopNestStateCollector sliceCollector; @@ -1244,9 +1356,7 @@ public: for (auto *forInst : sliceCollector.forInsts) { promoteIfSingleIteration(forInst); } - // Create private memref for 'memref' in 'dstForInst'. - auto *dstForInst = cast(dstNode->inst); SmallVector storesForMemref; for (auto *storeOpInst : sliceCollector.storeOpInsts) { if (storeOpInst->cast()->getMemRef() == memref) @@ -1256,6 +1366,11 @@ public: auto *newMemRef = createPrivateMemRef( dstForInst, storesForMemref[0], bestDstLoopDepth); visitedMemrefs.insert(newMemRef); + // Create new node in dependence graph for 'newMemRef' alloc op. + unsigned newMemRefNodeId = + mdg->addNode(newMemRef->getDefiningInst()); + // Add edge from 'newMemRef' node to dstNode. + mdg->addEdge(newMemRefNodeId, dstId, newMemRef); // Collect dst loop stats after memref privatizaton transformation. LoopNestStateCollector dstLoopCollector; @@ -1276,8 +1391,7 @@ public: // Remove old src loop nest if it no longer has outgoing dependence // edges, and it does not write to a memref which escapes the // function. - if (!mdg->hasOutEdges(srcNode->id) && - !mdg->writesToLiveInOrEscapingMemrefs(srcNode->id)) { + if (mdg->canRemoveNode(srcNode->id)) { mdg->removeNode(srcNode->id); cast(srcNode->inst)->erase(); } diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index 162f193f662..189968b3ffa 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -62,15 +62,15 @@ func @should_fuse_reduction_to_pointwise() { // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: for %i1 = 0 to 10 { // CHECK-NEXT: %3 = affine_apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: %4 = load %2[%3] : memref<1xf32> - // CHECK-NEXT: %5 = load %0[%i0, %i1] : memref<10x10xf32> + // CHECK-NEXT: %4 = load %0[%3] : memref<1xf32> + // CHECK-NEXT: %5 = load %1[%i0, %i1] : memref<10x10xf32> // CHECK-NEXT: %6 = addf %4, %5 : f32 // CHECK-NEXT: %7 = affine_apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: store %6, %2[%7] : memref<1xf32> + // CHECK-NEXT: store %6, %0[%7] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: %8 = affine_apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: %9 = load %2[%8] : memref<1xf32> - // CHECK-NEXT: store %9, %1[%i0] : memref<10xf32> + // CHECK-NEXT: %9 = load %0[%8] : memref<1xf32> + // CHECK-NEXT: store %9, %2[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -109,8 +109,7 @@ func @should_fuse_loop_nests_with_shifts() { // *) Fifth affine apply shifts the loads access function by '-1', because // of the offset induced by reducing the memref shape from 10x10 to 9x9. // NOTE: Should create a private memref with reduced shape 9x9xf32. - // CHECK: %0 = alloc() : memref<1x1xf32> - // CHECK-NEXT: for %i0 = 1 to 10 { + // CHECK: for %i0 = 1 to 10 { // CHECK-NEXT: for %i1 = 1 to 10 { // CHECK-NEXT: %1 = affine_apply [[MAP_SHIFT_MINUS_ONE_R1]](%i0) // CHECK-NEXT: %2 = affine_apply [[MAP_SHIFT_MINUS_ONE_R1]](%i1) @@ -155,10 +154,10 @@ func @should_fuse_loop_nest() { %v1 = load %b[%i4, %i5] : memref<10x10xf32> } } - // Expecting private memref for '%b' first, then private memref for '%a'. - // CHECK: [[NEWB:%[0-9]+]] = alloc() : memref<1x1xf32> - // CHECK-NEXT: [[NEWA:%[0-9]+]] = alloc() : memref<1x1xf32> - // CHECK-NEXT: for %i0 = 0 to 10 { + // Expecting private memref for '%a' first, then private memref for '%b'. + // CHECK-DAG: [[NEWA:%[0-9]+]] = alloc() : memref<1x1xf32> + // CHECK-DAG: [[NEWB:%[0-9]+]] = alloc() : memref<1x1xf32> + // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: for %i1 = 0 to 10 { // CHECK-NEXT: %2 = affine_apply [[MAP_D2_D0_DIFF]](%i1, %i0, %i1, %i0) // CHECK-NEXT: %3 = affine_apply [[MAP_D3_D1_DIFF]](%i1, %i0, %i1, %i0) @@ -204,15 +203,14 @@ func @should_fuse_across_intermediate_loop_with_no_deps() { // Should fuse first loop (past second loop with no dependences) into third. // Note that fusion creates a private memref '%2' for the fused loop nest. // CHECK: for %i0 = 0 to 10 { - // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32> + // CHECK-NEXT: store %cst, %2[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK: %2 = alloc() : memref<1xf32> // CHECK: for %i1 = 0 to 10 { - // CHECK-NEXT: %3 = load %0[%i1] : memref<10xf32> + // CHECK-NEXT: %3 = load %1[%i1] : memref<10xf32> // CHECK-NEXT: %4 = affine_apply [[MAP0]](%i1, %i1) - // CHECK-NEXT: store %3, %2[%4] : memref<1xf32> + // CHECK-NEXT: store %3, %0[%4] : memref<1xf32> // CHECK-NEXT: %5 = affine_apply [[MAP0]](%i1, %i1) - // CHECK-NEXT: %6 = load %2[%5] : memref<1xf32> + // CHECK-NEXT: %6 = load %0[%5] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -241,10 +239,10 @@ func @should_fuse_all_loops() { } // Should fuse first and second loops into third. - // Expecting private memref for '%b' first, then private memref for '%a'. - // CHECK: [[NEWB:%[0-9]+]] = alloc() : memref<1xf32> - // CHECK-NEXT: [[NEWA:%[0-9]+]] = alloc() : memref<1xf32> - // CHECK-NEXT: for %i0 = 0 to 10 { + // Expecting private memref for '%a' first, then private memref for '%b'. + // CHECK-DAG: [[NEWA:%[0-9]+]] = alloc() : memref<1xf32> + // CHECK-DAG: [[NEWB:%[0-9]+]] = alloc() : memref<1xf32> + // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: %2 = affine_apply [[MAP0]](%i0, %i0) // CHECK-NEXT: store %cst, [[NEWA]][%2] : memref<1xf32> // CHECK-NEXT: %3 = affine_apply [[MAP0]](%i0, %i0) @@ -283,16 +281,15 @@ func @should_fuse_first_and_second_loops() { // Should fuse first loop into the second (last loop should not be fused). // Should create private memref '%2' for fused loop. - // CHECK: %2 = alloc() : memref<1xf32> - // CHECK-NEXT: for %i0 = 0 to 10 { + // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: %3 = affine_apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: store %cst, %2[%3] : memref<1xf32> + // CHECK-NEXT: store %cst, %0[%3] : memref<1xf32> // CHECK-NEXT: %4 = affine_apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: %5 = load %2[%4] : memref<1xf32> - // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> + // CHECK-NEXT: %5 = load %0[%4] : memref<1xf32> + // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK: for %i1 = 0 to 10 { - // CHECK-NEXT: %6 = load %1[%i1] : memref<10xf32> + // CHECK-NEXT: %6 = load %2[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return @@ -361,17 +358,16 @@ func @should_fuse_across_waw_dep_with_private_memref() { } // Fusing loop %i0 to %i2 would violate the WAW dependence between %i0 and %i1 // CHECK: for %i0 = 0 to 10 { - // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> + // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK: for %i1 = 0 to 10 { - // CHECK-NEXT: store %cst, %0[%i1] : memref<10xf32> + // CHECK-NEXT: store %cst, %1[%i1] : memref<10xf32> // CHECK-NEXT: } - // CHECK: %1 = alloc() : memref<1xf32> - // CHECK-NEXT: for %i2 = 0 to 10 { - // CHECK-NEXT: %2 = affine_apply #map0(%i2, %i2) - // CHECK-NEXT: store %cst, %1[%2] : memref<1xf32> - // CHECK-NEXT: %3 = affine_apply #map0(%i2, %i2) - // CHECK-NEXT: %4 = load %1[%3] : memref<1xf32> + // CHECK: for %i2 = 0 to 10 { + // CHECK-NEXT: %2 = affine_apply #map0(%i2, %i2) + // CHECK-NEXT: store %cst, %0[%2] : memref<1xf32> + // CHECK-NEXT: %3 = affine_apply #map0(%i2, %i2) + // CHECK-NEXT: %4 = load %0[%3] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -379,8 +375,8 @@ func @should_fuse_across_waw_dep_with_private_memref() { // ----- -// CHECK-LABEL: func @should_not_fuse_war_dep_would_be_violated() { -func @should_not_fuse_war_dep_would_be_violated() { +// CHECK-LABEL: func @should_fuse_and_move_to_preserve_war_dep() { +func @should_fuse_and_move_to_preserve_war_dep() { %a = alloc() : memref<10xf32> %b = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 @@ -395,18 +391,20 @@ func @should_not_fuse_war_dep_would_be_violated() { for %i2 = 0 to 10 { %v1 = load %b[%i2] : memref<10xf32> } - // Fusing loop %i0 to %i2 would violate the WAR dependence between %i0 and %i1 - // CHECK: for %i0 = 0 to 10 { - // CHECK-NEXT: %2 = load %0[%i0] : memref<10xf32> - // CHECK-NEXT: store %2, %1[%i0] : memref<10xf32> - // CHECK-NEXT: } - // CHECK: for %i1 = 0 to 10 { - // CHECK-NEXT: store %cst, %0[%i1] : memref<10xf32> - // CHECK-NEXT: } - // CHECK: for %i2 = 0 to 10 { - // CHECK-NEXT: %3 = load %1[%i2] : memref<10xf32> - // CHECK-NEXT: } - // CHECK-NEXT: return + // Loops '%i1' and '%i2' have no dependences. We can fuse a slice of '%i0' + // into '%i2' if we move the fused loop nest before '%i1', which preserves + // the WAR dependence from load '%a' in '%i0' to the store '%a' in loop '%i1'. + // CHECK: for %i0 = 0 to 10 { + // CHECK-NEXT: %2 = load %1[%i0] : memref<10xf32> + // CHECK-NEXT: %3 = affine_apply #map0(%i0, %i0) + // CHECK-NEXT: store %2, %0[%3] : memref<1xf32> + // CHECK-NEXT: %4 = affine_apply #map0(%i0, %i0) + // CHECK-NEXT: %5 = load %0[%4] : memref<1xf32> + // CHECK-NEXT: } + // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: store %cst, %1[%i1] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: return return } @@ -428,14 +426,13 @@ func @should_fuse_with_private_memref_if_top_level_access() { %v1 = load %m[%c0] : memref<10xf32> // Top-level load to '%m' should prevent fusion. // CHECK: for %i0 = 0 to 10 { - // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> + // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK: %1 = alloc() : memref<1xf32> // CHECK-NEXT: for %i1 = 0 to 10 { // CHECK-NEXT: %2 = affine_apply #map0(%i1, %i1) - // CHECK-NEXT: store %cst, %1[%2] : memref<1xf32> + // CHECK-NEXT: store %cst, %0[%2] : memref<1xf32> // CHECK-NEXT: %3 = affine_apply #map0(%i1, %i1) - // CHECK-NEXT: %4 = load %1[%3] : memref<1xf32> + // CHECK-NEXT: %4 = load %0[%3] : memref<1xf32> // CHECK-NEXT: } return } @@ -630,12 +627,12 @@ func @fuse_reshape_16_4_64() { // CHECK: for %i0 = 0 to 64 { // CHECK-NEXT: %2 = affine_apply #map0(%i0) // CHECK-NEXT: %3 = affine_apply #map1(%i0) -// CHECK-NEXT: %4 = load %0[%2, %3] : memref<16x4xf32> +// CHECK-NEXT: %4 = load %1[%2, %3] : memref<16x4xf32> // CHECK-NEXT: %5 = affine_apply [[MAP2]](%2, %3) // CHECK-NEXT: %6 = affine_apply [[MAP3]](%i0, %5) -// CHECK-NEXT: store %4, %1[%6] : memref<1xf32> +// CHECK-NEXT: store %4, %0[%6] : memref<1xf32> // CHECK-NEXT: %7 = affine_apply [[MAP3]](%i0, %i0) -// CHECK-NEXT: %8 = load %1[%7] : memref<1xf32> +// CHECK-NEXT: %8 = load %0[%7] : memref<1xf32> // CHECK-NEXT: "foo"(%8) : (f32) -> () // CHECK-NEXT: } // CHECK-NEXT: return @@ -696,9 +693,9 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> { // Everything above is fused to a single 2-d loop nest, and the 6-d tensor %in // is eliminated if -memref-dataflow-opt is also supplied. // -// CHECK: %0 = alloc() : memref<64x9xi32> -// CHECK-NEXT: %1 = alloc() : memref<1x1xi32> -// CHECK-NEXT: %2 = alloc() : memref<1x2x3x3x16x1xi32> +// CHECK: %0 = alloc() : memref<1x2x3x3x16x1xi32> +// CHECK: %1 = alloc() : memref<1x1xi32> +// CHECK: %2 = alloc() : memref<64x9xi32> // CHECK-NEXT: for %i0 = 0 to 64 { // CHECK-NEXT: for %i1 = 0 to 9 { // CHECK-NEXT: %3 = affine_apply #map0(%i0, %i1) @@ -713,7 +710,7 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> { // CHECK-NEXT: %12 = affine_apply #map8(%i0, %i1, %3, %4, %5, %6, %7, %c0) // CHECK-NEXT: %13 = affine_apply #map9(%i0, %i1, %3, %4, %5, %6, %7, %c0) // CHECK-NEXT: %14 = affine_apply #map10(%i0, %i1, %3, %4, %5, %6, %7, %c0) -// CHECK-NEXT: store %8, %2[%9, %10, %11, %12, %13, %14] : memref<1x2x3x3x16x1xi32> +// CHECK-NEXT: store %8, %0[%9, %10, %11, %12, %13, %14] : memref<1x2x3x3x16x1xi32> // CHECK-NEXT: %15 = affine_apply #map11(%i0, %i1) // CHECK-NEXT: %16 = affine_apply #map12(%15) // CHECK-NEXT: %17 = affine_apply #map13(%15) @@ -727,7 +724,7 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> { // CHECK-NEXT: %25 = affine_apply #map8(%i0, %i1, %16, %17, %18, %19, %20, %21) // CHECK-NEXT: %26 = affine_apply #map9(%i0, %i1, %16, %17, %18, %19, %20, %21) // CHECK-NEXT: %27 = affine_apply #map10(%i0, %i1, %16, %17, %18, %19, %20, %21) -// CHECK-NEXT: %28 = load %2[%22, %23, %24, %25, %26, %27] : memref<1x2x3x3x16x1xi32> +// CHECK-NEXT: %28 = load %0[%22, %23, %24, %25, %26, %27] : memref<1x2x3x3x16x1xi32> // CHECK-NEXT: %29 = affine_apply #map18(%i0, %i1, %i0, %i1) // CHECK-NEXT: %30 = affine_apply #map19(%i0, %i1, %i0, %i1) // CHECK-NEXT: store %28, %1[%29, %30] : memref<1x1xi32> @@ -735,10 +732,10 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> { // CHECK-NEXT: %32 = affine_apply #map19(%i0, %i1, %i0, %i1) // CHECK-NEXT: %33 = load %1[%31, %32] : memref<1x1xi32> // CHECK-NEXT: %34 = muli %33, %33 : i32 -// CHECK-NEXT: store %34, %0[%i0, %i1] : memref<64x9xi32> +// CHECK-NEXT: store %34, %2[%i0, %i1] : memref<64x9xi32> // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: return %0 : memref<64x9xi32> +// CHECK-NEXT: return %2 : memref<64x9xi32> // ----- @@ -797,19 +794,19 @@ func @should_fuse_reduction_at_depth1() { // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: for %i1 = 0 to 100 { // CHECK-NEXT: %2 = affine_apply #map0(%i0, %i0) - // CHECK-NEXT: %3 = load %1[%2] : memref<1xf32> - // CHECK-NEXT: %4 = load %0[%i0, %i1] : memref<10x100xf32> + // CHECK-NEXT: %3 = load %0[%2] : memref<1xf32> + // CHECK-NEXT: %4 = load %1[%i0, %i1] : memref<10x100xf32> // CHECK-NEXT: %5 = "maxf"(%3, %4) : (f32, f32) -> f32 // CHECK-NEXT: %6 = affine_apply #map0(%i0, %i0) - // CHECK-NEXT: store %5, %1[%6] : memref<1xf32> + // CHECK-NEXT: store %5, %0[%6] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: for %i2 = 0 to 100 { // CHECK-NEXT: %7 = affine_apply #map0(%i0, %i0) - // CHECK-NEXT: %8 = load %1[%7] : memref<1xf32> - // CHECK-NEXT: %9 = load %0[%i0, %i2] : memref<10x100xf32> + // CHECK-NEXT: %8 = load %0[%7] : memref<1xf32> + // CHECK-NEXT: %9 = load %1[%i0, %i2] : memref<10x100xf32> // CHECK-NEXT: %10 = subf %9, %8 : f32 // CHECK-NEXT: %11 = affine_apply #map0(%i0, %i0) - // CHECK-NEXT: store %10, %1[%11] : memref<1xf32> + // CHECK-NEXT: store %10, %0[%11] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -849,19 +846,19 @@ func @should_fuse_at_src_depth1_and_dst_depth1() { // at depth 1 and the slice should be inserted at depth 1. // CHECK: for %i0 = 0 to 100 { // CHECK-NEXT: for %i1 = 0 to 16 { - // CHECK-NEXT: %2 = load %0[%i0, %i1] : memref<100x16xf32> + // CHECK-NEXT: %2 = load %1[%i0, %i1] : memref<100x16xf32> // CHECK-NEXT: "op0"(%2) : (f32) -> () // CHECK-NEXT: } // CHECK-NEXT: for %i2 = 0 to 16 { // CHECK-NEXT: %3 = "op1"() : () -> f32 // CHECK-NEXT: %4 = affine_apply #map0(%i0, %i0, %i2) // CHECK-NEXT: %5 = affine_apply #map1(%i0, %i0, %i2) - // CHECK-NEXT: store %3, %1[%4, %5] : memref<1x16xf32> + // CHECK-NEXT: store %3, %0[%4, %5] : memref<1x16xf32> // CHECK-NEXT: } // CHECK-NEXT: for %i3 = 0 to 16 { // CHECK-NEXT: %6 = affine_apply #map0(%i0, %i0, %i3) // CHECK-NEXT: %7 = affine_apply #map1(%i0, %i0, %i3) - // CHECK-NEXT: %8 = load %1[%6, %7] : memref<1x16xf32> + // CHECK-NEXT: %8 = load %0[%6, %7] : memref<1x16xf32> // CHECK-NEXT: "op2"(%8) : (f32) -> () // CHECK-NEXT: } // CHECK-NEXT: } @@ -920,8 +917,8 @@ func @fusion_at_depth0_not_currently_supported() { } // NOTE: Should shrink memref size to 1 element access by load in dst loop // nest, and make the store in the slice store to the same element. - // CHECK: %0 = alloc() : memref<1xf32> - // CHECK-NEXT: for %i0 = 0 to 10 { + // CHECK-DAG: %0 = alloc() : memref<1xf32> + // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: %1 = affine_apply #map0()[%c0] // CHECK-NEXT: store %cst, %0[%1] : memref<1xf32> // CHECK-NEXT: %2 = load %0[%c0] : memref<1xf32> @@ -1000,8 +997,8 @@ func @should_fuse_deep_loop_nests() { // bounds which are a function of the first four loops of destination loop nest, // where the destination loops nests have been interchanged. -// CHECK: %2 = alloc() : memref<1x1x1x1x16x10xf32, 2> -// CHECK-NEXT: for %i0 = 0 to 3 { +// CHECK-DAG: %0 = alloc() : memref<1x1x1x1x16x10xf32, 2> +// CHECK: for %i0 = 0 to 3 { // CHECK-NEXT: for %i1 = 0 to 3 { // CHECK-NEXT: for %i2 = 0 to 2 { // CHECK-NEXT: for %i3 = 0 to 2 { @@ -1009,7 +1006,7 @@ func @should_fuse_deep_loop_nests() { // CHECK-NEXT: for %i5 = 0 to 3 { // CHECK-NEXT: for %i6 = 0 to 16 { // CHECK-NEXT: for %i7 = 0 to 10 { -// CHECK-NEXT: %3 = load %0[%i2, %i3, %i0, %i1, %i6, %i7] : memref<2x2x3x3x16x10xf32, 2> +// CHECK-NEXT: %3 = load %1[%i2, %i3, %i0, %i1, %i6, %i7] : memref<2x2x3x3x16x10xf32, 2> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: for %i8 = 0 to 16 { @@ -1020,14 +1017,14 @@ func @should_fuse_deep_loop_nests() { // CHECK-NEXT: %7 = affine_apply #map3(%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i8, %i9) // CHECK-NEXT: %8 = affine_apply #map4(%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i8, %i9) // CHECK-NEXT: %9 = affine_apply #map5(%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i8, %i9) -// CHECK-NEXT: store %cst, %2[%4, %5, %6, %7, %8, %9] : memref<1x1x1x1x16x10xf32, 2> +// CHECK-NEXT: store %cst, %0[%4, %5, %6, %7, %8, %9] : memref<1x1x1x1x16x10xf32, 2> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: for %i10 = 0 to 2 { // CHECK-NEXT: for %i11 = 0 to 2 { // CHECK-NEXT: for %i12 = 0 to 16 { // CHECK-NEXT: for %i13 = 0 to 10 { -// CHECK-NEXT: %10 = load %0[%i10, %i11, %i4, %i5, %i12, %i13] : memref<2x2x3x3x16x10xf32, 2> +// CHECK-NEXT: %10 = load %1[%i10, %i11, %i4, %i5, %i12, %i13] : memref<2x2x3x3x16x10xf32, 2> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: for %i14 = 0 to 16 { @@ -1038,7 +1035,7 @@ func @should_fuse_deep_loop_nests() { // CHECK-NEXT: %14 = affine_apply #map3(%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i14, %i15) // CHECK-NEXT: %15 = affine_apply #map4(%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i14, %i15) // CHECK-NEXT: %16 = affine_apply #map5(%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i14, %i15) -// CHECK-NEXT: %17 = load %2[%11, %12, %13, %14, %15, %16] : memref<1x1x1x1x16x10xf32, 2> +// CHECK-NEXT: %17 = load %0[%11, %12, %13, %14, %15, %16] : memref<1x1x1x1x16x10xf32, 2> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } @@ -1088,20 +1085,20 @@ func @should_fuse_at_depth1_and_reduce_slice_trip_count() { // NOTE: the size of the private memref created for the fused loop nest // is reduced from the original shape from 4x256 to 4x16 because of the // data accessed by the load. - // CHECK: %1 = alloc() : memref<1x16xf32> - // CHECK-NEXT: for %i0 = 0 to 4 { + // CHECK-DAG: %0 = alloc() : memref<1x16xf32> + // CHECK: for %i0 = 0 to 4 { // CHECK-NEXT: for %i1 = 0 to 256 { - // CHECK-NEXT: %2 = load %0[%i0, %i1] : memref<4x256xf32> + // CHECK-NEXT: %2 = load %1[%i0, %i1] : memref<4x256xf32> // CHECK-NEXT: } // CHECK-NEXT: for %i2 = 0 to 16 { // CHECK-NEXT: %3 = affine_apply #map0(%i0, %i0, %i2) // CHECK-NEXT: %4 = affine_apply #map1(%i0, %i0, %i2) - // CHECK-NEXT: store %cst, %1[%3, %4] : memref<1x16xf32> + // CHECK-NEXT: store %cst, %0[%3, %4] : memref<1x16xf32> // CHECK-NEXT: } // CHECK-NEXT: for %i3 = 0 to 16 { // CHECK-NEXT: %5 = affine_apply #map0(%i0, %i0, %i3) // CHECK-NEXT: %6 = affine_apply #map1(%i0, %i0, %i3) - // CHECK-NEXT: %7 = load %1[%5, %6] : memref<1x16xf32> + // CHECK-NEXT: %7 = load %0[%5, %6] : memref<1x16xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -1131,8 +1128,8 @@ func @should_fuse_at_depth1_with_trip_count_20() { } } // NOTE: The size of the private memref created for fusion is shrunk to 20xf32 - // CHECK: %0 = alloc() : memref<20xf32> - // CHECK-NEXT: for %i0 = 0 to 5 { + // CHECK-DAG: %0 = alloc() : memref<20xf32> + // CHECK: for %i0 = 0 to 5 { // CHECK-NEXT: for %i1 = 0 to 20 { // CHECK-NEXT: store %cst, %0[%i1] : memref<20xf32> // CHECK-NEXT: } @@ -1172,8 +1169,8 @@ func @should_fuse_at_depth1_with_trip_count_19() { } } // NOTE: The size of the private memref created for fusion is shrunk to 19xf32 - // CHECK: %0 = alloc() : memref<19xf32> - // CHECK-NEXT: for %i0 = 0 to 5 { + // CHECK-DAG: %0 = alloc() : memref<19xf32> + // CHECK: for %i0 = 0 to 5 { // CHECK-NEXT: for %i1 = 0 to 19 { // CHECK-NEXT: store %cst, %0[%i1] : memref<19xf32> // CHECK-NEXT: } @@ -1210,19 +1207,19 @@ func @should_fuse_with_private_memrefs_with_diff_shapes() { } // Should create two new private memrefs customized to the shapes accessed // by loops %i1 and %i2. - // CHECK: %0 = alloc() : memref<1xf32> - // CHECK-NEXT: for %i0 = 0 to 17 { - // CHECK-NEXT: %1 = affine_apply #map0(%i0, %i0) - // CHECK-NEXT: store %cst, %0[%1] : memref<1xf32> + // CHECK-DAG: %0 = alloc() : memref<1xf32> + // CHECK-DAG: %1 = alloc() : memref<1xf32> + // CHECK: for %i0 = 0 to 17 { // CHECK-NEXT: %2 = affine_apply #map0(%i0, %i0) - // CHECK-NEXT: %3 = load %0[%2] : memref<1xf32> + // CHECK-NEXT: store %cst, %0[%2] : memref<1xf32> + // CHECK-NEXT: %3 = affine_apply #map0(%i0, %i0) + // CHECK-NEXT: %4 = load %0[%3] : memref<1xf32> // CHECK-NEXT: } - // CHECK-NEXT: %4 = alloc() : memref<1xf32> // CHECK-NEXT: for %i1 = 0 to 82 { // CHECK-NEXT: %5 = affine_apply #map0(%i1, %i1) - // CHECK-NEXT: store %cst, %4[%5] : memref<1xf32> + // CHECK-NEXT: store %cst, %1[%5] : memref<1xf32> // CHECK-NEXT: %6 = affine_apply #map0(%i1, %i1) - // CHECK-NEXT: %7 = load %4[%6] : memref<1xf32> + // CHECK-NEXT: %7 = load %1[%6] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -1244,10 +1241,10 @@ func @fusion_should_not_remove_memref_arg(%arg0: memref<10xf32>) { } // This tests that the loop nest '%i0' should not be removed after fusion // because it writes to memref argument '%arg0'. + // CHECK-DAG: %0 = alloc() : memref<1xf32> // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %arg0[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: %0 = alloc() : memref<1xf32> // CHECK-NEXT: for %i1 = 0 to 10 { // CHECK-NEXT: %1 = affine_apply [[MAP0]](%i1, %i1) // CHECK-NEXT: store %cst, %0[%1] : memref<1xf32> @@ -1274,17 +1271,17 @@ func @fusion_should_not_remove_escaping_memref() -> memref<10xf32> { } // This tests that the loop nest '%i0' should not be removed after fusion // because it writes to memref '%m' which is returned by the function. + // CHECK-DAG: %0 = alloc() : memref<1xf32> // CHECK: for %i0 = 0 to 10 { - // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> + // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: %1 = alloc() : memref<1xf32> // CHECK-NEXT: for %i1 = 0 to 10 { // CHECK-NEXT: %2 = affine_apply [[MAP0]](%i1, %i1) - // CHECK-NEXT: store %cst, %1[%2] : memref<1xf32> + // CHECK-NEXT: store %cst, %0[%2] : memref<1xf32> // CHECK-NEXT: %3 = affine_apply [[MAP0]](%i1, %i1) - // CHECK-NEXT: %4 = load %1[%3] : memref<1xf32> + // CHECK-NEXT: %4 = load %0[%3] : memref<1xf32> // CHECK-NEXT: } - // CHECK-NEXT: return %0 : memref<10xf32> + // CHECK-NEXT: return %1 : memref<10xf32> return %m : memref<10xf32> } @@ -1324,8 +1321,8 @@ func @R3_to_R2_reshape() { // CHECK-NEXT: #map6 = (d0) -> (d0 floordiv 48) // CHECK-LABEL: func @R3_to_R2_reshape() -// CHECK: %0 = alloc() : memref<1x1x1xi32> -// CHECK-NEXT: for %i0 = 0 to 32 { +// CHECK-DAG: %0 = alloc() : memref<1x1x1xi32> +// CHECK: for %i0 = 0 to 32 { // CHECK-NEXT: for %i1 = 0 to 3 { // CHECK-NEXT: %1 = affine_apply #map0(%i0, %i1) // CHECK-NEXT: %2 = affine_apply #map1()[%c0] @@ -1371,3 +1368,184 @@ func @should_not_fuse_multi_output_producer() { // CHECK-NEXT: return return } + +// ----- + +// CHECK-LABEL: func @fusion_preventing_deps_on_middle_loop() { +func @fusion_preventing_deps_on_middle_loop() { + %a = alloc() : memref<10xf32> + %b = alloc() : memref<10xf32> + %c = alloc() : memref<10xf32> + + %cf7 = constant 7.0 : f32 + + for %i0 = 0 to 10 { + %v0 = load %a[%i0] : memref<10xf32> + store %v0, %b[%i0] : memref<10xf32> + } + for %i1 = 0 to 10 { + store %cf7, %a[%i1] : memref<10xf32> + %v1 = load %c[%i1] : memref<10xf32> + } + for %i2 = 0 to 10 { + %v2 = load %b[%i2] : memref<10xf32> + store %v2, %c[%i2] : memref<10xf32> + } + // Loops '%i0' and '%i2' cannot fuse along producer/consumer edge on memref + // '%b', because of the WAR dep from '%i0' to '%i1' on memref '%a' and + // because of the WAR dep from '%i1' to '%i2' on memref '%c'. + // CHECK: for %i0 = 0 to 10 { + // CHECK-NEXT: %3 = load %0[%i0] : memref<10xf32> + // CHECK-NEXT: store %3, %1[%i0] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: store %cst, %0[%i1] : memref<10xf32> + // CHECK-NEXT: %4 = load %2[%i1] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: for %i2 = 0 to 10 { + // CHECK-NEXT: %5 = load %1[%i2] : memref<10xf32> + // CHECK-NEXT: store %5, %2[%i2] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: return + return +} + +// ----- + +// CHECK-LABEL: func @should_fuse_and_move_to_preserve_war_dep() { +func @should_fuse_and_move_to_preserve_war_dep() { + %a = alloc() : memref<10xf32> + %b = alloc() : memref<10xf32> + %c = alloc() : memref<10xf32> + + %cf7 = constant 7.0 : f32 + + for %i0 = 0 to 10 { + %v0 = load %b[%i0] : memref<10xf32> + store %v0, %a[%i0] : memref<10xf32> + } + for %i1 = 0 to 3 { + %v2 = load %c[%i1] : memref<10xf32> + } + for %i2 = 0 to 5 { + store %cf7, %b[%i2] : memref<10xf32> + } + for %i3 = 0 to 10 { + %v1 = load %a[%i3] : memref<10xf32> + store %cf7, %c[%i3] : memref<10xf32> + } + + // Dependence graph: + // + // %i0 --------- + // | | + // --- %i1 | %b | %a + // | | | + // %c | %i2 <-- | + // | | + // --> %i3 <-------- + // + // It is possible to fuse loop '%i0' into '%i3' and preserve dependences + // if the fused loop nest is inserted between loops '%i1' and '%i2'. + + // CHECK-DAG: %0 = alloc() : memref<1xf32> + // CHECK: for %i0 = 0 to 3 { + // CHECK-NEXT: %3 = load %2[%i0] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: %4 = load %1[%i1] : memref<10xf32> + // CHECK-NEXT: %5 = affine_apply #map0(%i1, %i1) + // CHECK-NEXT: store %4, %0[%5] : memref<1xf32> + // CHECK-NEXT: %6 = affine_apply #map0(%i1, %i1) + // CHECK-NEXT: %7 = load %0[%6] : memref<1xf32> + // CHECK-NEXT: store %cst, %2[%i1] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: for %i2 = 0 to 5 { + // CHECK-NEXT: store %cst, %1[%i2] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: return + return +} + +// ----- + +// CHECK-LABEL: func @fusion_preventing_dep_on_constant() { +func @fusion_preventing_dep_on_constant() { + %a = alloc() : memref<10xf32> + %b = alloc() : memref<10xf32> + %c = alloc() : memref<10xf32> + + %cf7 = constant 7.0 : f32 + + for %i0 = 0 to 10 { + %v0 = load %b[%i0] : memref<10xf32> + store %cf7, %a[%i0] : memref<10xf32> + } + for %i1 = 0 to 10 { + store %cf7, %b[%i1] : memref<10xf32> + } + %cf11 = constant 11.0 : f32 + for %i2 = 0 to 10 { + %v2 = load %a[%i2] : memref<10xf32> + store %cf11, %c[%i2] : memref<10xf32> + } + // Loops '%i0' and '%i2' cannot fuse along producer/consumer edge on memref + // '%a', because of the WAR dep from '%i0' to '%i1' on memref '%b' and + // because of the SSA value dep from '%cf11' def to use in '%i2'. + // CHECK: for %i0 = 0 to 10 { + // CHECK-NEXT: %3 = load %1[%i0] : memref<10xf32> + // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: store %cst, %1[%i1] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: %cst_0 = constant 1.100000e+01 : f32 + // CHECK-NEXT: for %i2 = 0 to 10 { + // CHECK-NEXT: %4 = load %0[%i2] : memref<10xf32> + // CHECK-NEXT: store %cst_0, %2[%i2] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: return + return +} + +// ----- + +// CHECK-LABEL: func @should_fuse_and_preserve_dep_on_constant() { +func @should_fuse_and_preserve_dep_on_constant() { + %a = alloc() : memref<10xf32> + %b = alloc() : memref<10xf32> + %c = alloc() : memref<10xf32> + + %cf7 = constant 7.0 : f32 + %cf11 = constant 11.0 : f32 + for %i0 = 0 to 10 { + %v0 = load %b[%i0] : memref<10xf32> + store %cf7, %a[%i0] : memref<10xf32> + } + for %i1 = 0 to 10 { + store %cf7, %b[%i1] : memref<10xf32> + } + for %i2 = 0 to 10 { + %v2 = load %a[%i2] : memref<10xf32> + store %cf11, %c[%i2] : memref<10xf32> + } + + // Loops '%i0' and '%i2' can fuse along producer/consumer edge on memref + // '%a', and preserve the WAR dep from '%i0' to '%i1' on memref '%b', and + // the SSA value dep from '%cf11' def to use in '%i2'. + + // CHECK: %cst_0 = constant 1.100000e+01 : f32 + // CHECK-NEXT: for %i0 = 0 to 10 { + // CHECK-NEXT: %3 = load %1[%i0] : memref<10xf32> + // CHECK-NEXT: %4 = affine_apply #map0(%i0, %i0) + // CHECK-NEXT: store %cst, %0[%4] : memref<1xf32> + // CHECK-NEXT: %5 = affine_apply #map0(%i0, %i0) + // CHECK-NEXT: %6 = load %0[%5] : memref<1xf32> + // CHECK-NEXT: store %cst_0, %2[%i0] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: store %cst, %1[%i1] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: return + return +} -- cgit v1.2.3 From d7c824451fa86466b48f48636270cde5366c9b6a Mon Sep 17 00:00:00 2001 From: MLIR Team Date: Wed, 30 Jan 2019 15:53:41 -0800 Subject: LoopFusion: insert the source loop nest slice at a depth in the destination loop nest which preserves dependences (above any loop carried or other dependences). This is accomplished by updating the maximum destination loop depth based on dependence checks between source loop nest loads and stores which access the memref on which the source loop nest has a store op. In addition, prevent fusing in source loop nests which write to memrefs which escape or are live out. PiperOrigin-RevId: 231684492 --- mlir/lib/Transforms/LoopFusion.cpp | 113 ++++++++++++++++++++++++++------- mlir/test/Transforms/loop-fusion.mlir | 114 ++++++++++++++++++++++++++++------ 2 files changed, 184 insertions(+), 43 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index f33afba3806..0da53a5add3 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -231,13 +231,9 @@ public: nodes.erase(id); } - // Returns true if node 'id' can be removed from the graph. Returns false - // otherwise. A node can be removed from the graph iff the following - // conditions are met: - // *) The node does not write to any memref which escapes (or is an argument - // to) the function/block. - // *) The node has no successors in the dependence graph. - bool canRemoveNode(unsigned id) { + // Returns true if node 'id' writes to any memref which escapes (or is an + // argument to) the function/block. Returns false otherwise. + bool writesToLiveInOrEscapingMemrefs(unsigned id) { Node *node = getNode(id); for (auto *storeOpInst : node->stores) { auto *memref = storeOpInst->cast()->getMemRef(); @@ -245,15 +241,30 @@ public: auto *opInst = dyn_cast_or_null(inst); // Return false if 'memref' is a function argument. if (opInst == nullptr) - return false; + return true; // Return false if any use of 'memref' escapes the function. for (auto &use : memref->getUses()) { auto *user = dyn_cast(use.getOwner()); if (!user || !isMemRefDereferencingOp(*user)) - return false; + return true; } + } + return false; + } + + // Returns true if node 'id' can be removed from the graph. Returns false + // otherwise. A node can be removed from the graph iff the following + // conditions are met: + // *) The node does not write to any memref which escapes (or is a + // function/block argument). + // *) The node has no successors in the dependence graph. + bool canRemoveNode(unsigned id) { + if (writesToLiveInOrEscapingMemrefs(id)) + return false; + Node *node = getNode(id); + for (auto *storeOpInst : node->stores) { // Return false if there exist out edges from 'id' on 'memref'. - if (getOutEdgeCount(id, memref) > 0) + if (getOutEdgeCount(id, storeOpInst->cast()->getMemRef()) > 0) return false; } return true; @@ -758,6 +769,49 @@ static unsigned getInnermostCommonLoopDepth(ArrayRef ops) { return loopDepth; } +// Returns the maximum loop depth at which no dependences between 'loadOpInsts' +// and 'storeOpInsts' are satisfied. +static unsigned getMaxLoopDepth(ArrayRef loadOpInsts, + ArrayRef storeOpInsts) { + // Merge loads and stores into the same array. + SmallVector ops(loadOpInsts.begin(), loadOpInsts.end()); + ops.append(storeOpInsts.begin(), storeOpInsts.end()); + + // Compute the innermost common loop depth for loads and stores. + unsigned loopDepth = getInnermostCommonLoopDepth(ops); + + // Return common loop depth for loads if there are no store ops. + if (storeOpInsts.empty()) + return loopDepth; + + // Check dependences on all pairs of ops in 'ops' and store the minimum + // loop depth at which a dependence is satisfied. + for (unsigned i = 0, e = ops.size(); i < e; ++i) { + auto *srcOpInst = ops[i]; + MemRefAccess srcAccess(srcOpInst); + for (unsigned j = 0; j < e; ++j) { + auto *dstOpInst = ops[j]; + MemRefAccess dstAccess(dstOpInst); + + unsigned numCommonLoops = + getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst); + for (unsigned d = 1; d <= numCommonLoops + 1; ++d) { + FlatAffineConstraints dependenceConstraints; + // TODO(andydavis) Cache dependence analysis results, check cache here. + if (checkMemrefAccessDependence(srcAccess, dstAccess, d, + &dependenceConstraints, + /*dependenceComponents=*/nullptr)) { + // Store minimum loop depth and break because we want the min 'd' at + // which there is a dependence. + loopDepth = std::min(loopDepth, d - 1); + break; + } + } + } + } + return loopDepth; +} + // Returns the slice union of 'sliceStateA' and 'sliceStateB' in 'sliceStateB' // using a rectangular bounding box. // TODO(andydavis) This function assumes that lower bounds for 'sliceStateA' @@ -926,7 +980,7 @@ static uint64_t getSliceIterationCount( } // Checks the profitability of fusing a backwards slice of the loop nest -// surrounding 'srcOpInst' into the loop nest surrounding 'dstOpInsts'. +// surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'. // Returns true if it is profitable to fuse the candidate loop nests. Returns // false otherwise. `dstLoopDepth` is set to the most profitable depth at which // to materialize the source loop nest slice. @@ -943,7 +997,7 @@ static uint64_t getSliceIterationCount( // the largest compution slice at the maximal dst loop depth (closest to the // load) to minimize reuse distance and potentially enable subsequent // load/store forwarding. -// NOTE: If the dst loop nest includes multiple loads in 'dstOpInsts' for +// NOTE: If the dst loop nest includes multiple loads in 'dstLoadOpInsts' for // the same memref as is written by 'srcOpInst', then the union of slice // loop bounds is used to compute the slice and associated slice cost. // NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop @@ -956,7 +1010,8 @@ static uint64_t getSliceIterationCount( // loop nest computed in the previous step, and returns true if the latter // is lower. static bool isFusionProfitable(OperationInst *srcOpInst, - ArrayRef dstOpInsts, + ArrayRef dstLoadOpInsts, + ArrayRef dstStoreOpInsts, ComputationSliceState *sliceState, unsigned *dstLoopDepth) { LLVM_DEBUG({ @@ -964,7 +1019,7 @@ static bool isFusionProfitable(OperationInst *srcOpInst, llvm::dbgs() << " "; srcOpInst->dump(); llvm::dbgs() << " and \n"; - for (auto dstOpInst : dstOpInsts) { + for (auto dstOpInst : dstLoadOpInsts) { llvm::dbgs() << " "; dstOpInst->dump(); }; @@ -985,7 +1040,7 @@ static bool isFusionProfitable(OperationInst *srcOpInst, // Compute cost of dst loop nest. SmallVector dstLoopIVs; - getLoopIVs(*dstOpInsts[0], &dstLoopIVs); + getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs); LoopNestStats dstLoopNestStats; LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats); @@ -994,8 +1049,9 @@ static bool isFusionProfitable(OperationInst *srcOpInst, if (dstStatsCollector.hasLoopWithNonConstTripCount) return false; - // Compute the innermost common loop for ops in 'dstOpInst'. - unsigned maxDstLoopDepth = getInnermostCommonLoopDepth(dstOpInsts); + // Compute the maximum loop depth at which we can can insert the src slice + // and still satisfy dest loop nest dependences. + unsigned maxDstLoopDepth = getMaxLoopDepth(dstLoadOpInsts, dstStoreOpInsts); if (maxDstLoopDepth == 0) return false; @@ -1030,11 +1086,11 @@ static bool isFusionProfitable(OperationInst *srcOpInst, MemRefAccess srcAccess(srcOpInst); // Handle the common case of one dst load without a copy. if (!mlir::getBackwardComputationSliceState( - srcAccess, MemRefAccess(dstOpInsts[0]), i, &sliceStates[i - 1])) + srcAccess, MemRefAccess(dstLoadOpInsts[0]), i, &sliceStates[i - 1])) return false; - // Compute the union of slice bound of all ops in 'dstOpInsts'. - for (int j = 1, e = dstOpInsts.size(); j < e; ++j) { - MemRefAccess dstAccess(dstOpInsts[j]); + // Compute the union of slice bound of all ops in 'dstLoadOpInsts'. + for (int j = 1, e = dstLoadOpInsts.size(); j < e; ++j) { + MemRefAccess dstAccess(dstLoadOpInsts[j]); ComputationSliceState tmpSliceState; if (!mlir::getBackwardComputationSliceState(srcAccess, dstAccess, i, &tmpSliceState)) @@ -1062,7 +1118,7 @@ static bool isFusionProfitable(OperationInst *srcOpInst, if (storeLoadFwdGuaranteed) { // A single store disappears: -1 for that. computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]] = -1; - for (auto *loadOp : dstOpInsts) { + for (auto *loadOp : dstLoadOpInsts) { if (auto *loadLoop = dyn_cast_or_null(loadOp->getParentInst())) computeCostMap[loadLoop] = -1; } @@ -1321,6 +1377,10 @@ public: if (mdg->getIncomingMemRefAccesses(srcNode->id, memref) != 0) continue; + // Skip if 'srcNode' writes to any live in or escaping memrefs. + if (mdg->writesToLiveInOrEscapingMemrefs(srcNode->id)) + continue; + // Compute an instruction list insertion point for the fused loop // nest which preserves dependences. Instruction *insertPointInst = mdg->getFusedLoopNestInsertionPoint( @@ -1330,10 +1390,17 @@ public: // Get unique 'srcNode' store op. auto *srcStoreOpInst = srcNode->stores.front(); + // Gather 'dstNode' store ops to 'memref'. + SmallVector dstStoreOpInsts; + for (auto *storeOpInst : dstNode->stores) + if (storeOpInst->cast()->getMemRef() == memref) + dstStoreOpInsts.push_back(storeOpInst); + unsigned bestDstLoopDepth; mlir::ComputationSliceState sliceState; // Check if fusion would be profitable. - if (!isFusionProfitable(srcStoreOpInst, dstLoadOpInsts, &sliceState, + if (!isFusionProfitable(srcStoreOpInst, dstLoadOpInsts, + dstStoreOpInsts, &sliceState, &bestDstLoopDepth)) continue; diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index 3b37612fd27..bb8dd0db73e 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -1252,10 +1252,8 @@ func @should_fuse_with_private_memrefs_with_diff_shapes() { // ----- -// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) - -// CHECK-LABEL: func @fusion_should_not_remove_memref_arg(%arg0: memref<10xf32>) { -func @fusion_should_not_remove_memref_arg(%arg0: memref<10xf32>) { +// CHECK-LABEL: func @should_not_fuse_live_out_arg(%arg0: memref<10xf32>) { +func @should_not_fuse_live_out_arg(%arg0: memref<10xf32>) { %cf7 = constant 7.0 : f32 for %i0 = 0 to 10 { @@ -1266,15 +1264,11 @@ func @fusion_should_not_remove_memref_arg(%arg0: memref<10xf32>) { } // This tests that the loop nest '%i0' should not be removed after fusion // because it writes to memref argument '%arg0'. - // CHECK-DAG: %0 = alloc() : memref<1xf32> // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %arg0[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: for %i1 = 0 to 10 { - // CHECK-NEXT: %1 = affine_apply [[MAP0]](%i1, %i1) - // CHECK-NEXT: store %cst, %0[%1] : memref<1xf32> - // CHECK-NEXT: %2 = affine_apply [[MAP0]](%i1, %i1) - // CHECK-NEXT: %3 = load %0[%2] : memref<1xf32> + // CHECK-NEXT: %0 = load %arg0[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -1282,10 +1276,8 @@ func @fusion_should_not_remove_memref_arg(%arg0: memref<10xf32>) { // ----- -// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) - -// CHECK-LABEL: func @fusion_should_not_remove_escaping_memref() -func @fusion_should_not_remove_escaping_memref() -> memref<10xf32> { +// CHECK-LABEL: func @should_not_fuse_escaping_memref() -> memref<10xf32> +func @should_not_fuse_escaping_memref() -> memref<10xf32> { %cf7 = constant 7.0 : f32 %m = alloc() : memref<10xf32> for %i0 = 0 to 10 { @@ -1296,17 +1288,14 @@ func @fusion_should_not_remove_escaping_memref() -> memref<10xf32> { } // This tests that the loop nest '%i0' should not be removed after fusion // because it writes to memref '%m' which is returned by the function. - // CHECK-DAG: %0 = alloc() : memref<1xf32> + // CHECK-DAG: %0 = alloc() : memref<10xf32> // CHECK: for %i0 = 0 to 10 { - // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32> + // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: for %i1 = 0 to 10 { - // CHECK-NEXT: %2 = affine_apply [[MAP0]](%i1, %i1) - // CHECK-NEXT: store %cst, %0[%2] : memref<1xf32> - // CHECK-NEXT: %3 = affine_apply [[MAP0]](%i1, %i1) - // CHECK-NEXT: %4 = load %0[%3] : memref<1xf32> + // CHECK-NEXT: %1 = load %0[%i1] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: return %1 : memref<10xf32> + // CHECK-NEXT: return %0 : memref<10xf32> return %m : memref<10xf32> } @@ -1578,3 +1567,88 @@ func @should_fuse_and_preserve_dep_on_constant() { // CHECK-NEXT: return return } + +// ----- + +// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1, d2) -> (d1) +// CHECK: [[MAP1:#map[0-9]+]] = (d0, d1, d2) -> (-d0 + d2) +// CHECK: [[MAP2:#map[0-9]+]] = (d0, d1) -> (d0 * 16 - d1 + 15) +// CHECK: [[MAP3:#map[0-9]+]] = (d0, d1) -> (d0 * 16 + d1) + +// CHECK-LABEL: func @should_fuse_at_depth_above_loop_carried_dependence(%arg0: memref<64x4xf32>, %arg1: memref<64x4xf32>) { +func @should_fuse_at_depth_above_loop_carried_dependence(%arg0: memref<64x4xf32>, %arg1: memref<64x4xf32>) { + %out = alloc() : memref<64x4xf32> + %0 = constant 0.0 : f32 + for %i0 = 0 to 64 { + for %i1 = 0 to 4 { + store %0, %out[%i0, %i1] : memref<64x4xf32> + } + } + for %i2 = 0 to 4 { + for %i3 = 0 to 4 { + for %i4 = 0 to 16 { + %1 = affine_apply (d0, d1) -> (d0 * 16 - d1 + 15)(%i3, %i4) + %2 = load %arg1[%1, %i2] : memref<64x4xf32> + "op0"(%2) : (f32) -> () + } + for %i5 = 0 to 4 { + for %i6 = 0 to 16 { + %3 = affine_apply (d0, d1) -> (d0 * 16 - d1 + 15)(%i5, %i6) + %4 = load %arg0[%3, %i3] : memref<64x4xf32> + "op1"(%4) : (f32) -> () + } + for %i7 = 0 to 16 { + %5 = "op2"() : () -> (f32) + %6 = affine_apply (d0, d1) -> (d0 * 16 + d1)(%i5, %i7) + %7 = load %out[%6, %i2] : memref<64x4xf32> + %8 = addf %7, %5 : f32 + store %8, %out[%6, %i2] : memref<64x4xf32> + } + } + } + } + + // We can fuse source loop nest '%i0' into dst loop nest '%i2', but the + // depth at which we can insert the src loop nest slice into the dst loop + // lest must be decreased because of a loop carried dependence on loop '%i3'. + // As a result, the source loop nest is inserted at dst loop nest depth 1, + // just above the loop with the carried depenence. In addition, the source + // loop nest iteration bounds on its loop '%i1' are reduced to 1, so the + // memref size can be reduced to 128x1xf32. + + // CHECK: %0 = alloc() : memref<64x1xf32> + // CHECK: for %i0 = 0 to 4 { + // CHECK-NEXT: for %i1 = 0 to 64 { + // CHECK-NEXT: %1 = affine_apply [[MAP0]](%i0, %i1, %i0) + // CHECK-NEXT: %2 = affine_apply [[MAP1]](%i0, %i1, %i0) + // CHECK-NEXT: store %cst, %0[%1, %2] : memref<64x1xf32> + // CHECK-NEXT: } + // CHECK-NEXT: for %i2 = 0 to 4 { + // CHECK-NEXT: for %i3 = 0 to 16 { + // CHECK-NEXT: %3 = affine_apply [[MAP2]](%i2, %i3) + // CHECK-NEXT: %4 = load %arg1[%3, %i0] : memref<64x4xf32> + // CHECK-NEXT: "op0"(%4) : (f32) -> () + // CHECK-NEXT: } + // CHECK-NEXT: for %i4 = 0 to 4 { + // CHECK-NEXT: for %i5 = 0 to 16 { + // CHECK-NEXT: %5 = affine_apply [[MAP2]](%i4, %i5) + // CHECK-NEXT: %6 = load %arg0[%5, %i2] : memref<64x4xf32> + // CHECK-NEXT: "op1"(%6) : (f32) -> () + // CHECK-NEXT: } + // CHECK-NEXT: for %i6 = 0 to 16 { + // CHECK-NEXT: %7 = "op2"() : () -> f32 + // CHECK-NEXT: %8 = affine_apply [[MAP3]](%i4, %i6) + // CHECK-NEXT: %9 = affine_apply [[MAP0]](%i0, %8, %i0) + // CHECK-NEXT: %10 = affine_apply [[MAP1]](%i0, %8, %i0) + // CHECK-NEXT: %11 = load %0[%9, %10] : memref<64x1xf32> + // CHECK-NEXT: %12 = addf %11, %7 : f32 + // CHECK-NEXT: %13 = affine_apply [[MAP0]](%i0, %8, %i0) + // CHECK-NEXT: %14 = affine_apply [[MAP1]](%i0, %8, %i0) + // CHECK-NEXT: store %12, %0[%13, %14] : memref<64x1xf32> + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: return + return +} -- cgit v1.2.3 From 1e85191d07da6e4b35a2ec590a7dda9bc3961c13 Mon Sep 17 00:00:00 2001 From: MLIR Team Date: Wed, 30 Jan 2019 16:01:46 -0800 Subject: Fix ASAN issue: snapshot edge list before loop which can modify this list. PiperOrigin-RevId: 231686040 --- mlir/lib/Transforms/LoopFusion.cpp | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 0da53a5add3..fa0e3b51de3 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -166,6 +166,10 @@ public: // Edge represents a data dependece between nodes in the graph. struct Edge { // The id of the node at the other end of the edge. + // If this edge is stored in Edge = Node.inEdges[i], then + // 'Node.inEdges[i].id' is the identifier of the source node of the edge. + // If this edge is stored in Edge = Node.outEdges[i], then + // 'Node.outEdges[i].id' is the identifier of the dest node of the edge. unsigned id; // The SSA value on which this edge represents a dependence. // If the value is a memref, then the dependence is between graph nodes @@ -1355,13 +1359,21 @@ public: // Skip if no input edges along which to fuse. if (mdg->inEdges.count(dstId) == 0) continue; - // Iterate through in edges for 'dstId'. + // Iterate through in edges for 'dstId' and src node id for any + // edges on 'memref'. + SmallVector srcNodeIds; for (auto &srcEdge : mdg->inEdges[dstId]) { // Skip 'srcEdge' if not for 'memref'. if (srcEdge.value != memref) continue; - - auto *srcNode = mdg->getNode(srcEdge.id); + srcNodeIds.push_back(srcEdge.id); + } + for (unsigned srcId : srcNodeIds) { + // Skip if this node was removed (fused into another node). + if (mdg->nodes.count(srcId) == 0) + continue; + // Get 'srcNode' from which to attempt fusion into 'dstNode'. + auto *srcNode = mdg->getNode(srcId); // Skip if 'srcNode' is not a loop nest. if (!isa(srcNode->inst)) continue; -- cgit v1.2.3 From 5052bd8582fbcfc0a4774c34141c2dd04b333613 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Fri, 1 Feb 2019 16:42:18 -0800 Subject: Define the AffineForOp and replace ForInst with it. This patch is largely mechanical, i.e. changing usages of ForInst to OpPointer. An important difference is that upon construction an AffineForOp no longer automatically creates the body and induction variable. To generate the body/iv, 'createBody' can be called on an AffineForOp with no body. PiperOrigin-RevId: 232060516 --- mlir/include/mlir/AffineOps/AffineOps.h | 219 ++++++++++- mlir/include/mlir/Analysis/AffineAnalysis.h | 6 +- mlir/include/mlir/Analysis/AffineStructures.h | 23 +- mlir/include/mlir/Analysis/LoopAnalysis.h | 22 +- mlir/include/mlir/Analysis/NestedMatcher.h | 1 - mlir/include/mlir/Analysis/Utils.h | 17 +- mlir/include/mlir/Analysis/VectorAnalysis.h | 10 +- mlir/include/mlir/IR/Builders.h | 14 - mlir/include/mlir/IR/InstVisitor.h | 24 +- mlir/include/mlir/IR/Instruction.h | 2 - mlir/include/mlir/IR/Instructions.h | 267 +------------- mlir/include/mlir/IR/OpDefinition.h | 14 + mlir/include/mlir/IR/OpImplementation.h | 14 +- mlir/include/mlir/IR/UseDefLists.h | 5 +- mlir/include/mlir/Transforms/LoopUtils.h | 34 +- mlir/include/mlir/Transforms/Passes.h | 10 +- mlir/include/mlir/Transforms/Utils.h | 4 +- mlir/lib/AffineOps/AffineOps.cpp | 443 ++++++++++++++++++++++- mlir/lib/Analysis/AffineAnalysis.cpp | 28 +- mlir/lib/Analysis/AffineStructures.cpp | 22 +- mlir/lib/Analysis/LoopAnalysis.cpp | 75 ++-- mlir/lib/Analysis/NestedMatcher.cpp | 26 +- mlir/lib/Analysis/SliceAnalysis.cpp | 22 +- mlir/lib/Analysis/Utils.cpp | 76 ++-- mlir/lib/Analysis/VectorAnalysis.cpp | 37 +- mlir/lib/Analysis/Verifier.cpp | 14 - mlir/lib/EDSC/MLIREmitter.cpp | 10 +- mlir/lib/IR/AsmPrinter.cpp | 116 +----- mlir/lib/IR/Builders.cpp | 16 - mlir/lib/IR/Instruction.cpp | 337 +++-------------- mlir/lib/IR/Value.cpp | 2 - mlir/lib/Parser/Parser.cpp | 346 +++++------------- mlir/lib/Transforms/CSE.cpp | 5 - mlir/lib/Transforms/ConstantFold.cpp | 13 +- mlir/lib/Transforms/DmaGeneration.cpp | 47 ++- mlir/lib/Transforms/LoopFusion.cpp | 200 +++++----- mlir/lib/Transforms/LoopTiling.cpp | 86 ++--- mlir/lib/Transforms/LoopUnroll.cpp | 67 ++-- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 86 +++-- mlir/lib/Transforms/LowerAffine.cpp | 68 ++-- mlir/lib/Transforms/MaterializeVectors.cpp | 20 +- mlir/lib/Transforms/PipelineDataTransfer.cpp | 70 ++-- mlir/lib/Transforms/SimplifyAffineStructures.cpp | 1 + mlir/lib/Transforms/Utils/LoopUtils.cpp | 170 ++++----- mlir/lib/Transforms/Utils/Utils.cpp | 5 +- mlir/lib/Transforms/Vectorize.cpp | 63 ++-- mlir/test/IR/invalid.mlir | 89 +---- mlir/test/IR/locations.mlir | 6 +- mlir/test/IR/parser.mlir | 24 +- mlir/test/IR/pretty-locations.mlir | 6 +- mlir/test/Transforms/strip-debuginfo.mlir | 6 +- mlir/test/Transforms/unroll.mlir | 11 +- 52 files changed, 1569 insertions(+), 1730 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/AffineOps/AffineOps.h b/mlir/include/mlir/AffineOps/AffineOps.h index d511f628c3c..b9def6cb24f 100644 --- a/mlir/include/mlir/AffineOps/AffineOps.h +++ b/mlir/include/mlir/AffineOps/AffineOps.h @@ -28,13 +28,230 @@ #include "mlir/IR/StandardTypes.h" namespace mlir { +class AffineBound; class AffineOpsDialect : public Dialect { public: AffineOpsDialect(MLIRContext *context); }; -/// The "if" operation represents an if–then–else construct for conditionally +/// The "for" instruction represents an affine loop nest, defining an SSA value +/// for its induction variable. The induction variable is represented as a +/// BlockArgument to the entry block of the body. The body and induction +/// variable can be created automatically for new "for" ops with 'createBody'. +/// This SSA value always has type index, which is the size of the machine word. +/// The stride, represented by step, is a positive constant integer which +/// defaults to "1" if not present. The lower and upper bounds specify a +/// half-open range: the range includes the lower bound but does not include the +/// upper bound. +/// +/// The lower and upper bounds of a for operation are represented as an +/// application of an affine mapping to a list of SSA values passed to the map. +/// The same restrictions hold for these SSA values as for all bindings of SSA +/// values to dimensions and symbols. The affine mappings for the bounds may +/// return multiple results, in which case the max/min keywords are required +/// (for the lower/upper bound respectively), and the bound is the +/// maximum/minimum of the returned values. +/// +/// Example: +/// +/// for %i = 1 to 10 { +/// ... +/// } +/// +class AffineForOp + : public Op { +public: + // Hooks to customize behavior of this op. + static void build(Builder *builder, OperationState *result, + ArrayRef lbOperands, AffineMap lbMap, + ArrayRef ubOperands, AffineMap ubMap, + int64_t step = 1); + static void build(Builder *builder, OperationState *result, int64_t lb, + int64_t ub, int64_t step = 1); + bool verify() const; + static bool parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p) const; + + static StringRef getOperationName() { return "for"; } + static StringRef getStepAttrName() { return "step"; } + static StringRef getLowerBoundAttrName() { return "lower_bound"; } + static StringRef getUpperBoundAttrName() { return "upper_bound"; } + + /// Generate a body block for this AffineForOp. The operation must not already + /// have a body. The operation must contain a parent function. + Block *createBody(); + + /// Get the body of the AffineForOp. + Block *getBody() { return &getBlockList().front(); } + const Block *getBody() const { return &getBlockList().front(); } + + /// Get the blocklist containing the body. + BlockList &getBlockList() { return getInstruction()->getBlockList(0); } + const BlockList &getBlockList() const { + return getInstruction()->getBlockList(0); + } + + /// Returns the induction variable for this loop. + Value *getInductionVar(); + const Value *getInductionVar() const { + return const_cast(this)->getInductionVar(); + } + + //===--------------------------------------------------------------------===// + // Bounds and step + //===--------------------------------------------------------------------===// + + using operand_range = llvm::iterator_range; + using const_operand_range = llvm::iterator_range; + + // TODO: provide iterators for the lower and upper bound operands + // if the current access via getLowerBound(), getUpperBound() is too slow. + + /// Returns operands for the lower bound map. + operand_range getLowerBoundOperands(); + const_operand_range getLowerBoundOperands() const; + + /// Returns operands for the upper bound map. + operand_range getUpperBoundOperands(); + const_operand_range getUpperBoundOperands() const; + + /// Returns information about the lower bound as a single object. + const AffineBound getLowerBound() const; + + /// Returns information about the upper bound as a single object. + const AffineBound getUpperBound() const; + + /// Returns loop step. + int64_t getStep() const { + return getAttr(getStepAttrName()).cast().getInt(); + } + + /// Returns affine map for the lower bound. + AffineMap getLowerBoundMap() const { + return getAttr(getLowerBoundAttrName()).cast().getValue(); + } + /// Returns affine map for the upper bound. The upper bound is exclusive. + AffineMap getUpperBoundMap() const { + return getAttr(getUpperBoundAttrName()).cast().getValue(); + } + + /// Set lower bound. The new bound must have the same number of operands as + /// the current bound map. Otherwise, 'replaceForLowerBound' should be used. + void setLowerBound(ArrayRef operands, AffineMap map); + /// Set upper bound. The new bound must not have more operands than the + /// current bound map. Otherwise, 'replaceForUpperBound' should be used. + void setUpperBound(ArrayRef operands, AffineMap map); + + /// Set the lower bound map without changing operands. + void setLowerBoundMap(AffineMap map); + + /// Set the upper bound map without changing operands. + void setUpperBoundMap(AffineMap map); + + /// Set loop step. + void setStep(int64_t step) { + assert(step > 0 && "step has to be a positive integer constant"); + auto *context = getLowerBoundMap().getContext(); + setAttr(Identifier::get(getStepAttrName(), context), + IntegerAttr::get(IndexType::get(context), step)); + } + + /// Returns true if the lower bound is constant. + bool hasConstantLowerBound() const; + /// Returns true if the upper bound is constant. + bool hasConstantUpperBound() const; + /// Returns true if both bounds are constant. + bool hasConstantBounds() const { + return hasConstantLowerBound() && hasConstantUpperBound(); + } + /// Returns the value of the constant lower bound. + /// Fails assertion if the bound is non-constant. + int64_t getConstantLowerBound() const; + /// Returns the value of the constant upper bound. The upper bound is + /// exclusive. Fails assertion if the bound is non-constant. + int64_t getConstantUpperBound() const; + /// Sets the lower bound to the given constant value. + void setConstantLowerBound(int64_t value); + /// Sets the upper bound to the given constant value. + void setConstantUpperBound(int64_t value); + + /// Returns true if both the lower and upper bound have the same operand lists + /// (same operands in the same order). + bool matchingBoundOperandList() const; + + /// Walk the operation instructions in the 'for' instruction in preorder, + /// calling the callback for each operation. + void walkOps(std::function callback); + + /// Walk the operation instructions in the 'for' instruction in postorder, + /// calling the callback for each operation. + void walkOpsPostOrder(std::function callback); + +private: + friend class OperationInst; + explicit AffineForOp(const OperationInst *state) : Op(state) {} +}; + +/// Returns if the provided value is the induction variable of a AffineForOp. +bool isForInductionVar(const Value *val); + +/// Returns the loop parent of an induction variable. If the provided value is +/// not an induction variable, then return nullptr. +OpPointer getForInductionVarOwner(Value *val); +ConstOpPointer getForInductionVarOwner(const Value *val); + +/// Extracts the induction variables from a list of AffineForOps and returns +/// them. +SmallVector +extractForInductionVars(MutableArrayRef> forInsts); + +/// AffineBound represents a lower or upper bound in the for instruction. +/// This class does not own the underlying operands. Instead, it refers +/// to the operands stored in the AffineForOp. Its life span should not exceed +/// that of the for instruction it refers to. +class AffineBound { +public: + ConstOpPointer getAffineForOp() const { return inst; } + AffineMap getMap() const { return map; } + + unsigned getNumOperands() const { return opEnd - opStart; } + const Value *getOperand(unsigned idx) const { + return inst->getInstruction()->getOperand(opStart + idx); + } + + using operand_iterator = AffineForOp::operand_iterator; + using operand_range = AffineForOp::operand_range; + + operand_iterator operand_begin() const { + return const_cast(inst->getInstruction()) + ->operand_begin() + + opStart; + } + operand_iterator operand_end() const { + return const_cast(inst->getInstruction()) + ->operand_begin() + + opEnd; + } + operand_range getOperands() const { return {operand_begin(), operand_end()}; } + +private: + // 'for' instruction that contains this bound. + ConstOpPointer inst; + // Start and end positions of this affine bound operands in the list of + // the containing 'for' instruction operands. + unsigned opStart, opEnd; + // Affine map for this bound. + AffineMap map; + + AffineBound(ConstOpPointer inst, unsigned opStart, + unsigned opEnd, AffineMap map) + : inst(inst), opStart(opStart), opEnd(opEnd), map(map) {} + + friend class AffineForOp; +}; + +/// The "if" operation represents an if-then-else construct for conditionally /// executing two regions of code. The operands to an if operation are an /// IntegerSet condition and a set of symbol/dimension operands to the /// condition set. The operation produces no results. For example: diff --git a/mlir/include/mlir/Analysis/AffineAnalysis.h b/mlir/include/mlir/Analysis/AffineAnalysis.h index 30576b587a0..3ee35eea2ff 100644 --- a/mlir/include/mlir/Analysis/AffineAnalysis.h +++ b/mlir/include/mlir/Analysis/AffineAnalysis.h @@ -32,10 +32,10 @@ namespace mlir { class AffineApplyOp; class AffineExpr; +class AffineForOp; class AffineMap; class AffineValueMap; class FlatAffineConstraints; -class ForInst; class FuncBuilder; class Instruction; class IntegerSet; @@ -108,12 +108,12 @@ bool getFlattenedAffineExprs( FlatAffineConstraints *cst = nullptr); /// Builds a system of constraints with dimensional identifiers corresponding to -/// the loop IVs of the forInsts appearing in that order. Bounds of the loop are +/// the loop IVs of the forOps appearing in that order. Bounds of the loop are /// used to add appropriate inequalities. Any symbols founds in the bound /// operands are added as symbols in the system. Returns false for the yet /// unimplemented cases. // TODO(bondhugula): handle non-unit strides. -bool getIndexSet(llvm::ArrayRef forInsts, +bool getIndexSet(llvm::MutableArrayRef> forOps, FlatAffineConstraints *domain); /// Encapsulates a memref load or store access information. diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h index 49202eb7cc5..e8b4ee623c0 100644 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -28,9 +28,10 @@ namespace mlir { class AffineApplyOp; class AffineBound; +class AffineForOp; class AffineCondition; class AffineMap; -class ForInst; +template class ConstOpPointer; class IntegerSet; class MLIRContext; class Value; @@ -113,13 +114,12 @@ private: /// results, and its map can themselves change as a result of /// substitutions, simplifications, and other analysis. // An affine value map can readily be constructed from an AffineApplyOp, or an -// AffineBound of a ForInst. It can be further transformed, substituted into, -// or simplified. Unlike AffineMap's, AffineValueMap's are created and destroyed -// during analysis. Only the AffineMap expressions that are pointed by them are -// unique'd. -// An affine value map, and the operations on it, maintain the invariant that -// operands are always positionally aligned with the AffineDimExpr and -// AffineSymbolExpr in the underlying AffineMap. +// AffineBound of a AffineForOp. It can be further transformed, substituted +// into, or simplified. Unlike AffineMap's, AffineValueMap's are created and +// destroyed during analysis. Only the AffineMap expressions that are pointed by +// them are unique'd. An affine value map, and the operations on it, maintain +// the invariant that operands are always positionally aligned with the +// AffineDimExpr and AffineSymbolExpr in the underlying AffineMap. // TODO(bondhugula): Some of these classes could go into separate files. class AffineValueMap { public: @@ -173,9 +173,6 @@ private: // Both, the integer set being pointed to and the operands can change during // analysis, simplification, and transformation. class IntegerValueSet { - // Constructs an integer value set map from an IntegerSet and operands. - explicit IntegerValueSet(const AffineCondition &cond); - /// Constructs an integer value set from an affine value map. // This will lead to a single equality in 'set'. explicit IntegerValueSet(const AffineValueMap &avm); @@ -403,7 +400,7 @@ public: /// Adds constraints (lower and upper bounds) for the specified 'for' /// instruction's Value using IR information stored in its bound maps. The - /// right identifier is first looked up using forInst's Value. Returns + /// right identifier is first looked up using forOp's Value. Returns /// false for the yet unimplemented/unsupported cases, and true if the /// information is succesfully added. Asserts if the Value corresponding to /// the 'for' instruction isn't found in the constraint system. Any new @@ -411,7 +408,7 @@ public: /// are added as trailing identifiers (either dimensional or symbolic /// depending on whether the operand is a valid ML Function symbol). // TODO(bondhugula): add support for non-unit strides. - bool addForInstDomain(const ForInst &forInst); + bool addAffineForOpDomain(ConstOpPointer forOp); /// Adds a constant lower bound constraint for the specified expression. void addConstantLowerBound(ArrayRef expr, int64_t lb); diff --git a/mlir/include/mlir/Analysis/LoopAnalysis.h b/mlir/include/mlir/Analysis/LoopAnalysis.h index 1b3d0ce9675..16c1c967385 100644 --- a/mlir/include/mlir/Analysis/LoopAnalysis.h +++ b/mlir/include/mlir/Analysis/LoopAnalysis.h @@ -29,8 +29,9 @@ namespace mlir { class AffineExpr; +class AffineForOp; class AffineMap; -class ForInst; +template class ConstOpPointer; class MemRefType; class OperationInst; class Value; @@ -38,19 +39,20 @@ class Value; /// Returns the trip count of the loop as an affine expression if the latter is /// expressible as an affine expression, and nullptr otherwise. The trip count /// expression is simplified before returning. -AffineExpr getTripCountExpr(const ForInst &forInst); +AffineExpr getTripCountExpr(ConstOpPointer forOp); /// Returns the trip count of the loop if it's a constant, None otherwise. This /// uses affine expression analysis and is able to determine constant trip count /// in non-trivial cases. -llvm::Optional getConstantTripCount(const ForInst &forInst); +llvm::Optional +getConstantTripCount(ConstOpPointer forOp); /// Returns the greatest known integral divisor of the trip count. Affine /// expression analysis is used (indirectly through getTripCount), and /// this method is thus able to determine non-trivial divisors. -uint64_t getLargestDivisorOfTripCount(const ForInst &forInst); +uint64_t getLargestDivisorOfTripCount(ConstOpPointer forOp); -/// Given an induction variable `iv` of type ForInst and an `index` of type +/// Given an induction variable `iv` of type AffineForOp and an `index` of type /// IndexType, returns `true` if `index` is independent of `iv` and false /// otherwise. /// The determination supports composition with at most one AffineApplyOp. @@ -67,7 +69,7 @@ uint64_t getLargestDivisorOfTripCount(const ForInst &forInst); /// conservative. bool isAccessInvariant(const Value &iv, const Value &index); -/// Given an induction variable `iv` of type ForInst and `indices` of type +/// Given an induction variable `iv` of type AffineForOp and `indices` of type /// IndexType, returns the set of `indices` that are independent of `iv`. /// /// Prerequisites (inherited from `isAccessInvariant` above): @@ -85,21 +87,21 @@ getInvariantAccesses(const Value &iv, llvm::ArrayRef indices); /// 3. all nested load/stores are to scalar MemRefs. /// TODO(ntv): implement dependence semantics /// TODO(ntv): relax the no-conditionals restriction -bool isVectorizableLoop(const ForInst &loop); +bool isVectorizableLoop(ConstOpPointer loop); /// Checks whether the loop is structurally vectorizable and that all the LoadOp /// and StoreOp matched have access indexing functions that are are either: /// 1. invariant along the loop induction variable created by 'loop'; /// 2. varying along the 'fastestVaryingDim' memory dimension. -bool isVectorizableLoopAlongFastestVaryingMemRefDim(const ForInst &loop, - unsigned fastestVaryingDim); +bool isVectorizableLoopAlongFastestVaryingMemRefDim( + ConstOpPointer loop, unsigned fastestVaryingDim); /// Checks where SSA dominance would be violated if a for inst's body /// instructions are shifted by the specified shifts. This method checks if a /// 'def' and all its uses have the same shift factor. // TODO(mlir-team): extend this to check for memory-based dependence // violation when we have the support. -bool isInstwiseShiftValid(const ForInst &forInst, +bool isInstwiseShiftValid(ConstOpPointer forOp, llvm::ArrayRef shifts); } // end namespace mlir diff --git a/mlir/include/mlir/Analysis/NestedMatcher.h b/mlir/include/mlir/Analysis/NestedMatcher.h index 2a1c469348d..0e41058f777 100644 --- a/mlir/include/mlir/Analysis/NestedMatcher.h +++ b/mlir/include/mlir/Analysis/NestedMatcher.h @@ -127,7 +127,6 @@ private: struct State : public InstWalker { State(NestedPattern &pattern, SmallVectorImpl *matches) : pattern(pattern), matches(matches) {} - void visitForInst(ForInst *forInst) { pattern.matchOne(forInst, matches); } void visitOperationInst(OperationInst *opInst) { pattern.matchOne(opInst, matches); } diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index e9de4a8d259..bb81df604cf 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -33,10 +33,12 @@ namespace mlir { +class AffineForOp; +template class ConstOpPointer; class FlatAffineConstraints; -class ForInst; class MemRefAccess; class OperationInst; +template class OpPointer; class Instruction; class Value; @@ -49,7 +51,8 @@ bool properlyDominates(const Instruction &a, const Instruction &b); /// Populates 'loops' with IVs of the loops surrounding 'inst' ordered from /// the outermost 'for' instruction to the innermost one. // TODO(bondhugula): handle 'if' inst's. -void getLoopIVs(const Instruction &inst, SmallVectorImpl *loops); +void getLoopIVs(const Instruction &inst, + SmallVectorImpl> *loops); /// Returns the nesting depth of this instruction, i.e., the number of loops /// surrounding this instruction. @@ -191,12 +194,12 @@ bool getBackwardComputationSliceState(const MemRefAccess &srcAccess, // materialize the results of the backward slice - presenting a trade-off b/w // storage and redundant computation in several cases. // TODO(andydavis) Support computation slices with common surrounding loops. -ForInst *insertBackwardComputationSlice(OperationInst *srcOpInst, - OperationInst *dstOpInst, - unsigned dstLoopDepth, - ComputationSliceState *sliceState); +OpPointer +insertBackwardComputationSlice(OperationInst *srcOpInst, + OperationInst *dstOpInst, unsigned dstLoopDepth, + ComputationSliceState *sliceState); -Optional getMemoryFootprintBytes(const ForInst &forInst, +Optional getMemoryFootprintBytes(ConstOpPointer forOp, int memorySpace = -1); } // end namespace mlir diff --git a/mlir/include/mlir/Analysis/VectorAnalysis.h b/mlir/include/mlir/Analysis/VectorAnalysis.h index dfb1164750a..89f49fdfe77 100644 --- a/mlir/include/mlir/Analysis/VectorAnalysis.h +++ b/mlir/include/mlir/Analysis/VectorAnalysis.h @@ -25,8 +25,8 @@ namespace mlir { class AffineApplyOp; +class AffineForOp; class AffineMap; -class ForInst; class FuncBuilder; class Instruction; class Location; @@ -71,7 +71,7 @@ shapeRatio(VectorType superVectorType, VectorType subVectorType); /// loop information is extracted. /// /// Prerequisites: `opInst` is a vectorizable load or store operation (i.e. at -/// most one invariant index along each ForInst of `loopToVectorDim`). +/// most one invariant index along each AffineForOp of `loopToVectorDim`). /// /// Example 1: /// The following MLIR snippet: @@ -122,9 +122,9 @@ shapeRatio(VectorType superVectorType, VectorType subVectorType); /// Meaning that vector_transfer_read will be responsible of reading the slice /// `%arg0[%c0, %c0]` into vector<128xf32> which needs a 1-D vector broadcast. /// -AffineMap -makePermutationMap(OperationInst *opInst, - const llvm::DenseMap &loopToVectorDim); +AffineMap makePermutationMap( + OperationInst *opInst, + const llvm::DenseMap &loopToVectorDim); namespace matcher { diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 3271c12afde..29a9fb0281b 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -239,11 +239,6 @@ public: /// current function. Block *createBlock(Block *insertBefore = nullptr); - /// Returns a builder for the body of a 'for' instruction. - static FuncBuilder getForInstBodyBuilder(ForInst *forInst) { - return FuncBuilder(forInst->getBody(), forInst->getBody()->end()); - } - /// Returns the current block of the builder. Block *getBlock() const { return block; } @@ -277,15 +272,6 @@ public: return cloneInst; } - // Creates a for instruction. When step is not specified, it is set to 1. - ForInst *createFor(Location location, ArrayRef lbOperands, - AffineMap lbMap, ArrayRef ubOperands, - AffineMap ubMap, int64_t step = 1); - - // Creates a for instruction with known (constant) lower and upper bounds. - // Default step is 1. - ForInst *createFor(Location loc, int64_t lb, int64_t ub, int64_t step = 1); - private: Function *function; Block *block = nullptr; diff --git a/mlir/include/mlir/IR/InstVisitor.h b/mlir/include/mlir/IR/InstVisitor.h index 78810da909d..0ed8599ff33 100644 --- a/mlir/include/mlir/IR/InstVisitor.h +++ b/mlir/include/mlir/IR/InstVisitor.h @@ -83,8 +83,6 @@ public: "Must pass the derived type to this template!"); switch (s->getKind()) { - case Instruction::Kind::For: - return static_cast(this)->visitForInst(cast(s)); case Instruction::Kind::OperationInst: return static_cast(this)->visitOperationInst( cast(s)); @@ -101,7 +99,6 @@ public: // When visiting a for inst, if inst, or an operation inst directly, these // methods get called to indicate when transitioning into a new unit. - void visitForInst(ForInst *forInst) {} void visitOperationInst(OperationInst *opInst) {} }; @@ -147,22 +144,11 @@ public: void walkOpInstPostOrder(OperationInst *opInst) { for (auto &blockList : opInst->getBlockLists()) for (auto &block : blockList) - static_cast(this)->walk(block.begin(), block.end()); + static_cast(this)->walkPostOrder(block.begin(), + block.end()); static_cast(this)->visitOperationInst(opInst); } - void walkForInst(ForInst *forInst) { - static_cast(this)->visitForInst(forInst); - auto *body = forInst->getBody(); - static_cast(this)->walk(body->begin(), body->end()); - } - - void walkForInstPostOrder(ForInst *forInst) { - auto *body = forInst->getBody(); - static_cast(this)->walkPostOrder(body->begin(), body->end()); - static_cast(this)->visitForInst(forInst); - } - // Function to walk a instruction. RetTy walk(Instruction *s) { static_assert(std::is_base_of::value, @@ -171,8 +157,6 @@ public: static_cast(this)->visitInstruction(s); switch (s->getKind()) { - case Instruction::Kind::For: - return static_cast(this)->walkForInst(cast(s)); case Instruction::Kind::OperationInst: return static_cast(this)->walkOpInst(cast(s)); } @@ -185,9 +169,6 @@ public: static_cast(this)->visitInstruction(s); switch (s->getKind()) { - case Instruction::Kind::For: - return static_cast(this)->walkForInstPostOrder( - cast(s)); case Instruction::Kind::OperationInst: return static_cast(this)->walkOpInstPostOrder( cast(s)); @@ -205,7 +186,6 @@ public: // called. These are typically O(1) complexity and shouldn't be recursively // processing their descendants in some way. When using RetTy, all of these // need to be overridden. - void visitForInst(ForInst *forInst) {} void visitOperationInst(OperationInst *opInst) {} void visitInstruction(Instruction *inst) {} }; diff --git a/mlir/include/mlir/IR/Instruction.h b/mlir/include/mlir/IR/Instruction.h index 3dc1e76dd20..3789fefc639 100644 --- a/mlir/include/mlir/IR/Instruction.h +++ b/mlir/include/mlir/IR/Instruction.h @@ -32,7 +32,6 @@ namespace mlir { class Block; class BlockAndValueMapping; class Location; -class ForInst; class MLIRContext; /// Terminator operations can have Block operands to represent successors. @@ -74,7 +73,6 @@ class Instruction : public IROperandOwner, public: enum class Kind { OperationInst = (int)IROperandOwner::Kind::OperationInst, - For = (int)IROperandOwner::Kind::ForInst, }; Kind getKind() const { return (Kind)IROperandOwner::getKind(); } diff --git a/mlir/include/mlir/IR/Instructions.h b/mlir/include/mlir/IR/Instructions.h index 724c4dd7039..c6fde0e0aee 100644 --- a/mlir/include/mlir/IR/Instructions.h +++ b/mlir/include/mlir/IR/Instructions.h @@ -26,15 +26,11 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Block.h" #include "mlir/IR/Instruction.h" -#include "mlir/IR/IntegerSet.h" #include "mlir/IR/OperationSupport.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/TrailingObjects.h" namespace mlir { -class AffineBound; -class IntegerSet; -class AffineCondition; class AttributeListStorage; template class ConstOpPointer; template class OpPointer; @@ -219,6 +215,13 @@ public: return getOperandStorage().isResizable(); } + /// Replace the current operands of this operation with the ones provided in + /// 'operands'. If the operands list is not resizable, the size of 'operands' + /// must be less than or equal to the current number of operands. + void setOperands(ArrayRef operands) { + getOperandStorage().setOperands(this, operands); + } + unsigned getNumOperands() const { return getOperandStorage().size(); } Value *getOperand(unsigned idx) { return getInstOperand(idx).get(); } @@ -697,262 +700,6 @@ inline auto OperationInst::getResultTypes() const return {result_type_begin(), result_type_end()}; } -/// For instruction represents an affine loop nest. -class ForInst final - : public Instruction, - private llvm::TrailingObjects { -public: - static ForInst *create(Location location, ArrayRef lbOperands, - AffineMap lbMap, ArrayRef ubOperands, - AffineMap ubMap, int64_t step); - - /// Resolve base class ambiguity. - using Instruction::getFunction; - - /// Operand iterators. - using operand_iterator = OperandIterator; - using const_operand_iterator = OperandIterator; - - /// Operand iterator range. - using operand_range = llvm::iterator_range; - using const_operand_range = llvm::iterator_range; - - /// Get the body of the ForInst. - Block *getBody() { return &body.front(); } - - /// Get the body of the ForInst. - const Block *getBody() const { return &body.front(); } - - //===--------------------------------------------------------------------===// - // Bounds and step - //===--------------------------------------------------------------------===// - - /// Returns information about the lower bound as a single object. - const AffineBound getLowerBound() const; - - /// Returns information about the upper bound as a single object. - const AffineBound getUpperBound() const; - - /// Returns loop step. - int64_t getStep() const { return step; } - - /// Returns affine map for the lower bound. - AffineMap getLowerBoundMap() const { return lbMap; } - /// Returns affine map for the upper bound. The upper bound is exclusive. - AffineMap getUpperBoundMap() const { return ubMap; } - - /// Set lower bound. - void setLowerBound(ArrayRef operands, AffineMap map); - /// Set upper bound. - void setUpperBound(ArrayRef operands, AffineMap map); - - /// Set the lower bound map without changing operands. - void setLowerBoundMap(AffineMap map); - - /// Set the upper bound map without changing operands. - void setUpperBoundMap(AffineMap map); - - /// Set loop step. - void setStep(int64_t step) { - assert(step > 0 && "step has to be a positive integer constant"); - this->step = step; - } - - /// Returns true if the lower bound is constant. - bool hasConstantLowerBound() const; - /// Returns true if the upper bound is constant. - bool hasConstantUpperBound() const; - /// Returns true if both bounds are constant. - bool hasConstantBounds() const { - return hasConstantLowerBound() && hasConstantUpperBound(); - } - /// Returns the value of the constant lower bound. - /// Fails assertion if the bound is non-constant. - int64_t getConstantLowerBound() const; - /// Returns the value of the constant upper bound. The upper bound is - /// exclusive. Fails assertion if the bound is non-constant. - int64_t getConstantUpperBound() const; - /// Sets the lower bound to the given constant value. - void setConstantLowerBound(int64_t value); - /// Sets the upper bound to the given constant value. - void setConstantUpperBound(int64_t value); - - /// Returns true if both the lower and upper bound have the same operand lists - /// (same operands in the same order). - bool matchingBoundOperandList() const; - - /// Walk the operation instructions in the 'for' instruction in preorder, - /// calling the callback for each operation. - void walkOps(std::function callback); - - /// Walk the operation instructions in the 'for' instruction in postorder, - /// calling the callback for each operation. - void walkOpsPostOrder(std::function callback); - - //===--------------------------------------------------------------------===// - // Operands - //===--------------------------------------------------------------------===// - - unsigned getNumOperands() const { return getOperandStorage().size(); } - - Value *getOperand(unsigned idx) { return getInstOperand(idx).get(); } - const Value *getOperand(unsigned idx) const { - return getInstOperand(idx).get(); - } - void setOperand(unsigned idx, Value *value) { - getInstOperand(idx).set(value); - } - - operand_iterator operand_begin() { return operand_iterator(this, 0); } - operand_iterator operand_end() { - return operand_iterator(this, getNumOperands()); - } - - const_operand_iterator operand_begin() const { - return const_operand_iterator(this, 0); - } - const_operand_iterator operand_end() const { - return const_operand_iterator(this, getNumOperands()); - } - - ArrayRef getInstOperands() const { - return getOperandStorage().getInstOperands(); - } - MutableArrayRef getInstOperands() { - return getOperandStorage().getInstOperands(); - } - InstOperand &getInstOperand(unsigned idx) { return getInstOperands()[idx]; } - const InstOperand &getInstOperand(unsigned idx) const { - return getInstOperands()[idx]; - } - - // TODO: provide iterators for the lower and upper bound operands - // if the current access via getLowerBound(), getUpperBound() is too slow. - - /// Returns operands for the lower bound map. - operand_range getLowerBoundOperands(); - const_operand_range getLowerBoundOperands() const; - - /// Returns operands for the upper bound map. - operand_range getUpperBoundOperands(); - const_operand_range getUpperBoundOperands() const; - - //===--------------------------------------------------------------------===// - // Other - //===--------------------------------------------------------------------===// - - /// Return the context this operation is associated with. - MLIRContext *getContext() const { - return getInductionVar()->getType().getContext(); - } - - using Instruction::dump; - using Instruction::print; - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(const IROperandOwner *ptr) { - return ptr->getKind() == IROperandOwner::Kind::ForInst; - } - - /// Returns the induction variable for this loop. - Value *getInductionVar(); - const Value *getInductionVar() const { - return const_cast(this)->getInductionVar(); - } - - void destroy(); - -private: - // The Block for the body. By construction, this list always contains exactly - // one block. - BlockList body; - - // Affine map for the lower bound. - AffineMap lbMap; - // Affine map for the upper bound. The upper bound is exclusive. - AffineMap ubMap; - // Positive constant step. Since index is stored as an int64_t, we restrict - // step to the set of positive integers that int64_t can represent. - int64_t step; - - explicit ForInst(Location location, AffineMap lbMap, AffineMap ubMap, - int64_t step); - ~ForInst(); - - /// Returns the operand storage object. - detail::OperandStorage &getOperandStorage() { - return *getTrailingObjects(); - } - const detail::OperandStorage &getOperandStorage() const { - return *getTrailingObjects(); - } - - // This stuff is used by the TrailingObjects template. - friend llvm::TrailingObjects; -}; - -/// Returns if the provided value is the induction variable of a ForInst. -bool isForInductionVar(const Value *val); - -/// Returns the loop parent of an induction variable. If the provided value is -/// not an induction variable, then return nullptr. -ForInst *getForInductionVarOwner(Value *val); -const ForInst *getForInductionVarOwner(const Value *val); - -/// Extracts the induction variables from a list of ForInsts and returns them. -SmallVector extractForInductionVars(ArrayRef forInsts); - -/// AffineBound represents a lower or upper bound in the for instruction. -/// This class does not own the underlying operands. Instead, it refers -/// to the operands stored in the ForInst. Its life span should not exceed -/// that of the for instruction it refers to. -class AffineBound { -public: - const ForInst *getForInst() const { return &inst; } - AffineMap getMap() const { return map; } - - unsigned getNumOperands() const { return opEnd - opStart; } - const Value *getOperand(unsigned idx) const { - return inst.getOperand(opStart + idx); - } - const InstOperand &getInstOperand(unsigned idx) const { - return inst.getInstOperand(opStart + idx); - } - - using operand_iterator = ForInst::operand_iterator; - using operand_range = ForInst::operand_range; - - operand_iterator operand_begin() const { - // These are iterators over Value *. Not casting away const'ness would - // require the caller to use const Value *. - return operand_iterator(const_cast(&inst), opStart); - } - operand_iterator operand_end() const { - return operand_iterator(const_cast(&inst), opEnd); - } - - /// Returns an iterator on the underlying Value's (Value *). - operand_range getOperands() const { return {operand_begin(), operand_end()}; } - ArrayRef getInstOperands() const { - auto ops = inst.getInstOperands(); - return ArrayRef(ops.begin() + opStart, ops.begin() + opEnd); - } - -private: - // 'for' instruction that contains this bound. - const ForInst &inst; - // Start and end positions of this affine bound operands in the list of - // the containing 'for' instruction operands. - unsigned opStart, opEnd; - // Affine map for this bound. - AffineMap map; - - AffineBound(const ForInst &inst, unsigned opStart, unsigned opEnd, - AffineMap map) - : inst(inst), opStart(opStart), opEnd(opEnd), map(map) {} - - friend class ForInst; -}; } // end namespace mlir #endif // MLIR_IR_INSTRUCTIONS_H diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index c9ef9bf7cd6..2c62816b924 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -68,6 +68,11 @@ public: operator bool() const { return value.getInstruction(); } + bool operator==(OpPointer rhs) const { + return value.getInstruction() == rhs.value.getInstruction(); + } + bool operator!=(OpPointer rhs) const { return !(*this == rhs); } + /// OpPointer can be implicitly converted to OpType*. /// Return `nullptr` if there is no associated OperationInst*. operator OpType *() { @@ -87,6 +92,9 @@ public: private: OpType value; + + // Allow access to value to enable constructing an empty ConstOpPointer. + friend class ConstOpPointer; }; /// This pointer represents a notional "const OperationInst*" but where the @@ -96,6 +104,7 @@ class ConstOpPointer { public: explicit ConstOpPointer() : value(OperationInst::getNull().value) {} explicit ConstOpPointer(OpType value) : value(value) {} + ConstOpPointer(OpPointer pointer) : value(pointer.value) {} const OpType &operator*() const { return value; } @@ -104,6 +113,11 @@ public: /// Return true if non-null. operator bool() const { return value.getInstruction(); } + bool operator==(ConstOpPointer rhs) const { + return value.getInstruction() == rhs.value.getInstruction(); + } + bool operator!=(ConstOpPointer rhs) const { return !(*this == rhs); } + /// ConstOpPointer can always be implicitly converted to const OpType*. /// Return `nullptr` if there is no associated OperationInst*. operator const OpType *() const { diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index d3a5d35427f..4e7596498e7 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -90,7 +90,8 @@ public: virtual void printGenericOp(const OperationInst *op) = 0; /// Prints a block list. - virtual void printBlockList(const BlockList &blocks) = 0; + virtual void printBlockList(const BlockList &blocks, + bool printEntryBlockArgs = true) = 0; private: OpAsmPrinter(const OpAsmPrinter &) = delete; @@ -170,6 +171,9 @@ public: /// This parses... a comma! virtual bool parseComma() = 0; + /// This parses an equal(=) token! + virtual bool parseEqual() = 0; + /// Parse a type. virtual bool parseType(Type &result) = 0; @@ -203,9 +207,9 @@ public: } /// Parse a keyword. - bool parseKeyword(const char *keyword) { + bool parseKeyword(const char *keyword, const Twine &msg = "") { if (parseOptionalKeyword(keyword)) - return emitError(getNameLoc(), "expected '" + Twine(keyword) + "'"); + return emitError(getNameLoc(), "expected '" + Twine(keyword) + "'" + msg); return false; } @@ -315,6 +319,10 @@ public: /// operation's block lists after the operation is created. virtual bool parseBlockList() = 0; + /// Parses an argument for the entry block of the next block list to be + /// parsed. + virtual bool parseBlockListEntryBlockArgument(Type argType) = 0; + //===--------------------------------------------------------------------===// // Methods for interacting with the parser //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h index 80cd21362ce..871e78bbb24 100644 --- a/mlir/include/mlir/IR/UseDefLists.h +++ b/mlir/include/mlir/IR/UseDefLists.h @@ -75,15 +75,14 @@ private: }; /// Subclasses of IROperandOwner can be the owner of an IROperand. In practice -/// this is the common base between Instruction and Instruction. +/// this is the common base between Instructions. class IROperandOwner { public: enum class Kind { OperationInst, - ForInst, /// These enums define ranges used for classof implementations. - INST_LAST = ForInst, + INST_LAST = OperationInst, }; Kind getKind() const { return locationAndKind.getInt(); } diff --git a/mlir/include/mlir/Transforms/LoopUtils.h b/mlir/include/mlir/Transforms/LoopUtils.h index e0cf3039f07..f3d9b9fe9fd 100644 --- a/mlir/include/mlir/Transforms/LoopUtils.h +++ b/mlir/include/mlir/Transforms/LoopUtils.h @@ -27,11 +27,12 @@ #include "mlir/Support/LLVM.h" namespace mlir { - class AffineMap; -class ForInst; +class AffineForOp; +template class ConstOpPointer; class Function; class FuncBuilder; +template class OpPointer; // Values that can be used to signal success/failure. This can be implicitly // converted to/from boolean values, with false representing success and true @@ -44,51 +45,54 @@ struct LLVM_NODISCARD UtilResult { /// Unrolls this for instruction completely if the trip count is known to be /// constant. Returns false otherwise. -bool loopUnrollFull(ForInst *forInst); +bool loopUnrollFull(OpPointer forOp); /// Unrolls this for instruction by the specified unroll factor. Returns false /// if the loop cannot be unrolled either due to restrictions or due to invalid /// unroll factors. -bool loopUnrollByFactor(ForInst *forInst, uint64_t unrollFactor); +bool loopUnrollByFactor(OpPointer forOp, uint64_t unrollFactor); /// Unrolls this loop by the specified unroll factor or its trip count, /// whichever is lower. -bool loopUnrollUpToFactor(ForInst *forInst, uint64_t unrollFactor); +bool loopUnrollUpToFactor(OpPointer forOp, uint64_t unrollFactor); /// Unrolls and jams this loop by the specified factor. Returns true if the loop /// is successfully unroll-jammed. -bool loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor); +bool loopUnrollJamByFactor(OpPointer forOp, + uint64_t unrollJamFactor); /// Unrolls and jams this loop by the specified factor or by the trip count (if /// constant), whichever is lower. -bool loopUnrollJamUpToFactor(ForInst *forInst, uint64_t unrollJamFactor); +bool loopUnrollJamUpToFactor(OpPointer forOp, + uint64_t unrollJamFactor); -/// Promotes the loop body of a ForInst to its containing block if the ForInst -/// was known to have a single iteration. Returns false otherwise. -bool promoteIfSingleIteration(ForInst *forInst); +/// Promotes the loop body of a AffineForOp to its containing block if the +/// AffineForOp was known to have a single iteration. Returns false otherwise. +bool promoteIfSingleIteration(OpPointer forOp); -/// Promotes all single iteration ForInst's in the Function, i.e., moves +/// Promotes all single iteration AffineForOp's in the Function, i.e., moves /// their body into the containing Block. void promoteSingleIterationLoops(Function *f); /// Returns the lower bound of the cleanup loop when unrolling a loop /// with the specified unroll factor. -AffineMap getCleanupLoopLowerBound(const ForInst &forInst, +AffineMap getCleanupLoopLowerBound(ConstOpPointer forOp, unsigned unrollFactor, FuncBuilder *builder); /// Returns the upper bound of an unrolled loop when unrolling with /// the specified trip count, stride, and unroll factor. -AffineMap getUnrolledLoopUpperBound(const ForInst &forInst, +AffineMap getUnrolledLoopUpperBound(ConstOpPointer forOp, unsigned unrollFactor, FuncBuilder *builder); /// Skew the instructions in the body of a 'for' instruction with the specified /// instruction-wise shifts. The shifts are with respect to the original /// execution order, and are multiplied by the loop 'step' before being applied. -UtilResult instBodySkew(ForInst *forInst, ArrayRef shifts, +UtilResult instBodySkew(OpPointer forOp, ArrayRef shifts, bool unrollPrologueEpilogue = false); /// Tiles the specified band of perfectly nested loops creating tile-space loops /// and intra-tile loops. A band is a contiguous set of loops. -UtilResult tileCodeGen(ArrayRef band, ArrayRef tileSizes); +UtilResult tileCodeGen(MutableArrayRef> band, + ArrayRef tileSizes); } // end namespace mlir diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index 714086f22a7..3269ac1fdc5 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -27,7 +27,8 @@ namespace mlir { -class ForInst; +class AffineForOp; +template class ConstOpPointer; class FunctionPass; class ModulePass; @@ -57,9 +58,10 @@ FunctionPass *createMaterializeVectorsPass(); /// factors supplied through other means. If -1 is passed as the unrollFactor /// and no callback is provided, anything passed from the command-line (if at /// all) or the default unroll factor is used (LoopUnroll:kDefaultUnrollFactor). -FunctionPass *createLoopUnrollPass( - int unrollFactor = -1, int unrollFull = -1, - const std::function &getUnrollFactor = nullptr); +FunctionPass * +createLoopUnrollPass(int unrollFactor = -1, int unrollFull = -1, + const std::function)> + &getUnrollFactor = nullptr); /// Creates a loop unroll jam pass to unroll jam by the specified factor. A /// factor of -1 lets the pass use the default factor or the one on the command diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h index 5c7260c9a58..169633cc106 100644 --- a/mlir/include/mlir/Transforms/Utils.h +++ b/mlir/include/mlir/Transforms/Utils.h @@ -32,7 +32,7 @@ namespace mlir { -class ForInst; +class AffineForOp; class FuncBuilder; class Location; class Module; @@ -115,7 +115,7 @@ void createAffineComputationSlice( /// Folds the lower and upper bounds of a 'for' inst to constants if possible. /// Returns false if the folding happens for at least one bound, true otherwise. -bool constantFoldBounds(ForInst *forInst); +bool constantFoldBounds(OpPointer forInst); /// Replaces (potentially nested) function attributes in the operation "op" /// with those specified in "remappingTable". diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 5b29467fc44..f1693c8e449 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -17,7 +17,10 @@ #include "mlir/AffineOps/AffineOps.h" #include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/InstVisitor.h" +#include "mlir/IR/IntegerSet.h" #include "mlir/IR/OpImplementation.h" using namespace mlir; @@ -27,7 +30,445 @@ using namespace mlir; AffineOpsDialect::AffineOpsDialect(MLIRContext *context) : Dialect(/*namePrefix=*/"", context) { - addOperations(); + addOperations(); +} + +//===----------------------------------------------------------------------===// +// AffineForOp +//===----------------------------------------------------------------------===// + +void AffineForOp::build(Builder *builder, OperationState *result, + ArrayRef lbOperands, AffineMap lbMap, + ArrayRef ubOperands, AffineMap ubMap, + int64_t step) { + assert((!lbMap && lbOperands.empty()) || + lbOperands.size() == lbMap.getNumInputs() && + "lower bound operand count does not match the affine map"); + assert((!ubMap && ubOperands.empty()) || + ubOperands.size() == ubMap.getNumInputs() && + "upper bound operand count does not match the affine map"); + assert(step > 0 && "step has to be a positive integer constant"); + + // Add an attribute for the step. + result->addAttribute(getStepAttrName(), + builder->getIntegerAttr(builder->getIndexType(), step)); + + // Add the lower bound. + result->addAttribute(getLowerBoundAttrName(), + builder->getAffineMapAttr(lbMap)); + result->addOperands(lbOperands); + + // Add the upper bound. + result->addAttribute(getUpperBoundAttrName(), + builder->getAffineMapAttr(ubMap)); + result->addOperands(ubOperands); + + // Reserve a block list for the body. + result->reserveBlockLists(/*numReserved=*/1); + + // Set the operands list as resizable so that we can freely modify the bounds. + result->setOperandListToResizable(); +} + +void AffineForOp::build(Builder *builder, OperationState *result, int64_t lb, + int64_t ub, int64_t step) { + auto lbMap = AffineMap::getConstantMap(lb, builder->getContext()); + auto ubMap = AffineMap::getConstantMap(ub, builder->getContext()); + return build(builder, result, {}, lbMap, {}, ubMap, step); +} + +bool AffineForOp::verify() const { + const auto &bodyBlockList = getInstruction()->getBlockList(0); + + // The body block list must contain a single basic block. + if (bodyBlockList.empty() || + std::next(bodyBlockList.begin()) != bodyBlockList.end()) + return emitOpError("expected body block list to have a single block"); + + // Check that the body defines as single block argument for the induction + // variable. + const auto *body = getBody(); + if (body->getNumArguments() != 1 || + !body->getArgument(0)->getType().isIndex()) + return emitOpError("expected body to have a single index argument for the " + "induction variable"); + + // TODO: check that loop bounds are properly formed. + return false; +} + +/// Parse a for operation loop bounds. +static bool parseBound(bool isLower, OperationState *result, OpAsmParser *p) { + // 'min' / 'max' prefixes are generally syntactic sugar, but are required if + // the map has multiple results. + bool failedToParsedMinMax = p->parseOptionalKeyword(isLower ? "max" : "min"); + + auto &builder = p->getBuilder(); + auto boundAttrName = isLower ? AffineForOp::getLowerBoundAttrName() + : AffineForOp::getUpperBoundAttrName(); + + // Parse ssa-id as identity map. + SmallVector boundOpInfos; + if (p->parseOperandList(boundOpInfos)) + return true; + + if (!boundOpInfos.empty()) { + // Check that only one operand was parsed. + if (boundOpInfos.size() > 1) + return p->emitError(p->getNameLoc(), + "expected only one loop bound operand"); + + // TODO: improve error message when SSA value is not an affine integer. + // Currently it is 'use of value ... expects different type than prior uses' + if (p->resolveOperand(boundOpInfos.front(), builder.getIndexType(), + result->operands)) + return true; + + // Create an identity map using symbol id. This representation is optimized + // for storage. Analysis passes may expand it into a multi-dimensional map + // if desired. + AffineMap map = builder.getSymbolIdentityMap(); + result->addAttribute(boundAttrName, builder.getAffineMapAttr(map)); + return false; + } + + Attribute boundAttr; + if (p->parseAttribute(boundAttr, builder.getIndexType(), boundAttrName.data(), + result->attributes)) + return true; + + // Parse full form - affine map followed by dim and symbol list. + if (auto affineMapAttr = boundAttr.dyn_cast()) { + unsigned currentNumOperands = result->operands.size(); + unsigned numDims; + if (parseDimAndSymbolList(p, result->operands, numDims)) + return true; + + auto map = affineMapAttr.getValue(); + if (map.getNumDims() != numDims) + return p->emitError( + p->getNameLoc(), + "dim operand count and integer set dim count must match"); + + unsigned numDimAndSymbolOperands = + result->operands.size() - currentNumOperands; + if (numDims + map.getNumSymbols() != numDimAndSymbolOperands) + return p->emitError( + p->getNameLoc(), + "symbol operand count and integer set symbol count must match"); + + // If the map has multiple results, make sure that we parsed the min/max + // prefix. + if (map.getNumResults() > 1 && failedToParsedMinMax) { + if (isLower) { + return p->emitError(p->getNameLoc(), + "lower loop bound affine map with multiple results " + "requires 'max' prefix"); + } + return p->emitError(p->getNameLoc(), + "upper loop bound affine map with multiple results " + "requires 'min' prefix"); + } + return false; + } + + // Parse custom assembly form. + if (auto integerAttr = boundAttr.dyn_cast()) { + result->attributes.pop_back(); + result->addAttribute( + boundAttrName, builder.getAffineMapAttr( + builder.getConstantAffineMap(integerAttr.getInt()))); + return false; + } + + return p->emitError( + p->getNameLoc(), + "expected valid affine map representation for loop bounds"); +} + +bool AffineForOp::parse(OpAsmParser *parser, OperationState *result) { + auto &builder = parser->getBuilder(); + // Parse the induction variable followed by '='. + if (parser->parseBlockListEntryBlockArgument(builder.getIndexType()) || + parser->parseEqual()) + return true; + + // Parse loop bounds. + if (parseBound(/*isLower=*/true, result, parser) || + parser->parseKeyword("to", " between bounds") || + parseBound(/*isLower=*/false, result, parser)) + return true; + + // Parse the optional loop step, we default to 1 if one is not present. + if (parser->parseOptionalKeyword("step")) { + result->addAttribute( + getStepAttrName(), + builder.getIntegerAttr(builder.getIndexType(), /*value=*/1)); + } else { + llvm::SMLoc stepLoc; + IntegerAttr stepAttr; + if (parser->getCurrentLocation(&stepLoc) || + parser->parseAttribute(stepAttr, builder.getIndexType(), + getStepAttrName().data(), result->attributes)) + return true; + + if (stepAttr.getValue().getSExtValue() < 0) + return parser->emitError( + stepLoc, + "expected step to be representable as a positive signed integer"); + } + + // Parse the body block list. + result->reserveBlockLists(/*numReserved=*/1); + if (parser->parseBlockList()) + return true; + + // Set the operands list as resizable so that we can freely modify the bounds. + result->setOperandListToResizable(); + return false; +} + +static void printBound(AffineBound bound, const char *prefix, OpAsmPrinter *p) { + AffineMap map = bound.getMap(); + + // Check if this bound should be printed using custom assembly form. + // The decision to restrict printing custom assembly form to trivial cases + // comes from the will to roundtrip MLIR binary -> text -> binary in a + // lossless way. + // Therefore, custom assembly form parsing and printing is only supported for + // zero-operand constant maps and single symbol operand identity maps. + if (map.getNumResults() == 1) { + AffineExpr expr = map.getResult(0); + + // Print constant bound. + if (map.getNumDims() == 0 && map.getNumSymbols() == 0) { + if (auto constExpr = expr.dyn_cast()) { + *p << constExpr.getValue(); + return; + } + } + + // Print bound that consists of a single SSA symbol if the map is over a + // single symbol. + if (map.getNumDims() == 0 && map.getNumSymbols() == 1) { + if (auto symExpr = expr.dyn_cast()) { + p->printOperand(bound.getOperand(0)); + return; + } + } + } else { + // Map has multiple results. Print 'min' or 'max' prefix. + *p << prefix << ' '; + } + + // Print the map and its operands. + p->printAffineMap(map); + printDimAndSymbolList(bound.operand_begin(), bound.operand_end(), + map.getNumDims(), p); +} + +void AffineForOp::print(OpAsmPrinter *p) const { + *p << "for "; + p->printOperand(getBody()->getArgument(0)); + *p << " = "; + printBound(getLowerBound(), "max", p); + *p << " to "; + printBound(getUpperBound(), "min", p); + + if (getStep() != 1) + *p << " step " << getStep(); + p->printBlockList(getInstruction()->getBlockList(0), + /*printEntryBlockArgs=*/false); +} + +Block *AffineForOp::createBody() { + auto &bodyBlockList = getBlockList(); + assert(bodyBlockList.empty() && "expected no existing body blocks"); + + // Create a new block for the body, and add an argument for the induction + // variable. + Block *body = new Block(); + body->addArgument(IndexType::get(getInstruction()->getContext())); + bodyBlockList.push_back(body); + return body; +} + +const AffineBound AffineForOp::getLowerBound() const { + auto lbMap = getLowerBoundMap(); + return AffineBound(ConstOpPointer(*this), 0, + lbMap.getNumInputs(), lbMap); +} + +const AffineBound AffineForOp::getUpperBound() const { + auto lbMap = getLowerBoundMap(); + auto ubMap = getUpperBoundMap(); + return AffineBound(ConstOpPointer(*this), lbMap.getNumInputs(), + getNumOperands(), ubMap); +} + +void AffineForOp::setLowerBound(ArrayRef lbOperands, AffineMap map) { + assert(lbOperands.size() == map.getNumInputs()); + assert(map.getNumResults() >= 1 && "bound map has at least one result"); + + SmallVector newOperands(lbOperands.begin(), lbOperands.end()); + + auto ubOperands = getUpperBoundOperands(); + newOperands.append(ubOperands.begin(), ubOperands.end()); + getInstruction()->setOperands(newOperands); + + setAttr(Identifier::get(getLowerBoundAttrName(), map.getContext()), + AffineMapAttr::get(map)); +} + +void AffineForOp::setUpperBound(ArrayRef ubOperands, AffineMap map) { + assert(ubOperands.size() == map.getNumInputs()); + assert(map.getNumResults() >= 1 && "bound map has at least one result"); + + SmallVector newOperands(getLowerBoundOperands()); + newOperands.append(ubOperands.begin(), ubOperands.end()); + getInstruction()->setOperands(newOperands); + + setAttr(Identifier::get(getUpperBoundAttrName(), map.getContext()), + AffineMapAttr::get(map)); +} + +void AffineForOp::setLowerBoundMap(AffineMap map) { + auto lbMap = getLowerBoundMap(); + assert(lbMap.getNumDims() == map.getNumDims() && + lbMap.getNumSymbols() == map.getNumSymbols()); + assert(map.getNumResults() >= 1 && "bound map has at least one result"); + (void)lbMap; + setAttr(Identifier::get(getLowerBoundAttrName(), map.getContext()), + AffineMapAttr::get(map)); +} + +void AffineForOp::setUpperBoundMap(AffineMap map) { + auto ubMap = getUpperBoundMap(); + assert(ubMap.getNumDims() == map.getNumDims() && + ubMap.getNumSymbols() == map.getNumSymbols()); + assert(map.getNumResults() >= 1 && "bound map has at least one result"); + (void)ubMap; + setAttr(Identifier::get(getUpperBoundAttrName(), map.getContext()), + AffineMapAttr::get(map)); +} + +bool AffineForOp::hasConstantLowerBound() const { + return getLowerBoundMap().isSingleConstant(); +} + +bool AffineForOp::hasConstantUpperBound() const { + return getUpperBoundMap().isSingleConstant(); +} + +int64_t AffineForOp::getConstantLowerBound() const { + return getLowerBoundMap().getSingleConstantResult(); +} + +int64_t AffineForOp::getConstantUpperBound() const { + return getUpperBoundMap().getSingleConstantResult(); +} + +void AffineForOp::setConstantLowerBound(int64_t value) { + setLowerBound( + {}, AffineMap::getConstantMap(value, getInstruction()->getContext())); +} + +void AffineForOp::setConstantUpperBound(int64_t value) { + setUpperBound( + {}, AffineMap::getConstantMap(value, getInstruction()->getContext())); +} + +AffineForOp::operand_range AffineForOp::getLowerBoundOperands() { + return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()}; +} + +AffineForOp::const_operand_range AffineForOp::getLowerBoundOperands() const { + return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()}; +} + +AffineForOp::operand_range AffineForOp::getUpperBoundOperands() { + return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()}; +} + +AffineForOp::const_operand_range AffineForOp::getUpperBoundOperands() const { + return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()}; +} + +bool AffineForOp::matchingBoundOperandList() const { + auto lbMap = getLowerBoundMap(); + auto ubMap = getUpperBoundMap(); + if (lbMap.getNumDims() != ubMap.getNumDims() || + lbMap.getNumSymbols() != ubMap.getNumSymbols()) + return false; + + unsigned numOperands = lbMap.getNumInputs(); + for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) { + // Compare Value *'s. + if (getOperand(i) != getOperand(numOperands + i)) + return false; + } + return true; +} + +void AffineForOp::walkOps(std::function callback) { + struct Walker : public InstWalker { + std::function const &callback; + Walker(std::function const &callback) + : callback(callback) {} + + void visitOperationInst(OperationInst *opInst) { callback(opInst); } + }; + + Walker w(callback); + w.walk(getInstruction()); +} + +void AffineForOp::walkOpsPostOrder( + std::function callback) { + struct Walker : public InstWalker { + std::function const &callback; + Walker(std::function const &callback) + : callback(callback) {} + + void visitOperationInst(OperationInst *opInst) { callback(opInst); } + }; + + Walker v(callback); + v.walkPostOrder(getInstruction()); +} + +/// Returns the induction variable for this loop. +Value *AffineForOp::getInductionVar() { return getBody()->getArgument(0); } + +/// Returns if the provided value is the induction variable of a AffineForOp. +bool mlir::isForInductionVar(const Value *val) { + return getForInductionVarOwner(val) != nullptr; +} + +/// Returns the loop parent of an induction variable. If the provided value is +/// not an induction variable, then return nullptr. +OpPointer mlir::getForInductionVarOwner(Value *val) { + const BlockArgument *ivArg = dyn_cast(val); + if (!ivArg || !ivArg->getOwner()) + return OpPointer(); + auto *containingInst = ivArg->getOwner()->getParent()->getContainingInst(); + if (!containingInst) + return OpPointer(); + return cast(containingInst)->dyn_cast(); +} +ConstOpPointer mlir::getForInductionVarOwner(const Value *val) { + auto nonConstOwner = getForInductionVarOwner(const_cast(val)); + return ConstOpPointer(nonConstOwner); +} + +/// Extracts the induction variables from a list of AffineForOps and returns +/// them. +SmallVector mlir::extractForInductionVars( + MutableArrayRef> forInsts) { + SmallVector results; + for (auto forInst : forInsts) + results.push_back(forInst->getInductionVar()); + return results; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index 0153546a4c6..d2366f1ce81 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -21,12 +21,14 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Utils.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Instructions.h" +#include "mlir/IR/IntegerSet.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Support/MathExtras.h" #include "mlir/Support/STLExtras.h" @@ -519,7 +521,7 @@ void mlir::getReachableAffineApplyOps( State &state = worklist.back(); auto *opInst = state.value->getDefiningInst(); // Note: getDefiningInst will return nullptr if the operand is not an - // OperationInst (i.e. ForInst), which is a terminator for the search. + // OperationInst (i.e. AffineForOp), which is a terminator for the search. if (opInst == nullptr || !opInst->isa()) { worklist.pop_back(); continue; @@ -546,21 +548,21 @@ void mlir::getReachableAffineApplyOps( } // Builds a system of constraints with dimensional identifiers corresponding to -// the loop IVs of the forInsts appearing in that order. Any symbols founds in +// the loop IVs of the forOps appearing in that order. Any symbols founds in // the bound operands are added as symbols in the system. Returns false for the // yet unimplemented cases. // TODO(andydavis,bondhugula) Handle non-unit steps through local variables or // stride information in FlatAffineConstraints. (For eg., by using iv - lb % // step = 0 and/or by introducing a method in FlatAffineConstraints // setExprStride(ArrayRef expr, int64_t stride) -bool mlir::getIndexSet(ArrayRef forInsts, +bool mlir::getIndexSet(MutableArrayRef> forOps, FlatAffineConstraints *domain) { - auto indices = extractForInductionVars(forInsts); + auto indices = extractForInductionVars(forOps); // Reset while associated Values in 'indices' to the domain. - domain->reset(forInsts.size(), /*numSymbols=*/0, /*numLocals=*/0, indices); - for (auto *forInst : forInsts) { - // Add constraints from forInst's bounds. - if (!domain->addForInstDomain(*forInst)) + domain->reset(forOps.size(), /*numSymbols=*/0, /*numLocals=*/0, indices); + for (auto forOp : forOps) { + // Add constraints from forOp's bounds. + if (!domain->addAffineForOpDomain(forOp)) return false; } return true; @@ -576,7 +578,7 @@ static bool getInstIndexSet(const Instruction *inst, FlatAffineConstraints *indexSet) { // TODO(andydavis) Extend this to gather enclosing IfInsts and consider // factoring it out into a utility function. - SmallVector loops; + SmallVector, 4> loops; getLoopIVs(*inst, &loops); return getIndexSet(loops, indexSet); } @@ -998,9 +1000,9 @@ static const Block *getCommonBlock(const MemRefAccess &srcAccess, return block; } auto *commonForValue = srcDomain.getIdValue(numCommonLoops - 1); - auto *forInst = getForInductionVarOwner(commonForValue); - assert(forInst && "commonForValue was not an induction variable"); - return forInst->getBody(); + auto forOp = getForInductionVarOwner(commonForValue); + assert(forOp && "commonForValue was not an induction variable"); + return forOp->getBody(); } // Returns true if the ancestor operation instruction of 'srcAccess' appears @@ -1195,7 +1197,7 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const { // until operands of the AffineValueMap are loop IVs or symbols. // *) Build iteration domain constraints for each access. Iteration domain // constraints are pairs of inequality contraints representing the -// upper/lower loop bounds for each ForInst in the loop nest associated +// upper/lower loop bounds for each AffineForOp in the loop nest associated // with each access. // *) Build dimension and symbol position maps for each access, which map // Values from access functions and iteration domains to their position diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 5e7f8e3243c..c794899d3e1 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -20,6 +20,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/AffineStructures.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" @@ -1247,22 +1248,23 @@ void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) { numSymbols = newSymbolCount; } -bool FlatAffineConstraints::addForInstDomain(const ForInst &forInst) { +bool FlatAffineConstraints::addAffineForOpDomain( + ConstOpPointer forOp) { unsigned pos; // Pre-condition for this method. - if (!findId(*forInst.getInductionVar(), &pos)) { + if (!findId(*forOp->getInductionVar(), &pos)) { assert(0 && "Value not found"); return false; } - if (forInst.getStep() != 1) + if (forOp->getStep() != 1) LLVM_DEBUG(llvm::dbgs() << "Domain conservative: non-unit stride not handled\n"); // Adds a lower or upper bound when the bounds aren't constant. auto addLowerOrUpperBound = [&](bool lower) -> bool { - auto operands = lower ? forInst.getLowerBoundOperands() - : forInst.getUpperBoundOperands(); + auto operands = + lower ? forOp->getLowerBoundOperands() : forOp->getUpperBoundOperands(); for (const auto &operand : operands) { unsigned loc; if (!findId(*operand, &loc)) { @@ -1291,7 +1293,7 @@ bool FlatAffineConstraints::addForInstDomain(const ForInst &forInst) { } auto boundMap = - lower ? forInst.getLowerBoundMap() : forInst.getUpperBoundMap(); + lower ? forOp->getLowerBoundMap() : forOp->getUpperBoundMap(); FlatAffineConstraints localVarCst; std::vector> flatExprs; @@ -1321,16 +1323,16 @@ bool FlatAffineConstraints::addForInstDomain(const ForInst &forInst) { return true; }; - if (forInst.hasConstantLowerBound()) { - addConstantLowerBound(pos, forInst.getConstantLowerBound()); + if (forOp->hasConstantLowerBound()) { + addConstantLowerBound(pos, forOp->getConstantLowerBound()); } else { // Non-constant lower bound case. if (!addLowerOrUpperBound(/*lower=*/true)) return false; } - if (forInst.hasConstantUpperBound()) { - addConstantUpperBound(pos, forInst.getConstantUpperBound() - 1); + if (forOp->hasConstantUpperBound()) { + addConstantUpperBound(pos, forOp->getConstantUpperBound() - 1); return true; } // Non-constant upper bound case. diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 7d88a3d9b9f..249776d42c9 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -43,27 +43,27 @@ using namespace mlir; /// Returns the trip count of the loop as an affine expression if the latter is /// expressible as an affine expression, and nullptr otherwise. The trip count /// expression is simplified before returning. -AffineExpr mlir::getTripCountExpr(const ForInst &forInst) { +AffineExpr mlir::getTripCountExpr(ConstOpPointer forOp) { // upper_bound - lower_bound int64_t loopSpan; - int64_t step = forInst.getStep(); - auto *context = forInst.getContext(); + int64_t step = forOp->getStep(); + auto *context = forOp->getInstruction()->getContext(); - if (forInst.hasConstantBounds()) { - int64_t lb = forInst.getConstantLowerBound(); - int64_t ub = forInst.getConstantUpperBound(); + if (forOp->hasConstantBounds()) { + int64_t lb = forOp->getConstantLowerBound(); + int64_t ub = forOp->getConstantUpperBound(); loopSpan = ub - lb; } else { - auto lbMap = forInst.getLowerBoundMap(); - auto ubMap = forInst.getUpperBoundMap(); + auto lbMap = forOp->getLowerBoundMap(); + auto ubMap = forOp->getUpperBoundMap(); // TODO(bondhugula): handle max/min of multiple expressions. if (lbMap.getNumResults() != 1 || ubMap.getNumResults() != 1) return nullptr; // TODO(bondhugula): handle bounds with different operands. // Bounds have different operands, unhandled for now. - if (!forInst.matchingBoundOperandList()) + if (!forOp->matchingBoundOperandList()) return nullptr; // ub_expr - lb_expr @@ -89,8 +89,9 @@ AffineExpr mlir::getTripCountExpr(const ForInst &forInst) { /// Returns the trip count of the loop if it's a constant, None otherwise. This /// method uses affine expression analysis (in turn using getTripCount) and is /// able to determine constant trip count in non-trivial cases. -llvm::Optional mlir::getConstantTripCount(const ForInst &forInst) { - auto tripCountExpr = getTripCountExpr(forInst); +llvm::Optional +mlir::getConstantTripCount(ConstOpPointer forOp) { + auto tripCountExpr = getTripCountExpr(forOp); if (!tripCountExpr) return None; @@ -104,8 +105,8 @@ llvm::Optional mlir::getConstantTripCount(const ForInst &forInst) { /// Returns the greatest known integral divisor of the trip count. Affine /// expression analysis is used (indirectly through getTripCount), and /// this method is thus able to determine non-trivial divisors. -uint64_t mlir::getLargestDivisorOfTripCount(const ForInst &forInst) { - auto tripCountExpr = getTripCountExpr(forInst); +uint64_t mlir::getLargestDivisorOfTripCount(ConstOpPointer forOp) { + auto tripCountExpr = getTripCountExpr(forOp); if (!tripCountExpr) return 1; @@ -126,7 +127,7 @@ uint64_t mlir::getLargestDivisorOfTripCount(const ForInst &forInst) { } bool mlir::isAccessInvariant(const Value &iv, const Value &index) { - assert(isForInductionVar(&iv) && "iv must be a ForInst"); + assert(isForInductionVar(&iv) && "iv must be a AffineForOp"); assert(index.getType().isa() && "index must be of IndexType"); SmallVector affineApplyOps; getReachableAffineApplyOps({const_cast(&index)}, affineApplyOps); @@ -163,7 +164,7 @@ mlir::getInvariantAccesses(const Value &iv, } /// Given: -/// 1. an induction variable `iv` of type ForInst; +/// 1. an induction variable `iv` of type AffineForOp; /// 2. a `memoryOp` of type const LoadOp& or const StoreOp&; /// 3. the index of the `fastestVaryingDim` along which to check; /// determines whether `memoryOp`[`fastestVaryingDim`] is a contiguous access @@ -231,17 +232,18 @@ static bool isVectorTransferReadOrWrite(const Instruction &inst) { } using VectorizableInstFun = - std::function; + std::function, const OperationInst &)>; -static bool isVectorizableLoopWithCond(const ForInst &loop, +static bool isVectorizableLoopWithCond(ConstOpPointer loop, VectorizableInstFun isVectorizableInst) { - if (!matcher::isParallelLoop(loop) && !matcher::isReductionLoop(loop)) { + auto *forInst = const_cast(loop->getInstruction()); + if (!matcher::isParallelLoop(*forInst) && + !matcher::isReductionLoop(*forInst)) { return false; } // No vectorization across conditionals for now. auto conditionals = matcher::If(); - auto *forInst = const_cast(&loop); SmallVector conditionalsMatched; conditionals.match(forInst, &conditionalsMatched); if (!conditionalsMatched.empty()) { @@ -251,7 +253,8 @@ static bool isVectorizableLoopWithCond(const ForInst &loop, // No vectorization across unknown regions. auto regions = matcher::Op([](const Instruction &inst) -> bool { auto &opInst = cast(inst); - return opInst.getNumBlockLists() != 0 && !opInst.isa(); + return opInst.getNumBlockLists() != 0 && + !(opInst.isa() || opInst.isa()); }); SmallVector regionsMatched; regions.match(forInst, ®ionsMatched); @@ -288,23 +291,25 @@ static bool isVectorizableLoopWithCond(const ForInst &loop, } bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim( - const ForInst &loop, unsigned fastestVaryingDim) { - VectorizableInstFun fun( - [fastestVaryingDim](const ForInst &loop, const OperationInst &op) { - auto load = op.dyn_cast(); - auto store = op.dyn_cast(); - return load ? isContiguousAccess(*loop.getInductionVar(), *load, - fastestVaryingDim) - : isContiguousAccess(*loop.getInductionVar(), *store, - fastestVaryingDim); - }); + ConstOpPointer loop, unsigned fastestVaryingDim) { + VectorizableInstFun fun([fastestVaryingDim](ConstOpPointer loop, + const OperationInst &op) { + auto load = op.dyn_cast(); + auto store = op.dyn_cast(); + return load ? isContiguousAccess(*loop->getInductionVar(), *load, + fastestVaryingDim) + : isContiguousAccess(*loop->getInductionVar(), *store, + fastestVaryingDim); + }); return isVectorizableLoopWithCond(loop, fun); } -bool mlir::isVectorizableLoop(const ForInst &loop) { +bool mlir::isVectorizableLoop(ConstOpPointer loop) { VectorizableInstFun fun( // TODO: implement me - [](const ForInst &loop, const OperationInst &op) { return true; }); + [](ConstOpPointer loop, const OperationInst &op) { + return true; + }); return isVectorizableLoopWithCond(loop, fun); } @@ -313,9 +318,9 @@ bool mlir::isVectorizableLoop(const ForInst &loop) { /// 'def' and all its uses have the same shift factor. // TODO(mlir-team): extend this to check for memory-based dependence // violation when we have the support. -bool mlir::isInstwiseShiftValid(const ForInst &forInst, +bool mlir::isInstwiseShiftValid(ConstOpPointer forOp, ArrayRef shifts) { - auto *forBody = forInst.getBody(); + auto *forBody = forOp->getBody(); assert(shifts.size() == forBody->getInstructions().size()); unsigned s = 0; for (const auto &inst : *forBody) { @@ -325,7 +330,7 @@ bool mlir::isInstwiseShiftValid(const ForInst &forInst, for (unsigned i = 0, e = opInst->getNumResults(); i < e; ++i) { const Value *result = opInst->getResult(i); for (const InstOperand &use : result->getUses()) { - // If an ancestor instruction doesn't lie in the block of forInst, + // If an ancestor instruction doesn't lie in the block of forOp, // there is no shift to check. This is a naive way. If performance // becomes an issue, a map can be used to store 'shifts' - to look up // the shift for a instruction in constant time. diff --git a/mlir/lib/Analysis/NestedMatcher.cpp b/mlir/lib/Analysis/NestedMatcher.cpp index 46bf5ad0b97..214b4ce403c 100644 --- a/mlir/lib/Analysis/NestedMatcher.cpp +++ b/mlir/lib/Analysis/NestedMatcher.cpp @@ -115,6 +115,10 @@ void NestedPattern::matchOne(Instruction *inst, } } +static bool isAffineForOp(const Instruction &inst) { + return cast(inst).isa(); +} + static bool isAffineIfOp(const Instruction &inst) { return isa(inst) && cast(inst).isa(); @@ -147,28 +151,34 @@ NestedPattern If(FilterFunctionType filter, ArrayRef nested) { } NestedPattern For(NestedPattern child) { - return NestedPattern(Instruction::Kind::For, child, defaultFilterFunction); + return NestedPattern(Instruction::Kind::OperationInst, child, isAffineForOp); } NestedPattern For(FilterFunctionType filter, NestedPattern child) { - return NestedPattern(Instruction::Kind::For, child, filter); + return NestedPattern(Instruction::Kind::OperationInst, child, + [=](const Instruction &inst) { + return isAffineForOp(inst) && filter(inst); + }); } NestedPattern For(ArrayRef nested) { - return NestedPattern(Instruction::Kind::For, nested, defaultFilterFunction); + return NestedPattern(Instruction::Kind::OperationInst, nested, isAffineForOp); } NestedPattern For(FilterFunctionType filter, ArrayRef nested) { - return NestedPattern(Instruction::Kind::For, nested, filter); + return NestedPattern(Instruction::Kind::OperationInst, nested, + [=](const Instruction &inst) { + return isAffineForOp(inst) && filter(inst); + }); } // TODO(ntv): parallel annotation on loops. bool isParallelLoop(const Instruction &inst) { - const auto *loop = cast(&inst); - return (void *)loop || true; // loop->isParallel(); + auto loop = cast(inst).cast(); + return loop || true; // loop->isParallel(); }; // TODO(ntv): reduction annotation on loops. bool isReductionLoop(const Instruction &inst) { - const auto *loop = cast(&inst); - return (void *)loop || true; // loop->isReduction(); + auto loop = cast(inst).cast(); + return loop || true; // loop->isReduction(); }; bool isLoadOrStore(const Instruction &inst) { diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index d16a7fcb1b3..4025af936f3 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -20,6 +20,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/VectorAnalysis.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Instructions.h" @@ -52,7 +53,16 @@ void mlir::getForwardSlice(Instruction *inst, return; } - if (auto *opInst = dyn_cast(inst)) { + auto *opInst = cast(inst); + if (auto forOp = opInst->dyn_cast()) { + for (auto &u : forOp->getInductionVar()->getUses()) { + auto *ownerInst = u.getOwner(); + if (forwardSlice->count(ownerInst) == 0) { + getForwardSlice(ownerInst, forwardSlice, filter, + /*topLevel=*/false); + } + } + } else { assert(opInst->getNumResults() <= 1 && "NYI: multiple results"); if (opInst->getNumResults() > 0) { for (auto &u : opInst->getResult(0)->getUses()) { @@ -63,16 +73,6 @@ void mlir::getForwardSlice(Instruction *inst, } } } - } else if (auto *forInst = dyn_cast(inst)) { - for (auto &u : forInst->getInductionVar()->getUses()) { - auto *ownerInst = u.getOwner(); - if (forwardSlice->count(ownerInst) == 0) { - getForwardSlice(ownerInst, forwardSlice, filter, - /*topLevel=*/false); - } - } - } else { - assert(false && "NYI slicing case"); } // At the top level we reverse to get back the actual topological order. diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 0e77d4d9084..4b8afd9a620 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -38,15 +38,17 @@ using namespace mlir; /// Populates 'loops' with IVs of the loops surrounding 'inst' ordered from /// the outermost 'for' instruction to the innermost one. void mlir::getLoopIVs(const Instruction &inst, - SmallVectorImpl *loops) { + SmallVectorImpl> *loops) { auto *currInst = inst.getParentInst(); - ForInst *currForInst; + OpPointer currAffineForOp; // Traverse up the hierarchy collecing all 'for' instruction while skipping // over 'if' instructions. - while (currInst && ((currForInst = dyn_cast(currInst)) || - cast(currInst)->isa())) { - if (currForInst) - loops->push_back(currForInst); + while (currInst && + ((currAffineForOp = + cast(currInst)->dyn_cast()) || + cast(currInst)->isa())) { + if (currAffineForOp) + loops->push_back(currAffineForOp); currInst = currInst->getParentInst(); } std::reverse(loops->begin(), loops->end()); @@ -148,7 +150,7 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth, if (rank == 0) { // A rank 0 memref has a 0-d region. - SmallVector ivs; + SmallVector, 4> ivs; getLoopIVs(*opInst, &ivs); SmallVector regionSymbols = extractForInductionVars(ivs); @@ -174,12 +176,12 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth, unsigned numSymbols = accessMap.getNumSymbols(); // Add inequalties for loop lower/upper bounds. for (unsigned i = 0; i < numDims + numSymbols; ++i) { - if (auto *loop = getForInductionVarOwner(accessValueMap.getOperand(i))) { + if (auto loop = getForInductionVarOwner(accessValueMap.getOperand(i))) { // Note that regionCst can now have more dimensions than accessMap if the // bounds expressions involve outer loops or other symbols. // TODO(bondhugula): rewrite this to use getInstIndexSet; this way // conditionals will be handled when the latter supports it. - if (!regionCst->addForInstDomain(*loop)) + if (!regionCst->addAffineForOpDomain(loop)) return false; } else { // Has to be a valid symbol. @@ -203,14 +205,14 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth, // Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which // this memref region is symbolic. - SmallVector outerIVs; + SmallVector, 4> outerIVs; getLoopIVs(*opInst, &outerIVs); assert(loopDepth <= outerIVs.size() && "invalid loop depth"); outerIVs.resize(loopDepth); for (auto *operand : accessValueMap.getOperands()) { - ForInst *iv; + OpPointer iv; if ((iv = getForInductionVarOwner(operand)) && - std::find(outerIVs.begin(), outerIVs.end(), iv) == outerIVs.end()) { + llvm::is_contained(outerIVs, iv) == false) { regionCst->projectOut(operand); } } @@ -357,8 +359,10 @@ static Instruction *getInstAtPosition(ArrayRef positions, } if (level == positions.size() - 1) return &inst; - if (auto *childForInst = dyn_cast(&inst)) - return getInstAtPosition(positions, level + 1, childForInst->getBody()); + if (auto childAffineForOp = + cast(inst).dyn_cast()) + return getInstAtPosition(positions, level + 1, + childAffineForOp->getBody()); for (auto &blockList : cast(&inst)->getBlockLists()) { for (auto &b : blockList) @@ -385,12 +389,12 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess, return false; } // Get loop nest surrounding src operation. - SmallVector srcLoopIVs; + SmallVector, 4> srcLoopIVs; getLoopIVs(*srcAccess.opInst, &srcLoopIVs); unsigned numSrcLoopIVs = srcLoopIVs.size(); // Get loop nest surrounding dst operation. - SmallVector dstLoopIVs; + SmallVector, 4> dstLoopIVs; getLoopIVs(*dstAccess.opInst, &dstLoopIVs); unsigned numDstLoopIVs = dstLoopIVs.size(); if (dstLoopDepth > numDstLoopIVs) { @@ -437,38 +441,41 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess, // solution. // TODO(andydavis) Remove dependence on 'srcLoopDepth' here. Instead project // out loop IVs we don't care about and produce smaller slice. -ForInst *mlir::insertBackwardComputationSlice( +OpPointer mlir::insertBackwardComputationSlice( OperationInst *srcOpInst, OperationInst *dstOpInst, unsigned dstLoopDepth, ComputationSliceState *sliceState) { // Get loop nest surrounding src operation. - SmallVector srcLoopIVs; + SmallVector, 4> srcLoopIVs; getLoopIVs(*srcOpInst, &srcLoopIVs); unsigned numSrcLoopIVs = srcLoopIVs.size(); // Get loop nest surrounding dst operation. - SmallVector dstLoopIVs; + SmallVector, 4> dstLoopIVs; getLoopIVs(*dstOpInst, &dstLoopIVs); unsigned dstLoopIVsSize = dstLoopIVs.size(); if (dstLoopDepth > dstLoopIVsSize) { dstOpInst->emitError("invalid destination loop depth"); - return nullptr; + return OpPointer(); } // Find the inst block positions of 'srcOpInst' within 'srcLoopIVs'. SmallVector positions; // TODO(andydavis): This code is incorrect since srcLoopIVs can be 0-d. - findInstPosition(srcOpInst, srcLoopIVs[0]->getBlock(), &positions); + findInstPosition(srcOpInst, srcLoopIVs[0]->getInstruction()->getBlock(), + &positions); // Clone src loop nest and insert it a the beginning of the instruction block // of the loop at 'dstLoopDepth' in 'dstLoopIVs'. - auto *dstForInst = dstLoopIVs[dstLoopDepth - 1]; - FuncBuilder b(dstForInst->getBody(), dstForInst->getBody()->begin()); - auto *sliceLoopNest = cast(b.clone(*srcLoopIVs[0])); + auto dstAffineForOp = dstLoopIVs[dstLoopDepth - 1]; + FuncBuilder b(dstAffineForOp->getBody(), dstAffineForOp->getBody()->begin()); + auto sliceLoopNest = + cast(b.clone(*srcLoopIVs[0]->getInstruction())) + ->cast(); Instruction *sliceInst = getInstAtPosition(positions, /*level=*/0, sliceLoopNest->getBody()); // Get loop nest surrounding 'sliceInst'. - SmallVector sliceSurroundingLoops; + SmallVector, 4> sliceSurroundingLoops; getLoopIVs(*sliceInst, &sliceSurroundingLoops); // Sanity check. @@ -481,11 +488,11 @@ ForInst *mlir::insertBackwardComputationSlice( // Update loop bounds for loops in 'sliceLoopNest'. for (unsigned i = 0; i < numSrcLoopIVs; ++i) { - auto *forInst = sliceSurroundingLoops[dstLoopDepth + i]; + auto forOp = sliceSurroundingLoops[dstLoopDepth + i]; if (AffineMap lbMap = sliceState->lbs[i]) - forInst->setLowerBound(sliceState->lbOperands[i], lbMap); + forOp->setLowerBound(sliceState->lbOperands[i], lbMap); if (AffineMap ubMap = sliceState->ubs[i]) - forInst->setUpperBound(sliceState->ubOperands[i], ubMap); + forOp->setUpperBound(sliceState->ubOperands[i], ubMap); } return sliceLoopNest; } @@ -520,7 +527,7 @@ unsigned mlir::getNestingDepth(const Instruction &stmt) { const Instruction *currInst = &stmt; unsigned depth = 0; while ((currInst = currInst->getParentInst())) { - if (isa(currInst)) + if (cast(currInst)->isa()) depth++; } return depth; @@ -530,14 +537,14 @@ unsigned mlir::getNestingDepth(const Instruction &stmt) { /// where each lists loops from outer-most to inner-most in loop nest. unsigned mlir::getNumCommonSurroundingLoops(const Instruction &A, const Instruction &B) { - SmallVector loopsA, loopsB; + SmallVector, 4> loopsA, loopsB; getLoopIVs(A, &loopsA); getLoopIVs(B, &loopsB); unsigned minNumLoops = std::min(loopsA.size(), loopsB.size()); unsigned numCommonLoops = 0; for (unsigned i = 0; i < minNumLoops; ++i) { - if (loopsA[i] != loopsB[i]) + if (loopsA[i]->getInstruction() != loopsB[i]->getInstruction()) break; ++numCommonLoops; } @@ -571,13 +578,14 @@ static Optional getRegionSize(const MemRefRegion ®ion) { return getMemRefEltSizeInBytes(memRefType) * numElements.getValue(); } -Optional mlir::getMemoryFootprintBytes(const ForInst &forInst, - int memorySpace) { +Optional +mlir::getMemoryFootprintBytes(ConstOpPointer forOp, + int memorySpace) { std::vector> regions; // Walk this 'for' instruction to gather all memory regions. bool error = false; - const_cast(&forInst)->walkOps([&](OperationInst *opInst) { + const_cast(*forOp).walkOps([&](OperationInst *opInst) { if (!opInst->isa() && !opInst->isa()) { // Neither load nor a store op. return; diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index 125020e92a3..4865cb03bb4 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -16,10 +16,12 @@ // ============================================================================= #include "mlir/Analysis/VectorAnalysis.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Instructions.h" +#include "mlir/IR/IntegerSet.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/SuperVectorOps/SuperVectorOps.h" #include "mlir/Support/Functional.h" @@ -105,7 +107,7 @@ Optional> mlir::shapeRatio(VectorType superVectorType, static AffineMap makePermutationMap( MLIRContext *context, llvm::iterator_range indices, - const DenseMap &enclosingLoopToVectorDim) { + const DenseMap &enclosingLoopToVectorDim) { using functional::makePtrDynCaster; using functional::map; auto unwrappedIndices = map(makePtrDynCaster(), indices); @@ -113,8 +115,9 @@ static AffineMap makePermutationMap( getAffineConstantExpr(0, context)); for (auto kvp : enclosingLoopToVectorDim) { assert(kvp.second < perm.size()); - auto invariants = - getInvariantAccesses(*kvp.first->getInductionVar(), unwrappedIndices); + auto invariants = getInvariantAccesses( + *cast(kvp.first)->cast()->getInductionVar(), + unwrappedIndices); unsigned numIndices = unwrappedIndices.size(); unsigned countInvariantIndices = 0; for (unsigned dim = 0; dim < numIndices; ++dim) { @@ -139,30 +142,30 @@ static AffineMap makePermutationMap( /// TODO(ntv): could also be implemented as a collect parents followed by a /// filter and made available outside this file. template -static SetVector getParentsOfType(Instruction *inst) { - SetVector res; +static SetVector getParentsOfType(Instruction *inst) { + SetVector res; auto *current = inst; while (auto *parent = current->getParentInst()) { - auto *typedParent = dyn_cast(parent); - if (typedParent) { - assert(res.count(typedParent) == 0 && "Already inserted"); - res.insert(typedParent); + if (auto typedParent = + cast(parent)->template dyn_cast()) { + assert(res.count(cast(parent)) == 0 && "Already inserted"); + res.insert(cast(parent)); } current = parent; } return res; } -/// Returns the enclosing ForInst, from closest to farthest. -static SetVector getEnclosingforInsts(Instruction *inst) { - return getParentsOfType(inst); +/// Returns the enclosing AffineForOp, from closest to farthest. +static SetVector getEnclosingforOps(Instruction *inst) { + return getParentsOfType(inst); } -AffineMap -mlir::makePermutationMap(OperationInst *opInst, - const DenseMap &loopToVectorDim) { - DenseMap enclosingLoopToVectorDim; - auto enclosingLoops = getEnclosingforInsts(opInst); +AffineMap mlir::makePermutationMap( + OperationInst *opInst, + const DenseMap &loopToVectorDim) { + DenseMap enclosingLoopToVectorDim; + auto enclosingLoops = getEnclosingforOps(opInst); for (auto *forInst : enclosingLoops) { auto it = loopToVectorDim.find(forInst); if (it != loopToVectorDim.end()) { diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index 474eeb2a28e..a69831053ad 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -72,7 +72,6 @@ public: bool verify(); bool verifyBlock(const Block &block, bool isTopLevel); bool verifyOperation(const OperationInst &op); - bool verifyForInst(const ForInst &forInst); bool verifyDominance(const Block &block); bool verifyInstDominance(const Instruction &inst); @@ -175,10 +174,6 @@ bool FuncVerifier::verifyBlock(const Block &block, bool isTopLevel) { if (verifyOperation(cast(inst))) return true; break; - case Instruction::Kind::For: - if (verifyForInst(cast(inst))) - return true; - break; } } @@ -240,11 +235,6 @@ bool FuncVerifier::verifyOperation(const OperationInst &op) { return false; } -bool FuncVerifier::verifyForInst(const ForInst &forInst) { - // TODO: check that loop bounds are properly formed. - return verifyBlock(*forInst.getBody(), /*isTopLevel=*/false); -} - bool FuncVerifier::verifyDominance(const Block &block) { for (auto &inst : block) { // Check that all operands on the instruction are ok. @@ -262,10 +252,6 @@ bool FuncVerifier::verifyDominance(const Block &block) { return true; break; } - case Instruction::Kind::For: - if (verifyDominance(*cast(inst).getBody())) - return true; - break; } } return false; diff --git a/mlir/lib/EDSC/MLIREmitter.cpp b/mlir/lib/EDSC/MLIREmitter.cpp index dc85c5ed682..f4d5d36d25b 100644 --- a/mlir/lib/EDSC/MLIREmitter.cpp +++ b/mlir/lib/EDSC/MLIREmitter.cpp @@ -21,12 +21,14 @@ #include "llvm/Support/raw_ostream.h" #include "mlir-c/Core.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/EDSC/MLIREmitter.h" #include "mlir/EDSC/Types.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Instructions.h" +#include "mlir/IR/IntegerSet.h" #include "mlir/IR/Location.h" #include "mlir/IR/Value.h" #include "mlir/StandardOps/StandardOps.h" @@ -133,8 +135,8 @@ static void printDefininingStatement(llvm::raw_ostream &os, const Value &v) { inst->print(os); return; } - if (auto *forInst = getForInductionVarOwner(&v)) { - forInst->print(os); + if (auto forInst = getForInductionVarOwner(&v)) { + forInst->getInstruction()->print(os); } else { os << "unknown_ssa_value"; } @@ -300,7 +302,9 @@ Value *mlir::edsc::MLIREmitter::emitExpr(Expr e) { exprs[1]->getDefiningInst()->cast()->getValue(); auto step = exprs[2]->getDefiningInst()->cast()->getValue(); - res = builder->createFor(location, lb, ub, step)->getInductionVar(); + auto forOp = builder->create(location, lb, ub, step); + forOp->createBody(); + res = forOp->getInductionVar(); } } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index cb4c1f0edce..0fb18fa0004 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -130,21 +130,8 @@ private: void recordTypeReference(Type ty) { usedTypes.insert(ty); } - // Return true if this map could be printed using the custom assembly form. - static bool hasCustomForm(AffineMap boundMap) { - if (boundMap.isSingleConstant()) - return true; - - // Check if the affine map is single dim id or single symbol identity - - // (i)->(i) or ()[s]->(i) - return boundMap.getNumInputs() == 1 && boundMap.getNumResults() == 1 && - (boundMap.getResult(0).isa() || - boundMap.getResult(0).isa()); - } - // Visit functions. void visitInstruction(const Instruction *inst); - void visitForInst(const ForInst *forInst); void visitOperationInst(const OperationInst *opInst); void visitType(Type type); void visitAttribute(Attribute attr); @@ -196,16 +183,6 @@ void ModuleState::visitAttribute(Attribute attr) { } } -void ModuleState::visitForInst(const ForInst *forInst) { - AffineMap lbMap = forInst->getLowerBoundMap(); - if (!hasCustomForm(lbMap)) - recordAffineMapReference(lbMap); - - AffineMap ubMap = forInst->getUpperBoundMap(); - if (!hasCustomForm(ubMap)) - recordAffineMapReference(ubMap); -} - void ModuleState::visitOperationInst(const OperationInst *op) { // Visit all the types used in the operation. for (auto *operand : op->getOperands()) @@ -220,8 +197,6 @@ void ModuleState::visitOperationInst(const OperationInst *op) { void ModuleState::visitInstruction(const Instruction *inst) { switch (inst->getKind()) { - case Instruction::Kind::For: - return visitForInst(cast(inst)); case Instruction::Kind::OperationInst: return visitOperationInst(cast(inst)); } @@ -1069,7 +1044,6 @@ public: // Methods to print instructions. void print(const Instruction *inst); void print(const OperationInst *inst); - void print(const ForInst *inst); void print(const Block *block, bool printBlockArgs = true); void printOperation(const OperationInst *op); @@ -1117,10 +1091,8 @@ public: unsigned index) override; /// Print a block list. - void printBlockList(const BlockList &blocks) override { - printBlockList(blocks, /*printEntryBlockArgs=*/true); - } - void printBlockList(const BlockList &blocks, bool printEntryBlockArgs) { + void printBlockList(const BlockList &blocks, + bool printEntryBlockArgs) override { os << " {\n"; if (!blocks.empty()) { auto *entryBlock = &blocks.front(); @@ -1132,10 +1104,6 @@ public: os.indent(currentIndent) << "}"; } - // Print if and loop bounds. - void printDimAndSymbolList(ArrayRef ops, unsigned numDims); - void printBound(AffineBound bound, const char *prefix); - // Number of spaces used for indenting nested instructions. const static unsigned indentWidth = 2; @@ -1205,10 +1173,6 @@ void FunctionPrinter::numberValuesInBlock(const Block &block) { numberValuesInBlock(block); break; } - case Instruction::Kind::For: - // Recursively number the stuff in the body. - numberValuesInBlock(*cast(&inst)->getBody()); - break; } } } @@ -1404,8 +1368,6 @@ void FunctionPrinter::print(const Instruction *inst) { switch (inst->getKind()) { case Instruction::Kind::OperationInst: return print(cast(inst)); - case Instruction::Kind::For: - return print(cast(inst)); } } @@ -1415,24 +1377,6 @@ void FunctionPrinter::print(const OperationInst *inst) { printTrailingLocation(inst->getLoc()); } -void FunctionPrinter::print(const ForInst *inst) { - os.indent(currentIndent) << "for "; - printOperand(inst->getInductionVar()); - os << " = "; - printBound(inst->getLowerBound(), "max"); - os << " to "; - printBound(inst->getUpperBound(), "min"); - - if (inst->getStep() != 1) - os << " step " << inst->getStep(); - - printTrailingLocation(inst->getLoc()); - - os << " {\n"; - print(inst->getBody(), /*printBlockArgs=*/false); - os.indent(currentIndent) << "}"; -} - void FunctionPrinter::printValueID(const Value *value, bool printResultNo) const { int resultNo = -1; @@ -1560,62 +1504,6 @@ void FunctionPrinter::printSuccessorAndUseList(const OperationInst *term, os << ')'; } -void FunctionPrinter::printDimAndSymbolList(ArrayRef ops, - unsigned numDims) { - auto printComma = [&]() { os << ", "; }; - os << '('; - interleave( - ops.begin(), ops.begin() + numDims, - [&](const InstOperand &v) { printOperand(v.get()); }, printComma); - os << ')'; - - if (numDims < ops.size()) { - os << '['; - interleave( - ops.begin() + numDims, ops.end(), - [&](const InstOperand &v) { printOperand(v.get()); }, printComma); - os << ']'; - } -} - -void FunctionPrinter::printBound(AffineBound bound, const char *prefix) { - AffineMap map = bound.getMap(); - - // Check if this bound should be printed using custom assembly form. - // The decision to restrict printing custom assembly form to trivial cases - // comes from the will to roundtrip MLIR binary -> text -> binary in a - // lossless way. - // Therefore, custom assembly form parsing and printing is only supported for - // zero-operand constant maps and single symbol operand identity maps. - if (map.getNumResults() == 1) { - AffineExpr expr = map.getResult(0); - - // Print constant bound. - if (map.getNumDims() == 0 && map.getNumSymbols() == 0) { - if (auto constExpr = expr.dyn_cast()) { - os << constExpr.getValue(); - return; - } - } - - // Print bound that consists of a single SSA symbol if the map is over a - // single symbol. - if (map.getNumDims() == 0 && map.getNumSymbols() == 1) { - if (auto symExpr = expr.dyn_cast()) { - printOperand(bound.getOperand(0)); - return; - } - } - } else { - // Map has multiple results. Print 'min' or 'max' prefix. - os << prefix << ' '; - } - - // Print the map and its operands. - printAffineMapReference(map); - printDimAndSymbolList(bound.getInstOperands(), map.getNumDims()); -} - // Prints function with initialized module state. void ModulePrinter::print(const Function *fn) { FunctionPrinter(fn, *this).print(); diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index ffeb4e0317f..68fbef2d27a 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -312,19 +312,3 @@ OperationInst *FuncBuilder::createOperation(const OperationState &state) { block->getInstructions().insert(insertPoint, op); return op; } - -ForInst *FuncBuilder::createFor(Location location, ArrayRef lbOperands, - AffineMap lbMap, ArrayRef ubOperands, - AffineMap ubMap, int64_t step) { - auto *inst = - ForInst::create(location, lbOperands, lbMap, ubOperands, ubMap, step); - block->getInstructions().insert(insertPoint, inst); - return inst; -} - -ForInst *FuncBuilder::createFor(Location location, int64_t lb, int64_t ub, - int64_t step) { - auto lbMap = AffineMap::getConstantMap(lb, context); - auto ubMap = AffineMap::getConstantMap(ub, context); - return createFor(location, {}, lbMap, {}, ubMap, step); -} diff --git a/mlir/lib/IR/Instruction.cpp b/mlir/lib/IR/Instruction.cpp index 8d43e3a783d..03f1a2702c9 100644 --- a/mlir/lib/IR/Instruction.cpp +++ b/mlir/lib/IR/Instruction.cpp @@ -143,9 +143,6 @@ void Instruction::destroy() { case Kind::OperationInst: cast(this)->destroy(); break; - case Kind::For: - cast(this)->destroy(); - break; } } @@ -209,8 +206,6 @@ unsigned Instruction::getNumOperands() const { switch (getKind()) { case Kind::OperationInst: return cast(this)->getNumOperands(); - case Kind::For: - return cast(this)->getNumOperands(); } } @@ -218,8 +213,6 @@ MutableArrayRef Instruction::getInstOperands() { switch (getKind()) { case Kind::OperationInst: return cast(this)->getInstOperands(); - case Kind::For: - return cast(this)->getInstOperands(); } } @@ -349,10 +342,6 @@ void Instruction::dropAllReferences() { op.drop(); switch (getKind()) { - case Kind::For: - // Make sure to drop references held by instructions within the body. - cast(this)->getBody()->dropAllReferences(); - break; case Kind::OperationInst: { auto *opInst = cast(this); if (isTerminator()) @@ -655,217 +644,6 @@ bool OperationInst::emitOpError(const Twine &message) const { return emitError(Twine('\'') + getName().getStringRef() + "' op " + message); } -//===----------------------------------------------------------------------===// -// ForInst -//===----------------------------------------------------------------------===// - -ForInst *ForInst::create(Location location, ArrayRef lbOperands, - AffineMap lbMap, ArrayRef ubOperands, - AffineMap ubMap, int64_t step) { - assert((!lbMap && lbOperands.empty()) || - lbOperands.size() == lbMap.getNumInputs() && - "lower bound operand count does not match the affine map"); - assert((!ubMap && ubOperands.empty()) || - ubOperands.size() == ubMap.getNumInputs() && - "upper bound operand count does not match the affine map"); - assert(step > 0 && "step has to be a positive integer constant"); - - // Compute the byte size for the instruction and the operand storage. - unsigned numOperands = lbOperands.size() + ubOperands.size(); - auto byteSize = totalSizeToAlloc( - /*detail::OperandStorage*/ 1); - byteSize += llvm::alignTo(detail::OperandStorage::additionalAllocSize( - numOperands, /*resizable=*/true), - alignof(ForInst)); - void *rawMem = malloc(byteSize); - - // Initialize the OperationInst part of the instruction. - ForInst *inst = ::new (rawMem) ForInst(location, lbMap, ubMap, step); - new (&inst->getOperandStorage()) - detail::OperandStorage(numOperands, /*resizable=*/true); - - auto operands = inst->getInstOperands(); - unsigned i = 0; - for (unsigned e = lbOperands.size(); i != e; ++i) - new (&operands[i]) InstOperand(inst, lbOperands[i]); - - for (unsigned j = 0, e = ubOperands.size(); j != e; ++i, ++j) - new (&operands[i]) InstOperand(inst, ubOperands[j]); - - return inst; -} - -ForInst::ForInst(Location location, AffineMap lbMap, AffineMap ubMap, - int64_t step) - : Instruction(Instruction::Kind::For, location), body(this), lbMap(lbMap), - ubMap(ubMap), step(step) { - - // The body of a for inst always has one block. - auto *bodyEntry = new Block(); - body.push_back(bodyEntry); - - // Add an argument to the block for the induction variable. - bodyEntry->addArgument(Type::getIndex(lbMap.getResult(0).getContext())); -} - -ForInst::~ForInst() { getOperandStorage().~OperandStorage(); } - -const AffineBound ForInst::getLowerBound() const { - return AffineBound(*this, 0, lbMap.getNumInputs(), lbMap); -} - -const AffineBound ForInst::getUpperBound() const { - return AffineBound(*this, lbMap.getNumInputs(), getNumOperands(), ubMap); -} - -void ForInst::setLowerBound(ArrayRef lbOperands, AffineMap map) { - assert(lbOperands.size() == map.getNumInputs()); - assert(map.getNumResults() >= 1 && "bound map has at least one result"); - - SmallVector newOperands(lbOperands.begin(), lbOperands.end()); - - auto ubOperands = getUpperBoundOperands(); - newOperands.append(ubOperands.begin(), ubOperands.end()); - getOperandStorage().setOperands(this, newOperands); - - this->lbMap = map; -} - -void ForInst::setUpperBound(ArrayRef ubOperands, AffineMap map) { - assert(ubOperands.size() == map.getNumInputs()); - assert(map.getNumResults() >= 1 && "bound map has at least one result"); - - SmallVector newOperands(getLowerBoundOperands()); - newOperands.append(ubOperands.begin(), ubOperands.end()); - getOperandStorage().setOperands(this, newOperands); - - this->ubMap = map; -} - -void ForInst::setLowerBoundMap(AffineMap map) { - assert(lbMap.getNumDims() == map.getNumDims() && - lbMap.getNumSymbols() == map.getNumSymbols()); - assert(map.getNumResults() >= 1 && "bound map has at least one result"); - this->lbMap = map; -} - -void ForInst::setUpperBoundMap(AffineMap map) { - assert(ubMap.getNumDims() == map.getNumDims() && - ubMap.getNumSymbols() == map.getNumSymbols()); - assert(map.getNumResults() >= 1 && "bound map has at least one result"); - this->ubMap = map; -} - -bool ForInst::hasConstantLowerBound() const { return lbMap.isSingleConstant(); } - -bool ForInst::hasConstantUpperBound() const { return ubMap.isSingleConstant(); } - -int64_t ForInst::getConstantLowerBound() const { - return lbMap.getSingleConstantResult(); -} - -int64_t ForInst::getConstantUpperBound() const { - return ubMap.getSingleConstantResult(); -} - -void ForInst::setConstantLowerBound(int64_t value) { - setLowerBound({}, AffineMap::getConstantMap(value, getContext())); -} - -void ForInst::setConstantUpperBound(int64_t value) { - setUpperBound({}, AffineMap::getConstantMap(value, getContext())); -} - -ForInst::operand_range ForInst::getLowerBoundOperands() { - return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()}; -} - -ForInst::const_operand_range ForInst::getLowerBoundOperands() const { - return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()}; -} - -ForInst::operand_range ForInst::getUpperBoundOperands() { - return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()}; -} - -ForInst::const_operand_range ForInst::getUpperBoundOperands() const { - return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()}; -} - -bool ForInst::matchingBoundOperandList() const { - if (lbMap.getNumDims() != ubMap.getNumDims() || - lbMap.getNumSymbols() != ubMap.getNumSymbols()) - return false; - - unsigned numOperands = lbMap.getNumInputs(); - for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) { - // Compare Value *'s. - if (getOperand(i) != getOperand(numOperands + i)) - return false; - } - return true; -} - -void ForInst::walkOps(std::function callback) { - struct Walker : public InstWalker { - std::function const &callback; - Walker(std::function const &callback) - : callback(callback) {} - - void visitOperationInst(OperationInst *opInst) { callback(opInst); } - }; - - Walker w(callback); - w.walk(this); -} - -void ForInst::walkOpsPostOrder(std::function callback) { - struct Walker : public InstWalker { - std::function const &callback; - Walker(std::function const &callback) - : callback(callback) {} - - void visitOperationInst(OperationInst *opInst) { callback(opInst); } - }; - - Walker v(callback); - v.walkPostOrder(this); -} - -/// Returns the induction variable for this loop. -Value *ForInst::getInductionVar() { return getBody()->getArgument(0); } - -void ForInst::destroy() { - this->~ForInst(); - free(this); -} - -/// Returns if the provided value is the induction variable of a ForInst. -bool mlir::isForInductionVar(const Value *val) { - return getForInductionVarOwner(val) != nullptr; -} - -/// Returns the loop parent of an induction variable. If the provided value is -/// not an induction variable, then return nullptr. -ForInst *mlir::getForInductionVarOwner(Value *val) { - const BlockArgument *ivArg = dyn_cast(val); - if (!ivArg || !ivArg->getOwner()) - return nullptr; - return dyn_cast_or_null( - ivArg->getOwner()->getParent()->getContainingInst()); -} -const ForInst *mlir::getForInductionVarOwner(const Value *val) { - return getForInductionVarOwner(const_cast(val)); -} - -/// Extracts the induction variables from a list of ForInsts and returns them. -SmallVector -mlir::extractForInductionVars(ArrayRef forInsts) { - SmallVector results; - for (auto *forInst : forInsts) - results.push_back(forInst->getInductionVar()); - return results; -} //===----------------------------------------------------------------------===// // Instruction Cloning //===----------------------------------------------------------------------===// @@ -879,84 +657,59 @@ Instruction *Instruction::clone(BlockAndValueMapping &mapper, MLIRContext *context) const { SmallVector operands; SmallVector successors; - if (auto *opInst = dyn_cast(this)) { - operands.reserve(getNumOperands() + opInst->getNumSuccessors()); - if (!opInst->isTerminator()) { - // Non-terminators just add all the operands. - for (auto *opValue : getOperands()) + auto *opInst = cast(this); + operands.reserve(getNumOperands() + opInst->getNumSuccessors()); + + if (!opInst->isTerminator()) { + // Non-terminators just add all the operands. + for (auto *opValue : getOperands()) + operands.push_back(mapper.lookupOrDefault(const_cast(opValue))); + } else { + // We add the operands separated by nullptr's for each successor. + unsigned firstSuccOperand = opInst->getNumSuccessors() + ? opInst->getSuccessorOperandIndex(0) + : opInst->getNumOperands(); + auto InstOperands = opInst->getInstOperands(); + + unsigned i = 0; + for (; i != firstSuccOperand; ++i) + operands.push_back( + mapper.lookupOrDefault(const_cast(InstOperands[i].get()))); + + successors.reserve(opInst->getNumSuccessors()); + for (unsigned succ = 0, e = opInst->getNumSuccessors(); succ != e; ++succ) { + successors.push_back(mapper.lookupOrDefault( + const_cast(opInst->getSuccessor(succ)))); + + // Add sentinel to delineate successor operands. + operands.push_back(nullptr); + + // Remap the successors operands. + for (auto *operand : opInst->getSuccessorOperands(succ)) operands.push_back( - mapper.lookupOrDefault(const_cast(opValue))); - } else { - // We add the operands separated by nullptr's for each successor. - unsigned firstSuccOperand = opInst->getNumSuccessors() - ? opInst->getSuccessorOperandIndex(0) - : opInst->getNumOperands(); - auto InstOperands = opInst->getInstOperands(); - - unsigned i = 0; - for (; i != firstSuccOperand; ++i) - operands.push_back( - mapper.lookupOrDefault(const_cast(InstOperands[i].get()))); - - successors.reserve(opInst->getNumSuccessors()); - for (unsigned succ = 0, e = opInst->getNumSuccessors(); succ != e; - ++succ) { - successors.push_back(mapper.lookupOrDefault( - const_cast(opInst->getSuccessor(succ)))); - - // Add sentinel to delineate successor operands. - operands.push_back(nullptr); - - // Remap the successors operands. - for (auto *operand : opInst->getSuccessorOperands(succ)) - operands.push_back( - mapper.lookupOrDefault(const_cast(operand))); - } + mapper.lookupOrDefault(const_cast(operand))); } - - SmallVector resultTypes; - resultTypes.reserve(opInst->getNumResults()); - for (auto *result : opInst->getResults()) - resultTypes.push_back(result->getType()); - - unsigned numBlockLists = opInst->getNumBlockLists(); - auto *newOp = OperationInst::create( - getLoc(), opInst->getName(), operands, resultTypes, opInst->getAttrs(), - successors, numBlockLists, opInst->hasResizableOperandsList(), context); - - // Clone the block lists. - for (unsigned i = 0; i != numBlockLists; ++i) - opInst->getBlockList(i).cloneInto(&newOp->getBlockList(i), mapper, - context); - - // Remember the mapping of any results. - for (unsigned i = 0, e = opInst->getNumResults(); i != e; ++i) - mapper.map(opInst->getResult(i), newOp->getResult(i)); - return newOp; } - operands.reserve(getNumOperands()); - for (auto *opValue : getOperands()) - operands.push_back(mapper.lookupOrDefault(const_cast(opValue))); + SmallVector resultTypes; + resultTypes.reserve(opInst->getNumResults()); + for (auto *result : opInst->getResults()) + resultTypes.push_back(result->getType()); - // Otherwise, this must be a ForInst. - auto *forInst = cast(this); - auto lbMap = forInst->getLowerBoundMap(); - auto ubMap = forInst->getUpperBoundMap(); + unsigned numBlockLists = opInst->getNumBlockLists(); + auto *newOp = OperationInst::create( + getLoc(), opInst->getName(), operands, resultTypes, opInst->getAttrs(), + successors, numBlockLists, opInst->hasResizableOperandsList(), context); - auto *newFor = ForInst::create( - getLoc(), ArrayRef(operands).take_front(lbMap.getNumInputs()), - lbMap, ArrayRef(operands).take_back(ubMap.getNumInputs()), ubMap, - forInst->getStep()); - - // Remember the induction variable mapping. - mapper.map(forInst->getInductionVar(), newFor->getInductionVar()); + // Clone the block lists. + for (unsigned i = 0; i != numBlockLists; ++i) + opInst->getBlockList(i).cloneInto(&newOp->getBlockList(i), mapper, context); - // Recursively clone the body of the for loop. - for (auto &subInst : *forInst->getBody()) - newFor->getBody()->push_back(subInst.clone(mapper, context)); - return newFor; + // Remember the mapping of any results. + for (unsigned i = 0, e = opInst->getNumResults(); i != e; ++i) + mapper.map(opInst->getResult(i), newOp->getResult(i)); + return newOp; } Instruction *Instruction::clone(MLIRContext *context) const { diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp index 7103eeb7389..a9c046dc7b1 100644 --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -64,8 +64,6 @@ MLIRContext *IROperandOwner::getContext() const { switch (getKind()) { case Kind::OperationInst: return cast(this)->getContext(); - case Kind::ForInst: - return cast(this)->getContext(); } } diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index f0c140166ed..a9c62767734 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -2128,23 +2128,6 @@ public: parseSuccessors(SmallVectorImpl &destinations, SmallVectorImpl> &operands); - ParseResult - parseOptionalBlockArgList(SmallVectorImpl &results, - Block *owner); - - ParseResult parseOperationBlockList(SmallVectorImpl &results); - ParseResult parseBlockListBody(SmallVectorImpl &results); - ParseResult parseBlock(Block *&block); - ParseResult parseBlockBody(Block *block); - - /// Cleans up the memory for allocated blocks when a parser error occurs. - void cleanupInvalidBlocks(ArrayRef invalidBlocks) { - // Add the referenced blocks to the function so that they can be properly - // cleaned up when the function is destroyed. - for (auto *block : invalidBlocks) - function->push_back(block); - } - /// After the function is finished parsing, this function checks to see if /// there are any remaining issues. ParseResult finalizeFunction(SMLoc loc); @@ -2187,6 +2170,25 @@ public: // Block references. + ParseResult + parseOperationBlockList(SmallVectorImpl &results, + ArrayRef> entryArguments); + ParseResult parseBlockListBody(SmallVectorImpl &results); + ParseResult parseBlock(Block *&block); + ParseResult parseBlockBody(Block *block); + + ParseResult + parseOptionalBlockArgList(SmallVectorImpl &results, + Block *owner); + + /// Cleans up the memory for allocated blocks when a parser error occurs. + void cleanupInvalidBlocks(ArrayRef invalidBlocks) { + // Add the referenced blocks to the function so that they can be properly + // cleaned up when the function is destroyed. + for (auto *block : invalidBlocks) + function->push_back(block); + } + /// Get the block with the specified name, creating it if it doesn't /// already exist. The location specified is the point of use, which allows /// us to diagnose references to blocks that are not defined precisely. @@ -2201,13 +2203,6 @@ public: OperationInst *parseGenericOperation(); OperationInst *parseCustomOperation(); - ParseResult parseForInst(); - ParseResult parseIntConstant(int64_t &val); - ParseResult parseDimAndSymbolList(SmallVectorImpl &operands, - unsigned numDims, unsigned numOperands, - const char *affineStructName); - ParseResult parseBound(SmallVectorImpl &operands, AffineMap &map, - bool isLower); ParseResult parseInstructions(Block *block); private: @@ -2287,25 +2282,43 @@ ParseResult FunctionParser::parseFunctionBody(bool hadNamedArguments) { /// /// block-list ::= '{' block-list-body /// -ParseResult -FunctionParser::parseOperationBlockList(SmallVectorImpl &results) { +ParseResult FunctionParser::parseOperationBlockList( + SmallVectorImpl &results, + ArrayRef> entryArguments) { // Parse the '{'. if (parseToken(Token::l_brace, "expected '{' to begin block list")) return ParseFailure; + // Check for an empty block list. - if (consumeIf(Token::r_brace)) + if (entryArguments.empty() && consumeIf(Token::r_brace)) return ParseSuccess; Block *currentBlock = builder.getInsertionBlock(); // Parse the first block directly to allow for it to be unnamed. Block *block = new Block(); + + // Add arguments to the entry block. + for (auto &placeholderArgPair : entryArguments) + if (addDefinition(placeholderArgPair.first, + block->addArgument(placeholderArgPair.second))) { + delete block; + return ParseFailure; + } + if (parseBlock(block)) { - cleanupInvalidBlocks(block); + delete block; return ParseFailure; } - results.push_back(block); + + // Verify that no other arguments were parsed. + if (!entryArguments.empty() && + block->getNumArguments() > entryArguments.size()) { + delete block; + return emitError("entry block arguments were already defined"); + } // Parse the rest of the block list. + results.push_back(block); if (parseBlockListBody(results)) return ParseFailure; @@ -2385,10 +2398,6 @@ ParseResult FunctionParser::parseBlockBody(Block *block) { if (parseOperation()) return ParseFailure; break; - case Token::kw_for: - if (parseForInst()) - return ParseFailure; - break; } } @@ -2859,7 +2868,7 @@ OperationInst *FunctionParser::parseGenericOperation() { std::vector> blocks; while (getToken().is(Token::l_brace)) { SmallVector newBlocks; - if (parseOperationBlockList(newBlocks)) { + if (parseOperationBlockList(newBlocks, /*entryArguments=*/llvm::None)) { for (auto &blockList : blocks) cleanupInvalidBlocks(blockList); return nullptr; @@ -2884,6 +2893,27 @@ public: CustomOpAsmParser(SMLoc nameLoc, StringRef opName, FunctionParser &parser) : nameLoc(nameLoc), opName(opName), parser(parser) {} + bool parseOperation(const AbstractOperation *opDefinition, + OperationState *opState) { + if (opDefinition->parseAssembly(this, opState)) + return true; + + // Check that enough block lists were reserved for those that were parsed. + if (parsedBlockLists.size() > opState->numBlockLists) { + return emitError( + nameLoc, + "parsed more block lists than those reserved in the operation state"); + } + + // Check there were no dangling entry block arguments. + if (!parsedBlockListEntryArguments.empty()) { + return emitError( + nameLoc, + "no block list was attached to parsed entry block arguments"); + } + return false; + } + //===--------------------------------------------------------------------===// // High level parsing methods. //===--------------------------------------------------------------------===// @@ -2895,6 +2925,9 @@ public: bool parseComma() override { return parser.parseToken(Token::comma, "expected ','"); } + bool parseEqual() override { + return parser.parseToken(Token::equal, "expected '='"); + } bool parseType(Type &result) override { return !(result = parser.parseType()); @@ -3083,13 +3116,35 @@ public: /// Parses a list of blocks. bool parseBlockList() override { + // Parse the block list. SmallVector results; - if (parser.parseOperationBlockList(results)) + if (parser.parseOperationBlockList(results, parsedBlockListEntryArguments)) return true; + + parsedBlockListEntryArguments.clear(); parsedBlockLists.emplace_back(results); return false; } + /// Parses an argument for the entry block of the next block list to be + /// parsed. + bool parseBlockListEntryBlockArgument(Type argType) override { + SmallVector argValues; + OperandType operand; + if (parseOperand(operand)) + return true; + + // Create a place holder for this argument. + FunctionParser::SSAUseInfo operandInfo = {operand.name, operand.number, + operand.location}; + if (auto *value = parser.resolveSSAUse(operandInfo, argType)) { + parsedBlockListEntryArguments.emplace_back(operandInfo, argType); + return false; + } + + return true; + } + //===--------------------------------------------------------------------===// // Methods for interacting with the parser //===--------------------------------------------------------------------===// @@ -3130,6 +3185,8 @@ public: private: std::vector> parsedBlockLists; + SmallVector, 2> + parsedBlockListEntryArguments; SMLoc nameLoc; StringRef opName; FunctionParser &parser; @@ -3161,26 +3218,18 @@ OperationInst *FunctionParser::parseCustomOperation() { // Have the op implementation take a crack and parsing this. OperationState opState(builder.getContext(), srcLocation, opName); - if (opDefinition->parseAssembly(&opAsmParser, &opState)) + if (opAsmParser.parseOperation(opDefinition, &opState)) return nullptr; // If it emitted an error, we failed. if (opAsmParser.didEmitError()) return nullptr; - // Check that enough block lists were reserved for those that were parsed. - auto parsedBlockLists = opAsmParser.getParsedBlockLists(); - if (parsedBlockLists.size() > opState.numBlockLists) { - opAsmParser.emitError( - opLoc, - "parsed more block lists than those reserved in the operation state"); - return nullptr; - } - // Otherwise, we succeeded. Use the state it parsed as our op information. auto *opInst = builder.createOperation(opState); // Resolve any parsed block lists. + auto parsedBlockLists = opAsmParser.getParsedBlockLists(); for (unsigned i = 0, e = parsedBlockLists.size(); i != e; ++i) { auto &opBlockList = opInst->getBlockList(i).getBlocks(); opBlockList.insert(opBlockList.end(), parsedBlockLists[i].begin(), @@ -3189,213 +3238,6 @@ OperationInst *FunctionParser::parseCustomOperation() { return opInst; } -/// For instruction. -/// -/// ml-for-inst ::= `for` ssa-id `=` lower-bound `to` upper-bound -/// (`step` integer-literal)? trailing-location? `{` inst* `}` -/// -ParseResult FunctionParser::parseForInst() { - consumeToken(Token::kw_for); - - // Parse induction variable. - if (getToken().isNot(Token::percent_identifier)) - return emitError("expected SSA identifier for the loop variable"); - - auto loc = getToken().getLoc(); - StringRef inductionVariableName = getTokenSpelling(); - consumeToken(Token::percent_identifier); - - if (parseToken(Token::equal, "expected '='")) - return ParseFailure; - - // Parse lower bound. - SmallVector lbOperands; - AffineMap lbMap; - if (parseBound(lbOperands, lbMap, /*isLower*/ true)) - return ParseFailure; - - if (parseToken(Token::kw_to, "expected 'to' between bounds")) - return ParseFailure; - - // Parse upper bound. - SmallVector ubOperands; - AffineMap ubMap; - if (parseBound(ubOperands, ubMap, /*isLower*/ false)) - return ParseFailure; - - // Parse step. - int64_t step = 1; - if (consumeIf(Token::kw_step) && parseIntConstant(step)) - return ParseFailure; - - // The loop step is a positive integer constant. Since index is stored as an - // int64_t type, we restrict step to be in the set of positive integers that - // int64_t can represent. - if (step < 1) { - return emitError("step has to be a positive integer"); - } - - // Create for instruction. - ForInst *forInst = - builder.createFor(getEncodedSourceLocation(loc), lbOperands, lbMap, - ubOperands, ubMap, step); - - // Create SSA value definition for the induction variable. - if (addDefinition({inductionVariableName, 0, loc}, - forInst->getInductionVar())) - return ParseFailure; - - // Try to parse the optional trailing location. - if (parseOptionalTrailingLocation(forInst)) - return ParseFailure; - - // If parsing of the for instruction body fails, - // MLIR contains for instruction with those nested instructions that have been - // successfully parsed. - auto *forBody = forInst->getBody(); - if (parseToken(Token::l_brace, "expected '{' before instruction list") || - parseBlock(forBody) || - parseToken(Token::r_brace, "expected '}' after instruction list")) - return ParseFailure; - - // Reset insertion point to the current block. - builder.setInsertionPointToEnd(forInst->getBlock()); - - return ParseSuccess; -} - -/// Parse integer constant as affine constant expression. -ParseResult FunctionParser::parseIntConstant(int64_t &val) { - bool negate = consumeIf(Token::minus); - - if (getToken().isNot(Token::integer)) - return emitError("expected integer"); - - auto uval = getToken().getUInt64IntegerValue(); - - if (!uval.hasValue() || (int64_t)uval.getValue() < 0) { - return emitError("bound or step is too large for index"); - } - - val = (int64_t)uval.getValue(); - if (negate) - val = -val; - consumeToken(); - - return ParseSuccess; -} - -/// Dimensions and symbol use list. -/// -/// dim-use-list ::= `(` ssa-use-list? `)` -/// symbol-use-list ::= `[` ssa-use-list? `]` -/// dim-and-symbol-use-list ::= dim-use-list symbol-use-list? -/// -ParseResult -FunctionParser::parseDimAndSymbolList(SmallVectorImpl &operands, - unsigned numDims, unsigned numOperands, - const char *affineStructName) { - if (parseToken(Token::l_paren, "expected '('")) - return ParseFailure; - - SmallVector opInfo; - parseOptionalSSAUseList(opInfo); - - if (parseToken(Token::r_paren, "expected ')'")) - return ParseFailure; - - if (numDims != opInfo.size()) - return emitError("dim operand count and " + Twine(affineStructName) + - " dim count must match"); - - if (consumeIf(Token::l_square)) { - parseOptionalSSAUseList(opInfo); - if (parseToken(Token::r_square, "expected ']'")) - return ParseFailure; - } - - if (numOperands != opInfo.size()) - return emitError("symbol operand count and " + Twine(affineStructName) + - " symbol count must match"); - - // Resolve SSA uses. - Type indexType = builder.getIndexType(); - for (unsigned i = 0, e = opInfo.size(); i != e; ++i) { - Value *sval = resolveSSAUse(opInfo[i], indexType); - if (!sval) - return ParseFailure; - - if (i < numDims && !sval->isValidDim()) - return emitError(opInfo[i].loc, "value '" + opInfo[i].name.str() + - "' cannot be used as a dimension id"); - if (i >= numDims && !sval->isValidSymbol()) - return emitError(opInfo[i].loc, "value '" + opInfo[i].name.str() + - "' cannot be used as a symbol"); - operands.push_back(sval); - } - - return ParseSuccess; -} - -// Loop bound. -/// -/// lower-bound ::= `max`? affine-map dim-and-symbol-use-list | -/// shorthand-bound upper-bound ::= `min`? affine-map dim-and-symbol-use-list -/// | shorthand-bound shorthand-bound ::= ssa-id | `-`? integer-literal -/// -ParseResult FunctionParser::parseBound(SmallVectorImpl &operands, - AffineMap &map, bool isLower) { - // 'min' / 'max' prefixes are syntactic sugar. Ignore them. - if (isLower) - consumeIf(Token::kw_max); - else - consumeIf(Token::kw_min); - - // Parse full form - affine map followed by dim and symbol list. - if (getToken().isAny(Token::hash_identifier, Token::l_paren)) { - map = parseAffineMapReference(); - if (!map) - return ParseFailure; - - if (parseDimAndSymbolList(operands, map.getNumDims(), map.getNumInputs(), - "affine map")) - return ParseFailure; - return ParseSuccess; - } - - // Parse custom assembly form. - if (getToken().isAny(Token::minus, Token::integer)) { - int64_t val; - if (!parseIntConstant(val)) { - map = builder.getConstantAffineMap(val); - return ParseSuccess; - } - return ParseFailure; - } - - // Parse ssa-id as identity map. - SSAUseInfo opInfo; - if (parseSSAUse(opInfo)) - return ParseFailure; - - // TODO: improve error message when SSA value is not an affine integer. - // Currently it is 'use of value ... expects different type than prior uses' - if (auto *value = resolveSSAUse(opInfo, builder.getIndexType())) - operands.push_back(value); - else - return ParseFailure; - - // Create an identity map using dim id for an induction variable and - // symbol otherwise. This representation is optimized for storage. - // Analysis passes may expand it into a multi-dimensional map if desired. - if (isForInductionVar(operands[0])) - map = builder.getDimIdentityMap(); - else - map = builder.getSymbolIdentityMap(); - - return ParseSuccess; -} - /// Parse an affine constraint. /// affine-constraint ::= affine-expr `>=` `0` /// | affine-expr `==` `0` diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index afd18a49b79..e471b6792c5 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -183,11 +183,6 @@ void CSE::simplifyBlock(Block *bb) { } break; } - case Instruction::Kind::For: { - ScopedMapTy::ScopeTy scope(knownValues); - simplifyBlock(cast(i).getBody()); - break; - } } } } diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index f9d02f7a47a..9c20e79180a 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -15,6 +15,7 @@ // limitations under the License. // ============================================================================= +#include "mlir/AffineOps/AffineOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" #include "mlir/IR/InstVisitor.h" @@ -37,7 +38,6 @@ struct ConstantFold : public FunctionPass, InstWalker { bool foldOperation(OperationInst *op, SmallVectorImpl &existingConstants); void visitOperationInst(OperationInst *inst); - void visitForInst(ForInst *inst); PassResult runOnFunction(Function *f) override; static char passID; @@ -50,6 +50,12 @@ char ConstantFold::passID = 0; /// constants are found, we keep track of them in the existingConstants list. /// void ConstantFold::visitOperationInst(OperationInst *op) { + // If this operation is an AffineForOp, then fold the bounds. + if (auto forOp = op->dyn_cast()) { + constantFoldBounds(forOp); + return; + } + // If this operation is already a constant, just remember it for cleanup // later, and don't try to fold it. if (auto constant = op->dyn_cast()) { @@ -98,11 +104,6 @@ void ConstantFold::visitOperationInst(OperationInst *op) { opInstsToErase.push_back(op); } -// Override the walker's 'for' instruction visit for constant folding. -void ConstantFold::visitForInst(ForInst *forInst) { - constantFoldBounds(forInst); -} - // For now, we do a simple top-down pass over a function folding constants. We // don't handle conditional control flow, block arguments, folding // conditional branches, or anything else fancy. diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 5c3a66208ec..83ec726ec2a 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -21,6 +21,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Utils.h" #include "mlir/IR/Builders.h" @@ -71,9 +72,9 @@ struct DmaGeneration : public FunctionPass { } PassResult runOnFunction(Function *f) override; - void runOnForInst(ForInst *forInst); + void runOnAffineForOp(OpPointer forOp); - bool generateDma(const MemRefRegion ®ion, ForInst *forInst, + bool generateDma(const MemRefRegion ®ion, OpPointer forOp, uint64_t *sizeInBytes); // List of memory regions to DMA for. We need a map vector to have a @@ -174,7 +175,7 @@ static bool getFullMemRefAsRegion(OperationInst *opInst, // Just get the first numSymbols IVs, which the memref region is parametric // on. - SmallVector ivs; + SmallVector, 4> ivs; getLoopIVs(*opInst, &ivs); ivs.resize(numParamLoopIVs); SmallVector symbols = extractForInductionVars(ivs); @@ -195,8 +196,10 @@ static bool getFullMemRefAsRegion(OperationInst *opInst, // generates a DMA from the lower memory space to this one, and replaces all // loads to load from that buffer. Returns false if DMAs could not be generated // due to yet unimplemented cases. -bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForInst *forInst, +bool DmaGeneration::generateDma(const MemRefRegion ®ion, + OpPointer forOp, uint64_t *sizeInBytes) { + auto *forInst = forOp->getInstruction(); // DMAs for read regions are going to be inserted just before the for loop. FuncBuilder prologue(forInst); @@ -386,39 +389,43 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForInst *forInst, remapExprs.push_back(dimExpr - offsets[i]); } auto indexRemap = b->getAffineMap(outerIVs.size() + rank, 0, remapExprs, {}); - // *Only* those uses within the body of 'forInst' are replaced. + // *Only* those uses within the body of 'forOp' are replaced. replaceAllMemRefUsesWith(memref, fastMemRef, /*extraIndices=*/{}, indexRemap, /*extraOperands=*/outerIVs, - /*domInstFilter=*/&*forInst->getBody()->begin()); + /*domInstFilter=*/&*forOp->getBody()->begin()); return true; } // TODO(bondhugula): make this run on a Block instead of a 'for' inst. -void DmaGeneration::runOnForInst(ForInst *forInst) { +void DmaGeneration::runOnAffineForOp(OpPointer forOp) { // For now (for testing purposes), we'll run this on the outermost among 'for' // inst's with unit stride, i.e., right at the top of the tile if tiling has // been done. In the future, the DMA generation has to be done at a level // where the generated data fits in a higher level of the memory hierarchy; so // the pass has to be instantiated with additional information that we aren't // provided with at the moment. - if (forInst->getStep() != 1) { - if (auto *innerFor = dyn_cast(&*forInst->getBody()->begin())) { - runOnForInst(innerFor); + if (forOp->getStep() != 1) { + auto *forBody = forOp->getBody(); + if (forBody->empty()) + return; + if (auto innerFor = + cast(forBody->front()).dyn_cast()) { + runOnAffineForOp(innerFor); } return; } // DMAs will be generated for this depth, i.e., for all data accessed by this // loop. - unsigned dmaDepth = getNestingDepth(*forInst); + unsigned dmaDepth = getNestingDepth(*forOp->getInstruction()); readRegions.clear(); writeRegions.clear(); fastBufferMap.clear(); // Walk this 'for' instruction to gather all memory regions. - forInst->walkOps([&](OperationInst *opInst) { + forOp->walkOps([&](OperationInst *opInst) { // Gather regions to promote to buffers in faster memory space. // TODO(bondhugula): handle store op's; only load's handled for now. if (auto loadOp = opInst->dyn_cast()) { @@ -443,7 +450,7 @@ void DmaGeneration::runOnForInst(ForInst *forInst) { LLVM_DEBUG(llvm::dbgs() << "over-approximating to the entire memref\n"); if (!getFullMemRefAsRegion(opInst, dmaDepth, region.get())) { LLVM_DEBUG( - forInst->emitError("Non-constant memref sizes not yet supported")); + forOp->emitError("Non-constant memref sizes not yet supported")); return; } } @@ -472,10 +479,10 @@ void DmaGeneration::runOnForInst(ForInst *forInst) { // Perform a union with the existing region. if (!(*it).second->unionBoundingBox(*region)) { LLVM_DEBUG(llvm::dbgs() - << "Memory region bounding box failed; " + << "Memory region bounding box failed" "over-approximating to the entire memref\n"); if (!getFullMemRefAsRegion(opInst, dmaDepth, region.get())) { - LLVM_DEBUG(forInst->emitError( + LLVM_DEBUG(forOp->emitError( "Non-constant memref sizes not yet supported")); } } @@ -501,7 +508,7 @@ void DmaGeneration::runOnForInst(ForInst *forInst) { ®ions) { for (const auto ®ionEntry : regions) { uint64_t sizeInBytes; - bool iRet = generateDma(*regionEntry.second, forInst, &sizeInBytes); + bool iRet = generateDma(*regionEntry.second, forOp, &sizeInBytes); if (iRet) totalSizeInBytes += sizeInBytes; ret = ret & iRet; @@ -510,7 +517,7 @@ void DmaGeneration::runOnForInst(ForInst *forInst) { processRegions(readRegions); processRegions(writeRegions); if (!ret) { - forInst->emitError("DMA generation failed for one or more memref's\n"); + forOp->emitError("DMA generation failed for one or more memref's\n"); return; } LLVM_DEBUG(llvm::dbgs() << Twine(llvm::divideCeil(totalSizeInBytes, 1024)) @@ -519,7 +526,7 @@ void DmaGeneration::runOnForInst(ForInst *forInst) { if (clFastMemoryCapacity && totalSizeInBytes > clFastMemoryCapacity) { // TODO(bondhugula): selecting the DMA depth so that the result DMA buffers // fit in fast memory is a TODO - not complex. - forInst->emitError( + forOp->emitError( "Total size of all DMA buffers' exceeds memory capacity\n"); } } @@ -531,8 +538,8 @@ PassResult DmaGeneration::runOnFunction(Function *f) { for (auto &block : *f) { for (auto &inst : block) { - if (auto *forInst = dyn_cast(&inst)) { - runOnForInst(forInst); + if (auto forOp = cast(inst).dyn_cast()) { + runOnAffineForOp(forOp); } } } diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index fa0e3b51de3..7d4ff03e306 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -97,15 +97,15 @@ namespace { // operations, and whether or not an IfInst was encountered in the loop nest. class LoopNestStateCollector : public InstWalker { public: - SmallVector forInsts; + SmallVector, 4> forOps; SmallVector loadOpInsts; SmallVector storeOpInsts; bool hasNonForRegion = false; - void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); } - void visitOperationInst(OperationInst *opInst) { - if (opInst->getNumBlockLists() != 0) + if (opInst->isa()) + forOps.push_back(opInst->cast()); + else if (opInst->getNumBlockLists() != 0) hasNonForRegion = true; else if (opInst->isa()) loadOpInsts.push_back(opInst); @@ -491,14 +491,14 @@ bool MemRefDependenceGraph::init(Function *f) { if (f->getBlocks().size() != 1) return false; - DenseMap forToNodeMap; + DenseMap forToNodeMap; for (auto &inst : f->front()) { - if (auto *forInst = dyn_cast(&inst)) { - // Create graph node 'id' to represent top-level 'forInst' and record + if (auto forOp = cast(&inst)->dyn_cast()) { + // Create graph node 'id' to represent top-level 'forOp' and record // all loads and store accesses it contains. LoopNestStateCollector collector; - collector.walkForInst(forInst); - // Return false if IfInsts are found (not currently supported). + collector.walk(&inst); + // Return false if a non 'for' region was found (not currently supported). if (collector.hasNonForRegion) return false; Node node(nextNodeId++, &inst); @@ -512,10 +512,9 @@ bool MemRefDependenceGraph::init(Function *f) { auto *memref = opInst->cast()->getMemRef(); memrefAccesses[memref].insert(node.id); } - forToNodeMap[forInst] = node.id; + forToNodeMap[&inst] = node.id; nodes.insert({node.id, node}); - } - if (auto *opInst = dyn_cast(&inst)) { + } else if (auto *opInst = dyn_cast(&inst)) { if (auto loadOp = opInst->dyn_cast()) { // Create graph node for top-level load op. Node node(nextNodeId++, &inst); @@ -552,12 +551,12 @@ bool MemRefDependenceGraph::init(Function *f) { for (auto *value : opInst->getResults()) { for (auto &use : value->getUses()) { auto *userOpInst = cast(use.getOwner()); - SmallVector loops; + SmallVector, 4> loops; getLoopIVs(*userOpInst, &loops); if (loops.empty()) continue; - assert(forToNodeMap.count(loops[0]) > 0); - unsigned userLoopNestId = forToNodeMap[loops[0]]; + assert(forToNodeMap.count(loops[0]->getInstruction()) > 0); + unsigned userLoopNestId = forToNodeMap[loops[0]->getInstruction()]; addEdge(node.id, userLoopNestId, value); } } @@ -587,12 +586,12 @@ namespace { // LoopNestStats aggregates various per-loop statistics (eg. loop trip count // and operation count) for a loop nest up until the innermost loop body. struct LoopNestStats { - // Map from ForInst to immediate child ForInsts in its loop body. - DenseMap> loopMap; - // Map from ForInst to count of operations in its loop body. - DenseMap opCountMap; - // Map from ForInst to its constant trip count. - DenseMap tripCountMap; + // Map from AffineForOp to immediate child AffineForOps in its loop body. + DenseMap, 2>> loopMap; + // Map from AffineForOp to count of operations in its loop body. + DenseMap opCountMap; + // Map from AffineForOp to its constant trip count. + DenseMap tripCountMap; }; // LoopNestStatsCollector walks a single loop nest and gathers per-loop @@ -604,23 +603,31 @@ public: LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {} - void visitForInst(ForInst *forInst) { - auto *parentInst = forInst->getParentInst(); + void visitOperationInst(OperationInst *opInst) { + auto forOp = opInst->dyn_cast(); + if (!forOp) + return; + + auto *forInst = forOp->getInstruction(); + auto *parentInst = forOp->getInstruction()->getParentInst(); if (parentInst != nullptr) { - assert(isa(parentInst) && "Expected parent ForInst"); - // Add mapping to 'forInst' from its parent ForInst. - stats->loopMap[cast(parentInst)].push_back(forInst); + assert(cast(parentInst)->isa() && + "Expected parent AffineForOp"); + // Add mapping to 'forOp' from its parent AffineForOp. + stats->loopMap[parentInst].push_back(forOp); } - // Record the number of op instructions in the body of 'forInst'. + + // Record the number of op instructions in the body of 'forOp'. unsigned count = 0; stats->opCountMap[forInst] = 0; - for (auto &inst : *forInst->getBody()) { - if (isa(&inst)) + for (auto &inst : *forOp->getBody()) { + if (!(cast(inst).isa() || + cast(inst).isa())) ++count; } stats->opCountMap[forInst] = count; - // Record trip count for 'forInst'. Set flag if trip count is not constant. - Optional maybeConstTripCount = getConstantTripCount(*forInst); + // Record trip count for 'forOp'. Set flag if trip count is not constant. + Optional maybeConstTripCount = getConstantTripCount(forOp); if (!maybeConstTripCount.hasValue()) { hasLoopWithNonConstTripCount = true; return; @@ -629,7 +636,7 @@ public: } }; -// Computes the total cost of the loop nest rooted at 'forInst'. +// Computes the total cost of the loop nest rooted at 'forOp'. // Currently, the total cost is computed by counting the total operation // instance count (i.e. total number of operations in the loop bodyloop // operation count * loop trip count) for the entire loop nest. @@ -637,7 +644,7 @@ public: // specified in the map when computing the total op instance count. // NOTE: this is used to compute the cost of computation slices, which are // sliced along the iteration dimension, and thus reduce the trip count. -// If 'computeCostMap' is non-null, the total op count for forInsts specified +// If 'computeCostMap' is non-null, the total op count for forOps specified // in the map is increased (not overridden) by adding the op count from the // map to the existing op count for the for loop. This is done before // multiplying by the loop's trip count, and is used to model the cost of @@ -645,15 +652,15 @@ public: // NOTE: this is used to compute the cost of fusing a slice of some loop nest // within another loop. static int64_t getComputeCost( - ForInst *forInst, LoopNestStats *stats, - llvm::SmallDenseMap *tripCountOverrideMap, - DenseMap *computeCostMap) { - // 'opCount' is the total number operations in one iteration of 'forInst' body + Instruction *forInst, LoopNestStats *stats, + llvm::SmallDenseMap *tripCountOverrideMap, + DenseMap *computeCostMap) { + // 'opCount' is the total number operations in one iteration of 'forOp' body int64_t opCount = stats->opCountMap[forInst]; if (stats->loopMap.count(forInst) > 0) { - for (auto *childForInst : stats->loopMap[forInst]) { - opCount += getComputeCost(childForInst, stats, tripCountOverrideMap, - computeCostMap); + for (auto childForOp : stats->loopMap[forInst]) { + opCount += getComputeCost(childForOp->getInstruction(), stats, + tripCountOverrideMap, computeCostMap); } } // Add in additional op instances from slice (if specified in map). @@ -694,18 +701,18 @@ static Optional getConstDifference(AffineMap lbMap, AffineMap ubMap) { return cExpr.getValue(); } -// Builds a map 'tripCountMap' from ForInst to constant trip count for loop +// Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop // nest surrounding 'srcAccess' utilizing slice loop bounds in 'sliceState'. // Returns true on success, false otherwise (if a non-constant trip count // was encountered). // TODO(andydavis) Make this work with non-unit step loops. static bool buildSliceTripCountMap( OperationInst *srcOpInst, ComputationSliceState *sliceState, - llvm::SmallDenseMap *tripCountMap) { - SmallVector srcLoopIVs; + llvm::SmallDenseMap *tripCountMap) { + SmallVector, 4> srcLoopIVs; getLoopIVs(*srcOpInst, &srcLoopIVs); unsigned numSrcLoopIVs = srcLoopIVs.size(); - // Populate map from ForInst -> trip count + // Populate map from AffineForOp -> trip count for (unsigned i = 0; i < numSrcLoopIVs; ++i) { AffineMap lbMap = sliceState->lbs[i]; AffineMap ubMap = sliceState->ubs[i]; @@ -713,7 +720,7 @@ static bool buildSliceTripCountMap( // The iteration of src loop IV 'i' was not sliced. Use full loop bounds. if (srcLoopIVs[i]->hasConstantLowerBound() && srcLoopIVs[i]->hasConstantUpperBound()) { - (*tripCountMap)[srcLoopIVs[i]] = + (*tripCountMap)[srcLoopIVs[i]->getInstruction()] = srcLoopIVs[i]->getConstantUpperBound() - srcLoopIVs[i]->getConstantLowerBound(); continue; @@ -723,7 +730,7 @@ static bool buildSliceTripCountMap( Optional tripCount = getConstDifference(lbMap, ubMap); if (!tripCount.hasValue()) return false; - (*tripCountMap)[srcLoopIVs[i]] = tripCount.getValue(); + (*tripCountMap)[srcLoopIVs[i]->getInstruction()] = tripCount.getValue(); } return true; } @@ -750,7 +757,7 @@ static unsigned getInnermostCommonLoopDepth(ArrayRef ops) { unsigned numOps = ops.size(); assert(numOps > 0); - std::vector> loops(numOps); + std::vector, 4>> loops(numOps); unsigned loopDepthLimit = std::numeric_limits::max(); for (unsigned i = 0; i < numOps; ++i) { getLoopIVs(*ops[i], &loops[i]); @@ -762,9 +769,8 @@ static unsigned getInnermostCommonLoopDepth(ArrayRef ops) { for (unsigned d = 0; d < loopDepthLimit; ++d) { unsigned i; for (i = 1; i < numOps; ++i) { - if (loops[i - 1][d] != loops[i][d]) { + if (loops[i - 1][d] != loops[i][d]) break; - } } if (i != numOps) break; @@ -871,14 +877,16 @@ static bool getSliceUnion(const ComputationSliceState &sliceStateA, } // Creates and returns a private (single-user) memref for fused loop rooted -// at 'forInst', with (potentially reduced) memref size based on the +// at 'forOp', with (potentially reduced) memref size based on the // MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'. // TODO(bondhugula): consider refactoring the common code from generateDma and // this one. -static Value *createPrivateMemRef(ForInst *forInst, +static Value *createPrivateMemRef(OpPointer forOp, OperationInst *srcStoreOpInst, unsigned dstLoopDepth) { - // Create builder to insert alloc op just before 'forInst'. + auto *forInst = forOp->getInstruction(); + + // Create builder to insert alloc op just before 'forOp'. FuncBuilder b(forInst); // Builder to create constants at the top level. FuncBuilder top(forInst->getFunction()); @@ -934,16 +942,16 @@ static Value *createPrivateMemRef(ForInst *forInst, for (auto dimSize : oldMemRefType.getShape()) { if (dimSize == -1) allocOperands.push_back( - top.create(forInst->getLoc(), oldMemRef, dynamicDimCount++)); + top.create(forOp->getLoc(), oldMemRef, dynamicDimCount++)); } - // Create new private memref for fused loop 'forInst'. + // Create new private memref for fused loop 'forOp'. // TODO(andydavis) Create/move alloc ops for private memrefs closer to their // consumer loop nests to reduce their live range. Currently they are added // at the beginning of the function, because loop nests can be reordered // during the fusion pass. Value *newMemRef = - top.create(forInst->getLoc(), newMemRefType, allocOperands); + top.create(forOp->getLoc(), newMemRefType, allocOperands); // Build an AffineMap to remap access functions based on lower bound offsets. SmallVector remapExprs; @@ -967,7 +975,7 @@ static Value *createPrivateMemRef(ForInst *forInst, bool ret = replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap, /*extraOperands=*/outerIVs, - /*domInstFilter=*/&*forInst->getBody()->begin()); + /*domInstFilter=*/&*forOp->getBody()->begin()); assert(ret && "replaceAllMemrefUsesWith should always succeed here"); (void)ret; return newMemRef; @@ -975,7 +983,7 @@ static Value *createPrivateMemRef(ForInst *forInst, // Does the slice have a single iteration? static uint64_t getSliceIterationCount( - const llvm::SmallDenseMap &sliceTripCountMap) { + const llvm::SmallDenseMap &sliceTripCountMap) { uint64_t iterCount = 1; for (const auto &count : sliceTripCountMap) { iterCount *= count.second; @@ -1030,25 +1038,25 @@ static bool isFusionProfitable(OperationInst *srcOpInst, }); // Compute cost of sliced and unsliced src loop nest. - SmallVector srcLoopIVs; + SmallVector, 4> srcLoopIVs; getLoopIVs(*srcOpInst, &srcLoopIVs); unsigned numSrcLoopIVs = srcLoopIVs.size(); // Walk src loop nest and collect stats. LoopNestStats srcLoopNestStats; LoopNestStatsCollector srcStatsCollector(&srcLoopNestStats); - srcStatsCollector.walk(srcLoopIVs[0]); + srcStatsCollector.walk(srcLoopIVs[0]->getInstruction()); // Currently only constant trip count loop nests are supported. if (srcStatsCollector.hasLoopWithNonConstTripCount) return false; // Compute cost of dst loop nest. - SmallVector dstLoopIVs; + SmallVector, 4> dstLoopIVs; getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs); LoopNestStats dstLoopNestStats; LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats); - dstStatsCollector.walk(dstLoopIVs[0]); + dstStatsCollector.walk(dstLoopIVs[0]->getInstruction()); // Currently only constant trip count loop nests are supported. if (dstStatsCollector.hasLoopWithNonConstTripCount) return false; @@ -1075,17 +1083,19 @@ static bool isFusionProfitable(OperationInst *srcOpInst, Optional bestDstLoopDepth = None; // Compute op instance count for the src loop nest without iteration slicing. - uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], &srcLoopNestStats, - /*tripCountOverrideMap=*/nullptr, - /*computeCostMap=*/nullptr); + uint64_t srcLoopNestCost = + getComputeCost(srcLoopIVs[0]->getInstruction(), &srcLoopNestStats, + /*tripCountOverrideMap=*/nullptr, + /*computeCostMap=*/nullptr); // Compute op instance count for the src loop nest. - uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], &dstLoopNestStats, - /*tripCountOverrideMap=*/nullptr, - /*computeCostMap=*/nullptr); + uint64_t dstLoopNestCost = + getComputeCost(dstLoopIVs[0]->getInstruction(), &dstLoopNestStats, + /*tripCountOverrideMap=*/nullptr, + /*computeCostMap=*/nullptr); - llvm::SmallDenseMap sliceTripCountMap; - DenseMap computeCostMap; + llvm::SmallDenseMap sliceTripCountMap; + DenseMap computeCostMap; for (unsigned i = maxDstLoopDepth; i >= 1; --i) { MemRefAccess srcAccess(srcOpInst); // Handle the common case of one dst load without a copy. @@ -1121,24 +1131,25 @@ static bool isFusionProfitable(OperationInst *srcOpInst, // The store and loads to this memref will disappear. if (storeLoadFwdGuaranteed) { // A single store disappears: -1 for that. - computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]] = -1; + computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]->getInstruction()] = -1; for (auto *loadOp : dstLoadOpInsts) { - if (auto *loadLoop = dyn_cast_or_null(loadOp->getParentInst())) - computeCostMap[loadLoop] = -1; + auto *parentInst = loadOp->getParentInst(); + if (parentInst && cast(parentInst)->isa()) + computeCostMap[parentInst] = -1; } } // Compute op instance count for the src loop nest with iteration slicing. int64_t sliceComputeCost = - getComputeCost(srcLoopIVs[0], &srcLoopNestStats, + getComputeCost(srcLoopIVs[0]->getInstruction(), &srcLoopNestStats, /*tripCountOverrideMap=*/&sliceTripCountMap, /*computeCostMap=*/&computeCostMap); // Compute cost of fusion for this depth. - computeCostMap[dstLoopIVs[i - 1]] = sliceComputeCost; + computeCostMap[dstLoopIVs[i - 1]->getInstruction()] = sliceComputeCost; int64_t fusedLoopNestComputeCost = - getComputeCost(dstLoopIVs[0], &dstLoopNestStats, + getComputeCost(dstLoopIVs[0]->getInstruction(), &dstLoopNestStats, /*tripCountOverrideMap=*/nullptr, &computeCostMap); double additionalComputeFraction = @@ -1211,8 +1222,8 @@ static bool isFusionProfitable(OperationInst *srcOpInst, << "\n fused loop nest compute cost: " << minFusedLoopNestComputeCost << "\n"); - auto dstMemSize = getMemoryFootprintBytes(*dstLoopIVs[0]); - auto srcMemSize = getMemoryFootprintBytes(*srcLoopIVs[0]); + auto dstMemSize = getMemoryFootprintBytes(dstLoopIVs[0]); + auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]); Optional storageReduction = None; @@ -1292,9 +1303,9 @@ static bool isFusionProfitable(OperationInst *srcOpInst, // // *) A worklist is initialized with node ids from the dependence graph. // *) For each node id in the worklist: -// *) Pop a ForInst of the worklist. This 'dstForInst' will be a candidate -// destination ForInst into which fusion will be attempted. -// *) Add each LoadOp currently in 'dstForInst' into list 'dstLoadOps'. +// *) Pop a AffineForOp of the worklist. This 'dstAffineForOp' will be a +// candidate destination AffineForOp into which fusion will be attempted. +// *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'. // *) For each LoadOp in 'dstLoadOps' do: // *) Lookup dependent loop nests at earlier positions in the Function // which have a single store op to the same memref. @@ -1342,7 +1353,7 @@ public: // Get 'dstNode' into which to attempt fusion. auto *dstNode = mdg->getNode(dstId); // Skip if 'dstNode' is not a loop nest. - if (!isa(dstNode->inst)) + if (!cast(dstNode->inst)->isa()) continue; SmallVector loads = dstNode->loads; @@ -1375,7 +1386,7 @@ public: // Get 'srcNode' from which to attempt fusion into 'dstNode'. auto *srcNode = mdg->getNode(srcId); // Skip if 'srcNode' is not a loop nest. - if (!isa(srcNode->inst)) + if (!cast(srcNode->inst)->isa()) continue; // Skip if 'srcNode' has more than one store to any memref. // TODO(andydavis) Support fusing multi-output src loop nests. @@ -1417,25 +1428,26 @@ public: continue; // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. - auto *sliceLoopNest = mlir::insertBackwardComputationSlice( + auto sliceLoopNest = mlir::insertBackwardComputationSlice( srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState); if (sliceLoopNest != nullptr) { - // Move 'dstForInst' before 'insertPointInst' if needed. - auto *dstForInst = cast(dstNode->inst); - if (insertPointInst != dstForInst) { - dstForInst->moveBefore(insertPointInst); + // Move 'dstAffineForOp' before 'insertPointInst' if needed. + auto dstAffineForOp = + cast(dstNode->inst)->cast(); + if (insertPointInst != dstAffineForOp->getInstruction()) { + dstAffineForOp->getInstruction()->moveBefore(insertPointInst); } // Update edges between 'srcNode' and 'dstNode'. mdg->updateEdges(srcNode->id, dstNode->id, memref); // Collect slice loop stats. LoopNestStateCollector sliceCollector; - sliceCollector.walkForInst(sliceLoopNest); + sliceCollector.walk(sliceLoopNest->getInstruction()); // Promote single iteration slice loops to single IV value. - for (auto *forInst : sliceCollector.forInsts) { - promoteIfSingleIteration(forInst); + for (auto forOp : sliceCollector.forOps) { + promoteIfSingleIteration(forOp); } - // Create private memref for 'memref' in 'dstForInst'. + // Create private memref for 'memref' in 'dstAffineForOp'. SmallVector storesForMemref; for (auto *storeOpInst : sliceCollector.storeOpInsts) { if (storeOpInst->cast()->getMemRef() == memref) @@ -1443,7 +1455,7 @@ public: } assert(storesForMemref.size() == 1); auto *newMemRef = createPrivateMemRef( - dstForInst, storesForMemref[0], bestDstLoopDepth); + dstAffineForOp, storesForMemref[0], bestDstLoopDepth); visitedMemrefs.insert(newMemRef); // Create new node in dependence graph for 'newMemRef' alloc op. unsigned newMemRefNodeId = @@ -1453,7 +1465,7 @@ public: // Collect dst loop stats after memref privatizaton transformation. LoopNestStateCollector dstLoopCollector; - dstLoopCollector.walkForInst(dstForInst); + dstLoopCollector.walk(dstAffineForOp->getInstruction()); // Add new load ops to current Node load op list 'loads' to // continue fusing based on new operands. @@ -1472,7 +1484,7 @@ public: // function. if (mdg->canRemoveNode(srcNode->id)) { mdg->removeNode(srcNode->id); - cast(srcNode->inst)->erase(); + srcNode->inst->erase(); } } } diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 396fc8eb658..f1ee7fd1853 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -19,6 +19,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/LoopAnalysis.h" @@ -60,16 +61,17 @@ char LoopTiling::passID = 0; /// Function. FunctionPass *mlir::createLoopTilingPass() { return new LoopTiling(); } -// Move the loop body of ForInst 'src' from 'src' into the specified location in -// destination's body. -static inline void moveLoopBody(ForInst *src, ForInst *dest, +// Move the loop body of AffineForOp 'src' from 'src' into the specified +// location in destination's body. +static inline void moveLoopBody(AffineForOp *src, AffineForOp *dest, Block::iterator loc) { dest->getBody()->getInstructions().splice(loc, src->getBody()->getInstructions()); } -// Move the loop body of ForInst 'src' from 'src' to the start of dest's body. -static inline void moveLoopBody(ForInst *src, ForInst *dest) { +// Move the loop body of AffineForOp 'src' from 'src' to the start of dest's +// body. +static inline void moveLoopBody(AffineForOp *src, AffineForOp *dest) { moveLoopBody(src, dest, dest->getBody()->begin()); } @@ -78,13 +80,14 @@ static inline void moveLoopBody(ForInst *src, ForInst *dest) { /// depend on other dimensions. Bounds of each dimension can thus be treated /// independently, and deriving the new bounds is much simpler and faster /// than for the case of tiling arbitrary polyhedral shapes. -static void constructTiledIndexSetHyperRect(ArrayRef origLoops, - ArrayRef newLoops, - ArrayRef tileSizes) { +static void constructTiledIndexSetHyperRect( + MutableArrayRef> origLoops, + MutableArrayRef> newLoops, + ArrayRef tileSizes) { assert(!origLoops.empty()); assert(origLoops.size() == tileSizes.size()); - FuncBuilder b(origLoops[0]); + FuncBuilder b(origLoops[0]->getInstruction()); unsigned width = origLoops.size(); // Bounds for tile space loops. @@ -99,8 +102,8 @@ static void constructTiledIndexSetHyperRect(ArrayRef origLoops, } // Bounds for intra-tile loops. for (unsigned i = 0; i < width; i++) { - int64_t largestDiv = getLargestDivisorOfTripCount(*origLoops[i]); - auto mayBeConstantCount = getConstantTripCount(*origLoops[i]); + int64_t largestDiv = getLargestDivisorOfTripCount(origLoops[i]); + auto mayBeConstantCount = getConstantTripCount(origLoops[i]); // The lower bound is just the tile-space loop. AffineMap lbMap = b.getDimIdentityMap(); newLoops[width + i]->setLowerBound( @@ -144,38 +147,40 @@ static void constructTiledIndexSetHyperRect(ArrayRef origLoops, /// Tiles the specified band of perfectly nested loops creating tile-space loops /// and intra-tile loops. A band is a contiguous set of loops. // TODO(bondhugula): handle non hyper-rectangular spaces. -UtilResult mlir::tileCodeGen(ArrayRef band, +UtilResult mlir::tileCodeGen(MutableArrayRef> band, ArrayRef tileSizes) { assert(!band.empty()); assert(band.size() == tileSizes.size()); // Check if the supplied for inst's are all successively nested. for (unsigned i = 1, e = band.size(); i < e; i++) { - assert(band[i]->getParentInst() == band[i - 1]); + assert(band[i]->getInstruction()->getParentInst() == + band[i - 1]->getInstruction()); } auto origLoops = band; - ForInst *rootForInst = origLoops[0]; - auto loc = rootForInst->getLoc(); + OpPointer rootAffineForOp = origLoops[0]; + auto loc = rootAffineForOp->getLoc(); // Note that width is at least one since band isn't empty. unsigned width = band.size(); - SmallVector newLoops(2 * width); - ForInst *innermostPointLoop; + SmallVector, 12> newLoops(2 * width); + OpPointer innermostPointLoop; // The outermost among the loops as we add more.. - auto *topLoop = rootForInst; + auto *topLoop = rootAffineForOp->getInstruction(); // Add intra-tile (or point) loops. for (unsigned i = 0; i < width; i++) { FuncBuilder b(topLoop); // Loop bounds will be set later. - auto *pointLoop = b.createFor(loc, 0, 0); + auto pointLoop = b.create(loc, 0, 0); + pointLoop->createBody(); pointLoop->getBody()->getInstructions().splice( pointLoop->getBody()->begin(), topLoop->getBlock()->getInstructions(), topLoop); newLoops[2 * width - 1 - i] = pointLoop; - topLoop = pointLoop; + topLoop = pointLoop->getInstruction(); if (i == 0) innermostPointLoop = pointLoop; } @@ -184,12 +189,13 @@ UtilResult mlir::tileCodeGen(ArrayRef band, for (unsigned i = width; i < 2 * width; i++) { FuncBuilder b(topLoop); // Loop bounds will be set later. - auto *tileSpaceLoop = b.createFor(loc, 0, 0); + auto tileSpaceLoop = b.create(loc, 0, 0); + tileSpaceLoop->createBody(); tileSpaceLoop->getBody()->getInstructions().splice( tileSpaceLoop->getBody()->begin(), topLoop->getBlock()->getInstructions(), topLoop); newLoops[2 * width - i - 1] = tileSpaceLoop; - topLoop = tileSpaceLoop; + topLoop = tileSpaceLoop->getInstruction(); } // Move the loop body of the original nest to the new one. @@ -201,8 +207,8 @@ UtilResult mlir::tileCodeGen(ArrayRef band, getIndexSet(band, &cst); if (!cst.isHyperRectangular(0, width)) { - rootForInst->emitError("tiled code generation unimplemented for the" - "non-hyperrectangular case"); + rootAffineForOp->emitError("tiled code generation unimplemented for the" + "non-hyperrectangular case"); return UtilResult::Failure; } @@ -213,7 +219,7 @@ UtilResult mlir::tileCodeGen(ArrayRef band, } // Erase the old loop nest. - rootForInst->erase(); + rootAffineForOp->erase(); return UtilResult::Success; } @@ -221,38 +227,36 @@ UtilResult mlir::tileCodeGen(ArrayRef band, // Identify valid and profitable bands of loops to tile. This is currently just // a temporary placeholder to test the mechanics of tiled code generation. // Returns all maximal outermost perfect loop nests to tile. -static void getTileableBands(Function *f, - std::vector> *bands) { +static void +getTileableBands(Function *f, + std::vector, 6>> *bands) { // Get maximal perfect nest of 'for' insts starting from root (inclusive). - auto getMaximalPerfectLoopNest = [&](ForInst *root) { - SmallVector band; - ForInst *currInst = root; + auto getMaximalPerfectLoopNest = [&](OpPointer root) { + SmallVector, 6> band; + OpPointer currInst = root; do { band.push_back(currInst); } while (currInst->getBody()->getInstructions().size() == 1 && - (currInst = dyn_cast(&currInst->getBody()->front()))); + (currInst = cast(currInst->getBody()->front()) + .dyn_cast())); bands->push_back(band); }; - for (auto &block : *f) { - for (auto &inst : block) { - auto *forInst = dyn_cast(&inst); - if (!forInst) - continue; - getMaximalPerfectLoopNest(forInst); - } - } + for (auto &block : *f) + for (auto &inst : block) + if (auto forOp = cast(inst).dyn_cast()) + getMaximalPerfectLoopNest(forOp); } PassResult LoopTiling::runOnFunction(Function *f) { - std::vector> bands; + std::vector, 6>> bands; getTileableBands(f, &bands); // Temporary tile sizes. unsigned tileSize = clTileSize.getNumOccurrences() > 0 ? clTileSize : kDefaultTileSize; - for (const auto &band : bands) { + for (auto &band : bands) { SmallVector tileSizes(band.size(), tileSize); if (tileCodeGen(band, tileSizes)) { return failure(); diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 6d63e4afd2d..86e913bd71f 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -21,6 +21,7 @@ #include "mlir/Transforms/Passes.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -70,18 +71,19 @@ struct LoopUnroll : public FunctionPass { const Optional unrollFull; // Callback to obtain unroll factors; if this has a callable target, takes // precedence over command-line argument or passed argument. - const std::function getUnrollFactor; + const std::function)> getUnrollFactor; - explicit LoopUnroll( - Optional unrollFactor = None, Optional unrollFull = None, - const std::function &getUnrollFactor = nullptr) + explicit LoopUnroll(Optional unrollFactor = None, + Optional unrollFull = None, + const std::function)> + &getUnrollFactor = nullptr) : FunctionPass(&LoopUnroll::passID), unrollFactor(unrollFactor), unrollFull(unrollFull), getUnrollFactor(getUnrollFactor) {} PassResult runOnFunction(Function *f) override; /// Unroll this for inst. Returns false if nothing was done. - bool runOnForInst(ForInst *forInst); + bool runOnAffineForOp(OpPointer forOp); static const unsigned kDefaultUnrollFactor = 4; @@ -96,7 +98,7 @@ PassResult LoopUnroll::runOnFunction(Function *f) { class InnermostLoopGatherer : public InstWalker { public: // Store innermost loops as we walk. - std::vector loops; + std::vector> loops; // This method specialized to encode custom return logic. using InstListType = llvm::iplist; @@ -111,20 +113,17 @@ PassResult LoopUnroll::runOnFunction(Function *f) { return hasInnerLoops; } - bool walkForInstPostOrder(ForInst *forInst) { - bool hasInnerLoops = - walkPostOrder(forInst->getBody()->begin(), forInst->getBody()->end()); - if (!hasInnerLoops) - loops.push_back(forInst); - return true; - } - bool walkOpInstPostOrder(OperationInst *opInst) { + bool hasInnerLoops = false; for (auto &blockList : opInst->getBlockLists()) for (auto &block : blockList) - if (walkPostOrder(block.begin(), block.end())) - return true; - return false; + hasInnerLoops |= walkPostOrder(block.begin(), block.end()); + if (opInst->isa()) { + if (!hasInnerLoops) + loops.push_back(opInst->cast()); + return true; + } + return hasInnerLoops; } // FIXME: can't use base class method for this because that in turn would @@ -137,14 +136,17 @@ PassResult LoopUnroll::runOnFunction(Function *f) { class ShortLoopGatherer : public InstWalker { public: // Store short loops as we walk. - std::vector loops; + std::vector> loops; const unsigned minTripCount; ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {} - void visitForInst(ForInst *forInst) { - Optional tripCount = getConstantTripCount(*forInst); + void visitOperationInst(OperationInst *opInst) { + auto forOp = opInst->dyn_cast(); + if (!forOp) + return; + Optional tripCount = getConstantTripCount(forOp); if (tripCount.hasValue() && tripCount.getValue() <= minTripCount) - loops.push_back(forInst); + loops.push_back(forOp); } }; @@ -156,8 +158,8 @@ PassResult LoopUnroll::runOnFunction(Function *f) { // ones). slg.walkPostOrder(f); auto &loops = slg.loops; - for (auto *forInst : loops) - loopUnrollFull(forInst); + for (auto forOp : loops) + loopUnrollFull(forOp); return success(); } @@ -172,8 +174,8 @@ PassResult LoopUnroll::runOnFunction(Function *f) { if (loops.empty()) break; bool unrolled = false; - for (auto *forInst : loops) - unrolled |= runOnForInst(forInst); + for (auto forOp : loops) + unrolled |= runOnAffineForOp(forOp); if (!unrolled) // Break out if nothing was unrolled. break; @@ -183,29 +185,30 @@ PassResult LoopUnroll::runOnFunction(Function *f) { /// Unrolls a 'for' inst. Returns true if the loop was unrolled, false /// otherwise. The default unroll factor is 4. -bool LoopUnroll::runOnForInst(ForInst *forInst) { +bool LoopUnroll::runOnAffineForOp(OpPointer forOp) { // Use the function callback if one was provided. if (getUnrollFactor) { - return loopUnrollByFactor(forInst, getUnrollFactor(*forInst)); + return loopUnrollByFactor(forOp, getUnrollFactor(forOp)); } // Unroll by the factor passed, if any. if (unrollFactor.hasValue()) - return loopUnrollByFactor(forInst, unrollFactor.getValue()); + return loopUnrollByFactor(forOp, unrollFactor.getValue()); // Unroll by the command line factor if one was specified. if (clUnrollFactor.getNumOccurrences() > 0) - return loopUnrollByFactor(forInst, clUnrollFactor); + return loopUnrollByFactor(forOp, clUnrollFactor); // Unroll completely if full loop unroll was specified. if (clUnrollFull.getNumOccurrences() > 0 || (unrollFull.hasValue() && unrollFull.getValue())) - return loopUnrollFull(forInst); + return loopUnrollFull(forOp); // Unroll by four otherwise. - return loopUnrollByFactor(forInst, kDefaultUnrollFactor); + return loopUnrollByFactor(forOp, kDefaultUnrollFactor); } FunctionPass *mlir::createLoopUnrollPass( int unrollFactor, int unrollFull, - const std::function &getUnrollFactor) { + const std::function)> + &getUnrollFactor) { return new LoopUnroll( unrollFactor == -1 ? None : Optional(unrollFactor), unrollFull == -1 ? None : Optional(unrollFull), getUnrollFactor); diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 7deaf850362..7327a37ee3a 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -43,6 +43,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Transforms/Passes.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -80,7 +81,7 @@ struct LoopUnrollAndJam : public FunctionPass { unrollJamFactor(unrollJamFactor) {} PassResult runOnFunction(Function *f) override; - bool runOnForInst(ForInst *forInst); + bool runOnAffineForOp(OpPointer forOp); static char passID; }; @@ -95,47 +96,51 @@ FunctionPass *mlir::createLoopUnrollAndJamPass(int unrollJamFactor) { PassResult LoopUnrollAndJam::runOnFunction(Function *f) { // Currently, just the outermost loop from the first loop nest is - // unroll-and-jammed by this pass. However, runOnForInst can be called on any - // for Inst. + // unroll-and-jammed by this pass. However, runOnAffineForOp can be called on + // any for Inst. auto &entryBlock = f->front(); if (!entryBlock.empty()) - if (auto *forInst = dyn_cast(&entryBlock.front())) - runOnForInst(forInst); + if (auto forOp = + cast(entryBlock.front()).dyn_cast()) + runOnAffineForOp(forOp); return success(); } /// Unroll and jam a 'for' inst. Default unroll jam factor is /// kDefaultUnrollJamFactor. Return false if nothing was done. -bool LoopUnrollAndJam::runOnForInst(ForInst *forInst) { +bool LoopUnrollAndJam::runOnAffineForOp(OpPointer forOp) { // Unroll and jam by the factor that was passed if any. if (unrollJamFactor.hasValue()) - return loopUnrollJamByFactor(forInst, unrollJamFactor.getValue()); + return loopUnrollJamByFactor(forOp, unrollJamFactor.getValue()); // Otherwise, unroll jam by the command-line factor if one was specified. if (clUnrollJamFactor.getNumOccurrences() > 0) - return loopUnrollJamByFactor(forInst, clUnrollJamFactor); + return loopUnrollJamByFactor(forOp, clUnrollJamFactor); // Unroll and jam by four otherwise. - return loopUnrollJamByFactor(forInst, kDefaultUnrollJamFactor); + return loopUnrollJamByFactor(forOp, kDefaultUnrollJamFactor); } -bool mlir::loopUnrollJamUpToFactor(ForInst *forInst, uint64_t unrollJamFactor) { - Optional mayBeConstantTripCount = getConstantTripCount(*forInst); +bool mlir::loopUnrollJamUpToFactor(OpPointer forOp, + uint64_t unrollJamFactor) { + Optional mayBeConstantTripCount = getConstantTripCount(forOp); if (mayBeConstantTripCount.hasValue() && mayBeConstantTripCount.getValue() < unrollJamFactor) - return loopUnrollJamByFactor(forInst, mayBeConstantTripCount.getValue()); - return loopUnrollJamByFactor(forInst, unrollJamFactor); + return loopUnrollJamByFactor(forOp, mayBeConstantTripCount.getValue()); + return loopUnrollJamByFactor(forOp, unrollJamFactor); } /// Unrolls and jams this loop by the specified factor. -bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) { +bool mlir::loopUnrollJamByFactor(OpPointer forOp, + uint64_t unrollJamFactor) { // Gathers all maximal sub-blocks of instructions that do not themselves // include a for inst (a instruction could have a descendant for inst though // in its tree). class JamBlockGatherer : public InstWalker { public: using InstListType = llvm::iplist; + using InstWalker::walk; // Store iterators to the first and last inst of each sub-block found. std::vector> subBlocks; @@ -144,30 +149,30 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) { void walk(InstListType::iterator Start, InstListType::iterator End) { for (auto it = Start; it != End;) { auto subBlockStart = it; - while (it != End && !isa(it)) + while (it != End && !cast(it)->isa()) ++it; if (it != subBlockStart) subBlocks.push_back({subBlockStart, std::prev(it)}); // Process all for insts that appear next. - while (it != End && isa(it)) - walkForInst(cast(it++)); + while (it != End && cast(it)->isa()) + walk(&*it++); } } }; assert(unrollJamFactor >= 1 && "unroll jam factor should be >= 1"); - if (unrollJamFactor == 1 || forInst->getBody()->empty()) + if (unrollJamFactor == 1 || forOp->getBody()->empty()) return false; - Optional mayBeConstantTripCount = getConstantTripCount(*forInst); + Optional mayBeConstantTripCount = getConstantTripCount(forOp); if (!mayBeConstantTripCount.hasValue() && - getLargestDivisorOfTripCount(*forInst) % unrollJamFactor != 0) + getLargestDivisorOfTripCount(forOp) % unrollJamFactor != 0) return false; - auto lbMap = forInst->getLowerBoundMap(); - auto ubMap = forInst->getUpperBoundMap(); + auto lbMap = forOp->getLowerBoundMap(); + auto ubMap = forOp->getUpperBoundMap(); // Loops with max/min expressions won't be unrolled here (the output can't be // expressed as a Function in the general case). However, the right way to @@ -178,7 +183,7 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) { // Same operand list for lower and upper bound for now. // TODO(bondhugula): handle bounds with different sets of operands. - if (!forInst->matchingBoundOperandList()) + if (!forOp->matchingBoundOperandList()) return false; // If the trip count is lower than the unroll jam factor, no unroll jam. @@ -187,35 +192,38 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) { mayBeConstantTripCount.getValue() < unrollJamFactor) return false; + auto *forInst = forOp->getInstruction(); + // Gather all sub-blocks to jam upon the loop being unrolled. JamBlockGatherer jbg; - jbg.walkForInst(forInst); + jbg.walkOpInst(forInst); auto &subBlocks = jbg.subBlocks; // Generate the cleanup loop if trip count isn't a multiple of // unrollJamFactor. if (mayBeConstantTripCount.hasValue() && mayBeConstantTripCount.getValue() % unrollJamFactor != 0) { - // Insert the cleanup loop right after 'forInst'. + // Insert the cleanup loop right after 'forOp'. FuncBuilder builder(forInst->getBlock(), std::next(Block::iterator(forInst))); - auto *cleanupForInst = cast(builder.clone(*forInst)); - cleanupForInst->setLowerBoundMap( - getCleanupLoopLowerBound(*forInst, unrollJamFactor, &builder)); + auto cleanupAffineForOp = + cast(builder.clone(*forInst))->cast(); + cleanupAffineForOp->setLowerBoundMap( + getCleanupLoopLowerBound(forOp, unrollJamFactor, &builder)); // The upper bound needs to be adjusted. - forInst->setUpperBoundMap( - getUnrolledLoopUpperBound(*forInst, unrollJamFactor, &builder)); + forOp->setUpperBoundMap( + getUnrolledLoopUpperBound(forOp, unrollJamFactor, &builder)); // Promote the loop body up if this has turned into a single iteration loop. - promoteIfSingleIteration(cleanupForInst); + promoteIfSingleIteration(cleanupAffineForOp); } // Scale the step of loop being unroll-jammed by the unroll-jam factor. - int64_t step = forInst->getStep(); - forInst->setStep(step * unrollJamFactor); + int64_t step = forOp->getStep(); + forOp->setStep(step * unrollJamFactor); - auto *forInstIV = forInst->getInductionVar(); + auto *forOpIV = forOp->getInductionVar(); for (auto &subBlock : subBlocks) { // Builder to insert unroll-jammed bodies. Insert right at the end of // sub-block. @@ -227,13 +235,13 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) { // If the induction variable is used, create a remapping to the value for // this unrolled instance. - if (!forInstIV->use_empty()) { + if (!forOpIV->use_empty()) { // iv' = iv + i, i = 1 to unrollJamFactor-1. auto d0 = builder.getAffineDimExpr(0); auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {}); - auto ivUnroll = builder.create(forInst->getLoc(), - bumpMap, forInstIV); - operandMapping.map(forInstIV, ivUnroll); + auto ivUnroll = + builder.create(forInst->getLoc(), bumpMap, forOpIV); + operandMapping.map(forOpIV, ivUnroll); } // Clone the sub-block being unroll-jammed. for (auto it = subBlock.first; it != std::next(subBlock.second); ++it) { @@ -243,7 +251,7 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) { } // Promote the loop body up if this has turned into a single iteration loop. - promoteIfSingleIteration(forInst); + promoteIfSingleIteration(forOp); return true; } diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index f770684f519..24ca4e95082 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -24,6 +24,7 @@ #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" @@ -246,7 +247,7 @@ public: LowerAffinePass() : FunctionPass(&passID) {} PassResult runOnFunction(Function *function) override; - bool lowerForInst(ForInst *forInst); + bool lowerAffineFor(OpPointer forOp); bool lowerAffineIf(AffineIfOp *ifOp); bool lowerAffineApply(AffineApplyOp *op); @@ -295,11 +296,11 @@ static Value *buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate, // a nested loop). Induction variable modification is appended to the body SESE // region that always loops back to the condition block. // -// +--------------------------------+ -// | | -// | | -// | br cond(%iv) | -// +--------------------------------+ +// +---------------------------------+ +// | | +// | | +// | br cond(%iv) | +// +---------------------------------+ // | // -------| | // | v v @@ -322,11 +323,12 @@ static Value *buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate, // v // +--------------------------------+ // | end: | -// | | +// | | // +--------------------------------+ // -bool LowerAffinePass::lowerForInst(ForInst *forInst) { - auto loc = forInst->getLoc(); +bool LowerAffinePass::lowerAffineFor(OpPointer forOp) { + auto loc = forOp->getLoc(); + auto *forInst = forOp->getInstruction(); // Start by splitting the block containing the 'for' into two parts. The part // before will get the init code, the part after will be the end point. @@ -339,23 +341,23 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) { conditionBlock->insertBefore(endBlock); auto *iv = conditionBlock->addArgument(IndexType::get(forInst->getContext())); - // Create the body block, moving the body of the forInst over to it. + // Create the body block, moving the body of the forOp over to it. auto *bodyBlock = new Block(); bodyBlock->insertBefore(endBlock); - auto *oldBody = forInst->getBody(); + auto *oldBody = forOp->getBody(); bodyBlock->getInstructions().splice(bodyBlock->begin(), oldBody->getInstructions(), oldBody->begin(), oldBody->end()); - // The code in the body of the forInst now uses 'iv' as its indvar. - forInst->getInductionVar()->replaceAllUsesWith(iv); + // The code in the body of the forOp now uses 'iv' as its indvar. + forOp->getInductionVar()->replaceAllUsesWith(iv); // Append the induction variable stepping logic and branch back to the exit // condition block. Construct an affine expression f : (x -> x+step) and // apply this expression to the induction variable. FuncBuilder builder(bodyBlock); - auto affStep = builder.getAffineConstantExpr(forInst->getStep()); + auto affStep = builder.getAffineConstantExpr(forOp->getStep()); auto affDim = builder.getAffineDimExpr(0); auto stepped = expandAffineExpr(&builder, loc, affDim + affStep, iv, {}); if (!stepped) @@ -368,18 +370,18 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) { builder.setInsertionPointToEnd(initBlock); // Compute loop bounds. - SmallVector operands(forInst->getLowerBoundOperands()); + SmallVector operands(forOp->getLowerBoundOperands()); auto lbValues = expandAffineMap(&builder, forInst->getLoc(), - forInst->getLowerBoundMap(), operands); + forOp->getLowerBoundMap(), operands); if (!lbValues) return true; Value *lowerBound = buildMinMaxReductionSeq(loc, CmpIPredicate::SGT, *lbValues, builder); - operands.assign(forInst->getUpperBoundOperands().begin(), - forInst->getUpperBoundOperands().end()); + operands.assign(forOp->getUpperBoundOperands().begin(), + forOp->getUpperBoundOperands().end()); auto ubValues = expandAffineMap(&builder, forInst->getLoc(), - forInst->getUpperBoundMap(), operands); + forOp->getUpperBoundMap(), operands); if (!ubValues) return true; Value *upperBound = @@ -394,7 +396,7 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) { endBlock, ArrayRef()); // Ok, we're done! - forInst->erase(); + forOp->erase(); return false; } @@ -614,28 +616,26 @@ PassResult LowerAffinePass::runOnFunction(Function *function) { // Collect all the For instructions as well as AffineIfOps and AffineApplyOps. // We do this as a prepass to avoid invalidating the walker with our rewrite. function->walkInsts([&](Instruction *inst) { - if (isa(inst)) - instsToRewrite.push_back(inst); - auto op = dyn_cast(inst); - if (op && (op->isa() || op->isa())) + auto op = cast(inst); + if (op->isa() || op->isa() || + op->isa()) instsToRewrite.push_back(inst); }); // Rewrite all of the ifs and fors. We walked the instructions in preorder, // so we know that we will rewrite them in the same order. - for (auto *inst : instsToRewrite) - if (auto *forInst = dyn_cast(inst)) { - if (lowerForInst(forInst)) + for (auto *inst : instsToRewrite) { + auto op = cast(inst); + if (auto ifOp = op->dyn_cast()) { + if (lowerAffineIf(ifOp)) return failure(); - } else { - auto op = cast(inst); - if (auto ifOp = op->dyn_cast()) { - if (lowerAffineIf(ifOp)) - return failure(); - } else if (lowerAffineApply(op->cast())) { + } else if (auto forOp = op->dyn_cast()) { + if (lowerAffineFor(forOp)) return failure(); - } + } else if (lowerAffineApply(op->cast())) { + return failure(); } + } return success(); } diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 432ad1f39b8..f2dae11112b 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -75,7 +75,7 @@ /// Implementation details /// ====================== /// The current decisions made by the super-vectorization pass guarantee that -/// use-def chains do not escape an enclosing vectorized ForInst. In other +/// use-def chains do not escape an enclosing vectorized AffineForOp. In other /// words, this pass operates on a scoped program slice. Furthermore, since we /// do not vectorize in the presence of conditionals for now, sliced chains are /// guaranteed not to escape the innermost scope, which has to be either the top @@ -285,13 +285,12 @@ static Value *substitute(Value *v, VectorType hwVectorType, /// /// The general problem this function solves is as follows: /// Assume a vector_transfer operation at the super-vector granularity that has -/// `l` enclosing loops (ForInst). Assume the vector transfer operation operates -/// on a MemRef of rank `r`, a super-vector of rank `s` and a hardware vector of -/// rank `h`. -/// For the purpose of illustration assume l==4, r==3, s==2, h==1 and that the -/// super-vector is vector<3x32xf32> and the hardware vector is vector<8xf32>. -/// Assume the following MLIR snippet after super-vectorization has been -/// applied: +/// `l` enclosing loops (AffineForOp). Assume the vector transfer operation +/// operates on a MemRef of rank `r`, a super-vector of rank `s` and a hardware +/// vector of rank `h`. For the purpose of illustration assume l==4, r==3, s==2, +/// h==1 and that the super-vector is vector<3x32xf32> and the hardware vector +/// is vector<8xf32>. Assume the following MLIR snippet after +/// super-vectorization has been applied: /// /// ```mlir /// for %i0 = 0 to %M { @@ -351,7 +350,7 @@ reindexAffineIndices(FuncBuilder *b, VectorType hwVectorType, SmallVector affineExprs; // TODO(ntv): support a concrete map and composition. unsigned i = 0; - // The first numMemRefIndices correspond to ForInst that have not been + // The first numMemRefIndices correspond to AffineForOp that have not been // vectorized, the transformation is the identity on those. for (i = 0; i < numMemRefIndices; ++i) { auto d_i = b->getAffineDimExpr(i); @@ -554,9 +553,6 @@ static bool instantiateMaterialization(Instruction *inst, MaterializationState *state) { LLVM_DEBUG(dbgs() << "\ninstantiate: " << *inst); - if (isa(inst)) - return inst->emitError("NYI path ForInst"); - // Create a builder here for unroll-and-jam effects. FuncBuilder b(inst); auto *opInst = cast(inst); diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 811741d08d1..2e083bbfd79 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -21,11 +21,11 @@ #include "mlir/Transforms/Passes.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Analysis/Utils.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Transforms/LoopUtils.h" @@ -38,15 +38,12 @@ using namespace mlir; namespace { -struct PipelineDataTransfer : public FunctionPass, - InstWalker { +struct PipelineDataTransfer : public FunctionPass { PipelineDataTransfer() : FunctionPass(&PipelineDataTransfer::passID) {} PassResult runOnFunction(Function *f) override; - PassResult runOnForInst(ForInst *forInst); + PassResult runOnAffineForOp(OpPointer forOp); - // Collect all 'for' instructions. - void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); } - std::vector forInsts; + std::vector> forOps; static char passID; }; @@ -79,8 +76,8 @@ static unsigned getTagMemRefPos(const OperationInst &dmaInst) { /// of the old memref by the new one while indexing the newly added dimension by /// the loop IV of the specified 'for' instruction modulo 2. Returns false if /// such a replacement cannot be performed. -static bool doubleBuffer(Value *oldMemRef, ForInst *forInst) { - auto *forBody = forInst->getBody(); +static bool doubleBuffer(Value *oldMemRef, OpPointer forOp) { + auto *forBody = forOp->getBody(); FuncBuilder bInner(forBody, forBody->begin()); bInner.setInsertionPoint(forBody, forBody->begin()); @@ -101,6 +98,7 @@ static bool doubleBuffer(Value *oldMemRef, ForInst *forInst) { auto newMemRefType = doubleShape(oldMemRefType); // Put together alloc operands for the dynamic dimensions of the memref. + auto *forInst = forOp->getInstruction(); FuncBuilder bOuter(forInst); SmallVector allocOperands; unsigned dynamicDimCount = 0; @@ -118,16 +116,16 @@ static bool doubleBuffer(Value *oldMemRef, ForInst *forInst) { // Create 'iv mod 2' value to index the leading dimension. auto d0 = bInner.getAffineDimExpr(0); - int64_t step = forInst->getStep(); + int64_t step = forOp->getStep(); auto modTwoMap = bInner.getAffineMap(/*dimCount=*/1, /*symbolCount=*/0, {d0.floorDiv(step) % 2}, {}); - auto ivModTwoOp = bInner.create(forInst->getLoc(), modTwoMap, - forInst->getInductionVar()); + auto ivModTwoOp = bInner.create(forOp->getLoc(), modTwoMap, + forOp->getInductionVar()); - // replaceAllMemRefUsesWith will always succeed unless the forInst body has + // replaceAllMemRefUsesWith will always succeed unless the forOp body has // non-deferencing uses of the memref. if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef, {ivModTwoOp}, AffineMap(), - {}, &*forInst->getBody()->begin())) { + {}, &*forOp->getBody()->begin())) { LLVM_DEBUG(llvm::dbgs() << "memref replacement for double buffering failed\n";); ivModTwoOp->getInstruction()->erase(); @@ -143,11 +141,14 @@ PassResult PipelineDataTransfer::runOnFunction(Function *f) { // invalid (erased) when the outer loop is pipelined (the pipelined one gets // deleted and replaced by a prologue, a new steady-state loop and an // epilogue). - forInsts.clear(); - walkPostOrder(f); + forOps.clear(); + f->walkOpsPostOrder([&](OperationInst *opInst) { + if (auto forOp = opInst->dyn_cast()) + forOps.push_back(forOp); + }); bool ret = false; - for (auto *forInst : forInsts) { - ret = ret | runOnForInst(forInst); + for (auto forOp : forOps) { + ret = ret | runOnAffineForOp(forOp); } return ret ? failure() : success(); } @@ -178,13 +179,13 @@ static bool checkTagMatch(OpPointer startOp, // Identify matching DMA start/finish instructions to overlap computation with. static void findMatchingStartFinishInsts( - ForInst *forInst, + OpPointer forOp, SmallVectorImpl> &startWaitPairs) { // Collect outgoing DMA instructions - needed to check for dependences below. SmallVector, 4> outgoingDmaOps; - for (auto &inst : *forInst->getBody()) { + for (auto &inst : *forOp->getBody()) { auto *opInst = dyn_cast(&inst); if (!opInst) continue; @@ -195,7 +196,7 @@ static void findMatchingStartFinishInsts( } SmallVector dmaStartInsts, dmaFinishInsts; - for (auto &inst : *forInst->getBody()) { + for (auto &inst : *forOp->getBody()) { auto *opInst = dyn_cast(&inst); if (!opInst) continue; @@ -227,7 +228,7 @@ static void findMatchingStartFinishInsts( auto *memref = dmaStartOp->getOperand(dmaStartOp->getFasterMemPos()); bool escapingUses = false; for (const auto &use : memref->getUses()) { - if (!forInst->getBody()->findAncestorInstInBlock(*use.getOwner())) { + if (!forOp->getBody()->findAncestorInstInBlock(*use.getOwner())) { LLVM_DEBUG(llvm::dbgs() << "can't pipeline: buffer is live out of loop\n";); escapingUses = true; @@ -251,17 +252,18 @@ static void findMatchingStartFinishInsts( } /// Overlap DMA transfers with computation in this loop. If successful, -/// 'forInst' is deleted, and a prologue, a new pipelined loop, and epilogue are +/// 'forOp' is deleted, and a prologue, a new pipelined loop, and epilogue are /// inserted right before where it was. -PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) { - auto mayBeConstTripCount = getConstantTripCount(*forInst); +PassResult +PipelineDataTransfer::runOnAffineForOp(OpPointer forOp) { + auto mayBeConstTripCount = getConstantTripCount(forOp); if (!mayBeConstTripCount.hasValue()) { LLVM_DEBUG(llvm::dbgs() << "unknown trip count loop\n"); return success(); } SmallVector, 4> startWaitPairs; - findMatchingStartFinishInsts(forInst, startWaitPairs); + findMatchingStartFinishInsts(forOp, startWaitPairs); if (startWaitPairs.empty()) { LLVM_DEBUG(llvm::dbgs() << "No dma start/finish pairs\n";); @@ -280,7 +282,7 @@ PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) { auto *dmaStartInst = pair.first; Value *oldMemRef = dmaStartInst->getOperand( dmaStartInst->cast()->getFasterMemPos()); - if (!doubleBuffer(oldMemRef, forInst)) { + if (!doubleBuffer(oldMemRef, forOp)) { // Normally, double buffering should not fail because we already checked // that there are no uses outside. LLVM_DEBUG(llvm::dbgs() << "double buffering failed for: \n";); @@ -302,7 +304,7 @@ PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) { auto *dmaFinishInst = pair.second; Value *oldTagMemRef = dmaFinishInst->getOperand(getTagMemRefPos(*dmaFinishInst)); - if (!doubleBuffer(oldTagMemRef, forInst)) { + if (!doubleBuffer(oldTagMemRef, forOp)) { LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";); return success(); } @@ -315,7 +317,7 @@ PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) { // Double buffering would have invalidated all the old DMA start/wait insts. startWaitPairs.clear(); - findMatchingStartFinishInsts(forInst, startWaitPairs); + findMatchingStartFinishInsts(forOp, startWaitPairs); // Store shift for instruction for later lookup for AffineApplyOp's. DenseMap instShiftMap; @@ -342,16 +344,16 @@ PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) { } } // Everything else (including compute ops and dma finish) are shifted by one. - for (const auto &inst : *forInst->getBody()) { + for (const auto &inst : *forOp->getBody()) { if (instShiftMap.find(&inst) == instShiftMap.end()) { instShiftMap[&inst] = 1; } } // Get shifts stored in map. - std::vector shifts(forInst->getBody()->getInstructions().size()); + std::vector shifts(forOp->getBody()->getInstructions().size()); unsigned s = 0; - for (auto &inst : *forInst->getBody()) { + for (auto &inst : *forOp->getBody()) { assert(instShiftMap.find(&inst) != instShiftMap.end()); shifts[s++] = instShiftMap[&inst]; LLVM_DEBUG( @@ -363,13 +365,13 @@ PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) { }); } - if (!isInstwiseShiftValid(*forInst, shifts)) { + if (!isInstwiseShiftValid(forOp, shifts)) { // Violates dependences. LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";); return success(); } - if (instBodySkew(forInst, shifts)) { + if (instBodySkew(forOp, shifts)) { LLVM_DEBUG(llvm::dbgs() << "inst body skewing failed - unexpected\n";); return success(); } diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp index ba59123c700..ae003b3e495 100644 --- a/mlir/lib/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp @@ -22,6 +22,7 @@ #include "mlir/Analysis/AffineStructures.h" #include "mlir/IR/Function.h" #include "mlir/IR/Instructions.h" +#include "mlir/IR/IntegerSet.h" #include "mlir/Pass.h" #include "mlir/Transforms/Passes.h" diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 59da2b0a56e..ce16656243d 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -21,6 +21,7 @@ #include "mlir/Transforms/LoopUtils.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -39,22 +40,22 @@ using namespace mlir; /// Returns the upper bound of an unrolled loop with lower bound 'lb' and with /// the specified trip count, stride, and unroll factor. Returns nullptr when /// the trip count can't be expressed as an affine expression. -AffineMap mlir::getUnrolledLoopUpperBound(const ForInst &forInst, +AffineMap mlir::getUnrolledLoopUpperBound(ConstOpPointer forOp, unsigned unrollFactor, FuncBuilder *builder) { - auto lbMap = forInst.getLowerBoundMap(); + auto lbMap = forOp->getLowerBoundMap(); // Single result lower bound map only. if (lbMap.getNumResults() != 1) return AffineMap(); // Sometimes, the trip count cannot be expressed as an affine expression. - auto tripCount = getTripCountExpr(forInst); + auto tripCount = getTripCountExpr(forOp); if (!tripCount) return AffineMap(); AffineExpr lb(lbMap.getResult(0)); - unsigned step = forInst.getStep(); + unsigned step = forOp->getStep(); auto newUb = lb + (tripCount - tripCount % unrollFactor - 1) * step; return builder->getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(), @@ -65,50 +66,51 @@ AffineMap mlir::getUnrolledLoopUpperBound(const ForInst &forInst, /// bound 'lb' and with the specified trip count, stride, and unroll factor. /// Returns an AffinMap with nullptr storage (that evaluates to false) /// when the trip count can't be expressed as an affine expression. -AffineMap mlir::getCleanupLoopLowerBound(const ForInst &forInst, +AffineMap mlir::getCleanupLoopLowerBound(ConstOpPointer forOp, unsigned unrollFactor, FuncBuilder *builder) { - auto lbMap = forInst.getLowerBoundMap(); + auto lbMap = forOp->getLowerBoundMap(); // Single result lower bound map only. if (lbMap.getNumResults() != 1) return AffineMap(); // Sometimes the trip count cannot be expressed as an affine expression. - AffineExpr tripCount(getTripCountExpr(forInst)); + AffineExpr tripCount(getTripCountExpr(forOp)); if (!tripCount) return AffineMap(); AffineExpr lb(lbMap.getResult(0)); - unsigned step = forInst.getStep(); + unsigned step = forOp->getStep(); auto newLb = lb + (tripCount - tripCount % unrollFactor) * step; return builder->getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(), {newLb}, {}); } -/// Promotes the loop body of a forInst to its containing block if the forInst +/// Promotes the loop body of a forOp to its containing block if the forOp /// was known to have a single iteration. Returns false otherwise. // TODO(bondhugula): extend this for arbitrary affine bounds. -bool mlir::promoteIfSingleIteration(ForInst *forInst) { - Optional tripCount = getConstantTripCount(*forInst); +bool mlir::promoteIfSingleIteration(OpPointer forOp) { + Optional tripCount = getConstantTripCount(forOp); if (!tripCount.hasValue() || tripCount.getValue() != 1) return false; // TODO(mlir-team): there is no builder for a max. - if (forInst->getLowerBoundMap().getNumResults() != 1) + if (forOp->getLowerBoundMap().getNumResults() != 1) return false; // Replaces all IV uses to its single iteration value. - auto *iv = forInst->getInductionVar(); + auto *iv = forOp->getInductionVar(); + OperationInst *forInst = forOp->getInstruction(); if (!iv->use_empty()) { - if (forInst->hasConstantLowerBound()) { + if (forOp->hasConstantLowerBound()) { auto *mlFunc = forInst->getFunction(); FuncBuilder topBuilder(mlFunc); auto constOp = topBuilder.create( - forInst->getLoc(), forInst->getConstantLowerBound()); + forOp->getLoc(), forOp->getConstantLowerBound()); iv->replaceAllUsesWith(constOp); } else { - const AffineBound lb = forInst->getLowerBound(); + const AffineBound lb = forOp->getLowerBound(); SmallVector lbOperands(lb.operand_begin(), lb.operand_end()); FuncBuilder builder(forInst->getBlock(), Block::iterator(forInst)); if (lb.getMap() == builder.getDimIdentityMap()) { @@ -124,8 +126,8 @@ bool mlir::promoteIfSingleIteration(ForInst *forInst) { // Move the loop body instructions to the loop's containing block. auto *block = forInst->getBlock(); block->getInstructions().splice(Block::iterator(forInst), - forInst->getBody()->getInstructions()); - forInst->erase(); + forOp->getBody()->getInstructions()); + forOp->erase(); return true; } @@ -133,13 +135,10 @@ bool mlir::promoteIfSingleIteration(ForInst *forInst) { /// their body into the containing Block. void mlir::promoteSingleIterationLoops(Function *f) { // Gathers all innermost loops through a post order pruned walk. - class LoopBodyPromoter : public InstWalker { - public: - void visitForInst(ForInst *forInst) { promoteIfSingleIteration(forInst); } - }; - - LoopBodyPromoter fsw; - fsw.walkPostOrder(f); + f->walkOpsPostOrder([](OperationInst *inst) { + if (auto forOp = inst->dyn_cast()) + promoteIfSingleIteration(forOp); + }); } /// Generates a 'for' inst with the specified lower and upper bounds while @@ -149,19 +148,22 @@ void mlir::promoteSingleIterationLoops(Function *f) { /// the pair specifies the shift applied to that group of instructions; note /// that the shift is multiplied by the loop step before being applied. Returns /// nullptr if the generated loop simplifies to a single iteration one. -static ForInst * +static OpPointer generateLoop(AffineMap lbMap, AffineMap ubMap, const std::vector>> &instGroupQueue, - unsigned offset, ForInst *srcForInst, FuncBuilder *b) { + unsigned offset, OpPointer srcForInst, + FuncBuilder *b) { SmallVector lbOperands(srcForInst->getLowerBoundOperands()); SmallVector ubOperands(srcForInst->getUpperBoundOperands()); assert(lbMap.getNumInputs() == lbOperands.size()); assert(ubMap.getNumInputs() == ubOperands.size()); - auto *loopChunk = b->createFor(srcForInst->getLoc(), lbOperands, lbMap, - ubOperands, ubMap, srcForInst->getStep()); + auto loopChunk = + b->create(srcForInst->getLoc(), lbOperands, lbMap, + ubOperands, ubMap, srcForInst->getStep()); + loopChunk->createBody(); auto *loopChunkIV = loopChunk->getInductionVar(); auto *srcIV = srcForInst->getInductionVar(); @@ -176,7 +178,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, // Generate the remapping if the shift is not zero: remappedIV = newIV - // shift. if (!srcIV->use_empty() && shift != 0) { - auto b = FuncBuilder::getForInstBodyBuilder(loopChunk); + FuncBuilder b(loopChunk->getBody()); auto ivRemap = b.create( srcForInst->getLoc(), b.getSingleDimShiftAffineMap( @@ -191,7 +193,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, } } if (promoteIfSingleIteration(loopChunk)) - return nullptr; + return OpPointer(); return loopChunk; } @@ -210,28 +212,29 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, // asserts preservation of SSA dominance. A check for that as well as that for // memory-based depedence preservation check rests with the users of this // method. -UtilResult mlir::instBodySkew(ForInst *forInst, ArrayRef shifts, +UtilResult mlir::instBodySkew(OpPointer forOp, + ArrayRef shifts, bool unrollPrologueEpilogue) { - if (forInst->getBody()->empty()) + if (forOp->getBody()->empty()) return UtilResult::Success; // If the trip counts aren't constant, we would need versioning and // conditional guards (or context information to prevent such versioning). The // better way to pipeline for such loops is to first tile them and extract // constant trip count "full tiles" before applying this. - auto mayBeConstTripCount = getConstantTripCount(*forInst); + auto mayBeConstTripCount = getConstantTripCount(forOp); if (!mayBeConstTripCount.hasValue()) { LLVM_DEBUG(llvm::dbgs() << "non-constant trip count loop\n";); return UtilResult::Success; } uint64_t tripCount = mayBeConstTripCount.getValue(); - assert(isInstwiseShiftValid(*forInst, shifts) && + assert(isInstwiseShiftValid(forOp, shifts) && "shifts will lead to an invalid transformation\n"); - int64_t step = forInst->getStep(); + int64_t step = forOp->getStep(); - unsigned numChildInsts = forInst->getBody()->getInstructions().size(); + unsigned numChildInsts = forOp->getBody()->getInstructions().size(); // Do a linear time (counting) sort for the shifts. uint64_t maxShift = 0; @@ -249,7 +252,7 @@ UtilResult mlir::instBodySkew(ForInst *forInst, ArrayRef shifts, // body of the 'for' inst. std::vector> sortedInstGroups(maxShift + 1); unsigned pos = 0; - for (auto &inst : *forInst->getBody()) { + for (auto &inst : *forOp->getBody()) { auto shift = shifts[pos++]; sortedInstGroups[shift].push_back(&inst); } @@ -259,17 +262,17 @@ UtilResult mlir::instBodySkew(ForInst *forInst, ArrayRef shifts, // Nevertheless, if 'unrollPrologueEpilogue' is set, we will treat the first // loop generated as the prologue and the last as epilogue and unroll these // fully. - ForInst *prologue = nullptr; - ForInst *epilogue = nullptr; + OpPointer prologue; + OpPointer epilogue; // Do a sweep over the sorted shifts while storing open groups in a // vector, and generating loop portions as necessary during the sweep. A block // of instructions is paired with its shift. std::vector>> instGroupQueue; - auto origLbMap = forInst->getLowerBoundMap(); + auto origLbMap = forOp->getLowerBoundMap(); uint64_t lbShift = 0; - FuncBuilder b(forInst); + FuncBuilder b(forOp->getInstruction()); for (uint64_t d = 0, e = sortedInstGroups.size(); d < e; ++d) { // If nothing is shifted by d, continue. if (sortedInstGroups[d].empty()) @@ -280,19 +283,19 @@ UtilResult mlir::instBodySkew(ForInst *forInst, ArrayRef shifts, // The interval for which the loop needs to be generated here is: // [lbShift, min(lbShift + tripCount, d)) and the body of the // loop needs to have all instructions in instQueue in that order. - ForInst *res; + OpPointer res; if (lbShift + tripCount * step < d * step) { res = generateLoop( b.getShiftedAffineMap(origLbMap, lbShift), b.getShiftedAffineMap(origLbMap, lbShift + tripCount * step), - instGroupQueue, 0, forInst, &b); + instGroupQueue, 0, forOp, &b); // Entire loop for the queued inst groups generated, empty it. instGroupQueue.clear(); lbShift += tripCount * step; } else { res = generateLoop(b.getShiftedAffineMap(origLbMap, lbShift), b.getShiftedAffineMap(origLbMap, d), instGroupQueue, - 0, forInst, &b); + 0, forOp, &b); lbShift = d * step; } if (!prologue && res) @@ -312,60 +315,63 @@ UtilResult mlir::instBodySkew(ForInst *forInst, ArrayRef shifts, uint64_t ubShift = (instGroupQueue[i].first + tripCount) * step; epilogue = generateLoop(b.getShiftedAffineMap(origLbMap, lbShift), b.getShiftedAffineMap(origLbMap, ubShift), - instGroupQueue, i, forInst, &b); + instGroupQueue, i, forOp, &b); lbShift = ubShift; if (!prologue) prologue = epilogue; } // Erase the original for inst. - forInst->erase(); + forOp->erase(); if (unrollPrologueEpilogue && prologue) loopUnrollFull(prologue); - if (unrollPrologueEpilogue && !epilogue && epilogue != prologue) + if (unrollPrologueEpilogue && !epilogue && + epilogue->getInstruction() != prologue->getInstruction()) loopUnrollFull(epilogue); return UtilResult::Success; } /// Unrolls this loop completely. -bool mlir::loopUnrollFull(ForInst *forInst) { - Optional mayBeConstantTripCount = getConstantTripCount(*forInst); +bool mlir::loopUnrollFull(OpPointer forOp) { + Optional mayBeConstantTripCount = getConstantTripCount(forOp); if (mayBeConstantTripCount.hasValue()) { uint64_t tripCount = mayBeConstantTripCount.getValue(); if (tripCount == 1) { - return promoteIfSingleIteration(forInst); + return promoteIfSingleIteration(forOp); } - return loopUnrollByFactor(forInst, tripCount); + return loopUnrollByFactor(forOp, tripCount); } return false; } /// Unrolls and jams this loop by the specified factor or by the trip count (if /// constant) whichever is lower. -bool mlir::loopUnrollUpToFactor(ForInst *forInst, uint64_t unrollFactor) { - Optional mayBeConstantTripCount = getConstantTripCount(*forInst); +bool mlir::loopUnrollUpToFactor(OpPointer forOp, + uint64_t unrollFactor) { + Optional mayBeConstantTripCount = getConstantTripCount(forOp); if (mayBeConstantTripCount.hasValue() && mayBeConstantTripCount.getValue() < unrollFactor) - return loopUnrollByFactor(forInst, mayBeConstantTripCount.getValue()); - return loopUnrollByFactor(forInst, unrollFactor); + return loopUnrollByFactor(forOp, mayBeConstantTripCount.getValue()); + return loopUnrollByFactor(forOp, unrollFactor); } /// Unrolls this loop by the specified factor. Returns true if the loop /// is successfully unrolled. -bool mlir::loopUnrollByFactor(ForInst *forInst, uint64_t unrollFactor) { +bool mlir::loopUnrollByFactor(OpPointer forOp, + uint64_t unrollFactor) { assert(unrollFactor >= 1 && "unroll factor should be >= 1"); if (unrollFactor == 1) - return promoteIfSingleIteration(forInst); + return promoteIfSingleIteration(forOp); - if (forInst->getBody()->empty()) + if (forOp->getBody()->empty()) return false; - auto lbMap = forInst->getLowerBoundMap(); - auto ubMap = forInst->getUpperBoundMap(); + auto lbMap = forOp->getLowerBoundMap(); + auto ubMap = forOp->getUpperBoundMap(); // Loops with max/min expressions won't be unrolled here (the output can't be // expressed as a Function in the general case). However, the right way to @@ -376,10 +382,10 @@ bool mlir::loopUnrollByFactor(ForInst *forInst, uint64_t unrollFactor) { // Same operand list for lower and upper bound for now. // TODO(bondhugula): handle bounds with different operand lists. - if (!forInst->matchingBoundOperandList()) + if (!forOp->matchingBoundOperandList()) return false; - Optional mayBeConstantTripCount = getConstantTripCount(*forInst); + Optional mayBeConstantTripCount = getConstantTripCount(forOp); // If the trip count is lower than the unroll factor, no unrolled body. // TODO(bondhugula): option to specify cleanup loop unrolling. @@ -388,10 +394,12 @@ bool mlir::loopUnrollByFactor(ForInst *forInst, uint64_t unrollFactor) { return false; // Generate the cleanup loop if trip count isn't a multiple of unrollFactor. - if (getLargestDivisorOfTripCount(*forInst) % unrollFactor != 0) { + OperationInst *forInst = forOp->getInstruction(); + if (getLargestDivisorOfTripCount(forOp) % unrollFactor != 0) { FuncBuilder builder(forInst->getBlock(), ++Block::iterator(forInst)); - auto *cleanupForInst = cast(builder.clone(*forInst)); - auto clLbMap = getCleanupLoopLowerBound(*forInst, unrollFactor, &builder); + auto cleanupForInst = + cast(builder.clone(*forInst))->cast(); + auto clLbMap = getCleanupLoopLowerBound(forOp, unrollFactor, &builder); assert(clLbMap && "cleanup loop lower bound map for single result bound maps can " "always be determined"); @@ -401,50 +409,50 @@ bool mlir::loopUnrollByFactor(ForInst *forInst, uint64_t unrollFactor) { // Adjust upper bound. auto unrolledUbMap = - getUnrolledLoopUpperBound(*forInst, unrollFactor, &builder); + getUnrolledLoopUpperBound(forOp, unrollFactor, &builder); assert(unrolledUbMap && "upper bound map can alwayys be determined for an unrolled loop " "with single result bounds"); - forInst->setUpperBoundMap(unrolledUbMap); + forOp->setUpperBoundMap(unrolledUbMap); } // Scale the step of loop being unrolled by unroll factor. - int64_t step = forInst->getStep(); - forInst->setStep(step * unrollFactor); + int64_t step = forOp->getStep(); + forOp->setStep(step * unrollFactor); // Builder to insert unrolled bodies right after the last instruction in the - // body of 'forInst'. - FuncBuilder builder(forInst->getBody(), forInst->getBody()->end()); + // body of 'forOp'. + FuncBuilder builder(forOp->getBody(), forOp->getBody()->end()); // Keep a pointer to the last instruction in the original block so that we // know what to clone (since we are doing this in-place). - Block::iterator srcBlockEnd = std::prev(forInst->getBody()->end()); + Block::iterator srcBlockEnd = std::prev(forOp->getBody()->end()); - // Unroll the contents of 'forInst' (append unrollFactor-1 additional copies). - auto *forInstIV = forInst->getInductionVar(); + // Unroll the contents of 'forOp' (append unrollFactor-1 additional copies). + auto *forOpIV = forOp->getInductionVar(); for (unsigned i = 1; i < unrollFactor; i++) { BlockAndValueMapping operandMap; // If the induction variable is used, create a remapping to the value for // this unrolled instance. - if (!forInstIV->use_empty()) { + if (!forOpIV->use_empty()) { // iv' = iv + 1/2/3...unrollFactor-1; auto d0 = builder.getAffineDimExpr(0); auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {}); auto ivUnroll = - builder.create(forInst->getLoc(), bumpMap, forInstIV); - operandMap.map(forInstIV, ivUnroll); + builder.create(forOp->getLoc(), bumpMap, forOpIV); + operandMap.map(forOpIV, ivUnroll); } - // Clone the original body of 'forInst'. - for (auto it = forInst->getBody()->begin(); it != std::next(srcBlockEnd); + // Clone the original body of 'forOp'. + for (auto it = forOp->getBody()->begin(); it != std::next(srcBlockEnd); it++) { builder.clone(*it, operandMap); } } // Promote the loop body up if this has turned into a single iteration loop. - promoteIfSingleIteration(forInst); + promoteIfSingleIteration(forOp); return true; } diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index d3689d056d6..819f1a59b6f 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -22,6 +22,7 @@ #include "mlir/Transforms/Utils.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Dominance.h" @@ -278,8 +279,8 @@ void mlir::createAffineComputationSlice( /// Folds the specified (lower or upper) bound to a constant if possible /// considering its operands. Returns false if the folding happens for any of /// the bounds, true otherwise. -bool mlir::constantFoldBounds(ForInst *forInst) { - auto foldLowerOrUpperBound = [forInst](bool lower) { +bool mlir::constantFoldBounds(OpPointer forInst) { + auto foldLowerOrUpperBound = [&forInst](bool lower) { // Check if the bound is already a constant. if (lower && forInst->hasConstantLowerBound()) return true; diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index ac551d7c20c..7f26161e520 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -20,6 +20,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Analysis/NestedMatcher.h" #include "mlir/Analysis/VectorAnalysis.h" @@ -252,9 +253,9 @@ using namespace mlir; /// ========== /// The algorithm proceeds in a few steps: /// 1. defining super-vectorization patterns and matching them on the tree of -/// ForInst. A super-vectorization pattern is defined as a recursive data -/// structures that matches and captures nested, imperfectly-nested loops -/// that have a. comformable loop annotations attached (e.g. parallel, +/// AffineForOp. A super-vectorization pattern is defined as a recursive +/// data structures that matches and captures nested, imperfectly-nested +/// loops that have a. comformable loop annotations attached (e.g. parallel, /// reduction, vectoriable, ...) as well as b. all contiguous load/store /// operations along a specified minor dimension (not necessarily the /// fastest varying) ; @@ -279,11 +280,11 @@ using namespace mlir; /// it by its vector form. Otherwise, if the scalar value is a constant, /// it is vectorized into a splat. In all other cases, vectorization for /// the pattern currently fails. -/// e. if everything under the root ForInst in the current pattern vectorizes -/// properly, we commit that loop to the IR. Otherwise we discard it and -/// restore a previously cloned version of the loop. Thanks to the -/// recursive scoping nature of matchers and captured patterns, this is -/// transparently achieved by a simple RAII implementation. +/// e. if everything under the root AffineForOp in the current pattern +/// vectorizes properly, we commit that loop to the IR. Otherwise we +/// discard it and restore a previously cloned version of the loop. Thanks +/// to the recursive scoping nature of matchers and captured patterns, +/// this is transparently achieved by a simple RAII implementation. /// f. vectorization is applied on the next pattern in the list. Because /// pattern interference avoidance is not yet implemented and that we do /// not support further vectorizing an already vector load we need to @@ -667,12 +668,13 @@ namespace { struct VectorizationStrategy { SmallVector vectorSizes; - DenseMap loopToVectorDim; + DenseMap loopToVectorDim; }; } // end anonymous namespace -static void vectorizeLoopIfProfitable(ForInst *loop, unsigned depthInPattern, +static void vectorizeLoopIfProfitable(Instruction *loop, + unsigned depthInPattern, unsigned patternDepth, VectorizationStrategy *strategy) { assert(patternDepth > depthInPattern && @@ -704,13 +706,13 @@ static bool analyzeProfitability(ArrayRef matches, unsigned depthInPattern, unsigned patternDepth, VectorizationStrategy *strategy) { for (auto m : matches) { - auto *loop = cast(m.getMatchedInstruction()); bool fail = analyzeProfitability(m.getMatchedChildren(), depthInPattern + 1, patternDepth, strategy); if (fail) { return fail; } - vectorizeLoopIfProfitable(loop, depthInPattern, patternDepth, strategy); + vectorizeLoopIfProfitable(m.getMatchedInstruction(), depthInPattern, + patternDepth, strategy); } return false; } @@ -855,8 +857,8 @@ static bool vectorizeRootOrTerminal(Value *iv, LoadOrStoreOpPointer memoryOp, /// Coarsens the loops bounds and transforms all remaining load and store /// operations into the appropriate vector_transfer. -static bool vectorizeForInst(ForInst *loop, int64_t step, - VectorizationState *state) { +static bool vectorizeAffineForOp(AffineForOp *loop, int64_t step, + VectorizationState *state) { using namespace functional; loop->setStep(step); @@ -873,7 +875,7 @@ static bool vectorizeForInst(ForInst *loop, int64_t step, }; auto loadAndStores = matcher::Op(notVectorizedThisPattern); SmallVector loadAndStoresMatches; - loadAndStores.match(loop, &loadAndStoresMatches); + loadAndStores.match(loop->getInstruction(), &loadAndStoresMatches); for (auto ls : loadAndStoresMatches) { auto *opInst = cast(ls.getMatchedInstruction()); auto load = opInst->dyn_cast(); @@ -898,7 +900,7 @@ static bool vectorizeForInst(ForInst *loop, int64_t step, static FilterFunctionType isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension) { return [fastestVaryingMemRefDimension](const Instruction &forInst) { - const auto &loop = cast(forInst); + auto loop = cast(forInst).cast(); return isVectorizableLoopAlongFastestVaryingMemRefDim( loop, fastestVaryingMemRefDimension); }; @@ -912,7 +914,8 @@ static bool vectorizeNonRoot(ArrayRef matches, /// if all vectorizations in `childrenMatches` have already succeeded /// recursively in DFS post-order. static bool doVectorize(NestedMatch oneMatch, VectorizationState *state) { - ForInst *loop = cast(oneMatch.getMatchedInstruction()); + auto *loopInst = oneMatch.getMatchedInstruction(); + auto loop = cast(loopInst)->cast(); auto childrenMatches = oneMatch.getMatchedChildren(); // 1. DFS postorder recursion, if any of my children fails, I fail too. @@ -924,7 +927,7 @@ static bool doVectorize(NestedMatch oneMatch, VectorizationState *state) { // 2. This loop may have been omitted from vectorization for various reasons // (e.g. due to the performance model or pattern depth > vector size). - auto it = state->strategy->loopToVectorDim.find(loop); + auto it = state->strategy->loopToVectorDim.find(loopInst); if (it == state->strategy->loopToVectorDim.end()) { return false; } @@ -939,10 +942,10 @@ static bool doVectorize(NestedMatch oneMatch, VectorizationState *state) { // exploratory tradeoffs (see top of the file). Apply coarsening, i.e.: // | ub -> ub // | step -> step * vectorSize - LLVM_DEBUG(dbgs() << "\n[early-vect] vectorizeForInst by " << vectorSize + LLVM_DEBUG(dbgs() << "\n[early-vect] vectorizeForOp by " << vectorSize << " : "); - LLVM_DEBUG(loop->print(dbgs())); - return vectorizeForInst(loop, loop->getStep() * vectorSize, state); + LLVM_DEBUG(loopInst->print(dbgs())); + return vectorizeAffineForOp(loop, loop->getStep() * vectorSize, state); } /// Non-root pattern iterates over the matches at this level, calls doVectorize @@ -1186,7 +1189,8 @@ static bool vectorizeOperations(VectorizationState *state) { /// Each root may succeed independently but will otherwise clean after itself if /// anything below it fails. static bool vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) { - auto *loop = cast(m.getMatchedInstruction()); + auto loop = + cast(m.getMatchedInstruction())->cast(); VectorizationState state; state.strategy = strategy; @@ -1197,17 +1201,20 @@ static bool vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) { // vectorizable. If a pattern is not vectorizable anymore, we just skip it. // TODO(ntv): implement a non-greedy profitability analysis that keeps only // non-intersecting patterns. - if (!isVectorizableLoop(*loop)) { + if (!isVectorizableLoop(loop)) { LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ loop is not vectorizable"); return true; } - FuncBuilder builder(loop); // builder to insert in place of loop - ForInst *clonedLoop = cast(builder.clone(*loop)); + auto *loopInst = loop->getInstruction(); + FuncBuilder builder(loopInst); + auto clonedLoop = + cast(builder.clone(*loopInst))->cast(); + auto fail = doVectorize(m, &state); /// Sets up error handling for this root loop. This is how the root match /// maintains a clone for handling failure and restores the proper state via /// RAII. - ScopeGuard sg2([&fail, loop, clonedLoop]() { + ScopeGuard sg2([&fail, &loop, &clonedLoop]() { if (fail) { loop->getInductionVar()->replaceAllUsesWith( clonedLoop->getInductionVar()); @@ -1291,8 +1298,8 @@ PassResult Vectorize::runOnFunction(Function *f) { if (fail) { continue; } - auto *loop = cast(m.getMatchedInstruction()); - vectorizeLoopIfProfitable(loop, 0, patternDepth, &strategy); + vectorizeLoopIfProfitable(m.getMatchedInstruction(), 0, patternDepth, + &strategy); // TODO(ntv): if pattern does not apply, report it; alter the // cost/benefit. fail = vectorizeRootMatch(m, &strategy); diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir index 595991c0109..e41f88c901b 100644 --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -204,7 +204,7 @@ func @illegaltype(i0) // expected-error {{invalid integer width}} // ----- func @malformed_for_percent() { - for i = 1 to 10 { // expected-error {{expected SSA identifier for the loop variable}} + for i = 1 to 10 { // expected-error {{expected SSA operand}} // ----- @@ -222,18 +222,18 @@ func @malformed_for_to() { func @incomplete_for() { for %i = 1 to 10 step 2 -} // expected-error {{expected '{' before instruction list}} +} // expected-error {{expected '{' to begin block list}} // ----- func @nonconstant_step(%1 : i32) { - for %2 = 1 to 5 step %1 { // expected-error {{expected integer}} + for %2 = 1 to 5 step %1 { // expected-error {{expected type}} // ----- func @for_negative_stride() { for %i = 1 to 10 step -1 -} // expected-error {{step has to be a positive integer}} +} // expected-error@-1 {{expected step to be representable as a positive signed integer}} // ----- @@ -510,7 +510,7 @@ func @undefined_function() { func @bound_symbol_mismatch(%N : index) { for %i = #map1(%N) to 100 { - // expected-error@-1 {{symbol operand count and affine map symbol count must match}} + // expected-error@-1 {{symbol operand count and integer set symbol count must match}} } return } @@ -521,78 +521,7 @@ func @bound_symbol_mismatch(%N : index) { func @bound_dim_mismatch(%N : index) { for %i = #map1(%N, %N)[%N] to 100 { - // expected-error@-1 {{dim operand count and affine map dim count must match}} - } - return -} - -// ----- - -#map1 = (i)[j] -> (i+j) - -func @invalid_dim_nested(%N : index) { - for %i = 1 to 100 { - %a = "foo"(%N) : (index)->(index) - for %j = 1 to #map1(%a)[%i] { - // expected-error@-1 {{value '%a' cannot be used as a dimension id}} - } - } - return -} - -// ----- - -#map1 = (i)[j] -> (i+j) - -func @invalid_dim_affine_apply(%N : index) { - for %i = 1 to 100 { - %a = "foo"(%N) : (index)->(index) - %w = affine_apply (i)->(i+1) (%a) - for %j = 1 to #map1(%w)[%i] { - // expected-error@-1 {{value '%w' cannot be used as a dimension id}} - } - } - return -} - -// ----- - -#map1 = (i)[j] -> (i+j) - -func @invalid_symbol_iv(%N : index) { - for %i = 1 to 100 { - %a = "foo"(%N) : (index)->(index) - for %j = 1 to #map1(%N)[%i] { - // expected-error@-1 {{value '%i' cannot be used as a symbol}} - } - } - return -} - -// ----- - -#map1 = (i)[j] -> (i+j) - -func @invalid_symbol_nested(%N : index) { - for %i = 1 to 100 { - %a = "foo"(%N) : (index)->(index) - for %j = 1 to #map1(%N)[%a] { - // expected-error@-1 {{value '%a' cannot be used as a symbol}} - } - } - return -} - -// ----- - -#map1 = (i)[j] -> (i+j) - -func @invalid_symbol_affine_apply(%N : index) { - for %i = 1 to 100 { - %w = affine_apply (i)->(i+1) (%i) - for %j = 1 to #map1(%i)[%w] { - // expected-error@-1 {{value '%w' cannot be used as a symbol}} - } + // expected-error@-1 {{dim operand count and integer set dim count must match}} } return } @@ -601,7 +530,7 @@ func @invalid_symbol_affine_apply(%N : index) { func @large_bound() { for %i = 1 to 9223372036854775810 { - // expected-error@-1 {{bound or step is too large for index}} + // expected-error@-1 {{integer constant out of range for attribute}} } return } @@ -609,7 +538,7 @@ func @large_bound() { // ----- func @max_in_upper_bound(%N : index) { - for %i = 1 to max (i)->(N, 100) { //expected-error {{expected SSA operand}} + for %i = 1 to max (i)->(N, 100) { //expected-error {{expected type}} } return } @@ -617,7 +546,7 @@ func @max_in_upper_bound(%N : index) { // ----- func @step_typo() { - for %i = 1 to 100 step -- 1 { //expected-error {{expected integer}} + for %i = 1 to 100 step -- 1 { //expected-error {{expected constant integer}} } return } diff --git a/mlir/test/IR/locations.mlir b/mlir/test/IR/locations.mlir index 8a90d12bd03..7196e3a5c29 100644 --- a/mlir/test/IR/locations.mlir +++ b/mlir/test/IR/locations.mlir @@ -12,9 +12,9 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) { // CHECK: constant 4 : index loc(callsite("foo" at "mysource.cc":10:8)) %2 = constant 4 : index loc(callsite("foo" at "mysource.cc":10:8)) - // CHECK: for %i0 = 0 to 8 loc(fused["foo", "mysource.cc":10:8]) - for %i0 = 0 to 8 loc(fused["foo", "mysource.cc":10:8]) { - } + // CHECK: } loc(fused["foo", "mysource.cc":10:8]) + for %i0 = 0 to 8 { + } loc(fused["foo", "mysource.cc":10:8]) // CHECK: } loc(fused<"myPass">["foo", "foo2"]) if #set0(%2) { diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index 626f24569c6..bee886c0f34 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -230,7 +230,7 @@ func @complex_loops() { func @triang_loop(%arg0: index, %arg1: memref) { %c = constant 0 : i32 // CHECK: %c0_i32 = constant 0 : i32 for %i0 = 1 to %arg0 { // CHECK: for %i0 = 1 to %arg0 { - for %i1 = %i0 to %arg0 { // CHECK: for %i1 = #map{{[0-9]+}}(%i0) to %arg0 { + for %i1 = (d0)[]->(d0)(%i0)[] to %arg0 { // CHECK: for %i1 = #map{{[0-9]+}}(%i0) to %arg0 { store %c, %arg1[%i0, %i1] : memref // CHECK: store %c0_i32, %arg1[%i0, %i1] } // CHECK: } } // CHECK: } @@ -254,7 +254,7 @@ func @loop_bounds(%N : index) { // CHECK: for %i0 = %0 to %arg0 for %i = %s to %N { // CHECK: for %i1 = #map{{[0-9]+}}(%i0) to 0 - for %j = %i to 0 step 1 { + for %j = (d0)[]->(d0)(%i)[] to 0 step 1 { // CHECK: %1 = affine_apply #map{{.*}}(%i0, %i1)[%0] %w1 = affine_apply(d0, d1)[s0] -> (d0+d1) (%i, %j) [%s] // CHECK: %2 = affine_apply #map{{.*}}(%i0, %i1)[%0] @@ -764,23 +764,3 @@ func @verbose_if(%N: index) { } return } - -// CHECK-LABEL: func @verbose_for -func @verbose_for(%arg0 : index, %arg1 : index) { - // CHECK-NEXT: %0 = "for"() {lb: 1, ub: 10} : () -> index { - %a = "for"() {lb: 1, ub: 10 } : () -> index { - - // CHECK-NEXT: %1 = "for"() {lb: 1, step: 2, ub: 100} : () -> index { - %b = "for"() {lb: 1, ub: 100, step: 2 } : () -> index { - - // CHECK-NEXT: %2 = "for"(%arg0, %arg1) : (index, index) -> index { - %c = "for"(%arg0, %arg1) : (index, index) -> index { - - // CHECK-NEXT: %3 = "for"(%arg0) {ub: 100} : (index) -> index { - %d = "for"(%arg0) {ub: 100 } : (index) -> index { - } - } - } - } - return -} diff --git a/mlir/test/IR/pretty-locations.mlir b/mlir/test/IR/pretty-locations.mlir index 69dace45165..4668e7a832b 100644 --- a/mlir/test/IR/pretty-locations.mlir +++ b/mlir/test/IR/pretty-locations.mlir @@ -17,9 +17,9 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) { // CHECK-NEXT: at mysource3.cc:100:10 %3 = constant 4 : index loc(callsite("foo" at callsite("mysource1.cc":10:8 at callsite("mysource2.cc":13:8 at "mysource3.cc":100:10)))) - // CHECK: for %i0 = 0 to 8 ["foo", mysource.cc:10:8] - for %i0 = 0 to 8 loc(fused["foo", "mysource.cc":10:8]) { - } + // CHECK: } ["foo", mysource.cc:10:8] + for %i0 = 0 to 8 { + } loc(fused["foo", "mysource.cc":10:8]) // CHECK: } <"myPass">["foo", "foo2"] if #set0(%2) { diff --git a/mlir/test/Transforms/strip-debuginfo.mlir b/mlir/test/Transforms/strip-debuginfo.mlir index 618cba83f13..5d157282071 100644 --- a/mlir/test/Transforms/strip-debuginfo.mlir +++ b/mlir/test/Transforms/strip-debuginfo.mlir @@ -9,9 +9,9 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) { // CHECK: "foo"() : () -> i32 loc(unknown) %1 = "foo"() : () -> i32 loc("foo") - // CHECK: for %i0 = 0 to 8 loc(unknown) - for %i0 = 0 to 8 loc(fused["foo", "mysource.cc":10:8]) { - } + // CHECK: } loc(unknown) + for %i0 = 0 to 8 { + } loc(fused["foo", "mysource.cc":10:8]) // CHECK: } loc(unknown) %2 = constant 4 : index diff --git a/mlir/test/Transforms/unroll.mlir b/mlir/test/Transforms/unroll.mlir index 54c5233430c..09e55403b7d 100644 --- a/mlir/test/Transforms/unroll.mlir +++ b/mlir/test/Transforms/unroll.mlir @@ -40,6 +40,9 @@ // UNROLL-BY-4: [[MAP7:#map[0-9]+]] = (d0) -> (d0 + 5) // UNROLL-BY-4: [[MAP8:#map[0-9]+]] = (d0) -> (d0 + 10) // UNROLL-BY-4: [[MAP9:#map[0-9]+]] = (d0) -> (d0 + 15) +// UNROLL-BY-4: [[MAP10:#map[0-9]+]] = (d0) -> (0) +// UNROLL-BY-4: [[MAP11:#map[0-9]+]] = (d0) -> (d0) +// UNROLL-BY-4: [[MAP12:#map[0-9]+]] = ()[s0] -> (0) // CHECK-LABEL: func @loop_nest_simplest() { func @loop_nest_simplest() { @@ -432,7 +435,7 @@ func @loop_nest_single_iteration_after_unroll(%N: index) { // UNROLL-BY-4-LABEL: func @loop_nest_operand1() { func @loop_nest_operand1() { // UNROLL-BY-4: for %i0 = 0 to 100 step 2 { -// UNROLL-BY-4-NEXT: for %i1 = (d0) -> (0)(%i0) to #map{{[0-9]+}}(%i0) step 4 +// UNROLL-BY-4-NEXT: for %i1 = [[MAP10]](%i0) to #map{{[0-9]+}}(%i0) step 4 // UNROLL-BY-4-NEXT: %0 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32 @@ -452,7 +455,7 @@ func @loop_nest_operand1() { // UNROLL-BY-4-LABEL: func @loop_nest_operand2() { func @loop_nest_operand2() { // UNROLL-BY-4: for %i0 = 0 to 100 step 2 { -// UNROLL-BY-4-NEXT: for %i1 = (d0) -> (d0)(%i0) to #map{{[0-9]+}}(%i0) step 4 { +// UNROLL-BY-4-NEXT: for %i1 = [[MAP11]](%i0) to #map{{[0-9]+}}(%i0) step 4 { // UNROLL-BY-4-NEXT: %0 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32 @@ -474,7 +477,7 @@ func @loop_nest_operand2() { func @loop_nest_operand3() { // UNROLL-BY-4: for %i0 = 0 to 100 step 2 { for %i = 0 to 100 step 2 { - // UNROLL-BY-4: for %i1 = (d0) -> (d0)(%i0) to #map{{[0-9]+}}(%i0) step 4 { + // UNROLL-BY-4: for %i1 = [[MAP11]](%i0) to #map{{[0-9]+}}(%i0) step 4 { // UNROLL-BY-4-NEXT: %0 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32 @@ -492,7 +495,7 @@ func @loop_nest_operand3() { func @loop_nest_operand4(%N : index) { // UNROLL-BY-4: for %i0 = 0 to 100 { for %i = 0 to 100 { - // UNROLL-BY-4: for %i1 = ()[s0] -> (0)()[%arg0] to #map{{[0-9]+}}()[%arg0] step 4 { + // UNROLL-BY-4: for %i1 = [[MAP12]]()[%arg0] to #map{{[0-9]+}}()[%arg0] step 4 { // UNROLL-BY-4: %0 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32 -- cgit v1.2.3 From 8be262743677cc9f602efba9d28507c70f9bee42 Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Fri, 1 Feb 2019 17:06:22 -0800 Subject: Promote local buffers created post fusion to higher memory space - fusion already includes the necessary analysis to create small/local buffers post fusion; allocate these buffers in a higher memory space if the necessary pass parameters are provided (threshold size, memory space id) - although there will be a separate utility at some point to directly detect and promote small local buffers to higher memory spaces, doing it while fusion when possible is much less expensive, comes free with fusion analysis, and covers a key common case. PiperOrigin-RevId: 232063894 --- mlir/lib/Transforms/LoopFusion.cpp | 62 +++++++++++++++++++++++++++++++++----- 1 file changed, 54 insertions(+), 8 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 7d4ff03e306..5091e3ceb33 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -63,6 +63,17 @@ static llvm::cl::opt clFusionAddlComputeTolerance( " computation tolerated while fusing"), llvm::cl::cat(clOptionsCategory)); +static llvm::cl::opt clFusionFastMemorySpace( + "fusion-fast-mem-space", llvm::cl::Hidden, + llvm::cl::desc("Faster memory space number to promote fusion buffers to"), + llvm::cl::cat(clOptionsCategory)); + +static llvm::cl::opt clFusionLocalBufThreshold( + "fusion-local-buf-threshold", llvm::cl::Hidden, + llvm::cl::desc("Threshold size (bytes) for promoting local buffers to fast " + "memory space"), + llvm::cl::cat(clOptionsCategory)); + namespace { /// Loop fusion pass. This pass currently supports a greedy fusion policy, @@ -80,6 +91,11 @@ struct LoopFusion : public FunctionPass { PassResult runOnFunction(Function *f) override; static char passID; + // Any local buffers smaller than this size will be created in + // `fastMemorySpace` if provided. + unsigned localBufSizeThreshold = 1024; + Optional fastMemorySpace = None; + // The amount of additional computation that is tolerated while fusing // pair-wise as a fraction of the total computation. constexpr static double kComputeToleranceThreshold = 0.30f; @@ -876,6 +892,21 @@ static bool getSliceUnion(const ComputationSliceState &sliceStateA, return true; } +// TODO(mlir-team): improve/complete this when we have target data. +unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { + auto elementType = memRefType.getElementType(); + + unsigned sizeInBits; + if (elementType.isIntOrFloat()) { + sizeInBits = elementType.getIntOrFloatBitWidth(); + } else { + auto vectorType = elementType.cast(); + sizeInBits = + vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); + } + return llvm::divideCeil(sizeInBits, 8); +} + // Creates and returns a private (single-user) memref for fused loop rooted // at 'forOp', with (potentially reduced) memref size based on the // MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'. @@ -883,7 +914,9 @@ static bool getSliceUnion(const ComputationSliceState &sliceStateA, // this one. static Value *createPrivateMemRef(OpPointer forOp, OperationInst *srcStoreOpInst, - unsigned dstLoopDepth) { + unsigned dstLoopDepth, + Optional fastMemorySpace, + unsigned localBufSizeThreshold) { auto *forInst = forOp->getInstruction(); // Create builder to insert alloc op just before 'forOp'. @@ -906,7 +939,8 @@ static Value *createPrivateMemRef(OpPointer forOp, // by 'srcStoreOpInst' at depth 'dstLoopDepth'. Optional numElements = region.getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors); - assert(numElements.hasValue()); + assert(numElements.hasValue() && + "non-constant number of elts in local buffer"); const FlatAffineConstraints *cst = region.getConstraints(); // 'outerIVs' holds the values that this memory region is symbolic/paramteric @@ -933,9 +967,16 @@ static Value *createPrivateMemRef(OpPointer forOp, // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed // by 'srcStoreOpInst'. - auto newMemRefType = - top.getMemRefType(newShape, oldMemRefType.getElementType(), {}, - oldMemRefType.getMemorySpace()); + uint64_t bufSize = + getMemRefEltSizeInBytes(oldMemRefType) * numElements.getValue(); + unsigned newMemSpace; + if (bufSize < localBufSizeThreshold && fastMemorySpace.hasValue()) { + newMemSpace = fastMemorySpace.getValue(); + } else { + newMemSpace = oldMemRefType.getMemorySpace(); + } + auto newMemRefType = top.getMemRefType( + newShape, oldMemRefType.getElementType(), {}, newMemSpace); // Gather alloc operands for the dynamic dimensions of the memref. SmallVector allocOperands; unsigned dynamicDimCount = 0; @@ -1343,7 +1384,7 @@ public: std::iota(worklist.begin(), worklist.end(), 0); } - void run() { + void run(unsigned localBufSizeThreshold, Optional fastMemorySpace) { while (!worklist.empty()) { unsigned dstId = worklist.back(); worklist.pop_back(); @@ -1455,7 +1496,8 @@ public: } assert(storesForMemref.size() == 1); auto *newMemRef = createPrivateMemRef( - dstAffineForOp, storesForMemref[0], bestDstLoopDepth); + dstAffineForOp, storesForMemref[0], bestDstLoopDepth, + fastMemorySpace, localBufSizeThreshold); visitedMemrefs.insert(newMemRef); // Create new node in dependence graph for 'newMemRef' alloc op. unsigned newMemRefNodeId = @@ -1510,9 +1552,13 @@ public: } // end anonymous namespace PassResult LoopFusion::runOnFunction(Function *f) { + if (clFusionFastMemorySpace.getNumOccurrences() > 0) { + fastMemorySpace = clFusionFastMemorySpace.getValue(); + } + MemRefDependenceGraph g; if (g.init(f)) - GreedyFusion(&g).run(); + GreedyFusion(&g).run(localBufSizeThreshold, fastMemorySpace); return success(); } -- cgit v1.2.3 From b26900dce55c93043e8f84580df4a1bec65408be Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Mon, 4 Feb 2019 07:58:42 -0800 Subject: Update dma-generate pass to (1) work on blocks of instructions (instead of just loops), (2) take into account fast memory space capacity and lower 'dmaDepth' to fit, (3) add location information for debug info / errors - change dma-generate pass to work on blocks of instructions (start/end iterators) instead of 'for' loops; complete TODOs - allows DMA generation for straightline blocks of operation instructions interspersed b/w loops - take into account fast memory capacity: check whether memory footprint fits in fastMemoryCapacity parameter, and recurse/lower the depth at which DMA generation is performed until it does fit in the provided memory - add location information to MemRefRegion; any insufficient fast memory capacity errors or debug info w.r.t dma generation shows location information - allow DMA generation pass to be instantiated with a fast memory capacity option (besides command line flag) - change getMemRefRegion to return unique_ptr's - change getMemRefFootprintBytes to work on a 'Block' instead of 'ForInst' - other helper methods; add postDomInstFilter option for replaceAllMemRefUsesWith; drop forInst->walkOps, add Block::walkOps methods Eg. output $ mlir-opt -dma-generate -dma-fast-mem-capacity=1 /tmp/single.mlir /tmp/single.mlir:9:13: error: Total size of all DMA buffers' for this block exceeds fast memory capacity for %i3 = (d0) -> (d0)(%i1) to (d0) -> (d0 + 32)(%i1) { ^ $ mlir-opt -debug-only=dma-generate -dma-generate -dma-fast-mem-capacity=400 /tmp/single.mlir /tmp/single.mlir:9:13: note: 8 KiB of DMA buffers in fast memory space for this block for %i3 = (d0) -> (d0)(%i1) to (d0) -> (d0 + 32)(%i1) { PiperOrigin-RevId: 232297044 --- mlir/include/mlir/Analysis/Utils.h | 21 +- mlir/include/mlir/IR/Block.h | 13 ++ mlir/include/mlir/Transforms/Passes.h | 8 +- mlir/include/mlir/Transforms/Utils.h | 34 +-- mlir/lib/Analysis/AffineStructures.cpp | 2 +- mlir/lib/Analysis/Utils.cpp | 51 +++-- mlir/lib/IR/Block.cpp | 31 +++ mlir/lib/Transforms/DmaGeneration.cpp | 343 +++++++++++++++++++++--------- mlir/lib/Transforms/LoopFusion.cpp | 7 +- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 5 +- mlir/lib/Transforms/Utils/Utils.cpp | 15 +- mlir/test/Transforms/dma-generate.mlir | 74 ++++++- 12 files changed, 451 insertions(+), 153 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index 65af6d7b1f2..54549fc8ef8 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -27,6 +27,7 @@ #include "mlir/Analysis/AffineStructures.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/Location.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/SmallVector.h" #include @@ -35,8 +36,10 @@ namespace mlir { class AffineForOp; template class ConstOpPointer; +class Block; class FlatAffineConstraints; class Instruction; +class Location; class MemRefAccess; template class OpPointer; class Instruction; @@ -73,6 +76,9 @@ unsigned getNestingDepth(const Instruction &stmt); // The last field is a 2-d FlatAffineConstraints symbolic in %i. // struct MemRefRegion { + MemRefRegion(Value *memref, Location loc, bool write) + : memref(memref), write(write), loc(loc) {} + FlatAffineConstraints *getConstraints() { return &cst; } const FlatAffineConstraints *getConstraints() const { return &cst; } bool isWrite() const { return write; } @@ -108,10 +114,13 @@ struct MemRefRegion { /// Memref that this region corresponds to. Value *memref; -private: /// Read or write. bool write; + /// If there is more than one load/store op associated with the region, the + /// location information would correspond to one of those op's. + Location loc; + /// Region (data space) of the memref accessed. This set will thus have at /// least as many dimensional identifiers as the shape dimensionality of the /// memref, and these are the leading dimensions of the set appearing in that @@ -125,7 +134,7 @@ private: /// Computes the memory region accessed by this memref with the region /// represented as constraints symbolic/parameteric in 'loopDepth' loops -/// surrounding opInst. Returns false if this fails due to yet unimplemented +/// surrounding opInst. Returns nullptr if this fails due to yet unimplemented /// cases. The computed region's 'cst' field has exactly as many dimensional /// identifiers as the rank of the memref, and *potentially* additional symbolic /// identifiers which could include any of the loop IVs surrounding opInst up @@ -142,8 +151,8 @@ private: /// {memref = %A, write = false, {%i <= m0 <= %i + 7} } /// The last field is a 2-d FlatAffineConstraints symbolic in %i. /// -bool getMemRefRegion(Instruction *opInst, unsigned loopDepth, - MemRefRegion *region); +std::unique_ptr getMemRefRegion(Instruction *opInst, + unsigned loopDepth); /// Returns the size of memref data in bytes if it's statically shaped, None /// otherwise. @@ -199,8 +208,12 @@ insertBackwardComputationSlice(Instruction *srcOpInst, Instruction *dstOpInst, unsigned dstLoopDepth, ComputationSliceState *sliceState); +/// Gets the memory footprint of all data touched in the specified memory space +/// in bytes; if the memory space is unspecified, considers all memory spaces. Optional getMemoryFootprintBytes(ConstOpPointer forOp, int memorySpace = -1); +Optional getMemoryFootprintBytes(const Block &block, + int memorySpace = -1); } // end namespace mlir diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h index d0982630a5a..479f15d1603 100644 --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -311,6 +311,19 @@ public: return &Block::instructions; } + /// Walk the operation instructions of this block in preorder, calling the + /// callback for each operation. + void walk(std::function callback); + + /// Walk the operation instructions in this block in postorder, calling the + /// callback for each operation. + void walkPostOrder(std::function callback); + + /// Walk the operation instructions in the specified [begin, end) range of + /// this block, calling the callback for each operation. + void walk(Block::iterator begin, Block::iterator end, + std::function callback); + void print(raw_ostream &os) const; void dump() const; diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index 3269ac1fdc5..d4aa8a67600 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -24,6 +24,7 @@ #define MLIR_TRANSFORMS_PASSES_H #include "mlir/Support/LLVM.h" +#include namespace mlir { @@ -91,9 +92,10 @@ FunctionPass *createLoopTilingPass(); /// Promotes all accessed memref regions to the specified faster memory space /// while generating DMAs to move data. -FunctionPass *createDmaGenerationPass(unsigned lowMemorySpace, - unsigned highMemorySpace, - int minDmaTransferSize = 1024); +FunctionPass *createDmaGenerationPass( + unsigned slowMemorySpace, unsigned fastMemorySpace, + int minDmaTransferSize = 1024, + uint64_t fastMemCapacityBytes = std::numeric_limits::max()); /// Creates a pass to lower VectorTransferReadOp and VectorTransferWriteOp. FunctionPass *createLowerVectorTransfersPass(); diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h index 784e68a5ab3..581c668a154 100644 --- a/mlir/include/mlir/Transforms/Utils.h +++ b/mlir/include/mlir/Transforms/Utils.h @@ -42,19 +42,24 @@ class Function; /// Replaces all uses of oldMemRef with newMemRef while optionally remapping the /// old memref's indices using the supplied affine map, 'indexRemap'. The new /// memref could be of a different shape or rank. 'extraIndices' provides -/// additional access indices to be added to the start. 'indexRemap' remaps -/// indices of the old memref access to a new set of indices that are used to -/// index the memref. Additional input operands to indexRemap can be optionally -/// provided, and they are added at the start of its input list. 'indexRemap' is -/// expected to have only dimensional inputs, and the number of its inputs equal -/// to extraOperands.size() plus rank of the memref. 'extraOperands' is an -/// optional argument that corresponds to additional operands (inputs) for -/// indexRemap at the beginning of its input list. An additional optional -/// argument 'domInstFilter' restricts the replacement to only those operations -/// that are dominated by the former. Returns true on success and false if the -/// replacement is not possible (whenever a memref is used as an operand in a -/// non-deferencing scenario). See comments at function definition for an -/// example. +/// additional access indices to be added to the start. +/// +/// 'indexRemap' remaps indices of the old memref access to a new set of indices +/// that are used to index the memref. Additional input operands to indexRemap +/// can be optionally provided, and they are added at the start of its input +/// list. 'indexRemap' is expected to have only dimensional inputs, and the +/// number of its inputs equal to extraOperands.size() plus rank of the memref. +/// 'extraOperands' is an optional argument that corresponds to additional +/// operands (inputs) for indexRemap at the beginning of its input list. +/// +/// 'domInstFilter', if non-null, restricts the replacement to only those +/// operations that are dominated by the former; similarly, `postDomInstFilter` +/// restricts replacement to only those operations that are postdominated by it. +/// +/// Returns true on success and false if the replacement is not possible +/// (whenever a memref is used as an operand in a non-deferencing scenario). See +/// comments at function definition for an example. +// // Ex: to replace load %A[%i, %j] with load %Abuf[%t mod 2, %ii - %i, %j]: // The SSA value corresponding to '%t mod 2' should be in 'extraIndices', and // index remap will perform (%i, %j) -> (%ii - %i, %j), i.e., indexRemap = (d0, @@ -66,7 +71,8 @@ bool replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, ArrayRef extraIndices = {}, AffineMap indexRemap = AffineMap(), ArrayRef extraOperands = {}, - const Instruction *domInstFilter = nullptr); + const Instruction *domInstFilter = nullptr, + const Instruction *postDomInstFilter = nullptr); /// Creates and inserts into 'builder' a new AffineApplyOp, with the number of /// its results equal to the number of operands, as a composition diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 9d1f7481115..468e79b8545 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -2130,7 +2130,7 @@ bool FlatAffineConstraints::unionBoundingBox( // Identify max. auto uRes = compareBounds(ub, otherUb); - if (uRes == BoundCmpResult::Greater || res == BoundCmpResult::Equal) { + if (uRes == BoundCmpResult::Greater || uRes == BoundCmpResult::Equal) { maxUb = ub; } else if (uRes == BoundCmpResult::Less) { maxUb = otherUb; diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 24361ac621f..652aaab0e1b 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -122,25 +122,26 @@ bool MemRefRegion::unionBoundingBox(const MemRefRegion &other) { // // TODO(bondhugula): extend this to any other memref dereferencing ops // (dma_start, dma_wait). -bool mlir::getMemRefRegion(Instruction *opInst, unsigned loopDepth, - MemRefRegion *region) { +std::unique_ptr mlir::getMemRefRegion(Instruction *opInst, + unsigned loopDepth) { unsigned rank; + std::unique_ptr region; SmallVector indices; if (auto loadOp = opInst->dyn_cast()) { rank = loadOp->getMemRefType().getRank(); indices.reserve(rank); indices.append(loadOp->getIndices().begin(), loadOp->getIndices().end()); - region->memref = loadOp->getMemRef(); - region->setWrite(false); + region = std::make_unique(loadOp->getMemRef(), + loadOp->getLoc(), false); } else if (auto storeOp = opInst->dyn_cast()) { rank = storeOp->getMemRefType().getRank(); indices.reserve(rank); indices.append(storeOp->getIndices().begin(), storeOp->getIndices().end()); - region->memref = storeOp->getMemRef(); - region->setWrite(true); + region = std::make_unique(storeOp->getMemRef(), + storeOp->getLoc(), true); } else { assert(false && "expected load or store op"); - return false; + return nullptr; } // Build the constraints for this region. @@ -153,13 +154,15 @@ bool mlir::getMemRefRegion(Instruction *opInst, unsigned loopDepth, SmallVector regionSymbols = extractForInductionVars(ivs); regionCst->reset(0, loopDepth, 0, regionSymbols); - return true; + return region; } FuncBuilder b(opInst); auto idMap = b.getMultiDimIdentityMap(rank); // Initialize 'accessValueMap' and compose with reachable AffineApplyOps. fullyComposeAffineMapAndOperands(&idMap, &indices); + // Remove any duplicates. + canonicalizeMapAndOperands(&idMap, &indices); AffineValueMap accessValueMap(idMap, indices); AffineMap accessMap = accessValueMap.getAffineMap(); @@ -180,7 +183,7 @@ bool mlir::getMemRefRegion(Instruction *opInst, unsigned loopDepth, // TODO(bondhugula): rewrite this to use getInstIndexSet; this way // conditionals will be handled when the latter supports it. if (!regionCst->addAffineForOpDomain(loop)) - return false; + return nullptr; } else { // Has to be a valid symbol. auto *symbol = accessValueMap.getOperand(i); @@ -198,7 +201,7 @@ bool mlir::getMemRefRegion(Instruction *opInst, unsigned loopDepth, if (!regionCst->composeMap(&accessValueMap)) { LLVM_DEBUG(llvm::dbgs() << "getMemRefRegion: compose affine map failed\n"); LLVM_DEBUG(accessValueMap.getAffineMap().dump()); - return false; + return nullptr; } // Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which @@ -233,7 +236,7 @@ bool mlir::getMemRefRegion(Instruction *opInst, unsigned loopDepth, LLVM_DEBUG(llvm::dbgs() << "Memory region:\n"); LLVM_DEBUG(region->getConstraints()->dump()); - return true; + return region; } // TODO(mlir-team): improve/complete this when we have target data. @@ -278,19 +281,20 @@ bool mlir::boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, "argument should be either a LoadOp or a StoreOp"); Instruction *opInst = loadOrStoreOp->getInstruction(); - MemRefRegion region; - if (!getMemRefRegion(opInst, /*loopDepth=*/0, ®ion)) + + auto region = getMemRefRegion(opInst, /*loopDepth=*/0); + if (!region) return false; LLVM_DEBUG(llvm::dbgs() << "Memory region"); - LLVM_DEBUG(region.getConstraints()->dump()); + LLVM_DEBUG(region->getConstraints()->dump()); bool outOfBounds = false; unsigned rank = loadOrStoreOp->getMemRefType().getRank(); // For each dimension, check for out of bounds. for (unsigned r = 0; r < rank; r++) { - FlatAffineConstraints ucst(*region.getConstraints()); + FlatAffineConstraints ucst(*region->getConstraints()); // Intersect memory region with constraint capturing out of bounds (both out // of upper and out of lower), and check if the constraint system is @@ -310,7 +314,7 @@ bool mlir::boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, } // Check for a negative index. - FlatAffineConstraints lcst(*region.getConstraints()); + FlatAffineConstraints lcst(*region->getConstraints()); std::fill(ineq.begin(), ineq.end(), 0); // d_i <= -1; lcst.addConstantUpperBound(r, -1); @@ -519,8 +523,8 @@ MemRefAccess::MemRefAccess(Instruction *loadOrStoreOpInst) { /// Returns the nesting depth of this statement, i.e., the number of loops /// surrounding this statement. -unsigned mlir::getNestingDepth(const Instruction &stmt) { - const Instruction *currInst = &stmt; +unsigned mlir::getNestingDepth(const Instruction &inst) { + const Instruction *currInst = &inst; unsigned depth = 0; while ((currInst = currInst->getParentInst())) { if (currInst->isa()) @@ -577,11 +581,16 @@ static Optional getRegionSize(const MemRefRegion ®ion) { Optional mlir::getMemoryFootprintBytes(ConstOpPointer forOp, int memorySpace) { + return getMemoryFootprintBytes(*forOp->getBody(), memorySpace); +} + +Optional mlir::getMemoryFootprintBytes(const Block &block, + int memorySpace) { std::vector> regions; // Walk this 'for' instruction to gather all memory regions. bool error = false; - const_cast(*forOp).walkOps([&](Instruction *opInst) { + const_cast(&block)->walk([&](Instruction *opInst) { if (!opInst->isa() && !opInst->isa()) { // Neither load nor a store op. return; @@ -591,8 +600,8 @@ mlir::getMemoryFootprintBytes(ConstOpPointer forOp, // all regions for a given memref instead of creating one region per // memory op. This way we would be allocating O(num of memref's) sets // instead of O(num of load/store op's). - auto region = std::make_unique(); - if (!getMemRefRegion(opInst, 0, region.get())) { + auto region = getMemRefRegion(opInst, 0); + if (!region) { LLVM_DEBUG(llvm::dbgs() << "Error obtaining memory region\n"); // TODO: stop the walk if an error occurred. error = true; diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp index 81e70e2b139..698494144ce 100644 --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -256,6 +256,37 @@ Block *Block::splitBlock(iterator splitBefore) { return newBB; } +void Block::walk(std::function callback) { + walk(begin(), end(), callback); +} + +void Block::walk(Block::iterator begin, Block::iterator end, + std::function callback) { + struct Walker : public InstWalker { + std::function const &callback; + Walker(std::function const &callback) + : callback(callback) {} + + void visitOperationInst(OperationInst *opInst) { callback(opInst); } + }; + + Walker w(callback); + w.walk(begin, end); +} + +void Block::walkPostOrder(std::function callback) { + struct Walker : public InstWalker { + std::function const &callback; + Walker(std::function const &callback) + : callback(callback) {} + + void visitOperationInst(OperationInst *opInst) { callback(opInst); } + }; + + Walker v(callback); + v.walkPostOrder(begin(), end()); +} + //===----------------------------------------------------------------------===// // BlockList //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 83ec726ec2a..2bbb32036c2 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -47,7 +47,7 @@ static llvm::cl::opt clFastMemorySpace( llvm::cl::desc("Set fast memory space id for DMA generation"), llvm::cl::cat(clOptionsCategory)); -static llvm::cl::opt clFastMemoryCapacity( +static llvm::cl::opt clFastMemoryCapacity( "dma-fast-mem-capacity", llvm::cl::Hidden, llvm::cl::desc("Set fast memory space capacity in KiB"), llvm::cl::cat(clOptionsCategory)); @@ -57,25 +57,28 @@ namespace { /// Generates DMAs for memref's living in 'slowMemorySpace' into newly created /// buffers in 'fastMemorySpace', and replaces memory operations to the former /// by the latter. Only load op's handled for now. -/// TODO(bondhugula): extend this to store op's. +// TODO(bondhugula): We currently can't generate DMAs correctly when stores are +// strided. Check for strided stores. +// TODO(mlir-team): we don't insert dealloc's for the DMA buffers; this is thus +// natural only for scoped allocations. struct DmaGeneration : public FunctionPass { - explicit DmaGeneration(unsigned slowMemorySpace = 0, - unsigned fastMemorySpaceArg = 1, - int minDmaTransferSize = 1024) + explicit DmaGeneration( + unsigned slowMemorySpace = 0, unsigned fastMemorySpace = 1, + int minDmaTransferSize = 1024, + uint64_t fastMemCapacityBytes = std::numeric_limits::max()) : FunctionPass(&DmaGeneration::passID), slowMemorySpace(slowMemorySpace), - minDmaTransferSize(minDmaTransferSize) { - if (clFastMemorySpace.getNumOccurrences() > 0) { - fastMemorySpace = clFastMemorySpace; - } else { - fastMemorySpace = fastMemorySpaceArg; - } - } + fastMemorySpace(fastMemorySpace), + minDmaTransferSize(minDmaTransferSize), + fastMemCapacityBytes(fastMemCapacityBytes) {} PassResult runOnFunction(Function *f) override; - void runOnAffineForOp(OpPointer forOp); + bool runOnBlock(Block *block, uint64_t consumedCapacityBytes); + uint64_t runOnBlock(Block::iterator begin, Block::iterator end); - bool generateDma(const MemRefRegion ®ion, OpPointer forOp, - uint64_t *sizeInBytes); + bool generateDma(const MemRefRegion ®ion, Block *block, + Block::iterator begin, Block::iterator end, + uint64_t *sizeInBytes, Block::iterator *nBegin, + Block::iterator *nEnd); // List of memory regions to DMA for. We need a map vector to have a // guaranteed iteration order to write test cases. CHECK-DAG doesn't help here @@ -93,6 +96,8 @@ struct DmaGeneration : public FunctionPass { unsigned fastMemorySpace; // Minimum DMA transfer size supported by the target in bytes. const int minDmaTransferSize; + // Capacity of the faster memory space. + uint64_t fastMemCapacityBytes; // Constant zero index to avoid too many duplicates. Value *zeroIndex = nullptr; @@ -110,9 +115,10 @@ char DmaGeneration::passID = 0; /// TODO(bondhugula): extend this to store op's. FunctionPass *mlir::createDmaGenerationPass(unsigned slowMemorySpace, unsigned fastMemorySpace, - int minDmaTransferSize) { - return new DmaGeneration(slowMemorySpace, fastMemorySpace, - minDmaTransferSize); + int minDmaTransferSize, + uint64_t fastMemCapacityBytes) { + return new DmaGeneration(slowMemorySpace, fastMemorySpace, minDmaTransferSize, + fastMemCapacityBytes); } // Info comprising stride and number of elements transferred every stride. @@ -192,26 +198,48 @@ static bool getFullMemRefAsRegion(OperationInst *opInst, return true; } -// Creates a buffer in the faster memory space for the specified region; -// generates a DMA from the lower memory space to this one, and replaces all -// loads to load from that buffer. Returns false if DMAs could not be generated -// due to yet unimplemented cases. -bool DmaGeneration::generateDma(const MemRefRegion ®ion, - OpPointer forOp, - uint64_t *sizeInBytes) { - auto *forInst = forOp->getInstruction(); +static void emitNoteForBlock(const Block &block, const Twine &message) { + auto *inst = block.getContainingInst(); + if (!inst) { + block.getFunction()->emitNote(message); + } else { + inst->emitNote(message); + } +} + +/// Creates a buffer in the faster memory space for the specified region; +/// generates a DMA from the lower memory space to this one, and replaces all +/// loads to load from that buffer. Returns false if DMAs could not be generated +/// due to yet unimplemented cases. `begin` and `end` specify the insertion +/// points where the incoming DMAs and outgoing DMAs, respectively, should +/// be inserted (the insertion happens right before the insertion point). Since +/// `begin` can itself be invalidated due to the memref rewriting done from this +/// method, the output argument `nBegin` is set to its replacement (set +/// to `begin` if no invalidation happens). Since outgoing DMAs are inserted at +/// `end`, the output argument `nEnd` is set to the one following the original +/// end (since the latter could have been invalidated/replaced). `sizeInBytes` +/// is set to the size of the DMA buffer allocated. +bool DmaGeneration::generateDma(const MemRefRegion ®ion, Block *block, + Block::iterator begin, Block::iterator end, + uint64_t *sizeInBytes, Block::iterator *nBegin, + Block::iterator *nEnd) { + *nBegin = begin; + *nEnd = end; + + if (begin == end) + return true; // DMAs for read regions are going to be inserted just before the for loop. - FuncBuilder prologue(forInst); + FuncBuilder prologue(block, begin); // DMAs for write regions are going to be inserted just after the for loop. - FuncBuilder epilogue(forInst->getBlock(), - std::next(Block::iterator(forInst))); + FuncBuilder epilogue(block, end); FuncBuilder *b = region.isWrite() ? &epilogue : &prologue; // Builder to create constants at the top level. - FuncBuilder top(forInst->getFunction()); + auto *func = block->getFunction(); + FuncBuilder top(func); - auto loc = forInst->getLoc(); + auto loc = region.loc; auto *memref = region.memref; auto memRefType = memref->getType().cast(); @@ -310,21 +338,17 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, auto fastMemRefType = top.getMemRefType( fastBufferShape, memRefType.getElementType(), {}, fastMemorySpace); - LLVM_DEBUG(llvm::dbgs() << "Creating a new buffer of type: "); - LLVM_DEBUG(fastMemRefType.dump(); llvm::dbgs() << "\n"); - // Create the fast memory space buffer just before the 'for' instruction. fastMemRef = prologue.create(loc, fastMemRefType)->getResult(); // Record it. fastBufferMap[memref] = fastMemRef; // fastMemRefType is a constant shaped memref. *sizeInBytes = getMemRefSizeInBytes(fastMemRefType).getValue(); - LLVM_DEBUG(llvm::dbgs() << "Creating a new buffer of type "; + LLVM_DEBUG(emitNoteForBlock(*block, "Creating DMA buffer of type "); fastMemRefType.dump(); llvm::dbgs() - << " and size " << Twine(llvm::divideCeil(*sizeInBytes, 1024)) + << " and of size " << Twine(llvm::divideCeil(*sizeInBytes, 1024)) << " KiB\n";); - } else { // Reuse the one already created. fastMemRef = fastBufferMap[memref]; @@ -336,9 +360,6 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, auto numElementsSSA = top.create(loc, numElements.getValue()); - // TODO(bondhugula): check for transfer sizes not being a multiple of - // minDmaTransferSize and handle them appropriately. - SmallVector strideInfos; getMultiLevelStrides(region, fastBufferShape, &strideInfos); @@ -357,6 +378,12 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, top.create(loc, strideInfos[0].numEltPerStride); } + // Record the last instruction just before the point where we insert the + // outgoing DMAs. We later do the memref replacement later only in [begin, + // postDomFilter] so that the original memref's in the DMA ops themselves + // don't get replaced. + auto postDomFilter = std::prev(end); + if (!region.isWrite()) { // DMA non-blocking read from original buffer to fast buffer. b->create(loc, memref, memIndices, fastMemRef, bufIndices, @@ -364,9 +391,13 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, numEltPerStride); } else { // DMA non-blocking write from fast buffer to the original memref. - b->create(loc, fastMemRef, bufIndices, memref, memIndices, - numElementsSSA, tagMemRef, zeroIndex, stride, - numEltPerStride); + auto op = b->create(loc, fastMemRef, bufIndices, memref, + memIndices, numElementsSSA, tagMemRef, + zeroIndex, stride, numEltPerStride); + // Since new ops are being appended (for outgoing DMAs), adjust the end to + // mark end of range of the original. + if (*nEnd == end) + *nEnd = Block::iterator(op->getInstruction()); } // Matching DMA wait to block on completion; tag always has a 0 index. @@ -389,45 +420,151 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, remapExprs.push_back(dimExpr - offsets[i]); } auto indexRemap = b->getAffineMap(outerIVs.size() + rank, 0, remapExprs, {}); - // *Only* those uses within the body of 'forOp' are replaced. + + // Record the begin since it may be invalidated by memref replacement. + Block::iterator prev; + bool wasAtStartOfBlock = (begin == block->begin()); + if (!wasAtStartOfBlock) + prev = std::prev(begin); + + // *Only* those uses within the range [begin, end) of 'block' are replaced. replaceAllMemRefUsesWith(memref, fastMemRef, /*extraIndices=*/{}, indexRemap, /*extraOperands=*/outerIVs, - /*domInstFilter=*/&*forOp->getBody()->begin()); + /*domInstFilter=*/&*begin, + /*postDomInstFilter=*/&*postDomFilter); + + *nBegin = wasAtStartOfBlock ? block->begin() : std::next(prev); + return true; } -// TODO(bondhugula): make this run on a Block instead of a 'for' inst. -void DmaGeneration::runOnAffineForOp(OpPointer forOp) { - // For now (for testing purposes), we'll run this on the outermost among 'for' - // inst's with unit stride, i.e., right at the top of the tile if tiling has - // been done. In the future, the DMA generation has to be done at a level - // where the generated data fits in a higher level of the memory hierarchy; so - // the pass has to be instantiated with additional information that we aren't - // provided with at the moment. - if (forOp->getStep() != 1) { - auto *forBody = forOp->getBody(); - if (forBody->empty()) - return; - if (auto innerFor = - cast(forBody->front()).dyn_cast()) { - runOnAffineForOp(innerFor); +/// Generate DMAs for this block. The block is partitioned into separate +/// `regions`; each region is either a sequence of one or more instructions +/// starting and ending with a load or store op, or just a loop (which could +/// have other loops nested within). Returns false on an error, true otherwise. +bool DmaGeneration::runOnBlock(Block *block, uint64_t consumedCapacityBytes) { + block->dump(); + if (block->empty()) + return true; + + uint64_t priorConsumedCapacityBytes = consumedCapacityBytes; + + // Every loop in the block starts and ends a region. A contiguous sequence of + // operation instructions starting and ending with a load/store op is also + // identified as a region. Straightline code (contiguous chunks of operation + // instructions) are always assumed to not exhaust memory. As a result, this + // approach is conservative in some cases at the moment, we do a check later + // and report an error with location info. + // TODO(bondhugula): An 'if' instruction is being treated similar to an + // operation instruction. 'if''s could have 'for's in them; treat them + // separately. + + // Get to the first load, store, or for op. + auto curBegin = + std::find_if(block->begin(), block->end(), [&](const Instruction &inst) { + return inst.isa() || inst.isa() || + inst.isa(); + }); + + for (auto it = curBegin; it != block->end(); ++it) { + if (auto forOp = it->dyn_cast()) { + // We'll assume for now that loops with steps are tiled loops, and so DMAs + // are not performed for that depth, but only further inside. + // If the memory footprint of the 'for' loop is higher than fast memory + // capacity (when provided), we recurse to DMA at an inner level until + // we find a depth at which footprint fits in the capacity. If the + // footprint can't be calcuated, we assume for now it fits. + + // Returns true if the footprint is known to exceed capacity. + auto exceedsCapacity = [&](OpPointer forOp) { + Optional footprint; + return ((footprint = getMemoryFootprintBytes(forOp, 0)).hasValue() && + consumedCapacityBytes + + static_cast(footprint.getValue()) > + fastMemCapacityBytes); + }; + + if (forOp->getStep() != 1 || exceedsCapacity(forOp)) { + // We'll split and do the DMAs one or more levels inside for forInst + consumedCapacityBytes += runOnBlock(/*begin=*/curBegin, /*end=*/it); + // Recurse onto the body of this loop. + runOnBlock(forOp->getBody(), consumedCapacityBytes); + // The next region starts right after the 'for' instruction. + curBegin = std::next(it); + } else { + // We have enough capacity, i.e., DMAs will be computed for the portion + // of the block until 'it', and for the 'for' loop. For the latter, they + // are placed just before this loop (for incoming DMAs) and right after + // (for outgoing ones). + consumedCapacityBytes += runOnBlock(/*begin=*/curBegin, /*end=*/it); + + // Inner loop DMAs have their own scope - we don't thus update consumed + // capacity. The footprint check above guarantees this inner loop's + // footprint fits. + runOnBlock(/*begin=*/it, /*end=*/std::next(it)); + curBegin = std::next(it); + } + } else if (!it->isa() && !it->isa()) { + consumedCapacityBytes += runOnBlock(/*begin=*/curBegin, /*end=*/it); + curBegin = std::next(it); } - return; } - // DMAs will be generated for this depth, i.e., for all data accessed by this - // loop. - unsigned dmaDepth = getNestingDepth(*forOp->getInstruction()); + // Generate the DMA for the final region. + if (curBegin != block->end()) { + // Can't be a terminator because it would have been skipped above. + assert(!curBegin->isTerminator() && "can't be a terminator"); + consumedCapacityBytes += + runOnBlock(/*begin=*/curBegin, /*end=*/block->end()); + } + + if (llvm::DebugFlag) { + uint64_t thisBlockDmaSizeBytes = + consumedCapacityBytes - priorConsumedCapacityBytes; + if (thisBlockDmaSizeBytes > 0) { + emitNoteForBlock( + *block, + Twine(llvm::divideCeil(thisBlockDmaSizeBytes, 1024)) + + " KiB of DMA buffers in fast memory space for this block\n"); + } + } + + if (consumedCapacityBytes > fastMemCapacityBytes) { + StringRef str = "Total size of all DMA buffers' for this block " + "exceeds fast memory capacity\n"; + if (auto *inst = block->getContainingInst()) + inst->emitError(str); + else + block->getFunction()->emitError(str); + return false; + } + + return true; +} + +/// Generates DMAs for a contiguous sequence of instructions in `block` in the +/// iterator range [begin, end). Returns the total size of the DMA buffers used. +uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { + if (begin == end) + return 0; + + assert(begin->getBlock() == std::prev(end)->getBlock() && + "Inconsistent args"); + + Block *block = begin->getBlock(); + + // DMAs will be generated for this depth, i.e., symbolic in all loops + // surrounding the region of this block. + unsigned dmaDepth = getNestingDepth(*begin); readRegions.clear(); writeRegions.clear(); fastBufferMap.clear(); - // Walk this 'for' instruction to gather all memory regions. - forOp->walkOps([&](OperationInst *opInst) { - // Gather regions to promote to buffers in faster memory space. - // TODO(bondhugula): handle store op's; only load's handled for now. + // Walk this range of instructions to gather all memory regions. + block->walk(begin, end, [&](OperationInst *opInst) { + // Gather regions to allocate to buffers in faster memory space. if (auto loadOp = opInst->dyn_cast()) { if (loadOp->getMemRefType().getMemorySpace() != slowMemorySpace) return; @@ -439,18 +576,15 @@ void DmaGeneration::runOnAffineForOp(OpPointer forOp) { return; } - // TODO(bondhugula): eventually, we need to be performing a union across - // all regions for a given memref instead of creating one region per - // memory op. This way we would be allocating O(num of memref's) sets - // instead of O(num of load/store op's). - auto region = std::make_unique(); - if (!getMemRefRegion(opInst, dmaDepth, region.get())) { + // Compute the MemRefRegion accessed. + auto region = getMemRefRegion(opInst, dmaDepth); + if (!region) { LLVM_DEBUG(llvm::dbgs() << "Error obtaining memory region: semi-affine maps?\n"); LLVM_DEBUG(llvm::dbgs() << "over-approximating to the entire memref\n"); if (!getFullMemRefAsRegion(opInst, dmaDepth, region.get())) { LLVM_DEBUG( - forOp->emitError("Non-constant memref sizes not yet supported")); + opInst->emitError("Non-constant memref sizes not yet supported")); return; } } @@ -477,12 +611,12 @@ void DmaGeneration::runOnAffineForOp(OpPointer forOp) { return false; // Perform a union with the existing region. - if (!(*it).second->unionBoundingBox(*region)) { + if (!it->second->unionBoundingBox(*region)) { LLVM_DEBUG(llvm::dbgs() - << "Memory region bounding box failed" + << "Memory region bounding box failed; " "over-approximating to the entire memref\n"); if (!getFullMemRefAsRegion(opInst, dmaDepth, region.get())) { - LLVM_DEBUG(forOp->emitError( + LLVM_DEBUG(opInst->emitError( "Non-constant memref sizes not yet supported")); } } @@ -500,48 +634,59 @@ void DmaGeneration::runOnAffineForOp(OpPointer forOp) { } }); - uint64_t totalSizeInBytes = 0; - + uint64_t totalDmaBuffersSizeInBytes = 0; bool ret = true; auto processRegions = [&](const SmallMapVector, 4> ®ions) { for (const auto ®ionEntry : regions) { uint64_t sizeInBytes; - bool iRet = generateDma(*regionEntry.second, forOp, &sizeInBytes); - if (iRet) - totalSizeInBytes += sizeInBytes; + Block::iterator nBegin, nEnd; + bool iRet = generateDma(*regionEntry.second, block, begin, end, + &sizeInBytes, &nBegin, &nEnd); + if (iRet) { + begin = nBegin; + end = nEnd; + totalDmaBuffersSizeInBytes += sizeInBytes; + } ret = ret & iRet; } }; processRegions(readRegions); processRegions(writeRegions); + if (!ret) { - forOp->emitError("DMA generation failed for one or more memref's\n"); - return; + begin->emitError( + "DMA generation failed for one or more memref's in this block\n"); + return totalDmaBuffersSizeInBytes; } - LLVM_DEBUG(llvm::dbgs() << Twine(llvm::divideCeil(totalSizeInBytes, 1024)) - << " KiB of DMA buffers in fast memory space\n";); - - if (clFastMemoryCapacity && totalSizeInBytes > clFastMemoryCapacity) { - // TODO(bondhugula): selecting the DMA depth so that the result DMA buffers - // fit in fast memory is a TODO - not complex. - forOp->emitError( - "Total size of all DMA buffers' exceeds memory capacity\n"); + + // For a range of operation instructions, a note will be emitted at the + // caller. + OpPointer forOp; + if (llvm::DebugFlag && (forOp = begin->dyn_cast())) { + forOp->emitNote( + Twine(llvm::divideCeil(totalDmaBuffersSizeInBytes, 1024)) + + " KiB of DMA buffers in fast memory space for this block\n"); } + + return totalDmaBuffersSizeInBytes; } PassResult DmaGeneration::runOnFunction(Function *f) { FuncBuilder topBuilder(f); - zeroIndex = topBuilder.create(f->getLoc(), 0); + if (clFastMemorySpace.getNumOccurrences() > 0) { + fastMemorySpace = clFastMemorySpace; + } + + if (clFastMemoryCapacity.getNumOccurrences() > 0) { + fastMemCapacityBytes = clFastMemoryCapacity * 1024; + } + for (auto &block : *f) { - for (auto &inst : block) { - if (auto forOp = cast(inst).dyn_cast()) { - runOnAffineForOp(forOp); - } - } + runOnBlock(&block, /*consumedCapacityBytes=*/0); } // This function never leaves the IR in an invalid state. return success(); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 5091e3ceb33..162e0e3b7f6 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -929,8 +929,7 @@ static Value *createPrivateMemRef(OpPointer forOp, unsigned rank = oldMemRefType.getRank(); // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'. - MemRefRegion region; - getMemRefRegion(srcStoreOpInst, dstLoopDepth, ®ion); + auto region = getMemRefRegion(srcStoreOpInst, dstLoopDepth); SmallVector newShape; std::vector> lbs; SmallVector lbDivisors; @@ -938,11 +937,11 @@ static Value *createPrivateMemRef(OpPointer forOp, // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed // by 'srcStoreOpInst' at depth 'dstLoopDepth'. Optional numElements = - region.getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors); + region->getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors); assert(numElements.hasValue() && "non-constant number of elts in local buffer"); - const FlatAffineConstraints *cst = region.getConstraints(); + const FlatAffineConstraints *cst = region->getConstraints(); // 'outerIVs' holds the values that this memory region is symbolic/paramteric // on; this would correspond to loop IVs surrounding the level at which the // slice is being materialized. diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index 4191a9cc279..e6ce273b532 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -178,9 +178,8 @@ void MemRefDataFlowOpt::visitOperationInst(OperationInst *opInst) { // is trivially loading from a single location at that depth; so there // isn't a need to call isRangeOneToOne. if (getNestingDepth(*storeOpInst) < loadOpDepth) { - MemRefRegion region; - getMemRefRegion(loadOpInst, nsLoops, ®ion); - if (!region.getConstraints()->isRangeOneToOne( + auto region = getMemRefRegion(loadOpInst, nsLoops); + if (!region->getConstraints()->isRangeOneToOne( /*start=*/0, /*limit=*/loadOp->getMemRefType().getRank())) break; } diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 819f1a59b6f..732062a8b97 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -48,7 +48,8 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, ArrayRef extraIndices, AffineMap indexRemap, ArrayRef extraOperands, - const Instruction *domInstFilter) { + const Instruction *domInstFilter, + const Instruction *postDomInstFilter) { unsigned newMemRefRank = newMemRef->getType().cast().getRank(); (void)newMemRefRank; // unused in opt mode unsigned oldMemRefRank = oldMemRef->getType().cast().getRank(); @@ -66,9 +67,14 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, newMemRef->getType().cast().getElementType()); std::unique_ptr domInfo; + std::unique_ptr postDomInfo; if (domInstFilter) domInfo = std::make_unique(domInstFilter->getFunction()); + if (postDomInstFilter) + postDomInfo = + std::make_unique(postDomInstFilter->getFunction()); + // The ops where memref replacement succeeds are replaced with new ones. SmallVector opsToErase; @@ -81,6 +87,11 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, if (domInstFilter && !domInfo->dominates(domInstFilter, opInst)) continue; + // Skip this use if it's not post-dominated by postDomInstFilter. + if (postDomInstFilter && + !postDomInfo->postDominates(postDomInstFilter, opInst)) + continue; + // Check if the memref was used in a non-deferencing context. It is fine for // the memref to be used in a non-deferencing way outside of the region // where this replacement is happening. @@ -167,7 +178,7 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, res->replaceAllUsesWith(repOp->getResult(r++)); } // Collect and erase at the end since one of these op's could be - // domInstFilter! + // domInstFilter or postDomInstFilter as well! opsToErase.push_back(opInst); } diff --git a/mlir/test/Transforms/dma-generate.mlir b/mlir/test/Transforms/dma-generate.mlir index 9096fe8b097..cdc7441b14e 100644 --- a/mlir/test/Transforms/dma-generate.mlir +++ b/mlir/test/Transforms/dma-generate.mlir @@ -262,7 +262,7 @@ func @dma_unknown_size(%arg0: memref) { // size -- not yet implemented. // CHECK: %2 = load %arg0[%i0, %i1] : memref load %arg0[%i, %j] : memref - // expected-error@-6 {{DMA generation failed for one or more memref's}} + // expected-error@-6 {{DMA generation failed for one or more memref's in this block}} } } return @@ -282,7 +282,7 @@ func @dma_memref_3d(%arg0: memref<1024x1024x1024xf32>) { // not yet implemented. // CHECK: %5 = load %arg0[%2, %3, %4] : memref<1024x1024x1024xf32> %v = load %arg0[%idx, %idy, %idz] : memref<1024 x 1024 x 1024 x f32> - // expected-error@-10 {{DMA generation failed for one or more memref's}} + // expected-error@-10 {{DMA generation failed for one or more memref's in this block}} } } } @@ -359,3 +359,73 @@ func @multi_load_store_union() { // CHECK-NEXT: dma_wait %3[%c0], %c170372 : memref<1xi32> // CHECK-NEXT: return // CHECK-NEXT:} + +// ----- + +// CHECK-DAG: [[MAP_MINUS_ONE:#map[0-9]+]] = (d0) -> (d0 - 1) + +// CHECK-LABEL: func @dma_loop_straightline_interspersed() { +func @dma_loop_straightline_interspersed() { + %c0 = constant 0 : index + %c255 = constant 255 : index + %A = alloc() : memref<256 x f32> + %v = load %A[%c0] : memref<256 x f32> + for %i = 1 to 255 { + load %A[%i] : memref<256 x f32> + } + %l = load %A[%c255] : memref<256 x f32> + store %l, %A[%c0] : memref<256 x f32> + return +} +// There are three regions here - the 'load' preceding the loop, the loop +// itself, and the instructions appearing after the loop. +// CHECK: %0 = alloc() : memref<256xf32> +// CHECK-NEXT: %1 = alloc() : memref<1xf32, 1> +// CHECK-NEXT: %2 = alloc() : memref<1xi32> +// CHECK-NEXT: dma_start %0[%c0], %1[%c0], %c1_1, %2[%c0] : memref<256xf32>, memref<1xf32, 1>, memref<1xi32> +// CHECK-NEXT: dma_wait %2[%c0], %c1_1 : memref<1xi32> +// CHECK-NEXT: %3 = load %1[%c0_2] : memref<1xf32, 1> +// CHECK-NEXT: %4 = alloc() : memref<254xf32, 1> +// CHECK-NEXT: %5 = alloc() : memref<1xi32> +// CHECK-NEXT: dma_start %0[%c1_0], %4[%c0], %c254, %5[%c0] : memref<256xf32>, memref<254xf32, 1>, memref<1xi32> +// CHECK-NEXT: dma_wait %5[%c0], %c254 : memref<1xi32> +// CHECK-NEXT: for %i0 = 1 to 255 { +// CHECK-NEXT: %6 = affine_apply [[MAP_MINUS_ONE]](%i0) +// CHECK-NEXT: %7 = load %4[%6] : memref<254xf32, 1> +// CHECK-NEXT: } +// CHECK-NEXT: %8 = alloc() : memref<256xf32, 1> +// CHECK-NEXT: %9 = alloc() : memref<1xi32> +// CHECK-NEXT: dma_start %0[%c0], %8[%c0], %c256, %9[%c0] : memref<256xf32>, memref<256xf32, 1>, memref<1xi32> +// CHECK-NEXT: dma_wait %9[%c0], %c256 : memref<1xi32> +// CHECK-NEXT: %10 = alloc() : memref<1xi32> +// CHECK-NEXT: %11 = load %8[%c255] : memref<256xf32, 1> +// CHECK-NEXT: store %11, %8[%c0_2] : memref<256xf32, 1> +// CHECK-NEXT: dma_start %8[%c0], %0[%c0], %c1, %10[%c0] : memref<256xf32, 1>, memref<256xf32>, memref<1xi32> +// CHECK-NEXT: dma_wait %10[%c0], %c1 : memref<1xi32> +// CHECK-NEXT: return + +// ----- + +// CHECK-LABEL: func @dma_mixed_loop_blocks() { +func @dma_mixed_loop_blocks() { + %c0 = constant 0 : index + %A = alloc() : memref<256 x 256 x vector<8 x f32>> + for %i = 0 to 256 { + %v = load %A[%c0, %c0] : memref<256 x 256 x vector<8 x f32>> + "foo"(%v) : (vector<8 x f32>) -> () + for %j = 0 to 256 { + %w = load %A[%i, %j] : memref<256 x 256 x vector<8 x f32>> + "bar"(%w) : (vector<8 x f32>) -> () + } + } + return +} +// CHECK-DAG: [[MEM:%[0-9]+]] = alloc() : memref<256x256xvector<8xf32>> +// CHECK-DAG: [[BUF:%[0-9]+]] = alloc() : memref<256x256xvector<8xf32>, 1> +// CHECK-DAG: [[TAG:%[0-9]+]] = alloc() : memref<1xi32> +// CHECK: dma_start [[MEM]][%c0, %c0], [[BUF]][%c0, %c0], %c65536, [[TAG]][%c0] : memref<256x256xvector<8xf32>>, memref<256x256xvector<8xf32>, 1>, memref<1xi32> +// CHECK-NEXT: dma_wait [[TAG]][%c0], %c65536 : memref<1xi32> +// CHECK-NEXT: for %i0 = 0 to 256 { +// CHECK-NEXT: %3 = load [[BUF]][%c0_0, %c0_0] : memref<256x256xvector<8xf32>, 1> +// CHECK: for %i1 = 0 to 256 { +// CHECK-NEXT: %4 = load [[BUF]][%i0, %i1] : memref<256x256xvector<8xf32>, 1> -- cgit v1.2.3 From a3d9ccaecbea042b59073ec34c91a343c1903a5c Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 4 Feb 2019 10:30:45 -0800 Subject: Replace the walkOps/visitOperationInst variants from the InstWalkers with the Instruction variants. PiperOrigin-RevId: 232322030 --- mlir/include/mlir/AffineOps/AffineOps.h | 4 +- mlir/include/mlir/Analysis/NestedMatcher.h | 2 +- mlir/include/mlir/IR/Block.h | 8 ++-- mlir/include/mlir/IR/Function.h | 6 +-- mlir/include/mlir/IR/InstVisitor.h | 53 +++------------------- .../mlir/Transforms/MLPatternLoweringPass.h | 2 +- mlir/lib/AffineOps/AffineOps.cpp | 9 ++-- mlir/lib/Analysis/MemRefDependenceCheck.cpp | 2 +- mlir/lib/EDSC/LowerEDSCTestPass.cpp | 2 +- mlir/lib/IR/AsmPrinter.cpp | 2 +- mlir/lib/IR/Block.cpp | 34 ++++++-------- mlir/lib/IR/Function.cpp | 32 ++----------- mlir/lib/Transforms/ComposeAffineMaps.cpp | 4 +- mlir/lib/Transforms/ConstantFold.cpp | 4 +- mlir/lib/Transforms/LoopFusion.cpp | 4 +- mlir/lib/Transforms/LoopUnroll.cpp | 4 +- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 2 +- mlir/lib/Transforms/LowerAffine.cpp | 2 +- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 4 +- mlir/lib/Transforms/PipelineDataTransfer.cpp | 2 +- mlir/lib/Transforms/SimplifyAffineStructures.cpp | 2 +- mlir/lib/Transforms/StripDebugInfo.cpp | 2 +- .../Utils/GreedyPatternRewriteDriver.cpp | 2 +- mlir/lib/Transforms/Utils/LoopUtils.cpp | 2 +- mlir/lib/Transforms/Utils/Utils.cpp | 4 +- 25 files changed, 61 insertions(+), 133 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/AffineOps/AffineOps.h b/mlir/include/mlir/AffineOps/AffineOps.h index 12e4589405f..46bb91c1bca 100644 --- a/mlir/include/mlir/AffineOps/AffineOps.h +++ b/mlir/include/mlir/AffineOps/AffineOps.h @@ -182,11 +182,11 @@ public: /// Walk the operation instructions in the 'for' instruction in preorder, /// calling the callback for each operation. - void walkOps(std::function callback); + void walk(std::function callback); /// Walk the operation instructions in the 'for' instruction in postorder, /// calling the callback for each operation. - void walkOpsPostOrder(std::function callback); + void walkPostOrder(std::function callback); private: friend class Instruction; diff --git a/mlir/include/mlir/Analysis/NestedMatcher.h b/mlir/include/mlir/Analysis/NestedMatcher.h index 5c040ecbe08..aba0e11ab91 100644 --- a/mlir/include/mlir/Analysis/NestedMatcher.h +++ b/mlir/include/mlir/Analysis/NestedMatcher.h @@ -127,7 +127,7 @@ private: struct State : public InstWalker { State(NestedPattern &pattern, SmallVectorImpl *matches) : pattern(pattern), matches(matches) {} - void visitOperationInst(Instruction *opInst) { + void visitInstruction(Instruction *opInst) { pattern.matchOne(opInst, matches); } diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h index 479f15d1603..a6a29fb84ea 100644 --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -311,12 +311,12 @@ public: return &Block::instructions; } - /// Walk the operation instructions of this block in preorder, calling the - /// callback for each operation. + /// Walk the instructions of this block in preorder, calling the callback for + /// each operation. void walk(std::function callback); - /// Walk the operation instructions in this block in postorder, calling the - /// callback for each operation. + /// Walk the instructions in this block in postorder, calling the callback for + /// each operation. void walkPostOrder(std::function callback); /// Walk the operation instructions in the specified [begin, end) range of diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h index 3876d750a28..f483ff46259 100644 --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -117,13 +117,11 @@ public: /// Walk the instructions in the function in preorder, calling the callback /// for each instruction or operation. - void walkInsts(std::function callback); - void walkOps(std::function callback); + void walk(std::function callback); /// Walk the instructions in the function in postorder, calling the callback /// for each instruction or operation. - void walkInstsPostOrder(std::function callback); - void walkOpsPostOrder(std::function callback); + void walkPostOrder(std::function callback); //===--------------------------------------------------------------------===// // Arguments diff --git a/mlir/include/mlir/IR/InstVisitor.h b/mlir/include/mlir/IR/InstVisitor.h index 7b74c69ceef..e11b7350894 100644 --- a/mlir/include/mlir/IR/InstVisitor.h +++ b/mlir/include/mlir/IR/InstVisitor.h @@ -67,34 +67,6 @@ #include "mlir/IR/Instruction.h" namespace mlir { - -/// Base class for instruction visitors. -template class InstVisitor { - //===--------------------------------------------------------------------===// - // Interface code - This is the public interface of the InstVisitor that you - // use to visit instructions. - -public: - // Function to visit a instruction. - RetTy visit(Instruction *s) { - static_assert(std::is_base_of::value, - "Must pass the derived type to this template!"); - return static_cast(this)->visitOperationInst(s); - } - - //===--------------------------------------------------------------------===// - // Visitation functions... these functions provide default fallbacks in case - // the user does not specify what to do for a particular instruction type. - // The default behavior is to generalize the instruction type to its subtype - // and try visiting the subtype. All of this should be inlined perfectly, - // because there are no virtual functions to get in the way. - // - - // When visiting a for inst, if inst, or an operation inst directly, these - // methods get called to indicate when transitioning into a new unit. - void visitOperationInst(Instruction *opInst) {} -}; - /// Base class for instruction walkers. A walker can traverse depth first in /// pre-order or post order. The walk methods without a suffix do a pre-order /// traversal while those that traverse in post order have a PostOrder suffix. @@ -127,36 +99,26 @@ public: static_cast(this)->walkPostOrder(it->begin(), it->end()); } - void walkOpInst(Instruction *opInst) { - static_cast(this)->visitOperationInst(opInst); - for (auto &blockList : opInst->getBlockLists()) - for (auto &block : blockList) - static_cast(this)->walk(block.begin(), block.end()); - } - - void walkOpInstPostOrder(Instruction *opInst) { - for (auto &blockList : opInst->getBlockLists()) - for (auto &block : blockList) - static_cast(this)->walkPostOrder(block.begin(), - block.end()); - static_cast(this)->visitOperationInst(opInst); - } - // Function to walk a instruction. RetTy walk(Instruction *s) { static_assert(std::is_base_of::value, "Must pass the derived type to this template!"); static_cast(this)->visitInstruction(s); - return static_cast(this)->walkOpInst(s); + for (auto &blockList : s->getBlockLists()) + for (auto &block : blockList) + static_cast(this)->walk(block.begin(), block.end()); } // Function to walk a instruction in post order DFS. RetTy walkPostOrder(Instruction *s) { static_assert(std::is_base_of::value, "Must pass the derived type to this template!"); + for (auto &blockList : s->getBlockLists()) + for (auto &block : blockList) + static_cast(this)->walkPostOrder(block.begin(), + block.end()); static_cast(this)->visitInstruction(s); - return static_cast(this)->walkOpInstPostOrder(s); } //===--------------------------------------------------------------------===// @@ -170,7 +132,6 @@ public: // called. These are typically O(1) complexity and shouldn't be recursively // processing their descendants in some way. When using RetTy, all of these // need to be overridden. - void visitOperationInst(Instruction *opInst) {} void visitInstruction(Instruction *inst) {} }; diff --git a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h index c6f810a215c..c5be3322f43 100644 --- a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h +++ b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h @@ -144,7 +144,7 @@ PassResult MLPatternLoweringPass::runOnFunction(Function *f) { MLFuncLoweringRewriter rewriter(&builder); llvm::SmallVector ops; - f->walkOps([&ops](Instruction *inst) { ops.push_back(inst); }); + f->walk([&ops](Instruction *inst) { ops.push_back(inst); }); for (Instruction *inst : ops) { for (const auto &pattern : patterns) { diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 682a8e4f1ed..2e657cf7e17 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -410,27 +410,26 @@ bool AffineForOp::matchingBoundOperandList() const { return true; } -void AffineForOp::walkOps(std::function callback) { +void AffineForOp::walk(std::function callback) { struct Walker : public InstWalker { std::function const &callback; Walker(std::function const &callback) : callback(callback) {} - void visitOperationInst(Instruction *opInst) { callback(opInst); } + void visitInstruction(Instruction *opInst) { callback(opInst); } }; Walker w(callback); w.walk(getInstruction()); } -void AffineForOp::walkOpsPostOrder( - std::function callback) { +void AffineForOp::walkPostOrder(std::function callback) { struct Walker : public InstWalker { std::function const &callback; Walker(std::function const &callback) : callback(callback) {} - void visitOperationInst(Instruction *opInst) { callback(opInst); } + void visitInstruction(Instruction *opInst) { callback(opInst); } }; Walker v(callback); diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp index b2549910a17..6ea47a20f60 100644 --- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp @@ -46,7 +46,7 @@ struct MemRefDependenceCheck : public FunctionPass, PassResult runOnFunction(Function *f) override; - void visitOperationInst(Instruction *opInst) { + void visitInstruction(Instruction *opInst) { if (opInst->isa() || opInst->isa()) { loadsAndStores.push_back(opInst); } diff --git a/mlir/lib/EDSC/LowerEDSCTestPass.cpp b/mlir/lib/EDSC/LowerEDSCTestPass.cpp index e891be68fd3..1be6b90985f 100644 --- a/mlir/lib/EDSC/LowerEDSCTestPass.cpp +++ b/mlir/lib/EDSC/LowerEDSCTestPass.cpp @@ -45,7 +45,7 @@ char LowerEDSCTestPass::passID = 0; #include "mlir/EDSC/reference-impl.inc" PassResult LowerEDSCTestPass::runOnFunction(Function *f) { - f->walkOps([](OperationInst *op) { + f->walk([](OperationInst *op) { if (op->getName().getStringRef() == "print") { auto opName = op->getAttrOfType("op"); if (!opName) { diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 36a4b8e3b5e..7b59321c815 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -263,7 +263,7 @@ void ModuleState::initialize(const Module *module) { for (auto &fn : *module) { visitType(fn.getType()); - const_cast(fn).walkInsts( + const_cast(fn).walk( [&](Instruction *op) { ModuleState::visitInstruction(op); }); } diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp index 698494144ce..0e1d21ffed8 100644 --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -256,31 +256,31 @@ Block *Block::splitBlock(iterator splitBefore) { return newBB; } -void Block::walk(std::function callback) { +void Block::walk(std::function callback) { walk(begin(), end(), callback); } void Block::walk(Block::iterator begin, Block::iterator end, - std::function callback) { + std::function callback) { struct Walker : public InstWalker { - std::function const &callback; - Walker(std::function const &callback) + std::function const &callback; + Walker(std::function const &callback) : callback(callback) {} - void visitOperationInst(OperationInst *opInst) { callback(opInst); } + void visitInstruction(Instruction *opInst) { callback(opInst); } }; Walker w(callback); w.walk(begin, end); } -void Block::walkPostOrder(std::function callback) { +void Block::walkPostOrder(std::function callback) { struct Walker : public InstWalker { - std::function const &callback; - Walker(std::function const &callback) + std::function const &callback; + Walker(std::function const &callback) : callback(callback) {} - void visitOperationInst(OperationInst *opInst) { callback(opInst); } + void visitInstruction(Instruction *opInst) { callback(opInst); } }; Walker v(callback); @@ -338,19 +338,15 @@ void BlockList::cloneInto(BlockList *dest, BlockAndValueMapping &mapper, BlockAndValueMapping &mapper; Walker(BlockAndValueMapping &mapper) : mapper(mapper) {} - /// Remap the instruction operands. - void visitInstruction(Instruction *inst) { + /// Remap the instruction and successor block operands. + void visitInstruction(OperationInst *inst) { for (auto &instOp : inst->getInstOperands()) if (auto *mappedOp = mapper.lookupOrNull(instOp.get())) instOp.set(mappedOp); - } - // Remap the successor block operands. - void visitOperationInst(OperationInst *opInst) { - if (!opInst->isTerminator()) - return; - for (auto &succOp : opInst->getBlockOperands()) - if (auto *mappedOp = mapper.lookupOrNull(succOp.get())) - succOp.set(mappedOp); + if (inst->isTerminator()) + for (auto &succOp : inst->getBlockOperands()) + if (auto *mappedOp = mapper.lookupOrNull(succOp.get())) + succOp.set(mappedOp); } }; diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index 35ac5459ad6..3a263fb13f9 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -214,7 +214,7 @@ void Function::addEntryBlock() { entry->addArguments(type.getInputs()); } -void Function::walkInsts(std::function callback) { +void Function::walk(std::function callback) { struct Walker : public InstWalker { std::function const &callback; Walker(std::function const &callback) @@ -227,39 +227,13 @@ void Function::walkInsts(std::function callback) { v.walk(this); } -void Function::walkOps(std::function callback) { - struct Walker : public InstWalker { - std::function const &callback; - Walker(std::function const &callback) - : callback(callback) {} - - void visitOperationInst(OperationInst *opInst) { callback(opInst); } - }; - - Walker v(callback); - v.walk(this); -} - -void Function::walkInstsPostOrder(std::function callback) { +void Function::walkPostOrder(std::function callback) { struct Walker : public InstWalker { std::function const &callback; Walker(std::function const &callback) : callback(callback) {} - void visitOperationInst(Instruction *inst) { callback(inst); } - }; - - Walker v(callback); - v.walkPostOrder(this); -} - -void Function::walkOpsPostOrder(std::function callback) { - struct Walker : public InstWalker { - std::function const &callback; - Walker(std::function const &callback) - : callback(callback) {} - - void visitOperationInst(OperationInst *opInst) { callback(opInst); } + void visitInstruction(Instruction *inst) { callback(inst); } }; Walker v(callback); diff --git a/mlir/lib/Transforms/ComposeAffineMaps.cpp b/mlir/lib/Transforms/ComposeAffineMaps.cpp index d7327d997c2..4f960ea73af 100644 --- a/mlir/lib/Transforms/ComposeAffineMaps.cpp +++ b/mlir/lib/Transforms/ComposeAffineMaps.cpp @@ -48,7 +48,7 @@ namespace { struct ComposeAffineMaps : public FunctionPass, InstWalker { explicit ComposeAffineMaps() : FunctionPass(&ComposeAffineMaps::passID) {} PassResult runOnFunction(Function *f) override; - void visitOperationInst(OperationInst *opInst); + void visitInstruction(OperationInst *opInst); SmallVector, 8> affineApplyOps; @@ -68,7 +68,7 @@ static bool affineApplyOp(const Instruction &inst) { return opInst.isa(); } -void ComposeAffineMaps::visitOperationInst(OperationInst *opInst) { +void ComposeAffineMaps::visitInstruction(OperationInst *opInst) { if (auto afOp = opInst->dyn_cast()) { affineApplyOps.push_back(afOp); } diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index 9c20e79180a..859d0012fac 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -37,7 +37,7 @@ struct ConstantFold : public FunctionPass, InstWalker { bool foldOperation(OperationInst *op, SmallVectorImpl &existingConstants); - void visitOperationInst(OperationInst *inst); + void visitInstruction(OperationInst *op); PassResult runOnFunction(Function *f) override; static char passID; @@ -49,7 +49,7 @@ char ConstantFold::passID = 0; /// Attempt to fold the specified operation, updating the IR to match. If /// constants are found, we keep track of them in the existingConstants list. /// -void ConstantFold::visitOperationInst(OperationInst *op) { +void ConstantFold::visitInstruction(OperationInst *op) { // If this operation is an AffineForOp, then fold the bounds. if (auto forOp = op->dyn_cast()) { constantFoldBounds(forOp); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 162e0e3b7f6..304331320ac 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -118,7 +118,7 @@ public: SmallVector storeOpInsts; bool hasNonForRegion = false; - void visitOperationInst(OperationInst *opInst) { + void visitInstruction(OperationInst *opInst) { if (opInst->isa()) forOps.push_back(opInst->cast()); else if (opInst->getNumBlockLists() != 0) @@ -619,7 +619,7 @@ public: LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {} - void visitOperationInst(OperationInst *opInst) { + void visitInstruction(OperationInst *opInst) { auto forOp = opInst->dyn_cast(); if (!forOp) return; diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 86e913bd71f..9c9952d31ca 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -113,7 +113,7 @@ PassResult LoopUnroll::runOnFunction(Function *f) { return hasInnerLoops; } - bool walkOpInstPostOrder(OperationInst *opInst) { + bool walkPostOrder(OperationInst *opInst) { bool hasInnerLoops = false; for (auto &blockList : opInst->getBlockLists()) for (auto &block : blockList) @@ -140,7 +140,7 @@ PassResult LoopUnroll::runOnFunction(Function *f) { const unsigned minTripCount; ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {} - void visitOperationInst(OperationInst *opInst) { + void visitInstruction(OperationInst *opInst) { auto forOp = opInst->dyn_cast(); if (!forOp) return; diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 7327a37ee3a..d87f9d5dc14 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -196,7 +196,7 @@ bool mlir::loopUnrollJamByFactor(OpPointer forOp, // Gather all sub-blocks to jam upon the loop being unrolled. JamBlockGatherer jbg; - jbg.walkOpInst(forInst); + jbg.walk(forInst); auto &subBlocks = jbg.subBlocks; // Generate the cleanup loop if trip count isn't a multiple of diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index 24ca4e95082..08c8188fada 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -615,7 +615,7 @@ PassResult LowerAffinePass::runOnFunction(Function *function) { // Collect all the For instructions as well as AffineIfOps and AffineApplyOps. // We do this as a prepass to avoid invalidating the walker with our rewrite. - function->walkInsts([&](Instruction *inst) { + function->walk([&](Instruction *inst) { auto op = cast(inst); if (op->isa() || op->isa() || op->isa()) diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index e6ce273b532..b9386c384dd 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -75,7 +75,7 @@ struct MemRefDataFlowOpt : public FunctionPass, InstWalker { PassResult runOnFunction(Function *f) override; - void visitOperationInst(OperationInst *opInst); + void visitInstruction(OperationInst *opInst); // A list of memref's that are potentially dead / could be eliminated. SmallPtrSet memrefsToErase; @@ -100,7 +100,7 @@ FunctionPass *mlir::createMemRefDataFlowOptPass() { // This is a straightforward implementation not optimized for speed. Optimize // this in the future if needed. -void MemRefDataFlowOpt::visitOperationInst(OperationInst *opInst) { +void MemRefDataFlowOpt::visitInstruction(OperationInst *opInst) { OperationInst *lastWriteStoreOp = nullptr; auto loadOp = opInst->dyn_cast(); diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 2e083bbfd79..8d13800160d 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -142,7 +142,7 @@ PassResult PipelineDataTransfer::runOnFunction(Function *f) { // deleted and replaced by a prologue, a new steady-state loop and an // epilogue). forOps.clear(); - f->walkOpsPostOrder([&](OperationInst *opInst) { + f->walkPostOrder([&](OperationInst *opInst) { if (auto forOp = opInst->dyn_cast()) forOps.push_back(forOp); }); diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp index 80b2de130e7..a9fcfc5bd11 100644 --- a/mlir/lib/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp @@ -64,7 +64,7 @@ static IntegerSet simplifyIntegerSet(IntegerSet set) { } PassResult SimplifyAffineStructures::runOnFunction(Function *f) { - f->walkOps([&](OperationInst *opInst) { + f->walk([&](OperationInst *opInst) { for (auto attr : opInst->getAttrs()) { if (auto mapAttr = attr.second.dyn_cast()) { MutableAffineMap mMap(mapAttr.getValue()); diff --git a/mlir/lib/Transforms/StripDebugInfo.cpp b/mlir/lib/Transforms/StripDebugInfo.cpp index 6e1d5ff2d11..c5e42b622ed 100644 --- a/mlir/lib/Transforms/StripDebugInfo.cpp +++ b/mlir/lib/Transforms/StripDebugInfo.cpp @@ -39,7 +39,7 @@ PassResult StripDebugInfo::runOnFunction(Function *f) { // Strip the debug info from the function and its instructions. f->setLoc(unknownLoc); - f->walkInsts([&](Instruction *inst) { inst->setLoc(unknownLoc); }); + f->walk([&](Instruction *inst) { inst->setLoc(unknownLoc); }); return success(); } diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 019cbbae063..790f971bb58 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -38,7 +38,7 @@ public: worklist.reserve(64); // Add all operations to the worklist. - fn->walkOps([&](OperationInst *inst) { addToWorklist(inst); }); + fn->walk([&](OperationInst *inst) { addToWorklist(inst); }); } /// Perform the rewrites. diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 99079119dab..153557de04a 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -135,7 +135,7 @@ bool mlir::promoteIfSingleIteration(OpPointer forOp) { /// their body into the containing Block. void mlir::promoteSingleIterationLoops(Function *f) { // Gathers all innermost loops through a post order pruned walk. - f->walkOpsPostOrder([](OperationInst *inst) { + f->walkPostOrder([](OperationInst *inst) { if (auto forOp = inst->dyn_cast()) promoteIfSingleIteration(forOp); }); diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 732062a8b97..879a4f4b585 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -362,8 +362,8 @@ void mlir::remapFunctionAttrs( Function &fn, const DenseMap &remappingTable) { // Look at all instructions in a Function. - fn.walkOps( - [&](OperationInst *inst) { remapFunctionAttrs(*inst, remappingTable); }); + fn.walk( + [&](Instruction *inst) { remapFunctionAttrs(*inst, remappingTable); }); } void mlir::remapFunctionAttrs( -- cgit v1.2.3 From b499277fb648c44907443ce44ec6bcc6b7596039 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 4 Feb 2019 10:38:47 -0800 Subject: Remove remaining usages of OperationInst in lib/Transforms. PiperOrigin-RevId: 232323671 --- mlir/lib/Transforms/CSE.cpp | 43 +++---- mlir/lib/Transforms/ComposeAffineMaps.cpp | 10 +- mlir/lib/Transforms/ConstantFold.cpp | 8 +- mlir/lib/Transforms/DialectConversion.cpp | 31 +++-- mlir/lib/Transforms/DmaGeneration.cpp | 5 +- mlir/lib/Transforms/LoopFusion.cpp | 134 ++++++++++----------- mlir/lib/Transforms/LoopTiling.cpp | 5 +- mlir/lib/Transforms/LoopUnroll.cpp | 4 +- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 10 +- mlir/lib/Transforms/LowerAffine.cpp | 12 +- mlir/lib/Transforms/LowerVectorTransfers.cpp | 5 +- mlir/lib/Transforms/MaterializeVectors.cpp | 44 ++++--- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 30 +++-- mlir/lib/Transforms/PipelineDataTransfer.cpp | 43 +++---- mlir/lib/Transforms/SimplifyAffineStructures.cpp | 2 +- .../Utils/GreedyPatternRewriteDriver.cpp | 38 +++--- mlir/lib/Transforms/Utils/LoopUtils.cpp | 9 +- mlir/lib/Transforms/Utils/Utils.cpp | 17 ++- .../Vectorization/VectorizerTestPass.cpp | 27 ++--- mlir/lib/Transforms/Vectorize.cpp | 91 +++++++------- 20 files changed, 251 insertions(+), 317 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index e471b6792c5..63a676d7b52 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -39,10 +39,10 @@ using namespace mlir; namespace { // TODO(riverriddle) Handle commutative operations. -struct SimpleOperationInfo : public llvm::DenseMapInfo { - static unsigned getHashValue(const OperationInst *op) { +struct SimpleOperationInfo : public llvm::DenseMapInfo { + static unsigned getHashValue(const Instruction *op) { // Hash the operations based upon their: - // - OperationInst Name + // - Instruction Name // - Attributes // - Result Types // - Operands @@ -51,7 +51,7 @@ struct SimpleOperationInfo : public llvm::DenseMapInfo { hash_combine_range(op->result_type_begin(), op->result_type_end()), hash_combine_range(op->operand_begin(), op->operand_end())); } - static bool isEqual(const OperationInst *lhs, const OperationInst *rhs) { + static bool isEqual(const Instruction *lhs, const Instruction *rhs) { if (lhs == rhs) return true; if (lhs == getTombstoneKey() || lhs == getEmptyKey() || @@ -89,8 +89,8 @@ struct CSE : public FunctionPass { /// Shared implementation of operation elimination and scoped map definitions. using AllocatorTy = llvm::RecyclingAllocator< llvm::BumpPtrAllocator, - llvm::ScopedHashTableVal>; - using ScopedMapTy = llvm::ScopedHashTable>; + using ScopedMapTy = llvm::ScopedHashTable; /// Represents a single entry in the depth first traversal of a CFG. @@ -111,7 +111,7 @@ struct CSE : public FunctionPass { /// Attempt to eliminate a redundant operation. Returns true if the operation /// was marked for removal, false otherwise. - bool simplifyOperation(OperationInst *op); + bool simplifyOperation(Instruction *op); void simplifyBlock(Block *bb); @@ -122,14 +122,14 @@ private: ScopedMapTy knownValues; /// Operations marked as dead and to be erased. - std::vector opsToErase; + std::vector opsToErase; }; } // end anonymous namespace char CSE::passID = 0; /// Attempt to eliminate a redundant operation. -bool CSE::simplifyOperation(OperationInst *op) { +bool CSE::simplifyOperation(Instruction *op) { // TODO(riverriddle) We currently only eliminate non side-effecting // operations. if (!op->hasNoSideEffect()) @@ -166,23 +166,16 @@ bool CSE::simplifyOperation(OperationInst *op) { void CSE::simplifyBlock(Block *bb) { for (auto &i : *bb) { - switch (i.getKind()) { - case Instruction::Kind::OperationInst: { - auto *opInst = cast(&i); - - // If the operation is simplified, we don't process any held block lists. - if (simplifyOperation(opInst)) - continue; - - // Simplify any held blocks. - for (auto &blockList : opInst->getBlockLists()) { - for (auto &b : blockList) { - ScopedMapTy::ScopeTy scope(knownValues); - simplifyBlock(&b); - } + // If the operation is simplified, we don't process any held block lists. + if (simplifyOperation(&i)) + continue; + + // Simplify any held blocks. + for (auto &blockList : i.getBlockLists()) { + for (auto &b : blockList) { + ScopedMapTy::ScopeTy scope(knownValues); + simplifyBlock(&b); } - break; - } } } } diff --git a/mlir/lib/Transforms/ComposeAffineMaps.cpp b/mlir/lib/Transforms/ComposeAffineMaps.cpp index 4f960ea73af..4a6430dc9be 100644 --- a/mlir/lib/Transforms/ComposeAffineMaps.cpp +++ b/mlir/lib/Transforms/ComposeAffineMaps.cpp @@ -48,7 +48,7 @@ namespace { struct ComposeAffineMaps : public FunctionPass, InstWalker { explicit ComposeAffineMaps() : FunctionPass(&ComposeAffineMaps::passID) {} PassResult runOnFunction(Function *f) override; - void visitInstruction(OperationInst *opInst); + void visitInstruction(Instruction *opInst); SmallVector, 8> affineApplyOps; @@ -64,14 +64,12 @@ FunctionPass *mlir::createComposeAffineMapsPass() { } static bool affineApplyOp(const Instruction &inst) { - const auto &opInst = cast(inst); - return opInst.isa(); + return inst.isa(); } -void ComposeAffineMaps::visitInstruction(OperationInst *opInst) { - if (auto afOp = opInst->dyn_cast()) { +void ComposeAffineMaps::visitInstruction(Instruction *opInst) { + if (auto afOp = opInst->dyn_cast()) affineApplyOps.push_back(afOp); - } } PassResult ComposeAffineMaps::runOnFunction(Function *f) { diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index 859d0012fac..54486cdb293 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -33,11 +33,11 @@ struct ConstantFold : public FunctionPass, InstWalker { // All constants in the function post folding. SmallVector existingConstants; // Operations that were folded and that need to be erased. - std::vector opInstsToErase; + std::vector opInstsToErase; - bool foldOperation(OperationInst *op, + bool foldOperation(Instruction *op, SmallVectorImpl &existingConstants); - void visitInstruction(OperationInst *op); + void visitInstruction(Instruction *op); PassResult runOnFunction(Function *f) override; static char passID; @@ -49,7 +49,7 @@ char ConstantFold::passID = 0; /// Attempt to fold the specified operation, updating the IR to match. If /// constants are found, we keep track of them in the existingConstants list. /// -void ConstantFold::visitInstruction(OperationInst *op) { +void ConstantFold::visitInstruction(Instruction *op) { // If this operation is an AffineForOp, then fold the bounds. if (auto forOp = op->dyn_cast()) { constantFoldBounds(forOp); diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index 443e7750947..996416d9271 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -50,7 +50,7 @@ private: // Utility that looks up a list of value in the value remapping table. Returns // an empty vector if one of the values is not mapped yet. SmallVector - lookupValues(const llvm::iterator_range + lookupValues(const llvm::iterator_range &operands); // Converts the given function to the dialect using hooks defined in @@ -61,13 +61,13 @@ private: // from `valueRemapping` and the converted blocks from `blockRemapping`, and // passes them to `converter->rewriteTerminator` function defined in the // pattern, together with `builder`. - bool convertOpWithSuccessors(DialectOpConversion *converter, - OperationInst *op, FuncBuilder &builder); + bool convertOpWithSuccessors(DialectOpConversion *converter, Instruction *op, + FuncBuilder &builder); // Converts an operation without successors. Extracts the converted operands // from `valueRemapping` and passes them to the `converter->rewrite` function // defined in the pattern, together with `builder`. - bool convertOp(DialectOpConversion *converter, OperationInst *op, + bool convertOp(DialectOpConversion *converter, Instruction *op, FuncBuilder &builder); // Converts a block by traversing its instructions sequentially, looking for @@ -104,8 +104,7 @@ private: } // end namespace mlir SmallVector impl::FunctionConversion::lookupValues( - const llvm::iterator_range - &operands) { + const llvm::iterator_range &operands) { SmallVector remapped; remapped.reserve(llvm::size(operands)); for (const Value *operand : operands) { @@ -118,7 +117,7 @@ SmallVector impl::FunctionConversion::lookupValues( } bool impl::FunctionConversion::convertOpWithSuccessors( - DialectOpConversion *converter, OperationInst *op, FuncBuilder &builder) { + DialectOpConversion *converter, Instruction *op, FuncBuilder &builder) { SmallVector destinations; destinations.reserve(op->getNumSuccessors()); SmallVector operands = lookupValues(op->getOperands()); @@ -149,7 +148,7 @@ bool impl::FunctionConversion::convertOpWithSuccessors( } bool impl::FunctionConversion::convertOp(DialectOpConversion *converter, - OperationInst *op, + Instruction *op, FuncBuilder &builder) { auto operands = lookupValues(op->getOperands()); assert((!operands.empty() || op->getNumOperands() == 0) && @@ -174,24 +173,22 @@ bool impl::FunctionConversion::convertBlock( // Iterate over ops and convert them. for (Instruction &inst : *block) { - auto op = dyn_cast(&inst); - if (!op) { - inst.emitError("unsupported instruction (For/If)"); + if (inst.getNumBlockLists() != 0) { + inst.emitError("unsupported region instruction"); return true; } // Find the first matching conversion and apply it. bool converted = false; for (auto *conversion : conversions) { - if (!conversion->match(op)) + if (!conversion->match(&inst)) continue; - if (op->isTerminator() && op->getNumSuccessors() > 0) { - if (convertOpWithSuccessors(conversion, op, builder)) - return true; - } else { - if (convertOp(conversion, op, builder)) + if (inst.isTerminator() && inst.getNumSuccessors() > 0) { + if (convertOpWithSuccessors(conversion, &inst, builder)) return true; + } else if (convertOp(conversion, &inst, builder)) { + return true; } converted = true; break; diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 2bbb32036c2..92ae3767098 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -157,8 +157,7 @@ static void getMultiLevelStrides(const MemRefRegion ®ion, /// dynamic shaped memref's for now. `numParamLoopIVs` is the number of /// enclosing loop IVs of opInst (starting from the outermost) that the region /// is parametric on. -static bool getFullMemRefAsRegion(OperationInst *opInst, - unsigned numParamLoopIVs, +static bool getFullMemRefAsRegion(Instruction *opInst, unsigned numParamLoopIVs, MemRefRegion *region) { unsigned rank; if (auto loadOp = opInst->dyn_cast()) { @@ -563,7 +562,7 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { fastBufferMap.clear(); // Walk this range of instructions to gather all memory regions. - block->walk(begin, end, [&](OperationInst *opInst) { + block->walk(begin, end, [&](Instruction *opInst) { // Gather regions to allocate to buffers in faster memory space. if (auto loadOp = opInst->dyn_cast()) { if (loadOp->getMemRefType().getMemorySpace() != slowMemorySpace) diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 304331320ac..d7d69e569e5 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -114,11 +114,11 @@ namespace { class LoopNestStateCollector : public InstWalker { public: SmallVector, 4> forOps; - SmallVector loadOpInsts; - SmallVector storeOpInsts; + SmallVector loadOpInsts; + SmallVector storeOpInsts; bool hasNonForRegion = false; - void visitInstruction(OperationInst *opInst) { + void visitInstruction(Instruction *opInst) { if (opInst->isa()) forOps.push_back(opInst->cast()); else if (opInst->getNumBlockLists() != 0) @@ -131,7 +131,7 @@ public: }; // TODO(b/117228571) Replace when this is modeled through side-effects/op traits -static bool isMemRefDereferencingOp(const OperationInst &op) { +static bool isMemRefDereferencingOp(const Instruction &op) { if (op.isa() || op.isa() || op.isa() || op.isa()) return true; @@ -153,9 +153,9 @@ public: // The top-level statment which is (or contains) loads/stores. Instruction *inst; // List of load operations. - SmallVector loads; + SmallVector loads; // List of store op insts. - SmallVector stores; + SmallVector stores; Node(unsigned id, Instruction *inst) : id(id), inst(inst) {} // Returns the load op count for 'memref'. @@ -258,16 +258,13 @@ public: for (auto *storeOpInst : node->stores) { auto *memref = storeOpInst->cast()->getMemRef(); auto *inst = memref->getDefiningInst(); - auto *opInst = dyn_cast_or_null(inst); - // Return false if 'memref' is a function argument. - if (opInst == nullptr) + // Return false if 'memref' is a block argument. + if (!inst) return true; // Return false if any use of 'memref' escapes the function. - for (auto &use : memref->getUses()) { - auto *user = dyn_cast(use.getOwner()); - if (!user || !isMemRefDereferencingOp(*user)) + for (auto &use : memref->getUses()) + if (!isMemRefDereferencingOp(*use.getOwner())) return true; - } } return false; } @@ -461,8 +458,8 @@ public: } // Adds ops in 'loads' and 'stores' to node at 'id'. - void addToNode(unsigned id, const SmallVectorImpl &loads, - const SmallVectorImpl &stores) { + void addToNode(unsigned id, const SmallVectorImpl &loads, + const SmallVectorImpl &stores) { Node *node = getNode(id); for (auto *loadOpInst : loads) node->loads.push_back(loadOpInst); @@ -509,7 +506,7 @@ bool MemRefDependenceGraph::init(Function *f) { DenseMap forToNodeMap; for (auto &inst : f->front()) { - if (auto forOp = cast(&inst)->dyn_cast()) { + if (auto forOp = inst.dyn_cast()) { // Create graph node 'id' to represent top-level 'forOp' and record // all loads and store accesses it contains. LoopNestStateCollector collector; @@ -530,30 +527,28 @@ bool MemRefDependenceGraph::init(Function *f) { } forToNodeMap[&inst] = node.id; nodes.insert({node.id, node}); - } else if (auto *opInst = dyn_cast(&inst)) { - if (auto loadOp = opInst->dyn_cast()) { - // Create graph node for top-level load op. - Node node(nextNodeId++, &inst); - node.loads.push_back(opInst); - auto *memref = opInst->cast()->getMemRef(); - memrefAccesses[memref].insert(node.id); - nodes.insert({node.id, node}); - } else if (auto storeOp = opInst->dyn_cast()) { - // Create graph node for top-level store op. - Node node(nextNodeId++, &inst); - node.stores.push_back(opInst); - auto *memref = opInst->cast()->getMemRef(); - memrefAccesses[memref].insert(node.id); - nodes.insert({node.id, node}); - } else if (opInst->getNumBlockLists() != 0) { - // Return false if another region is found (not currently supported). - return false; - } else if (opInst->getNumResults() > 0 && !opInst->use_empty()) { - // Create graph node for top-level producer of SSA values, which - // could be used by loop nest nodes. - Node node(nextNodeId++, &inst); - nodes.insert({node.id, node}); - } + } else if (auto loadOp = inst.dyn_cast()) { + // Create graph node for top-level load op. + Node node(nextNodeId++, &inst); + node.loads.push_back(&inst); + auto *memref = inst.cast()->getMemRef(); + memrefAccesses[memref].insert(node.id); + nodes.insert({node.id, node}); + } else if (auto storeOp = inst.dyn_cast()) { + // Create graph node for top-level store op. + Node node(nextNodeId++, &inst); + node.stores.push_back(&inst); + auto *memref = inst.cast()->getMemRef(); + memrefAccesses[memref].insert(node.id); + nodes.insert({node.id, node}); + } else if (inst.getNumBlockLists() != 0) { + // Return false if another region is found (not currently supported). + return false; + } else if (inst.getNumResults() > 0 && !inst.use_empty()) { + // Create graph node for top-level producer of SSA values, which + // could be used by loop nest nodes. + Node node(nextNodeId++, &inst); + nodes.insert({node.id, node}); } } @@ -563,12 +558,11 @@ bool MemRefDependenceGraph::init(Function *f) { const Node &node = idAndNode.second; if (!node.loads.empty() || !node.stores.empty()) continue; - auto *opInst = cast(node.inst); + auto *opInst = node.inst; for (auto *value : opInst->getResults()) { for (auto &use : value->getUses()) { - auto *userOpInst = cast(use.getOwner()); SmallVector, 4> loops; - getLoopIVs(*userOpInst, &loops); + getLoopIVs(*use.getOwner(), &loops); if (loops.empty()) continue; assert(forToNodeMap.count(loops[0]->getInstruction()) > 0); @@ -619,7 +613,7 @@ public: LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {} - void visitInstruction(OperationInst *opInst) { + void visitInstruction(Instruction *opInst) { auto forOp = opInst->dyn_cast(); if (!forOp) return; @@ -627,8 +621,7 @@ public: auto *forInst = forOp->getInstruction(); auto *parentInst = forOp->getInstruction()->getParentInst(); if (parentInst != nullptr) { - assert(cast(parentInst)->isa() && - "Expected parent AffineForOp"); + assert(parentInst->isa() && "Expected parent AffineForOp"); // Add mapping to 'forOp' from its parent AffineForOp. stats->loopMap[parentInst].push_back(forOp); } @@ -637,8 +630,7 @@ public: unsigned count = 0; stats->opCountMap[forInst] = 0; for (auto &inst : *forOp->getBody()) { - if (!(cast(inst).isa() || - cast(inst).isa())) + if (!(inst.isa() || inst.isa())) ++count; } stats->opCountMap[forInst] = count; @@ -723,7 +715,7 @@ static Optional getConstDifference(AffineMap lbMap, AffineMap ubMap) { // was encountered). // TODO(andydavis) Make this work with non-unit step loops. static bool buildSliceTripCountMap( - OperationInst *srcOpInst, ComputationSliceState *sliceState, + Instruction *srcOpInst, ComputationSliceState *sliceState, llvm::SmallDenseMap *tripCountMap) { SmallVector, 4> srcLoopIVs; getLoopIVs(*srcOpInst, &srcLoopIVs); @@ -755,10 +747,10 @@ static bool buildSliceTripCountMap( // adds them to 'dstLoads'. static void moveLoadsAccessingMemrefTo(Value *memref, - SmallVectorImpl *srcLoads, - SmallVectorImpl *dstLoads) { + SmallVectorImpl *srcLoads, + SmallVectorImpl *dstLoads) { dstLoads->clear(); - SmallVector srcLoadsToKeep; + SmallVector srcLoadsToKeep; for (auto *load : *srcLoads) { if (load->cast()->getMemRef() == memref) dstLoads->push_back(load); @@ -769,7 +761,7 @@ moveLoadsAccessingMemrefTo(Value *memref, } // Returns the innermost common loop depth for the set of operations in 'ops'. -static unsigned getInnermostCommonLoopDepth(ArrayRef ops) { +static unsigned getInnermostCommonLoopDepth(ArrayRef ops) { unsigned numOps = ops.size(); assert(numOps > 0); @@ -797,10 +789,10 @@ static unsigned getInnermostCommonLoopDepth(ArrayRef ops) { // Returns the maximum loop depth at which no dependences between 'loadOpInsts' // and 'storeOpInsts' are satisfied. -static unsigned getMaxLoopDepth(ArrayRef loadOpInsts, - ArrayRef storeOpInsts) { +static unsigned getMaxLoopDepth(ArrayRef loadOpInsts, + ArrayRef storeOpInsts) { // Merge loads and stores into the same array. - SmallVector ops(loadOpInsts.begin(), loadOpInsts.end()); + SmallVector ops(loadOpInsts.begin(), loadOpInsts.end()); ops.append(storeOpInsts.begin(), storeOpInsts.end()); // Compute the innermost common loop depth for loads and stores. @@ -913,7 +905,7 @@ unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { // TODO(bondhugula): consider refactoring the common code from generateDma and // this one. static Value *createPrivateMemRef(OpPointer forOp, - OperationInst *srcStoreOpInst, + Instruction *srcStoreOpInst, unsigned dstLoopDepth, Optional fastMemorySpace, unsigned localBufSizeThreshold) { @@ -1061,9 +1053,9 @@ static uint64_t getSliceIterationCount( // *) Compares the total cost of the unfused loop nests to the min cost fused // loop nest computed in the previous step, and returns true if the latter // is lower. -static bool isFusionProfitable(OperationInst *srcOpInst, - ArrayRef dstLoadOpInsts, - ArrayRef dstStoreOpInsts, +static bool isFusionProfitable(Instruction *srcOpInst, + ArrayRef dstLoadOpInsts, + ArrayRef dstStoreOpInsts, ComputationSliceState *sliceState, unsigned *dstLoopDepth) { LLVM_DEBUG({ @@ -1174,7 +1166,7 @@ static bool isFusionProfitable(OperationInst *srcOpInst, computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]->getInstruction()] = -1; for (auto *loadOp : dstLoadOpInsts) { auto *parentInst = loadOp->getParentInst(); - if (parentInst && cast(parentInst)->isa()) + if (parentInst && parentInst->isa()) computeCostMap[parentInst] = -1; } } @@ -1393,11 +1385,11 @@ public: // Get 'dstNode' into which to attempt fusion. auto *dstNode = mdg->getNode(dstId); // Skip if 'dstNode' is not a loop nest. - if (!cast(dstNode->inst)->isa()) + if (!dstNode->inst->isa()) continue; - SmallVector loads = dstNode->loads; - SmallVector dstLoadOpInsts; + SmallVector loads = dstNode->loads; + SmallVector dstLoadOpInsts; DenseSet visitedMemrefs; while (!loads.empty()) { // Get memref of load on top of the stack. @@ -1426,7 +1418,7 @@ public: // Get 'srcNode' from which to attempt fusion into 'dstNode'. auto *srcNode = mdg->getNode(srcId); // Skip if 'srcNode' is not a loop nest. - if (!cast(srcNode->inst)->isa()) + if (!srcNode->inst->isa()) continue; // Skip if 'srcNode' has more than one store to any memref. // TODO(andydavis) Support fusing multi-output src loop nests. @@ -1454,7 +1446,7 @@ public: // Get unique 'srcNode' store op. auto *srcStoreOpInst = srcNode->stores.front(); // Gather 'dstNode' store ops to 'memref'. - SmallVector dstStoreOpInsts; + SmallVector dstStoreOpInsts; for (auto *storeOpInst : dstNode->stores) if (storeOpInst->cast()->getMemRef() == memref) dstStoreOpInsts.push_back(storeOpInst); @@ -1472,8 +1464,7 @@ public: srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState); if (sliceLoopNest != nullptr) { // Move 'dstAffineForOp' before 'insertPointInst' if needed. - auto dstAffineForOp = - cast(dstNode->inst)->cast(); + auto dstAffineForOp = dstNode->inst->cast(); if (insertPointInst != dstAffineForOp->getInstruction()) { dstAffineForOp->getInstruction()->moveBefore(insertPointInst); } @@ -1488,7 +1479,7 @@ public: promoteIfSingleIteration(forOp); } // Create private memref for 'memref' in 'dstAffineForOp'. - SmallVector storesForMemref; + SmallVector storesForMemref; for (auto *storeOpInst : sliceCollector.storeOpInsts) { if (storeOpInst->cast()->getMemRef() == memref) storesForMemref.push_back(storeOpInst); @@ -1541,9 +1532,8 @@ public: continue; // Use list expected to match the dep graph info. auto *inst = memref->getDefiningInst(); - auto *opInst = dyn_cast_or_null(inst); - if (opInst && opInst->isa()) - opInst->erase(); + if (inst && inst->isa()) + inst->erase(); } } }; diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index f1ee7fd1853..8b368e5f182 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -237,14 +237,13 @@ getTileableBands(Function *f, do { band.push_back(currInst); } while (currInst->getBody()->getInstructions().size() == 1 && - (currInst = cast(currInst->getBody()->front()) - .dyn_cast())); + (currInst = currInst->getBody()->front().dyn_cast())); bands->push_back(band); }; for (auto &block : *f) for (auto &inst : block) - if (auto forOp = cast(inst).dyn_cast()) + if (auto forOp = inst.dyn_cast()) getMaximalPerfectLoopNest(forOp); } diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 9c9952d31ca..b1e15ccb07b 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -113,7 +113,7 @@ PassResult LoopUnroll::runOnFunction(Function *f) { return hasInnerLoops; } - bool walkPostOrder(OperationInst *opInst) { + bool walkPostOrder(Instruction *opInst) { bool hasInnerLoops = false; for (auto &blockList : opInst->getBlockLists()) for (auto &block : blockList) @@ -140,7 +140,7 @@ PassResult LoopUnroll::runOnFunction(Function *f) { const unsigned minTripCount; ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {} - void visitInstruction(OperationInst *opInst) { + void visitInstruction(Instruction *opInst) { auto forOp = opInst->dyn_cast(); if (!forOp) return; diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index d87f9d5dc14..74c54fde047 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -100,8 +100,7 @@ PassResult LoopUnrollAndJam::runOnFunction(Function *f) { // any for Inst. auto &entryBlock = f->front(); if (!entryBlock.empty()) - if (auto forOp = - cast(entryBlock.front()).dyn_cast()) + if (auto forOp = entryBlock.front().dyn_cast()) runOnAffineForOp(forOp); return success(); @@ -149,12 +148,12 @@ bool mlir::loopUnrollJamByFactor(OpPointer forOp, void walk(InstListType::iterator Start, InstListType::iterator End) { for (auto it = Start; it != End;) { auto subBlockStart = it; - while (it != End && !cast(it)->isa()) + while (it != End && !it->isa()) ++it; if (it != subBlockStart) subBlocks.push_back({subBlockStart, std::prev(it)}); // Process all for insts that appear next. - while (it != End && cast(it)->isa()) + while (it != End && it->isa()) walk(&*it++); } } @@ -206,8 +205,7 @@ bool mlir::loopUnrollJamByFactor(OpPointer forOp, // Insert the cleanup loop right after 'forOp'. FuncBuilder builder(forInst->getBlock(), std::next(Block::iterator(forInst))); - auto cleanupAffineForOp = - cast(builder.clone(*forInst))->cast(); + auto cleanupAffineForOp = builder.clone(*forInst)->cast(); cleanupAffineForOp->setLowerBoundMap( getCleanupLoopLowerBound(forOp, unrollJamFactor, &builder)); diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index 08c8188fada..88ccc90c18b 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -616,23 +616,21 @@ PassResult LowerAffinePass::runOnFunction(Function *function) { // Collect all the For instructions as well as AffineIfOps and AffineApplyOps. // We do this as a prepass to avoid invalidating the walker with our rewrite. function->walk([&](Instruction *inst) { - auto op = cast(inst); - if (op->isa() || op->isa() || - op->isa()) + if (inst->isa() || inst->isa() || + inst->isa()) instsToRewrite.push_back(inst); }); // Rewrite all of the ifs and fors. We walked the instructions in preorder, // so we know that we will rewrite them in the same order. for (auto *inst : instsToRewrite) { - auto op = cast(inst); - if (auto ifOp = op->dyn_cast()) { + if (auto ifOp = inst->dyn_cast()) { if (lowerAffineIf(ifOp)) return failure(); - } else if (auto forOp = op->dyn_cast()) { + } else if (auto forOp = inst->dyn_cast()) { if (lowerAffineFor(forOp)) return failure(); - } else if (lowerAffineApply(op->cast())) { + } else if (lowerAffineApply(inst->cast())) { return failure(); } } diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index 7f1e9b157d8..63fb45db9c5 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -401,13 +401,12 @@ public: explicit VectorTransferExpander(MLIRContext *context) : MLLoweringPattern(VectorTransferOpTy::getOperationName(), 1, context) {} - PatternMatchResult match(OperationInst *op) const override { + PatternMatchResult match(Instruction *op) const override { if (m_Op().match(op)) return matchSuccess(); return matchFailure(); } - void rewriteOpInst(OperationInst *op, - MLFuncGlobalLoweringState *funcWiseState, + void rewriteOpInst(Instruction *op, MLFuncGlobalLoweringState *funcWiseState, std::unique_ptr opState, MLFuncLoweringRewriter *rewriter) const override { VectorTransferRewriter( diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index f2dae11112b..f55c2154f08 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -246,8 +246,8 @@ static SmallVector delinearize(unsigned linearIndex, return res; } -static OperationInst * -instantiate(FuncBuilder *b, OperationInst *opInst, VectorType hwVectorType, +static Instruction * +instantiate(FuncBuilder *b, Instruction *opInst, VectorType hwVectorType, DenseMap *substitutionsMap); /// Not all Values belong to a program slice scoped within the immediately @@ -391,7 +391,7 @@ reindexAffineIndices(FuncBuilder *b, VectorType hwVectorType, /// - constant splat is replaced by constant splat of `hwVectorType`. /// TODO(ntv): add more substitutions on a per-need basis. static SmallVector -materializeAttributes(OperationInst *opInst, VectorType hwVectorType) { +materializeAttributes(Instruction *opInst, VectorType hwVectorType) { SmallVector res; for (auto a : opInst->getAttrs()) { if (auto splat = a.second.dyn_cast()) { @@ -411,8 +411,8 @@ materializeAttributes(OperationInst *opInst, VectorType hwVectorType) { /// substitutionsMap. /// /// If the underlying substitution fails, this fails too and returns nullptr. -static OperationInst * -instantiate(FuncBuilder *b, OperationInst *opInst, VectorType hwVectorType, +static Instruction * +instantiate(FuncBuilder *b, Instruction *opInst, VectorType hwVectorType, DenseMap *substitutionsMap) { assert(!opInst->isa() && "Should call the function specialized for VectorTransferReadOp"); @@ -488,7 +488,7 @@ static AffineMap projectedPermutationMap(VectorTransferOpTy *transfer, /// `hwVectorType` int the covering of the super-vector type. For a more /// detailed description of the problem, see the description of /// reindexAffineIndices. -static OperationInst * +static Instruction * instantiate(FuncBuilder *b, VectorTransferReadOp *read, VectorType hwVectorType, ArrayRef hwVectorInstance, DenseMap *substitutionsMap) { @@ -512,7 +512,7 @@ instantiate(FuncBuilder *b, VectorTransferReadOp *read, VectorType hwVectorType, /// `hwVectorType` int the covering of th3e super-vector type. For a more /// detailed description of the problem, see the description of /// reindexAffineIndices. -static OperationInst * +static Instruction * instantiate(FuncBuilder *b, VectorTransferWriteOp *write, VectorType hwVectorType, ArrayRef hwVectorInstance, DenseMap *substitutionsMap) { @@ -555,21 +555,20 @@ static bool instantiateMaterialization(Instruction *inst, // Create a builder here for unroll-and-jam effects. FuncBuilder b(inst); - auto *opInst = cast(inst); // AffineApplyOp are ignored: instantiating the proper vector op will take // care of AffineApplyOps by composing them properly. - if (opInst->isa()) { + if (inst->isa()) { return false; } - if (opInst->getNumBlockLists() != 0) + if (inst->getNumBlockLists() != 0) return inst->emitError("NYI path Op with region"); - if (auto write = opInst->dyn_cast()) { + if (auto write = inst->dyn_cast()) { auto *clone = instantiate(&b, write, state->hwVectorType, state->hwVectorInstance, state->substitutionsMap); return clone == nullptr; } - if (auto read = opInst->dyn_cast()) { + if (auto read = inst->dyn_cast()) { auto *clone = instantiate(&b, read, state->hwVectorType, state->hwVectorInstance, state->substitutionsMap); if (!clone) { @@ -582,19 +581,19 @@ static bool instantiateMaterialization(Instruction *inst, // The only op with 0 results reaching this point must, by construction, be // VectorTransferWriteOps and have been caught above. Ops with >= 2 results // are not yet supported. So just support 1 result. - if (opInst->getNumResults() != 1) { + if (inst->getNumResults() != 1) { return inst->emitError("NYI: ops with != 1 results"); } - if (opInst->getResult(0)->getType() != state->superVectorType) { + if (inst->getResult(0)->getType() != state->superVectorType) { return inst->emitError("Op does not return a supervector."); } auto *clone = - instantiate(&b, opInst, state->hwVectorType, state->substitutionsMap); + instantiate(&b, inst, state->hwVectorType, state->substitutionsMap); if (!clone) { return true; } state->substitutionsMap->insert( - std::make_pair(opInst->getResult(0), clone->getResult(0))); + std::make_pair(inst->getResult(0), clone->getResult(0))); return false; } @@ -645,7 +644,7 @@ static bool emitSlice(MaterializationState *state, } LLVM_DEBUG(dbgs() << "\nMLFunction is now\n"); - LLVM_DEBUG(cast((*slice)[0])->getFunction()->print(dbgs())); + LLVM_DEBUG((*slice)[0]->getFunction()->print(dbgs())); // slice are topologically sorted, we can just erase them in reverse // order. Reverse iterator does not just work simply with an operator* @@ -677,7 +676,7 @@ static bool emitSlice(MaterializationState *state, /// scope. /// TODO(ntv): please document return value. static bool materialize(Function *f, - const SetVector &terminators, + const SetVector &terminators, MaterializationState *state) { DenseSet seen; DominanceInfo domInfo(f); @@ -757,18 +756,17 @@ PassResult MaterializeVectorsPass::runOnFunction(Function *f) { // Capture terminators; i.e. vector_transfer_write ops involving a strict // super-vector of subVectorType. auto filter = [subVectorType](const Instruction &inst) { - const auto &opInst = cast(inst); - if (!opInst.isa()) { + if (!inst.isa()) { return false; } - return matcher::operatesOnSuperVectors(opInst, subVectorType); + return matcher::operatesOnSuperVectors(inst, subVectorType); }; auto pat = Op(filter); SmallVector matches; pat.match(f, &matches); - SetVector terminators; + SetVector terminators; for (auto m : matches) { - terminators.insert(cast(m.getMatchedInstruction())); + terminators.insert(m.getMatchedInstruction()); } auto fail = materialize(f, terminators, &state); diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index b9386c384dd..b2b69dc7b6d 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -75,12 +75,12 @@ struct MemRefDataFlowOpt : public FunctionPass, InstWalker { PassResult runOnFunction(Function *f) override; - void visitInstruction(OperationInst *opInst); + void visitInstruction(Instruction *opInst); // A list of memref's that are potentially dead / could be eliminated. SmallPtrSet memrefsToErase; // Load op's whose results were replaced by those forwarded from stores. - std::vector loadOpsToErase; + std::vector loadOpsToErase; DominanceInfo *domInfo = nullptr; PostDominanceInfo *postDomInfo = nullptr; @@ -100,22 +100,22 @@ FunctionPass *mlir::createMemRefDataFlowOptPass() { // This is a straightforward implementation not optimized for speed. Optimize // this in the future if needed. -void MemRefDataFlowOpt::visitInstruction(OperationInst *opInst) { - OperationInst *lastWriteStoreOp = nullptr; +void MemRefDataFlowOpt::visitInstruction(Instruction *opInst) { + Instruction *lastWriteStoreOp = nullptr; auto loadOp = opInst->dyn_cast(); if (!loadOp) return; - OperationInst *loadOpInst = opInst; + Instruction *loadOpInst = opInst; // First pass over the use list to get minimum number of surrounding // loops common between the load op and the store op, with min taken across // all store ops. - SmallVector storeOps; + SmallVector storeOps; unsigned minSurroundingLoops = getNestingDepth(*loadOpInst); for (InstOperand &use : loadOp->getMemRef()->getUses()) { - auto storeOp = cast(use.getOwner())->dyn_cast(); + auto storeOp = use.getOwner()->dyn_cast(); if (!storeOp) continue; auto *storeOpInst = storeOp->getInstruction(); @@ -131,11 +131,11 @@ void MemRefDataFlowOpt::visitInstruction(OperationInst *opInst) { // and loadOp. // The list of store op candidates for forwarding - need to satisfy the // conditions listed at the top. - SmallVector fwdingCandidates; + SmallVector fwdingCandidates; // Store ops that have a dependence into the load (even if they aren't // forwarding candidates). Each forwarding candidate will be checked for a // post-dominance on these. 'fwdingCandidates' are a subset of depSrcStores. - SmallVector depSrcStores; + SmallVector depSrcStores; for (auto *storeOpInst : storeOps) { MemRefAccess srcAccess(storeOpInst); MemRefAccess destAccess(loadOpInst); @@ -197,7 +197,7 @@ void MemRefDataFlowOpt::visitInstruction(OperationInst *opInst) { // that postdominates all 'depSrcStores' (if such a store exists) is the // unique store providing the value to the load, i.e., provably the last // writer to that memref loc. - if (llvm::all_of(depSrcStores, [&](OperationInst *depStore) { + if (llvm::all_of(depSrcStores, [&](Instruction *depStore) { return postDomInfo->postDominates(storeOpInst, depStore); })) { lastWriteStoreOp = storeOpInst; @@ -246,24 +246,22 @@ PassResult MemRefDataFlowOpt::runOnFunction(Function *f) { // to do this as well, but we'll do it here since we collected these anyway. for (auto *memref : memrefsToErase) { // If the memref hasn't been alloc'ed in this function, skip. - OperationInst *defInst = memref->getDefiningInst(); + Instruction *defInst = memref->getDefiningInst(); if (!defInst || !defInst->isa()) // TODO(mlir-team): if the memref was returned by a 'call' instruction, we // could still erase it if the call had no side-effects. continue; if (std::any_of(memref->use_begin(), memref->use_end(), [&](InstOperand &use) { - auto *ownerInst = cast(use.getOwner()); + auto *ownerInst = use.getOwner(); return (!ownerInst->isa() && !ownerInst->isa()); })) continue; // Erase all stores, the dealloc, and the alloc on the memref. - for (auto it = memref->use_begin(), e = memref->use_end(); it != e;) { - auto &use = *(it++); - cast(use.getOwner())->erase(); - } + for (auto &use : llvm::make_early_inc_range(memref->getUses())) + use.getOwner()->erase(); defInst->erase(); } diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 8d13800160d..ba3be5e95f4 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -61,7 +61,7 @@ FunctionPass *mlir::createPipelineDataTransferPass() { // Returns the position of the tag memref operand given a DMA instruction. // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are // added. TODO(b/117228571) -static unsigned getTagMemRefPos(const OperationInst &dmaInst) { +static unsigned getTagMemRefPos(const Instruction &dmaInst) { assert(dmaInst.isa() || dmaInst.isa()); if (dmaInst.isa()) { // Second to last operand. @@ -142,7 +142,7 @@ PassResult PipelineDataTransfer::runOnFunction(Function *f) { // deleted and replaced by a prologue, a new steady-state loop and an // epilogue). forOps.clear(); - f->walkPostOrder([&](OperationInst *opInst) { + f->walkPostOrder([&](Instruction *opInst) { if (auto forOp = opInst->dyn_cast()) forOps.push_back(forOp); }); @@ -180,33 +180,26 @@ static bool checkTagMatch(OpPointer startOp, // Identify matching DMA start/finish instructions to overlap computation with. static void findMatchingStartFinishInsts( OpPointer forOp, - SmallVectorImpl> - &startWaitPairs) { + SmallVectorImpl> &startWaitPairs) { // Collect outgoing DMA instructions - needed to check for dependences below. SmallVector, 4> outgoingDmaOps; for (auto &inst : *forOp->getBody()) { - auto *opInst = dyn_cast(&inst); - if (!opInst) - continue; OpPointer dmaStartOp; - if ((dmaStartOp = opInst->dyn_cast()) && + if ((dmaStartOp = inst.dyn_cast()) && dmaStartOp->isSrcMemorySpaceFaster()) outgoingDmaOps.push_back(dmaStartOp); } - SmallVector dmaStartInsts, dmaFinishInsts; + SmallVector dmaStartInsts, dmaFinishInsts; for (auto &inst : *forOp->getBody()) { - auto *opInst = dyn_cast(&inst); - if (!opInst) - continue; // Collect DMA finish instructions. - if (opInst->isa()) { - dmaFinishInsts.push_back(opInst); + if (inst.isa()) { + dmaFinishInsts.push_back(&inst); continue; } OpPointer dmaStartOp; - if (!(dmaStartOp = opInst->dyn_cast())) + if (!(dmaStartOp = inst.dyn_cast())) continue; // Only DMAs incoming into higher memory spaces are pipelined for now. // TODO(bondhugula): handle outgoing DMA pipelining. @@ -236,7 +229,7 @@ static void findMatchingStartFinishInsts( } } if (!escapingUses) - dmaStartInsts.push_back(opInst); + dmaStartInsts.push_back(&inst); } // For each start instruction, we look for a matching finish instruction. @@ -262,7 +255,7 @@ PipelineDataTransfer::runOnAffineForOp(OpPointer forOp) { return success(); } - SmallVector, 4> startWaitPairs; + SmallVector, 4> startWaitPairs; findMatchingStartFinishInsts(forOp, startWaitPairs); if (startWaitPairs.empty()) { @@ -335,7 +328,7 @@ PipelineDataTransfer::runOnAffineForOp(OpPointer forOp) { } else { // If a slice wasn't created, the reachable affine_apply op's from its // operands are the ones that go with it. - SmallVector affineApplyInsts; + SmallVector affineApplyInsts; SmallVector operands(dmaStartInst->getOperands()); getReachableAffineApplyOps(operands, affineApplyInsts); for (const auto *inst : affineApplyInsts) { @@ -356,13 +349,13 @@ PipelineDataTransfer::runOnAffineForOp(OpPointer forOp) { for (auto &inst : *forOp->getBody()) { assert(instShiftMap.find(&inst) != instShiftMap.end()); shifts[s++] = instShiftMap[&inst]; - LLVM_DEBUG( - // Tagging instructions with shifts for debugging purposes. - if (auto *opInst = dyn_cast(&inst)) { - FuncBuilder b(opInst); - opInst->setAttr(b.getIdentifier("shift"), - b.getI64IntegerAttr(shifts[s - 1])); - }); + + // Tagging instructions with shifts for debugging purposes. + LLVM_DEBUG({ + FuncBuilder b(&inst); + inst.setAttr(b.getIdentifier("shift"), + b.getI64IntegerAttr(shifts[s - 1])); + }); } if (!isInstwiseShiftValid(forOp, shifts)) { diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp index a9fcfc5bd11..29509911e31 100644 --- a/mlir/lib/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp @@ -64,7 +64,7 @@ static IntegerSet simplifyIntegerSet(IntegerSet set) { } PassResult SimplifyAffineStructures::runOnFunction(Function *f) { - f->walk([&](OperationInst *opInst) { + f->walk([&](Instruction *opInst) { for (auto attr : opInst->getAttrs()) { if (auto mapAttr = attr.second.dyn_cast()) { MutableAffineMap mMap(mapAttr.getValue()); diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 790f971bb58..45c57e2f307 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -38,13 +38,13 @@ public: worklist.reserve(64); // Add all operations to the worklist. - fn->walk([&](OperationInst *inst) { addToWorklist(inst); }); + fn->walk([&](Instruction *inst) { addToWorklist(inst); }); } /// Perform the rewrites. void simplifyFunction(); - void addToWorklist(OperationInst *op) { + void addToWorklist(Instruction *op) { // Check to see if the worklist already contains this op. if (worklistMap.count(op)) return; @@ -53,7 +53,7 @@ public: worklist.push_back(op); } - OperationInst *popFromWorklist() { + Instruction *popFromWorklist() { auto *op = worklist.back(); worklist.pop_back(); @@ -65,7 +65,7 @@ public: /// If the specified operation is in the worklist, remove it. If not, this is /// a no-op. - void removeFromWorklist(OperationInst *op) { + void removeFromWorklist(Instruction *op) { auto it = worklistMap.find(op); if (it != worklistMap.end()) { assert(worklist[it->second] == op && "malformed worklist data structure"); @@ -77,7 +77,7 @@ public: protected: // Implement the hook for creating operations, and make sure that newly // created ops are added to the worklist for processing. - OperationInst *createOperation(const OperationState &state) override { + Instruction *createOperation(const OperationState &state) override { auto *result = builder.createOperation(state); addToWorklist(result); return result; @@ -85,20 +85,18 @@ protected: // If an operation is about to be removed, make sure it is not in our // worklist anymore because we'd get dangling references to it. - void notifyOperationRemoved(OperationInst *op) override { + void notifyOperationRemoved(Instruction *op) override { removeFromWorklist(op); } // When the root of a pattern is about to be replaced, it can trigger // simplifications to its users - make sure to add them to the worklist // before the root is changed. - void notifyRootReplaced(OperationInst *op) override { + void notifyRootReplaced(Instruction *op) override { for (auto *result : op->getResults()) // TODO: Add a result->getUsers() iterator. - for (auto &user : result->getUses()) { - if (auto *op = dyn_cast(user.getOwner())) - addToWorklist(op); - } + for (auto &user : result->getUses()) + addToWorklist(user.getOwner()); // TODO: Walk the operand list dropping them as we go. If any of them // drop to zero uses, then add them to the worklist to allow them to be @@ -116,13 +114,13 @@ private: /// need to be revisited, plus their index in the worklist. This allows us to /// efficiently remove operations from the worklist when they are erased from /// the function, even if they aren't the root of a pattern. - std::vector worklist; - DenseMap worklistMap; + std::vector worklist; + DenseMap worklistMap; /// As part of canonicalization, we move constants to the top of the entry /// block of the current function and de-duplicate them. This keeps track of /// constants we have done this for. - DenseMap, OperationInst *> uniquedConstants; + DenseMap, Instruction *> uniquedConstants; }; }; // end anonymous namespace @@ -229,10 +227,8 @@ void GreedyPatternRewriteDriver::simplifyFunction() { // revisit them. // // TODO: Add a result->getUsers() iterator. - for (auto &operand : op->getResult(i)->getUses()) { - if (auto *op = dyn_cast(operand.getOwner())) - addToWorklist(op); - } + for (auto &operand : op->getResult(i)->getUses()) + addToWorklist(operand.getOwner()); res->replaceAllUsesWith(cstValue); } @@ -267,10 +263,8 @@ void GreedyPatternRewriteDriver::simplifyFunction() { if (res->use_empty()) // ignore dead uses. continue; - for (auto &operand : op->getResult(i)->getUses()) { - if (auto *op = dyn_cast(operand.getOwner())) - addToWorklist(op); - } + for (auto &operand : op->getResult(i)->getUses()) + addToWorklist(operand.getOwner()); res->replaceAllUsesWith(resultValues[i]); } } diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 153557de04a..5bf17989bef 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -101,7 +101,7 @@ bool mlir::promoteIfSingleIteration(OpPointer forOp) { // Replaces all IV uses to its single iteration value. auto *iv = forOp->getInductionVar(); - OperationInst *forInst = forOp->getInstruction(); + Instruction *forInst = forOp->getInstruction(); if (!iv->use_empty()) { if (forOp->hasConstantLowerBound()) { auto *mlFunc = forInst->getFunction(); @@ -135,7 +135,7 @@ bool mlir::promoteIfSingleIteration(OpPointer forOp) { /// their body into the containing Block. void mlir::promoteSingleIterationLoops(Function *f) { // Gathers all innermost loops through a post order pruned walk. - f->walkPostOrder([](OperationInst *inst) { + f->walkPostOrder([](Instruction *inst) { if (auto forOp = inst->dyn_cast()) promoteIfSingleIteration(forOp); }); @@ -394,11 +394,10 @@ bool mlir::loopUnrollByFactor(OpPointer forOp, return false; // Generate the cleanup loop if trip count isn't a multiple of unrollFactor. - OperationInst *forInst = forOp->getInstruction(); + Instruction *forInst = forOp->getInstruction(); if (getLargestDivisorOfTripCount(forOp) % unrollFactor != 0) { FuncBuilder builder(forInst->getBlock(), ++Block::iterator(forInst)); - auto cleanupForInst = - cast(builder.clone(*forInst))->cast(); + auto cleanupForInst = builder.clone(*forInst)->cast(); auto clLbMap = getCleanupLoopLowerBound(forOp, unrollFactor, &builder); assert(clLbMap && "cleanup loop lower bound map for single result bound maps can " diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 879a4f4b585..524e8d542f5 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -37,7 +37,7 @@ using namespace mlir; /// Return true if this operation dereferences one or more memref's. // Temporary utility: will be replaced when this is modeled through // side-effects/op traits. TODO(b/117228571) -static bool isMemRefDereferencingOp(const OperationInst &op) { +static bool isMemRefDereferencingOp(const Instruction &op) { if (op.isa() || op.isa() || op.isa() || op.isa()) return true; @@ -76,12 +76,11 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, std::make_unique(postDomInstFilter->getFunction()); // The ops where memref replacement succeeds are replaced with new ones. - SmallVector opsToErase; + SmallVector opsToErase; // Walk all uses of old memref. Operation using the memref gets replaced. - for (auto it = oldMemRef->use_begin(); it != oldMemRef->use_end();) { - InstOperand &use = *(it++); - auto *opInst = cast(use.getOwner()); + for (auto &use : llvm::make_early_inc_range(oldMemRef->getUses())) { + auto *opInst = use.getOwner(); // Skip this use if it's not dominated by domInstFilter. if (domInstFilter && !domInfo->dominates(domInstFilter, opInst)) @@ -217,8 +216,7 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, /// uses besides this opInst; otherwise returns the list of affine_apply /// operations created in output argument `sliceOps`. void mlir::createAffineComputationSlice( - OperationInst *opInst, - SmallVectorImpl> *sliceOps) { + Instruction *opInst, SmallVectorImpl> *sliceOps) { // Collect all operands that are results of affine apply ops. SmallVector subOperands; subOperands.reserve(opInst->getNumOperands()); @@ -230,7 +228,7 @@ void mlir::createAffineComputationSlice( } // Gather sequence of AffineApplyOps reachable from 'subOperands'. - SmallVector affineApplyOps; + SmallVector affineApplyOps; getReachableAffineApplyOps(subOperands, affineApplyOps); // Skip transforming if there are no affine maps to compose. if (affineApplyOps.empty()) @@ -341,8 +339,7 @@ bool mlir::constantFoldBounds(OpPointer forInst) { } void mlir::remapFunctionAttrs( - OperationInst &op, - const DenseMap &remappingTable) { + Instruction &op, const DenseMap &remappingTable) { for (auto attr : op.getAttrs()) { // Do the remapping, if we got the same thing back, then it must contain // functions that aren't getting remapped. diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index a9b9752ef51..7d51637a6e1 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -110,17 +110,13 @@ void VectorizerTestPass::testVectorShapeRatio(Function *f) { // Only filter instructions that operate on a strict super-vector and have one // return. This makes testing easier. auto filter = [subVectorType](const Instruction &inst) { - auto *opInst = dyn_cast(&inst); - if (!opInst) { - return false; - } assert(subVectorType.getElementType() == Type::getF32(subVectorType.getContext()) && "Only f32 supported for now"); - if (!matcher::operatesOnSuperVectors(*opInst, subVectorType)) { + if (!matcher::operatesOnSuperVectors(inst, subVectorType)) { return false; } - if (opInst->getNumResults() != 1) { + if (inst.getNumResults() != 1) { return false; } return true; @@ -129,7 +125,7 @@ void VectorizerTestPass::testVectorShapeRatio(Function *f) { SmallVector matches; pat.match(f, &matches); for (auto m : matches) { - auto *opInst = cast(m.getMatchedInstruction()); + auto *opInst = m.getMatchedInstruction(); // This is a unit test that only checks and prints shape ratio. // As a consequence we write only Ops with a single return type for the // purpose of this test. If we need to test more intricate behavior in the @@ -159,8 +155,7 @@ static NestedPattern patternTestSlicingOps() { using matcher::Op; // Match all OpInstructions with the kTestSlicingOpName name. auto filter = [](const Instruction &inst) { - const auto &opInst = cast(inst); - return opInst.getName().getStringRef() == kTestSlicingOpName; + return inst.getName().getStringRef() == kTestSlicingOpName; }; return Op(filter); } @@ -209,8 +204,7 @@ void VectorizerTestPass::testSlicing(Function *f) { } static bool customOpWithAffineMapAttribute(const Instruction &inst) { - const auto &opInst = cast(inst); - return opInst.getName().getStringRef() == + return inst.getName().getStringRef() == VectorizerTestPass::kTestAffineMapOpName; } @@ -222,7 +216,7 @@ void VectorizerTestPass::testComposeMaps(Function *f) { SmallVector maps; maps.reserve(matches.size()); for (auto m : llvm::reverse(matches)) { - auto *opInst = cast(m.getMatchedInstruction()); + auto *opInst = m.getMatchedInstruction(); auto map = opInst->getAttr(VectorizerTestPass::kTestAffineMapAttrName) .cast() .getValue(); @@ -236,13 +230,11 @@ void VectorizerTestPass::testComposeMaps(Function *f) { } static bool affineApplyOp(const Instruction &inst) { - const auto &opInst = cast(inst); - return opInst.isa(); + return inst.isa(); } static bool singleResultAffineApplyOpWithoutUses(const Instruction &inst) { - const auto &opInst = cast(inst); - auto app = opInst.dyn_cast(); + auto app = inst.dyn_cast(); return app && app->use_empty(); } @@ -259,8 +251,7 @@ void VectorizerTestPass::testNormalizeMaps(Function *f) { SmallVector matches; pattern.match(f, &matches); for (auto m : matches) { - auto app = - cast(m.getMatchedInstruction())->cast(); + auto app = m.getMatchedInstruction()->cast(); FuncBuilder b(m.getMatchedInstruction()); SmallVector operands(app->getOperands()); makeComposedAffineApply(&b, app->getLoc(), app->getAffineMap(), operands); diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 661861dcfd4..5a8d5d24661 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -723,22 +723,22 @@ namespace { struct VectorizationState { /// Adds an entry of pre/post vectorization instructions in the state. - void registerReplacement(OperationInst *key, OperationInst *value); + void registerReplacement(Instruction *key, Instruction *value); /// When the current vectorization pattern is successful, this erases the /// instructions that were marked for erasure in the proper order and resets /// the internal state for the next pattern. void finishVectorizationPattern(); - // In-order tracking of original OperationInst that have been vectorized. + // In-order tracking of original Instruction that have been vectorized. // Erase in reverse order. - SmallVector toErase; - // Set of OperationInst that have been vectorized (the values in the + SmallVector toErase; + // Set of Instruction that have been vectorized (the values in the // vectorizationMap for hashed access). The vectorizedSet is used in // particular to filter the instructions that have already been vectorized by // this pattern, when iterating over nested loops in this pattern. - DenseSet vectorizedSet; - // Map of old scalar OperationInst to new vectorized OperationInst. - DenseMap vectorizationMap; + DenseSet vectorizedSet; + // Map of old scalar Instruction to new vectorized Instruction. + DenseMap vectorizationMap; // Map of old scalar Value to new vectorized Value. DenseMap replacementMap; // The strategy drives which loop to vectorize by which amount. @@ -747,17 +747,17 @@ struct VectorizationState { // vectorizeOperations function. They consist of the subset of load operations // that have been vectorized. They can be retrieved from `vectorizationMap` // but it is convenient to keep track of them in a separate data structure. - DenseSet roots; + DenseSet roots; // Terminator instructions for the worklist in the vectorizeOperations // function. They consist of the subset of store operations that have been // vectorized. They can be retrieved from `vectorizationMap` but it is // convenient to keep track of them in a separate data structure. Since they // do not necessarily belong to use-def chains starting from loads (e.g // storing a constant), we need to handle them in a post-pass. - DenseSet terminators; + DenseSet terminators; // Checks that the type of `inst` is StoreOp and adds it to the terminators // set. - void registerTerminator(OperationInst *inst); + void registerTerminator(Instruction *inst); private: void registerReplacement(const Value *key, Value *value); @@ -765,8 +765,8 @@ private: } // end namespace -void VectorizationState::registerReplacement(OperationInst *key, - OperationInst *value) { +void VectorizationState::registerReplacement(Instruction *key, + Instruction *value) { LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ commit vectorized op: "); LLVM_DEBUG(key->print(dbgs())); LLVM_DEBUG(dbgs() << " into "); @@ -785,7 +785,7 @@ void VectorizationState::registerReplacement(OperationInst *key, } } -void VectorizationState::registerTerminator(OperationInst *inst) { +void VectorizationState::registerTerminator(Instruction *inst) { assert(inst->isa() && "terminator must be a StoreOp"); assert(terminators.count(inst) == 0 && "terminator was already inserted previously"); @@ -867,17 +867,16 @@ static bool vectorizeAffineForOp(AffineForOp *loop, int64_t step, if (!matcher::isLoadOrStore(inst)) { return false; } - auto *opInst = cast(&inst); - return state->vectorizationMap.count(opInst) == 0 && - state->vectorizedSet.count(opInst) == 0 && - state->roots.count(opInst) == 0 && - state->terminators.count(opInst) == 0; + return state->vectorizationMap.count(&inst) == 0 && + state->vectorizedSet.count(&inst) == 0 && + state->roots.count(&inst) == 0 && + state->terminators.count(&inst) == 0; }; auto loadAndStores = matcher::Op(notVectorizedThisPattern); SmallVector loadAndStoresMatches; loadAndStores.match(loop->getInstruction(), &loadAndStoresMatches); for (auto ls : loadAndStoresMatches) { - auto *opInst = cast(ls.getMatchedInstruction()); + auto *opInst = ls.getMatchedInstruction(); auto load = opInst->dyn_cast(); auto store = opInst->dyn_cast(); LLVM_DEBUG(opInst->print(dbgs())); @@ -900,7 +899,7 @@ static bool vectorizeAffineForOp(AffineForOp *loop, int64_t step, static FilterFunctionType isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension) { return [fastestVaryingMemRefDimension](const Instruction &forInst) { - auto loop = cast(forInst).cast(); + auto loop = forInst.cast(); return isVectorizableLoopAlongFastestVaryingMemRefDim( loop, fastestVaryingMemRefDimension); }; @@ -915,7 +914,7 @@ static bool vectorizeNonRoot(ArrayRef matches, /// recursively in DFS post-order. static bool doVectorize(NestedMatch oneMatch, VectorizationState *state) { auto *loopInst = oneMatch.getMatchedInstruction(); - auto loop = cast(loopInst)->cast(); + auto loop = loopInst->cast(); auto childrenMatches = oneMatch.getMatchedChildren(); // 1. DFS postorder recursion, if any of my children fails, I fail too. @@ -977,15 +976,14 @@ static Value *vectorizeConstant(Instruction *inst, const ConstantOp &constant, Location loc = inst->getLoc(); auto vectorType = type.cast(); auto attr = SplatElementsAttr::get(vectorType, constant.getValue()); - auto *constantOpInst = cast(constant.getInstruction()); + auto *constantOpInst = constant.getInstruction(); OperationState state( b.getContext(), loc, constantOpInst->getName().getStringRef(), {}, {vectorType}, {make_pair(Identifier::get("value", b.getContext()), attr)}); - auto *splat = cast(b.createOperation(state)); - return splat->getResult(0); + return b.createOperation(state)->getResult(0); } /// Returns a uniqu'ed VectorType. @@ -997,8 +995,7 @@ static Type getVectorType(Value *v, const VectorizationState &state) { if (!VectorType::isValidElementType(v->getType())) { return Type(); } - auto *definingOpInst = cast(v->getDefiningInst()); - if (state.vectorizedSet.count(definingOpInst) > 0) { + if (state.vectorizedSet.count(v->getDefiningInst()) > 0) { return v->getType().cast(); } return VectorType::get(state.strategy->vectorSizes, v->getType()); @@ -1029,9 +1026,8 @@ static Value *vectorizeOperand(Value *operand, Instruction *inst, VectorizationState *state) { LLVM_DEBUG(dbgs() << "\n[early-vect]vectorize operand: "); LLVM_DEBUG(operand->print(dbgs())); - auto *definingInstruction = cast(operand->getDefiningInst()); // 1. If this value has already been vectorized this round, we are done. - if (state->vectorizedSet.count(definingInstruction) > 0) { + if (state->vectorizedSet.count(operand->getDefiningInst()) > 0) { LLVM_DEBUG(dbgs() << " -> already vector operand"); return operand; } @@ -1062,7 +1058,7 @@ static Value *vectorizeOperand(Value *operand, Instruction *inst, return nullptr; }; -/// Encodes OperationInst-specific behavior for vectorization. In general we +/// Encodes Instruction-specific behavior for vectorization. In general we /// assume that all operands of an op must be vectorized but this is not always /// true. In the future, it would be nice to have a trait that describes how a /// particular operation vectorizes. For now we implement the case distinction @@ -1071,9 +1067,8 @@ static Value *vectorizeOperand(Value *operand, Instruction *inst, /// TODO(ntv): consider adding a trait to Op to describe how it gets vectorized. /// Maybe some Ops are not vectorizable or require some tricky logic, we cannot /// do one-off logic here; ideally it would be TableGen'd. -static OperationInst *vectorizeOneOperationInst(FuncBuilder *b, - OperationInst *opInst, - VectorizationState *state) { +static Instruction *vectorizeOneInstruction(FuncBuilder *b, Instruction *opInst, + VectorizationState *state) { // Sanity checks. assert(!opInst->isa() && "all loads must have already been fully vectorized independently"); @@ -1094,7 +1089,7 @@ static OperationInst *vectorizeOneOperationInst(FuncBuilder *b, LLVM_DEBUG(permutationMap.print(dbgs())); auto transfer = b.create( opInst->getLoc(), vectorValue, memRef, indices, permutationMap); - auto *res = cast(transfer->getInstruction()); + auto *res = transfer->getInstruction(); LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << *res); // "Terminators" (i.e. StoreOps) are erased on the spot. opInst->erase(); @@ -1119,8 +1114,8 @@ static OperationInst *vectorizeOneOperationInst(FuncBuilder *b, // Create a clone of the op with the proper operands and return types. // TODO(ntv): The following assumes there is always an op with a fixed // name that works both in scalar mode and vector mode. - // TODO(ntv): Is it worth considering an OperationInst.clone operation - // which changes the type so we can promote an OperationInst with less + // TODO(ntv): Is it worth considering an Instruction.clone operation + // which changes the type so we can promote an Instruction with less // boilerplate? OperationState newOp(b->getContext(), opInst->getLoc(), opInst->getName().getStringRef(), operands, types, @@ -1129,22 +1124,22 @@ static OperationInst *vectorizeOneOperationInst(FuncBuilder *b, return b->createOperation(newOp); } -/// Iterates over the OperationInst in the loop and rewrites them using their +/// Iterates over the Instruction in the loop and rewrites them using their /// vectorized counterpart by: -/// 1. iteratively building a worklist of uses of the OperationInst vectorized +/// 1. iteratively building a worklist of uses of the Instruction vectorized /// so far by this pattern; -/// 2. for each OperationInst in the worklist, create the vector form of this +/// 2. for each Instruction in the worklist, create the vector form of this /// operation and replace all its uses by the vectorized form. For this step, /// the worklist must be traversed in order; /// 3. verify that all operands of the newly vectorized operation have been /// vectorized by this pattern. static bool vectorizeOperations(VectorizationState *state) { // 1. create initial worklist with the uses of the roots. - SetVector worklist; - auto insertUsesOf = [&worklist, state](OperationInst *vectorized) { + SetVector worklist; + auto insertUsesOf = [&worklist, state](Instruction *vectorized) { for (auto *r : vectorized->getResults()) for (auto &u : r->getUses()) { - auto *inst = cast(u.getOwner()); + auto *inst = u.getOwner(); // Don't propagate to terminals, a separate pass is needed for those. // TODO(ntv)[b/119759136]: use isa<> once Op is implemented. if (state->terminators.count(inst) > 0) { @@ -1166,7 +1161,7 @@ static bool vectorizeOperations(VectorizationState *state) { // 2. Create vectorized form of the instruction. // Insert it just before inst, on success register inst as replaced. FuncBuilder b(inst); - auto *vectorizedInst = vectorizeOneOperationInst(&b, inst, state); + auto *vectorizedInst = vectorizeOneInstruction(&b, inst, state); if (!vectorizedInst) { return true; } @@ -1179,7 +1174,7 @@ static bool vectorizeOperations(VectorizationState *state) { // 4. Augment the worklist with uses of the instruction we just vectorized. // This preserves the proper order in the worklist. - apply(insertUsesOf, ArrayRef{inst}); + apply(insertUsesOf, ArrayRef{inst}); } return false; } @@ -1189,8 +1184,7 @@ static bool vectorizeOperations(VectorizationState *state) { /// Each root may succeed independently but will otherwise clean after itself if /// anything below it fails. static bool vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) { - auto loop = - cast(m.getMatchedInstruction())->cast(); + auto loop = m.getMatchedInstruction()->cast(); VectorizationState state; state.strategy = strategy; @@ -1207,8 +1201,7 @@ static bool vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) { } auto *loopInst = loop->getInstruction(); FuncBuilder builder(loopInst); - auto clonedLoop = - cast(builder.clone(*loopInst))->cast(); + auto clonedLoop = builder.clone(*loopInst)->cast(); auto fail = doVectorize(m, &state); /// Sets up error handling for this root loop. This is how the root match @@ -1248,12 +1241,12 @@ static bool vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) { } // Finally, vectorize the terminators. If anything fails to vectorize, skip. - auto vectorizeOrFail = [&fail, &state](OperationInst *inst) { + auto vectorizeOrFail = [&fail, &state](Instruction *inst) { if (fail) { return; } FuncBuilder b(inst); - auto *res = vectorizeOneOperationInst(&b, inst, &state); + auto *res = vectorizeOneInstruction(&b, inst, &state); if (res == nullptr) { fail = true; } -- cgit v1.2.3 From 0f50414fa4553b1277684cb1dded84b334b35d51 Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Mon, 4 Feb 2019 13:48:44 -0800 Subject: Refactor common code getting memref access in getMemRefRegion - NFC - use getAccessMap() instead of repeating it - fold getMemRefRegion into MemRefRegion ctor (more natural, avoid heap allocation and unique_ptr where possible) - change extractForInductionVars - MutableArrayRef -> ArrayRef for the arguments. Since the method is just returning copies of 'Value *', the client can't mutate the pointers themselves; it's fine to mutate the 'Value''s themselves, but that doesn't mutate the pointers to those. - change the way extractForInductionVars returns (see b/123437690) PiperOrigin-RevId: 232359277 --- mlir/include/mlir/AffineOps/AffineOps.h | 8 +- mlir/include/mlir/Analysis/AffineAnalysis.h | 9 ++- mlir/include/mlir/Analysis/Utils.h | 48 ++++++------ mlir/lib/AffineOps/AffineOps.cpp | 9 +-- mlir/lib/Analysis/AffineAnalysis.cpp | 3 +- mlir/lib/Analysis/Utils.cpp | 115 ++++++++++++---------------- mlir/lib/Transforms/DmaGeneration.cpp | 7 +- mlir/lib/Transforms/LoopFusion.cpp | 7 +- mlir/lib/Transforms/LoopTiling.cpp | 3 +- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 5 +- 10 files changed, 104 insertions(+), 110 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/AffineOps/AffineOps.h b/mlir/include/mlir/AffineOps/AffineOps.h index 46bb91c1bca..0e71221871b 100644 --- a/mlir/include/mlir/AffineOps/AffineOps.h +++ b/mlir/include/mlir/AffineOps/AffineOps.h @@ -201,10 +201,10 @@ bool isForInductionVar(const Value *val); OpPointer getForInductionVarOwner(Value *val); ConstOpPointer getForInductionVarOwner(const Value *val); -/// Extracts the induction variables from a list of AffineForOps and returns -/// them. -SmallVector -extractForInductionVars(MutableArrayRef> forInsts); +/// Extracts the induction variables from a list of AffineForOps and places them +/// in the output argument `ivs`. +void extractForInductionVars(ArrayRef> forInsts, + SmallVectorImpl *ivs); /// AffineBound represents a lower or upper bound in the for instruction. /// This class does not own the underlying operands. Instead, it refers diff --git a/mlir/include/mlir/Analysis/AffineAnalysis.h b/mlir/include/mlir/Analysis/AffineAnalysis.h index ca420bab7e1..dd3e676eb17 100644 --- a/mlir/include/mlir/Analysis/AffineAnalysis.h +++ b/mlir/include/mlir/Analysis/AffineAnalysis.h @@ -117,8 +117,8 @@ bool getIndexSet(llvm::MutableArrayRef> forOps, /// Encapsulates a memref load or store access information. struct MemRefAccess { - const Value *memref; - const Instruction *opInst; + Value *memref; + Instruction *opInst; llvm::SmallVector indices; /// Constructs a MemRefAccess from a load or store operation instruction. @@ -126,6 +126,11 @@ struct MemRefAccess { // return MemRefAccess, i.e., loadOp->getAccess(), dmaOp->getRead/WriteAccess. explicit MemRefAccess(Instruction *opInst); + // Returns the rank of the memref associated with this access. + unsigned getRank() const; + // Returns true if this access is of a store op. + bool isStore() const; + /// Populates 'accessMap' with composition of AffineApplyOps reachable from // 'indices'. void getAccessMap(AffineValueMap *accessMap) const; diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index 54549fc8ef8..1dcbd759154 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -76,8 +76,30 @@ unsigned getNestingDepth(const Instruction &stmt); // The last field is a 2-d FlatAffineConstraints symbolic in %i. // struct MemRefRegion { - MemRefRegion(Value *memref, Location loc, bool write) - : memref(memref), write(write), loc(loc) {} + explicit MemRefRegion(Location loc) : loc(loc) {} + + /// Computes the memory region accessed by this memref with the region + /// represented as constraints symbolic/parameteric in 'loopDepth' loops + /// surrounding opInst. Returns false if this fails due to yet unimplemented + /// cases. The computed region's 'cst' field has exactly as many dimensional + /// identifiers as the rank of the memref, and *potentially* additional + /// symbolic identifiers which could include any of the loop IVs surrounding + /// opInst up until 'loopDepth' and another additional Function symbols + /// involved with the access (for eg., those appear in affine_apply's, loop + /// bounds, etc.). + /// For example, the memref region for this operation at loopDepth = 1 will + /// be: + /// + /// for %i = 0 to 32 { + /// for %ii = %i to (d0) -> (d0 + 8) (%i) { + /// load %A[%ii] + /// } + /// } + /// + /// {memref = %A, write = false, {%i <= m0 <= %i + 7} } + /// The last field is a 2-d FlatAffineConstraints symbolic in %i. + /// + bool compute(Instruction *inst, unsigned loopDepth); FlatAffineConstraints *getConstraints() { return &cst; } const FlatAffineConstraints *getConstraints() const { return &cst; } @@ -132,28 +154,6 @@ struct MemRefRegion { FlatAffineConstraints cst; }; -/// Computes the memory region accessed by this memref with the region -/// represented as constraints symbolic/parameteric in 'loopDepth' loops -/// surrounding opInst. Returns nullptr if this fails due to yet unimplemented -/// cases. The computed region's 'cst' field has exactly as many dimensional -/// identifiers as the rank of the memref, and *potentially* additional symbolic -/// identifiers which could include any of the loop IVs surrounding opInst up -/// until 'loopDepth' and another additional Function symbols involved with -/// the access (for eg., those appear in affine_apply's, loop bounds, etc.). -/// For example, the memref region for this operation at loopDepth = 1 will be: -/// -/// for %i = 0 to 32 { -/// for %ii = %i to (d0) -> (d0 + 8) (%i) { -/// load %A[%ii] -/// } -/// } -/// -/// {memref = %A, write = false, {%i <= m0 <= %i + 7} } -/// The last field is a 2-d FlatAffineConstraints symbolic in %i. -/// -std::unique_ptr getMemRefRegion(Instruction *opInst, - unsigned loopDepth); - /// Returns the size of memref data in bytes if it's statically shaped, None /// otherwise. Optional getMemRefSizeInBytes(MemRefType memRefType); diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 2e657cf7e17..2ef96aa3d14 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -462,12 +462,11 @@ ConstOpPointer mlir::getForInductionVarOwner(const Value *val) { /// Extracts the induction variables from a list of AffineForOps and returns /// them. -SmallVector mlir::extractForInductionVars( - MutableArrayRef> forInsts) { - SmallVector results; +void mlir::extractForInductionVars(ArrayRef> forInsts, + SmallVectorImpl *ivs) { + ivs->reserve(forInsts.size()); for (auto forInst : forInsts) - results.push_back(forInst->getInductionVar()); - return results; + ivs->push_back(forInst->getInductionVar()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index fafa5126939..0a2b283738c 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -557,7 +557,8 @@ void mlir::getReachableAffineApplyOps( // setExprStride(ArrayRef expr, int64_t stride) bool mlir::getIndexSet(MutableArrayRef> forOps, FlatAffineConstraints *domain) { - auto indices = extractForInductionVars(forOps); + SmallVector indices; + extractForInductionVars(forOps, &indices); // Reset while associated Values in 'indices' to the domain. domain->reset(forOps.size(), /*numSymbols=*/0, /*numLocals=*/0, indices); for (auto forOp : forOps) { diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 652aaab0e1b..d27715333c8 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -122,55 +122,36 @@ bool MemRefRegion::unionBoundingBox(const MemRefRegion &other) { // // TODO(bondhugula): extend this to any other memref dereferencing ops // (dma_start, dma_wait). -std::unique_ptr mlir::getMemRefRegion(Instruction *opInst, - unsigned loopDepth) { - unsigned rank; - std::unique_ptr region; - SmallVector indices; - if (auto loadOp = opInst->dyn_cast()) { - rank = loadOp->getMemRefType().getRank(); - indices.reserve(rank); - indices.append(loadOp->getIndices().begin(), loadOp->getIndices().end()); - region = std::make_unique(loadOp->getMemRef(), - loadOp->getLoc(), false); - } else if (auto storeOp = opInst->dyn_cast()) { - rank = storeOp->getMemRefType().getRank(); - indices.reserve(rank); - indices.append(storeOp->getIndices().begin(), storeOp->getIndices().end()); - region = std::make_unique(storeOp->getMemRef(), - storeOp->getLoc(), true); - } else { - assert(false && "expected load or store op"); - return nullptr; - } +bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth) { + assert((inst->isa() || inst->isa()) && + "load/store op expected"); - // Build the constraints for this region. - FlatAffineConstraints *regionCst = region->getConstraints(); + MemRefAccess access(inst); + memref = access.memref; + write = access.isStore(); + + unsigned rank = access.getRank(); if (rank == 0) { - // A rank 0 memref has a 0-d region. SmallVector, 4> ivs; - getLoopIVs(*opInst, &ivs); - - SmallVector regionSymbols = extractForInductionVars(ivs); - regionCst->reset(0, loopDepth, 0, regionSymbols); - return region; + getLoopIVs(*inst, &ivs); + SmallVector regionSymbols; + extractForInductionVars(ivs, ®ionSymbols); + // A rank 0 memref has a 0-d region. + cst.reset(rank, loopDepth, 0, regionSymbols); + return true; } - FuncBuilder b(opInst); - auto idMap = b.getMultiDimIdentityMap(rank); - // Initialize 'accessValueMap' and compose with reachable AffineApplyOps. - fullyComposeAffineMapAndOperands(&idMap, &indices); - // Remove any duplicates. - canonicalizeMapAndOperands(&idMap, &indices); - AffineValueMap accessValueMap(idMap, indices); + // Build the constraints for this region. + AffineValueMap accessValueMap; + access.getAccessMap(&accessValueMap); AffineMap accessMap = accessValueMap.getAffineMap(); // We'll first associate the dims and symbols of the access map to the dims - // and symbols resp. of regionCst. This will change below once regionCst is + // and symbols resp. of cst. This will change below once cst is // fully constructed out. - regionCst->reset(accessMap.getNumDims(), accessMap.getNumSymbols(), 0, - accessValueMap.getOperands()); + cst.reset(accessMap.getNumDims(), accessMap.getNumSymbols(), 0, + accessValueMap.getOperands()); // Add equality constraints. unsigned numDims = accessMap.getNumDims(); @@ -178,65 +159,63 @@ std::unique_ptr mlir::getMemRefRegion(Instruction *opInst, // Add inequalties for loop lower/upper bounds. for (unsigned i = 0; i < numDims + numSymbols; ++i) { if (auto loop = getForInductionVarOwner(accessValueMap.getOperand(i))) { - // Note that regionCst can now have more dimensions than accessMap if the + // Note that cst can now have more dimensions than accessMap if the // bounds expressions involve outer loops or other symbols. // TODO(bondhugula): rewrite this to use getInstIndexSet; this way // conditionals will be handled when the latter supports it. - if (!regionCst->addAffineForOpDomain(loop)) - return nullptr; + if (!cst.addAffineForOpDomain(loop)) + return false; } else { // Has to be a valid symbol. auto *symbol = accessValueMap.getOperand(i); assert(symbol->isValidSymbol()); // Check if the symbol is a constant. - if (auto *opInst = symbol->getDefiningInst()) { - if (auto constOp = opInst->dyn_cast()) { - regionCst->setIdToConstant(*symbol, constOp->getValue()); + if (auto *inst = symbol->getDefiningInst()) { + if (auto constOp = inst->dyn_cast()) { + cst.setIdToConstant(*symbol, constOp->getValue()); } } } } // Add access function equalities to connect loop IVs to data dimensions. - if (!regionCst->composeMap(&accessValueMap)) { + if (!cst.composeMap(&accessValueMap)) { LLVM_DEBUG(llvm::dbgs() << "getMemRefRegion: compose affine map failed\n"); LLVM_DEBUG(accessValueMap.getAffineMap().dump()); - return nullptr; + return false; } // Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which // this memref region is symbolic. SmallVector, 4> outerIVs; - getLoopIVs(*opInst, &outerIVs); + getLoopIVs(*inst, &outerIVs); assert(loopDepth <= outerIVs.size() && "invalid loop depth"); outerIVs.resize(loopDepth); for (auto *operand : accessValueMap.getOperands()) { OpPointer iv; if ((iv = getForInductionVarOwner(operand)) && llvm::is_contained(outerIVs, iv) == false) { - regionCst->projectOut(operand); + cst.projectOut(operand); } } // Project out any local variables (these would have been added for any // mod/divs). - regionCst->projectOut(regionCst->getNumDimAndSymbolIds(), - regionCst->getNumLocalIds()); + cst.projectOut(cst.getNumDimAndSymbolIds(), cst.getNumLocalIds()); // Set all identifiers appearing after the first 'rank' identifiers as // symbolic identifiers - so that the ones correspoding to the memref // dimensions are the dimensional identifiers for the memref region. - regionCst->setDimSymbolSeparation(regionCst->getNumDimAndSymbolIds() - rank); + cst.setDimSymbolSeparation(cst.getNumDimAndSymbolIds() - rank); // Constant fold any symbolic identifiers. - regionCst->constantFoldIdRange(/*pos=*/regionCst->getNumDimIds(), - /*num=*/regionCst->getNumSymbolIds()); + cst.constantFoldIdRange(/*pos=*/cst.getNumDimIds(), + /*num=*/cst.getNumSymbolIds()); - assert(regionCst->getNumDimIds() == rank && "unexpected MemRefRegion format"); + assert(cst.getNumDimIds() == rank && "unexpected MemRefRegion format"); LLVM_DEBUG(llvm::dbgs() << "Memory region:\n"); - LLVM_DEBUG(region->getConstraints()->dump()); - - return region; + LLVM_DEBUG(cst.dump()); + return true; } // TODO(mlir-team): improve/complete this when we have target data. @@ -282,19 +261,19 @@ bool mlir::boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, Instruction *opInst = loadOrStoreOp->getInstruction(); - auto region = getMemRefRegion(opInst, /*loopDepth=*/0); - if (!region) + MemRefRegion region(opInst->getLoc()); + if (!region.compute(opInst, /*loopDepth=*/0)) return false; LLVM_DEBUG(llvm::dbgs() << "Memory region"); - LLVM_DEBUG(region->getConstraints()->dump()); + LLVM_DEBUG(region.getConstraints()->dump()); bool outOfBounds = false; unsigned rank = loadOrStoreOp->getMemRefType().getRank(); // For each dimension, check for out of bounds. for (unsigned r = 0; r < rank; r++) { - FlatAffineConstraints ucst(*region->getConstraints()); + FlatAffineConstraints ucst(*region.getConstraints()); // Intersect memory region with constraint capturing out of bounds (both out // of upper and out of lower), and check if the constraint system is @@ -314,7 +293,7 @@ bool mlir::boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, } // Check for a negative index. - FlatAffineConstraints lcst(*region->getConstraints()); + FlatAffineConstraints lcst(*region.getConstraints()); std::fill(ineq.begin(), ineq.end(), 0); // d_i <= -1; lcst.addConstantUpperBound(r, -1); @@ -521,6 +500,12 @@ MemRefAccess::MemRefAccess(Instruction *loadOrStoreOpInst) { } } +unsigned MemRefAccess::getRank() const { + return memref->getType().cast().getRank(); +} + +bool MemRefAccess::isStore() const { return opInst->isa(); } + /// Returns the nesting depth of this statement, i.e., the number of loops /// surrounding this statement. unsigned mlir::getNestingDepth(const Instruction &inst) { @@ -600,8 +585,8 @@ Optional mlir::getMemoryFootprintBytes(const Block &block, // all regions for a given memref instead of creating one region per // memory op. This way we would be allocating O(num of memref's) sets // instead of O(num of load/store op's). - auto region = getMemRefRegion(opInst, 0); - if (!region) { + auto region = std::make_unique(opInst->getLoc()); + if (!region->compute(opInst, /*loopDepth=*/0)) { LLVM_DEBUG(llvm::dbgs() << "Error obtaining memory region\n"); // TODO: stop the walk if an error occurred. error = true; diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 92ae3767098..40d90c3e27b 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -183,7 +183,8 @@ static bool getFullMemRefAsRegion(Instruction *opInst, unsigned numParamLoopIVs, SmallVector, 4> ivs; getLoopIVs(*opInst, &ivs); ivs.resize(numParamLoopIVs); - SmallVector symbols = extractForInductionVars(ivs); + SmallVector symbols; + extractForInductionVars(ivs, &symbols); regionCst->reset(rank, numParamLoopIVs, 0); regionCst->setIdValues(rank, rank + numParamLoopIVs, symbols); @@ -576,8 +577,8 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { } // Compute the MemRefRegion accessed. - auto region = getMemRefRegion(opInst, dmaDepth); - if (!region) { + auto region = std::make_unique(opInst->getLoc()); + if (!region->compute(opInst, dmaDepth)) { LLVM_DEBUG(llvm::dbgs() << "Error obtaining memory region: semi-affine maps?\n"); LLVM_DEBUG(llvm::dbgs() << "over-approximating to the entire memref\n"); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index d7d69e569e5..7a002168528 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -921,7 +921,8 @@ static Value *createPrivateMemRef(OpPointer forOp, unsigned rank = oldMemRefType.getRank(); // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'. - auto region = getMemRefRegion(srcStoreOpInst, dstLoopDepth); + MemRefRegion region(srcStoreOpInst->getLoc()); + region.compute(srcStoreOpInst, dstLoopDepth); SmallVector newShape; std::vector> lbs; SmallVector lbDivisors; @@ -929,11 +930,11 @@ static Value *createPrivateMemRef(OpPointer forOp, // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed // by 'srcStoreOpInst' at depth 'dstLoopDepth'. Optional numElements = - region->getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors); + region.getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors); assert(numElements.hasValue() && "non-constant number of elts in local buffer"); - const FlatAffineConstraints *cst = region->getConstraints(); + const FlatAffineConstraints *cst = region.getConstraints(); // 'outerIVs' holds the values that this memory region is symbolic/paramteric // on; this would correspond to loop IVs surrounding the level at which the // slice is being materialized. diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 8b368e5f182..758d434d25e 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -201,7 +201,8 @@ UtilResult mlir::tileCodeGen(MutableArrayRef> band, // Move the loop body of the original nest to the new one. moveLoopBody(origLoops[origLoops.size() - 1], innermostPointLoop); - SmallVector origLoopIVs = extractForInductionVars(band); + SmallVector origLoopIVs; + extractForInductionVars(band, &origLoopIVs); SmallVector, 6> ids(origLoopIVs.begin(), origLoopIVs.end()); FlatAffineConstraints cst; getIndexSet(band, &cst); diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index b2b69dc7b6d..2d06a327315 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -178,8 +178,9 @@ void MemRefDataFlowOpt::visitInstruction(Instruction *opInst) { // is trivially loading from a single location at that depth; so there // isn't a need to call isRangeOneToOne. if (getNestingDepth(*storeOpInst) < loadOpDepth) { - auto region = getMemRefRegion(loadOpInst, nsLoops); - if (!region->getConstraints()->isRangeOneToOne( + MemRefRegion region(loadOpInst->getLoc()); + region.compute(loadOpInst, nsLoops); + if (!region.getConstraints()->isRangeOneToOne( /*start=*/0, /*limit=*/loadOp->getMemRefType().getRank())) break; } -- cgit v1.2.3 From bf9c381d1dbf4381659597109422e543d62a49d7 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 4 Feb 2019 16:24:44 -0800 Subject: Remove InstWalker and move all instruction walking to the api facilities on Function/Block/Instruction. PiperOrigin-RevId: 232388113 --- mlir/include/mlir/AffineOps/AffineOps.h | 8 -- mlir/include/mlir/Analysis/NestedMatcher.h | 37 +++---- mlir/include/mlir/IR/Block.h | 35 ++++--- mlir/include/mlir/IR/Function.h | 32 +++++- mlir/include/mlir/IR/InstVisitor.h | 140 --------------------------- mlir/include/mlir/IR/Instruction.h | 30 ++++++ mlir/lib/AffineOps/AffineOps.cpp | 27 ------ mlir/lib/Analysis/MemRefBoundCheck.cpp | 24 ++--- mlir/lib/Analysis/MemRefDependenceCheck.cpp | 16 ++- mlir/lib/Analysis/OpStats.cpp | 17 ++-- mlir/lib/EDSC/LowerEDSCTestPass.cpp | 1 - mlir/lib/IR/AsmPrinter.cpp | 1 - mlir/lib/IR/Block.cpp | 85 +++++++--------- mlir/lib/IR/Function.cpp | 32 ++---- mlir/lib/IR/Instruction.cpp | 30 +++++- mlir/lib/Parser/Parser.cpp | 1 - mlir/lib/Transforms/CSE.cpp | 1 - mlir/lib/Transforms/ComposeAffineMaps.cpp | 15 +-- mlir/lib/Transforms/ConstantFold.cpp | 11 +-- mlir/lib/Transforms/LoopFusion.cpp | 90 +++++++++-------- mlir/lib/Transforms/LoopUnroll.cpp | 49 +++------- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 20 ++-- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 17 ++-- mlir/lib/Transforms/PipelineDataTransfer.cpp | 6 +- mlir/lib/Transforms/Utils/LoopUtils.cpp | 7 +- 25 files changed, 277 insertions(+), 455 deletions(-) delete mode 100644 mlir/include/mlir/IR/InstVisitor.h (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/AffineOps/AffineOps.h b/mlir/include/mlir/AffineOps/AffineOps.h index 43ad0354b01..5ec536f47e7 100644 --- a/mlir/include/mlir/AffineOps/AffineOps.h +++ b/mlir/include/mlir/AffineOps/AffineOps.h @@ -229,14 +229,6 @@ public: /// (same operands in the same order). bool matchingBoundOperandList() const; - /// Walk the operation instructions in the 'for' instruction in preorder, - /// calling the callback for each operation. - void walk(std::function callback); - - /// Walk the operation instructions in the 'for' instruction in postorder, - /// calling the callback for each operation. - void walkPostOrder(std::function callback); - private: friend class Instruction; explicit AffineForOp(const Instruction *state) : Op(state) {} diff --git a/mlir/include/mlir/Analysis/NestedMatcher.h b/mlir/include/mlir/Analysis/NestedMatcher.h index aba0e11ab91..44fe4c0558a 100644 --- a/mlir/include/mlir/Analysis/NestedMatcher.h +++ b/mlir/include/mlir/Analysis/NestedMatcher.h @@ -18,7 +18,7 @@ #ifndef MLIR_ANALYSIS_MLFUNCTIONMATCHER_H_ #define MLIR_ANALYSIS_MLFUNCTIONMATCHER_H_ -#include "mlir/IR/InstVisitor.h" +#include "mlir/IR/Function.h" #include "llvm/Support/Allocator.h" namespace mlir { @@ -76,7 +76,7 @@ private: ArrayRef matchedChildren; }; -/// A NestedPattern is a nested InstWalker that: +/// A NestedPattern is a nested instruction walker that: /// 1. recursively matches a substructure in the tree; /// 2. uses a filter function to refine matches with extra semantic /// constraints (passed via a lambda of type FilterFunctionType); @@ -92,8 +92,8 @@ private: /// /// The NestedMatches captured in the IR can grow large, especially after /// aggressive unrolling. As experience has shown, it is generally better to use -/// a plain InstWalker to match flat patterns but the current implementation is -/// competitive nonetheless. +/// a plain walk over instructions to match flat patterns but the current +/// implementation is competitive nonetheless. using FilterFunctionType = std::function; static bool defaultFilterFunction(const Instruction &) { return true; }; struct NestedPattern { @@ -102,16 +102,14 @@ struct NestedPattern { NestedPattern(const NestedPattern &) = default; NestedPattern &operator=(const NestedPattern &) = default; - /// Returns all the top-level matches in `function`. - void match(Function *function, SmallVectorImpl *matches) { - State state(*this, matches); - state.walkPostOrder(function); + /// Returns all the top-level matches in `func`. + void match(Function *func, SmallVectorImpl *matches) { + func->walkPostOrder([&](Instruction *inst) { matchOne(inst, matches); }); } /// Returns all the top-level matches in `inst`. void match(Instruction *inst, SmallVectorImpl *matches) { - State state(*this, matches); - state.walkPostOrder(inst); + inst->walkPostOrder([&](Instruction *child) { matchOne(child, matches); }); } /// Returns the depth of the pattern. @@ -120,22 +118,8 @@ struct NestedPattern { private: friend class NestedPatternContext; friend class NestedMatch; - friend class InstWalker; friend struct State; - /// Helper state that temporarily holds matches for the next level of nesting. - struct State : public InstWalker { - State(NestedPattern &pattern, SmallVectorImpl *matches) - : pattern(pattern), matches(matches) {} - void visitInstruction(Instruction *opInst) { - pattern.matchOne(opInst, matches); - } - - private: - NestedPattern &pattern; - SmallVectorImpl *matches; - }; - /// Underlying global bump allocator managed by a NestedPatternContext. static llvm::BumpPtrAllocator *&allocator(); @@ -153,8 +137,9 @@ private: /// without switching on the type of the Instruction. The idea is that a /// NestedPattern first checks if it matches locally and then recursively /// applies its nested matchers to its elem->nested. Since we want to rely on - /// the InstWalker impl rather than duplicate its the logic, we allow an - /// off-by-one traversal to account for the fact that we write: + /// the existing instruction walking functionality rather than duplicate + /// it, we allow an off-by-one traversal to account for the fact that we + /// write: /// /// void match(Instruction *elem) { /// for (auto &c : getNestedPatterns()) { diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h index f3a2218d0f9..6e44770282b 100644 --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -287,6 +287,28 @@ public: succ_iterator succ_end(); llvm::iterator_range getSuccessors(); + //===--------------------------------------------------------------------===// + // Instruction Walkers + //===--------------------------------------------------------------------===// + + /// Walk the instructions of this block in preorder, calling the callback for + /// each operation. + void walk(const std::function &callback); + + /// Walk the instructions in the specified [begin, end) range of + /// this block, calling the callback for each operation. + void walk(Block::iterator begin, Block::iterator end, + const std::function &callback); + + /// Walk the instructions in this block in postorder, calling the callback for + /// each operation. + void walkPostOrder(const std::function &callback); + + /// Walk the instructions in the specified [begin, end) range of this block + /// in postorder, calling the callback for each operation. + void walkPostOrder(Block::iterator begin, Block::iterator end, + const std::function &callback); + //===--------------------------------------------------------------------===// // Other //===--------------------------------------------------------------------===// @@ -311,19 +333,6 @@ public: return &Block::instructions; } - /// Walk the instructions of this block in preorder, calling the callback for - /// each operation. - void walk(std::function callback); - - /// Walk the instructions in this block in postorder, calling the callback for - /// each operation. - void walkPostOrder(std::function callback); - - /// Walk the instructions in the specified [begin, end) range of - /// this block, calling the callback for each operation. - void walk(Block::iterator begin, Block::iterator end, - std::function callback); - void print(raw_ostream &os) const; void dump() const; diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h index f483ff46259..3afb021c8ec 100644 --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -27,6 +27,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Block.h" #include "mlir/IR/Identifier.h" +#include "mlir/IR/Instruction.h" #include "mlir/IR/Location.h" #include "mlir/IR/Types.h" #include "mlir/Support/LLVM.h" @@ -39,6 +40,7 @@ class FunctionType; class MLIRContext; class Module; template class ArgumentIterator; +template class OpPointer; /// NamedAttribute is used for function attribute lists, it holds an /// identifier for the name and a value for the attribute. The attribute @@ -115,13 +117,35 @@ public: Block &front() { return blocks.front(); } const Block &front() const { return const_cast(this)->front(); } + //===--------------------------------------------------------------------===// + // Instruction Walkers + //===--------------------------------------------------------------------===// + /// Walk the instructions in the function in preorder, calling the callback - /// for each instruction or operation. - void walk(std::function callback); + /// for each instruction. + void walk(const std::function &callback); + + /// Specialization of walk to only visit operations of 'OpTy'. + template + void walk(std::function)> callback) { + walk([&](Instruction *inst) { + if (auto op = inst->dyn_cast()) + callback(op); + }); + } /// Walk the instructions in the function in postorder, calling the callback - /// for each instruction or operation. - void walkPostOrder(std::function callback); + /// for each instruction. + void walkPostOrder(const std::function &callback); + + /// Specialization of walkPostOrder to only visit operations of 'OpTy'. + template + void walkPostOrder(std::function)> callback) { + walkPostOrder([&](Instruction *inst) { + if (auto op = inst->dyn_cast()) + callback(op); + }); + } //===--------------------------------------------------------------------===// // Arguments diff --git a/mlir/include/mlir/IR/InstVisitor.h b/mlir/include/mlir/IR/InstVisitor.h deleted file mode 100644 index e11b7350894..00000000000 --- a/mlir/include/mlir/IR/InstVisitor.h +++ /dev/null @@ -1,140 +0,0 @@ -//===- InstVisitor.h - MLIR Instruction Visitor Class -----------*- C++ -*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file defines the base classes for Function's instruction visitors and -// walkers. A visit is a O(1) operation that visits just the node in question. A -// walk visits the node it's called on as well as the node's descendants. -// -// Instruction visitors/walkers are used when you want to perform different -// actions for different kinds of instructions without having to use lots of -// casts and a big switch instruction. -// -// To define your own visitor/walker, inherit from these classes, specifying -// your new type for the 'SubClass' template parameter, and "override" visitXXX -// functions in your class. This class is defined in terms of statically -// resolved overloading, not virtual functions. -// -// For example, here is a walker that counts the number of for loops in an -// Function. -// -// /// Declare the class. Note that we derive from InstWalker instantiated -// /// with _our new subclasses_ type. -// struct LoopCounter : public InstWalker { -// unsigned numLoops; -// LoopCounter() : numLoops(0) {} -// void visitForInst(ForInst &fs) { ++numLoops; } -// }; -// -// And this class would be used like this: -// LoopCounter lc; -// lc.walk(function); -// numLoops = lc.numLoops; -// -// There are 'visit' methods for Instruction and Function, which recursively -// process all contained instructions. -// -// Note that if you don't implement visitXXX for some instruction type, -// the visitXXX method for Instruction superclass will be invoked. -// -// The optional second template argument specifies the type that instruction -// visitation functions should return. If you specify this, you *MUST* provide -// an implementation of every visit<#Instruction>(InstType *). -// -// Note that these classes are specifically designed as a template to avoid -// virtual function call overhead. - -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_IR_INSTVISITOR_H -#define MLIR_IR_INSTVISITOR_H - -#include "mlir/IR/Function.h" -#include "mlir/IR/Instruction.h" - -namespace mlir { -/// Base class for instruction walkers. A walker can traverse depth first in -/// pre-order or post order. The walk methods without a suffix do a pre-order -/// traversal while those that traverse in post order have a PostOrder suffix. -template class InstWalker { - //===--------------------------------------------------------------------===// - // Interface code - This is the public interface of the InstWalker used to - // walk instructions. - -public: - // Generic walk method - allow walk to all instructions in a range. - template void walk(Iterator Start, Iterator End) { - while (Start != End) { - walk(&(*Start++)); - } - } - template void walkPostOrder(Iterator Start, Iterator End) { - while (Start != End) { - walkPostOrder(&(*Start++)); - } - } - - // Define walkers for Function and all Function instruction kinds. - void walk(Function *f) { - for (auto &block : *f) - static_cast(this)->walk(block.begin(), block.end()); - } - - void walkPostOrder(Function *f) { - for (auto it = f->rbegin(), e = f->rend(); it != e; ++it) - static_cast(this)->walkPostOrder(it->begin(), it->end()); - } - - // Function to walk a instruction. - RetTy walk(Instruction *s) { - static_assert(std::is_base_of::value, - "Must pass the derived type to this template!"); - - static_cast(this)->visitInstruction(s); - for (auto &blockList : s->getBlockLists()) - for (auto &block : blockList) - static_cast(this)->walk(block.begin(), block.end()); - } - - // Function to walk a instruction in post order DFS. - RetTy walkPostOrder(Instruction *s) { - static_assert(std::is_base_of::value, - "Must pass the derived type to this template!"); - for (auto &blockList : s->getBlockLists()) - for (auto &block : blockList) - static_cast(this)->walkPostOrder(block.begin(), - block.end()); - static_cast(this)->visitInstruction(s); - } - - //===--------------------------------------------------------------------===// - // Visitation functions... these functions provide default fallbacks in case - // the user does not specify what to do for a particular instruction type. - // The default behavior is to generalize the instruction type to its subtype - // and try visiting the subtype. All of this should be inlined perfectly, - // because there are no virtual functions to get in the way. - - // When visiting a specific inst directly during a walk, these methods get - // called. These are typically O(1) complexity and shouldn't be recursively - // processing their descendants in some way. When using RetTy, all of these - // need to be overridden. - void visitInstruction(Instruction *inst) {} -}; - -} // end namespace mlir - -#endif // MLIR_IR_INSTVISITOR_H diff --git a/mlir/include/mlir/IR/Instruction.h b/mlir/include/mlir/IR/Instruction.h index c8a1dc8a7bb..bbd0ba10d65 100644 --- a/mlir/include/mlir/IR/Instruction.h +++ b/mlir/include/mlir/IR/Instruction.h @@ -613,6 +613,36 @@ public: return OpClass::isClassFor(this); } + //===--------------------------------------------------------------------===// + // Instruction Walkers + //===--------------------------------------------------------------------===// + + /// Walk the instructions held by this instruction in preorder, calling the + /// callback for each instruction. + void walk(const std::function &callback); + + /// Specialization of walk to only visit operations of 'OpTy'. + template + void walk(std::function)> callback) { + walk([&](Instruction *inst) { + if (auto op = inst->dyn_cast()) + callback(op); + }); + } + + /// Walk the instructions held by this function in postorder, calling the + /// callback for each instruction. + void walkPostOrder(const std::function &callback); + + /// Specialization of walkPostOrder to only visit operations of 'OpTy'. + template + void walkPostOrder(std::function)> callback) { + walkPostOrder([&](Instruction *inst) { + if (auto op = inst->dyn_cast()) + callback(op); + }); + } + //===--------------------------------------------------------------------===// // Other //===--------------------------------------------------------------------===// diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 39345d7fc7a..c3adf5fb7c3 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -19,7 +19,6 @@ #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" @@ -646,32 +645,6 @@ bool AffineForOp::matchingBoundOperandList() const { return true; } -void AffineForOp::walk(std::function callback) { - struct Walker : public InstWalker { - std::function const &callback; - Walker(std::function const &callback) - : callback(callback) {} - - void visitInstruction(Instruction *opInst) { callback(opInst); } - }; - - Walker w(callback); - w.walk(getInstruction()); -} - -void AffineForOp::walkPostOrder(std::function callback) { - struct Walker : public InstWalker { - std::function const &callback; - Walker(std::function const &callback) - : callback(callback) {} - - void visitInstruction(Instruction *opInst) { callback(opInst); } - }; - - Walker v(callback); - v.walkPostOrder(getInstruction()); -} - /// Returns the induction variable for this loop. Value *AffineForOp::getInductionVar() { return getBody()->getArgument(0); } diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index ab22f261a3b..3376cd7d512 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -26,7 +26,6 @@ #include "mlir/Analysis/Utils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "llvm/Support/Debug.h" @@ -38,13 +37,11 @@ using namespace mlir; namespace { /// Checks for out of bound memef access subscripts.. -struct MemRefBoundCheck : public FunctionPass, InstWalker { +struct MemRefBoundCheck : public FunctionPass { explicit MemRefBoundCheck() : FunctionPass(&MemRefBoundCheck::passID) {} PassResult runOnFunction(Function *f) override; - void visitInstruction(Instruction *opInst); - static char passID; }; @@ -56,17 +53,16 @@ FunctionPass *mlir::createMemRefBoundCheckPass() { return new MemRefBoundCheck(); } -void MemRefBoundCheck::visitInstruction(Instruction *opInst) { - if (auto loadOp = opInst->dyn_cast()) { - boundCheckLoadOrStoreOp(loadOp); - } else if (auto storeOp = opInst->dyn_cast()) { - boundCheckLoadOrStoreOp(storeOp); - } - // TODO(bondhugula): do this for DMA ops as well. -} - PassResult MemRefBoundCheck::runOnFunction(Function *f) { - return walk(f), success(); + f->walk([](Instruction *opInst) { + if (auto loadOp = opInst->dyn_cast()) { + boundCheckLoadOrStoreOp(loadOp); + } else if (auto storeOp = opInst->dyn_cast()) { + boundCheckLoadOrStoreOp(storeOp); + } + // TODO(bondhugula): do this for DMA ops as well. + }); + return success(); } static PassRegistration diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp index 6ea47a20f60..9ec1c95f213 100644 --- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp @@ -25,7 +25,6 @@ #include "mlir/Analysis/Utils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "llvm/Support/Debug.h" @@ -38,19 +37,13 @@ namespace { // TODO(andydavis) Add common surrounding loop depth-wise dependence checks. /// Checks dependences between all pairs of memref accesses in a Function. -struct MemRefDependenceCheck : public FunctionPass, - InstWalker { +struct MemRefDependenceCheck : public FunctionPass { SmallVector loadsAndStores; explicit MemRefDependenceCheck() : FunctionPass(&MemRefDependenceCheck::passID) {} PassResult runOnFunction(Function *f) override; - void visitInstruction(Instruction *opInst) { - if (opInst->isa() || opInst->isa()) { - loadsAndStores.push_back(opInst); - } - } static char passID; }; @@ -120,8 +113,13 @@ static void checkDependences(ArrayRef loadsAndStores) { // Walks the Function 'f' adding load and store ops to 'loadsAndStores'. // Runs pair-wise dependence checks. PassResult MemRefDependenceCheck::runOnFunction(Function *f) { + // Collect the loads and stores within the function. loadsAndStores.clear(); - walk(f); + f->walk([&](Instruction *inst) { + if (inst->isa() || inst->isa()) + loadsAndStores.push_back(inst); + }); + checkDependences(loadsAndStores); return success(); } diff --git a/mlir/lib/Analysis/OpStats.cpp b/mlir/lib/Analysis/OpStats.cpp index 742c0baa96b..f05f8737b16 100644 --- a/mlir/lib/Analysis/OpStats.cpp +++ b/mlir/lib/Analysis/OpStats.cpp @@ -15,7 +15,6 @@ // limitations under the License. // ============================================================================= -#include "mlir/IR/InstVisitor.h" #include "mlir/IR/Instruction.h" #include "mlir/IR/Module.h" #include "mlir/IR/OperationSupport.h" @@ -27,16 +26,13 @@ using namespace mlir; namespace { -struct PrintOpStatsPass : public ModulePass, InstWalker { +struct PrintOpStatsPass : public ModulePass { explicit PrintOpStatsPass(llvm::raw_ostream &os = llvm::errs()) : ModulePass(&PrintOpStatsPass::passID), os(os) {} // Prints the resultant operation statistics post iterating over the module. PassResult runOnModule(Module *m) override; - // Updates the operation statistics for the given instruction. - void visitInstruction(Instruction *inst); - // Print summary of op stats. void printSummary(); @@ -44,7 +40,6 @@ struct PrintOpStatsPass : public ModulePass, InstWalker { private: llvm::StringMap opCount; - llvm::raw_ostream &os; }; } // namespace @@ -52,16 +47,16 @@ private: char PrintOpStatsPass::passID = 0; PassResult PrintOpStatsPass::runOnModule(Module *m) { + opCount.clear(); + + // Compute the operation statistics for each function in the module. for (auto &fn : *m) - walk(&fn); + fn.walk( + [&](Instruction *inst) { ++opCount[inst->getName().getStringRef()]; }); printSummary(); return success(); } -void PrintOpStatsPass::visitInstruction(Instruction *inst) { - ++opCount[inst->getName().getStringRef()]; -} - void PrintOpStatsPass::printSummary() { os << "Operations encountered:\n"; os << "-----------------------\n"; diff --git a/mlir/lib/EDSC/LowerEDSCTestPass.cpp b/mlir/lib/EDSC/LowerEDSCTestPass.cpp index 1703a16c2b8..cea99121dc9 100644 --- a/mlir/lib/EDSC/LowerEDSCTestPass.cpp +++ b/mlir/lib/EDSC/LowerEDSCTestPass.cpp @@ -19,7 +19,6 @@ #include "mlir/EDSC/Types.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/StandardTypes.h" diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index a69920cbd86..ffc863d76d0 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -25,7 +25,6 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Function.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/IR/Instruction.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLIRContext.h" diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp index f18ce8e33a8..e6dfc4c2145 100644 --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -18,7 +18,6 @@ #include "mlir/IR/Block.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/IR/Instruction.h" using namespace mlir; @@ -226,6 +225,34 @@ Block *Block::getSinglePredecessor() { return it == pred_end() ? firstPred : nullptr; } +//===----------------------------------------------------------------------===// +// Instruction Walkers +//===----------------------------------------------------------------------===// + +void Block::walk(const std::function &callback) { + walk(begin(), end(), callback); +} + +void Block::walk(Block::iterator begin, Block::iterator end, + const std::function &callback) { + // Walk the instructions within this block. + for (auto &inst : llvm::make_early_inc_range(llvm::make_range(begin, end))) + inst.walk(callback); +} + +void Block::walkPostOrder(const std::function &callback) { + walkPostOrder(begin(), end(), callback); +} + +/// Walk the instructions in the specified [begin, end) range of this block +/// in postorder, calling the callback for each operation. +void Block::walkPostOrder(Block::iterator begin, Block::iterator end, + const std::function &callback) { + // Walk the instructions within this block. + for (auto &inst : llvm::make_early_inc_range(llvm::make_range(begin, end))) + inst.walkPostOrder(callback); +} + //===----------------------------------------------------------------------===// // Other //===----------------------------------------------------------------------===// @@ -253,37 +280,6 @@ Block *Block::splitBlock(iterator splitBefore) { return newBB; } -void Block::walk(std::function callback) { - walk(begin(), end(), callback); -} - -void Block::walk(Block::iterator begin, Block::iterator end, - std::function callback) { - struct Walker : public InstWalker { - std::function const &callback; - Walker(std::function const &callback) - : callback(callback) {} - - void visitInstruction(Instruction *opInst) { callback(opInst); } - }; - - Walker w(callback); - w.walk(begin, end); -} - -void Block::walkPostOrder(std::function callback) { - struct Walker : public InstWalker { - std::function const &callback; - Walker(std::function const &callback) - : callback(callback) {} - - void visitInstruction(Instruction *opInst) { callback(opInst); } - }; - - Walker v(callback); - v.walkPostOrder(begin(), end()); -} - //===----------------------------------------------------------------------===// // BlockList //===----------------------------------------------------------------------===// @@ -331,25 +327,18 @@ void BlockList::cloneInto(BlockList *dest, BlockAndValueMapping &mapper, // Now that each of the blocks have been cloned, go through and remap the // operands of each of the instructions. - struct Walker : public InstWalker { - BlockAndValueMapping &mapper; - Walker(BlockAndValueMapping &mapper) : mapper(mapper) {} - - /// Remap the instruction and successor block operands. - void visitInstruction(Instruction *inst) { - for (auto &instOp : inst->getInstOperands()) - if (auto *mappedOp = mapper.lookupOrNull(instOp.get())) - instOp.set(mappedOp); - if (inst->isTerminator()) - for (auto &succOp : inst->getBlockOperands()) - if (auto *mappedOp = mapper.lookupOrNull(succOp.get())) - succOp.set(mappedOp); - } + auto remapOperands = [&](Instruction *inst) { + for (auto &instOp : inst->getInstOperands()) + if (auto *mappedOp = mapper.lookupOrNull(instOp.get())) + instOp.set(mappedOp); + if (inst->isTerminator()) + for (auto &succOp : inst->getBlockOperands()) + if (auto *mappedOp = mapper.lookupOrNull(succOp.get())) + succOp.set(mappedOp); }; - Walker v(mapper); for (auto it = std::next(lastOldBlock), e = dest->end(); it != e; ++it) - v.walk(it->begin(), it->end()); + it->walk(remapOperands); } BlockList *llvm::ilist_traits<::mlir::Block>::getContainingBlockList() { diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index 3a263fb13f9..ba781500c4f 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -19,7 +19,6 @@ #include "AttributeListStorage.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/Types.h" @@ -214,28 +213,15 @@ void Function::addEntryBlock() { entry->addArguments(type.getInputs()); } -void Function::walk(std::function callback) { - struct Walker : public InstWalker { - std::function const &callback; - Walker(std::function const &callback) - : callback(callback) {} - - void visitInstruction(Instruction *inst) { callback(inst); } - }; - - Walker v(callback); - v.walk(this); +void Function::walk(const std::function &callback) { + // Walk each of the blocks within the function. + for (auto &block : getBlocks()) + block.walk(callback); } -void Function::walkPostOrder(std::function callback) { - struct Walker : public InstWalker { - std::function const &callback; - Walker(std::function const &callback) - : callback(callback) {} - - void visitInstruction(Instruction *inst) { callback(inst); } - }; - - Walker v(callback); - v.walkPostOrder(this); +void Function::walkPostOrder( + const std::function &callback) { + // Walk each of the blocks within the function. + for (auto &block : llvm::reverse(getBlocks())) + block.walkPostOrder(callback); } diff --git a/mlir/lib/IR/Instruction.cpp b/mlir/lib/IR/Instruction.cpp index 6720969ac0f..062f13a3282 100644 --- a/mlir/lib/IR/Instruction.cpp +++ b/mlir/lib/IR/Instruction.cpp @@ -22,7 +22,6 @@ #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Function.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLIRContext.h" #include "llvm/ADT/DenseMap.h" @@ -300,6 +299,35 @@ Function *Instruction::getFunction() const { return block ? block->getFunction() : nullptr; } +//===----------------------------------------------------------------------===// +// Instruction Walkers +//===----------------------------------------------------------------------===// + +void Instruction::walk(const std::function &callback) { + // Visit the current instruction. + callback(this); + + // Visit any internal instructions. + for (auto &blockList : getBlockLists()) + for (auto &block : blockList) + block.walk(callback); +} + +void Instruction::walkPostOrder( + const std::function &callback) { + // Visit any internal instructions. + for (auto &blockList : llvm::reverse(getBlockLists())) + for (auto &block : llvm::reverse(blockList)) + block.walkPostOrder(callback); + + // Visit the current instruction. + callback(this); +} + +//===----------------------------------------------------------------------===// +// Other +//===----------------------------------------------------------------------===// + /// Emit a note about this instruction, reporting up to any diagnostic /// handlers that may be listening. void Instruction::emitNote(const Twine &message) const { diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index ae77d66b183..b7e4fb147cb 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -26,7 +26,6 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index 63a676d7b52..de10fe8a461 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -24,7 +24,6 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/Support/Functional.h" #include "mlir/Transforms/Passes.h" diff --git a/mlir/lib/Transforms/ComposeAffineMaps.cpp b/mlir/lib/Transforms/ComposeAffineMaps.cpp index 289b00d3b51..796477c64f2 100644 --- a/mlir/lib/Transforms/ComposeAffineMaps.cpp +++ b/mlir/lib/Transforms/ComposeAffineMaps.cpp @@ -27,7 +27,6 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Transforms/Passes.h" @@ -46,10 +45,9 @@ namespace { // result of any AffineApplyOp). After this composition, AffineApplyOps with no // remaining uses are erased. // TODO(andydavis) Remove this when Chris adds instruction combiner pass. -struct ComposeAffineMaps : public FunctionPass, InstWalker { +struct ComposeAffineMaps : public FunctionPass { explicit ComposeAffineMaps() : FunctionPass(&ComposeAffineMaps::passID) {} PassResult runOnFunction(Function *f) override; - void visitInstruction(Instruction *opInst); SmallVector, 8> affineApplyOps; @@ -68,15 +66,11 @@ static bool affineApplyOp(const Instruction &inst) { return inst.isa(); } -void ComposeAffineMaps::visitInstruction(Instruction *opInst) { - if (auto afOp = opInst->dyn_cast()) - affineApplyOps.push_back(afOp); -} - PassResult ComposeAffineMaps::runOnFunction(Function *f) { // If needed for future efficiency, reserve space based on a pre-walk. affineApplyOps.clear(); - walk(f); + f->walk( + [&](OpPointer afOp) { affineApplyOps.push_back(afOp); }); for (auto afOp : affineApplyOps) { SmallVector operands(afOp->getOperands()); FuncBuilder b(afOp->getInstruction()); @@ -87,7 +81,8 @@ PassResult ComposeAffineMaps::runOnFunction(Function *f) { // Erase dead affine apply ops. affineApplyOps.clear(); - walk(f); + f->walk( + [&](OpPointer afOp) { affineApplyOps.push_back(afOp); }); for (auto it = affineApplyOps.rbegin(); it != affineApplyOps.rend(); ++it) { if ((*it)->use_empty()) { (*it)->erase(); diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index 54486cdb293..e41ac0ad329 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -18,7 +18,6 @@ #include "mlir/AffineOps/AffineOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Utils.h" @@ -27,7 +26,7 @@ using namespace mlir; namespace { /// Simple constant folding pass. -struct ConstantFold : public FunctionPass, InstWalker { +struct ConstantFold : public FunctionPass { ConstantFold() : FunctionPass(&ConstantFold::passID) {} // All constants in the function post folding. @@ -35,9 +34,7 @@ struct ConstantFold : public FunctionPass, InstWalker { // Operations that were folded and that need to be erased. std::vector opInstsToErase; - bool foldOperation(Instruction *op, - SmallVectorImpl &existingConstants); - void visitInstruction(Instruction *op); + void foldInstruction(Instruction *op); PassResult runOnFunction(Function *f) override; static char passID; @@ -49,7 +46,7 @@ char ConstantFold::passID = 0; /// Attempt to fold the specified operation, updating the IR to match. If /// constants are found, we keep track of them in the existingConstants list. /// -void ConstantFold::visitInstruction(Instruction *op) { +void ConstantFold::foldInstruction(Instruction *op) { // If this operation is an AffineForOp, then fold the bounds. if (auto forOp = op->dyn_cast()) { constantFoldBounds(forOp); @@ -111,7 +108,7 @@ PassResult ConstantFold::runOnFunction(Function *f) { existingConstants.clear(); opInstsToErase.clear(); - walk(f); + f->walk([&](Instruction *inst) { foldInstruction(inst); }); // At this point, these operations are dead, remove them. // TODO: This is assuming that all constant foldable operations have no diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 7a002168528..77e5a6aa04f 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -28,7 +28,6 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Transforms/LoopUtils.h" @@ -111,22 +110,23 @@ namespace { // LoopNestStateCollector walks loop nests and collects load and store // operations, and whether or not an IfInst was encountered in the loop nest. -class LoopNestStateCollector : public InstWalker { -public: +struct LoopNestStateCollector { SmallVector, 4> forOps; SmallVector loadOpInsts; SmallVector storeOpInsts; bool hasNonForRegion = false; - void visitInstruction(Instruction *opInst) { - if (opInst->isa()) - forOps.push_back(opInst->cast()); - else if (opInst->getNumBlockLists() != 0) - hasNonForRegion = true; - else if (opInst->isa()) - loadOpInsts.push_back(opInst); - else if (opInst->isa()) - storeOpInsts.push_back(opInst); + void collect(Instruction *instToWalk) { + instToWalk->walk([&](Instruction *opInst) { + if (opInst->isa()) + forOps.push_back(opInst->cast()); + else if (opInst->getNumBlockLists() != 0) + hasNonForRegion = true; + else if (opInst->isa()) + loadOpInsts.push_back(opInst); + else if (opInst->isa()) + storeOpInsts.push_back(opInst); + }); } }; @@ -510,7 +510,7 @@ bool MemRefDependenceGraph::init(Function *f) { // Create graph node 'id' to represent top-level 'forOp' and record // all loads and store accesses it contains. LoopNestStateCollector collector; - collector.walk(&inst); + collector.collect(&inst); // Return false if a non 'for' region was found (not currently supported). if (collector.hasNonForRegion) return false; @@ -606,41 +606,39 @@ struct LoopNestStats { // LoopNestStatsCollector walks a single loop nest and gathers per-loop // trip count and operation count statistics and records them in 'stats'. -class LoopNestStatsCollector : public InstWalker { -public: +struct LoopNestStatsCollector { LoopNestStats *stats; bool hasLoopWithNonConstTripCount = false; LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {} - void visitInstruction(Instruction *opInst) { - auto forOp = opInst->dyn_cast(); - if (!forOp) - return; - - auto *forInst = forOp->getInstruction(); - auto *parentInst = forOp->getInstruction()->getParentInst(); - if (parentInst != nullptr) { - assert(parentInst->isa() && "Expected parent AffineForOp"); - // Add mapping to 'forOp' from its parent AffineForOp. - stats->loopMap[parentInst].push_back(forOp); - } + void collect(Instruction *inst) { + inst->walk([&](OpPointer forOp) { + auto *forInst = forOp->getInstruction(); + auto *parentInst = forOp->getInstruction()->getParentInst(); + if (parentInst != nullptr) { + assert(parentInst->isa() && "Expected parent AffineForOp"); + // Add mapping to 'forOp' from its parent AffineForOp. + stats->loopMap[parentInst].push_back(forOp); + } - // Record the number of op instructions in the body of 'forOp'. - unsigned count = 0; - stats->opCountMap[forInst] = 0; - for (auto &inst : *forOp->getBody()) { - if (!(inst.isa() || inst.isa())) - ++count; - } - stats->opCountMap[forInst] = count; - // Record trip count for 'forOp'. Set flag if trip count is not constant. - Optional maybeConstTripCount = getConstantTripCount(forOp); - if (!maybeConstTripCount.hasValue()) { - hasLoopWithNonConstTripCount = true; - return; - } - stats->tripCountMap[forInst] = maybeConstTripCount.getValue(); + // Record the number of op instructions in the body of 'forOp'. + unsigned count = 0; + stats->opCountMap[forInst] = 0; + for (auto &inst : *forOp->getBody()) { + if (!(inst.isa() || inst.isa())) + ++count; + } + stats->opCountMap[forInst] = count; + // Record trip count for 'forOp'. Set flag if trip count is not + // constant. + Optional maybeConstTripCount = getConstantTripCount(forOp); + if (!maybeConstTripCount.hasValue()) { + hasLoopWithNonConstTripCount = true; + return; + } + stats->tripCountMap[forInst] = maybeConstTripCount.getValue(); + }); } }; @@ -1078,7 +1076,7 @@ static bool isFusionProfitable(Instruction *srcOpInst, // Walk src loop nest and collect stats. LoopNestStats srcLoopNestStats; LoopNestStatsCollector srcStatsCollector(&srcLoopNestStats); - srcStatsCollector.walk(srcLoopIVs[0]->getInstruction()); + srcStatsCollector.collect(srcLoopIVs[0]->getInstruction()); // Currently only constant trip count loop nests are supported. if (srcStatsCollector.hasLoopWithNonConstTripCount) return false; @@ -1089,7 +1087,7 @@ static bool isFusionProfitable(Instruction *srcOpInst, LoopNestStats dstLoopNestStats; LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats); - dstStatsCollector.walk(dstLoopIVs[0]->getInstruction()); + dstStatsCollector.collect(dstLoopIVs[0]->getInstruction()); // Currently only constant trip count loop nests are supported. if (dstStatsCollector.hasLoopWithNonConstTripCount) return false; @@ -1474,7 +1472,7 @@ public: // Collect slice loop stats. LoopNestStateCollector sliceCollector; - sliceCollector.walk(sliceLoopNest->getInstruction()); + sliceCollector.collect(sliceLoopNest->getInstruction()); // Promote single iteration slice loops to single IV value. for (auto forOp : sliceCollector.forOps) { promoteIfSingleIteration(forOp); @@ -1498,7 +1496,7 @@ public: // Collect dst loop stats after memref privatizaton transformation. LoopNestStateCollector dstLoopCollector; - dstLoopCollector.walk(dstAffineForOp->getInstruction()); + dstLoopCollector.collect(dstAffineForOp->getInstruction()); // Add new load ops to current Node load op list 'loads' to // continue fusing based on new operands. diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index b1e15ccb07b..3a7cfb85e08 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -27,7 +27,6 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/Transforms/LoopUtils.h" #include "llvm/ADT/DenseMap.h" @@ -95,15 +94,16 @@ char LoopUnroll::passID = 0; PassResult LoopUnroll::runOnFunction(Function *f) { // Gathers all innermost loops through a post order pruned walk. - class InnermostLoopGatherer : public InstWalker { - public: + struct InnermostLoopGatherer { // Store innermost loops as we walk. std::vector> loops; - // This method specialized to encode custom return logic. - using InstListType = llvm::iplist; - bool walkPostOrder(InstListType::iterator Start, - InstListType::iterator End) { + void walkPostOrder(Function *f) { + for (auto &b : *f) + walkPostOrder(b.begin(), b.end()); + } + + bool walkPostOrder(Block::iterator Start, Block::iterator End) { bool hasInnerLoops = false; // We need to walk all elements since all innermost loops need to be // gathered as opposed to determining whether this list has any inner @@ -112,7 +112,6 @@ PassResult LoopUnroll::runOnFunction(Function *f) { hasInnerLoops |= walkPostOrder(&(*Start++)); return hasInnerLoops; } - bool walkPostOrder(Instruction *opInst) { bool hasInnerLoops = false; for (auto &blockList : opInst->getBlockLists()) @@ -125,39 +124,21 @@ PassResult LoopUnroll::runOnFunction(Function *f) { } return hasInnerLoops; } - - // FIXME: can't use base class method for this because that in turn would - // need to use the derived class method above. CRTP doesn't allow it, and - // the compiler error resulting from it is also misleading. - using InstWalker::walkPostOrder; }; - // Gathers all loops with trip count <= minTripCount. - class ShortLoopGatherer : public InstWalker { - public: + if (clUnrollFull.getNumOccurrences() > 0 && + clUnrollFullThreshold.getNumOccurrences() > 0) { // Store short loops as we walk. std::vector> loops; - const unsigned minTripCount; - ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {} - void visitInstruction(Instruction *opInst) { - auto forOp = opInst->dyn_cast(); - if (!forOp) - return; + // Gathers all loops with trip count <= minTripCount. Do a post order walk + // so that loops are gathered from innermost to outermost (or else unrolling + // an outer one may delete gathered inner ones). + f->walkPostOrder([&](OpPointer forOp) { Optional tripCount = getConstantTripCount(forOp); - if (tripCount.hasValue() && tripCount.getValue() <= minTripCount) + if (tripCount.hasValue() && tripCount.getValue() <= clUnrollFullThreshold) loops.push_back(forOp); - } - }; - - if (clUnrollFull.getNumOccurrences() > 0 && - clUnrollFullThreshold.getNumOccurrences() > 0) { - ShortLoopGatherer slg(clUnrollFullThreshold); - // Do a post order walk so that loops are gathered from innermost to - // outermost (or else unrolling an outer one may delete gathered inner - // ones). - slg.walkPostOrder(f); - auto &loops = slg.loops; + }); for (auto forOp : loops) loopUnrollFull(forOp); return success(); diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 74c54fde047..b2aed7d9d7f 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -50,7 +50,6 @@ #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/Transforms/LoopUtils.h" #include "llvm/ADT/DenseMap.h" @@ -136,24 +135,25 @@ bool mlir::loopUnrollJamByFactor(OpPointer forOp, // Gathers all maximal sub-blocks of instructions that do not themselves // include a for inst (a instruction could have a descendant for inst though // in its tree). - class JamBlockGatherer : public InstWalker { - public: - using InstListType = llvm::iplist; - using InstWalker::walk; - + struct JamBlockGatherer { // Store iterators to the first and last inst of each sub-block found. std::vector> subBlocks; // This is a linear time walk. - void walk(InstListType::iterator Start, InstListType::iterator End) { - for (auto it = Start; it != End;) { + void walk(Instruction *inst) { + for (auto &blockList : inst->getBlockLists()) + for (auto &block : blockList) + walk(block); + } + void walk(Block &block) { + for (auto it = block.begin(), e = block.end(); it != e;) { auto subBlockStart = it; - while (it != End && !it->isa()) + while (it != e && !it->isa()) ++it; if (it != subBlockStart) subBlocks.push_back({subBlockStart, std::prev(it)}); // Process all for insts that appear next. - while (it != End && it->isa()) + while (it != e && it->isa()) walk(&*it++); } } diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index 2d06a327315..9c9db30d163 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -25,7 +25,6 @@ #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/Dominance.h" #include "mlir/Analysis/Utils.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Transforms/Passes.h" @@ -70,12 +69,12 @@ namespace { // currently only eliminates the stores only if no other loads/uses (other // than dealloc) remain. // -struct MemRefDataFlowOpt : public FunctionPass, InstWalker { +struct MemRefDataFlowOpt : public FunctionPass { explicit MemRefDataFlowOpt() : FunctionPass(&MemRefDataFlowOpt::passID) {} PassResult runOnFunction(Function *f) override; - void visitInstruction(Instruction *opInst); + void forwardStoreToLoad(OpPointer loadOp); // A list of memref's that are potentially dead / could be eliminated. SmallPtrSet memrefsToErase; @@ -100,14 +99,9 @@ FunctionPass *mlir::createMemRefDataFlowOptPass() { // This is a straightforward implementation not optimized for speed. Optimize // this in the future if needed. -void MemRefDataFlowOpt::visitInstruction(Instruction *opInst) { +void MemRefDataFlowOpt::forwardStoreToLoad(OpPointer loadOp) { Instruction *lastWriteStoreOp = nullptr; - - auto loadOp = opInst->dyn_cast(); - if (!loadOp) - return; - - Instruction *loadOpInst = opInst; + Instruction *loadOpInst = loadOp->getInstruction(); // First pass over the use list to get minimum number of surrounding // loops common between the load op and the store op, with min taken across @@ -235,7 +229,8 @@ PassResult MemRefDataFlowOpt::runOnFunction(Function *f) { memrefsToErase.clear(); // Walk all load's and perform load/store forwarding. - walk(f); + f->walk( + [&](OpPointer loadOp) { forwardStoreToLoad(loadOp); }); // Erase all load op's whose results were replaced with store fwd'ed ones. for (auto *loadOp : loadOpsToErase) { diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index ba3be5e95f4..4ca48a53485 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -142,10 +142,8 @@ PassResult PipelineDataTransfer::runOnFunction(Function *f) { // deleted and replaced by a prologue, a new steady-state loop and an // epilogue). forOps.clear(); - f->walkPostOrder([&](Instruction *opInst) { - if (auto forOp = opInst->dyn_cast()) - forOps.push_back(forOp); - }); + f->walkPostOrder( + [&](OpPointer forOp) { forOps.push_back(forOp); }); bool ret = false; for (auto forOp : forOps) { ret = ret | runOnAffineForOp(forOp); diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 5bf17989bef..95875adca6e 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -28,7 +28,6 @@ #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/IR/Instruction.h" #include "mlir/StandardOps/StandardOps.h" #include "llvm/ADT/DenseMap.h" @@ -135,10 +134,8 @@ bool mlir::promoteIfSingleIteration(OpPointer forOp) { /// their body into the containing Block. void mlir::promoteSingleIterationLoops(Function *f) { // Gathers all innermost loops through a post order pruned walk. - f->walkPostOrder([](Instruction *inst) { - if (auto forOp = inst->dyn_cast()) - promoteIfSingleIteration(forOp); - }); + f->walkPostOrder( + [](OpPointer forOp) { promoteIfSingleIteration(forOp); }); } /// Generates a 'for' inst with the specified lower and upper bounds while -- cgit v1.2.3 From a78edcda5bb5ba6d89d2efd3004becb7e3a9fc95 Mon Sep 17 00:00:00 2001 From: MLIR Team Date: Tue, 5 Feb 2019 06:57:08 -0800 Subject: Loop fusion improvements: *) After a private memref buffer is created for a fused loop nest, dependences on the old memref are reduced, which can open up fusion opportunities. In these cases, users of the old memref are added back to the worklist to be reconsidered for fusion. *) Fixed a bug in fusion insertion point dependence check where the memref being privatized was being skipped from the check. PiperOrigin-RevId: 232477853 --- mlir/lib/Transforms/LoopFusion.cpp | 35 +++++++++++++---- mlir/test/Transforms/loop-fusion.mlir | 73 ++++++++++++++++++++++++++++------- 2 files changed, 86 insertions(+), 22 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 77e5a6aa04f..d7e1b610022 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -365,21 +365,20 @@ public: // Computes and returns an insertion point instruction, before which the // the fused loop nest can be inserted while preserving // dependences. Returns nullptr if no such insertion point is found. - Instruction *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId, - Value *memrefToSkip) { + Instruction *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId) { if (outEdges.count(srcId) == 0) return getNode(dstId)->inst; // Build set of insts in range (srcId, dstId) which depend on 'srcId'. SmallPtrSet srcDepInsts; for (auto &outEdge : outEdges[srcId]) - if (outEdge.id != dstId && outEdge.value != memrefToSkip) + if (outEdge.id != dstId) srcDepInsts.insert(getNode(outEdge.id)->inst); // Build set of insts in range (srcId, dstId) on which 'dstId' depends. SmallPtrSet dstDepInsts; for (auto &inEdge : inEdges[dstId]) - if (inEdge.id != srcId && inEdge.value != memrefToSkip) + if (inEdge.id != srcId) dstDepInsts.insert(getNode(inEdge.id)->inst); Instruction *srcNodeInst = getNode(srcId)->inst; @@ -1366,18 +1365,24 @@ static bool isFusionProfitable(Instruction *srcOpInst, struct GreedyFusion { public: MemRefDependenceGraph *mdg; - SmallVector worklist; + SmallVector worklist; + llvm::SmallDenseSet worklistSet; GreedyFusion(MemRefDependenceGraph *mdg) : mdg(mdg) { // Initialize worklist with nodes from 'mdg'. + // TODO(andydavis) Add a priority queue for prioritizing nodes by different + // metrics (e.g. arithmetic intensity/flops-to-bytes ratio). worklist.resize(mdg->nodes.size()); std::iota(worklist.begin(), worklist.end(), 0); + worklistSet.insert(worklist.begin(), worklist.end()); } void run(unsigned localBufSizeThreshold, Optional fastMemorySpace) { while (!worklist.empty()) { unsigned dstId = worklist.back(); worklist.pop_back(); + worklistSet.erase(dstId); + // Skip if this node was removed (fused into another node). if (mdg->nodes.count(dstId) == 0) continue; @@ -1437,8 +1442,8 @@ public: // Compute an instruction list insertion point for the fused loop // nest which preserves dependences. - Instruction *insertPointInst = mdg->getFusedLoopNestInsertionPoint( - srcNode->id, dstNode->id, memref); + Instruction *insertPointInst = + mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id); if (insertPointInst == nullptr) continue; @@ -1516,6 +1521,22 @@ public: if (mdg->canRemoveNode(srcNode->id)) { mdg->removeNode(srcNode->id); srcNode->inst->erase(); + } else { + // Add remaining users of 'oldMemRef' back on the worklist (if not + // already there), as its replacement with a local/private memref + // has reduced dependences on 'oldMemRef' which may have created + // new fusion opportunities. + if (mdg->outEdges.count(srcNode->id) > 0) { + SmallVector oldOutEdges = + mdg->outEdges[srcNode->id]; + for (auto &outEdge : oldOutEdges) { + if (outEdge.value == memref && + worklistSet.count(outEdge.id) == 0) { + worklist.push_back(outEdge.id); + worklistSet.insert(outEdge.id); + } + } + } } } } diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index bb8dd0db73e..57c1ec4ceed 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -340,10 +340,9 @@ func @should_not_fuse_would_create_cycle() { } // ----- -// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) -// CHECK-LABEL: func @should_fuse_across_waw_dep_with_private_memref() { -func @should_fuse_across_waw_dep_with_private_memref() { +// CHECK-LABEL: func @should_not_fuse_across_waw_dep() { +func @should_not_fuse_across_waw_dep() { %m = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 @@ -358,16 +357,13 @@ func @should_fuse_across_waw_dep_with_private_memref() { } // Fusing loop %i0 to %i2 would violate the WAW dependence between %i0 and %i1 // CHECK: for %i0 = 0 to 10 { - // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32> + // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK: for %i1 = 0 to 10 { - // CHECK-NEXT: store %cst, %1[%i1] : memref<10xf32> + // CHECK-NEXT: store %cst, %0[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK: for %i2 = 0 to 10 { - // CHECK-NEXT: %2 = affine_apply [[MAP0]](%i2, %i2) - // CHECK-NEXT: store %cst, %0[%2] : memref<1xf32> - // CHECK-NEXT: %3 = affine_apply [[MAP0]](%i2, %i2) - // CHECK-NEXT: %4 = load %0[%3] : memref<1xf32> + // CHECK-NEXT: %1 = load %0[%i2] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -1234,17 +1230,17 @@ func @should_fuse_with_private_memrefs_with_diff_shapes() { // by loops %i1 and %i2. // CHECK-DAG: %0 = alloc() : memref<1xf32> // CHECK-DAG: %1 = alloc() : memref<1xf32> - // CHECK: for %i0 = 0 to 17 { + // CHECK: for %i0 = 0 to 82 { // CHECK-NEXT: %2 = affine_apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: store %cst, %0[%2] : memref<1xf32> + // CHECK-NEXT: store %cst, %1[%2] : memref<1xf32> // CHECK-NEXT: %3 = affine_apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: %4 = load %0[%3] : memref<1xf32> + // CHECK-NEXT: %4 = load %1[%3] : memref<1xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i1 = 0 to 82 { + // CHECK-NEXT: for %i1 = 0 to 17 { // CHECK-NEXT: %5 = affine_apply [[MAP0]](%i1, %i1) - // CHECK-NEXT: store %cst, %1[%5] : memref<1xf32> + // CHECK-NEXT: store %cst, %0[%5] : memref<1xf32> // CHECK-NEXT: %6 = affine_apply [[MAP0]](%i1, %i1) - // CHECK-NEXT: %7 = load %1[%6] : memref<1xf32> + // CHECK-NEXT: %7 = load %0[%6] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -1652,3 +1648,50 @@ func @should_fuse_at_depth_above_loop_carried_dependence(%arg0: memref<64x4xf32> // CHECK-NEXT: return return } + +// ----- + +// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) + +// CHECK-LABEL: func @should_fuse_after_private_memref_creation() { +func @should_fuse_after_private_memref_creation() { + %a = alloc() : memref<10xf32> + %b = alloc() : memref<10xf32> + + %cf7 = constant 7.0 : f32 + + for %i0 = 0 to 10 { + store %cf7, %a[%i0] : memref<10xf32> + } + for %i1 = 0 to 10 { + %v0 = load %a[%i1] : memref<10xf32> + store %v0, %b[%i1] : memref<10xf32> + } + for %i2 = 0 to 10 { + %v1 = load %a[%i2] : memref<10xf32> + store %v1, %b[%i2] : memref<10xf32> + } + + // On the first visit to '%i2', the fusion algorithm can not fuse loop nest + // '%i0' into '%i2' because of the dependences '%i0' and '%i2' each have on + // '%i1'. However, once the loop nest '%i0' is fused into '%i1' with a + // private memref, the dependence between '%i0' and '%i1' on memref '%a' no + // longer exists, so '%i0' can now be fused into '%i2'. + + // CHECK: for %i0 = 0 to 10 { + // CHECK-NEXT: %3 = affine_apply [[MAP0]](%i0, %i0) + // CHECK-NEXT: store %cst, %1[%3] : memref<1xf32> + // CHECK-NEXT: %4 = affine_apply [[MAP0]](%i0, %i0) + // CHECK-NEXT: %5 = load %1[%4] : memref<1xf32> + // CHECK-NEXT: store %5, %2[%i0] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: %6 = affine_apply [[MAP0]](%i1, %i1) + // CHECK-NEXT: store %cst, %0[%6] : memref<1xf32> + // CHECK-NEXT: %7 = affine_apply [[MAP0]](%i1, %i1) + // CHECK-NEXT: %8 = load %0[%7] : memref<1xf32> + // CHECK-NEXT: store %8, %2[%i1] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: return + return +} -- cgit v1.2.3 From 10237de8eb41f7343dd3c20cb21adc3cf2b1fee5 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Tue, 5 Feb 2019 17:00:13 -0800 Subject: Refactor the affine analysis by moving some functionality to IR and some to AffineOps. This is important for allowing the affine dialect to define canonicalizations directly on the operations instead of relying on transformation passes, e.g. ComposeAffineMaps. A summary of the refactoring: * AffineStructures has moved to IR. * simplifyAffineExpr/simplifyAffineMap/getFlattenedAffineExpr have moved to IR. * makeComposedAffineApply/fullyComposeAffineMapAndOperands have moved to AffineOps. * ComposeAffineMaps is replaced by AffineApplyOp::canonicalize and deleted. PiperOrigin-RevId: 232586468 --- mlir/include/mlir/AffineOps/AffineOps.h | 39 + mlir/include/mlir/Analysis/AffineAnalysis.h | 61 - mlir/include/mlir/Analysis/AffineStructures.h | 697 ------- mlir/include/mlir/Analysis/Utils.h | 2 +- mlir/include/mlir/IR/AffineExpr.h | 36 + mlir/include/mlir/IR/AffineMap.h | 4 + mlir/include/mlir/IR/AffineStructures.h | 682 +++++++ mlir/include/mlir/Transforms/Passes.h | 3 - mlir/lib/AffineOps/AffineOps.cpp | 317 +++- mlir/lib/Analysis/AffineAnalysis.cpp | 665 +------ mlir/lib/Analysis/AffineStructures.cpp | 2167 ---------------------- mlir/lib/Analysis/LoopAnalysis.cpp | 5 +- mlir/lib/Analysis/MemRefBoundCheck.cpp | 2 +- mlir/lib/Analysis/MemRefDependenceCheck.cpp | 2 +- mlir/lib/Analysis/Utils.cpp | 4 +- mlir/lib/IR/AffineExpr.cpp | 445 +++++ mlir/lib/IR/AffineMap.cpp | 13 + mlir/lib/IR/AffineStructures.cpp | 2063 ++++++++++++++++++++ mlir/lib/Transforms/ComposeAffineMaps.cpp | 96 - mlir/lib/Transforms/DmaGeneration.cpp | 2 +- mlir/lib/Transforms/LoopFusion.cpp | 2 +- mlir/lib/Transforms/LoopTiling.cpp | 2 +- mlir/lib/Transforms/SimplifyAffineStructures.cpp | 2 +- mlir/lib/Transforms/Utils/Utils.cpp | 2 +- mlir/test/AffineOps/canonicalize.mlir | 263 +++ mlir/test/Transforms/canonicalize.mlir | 45 - mlir/test/Transforms/compose-affine-maps.mlir | 255 --- 27 files changed, 3874 insertions(+), 4002 deletions(-) delete mode 100644 mlir/include/mlir/Analysis/AffineStructures.h create mode 100644 mlir/include/mlir/IR/AffineStructures.h delete mode 100644 mlir/lib/Analysis/AffineStructures.cpp create mode 100644 mlir/lib/IR/AffineStructures.cpp delete mode 100644 mlir/lib/Transforms/ComposeAffineMaps.cpp create mode 100644 mlir/test/AffineOps/canonicalize.mlir delete mode 100644 mlir/test/Transforms/compose-affine-maps.mlir (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/AffineOps/AffineOps.h b/mlir/include/mlir/AffineOps/AffineOps.h index 5ec536f47e7..52601370207 100644 --- a/mlir/include/mlir/AffineOps/AffineOps.h +++ b/mlir/include/mlir/AffineOps/AffineOps.h @@ -29,6 +29,9 @@ namespace mlir { class AffineBound; +class AffineValueMap; +class FlatAffineConstraints; +class FuncBuilder; class AffineOpsDialect : public Dialect { public: @@ -61,6 +64,9 @@ public: return getAttrOfType("map").getValue(); } + /// Returns an AffineValueMap representing this affine apply. + AffineValueMap getAsAffineValueMap(); + /// Returns true if the result of this operation can be used as dimension id. bool isValidDim() const; @@ -247,6 +253,19 @@ ConstOpPointer getForInductionVarOwner(const Value *val); void extractForInductionVars(ArrayRef> forInsts, SmallVectorImpl *ivs); +/// Adds constraints (lower and upper bounds) for the specified 'for' +/// instruction's Value using IR information stored in its bound maps. The +/// right identifier is first looked up using forOp's Value. Returns +/// false for the yet unimplemented/unsupported cases, and true if the +/// information is successfully added. Asserts if the Value corresponding to +/// the 'for' instruction isn't found in the constraint system. Any new +/// identifiers that are found in the bound operands of the 'for' instruction +/// are added as trailing identifiers (either dimensional or symbolic +/// depending on whether the operand is a valid ML Function symbol). +// TODO(bondhugula): add support for non-unit strides. +bool addAffineForOpDomain(ConstOpPointer forOp, + FlatAffineConstraints *constraints); + /// AffineBound represents a lower or upper bound in the for instruction. /// This class does not own the underlying operands. Instead, it refers /// to the operands stored in the AffineForOp. Its life span should not exceed @@ -256,6 +275,9 @@ public: ConstOpPointer getAffineForOp() const { return inst; } AffineMap getMap() const { return map; } + /// Returns an AffineValueMap representing this bound. + AffineValueMap getAsAffineValueMap(); + unsigned getNumOperands() const { return opEnd - opStart; } const Value *getOperand(unsigned idx) const { return inst->getInstruction()->getOperand(opStart + idx); @@ -354,6 +376,23 @@ bool isValidSymbol(const Value *value); void canonicalizeMapAndOperands(AffineMap *map, llvm::SmallVectorImpl *operands); +/// Returns a composed AffineApplyOp by composing `map` and `operands` with +/// other AffineApplyOps supplying those operands. The operands of the resulting +/// AffineApplyOp do not change the length of AffineApplyOp chains. +OpPointer +makeComposedAffineApply(FuncBuilder *b, Location loc, AffineMap map, + llvm::ArrayRef operands); + +/// Given an affine map `map` and its input `operands`, this method composes +/// into `map`, maps of AffineApplyOps whose results are the values in +/// `operands`, iteratively until no more of `operands` are the result of an +/// AffineApplyOp. When this function returns, `map` becomes the composed affine +/// map, and each Value in `operands` is guaranteed to be either a loop IV or a +/// terminal symbol, i.e., a symbol defined at the top level or a block/function +/// argument. +void fullyComposeAffineMapAndOperands(AffineMap *map, + llvm::SmallVectorImpl *operands); + } // end namespace mlir #endif diff --git a/mlir/include/mlir/Analysis/AffineAnalysis.h b/mlir/include/mlir/Analysis/AffineAnalysis.h index dd3e676eb17..9d3887ddb70 100644 --- a/mlir/include/mlir/Analysis/AffineAnalysis.h +++ b/mlir/include/mlir/Analysis/AffineAnalysis.h @@ -31,47 +31,13 @@ namespace mlir { class AffineApplyOp; -class AffineExpr; class AffineForOp; -class AffineMap; class AffineValueMap; class FlatAffineConstraints; -class FuncBuilder; class Instruction; -class IntegerSet; -class Location; -class MLIRContext; template class OpPointer; class Value; -/// Simplify an affine expression by flattening and some amount of -/// simple analysis. This has complexity linear in the number of nodes in -/// 'expr'. Returns the simplified expression, which is the same as the input -/// expression if it can't be simplified. -AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, - unsigned numSymbols); - -/// Simplify an affine map by simplifying its underlying AffineExpr results and -/// sizes. -AffineMap simplifyAffineMap(AffineMap map); - -/// Returns a composed AffineApplyOp by composing `map` and `operands` with -/// other AffineApplyOps supplying those operands. The operands of the resulting -/// AffineApplyOp do not change the length of AffineApplyOp chains. -OpPointer -makeComposedAffineApply(FuncBuilder *b, Location loc, AffineMap map, - llvm::ArrayRef operands); - -/// Given an affine map `map` and its input `operands`, this method composes -/// into `map`, maps of AffineApplyOps whose results are the values in -/// `operands`, iteratively until no more of `operands` are the result of an -/// AffineApplyOp. When this function returns, `map` becomes the composed affine -/// map, and each Value in `operands` is guaranteed to be either a loop IV or a -/// terminal symbol, i.e., a symbol defined at the top level or a block/function -/// argument. -void fullyComposeAffineMapAndOperands(AffineMap *map, - llvm::SmallVectorImpl *operands); - /// Returns in `affineApplyOps`, the sequence of those AffineApplyOp /// Instructions that are reachable via a search starting from `operands` and /// ending at those operands that are not the result of an AffineApplyOp. @@ -79,33 +45,6 @@ void getReachableAffineApplyOps( llvm::ArrayRef operands, llvm::SmallVectorImpl &affineApplyOps); -/// Flattens 'expr' into 'flattenedExpr'. Returns true on success or false -/// if 'expr' could not be flattened (i.e., semi-affine is not yet handled). -/// 'cst' contains constraints that connect newly introduced local identifiers -/// to existing dimensional and / symbolic identifiers. See documentation for -/// AffineExprFlattener on how mod's and div's are flattened. -bool getFlattenedAffineExpr(AffineExpr expr, unsigned numDims, - unsigned numSymbols, - llvm::SmallVectorImpl *flattenedExpr, - FlatAffineConstraints *cst = nullptr); - -/// Flattens the result expressions of the map to their corresponding flattened -/// forms and set in 'flattenedExprs'. Returns true on success or false -/// if any expression in the map could not be flattened (i.e., semi-affine is -/// not yet handled). 'cst' contains constraints that connect newly introduced -/// local identifiers to existing dimensional and / symbolic identifiers. See -/// documentation for AffineExprFlattener on how mod's and div's are flattened. -/// For all affine expressions that share the same operands (like those of an -/// affine map), this method should be used instead of repeatedly calling -/// getFlattenedAffineExpr since local variables added to deal with div's and -/// mod's will be reused across expressions. -bool getFlattenedAffineExprs( - AffineMap map, std::vector> *flattenedExprs, - FlatAffineConstraints *cst = nullptr); -bool getFlattenedAffineExprs( - IntegerSet set, std::vector> *flattenedExprs, - FlatAffineConstraints *cst = nullptr); - /// Builds a system of constraints with dimensional identifiers corresponding to /// the loop IVs of the forOps appearing in that order. Bounds of the loop are /// used to add appropriate inequalities. Any symbols founds in the bound diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h deleted file mode 100644 index e8b4ee623c0..00000000000 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ /dev/null @@ -1,697 +0,0 @@ -//===- AffineStructures.h - MLIR Affine Structures Class --------*- C++ -*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// Structures for affine/polyhedral analysis of ML functions. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_ANALYSIS_AFFINE_STRUCTURES_H -#define MLIR_ANALYSIS_AFFINE_STRUCTURES_H - -#include "mlir/IR/AffineExpr.h" - -namespace mlir { - -class AffineApplyOp; -class AffineBound; -class AffineForOp; -class AffineCondition; -class AffineMap; -template class ConstOpPointer; -class IntegerSet; -class MLIRContext; -class Value; -class HyperRectangularSet; -class MemRefType; - -/// A mutable affine map. Its affine expressions are however unique. -struct MutableAffineMap { -public: - MutableAffineMap() {} - MutableAffineMap(AffineMap map); - - ArrayRef getResults() const { return results; } - AffineExpr getResult(unsigned idx) const { return results[idx]; } - void setResult(unsigned idx, AffineExpr result) { results[idx] = result; } - unsigned getNumResults() const { return results.size(); } - unsigned getNumDims() const { return numDims; } - void setNumDims(unsigned d) { numDims = d; } - unsigned getNumSymbols() const { return numSymbols; } - void setNumSymbols(unsigned d) { numSymbols = d; } - MLIRContext *getContext() const { return context; } - - /// Returns true if the idx'th result expression is a multiple of factor. - bool isMultipleOf(unsigned idx, int64_t factor) const; - - /// Resets this MutableAffineMap with 'map'. - void reset(AffineMap map); - - /// Simplify the (result) expressions in this map using analysis (used by - //-simplify-affine-expr pass). - void simplify(); - /// Get the AffineMap corresponding to this MutableAffineMap. Note that an - /// AffineMap will be uniqued and stored in context, while a mutable one - /// isn't. - AffineMap getAffineMap() const; - -private: - // Same meaning as AffineMap's fields. - SmallVector results; - SmallVector rangeSizes; - unsigned numDims; - unsigned numSymbols; - /// A pointer to the IR's context to store all newly created - /// AffineExprStorage's. - MLIRContext *context; -}; - -/// A mutable integer set. Its affine expressions are however unique. -struct MutableIntegerSet { -public: - MutableIntegerSet(IntegerSet set, MLIRContext *context); - - /// Create a universal set (no constraints). - MutableIntegerSet(unsigned numDims, unsigned numSymbols, - MLIRContext *context); - - unsigned getNumDims() const { return numDims; } - unsigned getNumSymbols() const { return numSymbols; } - unsigned getNumConstraints() const { return constraints.size(); } - - void clear() { - constraints.clear(); - eqFlags.clear(); - } - -private: - unsigned numDims; - unsigned numSymbols; - - SmallVector constraints; - SmallVector eqFlags; - /// A pointer to the IR's context to store all newly created - /// AffineExprStorage's. - MLIRContext *context; -}; - -/// An AffineValueMap is an affine map plus its ML value operands and -/// results for analysis purposes. The structure is still a tree form that is -/// same as that of an affine map or an AffineApplyOp. However, its operands, -/// results, and its map can themselves change as a result of -/// substitutions, simplifications, and other analysis. -// An affine value map can readily be constructed from an AffineApplyOp, or an -// AffineBound of a AffineForOp. It can be further transformed, substituted -// into, or simplified. Unlike AffineMap's, AffineValueMap's are created and -// destroyed during analysis. Only the AffineMap expressions that are pointed by -// them are unique'd. An affine value map, and the operations on it, maintain -// the invariant that operands are always positionally aligned with the -// AffineDimExpr and AffineSymbolExpr in the underlying AffineMap. -// TODO(bondhugula): Some of these classes could go into separate files. -class AffineValueMap { -public: - // Creates an empty AffineValueMap (users should call 'reset' to reset map - // and operands). - AffineValueMap() {} - AffineValueMap(const AffineApplyOp &op); - AffineValueMap(const AffineBound &bound); - AffineValueMap(AffineMap map); - AffineValueMap(AffineMap map, ArrayRef operands); - - ~AffineValueMap(); - - // Resets this AffineValueMap with 'map' and 'operands'. - void reset(AffineMap map, ArrayRef operands); - /// Return true if the idx^th result can be proved to be a multiple of - /// 'factor', false otherwise. - inline bool isMultipleOf(unsigned idx, int64_t factor) const; - - /// Return true if the idx^th result depends on 'value', false otherwise. - bool isFunctionOf(unsigned idx, Value *value) const; - - /// Return true if the result at 'idx' is a constant, false - /// otherwise. - bool isConstant(unsigned idx) const; - - /// Return true if this is an identity map. - bool isIdentity() const; - - inline unsigned getNumOperands() const { return operands.size(); } - inline unsigned getNumDims() const { return map.getNumDims(); } - inline unsigned getNumSymbols() const { return map.getNumSymbols(); } - inline unsigned getNumResults() const { return map.getNumResults(); } - - Value *getOperand(unsigned i) const; - ArrayRef getOperands() const; - AffineMap getAffineMap() const; - -private: - // A mutable affine map. - MutableAffineMap map; - - // TODO: make these trailing objects? - /// The SSA operands binding to the dim's and symbols of 'map'. - SmallVector operands; - /// The SSA results binding to the results of 'map'. - SmallVector results; -}; - -/// An IntegerValueSet is an integer set plus its operands. -// Both, the integer set being pointed to and the operands can change during -// analysis, simplification, and transformation. -class IntegerValueSet { - /// Constructs an integer value set from an affine value map. - // This will lead to a single equality in 'set'. - explicit IntegerValueSet(const AffineValueMap &avm); - - /// Returns true if this integer set is determined to be empty. Emptiness is - /// checked by by eliminating identifiers successively (through either - /// Gaussian or Fourier-Motzkin) while using the GCD test and a trivial - /// invalid constraint check. Returns 'true' if the constaint system is found - /// to be empty; false otherwise. This method is exact for rational spaces but - /// not integer spaces - thus, if it returns true, the set is provably integer - /// empty as well, but if it returns false, it doesn't necessarily mean an - /// integer point exists in it. This method also returns false where an - /// explosion of constraints is detected - due to the super-exponential - /// worse-case complexity of Fourier-Motzkin elimination (rare for realistic - /// problem cases but possible for artificial adversarial or improperly - // constructed ones), this method returns false conservatively. - bool isEmpty() const; - - bool getNumDims() const { return set.getNumDims(); } - bool getNumSymbols() const { return set.getNumSymbols(); } - -private: - // The set pointed to may itself change unlike in IR structures like - // 'AffineCondition'. - MutableIntegerSet set; - /// The SSA operands binding to the dim's and symbols of 'set'. - SmallVector operands; -}; - -/// A flat list of affine equalities and inequalities in the form. -/// Inequality: c_0*x_0 + c_1*x_1 + .... + c_{n-1}*x_{n-1} == 0 -/// Equality: c_0*x_0 + c_1*x_1 + .... + c_{n-1}*x_{n-1} >= 0 -/// -/// FlatAffineConstraints stores coefficients in a contiguous buffer (one buffer -/// for equalities and one for inequalities). The size of each buffer is -/// numReservedCols * number of inequalities (or equalities). The reserved size -/// is numReservedCols * numReservedInequalities (or numReservedEqualities). A -/// coefficient (r, c) lives at the location numReservedCols * r + c in the -/// buffer. The extra space between getNumCols() and numReservedCols exists to -/// prevent frequent movement of data when adding columns, especially at the -/// end. -/// -/// The identifiers x_0, x_1, ... appear in the order: dimensional identifiers, -/// symbolic identifiers, and local identifiers. The local identifiers -/// correspond to local/internal variables created when converting from -/// AffineExpr's containing mod's and div's; they are thus needed to increase -/// representational power. Each local identifier is always (by construction) a -/// floordiv of a pure add/mul affine function of dimensional, symbolic, and -/// other local identifiers, in a non-mutually recursive way. Hence, every local -/// identifier can ultimately always be recovered as an affine function of -/// dimensional and symbolic identifiers (involving floordiv's); note however -/// that some floordiv combinations are converted to mod's by AffineExpr -/// construction. -/// -class FlatAffineConstraints { -public: - enum IdKind { Dimension, Symbol, Local }; - - /// Constructs a constraint system reserving memory for the specified number - /// of constraints and identifiers.. - FlatAffineConstraints(unsigned numReservedInequalities, - unsigned numReservedEqualities, - unsigned numReservedCols, unsigned numDims = 0, - unsigned numSymbols = 0, unsigned numLocals = 0, - ArrayRef> idArgs = {}) - : numReservedCols(numReservedCols), numDims(numDims), - numSymbols(numSymbols) { - assert(numReservedCols >= numDims + numSymbols + 1); - assert(idArgs.empty() || idArgs.size() == numDims + numSymbols + numLocals); - equalities.reserve(numReservedCols * numReservedEqualities); - inequalities.reserve(numReservedCols * numReservedInequalities); - numIds = numDims + numSymbols + numLocals; - ids.reserve(numReservedCols); - if (idArgs.empty()) - ids.resize(numIds, None); - else - ids.append(idArgs.begin(), idArgs.end()); - } - - /// Constructs a constraint system with the specified number of - /// dimensions and symbols. - FlatAffineConstraints(unsigned numDims = 0, unsigned numSymbols = 0, - unsigned numLocals = 0, - ArrayRef> idArgs = {}) - : numReservedCols(numDims + numSymbols + numLocals + 1), numDims(numDims), - numSymbols(numSymbols) { - assert(numReservedCols >= numDims + numSymbols + 1); - assert(idArgs.empty() || idArgs.size() == numDims + numSymbols + numLocals); - numIds = numDims + numSymbols + numLocals; - ids.reserve(numIds); - if (idArgs.empty()) - ids.resize(numIds, None); - else - ids.append(idArgs.begin(), idArgs.end()); - } - - explicit FlatAffineConstraints(const HyperRectangularSet &set); - - /// Create a flat affine constraint system from an AffineValueMap or a list of - /// these. The constructed system will only include equalities. - // TODO(bondhugula) - explicit FlatAffineConstraints(const AffineValueMap &avm); - explicit FlatAffineConstraints(ArrayRef avmRef); - - /// Creates an affine constraint system from an IntegerSet. - explicit FlatAffineConstraints(IntegerSet set); - - /// Create an affine constraint system from an IntegerValueSet. - // TODO(bondhugula) - explicit FlatAffineConstraints(const IntegerValueSet &set); - - FlatAffineConstraints(const FlatAffineConstraints &other); - - FlatAffineConstraints(ArrayRef avmRef, - IntegerSet set); - - FlatAffineConstraints(const MutableAffineMap &map); - - ~FlatAffineConstraints() {} - - // Clears any existing data and reserves memory for the specified constraints. - void reset(unsigned numReservedInequalities, unsigned numReservedEqualities, - unsigned numReservedCols, unsigned numDims, unsigned numSymbols, - unsigned numLocals = 0, ArrayRef idArgs = {}); - - void reset(unsigned numDims = 0, unsigned numSymbols = 0, - unsigned numLocals = 0, ArrayRef idArgs = {}); - - /// Appends constraints from 'other' into this. This is equivalent to an - /// intersection with no simplification of any sort attempted. - void append(const FlatAffineConstraints &other); - - // Checks for emptiness by performing variable elimination on all identifiers, - // running the GCD test on each equality constraint, and checking for invalid - // constraints. - // Returns true if the GCD test fails for any equality, or if any invalid - // constraints are discovered on any row. Returns false otherwise. - bool isEmpty() const; - - // Runs the GCD test on all equality constraints. Returns 'true' if this test - // fails on any equality. Returns 'false' otherwise. - // This test can be used to disprove the existence of a solution. If it - // returns true, no integer solution to the equality constraints can exist. - bool isEmptyByGCDTest() const; - - // Clones this object. - std::unique_ptr clone() const; - - /// Returns the value at the specified equality row and column. - inline int64_t atEq(unsigned i, unsigned j) const { - return equalities[i * numReservedCols + j]; - } - inline int64_t &atEq(unsigned i, unsigned j) { - return equalities[i * numReservedCols + j]; - } - - inline int64_t atIneq(unsigned i, unsigned j) const { - return inequalities[i * numReservedCols + j]; - } - - inline int64_t &atIneq(unsigned i, unsigned j) { - return inequalities[i * numReservedCols + j]; - } - - /// Returns the number of columns in the constraint system. - inline unsigned getNumCols() const { return numIds + 1; } - - inline unsigned getNumEqualities() const { - assert(equalities.size() % numReservedCols == 0 && - "inconsistent equality buffer size"); - return equalities.size() / numReservedCols; - } - - inline unsigned getNumInequalities() const { - assert(inequalities.size() % numReservedCols == 0 && - "inconsistent inequality buffer size"); - return inequalities.size() / numReservedCols; - } - - inline unsigned getNumReservedEqualities() const { - return equalities.capacity() / numReservedCols; - } - - inline unsigned getNumReservedInequalities() const { - return inequalities.capacity() / numReservedCols; - } - - inline ArrayRef getEquality(unsigned idx) const { - return ArrayRef(&equalities[idx * numReservedCols], getNumCols()); - } - - inline ArrayRef getInequality(unsigned idx) const { - return ArrayRef(&inequalities[idx * numReservedCols], - getNumCols()); - } - - AffineExpr toAffineExpr(unsigned idx, MLIRContext *context); - - /// Computes the lower and upper bounds of the first 'num' dimensional - /// identifiers as an affine map of the remaining identifiers (dimensional and - /// symbolic). This method is able to detect identifiers as floordiv's - /// and mod's of affine expressions of other identifiers with respect to - /// (positive) constants. Sets bound map to a null AffineMap if such a bound - /// can't be found (or yet unimplemented). - void getSliceBounds(unsigned num, MLIRContext *context, - SmallVectorImpl *lbMaps, - SmallVectorImpl *ubMaps); - - // Adds an inequality (>= 0) from the coefficients specified in inEq. - void addInequality(ArrayRef inEq); - // Adds an equality from the coefficients specified in eq. - void addEquality(ArrayRef eq); - - /// Adds a constant lower bound constraint for the specified identifier. - void addConstantLowerBound(unsigned pos, int64_t lb); - /// Adds a constant upper bound constraint for the specified identifier. - void addConstantUpperBound(unsigned pos, int64_t ub); - - /// Adds a new local identifier as the floordiv of an affine function of other - /// identifiers, the coefficients of which are provided in 'dividend' and with - /// respect to a positive constant 'divisor'. Two constraints are added to the - /// system to capture equivalence with the floordiv: - /// q = dividend floordiv c <=> c*q <= dividend <= c*q + c - 1. - void addLocalFloorDiv(ArrayRef dividend, int64_t divisor); - - /// Adds constraints (lower and upper bounds) for the specified 'for' - /// instruction's Value using IR information stored in its bound maps. The - /// right identifier is first looked up using forOp's Value. Returns - /// false for the yet unimplemented/unsupported cases, and true if the - /// information is succesfully added. Asserts if the Value corresponding to - /// the 'for' instruction isn't found in the constraint system. Any new - /// identifiers that are found in the bound operands of the 'for' instruction - /// are added as trailing identifiers (either dimensional or symbolic - /// depending on whether the operand is a valid ML Function symbol). - // TODO(bondhugula): add support for non-unit strides. - bool addAffineForOpDomain(ConstOpPointer forOp); - - /// Adds a constant lower bound constraint for the specified expression. - void addConstantLowerBound(ArrayRef expr, int64_t lb); - /// Adds a constant upper bound constraint for the specified expression. - void addConstantUpperBound(ArrayRef expr, int64_t ub); - - /// Sets the identifier at the specified position to a constant. - void setIdToConstant(unsigned pos, int64_t val); - - /// Sets the identifier corresponding to the specified Value id to a - /// constant. Asserts if the 'id' is not found. - void setIdToConstant(const Value &id, int64_t val); - - /// Looks up the identifier with the specified Value. Returns false if not - /// found, true if found. pos is set to the (column) position of the - /// identifier. - bool findId(const Value &id, unsigned *pos) const; - - // Add identifiers of the specified kind - specified positions are relative to - // the kind of identifier. 'id' is the Value corresponding to the - // identifier that can optionally be provided. - void addDimId(unsigned pos, Value *id = nullptr); - void addSymbolId(unsigned pos, Value *id = nullptr); - void addLocalId(unsigned pos); - void addId(IdKind kind, unsigned pos, Value *id = nullptr); - - /// Composes the affine value map with this FlatAffineConstrains, adding the - /// results of the map as dimensions at the front [0, vMap->getNumResults()) - /// and with the dimensions set to the equalities specified by the value map. - /// Returns false if the composition fails (when vMap is a semi-affine map). - /// The vMap's operand Value's are used to look up the right positions in - /// the FlatAffineConstraints with which to associate. The dimensional and - /// symbolic operands of vMap should match 1:1 (in the same order) with those - /// of this constraint system, but the latter could have additional trailing - /// operands. - bool composeMap(AffineValueMap *vMap); - - /// Projects out (aka eliminates) 'num' identifiers starting at position - /// 'pos'. The resulting constraint system is the shadow along the dimensions - /// that still exist. This method may not always be integer exact. - // TODO(bondhugula): deal with integer exactness when necessary - can return a - // value to mark exactness for example. - void projectOut(unsigned pos, unsigned num); - inline void projectOut(unsigned pos) { return projectOut(pos, 1); } - - /// Projects out the identifier that is associate with Value *. - void projectOut(Value *id); - - void removeId(IdKind idKind, unsigned pos); - void removeId(unsigned pos); - - void removeDim(unsigned pos); - - void removeEquality(unsigned pos); - void removeInequality(unsigned pos); - - /// Changes the partition between dimensions and symbols. Depending on the new - /// symbol count, either a chunk of trailing dimensional identifiers becomes - /// symbols, or some of the leading symbols become dimensions. - void setDimSymbolSeparation(unsigned newSymbolCount); - - /// Sets the specified identifier to a constant and removes it. - void setAndEliminate(unsigned pos, int64_t constVal); - - /// Tries to fold the specified identifer to a constant using a trivial - /// equality detection; if successful, the constant is substituted for the - /// identifier everywhere in the constraint system and then removed from the - /// system. Returns true if the folding happens, false otherwise. - bool constantFoldId(unsigned pos); - - /// This method calls constantFoldId for the specified range of identifiers, - /// 'num' identifiers starting at position 'pos'. - void constantFoldIdRange(unsigned pos, unsigned num); - - /// Returns true if all the identifiers in the specified range [start, limit) - /// can only take a single value each if the remaining identifiers are treated - /// as symbols/parameters, i.e., for given values of the latter, there only - /// exists a unique value for each of the dimensions in the specified range. - bool isRangeOneToOne(unsigned start, unsigned limit) const; - - /// Updates the constraints to be the smallest bounding (enclosing) box that - /// contains the points of 'this' set and that of 'other', with the symbols - /// being treated specially. For each of the dimensions, the min of the lower - /// bounds (symbolic) and the max of the upper bounds (symbolic) is computed - /// to determine such a bounding box. - /// - /// Eg: if 'this' is {0 <= d0 <= 127}, 'other' is {16 <= d0 <= 192}, the - /// output is {0 <= d0 <= 192}. - /// 2) 'this' = {s0 + 5 <= d0 <= s0 + 20}, 'other' is {s0 + 1 <= d0 <= s0 + - /// 9}, output = {s0 + 1 <= d0 <= s0 + 20}. - /// 3) 'this' = {0 <= d0 <= 5, 1 <= d1 <= 9}, 'other' = {2 <= d0 <= 6, 5 <= d1 - /// <= 15}, output = {0 <= d0 <= 6, 1 <= d1 <= 15}. - bool unionBoundingBox(const FlatAffineConstraints &other); - - unsigned getNumConstraints() const { - return getNumInequalities() + getNumEqualities(); - } - inline unsigned getNumIds() const { return numIds; } - inline unsigned getNumDimIds() const { return numDims; } - inline unsigned getNumSymbolIds() const { return numSymbols; } - inline unsigned getNumDimAndSymbolIds() const { return numDims + numSymbols; } - inline unsigned getNumLocalIds() const { - return numIds - numDims - numSymbols; - } - - inline ArrayRef> getIds() const { - return {ids.data(), ids.size()}; - } - - /// Returns the Value associated with the pos^th identifier. Asserts if - /// no Value identifier was associated. - inline Value *getIdValue(unsigned pos) const { - assert(ids[pos].hasValue() && "identifier's Value not set"); - return ids[pos].getValue(); - } - - /// Returns the Values associated with identifiers in range [start, end). - /// Asserts if no Value was associated with one of these identifiers. - void getIdValues(unsigned start, unsigned end, - SmallVectorImpl *values) const { - assert((start < numIds || start == end) && "invalid start position"); - assert(end <= numIds && "invalid end position"); - values->clear(); - values->reserve(end - start); - for (unsigned i = start; i < end; i++) { - values->push_back(getIdValue(i)); - } - } - inline void getAllIdValues(SmallVectorImpl *values) const { - getIdValues(0, numIds, values); - } - - /// Sets Value associated with the pos^th identifier. - inline void setIdValue(unsigned pos, Value *val) { - assert(pos < numIds && "invalid id position"); - ids[pos] = val; - } - /// Sets Values associated with identifiers in the range [start, end). - void setIdValues(unsigned start, unsigned end, ArrayRef values) { - assert((start < numIds || end == start) && "invalid start position"); - assert(end <= numIds && "invalid end position"); - assert(values.size() == end - start); - for (unsigned i = start; i < end; ++i) - ids[i] = values[i - start]; - } - - /// Clears this list of constraints and copies other into it. - void clearAndCopyFrom(const FlatAffineConstraints &other); - - /// Returns the smallest known constant bound for the extent of the specified - /// identifier (pos^th), i.e., the smallest known constant that is greater - /// than or equal to 'exclusive upper bound' - 'lower bound' of the - /// identifier. Returns None if it's not a constant. This method employs - /// trivial (low complexity / cost) checks and detection. Symbolic identifiers - /// are treated specially, i.e., it looks for constant differences between - /// affine expressions involving only the symbolic identifiers. See comments - /// at function definition for examples. 'lb' and 'lbDivisor', if provided, - /// are used to express the lower bound associated with the constant - /// difference: 'lb' has the coefficients and lbDivisor, the divisor. For eg., - /// if the lower bound is [(s0 + s2 - 1) floordiv 32] for a system with three - /// symbolic identifiers, *lb = [1, 0, 1], lbDivisor = 32. - Optional - getConstantBoundOnDimSize(unsigned pos, - SmallVectorImpl *lb = nullptr, - int64_t *lbDivisor = nullptr) const; - - /// Returns the constant lower bound for the pos^th identifier if there is - /// one; None otherwise. - Optional getConstantLowerBound(unsigned pos) const; - - /// Returns the constant upper bound for the pos^th identifier if there is - /// one; None otherwise. - Optional getConstantUpperBound(unsigned pos) const; - - /// Returns true if the set can be trivially detected as being - /// hyper-rectangular on the specified contiguous set of identifiers. - bool isHyperRectangular(unsigned pos, unsigned num) const; - - /// Removes duplicates and trivially true constraints: a constraint of the - /// form >= 0 is considered a trivially true - /// constraint. - void removeTrivialRedundancy(); - - // Removes all equalities and inequalities. - void clearConstraints(); - - void print(raw_ostream &os) const; - void dump() const; - -private: - /// Returns false if the fields corresponding to various identifier counts, or - /// equality/inequality buffer sizes aren't consistent; true otherwise. This - /// is meant to be used within an assert internally. - bool hasConsistentState() const; - - /// Checks all rows of equality/inequality constraints for trivial - /// contradictions (for example: 1 == 0, 0 >= 1), which may have surfaced - /// after elimination. Returns 'true' if an invalid constraint is found; - /// 'false'otherwise. - bool hasInvalidConstraint() const; - - /// Returns the constant lower bound bound if isLower is true, and the upper - /// bound if isLower is false. - template - Optional getConstantLowerOrUpperBound(unsigned pos) const; - - // Eliminates a single identifier at 'position' from equality and inequality - // constraints. Returns 'true' if the identifier was eliminated, and false - // otherwise. - inline bool gaussianEliminateId(unsigned position) { - return gaussianEliminateIds(position, position + 1) == 1; - } - - // Eliminates identifiers from equality and inequality constraints - // in column range [posStart, posLimit). - // Returns the number of variables eliminated. - unsigned gaussianEliminateIds(unsigned posStart, unsigned posLimit); - - /// Eliminates identifier at the specified position using Fourier-Motzkin - /// variable elimination, but uses Gaussian elimination if there is an - /// equality involving that identifier. If the result of the elimination is - /// integer exact, *isResultIntegerExact is set to true. If 'darkShadow' is - /// set to true, a potential under approximation (subset) of the rational - /// shadow / exact integer shadow is computed. - // See implementation comments for more details. - void FourierMotzkinEliminate(unsigned pos, bool darkShadow = false, - bool *isResultIntegerExact = nullptr); - - /// Tightens inequalities given that we are dealing with integer spaces. This - /// is similar to the GCD test but applied to inequalities. The constant term - /// can be reduced to the preceding multiple of the GCD of the coefficients, - /// i.e., - /// 64*i - 100 >= 0 => 64*i - 128 >= 0 (since 'i' is an integer). This is a - /// fast method (linear in the number of coefficients). - void GCDTightenInequalities(); - - /// Normalized each constraints by the GCD of its coefficients. - void normalizeConstraintsByGCD(); - - /// Removes identifiers in column range [idStart, idLimit), and copies any - /// remaining valid data into place, updates member variables, and resizes - /// arrays as needed. - void removeIdRange(unsigned idStart, unsigned idLimit); - - /// Coefficients of affine equalities (in == 0 form). - SmallVector equalities; - - /// Coefficients of affine inequalities (in >= 0 form). - SmallVector inequalities; - - /// Number of columns reserved. Actual ones in used are returned by - /// getNumCols(). - unsigned numReservedCols; - - /// Total number of identifiers. - unsigned numIds; - - /// Number of identifiers corresponding to real dimensions. - unsigned numDims; - - /// Number of identifiers corresponding to symbols (unknown but constant for - /// analysis). - unsigned numSymbols; - - /// Values corresponding to the (column) identifiers of this constraint - /// system appearing in the order the identifiers correspond to columns. - /// Temporary ones or those that aren't associated to any Value are to be - /// set to None. - SmallVector, 8> ids; - - /// A parameter that controls detection of an unrealistic number of - /// constraints. If the number of constraints is this many times the number of - /// variables, we consider such a system out of line with the intended use - /// case of FlatAffineConstraints. - // The rationale for 32 is that in the typical simplest of cases, an - // identifier is expected to have one lower bound and one upper bound - // constraint. With a level of tiling or a connection to another identifier - // through a div or mod, an extra pair of bounds gets added. As a limit, we - // don't expect an identifier to have more than 32 lower/upper/equality - // constraints. This is conservatively set low and can be raised if needed. - constexpr static unsigned kExplosionFactor = 32; -}; - -} // end namespace mlir. - -#endif // MLIR_ANALYSIS_AFFINE_STRUCTURES_H diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index 1dcbd759154..6cd688045c5 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -25,8 +25,8 @@ #ifndef MLIR_ANALYSIS_UTILS_H #define MLIR_ANALYSIS_UTILS_H -#include "mlir/Analysis/AffineStructures.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/AffineStructures.h" #include "mlir/IR/Location.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/SmallVector.h" diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h index 596d47096f5..d7eab0f1312 100644 --- a/mlir/include/mlir/IR/AffineExpr.h +++ b/mlir/include/mlir/IR/AffineExpr.h @@ -32,6 +32,8 @@ namespace mlir { class MLIRContext; class AffineMap; +class IntegerSet; +class FlatAffineConstraints; namespace detail { @@ -247,6 +249,40 @@ template U AffineExpr::cast() const { return U(expr); } +/// Simplify an affine expression by flattening and some amount of +/// simple analysis. This has complexity linear in the number of nodes in +/// 'expr'. Returns the simplified expression, which is the same as the input +/// expression if it can't be simplified. +AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, + unsigned numSymbols); + +/// Flattens 'expr' into 'flattenedExpr'. Returns true on success or false +/// if 'expr' could not be flattened (i.e., semi-affine is not yet handled). +/// 'cst' contains constraints that connect newly introduced local identifiers +/// to existing dimensional and / symbolic identifiers. See documentation for +/// AffineExprFlattener on how mod's and div's are flattened. +bool getFlattenedAffineExpr(AffineExpr expr, unsigned numDims, + unsigned numSymbols, + llvm::SmallVectorImpl *flattenedExpr, + FlatAffineConstraints *cst = nullptr); + +/// Flattens the result expressions of the map to their corresponding flattened +/// forms and set in 'flattenedExprs'. Returns true on success or false +/// if any expression in the map could not be flattened (i.e., semi-affine is +/// not yet handled). 'cst' contains constraints that connect newly introduced +/// local identifiers to existing dimensional and / symbolic identifiers. See +/// documentation for AffineExprFlattener on how mod's and div's are flattened. +/// For all affine expressions that share the same operands (like those of an +/// affine map), this method should be used instead of repeatedly calling +/// getFlattenedAffineExpr since local variables added to deal with div's and +/// mod's will be reused across expressions. +bool getFlattenedAffineExprs( + AffineMap map, std::vector> *flattenedExprs, + FlatAffineConstraints *cst = nullptr); +bool getFlattenedAffineExprs( + IntegerSet set, std::vector> *flattenedExprs, + FlatAffineConstraints *cst = nullptr); + } // namespace mlir namespace llvm { diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h index cb2c315acea..87411cf3d21 100644 --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -150,6 +150,10 @@ inline ::llvm::hash_code hash_value(AffineMap arg) { return ::llvm::hash_value(arg.map); } +/// Simplify an affine map by simplifying its underlying AffineExpr results and +/// sizes. +AffineMap simplifyAffineMap(AffineMap map); + } // end namespace mlir namespace llvm { diff --git a/mlir/include/mlir/IR/AffineStructures.h b/mlir/include/mlir/IR/AffineStructures.h new file mode 100644 index 00000000000..9c88436dcb0 --- /dev/null +++ b/mlir/include/mlir/IR/AffineStructures.h @@ -0,0 +1,682 @@ +//===- AffineStructures.h - MLIR Affine Structures Class --------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// Structures for affine/polyhedral analysis of ML functions. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_AFFINE_STRUCTURES_H +#define MLIR_IR_AFFINE_STRUCTURES_H + +#include "mlir/IR/AffineExpr.h" + +namespace mlir { + +class AffineCondition; +class AffineMap; +class IntegerSet; +class MLIRContext; +class Value; +class HyperRectangularSet; +class MemRefType; + +/// A mutable affine map. Its affine expressions are however unique. +struct MutableAffineMap { +public: + MutableAffineMap() {} + MutableAffineMap(AffineMap map); + + ArrayRef getResults() const { return results; } + AffineExpr getResult(unsigned idx) const { return results[idx]; } + void setResult(unsigned idx, AffineExpr result) { results[idx] = result; } + unsigned getNumResults() const { return results.size(); } + unsigned getNumDims() const { return numDims; } + void setNumDims(unsigned d) { numDims = d; } + unsigned getNumSymbols() const { return numSymbols; } + void setNumSymbols(unsigned d) { numSymbols = d; } + MLIRContext *getContext() const { return context; } + + /// Returns true if the idx'th result expression is a multiple of factor. + bool isMultipleOf(unsigned idx, int64_t factor) const; + + /// Resets this MutableAffineMap with 'map'. + void reset(AffineMap map); + + /// Simplify the (result) expressions in this map using analysis (used by + //-simplify-affine-expr pass). + void simplify(); + /// Get the AffineMap corresponding to this MutableAffineMap. Note that an + /// AffineMap will be uniqued and stored in context, while a mutable one + /// isn't. + AffineMap getAffineMap() const; + +private: + // Same meaning as AffineMap's fields. + SmallVector results; + SmallVector rangeSizes; + unsigned numDims; + unsigned numSymbols; + /// A pointer to the IR's context to store all newly created + /// AffineExprStorage's. + MLIRContext *context; +}; + +/// A mutable integer set. Its affine expressions are however unique. +struct MutableIntegerSet { +public: + MutableIntegerSet(IntegerSet set, MLIRContext *context); + + /// Create a universal set (no constraints). + MutableIntegerSet(unsigned numDims, unsigned numSymbols, + MLIRContext *context); + + unsigned getNumDims() const { return numDims; } + unsigned getNumSymbols() const { return numSymbols; } + unsigned getNumConstraints() const { return constraints.size(); } + + void clear() { + constraints.clear(); + eqFlags.clear(); + } + +private: + unsigned numDims; + unsigned numSymbols; + + SmallVector constraints; + SmallVector eqFlags; + /// A pointer to the IR's context to store all newly created + /// AffineExprStorage's. + MLIRContext *context; +}; + +/// An AffineValueMap is an affine map plus its ML value operands and +/// results for analysis purposes. The structure is still a tree form that is +/// same as that of an affine map or an AffineApplyOp. However, its operands, +/// results, and its map can themselves change as a result of +/// substitutions, simplifications, and other analysis. +// An affine value map can readily be constructed from an AffineApplyOp, or an +// AffineBound of a AffineForOp. It can be further transformed, substituted +// into, or simplified. Unlike AffineMap's, AffineValueMap's are created and +// destroyed during analysis. Only the AffineMap expressions that are pointed by +// them are unique'd. An affine value map, and the operations on it, maintain +// the invariant that operands are always positionally aligned with the +// AffineDimExpr and AffineSymbolExpr in the underlying AffineMap. +// TODO(bondhugula): Some of these classes could go into separate files. +class AffineValueMap { +public: + // Creates an empty AffineValueMap (users should call 'reset' to reset map + // and operands). + AffineValueMap() {} + AffineValueMap(AffineMap map); + AffineValueMap(AffineMap map, ArrayRef operands, + ArrayRef results = llvm::None); + + ~AffineValueMap(); + + // Resets this AffineValueMap with 'map', 'operands', and 'results'. + void reset(AffineMap map, ArrayRef operands, + ArrayRef results = llvm::None); + + /// Return true if the idx^th result can be proved to be a multiple of + /// 'factor', false otherwise. + inline bool isMultipleOf(unsigned idx, int64_t factor) const; + + /// Return true if the idx^th result depends on 'value', false otherwise. + bool isFunctionOf(unsigned idx, Value *value) const; + + /// Return true if the result at 'idx' is a constant, false + /// otherwise. + bool isConstant(unsigned idx) const; + + /// Return true if this is an identity map. + bool isIdentity() const; + + inline unsigned getNumOperands() const { return operands.size(); } + inline unsigned getNumDims() const { return map.getNumDims(); } + inline unsigned getNumSymbols() const { return map.getNumSymbols(); } + inline unsigned getNumResults() const { return map.getNumResults(); } + + Value *getOperand(unsigned i) const; + ArrayRef getOperands() const; + AffineMap getAffineMap() const; + +private: + // A mutable affine map. + MutableAffineMap map; + + // TODO: make these trailing objects? + /// The SSA operands binding to the dim's and symbols of 'map'. + SmallVector operands; + /// The SSA results binding to the results of 'map'. + SmallVector results; +}; + +/// An IntegerValueSet is an integer set plus its operands. +// Both, the integer set being pointed to and the operands can change during +// analysis, simplification, and transformation. +class IntegerValueSet { + /// Constructs an integer value set from an affine value map. + // This will lead to a single equality in 'set'. + explicit IntegerValueSet(const AffineValueMap &avm); + + /// Returns true if this integer set is determined to be empty. Emptiness is + /// checked by by eliminating identifiers successively (through either + /// Gaussian or Fourier-Motzkin) while using the GCD test and a trivial + /// invalid constraint check. Returns 'true' if the constaint system is found + /// to be empty; false otherwise. This method is exact for rational spaces but + /// not integer spaces - thus, if it returns true, the set is provably integer + /// empty as well, but if it returns false, it doesn't necessarily mean an + /// integer point exists in it. This method also returns false where an + /// explosion of constraints is detected - due to the super-exponential + /// worse-case complexity of Fourier-Motzkin elimination (rare for realistic + /// problem cases but possible for artificial adversarial or improperly + // constructed ones), this method returns false conservatively. + bool isEmpty() const; + + bool getNumDims() const { return set.getNumDims(); } + bool getNumSymbols() const { return set.getNumSymbols(); } + +private: + // The set pointed to may itself change unlike in IR structures like + // 'AffineCondition'. + MutableIntegerSet set; + /// The SSA operands binding to the dim's and symbols of 'set'. + SmallVector operands; +}; + +/// A flat list of affine equalities and inequalities in the form. +/// Inequality: c_0*x_0 + c_1*x_1 + .... + c_{n-1}*x_{n-1} == 0 +/// Equality: c_0*x_0 + c_1*x_1 + .... + c_{n-1}*x_{n-1} >= 0 +/// +/// FlatAffineConstraints stores coefficients in a contiguous buffer (one buffer +/// for equalities and one for inequalities). The size of each buffer is +/// numReservedCols * number of inequalities (or equalities). The reserved size +/// is numReservedCols * numReservedInequalities (or numReservedEqualities). A +/// coefficient (r, c) lives at the location numReservedCols * r + c in the +/// buffer. The extra space between getNumCols() and numReservedCols exists to +/// prevent frequent movement of data when adding columns, especially at the +/// end. +/// +/// The identifiers x_0, x_1, ... appear in the order: dimensional identifiers, +/// symbolic identifiers, and local identifiers. The local identifiers +/// correspond to local/internal variables created when converting from +/// AffineExpr's containing mod's and div's; they are thus needed to increase +/// representational power. Each local identifier is always (by construction) a +/// floordiv of a pure add/mul affine function of dimensional, symbolic, and +/// other local identifiers, in a non-mutually recursive way. Hence, every local +/// identifier can ultimately always be recovered as an affine function of +/// dimensional and symbolic identifiers (involving floordiv's); note however +/// that some floordiv combinations are converted to mod's by AffineExpr +/// construction. +/// +class FlatAffineConstraints { +public: + enum IdKind { Dimension, Symbol, Local }; + + /// Constructs a constraint system reserving memory for the specified number + /// of constraints and identifiers.. + FlatAffineConstraints(unsigned numReservedInequalities, + unsigned numReservedEqualities, + unsigned numReservedCols, unsigned numDims = 0, + unsigned numSymbols = 0, unsigned numLocals = 0, + ArrayRef> idArgs = {}) + : numReservedCols(numReservedCols), numDims(numDims), + numSymbols(numSymbols) { + assert(numReservedCols >= numDims + numSymbols + 1); + assert(idArgs.empty() || idArgs.size() == numDims + numSymbols + numLocals); + equalities.reserve(numReservedCols * numReservedEqualities); + inequalities.reserve(numReservedCols * numReservedInequalities); + numIds = numDims + numSymbols + numLocals; + ids.reserve(numReservedCols); + if (idArgs.empty()) + ids.resize(numIds, None); + else + ids.append(idArgs.begin(), idArgs.end()); + } + + /// Constructs a constraint system with the specified number of + /// dimensions and symbols. + FlatAffineConstraints(unsigned numDims = 0, unsigned numSymbols = 0, + unsigned numLocals = 0, + ArrayRef> idArgs = {}) + : numReservedCols(numDims + numSymbols + numLocals + 1), numDims(numDims), + numSymbols(numSymbols) { + assert(numReservedCols >= numDims + numSymbols + 1); + assert(idArgs.empty() || idArgs.size() == numDims + numSymbols + numLocals); + numIds = numDims + numSymbols + numLocals; + ids.reserve(numIds); + if (idArgs.empty()) + ids.resize(numIds, None); + else + ids.append(idArgs.begin(), idArgs.end()); + } + + explicit FlatAffineConstraints(const HyperRectangularSet &set); + + /// Create a flat affine constraint system from an AffineValueMap or a list of + /// these. The constructed system will only include equalities. + // TODO(bondhugula) + explicit FlatAffineConstraints(const AffineValueMap &avm); + explicit FlatAffineConstraints(ArrayRef avmRef); + + /// Creates an affine constraint system from an IntegerSet. + explicit FlatAffineConstraints(IntegerSet set); + + /// Create an affine constraint system from an IntegerValueSet. + // TODO(bondhugula) + explicit FlatAffineConstraints(const IntegerValueSet &set); + + FlatAffineConstraints(const FlatAffineConstraints &other); + + FlatAffineConstraints(ArrayRef avmRef, + IntegerSet set); + + FlatAffineConstraints(const MutableAffineMap &map); + + ~FlatAffineConstraints() {} + + // Clears any existing data and reserves memory for the specified constraints. + void reset(unsigned numReservedInequalities, unsigned numReservedEqualities, + unsigned numReservedCols, unsigned numDims, unsigned numSymbols, + unsigned numLocals = 0, ArrayRef idArgs = {}); + + void reset(unsigned numDims = 0, unsigned numSymbols = 0, + unsigned numLocals = 0, ArrayRef idArgs = {}); + + /// Appends constraints from 'other' into this. This is equivalent to an + /// intersection with no simplification of any sort attempted. + void append(const FlatAffineConstraints &other); + + // Checks for emptiness by performing variable elimination on all identifiers, + // running the GCD test on each equality constraint, and checking for invalid + // constraints. + // Returns true if the GCD test fails for any equality, or if any invalid + // constraints are discovered on any row. Returns false otherwise. + bool isEmpty() const; + + // Runs the GCD test on all equality constraints. Returns 'true' if this test + // fails on any equality. Returns 'false' otherwise. + // This test can be used to disprove the existence of a solution. If it + // returns true, no integer solution to the equality constraints can exist. + bool isEmptyByGCDTest() const; + + // Clones this object. + std::unique_ptr clone() const; + + /// Returns the value at the specified equality row and column. + inline int64_t atEq(unsigned i, unsigned j) const { + return equalities[i * numReservedCols + j]; + } + inline int64_t &atEq(unsigned i, unsigned j) { + return equalities[i * numReservedCols + j]; + } + + inline int64_t atIneq(unsigned i, unsigned j) const { + return inequalities[i * numReservedCols + j]; + } + + inline int64_t &atIneq(unsigned i, unsigned j) { + return inequalities[i * numReservedCols + j]; + } + + /// Returns the number of columns in the constraint system. + inline unsigned getNumCols() const { return numIds + 1; } + + inline unsigned getNumEqualities() const { + assert(equalities.size() % numReservedCols == 0 && + "inconsistent equality buffer size"); + return equalities.size() / numReservedCols; + } + + inline unsigned getNumInequalities() const { + assert(inequalities.size() % numReservedCols == 0 && + "inconsistent inequality buffer size"); + return inequalities.size() / numReservedCols; + } + + inline unsigned getNumReservedEqualities() const { + return equalities.capacity() / numReservedCols; + } + + inline unsigned getNumReservedInequalities() const { + return inequalities.capacity() / numReservedCols; + } + + inline ArrayRef getEquality(unsigned idx) const { + return ArrayRef(&equalities[idx * numReservedCols], getNumCols()); + } + + inline ArrayRef getInequality(unsigned idx) const { + return ArrayRef(&inequalities[idx * numReservedCols], + getNumCols()); + } + + AffineExpr toAffineExpr(unsigned idx, MLIRContext *context); + + /// Computes the lower and upper bounds of the first 'num' dimensional + /// identifiers as an affine map of the remaining identifiers (dimensional and + /// symbolic). This method is able to detect identifiers as floordiv's + /// and mod's of affine expressions of other identifiers with respect to + /// (positive) constants. Sets bound map to a null AffineMap if such a bound + /// can't be found (or yet unimplemented). + void getSliceBounds(unsigned num, MLIRContext *context, + SmallVectorImpl *lbMaps, + SmallVectorImpl *ubMaps); + + // Adds an inequality (>= 0) from the coefficients specified in inEq. + void addInequality(ArrayRef inEq); + // Adds an equality from the coefficients specified in eq. + void addEquality(ArrayRef eq); + + /// Adds a constant lower bound constraint for the specified identifier. + void addConstantLowerBound(unsigned pos, int64_t lb); + /// Adds a constant upper bound constraint for the specified identifier. + void addConstantUpperBound(unsigned pos, int64_t ub); + + /// Adds a new local identifier as the floordiv of an affine function of other + /// identifiers, the coefficients of which are provided in 'dividend' and with + /// respect to a positive constant 'divisor'. Two constraints are added to the + /// system to capture equivalence with the floordiv: + /// q = dividend floordiv c <=> c*q <= dividend <= c*q + c - 1. + void addLocalFloorDiv(ArrayRef dividend, int64_t divisor); + + /// Adds a constant lower bound constraint for the specified expression. + void addConstantLowerBound(ArrayRef expr, int64_t lb); + /// Adds a constant upper bound constraint for the specified expression. + void addConstantUpperBound(ArrayRef expr, int64_t ub); + + /// Sets the identifier at the specified position to a constant. + void setIdToConstant(unsigned pos, int64_t val); + + /// Sets the identifier corresponding to the specified Value id to a + /// constant. Asserts if the 'id' is not found. + void setIdToConstant(const Value &id, int64_t val); + + /// Looks up the identifier with the specified Value. Returns false if not + /// found, true if found. pos is set to the (column) position of the + /// identifier. + bool findId(const Value &id, unsigned *pos) const; + + // Add identifiers of the specified kind - specified positions are relative to + // the kind of identifier. 'id' is the Value corresponding to the + // identifier that can optionally be provided. + void addDimId(unsigned pos, Value *id = nullptr); + void addSymbolId(unsigned pos, Value *id = nullptr); + void addLocalId(unsigned pos); + void addId(IdKind kind, unsigned pos, Value *id = nullptr); + + /// Composes the affine value map with this FlatAffineConstrains, adding the + /// results of the map as dimensions at the front [0, vMap->getNumResults()) + /// and with the dimensions set to the equalities specified by the value map. + /// Returns false if the composition fails (when vMap is a semi-affine map). + /// The vMap's operand Value's are used to look up the right positions in + /// the FlatAffineConstraints with which to associate. The dimensional and + /// symbolic operands of vMap should match 1:1 (in the same order) with those + /// of this constraint system, but the latter could have additional trailing + /// operands. + bool composeMap(AffineValueMap *vMap); + + /// Projects out (aka eliminates) 'num' identifiers starting at position + /// 'pos'. The resulting constraint system is the shadow along the dimensions + /// that still exist. This method may not always be integer exact. + // TODO(bondhugula): deal with integer exactness when necessary - can return a + // value to mark exactness for example. + void projectOut(unsigned pos, unsigned num); + inline void projectOut(unsigned pos) { return projectOut(pos, 1); } + + /// Projects out the identifier that is associate with Value *. + void projectOut(Value *id); + + void removeId(IdKind idKind, unsigned pos); + void removeId(unsigned pos); + + void removeDim(unsigned pos); + + void removeEquality(unsigned pos); + void removeInequality(unsigned pos); + + /// Changes the partition between dimensions and symbols. Depending on the new + /// symbol count, either a chunk of trailing dimensional identifiers becomes + /// symbols, or some of the leading symbols become dimensions. + void setDimSymbolSeparation(unsigned newSymbolCount); + + /// Sets the specified identifier to a constant and removes it. + void setAndEliminate(unsigned pos, int64_t constVal); + + /// Tries to fold the specified identifer to a constant using a trivial + /// equality detection; if successful, the constant is substituted for the + /// identifier everywhere in the constraint system and then removed from the + /// system. Returns true if the folding happens, false otherwise. + bool constantFoldId(unsigned pos); + + /// This method calls constantFoldId for the specified range of identifiers, + /// 'num' identifiers starting at position 'pos'. + void constantFoldIdRange(unsigned pos, unsigned num); + + /// Returns true if all the identifiers in the specified range [start, limit) + /// can only take a single value each if the remaining identifiers are treated + /// as symbols/parameters, i.e., for given values of the latter, there only + /// exists a unique value for each of the dimensions in the specified range. + bool isRangeOneToOne(unsigned start, unsigned limit) const; + + /// Updates the constraints to be the smallest bounding (enclosing) box that + /// contains the points of 'this' set and that of 'other', with the symbols + /// being treated specially. For each of the dimensions, the min of the lower + /// bounds (symbolic) and the max of the upper bounds (symbolic) is computed + /// to determine such a bounding box. + /// + /// Eg: if 'this' is {0 <= d0 <= 127}, 'other' is {16 <= d0 <= 192}, the + /// output is {0 <= d0 <= 192}. + /// 2) 'this' = {s0 + 5 <= d0 <= s0 + 20}, 'other' is {s0 + 1 <= d0 <= s0 + + /// 9}, output = {s0 + 1 <= d0 <= s0 + 20}. + /// 3) 'this' = {0 <= d0 <= 5, 1 <= d1 <= 9}, 'other' = {2 <= d0 <= 6, 5 <= d1 + /// <= 15}, output = {0 <= d0 <= 6, 1 <= d1 <= 15}. + bool unionBoundingBox(const FlatAffineConstraints &other); + + unsigned getNumConstraints() const { + return getNumInequalities() + getNumEqualities(); + } + inline unsigned getNumIds() const { return numIds; } + inline unsigned getNumDimIds() const { return numDims; } + inline unsigned getNumSymbolIds() const { return numSymbols; } + inline unsigned getNumDimAndSymbolIds() const { return numDims + numSymbols; } + inline unsigned getNumLocalIds() const { + return numIds - numDims - numSymbols; + } + + inline ArrayRef> getIds() const { + return {ids.data(), ids.size()}; + } + + /// Returns the Value associated with the pos^th identifier. Asserts if + /// no Value identifier was associated. + inline Value *getIdValue(unsigned pos) const { + assert(ids[pos].hasValue() && "identifier's Value not set"); + return ids[pos].getValue(); + } + + /// Returns the Values associated with identifiers in range [start, end). + /// Asserts if no Value was associated with one of these identifiers. + void getIdValues(unsigned start, unsigned end, + SmallVectorImpl *values) const { + assert((start < numIds || start == end) && "invalid start position"); + assert(end <= numIds && "invalid end position"); + values->clear(); + values->reserve(end - start); + for (unsigned i = start; i < end; i++) { + values->push_back(getIdValue(i)); + } + } + inline void getAllIdValues(SmallVectorImpl *values) const { + getIdValues(0, numIds, values); + } + + /// Sets Value associated with the pos^th identifier. + inline void setIdValue(unsigned pos, Value *val) { + assert(pos < numIds && "invalid id position"); + ids[pos] = val; + } + /// Sets Values associated with identifiers in the range [start, end). + void setIdValues(unsigned start, unsigned end, ArrayRef values) { + assert((start < numIds || end == start) && "invalid start position"); + assert(end <= numIds && "invalid end position"); + assert(values.size() == end - start); + for (unsigned i = start; i < end; ++i) + ids[i] = values[i - start]; + } + + /// Clears this list of constraints and copies other into it. + void clearAndCopyFrom(const FlatAffineConstraints &other); + + /// Returns the smallest known constant bound for the extent of the specified + /// identifier (pos^th), i.e., the smallest known constant that is greater + /// than or equal to 'exclusive upper bound' - 'lower bound' of the + /// identifier. Returns None if it's not a constant. This method employs + /// trivial (low complexity / cost) checks and detection. Symbolic identifiers + /// are treated specially, i.e., it looks for constant differences between + /// affine expressions involving only the symbolic identifiers. See comments + /// at function definition for examples. 'lb' and 'lbDivisor', if provided, + /// are used to express the lower bound associated with the constant + /// difference: 'lb' has the coefficients and lbDivisor, the divisor. For eg., + /// if the lower bound is [(s0 + s2 - 1) floordiv 32] for a system with three + /// symbolic identifiers, *lb = [1, 0, 1], lbDivisor = 32. + Optional + getConstantBoundOnDimSize(unsigned pos, + SmallVectorImpl *lb = nullptr, + int64_t *lbDivisor = nullptr) const; + + /// Returns the constant lower bound for the pos^th identifier if there is + /// one; None otherwise. + Optional getConstantLowerBound(unsigned pos) const; + + /// Returns the constant upper bound for the pos^th identifier if there is + /// one; None otherwise. + Optional getConstantUpperBound(unsigned pos) const; + + /// Returns true if the set can be trivially detected as being + /// hyper-rectangular on the specified contiguous set of identifiers. + bool isHyperRectangular(unsigned pos, unsigned num) const; + + /// Removes duplicates and trivially true constraints: a constraint of the + /// form >= 0 is considered a trivially true + /// constraint. + void removeTrivialRedundancy(); + + // Removes all equalities and inequalities. + void clearConstraints(); + + void print(raw_ostream &os) const; + void dump() const; + +private: + /// Returns false if the fields corresponding to various identifier counts, or + /// equality/inequality buffer sizes aren't consistent; true otherwise. This + /// is meant to be used within an assert internally. + bool hasConsistentState() const; + + /// Checks all rows of equality/inequality constraints for trivial + /// contradictions (for example: 1 == 0, 0 >= 1), which may have surfaced + /// after elimination. Returns 'true' if an invalid constraint is found; + /// 'false'otherwise. + bool hasInvalidConstraint() const; + + /// Returns the constant lower bound bound if isLower is true, and the upper + /// bound if isLower is false. + template + Optional getConstantLowerOrUpperBound(unsigned pos) const; + + // Eliminates a single identifier at 'position' from equality and inequality + // constraints. Returns 'true' if the identifier was eliminated, and false + // otherwise. + inline bool gaussianEliminateId(unsigned position) { + return gaussianEliminateIds(position, position + 1) == 1; + } + + // Eliminates identifiers from equality and inequality constraints + // in column range [posStart, posLimit). + // Returns the number of variables eliminated. + unsigned gaussianEliminateIds(unsigned posStart, unsigned posLimit); + + /// Eliminates identifier at the specified position using Fourier-Motzkin + /// variable elimination, but uses Gaussian elimination if there is an + /// equality involving that identifier. If the result of the elimination is + /// integer exact, *isResultIntegerExact is set to true. If 'darkShadow' is + /// set to true, a potential under approximation (subset) of the rational + /// shadow / exact integer shadow is computed. + // See implementation comments for more details. + void FourierMotzkinEliminate(unsigned pos, bool darkShadow = false, + bool *isResultIntegerExact = nullptr); + + /// Tightens inequalities given that we are dealing with integer spaces. This + /// is similar to the GCD test but applied to inequalities. The constant term + /// can be reduced to the preceding multiple of the GCD of the coefficients, + /// i.e., + /// 64*i - 100 >= 0 => 64*i - 128 >= 0 (since 'i' is an integer). This is a + /// fast method (linear in the number of coefficients). + void GCDTightenInequalities(); + + /// Normalized each constraints by the GCD of its coefficients. + void normalizeConstraintsByGCD(); + + /// Removes identifiers in column range [idStart, idLimit), and copies any + /// remaining valid data into place, updates member variables, and resizes + /// arrays as needed. + void removeIdRange(unsigned idStart, unsigned idLimit); + + /// Coefficients of affine equalities (in == 0 form). + SmallVector equalities; + + /// Coefficients of affine inequalities (in >= 0 form). + SmallVector inequalities; + + /// Number of columns reserved. Actual ones in used are returned by + /// getNumCols(). + unsigned numReservedCols; + + /// Total number of identifiers. + unsigned numIds; + + /// Number of identifiers corresponding to real dimensions. + unsigned numDims; + + /// Number of identifiers corresponding to symbols (unknown but constant for + /// analysis). + unsigned numSymbols; + + /// Values corresponding to the (column) identifiers of this constraint + /// system appearing in the order the identifiers correspond to columns. + /// Temporary ones or those that aren't associated to any Value are to be + /// set to None. + SmallVector, 8> ids; + + /// A parameter that controls detection of an unrealistic number of + /// constraints. If the number of constraints is this many times the number of + /// variables, we consider such a system out of line with the intended use + /// case of FlatAffineConstraints. + // The rationale for 32 is that in the typical simplest of cases, an + // identifier is expected to have one lower bound and one upper bound + // constraint. With a level of tiling or a connection to another identifier + // through a div or mod, an extra pair of bounds gets added. As a limit, we + // don't expect an identifier to have more than 32 lower/upper/equality + // constraints. This is conservatively set low and can be raised if needed. + constexpr static unsigned kExplosionFactor = 32; +}; + +} // end namespace mlir. + +#endif // MLIR_IR_AFFINE_STRUCTURES_H diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index d4aa8a67600..e0fc934a620 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -79,9 +79,6 @@ FunctionPass *createLoopFusionPass(); /// memory hierarchy. FunctionPass *createPipelineDataTransferPass(); -/// Creates a pass which composes all affine maps applied to loads and stores. -FunctionPass *createComposeAffineMapsPass(); - /// Lowers affine control flow instructions (ForStmt, IfStmt and AffineApplyOp) /// to equivalent lower-level constructs (flow of basic blocks and arithmetic /// primitives). diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index c3adf5fb7c3..b77f62c514a 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -16,6 +16,7 @@ // ============================================================================= #include "mlir/AffineOps/AffineOps.h" +#include "mlir/IR/AffineStructures.h" #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" @@ -23,7 +24,11 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "llvm/ADT/SmallBitVector.h" +#include "llvm/Support/Debug.h" using namespace mlir; +using llvm::dbgs; + +#define DEBUG_TYPE "affine-analysis" //===----------------------------------------------------------------------===// // AffineOpsDialect @@ -130,6 +135,12 @@ bool AffineApplyOp::verify() const { return false; } +/// Returns an AffineValueMap representing this affine apply. +AffineValueMap AffineApplyOp::getAsAffineValueMap() { + SmallVector operands(getOperands()); + return AffineValueMap(getAffineMap(), operands, getResult()); +} + // The result of the affine apply operation can be used as a dimension id if it // is a CFG value or if it is an Value, and all the operands are valid // dimension ids. @@ -168,6 +179,77 @@ struct SimplifyAffineApply : public RewritePattern { } // end anonymous namespace. namespace { +/// An `AffineApplyNormalizer` is a helper class that is not visible to the user +/// and supports renumbering operands of AffineApplyOp. This acts as a +/// reindexing map of Value* to positional dims or symbols and allows +/// simplifications such as: +/// +/// ```mlir +/// %1 = affine_apply (d0, d1) -> (d0 - d1) (%0, %0) +/// ``` +/// +/// into: +/// +/// ```mlir +/// %1 = affine_apply () -> (0) +/// ``` +struct AffineApplyNormalizer { + AffineApplyNormalizer(AffineMap map, ArrayRef operands); + + /// Returns the AffineMap resulting from normalization. + AffineMap getAffineMap() { return affineMap; } + + SmallVector getOperands() { + SmallVector res(reorderedDims); + res.append(concatenatedSymbols.begin(), concatenatedSymbols.end()); + return res; + } + +private: + /// Helper function to insert `v` into the coordinate system of the current + /// AffineApplyNormalizer. Returns the AffineDimExpr with the corresponding + /// renumbered position. + AffineDimExpr applyOneDim(Value *v); + + /// Given an `other` normalizer, this rewrites `other.affineMap` in the + /// coordinate system of the current AffineApplyNormalizer. + /// Returns the rewritten AffineMap and updates the dims and symbols of + /// `this`. + AffineMap renumber(const AffineApplyNormalizer &other); + + /// Given an `app`, rewrites `app.getAffineMap()` in the coordinate system of + /// the current AffineApplyNormalizer. + /// Returns the rewritten AffineMap and updates the dims and symbols of + /// `this`. + AffineMap renumber(const AffineApplyOp &app); + + /// Maps of Value* to position in `affineMap`. + DenseMap dimValueToPosition; + + /// Ordered dims and symbols matching positional dims and symbols in + /// `affineMap`. + SmallVector reorderedDims; + SmallVector concatenatedSymbols; + + AffineMap affineMap; + + /// Used with RAII to control the depth at which AffineApply are composed + /// recursively. Only accepts depth 1 for now. + /// Note that if one wishes to compose all AffineApply in the program and + /// follows program order, maxdepth 1 is sufficient. This is as much as this + /// abstraction is willing to support for now. + static unsigned &affineApplyDepth() { + static thread_local unsigned depth = 0; + return depth; + } + static constexpr unsigned kMaxAffineApplyDepth = 1; + + AffineApplyNormalizer() { affineApplyDepth()++; } + +public: + ~AffineApplyNormalizer() { affineApplyDepth()--; } +}; + /// FIXME: this is massive overkill for simple obviously always matching /// canonicalizations. Fix the pattern rewriter to make this easy. struct SimplifyAffineApplyState : public PatternState { @@ -181,6 +263,136 @@ struct SimplifyAffineApplyState : public PatternState { } // end anonymous namespace. +AffineDimExpr AffineApplyNormalizer::applyOneDim(Value *v) { + DenseMap::iterator iterPos; + bool inserted = false; + std::tie(iterPos, inserted) = + dimValueToPosition.insert(std::make_pair(v, dimValueToPosition.size())); + if (inserted) { + reorderedDims.push_back(v); + } + return getAffineDimExpr(iterPos->second, v->getFunction()->getContext()) + .cast(); +} + +AffineMap AffineApplyNormalizer::renumber(const AffineApplyNormalizer &other) { + SmallVector dimRemapping; + for (auto *v : other.reorderedDims) { + auto kvp = other.dimValueToPosition.find(v); + if (dimRemapping.size() <= kvp->second) + dimRemapping.resize(kvp->second + 1); + dimRemapping[kvp->second] = applyOneDim(kvp->first); + } + unsigned numSymbols = concatenatedSymbols.size(); + unsigned numOtherSymbols = other.concatenatedSymbols.size(); + SmallVector symRemapping(numOtherSymbols); + for (unsigned idx = 0; idx < numOtherSymbols; ++idx) { + symRemapping[idx] = + getAffineSymbolExpr(idx + numSymbols, other.affineMap.getContext()); + } + concatenatedSymbols.insert(concatenatedSymbols.end(), + other.concatenatedSymbols.begin(), + other.concatenatedSymbols.end()); + auto map = other.affineMap; + return map.replaceDimsAndSymbols(dimRemapping, symRemapping, + dimRemapping.size(), symRemapping.size()); +} + +AffineMap AffineApplyNormalizer::renumber(const AffineApplyOp &app) { + assert(app.getAffineMap().getRangeSizes().empty() && "Non-empty range sizes"); + + // Create the AffineApplyNormalizer for the operands of this + // AffineApplyOp and combine it with the current AffineApplyNormalizer. + SmallVector operands( + const_cast(app).getOperands().begin(), + const_cast(app).getOperands().end()); + AffineApplyNormalizer normalizer(app.getAffineMap(), operands); + return renumber(normalizer); +} + +AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map, + ArrayRef operands) + : AffineApplyNormalizer() { + assert(map.getRangeSizes().empty() && "Unbounded map expected"); + assert(map.getNumInputs() == operands.size() && + "number of operands does not match the number of map inputs"); + + SmallVector exprs; + for (auto en : llvm::enumerate(operands)) { + auto *t = en.value(); + assert(t->getType().isIndex()); + bool operandNotFromAffineApply = + !t->getDefiningInst() || !t->getDefiningInst()->isa(); + if (operandNotFromAffineApply || + affineApplyDepth() > kMaxAffineApplyDepth) { + if (en.index() < map.getNumDims()) { + exprs.push_back(applyOneDim(t)); + } else { + // Composition of mathematical symbols must occur by concatenation. + // A subsequent canonicalization will drop duplicates. Duplicates are + // not dropped here because it would just amount to code duplication. + concatenatedSymbols.push_back(t); + } + } else { + auto *inst = t->getDefiningInst(); + auto app = inst->dyn_cast(); + auto tmpMap = renumber(*app); + exprs.push_back(tmpMap.getResult(0)); + } + } + + // Map is already composed. + if (exprs.empty()) { + affineMap = map; + return; + } + + auto numDims = dimValueToPosition.size(); + auto numSymbols = concatenatedSymbols.size() - map.getNumSymbols(); + auto exprsMap = AffineMap::get(numDims, numSymbols, exprs, {}); + LLVM_DEBUG(map.print(dbgs() << "\nCompose map: ")); + LLVM_DEBUG(exprsMap.print(dbgs() << "\nWith map: ")); + LLVM_DEBUG(map.compose(exprsMap).print(dbgs() << "\nResult: ")); + + affineMap = simplifyAffineMap(map.compose(exprsMap)); + LLVM_DEBUG(affineMap.print(dbgs() << "\nSimplified result: ")); + LLVM_DEBUG(dbgs() << "\n"); +} + +/// Implements `map` and `operands` composition and simplification to support +/// `makeComposedAffineApply`. This can be called to achieve the same effects +/// on `map` and `operands` without creating an AffineApplyOp that needs to be +/// immediately deleted. +static void composeAffineMapAndOperands(AffineMap *map, + SmallVectorImpl *operands) { + AffineApplyNormalizer normalizer(*map, *operands); + auto normalizedMap = normalizer.getAffineMap(); + auto normalizedOperands = normalizer.getOperands(); + canonicalizeMapAndOperands(&normalizedMap, &normalizedOperands); + *map = normalizedMap; + *operands = normalizedOperands; + assert(*map); +} + +void mlir::fullyComposeAffineMapAndOperands( + AffineMap *map, SmallVectorImpl *operands) { + while (llvm::any_of(*operands, [](Value *v) { + return v->getDefiningInst() && v->getDefiningInst()->isa(); + })) { + composeAffineMapAndOperands(map, operands); + } +} + +OpPointer +mlir::makeComposedAffineApply(FuncBuilder *b, Location loc, AffineMap map, + ArrayRef operands) { + AffineMap normalizedMap = map; + SmallVector normalizedOperands(operands.begin(), operands.end()); + composeAffineMapAndOperands(&normalizedMap, &normalizedOperands); + assert(normalizedMap); + return b->create(loc, normalizedMap, normalizedOperands); +} + void mlir::canonicalizeMapAndOperands( AffineMap *map, llvm::SmallVectorImpl *operands) { if (!map || operands->empty()) @@ -245,9 +457,8 @@ PatternMatchResult SimplifyAffineApply::match(Instruction *op) const { auto map = apply->getAffineMap(); AffineMap oldMap = map; - SmallVector resultOperands(apply->getOperands().begin(), - apply->getOperands().end()); - canonicalizeMapAndOperands(&map, &resultOperands); + SmallVector resultOperands(apply->getOperands()); + composeAffineMapAndOperands(&map, &resultOperands); if (map != oldMap) return matchSuccess( std::make_unique(map, resultOperands)); @@ -678,6 +889,106 @@ void mlir::extractForInductionVars(ArrayRef> forInsts, ivs->push_back(forInst->getInductionVar()); } +bool mlir::addAffineForOpDomain(ConstOpPointer forOp, + FlatAffineConstraints *constraints) { + unsigned pos; + // Pre-condition for this method. + if (!constraints->findId(*forOp->getInductionVar(), &pos)) { + assert(0 && "Value not found"); + return false; + } + + if (forOp->getStep() != 1) + LLVM_DEBUG(llvm::dbgs() + << "Domain conservative: non-unit stride not handled\n"); + + // Adds a lower or upper bound when the bounds aren't constant. + auto addLowerOrUpperBound = [&](bool lower) -> bool { + auto operands = + lower ? forOp->getLowerBoundOperands() : forOp->getUpperBoundOperands(); + for (const auto &operand : operands) { + unsigned loc; + if (!constraints->findId(*operand, &loc)) { + if (isValidSymbol(operand)) { + constraints->addSymbolId(constraints->getNumSymbolIds(), + const_cast(operand)); + loc = + constraints->getNumDimIds() + constraints->getNumSymbolIds() - 1; + // Check if the symbol is a constant. + if (auto *opInst = operand->getDefiningInst()) { + if (auto constOp = opInst->dyn_cast()) { + constraints->setIdToConstant(*operand, constOp->getValue()); + } + } + } else { + constraints->addDimId(constraints->getNumDimIds(), + const_cast(operand)); + loc = constraints->getNumDimIds() - 1; + } + } + } + // Record positions of the operands in the constraint system. + SmallVector positions; + for (const auto &operand : operands) { + unsigned loc; + if (!constraints->findId(*operand, &loc)) + assert(0 && "expected to be found"); + positions.push_back(loc); + } + + auto boundMap = + lower ? forOp->getLowerBoundMap() : forOp->getUpperBoundMap(); + + FlatAffineConstraints localVarCst; + std::vector> flatExprs; + if (!getFlattenedAffineExprs(boundMap, &flatExprs, &localVarCst)) { + LLVM_DEBUG(llvm::dbgs() << "semi-affine expressions not yet supported\n"); + return false; + } + if (localVarCst.getNumLocalIds() > 0) { + LLVM_DEBUG(llvm::dbgs() + << "loop bounds with mod/floordiv expr's not yet supported\n"); + return false; + } + + for (const auto &flatExpr : flatExprs) { + SmallVector ineq(constraints->getNumCols(), 0); + ineq[pos] = lower ? 1 : -1; + for (unsigned j = 0, e = boundMap.getNumInputs(); j < e; j++) { + ineq[positions[j]] = lower ? -flatExpr[j] : flatExpr[j]; + } + // Constant term. + ineq[constraints->getNumCols() - 1] = + lower ? -flatExpr[flatExpr.size() - 1] + // Upper bound in flattenedExpr is an exclusive one. + : flatExpr[flatExpr.size() - 1] - 1; + constraints->addInequality(ineq); + } + return true; + }; + + if (forOp->hasConstantLowerBound()) { + constraints->addConstantLowerBound(pos, forOp->getConstantLowerBound()); + } else { + // Non-constant lower bound case. + if (!addLowerOrUpperBound(/*lower=*/true)) + return false; + } + + if (forOp->hasConstantUpperBound()) { + constraints->addConstantUpperBound(pos, forOp->getConstantUpperBound() - 1); + return true; + } + // Non-constant upper bound case. + return addLowerOrUpperBound(/*lower=*/false); +} + +/// Returns an AffineValueMap representing this bound. +AffineValueMap AffineBound::getAsAffineValueMap() { + SmallVector operands(getOperands()); + return AffineValueMap(getMap(), operands); +} + //===----------------------------------------------------------------------===// // AffineIfOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index 2323cb3ef71..0a5c4727eb1 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -22,9 +22,9 @@ #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/AffineOps/AffineOps.h" -#include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Utils.h" #include "mlir/IR/AffineExprVisitor.h" +#include "mlir/IR/AffineStructures.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Instruction.h" @@ -42,462 +42,6 @@ using namespace mlir; using llvm::dbgs; -/// Constructs an affine expression from a flat ArrayRef. If there are local -/// identifiers (neither dimensional nor symbolic) that appear in the sum of -/// products expression, 'localExprs' is expected to have the AffineExpr -/// for it, and is substituted into. The ArrayRef 'eq' is expected to be in the -/// format [dims, symbols, locals, constant term]. -// TODO(bondhugula): refactor getAddMulPureAffineExpr to reuse it from here. -static AffineExpr toAffineExpr(ArrayRef eq, unsigned numDims, - unsigned numSymbols, - ArrayRef localExprs, - MLIRContext *context) { - // Assert expected numLocals = eq.size() - numDims - numSymbols - 1 - assert(eq.size() - numDims - numSymbols - 1 == localExprs.size() && - "unexpected number of local expressions"); - - auto expr = getAffineConstantExpr(0, context); - // Dimensions and symbols. - for (unsigned j = 0; j < numDims + numSymbols; j++) { - if (eq[j] == 0) { - continue; - } - auto id = j < numDims ? getAffineDimExpr(j, context) - : getAffineSymbolExpr(j - numDims, context); - expr = expr + id * eq[j]; - } - - // Local identifiers. - for (unsigned j = numDims + numSymbols, e = eq.size() - 1; j < e; j++) { - if (eq[j] == 0) { - continue; - } - auto term = localExprs[j - numDims - numSymbols] * eq[j]; - expr = expr + term; - } - - // Constant term. - int64_t constTerm = eq[eq.size() - 1]; - if (constTerm != 0) - expr = expr + constTerm; - return expr; -} - -AffineMap mlir::simplifyAffineMap(AffineMap map) { - SmallVector exprs, sizes; - for (auto e : map.getResults()) { - exprs.push_back( - simplifyAffineExpr(e, map.getNumDims(), map.getNumSymbols())); - } - for (auto e : map.getRangeSizes()) { - sizes.push_back( - simplifyAffineExpr(e, map.getNumDims(), map.getNumSymbols())); - } - return AffineMap::get(map.getNumDims(), map.getNumSymbols(), exprs, sizes); -} - -namespace { - -// This class is used to flatten a pure affine expression (AffineExpr, -// which is in a tree form) into a sum of products (w.r.t constants) when -// possible, and in that process simplifying the expression. For a modulo, -// floordiv, or a ceildiv expression, an additional identifier, called a local -// identifier, is introduced to rewrite the expression as a sum of product -// affine expression. Each local identifier is always and by construction a -// floordiv of a pure add/mul affine function of dimensional, symbolic, and -// other local identifiers, in a non-mutually recursive way. Hence, every local -// identifier can ultimately always be recovered as an affine function of -// dimensional and symbolic identifiers (involving floordiv's); note however -// that by AffineExpr construction, some floordiv combinations are converted to -// mod's. The result of the flattening is a flattened expression and a set of -// constraints involving just the local variables. -// -// d2 + (d0 + d1) floordiv 4 is flattened to d2 + q where 'q' is the local -// variable introduced, with localVarCst containing 4*q <= d0 + d1 <= 4*q + 3. -// -// The simplification performed includes the accumulation of contributions for -// each dimensional and symbolic identifier together, the simplification of -// floordiv/ceildiv/mod expressions and other simplifications that in turn -// happen as a result. A simplification that this flattening naturally performs -// is of simplifying the numerator and denominator of floordiv/ceildiv, and -// folding a modulo expression to a zero, if possible. Three examples are below: -// -// (d0 + 3 * d1) + d0) - 2 * d1) - d0 simplified to d0 + d1 -// (d0 - d0 mod 4 + 4) mod 4 simplified to 0 -// (3*d0 + 2*d1 + d0) floordiv 2 + d1 simplified to 2*d0 + 2*d1 -// -// The way the flattening works for the second example is as follows: d0 % 4 is -// replaced by d0 - 4*q with q being introduced: the expression then simplifies -// to: (d0 - (d0 - 4q) + 4) = 4q + 4, modulo of which w.r.t 4 simplifies to -// zero. Note that an affine expression may not always be expressible purely as -// a sum of products involving just the original dimensional and symbolic -// identifiers due to the presence of modulo/floordiv/ceildiv expressions that -// may not be eliminated after simplification; in such cases, the final -// expression can be reconstructed by replacing the local identifiers with their -// corresponding explicit form stored in 'localExprs' (note that each of the -// explicit forms itself would have been simplified). -// -// The expression walk method here performs a linear time post order walk that -// performs the above simplifications through visit methods, with partial -// results being stored in 'operandExprStack'. When a parent expr is visited, -// the flattened expressions corresponding to its two operands would already be -// on the stack - the parent expression looks at the two flattened expressions -// and combines the two. It pops off the operand expressions and pushes the -// combined result (although this is done in-place on its LHS operand expr). -// When the walk is completed, the flattened form of the top-level expression -// would be left on the stack. -// -// A flattener can be repeatedly used for multiple affine expressions that bind -// to the same operands, for example, for all result expressions of an -// AffineMap or AffineValueMap. In such cases, using it for multiple expressions -// is more efficient than creating a new flattener for each expression since -// common idenical div and mod expressions appearing across different -// expressions are mapped to the same local identifier (same column position in -// 'localVarCst'). -struct AffineExprFlattener : public AffineExprVisitor { -public: - // Flattend expression layout: [dims, symbols, locals, constant] - // Stack that holds the LHS and RHS operands while visiting a binary op expr. - // In future, consider adding a prepass to determine how big the SmallVector's - // will be, and linearize this to std::vector to prevent - // SmallVector moves on re-allocation. - std::vector> operandExprStack; - // Constraints connecting newly introduced local variables (for mod's and - // div's) to existing (dimensional and symbolic) ones. These are always - // inequalities. - FlatAffineConstraints localVarCst; - - unsigned numDims; - unsigned numSymbols; - // Number of newly introduced identifiers to flatten mod/floordiv/ceildiv - // expressions that could not be simplified. - unsigned numLocals; - // AffineExpr's corresponding to the floordiv/ceildiv/mod expressions for - // which new identifiers were introduced; if the latter do not get canceled - // out, these expressions can be readily used to reconstruct the AffineExpr - // (tree) form. Note that these expressions themselves would have been - // simplified (recursively) by this pass. Eg. d0 + (d0 + 2*d1 + d0) ceildiv 4 - // will be simplified to d0 + q, where q = (d0 + d1) ceildiv 2. (d0 + d1) - // ceildiv 2 would be the local expression stored for q. - SmallVector localExprs; - MLIRContext *context; - - AffineExprFlattener(unsigned numDims, unsigned numSymbols, - MLIRContext *context) - : numDims(numDims), numSymbols(numSymbols), numLocals(0), - context(context) { - operandExprStack.reserve(8); - localVarCst.reset(numDims, numSymbols, numLocals); - } - - void visitMulExpr(AffineBinaryOpExpr expr) { - assert(operandExprStack.size() >= 2); - // This is a pure affine expr; the RHS will be a constant. - assert(expr.getRHS().isa()); - // Get the RHS constant. - auto rhsConst = operandExprStack.back()[getConstantIndex()]; - operandExprStack.pop_back(); - // Update the LHS in place instead of pop and push. - auto &lhs = operandExprStack.back(); - for (unsigned i = 0, e = lhs.size(); i < e; i++) { - lhs[i] *= rhsConst; - } - } - - void visitAddExpr(AffineBinaryOpExpr expr) { - assert(operandExprStack.size() >= 2); - const auto &rhs = operandExprStack.back(); - auto &lhs = operandExprStack[operandExprStack.size() - 2]; - assert(lhs.size() == rhs.size()); - // Update the LHS in place. - for (unsigned i = 0, e = rhs.size(); i < e; i++) { - lhs[i] += rhs[i]; - } - // Pop off the RHS. - operandExprStack.pop_back(); - } - - // - // t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1 - // - // A mod expression "expr mod c" is thus flattened by introducing a new local - // variable q (= expr floordiv c), such that expr mod c is replaced with - // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst. - void visitModExpr(AffineBinaryOpExpr expr) { - assert(operandExprStack.size() >= 2); - // This is a pure affine expr; the RHS will be a constant. - assert(expr.getRHS().isa()); - auto rhsConst = operandExprStack.back()[getConstantIndex()]; - operandExprStack.pop_back(); - auto &lhs = operandExprStack.back(); - // TODO(bondhugula): handle modulo by zero case when this issue is fixed - // at the other places in the IR. - assert(rhsConst > 0 && "RHS constant has to be positive"); - - // Check if the LHS expression is a multiple of modulo factor. - unsigned i, e; - for (i = 0, e = lhs.size(); i < e; i++) - if (lhs[i] % rhsConst != 0) - break; - // If yes, modulo expression here simplifies to zero. - if (i == lhs.size()) { - std::fill(lhs.begin(), lhs.end(), 0); - return; - } - - // Add a local variable for the quotient, i.e., expr % c is replaced by - // (expr - q * c) where q = expr floordiv c. Do this while canceling out - // the GCD of expr and c. - SmallVector floorDividend(lhs); - uint64_t gcd = rhsConst; - for (unsigned i = 0, e = lhs.size(); i < e; i++) - gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i])); - // Simplify the numerator and the denominator. - if (gcd != 1) { - for (unsigned i = 0, e = floorDividend.size(); i < e; i++) - floorDividend[i] = floorDividend[i] / static_cast(gcd); - } - int64_t floorDivisor = rhsConst / static_cast(gcd); - - // Construct the AffineExpr form of the floordiv to store in localExprs. - auto dividendExpr = - toAffineExpr(floorDividend, numDims, numSymbols, localExprs, context); - auto divisorExpr = getAffineConstantExpr(floorDivisor, context); - auto floorDivExpr = dividendExpr.floorDiv(divisorExpr); - int loc; - if ((loc = findLocalId(floorDivExpr)) == -1) { - addLocalFloorDivId(floorDividend, floorDivisor, floorDivExpr); - // Set result at top of stack to "lhs - rhsConst * q". - lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst; - } else { - // Reuse the existing local id. - lhs[getLocalVarStartIndex() + loc] = -rhsConst; - } - } - - void visitCeilDivExpr(AffineBinaryOpExpr expr) { - visitDivExpr(expr, /*isCeil=*/true); - } - void visitFloorDivExpr(AffineBinaryOpExpr expr) { - visitDivExpr(expr, /*isCeil=*/false); - } - - void visitDimExpr(AffineDimExpr expr) { - operandExprStack.emplace_back(SmallVector(getNumCols(), 0)); - auto &eq = operandExprStack.back(); - assert(expr.getPosition() < numDims && "Inconsistent number of dims"); - eq[getDimStartIndex() + expr.getPosition()] = 1; - } - - void visitSymbolExpr(AffineSymbolExpr expr) { - operandExprStack.emplace_back(SmallVector(getNumCols(), 0)); - auto &eq = operandExprStack.back(); - assert(expr.getPosition() < numSymbols && "inconsistent number of symbols"); - eq[getSymbolStartIndex() + expr.getPosition()] = 1; - } - - void visitConstantExpr(AffineConstantExpr expr) { - operandExprStack.emplace_back(SmallVector(getNumCols(), 0)); - auto &eq = operandExprStack.back(); - eq[getConstantIndex()] = expr.getValue(); - } - -private: - // t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1 - // A floordiv is thus flattened by introducing a new local variable q, and - // replacing that expression with 'q' while adding the constraints - // c * q <= expr <= c * q + c - 1 to localVarCst (done by - // FlatAffineConstraints::addLocalFloorDiv). - // - // A ceildiv is similarly flattened: - // t = expr ceildiv c <=> t = (expr + c - 1) floordiv c - void visitDivExpr(AffineBinaryOpExpr expr, bool isCeil) { - assert(operandExprStack.size() >= 2); - assert(expr.getRHS().isa()); - - // This is a pure affine expr; the RHS is a positive constant. - int64_t rhsConst = operandExprStack.back()[getConstantIndex()]; - // TODO(bondhugula): handle division by zero at the same time the issue is - // fixed at other places. - assert(rhsConst > 0 && "RHS constant has to be positive"); - operandExprStack.pop_back(); - auto &lhs = operandExprStack.back(); - - // Simplify the floordiv, ceildiv if possible by canceling out the greatest - // common divisors of the numerator and denominator. - uint64_t gcd = std::abs(rhsConst); - for (unsigned i = 0, e = lhs.size(); i < e; i++) - gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i])); - // Simplify the numerator and the denominator. - if (gcd != 1) { - for (unsigned i = 0, e = lhs.size(); i < e; i++) - lhs[i] = lhs[i] / static_cast(gcd); - } - int64_t divisor = rhsConst / static_cast(gcd); - // If the divisor becomes 1, the updated LHS is the result. (The - // divisor can't be negative since rhsConst is positive). - if (divisor == 1) - return; - - // If the divisor cannot be simplified to one, we will have to retain - // the ceil/floor expr (simplified up until here). Add an existential - // quantifier to express its result, i.e., expr1 div expr2 is replaced - // by a new identifier, q. - auto a = toAffineExpr(lhs, numDims, numSymbols, localExprs, context); - auto b = getAffineConstantExpr(divisor, context); - - int loc; - auto divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b); - if ((loc = findLocalId(divExpr)) == -1) { - if (!isCeil) { - SmallVector dividend(lhs); - addLocalFloorDivId(dividend, divisor, divExpr); - } else { - // lhs ceildiv c <=> (lhs + c - 1) floordiv c - SmallVector dividend(lhs); - dividend.back() += divisor - 1; - addLocalFloorDivId(dividend, divisor, divExpr); - } - } - // Set the expression on stack to the local var introduced to capture the - // result of the division (floor or ceil). - std::fill(lhs.begin(), lhs.end(), 0); - if (loc == -1) - lhs[getLocalVarStartIndex() + numLocals - 1] = 1; - else - lhs[getLocalVarStartIndex() + loc] = 1; - } - - // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr). - // The local identifier added is always a floordiv of a pure add/mul affine - // function of other identifiers, coefficients of which are specified in - // dividend and with respect to a positive constant divisor. localExpr is the - // simplified tree expression (AffineExpr) corresponding to the quantifier. - void addLocalFloorDivId(ArrayRef dividend, int64_t divisor, - AffineExpr localExpr) { - assert(divisor > 0 && "positive constant divisor expected"); - for (auto &subExpr : operandExprStack) - subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0); - localExprs.push_back(localExpr); - numLocals++; - // Update localVarCst. - localVarCst.addLocalFloorDiv(dividend, divisor); - } - - int findLocalId(AffineExpr localExpr) { - SmallVectorImpl::iterator it; - if ((it = std::find(localExprs.begin(), localExprs.end(), localExpr)) == - localExprs.end()) - return -1; - return it - localExprs.begin(); - } - - inline unsigned getNumCols() const { - return numDims + numSymbols + numLocals + 1; - } - inline unsigned getConstantIndex() const { return getNumCols() - 1; } - inline unsigned getLocalVarStartIndex() const { return numDims + numSymbols; } - inline unsigned getSymbolStartIndex() const { return numDims; } - inline unsigned getDimStartIndex() const { return 0; } -}; - -} // end anonymous namespace - -/// Simplify the affine expression by flattening it and reconstructing it. -AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims, - unsigned numSymbols) { - // TODO(bondhugula): only pure affine for now. The simplification here can - // be extended to semi-affine maps in the future. - if (!expr.isPureAffine()) - return expr; - - AffineExprFlattener flattener(numDims, numSymbols, expr.getContext()); - flattener.walkPostOrder(expr); - ArrayRef flattenedExpr = flattener.operandExprStack.back(); - auto simplifiedExpr = toAffineExpr(flattenedExpr, numDims, numSymbols, - flattener.localExprs, expr.getContext()); - flattener.operandExprStack.pop_back(); - assert(flattener.operandExprStack.empty()); - - return simplifiedExpr; -} - -// Flattens the expressions in map. Returns true on success or false -// if 'expr' was unable to be flattened (i.e., semi-affine expressions not -// handled yet). -static bool getFlattenedAffineExprs( - ArrayRef exprs, unsigned numDims, unsigned numSymbols, - std::vector> *flattenedExprs, - FlatAffineConstraints *localVarCst) { - if (exprs.empty()) { - localVarCst->reset(numDims, numSymbols); - return true; - } - - flattenedExprs->clear(); - flattenedExprs->reserve(exprs.size()); - - AffineExprFlattener flattener(numDims, numSymbols, exprs[0].getContext()); - // Use the same flattener to simplify each expression successively. This way - // local identifiers / expressions are shared. - for (auto expr : exprs) { - if (!expr.isPureAffine()) - return false; - - flattener.walkPostOrder(expr); - } - - assert(flattener.operandExprStack.size() == exprs.size()); - flattenedExprs->insert(flattenedExprs->end(), - flattener.operandExprStack.begin(), - flattener.operandExprStack.end()); - if (localVarCst) - localVarCst->clearAndCopyFrom(flattener.localVarCst); - - return true; -} - -// Flattens 'expr' into 'flattenedExpr'. Returns true on success or false -// if 'expr' was unable to be flattened (semi-affine expressions not handled -// yet). -bool mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims, - unsigned numSymbols, - llvm::SmallVectorImpl *flattenedExpr, - FlatAffineConstraints *localVarCst) { - std::vector> flattenedExprs; - bool ret = ::getFlattenedAffineExprs({expr}, numDims, numSymbols, - &flattenedExprs, localVarCst); - *flattenedExpr = flattenedExprs[0]; - return ret; -} - -/// Flattens the expressions in map. Returns true on success or false -/// if 'expr' was unable to be flattened (i.e., semi-affine expressions not -/// handled yet). -bool mlir::getFlattenedAffineExprs( - AffineMap map, std::vector> *flattenedExprs, - FlatAffineConstraints *localVarCst) { - if (map.getNumResults() == 0) { - localVarCst->reset(map.getNumDims(), map.getNumSymbols()); - return true; - } - return ::getFlattenedAffineExprs(map.getResults(), map.getNumDims(), - map.getNumSymbols(), flattenedExprs, - localVarCst); -} - -bool mlir::getFlattenedAffineExprs( - IntegerSet set, std::vector> *flattenedExprs, - FlatAffineConstraints *localVarCst) { - if (set.getNumConstraints() == 0) { - localVarCst->reset(set.getNumDims(), set.getNumSymbols()); - return true; - } - return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(), - set.getNumSymbols(), flattenedExprs, - localVarCst); -} - /// Returns the sequence of AffineApplyOp Instructions operation in /// 'affineApplyOps', which are reachable via a search starting from 'operands', /// and ending at operands which are not defined by AffineApplyOps. @@ -563,7 +107,7 @@ bool mlir::getIndexSet(MutableArrayRef> forOps, domain->reset(forOps.size(), /*numSymbols=*/0, /*numLocals=*/0, indices); for (auto forOp : forOps) { // Add constraints from forOp's bounds. - if (!domain->addAffineForOpDomain(forOp)) + if (!addAffineForOpDomain(forOp, domain)) return false; } return true; @@ -1355,208 +899,3 @@ bool mlir::checkMemrefAccessDependence( LLVM_DEBUG(dependenceConstraints->dump()); return true; } - -namespace { - -/// An `AffineApplyNormalizer` is a helper class that is not visible to the user -/// and supports renumbering operands of AffineApplyOp. This acts as a -/// reindexing map of Value* to positional dims or symbols and allows -/// simplifications such as: -/// -/// ```mlir -/// %1 = affine_apply (d0, d1) -> (d0 - d1) (%0, %0) -/// ``` -/// -/// into: -/// -/// ```mlir -/// %1 = affine_apply () -> (0) -/// ``` -struct AffineApplyNormalizer { - AffineApplyNormalizer(AffineMap map, ArrayRef operands); - - /// Returns the AffineMap resulting from normalization. - AffineMap getAffineMap() { return affineMap; } - - SmallVector getOperands() { - SmallVector res(reorderedDims); - res.append(concatenatedSymbols.begin(), concatenatedSymbols.end()); - return res; - } - -private: - /// Helper function to insert `v` into the coordinate system of the current - /// AffineApplyNormalizer. Returns the AffineDimExpr with the corresponding - /// renumbered position. - AffineDimExpr applyOneDim(Value *v); - - /// Given an `other` normalizer, this rewrites `other.affineMap` in the - /// coordinate system of the current AffineApplyNormalizer. - /// Returns the rewritten AffineMap and updates the dims and symbols of - /// `this`. - AffineMap renumber(const AffineApplyNormalizer &other); - - /// Given an `app`, rewrites `app.getAffineMap()` in the coordinate system of - /// the current AffineApplyNormalizer. - /// Returns the rewritten AffineMap and updates the dims and symbols of - /// `this`. - AffineMap renumber(const AffineApplyOp &app); - - /// Maps of Value* to position in `affineMap`. - DenseMap dimValueToPosition; - - /// Ordered dims and symbols matching positional dims and symbols in - /// `affineMap`. - SmallVector reorderedDims; - SmallVector concatenatedSymbols; - - AffineMap affineMap; - - /// Used with RAII to control the depth at which AffineApply are composed - /// recursively. Only accepts depth 1 for now. - /// Note that if one wishes to compose all AffineApply in the program and - /// follows program order, maxdepth 1 is sufficient. This is as much as this - /// abstraction is willing to support for now. - static unsigned &affineApplyDepth() { - static thread_local unsigned depth = 0; - return depth; - } - static constexpr unsigned kMaxAffineApplyDepth = 1; - - AffineApplyNormalizer() { affineApplyDepth()++; } - -public: - ~AffineApplyNormalizer() { affineApplyDepth()--; } -}; - -} // namespace - -AffineDimExpr AffineApplyNormalizer::applyOneDim(Value *v) { - DenseMap::iterator iterPos; - bool inserted = false; - std::tie(iterPos, inserted) = - dimValueToPosition.insert(std::make_pair(v, dimValueToPosition.size())); - if (inserted) { - reorderedDims.push_back(v); - } - return getAffineDimExpr(iterPos->second, v->getFunction()->getContext()) - .cast(); -} - -AffineMap AffineApplyNormalizer::renumber(const AffineApplyNormalizer &other) { - SmallVector dimRemapping; - for (auto *v : other.reorderedDims) { - auto kvp = other.dimValueToPosition.find(v); - if (dimRemapping.size() <= kvp->second) - dimRemapping.resize(kvp->second + 1); - dimRemapping[kvp->second] = applyOneDim(kvp->first); - } - unsigned numSymbols = concatenatedSymbols.size(); - unsigned numOtherSymbols = other.concatenatedSymbols.size(); - SmallVector symRemapping(numOtherSymbols); - for (unsigned idx = 0; idx < numOtherSymbols; ++idx) { - symRemapping[idx] = - getAffineSymbolExpr(idx + numSymbols, other.affineMap.getContext()); - } - concatenatedSymbols.insert(concatenatedSymbols.end(), - other.concatenatedSymbols.begin(), - other.concatenatedSymbols.end()); - auto map = other.affineMap; - return map.replaceDimsAndSymbols(dimRemapping, symRemapping, - dimRemapping.size(), symRemapping.size()); -} - -AffineMap AffineApplyNormalizer::renumber(const AffineApplyOp &app) { - assert(app.getAffineMap().getRangeSizes().empty() && "Non-empty range sizes"); - - // Create the AffineApplyNormalizer for the operands of this - // AffineApplyOp and combine it with the current AffineApplyNormalizer. - SmallVector operands( - const_cast(app).getOperands().begin(), - const_cast(app).getOperands().end()); - AffineApplyNormalizer normalizer(app.getAffineMap(), operands); - return renumber(normalizer); -} - -AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map, - ArrayRef operands) - : AffineApplyNormalizer() { - assert(map.getRangeSizes().empty() && "Unbounded map expected"); - assert(map.getNumInputs() == operands.size() && - "number of operands does not match the number of map inputs"); - - SmallVector exprs; - for (auto en : llvm::enumerate(operands)) { - auto *t = en.value(); - assert(t->getType().isIndex()); - bool operandNotFromAffineApply = - !t->getDefiningInst() || !t->getDefiningInst()->isa(); - if (operandNotFromAffineApply || - affineApplyDepth() > kMaxAffineApplyDepth) { - if (en.index() < map.getNumDims()) { - exprs.push_back(applyOneDim(t)); - } else { - // Composition of mathematical symbols must occur by concatenation. - // A subsequent canonicalization will drop duplicates. Duplicates are - // not dropped here because it would just amount to code duplication. - concatenatedSymbols.push_back(t); - } - } else { - auto *inst = t->getDefiningInst(); - auto app = inst->dyn_cast(); - auto tmpMap = renumber(*app); - exprs.push_back(tmpMap.getResult(0)); - } - } - - // Map is already composed. - if (exprs.empty()) { - affineMap = map; - return; - } - - auto numDims = dimValueToPosition.size(); - auto numSymbols = concatenatedSymbols.size() - map.getNumSymbols(); - auto exprsMap = AffineMap::get(numDims, numSymbols, exprs, {}); - LLVM_DEBUG(map.print(dbgs() << "\nCompose map: ")); - LLVM_DEBUG(exprsMap.print(dbgs() << "\nWith map: ")); - LLVM_DEBUG(map.compose(exprsMap).print(dbgs() << "\nResult: ")); - - affineMap = simplifyAffineMap(map.compose(exprsMap)); - LLVM_DEBUG(affineMap.print(dbgs() << "\nSimplified result: ")); - LLVM_DEBUG(dbgs() << "\n"); -} - -/// Implements `map` and `operands` composition and simplification to support -/// `makeComposedAffineApply`. This can be called to achieve the same effects -/// on `map` and `operands` without creating an AffineApplyOp that needs to be -/// immediately deleted. -static void composeAffineMapAndOperands(AffineMap *map, - SmallVectorImpl *operands) { - AffineApplyNormalizer normalizer(*map, *operands); - auto normalizedMap = normalizer.getAffineMap(); - auto normalizedOperands = normalizer.getOperands(); - canonicalizeMapAndOperands(&normalizedMap, &normalizedOperands); - *map = normalizedMap; - *operands = normalizedOperands; - assert(*map); -} - -void mlir::fullyComposeAffineMapAndOperands( - AffineMap *map, SmallVectorImpl *operands) { - while (llvm::any_of(*operands, [](Value *v) { - return v->getDefiningInst() && v->getDefiningInst()->isa(); - })) { - composeAffineMapAndOperands(map, operands); - } -} - -OpPointer -mlir::makeComposedAffineApply(FuncBuilder *b, Location loc, AffineMap map, - ArrayRef operands) { - AffineMap normalizedMap = map; - SmallVector normalizedOperands(operands.begin(), operands.end()); - composeAffineMapAndOperands(&normalizedMap, &normalizedOperands); - assert(normalizedMap); - return b->create(loc, normalizedMap, normalizedOperands); -} diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp deleted file mode 100644 index b764d69c298..00000000000 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ /dev/null @@ -1,2167 +0,0 @@ -//===- AffineStructures.cpp - MLIR Affine Structures Class-------*- C++ -*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// Structures for affine/polyhedral analysis of MLIR functions. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Analysis/AffineStructures.h" -#include "mlir/AffineOps/AffineOps.h" -#include "mlir/Analysis/AffineAnalysis.h" -#include "mlir/IR/AffineExprVisitor.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Instruction.h" -#include "mlir/IR/IntegerSet.h" -#include "mlir/Support/MathExtras.h" -#include "llvm/ADT/DenseSet.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" - -#define DEBUG_TYPE "affine-structures" - -using namespace mlir; -using namespace llvm; - -//===----------------------------------------------------------------------===// -// MutableAffineMap. -//===----------------------------------------------------------------------===// - -MutableAffineMap::MutableAffineMap(AffineMap map) - : numDims(map.getNumDims()), numSymbols(map.getNumSymbols()), - // A map always has at least 1 result by construction - context(map.getResult(0).getContext()) { - for (auto result : map.getResults()) - results.push_back(result); - for (auto rangeSize : map.getRangeSizes()) - results.push_back(rangeSize); -} - -void MutableAffineMap::reset(AffineMap map) { - results.clear(); - rangeSizes.clear(); - numDims = map.getNumDims(); - numSymbols = map.getNumSymbols(); - // A map always has at least 1 result by construction - context = map.getResult(0).getContext(); - for (auto result : map.getResults()) - results.push_back(result); - for (auto rangeSize : map.getRangeSizes()) - results.push_back(rangeSize); -} - -bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const { - if (results[idx].isMultipleOf(factor)) - return true; - - // TODO(bondhugula): use simplifyAffineExpr and FlatAffineConstraints to - // complete this (for a more powerful analysis). - return false; -} - -// Simplifies the result affine expressions of this map. The expressions have to -// be pure for the simplification implemented. -void MutableAffineMap::simplify() { - // Simplify each of the results if possible. - // TODO(ntv): functional-style map - for (unsigned i = 0, e = getNumResults(); i < e; i++) { - results[i] = simplifyAffineExpr(getResult(i), numDims, numSymbols); - } -} - -AffineMap MutableAffineMap::getAffineMap() const { - return AffineMap::get(numDims, numSymbols, results, rangeSizes); -} - -MutableIntegerSet::MutableIntegerSet(IntegerSet set, MLIRContext *context) - : numDims(set.getNumDims()), numSymbols(set.getNumSymbols()), - context(context) { - // TODO(bondhugula) -} - -// Universal set. -MutableIntegerSet::MutableIntegerSet(unsigned numDims, unsigned numSymbols, - MLIRContext *context) - : numDims(numDims), numSymbols(numSymbols), context(context) {} - -//===----------------------------------------------------------------------===// -// AffineValueMap. -//===----------------------------------------------------------------------===// - -AffineValueMap::AffineValueMap(const AffineApplyOp &op) - : map(op.getAffineMap()) { - for (auto *operand : op.getOperands()) - operands.push_back(const_cast(operand)); - results.push_back(const_cast(op.getResult())); -} - -AffineValueMap::AffineValueMap(AffineMap map, ArrayRef operands) - : map(map) { - for (Value *operand : operands) { - this->operands.push_back(operand); - } -} - -void AffineValueMap::reset(AffineMap map, ArrayRef operands) { - this->operands.clear(); - this->results.clear(); - this->map.reset(map); - for (Value *operand : operands) { - this->operands.push_back(operand); - } -} - -// Returns true and sets 'indexOfMatch' if 'valueToMatch' is found in -// 'valuesToSearch' beginning at 'indexStart'. Returns false otherwise. -static bool findIndex(Value *valueToMatch, ArrayRef valuesToSearch, - unsigned indexStart, unsigned *indexOfMatch) { - unsigned size = valuesToSearch.size(); - for (unsigned i = indexStart; i < size; ++i) { - if (valueToMatch == valuesToSearch[i]) { - *indexOfMatch = i; - return true; - } - } - return false; -} - -inline bool AffineValueMap::isMultipleOf(unsigned idx, int64_t factor) const { - return map.isMultipleOf(idx, factor); -} - -/// This method uses the invariant that operands are always positionally aligned -/// with the AffineDimExpr in the underlying AffineMap. -bool AffineValueMap::isFunctionOf(unsigned idx, Value *value) const { - unsigned index; - if (!findIndex(value, operands, /*indexStart=*/0, &index)) { - return false; - } - auto expr = const_cast(this)->getAffineMap().getResult(idx); - // TODO(ntv): this is better implemented on a flattened representation. - // At least for now it is conservative. - return expr.isFunctionOfDim(index); -} - -Value *AffineValueMap::getOperand(unsigned i) const { - return static_cast(operands[i]); -} - -ArrayRef AffineValueMap::getOperands() const { - return ArrayRef(operands); -} - -AffineMap AffineValueMap::getAffineMap() const { return map.getAffineMap(); } - -AffineValueMap::~AffineValueMap() {} - -//===----------------------------------------------------------------------===// -// FlatAffineConstraints. -//===----------------------------------------------------------------------===// - -// Copy constructor. -FlatAffineConstraints::FlatAffineConstraints( - const FlatAffineConstraints &other) { - numReservedCols = other.numReservedCols; - numDims = other.getNumDimIds(); - numSymbols = other.getNumSymbolIds(); - numIds = other.getNumIds(); - - auto otherIds = other.getIds(); - ids.reserve(numReservedCols); - ids.append(otherIds.begin(), otherIds.end()); - - unsigned numReservedEqualities = other.getNumReservedEqualities(); - unsigned numReservedInequalities = other.getNumReservedInequalities(); - - equalities.reserve(numReservedEqualities * numReservedCols); - inequalities.reserve(numReservedInequalities * numReservedCols); - - for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) { - addInequality(other.getInequality(r)); - } - for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) { - addEquality(other.getEquality(r)); - } -} - -// Clones this object. -std::unique_ptr FlatAffineConstraints::clone() const { - return std::make_unique(*this); -} - -// Construct from an IntegerSet. -FlatAffineConstraints::FlatAffineConstraints(IntegerSet set) - : numReservedCols(set.getNumOperands() + 1), - numIds(set.getNumDims() + set.getNumSymbols()), numDims(set.getNumDims()), - numSymbols(set.getNumSymbols()) { - equalities.reserve(set.getNumEqualities() * numReservedCols); - inequalities.reserve(set.getNumInequalities() * numReservedCols); - ids.resize(numIds, None); - - // Flatten expressions and add them to the constraint system. - std::vector> flatExprs; - FlatAffineConstraints localVarCst; - if (!getFlattenedAffineExprs(set, &flatExprs, &localVarCst)) { - assert(false && "flattening unimplemented for semi-affine integer sets"); - return; - } - assert(flatExprs.size() == set.getNumConstraints()); - for (unsigned l = 0, e = localVarCst.getNumLocalIds(); l < e; l++) { - addLocalId(getNumLocalIds()); - } - - for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) { - const auto &flatExpr = flatExprs[i]; - assert(flatExpr.size() == getNumCols()); - if (set.getEqFlags()[i]) { - addEquality(flatExpr); - } else { - addInequality(flatExpr); - } - } - // Add the other constraints involving local id's from flattening. - append(localVarCst); -} - -void FlatAffineConstraints::reset(unsigned numReservedInequalities, - unsigned numReservedEqualities, - unsigned newNumReservedCols, - unsigned newNumDims, unsigned newNumSymbols, - unsigned newNumLocals, - ArrayRef idArgs) { - assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 && - "minimum 1 column"); - numReservedCols = newNumReservedCols; - numDims = newNumDims; - numSymbols = newNumSymbols; - numIds = numDims + numSymbols + newNumLocals; - assert(idArgs.empty() || idArgs.size() == numIds); - - clearConstraints(); - if (numReservedEqualities >= 1) - equalities.reserve(newNumReservedCols * numReservedEqualities); - if (numReservedInequalities >= 1) - inequalities.reserve(newNumReservedCols * numReservedInequalities); - if (idArgs.empty()) { - ids.resize(numIds, None); - } else { - ids.assign(idArgs.begin(), idArgs.end()); - } -} - -void FlatAffineConstraints::reset(unsigned newNumDims, unsigned newNumSymbols, - unsigned newNumLocals, - ArrayRef idArgs) { - reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims, - newNumSymbols, newNumLocals, idArgs); -} - -void FlatAffineConstraints::append(const FlatAffineConstraints &other) { - assert(other.getNumCols() == getNumCols()); - assert(other.getNumDimIds() == getNumDimIds()); - assert(other.getNumSymbolIds() == getNumSymbolIds()); - - inequalities.reserve(inequalities.size() + - other.getNumInequalities() * numReservedCols); - equalities.reserve(equalities.size() + - other.getNumEqualities() * numReservedCols); - - for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) { - addInequality(other.getInequality(r)); - } - for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) { - addEquality(other.getEquality(r)); - } -} - -void FlatAffineConstraints::addLocalId(unsigned pos) { - addId(IdKind::Local, pos); -} - -void FlatAffineConstraints::addDimId(unsigned pos, Value *id) { - addId(IdKind::Dimension, pos, id); -} - -void FlatAffineConstraints::addSymbolId(unsigned pos, Value *id) { - addId(IdKind::Symbol, pos, id); -} - -/// Adds a dimensional identifier. The added column is initialized to -/// zero. -void FlatAffineConstraints::addId(IdKind kind, unsigned pos, Value *id) { - if (kind == IdKind::Dimension) { - assert(pos <= getNumDimIds()); - } else if (kind == IdKind::Symbol) { - assert(pos <= getNumSymbolIds()); - } else { - assert(pos <= getNumLocalIds()); - } - - unsigned oldNumReservedCols = numReservedCols; - - // Check if a resize is necessary. - if (getNumCols() + 1 > numReservedCols) { - equalities.resize(getNumEqualities() * (getNumCols() + 1)); - inequalities.resize(getNumInequalities() * (getNumCols() + 1)); - numReservedCols++; - } - - unsigned absolutePos; - - if (kind == IdKind::Dimension) { - absolutePos = pos; - numDims++; - } else if (kind == IdKind::Symbol) { - absolutePos = pos + getNumDimIds(); - numSymbols++; - } else { - absolutePos = pos + getNumDimIds() + getNumSymbolIds(); - } - numIds++; - - // Note that getNumCols() now will already return the new size, which will be - // at least one. - int numInequalities = static_cast(getNumInequalities()); - int numEqualities = static_cast(getNumEqualities()); - int numCols = static_cast(getNumCols()); - for (int r = numInequalities - 1; r >= 0; r--) { - for (int c = numCols - 2; c >= 0; c--) { - if (c < absolutePos) - atIneq(r, c) = inequalities[r * oldNumReservedCols + c]; - else - atIneq(r, c + 1) = inequalities[r * oldNumReservedCols + c]; - } - atIneq(r, absolutePos) = 0; - } - - for (int r = numEqualities - 1; r >= 0; r--) { - for (int c = numCols - 2; c >= 0; c--) { - // All values in column absolutePositions < absolutePos have the same - // coordinates in the 2-d view of the coefficient buffer. - if (c < absolutePos) - atEq(r, c) = equalities[r * oldNumReservedCols + c]; - else - // Those at absolutePosition >= absolutePos, get a shifted - // absolutePosition. - atEq(r, c + 1) = equalities[r * oldNumReservedCols + c]; - } - // Initialize added dimension to zero. - atEq(r, absolutePos) = 0; - } - - // If an 'id' is provided, insert it; otherwise use None. - if (id) { - ids.insert(ids.begin() + absolutePos, id); - } else { - ids.insert(ids.begin() + absolutePos, None); - } - assert(ids.size() == getNumIds()); -} - -// This routine may add additional local variables if the flattened expression -// corresponding to the map has such variables due to the presence of -// mod's, ceildiv's, and floordiv's. -bool FlatAffineConstraints::composeMap(AffineValueMap *vMap) { - // Assert if the map and this constraint set aren't associated with the same - // identifiers in the same order. - assert(vMap->getNumDims() <= getNumDimIds()); - assert(vMap->getNumSymbols() <= getNumSymbolIds()); - for (unsigned i = 0, e = vMap->getNumDims(); i < e; i++) { - assert(ids[i].hasValue()); - assert(vMap->getOperand(i) == ids[i].getValue()); - } - for (unsigned i = 0, e = vMap->getNumSymbols(); i < e; i++) { - assert(ids[numDims + i].hasValue()); - assert(vMap->getOperand(vMap->getNumDims() + i) == - ids[numDims + i].getValue()); - } - - std::vector> flatExprs; - FlatAffineConstraints cst; - if (!getFlattenedAffineExprs(vMap->getAffineMap(), &flatExprs, &cst)) { - LLVM_DEBUG(llvm::dbgs() - << "composition unimplemented for semi-affine maps\n"); - return false; - } - assert(flatExprs.size() == vMap->getNumResults()); - - // Make the value map and the flat affine cst dimensions compatible. - // A lot of this code will be refactored/cleaned up. - // TODO(bondhugula): the next ~20 lines of code is pretty UGLY. This needs - // to be factored out into an FlatAffineConstraints::alignAndMerge(). - for (unsigned l = 0, e = cst.getNumLocalIds(); l < e; l++) { - addLocalId(0); - } - - for (unsigned t = 0, e = vMap->getNumResults(); t < e; t++) { - // TODO: Consider using a batched version to add a range of IDs. - addDimId(0); - cst.addDimId(0); - } - - assert(cst.getNumDimIds() <= getNumDimIds()); - for (unsigned t = 0, e = getNumDimIds() - cst.getNumDimIds(); t < e; t++) { - // Dimensions that are in 'this' but not in vMap/cst are added at the end. - cst.addDimId(cst.getNumDimIds()); - } - assert(cst.getNumSymbolIds() <= getNumSymbolIds()); - for (unsigned t = 0, e = getNumSymbolIds() - cst.getNumSymbolIds(); t < e; - t++) { - // Dimensions that are in 'this' but not in vMap/cst are added at the end. - cst.addSymbolId(cst.getNumSymbolIds()); - } - assert(cst.getNumLocalIds() <= getNumLocalIds()); - for (unsigned t = 0, e = getNumLocalIds() - cst.getNumLocalIds(); t < e; - t++) { - cst.addLocalId(cst.getNumLocalIds()); - } - /// Finally, append cst to this constraint set. - append(cst); - - // We add one equality for each result connecting the result dim of the map to - // the other identifiers. - // For eg: if the expression is 16*i0 + i1, and this is the r^th - // iteration/result of the value map, we are adding the equality: - // d_r - 16*i0 - i1 = 0. Hence, when flattening say (i0 + 1, i0 + 8*i2), we - // add two equalities overall: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0. - for (unsigned r = 0, e = flatExprs.size(); r < e; r++) { - const auto &flatExpr = flatExprs[r]; - // eqToAdd is the equality corresponding to the flattened affine expression. - SmallVector eqToAdd(getNumCols(), 0); - // Set the coefficient for this result to one. - eqToAdd[r] = 1; - - assert(flatExpr.size() >= vMap->getNumOperands() + 1); - - // Dims and symbols. - for (unsigned i = 0, e = vMap->getNumOperands(); i < e; i++) { - unsigned loc; - bool ret = findId(*vMap->getOperand(i), &loc); - assert(ret && "value map's id can't be found"); - (void)ret; - // We need to negate 'eq[r]' since the newly added dimension is going to - // be set to this one. - eqToAdd[loc] = -flatExpr[i]; - } - // Local vars common to eq and cst are at the beginning. - int j = getNumDimIds() + getNumSymbolIds(); - int end = flatExpr.size() - 1; - for (int i = vMap->getNumOperands(); i < end; i++, j++) { - eqToAdd[j] = -flatExpr[i]; - } - - // Constant term. - eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1]; - - // Add the equality connecting the result of the map to this constraint set. - addEquality(eqToAdd); - } - - return true; -} - -// Searches for a constraint with a non-zero coefficient at 'colIdx' in -// equality (isEq=true) or inequality (isEq=false) constraints. -// Returns true and sets row found in search in 'rowIdx'. -// Returns false otherwise. -static bool -findConstraintWithNonZeroAt(const FlatAffineConstraints &constraints, - unsigned colIdx, bool isEq, unsigned *rowIdx) { - auto at = [&](unsigned rowIdx) -> int64_t { - return isEq ? constraints.atEq(rowIdx, colIdx) - : constraints.atIneq(rowIdx, colIdx); - }; - unsigned e = - isEq ? constraints.getNumEqualities() : constraints.getNumInequalities(); - for (*rowIdx = 0; *rowIdx < e; ++(*rowIdx)) { - if (at(*rowIdx) != 0) { - return true; - } - } - return false; -} - -// Normalizes the coefficient values across all columns in 'rowIDx' by their -// GCD in equality or inequality contraints as specified by 'isEq'. -template -static void normalizeConstraintByGCD(FlatAffineConstraints *constraints, - unsigned rowIdx) { - auto at = [&](unsigned colIdx) -> int64_t { - return isEq ? constraints->atEq(rowIdx, colIdx) - : constraints->atIneq(rowIdx, colIdx); - }; - uint64_t gcd = std::abs(at(0)); - for (unsigned j = 1, e = constraints->getNumCols(); j < e; ++j) { - gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(at(j))); - } - if (gcd > 0 && gcd != 1) { - for (unsigned j = 0, e = constraints->getNumCols(); j < e; ++j) { - int64_t v = at(j) / static_cast(gcd); - isEq ? constraints->atEq(rowIdx, j) = v - : constraints->atIneq(rowIdx, j) = v; - } - } -} - -void FlatAffineConstraints::normalizeConstraintsByGCD() { - for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { - normalizeConstraintByGCD(this, i); - } - for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { - normalizeConstraintByGCD(this, i); - } -} - -bool FlatAffineConstraints::hasConsistentState() const { - if (inequalities.size() != getNumInequalities() * numReservedCols) - return false; - if (equalities.size() != getNumEqualities() * numReservedCols) - return false; - if (ids.size() != getNumIds()) - return false; - - // Catches errors where numDims, numSymbols, numIds aren't consistent. - if (numDims > numIds || numSymbols > numIds || numDims + numSymbols > numIds) - return false; - - return true; -} - -/// Checks all rows of equality/inequality constraints for trivial -/// contradictions (for example: 1 == 0, 0 >= 1), which may have surfaced -/// after elimination. Returns 'true' if an invalid constraint is found; -/// 'false' otherwise. -bool FlatAffineConstraints::hasInvalidConstraint() const { - assert(hasConsistentState()); - auto check = [&](bool isEq) -> bool { - unsigned numCols = getNumCols(); - unsigned numRows = isEq ? getNumEqualities() : getNumInequalities(); - for (unsigned i = 0, e = numRows; i < e; ++i) { - unsigned j; - for (j = 0; j < numCols - 1; ++j) { - int64_t v = isEq ? atEq(i, j) : atIneq(i, j); - // Skip rows with non-zero variable coefficients. - if (v != 0) - break; - } - if (j < numCols - 1) { - continue; - } - // Check validity of constant term at 'numCols - 1' w.r.t 'isEq'. - // Example invalid constraints include: '1 == 0' or '-1 >= 0' - int64_t v = isEq ? atEq(i, numCols - 1) : atIneq(i, numCols - 1); - if ((isEq && v != 0) || (!isEq && v < 0)) { - return true; - } - } - return false; - }; - if (check(/*isEq=*/true)) - return true; - return check(/*isEq=*/false); -} - -// Eliminate identifier from constraint at 'rowIdx' based on coefficient at -// pivotRow, pivotCol. Columns in range [elimColStart, pivotCol) will not be -// updated as they have already been eliminated. -static void eliminateFromConstraint(FlatAffineConstraints *constraints, - unsigned rowIdx, unsigned pivotRow, - unsigned pivotCol, unsigned elimColStart, - bool isEq) { - // Skip if equality 'rowIdx' if same as 'pivotRow'. - if (isEq && rowIdx == pivotRow) - return; - auto at = [&](unsigned i, unsigned j) -> int64_t { - return isEq ? constraints->atEq(i, j) : constraints->atIneq(i, j); - }; - int64_t leadCoeff = at(rowIdx, pivotCol); - // Skip if leading coefficient at 'rowIdx' is already zero. - if (leadCoeff == 0) - return; - int64_t pivotCoeff = constraints->atEq(pivotRow, pivotCol); - int64_t sign = (leadCoeff * pivotCoeff > 0) ? -1 : 1; - int64_t lcm = mlir::lcm(pivotCoeff, leadCoeff); - int64_t pivotMultiplier = sign * (lcm / std::abs(pivotCoeff)); - int64_t rowMultiplier = lcm / std::abs(leadCoeff); - - unsigned numCols = constraints->getNumCols(); - for (unsigned j = 0; j < numCols; ++j) { - // Skip updating column 'j' if it was just eliminated. - if (j >= elimColStart && j < pivotCol) - continue; - int64_t v = pivotMultiplier * constraints->atEq(pivotRow, j) + - rowMultiplier * at(rowIdx, j); - isEq ? constraints->atEq(rowIdx, j) = v - : constraints->atIneq(rowIdx, j) = v; - } -} - -// Remove coefficients in column range [colStart, colLimit) in place. -// This removes in data in the specified column range, and copies any -// remaining valid data into place. -static void shiftColumnsToLeft(FlatAffineConstraints *constraints, - unsigned colStart, unsigned colLimit, - bool isEq) { - assert(colStart >= 0 && colLimit <= constraints->getNumIds()); - if (colLimit <= colStart) - return; - - unsigned numCols = constraints->getNumCols(); - unsigned numRows = isEq ? constraints->getNumEqualities() - : constraints->getNumInequalities(); - unsigned numToEliminate = colLimit - colStart; - for (unsigned r = 0, e = numRows; r < e; ++r) { - for (unsigned c = colLimit; c < numCols; ++c) { - if (isEq) { - constraints->atEq(r, c - numToEliminate) = constraints->atEq(r, c); - } else { - constraints->atIneq(r, c - numToEliminate) = constraints->atIneq(r, c); - } - } - } -} - -// Removes identifiers in column range [idStart, idLimit), and copies any -// remaining valid data into place, and updates member variables. -void FlatAffineConstraints::removeIdRange(unsigned idStart, unsigned idLimit) { - assert(idLimit < getNumCols() && "invalid id limit"); - - if (idStart >= idLimit) - return; - - // We are going to be removing one or more identifiers from the range. - assert(idStart < numIds && "invalid idStart position"); - - // TODO(andydavis) Make 'removeIdRange' a lambda called from here. - // Remove eliminated identifiers from equalities. - shiftColumnsToLeft(this, idStart, idLimit, /*isEq=*/true); - - // Remove eliminated identifiers from inequalities. - shiftColumnsToLeft(this, idStart, idLimit, /*isEq=*/false); - - // Update members numDims, numSymbols and numIds. - unsigned numDimsEliminated = 0; - unsigned numLocalsEliminated = 0; - unsigned numColsEliminated = idLimit - idStart; - if (idStart < numDims) { - numDimsEliminated = std::min(numDims, idLimit) - idStart; - } - // Check how many local id's were removed. Note that our identifier order is - // [dims, symbols, locals]. Local id start at position numDims + numSymbols. - if (idLimit > numDims + numSymbols) { - numLocalsEliminated = std::min( - idLimit - std::max(idStart, numDims + numSymbols), getNumLocalIds()); - } - unsigned numSymbolsEliminated = - numColsEliminated - numDimsEliminated - numLocalsEliminated; - - numDims -= numDimsEliminated; - numSymbols -= numSymbolsEliminated; - numIds = numIds - numColsEliminated; - - ids.erase(ids.begin() + idStart, ids.begin() + idLimit); - - // No resize necessary. numReservedCols remains the same. -} - -/// Returns the position of the identifier that has the minimum times from the specified range of -/// identifiers [start, end). It is often best to eliminate in the increasing -/// order of these counts when doing Fourier-Motzkin elimination since FM adds -/// that many new constraints. -static unsigned getBestIdToEliminate(const FlatAffineConstraints &cst, - unsigned start, unsigned end) { - assert(start < cst.getNumIds() && end < cst.getNumIds() + 1); - - auto getProductOfNumLowerUpperBounds = [&](unsigned pos) { - unsigned numLb = 0; - unsigned numUb = 0; - for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) { - if (cst.atIneq(r, pos) > 0) { - ++numLb; - } else if (cst.atIneq(r, pos) < 0) { - ++numUb; - } - } - return numLb * numUb; - }; - - unsigned minLoc = start; - unsigned min = getProductOfNumLowerUpperBounds(start); - for (unsigned c = start + 1; c < end; c++) { - unsigned numLbUbProduct = getProductOfNumLowerUpperBounds(c); - if (numLbUbProduct < min) { - min = numLbUbProduct; - minLoc = c; - } - } - return minLoc; -} - -// Checks for emptiness of the set by eliminating identifiers successively and -// using the GCD test (on all equality constraints) and checking for trivially -// invalid constraints. Returns 'true' if the constraint system is found to be -// empty; false otherwise. -bool FlatAffineConstraints::isEmpty() const { - if (isEmptyByGCDTest() || hasInvalidConstraint()) - return true; - - // First, eliminate as many identifiers as possible using Gaussian - // elimination. - FlatAffineConstraints tmpCst(*this); - unsigned currentPos = 0; - while (currentPos < tmpCst.getNumIds()) { - tmpCst.gaussianEliminateIds(currentPos, tmpCst.getNumIds()); - ++currentPos; - // We check emptiness through trivial checks after eliminating each ID to - // detect emptiness early. Since the checks isEmptyByGCDTest() and - // hasInvalidConstraint() are linear time and single sweep on the constraint - // buffer, this appears reasonable - but can optimize in the future. - if (tmpCst.hasInvalidConstraint() || tmpCst.isEmptyByGCDTest()) - return true; - } - - // Eliminate the remaining using FM. - for (unsigned i = 0, e = tmpCst.getNumIds(); i < e; i++) { - tmpCst.FourierMotzkinEliminate( - getBestIdToEliminate(tmpCst, 0, tmpCst.getNumIds())); - // Check for a constraint explosion. This rarely happens in practice, but - // this check exists as a safeguard against improperly constructed - // constraint systems or artifically created arbitrarily complex systems - // that aren't the intended use case for FlatAffineConstraints. This is - // needed since FM has a worst case exponential complexity in theory. - if (tmpCst.getNumConstraints() >= kExplosionFactor * getNumIds()) { - LLVM_DEBUG(llvm::dbgs() << "FM constraint explosion detected"); - return false; - } - - // FM wouldn't have modified the equalities in any way. So no need to again - // run GCD test. Check for trivial invalid constraints. - if (tmpCst.hasInvalidConstraint()) - return true; - } - return false; -} - -// Runs the GCD test on all equality constraints. Returns 'true' if this test -// fails on any equality. Returns 'false' otherwise. -// This test can be used to disprove the existence of a solution. If it returns -// true, no integer solution to the equality constraints can exist. -// -// GCD test definition: -// -// The equality constraint: -// -// c_1*x_1 + c_2*x_2 + ... + c_n*x_n = c_0 -// -// has an integer solution iff: -// -// GCD of c_1, c_2, ..., c_n divides c_0. -// -bool FlatAffineConstraints::isEmptyByGCDTest() const { - assert(hasConsistentState()); - unsigned numCols = getNumCols(); - for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { - uint64_t gcd = std::abs(atEq(i, 0)); - for (unsigned j = 1; j < numCols - 1; ++j) { - gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atEq(i, j))); - } - int64_t v = std::abs(atEq(i, numCols - 1)); - if (gcd > 0 && (v % gcd != 0)) { - return true; - } - } - return false; -} - -/// Tightens inequalities given that we are dealing with integer spaces. This is -/// analogous to the GCD test but applied to inequalities. The constant term can -/// be reduced to the preceding multiple of the GCD of the coefficients, i.e., -/// 64*i - 100 >= 0 => 64*i - 128 >= 0 (since 'i' is an integer). This is a -/// fast method - linear in the number of coefficients. -// Example on how this affects practical cases: consider the scenario: -// 64*i >= 100, j = 64*i; without a tightening, elimination of i would yield -// j >= 100 instead of the tighter (exact) j >= 128. -void FlatAffineConstraints::GCDTightenInequalities() { - unsigned numCols = getNumCols(); - for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { - uint64_t gcd = std::abs(atIneq(i, 0)); - for (unsigned j = 1; j < numCols - 1; ++j) { - gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atIneq(i, j))); - } - if (gcd > 0) { - int64_t gcdI = static_cast(gcd); - atIneq(i, numCols - 1) = - gcdI * mlir::floorDiv(atIneq(i, numCols - 1), gcdI); - } - } -} - -// Eliminates all identifer variables in column range [posStart, posLimit). -// Returns the number of variables eliminated. -unsigned FlatAffineConstraints::gaussianEliminateIds(unsigned posStart, - unsigned posLimit) { - // Return if identifier positions to eliminate are out of range. - assert(posLimit <= numIds); - assert(hasConsistentState()); - - if (posStart >= posLimit) - return 0; - - LLVM_DEBUG(llvm::dbgs() << "Eliminating by Gaussian [" << posStart << ", " - << posLimit << ")\n"); - - GCDTightenInequalities(); - - unsigned pivotCol = 0; - for (pivotCol = posStart; pivotCol < posLimit; ++pivotCol) { - // Find a row which has a non-zero coefficient in column 'j'. - unsigned pivotRow; - if (!findConstraintWithNonZeroAt(*this, pivotCol, /*isEq=*/true, - &pivotRow)) { - // No pivot row in equalities with non-zero at 'pivotCol'. - if (!findConstraintWithNonZeroAt(*this, pivotCol, /*isEq=*/false, - &pivotRow)) { - // If inequalities are also non-zero in 'pivotCol', it can be - // eliminated. - continue; - } - break; - } - - // Eliminate identifier at 'pivotCol' from each equality row. - for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { - eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart, - /*isEq=*/true); - normalizeConstraintByGCD(this, i); - } - - // Eliminate identifier at 'pivotCol' from each inequality row. - for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { - eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart, - /*isEq=*/false); - normalizeConstraintByGCD(this, i); - } - removeEquality(pivotRow); - } - // Update position limit based on number eliminated. - posLimit = pivotCol; - // Remove eliminated columns from all constraints. - removeIdRange(posStart, posLimit); - return posLimit - posStart; -} - -// Detect the identifier at 'pos' (say id_r) as modulo of another identifier -// (say id_n) w.r.t a constant. When this happens, another identifier (say id_q) -// could be detected as the floordiv of n. For eg: -// id_n - 4*id_q - id_r = 0, 0 <= id_r <= 3 <=> -// id_r = id_n mod 4, id_q = id_n floordiv 4. -// lbConst and ubConst are the constant lower and upper bounds for 'pos' - -// pre-detected at the caller. -static bool detectAsMod(const FlatAffineConstraints &cst, unsigned pos, - int64_t lbConst, int64_t ubConst, - SmallVectorImpl *memo) { - assert(pos < cst.getNumIds() && "invalid position"); - - // Check if 0 <= id_r <= divisor - 1 and if id_r is equal to - // id_n - divisor * id_q. If these are true, then id_n becomes the dividend - // and id_q the quotient when dividing id_n by the divisor. - - if (lbConst != 0 || ubConst < 1) - return false; - - int64_t divisor = ubConst + 1; - - // Now check for: id_r = id_n - divisor * id_q. As an example, we - // are looking r = d - 4q, i.e., either r - d + 4q = 0 or -r + d - 4q = 0. - unsigned seenQuotient = 0, seenDividend = 0; - int quotientPos = -1, dividendPos = -1; - for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) { - // id_n should have coeff 1 or -1. - if (std::abs(cst.atEq(r, pos)) != 1) - continue; - for (unsigned c = 0, f = cst.getNumDimAndSymbolIds(); c < f; c++) { - // The coeff of the quotient should be -divisor if the coefficient of - // the pos^th identifier is -1, and divisor if the latter is -1. - if (cst.atEq(r, c) * cst.atEq(r, pos) == divisor) { - seenQuotient++; - quotientPos = c; - } else if (cst.atEq(r, c) * cst.atEq(r, pos) == -1) { - seenDividend++; - dividendPos = c; - } - } - // We are looking for exactly one identifier as part of the dividend. - // TODO(bondhugula): could be extended to cover multiple ones in the - // dividend to detect mod of an affine function of identifiers. - if (seenDividend == 1 && seenQuotient >= 1) { - if (!(*memo)[dividendPos]) - return false; - // Successfully detected a mod. - (*memo)[pos] = (*memo)[dividendPos] % divisor; - if (seenQuotient == 1 && !(*memo)[quotientPos]) - // Successfully detected a floordiv as well. - (*memo)[quotientPos] = (*memo)[dividendPos].floorDiv(divisor); - return true; - } - } - return false; -} - -// Check if the pos^th identifier can be expressed as a floordiv of an affine -// function of other identifiers (where the divisor is a positive constant). -// For eg: 4q <= i + j <= 4q + 3 <=> q = (i + j) floordiv 4. -bool detectAsFloorDiv(const FlatAffineConstraints &cst, unsigned pos, - SmallVectorImpl *memo, MLIRContext *context) { - assert(pos < cst.getNumIds() && "invalid position"); - SmallVector lbIndices, ubIndices; - - // Gather all lower bounds and upper bound constraints of this identifier. - // Since the canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint - // is a lower bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. - for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) { - if (cst.atIneq(r, pos) >= 1) - // Lower bound. - lbIndices.push_back(r); - else if (cst.atIneq(r, pos) <= -1) - // Upper bound. - ubIndices.push_back(r); - } - - // Check if any lower bound, upper bound pair is of the form: - // divisor * id >= expr - (divisor - 1) <-- Lower bound for 'id' - // divisor * id <= expr <-- Upper bound for 'id' - // Then, 'id' is equivalent to 'expr floordiv divisor'. (where divisor > 1). - // - // For example, if -32*k + 16*i + j >= 0 - // 32*k - 16*i - j + 31 >= 0 <=> - // k = ( 16*i + j ) floordiv 32 - unsigned seenDividends = 0; - for (auto ubPos : ubIndices) { - for (auto lbPos : lbIndices) { - // Check if lower bound's constant term is 'divisor - 1'. The 'divisor' - // here is cst.atIneq(lbPos, pos) and we already know that it's positive - // (since cst.Ineq(lbPos, ...) is a lower bound expression for 'pos'. - if (cst.atIneq(lbPos, cst.getNumCols() - 1) != cst.atIneq(lbPos, pos) - 1) - continue; - // Check if upper bound's constant term is 0. - if (cst.atIneq(ubPos, cst.getNumCols() - 1) != 0) - continue; - // For the remaining part, check if the lower bound expr's coeff's are - // negations of corresponding upper bound ones'. - unsigned c, f; - for (c = 0, f = cst.getNumCols() - 1; c < f; c++) { - if (cst.atIneq(lbPos, c) != -cst.atIneq(ubPos, c)) - break; - if (c != pos && cst.atIneq(lbPos, c) != 0) - seenDividends++; - } - // Lb coeff's aren't negative of ub coeff's (for the non constant term - // part). - if (c < f) - continue; - if (seenDividends >= 1) { - // The divisor is the constant term of the lower bound expression. - // We already know that cst.atIneq(lbPos, pos) > 0. - int64_t divisor = cst.atIneq(lbPos, pos); - // Construct the dividend expression. - auto dividendExpr = getAffineConstantExpr(0, context); - unsigned c, f; - for (c = 0, f = cst.getNumCols() - 1; c < f; c++) { - if (c == pos) - continue; - int64_t ubVal = cst.atIneq(ubPos, c); - if (ubVal == 0) - continue; - if (!(*memo)[c]) - break; - dividendExpr = dividendExpr + ubVal * (*memo)[c]; - } - // Expression can't be constructed as it depends on a yet unknown - // identifier. - // TODO(mlir-team): Visit/compute the identifiers in an order so that - // this doesn't happen. More complex but much more efficient. - if (c < f) - continue; - // Successfully detected the floordiv. - (*memo)[pos] = dividendExpr.floorDiv(divisor); - return true; - } - } - } - return false; -} - -/// Computes the lower and upper bounds of the first 'num' dimensional -/// identifiers as affine maps of the remaining identifiers (dimensional and -/// symbolic identifiers). Local identifiers are themselves explicitly computed -/// as affine functions of other identifiers in this process if needed. -void FlatAffineConstraints::getSliceBounds(unsigned num, MLIRContext *context, - SmallVectorImpl *lbMaps, - SmallVectorImpl *ubMaps) { - assert(num < getNumDimIds() && "invalid range"); - - // Basic simplification. - normalizeConstraintsByGCD(); - - LLVM_DEBUG(llvm::dbgs() << "getSliceBounds on:\n"); - LLVM_DEBUG(dump()); - - // Record computed/detected identifiers. - SmallVector memo(getNumIds(), AffineExpr::Null()); - // Initialize dimensional and symbolic identifiers. - for (unsigned i = num, e = getNumDimIds(); i < e; i++) - memo[i] = getAffineDimExpr(i - num, context); - for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++) - memo[i] = getAffineSymbolExpr(i - getNumDimIds(), context); - - bool changed; - do { - changed = false; - // Identify yet unknown identifiers as constants or mod's / floordiv's of - // other identifiers if possible. - for (unsigned pos = 0; pos < getNumIds(); pos++) { - if (memo[pos]) - continue; - - auto lbConst = getConstantLowerBound(pos); - auto ubConst = getConstantUpperBound(pos); - if (lbConst.hasValue() && ubConst.hasValue()) { - // Detect equality to a constant. - if (lbConst.getValue() == ubConst.getValue()) { - memo[pos] = getAffineConstantExpr(lbConst.getValue(), context); - changed = true; - continue; - } - - // Detect an identifier as modulo of another identifier w.r.t a - // constant. - if (detectAsMod(*this, pos, lbConst.getValue(), ubConst.getValue(), - &memo)) { - changed = true; - continue; - } - } - - // Detect an identifier as floordiv of another identifier w.r.t a - // constant. - if (detectAsFloorDiv(*this, pos, &memo, context)) { - changed = true; - continue; - } - - // Detect an identifier as an expression of other identifiers. - unsigned idx; - if (!findConstraintWithNonZeroAt(*this, pos, /*isEq=*/true, &idx)) { - continue; - } - - // Build AffineExpr solving for identifier 'pos' in terms of all others. - auto expr = getAffineConstantExpr(0, context); - unsigned j, e; - for (j = 0, e = getNumIds(); j < e; ++j) { - if (j == pos) - continue; - int64_t c = atEq(idx, j); - if (c == 0) - continue; - // If any of the involved IDs hasn't been found yet, we can't proceed. - if (!memo[j]) - break; - expr = expr + memo[j] * c; - } - if (j < e) - // Can't construct expression as it depends on a yet uncomputed - // identifier. - continue; - - // Add constant term to AffineExpr. - expr = expr + atEq(idx, getNumIds()); - int64_t vPos = atEq(idx, pos); - assert(vPos != 0 && "expected non-zero here"); - if (vPos > 0) - expr = (-expr).floorDiv(vPos); - else - // vPos < 0. - expr = expr.floorDiv(-vPos); - // Successfully constructed expression. - memo[pos] = expr; - changed = true; - } - // This loop is guaranteed to reach a fixed point - since once an - // identifier's explicit form is computed (in memo[pos]), it's not updated - // again. - } while (changed); - - // Set the lower and upper bound maps for all the identifiers that were - // computed as affine expressions of the rest as the "detected expr" and - // "detected expr + 1" respectively; set the undetected ones to Null(). - for (unsigned pos = 0; pos < num; pos++) { - unsigned numMapDims = getNumDimIds() - num; - unsigned numMapSymbols = getNumSymbolIds(); - AffineExpr expr = memo[pos]; - if (expr) - expr = simplifyAffineExpr(expr, numMapDims, numMapSymbols); - - if (expr) { - (*lbMaps)[pos] = AffineMap::get(numMapDims, numMapSymbols, expr, {}); - (*ubMaps)[pos] = AffineMap::get(numMapDims, numMapSymbols, expr + 1, {}); - } else { - // TODO(andydavis, bondhugula) Add support for computing slice bounds - // symbolic in the identifies [num, numIds). - auto lbConst = getConstantLowerBound(pos); - auto ubConst = getConstantUpperBound(pos); - if (lbConst.hasValue() && ubConst.hasValue()) { - (*lbMaps)[pos] = AffineMap::get( - numMapDims, numMapSymbols, - getAffineConstantExpr(lbConst.getValue(), context), {}); - (*ubMaps)[pos] = AffineMap::get( - numMapDims, numMapSymbols, - getAffineConstantExpr(ubConst.getValue() + 1, context), {}); - } else { - (*lbMaps)[pos] = AffineMap(); - (*ubMaps)[pos] = AffineMap(); - } - } - LLVM_DEBUG(llvm::dbgs() << "lb map for pos = " << Twine(pos) << ", expr: "); - LLVM_DEBUG(expr.dump();); - } -} - -void FlatAffineConstraints::addEquality(ArrayRef eq) { - assert(eq.size() == getNumCols()); - unsigned offset = equalities.size(); - equalities.resize(equalities.size() + numReservedCols); - std::copy(eq.begin(), eq.end(), equalities.begin() + offset); -} - -void FlatAffineConstraints::addInequality(ArrayRef inEq) { - assert(inEq.size() == getNumCols()); - unsigned offset = inequalities.size(); - inequalities.resize(inequalities.size() + numReservedCols); - std::copy(inEq.begin(), inEq.end(), inequalities.begin() + offset); -} - -void FlatAffineConstraints::addConstantLowerBound(unsigned pos, int64_t lb) { - assert(pos < getNumCols()); - unsigned offset = inequalities.size(); - inequalities.resize(inequalities.size() + numReservedCols); - std::fill(inequalities.begin() + offset, - inequalities.begin() + offset + getNumCols(), 0); - inequalities[offset + pos] = 1; - inequalities[offset + getNumCols() - 1] = -lb; -} - -void FlatAffineConstraints::addConstantUpperBound(unsigned pos, int64_t ub) { - assert(pos < getNumCols()); - unsigned offset = inequalities.size(); - inequalities.resize(inequalities.size() + numReservedCols); - std::fill(inequalities.begin() + offset, - inequalities.begin() + offset + getNumCols(), 0); - inequalities[offset + pos] = -1; - inequalities[offset + getNumCols() - 1] = ub; -} - -void FlatAffineConstraints::addConstantLowerBound(ArrayRef expr, - int64_t lb) { - assert(expr.size() == getNumCols()); - unsigned offset = inequalities.size(); - inequalities.resize(inequalities.size() + numReservedCols); - std::fill(inequalities.begin() + offset, - inequalities.begin() + offset + getNumCols(), 0); - std::copy(expr.begin(), expr.end(), inequalities.begin() + offset); - inequalities[offset + getNumCols() - 1] += -lb; -} - -void FlatAffineConstraints::addConstantUpperBound(ArrayRef expr, - int64_t ub) { - assert(expr.size() == getNumCols()); - unsigned offset = inequalities.size(); - inequalities.resize(inequalities.size() + numReservedCols); - std::fill(inequalities.begin() + offset, - inequalities.begin() + offset + getNumCols(), 0); - for (unsigned i = 0, e = getNumCols(); i < e; i++) { - inequalities[offset + i] = -expr[i]; - } - inequalities[offset + getNumCols() - 1] += ub; -} - -/// Adds a new local identifier as the floordiv of an affine function of other -/// identifiers, the coefficients of which are provided in 'dividend' and with -/// respect to a positive constant 'divisor'. Two constraints are added to the -/// system to capture equivalence with the floordiv. -/// q = expr floordiv c <=> c*q <= expr <= c*q + c - 1. -void FlatAffineConstraints::addLocalFloorDiv(ArrayRef dividend, - int64_t divisor) { - assert(dividend.size() == getNumCols() && "incorrect dividend size"); - assert(divisor > 0 && "positive divisor expected"); - - addLocalId(getNumLocalIds()); - - // Add two constraints for this new identifier 'q'. - SmallVector bound(dividend.size() + 1); - - // dividend - q * divisor >= 0 - std::copy(dividend.begin(), dividend.begin() + dividend.size() - 1, - bound.begin()); - bound.back() = dividend.back(); - bound[getNumIds() - 1] = -divisor; - addInequality(bound); - - // -dividend +qdivisor * q + divisor - 1 >= 0 - std::transform(bound.begin(), bound.end(), bound.begin(), - std::negate()); - bound[bound.size() - 1] += divisor - 1; - addInequality(bound); -} - -bool FlatAffineConstraints::findId(const Value &id, unsigned *pos) const { - unsigned i = 0; - for (const auto &mayBeId : ids) { - if (mayBeId.hasValue() && mayBeId.getValue() == &id) { - *pos = i; - return true; - } - i++; - } - return false; -} - -void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) { - assert(newSymbolCount <= numDims + numSymbols && - "invalid separation position"); - numDims = numDims + numSymbols - newSymbolCount; - numSymbols = newSymbolCount; -} - -bool FlatAffineConstraints::addAffineForOpDomain( - ConstOpPointer forOp) { - unsigned pos; - // Pre-condition for this method. - if (!findId(*forOp->getInductionVar(), &pos)) { - assert(0 && "Value not found"); - return false; - } - - if (forOp->getStep() != 1) - LLVM_DEBUG(llvm::dbgs() - << "Domain conservative: non-unit stride not handled\n"); - - // Adds a lower or upper bound when the bounds aren't constant. - auto addLowerOrUpperBound = [&](bool lower) -> bool { - auto operands = - lower ? forOp->getLowerBoundOperands() : forOp->getUpperBoundOperands(); - for (const auto &operand : operands) { - unsigned loc; - if (!findId(*operand, &loc)) { - if (isValidSymbol(operand)) { - addSymbolId(getNumSymbolIds(), const_cast(operand)); - loc = getNumDimIds() + getNumSymbolIds() - 1; - // Check if the symbol is a constant. - if (auto *opInst = operand->getDefiningInst()) { - if (auto constOp = opInst->dyn_cast()) { - setIdToConstant(*operand, constOp->getValue()); - } - } - } else { - addDimId(getNumDimIds(), const_cast(operand)); - loc = getNumDimIds() - 1; - } - } - } - // Record positions of the operands in the constraint system. - SmallVector positions; - for (const auto &operand : operands) { - unsigned loc; - if (!findId(*operand, &loc)) - assert(0 && "expected to be found"); - positions.push_back(loc); - } - - auto boundMap = - lower ? forOp->getLowerBoundMap() : forOp->getUpperBoundMap(); - - FlatAffineConstraints localVarCst; - std::vector> flatExprs; - if (!getFlattenedAffineExprs(boundMap, &flatExprs, &localVarCst)) { - LLVM_DEBUG(llvm::dbgs() << "semi-affine expressions not yet supported\n"); - return false; - } - if (localVarCst.getNumLocalIds() > 0) { - LLVM_DEBUG(llvm::dbgs() - << "loop bounds with mod/floordiv expr's not yet supported\n"); - return false; - } - - for (const auto &flatExpr : flatExprs) { - SmallVector ineq(getNumCols(), 0); - ineq[pos] = lower ? 1 : -1; - for (unsigned j = 0, e = boundMap.getNumInputs(); j < e; j++) { - ineq[positions[j]] = lower ? -flatExpr[j] : flatExpr[j]; - } - // Constant term. - ineq[getNumCols() - 1] = - lower ? -flatExpr[flatExpr.size() - 1] - // Upper bound in flattenedExpr is an exclusive one. - : flatExpr[flatExpr.size() - 1] - 1; - addInequality(ineq); - } - return true; - }; - - if (forOp->hasConstantLowerBound()) { - addConstantLowerBound(pos, forOp->getConstantLowerBound()); - } else { - // Non-constant lower bound case. - if (!addLowerOrUpperBound(/*lower=*/true)) - return false; - } - - if (forOp->hasConstantUpperBound()) { - addConstantUpperBound(pos, forOp->getConstantUpperBound() - 1); - return true; - } - // Non-constant upper bound case. - return addLowerOrUpperBound(/*lower=*/false); -} - -/// Sets the specified identifer to a constant value. -void FlatAffineConstraints::setIdToConstant(unsigned pos, int64_t val) { - unsigned offset = equalities.size(); - equalities.resize(equalities.size() + numReservedCols); - std::fill(equalities.begin() + offset, - equalities.begin() + offset + getNumCols(), 0); - equalities[offset + pos] = 1; - equalities[offset + getNumCols() - 1] = -val; -} - -/// Sets the specified identifer to a constant value; asserts if the id is not -/// found. -void FlatAffineConstraints::setIdToConstant(const Value &id, int64_t val) { - unsigned pos; - if (!findId(id, &pos)) - // This is a pre-condition for this method. - assert(0 && "id not found"); - setIdToConstant(pos, val); -} - -void FlatAffineConstraints::removeEquality(unsigned pos) { - unsigned numEqualities = getNumEqualities(); - assert(pos < numEqualities); - unsigned outputIndex = pos * numReservedCols; - unsigned inputIndex = (pos + 1) * numReservedCols; - unsigned numElemsToCopy = (numEqualities - pos - 1) * numReservedCols; - std::copy(equalities.begin() + inputIndex, - equalities.begin() + inputIndex + numElemsToCopy, - equalities.begin() + outputIndex); - equalities.resize(equalities.size() - numReservedCols); -} - -/// Finds an equality that equates the specified identifier to a constant. -/// Returns the position of the equality row. If 'symbolic' is set to true, -/// symbols are also treated like a constant, i.e., an affine function of the -/// symbols is also treated like a constant. -static int findEqualityToConstant(const FlatAffineConstraints &cst, - unsigned pos, bool symbolic = false) { - assert(pos < cst.getNumIds() && "invalid position"); - for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) { - int64_t v = cst.atEq(r, pos); - if (v * v != 1) - continue; - unsigned c; - unsigned f = symbolic ? cst.getNumDimIds() : cst.getNumIds(); - // This checks for zeros in all positions other than 'pos' in [0, f) - for (c = 0; c < f; c++) { - if (c == pos) - continue; - if (cst.atEq(r, c) != 0) { - // Dependent on another identifier. - break; - } - } - if (c == f) - // Equality is free of other identifiers. - return r; - } - return -1; -} - -void FlatAffineConstraints::setAndEliminate(unsigned pos, int64_t constVal) { - assert(pos < getNumIds() && "invalid position"); - for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { - atIneq(r, getNumCols() - 1) += atIneq(r, pos) * constVal; - } - for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { - atEq(r, getNumCols() - 1) += atEq(r, pos) * constVal; - } - removeId(pos); -} - -bool FlatAffineConstraints::constantFoldId(unsigned pos) { - assert(pos < getNumIds() && "invalid position"); - int rowIdx; - if ((rowIdx = findEqualityToConstant(*this, pos)) == -1) - return false; - - // atEq(rowIdx, pos) is either -1 or 1. - assert(atEq(rowIdx, pos) * atEq(rowIdx, pos) == 1); - int64_t constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos); - setAndEliminate(pos, constVal); - return true; -} - -void FlatAffineConstraints::constantFoldIdRange(unsigned pos, unsigned num) { - for (unsigned s = pos, t = pos, e = pos + num; s < e; s++) { - if (!constantFoldId(t)) - t++; - } -} - -/// Returns the extent (upper bound - lower bound) of the specified -/// identifier if it is found to be a constant; returns None if it's not a -/// constant. This methods treats symbolic identifiers specially, i.e., -/// it looks for constant differences between affine expressions involving -/// only the symbolic identifiers. See comments at function definition for -/// example. 'lb', if provided, is set to the lower bound associated with the -/// constant difference. Note that 'lb' is purely symbolic and thus will contain -/// the coefficients of the symbolic identifiers and the constant coefficient. -// Egs: 0 <= i <= 15, return 16. -// s0 + 2 <= i <= s0 + 17, returns 16. (s0 has to be a symbol) -// s0 + s1 + 16 <= d0 <= s0 + s1 + 31, returns 16. -// s0 - 7 <= 8*j <= s0 returns 1 with lb = s0, lbDivisor = 8 (since lb = -// ceil(s0 - 7 / 8) = floor(s0 / 8)). -Optional FlatAffineConstraints::getConstantBoundOnDimSize( - unsigned pos, SmallVectorImpl *lb, int64_t *lbFloorDivisor) const { - assert(pos < getNumDimIds() && "Invalid identifier position"); - assert(getNumLocalIds() == 0); - - // TODO(bondhugula): eliminate all remaining dimensional identifiers (other - // than the one at 'pos' to make this more powerful. Not needed for - // hyper-rectangular spaces. - - // Find an equality for 'pos'^th identifier that equates it to some function - // of the symbolic identifiers (+ constant). - int eqRow = findEqualityToConstant(*this, pos, /*symbolic=*/true); - if (eqRow != -1) { - // This identifier can only take a single value. - if (lb) { - // Set lb to the symbolic value. - lb->resize(getNumSymbolIds() + 1); - for (unsigned c = 0, f = getNumSymbolIds() + 1; c < f; c++) { - int64_t v = atEq(eqRow, pos); - // atEq(eqRow, pos) is either -1 or 1. - assert(v * v == 1); - (*lb)[c] = v < 0 ? atEq(eqRow, getNumDimIds() + c) / -v - : -atEq(eqRow, getNumDimIds() + c) / v; - } - assert(lbFloorDivisor && - "both lb and divisor or none should be provided"); - *lbFloorDivisor = 1; - } - return 1; - } - - // Check if the identifier appears at all in any of the inequalities. - unsigned r, e; - for (r = 0, e = getNumInequalities(); r < e; r++) { - if (atIneq(r, pos) != 0) - break; - } - if (r == e) - // If it doesn't, there isn't a bound on it. - return None; - - // Positions of constraints that are lower/upper bounds on the variable. - SmallVector lbIndices, ubIndices; - - // Gather all symbolic lower bounds and upper bounds of the variable. Since - // the canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a - // lower bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. - for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { - unsigned c, f; - for (c = 0, f = getNumDimIds(); c < f; c++) { - if (c != pos && atIneq(r, c) != 0) - break; - } - if (c < getNumDimIds()) - continue; - if (atIneq(r, pos) >= 1) - // Lower bound. - lbIndices.push_back(r); - else if (atIneq(r, pos) <= -1) - // Upper bound. - ubIndices.push_back(r); - } - - // TODO(bondhugula): eliminate other dimensional identifiers to make this more - // powerful. Not needed for hyper-rectangular iteration spaces. - - Optional minDiff = None; - unsigned minLbPosition; - for (auto ubPos : ubIndices) { - for (auto lbPos : lbIndices) { - // Look for a lower bound and an upper bound that only differ by a - // constant, i.e., pairs of the form 0 <= c_pos - f(c_i's) <= diffConst. - // For example, if ii is the pos^th variable, we are looking for - // constraints like ii >= i, ii <= ii + 50, 50 being the difference. The - // minimum among all such constant differences is kept since that's the - // constant bounding the extent of the pos^th variable. - unsigned j, e; - for (j = 0, e = getNumCols() - 1; j < e; j++) - if (atIneq(ubPos, j) != -atIneq(lbPos, j)) { - break; - } - if (j < getNumCols() - 1) - continue; - int64_t diff = floorDiv(atIneq(ubPos, getNumCols() - 1) + - atIneq(lbPos, getNumCols() - 1) + 1, - atIneq(lbPos, pos)); - if (minDiff == None || diff < minDiff) { - minDiff = diff; - minLbPosition = lbPos; - } - } - } - if (lb && minDiff.hasValue()) { - // Set lb to the symbolic lower bound. - lb->resize(getNumSymbolIds() + 1); - // The lower bound is the ceildiv of the lb constraint over the coefficient - // of the variable at 'pos'. We express the ceildiv equivalently as a floor - // for uniformity. For eg., if the lower bound constraint was: 32*d0 - N + - // 31 >= 0, the lower bound for d0 is ceil(N - 31, 32), i.e., floor(N, 32). - *lbFloorDivisor = atIneq(minLbPosition, pos); - for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++) { - // ceildiv (val / d) = floordiv (val + d - 1 / d); hence, the addition of - // 'atIneq(minLbPosition, pos) - 1'. - (*lb)[c] = -atIneq(minLbPosition, getNumDimIds() + c) + - atIneq(minLbPosition, pos) - 1; - } - } - return minDiff; -} - -template -Optional -FlatAffineConstraints::getConstantLowerOrUpperBound(unsigned pos) const { - // Check if there's an equality equating the 'pos'^th identifier to a - // constant. - int eqRowIdx = findEqualityToConstant(*this, pos, /*symbolic=*/false); - if (eqRowIdx != -1) - // atEq(rowIdx, pos) is either -1 or 1. - return -atEq(eqRowIdx, getNumCols() - 1) / atEq(eqRowIdx, pos); - - // Check if the identifier appears at all in any of the inequalities. - unsigned r, e; - for (r = 0, e = getNumInequalities(); r < e; r++) { - if (atIneq(r, pos) != 0) - break; - } - if (r == e) - // If it doesn't, there isn't a bound on it. - return None; - - Optional minOrMaxConst = None; - - // Take the max across all const lower bounds (or min across all constant - // upper bounds). - for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { - if (isLower) { - if (atIneq(r, pos) <= 0) - // Not a lower bound. - continue; - } else if (atIneq(r, pos) >= 0) { - // Not an upper bound. - continue; - } - unsigned c, f; - for (c = 0, f = getNumCols() - 1; c < f; c++) - if (c != pos && atIneq(r, c) != 0) - break; - if (c < getNumCols() - 1) - // Not a constant bound. - continue; - - int64_t boundConst = - isLower ? mlir::ceilDiv(-atIneq(r, getNumCols() - 1), atIneq(r, pos)) - : mlir::floorDiv(atIneq(r, getNumCols() - 1), -atIneq(r, pos)); - if (isLower) { - if (minOrMaxConst == None || boundConst > minOrMaxConst) - minOrMaxConst = boundConst; - } else { - if (minOrMaxConst == None || boundConst < minOrMaxConst) - minOrMaxConst = boundConst; - } - } - return minOrMaxConst; -} - -Optional -FlatAffineConstraints::getConstantLowerBound(unsigned pos) const { - return getConstantLowerOrUpperBound(pos); -} - -Optional -FlatAffineConstraints::getConstantUpperBound(unsigned pos) const { - return getConstantLowerOrUpperBound(pos); -} - -// A simple (naive and conservative) check for hyper-rectangularlity. -bool FlatAffineConstraints::isHyperRectangular(unsigned pos, - unsigned num) const { - assert(pos < getNumCols() - 1); - // Check for two non-zero coefficients in the range [pos, pos + sum). - for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { - unsigned sum = 0; - for (unsigned c = pos; c < pos + num; c++) { - if (atIneq(r, c) != 0) - sum++; - } - if (sum > 1) - return false; - } - for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { - unsigned sum = 0; - for (unsigned c = pos; c < pos + num; c++) { - if (atEq(r, c) != 0) - sum++; - } - if (sum > 1) - return false; - } - return true; -} - -void FlatAffineConstraints::print(raw_ostream &os) const { - assert(hasConsistentState()); - os << "\nConstraints (" << getNumDimIds() << " dims, " << getNumSymbolIds() - << " symbols, " << getNumLocalIds() << " locals), (" << getNumConstraints() - << " constraints)\n"; - os << "("; - for (unsigned i = 0, e = getNumIds(); i < e; i++) { - if (ids[i] == None) - os << "None "; - else - os << "Value "; - } - os << " const)\n"; - for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { - for (unsigned j = 0, f = getNumCols(); j < f; ++j) { - os << atEq(i, j) << " "; - } - os << "= 0\n"; - } - for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { - for (unsigned j = 0, f = getNumCols(); j < f; ++j) { - os << atIneq(i, j) << " "; - } - os << ">= 0\n"; - } - os << '\n'; -} - -void FlatAffineConstraints::dump() const { print(llvm::errs()); } - -/// Removes duplicate constraints and trivially true constraints: a constraint -/// of the form >= 0 is considered a trivially true -/// constraint. -// Uses a DenseSet to hash and detect duplicates followed by a linear scan to -// remove duplicates in place. -void FlatAffineConstraints::removeTrivialRedundancy() { - DenseSet> rowSet; - - // Check if constraint is of the form >= 0. - auto isTriviallyValid = [&](unsigned r) -> bool { - for (unsigned c = 0, e = getNumCols() - 1; c < e; c++) { - if (atIneq(r, c) != 0) - return false; - } - return atIneq(r, getNumCols() - 1) >= 0; - }; - - // Detect and mark redundant constraints. - std::vector redunIneq(getNumInequalities(), false); - for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { - int64_t *rowStart = inequalities.data() + numReservedCols * r; - auto row = ArrayRef(rowStart, getNumCols()); - if (isTriviallyValid(r) || !rowSet.insert(row).second) { - redunIneq[r] = true; - } - } - - auto copyRow = [&](unsigned src, unsigned dest) { - if (src == dest) - return; - for (unsigned c = 0, e = getNumCols(); c < e; c++) { - atIneq(dest, c) = atIneq(src, c); - } - }; - - // Scan to get rid of all rows marked redundant, in-place. - unsigned pos = 0; - for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { - if (!redunIneq[r]) - copyRow(r, pos++); - } - inequalities.resize(numReservedCols * pos); - - // TODO(bondhugula): consider doing this for equalities as well, but probably - // not worth the savings. -} - -void FlatAffineConstraints::clearAndCopyFrom( - const FlatAffineConstraints &other) { - FlatAffineConstraints copy(other); - std::swap(*this, copy); - assert(copy.getNumIds() == copy.getIds().size()); -} - -void FlatAffineConstraints::removeId(unsigned pos) { - removeIdRange(pos, pos + 1); -} - -static std::pair -getNewNumDimsSymbols(unsigned pos, const FlatAffineConstraints &cst) { - unsigned numDims = cst.getNumDimIds(); - unsigned numSymbols = cst.getNumSymbolIds(); - unsigned newNumDims, newNumSymbols; - if (pos < numDims) { - newNumDims = numDims - 1; - newNumSymbols = numSymbols; - } else if (pos < numDims + numSymbols) { - assert(numSymbols >= 1); - newNumDims = numDims; - newNumSymbols = numSymbols - 1; - } else { - newNumDims = numDims; - newNumSymbols = numSymbols; - } - return {newNumDims, newNumSymbols}; -} - -#undef DEBUG_TYPE -#define DEBUG_TYPE "fm" - -/// Eliminates identifier at the specified position using Fourier-Motzkin -/// variable elimination. This technique is exact for rational spaces but -/// conservative (in "rare" cases) for integer spaces. The operation corresponds -/// to a projection operation yielding the (convex) set of integer points -/// contained in the rational shadow of the set. An emptiness test that relies -/// on this method will guarantee emptiness, i.e., it disproves the existence of -/// a solution if it says it's empty. -/// If a non-null isResultIntegerExact is passed, it is set to true if the -/// result is also integer exact. If it's set to false, the obtained solution -/// *may* not be exact, i.e., it may contain integer points that do not have an -/// integer pre-image in the original set. -/// -/// Eg: -/// j >= 0, j <= i + 1 -/// i >= 0, i <= N + 1 -/// Eliminating i yields, -/// j >= 0, 0 <= N + 1, j - 1 <= N + 1 -/// -/// If darkShadow = true, this method computes the dark shadow on elimination; -/// the dark shadow is a convex integer subset of the exact integer shadow. A -/// non-empty dark shadow proves the existence of an integer solution. The -/// elimination in such a case could however be an under-approximation, and thus -/// should not be used for scanning sets or used by itself for dependence -/// checking. -/// -/// Eg: 2-d set, * represents grid points, 'o' represents a point in the set. -/// ^ -/// | -/// | * * * * o o -/// i | * * o o o o -/// | o * * * * * -/// ---------------> -/// j -> -/// -/// Eliminating i from this system (projecting on the j dimension): -/// rational shadow / integer light shadow: 1 <= j <= 6 -/// dark shadow: 3 <= j <= 6 -/// exact integer shadow: j = 1 \union 3 <= j <= 6 -/// holes/splinters: j = 2 -/// -/// darkShadow = false, isResultIntegerExact = nullptr are default values. -// TODO(bondhugula): a slight modification to yield dark shadow version of FM -// (tightened), which can prove the existence of a solution if there is one. -void FlatAffineConstraints::FourierMotzkinEliminate( - unsigned pos, bool darkShadow, bool *isResultIntegerExact) { - LLVM_DEBUG(llvm::dbgs() << "FM input (eliminate pos " << pos << "):\n"); - LLVM_DEBUG(dump()); - assert(pos < getNumIds() && "invalid position"); - assert(hasConsistentState()); - - // Check if this identifier can be eliminated through a substitution. - for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { - if (atEq(r, pos) != 0) { - // Use Gaussian elimination here (since we have an equality). - bool ret = gaussianEliminateId(pos); - (void)ret; - assert(ret && "Gaussian elimination guaranteed to succeed"); - LLVM_DEBUG(llvm::dbgs() << "FM output:\n"); - LLVM_DEBUG(dump()); - return; - } - } - - // A fast linear time tightening. - GCDTightenInequalities(); - - // Check if the identifier appears at all in any of the inequalities. - unsigned r, e; - for (r = 0, e = getNumInequalities(); r < e; r++) { - if (atIneq(r, pos) != 0) - break; - } - if (r == getNumInequalities()) { - // If it doesn't appear, just remove the column and return. - // TODO(andydavis,bondhugula): refactor removeColumns to use it from here. - removeId(pos); - LLVM_DEBUG(llvm::dbgs() << "FM output:\n"); - LLVM_DEBUG(dump()); - return; - } - - // Positions of constraints that are lower bounds on the variable. - SmallVector lbIndices; - // Positions of constraints that are lower bounds on the variable. - SmallVector ubIndices; - // Positions of constraints that do not involve the variable. - std::vector nbIndices; - nbIndices.reserve(getNumInequalities()); - - // Gather all lower bounds and upper bounds of the variable. Since the - // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower - // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. - for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { - if (atIneq(r, pos) == 0) { - // Id does not appear in bound. - nbIndices.push_back(r); - } else if (atIneq(r, pos) >= 1) { - // Lower bound. - lbIndices.push_back(r); - } else { - // Upper bound. - ubIndices.push_back(r); - } - } - - // Set the number of dimensions, symbols in the resulting system. - const auto &dimsSymbols = getNewNumDimsSymbols(pos, *this); - unsigned newNumDims = dimsSymbols.first; - unsigned newNumSymbols = dimsSymbols.second; - - SmallVector, 8> newIds; - newIds.reserve(numIds - 1); - newIds.append(ids.begin(), ids.begin() + pos); - newIds.append(ids.begin() + pos + 1, ids.end()); - - /// Create the new system which has one identifier less. - FlatAffineConstraints newFac( - lbIndices.size() * ubIndices.size() + nbIndices.size(), - getNumEqualities(), getNumCols() - 1, newNumDims, newNumSymbols, - /*numLocals=*/getNumIds() - 1 - newNumDims - newNumSymbols, newIds); - - assert(newFac.getIds().size() == newFac.getNumIds()); - - // This will be used to check if the elimination was integer exact. - unsigned lcmProducts = 1; - - // Let x be the variable we are eliminating. - // For each lower bound, lb <= c_l*x, and each upper bound c_u*x <= ub, (note - // that c_l, c_u >= 1) we have: - // lb*lcm(c_l, c_u)/c_l <= lcm(c_l, c_u)*x <= ub*lcm(c_l, c_u)/c_u - // We thus generate a constraint: - // lcm(c_l, c_u)/c_l*lb <= lcm(c_l, c_u)/c_u*ub. - // Note if c_l = c_u = 1, all integer points captured by the resulting - // constraint correspond to integer points in the original system (i.e., they - // have integer pre-images). Hence, if the lcm's are all 1, the elimination is - // integer exact. - for (auto ubPos : ubIndices) { - for (auto lbPos : lbIndices) { - SmallVector ineq; - ineq.reserve(newFac.getNumCols()); - int64_t lbCoeff = atIneq(lbPos, pos); - // Note that in the comments above, ubCoeff is the negation of the - // coefficient in the canonical form as the view taken here is that of the - // term being moved to the other size of '>='. - int64_t ubCoeff = -atIneq(ubPos, pos); - // TODO(bondhugula): refactor this loop to avoid all branches inside. - for (unsigned l = 0, e = getNumCols(); l < e; l++) { - if (l == pos) - continue; - assert(lbCoeff >= 1 && ubCoeff >= 1 && "bounds wrongly identified"); - int64_t lcm = mlir::lcm(lbCoeff, ubCoeff); - ineq.push_back(atIneq(ubPos, l) * (lcm / ubCoeff) + - atIneq(lbPos, l) * (lcm / lbCoeff)); - lcmProducts *= lcm; - } - if (darkShadow) { - // The dark shadow is a convex subset of the exact integer shadow. If - // there is a point here, it proves the existence of a solution. - ineq[ineq.size() - 1] += lbCoeff * ubCoeff - lbCoeff - ubCoeff + 1; - } - // TODO: we need to have a way to add inequalities in-place in - // FlatAffineConstraints instead of creating and copying over. - newFac.addInequality(ineq); - } - } - - if (lcmProducts == 1 && isResultIntegerExact) - *isResultIntegerExact = 1; - - // Copy over the constraints not involving this variable. - for (auto nbPos : nbIndices) { - SmallVector ineq; - ineq.reserve(getNumCols() - 1); - for (unsigned l = 0, e = getNumCols(); l < e; l++) { - if (l == pos) - continue; - ineq.push_back(atIneq(nbPos, l)); - } - newFac.addInequality(ineq); - } - - assert(newFac.getNumConstraints() == - lbIndices.size() * ubIndices.size() + nbIndices.size()); - - // Copy over the equalities. - for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { - SmallVector eq; - eq.reserve(newFac.getNumCols()); - for (unsigned l = 0, e = getNumCols(); l < e; l++) { - if (l == pos) - continue; - eq.push_back(atEq(r, l)); - } - newFac.addEquality(eq); - } - - newFac.removeTrivialRedundancy(); - clearAndCopyFrom(newFac); - LLVM_DEBUG(llvm::dbgs() << "FM output:\n"); - LLVM_DEBUG(dump()); -} - -#undef DEBUG_TYPE -#define DEBUG_TYPE "affine-structures" - -void FlatAffineConstraints::projectOut(unsigned pos, unsigned num) { - if (num == 0) - return; - - // 'pos' can be at most getNumCols() - 2 if num > 0. - assert(getNumCols() < 2 || pos <= getNumCols() - 2 && "invalid position"); - assert(pos + num < getNumCols() && "invalid range"); - - // Eliminate as many identifiers as possible using Gaussian elimination. - unsigned currentPos = pos; - unsigned numToEliminate = num; - unsigned numGaussianEliminated = 0; - - while (currentPos < getNumIds()) { - unsigned curNumEliminated = - gaussianEliminateIds(currentPos, currentPos + numToEliminate); - ++currentPos; - numToEliminate -= curNumEliminated + 1; - numGaussianEliminated += curNumEliminated; - } - - // Eliminate the remaining using Fourier-Motzkin. - for (unsigned i = 0; i < num - numGaussianEliminated; i++) { - unsigned numToEliminate = num - numGaussianEliminated - i; - FourierMotzkinEliminate( - getBestIdToEliminate(*this, pos, pos + numToEliminate)); - } - - // Fast/trivial simplifications. - GCDTightenInequalities(); - // Normalize constraints after tightening since the latter impacts this, but - // not the other way round. - normalizeConstraintsByGCD(); -} - -void FlatAffineConstraints::projectOut(Value *id) { - unsigned pos; - bool ret = findId(*id, &pos); - assert(ret); - (void)ret; - FourierMotzkinEliminate(pos); -} - -bool FlatAffineConstraints::isRangeOneToOne(unsigned start, - unsigned limit) const { - assert(start <= getNumIds() - 1 && "invalid start position"); - assert(limit > start && limit <= getNumIds() && "invalid limit"); - - FlatAffineConstraints tmpCst(*this); - - if (start != 0) { - // Move [start, limit) to the left. - for (unsigned r = 0, e = getNumInequalities(); r < e; ++r) { - for (unsigned c = 0, f = getNumCols(); c < f; ++c) { - if (c >= start && c < limit) - tmpCst.atIneq(r, c - start) = atIneq(r, c); - else if (c < start) - tmpCst.atIneq(r, c + limit - start) = atIneq(r, c); - else - tmpCst.atIneq(r, c) = atIneq(r, c); - } - } - for (unsigned r = 0, e = getNumEqualities(); r < e; ++r) { - for (unsigned c = 0, f = getNumCols(); c < f; ++c) { - if (c >= start && c < limit) - tmpCst.atEq(r, c - start) = atEq(r, c); - else if (c < start) - tmpCst.atEq(r, c + limit - start) = atEq(r, c); - else - tmpCst.atEq(r, c) = atEq(r, c); - } - } - } - - // Mark everything to the right as symbols so that we can check the extents in - // a symbolic way below. - tmpCst.setDimSymbolSeparation(getNumIds() - (limit - start)); - - // Check if the extents of all the specified dimensions are just one (when - // treating the rest as symbols). - for (unsigned pos = 0, e = tmpCst.getNumDimIds(); pos < e; ++pos) { - auto extent = tmpCst.getConstantBoundOnDimSize(pos); - if (!extent.hasValue() || extent.getValue() != 1) - return false; - } - return true; -} - -void FlatAffineConstraints::clearConstraints() { - equalities.clear(); - inequalities.clear(); -} - -namespace { - -enum BoundCmpResult { Greater, Less, Equal, Unknown }; - -/// Compares two affine bounds whose coefficients are provided in 'first' and -/// 'second'. The last coefficient is the constant term. -static BoundCmpResult compareBounds(ArrayRef a, ArrayRef b) { - assert(a.size() == b.size()); - - // For the bounds to be comparable, their corresponding identifier - // coefficients should be equal; the constant terms are then compared to - // determine less/greater/equal. - - if (!std::equal(a.begin(), a.end() - 1, b.begin())) - return Unknown; - - if (a.back() == b.back()) - return Equal; - - return a.back() < b.back() ? Less : Greater; -} -}; // namespace - -// Compute the bounding box with respect to 'other' by finding the min of the -// lower bounds and the max of the upper bounds along each of the dimensions. -bool FlatAffineConstraints::unionBoundingBox( - const FlatAffineConstraints &other) { - assert(other.getNumDimIds() == numDims); - assert(other.getNumSymbolIds() == getNumSymbolIds()); - assert(other.getNumLocalIds() == 0); - assert(getNumLocalIds() == 0); - std::vector> boundingLbs; - std::vector> boundingUbs; - boundingLbs.reserve(2 * getNumDimIds()); - boundingUbs.reserve(2 * getNumDimIds()); - - SmallVector lb, otherLb; - lb.reserve(getNumSymbolIds() + 1); - otherLb.reserve(getNumSymbolIds() + 1); - int64_t lbDivisor, otherLbDivisor; - for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) { - lb.clear(); - auto extent = getConstantBoundOnDimSize(d, &lb, &lbDivisor); - if (!extent.hasValue()) - // TODO(bondhugula): symbolic extents when necessary. - return false; - - otherLb.clear(); - auto otherExtent = - other.getConstantBoundOnDimSize(d, &otherLb, &otherLbDivisor); - if (!otherExtent.hasValue() || lbDivisor != otherLbDivisor) - // TODO(bondhugula): symbolic extents when necessary. - return false; - - assert(lbDivisor > 0 && "divisor always expected to be positive"); - - // Compute min of lower bounds and max of upper bounds. - ArrayRef minLb, maxUb; - - auto res = compareBounds(lb, otherLb); - // Identify min. - if (res == BoundCmpResult::Less || res == BoundCmpResult::Equal) { - minLb = lb; - } else if (res == BoundCmpResult::Greater) { - minLb = otherLb; - } else { - // Uncomparable. - return false; - } - - // Do the same for ub's but max of upper bounds. - SmallVector ub(lb), otherUb(otherLb); - ub.back() += extent.getValue() - 1; - otherUb.back() += otherExtent.getValue() - 1; - - // Identify max. - auto uRes = compareBounds(ub, otherUb); - if (uRes == BoundCmpResult::Greater || uRes == BoundCmpResult::Equal) { - maxUb = ub; - } else if (uRes == BoundCmpResult::Less) { - maxUb = otherUb; - } else { - // Uncomparable. - return false; - } - - SmallVector newLb(getNumCols(), 0); - SmallVector newUb(getNumCols(), 0); - - // The divisor for lb, ub, otherLb, otherUb at this point is lbDivisor, - // and so it's the divisor for newLb and newUb as well. - newLb[d] = lbDivisor; - newUb[d] = -lbDivisor; - // Copy over the symbolic part + constant term. - std::copy(minLb.begin(), minLb.end(), newLb.begin() + getNumDimIds()); - std::transform(newLb.begin() + getNumDimIds(), newLb.end(), - newLb.begin() + getNumDimIds(), std::negate()); - std::copy(maxUb.begin(), maxUb.end(), newUb.begin() + getNumDimIds()); - - boundingLbs.push_back(newLb); - boundingUbs.push_back(newUb); - } - - // Clear all constraints and add the lower/upper bounds for the bounding box. - clearConstraints(); - for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) { - addInequality(boundingLbs[d]); - addInequality(boundingUbs[d]); - } - - return true; -} diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 4ded1bfc400..ce32806cb70 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -23,9 +23,9 @@ #include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" -#include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/NestedMatcher.h" #include "mlir/Analysis/VectorAnalysis.h" +#include "mlir/IR/AffineStructures.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Instruction.h" @@ -147,7 +147,8 @@ bool mlir::isAccessInvariant(const Value &iv, const Value &index) { auto composeOp = affineApplyOps[0]->cast(); // We need yet another level of indirection because the `dim` index of the // access may not correspond to the `dim` index of composeOp. - return !AffineValueMap(*composeOp).isFunctionOf(0, const_cast(&iv)); + return !composeOp->getAsAffineValueMap().isFunctionOf( + 0, const_cast(&iv)); } llvm::DenseSet diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index 3376cd7d512..3482f24dfcc 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -21,9 +21,9 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/AffineAnalysis.h" -#include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Passes.h" #include "mlir/Analysis/Utils.h" +#include "mlir/IR/AffineStructures.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass.h" diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp index 9ec1c95f213..3f26ae5d2fe 100644 --- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp @@ -20,9 +20,9 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/AffineAnalysis.h" -#include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Passes.h" #include "mlir/Analysis/Utils.h" +#include "mlir/IR/AffineStructures.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass.h" diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 9f5fd65e774..892701863eb 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -24,7 +24,7 @@ #include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" -#include "mlir/Analysis/AffineStructures.h" +#include "mlir/IR/AffineStructures.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/StandardOps/StandardOps.h" @@ -163,7 +163,7 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth) { // bounds expressions involve outer loops or other symbols. // TODO(bondhugula): rewrite this to use getInstIndexSet; this way // conditionals will be handled when the latter supports it. - if (!cst.addAffineForOpDomain(loop)) + if (!addAffineForOpDomain(loop, &cst)) return false; } else { // Has to be a valid symbol. diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp index 39d24b339f4..c029ef3df52 100644 --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -19,6 +19,8 @@ #include "AffineExprDetail.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/AffineStructures.h" +#include "mlir/IR/IntegerSet.h" #include "mlir/Support/STLExtras.h" #include "llvm/ADT/STLExtras.h" @@ -293,3 +295,446 @@ raw_ostream &operator<<(raw_ostream &os, AffineExpr &expr) { expr.print(os); return os; } + +/// Constructs an affine expression from a flat ArrayRef. If there are local +/// identifiers (neither dimensional nor symbolic) that appear in the sum of +/// products expression, 'localExprs' is expected to have the AffineExpr +/// for it, and is substituted into. The ArrayRef 'eq' is expected to be in the +/// format [dims, symbols, locals, constant term]. +// TODO(bondhugula): refactor getAddMulPureAffineExpr to reuse it from here. +static AffineExpr toAffineExpr(ArrayRef eq, unsigned numDims, + unsigned numSymbols, + ArrayRef localExprs, + MLIRContext *context) { + // Assert expected numLocals = eq.size() - numDims - numSymbols - 1 + assert(eq.size() - numDims - numSymbols - 1 == localExprs.size() && + "unexpected number of local expressions"); + + auto expr = getAffineConstantExpr(0, context); + // Dimensions and symbols. + for (unsigned j = 0; j < numDims + numSymbols; j++) { + if (eq[j] == 0) { + continue; + } + auto id = j < numDims ? getAffineDimExpr(j, context) + : getAffineSymbolExpr(j - numDims, context); + expr = expr + id * eq[j]; + } + + // Local identifiers. + for (unsigned j = numDims + numSymbols, e = eq.size() - 1; j < e; j++) { + if (eq[j] == 0) { + continue; + } + auto term = localExprs[j - numDims - numSymbols] * eq[j]; + expr = expr + term; + } + + // Constant term. + int64_t constTerm = eq[eq.size() - 1]; + if (constTerm != 0) + expr = expr + constTerm; + return expr; +} + +namespace { + +// This class is used to flatten a pure affine expression (AffineExpr, +// which is in a tree form) into a sum of products (w.r.t constants) when +// possible, and in that process simplifying the expression. For a modulo, +// floordiv, or a ceildiv expression, an additional identifier, called a local +// identifier, is introduced to rewrite the expression as a sum of product +// affine expression. Each local identifier is always and by construction a +// floordiv of a pure add/mul affine function of dimensional, symbolic, and +// other local identifiers, in a non-mutually recursive way. Hence, every local +// identifier can ultimately always be recovered as an affine function of +// dimensional and symbolic identifiers (involving floordiv's); note however +// that by AffineExpr construction, some floordiv combinations are converted to +// mod's. The result of the flattening is a flattened expression and a set of +// constraints involving just the local variables. +// +// d2 + (d0 + d1) floordiv 4 is flattened to d2 + q where 'q' is the local +// variable introduced, with localVarCst containing 4*q <= d0 + d1 <= 4*q + 3. +// +// The simplification performed includes the accumulation of contributions for +// each dimensional and symbolic identifier together, the simplification of +// floordiv/ceildiv/mod expressions and other simplifications that in turn +// happen as a result. A simplification that this flattening naturally performs +// is of simplifying the numerator and denominator of floordiv/ceildiv, and +// folding a modulo expression to a zero, if possible. Three examples are below: +// +// (d0 + 3 * d1) + d0) - 2 * d1) - d0 simplified to d0 + d1 +// (d0 - d0 mod 4 + 4) mod 4 simplified to 0 +// (3*d0 + 2*d1 + d0) floordiv 2 + d1 simplified to 2*d0 + 2*d1 +// +// The way the flattening works for the second example is as follows: d0 % 4 is +// replaced by d0 - 4*q with q being introduced: the expression then simplifies +// to: (d0 - (d0 - 4q) + 4) = 4q + 4, modulo of which w.r.t 4 simplifies to +// zero. Note that an affine expression may not always be expressible purely as +// a sum of products involving just the original dimensional and symbolic +// identifiers due to the presence of modulo/floordiv/ceildiv expressions that +// may not be eliminated after simplification; in such cases, the final +// expression can be reconstructed by replacing the local identifiers with their +// corresponding explicit form stored in 'localExprs' (note that each of the +// explicit forms itself would have been simplified). +// +// The expression walk method here performs a linear time post order walk that +// performs the above simplifications through visit methods, with partial +// results being stored in 'operandExprStack'. When a parent expr is visited, +// the flattened expressions corresponding to its two operands would already be +// on the stack - the parent expression looks at the two flattened expressions +// and combines the two. It pops off the operand expressions and pushes the +// combined result (although this is done in-place on its LHS operand expr). +// When the walk is completed, the flattened form of the top-level expression +// would be left on the stack. +// +// A flattener can be repeatedly used for multiple affine expressions that bind +// to the same operands, for example, for all result expressions of an +// AffineMap or AffineValueMap. In such cases, using it for multiple expressions +// is more efficient than creating a new flattener for each expression since +// common idenical div and mod expressions appearing across different +// expressions are mapped to the same local identifier (same column position in +// 'localVarCst'). +struct AffineExprFlattener : public AffineExprVisitor { +public: + // Flattend expression layout: [dims, symbols, locals, constant] + // Stack that holds the LHS and RHS operands while visiting a binary op expr. + // In future, consider adding a prepass to determine how big the SmallVector's + // will be, and linearize this to std::vector to prevent + // SmallVector moves on re-allocation. + std::vector> operandExprStack; + // Constraints connecting newly introduced local variables (for mod's and + // div's) to existing (dimensional and symbolic) ones. These are always + // inequalities. + FlatAffineConstraints localVarCst; + + unsigned numDims; + unsigned numSymbols; + // Number of newly introduced identifiers to flatten mod/floordiv/ceildiv + // expressions that could not be simplified. + unsigned numLocals; + // AffineExpr's corresponding to the floordiv/ceildiv/mod expressions for + // which new identifiers were introduced; if the latter do not get canceled + // out, these expressions can be readily used to reconstruct the AffineExpr + // (tree) form. Note that these expressions themselves would have been + // simplified (recursively) by this pass. Eg. d0 + (d0 + 2*d1 + d0) ceildiv 4 + // will be simplified to d0 + q, where q = (d0 + d1) ceildiv 2. (d0 + d1) + // ceildiv 2 would be the local expression stored for q. + SmallVector localExprs; + MLIRContext *context; + + AffineExprFlattener(unsigned numDims, unsigned numSymbols, + MLIRContext *context) + : numDims(numDims), numSymbols(numSymbols), numLocals(0), + context(context) { + operandExprStack.reserve(8); + localVarCst.reset(numDims, numSymbols, numLocals); + } + + void visitMulExpr(AffineBinaryOpExpr expr) { + assert(operandExprStack.size() >= 2); + // This is a pure affine expr; the RHS will be a constant. + assert(expr.getRHS().isa()); + // Get the RHS constant. + auto rhsConst = operandExprStack.back()[getConstantIndex()]; + operandExprStack.pop_back(); + // Update the LHS in place instead of pop and push. + auto &lhs = operandExprStack.back(); + for (unsigned i = 0, e = lhs.size(); i < e; i++) { + lhs[i] *= rhsConst; + } + } + + void visitAddExpr(AffineBinaryOpExpr expr) { + assert(operandExprStack.size() >= 2); + const auto &rhs = operandExprStack.back(); + auto &lhs = operandExprStack[operandExprStack.size() - 2]; + assert(lhs.size() == rhs.size()); + // Update the LHS in place. + for (unsigned i = 0, e = rhs.size(); i < e; i++) { + lhs[i] += rhs[i]; + } + // Pop off the RHS. + operandExprStack.pop_back(); + } + + // + // t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1 + // + // A mod expression "expr mod c" is thus flattened by introducing a new local + // variable q (= expr floordiv c), such that expr mod c is replaced with + // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst. + void visitModExpr(AffineBinaryOpExpr expr) { + assert(operandExprStack.size() >= 2); + // This is a pure affine expr; the RHS will be a constant. + assert(expr.getRHS().isa()); + auto rhsConst = operandExprStack.back()[getConstantIndex()]; + operandExprStack.pop_back(); + auto &lhs = operandExprStack.back(); + // TODO(bondhugula): handle modulo by zero case when this issue is fixed + // at the other places in the IR. + assert(rhsConst > 0 && "RHS constant has to be positive"); + + // Check if the LHS expression is a multiple of modulo factor. + unsigned i, e; + for (i = 0, e = lhs.size(); i < e; i++) + if (lhs[i] % rhsConst != 0) + break; + // If yes, modulo expression here simplifies to zero. + if (i == lhs.size()) { + std::fill(lhs.begin(), lhs.end(), 0); + return; + } + + // Add a local variable for the quotient, i.e., expr % c is replaced by + // (expr - q * c) where q = expr floordiv c. Do this while canceling out + // the GCD of expr and c. + SmallVector floorDividend(lhs); + uint64_t gcd = rhsConst; + for (unsigned i = 0, e = lhs.size(); i < e; i++) + gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i])); + // Simplify the numerator and the denominator. + if (gcd != 1) { + for (unsigned i = 0, e = floorDividend.size(); i < e; i++) + floorDividend[i] = floorDividend[i] / static_cast(gcd); + } + int64_t floorDivisor = rhsConst / static_cast(gcd); + + // Construct the AffineExpr form of the floordiv to store in localExprs. + auto dividendExpr = + toAffineExpr(floorDividend, numDims, numSymbols, localExprs, context); + auto divisorExpr = getAffineConstantExpr(floorDivisor, context); + auto floorDivExpr = dividendExpr.floorDiv(divisorExpr); + int loc; + if ((loc = findLocalId(floorDivExpr)) == -1) { + addLocalFloorDivId(floorDividend, floorDivisor, floorDivExpr); + // Set result at top of stack to "lhs - rhsConst * q". + lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst; + } else { + // Reuse the existing local id. + lhs[getLocalVarStartIndex() + loc] = -rhsConst; + } + } + + void visitCeilDivExpr(AffineBinaryOpExpr expr) { + visitDivExpr(expr, /*isCeil=*/true); + } + void visitFloorDivExpr(AffineBinaryOpExpr expr) { + visitDivExpr(expr, /*isCeil=*/false); + } + + void visitDimExpr(AffineDimExpr expr) { + operandExprStack.emplace_back(SmallVector(getNumCols(), 0)); + auto &eq = operandExprStack.back(); + assert(expr.getPosition() < numDims && "Inconsistent number of dims"); + eq[getDimStartIndex() + expr.getPosition()] = 1; + } + + void visitSymbolExpr(AffineSymbolExpr expr) { + operandExprStack.emplace_back(SmallVector(getNumCols(), 0)); + auto &eq = operandExprStack.back(); + assert(expr.getPosition() < numSymbols && "inconsistent number of symbols"); + eq[getSymbolStartIndex() + expr.getPosition()] = 1; + } + + void visitConstantExpr(AffineConstantExpr expr) { + operandExprStack.emplace_back(SmallVector(getNumCols(), 0)); + auto &eq = operandExprStack.back(); + eq[getConstantIndex()] = expr.getValue(); + } + +private: + // t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1 + // A floordiv is thus flattened by introducing a new local variable q, and + // replacing that expression with 'q' while adding the constraints + // c * q <= expr <= c * q + c - 1 to localVarCst (done by + // FlatAffineConstraints::addLocalFloorDiv). + // + // A ceildiv is similarly flattened: + // t = expr ceildiv c <=> t = (expr + c - 1) floordiv c + void visitDivExpr(AffineBinaryOpExpr expr, bool isCeil) { + assert(operandExprStack.size() >= 2); + assert(expr.getRHS().isa()); + + // This is a pure affine expr; the RHS is a positive constant. + int64_t rhsConst = operandExprStack.back()[getConstantIndex()]; + // TODO(bondhugula): handle division by zero at the same time the issue is + // fixed at other places. + assert(rhsConst > 0 && "RHS constant has to be positive"); + operandExprStack.pop_back(); + auto &lhs = operandExprStack.back(); + + // Simplify the floordiv, ceildiv if possible by canceling out the greatest + // common divisors of the numerator and denominator. + uint64_t gcd = std::abs(rhsConst); + for (unsigned i = 0, e = lhs.size(); i < e; i++) + gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i])); + // Simplify the numerator and the denominator. + if (gcd != 1) { + for (unsigned i = 0, e = lhs.size(); i < e; i++) + lhs[i] = lhs[i] / static_cast(gcd); + } + int64_t divisor = rhsConst / static_cast(gcd); + // If the divisor becomes 1, the updated LHS is the result. (The + // divisor can't be negative since rhsConst is positive). + if (divisor == 1) + return; + + // If the divisor cannot be simplified to one, we will have to retain + // the ceil/floor expr (simplified up until here). Add an existential + // quantifier to express its result, i.e., expr1 div expr2 is replaced + // by a new identifier, q. + auto a = toAffineExpr(lhs, numDims, numSymbols, localExprs, context); + auto b = getAffineConstantExpr(divisor, context); + + int loc; + auto divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b); + if ((loc = findLocalId(divExpr)) == -1) { + if (!isCeil) { + SmallVector dividend(lhs); + addLocalFloorDivId(dividend, divisor, divExpr); + } else { + // lhs ceildiv c <=> (lhs + c - 1) floordiv c + SmallVector dividend(lhs); + dividend.back() += divisor - 1; + addLocalFloorDivId(dividend, divisor, divExpr); + } + } + // Set the expression on stack to the local var introduced to capture the + // result of the division (floor or ceil). + std::fill(lhs.begin(), lhs.end(), 0); + if (loc == -1) + lhs[getLocalVarStartIndex() + numLocals - 1] = 1; + else + lhs[getLocalVarStartIndex() + loc] = 1; + } + + // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr). + // The local identifier added is always a floordiv of a pure add/mul affine + // function of other identifiers, coefficients of which are specified in + // dividend and with respect to a positive constant divisor. localExpr is the + // simplified tree expression (AffineExpr) corresponding to the quantifier. + void addLocalFloorDivId(ArrayRef dividend, int64_t divisor, + AffineExpr localExpr) { + assert(divisor > 0 && "positive constant divisor expected"); + for (auto &subExpr : operandExprStack) + subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0); + localExprs.push_back(localExpr); + numLocals++; + // Update localVarCst. + localVarCst.addLocalFloorDiv(dividend, divisor); + } + + int findLocalId(AffineExpr localExpr) { + SmallVectorImpl::iterator it; + if ((it = std::find(localExprs.begin(), localExprs.end(), localExpr)) == + localExprs.end()) + return -1; + return it - localExprs.begin(); + } + + inline unsigned getNumCols() const { + return numDims + numSymbols + numLocals + 1; + } + inline unsigned getConstantIndex() const { return getNumCols() - 1; } + inline unsigned getLocalVarStartIndex() const { return numDims + numSymbols; } + inline unsigned getSymbolStartIndex() const { return numDims; } + inline unsigned getDimStartIndex() const { return 0; } +}; + +} // end anonymous namespace + +/// Simplify the affine expression by flattening it and reconstructing it. +AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims, + unsigned numSymbols) { + // TODO(bondhugula): only pure affine for now. The simplification here can + // be extended to semi-affine maps in the future. + if (!expr.isPureAffine()) + return expr; + + AffineExprFlattener flattener(numDims, numSymbols, expr.getContext()); + flattener.walkPostOrder(expr); + ArrayRef flattenedExpr = flattener.operandExprStack.back(); + auto simplifiedExpr = toAffineExpr(flattenedExpr, numDims, numSymbols, + flattener.localExprs, expr.getContext()); + flattener.operandExprStack.pop_back(); + assert(flattener.operandExprStack.empty()); + + return simplifiedExpr; +} + +// Flattens the expressions in map. Returns true on success or false +// if 'expr' was unable to be flattened (i.e., semi-affine expressions not +// handled yet). +static bool getFlattenedAffineExprs( + ArrayRef exprs, unsigned numDims, unsigned numSymbols, + std::vector> *flattenedExprs, + FlatAffineConstraints *localVarCst) { + if (exprs.empty()) { + localVarCst->reset(numDims, numSymbols); + return true; + } + + flattenedExprs->clear(); + flattenedExprs->reserve(exprs.size()); + + AffineExprFlattener flattener(numDims, numSymbols, exprs[0].getContext()); + // Use the same flattener to simplify each expression successively. This way + // local identifiers / expressions are shared. + for (auto expr : exprs) { + if (!expr.isPureAffine()) + return false; + + flattener.walkPostOrder(expr); + } + + assert(flattener.operandExprStack.size() == exprs.size()); + flattenedExprs->insert(flattenedExprs->end(), + flattener.operandExprStack.begin(), + flattener.operandExprStack.end()); + if (localVarCst) + localVarCst->clearAndCopyFrom(flattener.localVarCst); + + return true; +} + +// Flattens 'expr' into 'flattenedExpr'. Returns true on success or false +// if 'expr' was unable to be flattened (semi-affine expressions not handled +// yet). +bool mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims, + unsigned numSymbols, + llvm::SmallVectorImpl *flattenedExpr, + FlatAffineConstraints *localVarCst) { + std::vector> flattenedExprs; + bool ret = ::getFlattenedAffineExprs({expr}, numDims, numSymbols, + &flattenedExprs, localVarCst); + *flattenedExpr = flattenedExprs[0]; + return ret; +} + +/// Flattens the expressions in map. Returns true on success or false +/// if 'expr' was unable to be flattened (i.e., semi-affine expressions not +/// handled yet). +bool mlir::getFlattenedAffineExprs( + AffineMap map, std::vector> *flattenedExprs, + FlatAffineConstraints *localVarCst) { + if (map.getNumResults() == 0) { + localVarCst->reset(map.getNumDims(), map.getNumSymbols()); + return true; + } + return ::getFlattenedAffineExprs(map.getResults(), map.getNumDims(), + map.getNumSymbols(), flattenedExprs, + localVarCst); +} + +bool mlir::getFlattenedAffineExprs( + IntegerSet set, std::vector> *flattenedExprs, + FlatAffineConstraints *localVarCst) { + if (set.getNumConstraints() == 0) { + localVarCst->reset(set.getNumDims(), set.getNumSymbols()); + return true; + } + return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(), + set.getNumSymbols(), flattenedExprs, + localVarCst); +} diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp index bc6c0e6fe3f..aaf805dc4b6 100644 --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -246,3 +246,16 @@ AffineMap AffineMap::compose(AffineMap map) { exprs.push_back(expr.compose(newMap)); return AffineMap::get(numDims, numSymbols, exprs, {}); } + +AffineMap mlir::simplifyAffineMap(AffineMap map) { + SmallVector exprs, sizes; + for (auto e : map.getResults()) { + exprs.push_back( + simplifyAffineExpr(e, map.getNumDims(), map.getNumSymbols())); + } + for (auto e : map.getRangeSizes()) { + sizes.push_back( + simplifyAffineExpr(e, map.getNumDims(), map.getNumSymbols())); + } + return AffineMap::get(map.getNumDims(), map.getNumSymbols(), exprs, sizes); +} diff --git a/mlir/lib/IR/AffineStructures.cpp b/mlir/lib/IR/AffineStructures.cpp new file mode 100644 index 00000000000..1306199dc79 --- /dev/null +++ b/mlir/lib/IR/AffineStructures.cpp @@ -0,0 +1,2063 @@ +//===- AffineStructures.cpp - MLIR Affine Structures Class-------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// Structures for affine/polyhedral analysis of MLIR functions. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/AffineStructures.h" +#include "mlir/IR/AffineExprVisitor.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Instruction.h" +#include "mlir/IR/IntegerSet.h" +#include "mlir/Support/MathExtras.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "affine-structures" + +using namespace mlir; +using namespace llvm; + +//===----------------------------------------------------------------------===// +// MutableAffineMap. +//===----------------------------------------------------------------------===// + +MutableAffineMap::MutableAffineMap(AffineMap map) + : numDims(map.getNumDims()), numSymbols(map.getNumSymbols()), + // A map always has at least 1 result by construction + context(map.getResult(0).getContext()) { + for (auto result : map.getResults()) + results.push_back(result); + for (auto rangeSize : map.getRangeSizes()) + results.push_back(rangeSize); +} + +void MutableAffineMap::reset(AffineMap map) { + results.clear(); + rangeSizes.clear(); + numDims = map.getNumDims(); + numSymbols = map.getNumSymbols(); + // A map always has at least 1 result by construction + context = map.getResult(0).getContext(); + for (auto result : map.getResults()) + results.push_back(result); + for (auto rangeSize : map.getRangeSizes()) + results.push_back(rangeSize); +} + +bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const { + if (results[idx].isMultipleOf(factor)) + return true; + + // TODO(bondhugula): use simplifyAffineExpr and FlatAffineConstraints to + // complete this (for a more powerful analysis). + return false; +} + +// Simplifies the result affine expressions of this map. The expressions have to +// be pure for the simplification implemented. +void MutableAffineMap::simplify() { + // Simplify each of the results if possible. + // TODO(ntv): functional-style map + for (unsigned i = 0, e = getNumResults(); i < e; i++) { + results[i] = simplifyAffineExpr(getResult(i), numDims, numSymbols); + } +} + +AffineMap MutableAffineMap::getAffineMap() const { + return AffineMap::get(numDims, numSymbols, results, rangeSizes); +} + +MutableIntegerSet::MutableIntegerSet(IntegerSet set, MLIRContext *context) + : numDims(set.getNumDims()), numSymbols(set.getNumSymbols()), + context(context) { + // TODO(bondhugula) +} + +// Universal set. +MutableIntegerSet::MutableIntegerSet(unsigned numDims, unsigned numSymbols, + MLIRContext *context) + : numDims(numDims), numSymbols(numSymbols), context(context) {} + +//===----------------------------------------------------------------------===// +// AffineValueMap. +//===----------------------------------------------------------------------===// + +AffineValueMap::AffineValueMap(AffineMap map, ArrayRef operands, + ArrayRef results) + : map(map), operands(operands.begin(), operands.end()), + results(results.begin(), results.end()) {} + +void AffineValueMap::reset(AffineMap map, ArrayRef operands, + ArrayRef results) { + this->map.reset(map); + this->operands.assign(operands.begin(), operands.end()); + this->results.assign(results.begin(), results.end()); +} + +// Returns true and sets 'indexOfMatch' if 'valueToMatch' is found in +// 'valuesToSearch' beginning at 'indexStart'. Returns false otherwise. +static bool findIndex(Value *valueToMatch, ArrayRef valuesToSearch, + unsigned indexStart, unsigned *indexOfMatch) { + unsigned size = valuesToSearch.size(); + for (unsigned i = indexStart; i < size; ++i) { + if (valueToMatch == valuesToSearch[i]) { + *indexOfMatch = i; + return true; + } + } + return false; +} + +inline bool AffineValueMap::isMultipleOf(unsigned idx, int64_t factor) const { + return map.isMultipleOf(idx, factor); +} + +/// This method uses the invariant that operands are always positionally aligned +/// with the AffineDimExpr in the underlying AffineMap. +bool AffineValueMap::isFunctionOf(unsigned idx, Value *value) const { + unsigned index; + if (!findIndex(value, operands, /*indexStart=*/0, &index)) { + return false; + } + auto expr = const_cast(this)->getAffineMap().getResult(idx); + // TODO(ntv): this is better implemented on a flattened representation. + // At least for now it is conservative. + return expr.isFunctionOfDim(index); +} + +Value *AffineValueMap::getOperand(unsigned i) const { + return static_cast(operands[i]); +} + +ArrayRef AffineValueMap::getOperands() const { + return ArrayRef(operands); +} + +AffineMap AffineValueMap::getAffineMap() const { return map.getAffineMap(); } + +AffineValueMap::~AffineValueMap() {} + +//===----------------------------------------------------------------------===// +// FlatAffineConstraints. +//===----------------------------------------------------------------------===// + +// Copy constructor. +FlatAffineConstraints::FlatAffineConstraints( + const FlatAffineConstraints &other) { + numReservedCols = other.numReservedCols; + numDims = other.getNumDimIds(); + numSymbols = other.getNumSymbolIds(); + numIds = other.getNumIds(); + + auto otherIds = other.getIds(); + ids.reserve(numReservedCols); + ids.append(otherIds.begin(), otherIds.end()); + + unsigned numReservedEqualities = other.getNumReservedEqualities(); + unsigned numReservedInequalities = other.getNumReservedInequalities(); + + equalities.reserve(numReservedEqualities * numReservedCols); + inequalities.reserve(numReservedInequalities * numReservedCols); + + for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) { + addInequality(other.getInequality(r)); + } + for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) { + addEquality(other.getEquality(r)); + } +} + +// Clones this object. +std::unique_ptr FlatAffineConstraints::clone() const { + return std::make_unique(*this); +} + +// Construct from an IntegerSet. +FlatAffineConstraints::FlatAffineConstraints(IntegerSet set) + : numReservedCols(set.getNumOperands() + 1), + numIds(set.getNumDims() + set.getNumSymbols()), numDims(set.getNumDims()), + numSymbols(set.getNumSymbols()) { + equalities.reserve(set.getNumEqualities() * numReservedCols); + inequalities.reserve(set.getNumInequalities() * numReservedCols); + ids.resize(numIds, None); + + // Flatten expressions and add them to the constraint system. + std::vector> flatExprs; + FlatAffineConstraints localVarCst; + if (!getFlattenedAffineExprs(set, &flatExprs, &localVarCst)) { + assert(false && "flattening unimplemented for semi-affine integer sets"); + return; + } + assert(flatExprs.size() == set.getNumConstraints()); + for (unsigned l = 0, e = localVarCst.getNumLocalIds(); l < e; l++) { + addLocalId(getNumLocalIds()); + } + + for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) { + const auto &flatExpr = flatExprs[i]; + assert(flatExpr.size() == getNumCols()); + if (set.getEqFlags()[i]) { + addEquality(flatExpr); + } else { + addInequality(flatExpr); + } + } + // Add the other constraints involving local id's from flattening. + append(localVarCst); +} + +void FlatAffineConstraints::reset(unsigned numReservedInequalities, + unsigned numReservedEqualities, + unsigned newNumReservedCols, + unsigned newNumDims, unsigned newNumSymbols, + unsigned newNumLocals, + ArrayRef idArgs) { + assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 && + "minimum 1 column"); + numReservedCols = newNumReservedCols; + numDims = newNumDims; + numSymbols = newNumSymbols; + numIds = numDims + numSymbols + newNumLocals; + assert(idArgs.empty() || idArgs.size() == numIds); + + clearConstraints(); + if (numReservedEqualities >= 1) + equalities.reserve(newNumReservedCols * numReservedEqualities); + if (numReservedInequalities >= 1) + inequalities.reserve(newNumReservedCols * numReservedInequalities); + if (idArgs.empty()) { + ids.resize(numIds, None); + } else { + ids.assign(idArgs.begin(), idArgs.end()); + } +} + +void FlatAffineConstraints::reset(unsigned newNumDims, unsigned newNumSymbols, + unsigned newNumLocals, + ArrayRef idArgs) { + reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims, + newNumSymbols, newNumLocals, idArgs); +} + +void FlatAffineConstraints::append(const FlatAffineConstraints &other) { + assert(other.getNumCols() == getNumCols()); + assert(other.getNumDimIds() == getNumDimIds()); + assert(other.getNumSymbolIds() == getNumSymbolIds()); + + inequalities.reserve(inequalities.size() + + other.getNumInequalities() * numReservedCols); + equalities.reserve(equalities.size() + + other.getNumEqualities() * numReservedCols); + + for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) { + addInequality(other.getInequality(r)); + } + for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) { + addEquality(other.getEquality(r)); + } +} + +void FlatAffineConstraints::addLocalId(unsigned pos) { + addId(IdKind::Local, pos); +} + +void FlatAffineConstraints::addDimId(unsigned pos, Value *id) { + addId(IdKind::Dimension, pos, id); +} + +void FlatAffineConstraints::addSymbolId(unsigned pos, Value *id) { + addId(IdKind::Symbol, pos, id); +} + +/// Adds a dimensional identifier. The added column is initialized to +/// zero. +void FlatAffineConstraints::addId(IdKind kind, unsigned pos, Value *id) { + if (kind == IdKind::Dimension) { + assert(pos <= getNumDimIds()); + } else if (kind == IdKind::Symbol) { + assert(pos <= getNumSymbolIds()); + } else { + assert(pos <= getNumLocalIds()); + } + + unsigned oldNumReservedCols = numReservedCols; + + // Check if a resize is necessary. + if (getNumCols() + 1 > numReservedCols) { + equalities.resize(getNumEqualities() * (getNumCols() + 1)); + inequalities.resize(getNumInequalities() * (getNumCols() + 1)); + numReservedCols++; + } + + unsigned absolutePos; + + if (kind == IdKind::Dimension) { + absolutePos = pos; + numDims++; + } else if (kind == IdKind::Symbol) { + absolutePos = pos + getNumDimIds(); + numSymbols++; + } else { + absolutePos = pos + getNumDimIds() + getNumSymbolIds(); + } + numIds++; + + // Note that getNumCols() now will already return the new size, which will be + // at least one. + int numInequalities = static_cast(getNumInequalities()); + int numEqualities = static_cast(getNumEqualities()); + int numCols = static_cast(getNumCols()); + for (int r = numInequalities - 1; r >= 0; r--) { + for (int c = numCols - 2; c >= 0; c--) { + if (c < absolutePos) + atIneq(r, c) = inequalities[r * oldNumReservedCols + c]; + else + atIneq(r, c + 1) = inequalities[r * oldNumReservedCols + c]; + } + atIneq(r, absolutePos) = 0; + } + + for (int r = numEqualities - 1; r >= 0; r--) { + for (int c = numCols - 2; c >= 0; c--) { + // All values in column absolutePositions < absolutePos have the same + // coordinates in the 2-d view of the coefficient buffer. + if (c < absolutePos) + atEq(r, c) = equalities[r * oldNumReservedCols + c]; + else + // Those at absolutePosition >= absolutePos, get a shifted + // absolutePosition. + atEq(r, c + 1) = equalities[r * oldNumReservedCols + c]; + } + // Initialize added dimension to zero. + atEq(r, absolutePos) = 0; + } + + // If an 'id' is provided, insert it; otherwise use None. + if (id) { + ids.insert(ids.begin() + absolutePos, id); + } else { + ids.insert(ids.begin() + absolutePos, None); + } + assert(ids.size() == getNumIds()); +} + +// This routine may add additional local variables if the flattened expression +// corresponding to the map has such variables due to the presence of +// mod's, ceildiv's, and floordiv's. +bool FlatAffineConstraints::composeMap(AffineValueMap *vMap) { + // Assert if the map and this constraint set aren't associated with the same + // identifiers in the same order. + assert(vMap->getNumDims() <= getNumDimIds()); + assert(vMap->getNumSymbols() <= getNumSymbolIds()); + for (unsigned i = 0, e = vMap->getNumDims(); i < e; i++) { + assert(ids[i].hasValue()); + assert(vMap->getOperand(i) == ids[i].getValue()); + } + for (unsigned i = 0, e = vMap->getNumSymbols(); i < e; i++) { + assert(ids[numDims + i].hasValue()); + assert(vMap->getOperand(vMap->getNumDims() + i) == + ids[numDims + i].getValue()); + } + + std::vector> flatExprs; + FlatAffineConstraints cst; + if (!getFlattenedAffineExprs(vMap->getAffineMap(), &flatExprs, &cst)) { + LLVM_DEBUG(llvm::dbgs() + << "composition unimplemented for semi-affine maps\n"); + return false; + } + assert(flatExprs.size() == vMap->getNumResults()); + + // Make the value map and the flat affine cst dimensions compatible. + // A lot of this code will be refactored/cleaned up. + // TODO(bondhugula): the next ~20 lines of code is pretty UGLY. This needs + // to be factored out into an FlatAffineConstraints::alignAndMerge(). + for (unsigned l = 0, e = cst.getNumLocalIds(); l < e; l++) { + addLocalId(0); + } + + for (unsigned t = 0, e = vMap->getNumResults(); t < e; t++) { + // TODO: Consider using a batched version to add a range of IDs. + addDimId(0); + cst.addDimId(0); + } + + assert(cst.getNumDimIds() <= getNumDimIds()); + for (unsigned t = 0, e = getNumDimIds() - cst.getNumDimIds(); t < e; t++) { + // Dimensions that are in 'this' but not in vMap/cst are added at the end. + cst.addDimId(cst.getNumDimIds()); + } + assert(cst.getNumSymbolIds() <= getNumSymbolIds()); + for (unsigned t = 0, e = getNumSymbolIds() - cst.getNumSymbolIds(); t < e; + t++) { + // Dimensions that are in 'this' but not in vMap/cst are added at the end. + cst.addSymbolId(cst.getNumSymbolIds()); + } + assert(cst.getNumLocalIds() <= getNumLocalIds()); + for (unsigned t = 0, e = getNumLocalIds() - cst.getNumLocalIds(); t < e; + t++) { + cst.addLocalId(cst.getNumLocalIds()); + } + /// Finally, append cst to this constraint set. + append(cst); + + // We add one equality for each result connecting the result dim of the map to + // the other identifiers. + // For eg: if the expression is 16*i0 + i1, and this is the r^th + // iteration/result of the value map, we are adding the equality: + // d_r - 16*i0 - i1 = 0. Hence, when flattening say (i0 + 1, i0 + 8*i2), we + // add two equalities overall: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0. + for (unsigned r = 0, e = flatExprs.size(); r < e; r++) { + const auto &flatExpr = flatExprs[r]; + // eqToAdd is the equality corresponding to the flattened affine expression. + SmallVector eqToAdd(getNumCols(), 0); + // Set the coefficient for this result to one. + eqToAdd[r] = 1; + + assert(flatExpr.size() >= vMap->getNumOperands() + 1); + + // Dims and symbols. + for (unsigned i = 0, e = vMap->getNumOperands(); i < e; i++) { + unsigned loc; + bool ret = findId(*vMap->getOperand(i), &loc); + assert(ret && "value map's id can't be found"); + (void)ret; + // We need to negate 'eq[r]' since the newly added dimension is going to + // be set to this one. + eqToAdd[loc] = -flatExpr[i]; + } + // Local vars common to eq and cst are at the beginning. + int j = getNumDimIds() + getNumSymbolIds(); + int end = flatExpr.size() - 1; + for (int i = vMap->getNumOperands(); i < end; i++, j++) { + eqToAdd[j] = -flatExpr[i]; + } + + // Constant term. + eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1]; + + // Add the equality connecting the result of the map to this constraint set. + addEquality(eqToAdd); + } + + return true; +} + +// Searches for a constraint with a non-zero coefficient at 'colIdx' in +// equality (isEq=true) or inequality (isEq=false) constraints. +// Returns true and sets row found in search in 'rowIdx'. +// Returns false otherwise. +static bool +findConstraintWithNonZeroAt(const FlatAffineConstraints &constraints, + unsigned colIdx, bool isEq, unsigned *rowIdx) { + auto at = [&](unsigned rowIdx) -> int64_t { + return isEq ? constraints.atEq(rowIdx, colIdx) + : constraints.atIneq(rowIdx, colIdx); + }; + unsigned e = + isEq ? constraints.getNumEqualities() : constraints.getNumInequalities(); + for (*rowIdx = 0; *rowIdx < e; ++(*rowIdx)) { + if (at(*rowIdx) != 0) { + return true; + } + } + return false; +} + +// Normalizes the coefficient values across all columns in 'rowIDx' by their +// GCD in equality or inequality contraints as specified by 'isEq'. +template +static void normalizeConstraintByGCD(FlatAffineConstraints *constraints, + unsigned rowIdx) { + auto at = [&](unsigned colIdx) -> int64_t { + return isEq ? constraints->atEq(rowIdx, colIdx) + : constraints->atIneq(rowIdx, colIdx); + }; + uint64_t gcd = std::abs(at(0)); + for (unsigned j = 1, e = constraints->getNumCols(); j < e; ++j) { + gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(at(j))); + } + if (gcd > 0 && gcd != 1) { + for (unsigned j = 0, e = constraints->getNumCols(); j < e; ++j) { + int64_t v = at(j) / static_cast(gcd); + isEq ? constraints->atEq(rowIdx, j) = v + : constraints->atIneq(rowIdx, j) = v; + } + } +} + +void FlatAffineConstraints::normalizeConstraintsByGCD() { + for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { + normalizeConstraintByGCD(this, i); + } + for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { + normalizeConstraintByGCD(this, i); + } +} + +bool FlatAffineConstraints::hasConsistentState() const { + if (inequalities.size() != getNumInequalities() * numReservedCols) + return false; + if (equalities.size() != getNumEqualities() * numReservedCols) + return false; + if (ids.size() != getNumIds()) + return false; + + // Catches errors where numDims, numSymbols, numIds aren't consistent. + if (numDims > numIds || numSymbols > numIds || numDims + numSymbols > numIds) + return false; + + return true; +} + +/// Checks all rows of equality/inequality constraints for trivial +/// contradictions (for example: 1 == 0, 0 >= 1), which may have surfaced +/// after elimination. Returns 'true' if an invalid constraint is found; +/// 'false' otherwise. +bool FlatAffineConstraints::hasInvalidConstraint() const { + assert(hasConsistentState()); + auto check = [&](bool isEq) -> bool { + unsigned numCols = getNumCols(); + unsigned numRows = isEq ? getNumEqualities() : getNumInequalities(); + for (unsigned i = 0, e = numRows; i < e; ++i) { + unsigned j; + for (j = 0; j < numCols - 1; ++j) { + int64_t v = isEq ? atEq(i, j) : atIneq(i, j); + // Skip rows with non-zero variable coefficients. + if (v != 0) + break; + } + if (j < numCols - 1) { + continue; + } + // Check validity of constant term at 'numCols - 1' w.r.t 'isEq'. + // Example invalid constraints include: '1 == 0' or '-1 >= 0' + int64_t v = isEq ? atEq(i, numCols - 1) : atIneq(i, numCols - 1); + if ((isEq && v != 0) || (!isEq && v < 0)) { + return true; + } + } + return false; + }; + if (check(/*isEq=*/true)) + return true; + return check(/*isEq=*/false); +} + +// Eliminate identifier from constraint at 'rowIdx' based on coefficient at +// pivotRow, pivotCol. Columns in range [elimColStart, pivotCol) will not be +// updated as they have already been eliminated. +static void eliminateFromConstraint(FlatAffineConstraints *constraints, + unsigned rowIdx, unsigned pivotRow, + unsigned pivotCol, unsigned elimColStart, + bool isEq) { + // Skip if equality 'rowIdx' if same as 'pivotRow'. + if (isEq && rowIdx == pivotRow) + return; + auto at = [&](unsigned i, unsigned j) -> int64_t { + return isEq ? constraints->atEq(i, j) : constraints->atIneq(i, j); + }; + int64_t leadCoeff = at(rowIdx, pivotCol); + // Skip if leading coefficient at 'rowIdx' is already zero. + if (leadCoeff == 0) + return; + int64_t pivotCoeff = constraints->atEq(pivotRow, pivotCol); + int64_t sign = (leadCoeff * pivotCoeff > 0) ? -1 : 1; + int64_t lcm = mlir::lcm(pivotCoeff, leadCoeff); + int64_t pivotMultiplier = sign * (lcm / std::abs(pivotCoeff)); + int64_t rowMultiplier = lcm / std::abs(leadCoeff); + + unsigned numCols = constraints->getNumCols(); + for (unsigned j = 0; j < numCols; ++j) { + // Skip updating column 'j' if it was just eliminated. + if (j >= elimColStart && j < pivotCol) + continue; + int64_t v = pivotMultiplier * constraints->atEq(pivotRow, j) + + rowMultiplier * at(rowIdx, j); + isEq ? constraints->atEq(rowIdx, j) = v + : constraints->atIneq(rowIdx, j) = v; + } +} + +// Remove coefficients in column range [colStart, colLimit) in place. +// This removes in data in the specified column range, and copies any +// remaining valid data into place. +static void shiftColumnsToLeft(FlatAffineConstraints *constraints, + unsigned colStart, unsigned colLimit, + bool isEq) { + assert(colStart >= 0 && colLimit <= constraints->getNumIds()); + if (colLimit <= colStart) + return; + + unsigned numCols = constraints->getNumCols(); + unsigned numRows = isEq ? constraints->getNumEqualities() + : constraints->getNumInequalities(); + unsigned numToEliminate = colLimit - colStart; + for (unsigned r = 0, e = numRows; r < e; ++r) { + for (unsigned c = colLimit; c < numCols; ++c) { + if (isEq) { + constraints->atEq(r, c - numToEliminate) = constraints->atEq(r, c); + } else { + constraints->atIneq(r, c - numToEliminate) = constraints->atIneq(r, c); + } + } + } +} + +// Removes identifiers in column range [idStart, idLimit), and copies any +// remaining valid data into place, and updates member variables. +void FlatAffineConstraints::removeIdRange(unsigned idStart, unsigned idLimit) { + assert(idLimit < getNumCols() && "invalid id limit"); + + if (idStart >= idLimit) + return; + + // We are going to be removing one or more identifiers from the range. + assert(idStart < numIds && "invalid idStart position"); + + // TODO(andydavis) Make 'removeIdRange' a lambda called from here. + // Remove eliminated identifiers from equalities. + shiftColumnsToLeft(this, idStart, idLimit, /*isEq=*/true); + + // Remove eliminated identifiers from inequalities. + shiftColumnsToLeft(this, idStart, idLimit, /*isEq=*/false); + + // Update members numDims, numSymbols and numIds. + unsigned numDimsEliminated = 0; + unsigned numLocalsEliminated = 0; + unsigned numColsEliminated = idLimit - idStart; + if (idStart < numDims) { + numDimsEliminated = std::min(numDims, idLimit) - idStart; + } + // Check how many local id's were removed. Note that our identifier order is + // [dims, symbols, locals]. Local id start at position numDims + numSymbols. + if (idLimit > numDims + numSymbols) { + numLocalsEliminated = std::min( + idLimit - std::max(idStart, numDims + numSymbols), getNumLocalIds()); + } + unsigned numSymbolsEliminated = + numColsEliminated - numDimsEliminated - numLocalsEliminated; + + numDims -= numDimsEliminated; + numSymbols -= numSymbolsEliminated; + numIds = numIds - numColsEliminated; + + ids.erase(ids.begin() + idStart, ids.begin() + idLimit); + + // No resize necessary. numReservedCols remains the same. +} + +/// Returns the position of the identifier that has the minimum times from the specified range of +/// identifiers [start, end). It is often best to eliminate in the increasing +/// order of these counts when doing Fourier-Motzkin elimination since FM adds +/// that many new constraints. +static unsigned getBestIdToEliminate(const FlatAffineConstraints &cst, + unsigned start, unsigned end) { + assert(start < cst.getNumIds() && end < cst.getNumIds() + 1); + + auto getProductOfNumLowerUpperBounds = [&](unsigned pos) { + unsigned numLb = 0; + unsigned numUb = 0; + for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) { + if (cst.atIneq(r, pos) > 0) { + ++numLb; + } else if (cst.atIneq(r, pos) < 0) { + ++numUb; + } + } + return numLb * numUb; + }; + + unsigned minLoc = start; + unsigned min = getProductOfNumLowerUpperBounds(start); + for (unsigned c = start + 1; c < end; c++) { + unsigned numLbUbProduct = getProductOfNumLowerUpperBounds(c); + if (numLbUbProduct < min) { + min = numLbUbProduct; + minLoc = c; + } + } + return minLoc; +} + +// Checks for emptiness of the set by eliminating identifiers successively and +// using the GCD test (on all equality constraints) and checking for trivially +// invalid constraints. Returns 'true' if the constraint system is found to be +// empty; false otherwise. +bool FlatAffineConstraints::isEmpty() const { + if (isEmptyByGCDTest() || hasInvalidConstraint()) + return true; + + // First, eliminate as many identifiers as possible using Gaussian + // elimination. + FlatAffineConstraints tmpCst(*this); + unsigned currentPos = 0; + while (currentPos < tmpCst.getNumIds()) { + tmpCst.gaussianEliminateIds(currentPos, tmpCst.getNumIds()); + ++currentPos; + // We check emptiness through trivial checks after eliminating each ID to + // detect emptiness early. Since the checks isEmptyByGCDTest() and + // hasInvalidConstraint() are linear time and single sweep on the constraint + // buffer, this appears reasonable - but can optimize in the future. + if (tmpCst.hasInvalidConstraint() || tmpCst.isEmptyByGCDTest()) + return true; + } + + // Eliminate the remaining using FM. + for (unsigned i = 0, e = tmpCst.getNumIds(); i < e; i++) { + tmpCst.FourierMotzkinEliminate( + getBestIdToEliminate(tmpCst, 0, tmpCst.getNumIds())); + // Check for a constraint explosion. This rarely happens in practice, but + // this check exists as a safeguard against improperly constructed + // constraint systems or artifically created arbitrarily complex systems + // that aren't the intended use case for FlatAffineConstraints. This is + // needed since FM has a worst case exponential complexity in theory. + if (tmpCst.getNumConstraints() >= kExplosionFactor * getNumIds()) { + LLVM_DEBUG(llvm::dbgs() << "FM constraint explosion detected"); + return false; + } + + // FM wouldn't have modified the equalities in any way. So no need to again + // run GCD test. Check for trivial invalid constraints. + if (tmpCst.hasInvalidConstraint()) + return true; + } + return false; +} + +// Runs the GCD test on all equality constraints. Returns 'true' if this test +// fails on any equality. Returns 'false' otherwise. +// This test can be used to disprove the existence of a solution. If it returns +// true, no integer solution to the equality constraints can exist. +// +// GCD test definition: +// +// The equality constraint: +// +// c_1*x_1 + c_2*x_2 + ... + c_n*x_n = c_0 +// +// has an integer solution iff: +// +// GCD of c_1, c_2, ..., c_n divides c_0. +// +bool FlatAffineConstraints::isEmptyByGCDTest() const { + assert(hasConsistentState()); + unsigned numCols = getNumCols(); + for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { + uint64_t gcd = std::abs(atEq(i, 0)); + for (unsigned j = 1; j < numCols - 1; ++j) { + gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atEq(i, j))); + } + int64_t v = std::abs(atEq(i, numCols - 1)); + if (gcd > 0 && (v % gcd != 0)) { + return true; + } + } + return false; +} + +/// Tightens inequalities given that we are dealing with integer spaces. This is +/// analogous to the GCD test but applied to inequalities. The constant term can +/// be reduced to the preceding multiple of the GCD of the coefficients, i.e., +/// 64*i - 100 >= 0 => 64*i - 128 >= 0 (since 'i' is an integer). This is a +/// fast method - linear in the number of coefficients. +// Example on how this affects practical cases: consider the scenario: +// 64*i >= 100, j = 64*i; without a tightening, elimination of i would yield +// j >= 100 instead of the tighter (exact) j >= 128. +void FlatAffineConstraints::GCDTightenInequalities() { + unsigned numCols = getNumCols(); + for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { + uint64_t gcd = std::abs(atIneq(i, 0)); + for (unsigned j = 1; j < numCols - 1; ++j) { + gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atIneq(i, j))); + } + if (gcd > 0) { + int64_t gcdI = static_cast(gcd); + atIneq(i, numCols - 1) = + gcdI * mlir::floorDiv(atIneq(i, numCols - 1), gcdI); + } + } +} + +// Eliminates all identifer variables in column range [posStart, posLimit). +// Returns the number of variables eliminated. +unsigned FlatAffineConstraints::gaussianEliminateIds(unsigned posStart, + unsigned posLimit) { + // Return if identifier positions to eliminate are out of range. + assert(posLimit <= numIds); + assert(hasConsistentState()); + + if (posStart >= posLimit) + return 0; + + LLVM_DEBUG(llvm::dbgs() << "Eliminating by Gaussian [" << posStart << ", " + << posLimit << ")\n"); + + GCDTightenInequalities(); + + unsigned pivotCol = 0; + for (pivotCol = posStart; pivotCol < posLimit; ++pivotCol) { + // Find a row which has a non-zero coefficient in column 'j'. + unsigned pivotRow; + if (!findConstraintWithNonZeroAt(*this, pivotCol, /*isEq=*/true, + &pivotRow)) { + // No pivot row in equalities with non-zero at 'pivotCol'. + if (!findConstraintWithNonZeroAt(*this, pivotCol, /*isEq=*/false, + &pivotRow)) { + // If inequalities are also non-zero in 'pivotCol', it can be + // eliminated. + continue; + } + break; + } + + // Eliminate identifier at 'pivotCol' from each equality row. + for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { + eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart, + /*isEq=*/true); + normalizeConstraintByGCD(this, i); + } + + // Eliminate identifier at 'pivotCol' from each inequality row. + for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { + eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart, + /*isEq=*/false); + normalizeConstraintByGCD(this, i); + } + removeEquality(pivotRow); + } + // Update position limit based on number eliminated. + posLimit = pivotCol; + // Remove eliminated columns from all constraints. + removeIdRange(posStart, posLimit); + return posLimit - posStart; +} + +// Detect the identifier at 'pos' (say id_r) as modulo of another identifier +// (say id_n) w.r.t a constant. When this happens, another identifier (say id_q) +// could be detected as the floordiv of n. For eg: +// id_n - 4*id_q - id_r = 0, 0 <= id_r <= 3 <=> +// id_r = id_n mod 4, id_q = id_n floordiv 4. +// lbConst and ubConst are the constant lower and upper bounds for 'pos' - +// pre-detected at the caller. +static bool detectAsMod(const FlatAffineConstraints &cst, unsigned pos, + int64_t lbConst, int64_t ubConst, + SmallVectorImpl *memo) { + assert(pos < cst.getNumIds() && "invalid position"); + + // Check if 0 <= id_r <= divisor - 1 and if id_r is equal to + // id_n - divisor * id_q. If these are true, then id_n becomes the dividend + // and id_q the quotient when dividing id_n by the divisor. + + if (lbConst != 0 || ubConst < 1) + return false; + + int64_t divisor = ubConst + 1; + + // Now check for: id_r = id_n - divisor * id_q. As an example, we + // are looking r = d - 4q, i.e., either r - d + 4q = 0 or -r + d - 4q = 0. + unsigned seenQuotient = 0, seenDividend = 0; + int quotientPos = -1, dividendPos = -1; + for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) { + // id_n should have coeff 1 or -1. + if (std::abs(cst.atEq(r, pos)) != 1) + continue; + for (unsigned c = 0, f = cst.getNumDimAndSymbolIds(); c < f; c++) { + // The coeff of the quotient should be -divisor if the coefficient of + // the pos^th identifier is -1, and divisor if the latter is -1. + if (cst.atEq(r, c) * cst.atEq(r, pos) == divisor) { + seenQuotient++; + quotientPos = c; + } else if (cst.atEq(r, c) * cst.atEq(r, pos) == -1) { + seenDividend++; + dividendPos = c; + } + } + // We are looking for exactly one identifier as part of the dividend. + // TODO(bondhugula): could be extended to cover multiple ones in the + // dividend to detect mod of an affine function of identifiers. + if (seenDividend == 1 && seenQuotient >= 1) { + if (!(*memo)[dividendPos]) + return false; + // Successfully detected a mod. + (*memo)[pos] = (*memo)[dividendPos] % divisor; + if (seenQuotient == 1 && !(*memo)[quotientPos]) + // Successfully detected a floordiv as well. + (*memo)[quotientPos] = (*memo)[dividendPos].floorDiv(divisor); + return true; + } + } + return false; +} + +// Check if the pos^th identifier can be expressed as a floordiv of an affine +// function of other identifiers (where the divisor is a positive constant). +// For eg: 4q <= i + j <= 4q + 3 <=> q = (i + j) floordiv 4. +bool detectAsFloorDiv(const FlatAffineConstraints &cst, unsigned pos, + SmallVectorImpl *memo, MLIRContext *context) { + assert(pos < cst.getNumIds() && "invalid position"); + SmallVector lbIndices, ubIndices; + + // Gather all lower bounds and upper bound constraints of this identifier. + // Since the canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint + // is a lower bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. + for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) { + if (cst.atIneq(r, pos) >= 1) + // Lower bound. + lbIndices.push_back(r); + else if (cst.atIneq(r, pos) <= -1) + // Upper bound. + ubIndices.push_back(r); + } + + // Check if any lower bound, upper bound pair is of the form: + // divisor * id >= expr - (divisor - 1) <-- Lower bound for 'id' + // divisor * id <= expr <-- Upper bound for 'id' + // Then, 'id' is equivalent to 'expr floordiv divisor'. (where divisor > 1). + // + // For example, if -32*k + 16*i + j >= 0 + // 32*k - 16*i - j + 31 >= 0 <=> + // k = ( 16*i + j ) floordiv 32 + unsigned seenDividends = 0; + for (auto ubPos : ubIndices) { + for (auto lbPos : lbIndices) { + // Check if lower bound's constant term is 'divisor - 1'. The 'divisor' + // here is cst.atIneq(lbPos, pos) and we already know that it's positive + // (since cst.Ineq(lbPos, ...) is a lower bound expression for 'pos'. + if (cst.atIneq(lbPos, cst.getNumCols() - 1) != cst.atIneq(lbPos, pos) - 1) + continue; + // Check if upper bound's constant term is 0. + if (cst.atIneq(ubPos, cst.getNumCols() - 1) != 0) + continue; + // For the remaining part, check if the lower bound expr's coeff's are + // negations of corresponding upper bound ones'. + unsigned c, f; + for (c = 0, f = cst.getNumCols() - 1; c < f; c++) { + if (cst.atIneq(lbPos, c) != -cst.atIneq(ubPos, c)) + break; + if (c != pos && cst.atIneq(lbPos, c) != 0) + seenDividends++; + } + // Lb coeff's aren't negative of ub coeff's (for the non constant term + // part). + if (c < f) + continue; + if (seenDividends >= 1) { + // The divisor is the constant term of the lower bound expression. + // We already know that cst.atIneq(lbPos, pos) > 0. + int64_t divisor = cst.atIneq(lbPos, pos); + // Construct the dividend expression. + auto dividendExpr = getAffineConstantExpr(0, context); + unsigned c, f; + for (c = 0, f = cst.getNumCols() - 1; c < f; c++) { + if (c == pos) + continue; + int64_t ubVal = cst.atIneq(ubPos, c); + if (ubVal == 0) + continue; + if (!(*memo)[c]) + break; + dividendExpr = dividendExpr + ubVal * (*memo)[c]; + } + // Expression can't be constructed as it depends on a yet unknown + // identifier. + // TODO(mlir-team): Visit/compute the identifiers in an order so that + // this doesn't happen. More complex but much more efficient. + if (c < f) + continue; + // Successfully detected the floordiv. + (*memo)[pos] = dividendExpr.floorDiv(divisor); + return true; + } + } + } + return false; +} + +/// Computes the lower and upper bounds of the first 'num' dimensional +/// identifiers as affine maps of the remaining identifiers (dimensional and +/// symbolic identifiers). Local identifiers are themselves explicitly computed +/// as affine functions of other identifiers in this process if needed. +void FlatAffineConstraints::getSliceBounds(unsigned num, MLIRContext *context, + SmallVectorImpl *lbMaps, + SmallVectorImpl *ubMaps) { + assert(num < getNumDimIds() && "invalid range"); + + // Basic simplification. + normalizeConstraintsByGCD(); + + LLVM_DEBUG(llvm::dbgs() << "getSliceBounds on:\n"); + LLVM_DEBUG(dump()); + + // Record computed/detected identifiers. + SmallVector memo(getNumIds(), AffineExpr::Null()); + // Initialize dimensional and symbolic identifiers. + for (unsigned i = num, e = getNumDimIds(); i < e; i++) + memo[i] = getAffineDimExpr(i - num, context); + for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++) + memo[i] = getAffineSymbolExpr(i - getNumDimIds(), context); + + bool changed; + do { + changed = false; + // Identify yet unknown identifiers as constants or mod's / floordiv's of + // other identifiers if possible. + for (unsigned pos = 0; pos < getNumIds(); pos++) { + if (memo[pos]) + continue; + + auto lbConst = getConstantLowerBound(pos); + auto ubConst = getConstantUpperBound(pos); + if (lbConst.hasValue() && ubConst.hasValue()) { + // Detect equality to a constant. + if (lbConst.getValue() == ubConst.getValue()) { + memo[pos] = getAffineConstantExpr(lbConst.getValue(), context); + changed = true; + continue; + } + + // Detect an identifier as modulo of another identifier w.r.t a + // constant. + if (detectAsMod(*this, pos, lbConst.getValue(), ubConst.getValue(), + &memo)) { + changed = true; + continue; + } + } + + // Detect an identifier as floordiv of another identifier w.r.t a + // constant. + if (detectAsFloorDiv(*this, pos, &memo, context)) { + changed = true; + continue; + } + + // Detect an identifier as an expression of other identifiers. + unsigned idx; + if (!findConstraintWithNonZeroAt(*this, pos, /*isEq=*/true, &idx)) { + continue; + } + + // Build AffineExpr solving for identifier 'pos' in terms of all others. + auto expr = getAffineConstantExpr(0, context); + unsigned j, e; + for (j = 0, e = getNumIds(); j < e; ++j) { + if (j == pos) + continue; + int64_t c = atEq(idx, j); + if (c == 0) + continue; + // If any of the involved IDs hasn't been found yet, we can't proceed. + if (!memo[j]) + break; + expr = expr + memo[j] * c; + } + if (j < e) + // Can't construct expression as it depends on a yet uncomputed + // identifier. + continue; + + // Add constant term to AffineExpr. + expr = expr + atEq(idx, getNumIds()); + int64_t vPos = atEq(idx, pos); + assert(vPos != 0 && "expected non-zero here"); + if (vPos > 0) + expr = (-expr).floorDiv(vPos); + else + // vPos < 0. + expr = expr.floorDiv(-vPos); + // Successfully constructed expression. + memo[pos] = expr; + changed = true; + } + // This loop is guaranteed to reach a fixed point - since once an + // identifier's explicit form is computed (in memo[pos]), it's not updated + // again. + } while (changed); + + // Set the lower and upper bound maps for all the identifiers that were + // computed as affine expressions of the rest as the "detected expr" and + // "detected expr + 1" respectively; set the undetected ones to Null(). + for (unsigned pos = 0; pos < num; pos++) { + unsigned numMapDims = getNumDimIds() - num; + unsigned numMapSymbols = getNumSymbolIds(); + AffineExpr expr = memo[pos]; + if (expr) + expr = simplifyAffineExpr(expr, numMapDims, numMapSymbols); + + if (expr) { + (*lbMaps)[pos] = AffineMap::get(numMapDims, numMapSymbols, expr, {}); + (*ubMaps)[pos] = AffineMap::get(numMapDims, numMapSymbols, expr + 1, {}); + } else { + // TODO(andydavis, bondhugula) Add support for computing slice bounds + // symbolic in the identifies [num, numIds). + auto lbConst = getConstantLowerBound(pos); + auto ubConst = getConstantUpperBound(pos); + if (lbConst.hasValue() && ubConst.hasValue()) { + (*lbMaps)[pos] = AffineMap::get( + numMapDims, numMapSymbols, + getAffineConstantExpr(lbConst.getValue(), context), {}); + (*ubMaps)[pos] = AffineMap::get( + numMapDims, numMapSymbols, + getAffineConstantExpr(ubConst.getValue() + 1, context), {}); + } else { + (*lbMaps)[pos] = AffineMap(); + (*ubMaps)[pos] = AffineMap(); + } + } + LLVM_DEBUG(llvm::dbgs() << "lb map for pos = " << Twine(pos) << ", expr: "); + LLVM_DEBUG(expr.dump();); + } +} + +void FlatAffineConstraints::addEquality(ArrayRef eq) { + assert(eq.size() == getNumCols()); + unsigned offset = equalities.size(); + equalities.resize(equalities.size() + numReservedCols); + std::copy(eq.begin(), eq.end(), equalities.begin() + offset); +} + +void FlatAffineConstraints::addInequality(ArrayRef inEq) { + assert(inEq.size() == getNumCols()); + unsigned offset = inequalities.size(); + inequalities.resize(inequalities.size() + numReservedCols); + std::copy(inEq.begin(), inEq.end(), inequalities.begin() + offset); +} + +void FlatAffineConstraints::addConstantLowerBound(unsigned pos, int64_t lb) { + assert(pos < getNumCols()); + unsigned offset = inequalities.size(); + inequalities.resize(inequalities.size() + numReservedCols); + std::fill(inequalities.begin() + offset, + inequalities.begin() + offset + getNumCols(), 0); + inequalities[offset + pos] = 1; + inequalities[offset + getNumCols() - 1] = -lb; +} + +void FlatAffineConstraints::addConstantUpperBound(unsigned pos, int64_t ub) { + assert(pos < getNumCols()); + unsigned offset = inequalities.size(); + inequalities.resize(inequalities.size() + numReservedCols); + std::fill(inequalities.begin() + offset, + inequalities.begin() + offset + getNumCols(), 0); + inequalities[offset + pos] = -1; + inequalities[offset + getNumCols() - 1] = ub; +} + +void FlatAffineConstraints::addConstantLowerBound(ArrayRef expr, + int64_t lb) { + assert(expr.size() == getNumCols()); + unsigned offset = inequalities.size(); + inequalities.resize(inequalities.size() + numReservedCols); + std::fill(inequalities.begin() + offset, + inequalities.begin() + offset + getNumCols(), 0); + std::copy(expr.begin(), expr.end(), inequalities.begin() + offset); + inequalities[offset + getNumCols() - 1] += -lb; +} + +void FlatAffineConstraints::addConstantUpperBound(ArrayRef expr, + int64_t ub) { + assert(expr.size() == getNumCols()); + unsigned offset = inequalities.size(); + inequalities.resize(inequalities.size() + numReservedCols); + std::fill(inequalities.begin() + offset, + inequalities.begin() + offset + getNumCols(), 0); + for (unsigned i = 0, e = getNumCols(); i < e; i++) { + inequalities[offset + i] = -expr[i]; + } + inequalities[offset + getNumCols() - 1] += ub; +} + +/// Adds a new local identifier as the floordiv of an affine function of other +/// identifiers, the coefficients of which are provided in 'dividend' and with +/// respect to a positive constant 'divisor'. Two constraints are added to the +/// system to capture equivalence with the floordiv. +/// q = expr floordiv c <=> c*q <= expr <= c*q + c - 1. +void FlatAffineConstraints::addLocalFloorDiv(ArrayRef dividend, + int64_t divisor) { + assert(dividend.size() == getNumCols() && "incorrect dividend size"); + assert(divisor > 0 && "positive divisor expected"); + + addLocalId(getNumLocalIds()); + + // Add two constraints for this new identifier 'q'. + SmallVector bound(dividend.size() + 1); + + // dividend - q * divisor >= 0 + std::copy(dividend.begin(), dividend.begin() + dividend.size() - 1, + bound.begin()); + bound.back() = dividend.back(); + bound[getNumIds() - 1] = -divisor; + addInequality(bound); + + // -dividend +qdivisor * q + divisor - 1 >= 0 + std::transform(bound.begin(), bound.end(), bound.begin(), + std::negate()); + bound[bound.size() - 1] += divisor - 1; + addInequality(bound); +} + +bool FlatAffineConstraints::findId(const Value &id, unsigned *pos) const { + unsigned i = 0; + for (const auto &mayBeId : ids) { + if (mayBeId.hasValue() && mayBeId.getValue() == &id) { + *pos = i; + return true; + } + i++; + } + return false; +} + +void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) { + assert(newSymbolCount <= numDims + numSymbols && + "invalid separation position"); + numDims = numDims + numSymbols - newSymbolCount; + numSymbols = newSymbolCount; +} + +/// Sets the specified identifer to a constant value. +void FlatAffineConstraints::setIdToConstant(unsigned pos, int64_t val) { + unsigned offset = equalities.size(); + equalities.resize(equalities.size() + numReservedCols); + std::fill(equalities.begin() + offset, + equalities.begin() + offset + getNumCols(), 0); + equalities[offset + pos] = 1; + equalities[offset + getNumCols() - 1] = -val; +} + +/// Sets the specified identifer to a constant value; asserts if the id is not +/// found. +void FlatAffineConstraints::setIdToConstant(const Value &id, int64_t val) { + unsigned pos; + if (!findId(id, &pos)) + // This is a pre-condition for this method. + assert(0 && "id not found"); + setIdToConstant(pos, val); +} + +void FlatAffineConstraints::removeEquality(unsigned pos) { + unsigned numEqualities = getNumEqualities(); + assert(pos < numEqualities); + unsigned outputIndex = pos * numReservedCols; + unsigned inputIndex = (pos + 1) * numReservedCols; + unsigned numElemsToCopy = (numEqualities - pos - 1) * numReservedCols; + std::copy(equalities.begin() + inputIndex, + equalities.begin() + inputIndex + numElemsToCopy, + equalities.begin() + outputIndex); + equalities.resize(equalities.size() - numReservedCols); +} + +/// Finds an equality that equates the specified identifier to a constant. +/// Returns the position of the equality row. If 'symbolic' is set to true, +/// symbols are also treated like a constant, i.e., an affine function of the +/// symbols is also treated like a constant. +static int findEqualityToConstant(const FlatAffineConstraints &cst, + unsigned pos, bool symbolic = false) { + assert(pos < cst.getNumIds() && "invalid position"); + for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) { + int64_t v = cst.atEq(r, pos); + if (v * v != 1) + continue; + unsigned c; + unsigned f = symbolic ? cst.getNumDimIds() : cst.getNumIds(); + // This checks for zeros in all positions other than 'pos' in [0, f) + for (c = 0; c < f; c++) { + if (c == pos) + continue; + if (cst.atEq(r, c) != 0) { + // Dependent on another identifier. + break; + } + } + if (c == f) + // Equality is free of other identifiers. + return r; + } + return -1; +} + +void FlatAffineConstraints::setAndEliminate(unsigned pos, int64_t constVal) { + assert(pos < getNumIds() && "invalid position"); + for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { + atIneq(r, getNumCols() - 1) += atIneq(r, pos) * constVal; + } + for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { + atEq(r, getNumCols() - 1) += atEq(r, pos) * constVal; + } + removeId(pos); +} + +bool FlatAffineConstraints::constantFoldId(unsigned pos) { + assert(pos < getNumIds() && "invalid position"); + int rowIdx; + if ((rowIdx = findEqualityToConstant(*this, pos)) == -1) + return false; + + // atEq(rowIdx, pos) is either -1 or 1. + assert(atEq(rowIdx, pos) * atEq(rowIdx, pos) == 1); + int64_t constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos); + setAndEliminate(pos, constVal); + return true; +} + +void FlatAffineConstraints::constantFoldIdRange(unsigned pos, unsigned num) { + for (unsigned s = pos, t = pos, e = pos + num; s < e; s++) { + if (!constantFoldId(t)) + t++; + } +} + +/// Returns the extent (upper bound - lower bound) of the specified +/// identifier if it is found to be a constant; returns None if it's not a +/// constant. This methods treats symbolic identifiers specially, i.e., +/// it looks for constant differences between affine expressions involving +/// only the symbolic identifiers. See comments at function definition for +/// example. 'lb', if provided, is set to the lower bound associated with the +/// constant difference. Note that 'lb' is purely symbolic and thus will contain +/// the coefficients of the symbolic identifiers and the constant coefficient. +// Egs: 0 <= i <= 15, return 16. +// s0 + 2 <= i <= s0 + 17, returns 16. (s0 has to be a symbol) +// s0 + s1 + 16 <= d0 <= s0 + s1 + 31, returns 16. +// s0 - 7 <= 8*j <= s0 returns 1 with lb = s0, lbDivisor = 8 (since lb = +// ceil(s0 - 7 / 8) = floor(s0 / 8)). +Optional FlatAffineConstraints::getConstantBoundOnDimSize( + unsigned pos, SmallVectorImpl *lb, int64_t *lbFloorDivisor) const { + assert(pos < getNumDimIds() && "Invalid identifier position"); + assert(getNumLocalIds() == 0); + + // TODO(bondhugula): eliminate all remaining dimensional identifiers (other + // than the one at 'pos' to make this more powerful. Not needed for + // hyper-rectangular spaces. + + // Find an equality for 'pos'^th identifier that equates it to some function + // of the symbolic identifiers (+ constant). + int eqRow = findEqualityToConstant(*this, pos, /*symbolic=*/true); + if (eqRow != -1) { + // This identifier can only take a single value. + if (lb) { + // Set lb to the symbolic value. + lb->resize(getNumSymbolIds() + 1); + for (unsigned c = 0, f = getNumSymbolIds() + 1; c < f; c++) { + int64_t v = atEq(eqRow, pos); + // atEq(eqRow, pos) is either -1 or 1. + assert(v * v == 1); + (*lb)[c] = v < 0 ? atEq(eqRow, getNumDimIds() + c) / -v + : -atEq(eqRow, getNumDimIds() + c) / v; + } + assert(lbFloorDivisor && + "both lb and divisor or none should be provided"); + *lbFloorDivisor = 1; + } + return 1; + } + + // Check if the identifier appears at all in any of the inequalities. + unsigned r, e; + for (r = 0, e = getNumInequalities(); r < e; r++) { + if (atIneq(r, pos) != 0) + break; + } + if (r == e) + // If it doesn't, there isn't a bound on it. + return None; + + // Positions of constraints that are lower/upper bounds on the variable. + SmallVector lbIndices, ubIndices; + + // Gather all symbolic lower bounds and upper bounds of the variable. Since + // the canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a + // lower bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. + for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { + unsigned c, f; + for (c = 0, f = getNumDimIds(); c < f; c++) { + if (c != pos && atIneq(r, c) != 0) + break; + } + if (c < getNumDimIds()) + continue; + if (atIneq(r, pos) >= 1) + // Lower bound. + lbIndices.push_back(r); + else if (atIneq(r, pos) <= -1) + // Upper bound. + ubIndices.push_back(r); + } + + // TODO(bondhugula): eliminate other dimensional identifiers to make this more + // powerful. Not needed for hyper-rectangular iteration spaces. + + Optional minDiff = None; + unsigned minLbPosition; + for (auto ubPos : ubIndices) { + for (auto lbPos : lbIndices) { + // Look for a lower bound and an upper bound that only differ by a + // constant, i.e., pairs of the form 0 <= c_pos - f(c_i's) <= diffConst. + // For example, if ii is the pos^th variable, we are looking for + // constraints like ii >= i, ii <= ii + 50, 50 being the difference. The + // minimum among all such constant differences is kept since that's the + // constant bounding the extent of the pos^th variable. + unsigned j, e; + for (j = 0, e = getNumCols() - 1; j < e; j++) + if (atIneq(ubPos, j) != -atIneq(lbPos, j)) { + break; + } + if (j < getNumCols() - 1) + continue; + int64_t diff = floorDiv(atIneq(ubPos, getNumCols() - 1) + + atIneq(lbPos, getNumCols() - 1) + 1, + atIneq(lbPos, pos)); + if (minDiff == None || diff < minDiff) { + minDiff = diff; + minLbPosition = lbPos; + } + } + } + if (lb && minDiff.hasValue()) { + // Set lb to the symbolic lower bound. + lb->resize(getNumSymbolIds() + 1); + // The lower bound is the ceildiv of the lb constraint over the coefficient + // of the variable at 'pos'. We express the ceildiv equivalently as a floor + // for uniformity. For eg., if the lower bound constraint was: 32*d0 - N + + // 31 >= 0, the lower bound for d0 is ceil(N - 31, 32), i.e., floor(N, 32). + *lbFloorDivisor = atIneq(minLbPosition, pos); + for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++) { + // ceildiv (val / d) = floordiv (val + d - 1 / d); hence, the addition of + // 'atIneq(minLbPosition, pos) - 1'. + (*lb)[c] = -atIneq(minLbPosition, getNumDimIds() + c) + + atIneq(minLbPosition, pos) - 1; + } + } + return minDiff; +} + +template +Optional +FlatAffineConstraints::getConstantLowerOrUpperBound(unsigned pos) const { + // Check if there's an equality equating the 'pos'^th identifier to a + // constant. + int eqRowIdx = findEqualityToConstant(*this, pos, /*symbolic=*/false); + if (eqRowIdx != -1) + // atEq(rowIdx, pos) is either -1 or 1. + return -atEq(eqRowIdx, getNumCols() - 1) / atEq(eqRowIdx, pos); + + // Check if the identifier appears at all in any of the inequalities. + unsigned r, e; + for (r = 0, e = getNumInequalities(); r < e; r++) { + if (atIneq(r, pos) != 0) + break; + } + if (r == e) + // If it doesn't, there isn't a bound on it. + return None; + + Optional minOrMaxConst = None; + + // Take the max across all const lower bounds (or min across all constant + // upper bounds). + for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { + if (isLower) { + if (atIneq(r, pos) <= 0) + // Not a lower bound. + continue; + } else if (atIneq(r, pos) >= 0) { + // Not an upper bound. + continue; + } + unsigned c, f; + for (c = 0, f = getNumCols() - 1; c < f; c++) + if (c != pos && atIneq(r, c) != 0) + break; + if (c < getNumCols() - 1) + // Not a constant bound. + continue; + + int64_t boundConst = + isLower ? mlir::ceilDiv(-atIneq(r, getNumCols() - 1), atIneq(r, pos)) + : mlir::floorDiv(atIneq(r, getNumCols() - 1), -atIneq(r, pos)); + if (isLower) { + if (minOrMaxConst == None || boundConst > minOrMaxConst) + minOrMaxConst = boundConst; + } else { + if (minOrMaxConst == None || boundConst < minOrMaxConst) + minOrMaxConst = boundConst; + } + } + return minOrMaxConst; +} + +Optional +FlatAffineConstraints::getConstantLowerBound(unsigned pos) const { + return getConstantLowerOrUpperBound(pos); +} + +Optional +FlatAffineConstraints::getConstantUpperBound(unsigned pos) const { + return getConstantLowerOrUpperBound(pos); +} + +// A simple (naive and conservative) check for hyper-rectangularlity. +bool FlatAffineConstraints::isHyperRectangular(unsigned pos, + unsigned num) const { + assert(pos < getNumCols() - 1); + // Check for two non-zero coefficients in the range [pos, pos + sum). + for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { + unsigned sum = 0; + for (unsigned c = pos; c < pos + num; c++) { + if (atIneq(r, c) != 0) + sum++; + } + if (sum > 1) + return false; + } + for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { + unsigned sum = 0; + for (unsigned c = pos; c < pos + num; c++) { + if (atEq(r, c) != 0) + sum++; + } + if (sum > 1) + return false; + } + return true; +} + +void FlatAffineConstraints::print(raw_ostream &os) const { + assert(hasConsistentState()); + os << "\nConstraints (" << getNumDimIds() << " dims, " << getNumSymbolIds() + << " symbols, " << getNumLocalIds() << " locals), (" << getNumConstraints() + << " constraints)\n"; + os << "("; + for (unsigned i = 0, e = getNumIds(); i < e; i++) { + if (ids[i] == None) + os << "None "; + else + os << "Value "; + } + os << " const)\n"; + for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { + for (unsigned j = 0, f = getNumCols(); j < f; ++j) { + os << atEq(i, j) << " "; + } + os << "= 0\n"; + } + for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { + for (unsigned j = 0, f = getNumCols(); j < f; ++j) { + os << atIneq(i, j) << " "; + } + os << ">= 0\n"; + } + os << '\n'; +} + +void FlatAffineConstraints::dump() const { print(llvm::errs()); } + +/// Removes duplicate constraints and trivially true constraints: a constraint +/// of the form >= 0 is considered a trivially true +/// constraint. +// Uses a DenseSet to hash and detect duplicates followed by a linear scan to +// remove duplicates in place. +void FlatAffineConstraints::removeTrivialRedundancy() { + DenseSet> rowSet; + + // Check if constraint is of the form >= 0. + auto isTriviallyValid = [&](unsigned r) -> bool { + for (unsigned c = 0, e = getNumCols() - 1; c < e; c++) { + if (atIneq(r, c) != 0) + return false; + } + return atIneq(r, getNumCols() - 1) >= 0; + }; + + // Detect and mark redundant constraints. + std::vector redunIneq(getNumInequalities(), false); + for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { + int64_t *rowStart = inequalities.data() + numReservedCols * r; + auto row = ArrayRef(rowStart, getNumCols()); + if (isTriviallyValid(r) || !rowSet.insert(row).second) { + redunIneq[r] = true; + } + } + + auto copyRow = [&](unsigned src, unsigned dest) { + if (src == dest) + return; + for (unsigned c = 0, e = getNumCols(); c < e; c++) { + atIneq(dest, c) = atIneq(src, c); + } + }; + + // Scan to get rid of all rows marked redundant, in-place. + unsigned pos = 0; + for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { + if (!redunIneq[r]) + copyRow(r, pos++); + } + inequalities.resize(numReservedCols * pos); + + // TODO(bondhugula): consider doing this for equalities as well, but probably + // not worth the savings. +} + +void FlatAffineConstraints::clearAndCopyFrom( + const FlatAffineConstraints &other) { + FlatAffineConstraints copy(other); + std::swap(*this, copy); + assert(copy.getNumIds() == copy.getIds().size()); +} + +void FlatAffineConstraints::removeId(unsigned pos) { + removeIdRange(pos, pos + 1); +} + +static std::pair +getNewNumDimsSymbols(unsigned pos, const FlatAffineConstraints &cst) { + unsigned numDims = cst.getNumDimIds(); + unsigned numSymbols = cst.getNumSymbolIds(); + unsigned newNumDims, newNumSymbols; + if (pos < numDims) { + newNumDims = numDims - 1; + newNumSymbols = numSymbols; + } else if (pos < numDims + numSymbols) { + assert(numSymbols >= 1); + newNumDims = numDims; + newNumSymbols = numSymbols - 1; + } else { + newNumDims = numDims; + newNumSymbols = numSymbols; + } + return {newNumDims, newNumSymbols}; +} + +#undef DEBUG_TYPE +#define DEBUG_TYPE "fm" + +/// Eliminates identifier at the specified position using Fourier-Motzkin +/// variable elimination. This technique is exact for rational spaces but +/// conservative (in "rare" cases) for integer spaces. The operation corresponds +/// to a projection operation yielding the (convex) set of integer points +/// contained in the rational shadow of the set. An emptiness test that relies +/// on this method will guarantee emptiness, i.e., it disproves the existence of +/// a solution if it says it's empty. +/// If a non-null isResultIntegerExact is passed, it is set to true if the +/// result is also integer exact. If it's set to false, the obtained solution +/// *may* not be exact, i.e., it may contain integer points that do not have an +/// integer pre-image in the original set. +/// +/// Eg: +/// j >= 0, j <= i + 1 +/// i >= 0, i <= N + 1 +/// Eliminating i yields, +/// j >= 0, 0 <= N + 1, j - 1 <= N + 1 +/// +/// If darkShadow = true, this method computes the dark shadow on elimination; +/// the dark shadow is a convex integer subset of the exact integer shadow. A +/// non-empty dark shadow proves the existence of an integer solution. The +/// elimination in such a case could however be an under-approximation, and thus +/// should not be used for scanning sets or used by itself for dependence +/// checking. +/// +/// Eg: 2-d set, * represents grid points, 'o' represents a point in the set. +/// ^ +/// | +/// | * * * * o o +/// i | * * o o o o +/// | o * * * * * +/// ---------------> +/// j -> +/// +/// Eliminating i from this system (projecting on the j dimension): +/// rational shadow / integer light shadow: 1 <= j <= 6 +/// dark shadow: 3 <= j <= 6 +/// exact integer shadow: j = 1 \union 3 <= j <= 6 +/// holes/splinters: j = 2 +/// +/// darkShadow = false, isResultIntegerExact = nullptr are default values. +// TODO(bondhugula): a slight modification to yield dark shadow version of FM +// (tightened), which can prove the existence of a solution if there is one. +void FlatAffineConstraints::FourierMotzkinEliminate( + unsigned pos, bool darkShadow, bool *isResultIntegerExact) { + LLVM_DEBUG(llvm::dbgs() << "FM input (eliminate pos " << pos << "):\n"); + LLVM_DEBUG(dump()); + assert(pos < getNumIds() && "invalid position"); + assert(hasConsistentState()); + + // Check if this identifier can be eliminated through a substitution. + for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { + if (atEq(r, pos) != 0) { + // Use Gaussian elimination here (since we have an equality). + bool ret = gaussianEliminateId(pos); + (void)ret; + assert(ret && "Gaussian elimination guaranteed to succeed"); + LLVM_DEBUG(llvm::dbgs() << "FM output:\n"); + LLVM_DEBUG(dump()); + return; + } + } + + // A fast linear time tightening. + GCDTightenInequalities(); + + // Check if the identifier appears at all in any of the inequalities. + unsigned r, e; + for (r = 0, e = getNumInequalities(); r < e; r++) { + if (atIneq(r, pos) != 0) + break; + } + if (r == getNumInequalities()) { + // If it doesn't appear, just remove the column and return. + // TODO(andydavis,bondhugula): refactor removeColumns to use it from here. + removeId(pos); + LLVM_DEBUG(llvm::dbgs() << "FM output:\n"); + LLVM_DEBUG(dump()); + return; + } + + // Positions of constraints that are lower bounds on the variable. + SmallVector lbIndices; + // Positions of constraints that are lower bounds on the variable. + SmallVector ubIndices; + // Positions of constraints that do not involve the variable. + std::vector nbIndices; + nbIndices.reserve(getNumInequalities()); + + // Gather all lower bounds and upper bounds of the variable. Since the + // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower + // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. + for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { + if (atIneq(r, pos) == 0) { + // Id does not appear in bound. + nbIndices.push_back(r); + } else if (atIneq(r, pos) >= 1) { + // Lower bound. + lbIndices.push_back(r); + } else { + // Upper bound. + ubIndices.push_back(r); + } + } + + // Set the number of dimensions, symbols in the resulting system. + const auto &dimsSymbols = getNewNumDimsSymbols(pos, *this); + unsigned newNumDims = dimsSymbols.first; + unsigned newNumSymbols = dimsSymbols.second; + + SmallVector, 8> newIds; + newIds.reserve(numIds - 1); + newIds.append(ids.begin(), ids.begin() + pos); + newIds.append(ids.begin() + pos + 1, ids.end()); + + /// Create the new system which has one identifier less. + FlatAffineConstraints newFac( + lbIndices.size() * ubIndices.size() + nbIndices.size(), + getNumEqualities(), getNumCols() - 1, newNumDims, newNumSymbols, + /*numLocals=*/getNumIds() - 1 - newNumDims - newNumSymbols, newIds); + + assert(newFac.getIds().size() == newFac.getNumIds()); + + // This will be used to check if the elimination was integer exact. + unsigned lcmProducts = 1; + + // Let x be the variable we are eliminating. + // For each lower bound, lb <= c_l*x, and each upper bound c_u*x <= ub, (note + // that c_l, c_u >= 1) we have: + // lb*lcm(c_l, c_u)/c_l <= lcm(c_l, c_u)*x <= ub*lcm(c_l, c_u)/c_u + // We thus generate a constraint: + // lcm(c_l, c_u)/c_l*lb <= lcm(c_l, c_u)/c_u*ub. + // Note if c_l = c_u = 1, all integer points captured by the resulting + // constraint correspond to integer points in the original system (i.e., they + // have integer pre-images). Hence, if the lcm's are all 1, the elimination is + // integer exact. + for (auto ubPos : ubIndices) { + for (auto lbPos : lbIndices) { + SmallVector ineq; + ineq.reserve(newFac.getNumCols()); + int64_t lbCoeff = atIneq(lbPos, pos); + // Note that in the comments above, ubCoeff is the negation of the + // coefficient in the canonical form as the view taken here is that of the + // term being moved to the other size of '>='. + int64_t ubCoeff = -atIneq(ubPos, pos); + // TODO(bondhugula): refactor this loop to avoid all branches inside. + for (unsigned l = 0, e = getNumCols(); l < e; l++) { + if (l == pos) + continue; + assert(lbCoeff >= 1 && ubCoeff >= 1 && "bounds wrongly identified"); + int64_t lcm = mlir::lcm(lbCoeff, ubCoeff); + ineq.push_back(atIneq(ubPos, l) * (lcm / ubCoeff) + + atIneq(lbPos, l) * (lcm / lbCoeff)); + lcmProducts *= lcm; + } + if (darkShadow) { + // The dark shadow is a convex subset of the exact integer shadow. If + // there is a point here, it proves the existence of a solution. + ineq[ineq.size() - 1] += lbCoeff * ubCoeff - lbCoeff - ubCoeff + 1; + } + // TODO: we need to have a way to add inequalities in-place in + // FlatAffineConstraints instead of creating and copying over. + newFac.addInequality(ineq); + } + } + + if (lcmProducts == 1 && isResultIntegerExact) + *isResultIntegerExact = 1; + + // Copy over the constraints not involving this variable. + for (auto nbPos : nbIndices) { + SmallVector ineq; + ineq.reserve(getNumCols() - 1); + for (unsigned l = 0, e = getNumCols(); l < e; l++) { + if (l == pos) + continue; + ineq.push_back(atIneq(nbPos, l)); + } + newFac.addInequality(ineq); + } + + assert(newFac.getNumConstraints() == + lbIndices.size() * ubIndices.size() + nbIndices.size()); + + // Copy over the equalities. + for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { + SmallVector eq; + eq.reserve(newFac.getNumCols()); + for (unsigned l = 0, e = getNumCols(); l < e; l++) { + if (l == pos) + continue; + eq.push_back(atEq(r, l)); + } + newFac.addEquality(eq); + } + + newFac.removeTrivialRedundancy(); + clearAndCopyFrom(newFac); + LLVM_DEBUG(llvm::dbgs() << "FM output:\n"); + LLVM_DEBUG(dump()); +} + +#undef DEBUG_TYPE +#define DEBUG_TYPE "affine-structures" + +void FlatAffineConstraints::projectOut(unsigned pos, unsigned num) { + if (num == 0) + return; + + // 'pos' can be at most getNumCols() - 2 if num > 0. + assert(getNumCols() < 2 || pos <= getNumCols() - 2 && "invalid position"); + assert(pos + num < getNumCols() && "invalid range"); + + // Eliminate as many identifiers as possible using Gaussian elimination. + unsigned currentPos = pos; + unsigned numToEliminate = num; + unsigned numGaussianEliminated = 0; + + while (currentPos < getNumIds()) { + unsigned curNumEliminated = + gaussianEliminateIds(currentPos, currentPos + numToEliminate); + ++currentPos; + numToEliminate -= curNumEliminated + 1; + numGaussianEliminated += curNumEliminated; + } + + // Eliminate the remaining using Fourier-Motzkin. + for (unsigned i = 0; i < num - numGaussianEliminated; i++) { + unsigned numToEliminate = num - numGaussianEliminated - i; + FourierMotzkinEliminate( + getBestIdToEliminate(*this, pos, pos + numToEliminate)); + } + + // Fast/trivial simplifications. + GCDTightenInequalities(); + // Normalize constraints after tightening since the latter impacts this, but + // not the other way round. + normalizeConstraintsByGCD(); +} + +void FlatAffineConstraints::projectOut(Value *id) { + unsigned pos; + bool ret = findId(*id, &pos); + assert(ret); + (void)ret; + FourierMotzkinEliminate(pos); +} + +bool FlatAffineConstraints::isRangeOneToOne(unsigned start, + unsigned limit) const { + assert(start <= getNumIds() - 1 && "invalid start position"); + assert(limit > start && limit <= getNumIds() && "invalid limit"); + + FlatAffineConstraints tmpCst(*this); + + if (start != 0) { + // Move [start, limit) to the left. + for (unsigned r = 0, e = getNumInequalities(); r < e; ++r) { + for (unsigned c = 0, f = getNumCols(); c < f; ++c) { + if (c >= start && c < limit) + tmpCst.atIneq(r, c - start) = atIneq(r, c); + else if (c < start) + tmpCst.atIneq(r, c + limit - start) = atIneq(r, c); + else + tmpCst.atIneq(r, c) = atIneq(r, c); + } + } + for (unsigned r = 0, e = getNumEqualities(); r < e; ++r) { + for (unsigned c = 0, f = getNumCols(); c < f; ++c) { + if (c >= start && c < limit) + tmpCst.atEq(r, c - start) = atEq(r, c); + else if (c < start) + tmpCst.atEq(r, c + limit - start) = atEq(r, c); + else + tmpCst.atEq(r, c) = atEq(r, c); + } + } + } + + // Mark everything to the right as symbols so that we can check the extents in + // a symbolic way below. + tmpCst.setDimSymbolSeparation(getNumIds() - (limit - start)); + + // Check if the extents of all the specified dimensions are just one (when + // treating the rest as symbols). + for (unsigned pos = 0, e = tmpCst.getNumDimIds(); pos < e; ++pos) { + auto extent = tmpCst.getConstantBoundOnDimSize(pos); + if (!extent.hasValue() || extent.getValue() != 1) + return false; + } + return true; +} + +void FlatAffineConstraints::clearConstraints() { + equalities.clear(); + inequalities.clear(); +} + +namespace { + +enum BoundCmpResult { Greater, Less, Equal, Unknown }; + +/// Compares two affine bounds whose coefficients are provided in 'first' and +/// 'second'. The last coefficient is the constant term. +static BoundCmpResult compareBounds(ArrayRef a, ArrayRef b) { + assert(a.size() == b.size()); + + // For the bounds to be comparable, their corresponding identifier + // coefficients should be equal; the constant terms are then compared to + // determine less/greater/equal. + + if (!std::equal(a.begin(), a.end() - 1, b.begin())) + return Unknown; + + if (a.back() == b.back()) + return Equal; + + return a.back() < b.back() ? Less : Greater; +} +}; // namespace + +// Compute the bounding box with respect to 'other' by finding the min of the +// lower bounds and the max of the upper bounds along each of the dimensions. +bool FlatAffineConstraints::unionBoundingBox( + const FlatAffineConstraints &other) { + assert(other.getNumDimIds() == numDims); + assert(other.getNumSymbolIds() == getNumSymbolIds()); + assert(other.getNumLocalIds() == 0); + assert(getNumLocalIds() == 0); + std::vector> boundingLbs; + std::vector> boundingUbs; + boundingLbs.reserve(2 * getNumDimIds()); + boundingUbs.reserve(2 * getNumDimIds()); + + SmallVector lb, otherLb; + lb.reserve(getNumSymbolIds() + 1); + otherLb.reserve(getNumSymbolIds() + 1); + int64_t lbDivisor, otherLbDivisor; + for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) { + lb.clear(); + auto extent = getConstantBoundOnDimSize(d, &lb, &lbDivisor); + if (!extent.hasValue()) + // TODO(bondhugula): symbolic extents when necessary. + return false; + + otherLb.clear(); + auto otherExtent = + other.getConstantBoundOnDimSize(d, &otherLb, &otherLbDivisor); + if (!otherExtent.hasValue() || lbDivisor != otherLbDivisor) + // TODO(bondhugula): symbolic extents when necessary. + return false; + + assert(lbDivisor > 0 && "divisor always expected to be positive"); + + // Compute min of lower bounds and max of upper bounds. + ArrayRef minLb, maxUb; + + auto res = compareBounds(lb, otherLb); + // Identify min. + if (res == BoundCmpResult::Less || res == BoundCmpResult::Equal) { + minLb = lb; + } else if (res == BoundCmpResult::Greater) { + minLb = otherLb; + } else { + // Uncomparable. + return false; + } + + // Do the same for ub's but max of upper bounds. + SmallVector ub(lb), otherUb(otherLb); + ub.back() += extent.getValue() - 1; + otherUb.back() += otherExtent.getValue() - 1; + + // Identify max. + auto uRes = compareBounds(ub, otherUb); + if (uRes == BoundCmpResult::Greater || uRes == BoundCmpResult::Equal) { + maxUb = ub; + } else if (uRes == BoundCmpResult::Less) { + maxUb = otherUb; + } else { + // Uncomparable. + return false; + } + + SmallVector newLb(getNumCols(), 0); + SmallVector newUb(getNumCols(), 0); + + // The divisor for lb, ub, otherLb, otherUb at this point is lbDivisor, + // and so it's the divisor for newLb and newUb as well. + newLb[d] = lbDivisor; + newUb[d] = -lbDivisor; + // Copy over the symbolic part + constant term. + std::copy(minLb.begin(), minLb.end(), newLb.begin() + getNumDimIds()); + std::transform(newLb.begin() + getNumDimIds(), newLb.end(), + newLb.begin() + getNumDimIds(), std::negate()); + std::copy(maxUb.begin(), maxUb.end(), newUb.begin() + getNumDimIds()); + + boundingLbs.push_back(newLb); + boundingUbs.push_back(newUb); + } + + // Clear all constraints and add the lower/upper bounds for the bounding box. + clearConstraints(); + for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) { + addInequality(boundingLbs[d]); + addInequality(boundingUbs[d]); + } + + return true; +} diff --git a/mlir/lib/Transforms/ComposeAffineMaps.cpp b/mlir/lib/Transforms/ComposeAffineMaps.cpp deleted file mode 100644 index 796477c64f2..00000000000 --- a/mlir/lib/Transforms/ComposeAffineMaps.cpp +++ /dev/null @@ -1,96 +0,0 @@ -//===- ComposeAffineMaps.cpp - MLIR Affine Transform Class-----*- C++ -*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file implements a testing pass which composes affine maps from -// AffineApplyOps in a Function, by forward subtituting results from an -// AffineApplyOp into any of its users which are also AffineApplyOps. -// -//===----------------------------------------------------------------------===// - -#include "mlir/AffineOps/AffineOps.h" -#include "mlir/Analysis/AffineAnalysis.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/Pass.h" -#include "mlir/StandardOps/StandardOps.h" -#include "mlir/Transforms/Passes.h" -#include "mlir/Transforms/Utils.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/raw_ostream.h" - -using namespace mlir; - -namespace { - -// ComposeAffineMaps walks all affine apply op's in a function, and for each -// such op, composes into it the results of any other AffineApplyOps - so -// that all operands of the composed AffineApplyOp are guaranteed to be either -// loop IVs or terminal symbols, (i.e., Values that are themselves not the -// result of any AffineApplyOp). After this composition, AffineApplyOps with no -// remaining uses are erased. -// TODO(andydavis) Remove this when Chris adds instruction combiner pass. -struct ComposeAffineMaps : public FunctionPass { - explicit ComposeAffineMaps() : FunctionPass(&ComposeAffineMaps::passID) {} - PassResult runOnFunction(Function *f) override; - - SmallVector, 8> affineApplyOps; - - static char passID; -}; - -} // end anonymous namespace - -char ComposeAffineMaps::passID = 0; - -FunctionPass *mlir::createComposeAffineMapsPass() { - return new ComposeAffineMaps(); -} - -static bool affineApplyOp(const Instruction &inst) { - return inst.isa(); -} - -PassResult ComposeAffineMaps::runOnFunction(Function *f) { - // If needed for future efficiency, reserve space based on a pre-walk. - affineApplyOps.clear(); - f->walk( - [&](OpPointer afOp) { affineApplyOps.push_back(afOp); }); - for (auto afOp : affineApplyOps) { - SmallVector operands(afOp->getOperands()); - FuncBuilder b(afOp->getInstruction()); - auto newAfOp = makeComposedAffineApply(&b, afOp->getLoc(), - afOp->getAffineMap(), operands); - afOp->replaceAllUsesWith(newAfOp); - } - - // Erase dead affine apply ops. - affineApplyOps.clear(); - f->walk( - [&](OpPointer afOp) { affineApplyOps.push_back(afOp); }); - for (auto it = affineApplyOps.rbegin(); it != affineApplyOps.rend(); ++it) { - if ((*it)->use_empty()) { - (*it)->erase(); - } - } - - return success(); -} - -static PassRegistration pass("compose-affine-maps", - "Compose affine maps"); diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 6bc5260850b..7fd9128b358 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -22,8 +22,8 @@ //===----------------------------------------------------------------------===// #include "mlir/AffineOps/AffineOps.h" -#include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Utils.h" +#include "mlir/IR/AffineStructures.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass.h" diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index d7e1b610022..3e0bf046ddf 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -21,11 +21,11 @@ #include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" -#include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Analysis/Utils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/AffineStructures.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass.h" diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 758d434d25e..368a1dac1df 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -21,8 +21,8 @@ #include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" -#include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/LoopAnalysis.h" +#include "mlir/IR/AffineStructures.h" #include "mlir/IR/Builders.h" #include "mlir/Pass.h" #include "mlir/Transforms/LoopUtils.h" diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp index 29509911e31..897498e8346 100644 --- a/mlir/lib/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp @@ -19,7 +19,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Analysis/AffineStructures.h" +#include "mlir/IR/AffineStructures.h" #include "mlir/IR/Function.h" #include "mlir/IR/Instruction.h" #include "mlir/IR/IntegerSet.h" diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 1e9ad25fbec..724411ad245 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -24,9 +24,9 @@ #include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" -#include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Dominance.h" #include "mlir/Analysis/Utils.h" +#include "mlir/IR/AffineStructures.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Module.h" #include "mlir/StandardOps/StandardOps.h" diff --git a/mlir/test/AffineOps/canonicalize.mlir b/mlir/test/AffineOps/canonicalize.mlir new file mode 100644 index 00000000000..661dc732641 --- /dev/null +++ b/mlir/test/AffineOps/canonicalize.mlir @@ -0,0 +1,263 @@ +// RUN: mlir-opt %s -canonicalize | FileCheck %s + +// Affine maps for test case: compose_affine_maps_1dto2d_no_symbols +// CHECK-DAG: [[MAP0:#map[0-9]+]] = (d0) -> (d0 - 1) +// CHECK-DAG: [[MAP1:#map[0-9]+]] = (d0) -> (d0 + 1) + +// Affine maps for test case: compose_affine_maps_1dto2d_with_symbols +// CHECK-DAG: [[MAP4:#map[0-9]+]] = (d0)[s0] -> (d0 - s0) +// CHECK-DAG: [[MAP6:#map[0-9]+]] = (d0)[s0] -> (d0 * 2 - s0 + 1) +// CHECK-DAG: [[MAP7:#map[0-9]+]] = (d0)[s0, s1] -> (d0 * 2 + s0 - s1) + +// Affine map for test case: compose_affine_maps_d2_tile +// CHECK-DAG: [[MAP8:#map[0-9]+]] = (d0, d1)[s0] -> ((d0 ceildiv s0) * s0 + d1 mod s0) + +// Affine maps for test case: compose_affine_maps_dependent_loads +// CHECK-DAG: [[MAP9:#map[0-9]+]] = (d0)[s0] -> (d0 + s0) +// CHECK-DAG: [[MAP10:#map[0-9]+]] = (d0)[s0] -> (d0 * s0) +// CHECK-DAG: [[MAP11:#map[0-9]+]] = (d0)[s0, s1] -> ((d0 + s1) ceildiv s0) +// CHECK-DAG: [[MAP12:#map[0-9]+]] = (d0)[s0] -> ((d0 - s0) * s0) + +// Affine maps for test case: compose_affine_maps_diamond_dependency +// CHECK-DAG: [[MAP13A:#map[0-9]+]] = (d0) -> ((d0 + 6) ceildiv 8) +// CHECK-DAG: [[MAP13B:#map[0-9]+]] = (d0) -> ((d0 * 4 - 4) floordiv 3) + +// Affine maps for test case: arg_used_as_dim_and_symbol +// CHECK-DAG: [[MAP14:#map[0-9]+]] = (d0, d1, d2)[s0, s1] -> (-d0 - d1 + d2 + s0 + s1) + +// Affine maps for test case: partial_fold_map +// CHECK-DAG: [[MAP15:#map[0-9]+]] = (d0, d1) -> (d0 - d1) + +// CHECK-LABEL: func @compose_affine_maps_1dto2d_no_symbols() { +func @compose_affine_maps_1dto2d_no_symbols() { + %0 = alloc() : memref<4x4xf32> + + for %i0 = 0 to 15 { + // Test load[%x, %x] + + %x0 = affine_apply (d0) -> (d0 - 1) (%i0) + %x1_0 = affine_apply (d0, d1) -> (d0) (%x0, %x0) + %x1_1 = affine_apply (d0, d1) -> (d1) (%x0, %x0) + + // CHECK: [[I0A:%[0-9]+]] = affine_apply [[MAP0]](%i0) + // CHECK-NEXT: [[I0B:%[0-9]+]] = affine_apply [[MAP0]](%i0) + // CHECK-NEXT: load %0{{\[}}[[I0A]], [[I0B]]{{\]}} + %v0 = load %0[%x1_0, %x1_1] : memref<4x4xf32> + + // Test load[%y, %y] + %y0 = affine_apply (d0) -> (d0 + 1) (%i0) + %y1_0 = affine_apply (d0, d1) -> (d0) (%y0, %y0) + %y1_1 = affine_apply (d0, d1) -> (d1) (%y0, %y0) + + // CHECK-NEXT: [[I1A:%[0-9]+]] = affine_apply [[MAP1]](%i0) + // CHECK-NEXT: [[I1B:%[0-9]+]] = affine_apply [[MAP1]](%i0) + // CHECK-NEXT: load %0{{\[}}[[I1A]], [[I1B]]{{\]}} + %v1 = load %0[%y1_0, %y1_1] : memref<4x4xf32> + + // Test load[%x, %y] + %xy_0 = affine_apply (d0, d1) -> (d0) (%x0, %y0) + %xy_1 = affine_apply (d0, d1) -> (d1) (%x0, %y0) + + // CHECK-NEXT: [[I2A:%[0-9]+]] = affine_apply [[MAP0]](%i0) + // CHECK-NEXT: [[I2B:%[0-9]+]] = affine_apply [[MAP1]](%i0) + // CHECK-NEXT: load %0{{\[}}[[I2A]], [[I2B]]{{\]}} + %v2 = load %0[%xy_0, %xy_1] : memref<4x4xf32> + + // Test load[%y, %x] + %yx_0 = affine_apply (d0, d1) -> (d0) (%y0, %x0) + %yx_1 = affine_apply (d0, d1) -> (d1) (%y0, %x0) + // CHECK-NEXT: [[I3A:%[0-9]+]] = affine_apply [[MAP1]](%i0) + // CHECK-NEXT: [[I3B:%[0-9]+]] = affine_apply [[MAP0]](%i0) + // CHECK-NEXT: load %0{{\[}}[[I3A]], [[I3B]]{{\]}} + %v3 = load %0[%yx_0, %yx_1] : memref<4x4xf32> + } + return +} + +// CHECK-LABEL: func @compose_affine_maps_1dto2d_with_symbols() { +func @compose_affine_maps_1dto2d_with_symbols() { + %0 = alloc() : memref<4x4xf32> + + for %i0 = 0 to 15 { + // Test load[%x0, %x0] with symbol %c4 + %c4 = constant 4 : index + %x0 = affine_apply (d0)[s0] -> (d0 - s0) (%i0)[%c4] + + // CHECK: [[I0:%[0-9]+]] = affine_apply [[MAP4]](%i0)[%c4] + // CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I0]], [[I0]]{{\]}} + %v0 = load %0[%x0, %x0] : memref<4x4xf32> + + // Test load[%x0, %x1] with symbol %c4 captured by '%x0' map. + %x1 = affine_apply (d0) -> (d0 + 1) (%i0) + %y1 = affine_apply (d0, d1) -> (d0+d1) (%x0, %x1) + // CHECK-NEXT: [[I1:%[0-9]+]] = affine_apply [[MAP6]](%i0)[%c4] + // CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I1]], [[I1]]{{\]}} + %v1 = load %0[%y1, %y1] : memref<4x4xf32> + + // Test load[%x1, %x0] with symbol %c4 captured by '%x0' map. + %y2 = affine_apply (d0, d1) -> (d0 + d1) (%x1, %x0) + // CHECK-NEXT: [[I2:%[0-9]+]] = affine_apply [[MAP6]](%i0)[%c4] + // CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I2]], [[I2]]{{\]}} + %v2 = load %0[%y2, %y2] : memref<4x4xf32> + + // Test load[%x2, %x0] with symbol %c4 from '%x0' and %c5 from '%x2' + %c5 = constant 5 : index + %x2 = affine_apply (d0)[s0] -> (d0 + s0) (%i0)[%c5] + %y3 = affine_apply (d0, d1) -> (d0 + d1) (%x2, %x0) + // CHECK: [[I3:%[0-9]+]] = affine_apply [[MAP7]](%i0)[%c5, %c4] + // CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I3]], [[I3]]{{\]}} + %v3 = load %0[%y3, %y3] : memref<4x4xf32> + } + return +} + +// CHECK-LABEL: func @compose_affine_maps_2d_tile() { +func @compose_affine_maps_2d_tile() { + %0 = alloc() : memref<16x32xf32> + %1 = alloc() : memref<16x32xf32> + + %c4 = constant 4 : index + %c8 = constant 8 : index + + for %i0 = 0 to 3 { + %x0 = affine_apply (d0)[s0] -> (d0 ceildiv s0) (%i0)[%c4] + for %i1 = 0 to 3 { + %x1 = affine_apply (d0)[s0] -> (d0 ceildiv s0) (%i1)[%c8] + for %i2 = 0 to 3 { + %x2 = affine_apply (d0)[s0] -> (d0 mod s0) (%i2)[%c4] + for %i3 = 0 to 3 { + %x3 = affine_apply (d0)[s0] -> (d0 mod s0) (%i3)[%c8] + + %x40 = affine_apply (d0, d1, d2, d3)[s0, s1] -> + ((d0 * s0) + d2) (%x0, %x1, %x2, %x3)[%c4, %c8] + %x41 = affine_apply (d0, d1, d2, d3)[s0, s1] -> + ((d1 * s1) + d3) (%x0, %x1, %x2, %x3)[%c4, %c8] + // CHECK: [[I0:%[0-9]+]] = affine_apply [[MAP8]](%i0, %i2)[%c4] + // CHECK: [[I1:%[0-9]+]] = affine_apply [[MAP8]](%i1, %i3)[%c8] + // CHECK-NEXT: [[L0:%[0-9]+]] = load %{{[0-9]+}}{{\[}}[[I0]], [[I1]]{{\]}} + %v0 = load %0[%x40, %x41] : memref<16x32xf32> + + // CHECK-NEXT: store [[L0]], %{{[0-9]+}}{{\[}}[[I0]], [[I1]]{{\]}} + store %v0, %1[%x40, %x41] : memref<16x32xf32> + } + } + } + } + return +} + +// CHECK-LABEL: func @compose_affine_maps_dependent_loads() { +func @compose_affine_maps_dependent_loads() { + %0 = alloc() : memref<16x32xf32> + %1 = alloc() : memref<16x32xf32> + + for %i0 = 0 to 3 { + for %i1 = 0 to 3 { + for %i2 = 0 to 3 { + %c3 = constant 3 : index + %c7 = constant 7 : index + + %x00 = affine_apply (d0, d1, d2)[s0, s1] -> (d0 + s0) + (%i0, %i1, %i2)[%c3, %c7] + %x01 = affine_apply (d0, d1, d2)[s0, s1] -> (d1 - s1) + (%i0, %i1, %i2)[%c3, %c7] + %x02 = affine_apply (d0, d1, d2)[s0, s1] -> (d2 * s0) + (%i0, %i1, %i2)[%c3, %c7] + + // CHECK: [[I0:%[0-9]+]] = affine_apply [[MAP9]](%i0)[%c3] + // CHECK: [[I1:%[0-9]+]] = affine_apply [[MAP4]](%i1)[%c7] + // CHECK: [[I2:%[0-9]+]] = affine_apply [[MAP10]](%i2)[%c3] + // CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I0]], [[I1]]{{\]}} + %v0 = load %0[%x00, %x01] : memref<16x32xf32> + + // CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I0]], [[I2]]{{\]}} + %v1 = load %0[%x00, %x02] : memref<16x32xf32> + + // Swizzle %i0, %i1 + // CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I1]], [[I0]]{{\]}} + %v2 = load %0[%x01, %x00] : memref<16x32xf32> + + // Swizzle %x00, %x01 and %c3, %c7 + %x10 = affine_apply (d0, d1)[s0, s1] -> (d0 * s1) + (%x01, %x00)[%c7, %c3] + %x11 = affine_apply (d0, d1)[s0, s1] -> (d1 ceildiv s0) + (%x01, %x00)[%c7, %c3] + + // CHECK-NEXT: [[I2A:%[0-9]+]] = affine_apply [[MAP12]](%i1)[%c7] + // CHECK-NEXT: [[I2B:%[0-9]+]] = affine_apply [[MAP11]](%i0)[%c3, %c7] + // CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I2A]], [[I2B]]{{\]}} + %v3 = load %0[%x10, %x11] : memref<16x32xf32> + } + } + } + return +} + +// CHECK-LABEL: func @compose_affine_maps_diamond_dependency() { +func @compose_affine_maps_diamond_dependency() { + %0 = alloc() : memref<4x4xf32> + + for %i0 = 0 to 15 { + %a = affine_apply (d0) -> (d0 - 1) (%i0) + %b = affine_apply (d0) -> (d0 + 7) (%a) + %c = affine_apply (d0) -> (d0 * 4) (%a) + %d0 = affine_apply (d0, d1) -> (d0 ceildiv 8) (%b, %c) + %d1 = affine_apply (d0, d1) -> (d1 floordiv 3) (%b, %c) + // CHECK: [[I0:%[0-9]+]] = affine_apply [[MAP13A]](%i0) + // CHECK: [[I1:%[0-9]+]] = affine_apply [[MAP13B]](%i0) + // CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I0]], [[I1]]{{\]}} + %v = load %0[%d0, %d1] : memref<4x4xf32> + } + + return +} + +// CHECK-LABEL: func @arg_used_as_dim_and_symbol +func @arg_used_as_dim_and_symbol(%arg0: memref<100x100xf32>, %arg1: index) { + %c9 = constant 9 : index + %1 = alloc() : memref<100x100xf32, 1> + %2 = alloc() : memref<1xi32> + for %i0 = 0 to 100 { + for %i1 = 0 to 100 { + %3 = affine_apply (d0, d1)[s0, s1] -> (d1 + s0 + s1) + (%i0, %i1)[%arg1, %c9] + %4 = affine_apply (d0, d1, d3) -> (d3 - (d0 + d1)) + (%arg1, %c9, %3) + // CHECK: [[I0:%[0-9]+]] = affine_apply [[MAP14]](%arg1, %c9, %i1)[%arg1, %c9] + // CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I0]], %arg1{{\]}} + %5 = load %1[%4, %arg1] : memref<100x100xf32, 1> + } + } + return +} + +// CHECK-LABEL: func @trivial_maps +func @trivial_maps() { + // CHECK-NOT: affine_apply + + %0 = alloc() : memref<10xf32> + %c0 = constant 0 : index + %cst = constant 0.000000e+00 : f32 + for %i1 = 0 to 10 { + %1 = affine_apply ()[s0] -> (s0)()[%c0] + store %cst, %0[%1] : memref<10xf32> + %2 = load %0[%c0] : memref<10xf32> + + %3 = affine_apply ()[] -> (0)()[] + store %cst, %0[%3] : memref<10xf32> + %4 = load %0[%c0] : memref<10xf32> + } + return +} + +// CHECK-LABEL: func @partial_fold_map +func @partial_fold_map(%arg0: memref, %arg1: index, %arg2: index) { + // TODO: Constant fold one index into affine_apply + %c42 = constant 42 : index + %2 = affine_apply (d0, d1) -> (d0 - d1) (%arg1, %c42) + store %2, %arg0[] : memref + // CHECK: [[X:%[0-9]+]] = affine_apply [[MAP15]](%arg1, %c42) + // CHECK-NEXT: store [[X]], %arg0 + + return +} diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir index 7df61cda44d..dfc34cd66a5 100644 --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -1,11 +1,5 @@ // RUN: mlir-opt %s -canonicalize | FileCheck %s -// CHECK-DAG: [[D0M1:#map.*]] = (d0) -> (d0 - 1) -// CHECK-DAG: [[D0MD1:#map.*]] = (d0, d1) -> (d0 - d1) -// CHECK-DAG: [[D0PD0:#map.*]] = (d0) -> (d0 + d0) -// CHECK-DAG: [[D0P2:#map.*]] = (d0) -> (d0 + 2) -// CHECK-DAG: [[DEDUPMAP:#map.*]] = (d0)[s0, s1] -> (d0 - d0 + s0 + s1 + s0 + s1 - 1) - // CHECK-LABEL: func @test_subi_zero func @test_subi_zero(%arg0: i32) -> i32 { // CHECK-NEXT: %c0_i32 = constant 0 : i32 @@ -267,45 +261,6 @@ func @const_fold_propagate() -> memref { // CHECK: = alloc() : memref<64x32xf32> %Av = alloc(%VT_i_s, %VT_k_l) : memref return %Av : memref - } - - -// CHECK-LABEL: func @simplify_affine_apply -func @simplify_affine_apply(%arg0: memref, %arg1: index, %arg2: index) { - // Only uses d1, not d0. - %0 = affine_apply (d0, d1) -> (d1 - 1) (%arg1, %arg2) - store %0, %arg0[] : memref - // CHECK: [[X:%[0-9]+]] = affine_apply [[D0M1]](%arg2) - // CHECK-NEXT: store [[X]], %arg0 - - // TODO: Constant fold one index into affine_apply - %c42 = constant 42 : index - %2 = affine_apply (d0, d1) -> (d0 - d1) (%arg1, %c42) - store %2, %arg0[] : memref - // CHECK: [[X:%[0-9]+]] = affine_apply [[D0MD1]](%arg1, %c42) - // CHECK-NEXT: store [[X]], %arg0 - - %3 = affine_apply (d0, d1) -> (d0 + d1) (%arg1, %arg1) - store %3, %arg0[] : memref - // CHECK: [[X:%[0-9]+]] = affine_apply [[D0PD0]](%arg1) - // CHECK-NEXT: store [[X]], %arg0 - - // TODO: Compose affine maps. - %x0 = affine_apply (d0) -> (d0 - 1) (%arg1) - %x1 = affine_apply (d0) -> (d0+2) (%x0) - store %x1, %arg0[] : memref - - // CHECK: [[X:%[0-9]+]] = affine_apply [[D0M1]](%arg1) - // CHECK-NEXT: [[Y:%[0-9]+]] = affine_apply [[D0P2]]([[X]]) - // CHECK-NEXT: store [[Y]], %arg0 - - // Drop redundant exprs and symbols. - %dedup = affine_apply (d0, d1) [s0, s1, s2, s3] -> (d0 - d1 - 1 + s0 + s1 + s2 + s3) (%arg1, %arg1)[%arg2, %arg1, %arg2, %arg1] - store %dedup, %arg0[] : memref - // CHECK: [[DEDUP:%.+]] = affine_apply [[DEDUPMAP]](%arg1)[%arg2, %arg1] - // CHECK-NEXT: store [[DEDUP]], %arg0 - - return } // CHECK-LABEL: func @cond_br_folding diff --git a/mlir/test/Transforms/compose-affine-maps.mlir b/mlir/test/Transforms/compose-affine-maps.mlir deleted file mode 100644 index 47e8f6ac34d..00000000000 --- a/mlir/test/Transforms/compose-affine-maps.mlir +++ /dev/null @@ -1,255 +0,0 @@ -// RUN: mlir-opt %s -compose-affine-maps | FileCheck %s - -// Affine maps for test case: compose_affine_maps_1dto2d_no_symbols -// CHECK-DAG: [[MAP0:#map[0-9]+]] = (d0) -> (d0 - 1) -// CHECK-DAG: [[MAP1:#map[0-9]+]] = (d0) -> (d0 + 1) - -// Affine maps for test case: compose_affine_maps_1dto2d_with_symbols -// CHECK-DAG: [[MAP4:#map[0-9]+]] = (d0)[s0] -> (d0 - s0) -// CHECK-DAG: [[MAP6:#map[0-9]+]] = (d0)[s0] -> (d0 * 2 - s0 + 1) -// CHECK-DAG: [[MAP7:#map[0-9]+]] = (d0)[s0, s1] -> (d0 * 2 + s0 - s1) - -// Affine map for test case: compose_affine_maps_d2_tile -// CHECK-DAG: [[MAP8:#map[0-9]+]] = (d0, d1)[s0] -> ((d0 ceildiv s0) * s0 + d1 mod s0) - -// Affine maps for test case: compose_affine_maps_dependent_loads -// CHECK-DAG: [[MAP9:#map[0-9]+]] = (d0)[s0] -> (d0 + s0) -// CHECK-DAG: [[MAP10:#map[0-9]+]] = (d0)[s0] -> (d0 * s0) -// CHECK-DAG: [[MAP12A:#map[0-9]+]] = (d0)[s0, s1] -> ((d0 - s1) * s0) -// CHECK-DAG: [[MAP12B:#map[0-9]+]] = (d0)[s0, s1] -> ((d0 + s1) ceildiv s0) - -// Affine maps for test case: compose_affine_maps_diamond_dependency -// CHECK-DAG: [[MAP13A:#map[0-9]+]] = (d0) -> ((d0 + 6) ceildiv 8) -// CHECK-DAG: [[MAP13B:#map[0-9]+]] = (d0) -> ((d0 * 4 - 4) floordiv 3) - -// Affine maps for test case: arg_used_as_dim_and_symbol -// CHECK-DAG: [[MAP14:#map[0-9]+]] = (d0, d1, d2)[s0, s1] -> (-d0 - d1 + d2 + s0 + s1) - -// Affine maps for test case: zero_map -// CHECK-DAG: [[MAP15:#map[0-9]+]] = ()[s0] -> (s0) - -// Affine maps for test case: zero_map -// CHECK-DAG: [[MAP16:#map[0-9]+]] = () -> (0) - -// CHECK-LABEL: func @compose_affine_maps_1dto2d_no_symbols() { -func @compose_affine_maps_1dto2d_no_symbols() { - %0 = alloc() : memref<4x4xf32> - - for %i0 = 0 to 15 { - // Test load[%x, %x] - - %x0 = affine_apply (d0) -> (d0 - 1) (%i0) - %x1_0 = affine_apply (d0, d1) -> (d0) (%x0, %x0) - %x1_1 = affine_apply (d0, d1) -> (d1) (%x0, %x0) - - // CHECK: [[I0A:%[0-9]+]] = affine_apply [[MAP0]](%i0) - // CHECK-NEXT: [[I0B:%[0-9]+]] = affine_apply [[MAP0]](%i0) - // CHECK-NEXT: load %0{{\[}}[[I0A]], [[I0B]]{{\]}} - %v0 = load %0[%x1_0, %x1_1] : memref<4x4xf32> - - // Test load[%y, %y] - %y0 = affine_apply (d0) -> (d0 + 1) (%i0) - %y1_0 = affine_apply (d0, d1) -> (d0) (%y0, %y0) - %y1_1 = affine_apply (d0, d1) -> (d1) (%y0, %y0) - - // CHECK-NEXT: [[I1A:%[0-9]+]] = affine_apply [[MAP1]](%i0) - // CHECK-NEXT: [[I1B:%[0-9]+]] = affine_apply [[MAP1]](%i0) - // CHECK-NEXT: load %0{{\[}}[[I1A]], [[I1B]]{{\]}} - %v1 = load %0[%y1_0, %y1_1] : memref<4x4xf32> - - // Test load[%x, %y] - %xy_0 = affine_apply (d0, d1) -> (d0) (%x0, %y0) - %xy_1 = affine_apply (d0, d1) -> (d1) (%x0, %y0) - - // CHECK-NEXT: [[I2A:%[0-9]+]] = affine_apply [[MAP0]](%i0) - // CHECK-NEXT: [[I2B:%[0-9]+]] = affine_apply [[MAP1]](%i0) - // CHECK-NEXT: load %0{{\[}}[[I2A]], [[I2B]]{{\]}} - %v2 = load %0[%xy_0, %xy_1] : memref<4x4xf32> - - // Test load[%y, %x] - %yx_0 = affine_apply (d0, d1) -> (d0) (%y0, %x0) - %yx_1 = affine_apply (d0, d1) -> (d1) (%y0, %x0) - // CHECK-NEXT: [[I3A:%[0-9]+]] = affine_apply [[MAP1]](%i0) - // CHECK-NEXT: [[I3B:%[0-9]+]] = affine_apply [[MAP0]](%i0) - // CHECK-NEXT: load %0{{\[}}[[I3A]], [[I3B]]{{\]}} - %v3 = load %0[%yx_0, %yx_1] : memref<4x4xf32> - } - return -} - -// CHECK-LABEL: func @compose_affine_maps_1dto2d_with_symbols() { -func @compose_affine_maps_1dto2d_with_symbols() { - %0 = alloc() : memref<4x4xf32> - - for %i0 = 0 to 15 { - // Test load[%x0, %x0] with symbol %c4 - %c4 = constant 4 : index - %x0 = affine_apply (d0)[s0] -> (d0 - s0) (%i0)[%c4] - - // CHECK: constant 4 - // CHECK-NEXT: [[I0:%[0-9]+]] = affine_apply [[MAP4]](%i0)[%c4] - // CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I0]], [[I0]]{{\]}} - %v0 = load %0[%x0, %x0] : memref<4x4xf32> - - // Test load[%x0, %x1] with symbol %c4 captured by '%x0' map. - %x1 = affine_apply (d0) -> (d0 + 1) (%i0) - %y1 = affine_apply (d0, d1) -> (d0+d1) (%x0, %x1) - // CHECK-NEXT: [[I1:%[0-9]+]] = affine_apply [[MAP6]](%i0)[%c4] - // CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I1]], [[I1]]{{\]}} - %v1 = load %0[%y1, %y1] : memref<4x4xf32> - - // Test load[%x1, %x0] with symbol %c4 captured by '%x0' map. - %y2 = affine_apply (d0, d1) -> (d0 + d1) (%x1, %x0) - // CHECK-NEXT: [[I2:%[0-9]+]] = affine_apply [[MAP6]](%i0)[%c4] - // CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I2]], [[I2]]{{\]}} - %v2 = load %0[%y2, %y2] : memref<4x4xf32> - - // Test load[%x2, %x0] with symbol %c4 from '%x0' and %c5 from '%x2' - %c5 = constant 5 : index - %x2 = affine_apply (d0)[s0] -> (d0 + s0) (%i0)[%c5] - %y3 = affine_apply (d0, d1) -> (d0 + d1) (%x2, %x0) - // CHECK: [[I3:%[0-9]+]] = affine_apply [[MAP7]](%i0)[%c5, %c4] - // CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I3]], [[I3]]{{\]}} - %v3 = load %0[%y3, %y3] : memref<4x4xf32> - } - return -} - -// CHECK-LABEL: func @compose_affine_maps_2d_tile() { -func @compose_affine_maps_2d_tile() { - %0 = alloc() : memref<16x32xf32> - %1 = alloc() : memref<16x32xf32> - - %c4 = constant 4 : index - %c8 = constant 8 : index - - for %i0 = 0 to 3 { - %x0 = affine_apply (d0)[s0] -> (d0 ceildiv s0) (%i0)[%c4] - for %i1 = 0 to 3 { - %x1 = affine_apply (d0)[s0] -> (d0 ceildiv s0) (%i1)[%c8] - for %i2 = 0 to 3 { - %x2 = affine_apply (d0)[s0] -> (d0 mod s0) (%i2)[%c4] - for %i3 = 0 to 3 { - %x3 = affine_apply (d0)[s0] -> (d0 mod s0) (%i3)[%c8] - - %x40 = affine_apply (d0, d1, d2, d3)[s0, s1] -> - ((d0 * s0) + d2) (%x0, %x1, %x2, %x3)[%c4, %c8] - %x41 = affine_apply (d0, d1, d2, d3)[s0, s1] -> - ((d1 * s1) + d3) (%x0, %x1, %x2, %x3)[%c4, %c8] - // CHECK: [[I0:%[0-9]+]] = affine_apply [[MAP8]](%i0, %i2)[%c4] - // CHECK: [[I1:%[0-9]+]] = affine_apply [[MAP8]](%i1, %i3)[%c8] - // CHECK-NEXT: [[L0:%[0-9]+]] = load %{{[0-9]+}}{{\[}}[[I0]], [[I1]]{{\]}} - %v0 = load %0[%x40, %x41] : memref<16x32xf32> - - // CHECK-NEXT: store [[L0]], %{{[0-9]+}}{{\[}}[[I0]], [[I1]]{{\]}} - store %v0, %1[%x40, %x41] : memref<16x32xf32> - } - } - } - } - return -} - -// CHECK-LABEL: func @compose_affine_maps_dependent_loads() { -func @compose_affine_maps_dependent_loads() { - %0 = alloc() : memref<16x32xf32> - %1 = alloc() : memref<16x32xf32> - - for %i0 = 0 to 3 { - for %i1 = 0 to 3 { - for %i2 = 0 to 3 { - %c3 = constant 3 : index - %c7 = constant 7 : index - - %x00 = affine_apply (d0, d1, d2)[s0, s1] -> (d0 + s0) - (%i0, %i1, %i2)[%c3, %c7] - %x01 = affine_apply (d0, d1, d2)[s0, s1] -> (d1 - s1) - (%i0, %i1, %i2)[%c3, %c7] - %x02 = affine_apply (d0, d1, d2)[s0, s1] -> (d2 * s0) - (%i0, %i1, %i2)[%c3, %c7] - - // CHECK: [[I0:%[0-9]+]] = affine_apply [[MAP9]](%i0)[%c3] - // CHECK: [[I1:%[0-9]+]] = affine_apply [[MAP4]](%i1)[%c7] - // CHECK: [[I2:%[0-9]+]] = affine_apply [[MAP10]](%i2)[%c3] - // CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I0]], [[I1]]{{\]}} - %v0 = load %0[%x00, %x01] : memref<16x32xf32> - - // CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I0]], [[I2]]{{\]}} - %v1 = load %0[%x00, %x02] : memref<16x32xf32> - - // Swizzle %i0, %i1 - // CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I1]], [[I0]]{{\]}} - %v2 = load %0[%x01, %x00] : memref<16x32xf32> - - // Swizzle %x00, %x01 and %c3, %c7 - %x10 = affine_apply (d0, d1)[s0, s1] -> (d0 * s1) - (%x01, %x00)[%c7, %c3] - %x11 = affine_apply (d0, d1)[s0, s1] -> (d1 ceildiv s0) - (%x01, %x00)[%c7, %c3] - - // CHECK-NEXT: [[I2A:%[0-9]+]] = affine_apply [[MAP12A]](%i1)[%c3, %c7] - // CHECK-NEXT: [[I2B:%[0-9]+]] = affine_apply [[MAP12B]](%i0)[%c7, %c3] - // CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I2A]], [[I2B]]{{\]}} - %v3 = load %0[%x10, %x11] : memref<16x32xf32> - } - } - } - return -} - -// CHECK-LABEL: func @compose_affine_maps_diamond_dependency() { -func @compose_affine_maps_diamond_dependency() { - %0 = alloc() : memref<4x4xf32> - - for %i0 = 0 to 15 { - %a = affine_apply (d0) -> (d0 - 1) (%i0) - %b = affine_apply (d0) -> (d0 + 7) (%a) - %c = affine_apply (d0) -> (d0 * 4) (%a) - %d0 = affine_apply (d0, d1) -> (d0 ceildiv 8) (%b, %c) - %d1 = affine_apply (d0, d1) -> (d1 floordiv 3) (%b, %c) - // CHECK: [[I0:%[0-9]+]] = affine_apply [[MAP13A]](%i0) - // CHECK: [[I1:%[0-9]+]] = affine_apply [[MAP13B]](%i0) - // CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I0]], [[I1]]{{\]}} - %v = load %0[%d0, %d1] : memref<4x4xf32> - } - - return -} - -// CHECK-LABEL: func @arg_used_as_dim_and_symbol(%arg0: memref<100x100xf32>, %arg1: index) { -func @arg_used_as_dim_and_symbol(%arg0: memref<100x100xf32>, %arg1: index) { - %c9 = constant 9 : index - %1 = alloc() : memref<100x100xf32, 1> - %2 = alloc() : memref<1xi32> - for %i0 = 0 to 100 { - for %i1 = 0 to 100 { - %3 = affine_apply (d0, d1)[s0, s1] -> (d1 + s0 + s1) - (%i0, %i1)[%arg1, %c9] - %4 = affine_apply (d0, d1, d3) -> (d3 - (d0 + d1)) - (%arg1, %c9, %3) - // CHECK: [[I0:%[0-9]+]] = affine_apply [[MAP14]](%arg1, %c9, %i1)[%arg1, %c9] - // CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I0]], %arg1{{\]}} - %5 = load %1[%4, %arg1] : memref<100x100xf32, 1> - } - } - return -} - -// CHECK-LABEL: func @trivial_maps -func @trivial_maps() { - %0 = alloc() : memref<10xf32> - %c0 = constant 0 : index - %cst = constant 0.000000e+00 : f32 - for %i1 = 0 to 10 { - %1 = affine_apply ()[s0] -> (s0)()[%c0] - // CHECK: {{.*}} = affine_apply [[MAP15]]()[%c0] - store %cst, %0[%1] : memref<10xf32> - %2 = load %0[%c0] : memref<10xf32> - - %3 = affine_apply ()[] -> (0)()[] - // CHECK: {{.*}} = affine_apply [[MAP16]]() - store %cst, %0[%3] : memref<10xf32> - %4 = load %0[%c0] : memref<10xf32> - } - return -} -- cgit v1.2.3 From b9dde91ea6ec449ecd4203fb06fdf726001a0c37 Mon Sep 17 00:00:00 2001 From: MLIR Team Date: Wed, 6 Feb 2019 11:01:10 -0800 Subject: Adds the ability to compute the MemRefRegion of a sliced loop nest. Utilizes this feature during loop fusion cost computation, to compute what the write region of a fusion candidate loop nest slice would be (without having to materialize the slice or change the IR). *) Adds parameter to public API of MemRefRegion::compute for passing in the slice loop bounds to compute the memref region of the loop nest slice. *) Exposes public method MemRefRegion::getRegionSize for computing the size of the memref region in bytes. PiperOrigin-RevId: 232706165 --- mlir/include/mlir/Analysis/Utils.h | 97 +++++++++++++++------------ mlir/include/mlir/IR/AffineStructures.h | 11 +++ mlir/lib/Analysis/Utils.cpp | 114 +++++++++++++++++++++----------- mlir/lib/IR/AffineStructures.cpp | 60 +++++++++++++++++ mlir/lib/Transforms/LoopFusion.cpp | 38 ++++++++--- 5 files changed, 230 insertions(+), 90 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index 6cd688045c5..d6a67f65c7c 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -61,6 +61,45 @@ void getLoopIVs(const Instruction &inst, /// surrounding this instruction. unsigned getNestingDepth(const Instruction &stmt); +/// ComputationSliceState aggregates loop bound AffineMaps and their associated +/// operands for a set of loops within a loop nest (typically the set of loops +/// surrounding a store operation). Loop bound AffineMaps which are non-null +/// represent slices of that loop's iteration space. +struct ComputationSliceState { + // List of lower bound AffineMaps. + SmallVector lbs; + // List of upper bound AffineMaps. + SmallVector ubs; + // List of lower bound operands (lbOperands[i] are used by 'lbs[i]'). + std::vector> lbOperands; + // List of upper bound operands (ubOperands[i] are used by 'ubs[i]'). + std::vector> ubOperands; +}; + +/// Computes computation slice loop bounds for the loop nest surrounding +/// 'srcAccess', where the returned loop bound AffineMaps are functions of +/// loop IVs from the loop nest surrounding 'dstAccess'. +/// Returns true on success, false otherwise. +bool getBackwardComputationSliceState(const MemRefAccess &srcAccess, + const MemRefAccess &dstAccess, + unsigned dstLoopDepth, + ComputationSliceState *sliceState); + +/// Creates a clone of the computation contained in the loop nest surrounding +/// 'srcOpInst', slices the iteration space of src loop based on slice bounds +/// in 'sliceState', and inserts the computation slice at the beginning of the +/// instruction block of the loop at 'dstLoopDepth' in the loop nest surrounding +/// 'dstOpInst'. Returns the top-level loop of the computation slice on +/// success, returns nullptr otherwise. +// Loop depth is a crucial optimization choice that determines where to +// materialize the results of the backward slice - presenting a trade-off b/w +// storage and redundant computation in several cases. +// TODO(andydavis) Support computation slices with common surrounding loops. +OpPointer +insertBackwardComputationSlice(Instruction *srcOpInst, Instruction *dstOpInst, + unsigned dstLoopDepth, + ComputationSliceState *sliceState); + /// A region of a memref's data space; this is typically constructed by /// analyzing load/store op's on this memref and the index space of loops /// surrounding such op's. @@ -86,7 +125,17 @@ struct MemRefRegion { /// symbolic identifiers which could include any of the loop IVs surrounding /// opInst up until 'loopDepth' and another additional Function symbols /// involved with the access (for eg., those appear in affine_apply's, loop - /// bounds, etc.). + /// bounds, etc.). If 'sliceState' is non-null, operands from 'sliceState' + /// are added as symbols, and the following constraints are added to the + /// system: + /// *) Inequality constraints which represent loop bounds for 'sliceState' + /// operands which are loop IVS (these represent the destination loop IVs + /// of the slice, and are added as symbols to MemRefRegion's constraint + /// system). + /// *) Inequality constraints for the slice bounds in 'sliceState', which + /// represent the bounds on the loop IVs in this constraint system w.r.t + /// to slice operands (which correspond to symbols). + /// /// For example, the memref region for this operation at loopDepth = 1 will /// be: /// @@ -99,7 +148,8 @@ struct MemRefRegion { /// {memref = %A, write = false, {%i <= m0 <= %i + 7} } /// The last field is a 2-d FlatAffineConstraints symbolic in %i. /// - bool compute(Instruction *inst, unsigned loopDepth); + bool compute(Instruction *inst, unsigned loopDepth, + ComputationSliceState *sliceState = nullptr); FlatAffineConstraints *getConstraints() { return &cst; } const FlatAffineConstraints *getConstraints() const { return &cst; } @@ -128,6 +178,9 @@ struct MemRefRegion { return cst.getConstantBoundOnDimSize(pos, lb); } + /// Returns the size of this MemRefRegion in bytes. + Optional getRegionSize(); + bool unionBoundingBox(const MemRefRegion &other); /// Returns the rank of the memref that this region corresponds to. @@ -169,52 +222,12 @@ bool boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, unsigned getNumCommonSurroundingLoops(const Instruction &A, const Instruction &B); -/// ComputationSliceState aggregates loop bound AffineMaps and their associated -/// operands for a set of loops within a loop nest (typically the set of loops -/// surrounding a store operation). Loop bound AffineMaps which are non-null -/// represent slices of that loop's iteration space. -struct ComputationSliceState { - // List of lower bound AffineMaps. - SmallVector lbs; - // List of upper bound AffineMaps. - SmallVector ubs; - // List of lower bound operands (lbOperands[i] are used by 'lbs[i]'). - std::vector> lbOperands; - // List of upper bound operands (ubOperands[i] are used by 'ubs[i]'). - std::vector> ubOperands; -}; - -/// Computes computation slice loop bounds for the loop nest surrounding -/// 'srcAccess', where the returned loop bound AffineMaps are functions of -/// loop IVs from the loop nest surrounding 'dstAccess'. -/// Returns true on success, false otherwise. -bool getBackwardComputationSliceState(const MemRefAccess &srcAccess, - const MemRefAccess &dstAccess, - unsigned dstLoopDepth, - ComputationSliceState *sliceState); - -/// Creates a clone of the computation contained in the loop nest surrounding -/// 'srcOpInst', slices the iteration space of src loop based on slice bounds -/// in 'sliceState', and inserts the computation slice at the beginning of the -/// instruction block of the loop at 'dstLoopDepth' in the loop nest surrounding -/// 'dstOpInst'. Returns the top-level loop of the computation slice on -/// success, returns nullptr otherwise. -// Loop depth is a crucial optimization choice that determines where to -// materialize the results of the backward slice - presenting a trade-off b/w -// storage and redundant computation in several cases. -// TODO(andydavis) Support computation slices with common surrounding loops. -OpPointer -insertBackwardComputationSlice(Instruction *srcOpInst, Instruction *dstOpInst, - unsigned dstLoopDepth, - ComputationSliceState *sliceState); - /// Gets the memory footprint of all data touched in the specified memory space /// in bytes; if the memory space is unspecified, considers all memory spaces. Optional getMemoryFootprintBytes(ConstOpPointer forOp, int memorySpace = -1); Optional getMemoryFootprintBytes(const Block &block, int memorySpace = -1); - } // end namespace mlir #endif // MLIR_ANALYSIS_UTILS_H diff --git a/mlir/include/mlir/IR/AffineStructures.h b/mlir/include/mlir/IR/AffineStructures.h index 9c88436dcb0..2acee9a7f39 100644 --- a/mlir/include/mlir/IR/AffineStructures.h +++ b/mlir/include/mlir/IR/AffineStructures.h @@ -378,6 +378,17 @@ public: SmallVectorImpl *lbMaps, SmallVectorImpl *ubMaps); + /// Adds slice lower bounds represented by lower bounds in 'lbMaps' and upper + /// bounds in 'ubMaps' to the constraint system. Note that both lower/upper + /// bounds share the same operand list 'operands'. + /// This function assumes that position 'lbMaps.size' == 'ubMaps.size', + /// and that positions [0, lbMaps.size) represent dimensional identifiers + /// which correspond to the loop IVs whose iteration bounds are being sliced. + /// Note that both lower/upper bounds use operands from 'operands'. + /// Returns true on success, returns false for unimplemented cases. + bool addSliceBounds(ArrayRef lbMaps, ArrayRef ubMaps, + ArrayRef operands); + // Adds an inequality (>= 0) from the coefficients specified in inEq. void addInequality(ArrayRef inEq); // Adds an equality from the coefficients specified in eq. diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 2e753f8d10a..bdc5d19d0be 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -122,7 +122,8 @@ bool MemRefRegion::unionBoundingBox(const MemRefRegion &other) { // // TODO(bondhugula): extend this to any other memref dereferencing ops // (dma_start, dma_wait). -bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth) { +bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth, + ComputationSliceState *sliceState) { assert((inst->isa() || inst->isa()) && "load/store op expected"); @@ -147,18 +148,33 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth) { access.getAccessMap(&accessValueMap); AffineMap accessMap = accessValueMap.getAffineMap(); + unsigned numDims = accessMap.getNumDims(); + unsigned numSymbols = accessMap.getNumSymbols(); + unsigned numOperands = accessValueMap.getNumOperands(); + // Merge operands with slice operands. + SmallVector operands; + operands.resize(numOperands); + for (unsigned i = 0; i < numOperands; ++i) + operands[i] = accessValueMap.getOperand(i); + + if (sliceState != nullptr) { + // Append slice operands to 'operands' as symbols. + operands.append(sliceState->lbOperands[0].begin(), + sliceState->lbOperands[0].end()); + // Update 'numSymbols' by operands from 'sliceState'. + numSymbols += sliceState->lbOperands[0].size(); + } + // We'll first associate the dims and symbols of the access map to the dims // and symbols resp. of cst. This will change below once cst is // fully constructed out. - cst.reset(accessMap.getNumDims(), accessMap.getNumSymbols(), 0, - accessValueMap.getOperands()); + cst.reset(numDims, numSymbols, 0, operands); // Add equality constraints. - unsigned numDims = accessMap.getNumDims(); - unsigned numSymbols = accessMap.getNumSymbols(); // Add inequalties for loop lower/upper bounds. for (unsigned i = 0; i < numDims + numSymbols; ++i) { - if (auto loop = getForInductionVarOwner(accessValueMap.getOperand(i))) { + auto *operand = operands[i]; + if (auto loop = getForInductionVarOwner(operand)) { // Note that cst can now have more dimensions than accessMap if the // bounds expressions involve outer loops or other symbols. // TODO(bondhugula): rewrite this to use getInstIndexSet; this way @@ -167,7 +183,7 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth) { return false; } else { // Has to be a valid symbol. - auto *symbol = accessValueMap.getOperand(i); + auto *symbol = operand; assert(isValidSymbol(symbol)); // Check if the symbol is a constant. if (auto *inst = symbol->getDefiningInst()) { @@ -178,6 +194,33 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth) { } } + // Add lower/upper bounds on loop IVs using bounds from 'sliceState'. + if (sliceState != nullptr) { + // Add dim and symbol slice operands. + for (const auto &operand : sliceState->lbOperands[0]) { + unsigned loc; + if (!cst.findId(*operand, &loc)) { + if (isValidSymbol(operand)) { + cst.addSymbolId(cst.getNumSymbolIds(), const_cast(operand)); + loc = cst.getNumDimIds() + cst.getNumSymbolIds() - 1; + // Check if the symbol is a constant. + if (auto *opInst = operand->getDefiningInst()) { + if (auto constOp = opInst->dyn_cast()) { + cst.setIdToConstant(*operand, constOp->getValue()); + } + } + } else { + cst.addDimId(cst.getNumDimIds(), const_cast(operand)); + loc = cst.getNumDimIds() - 1; + } + } + } + // Add upper/lower bounds from 'sliceState' to 'cst'. + if (!cst.addSliceBounds(sliceState->lbs, sliceState->ubs, + sliceState->lbOperands[0])) + return false; + } + // Add access function equalities to connect loop IVs to data dimensions. if (!cst.composeMap(&accessValueMap)) { LLVM_DEBUG(llvm::dbgs() << "getMemRefRegion: compose affine map failed\n"); @@ -233,6 +276,32 @@ static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { return llvm::divideCeil(sizeInBits, 8); } +// Returns the size of the region. +Optional MemRefRegion::getRegionSize() { + auto memRefType = memref->getType().cast(); + + auto layoutMaps = memRefType.getAffineMaps(); + if (layoutMaps.size() > 1 || + (layoutMaps.size() == 1 && !layoutMaps[0].isIdentity())) { + LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n"); + return false; + } + + // Indices to use for the DmaStart op. + // Indices for the original memref being DMAed from/to. + SmallVector memIndices; + // Indices for the faster buffer being DMAed into/from. + SmallVector bufIndices; + + // Compute the extents of the buffer. + Optional numElements = getConstantBoundingSizeAndShape(); + if (!numElements.hasValue()) { + LLVM_DEBUG(llvm::dbgs() << "Dynamic shapes not yet supported\n"); + return None; + } + return getMemRefEltSizeInBytes(memRefType) * numElements.getValue(); +} + /// Returns the size of memref data in bytes if it's statically shaped, None /// otherwise. If the element of the memref has vector type, takes into account /// size of the vector as well. @@ -420,8 +489,6 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess, // entire destination index set. Subtract out the dependent destination // iterations from destination index set and check for emptiness --- this is one // solution. -// TODO(andydavis) Remove dependence on 'srcLoopDepth' here. Instead project -// out loop IVs we don't care about and produce smaller slice. OpPointer mlir::insertBackwardComputationSlice( Instruction *srcOpInst, Instruction *dstOpInst, unsigned dstLoopDepth, ComputationSliceState *sliceState) { @@ -537,33 +604,6 @@ unsigned mlir::getNumCommonSurroundingLoops(const Instruction &A, return numCommonLoops; } -// Returns the size of the region. -static Optional getRegionSize(const MemRefRegion ®ion) { - auto *memref = region.memref; - auto memRefType = memref->getType().cast(); - - auto layoutMaps = memRefType.getAffineMaps(); - if (layoutMaps.size() > 1 || - (layoutMaps.size() == 1 && !layoutMaps[0].isIdentity())) { - LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n"); - return false; - } - - // Indices to use for the DmaStart op. - // Indices for the original memref being DMAed from/to. - SmallVector memIndices; - // Indices for the faster buffer being DMAed into/from. - SmallVector bufIndices; - - // Compute the extents of the buffer. - Optional numElements = region.getConstantBoundingSizeAndShape(); - if (!numElements.hasValue()) { - LLVM_DEBUG(llvm::dbgs() << "Dynamic shapes not yet supported\n"); - return None; - } - return getMemRefEltSizeInBytes(memRefType) * numElements.getValue(); -} - Optional mlir::getMemoryFootprintBytes(ConstOpPointer forOp, int memorySpace) { @@ -601,7 +641,7 @@ Optional mlir::getMemoryFootprintBytes(const Block &block, int64_t totalSizeInBytes = 0; for (const auto ®ion : regions) { - auto size = getRegionSize(*region); + auto size = region->getRegionSize(); if (!size.hasValue()) return None; totalSizeInBytes += size.getValue(); diff --git a/mlir/lib/IR/AffineStructures.cpp b/mlir/lib/IR/AffineStructures.cpp index 1306199dc79..d92b1b60182 100644 --- a/mlir/lib/IR/AffineStructures.cpp +++ b/mlir/lib/IR/AffineStructures.cpp @@ -1129,6 +1129,66 @@ void FlatAffineConstraints::getSliceBounds(unsigned num, MLIRContext *context, } } +// Adds slice lower/upper bounds from 'lbMaps'/'upMaps' to the constraint +// system. This function assumes that position 'lbMaps.size' == 'ubMaps.size', +// and that positions [0, lbMaps.size) represent dimensional identifiers which +// correspond to the loop IVs whose iteration bounds are being sliced. +// Note that both lower/upper bounds use operands from 'operands'. +// Returns true on success. Returns false for unimplemented cases such as +// semi-affine expressions or expressions with mod/floordiv. +bool FlatAffineConstraints::addSliceBounds(ArrayRef lbMaps, + ArrayRef ubMaps, + ArrayRef operands) { + assert(lbMaps.size() == ubMaps.size()); + // Record positions of the operands in the constraint system. + SmallVector positions; + for (const auto &operand : operands) { + unsigned loc; + if (!findId(*operand, &loc)) + assert(0 && "expected to be found"); + positions.push_back(loc); + } + + auto addLowerOrUpperBound = [&](unsigned pos, AffineMap boundMap, + bool lower) -> bool { + FlatAffineConstraints localVarCst; + std::vector> flatExprs; + if (!getFlattenedAffineExprs(boundMap, &flatExprs, &localVarCst)) { + LLVM_DEBUG(llvm::dbgs() << "semi-affine expressions not yet supported\n"); + return false; + } + if (localVarCst.getNumLocalIds() > 0) { + LLVM_DEBUG(llvm::dbgs() + << "loop bounds with mod/floordiv expr's not yet supported\n"); + return false; + } + + for (const auto &flatExpr : flatExprs) { + SmallVector ineq(getNumCols(), 0); + ineq[pos] = lower ? 1 : -1; + for (unsigned j = 0, e = boundMap.getNumInputs(); j < e; j++) { + ineq[positions[j]] = lower ? -flatExpr[j] : flatExpr[j]; + } + // Constant term. + ineq[getNumCols() - 1] = + lower ? -flatExpr[flatExpr.size() - 1] + // Upper bound in flattenedExpr is an exclusive one. + : flatExpr[flatExpr.size() - 1] - 1; + addInequality(ineq); + } + return true; + }; + + for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) { + if (!addLowerOrUpperBound(i, lbMaps[i], /*lower=*/true)) + return false; + if (!addLowerOrUpperBound(i, ubMaps[i], /*lower=*/false)) + return false; + } + + return true; +} + void FlatAffineConstraints::addEquality(ArrayRef eq) { assert(eq.size() == getNumCols()); unsigned offset = equalities.size(); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 3e0bf046ddf..8d5f51059bf 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -1118,12 +1118,23 @@ static bool isFusionProfitable(Instruction *srcOpInst, /*tripCountOverrideMap=*/nullptr, /*computeCostMap=*/nullptr); + // Compute src loop nest write region size. + MemRefRegion srcWriteRegion(srcOpInst->getLoc()); + srcWriteRegion.compute(srcOpInst, /*loopDepth=*/0); + Optional maybeSrcWriteRegionSizeBytes = + srcWriteRegion.getRegionSize(); + if (!maybeSrcWriteRegionSizeBytes.hasValue()) + return false; + int64_t srcWriteRegionSizeBytes = maybeSrcWriteRegionSizeBytes.getValue(); + // Compute op instance count for the src loop nest. uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0]->getInstruction(), &dstLoopNestStats, /*tripCountOverrideMap=*/nullptr, /*computeCostMap=*/nullptr); + // Evaluate all depth choices for materializing the slice in the destination + // loop nest. llvm::SmallDenseMap sliceTripCountMap; DenseMap computeCostMap; for (unsigned i = maxDstLoopDepth; i >= 1; --i) { @@ -1187,11 +1198,21 @@ static bool isFusionProfitable(Instruction *srcOpInst, (static_cast(srcLoopNestCost) + dstLoopNestCost) - 1; - // TODO(bondhugula): This is an ugly approximation. Fix this by finding a - // good way to calculate the footprint of the memref in the slice and - // divide it by the total memory footprint of the fused computation. - double storageReduction = - static_cast(srcLoopNestCost) / sliceIterationCount; + // Compute what the slice write MemRefRegion would be, if the src loop + // nest slice 'sliceStates[i - 1]' were to be inserted into the dst loop + // nest at loop depth 'i' + MemRefRegion sliceWriteRegion(srcOpInst->getLoc()); + sliceWriteRegion.compute(srcOpInst, /*loopDepth=*/0, &sliceStates[i - 1]); + Optional maybeSliceWriteRegionSizeBytes = + sliceWriteRegion.getRegionSize(); + if (!maybeSliceWriteRegionSizeBytes.hasValue() || + maybeSliceWriteRegionSizeBytes.getValue() == 0) + continue; + int64_t sliceWriteRegionSizeBytes = + maybeSliceWriteRegionSizeBytes.getValue(); + + double storageReduction = static_cast(srcWriteRegionSizeBytes) / + static_cast(sliceWriteRegionSizeBytes); LLVM_DEBUG({ std::stringstream msg; @@ -1219,12 +1240,7 @@ static bool isFusionProfitable(Instruction *srcOpInst, maxStorageReduction = storageReduction; bestDstLoopDepth = i; minFusedLoopNestComputeCost = fusedLoopNestComputeCost; - // TODO(bondhugula,andydavis): find a good way to compute the memory - // footprint of the materialized slice. - // Approximating this to the compute cost of the slice. This could be an - // under-approximation or an overapproximation, but in many cases - // accurate. - sliceMemEstimate = sliceIterationCount; + sliceMemEstimate = sliceWriteRegionSizeBytes; } } -- cgit v1.2.3 From 90d10b4e00cc6397a03ddc981b7be8bab43a9f38 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Wed, 6 Feb 2019 11:58:03 -0800 Subject: NFC: Rename the 'for' operation in the AffineOps dialect to 'affine.for'. The is the second step to adding a namespace to the AffineOps dialect. PiperOrigin-RevId: 232717775 --- mlir/g3doc/Dialects/Affine.md | 26 +- mlir/g3doc/Dialects/SuperVector.md | 28 +- mlir/g3doc/LangRef.md | 16 +- mlir/g3doc/Passes.md | 4 +- mlir/g3doc/Rationale.md | 54 +-- mlir/g3doc/RationaleSimplifiedPolyhedralForm.md | 14 +- mlir/include/mlir/AffineOps/AffineOps.h | 34 +- mlir/include/mlir/Analysis/Utils.h | 10 +- mlir/include/mlir/Analysis/VectorAnalysis.h | 16 +- mlir/include/mlir/Transforms/LoopUtils.h | 7 +- mlir/include/mlir/Transforms/Utils.h | 4 +- mlir/lib/AffineOps/AffineOps.cpp | 2 +- mlir/lib/Analysis/AffineAnalysis.cpp | 8 +- mlir/lib/Analysis/Utils.cpp | 12 +- mlir/lib/IR/Block.cpp | 2 +- mlir/lib/Transforms/DmaGeneration.cpp | 19 +- mlir/lib/Transforms/LoopFusion.cpp | 3 +- mlir/lib/Transforms/LoopTiling.cpp | 3 +- mlir/lib/Transforms/LoopUnroll.cpp | 2 +- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 2 +- mlir/lib/Transforms/LowerAffine.cpp | 8 +- mlir/lib/Transforms/LowerVectorTransfers.cpp | 10 +- mlir/lib/Transforms/MaterializeVectors.cpp | 24 +- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 8 +- mlir/lib/Transforms/PipelineDataTransfer.cpp | 18 +- mlir/lib/Transforms/Utils/LoopUtils.cpp | 14 +- mlir/lib/Transforms/Utils/Utils.cpp | 4 +- mlir/lib/Transforms/Vectorize.cpp | 50 +- mlir/test/AffineOps/canonicalize.mlir | 38 +- mlir/test/IR/invalid.mlir | 60 +-- mlir/test/IR/locations.mlir | 2 +- mlir/test/IR/parser.mlir | 76 +-- mlir/test/IR/pretty-locations.mlir | 2 +- .../Vectorize/lower_vector_transfers.mlir | 58 +-- mlir/test/Transforms/Vectorize/materialize.mlir | 16 +- .../Vectorize/materialize_vectors_1d_to_1d.mlir | 24 +- .../Vectorize/materialize_vectors_2d_to_1d.mlir | 24 +- .../Vectorize/materialize_vectors_2d_to_2d.mlir | 24 +- mlir/test/Transforms/Vectorize/normalize_maps.mlir | 24 +- mlir/test/Transforms/Vectorize/vectorize_1d.mlir | 62 +-- mlir/test/Transforms/Vectorize/vectorize_2d.mlir | 30 +- mlir/test/Transforms/Vectorize/vectorize_3d.mlir | 20 +- .../Vectorize/vectorize_outer_loop_2d.mlir | 18 +- .../vectorize_outer_loop_transpose_2d.mlir | 42 +- .../Vectorize/vectorize_transpose_2d.mlir | 42 +- mlir/test/Transforms/canonicalize.mlir | 12 +- mlir/test/Transforms/constant-fold.mlir | 4 +- mlir/test/Transforms/cse.mlir | 8 +- mlir/test/Transforms/dma-generate.mlir | 94 ++-- mlir/test/Transforms/loop-fusion.mlir | 516 ++++++++++----------- mlir/test/Transforms/loop-tiling.mlir | 36 +- mlir/test/Transforms/lower-affine.mlir | 28 +- mlir/test/Transforms/memref-bound-check.mlir | 32 +- mlir/test/Transforms/memref-dataflow-opt.mlir | 62 +-- mlir/test/Transforms/memref-dependence-check.mlir | 86 ++-- mlir/test/Transforms/pipeline-data-transfer.mlir | 50 +- .../Transforms/simplify-affine-structures.mlir | 38 +- mlir/test/Transforms/strip-debuginfo.mlir | 2 +- mlir/test/Transforms/unroll-jam.mlir | 30 +- mlir/test/Transforms/unroll.mlir | 136 +++--- mlir/utils/emacs/mlir-mode.el | 2 +- mlir/utils/vim/mlir.vim | 4 +- 62 files changed, 1055 insertions(+), 1049 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/g3doc/Dialects/Affine.md b/mlir/g3doc/Dialects/Affine.md index 6af24668d99..59295cf7489 100644 --- a/mlir/g3doc/Dialects/Affine.md +++ b/mlir/g3doc/Dialects/Affine.md @@ -15,7 +15,7 @@ loops and if instructions), the result of a [`affine.apply` operation](#'affine.apply'-operation) that recursively takes as arguments any symbolic identifiers. Dimensions may be bound not only to anything that a symbol is bound to, but also to induction variables of enclosing -[for instructions](#'for'-operation), and the result of an +['affine.for' operations](#'affine.for'-operation), and the result of an [`affine.apply` operation](#'affine.apply'-operation) (which recursively may use other dimensions and symbols). @@ -47,12 +47,12 @@ Example: %2 = affine.apply (i)[s0] -> (i+s0) (%42)[%n] ``` -#### 'for' operation {#'for'-operation} +#### 'affine.for' operation {#'affine.for'-operation} Syntax: ``` {.ebnf} -operation ::= `for` ssa-id `=` lower-bound `to` upper-bound +operation ::= `affine.for` ssa-id `=` lower-bound `to` upper-bound (`step` integer-literal)? `{` inst* `}` lower-bound ::= `max`? affine-map dim-and-symbol-use-list | shorthand-bound @@ -60,17 +60,17 @@ upper-bound ::= `min`? affine-map dim-and-symbol-use-list | shorthand-bound shorthand-bound ::= ssa-id | `-`? integer-literal ``` -The `for` operation represents an affine loop nest, defining an SSA value for -its induction variable. This SSA value always has type +The `affine.for` operation represents an affine loop nest, defining an SSA value +for its induction variable. This SSA value always has type [`index`](LangRef.md#index-type), which is the size of the machine word. -The `for` operation executes its body a number of times iterating from a lower -bound to an upper bound by a stride. The stride, represented by `step`, is a -positive constant integer which defaults to "1" if not present. The lower and +The `affine.for` operation executes its body a number of times iterating from a +lower bound to an upper bound by a stride. The stride, represented by `step`, is +a positive constant integer which defaults to "1" if not present. The lower and upper bounds specify a half-open range: the range includes the lower bound but does not include the upper bound. -The lower and upper bounds of a `for` operation are represented as an +The lower and upper bounds of a `affine.for` operation are represented as an application of an affine mapping to a list of SSA values passed to the map. The [same restrictions](#restrictions-on-dimensions-and-symbols) hold for these SSA values as for all bindings of SSA values to dimensions and symbols. @@ -94,8 +94,8 @@ Example showing reverse iteration of the inner loop: func @simple_example(%A: memref, %B: memref) { %N = dim %A, 0 : memref - for %i = 0 to %N step 1 { - for %j = 0 to %N { // implicitly steps by 1 + affine.for %i = 0 to %N step 1 { + affine.for %j = 0 to %N { // implicitly steps by 1 %0 = affine.apply #map57(%j)[%N] %tmp = call @F1(%A, %i, %0) : (memref, index, index)->(f32) call @F2(%tmp, %B, %i, %0) : (f32, memref, index, index)->() @@ -130,8 +130,8 @@ Example: #set = (d0, d1)[s0]: (d0 - 10 >= 0, s0 - d0 - 9 >= 0, d1 - 10 >= 0, s0 - d1 - 9 >= 0) func @reduced_domain_example(%A, %X, %N) : (memref<10xi32>, i32, i32) { - for %i = 0 to %N { - for %j = 0 to %N { + affine.for %i = 0 to %N { + affine.for %j = 0 to %N { %0 = affine.apply #map42(%j) %tmp = call @S1(%X, %i, %0) if #set(%i, %j)[%N] { diff --git a/mlir/g3doc/Dialects/SuperVector.md b/mlir/g3doc/Dialects/SuperVector.md index 09beb950e37..cd540335a52 100644 --- a/mlir/g3doc/Dialects/SuperVector.md +++ b/mlir/g3doc/Dialects/SuperVector.md @@ -22,9 +22,9 @@ Examples: // Read the slice `%A[%i0, %i1:%i1+256, %i2:%i2+32]` into vector<32x256xf32> and // pad with %f0 to handle the boundary case: %f0 = constant 0.0f : f32 -for %i0 = 0 to %0 { - for %i1 = 0 to %1 step 256 { - for %i2 = 0 to %2 step 32 { +affine.for %i0 = 0 to %0 { + affine.for %i1 = 0 to %1 step 256 { + affine.for %i2 = 0 to %2 step 32 { %v = vector_transfer_read %A, %i0, %i1, %i2, %f0 {permutation_map: (d0, d1, d2) -> (d2, d1)} : (memref, index, index, f32) -> vector<32x256xf32> @@ -33,8 +33,8 @@ for %i0 = 0 to %0 { // Read the slice `%A[%i0, %i1]` (i.e. the element `%A[%i0, %i1]`) into // vector<128xf32>. The underlying implementation will require a 1-D vector // broadcast: -for %i0 = 0 to %0 { - for %i1 = 0 to %1 { +affine.for %i0 = 0 to %0 { + affine.for %i1 = 0 to %1 { %3 = vector_transfer_read %A, %i0, %i1 {permutation_map: (d0, d1) -> (0)} : (memref, index, index) -> vector<128xf32> @@ -80,9 +80,9 @@ A notional lowering of vector_transfer_read could generate code resembling: // %expr1, %expr2, %expr3, %expr4 defined before this point %tmp = alloc() : vector<3x4x5xf32> %view_in_tmp = "element_type_cast"(%tmp) : memref<1xvector<3x4x5xf32>> -for %i = 0 to 3 { - for %j = 0 to 4 { - for %k = 0 to 5 { +affine.for %i = 0 to 3 { + affine.for %j = 0 to 4 { + affine.for %k = 0 to 5 { %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref store %tmp[%i, %j, %k] : vector<3x4x5xf32> }}} @@ -101,8 +101,8 @@ lowered code would resemble: // %expr1, %expr2, %expr3, %expr4 defined before this point %tmp = alloc() : vector<3x4x5xf32> %view_in_tmp = "element_type_cast"(%tmp) : memref<1xvector<3x4x5xf32>> -for %i = 0 to 3 { - for %k = 0 to 5 { +affine.for %i = 0 to 3 { + affine.for %k = 0 to 5 { %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref store %tmp[%i, 0, %k] : vector<3x4x5xf32> }} @@ -129,10 +129,10 @@ Examples: ```mlir {.mlir} // write vector<16x32x64xf32> into the slice `%A[%i0, %i1:%i1+32, %i2:%i2+64, %i3:%i3+16]`: -for %i0 = 0 to %0 { - for %i1 = 0 to %1 step 32 { - for %i2 = 0 to %2 step 64 { - for %i3 = 0 to %3 step 16 { +affine.for %i0 = 0 to %0 { + affine.for %i1 = 0 to %1 step 32 { + affine.for %i2 = 0 to %2 step 64 { + affine.for %i3 = 0 to %3 step 16 { %val = `ssa-value` : vector<16x32x64xf32> vector_transfer_write %val, %A, %i0, %i1, %i2, %i3 {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} : diff --git a/mlir/g3doc/LangRef.md b/mlir/g3doc/LangRef.md index 1858b7515eb..74fe885e4e6 100644 --- a/mlir/g3doc/LangRef.md +++ b/mlir/g3doc/LangRef.md @@ -40,7 +40,7 @@ which means that values are defined before use and have scope defined by their dominance relations. Operations may produce zero or more results, and each is a distinct SSA value with its own type defined by the [type system](#type-system). -MLIR incorporates polyhedral compiler concepts, including `for` and `if` +MLIR incorporates polyhedral compiler concepts, including `affine.for` and `if` operations defined by the [affine dialect](Dialects/Affine.md), which model affine loops and affine conditionals. It also includes affine maps integrated into the type system - they are key to the representation of data and @@ -99,10 +99,10 @@ func @multiply(%A: memref<100x?xf32>, %B: memref) %C = alloc memref<100x50xf32>() // Multiplication loop nest. - for %i = 0 to 100 { - for %j = 0 to 50 { + affine.for %i = 0 to 100 { + affine.for %j = 0 to 50 { store 0 to %C[%i, %j] : memref<100x50xf32> - for %k = 0 to %n { + affine.for %k = 0 to %n { %a_v = load %A[%i, %k] : memref<100x?xf32> %b_v = load %B[%k, %j] : memref %prod = mulf %a_v, %b_v : f32 @@ -1434,8 +1434,8 @@ The arity of indices is the rank of the memref (i.e., if the memref loaded from is of rank 3, then 3 indices are required for the load following the memref identifier). -In an `if` or `for` body, the indices of a load are restricted to SSA values -bound to surrounding loop induction variables, +In an `if` or `affine.for` body, the indices of a load are restricted to SSA +values bound to surrounding loop induction variables, [symbols](#dimensions-and-symbols), results of a [`constant` operation](#'constant'-operation), or the result of an `affine.apply` operation that can in turn take as arguments all of the @@ -1456,7 +1456,7 @@ Example: **Context:** The `load` and `store` instructions are specifically crafted to fully resolve a reference to an element of a memref, and (in affine `if` and -`for` instructions) the compiler can follow use-def chains (e.g. through +`affine.for` instructions) the compiler can follow use-def chains (e.g. through [`affine.apply`](Dialects/Affine.md#'affine.apply'-operation) operations) to precisely analyze references at compile-time using polyhedral techniques. This is possible because of the @@ -1492,7 +1492,7 @@ store %100, %A[%1, 1023] : memref<4x?xf32, #layout, hbm> **Context:** The `load` and `store` instructions are specifically crafted to fully resolve a reference to an element of a memref, and (in polyhedral `if` and -`for` instructions) the compiler can follow use-def chains (e.g. through +`affine.for` instructions) the compiler can follow use-def chains (e.g. through [`affine.apply`](Dialects/Affine.md#'affine.apply'-operation) operations) to precisely analyze references at compile-time using polyhedral techniques. This is possible because of the diff --git a/mlir/g3doc/Passes.md b/mlir/g3doc/Passes.md index 18f6b7b1c23..894d8b5fd2b 100644 --- a/mlir/g3doc/Passes.md +++ b/mlir/g3doc/Passes.md @@ -39,8 +39,8 @@ These restrictions may be lifted in the future. ### Output IR -Functions with `for` and `if` instructions eliminated. These functions may -contain operations from the Standard dialect in addition to those already +Functions with `affine.for` and `if` instructions eliminated. These functions +may contain operations from the Standard dialect in addition to those already present before the pass. ### Invariants diff --git a/mlir/g3doc/Rationale.md b/mlir/g3doc/Rationale.md index 58ad1670d15..bafd029ce0f 100644 --- a/mlir/g3doc/Rationale.md +++ b/mlir/g3doc/Rationale.md @@ -150,8 +150,8 @@ func bar(%A : memref<8x?xf32, #lmap>) { // dynamically using dim instruction. %N = dim %A, 1 : memref<8x?xf32, #lmap> - for %i = 0 to 8 { - for %j = 0 to %N { + affine.for %i = 0 to 8 { + affine.for %j = 0 to %N { // A[i,j] += 1 %s1 = load %A [%i, %j] : memref<8x?xf32, #lmap> %s2 = add %s1, 1 @@ -534,7 +534,7 @@ nested in an outer function that using affine loops. func @search(memref %S, i32 %key) { %ni = dim %A, 0 : memref // This loop can be parallelized - for %i = 0 to %ni { + affine.for %i = 0 to %ni { call @search_body (%A, %S, %i) : (memref, memref, i32) } return @@ -568,10 +568,10 @@ func @search_body(%A: memref, %S: memref, %key: i32) { As per the [MLIR spec](LangRef.md), the restrictions on dimensions and symbol identifiers to be used with the affine.apply instruction only apply to accesses -inside `for` and `if` instructions. However, an analysis of accesses inside the -called function (`@search_body`) is necessary to determine if the `%i` loop -could be parallelized: such function access analysis is calling context -sensitive. +inside `affine.for` and `if` instructions. However, an analysis of accesses +inside the called function (`@search_body`) is necessary to determine if the +`%i` loop could be parallelized: such function access analysis is calling +context sensitive. ### Non-affine loop bounds {#non-affine-loop-bounds} @@ -590,8 +590,8 @@ for (i=0; i i32 { - for %k = 0 to %m { - for %l = 0 to %n { + affine.for %k = 0 to %m { + affine.for %l = 0 to %n { ... } } @@ -649,13 +649,13 @@ in a dilated convolution. func @conv2d(memref<16x1024x1024x3xf32, #lm0, vmem> %input, memref<5x5x3x32xf32, #lm0, vmem> %kernel, memref<16x512x512x32xf32, #lm0, vmem> %output) { - for %b = 0 to %batch { - for %oh = 0 to %output_height { - for %ow = 0 to %output_width { - for %of = 0 to %output_feature { - for %kh = 0 to %kernel_height { - for %kw = 0 to %kernel_width { - for %if = 0 to %input_feature { + affine.for %b = 0 to %batch { + affine.for %oh = 0 to %output_height { + affine.for %ow = 0 to %output_width { + affine.for %of = 0 to %output_feature { + affine.for %kh = 0 to %kernel_height { + affine.for %kw = 0 to %kernel_width { + affine.for %if = 0 to %input_feature { // Calculate input indices. %1_0 = affine.apply #map1_0 (%0#1, %0#2, %0#4, %0#5) [%h_stride, %w_stride, %h_kernel_dilation, %w_kernel_dilation, @@ -899,14 +899,14 @@ func @dma_hbm_to_vmem(memref<1024 x f32, #layout_map0, hbm> %a, representation. 2(b) requires no change, but impacts how cost models look at index and layout maps. -### `if` and `for` Extensions for "Escaping Scalars" {#extensions-for-"escaping-scalars"} +### `if` and `affine.for` Extensions for "Escaping Scalars" {#extensions-for-"escaping-scalars"} We considered providing a representation for SSA values that are live out of -`if/else` conditional bodies and loop carried in `for` loops. We ultimately -abandoned this approach due to its complexity. In the current design of MLIR, -scalar variables cannot escape for loops or if instructions. In situations, -where escaping is necessary, we use zero-dimensional tensors and memrefs instead -of scalars. +`if/else` conditional bodies and loop carried in `affine.for` loops. We +ultimately abandoned this approach due to its complexity. In the current design +of MLIR, scalar variables cannot escape for loops or if instructions. In +situations, where escaping is necessary, we use zero-dimensional tensors and +memrefs instead of scalars. **TODO**: This whole section is obsolete and should be updated to use block arguments and a yield like terminator in for/if instructions. @@ -919,7 +919,7 @@ Syntax: ``` {.ebnf} [ =] -for % = ... step +affine.for % = ... step [with ] { } ``` @@ -934,7 +934,7 @@ Example: // Return sum of elements in 1-dimensional mref A func int32 @sum(%A : memref, %N : i32) -> (i32) { %init = 0 - %result = for %i = 0 to N with %tmp(%init) { + %result = affine.for %i = 0 to N with %tmp(%init) { %value = load %A[%i] %sum = %value + %tmp yield %sum @@ -964,7 +964,7 @@ Example: // Compute sum of half of the array func int32 @sum_half(%A, %N) { %s0 = 0 - %s1 = for %i = 1 ... N step 1 with %s2 (%s0) { + %s1 = affine.for %i = 1 ... N step 1 with %s2 (%s0) { %s3 = if (%i >= %N / 2) { %v0 = load %A[%i] %s4 = %s2 + %v0 diff --git a/mlir/g3doc/RationaleSimplifiedPolyhedralForm.md b/mlir/g3doc/RationaleSimplifiedPolyhedralForm.md index f51eff45633..b40f6708d0d 100644 --- a/mlir/g3doc/RationaleSimplifiedPolyhedralForm.md +++ b/mlir/g3doc/RationaleSimplifiedPolyhedralForm.md @@ -184,8 +184,8 @@ Our simple example above would be represented as: ```mlir mlfunc @simple_example(... %N) { - for %i = 0 ... %N step 1 { - for %j = 0 ... %N step 1 { + affine.for %i = 0 ... %N step 1 { + affine.for %j = 0 ... %N step 1 { // identity noop in this case, but can exist in general. %0,%1 = affine.apply #57(%i, %j) @@ -203,8 +203,8 @@ The example with the reduced domain would be represented with an if instruction: ```mlir mlfunc @reduced_domain_example(... %N) { - for %i = 0 ... %N step 1 { - for %j = 0 ... %N step 1 { + affine.for %i = 0 ... %N step 1 { + affine.for %j = 0 ... %N step 1 { // identity noop in this case, but can exist in general. %0,%1 = affinecall #57(%i, %j) @@ -233,8 +233,8 @@ that transformations call into): ```mlir mlfunc @skewed_domain_example(... %N) { - for %t1 = 0 ... 2*N-2 step 1 { - for %t2 = max(0, t1-N+1) ... min(N, t1) step 1 { + affine.for %t1 = 0 ... 2*N-2 step 1 { + affine.for %t2 = max(0, t1-N+1) ... min(N, t1) step 1 { (%i, %j) = (%t1-%t2, %t2) ... } @@ -373,7 +373,7 @@ mlfunc's (if we support them) will also have to have domains. ### Lack of redundancy in IR The traditional form has multiple encodings for the same sorts of behavior: you -end up having bits on `for` loops to specify whether codegen should use +end up having bits on `affine.for` loops to specify whether codegen should use "atomic/separate" policies, unroll loops, etc. Instructions can be split or can generate multiple copies of their instruction because of overlapping domains, etc. diff --git a/mlir/include/mlir/AffineOps/AffineOps.h b/mlir/include/mlir/AffineOps/AffineOps.h index 0ae43426db0..c448b795579 100644 --- a/mlir/include/mlir/AffineOps/AffineOps.h +++ b/mlir/include/mlir/AffineOps/AffineOps.h @@ -90,15 +90,15 @@ private: explicit AffineApplyOp(const Instruction *state) : Op(state) {} }; -/// The "for" instruction represents an affine loop nest, defining an SSA value -/// for its induction variable. The induction variable is represented as a +/// The "affine.for" instruction represents an affine loop nest, defining an SSA +/// value for its induction variable. The induction variable is represented as a /// BlockArgument to the entry block of the body. The body and induction -/// variable can be created automatically for new "for" ops with 'createBody'. -/// This SSA value always has type index, which is the size of the machine word. -/// The stride, represented by step, is a positive constant integer which -/// defaults to "1" if not present. The lower and upper bounds specify a -/// half-open range: the range includes the lower bound but does not include the -/// upper bound. +/// variable can be created automatically for new "affine.for" ops with +/// 'createBody'. This SSA value always has type index, which is the size of the +/// machine word. The stride, represented by step, is a positive constant +/// integer which defaults to "1" if not present. The lower and upper bounds +/// specify a half-open range: the range includes the lower bound but does not +/// include the upper bound. /// /// The lower and upper bounds of a for operation are represented as an /// application of an affine mapping to a list of SSA values passed to the map. @@ -110,7 +110,7 @@ private: /// /// Example: /// -/// for %i = 1 to 10 { +/// affine.for %i = 1 to 10 { /// ... /// } /// @@ -131,7 +131,7 @@ public: static void getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context); - static StringRef getOperationName() { return "for"; } + static StringRef getOperationName() { return "affine.for"; } static StringRef getStepAttrName() { return "step"; } static StringRef getLowerBoundAttrName() { return "lower_bound"; } static StringRef getUpperBoundAttrName() { return "upper_bound"; } @@ -253,15 +253,15 @@ ConstOpPointer getForInductionVarOwner(const Value *val); void extractForInductionVars(ArrayRef> forInsts, SmallVectorImpl *ivs); -/// Adds constraints (lower and upper bounds) for the specified 'for' +/// Adds constraints (lower and upper bounds) for the specified 'affine.for' /// instruction's Value using IR information stored in its bound maps. The /// right identifier is first looked up using forOp's Value. Returns /// false for the yet unimplemented/unsupported cases, and true if the /// information is successfully added. Asserts if the Value corresponding to -/// the 'for' instruction isn't found in the constraint system. Any new -/// identifiers that are found in the bound operands of the 'for' instruction -/// are added as trailing identifiers (either dimensional or symbolic -/// depending on whether the operand is a valid ML Function symbol). +/// the 'affine.for' instruction isn't found in the constraint system. Any new +/// identifiers that are found in the bound operands of the 'affine.for' +/// instruction are added as trailing identifiers (either dimensional or +/// symbolic depending on whether the operand is a valid ML Function symbol). // TODO(bondhugula): add support for non-unit strides. bool addAffineForOpDomain(ConstOpPointer forOp, FlatAffineConstraints *constraints); @@ -297,10 +297,10 @@ public: operand_range getOperands() const { return {operand_begin(), operand_end()}; } private: - // 'for' instruction that contains this bound. + // 'affine.for' instruction that contains this bound. ConstOpPointer inst; // Start and end positions of this affine bound operands in the list of - // the containing 'for' instruction operands. + // the containing 'affine.for' instruction operands. unsigned opStart, opEnd; // Affine map for this bound. AffineMap map; diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index 3371daf99fd..81b896c0102 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -52,7 +52,7 @@ bool dominates(const Instruction &a, const Instruction &b); bool properlyDominates(const Instruction &a, const Instruction &b); /// Populates 'loops' with IVs of the loops surrounding 'inst' ordered from -/// the outermost 'for' instruction to the innermost one. +/// the outermost 'affine.for' instruction to the innermost one. // TODO(bondhugula): handle 'if' inst's. void getLoopIVs(const Instruction &inst, SmallVectorImpl> *loops); @@ -105,8 +105,8 @@ insertBackwardComputationSlice(Instruction *srcOpInst, Instruction *dstOpInst, /// surrounding such op's. // For example, the memref region for a load operation at loop depth = 1: // -// for %i = 0 to 32 { -// for %ii = %i to (d0) -> (d0 + 8) (%i) { +// affine.for %i = 0 to 32 { +// affine.for %ii = %i to (d0) -> (d0 + 8) (%i) { // load %A[%ii] // } // } @@ -139,8 +139,8 @@ struct MemRefRegion { /// For example, the memref region for this operation at loopDepth = 1 will /// be: /// - /// for %i = 0 to 32 { - /// for %ii = %i to (d0) -> (d0 + 8) (%i) { + /// affine.for %i = 0 to 32 { + /// affine.for %ii = %i to (d0) -> (d0 + 8) (%i) { /// load %A[%ii] /// } /// } diff --git a/mlir/include/mlir/Analysis/VectorAnalysis.h b/mlir/include/mlir/Analysis/VectorAnalysis.h index 4982481bf6c..b3196e14097 100644 --- a/mlir/include/mlir/Analysis/VectorAnalysis.h +++ b/mlir/include/mlir/Analysis/VectorAnalysis.h @@ -76,9 +76,9 @@ shapeRatio(VectorType superVectorType, VectorType subVectorType); /// The following MLIR snippet: /// /// ```mlir -/// for %i3 = 0 to %0 { -/// for %i4 = 0 to %1 { -/// for %i5 = 0 to %2 { +/// affine.for %i3 = 0 to %0 { +/// affine.for %i4 = 0 to %1 { +/// affine.for %i5 = 0 to %2 { /// %a5 = load %arg0[%i4, %i5, %i3] : memref /// }}} /// ``` @@ -86,9 +86,9 @@ shapeRatio(VectorType superVectorType, VectorType subVectorType); /// may vectorize with {permutation_map: (d0, d1, d2) -> (d2, d1)} into: /// /// ```mlir -/// for %i3 = 0 to %0 step 32 { -/// for %i4 = 0 to %1 { -/// for %i5 = 0 to %2 step 256 { +/// affine.for %i3 = 0 to %0 step 32 { +/// affine.for %i4 = 0 to %1 { +/// affine.for %i5 = 0 to %2 step 256 { /// %4 = vector_transfer_read %arg0, %i4, %i5, %i3 /// {permutation_map: (d0, d1, d2) -> (d2, d1)} : /// (memref, index, index) -> vector<32x256xf32> @@ -103,7 +103,7 @@ shapeRatio(VectorType superVectorType, VectorType subVectorType); /// /// ```mlir /// %cst0 = constant 0 : index -/// for %i0 = 0 to %0 { +/// affine.for %i0 = 0 to %0 { /// %a0 = load %arg0[%cst0, %cst0] : memref /// } /// ``` @@ -111,7 +111,7 @@ shapeRatio(VectorType superVectorType, VectorType subVectorType); /// may vectorize with {permutation_map: (d0) -> (0)} into: /// /// ```mlir -/// for %i0 = 0 to %0 step 128 { +/// affine.for %i0 = 0 to %0 step 128 { /// %3 = vector_transfer_read %arg0, %c0_0, %c0_0 /// {permutation_map: (d0, d1) -> (0)} : /// (memref, index, index) -> vector<128xf32> diff --git a/mlir/include/mlir/Transforms/LoopUtils.h b/mlir/include/mlir/Transforms/LoopUtils.h index f3d9b9fe9fd..d543b520565 100644 --- a/mlir/include/mlir/Transforms/LoopUtils.h +++ b/mlir/include/mlir/Transforms/LoopUtils.h @@ -83,9 +83,10 @@ AffineMap getUnrolledLoopUpperBound(ConstOpPointer forOp, unsigned unrollFactor, FuncBuilder *builder); -/// Skew the instructions in the body of a 'for' instruction with the specified -/// instruction-wise shifts. The shifts are with respect to the original -/// execution order, and are multiplied by the loop 'step' before being applied. +/// Skew the instructions in the body of a 'affine.for' instruction with the +/// specified instruction-wise shifts. The shifts are with respect to the +/// original execution order, and are multiplied by the loop 'step' before being +/// applied. UtilResult instBodySkew(OpPointer forOp, ArrayRef shifts, bool unrollPrologueEpilogue = false); diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h index 3b828db6ae9..eb7f725576a 100644 --- a/mlir/include/mlir/Transforms/Utils.h +++ b/mlir/include/mlir/Transforms/Utils.h @@ -94,14 +94,14 @@ Instruction *createComposedAffineApplyOp(FuncBuilder *builder, Location loc, /// /// Before /// -/// for %i = 0 to #map(%N) +/// affine.for %i = 0 to #map(%N) /// %idx = affine.apply (d0) -> (d0 mod 2) (%i) /// send %A[%idx], ... /// %v = "compute"(%idx, ...) /// /// After /// -/// for %i = 0 to #map(%N) +/// affine.for %i = 0 to #map(%N) /// %idx = affine.apply (d0) -> (d0 mod 2) (%i) /// send %A[%idx], ... /// %idx_ = affine.apply (d0) -> (d0 mod 2) (%i) diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 249b09f41cd..be5a2f14628 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -716,7 +716,7 @@ static void printBound(AffineBound bound, const char *prefix, OpAsmPrinter *p) { } void AffineForOp::print(OpAsmPrinter *p) const { - *p << "for "; + *p << "affine.for "; p->printOperand(getBody()->getArgument(0)); *p << " = "; printBound(getLowerBound(), "max", p); diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index 9d2ea691bdd..3a086ba512d 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -756,8 +756,8 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const { // For example, given the following MLIR code with with "source" and // "destination" accesses to the same memref labled, and symbols %M, %N, %K: // -// for %i0 = 0 to 100 { -// for %i1 = 0 to 50 { +// affine.for %i0 = 0 to 100 { +// affine.for %i1 = 0 to 50 { // %a0 = affine.apply // (d0, d1) -> (d0 * 2 - d1 * 4 + s1, d1 * 3 - s0) (%i0, %i1)[%M, %N] // // Source memref access. @@ -765,8 +765,8 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const { // } // } // -// for %i2 = 0 to 100 { -// for %i3 = 0 to 50 { +// affine.for %i2 = 0 to 100 { +// affine.for %i3 = 0 to 50 { // %a1 = affine.apply // (d0, d1) -> (d0 * 7 + d1 * 9 - s1, d1 * 11 + s0) (%i2, %i3)[%K, %M] // // Destination memref access. diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index ae48e644a68..a48f39c2aac 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -36,13 +36,13 @@ using namespace mlir; /// Populates 'loops' with IVs of the loops surrounding 'inst' ordered from -/// the outermost 'for' instruction to the innermost one. +/// the outermost 'affine.for' instruction to the innermost one. void mlir::getLoopIVs(const Instruction &inst, SmallVectorImpl> *loops) { auto *currInst = inst.getParentInst(); OpPointer currAffineForOp; - // Traverse up the hierarchy collecing all 'for' instruction while skipping - // over 'if' instructions. + // Traverse up the hierarchy collecing all 'affine.for' instruction while + // skipping over 'if' instructions. while (currInst && ((currAffineForOp = currInst->dyn_cast()) || currInst->isa())) { if (currAffineForOp) @@ -111,8 +111,8 @@ bool MemRefRegion::unionBoundingBox(const MemRefRegion &other) { // For example, the memref region for this load operation at loopDepth = 1 will // be as below: // -// for %i = 0 to 32 { -// for %ii = %i to (d0) -> (d0 + 8) (%i) { +// affine.for %i = 0 to 32 { +// affine.for %ii = %i to (d0) -> (d0 + 8) (%i) { // load %A[%ii] // } // } @@ -614,7 +614,7 @@ Optional mlir::getMemoryFootprintBytes(const Block &block, int memorySpace) { std::vector> regions; - // Walk this 'for' instruction to gather all memory regions. + // Walk this 'affine.for' instruction to gather all memory regions. bool error = false; const_cast(&block)->walk([&](Instruction *opInst) { if (!opInst->isa() && !opInst->isa()) { diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp index e0c76e9efad..83e15097942 100644 --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -189,7 +189,7 @@ unsigned Block::getNumSuccessors() const { return terminator->getNumSuccessors(); } assert(getParent() && "top-level block with no terminator"); - // Blocks inside 'for'/'if' instructions don't have successors. + // Blocks inside 'affine.for'/'if' instructions don't have successors. return 0; } diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 855ff37f60f..631ebf939ea 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -338,7 +338,8 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, Block *block, auto fastMemRefType = top.getMemRefType( fastBufferShape, memRefType.getElementType(), {}, fastMemorySpace); - // Create the fast memory space buffer just before the 'for' instruction. + // Create the fast memory space buffer just before the 'affine.for' + // instruction. fastMemRef = prologue.create(loc, fastMemRefType)->getResult(); // Record it. fastBufferMap[memref] = fastMemRef; @@ -456,7 +457,7 @@ bool DmaGeneration::runOnBlock(Block *block, uint64_t consumedCapacityBytes) { // approach is conservative in some cases at the moment, we do a check later // and report an error with location info. // TODO(bondhugula): An 'if' instruction is being treated similar to an - // operation instruction. 'if''s could have 'for's in them; treat them + // operation instruction. 'if''s could have 'affine.for's in them; treat them // separately. // Get to the first load, store, or for op. @@ -470,9 +471,9 @@ bool DmaGeneration::runOnBlock(Block *block, uint64_t consumedCapacityBytes) { if (auto forOp = it->dyn_cast()) { // We'll assume for now that loops with steps are tiled loops, and so DMAs // are not performed for that depth, but only further inside. - // If the memory footprint of the 'for' loop is higher than fast memory - // capacity (when provided), we recurse to DMA at an inner level until - // we find a depth at which footprint fits in the capacity. If the + // If the memory footprint of the 'affine.for' loop is higher than fast + // memory capacity (when provided), we recurse to DMA at an inner level + // until we find a depth at which footprint fits in the capacity. If the // footprint can't be calcuated, we assume for now it fits. // Returns true if the footprint is known to exceed capacity. @@ -489,13 +490,13 @@ bool DmaGeneration::runOnBlock(Block *block, uint64_t consumedCapacityBytes) { consumedCapacityBytes += runOnBlock(/*begin=*/curBegin, /*end=*/it); // Recurse onto the body of this loop. runOnBlock(forOp->getBody(), consumedCapacityBytes); - // The next region starts right after the 'for' instruction. + // The next region starts right after the 'affine.for' instruction. curBegin = std::next(it); } else { // We have enough capacity, i.e., DMAs will be computed for the portion - // of the block until 'it', and for the 'for' loop. For the latter, they - // are placed just before this loop (for incoming DMAs) and right after - // (for outgoing ones). + // of the block until 'it', and for the 'affine.for' loop. For the + // latter, they are placed just before this loop (for incoming DMAs) and + // right after (for outgoing ones). consumedCapacityBytes += runOnBlock(/*begin=*/curBegin, /*end=*/it); // Inner loop DMAs have their own scope - we don't thus update consumed diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 8d5f51059bf..9e96b0800b3 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -510,7 +510,8 @@ bool MemRefDependenceGraph::init(Function *f) { // all loads and store accesses it contains. LoopNestStateCollector collector; collector.collect(&inst); - // Return false if a non 'for' region was found (not currently supported). + // Return false if a non 'affine.for' region was found (not currently + // supported). if (collector.hasNonForRegion) return false; Node node(nextNodeId++, &inst); diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 368a1dac1df..f00c2e767e6 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -231,7 +231,8 @@ UtilResult mlir::tileCodeGen(MutableArrayRef> band, static void getTileableBands(Function *f, std::vector, 6>> *bands) { - // Get maximal perfect nest of 'for' insts starting from root (inclusive). + // Get maximal perfect nest of 'affine.for' insts starting from root + // (inclusive). auto getMaximalPerfectLoopNest = [&](OpPointer root) { SmallVector, 6> band; OpPointer currInst = root; diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 3a7cfb85e08..025a86891df 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -164,7 +164,7 @@ PassResult LoopUnroll::runOnFunction(Function *f) { return success(); } -/// Unrolls a 'for' inst. Returns true if the loop was unrolled, false +/// Unrolls a 'affine.for' inst. Returns true if the loop was unrolled, false /// otherwise. The default unroll factor is 4. bool LoopUnroll::runOnAffineForOp(OpPointer forOp) { // Use the function callback if one was provided. diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index b2aed7d9d7f..2f0249824dd 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -105,7 +105,7 @@ PassResult LoopUnrollAndJam::runOnFunction(Function *f) { return success(); } -/// Unroll and jam a 'for' inst. Default unroll jam factor is +/// Unroll and jam a 'affine.for' inst. Default unroll jam factor is /// kDefaultUnrollJamFactor. Return false if nothing was done. bool LoopUnrollAndJam::runOnAffineForOp(OpPointer forOp) { // Unroll and jam by the factor that was passed if any. diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index 0d8eb8a4761..ef45891c26f 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -283,7 +283,8 @@ static Value *buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate, return value; } -// Convert a "for" loop to a flow of blocks. Return `false` on success. +// Convert a "affine.for" loop to a flow of blocks. Return `false` on +// success. // // Create an SESE region for the loop (including its body) and append it to the // end of the current region. The loop region consists of the initialization @@ -330,8 +331,9 @@ bool LowerAffinePass::lowerAffineFor(OpPointer forOp) { auto loc = forOp->getLoc(); auto *forInst = forOp->getInstruction(); - // Start by splitting the block containing the 'for' into two parts. The part - // before will get the init code, the part after will be the end point. + // Start by splitting the block containing the 'affine.for' into two parts. + // The part before will get the init code, the part after will be the end + // point. auto *initBlock = forInst->getBlock(); auto *endBlock = initBlock->splitBlock(forInst); diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index 63fb45db9c5..e63d3c8111c 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -126,9 +126,9 @@ private: /// // Read the slice `%A[%i0, %i1:%i1+256, %i2:%i2+32]` into /// // vector<32x256xf32> and pad with %f0 to handle the boundary case: /// %f0 = constant 0.0f : f32 -/// for %i0 = 0 to %0 { -/// for %i1 = 0 to %1 step 256 { -/// for %i2 = 0 to %2 step 32 { +/// affine.for %i0 = 0 to %0 { +/// affine.for %i1 = 0 to %1 step 256 { +/// affine.for %i2 = 0 to %2 step 32 { /// %v = vector_transfer_read %A, %i0, %i1, %i2, %f0 /// {permutation_map: (d0, d1, d2) -> (d2, d1)} : /// (memref, index, index, f32) -> vector<32x256xf32> @@ -139,8 +139,8 @@ private: /// MLIR resembling: /// /// ```mlir -/// for %d1 = 0 to 256 { -/// for %d2 = 0 to 32 { +/// affine.for %d1 = 0 to 256 { +/// affine.for %d2 = 0 to 32 { /// %s = %A[%i0, %i1 + %d1, %i2 + %d2] : f32 /// %tmp[%d2, %d1] = %s /// } diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index be5a03bc416..4434ab5322e 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -101,10 +101,10 @@ /// mlfunc @materialize(%M : index, %N : index, %O : index, %P : index) { /// %A = alloc (%M, %N, %O, %P) : memref /// %f1 = constant splat, 1.000000e+00> : -/// vector<4x4x4xf32> for %i0 = 0 to %M step 4 { -/// for %i1 = 0 to %N step 4 { -/// for %i2 = 0 to %O { -/// for %i3 = 0 to %P step 4 { +/// vector<4x4x4xf32> affine.for %i0 = 0 to %M step 4 { +/// affine.for %i1 = 0 to %N step 4 { +/// affine.for %i2 = 0 to %O { +/// affine.for %i3 = 0 to %P step 4 { /// vector_transfer_write %f1, %A, %i0, %i1, %i2, %i3 /// {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d0)} : /// vector<4x4x4xf32>, memref, @@ -120,10 +120,10 @@ /// mlfunc @materialize(%M : index, %N : index, %O : index, %P : index) { /// %A = alloc (%M, %N, %O, %P) : memref /// %f1 = constant splat, 1.000000e+00> : vector<4x4x4xf32> -/// for %i0 = 0 to %arg0 step 4 { -/// for %i1 = 0 to %arg1 step 4 { -/// for %i2 = 0 to %arg2 { -/// for %i3 = 0 to %arg3 step 4 { +/// affine.for %i0 = 0 to %arg0 step 4 { +/// affine.for %i1 = 0 to %arg1 step 4 { +/// affine.for %i2 = 0 to %arg2 { +/// affine.for %i3 = 0 to %arg3 step 4 { /// %1 = affine.apply (d0, d1, d2, d3) -> (d0, d1, d2, d3) /// (%i0, %i1, %i2, %i3) /// vector_transfer_write f1, %0, %1#0, %1#1, %1#2, %1#3 @@ -293,10 +293,10 @@ static Value *substitute(Value *v, VectorType hwVectorType, /// super-vectorization has been applied: /// /// ```mlir -/// for %i0 = 0 to %M { -/// for %i1 = 0 to %N step 3 { -/// for %i2 = 0 to %O { -/// for %i3 = 0 to %P step 32 { +/// affine.for %i0 = 0 to %M { +/// affine.for %i1 = 0 to %N step 3 { +/// affine.for %i2 = 0 to %O { +/// affine.for %i3 = 0 to %P step 32 { /// %r = vector_transfer_read(%A, map(%i..)#0, map(%i..)#1, map(%i..)#2) /// -> vector<3x32xf32> /// ... diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index d9f940a01f3..3141d748750 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -19,7 +19,7 @@ // potentially getting rid of intermediate memref's entirely. // TODO(mlir-team): In the future, similar techniques could be used to eliminate // dead memref store's and perform more complex forwarding when support for -// SSA scalars live out of 'for'/'if' statements is available. +// SSA scalars live out of 'affine.for'/'if' statements is available. //===----------------------------------------------------------------------===// #include "mlir/Analysis/AffineAnalysis.h" @@ -55,7 +55,7 @@ namespace { // // (* A dependence being satisfied at a block: a dependence that is satisfied by // virtue of the destination instruction appearing textually / lexically after -// the source instruction within the body of a 'for' instruction; thus, a +// the source instruction within the body of a 'affine.for' instruction; thus, a // dependence is always either satisfied by a loop or by a block). // // The above conditions are simple to check, sufficient, and powerful for most @@ -145,8 +145,8 @@ void MemRefDataFlowOpt::forwardStoreToLoad(OpPointer loadOp) { // Check if this store is a candidate for forwarding; we only forward if // the dependence from the store is carried by the *body* of innermost // common surrounding loop. As an example this filters out cases like: - // for %i0 - // for %i1 + // affine.for %i0 + // affine.for %i1 // %idx = affine.apply (d0) -> (d0 + 1) (%i0) // store %A[%idx] // load %A[%i0] diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index cfa045f2279..84c8cd830dc 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -71,11 +71,11 @@ static unsigned getTagMemRefPos(const Instruction &dmaInst) { return 0; } -/// Doubles the buffer of the supplied memref on the specified 'for' instruction -/// by adding a leading dimension of size two to the memref. Replaces all uses -/// of the old memref by the new one while indexing the newly added dimension by -/// the loop IV of the specified 'for' instruction modulo 2. Returns false if -/// such a replacement cannot be performed. +/// Doubles the buffer of the supplied memref on the specified 'affine.for' +/// instruction by adding a leading dimension of size two to the memref. +/// Replaces all uses of the old memref by the new one while indexing the newly +/// added dimension by the loop IV of the specified 'affine.for' instruction +/// modulo 2. Returns false if such a replacement cannot be performed. static bool doubleBuffer(Value *oldMemRef, OpPointer forOp) { auto *forBody = forOp->getBody(); FuncBuilder bInner(forBody, forBody->begin()); @@ -108,7 +108,7 @@ static bool doubleBuffer(Value *oldMemRef, OpPointer forOp) { dynamicDimCount++)); } - // Create and place the alloc right before the 'for' instruction. + // Create and place the alloc right before the 'affine.for' instruction. // TODO(mlir-team): we are assuming scoped allocation here, and aren't // inserting a dealloc -- this isn't the right thing. Value *newMemRef = @@ -137,9 +137,9 @@ static bool doubleBuffer(Value *oldMemRef, OpPointer forOp) { /// Returns success if the IR is in a valid state. PassResult PipelineDataTransfer::runOnFunction(Function *f) { // Do a post order walk so that inner loop DMAs are processed first. This is - // necessary since 'for' instructions nested within would otherwise become - // invalid (erased) when the outer loop is pipelined (the pipelined one gets - // deleted and replaced by a prologue, a new steady-state loop and an + // necessary since 'affine.for' instructions nested within would otherwise + // become invalid (erased) when the outer loop is pipelined (the pipelined one + // gets deleted and replaced by a prologue, a new steady-state loop and an // epilogue). forOps.clear(); f->walkPostOrder( diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index a1903ace026..110949f43d5 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -138,8 +138,8 @@ void mlir::promoteSingleIterationLoops(Function *f) { [](OpPointer forOp) { promoteIfSingleIteration(forOp); }); } -/// Generates a 'for' inst with the specified lower and upper bounds while -/// generating the right IV remappings for the shifted instructions. The +/// Generates a 'affine.for' inst with the specified lower and upper bounds +/// while generating the right IV remappings for the shifted instructions. The /// instruction blocks that go into the loop are specified in instGroupQueue /// starting from the specified offset, and in that order; the first element of /// the pair specifies the shift applied to that group of instructions; note @@ -194,10 +194,10 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, return loopChunk; } -/// Skew the instructions in the body of a 'for' instruction with the specified -/// instruction-wise shifts. The shifts are with respect to the original -/// execution order, and are multiplied by the loop 'step' before being applied. -/// A shift of zero for each instruction will lead to no change. +/// Skew the instructions in the body of a 'affine.for' instruction with the +/// specified instruction-wise shifts. The shifts are with respect to the +/// original execution order, and are multiplied by the loop 'step' before being +/// applied. A shift of zero for each instruction will lead to no change. // The skewing of instructions with respect to one another can be used for // example to allow overlap of asynchronous operations (such as DMA // communication) with computation, or just relative shifting of instructions @@ -246,7 +246,7 @@ UtilResult mlir::instBodySkew(OpPointer forOp, // An array of instruction groups sorted by shift amount; each group has all // instructions with the same shift in the order in which they appear in the - // body of the 'for' inst. + // body of the 'affine.for' inst. std::vector> sortedInstGroups(maxShift + 1); unsigned pos = 0; for (auto &inst : *forOp->getBody()) { diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 41689be52fc..90d28bf34df 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -194,14 +194,14 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, /// /// Before /// -/// for %i = 0 to #map(%N) +/// affine.for %i = 0 to #map(%N) /// %idx = affine.apply (d0) -> (d0 mod 2) (%i) /// "send"(%idx, %A, ...) /// "compute"(%idx) /// /// After /// -/// for %i = 0 to #map(%N) +/// affine.for %i = 0 to #map(%N) /// %idx = affine.apply (d0) -> (d0 mod 2) (%i) /// "send"(%idx, %A, ...) /// %idx_ = affine.apply (d0) -> (d0 mod 2) (%i) diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 5a8d5d24661..1f4c7b9fcc8 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -113,7 +113,7 @@ using namespace mlir; /// /// At a high level, a vectorized load in a loop will resemble: /// ```mlir -/// for %i = ? to ? step ? { +/// affine.for %i = ? to ? step ? { /// %v_a = "vector_transfer_read" (A, %i) : (memref, index) -> /// vector<128xf32> /// } @@ -309,7 +309,7 @@ using namespace mlir; /// ```mlir /// mlfunc @fill(%A : memref<128xf32>) -> () { /// %f1 = constant 1.0 : f32 -/// for %i0 = 0 to 32 { +/// affine.for %i0 = 0 to 32 { /// store %f1, %A[%i0] : memref<128xf32, 0> /// } /// return @@ -322,7 +322,7 @@ using namespace mlir; /// is still subject to exploratory tradeoffs. In particular, say we want to /// vectorize by a factor 128, we want to transform the following input: /// ```mlir -/// for %i = %M to %N { +/// affine.for %i = %M to %N { /// %a = load A[%i] : memref /// } /// ``` @@ -331,8 +331,8 @@ using namespace mlir; /// memory promotion etc) say after stripmining (and potentially unrolling in /// the case of LLVM's SLP vectorizer): /// ```mlir -/// for %i = floor(%M, 128) to ceil(%N, 128) { -/// for %ii = max(%M, 128 * %i) to min(%N, 128*%i + 127) { +/// affine.for %i = floor(%M, 128) to ceil(%N, 128) { +/// affine.for %ii = max(%M, 128 * %i) to min(%N, 128*%i + 127) { /// %a = load A[%ii] : memref /// } /// } @@ -341,7 +341,7 @@ using namespace mlir; /// Instead, we seek to vectorize early and freeze vector types before /// scheduling, so we want to generate a pattern that resembles: /// ```mlir -/// for %i = ? to ? step ? { +/// affine.for %i = ? to ? step ? { /// %v_a = "vector_transfer_read" (A, %i) : (memref, index) -> /// vector<128xf32> /// } @@ -362,7 +362,7 @@ using namespace mlir; /// For the simple strawman example above, vectorizing for a 1-D vector /// abstraction of size 128 returns code similar to: /// ```mlir -/// for %i = %M to %N step 128 { +/// affine.for %i = %M to %N step 128 { /// %v_a = "vector_transfer_read" (A, %i) : (memref, index) -> /// vector<128xf32> /// } @@ -391,20 +391,20 @@ using namespace mlir; /// %C = alloc (%M, %N) : memref /// %f1 = constant 1.0 : f32 /// %f2 = constant 2.0 : f32 -/// for %i0 = 0 to %M { -/// for %i1 = 0 to %N { +/// affine.for %i0 = 0 to %M { +/// affine.for %i1 = 0 to %N { /// // non-scoped %f1 /// store %f1, %A[%i0, %i1] : memref /// } /// } -/// for %i2 = 0 to %M { -/// for %i3 = 0 to %N { +/// affine.for %i2 = 0 to %M { +/// affine.for %i3 = 0 to %N { /// // non-scoped %f2 /// store %f2, %B[%i2, %i3] : memref /// } /// } -/// for %i4 = 0 to %M { -/// for %i5 = 0 to %N { +/// affine.for %i4 = 0 to %M { +/// affine.for %i5 = 0 to %N { /// %a5 = load %A[%i4, %i5] : memref /// %b5 = load %B[%i4, %i5] : memref /// %s5 = addf %a5, %b5 : f32 @@ -438,24 +438,24 @@ using namespace mlir; /// %2 = alloc(%arg0, %arg1) : memref /// %cst = constant 1.0 : f32 /// %cst_0 = constant 2.0 : f32 -/// for %i0 = 0 to %arg0 { -/// for %i1 = 0 to %arg1 step 256 { +/// affine.for %i0 = 0 to %arg0 { +/// affine.for %i1 = 0 to %arg1 step 256 { /// %cst_1 = constant splat, 1.0> : /// vector<256xf32> /// "vector_transfer_write"(%cst_1, %0, %i0, %i1) : /// (vector<256xf32>, memref, index, index) -> () /// } /// } -/// for %i2 = 0 to %arg0 { -/// for %i3 = 0 to %arg1 step 256 { +/// affine.for %i2 = 0 to %arg0 { +/// affine.for %i3 = 0 to %arg1 step 256 { /// %cst_2 = constant splat, 2.0> : /// vector<256xf32> /// "vector_transfer_write"(%cst_2, %1, %i2, %i3) : /// (vector<256xf32>, memref, index, index) -> () /// } /// } -/// for %i4 = 0 to %arg0 { -/// for %i5 = 0 to %arg1 step 256 { +/// affine.for %i4 = 0 to %arg0 { +/// affine.for %i5 = 0 to %arg1 step 256 { /// %3 = "vector_transfer_read"(%0, %i4, %i5) : /// (memref, index, index) -> vector<256xf32> /// %4 = "vector_transfer_read"(%1, %i4, %i5) : @@ -494,24 +494,24 @@ using namespace mlir; /// %2 = alloc(%arg0, %arg1) : memref /// %cst = constant 1.0 : f32 /// %cst_0 = constant 2.0 : f32 -/// for %i0 = 0 to %arg0 step 32 { -/// for %i1 = 0 to %arg1 step 256 { +/// affine.for %i0 = 0 to %arg0 step 32 { +/// affine.for %i1 = 0 to %arg1 step 256 { /// %cst_1 = constant splat, 1.0> : /// vector<32x256xf32> /// "vector_transfer_write"(%cst_1, %0, %i0, %i1) : /// (vector<32x256xf32>, memref, index, index) -> () /// } /// } -/// for %i2 = 0 to %arg0 step 32 { -/// for %i3 = 0 to %arg1 step 256 { +/// affine.for %i2 = 0 to %arg0 step 32 { +/// affine.for %i3 = 0 to %arg1 step 256 { /// %cst_2 = constant splat, 2.0> : /// vector<32x256xf32> /// "vector_transfer_write"(%cst_2, %1, %i2, %i3) : /// (vector<32x256xf32>, memref, index, index) -> () /// } /// } -/// for %i4 = 0 to %arg0 step 32 { -/// for %i5 = 0 to %arg1 step 256 { +/// affine.for %i4 = 0 to %arg0 step 32 { +/// affine.for %i5 = 0 to %arg1 step 256 { /// %3 = "vector_transfer_read"(%0, %i4, %i5) : /// (memref, index, index) -> vector<32x256xf32> /// %4 = "vector_transfer_read"(%1, %i4, %i5) : diff --git a/mlir/test/AffineOps/canonicalize.mlir b/mlir/test/AffineOps/canonicalize.mlir index ad6f39f3496..163cfbe0985 100644 --- a/mlir/test/AffineOps/canonicalize.mlir +++ b/mlir/test/AffineOps/canonicalize.mlir @@ -32,7 +32,7 @@ func @compose_affine_maps_1dto2d_no_symbols() { %0 = alloc() : memref<4x4xf32> - for %i0 = 0 to 15 { + affine.for %i0 = 0 to 15 { // Test load[%x, %x] %x0 = affine.apply (d0) -> (d0 - 1) (%i0) @@ -78,7 +78,7 @@ func @compose_affine_maps_1dto2d_no_symbols() { func @compose_affine_maps_1dto2d_with_symbols() { %0 = alloc() : memref<4x4xf32> - for %i0 = 0 to 15 { + affine.for %i0 = 0 to 15 { // Test load[%x0, %x0] with symbol %c4 %c4 = constant 4 : index %x0 = affine.apply (d0)[s0] -> (d0 - s0) (%i0)[%c4] @@ -119,13 +119,13 @@ func @compose_affine_maps_2d_tile() { %c4 = constant 4 : index %c8 = constant 8 : index - for %i0 = 0 to 3 { + affine.for %i0 = 0 to 3 { %x0 = affine.apply (d0)[s0] -> (d0 ceildiv s0) (%i0)[%c4] - for %i1 = 0 to 3 { + affine.for %i1 = 0 to 3 { %x1 = affine.apply (d0)[s0] -> (d0 ceildiv s0) (%i1)[%c8] - for %i2 = 0 to 3 { + affine.for %i2 = 0 to 3 { %x2 = affine.apply (d0)[s0] -> (d0 mod s0) (%i2)[%c4] - for %i3 = 0 to 3 { + affine.for %i3 = 0 to 3 { %x3 = affine.apply (d0)[s0] -> (d0 mod s0) (%i3)[%c8] %x40 = affine.apply (d0, d1, d2, d3)[s0, s1] -> @@ -151,9 +151,9 @@ func @compose_affine_maps_dependent_loads() { %0 = alloc() : memref<16x32xf32> %1 = alloc() : memref<16x32xf32> - for %i0 = 0 to 3 { - for %i1 = 0 to 3 { - for %i2 = 0 to 3 { + affine.for %i0 = 0 to 3 { + affine.for %i1 = 0 to 3 { + affine.for %i2 = 0 to 3 { %c3 = constant 3 : index %c7 = constant 7 : index @@ -197,7 +197,7 @@ func @compose_affine_maps_dependent_loads() { func @compose_affine_maps_diamond_dependency() { %0 = alloc() : memref<4x4xf32> - for %i0 = 0 to 15 { + affine.for %i0 = 0 to 15 { %a = affine.apply (d0) -> (d0 - 1) (%i0) %b = affine.apply (d0) -> (d0 + 7) (%a) %c = affine.apply (d0) -> (d0 * 4) (%a) @@ -217,8 +217,8 @@ func @arg_used_as_dim_and_symbol(%arg0: memref<100x100xf32>, %arg1: index) { %c9 = constant 9 : index %1 = alloc() : memref<100x100xf32, 1> %2 = alloc() : memref<1xi32> - for %i0 = 0 to 100 { - for %i1 = 0 to 100 { + affine.for %i0 = 0 to 100 { + affine.for %i1 = 0 to 100 { %3 = affine.apply (d0, d1)[s0, s1] -> (d1 + s0 + s1) (%i0, %i1)[%arg1, %c9] %4 = affine.apply (d0, d1, d3) -> (d3 - (d0 + d1)) @@ -238,7 +238,7 @@ func @trivial_maps() { %0 = alloc() : memref<10xf32> %c0 = constant 0 : index %cst = constant 0.000000e+00 : f32 - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { %1 = affine.apply ()[s0] -> (s0)()[%c0] store %cst, %0[%1] : memref<10xf32> %2 = load %0[%c0] : memref<10xf32> @@ -277,20 +277,20 @@ func @constant_fold_bounds(%N : index) { %c3 = affine.apply (d0, d1) -> (d0 + d1) (%c1, %c2) %l = "foo"() : () -> index - // CHECK: for %i0 = 5 to 7 { - for %i = max (d0, d1) -> (0, d0 + d1)(%c2, %c3) to min (d0, d1) -> (d0 - 2, 32*d1) (%c9, %c1) { + // CHECK: affine.for %i0 = 5 to 7 { + affine.for %i = max (d0, d1) -> (0, d0 + d1)(%c2, %c3) to min (d0, d1) -> (d0 - 2, 32*d1) (%c9, %c1) { "foo"(%i, %c3) : (index, index) -> () } // Bound takes a non-constant argument but can still be folded. - // CHECK: for %i1 = 1 to 7 { - for %j = max (d0) -> (0, 1)(%N) to min (d0, d1) -> (7, 9)(%N, %l) { + // CHECK: affine.for %i1 = 1 to 7 { + affine.for %j = max (d0) -> (0, 1)(%N) to min (d0, d1) -> (7, 9)(%N, %l) { "foo"(%j, %c3) : (index, index) -> () } // None of the bounds can be folded. - // CHECK: for %i2 = max [[MAP0]]()[%0] to min [[MAP1]]()[%arg0] { - for %k = max ()[s0] -> (0, s0) ()[%l] to min ()[s0] -> (100, s0)()[%N] { + // CHECK: affine.for %i2 = max [[MAP0]]()[%0] to min [[MAP1]]()[%arg0] { + affine.for %k = max ()[s0] -> (0, s0) ()[%l] to min ()[s0] -> (100, s0)()[%N] { "foo"(%k, %c3) : (index, index) -> () } return diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir index 30fae330787..330407272c8 100644 --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -204,35 +204,35 @@ func @illegaltype(i0) // expected-error {{invalid integer width}} // ----- func @malformed_for_percent() { - for i = 1 to 10 { // expected-error {{expected SSA operand}} + affine.for i = 1 to 10 { // expected-error {{expected SSA operand}} // ----- func @malformed_for_equal() { - for %i 1 to 10 { // expected-error {{expected '='}} + affine.for %i 1 to 10 { // expected-error {{expected '='}} // ----- func @malformed_for_to() { - for %i = 1 too 10 { // expected-error {{expected 'to' between bounds}} + affine.for %i = 1 too 10 { // expected-error {{expected 'to' between bounds}} } } // ----- func @incomplete_for() { - for %i = 1 to 10 step 2 + affine.for %i = 1 to 10 step 2 } // expected-error {{expected '{' to begin block list}} // ----- func @nonconstant_step(%1 : i32) { - for %2 = 1 to 5 step %1 { // expected-error {{expected non-function type}} + affine.for %2 = 1 to 5 step %1 { // expected-error {{expected non-function type}} // ----- func @for_negative_stride() { - for %i = 1 to 10 step -1 + affine.for %i = 1 to 10 step -1 } // expected-error@-1 {{expected step to be representable as a positive signed integer}} // ----- @@ -244,7 +244,7 @@ func @non_instruction() { // ----- func @invalid_if_conditional2() { - for %i = 1 to 10 { + affine.for %i = 1 to 10 { if (i)[N] : (i >= ) // expected-error {{expected '== 0' or '>= 0' at end of affine constraint}} } } @@ -252,7 +252,7 @@ func @invalid_if_conditional2() { // ----- func @invalid_if_conditional3() { - for %i = 1 to 10 { + affine.for %i = 1 to 10 { if (i)[N] : (i == 1) // expected-error {{expected '0' after '=='}} } } @@ -260,7 +260,7 @@ func @invalid_if_conditional3() { // ----- func @invalid_if_conditional4() { - for %i = 1 to 10 { + affine.for %i = 1 to 10 { if (i)[N] : (i >= 2) // expected-error {{expected '0' after '>='}} } } @@ -268,7 +268,7 @@ func @invalid_if_conditional4() { // ----- func @invalid_if_conditional5() { - for %i = 1 to 10 { + affine.for %i = 1 to 10 { if (i)[N] : (i <= 0 ) // expected-error {{expected '== 0' or '>= 0' at end of affine constraint}} } } @@ -276,7 +276,7 @@ func @invalid_if_conditional5() { // ----- func @invalid_if_conditional6() { - for %i = 1 to 10 { + affine.for %i = 1 to 10 { if (i) : (i) // expected-error {{expected '== 0' or '>= 0' at end of affine constraint}} } } @@ -284,7 +284,7 @@ func @invalid_if_conditional6() { // ----- // TODO (support if (1)? func @invalid_if_conditional7() { - for %i = 1 to 10 { + affine.for %i = 1 to 10 { if (i) : (1) // expected-error {{expected '== 0' or '>= 0' at end of affine constraint}} } } @@ -438,8 +438,8 @@ func @undef() { // ----- func @duplicate_induction_var() { - for %i = 1 to 10 { // expected-error {{previously defined here}} - for %i = 1 to 10 { // expected-error {{redefinition of SSA value '%i'}} + affine.for %i = 1 to 10 { // expected-error {{previously defined here}} + affine.for %i = 1 to 10 { // expected-error {{redefinition of SSA value '%i'}} } } return @@ -448,7 +448,7 @@ func @duplicate_induction_var() { // ----- func @dominance_failure() { - for %i = 1 to 10 { + affine.for %i = 1 to 10 { } "xxx"(%i) : (index)->() // expected-error {{operand #0 does not dominate this use}} return @@ -475,7 +475,7 @@ func @return_type_mismatch() -> i32 { // ----- func @return_inside_loop() -> i8 { - for %i = 1 to 100 { + affine.for %i = 1 to 100 { %a = "foo"() : ()->i8 return %a : i8 // expected-error@-1 {{'return' op may only be at the top level of a function}} @@ -521,7 +521,7 @@ func @referer() { #map1 = (i)[j] -> (i+j) func @bound_symbol_mismatch(%N : index) { - for %i = #map1(%N) to 100 { + affine.for %i = #map1(%N) to 100 { // expected-error@-1 {{symbol operand count and integer set symbol count must match}} } return @@ -532,7 +532,7 @@ func @bound_symbol_mismatch(%N : index) { #map1 = (i)[j] -> (i+j) func @bound_dim_mismatch(%N : index) { - for %i = #map1(%N, %N)[%N] to 100 { + affine.for %i = #map1(%N, %N)[%N] to 100 { // expected-error@-1 {{dim operand count and integer set dim count must match}} } return @@ -541,7 +541,7 @@ func @bound_dim_mismatch(%N : index) { // ----- func @large_bound() { - for %i = 1 to 9223372036854775810 { + affine.for %i = 1 to 9223372036854775810 { // expected-error@-1 {{integer constant out of range for attribute}} } return @@ -550,7 +550,7 @@ func @large_bound() { // ----- func @max_in_upper_bound(%N : index) { - for %i = 1 to max (i)->(N, 100) { //expected-error {{expected non-function type}} + affine.for %i = 1 to max (i)->(N, 100) { //expected-error {{expected non-function type}} } return } @@ -558,7 +558,7 @@ func @max_in_upper_bound(%N : index) { // ----- func @step_typo() { - for %i = 1 to 100 step -- 1 { //expected-error {{expected constant integer}} + affine.for %i = 1 to 100 step -- 1 { //expected-error {{expected constant integer}} } return } @@ -566,7 +566,7 @@ func @step_typo() { // ----- func @invalid_bound_map(%N : i32) { - for %i = 1 to (i)->(j)(%N) { //expected-error {{use of undeclared identifier}} + affine.for %i = 1 to (i)->(j)(%N) { //expected-error {{use of undeclared identifier}} } return } @@ -579,7 +579,7 @@ func @invalid_bound_map(%N : i32) { #set0 = (i)[N] : (i >= 0, N - i >= 0) func @invalid_if_operands1(%N : index) { - for %i = 1 to 10 { + affine.for %i = 1 to 10 { if #set0(%i) { // expected-error@-1 {{symbol operand count and integer set symbol count must match}} @@ -587,7 +587,7 @@ func @invalid_if_operands1(%N : index) { #set0 = (i)[N] : (i >= 0, N - i >= 0) func @invalid_if_operands2(%N : index) { - for %i = 1 to 10 { + affine.for %i = 1 to 10 { if #set0()[%N] { // expected-error@-1 {{dim operand count and integer set dim count must match}} @@ -595,7 +595,7 @@ func @invalid_if_operands2(%N : index) { #set0 = (i)[N] : (i >= 0, N - i >= 0) func @invalid_if_operands3(%N : index) { - for %i = 1 to 10 { + affine.for %i = 1 to 10 { if #set0(%i)[%i] { // expected-error@-1 {{operand cannot be used as a symbol}} } @@ -736,11 +736,11 @@ func @f(f32) { // ----- func @f(%m : memref) { - for %i0 = 0 to 42 { + affine.for %i0 = 0 to 42 { // expected-error@+1 {{operand #2 does not dominate this use}} %x = load %m[%i0, %i1] : memref } - for %i1 = 0 to 42 { + affine.for %i1 = 0 to 42 { } return } @@ -790,7 +790,7 @@ func @type_alias_unknown(!unknown_alias) -> () { // expected-error {{undefined t // Check ill-formed opaque tensor. func @complex_loops() { - for %i1 = 1 to 100 { + affine.for %i1 = 1 to 100 { // expected-error @+1 {{expected '"' in string literal}} "opaqueIntTensor"(){bar: opaque, "0x686]>} : () -> () @@ -824,7 +824,7 @@ func @invalid_affine_structure() { func @missing_for_max(%arg0: index, %arg1: index, %arg2: memref<100xf32>) { // expected-error @+1 {{lower loop bound affine map with multiple results requires 'max' prefix}} - for %i0 = ()[s]->(0,s-1)()[%arg0] to %arg1 { + affine.for %i0 = ()[s]->(0,s-1)()[%arg0] to %arg1 { } return } @@ -833,7 +833,7 @@ func @missing_for_max(%arg0: index, %arg1: index, %arg2: memref<100xf32>) { func @missing_for_min(%arg0: index, %arg1: index, %arg2: memref<100xf32>) { // expected-error @+1 {{upper loop bound affine map with multiple results requires 'min' prefix}} - for %i0 = %arg0 to ()[s]->(100,s+1)()[%arg1] { + affine.for %i0 = %arg0 to ()[s]->(100,s+1)()[%arg1] { } return } diff --git a/mlir/test/IR/locations.mlir b/mlir/test/IR/locations.mlir index 7196e3a5c29..935d2e32186 100644 --- a/mlir/test/IR/locations.mlir +++ b/mlir/test/IR/locations.mlir @@ -13,7 +13,7 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) { %2 = constant 4 : index loc(callsite("foo" at "mysource.cc":10:8)) // CHECK: } loc(fused["foo", "mysource.cc":10:8]) - for %i0 = 0 to 8 { + affine.for %i0 = 0 to 8 { } loc(fused["foo", "mysource.cc":10:8]) // CHECK: } loc(fused<"myPass">["foo", "foo2"]) diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index 1ec62b9a77d..01b02765de7 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -208,8 +208,8 @@ func @identity_functor(%a : () -> ()) -> (() -> ()) { func @func_ops_in_loop() { // CHECK: %0 = "foo"() : () -> i64 %a = "foo"() : ()->i64 - // CHECK: for %i0 = 1 to 10 { - for %i = 1 to 10 { + // CHECK: affine.for %i0 = 1 to 10 { + affine.for %i = 1 to 10 { // CHECK: %1 = "doo"() : () -> f32 %b = "doo"() : ()->f32 // CHECK: "bar"(%0, %1) : (i64, f32) -> () @@ -224,10 +224,10 @@ func @func_ops_in_loop() { // CHECK-LABEL: func @loops() { func @loops() { - // CHECK: for %i0 = 1 to 100 step 2 { - for %i = 1 to 100 step 2 { - // CHECK: for %i1 = 1 to 200 { - for %j = 1 to 200 { + // CHECK: affine.for %i0 = 1 to 100 step 2 { + affine.for %i = 1 to 100 step 2 { + // CHECK: affine.for %i1 = 1 to 200 { + affine.for %j = 1 to 200 { } // CHECK: } } // CHECK: } return // CHECK: return @@ -235,14 +235,14 @@ func @loops() { // CHECK-LABEL: func @complex_loops() { func @complex_loops() { - for %i1 = 1 to 100 { // CHECK: for %i0 = 1 to 100 { - for %j1 = 1 to 100 { // CHECK: for %i1 = 1 to 100 { + affine.for %i1 = 1 to 100 { // CHECK: affine.for %i0 = 1 to 100 { + affine.for %j1 = 1 to 100 { // CHECK: affine.for %i1 = 1 to 100 { // CHECK: "foo"(%i0, %i1) : (index, index) -> () "foo"(%i1, %j1) : (index,index) -> () } // CHECK: } "boo"() : () -> () // CHECK: "boo"() : () -> () - for %j2 = 1 to 10 { // CHECK: for %i2 = 1 to 10 { - for %k2 = 1 to 10 { // CHECK: for %i3 = 1 to 10 { + affine.for %j2 = 1 to 10 { // CHECK: affine.for %i2 = 1 to 10 { + affine.for %k2 = 1 to 10 { // CHECK: affine.for %i3 = 1 to 10 { "goo"() : () -> () // CHECK: "goo"() : () -> () } // CHECK: } } // CHECK: } @@ -253,8 +253,8 @@ func @complex_loops() { // CHECK: func @triang_loop(%arg0: index, %arg1: memref) { func @triang_loop(%arg0: index, %arg1: memref) { %c = constant 0 : i32 // CHECK: %c0_i32 = constant 0 : i32 - for %i0 = 1 to %arg0 { // CHECK: for %i0 = 1 to %arg0 { - for %i1 = (d0)[]->(d0)(%i0)[] to %arg0 { // CHECK: for %i1 = #map{{[0-9]+}}(%i0) to %arg0 { + affine.for %i0 = 1 to %arg0 { // CHECK: affine.for %i0 = 1 to %arg0 { + affine.for %i1 = (d0)[]->(d0)(%i0)[] to %arg0 { // CHECK: affine.for %i1 = #map{{[0-9]+}}(%i0) to %arg0 { store %c, %arg1[%i0, %i1] : memref // CHECK: store %c0_i32, %arg1[%i0, %i1] } // CHECK: } } // CHECK: } @@ -263,8 +263,8 @@ func @triang_loop(%arg0: index, %arg1: memref) { // CHECK: func @minmax_loop(%arg0: index, %arg1: index, %arg2: memref<100xf32>) { func @minmax_loop(%arg0: index, %arg1: index, %arg2: memref<100xf32>) { - // CHECK: for %i0 = max #map{{.*}}()[%arg0] to min #map{{.*}}()[%arg1] { - for %i0 = max()[s]->(0,s-1)()[%arg0] to min()[s]->(100,s+1)()[%arg1] { + // CHECK: affine.for %i0 = max #map{{.*}}()[%arg0] to min #map{{.*}}()[%arg1] { + affine.for %i0 = max()[s]->(0,s-1)()[%arg0] to min()[s]->(100,s+1)()[%arg1] { // CHECK: "foo"(%arg2, %i0) : (memref<100xf32>, index) -> () "foo"(%arg2, %i0) : (memref<100xf32>, index) -> () } // CHECK: } @@ -275,24 +275,24 @@ func @minmax_loop(%arg0: index, %arg1: index, %arg2: memref<100xf32>) { func @loop_bounds(%N : index) { // CHECK: %0 = "foo"(%arg0) : (index) -> index %s = "foo"(%N) : (index) -> index - // CHECK: for %i0 = %0 to %arg0 - for %i = %s to %N { - // CHECK: for %i1 = #map{{[0-9]+}}(%i0) to 0 - for %j = (d0)[]->(d0)(%i)[] to 0 step 1 { + // CHECK: affine.for %i0 = %0 to %arg0 + affine.for %i = %s to %N { + // CHECK: affine.for %i1 = #map{{[0-9]+}}(%i0) to 0 + affine.for %j = (d0)[]->(d0)(%i)[] to 0 step 1 { // CHECK: %1 = affine.apply #map{{.*}}(%i0, %i1)[%0] %w1 = affine.apply(d0, d1)[s0] -> (d0+d1) (%i, %j) [%s] // CHECK: %2 = affine.apply #map{{.*}}(%i0, %i1)[%0] %w2 = affine.apply(d0, d1)[s0] -> (s0+1) (%i, %j) [%s] - // CHECK: for %i2 = #map{{.*}}(%1, %i0)[%arg0] to #map{{.*}}(%2, %i1)[%0] { - for %k = #bound_map1 (%w1, %i)[%N] to (i, j)[s] -> (i + j + s) (%w2, %j)[%s] { + // CHECK: affine.for %i2 = #map{{.*}}(%1, %i0)[%arg0] to #map{{.*}}(%2, %i1)[%0] { + affine.for %k = #bound_map1 (%w1, %i)[%N] to (i, j)[s] -> (i + j + s) (%w2, %j)[%s] { // CHECK: "foo"(%i0, %i1, %i2) : (index, index, index) -> () "foo"(%i, %j, %k) : (index, index, index)->() // CHECK: %c30 = constant 30 : index %c = constant 30 : index // CHECK: %3 = affine.apply #map{{.*}}(%arg0, %c30) %u = affine.apply (d0, d1)->(d0+d1) (%N, %c) - // CHECK: for %i3 = max #map{{.*}}(%i0)[%3] to min #map{{.*}}(%i2)[%c30] { - for %l = max #bound_map2(%i)[%u] to min #bound_map2(%k)[%c] { + // CHECK: affine.for %i3 = max #map{{.*}}(%i0)[%3] to min #map{{.*}}(%i2)[%c30] { + affine.for %l = max #bound_map2(%i)[%u] to min #bound_map2(%k)[%c] { // CHECK: "bar"(%i3) : (index) -> () "bar"(%l) : (index) -> () } // CHECK: } @@ -305,7 +305,7 @@ func @loop_bounds(%N : index) { // CHECK-LABEL: func @ifinst(%arg0: index) { func @ifinst(%N: index) { %c = constant 200 : index // CHECK %c200 = constant 200 - for %i = 1 to 10 { // CHECK for %i0 = 1 to 10 { + affine.for %i = 1 to 10 { // CHECK affine.for %i0 = 1 to 10 { if #set0(%i)[%N, %c] { // CHECK if #set0(%i0)[%arg0, %c200] { %x = constant 1 : i32 // CHECK: %c1_i32 = constant 1 : i32 @@ -328,7 +328,7 @@ func @ifinst(%N: index) { // CHECK-LABEL: func @simple_ifinst(%arg0: index) { func @simple_ifinst(%N: index) { %c = constant 200 : index // CHECK %c200 = constant 200 - for %i = 1 to 10 { // CHECK for %i0 = 1 to 10 { + affine.for %i = 1 to 10 { // CHECK affine.for %i0 = 1 to 10 { if #set0(%i)[%N, %c] { // CHECK if #set0(%i0)[%arg0, %c200] { %x = constant 1 : i32 // CHECK: %c1_i32 = constant 1 : i32 @@ -544,18 +544,18 @@ func @funcattrwithblock() -> () #map_non_simple2 = ()[s0, s1] -> (s0 + s1) #map_non_simple3 = ()[s0] -> (s0 + 3) func @funcsimplemap(%arg0: index, %arg1: index) -> () { - for %i0 = 0 to #map_simple0()[] { - // CHECK: for %i0 = 0 to 10 { - for %i1 = 0 to #map_simple1()[%arg1] { - // CHECK: for %i1 = 0 to %arg1 { - for %i2 = 0 to #map_non_simple0(%i0)[] { - // CHECK: for %i2 = 0 to #map{{[a-z_0-9]*}}(%i0) { - for %i3 = 0 to #map_non_simple1(%i0)[%arg1] { - // CHECK: for %i3 = 0 to #map{{[a-z_0-9]*}}(%i0)[%arg1] { - for %i4 = 0 to #map_non_simple2()[%arg1, %arg0] { - // CHECK: for %i4 = 0 to #map{{[a-z_0-9]*}}()[%arg1, %arg0] { - for %i5 = 0 to #map_non_simple3()[%arg0] { - // CHECK: for %i5 = 0 to #map{{[a-z_0-9]*}}()[%arg0] { + affine.for %i0 = 0 to #map_simple0()[] { + // CHECK: affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to #map_simple1()[%arg1] { + // CHECK: affine.for %i1 = 0 to %arg1 { + affine.for %i2 = 0 to #map_non_simple0(%i0)[] { + // CHECK: affine.for %i2 = 0 to #map{{[a-z_0-9]*}}(%i0) { + affine.for %i3 = 0 to #map_non_simple1(%i0)[%arg1] { + // CHECK: affine.for %i3 = 0 to #map{{[a-z_0-9]*}}(%i0)[%arg1] { + affine.for %i4 = 0 to #map_non_simple2()[%arg1, %arg0] { + // CHECK: affine.for %i4 = 0 to #map{{[a-z_0-9]*}}()[%arg1, %arg0] { + affine.for %i5 = 0 to #map_non_simple3()[%arg0] { + // CHECK: affine.for %i5 = 0 to #map{{[a-z_0-9]*}}()[%arg0] { %c42_i32 = constant 42 : i32 } } @@ -749,9 +749,9 @@ func @sparsevectorattr() -> () { // CHECK-LABEL: func @loops_with_blockids() { func @loops_with_blockids() { ^block0: - for %i = 1 to 100 step 2 { + affine.for %i = 1 to 100 step 2 { ^block1: - for %j = 1 to 200 { + affine.for %j = 1 to 200 { ^block2: } } diff --git a/mlir/test/IR/pretty-locations.mlir b/mlir/test/IR/pretty-locations.mlir index 4668e7a832b..3f13d6c2368 100644 --- a/mlir/test/IR/pretty-locations.mlir +++ b/mlir/test/IR/pretty-locations.mlir @@ -18,7 +18,7 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) { %3 = constant 4 : index loc(callsite("foo" at callsite("mysource1.cc":10:8 at callsite("mysource2.cc":13:8 at "mysource3.cc":100:10)))) // CHECK: } ["foo", mysource.cc:10:8] - for %i0 = 0 to 8 { + affine.for %i0 = 0 to 8 { } loc(fused["foo", "mysource.cc":10:8]) // CHECK: } <"myPass">["foo", "foo2"] diff --git a/mlir/test/Transforms/Vectorize/lower_vector_transfers.mlir b/mlir/test/Transforms/Vectorize/lower_vector_transfers.mlir index b82ac08fe59..e896e0588d3 100644 --- a/mlir/test/Transforms/Vectorize/lower_vector_transfers.mlir +++ b/mlir/test/Transforms/Vectorize/lower_vector_transfers.mlir @@ -6,8 +6,8 @@ // CHECK-LABEL: func @materialize_read_1d() { func @materialize_read_1d() { %A = alloc () : memref<7x42xf32> - for %i0 = 0 to 7 step 4 { - for %i1 = 0 to 42 step 4 { + affine.for %i0 = 0 to 7 step 4 { + affine.for %i1 = 0 to 42 step 4 { %f1 = vector_transfer_read %A, %i0, %i1 {permutation_map: (d0, d1) -> (d0)} : (memref<7x42xf32>, index, index) -> vector<4xf32> %ip1 = affine.apply (d0) -> (d0 + 1) (%i1) %f2 = vector_transfer_read %A, %i0, %ip1 {permutation_map: (d0, d1) -> (d0)} : (memref<7x42xf32>, index, index) -> vector<4xf32> @@ -29,11 +29,11 @@ func @materialize_read_1d() { // CHECK-LABEL: func @materialize_read_1d_partially_specialized func @materialize_read_1d_partially_specialized(%dyn1 : index, %dyn2 : index, %dyn4 : index) { %A = alloc (%dyn1, %dyn2, %dyn4) : memref<7x?x?x42x?xf32> - for %i0 = 0 to 7 { - for %i1 = 0 to %dyn1 { - for %i2 = 0 to %dyn2 { - for %i3 = 0 to 42 step 2 { - for %i4 = 0 to %dyn4 { + affine.for %i0 = 0 to 7 { + affine.for %i1 = 0 to %dyn1 { + affine.for %i2 = 0 to %dyn2 { + affine.for %i3 = 0 to 42 step 2 { + affine.for %i4 = 0 to %dyn4 { %f1 = vector_transfer_read %A, %i0, %i1, %i2, %i3, %i4 {permutation_map: (d0, d1, d2, d3, d4) -> (d3)} : ( memref<7x?x?x42x?xf32>, index, index, index, index, index) -> vector<4xf32> %i3p1 = affine.apply (d0) -> (d0 + 1) (%i3) %f2 = vector_transfer_read %A, %i0, %i1, %i2, %i3p1, %i4 {permutation_map: (d0, d1, d2, d3, d4) -> (d3)} : ( memref<7x?x?x42x?xf32>, index, index, index, index, index) -> vector<4xf32> @@ -54,10 +54,10 @@ func @materialize_read_1d_partially_specialized(%dyn1 : index, %dyn2 : index, %d // CHECK-LABEL: func @materialize_read(%arg0: index, %arg1: index, %arg2: index, %arg3: index) { func @materialize_read(%M: index, %N: index, %O: index, %P: index) { // CHECK-NEXT: %0 = alloc(%arg0, %arg1, %arg2, %arg3) : memref - // CHECK-NEXT: for %[[I0:.*]] = 0 to %arg0 step 3 { - // CHECK-NEXT: for %[[I1:.*]] = 0 to %arg1 { - // CHECK-NEXT: for %[[I2:.*]] = 0 to %arg2 { - // CHECK-NEXT: for %[[I3:.*]] = 0 to %arg3 step 5 { + // CHECK-NEXT: affine.for %[[I0:.*]] = 0 to %arg0 step 3 { + // CHECK-NEXT: affine.for %[[I1:.*]] = 0 to %arg1 { + // CHECK-NEXT: affine.for %[[I2:.*]] = 0 to %arg2 { + // CHECK-NEXT: affine.for %[[I3:.*]] = 0 to %arg3 step 5 { // CHECK-NEXT: %[[C0:.*]] = constant 0 : index // CHECK-NEXT: %[[C1:.*]] = constant 1 : index // CHECK: {{.*}} = dim %0, 0 : memref @@ -66,9 +66,9 @@ func @materialize_read(%M: index, %N: index, %O: index, %P: index) { // CHECK-NEXT: {{.*}} = dim %0, 3 : memref // CHECK: %[[ALLOC:.*]] = alloc() : memref<5x4x3xf32> // CHECK-NEXT: %[[VECTOR_VIEW:.*]] = vector_type_cast %[[ALLOC]] : memref<5x4x3xf32>, memref<1xvector<5x4x3xf32>> - // CHECK-NEXT: for %[[I4:.*]] = 0 to 3 { - // CHECK-NEXT: for %[[I5:.*]] = 0 to 4 { - // CHECK-NEXT: for %[[I6:.*]] = 0 to 5 { + // CHECK-NEXT: affine.for %[[I4:.*]] = 0 to 3 { + // CHECK-NEXT: affine.for %[[I5:.*]] = 0 to 4 { + // CHECK-NEXT: affine.for %[[I6:.*]] = 0 to 5 { // CHECK-NEXT: {{.*}} = affine.apply #[[ADD]] // CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index // CHECK-NEXT: {{.*}} = affine.apply #[[ADD]] @@ -109,10 +109,10 @@ func @materialize_read(%M: index, %N: index, %O: index, %P: index) { // CHECK-NEXT: return // CHECK-NEXT:} %A = alloc (%M, %N, %O, %P) : memref - for %i0 = 0 to %M step 3 { - for %i1 = 0 to %N { - for %i2 = 0 to %O { - for %i3 = 0 to %P step 5 { + affine.for %i0 = 0 to %M step 3 { + affine.for %i1 = 0 to %N { + affine.for %i2 = 0 to %O { + affine.for %i3 = 0 to %P step 5 { %f = vector_transfer_read %A, %i0, %i1, %i2, %i3 {permutation_map: (d0, d1, d2, d3) -> (d3, 0, d0)} : (memref, index, index, index, index) -> vector<5x4x3xf32> } } @@ -125,10 +125,10 @@ func @materialize_read(%M: index, %N: index, %O: index, %P: index) { func @materialize_write(%M: index, %N: index, %O: index, %P: index) { // CHECK-NEXT: %0 = alloc(%arg0, %arg1, %arg2, %arg3) : memref // CHECK-NEXT: %cst = constant splat, 1.000000e+00> : vector<5x4x3xf32> - // CHECK-NEXT: for %[[I0:.*]] = 0 to %arg0 step 3 { - // CHECK-NEXT: for %[[I1:.*]] = 0 to %arg1 step 4 { - // CHECK-NEXT: for %[[I2:.*]] = 0 to %arg2 { - // CHECK-NEXT: for %[[I3:.*]] = 0 to %arg3 step 5 { + // CHECK-NEXT: affine.for %[[I0:.*]] = 0 to %arg0 step 3 { + // CHECK-NEXT: affine.for %[[I1:.*]] = 0 to %arg1 step 4 { + // CHECK-NEXT: affine.for %[[I2:.*]] = 0 to %arg2 { + // CHECK-NEXT: affine.for %[[I3:.*]] = 0 to %arg3 step 5 { // CHECK-NEXT: %[[C0:.*]] = constant 0 : index // CHECK-NEXT: %[[C1:.*]] = constant 1 : index // CHECK: {{.*}} = dim %0, 0 : memref @@ -138,9 +138,9 @@ func @materialize_write(%M: index, %N: index, %O: index, %P: index) { // CHECK: %[[ALLOC:.*]] = alloc() : memref<5x4x3xf32> // CHECK-NEXT: %[[VECTOR_VIEW:.*]] = vector_type_cast {{.*}} : memref<5x4x3xf32>, memref<1xvector<5x4x3xf32>> // CHECK-NEXT: store %cst, {{.*}}[%[[C0]]] : memref<1xvector<5x4x3xf32>> - // CHECK-NEXT: for %[[I4:.*]] = 0 to 3 { - // CHECK-NEXT: for %[[I5:.*]] = 0 to 4 { - // CHECK-NEXT: for %[[I6:.*]] = 0 to 5 { + // CHECK-NEXT: affine.for %[[I4:.*]] = 0 to 3 { + // CHECK-NEXT: affine.for %[[I5:.*]] = 0 to 4 { + // CHECK-NEXT: affine.for %[[I6:.*]] = 0 to 5 { // CHECK-NEXT: {{.*}} = load {{.*}}[%[[I6]], %[[I5]], %[[I4]]] : memref<5x4x3xf32> // CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I0]], %[[I4]]) // CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index @@ -184,10 +184,10 @@ func @materialize_write(%M: index, %N: index, %O: index, %P: index) { // CHECK-NEXT:} %A = alloc (%M, %N, %O, %P) : memref %f1 = constant splat, 1.000000e+00> : vector<5x4x3xf32> - for %i0 = 0 to %M step 3 { - for %i1 = 0 to %N step 4 { - for %i2 = 0 to %O { - for %i3 = 0 to %P step 5 { + affine.for %i0 = 0 to %M step 3 { + affine.for %i1 = 0 to %N step 4 { + affine.for %i2 = 0 to %O { + affine.for %i3 = 0 to %P step 5 { vector_transfer_write %f1, %A, %i0, %i1, %i2, %i3 {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d0)} : vector<5x4x3xf32>, memref, index, index, index, index } } diff --git a/mlir/test/Transforms/Vectorize/materialize.mlir b/mlir/test/Transforms/Vectorize/materialize.mlir index 80458c75333..ce445ec75bb 100644 --- a/mlir/test/Transforms/Vectorize/materialize.mlir +++ b/mlir/test/Transforms/Vectorize/materialize.mlir @@ -10,10 +10,10 @@ func @materialize(%M : index, %N : index, %O : index, %P : index) { %A = alloc (%M, %N, %O, %P) : memref %f1 = constant splat, 1.000000e+00> : vector<4x4x4xf32> - // CHECK: for %i0 = 0 to %arg0 step 4 { - // CHECK-NEXT: for %i1 = 0 to %arg1 step 4 { - // CHECK-NEXT: for %i2 = 0 to %arg2 { - // CHECK-NEXT: for %i3 = 0 to %arg3 step 4 { + // CHECK: affine.for %i0 = 0 to %arg0 step 4 { + // CHECK-NEXT: affine.for %i1 = 0 to %arg1 step 4 { + // CHECK-NEXT: affine.for %i2 = 0 to %arg2 { + // CHECK-NEXT: affine.for %i3 = 0 to %arg3 step 4 { // CHECK-NEXT: %[[a:[0-9]+]] = {{.*}}[[ID1]](%i0) // CHECK-NEXT: %[[b:[0-9]+]] = {{.*}}[[ID1]](%i1) // CHECK-NEXT: %[[c:[0-9]+]] = {{.*}}[[ID1]](%i2) @@ -25,10 +25,10 @@ func @materialize(%M : index, %N : index, %O : index, %P : index) { // CHECK: vector_transfer_write {{.*}}, %0, {{.*}}, %[[b2]], {{.*}} {permutation_map: #[[D0D1D2D3TOD1D0]]} : vector<4x4xf32>, memref, index, index, index, index // CHECK: %[[b3:[0-9]+]] = {{.*}}[[D0P3]](%i1) // CHECK: vector_transfer_write {{.*}}, %0, {{.*}}, %[[b3]], {{.*}} {permutation_map: #[[D0D1D2D3TOD1D0]]} : vector<4x4xf32>, memref, index, index, index, index - for %i0 = 0 to %M step 4 { - for %i1 = 0 to %N step 4 { - for %i2 = 0 to %O { - for %i3 = 0 to %P step 4 { + affine.for %i0 = 0 to %M step 4 { + affine.for %i1 = 0 to %N step 4 { + affine.for %i2 = 0 to %O { + affine.for %i3 = 0 to %P step 4 { "vector_transfer_write"(%f1, %A, %i0, %i1, %i2, %i3) {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d0)} : (vector<4x4x4xf32>, memref, index, index, index, index) -> () } } diff --git a/mlir/test/Transforms/Vectorize/materialize_vectors_1d_to_1d.mlir b/mlir/test/Transforms/Vectorize/materialize_vectors_1d_to_1d.mlir index b5f771d7e62..71c442b965e 100644 --- a/mlir/test/Transforms/Vectorize/materialize_vectors_1d_to_1d.mlir +++ b/mlir/test/Transforms/Vectorize/materialize_vectors_1d_to_1d.mlir @@ -15,8 +15,8 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { %f1 = constant 1.0 : f32 %f2 = constant 2.0 : f32 // 4x unroll (jammed by construction). - // CHECK: for %i0 = 0 to %arg0 { - // CHECK-NEXT: for %i1 = 0 to %arg1 step 32 { + // CHECK: affine.for %i0 = 0 to %arg0 { + // CHECK-NEXT: affine.for %i1 = 0 to %arg1 step 32 { // CHECK-NEXT: [[CST0:%.*]] = constant splat, 1.000000e+00> : vector<8xf32> // CHECK-NEXT: [[CST1:%.*]] = constant splat, 1.000000e+00> : vector<8xf32> // CHECK-NEXT: [[CST2:%.*]] = constant splat, 1.000000e+00> : vector<8xf32> @@ -34,15 +34,15 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { // CHECK-NEXT: [[VAL31:%.*]] = affine.apply [[D0P24]]{{.*}} // CHECK-NEXT: vector_transfer_write [[CST3]], {{.*}}, [[VAL30]], [[VAL31]] {permutation_map: [[D0D1TOD1]]} : vector<8xf32> // - for %i0 = 0 to %M { - for %i1 = 0 to %N { + affine.for %i0 = 0 to %M { + affine.for %i1 = 0 to %N { // non-scoped %f1 store %f1, %A[%i0, %i1] : memref } } // 4x unroll (jammed by construction). - // CHECK: for %i2 = 0 to %arg0 { - // CHECK-NEXT: for %i3 = 0 to %arg1 step 32 { + // CHECK: affine.for %i2 = 0 to %arg0 { + // CHECK-NEXT: affine.for %i3 = 0 to %arg1 step 32 { // CHECK-NEXT: [[CST0:%.*]] = constant splat, 2.000000e+00> : vector<8xf32> // CHECK-NEXT: [[CST1:%.*]] = constant splat, 2.000000e+00> : vector<8xf32> // CHECK-NEXT: [[CST2:%.*]] = constant splat, 2.000000e+00> : vector<8xf32> @@ -60,15 +60,15 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { // CHECK-NEXT: [[VAL31:%.*]] = affine.apply [[D0P24]]{{.*}} // CHECK-NEXT: vector_transfer_write [[CST3]], {{.*}}, [[VAL30]], [[VAL31]] {permutation_map: [[D0D1TOD1]]} : vector<8xf32> // - for %i2 = 0 to %M { - for %i3 = 0 to %N { + affine.for %i2 = 0 to %M { + affine.for %i3 = 0 to %N { // non-scoped %f2 store %f2, %B[%i2, %i3] : memref } } // 4x unroll (jammed by construction). - // CHECK: for %i4 = 0 to %arg0 { - // CHECK-NEXT: for %i5 = 0 to %arg1 step 32 { + // CHECK: affine.for %i4 = 0 to %arg0 { + // CHECK-NEXT: affine.for %i5 = 0 to %arg1 step 32 { // CHECK-NEXT: {{.*}} = affine.apply // CHECK-NEXT: {{.*}} = affine.apply // CHECK-NEXT: {{.*}} = vector_transfer_read @@ -110,8 +110,8 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { // CHECK-NEXT: {{.*}} = affine.apply // CHECK-NEXT: vector_transfer_write // - for %i4 = 0 to %M { - for %i5 = 0 to %N { + affine.for %i4 = 0 to %M { + affine.for %i5 = 0 to %N { %a5 = load %A[%i4, %i5] : memref %b5 = load %B[%i4, %i5] : memref %s5 = addf %a5, %b5 : f32 diff --git a/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_1d.mlir b/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_1d.mlir index 92df49fa8fa..62149c323b6 100644 --- a/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_1d.mlir +++ b/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_1d.mlir @@ -15,8 +15,8 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { %f1 = constant 1.0 : f32 %f2 = constant 2.0 : f32 // (3x2)x unroll (jammed by construction). - // CHECK: for %i0 = 0 to %arg0 step 3 { - // CHECK-NEXT: for %i1 = 0 to %arg1 step 16 { + // CHECK: affine.for %i0 = 0 to %arg0 step 3 { + // CHECK-NEXT: affine.for %i1 = 0 to %arg1 step 16 { // CHECK-NEXT: {{.*}} = constant splat, 1.000000e+00> : vector<8xf32> // CHECK-NEXT: {{.*}} = constant splat, 1.000000e+00> : vector<8xf32> // CHECK-NEXT: {{.*}} = constant splat, 1.000000e+00> : vector<8xf32> @@ -41,26 +41,26 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { // CHECK-NEXT: [[VAL50:%.*]] = affine.apply [[D0P2]](%i0) // CHECK-NEXT: [[VAL51:%.*]] = affine.apply [[D0P8]](%i1) // CHECK-NEXT: vector_transfer_write {{.*}}, {{.*}}, [[VAL50]], [[VAL51]] {permutation_map: [[D0D1TOD1]]} : vector<8xf32> - for %i0 = 0 to %M { - for %i1 = 0 to %N { + affine.for %i0 = 0 to %M { + affine.for %i1 = 0 to %N { // non-scoped %f1 store %f1, %A[%i0, %i1] : memref } } // (3x2)x unroll (jammed by construction). - // CHECK: for %i2 = 0 to %arg0 step 3 { - // CHECK-NEXT: for %i3 = 0 to %arg1 step 16 { + // CHECK: affine.for %i2 = 0 to %arg0 step 3 { + // CHECK-NEXT: affine.for %i3 = 0 to %arg1 step 16 { // ..... - for %i2 = 0 to %M { - for %i3 = 0 to %N { + affine.for %i2 = 0 to %M { + affine.for %i3 = 0 to %N { // non-scoped %f2 // CHECK does (3x4)x unrolling. store %f2, %B[%i2, %i3] : memref } } // (3x2)x unroll (jammed by construction). - // CHECK: for %i4 = 0 to %arg0 step 3 { - // CHECK-NEXT: for %i5 = 0 to %arg1 step 16 { + // CHECK: affine.for %i4 = 0 to %arg0 step 3 { + // CHECK-NEXT: affine.for %i5 = 0 to %arg1 step 16 { // CHECK-NEXT: {{.*}} = affine.apply // CHECK-NEXT: {{.*}} = affine.apply // CHECK-NEXT: {{.*}} = vector_transfer_read @@ -122,8 +122,8 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { // CHECK-NEXT: {{.*}} = affine.apply // CHECK-NEXT: vector_transfer_write // - for %i4 = 0 to %M { - for %i5 = 0 to %N { + affine.for %i4 = 0 to %M { + affine.for %i5 = 0 to %N { %a5 = load %A[%i4, %i5] : memref %b5 = load %B[%i4, %i5] : memref %s5 = addf %a5, %b5 : f32 diff --git a/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_2d.mlir b/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_2d.mlir index 36ec96e30b4..59705eca69e 100644 --- a/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_2d.mlir +++ b/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_2d.mlir @@ -13,8 +13,8 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { %f1 = constant 1.0 : f32 %f2 = constant 2.0 : f32 // 2x unroll (jammed by construction). - // CHECK: for %i0 = 0 to %arg0 step 3 { - // CHECK-NEXT: for %i1 = 0 to %arg1 step 32 { + // CHECK: affine.for %i0 = 0 to %arg0 step 3 { + // CHECK-NEXT: affine.for %i1 = 0 to %arg1 step 32 { // CHECK-NEXT: {{.*}} = constant splat, 1.000000e+00> : vector<3x16xf32> // CHECK-NEXT: {{.*}} = constant splat, 1.000000e+00> : vector<3x16xf32> // CHECK-NEXT: [[VAL00:%.*]] = affine.apply [[ID1]](%i0) @@ -24,15 +24,15 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { // CHECK-NEXT: [[VAL11:%.*]] = affine.apply [[D0P16]](%i1) // CHECK-NEXT: vector_transfer_write {{.*}}, {{.*}}, [[VAL10]], [[VAL11]] {permutation_map: [[ID2]]} : vector<3x16xf32> // - for %i0 = 0 to %M { - for %i1 = 0 to %N { + affine.for %i0 = 0 to %M { + affine.for %i1 = 0 to %N { // non-scoped %f1 store %f1, %A[%i0, %i1] : memref } } // 2x unroll (jammed by construction). - // CHECK: for %i2 = 0 to %arg0 step 3 { - // CHECK-NEXT: for %i3 = 0 to %arg1 step 32 { + // CHECK: affine.for %i2 = 0 to %arg0 step 3 { + // CHECK-NEXT: affine.for %i3 = 0 to %arg1 step 32 { // CHECK-NEXT: {{.*}} = constant splat, 2.000000e+00> : vector<3x16xf32> // CHECK-NEXT: {{.*}} = constant splat, 2.000000e+00> : vector<3x16xf32> // CHECK-NEXT: [[VAL00:%.*]] = affine.apply [[ID1]](%i2) @@ -42,15 +42,15 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { // CHECK-NEXT: [[VAL11:%.*]] = affine.apply [[D0P16]](%i3) // CHECK-NEXT: vector_transfer_write {{.*}}, {{.*}}, [[VAL10]], [[VAL11]] {permutation_map: [[ID2]]} : vector<3x16xf32> // - for %i2 = 0 to %M { - for %i3 = 0 to %N { + affine.for %i2 = 0 to %M { + affine.for %i3 = 0 to %N { // non-scoped %f2 store %f2, %B[%i2, %i3] : memref } } // 2x unroll (jammed by construction). - // CHECK: for %i4 = 0 to %arg0 step 3 { - // CHECK-NEXT: for %i5 = 0 to %arg1 step 32 { + // CHECK: affine.for %i4 = 0 to %arg0 step 3 { + // CHECK-NEXT: affine.for %i5 = 0 to %arg1 step 32 { // CHECK-NEXT: {{.*}} = affine.apply // CHECK-NEXT: {{.*}} = affine.apply // CHECK-NEXT: {{.*}} = vector_transfer_read @@ -72,8 +72,8 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { // CHECK-NEXT: {{.*}} = affine.apply // CHECK-NEXT: vector_transfer_write // - for %i4 = 0 to %M { - for %i5 = 0 to %N { + affine.for %i4 = 0 to %M { + affine.for %i5 = 0 to %N { %a5 = load %A[%i4, %i5] : memref %b5 = load %B[%i4, %i5] : memref %s5 = addf %a5, %b5 : f32 diff --git a/mlir/test/Transforms/Vectorize/normalize_maps.mlir b/mlir/test/Transforms/Vectorize/normalize_maps.mlir index 9569dbe07fe..076d2c75633 100644 --- a/mlir/test/Transforms/Vectorize/normalize_maps.mlir +++ b/mlir/test/Transforms/Vectorize/normalize_maps.mlir @@ -9,19 +9,19 @@ // CHECK-LABEL: func @simple() func @simple() { - for %i0 = 0 to 7 { + affine.for %i0 = 0 to 7 { %0 = affine.apply (d0) -> (d0) (%i0) %1 = affine.apply (d0) -> (d0) (%0) %2 = affine.apply (d0, d1) -> (d0 + d1) (%0, %0) %3 = affine.apply (d0, d1) -> (d0 - d1) (%0, %0) } - // CHECK-NEXT: for %i0 = 0 to 7 + // CHECK-NEXT: affine.for %i0 = 0 to 7 // CHECK-NEXT: {{.*}} affine.apply #[[ID1]](%i0) // CHECK-NEXT: {{.*}} affine.apply #[[D0TIMES2]](%i0) // CHECK-NEXT: {{.*}} affine.apply #[[ZERO]]() - for %i1 = 0 to 7 { - for %i2 = 0 to 42 { + affine.for %i1 = 0 to 7 { + affine.for %i2 = 0 to 42 { %20 = affine.apply (d0, d1) -> (d1) (%i1, %i2) %21 = affine.apply (d0, d1) -> (d0) (%i1, %i2) %22 = affine.apply (d0, d1) -> (d0 + d1) (%20, %21) @@ -29,15 +29,15 @@ func @simple() { %24 = affine.apply (d0, d1) -> (-d0 + d1) (%20, %21) } } - // CHECK: for %i1 = 0 to 7 - // CHECK-NEXT: for %i2 = 0 to 42 + // CHECK: affine.for %i1 = 0 to 7 + // CHECK-NEXT: affine.for %i2 = 0 to 42 // CHECK-NEXT: {{.*}} affine.apply #[[D0PLUSD1]](%i1, %i2) // CHECK-NEXT: {{.*}} affine.apply #[[MINSD0PLUSD1]](%i1, %i2) // CHECK-NEXT: {{.*}} affine.apply #[[D0MINUSD1]](%i1, %i2) - for %i3 = 0 to 16 { - for %i4 = 0 to 47 step 2 { - for %i5 = 0 to 78 step 16 { + affine.for %i3 = 0 to 16 { + affine.for %i4 = 0 to 47 step 2 { + affine.for %i5 = 0 to 78 step 16 { %50 = affine.apply (d0) -> (d0) (%i3) %51 = affine.apply (d0) -> (d0) (%i4) %52 = affine.apply (d0) -> (d0) (%i5) @@ -47,9 +47,9 @@ func @simple() { } } } - // CHECK: for %i3 = 0 to 16 - // CHECK-NEXT: for %i4 = 0 to 47 step 2 - // CHECK-NEXT: for %i5 = 0 to 78 step 16 + // CHECK: affine.for %i3 = 0 to 16 + // CHECK-NEXT: affine.for %i4 = 0 to 47 step 2 + // CHECK-NEXT: affine.for %i5 = 0 to 78 step 16 // CHECK-NEXT: {{.*}} affine.apply #[[ID1]](%i3) // CHECK-NEXT: {{.*}} affine.apply #[[ID1]](%i4) // CHECK-NEXT: {{.*}} affine.apply #[[ID1]](%i5) diff --git a/mlir/test/Transforms/Vectorize/vectorize_1d.mlir b/mlir/test/Transforms/Vectorize/vectorize_1d.mlir index da69e8dd26d..6d3f3a54e99 100644 --- a/mlir/test/Transforms/Vectorize/vectorize_1d.mlir +++ b/mlir/test/Transforms/Vectorize/vectorize_1d.mlir @@ -23,17 +23,17 @@ func @vec1d(%A : memref, %B : memref) { // // CHECK: for {{.*}} step 128 // CHECK-NEXT: {{.*}} = vector_transfer_read %arg0, [[C0]], [[C0]] {permutation_map: #[[map_proj_d0d1_0]]} : (memref, index, index) -> vector<128xf32> - for %i0 = 0 to %M { // vectorized due to scalar -> vector + affine.for %i0 = 0 to %M { // vectorized due to scalar -> vector %a0 = load %A[%cst0, %cst0] : memref } // // CHECK:for {{.*}} [[ARG_M]] { - for %i1 = 0 to %M { // not vectorized + affine.for %i1 = 0 to %M { // not vectorized %a1 = load %A[%i1, %i1] : memref } // -// CHECK: for %i{{[0-9]*}} = 0 to [[ARG_M]] { - for %i2 = 0 to %M { // not vectorized, would vectorize with --test-fastest-varying=1 +// CHECK: affine.for %i{{[0-9]*}} = 0 to [[ARG_M]] { + affine.for %i2 = 0 to %M { // not vectorized, would vectorize with --test-fastest-varying=1 %r2 = affine.apply (d0) -> (d0) (%i2) %a2 = load %A[%r2#0, %cst0] : memref } @@ -41,7 +41,7 @@ func @vec1d(%A : memref, %B : memref) { // CHECK:for [[IV3:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128 // CHECK-NEXT: [[APP3:%[a-zA-Z0-9]+]] = affine.apply {{.*}}[[IV3]] // CHECK-NEXT: {{.*}} = vector_transfer_read %arg0, [[C0]], [[APP3]] {permutation_map: #[[map_proj_d0d1_d1]]} : {{.*}} -> vector<128xf32> - for %i3 = 0 to %M { // vectorized + affine.for %i3 = 0 to %M { // vectorized %r3 = affine.apply (d0) -> (d0) (%i3) %a3 = load %A[%cst0, %r3#0] : memref } @@ -51,8 +51,8 @@ func @vec1d(%A : memref, %B : memref) { // CHECK-NEXT: [[APP50:%[0-9]+]] = affine.apply {{.*}}([[IV4]], [[IV5]]) // CHECK-NEXT: [[APP51:%[0-9]+]] = affine.apply {{.*}}([[IV4]], [[IV5]]) // CHECK-NEXT: {{.*}} = vector_transfer_read %arg0, [[APP50]], [[APP51]] {permutation_map: #[[map_proj_d0d1_d1]]} : {{.*}} -> vector<128xf32> - for %i4 = 0 to %M { // vectorized - for %i5 = 0 to %N { // not vectorized, would vectorize with --test-fastest-varying=1 + affine.for %i4 = 0 to %M { // vectorized + affine.for %i5 = 0 to %N { // not vectorized, would vectorize with --test-fastest-varying=1 %r50 = affine.apply (d0, d1) -> (d1) (%i4, %i5) %r51 = affine.apply (d0, d1) -> (d0) (%i4, %i5) %a5 = load %A[%r50, %r51] : memref @@ -61,8 +61,8 @@ func @vec1d(%A : memref, %B : memref) { // // CHECK: for [[IV6:%[i0-9]*]] = 0 to [[ARG_M]] { // CHECK-NEXT: for [[IV7:%[i0-9]*]] = 0 to [[ARG_N]] { - for %i6 = 0 to %M { // not vectorized, would vectorize with --test-fastest-varying=1 - for %i7 = 0 to %N { // not vectorized, can never vectorize + affine.for %i6 = 0 to %M { // not vectorized, would vectorize with --test-fastest-varying=1 + affine.for %i7 = 0 to %N { // not vectorized, can never vectorize %r70 = affine.apply (d0, d1) -> (d1 + d0) (%i6, %i7) %r71 = affine.apply (d0, d1) -> (d0) (%i6, %i7) %a7 = load %A[%r70, %r71] : memref @@ -74,8 +74,8 @@ func @vec1d(%A : memref, %B : memref) { // CHECK-NEXT: [[APP9_0:%[0-9]+]] = affine.apply {{.*}}([[IV8]], [[IV9]]) // CHECK-NEXT: [[APP9_1:%[0-9]+]] = affine.apply {{.*}}([[IV8]], [[IV9]]) // CHECK-NEXT: {{.*}} = vector_transfer_read %arg0, [[APP9_0]], [[APP9_1]] {permutation_map: #[[map_proj_d0d1_d1]]} : {{.*}} -> vector<128xf32> - for %i8 = 0 to %M { // vectorized - for %i9 = 0 to %N { + affine.for %i8 = 0 to %M { // vectorized + affine.for %i9 = 0 to %N { %r90 = affine.apply (d0, d1) -> (d1) (%i8, %i9) %r91 = affine.apply (d0, d1) -> (d0 + d1) (%i8, %i9) %a9 = load %A[%r90, %r91] : memref @@ -84,8 +84,8 @@ func @vec1d(%A : memref, %B : memref) { // // CHECK: for [[IV10:%[i0-9]*]] = 0 to %{{[0-9]*}} { // CHECK: for [[IV11:%[i0-9]*]] = 0 to %{{[0-9]*}} { - for %i10 = 0 to %M { // not vectorized, need per load transposes - for %i11 = 0 to %N { // not vectorized, need per load transposes + affine.for %i10 = 0 to %M { // not vectorized, need per load transposes + affine.for %i11 = 0 to %N { // not vectorized, need per load transposes %r11_0 = affine.apply (d0, d1) -> (d0) (%i10, %i11) %r11_1 = affine.apply (d0, d1) -> (d1) (%i10, %i11) %a11 = load %A[%r11_0, %r11_1] : memref @@ -98,9 +98,9 @@ func @vec1d(%A : memref, %B : memref) { // CHECK: for [[IV12:%[i0-9]*]] = 0 to %{{[0-9]*}} { // CHECK: for [[IV13:%[i0-9]*]] = 0 to %{{[0-9]*}} { // CHECK: for [[IV14:%[i0-9]+]] = 0 to [[ARG_P]] step 128 - for %i12 = 0 to %M { // not vectorized, can never vectorize - for %i13 = 0 to %N { // not vectorized, can never vectorize - for %i14 = 0 to %P { // vectorized + affine.for %i12 = 0 to %M { // not vectorized, can never vectorize + affine.for %i13 = 0 to %N { // not vectorized, can never vectorize + affine.for %i14 = 0 to %P { // vectorized %r14_0 = affine.apply (d0, d1, d2) -> (d1) (%i12, %i13, %i14) %r14_1 = affine.apply (d0, d1, d2) -> (d0 + d1) (%i12, %i13, %i14) %r14_2 = affine.apply (d0, d1, d2) -> (d0 + d2) (%i12, %i13, %i14) @@ -109,24 +109,24 @@ func @vec1d(%A : memref, %B : memref) { } } // -// CHECK: for %i{{[0-9]*}} = 0 to %{{[0-9]*}} { - for %i15 = 0 to %M { // not vectorized due to condition below +// CHECK: affine.for %i{{[0-9]*}} = 0 to %{{[0-9]*}} { + affine.for %i15 = 0 to %M { // not vectorized due to condition below if #set0(%i15) { %a15 = load %A[%cst0, %cst0] : memref } } // -// CHECK: for %i{{[0-9]*}} = 0 to %{{[0-9]*}} { - for %i16 = 0 to %M { // not vectorized, can't vectorize a vector load +// CHECK: affine.for %i{{[0-9]*}} = 0 to %{{[0-9]*}} { + affine.for %i16 = 0 to %M { // not vectorized, can't vectorize a vector load %a16 = alloc(%M) : memref> %l16 = load %a16[%i16] : memref> } // -// CHECK: for %i{{[0-9]*}} = 0 to %{{[0-9]*}} { +// CHECK: affine.for %i{{[0-9]*}} = 0 to %{{[0-9]*}} { // CHECK: for [[IV18:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128 // CHECK: {{.*}} = vector_transfer_read %arg0, [[C0]], [[C0]] {permutation_map: #[[map_proj_d0d1_0]]} : {{.*}} -> vector<128xf32> - for %i17 = 0 to %M { // not vectorized, the 1-D pattern that matched %i18 in DFS post-order prevents vectorizing %i17 - for %i18 = 0 to %M { // vectorized due to scalar -> vector + affine.for %i17 = 0 to %M { // not vectorized, the 1-D pattern that matched %i18 in DFS post-order prevents vectorizing %i17 + affine.for %i18 = 0 to %M { // vectorized due to scalar -> vector %a18 = load %A[%cst0, %cst0] : memref } } @@ -139,24 +139,24 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { %C = alloc (%M, %N) : memref %f1 = constant 1.0 : f32 %f2 = constant 2.0 : f32 - for %i0 = 0 to %M { - for %i1 = 0 to %N { + affine.for %i0 = 0 to %M { + affine.for %i1 = 0 to %N { // CHECK: [[C1:%.*]] = constant splat, 1.000000e+00> : vector<128xf32> // CHECK: vector_transfer_write [[C1]], {{.*}} {permutation_map: #[[map_proj_d0d1_d1]]} : vector<128xf32>, memref, index, index // non-scoped %f1 store %f1, %A[%i0, %i1] : memref } } - for %i2 = 0 to %M { - for %i3 = 0 to %N { + affine.for %i2 = 0 to %M { + affine.for %i3 = 0 to %N { // CHECK: [[C3:%.*]] = constant splat, 2.000000e+00> : vector<128xf32> // CHECK: vector_transfer_write [[C3]], {{.*}} {permutation_map: #[[map_proj_d0d1_d1]]} : vector<128xf32>, memref, index, index // non-scoped %f2 store %f2, %B[%i2, %i3] : memref } } - for %i4 = 0 to %M { - for %i5 = 0 to %N { + affine.for %i4 = 0 to %M { + affine.for %i5 = 0 to %N { // CHECK: [[A5:%.*]] = vector_transfer_read %0, {{.*}} {permutation_map: #[[map_proj_d0d1_d1]]} : (memref, index, index) -> vector<128xf32> // CHECK: [[B5:%.*]] = vector_transfer_read %1, {{.*}} {permutation_map: #[[map_proj_d0d1_d1]]} : (memref, index, index) -> vector<128xf32> // CHECK: [[S5:%.*]] = addf [[A5]], [[B5]] : vector<128xf32> @@ -188,10 +188,10 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { // CHECK-LABEL: @vec_rejected func @vec_rejected(%A : memref, %C : memref) { %N = dim %A, 0 : memref - for %i = 0 to %N { + affine.for %i = 0 to %N { // CHECK-NOT: vector %a = load %A[%i, %i] : memref // not vectorized - for %j = 0 to %N { + affine.for %j = 0 to %N { %b = load %A[%i, %j] : memref // may be vectorized // CHECK-NOT: vector %c = addf %a, %b : f32 // not vectorized because %a wasn't diff --git a/mlir/test/Transforms/Vectorize/vectorize_2d.mlir b/mlir/test/Transforms/Vectorize/vectorize_2d.mlir index d847f6bb5ce..59c7483749b 100644 --- a/mlir/test/Transforms/Vectorize/vectorize_2d.mlir +++ b/mlir/test/Transforms/Vectorize/vectorize_2d.mlir @@ -11,13 +11,13 @@ func @vec2d(%A : memref) { // CHECK: for {{.*}} = 0 to %1 step 32 // CHECK: for {{.*}} = 0 to %2 step 256 // Example: - // for %i0 = 0 to %0 { - // for %i1 = 0 to %1 step 32 { - // for %i2 = 0 to %2 step 256 { + // affine.for %i0 = 0 to %0 { + // affine.for %i1 = 0 to %1 step 32 { + // affine.for %i2 = 0 to %2 step 256 { // %3 = "vector_transfer_read"(%arg0, %i0, %i1, %i2) : (memref, index, index, index) -> vector<32x256xf32> - for %i0 = 0 to %M { - for %i1 = 0 to %N { - for %i2 = 0 to %P { + affine.for %i0 = 0 to %M { + affine.for %i1 = 0 to %N { + affine.for %i2 = 0 to %P { %a2 = load %A[%i0, %i1, %i2] : memref } } @@ -27,9 +27,9 @@ func @vec2d(%A : memref) { // CHECK: for {{.*}} = 0 to %2 { // For the case: --test-fastest-varying=1 --test-fastest-varying=0 no // vectorization happens because of loop nesting order . - for %i3 = 0 to %M { - for %i4 = 0 to %N { - for %i5 = 0 to %P { + affine.for %i3 = 0 to %M { + affine.for %i4 = 0 to %N { + affine.for %i5 = 0 to %P { %a5 = load %A[%i4, %i5, %i3] : memref } } @@ -43,24 +43,24 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { %C = alloc (%M, %N) : memref %f1 = constant 1.0 : f32 %f2 = constant 2.0 : f32 - for %i0 = 0 to %M { - for %i1 = 0 to %N { + affine.for %i0 = 0 to %M { + affine.for %i1 = 0 to %N { // CHECK: [[C1:%.*]] = constant splat, 1.000000e+00> : vector<32x256xf32> // CHECK: vector_transfer_write [[C1]], {{.*}} {permutation_map: #[[map_proj_d0d1_d0d1]]} : vector<32x256xf32>, memref, index, index // non-scoped %f1 store %f1, %A[%i0, %i1] : memref } } - for %i2 = 0 to %M { - for %i3 = 0 to %N { + affine.for %i2 = 0 to %M { + affine.for %i3 = 0 to %N { // CHECK: [[C3:%.*]] = constant splat, 2.000000e+00> : vector<32x256xf32> // CHECK: vector_transfer_write [[C3]], {{.*}} {permutation_map: #[[map_proj_d0d1_d0d1]]} : vector<32x256xf32>, memref, index, index // non-scoped %f2 store %f2, %B[%i2, %i3] : memref } } - for %i4 = 0 to %M { - for %i5 = 0 to %N { + affine.for %i4 = 0 to %M { + affine.for %i5 = 0 to %N { // CHECK: [[A5:%.*]] = vector_transfer_read %0, {{.*}} {permutation_map: #[[map_proj_d0d1_d0d1]]} : (memref, index, index) -> vector<32x256xf32> // CHECK: [[B5:%.*]] = vector_transfer_read %1, {{.*}} {permutation_map: #[[map_proj_d0d1_d0d1]]} : (memref, index, index) -> vector<32x256xf32> // CHECK: [[S5:%.*]] = addf [[A5]], [[B5]] : vector<32x256xf32> diff --git a/mlir/test/Transforms/Vectorize/vectorize_3d.mlir b/mlir/test/Transforms/Vectorize/vectorize_3d.mlir index 1a6bee585ee..08ca27dbeee 100644 --- a/mlir/test/Transforms/Vectorize/vectorize_3d.mlir +++ b/mlir/test/Transforms/Vectorize/vectorize_3d.mlir @@ -7,17 +7,17 @@ func @vec3d(%A : memref) { %0 = dim %A, 0 : memref %1 = dim %A, 1 : memref %2 = dim %A, 2 : memref - // CHECK: for %i0 = 0 to %0 { - // CHECK: for %i1 = 0 to %0 { - // CHECK: for %i2 = 0 to %0 step 32 { - // CHECK: for %i3 = 0 to %1 step 64 { - // CHECK: for %i4 = 0 to %2 step 256 { + // CHECK: affine.for %i0 = 0 to %0 { + // CHECK: affine.for %i1 = 0 to %0 { + // CHECK: affine.for %i2 = 0 to %0 step 32 { + // CHECK: affine.for %i3 = 0 to %1 step 64 { + // CHECK: affine.for %i4 = 0 to %2 step 256 { // CHECK: %3 = vector_transfer_read %arg0, %i2, %i3, %i4 {permutation_map: #[[map_proj_d0d1d2_d0d1d2]]} : (memref, index, index, index) -> vector<32x64x256xf32> - for %t0 = 0 to %0 { - for %t1 = 0 to %0 { - for %i0 = 0 to %0 { - for %i1 = 0 to %1 { - for %i2 = 0 to %2 { + affine.for %t0 = 0 to %0 { + affine.for %t1 = 0 to %0 { + affine.for %i0 = 0 to %0 { + affine.for %i1 = 0 to %1 { + affine.for %i2 = 0 to %2 { %a2 = load %A[%i0, %i1, %i2] : memref } } diff --git a/mlir/test/Transforms/Vectorize/vectorize_outer_loop_2d.mlir b/mlir/test/Transforms/Vectorize/vectorize_outer_loop_2d.mlir index 4654ab810df..d00b99f1716 100644 --- a/mlir/test/Transforms/Vectorize/vectorize_outer_loop_2d.mlir +++ b/mlir/test/Transforms/Vectorize/vectorize_outer_loop_2d.mlir @@ -7,13 +7,13 @@ func @vec2d(%A : memref) { %M = dim %A, 0 : memref %N = dim %A, 1 : memref %P = dim %A, 2 : memref - // CHECK: for %i0 = 0 to %0 step 32 - // CHECK: for %i1 = 0 to %1 { - // CHECK: for %i2 = 0 to %2 step 256 + // CHECK: affine.for %i0 = 0 to %0 step 32 + // CHECK: affine.for %i1 = 0 to %1 { + // CHECK: affine.for %i2 = 0 to %2 step 256 // CHECK: {{.*}} = vector_transfer_read %arg0, %i0, %i1, %i2 {permutation_map: #[[map_proj_d0d1d2_d0d2]]} : (memref, index, index, index) -> vector<32x256xf32> - for %i0 = 0 to %M { - for %i1 = 0 to %N { - for %i2 = 0 to %P { + affine.for %i0 = 0 to %M { + affine.for %i1 = 0 to %N { + affine.for %i2 = 0 to %P { %a2 = load %A[%i0, %i1, %i2] : memref } } @@ -23,9 +23,9 @@ func @vec2d(%A : memref) { // CHECK: for {{.*}} = 0 to %2 { // For the case: --test-fastest-varying=2 --test-fastest-varying=0 no // vectorization happens because of loop nesting order - for %i3 = 0 to %M { - for %i4 = 0 to %N { - for %i5 = 0 to %P { + affine.for %i3 = 0 to %M { + affine.for %i4 = 0 to %N { + affine.for %i5 = 0 to %P { %a5 = load %A[%i4, %i5, %i3] : memref } } diff --git a/mlir/test/Transforms/Vectorize/vectorize_outer_loop_transpose_2d.mlir b/mlir/test/Transforms/Vectorize/vectorize_outer_loop_transpose_2d.mlir index 0eebf816535..a8a8d5d7790 100644 --- a/mlir/test/Transforms/Vectorize/vectorize_outer_loop_transpose_2d.mlir +++ b/mlir/test/Transforms/Vectorize/vectorize_outer_loop_transpose_2d.mlir @@ -12,20 +12,20 @@ func @vec2d(%A : memref) { // CHECK: for {{.*}} = 0 to %2 { // For the case: --test-fastest-varying=0 --test-fastest-varying=2 no // vectorization happens because of loop nesting order. - for %i0 = 0 to %M { - for %i1 = 0 to %N { - for %i2 = 0 to %P { + affine.for %i0 = 0 to %M { + affine.for %i1 = 0 to %N { + affine.for %i2 = 0 to %P { %a2 = load %A[%i0, %i1, %i2] : memref } } } - // CHECK: for %i3 = 0 to %0 step 32 - // CHECK: for %i4 = 0 to %1 step 256 - // CHECK: for %i5 = 0 to %2 { + // CHECK: affine.for %i3 = 0 to %0 step 32 + // CHECK: affine.for %i4 = 0 to %1 step 256 + // CHECK: affine.for %i5 = 0 to %2 { // CHECK: {{.*}} = vector_transfer_read %arg0, %i4, %i5, %i3 {permutation_map: #[[map_proj_d0d1d2_d2d0]]} : (memref, index, index, index) -> vector<32x256xf32> - for %i3 = 0 to %M { - for %i4 = 0 to %N { - for %i5 = 0 to %P { + affine.for %i3 = 0 to %M { + affine.for %i4 = 0 to %N { + affine.for %i5 = 0 to %P { %a5 = load %A[%i4, %i5, %i3] : memref } } @@ -37,26 +37,26 @@ func @vec2d_imperfectly_nested(%A : memref) { %0 = dim %A, 0 : memref %1 = dim %A, 1 : memref %2 = dim %A, 2 : memref - // CHECK: for %i0 = 0 to %0 step 32 { - // CHECK: for %i1 = 0 to %1 { - // CHECK: for %i2 = 0 to %2 step 256 { + // CHECK: affine.for %i0 = 0 to %0 step 32 { + // CHECK: affine.for %i1 = 0 to %1 { + // CHECK: affine.for %i2 = 0 to %2 step 256 { // CHECK: %3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: #[[map_proj_d0d1d2_d2d0]]} : (memref, index, index, index) -> vector<32x256xf32> - // CHECK: for %i3 = 0 to %1 step 256 { - // CHECK: for %i4 = 0 to %2 { + // CHECK: affine.for %i3 = 0 to %1 step 256 { + // CHECK: affine.for %i4 = 0 to %2 { // CHECK: %4 = vector_transfer_read %arg0, %i3, %i4, %i0 {permutation_map: #[[map_proj_d0d1d2_d2d0]]} : (memref, index, index, index) -> vector<32x256xf32> - // CHECK: for %i5 = 0 to %2 { + // CHECK: affine.for %i5 = 0 to %2 { // CHECK: %5 = vector_transfer_read %arg0, %i3, %i5, %i0 {permutation_map: #[[map_proj_d0d1d2_d2d0]]} : (memref, index, index, index) -> vector<32x256xf32> - for %i0 = 0 to %0 { - for %i1 = 0 to %1 { - for %i2 = 0 to %2 { + affine.for %i0 = 0 to %0 { + affine.for %i1 = 0 to %1 { + affine.for %i2 = 0 to %2 { %a2 = load %A[%i2, %i1, %i0] : memref } } - for %i3 = 0 to %1 { - for %i4 = 0 to %2 { + affine.for %i3 = 0 to %1 { + affine.for %i4 = 0 to %2 { %a4 = load %A[%i3, %i4, %i0] : memref } - for %i5 = 0 to %2 { + affine.for %i5 = 0 to %2 { %a5 = load %A[%i3, %i5, %i0] : memref } } diff --git a/mlir/test/Transforms/Vectorize/vectorize_transpose_2d.mlir b/mlir/test/Transforms/Vectorize/vectorize_transpose_2d.mlir index 1ba563b3442..b8e4e075890 100644 --- a/mlir/test/Transforms/Vectorize/vectorize_transpose_2d.mlir +++ b/mlir/test/Transforms/Vectorize/vectorize_transpose_2d.mlir @@ -12,20 +12,20 @@ func @vec2d(%A : memref) { // CHECK: for {{.*}} = 0 to %2 { // For the case: --test-fastest-varying=0 --test-fastest-varying=1 no // vectorization happens because of loop nesting order. - for %i0 = 0 to %M { - for %i1 = 0 to %N { - for %i2 = 0 to %P { + affine.for %i0 = 0 to %M { + affine.for %i1 = 0 to %N { + affine.for %i2 = 0 to %P { %a2 = load %A[%i0, %i1, %i2] : memref } } } - // CHECK: for %i3 = 0 to %0 step 32 - // CHECK: for %i4 = 0 to %1 { - // CHECK: for %i5 = 0 to %2 step 256 + // CHECK: affine.for %i3 = 0 to %0 step 32 + // CHECK: affine.for %i4 = 0 to %1 { + // CHECK: affine.for %i5 = 0 to %2 step 256 // CHECK: {{.*}} = vector_transfer_read %arg0, %i4, %i5, %i3 {permutation_map: #[[map_proj_d0d1d2_d2d1]]} : (memref, index, index, index) -> vector<32x256xf32> - for %i3 = 0 to %M { - for %i4 = 0 to %N { - for %i5 = 0 to %P { + affine.for %i3 = 0 to %M { + affine.for %i4 = 0 to %N { + affine.for %i5 = 0 to %P { %a5 = load %A[%i4, %i5, %i3] : memref } } @@ -37,26 +37,26 @@ func @vec2d_imperfectly_nested(%A : memref) { %0 = dim %A, 0 : memref %1 = dim %A, 1 : memref %2 = dim %A, 2 : memref - // CHECK: for %i0 = 0 to %0 step 32 { - // CHECK: for %i1 = 0 to %1 step 256 { - // CHECK: for %i2 = 0 to %2 { + // CHECK: affine.for %i0 = 0 to %0 step 32 { + // CHECK: affine.for %i1 = 0 to %1 step 256 { + // CHECK: affine.for %i2 = 0 to %2 { // CHECK: %3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: #[[map_proj_d0d1d2_d2d1]]} : (memref, index, index, index) -> vector<32x256xf32> - // CHECK: for %i3 = 0 to %1 { - // CHECK: for %i4 = 0 to %2 step 256 { + // CHECK: affine.for %i3 = 0 to %1 { + // CHECK: affine.for %i4 = 0 to %2 step 256 { // CHECK: %4 = vector_transfer_read %arg0, %i3, %i4, %i0 {permutation_map: #[[map_proj_d0d1d2_d2d1]]} : (memref, index, index, index) -> vector<32x256xf32> - // CHECK: for %i5 = 0 to %2 step 256 { + // CHECK: affine.for %i5 = 0 to %2 step 256 { // CHECK: %5 = vector_transfer_read %arg0, %i3, %i5, %i0 {permutation_map: #[[map_proj_d0d1d2_d2d1]]} : (memref, index, index, index) -> vector<32x256xf32> - for %i0 = 0 to %0 { - for %i1 = 0 to %1 { - for %i2 = 0 to %2 { + affine.for %i0 = 0 to %0 { + affine.for %i1 = 0 to %1 { + affine.for %i2 = 0 to %2 { %a2 = load %A[%i2, %i1, %i0] : memref } } - for %i3 = 0 to %1 { - for %i4 = 0 to %2 { + affine.for %i3 = 0 to %1 { + affine.for %i4 = 0 to %2 { %a4 = load %A[%i3, %i4, %i0] : memref } - for %i5 = 0 to %2 { + affine.for %i5 = 0 to %2 { %a5 = load %A[%i3, %i5, %i0] : memref } } diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir index 29accf4ffc1..cc295751748 100644 --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -213,10 +213,10 @@ func @dyn_shape_fold(%L : index, %M : index) -> (memref, memref %c = alloc(%K, %N) : memref - // CHECK: for %i0 = - for %i = 0 to %L { - // CHECK-NEXT: for %i1 = - for %j = 0 to 10 { + // CHECK: affine.for %i0 = + affine.for %i = 0 to %L { + // CHECK-NEXT: affine.for %i1 = + affine.for %j = 0 to 10 { // CHECK-NEXT: %4 = load %0[%i0, %i1] : memref // CHECK-NEXT: store %4, %1[%c0, %c0, %i0, %i1, %c0] : memref<4x1024x8x512x?xf32> %v = load %a[%i, %j] : memref @@ -242,8 +242,8 @@ func @merge_constants() -> (index, index) { // CHECK-LABEL: func @hoist_constant func @hoist_constant(%arg0: memref<8xi32>) { // CHECK-NEXT: %c42_i32 = constant 42 : i32 - // CHECK-NEXT: for %i0 = 0 to 8 { - for %i0 = 0 to 8 { + // CHECK-NEXT: affine.for %i0 = 0 to 8 { + affine.for %i0 = 0 to 8 { // CHECK-NEXT: store %c42_i32, %arg0[%i0] %c42_i32 = constant 42 : i32 store %c42_i32, %arg0[%i0] : memref<8xi32> diff --git a/mlir/test/Transforms/constant-fold.mlir b/mlir/test/Transforms/constant-fold.mlir index 6043e478c5a..1c23914d7a2 100644 --- a/mlir/test/Transforms/constant-fold.mlir +++ b/mlir/test/Transforms/constant-fold.mlir @@ -2,8 +2,8 @@ // CHECK-LABEL: @test(%arg0: memref) { func @test(%p : memref) { - for %i0 = 0 to 128 { - for %i1 = 0 to 8 { // CHECK: for %i1 = 0 to 8 { + affine.for %i0 = 0 to 128 { + affine.for %i1 = 0 to 8 { // CHECK: affine.for %i1 = 0 to 8 { %0 = constant 4.5 : f32 %1 = constant 1.5 : f32 diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir index c4c0da7053e..31a7e13b73e 100644 --- a/mlir/test/Transforms/cse.mlir +++ b/mlir/test/Transforms/cse.mlir @@ -123,8 +123,8 @@ func @down_propagate_for_ml() { // CHECK: %c1_i32 = constant 1 : i32 %0 = constant 1 : i32 - // CHECK-NEXT: for %i0 = 0 to 4 { - for %i = 0 to 4 { + // CHECK-NEXT: affine.for %i0 = 0 to 4 { + affine.for %i = 0 to 4 { // CHECK-NEXT: "foo"(%c1_i32, %c1_i32) : (i32, i32) -> () %1 = constant 1 : i32 "foo"(%0, %1) : (i32, i32) -> () @@ -155,8 +155,8 @@ func @down_propagate_cfg() -> i32 { /// Check that operation definitions are NOT propagated up the dominance tree. // CHECK-LABEL: @up_propagate_ml func @up_propagate_ml() -> i32 { - // CHECK: for %i0 = 0 to 4 { - for %i = 0 to 4 { + // CHECK: affine.for %i0 = 0 to 4 { + affine.for %i = 0 to 4 { // CHECK-NEXT: %c1_i32 = constant 1 : i32 // CHECK-NEXT: "foo"(%c1_i32) : (i32) -> () %0 = constant 1 : i32 diff --git a/mlir/test/Transforms/dma-generate.mlir b/mlir/test/Transforms/dma-generate.mlir index a954bdb96a1..864a61d3abd 100644 --- a/mlir/test/Transforms/dma-generate.mlir +++ b/mlir/test/Transforms/dma-generate.mlir @@ -32,7 +32,7 @@ func @loop_nest_1d() { // Second DMA transfer. // CHECK: dma_start %1[%c256], %5[%c0], %c256_0, %6[%c0] : memref<512xf32>, memref<256xf32, 1>, memref<1xi32> // CHECK-NEXT: dma_wait %6[%c0], %c256_0 : memref<1xi32> - // CHECK: for %i0 = 0 to 256 { + // CHECK: affine.for %i0 = 0 to 256 { // CHECK-NEXT: %7 = load %3[%i0] : memref<256xf32, 1> // CHECK: %8 = affine.apply [[MAP_PLUS_256]](%i0) // CHECK: %9 = affine.apply [[MAP_MINUS_256]](%8) @@ -41,7 +41,7 @@ func @loop_nest_1d() { // CHECK: %11 = load %2[%i0] : memref<256xf32, 1> // CHECK-NEXT: } // CHECK-NEXT: return - for %i = 0 to 256 { + affine.for %i = 0 to 256 { load %A[%i] : memref<256 x f32> %idx = affine.apply (d0) -> (d0 + 256)(%i) load %B[%idx] : memref<512 x f32> @@ -68,20 +68,20 @@ func @loop_nest_1d() { // INCOMING DMA for C. // CHECK-DAG: dma_start %arg2[%c0, %c0], [[BUFC]][%c0, %c0], %c16384_0, [[TAGC]][%c0] : memref<512x32xf32>, memref<512x32xf32, 1>, memref<1xi32> // CHECK-DAG: dma_wait [[TAGC]][%c0], %c16384_0 : memref<1xi32> -// CHECK-NEXT: for %i0 = 0 to 32 { -// CHECK-NEXT: for %i1 = 0 to 32 { -// CHECK-NEXT: for %i2 = 0 to 32 { -// CHECK-NEXT: for %i3 = 0 to 16 { +// CHECK-NEXT: affine.for %i0 = 0 to 32 { +// CHECK-NEXT: affine.for %i1 = 0 to 32 { +// CHECK-NEXT: affine.for %i2 = 0 to 32 { +// CHECK-NEXT: affine.for %i3 = 0 to 16 { // CHECK-NEXT: %7 = affine.apply #map{{[0-9]+}}(%i1, %i3) // CHECK-NEXT: %8 = load [[BUFB]][%7, %i0] : memref<512x32xf32, 1> // CHECK-NEXT: "foo"(%8) : (f32) -> () // CHECK-NEXT: } -// CHECK-NEXT: for %i4 = 0 to 16 { +// CHECK-NEXT: affine.for %i4 = 0 to 16 { // CHECK-NEXT: %9 = affine.apply #map{{[0-9]+}}(%i2, %i4) // CHECK-NEXT: %10 = load [[BUFA]][%9, %i1] : memref<512x32xf32, 1> // CHECK-NEXT: "bar"(%10) : (f32) -> () // CHECK-NEXT: } -// CHECK-NEXT: for %i5 = 0 to 16 { +// CHECK-NEXT: affine.for %i5 = 0 to 16 { // CHECK-NEXT: %11 = "abc_compute"() : () -> f32 // CHECK-NEXT: %12 = affine.apply #map{{[0-9]+}}(%i2, %i5) // CHECK-NEXT: %13 = load [[BUFC]][%12, %i0] : memref<512x32xf32, 1> @@ -102,20 +102,20 @@ func @loop_nest_high_d(%A: memref<512 x 32 x f32>, // DMAs will be performed at this level (jT is the first loop without a stride). // A and B are read, while C is both read and written. A total of three new buffers // are allocated and existing load's/store's are replaced by accesses to those buffers. - for %jT = 0 to 32 { - for %kT = 0 to 32 { - for %iT = 0 to 32 { - for %kk = 0 to 16 { // k intratile + affine.for %jT = 0 to 32 { + affine.for %kT = 0 to 32 { + affine.for %iT = 0 to 32 { + affine.for %kk = 0 to 16 { // k intratile %k = affine.apply (d0, d1) -> (16*d0 + d1) (%kT, %kk) %v0 = load %B[%k, %jT] : memref<512 x 32 x f32> "foo"(%v0) : (f32) -> () } - for %ii = 0 to 16 { // i intratile. + affine.for %ii = 0 to 16 { // i intratile. %i = affine.apply (d0, d1) -> (16*d0 + d1)(%iT, %ii) %v1 = load %A[%i, %kT] : memref<512 x 32 x f32> "bar"(%v1) : (f32) -> () } - for %ii_ = 0 to 16 { // i intratile. + affine.for %ii_ = 0 to 16 { // i intratile. %v2 = "abc_compute"() : () -> f32 %i_ = affine.apply (d0, d1) -> (16*d0 + d1)(%iT, %ii_) %v3 = load %C[%i_, %jT] : memref<512 x 32 x f32> @@ -134,13 +134,13 @@ func @loop_nest_high_d(%A: memref<512 x 32 x f32>, // // CHECK-LABEL: func @loop_nest_modulo() { // CHECK: %0 = alloc() : memref<256x8xf32> -// CHECK-NEXT: for %i0 = 0 to 32 step 4 { +// CHECK-NEXT: affine.for %i0 = 0 to 32 step 4 { // CHECK-NEXT: %1 = affine.apply #map{{[0-9]+}}(%i0) // CHECK-NEXT: %2 = alloc() : memref<1x2xf32, 1> // CHECK-NEXT: %3 = alloc() : memref<1xi32> // CHECK-NEXT: dma_start %0[%1, %c0], %2[%c0, %c0], %c2, %3[%c0] : memref<256x8xf32>, memref<1x2xf32, 1>, memref<1xi32> // CHECK-NEXT: dma_wait %3[%c0], %c2 : memref<1xi32> -// CHECK-NEXT: for %i1 = 0 to 8 { +// CHECK-NEXT: affine.for %i1 = 0 to 8 { // ... // ... // CHECK: } @@ -148,9 +148,9 @@ func @loop_nest_high_d(%A: memref<512 x 32 x f32>, // CHECK-NEXT: return func @loop_nest_modulo() { %A = alloc() : memref<256 x 8 x f32> - for %i = 0 to 32 step 4 { + affine.for %i = 0 to 32 step 4 { // DMAs will be performed at this level (%j is the first unit stride loop) - for %j = 0 to 8 { + affine.for %j = 0 to 8 { %idx = affine.apply (d0) -> (d0 mod 2) (%j) // A buffer of size 32 x 2 will be allocated (original buffer was 256 x 8). %v = load %A[%i, %idx] : memref<256 x 8 x f32> @@ -164,17 +164,17 @@ func @loop_nest_modulo() { // CHECK-LABEL: func @loop_nest_tiled() -> memref<256x1024xf32> { func @loop_nest_tiled() -> memref<256x1024xf32> { %0 = alloc() : memref<256x1024xf32> - for %i0 = 0 to 256 step 32 { - for %i1 = 0 to 1024 step 32 { + affine.for %i0 = 0 to 256 step 32 { + affine.for %i1 = 0 to 1024 step 32 { // CHECK: %3 = alloc() : memref<32x32xf32, 1> // CHECK-NEXT: %4 = alloc() : memref<1xi32> // Strided DMA here: 32 x 32 tile in a 256 x 1024 memref. // CHECK-NEXT: dma_start %0[%1, %2], %3[%c0, %c0], %c1024, %4[%c0], %c1024_0, %c32 : memref<256x1024xf32>, memref<32x32xf32, 1>, memref<1xi32> // CHECK-NEXT: dma_wait -// CHECK-NEXT: for %i2 = #map -// CHECK-NEXT: for %i3 = #map - for %i2 = (d0) -> (d0)(%i0) to (d0) -> (d0 + 32)(%i0) { - for %i3 = (d0) -> (d0)(%i1) to (d0) -> (d0 + 32)(%i1) { +// CHECK-NEXT: affine.for %i2 = #map +// CHECK-NEXT: affine.for %i3 = #map + affine.for %i2 = (d0) -> (d0)(%i0) to (d0) -> (d0 + 32)(%i0) { + affine.for %i3 = (d0) -> (d0)(%i1) to (d0) -> (d0 + 32)(%i1) { // CHECK-NEXT: %5 = affine.apply [[MAP_INDEX_DIFF_EVEN]](%i0, %i1, %i2, %i3) // CHECK-NEXT: %6 = affine.apply [[MAP_INDEX_DIFF_ODD]](%i0, %i1, %i2, %i3) // CHECK-NEXT: %7 = load %3[%5, %6] : memref<32x32xf32, 1> @@ -196,8 +196,8 @@ func @dma_constant_dim_access(%A : memref<100x100xf32>) { // No strided DMA needed here. // CHECK: dma_start %arg0[%c1, %c0], %0[%c0, %c0], %c100, %1[%c0] : memref<100x100xf32>, memref<1x100xf32, 1>, // CHECK-NEXT: dma_wait %1[%c0], %c100 : memref<1xi32> - for %i = 0 to 100 { - for %j = 0 to ()[s0] -> (s0) ()[%N] { + affine.for %i = 0 to 100 { + affine.for %j = 0 to ()[s0] -> (s0) ()[%N] { // CHECK: %2 = affine.apply [[MAP_D0_MINUS_ONE]](%c1_0, %i1) // CHECK: %3 = affine.apply [[MAP_D1]](%c1_0, %i1) // CHECK-NEXT: %4 = load %0[%2, %3] : memref<1x100xf32, 1> @@ -210,8 +210,8 @@ func @dma_constant_dim_access(%A : memref<100x100xf32>) { // CHECK-LABEL: func @dma_with_symbolic_accesses func @dma_with_symbolic_accesses(%A : memref<100x100xf32>, %M : index) { %N = constant 9 : index - for %i = 0 to 100 { - for %j = 0 to 100 { + affine.for %i = 0 to 100 { + affine.for %j = 0 to 100 { %idy = affine.apply (d0, d1) [s0, s1] -> (d1 + s0 + s1)(%i, %j)[%M, %N] load %A[%i, %idy] : memref<100 x 100 x f32> } @@ -221,8 +221,8 @@ func @dma_with_symbolic_accesses(%A : memref<100x100xf32>, %M : index) { // CHECK-NEXT: %2 = alloc() : memref<1xi32> // CHECK-NEXT: dma_start %arg0[%c0, %0], %1[%c0, %c0], %c10000, %2[%c0] // CHECK-NEXT: dma_wait %2[%c0], %c10000 -// CHECK-NEXT: for %i0 = 0 to 100 { -// CHECK-NEXT: for %i1 = 0 to 100 { +// CHECK-NEXT: affine.for %i0 = 0 to 100 { +// CHECK-NEXT: affine.for %i1 = 0 to 100 { // CHECK-NEXT: %3 = affine.apply [[MAP_SYM_SHIFT]](%i0, %i1)[%arg1, %c9] // CHECK-NEXT: %4 = affine.apply [[MAP_3D_D1]](%arg1, %i0, %3) // CHECK-NEXT: %5 = affine.apply [[MAP_SUB_OFFSET]](%arg1, %i0, %3) @@ -241,8 +241,8 @@ func @dma_with_symbolic_loop_bounds(%A : memref<100x100xf32>, %M : index, %N: in // CHECK-NEXT: %1 = alloc() : memref<1xi32> // CHECK-NEXT: dma_start %arg0[%c0, %c0], %0[%c0, %c0], %c10000, %1[%c0] : memref<100x100xf32>, memref<100x100xf32, 1>, memref<1xi32> // CHECK-NEXT: dma_wait %1[%c0], %c10000 : memref<1xi32> - for %i = 0 to 100 { - for %j = %M to %N { + affine.for %i = 0 to 100 { + affine.for %j = %M to %N { %idy = affine.apply (d1) [s0] -> (d1 + s0)(%j)[%K] load %A[%i, %idy] : memref<100 x 100 x f32> } @@ -256,8 +256,8 @@ func @dma_with_symbolic_loop_bounds(%A : memref<100x100xf32>, %M : index, %N: in func @dma_unknown_size(%arg0: memref) { %M = dim %arg0, 0 : memref %N = dim %arg0, 0 : memref - for %i = 0 to %M { - for %j = 0 to %N { + affine.for %i = 0 to %M { + affine.for %j = 0 to %N { // If this loop nest isn't tiled, the access requires a non-constant DMA // size -- not yet implemented. // CHECK: %2 = load %arg0[%i0, %i1] : memref @@ -272,9 +272,9 @@ func @dma_unknown_size(%arg0: memref) { // CHECK-LABEL: func @dma_memref_3d func @dma_memref_3d(%arg0: memref<1024x1024x1024xf32>) { - for %i = 0 to 1024 { - for %j = 0 to 1024 { - for %k = 0 to 1024 { + affine.for %i = 0 to 1024 { + affine.for %j = 0 to 1024 { + affine.for %k = 0 to 1024 { %idx = affine.apply (d0) -> (d0 mod 128)(%i) %idy = affine.apply (d0) -> (d0 mod 128)(%j) %idz = affine.apply (d0) -> (d0 mod 128)(%k) @@ -308,8 +308,8 @@ func @dma_memref_3d(%arg0: memref<1024x1024x1024xf32>) { // CHECK-LABEL: func @multi_load_store_union() { func @multi_load_store_union() { %A = alloc() : memref<512 x 512 x f32> - for %i = 0 to 256 { - for %j = 0 to 256 { + affine.for %i = 0 to 256 { + affine.for %j = 0 to 256 { %idx = affine.apply (d0) -> (d0 + 64)(%i) %idy = affine.apply (d0) -> (d0 + 128)(%j) %ishift = affine.apply (d0) -> (d0 + 2)(%i) @@ -333,8 +333,8 @@ func @multi_load_store_union() { // CHECK-NEXT: dma_start %0[%c2_1, %c2_2], %1[%c0, %c0], %c170372_3, %2[%c0], %c512_4, %c446_5 : memref<512x512xf32>, memref<382x446xf32, 1>, memref<1xi32> // CHECK-NEXT: dma_wait %2[%c0], %c170372_3 : memref<1xi32> // CHECK-NEXT: %3 = alloc() : memref<1xi32> -// CHECK-NEXT: for %i0 = 0 to 256 { -// CHECK-NEXT: for %i1 = 0 to 256 { +// CHECK-NEXT: affine.for %i0 = 0 to 256 { +// CHECK-NEXT: affine.for %i1 = 0 to 256 { // CHECK-NEXT: %4 = affine.apply [[MAP_PLUS_64]](%i0) // CHECK-NEXT: %5 = affine.apply [[MAP_PLUS_128]](%i1) // CHECK-NEXT: %6 = affine.apply [[MAP_PLUS_2]](%i0) @@ -370,7 +370,7 @@ func @dma_loop_straightline_interspersed() { %c255 = constant 255 : index %A = alloc() : memref<256 x f32> %v = load %A[%c0] : memref<256 x f32> - for %i = 1 to 255 { + affine.for %i = 1 to 255 { load %A[%i] : memref<256 x f32> } %l = load %A[%c255] : memref<256 x f32> @@ -389,7 +389,7 @@ func @dma_loop_straightline_interspersed() { // CHECK-NEXT: %5 = alloc() : memref<1xi32> // CHECK-NEXT: dma_start %0[%c1_0], %4[%c0], %c254, %5[%c0] : memref<256xf32>, memref<254xf32, 1>, memref<1xi32> // CHECK-NEXT: dma_wait %5[%c0], %c254 : memref<1xi32> -// CHECK-NEXT: for %i0 = 1 to 255 { +// CHECK-NEXT: affine.for %i0 = 1 to 255 { // CHECK-NEXT: %6 = affine.apply [[MAP_MINUS_ONE]](%i0) // CHECK-NEXT: %7 = load %4[%6] : memref<254xf32, 1> // CHECK-NEXT: } @@ -410,10 +410,10 @@ func @dma_loop_straightline_interspersed() { func @dma_mixed_loop_blocks() { %c0 = constant 0 : index %A = alloc() : memref<256 x 256 x vector<8 x f32>> - for %i = 0 to 256 { + affine.for %i = 0 to 256 { %v = load %A[%c0, %c0] : memref<256 x 256 x vector<8 x f32>> "foo"(%v) : (vector<8 x f32>) -> () - for %j = 0 to 256 { + affine.for %j = 0 to 256 { %w = load %A[%i, %j] : memref<256 x 256 x vector<8 x f32>> "bar"(%w) : (vector<8 x f32>) -> () } @@ -425,7 +425,7 @@ func @dma_mixed_loop_blocks() { // CHECK-DAG: [[TAG:%[0-9]+]] = alloc() : memref<1xi32> // CHECK: dma_start [[MEM]][%c0, %c0], [[BUF]][%c0, %c0], %c65536, [[TAG]][%c0] : memref<256x256xvector<8xf32>>, memref<256x256xvector<8xf32>, 1>, memref<1xi32> // CHECK-NEXT: dma_wait [[TAG]][%c0], %c65536 : memref<1xi32> -// CHECK-NEXT: for %i0 = 0 to 256 { +// CHECK-NEXT: affine.for %i0 = 0 to 256 { // CHECK-NEXT: %3 = load [[BUF]][%c0_0, %c0_0] : memref<256x256xvector<8xf32>, 1> -// CHECK: for %i1 = 0 to 256 { +// CHECK: affine.for %i1 = 0 to 256 { // CHECK-NEXT: %4 = load [[BUF]][%i0, %i1] : memref<256x256xvector<8xf32>, 1> diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index 2071c6023e8..b14a72eb5ee 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -16,13 +16,13 @@ func @should_fuse_raw_dep_for_locality() { %m = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { %v0 = load %m[%i1] : memref<10xf32> } - // CHECK: for %i0 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: %1 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: store %cst, %0[%1] : memref<1xf32> // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i0, %i0) @@ -44,23 +44,23 @@ func @should_fuse_reduction_to_pointwise() { %cf7 = constant 7.0 : f32 - for %i0 = 0 to 10 { - for %i1 = 0 to 10 { + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { %v0 = load %b[%i0] : memref<10xf32> %v1 = load %a[%i0, %i1] : memref<10x10xf32> %v3 = addf %v0, %v1 : f32 store %v3, %b[%i0] : memref<10xf32> } } - for %i2 = 0 to 10 { + affine.for %i2 = 0 to 10 { %v4 = load %b[%i2] : memref<10xf32> store %v4, %c[%i2] : memref<10xf32> } // Should fuse in entire inner loop on %i1 from source loop nest, as %i1 // is not used in the access function of the store/load on %b. - // CHECK: for %i0 = 0 to 10 { - // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { + // CHECK-NEXT: affine.for %i1 = 0 to 10 { // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: %4 = load %0[%3] : memref<1xf32> // CHECK-NEXT: %5 = load %1[%i0, %i1] : memref<10x10xf32> @@ -88,15 +88,15 @@ func @should_fuse_loop_nests_with_shifts() { %a = alloc() : memref<10x10xf32> %cf7 = constant 7.0 : f32 - for %i0 = 0 to 9 { - for %i1 = 0 to 9 { + affine.for %i0 = 0 to 9 { + affine.for %i1 = 0 to 9 { %idx = affine.apply (d0) -> (d0 + 1) (%i0) %idy = affine.apply (d0) -> (d0 + 1) (%i1) store %cf7, %a[%idx, %idy] : memref<10x10xf32> } } - for %i2 = 1 to 10 { - for %i3 = 1 to 10 { + affine.for %i2 = 1 to 10 { + affine.for %i3 = 1 to 10 { %v0 = load %a[%i2, %i3] : memref<10x10xf32> } } @@ -109,8 +109,8 @@ func @should_fuse_loop_nests_with_shifts() { // *) Fifth affine apply shifts the loads access function by '-1', because // of the offset induced by reducing the memref shape from 10x10 to 9x9. // NOTE: Should create a private memref with reduced shape 9x9xf32. - // CHECK: for %i0 = 1 to 10 { - // CHECK-NEXT: for %i1 = 1 to 10 { + // CHECK: affine.for %i0 = 1 to 10 { + // CHECK-NEXT: affine.for %i1 = 1 to 10 { // CHECK-NEXT: %1 = affine.apply [[MAP_SHIFT_MINUS_ONE_R1]](%i0) // CHECK-NEXT: %2 = affine.apply [[MAP_SHIFT_MINUS_ONE_R1]](%i1) // CHECK-NEXT: %3 = affine.apply [[MAP_SHIFT_BY_ONE]](%1) @@ -138,27 +138,27 @@ func @should_fuse_loop_nest() { %b = alloc() : memref<10x10xf32> %cf7 = constant 7.0 : f32 - for %i0 = 0 to 10 { - for %i1 = 0 to 10 { + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { store %cf7, %a[%i0, %i1] : memref<10x10xf32> } } - for %i2 = 0 to 10 { - for %i3 = 0 to 10 { + affine.for %i2 = 0 to 10 { + affine.for %i3 = 0 to 10 { %v0 = load %a[%i3, %i2] : memref<10x10xf32> store %v0, %b[%i2, %i3] : memref<10x10xf32> } } - for %i4 = 0 to 10 { - for %i5 = 0 to 10 { + affine.for %i4 = 0 to 10 { + affine.for %i5 = 0 to 10 { %v1 = load %b[%i4, %i5] : memref<10x10xf32> } } // Expecting private memref for '%a' first, then private memref for '%b'. // CHECK-DAG: [[NEWA:%[0-9]+]] = alloc() : memref<1x1xf32> // CHECK-DAG: [[NEWB:%[0-9]+]] = alloc() : memref<1x1xf32> - // CHECK: for %i0 = 0 to 10 { - // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { + // CHECK-NEXT: affine.for %i1 = 0 to 10 { // CHECK-NEXT: %2 = affine.apply [[MAP_D2_D0_DIFF]](%i1, %i0, %i1, %i0) // CHECK-NEXT: %3 = affine.apply [[MAP_D3_D1_DIFF]](%i1, %i0, %i1, %i0) // CHECK-NEXT: store %cst, [[NEWA]][%2, %3] : memref<1x1xf32> @@ -189,23 +189,23 @@ func @should_fuse_across_intermediate_loop_with_no_deps() { %cf7 = constant 7.0 : f32 - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { %v0 = load %a[%i0] : memref<10xf32> store %v0, %b[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { store %cf7, %c[%i1] : memref<10xf32> } - for %i2 = 0 to 10 { + affine.for %i2 = 0 to 10 { %v1 = load %b[%i2] : memref<10xf32> } // Should fuse first loop (past second loop with no dependences) into third. // Note that fusion creates a private memref '%2' for the fused loop nest. - // CHECK: for %i0 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %2[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK: for %i1 = 0 to 10 { + // CHECK: affine.for %i1 = 0 to 10 { // CHECK-NEXT: %3 = load %1[%i1] : memref<10xf32> // CHECK-NEXT: %4 = affine.apply [[MAP0]](%i1, %i1) // CHECK-NEXT: store %3, %0[%4] : memref<1xf32> @@ -227,13 +227,13 @@ func @should_fuse_all_loops() { %cf7 = constant 7.0 : f32 // Set up flow dependences from first and second loops to third. - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %a[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { store %cf7, %b[%i1] : memref<10xf32> } - for %i2 = 0 to 10 { + affine.for %i2 = 0 to 10 { %v0 = load %a[%i2] : memref<10xf32> %v1 = load %b[%i2] : memref<10xf32> } @@ -242,7 +242,7 @@ func @should_fuse_all_loops() { // Expecting private memref for '%a' first, then private memref for '%b'. // CHECK-DAG: [[NEWA:%[0-9]+]] = alloc() : memref<1xf32> // CHECK-DAG: [[NEWB:%[0-9]+]] = alloc() : memref<1xf32> - // CHECK: for %i0 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: store %cst, [[NEWA]][%2] : memref<1xf32> // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i0) @@ -268,27 +268,27 @@ func @should_fuse_first_and_second_loops() { %cf7 = constant 7.0 : f32 - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %a[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { %v0 = load %a[%i1] : memref<10xf32> store %cf7, %b[%i1] : memref<10xf32> } - for %i2 = 0 to 10 { + affine.for %i2 = 0 to 10 { %v1 = load %c[%i2] : memref<10xf32> } // Should fuse first loop into the second (last loop should not be fused). // Should create private memref '%2' for fused loop. - // CHECK: for %i0 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: store %cst, %0[%3] : memref<1xf32> // CHECK-NEXT: %4 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: %5 = load %0[%4] : memref<1xf32> // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK: for %i1 = 0 to 10 { + // CHECK: affine.for %i1 = 0 to 10 { // CHECK-NEXT: %6 = load %2[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return @@ -310,28 +310,28 @@ func @should_not_fuse_would_create_cycle() { // 1) loop0 -> loop1 on memref '%a' // 2) loop0 -> loop2 on memref '%b' // 3) loop1 -> loop2 on memref '%c' - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { %v0 = load %a[%i0] : memref<10xf32> store %cf7, %b[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { store %cf7, %a[%i1] : memref<10xf32> %v1 = load %c[%i1] : memref<10xf32> } - for %i2 = 0 to 10 { + affine.for %i2 = 0 to 10 { %v2 = load %b[%i2] : memref<10xf32> store %cf7, %c[%i2] : memref<10xf32> } // Should not fuse: fusing loop first loop into last would create a cycle. - // CHECK: for %i0 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: %3 = load %0[%i0] : memref<10xf32> // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK: for %i1 = 0 to 10 { + // CHECK: affine.for %i1 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i1] : memref<10xf32> // CHECK-NEXT: %4 = load %2[%i1] : memref<10xf32> // CHECK-NEXT: } - // CHECK: for %i2 = 0 to 10 { + // CHECK: affine.for %i2 = 0 to 10 { // CHECK-NEXT: %5 = load %1[%i2] : memref<10xf32> // CHECK-NEXT: store %cst, %2[%i2] : memref<10xf32> // CHECK-NEXT: } @@ -346,23 +346,23 @@ func @should_not_fuse_across_waw_dep() { %m = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { store %cf7, %m[%i1] : memref<10xf32> } - for %i2 = 0 to 10 { + affine.for %i2 = 0 to 10 { %v1 = load %m[%i2] : memref<10xf32> } // Fusing loop %i0 to %i2 would violate the WAW dependence between %i0 and %i1 - // CHECK: for %i0 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK: for %i1 = 0 to 10 { + // CHECK: affine.for %i1 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i1] : memref<10xf32> // CHECK-NEXT: } - // CHECK: for %i2 = 0 to 10 { + // CHECK: affine.for %i2 = 0 to 10 { // CHECK-NEXT: %1 = load %0[%i2] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return @@ -379,27 +379,27 @@ func @should_fuse_and_move_to_preserve_war_dep() { %b = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { %v0 = load %a[%i0] : memref<10xf32> store %v0, %b[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { store %cf7, %a[%i1] : memref<10xf32> } - for %i2 = 0 to 10 { + affine.for %i2 = 0 to 10 { %v1 = load %b[%i2] : memref<10xf32> } // Loops '%i1' and '%i2' have no dependences. We can fuse a slice of '%i0' // into '%i2' if we move the fused loop nest before '%i1', which preserves // the WAR dependence from load '%a' in '%i0' to the store '%a' in loop '%i1'. - // CHECK: for %i0 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: %2 = load %1[%i0] : memref<10xf32> // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: store %2, %0[%3] : memref<1xf32> // CHECK-NEXT: %4 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: %5 = load %0[%4] : memref<1xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: affine.for %i1 = 0 to 10 { // CHECK-NEXT: store %cst, %1[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return @@ -415,20 +415,20 @@ func @should_fuse_with_private_memref_if_top_level_access() { %m = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { %v0 = load %m[%i1] : memref<10xf32> } %c0 = constant 4 : index %v1 = load %m[%c0] : memref<10xf32> // Top-level load to '%m' should prevent fusion. - // CHECK: for %i0 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: affine.for %i1 = 0 to 10 { // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i1, %i1) // CHECK-NEXT: store %cst, %0[%2] : memref<1xf32> // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i1, %i1) @@ -446,13 +446,13 @@ func @should_fuse_no_top_level_access() { %m = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { %v0 = load %m[%i1] : memref<10xf32> } - // CHECK: for %i0 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: %1 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: store %cst, %0[%1] : memref<1xf32> // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i0, %i0) @@ -471,20 +471,20 @@ func @should_not_fuse_if_inst_at_top_level() { %m = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { %v0 = load %m[%i1] : memref<10xf32> } %c0 = constant 4 : index if #set0(%c0) { } // Top-level IfOp should prevent fusion. - // CHECK: for %i0 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK: for %i1 = 0 to 10 { + // CHECK: affine.for %i1 = 0 to 10 { // CHECK-NEXT: %1 = load %0[%i1] : memref<10xf32> // CHECK-NEXT: } return @@ -500,20 +500,20 @@ func @should_not_fuse_if_inst_in_loop_nest() { %cf7 = constant 7.0 : f32 %c4 = constant 4 : index - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { if #set0(%c4) { } %v0 = load %m[%i1] : memref<10xf32> } // IfOp in ForInst should prevent fusion. - // CHECK: for %i0 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK: for %i1 = 0 to 10 { + // CHECK: affine.for %i1 = 0 to 10 { // CHECK-NEXT: if #set0(%c4) { // CHECK-NEXT: } // CHECK-NEXT: %1 = load %0[%i1] : memref<10xf32> @@ -532,24 +532,24 @@ func @permute_and_fuse() { %m = alloc() : memref<10x20x30xf32> %cf7 = constant 7.0 : f32 - for %i0 = 0 to 10 { - for %i1 = 0 to 20 { - for %i2 = 0 to 30 { + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 20 { + affine.for %i2 = 0 to 30 { store %cf7, %m[%i0, %i1, %i2] : memref<10x20x30xf32> } } } - for %i3 = 0 to 30 { - for %i4 = 0 to 10 { - for %i5 = 0 to 20 { + affine.for %i3 = 0 to 30 { + affine.for %i4 = 0 to 10 { + affine.for %i5 = 0 to 20 { %v0 = load %m[%i4, %i5, %i3] : memref<10x20x30xf32> "foo"(%v0) : (f32) -> () } } } -// CHECK: for %i0 = 0 to 30 { -// CHECK-NEXT: for %i1 = 0 to 10 { -// CHECK-NEXT: for %i2 = 0 to 20 { +// CHECK: affine.for %i0 = 0 to 30 { +// CHECK-NEXT: affine.for %i1 = 0 to 10 { +// CHECK-NEXT: affine.for %i2 = 0 to 20 { // CHECK-NEXT: %1 = affine.apply [[MAP0]](%i1, %i2, %i0, %i1, %i2, %i0) // CHECK-NEXT: %2 = affine.apply [[MAP1]](%i1, %i2, %i0, %i1, %i2, %i0) // CHECK-NEXT: %3 = affine.apply [[MAP2]](%i1, %i2, %i0, %i1, %i2, %i0) @@ -578,22 +578,22 @@ func @permute_and_fuse() { func @fuse_reshape_64_16_4(%in : memref<64xf32>) { %out = alloc() : memref<16x4xf32> - for %i0 = 0 to 64 { + affine.for %i0 = 0 to 64 { %v = load %in[%i0] : memref<64xf32> %idx = affine.apply (d0) -> (d0 floordiv 4) (%i0) %idy = affine.apply (d0) -> (d0 mod 4) (%i0) store %v, %out[%idx, %idy] : memref<16x4xf32> } - for %i1 = 0 to 16 { - for %i2 = 0 to 4 { + affine.for %i1 = 0 to 16 { + affine.for %i2 = 0 to 4 { %w = load %out[%i1, %i2] : memref<16x4xf32> "foo"(%w) : (f32) -> () } } return - // CHECK: for %i0 = - // CHECK-NEXT: for %i1 = + // CHECK: affine.for %i0 = + // CHECK-NEXT: affine.for %i1 = // CHECK-NOT: for // CHECK: } // CHECK-NEXT: } @@ -612,19 +612,19 @@ func @fuse_reshape_16_4_64() { %in = alloc() : memref<16x4xf32> %out = alloc() : memref<64xf32> - for %i0 = 0 to 16 { - for %i1 = 0 to 4 { + affine.for %i0 = 0 to 16 { + affine.for %i1 = 0 to 4 { %v = load %in[%i0, %i1] : memref<16x4xf32> %idx = affine.apply (d0, d1) -> (4*d0 + d1) (%i0, %i1) store %v, %out[%idx] : memref<64xf32> } } - for %i2 = 0 to 64 { + affine.for %i2 = 0 to 64 { %w = load %out[%i2] : memref<64xf32> "foo"(%w) : (f32) -> () } -// CHECK: for %i0 = 0 to 64 { +// CHECK: affine.for %i0 = 0 to 64 { // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i0) // CHECK-NEXT: %3 = affine.apply [[MAP1]](%i0) // CHECK-NEXT: %4 = load %1[%2, %3] : memref<16x4xf32> @@ -650,12 +650,12 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> { %live_out = alloc() : memref<64x9xi32> // Initialize input. - for %i0 = 0 to 2 { - for %i1 = 0 to 2 { - for %i2 = 0 to 3 { - for %i3 = 0 to 3 { - for %i4 = 0 to 16 { - for %i5 = 0 to 1 { + affine.for %i0 = 0 to 2 { + affine.for %i1 = 0 to 2 { + affine.for %i2 = 0 to 3 { + affine.for %i3 = 0 to 3 { + affine.for %i4 = 0 to 16 { + affine.for %i5 = 0 to 1 { %val = "foo"(%i0, %i1, %i2, %i3, %i4, %i5) : (index, index, index, index, index, index) -> i32 store %val, %in[%i0, %i1, %i2, %i3, %i4, %i5] : memref<2x2x3x3x16x1xi32> } @@ -665,8 +665,8 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> { } } - for %ii = 0 to 64 { - for %jj = 0 to 9 { + affine.for %ii = 0 to 64 { + affine.for %jj = 0 to 9 { // Convert output coordinates to linear index. %a0 = affine.apply (d0, d1) -> (d0 * 9 + d1) (%ii, %jj) %0 = affine.apply (d0) -> (d0 floordiv (2 * 3 * 3 * 16 * 1))(%a0) @@ -680,8 +680,8 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> { } } - for %i = 0 to 64 { - for %j = 0 to 9 { + affine.for %i = 0 to 64 { + affine.for %j = 0 to 9 { %a = load %out[%i, %j] : memref<64x9xi32> %b = muli %a, %a : i32 store %b, %live_out[%i, %j] : memref<64x9xi32> @@ -717,8 +717,8 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> { // CHECK: %0 = alloc() : memref<1x2x3x3x16x1xi32> // CHECK: %1 = alloc() : memref<1x1xi32> // CHECK: %2 = alloc() : memref<64x9xi32> -// CHECK-NEXT: for %i0 = 0 to 64 { -// CHECK-NEXT: for %i1 = 0 to 9 { +// CHECK-NEXT: affine.for %i0 = 0 to 64 { +// CHECK-NEXT: affine.for %i1 = 0 to 9 { // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i1) // CHECK-NEXT: %4 = affine.apply [[MAP1]](%i0, %i1) // CHECK-NEXT: %5 = affine.apply [[MAP2]](%i0, %i1) @@ -768,14 +768,14 @@ func @fuse_symbolic_bounds(%M : index, %N : index) { %c0 = constant 0.0 : f32 %s = constant 5 : index - for %i0 = 0 to %M { - for %i1 = 0 to (d0) -> (d0 + 5) (%N) { + affine.for %i0 = 0 to %M { + affine.for %i1 = 0 to (d0) -> (d0 + 5) (%N) { store %c0, %m[%i0, %i1] : memref } } - for %i2 = 0 to %M { - for %i3 = 0 to %N { + affine.for %i2 = 0 to %M { + affine.for %i3 = 0 to %N { %idy = affine.apply (d0)[s0] -> (d0 + s0) (%i3)[%s] %v = load %m[%i2, %idy] : memref } @@ -792,16 +792,16 @@ func @should_fuse_reduction_at_depth1() { %a = alloc() : memref<10x100xf32> %b = alloc() : memref<10xf32> - for %i0 = 0 to 10 { - for %i1 = 0 to 100 { + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 100 { %v0 = load %b[%i0] : memref<10xf32> %v1 = load %a[%i0, %i1] : memref<10x100xf32> %v2 = "maxf"(%v0, %v1) : (f32, f32) -> f32 store %v2, %b[%i0] : memref<10xf32> } } - for %i2 = 0 to 10 { - for %i3 = 0 to 100 { + affine.for %i2 = 0 to 10 { + affine.for %i3 = 0 to 100 { %v3 = load %b[%i2] : memref<10xf32> %v4 = load %a[%i2, %i3] : memref<10x100xf32> %v5 = subf %v4, %v3 : f32 @@ -812,8 +812,8 @@ func @should_fuse_reduction_at_depth1() { // loop nest, which improves locality and enables subsequence passes to // decrease the reduction memref size and possibly place it in a faster // memory space. - // CHECK: for %i0 = 0 to 10 { - // CHECK-NEXT: for %i1 = 0 to 100 { + // CHECK: affine.for %i0 = 0 to 10 { + // CHECK-NEXT: affine.for %i1 = 0 to 100 { // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: %3 = load %0[%2] : memref<1xf32> // CHECK-NEXT: %4 = load %1[%i0, %i1] : memref<10x100xf32> @@ -821,7 +821,7 @@ func @should_fuse_reduction_at_depth1() { // CHECK-NEXT: %6 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: store %5, %0[%6] : memref<1xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i2 = 0 to 100 { + // CHECK-NEXT: affine.for %i2 = 0 to 100 { // CHECK-NEXT: %7 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: %8 = load %0[%7] : memref<1xf32> // CHECK-NEXT: %9 = load %1[%i0, %i2] : memref<10x100xf32> @@ -843,19 +843,19 @@ func @should_fuse_at_src_depth1_and_dst_depth1() { %a = alloc() : memref<100x16xf32> %b = alloc() : memref<100x16xf32> - for %i0 = 0 to 100 { - for %i1 = 0 to 16 { + affine.for %i0 = 0 to 100 { + affine.for %i1 = 0 to 16 { %v0 = load %a[%i0, %i1] : memref<100x16xf32> "op0"(%v0) : (f32) -> () } - for %i2 = 0 to 16 { + affine.for %i2 = 0 to 16 { %v1 = "op1"() : () -> (f32) store %v1, %b[%i0, %i2] : memref<100x16xf32> } } - for %i3 = 0 to 100 { - for %i4 = 0 to 16 { + affine.for %i3 = 0 to 100 { + affine.for %i4 = 0 to 16 { %v2 = load %b[%i3, %i4] : memref<100x16xf32> "op2"(%v2) : (f32) -> () } @@ -865,18 +865,18 @@ func @should_fuse_at_src_depth1_and_dst_depth1() { // destination loop nest at depth2 causes extra computation. Instead, // the fusion algorithm should detect that the source loop should be sliced // at depth 1 and the slice should be inserted at depth 1. - // CHECK: for %i0 = 0 to 100 { - // CHECK-NEXT: for %i1 = 0 to 16 { + // CHECK: affine.for %i0 = 0 to 100 { + // CHECK-NEXT: affine.for %i1 = 0 to 16 { // CHECK-NEXT: %2 = load %1[%i0, %i1] : memref<100x16xf32> // CHECK-NEXT: "op0"(%2) : (f32) -> () // CHECK-NEXT: } - // CHECK-NEXT: for %i2 = 0 to 16 { + // CHECK-NEXT: affine.for %i2 = 0 to 16 { // CHECK-NEXT: %3 = "op1"() : () -> f32 // CHECK-NEXT: %4 = affine.apply [[MAP0]](%i0, %i0, %i2) // CHECK-NEXT: %5 = affine.apply [[MAP1]](%i0, %i0, %i2) // CHECK-NEXT: store %3, %0[%4, %5] : memref<1x16xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i3 = 0 to 16 { + // CHECK-NEXT: affine.for %i3 = 0 to 16 { // CHECK-NEXT: %6 = affine.apply [[MAP0]](%i0, %i0, %i3) // CHECK-NEXT: %7 = affine.apply [[MAP1]](%i0, %i0, %i3) // CHECK-NEXT: %8 = load %0[%6, %7] : memref<1x16xf32> @@ -896,20 +896,20 @@ func @should_fuse_src_depth1_at_dst_depth2() { %a = alloc() : memref<100xf32> %c0 = constant 0.0 : f32 - for %i0 = 0 to 100 { + affine.for %i0 = 0 to 100 { store %c0, %a[%i0] : memref<100xf32> } - for %i1 = 0 to 10 { - for %i2 = 0 to 10 { + affine.for %i1 = 0 to 10 { + affine.for %i2 = 0 to 10 { %a0 = affine.apply (d0, d1) -> (d0 * 10 + d1) (%i1, %i2) %v0 = load %a[%a0] : memref<100xf32> } } // The source loop nest slice loop bound is a function of both destination // loop IVs, so we should slice at depth 1 and insert the slice at depth 2. - // CHECK: for %i0 = 0 to 10 { - // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { + // CHECK-NEXT: affine.for %i1 = 0 to 10 { // CHECK-NEXT: %1 = affine.apply [[MAP0]](%i0, %i1) // CHECK-NEXT: %2 = affine.apply [[MAP1]](%i0, %i1, %1) // CHECK-NEXT: store %cst, %0[%2] : memref<1xf32> @@ -930,10 +930,10 @@ func @fusion_at_depth0_not_currently_supported() { %0 = alloc() : memref<10xf32> %c0 = constant 0 : index %cst = constant 0.000000e+00 : f32 - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cst, %0[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { %1 = load %0[%c0] : memref<10xf32> } // NOTE: Should shrink memref size to 1 element access by load in dst loop @@ -966,18 +966,18 @@ func @should_fuse_deep_loop_nests() { %c1 = constant 1 : index %c1_0 = constant 1 : index %cst = constant 0.000000e+00 : f32 - for %i0 = 0 to 2 { - for %i1 = 0 to 2 { - for %i2 = 0 to 3 { - for %i3 = 0 to 3 { - for %i4 = 0 to 16 { - for %i5 = 0 to 10 { + affine.for %i0 = 0 to 2 { + affine.for %i1 = 0 to 2 { + affine.for %i2 = 0 to 3 { + affine.for %i3 = 0 to 3 { + affine.for %i4 = 0 to 16 { + affine.for %i5 = 0 to 10 { %3 = load %0[%i0, %i1, %i2, %i3, %i4, %i5] : memref<2x2x3x3x16x10xf32, 2> } } - for %i6 = 0 to 16 { - for %i7 = 0 to 10 { + affine.for %i6 = 0 to 16 { + affine.for %i7 = 0 to 10 { store %cst, %1[%i0, %i1, %i2, %i3, %i6, %i7] : memref<2x2x3x3x16x10xf32, 2> } @@ -986,22 +986,22 @@ func @should_fuse_deep_loop_nests() { } } } - for %i8 = 0 to 3 { - for %i9 = 0 to 3 { - for %i10 = 0 to 2 { - for %i11 = 0 to 2 { - for %i12 = 0 to 3 { - for %i13 = 0 to 3 { - for %i14 = 0 to 2 { - for %i15 = 0 to 2 { - for %i16 = 0 to 16 { - for %i17 = 0 to 10 { + affine.for %i8 = 0 to 3 { + affine.for %i9 = 0 to 3 { + affine.for %i10 = 0 to 2 { + affine.for %i11 = 0 to 2 { + affine.for %i12 = 0 to 3 { + affine.for %i13 = 0 to 3 { + affine.for %i14 = 0 to 2 { + affine.for %i15 = 0 to 2 { + affine.for %i16 = 0 to 16 { + affine.for %i17 = 0 to 10 { %5 = load %0[%i14, %i15, %i12, %i13, %i16, %i17] : memref<2x2x3x3x16x10xf32, 2> } } - for %i18 = 0 to 16 { - for %i19 = 0 to 10 { + affine.for %i18 = 0 to 16 { + affine.for %i19 = 0 to 10 { %6 = load %1[%i10, %i11, %i8, %i9, %i18, %i19] : memref<2x2x3x3x16x10xf32, 2> } @@ -1019,19 +1019,19 @@ func @should_fuse_deep_loop_nests() { // where the destination loops nests have been interchanged. // CHECK-DAG: %0 = alloc() : memref<1x1x1x1x16x10xf32, 2> -// CHECK: for %i0 = 0 to 3 { -// CHECK-NEXT: for %i1 = 0 to 3 { -// CHECK-NEXT: for %i2 = 0 to 2 { -// CHECK-NEXT: for %i3 = 0 to 2 { -// CHECK-NEXT: for %i4 = 0 to 3 { -// CHECK-NEXT: for %i5 = 0 to 3 { -// CHECK-NEXT: for %i6 = 0 to 16 { -// CHECK-NEXT: for %i7 = 0 to 10 { +// CHECK: affine.for %i0 = 0 to 3 { +// CHECK-NEXT: affine.for %i1 = 0 to 3 { +// CHECK-NEXT: affine.for %i2 = 0 to 2 { +// CHECK-NEXT: affine.for %i3 = 0 to 2 { +// CHECK-NEXT: affine.for %i4 = 0 to 3 { +// CHECK-NEXT: affine.for %i5 = 0 to 3 { +// CHECK-NEXT: affine.for %i6 = 0 to 16 { +// CHECK-NEXT: affine.for %i7 = 0 to 10 { // CHECK-NEXT: %3 = load %1[%i2, %i3, %i0, %i1, %i6, %i7] : memref<2x2x3x3x16x10xf32, 2> // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: for %i8 = 0 to 16 { -// CHECK-NEXT: for %i9 = 0 to 10 { +// CHECK-NEXT: affine.for %i8 = 0 to 16 { +// CHECK-NEXT: affine.for %i9 = 0 to 10 { // CHECK-NEXT: %4 = affine.apply [[MAP0]](%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i8, %i9) // CHECK-NEXT: %5 = affine.apply [[MAP1]](%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i8, %i9) // CHECK-NEXT: %6 = affine.apply [[MAP2]](%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i8, %i9) @@ -1041,15 +1041,15 @@ func @should_fuse_deep_loop_nests() { // CHECK-NEXT: store %cst, %0[%4, %5, %6, %7, %8, %9] : memref<1x1x1x1x16x10xf32, 2> // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: for %i10 = 0 to 2 { -// CHECK-NEXT: for %i11 = 0 to 2 { -// CHECK-NEXT: for %i12 = 0 to 16 { -// CHECK-NEXT: for %i13 = 0 to 10 { +// CHECK-NEXT: affine.for %i10 = 0 to 2 { +// CHECK-NEXT: affine.for %i11 = 0 to 2 { +// CHECK-NEXT: affine.for %i12 = 0 to 16 { +// CHECK-NEXT: affine.for %i13 = 0 to 10 { // CHECK-NEXT: %10 = load %1[%i10, %i11, %i4, %i5, %i12, %i13] : memref<2x2x3x3x16x10xf32, 2> // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: for %i14 = 0 to 16 { -// CHECK-NEXT: for %i15 = 0 to 10 { +// CHECK-NEXT: affine.for %i14 = 0 to 16 { +// CHECK-NEXT: affine.for %i15 = 0 to 10 { // CHECK-NEXT: %11 = affine.apply [[MAP0]](%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i14, %i15) // CHECK-NEXT: %12 = affine.apply [[MAP1]](%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i14, %i15) // CHECK-NEXT: %13 = affine.apply [[MAP2]](%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i14, %i15) @@ -1083,17 +1083,17 @@ func @should_fuse_at_depth1_and_reduce_slice_trip_count() { %c0 = constant 0 : index %cf0 = constant 0.0 : f32 - for %i0 = 0 to 4 { - for %i1 = 0 to 256 { + affine.for %i0 = 0 to 4 { + affine.for %i1 = 0 to 256 { %v0 = load %b[%i0, %i1] : memref<4x256xf32> } - for %i2 = 0 to 256 { + affine.for %i2 = 0 to 256 { store %cf0, %a[%i0, %i2] : memref<4x256xf32> } } - for %d0 = 0 to 4 { - for %d1 = 0 to 16 { + affine.for %d0 = 0 to 4 { + affine.for %d1 = 0 to 16 { %v1 = load %a[%d0, %d1] : memref<4x256xf32> } } @@ -1107,16 +1107,16 @@ func @should_fuse_at_depth1_and_reduce_slice_trip_count() { // is reduced from the original shape from 4x256 to 4x16 because of the // data accessed by the load. // CHECK-DAG: %0 = alloc() : memref<1x16xf32> - // CHECK: for %i0 = 0 to 4 { - // CHECK-NEXT: for %i1 = 0 to 256 { + // CHECK: affine.for %i0 = 0 to 4 { + // CHECK-NEXT: affine.for %i1 = 0 to 256 { // CHECK-NEXT: %2 = load %1[%i0, %i1] : memref<4x256xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i2 = 0 to 16 { + // CHECK-NEXT: affine.for %i2 = 0 to 16 { // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i0, %i2) // CHECK-NEXT: %4 = affine.apply [[MAP1]](%i0, %i0, %i2) // CHECK-NEXT: store %cst, %0[%3, %4] : memref<1x16xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i3 = 0 to 16 { + // CHECK-NEXT: affine.for %i3 = 0 to 16 { // CHECK-NEXT: %5 = affine.apply [[MAP0]](%i0, %i0, %i3) // CHECK-NEXT: %6 = affine.apply [[MAP1]](%i0, %i0, %i3) // CHECK-NEXT: %7 = load %0[%5, %6] : memref<1x16xf32> @@ -1134,31 +1134,31 @@ func @should_fuse_at_depth1_with_trip_count_20() { %c0 = constant 0 : index %cf0 = constant 0.0 : f32 - for %i0 = 0 to 100 { + affine.for %i0 = 0 to 100 { store %cf0, %a[%i0]: memref<100xf32> } - for %i1 = 0 to 5 { - for %i2 = 0 to 10 { + affine.for %i1 = 0 to 5 { + affine.for %i2 = 0 to 10 { %v0 = load %a[%i2]: memref<100xf32> } - for %i3 = 0 to 10 { - for %i4 = 0 to 20 { + affine.for %i3 = 0 to 10 { + affine.for %i4 = 0 to 20 { %v1 = load %a[%i4]: memref<100xf32> } } } // NOTE: The size of the private memref created for fusion is shrunk to 20xf32 // CHECK-DAG: %0 = alloc() : memref<20xf32> - // CHECK: for %i0 = 0 to 5 { - // CHECK-NEXT: for %i1 = 0 to 20 { + // CHECK: affine.for %i0 = 0 to 5 { + // CHECK-NEXT: affine.for %i1 = 0 to 20 { // CHECK-NEXT: store %cst, %0[%i1] : memref<20xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i2 = 0 to 10 { + // CHECK-NEXT: affine.for %i2 = 0 to 10 { // CHECK-NEXT: %1 = load %0[%i2] : memref<20xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i3 = 0 to 10 { - // CHECK-NEXT: for %i4 = 0 to 20 { + // CHECK-NEXT: affine.for %i3 = 0 to 10 { + // CHECK-NEXT: affine.for %i4 = 0 to 20 { // CHECK-NEXT: %2 = load %0[%i4] : memref<20xf32> // CHECK-NEXT: } // CHECK-NEXT: } @@ -1175,31 +1175,31 @@ func @should_fuse_at_depth1_with_trip_count_19() { %c0 = constant 0 : index %cf0 = constant 0.0 : f32 - for %i0 = 0 to 100 { + affine.for %i0 = 0 to 100 { store %cf0, %a[%i0]: memref<100xf32> } - for %i1 = 0 to 5 { - for %i2 = 0 to 19 { + affine.for %i1 = 0 to 5 { + affine.for %i2 = 0 to 19 { %v0 = load %a[%i2]: memref<100xf32> } - for %i3 = 0 to 10 { - for %i4 = 0 to 10 { + affine.for %i3 = 0 to 10 { + affine.for %i4 = 0 to 10 { %v1 = load %a[%i4]: memref<100xf32> } } } // NOTE: The size of the private memref created for fusion is shrunk to 19xf32 // CHECK-DAG: %0 = alloc() : memref<19xf32> - // CHECK: for %i0 = 0 to 5 { - // CHECK-NEXT: for %i1 = 0 to 19 { + // CHECK: affine.for %i0 = 0 to 5 { + // CHECK-NEXT: affine.for %i1 = 0 to 19 { // CHECK-NEXT: store %cst, %0[%i1] : memref<19xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i2 = 0 to 19 { + // CHECK-NEXT: affine.for %i2 = 0 to 19 { // CHECK-NEXT: %1 = load %0[%i2] : memref<19xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i3 = 0 to 10 { - // CHECK-NEXT: for %i4 = 0 to 10 { + // CHECK-NEXT: affine.for %i3 = 0 to 10 { + // CHECK-NEXT: affine.for %i4 = 0 to 10 { // CHECK-NEXT: %2 = load %0[%i4] : memref<19xf32> // CHECK-NEXT: } // CHECK-NEXT: } @@ -1217,26 +1217,26 @@ func @should_fuse_with_private_memrefs_with_diff_shapes() { %m = alloc() : memref<100xf32> %cf7 = constant 7.0 : f32 - for %i0 = 0 to 100 { + affine.for %i0 = 0 to 100 { store %cf7, %m[%i0] : memref<100xf32> } - for %i1 = 0 to 17 { + affine.for %i1 = 0 to 17 { %v0 = load %m[%i1] : memref<100xf32> } - for %i2 = 0 to 82 { + affine.for %i2 = 0 to 82 { %v1 = load %m[%i2] : memref<100xf32> } // Should create two new private memrefs customized to the shapes accessed // by loops %i1 and %i2. // CHECK-DAG: %0 = alloc() : memref<1xf32> // CHECK-DAG: %1 = alloc() : memref<1xf32> - // CHECK: for %i0 = 0 to 82 { + // CHECK: affine.for %i0 = 0 to 82 { // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: store %cst, %1[%2] : memref<1xf32> // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: %4 = load %1[%3] : memref<1xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i1 = 0 to 17 { + // CHECK-NEXT: affine.for %i1 = 0 to 17 { // CHECK-NEXT: %5 = affine.apply [[MAP0]](%i1, %i1) // CHECK-NEXT: store %cst, %0[%5] : memref<1xf32> // CHECK-NEXT: %6 = affine.apply [[MAP0]](%i1, %i1) @@ -1252,18 +1252,18 @@ func @should_fuse_with_private_memrefs_with_diff_shapes() { func @should_not_fuse_live_out_arg(%arg0: memref<10xf32>) { %cf7 = constant 7.0 : f32 - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %arg0[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { %v0 = load %arg0[%i1] : memref<10xf32> } // This tests that the loop nest '%i0' should not be removed after fusion // because it writes to memref argument '%arg0'. - // CHECK: for %i0 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %arg0[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: affine.for %i1 = 0 to 10 { // CHECK-NEXT: %0 = load %arg0[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return @@ -1276,19 +1276,19 @@ func @should_not_fuse_live_out_arg(%arg0: memref<10xf32>) { func @should_not_fuse_escaping_memref() -> memref<10xf32> { %cf7 = constant 7.0 : f32 %m = alloc() : memref<10xf32> - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { %v0 = load %m[%i1] : memref<10xf32> } // This tests that the loop nest '%i0' should not be removed after fusion // because it writes to memref '%m' which is returned by the function. // CHECK-DAG: %0 = alloc() : memref<10xf32> - // CHECK: for %i0 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: affine.for %i1 = 0 to 10 { // CHECK-NEXT: %1 = load %0[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return %0 : memref<10xf32> @@ -1303,17 +1303,17 @@ func @R3_to_R2_reshape() { %c0 = constant 0 : index - for %i0 = 0 to 2 { - for %i1 = 0 to 3 { - for %i2 = 0 to 16 { + affine.for %i0 = 0 to 2 { + affine.for %i1 = 0 to 3 { + affine.for %i2 = 0 to 16 { %val = "foo"(%i0, %i1, %i2) : (index, index, index) -> i32 store %val, %in[%i0, %i1, %i2] : memref<2x3x16xi32> } } } - for %ii = 0 to 32 { - for %jj = 0 to 3 { + affine.for %ii = 0 to 32 { + affine.for %jj = 0 to 3 { %a0 = affine.apply (d0, d1) -> (d0 * 3 + d1) (%ii, %jj) %idx = affine.apply (d0) -> (d0 floordiv (3 * 16)) (%a0) %v = load %in[%idx, %jj, %c0] @@ -1332,8 +1332,8 @@ func @R3_to_R2_reshape() { // CHECK-LABEL: func @R3_to_R2_reshape() // CHECK-DAG: %0 = alloc() : memref<1x1x1xi32> -// CHECK: for %i0 = 0 to 32 { -// CHECK-NEXT: for %i1 = 0 to 3 { +// CHECK: affine.for %i0 = 0 to 32 { +// CHECK-NEXT: affine.for %i1 = 0 to 3 { // CHECK-NEXT: %1 = affine.apply [[MAP0]](%i0, %i1) // CHECK-NEXT: %2 = affine.apply [[MAP1]]()[%c0] // CHECK-NEXT: %3 = "foo"(%1, %i1, %2) : (index, index, index) -> i32 @@ -1360,19 +1360,19 @@ func @should_not_fuse_multi_output_producer() { %cf7 = constant 7.0 : f32 - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %a[%i0] : memref<10xf32> store %cf7, %b[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { %v0 = load %a[%i1] : memref<10xf32> } - // CHECK: for %i0 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: affine.for %i1 = 0 to 10 { // CHECK-NEXT: %2 = load %0[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return @@ -1389,30 +1389,30 @@ func @fusion_preventing_deps_on_middle_loop() { %cf7 = constant 7.0 : f32 - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { %v0 = load %a[%i0] : memref<10xf32> store %v0, %b[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { store %cf7, %a[%i1] : memref<10xf32> %v1 = load %c[%i1] : memref<10xf32> } - for %i2 = 0 to 10 { + affine.for %i2 = 0 to 10 { %v2 = load %b[%i2] : memref<10xf32> store %v2, %c[%i2] : memref<10xf32> } // Loops '%i0' and '%i2' cannot fuse along producer/consumer edge on memref // '%b', because of the WAR dep from '%i0' to '%i1' on memref '%a' and // because of the WAR dep from '%i1' to '%i2' on memref '%c'. - // CHECK: for %i0 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: %3 = load %0[%i0] : memref<10xf32> // CHECK-NEXT: store %3, %1[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: affine.for %i1 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i1] : memref<10xf32> // CHECK-NEXT: %4 = load %2[%i1] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i2 = 0 to 10 { + // CHECK-NEXT: affine.for %i2 = 0 to 10 { // CHECK-NEXT: %5 = load %1[%i2] : memref<10xf32> // CHECK-NEXT: store %5, %2[%i2] : memref<10xf32> // CHECK-NEXT: } @@ -1432,17 +1432,17 @@ func @should_fuse_and_move_to_preserve_war_dep() { %cf7 = constant 7.0 : f32 - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { %v0 = load %b[%i0] : memref<10xf32> store %v0, %a[%i0] : memref<10xf32> } - for %i1 = 0 to 3 { + affine.for %i1 = 0 to 3 { %v2 = load %c[%i1] : memref<10xf32> } - for %i2 = 0 to 5 { + affine.for %i2 = 0 to 5 { store %cf7, %b[%i2] : memref<10xf32> } - for %i3 = 0 to 10 { + affine.for %i3 = 0 to 10 { %v1 = load %a[%i3] : memref<10xf32> store %cf7, %c[%i3] : memref<10xf32> } @@ -1461,10 +1461,10 @@ func @should_fuse_and_move_to_preserve_war_dep() { // if the fused loop nest is inserted between loops '%i1' and '%i2'. // CHECK-DAG: %0 = alloc() : memref<1xf32> - // CHECK: for %i0 = 0 to 3 { + // CHECK: affine.for %i0 = 0 to 3 { // CHECK-NEXT: %3 = load %2[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: affine.for %i1 = 0 to 10 { // CHECK-NEXT: %4 = load %1[%i1] : memref<10xf32> // CHECK-NEXT: %5 = affine.apply [[MAP0]](%i1, %i1) // CHECK-NEXT: store %4, %0[%5] : memref<1xf32> @@ -1472,7 +1472,7 @@ func @should_fuse_and_move_to_preserve_war_dep() { // CHECK-NEXT: %7 = load %0[%6] : memref<1xf32> // CHECK-NEXT: store %cst, %2[%i1] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i2 = 0 to 5 { + // CHECK-NEXT: affine.for %i2 = 0 to 5 { // CHECK-NEXT: store %cst, %1[%i2] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return @@ -1489,30 +1489,30 @@ func @fusion_preventing_dep_on_constant() { %cf7 = constant 7.0 : f32 - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { %v0 = load %b[%i0] : memref<10xf32> store %cf7, %a[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { store %cf7, %b[%i1] : memref<10xf32> } %cf11 = constant 11.0 : f32 - for %i2 = 0 to 10 { + affine.for %i2 = 0 to 10 { %v2 = load %a[%i2] : memref<10xf32> store %cf11, %c[%i2] : memref<10xf32> } // Loops '%i0' and '%i2' cannot fuse along producer/consumer edge on memref // '%a', because of the WAR dep from '%i0' to '%i1' on memref '%b' and // because of the SSA value dep from '%cf11' def to use in '%i2'. - // CHECK: for %i0 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: %3 = load %1[%i0] : memref<10xf32> // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: affine.for %i1 = 0 to 10 { // CHECK-NEXT: store %cst, %1[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: %cst_0 = constant 1.100000e+01 : f32 - // CHECK-NEXT: for %i2 = 0 to 10 { + // CHECK-NEXT: affine.for %i2 = 0 to 10 { // CHECK-NEXT: %4 = load %0[%i2] : memref<10xf32> // CHECK-NEXT: store %cst_0, %2[%i2] : memref<10xf32> // CHECK-NEXT: } @@ -1532,14 +1532,14 @@ func @should_fuse_and_preserve_dep_on_constant() { %cf7 = constant 7.0 : f32 %cf11 = constant 11.0 : f32 - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { %v0 = load %b[%i0] : memref<10xf32> store %cf7, %a[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { store %cf7, %b[%i1] : memref<10xf32> } - for %i2 = 0 to 10 { + affine.for %i2 = 0 to 10 { %v2 = load %a[%i2] : memref<10xf32> store %cf11, %c[%i2] : memref<10xf32> } @@ -1549,7 +1549,7 @@ func @should_fuse_and_preserve_dep_on_constant() { // the SSA value dep from '%cf11' def to use in '%i2'. // CHECK: %cst_0 = constant 1.100000e+01 : f32 - // CHECK-NEXT: for %i0 = 0 to 10 { + // CHECK-NEXT: affine.for %i0 = 0 to 10 { // CHECK-NEXT: %3 = load %1[%i0] : memref<10xf32> // CHECK-NEXT: %4 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: store %cst, %0[%4] : memref<1xf32> @@ -1557,7 +1557,7 @@ func @should_fuse_and_preserve_dep_on_constant() { // CHECK-NEXT: %6 = load %0[%5] : memref<1xf32> // CHECK-NEXT: store %cst_0, %2[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: affine.for %i1 = 0 to 10 { // CHECK-NEXT: store %cst, %1[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return @@ -1575,25 +1575,25 @@ func @should_fuse_and_preserve_dep_on_constant() { func @should_fuse_at_depth_above_loop_carried_dependence(%arg0: memref<64x4xf32>, %arg1: memref<64x4xf32>) { %out = alloc() : memref<64x4xf32> %0 = constant 0.0 : f32 - for %i0 = 0 to 64 { - for %i1 = 0 to 4 { + affine.for %i0 = 0 to 64 { + affine.for %i1 = 0 to 4 { store %0, %out[%i0, %i1] : memref<64x4xf32> } } - for %i2 = 0 to 4 { - for %i3 = 0 to 4 { - for %i4 = 0 to 16 { + affine.for %i2 = 0 to 4 { + affine.for %i3 = 0 to 4 { + affine.for %i4 = 0 to 16 { %1 = affine.apply (d0, d1) -> (d0 * 16 - d1 + 15)(%i3, %i4) %2 = load %arg1[%1, %i2] : memref<64x4xf32> "op0"(%2) : (f32) -> () } - for %i5 = 0 to 4 { - for %i6 = 0 to 16 { + affine.for %i5 = 0 to 4 { + affine.for %i6 = 0 to 16 { %3 = affine.apply (d0, d1) -> (d0 * 16 - d1 + 15)(%i5, %i6) %4 = load %arg0[%3, %i3] : memref<64x4xf32> "op1"(%4) : (f32) -> () } - for %i7 = 0 to 16 { + affine.for %i7 = 0 to 16 { %5 = "op2"() : () -> (f32) %6 = affine.apply (d0, d1) -> (d0 * 16 + d1)(%i5, %i7) %7 = load %out[%6, %i2] : memref<64x4xf32> @@ -1613,25 +1613,25 @@ func @should_fuse_at_depth_above_loop_carried_dependence(%arg0: memref<64x4xf32> // memref size can be reduced to 128x1xf32. // CHECK: %0 = alloc() : memref<64x1xf32> - // CHECK: for %i0 = 0 to 4 { - // CHECK-NEXT: for %i1 = 0 to 64 { + // CHECK: affine.for %i0 = 0 to 4 { + // CHECK-NEXT: affine.for %i1 = 0 to 64 { // CHECK-NEXT: %1 = affine.apply [[MAP0]](%i0, %i1, %i0) // CHECK-NEXT: %2 = affine.apply [[MAP1]](%i0, %i1, %i0) // CHECK-NEXT: store %cst, %0[%1, %2] : memref<64x1xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i2 = 0 to 4 { - // CHECK-NEXT: for %i3 = 0 to 16 { + // CHECK-NEXT: affine.for %i2 = 0 to 4 { + // CHECK-NEXT: affine.for %i3 = 0 to 16 { // CHECK-NEXT: %3 = affine.apply [[MAP2]](%i2, %i3) // CHECK-NEXT: %4 = load %arg1[%3, %i0] : memref<64x4xf32> // CHECK-NEXT: "op0"(%4) : (f32) -> () // CHECK-NEXT: } - // CHECK-NEXT: for %i4 = 0 to 4 { - // CHECK-NEXT: for %i5 = 0 to 16 { + // CHECK-NEXT: affine.for %i4 = 0 to 4 { + // CHECK-NEXT: affine.for %i5 = 0 to 16 { // CHECK-NEXT: %5 = affine.apply [[MAP2]](%i4, %i5) // CHECK-NEXT: %6 = load %arg0[%5, %i2] : memref<64x4xf32> // CHECK-NEXT: "op1"(%6) : (f32) -> () // CHECK-NEXT: } - // CHECK-NEXT: for %i6 = 0 to 16 { + // CHECK-NEXT: affine.for %i6 = 0 to 16 { // CHECK-NEXT: %7 = "op2"() : () -> f32 // CHECK-NEXT: %8 = affine.apply [[MAP3]](%i4, %i6) // CHECK-NEXT: %9 = affine.apply [[MAP0]](%i0, %8, %i0) @@ -1660,14 +1660,14 @@ func @should_fuse_after_private_memref_creation() { %cf7 = constant 7.0 : f32 - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %a[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { %v0 = load %a[%i1] : memref<10xf32> store %v0, %b[%i1] : memref<10xf32> } - for %i2 = 0 to 10 { + affine.for %i2 = 0 to 10 { %v1 = load %a[%i2] : memref<10xf32> store %v1, %b[%i2] : memref<10xf32> } @@ -1678,14 +1678,14 @@ func @should_fuse_after_private_memref_creation() { // private memref, the dependence between '%i0' and '%i1' on memref '%a' no // longer exists, so '%i0' can now be fused into '%i2'. - // CHECK: for %i0 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: store %cst, %1[%3] : memref<1xf32> // CHECK-NEXT: %4 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: %5 = load %1[%4] : memref<1xf32> // CHECK-NEXT: store %5, %2[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: affine.for %i1 = 0 to 10 { // CHECK-NEXT: %6 = affine.apply [[MAP0]](%i1, %i1) // CHECK-NEXT: store %cst, %0[%6] : memref<1xf32> // CHECK-NEXT: %7 = affine.apply [[MAP0]](%i1, %i1) diff --git a/mlir/test/Transforms/loop-tiling.mlir b/mlir/test/Transforms/loop-tiling.mlir index c2fdbd4f80f..a1f9d717fab 100644 --- a/mlir/test/Transforms/loop-tiling.mlir +++ b/mlir/test/Transforms/loop-tiling.mlir @@ -8,12 +8,12 @@ // CHECK-DAG: [[UB_INTRA_TILE:#map[0-9]+]] = (d0, d1, d2) -> (d2 + 32, s0, 4096 floordiv s1) // CHECK-LABEL: func @loop_tiling() -// CHECK-NEXT: for %i0 = 0 to 256 step 32 { -// CHECK-NEXT: for %i1 = 0 to 512 step 32 { -// CHECK-NEXT: for %i2 = 0 to 1024 step 32 { -// CHECK-NEXT: for %i3 = [[IDENTITY]](%i0) to [[MAP0]](%i0) { -// CHECK-NEXT: for %i4 = [[IDENTITY]](%i1) to [[MAP0]](%i1) { -// CHECK-NEXT: for %i5 = [[IDENTITY]](%i2) to [[MAP0]](%i2) { +// CHECK-NEXT: affine.for %i0 = 0 to 256 step 32 { +// CHECK-NEXT: affine.for %i1 = 0 to 512 step 32 { +// CHECK-NEXT: affine.for %i2 = 0 to 1024 step 32 { +// CHECK-NEXT: affine.for %i3 = [[IDENTITY]](%i0) to [[MAP0]](%i0) { +// CHECK-NEXT: affine.for %i4 = [[IDENTITY]](%i1) to [[MAP0]](%i1) { +// CHECK-NEXT: affine.for %i5 = [[IDENTITY]](%i2) to [[MAP0]](%i2) { // CHECK-NEXT: "foo"(%i3, %i4, %i5) : (index, index, index) -> () // CHECK-NEXT: } // CHECK-NEXT: } @@ -21,32 +21,32 @@ // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: for %i6 = 0 to 50 step 32 { -// CHECK-NEXT: for %i7 = [[IDENTITY]](%i6) to min [[MAP1]](%i6) { +// CHECK-NEXT: affine.for %i6 = 0 to 50 step 32 { +// CHECK-NEXT: affine.for %i7 = [[IDENTITY]](%i6) to min [[MAP1]](%i6) { // CHECK-NEXT: "bar"(%i7, %i7) : (index, index) -> () // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: for %i8 = 0 to 21 step 32 { -// CHECK-NEXT: for %i9 = [[IDENTITY]](%i8) to 21 { +// CHECK-NEXT: affine.for %i8 = 0 to 21 step 32 { +// CHECK-NEXT: affine.for %i9 = [[IDENTITY]](%i8) to 21 { // CHECK-NEXT: "foobar"(%i9) : (index) -> () // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return func @loop_tiling() { - for %i = 0 to 256 { - for %j = 0 to 512 { - for %k = 0 to 1024 { + affine.for %i = 0 to 256 { + affine.for %j = 0 to 512 { + affine.for %k = 0 to 1024 { "foo"(%i, %j, %k) : (index, index, index) -> () } } } - for %x = 0 to 50 { + affine.for %x = 0 to 50 { "bar"(%x, %x) : (index, index) -> () } // Intra-tile loop won't need a min expression. - for %y = 0 to 21 { + affine.for %y = 0 to 21 { "foobar"(%y) : (index) -> () } @@ -58,12 +58,12 @@ func @loop_tiling() { // CHECK-LABEL: func @loop_max_min_bound(%arg0: memref, %arg1: index, %arg2: index) { func @loop_max_min_bound(%A : memref, %L : index, %U : index) { %M = dim %A, 0 : memref - for %iTT = max #lb()[%L] to min #ub()[%M, %U] { + affine.for %iTT = max #lb()[%L] to min #ub()[%M, %U] { %out = affine.apply (d0) -> (d0) (%iTT) } return -// CHECK: for %i0 = max [[LB]]()[%arg1] to min [[UB]]()[%0, %arg2] step 32 { -// CHECK-NEXT: for %i1 = [[IDENTITY]](%i0) to min [[UB_INTRA_TILE]](%0, %arg2, %i0) { +// CHECK: affine.for %i0 = max [[LB]]()[%arg1] to min [[UB]]()[%0, %arg2] step 32 { +// CHECK-NEXT: affine.for %i1 = [[IDENTITY]](%i0) to min [[UB_INTRA_TILE]](%0, %arg2, %i0) { // CHECK-NEXT: %1 = affine.apply [[IDENTITY]](%i1) // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/mlir/test/Transforms/lower-affine.mlir b/mlir/test/Transforms/lower-affine.mlir index 7804c9ca752..c24ef1038fc 100644 --- a/mlir/test/Transforms/lower-affine.mlir +++ b/mlir/test/Transforms/lower-affine.mlir @@ -24,7 +24,7 @@ func @body(index) -> () // CHECK-NEXT: return // CHECK-NEXT: } func @simple_loop() { - for %i = 1 to 42 { + affine.for %i = 1 to 42 { call @body(%i) : (index) -> () } return @@ -65,9 +65,9 @@ func @post(index) -> () // CHECK-NEXT: return // CHECK-NEXT: } func @imperfectly_nested_loops() { - for %i = 0 to 42 { + affine.for %i = 0 to 42 { call @pre(%i) : (index) -> () - for %j = 7 to 56 step 2 { + affine.for %j = 7 to 56 step 2 { call @body2(%i, %j) : (index, index) -> () } call @post(%i) : (index) -> () @@ -122,13 +122,13 @@ func @body3(index, index) -> () // CHECK-NEXT: return // CHECK-NEXT: } func @more_imperfectly_nested_loops() { - for %i = 0 to 42 { + affine.for %i = 0 to 42 { call @pre(%i) : (index) -> () - for %j = 7 to 56 step 2 { + affine.for %j = 7 to 56 step 2 { call @body2(%i, %j) : (index, index) -> () } call @mid(%i) : (index) -> () - for %k = 18 to 37 step 3 { + affine.for %k = 18 to 37 step 3 { call @body3(%i, %k) : (index, index) -> () } call @post(%i) : (index) -> () @@ -161,8 +161,8 @@ func @more_imperfectly_nested_loops() { // CHECK-NEXT: return // CHECK-NEXT: } func @affine_apply_loops_shorthand(%N : index) { - for %i = 0 to %N { - for %j = %i to 42 { + affine.for %i = 0 to %N { + affine.for %j = %i to 42 { call @body2(%i, %j) : (index, index) -> () } } @@ -360,7 +360,7 @@ func @if_for() { // CHECK-NEXT: [[outerEndBB]]: // CHECK-NEXT: br [[outerLoopInit:\^bb[0-9]+]] if #set1(%i) { - for %j = 0 to 42 { + affine.for %j = 0 to 42 { if #set2(%j) { call @body2(%i, %j) : (index, index) -> () } @@ -397,9 +397,9 @@ func @if_for() { // CHECK-NEXT: %c1_9 = constant 1 : index // CHECK-NEXT: %16 = addi %9, %c1_9 : index // CHECK-NEXT: br [[outerLoopCond]](%16 : index) - for %k = 0 to 42 { + affine.for %k = 0 to 42 { if #set2(%k) { - for %l = 0 to 42 { + affine.for %l = 0 to 42 { call @body3(%k, %l) : (index, index) -> () } } @@ -446,8 +446,8 @@ func @if_for() { // CHECK-NEXT: return // CHECK-NEXT: } func @loop_min_max(%N : index) { - for %i = 0 to 42 { - for %j = max #lbMultiMap(%i)[%N] to min #ubMultiMap(%i)[%N] { + affine.for %i = 0 to 42 { + affine.for %j = max #lbMultiMap(%i)[%N] to min #ubMultiMap(%i)[%N] { call @body2(%i, %j) : (index, index) -> () } } @@ -486,7 +486,7 @@ func @loop_min_max(%N : index) { // CHECK-NEXT: return // CHECK-NEXT: } func @min_reduction_tree(%v : index) { - for %i = 0 to min #map_7_values(%v)[] { + affine.for %i = 0 to min #map_7_values(%v)[] { call @body(%i) : (index) -> () } return diff --git a/mlir/test/Transforms/memref-bound-check.mlir b/mlir/test/Transforms/memref-bound-check.mlir index 2926bf1afbc..b3d5b23e70f 100644 --- a/mlir/test/Transforms/memref-bound-check.mlir +++ b/mlir/test/Transforms/memref-bound-check.mlir @@ -11,8 +11,8 @@ func @test() { %A = alloc() : memref<9 x 9 x i32> %B = alloc() : memref<111 x i32> - for %i = -1 to 10 { - for %j = -1 to 10 { + affine.for %i = -1 to 10 { + affine.for %j = -1 to 10 { %idx0 = affine.apply (d0, d1) -> (d0)(%i, %j) %idx1 = affine.apply (d0, d1) -> (d1)(%i, %j) // Out of bound access. @@ -27,7 +27,7 @@ func @test() { } } - for %k = 0 to 10 { + affine.for %k = 0 to 10 { // In bound. %u = load %B[%zero] : memref<111 x i32> // Out of bounds. @@ -43,8 +43,8 @@ func @test_mod_floordiv_ceildiv() { %zero = constant 0 : index %A = alloc() : memref<128 x 64 x 64 x i32> - for %i = 0 to 256 { - for %j = 0 to 256 { + affine.for %i = 0 to 256 { + affine.for %j = 0 to 256 { %idx0 = affine.apply (d0, d1, d2) -> (d0 mod 128 + 1)(%i, %j, %j) %idx1 = affine.apply (d0, d1, d2) -> (d1 floordiv 4 + 1)(%i, %j, %j) %idx2 = affine.apply (d0, d1, d2) -> (d2 ceildiv 4)(%i, %j, %j) @@ -69,8 +69,8 @@ func @test_no_out_of_bounds() { %C = alloc() : memref<257 x i32> %B = alloc() : memref<1 x i32> - for %i = 0 to 256 { - for %j = 0 to 256 { + affine.for %i = 0 to 256 { + affine.for %j = 0 to 256 { // All of these accesses are in bound; check that no errors are emitted. // CHECK: %3 = affine.apply {{#map.*}}(%i0, %i1) // CHECK-NEXT: %4 = load %0[%3, %c0] : memref<257x256xi32> @@ -93,8 +93,8 @@ func @mod_div() { %zero = constant 0 : index %A = alloc() : memref<128 x 64 x 64 x i32> - for %i = 0 to 256 { - for %j = 0 to 256 { + affine.for %i = 0 to 256 { + affine.for %j = 0 to 256 { %idx0 = affine.apply (d0, d1, d2) -> (d0 mod 128 + 1)(%i, %j, %j) %idx1 = affine.apply (d0, d1, d2) -> (d1 floordiv 4 + 1)(%i, %j, %j) %idx2 = affine.apply (d0, d1, d2) -> (d2 ceildiv 4)(%i, %j, %j) @@ -115,8 +115,8 @@ func @mod_div() { // CHECK-LABEL: func @mod_floordiv_nested() { func @mod_floordiv_nested() { %A = alloc() : memref<256 x 256 x i32> - for %i = 0 to 256 { - for %j = 0 to 256 { + affine.for %i = 0 to 256 { + affine.for %j = 0 to 256 { %idx0 = affine.apply (d0, d1) -> ((d0 mod 1024) floordiv 4)(%i, %j) %idx1 = affine.apply (d0, d1) -> ((((d1 mod 128) mod 32) ceildiv 4) * 32)(%i, %j) load %A[%idx0, %idx1] : memref<256 x 256 x i32> // expected-error {{'load' op memref out of upper bound access along dimension #2}} @@ -128,7 +128,7 @@ func @mod_floordiv_nested() { // CHECK-LABEL: func @test_semi_affine_bailout func @test_semi_affine_bailout(%N : index) { %B = alloc() : memref<10 x i32> - for %i = 0 to 10 { + affine.for %i = 0 to 10 { %idx = affine.apply (d0)[s0] -> (d0 * s0)(%i)[%N] %y = load %B[%idx] : memref<10 x i32> } @@ -138,7 +138,7 @@ func @test_semi_affine_bailout(%N : index) { // CHECK-LABEL: func @multi_mod_floordiv func @multi_mod_floordiv() { %A = alloc() : memref<2x2xi32> - for %ii = 0 to 64 { + affine.for %ii = 0 to 64 { %idx0 = affine.apply (d0) -> ((d0 mod 147456) floordiv 1152) (%ii) %idx1 = affine.apply (d0) -> (((d0 mod 147456) mod 1152) floordiv 384) (%ii) %v = load %A[%idx0, %idx1] : memref<2x2xi32> @@ -153,8 +153,8 @@ func @delinearize_mod_floordiv() { %out = alloc() : memref<64x9xi32> // Reshape '%in' into '%out'. - for %ii = 0 to 64 { - for %jj = 0 to 9 { + affine.for %ii = 0 to 64 { + affine.for %jj = 0 to 9 { %a0 = affine.apply (d0, d1) -> (d0 * (9 * 1024) + d1 * 128) (%ii, %jj) %a10 = affine.apply (d0) -> (d0 floordiv (2 * 3 * 3 * 128 * 128)) (%a0) @@ -189,7 +189,7 @@ func @out_of_bounds() { %in = alloc() : memref<1xi32> %c9 = constant 9 : i32 - for %i0 = 10 to 11 { + affine.for %i0 = 10 to 11 { %idy = affine.apply (d0) -> (100 * d0 floordiv 1000) (%i0) store %c9, %in[%idy] : memref<1xi32> // expected-error {{'store' op memref out of upper bound access along dimension #1}} } diff --git a/mlir/test/Transforms/memref-dataflow-opt.mlir b/mlir/test/Transforms/memref-dataflow-opt.mlir index 710d14c1cf9..ed39d71eefd 100644 --- a/mlir/test/Transforms/memref-dataflow-opt.mlir +++ b/mlir/test/Transforms/memref-dataflow-opt.mlir @@ -10,14 +10,14 @@ func @simple_store_load() { %cf7 = constant 7.0 : f32 %m = alloc() : memref<10xf32> - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> %v0 = load %m[%i0] : memref<10xf32> %v1 = addf %v0, %v0 : f32 } return // CHECK: %cst = constant 7.000000e+00 : f32 -// CHECK-NEXT: for %i0 = 0 to 10 { +// CHECK-NEXT: affine.for %i0 = 0 to 10 { // CHECK-NEXT: %0 = addf %cst, %cst : f32 // CHECK-NEXT: } // CHECK-NEXT: return @@ -30,7 +30,7 @@ func @multi_store_load() { %cf8 = constant 8.0 : f32 %cf9 = constant 9.0 : f32 %m = alloc() : memref<10xf32> - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> %v0 = load %m[%i0] : memref<10xf32> %v1 = addf %v0, %v0 : f32 @@ -45,7 +45,7 @@ func @multi_store_load() { // CHECK-NEXT: %cst = constant 7.000000e+00 : f32 // CHECK-NEXT: %cst_0 = constant 8.000000e+00 : f32 // CHECK-NEXT: %cst_1 = constant 9.000000e+00 : f32 -// CHECK-NEXT: for %i0 = 0 to 10 { +// CHECK-NEXT: affine.for %i0 = 0 to 10 { // CHECK-NEXT: %0 = addf %cst, %cst : f32 // CHECK-NEXT: %1 = mulf %cst_1, %cst_1 : f32 // CHECK-NEXT: } @@ -59,8 +59,8 @@ func @multi_store_load() { func @store_load_affine_apply() -> memref<10x10xf32> { %cf7 = constant 7.0 : f32 %m = alloc() : memref<10x10xf32> - for %i0 = 0 to 10 { - for %i1 = 0 to 10 { + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { %t0 = affine.apply (d0, d1) -> (d1 + 1)(%i0, %i1) %t1 = affine.apply (d0, d1) -> (d0)(%i0, %i1) %idx0 = affine.apply (d0, d1) -> (d1) (%t0, %t1) @@ -75,8 +75,8 @@ func @store_load_affine_apply() -> memref<10x10xf32> { return %m : memref<10x10xf32> // CHECK: %cst = constant 7.000000e+00 : f32 // CHECK-NEXT: %0 = alloc() : memref<10x10xf32> -// CHECK-NEXT: for %i0 = 0 to 10 { -// CHECK-NEXT: for %i1 = 0 to 10 { +// CHECK-NEXT: affine.for %i0 = 0 to 10 { +// CHECK-NEXT: affine.for %i1 = 0 to 10 { // CHECK-NEXT: %1 = affine.apply [[MAP0]](%i0, %i1) // CHECK-NEXT: %2 = affine.apply [[MAP1]](%i0, %i1) // CHECK-NEXT: %3 = affine.apply [[MAP2]](%1, %2) @@ -92,17 +92,17 @@ func @store_load_affine_apply() -> memref<10x10xf32> { func @store_load_nested(%N : index) { %cf7 = constant 7.0 : f32 %m = alloc() : memref<10xf32> - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> - for %i1 = 0 to %N { + affine.for %i1 = 0 to %N { %v0 = load %m[%i0] : memref<10xf32> %v1 = addf %v0, %v0 : f32 } } return // CHECK: %cst = constant 7.000000e+00 : f32 -// CHECK-NEXT: for %i0 = 0 to 10 { -// CHECK-NEXT: for %i1 = 0 to %arg0 { +// CHECK-NEXT: affine.for %i0 = 0 to 10 { +// CHECK-NEXT: affine.for %i1 = 0 to %arg0 { // CHECK-NEXT: %0 = addf %cst, %cst : f32 // CHECK-NEXT: } // CHECK-NEXT: } @@ -117,12 +117,12 @@ func @multi_store_load_nested_no_fwd(%N : index) { %cf7 = constant 7.0 : f32 %cf8 = constant 8.0 : f32 %m = alloc() : memref<10xf32> - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> - for %i1 = 0 to %N { + affine.for %i1 = 0 to %N { store %cf8, %m[%i1] : memref<10xf32> } - for %i2 = 0 to %N { + affine.for %i2 = 0 to %N { // CHECK: %{{[0-9]+}} = load %0[%i0] : memref<10xf32> %v0 = load %m[%i0] : memref<10xf32> %v1 = addf %v0, %v0 : f32 @@ -138,9 +138,9 @@ func @store_load_store_nested_no_fwd(%N : index) { %cf7 = constant 7.0 : f32 %cf9 = constant 9.0 : f32 %m = alloc() : memref<10xf32> - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> - for %i1 = 0 to %N { + affine.for %i1 = 0 to %N { // CHECK: %{{[0-9]+}} = load %0[%i0] : memref<10xf32> %v0 = load %m[%i0] : memref<10xf32> %v1 = addf %v0, %v0 : f32 @@ -159,16 +159,16 @@ func @multi_store_load_nested_fwd(%N : index) { %cf9 = constant 9.0 : f32 %cf10 = constant 10.0 : f32 %m = alloc() : memref<10xf32> - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> - for %i1 = 0 to %N { + affine.for %i1 = 0 to %N { store %cf8, %m[%i1] : memref<10xf32> } - for %i2 = 0 to %N { + affine.for %i2 = 0 to %N { store %cf9, %m[%i2] : memref<10xf32> } store %cf10, %m[%i0] : memref<10xf32> - for %i3 = 0 to %N { + affine.for %i3 = 0 to %N { // CHECK-NOT: %{{[0-9]+}} = load %v0 = load %m[%i0] : memref<10xf32> %v1 = addf %v0, %v0 : f32 @@ -182,10 +182,10 @@ func @multi_store_load_nested_fwd(%N : index) { func @store_load_no_fwd() { %cf7 = constant 7.0 : f32 %m = alloc() : memref<10xf32> - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> - for %i1 = 0 to 10 { - for %i2 = 0 to 10 { + affine.for %i1 = 0 to 10 { + affine.for %i2 = 0 to 10 { // CHECK: load %{{[0-9]+}} %v0 = load %m[%i2] : memref<10xf32> %v1 = addf %v0, %v0 : f32 @@ -202,9 +202,9 @@ func @store_load_fwd() { %c0 = constant 0 : index %m = alloc() : memref<10xf32> store %cf7, %m[%c0] : memref<10xf32> - for %i0 = 0 to 10 { - for %i1 = 0 to 10 { - for %i2 = 0 to 10 { + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { + affine.for %i2 = 0 to 10 { // CHECK-NOT: load %{{[0-9]}}+ %v0 = load %m[%c0] : memref<10xf32> %v1 = addf %v0, %v0 : f32 @@ -223,9 +223,9 @@ func @store_load_store_nested_fwd(%N : index) -> f32 { %c0 = constant 0 : index %c1 = constant 1 : index %m = alloc() : memref<10xf32> - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> - for %i1 = 0 to %N { + affine.for %i1 = 0 to %N { %v0 = load %m[%i0] : memref<10xf32> %v1 = addf %v0, %v0 : f32 %idx = affine.apply (d0) -> (d0 + 1) (%i0) @@ -236,9 +236,9 @@ func @store_load_store_nested_fwd(%N : index) -> f32 { %v3 = load %m[%c1] : memref<10xf32> return %v3 : f32 // CHECK: %0 = alloc() : memref<10xf32> -// CHECK-NEXT: for %i0 = 0 to 10 { +// CHECK-NEXT: affine.for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> -// CHECK-NEXT: for %i1 = 0 to %arg0 { +// CHECK-NEXT: affine.for %i1 = 0 to %arg0 { // CHECK-NEXT: %1 = addf %cst, %cst : f32 // CHECK-NEXT: %2 = affine.apply [[MAP4]](%i0) // CHECK-NEXT: store %cst_0, %0[%2] : memref<10xf32> diff --git a/mlir/test/Transforms/memref-dependence-check.mlir b/mlir/test/Transforms/memref-dependence-check.mlir index 3ec840b1eb7..6e176f5d29b 100644 --- a/mlir/test/Transforms/memref-dependence-check.mlir +++ b/mlir/test/Transforms/memref-dependence-check.mlir @@ -13,14 +13,14 @@ func @store_may_execute_before_load() { // ancestor IfOp of the store, dominates the ancestor ForSmt of the load, // and thus the store "may" conditionally execute before the load. if #set0(%c0) { - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-note@-2 {{dependence from 0 to 0 at depth 2 = false}} // expected-note@-3 {{dependence from 0 to 1 at depth 1 = true}} } } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { %v0 = load %m[%i1] : memref<10xf32> // expected-note@-1 {{dependence from 1 to 1 at depth 1 = false}} // expected-note@-2 {{dependence from 1 to 1 at depth 2 = false}} @@ -37,13 +37,13 @@ func @dependent_loops() { %cst = constant 7.000000e+00 : f32 // There is a dependence from 0 to 1 at depth 1 (common surrounding loops 0) // because the first loop with the store dominates the second loop. - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cst, %0[%i0] : memref<10xf32> // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-note@-2 {{dependence from 0 to 0 at depth 2 = false}} // expected-note@-3 {{dependence from 0 to 1 at depth 1 = true}} } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { %1 = load %0[%i1] : memref<10xf32> // expected-note@-1 {{dependence from 1 to 1 at depth 1 = false}} // expected-note@-2 {{dependence from 1 to 1 at depth 2 = false}} @@ -231,7 +231,7 @@ func @store_range_load_after_range() { %m = alloc() : memref<100xf32> %c7 = constant 7.0 : f32 %c10 = constant 10 : index - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { %a0 = affine.apply (d0) -> (d0) (%i0) store %c7, %m[%a0] : memref<100xf32> // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} @@ -254,7 +254,7 @@ func @store_load_func_symbol(%arg0: index, %arg1: index) { %m = alloc() : memref<100xf32> %c7 = constant 7.0 : f32 %c10 = constant 10 : index - for %i0 = 0 to %arg1 { + affine.for %i0 = 0 to %arg1 { %a0 = affine.apply (d0) -> (d0) (%arg0) store %c7, %m[%a0] : memref<100xf32> // expected-note@-1 {{dependence from 0 to 0 at depth 1 = [1, +inf]}} @@ -277,7 +277,7 @@ func @store_range_load_last_in_range() { %m = alloc() : memref<100xf32> %c7 = constant 7.0 : f32 %c10 = constant 10 : index - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { %a0 = affine.apply (d0) -> (d0) (%i0) // For dependence from 0 to 1, we do not have a loop carried dependence // because only the final write in the loop accesses the same element as the @@ -305,7 +305,7 @@ func @store_range_load_before_range() { %m = alloc() : memref<100xf32> %c7 = constant 7.0 : f32 %c0 = constant 0 : index - for %i0 = 1 to 11 { + affine.for %i0 = 1 to 11 { %a0 = affine.apply (d0) -> (d0) (%i0) store %c7, %m[%a0] : memref<100xf32> // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} @@ -328,7 +328,7 @@ func @store_range_load_first_in_range() { %m = alloc() : memref<100xf32> %c7 = constant 7.0 : f32 %c0 = constant 0 : index - for %i0 = 1 to 11 { + affine.for %i0 = 1 to 11 { %a0 = affine.apply (d0) -> (d0) (%i0) // Dependence from 0 to 1 at depth 1 is a range because all loads at // constant index zero are reads after first store at index zero during @@ -353,7 +353,7 @@ func @store_range_load_first_in_range() { func @store_plus_3() { %m = alloc() : memref<100xf32> %c7 = constant 7.0 : f32 - for %i0 = 1 to 11 { + affine.for %i0 = 1 to 11 { %a0 = affine.apply (d0) -> (d0 + 3) (%i0) store %c7, %m[%a0] : memref<100xf32> // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} @@ -375,7 +375,7 @@ func @store_plus_3() { func @load_minus_2() { %m = alloc() : memref<100xf32> %c7 = constant 7.0 : f32 - for %i0 = 2 to 11 { + affine.for %i0 = 2 to 11 { %a0 = affine.apply (d0) -> (d0) (%i0) store %c7, %m[%a0] : memref<100xf32> // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} @@ -397,8 +397,8 @@ func @load_minus_2() { func @perfectly_nested_loops_loop_independent() { %m = alloc() : memref<10x10xf32> %c7 = constant 7.0 : f32 - for %i0 = 0 to 11 { - for %i1 = 0 to 11 { + affine.for %i0 = 0 to 11 { + affine.for %i1 = 0 to 11 { // Dependence from access 0 to 1 is loop independent at depth = 3. %a00 = affine.apply (d0, d1) -> (d0) (%i0, %i1) %a01 = affine.apply (d0, d1) -> (d1) (%i0, %i1) @@ -428,8 +428,8 @@ func @perfectly_nested_loops_loop_independent() { func @perfectly_nested_loops_loop_carried_at_depth1() { %m = alloc() : memref<10x10xf32> %c7 = constant 7.0 : f32 - for %i0 = 0 to 9 { - for %i1 = 0 to 9 { + affine.for %i0 = 0 to 9 { + affine.for %i1 = 0 to 9 { // Dependence from access 0 to 1 is loop carried at depth 1. %a00 = affine.apply (d0, d1) -> (d0) (%i0, %i1) %a01 = affine.apply (d0, d1) -> (d1) (%i0, %i1) @@ -459,8 +459,8 @@ func @perfectly_nested_loops_loop_carried_at_depth1() { func @perfectly_nested_loops_loop_carried_at_depth2() { %m = alloc() : memref<10x10xf32> %c7 = constant 7.0 : f32 - for %i0 = 0 to 10 { - for %i1 = 0 to 10 { + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { // Dependence from access 0 to 1 is loop carried at depth 2. %a00 = affine.apply (d0, d1) -> (d0) (%i0, %i1) %a01 = affine.apply (d0, d1) -> (d1) (%i0, %i1) @@ -491,8 +491,8 @@ func @one_common_loop() { %m = alloc() : memref<10x10xf32> %c7 = constant 7.0 : f32 // There is a loop-independent dependence from access 0 to 1 at depth 2. - for %i0 = 0 to 10 { - for %i1 = 0 to 10 { + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { %a00 = affine.apply (d0, d1) -> (d0) (%i0, %i1) %a01 = affine.apply (d0, d1) -> (d1) (%i0, %i1) store %c7, %m[%a00, %a01] : memref<10x10xf32> @@ -502,7 +502,7 @@ func @one_common_loop() { // expected-note@-4 {{dependence from 0 to 1 at depth 1 = false}} // expected-note@-5 {{dependence from 0 to 1 at depth 2 = true}} } - for %i2 = 0 to 9 { + affine.for %i2 = 0 to 9 { %a10 = affine.apply (d0, d1) -> (d0) (%i0, %i2) %a11 = affine.apply (d0, d1) -> (d1) (%i0, %i2) %v0 = load %m[%a10, %a11] : memref<10x10xf32> @@ -525,7 +525,7 @@ func @dependence_cycle() { // Dependences: // *) loop-independent dependence from access 1 to 2 at depth 2. // *) loop-carried dependence from access 3 to 0 at depth 1. - for %i0 = 0 to 9 { + affine.for %i0 = 0 to 9 { %a0 = affine.apply (d0) -> (d0) (%i0) %v0 = load %m.a[%a0] : memref<100xf32> // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} @@ -575,8 +575,8 @@ func @dependence_cycle() { func @negative_and_positive_direction_vectors(%arg0: index, %arg1: index) { %m = alloc() : memref<10x10xf32> %c7 = constant 7.0 : f32 - for %i0 = 0 to %arg0 { - for %i1 = 0 to %arg1 { + affine.for %i0 = 0 to %arg0 { + affine.for %i1 = 0 to %arg1 { %a00 = affine.apply (d0, d1) -> (d0 - 1) (%i0, %i1) %a01 = affine.apply (d0, d1) -> (d1 + 1) (%i0, %i1) %v0 = load %m[%a00, %a01] : memref<10x10xf32> @@ -605,8 +605,8 @@ func @negative_and_positive_direction_vectors(%arg0: index, %arg1: index) { func @war_raw_waw_deps() { %m = alloc() : memref<100xf32> %c7 = constant 7.0 : f32 - for %i0 = 0 to 10 { - for %i1 = 0 to 10 { + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { %a0 = affine.apply (d0) -> (d0 + 1) (%i1) %v0 = load %m[%a0] : memref<100xf32> // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} @@ -633,7 +633,7 @@ func @war_raw_waw_deps() { func @mod_deps() { %m = alloc() : memref<100xf32> %c7 = constant 7.0 : f32 - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { %a0 = affine.apply (d0) -> (d0 mod 2) (%i0) // Results are conservative here since we currently don't have a way to // represent strided sets in FlatAffineConstraints. @@ -658,8 +658,8 @@ func @loop_nest_depth() { %0 = alloc() : memref<100x100xf32> %c7 = constant 7.0 : f32 - for %i0 = 0 to 128 { - for %i1 = 0 to 8 { + affine.for %i0 = 0 to 128 { + affine.for %i1 = 0 to 8 { store %c7, %0[%i0, %i1] : memref<100x100xf32> // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-note@-2 {{dependence from 0 to 0 at depth 2 = false}} @@ -667,10 +667,10 @@ func @loop_nest_depth() { // expected-note@-4 {{dependence from 0 to 1 at depth 1 = true}} } } - for %i2 = 0 to 8 { - for %i3 = 0 to 8 { - for %i4 = 0 to 8 { - for %i5 = 0 to 16 { + affine.for %i2 = 0 to 8 { + affine.for %i3 = 0 to 8 { + affine.for %i4 = 0 to 8 { + affine.for %i5 = 0 to 16 { %8 = affine.apply (d0, d1) -> (d0 * 16 + d1)(%i4, %i5) %9 = load %0[%8, %i3] : memref<100x100xf32> // expected-note@-1 {{dependence from 1 to 0 at depth 1 = false}} @@ -693,9 +693,9 @@ func @loop_nest_depth() { func @mod_div_3d() { %M = alloc() : memref<2x2x2xi32> %c0 = constant 0 : i32 - for %i0 = 0 to 8 { - for %i1 = 0 to 8 { - for %i2 = 0 to 8 { + affine.for %i0 = 0 to 8 { + affine.for %i1 = 0 to 8 { + affine.for %i2 = 0 to 8 { %idx0 = affine.apply (d0, d1, d2) -> (d0 floordiv 4) (%i0, %i1, %i2) %idx1 = affine.apply (d0, d1, d2) -> (d1 mod 2) (%i0, %i1, %i2) %idx2 = affine.apply (d0, d1, d2) -> (d2 floordiv 4) (%i0, %i1, %i2) @@ -719,12 +719,12 @@ func @delinearize_mod_floordiv() { %in = alloc() : memref<2x2x3x3x16x1xi32> %out = alloc() : memref<64x9xi32> - for %i0 = 0 to 2 { - for %i1 = 0 to 2 { - for %i2 = 0 to 3 { - for %i3 = 0 to 3 { - for %i4 = 0 to 16 { - for %i5 = 0 to 1 { + affine.for %i0 = 0 to 2 { + affine.for %i1 = 0 to 2 { + affine.for %i2 = 0 to 3 { + affine.for %i3 = 0 to 3 { + affine.for %i4 = 0 to 16 { + affine.for %i5 = 0 to 1 { store %val, %in[%i0, %i1, %i2, %i3, %i4, %i5] : memref<2x2x3x3x16x1xi32> // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-note@-2 {{dependence from 0 to 0 at depth 2 = false}} @@ -742,8 +742,8 @@ func @delinearize_mod_floordiv() { } } - for %ii = 0 to 64 { - for %jj = 0 to 9 { + affine.for %ii = 0 to 64 { + affine.for %jj = 0 to 9 { %a0 = affine.apply (d0, d1) -> (d0 * (9 * 1024) + d1 * 128) (%ii, %jj) %a10 = affine.apply (d0) -> (d0 floordiv (2 * 3 * 3 * 128 * 128)) (%a0) diff --git a/mlir/test/Transforms/pipeline-data-transfer.mlir b/mlir/test/Transforms/pipeline-data-transfer.mlir index 30f98db2583..ede5c63fbac 100644 --- a/mlir/test/Transforms/pipeline-data-transfer.mlir +++ b/mlir/test/Transforms/pipeline-data-transfer.mlir @@ -16,13 +16,13 @@ func @loop_nest_dma() { %zero = constant 0 : index %num_elts = constant 128 : index - for %i = 0 to 8 { + affine.for %i = 0 to 8 { dma_start %A[%i], %Ah[%i], %num_elts, %tag[%zero] : memref<256 x f32>, memref<32 x f32, 1>, memref<1 x f32> dma_wait %tag[%zero], %num_elts : memref<1 x f32> %v = load %Ah[%i] : memref<32 x f32, (d0) -> (d0), 1> %r = "compute"(%v) : (f32) -> (f32) store %r, %Ah[%i] : memref<32 x f32, (d0) -> (d0), 1> - for %j = 0 to 128 { + affine.for %j = 0 to 128 { "do_more_compute"(%i, %j) : (index, index) -> () } } @@ -34,7 +34,7 @@ func @loop_nest_dma() { // CHECK-NEXT: %3 = affine.apply [[MOD_2]](%c0) // CHECK-NEXT: %4 = affine.apply [[MOD_2]](%c0) // CHECK-NEXT: dma_start %0[%c0], %1[%3, %c0], %c128, %2[%4, %c0_0] : memref<256xf32>, memref<2x32xf32, 1>, memref<2x1xf32> -// CHECK-NEXT: for %i0 = 1 to 8 { +// CHECK-NEXT: affine.for %i0 = 1 to 8 { // CHECK-NEXT: %5 = affine.apply [[MOD_2]](%i0) // CHECK-NEXT: %6 = affine.apply [[MOD_2]](%i0) // CHECK-NEXT: dma_start %0[%i0], %1[%5, %i0], %c128, %2[%6, %c0_0] : memref<256xf32>, memref<2x32xf32, 1>, memref<2x1xf32> @@ -45,7 +45,7 @@ func @loop_nest_dma() { // CHECK-NEXT: %10 = load %1[%9, %7] : memref<2x32xf32, 1> // CHECK-NEXT: %11 = "compute"(%10) : (f32) -> f32 // CHECK-NEXT: store %11, %1[%9, %7] : memref<2x32xf32, 1> -// CHECK-NEXT: for %i1 = 0 to 128 { +// CHECK-NEXT: affine.for %i1 = 0 to 128 { // CHECK-NEXT: "do_more_compute"(%7, %i1) : (index, index) -> () // CHECK-NEXT: } // CHECK-NEXT: } @@ -56,7 +56,7 @@ func @loop_nest_dma() { // CHECK-NEXT: %15 = load %1[%14, %12] : memref<2x32xf32, 1> // CHECK-NEXT: %16 = "compute"(%15) : (f32) -> f32 // CHECK-NEXT: store %16, %1[%14, %12] : memref<2x32xf32, 1> -// CHECK-NEXT: for %i2 = 0 to 128 { +// CHECK-NEXT: affine.for %i2 = 0 to 128 { // CHECK-NEXT: "do_more_compute"(%12, %i2) : (index, index) -> () // CHECK-NEXT: } // CHECK-NEXT: return @@ -68,7 +68,7 @@ func @loop_step(%arg0: memref<512xf32>, %arg1: memref<512xf32>) { %c0 = constant 0 : index %c4 = constant 4 : index - for %i0 = 0 to 512 step 4 { + affine.for %i0 = 0 to 512 step 4 { %1 = alloc() : memref<4xf32, 1> %2 = alloc() : memref<1xi32> dma_start %arg0[%i0], %1[%c0], %c4, %2[%c0] @@ -82,7 +82,7 @@ func @loop_step(%arg0: memref<512xf32>, // CHECK: %2 = affine.apply [[FLOOR_MOD_2]](%c0) // CHECK: %3 = affine.apply [[FLOOR_MOD_2]](%c0) // CHECK-NEXT: dma_start %arg0[%c0], %0[%2, %c0_0], %c4, [[TAG]][%3, %c0_0] : memref<512xf32>, memref<2x4xf32, 1>, memref<2x1xi32> -// CHECK-NEXT: for %i0 = 4 to 512 step 4 { +// CHECK-NEXT: affine.for %i0 = 4 to 512 step 4 { // CHECK-NEXT: %4 = affine.apply [[FLOOR_MOD_2]](%i0) // CHECK-NEXT: %5 = affine.apply [[FLOOR_MOD_2]](%i0) // CHECK-NEXT: dma_start %arg0[%i0], %0[%4, %c0_0], %c4, [[TAG]][%5, %c0_0] : memref<512xf32>, memref<2x4xf32, 1>, memref<2x1xi32> @@ -114,8 +114,8 @@ func @loop_dma_nested(%arg0: memref<512x32xvector<8xf32>, #map0>, %arg1: memref< // Prologue for DMA overlap on arg2. // CHECK:[[TAG_ARG2:%[0-9]+]] = alloc() : memref<2x2xi32> // CHECK: dma_start %arg2[ - // CHECK: for %i0 = 1 to 8 { - for %i0 = 0 to 8 { + // CHECK: affine.for %i0 = 1 to 8 { + affine.for %i0 = 0 to 8 { %6 = affine.apply #map2(%i0) dma_start %arg2[%6, %c0], %2[%c0, %c0], %num_elts, %5[%c0] : memref<512x32xvector<8xf32>, #map0>, memref<64x4xvector<8xf32>, #map0, 2>, memref<2xi32> dma_wait %5[%c0], %num_elts : memref<2xi32> @@ -127,8 +127,8 @@ func @loop_dma_nested(%arg0: memref<512x32xvector<8xf32>, #map0>, %arg1: memref< // CHECK: [[TAG_ARG1:%[0-9]+]] = alloc() : memref<2x2xi32> // CHECK: dma_start %arg0[ // CHECK: dma_start %arg1[ - // CHECK-NEXT for %i1 = 1 to 8 { - for %i1 = 0 to 8 { + // CHECK-NEXT affine.for %i1 = 1 to 8 { + affine.for %i1 = 0 to 8 { %7 = affine.apply #map1(%i0, %i1) %8 = affine.apply #map2(%i1) dma_start %arg0[%7, %c0], %0[%c0, %c0], %num_elts, %3[%c0] : memref<512x32xvector<8xf32>, #map0>, memref<64x4xvector<8xf32>, #map0, 2>, memref<2xi32> @@ -140,8 +140,8 @@ func @loop_dma_nested(%arg0: memref<512x32xvector<8xf32>, #map0>, %arg1: memref< // CHECK: dma_start %arg1[ // CHECK: dma_wait [[TAG_ARG0]] // CHECK: dma_wait [[TAG_ARG1]] - // CHECK-NEXT: for %i2 = 0 to 4 { - for %i2 = 0 to 4 { + // CHECK-NEXT: affine.for %i2 = 0 to 4 { + affine.for %i2 = 0 to 4 { "foo"() : () -> () } } @@ -155,16 +155,16 @@ func @loop_dma_nested(%arg0: memref<512x32xvector<8xf32>, #map0>, %arg1: memref< // CHECK: [[TAG_ARG1_NESTED:%[0-9]+]] = alloc() : memref<2x2xi32> // CHECK: dma_start %arg0[ // CHECK: dma_start %arg1[ - // CHECK: for %i4 = 1 to 8 { + // CHECK: affine.for %i4 = 1 to 8 { // CHECK: dma_start %arg0[ // CHECK: dma_start %arg1[ // CHECK: dma_wait [[TAG_ARG0_NESTED]] // CHECK: dma_wait [[TAG_ARG1_NESTED]] - // CHECK: for %i5 = 0 to 4 { + // CHECK: affine.for %i5 = 0 to 4 { // CHECK: "foo"() : () -> () // CHECK: dma_wait [[TAG_ARG0_NESTED]] // CHECK: dma_wait [[TAG_ARG1_NESTED]] - // CHECK: for %i6 = 0 to 4 { + // CHECK: affine.for %i6 = 0 to 4 { } return // CHECK: } @@ -185,8 +185,8 @@ func @loop_dma_dependent(%arg2: memref<512x32xvector<8xf32>>) { // The two DMAs below are dependent (incoming and outgoing on the same // memref) in the same iteration; so no pipelining here. // CHECK-NOT: dma_start - // CHECK: for %i0 = 0 to 8 { - for %i0 = 0 to 8 { + // CHECK: affine.for %i0 = 0 to 8 { + affine.for %i0 = 0 to 8 { %6 = affine.apply #map2(%i0) dma_start %arg2[%6, %c0], %2[%c0, %c0], %num_elts, %5[%c0] : memref<512x32xvector<8xf32>>, memref<64x4xvector<8xf32>, 2>, memref<2xi32> dma_wait %5[%c0], %num_elts : memref<2xi32> @@ -206,8 +206,8 @@ func @escaping_use(%arg0: memref<512 x 32 x f32>) { %tag = alloc() : memref<1 x i32> // CHECK-NOT: dma_start - // CHECK: for %i0 = 0 to 16 { - for %kTT = 0 to 16 { + // CHECK: affine.for %i0 = 0 to 16 { + affine.for %kTT = 0 to 16 { dma_start %arg0[%zero, %zero], %Av[%zero, %zero], %num_elt, %tag[%zero] : memref<512 x 32 x f32>, memref<32 x 32 x f32, 2>, memref<1 x i32> @@ -230,14 +230,14 @@ func @live_out_use(%arg0: memref<512 x 32 x f32>) -> f32 { %tag = alloc() : memref<1 x i32> // CHECK-NOT: dma_start - // CHECK: for %i0 = 0 to 16 { - for %kTT = 0 to 16 { + // CHECK: affine.for %i0 = 0 to 16 { + affine.for %kTT = 0 to 16 { dma_start %arg0[%zero, %zero], %Av[%zero, %zero], %num_elt, %tag[%zero] : memref<512 x 32 x f32>, memref<32 x 32 x f32, 2>, memref<1 x i32> dma_wait %tag[%zero], %num_elt : memref<1 x i32> } - // Use live out of 'for' inst; no DMA pipelining will be done. + // Use live out of 'affine.for' inst; no DMA pipelining will be done. %v = load %Av[%zero, %zero] : memref<32 x 32 x f32, 2> return %v : f32 // CHECK: %{{[0-9]+}} = load %{{[0-9]+}}[%c0, %c0] : memref<32x32xf32, 2> @@ -261,14 +261,14 @@ func @dynamic_shape_dma_buffer(%arg0: memref<512 x 32 x f32>) { // CHECK: %5 = affine.apply [[MOD_2]](%c0) // CHECK: %6 = affine.apply [[MOD_2]](%c0) // CHECK: dma_start %arg0[%c0_0, %c0_0], %3[%5, %c0_0, %c0_0], %c512, %4[%6, %c0_0] - for %kTT = 0 to 16 { + affine.for %kTT = 0 to 16 { dma_start %arg0[%zero, %zero], %Av[%zero, %zero], %num_elt, %tag[%zero] : memref<512 x 32 x f32>, memref, memref<1 x i32> dma_wait %tag[%zero], %num_elt : memref<1 x i32> } return -// CHECK-NEXT: for %i0 = 1 to 16 { +// CHECK-NEXT: affine.for %i0 = 1 to 16 { // CHECK: %7 = affine.apply [[MOD_2]](%i0) // CHECK: %8 = affine.apply [[MOD_2]](%i0) // CHECK: dma_start %arg0[%c0_0, %c0_0], %3[%7, %c0_0, %c0_0], %c512, %4[%8, %c0_0] diff --git a/mlir/test/Transforms/simplify-affine-structures.mlir b/mlir/test/Transforms/simplify-affine-structures.mlir index c3f7077d628..d4c07042518 100644 --- a/mlir/test/Transforms/simplify-affine-structures.mlir +++ b/mlir/test/Transforms/simplify-affine-structures.mlir @@ -73,8 +73,8 @@ // CHECK-LABEL: func @test_gaussian_elimination_empty_set0() { func @test_gaussian_elimination_empty_set0() { - for %i0 = 1 to 10 { - for %i1 = 1 to 100 { + affine.for %i0 = 1 to 10 { + affine.for %i1 = 1 to 100 { // CHECK: [[SET_EMPTY_2D]](%i0, %i1) if (d0, d1) : (2 == 0)(%i0, %i1) { } @@ -85,8 +85,8 @@ func @test_gaussian_elimination_empty_set0() { // CHECK-LABEL: func @test_gaussian_elimination_empty_set1() { func @test_gaussian_elimination_empty_set1() { - for %i0 = 1 to 10 { - for %i1 = 1 to 100 { + affine.for %i0 = 1 to 10 { + affine.for %i1 = 1 to 100 { // CHECK: [[SET_EMPTY_2D]](%i0, %i1) if (d0, d1) : (1 >= 0, -1 >= 0) (%i0, %i1) { } @@ -97,8 +97,8 @@ func @test_gaussian_elimination_empty_set1() { // CHECK-LABEL: func @test_gaussian_elimination_non_empty_set2() { func @test_gaussian_elimination_non_empty_set2() { - for %i0 = 1 to 10 { - for %i1 = 1 to 100 { + affine.for %i0 = 1 to 10 { + affine.for %i1 = 1 to 100 { // CHECK: #set1(%i0, %i1) if #set2(%i0, %i1) { } @@ -111,8 +111,8 @@ func @test_gaussian_elimination_non_empty_set2() { func @test_gaussian_elimination_empty_set3() { %c7 = constant 7 : index %c11 = constant 11 : index - for %i0 = 1 to 10 { - for %i1 = 1 to 100 { + affine.for %i0 = 1 to 10 { + affine.for %i1 = 1 to 100 { // CHECK: #set2(%i0, %i1)[%c7, %c11] if #set3(%i0, %i1)[%c7, %c11] { } @@ -125,8 +125,8 @@ func @test_gaussian_elimination_empty_set3() { func @test_gaussian_elimination_non_empty_set4() { %c7 = constant 7 : index %c11 = constant 11 : index - for %i0 = 1 to 10 { - for %i1 = 1 to 100 { + affine.for %i0 = 1 to 10 { + affine.for %i1 = 1 to 100 { // CHECK: #set3(%i0, %i1)[%c7, %c11] if #set4(%i0, %i1)[%c7, %c11] { } @@ -139,8 +139,8 @@ func @test_gaussian_elimination_non_empty_set4() { func @test_gaussian_elimination_empty_set5() { %c7 = constant 7 : index %c11 = constant 11 : index - for %i0 = 1 to 10 { - for %i1 = 1 to 100 { + affine.for %i0 = 1 to 10 { + affine.for %i1 = 1 to 100 { // CHECK: #set2(%i0, %i1)[%c7, %c11] if #set5(%i0, %i1)[%c7, %c11] { } @@ -151,8 +151,8 @@ func @test_gaussian_elimination_empty_set5() { // CHECK-LABEL: func @test_fuzz_explosion func @test_fuzz_explosion(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index) { - for %i0 = 1 to 10 { - for %i1 = 1 to 100 { + affine.for %i0 = 1 to 10 { + affine.for %i1 = 1 to 100 { if #set_fuzz_virus(%i0, %i1, %arg0, %arg1, %arg2, %arg3) { } } @@ -163,8 +163,8 @@ func @test_fuzz_explosion(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : i // CHECK-LABEL: func @test_empty_set(%arg0: index) { func @test_empty_set(%N : index) { - for %i = 0 to 10 { - for %j = 0 to 10 { + affine.for %i = 0 to 10 { + affine.for %j = 0 to 10 { // CHECK: if [[SET_EMPTY_2D]](%i0, %i1) if (d0, d1) : (d0 - d1 >= 0, d1 - d0 - 1 >= 0)(%i, %j) { "foo"() : () -> () @@ -198,8 +198,8 @@ func @test_empty_set(%N : index) { } } // The tests below test GCDTightenInequalities(). - for %k = 0 to 10 { - for %l = 0 to 10 { + affine.for %k = 0 to 10 { + affine.for %l = 0 to 10 { // Empty because no multiple of 8 lies between 4 and 7. // CHECK: if [[SET_EMPTY_1D]](%i2) if (d0) : (8*d0 - 4 >= 0, -8*d0 + 7 >= 0)(%k) { @@ -226,7 +226,7 @@ func @test_empty_set(%N : index) { } } - for %m = 0 to 10 { + affine.for %m = 0 to 10 { // CHECK: if [[SET_EMPTY_1D]](%i{{[0-9]+}}) if (d0) : (d0 mod 2 - 3 == 0) (%m) { "foo"() : () -> () diff --git a/mlir/test/Transforms/strip-debuginfo.mlir b/mlir/test/Transforms/strip-debuginfo.mlir index 5d157282071..776b30eb489 100644 --- a/mlir/test/Transforms/strip-debuginfo.mlir +++ b/mlir/test/Transforms/strip-debuginfo.mlir @@ -10,7 +10,7 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) { %1 = "foo"() : () -> i32 loc("foo") // CHECK: } loc(unknown) - for %i0 = 0 to 8 { + affine.for %i0 = 0 to 8 { } loc(fused["foo", "mysource.cc":10:8]) // CHECK: } loc(unknown) diff --git a/mlir/test/Transforms/unroll-jam.mlir b/mlir/test/Transforms/unroll-jam.mlir index da4f965676f..98d284aeede 100644 --- a/mlir/test/Transforms/unroll-jam.mlir +++ b/mlir/test/Transforms/unroll-jam.mlir @@ -7,13 +7,13 @@ // CHECK-LABEL: func @unroll_jam_imperfect_nest() { func @unroll_jam_imperfect_nest() { // CHECK: %c100 = constant 100 : index - // CHECK-NEXT: for %i0 = 0 to 99 step 2 { - for %i = 0 to 101 { + // CHECK-NEXT: affine.for %i0 = 0 to 99 step 2 { + affine.for %i = 0 to 101 { // CHECK: %0 = "addi32"(%i0, %i0) : (index, index) -> i32 // CHECK-NEXT: %1 = affine.apply [[MAP_PLUS_1]](%i0) // CHECK-NEXT: %2 = "addi32"(%1, %1) : (index, index) -> i32 %x = "addi32"(%i, %i) : (index, index) -> i32 - for %j = 0 to 17 { + affine.for %j = 0 to 17 { // CHECK: %3 = "addi32"(%i0, %i0) : (index, index) -> i32 // CHECK-NEXT: %4 = "addi32"(%3, %3) : (i32, i32) -> i32 // CHECK-NEXT: %5 = affine.apply [[MAP_PLUS_1]](%i0) @@ -29,7 +29,7 @@ func @unroll_jam_imperfect_nest() { } // CHECK } // cleanup loop (single iteration) // CHECK: %11 = "addi32"(%c100, %c100) : (index, index) -> i32 - // CHECK-NEXT: for %i2 = 0 to 17 { + // CHECK-NEXT: affine.for %i2 = 0 to 17 { // CHECK-NEXT: %12 = "addi32"(%c100, %c100) : (index, index) -> i32 // CHECK-NEXT: %13 = "addi32"(%12, %12) : (i32, i32) -> i32 // CHECK-NEXT: } @@ -39,8 +39,8 @@ func @unroll_jam_imperfect_nest() { // UNROLL-BY-4-LABEL: func @loop_nest_unknown_count_1(%arg0: index) { func @loop_nest_unknown_count_1(%N : index) { - // UNROLL-BY-4-NEXT: for %i0 = 1 to #map{{[0-9]+}}()[%arg0] step 4 { - // UNROLL-BY-4-NEXT: for %i1 = 1 to 100 { + // UNROLL-BY-4-NEXT: affine.for %i0 = 1 to #map{{[0-9]+}}()[%arg0] step 4 { + // UNROLL-BY-4-NEXT: affine.for %i1 = 1 to 100 { // UNROLL-BY-4-NEXT: %0 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32 @@ -48,14 +48,14 @@ func @loop_nest_unknown_count_1(%N : index) { // UNROLL-BY-4-NEXT: } // UNROLL-BY-4-NEXT: } // A cleanup loop should be generated here. - // UNROLL-BY-4-NEXT: for %i2 = #map{{[0-9]+}}()[%arg0] to %arg0 { - // UNROLL-BY-4-NEXT: for %i3 = 1 to 100 { + // UNROLL-BY-4-NEXT: affine.for %i2 = #map{{[0-9]+}}()[%arg0] to %arg0 { + // UNROLL-BY-4-NEXT: affine.for %i3 = 1 to 100 { // UNROLL-BY-4-NEXT: %4 = "foo"() : () -> i32 // UNROLL-BY-4_NEXT: } // UNROLL-BY-4_NEXT: } // Specify the lower bound in a form so that both lb and ub operands match. - for %i = ()[s0] -> (1)()[%N] to %N { - for %j = 1 to 100 { + affine.for %i = ()[s0] -> (1)()[%N] to %N { + affine.for %j = 1 to 100 { %x = "foo"() : () -> i32 } } @@ -64,8 +64,8 @@ func @loop_nest_unknown_count_1(%N : index) { // UNROLL-BY-4-LABEL: func @loop_nest_unknown_count_2(%arg0: index) { func @loop_nest_unknown_count_2(%arg : index) { - // UNROLL-BY-4-NEXT: for %i0 = %arg0 to #map{{[0-9]+}}()[%arg0] step 4 { - // UNROLL-BY-4-NEXT: for %i1 = 1 to 100 { + // UNROLL-BY-4-NEXT: affine.for %i0 = %arg0 to #map{{[0-9]+}}()[%arg0] step 4 { + // UNROLL-BY-4-NEXT: affine.for %i1 = 1 to 100 { // UNROLL-BY-4-NEXT: %0 = "foo"(%i0) : (index) -> i32 // UNROLL-BY-4-NEXT: %1 = affine.apply #map{{[0-9]+}}(%i0) // UNROLL-BY-4-NEXT: %2 = "foo"(%1) : (index) -> i32 @@ -77,12 +77,12 @@ func @loop_nest_unknown_count_2(%arg : index) { // UNROLL-BY-4-NEXT: } // The cleanup loop is a single iteration one and is promoted. // UNROLL-BY-4-NEXT: %7 = affine.apply [[M1:#map{{[0-9]+}}]]()[%arg0] - // UNROLL-BY-4-NEXT: for %i3 = 1 to 100 { + // UNROLL-BY-4-NEXT: affine.for %i3 = 1 to 100 { // UNROLL-BY-4-NEXT: %8 = "foo"() : () -> i32 // UNROLL-BY-4_NEXT: } // Specify the lower bound in a form so that both lb and ub operands match. - for %i = ()[s0] -> (s0) ()[%arg] to ()[s0] -> (s0+8) ()[%arg] { - for %j = 1 to 100 { + affine.for %i = ()[s0] -> (s0) ()[%arg] to ()[s0] -> (s0+8) ()[%arg] { + affine.for %j = 1 to 100 { %x = "foo"(%i) : (index) -> i32 } } diff --git a/mlir/test/Transforms/unroll.mlir b/mlir/test/Transforms/unroll.mlir index c023561faa8..013f65367cb 100644 --- a/mlir/test/Transforms/unroll.mlir +++ b/mlir/test/Transforms/unroll.mlir @@ -46,13 +46,13 @@ // CHECK-LABEL: func @loop_nest_simplest() { func @loop_nest_simplest() { - // CHECK: for %i0 = 0 to 100 step 2 { - for %i = 0 to 100 step 2 { + // CHECK: affine.for %i0 = 0 to 100 step 2 { + affine.for %i = 0 to 100 step 2 { // CHECK: %c1_i32 = constant 1 : i32 // CHECK-NEXT: %c1_i32_0 = constant 1 : i32 // CHECK-NEXT: %c1_i32_1 = constant 1 : i32 // CHECK-NEXT: %c1_i32_2 = constant 1 : i32 - for %j = 0 to 4 { + affine.for %j = 0 to 4 { %x = constant 1 : i32 } } // CHECK: } @@ -62,8 +62,8 @@ func @loop_nest_simplest() { // CHECK-LABEL: func @loop_nest_simple_iv_use() { func @loop_nest_simple_iv_use() { // CHECK: %c0 = constant 0 : index - // CHECK-NEXT: for %i0 = 0 to 100 step 2 { - for %i = 0 to 100 step 2 { + // CHECK-NEXT: affine.for %i0 = 0 to 100 step 2 { + affine.for %i = 0 to 100 step 2 { // CHECK: %0 = "addi32"(%c0, %c0) : (index, index) -> i32 // CHECK: %1 = affine.apply [[MAP0]](%c0) // CHECK-NEXT: %2 = "addi32"(%1, %1) : (index, index) -> i32 @@ -71,7 +71,7 @@ func @loop_nest_simple_iv_use() { // CHECK-NEXT: %4 = "addi32"(%3, %3) : (index, index) -> i32 // CHECK: %5 = affine.apply [[MAP2]](%c0) // CHECK-NEXT: %6 = "addi32"(%5, %5) : (index, index) -> i32 - for %j = 0 to 4 { + affine.for %j = 0 to 4 { %x = "addi32"(%j, %j) : (index, index) -> i32 } } // CHECK: } @@ -82,8 +82,8 @@ func @loop_nest_simple_iv_use() { // CHECK-LABEL: func @loop_nest_body_def_use() { func @loop_nest_body_def_use() { // CHECK: %c0 = constant 0 : index - // CHECK-NEXT: for %i0 = 0 to 100 step 2 { - for %i = 0 to 100 step 2 { + // CHECK-NEXT: affine.for %i0 = 0 to 100 step 2 { + affine.for %i = 0 to 100 step 2 { // CHECK: %c0_0 = constant 0 : index %c0 = constant 0 : index // CHECK: %0 = affine.apply [[MAP0]](%c0) @@ -97,7 +97,7 @@ func @loop_nest_body_def_use() { // CHECK-NEXT: %8 = affine.apply [[MAP2]](%c0) // CHECK-NEXT: %9 = affine.apply [[MAP0]](%8) // CHECK-NEXT: %10 = "addi32"(%9, %c0_0) : (index, index) -> index - for %j = 0 to 4 { + affine.for %j = 0 to 4 { %x = "affine.apply" (%j) { map: (d0) -> (d0 + 1) } : (index) -> (index) %y = "addi32"(%x, %c0) : (index, index) -> index @@ -110,14 +110,14 @@ func @loop_nest_body_def_use() { func @loop_nest_strided() { // CHECK: %c2 = constant 2 : index // CHECK-NEXT: %c2_0 = constant 2 : index - // CHECK-NEXT: for %i0 = 0 to 100 { - for %i = 0 to 100 { + // CHECK-NEXT: affine.for %i0 = 0 to 100 { + affine.for %i = 0 to 100 { // CHECK: %0 = affine.apply [[MAP0]](%c2_0) // CHECK-NEXT: %1 = "addi32"(%0, %0) : (index, index) -> index // CHECK-NEXT: %2 = affine.apply [[MAP1]](%c2_0) // CHECK-NEXT: %3 = affine.apply [[MAP0]](%2) // CHECK-NEXT: %4 = "addi32"(%3, %3) : (index, index) -> index - for %j = 2 to 6 step 2 { + affine.for %j = 2 to 6 step 2 { %x = "affine.apply" (%j) { map: (d0) -> (d0 + 1) } : (index) -> (index) %y = "addi32"(%x, %x) : (index, index) -> index @@ -130,7 +130,7 @@ func @loop_nest_strided() { // CHECK-NEXT: %10 = affine.apply [[MAP3]](%c2) // CHECK-NEXT: %11 = affine.apply [[MAP0]](%10) // CHECK-NEXT: %12 = "addi32"(%11, %11) : (index, index) -> index - for %k = 2 to 7 step 2 { + affine.for %k = 2 to 7 step 2 { %z = "affine.apply" (%k) { map: (d0) -> (d0 + 1) } : (index) -> (index) %w = "addi32"(%z, %z) : (index, index) -> index @@ -142,8 +142,8 @@ func @loop_nest_strided() { // CHECK-LABEL: func @loop_nest_multiple_results() { func @loop_nest_multiple_results() { // CHECK: %c0 = constant 0 : index - // CHECK-NEXT: for %i0 = 0 to 100 { - for %i = 0 to 100 { + // CHECK-NEXT: affine.for %i0 = 0 to 100 { + affine.for %i = 0 to 100 { // CHECK: %0 = affine.apply [[MAP4]](%i0, %c0) // CHECK-NEXT: %1 = "addi32"(%0, %0) : (index, index) -> index // CHECK-NEXT: %2 = affine.apply #map{{.*}}(%i0, %c0) @@ -153,7 +153,7 @@ func @loop_nest_multiple_results() { // CHECK-NEXT: %6 = "addi32"(%5, %5) : (index, index) -> index // CHECK-NEXT: %7 = affine.apply #map{{.*}}(%i0, %4) // CHECK-NEXT: %8 = "fma"(%7, %5, %5) : (index, index, index) -> (index, index) - for %j = 0 to 2 step 1 { + affine.for %j = 0 to 2 step 1 { %x = affine.apply (d0, d1) -> (d0 + 1) (%i, %j) %y = "addi32"(%x, %x) : (index, index) -> index %z = affine.apply (d0, d1) -> (d0 + 3) (%i, %j) @@ -170,8 +170,8 @@ func @loop_nest_seq_imperfect(%a : memref<128x128xf32>) { // CHECK: %c0 = constant 0 : index // CHECK-NEXT: %c128 = constant 128 : index %c128 = constant 128 : index - // CHECK: for %i0 = 0 to 100 { - for %i = 0 to 100 { + // CHECK: affine.for %i0 = 0 to 100 { + affine.for %i = 0 to 100 { // CHECK: %0 = "vld"(%i0) : (index) -> i32 %ld = "vld"(%i) : (index) -> i32 // CHECK: %1 = affine.apply [[MAP0]](%c0) @@ -189,7 +189,7 @@ func @loop_nest_seq_imperfect(%a : memref<128x128xf32>) { // CHECK-NEXT: %13 = affine.apply [[MAP0]](%12) // CHECK-NEXT: %14 = "vmulf"(%12, %13) : (index, index) -> index // CHECK-NEXT: %15 = "vaddf"(%14, %14) : (index, index) -> index - for %j = 0 to 4 { + affine.for %j = 0 to 4 { %x = "affine.apply" (%j) { map: (d0) -> (d0 + 1) } : (index) -> (index) %y = "vmulf"(%j, %x) : (index, index) -> index @@ -218,7 +218,7 @@ func @loop_nest_seq_multiple() { // CHECK-NEXT: %5 = affine.apply [[MAP2]](%c0_0) // CHECK-NEXT: %6 = affine.apply [[MAP0]](%5) // CHECK-NEXT: "mul"(%6, %6) : (index, index) -> () - for %j = 0 to 4 { + affine.for %j = 0 to 4 { %x = "affine.apply" (%j) { map: (d0) -> (d0 + 1) } : (index) -> (index) "mul"(%x, %x) : (index, index) -> () @@ -226,8 +226,8 @@ func @loop_nest_seq_multiple() { // CHECK: %c99 = constant 99 : index %k = "constant"(){value: 99} : () -> index - // CHECK: for %i0 = 0 to 100 step 2 { - for %m = 0 to 100 step 2 { + // CHECK: affine.for %i0 = 0 to 100 step 2 { + affine.for %m = 0 to 100 step 2 { // CHECK: %7 = affine.apply [[MAP0]](%c0) // CHECK-NEXT: %8 = affine.apply [[MAP6]](%c0)[%c99] // CHECK-NEXT: %9 = affine.apply [[MAP0]](%c0) @@ -239,7 +239,7 @@ func @loop_nest_seq_multiple() { // CHECK-NEXT: %15 = affine.apply [[MAP2]](%c0) // CHECK-NEXT: %16 = affine.apply [[MAP0]](%15) // CHECK-NEXT: %17 = affine.apply [[MAP6]](%15)[%c99] - for %n = 0 to 4 { + affine.for %n = 0 to 4 { %y = "affine.apply" (%n) { map: (d0) -> (d0 + 1) } : (index) -> (index) %z = "affine.apply" (%n, %k) { map: (d0) [s0] -> (d0 + s0 + 1) } : @@ -251,16 +251,16 @@ func @loop_nest_seq_multiple() { // SHORT-LABEL: func @loop_nest_outer_unroll() { func @loop_nest_outer_unroll() { - // SHORT: for %i0 = 0 to 4 { + // SHORT: affine.for %i0 = 0 to 4 { // SHORT-NEXT: %0 = affine.apply [[MAP0]](%i0) // SHORT-NEXT: %1 = "addi32"(%0, %0) : (index, index) -> index // SHORT-NEXT: } - // SHORT-NEXT: for %i1 = 0 to 4 { + // SHORT-NEXT: affine.for %i1 = 0 to 4 { // SHORT-NEXT: %2 = affine.apply [[MAP0]](%i1) // SHORT-NEXT: %3 = "addi32"(%2, %2) : (index, index) -> index // SHORT-NEXT: } - for %i = 0 to 2 { - for %j = 0 to 4 { + affine.for %i = 0 to 2 { + affine.for %j = 0 to 4 { %x = "affine.apply" (%j) { map: (d0) -> (d0 + 1) } : (index) -> (index) %y = "addi32"(%x, %x) : (index, index) -> index @@ -284,28 +284,28 @@ func @loop_nest_seq_long() -> i32 { %zero_idx = constant 0 : index - for %n0 = 0 to 512 { - for %n1 = 0 to 8 { + affine.for %n0 = 0 to 512 { + affine.for %n1 = 0 to 8 { store %one, %A[%n0, %n1] : memref<512 x 512 x i32, (d0, d1) -> (d0, d1), 2> store %two, %B[%n0, %n1] : memref<512 x 512 x i32, (d0, d1) -> (d0, d1), 2> store %zero, %C[%n0, %n1] : memref<512 x 512 x i32, (d0, d1) -> (d0, d1), 2> } } - for %i0 = 0 to 2 { - for %i1 = 0 to 2 { - for %i2 = 0 to 8 { + affine.for %i0 = 0 to 2 { + affine.for %i1 = 0 to 2 { + affine.for %i2 = 0 to 8 { %b2 = "affine.apply" (%i1, %i2) {map: (d0, d1) -> (16*d0 + d1)} : (index, index) -> index %x = load %B[%i0, %b2] : memref<512 x 512 x i32, (d0, d1) -> (d0, d1), 2> "op1"(%x) : (i32) -> () } - for %j1 = 0 to 8 { - for %j2 = 0 to 8 { + affine.for %j1 = 0 to 8 { + affine.for %j2 = 0 to 8 { %a2 = "affine.apply" (%i1, %j2) {map: (d0, d1) -> (16*d0 + d1)} : (index, index) -> index %v203 = load %A[%j1, %a2] : memref<512 x 512 x i32, (d0, d1) -> (d0, d1), 2> "op2"(%v203) : (i32) -> () } - for %k2 = 0 to 8 { + affine.for %k2 = 0 to 8 { %s0 = "op3"() : () -> i32 %c2 = "affine.apply" (%i0, %k2) {map: (d0, d1) -> (16*d0 + d1)} : (index, index) -> index %s1 = load %C[%j1, %c2] : memref<512 x 512 x i32, (d0, d1) -> (d0, d1), 2> @@ -322,8 +322,8 @@ func @loop_nest_seq_long() -> i32 { // UNROLL-BY-4-LABEL: func @unroll_unit_stride_no_cleanup() { func @unroll_unit_stride_no_cleanup() { - // UNROLL-BY-4: for %i0 = 0 to 100 { - for %i = 0 to 100 { + // UNROLL-BY-4: affine.for %i0 = 0 to 100 { + affine.for %i = 0 to 100 { // UNROLL-BY-4: for [[L1:%i[0-9]+]] = 0 to 8 step 4 { // UNROLL-BY-4-NEXT: %0 = "addi32"([[L1]], [[L1]]) : (index, index) -> i32 // UNROLL-BY-4-NEXT: %1 = "addi32"(%0, %0) : (i32, i32) -> i32 @@ -337,13 +337,13 @@ func @unroll_unit_stride_no_cleanup() { // UNROLL-BY-4-NEXT: %9 = "addi32"(%8, %8) : (index, index) -> i32 // UNROLL-BY-4-NEXT: %10 = "addi32"(%9, %9) : (i32, i32) -> i32 // UNROLL-BY-4-NEXT: } - for %j = 0 to 8 { + affine.for %j = 0 to 8 { %x = "addi32"(%j, %j) : (index, index) -> i32 %y = "addi32"(%x, %x) : (i32, i32) -> i32 } // empty loop - // UNROLL-BY-4: for %i2 = 0 to 8 { - for %k = 0 to 8 { + // UNROLL-BY-4: affine.for %i2 = 0 to 8 { + affine.for %k = 0 to 8 { } } return @@ -351,8 +351,8 @@ func @unroll_unit_stride_no_cleanup() { // UNROLL-BY-4-LABEL: func @unroll_unit_stride_cleanup() { func @unroll_unit_stride_cleanup() { - // UNROLL-BY-4: for %i0 = 0 to 100 { - for %i = 0 to 100 { + // UNROLL-BY-4: affine.for %i0 = 0 to 100 { + affine.for %i = 0 to 100 { // UNROLL-BY-4: for [[L1:%i[0-9]+]] = 0 to 7 step 4 { // UNROLL-BY-4-NEXT: %0 = "addi32"([[L1]], [[L1]]) : (index, index) -> i32 // UNROLL-BY-4-NEXT: %1 = "addi32"(%0, %0) : (i32, i32) -> i32 @@ -370,7 +370,7 @@ func @unroll_unit_stride_cleanup() { // UNROLL-BY-4-NEXT: %11 = "addi32"([[L2]], [[L2]]) : (index, index) -> i32 // UNROLL-BY-4-NEXT: %12 = "addi32"(%11, %11) : (i32, i32) -> i32 // UNROLL-BY-4-NEXT: } - for %j = 0 to 10 { + affine.for %j = 0 to 10 { %x = "addi32"(%j, %j) : (index, index) -> i32 %y = "addi32"(%x, %x) : (i32, i32) -> i32 } @@ -380,8 +380,8 @@ func @unroll_unit_stride_cleanup() { // UNROLL-BY-4-LABEL: func @unroll_non_unit_stride_cleanup() { func @unroll_non_unit_stride_cleanup() { - // UNROLL-BY-4: for %i0 = 0 to 100 { - for %i = 0 to 100 { + // UNROLL-BY-4: affine.for %i0 = 0 to 100 { + affine.for %i = 0 to 100 { // UNROLL-BY-4: for [[L1:%i[0-9]+]] = 2 to 37 step 20 { // UNROLL-BY-4-NEXT: %0 = "addi32"([[L1]], [[L1]]) : (index, index) -> i32 // UNROLL-BY-4-NEXT: %1 = "addi32"(%0, %0) : (i32, i32) -> i32 @@ -399,7 +399,7 @@ func @unroll_non_unit_stride_cleanup() { // UNROLL-BY-4-NEXT: %11 = "addi32"([[L2]], [[L2]]) : (index, index) -> i32 // UNROLL-BY-4-NEXT: %12 = "addi32"(%11, %11) : (i32, i32) -> i32 // UNROLL-BY-4-NEXT: } - for %j = 2 to 48 step 5 { + affine.for %j = 2 to 48 step 5 { %x = "addi32"(%j, %j) : (index, index) -> i32 %y = "addi32"(%x, %x) : (i32, i32) -> i32 } @@ -411,8 +411,8 @@ func @unroll_non_unit_stride_cleanup() { func @loop_nest_single_iteration_after_unroll(%N: index) { // UNROLL-BY-4: %c0 = constant 0 : index // UNROLL-BY-4: %c4 = constant 4 : index - // UNROLL-BY-4: for %i0 = 0 to %arg0 { - for %i = 0 to %N { + // UNROLL-BY-4: affine.for %i0 = 0 to %arg0 { + affine.for %i = 0 to %N { // UNROLL-BY-4: %0 = "addi32"(%c0, %c0) : (index, index) -> i32 // UNROLL-BY-4-NEXT: %1 = affine.apply [[MAP0]](%c0) // UNROLL-BY-4-NEXT: %2 = "addi32"(%1, %1) : (index, index) -> i32 @@ -422,7 +422,7 @@ func @loop_nest_single_iteration_after_unroll(%N: index) { // UNROLL-BY-4-NEXT: %6 = "addi32"(%5, %5) : (index, index) -> i32 // UNROLL-BY-4-NEXT: %7 = "addi32"(%c4, %c4) : (index, index) -> i32 // UNROLL-BY-4-NOT: for - for %j = 0 to 5 { + affine.for %j = 0 to 5 { %x = "addi32"(%j, %j) : (index, index) -> i32 } // UNROLL-BY-4-NOT: } } // UNROLL-BY-4: } @@ -434,8 +434,8 @@ func @loop_nest_single_iteration_after_unroll(%N: index) { // No cleanup will be generated here. // UNROLL-BY-4-LABEL: func @loop_nest_operand1() { func @loop_nest_operand1() { -// UNROLL-BY-4: for %i0 = 0 to 100 step 2 { -// UNROLL-BY-4-NEXT: for %i1 = [[MAP10]](%i0) to #map{{[0-9]+}}(%i0) step 4 +// UNROLL-BY-4: affine.for %i0 = 0 to 100 step 2 { +// UNROLL-BY-4-NEXT: affine.for %i1 = [[MAP10]](%i0) to #map{{[0-9]+}}(%i0) step 4 // UNROLL-BY-4-NEXT: %0 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32 @@ -443,8 +443,8 @@ func @loop_nest_operand1() { // UNROLL-BY-4-NEXT: } // UNROLL-BY-4-NEXT: } // UNROLL-BY-4-NEXT: return - for %i = 0 to 100 step 2 { - for %j = (d0) -> (0) (%i) to (d0) -> (d0 - d0 mod 4) (%i) { + affine.for %i = 0 to 100 step 2 { + affine.for %j = (d0) -> (0) (%i) to (d0) -> (d0 - d0 mod 4) (%i) { %x = "foo"() : () -> i32 } } @@ -454,8 +454,8 @@ func @loop_nest_operand1() { // No cleanup will be generated here. // UNROLL-BY-4-LABEL: func @loop_nest_operand2() { func @loop_nest_operand2() { -// UNROLL-BY-4: for %i0 = 0 to 100 step 2 { -// UNROLL-BY-4-NEXT: for %i1 = [[MAP11]](%i0) to #map{{[0-9]+}}(%i0) step 4 { +// UNROLL-BY-4: affine.for %i0 = 0 to 100 step 2 { +// UNROLL-BY-4-NEXT: affine.for %i1 = [[MAP11]](%i0) to #map{{[0-9]+}}(%i0) step 4 { // UNROLL-BY-4-NEXT: %0 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32 @@ -463,8 +463,8 @@ func @loop_nest_operand2() { // UNROLL-BY-4-NEXT: } // UNROLL-BY-4-NEXT: } // UNROLL-BY-4-NEXT: return - for %i = 0 to 100 step 2 { - for %j = (d0) -> (d0) (%i) to (d0) -> (5*d0 + 4) (%i) { + affine.for %i = 0 to 100 step 2 { + affine.for %j = (d0) -> (d0) (%i) to (d0) -> (5*d0 + 4) (%i) { %x = "foo"() : () -> i32 } } @@ -475,16 +475,16 @@ func @loop_nest_operand2() { // factor. The cleanup loop happens to be a single iteration one and is promoted. // UNROLL-BY-4-LABEL: func @loop_nest_operand3() { func @loop_nest_operand3() { - // UNROLL-BY-4: for %i0 = 0 to 100 step 2 { - for %i = 0 to 100 step 2 { - // UNROLL-BY-4: for %i1 = [[MAP11]](%i0) to #map{{[0-9]+}}(%i0) step 4 { + // UNROLL-BY-4: affine.for %i0 = 0 to 100 step 2 { + affine.for %i = 0 to 100 step 2 { + // UNROLL-BY-4: affine.for %i1 = [[MAP11]](%i0) to #map{{[0-9]+}}(%i0) step 4 { // UNROLL-BY-4-NEXT: %0 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %3 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: } // UNROLL-BY-4-NEXT: %4 = "foo"() : () -> i32 - for %j = (d0) -> (d0) (%i) to (d0) -> (d0 + 9) (%i) { + affine.for %j = (d0) -> (d0) (%i) to (d0) -> (d0 + 9) (%i) { %x = "foo"() : () -> i32 } } // UNROLL-BY-4: } @@ -493,20 +493,20 @@ func @loop_nest_operand3() { // UNROLL-BY-4-LABEL: func @loop_nest_operand4(%arg0: index) { func @loop_nest_operand4(%N : index) { - // UNROLL-BY-4: for %i0 = 0 to 100 { - for %i = 0 to 100 { - // UNROLL-BY-4: for %i1 = [[MAP12]]()[%arg0] to #map{{[0-9]+}}()[%arg0] step 4 { + // UNROLL-BY-4: affine.for %i0 = 0 to 100 { + affine.for %i = 0 to 100 { + // UNROLL-BY-4: affine.for %i1 = [[MAP12]]()[%arg0] to #map{{[0-9]+}}()[%arg0] step 4 { // UNROLL-BY-4: %0 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %3 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: } // A cleanup loop will be be generated here. - // UNROLL-BY-4-NEXT: for %i2 = #map{{[0-9]+}}()[%arg0] to %arg0 { + // UNROLL-BY-4-NEXT: affine.for %i2 = #map{{[0-9]+}}()[%arg0] to %arg0 { // UNROLL-BY-4-NEXT: %4 = "foo"() : () -> i32 // UNROLL-BY-4_NEXT: } // Specify the lower bound so that both lb and ub operands match. - for %j = ()[s0] -> (0)()[%N] to %N { + affine.for %j = ()[s0] -> (0)()[%N] to %N { %x = "foo"() : () -> i32 } } @@ -518,7 +518,7 @@ func @loop_nest_unroll_full() { // CHECK-NEXT: %0 = "foo"() : () -> i32 // CHECK-NEXT: %1 = "bar"() : () -> i32 // CHECK-NEXT: return - for %i = 0 to 1 { + affine.for %i = 0 to 1 { %x = "foo"() : () -> i32 %y = "bar"() : () -> i32 } @@ -527,7 +527,7 @@ func @loop_nest_unroll_full() { // UNROLL-BY-1-LABEL: func @unroll_by_one_should_promote_single_iteration_loop() func @unroll_by_one_should_promote_single_iteration_loop() { - for %i = 0 to 1 { + affine.for %i = 0 to 1 { %x = "foo"(%i) : (index) -> i32 } return diff --git a/mlir/utils/emacs/mlir-mode.el b/mlir/utils/emacs/mlir-mode.el index f3c12795f5d..16c3f69ca3b 100644 --- a/mlir/utils/emacs/mlir-mode.el +++ b/mlir/utils/emacs/mlir-mode.el @@ -42,7 +42,7 @@ ;; Keywords `(,(regexp-opt '(;; Toplevel entities - "br" "ceildiv" "cfgfunc" "cond_br" "else" "extfunc" "false" "floordiv" "for" "if" "mlfunc" "mod" "return" "size" "step" "to" "true" "??" ) 'symbols) . font-lock-keyword-face)) + "br" "ceildiv" "func" "cond_br" "else" "extfunc" "false" "floordiv" "affine.for" "if" "mod" "return" "size" "step" "to" "true" "??" ) 'symbols) . font-lock-keyword-face)) "Syntax highlighting for MLIR.") ;; Emacs 23 compatibility. diff --git a/mlir/utils/vim/mlir.vim b/mlir/utils/vim/mlir.vim index 27526c6c54d..93291a719ae 100644 --- a/mlir/utils/vim/mlir.vim +++ b/mlir/utils/vim/mlir.vim @@ -10,9 +10,9 @@ syn keyword mlirType index i1 i2 i4 i8 i13 i16 i32 i64 \ f16 f32 tf_control syn keyword mlirType memref tensor vector -syntax keyword mlirKeywords extfunc cfgfunc mlfunc for to step return +syntax keyword mlirKeywords extfunc func to step return syntax keyword mlirConditional if else -syntax keyword mlirCoreOps dim addf addi subf subi mulf muli cmpi select constant affine.apply call call_indirect extract_element getTensor memref_cast tensor_cast load store alloc dealloc dma_start dma_wait +syntax keyword mlirCoreOps dim addf addi subf subi mulf muli cmpi select constant affine.apply affine.for call call_indirect extract_element getTensor memref_cast tensor_cast load store alloc dealloc dma_start dma_wait syn match mlirInt "-\=\<\d\+\>" syn match mlirFloat "-\=\<\d\+\.\d\+\>" -- cgit v1.2.3 From 4ba8c9147d04d82d629dde4730e1dd5d4ae4123d Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Wed, 6 Feb 2019 21:54:18 -0800 Subject: Automated rollback of changelist 232717775. PiperOrigin-RevId: 232807986 --- mlir/g3doc/Dialects/Affine.md | 26 +- mlir/g3doc/Dialects/SuperVector.md | 28 +- mlir/g3doc/LangRef.md | 14 +- mlir/g3doc/Passes.md | 2 +- mlir/g3doc/Rationale.md | 40 +- mlir/g3doc/RationaleSimplifiedPolyhedralForm.md | 14 +- mlir/include/mlir/AffineOps/AffineOps.h | 34 +- mlir/include/mlir/Analysis/Utils.h | 10 +- mlir/include/mlir/Analysis/VectorAnalysis.h | 16 +- mlir/include/mlir/Transforms/LoopUtils.h | 7 +- mlir/include/mlir/Transforms/Utils.h | 4 +- mlir/lib/AffineOps/AffineOps.cpp | 4 +- mlir/lib/Analysis/AffineAnalysis.cpp | 8 +- mlir/lib/Analysis/Utils.cpp | 10 +- mlir/lib/IR/Block.cpp | 2 +- mlir/lib/Transforms/DmaGeneration.cpp | 10 +- mlir/lib/Transforms/LoopFusion.cpp | 3 +- mlir/lib/Transforms/LoopTiling.cpp | 3 +- mlir/lib/Transforms/LoopUnroll.cpp | 2 +- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 2 +- mlir/lib/Transforms/LowerAffine.cpp | 8 +- mlir/lib/Transforms/LowerVectorTransfers.cpp | 10 +- mlir/lib/Transforms/MaterializeVectors.cpp | 24 +- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 8 +- mlir/lib/Transforms/PipelineDataTransfer.cpp | 18 +- mlir/lib/Transforms/Utils/LoopUtils.cpp | 14 +- mlir/lib/Transforms/Utils/Utils.cpp | 4 +- mlir/lib/Transforms/Vectorize.cpp | 50 +- mlir/test/AffineOps/canonicalize.mlir | 38 +- mlir/test/IR/invalid.mlir | 60 +-- mlir/test/IR/locations.mlir | 2 +- mlir/test/IR/parser.mlir | 76 +-- mlir/test/IR/pretty-locations.mlir | 2 +- .../Vectorize/lower_vector_transfers.mlir | 58 +-- mlir/test/Transforms/Vectorize/materialize.mlir | 16 +- .../Vectorize/materialize_vectors_1d_to_1d.mlir | 24 +- .../Vectorize/materialize_vectors_2d_to_1d.mlir | 24 +- .../Vectorize/materialize_vectors_2d_to_2d.mlir | 24 +- mlir/test/Transforms/Vectorize/normalize_maps.mlir | 24 +- mlir/test/Transforms/Vectorize/vectorize_1d.mlir | 62 +-- mlir/test/Transforms/Vectorize/vectorize_2d.mlir | 30 +- mlir/test/Transforms/Vectorize/vectorize_3d.mlir | 20 +- .../Vectorize/vectorize_outer_loop_2d.mlir | 18 +- .../vectorize_outer_loop_transpose_2d.mlir | 42 +- .../Vectorize/vectorize_transpose_2d.mlir | 42 +- mlir/test/Transforms/canonicalize.mlir | 12 +- mlir/test/Transforms/constant-fold.mlir | 4 +- mlir/test/Transforms/cse.mlir | 8 +- mlir/test/Transforms/dma-generate.mlir | 94 ++-- mlir/test/Transforms/loop-fusion.mlir | 516 ++++++++++----------- mlir/test/Transforms/loop-tiling.mlir | 36 +- mlir/test/Transforms/lower-affine.mlir | 28 +- mlir/test/Transforms/memref-bound-check.mlir | 32 +- mlir/test/Transforms/memref-dataflow-opt.mlir | 62 +-- mlir/test/Transforms/memref-dependence-check.mlir | 86 ++-- mlir/test/Transforms/pipeline-data-transfer.mlir | 50 +- .../Transforms/simplify-affine-structures.mlir | 38 +- mlir/test/Transforms/strip-debuginfo.mlir | 2 +- mlir/test/Transforms/unroll-jam.mlir | 30 +- mlir/test/Transforms/unroll.mlir | 136 +++--- mlir/utils/emacs/mlir-mode.el | 2 +- mlir/utils/vim/mlir.vim | 4 +- 62 files changed, 1036 insertions(+), 1041 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/g3doc/Dialects/Affine.md b/mlir/g3doc/Dialects/Affine.md index 0c69c60cbe9..55d26f0d956 100644 --- a/mlir/g3doc/Dialects/Affine.md +++ b/mlir/g3doc/Dialects/Affine.md @@ -15,7 +15,7 @@ loops and if instructions), the result of a [`affine.apply` operation](#'affine.apply'-operation) that recursively takes as arguments any symbolic identifiers. Dimensions may be bound not only to anything that a symbol is bound to, but also to induction variables of enclosing -[`affine.for` operations](#'affine.for'-operation), and the result of an +[`for` operations](#'for'-operation), and the result of an [`affine.apply` operation](#'affine.apply'-operation) (which recursively may use other dimensions and symbols). @@ -47,12 +47,12 @@ Example: %2 = affine.apply (i)[s0] -> (i+s0) (%42)[%n] ``` -#### 'affine.for' operation {#'affine.for'-operation} +#### 'for' operation {#'for'-operation} Syntax: ``` {.ebnf} -operation ::= `affine.for` ssa-id `=` lower-bound `to` upper-bound +operation ::= `for` ssa-id `=` lower-bound `to` upper-bound (`step` integer-literal)? `{` inst* `}` lower-bound ::= `max`? affine-map dim-and-symbol-use-list | shorthand-bound @@ -60,17 +60,17 @@ upper-bound ::= `min`? affine-map dim-and-symbol-use-list | shorthand-bound shorthand-bound ::= ssa-id | `-`? integer-literal ``` -The `affine.for` operation represents an affine loop nest, defining an SSA value -for its induction variable. This SSA value always has type +The `for` operation represents an affine loop nest, defining an SSA value for +its induction variable. This SSA value always has type [`index`](LangRef.md#index-type), which is the size of the machine word. -The `affine.for` operation executes its body a number of times iterating from a -lower bound to an upper bound by a stride. The stride, represented by `step`, is -a positive constant integer which defaults to "1" if not present. The lower and +The `for` operation executes its body a number of times iterating from a lower +bound to an upper bound by a stride. The stride, represented by `step`, is a +positive constant integer which defaults to "1" if not present. The lower and upper bounds specify a half-open range: the range includes the lower bound but does not include the upper bound. -The lower and upper bounds of a `affine.for` operation are represented as an +The lower and upper bounds of a `for` operation are represented as an application of an affine mapping to a list of SSA values passed to the map. The [same restrictions](#restrictions-on-dimensions-and-symbols) hold for these SSA values as for all bindings of SSA values to dimensions and symbols. @@ -94,8 +94,8 @@ Example showing reverse iteration of the inner loop: func @simple_example(%A: memref, %B: memref) { %N = dim %A, 0 : memref - affine.for %i = 0 to %N step 1 { - affine.for %j = 0 to %N { // implicitly steps by 1 + for %i = 0 to %N step 1 { + for %j = 0 to %N { // implicitly steps by 1 %0 = affine.apply #map57(%j)[%N] %tmp = call @F1(%A, %i, %0) : (memref, index, index)->(f32) call @F2(%tmp, %B, %i, %0) : (f32, memref, index, index)->() @@ -130,8 +130,8 @@ Example: #set = (d0, d1)[s0]: (d0 - 10 >= 0, s0 - d0 - 9 >= 0, d1 - 10 >= 0, s0 - d1 - 9 >= 0) func @reduced_domain_example(%A, %X, %N) : (memref<10xi32>, i32, i32) { - affine.for %i = 0 to %N { - affine.for %j = 0 to %N { + for %i = 0 to %N { + for %j = 0 to %N { %0 = affine.apply #map42(%j) %tmp = call @S1(%X, %i, %0) affine.if #set(%i, %j)[%N] { diff --git a/mlir/g3doc/Dialects/SuperVector.md b/mlir/g3doc/Dialects/SuperVector.md index cd540335a52..09beb950e37 100644 --- a/mlir/g3doc/Dialects/SuperVector.md +++ b/mlir/g3doc/Dialects/SuperVector.md @@ -22,9 +22,9 @@ Examples: // Read the slice `%A[%i0, %i1:%i1+256, %i2:%i2+32]` into vector<32x256xf32> and // pad with %f0 to handle the boundary case: %f0 = constant 0.0f : f32 -affine.for %i0 = 0 to %0 { - affine.for %i1 = 0 to %1 step 256 { - affine.for %i2 = 0 to %2 step 32 { +for %i0 = 0 to %0 { + for %i1 = 0 to %1 step 256 { + for %i2 = 0 to %2 step 32 { %v = vector_transfer_read %A, %i0, %i1, %i2, %f0 {permutation_map: (d0, d1, d2) -> (d2, d1)} : (memref, index, index, f32) -> vector<32x256xf32> @@ -33,8 +33,8 @@ affine.for %i0 = 0 to %0 { // Read the slice `%A[%i0, %i1]` (i.e. the element `%A[%i0, %i1]`) into // vector<128xf32>. The underlying implementation will require a 1-D vector // broadcast: -affine.for %i0 = 0 to %0 { - affine.for %i1 = 0 to %1 { +for %i0 = 0 to %0 { + for %i1 = 0 to %1 { %3 = vector_transfer_read %A, %i0, %i1 {permutation_map: (d0, d1) -> (0)} : (memref, index, index) -> vector<128xf32> @@ -80,9 +80,9 @@ A notional lowering of vector_transfer_read could generate code resembling: // %expr1, %expr2, %expr3, %expr4 defined before this point %tmp = alloc() : vector<3x4x5xf32> %view_in_tmp = "element_type_cast"(%tmp) : memref<1xvector<3x4x5xf32>> -affine.for %i = 0 to 3 { - affine.for %j = 0 to 4 { - affine.for %k = 0 to 5 { +for %i = 0 to 3 { + for %j = 0 to 4 { + for %k = 0 to 5 { %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref store %tmp[%i, %j, %k] : vector<3x4x5xf32> }}} @@ -101,8 +101,8 @@ lowered code would resemble: // %expr1, %expr2, %expr3, %expr4 defined before this point %tmp = alloc() : vector<3x4x5xf32> %view_in_tmp = "element_type_cast"(%tmp) : memref<1xvector<3x4x5xf32>> -affine.for %i = 0 to 3 { - affine.for %k = 0 to 5 { +for %i = 0 to 3 { + for %k = 0 to 5 { %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref store %tmp[%i, 0, %k] : vector<3x4x5xf32> }} @@ -129,10 +129,10 @@ Examples: ```mlir {.mlir} // write vector<16x32x64xf32> into the slice `%A[%i0, %i1:%i1+32, %i2:%i2+64, %i3:%i3+16]`: -affine.for %i0 = 0 to %0 { - affine.for %i1 = 0 to %1 step 32 { - affine.for %i2 = 0 to %2 step 64 { - affine.for %i3 = 0 to %3 step 16 { +for %i0 = 0 to %0 { + for %i1 = 0 to %1 step 32 { + for %i2 = 0 to %2 step 64 { + for %i3 = 0 to %3 step 16 { %val = `ssa-value` : vector<16x32x64xf32> vector_transfer_write %val, %A, %i0, %i1, %i2, %i3 {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} : diff --git a/mlir/g3doc/LangRef.md b/mlir/g3doc/LangRef.md index fdfc43ea39d..3448927d214 100644 --- a/mlir/g3doc/LangRef.md +++ b/mlir/g3doc/LangRef.md @@ -40,7 +40,7 @@ which means that values are defined before use and have scope defined by their dominance relations. Operations may produce zero or more results, and each is a distinct SSA value with its own type defined by the [type system](#type-system). -MLIR incorporates polyhedral compiler concepts, including `affine.for` and +MLIR incorporates polyhedral compiler concepts, including `for` and `affine.if` operations defined by the [affine dialect](Dialects/Affine.md), which model affine loops and affine conditionals. It also includes affine maps integrated into the type system - they are key to the representation of data and @@ -99,10 +99,10 @@ func @multiply(%A: memref<100x?xf32>, %B: memref) %C = alloc memref<100x50xf32>() // Multiplication loop nest. - affine.for %i = 0 to 100 { - affine.for %j = 0 to 50 { + for %i = 0 to 100 { + for %j = 0 to 50 { store 0 to %C[%i, %j] : memref<100x50xf32> - affine.for %k = 0 to %n { + for %k = 0 to %n { %a_v = load %A[%i, %k] : memref<100x?xf32> %b_v = load %B[%k, %j] : memref %prod = mulf %a_v, %b_v : f32 @@ -1434,7 +1434,7 @@ The arity of indices is the rank of the memref (i.e., if the memref loaded from is of rank 3, then 3 indices are required for the load following the memref identifier). -In an `affine.if` or `affine.for` body, the indices of a load are restricted to +In an `affine.if` or `for` body, the indices of a load are restricted to SSA values bound to surrounding loop induction variables, [symbols](#dimensions-and-symbols), results of a [`constant` operation](#'constant'-operation), or the result of an @@ -1456,7 +1456,7 @@ Example: **Context:** The `load` and `store` instructions are specifically crafted to fully resolve a reference to an element of a memref, and (in affine `affine.if` -and `affine.for` instructions) the compiler can follow use-def chains (e.g. +and `for` instructions) the compiler can follow use-def chains (e.g. through [`affine.apply`](Dialects/Affine.md#'affine.apply'-operation) operations) to precisely analyze references at compile-time using polyhedral techniques. This is possible because of the @@ -1492,7 +1492,7 @@ store %100, %A[%1, 1023] : memref<4x?xf32, #layout, hbm> **Context:** The `load` and `store` instructions are specifically crafted to fully resolve a reference to an element of a memref, and (in polyhedral -`affine.if` and `affine.for` instructions) the compiler can follow use-def +`affine.if` and `for` instructions) the compiler can follow use-def chains (e.g. through [`affine.apply`](Dialects/Affine.md#'affine.apply'-operation) operations) to precisely analyze references at compile-time using polyhedral techniques. This diff --git a/mlir/g3doc/Passes.md b/mlir/g3doc/Passes.md index bb15cec22a4..dc46b97f7b1 100644 --- a/mlir/g3doc/Passes.md +++ b/mlir/g3doc/Passes.md @@ -39,7 +39,7 @@ These restrictions may be lifted in the future. ### Output IR -Functions with `affine.for` and `affine.if` instructions eliminated. These +Functions with `for` and `affine.if` instructions eliminated. These functions may contain operations from the Standard dialect in addition to those already present before the pass. diff --git a/mlir/g3doc/Rationale.md b/mlir/g3doc/Rationale.md index 8b22e93598c..949f405d5f6 100644 --- a/mlir/g3doc/Rationale.md +++ b/mlir/g3doc/Rationale.md @@ -150,8 +150,8 @@ func bar(%A : memref<8x?xf32, #lmap>) { // dynamically using dim instruction. %N = dim %A, 1 : memref<8x?xf32, #lmap> - affine.for %i = 0 to 8 { - affine.for %j = 0 to %N { + for %i = 0 to 8 { + for %j = 0 to %N { // A[i,j] += 1 %s1 = load %A [%i, %j] : memref<8x?xf32, #lmap> %s2 = add %s1, 1 @@ -534,7 +534,7 @@ nested in an outer function that using affine loops. func @search(memref %S, i32 %key) { %ni = dim %A, 0 : memref // This loop can be parallelized - affine.for %i = 0 to %ni { + for %i = 0 to %ni { call @search_body (%A, %S, %i) : (memref, memref, i32) } return @@ -568,7 +568,7 @@ func @search_body(%A: memref, %S: memref, %key: i32) { As per the [MLIR spec](LangRef.md), the restrictions on dimensions and symbol identifiers to be used with the affine.apply instruction only apply to accesses -inside `affine.for` and `affine.if` instructions. However, an analysis of +inside `for` and `affine.if` instructions. However, an analysis of accesses inside the called function (`@search_body`) is necessary to determine if the `%i` loop could be parallelized: such function access analysis is calling context sensitive. @@ -590,8 +590,8 @@ for (i=0; i i32 { - affine.for %k = 0 to %m { - affine.for %l = 0 to %n { + for %k = 0 to %m { + for %l = 0 to %n { ... } } @@ -649,13 +649,13 @@ in a dilated convolution. func @conv2d(memref<16x1024x1024x3xf32, #lm0, vmem> %input, memref<5x5x3x32xf32, #lm0, vmem> %kernel, memref<16x512x512x32xf32, #lm0, vmem> %output) { - affine.for %b = 0 to %batch { - affine.for %oh = 0 to %output_height { - affine.for %ow = 0 to %output_width { - affine.for %of = 0 to %output_feature { - affine.for %kh = 0 to %kernel_height { - affine.for %kw = 0 to %kernel_width { - affine.for %if = 0 to %input_feature { + for %b = 0 to %batch { + for %oh = 0 to %output_height { + for %ow = 0 to %output_width { + for %of = 0 to %output_feature { + for %kh = 0 to %kernel_height { + for %kw = 0 to %kernel_width { + for %if = 0 to %input_feature { // Calculate input indices. %1_0 = affine.apply #map1_0 (%0#1, %0#2, %0#4, %0#5) [%h_stride, %w_stride, %h_kernel_dilation, %w_kernel_dilation, @@ -899,10 +899,10 @@ func @dma_hbm_to_vmem(memref<1024 x f32, #layout_map0, hbm> %a, representation. 2(b) requires no change, but impacts how cost models look at index and layout maps. -### `affine.if` and `affine.for` Extensions for "Escaping Scalars" {#extensions-for-"escaping-scalars"} +### `affine.if` and `for` Extensions for "Escaping Scalars" {#extensions-for-"escaping-scalars"} We considered providing a representation for SSA values that are live out of -`affine.if/else` conditional bodies and loop carried in `affine.for` loops. We +`affine.if/else` conditional bodies and loop carried in `for` loops. We ultimately abandoned this approach due to its complexity. In the current design of MLIR, scalar variables cannot escape for loops or if instructions. In situations, where escaping is necessary, we use zero-dimensional tensors and @@ -919,7 +919,7 @@ Syntax: ``` {.ebnf} [ =] -affine.for % = ... step +for % = ... step [with ] { } ``` @@ -934,7 +934,7 @@ Example: // Return sum of elements in 1-dimensional mref A func int32 @sum(%A : memref, %N : i32) -> (i32) { %init = 0 - %result = affine.for %i = 0 to N with %tmp(%init) { + %result = for %i = 0 to N with %tmp(%init) { %value = load %A[%i] %sum = %value + %tmp yield %sum @@ -964,7 +964,7 @@ Example: // Compute sum of half of the array func int32 @sum_half(%A, %N) { %s0 = 0 - %s1 = affine.for %i = 1 ... N step 1 with %s2 (%s0) { + %s1 = for %i = 1 ... N step 1 with %s2 (%s0) { %s3 = affine.if (%i >= %N / 2) { %v0 = load %A[%i] %s4 = %s2 + %v0 diff --git a/mlir/g3doc/RationaleSimplifiedPolyhedralForm.md b/mlir/g3doc/RationaleSimplifiedPolyhedralForm.md index a1830f0b4ab..6fe05a4d8c4 100644 --- a/mlir/g3doc/RationaleSimplifiedPolyhedralForm.md +++ b/mlir/g3doc/RationaleSimplifiedPolyhedralForm.md @@ -184,8 +184,8 @@ Our simple example above would be represented as: ```mlir mlfunc @simple_example(... %N) { - affine.for %i = 0 ... %N step 1 { - affine.for %j = 0 ... %N step 1 { + for %i = 0 ... %N step 1 { + for %j = 0 ... %N step 1 { // identity noop in this case, but can exist in general. %0,%1 = affine.apply #57(%i, %j) @@ -203,8 +203,8 @@ The example with the reduced domain would be represented with an if instruction: ```mlir mlfunc @reduced_domain_example(... %N) { - affine.for %i = 0 ... %N step 1 { - affine.for %j = 0 ... %N step 1 { + for %i = 0 ... %N step 1 { + for %j = 0 ... %N step 1 { // identity noop in this case, but can exist in general. %0,%1 = affinecall #57(%i, %j) @@ -233,8 +233,8 @@ that transformations call into): ```mlir mlfunc @skewed_domain_example(... %N) { - affine.for %t1 = 0 ... 2*N-2 step 1 { - affine.for %t2 = max(0, t1-N+1) ... min(N, t1) step 1 { + for %t1 = 0 ... 2*N-2 step 1 { + for %t2 = max(0, t1-N+1) ... min(N, t1) step 1 { (%i, %j) = (%t1-%t2, %t2) ... } @@ -373,7 +373,7 @@ mlfunc's (if we support them) will also have to have domains. ### Lack of redundancy in IR The traditional form has multiple encodings for the same sorts of behavior: you -end up having bits on `affine.for` loops to specify whether codegen should use +end up having bits on `for` loops to specify whether codegen should use "atomic/separate" policies, unroll loops, etc. Instructions can be split or can generate multiple copies of their instruction because of overlapping domains, etc. diff --git a/mlir/include/mlir/AffineOps/AffineOps.h b/mlir/include/mlir/AffineOps/AffineOps.h index b7a5c0c8326..caa7e16cda8 100644 --- a/mlir/include/mlir/AffineOps/AffineOps.h +++ b/mlir/include/mlir/AffineOps/AffineOps.h @@ -90,15 +90,15 @@ private: explicit AffineApplyOp(const Instruction *state) : Op(state) {} }; -/// The "affine.for" instruction represents an affine loop nest, defining an SSA -/// value for its induction variable. The induction variable is represented as a +/// The "for" instruction represents an affine loop nest, defining an SSA value +/// for its induction variable. The induction variable is represented as a /// BlockArgument to the entry block of the body. The body and induction -/// variable can be created automatically for new "affine.for" ops with -/// 'createBody'. This SSA value always has type index, which is the size of the -/// machine word. The stride, represented by step, is a positive constant -/// integer which defaults to "1" if not present. The lower and upper bounds -/// specify a half-open range: the range includes the lower bound but does not -/// include the upper bound. +/// variable can be created automatically for new "for" ops with 'createBody'. +/// This SSA value always has type index, which is the size of the machine word. +/// The stride, represented by step, is a positive constant integer which +/// defaults to "1" if not present. The lower and upper bounds specify a +/// half-open range: the range includes the lower bound but does not include the +/// upper bound. /// /// The lower and upper bounds of a for operation are represented as an /// application of an affine mapping to a list of SSA values passed to the map. @@ -110,7 +110,7 @@ private: /// /// Example: /// -/// affine.for %i = 1 to 10 { +/// for %i = 1 to 10 { /// ... /// } /// @@ -131,7 +131,7 @@ public: static void getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context); - static StringRef getOperationName() { return "affine.for"; } + static StringRef getOperationName() { return "for"; } static StringRef getStepAttrName() { return "step"; } static StringRef getLowerBoundAttrName() { return "lower_bound"; } static StringRef getUpperBoundAttrName() { return "upper_bound"; } @@ -253,15 +253,15 @@ ConstOpPointer getForInductionVarOwner(const Value *val); void extractForInductionVars(ArrayRef> forInsts, SmallVectorImpl *ivs); -/// Adds constraints (lower and upper bounds) for the specified 'affine.for' +/// Adds constraints (lower and upper bounds) for the specified 'for' /// instruction's Value using IR information stored in its bound maps. The /// right identifier is first looked up using forOp's Value. Returns /// false for the yet unimplemented/unsupported cases, and true if the /// information is successfully added. Asserts if the Value corresponding to -/// the 'affine.for' instruction isn't found in the constraint system. Any new -/// identifiers that are found in the bound operands of the 'affine.for' -/// instruction are added as trailing identifiers (either dimensional or -/// symbolic depending on whether the operand is a valid ML Function symbol). +/// the 'for' instruction isn't found in the constraint system. Any new +/// identifiers that are found in the bound operands of the 'for' instruction +/// are added as trailing identifiers (either dimensional or symbolic +/// depending on whether the operand is a valid ML Function symbol). // TODO(bondhugula): add support for non-unit strides. bool addAffineForOpDomain(ConstOpPointer forOp, FlatAffineConstraints *constraints); @@ -297,10 +297,10 @@ public: operand_range getOperands() const { return {operand_begin(), operand_end()}; } private: - // 'affine.for' instruction that contains this bound. + // 'for' instruction that contains this bound. ConstOpPointer inst; // Start and end positions of this affine bound operands in the list of - // the containing 'affine.for' instruction operands. + // the containing 'for' instruction operands. unsigned opStart, opEnd; // Affine map for this bound. AffineMap map; diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index 4e01bc962ed..ee72ac26a0d 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -52,7 +52,7 @@ bool dominates(const Instruction &a, const Instruction &b); bool properlyDominates(const Instruction &a, const Instruction &b); /// Populates 'loops' with IVs of the loops surrounding 'inst' ordered from -/// the outermost 'affine.for' instruction to the innermost one. +/// the outermost 'for' instruction to the innermost one. // TODO(bondhugula): handle 'affine.if' inst's. void getLoopIVs(const Instruction &inst, SmallVectorImpl> *loops); @@ -105,8 +105,8 @@ insertBackwardComputationSlice(Instruction *srcOpInst, Instruction *dstOpInst, /// surrounding such op's. // For example, the memref region for a load operation at loop depth = 1: // -// affine.for %i = 0 to 32 { -// affine.for %ii = %i to (d0) -> (d0 + 8) (%i) { +// for %i = 0 to 32 { +// for %ii = %i to (d0) -> (d0 + 8) (%i) { // load %A[%ii] // } // } @@ -139,8 +139,8 @@ struct MemRefRegion { /// For example, the memref region for this operation at loopDepth = 1 will /// be: /// - /// affine.for %i = 0 to 32 { - /// affine.for %ii = %i to (d0) -> (d0 + 8) (%i) { + /// for %i = 0 to 32 { + /// for %ii = %i to (d0) -> (d0 + 8) (%i) { /// load %A[%ii] /// } /// } diff --git a/mlir/include/mlir/Analysis/VectorAnalysis.h b/mlir/include/mlir/Analysis/VectorAnalysis.h index b3196e14097..4982481bf6c 100644 --- a/mlir/include/mlir/Analysis/VectorAnalysis.h +++ b/mlir/include/mlir/Analysis/VectorAnalysis.h @@ -76,9 +76,9 @@ shapeRatio(VectorType superVectorType, VectorType subVectorType); /// The following MLIR snippet: /// /// ```mlir -/// affine.for %i3 = 0 to %0 { -/// affine.for %i4 = 0 to %1 { -/// affine.for %i5 = 0 to %2 { +/// for %i3 = 0 to %0 { +/// for %i4 = 0 to %1 { +/// for %i5 = 0 to %2 { /// %a5 = load %arg0[%i4, %i5, %i3] : memref /// }}} /// ``` @@ -86,9 +86,9 @@ shapeRatio(VectorType superVectorType, VectorType subVectorType); /// may vectorize with {permutation_map: (d0, d1, d2) -> (d2, d1)} into: /// /// ```mlir -/// affine.for %i3 = 0 to %0 step 32 { -/// affine.for %i4 = 0 to %1 { -/// affine.for %i5 = 0 to %2 step 256 { +/// for %i3 = 0 to %0 step 32 { +/// for %i4 = 0 to %1 { +/// for %i5 = 0 to %2 step 256 { /// %4 = vector_transfer_read %arg0, %i4, %i5, %i3 /// {permutation_map: (d0, d1, d2) -> (d2, d1)} : /// (memref, index, index) -> vector<32x256xf32> @@ -103,7 +103,7 @@ shapeRatio(VectorType superVectorType, VectorType subVectorType); /// /// ```mlir /// %cst0 = constant 0 : index -/// affine.for %i0 = 0 to %0 { +/// for %i0 = 0 to %0 { /// %a0 = load %arg0[%cst0, %cst0] : memref /// } /// ``` @@ -111,7 +111,7 @@ shapeRatio(VectorType superVectorType, VectorType subVectorType); /// may vectorize with {permutation_map: (d0) -> (0)} into: /// /// ```mlir -/// affine.for %i0 = 0 to %0 step 128 { +/// for %i0 = 0 to %0 step 128 { /// %3 = vector_transfer_read %arg0, %c0_0, %c0_0 /// {permutation_map: (d0, d1) -> (0)} : /// (memref, index, index) -> vector<128xf32> diff --git a/mlir/include/mlir/Transforms/LoopUtils.h b/mlir/include/mlir/Transforms/LoopUtils.h index d543b520565..f3d9b9fe9fd 100644 --- a/mlir/include/mlir/Transforms/LoopUtils.h +++ b/mlir/include/mlir/Transforms/LoopUtils.h @@ -83,10 +83,9 @@ AffineMap getUnrolledLoopUpperBound(ConstOpPointer forOp, unsigned unrollFactor, FuncBuilder *builder); -/// Skew the instructions in the body of a 'affine.for' instruction with the -/// specified instruction-wise shifts. The shifts are with respect to the -/// original execution order, and are multiplied by the loop 'step' before being -/// applied. +/// Skew the instructions in the body of a 'for' instruction with the specified +/// instruction-wise shifts. The shifts are with respect to the original +/// execution order, and are multiplied by the loop 'step' before being applied. UtilResult instBodySkew(OpPointer forOp, ArrayRef shifts, bool unrollPrologueEpilogue = false); diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h index eb7f725576a..3b828db6ae9 100644 --- a/mlir/include/mlir/Transforms/Utils.h +++ b/mlir/include/mlir/Transforms/Utils.h @@ -94,14 +94,14 @@ Instruction *createComposedAffineApplyOp(FuncBuilder *builder, Location loc, /// /// Before /// -/// affine.for %i = 0 to #map(%N) +/// for %i = 0 to #map(%N) /// %idx = affine.apply (d0) -> (d0 mod 2) (%i) /// send %A[%idx], ... /// %v = "compute"(%idx, ...) /// /// After /// -/// affine.for %i = 0 to #map(%N) +/// for %i = 0 to #map(%N) /// %idx = affine.apply (d0) -> (d0 mod 2) (%i) /// send %A[%idx], ... /// %idx_ = affine.apply (d0) -> (d0 mod 2) (%i) diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 858b8bd791d..9da155f09d8 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -35,7 +35,7 @@ using llvm::dbgs; //===----------------------------------------------------------------------===// AffineOpsDialect::AffineOpsDialect(MLIRContext *context) - : Dialect(/*namePrefix=*/"affine", context) { + : Dialect(/*namePrefix=*/"", context) { addOperations(); } @@ -716,7 +716,7 @@ static void printBound(AffineBound bound, const char *prefix, OpAsmPrinter *p) { } void AffineForOp::print(OpAsmPrinter *p) const { - *p << "affine.for "; + *p << "for "; p->printOperand(getBody()->getArgument(0)); *p << " = "; printBound(getLowerBound(), "max", p); diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index 3a086ba512d..9d2ea691bdd 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -756,8 +756,8 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const { // For example, given the following MLIR code with with "source" and // "destination" accesses to the same memref labled, and symbols %M, %N, %K: // -// affine.for %i0 = 0 to 100 { -// affine.for %i1 = 0 to 50 { +// for %i0 = 0 to 100 { +// for %i1 = 0 to 50 { // %a0 = affine.apply // (d0, d1) -> (d0 * 2 - d1 * 4 + s1, d1 * 3 - s0) (%i0, %i1)[%M, %N] // // Source memref access. @@ -765,8 +765,8 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const { // } // } // -// affine.for %i2 = 0 to 100 { -// affine.for %i3 = 0 to 50 { +// for %i2 = 0 to 100 { +// for %i3 = 0 to 50 { // %a1 = affine.apply // (d0, d1) -> (d0 * 7 + d1 * 9 - s1, d1 * 11 + s0) (%i2, %i3)[%K, %M] // // Destination memref access. diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 823fbbe9fcd..0499e866fe8 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -36,12 +36,12 @@ using namespace mlir; /// Populates 'loops' with IVs of the loops surrounding 'inst' ordered from -/// the outermost 'affine.for' instruction to the innermost one. +/// the outermost 'for' instruction to the innermost one. void mlir::getLoopIVs(const Instruction &inst, SmallVectorImpl> *loops) { auto *currInst = inst.getParentInst(); OpPointer currAffineForOp; - // Traverse up the hierarchy collecing all 'affine.for' instruction while + // Traverse up the hierarchy collecing all 'for' instruction while // skipping over 'affine.if' instructions. while (currInst && ((currAffineForOp = currInst->dyn_cast()) || currInst->isa())) { @@ -111,8 +111,8 @@ bool MemRefRegion::unionBoundingBox(const MemRefRegion &other) { // For example, the memref region for this load operation at loopDepth = 1 will // be as below: // -// affine.for %i = 0 to 32 { -// affine.for %ii = %i to (d0) -> (d0 + 8) (%i) { +// for %i = 0 to 32 { +// for %ii = %i to (d0) -> (d0 + 8) (%i) { // load %A[%ii] // } // } @@ -614,7 +614,7 @@ Optional mlir::getMemoryFootprintBytes(const Block &block, int memorySpace) { std::vector> regions; - // Walk this 'affine.for' instruction to gather all memory regions. + // Walk this 'for' instruction to gather all memory regions. bool error = false; const_cast(&block)->walk([&](Instruction *opInst) { if (!opInst->isa() && !opInst->isa()) { diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp index 90bc0b76efc..96582032b2b 100644 --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -189,7 +189,7 @@ unsigned Block::getNumSuccessors() const { return terminator->getNumSuccessors(); } assert(getParent() && "top-level block with no terminator"); - // Blocks inside 'affine.for'/'affine.if' instructions don't have successors. + // Blocks inside 'for'/'affine.if' instructions don't have successors. return 0; } diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index dcb4828d0bf..bda98f46b61 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -338,7 +338,7 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, Block *block, auto fastMemRefType = top.getMemRefType( fastBufferShape, memRefType.getElementType(), {}, fastMemorySpace); - // Create the fast memory space buffer just before the 'affine.for' + // Create the fast memory space buffer just before the 'for' // instruction. fastMemRef = prologue.create(loc, fastMemRefType)->getResult(); // Record it. @@ -457,7 +457,7 @@ bool DmaGeneration::runOnBlock(Block *block, uint64_t consumedCapacityBytes) { // approach is conservative in some cases at the moment, we do a check later // and report an error with location info. // TODO(bondhugula): An 'affine.if' instruction is being treated similar to an - // operation instruction. 'affine.if''s could have 'affine.for's in them; + // operation instruction. 'affine.if''s could have 'for's in them; // treat them separately. // Get to the first load, store, or for op. @@ -471,7 +471,7 @@ bool DmaGeneration::runOnBlock(Block *block, uint64_t consumedCapacityBytes) { if (auto forOp = it->dyn_cast()) { // We'll assume for now that loops with steps are tiled loops, and so DMAs // are not performed for that depth, but only further inside. - // If the memory footprint of the 'affine.for' loop is higher than fast + // If the memory footprint of the 'for' loop is higher than fast // memory capacity (when provided), we recurse to DMA at an inner level // until we find a depth at which footprint fits in the capacity. If the // footprint can't be calcuated, we assume for now it fits. @@ -490,11 +490,11 @@ bool DmaGeneration::runOnBlock(Block *block, uint64_t consumedCapacityBytes) { consumedCapacityBytes += runOnBlock(/*begin=*/curBegin, /*end=*/it); // Recurse onto the body of this loop. runOnBlock(forOp->getBody(), consumedCapacityBytes); - // The next region starts right after the 'affine.for' instruction. + // The next region starts right after the 'for' instruction. curBegin = std::next(it); } else { // We have enough capacity, i.e., DMAs will be computed for the portion - // of the block until 'it', and for the 'affine.for' loop. For the + // of the block until 'it', and for the 'for' loop. For the // latter, they are placed just before this loop (for incoming DMAs) and // right after (for outgoing ones). consumedCapacityBytes += runOnBlock(/*begin=*/curBegin, /*end=*/it); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 9e96b0800b3..8d5f51059bf 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -510,8 +510,7 @@ bool MemRefDependenceGraph::init(Function *f) { // all loads and store accesses it contains. LoopNestStateCollector collector; collector.collect(&inst); - // Return false if a non 'affine.for' region was found (not currently - // supported). + // Return false if a non 'for' region was found (not currently supported). if (collector.hasNonForRegion) return false; Node node(nextNodeId++, &inst); diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index f00c2e767e6..368a1dac1df 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -231,8 +231,7 @@ UtilResult mlir::tileCodeGen(MutableArrayRef> band, static void getTileableBands(Function *f, std::vector, 6>> *bands) { - // Get maximal perfect nest of 'affine.for' insts starting from root - // (inclusive). + // Get maximal perfect nest of 'for' insts starting from root (inclusive). auto getMaximalPerfectLoopNest = [&](OpPointer root) { SmallVector, 6> band; OpPointer currInst = root; diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 025a86891df..3a7cfb85e08 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -164,7 +164,7 @@ PassResult LoopUnroll::runOnFunction(Function *f) { return success(); } -/// Unrolls a 'affine.for' inst. Returns true if the loop was unrolled, false +/// Unrolls a 'for' inst. Returns true if the loop was unrolled, false /// otherwise. The default unroll factor is 4. bool LoopUnroll::runOnAffineForOp(OpPointer forOp) { // Use the function callback if one was provided. diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 2f0249824dd..b2aed7d9d7f 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -105,7 +105,7 @@ PassResult LoopUnrollAndJam::runOnFunction(Function *f) { return success(); } -/// Unroll and jam a 'affine.for' inst. Default unroll jam factor is +/// Unroll and jam a 'for' inst. Default unroll jam factor is /// kDefaultUnrollJamFactor. Return false if nothing was done. bool LoopUnrollAndJam::runOnAffineForOp(OpPointer forOp) { // Unroll and jam by the factor that was passed if any. diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index 5ce8a6258f4..ef6ff420912 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -283,8 +283,7 @@ static Value *buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate, return value; } -// Convert a "affine.for" loop to a flow of blocks. Return `false` on -// success. +// Convert a "for" loop to a flow of blocks. Return `false` on success. // // Create an SESE region for the loop (including its body) and append it to the // end of the current region. The loop region consists of the initialization @@ -331,9 +330,8 @@ bool LowerAffinePass::lowerAffineFor(OpPointer forOp) { auto loc = forOp->getLoc(); auto *forInst = forOp->getInstruction(); - // Start by splitting the block containing the 'affine.for' into two parts. - // The part before will get the init code, the part after will be the end - // point. + // Start by splitting the block containing the 'for' into two parts. The part + // before will get the init code, the part after will be the end point. auto *initBlock = forInst->getBlock(); auto *endBlock = initBlock->splitBlock(forInst); diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index e63d3c8111c..63fb45db9c5 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -126,9 +126,9 @@ private: /// // Read the slice `%A[%i0, %i1:%i1+256, %i2:%i2+32]` into /// // vector<32x256xf32> and pad with %f0 to handle the boundary case: /// %f0 = constant 0.0f : f32 -/// affine.for %i0 = 0 to %0 { -/// affine.for %i1 = 0 to %1 step 256 { -/// affine.for %i2 = 0 to %2 step 32 { +/// for %i0 = 0 to %0 { +/// for %i1 = 0 to %1 step 256 { +/// for %i2 = 0 to %2 step 32 { /// %v = vector_transfer_read %A, %i0, %i1, %i2, %f0 /// {permutation_map: (d0, d1, d2) -> (d2, d1)} : /// (memref, index, index, f32) -> vector<32x256xf32> @@ -139,8 +139,8 @@ private: /// MLIR resembling: /// /// ```mlir -/// affine.for %d1 = 0 to 256 { -/// affine.for %d2 = 0 to 32 { +/// for %d1 = 0 to 256 { +/// for %d2 = 0 to 32 { /// %s = %A[%i0, %i1 + %d1, %i2 + %d2] : f32 /// %tmp[%d2, %d1] = %s /// } diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 4434ab5322e..be5a03bc416 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -101,10 +101,10 @@ /// mlfunc @materialize(%M : index, %N : index, %O : index, %P : index) { /// %A = alloc (%M, %N, %O, %P) : memref /// %f1 = constant splat, 1.000000e+00> : -/// vector<4x4x4xf32> affine.for %i0 = 0 to %M step 4 { -/// affine.for %i1 = 0 to %N step 4 { -/// affine.for %i2 = 0 to %O { -/// affine.for %i3 = 0 to %P step 4 { +/// vector<4x4x4xf32> for %i0 = 0 to %M step 4 { +/// for %i1 = 0 to %N step 4 { +/// for %i2 = 0 to %O { +/// for %i3 = 0 to %P step 4 { /// vector_transfer_write %f1, %A, %i0, %i1, %i2, %i3 /// {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d0)} : /// vector<4x4x4xf32>, memref, @@ -120,10 +120,10 @@ /// mlfunc @materialize(%M : index, %N : index, %O : index, %P : index) { /// %A = alloc (%M, %N, %O, %P) : memref /// %f1 = constant splat, 1.000000e+00> : vector<4x4x4xf32> -/// affine.for %i0 = 0 to %arg0 step 4 { -/// affine.for %i1 = 0 to %arg1 step 4 { -/// affine.for %i2 = 0 to %arg2 { -/// affine.for %i3 = 0 to %arg3 step 4 { +/// for %i0 = 0 to %arg0 step 4 { +/// for %i1 = 0 to %arg1 step 4 { +/// for %i2 = 0 to %arg2 { +/// for %i3 = 0 to %arg3 step 4 { /// %1 = affine.apply (d0, d1, d2, d3) -> (d0, d1, d2, d3) /// (%i0, %i1, %i2, %i3) /// vector_transfer_write f1, %0, %1#0, %1#1, %1#2, %1#3 @@ -293,10 +293,10 @@ static Value *substitute(Value *v, VectorType hwVectorType, /// super-vectorization has been applied: /// /// ```mlir -/// affine.for %i0 = 0 to %M { -/// affine.for %i1 = 0 to %N step 3 { -/// affine.for %i2 = 0 to %O { -/// affine.for %i3 = 0 to %P step 32 { +/// for %i0 = 0 to %M { +/// for %i1 = 0 to %N step 3 { +/// for %i2 = 0 to %O { +/// for %i3 = 0 to %P step 32 { /// %r = vector_transfer_read(%A, map(%i..)#0, map(%i..)#1, map(%i..)#2) /// -> vector<3x32xf32> /// ... diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index 91a17764358..ad9801fea89 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -19,7 +19,7 @@ // potentially getting rid of intermediate memref's entirely. // TODO(mlir-team): In the future, similar techniques could be used to eliminate // dead memref store's and perform more complex forwarding when support for -// SSA scalars live out of 'affine.for'/'affine.if' statements is available. +// SSA scalars live out of 'for'/'affine.if' statements is available. //===----------------------------------------------------------------------===// #include "mlir/Analysis/AffineAnalysis.h" @@ -55,7 +55,7 @@ namespace { // // (* A dependence being satisfied at a block: a dependence that is satisfied by // virtue of the destination instruction appearing textually / lexically after -// the source instruction within the body of a 'affine.for' instruction; thus, a +// the source instruction within the body of a 'for' instruction; thus, a // dependence is always either satisfied by a loop or by a block). // // The above conditions are simple to check, sufficient, and powerful for most @@ -145,8 +145,8 @@ void MemRefDataFlowOpt::forwardStoreToLoad(OpPointer loadOp) { // Check if this store is a candidate for forwarding; we only forward if // the dependence from the store is carried by the *body* of innermost // common surrounding loop. As an example this filters out cases like: - // affine.for %i0 - // affine.for %i1 + // for %i0 + // for %i1 // %idx = affine.apply (d0) -> (d0 + 1) (%i0) // store %A[%idx] // load %A[%i0] diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 84c8cd830dc..cfa045f2279 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -71,11 +71,11 @@ static unsigned getTagMemRefPos(const Instruction &dmaInst) { return 0; } -/// Doubles the buffer of the supplied memref on the specified 'affine.for' -/// instruction by adding a leading dimension of size two to the memref. -/// Replaces all uses of the old memref by the new one while indexing the newly -/// added dimension by the loop IV of the specified 'affine.for' instruction -/// modulo 2. Returns false if such a replacement cannot be performed. +/// Doubles the buffer of the supplied memref on the specified 'for' instruction +/// by adding a leading dimension of size two to the memref. Replaces all uses +/// of the old memref by the new one while indexing the newly added dimension by +/// the loop IV of the specified 'for' instruction modulo 2. Returns false if +/// such a replacement cannot be performed. static bool doubleBuffer(Value *oldMemRef, OpPointer forOp) { auto *forBody = forOp->getBody(); FuncBuilder bInner(forBody, forBody->begin()); @@ -108,7 +108,7 @@ static bool doubleBuffer(Value *oldMemRef, OpPointer forOp) { dynamicDimCount++)); } - // Create and place the alloc right before the 'affine.for' instruction. + // Create and place the alloc right before the 'for' instruction. // TODO(mlir-team): we are assuming scoped allocation here, and aren't // inserting a dealloc -- this isn't the right thing. Value *newMemRef = @@ -137,9 +137,9 @@ static bool doubleBuffer(Value *oldMemRef, OpPointer forOp) { /// Returns success if the IR is in a valid state. PassResult PipelineDataTransfer::runOnFunction(Function *f) { // Do a post order walk so that inner loop DMAs are processed first. This is - // necessary since 'affine.for' instructions nested within would otherwise - // become invalid (erased) when the outer loop is pipelined (the pipelined one - // gets deleted and replaced by a prologue, a new steady-state loop and an + // necessary since 'for' instructions nested within would otherwise become + // invalid (erased) when the outer loop is pipelined (the pipelined one gets + // deleted and replaced by a prologue, a new steady-state loop and an // epilogue). forOps.clear(); f->walkPostOrder( diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 110949f43d5..a1903ace026 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -138,8 +138,8 @@ void mlir::promoteSingleIterationLoops(Function *f) { [](OpPointer forOp) { promoteIfSingleIteration(forOp); }); } -/// Generates a 'affine.for' inst with the specified lower and upper bounds -/// while generating the right IV remappings for the shifted instructions. The +/// Generates a 'for' inst with the specified lower and upper bounds while +/// generating the right IV remappings for the shifted instructions. The /// instruction blocks that go into the loop are specified in instGroupQueue /// starting from the specified offset, and in that order; the first element of /// the pair specifies the shift applied to that group of instructions; note @@ -194,10 +194,10 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, return loopChunk; } -/// Skew the instructions in the body of a 'affine.for' instruction with the -/// specified instruction-wise shifts. The shifts are with respect to the -/// original execution order, and are multiplied by the loop 'step' before being -/// applied. A shift of zero for each instruction will lead to no change. +/// Skew the instructions in the body of a 'for' instruction with the specified +/// instruction-wise shifts. The shifts are with respect to the original +/// execution order, and are multiplied by the loop 'step' before being applied. +/// A shift of zero for each instruction will lead to no change. // The skewing of instructions with respect to one another can be used for // example to allow overlap of asynchronous operations (such as DMA // communication) with computation, or just relative shifting of instructions @@ -246,7 +246,7 @@ UtilResult mlir::instBodySkew(OpPointer forOp, // An array of instruction groups sorted by shift amount; each group has all // instructions with the same shift in the order in which they appear in the - // body of the 'affine.for' inst. + // body of the 'for' inst. std::vector> sortedInstGroups(maxShift + 1); unsigned pos = 0; for (auto &inst : *forOp->getBody()) { diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 90d28bf34df..41689be52fc 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -194,14 +194,14 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, /// /// Before /// -/// affine.for %i = 0 to #map(%N) +/// for %i = 0 to #map(%N) /// %idx = affine.apply (d0) -> (d0 mod 2) (%i) /// "send"(%idx, %A, ...) /// "compute"(%idx) /// /// After /// -/// affine.for %i = 0 to #map(%N) +/// for %i = 0 to #map(%N) /// %idx = affine.apply (d0) -> (d0 mod 2) (%i) /// "send"(%idx, %A, ...) /// %idx_ = affine.apply (d0) -> (d0 mod 2) (%i) diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 1f4c7b9fcc8..5a8d5d24661 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -113,7 +113,7 @@ using namespace mlir; /// /// At a high level, a vectorized load in a loop will resemble: /// ```mlir -/// affine.for %i = ? to ? step ? { +/// for %i = ? to ? step ? { /// %v_a = "vector_transfer_read" (A, %i) : (memref, index) -> /// vector<128xf32> /// } @@ -309,7 +309,7 @@ using namespace mlir; /// ```mlir /// mlfunc @fill(%A : memref<128xf32>) -> () { /// %f1 = constant 1.0 : f32 -/// affine.for %i0 = 0 to 32 { +/// for %i0 = 0 to 32 { /// store %f1, %A[%i0] : memref<128xf32, 0> /// } /// return @@ -322,7 +322,7 @@ using namespace mlir; /// is still subject to exploratory tradeoffs. In particular, say we want to /// vectorize by a factor 128, we want to transform the following input: /// ```mlir -/// affine.for %i = %M to %N { +/// for %i = %M to %N { /// %a = load A[%i] : memref /// } /// ``` @@ -331,8 +331,8 @@ using namespace mlir; /// memory promotion etc) say after stripmining (and potentially unrolling in /// the case of LLVM's SLP vectorizer): /// ```mlir -/// affine.for %i = floor(%M, 128) to ceil(%N, 128) { -/// affine.for %ii = max(%M, 128 * %i) to min(%N, 128*%i + 127) { +/// for %i = floor(%M, 128) to ceil(%N, 128) { +/// for %ii = max(%M, 128 * %i) to min(%N, 128*%i + 127) { /// %a = load A[%ii] : memref /// } /// } @@ -341,7 +341,7 @@ using namespace mlir; /// Instead, we seek to vectorize early and freeze vector types before /// scheduling, so we want to generate a pattern that resembles: /// ```mlir -/// affine.for %i = ? to ? step ? { +/// for %i = ? to ? step ? { /// %v_a = "vector_transfer_read" (A, %i) : (memref, index) -> /// vector<128xf32> /// } @@ -362,7 +362,7 @@ using namespace mlir; /// For the simple strawman example above, vectorizing for a 1-D vector /// abstraction of size 128 returns code similar to: /// ```mlir -/// affine.for %i = %M to %N step 128 { +/// for %i = %M to %N step 128 { /// %v_a = "vector_transfer_read" (A, %i) : (memref, index) -> /// vector<128xf32> /// } @@ -391,20 +391,20 @@ using namespace mlir; /// %C = alloc (%M, %N) : memref /// %f1 = constant 1.0 : f32 /// %f2 = constant 2.0 : f32 -/// affine.for %i0 = 0 to %M { -/// affine.for %i1 = 0 to %N { +/// for %i0 = 0 to %M { +/// for %i1 = 0 to %N { /// // non-scoped %f1 /// store %f1, %A[%i0, %i1] : memref /// } /// } -/// affine.for %i2 = 0 to %M { -/// affine.for %i3 = 0 to %N { +/// for %i2 = 0 to %M { +/// for %i3 = 0 to %N { /// // non-scoped %f2 /// store %f2, %B[%i2, %i3] : memref /// } /// } -/// affine.for %i4 = 0 to %M { -/// affine.for %i5 = 0 to %N { +/// for %i4 = 0 to %M { +/// for %i5 = 0 to %N { /// %a5 = load %A[%i4, %i5] : memref /// %b5 = load %B[%i4, %i5] : memref /// %s5 = addf %a5, %b5 : f32 @@ -438,24 +438,24 @@ using namespace mlir; /// %2 = alloc(%arg0, %arg1) : memref /// %cst = constant 1.0 : f32 /// %cst_0 = constant 2.0 : f32 -/// affine.for %i0 = 0 to %arg0 { -/// affine.for %i1 = 0 to %arg1 step 256 { +/// for %i0 = 0 to %arg0 { +/// for %i1 = 0 to %arg1 step 256 { /// %cst_1 = constant splat, 1.0> : /// vector<256xf32> /// "vector_transfer_write"(%cst_1, %0, %i0, %i1) : /// (vector<256xf32>, memref, index, index) -> () /// } /// } -/// affine.for %i2 = 0 to %arg0 { -/// affine.for %i3 = 0 to %arg1 step 256 { +/// for %i2 = 0 to %arg0 { +/// for %i3 = 0 to %arg1 step 256 { /// %cst_2 = constant splat, 2.0> : /// vector<256xf32> /// "vector_transfer_write"(%cst_2, %1, %i2, %i3) : /// (vector<256xf32>, memref, index, index) -> () /// } /// } -/// affine.for %i4 = 0 to %arg0 { -/// affine.for %i5 = 0 to %arg1 step 256 { +/// for %i4 = 0 to %arg0 { +/// for %i5 = 0 to %arg1 step 256 { /// %3 = "vector_transfer_read"(%0, %i4, %i5) : /// (memref, index, index) -> vector<256xf32> /// %4 = "vector_transfer_read"(%1, %i4, %i5) : @@ -494,24 +494,24 @@ using namespace mlir; /// %2 = alloc(%arg0, %arg1) : memref /// %cst = constant 1.0 : f32 /// %cst_0 = constant 2.0 : f32 -/// affine.for %i0 = 0 to %arg0 step 32 { -/// affine.for %i1 = 0 to %arg1 step 256 { +/// for %i0 = 0 to %arg0 step 32 { +/// for %i1 = 0 to %arg1 step 256 { /// %cst_1 = constant splat, 1.0> : /// vector<32x256xf32> /// "vector_transfer_write"(%cst_1, %0, %i0, %i1) : /// (vector<32x256xf32>, memref, index, index) -> () /// } /// } -/// affine.for %i2 = 0 to %arg0 step 32 { -/// affine.for %i3 = 0 to %arg1 step 256 { +/// for %i2 = 0 to %arg0 step 32 { +/// for %i3 = 0 to %arg1 step 256 { /// %cst_2 = constant splat, 2.0> : /// vector<32x256xf32> /// "vector_transfer_write"(%cst_2, %1, %i2, %i3) : /// (vector<32x256xf32>, memref, index, index) -> () /// } /// } -/// affine.for %i4 = 0 to %arg0 step 32 { -/// affine.for %i5 = 0 to %arg1 step 256 { +/// for %i4 = 0 to %arg0 step 32 { +/// for %i5 = 0 to %arg1 step 256 { /// %3 = "vector_transfer_read"(%0, %i4, %i5) : /// (memref, index, index) -> vector<32x256xf32> /// %4 = "vector_transfer_read"(%1, %i4, %i5) : diff --git a/mlir/test/AffineOps/canonicalize.mlir b/mlir/test/AffineOps/canonicalize.mlir index 163cfbe0985..ad6f39f3496 100644 --- a/mlir/test/AffineOps/canonicalize.mlir +++ b/mlir/test/AffineOps/canonicalize.mlir @@ -32,7 +32,7 @@ func @compose_affine_maps_1dto2d_no_symbols() { %0 = alloc() : memref<4x4xf32> - affine.for %i0 = 0 to 15 { + for %i0 = 0 to 15 { // Test load[%x, %x] %x0 = affine.apply (d0) -> (d0 - 1) (%i0) @@ -78,7 +78,7 @@ func @compose_affine_maps_1dto2d_no_symbols() { func @compose_affine_maps_1dto2d_with_symbols() { %0 = alloc() : memref<4x4xf32> - affine.for %i0 = 0 to 15 { + for %i0 = 0 to 15 { // Test load[%x0, %x0] with symbol %c4 %c4 = constant 4 : index %x0 = affine.apply (d0)[s0] -> (d0 - s0) (%i0)[%c4] @@ -119,13 +119,13 @@ func @compose_affine_maps_2d_tile() { %c4 = constant 4 : index %c8 = constant 8 : index - affine.for %i0 = 0 to 3 { + for %i0 = 0 to 3 { %x0 = affine.apply (d0)[s0] -> (d0 ceildiv s0) (%i0)[%c4] - affine.for %i1 = 0 to 3 { + for %i1 = 0 to 3 { %x1 = affine.apply (d0)[s0] -> (d0 ceildiv s0) (%i1)[%c8] - affine.for %i2 = 0 to 3 { + for %i2 = 0 to 3 { %x2 = affine.apply (d0)[s0] -> (d0 mod s0) (%i2)[%c4] - affine.for %i3 = 0 to 3 { + for %i3 = 0 to 3 { %x3 = affine.apply (d0)[s0] -> (d0 mod s0) (%i3)[%c8] %x40 = affine.apply (d0, d1, d2, d3)[s0, s1] -> @@ -151,9 +151,9 @@ func @compose_affine_maps_dependent_loads() { %0 = alloc() : memref<16x32xf32> %1 = alloc() : memref<16x32xf32> - affine.for %i0 = 0 to 3 { - affine.for %i1 = 0 to 3 { - affine.for %i2 = 0 to 3 { + for %i0 = 0 to 3 { + for %i1 = 0 to 3 { + for %i2 = 0 to 3 { %c3 = constant 3 : index %c7 = constant 7 : index @@ -197,7 +197,7 @@ func @compose_affine_maps_dependent_loads() { func @compose_affine_maps_diamond_dependency() { %0 = alloc() : memref<4x4xf32> - affine.for %i0 = 0 to 15 { + for %i0 = 0 to 15 { %a = affine.apply (d0) -> (d0 - 1) (%i0) %b = affine.apply (d0) -> (d0 + 7) (%a) %c = affine.apply (d0) -> (d0 * 4) (%a) @@ -217,8 +217,8 @@ func @arg_used_as_dim_and_symbol(%arg0: memref<100x100xf32>, %arg1: index) { %c9 = constant 9 : index %1 = alloc() : memref<100x100xf32, 1> %2 = alloc() : memref<1xi32> - affine.for %i0 = 0 to 100 { - affine.for %i1 = 0 to 100 { + for %i0 = 0 to 100 { + for %i1 = 0 to 100 { %3 = affine.apply (d0, d1)[s0, s1] -> (d1 + s0 + s1) (%i0, %i1)[%arg1, %c9] %4 = affine.apply (d0, d1, d3) -> (d3 - (d0 + d1)) @@ -238,7 +238,7 @@ func @trivial_maps() { %0 = alloc() : memref<10xf32> %c0 = constant 0 : index %cst = constant 0.000000e+00 : f32 - affine.for %i1 = 0 to 10 { + for %i1 = 0 to 10 { %1 = affine.apply ()[s0] -> (s0)()[%c0] store %cst, %0[%1] : memref<10xf32> %2 = load %0[%c0] : memref<10xf32> @@ -277,20 +277,20 @@ func @constant_fold_bounds(%N : index) { %c3 = affine.apply (d0, d1) -> (d0 + d1) (%c1, %c2) %l = "foo"() : () -> index - // CHECK: affine.for %i0 = 5 to 7 { - affine.for %i = max (d0, d1) -> (0, d0 + d1)(%c2, %c3) to min (d0, d1) -> (d0 - 2, 32*d1) (%c9, %c1) { + // CHECK: for %i0 = 5 to 7 { + for %i = max (d0, d1) -> (0, d0 + d1)(%c2, %c3) to min (d0, d1) -> (d0 - 2, 32*d1) (%c9, %c1) { "foo"(%i, %c3) : (index, index) -> () } // Bound takes a non-constant argument but can still be folded. - // CHECK: affine.for %i1 = 1 to 7 { - affine.for %j = max (d0) -> (0, 1)(%N) to min (d0, d1) -> (7, 9)(%N, %l) { + // CHECK: for %i1 = 1 to 7 { + for %j = max (d0) -> (0, 1)(%N) to min (d0, d1) -> (7, 9)(%N, %l) { "foo"(%j, %c3) : (index, index) -> () } // None of the bounds can be folded. - // CHECK: affine.for %i2 = max [[MAP0]]()[%0] to min [[MAP1]]()[%arg0] { - affine.for %k = max ()[s0] -> (0, s0) ()[%l] to min ()[s0] -> (100, s0)()[%N] { + // CHECK: for %i2 = max [[MAP0]]()[%0] to min [[MAP1]]()[%arg0] { + for %k = max ()[s0] -> (0, s0) ()[%l] to min ()[s0] -> (100, s0)()[%N] { "foo"(%k, %c3) : (index, index) -> () } return diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir index 99e0f682216..bd7e062063c 100644 --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -204,35 +204,35 @@ func @illegaltype(i0) // expected-error {{invalid integer width}} // ----- func @malformed_for_percent() { - affine.for i = 1 to 10 { // expected-error {{expected SSA operand}} + for i = 1 to 10 { // expected-error {{expected SSA operand}} // ----- func @malformed_for_equal() { - affine.for %i 1 to 10 { // expected-error {{expected '='}} + for %i 1 to 10 { // expected-error {{expected '='}} // ----- func @malformed_for_to() { - affine.for %i = 1 too 10 { // expected-error {{expected 'to' between bounds}} + for %i = 1 too 10 { // expected-error {{expected 'to' between bounds}} } } // ----- func @incomplete_for() { - affine.for %i = 1 to 10 step 2 + for %i = 1 to 10 step 2 } // expected-error {{expected '{' to begin block list}} // ----- func @nonconstant_step(%1 : i32) { - affine.for %2 = 1 to 5 step %1 { // expected-error {{expected non-function type}} + for %2 = 1 to 5 step %1 { // expected-error {{expected non-function type}} // ----- func @for_negative_stride() { - affine.for %i = 1 to 10 step -1 + for %i = 1 to 10 step -1 } // expected-error@-1 {{expected step to be representable as a positive signed integer}} // ----- @@ -244,7 +244,7 @@ func @non_instruction() { // ----- func @invalid_if_conditional2() { - affine.for %i = 1 to 10 { + for %i = 1 to 10 { affine.if (i)[N] : (i >= ) // expected-error {{expected '== 0' or '>= 0' at end of affine constraint}} } } @@ -252,7 +252,7 @@ func @invalid_if_conditional2() { // ----- func @invalid_if_conditional3() { - affine.for %i = 1 to 10 { + for %i = 1 to 10 { affine.if (i)[N] : (i == 1) // expected-error {{expected '0' after '=='}} } } @@ -260,7 +260,7 @@ func @invalid_if_conditional3() { // ----- func @invalid_if_conditional4() { - affine.for %i = 1 to 10 { + for %i = 1 to 10 { affine.if (i)[N] : (i >= 2) // expected-error {{expected '0' after '>='}} } } @@ -268,7 +268,7 @@ func @invalid_if_conditional4() { // ----- func @invalid_if_conditional5() { - affine.for %i = 1 to 10 { + for %i = 1 to 10 { affine.if (i)[N] : (i <= 0 ) // expected-error {{expected '== 0' or '>= 0' at end of affine constraint}} } } @@ -276,7 +276,7 @@ func @invalid_if_conditional5() { // ----- func @invalid_if_conditional6() { - affine.for %i = 1 to 10 { + for %i = 1 to 10 { affine.if (i) : (i) // expected-error {{expected '== 0' or '>= 0' at end of affine constraint}} } } @@ -284,7 +284,7 @@ func @invalid_if_conditional6() { // ----- // TODO (support affine.if (1)? func @invalid_if_conditional7() { - affine.for %i = 1 to 10 { + for %i = 1 to 10 { affine.if (i) : (1) // expected-error {{expected '== 0' or '>= 0' at end of affine constraint}} } } @@ -438,8 +438,8 @@ func @undef() { // ----- func @duplicate_induction_var() { - affine.for %i = 1 to 10 { // expected-error {{previously defined here}} - affine.for %i = 1 to 10 { // expected-error {{redefinition of SSA value '%i'}} + for %i = 1 to 10 { // expected-error {{previously defined here}} + for %i = 1 to 10 { // expected-error {{redefinition of SSA value '%i'}} } } return @@ -448,7 +448,7 @@ func @duplicate_induction_var() { // ----- func @dominance_failure() { - affine.for %i = 1 to 10 { + for %i = 1 to 10 { } "xxx"(%i) : (index)->() // expected-error {{operand #0 does not dominate this use}} return @@ -475,7 +475,7 @@ func @return_type_mismatch() -> i32 { // ----- func @return_inside_loop() -> i8 { - affine.for %i = 1 to 100 { + for %i = 1 to 100 { %a = "foo"() : ()->i8 return %a : i8 // expected-error@-1 {{'return' op may only be at the top level of a function}} @@ -521,7 +521,7 @@ func @referer() { #map1 = (i)[j] -> (i+j) func @bound_symbol_mismatch(%N : index) { - affine.for %i = #map1(%N) to 100 { + for %i = #map1(%N) to 100 { // expected-error@-1 {{symbol operand count and integer set symbol count must match}} } return @@ -532,7 +532,7 @@ func @bound_symbol_mismatch(%N : index) { #map1 = (i)[j] -> (i+j) func @bound_dim_mismatch(%N : index) { - affine.for %i = #map1(%N, %N)[%N] to 100 { + for %i = #map1(%N, %N)[%N] to 100 { // expected-error@-1 {{dim operand count and integer set dim count must match}} } return @@ -541,7 +541,7 @@ func @bound_dim_mismatch(%N : index) { // ----- func @large_bound() { - affine.for %i = 1 to 9223372036854775810 { + for %i = 1 to 9223372036854775810 { // expected-error@-1 {{integer constant out of range for attribute}} } return @@ -550,7 +550,7 @@ func @large_bound() { // ----- func @max_in_upper_bound(%N : index) { - affine.for %i = 1 to max (i)->(N, 100) { //expected-error {{expected non-function type}} + for %i = 1 to max (i)->(N, 100) { //expected-error {{expected non-function type}} } return } @@ -558,7 +558,7 @@ func @max_in_upper_bound(%N : index) { // ----- func @step_typo() { - affine.for %i = 1 to 100 step -- 1 { //expected-error {{expected constant integer}} + for %i = 1 to 100 step -- 1 { //expected-error {{expected constant integer}} } return } @@ -566,7 +566,7 @@ func @step_typo() { // ----- func @invalid_bound_map(%N : i32) { - affine.for %i = 1 to (i)->(j)(%N) { //expected-error {{use of undeclared identifier}} + for %i = 1 to (i)->(j)(%N) { //expected-error {{use of undeclared identifier}} } return } @@ -579,7 +579,7 @@ func @invalid_bound_map(%N : i32) { #set0 = (i)[N] : (i >= 0, N - i >= 0) func @invalid_if_operands1(%N : index) { - affine.for %i = 1 to 10 { + for %i = 1 to 10 { affine.if #set0(%i) { // expected-error@-1 {{symbol operand count and integer set symbol count must match}} @@ -587,7 +587,7 @@ func @invalid_if_operands1(%N : index) { #set0 = (i)[N] : (i >= 0, N - i >= 0) func @invalid_if_operands2(%N : index) { - affine.for %i = 1 to 10 { + for %i = 1 to 10 { affine.if #set0()[%N] { // expected-error@-1 {{dim operand count and integer set dim count must match}} @@ -595,7 +595,7 @@ func @invalid_if_operands2(%N : index) { #set0 = (i)[N] : (i >= 0, N - i >= 0) func @invalid_if_operands3(%N : index) { - affine.for %i = 1 to 10 { + for %i = 1 to 10 { affine.if #set0(%i)[%i] { // expected-error@-1 {{operand cannot be used as a symbol}} } @@ -736,11 +736,11 @@ func @f(f32) { // ----- func @f(%m : memref) { - affine.for %i0 = 0 to 42 { + for %i0 = 0 to 42 { // expected-error@+1 {{operand #2 does not dominate this use}} %x = load %m[%i0, %i1] : memref } - affine.for %i1 = 0 to 42 { + for %i1 = 0 to 42 { } return } @@ -790,7 +790,7 @@ func @type_alias_unknown(!unknown_alias) -> () { // expected-error {{undefined t // Check ill-formed opaque tensor. func @complex_loops() { - affine.for %i1 = 1 to 100 { + for %i1 = 1 to 100 { // expected-error @+1 {{expected '"' in string literal}} "opaqueIntTensor"(){bar: opaque, "0x686]>} : () -> () @@ -824,7 +824,7 @@ func @invalid_affine_structure() { func @missing_for_max(%arg0: index, %arg1: index, %arg2: memref<100xf32>) { // expected-error @+1 {{lower loop bound affine map with multiple results requires 'max' prefix}} - affine.for %i0 = ()[s]->(0,s-1)()[%arg0] to %arg1 { + for %i0 = ()[s]->(0,s-1)()[%arg0] to %arg1 { } return } @@ -833,7 +833,7 @@ func @missing_for_max(%arg0: index, %arg1: index, %arg2: memref<100xf32>) { func @missing_for_min(%arg0: index, %arg1: index, %arg2: memref<100xf32>) { // expected-error @+1 {{upper loop bound affine map with multiple results requires 'min' prefix}} - affine.for %i0 = %arg0 to ()[s]->(100,s+1)()[%arg1] { + for %i0 = %arg0 to ()[s]->(100,s+1)()[%arg1] { } return } diff --git a/mlir/test/IR/locations.mlir b/mlir/test/IR/locations.mlir index ac4925e3e52..3b27301cfae 100644 --- a/mlir/test/IR/locations.mlir +++ b/mlir/test/IR/locations.mlir @@ -13,7 +13,7 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) { %2 = constant 4 : index loc(callsite("foo" at "mysource.cc":10:8)) // CHECK: } loc(fused["foo", "mysource.cc":10:8]) - affine.for %i0 = 0 to 8 { + for %i0 = 0 to 8 { } loc(fused["foo", "mysource.cc":10:8]) // CHECK: } loc(fused<"myPass">["foo", "foo2"]) diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index a194c52344a..8fa3116a139 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -208,8 +208,8 @@ func @identity_functor(%a : () -> ()) -> (() -> ()) { func @func_ops_in_loop() { // CHECK: %0 = "foo"() : () -> i64 %a = "foo"() : ()->i64 - // CHECK: affine.for %i0 = 1 to 10 { - affine.for %i = 1 to 10 { + // CHECK: for %i0 = 1 to 10 { + for %i = 1 to 10 { // CHECK: %1 = "doo"() : () -> f32 %b = "doo"() : ()->f32 // CHECK: "bar"(%0, %1) : (i64, f32) -> () @@ -224,10 +224,10 @@ func @func_ops_in_loop() { // CHECK-LABEL: func @loops() { func @loops() { - // CHECK: affine.for %i0 = 1 to 100 step 2 { - affine.for %i = 1 to 100 step 2 { - // CHECK: affine.for %i1 = 1 to 200 { - affine.for %j = 1 to 200 { + // CHECK: for %i0 = 1 to 100 step 2 { + for %i = 1 to 100 step 2 { + // CHECK: for %i1 = 1 to 200 { + for %j = 1 to 200 { } // CHECK: } } // CHECK: } return // CHECK: return @@ -235,14 +235,14 @@ func @loops() { // CHECK-LABEL: func @complex_loops() { func @complex_loops() { - affine.for %i1 = 1 to 100 { // CHECK: affine.for %i0 = 1 to 100 { - affine.for %j1 = 1 to 100 { // CHECK: affine.for %i1 = 1 to 100 { + for %i1 = 1 to 100 { // CHECK: for %i0 = 1 to 100 { + for %j1 = 1 to 100 { // CHECK: for %i1 = 1 to 100 { // CHECK: "foo"(%i0, %i1) : (index, index) -> () "foo"(%i1, %j1) : (index,index) -> () } // CHECK: } "boo"() : () -> () // CHECK: "boo"() : () -> () - affine.for %j2 = 1 to 10 { // CHECK: affine.for %i2 = 1 to 10 { - affine.for %k2 = 1 to 10 { // CHECK: affine.for %i3 = 1 to 10 { + for %j2 = 1 to 10 { // CHECK: for %i2 = 1 to 10 { + for %k2 = 1 to 10 { // CHECK: for %i3 = 1 to 10 { "goo"() : () -> () // CHECK: "goo"() : () -> () } // CHECK: } } // CHECK: } @@ -253,8 +253,8 @@ func @complex_loops() { // CHECK: func @triang_loop(%arg0: index, %arg1: memref) { func @triang_loop(%arg0: index, %arg1: memref) { %c = constant 0 : i32 // CHECK: %c0_i32 = constant 0 : i32 - affine.for %i0 = 1 to %arg0 { // CHECK: affine.for %i0 = 1 to %arg0 { - affine.for %i1 = (d0)[]->(d0)(%i0)[] to %arg0 { // CHECK: affine.for %i1 = #map{{[0-9]+}}(%i0) to %arg0 { + for %i0 = 1 to %arg0 { // CHECK: for %i0 = 1 to %arg0 { + for %i1 = (d0)[]->(d0)(%i0)[] to %arg0 { // CHECK: for %i1 = #map{{[0-9]+}}(%i0) to %arg0 { store %c, %arg1[%i0, %i1] : memref // CHECK: store %c0_i32, %arg1[%i0, %i1] } // CHECK: } } // CHECK: } @@ -263,8 +263,8 @@ func @triang_loop(%arg0: index, %arg1: memref) { // CHECK: func @minmax_loop(%arg0: index, %arg1: index, %arg2: memref<100xf32>) { func @minmax_loop(%arg0: index, %arg1: index, %arg2: memref<100xf32>) { - // CHECK: affine.for %i0 = max #map{{.*}}()[%arg0] to min #map{{.*}}()[%arg1] { - affine.for %i0 = max()[s]->(0,s-1)()[%arg0] to min()[s]->(100,s+1)()[%arg1] { + // CHECK: for %i0 = max #map{{.*}}()[%arg0] to min #map{{.*}}()[%arg1] { + for %i0 = max()[s]->(0,s-1)()[%arg0] to min()[s]->(100,s+1)()[%arg1] { // CHECK: "foo"(%arg2, %i0) : (memref<100xf32>, index) -> () "foo"(%arg2, %i0) : (memref<100xf32>, index) -> () } // CHECK: } @@ -275,24 +275,24 @@ func @minmax_loop(%arg0: index, %arg1: index, %arg2: memref<100xf32>) { func @loop_bounds(%N : index) { // CHECK: %0 = "foo"(%arg0) : (index) -> index %s = "foo"(%N) : (index) -> index - // CHECK: affine.for %i0 = %0 to %arg0 - affine.for %i = %s to %N { - // CHECK: affine.for %i1 = #map{{[0-9]+}}(%i0) to 0 - affine.for %j = (d0)[]->(d0)(%i)[] to 0 step 1 { + // CHECK: for %i0 = %0 to %arg0 + for %i = %s to %N { + // CHECK: for %i1 = #map{{[0-9]+}}(%i0) to 0 + for %j = (d0)[]->(d0)(%i)[] to 0 step 1 { // CHECK: %1 = affine.apply #map{{.*}}(%i0, %i1)[%0] %w1 = affine.apply(d0, d1)[s0] -> (d0+d1) (%i, %j) [%s] // CHECK: %2 = affine.apply #map{{.*}}(%i0, %i1)[%0] %w2 = affine.apply(d0, d1)[s0] -> (s0+1) (%i, %j) [%s] - // CHECK: affine.for %i2 = #map{{.*}}(%1, %i0)[%arg0] to #map{{.*}}(%2, %i1)[%0] { - affine.for %k = #bound_map1 (%w1, %i)[%N] to (i, j)[s] -> (i + j + s) (%w2, %j)[%s] { + // CHECK: for %i2 = #map{{.*}}(%1, %i0)[%arg0] to #map{{.*}}(%2, %i1)[%0] { + for %k = #bound_map1 (%w1, %i)[%N] to (i, j)[s] -> (i + j + s) (%w2, %j)[%s] { // CHECK: "foo"(%i0, %i1, %i2) : (index, index, index) -> () "foo"(%i, %j, %k) : (index, index, index)->() // CHECK: %c30 = constant 30 : index %c = constant 30 : index // CHECK: %3 = affine.apply #map{{.*}}(%arg0, %c30) %u = affine.apply (d0, d1)->(d0+d1) (%N, %c) - // CHECK: affine.for %i3 = max #map{{.*}}(%i0)[%3] to min #map{{.*}}(%i2)[%c30] { - affine.for %l = max #bound_map2(%i)[%u] to min #bound_map2(%k)[%c] { + // CHECK: for %i3 = max #map{{.*}}(%i0)[%3] to min #map{{.*}}(%i2)[%c30] { + for %l = max #bound_map2(%i)[%u] to min #bound_map2(%k)[%c] { // CHECK: "bar"(%i3) : (index) -> () "bar"(%l) : (index) -> () } // CHECK: } @@ -305,7 +305,7 @@ func @loop_bounds(%N : index) { // CHECK-LABEL: func @ifinst(%arg0: index) { func @ifinst(%N: index) { %c = constant 200 : index // CHECK %c200 = constant 200 - affine.for %i = 1 to 10 { // CHECK affine.for %i0 = 1 to 10 { + for %i = 1 to 10 { // CHECK for %i0 = 1 to 10 { affine.if #set0(%i)[%N, %c] { // CHECK affine.if #set0(%i0)[%arg0, %c200] { %x = constant 1 : i32 // CHECK: %c1_i32 = constant 1 : i32 @@ -328,7 +328,7 @@ func @ifinst(%N: index) { // CHECK-LABEL: func @simple_ifinst(%arg0: index) { func @simple_ifinst(%N: index) { %c = constant 200 : index // CHECK %c200 = constant 200 - affine.for %i = 1 to 10 { // CHECK affine.for %i0 = 1 to 10 { + for %i = 1 to 10 { // CHECK for %i0 = 1 to 10 { affine.if #set0(%i)[%N, %c] { // CHECK affine.if #set0(%i0)[%arg0, %c200] { %x = constant 1 : i32 // CHECK: %c1_i32 = constant 1 : i32 @@ -544,18 +544,18 @@ func @funcattrwithblock() -> () #map_non_simple2 = ()[s0, s1] -> (s0 + s1) #map_non_simple3 = ()[s0] -> (s0 + 3) func @funcsimplemap(%arg0: index, %arg1: index) -> () { - affine.for %i0 = 0 to #map_simple0()[] { - // CHECK: affine.for %i0 = 0 to 10 { - affine.for %i1 = 0 to #map_simple1()[%arg1] { - // CHECK: affine.for %i1 = 0 to %arg1 { - affine.for %i2 = 0 to #map_non_simple0(%i0)[] { - // CHECK: affine.for %i2 = 0 to #map{{[a-z_0-9]*}}(%i0) { - affine.for %i3 = 0 to #map_non_simple1(%i0)[%arg1] { - // CHECK: affine.for %i3 = 0 to #map{{[a-z_0-9]*}}(%i0)[%arg1] { - affine.for %i4 = 0 to #map_non_simple2()[%arg1, %arg0] { - // CHECK: affine.for %i4 = 0 to #map{{[a-z_0-9]*}}()[%arg1, %arg0] { - affine.for %i5 = 0 to #map_non_simple3()[%arg0] { - // CHECK: affine.for %i5 = 0 to #map{{[a-z_0-9]*}}()[%arg0] { + for %i0 = 0 to #map_simple0()[] { + // CHECK: for %i0 = 0 to 10 { + for %i1 = 0 to #map_simple1()[%arg1] { + // CHECK: for %i1 = 0 to %arg1 { + for %i2 = 0 to #map_non_simple0(%i0)[] { + // CHECK: for %i2 = 0 to #map{{[a-z_0-9]*}}(%i0) { + for %i3 = 0 to #map_non_simple1(%i0)[%arg1] { + // CHECK: for %i3 = 0 to #map{{[a-z_0-9]*}}(%i0)[%arg1] { + for %i4 = 0 to #map_non_simple2()[%arg1, %arg0] { + // CHECK: for %i4 = 0 to #map{{[a-z_0-9]*}}()[%arg1, %arg0] { + for %i5 = 0 to #map_non_simple3()[%arg0] { + // CHECK: for %i5 = 0 to #map{{[a-z_0-9]*}}()[%arg0] { %c42_i32 = constant 42 : i32 } } @@ -749,9 +749,9 @@ func @sparsevectorattr() -> () { // CHECK-LABEL: func @loops_with_blockids() { func @loops_with_blockids() { ^block0: - affine.for %i = 1 to 100 step 2 { + for %i = 1 to 100 step 2 { ^block1: - affine.for %j = 1 to 200 { + for %j = 1 to 200 { ^block2: } } diff --git a/mlir/test/IR/pretty-locations.mlir b/mlir/test/IR/pretty-locations.mlir index defde9e9c70..bc5a319c99e 100644 --- a/mlir/test/IR/pretty-locations.mlir +++ b/mlir/test/IR/pretty-locations.mlir @@ -18,7 +18,7 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) { %3 = constant 4 : index loc(callsite("foo" at callsite("mysource1.cc":10:8 at callsite("mysource2.cc":13:8 at "mysource3.cc":100:10)))) // CHECK: } ["foo", mysource.cc:10:8] - affine.for %i0 = 0 to 8 { + for %i0 = 0 to 8 { } loc(fused["foo", "mysource.cc":10:8]) // CHECK: } <"myPass">["foo", "foo2"] diff --git a/mlir/test/Transforms/Vectorize/lower_vector_transfers.mlir b/mlir/test/Transforms/Vectorize/lower_vector_transfers.mlir index e896e0588d3..b82ac08fe59 100644 --- a/mlir/test/Transforms/Vectorize/lower_vector_transfers.mlir +++ b/mlir/test/Transforms/Vectorize/lower_vector_transfers.mlir @@ -6,8 +6,8 @@ // CHECK-LABEL: func @materialize_read_1d() { func @materialize_read_1d() { %A = alloc () : memref<7x42xf32> - affine.for %i0 = 0 to 7 step 4 { - affine.for %i1 = 0 to 42 step 4 { + for %i0 = 0 to 7 step 4 { + for %i1 = 0 to 42 step 4 { %f1 = vector_transfer_read %A, %i0, %i1 {permutation_map: (d0, d1) -> (d0)} : (memref<7x42xf32>, index, index) -> vector<4xf32> %ip1 = affine.apply (d0) -> (d0 + 1) (%i1) %f2 = vector_transfer_read %A, %i0, %ip1 {permutation_map: (d0, d1) -> (d0)} : (memref<7x42xf32>, index, index) -> vector<4xf32> @@ -29,11 +29,11 @@ func @materialize_read_1d() { // CHECK-LABEL: func @materialize_read_1d_partially_specialized func @materialize_read_1d_partially_specialized(%dyn1 : index, %dyn2 : index, %dyn4 : index) { %A = alloc (%dyn1, %dyn2, %dyn4) : memref<7x?x?x42x?xf32> - affine.for %i0 = 0 to 7 { - affine.for %i1 = 0 to %dyn1 { - affine.for %i2 = 0 to %dyn2 { - affine.for %i3 = 0 to 42 step 2 { - affine.for %i4 = 0 to %dyn4 { + for %i0 = 0 to 7 { + for %i1 = 0 to %dyn1 { + for %i2 = 0 to %dyn2 { + for %i3 = 0 to 42 step 2 { + for %i4 = 0 to %dyn4 { %f1 = vector_transfer_read %A, %i0, %i1, %i2, %i3, %i4 {permutation_map: (d0, d1, d2, d3, d4) -> (d3)} : ( memref<7x?x?x42x?xf32>, index, index, index, index, index) -> vector<4xf32> %i3p1 = affine.apply (d0) -> (d0 + 1) (%i3) %f2 = vector_transfer_read %A, %i0, %i1, %i2, %i3p1, %i4 {permutation_map: (d0, d1, d2, d3, d4) -> (d3)} : ( memref<7x?x?x42x?xf32>, index, index, index, index, index) -> vector<4xf32> @@ -54,10 +54,10 @@ func @materialize_read_1d_partially_specialized(%dyn1 : index, %dyn2 : index, %d // CHECK-LABEL: func @materialize_read(%arg0: index, %arg1: index, %arg2: index, %arg3: index) { func @materialize_read(%M: index, %N: index, %O: index, %P: index) { // CHECK-NEXT: %0 = alloc(%arg0, %arg1, %arg2, %arg3) : memref - // CHECK-NEXT: affine.for %[[I0:.*]] = 0 to %arg0 step 3 { - // CHECK-NEXT: affine.for %[[I1:.*]] = 0 to %arg1 { - // CHECK-NEXT: affine.for %[[I2:.*]] = 0 to %arg2 { - // CHECK-NEXT: affine.for %[[I3:.*]] = 0 to %arg3 step 5 { + // CHECK-NEXT: for %[[I0:.*]] = 0 to %arg0 step 3 { + // CHECK-NEXT: for %[[I1:.*]] = 0 to %arg1 { + // CHECK-NEXT: for %[[I2:.*]] = 0 to %arg2 { + // CHECK-NEXT: for %[[I3:.*]] = 0 to %arg3 step 5 { // CHECK-NEXT: %[[C0:.*]] = constant 0 : index // CHECK-NEXT: %[[C1:.*]] = constant 1 : index // CHECK: {{.*}} = dim %0, 0 : memref @@ -66,9 +66,9 @@ func @materialize_read(%M: index, %N: index, %O: index, %P: index) { // CHECK-NEXT: {{.*}} = dim %0, 3 : memref // CHECK: %[[ALLOC:.*]] = alloc() : memref<5x4x3xf32> // CHECK-NEXT: %[[VECTOR_VIEW:.*]] = vector_type_cast %[[ALLOC]] : memref<5x4x3xf32>, memref<1xvector<5x4x3xf32>> - // CHECK-NEXT: affine.for %[[I4:.*]] = 0 to 3 { - // CHECK-NEXT: affine.for %[[I5:.*]] = 0 to 4 { - // CHECK-NEXT: affine.for %[[I6:.*]] = 0 to 5 { + // CHECK-NEXT: for %[[I4:.*]] = 0 to 3 { + // CHECK-NEXT: for %[[I5:.*]] = 0 to 4 { + // CHECK-NEXT: for %[[I6:.*]] = 0 to 5 { // CHECK-NEXT: {{.*}} = affine.apply #[[ADD]] // CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index // CHECK-NEXT: {{.*}} = affine.apply #[[ADD]] @@ -109,10 +109,10 @@ func @materialize_read(%M: index, %N: index, %O: index, %P: index) { // CHECK-NEXT: return // CHECK-NEXT:} %A = alloc (%M, %N, %O, %P) : memref - affine.for %i0 = 0 to %M step 3 { - affine.for %i1 = 0 to %N { - affine.for %i2 = 0 to %O { - affine.for %i3 = 0 to %P step 5 { + for %i0 = 0 to %M step 3 { + for %i1 = 0 to %N { + for %i2 = 0 to %O { + for %i3 = 0 to %P step 5 { %f = vector_transfer_read %A, %i0, %i1, %i2, %i3 {permutation_map: (d0, d1, d2, d3) -> (d3, 0, d0)} : (memref, index, index, index, index) -> vector<5x4x3xf32> } } @@ -125,10 +125,10 @@ func @materialize_read(%M: index, %N: index, %O: index, %P: index) { func @materialize_write(%M: index, %N: index, %O: index, %P: index) { // CHECK-NEXT: %0 = alloc(%arg0, %arg1, %arg2, %arg3) : memref // CHECK-NEXT: %cst = constant splat, 1.000000e+00> : vector<5x4x3xf32> - // CHECK-NEXT: affine.for %[[I0:.*]] = 0 to %arg0 step 3 { - // CHECK-NEXT: affine.for %[[I1:.*]] = 0 to %arg1 step 4 { - // CHECK-NEXT: affine.for %[[I2:.*]] = 0 to %arg2 { - // CHECK-NEXT: affine.for %[[I3:.*]] = 0 to %arg3 step 5 { + // CHECK-NEXT: for %[[I0:.*]] = 0 to %arg0 step 3 { + // CHECK-NEXT: for %[[I1:.*]] = 0 to %arg1 step 4 { + // CHECK-NEXT: for %[[I2:.*]] = 0 to %arg2 { + // CHECK-NEXT: for %[[I3:.*]] = 0 to %arg3 step 5 { // CHECK-NEXT: %[[C0:.*]] = constant 0 : index // CHECK-NEXT: %[[C1:.*]] = constant 1 : index // CHECK: {{.*}} = dim %0, 0 : memref @@ -138,9 +138,9 @@ func @materialize_write(%M: index, %N: index, %O: index, %P: index) { // CHECK: %[[ALLOC:.*]] = alloc() : memref<5x4x3xf32> // CHECK-NEXT: %[[VECTOR_VIEW:.*]] = vector_type_cast {{.*}} : memref<5x4x3xf32>, memref<1xvector<5x4x3xf32>> // CHECK-NEXT: store %cst, {{.*}}[%[[C0]]] : memref<1xvector<5x4x3xf32>> - // CHECK-NEXT: affine.for %[[I4:.*]] = 0 to 3 { - // CHECK-NEXT: affine.for %[[I5:.*]] = 0 to 4 { - // CHECK-NEXT: affine.for %[[I6:.*]] = 0 to 5 { + // CHECK-NEXT: for %[[I4:.*]] = 0 to 3 { + // CHECK-NEXT: for %[[I5:.*]] = 0 to 4 { + // CHECK-NEXT: for %[[I6:.*]] = 0 to 5 { // CHECK-NEXT: {{.*}} = load {{.*}}[%[[I6]], %[[I5]], %[[I4]]] : memref<5x4x3xf32> // CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I0]], %[[I4]]) // CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index @@ -184,10 +184,10 @@ func @materialize_write(%M: index, %N: index, %O: index, %P: index) { // CHECK-NEXT:} %A = alloc (%M, %N, %O, %P) : memref %f1 = constant splat, 1.000000e+00> : vector<5x4x3xf32> - affine.for %i0 = 0 to %M step 3 { - affine.for %i1 = 0 to %N step 4 { - affine.for %i2 = 0 to %O { - affine.for %i3 = 0 to %P step 5 { + for %i0 = 0 to %M step 3 { + for %i1 = 0 to %N step 4 { + for %i2 = 0 to %O { + for %i3 = 0 to %P step 5 { vector_transfer_write %f1, %A, %i0, %i1, %i2, %i3 {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d0)} : vector<5x4x3xf32>, memref, index, index, index, index } } diff --git a/mlir/test/Transforms/Vectorize/materialize.mlir b/mlir/test/Transforms/Vectorize/materialize.mlir index ce445ec75bb..80458c75333 100644 --- a/mlir/test/Transforms/Vectorize/materialize.mlir +++ b/mlir/test/Transforms/Vectorize/materialize.mlir @@ -10,10 +10,10 @@ func @materialize(%M : index, %N : index, %O : index, %P : index) { %A = alloc (%M, %N, %O, %P) : memref %f1 = constant splat, 1.000000e+00> : vector<4x4x4xf32> - // CHECK: affine.for %i0 = 0 to %arg0 step 4 { - // CHECK-NEXT: affine.for %i1 = 0 to %arg1 step 4 { - // CHECK-NEXT: affine.for %i2 = 0 to %arg2 { - // CHECK-NEXT: affine.for %i3 = 0 to %arg3 step 4 { + // CHECK: for %i0 = 0 to %arg0 step 4 { + // CHECK-NEXT: for %i1 = 0 to %arg1 step 4 { + // CHECK-NEXT: for %i2 = 0 to %arg2 { + // CHECK-NEXT: for %i3 = 0 to %arg3 step 4 { // CHECK-NEXT: %[[a:[0-9]+]] = {{.*}}[[ID1]](%i0) // CHECK-NEXT: %[[b:[0-9]+]] = {{.*}}[[ID1]](%i1) // CHECK-NEXT: %[[c:[0-9]+]] = {{.*}}[[ID1]](%i2) @@ -25,10 +25,10 @@ func @materialize(%M : index, %N : index, %O : index, %P : index) { // CHECK: vector_transfer_write {{.*}}, %0, {{.*}}, %[[b2]], {{.*}} {permutation_map: #[[D0D1D2D3TOD1D0]]} : vector<4x4xf32>, memref, index, index, index, index // CHECK: %[[b3:[0-9]+]] = {{.*}}[[D0P3]](%i1) // CHECK: vector_transfer_write {{.*}}, %0, {{.*}}, %[[b3]], {{.*}} {permutation_map: #[[D0D1D2D3TOD1D0]]} : vector<4x4xf32>, memref, index, index, index, index - affine.for %i0 = 0 to %M step 4 { - affine.for %i1 = 0 to %N step 4 { - affine.for %i2 = 0 to %O { - affine.for %i3 = 0 to %P step 4 { + for %i0 = 0 to %M step 4 { + for %i1 = 0 to %N step 4 { + for %i2 = 0 to %O { + for %i3 = 0 to %P step 4 { "vector_transfer_write"(%f1, %A, %i0, %i1, %i2, %i3) {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d0)} : (vector<4x4x4xf32>, memref, index, index, index, index) -> () } } diff --git a/mlir/test/Transforms/Vectorize/materialize_vectors_1d_to_1d.mlir b/mlir/test/Transforms/Vectorize/materialize_vectors_1d_to_1d.mlir index 71c442b965e..b5f771d7e62 100644 --- a/mlir/test/Transforms/Vectorize/materialize_vectors_1d_to_1d.mlir +++ b/mlir/test/Transforms/Vectorize/materialize_vectors_1d_to_1d.mlir @@ -15,8 +15,8 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { %f1 = constant 1.0 : f32 %f2 = constant 2.0 : f32 // 4x unroll (jammed by construction). - // CHECK: affine.for %i0 = 0 to %arg0 { - // CHECK-NEXT: affine.for %i1 = 0 to %arg1 step 32 { + // CHECK: for %i0 = 0 to %arg0 { + // CHECK-NEXT: for %i1 = 0 to %arg1 step 32 { // CHECK-NEXT: [[CST0:%.*]] = constant splat, 1.000000e+00> : vector<8xf32> // CHECK-NEXT: [[CST1:%.*]] = constant splat, 1.000000e+00> : vector<8xf32> // CHECK-NEXT: [[CST2:%.*]] = constant splat, 1.000000e+00> : vector<8xf32> @@ -34,15 +34,15 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { // CHECK-NEXT: [[VAL31:%.*]] = affine.apply [[D0P24]]{{.*}} // CHECK-NEXT: vector_transfer_write [[CST3]], {{.*}}, [[VAL30]], [[VAL31]] {permutation_map: [[D0D1TOD1]]} : vector<8xf32> // - affine.for %i0 = 0 to %M { - affine.for %i1 = 0 to %N { + for %i0 = 0 to %M { + for %i1 = 0 to %N { // non-scoped %f1 store %f1, %A[%i0, %i1] : memref } } // 4x unroll (jammed by construction). - // CHECK: affine.for %i2 = 0 to %arg0 { - // CHECK-NEXT: affine.for %i3 = 0 to %arg1 step 32 { + // CHECK: for %i2 = 0 to %arg0 { + // CHECK-NEXT: for %i3 = 0 to %arg1 step 32 { // CHECK-NEXT: [[CST0:%.*]] = constant splat, 2.000000e+00> : vector<8xf32> // CHECK-NEXT: [[CST1:%.*]] = constant splat, 2.000000e+00> : vector<8xf32> // CHECK-NEXT: [[CST2:%.*]] = constant splat, 2.000000e+00> : vector<8xf32> @@ -60,15 +60,15 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { // CHECK-NEXT: [[VAL31:%.*]] = affine.apply [[D0P24]]{{.*}} // CHECK-NEXT: vector_transfer_write [[CST3]], {{.*}}, [[VAL30]], [[VAL31]] {permutation_map: [[D0D1TOD1]]} : vector<8xf32> // - affine.for %i2 = 0 to %M { - affine.for %i3 = 0 to %N { + for %i2 = 0 to %M { + for %i3 = 0 to %N { // non-scoped %f2 store %f2, %B[%i2, %i3] : memref } } // 4x unroll (jammed by construction). - // CHECK: affine.for %i4 = 0 to %arg0 { - // CHECK-NEXT: affine.for %i5 = 0 to %arg1 step 32 { + // CHECK: for %i4 = 0 to %arg0 { + // CHECK-NEXT: for %i5 = 0 to %arg1 step 32 { // CHECK-NEXT: {{.*}} = affine.apply // CHECK-NEXT: {{.*}} = affine.apply // CHECK-NEXT: {{.*}} = vector_transfer_read @@ -110,8 +110,8 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { // CHECK-NEXT: {{.*}} = affine.apply // CHECK-NEXT: vector_transfer_write // - affine.for %i4 = 0 to %M { - affine.for %i5 = 0 to %N { + for %i4 = 0 to %M { + for %i5 = 0 to %N { %a5 = load %A[%i4, %i5] : memref %b5 = load %B[%i4, %i5] : memref %s5 = addf %a5, %b5 : f32 diff --git a/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_1d.mlir b/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_1d.mlir index 62149c323b6..92df49fa8fa 100644 --- a/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_1d.mlir +++ b/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_1d.mlir @@ -15,8 +15,8 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { %f1 = constant 1.0 : f32 %f2 = constant 2.0 : f32 // (3x2)x unroll (jammed by construction). - // CHECK: affine.for %i0 = 0 to %arg0 step 3 { - // CHECK-NEXT: affine.for %i1 = 0 to %arg1 step 16 { + // CHECK: for %i0 = 0 to %arg0 step 3 { + // CHECK-NEXT: for %i1 = 0 to %arg1 step 16 { // CHECK-NEXT: {{.*}} = constant splat, 1.000000e+00> : vector<8xf32> // CHECK-NEXT: {{.*}} = constant splat, 1.000000e+00> : vector<8xf32> // CHECK-NEXT: {{.*}} = constant splat, 1.000000e+00> : vector<8xf32> @@ -41,26 +41,26 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { // CHECK-NEXT: [[VAL50:%.*]] = affine.apply [[D0P2]](%i0) // CHECK-NEXT: [[VAL51:%.*]] = affine.apply [[D0P8]](%i1) // CHECK-NEXT: vector_transfer_write {{.*}}, {{.*}}, [[VAL50]], [[VAL51]] {permutation_map: [[D0D1TOD1]]} : vector<8xf32> - affine.for %i0 = 0 to %M { - affine.for %i1 = 0 to %N { + for %i0 = 0 to %M { + for %i1 = 0 to %N { // non-scoped %f1 store %f1, %A[%i0, %i1] : memref } } // (3x2)x unroll (jammed by construction). - // CHECK: affine.for %i2 = 0 to %arg0 step 3 { - // CHECK-NEXT: affine.for %i3 = 0 to %arg1 step 16 { + // CHECK: for %i2 = 0 to %arg0 step 3 { + // CHECK-NEXT: for %i3 = 0 to %arg1 step 16 { // ..... - affine.for %i2 = 0 to %M { - affine.for %i3 = 0 to %N { + for %i2 = 0 to %M { + for %i3 = 0 to %N { // non-scoped %f2 // CHECK does (3x4)x unrolling. store %f2, %B[%i2, %i3] : memref } } // (3x2)x unroll (jammed by construction). - // CHECK: affine.for %i4 = 0 to %arg0 step 3 { - // CHECK-NEXT: affine.for %i5 = 0 to %arg1 step 16 { + // CHECK: for %i4 = 0 to %arg0 step 3 { + // CHECK-NEXT: for %i5 = 0 to %arg1 step 16 { // CHECK-NEXT: {{.*}} = affine.apply // CHECK-NEXT: {{.*}} = affine.apply // CHECK-NEXT: {{.*}} = vector_transfer_read @@ -122,8 +122,8 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { // CHECK-NEXT: {{.*}} = affine.apply // CHECK-NEXT: vector_transfer_write // - affine.for %i4 = 0 to %M { - affine.for %i5 = 0 to %N { + for %i4 = 0 to %M { + for %i5 = 0 to %N { %a5 = load %A[%i4, %i5] : memref %b5 = load %B[%i4, %i5] : memref %s5 = addf %a5, %b5 : f32 diff --git a/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_2d.mlir b/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_2d.mlir index 59705eca69e..36ec96e30b4 100644 --- a/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_2d.mlir +++ b/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_2d.mlir @@ -13,8 +13,8 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { %f1 = constant 1.0 : f32 %f2 = constant 2.0 : f32 // 2x unroll (jammed by construction). - // CHECK: affine.for %i0 = 0 to %arg0 step 3 { - // CHECK-NEXT: affine.for %i1 = 0 to %arg1 step 32 { + // CHECK: for %i0 = 0 to %arg0 step 3 { + // CHECK-NEXT: for %i1 = 0 to %arg1 step 32 { // CHECK-NEXT: {{.*}} = constant splat, 1.000000e+00> : vector<3x16xf32> // CHECK-NEXT: {{.*}} = constant splat, 1.000000e+00> : vector<3x16xf32> // CHECK-NEXT: [[VAL00:%.*]] = affine.apply [[ID1]](%i0) @@ -24,15 +24,15 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { // CHECK-NEXT: [[VAL11:%.*]] = affine.apply [[D0P16]](%i1) // CHECK-NEXT: vector_transfer_write {{.*}}, {{.*}}, [[VAL10]], [[VAL11]] {permutation_map: [[ID2]]} : vector<3x16xf32> // - affine.for %i0 = 0 to %M { - affine.for %i1 = 0 to %N { + for %i0 = 0 to %M { + for %i1 = 0 to %N { // non-scoped %f1 store %f1, %A[%i0, %i1] : memref } } // 2x unroll (jammed by construction). - // CHECK: affine.for %i2 = 0 to %arg0 step 3 { - // CHECK-NEXT: affine.for %i3 = 0 to %arg1 step 32 { + // CHECK: for %i2 = 0 to %arg0 step 3 { + // CHECK-NEXT: for %i3 = 0 to %arg1 step 32 { // CHECK-NEXT: {{.*}} = constant splat, 2.000000e+00> : vector<3x16xf32> // CHECK-NEXT: {{.*}} = constant splat, 2.000000e+00> : vector<3x16xf32> // CHECK-NEXT: [[VAL00:%.*]] = affine.apply [[ID1]](%i2) @@ -42,15 +42,15 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { // CHECK-NEXT: [[VAL11:%.*]] = affine.apply [[D0P16]](%i3) // CHECK-NEXT: vector_transfer_write {{.*}}, {{.*}}, [[VAL10]], [[VAL11]] {permutation_map: [[ID2]]} : vector<3x16xf32> // - affine.for %i2 = 0 to %M { - affine.for %i3 = 0 to %N { + for %i2 = 0 to %M { + for %i3 = 0 to %N { // non-scoped %f2 store %f2, %B[%i2, %i3] : memref } } // 2x unroll (jammed by construction). - // CHECK: affine.for %i4 = 0 to %arg0 step 3 { - // CHECK-NEXT: affine.for %i5 = 0 to %arg1 step 32 { + // CHECK: for %i4 = 0 to %arg0 step 3 { + // CHECK-NEXT: for %i5 = 0 to %arg1 step 32 { // CHECK-NEXT: {{.*}} = affine.apply // CHECK-NEXT: {{.*}} = affine.apply // CHECK-NEXT: {{.*}} = vector_transfer_read @@ -72,8 +72,8 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { // CHECK-NEXT: {{.*}} = affine.apply // CHECK-NEXT: vector_transfer_write // - affine.for %i4 = 0 to %M { - affine.for %i5 = 0 to %N { + for %i4 = 0 to %M { + for %i5 = 0 to %N { %a5 = load %A[%i4, %i5] : memref %b5 = load %B[%i4, %i5] : memref %s5 = addf %a5, %b5 : f32 diff --git a/mlir/test/Transforms/Vectorize/normalize_maps.mlir b/mlir/test/Transforms/Vectorize/normalize_maps.mlir index 076d2c75633..9569dbe07fe 100644 --- a/mlir/test/Transforms/Vectorize/normalize_maps.mlir +++ b/mlir/test/Transforms/Vectorize/normalize_maps.mlir @@ -9,19 +9,19 @@ // CHECK-LABEL: func @simple() func @simple() { - affine.for %i0 = 0 to 7 { + for %i0 = 0 to 7 { %0 = affine.apply (d0) -> (d0) (%i0) %1 = affine.apply (d0) -> (d0) (%0) %2 = affine.apply (d0, d1) -> (d0 + d1) (%0, %0) %3 = affine.apply (d0, d1) -> (d0 - d1) (%0, %0) } - // CHECK-NEXT: affine.for %i0 = 0 to 7 + // CHECK-NEXT: for %i0 = 0 to 7 // CHECK-NEXT: {{.*}} affine.apply #[[ID1]](%i0) // CHECK-NEXT: {{.*}} affine.apply #[[D0TIMES2]](%i0) // CHECK-NEXT: {{.*}} affine.apply #[[ZERO]]() - affine.for %i1 = 0 to 7 { - affine.for %i2 = 0 to 42 { + for %i1 = 0 to 7 { + for %i2 = 0 to 42 { %20 = affine.apply (d0, d1) -> (d1) (%i1, %i2) %21 = affine.apply (d0, d1) -> (d0) (%i1, %i2) %22 = affine.apply (d0, d1) -> (d0 + d1) (%20, %21) @@ -29,15 +29,15 @@ func @simple() { %24 = affine.apply (d0, d1) -> (-d0 + d1) (%20, %21) } } - // CHECK: affine.for %i1 = 0 to 7 - // CHECK-NEXT: affine.for %i2 = 0 to 42 + // CHECK: for %i1 = 0 to 7 + // CHECK-NEXT: for %i2 = 0 to 42 // CHECK-NEXT: {{.*}} affine.apply #[[D0PLUSD1]](%i1, %i2) // CHECK-NEXT: {{.*}} affine.apply #[[MINSD0PLUSD1]](%i1, %i2) // CHECK-NEXT: {{.*}} affine.apply #[[D0MINUSD1]](%i1, %i2) - affine.for %i3 = 0 to 16 { - affine.for %i4 = 0 to 47 step 2 { - affine.for %i5 = 0 to 78 step 16 { + for %i3 = 0 to 16 { + for %i4 = 0 to 47 step 2 { + for %i5 = 0 to 78 step 16 { %50 = affine.apply (d0) -> (d0) (%i3) %51 = affine.apply (d0) -> (d0) (%i4) %52 = affine.apply (d0) -> (d0) (%i5) @@ -47,9 +47,9 @@ func @simple() { } } } - // CHECK: affine.for %i3 = 0 to 16 - // CHECK-NEXT: affine.for %i4 = 0 to 47 step 2 - // CHECK-NEXT: affine.for %i5 = 0 to 78 step 16 + // CHECK: for %i3 = 0 to 16 + // CHECK-NEXT: for %i4 = 0 to 47 step 2 + // CHECK-NEXT: for %i5 = 0 to 78 step 16 // CHECK-NEXT: {{.*}} affine.apply #[[ID1]](%i3) // CHECK-NEXT: {{.*}} affine.apply #[[ID1]](%i4) // CHECK-NEXT: {{.*}} affine.apply #[[ID1]](%i5) diff --git a/mlir/test/Transforms/Vectorize/vectorize_1d.mlir b/mlir/test/Transforms/Vectorize/vectorize_1d.mlir index c812db2d498..05e31dbdea5 100644 --- a/mlir/test/Transforms/Vectorize/vectorize_1d.mlir +++ b/mlir/test/Transforms/Vectorize/vectorize_1d.mlir @@ -23,17 +23,17 @@ func @vec1d(%A : memref, %B : memref) { // // CHECK: for {{.*}} step 128 // CHECK-NEXT: {{.*}} = vector_transfer_read %arg0, [[C0]], [[C0]] {permutation_map: #[[map_proj_d0d1_0]]} : (memref, index, index) -> vector<128xf32> - affine.for %i0 = 0 to %M { // vectorized due to scalar -> vector + for %i0 = 0 to %M { // vectorized due to scalar -> vector %a0 = load %A[%cst0, %cst0] : memref } // // CHECK:for {{.*}} [[ARG_M]] { - affine.for %i1 = 0 to %M { // not vectorized + for %i1 = 0 to %M { // not vectorized %a1 = load %A[%i1, %i1] : memref } // -// CHECK: affine.for %i{{[0-9]*}} = 0 to [[ARG_M]] { - affine.for %i2 = 0 to %M { // not vectorized, would vectorize with --test-fastest-varying=1 +// CHECK: for %i{{[0-9]*}} = 0 to [[ARG_M]] { + for %i2 = 0 to %M { // not vectorized, would vectorize with --test-fastest-varying=1 %r2 = affine.apply (d0) -> (d0) (%i2) %a2 = load %A[%r2#0, %cst0] : memref } @@ -41,7 +41,7 @@ func @vec1d(%A : memref, %B : memref) { // CHECK:for [[IV3:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128 // CHECK-NEXT: [[APP3:%[a-zA-Z0-9]+]] = affine.apply {{.*}}[[IV3]] // CHECK-NEXT: {{.*}} = vector_transfer_read %arg0, [[C0]], [[APP3]] {permutation_map: #[[map_proj_d0d1_d1]]} : {{.*}} -> vector<128xf32> - affine.for %i3 = 0 to %M { // vectorized + for %i3 = 0 to %M { // vectorized %r3 = affine.apply (d0) -> (d0) (%i3) %a3 = load %A[%cst0, %r3#0] : memref } @@ -51,8 +51,8 @@ func @vec1d(%A : memref, %B : memref) { // CHECK-NEXT: [[APP50:%[0-9]+]] = affine.apply {{.*}}([[IV4]], [[IV5]]) // CHECK-NEXT: [[APP51:%[0-9]+]] = affine.apply {{.*}}([[IV4]], [[IV5]]) // CHECK-NEXT: {{.*}} = vector_transfer_read %arg0, [[APP50]], [[APP51]] {permutation_map: #[[map_proj_d0d1_d1]]} : {{.*}} -> vector<128xf32> - affine.for %i4 = 0 to %M { // vectorized - affine.for %i5 = 0 to %N { // not vectorized, would vectorize with --test-fastest-varying=1 + for %i4 = 0 to %M { // vectorized + for %i5 = 0 to %N { // not vectorized, would vectorize with --test-fastest-varying=1 %r50 = affine.apply (d0, d1) -> (d1) (%i4, %i5) %r51 = affine.apply (d0, d1) -> (d0) (%i4, %i5) %a5 = load %A[%r50, %r51] : memref @@ -61,8 +61,8 @@ func @vec1d(%A : memref, %B : memref) { // // CHECK: for [[IV6:%[i0-9]*]] = 0 to [[ARG_M]] { // CHECK-NEXT: for [[IV7:%[i0-9]*]] = 0 to [[ARG_N]] { - affine.for %i6 = 0 to %M { // not vectorized, would vectorize with --test-fastest-varying=1 - affine.for %i7 = 0 to %N { // not vectorized, can never vectorize + for %i6 = 0 to %M { // not vectorized, would vectorize with --test-fastest-varying=1 + for %i7 = 0 to %N { // not vectorized, can never vectorize %r70 = affine.apply (d0, d1) -> (d1 + d0) (%i6, %i7) %r71 = affine.apply (d0, d1) -> (d0) (%i6, %i7) %a7 = load %A[%r70, %r71] : memref @@ -74,8 +74,8 @@ func @vec1d(%A : memref, %B : memref) { // CHECK-NEXT: [[APP9_0:%[0-9]+]] = affine.apply {{.*}}([[IV8]], [[IV9]]) // CHECK-NEXT: [[APP9_1:%[0-9]+]] = affine.apply {{.*}}([[IV8]], [[IV9]]) // CHECK-NEXT: {{.*}} = vector_transfer_read %arg0, [[APP9_0]], [[APP9_1]] {permutation_map: #[[map_proj_d0d1_d1]]} : {{.*}} -> vector<128xf32> - affine.for %i8 = 0 to %M { // vectorized - affine.for %i9 = 0 to %N { + for %i8 = 0 to %M { // vectorized + for %i9 = 0 to %N { %r90 = affine.apply (d0, d1) -> (d1) (%i8, %i9) %r91 = affine.apply (d0, d1) -> (d0 + d1) (%i8, %i9) %a9 = load %A[%r90, %r91] : memref @@ -84,8 +84,8 @@ func @vec1d(%A : memref, %B : memref) { // // CHECK: for [[IV10:%[i0-9]*]] = 0 to %{{[0-9]*}} { // CHECK: for [[IV11:%[i0-9]*]] = 0 to %{{[0-9]*}} { - affine.for %i10 = 0 to %M { // not vectorized, need per load transposes - affine.for %i11 = 0 to %N { // not vectorized, need per load transposes + for %i10 = 0 to %M { // not vectorized, need per load transposes + for %i11 = 0 to %N { // not vectorized, need per load transposes %r11_0 = affine.apply (d0, d1) -> (d0) (%i10, %i11) %r11_1 = affine.apply (d0, d1) -> (d1) (%i10, %i11) %a11 = load %A[%r11_0, %r11_1] : memref @@ -98,9 +98,9 @@ func @vec1d(%A : memref, %B : memref) { // CHECK: for [[IV12:%[i0-9]*]] = 0 to %{{[0-9]*}} { // CHECK: for [[IV13:%[i0-9]*]] = 0 to %{{[0-9]*}} { // CHECK: for [[IV14:%[i0-9]+]] = 0 to [[ARG_P]] step 128 - affine.for %i12 = 0 to %M { // not vectorized, can never vectorize - affine.for %i13 = 0 to %N { // not vectorized, can never vectorize - affine.for %i14 = 0 to %P { // vectorized + for %i12 = 0 to %M { // not vectorized, can never vectorize + for %i13 = 0 to %N { // not vectorized, can never vectorize + for %i14 = 0 to %P { // vectorized %r14_0 = affine.apply (d0, d1, d2) -> (d1) (%i12, %i13, %i14) %r14_1 = affine.apply (d0, d1, d2) -> (d0 + d1) (%i12, %i13, %i14) %r14_2 = affine.apply (d0, d1, d2) -> (d0 + d2) (%i12, %i13, %i14) @@ -109,24 +109,24 @@ func @vec1d(%A : memref, %B : memref) { } } // -// CHECK: affine.for %i{{[0-9]*}} = 0 to %{{[0-9]*}} { - affine.for %i15 = 0 to %M { // not vectorized due to condition below +// CHECK: for %i{{[0-9]*}} = 0 to %{{[0-9]*}} { + for %i15 = 0 to %M { // not vectorized due to condition below affine.if #set0(%i15) { %a15 = load %A[%cst0, %cst0] : memref } } // -// CHECK: affine.for %i{{[0-9]*}} = 0 to %{{[0-9]*}} { - affine.for %i16 = 0 to %M { // not vectorized, can't vectorize a vector load +// CHECK: for %i{{[0-9]*}} = 0 to %{{[0-9]*}} { + for %i16 = 0 to %M { // not vectorized, can't vectorize a vector load %a16 = alloc(%M) : memref> %l16 = load %a16[%i16] : memref> } // -// CHECK: affine.for %i{{[0-9]*}} = 0 to %{{[0-9]*}} { +// CHECK: for %i{{[0-9]*}} = 0 to %{{[0-9]*}} { // CHECK: for [[IV18:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128 // CHECK: {{.*}} = vector_transfer_read %arg0, [[C0]], [[C0]] {permutation_map: #[[map_proj_d0d1_0]]} : {{.*}} -> vector<128xf32> - affine.for %i17 = 0 to %M { // not vectorized, the 1-D pattern that matched %i18 in DFS post-order prevents vectorizing %i17 - affine.for %i18 = 0 to %M { // vectorized due to scalar -> vector + for %i17 = 0 to %M { // not vectorized, the 1-D pattern that matched %i18 in DFS post-order prevents vectorizing %i17 + for %i18 = 0 to %M { // vectorized due to scalar -> vector %a18 = load %A[%cst0, %cst0] : memref } } @@ -139,24 +139,24 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { %C = alloc (%M, %N) : memref %f1 = constant 1.0 : f32 %f2 = constant 2.0 : f32 - affine.for %i0 = 0 to %M { - affine.for %i1 = 0 to %N { + for %i0 = 0 to %M { + for %i1 = 0 to %N { // CHECK: [[C1:%.*]] = constant splat, 1.000000e+00> : vector<128xf32> // CHECK: vector_transfer_write [[C1]], {{.*}} {permutation_map: #[[map_proj_d0d1_d1]]} : vector<128xf32>, memref, index, index // non-scoped %f1 store %f1, %A[%i0, %i1] : memref } } - affine.for %i2 = 0 to %M { - affine.for %i3 = 0 to %N { + for %i2 = 0 to %M { + for %i3 = 0 to %N { // CHECK: [[C3:%.*]] = constant splat, 2.000000e+00> : vector<128xf32> // CHECK: vector_transfer_write [[C3]], {{.*}} {permutation_map: #[[map_proj_d0d1_d1]]} : vector<128xf32>, memref, index, index // non-scoped %f2 store %f2, %B[%i2, %i3] : memref } } - affine.for %i4 = 0 to %M { - affine.for %i5 = 0 to %N { + for %i4 = 0 to %M { + for %i5 = 0 to %N { // CHECK: [[A5:%.*]] = vector_transfer_read %0, {{.*}} {permutation_map: #[[map_proj_d0d1_d1]]} : (memref, index, index) -> vector<128xf32> // CHECK: [[B5:%.*]] = vector_transfer_read %1, {{.*}} {permutation_map: #[[map_proj_d0d1_d1]]} : (memref, index, index) -> vector<128xf32> // CHECK: [[S5:%.*]] = addf [[A5]], [[B5]] : vector<128xf32> @@ -188,10 +188,10 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { // CHECK-LABEL: @vec_rejected func @vec_rejected(%A : memref, %C : memref) { %N = dim %A, 0 : memref - affine.for %i = 0 to %N { + for %i = 0 to %N { // CHECK-NOT: vector %a = load %A[%i, %i] : memref // not vectorized - affine.for %j = 0 to %N { + for %j = 0 to %N { %b = load %A[%i, %j] : memref // may be vectorized // CHECK-NOT: vector %c = addf %a, %b : f32 // not vectorized because %a wasn't diff --git a/mlir/test/Transforms/Vectorize/vectorize_2d.mlir b/mlir/test/Transforms/Vectorize/vectorize_2d.mlir index 59c7483749b..d847f6bb5ce 100644 --- a/mlir/test/Transforms/Vectorize/vectorize_2d.mlir +++ b/mlir/test/Transforms/Vectorize/vectorize_2d.mlir @@ -11,13 +11,13 @@ func @vec2d(%A : memref) { // CHECK: for {{.*}} = 0 to %1 step 32 // CHECK: for {{.*}} = 0 to %2 step 256 // Example: - // affine.for %i0 = 0 to %0 { - // affine.for %i1 = 0 to %1 step 32 { - // affine.for %i2 = 0 to %2 step 256 { + // for %i0 = 0 to %0 { + // for %i1 = 0 to %1 step 32 { + // for %i2 = 0 to %2 step 256 { // %3 = "vector_transfer_read"(%arg0, %i0, %i1, %i2) : (memref, index, index, index) -> vector<32x256xf32> - affine.for %i0 = 0 to %M { - affine.for %i1 = 0 to %N { - affine.for %i2 = 0 to %P { + for %i0 = 0 to %M { + for %i1 = 0 to %N { + for %i2 = 0 to %P { %a2 = load %A[%i0, %i1, %i2] : memref } } @@ -27,9 +27,9 @@ func @vec2d(%A : memref) { // CHECK: for {{.*}} = 0 to %2 { // For the case: --test-fastest-varying=1 --test-fastest-varying=0 no // vectorization happens because of loop nesting order . - affine.for %i3 = 0 to %M { - affine.for %i4 = 0 to %N { - affine.for %i5 = 0 to %P { + for %i3 = 0 to %M { + for %i4 = 0 to %N { + for %i5 = 0 to %P { %a5 = load %A[%i4, %i5, %i3] : memref } } @@ -43,24 +43,24 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { %C = alloc (%M, %N) : memref %f1 = constant 1.0 : f32 %f2 = constant 2.0 : f32 - affine.for %i0 = 0 to %M { - affine.for %i1 = 0 to %N { + for %i0 = 0 to %M { + for %i1 = 0 to %N { // CHECK: [[C1:%.*]] = constant splat, 1.000000e+00> : vector<32x256xf32> // CHECK: vector_transfer_write [[C1]], {{.*}} {permutation_map: #[[map_proj_d0d1_d0d1]]} : vector<32x256xf32>, memref, index, index // non-scoped %f1 store %f1, %A[%i0, %i1] : memref } } - affine.for %i2 = 0 to %M { - affine.for %i3 = 0 to %N { + for %i2 = 0 to %M { + for %i3 = 0 to %N { // CHECK: [[C3:%.*]] = constant splat, 2.000000e+00> : vector<32x256xf32> // CHECK: vector_transfer_write [[C3]], {{.*}} {permutation_map: #[[map_proj_d0d1_d0d1]]} : vector<32x256xf32>, memref, index, index // non-scoped %f2 store %f2, %B[%i2, %i3] : memref } } - affine.for %i4 = 0 to %M { - affine.for %i5 = 0 to %N { + for %i4 = 0 to %M { + for %i5 = 0 to %N { // CHECK: [[A5:%.*]] = vector_transfer_read %0, {{.*}} {permutation_map: #[[map_proj_d0d1_d0d1]]} : (memref, index, index) -> vector<32x256xf32> // CHECK: [[B5:%.*]] = vector_transfer_read %1, {{.*}} {permutation_map: #[[map_proj_d0d1_d0d1]]} : (memref, index, index) -> vector<32x256xf32> // CHECK: [[S5:%.*]] = addf [[A5]], [[B5]] : vector<32x256xf32> diff --git a/mlir/test/Transforms/Vectorize/vectorize_3d.mlir b/mlir/test/Transforms/Vectorize/vectorize_3d.mlir index 08ca27dbeee..1a6bee585ee 100644 --- a/mlir/test/Transforms/Vectorize/vectorize_3d.mlir +++ b/mlir/test/Transforms/Vectorize/vectorize_3d.mlir @@ -7,17 +7,17 @@ func @vec3d(%A : memref) { %0 = dim %A, 0 : memref %1 = dim %A, 1 : memref %2 = dim %A, 2 : memref - // CHECK: affine.for %i0 = 0 to %0 { - // CHECK: affine.for %i1 = 0 to %0 { - // CHECK: affine.for %i2 = 0 to %0 step 32 { - // CHECK: affine.for %i3 = 0 to %1 step 64 { - // CHECK: affine.for %i4 = 0 to %2 step 256 { + // CHECK: for %i0 = 0 to %0 { + // CHECK: for %i1 = 0 to %0 { + // CHECK: for %i2 = 0 to %0 step 32 { + // CHECK: for %i3 = 0 to %1 step 64 { + // CHECK: for %i4 = 0 to %2 step 256 { // CHECK: %3 = vector_transfer_read %arg0, %i2, %i3, %i4 {permutation_map: #[[map_proj_d0d1d2_d0d1d2]]} : (memref, index, index, index) -> vector<32x64x256xf32> - affine.for %t0 = 0 to %0 { - affine.for %t1 = 0 to %0 { - affine.for %i0 = 0 to %0 { - affine.for %i1 = 0 to %1 { - affine.for %i2 = 0 to %2 { + for %t0 = 0 to %0 { + for %t1 = 0 to %0 { + for %i0 = 0 to %0 { + for %i1 = 0 to %1 { + for %i2 = 0 to %2 { %a2 = load %A[%i0, %i1, %i2] : memref } } diff --git a/mlir/test/Transforms/Vectorize/vectorize_outer_loop_2d.mlir b/mlir/test/Transforms/Vectorize/vectorize_outer_loop_2d.mlir index d00b99f1716..4654ab810df 100644 --- a/mlir/test/Transforms/Vectorize/vectorize_outer_loop_2d.mlir +++ b/mlir/test/Transforms/Vectorize/vectorize_outer_loop_2d.mlir @@ -7,13 +7,13 @@ func @vec2d(%A : memref) { %M = dim %A, 0 : memref %N = dim %A, 1 : memref %P = dim %A, 2 : memref - // CHECK: affine.for %i0 = 0 to %0 step 32 - // CHECK: affine.for %i1 = 0 to %1 { - // CHECK: affine.for %i2 = 0 to %2 step 256 + // CHECK: for %i0 = 0 to %0 step 32 + // CHECK: for %i1 = 0 to %1 { + // CHECK: for %i2 = 0 to %2 step 256 // CHECK: {{.*}} = vector_transfer_read %arg0, %i0, %i1, %i2 {permutation_map: #[[map_proj_d0d1d2_d0d2]]} : (memref, index, index, index) -> vector<32x256xf32> - affine.for %i0 = 0 to %M { - affine.for %i1 = 0 to %N { - affine.for %i2 = 0 to %P { + for %i0 = 0 to %M { + for %i1 = 0 to %N { + for %i2 = 0 to %P { %a2 = load %A[%i0, %i1, %i2] : memref } } @@ -23,9 +23,9 @@ func @vec2d(%A : memref) { // CHECK: for {{.*}} = 0 to %2 { // For the case: --test-fastest-varying=2 --test-fastest-varying=0 no // vectorization happens because of loop nesting order - affine.for %i3 = 0 to %M { - affine.for %i4 = 0 to %N { - affine.for %i5 = 0 to %P { + for %i3 = 0 to %M { + for %i4 = 0 to %N { + for %i5 = 0 to %P { %a5 = load %A[%i4, %i5, %i3] : memref } } diff --git a/mlir/test/Transforms/Vectorize/vectorize_outer_loop_transpose_2d.mlir b/mlir/test/Transforms/Vectorize/vectorize_outer_loop_transpose_2d.mlir index a8a8d5d7790..0eebf816535 100644 --- a/mlir/test/Transforms/Vectorize/vectorize_outer_loop_transpose_2d.mlir +++ b/mlir/test/Transforms/Vectorize/vectorize_outer_loop_transpose_2d.mlir @@ -12,20 +12,20 @@ func @vec2d(%A : memref) { // CHECK: for {{.*}} = 0 to %2 { // For the case: --test-fastest-varying=0 --test-fastest-varying=2 no // vectorization happens because of loop nesting order. - affine.for %i0 = 0 to %M { - affine.for %i1 = 0 to %N { - affine.for %i2 = 0 to %P { + for %i0 = 0 to %M { + for %i1 = 0 to %N { + for %i2 = 0 to %P { %a2 = load %A[%i0, %i1, %i2] : memref } } } - // CHECK: affine.for %i3 = 0 to %0 step 32 - // CHECK: affine.for %i4 = 0 to %1 step 256 - // CHECK: affine.for %i5 = 0 to %2 { + // CHECK: for %i3 = 0 to %0 step 32 + // CHECK: for %i4 = 0 to %1 step 256 + // CHECK: for %i5 = 0 to %2 { // CHECK: {{.*}} = vector_transfer_read %arg0, %i4, %i5, %i3 {permutation_map: #[[map_proj_d0d1d2_d2d0]]} : (memref, index, index, index) -> vector<32x256xf32> - affine.for %i3 = 0 to %M { - affine.for %i4 = 0 to %N { - affine.for %i5 = 0 to %P { + for %i3 = 0 to %M { + for %i4 = 0 to %N { + for %i5 = 0 to %P { %a5 = load %A[%i4, %i5, %i3] : memref } } @@ -37,26 +37,26 @@ func @vec2d_imperfectly_nested(%A : memref) { %0 = dim %A, 0 : memref %1 = dim %A, 1 : memref %2 = dim %A, 2 : memref - // CHECK: affine.for %i0 = 0 to %0 step 32 { - // CHECK: affine.for %i1 = 0 to %1 { - // CHECK: affine.for %i2 = 0 to %2 step 256 { + // CHECK: for %i0 = 0 to %0 step 32 { + // CHECK: for %i1 = 0 to %1 { + // CHECK: for %i2 = 0 to %2 step 256 { // CHECK: %3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: #[[map_proj_d0d1d2_d2d0]]} : (memref, index, index, index) -> vector<32x256xf32> - // CHECK: affine.for %i3 = 0 to %1 step 256 { - // CHECK: affine.for %i4 = 0 to %2 { + // CHECK: for %i3 = 0 to %1 step 256 { + // CHECK: for %i4 = 0 to %2 { // CHECK: %4 = vector_transfer_read %arg0, %i3, %i4, %i0 {permutation_map: #[[map_proj_d0d1d2_d2d0]]} : (memref, index, index, index) -> vector<32x256xf32> - // CHECK: affine.for %i5 = 0 to %2 { + // CHECK: for %i5 = 0 to %2 { // CHECK: %5 = vector_transfer_read %arg0, %i3, %i5, %i0 {permutation_map: #[[map_proj_d0d1d2_d2d0]]} : (memref, index, index, index) -> vector<32x256xf32> - affine.for %i0 = 0 to %0 { - affine.for %i1 = 0 to %1 { - affine.for %i2 = 0 to %2 { + for %i0 = 0 to %0 { + for %i1 = 0 to %1 { + for %i2 = 0 to %2 { %a2 = load %A[%i2, %i1, %i0] : memref } } - affine.for %i3 = 0 to %1 { - affine.for %i4 = 0 to %2 { + for %i3 = 0 to %1 { + for %i4 = 0 to %2 { %a4 = load %A[%i3, %i4, %i0] : memref } - affine.for %i5 = 0 to %2 { + for %i5 = 0 to %2 { %a5 = load %A[%i3, %i5, %i0] : memref } } diff --git a/mlir/test/Transforms/Vectorize/vectorize_transpose_2d.mlir b/mlir/test/Transforms/Vectorize/vectorize_transpose_2d.mlir index b8e4e075890..1ba563b3442 100644 --- a/mlir/test/Transforms/Vectorize/vectorize_transpose_2d.mlir +++ b/mlir/test/Transforms/Vectorize/vectorize_transpose_2d.mlir @@ -12,20 +12,20 @@ func @vec2d(%A : memref) { // CHECK: for {{.*}} = 0 to %2 { // For the case: --test-fastest-varying=0 --test-fastest-varying=1 no // vectorization happens because of loop nesting order. - affine.for %i0 = 0 to %M { - affine.for %i1 = 0 to %N { - affine.for %i2 = 0 to %P { + for %i0 = 0 to %M { + for %i1 = 0 to %N { + for %i2 = 0 to %P { %a2 = load %A[%i0, %i1, %i2] : memref } } } - // CHECK: affine.for %i3 = 0 to %0 step 32 - // CHECK: affine.for %i4 = 0 to %1 { - // CHECK: affine.for %i5 = 0 to %2 step 256 + // CHECK: for %i3 = 0 to %0 step 32 + // CHECK: for %i4 = 0 to %1 { + // CHECK: for %i5 = 0 to %2 step 256 // CHECK: {{.*}} = vector_transfer_read %arg0, %i4, %i5, %i3 {permutation_map: #[[map_proj_d0d1d2_d2d1]]} : (memref, index, index, index) -> vector<32x256xf32> - affine.for %i3 = 0 to %M { - affine.for %i4 = 0 to %N { - affine.for %i5 = 0 to %P { + for %i3 = 0 to %M { + for %i4 = 0 to %N { + for %i5 = 0 to %P { %a5 = load %A[%i4, %i5, %i3] : memref } } @@ -37,26 +37,26 @@ func @vec2d_imperfectly_nested(%A : memref) { %0 = dim %A, 0 : memref %1 = dim %A, 1 : memref %2 = dim %A, 2 : memref - // CHECK: affine.for %i0 = 0 to %0 step 32 { - // CHECK: affine.for %i1 = 0 to %1 step 256 { - // CHECK: affine.for %i2 = 0 to %2 { + // CHECK: for %i0 = 0 to %0 step 32 { + // CHECK: for %i1 = 0 to %1 step 256 { + // CHECK: for %i2 = 0 to %2 { // CHECK: %3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: #[[map_proj_d0d1d2_d2d1]]} : (memref, index, index, index) -> vector<32x256xf32> - // CHECK: affine.for %i3 = 0 to %1 { - // CHECK: affine.for %i4 = 0 to %2 step 256 { + // CHECK: for %i3 = 0 to %1 { + // CHECK: for %i4 = 0 to %2 step 256 { // CHECK: %4 = vector_transfer_read %arg0, %i3, %i4, %i0 {permutation_map: #[[map_proj_d0d1d2_d2d1]]} : (memref, index, index, index) -> vector<32x256xf32> - // CHECK: affine.for %i5 = 0 to %2 step 256 { + // CHECK: for %i5 = 0 to %2 step 256 { // CHECK: %5 = vector_transfer_read %arg0, %i3, %i5, %i0 {permutation_map: #[[map_proj_d0d1d2_d2d1]]} : (memref, index, index, index) -> vector<32x256xf32> - affine.for %i0 = 0 to %0 { - affine.for %i1 = 0 to %1 { - affine.for %i2 = 0 to %2 { + for %i0 = 0 to %0 { + for %i1 = 0 to %1 { + for %i2 = 0 to %2 { %a2 = load %A[%i2, %i1, %i0] : memref } } - affine.for %i3 = 0 to %1 { - affine.for %i4 = 0 to %2 { + for %i3 = 0 to %1 { + for %i4 = 0 to %2 { %a4 = load %A[%i3, %i4, %i0] : memref } - affine.for %i5 = 0 to %2 { + for %i5 = 0 to %2 { %a5 = load %A[%i3, %i5, %i0] : memref } } diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir index cc295751748..29accf4ffc1 100644 --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -213,10 +213,10 @@ func @dyn_shape_fold(%L : index, %M : index) -> (memref, memref %c = alloc(%K, %N) : memref - // CHECK: affine.for %i0 = - affine.for %i = 0 to %L { - // CHECK-NEXT: affine.for %i1 = - affine.for %j = 0 to 10 { + // CHECK: for %i0 = + for %i = 0 to %L { + // CHECK-NEXT: for %i1 = + for %j = 0 to 10 { // CHECK-NEXT: %4 = load %0[%i0, %i1] : memref // CHECK-NEXT: store %4, %1[%c0, %c0, %i0, %i1, %c0] : memref<4x1024x8x512x?xf32> %v = load %a[%i, %j] : memref @@ -242,8 +242,8 @@ func @merge_constants() -> (index, index) { // CHECK-LABEL: func @hoist_constant func @hoist_constant(%arg0: memref<8xi32>) { // CHECK-NEXT: %c42_i32 = constant 42 : i32 - // CHECK-NEXT: affine.for %i0 = 0 to 8 { - affine.for %i0 = 0 to 8 { + // CHECK-NEXT: for %i0 = 0 to 8 { + for %i0 = 0 to 8 { // CHECK-NEXT: store %c42_i32, %arg0[%i0] %c42_i32 = constant 42 : i32 store %c42_i32, %arg0[%i0] : memref<8xi32> diff --git a/mlir/test/Transforms/constant-fold.mlir b/mlir/test/Transforms/constant-fold.mlir index 1c23914d7a2..6043e478c5a 100644 --- a/mlir/test/Transforms/constant-fold.mlir +++ b/mlir/test/Transforms/constant-fold.mlir @@ -2,8 +2,8 @@ // CHECK-LABEL: @test(%arg0: memref) { func @test(%p : memref) { - affine.for %i0 = 0 to 128 { - affine.for %i1 = 0 to 8 { // CHECK: affine.for %i1 = 0 to 8 { + for %i0 = 0 to 128 { + for %i1 = 0 to 8 { // CHECK: for %i1 = 0 to 8 { %0 = constant 4.5 : f32 %1 = constant 1.5 : f32 diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir index 31a7e13b73e..c4c0da7053e 100644 --- a/mlir/test/Transforms/cse.mlir +++ b/mlir/test/Transforms/cse.mlir @@ -123,8 +123,8 @@ func @down_propagate_for_ml() { // CHECK: %c1_i32 = constant 1 : i32 %0 = constant 1 : i32 - // CHECK-NEXT: affine.for %i0 = 0 to 4 { - affine.for %i = 0 to 4 { + // CHECK-NEXT: for %i0 = 0 to 4 { + for %i = 0 to 4 { // CHECK-NEXT: "foo"(%c1_i32, %c1_i32) : (i32, i32) -> () %1 = constant 1 : i32 "foo"(%0, %1) : (i32, i32) -> () @@ -155,8 +155,8 @@ func @down_propagate_cfg() -> i32 { /// Check that operation definitions are NOT propagated up the dominance tree. // CHECK-LABEL: @up_propagate_ml func @up_propagate_ml() -> i32 { - // CHECK: affine.for %i0 = 0 to 4 { - affine.for %i = 0 to 4 { + // CHECK: for %i0 = 0 to 4 { + for %i = 0 to 4 { // CHECK-NEXT: %c1_i32 = constant 1 : i32 // CHECK-NEXT: "foo"(%c1_i32) : (i32) -> () %0 = constant 1 : i32 diff --git a/mlir/test/Transforms/dma-generate.mlir b/mlir/test/Transforms/dma-generate.mlir index 864a61d3abd..a954bdb96a1 100644 --- a/mlir/test/Transforms/dma-generate.mlir +++ b/mlir/test/Transforms/dma-generate.mlir @@ -32,7 +32,7 @@ func @loop_nest_1d() { // Second DMA transfer. // CHECK: dma_start %1[%c256], %5[%c0], %c256_0, %6[%c0] : memref<512xf32>, memref<256xf32, 1>, memref<1xi32> // CHECK-NEXT: dma_wait %6[%c0], %c256_0 : memref<1xi32> - // CHECK: affine.for %i0 = 0 to 256 { + // CHECK: for %i0 = 0 to 256 { // CHECK-NEXT: %7 = load %3[%i0] : memref<256xf32, 1> // CHECK: %8 = affine.apply [[MAP_PLUS_256]](%i0) // CHECK: %9 = affine.apply [[MAP_MINUS_256]](%8) @@ -41,7 +41,7 @@ func @loop_nest_1d() { // CHECK: %11 = load %2[%i0] : memref<256xf32, 1> // CHECK-NEXT: } // CHECK-NEXT: return - affine.for %i = 0 to 256 { + for %i = 0 to 256 { load %A[%i] : memref<256 x f32> %idx = affine.apply (d0) -> (d0 + 256)(%i) load %B[%idx] : memref<512 x f32> @@ -68,20 +68,20 @@ func @loop_nest_1d() { // INCOMING DMA for C. // CHECK-DAG: dma_start %arg2[%c0, %c0], [[BUFC]][%c0, %c0], %c16384_0, [[TAGC]][%c0] : memref<512x32xf32>, memref<512x32xf32, 1>, memref<1xi32> // CHECK-DAG: dma_wait [[TAGC]][%c0], %c16384_0 : memref<1xi32> -// CHECK-NEXT: affine.for %i0 = 0 to 32 { -// CHECK-NEXT: affine.for %i1 = 0 to 32 { -// CHECK-NEXT: affine.for %i2 = 0 to 32 { -// CHECK-NEXT: affine.for %i3 = 0 to 16 { +// CHECK-NEXT: for %i0 = 0 to 32 { +// CHECK-NEXT: for %i1 = 0 to 32 { +// CHECK-NEXT: for %i2 = 0 to 32 { +// CHECK-NEXT: for %i3 = 0 to 16 { // CHECK-NEXT: %7 = affine.apply #map{{[0-9]+}}(%i1, %i3) // CHECK-NEXT: %8 = load [[BUFB]][%7, %i0] : memref<512x32xf32, 1> // CHECK-NEXT: "foo"(%8) : (f32) -> () // CHECK-NEXT: } -// CHECK-NEXT: affine.for %i4 = 0 to 16 { +// CHECK-NEXT: for %i4 = 0 to 16 { // CHECK-NEXT: %9 = affine.apply #map{{[0-9]+}}(%i2, %i4) // CHECK-NEXT: %10 = load [[BUFA]][%9, %i1] : memref<512x32xf32, 1> // CHECK-NEXT: "bar"(%10) : (f32) -> () // CHECK-NEXT: } -// CHECK-NEXT: affine.for %i5 = 0 to 16 { +// CHECK-NEXT: for %i5 = 0 to 16 { // CHECK-NEXT: %11 = "abc_compute"() : () -> f32 // CHECK-NEXT: %12 = affine.apply #map{{[0-9]+}}(%i2, %i5) // CHECK-NEXT: %13 = load [[BUFC]][%12, %i0] : memref<512x32xf32, 1> @@ -102,20 +102,20 @@ func @loop_nest_high_d(%A: memref<512 x 32 x f32>, // DMAs will be performed at this level (jT is the first loop without a stride). // A and B are read, while C is both read and written. A total of three new buffers // are allocated and existing load's/store's are replaced by accesses to those buffers. - affine.for %jT = 0 to 32 { - affine.for %kT = 0 to 32 { - affine.for %iT = 0 to 32 { - affine.for %kk = 0 to 16 { // k intratile + for %jT = 0 to 32 { + for %kT = 0 to 32 { + for %iT = 0 to 32 { + for %kk = 0 to 16 { // k intratile %k = affine.apply (d0, d1) -> (16*d0 + d1) (%kT, %kk) %v0 = load %B[%k, %jT] : memref<512 x 32 x f32> "foo"(%v0) : (f32) -> () } - affine.for %ii = 0 to 16 { // i intratile. + for %ii = 0 to 16 { // i intratile. %i = affine.apply (d0, d1) -> (16*d0 + d1)(%iT, %ii) %v1 = load %A[%i, %kT] : memref<512 x 32 x f32> "bar"(%v1) : (f32) -> () } - affine.for %ii_ = 0 to 16 { // i intratile. + for %ii_ = 0 to 16 { // i intratile. %v2 = "abc_compute"() : () -> f32 %i_ = affine.apply (d0, d1) -> (16*d0 + d1)(%iT, %ii_) %v3 = load %C[%i_, %jT] : memref<512 x 32 x f32> @@ -134,13 +134,13 @@ func @loop_nest_high_d(%A: memref<512 x 32 x f32>, // // CHECK-LABEL: func @loop_nest_modulo() { // CHECK: %0 = alloc() : memref<256x8xf32> -// CHECK-NEXT: affine.for %i0 = 0 to 32 step 4 { +// CHECK-NEXT: for %i0 = 0 to 32 step 4 { // CHECK-NEXT: %1 = affine.apply #map{{[0-9]+}}(%i0) // CHECK-NEXT: %2 = alloc() : memref<1x2xf32, 1> // CHECK-NEXT: %3 = alloc() : memref<1xi32> // CHECK-NEXT: dma_start %0[%1, %c0], %2[%c0, %c0], %c2, %3[%c0] : memref<256x8xf32>, memref<1x2xf32, 1>, memref<1xi32> // CHECK-NEXT: dma_wait %3[%c0], %c2 : memref<1xi32> -// CHECK-NEXT: affine.for %i1 = 0 to 8 { +// CHECK-NEXT: for %i1 = 0 to 8 { // ... // ... // CHECK: } @@ -148,9 +148,9 @@ func @loop_nest_high_d(%A: memref<512 x 32 x f32>, // CHECK-NEXT: return func @loop_nest_modulo() { %A = alloc() : memref<256 x 8 x f32> - affine.for %i = 0 to 32 step 4 { + for %i = 0 to 32 step 4 { // DMAs will be performed at this level (%j is the first unit stride loop) - affine.for %j = 0 to 8 { + for %j = 0 to 8 { %idx = affine.apply (d0) -> (d0 mod 2) (%j) // A buffer of size 32 x 2 will be allocated (original buffer was 256 x 8). %v = load %A[%i, %idx] : memref<256 x 8 x f32> @@ -164,17 +164,17 @@ func @loop_nest_modulo() { // CHECK-LABEL: func @loop_nest_tiled() -> memref<256x1024xf32> { func @loop_nest_tiled() -> memref<256x1024xf32> { %0 = alloc() : memref<256x1024xf32> - affine.for %i0 = 0 to 256 step 32 { - affine.for %i1 = 0 to 1024 step 32 { + for %i0 = 0 to 256 step 32 { + for %i1 = 0 to 1024 step 32 { // CHECK: %3 = alloc() : memref<32x32xf32, 1> // CHECK-NEXT: %4 = alloc() : memref<1xi32> // Strided DMA here: 32 x 32 tile in a 256 x 1024 memref. // CHECK-NEXT: dma_start %0[%1, %2], %3[%c0, %c0], %c1024, %4[%c0], %c1024_0, %c32 : memref<256x1024xf32>, memref<32x32xf32, 1>, memref<1xi32> // CHECK-NEXT: dma_wait -// CHECK-NEXT: affine.for %i2 = #map -// CHECK-NEXT: affine.for %i3 = #map - affine.for %i2 = (d0) -> (d0)(%i0) to (d0) -> (d0 + 32)(%i0) { - affine.for %i3 = (d0) -> (d0)(%i1) to (d0) -> (d0 + 32)(%i1) { +// CHECK-NEXT: for %i2 = #map +// CHECK-NEXT: for %i3 = #map + for %i2 = (d0) -> (d0)(%i0) to (d0) -> (d0 + 32)(%i0) { + for %i3 = (d0) -> (d0)(%i1) to (d0) -> (d0 + 32)(%i1) { // CHECK-NEXT: %5 = affine.apply [[MAP_INDEX_DIFF_EVEN]](%i0, %i1, %i2, %i3) // CHECK-NEXT: %6 = affine.apply [[MAP_INDEX_DIFF_ODD]](%i0, %i1, %i2, %i3) // CHECK-NEXT: %7 = load %3[%5, %6] : memref<32x32xf32, 1> @@ -196,8 +196,8 @@ func @dma_constant_dim_access(%A : memref<100x100xf32>) { // No strided DMA needed here. // CHECK: dma_start %arg0[%c1, %c0], %0[%c0, %c0], %c100, %1[%c0] : memref<100x100xf32>, memref<1x100xf32, 1>, // CHECK-NEXT: dma_wait %1[%c0], %c100 : memref<1xi32> - affine.for %i = 0 to 100 { - affine.for %j = 0 to ()[s0] -> (s0) ()[%N] { + for %i = 0 to 100 { + for %j = 0 to ()[s0] -> (s0) ()[%N] { // CHECK: %2 = affine.apply [[MAP_D0_MINUS_ONE]](%c1_0, %i1) // CHECK: %3 = affine.apply [[MAP_D1]](%c1_0, %i1) // CHECK-NEXT: %4 = load %0[%2, %3] : memref<1x100xf32, 1> @@ -210,8 +210,8 @@ func @dma_constant_dim_access(%A : memref<100x100xf32>) { // CHECK-LABEL: func @dma_with_symbolic_accesses func @dma_with_symbolic_accesses(%A : memref<100x100xf32>, %M : index) { %N = constant 9 : index - affine.for %i = 0 to 100 { - affine.for %j = 0 to 100 { + for %i = 0 to 100 { + for %j = 0 to 100 { %idy = affine.apply (d0, d1) [s0, s1] -> (d1 + s0 + s1)(%i, %j)[%M, %N] load %A[%i, %idy] : memref<100 x 100 x f32> } @@ -221,8 +221,8 @@ func @dma_with_symbolic_accesses(%A : memref<100x100xf32>, %M : index) { // CHECK-NEXT: %2 = alloc() : memref<1xi32> // CHECK-NEXT: dma_start %arg0[%c0, %0], %1[%c0, %c0], %c10000, %2[%c0] // CHECK-NEXT: dma_wait %2[%c0], %c10000 -// CHECK-NEXT: affine.for %i0 = 0 to 100 { -// CHECK-NEXT: affine.for %i1 = 0 to 100 { +// CHECK-NEXT: for %i0 = 0 to 100 { +// CHECK-NEXT: for %i1 = 0 to 100 { // CHECK-NEXT: %3 = affine.apply [[MAP_SYM_SHIFT]](%i0, %i1)[%arg1, %c9] // CHECK-NEXT: %4 = affine.apply [[MAP_3D_D1]](%arg1, %i0, %3) // CHECK-NEXT: %5 = affine.apply [[MAP_SUB_OFFSET]](%arg1, %i0, %3) @@ -241,8 +241,8 @@ func @dma_with_symbolic_loop_bounds(%A : memref<100x100xf32>, %M : index, %N: in // CHECK-NEXT: %1 = alloc() : memref<1xi32> // CHECK-NEXT: dma_start %arg0[%c0, %c0], %0[%c0, %c0], %c10000, %1[%c0] : memref<100x100xf32>, memref<100x100xf32, 1>, memref<1xi32> // CHECK-NEXT: dma_wait %1[%c0], %c10000 : memref<1xi32> - affine.for %i = 0 to 100 { - affine.for %j = %M to %N { + for %i = 0 to 100 { + for %j = %M to %N { %idy = affine.apply (d1) [s0] -> (d1 + s0)(%j)[%K] load %A[%i, %idy] : memref<100 x 100 x f32> } @@ -256,8 +256,8 @@ func @dma_with_symbolic_loop_bounds(%A : memref<100x100xf32>, %M : index, %N: in func @dma_unknown_size(%arg0: memref) { %M = dim %arg0, 0 : memref %N = dim %arg0, 0 : memref - affine.for %i = 0 to %M { - affine.for %j = 0 to %N { + for %i = 0 to %M { + for %j = 0 to %N { // If this loop nest isn't tiled, the access requires a non-constant DMA // size -- not yet implemented. // CHECK: %2 = load %arg0[%i0, %i1] : memref @@ -272,9 +272,9 @@ func @dma_unknown_size(%arg0: memref) { // CHECK-LABEL: func @dma_memref_3d func @dma_memref_3d(%arg0: memref<1024x1024x1024xf32>) { - affine.for %i = 0 to 1024 { - affine.for %j = 0 to 1024 { - affine.for %k = 0 to 1024 { + for %i = 0 to 1024 { + for %j = 0 to 1024 { + for %k = 0 to 1024 { %idx = affine.apply (d0) -> (d0 mod 128)(%i) %idy = affine.apply (d0) -> (d0 mod 128)(%j) %idz = affine.apply (d0) -> (d0 mod 128)(%k) @@ -308,8 +308,8 @@ func @dma_memref_3d(%arg0: memref<1024x1024x1024xf32>) { // CHECK-LABEL: func @multi_load_store_union() { func @multi_load_store_union() { %A = alloc() : memref<512 x 512 x f32> - affine.for %i = 0 to 256 { - affine.for %j = 0 to 256 { + for %i = 0 to 256 { + for %j = 0 to 256 { %idx = affine.apply (d0) -> (d0 + 64)(%i) %idy = affine.apply (d0) -> (d0 + 128)(%j) %ishift = affine.apply (d0) -> (d0 + 2)(%i) @@ -333,8 +333,8 @@ func @multi_load_store_union() { // CHECK-NEXT: dma_start %0[%c2_1, %c2_2], %1[%c0, %c0], %c170372_3, %2[%c0], %c512_4, %c446_5 : memref<512x512xf32>, memref<382x446xf32, 1>, memref<1xi32> // CHECK-NEXT: dma_wait %2[%c0], %c170372_3 : memref<1xi32> // CHECK-NEXT: %3 = alloc() : memref<1xi32> -// CHECK-NEXT: affine.for %i0 = 0 to 256 { -// CHECK-NEXT: affine.for %i1 = 0 to 256 { +// CHECK-NEXT: for %i0 = 0 to 256 { +// CHECK-NEXT: for %i1 = 0 to 256 { // CHECK-NEXT: %4 = affine.apply [[MAP_PLUS_64]](%i0) // CHECK-NEXT: %5 = affine.apply [[MAP_PLUS_128]](%i1) // CHECK-NEXT: %6 = affine.apply [[MAP_PLUS_2]](%i0) @@ -370,7 +370,7 @@ func @dma_loop_straightline_interspersed() { %c255 = constant 255 : index %A = alloc() : memref<256 x f32> %v = load %A[%c0] : memref<256 x f32> - affine.for %i = 1 to 255 { + for %i = 1 to 255 { load %A[%i] : memref<256 x f32> } %l = load %A[%c255] : memref<256 x f32> @@ -389,7 +389,7 @@ func @dma_loop_straightline_interspersed() { // CHECK-NEXT: %5 = alloc() : memref<1xi32> // CHECK-NEXT: dma_start %0[%c1_0], %4[%c0], %c254, %5[%c0] : memref<256xf32>, memref<254xf32, 1>, memref<1xi32> // CHECK-NEXT: dma_wait %5[%c0], %c254 : memref<1xi32> -// CHECK-NEXT: affine.for %i0 = 1 to 255 { +// CHECK-NEXT: for %i0 = 1 to 255 { // CHECK-NEXT: %6 = affine.apply [[MAP_MINUS_ONE]](%i0) // CHECK-NEXT: %7 = load %4[%6] : memref<254xf32, 1> // CHECK-NEXT: } @@ -410,10 +410,10 @@ func @dma_loop_straightline_interspersed() { func @dma_mixed_loop_blocks() { %c0 = constant 0 : index %A = alloc() : memref<256 x 256 x vector<8 x f32>> - affine.for %i = 0 to 256 { + for %i = 0 to 256 { %v = load %A[%c0, %c0] : memref<256 x 256 x vector<8 x f32>> "foo"(%v) : (vector<8 x f32>) -> () - affine.for %j = 0 to 256 { + for %j = 0 to 256 { %w = load %A[%i, %j] : memref<256 x 256 x vector<8 x f32>> "bar"(%w) : (vector<8 x f32>) -> () } @@ -425,7 +425,7 @@ func @dma_mixed_loop_blocks() { // CHECK-DAG: [[TAG:%[0-9]+]] = alloc() : memref<1xi32> // CHECK: dma_start [[MEM]][%c0, %c0], [[BUF]][%c0, %c0], %c65536, [[TAG]][%c0] : memref<256x256xvector<8xf32>>, memref<256x256xvector<8xf32>, 1>, memref<1xi32> // CHECK-NEXT: dma_wait [[TAG]][%c0], %c65536 : memref<1xi32> -// CHECK-NEXT: affine.for %i0 = 0 to 256 { +// CHECK-NEXT: for %i0 = 0 to 256 { // CHECK-NEXT: %3 = load [[BUF]][%c0_0, %c0_0] : memref<256x256xvector<8xf32>, 1> -// CHECK: affine.for %i1 = 0 to 256 { +// CHECK: for %i1 = 0 to 256 { // CHECK-NEXT: %4 = load [[BUF]][%i0, %i1] : memref<256x256xvector<8xf32>, 1> diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index 7fbf7097be3..439e93137a4 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -16,13 +16,13 @@ func @should_fuse_raw_dep_for_locality() { %m = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 - affine.for %i0 = 0 to 10 { + for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> } - affine.for %i1 = 0 to 10 { + for %i1 = 0 to 10 { %v0 = load %m[%i1] : memref<10xf32> } - // CHECK: affine.for %i0 = 0 to 10 { + // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: %1 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: store %cst, %0[%1] : memref<1xf32> // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i0, %i0) @@ -44,23 +44,23 @@ func @should_fuse_reduction_to_pointwise() { %cf7 = constant 7.0 : f32 - affine.for %i0 = 0 to 10 { - affine.for %i1 = 0 to 10 { + for %i0 = 0 to 10 { + for %i1 = 0 to 10 { %v0 = load %b[%i0] : memref<10xf32> %v1 = load %a[%i0, %i1] : memref<10x10xf32> %v3 = addf %v0, %v1 : f32 store %v3, %b[%i0] : memref<10xf32> } } - affine.for %i2 = 0 to 10 { + for %i2 = 0 to 10 { %v4 = load %b[%i2] : memref<10xf32> store %v4, %c[%i2] : memref<10xf32> } // Should fuse in entire inner loop on %i1 from source loop nest, as %i1 // is not used in the access function of the store/load on %b. - // CHECK: affine.for %i0 = 0 to 10 { - // CHECK-NEXT: affine.for %i1 = 0 to 10 { + // CHECK: for %i0 = 0 to 10 { + // CHECK-NEXT: for %i1 = 0 to 10 { // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: %4 = load %0[%3] : memref<1xf32> // CHECK-NEXT: %5 = load %1[%i0, %i1] : memref<10x10xf32> @@ -88,15 +88,15 @@ func @should_fuse_loop_nests_with_shifts() { %a = alloc() : memref<10x10xf32> %cf7 = constant 7.0 : f32 - affine.for %i0 = 0 to 9 { - affine.for %i1 = 0 to 9 { + for %i0 = 0 to 9 { + for %i1 = 0 to 9 { %idx = affine.apply (d0) -> (d0 + 1) (%i0) %idy = affine.apply (d0) -> (d0 + 1) (%i1) store %cf7, %a[%idx, %idy] : memref<10x10xf32> } } - affine.for %i2 = 1 to 10 { - affine.for %i3 = 1 to 10 { + for %i2 = 1 to 10 { + for %i3 = 1 to 10 { %v0 = load %a[%i2, %i3] : memref<10x10xf32> } } @@ -109,8 +109,8 @@ func @should_fuse_loop_nests_with_shifts() { // *) Fifth affine apply shifts the loads access function by '-1', because // of the offset induced by reducing the memref shape from 10x10 to 9x9. // NOTE: Should create a private memref with reduced shape 9x9xf32. - // CHECK: affine.for %i0 = 1 to 10 { - // CHECK-NEXT: affine.for %i1 = 1 to 10 { + // CHECK: for %i0 = 1 to 10 { + // CHECK-NEXT: for %i1 = 1 to 10 { // CHECK-NEXT: %1 = affine.apply [[MAP_SHIFT_MINUS_ONE_R1]](%i0) // CHECK-NEXT: %2 = affine.apply [[MAP_SHIFT_MINUS_ONE_R1]](%i1) // CHECK-NEXT: %3 = affine.apply [[MAP_SHIFT_BY_ONE]](%1) @@ -138,27 +138,27 @@ func @should_fuse_loop_nest() { %b = alloc() : memref<10x10xf32> %cf7 = constant 7.0 : f32 - affine.for %i0 = 0 to 10 { - affine.for %i1 = 0 to 10 { + for %i0 = 0 to 10 { + for %i1 = 0 to 10 { store %cf7, %a[%i0, %i1] : memref<10x10xf32> } } - affine.for %i2 = 0 to 10 { - affine.for %i3 = 0 to 10 { + for %i2 = 0 to 10 { + for %i3 = 0 to 10 { %v0 = load %a[%i3, %i2] : memref<10x10xf32> store %v0, %b[%i2, %i3] : memref<10x10xf32> } } - affine.for %i4 = 0 to 10 { - affine.for %i5 = 0 to 10 { + for %i4 = 0 to 10 { + for %i5 = 0 to 10 { %v1 = load %b[%i4, %i5] : memref<10x10xf32> } } // Expecting private memref for '%a' first, then private memref for '%b'. // CHECK-DAG: [[NEWA:%[0-9]+]] = alloc() : memref<1x1xf32> // CHECK-DAG: [[NEWB:%[0-9]+]] = alloc() : memref<1x1xf32> - // CHECK: affine.for %i0 = 0 to 10 { - // CHECK-NEXT: affine.for %i1 = 0 to 10 { + // CHECK: for %i0 = 0 to 10 { + // CHECK-NEXT: for %i1 = 0 to 10 { // CHECK-NEXT: %2 = affine.apply [[MAP_D2_D0_DIFF]](%i1, %i0, %i1, %i0) // CHECK-NEXT: %3 = affine.apply [[MAP_D3_D1_DIFF]](%i1, %i0, %i1, %i0) // CHECK-NEXT: store %cst, [[NEWA]][%2, %3] : memref<1x1xf32> @@ -189,23 +189,23 @@ func @should_fuse_across_intermediate_loop_with_no_deps() { %cf7 = constant 7.0 : f32 - affine.for %i0 = 0 to 10 { + for %i0 = 0 to 10 { %v0 = load %a[%i0] : memref<10xf32> store %v0, %b[%i0] : memref<10xf32> } - affine.for %i1 = 0 to 10 { + for %i1 = 0 to 10 { store %cf7, %c[%i1] : memref<10xf32> } - affine.for %i2 = 0 to 10 { + for %i2 = 0 to 10 { %v1 = load %b[%i2] : memref<10xf32> } // Should fuse first loop (past second loop with no dependences) into third. // Note that fusion creates a private memref '%2' for the fused loop nest. - // CHECK: affine.for %i0 = 0 to 10 { + // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %2[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK: affine.for %i1 = 0 to 10 { + // CHECK: for %i1 = 0 to 10 { // CHECK-NEXT: %3 = load %1[%i1] : memref<10xf32> // CHECK-NEXT: %4 = affine.apply [[MAP0]](%i1, %i1) // CHECK-NEXT: store %3, %0[%4] : memref<1xf32> @@ -227,13 +227,13 @@ func @should_fuse_all_loops() { %cf7 = constant 7.0 : f32 // Set up flow dependences from first and second loops to third. - affine.for %i0 = 0 to 10 { + for %i0 = 0 to 10 { store %cf7, %a[%i0] : memref<10xf32> } - affine.for %i1 = 0 to 10 { + for %i1 = 0 to 10 { store %cf7, %b[%i1] : memref<10xf32> } - affine.for %i2 = 0 to 10 { + for %i2 = 0 to 10 { %v0 = load %a[%i2] : memref<10xf32> %v1 = load %b[%i2] : memref<10xf32> } @@ -242,7 +242,7 @@ func @should_fuse_all_loops() { // Expecting private memref for '%a' first, then private memref for '%b'. // CHECK-DAG: [[NEWA:%[0-9]+]] = alloc() : memref<1xf32> // CHECK-DAG: [[NEWB:%[0-9]+]] = alloc() : memref<1xf32> - // CHECK: affine.for %i0 = 0 to 10 { + // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: store %cst, [[NEWA]][%2] : memref<1xf32> // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i0) @@ -268,27 +268,27 @@ func @should_fuse_first_and_second_loops() { %cf7 = constant 7.0 : f32 - affine.for %i0 = 0 to 10 { + for %i0 = 0 to 10 { store %cf7, %a[%i0] : memref<10xf32> } - affine.for %i1 = 0 to 10 { + for %i1 = 0 to 10 { %v0 = load %a[%i1] : memref<10xf32> store %cf7, %b[%i1] : memref<10xf32> } - affine.for %i2 = 0 to 10 { + for %i2 = 0 to 10 { %v1 = load %c[%i2] : memref<10xf32> } // Should fuse first loop into the second (last loop should not be fused). // Should create private memref '%2' for fused loop. - // CHECK: affine.for %i0 = 0 to 10 { + // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: store %cst, %0[%3] : memref<1xf32> // CHECK-NEXT: %4 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: %5 = load %0[%4] : memref<1xf32> // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK: affine.for %i1 = 0 to 10 { + // CHECK: for %i1 = 0 to 10 { // CHECK-NEXT: %6 = load %2[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return @@ -310,28 +310,28 @@ func @should_not_fuse_would_create_cycle() { // 1) loop0 -> loop1 on memref '%a' // 2) loop0 -> loop2 on memref '%b' // 3) loop1 -> loop2 on memref '%c' - affine.for %i0 = 0 to 10 { + for %i0 = 0 to 10 { %v0 = load %a[%i0] : memref<10xf32> store %cf7, %b[%i0] : memref<10xf32> } - affine.for %i1 = 0 to 10 { + for %i1 = 0 to 10 { store %cf7, %a[%i1] : memref<10xf32> %v1 = load %c[%i1] : memref<10xf32> } - affine.for %i2 = 0 to 10 { + for %i2 = 0 to 10 { %v2 = load %b[%i2] : memref<10xf32> store %cf7, %c[%i2] : memref<10xf32> } // Should not fuse: fusing loop first loop into last would create a cycle. - // CHECK: affine.for %i0 = 0 to 10 { + // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: %3 = load %0[%i0] : memref<10xf32> // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK: affine.for %i1 = 0 to 10 { + // CHECK: for %i1 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i1] : memref<10xf32> // CHECK-NEXT: %4 = load %2[%i1] : memref<10xf32> // CHECK-NEXT: } - // CHECK: affine.for %i2 = 0 to 10 { + // CHECK: for %i2 = 0 to 10 { // CHECK-NEXT: %5 = load %1[%i2] : memref<10xf32> // CHECK-NEXT: store %cst, %2[%i2] : memref<10xf32> // CHECK-NEXT: } @@ -346,23 +346,23 @@ func @should_not_fuse_across_waw_dep() { %m = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 - affine.for %i0 = 0 to 10 { + for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> } - affine.for %i1 = 0 to 10 { + for %i1 = 0 to 10 { store %cf7, %m[%i1] : memref<10xf32> } - affine.for %i2 = 0 to 10 { + for %i2 = 0 to 10 { %v1 = load %m[%i2] : memref<10xf32> } // Fusing loop %i0 to %i2 would violate the WAW dependence between %i0 and %i1 - // CHECK: affine.for %i0 = 0 to 10 { + // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK: affine.for %i1 = 0 to 10 { + // CHECK: for %i1 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i1] : memref<10xf32> // CHECK-NEXT: } - // CHECK: affine.for %i2 = 0 to 10 { + // CHECK: for %i2 = 0 to 10 { // CHECK-NEXT: %1 = load %0[%i2] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return @@ -379,27 +379,27 @@ func @should_fuse_and_move_to_preserve_war_dep() { %b = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 - affine.for %i0 = 0 to 10 { + for %i0 = 0 to 10 { %v0 = load %a[%i0] : memref<10xf32> store %v0, %b[%i0] : memref<10xf32> } - affine.for %i1 = 0 to 10 { + for %i1 = 0 to 10 { store %cf7, %a[%i1] : memref<10xf32> } - affine.for %i2 = 0 to 10 { + for %i2 = 0 to 10 { %v1 = load %b[%i2] : memref<10xf32> } // Loops '%i1' and '%i2' have no dependences. We can fuse a slice of '%i0' // into '%i2' if we move the fused loop nest before '%i1', which preserves // the WAR dependence from load '%a' in '%i0' to the store '%a' in loop '%i1'. - // CHECK: affine.for %i0 = 0 to 10 { + // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: %2 = load %1[%i0] : memref<10xf32> // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: store %2, %0[%3] : memref<1xf32> // CHECK-NEXT: %4 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: %5 = load %0[%4] : memref<1xf32> // CHECK-NEXT: } - // CHECK-NEXT: affine.for %i1 = 0 to 10 { + // CHECK-NEXT: for %i1 = 0 to 10 { // CHECK-NEXT: store %cst, %1[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return @@ -415,20 +415,20 @@ func @should_fuse_with_private_memref_if_top_level_access() { %m = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 - affine.for %i0 = 0 to 10 { + for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> } - affine.for %i1 = 0 to 10 { + for %i1 = 0 to 10 { %v0 = load %m[%i1] : memref<10xf32> } %c0 = constant 4 : index %v1 = load %m[%c0] : memref<10xf32> // Top-level load to '%m' should prevent fusion. - // CHECK: affine.for %i0 = 0 to 10 { + // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: affine.for %i1 = 0 to 10 { + // CHECK-NEXT: for %i1 = 0 to 10 { // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i1, %i1) // CHECK-NEXT: store %cst, %0[%2] : memref<1xf32> // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i1, %i1) @@ -446,13 +446,13 @@ func @should_fuse_no_top_level_access() { %m = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 - affine.for %i0 = 0 to 10 { + for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> } - affine.for %i1 = 0 to 10 { + for %i1 = 0 to 10 { %v0 = load %m[%i1] : memref<10xf32> } - // CHECK: affine.for %i0 = 0 to 10 { + // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: %1 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: store %cst, %0[%1] : memref<1xf32> // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i0, %i0) @@ -471,20 +471,20 @@ func @should_not_fuse_if_inst_at_top_level() { %m = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 - affine.for %i0 = 0 to 10 { + for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> } - affine.for %i1 = 0 to 10 { + for %i1 = 0 to 10 { %v0 = load %m[%i1] : memref<10xf32> } %c0 = constant 4 : index affine.if #set0(%c0) { } // Top-level IfOp should prevent fusion. - // CHECK: affine.for %i0 = 0 to 10 { + // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK: affine.for %i1 = 0 to 10 { + // CHECK: for %i1 = 0 to 10 { // CHECK-NEXT: %1 = load %0[%i1] : memref<10xf32> // CHECK-NEXT: } return @@ -500,20 +500,20 @@ func @should_not_fuse_if_inst_in_loop_nest() { %cf7 = constant 7.0 : f32 %c4 = constant 4 : index - affine.for %i0 = 0 to 10 { + for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> } - affine.for %i1 = 0 to 10 { + for %i1 = 0 to 10 { affine.if #set0(%c4) { } %v0 = load %m[%i1] : memref<10xf32> } // IfOp in ForInst should prevent fusion. - // CHECK: affine.for %i0 = 0 to 10 { + // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK: affine.for %i1 = 0 to 10 { + // CHECK: for %i1 = 0 to 10 { // CHECK-NEXT: affine.if #set0(%c4) { // CHECK-NEXT: } // CHECK-NEXT: %1 = load %0[%i1] : memref<10xf32> @@ -532,24 +532,24 @@ func @permute_and_fuse() { %m = alloc() : memref<10x20x30xf32> %cf7 = constant 7.0 : f32 - affine.for %i0 = 0 to 10 { - affine.for %i1 = 0 to 20 { - affine.for %i2 = 0 to 30 { + for %i0 = 0 to 10 { + for %i1 = 0 to 20 { + for %i2 = 0 to 30 { store %cf7, %m[%i0, %i1, %i2] : memref<10x20x30xf32> } } } - affine.for %i3 = 0 to 30 { - affine.for %i4 = 0 to 10 { - affine.for %i5 = 0 to 20 { + for %i3 = 0 to 30 { + for %i4 = 0 to 10 { + for %i5 = 0 to 20 { %v0 = load %m[%i4, %i5, %i3] : memref<10x20x30xf32> "foo"(%v0) : (f32) -> () } } } -// CHECK: affine.for %i0 = 0 to 30 { -// CHECK-NEXT: affine.for %i1 = 0 to 10 { -// CHECK-NEXT: affine.for %i2 = 0 to 20 { +// CHECK: for %i0 = 0 to 30 { +// CHECK-NEXT: for %i1 = 0 to 10 { +// CHECK-NEXT: for %i2 = 0 to 20 { // CHECK-NEXT: %1 = affine.apply [[MAP0]](%i1, %i2, %i0, %i1, %i2, %i0) // CHECK-NEXT: %2 = affine.apply [[MAP1]](%i1, %i2, %i0, %i1, %i2, %i0) // CHECK-NEXT: %3 = affine.apply [[MAP2]](%i1, %i2, %i0, %i1, %i2, %i0) @@ -578,22 +578,22 @@ func @permute_and_fuse() { func @fuse_reshape_64_16_4(%in : memref<64xf32>) { %out = alloc() : memref<16x4xf32> - affine.for %i0 = 0 to 64 { + for %i0 = 0 to 64 { %v = load %in[%i0] : memref<64xf32> %idx = affine.apply (d0) -> (d0 floordiv 4) (%i0) %idy = affine.apply (d0) -> (d0 mod 4) (%i0) store %v, %out[%idx, %idy] : memref<16x4xf32> } - affine.for %i1 = 0 to 16 { - affine.for %i2 = 0 to 4 { + for %i1 = 0 to 16 { + for %i2 = 0 to 4 { %w = load %out[%i1, %i2] : memref<16x4xf32> "foo"(%w) : (f32) -> () } } return - // CHECK: affine.for %i0 = - // CHECK-NEXT: affine.for %i1 = + // CHECK: for %i0 = + // CHECK-NEXT: for %i1 = // CHECK-NOT: for // CHECK: } // CHECK-NEXT: } @@ -612,19 +612,19 @@ func @fuse_reshape_16_4_64() { %in = alloc() : memref<16x4xf32> %out = alloc() : memref<64xf32> - affine.for %i0 = 0 to 16 { - affine.for %i1 = 0 to 4 { + for %i0 = 0 to 16 { + for %i1 = 0 to 4 { %v = load %in[%i0, %i1] : memref<16x4xf32> %idx = affine.apply (d0, d1) -> (4*d0 + d1) (%i0, %i1) store %v, %out[%idx] : memref<64xf32> } } - affine.for %i2 = 0 to 64 { + for %i2 = 0 to 64 { %w = load %out[%i2] : memref<64xf32> "foo"(%w) : (f32) -> () } -// CHECK: affine.for %i0 = 0 to 64 { +// CHECK: for %i0 = 0 to 64 { // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i0) // CHECK-NEXT: %3 = affine.apply [[MAP1]](%i0) // CHECK-NEXT: %4 = load %1[%2, %3] : memref<16x4xf32> @@ -650,12 +650,12 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> { %live_out = alloc() : memref<64x9xi32> // Initialize input. - affine.for %i0 = 0 to 2 { - affine.for %i1 = 0 to 2 { - affine.for %i2 = 0 to 3 { - affine.for %i3 = 0 to 3 { - affine.for %i4 = 0 to 16 { - affine.for %i5 = 0 to 1 { + for %i0 = 0 to 2 { + for %i1 = 0 to 2 { + for %i2 = 0 to 3 { + for %i3 = 0 to 3 { + for %i4 = 0 to 16 { + for %i5 = 0 to 1 { %val = "foo"(%i0, %i1, %i2, %i3, %i4, %i5) : (index, index, index, index, index, index) -> i32 store %val, %in[%i0, %i1, %i2, %i3, %i4, %i5] : memref<2x2x3x3x16x1xi32> } @@ -665,8 +665,8 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> { } } - affine.for %ii = 0 to 64 { - affine.for %jj = 0 to 9 { + for %ii = 0 to 64 { + for %jj = 0 to 9 { // Convert output coordinates to linear index. %a0 = affine.apply (d0, d1) -> (d0 * 9 + d1) (%ii, %jj) %0 = affine.apply (d0) -> (d0 floordiv (2 * 3 * 3 * 16 * 1))(%a0) @@ -680,8 +680,8 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> { } } - affine.for %i = 0 to 64 { - affine.for %j = 0 to 9 { + for %i = 0 to 64 { + for %j = 0 to 9 { %a = load %out[%i, %j] : memref<64x9xi32> %b = muli %a, %a : i32 store %b, %live_out[%i, %j] : memref<64x9xi32> @@ -717,8 +717,8 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> { // CHECK: %0 = alloc() : memref<1x2x3x3x16x1xi32> // CHECK: %1 = alloc() : memref<1x1xi32> // CHECK: %2 = alloc() : memref<64x9xi32> -// CHECK-NEXT: affine.for %i0 = 0 to 64 { -// CHECK-NEXT: affine.for %i1 = 0 to 9 { +// CHECK-NEXT: for %i0 = 0 to 64 { +// CHECK-NEXT: for %i1 = 0 to 9 { // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i1) // CHECK-NEXT: %4 = affine.apply [[MAP1]](%i0, %i1) // CHECK-NEXT: %5 = affine.apply [[MAP2]](%i0, %i1) @@ -768,14 +768,14 @@ func @fuse_symbolic_bounds(%M : index, %N : index) { %c0 = constant 0.0 : f32 %s = constant 5 : index - affine.for %i0 = 0 to %M { - affine.for %i1 = 0 to (d0) -> (d0 + 5) (%N) { + for %i0 = 0 to %M { + for %i1 = 0 to (d0) -> (d0 + 5) (%N) { store %c0, %m[%i0, %i1] : memref } } - affine.for %i2 = 0 to %M { - affine.for %i3 = 0 to %N { + for %i2 = 0 to %M { + for %i3 = 0 to %N { %idy = affine.apply (d0)[s0] -> (d0 + s0) (%i3)[%s] %v = load %m[%i2, %idy] : memref } @@ -792,16 +792,16 @@ func @should_fuse_reduction_at_depth1() { %a = alloc() : memref<10x100xf32> %b = alloc() : memref<10xf32> - affine.for %i0 = 0 to 10 { - affine.for %i1 = 0 to 100 { + for %i0 = 0 to 10 { + for %i1 = 0 to 100 { %v0 = load %b[%i0] : memref<10xf32> %v1 = load %a[%i0, %i1] : memref<10x100xf32> %v2 = "maxf"(%v0, %v1) : (f32, f32) -> f32 store %v2, %b[%i0] : memref<10xf32> } } - affine.for %i2 = 0 to 10 { - affine.for %i3 = 0 to 100 { + for %i2 = 0 to 10 { + for %i3 = 0 to 100 { %v3 = load %b[%i2] : memref<10xf32> %v4 = load %a[%i2, %i3] : memref<10x100xf32> %v5 = subf %v4, %v3 : f32 @@ -812,8 +812,8 @@ func @should_fuse_reduction_at_depth1() { // loop nest, which improves locality and enables subsequence passes to // decrease the reduction memref size and possibly place it in a faster // memory space. - // CHECK: affine.for %i0 = 0 to 10 { - // CHECK-NEXT: affine.for %i1 = 0 to 100 { + // CHECK: for %i0 = 0 to 10 { + // CHECK-NEXT: for %i1 = 0 to 100 { // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: %3 = load %0[%2] : memref<1xf32> // CHECK-NEXT: %4 = load %1[%i0, %i1] : memref<10x100xf32> @@ -821,7 +821,7 @@ func @should_fuse_reduction_at_depth1() { // CHECK-NEXT: %6 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: store %5, %0[%6] : memref<1xf32> // CHECK-NEXT: } - // CHECK-NEXT: affine.for %i2 = 0 to 100 { + // CHECK-NEXT: for %i2 = 0 to 100 { // CHECK-NEXT: %7 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: %8 = load %0[%7] : memref<1xf32> // CHECK-NEXT: %9 = load %1[%i0, %i2] : memref<10x100xf32> @@ -843,19 +843,19 @@ func @should_fuse_at_src_depth1_and_dst_depth1() { %a = alloc() : memref<100x16xf32> %b = alloc() : memref<100x16xf32> - affine.for %i0 = 0 to 100 { - affine.for %i1 = 0 to 16 { + for %i0 = 0 to 100 { + for %i1 = 0 to 16 { %v0 = load %a[%i0, %i1] : memref<100x16xf32> "op0"(%v0) : (f32) -> () } - affine.for %i2 = 0 to 16 { + for %i2 = 0 to 16 { %v1 = "op1"() : () -> (f32) store %v1, %b[%i0, %i2] : memref<100x16xf32> } } - affine.for %i3 = 0 to 100 { - affine.for %i4 = 0 to 16 { + for %i3 = 0 to 100 { + for %i4 = 0 to 16 { %v2 = load %b[%i3, %i4] : memref<100x16xf32> "op2"(%v2) : (f32) -> () } @@ -865,18 +865,18 @@ func @should_fuse_at_src_depth1_and_dst_depth1() { // destination loop nest at depth2 causes extra computation. Instead, // the fusion algorithm should detect that the source loop should be sliced // at depth 1 and the slice should be inserted at depth 1. - // CHECK: affine.for %i0 = 0 to 100 { - // CHECK-NEXT: affine.for %i1 = 0 to 16 { + // CHECK: for %i0 = 0 to 100 { + // CHECK-NEXT: for %i1 = 0 to 16 { // CHECK-NEXT: %2 = load %1[%i0, %i1] : memref<100x16xf32> // CHECK-NEXT: "op0"(%2) : (f32) -> () // CHECK-NEXT: } - // CHECK-NEXT: affine.for %i2 = 0 to 16 { + // CHECK-NEXT: for %i2 = 0 to 16 { // CHECK-NEXT: %3 = "op1"() : () -> f32 // CHECK-NEXT: %4 = affine.apply [[MAP0]](%i0, %i0, %i2) // CHECK-NEXT: %5 = affine.apply [[MAP1]](%i0, %i0, %i2) // CHECK-NEXT: store %3, %0[%4, %5] : memref<1x16xf32> // CHECK-NEXT: } - // CHECK-NEXT: affine.for %i3 = 0 to 16 { + // CHECK-NEXT: for %i3 = 0 to 16 { // CHECK-NEXT: %6 = affine.apply [[MAP0]](%i0, %i0, %i3) // CHECK-NEXT: %7 = affine.apply [[MAP1]](%i0, %i0, %i3) // CHECK-NEXT: %8 = load %0[%6, %7] : memref<1x16xf32> @@ -896,20 +896,20 @@ func @should_fuse_src_depth1_at_dst_depth2() { %a = alloc() : memref<100xf32> %c0 = constant 0.0 : f32 - affine.for %i0 = 0 to 100 { + for %i0 = 0 to 100 { store %c0, %a[%i0] : memref<100xf32> } - affine.for %i1 = 0 to 10 { - affine.for %i2 = 0 to 10 { + for %i1 = 0 to 10 { + for %i2 = 0 to 10 { %a0 = affine.apply (d0, d1) -> (d0 * 10 + d1) (%i1, %i2) %v0 = load %a[%a0] : memref<100xf32> } } // The source loop nest slice loop bound is a function of both destination // loop IVs, so we should slice at depth 1 and insert the slice at depth 2. - // CHECK: affine.for %i0 = 0 to 10 { - // CHECK-NEXT: affine.for %i1 = 0 to 10 { + // CHECK: for %i0 = 0 to 10 { + // CHECK-NEXT: for %i1 = 0 to 10 { // CHECK-NEXT: %1 = affine.apply [[MAP0]](%i0, %i1) // CHECK-NEXT: %2 = affine.apply [[MAP1]](%i0, %i1, %1) // CHECK-NEXT: store %cst, %0[%2] : memref<1xf32> @@ -930,10 +930,10 @@ func @fusion_at_depth0_not_currently_supported() { %0 = alloc() : memref<10xf32> %c0 = constant 0 : index %cst = constant 0.000000e+00 : f32 - affine.for %i0 = 0 to 10 { + for %i0 = 0 to 10 { store %cst, %0[%i0] : memref<10xf32> } - affine.for %i1 = 0 to 10 { + for %i1 = 0 to 10 { %1 = load %0[%c0] : memref<10xf32> } // NOTE: Should shrink memref size to 1 element access by load in dst loop @@ -966,18 +966,18 @@ func @should_fuse_deep_loop_nests() { %c1 = constant 1 : index %c1_0 = constant 1 : index %cst = constant 0.000000e+00 : f32 - affine.for %i0 = 0 to 2 { - affine.for %i1 = 0 to 2 { - affine.for %i2 = 0 to 3 { - affine.for %i3 = 0 to 3 { - affine.for %i4 = 0 to 16 { - affine.for %i5 = 0 to 10 { + for %i0 = 0 to 2 { + for %i1 = 0 to 2 { + for %i2 = 0 to 3 { + for %i3 = 0 to 3 { + for %i4 = 0 to 16 { + for %i5 = 0 to 10 { %3 = load %0[%i0, %i1, %i2, %i3, %i4, %i5] : memref<2x2x3x3x16x10xf32, 2> } } - affine.for %i6 = 0 to 16 { - affine.for %i7 = 0 to 10 { + for %i6 = 0 to 16 { + for %i7 = 0 to 10 { store %cst, %1[%i0, %i1, %i2, %i3, %i6, %i7] : memref<2x2x3x3x16x10xf32, 2> } @@ -986,22 +986,22 @@ func @should_fuse_deep_loop_nests() { } } } - affine.for %i8 = 0 to 3 { - affine.for %i9 = 0 to 3 { - affine.for %i10 = 0 to 2 { - affine.for %i11 = 0 to 2 { - affine.for %i12 = 0 to 3 { - affine.for %i13 = 0 to 3 { - affine.for %i14 = 0 to 2 { - affine.for %i15 = 0 to 2 { - affine.for %i16 = 0 to 16 { - affine.for %i17 = 0 to 10 { + for %i8 = 0 to 3 { + for %i9 = 0 to 3 { + for %i10 = 0 to 2 { + for %i11 = 0 to 2 { + for %i12 = 0 to 3 { + for %i13 = 0 to 3 { + for %i14 = 0 to 2 { + for %i15 = 0 to 2 { + for %i16 = 0 to 16 { + for %i17 = 0 to 10 { %5 = load %0[%i14, %i15, %i12, %i13, %i16, %i17] : memref<2x2x3x3x16x10xf32, 2> } } - affine.for %i18 = 0 to 16 { - affine.for %i19 = 0 to 10 { + for %i18 = 0 to 16 { + for %i19 = 0 to 10 { %6 = load %1[%i10, %i11, %i8, %i9, %i18, %i19] : memref<2x2x3x3x16x10xf32, 2> } @@ -1019,19 +1019,19 @@ func @should_fuse_deep_loop_nests() { // where the destination loops nests have been interchanged. // CHECK-DAG: %0 = alloc() : memref<1x1x1x1x16x10xf32, 2> -// CHECK: affine.for %i0 = 0 to 3 { -// CHECK-NEXT: affine.for %i1 = 0 to 3 { -// CHECK-NEXT: affine.for %i2 = 0 to 2 { -// CHECK-NEXT: affine.for %i3 = 0 to 2 { -// CHECK-NEXT: affine.for %i4 = 0 to 3 { -// CHECK-NEXT: affine.for %i5 = 0 to 3 { -// CHECK-NEXT: affine.for %i6 = 0 to 16 { -// CHECK-NEXT: affine.for %i7 = 0 to 10 { +// CHECK: for %i0 = 0 to 3 { +// CHECK-NEXT: for %i1 = 0 to 3 { +// CHECK-NEXT: for %i2 = 0 to 2 { +// CHECK-NEXT: for %i3 = 0 to 2 { +// CHECK-NEXT: for %i4 = 0 to 3 { +// CHECK-NEXT: for %i5 = 0 to 3 { +// CHECK-NEXT: for %i6 = 0 to 16 { +// CHECK-NEXT: for %i7 = 0 to 10 { // CHECK-NEXT: %3 = load %1[%i2, %i3, %i0, %i1, %i6, %i7] : memref<2x2x3x3x16x10xf32, 2> // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: affine.for %i8 = 0 to 16 { -// CHECK-NEXT: affine.for %i9 = 0 to 10 { +// CHECK-NEXT: for %i8 = 0 to 16 { +// CHECK-NEXT: for %i9 = 0 to 10 { // CHECK-NEXT: %4 = affine.apply [[MAP0]](%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i8, %i9) // CHECK-NEXT: %5 = affine.apply [[MAP1]](%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i8, %i9) // CHECK-NEXT: %6 = affine.apply [[MAP2]](%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i8, %i9) @@ -1041,15 +1041,15 @@ func @should_fuse_deep_loop_nests() { // CHECK-NEXT: store %cst, %0[%4, %5, %6, %7, %8, %9] : memref<1x1x1x1x16x10xf32, 2> // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: affine.for %i10 = 0 to 2 { -// CHECK-NEXT: affine.for %i11 = 0 to 2 { -// CHECK-NEXT: affine.for %i12 = 0 to 16 { -// CHECK-NEXT: affine.for %i13 = 0 to 10 { +// CHECK-NEXT: for %i10 = 0 to 2 { +// CHECK-NEXT: for %i11 = 0 to 2 { +// CHECK-NEXT: for %i12 = 0 to 16 { +// CHECK-NEXT: for %i13 = 0 to 10 { // CHECK-NEXT: %10 = load %1[%i10, %i11, %i4, %i5, %i12, %i13] : memref<2x2x3x3x16x10xf32, 2> // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: affine.for %i14 = 0 to 16 { -// CHECK-NEXT: affine.for %i15 = 0 to 10 { +// CHECK-NEXT: for %i14 = 0 to 16 { +// CHECK-NEXT: for %i15 = 0 to 10 { // CHECK-NEXT: %11 = affine.apply [[MAP0]](%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i14, %i15) // CHECK-NEXT: %12 = affine.apply [[MAP1]](%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i14, %i15) // CHECK-NEXT: %13 = affine.apply [[MAP2]](%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i14, %i15) @@ -1083,17 +1083,17 @@ func @should_fuse_at_depth1_and_reduce_slice_trip_count() { %c0 = constant 0 : index %cf0 = constant 0.0 : f32 - affine.for %i0 = 0 to 4 { - affine.for %i1 = 0 to 256 { + for %i0 = 0 to 4 { + for %i1 = 0 to 256 { %v0 = load %b[%i0, %i1] : memref<4x256xf32> } - affine.for %i2 = 0 to 256 { + for %i2 = 0 to 256 { store %cf0, %a[%i0, %i2] : memref<4x256xf32> } } - affine.for %d0 = 0 to 4 { - affine.for %d1 = 0 to 16 { + for %d0 = 0 to 4 { + for %d1 = 0 to 16 { %v1 = load %a[%d0, %d1] : memref<4x256xf32> } } @@ -1107,16 +1107,16 @@ func @should_fuse_at_depth1_and_reduce_slice_trip_count() { // is reduced from the original shape from 4x256 to 4x16 because of the // data accessed by the load. // CHECK-DAG: %0 = alloc() : memref<1x16xf32> - // CHECK: affine.for %i0 = 0 to 4 { - // CHECK-NEXT: affine.for %i1 = 0 to 256 { + // CHECK: for %i0 = 0 to 4 { + // CHECK-NEXT: for %i1 = 0 to 256 { // CHECK-NEXT: %2 = load %1[%i0, %i1] : memref<4x256xf32> // CHECK-NEXT: } - // CHECK-NEXT: affine.for %i2 = 0 to 16 { + // CHECK-NEXT: for %i2 = 0 to 16 { // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i0, %i2) // CHECK-NEXT: %4 = affine.apply [[MAP1]](%i0, %i0, %i2) // CHECK-NEXT: store %cst, %0[%3, %4] : memref<1x16xf32> // CHECK-NEXT: } - // CHECK-NEXT: affine.for %i3 = 0 to 16 { + // CHECK-NEXT: for %i3 = 0 to 16 { // CHECK-NEXT: %5 = affine.apply [[MAP0]](%i0, %i0, %i3) // CHECK-NEXT: %6 = affine.apply [[MAP1]](%i0, %i0, %i3) // CHECK-NEXT: %7 = load %0[%5, %6] : memref<1x16xf32> @@ -1134,31 +1134,31 @@ func @should_fuse_at_depth1_with_trip_count_20() { %c0 = constant 0 : index %cf0 = constant 0.0 : f32 - affine.for %i0 = 0 to 100 { + for %i0 = 0 to 100 { store %cf0, %a[%i0]: memref<100xf32> } - affine.for %i1 = 0 to 5 { - affine.for %i2 = 0 to 10 { + for %i1 = 0 to 5 { + for %i2 = 0 to 10 { %v0 = load %a[%i2]: memref<100xf32> } - affine.for %i3 = 0 to 10 { - affine.for %i4 = 0 to 20 { + for %i3 = 0 to 10 { + for %i4 = 0 to 20 { %v1 = load %a[%i4]: memref<100xf32> } } } // NOTE: The size of the private memref created for fusion is shrunk to 20xf32 // CHECK-DAG: %0 = alloc() : memref<20xf32> - // CHECK: affine.for %i0 = 0 to 5 { - // CHECK-NEXT: affine.for %i1 = 0 to 20 { + // CHECK: for %i0 = 0 to 5 { + // CHECK-NEXT: for %i1 = 0 to 20 { // CHECK-NEXT: store %cst, %0[%i1] : memref<20xf32> // CHECK-NEXT: } - // CHECK-NEXT: affine.for %i2 = 0 to 10 { + // CHECK-NEXT: for %i2 = 0 to 10 { // CHECK-NEXT: %1 = load %0[%i2] : memref<20xf32> // CHECK-NEXT: } - // CHECK-NEXT: affine.for %i3 = 0 to 10 { - // CHECK-NEXT: affine.for %i4 = 0 to 20 { + // CHECK-NEXT: for %i3 = 0 to 10 { + // CHECK-NEXT: for %i4 = 0 to 20 { // CHECK-NEXT: %2 = load %0[%i4] : memref<20xf32> // CHECK-NEXT: } // CHECK-NEXT: } @@ -1175,31 +1175,31 @@ func @should_fuse_at_depth1_with_trip_count_19() { %c0 = constant 0 : index %cf0 = constant 0.0 : f32 - affine.for %i0 = 0 to 100 { + for %i0 = 0 to 100 { store %cf0, %a[%i0]: memref<100xf32> } - affine.for %i1 = 0 to 5 { - affine.for %i2 = 0 to 19 { + for %i1 = 0 to 5 { + for %i2 = 0 to 19 { %v0 = load %a[%i2]: memref<100xf32> } - affine.for %i3 = 0 to 10 { - affine.for %i4 = 0 to 10 { + for %i3 = 0 to 10 { + for %i4 = 0 to 10 { %v1 = load %a[%i4]: memref<100xf32> } } } // NOTE: The size of the private memref created for fusion is shrunk to 19xf32 // CHECK-DAG: %0 = alloc() : memref<19xf32> - // CHECK: affine.for %i0 = 0 to 5 { - // CHECK-NEXT: affine.for %i1 = 0 to 19 { + // CHECK: for %i0 = 0 to 5 { + // CHECK-NEXT: for %i1 = 0 to 19 { // CHECK-NEXT: store %cst, %0[%i1] : memref<19xf32> // CHECK-NEXT: } - // CHECK-NEXT: affine.for %i2 = 0 to 19 { + // CHECK-NEXT: for %i2 = 0 to 19 { // CHECK-NEXT: %1 = load %0[%i2] : memref<19xf32> // CHECK-NEXT: } - // CHECK-NEXT: affine.for %i3 = 0 to 10 { - // CHECK-NEXT: affine.for %i4 = 0 to 10 { + // CHECK-NEXT: for %i3 = 0 to 10 { + // CHECK-NEXT: for %i4 = 0 to 10 { // CHECK-NEXT: %2 = load %0[%i4] : memref<19xf32> // CHECK-NEXT: } // CHECK-NEXT: } @@ -1217,26 +1217,26 @@ func @should_fuse_with_private_memrefs_with_diff_shapes() { %m = alloc() : memref<100xf32> %cf7 = constant 7.0 : f32 - affine.for %i0 = 0 to 100 { + for %i0 = 0 to 100 { store %cf7, %m[%i0] : memref<100xf32> } - affine.for %i1 = 0 to 17 { + for %i1 = 0 to 17 { %v0 = load %m[%i1] : memref<100xf32> } - affine.for %i2 = 0 to 82 { + for %i2 = 0 to 82 { %v1 = load %m[%i2] : memref<100xf32> } // Should create two new private memrefs customized to the shapes accessed // by loops %i1 and %i2. // CHECK-DAG: %0 = alloc() : memref<1xf32> // CHECK-DAG: %1 = alloc() : memref<1xf32> - // CHECK: affine.for %i0 = 0 to 82 { + // CHECK: for %i0 = 0 to 82 { // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: store %cst, %1[%2] : memref<1xf32> // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: %4 = load %1[%3] : memref<1xf32> // CHECK-NEXT: } - // CHECK-NEXT: affine.for %i1 = 0 to 17 { + // CHECK-NEXT: for %i1 = 0 to 17 { // CHECK-NEXT: %5 = affine.apply [[MAP0]](%i1, %i1) // CHECK-NEXT: store %cst, %0[%5] : memref<1xf32> // CHECK-NEXT: %6 = affine.apply [[MAP0]](%i1, %i1) @@ -1252,18 +1252,18 @@ func @should_fuse_with_private_memrefs_with_diff_shapes() { func @should_not_fuse_live_out_arg(%arg0: memref<10xf32>) { %cf7 = constant 7.0 : f32 - affine.for %i0 = 0 to 10 { + for %i0 = 0 to 10 { store %cf7, %arg0[%i0] : memref<10xf32> } - affine.for %i1 = 0 to 10 { + for %i1 = 0 to 10 { %v0 = load %arg0[%i1] : memref<10xf32> } // This tests that the loop nest '%i0' should not be removed after fusion // because it writes to memref argument '%arg0'. - // CHECK: affine.for %i0 = 0 to 10 { + // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %arg0[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: affine.for %i1 = 0 to 10 { + // CHECK-NEXT: for %i1 = 0 to 10 { // CHECK-NEXT: %0 = load %arg0[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return @@ -1276,19 +1276,19 @@ func @should_not_fuse_live_out_arg(%arg0: memref<10xf32>) { func @should_not_fuse_escaping_memref() -> memref<10xf32> { %cf7 = constant 7.0 : f32 %m = alloc() : memref<10xf32> - affine.for %i0 = 0 to 10 { + for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> } - affine.for %i1 = 0 to 10 { + for %i1 = 0 to 10 { %v0 = load %m[%i1] : memref<10xf32> } // This tests that the loop nest '%i0' should not be removed after fusion // because it writes to memref '%m' which is returned by the function. // CHECK-DAG: %0 = alloc() : memref<10xf32> - // CHECK: affine.for %i0 = 0 to 10 { + // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: affine.for %i1 = 0 to 10 { + // CHECK-NEXT: for %i1 = 0 to 10 { // CHECK-NEXT: %1 = load %0[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return %0 : memref<10xf32> @@ -1303,17 +1303,17 @@ func @R3_to_R2_reshape() { %c0 = constant 0 : index - affine.for %i0 = 0 to 2 { - affine.for %i1 = 0 to 3 { - affine.for %i2 = 0 to 16 { + for %i0 = 0 to 2 { + for %i1 = 0 to 3 { + for %i2 = 0 to 16 { %val = "foo"(%i0, %i1, %i2) : (index, index, index) -> i32 store %val, %in[%i0, %i1, %i2] : memref<2x3x16xi32> } } } - affine.for %ii = 0 to 32 { - affine.for %jj = 0 to 3 { + for %ii = 0 to 32 { + for %jj = 0 to 3 { %a0 = affine.apply (d0, d1) -> (d0 * 3 + d1) (%ii, %jj) %idx = affine.apply (d0) -> (d0 floordiv (3 * 16)) (%a0) %v = load %in[%idx, %jj, %c0] @@ -1332,8 +1332,8 @@ func @R3_to_R2_reshape() { // CHECK-LABEL: func @R3_to_R2_reshape() // CHECK-DAG: %0 = alloc() : memref<1x1x1xi32> -// CHECK: affine.for %i0 = 0 to 32 { -// CHECK-NEXT: affine.for %i1 = 0 to 3 { +// CHECK: for %i0 = 0 to 32 { +// CHECK-NEXT: for %i1 = 0 to 3 { // CHECK-NEXT: %1 = affine.apply [[MAP0]](%i0, %i1) // CHECK-NEXT: %2 = affine.apply [[MAP1]]()[%c0] // CHECK-NEXT: %3 = "foo"(%1, %i1, %2) : (index, index, index) -> i32 @@ -1360,19 +1360,19 @@ func @should_not_fuse_multi_output_producer() { %cf7 = constant 7.0 : f32 - affine.for %i0 = 0 to 10 { + for %i0 = 0 to 10 { store %cf7, %a[%i0] : memref<10xf32> store %cf7, %b[%i0] : memref<10xf32> } - affine.for %i1 = 0 to 10 { + for %i1 = 0 to 10 { %v0 = load %a[%i1] : memref<10xf32> } - // CHECK: affine.for %i0 = 0 to 10 { + // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: affine.for %i1 = 0 to 10 { + // CHECK-NEXT: for %i1 = 0 to 10 { // CHECK-NEXT: %2 = load %0[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return @@ -1389,30 +1389,30 @@ func @fusion_preventing_deps_on_middle_loop() { %cf7 = constant 7.0 : f32 - affine.for %i0 = 0 to 10 { + for %i0 = 0 to 10 { %v0 = load %a[%i0] : memref<10xf32> store %v0, %b[%i0] : memref<10xf32> } - affine.for %i1 = 0 to 10 { + for %i1 = 0 to 10 { store %cf7, %a[%i1] : memref<10xf32> %v1 = load %c[%i1] : memref<10xf32> } - affine.for %i2 = 0 to 10 { + for %i2 = 0 to 10 { %v2 = load %b[%i2] : memref<10xf32> store %v2, %c[%i2] : memref<10xf32> } // Loops '%i0' and '%i2' cannot fuse along producer/consumer edge on memref // '%b', because of the WAR dep from '%i0' to '%i1' on memref '%a' and // because of the WAR dep from '%i1' to '%i2' on memref '%c'. - // CHECK: affine.for %i0 = 0 to 10 { + // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: %3 = load %0[%i0] : memref<10xf32> // CHECK-NEXT: store %3, %1[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: affine.for %i1 = 0 to 10 { + // CHECK-NEXT: for %i1 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i1] : memref<10xf32> // CHECK-NEXT: %4 = load %2[%i1] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: affine.for %i2 = 0 to 10 { + // CHECK-NEXT: for %i2 = 0 to 10 { // CHECK-NEXT: %5 = load %1[%i2] : memref<10xf32> // CHECK-NEXT: store %5, %2[%i2] : memref<10xf32> // CHECK-NEXT: } @@ -1432,17 +1432,17 @@ func @should_fuse_and_move_to_preserve_war_dep() { %cf7 = constant 7.0 : f32 - affine.for %i0 = 0 to 10 { + for %i0 = 0 to 10 { %v0 = load %b[%i0] : memref<10xf32> store %v0, %a[%i0] : memref<10xf32> } - affine.for %i1 = 0 to 3 { + for %i1 = 0 to 3 { %v2 = load %c[%i1] : memref<10xf32> } - affine.for %i2 = 0 to 5 { + for %i2 = 0 to 5 { store %cf7, %b[%i2] : memref<10xf32> } - affine.for %i3 = 0 to 10 { + for %i3 = 0 to 10 { %v1 = load %a[%i3] : memref<10xf32> store %cf7, %c[%i3] : memref<10xf32> } @@ -1461,10 +1461,10 @@ func @should_fuse_and_move_to_preserve_war_dep() { // if the fused loop nest is inserted between loops '%i1' and '%i2'. // CHECK-DAG: %0 = alloc() : memref<1xf32> - // CHECK: affine.for %i0 = 0 to 3 { + // CHECK: for %i0 = 0 to 3 { // CHECK-NEXT: %3 = load %2[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: affine.for %i1 = 0 to 10 { + // CHECK-NEXT: for %i1 = 0 to 10 { // CHECK-NEXT: %4 = load %1[%i1] : memref<10xf32> // CHECK-NEXT: %5 = affine.apply [[MAP0]](%i1, %i1) // CHECK-NEXT: store %4, %0[%5] : memref<1xf32> @@ -1472,7 +1472,7 @@ func @should_fuse_and_move_to_preserve_war_dep() { // CHECK-NEXT: %7 = load %0[%6] : memref<1xf32> // CHECK-NEXT: store %cst, %2[%i1] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: affine.for %i2 = 0 to 5 { + // CHECK-NEXT: for %i2 = 0 to 5 { // CHECK-NEXT: store %cst, %1[%i2] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return @@ -1489,30 +1489,30 @@ func @fusion_preventing_dep_on_constant() { %cf7 = constant 7.0 : f32 - affine.for %i0 = 0 to 10 { + for %i0 = 0 to 10 { %v0 = load %b[%i0] : memref<10xf32> store %cf7, %a[%i0] : memref<10xf32> } - affine.for %i1 = 0 to 10 { + for %i1 = 0 to 10 { store %cf7, %b[%i1] : memref<10xf32> } %cf11 = constant 11.0 : f32 - affine.for %i2 = 0 to 10 { + for %i2 = 0 to 10 { %v2 = load %a[%i2] : memref<10xf32> store %cf11, %c[%i2] : memref<10xf32> } // Loops '%i0' and '%i2' cannot fuse along producer/consumer edge on memref // '%a', because of the WAR dep from '%i0' to '%i1' on memref '%b' and // because of the SSA value dep from '%cf11' def to use in '%i2'. - // CHECK: affine.for %i0 = 0 to 10 { + // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: %3 = load %1[%i0] : memref<10xf32> // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: affine.for %i1 = 0 to 10 { + // CHECK-NEXT: for %i1 = 0 to 10 { // CHECK-NEXT: store %cst, %1[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: %cst_0 = constant 1.100000e+01 : f32 - // CHECK-NEXT: affine.for %i2 = 0 to 10 { + // CHECK-NEXT: for %i2 = 0 to 10 { // CHECK-NEXT: %4 = load %0[%i2] : memref<10xf32> // CHECK-NEXT: store %cst_0, %2[%i2] : memref<10xf32> // CHECK-NEXT: } @@ -1532,14 +1532,14 @@ func @should_fuse_and_preserve_dep_on_constant() { %cf7 = constant 7.0 : f32 %cf11 = constant 11.0 : f32 - affine.for %i0 = 0 to 10 { + for %i0 = 0 to 10 { %v0 = load %b[%i0] : memref<10xf32> store %cf7, %a[%i0] : memref<10xf32> } - affine.for %i1 = 0 to 10 { + for %i1 = 0 to 10 { store %cf7, %b[%i1] : memref<10xf32> } - affine.for %i2 = 0 to 10 { + for %i2 = 0 to 10 { %v2 = load %a[%i2] : memref<10xf32> store %cf11, %c[%i2] : memref<10xf32> } @@ -1549,7 +1549,7 @@ func @should_fuse_and_preserve_dep_on_constant() { // the SSA value dep from '%cf11' def to use in '%i2'. // CHECK: %cst_0 = constant 1.100000e+01 : f32 - // CHECK-NEXT: affine.for %i0 = 0 to 10 { + // CHECK-NEXT: for %i0 = 0 to 10 { // CHECK-NEXT: %3 = load %1[%i0] : memref<10xf32> // CHECK-NEXT: %4 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: store %cst, %0[%4] : memref<1xf32> @@ -1557,7 +1557,7 @@ func @should_fuse_and_preserve_dep_on_constant() { // CHECK-NEXT: %6 = load %0[%5] : memref<1xf32> // CHECK-NEXT: store %cst_0, %2[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: affine.for %i1 = 0 to 10 { + // CHECK-NEXT: for %i1 = 0 to 10 { // CHECK-NEXT: store %cst, %1[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return @@ -1575,25 +1575,25 @@ func @should_fuse_and_preserve_dep_on_constant() { func @should_fuse_at_depth_above_loop_carried_dependence(%arg0: memref<64x4xf32>, %arg1: memref<64x4xf32>) { %out = alloc() : memref<64x4xf32> %0 = constant 0.0 : f32 - affine.for %i0 = 0 to 64 { - affine.for %i1 = 0 to 4 { + for %i0 = 0 to 64 { + for %i1 = 0 to 4 { store %0, %out[%i0, %i1] : memref<64x4xf32> } } - affine.for %i2 = 0 to 4 { - affine.for %i3 = 0 to 4 { - affine.for %i4 = 0 to 16 { + for %i2 = 0 to 4 { + for %i3 = 0 to 4 { + for %i4 = 0 to 16 { %1 = affine.apply (d0, d1) -> (d0 * 16 - d1 + 15)(%i3, %i4) %2 = load %arg1[%1, %i2] : memref<64x4xf32> "op0"(%2) : (f32) -> () } - affine.for %i5 = 0 to 4 { - affine.for %i6 = 0 to 16 { + for %i5 = 0 to 4 { + for %i6 = 0 to 16 { %3 = affine.apply (d0, d1) -> (d0 * 16 - d1 + 15)(%i5, %i6) %4 = load %arg0[%3, %i3] : memref<64x4xf32> "op1"(%4) : (f32) -> () } - affine.for %i7 = 0 to 16 { + for %i7 = 0 to 16 { %5 = "op2"() : () -> (f32) %6 = affine.apply (d0, d1) -> (d0 * 16 + d1)(%i5, %i7) %7 = load %out[%6, %i2] : memref<64x4xf32> @@ -1613,25 +1613,25 @@ func @should_fuse_at_depth_above_loop_carried_dependence(%arg0: memref<64x4xf32> // memref size can be reduced to 128x1xf32. // CHECK: %0 = alloc() : memref<64x1xf32> - // CHECK: affine.for %i0 = 0 to 4 { - // CHECK-NEXT: affine.for %i1 = 0 to 64 { + // CHECK: for %i0 = 0 to 4 { + // CHECK-NEXT: for %i1 = 0 to 64 { // CHECK-NEXT: %1 = affine.apply [[MAP0]](%i0, %i1, %i0) // CHECK-NEXT: %2 = affine.apply [[MAP1]](%i0, %i1, %i0) // CHECK-NEXT: store %cst, %0[%1, %2] : memref<64x1xf32> // CHECK-NEXT: } - // CHECK-NEXT: affine.for %i2 = 0 to 4 { - // CHECK-NEXT: affine.for %i3 = 0 to 16 { + // CHECK-NEXT: for %i2 = 0 to 4 { + // CHECK-NEXT: for %i3 = 0 to 16 { // CHECK-NEXT: %3 = affine.apply [[MAP2]](%i2, %i3) // CHECK-NEXT: %4 = load %arg1[%3, %i0] : memref<64x4xf32> // CHECK-NEXT: "op0"(%4) : (f32) -> () // CHECK-NEXT: } - // CHECK-NEXT: affine.for %i4 = 0 to 4 { - // CHECK-NEXT: affine.for %i5 = 0 to 16 { + // CHECK-NEXT: for %i4 = 0 to 4 { + // CHECK-NEXT: for %i5 = 0 to 16 { // CHECK-NEXT: %5 = affine.apply [[MAP2]](%i4, %i5) // CHECK-NEXT: %6 = load %arg0[%5, %i2] : memref<64x4xf32> // CHECK-NEXT: "op1"(%6) : (f32) -> () // CHECK-NEXT: } - // CHECK-NEXT: affine.for %i6 = 0 to 16 { + // CHECK-NEXT: for %i6 = 0 to 16 { // CHECK-NEXT: %7 = "op2"() : () -> f32 // CHECK-NEXT: %8 = affine.apply [[MAP3]](%i4, %i6) // CHECK-NEXT: %9 = affine.apply [[MAP0]](%i0, %8, %i0) @@ -1660,14 +1660,14 @@ func @should_fuse_after_private_memref_creation() { %cf7 = constant 7.0 : f32 - affine.for %i0 = 0 to 10 { + for %i0 = 0 to 10 { store %cf7, %a[%i0] : memref<10xf32> } - affine.for %i1 = 0 to 10 { + for %i1 = 0 to 10 { %v0 = load %a[%i1] : memref<10xf32> store %v0, %b[%i1] : memref<10xf32> } - affine.for %i2 = 0 to 10 { + for %i2 = 0 to 10 { %v1 = load %a[%i2] : memref<10xf32> store %v1, %b[%i2] : memref<10xf32> } @@ -1678,14 +1678,14 @@ func @should_fuse_after_private_memref_creation() { // private memref, the dependence between '%i0' and '%i1' on memref '%a' no // longer exists, so '%i0' can now be fused into '%i2'. - // CHECK: affine.for %i0 = 0 to 10 { + // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: store %cst, %1[%3] : memref<1xf32> // CHECK-NEXT: %4 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: %5 = load %1[%4] : memref<1xf32> // CHECK-NEXT: store %5, %2[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: affine.for %i1 = 0 to 10 { + // CHECK-NEXT: for %i1 = 0 to 10 { // CHECK-NEXT: %6 = affine.apply [[MAP0]](%i1, %i1) // CHECK-NEXT: store %cst, %0[%6] : memref<1xf32> // CHECK-NEXT: %7 = affine.apply [[MAP0]](%i1, %i1) diff --git a/mlir/test/Transforms/loop-tiling.mlir b/mlir/test/Transforms/loop-tiling.mlir index a1f9d717fab..c2fdbd4f80f 100644 --- a/mlir/test/Transforms/loop-tiling.mlir +++ b/mlir/test/Transforms/loop-tiling.mlir @@ -8,12 +8,12 @@ // CHECK-DAG: [[UB_INTRA_TILE:#map[0-9]+]] = (d0, d1, d2) -> (d2 + 32, s0, 4096 floordiv s1) // CHECK-LABEL: func @loop_tiling() -// CHECK-NEXT: affine.for %i0 = 0 to 256 step 32 { -// CHECK-NEXT: affine.for %i1 = 0 to 512 step 32 { -// CHECK-NEXT: affine.for %i2 = 0 to 1024 step 32 { -// CHECK-NEXT: affine.for %i3 = [[IDENTITY]](%i0) to [[MAP0]](%i0) { -// CHECK-NEXT: affine.for %i4 = [[IDENTITY]](%i1) to [[MAP0]](%i1) { -// CHECK-NEXT: affine.for %i5 = [[IDENTITY]](%i2) to [[MAP0]](%i2) { +// CHECK-NEXT: for %i0 = 0 to 256 step 32 { +// CHECK-NEXT: for %i1 = 0 to 512 step 32 { +// CHECK-NEXT: for %i2 = 0 to 1024 step 32 { +// CHECK-NEXT: for %i3 = [[IDENTITY]](%i0) to [[MAP0]](%i0) { +// CHECK-NEXT: for %i4 = [[IDENTITY]](%i1) to [[MAP0]](%i1) { +// CHECK-NEXT: for %i5 = [[IDENTITY]](%i2) to [[MAP0]](%i2) { // CHECK-NEXT: "foo"(%i3, %i4, %i5) : (index, index, index) -> () // CHECK-NEXT: } // CHECK-NEXT: } @@ -21,32 +21,32 @@ // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: affine.for %i6 = 0 to 50 step 32 { -// CHECK-NEXT: affine.for %i7 = [[IDENTITY]](%i6) to min [[MAP1]](%i6) { +// CHECK-NEXT: for %i6 = 0 to 50 step 32 { +// CHECK-NEXT: for %i7 = [[IDENTITY]](%i6) to min [[MAP1]](%i6) { // CHECK-NEXT: "bar"(%i7, %i7) : (index, index) -> () // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: affine.for %i8 = 0 to 21 step 32 { -// CHECK-NEXT: affine.for %i9 = [[IDENTITY]](%i8) to 21 { +// CHECK-NEXT: for %i8 = 0 to 21 step 32 { +// CHECK-NEXT: for %i9 = [[IDENTITY]](%i8) to 21 { // CHECK-NEXT: "foobar"(%i9) : (index) -> () // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return func @loop_tiling() { - affine.for %i = 0 to 256 { - affine.for %j = 0 to 512 { - affine.for %k = 0 to 1024 { + for %i = 0 to 256 { + for %j = 0 to 512 { + for %k = 0 to 1024 { "foo"(%i, %j, %k) : (index, index, index) -> () } } } - affine.for %x = 0 to 50 { + for %x = 0 to 50 { "bar"(%x, %x) : (index, index) -> () } // Intra-tile loop won't need a min expression. - affine.for %y = 0 to 21 { + for %y = 0 to 21 { "foobar"(%y) : (index) -> () } @@ -58,12 +58,12 @@ func @loop_tiling() { // CHECK-LABEL: func @loop_max_min_bound(%arg0: memref, %arg1: index, %arg2: index) { func @loop_max_min_bound(%A : memref, %L : index, %U : index) { %M = dim %A, 0 : memref - affine.for %iTT = max #lb()[%L] to min #ub()[%M, %U] { + for %iTT = max #lb()[%L] to min #ub()[%M, %U] { %out = affine.apply (d0) -> (d0) (%iTT) } return -// CHECK: affine.for %i0 = max [[LB]]()[%arg1] to min [[UB]]()[%0, %arg2] step 32 { -// CHECK-NEXT: affine.for %i1 = [[IDENTITY]](%i0) to min [[UB_INTRA_TILE]](%0, %arg2, %i0) { +// CHECK: for %i0 = max [[LB]]()[%arg1] to min [[UB]]()[%0, %arg2] step 32 { +// CHECK-NEXT: for %i1 = [[IDENTITY]](%i0) to min [[UB_INTRA_TILE]](%0, %arg2, %i0) { // CHECK-NEXT: %1 = affine.apply [[IDENTITY]](%i1) // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/mlir/test/Transforms/lower-affine.mlir b/mlir/test/Transforms/lower-affine.mlir index 5882da5c749..22e9f4b9fd4 100644 --- a/mlir/test/Transforms/lower-affine.mlir +++ b/mlir/test/Transforms/lower-affine.mlir @@ -24,7 +24,7 @@ func @body(index) -> () // CHECK-NEXT: return // CHECK-NEXT: } func @simple_loop() { - affine.for %i = 1 to 42 { + for %i = 1 to 42 { call @body(%i) : (index) -> () } return @@ -65,9 +65,9 @@ func @post(index) -> () // CHECK-NEXT: return // CHECK-NEXT: } func @imperfectly_nested_loops() { - affine.for %i = 0 to 42 { + for %i = 0 to 42 { call @pre(%i) : (index) -> () - affine.for %j = 7 to 56 step 2 { + for %j = 7 to 56 step 2 { call @body2(%i, %j) : (index, index) -> () } call @post(%i) : (index) -> () @@ -122,13 +122,13 @@ func @body3(index, index) -> () // CHECK-NEXT: return // CHECK-NEXT: } func @more_imperfectly_nested_loops() { - affine.for %i = 0 to 42 { + for %i = 0 to 42 { call @pre(%i) : (index) -> () - affine.for %j = 7 to 56 step 2 { + for %j = 7 to 56 step 2 { call @body2(%i, %j) : (index, index) -> () } call @mid(%i) : (index) -> () - affine.for %k = 18 to 37 step 3 { + for %k = 18 to 37 step 3 { call @body3(%i, %k) : (index, index) -> () } call @post(%i) : (index) -> () @@ -161,8 +161,8 @@ func @more_imperfectly_nested_loops() { // CHECK-NEXT: return // CHECK-NEXT: } func @affine_apply_loops_shorthand(%N : index) { - affine.for %i = 0 to %N { - affine.for %j = %i to 42 { + for %i = 0 to %N { + for %j = %i to 42 { call @body2(%i, %j) : (index, index) -> () } } @@ -360,7 +360,7 @@ func @if_for() { // CHECK-NEXT: [[outerEndBB]]: // CHECK-NEXT: br [[outerLoopInit:\^bb[0-9]+]] affine.if #set1(%i) { - affine.for %j = 0 to 42 { + for %j = 0 to 42 { affine.if #set2(%j) { call @body2(%i, %j) : (index, index) -> () } @@ -397,9 +397,9 @@ func @if_for() { // CHECK-NEXT: %c1_9 = constant 1 : index // CHECK-NEXT: %16 = addi %9, %c1_9 : index // CHECK-NEXT: br [[outerLoopCond]](%16 : index) - affine.for %k = 0 to 42 { + for %k = 0 to 42 { affine.if #set2(%k) { - affine.for %l = 0 to 42 { + for %l = 0 to 42 { call @body3(%k, %l) : (index, index) -> () } } @@ -446,8 +446,8 @@ func @if_for() { // CHECK-NEXT: return // CHECK-NEXT: } func @loop_min_max(%N : index) { - affine.for %i = 0 to 42 { - affine.for %j = max #lbMultiMap(%i)[%N] to min #ubMultiMap(%i)[%N] { + for %i = 0 to 42 { + for %j = max #lbMultiMap(%i)[%N] to min #ubMultiMap(%i)[%N] { call @body2(%i, %j) : (index, index) -> () } } @@ -486,7 +486,7 @@ func @loop_min_max(%N : index) { // CHECK-NEXT: return // CHECK-NEXT: } func @min_reduction_tree(%v : index) { - affine.for %i = 0 to min #map_7_values(%v)[] { + for %i = 0 to min #map_7_values(%v)[] { call @body(%i) : (index) -> () } return diff --git a/mlir/test/Transforms/memref-bound-check.mlir b/mlir/test/Transforms/memref-bound-check.mlir index b3d5b23e70f..2926bf1afbc 100644 --- a/mlir/test/Transforms/memref-bound-check.mlir +++ b/mlir/test/Transforms/memref-bound-check.mlir @@ -11,8 +11,8 @@ func @test() { %A = alloc() : memref<9 x 9 x i32> %B = alloc() : memref<111 x i32> - affine.for %i = -1 to 10 { - affine.for %j = -1 to 10 { + for %i = -1 to 10 { + for %j = -1 to 10 { %idx0 = affine.apply (d0, d1) -> (d0)(%i, %j) %idx1 = affine.apply (d0, d1) -> (d1)(%i, %j) // Out of bound access. @@ -27,7 +27,7 @@ func @test() { } } - affine.for %k = 0 to 10 { + for %k = 0 to 10 { // In bound. %u = load %B[%zero] : memref<111 x i32> // Out of bounds. @@ -43,8 +43,8 @@ func @test_mod_floordiv_ceildiv() { %zero = constant 0 : index %A = alloc() : memref<128 x 64 x 64 x i32> - affine.for %i = 0 to 256 { - affine.for %j = 0 to 256 { + for %i = 0 to 256 { + for %j = 0 to 256 { %idx0 = affine.apply (d0, d1, d2) -> (d0 mod 128 + 1)(%i, %j, %j) %idx1 = affine.apply (d0, d1, d2) -> (d1 floordiv 4 + 1)(%i, %j, %j) %idx2 = affine.apply (d0, d1, d2) -> (d2 ceildiv 4)(%i, %j, %j) @@ -69,8 +69,8 @@ func @test_no_out_of_bounds() { %C = alloc() : memref<257 x i32> %B = alloc() : memref<1 x i32> - affine.for %i = 0 to 256 { - affine.for %j = 0 to 256 { + for %i = 0 to 256 { + for %j = 0 to 256 { // All of these accesses are in bound; check that no errors are emitted. // CHECK: %3 = affine.apply {{#map.*}}(%i0, %i1) // CHECK-NEXT: %4 = load %0[%3, %c0] : memref<257x256xi32> @@ -93,8 +93,8 @@ func @mod_div() { %zero = constant 0 : index %A = alloc() : memref<128 x 64 x 64 x i32> - affine.for %i = 0 to 256 { - affine.for %j = 0 to 256 { + for %i = 0 to 256 { + for %j = 0 to 256 { %idx0 = affine.apply (d0, d1, d2) -> (d0 mod 128 + 1)(%i, %j, %j) %idx1 = affine.apply (d0, d1, d2) -> (d1 floordiv 4 + 1)(%i, %j, %j) %idx2 = affine.apply (d0, d1, d2) -> (d2 ceildiv 4)(%i, %j, %j) @@ -115,8 +115,8 @@ func @mod_div() { // CHECK-LABEL: func @mod_floordiv_nested() { func @mod_floordiv_nested() { %A = alloc() : memref<256 x 256 x i32> - affine.for %i = 0 to 256 { - affine.for %j = 0 to 256 { + for %i = 0 to 256 { + for %j = 0 to 256 { %idx0 = affine.apply (d0, d1) -> ((d0 mod 1024) floordiv 4)(%i, %j) %idx1 = affine.apply (d0, d1) -> ((((d1 mod 128) mod 32) ceildiv 4) * 32)(%i, %j) load %A[%idx0, %idx1] : memref<256 x 256 x i32> // expected-error {{'load' op memref out of upper bound access along dimension #2}} @@ -128,7 +128,7 @@ func @mod_floordiv_nested() { // CHECK-LABEL: func @test_semi_affine_bailout func @test_semi_affine_bailout(%N : index) { %B = alloc() : memref<10 x i32> - affine.for %i = 0 to 10 { + for %i = 0 to 10 { %idx = affine.apply (d0)[s0] -> (d0 * s0)(%i)[%N] %y = load %B[%idx] : memref<10 x i32> } @@ -138,7 +138,7 @@ func @test_semi_affine_bailout(%N : index) { // CHECK-LABEL: func @multi_mod_floordiv func @multi_mod_floordiv() { %A = alloc() : memref<2x2xi32> - affine.for %ii = 0 to 64 { + for %ii = 0 to 64 { %idx0 = affine.apply (d0) -> ((d0 mod 147456) floordiv 1152) (%ii) %idx1 = affine.apply (d0) -> (((d0 mod 147456) mod 1152) floordiv 384) (%ii) %v = load %A[%idx0, %idx1] : memref<2x2xi32> @@ -153,8 +153,8 @@ func @delinearize_mod_floordiv() { %out = alloc() : memref<64x9xi32> // Reshape '%in' into '%out'. - affine.for %ii = 0 to 64 { - affine.for %jj = 0 to 9 { + for %ii = 0 to 64 { + for %jj = 0 to 9 { %a0 = affine.apply (d0, d1) -> (d0 * (9 * 1024) + d1 * 128) (%ii, %jj) %a10 = affine.apply (d0) -> (d0 floordiv (2 * 3 * 3 * 128 * 128)) (%a0) @@ -189,7 +189,7 @@ func @out_of_bounds() { %in = alloc() : memref<1xi32> %c9 = constant 9 : i32 - affine.for %i0 = 10 to 11 { + for %i0 = 10 to 11 { %idy = affine.apply (d0) -> (100 * d0 floordiv 1000) (%i0) store %c9, %in[%idy] : memref<1xi32> // expected-error {{'store' op memref out of upper bound access along dimension #1}} } diff --git a/mlir/test/Transforms/memref-dataflow-opt.mlir b/mlir/test/Transforms/memref-dataflow-opt.mlir index ed39d71eefd..710d14c1cf9 100644 --- a/mlir/test/Transforms/memref-dataflow-opt.mlir +++ b/mlir/test/Transforms/memref-dataflow-opt.mlir @@ -10,14 +10,14 @@ func @simple_store_load() { %cf7 = constant 7.0 : f32 %m = alloc() : memref<10xf32> - affine.for %i0 = 0 to 10 { + for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> %v0 = load %m[%i0] : memref<10xf32> %v1 = addf %v0, %v0 : f32 } return // CHECK: %cst = constant 7.000000e+00 : f32 -// CHECK-NEXT: affine.for %i0 = 0 to 10 { +// CHECK-NEXT: for %i0 = 0 to 10 { // CHECK-NEXT: %0 = addf %cst, %cst : f32 // CHECK-NEXT: } // CHECK-NEXT: return @@ -30,7 +30,7 @@ func @multi_store_load() { %cf8 = constant 8.0 : f32 %cf9 = constant 9.0 : f32 %m = alloc() : memref<10xf32> - affine.for %i0 = 0 to 10 { + for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> %v0 = load %m[%i0] : memref<10xf32> %v1 = addf %v0, %v0 : f32 @@ -45,7 +45,7 @@ func @multi_store_load() { // CHECK-NEXT: %cst = constant 7.000000e+00 : f32 // CHECK-NEXT: %cst_0 = constant 8.000000e+00 : f32 // CHECK-NEXT: %cst_1 = constant 9.000000e+00 : f32 -// CHECK-NEXT: affine.for %i0 = 0 to 10 { +// CHECK-NEXT: for %i0 = 0 to 10 { // CHECK-NEXT: %0 = addf %cst, %cst : f32 // CHECK-NEXT: %1 = mulf %cst_1, %cst_1 : f32 // CHECK-NEXT: } @@ -59,8 +59,8 @@ func @multi_store_load() { func @store_load_affine_apply() -> memref<10x10xf32> { %cf7 = constant 7.0 : f32 %m = alloc() : memref<10x10xf32> - affine.for %i0 = 0 to 10 { - affine.for %i1 = 0 to 10 { + for %i0 = 0 to 10 { + for %i1 = 0 to 10 { %t0 = affine.apply (d0, d1) -> (d1 + 1)(%i0, %i1) %t1 = affine.apply (d0, d1) -> (d0)(%i0, %i1) %idx0 = affine.apply (d0, d1) -> (d1) (%t0, %t1) @@ -75,8 +75,8 @@ func @store_load_affine_apply() -> memref<10x10xf32> { return %m : memref<10x10xf32> // CHECK: %cst = constant 7.000000e+00 : f32 // CHECK-NEXT: %0 = alloc() : memref<10x10xf32> -// CHECK-NEXT: affine.for %i0 = 0 to 10 { -// CHECK-NEXT: affine.for %i1 = 0 to 10 { +// CHECK-NEXT: for %i0 = 0 to 10 { +// CHECK-NEXT: for %i1 = 0 to 10 { // CHECK-NEXT: %1 = affine.apply [[MAP0]](%i0, %i1) // CHECK-NEXT: %2 = affine.apply [[MAP1]](%i0, %i1) // CHECK-NEXT: %3 = affine.apply [[MAP2]](%1, %2) @@ -92,17 +92,17 @@ func @store_load_affine_apply() -> memref<10x10xf32> { func @store_load_nested(%N : index) { %cf7 = constant 7.0 : f32 %m = alloc() : memref<10xf32> - affine.for %i0 = 0 to 10 { + for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> - affine.for %i1 = 0 to %N { + for %i1 = 0 to %N { %v0 = load %m[%i0] : memref<10xf32> %v1 = addf %v0, %v0 : f32 } } return // CHECK: %cst = constant 7.000000e+00 : f32 -// CHECK-NEXT: affine.for %i0 = 0 to 10 { -// CHECK-NEXT: affine.for %i1 = 0 to %arg0 { +// CHECK-NEXT: for %i0 = 0 to 10 { +// CHECK-NEXT: for %i1 = 0 to %arg0 { // CHECK-NEXT: %0 = addf %cst, %cst : f32 // CHECK-NEXT: } // CHECK-NEXT: } @@ -117,12 +117,12 @@ func @multi_store_load_nested_no_fwd(%N : index) { %cf7 = constant 7.0 : f32 %cf8 = constant 8.0 : f32 %m = alloc() : memref<10xf32> - affine.for %i0 = 0 to 10 { + for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> - affine.for %i1 = 0 to %N { + for %i1 = 0 to %N { store %cf8, %m[%i1] : memref<10xf32> } - affine.for %i2 = 0 to %N { + for %i2 = 0 to %N { // CHECK: %{{[0-9]+}} = load %0[%i0] : memref<10xf32> %v0 = load %m[%i0] : memref<10xf32> %v1 = addf %v0, %v0 : f32 @@ -138,9 +138,9 @@ func @store_load_store_nested_no_fwd(%N : index) { %cf7 = constant 7.0 : f32 %cf9 = constant 9.0 : f32 %m = alloc() : memref<10xf32> - affine.for %i0 = 0 to 10 { + for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> - affine.for %i1 = 0 to %N { + for %i1 = 0 to %N { // CHECK: %{{[0-9]+}} = load %0[%i0] : memref<10xf32> %v0 = load %m[%i0] : memref<10xf32> %v1 = addf %v0, %v0 : f32 @@ -159,16 +159,16 @@ func @multi_store_load_nested_fwd(%N : index) { %cf9 = constant 9.0 : f32 %cf10 = constant 10.0 : f32 %m = alloc() : memref<10xf32> - affine.for %i0 = 0 to 10 { + for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> - affine.for %i1 = 0 to %N { + for %i1 = 0 to %N { store %cf8, %m[%i1] : memref<10xf32> } - affine.for %i2 = 0 to %N { + for %i2 = 0 to %N { store %cf9, %m[%i2] : memref<10xf32> } store %cf10, %m[%i0] : memref<10xf32> - affine.for %i3 = 0 to %N { + for %i3 = 0 to %N { // CHECK-NOT: %{{[0-9]+}} = load %v0 = load %m[%i0] : memref<10xf32> %v1 = addf %v0, %v0 : f32 @@ -182,10 +182,10 @@ func @multi_store_load_nested_fwd(%N : index) { func @store_load_no_fwd() { %cf7 = constant 7.0 : f32 %m = alloc() : memref<10xf32> - affine.for %i0 = 0 to 10 { + for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> - affine.for %i1 = 0 to 10 { - affine.for %i2 = 0 to 10 { + for %i1 = 0 to 10 { + for %i2 = 0 to 10 { // CHECK: load %{{[0-9]+}} %v0 = load %m[%i2] : memref<10xf32> %v1 = addf %v0, %v0 : f32 @@ -202,9 +202,9 @@ func @store_load_fwd() { %c0 = constant 0 : index %m = alloc() : memref<10xf32> store %cf7, %m[%c0] : memref<10xf32> - affine.for %i0 = 0 to 10 { - affine.for %i1 = 0 to 10 { - affine.for %i2 = 0 to 10 { + for %i0 = 0 to 10 { + for %i1 = 0 to 10 { + for %i2 = 0 to 10 { // CHECK-NOT: load %{{[0-9]}}+ %v0 = load %m[%c0] : memref<10xf32> %v1 = addf %v0, %v0 : f32 @@ -223,9 +223,9 @@ func @store_load_store_nested_fwd(%N : index) -> f32 { %c0 = constant 0 : index %c1 = constant 1 : index %m = alloc() : memref<10xf32> - affine.for %i0 = 0 to 10 { + for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> - affine.for %i1 = 0 to %N { + for %i1 = 0 to %N { %v0 = load %m[%i0] : memref<10xf32> %v1 = addf %v0, %v0 : f32 %idx = affine.apply (d0) -> (d0 + 1) (%i0) @@ -236,9 +236,9 @@ func @store_load_store_nested_fwd(%N : index) -> f32 { %v3 = load %m[%c1] : memref<10xf32> return %v3 : f32 // CHECK: %0 = alloc() : memref<10xf32> -// CHECK-NEXT: affine.for %i0 = 0 to 10 { +// CHECK-NEXT: for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> -// CHECK-NEXT: affine.for %i1 = 0 to %arg0 { +// CHECK-NEXT: for %i1 = 0 to %arg0 { // CHECK-NEXT: %1 = addf %cst, %cst : f32 // CHECK-NEXT: %2 = affine.apply [[MAP4]](%i0) // CHECK-NEXT: store %cst_0, %0[%2] : memref<10xf32> diff --git a/mlir/test/Transforms/memref-dependence-check.mlir b/mlir/test/Transforms/memref-dependence-check.mlir index 00d0e730098..0accc30630b 100644 --- a/mlir/test/Transforms/memref-dependence-check.mlir +++ b/mlir/test/Transforms/memref-dependence-check.mlir @@ -13,14 +13,14 @@ func @store_may_execute_before_load() { // ancestor IfOp of the store, dominates the ancestor ForSmt of the load, // and thus the store "may" conditionally execute before the load. affine.if #set0(%c0) { - affine.for %i0 = 0 to 10 { + for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-note@-2 {{dependence from 0 to 0 at depth 2 = false}} // expected-note@-3 {{dependence from 0 to 1 at depth 1 = true}} } } - affine.for %i1 = 0 to 10 { + for %i1 = 0 to 10 { %v0 = load %m[%i1] : memref<10xf32> // expected-note@-1 {{dependence from 1 to 1 at depth 1 = false}} // expected-note@-2 {{dependence from 1 to 1 at depth 2 = false}} @@ -37,13 +37,13 @@ func @dependent_loops() { %cst = constant 7.000000e+00 : f32 // There is a dependence from 0 to 1 at depth 1 (common surrounding loops 0) // because the first loop with the store dominates the second loop. - affine.for %i0 = 0 to 10 { + for %i0 = 0 to 10 { store %cst, %0[%i0] : memref<10xf32> // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-note@-2 {{dependence from 0 to 0 at depth 2 = false}} // expected-note@-3 {{dependence from 0 to 1 at depth 1 = true}} } - affine.for %i1 = 0 to 10 { + for %i1 = 0 to 10 { %1 = load %0[%i1] : memref<10xf32> // expected-note@-1 {{dependence from 1 to 1 at depth 1 = false}} // expected-note@-2 {{dependence from 1 to 1 at depth 2 = false}} @@ -231,7 +231,7 @@ func @store_range_load_after_range() { %m = alloc() : memref<100xf32> %c7 = constant 7.0 : f32 %c10 = constant 10 : index - affine.for %i0 = 0 to 10 { + for %i0 = 0 to 10 { %a0 = affine.apply (d0) -> (d0) (%i0) store %c7, %m[%a0] : memref<100xf32> // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} @@ -254,7 +254,7 @@ func @store_load_func_symbol(%arg0: index, %arg1: index) { %m = alloc() : memref<100xf32> %c7 = constant 7.0 : f32 %c10 = constant 10 : index - affine.for %i0 = 0 to %arg1 { + for %i0 = 0 to %arg1 { %a0 = affine.apply (d0) -> (d0) (%arg0) store %c7, %m[%a0] : memref<100xf32> // expected-note@-1 {{dependence from 0 to 0 at depth 1 = [1, +inf]}} @@ -277,7 +277,7 @@ func @store_range_load_last_in_range() { %m = alloc() : memref<100xf32> %c7 = constant 7.0 : f32 %c10 = constant 10 : index - affine.for %i0 = 0 to 10 { + for %i0 = 0 to 10 { %a0 = affine.apply (d0) -> (d0) (%i0) // For dependence from 0 to 1, we do not have a loop carried dependence // because only the final write in the loop accesses the same element as the @@ -305,7 +305,7 @@ func @store_range_load_before_range() { %m = alloc() : memref<100xf32> %c7 = constant 7.0 : f32 %c0 = constant 0 : index - affine.for %i0 = 1 to 11 { + for %i0 = 1 to 11 { %a0 = affine.apply (d0) -> (d0) (%i0) store %c7, %m[%a0] : memref<100xf32> // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} @@ -328,7 +328,7 @@ func @store_range_load_first_in_range() { %m = alloc() : memref<100xf32> %c7 = constant 7.0 : f32 %c0 = constant 0 : index - affine.for %i0 = 1 to 11 { + for %i0 = 1 to 11 { %a0 = affine.apply (d0) -> (d0) (%i0) // Dependence from 0 to 1 at depth 1 is a range because all loads at // constant index zero are reads after first store at index zero during @@ -353,7 +353,7 @@ func @store_range_load_first_in_range() { func @store_plus_3() { %m = alloc() : memref<100xf32> %c7 = constant 7.0 : f32 - affine.for %i0 = 1 to 11 { + for %i0 = 1 to 11 { %a0 = affine.apply (d0) -> (d0 + 3) (%i0) store %c7, %m[%a0] : memref<100xf32> // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} @@ -375,7 +375,7 @@ func @store_plus_3() { func @load_minus_2() { %m = alloc() : memref<100xf32> %c7 = constant 7.0 : f32 - affine.for %i0 = 2 to 11 { + for %i0 = 2 to 11 { %a0 = affine.apply (d0) -> (d0) (%i0) store %c7, %m[%a0] : memref<100xf32> // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} @@ -397,8 +397,8 @@ func @load_minus_2() { func @perfectly_nested_loops_loop_independent() { %m = alloc() : memref<10x10xf32> %c7 = constant 7.0 : f32 - affine.for %i0 = 0 to 11 { - affine.for %i1 = 0 to 11 { + for %i0 = 0 to 11 { + for %i1 = 0 to 11 { // Dependence from access 0 to 1 is loop independent at depth = 3. %a00 = affine.apply (d0, d1) -> (d0) (%i0, %i1) %a01 = affine.apply (d0, d1) -> (d1) (%i0, %i1) @@ -428,8 +428,8 @@ func @perfectly_nested_loops_loop_independent() { func @perfectly_nested_loops_loop_carried_at_depth1() { %m = alloc() : memref<10x10xf32> %c7 = constant 7.0 : f32 - affine.for %i0 = 0 to 9 { - affine.for %i1 = 0 to 9 { + for %i0 = 0 to 9 { + for %i1 = 0 to 9 { // Dependence from access 0 to 1 is loop carried at depth 1. %a00 = affine.apply (d0, d1) -> (d0) (%i0, %i1) %a01 = affine.apply (d0, d1) -> (d1) (%i0, %i1) @@ -459,8 +459,8 @@ func @perfectly_nested_loops_loop_carried_at_depth1() { func @perfectly_nested_loops_loop_carried_at_depth2() { %m = alloc() : memref<10x10xf32> %c7 = constant 7.0 : f32 - affine.for %i0 = 0 to 10 { - affine.for %i1 = 0 to 10 { + for %i0 = 0 to 10 { + for %i1 = 0 to 10 { // Dependence from access 0 to 1 is loop carried at depth 2. %a00 = affine.apply (d0, d1) -> (d0) (%i0, %i1) %a01 = affine.apply (d0, d1) -> (d1) (%i0, %i1) @@ -491,8 +491,8 @@ func @one_common_loop() { %m = alloc() : memref<10x10xf32> %c7 = constant 7.0 : f32 // There is a loop-independent dependence from access 0 to 1 at depth 2. - affine.for %i0 = 0 to 10 { - affine.for %i1 = 0 to 10 { + for %i0 = 0 to 10 { + for %i1 = 0 to 10 { %a00 = affine.apply (d0, d1) -> (d0) (%i0, %i1) %a01 = affine.apply (d0, d1) -> (d1) (%i0, %i1) store %c7, %m[%a00, %a01] : memref<10x10xf32> @@ -502,7 +502,7 @@ func @one_common_loop() { // expected-note@-4 {{dependence from 0 to 1 at depth 1 = false}} // expected-note@-5 {{dependence from 0 to 1 at depth 2 = true}} } - affine.for %i2 = 0 to 9 { + for %i2 = 0 to 9 { %a10 = affine.apply (d0, d1) -> (d0) (%i0, %i2) %a11 = affine.apply (d0, d1) -> (d1) (%i0, %i2) %v0 = load %m[%a10, %a11] : memref<10x10xf32> @@ -525,7 +525,7 @@ func @dependence_cycle() { // Dependences: // *) loop-independent dependence from access 1 to 2 at depth 2. // *) loop-carried dependence from access 3 to 0 at depth 1. - affine.for %i0 = 0 to 9 { + for %i0 = 0 to 9 { %a0 = affine.apply (d0) -> (d0) (%i0) %v0 = load %m.a[%a0] : memref<100xf32> // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} @@ -575,8 +575,8 @@ func @dependence_cycle() { func @negative_and_positive_direction_vectors(%arg0: index, %arg1: index) { %m = alloc() : memref<10x10xf32> %c7 = constant 7.0 : f32 - affine.for %i0 = 0 to %arg0 { - affine.for %i1 = 0 to %arg1 { + for %i0 = 0 to %arg0 { + for %i1 = 0 to %arg1 { %a00 = affine.apply (d0, d1) -> (d0 - 1) (%i0, %i1) %a01 = affine.apply (d0, d1) -> (d1 + 1) (%i0, %i1) %v0 = load %m[%a00, %a01] : memref<10x10xf32> @@ -605,8 +605,8 @@ func @negative_and_positive_direction_vectors(%arg0: index, %arg1: index) { func @war_raw_waw_deps() { %m = alloc() : memref<100xf32> %c7 = constant 7.0 : f32 - affine.for %i0 = 0 to 10 { - affine.for %i1 = 0 to 10 { + for %i0 = 0 to 10 { + for %i1 = 0 to 10 { %a0 = affine.apply (d0) -> (d0 + 1) (%i1) %v0 = load %m[%a0] : memref<100xf32> // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} @@ -633,7 +633,7 @@ func @war_raw_waw_deps() { func @mod_deps() { %m = alloc() : memref<100xf32> %c7 = constant 7.0 : f32 - affine.for %i0 = 0 to 10 { + for %i0 = 0 to 10 { %a0 = affine.apply (d0) -> (d0 mod 2) (%i0) // Results are conservative here since we currently don't have a way to // represent strided sets in FlatAffineConstraints. @@ -658,8 +658,8 @@ func @loop_nest_depth() { %0 = alloc() : memref<100x100xf32> %c7 = constant 7.0 : f32 - affine.for %i0 = 0 to 128 { - affine.for %i1 = 0 to 8 { + for %i0 = 0 to 128 { + for %i1 = 0 to 8 { store %c7, %0[%i0, %i1] : memref<100x100xf32> // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-note@-2 {{dependence from 0 to 0 at depth 2 = false}} @@ -667,10 +667,10 @@ func @loop_nest_depth() { // expected-note@-4 {{dependence from 0 to 1 at depth 1 = true}} } } - affine.for %i2 = 0 to 8 { - affine.for %i3 = 0 to 8 { - affine.for %i4 = 0 to 8 { - affine.for %i5 = 0 to 16 { + for %i2 = 0 to 8 { + for %i3 = 0 to 8 { + for %i4 = 0 to 8 { + for %i5 = 0 to 16 { %8 = affine.apply (d0, d1) -> (d0 * 16 + d1)(%i4, %i5) %9 = load %0[%8, %i3] : memref<100x100xf32> // expected-note@-1 {{dependence from 1 to 0 at depth 1 = false}} @@ -693,9 +693,9 @@ func @loop_nest_depth() { func @mod_div_3d() { %M = alloc() : memref<2x2x2xi32> %c0 = constant 0 : i32 - affine.for %i0 = 0 to 8 { - affine.for %i1 = 0 to 8 { - affine.for %i2 = 0 to 8 { + for %i0 = 0 to 8 { + for %i1 = 0 to 8 { + for %i2 = 0 to 8 { %idx0 = affine.apply (d0, d1, d2) -> (d0 floordiv 4) (%i0, %i1, %i2) %idx1 = affine.apply (d0, d1, d2) -> (d1 mod 2) (%i0, %i1, %i2) %idx2 = affine.apply (d0, d1, d2) -> (d2 floordiv 4) (%i0, %i1, %i2) @@ -719,12 +719,12 @@ func @delinearize_mod_floordiv() { %in = alloc() : memref<2x2x3x3x16x1xi32> %out = alloc() : memref<64x9xi32> - affine.for %i0 = 0 to 2 { - affine.for %i1 = 0 to 2 { - affine.for %i2 = 0 to 3 { - affine.for %i3 = 0 to 3 { - affine.for %i4 = 0 to 16 { - affine.for %i5 = 0 to 1 { + for %i0 = 0 to 2 { + for %i1 = 0 to 2 { + for %i2 = 0 to 3 { + for %i3 = 0 to 3 { + for %i4 = 0 to 16 { + for %i5 = 0 to 1 { store %val, %in[%i0, %i1, %i2, %i3, %i4, %i5] : memref<2x2x3x3x16x1xi32> // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-note@-2 {{dependence from 0 to 0 at depth 2 = false}} @@ -742,8 +742,8 @@ func @delinearize_mod_floordiv() { } } - affine.for %ii = 0 to 64 { - affine.for %jj = 0 to 9 { + for %ii = 0 to 64 { + for %jj = 0 to 9 { %a0 = affine.apply (d0, d1) -> (d0 * (9 * 1024) + d1 * 128) (%ii, %jj) %a10 = affine.apply (d0) -> (d0 floordiv (2 * 3 * 3 * 128 * 128)) (%a0) diff --git a/mlir/test/Transforms/pipeline-data-transfer.mlir b/mlir/test/Transforms/pipeline-data-transfer.mlir index ede5c63fbac..30f98db2583 100644 --- a/mlir/test/Transforms/pipeline-data-transfer.mlir +++ b/mlir/test/Transforms/pipeline-data-transfer.mlir @@ -16,13 +16,13 @@ func @loop_nest_dma() { %zero = constant 0 : index %num_elts = constant 128 : index - affine.for %i = 0 to 8 { + for %i = 0 to 8 { dma_start %A[%i], %Ah[%i], %num_elts, %tag[%zero] : memref<256 x f32>, memref<32 x f32, 1>, memref<1 x f32> dma_wait %tag[%zero], %num_elts : memref<1 x f32> %v = load %Ah[%i] : memref<32 x f32, (d0) -> (d0), 1> %r = "compute"(%v) : (f32) -> (f32) store %r, %Ah[%i] : memref<32 x f32, (d0) -> (d0), 1> - affine.for %j = 0 to 128 { + for %j = 0 to 128 { "do_more_compute"(%i, %j) : (index, index) -> () } } @@ -34,7 +34,7 @@ func @loop_nest_dma() { // CHECK-NEXT: %3 = affine.apply [[MOD_2]](%c0) // CHECK-NEXT: %4 = affine.apply [[MOD_2]](%c0) // CHECK-NEXT: dma_start %0[%c0], %1[%3, %c0], %c128, %2[%4, %c0_0] : memref<256xf32>, memref<2x32xf32, 1>, memref<2x1xf32> -// CHECK-NEXT: affine.for %i0 = 1 to 8 { +// CHECK-NEXT: for %i0 = 1 to 8 { // CHECK-NEXT: %5 = affine.apply [[MOD_2]](%i0) // CHECK-NEXT: %6 = affine.apply [[MOD_2]](%i0) // CHECK-NEXT: dma_start %0[%i0], %1[%5, %i0], %c128, %2[%6, %c0_0] : memref<256xf32>, memref<2x32xf32, 1>, memref<2x1xf32> @@ -45,7 +45,7 @@ func @loop_nest_dma() { // CHECK-NEXT: %10 = load %1[%9, %7] : memref<2x32xf32, 1> // CHECK-NEXT: %11 = "compute"(%10) : (f32) -> f32 // CHECK-NEXT: store %11, %1[%9, %7] : memref<2x32xf32, 1> -// CHECK-NEXT: affine.for %i1 = 0 to 128 { +// CHECK-NEXT: for %i1 = 0 to 128 { // CHECK-NEXT: "do_more_compute"(%7, %i1) : (index, index) -> () // CHECK-NEXT: } // CHECK-NEXT: } @@ -56,7 +56,7 @@ func @loop_nest_dma() { // CHECK-NEXT: %15 = load %1[%14, %12] : memref<2x32xf32, 1> // CHECK-NEXT: %16 = "compute"(%15) : (f32) -> f32 // CHECK-NEXT: store %16, %1[%14, %12] : memref<2x32xf32, 1> -// CHECK-NEXT: affine.for %i2 = 0 to 128 { +// CHECK-NEXT: for %i2 = 0 to 128 { // CHECK-NEXT: "do_more_compute"(%12, %i2) : (index, index) -> () // CHECK-NEXT: } // CHECK-NEXT: return @@ -68,7 +68,7 @@ func @loop_step(%arg0: memref<512xf32>, %arg1: memref<512xf32>) { %c0 = constant 0 : index %c4 = constant 4 : index - affine.for %i0 = 0 to 512 step 4 { + for %i0 = 0 to 512 step 4 { %1 = alloc() : memref<4xf32, 1> %2 = alloc() : memref<1xi32> dma_start %arg0[%i0], %1[%c0], %c4, %2[%c0] @@ -82,7 +82,7 @@ func @loop_step(%arg0: memref<512xf32>, // CHECK: %2 = affine.apply [[FLOOR_MOD_2]](%c0) // CHECK: %3 = affine.apply [[FLOOR_MOD_2]](%c0) // CHECK-NEXT: dma_start %arg0[%c0], %0[%2, %c0_0], %c4, [[TAG]][%3, %c0_0] : memref<512xf32>, memref<2x4xf32, 1>, memref<2x1xi32> -// CHECK-NEXT: affine.for %i0 = 4 to 512 step 4 { +// CHECK-NEXT: for %i0 = 4 to 512 step 4 { // CHECK-NEXT: %4 = affine.apply [[FLOOR_MOD_2]](%i0) // CHECK-NEXT: %5 = affine.apply [[FLOOR_MOD_2]](%i0) // CHECK-NEXT: dma_start %arg0[%i0], %0[%4, %c0_0], %c4, [[TAG]][%5, %c0_0] : memref<512xf32>, memref<2x4xf32, 1>, memref<2x1xi32> @@ -114,8 +114,8 @@ func @loop_dma_nested(%arg0: memref<512x32xvector<8xf32>, #map0>, %arg1: memref< // Prologue for DMA overlap on arg2. // CHECK:[[TAG_ARG2:%[0-9]+]] = alloc() : memref<2x2xi32> // CHECK: dma_start %arg2[ - // CHECK: affine.for %i0 = 1 to 8 { - affine.for %i0 = 0 to 8 { + // CHECK: for %i0 = 1 to 8 { + for %i0 = 0 to 8 { %6 = affine.apply #map2(%i0) dma_start %arg2[%6, %c0], %2[%c0, %c0], %num_elts, %5[%c0] : memref<512x32xvector<8xf32>, #map0>, memref<64x4xvector<8xf32>, #map0, 2>, memref<2xi32> dma_wait %5[%c0], %num_elts : memref<2xi32> @@ -127,8 +127,8 @@ func @loop_dma_nested(%arg0: memref<512x32xvector<8xf32>, #map0>, %arg1: memref< // CHECK: [[TAG_ARG1:%[0-9]+]] = alloc() : memref<2x2xi32> // CHECK: dma_start %arg0[ // CHECK: dma_start %arg1[ - // CHECK-NEXT affine.for %i1 = 1 to 8 { - affine.for %i1 = 0 to 8 { + // CHECK-NEXT for %i1 = 1 to 8 { + for %i1 = 0 to 8 { %7 = affine.apply #map1(%i0, %i1) %8 = affine.apply #map2(%i1) dma_start %arg0[%7, %c0], %0[%c0, %c0], %num_elts, %3[%c0] : memref<512x32xvector<8xf32>, #map0>, memref<64x4xvector<8xf32>, #map0, 2>, memref<2xi32> @@ -140,8 +140,8 @@ func @loop_dma_nested(%arg0: memref<512x32xvector<8xf32>, #map0>, %arg1: memref< // CHECK: dma_start %arg1[ // CHECK: dma_wait [[TAG_ARG0]] // CHECK: dma_wait [[TAG_ARG1]] - // CHECK-NEXT: affine.for %i2 = 0 to 4 { - affine.for %i2 = 0 to 4 { + // CHECK-NEXT: for %i2 = 0 to 4 { + for %i2 = 0 to 4 { "foo"() : () -> () } } @@ -155,16 +155,16 @@ func @loop_dma_nested(%arg0: memref<512x32xvector<8xf32>, #map0>, %arg1: memref< // CHECK: [[TAG_ARG1_NESTED:%[0-9]+]] = alloc() : memref<2x2xi32> // CHECK: dma_start %arg0[ // CHECK: dma_start %arg1[ - // CHECK: affine.for %i4 = 1 to 8 { + // CHECK: for %i4 = 1 to 8 { // CHECK: dma_start %arg0[ // CHECK: dma_start %arg1[ // CHECK: dma_wait [[TAG_ARG0_NESTED]] // CHECK: dma_wait [[TAG_ARG1_NESTED]] - // CHECK: affine.for %i5 = 0 to 4 { + // CHECK: for %i5 = 0 to 4 { // CHECK: "foo"() : () -> () // CHECK: dma_wait [[TAG_ARG0_NESTED]] // CHECK: dma_wait [[TAG_ARG1_NESTED]] - // CHECK: affine.for %i6 = 0 to 4 { + // CHECK: for %i6 = 0 to 4 { } return // CHECK: } @@ -185,8 +185,8 @@ func @loop_dma_dependent(%arg2: memref<512x32xvector<8xf32>>) { // The two DMAs below are dependent (incoming and outgoing on the same // memref) in the same iteration; so no pipelining here. // CHECK-NOT: dma_start - // CHECK: affine.for %i0 = 0 to 8 { - affine.for %i0 = 0 to 8 { + // CHECK: for %i0 = 0 to 8 { + for %i0 = 0 to 8 { %6 = affine.apply #map2(%i0) dma_start %arg2[%6, %c0], %2[%c0, %c0], %num_elts, %5[%c0] : memref<512x32xvector<8xf32>>, memref<64x4xvector<8xf32>, 2>, memref<2xi32> dma_wait %5[%c0], %num_elts : memref<2xi32> @@ -206,8 +206,8 @@ func @escaping_use(%arg0: memref<512 x 32 x f32>) { %tag = alloc() : memref<1 x i32> // CHECK-NOT: dma_start - // CHECK: affine.for %i0 = 0 to 16 { - affine.for %kTT = 0 to 16 { + // CHECK: for %i0 = 0 to 16 { + for %kTT = 0 to 16 { dma_start %arg0[%zero, %zero], %Av[%zero, %zero], %num_elt, %tag[%zero] : memref<512 x 32 x f32>, memref<32 x 32 x f32, 2>, memref<1 x i32> @@ -230,14 +230,14 @@ func @live_out_use(%arg0: memref<512 x 32 x f32>) -> f32 { %tag = alloc() : memref<1 x i32> // CHECK-NOT: dma_start - // CHECK: affine.for %i0 = 0 to 16 { - affine.for %kTT = 0 to 16 { + // CHECK: for %i0 = 0 to 16 { + for %kTT = 0 to 16 { dma_start %arg0[%zero, %zero], %Av[%zero, %zero], %num_elt, %tag[%zero] : memref<512 x 32 x f32>, memref<32 x 32 x f32, 2>, memref<1 x i32> dma_wait %tag[%zero], %num_elt : memref<1 x i32> } - // Use live out of 'affine.for' inst; no DMA pipelining will be done. + // Use live out of 'for' inst; no DMA pipelining will be done. %v = load %Av[%zero, %zero] : memref<32 x 32 x f32, 2> return %v : f32 // CHECK: %{{[0-9]+}} = load %{{[0-9]+}}[%c0, %c0] : memref<32x32xf32, 2> @@ -261,14 +261,14 @@ func @dynamic_shape_dma_buffer(%arg0: memref<512 x 32 x f32>) { // CHECK: %5 = affine.apply [[MOD_2]](%c0) // CHECK: %6 = affine.apply [[MOD_2]](%c0) // CHECK: dma_start %arg0[%c0_0, %c0_0], %3[%5, %c0_0, %c0_0], %c512, %4[%6, %c0_0] - affine.for %kTT = 0 to 16 { + for %kTT = 0 to 16 { dma_start %arg0[%zero, %zero], %Av[%zero, %zero], %num_elt, %tag[%zero] : memref<512 x 32 x f32>, memref, memref<1 x i32> dma_wait %tag[%zero], %num_elt : memref<1 x i32> } return -// CHECK-NEXT: affine.for %i0 = 1 to 16 { +// CHECK-NEXT: for %i0 = 1 to 16 { // CHECK: %7 = affine.apply [[MOD_2]](%i0) // CHECK: %8 = affine.apply [[MOD_2]](%i0) // CHECK: dma_start %arg0[%c0_0, %c0_0], %3[%7, %c0_0, %c0_0], %c512, %4[%8, %c0_0] diff --git a/mlir/test/Transforms/simplify-affine-structures.mlir b/mlir/test/Transforms/simplify-affine-structures.mlir index feb3a99b70b..2459604f369 100644 --- a/mlir/test/Transforms/simplify-affine-structures.mlir +++ b/mlir/test/Transforms/simplify-affine-structures.mlir @@ -73,8 +73,8 @@ // CHECK-LABEL: func @test_gaussian_elimination_empty_set0() { func @test_gaussian_elimination_empty_set0() { - affine.for %i0 = 1 to 10 { - affine.for %i1 = 1 to 100 { + for %i0 = 1 to 10 { + for %i1 = 1 to 100 { // CHECK: [[SET_EMPTY_2D]](%i0, %i1) affine.if (d0, d1) : (2 == 0)(%i0, %i1) { } @@ -85,8 +85,8 @@ func @test_gaussian_elimination_empty_set0() { // CHECK-LABEL: func @test_gaussian_elimination_empty_set1() { func @test_gaussian_elimination_empty_set1() { - affine.for %i0 = 1 to 10 { - affine.for %i1 = 1 to 100 { + for %i0 = 1 to 10 { + for %i1 = 1 to 100 { // CHECK: [[SET_EMPTY_2D]](%i0, %i1) affine.if (d0, d1) : (1 >= 0, -1 >= 0) (%i0, %i1) { } @@ -97,8 +97,8 @@ func @test_gaussian_elimination_empty_set1() { // CHECK-LABEL: func @test_gaussian_elimination_non_empty_set2() { func @test_gaussian_elimination_non_empty_set2() { - affine.for %i0 = 1 to 10 { - affine.for %i1 = 1 to 100 { + for %i0 = 1 to 10 { + for %i1 = 1 to 100 { // CHECK: #set1(%i0, %i1) affine.if #set2(%i0, %i1) { } @@ -111,8 +111,8 @@ func @test_gaussian_elimination_non_empty_set2() { func @test_gaussian_elimination_empty_set3() { %c7 = constant 7 : index %c11 = constant 11 : index - affine.for %i0 = 1 to 10 { - affine.for %i1 = 1 to 100 { + for %i0 = 1 to 10 { + for %i1 = 1 to 100 { // CHECK: #set2(%i0, %i1)[%c7, %c11] affine.if #set3(%i0, %i1)[%c7, %c11] { } @@ -125,8 +125,8 @@ func @test_gaussian_elimination_empty_set3() { func @test_gaussian_elimination_non_empty_set4() { %c7 = constant 7 : index %c11 = constant 11 : index - affine.for %i0 = 1 to 10 { - affine.for %i1 = 1 to 100 { + for %i0 = 1 to 10 { + for %i1 = 1 to 100 { // CHECK: #set3(%i0, %i1)[%c7, %c11] affine.if #set4(%i0, %i1)[%c7, %c11] { } @@ -139,8 +139,8 @@ func @test_gaussian_elimination_non_empty_set4() { func @test_gaussian_elimination_empty_set5() { %c7 = constant 7 : index %c11 = constant 11 : index - affine.for %i0 = 1 to 10 { - affine.for %i1 = 1 to 100 { + for %i0 = 1 to 10 { + for %i1 = 1 to 100 { // CHECK: #set2(%i0, %i1)[%c7, %c11] affine.if #set5(%i0, %i1)[%c7, %c11] { } @@ -151,8 +151,8 @@ func @test_gaussian_elimination_empty_set5() { // CHECK-LABEL: func @test_fuzz_explosion func @test_fuzz_explosion(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index) { - affine.for %i0 = 1 to 10 { - affine.for %i1 = 1 to 100 { + for %i0 = 1 to 10 { + for %i1 = 1 to 100 { affine.if #set_fuzz_virus(%i0, %i1, %arg0, %arg1, %arg2, %arg3) { } } @@ -163,8 +163,8 @@ func @test_fuzz_explosion(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : i // CHECK-LABEL: func @test_empty_set(%arg0: index) { func @test_empty_set(%N : index) { - affine.for %i = 0 to 10 { - affine.for %j = 0 to 10 { + for %i = 0 to 10 { + for %j = 0 to 10 { // CHECK: affine.if [[SET_EMPTY_2D]](%i0, %i1) affine.if (d0, d1) : (d0 - d1 >= 0, d1 - d0 - 1 >= 0)(%i, %j) { "foo"() : () -> () @@ -198,8 +198,8 @@ func @test_empty_set(%N : index) { } } // The tests below test GCDTightenInequalities(). - affine.for %k = 0 to 10 { - affine.for %l = 0 to 10 { + for %k = 0 to 10 { + for %l = 0 to 10 { // Empty because no multiple of 8 lies between 4 and 7. // CHECK: affine.if [[SET_EMPTY_1D]](%i2) affine.if (d0) : (8*d0 - 4 >= 0, -8*d0 + 7 >= 0)(%k) { @@ -226,7 +226,7 @@ func @test_empty_set(%N : index) { } } - affine.for %m = 0 to 10 { + for %m = 0 to 10 { // CHECK: affine.if [[SET_EMPTY_1D]](%i{{[0-9]+}}) affine.if (d0) : (d0 mod 2 - 3 == 0) (%m) { "foo"() : () -> () diff --git a/mlir/test/Transforms/strip-debuginfo.mlir b/mlir/test/Transforms/strip-debuginfo.mlir index 181481279d0..fdabd5d12e0 100644 --- a/mlir/test/Transforms/strip-debuginfo.mlir +++ b/mlir/test/Transforms/strip-debuginfo.mlir @@ -10,7 +10,7 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) { %1 = "foo"() : () -> i32 loc("foo") // CHECK: } loc(unknown) - affine.for %i0 = 0 to 8 { + for %i0 = 0 to 8 { } loc(fused["foo", "mysource.cc":10:8]) // CHECK: } loc(unknown) diff --git a/mlir/test/Transforms/unroll-jam.mlir b/mlir/test/Transforms/unroll-jam.mlir index 98d284aeede..da4f965676f 100644 --- a/mlir/test/Transforms/unroll-jam.mlir +++ b/mlir/test/Transforms/unroll-jam.mlir @@ -7,13 +7,13 @@ // CHECK-LABEL: func @unroll_jam_imperfect_nest() { func @unroll_jam_imperfect_nest() { // CHECK: %c100 = constant 100 : index - // CHECK-NEXT: affine.for %i0 = 0 to 99 step 2 { - affine.for %i = 0 to 101 { + // CHECK-NEXT: for %i0 = 0 to 99 step 2 { + for %i = 0 to 101 { // CHECK: %0 = "addi32"(%i0, %i0) : (index, index) -> i32 // CHECK-NEXT: %1 = affine.apply [[MAP_PLUS_1]](%i0) // CHECK-NEXT: %2 = "addi32"(%1, %1) : (index, index) -> i32 %x = "addi32"(%i, %i) : (index, index) -> i32 - affine.for %j = 0 to 17 { + for %j = 0 to 17 { // CHECK: %3 = "addi32"(%i0, %i0) : (index, index) -> i32 // CHECK-NEXT: %4 = "addi32"(%3, %3) : (i32, i32) -> i32 // CHECK-NEXT: %5 = affine.apply [[MAP_PLUS_1]](%i0) @@ -29,7 +29,7 @@ func @unroll_jam_imperfect_nest() { } // CHECK } // cleanup loop (single iteration) // CHECK: %11 = "addi32"(%c100, %c100) : (index, index) -> i32 - // CHECK-NEXT: affine.for %i2 = 0 to 17 { + // CHECK-NEXT: for %i2 = 0 to 17 { // CHECK-NEXT: %12 = "addi32"(%c100, %c100) : (index, index) -> i32 // CHECK-NEXT: %13 = "addi32"(%12, %12) : (i32, i32) -> i32 // CHECK-NEXT: } @@ -39,8 +39,8 @@ func @unroll_jam_imperfect_nest() { // UNROLL-BY-4-LABEL: func @loop_nest_unknown_count_1(%arg0: index) { func @loop_nest_unknown_count_1(%N : index) { - // UNROLL-BY-4-NEXT: affine.for %i0 = 1 to #map{{[0-9]+}}()[%arg0] step 4 { - // UNROLL-BY-4-NEXT: affine.for %i1 = 1 to 100 { + // UNROLL-BY-4-NEXT: for %i0 = 1 to #map{{[0-9]+}}()[%arg0] step 4 { + // UNROLL-BY-4-NEXT: for %i1 = 1 to 100 { // UNROLL-BY-4-NEXT: %0 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32 @@ -48,14 +48,14 @@ func @loop_nest_unknown_count_1(%N : index) { // UNROLL-BY-4-NEXT: } // UNROLL-BY-4-NEXT: } // A cleanup loop should be generated here. - // UNROLL-BY-4-NEXT: affine.for %i2 = #map{{[0-9]+}}()[%arg0] to %arg0 { - // UNROLL-BY-4-NEXT: affine.for %i3 = 1 to 100 { + // UNROLL-BY-4-NEXT: for %i2 = #map{{[0-9]+}}()[%arg0] to %arg0 { + // UNROLL-BY-4-NEXT: for %i3 = 1 to 100 { // UNROLL-BY-4-NEXT: %4 = "foo"() : () -> i32 // UNROLL-BY-4_NEXT: } // UNROLL-BY-4_NEXT: } // Specify the lower bound in a form so that both lb and ub operands match. - affine.for %i = ()[s0] -> (1)()[%N] to %N { - affine.for %j = 1 to 100 { + for %i = ()[s0] -> (1)()[%N] to %N { + for %j = 1 to 100 { %x = "foo"() : () -> i32 } } @@ -64,8 +64,8 @@ func @loop_nest_unknown_count_1(%N : index) { // UNROLL-BY-4-LABEL: func @loop_nest_unknown_count_2(%arg0: index) { func @loop_nest_unknown_count_2(%arg : index) { - // UNROLL-BY-4-NEXT: affine.for %i0 = %arg0 to #map{{[0-9]+}}()[%arg0] step 4 { - // UNROLL-BY-4-NEXT: affine.for %i1 = 1 to 100 { + // UNROLL-BY-4-NEXT: for %i0 = %arg0 to #map{{[0-9]+}}()[%arg0] step 4 { + // UNROLL-BY-4-NEXT: for %i1 = 1 to 100 { // UNROLL-BY-4-NEXT: %0 = "foo"(%i0) : (index) -> i32 // UNROLL-BY-4-NEXT: %1 = affine.apply #map{{[0-9]+}}(%i0) // UNROLL-BY-4-NEXT: %2 = "foo"(%1) : (index) -> i32 @@ -77,12 +77,12 @@ func @loop_nest_unknown_count_2(%arg : index) { // UNROLL-BY-4-NEXT: } // The cleanup loop is a single iteration one and is promoted. // UNROLL-BY-4-NEXT: %7 = affine.apply [[M1:#map{{[0-9]+}}]]()[%arg0] - // UNROLL-BY-4-NEXT: affine.for %i3 = 1 to 100 { + // UNROLL-BY-4-NEXT: for %i3 = 1 to 100 { // UNROLL-BY-4-NEXT: %8 = "foo"() : () -> i32 // UNROLL-BY-4_NEXT: } // Specify the lower bound in a form so that both lb and ub operands match. - affine.for %i = ()[s0] -> (s0) ()[%arg] to ()[s0] -> (s0+8) ()[%arg] { - affine.for %j = 1 to 100 { + for %i = ()[s0] -> (s0) ()[%arg] to ()[s0] -> (s0+8) ()[%arg] { + for %j = 1 to 100 { %x = "foo"(%i) : (index) -> i32 } } diff --git a/mlir/test/Transforms/unroll.mlir b/mlir/test/Transforms/unroll.mlir index 013f65367cb..c023561faa8 100644 --- a/mlir/test/Transforms/unroll.mlir +++ b/mlir/test/Transforms/unroll.mlir @@ -46,13 +46,13 @@ // CHECK-LABEL: func @loop_nest_simplest() { func @loop_nest_simplest() { - // CHECK: affine.for %i0 = 0 to 100 step 2 { - affine.for %i = 0 to 100 step 2 { + // CHECK: for %i0 = 0 to 100 step 2 { + for %i = 0 to 100 step 2 { // CHECK: %c1_i32 = constant 1 : i32 // CHECK-NEXT: %c1_i32_0 = constant 1 : i32 // CHECK-NEXT: %c1_i32_1 = constant 1 : i32 // CHECK-NEXT: %c1_i32_2 = constant 1 : i32 - affine.for %j = 0 to 4 { + for %j = 0 to 4 { %x = constant 1 : i32 } } // CHECK: } @@ -62,8 +62,8 @@ func @loop_nest_simplest() { // CHECK-LABEL: func @loop_nest_simple_iv_use() { func @loop_nest_simple_iv_use() { // CHECK: %c0 = constant 0 : index - // CHECK-NEXT: affine.for %i0 = 0 to 100 step 2 { - affine.for %i = 0 to 100 step 2 { + // CHECK-NEXT: for %i0 = 0 to 100 step 2 { + for %i = 0 to 100 step 2 { // CHECK: %0 = "addi32"(%c0, %c0) : (index, index) -> i32 // CHECK: %1 = affine.apply [[MAP0]](%c0) // CHECK-NEXT: %2 = "addi32"(%1, %1) : (index, index) -> i32 @@ -71,7 +71,7 @@ func @loop_nest_simple_iv_use() { // CHECK-NEXT: %4 = "addi32"(%3, %3) : (index, index) -> i32 // CHECK: %5 = affine.apply [[MAP2]](%c0) // CHECK-NEXT: %6 = "addi32"(%5, %5) : (index, index) -> i32 - affine.for %j = 0 to 4 { + for %j = 0 to 4 { %x = "addi32"(%j, %j) : (index, index) -> i32 } } // CHECK: } @@ -82,8 +82,8 @@ func @loop_nest_simple_iv_use() { // CHECK-LABEL: func @loop_nest_body_def_use() { func @loop_nest_body_def_use() { // CHECK: %c0 = constant 0 : index - // CHECK-NEXT: affine.for %i0 = 0 to 100 step 2 { - affine.for %i = 0 to 100 step 2 { + // CHECK-NEXT: for %i0 = 0 to 100 step 2 { + for %i = 0 to 100 step 2 { // CHECK: %c0_0 = constant 0 : index %c0 = constant 0 : index // CHECK: %0 = affine.apply [[MAP0]](%c0) @@ -97,7 +97,7 @@ func @loop_nest_body_def_use() { // CHECK-NEXT: %8 = affine.apply [[MAP2]](%c0) // CHECK-NEXT: %9 = affine.apply [[MAP0]](%8) // CHECK-NEXT: %10 = "addi32"(%9, %c0_0) : (index, index) -> index - affine.for %j = 0 to 4 { + for %j = 0 to 4 { %x = "affine.apply" (%j) { map: (d0) -> (d0 + 1) } : (index) -> (index) %y = "addi32"(%x, %c0) : (index, index) -> index @@ -110,14 +110,14 @@ func @loop_nest_body_def_use() { func @loop_nest_strided() { // CHECK: %c2 = constant 2 : index // CHECK-NEXT: %c2_0 = constant 2 : index - // CHECK-NEXT: affine.for %i0 = 0 to 100 { - affine.for %i = 0 to 100 { + // CHECK-NEXT: for %i0 = 0 to 100 { + for %i = 0 to 100 { // CHECK: %0 = affine.apply [[MAP0]](%c2_0) // CHECK-NEXT: %1 = "addi32"(%0, %0) : (index, index) -> index // CHECK-NEXT: %2 = affine.apply [[MAP1]](%c2_0) // CHECK-NEXT: %3 = affine.apply [[MAP0]](%2) // CHECK-NEXT: %4 = "addi32"(%3, %3) : (index, index) -> index - affine.for %j = 2 to 6 step 2 { + for %j = 2 to 6 step 2 { %x = "affine.apply" (%j) { map: (d0) -> (d0 + 1) } : (index) -> (index) %y = "addi32"(%x, %x) : (index, index) -> index @@ -130,7 +130,7 @@ func @loop_nest_strided() { // CHECK-NEXT: %10 = affine.apply [[MAP3]](%c2) // CHECK-NEXT: %11 = affine.apply [[MAP0]](%10) // CHECK-NEXT: %12 = "addi32"(%11, %11) : (index, index) -> index - affine.for %k = 2 to 7 step 2 { + for %k = 2 to 7 step 2 { %z = "affine.apply" (%k) { map: (d0) -> (d0 + 1) } : (index) -> (index) %w = "addi32"(%z, %z) : (index, index) -> index @@ -142,8 +142,8 @@ func @loop_nest_strided() { // CHECK-LABEL: func @loop_nest_multiple_results() { func @loop_nest_multiple_results() { // CHECK: %c0 = constant 0 : index - // CHECK-NEXT: affine.for %i0 = 0 to 100 { - affine.for %i = 0 to 100 { + // CHECK-NEXT: for %i0 = 0 to 100 { + for %i = 0 to 100 { // CHECK: %0 = affine.apply [[MAP4]](%i0, %c0) // CHECK-NEXT: %1 = "addi32"(%0, %0) : (index, index) -> index // CHECK-NEXT: %2 = affine.apply #map{{.*}}(%i0, %c0) @@ -153,7 +153,7 @@ func @loop_nest_multiple_results() { // CHECK-NEXT: %6 = "addi32"(%5, %5) : (index, index) -> index // CHECK-NEXT: %7 = affine.apply #map{{.*}}(%i0, %4) // CHECK-NEXT: %8 = "fma"(%7, %5, %5) : (index, index, index) -> (index, index) - affine.for %j = 0 to 2 step 1 { + for %j = 0 to 2 step 1 { %x = affine.apply (d0, d1) -> (d0 + 1) (%i, %j) %y = "addi32"(%x, %x) : (index, index) -> index %z = affine.apply (d0, d1) -> (d0 + 3) (%i, %j) @@ -170,8 +170,8 @@ func @loop_nest_seq_imperfect(%a : memref<128x128xf32>) { // CHECK: %c0 = constant 0 : index // CHECK-NEXT: %c128 = constant 128 : index %c128 = constant 128 : index - // CHECK: affine.for %i0 = 0 to 100 { - affine.for %i = 0 to 100 { + // CHECK: for %i0 = 0 to 100 { + for %i = 0 to 100 { // CHECK: %0 = "vld"(%i0) : (index) -> i32 %ld = "vld"(%i) : (index) -> i32 // CHECK: %1 = affine.apply [[MAP0]](%c0) @@ -189,7 +189,7 @@ func @loop_nest_seq_imperfect(%a : memref<128x128xf32>) { // CHECK-NEXT: %13 = affine.apply [[MAP0]](%12) // CHECK-NEXT: %14 = "vmulf"(%12, %13) : (index, index) -> index // CHECK-NEXT: %15 = "vaddf"(%14, %14) : (index, index) -> index - affine.for %j = 0 to 4 { + for %j = 0 to 4 { %x = "affine.apply" (%j) { map: (d0) -> (d0 + 1) } : (index) -> (index) %y = "vmulf"(%j, %x) : (index, index) -> index @@ -218,7 +218,7 @@ func @loop_nest_seq_multiple() { // CHECK-NEXT: %5 = affine.apply [[MAP2]](%c0_0) // CHECK-NEXT: %6 = affine.apply [[MAP0]](%5) // CHECK-NEXT: "mul"(%6, %6) : (index, index) -> () - affine.for %j = 0 to 4 { + for %j = 0 to 4 { %x = "affine.apply" (%j) { map: (d0) -> (d0 + 1) } : (index) -> (index) "mul"(%x, %x) : (index, index) -> () @@ -226,8 +226,8 @@ func @loop_nest_seq_multiple() { // CHECK: %c99 = constant 99 : index %k = "constant"(){value: 99} : () -> index - // CHECK: affine.for %i0 = 0 to 100 step 2 { - affine.for %m = 0 to 100 step 2 { + // CHECK: for %i0 = 0 to 100 step 2 { + for %m = 0 to 100 step 2 { // CHECK: %7 = affine.apply [[MAP0]](%c0) // CHECK-NEXT: %8 = affine.apply [[MAP6]](%c0)[%c99] // CHECK-NEXT: %9 = affine.apply [[MAP0]](%c0) @@ -239,7 +239,7 @@ func @loop_nest_seq_multiple() { // CHECK-NEXT: %15 = affine.apply [[MAP2]](%c0) // CHECK-NEXT: %16 = affine.apply [[MAP0]](%15) // CHECK-NEXT: %17 = affine.apply [[MAP6]](%15)[%c99] - affine.for %n = 0 to 4 { + for %n = 0 to 4 { %y = "affine.apply" (%n) { map: (d0) -> (d0 + 1) } : (index) -> (index) %z = "affine.apply" (%n, %k) { map: (d0) [s0] -> (d0 + s0 + 1) } : @@ -251,16 +251,16 @@ func @loop_nest_seq_multiple() { // SHORT-LABEL: func @loop_nest_outer_unroll() { func @loop_nest_outer_unroll() { - // SHORT: affine.for %i0 = 0 to 4 { + // SHORT: for %i0 = 0 to 4 { // SHORT-NEXT: %0 = affine.apply [[MAP0]](%i0) // SHORT-NEXT: %1 = "addi32"(%0, %0) : (index, index) -> index // SHORT-NEXT: } - // SHORT-NEXT: affine.for %i1 = 0 to 4 { + // SHORT-NEXT: for %i1 = 0 to 4 { // SHORT-NEXT: %2 = affine.apply [[MAP0]](%i1) // SHORT-NEXT: %3 = "addi32"(%2, %2) : (index, index) -> index // SHORT-NEXT: } - affine.for %i = 0 to 2 { - affine.for %j = 0 to 4 { + for %i = 0 to 2 { + for %j = 0 to 4 { %x = "affine.apply" (%j) { map: (d0) -> (d0 + 1) } : (index) -> (index) %y = "addi32"(%x, %x) : (index, index) -> index @@ -284,28 +284,28 @@ func @loop_nest_seq_long() -> i32 { %zero_idx = constant 0 : index - affine.for %n0 = 0 to 512 { - affine.for %n1 = 0 to 8 { + for %n0 = 0 to 512 { + for %n1 = 0 to 8 { store %one, %A[%n0, %n1] : memref<512 x 512 x i32, (d0, d1) -> (d0, d1), 2> store %two, %B[%n0, %n1] : memref<512 x 512 x i32, (d0, d1) -> (d0, d1), 2> store %zero, %C[%n0, %n1] : memref<512 x 512 x i32, (d0, d1) -> (d0, d1), 2> } } - affine.for %i0 = 0 to 2 { - affine.for %i1 = 0 to 2 { - affine.for %i2 = 0 to 8 { + for %i0 = 0 to 2 { + for %i1 = 0 to 2 { + for %i2 = 0 to 8 { %b2 = "affine.apply" (%i1, %i2) {map: (d0, d1) -> (16*d0 + d1)} : (index, index) -> index %x = load %B[%i0, %b2] : memref<512 x 512 x i32, (d0, d1) -> (d0, d1), 2> "op1"(%x) : (i32) -> () } - affine.for %j1 = 0 to 8 { - affine.for %j2 = 0 to 8 { + for %j1 = 0 to 8 { + for %j2 = 0 to 8 { %a2 = "affine.apply" (%i1, %j2) {map: (d0, d1) -> (16*d0 + d1)} : (index, index) -> index %v203 = load %A[%j1, %a2] : memref<512 x 512 x i32, (d0, d1) -> (d0, d1), 2> "op2"(%v203) : (i32) -> () } - affine.for %k2 = 0 to 8 { + for %k2 = 0 to 8 { %s0 = "op3"() : () -> i32 %c2 = "affine.apply" (%i0, %k2) {map: (d0, d1) -> (16*d0 + d1)} : (index, index) -> index %s1 = load %C[%j1, %c2] : memref<512 x 512 x i32, (d0, d1) -> (d0, d1), 2> @@ -322,8 +322,8 @@ func @loop_nest_seq_long() -> i32 { // UNROLL-BY-4-LABEL: func @unroll_unit_stride_no_cleanup() { func @unroll_unit_stride_no_cleanup() { - // UNROLL-BY-4: affine.for %i0 = 0 to 100 { - affine.for %i = 0 to 100 { + // UNROLL-BY-4: for %i0 = 0 to 100 { + for %i = 0 to 100 { // UNROLL-BY-4: for [[L1:%i[0-9]+]] = 0 to 8 step 4 { // UNROLL-BY-4-NEXT: %0 = "addi32"([[L1]], [[L1]]) : (index, index) -> i32 // UNROLL-BY-4-NEXT: %1 = "addi32"(%0, %0) : (i32, i32) -> i32 @@ -337,13 +337,13 @@ func @unroll_unit_stride_no_cleanup() { // UNROLL-BY-4-NEXT: %9 = "addi32"(%8, %8) : (index, index) -> i32 // UNROLL-BY-4-NEXT: %10 = "addi32"(%9, %9) : (i32, i32) -> i32 // UNROLL-BY-4-NEXT: } - affine.for %j = 0 to 8 { + for %j = 0 to 8 { %x = "addi32"(%j, %j) : (index, index) -> i32 %y = "addi32"(%x, %x) : (i32, i32) -> i32 } // empty loop - // UNROLL-BY-4: affine.for %i2 = 0 to 8 { - affine.for %k = 0 to 8 { + // UNROLL-BY-4: for %i2 = 0 to 8 { + for %k = 0 to 8 { } } return @@ -351,8 +351,8 @@ func @unroll_unit_stride_no_cleanup() { // UNROLL-BY-4-LABEL: func @unroll_unit_stride_cleanup() { func @unroll_unit_stride_cleanup() { - // UNROLL-BY-4: affine.for %i0 = 0 to 100 { - affine.for %i = 0 to 100 { + // UNROLL-BY-4: for %i0 = 0 to 100 { + for %i = 0 to 100 { // UNROLL-BY-4: for [[L1:%i[0-9]+]] = 0 to 7 step 4 { // UNROLL-BY-4-NEXT: %0 = "addi32"([[L1]], [[L1]]) : (index, index) -> i32 // UNROLL-BY-4-NEXT: %1 = "addi32"(%0, %0) : (i32, i32) -> i32 @@ -370,7 +370,7 @@ func @unroll_unit_stride_cleanup() { // UNROLL-BY-4-NEXT: %11 = "addi32"([[L2]], [[L2]]) : (index, index) -> i32 // UNROLL-BY-4-NEXT: %12 = "addi32"(%11, %11) : (i32, i32) -> i32 // UNROLL-BY-4-NEXT: } - affine.for %j = 0 to 10 { + for %j = 0 to 10 { %x = "addi32"(%j, %j) : (index, index) -> i32 %y = "addi32"(%x, %x) : (i32, i32) -> i32 } @@ -380,8 +380,8 @@ func @unroll_unit_stride_cleanup() { // UNROLL-BY-4-LABEL: func @unroll_non_unit_stride_cleanup() { func @unroll_non_unit_stride_cleanup() { - // UNROLL-BY-4: affine.for %i0 = 0 to 100 { - affine.for %i = 0 to 100 { + // UNROLL-BY-4: for %i0 = 0 to 100 { + for %i = 0 to 100 { // UNROLL-BY-4: for [[L1:%i[0-9]+]] = 2 to 37 step 20 { // UNROLL-BY-4-NEXT: %0 = "addi32"([[L1]], [[L1]]) : (index, index) -> i32 // UNROLL-BY-4-NEXT: %1 = "addi32"(%0, %0) : (i32, i32) -> i32 @@ -399,7 +399,7 @@ func @unroll_non_unit_stride_cleanup() { // UNROLL-BY-4-NEXT: %11 = "addi32"([[L2]], [[L2]]) : (index, index) -> i32 // UNROLL-BY-4-NEXT: %12 = "addi32"(%11, %11) : (i32, i32) -> i32 // UNROLL-BY-4-NEXT: } - affine.for %j = 2 to 48 step 5 { + for %j = 2 to 48 step 5 { %x = "addi32"(%j, %j) : (index, index) -> i32 %y = "addi32"(%x, %x) : (i32, i32) -> i32 } @@ -411,8 +411,8 @@ func @unroll_non_unit_stride_cleanup() { func @loop_nest_single_iteration_after_unroll(%N: index) { // UNROLL-BY-4: %c0 = constant 0 : index // UNROLL-BY-4: %c4 = constant 4 : index - // UNROLL-BY-4: affine.for %i0 = 0 to %arg0 { - affine.for %i = 0 to %N { + // UNROLL-BY-4: for %i0 = 0 to %arg0 { + for %i = 0 to %N { // UNROLL-BY-4: %0 = "addi32"(%c0, %c0) : (index, index) -> i32 // UNROLL-BY-4-NEXT: %1 = affine.apply [[MAP0]](%c0) // UNROLL-BY-4-NEXT: %2 = "addi32"(%1, %1) : (index, index) -> i32 @@ -422,7 +422,7 @@ func @loop_nest_single_iteration_after_unroll(%N: index) { // UNROLL-BY-4-NEXT: %6 = "addi32"(%5, %5) : (index, index) -> i32 // UNROLL-BY-4-NEXT: %7 = "addi32"(%c4, %c4) : (index, index) -> i32 // UNROLL-BY-4-NOT: for - affine.for %j = 0 to 5 { + for %j = 0 to 5 { %x = "addi32"(%j, %j) : (index, index) -> i32 } // UNROLL-BY-4-NOT: } } // UNROLL-BY-4: } @@ -434,8 +434,8 @@ func @loop_nest_single_iteration_after_unroll(%N: index) { // No cleanup will be generated here. // UNROLL-BY-4-LABEL: func @loop_nest_operand1() { func @loop_nest_operand1() { -// UNROLL-BY-4: affine.for %i0 = 0 to 100 step 2 { -// UNROLL-BY-4-NEXT: affine.for %i1 = [[MAP10]](%i0) to #map{{[0-9]+}}(%i0) step 4 +// UNROLL-BY-4: for %i0 = 0 to 100 step 2 { +// UNROLL-BY-4-NEXT: for %i1 = [[MAP10]](%i0) to #map{{[0-9]+}}(%i0) step 4 // UNROLL-BY-4-NEXT: %0 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32 @@ -443,8 +443,8 @@ func @loop_nest_operand1() { // UNROLL-BY-4-NEXT: } // UNROLL-BY-4-NEXT: } // UNROLL-BY-4-NEXT: return - affine.for %i = 0 to 100 step 2 { - affine.for %j = (d0) -> (0) (%i) to (d0) -> (d0 - d0 mod 4) (%i) { + for %i = 0 to 100 step 2 { + for %j = (d0) -> (0) (%i) to (d0) -> (d0 - d0 mod 4) (%i) { %x = "foo"() : () -> i32 } } @@ -454,8 +454,8 @@ func @loop_nest_operand1() { // No cleanup will be generated here. // UNROLL-BY-4-LABEL: func @loop_nest_operand2() { func @loop_nest_operand2() { -// UNROLL-BY-4: affine.for %i0 = 0 to 100 step 2 { -// UNROLL-BY-4-NEXT: affine.for %i1 = [[MAP11]](%i0) to #map{{[0-9]+}}(%i0) step 4 { +// UNROLL-BY-4: for %i0 = 0 to 100 step 2 { +// UNROLL-BY-4-NEXT: for %i1 = [[MAP11]](%i0) to #map{{[0-9]+}}(%i0) step 4 { // UNROLL-BY-4-NEXT: %0 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32 @@ -463,8 +463,8 @@ func @loop_nest_operand2() { // UNROLL-BY-4-NEXT: } // UNROLL-BY-4-NEXT: } // UNROLL-BY-4-NEXT: return - affine.for %i = 0 to 100 step 2 { - affine.for %j = (d0) -> (d0) (%i) to (d0) -> (5*d0 + 4) (%i) { + for %i = 0 to 100 step 2 { + for %j = (d0) -> (d0) (%i) to (d0) -> (5*d0 + 4) (%i) { %x = "foo"() : () -> i32 } } @@ -475,16 +475,16 @@ func @loop_nest_operand2() { // factor. The cleanup loop happens to be a single iteration one and is promoted. // UNROLL-BY-4-LABEL: func @loop_nest_operand3() { func @loop_nest_operand3() { - // UNROLL-BY-4: affine.for %i0 = 0 to 100 step 2 { - affine.for %i = 0 to 100 step 2 { - // UNROLL-BY-4: affine.for %i1 = [[MAP11]](%i0) to #map{{[0-9]+}}(%i0) step 4 { + // UNROLL-BY-4: for %i0 = 0 to 100 step 2 { + for %i = 0 to 100 step 2 { + // UNROLL-BY-4: for %i1 = [[MAP11]](%i0) to #map{{[0-9]+}}(%i0) step 4 { // UNROLL-BY-4-NEXT: %0 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %3 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: } // UNROLL-BY-4-NEXT: %4 = "foo"() : () -> i32 - affine.for %j = (d0) -> (d0) (%i) to (d0) -> (d0 + 9) (%i) { + for %j = (d0) -> (d0) (%i) to (d0) -> (d0 + 9) (%i) { %x = "foo"() : () -> i32 } } // UNROLL-BY-4: } @@ -493,20 +493,20 @@ func @loop_nest_operand3() { // UNROLL-BY-4-LABEL: func @loop_nest_operand4(%arg0: index) { func @loop_nest_operand4(%N : index) { - // UNROLL-BY-4: affine.for %i0 = 0 to 100 { - affine.for %i = 0 to 100 { - // UNROLL-BY-4: affine.for %i1 = [[MAP12]]()[%arg0] to #map{{[0-9]+}}()[%arg0] step 4 { + // UNROLL-BY-4: for %i0 = 0 to 100 { + for %i = 0 to 100 { + // UNROLL-BY-4: for %i1 = [[MAP12]]()[%arg0] to #map{{[0-9]+}}()[%arg0] step 4 { // UNROLL-BY-4: %0 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %3 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: } // A cleanup loop will be be generated here. - // UNROLL-BY-4-NEXT: affine.for %i2 = #map{{[0-9]+}}()[%arg0] to %arg0 { + // UNROLL-BY-4-NEXT: for %i2 = #map{{[0-9]+}}()[%arg0] to %arg0 { // UNROLL-BY-4-NEXT: %4 = "foo"() : () -> i32 // UNROLL-BY-4_NEXT: } // Specify the lower bound so that both lb and ub operands match. - affine.for %j = ()[s0] -> (0)()[%N] to %N { + for %j = ()[s0] -> (0)()[%N] to %N { %x = "foo"() : () -> i32 } } @@ -518,7 +518,7 @@ func @loop_nest_unroll_full() { // CHECK-NEXT: %0 = "foo"() : () -> i32 // CHECK-NEXT: %1 = "bar"() : () -> i32 // CHECK-NEXT: return - affine.for %i = 0 to 1 { + for %i = 0 to 1 { %x = "foo"() : () -> i32 %y = "bar"() : () -> i32 } @@ -527,7 +527,7 @@ func @loop_nest_unroll_full() { // UNROLL-BY-1-LABEL: func @unroll_by_one_should_promote_single_iteration_loop() func @unroll_by_one_should_promote_single_iteration_loop() { - affine.for %i = 0 to 1 { + for %i = 0 to 1 { %x = "foo"(%i) : (index) -> i32 } return diff --git a/mlir/utils/emacs/mlir-mode.el b/mlir/utils/emacs/mlir-mode.el index 8918890b8be..efc61cbe92a 100644 --- a/mlir/utils/emacs/mlir-mode.el +++ b/mlir/utils/emacs/mlir-mode.el @@ -42,7 +42,7 @@ ;; Keywords `(,(regexp-opt '(;; Toplevel entities - "br" "ceildiv" "func" "cond_br" "else" "extfunc" "false" "floordiv" "affine.for" "affine.if" "mod" "return" "size" "step" "to" "true" "??" ) 'symbols) . font-lock-keyword-face)) + "br" "ceildiv" "func" "cond_br" "else" "extfunc" "false" "floordiv" "for" "affine.if" "mod" "return" "size" "step" "to" "true" "??" ) 'symbols) . font-lock-keyword-face)) "Syntax highlighting for MLIR.") ;; Emacs 23 compatibility. diff --git a/mlir/utils/vim/mlir.vim b/mlir/utils/vim/mlir.vim index 0e2797f5603..91478d62136 100644 --- a/mlir/utils/vim/mlir.vim +++ b/mlir/utils/vim/mlir.vim @@ -10,9 +10,9 @@ syn keyword mlirType index i1 i2 i4 i8 i13 i16 i32 i64 \ f16 f32 tf_control syn keyword mlirType memref tensor vector -syntax keyword mlirKeywords extfunc func to step return +syntax keyword mlirKeywords extfunc cfgfunc mlfunc for to step return syntax keyword mlirConditional affine.if else -syntax keyword mlirCoreOps dim addf addi subf subi mulf muli cmpi select constant affine.apply affine.for call call_indirect extract_element getTensor memref_cast tensor_cast load store alloc dealloc dma_start dma_wait +syntax keyword mlirCoreOps dim addf addi subf subi mulf muli cmpi select constant affine.apply call call_indirect extract_element getTensor memref_cast tensor_cast load store alloc dealloc dma_start dma_wait syn match mlirInt "-\=\<\d\+\>" syn match mlirFloat "-\=\<\d\+\.\d\+\>" -- cgit v1.2.3 From 8f5f2c765d8073f4cbac073be0ace8daa59ff23f Mon Sep 17 00:00:00 2001 From: MLIR Team Date: Fri, 15 Feb 2019 09:32:18 -0800 Subject: LoopFusion: perform a series of loop interchanges to increase the loop depth at which slices of producer loop nests can be fused into constumer loop nests. *) Adds utility to LoopUtils to perform loop interchange of two AffineForOps. *) Adds utility to LoopUtils to sink a loop to a specified depth within a loop nest, using a series of loop interchanges. *) Computes dependences between all loads and stores in the loop nest, and classifies each loop as parallel or sequential. *) Computes loop interchange permutation required to sink sequential loops (and raise parallel loop nests) while preserving relative order among them. *) Checks each dependence against the permutation to make sure that dependences would not be violated by the loop interchange transformation. *) Calls loop interchange in LoopFusion pass on consumer loop nests before fusing in producers, sinking loops with loop carried dependences deeper into the consumer loop nest. *) Adds and updates related unit tests. PiperOrigin-RevId: 234158370 --- mlir/include/mlir/Transforms/LoopUtils.h | 10 +++ mlir/lib/Transforms/LoopFusion.cpp | 141 +++++++++++++++++++++++++++++++ mlir/lib/Transforms/Utils/LoopUtils.cpp | 34 ++++++++ mlir/test/Transforms/loop-fusion.mlir | 99 +++++++++++++++++++++- 4 files changed, 283 insertions(+), 1 deletion(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Transforms/LoopUtils.h b/mlir/include/mlir/Transforms/LoopUtils.h index f3d9b9fe9fd..470e94950a5 100644 --- a/mlir/include/mlir/Transforms/LoopUtils.h +++ b/mlir/include/mlir/Transforms/LoopUtils.h @@ -94,6 +94,16 @@ UtilResult instBodySkew(OpPointer forOp, ArrayRef shifts, UtilResult tileCodeGen(MutableArrayRef> band, ArrayRef tileSizes); +/// Performs loop interchange on 'forOpA' and 'forOpB'. Requires that 'forOpA' +/// and 'forOpB' are part of a perfectly nested sequence of loops. +void interchangeLoops(OpPointer forOpA, + OpPointer forOpB); + +/// Sinks 'forOp' by 'loopDepth' levels by performing a series of loop +/// interchanges. Requires that 'forOp' is part of a perfect nest with +/// 'loopDepth' AffineForOps consecutively nested under it. +void sinkLoop(OpPointer forOp, unsigned loopDepth); + } // end namespace mlir #endif // MLIR_TRANSFORMS_LOOP_UTILS_H diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 8d5f51059bf..cf0f07345a4 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -827,6 +827,142 @@ static unsigned getMaxLoopDepth(ArrayRef loadOpInsts, return loopDepth; } +// Compute loop interchange permutation: +// *) Computes dependence components between all op pairs in 'ops' for loop +// depths in range [1, 'maxLoopDepth']. +// *) Classifies the outermost 'maxLoopDepth' loops surrounding 'ops' as either +// parallel or sequential. +// *) Computes the loop permutation which sinks sequential loops deeper into +// the loop nest, while preserving the relative order between other loops. +// *) Checks each dependence component against the permutation to see if the +// desired loop interchange would violated dependences by making the a +// dependence componenent lexicographically negative. +// TODO(andydavis) Move this function to LoopUtils. +static bool +computeLoopInterchangePermutation(ArrayRef ops, + unsigned maxLoopDepth, + SmallVectorImpl *loopPermMap) { + // Gather dependence components for dependences between all ops in 'ops' + // at loop depths in range [1, maxLoopDepth]. + // TODO(andydavis) Refactor this loop into a LoopUtil utility function: + // mlir::getDependenceComponents(). + // TODO(andydavis) Split this loop into two: first check all dependences, + // and construct dep vectors. Then, scan through them to detect the parallel + // ones. + std::vector> depCompsVec; + llvm::SmallVector isParallelLoop(maxLoopDepth, true); + unsigned numOps = ops.size(); + for (unsigned d = 1; d <= maxLoopDepth; ++d) { + for (unsigned i = 0; i < numOps; ++i) { + auto *srcOpInst = ops[i]; + MemRefAccess srcAccess(srcOpInst); + for (unsigned j = 0; j < numOps; ++j) { + auto *dstOpInst = ops[j]; + MemRefAccess dstAccess(dstOpInst); + + FlatAffineConstraints dependenceConstraints; + llvm::SmallVector depComps; + // TODO(andydavis,bondhugula) Explore whether it would be profitable + // to pre-compute and store deps instead of repeatidly checking. + if (checkMemrefAccessDependence(srcAccess, dstAccess, d, + &dependenceConstraints, &depComps)) { + isParallelLoop[d - 1] = false; + depCompsVec.push_back(depComps); + } + } + } + } + // Count the number of parallel loops. + unsigned numParallelLoops = 0; + for (unsigned i = 0, e = isParallelLoop.size(); i < e; ++i) + if (isParallelLoop[i]) + ++numParallelLoops; + + // Compute permutation of loops that sinks sequential loops (and thus raises + // parallel loops) while preserving relative order. + llvm::SmallVector loopPermMapInv; + loopPermMapInv.resize(maxLoopDepth); + loopPermMap->resize(maxLoopDepth); + unsigned nextSequentialLoop = numParallelLoops; + unsigned nextParallelLoop = 0; + for (unsigned i = 0; i < maxLoopDepth; ++i) { + if (isParallelLoop[i]) { + (*loopPermMap)[i] = nextParallelLoop; + loopPermMapInv[nextParallelLoop++] = i; + } else { + (*loopPermMap)[i] = nextSequentialLoop; + loopPermMapInv[nextSequentialLoop++] = i; + } + } + + // Check each dependence component against the permutation to see if the + // desired loop interchange permutation would make the dependence vectors + // lexicographically negative. + // Example 1: [-1, 1][0, 0] + // Example 2: [0, 0][-1, 1] + for (unsigned i = 0, e = depCompsVec.size(); i < e; ++i) { + llvm::SmallVector &depComps = depCompsVec[i]; + assert(depComps.size() >= maxLoopDepth); + // Check if the first non-zero dependence component is positive. + for (unsigned j = 0; j < maxLoopDepth; ++j) { + unsigned permIndex = loopPermMapInv[j]; + assert(depComps[permIndex].lb.hasValue()); + int64_t depCompLb = depComps[permIndex].lb.getValue(); + if (depCompLb > 0) + break; + if (depCompLb < 0) + return false; + } + } + return true; +} + +// Sinks all sequential loops to the innermost levels (while preserving +// relative order among them) and moves all parallel loops to the +// outermost (while again preserving relative order among them). +// This can increase the loop depth at which we can fuse a slice, since we are +// pushing loop carried dependence to a greater depth in the loop nest. +static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) { + assert(node->inst->isa()); + // Get perfectly nested sequence of loops starting at root of loop nest. + // TODO(andydavis,bondhugula) Share this with similar code in loop tiling. + SmallVector, 4> loops; + OpPointer curr = node->inst->cast(); + loops.push_back(curr); + auto *currBody = curr->getBody(); + while (!currBody->empty() && + std::next(currBody->begin()) == currBody->end() && + (curr = curr->getBody()->front().dyn_cast())) { + loops.push_back(curr); + currBody = curr->getBody(); + } + if (loops.size() < 2) + return; + + // Merge loads and stores into the same array. + SmallVector memOps(node->loads.begin(), node->loads.end()); + memOps.append(node->stores.begin(), node->stores.end()); + + // Compute loop permutation in 'loopPermMap'. + llvm::SmallVector loopPermMap; + if (!computeLoopInterchangePermutation(memOps, loops.size(), &loopPermMap)) + return; + + int loopNestRootIndex = -1; + for (int i = loops.size() - 1; i >= 0; --i) { + int permIndex = static_cast(loopPermMap[i]); + // Store the index of the for loop which will be the new loop nest root. + if (permIndex == 0) + loopNestRootIndex = i; + if (permIndex > i) { + // Sink loop 'i' by 'permIndex - i' levels deeper into the loop nest. + sinkLoop(loops[i], permIndex - i); + } + } + assert(loopNestRootIndex != -1 && "invalid root index"); + node->inst = loops[loopNestRootIndex]->getInstruction(); +} + // Returns the slice union of 'sliceStateA' and 'sliceStateB' in 'sliceStateB' // using a rectangular bounding box. // TODO(andydavis) This function assumes that lower bounds for 'sliceStateA' @@ -1407,6 +1543,11 @@ public: // Skip if 'dstNode' is not a loop nest. if (!dstNode->inst->isa()) continue; + // Sink sequential loops in 'dstNode' (and thus raise parallel loops) + // while preserving relative order. This can increase the maximum loop + // depth at which we can fuse a slice of a producer loop nest into a + // consumer loop nest. + sinkSequentialLoops(dstNode); SmallVector loads = dstNode->loads; SmallVector dstLoadOpInsts; diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index a1903ace026..6b1a0be3bd3 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -22,9 +22,11 @@ #include "mlir/Transforms/LoopUtils.h" #include "mlir/AffineOps/AffineOps.h" +#include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/AffineStructures.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" @@ -452,3 +454,35 @@ bool mlir::loopUnrollByFactor(OpPointer forOp, return true; } + +/// Performs loop interchange on 'forOpA' and 'forOpB', where 'forOpB' is +/// nested within 'forOpA' as the only instruction in its block. +void mlir::interchangeLoops(OpPointer forOpA, + OpPointer forOpB) { + auto *forOpAInst = forOpA->getInstruction(); + // 1) Slice forOpA's instruction list (which is just forOpB) just before + // forOpA (in forOpA's parent's block) this should leave 'forOpA's + // instruction list empty (because its perfectly nested). + assert(&*forOpA->getBody()->begin() == forOpB->getInstruction()); + forOpAInst->getBlock()->getInstructions().splice( + Block::iterator(forOpAInst), forOpA->getBody()->getInstructions()); + // 2) Slice forOpB's instruction list into forOpA's instruction list (this + // leaves forOpB's instruction list empty). + forOpA->getBody()->getInstructions().splice( + forOpA->getBody()->begin(), forOpB->getBody()->getInstructions()); + // 3) Slice forOpA into forOpB's instruction list. + forOpB->getBody()->getInstructions().splice( + forOpB->getBody()->begin(), forOpAInst->getBlock()->getInstructions(), + Block::iterator(forOpAInst)); +} + +/// Performs a series of loop interchanges to sink 'forOp' 'loopDepth' levels +/// deeper in the loop nest. +void mlir::sinkLoop(OpPointer forOp, unsigned loopDepth) { + for (unsigned i = 0; i < loopDepth; ++i) { + assert(forOp->getBody()->front().isa()); + OpPointer nextForOp = + forOp->getBody()->front().cast(); + interchangeLoops(forOp, nextForOp); + } +} diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index 16afcaa8a17..72b0cddb514 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -938,7 +938,7 @@ func @fusion_at_depth0_not_currently_supported() { // NOTE: Should shrink memref size to 1 element access by load in dst loop // nest, and make the store in the slice store to the same element. // CHECK-DAG: %0 = alloc() : memref<1xf32> - // CHECK: for %i0 = 0 to 10 { + // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%c0] : memref<1xf32> // CHECK-NEXT: %1 = load %0[%c0_0] : memref<1xf32> // CHECK-NEXT: } @@ -1691,3 +1691,100 @@ func @should_fuse_after_private_memref_creation() { // CHECK-NEXT: return return } + +// ----- + +// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) + +// CHECK-LABEL: func @should_fuse_after_one_loop_interchange() { +func @should_fuse_after_one_loop_interchange() { + %a = alloc() : memref<10xf32> + + %cf0 = constant 0.0 : f32 + for %i0 = 0 to 10 { + store %cf0, %a[%i0] : memref<10xf32> + } + + for %i1 = 0 to 5 { + for %i2 = 0 to 10 { + %v0 = load %a[%i2] : memref<10xf32> + store %v0, %a[%i2] : memref<10xf32> + } + } + + // The dependence between the load and store is carried on loop '%i1', and + // cannot be fused with loop '%i0' without violating this dependence. + // Once loops '%i1' and %i2' are interchanged, loop '%i0' can be fused + // at loop depth 1, because the loop carrying the dependence has been + // interchanged and is now at depth 2. + + // CHECK: for %i0 = 0 to 10 { + // CHECK-NEXT: %1 = affine.apply [[MAP0]](%i0, %i0) + // CHECK-NEXT: store %cst, %0[%1] : memref<1xf32> + // CHECK-NEXT: for %i1 = 0 to 5 { + // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i0, %i0) + // CHECK-NEXT: %3 = load %0[%2] : memref<1xf32> + // CHECK-NEXT: %4 = affine.apply [[MAP0]](%i0, %i0) + // CHECK-NEXT: store %3, %0[%4] : memref<1xf32> + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: return + return +} + +// ----- + +// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1, d2, d3) -> (-d0 + d2) +// CHECK: [[MAP1:#map[0-9]+]] = (d0, d1, d2, d3) -> (-d1 + d3) + +// CHECK-LABEL: func @should_fuse_after_two_loop_interchanges() { +func @should_fuse_after_two_loop_interchanges() { + %a = alloc() : memref<6x8xf32> + + %cf0 = constant 0.0 : f32 + for %i0 = 0 to 6 { + for %i1 = 0 to 8 { + store %cf0, %a[%i0, %i1] : memref<6x8xf32> + } + } + + for %i2 = 0 to 4 { + for %i3 = 0 to 6 { + for %i4 = 0 to 2 { + for %i5 = 0 to 8 { + %v0 = load %a[%i3, %i5] : memref<6x8xf32> + %v1 = addf %v0, %v0 : f32 + store %v1, %a[%i3, %i5] : memref<6x8xf32> + } + } + } + } + + // The dependence between the load and store is carried on loops '%i2' and + // '%i4', and cannot be fused with loop '%i0' without violating this + // dependence. + // Once loop '%i2' is interchanged with loop '%i3', and again with loop + // '%i5', then loop '%i0' can be fused at loop depth 2, because the loop + // carring the dependences have been interchanged with loops at depth > 2. + + // CHECK: for %i0 = 0 to 6 { + // CHECK-NEXT: for %i1 = 0 to 8 { + // CHECK-NEXT: %1 = affine.apply [[MAP0]](%i0, %i1, %i0, %i1) + // CHECK-NEXT: %2 = affine.apply [[MAP1]](%i0, %i1, %i0, %i1) + // CHECK-NEXT: store %cst, %0[%1, %2] : memref<1x1xf32> + // CHECK-NEXT: for %i2 = 0 to 4 { + // CHECK-NEXT: for %i3 = 0 to 2 { + // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i1, %i0, %i1) + // CHECK-NEXT: %4 = affine.apply [[MAP1]](%i0, %i1, %i0, %i1) + // CHECK-NEXT: %5 = load %0[%3, %4] : memref<1x1xf32> + // CHECK-NEXT: %6 = addf %5, %5 : f32 + // CHECK-NEXT: %7 = affine.apply [[MAP0]](%i0, %i1, %i0, %i1) + // CHECK-NEXT: %8 = affine.apply [[MAP1]](%i0, %i1, %i0, %i1) + // CHECK-NEXT: store %6, %0[%7, %8] : memref<1x1xf32> + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: return + return +} -- cgit v1.2.3 From 58aa383e6092cd56dc3b5586b82b53674912dfad Mon Sep 17 00:00:00 2001 From: MLIR Team Date: Fri, 15 Feb 2019 17:12:19 -0800 Subject: Support fusing producer loop nests which write to a memref which is live out, provided that the write region of the consumer loop nest to the same memref is a super set of the producer's write region. PiperOrigin-RevId: 234240958 --- mlir/lib/Transforms/LoopFusion.cpp | 117 ++++++++++++++++++++++++++++------ mlir/test/Transforms/loop-fusion.mlir | 22 +++++++ 2 files changed, 118 insertions(+), 21 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index cf0f07345a4..aebf2716c4e 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -177,6 +177,15 @@ public: } return storeOpCount; } + + // Returns all store ups in 'storeOps' which access 'memref'. + void getStoreOpsForMemref(Value *memref, + SmallVectorImpl *storeOps) { + for (auto *storeOpInst : stores) { + if (memref == storeOpInst->cast()->getMemRef()) + storeOps->push_back(storeOpInst); + } + } }; // Edge represents a data dependece between nodes in the graph. @@ -258,10 +267,10 @@ public: for (auto *storeOpInst : node->stores) { auto *memref = storeOpInst->cast()->getMemRef(); auto *inst = memref->getDefiningInst(); - // Return false if 'memref' is a block argument. + // Return true if 'memref' is a block argument. if (!inst) return true; - // Return false if any use of 'memref' escapes the function. + // Return true if any use of 'memref' escapes the function. for (auto &use : memref->getUses()) if (!isMemRefDereferencingOp(*use.getOwner())) return true; @@ -1157,6 +1166,63 @@ static uint64_t getSliceIterationCount( return iterCount; } +// Checks if node 'srcId' (which writes to a live out memref), can be safely +// fused into node 'dstId'. Returns true if the following conditions are met: +// *) 'srcNode' writes only writes to live out 'memref'. +// *) 'srcNode' has exaclty one output edge on 'memref' (which is to 'dstId'). +// *) 'dstNode' does write to 'memref'. +// *) 'dstNode's write region to 'memref' is a super set of 'srcNode's write +// region to 'memref'. +// TODO(andydavis) Generalize this to handle more live in/out cases. +static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, + Value *memref, + MemRefDependenceGraph *mdg) { + auto *srcNode = mdg->getNode(srcId); + auto *dstNode = mdg->getNode(dstId); + + // Return false if any of the following are true: + // *) 'srcNode' writes to a live in/out memref other than 'memref'. + // *) 'srcNode' has more than one output edge on 'memref'. + // *) 'dstNode' does not write to 'memref'. + if (srcNode->getStoreOpCount(memref) != 1 || + mdg->getOutEdgeCount(srcNode->id, memref) != 1 || + dstNode->getStoreOpCount(memref) == 0) + return false; + // Compute MemRefRegion 'srcWriteRegion' for 'srcStoreOpInst' on 'memref'. + auto *srcStoreOpInst = srcNode->stores.front(); + MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc()); + srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0); + SmallVector srcShape; + // Query 'srcWriteRegion' for 'srcShape' and 'srcNumElements'. + // by 'srcStoreOpInst' at depth 'dstLoopDepth'. + Optional srcNumElements = + srcWriteRegion.getConstantBoundingSizeAndShape(&srcShape); + if (!srcNumElements.hasValue()) + return false; + + // Compute MemRefRegion 'dstWriteRegion' for 'dstStoreOpInst' on 'memref'. + SmallVector dstStoreOps; + dstNode->getStoreOpsForMemref(memref, &dstStoreOps); + assert(dstStoreOps.size() == 1); + auto *dstStoreOpInst = dstStoreOps[0]; + MemRefRegion dstWriteRegion(dstStoreOpInst->getLoc()); + dstWriteRegion.compute(dstStoreOpInst, /*loopDepth=*/0); + SmallVector dstShape; + // Query 'dstWriteRegion' for 'dstShape' and 'dstNumElements'. + // by 'dstStoreOpInst' at depth 'dstLoopDepth'. + Optional dstNumElements = + dstWriteRegion.getConstantBoundingSizeAndShape(&dstShape); + if (!dstNumElements.hasValue()) + return false; + + // Return false if write region is not a superset of 'srcNodes' write + // region to 'memref'. + // TODO(andydavis) Check the shape and lower bounds here too. + if (srcNumElements != dstNumElements) + return false; + return true; +} + // Checks the profitability of fusing a backwards slice of the loop nest // surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'. // Returns true if it is profitable to fuse the candidate loop nests. Returns @@ -1593,8 +1659,12 @@ public: if (mdg->getIncomingMemRefAccesses(srcNode->id, memref) != 0) continue; - // Skip if 'srcNode' writes to any live in or escaping memrefs. - if (mdg->writesToLiveInOrEscapingMemrefs(srcNode->id)) + // Skip if 'srcNode' writes to any live in or escaping memrefs, + // and cannot be fused. + bool writesToLiveInOrOut = + mdg->writesToLiveInOrEscapingMemrefs(srcNode->id); + if (writesToLiveInOrOut && + !canFuseSrcWhichWritesToLiveOut(srcId, dstId, memref, mdg)) continue; // Compute an instruction list insertion point for the fused loop @@ -1639,22 +1709,24 @@ public: for (auto forOp : sliceCollector.forOps) { promoteIfSingleIteration(forOp); } - // Create private memref for 'memref' in 'dstAffineForOp'. - SmallVector storesForMemref; - for (auto *storeOpInst : sliceCollector.storeOpInsts) { - if (storeOpInst->cast()->getMemRef() == memref) - storesForMemref.push_back(storeOpInst); + if (!writesToLiveInOrOut) { + // Create private memref for 'memref' in 'dstAffineForOp'. + SmallVector storesForMemref; + for (auto *storeOpInst : sliceCollector.storeOpInsts) { + if (storeOpInst->cast()->getMemRef() == memref) + storesForMemref.push_back(storeOpInst); + } + assert(storesForMemref.size() == 1); + auto *newMemRef = createPrivateMemRef( + dstAffineForOp, storesForMemref[0], bestDstLoopDepth, + fastMemorySpace, localBufSizeThreshold); + visitedMemrefs.insert(newMemRef); + // Create new node in dependence graph for 'newMemRef' alloc op. + unsigned newMemRefNodeId = + mdg->addNode(newMemRef->getDefiningInst()); + // Add edge from 'newMemRef' node to dstNode. + mdg->addEdge(newMemRefNodeId, dstId, newMemRef); } - assert(storesForMemref.size() == 1); - auto *newMemRef = createPrivateMemRef( - dstAffineForOp, storesForMemref[0], bestDstLoopDepth, - fastMemorySpace, localBufSizeThreshold); - visitedMemrefs.insert(newMemRef); - // Create new node in dependence graph for 'newMemRef' alloc op. - unsigned newMemRefNodeId = - mdg->addNode(newMemRef->getDefiningInst()); - // Add edge from 'newMemRef' node to dstNode. - mdg->addEdge(newMemRefNodeId, dstId, newMemRef); // Collect dst loop stats after memref privatizaton transformation. LoopNestStateCollector dstLoopCollector; @@ -1674,8 +1746,11 @@ public: dstLoopCollector.storeOpInsts); // Remove old src loop nest if it no longer has outgoing dependence // edges, and it does not write to a memref which escapes the - // function. - if (mdg->canRemoveNode(srcNode->id)) { + // function. If 'writesToLiveInOrOut' is true, then 'srcNode' has + // been fused into 'dstNode' and write region of 'dstNode' covers + // the write region of 'srcNode', and 'srcNode' has no other users + // so it is safe to remove. + if (writesToLiveInOrOut || mdg->canRemoveNode(srcNode->id)) { mdg->removeNode(srcNode->id); srcNode->inst->erase(); } else { diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index 72b0cddb514..c671adc6cf9 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -1788,3 +1788,25 @@ func @should_fuse_after_two_loop_interchanges() { // CHECK-NEXT: return return } + +// ----- + +func @should_fuse_live_out_writer(%arg0 : memref<10xf32>) -> memref<10xf32> { + %cst = constant 0.000000e+00 : f32 + for %i0 = 0 to 10 { + store %cst, %arg0[%i0] : memref<10xf32> + } + for %i1 = 0 to 10 { + %1 = load %arg0[%i1] : memref<10xf32> + store %1, %arg0[%i1] : memref<10xf32> + } + return %arg0 : memref<10xf32> + + // CHECK: %cst = constant 0.000000e+00 : f32 + // CHECK-NEXT: for %i0 = 0 to 10 { + // CHECK-NEXT: store %cst, %arg0[%i0] : memref<10xf32> + // CHECK-NEXT: %0 = load %arg0[%i0] : memref<10xf32> + // CHECK-NEXT: store %0, %arg0[%i0] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: return %arg0 : memref<10xf32> +} -- cgit v1.2.3 From 48ccae247639bd49bbf8de3330a96f2253e9c417 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Tue, 19 Feb 2019 17:17:46 -0800 Subject: NFC: Refactor the files related to passes. * PassRegistry is split into its own source file. * Pass related files are moved to a new library 'Pass'. PiperOrigin-RevId: 234705771 --- mlir/bindings/python/pybind.cpp | 2 +- mlir/include/mlir/Pass.h | 157 --------------------- mlir/include/mlir/Pass/Pass.h | 92 ++++++++++++ mlir/include/mlir/Pass/PassRegistry.h | 104 ++++++++++++++ mlir/include/mlir/Support/PassNameParser.h | 40 ------ mlir/include/mlir/Transforms/DialectConversion.h | 2 +- .../mlir/Transforms/MLPatternLoweringPass.h | 2 +- mlir/lib/Analysis/MemRefBoundCheck.cpp | 2 +- mlir/lib/Analysis/MemRefDependenceCheck.cpp | 2 +- mlir/lib/Analysis/OpStats.cpp | 2 +- mlir/lib/Analysis/Pass.cpp | 92 ------------ mlir/lib/EDSC/LowerEDSCTestPass.cpp | 2 +- mlir/lib/ExecutionEngine/ExecutionEngine.cpp | 2 +- .../lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp | 2 +- mlir/lib/Pass/Pass.cpp | 47 ++++++ mlir/lib/Pass/PassRegistry.cpp | 65 +++++++++ mlir/lib/Transforms/CSE.cpp | 2 +- mlir/lib/Transforms/Canonicalizer.cpp | 2 +- mlir/lib/Transforms/ConstantFold.cpp | 2 +- mlir/lib/Transforms/DmaGeneration.cpp | 2 +- mlir/lib/Transforms/LoopFusion.cpp | 2 +- mlir/lib/Transforms/LoopTiling.cpp | 2 +- mlir/lib/Transforms/LoopUnroll.cpp | 2 +- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 2 +- mlir/lib/Transforms/LowerAffine.cpp | 2 +- mlir/lib/Transforms/LowerVectorTransfers.cpp | 2 +- mlir/lib/Transforms/MaterializeVectors.cpp | 2 +- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 2 +- mlir/lib/Transforms/PipelineDataTransfer.cpp | 2 +- mlir/lib/Transforms/SimplifyAffineStructures.cpp | 2 +- mlir/lib/Transforms/StripDebugInfo.cpp | 2 +- .../Vectorization/VectorizerTestPass.cpp | 2 +- mlir/lib/Transforms/Vectorize.cpp | 2 +- mlir/lib/Transforms/ViewFunctionGraph.cpp | 2 +- mlir/tools/mlir-opt/mlir-opt.cpp | 3 +- 35 files changed, 336 insertions(+), 318 deletions(-) delete mode 100644 mlir/include/mlir/Pass.h create mode 100644 mlir/include/mlir/Pass/Pass.h create mode 100644 mlir/include/mlir/Pass/PassRegistry.h delete mode 100644 mlir/include/mlir/Support/PassNameParser.h delete mode 100644 mlir/lib/Analysis/Pass.cpp create mode 100644 mlir/lib/Pass/Pass.cpp create mode 100644 mlir/lib/Pass/PassRegistry.cpp (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/bindings/python/pybind.cpp b/mlir/bindings/python/pybind.cpp index 7b642924d72..5455eba1350 100644 --- a/mlir/bindings/python/pybind.cpp +++ b/mlir/bindings/python/pybind.cpp @@ -11,7 +11,7 @@ #include "third_party/llvm/llvm/projects/google_mlir/include/mlir/ExecutionEngine/ExecutionEngine.h" #include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/BuiltinOps.h" #include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/Module.h" -#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Pass.h" +#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Pass/Pass.h" #include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Target/LLVMIR.h" #include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Transforms/Passes.h" #include "pybind11/pybind11.h" diff --git a/mlir/include/mlir/Pass.h b/mlir/include/mlir/Pass.h deleted file mode 100644 index d8ab143f3bf..00000000000 --- a/mlir/include/mlir/Pass.h +++ /dev/null @@ -1,157 +0,0 @@ -//===- Pass.h - Base classes for compiler passes ----------------*- C++ -*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#ifndef MLIR_PASS_H -#define MLIR_PASS_H - -#include "mlir/Support/LLVM.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/Compiler.h" -#include - -namespace mlir { -class Function; -class Module; - -// Values that can be used by to signal success/failure. This can be implicitly -// converted to/from boolean values, with false representing success and true -// failure. -struct LLVM_NODISCARD PassResult { - enum ResultEnum { Success, Failure } value; - PassResult(ResultEnum v) : value(v) {} - operator bool() const { return value == Failure; } -}; - -class PassInfo; - -class Pass { -public: - explicit Pass(const void *passID) : passID(passID) {} - virtual ~Pass() = default; - virtual PassResult runOnModule(Module *m) = 0; - - /// Returns the unique identifier that corresponds to this pass. - const void *getPassID() const { return passID; } - - static PassResult success() { return PassResult::Success; } - static PassResult failure() { return PassResult::Failure; } - - /// Returns the pass info for the specified pass class or null if unknown. - static const PassInfo *lookupPassInfo(const void *passID); - - /// Returns the pass info for this pass. - const PassInfo *lookupPassInfo() const { return lookupPassInfo(passID); } - -private: - /// Out of line virtual method to ensure vtables and metadata are emitted to a - /// single .o file. - virtual void anchor(); - - /// Unique identifier for pass. - const void *const passID; -}; - -class ModulePass : public Pass { -public: - explicit ModulePass(const void *passID) : Pass(passID) {} - - virtual PassResult runOnModule(Module *m) override = 0; - -private: - /// Out of line virtual method to ensure vtables and metadata are emitted to a - /// single .o file. - virtual void anchor(); -}; - -/// FunctionPass's are run on every function in a module, and multiple functions -/// may be optimized concurrently by different instances of the function pass. -/// By subclassing this, your pass promises only to look at the function psased -/// in to it, it isn't allowed to inspect or modify other functions in the -/// module. -class FunctionPass : public Pass { -public: - explicit FunctionPass(const void *passID) : Pass(passID) {} - - /// Implement this function to be run on every function in the module. - virtual PassResult runOnFunction(Function *fn) = 0; - - // Iterates over all functions in a module, halting upon failure. - virtual PassResult runOnModule(Module *m) override; -}; - -using PassAllocatorFunction = std::function; - -/// Structure to group information about a pass (argument to invoke via -/// mlir-opt, description, pass allocator and unique ID). -class PassInfo { -public: - /// PassInfo constructor should not be invoked directly, instead use - /// PassRegistration or registerPass. - PassInfo(StringRef arg, StringRef description, const void *passID, - PassAllocatorFunction allocator) - : arg(arg), description(description), allocator(allocator), - passID(passID) {} - - /// Returns an allocated instance of this pass. - Pass *createPass() const { - assert(allocator && - "Cannot call createPass on PassInfo without default allocator"); - return allocator(); - } - - /// Returns the command line option that may be passed to 'mlir-opt' that will - /// cause this pass to run or null if there is no such argument. - StringRef getPassArgument() const { return arg; } - - /// Returns a description for the pass, this never returns null. - StringRef getPassDescription() const { return description; } - -private: - // The argument with which to invoke the pass via mlir-opt. - StringRef arg; - - // Description of the pass. - StringRef description; - - // Allocator to construct an instance of this pass. - PassAllocatorFunction allocator; - - // Unique identifier for pass. - const void *passID; -}; - -/// Register a specific dialect creation function with the system, typically -/// used through the PassRegistration template. -void registerPass(StringRef arg, StringRef description, const void *passID, - const PassAllocatorFunction &function); - -/// PassRegistration provides a global initializer that registers a Pass -/// allocation routine. -/// -/// Usage: -/// -/// // At namespace scope. -/// static PassRegistration Unused("unused", "Unused pass"); -template struct PassRegistration { - PassRegistration(StringRef arg, StringRef description) { - registerPass(arg, description, &ConcretePass::passID, - [&]() { return new ConcretePass(); }); - } -}; -} // end namespace mlir - -#endif // MLIR_PASS_H diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h new file mode 100644 index 00000000000..c489dafb20b --- /dev/null +++ b/mlir/include/mlir/Pass/Pass.h @@ -0,0 +1,92 @@ +//===- Pass.h - Base classes for compiler passes ----------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef MLIR_PASS_PASS_H +#define MLIR_PASS_PASS_H + +#include "mlir/Pass/PassRegistry.h" + +namespace mlir { +class Function; +class Module; + +// Values that can be used by to signal success/failure. This can be implicitly +// converted to/from boolean values, with false representing success and true +// failure. +struct LLVM_NODISCARD PassResult { + enum ResultEnum { Success, Failure } value; + PassResult(ResultEnum v) : value(v) {} + operator bool() const { return value == Failure; } +}; + +class Pass { +public: + explicit Pass(const void *passID) : passID(passID) {} + virtual ~Pass() = default; + virtual PassResult runOnModule(Module *m) = 0; + + /// Returns the unique identifier that corresponds to this pass. + const void *getPassID() const { return passID; } + + static PassResult success() { return PassResult::Success; } + static PassResult failure() { return PassResult::Failure; } + + /// Returns the pass info for the specified pass class or null if unknown. + static const PassInfo *lookupPassInfo(const void *passID); + + /// Returns the pass info for this pass. + const PassInfo *lookupPassInfo() const { return lookupPassInfo(passID); } + +private: + /// Out of line virtual method to ensure vtables and metadata are emitted to a + /// single .o file. + virtual void anchor(); + + /// Unique identifier for pass. + const void *const passID; +}; + +class ModulePass : public Pass { +public: + explicit ModulePass(const void *passID) : Pass(passID) {} + + virtual PassResult runOnModule(Module *m) override = 0; + +private: + /// Out of line virtual method to ensure vtables and metadata are emitted to a + /// single .o file. + virtual void anchor(); +}; + +/// FunctionPass's are run on every function in a module, and multiple functions +/// may be optimized concurrently by different instances of the function pass. +/// By subclassing this, your pass promises only to look at the function psased +/// in to it, it isn't allowed to inspect or modify other functions in the +/// module. +class FunctionPass : public Pass { +public: + explicit FunctionPass(const void *passID) : Pass(passID) {} + + /// Implement this function to be run on every function in the module. + virtual PassResult runOnFunction(Function *fn) = 0; + + // Iterates over all functions in a module, halting upon failure. + virtual PassResult runOnModule(Module *m) override; +}; +} // end namespace mlir + +#endif // MLIR_PASS_PASS_H diff --git a/mlir/include/mlir/Pass/PassRegistry.h b/mlir/include/mlir/Pass/PassRegistry.h new file mode 100644 index 00000000000..8f324bf0f69 --- /dev/null +++ b/mlir/include/mlir/Pass/PassRegistry.h @@ -0,0 +1,104 @@ +//===- PassRegistry.h - Pass Registration Utilities -------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file contains utilities for registering information about compiler +// passes. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_PASS_PASSREGISTRY_H_ +#define MLIR_PASS_PASSREGISTRY_H_ + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include + +namespace mlir { +class Pass; + +using PassAllocatorFunction = std::function; + +/// Structure to group information about a pass (argument to invoke via +/// mlir-opt, description, pass allocator and unique ID). +class PassInfo { +public: + /// PassInfo constructor should not be invoked directly, instead use + /// PassRegistration or registerPass. + PassInfo(StringRef arg, StringRef description, const void *passID, + PassAllocatorFunction allocator) + : arg(arg), description(description), allocator(allocator), + passID(passID) {} + + /// Returns an allocated instance of this pass. + Pass *createPass() const { + assert(allocator && + "Cannot call createPass on PassInfo without default allocator"); + return allocator(); + } + + /// Returns the command line option that may be passed to 'mlir-opt' that will + /// cause this pass to run or null if there is no such argument. + StringRef getPassArgument() const { return arg; } + + /// Returns a description for the pass, this never returns null. + StringRef getPassDescription() const { return description; } + +private: + // The argument with which to invoke the pass via mlir-opt. + StringRef arg; + + // Description of the pass. + StringRef description; + + // Allocator to construct an instance of this pass. + PassAllocatorFunction allocator; + + // Unique identifier for pass. + const void *passID; +}; + +/// Register a specific dialect creation function with the system, typically +/// used through the PassRegistration template. +void registerPass(StringRef arg, StringRef description, const void *passID, + const PassAllocatorFunction &function); + +/// PassRegistration provides a global initializer that registers a Pass +/// allocation routine. +/// +/// Usage: +/// +/// // At namespace scope. +/// static PassRegistration Unused("unused", "Unused pass"); +template struct PassRegistration { + PassRegistration(StringRef arg, StringRef description) { + registerPass(arg, description, &ConcretePass::passID, + [&]() { return new ConcretePass(); }); + } +}; + +/// Adds command line option for each registered pass. +struct PassNameParser : public llvm::cl::parser { + PassNameParser(llvm::cl::Option &opt); + + void printOptionInfo(const llvm::cl::Option &O, + size_t GlobalWidth) const override; +}; +} // end namespace mlir + +#endif // MLIR_PASS_PASSREGISTRY_H_ diff --git a/mlir/include/mlir/Support/PassNameParser.h b/mlir/include/mlir/Support/PassNameParser.h deleted file mode 100644 index bbdf433b9ab..00000000000 --- a/mlir/include/mlir/Support/PassNameParser.h +++ /dev/null @@ -1,40 +0,0 @@ -//===- PassNameParser.h - Base classes for compiler passes ------*- C++ -*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// The PassNameParser class adds all passes linked in to the system that are -// creatable to the tool. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_SUPPORT_PASSNAMEPARSER_H_ -#define MLIR_SUPPORT_PASSNAMEPARSER_H_ - -#include "llvm/Support/CommandLine.h" - -namespace mlir { -class PassInfo; - -/// Adds command line option for each registered pass. -struct PassNameParser : public llvm::cl::parser { - PassNameParser(llvm::cl::Option &opt); - - void printOptionInfo(const llvm::cl::Option &O, - size_t GlobalWidth) const override; -}; -} // end namespace mlir - -#endif // MLIR_SUPPORT_PASSNAMEPARSER_H_ diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 7bd08bbd766..b547e21c28a 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -23,7 +23,7 @@ #define MLIR_TRANSFORMS_DIALECTCONVERSION_H_ #include "mlir/IR/PatternMatch.h" -#include "mlir/Pass.h" +#include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" namespace mlir { diff --git a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h index c5be3322f43..1abd85a1d2b 100644 --- a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h +++ b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h @@ -24,7 +24,7 @@ #define MLIR_TRANSFORMS_MLPATTERNLOWERINGPASS_H #include "mlir/IR/PatternMatch.h" -#include "mlir/Pass.h" +#include "mlir/Pass/Pass.h" #include namespace mlir { diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index 3482f24dfcc..9f6efff3187 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -26,7 +26,7 @@ #include "mlir/IR/AffineStructures.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/Pass.h" +#include "mlir/Pass/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "llvm/Support/Debug.h" diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp index 578121e546e..43bc0c98916 100644 --- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp @@ -25,7 +25,7 @@ #include "mlir/IR/AffineStructures.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/Pass.h" +#include "mlir/Pass/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "llvm/Support/Debug.h" diff --git a/mlir/lib/Analysis/OpStats.cpp b/mlir/lib/Analysis/OpStats.cpp index f05f8737b16..6ae7ec59c50 100644 --- a/mlir/lib/Analysis/OpStats.cpp +++ b/mlir/lib/Analysis/OpStats.cpp @@ -18,7 +18,7 @@ #include "mlir/IR/Instruction.h" #include "mlir/IR/Module.h" #include "mlir/IR/OperationSupport.h" -#include "mlir/Pass.h" +#include "mlir/Pass/Pass.h" #include "llvm/ADT/DenseMap.h" #include "llvm/Support/Format.h" #include "llvm/Support/raw_ostream.h" diff --git a/mlir/lib/Analysis/Pass.cpp b/mlir/lib/Analysis/Pass.cpp deleted file mode 100644 index a10fe0324d5..00000000000 --- a/mlir/lib/Analysis/Pass.cpp +++ /dev/null @@ -1,92 +0,0 @@ -//===- Pass.cpp - Pass infrastructure implementation ----------------------===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file implements common pass infrastructure. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Pass.h" -#include "mlir/IR/Function.h" -#include "mlir/IR/Module.h" -#include "mlir/Support/PassNameParser.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/Support/ManagedStatic.h" -using namespace mlir; - -/// Out of line virtual method to ensure vtables and metadata are emitted to a -/// single .o file. -void Pass::anchor() {} - -/// Out of line virtual method to ensure vtables and metadata are emitted to a -/// single .o file. -void ModulePass::anchor() {} - -/// Function passes walk a module and look at each function with their -/// corresponding hooks and terminates upon error encountered. -PassResult FunctionPass::runOnModule(Module *m) { - for (auto &fn : *m) { - // All function passes ignore external functions. - if (fn.empty()) - continue; - - if (runOnFunction(&fn)) - return failure(); - } - return success(); -} - -// TODO: The pass registry and pass name parsing should be moved out. -static llvm::ManagedStatic> passRegistry; - -void mlir::registerPass(StringRef arg, StringRef description, - const void *passID, - const PassAllocatorFunction &function) { - bool inserted = passRegistry - ->insert(std::make_pair( - passID, PassInfo(arg, description, passID, function))) - .second; - assert(inserted && "Pass registered multiple times"); - (void)inserted; -} - -/// Returns the pass info for the specified pass class or null if unknown. -const PassInfo *mlir::Pass::lookupPassInfo(const void *passID) { - auto it = passRegistry->find(passID); - if (it == passRegistry->end()) - return nullptr; - return &it->getSecond(); -} - -PassNameParser::PassNameParser(llvm::cl::Option &opt) - : llvm::cl::parser(opt) { - for (const auto &kv : *passRegistry) { - addLiteralOption(kv.second.getPassArgument(), &kv.second, - kv.second.getPassDescription()); - } -} - -void PassNameParser::printOptionInfo(const llvm::cl::Option &O, - size_t GlobalWidth) const { - PassNameParser *TP = const_cast(this); - llvm::array_pod_sort(TP->Values.begin(), TP->Values.end(), - [](const PassNameParser::OptionInfo *VT1, - const PassNameParser::OptionInfo *VT2) { - return VT1->Name.compare(VT2->Name); - }); - using llvm::cl::parser; - parser::printOptionInfo(O, GlobalWidth); -} diff --git a/mlir/lib/EDSC/LowerEDSCTestPass.cpp b/mlir/lib/EDSC/LowerEDSCTestPass.cpp index 073113a342a..25fdadd397d 100644 --- a/mlir/lib/EDSC/LowerEDSCTestPass.cpp +++ b/mlir/lib/EDSC/LowerEDSCTestPass.cpp @@ -23,7 +23,7 @@ #include "mlir/IR/Module.h" #include "mlir/IR/StandardTypes.h" #include "mlir/IR/Types.h" -#include "mlir/Pass.h" +#include "mlir/Pass/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "llvm/Support/raw_ostream.h" diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp index b8ca360756b..01278aad8af 100644 --- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp @@ -22,7 +22,7 @@ #include "mlir/ExecutionEngine/ExecutionEngine.h" #include "mlir/IR/Function.h" #include "mlir/IR/Module.h" -#include "mlir/Pass.h" +#include "mlir/Pass/Pass.h" #include "mlir/Target/LLVMIR.h" #include "mlir/Transforms/Passes.h" diff --git a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp index e72840695dc..71c335082ff 100644 --- a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp +++ b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp @@ -26,7 +26,7 @@ #include "mlir/IR/Module.h" #include "mlir/IR/PatternMatch.h" #include "mlir/LLVMIR/LLVMDialect.h" -#include "mlir/Pass.h" +#include "mlir/Pass/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Support/Functional.h" #include "mlir/Transforms/DialectConversion.h" diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp new file mode 100644 index 00000000000..6a41cafcf1e --- /dev/null +++ b/mlir/lib/Pass/Pass.cpp @@ -0,0 +1,47 @@ +//===- Pass.cpp - Pass infrastructure implementation ----------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements common pass infrastructure. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Pass/Pass.h" +#include "mlir/IR/Module.h" + +using namespace mlir; + +/// Out of line virtual method to ensure vtables and metadata are emitted to a +/// single .o file. +void Pass::anchor() {} + +/// Out of line virtual method to ensure vtables and metadata are emitted to a +/// single .o file. +void ModulePass::anchor() {} + +/// Function passes walk a module and look at each function with their +/// corresponding hooks and terminates upon error encountered. +PassResult FunctionPass::runOnModule(Module *m) { + for (auto &fn : *m) { + // All function passes ignore external functions. + if (fn.empty()) + continue; + + if (runOnFunction(&fn)) + return failure(); + } + return success(); +} diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp new file mode 100644 index 00000000000..c26da1f4099 --- /dev/null +++ b/mlir/lib/Pass/PassRegistry.cpp @@ -0,0 +1,65 @@ +//===- PassRegistry.cpp - Pass Registration Utilities ---------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/ManagedStatic.h" + +using namespace mlir; + +/// Static mapping of all of the registered passes. +static llvm::ManagedStatic> passRegistry; + +void mlir::registerPass(StringRef arg, StringRef description, + const void *passID, + const PassAllocatorFunction &function) { + bool inserted = passRegistry + ->insert(std::make_pair( + passID, PassInfo(arg, description, passID, function))) + .second; + assert(inserted && "Pass registered multiple times"); + (void)inserted; +} + +/// Returns the pass info for the specified pass class or null if unknown. +const PassInfo *mlir::Pass::lookupPassInfo(const void *passID) { + auto it = passRegistry->find(passID); + if (it == passRegistry->end()) + return nullptr; + return &it->getSecond(); +} + +PassNameParser::PassNameParser(llvm::cl::Option &opt) + : llvm::cl::parser(opt) { + for (const auto &kv : *passRegistry) { + addLiteralOption(kv.second.getPassArgument(), &kv.second, + kv.second.getPassDescription()); + } +} + +void PassNameParser::printOptionInfo(const llvm::cl::Option &O, + size_t GlobalWidth) const { + PassNameParser *TP = const_cast(this); + llvm::array_pod_sort(TP->Values.begin(), TP->Values.end(), + [](const PassNameParser::OptionInfo *VT1, + const PassNameParser::OptionInfo *VT2) { + return VT1->Name.compare(VT2->Name); + }); + using llvm::cl::parser; + parser::printOptionInfo(O, GlobalWidth); +} diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index de10fe8a461..cd205fe773b 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -24,7 +24,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" -#include "mlir/Pass.h" +#include "mlir/Pass/Pass.h" #include "mlir/Support/Functional.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Utils.h" diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp index e6d32e02e58..0388744d4d2 100644 --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -22,7 +22,7 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/Pass.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/Passes.h" using namespace mlir; diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index 51ec123d658..7634d9ec16a 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -18,7 +18,7 @@ #include "mlir/AffineOps/AffineOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" -#include "mlir/Pass.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Utils.h" diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index cafb9ba4d70..fb5719aca02 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -26,7 +26,7 @@ #include "mlir/IR/AffineStructures.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/Pass.h" +#include "mlir/Pass/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Utils.h" diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index aebf2716c4e..63a681a4fdc 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -28,7 +28,7 @@ #include "mlir/IR/AffineStructures.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/Pass.h" +#include "mlir/Pass/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Transforms/LoopUtils.h" #include "mlir/Transforms/Passes.h" diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 44798dcee85..76e8e9254c9 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -24,7 +24,7 @@ #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/IR/AffineStructures.h" #include "mlir/IR/Builders.h" -#include "mlir/Pass.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/LoopUtils.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Utils.h" diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 84b0d2279db..b452b4f76e2 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -27,7 +27,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/Pass.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/LoopUtils.h" #include "llvm/ADT/DenseMap.h" #include "llvm/Support/CommandLine.h" diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index b2aed7d9d7f..76668c7f0b5 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -50,7 +50,7 @@ #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/Pass.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/LoopUtils.h" #include "llvm/ADT/DenseMap.h" #include "llvm/Support/CommandLine.h" diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index 0d8eb8a4761..8b62601ab41 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -26,7 +26,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLIRContext.h" -#include "mlir/Pass.h" +#include "mlir/Pass/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Support/Functional.h" #include "mlir/Transforms/Passes.h" diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index ff97e9197f3..bd43a637665 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -36,7 +36,7 @@ #include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Types.h" -#include "mlir/Pass.h" +#include "mlir/Pass/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/SuperVectorOps/SuperVectorOps.h" #include "mlir/Support/Functional.h" diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 6c7820fced9..3cd33c20fae 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -36,7 +36,7 @@ #include "mlir/IR/Location.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/Types.h" -#include "mlir/Pass.h" +#include "mlir/Pass/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/SuperVectorOps/SuperVectorOps.h" #include "mlir/Support/Functional.h" diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index d9f940a01f3..68bce854222 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -25,7 +25,7 @@ #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/Dominance.h" #include "mlir/Analysis/Utils.h" -#include "mlir/Pass.h" +#include "mlir/Pass/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/SmallPtrSet.h" diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index a85f428bde6..b6bfde58494 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -26,7 +26,7 @@ #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Analysis/Utils.h" #include "mlir/IR/Builders.h" -#include "mlir/Pass.h" +#include "mlir/Pass/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Transforms/LoopUtils.h" #include "mlir/Transforms/Utils.h" diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp index 897498e8346..20961180b83 100644 --- a/mlir/lib/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp @@ -23,7 +23,7 @@ #include "mlir/IR/Function.h" #include "mlir/IR/Instruction.h" #include "mlir/IR/IntegerSet.h" -#include "mlir/Pass.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/Passes.h" #define DEBUG_TYPE "simplify-affine-structure" diff --git a/mlir/lib/Transforms/StripDebugInfo.cpp b/mlir/lib/Transforms/StripDebugInfo.cpp index c5e42b622ed..2eb4a37445c 100644 --- a/mlir/lib/Transforms/StripDebugInfo.cpp +++ b/mlir/lib/Transforms/StripDebugInfo.cpp @@ -17,7 +17,7 @@ #include "mlir/IR/Function.h" #include "mlir/IR/Instruction.h" -#include "mlir/Pass.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/Passes.h" using namespace mlir; diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index 1a83fce2eeb..3fa08bba096 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -27,7 +27,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/StandardTypes.h" -#include "mlir/Pass.h" +#include "mlir/Pass/Pass.h" #include "mlir/Support/Functional.h" #include "mlir/Support/STLExtras.h" #include "mlir/Transforms/Passes.h" diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 5a8d5d24661..40a2c9794ae 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -29,7 +29,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Location.h" #include "mlir/IR/Types.h" -#include "mlir/Pass.h" +#include "mlir/Pass/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/SuperVectorOps/SuperVectorOps.h" #include "mlir/Support/Functional.h" diff --git a/mlir/lib/Transforms/ViewFunctionGraph.cpp b/mlir/lib/Transforms/ViewFunctionGraph.cpp index e46dc503ea9..4865859b9ec 100644 --- a/mlir/lib/Transforms/ViewFunctionGraph.cpp +++ b/mlir/lib/Transforms/ViewFunctionGraph.cpp @@ -17,7 +17,7 @@ #include "mlir/Transforms/ViewFunctionGraph.h" #include "mlir/IR/FunctionGraphTraits.h" -#include "mlir/Pass.h" +#include "mlir/Pass/Pass.h" using namespace mlir; diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 12593a8cae5..4c0c8fdc296 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -28,9 +28,8 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/Parser.h" -#include "mlir/Pass.h" +#include "mlir/Pass/Pass.h" #include "mlir/Support/FileUtilities.h" -#include "mlir/Support/PassNameParser.h" #include "mlir/TensorFlow/ControlFlowOps.h" #include "mlir/TensorFlow/Passes.h" #include "mlir/TensorFlowLite/Passes.h" -- cgit v1.2.3 From a1dad3a5d9903d3a71735612dd5b51a9fd45a5a7 Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Tue, 19 Feb 2019 18:17:19 -0800 Subject: Extend/improve getSliceBounds() / complete TODO + update unionBoundingBox - compute slices precisely where the destination iteration depends on multiple source iterations (instead of over-approximating to the whole source loop extent) - update unionBoundingBox to deal with input with non-matching symbols - reenable disabled backend test case PiperOrigin-RevId: 234714069 --- mlir/include/mlir/IR/AffineExpr.h | 9 ++ mlir/include/mlir/IR/AffineStructures.h | 17 ++- mlir/lib/IR/AffineExpr.cpp | 9 +- mlir/lib/IR/AffineStructures.cpp | 257 +++++++++++++++++++++++++++----- mlir/lib/Transforms/LoopFusion.cpp | 26 ++-- mlir/test/Transforms/loop-fusion.mlir | 58 +++++++ 6 files changed, 322 insertions(+), 54 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h index d7eab0f1312..a652ff6a22f 100644 --- a/mlir/include/mlir/IR/AffineExpr.h +++ b/mlir/include/mlir/IR/AffineExpr.h @@ -222,6 +222,15 @@ AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context); AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs); +/// Constructs an affine expression from a flat ArrayRef. If there are local +/// identifiers (neither dimensional nor symbolic) that appear in the sum of +/// products expression, 'localExprs' is expected to have the AffineExpr +/// for it, and is substituted into. The ArrayRef 'eq' is expected to be in the +/// format [dims, symbols, locals, constant term]. +AffineExpr toAffineExpr(ArrayRef eq, unsigned numDims, + unsigned numSymbols, ArrayRef localExprs, + MLIRContext *context); + raw_ostream &operator<<(raw_ostream &os, AffineExpr &expr); template bool AffineExpr::isa() const { diff --git a/mlir/include/mlir/IR/AffineStructures.h b/mlir/include/mlir/IR/AffineStructures.h index c90731cfd64..20ca7d71052 100644 --- a/mlir/include/mlir/IR/AffineStructures.h +++ b/mlir/include/mlir/IR/AffineStructures.h @@ -424,7 +424,8 @@ public: bool findId(const Value &id, unsigned *pos) const; // Add identifiers of the specified kind - specified positions are relative to - // the kind of identifier. 'id' is the Value corresponding to the + // the kind of identifier. The coefficient column corresponding to the added + // identifier is initialized to zero. 'id' is the Value corresponding to the // identifier that can optionally be provided. void addDimId(unsigned pos, Value *id = nullptr); void addSymbolId(unsigned pos, Value *id = nullptr); @@ -579,6 +580,17 @@ public: /// one; None otherwise. Optional getConstantUpperBound(unsigned pos) const; + /// Gets the lower and upper bound of the pos^th identifier treating + /// [dimStartPos, symbStartPos) as dimensions and [symStartPos, + /// getNumDimAndSymbolIds) as symbols. The returned multi-dimensional maps + /// in the pair represent the max and min of potentially multiple affine + /// expressions. The upper bound is exclusive. 'localExprs' holds pre-computed + /// AffineExpr's for all local identifiers in the system. + std::pair + getLowerAndUpperBound(unsigned pos, unsigned dimStartPos, + unsigned symStartPos, ArrayRef localExprs, + MLIRContext *context); + /// Returns true if the set can be trivially detected as being /// hyper-rectangular on the specified contiguous set of identifiers. bool isHyperRectangular(unsigned pos, unsigned num) const; @@ -588,6 +600,9 @@ public: /// constraint. void removeTrivialRedundancy(); + /// A more expensive check to detect redundant inequalities. + void removeRedundantInequalities(); + // Removes all equalities and inequalities. void clearConstraints(); diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp index c029ef3df52..5cfb1461590 100644 --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -301,11 +301,10 @@ raw_ostream &operator<<(raw_ostream &os, AffineExpr &expr) { /// products expression, 'localExprs' is expected to have the AffineExpr /// for it, and is substituted into. The ArrayRef 'eq' is expected to be in the /// format [dims, symbols, locals, constant term]. -// TODO(bondhugula): refactor getAddMulPureAffineExpr to reuse it from here. -static AffineExpr toAffineExpr(ArrayRef eq, unsigned numDims, - unsigned numSymbols, - ArrayRef localExprs, - MLIRContext *context) { +AffineExpr mlir::toAffineExpr(ArrayRef eq, unsigned numDims, + unsigned numSymbols, + ArrayRef localExprs, + MLIRContext *context) { // Assert expected numLocals = eq.size() - numDims - numSymbols - 1 assert(eq.size() - numDims - numSymbols - 1 == localExprs.size() && "unexpected number of local expressions"); diff --git a/mlir/lib/IR/AffineStructures.cpp b/mlir/lib/IR/AffineStructures.cpp index d043e78f059..5114f56bcfc 100644 --- a/mlir/lib/IR/AffineStructures.cpp +++ b/mlir/lib/IR/AffineStructures.cpp @@ -809,9 +809,6 @@ unsigned FlatAffineConstraints::gaussianEliminateIds(unsigned posStart, if (posStart >= posLimit) return 0; - LLVM_DEBUG(llvm::dbgs() << "Eliminating by Gaussian [" << posStart << ", " - << posLimit << ")\n"); - GCDTightenInequalities(); unsigned pivotCol = 0; @@ -909,25 +906,36 @@ static bool detectAsMod(const FlatAffineConstraints &cst, unsigned pos, return false; } +// Gather lower and upper bounds for the pos^th identifier. +static void getLowerAndUpperBoundIndices(const FlatAffineConstraints &cst, + unsigned pos, + SmallVectorImpl *lbIndices, + SmallVectorImpl *ubIndices) { + assert(pos < cst.getNumIds() && "invalid position"); + + // Gather all lower bounds and upper bounds of the variable. Since the + // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower + // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. + for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) { + if (cst.atIneq(r, pos) >= 1) { + // Lower bound. + lbIndices->push_back(r); + } else if (cst.atIneq(r, pos) <= -1) { + // Upper bound. + ubIndices->push_back(r); + } + } +} + // Check if the pos^th identifier can be expressed as a floordiv of an affine // function of other identifiers (where the divisor is a positive constant). // For eg: 4q <= i + j <= 4q + 3 <=> q = (i + j) floordiv 4. bool detectAsFloorDiv(const FlatAffineConstraints &cst, unsigned pos, SmallVectorImpl *memo, MLIRContext *context) { assert(pos < cst.getNumIds() && "invalid position"); - SmallVector lbIndices, ubIndices; - // Gather all lower bounds and upper bound constraints of this identifier. - // Since the canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint - // is a lower bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. - for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) { - if (cst.atIneq(r, pos) >= 1) - // Lower bound. - lbIndices.push_back(r); - else if (cst.atIneq(r, pos) <= -1) - // Upper bound. - ubIndices.push_back(r); - } + SmallVector lbIndices, ubIndices; + getLowerAndUpperBoundIndices(cst, pos, &lbIndices, &ubIndices); // Check if any lower bound, upper bound pair is of the form: // divisor * id >= expr - (divisor - 1) <-- Lower bound for 'id' @@ -993,6 +1001,107 @@ bool detectAsFloorDiv(const FlatAffineConstraints &cst, unsigned pos, return false; } +// Fills an inequality row with the value 'val'. +static inline void fillInequality(FlatAffineConstraints *cst, unsigned r, + int64_t val) { + for (unsigned c = 0, f = cst->getNumCols(); c < f; c++) { + cst->atIneq(r, c) = val; + } +} + +// Negates an inequality. +static inline void negateInequality(FlatAffineConstraints *cst, unsigned r) { + for (unsigned c = 0, f = cst->getNumCols(); c < f; c++) { + cst->atIneq(r, c) = -cst->atIneq(r, c); + } +} + +// A more complex check to eliminate redundant inequalities. +void FlatAffineConstraints::removeRedundantInequalities() { + SmallVector redun(getNumInequalities(), false); + // To check if an inequality is redundant, we replace the inequality by its + // complement (for eg., i - 1 >= 0 by i <= 0), and check if the resulting + // system is empty. If it is, the inequality is redundant. + FlatAffineConstraints tmpCst(*this); + for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { + // Change the inequality to its complement. + negateInequality(&tmpCst, r); + tmpCst.atIneq(r, tmpCst.getNumCols() - 1)--; + if (tmpCst.isEmpty()) { + redun[r] = true; + // Zero fill the redundant inequality. + fillInequality(this, r, /*val=*/0); + fillInequality(&tmpCst, r, /*val=*/0); + } else { + // Reverse the change (to avoid recreating tmpCst each time). + tmpCst.atIneq(r, tmpCst.getNumCols() - 1)++; + negateInequality(&tmpCst, r); + } + } + + // Scan to get rid of all rows marked redundant, in-place. + auto copyRow = [&](unsigned src, unsigned dest) { + if (src == dest) + return; + for (unsigned c = 0, e = getNumCols(); c < e; c++) { + atIneq(dest, c) = atIneq(src, c); + } + }; + unsigned pos = 0; + for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { + if (!redun[r]) + copyRow(r, pos++); + } + inequalities.resize(numReservedCols * pos); +} + +std::pair FlatAffineConstraints::getLowerAndUpperBound( + unsigned pos, unsigned dimStartPos, unsigned symStartPos, + ArrayRef localExprs, MLIRContext *context) { + assert(pos < dimStartPos && "invalid dim start pos"); + assert(symStartPos >= dimStartPos && "invalid sym start pos"); + assert(getNumLocalIds() == localExprs.size() && + "incorrect local exprs count"); + + SmallVector lbIndices, ubIndices; + getLowerAndUpperBoundIndices(*this, pos, &lbIndices, &ubIndices); + + SmallVector lb, ub; + SmallVector exprs; + unsigned dimCount = symStartPos - dimStartPos; + unsigned symCount = getNumDimAndSymbolIds() - symStartPos; + exprs.reserve(lbIndices.size()); + // Lower bound expressions. + for (auto idx : lbIndices) { + auto ineq = getInequality(idx); + // Extract the lower bound (in terms of other coeff's + const), i.e., if + // i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j + // - 1. + lb.assign(ineq.begin() + dimStartPos, ineq.end()); + std::transform(lb.begin(), lb.end(), lb.begin(), std::negate()); + auto expr = mlir::toAffineExpr(lb, dimCount, symCount, localExprs, context); + exprs.push_back(expr); + } + auto lbMap = exprs.empty() ? AffineMap() + : AffineMap::get(dimCount, symCount, exprs, {}); + + exprs.clear(); + exprs.reserve(ubIndices.size()); + // Upper bound expressions. + for (auto idx : ubIndices) { + auto ineq = getInequality(idx); + // Extract the upper bound (in terms of other coeff's + const). + ub.assign(ineq.begin() + dimStartPos, ineq.end()); + auto expr = mlir::toAffineExpr(ub, dimCount, symCount, localExprs, context); + // Upper bound is exclusive. + exprs.push_back(expr + 1); + } + auto ubMap = exprs.empty() ? AffineMap() + : AffineMap::get(dimCount, symCount, exprs, {}); + + return {lbMap, ubMap}; +} + /// Computes the lower and upper bounds of the first 'num' dimensional /// identifiers as affine maps of the remaining identifiers (dimensional and /// symbolic identifiers). Local identifiers are themselves explicitly computed @@ -1097,6 +1206,7 @@ void FlatAffineConstraints::getSliceBounds(unsigned num, MLIRContext *context, // Set the lower and upper bound maps for all the identifiers that were // computed as affine expressions of the rest as the "detected expr" and // "detected expr + 1" respectively; set the undetected ones to Null(). + Optional tmpClone; for (unsigned pos = 0; pos < num; pos++) { unsigned numMapDims = getNumDimIds() - num; unsigned numMapSymbols = getNumSymbolIds(); @@ -1108,24 +1218,49 @@ void FlatAffineConstraints::getSliceBounds(unsigned num, MLIRContext *context, (*lbMaps)[pos] = AffineMap::get(numMapDims, numMapSymbols, expr, {}); (*ubMaps)[pos] = AffineMap::get(numMapDims, numMapSymbols, expr + 1, {}); } else { - // TODO(andydavis, bondhugula) Add support for computing slice bounds - // symbolic in the identifies [num, numIds). - auto lbConst = getConstantLowerBound(pos); - auto ubConst = getConstantUpperBound(pos); - if (lbConst.hasValue() && ubConst.hasValue()) { - (*lbMaps)[pos] = AffineMap::get( - numMapDims, numMapSymbols, - getAffineConstantExpr(lbConst.getValue(), context), {}); - (*ubMaps)[pos] = AffineMap::get( - numMapDims, numMapSymbols, - getAffineConstantExpr(ubConst.getValue() + 1, context), {}); - } else { - (*lbMaps)[pos] = AffineMap(); - (*ubMaps)[pos] = AffineMap(); + // TODO(bondhugula): Whenever there have local identifiers in the + // dependence constraints, we'll conservatively over-approximate, since we + // don't always explicitly compute them above (in the while loop). + if (getNumLocalIds() == 0) { + // Work on a copy so that we don't update this constraint system. + if (!tmpClone) { + tmpClone.emplace(FlatAffineConstraints(*this)); + // Removing redudnant inequalities is necessary so that we don't get + // redundant loop bounds. + tmpClone->removeRedundantInequalities(); + } + std::tie((*lbMaps)[pos], (*ubMaps)[pos]) = + tmpClone->getLowerAndUpperBound(pos, num, getNumDimIds(), {}, + context); + } + + // If the above fails, we'll just use the constant lower bound and the + // constant upper bound (if they exist) as the slice bounds. + if (!(*lbMaps)[pos]) { + LLVM_DEBUG(llvm::dbgs() + << "WARNING: Potentially over-approximating slice lb\n"); + auto lbConst = getConstantLowerBound(pos); + if (lbConst.hasValue()) { + (*lbMaps)[pos] = AffineMap::get( + numMapDims, numMapSymbols, + getAffineConstantExpr(lbConst.getValue(), context), {}); + } + } + if (!(*ubMaps)[pos]) { + LLVM_DEBUG(llvm::dbgs() + << "WARNING: Potentially over-approximating slice ub\n"); + auto ubConst = getConstantUpperBound(pos); + if (ubConst.hasValue()) { + (*ubMaps)[pos] = AffineMap::get( + numMapDims, numMapSymbols, + getAffineConstantExpr(ubConst.getValue() + 1, context), {}); + } } } LLVM_DEBUG(llvm::dbgs() << "lb map for pos = " << Twine(pos) << ", expr: "); - LLVM_DEBUG(expr.dump();); + LLVM_DEBUG((*lbMaps)[pos].dump();); + LLVM_DEBUG(llvm::dbgs() << "ub map for pos = " << Twine(pos) << ", expr: "); + LLVM_DEBUG((*ubMaps)[pos].dump();); } } @@ -1454,6 +1589,7 @@ Optional FlatAffineConstraints::getConstantBoundOnDimSize( break; } if (c < getNumDimIds()) + // Not a pure symbolic bound. continue; if (atIneq(r, pos) >= 1) // Lower bound. @@ -2037,14 +2173,53 @@ static BoundCmpResult compareBounds(ArrayRef a, ArrayRef b) { } }; // namespace +// TODO(bondhugula,andydavis): This still doesn't do a comprehensive merge of +// the symbols. Assumes the common symbols appear in the same order (the +// current/common use case). +static void mergeSymbols(FlatAffineConstraints *A, FlatAffineConstraints *B) { + SmallVector symbolsA, symbolsB; + A->getIdValues(A->getNumDimIds(), A->getNumDimAndSymbolIds(), &symbolsA); + B->getIdValues(B->getNumDimIds(), B->getNumDimAndSymbolIds(), &symbolsB); + + // Both symbol list have a handful symbols each typically (3-4); a merge + // quadratic in complexity with a linear search is fine. + for (auto *symbolB : symbolsB) { + if (llvm::is_contained(symbolsA, symbolB)) { + A->addSymbolId(symbolsA.size(), symbolB); + symbolsA.push_back(symbolB); + } + } + // symbolsA now holds the merged symbol list. + symbolsB.reserve(symbolsA.size()); + unsigned iB = 0; + for (auto *symbolA : symbolsA) { + assert(iB < symbolsB.size()); + if (symbolA != symbolsB[iB]) { + symbolsB.insert(symbolsB.begin() + iB, symbolA); + B->addSymbolId(iB, symbolA); + } + ++iB; + } +} + // Compute the bounding box with respect to 'other' by finding the min of the // lower bounds and the max of the upper bounds along each of the dimensions. bool FlatAffineConstraints::unionBoundingBox( - const FlatAffineConstraints &other) { - assert(other.getNumDimIds() == numDims); - assert(other.getNumSymbolIds() == getNumSymbolIds()); - assert(other.getNumLocalIds() == 0); - assert(getNumLocalIds() == 0); + const FlatAffineConstraints &otherArg) { + assert(otherArg.getNumDimIds() == numDims && "dims mismatch"); + + Optional copy; + if (!otherArg.getIds().equals(getIds())) { + copy.emplace(FlatAffineConstraints(otherArg)); + mergeSymbols(this, ©.getValue()); + assert(getIds().equals(copy->getIds()) && "merge failed"); + } + + const auto &other = copy ? *copy : otherArg; + + assert(other.getNumLocalIds() == 0 && "local ids not eliminated"); + assert(getNumLocalIds() == 0 && "local ids not eliminated"); + std::vector> boundingLbs; std::vector> boundingUbs; boundingLbs.reserve(2 * getNumDimIds()); @@ -2082,7 +2257,11 @@ bool FlatAffineConstraints::unionBoundingBox( minLb = otherLb; } else { // Uncomparable. - return false; + auto constLb = getConstantLowerBound(d); + auto constOtherLb = other.getConstantLowerBound(d); + if (!constLb.hasValue() || !constOtherLb.hasValue()) + return false; + minLb = std::min(constLb.getValue(), constOtherLb.getValue()); } // Do the same for ub's but max of upper bounds. @@ -2098,7 +2277,11 @@ bool FlatAffineConstraints::unionBoundingBox( maxUb = otherUb; } else { // Uncomparable. - return false; + auto constUb = getConstantUpperBound(d); + auto constOtherUb = other.getConstantUpperBound(d); + if (!constUb.hasValue() || !constOtherUb.hasValue()) + return false; + maxUb = std::max(constUb.getValue(), constOtherUb.getValue()); } SmallVector newLb(getNumCols(), 0); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 63a681a4fdc..524b34bb86f 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -58,8 +58,8 @@ static llvm::cl::opt /// A threshold in percent of additional computation allowed when fusing. static llvm::cl::opt clFusionAddlComputeTolerance( "fusion-compute-tolerance", llvm::cl::Hidden, - llvm::cl::desc("Fractional increase in additional" - " computation tolerated while fusing"), + llvm::cl::desc("Fractional increase in additional " + "computation tolerated while fusing"), llvm::cl::cat(clOptionsCategory)); static llvm::cl::opt clFusionFastMemorySpace( @@ -1260,12 +1260,9 @@ static bool isFusionProfitable(Instruction *srcOpInst, unsigned *dstLoopDepth) { LLVM_DEBUG({ llvm::dbgs() << "Checking whether fusion is profitable between:\n"; - llvm::dbgs() << " "; - srcOpInst->dump(); - llvm::dbgs() << " and \n"; + llvm::dbgs() << " " << *srcOpInst << " and \n"; for (auto dstOpInst : dstLoadOpInsts) { - llvm::dbgs() << " "; - dstOpInst->dump(); + llvm::dbgs() << " " << *dstOpInst << "\n"; }; }); @@ -1423,7 +1420,10 @@ static bool isFusionProfitable(Instruction *srcOpInst, << 100.0 * additionalComputeFraction << "%\n" << " storage reduction factor: " << storageReduction << "x\n" << " fused nest cost: " << fusedLoopNestComputeCost << "\n" - << " slice iteration count: " << sliceIterationCount << "\n"; + << " slice iteration count: " << sliceIterationCount << "\n" + << " src write region size: " << srcWriteRegionSizeBytes << "\n" + << " slice write region size: " << sliceWriteRegionSizeBytes + << "\n"; llvm::dbgs() << msg.str(); }); @@ -1450,9 +1450,10 @@ static bool isFusionProfitable(Instruction *srcOpInst, // -maximal-fusion is set, fuse nevertheless. if (!clMaximalLoopFusion && !bestDstLoopDepth.hasValue()) { - LLVM_DEBUG(llvm::dbgs() - << "All fusion choices involve more than the threshold amount of" - "redundant computation; NOT fusing.\n"); + LLVM_DEBUG( + llvm::dbgs() + << "All fusion choices involve more than the threshold amount of " + "redundant computation; NOT fusing.\n"); return false; } @@ -1694,6 +1695,9 @@ public: auto sliceLoopNest = mlir::insertBackwardComputationSlice( srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState); if (sliceLoopNest != nullptr) { + LLVM_DEBUG(llvm::dbgs() + << "\tslice loop nest:\n" + << *sliceLoopNest->getInstruction() << "\n"); // Move 'dstAffineForOp' before 'insertPointInst' if needed. auto dstAffineForOp = dstNode->inst->cast(); if (insertPointInst != dstAffineForOp->getInstruction()) { diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index c671adc6cf9..2458049f22c 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -1810,3 +1810,61 @@ func @should_fuse_live_out_writer(%arg0 : memref<10xf32>) -> memref<10xf32> { // CHECK-NEXT: } // CHECK-NEXT: return %arg0 : memref<10xf32> } + +// ----- + +// The fused slice has 16 iterations from along %i0. + +// CHECK-DAG: [[MAP_LB:#map[0-9]+]] = (d0) -> (d0 * 16) +// CHECK-DAG: [[MAP_UB:#map[0-9]+]] = (d0) -> (d0 * 16 + 16) + +#map = (d0, d1) -> (d0 * 16 + d1) + +// CHECK-LABEL: slice_tile +func @slice_tile(%arg1: memref<32x8xf32>, %arg2: memref<32x8xf32>, %0 : f32) -> memref<32x8xf32> { + for %i0 = 0 to 32 { + for %i1 = 0 to 8 { + store %0, %arg2[%i0, %i1] : memref<32x8xf32> + } + } + for %i = 0 to 2 { + for %j = 0 to 8 { + for %k = 0 to 8 { + for %kk = 0 to 16 { + %1 = affine.apply #map(%k, %kk) + %2 = load %arg1[%1, %j] : memref<32x8xf32> + %3 = "foo"(%2) : (f32) -> f32 + } + for %ii = 0 to 16 { + %6 = affine.apply #map(%i, %ii) + %7 = load %arg2[%6, %j] : memref<32x8xf32> + %8 = addf %7, %7 : f32 + store %8, %arg2[%6, %j] : memref<32x8xf32> + } + } + } + } + return %arg2 : memref<32x8xf32> +} +// CHECK: for %i0 = 0 to 2 { +// CHECK-NEXT: for %i1 = 0 to 8 { +// CHECK-NEXT: for %i2 = [[MAP_LB]](%i0) to [[MAP_UB]](%i0) { +// CHECK-NEXT: store %arg2, %arg1[%i2, %i1] : memref<32x8xf32> +// CHECK-NEXT: } +// CHECK-NEXT: for %i3 = 0 to 8 { +// CHECK-NEXT: for %i4 = 0 to 16 { +// CHECK-NEXT: %0 = affine.apply #map{{[0-9]+}}(%i3, %i4) +// CHECK-NEXT: %1 = load %arg0[%0, %i1] : memref<32x8xf32> +// CHECK-NEXT: %2 = "foo"(%1) : (f32) -> f32 +// CHECK-NEXT: } +// CHECK-NEXT: for %i5 = 0 to 16 { +// CHECK-NEXT: %3 = affine.apply #map{{[0-9]+}}(%i0, %i5) +// CHECK-NEXT: %4 = load %arg1[%3, %i1] : memref<32x8xf32> +// CHECK-NEXT: %5 = addf %4, %4 : f32 +// CHECK-NEXT: store %5, %arg1[%3, %i1] : memref<32x8xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return %arg1 : memref<32x8xf32> +// CHECK-NEXT:} -- cgit v1.2.3 From 3e656599f1c1ab2d4b810ba054d59c99d5628b35 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Thu, 21 Feb 2019 18:01:09 -0800 Subject: Define a PassID class to use when defining a pass. This allows for the type used for the ID field to be self documenting. It also allows for the compiler to know the set alignment of the ID object, which is useful for storing pointer identifiers within llvm data structures. PiperOrigin-RevId: 235107957 --- mlir/g3doc/QuickstartRewrites.md | 2 +- mlir/include/mlir/Pass/Pass.h | 16 ++++++++++------ mlir/include/mlir/Pass/PassRegistry.h | 7 ++++--- mlir/include/mlir/Transforms/DialectConversion.h | 2 +- mlir/include/mlir/Transforms/MLPatternLoweringPass.h | 2 +- mlir/lib/Analysis/MemRefBoundCheck.cpp | 4 +--- mlir/lib/Analysis/MemRefDependenceCheck.cpp | 4 +--- mlir/lib/Analysis/OpStats.cpp | 4 +--- mlir/lib/EDSC/LowerEDSCTestPass.cpp | 4 +--- mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp | 4 +--- mlir/lib/Pass/PassRegistry.cpp | 7 ++++--- mlir/lib/Transforms/CSE.cpp | 4 +--- mlir/lib/Transforms/Canonicalizer.cpp | 4 +--- mlir/lib/Transforms/ConstantFold.cpp | 4 +--- mlir/lib/Transforms/DmaGeneration.cpp | 4 +--- mlir/lib/Transforms/LoopFusion.cpp | 4 +--- mlir/lib/Transforms/LoopTiling.cpp | 4 +--- mlir/lib/Transforms/LoopUnroll.cpp | 4 +--- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 4 +--- mlir/lib/Transforms/LowerAffine.cpp | 4 +--- mlir/lib/Transforms/LowerVectorTransfers.cpp | 4 +--- mlir/lib/Transforms/MaterializeVectors.cpp | 4 +--- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 4 +--- mlir/lib/Transforms/PipelineDataTransfer.cpp | 4 +--- mlir/lib/Transforms/SimplifyAffineStructures.cpp | 4 +--- mlir/lib/Transforms/StripDebugInfo.cpp | 4 +--- mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp | 4 +--- mlir/lib/Transforms/Vectorize.cpp | 4 +--- mlir/lib/Transforms/ViewFunctionGraph.cpp | 4 +--- 29 files changed, 44 insertions(+), 84 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/g3doc/QuickstartRewrites.md b/mlir/g3doc/QuickstartRewrites.md index a7548c00392..2ccbfa174ab 100644 --- a/mlir/g3doc/QuickstartRewrites.md +++ b/mlir/g3doc/QuickstartRewrites.md @@ -222,7 +222,7 @@ struct TestPass : public FunctionPass { TestPass() : FunctionPass(&TestPass::passID) {} PassResult runOnFunction(Function *f) override; - static char passID; + constexpr static PassID passID = {}; }; } // end anonymous namespace diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h index c489dafb20b..6edcc3bb8a8 100644 --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -24,6 +24,10 @@ namespace mlir { class Function; class Module; +/// A special type used by transformation passes to provide an address that can +/// act as a unique identifier during pass registration. +struct alignas(8) PassID {}; + // Values that can be used by to signal success/failure. This can be implicitly // converted to/from boolean values, with false representing success and true // failure. @@ -35,18 +39,18 @@ struct LLVM_NODISCARD PassResult { class Pass { public: - explicit Pass(const void *passID) : passID(passID) {} + explicit Pass(const PassID *passID) : passID(passID) {} virtual ~Pass() = default; virtual PassResult runOnModule(Module *m) = 0; /// Returns the unique identifier that corresponds to this pass. - const void *getPassID() const { return passID; } + const PassID *getPassID() const { return passID; } static PassResult success() { return PassResult::Success; } static PassResult failure() { return PassResult::Failure; } /// Returns the pass info for the specified pass class or null if unknown. - static const PassInfo *lookupPassInfo(const void *passID); + static const PassInfo *lookupPassInfo(const PassID *passID); /// Returns the pass info for this pass. const PassInfo *lookupPassInfo() const { return lookupPassInfo(passID); } @@ -57,12 +61,12 @@ private: virtual void anchor(); /// Unique identifier for pass. - const void *const passID; + const PassID *const passID; }; class ModulePass : public Pass { public: - explicit ModulePass(const void *passID) : Pass(passID) {} + explicit ModulePass(const PassID *passID) : Pass(passID) {} virtual PassResult runOnModule(Module *m) override = 0; @@ -79,7 +83,7 @@ private: /// module. class FunctionPass : public Pass { public: - explicit FunctionPass(const void *passID) : Pass(passID) {} + explicit FunctionPass(const PassID *passID) : Pass(passID) {} /// Implement this function to be run on every function in the module. virtual PassResult runOnFunction(Function *fn) = 0; diff --git a/mlir/include/mlir/Pass/PassRegistry.h b/mlir/include/mlir/Pass/PassRegistry.h index 8f324bf0f69..c8d85f1dd69 100644 --- a/mlir/include/mlir/Pass/PassRegistry.h +++ b/mlir/include/mlir/Pass/PassRegistry.h @@ -31,6 +31,7 @@ namespace mlir { class Pass; +class PassID; using PassAllocatorFunction = std::function; @@ -40,7 +41,7 @@ class PassInfo { public: /// PassInfo constructor should not be invoked directly, instead use /// PassRegistration or registerPass. - PassInfo(StringRef arg, StringRef description, const void *passID, + PassInfo(StringRef arg, StringRef description, const PassID *passID, PassAllocatorFunction allocator) : arg(arg), description(description), allocator(allocator), passID(passID) {} @@ -70,12 +71,12 @@ private: PassAllocatorFunction allocator; // Unique identifier for pass. - const void *passID; + const PassID *passID; }; /// Register a specific dialect creation function with the system, typically /// used through the PassRegistration template. -void registerPass(StringRef arg, StringRef description, const void *passID, +void registerPass(StringRef arg, StringRef description, const PassID *passID, const PassAllocatorFunction &function); /// PassRegistration provides a global initializer that registers a Pass diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index b547e21c28a..b1e87e7223e 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -151,7 +151,7 @@ class DialectConversion : public ModulePass { public: /// Construct a pass given its unique identifier. - DialectConversion(const void *passID) : ModulePass(passID) {} + DialectConversion(const PassID *passID) : ModulePass(passID) {} /// Run the pass on the module. PassResult runOnModule(Module *m) override; diff --git a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h index 1abd85a1d2b..15e4f215c61 100644 --- a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h +++ b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h @@ -101,7 +101,7 @@ using OwningMLLoweringPatternList = template class MLPatternLoweringPass : public FunctionPass { public: - explicit MLPatternLoweringPass(void *ID) : FunctionPass(ID) {} + explicit MLPatternLoweringPass(const PassID *ID) : FunctionPass(ID) {} virtual std::unique_ptr makeFuncWiseState(Function *f) const { diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index 9f6efff3187..b86651793f9 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -42,13 +42,11 @@ struct MemRefBoundCheck : public FunctionPass { PassResult runOnFunction(Function *f) override; - static char passID; + constexpr static PassID passID = {}; }; } // end anonymous namespace -char MemRefBoundCheck::passID = 0; - FunctionPass *mlir::createMemRefBoundCheckPass() { return new MemRefBoundCheck(); } diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp index 43bc0c98916..0b5c9b997a5 100644 --- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp @@ -44,13 +44,11 @@ struct MemRefDependenceCheck : public FunctionPass { PassResult runOnFunction(Function *f) override; - static char passID; + constexpr static PassID passID = {}; }; } // end anonymous namespace -char MemRefDependenceCheck::passID = 0; - FunctionPass *mlir::createMemRefDependenceCheckPass() { return new MemRefDependenceCheck(); } diff --git a/mlir/lib/Analysis/OpStats.cpp b/mlir/lib/Analysis/OpStats.cpp index 6ae7ec59c50..c1fcacac15a 100644 --- a/mlir/lib/Analysis/OpStats.cpp +++ b/mlir/lib/Analysis/OpStats.cpp @@ -36,7 +36,7 @@ struct PrintOpStatsPass : public ModulePass { // Print summary of op stats. void printSummary(); - static char passID; + constexpr static PassID passID = {}; private: llvm::StringMap opCount; @@ -44,8 +44,6 @@ private: }; } // namespace -char PrintOpStatsPass::passID = 0; - PassResult PrintOpStatsPass::runOnModule(Module *m) { opCount.clear(); diff --git a/mlir/lib/EDSC/LowerEDSCTestPass.cpp b/mlir/lib/EDSC/LowerEDSCTestPass.cpp index 41b2031bdbc..67d4fb38080 100644 --- a/mlir/lib/EDSC/LowerEDSCTestPass.cpp +++ b/mlir/lib/EDSC/LowerEDSCTestPass.cpp @@ -35,12 +35,10 @@ struct LowerEDSCTestPass : public FunctionPass { LowerEDSCTestPass() : FunctionPass(&LowerEDSCTestPass::passID) {} PassResult runOnFunction(Function *f) override; - static char passID; + constexpr static PassID passID = {}; }; } // end anonymous namespace -char LowerEDSCTestPass::passID = 0; - #include "mlir/EDSC/reference-impl.inc" PassResult LowerEDSCTestPass::runOnFunction(Function *f) { diff --git a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp index e1abeffe604..ba3619e38b6 100644 --- a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp +++ b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp @@ -1030,7 +1030,7 @@ class LLVMLowering : public DialectConversion { public: LLVMLowering() : DialectConversion(&passID) {} - const static char passID = '\0'; + constexpr static PassID passID = {}; protected: // Create a set of converters that live in the pass object by passing them a @@ -1078,8 +1078,6 @@ private: llvm::Module *module; }; -const char LLVMLowering::passID; - ModulePass *mlir::createConvertToLLVMIRPass() { return new LLVMLowering; } static PassRegistration diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp index c26da1f4099..e90fb2217a2 100644 --- a/mlir/lib/Pass/PassRegistry.cpp +++ b/mlir/lib/Pass/PassRegistry.cpp @@ -23,10 +23,11 @@ using namespace mlir; /// Static mapping of all of the registered passes. -static llvm::ManagedStatic> passRegistry; +static llvm::ManagedStatic> + passRegistry; void mlir::registerPass(StringRef arg, StringRef description, - const void *passID, + const PassID *passID, const PassAllocatorFunction &function) { bool inserted = passRegistry ->insert(std::make_pair( @@ -37,7 +38,7 @@ void mlir::registerPass(StringRef arg, StringRef description, } /// Returns the pass info for the specified pass class or null if unknown. -const PassInfo *mlir::Pass::lookupPassInfo(const void *passID) { +const PassInfo *mlir::Pass::lookupPassInfo(const PassID *passID) { auto it = passRegistry->find(passID); if (it == passRegistry->end()) return nullptr; diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index cd205fe773b..e83be30655a 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -83,7 +83,7 @@ namespace { struct CSE : public FunctionPass { CSE() : FunctionPass(&CSE::passID) {} - static char passID; + constexpr static PassID passID = {}; /// Shared implementation of operation elimination and scoped map definitions. using AllocatorTy = llvm::RecyclingAllocator< @@ -125,8 +125,6 @@ private: }; } // end anonymous namespace -char CSE::passID = 0; - /// Attempt to eliminate a redundant operation. bool CSE::simplifyOperation(Instruction *op) { // TODO(riverriddle) We currently only eliminate non side-effecting diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp index 0388744d4d2..ac77e201acf 100644 --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -37,12 +37,10 @@ struct Canonicalizer : public FunctionPass { Canonicalizer() : FunctionPass(&Canonicalizer::passID) {} PassResult runOnFunction(Function *fn) override; - static char passID; + constexpr static PassID passID = {}; }; } // end anonymous namespace -char Canonicalizer::passID = 0; - PassResult Canonicalizer::runOnFunction(Function *fn) { auto *context = fn->getContext(); OwningRewritePatternList patterns; diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index 7634d9ec16a..4817baaa23e 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -37,12 +37,10 @@ struct ConstantFold : public FunctionPass { void foldInstruction(Instruction *op); PassResult runOnFunction(Function *f) override; - static char passID; + constexpr static PassID passID = {}; }; } // end anonymous namespace -char ConstantFold::passID = 0; - /// Attempt to fold the specified operation, updating the IR to match. If /// constants are found, we keep track of them in the existingConstants list. /// diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 5df938634e7..5083bc4d586 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -116,13 +116,11 @@ struct DmaGeneration : public FunctionPass { // Constant zero index to avoid too many duplicates. Value *zeroIndex = nullptr; - static char passID; + constexpr static PassID passID = {}; }; } // end anonymous namespace -char DmaGeneration::passID = 0; - /// Generates DMAs for memref's living in 'slowMemorySpace' into newly created /// buffers in 'fastMemorySpace', and replaces memory operations to the former /// by the latter. Only load op's handled for now. diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 524b34bb86f..303efc69ceb 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -88,7 +88,7 @@ struct LoopFusion : public FunctionPass { LoopFusion() : FunctionPass(&LoopFusion::passID) {} PassResult runOnFunction(Function *f) override; - static char passID; + constexpr static PassID passID = {}; // Any local buffers smaller than this size will be created in // `fastMemorySpace` if provided. @@ -102,8 +102,6 @@ struct LoopFusion : public FunctionPass { } // end anonymous namespace -char LoopFusion::passID = 0; - FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; } namespace { diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 76e8e9254c9..2253d1d354a 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -52,13 +52,11 @@ struct LoopTiling : public FunctionPass { PassResult runOnFunction(Function *f) override; constexpr static unsigned kDefaultTileSize = 4; - static char passID; + constexpr static PassID passID = {}; }; } // end anonymous namespace -char LoopTiling::passID = 0; - // Tile size to use for all loops (overridden by -tile-sizes if provided). static llvm::cl::opt clTileSize("tile-size", llvm::cl::init(LoopTiling::kDefaultTileSize), diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index b452b4f76e2..3b4a0517f0d 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -86,12 +86,10 @@ struct LoopUnroll : public FunctionPass { static const unsigned kDefaultUnrollFactor = 4; - static char passID; + constexpr static PassID passID = {}; }; } // end anonymous namespace -char LoopUnroll::passID = 0; - PassResult LoopUnroll::runOnFunction(Function *f) { // Gathers all innermost loops through a post order pruned walk. struct InnermostLoopGatherer { diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 76668c7f0b5..87e2770aa41 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -82,12 +82,10 @@ struct LoopUnrollAndJam : public FunctionPass { PassResult runOnFunction(Function *f) override; bool runOnAffineForOp(OpPointer forOp); - static char passID; + constexpr static PassID passID = {}; }; } // end anonymous namespace -char LoopUnrollAndJam::passID = 0; - FunctionPass *mlir::createLoopUnrollAndJamPass(int unrollJamFactor) { return new LoopUnrollAndJam( unrollJamFactor == -1 ? None : Optional(unrollJamFactor)); diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index 8b62601ab41..83620516994 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -251,12 +251,10 @@ public: bool lowerAffineIf(AffineIfOp *ifOp); bool lowerAffineApply(AffineApplyOp *op); - static char passID; + constexpr static PassID passID = {}; }; } // end anonymous namespace -char LowerAffinePass::passID = 0; - // Given a range of values, emit the code that reduces them with "min" or "max" // depending on the provided comparison predicate. The predicate defines which // comparison to perform, "lt" for "min", "gt" for "max" and is used for the diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index bd43a637665..ac8f7e064f5 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -434,13 +434,11 @@ struct LowerVectorTransfersPass // Thread-safe RAII context with local scope. BumpPtrAllocator freed on exit. edsc::ScopedEDSCContext raiiContext; - static char passID; + constexpr static PassID passID = {}; }; } // end anonymous namespace -char LowerVectorTransfersPass::passID = 0; - FunctionPass *mlir::createLowerVectorTransfersPass() { return new LowerVectorTransfersPass(); } diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 3cd33c20fae..6177ca1233b 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -201,13 +201,11 @@ struct MaterializeVectorsPass : public FunctionPass { PassResult runOnFunction(Function *f) override; - static char passID; + constexpr static PassID passID = {}; }; } // end anonymous namespace -char MaterializeVectorsPass::passID = 0; - /// Given a shape with sizes greater than 0 along all dimensions, /// returns the distance, in number of elements, between a slice in a dimension /// and the next slice in the same dimension. diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index 68bce854222..0ba06fecae0 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -84,13 +84,11 @@ struct MemRefDataFlowOpt : public FunctionPass { DominanceInfo *domInfo = nullptr; PostDominanceInfo *postDomInfo = nullptr; - static char passID; + constexpr static PassID passID = {}; }; } // end anonymous namespace -char MemRefDataFlowOpt::passID = 0; - /// Creates a pass to perform optimizations relying on memref dataflow such as /// store to load forwarding, elimination of dead stores, and dead allocs. FunctionPass *mlir::createMemRefDataFlowOptPass() { diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index b6bfde58494..f41f56efd8f 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -45,13 +45,11 @@ struct PipelineDataTransfer : public FunctionPass { std::vector> forOps; - static char passID; + constexpr static PassID passID = {}; }; } // end anonymous namespace -char PipelineDataTransfer::passID = 0; - /// Creates a pass to pipeline explicit movement of data across levels of the /// memory hierarchy. FunctionPass *mlir::createPipelineDataTransferPass() { diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp index 20961180b83..4ddfd9f06fb 100644 --- a/mlir/lib/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp @@ -42,13 +42,11 @@ struct SimplifyAffineStructures : public FunctionPass { PassResult runOnFunction(Function *f) override; - static char passID; + constexpr static PassID passID = {}; }; } // end anonymous namespace -char SimplifyAffineStructures::passID = 0; - FunctionPass *mlir::createSimplifyAffineStructuresPass() { return new SimplifyAffineStructures(); } diff --git a/mlir/lib/Transforms/StripDebugInfo.cpp b/mlir/lib/Transforms/StripDebugInfo.cpp index 2eb4a37445c..fc2b0eb0a95 100644 --- a/mlir/lib/Transforms/StripDebugInfo.cpp +++ b/mlir/lib/Transforms/StripDebugInfo.cpp @@ -28,12 +28,10 @@ struct StripDebugInfo : public FunctionPass { PassResult runOnFunction(Function *f) override; - static char passID; + constexpr static PassID passID = {}; }; } // end anonymous namespace -char StripDebugInfo::passID = 0; - PassResult StripDebugInfo::runOnFunction(Function *f) { UnknownLoc unknownLoc = UnknownLoc::get(f->getContext()); diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index 3fa08bba096..2363c5638ee 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -96,13 +96,11 @@ struct VectorizerTestPass : public FunctionPass { void testComposeMaps(Function *f); void testNormalizeMaps(Function *f); - static char passID; + constexpr static PassID passID = {}; }; } // end anonymous namespace -char VectorizerTestPass::passID = 0; - void VectorizerTestPass::testVectorShapeRatio(Function *f) { using matcher::Op; SmallVector shape(clTestVectorShapeRatio.begin(), diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 40a2c9794ae..e009696fa1f 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -656,13 +656,11 @@ struct Vectorize : public FunctionPass { PassResult runOnFunction(Function *f) override; - static char passID; + constexpr static PassID passID = {}; }; } // end anonymous namespace -char Vectorize::passID = 0; - /////// TODO(ntv): Hoist to a VectorizationStrategy.cpp when appropriate. ////// namespace { diff --git a/mlir/lib/Transforms/ViewFunctionGraph.cpp b/mlir/lib/Transforms/ViewFunctionGraph.cpp index 4865859b9ec..14e21770e25 100644 --- a/mlir/lib/Transforms/ViewFunctionGraph.cpp +++ b/mlir/lib/Transforms/ViewFunctionGraph.cpp @@ -83,7 +83,7 @@ struct PrintCFGPass : public FunctionPass { return success(); } - static char passID; + constexpr static PassID passID = {}; private: llvm::raw_ostream &os; @@ -92,8 +92,6 @@ private: }; } // namespace -char PrintCFGPass::passID = 0; - FunctionPass *mlir::createPrintCFGGraphPass(llvm::raw_ostream &os, bool shortNames, const llvm::Twine &title) { -- cgit v1.2.3 From 8564b274dbd7cf8a295a8ec16de67b330dffc694 Mon Sep 17 00:00:00 2001 From: MLIR Team Date: Fri, 22 Feb 2019 07:48:59 -0800 Subject: Internal change PiperOrigin-RevId: 235191129 --- mlir/lib/Transforms/LoopFusion.cpp | 2 +- mlir/lib/Transforms/Vectorize.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 303efc69ceb..5dbefb875da 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -1511,7 +1511,7 @@ static bool isFusionProfitable(Instruction *srcOpInst, LLVM_DEBUG({ std::stringstream msg; msg << " fusion is most profitable at depth " << *dstLoopDepth << " with " - << setprecision(2) << additionalComputeFraction + << std::setprecision(2) << additionalComputeFraction << "% redundant computation and a "; msg << (storageReduction.hasValue() ? std::to_string(storageReduction.getValue()) diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index e009696fa1f..5722b9d17da 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -979,7 +979,7 @@ static Value *vectorizeConstant(Instruction *inst, const ConstantOp &constant, OperationState state( b.getContext(), loc, constantOpInst->getName().getStringRef(), {}, {vectorType}, - {make_pair(Identifier::get("value", b.getContext()), attr)}); + {std::make_pair(Identifier::get("value", b.getContext()), attr)}); return b.createOperation(state)->getResult(0); } -- cgit v1.2.3 From dfe07b7bf6077040cbb2b4392cbd81dc443570b2 Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Fri, 22 Feb 2019 16:51:08 -0800 Subject: Refactor AffineExprFlattener and move FlatAffineConstraints out of IR into Analysis - NFC - refactor AffineExprFlattener (-> SimpleAffineExprFlattener) so that it doesn't depend on FlatAffineConstraints, and so that FlatAffineConstraints could be moved out of IR/; the simplification that the IR needs for AffineExpr's doesn't depend on FlatAffineConstraints - have AffineExprFlattener derive from SimpleAffineExprFlattener to use for all Analysis/Transforms purposes; override addLocalFloorDivId in the derived class - turn addAffineForOpDomain into a method on FlatAffineConstraints - turn AffineForOp::getAsValueMap into an AffineValueMap ctor PiperOrigin-RevId: 235283610 --- mlir/include/mlir/AffineOps/AffineOps.h | 15 - mlir/include/mlir/Analysis/AffineStructures.h | 761 +++++++ mlir/include/mlir/Analysis/Utils.h | 4 +- mlir/include/mlir/IR/AffineExpr.h | 21 +- mlir/include/mlir/IR/AffineExprVisitor.h | 136 ++ mlir/include/mlir/IR/AffineStructures.h | 708 ------ mlir/lib/AffineOps/AffineOps.cpp | 113 - mlir/lib/Analysis/AffineAnalysis.cpp | 4 +- mlir/lib/Analysis/AffineStructures.cpp | 2530 ++++++++++++++++++++++ mlir/lib/Analysis/LoopAnalysis.cpp | 5 +- mlir/lib/Analysis/MemRefBoundCheck.cpp | 2 +- mlir/lib/Analysis/MemRefDependenceCheck.cpp | 2 +- mlir/lib/Analysis/Utils.cpp | 4 +- mlir/lib/IR/AffineExpr.cpp | 484 ++--- mlir/lib/IR/AffineStructures.cpp | 2312 -------------------- mlir/lib/Transforms/DmaGeneration.cpp | 2 +- mlir/lib/Transforms/LoopFusion.cpp | 2 +- mlir/lib/Transforms/LoopTiling.cpp | 2 +- mlir/lib/Transforms/SimplifyAffineStructures.cpp | 2 +- mlir/lib/Transforms/Utils/LoopUtils.cpp | 2 +- mlir/lib/Transforms/Utils/Utils.cpp | 2 +- 21 files changed, 3636 insertions(+), 3477 deletions(-) create mode 100644 mlir/include/mlir/Analysis/AffineStructures.h delete mode 100644 mlir/include/mlir/IR/AffineStructures.h create mode 100644 mlir/lib/Analysis/AffineStructures.cpp delete mode 100644 mlir/lib/IR/AffineStructures.cpp (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/AffineOps/AffineOps.h b/mlir/include/mlir/AffineOps/AffineOps.h index 0ae43426db0..f4d03380123 100644 --- a/mlir/include/mlir/AffineOps/AffineOps.h +++ b/mlir/include/mlir/AffineOps/AffineOps.h @@ -64,9 +64,6 @@ public: return getAttrOfType("map").getValue(); } - /// Returns an AffineValueMap representing this affine apply. - AffineValueMap getAsAffineValueMap(); - /// Returns true if the result of this operation can be used as dimension id. bool isValidDim() const; @@ -253,18 +250,6 @@ ConstOpPointer getForInductionVarOwner(const Value *val); void extractForInductionVars(ArrayRef> forInsts, SmallVectorImpl *ivs); -/// Adds constraints (lower and upper bounds) for the specified 'for' -/// instruction's Value using IR information stored in its bound maps. The -/// right identifier is first looked up using forOp's Value. Returns -/// false for the yet unimplemented/unsupported cases, and true if the -/// information is successfully added. Asserts if the Value corresponding to -/// the 'for' instruction isn't found in the constraint system. Any new -/// identifiers that are found in the bound operands of the 'for' instruction -/// are added as trailing identifiers (either dimensional or symbolic -/// depending on whether the operand is a valid ML Function symbol). -// TODO(bondhugula): add support for non-unit strides. -bool addAffineForOpDomain(ConstOpPointer forOp, - FlatAffineConstraints *constraints); /// AffineBound represents a lower or upper bound in the for instruction. /// This class does not own the underlying operands. Instead, it refers diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h new file mode 100644 index 00000000000..be9dcd40f67 --- /dev/null +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -0,0 +1,761 @@ +//===- AffineStructures.h - MLIR Affine Structures Class --------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// Structures for affine/polyhedral analysis of ML functions. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_AFFINE_STRUCTURES_H +#define MLIR_ANALYSIS_AFFINE_STRUCTURES_H + +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/OpDefinition.h" + +namespace mlir { + +class AffineApplyOp; +class AffineBound; +class AffineCondition; +class AffineMap; +class AffineForOp; +class IntegerSet; +class MLIRContext; +class Value; +class HyperRectangularSet; +class MemRefType; + +/// A mutable affine map. Its affine expressions are however unique. +struct MutableAffineMap { +public: + MutableAffineMap() {} + MutableAffineMap(AffineMap map); + + ArrayRef getResults() const { return results; } + AffineExpr getResult(unsigned idx) const { return results[idx]; } + void setResult(unsigned idx, AffineExpr result) { results[idx] = result; } + unsigned getNumResults() const { return results.size(); } + unsigned getNumDims() const { return numDims; } + void setNumDims(unsigned d) { numDims = d; } + unsigned getNumSymbols() const { return numSymbols; } + void setNumSymbols(unsigned d) { numSymbols = d; } + MLIRContext *getContext() const { return context; } + + /// Returns true if the idx'th result expression is a multiple of factor. + bool isMultipleOf(unsigned idx, int64_t factor) const; + + /// Resets this MutableAffineMap with 'map'. + void reset(AffineMap map); + + /// Simplify the (result) expressions in this map using analysis (used by + //-simplify-affine-expr pass). + void simplify(); + /// Get the AffineMap corresponding to this MutableAffineMap. Note that an + /// AffineMap will be uniqued and stored in context, while a mutable one + /// isn't. + AffineMap getAffineMap() const; + +private: + // Same meaning as AffineMap's fields. + SmallVector results; + SmallVector rangeSizes; + unsigned numDims; + unsigned numSymbols; + /// A pointer to the IR's context to store all newly created + /// AffineExprStorage's. + MLIRContext *context; +}; + +/// A mutable integer set. Its affine expressions are however unique. +struct MutableIntegerSet { +public: + MutableIntegerSet(IntegerSet set, MLIRContext *context); + + /// Create a universal set (no constraints). + MutableIntegerSet(unsigned numDims, unsigned numSymbols, + MLIRContext *context); + + unsigned getNumDims() const { return numDims; } + unsigned getNumSymbols() const { return numSymbols; } + unsigned getNumConstraints() const { return constraints.size(); } + + void clear() { + constraints.clear(); + eqFlags.clear(); + } + +private: + unsigned numDims; + unsigned numSymbols; + + SmallVector constraints; + SmallVector eqFlags; + /// A pointer to the IR's context to store all newly created + /// AffineExprStorage's. + MLIRContext *context; +}; + +/// An AffineValueMap is an affine map plus its ML value operands and +/// results for analysis purposes. The structure is still a tree form that is +/// same as that of an affine map or an AffineApplyOp. However, its operands, +/// results, and its map can themselves change as a result of +/// substitutions, simplifications, and other analysis. +// An affine value map can readily be constructed from an AffineApplyOp, or an +// AffineBound of a AffineForOp. It can be further transformed, substituted +// into, or simplified. Unlike AffineMap's, AffineValueMap's are created and +// destroyed during analysis. Only the AffineMap expressions that are pointed by +// them are unique'd. An affine value map, and the operations on it, maintain +// the invariant that operands are always positionally aligned with the +// AffineDimExpr and AffineSymbolExpr in the underlying AffineMap. +// TODO(bondhugula): Some of these classes could go into separate files. +class AffineValueMap { +public: + // Creates an empty AffineValueMap (users should call 'reset' to reset map + // and operands). + AffineValueMap() {} + AffineValueMap(AffineMap map); + AffineValueMap(AffineMap map, ArrayRef operands, + ArrayRef results = llvm::None); + + explicit AffineValueMap(OpPointer applyOp); + explicit AffineValueMap(AffineBound bound); + + ~AffineValueMap(); + + // Resets this AffineValueMap with 'map', 'operands', and 'results'. + void reset(AffineMap map, ArrayRef operands, + ArrayRef results = llvm::None); + + /// Return true if the idx^th result can be proved to be a multiple of + /// 'factor', false otherwise. + inline bool isMultipleOf(unsigned idx, int64_t factor) const; + + /// Return true if the idx^th result depends on 'value', false otherwise. + bool isFunctionOf(unsigned idx, Value *value) const; + + /// Return true if the result at 'idx' is a constant, false + /// otherwise. + bool isConstant(unsigned idx) const; + + /// Return true if this is an identity map. + bool isIdentity() const; + + inline unsigned getNumOperands() const { return operands.size(); } + inline unsigned getNumDims() const { return map.getNumDims(); } + inline unsigned getNumSymbols() const { return map.getNumSymbols(); } + inline unsigned getNumResults() const { return map.getNumResults(); } + + Value *getOperand(unsigned i) const; + ArrayRef getOperands() const; + AffineMap getAffineMap() const; + +private: + // A mutable affine map. + MutableAffineMap map; + + // TODO: make these trailing objects? + /// The SSA operands binding to the dim's and symbols of 'map'. + SmallVector operands; + /// The SSA results binding to the results of 'map'. + SmallVector results; +}; + +/// An IntegerValueSet is an integer set plus its operands. +// Both, the integer set being pointed to and the operands can change during +// analysis, simplification, and transformation. +class IntegerValueSet { + /// Constructs an integer value set from an affine value map. + // This will lead to a single equality in 'set'. + explicit IntegerValueSet(const AffineValueMap &avm); + + /// Returns true if this integer set is determined to be empty. Emptiness is + /// checked by by eliminating identifiers successively (through either + /// Gaussian or Fourier-Motzkin) while using the GCD test and a trivial + /// invalid constraint check. Returns 'true' if the constaint system is found + /// to be empty; false otherwise. This method is exact for rational spaces but + /// not integer spaces - thus, if it returns true, the set is provably integer + /// empty as well, but if it returns false, it doesn't necessarily mean an + /// integer point exists in it. This method also returns false where an + /// explosion of constraints is detected - due to the super-exponential + /// worse-case complexity of Fourier-Motzkin elimination (rare for realistic + /// problem cases but possible for artificial adversarial or improperly + // constructed ones), this method returns false conservatively. + bool isEmpty() const; + + bool getNumDims() const { return set.getNumDims(); } + bool getNumSymbols() const { return set.getNumSymbols(); } + +private: + // The set pointed to may itself change unlike in IR structures like + // 'AffineCondition'. + MutableIntegerSet set; + /// The SSA operands binding to the dim's and symbols of 'set'. + SmallVector operands; +}; + +/// A flat list of affine equalities and inequalities in the form. +/// Inequality: c_0*x_0 + c_1*x_1 + .... + c_{n-1}*x_{n-1} == 0 +/// Equality: c_0*x_0 + c_1*x_1 + .... + c_{n-1}*x_{n-1} >= 0 +/// +/// FlatAffineConstraints stores coefficients in a contiguous buffer (one buffer +/// for equalities and one for inequalities). The size of each buffer is +/// numReservedCols * number of inequalities (or equalities). The reserved size +/// is numReservedCols * numReservedInequalities (or numReservedEqualities). A +/// coefficient (r, c) lives at the location numReservedCols * r + c in the +/// buffer. The extra space between getNumCols() and numReservedCols exists to +/// prevent frequent movement of data when adding columns, especially at the +/// end. +/// +/// The identifiers x_0, x_1, ... appear in the order: dimensional identifiers, +/// symbolic identifiers, and local identifiers. The local identifiers +/// correspond to local/internal variables created when converting from +/// AffineExpr's containing mod's and div's; they are thus needed to increase +/// representational power. Each local identifier is always (by construction) a +/// floordiv of a pure add/mul affine function of dimensional, symbolic, and +/// other local identifiers, in a non-mutually recursive way. Hence, every local +/// identifier can ultimately always be recovered as an affine function of +/// dimensional and symbolic identifiers (involving floordiv's); note however +/// that some floordiv combinations are converted to mod's by AffineExpr +/// construction. +/// +class FlatAffineConstraints { +public: + enum IdKind { Dimension, Symbol, Local }; + + /// Constructs a constraint system reserving memory for the specified number + /// of constraints and identifiers.. + FlatAffineConstraints(unsigned numReservedInequalities, + unsigned numReservedEqualities, + unsigned numReservedCols, unsigned numDims = 0, + unsigned numSymbols = 0, unsigned numLocals = 0, + ArrayRef> idArgs = {}) + : numReservedCols(numReservedCols), numDims(numDims), + numSymbols(numSymbols) { + assert(numReservedCols >= numDims + numSymbols + 1); + assert(idArgs.empty() || idArgs.size() == numDims + numSymbols + numLocals); + equalities.reserve(numReservedCols * numReservedEqualities); + inequalities.reserve(numReservedCols * numReservedInequalities); + numIds = numDims + numSymbols + numLocals; + ids.reserve(numReservedCols); + if (idArgs.empty()) + ids.resize(numIds, None); + else + ids.append(idArgs.begin(), idArgs.end()); + } + + /// Constructs a constraint system with the specified number of + /// dimensions and symbols. + FlatAffineConstraints(unsigned numDims = 0, unsigned numSymbols = 0, + unsigned numLocals = 0, + ArrayRef> idArgs = {}) + : numReservedCols(numDims + numSymbols + numLocals + 1), numDims(numDims), + numSymbols(numSymbols) { + assert(numReservedCols >= numDims + numSymbols + 1); + assert(idArgs.empty() || idArgs.size() == numDims + numSymbols + numLocals); + numIds = numDims + numSymbols + numLocals; + ids.reserve(numIds); + if (idArgs.empty()) + ids.resize(numIds, None); + else + ids.append(idArgs.begin(), idArgs.end()); + } + + explicit FlatAffineConstraints(const HyperRectangularSet &set); + + /// Create a flat affine constraint system from an AffineValueMap or a list of + /// these. The constructed system will only include equalities. + // TODO(bondhugula) + explicit FlatAffineConstraints(const AffineValueMap &avm); + explicit FlatAffineConstraints(ArrayRef avmRef); + + /// Creates an affine constraint system from an IntegerSet. + explicit FlatAffineConstraints(IntegerSet set); + + /// Create an affine constraint system from an IntegerValueSet. + // TODO(bondhugula) + explicit FlatAffineConstraints(const IntegerValueSet &set); + + FlatAffineConstraints(const FlatAffineConstraints &other); + + FlatAffineConstraints(ArrayRef avmRef, + IntegerSet set); + + FlatAffineConstraints(const MutableAffineMap &map); + + ~FlatAffineConstraints() {} + + // Clears any existing data and reserves memory for the specified constraints. + void reset(unsigned numReservedInequalities, unsigned numReservedEqualities, + unsigned numReservedCols, unsigned numDims, unsigned numSymbols, + unsigned numLocals = 0, ArrayRef idArgs = {}); + + void reset(unsigned numDims = 0, unsigned numSymbols = 0, + unsigned numLocals = 0, ArrayRef idArgs = {}); + + /// Appends constraints from 'other' into this. This is equivalent to an + /// intersection with no simplification of any sort attempted. + void append(const FlatAffineConstraints &other); + + // Checks for emptiness by performing variable elimination on all identifiers, + // running the GCD test on each equality constraint, and checking for invalid + // constraints. + // Returns true if the GCD test fails for any equality, or if any invalid + // constraints are discovered on any row. Returns false otherwise. + bool isEmpty() const; + + // Runs the GCD test on all equality constraints. Returns 'true' if this test + // fails on any equality. Returns 'false' otherwise. + // This test can be used to disprove the existence of a solution. If it + // returns true, no integer solution to the equality constraints can exist. + bool isEmptyByGCDTest() const; + + // Clones this object. + std::unique_ptr clone() const; + + /// Returns the value at the specified equality row and column. + inline int64_t atEq(unsigned i, unsigned j) const { + return equalities[i * numReservedCols + j]; + } + inline int64_t &atEq(unsigned i, unsigned j) { + return equalities[i * numReservedCols + j]; + } + + inline int64_t atIneq(unsigned i, unsigned j) const { + return inequalities[i * numReservedCols + j]; + } + + inline int64_t &atIneq(unsigned i, unsigned j) { + return inequalities[i * numReservedCols + j]; + } + + /// Returns the number of columns in the constraint system. + inline unsigned getNumCols() const { return numIds + 1; } + + inline unsigned getNumEqualities() const { + assert(equalities.size() % numReservedCols == 0 && + "inconsistent equality buffer size"); + return equalities.size() / numReservedCols; + } + + inline unsigned getNumInequalities() const { + assert(inequalities.size() % numReservedCols == 0 && + "inconsistent inequality buffer size"); + return inequalities.size() / numReservedCols; + } + + inline unsigned getNumReservedEqualities() const { + return equalities.capacity() / numReservedCols; + } + + inline unsigned getNumReservedInequalities() const { + return inequalities.capacity() / numReservedCols; + } + + inline ArrayRef getEquality(unsigned idx) const { + return ArrayRef(&equalities[idx * numReservedCols], getNumCols()); + } + + inline ArrayRef getInequality(unsigned idx) const { + return ArrayRef(&inequalities[idx * numReservedCols], + getNumCols()); + } + + AffineExpr toAffineExpr(unsigned idx, MLIRContext *context); + + /// Adds constraints (lower and upper bounds) for the specified 'for' + /// instruction's Value using IR information stored in its bound maps. The + /// right identifier is first looked up using forOp's Value. Returns + /// false for the yet unimplemented/unsupported cases, and true if the + /// information is successfully added. Asserts if the Value corresponding to + /// the 'for' instruction isn't found in the constraint system. Any new + /// identifiers that are found in the bound operands of the 'for' instruction + /// are added as trailing identifiers (either dimensional or symbolic + /// depending on whether the operand is a valid ML Function symbol). + // TODO(bondhugula): add support for non-unit strides. + bool addAffineForOpDomain(ConstOpPointer forOp); + + /// Computes the lower and upper bounds of the first 'num' dimensional + /// identifiers as an affine map of the remaining identifiers (dimensional and + /// symbolic). This method is able to detect identifiers as floordiv's + /// and mod's of affine expressions of other identifiers with respect to + /// (positive) constants. Sets bound map to a null AffineMap if such a bound + /// can't be found (or yet unimplemented). + void getSliceBounds(unsigned num, MLIRContext *context, + SmallVectorImpl *lbMaps, + SmallVectorImpl *ubMaps); + + /// Adds slice lower bounds represented by lower bounds in 'lbMaps' and upper + /// bounds in 'ubMaps' to the constraint system. Note that both lower/upper + /// bounds share the same operand list 'operands'. + /// This function assumes that position 'lbMaps.size' == 'ubMaps.size', + /// and that positions [0, lbMaps.size) represent dimensional identifiers + /// which correspond to the loop IVs whose iteration bounds are being sliced. + /// Note that both lower/upper bounds use operands from 'operands'. + /// Returns true on success, returns false for unimplemented cases. + bool addSliceBounds(ArrayRef lbMaps, ArrayRef ubMaps, + ArrayRef operands); + + // Adds an inequality (>= 0) from the coefficients specified in inEq. + void addInequality(ArrayRef inEq); + // Adds an equality from the coefficients specified in eq. + void addEquality(ArrayRef eq); + + /// Adds a constant lower bound constraint for the specified identifier. + void addConstantLowerBound(unsigned pos, int64_t lb); + /// Adds a constant upper bound constraint for the specified identifier. + void addConstantUpperBound(unsigned pos, int64_t ub); + + /// Adds a new local identifier as the floordiv of an affine function of other + /// identifiers, the coefficients of which are provided in 'dividend' and with + /// respect to a positive constant 'divisor'. Two constraints are added to the + /// system to capture equivalence with the floordiv: + /// q = dividend floordiv c <=> c*q <= dividend <= c*q + c - 1. + void addLocalFloorDiv(ArrayRef dividend, int64_t divisor); + + /// Adds a constant lower bound constraint for the specified expression. + void addConstantLowerBound(ArrayRef expr, int64_t lb); + /// Adds a constant upper bound constraint for the specified expression. + void addConstantUpperBound(ArrayRef expr, int64_t ub); + + /// Sets the identifier at the specified position to a constant. + void setIdToConstant(unsigned pos, int64_t val); + + /// Sets the identifier corresponding to the specified Value id to a + /// constant. Asserts if the 'id' is not found. + void setIdToConstant(const Value &id, int64_t val); + + /// Looks up the identifier with the specified Value. Returns false if not + /// found, true if found. pos is set to the (column) position of the + /// identifier. + bool findId(const Value &id, unsigned *pos) const; + + // Add identifiers of the specified kind - specified positions are relative to + // the kind of identifier. The coefficient column corresponding to the added + // identifier is initialized to zero. 'id' is the Value corresponding to the + // identifier that can optionally be provided. + void addDimId(unsigned pos, Value *id = nullptr); + void addSymbolId(unsigned pos, Value *id = nullptr); + void addLocalId(unsigned pos); + void addId(IdKind kind, unsigned pos, Value *id = nullptr); + + /// Composes the affine value map with this FlatAffineConstrains, adding the + /// results of the map as dimensions at the front [0, vMap->getNumResults()) + /// and with the dimensions set to the equalities specified by the value map. + /// Returns false if the composition fails (when vMap is a semi-affine map). + /// The vMap's operand Value's are used to look up the right positions in + /// the FlatAffineConstraints with which to associate. The dimensional and + /// symbolic operands of vMap should match 1:1 (in the same order) with those + /// of this constraint system, but the latter could have additional trailing + /// operands. + bool composeMap(AffineValueMap *vMap); + + /// Projects out (aka eliminates) 'num' identifiers starting at position + /// 'pos'. The resulting constraint system is the shadow along the dimensions + /// that still exist. This method may not always be integer exact. + // TODO(bondhugula): deal with integer exactness when necessary - can return a + // value to mark exactness for example. + void projectOut(unsigned pos, unsigned num); + inline void projectOut(unsigned pos) { return projectOut(pos, 1); } + + /// Projects out the identifier that is associate with Value *. + void projectOut(Value *id); + + void removeId(IdKind idKind, unsigned pos); + void removeId(unsigned pos); + + void removeDim(unsigned pos); + + void removeEquality(unsigned pos); + void removeInequality(unsigned pos); + + /// Changes the partition between dimensions and symbols. Depending on the new + /// symbol count, either a chunk of trailing dimensional identifiers becomes + /// symbols, or some of the leading symbols become dimensions. + void setDimSymbolSeparation(unsigned newSymbolCount); + + /// Sets the specified identifier to a constant and removes it. + void setAndEliminate(unsigned pos, int64_t constVal); + + /// Tries to fold the specified identifier to a constant using a trivial + /// equality detection; if successful, the constant is substituted for the + /// identifier everywhere in the constraint system and then removed from the + /// system. Returns true if the folding happens, false otherwise. + bool constantFoldId(unsigned pos); + + /// This method calls constantFoldId for the specified range of identifiers, + /// 'num' identifiers starting at position 'pos'. + void constantFoldIdRange(unsigned pos, unsigned num); + + /// Returns true if all the identifiers in the specified range [start, limit) + /// can only take a single value each if the remaining identifiers are treated + /// as symbols/parameters, i.e., for given values of the latter, there only + /// exists a unique value for each of the dimensions in the specified range. + bool isRangeOneToOne(unsigned start, unsigned limit) const; + + /// Updates the constraints to be the smallest bounding (enclosing) box that + /// contains the points of 'this' set and that of 'other', with the symbols + /// being treated specially. For each of the dimensions, the min of the lower + /// bounds (symbolic) and the max of the upper bounds (symbolic) is computed + /// to determine such a bounding box. + /// + /// Eg: if 'this' is {0 <= d0 <= 127}, 'other' is {16 <= d0 <= 192}, the + /// output is {0 <= d0 <= 192}. + /// 2) 'this' = {s0 + 5 <= d0 <= s0 + 20}, 'other' is {s0 + 1 <= d0 <= s0 + + /// 9}, output = {s0 + 1 <= d0 <= s0 + 20}. + /// 3) 'this' = {0 <= d0 <= 5, 1 <= d1 <= 9}, 'other' = {2 <= d0 <= 6, 5 <= d1 + /// <= 15}, output = {0 <= d0 <= 6, 1 <= d1 <= 15}. + bool unionBoundingBox(const FlatAffineConstraints &other); + + unsigned getNumConstraints() const { + return getNumInequalities() + getNumEqualities(); + } + inline unsigned getNumIds() const { return numIds; } + inline unsigned getNumDimIds() const { return numDims; } + inline unsigned getNumSymbolIds() const { return numSymbols; } + inline unsigned getNumDimAndSymbolIds() const { return numDims + numSymbols; } + inline unsigned getNumLocalIds() const { + return numIds - numDims - numSymbols; + } + + inline ArrayRef> getIds() const { + return {ids.data(), ids.size()}; + } + + /// Returns the Value associated with the pos^th identifier. Asserts if + /// no Value identifier was associated. + inline Value *getIdValue(unsigned pos) const { + assert(ids[pos].hasValue() && "identifier's Value not set"); + return ids[pos].getValue(); + } + + /// Returns the Values associated with identifiers in range [start, end). + /// Asserts if no Value was associated with one of these identifiers. + void getIdValues(unsigned start, unsigned end, + SmallVectorImpl *values) const { + assert((start < numIds || start == end) && "invalid start position"); + assert(end <= numIds && "invalid end position"); + values->clear(); + values->reserve(end - start); + for (unsigned i = start; i < end; i++) { + values->push_back(getIdValue(i)); + } + } + inline void getAllIdValues(SmallVectorImpl *values) const { + getIdValues(0, numIds, values); + } + + /// Sets Value associated with the pos^th identifier. + inline void setIdValue(unsigned pos, Value *val) { + assert(pos < numIds && "invalid id position"); + ids[pos] = val; + } + /// Sets Values associated with identifiers in the range [start, end). + void setIdValues(unsigned start, unsigned end, ArrayRef values) { + assert((start < numIds || end == start) && "invalid start position"); + assert(end <= numIds && "invalid end position"); + assert(values.size() == end - start); + for (unsigned i = start; i < end; ++i) + ids[i] = values[i - start]; + } + + /// Clears this list of constraints and copies other into it. + void clearAndCopyFrom(const FlatAffineConstraints &other); + + /// Returns the smallest known constant bound for the extent of the specified + /// identifier (pos^th), i.e., the smallest known constant that is greater + /// than or equal to 'exclusive upper bound' - 'lower bound' of the + /// identifier. Returns None if it's not a constant. This method employs + /// trivial (low complexity / cost) checks and detection. Symbolic identifiers + /// are treated specially, i.e., it looks for constant differences between + /// affine expressions involving only the symbolic identifiers. See comments + /// at function definition for examples. 'lb' and 'lbDivisor', if provided, + /// are used to express the lower bound associated with the constant + /// difference: 'lb' has the coefficients and lbDivisor, the divisor. For eg., + /// if the lower bound is [(s0 + s2 - 1) floordiv 32] for a system with three + /// symbolic identifiers, *lb = [1, 0, 1], lbDivisor = 32. + Optional + getConstantBoundOnDimSize(unsigned pos, + SmallVectorImpl *lb = nullptr, + int64_t *lbFloorDivisor = nullptr) const; + + /// Returns the constant lower bound for the pos^th identifier if there is + /// one; None otherwise. + Optional getConstantLowerBound(unsigned pos) const; + + /// Returns the constant upper bound for the pos^th identifier if there is + /// one; None otherwise. + Optional getConstantUpperBound(unsigned pos) const; + + /// Gets the lower and upper bound of the pos^th identifier treating + /// [dimStartPos, symbStartPos) as dimensions and [symStartPos, + /// getNumDimAndSymbolIds) as symbols. The returned multi-dimensional maps + /// in the pair represent the max and min of potentially multiple affine + /// expressions. The upper bound is exclusive. 'localExprs' holds pre-computed + /// AffineExpr's for all local identifiers in the system. + std::pair + getLowerAndUpperBound(unsigned pos, unsigned dimStartPos, + unsigned symStartPos, ArrayRef localExprs, + MLIRContext *context); + + /// Returns true if the set can be trivially detected as being + /// hyper-rectangular on the specified contiguous set of identifiers. + bool isHyperRectangular(unsigned pos, unsigned num) const; + + /// Removes duplicates and trivially true constraints: a constraint of the + /// form >= 0 is considered a trivially true + /// constraint. + void removeTrivialRedundancy(); + + /// A more expensive check to detect redundant inequalities. + void removeRedundantInequalities(); + + // Removes all equalities and inequalities. + void clearConstraints(); + + void print(raw_ostream &os) const; + void dump() const; + +private: + /// Returns false if the fields corresponding to various identifier counts, or + /// equality/inequality buffer sizes aren't consistent; true otherwise. This + /// is meant to be used within an assert internally. + bool hasConsistentState() const; + + /// Checks all rows of equality/inequality constraints for trivial + /// contradictions (for example: 1 == 0, 0 >= 1), which may have surfaced + /// after elimination. Returns 'true' if an invalid constraint is found; + /// 'false'otherwise. + bool hasInvalidConstraint() const; + + /// Returns the constant lower bound bound if isLower is true, and the upper + /// bound if isLower is false. + template + Optional computeConstantLowerOrUpperBound(unsigned pos); + + // Eliminates a single identifier at 'position' from equality and inequality + // constraints. Returns 'true' if the identifier was eliminated, and false + // otherwise. + inline bool gaussianEliminateId(unsigned position) { + return gaussianEliminateIds(position, position + 1) == 1; + } + + // Eliminates identifiers from equality and inequality constraints + // in column range [posStart, posLimit). + // Returns the number of variables eliminated. + unsigned gaussianEliminateIds(unsigned posStart, unsigned posLimit); + + /// Eliminates identifier at the specified position using Fourier-Motzkin + /// variable elimination, but uses Gaussian elimination if there is an + /// equality involving that identifier. If the result of the elimination is + /// integer exact, *isResultIntegerExact is set to true. If 'darkShadow' is + /// set to true, a potential under approximation (subset) of the rational + /// shadow / exact integer shadow is computed. + // See implementation comments for more details. + void FourierMotzkinEliminate(unsigned pos, bool darkShadow = false, + bool *isResultIntegerExact = nullptr); + + /// Tightens inequalities given that we are dealing with integer spaces. This + /// is similar to the GCD test but applied to inequalities. The constant term + /// can be reduced to the preceding multiple of the GCD of the coefficients, + /// i.e., + /// 64*i - 100 >= 0 => 64*i - 128 >= 0 (since 'i' is an integer). This is a + /// fast method (linear in the number of coefficients). + void GCDTightenInequalities(); + + /// Normalized each constraints by the GCD of its coefficients. + void normalizeConstraintsByGCD(); + + /// Removes identifiers in column range [idStart, idLimit), and copies any + /// remaining valid data into place, updates member variables, and resizes + /// arrays as needed. + void removeIdRange(unsigned idStart, unsigned idLimit); + + /// Coefficients of affine equalities (in == 0 form). + SmallVector equalities; + + /// Coefficients of affine inequalities (in >= 0 form). + SmallVector inequalities; + + /// Number of columns reserved. Actual ones in used are returned by + /// getNumCols(). + unsigned numReservedCols; + + /// Total number of identifiers. + unsigned numIds; + + /// Number of identifiers corresponding to real dimensions. + unsigned numDims; + + /// Number of identifiers corresponding to symbols (unknown but constant for + /// analysis). + unsigned numSymbols; + + /// Values corresponding to the (column) identifiers of this constraint + /// system appearing in the order the identifiers correspond to columns. + /// Temporary ones or those that aren't associated to any Value are to be + /// set to None. + SmallVector, 8> ids; + + /// A parameter that controls detection of an unrealistic number of + /// constraints. If the number of constraints is this many times the number of + /// variables, we consider such a system out of line with the intended use + /// case of FlatAffineConstraints. + // The rationale for 32 is that in the typical simplest of cases, an + // identifier is expected to have one lower bound and one upper bound + // constraint. With a level of tiling or a connection to another identifier + // through a div or mod, an extra pair of bounds gets added. As a limit, we + // don't expect an identifier to have more than 32 lower/upper/equality + // constraints. This is conservatively set low and can be raised if needed. + constexpr static unsigned kExplosionFactor = 32; +}; + +/// Simplify an affine expression by flattening and some amount of +/// simple analysis. This has complexity linear in the number of nodes in +/// 'expr'. Returns the simplified expression, which is the same as the input +/// expression if it can't be simplified. +AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, + unsigned numSymbols); + +/// Flattens 'expr' into 'flattenedExpr'. Returns true on success or false +/// if 'expr' could not be flattened (i.e., semi-affine is not yet handled). +/// 'cst' contains constraints that connect newly introduced local identifiers +/// to existing dimensional and / symbolic identifiers. See documentation for +/// AffineExprFlattener on how mod's and div's are flattened. +bool getFlattenedAffineExpr(AffineExpr expr, unsigned numDims, + unsigned numSymbols, + llvm::SmallVectorImpl *flattenedExpr, + FlatAffineConstraints *cst = nullptr); + +/// Flattens the result expressions of the map to their corresponding flattened +/// forms and set in 'flattenedExprs'. Returns true on success or false +/// if any expression in the map could not be flattened (i.e., semi-affine is +/// not yet handled). 'cst' contains constraints that connect newly introduced +/// local identifiers to existing dimensional and / symbolic identifiers. See +/// documentation for AffineExprFlattener on how mod's and div's are flattened. +/// For all affine expressions that share the same operands (like those of an +/// affine map), this method should be used instead of repeatedly calling +/// getFlattenedAffineExpr since local variables added to deal with div's and +/// mod's will be reused across expressions. +bool getFlattenedAffineExprs( + AffineMap map, std::vector> *flattenedExprs, + FlatAffineConstraints *cst = nullptr); +bool getFlattenedAffineExprs( + IntegerSet set, std::vector> *flattenedExprs, + FlatAffineConstraints *cst = nullptr); + +} // end namespace mlir. + +#endif // MLIR_ANALYSIS_AFFINE_STRUCTURES_H diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index 7f38d327a44..6adec6703cf 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -25,8 +25,8 @@ #ifndef MLIR_ANALYSIS_UTILS_H #define MLIR_ANALYSIS_UTILS_H +#include "mlir/Analysis/AffineStructures.h" #include "mlir/IR/AffineMap.h" -#include "mlir/IR/AffineStructures.h" #include "mlir/IR/Block.h" #include "mlir/IR/Location.h" #include "mlir/Support/LLVM.h" @@ -176,7 +176,7 @@ struct MemRefRegion { Optional getConstantBoundOnDimSize(unsigned pos, SmallVectorImpl *lb = nullptr, - int64_t *lbDivisor = nullptr) const { + int64_t *lbFloorDivisor = nullptr) const { assert(pos < getRank() && "invalid position"); return cst.getConstantBoundOnDimSize(pos, lb); } diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h index a652ff6a22f..55b6d046769 100644 --- a/mlir/include/mlir/IR/AffineExpr.h +++ b/mlir/include/mlir/IR/AffineExpr.h @@ -33,7 +33,6 @@ namespace mlir { class MLIRContext; class AffineMap; class IntegerSet; -class FlatAffineConstraints; namespace detail { @@ -272,25 +271,19 @@ AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, /// AffineExprFlattener on how mod's and div's are flattened. bool getFlattenedAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, - llvm::SmallVectorImpl *flattenedExpr, - FlatAffineConstraints *cst = nullptr); + llvm::SmallVectorImpl *flattenedExpr); /// Flattens the result expressions of the map to their corresponding flattened /// forms and set in 'flattenedExprs'. Returns true on success or false /// if any expression in the map could not be flattened (i.e., semi-affine is -/// not yet handled). 'cst' contains constraints that connect newly introduced -/// local identifiers to existing dimensional and / symbolic identifiers. See -/// documentation for AffineExprFlattener on how mod's and div's are flattened. -/// For all affine expressions that share the same operands (like those of an -/// affine map), this method should be used instead of repeatedly calling -/// getFlattenedAffineExpr since local variables added to deal with div's and -/// mod's will be reused across expressions. +/// not yet handled). For all affine expressions that share the same operands +/// (like those of an affine map), this method should be used instead of +/// repeatedly calling getFlattenedAffineExpr since local variables added to +/// deal with div's and mod's will be reused across expressions. bool getFlattenedAffineExprs( - AffineMap map, std::vector> *flattenedExprs, - FlatAffineConstraints *cst = nullptr); + AffineMap map, std::vector> *flattenedExprs); bool getFlattenedAffineExprs( - IntegerSet set, std::vector> *flattenedExprs, - FlatAffineConstraints *cst = nullptr); + IntegerSet set, std::vector> *flattenedExprs); } // namespace mlir diff --git a/mlir/include/mlir/IR/AffineExprVisitor.h b/mlir/include/mlir/IR/AffineExprVisitor.h index b3995352e61..a286781b651 100644 --- a/mlir/include/mlir/IR/AffineExprVisitor.h +++ b/mlir/include/mlir/IR/AffineExprVisitor.h @@ -195,6 +195,142 @@ private: } }; +// This class is used to flatten a pure affine expression (AffineExpr, +// which is in a tree form) into a sum of products (w.r.t constants) when +// possible, and in that process simplifying the expression. For a modulo, +// floordiv, or a ceildiv expression, an additional identifier, called a local +// identifier, is introduced to rewrite the expression as a sum of product +// affine expression. Each local identifier is always and by construction a +// floordiv of a pure add/mul affine function of dimensional, symbolic, and +// other local identifiers, in a non-mutually recursive way. Hence, every local +// identifier can ultimately always be recovered as an affine function of +// dimensional and symbolic identifiers (involving floordiv's); note however +// that by AffineExpr construction, some floordiv combinations are converted to +// mod's. The result of the flattening is a flattened expression and a set of +// constraints involving just the local variables. +// +// d2 + (d0 + d1) floordiv 4 is flattened to d2 + q where 'q' is the local +// variable introduced, with localVarCst containing 4*q <= d0 + d1 <= 4*q + 3. +// +// The simplification performed includes the accumulation of contributions for +// each dimensional and symbolic identifier together, the simplification of +// floordiv/ceildiv/mod expressions and other simplifications that in turn +// happen as a result. A simplification that this flattening naturally performs +// is of simplifying the numerator and denominator of floordiv/ceildiv, and +// folding a modulo expression to a zero, if possible. Three examples are below: +// +// (d0 + 3 * d1) + d0) - 2 * d1) - d0 simplified to d0 + d1 +// (d0 - d0 mod 4 + 4) mod 4 simplified to 0 +// (3*d0 + 2*d1 + d0) floordiv 2 + d1 simplified to 2*d0 + 2*d1 +// +// The way the flattening works for the second example is as follows: d0 % 4 is +// replaced by d0 - 4*q with q being introduced: the expression then simplifies +// to: (d0 - (d0 - 4q) + 4) = 4q + 4, modulo of which w.r.t 4 simplifies to +// zero. Note that an affine expression may not always be expressible purely as +// a sum of products involving just the original dimensional and symbolic +// identifiers due to the presence of modulo/floordiv/ceildiv expressions that +// may not be eliminated after simplification; in such cases, the final +// expression can be reconstructed by replacing the local identifiers with their +// corresponding explicit form stored in 'localExprs' (note that each of the +// explicit forms itself would have been simplified). +// +// The expression walk method here performs a linear time post order walk that +// performs the above simplifications through visit methods, with partial +// results being stored in 'operandExprStack'. When a parent expr is visited, +// the flattened expressions corresponding to its two operands would already be +// on the stack - the parent expression looks at the two flattened expressions +// and combines the two. It pops off the operand expressions and pushes the +// combined result (although this is done in-place on its LHS operand expr). +// When the walk is completed, the flattened form of the top-level expression +// would be left on the stack. +// +// A flattener can be repeatedly used for multiple affine expressions that bind +// to the same operands, for example, for all result expressions of an +// AffineMap or AffineValueMap. In such cases, using it for multiple expressions +// is more efficient than creating a new flattener for each expression since +// common idenical div and mod expressions appearing across different +// expressions are mapped to the same local identifier (same column position in +// 'localVarCst'). +class SimpleAffineExprFlattener + : public AffineExprVisitor { +public: + // Flattend expression layout: [dims, symbols, locals, constant] + // Stack that holds the LHS and RHS operands while visiting a binary op expr. + // In future, consider adding a prepass to determine how big the SmallVector's + // will be, and linearize this to std::vector to prevent + // SmallVector moves on re-allocation. + std::vector> operandExprStack; + // Constraints connecting newly introduced local variables (for mod's and + // div's) to existing (dimensional and symbolic) ones. These are always + // inequalities. + + unsigned numDims; + unsigned numSymbols; + // Number of newly introduced identifiers to flatten mod/floordiv/ceildiv + // expressions that could not be simplified. + unsigned numLocals; + // AffineExpr's corresponding to the floordiv/ceildiv/mod expressions for + // which new identifiers were introduced; if the latter do not get canceled + // out, these expressions can be readily used to reconstruct the AffineExpr + // (tree) form. Note that these expressions themselves would have been + // simplified (recursively) by this pass. Eg. d0 + (d0 + 2*d1 + d0) ceildiv 4 + // will be simplified to d0 + q, where q = (d0 + d1) ceildiv 2. (d0 + d1) + // ceildiv 2 would be the local expression stored for q. + SmallVector localExprs; + MLIRContext *context; + + SimpleAffineExprFlattener(unsigned numDims, unsigned numSymbols, + MLIRContext *context); + + virtual ~SimpleAffineExprFlattener() = default; + + void visitMulExpr(AffineBinaryOpExpr expr); + void visitAddExpr(AffineBinaryOpExpr expr); + + // + // t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1 + // + // A mod expression "expr mod c" is thus flattened by introducing a new local + // variable q (= expr floordiv c), such that expr mod c is replaced with + // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst. + void visitModExpr(AffineBinaryOpExpr expr); + + // t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1 + // A floordiv is thus flattened by introducing a new local variable q, and + // replacing that expression with 'q' while adding the constraints + // c * q <= expr <= c * q + c - 1 to localVarCst (done by + // FlatAffineConstraints::addLocalFloorDiv). + // + // A ceildiv is similarly flattened: + // t = expr ceildiv c <=> t = (expr + c - 1) floordiv c + void visitDivExpr(AffineBinaryOpExpr expr, bool isCeil); + + void visitDimExpr(AffineDimExpr expr); + void visitSymbolExpr(AffineSymbolExpr expr); + void visitConstantExpr(AffineConstantExpr expr); + void visitCeilDivExpr(AffineBinaryOpExpr expr); + void visitFloorDivExpr(AffineBinaryOpExpr expr); + +protected: + // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr). + // The local identifier added is always a floordiv of a pure add/mul affine + // function of other identifiers, coefficients of which are specified in + // dividend and with respect to a positive constant divisor. localExpr is the + // simplified tree expression (AffineExpr) corresponding to the quantifier. + virtual void addLocalFloorDivId(ArrayRef dividend, int64_t divisor, + AffineExpr localExpr); + + int findLocalId(AffineExpr localExpr); + + inline unsigned getNumCols() const { + return numDims + numSymbols + numLocals + 1; + } + inline unsigned getConstantIndex() const { return getNumCols() - 1; } + inline unsigned getLocalVarStartIndex() const { return numDims + numSymbols; } + inline unsigned getSymbolStartIndex() const { return numDims; } + inline unsigned getDimStartIndex() const { return 0; } +}; + } // end namespace mlir #endif // MLIR_IR_AFFINE_EXPR_VISITOR_H diff --git a/mlir/include/mlir/IR/AffineStructures.h b/mlir/include/mlir/IR/AffineStructures.h deleted file mode 100644 index 20ca7d71052..00000000000 --- a/mlir/include/mlir/IR/AffineStructures.h +++ /dev/null @@ -1,708 +0,0 @@ -//===- AffineStructures.h - MLIR Affine Structures Class --------*- C++ -*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// Structures for affine/polyhedral analysis of ML functions. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_IR_AFFINE_STRUCTURES_H -#define MLIR_IR_AFFINE_STRUCTURES_H - -#include "mlir/IR/AffineExpr.h" - -namespace mlir { - -class AffineCondition; -class AffineMap; -class IntegerSet; -class MLIRContext; -class Value; -class HyperRectangularSet; -class MemRefType; - -/// A mutable affine map. Its affine expressions are however unique. -struct MutableAffineMap { -public: - MutableAffineMap() {} - MutableAffineMap(AffineMap map); - - ArrayRef getResults() const { return results; } - AffineExpr getResult(unsigned idx) const { return results[idx]; } - void setResult(unsigned idx, AffineExpr result) { results[idx] = result; } - unsigned getNumResults() const { return results.size(); } - unsigned getNumDims() const { return numDims; } - void setNumDims(unsigned d) { numDims = d; } - unsigned getNumSymbols() const { return numSymbols; } - void setNumSymbols(unsigned d) { numSymbols = d; } - MLIRContext *getContext() const { return context; } - - /// Returns true if the idx'th result expression is a multiple of factor. - bool isMultipleOf(unsigned idx, int64_t factor) const; - - /// Resets this MutableAffineMap with 'map'. - void reset(AffineMap map); - - /// Simplify the (result) expressions in this map using analysis (used by - //-simplify-affine-expr pass). - void simplify(); - /// Get the AffineMap corresponding to this MutableAffineMap. Note that an - /// AffineMap will be uniqued and stored in context, while a mutable one - /// isn't. - AffineMap getAffineMap() const; - -private: - // Same meaning as AffineMap's fields. - SmallVector results; - SmallVector rangeSizes; - unsigned numDims; - unsigned numSymbols; - /// A pointer to the IR's context to store all newly created - /// AffineExprStorage's. - MLIRContext *context; -}; - -/// A mutable integer set. Its affine expressions are however unique. -struct MutableIntegerSet { -public: - MutableIntegerSet(IntegerSet set, MLIRContext *context); - - /// Create a universal set (no constraints). - MutableIntegerSet(unsigned numDims, unsigned numSymbols, - MLIRContext *context); - - unsigned getNumDims() const { return numDims; } - unsigned getNumSymbols() const { return numSymbols; } - unsigned getNumConstraints() const { return constraints.size(); } - - void clear() { - constraints.clear(); - eqFlags.clear(); - } - -private: - unsigned numDims; - unsigned numSymbols; - - SmallVector constraints; - SmallVector eqFlags; - /// A pointer to the IR's context to store all newly created - /// AffineExprStorage's. - MLIRContext *context; -}; - -/// An AffineValueMap is an affine map plus its ML value operands and -/// results for analysis purposes. The structure is still a tree form that is -/// same as that of an affine map or an AffineApplyOp. However, its operands, -/// results, and its map can themselves change as a result of -/// substitutions, simplifications, and other analysis. -// An affine value map can readily be constructed from an AffineApplyOp, or an -// AffineBound of a AffineForOp. It can be further transformed, substituted -// into, or simplified. Unlike AffineMap's, AffineValueMap's are created and -// destroyed during analysis. Only the AffineMap expressions that are pointed by -// them are unique'd. An affine value map, and the operations on it, maintain -// the invariant that operands are always positionally aligned with the -// AffineDimExpr and AffineSymbolExpr in the underlying AffineMap. -// TODO(bondhugula): Some of these classes could go into separate files. -class AffineValueMap { -public: - // Creates an empty AffineValueMap (users should call 'reset' to reset map - // and operands). - AffineValueMap() {} - AffineValueMap(AffineMap map); - AffineValueMap(AffineMap map, ArrayRef operands, - ArrayRef results = llvm::None); - - ~AffineValueMap(); - - // Resets this AffineValueMap with 'map', 'operands', and 'results'. - void reset(AffineMap map, ArrayRef operands, - ArrayRef results = llvm::None); - - /// Return true if the idx^th result can be proved to be a multiple of - /// 'factor', false otherwise. - inline bool isMultipleOf(unsigned idx, int64_t factor) const; - - /// Return true if the idx^th result depends on 'value', false otherwise. - bool isFunctionOf(unsigned idx, Value *value) const; - - /// Return true if the result at 'idx' is a constant, false - /// otherwise. - bool isConstant(unsigned idx) const; - - /// Return true if this is an identity map. - bool isIdentity() const; - - inline unsigned getNumOperands() const { return operands.size(); } - inline unsigned getNumDims() const { return map.getNumDims(); } - inline unsigned getNumSymbols() const { return map.getNumSymbols(); } - inline unsigned getNumResults() const { return map.getNumResults(); } - - Value *getOperand(unsigned i) const; - ArrayRef getOperands() const; - AffineMap getAffineMap() const; - -private: - // A mutable affine map. - MutableAffineMap map; - - // TODO: make these trailing objects? - /// The SSA operands binding to the dim's and symbols of 'map'. - SmallVector operands; - /// The SSA results binding to the results of 'map'. - SmallVector results; -}; - -/// An IntegerValueSet is an integer set plus its operands. -// Both, the integer set being pointed to and the operands can change during -// analysis, simplification, and transformation. -class IntegerValueSet { - /// Constructs an integer value set from an affine value map. - // This will lead to a single equality in 'set'. - explicit IntegerValueSet(const AffineValueMap &avm); - - /// Returns true if this integer set is determined to be empty. Emptiness is - /// checked by by eliminating identifiers successively (through either - /// Gaussian or Fourier-Motzkin) while using the GCD test and a trivial - /// invalid constraint check. Returns 'true' if the constaint system is found - /// to be empty; false otherwise. This method is exact for rational spaces but - /// not integer spaces - thus, if it returns true, the set is provably integer - /// empty as well, but if it returns false, it doesn't necessarily mean an - /// integer point exists in it. This method also returns false where an - /// explosion of constraints is detected - due to the super-exponential - /// worse-case complexity of Fourier-Motzkin elimination (rare for realistic - /// problem cases but possible for artificial adversarial or improperly - // constructed ones), this method returns false conservatively. - bool isEmpty() const; - - bool getNumDims() const { return set.getNumDims(); } - bool getNumSymbols() const { return set.getNumSymbols(); } - -private: - // The set pointed to may itself change unlike in IR structures like - // 'AffineCondition'. - MutableIntegerSet set; - /// The SSA operands binding to the dim's and symbols of 'set'. - SmallVector operands; -}; - -/// A flat list of affine equalities and inequalities in the form. -/// Inequality: c_0*x_0 + c_1*x_1 + .... + c_{n-1}*x_{n-1} == 0 -/// Equality: c_0*x_0 + c_1*x_1 + .... + c_{n-1}*x_{n-1} >= 0 -/// -/// FlatAffineConstraints stores coefficients in a contiguous buffer (one buffer -/// for equalities and one for inequalities). The size of each buffer is -/// numReservedCols * number of inequalities (or equalities). The reserved size -/// is numReservedCols * numReservedInequalities (or numReservedEqualities). A -/// coefficient (r, c) lives at the location numReservedCols * r + c in the -/// buffer. The extra space between getNumCols() and numReservedCols exists to -/// prevent frequent movement of data when adding columns, especially at the -/// end. -/// -/// The identifiers x_0, x_1, ... appear in the order: dimensional identifiers, -/// symbolic identifiers, and local identifiers. The local identifiers -/// correspond to local/internal variables created when converting from -/// AffineExpr's containing mod's and div's; they are thus needed to increase -/// representational power. Each local identifier is always (by construction) a -/// floordiv of a pure add/mul affine function of dimensional, symbolic, and -/// other local identifiers, in a non-mutually recursive way. Hence, every local -/// identifier can ultimately always be recovered as an affine function of -/// dimensional and symbolic identifiers (involving floordiv's); note however -/// that some floordiv combinations are converted to mod's by AffineExpr -/// construction. -/// -class FlatAffineConstraints { -public: - enum IdKind { Dimension, Symbol, Local }; - - /// Constructs a constraint system reserving memory for the specified number - /// of constraints and identifiers.. - FlatAffineConstraints(unsigned numReservedInequalities, - unsigned numReservedEqualities, - unsigned numReservedCols, unsigned numDims = 0, - unsigned numSymbols = 0, unsigned numLocals = 0, - ArrayRef> idArgs = {}) - : numReservedCols(numReservedCols), numDims(numDims), - numSymbols(numSymbols) { - assert(numReservedCols >= numDims + numSymbols + 1); - assert(idArgs.empty() || idArgs.size() == numDims + numSymbols + numLocals); - equalities.reserve(numReservedCols * numReservedEqualities); - inequalities.reserve(numReservedCols * numReservedInequalities); - numIds = numDims + numSymbols + numLocals; - ids.reserve(numReservedCols); - if (idArgs.empty()) - ids.resize(numIds, None); - else - ids.append(idArgs.begin(), idArgs.end()); - } - - /// Constructs a constraint system with the specified number of - /// dimensions and symbols. - FlatAffineConstraints(unsigned numDims = 0, unsigned numSymbols = 0, - unsigned numLocals = 0, - ArrayRef> idArgs = {}) - : numReservedCols(numDims + numSymbols + numLocals + 1), numDims(numDims), - numSymbols(numSymbols) { - assert(numReservedCols >= numDims + numSymbols + 1); - assert(idArgs.empty() || idArgs.size() == numDims + numSymbols + numLocals); - numIds = numDims + numSymbols + numLocals; - ids.reserve(numIds); - if (idArgs.empty()) - ids.resize(numIds, None); - else - ids.append(idArgs.begin(), idArgs.end()); - } - - explicit FlatAffineConstraints(const HyperRectangularSet &set); - - /// Create a flat affine constraint system from an AffineValueMap or a list of - /// these. The constructed system will only include equalities. - // TODO(bondhugula) - explicit FlatAffineConstraints(const AffineValueMap &avm); - explicit FlatAffineConstraints(ArrayRef avmRef); - - /// Creates an affine constraint system from an IntegerSet. - explicit FlatAffineConstraints(IntegerSet set); - - /// Create an affine constraint system from an IntegerValueSet. - // TODO(bondhugula) - explicit FlatAffineConstraints(const IntegerValueSet &set); - - FlatAffineConstraints(const FlatAffineConstraints &other); - - FlatAffineConstraints(ArrayRef avmRef, - IntegerSet set); - - FlatAffineConstraints(const MutableAffineMap &map); - - ~FlatAffineConstraints() {} - - // Clears any existing data and reserves memory for the specified constraints. - void reset(unsigned numReservedInequalities, unsigned numReservedEqualities, - unsigned numReservedCols, unsigned numDims, unsigned numSymbols, - unsigned numLocals = 0, ArrayRef idArgs = {}); - - void reset(unsigned numDims = 0, unsigned numSymbols = 0, - unsigned numLocals = 0, ArrayRef idArgs = {}); - - /// Appends constraints from 'other' into this. This is equivalent to an - /// intersection with no simplification of any sort attempted. - void append(const FlatAffineConstraints &other); - - // Checks for emptiness by performing variable elimination on all identifiers, - // running the GCD test on each equality constraint, and checking for invalid - // constraints. - // Returns true if the GCD test fails for any equality, or if any invalid - // constraints are discovered on any row. Returns false otherwise. - bool isEmpty() const; - - // Runs the GCD test on all equality constraints. Returns 'true' if this test - // fails on any equality. Returns 'false' otherwise. - // This test can be used to disprove the existence of a solution. If it - // returns true, no integer solution to the equality constraints can exist. - bool isEmptyByGCDTest() const; - - // Clones this object. - std::unique_ptr clone() const; - - /// Returns the value at the specified equality row and column. - inline int64_t atEq(unsigned i, unsigned j) const { - return equalities[i * numReservedCols + j]; - } - inline int64_t &atEq(unsigned i, unsigned j) { - return equalities[i * numReservedCols + j]; - } - - inline int64_t atIneq(unsigned i, unsigned j) const { - return inequalities[i * numReservedCols + j]; - } - - inline int64_t &atIneq(unsigned i, unsigned j) { - return inequalities[i * numReservedCols + j]; - } - - /// Returns the number of columns in the constraint system. - inline unsigned getNumCols() const { return numIds + 1; } - - inline unsigned getNumEqualities() const { - assert(equalities.size() % numReservedCols == 0 && - "inconsistent equality buffer size"); - return equalities.size() / numReservedCols; - } - - inline unsigned getNumInequalities() const { - assert(inequalities.size() % numReservedCols == 0 && - "inconsistent inequality buffer size"); - return inequalities.size() / numReservedCols; - } - - inline unsigned getNumReservedEqualities() const { - return equalities.capacity() / numReservedCols; - } - - inline unsigned getNumReservedInequalities() const { - return inequalities.capacity() / numReservedCols; - } - - inline ArrayRef getEquality(unsigned idx) const { - return ArrayRef(&equalities[idx * numReservedCols], getNumCols()); - } - - inline ArrayRef getInequality(unsigned idx) const { - return ArrayRef(&inequalities[idx * numReservedCols], - getNumCols()); - } - - AffineExpr toAffineExpr(unsigned idx, MLIRContext *context); - - /// Computes the lower and upper bounds of the first 'num' dimensional - /// identifiers as an affine map of the remaining identifiers (dimensional and - /// symbolic). This method is able to detect identifiers as floordiv's - /// and mod's of affine expressions of other identifiers with respect to - /// (positive) constants. Sets bound map to a null AffineMap if such a bound - /// can't be found (or yet unimplemented). - void getSliceBounds(unsigned num, MLIRContext *context, - SmallVectorImpl *lbMaps, - SmallVectorImpl *ubMaps); - - /// Adds slice lower bounds represented by lower bounds in 'lbMaps' and upper - /// bounds in 'ubMaps' to the constraint system. Note that both lower/upper - /// bounds share the same operand list 'operands'. - /// This function assumes that position 'lbMaps.size' == 'ubMaps.size', - /// and that positions [0, lbMaps.size) represent dimensional identifiers - /// which correspond to the loop IVs whose iteration bounds are being sliced. - /// Note that both lower/upper bounds use operands from 'operands'. - /// Returns true on success, returns false for unimplemented cases. - bool addSliceBounds(ArrayRef lbMaps, ArrayRef ubMaps, - ArrayRef operands); - - // Adds an inequality (>= 0) from the coefficients specified in inEq. - void addInequality(ArrayRef inEq); - // Adds an equality from the coefficients specified in eq. - void addEquality(ArrayRef eq); - - /// Adds a constant lower bound constraint for the specified identifier. - void addConstantLowerBound(unsigned pos, int64_t lb); - /// Adds a constant upper bound constraint for the specified identifier. - void addConstantUpperBound(unsigned pos, int64_t ub); - - /// Adds a new local identifier as the floordiv of an affine function of other - /// identifiers, the coefficients of which are provided in 'dividend' and with - /// respect to a positive constant 'divisor'. Two constraints are added to the - /// system to capture equivalence with the floordiv: - /// q = dividend floordiv c <=> c*q <= dividend <= c*q + c - 1. - void addLocalFloorDiv(ArrayRef dividend, int64_t divisor); - - /// Adds a constant lower bound constraint for the specified expression. - void addConstantLowerBound(ArrayRef expr, int64_t lb); - /// Adds a constant upper bound constraint for the specified expression. - void addConstantUpperBound(ArrayRef expr, int64_t ub); - - /// Sets the identifier at the specified position to a constant. - void setIdToConstant(unsigned pos, int64_t val); - - /// Sets the identifier corresponding to the specified Value id to a - /// constant. Asserts if the 'id' is not found. - void setIdToConstant(const Value &id, int64_t val); - - /// Looks up the identifier with the specified Value. Returns false if not - /// found, true if found. pos is set to the (column) position of the - /// identifier. - bool findId(const Value &id, unsigned *pos) const; - - // Add identifiers of the specified kind - specified positions are relative to - // the kind of identifier. The coefficient column corresponding to the added - // identifier is initialized to zero. 'id' is the Value corresponding to the - // identifier that can optionally be provided. - void addDimId(unsigned pos, Value *id = nullptr); - void addSymbolId(unsigned pos, Value *id = nullptr); - void addLocalId(unsigned pos); - void addId(IdKind kind, unsigned pos, Value *id = nullptr); - - /// Composes the affine value map with this FlatAffineConstrains, adding the - /// results of the map as dimensions at the front [0, vMap->getNumResults()) - /// and with the dimensions set to the equalities specified by the value map. - /// Returns false if the composition fails (when vMap is a semi-affine map). - /// The vMap's operand Value's are used to look up the right positions in - /// the FlatAffineConstraints with which to associate. The dimensional and - /// symbolic operands of vMap should match 1:1 (in the same order) with those - /// of this constraint system, but the latter could have additional trailing - /// operands. - bool composeMap(AffineValueMap *vMap); - - /// Projects out (aka eliminates) 'num' identifiers starting at position - /// 'pos'. The resulting constraint system is the shadow along the dimensions - /// that still exist. This method may not always be integer exact. - // TODO(bondhugula): deal with integer exactness when necessary - can return a - // value to mark exactness for example. - void projectOut(unsigned pos, unsigned num); - inline void projectOut(unsigned pos) { return projectOut(pos, 1); } - - /// Projects out the identifier that is associate with Value *. - void projectOut(Value *id); - - void removeId(IdKind idKind, unsigned pos); - void removeId(unsigned pos); - - void removeDim(unsigned pos); - - void removeEquality(unsigned pos); - void removeInequality(unsigned pos); - - /// Changes the partition between dimensions and symbols. Depending on the new - /// symbol count, either a chunk of trailing dimensional identifiers becomes - /// symbols, or some of the leading symbols become dimensions. - void setDimSymbolSeparation(unsigned newSymbolCount); - - /// Sets the specified identifier to a constant and removes it. - void setAndEliminate(unsigned pos, int64_t constVal); - - /// Tries to fold the specified identifer to a constant using a trivial - /// equality detection; if successful, the constant is substituted for the - /// identifier everywhere in the constraint system and then removed from the - /// system. Returns true if the folding happens, false otherwise. - bool constantFoldId(unsigned pos); - - /// This method calls constantFoldId for the specified range of identifiers, - /// 'num' identifiers starting at position 'pos'. - void constantFoldIdRange(unsigned pos, unsigned num); - - /// Returns true if all the identifiers in the specified range [start, limit) - /// can only take a single value each if the remaining identifiers are treated - /// as symbols/parameters, i.e., for given values of the latter, there only - /// exists a unique value for each of the dimensions in the specified range. - bool isRangeOneToOne(unsigned start, unsigned limit) const; - - /// Updates the constraints to be the smallest bounding (enclosing) box that - /// contains the points of 'this' set and that of 'other', with the symbols - /// being treated specially. For each of the dimensions, the min of the lower - /// bounds (symbolic) and the max of the upper bounds (symbolic) is computed - /// to determine such a bounding box. - /// - /// Eg: if 'this' is {0 <= d0 <= 127}, 'other' is {16 <= d0 <= 192}, the - /// output is {0 <= d0 <= 192}. - /// 2) 'this' = {s0 + 5 <= d0 <= s0 + 20}, 'other' is {s0 + 1 <= d0 <= s0 + - /// 9}, output = {s0 + 1 <= d0 <= s0 + 20}. - /// 3) 'this' = {0 <= d0 <= 5, 1 <= d1 <= 9}, 'other' = {2 <= d0 <= 6, 5 <= d1 - /// <= 15}, output = {0 <= d0 <= 6, 1 <= d1 <= 15}. - bool unionBoundingBox(const FlatAffineConstraints &other); - - unsigned getNumConstraints() const { - return getNumInequalities() + getNumEqualities(); - } - inline unsigned getNumIds() const { return numIds; } - inline unsigned getNumDimIds() const { return numDims; } - inline unsigned getNumSymbolIds() const { return numSymbols; } - inline unsigned getNumDimAndSymbolIds() const { return numDims + numSymbols; } - inline unsigned getNumLocalIds() const { - return numIds - numDims - numSymbols; - } - - inline ArrayRef> getIds() const { - return {ids.data(), ids.size()}; - } - - /// Returns the Value associated with the pos^th identifier. Asserts if - /// no Value identifier was associated. - inline Value *getIdValue(unsigned pos) const { - assert(ids[pos].hasValue() && "identifier's Value not set"); - return ids[pos].getValue(); - } - - /// Returns the Values associated with identifiers in range [start, end). - /// Asserts if no Value was associated with one of these identifiers. - void getIdValues(unsigned start, unsigned end, - SmallVectorImpl *values) const { - assert((start < numIds || start == end) && "invalid start position"); - assert(end <= numIds && "invalid end position"); - values->clear(); - values->reserve(end - start); - for (unsigned i = start; i < end; i++) { - values->push_back(getIdValue(i)); - } - } - inline void getAllIdValues(SmallVectorImpl *values) const { - getIdValues(0, numIds, values); - } - - /// Sets Value associated with the pos^th identifier. - inline void setIdValue(unsigned pos, Value *val) { - assert(pos < numIds && "invalid id position"); - ids[pos] = val; - } - /// Sets Values associated with identifiers in the range [start, end). - void setIdValues(unsigned start, unsigned end, ArrayRef values) { - assert((start < numIds || end == start) && "invalid start position"); - assert(end <= numIds && "invalid end position"); - assert(values.size() == end - start); - for (unsigned i = start; i < end; ++i) - ids[i] = values[i - start]; - } - - /// Clears this list of constraints and copies other into it. - void clearAndCopyFrom(const FlatAffineConstraints &other); - - /// Returns the smallest known constant bound for the extent of the specified - /// identifier (pos^th), i.e., the smallest known constant that is greater - /// than or equal to 'exclusive upper bound' - 'lower bound' of the - /// identifier. Returns None if it's not a constant. This method employs - /// trivial (low complexity / cost) checks and detection. Symbolic identifiers - /// are treated specially, i.e., it looks for constant differences between - /// affine expressions involving only the symbolic identifiers. See comments - /// at function definition for examples. 'lb' and 'lbDivisor', if provided, - /// are used to express the lower bound associated with the constant - /// difference: 'lb' has the coefficients and lbDivisor, the divisor. For eg., - /// if the lower bound is [(s0 + s2 - 1) floordiv 32] for a system with three - /// symbolic identifiers, *lb = [1, 0, 1], lbDivisor = 32. - Optional - getConstantBoundOnDimSize(unsigned pos, - SmallVectorImpl *lb = nullptr, - int64_t *lbDivisor = nullptr) const; - - /// Returns the constant lower bound for the pos^th identifier if there is - /// one; None otherwise. - Optional getConstantLowerBound(unsigned pos) const; - - /// Returns the constant upper bound for the pos^th identifier if there is - /// one; None otherwise. - Optional getConstantUpperBound(unsigned pos) const; - - /// Gets the lower and upper bound of the pos^th identifier treating - /// [dimStartPos, symbStartPos) as dimensions and [symStartPos, - /// getNumDimAndSymbolIds) as symbols. The returned multi-dimensional maps - /// in the pair represent the max and min of potentially multiple affine - /// expressions. The upper bound is exclusive. 'localExprs' holds pre-computed - /// AffineExpr's for all local identifiers in the system. - std::pair - getLowerAndUpperBound(unsigned pos, unsigned dimStartPos, - unsigned symStartPos, ArrayRef localExprs, - MLIRContext *context); - - /// Returns true if the set can be trivially detected as being - /// hyper-rectangular on the specified contiguous set of identifiers. - bool isHyperRectangular(unsigned pos, unsigned num) const; - - /// Removes duplicates and trivially true constraints: a constraint of the - /// form >= 0 is considered a trivially true - /// constraint. - void removeTrivialRedundancy(); - - /// A more expensive check to detect redundant inequalities. - void removeRedundantInequalities(); - - // Removes all equalities and inequalities. - void clearConstraints(); - - void print(raw_ostream &os) const; - void dump() const; - -private: - /// Returns false if the fields corresponding to various identifier counts, or - /// equality/inequality buffer sizes aren't consistent; true otherwise. This - /// is meant to be used within an assert internally. - bool hasConsistentState() const; - - /// Checks all rows of equality/inequality constraints for trivial - /// contradictions (for example: 1 == 0, 0 >= 1), which may have surfaced - /// after elimination. Returns 'true' if an invalid constraint is found; - /// 'false'otherwise. - bool hasInvalidConstraint() const; - - /// Returns the constant lower bound bound if isLower is true, and the upper - /// bound if isLower is false. - template - Optional computeConstantLowerOrUpperBound(unsigned pos); - - // Eliminates a single identifier at 'position' from equality and inequality - // constraints. Returns 'true' if the identifier was eliminated, and false - // otherwise. - inline bool gaussianEliminateId(unsigned position) { - return gaussianEliminateIds(position, position + 1) == 1; - } - - // Eliminates identifiers from equality and inequality constraints - // in column range [posStart, posLimit). - // Returns the number of variables eliminated. - unsigned gaussianEliminateIds(unsigned posStart, unsigned posLimit); - - /// Eliminates identifier at the specified position using Fourier-Motzkin - /// variable elimination, but uses Gaussian elimination if there is an - /// equality involving that identifier. If the result of the elimination is - /// integer exact, *isResultIntegerExact is set to true. If 'darkShadow' is - /// set to true, a potential under approximation (subset) of the rational - /// shadow / exact integer shadow is computed. - // See implementation comments for more details. - void FourierMotzkinEliminate(unsigned pos, bool darkShadow = false, - bool *isResultIntegerExact = nullptr); - - /// Tightens inequalities given that we are dealing with integer spaces. This - /// is similar to the GCD test but applied to inequalities. The constant term - /// can be reduced to the preceding multiple of the GCD of the coefficients, - /// i.e., - /// 64*i - 100 >= 0 => 64*i - 128 >= 0 (since 'i' is an integer). This is a - /// fast method (linear in the number of coefficients). - void GCDTightenInequalities(); - - /// Normalized each constraints by the GCD of its coefficients. - void normalizeConstraintsByGCD(); - - /// Removes identifiers in column range [idStart, idLimit), and copies any - /// remaining valid data into place, updates member variables, and resizes - /// arrays as needed. - void removeIdRange(unsigned idStart, unsigned idLimit); - - /// Coefficients of affine equalities (in == 0 form). - SmallVector equalities; - - /// Coefficients of affine inequalities (in >= 0 form). - SmallVector inequalities; - - /// Number of columns reserved. Actual ones in used are returned by - /// getNumCols(). - unsigned numReservedCols; - - /// Total number of identifiers. - unsigned numIds; - - /// Number of identifiers corresponding to real dimensions. - unsigned numDims; - - /// Number of identifiers corresponding to symbols (unknown but constant for - /// analysis). - unsigned numSymbols; - - /// Values corresponding to the (column) identifiers of this constraint - /// system appearing in the order the identifiers correspond to columns. - /// Temporary ones or those that aren't associated to any Value are to be - /// set to None. - SmallVector, 8> ids; - - /// A parameter that controls detection of an unrealistic number of - /// constraints. If the number of constraints is this many times the number of - /// variables, we consider such a system out of line with the intended use - /// case of FlatAffineConstraints. - // The rationale for 32 is that in the typical simplest of cases, an - // identifier is expected to have one lower bound and one upper bound - // constraint. With a level of tiling or a connection to another identifier - // through a div or mod, an extra pair of bounds gets added. As a limit, we - // don't expect an identifier to have more than 32 lower/upper/equality - // constraints. This is conservatively set low and can be raised if needed. - constexpr static unsigned kExplosionFactor = 32; -}; - -} // end namespace mlir. - -#endif // MLIR_IR_AFFINE_STRUCTURES_H diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 0b82df271f5..bfed6f8a645 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -16,7 +16,6 @@ // ============================================================================= #include "mlir/AffineOps/AffineOps.h" -#include "mlir/IR/AffineStructures.h" #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" @@ -187,12 +186,6 @@ bool AffineApplyOp::verify() const { return false; } -/// Returns an AffineValueMap representing this affine apply. -AffineValueMap AffineApplyOp::getAsAffineValueMap() { - SmallVector operands(getOperands()); - return AffineValueMap(getAffineMap(), operands, getResult()); -} - // The result of the affine apply operation can be used as a dimension id if it // is a CFG value or if it is an Value, and all the operands are valid // dimension ids. @@ -1033,112 +1026,6 @@ void mlir::extractForInductionVars(ArrayRef> forInsts, ivs->push_back(forInst->getInductionVar()); } -bool mlir::addAffineForOpDomain(ConstOpPointer forOp, - FlatAffineConstraints *constraints) { - unsigned pos; - // Pre-condition for this method. - if (!constraints->findId(*forOp->getInductionVar(), &pos)) { - assert(0 && "Value not found"); - return false; - } - - if (forOp->getStep() != 1) - LLVM_DEBUG(llvm::dbgs() - << "Domain conservative: non-unit stride not handled\n"); - - int64_t step = forOp->getStep(); - - // Adds a lower or upper bound when the bounds aren't constant. - auto addLowerOrUpperBound = [&](bool lower) -> bool { - auto operands = - lower ? forOp->getLowerBoundOperands() : forOp->getUpperBoundOperands(); - for (const auto &operand : operands) { - unsigned pos; - if (!constraints->findId(*operand, &pos)) { - if (isValidSymbol(operand)) { - constraints->addSymbolId(constraints->getNumSymbolIds(), - const_cast(operand)); - pos = constraints->getNumDimAndSymbolIds() - 1; - // Check if the symbol is a constant. - if (auto *opInst = operand->getDefiningInst()) { - if (auto constOp = opInst->dyn_cast()) { - constraints->setIdToConstant(*operand, constOp->getValue()); - } - } - } else { - constraints->addDimId(constraints->getNumDimIds(), - const_cast(operand)); - pos = constraints->getNumDimIds() - 1; - if (auto loop = getForInductionVarOwner(operand)) { - // Outer loop IVs could be used in forOp's bounds. - if (!addAffineForOpDomain(loop, constraints)) - return false; - } - } - } - } - // Record positions of the operands in the constraint system. - SmallVector positions; - for (const auto &operand : operands) { - unsigned pos; - if (!constraints->findId(*operand, &pos)) - assert(0 && "expected to be found"); - positions.push_back(pos); - } - - auto boundMap = - lower ? forOp->getLowerBoundMap() : forOp->getUpperBoundMap(); - - FlatAffineConstraints localVarCst; - std::vector> flatExprs; - if (!getFlattenedAffineExprs(boundMap, &flatExprs, &localVarCst)) { - LLVM_DEBUG(llvm::dbgs() << "semi-affine expressions not yet supported\n"); - return false; - } - if (localVarCst.getNumLocalIds() > 0) { - LLVM_DEBUG(llvm::dbgs() - << "loop bounds with mod/floordiv expr's not yet supported\n"); - return false; - } - - for (const auto &flatExpr : flatExprs) { - SmallVector ineq(constraints->getNumCols(), 0); - ineq[pos] = lower ? 1 : -1; - for (unsigned j = 0, e = boundMap.getNumInputs(); j < e; j++) { - ineq[positions[j]] = lower ? -flatExpr[j] : flatExpr[j]; - } - // Constant term. - ineq[constraints->getNumCols() - 1] = - lower ? -flatExpr[flatExpr.size() - 1] - // Upper bound in flattenedExpr is an exclusive one. - : flatExpr[flatExpr.size() - 1] - step; - constraints->addInequality(ineq); - } - return true; - }; - - if (forOp->hasConstantLowerBound()) { - constraints->addConstantLowerBound(pos, forOp->getConstantLowerBound()); - } else { - // Non-constant lower bound case. - if (!addLowerOrUpperBound(/*lower=*/true)) - return false; - } - - if (forOp->hasConstantUpperBound()) { - constraints->addConstantUpperBound(pos, - forOp->getConstantUpperBound() - step); - return true; - } - // Non-constant upper bound case. - return addLowerOrUpperBound(/*lower=*/false); -} - -/// Returns an AffineValueMap representing this bound. -AffineValueMap AffineBound::getAsAffineValueMap() { - SmallVector operands(getOperands()); - return AffineValueMap(getMap(), operands); -} //===----------------------------------------------------------------------===// // AffineIfOp diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index c60e11dbe0f..a2e679b182d 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -22,9 +22,9 @@ #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/AffineOps/AffineOps.h" +#include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Utils.h" #include "mlir/IR/AffineExprVisitor.h" -#include "mlir/IR/AffineStructures.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Instruction.h" @@ -107,7 +107,7 @@ bool mlir::getIndexSet(MutableArrayRef> forOps, domain->reset(forOps.size(), /*numSymbols=*/0, /*numLocals=*/0, indices); for (auto forOp : forOps) { // Add constraints from forOp's bounds. - if (!addAffineForOpDomain(forOp, domain)) + if (!domain->addAffineForOpDomain(forOp)) return false; } return true; diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp new file mode 100644 index 00000000000..d1f05be9cc4 --- /dev/null +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -0,0 +1,2530 @@ +//===- AffineStructures.cpp - MLIR Affine Structures Class-------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// Structures for affine/polyhedral analysis of MLIR functions. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/AffineStructures.h" +#include "mlir/AffineOps/AffineOps.h" +#include "mlir/IR/AffineExprVisitor.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Instruction.h" +#include "mlir/IR/IntegerSet.h" +#include "mlir/Support/MathExtras.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "affine-structures" + +using namespace mlir; +using namespace llvm; + +namespace { + +// See comments for SimpleAffineExprFlattener. +// An AffineExprFlattener extends a SimpleAffineExprFlattener by recording +// constraint information associated with mod's, floordiv's, and ceildiv's +// in localVarCst. +struct AffineExprFlattener : public SimpleAffineExprFlattener { +public: + // Constraints connecting newly introduced local variables (for mod's and + // div's) to existing (dimensional and symbolic) ones. These are always + // inequalities. + FlatAffineConstraints localVarCst; + + AffineExprFlattener(unsigned nDims, unsigned nSymbols, MLIRContext *ctx) + : SimpleAffineExprFlattener(nDims, nSymbols, ctx) { + localVarCst.reset(nDims, nSymbols, /*numLocals=*/0); + } + +private: + // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr). + // The local identifier added is always a floordiv of a pure add/mul affine + // function of other identifiers, coefficients of which are specified in + // dividend and with respect to a positive constant divisor. localExpr is the + // simplified tree expression (AffineExpr) corresponding to the quantifier. + void addLocalFloorDivId(ArrayRef dividend, int64_t divisor, + AffineExpr localExpr) override { + SimpleAffineExprFlattener::addLocalFloorDivId(dividend, divisor, localExpr); + // Update localVarCst. + localVarCst.addLocalFloorDiv(dividend, divisor); + } +}; + +} // end anonymous namespace + +// Flattens the expressions in map. Returns true on success or false +// if 'expr' was unable to be flattened (i.e., semi-affine expressions not +// handled yet). +static bool getFlattenedAffineExprs( + ArrayRef exprs, unsigned numDims, unsigned numSymbols, + std::vector> *flattenedExprs, + FlatAffineConstraints *localVarCst) { + if (exprs.empty()) { + localVarCst->reset(numDims, numSymbols); + return true; + } + + AffineExprFlattener flattener(numDims, numSymbols, exprs[0].getContext()); + // Use the same flattener to simplify each expression successively. This way + // local identifiers / expressions are shared. + for (auto expr : exprs) { + if (!expr.isPureAffine()) + return false; + + flattener.walkPostOrder(expr); + } + + assert(flattener.operandExprStack.size() == exprs.size()); + flattenedExprs->clear(); + flattenedExprs->assign(flattener.operandExprStack.begin(), + flattener.operandExprStack.end()); + + if (localVarCst) { + localVarCst->clearAndCopyFrom(flattener.localVarCst); + } + + return true; +} + +// Flattens 'expr' into 'flattenedExpr'. Returns true on success or false +// if 'expr' was unable to be flattened (semi-affine expressions not handled +// yet). +bool mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims, + unsigned numSymbols, + llvm::SmallVectorImpl *flattenedExpr, + FlatAffineConstraints *localVarCst) { + std::vector> flattenedExprs; + bool ret = ::getFlattenedAffineExprs({expr}, numDims, numSymbols, + &flattenedExprs, localVarCst); + *flattenedExpr = flattenedExprs[0]; + return ret; +} + +/// Flattens the expressions in map. Returns true on success or false +/// if 'expr' was unable to be flattened (i.e., semi-affine expressions not +/// handled yet). +bool mlir::getFlattenedAffineExprs( + AffineMap map, std::vector> *flattenedExprs, + FlatAffineConstraints *localVarCst) { + if (map.getNumResults() == 0) { + localVarCst->reset(map.getNumDims(), map.getNumSymbols()); + return true; + } + return ::getFlattenedAffineExprs(map.getResults(), map.getNumDims(), + map.getNumSymbols(), flattenedExprs, + localVarCst); +} + +bool mlir::getFlattenedAffineExprs( + IntegerSet set, std::vector> *flattenedExprs, + FlatAffineConstraints *localVarCst) { + if (set.getNumConstraints() == 0) { + localVarCst->reset(set.getNumDims(), set.getNumSymbols()); + return true; + } + return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(), + set.getNumSymbols(), flattenedExprs, + localVarCst); +} + +//===----------------------------------------------------------------------===// +// MutableAffineMap. +//===----------------------------------------------------------------------===// + +MutableAffineMap::MutableAffineMap(AffineMap map) + : numDims(map.getNumDims()), numSymbols(map.getNumSymbols()), + // A map always has at least 1 result by construction + context(map.getResult(0).getContext()) { + for (auto result : map.getResults()) + results.push_back(result); + for (auto rangeSize : map.getRangeSizes()) + results.push_back(rangeSize); +} + +void MutableAffineMap::reset(AffineMap map) { + results.clear(); + rangeSizes.clear(); + numDims = map.getNumDims(); + numSymbols = map.getNumSymbols(); + // A map always has at least 1 result by construction + context = map.getResult(0).getContext(); + for (auto result : map.getResults()) + results.push_back(result); + for (auto rangeSize : map.getRangeSizes()) + results.push_back(rangeSize); +} + +bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const { + if (results[idx].isMultipleOf(factor)) + return true; + + // TODO(bondhugula): use simplifyAffineExpr and FlatAffineConstraints to + // complete this (for a more powerful analysis). + return false; +} + +// Simplifies the result affine expressions of this map. The expressions have to +// be pure for the simplification implemented. +void MutableAffineMap::simplify() { + // Simplify each of the results if possible. + // TODO(ntv): functional-style map + for (unsigned i = 0, e = getNumResults(); i < e; i++) { + results[i] = simplifyAffineExpr(getResult(i), numDims, numSymbols); + } +} + +AffineMap MutableAffineMap::getAffineMap() const { + return AffineMap::get(numDims, numSymbols, results, rangeSizes); +} + +MutableIntegerSet::MutableIntegerSet(IntegerSet set, MLIRContext *context) + : numDims(set.getNumDims()), numSymbols(set.getNumSymbols()), + context(context) { + // TODO(bondhugula) +} + +// Universal set. +MutableIntegerSet::MutableIntegerSet(unsigned numDims, unsigned numSymbols, + MLIRContext *context) + : numDims(numDims), numSymbols(numSymbols), context(context) {} + +//===----------------------------------------------------------------------===// +// AffineValueMap. +//===----------------------------------------------------------------------===// + +AffineValueMap::AffineValueMap(AffineMap map, ArrayRef operands, + ArrayRef results) + : map(map), operands(operands.begin(), operands.end()), + results(results.begin(), results.end()) {} + +AffineValueMap::AffineValueMap(OpPointer applyOp) + : map(applyOp->getAffineMap()), + operands(applyOp->operand_begin(), applyOp->operand_end()) { + results.push_back(applyOp->getResult()); +} + +AffineValueMap::AffineValueMap(AffineBound bound) + : map(bound.getMap()), + operands(bound.operand_begin(), bound.operand_end()) {} + +void AffineValueMap::reset(AffineMap map, ArrayRef operands, + ArrayRef results) { + this->map.reset(map); + this->operands.assign(operands.begin(), operands.end()); + this->results.assign(results.begin(), results.end()); +} + +// Returns true and sets 'indexOfMatch' if 'valueToMatch' is found in +// 'valuesToSearch' beginning at 'indexStart'. Returns false otherwise. +static bool findIndex(Value *valueToMatch, ArrayRef valuesToSearch, + unsigned indexStart, unsigned *indexOfMatch) { + unsigned size = valuesToSearch.size(); + for (unsigned i = indexStart; i < size; ++i) { + if (valueToMatch == valuesToSearch[i]) { + *indexOfMatch = i; + return true; + } + } + return false; +} + +inline bool AffineValueMap::isMultipleOf(unsigned idx, int64_t factor) const { + return map.isMultipleOf(idx, factor); +} + +/// This method uses the invariant that operands are always positionally aligned +/// with the AffineDimExpr in the underlying AffineMap. +bool AffineValueMap::isFunctionOf(unsigned idx, Value *value) const { + unsigned index; + if (!findIndex(value, operands, /*indexStart=*/0, &index)) { + return false; + } + auto expr = const_cast(this)->getAffineMap().getResult(idx); + // TODO(ntv): this is better implemented on a flattened representation. + // At least for now it is conservative. + return expr.isFunctionOfDim(index); +} + +Value *AffineValueMap::getOperand(unsigned i) const { + return static_cast(operands[i]); +} + +ArrayRef AffineValueMap::getOperands() const { + return ArrayRef(operands); +} + +AffineMap AffineValueMap::getAffineMap() const { return map.getAffineMap(); } + +AffineValueMap::~AffineValueMap() {} + +//===----------------------------------------------------------------------===// +// FlatAffineConstraints. +//===----------------------------------------------------------------------===// + +// Copy constructor. +FlatAffineConstraints::FlatAffineConstraints( + const FlatAffineConstraints &other) { + numReservedCols = other.numReservedCols; + numDims = other.getNumDimIds(); + numSymbols = other.getNumSymbolIds(); + numIds = other.getNumIds(); + + auto otherIds = other.getIds(); + ids.reserve(numReservedCols); + ids.append(otherIds.begin(), otherIds.end()); + + unsigned numReservedEqualities = other.getNumReservedEqualities(); + unsigned numReservedInequalities = other.getNumReservedInequalities(); + + equalities.reserve(numReservedEqualities * numReservedCols); + inequalities.reserve(numReservedInequalities * numReservedCols); + + for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) { + addInequality(other.getInequality(r)); + } + for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) { + addEquality(other.getEquality(r)); + } +} + +// Clones this object. +std::unique_ptr FlatAffineConstraints::clone() const { + return std::make_unique(*this); +} + +// Construct from an IntegerSet. +FlatAffineConstraints::FlatAffineConstraints(IntegerSet set) + : numReservedCols(set.getNumOperands() + 1), + numIds(set.getNumDims() + set.getNumSymbols()), numDims(set.getNumDims()), + numSymbols(set.getNumSymbols()) { + equalities.reserve(set.getNumEqualities() * numReservedCols); + inequalities.reserve(set.getNumInequalities() * numReservedCols); + ids.resize(numIds, None); + + // Flatten expressions and add them to the constraint system. + std::vector> flatExprs; + FlatAffineConstraints localVarCst; + if (!getFlattenedAffineExprs(set, &flatExprs, &localVarCst)) { + assert(false && "flattening unimplemented for semi-affine integer sets"); + return; + } + assert(flatExprs.size() == set.getNumConstraints()); + for (unsigned l = 0, e = localVarCst.getNumLocalIds(); l < e; l++) { + addLocalId(getNumLocalIds()); + } + + for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) { + const auto &flatExpr = flatExprs[i]; + assert(flatExpr.size() == getNumCols()); + if (set.getEqFlags()[i]) { + addEquality(flatExpr); + } else { + addInequality(flatExpr); + } + } + // Add the other constraints involving local id's from flattening. + append(localVarCst); +} + +void FlatAffineConstraints::reset(unsigned numReservedInequalities, + unsigned numReservedEqualities, + unsigned newNumReservedCols, + unsigned newNumDims, unsigned newNumSymbols, + unsigned newNumLocals, + ArrayRef idArgs) { + assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 && + "minimum 1 column"); + numReservedCols = newNumReservedCols; + numDims = newNumDims; + numSymbols = newNumSymbols; + numIds = numDims + numSymbols + newNumLocals; + assert(idArgs.empty() || idArgs.size() == numIds); + + clearConstraints(); + if (numReservedEqualities >= 1) + equalities.reserve(newNumReservedCols * numReservedEqualities); + if (numReservedInequalities >= 1) + inequalities.reserve(newNumReservedCols * numReservedInequalities); + if (idArgs.empty()) { + ids.resize(numIds, None); + } else { + ids.assign(idArgs.begin(), idArgs.end()); + } +} + +void FlatAffineConstraints::reset(unsigned newNumDims, unsigned newNumSymbols, + unsigned newNumLocals, + ArrayRef idArgs) { + reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims, + newNumSymbols, newNumLocals, idArgs); +} + +void FlatAffineConstraints::append(const FlatAffineConstraints &other) { + assert(other.getNumCols() == getNumCols()); + assert(other.getNumDimIds() == getNumDimIds()); + assert(other.getNumSymbolIds() == getNumSymbolIds()); + + inequalities.reserve(inequalities.size() + + other.getNumInequalities() * numReservedCols); + equalities.reserve(equalities.size() + + other.getNumEqualities() * numReservedCols); + + for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) { + addInequality(other.getInequality(r)); + } + for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) { + addEquality(other.getEquality(r)); + } +} + +void FlatAffineConstraints::addLocalId(unsigned pos) { + addId(IdKind::Local, pos); +} + +void FlatAffineConstraints::addDimId(unsigned pos, Value *id) { + addId(IdKind::Dimension, pos, id); +} + +void FlatAffineConstraints::addSymbolId(unsigned pos, Value *id) { + addId(IdKind::Symbol, pos, id); +} + +/// Adds a dimensional identifier. The added column is initialized to +/// zero. +void FlatAffineConstraints::addId(IdKind kind, unsigned pos, Value *id) { + if (kind == IdKind::Dimension) { + assert(pos <= getNumDimIds()); + } else if (kind == IdKind::Symbol) { + assert(pos <= getNumSymbolIds()); + } else { + assert(pos <= getNumLocalIds()); + } + + unsigned oldNumReservedCols = numReservedCols; + + // Check if a resize is necessary. + if (getNumCols() + 1 > numReservedCols) { + equalities.resize(getNumEqualities() * (getNumCols() + 1)); + inequalities.resize(getNumInequalities() * (getNumCols() + 1)); + numReservedCols++; + } + + unsigned absolutePos; + + if (kind == IdKind::Dimension) { + absolutePos = pos; + numDims++; + } else if (kind == IdKind::Symbol) { + absolutePos = pos + getNumDimIds(); + numSymbols++; + } else { + absolutePos = pos + getNumDimIds() + getNumSymbolIds(); + } + numIds++; + + // Note that getNumCols() now will already return the new size, which will be + // at least one. + int numInequalities = static_cast(getNumInequalities()); + int numEqualities = static_cast(getNumEqualities()); + int numCols = static_cast(getNumCols()); + for (int r = numInequalities - 1; r >= 0; r--) { + for (int c = numCols - 2; c >= 0; c--) { + if (c < absolutePos) + atIneq(r, c) = inequalities[r * oldNumReservedCols + c]; + else + atIneq(r, c + 1) = inequalities[r * oldNumReservedCols + c]; + } + atIneq(r, absolutePos) = 0; + } + + for (int r = numEqualities - 1; r >= 0; r--) { + for (int c = numCols - 2; c >= 0; c--) { + // All values in column absolutePositions < absolutePos have the same + // coordinates in the 2-d view of the coefficient buffer. + if (c < absolutePos) + atEq(r, c) = equalities[r * oldNumReservedCols + c]; + else + // Those at absolutePosition >= absolutePos, get a shifted + // absolutePosition. + atEq(r, c + 1) = equalities[r * oldNumReservedCols + c]; + } + // Initialize added dimension to zero. + atEq(r, absolutePos) = 0; + } + + // If an 'id' is provided, insert it; otherwise use None. + if (id) { + ids.insert(ids.begin() + absolutePos, id); + } else { + ids.insert(ids.begin() + absolutePos, None); + } + assert(ids.size() == getNumIds()); +} + +// This routine may add additional local variables if the flattened expression +// corresponding to the map has such variables due to the presence of +// mod's, ceildiv's, and floordiv's. +bool FlatAffineConstraints::composeMap(AffineValueMap *vMap) { + // Assert if the map and this constraint set aren't associated with the same + // identifiers in the same order. + assert(vMap->getNumDims() <= getNumDimIds()); + assert(vMap->getNumSymbols() <= getNumSymbolIds()); + for (unsigned i = 0, e = vMap->getNumDims(); i < e; i++) { + assert(ids[i].hasValue()); + assert(vMap->getOperand(i) == ids[i].getValue()); + } + for (unsigned i = 0, e = vMap->getNumSymbols(); i < e; i++) { + assert(ids[numDims + i].hasValue()); + assert(vMap->getOperand(vMap->getNumDims() + i) == + ids[numDims + i].getValue()); + } + + std::vector> flatExprs; + FlatAffineConstraints cst; + if (!getFlattenedAffineExprs(vMap->getAffineMap(), &flatExprs, &cst)) { + LLVM_DEBUG(llvm::dbgs() + << "composition unimplemented for semi-affine maps\n"); + return false; + } + assert(flatExprs.size() == vMap->getNumResults()); + + // Make the value map and the flat affine cst dimensions compatible. + // A lot of this code will be refactored/cleaned up. + // TODO(bondhugula): the next ~20 lines of code is pretty UGLY. This needs + // to be factored out into an FlatAffineConstraints::alignAndMerge(). + for (unsigned l = 0, e = cst.getNumLocalIds(); l < e; l++) { + addLocalId(0); + } + + for (unsigned t = 0, e = vMap->getNumResults(); t < e; t++) { + // TODO: Consider using a batched version to add a range of IDs. + addDimId(0); + cst.addDimId(0); + } + + assert(cst.getNumDimIds() <= getNumDimIds()); + for (unsigned t = 0, e = getNumDimIds() - cst.getNumDimIds(); t < e; t++) { + // Dimensions that are in 'this' but not in vMap/cst are added at the end. + cst.addDimId(cst.getNumDimIds()); + } + assert(cst.getNumSymbolIds() <= getNumSymbolIds()); + for (unsigned t = 0, e = getNumSymbolIds() - cst.getNumSymbolIds(); t < e; + t++) { + // Dimensions that are in 'this' but not in vMap/cst are added at the end. + cst.addSymbolId(cst.getNumSymbolIds()); + } + assert(cst.getNumLocalIds() <= getNumLocalIds()); + for (unsigned t = 0, e = getNumLocalIds() - cst.getNumLocalIds(); t < e; + t++) { + cst.addLocalId(cst.getNumLocalIds()); + } + /// Finally, append cst to this constraint set. + append(cst); + + // We add one equality for each result connecting the result dim of the map to + // the other identifiers. + // For eg: if the expression is 16*i0 + i1, and this is the r^th + // iteration/result of the value map, we are adding the equality: + // d_r - 16*i0 - i1 = 0. Hence, when flattening say (i0 + 1, i0 + 8*i2), we + // add two equalities overall: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0. + for (unsigned r = 0, e = flatExprs.size(); r < e; r++) { + const auto &flatExpr = flatExprs[r]; + // eqToAdd is the equality corresponding to the flattened affine expression. + SmallVector eqToAdd(getNumCols(), 0); + // Set the coefficient for this result to one. + eqToAdd[r] = 1; + + assert(flatExpr.size() >= vMap->getNumOperands() + 1); + + // Dims and symbols. + for (unsigned i = 0, e = vMap->getNumOperands(); i < e; i++) { + unsigned loc; + bool ret = findId(*vMap->getOperand(i), &loc); + assert(ret && "value map's id can't be found"); + (void)ret; + // We need to negate 'eq[r]' since the newly added dimension is going to + // be set to this one. + eqToAdd[loc] = -flatExpr[i]; + } + // Local vars common to eq and cst are at the beginning. + int j = getNumDimIds() + getNumSymbolIds(); + int end = flatExpr.size() - 1; + for (int i = vMap->getNumOperands(); i < end; i++, j++) { + eqToAdd[j] = -flatExpr[i]; + } + + // Constant term. + eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1]; + + // Add the equality connecting the result of the map to this constraint set. + addEquality(eqToAdd); + } + + return true; +} + +bool FlatAffineConstraints::addAffineForOpDomain( + ConstOpPointer forOp) { + unsigned pos; + // Pre-condition for this method. + if (!findId(*forOp->getInductionVar(), &pos)) { + assert(0 && "Value not found"); + return false; + } + + if (forOp->getStep() != 1) + LLVM_DEBUG(llvm::dbgs() + << "Domain conservative: non-unit stride not handled\n"); + + int64_t step = forOp->getStep(); + + // Adds a lower or upper bound when the bounds aren't constant. + auto addLowerOrUpperBound = [&](bool lower) -> bool { + auto operands = + lower ? forOp->getLowerBoundOperands() : forOp->getUpperBoundOperands(); + for (const auto &operand : operands) { + unsigned pos; + if (!findId(*operand, &pos)) { + if (isValidSymbol(operand)) { + addSymbolId(getNumSymbolIds(), const_cast(operand)); + pos = getNumDimAndSymbolIds() - 1; + // Check if the symbol is a constant. + if (auto *opInst = operand->getDefiningInst()) { + if (auto constOp = opInst->dyn_cast()) { + setIdToConstant(*operand, constOp->getValue()); + } + } + } else { + addDimId(getNumDimIds(), const_cast(operand)); + pos = getNumDimIds() - 1; + if (auto loop = getForInductionVarOwner(operand)) { + // Outer loop IVs could be used in forOp's bounds. + if (!this->addAffineForOpDomain(loop)) + return false; + } + } + } + } + // Record positions of the operands in the constraint system. + SmallVector positions; + for (const auto &operand : operands) { + unsigned pos; + if (!findId(*operand, &pos)) + assert(0 && "expected to be found"); + positions.push_back(pos); + } + + auto boundMap = + lower ? forOp->getLowerBoundMap() : forOp->getUpperBoundMap(); + + FlatAffineConstraints localVarCst; + std::vector> flatExprs; + if (!getFlattenedAffineExprs(boundMap, &flatExprs, &localVarCst)) { + LLVM_DEBUG(llvm::dbgs() << "semi-affine expressions not yet supported\n"); + return false; + } + if (localVarCst.getNumLocalIds() > 0) { + LLVM_DEBUG(llvm::dbgs() + << "loop bounds with mod/floordiv expr's not yet supported\n"); + return false; + } + + for (const auto &flatExpr : flatExprs) { + SmallVector ineq(getNumCols(), 0); + ineq[pos] = lower ? 1 : -1; + for (unsigned j = 0, e = boundMap.getNumInputs(); j < e; j++) { + ineq[positions[j]] = lower ? -flatExpr[j] : flatExpr[j]; + } + // Constant term. + ineq[getNumCols() - 1] = + lower ? -flatExpr[flatExpr.size() - 1] + // Upper bound in flattenedExpr is an exclusive one. + : flatExpr[flatExpr.size() - 1] - step; + addInequality(ineq); + } + return true; + }; + + if (forOp->hasConstantLowerBound()) { + addConstantLowerBound(pos, forOp->getConstantLowerBound()); + } else { + // Non-constant lower bound case. + if (!addLowerOrUpperBound(/*lower=*/true)) + return false; + } + + if (forOp->hasConstantUpperBound()) { + addConstantUpperBound(pos, forOp->getConstantUpperBound() - step); + return true; + } + // Non-constant upper bound case. + return addLowerOrUpperBound(/*lower=*/false); +} + +// Searches for a constraint with a non-zero coefficient at 'colIdx' in +// equality (isEq=true) or inequality (isEq=false) constraints. +// Returns true and sets row found in search in 'rowIdx'. +// Returns false otherwise. +static bool +findConstraintWithNonZeroAt(const FlatAffineConstraints &constraints, + unsigned colIdx, bool isEq, unsigned *rowIdx) { + auto at = [&](unsigned rowIdx) -> int64_t { + return isEq ? constraints.atEq(rowIdx, colIdx) + : constraints.atIneq(rowIdx, colIdx); + }; + unsigned e = + isEq ? constraints.getNumEqualities() : constraints.getNumInequalities(); + for (*rowIdx = 0; *rowIdx < e; ++(*rowIdx)) { + if (at(*rowIdx) != 0) { + return true; + } + } + return false; +} + +// Normalizes the coefficient values across all columns in 'rowIDx' by their +// GCD in equality or inequality contraints as specified by 'isEq'. +template +static void normalizeConstraintByGCD(FlatAffineConstraints *constraints, + unsigned rowIdx) { + auto at = [&](unsigned colIdx) -> int64_t { + return isEq ? constraints->atEq(rowIdx, colIdx) + : constraints->atIneq(rowIdx, colIdx); + }; + uint64_t gcd = std::abs(at(0)); + for (unsigned j = 1, e = constraints->getNumCols(); j < e; ++j) { + gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(at(j))); + } + if (gcd > 0 && gcd != 1) { + for (unsigned j = 0, e = constraints->getNumCols(); j < e; ++j) { + int64_t v = at(j) / static_cast(gcd); + isEq ? constraints->atEq(rowIdx, j) = v + : constraints->atIneq(rowIdx, j) = v; + } + } +} + +void FlatAffineConstraints::normalizeConstraintsByGCD() { + for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { + normalizeConstraintByGCD(this, i); + } + for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { + normalizeConstraintByGCD(this, i); + } +} + +bool FlatAffineConstraints::hasConsistentState() const { + if (inequalities.size() != getNumInequalities() * numReservedCols) + return false; + if (equalities.size() != getNumEqualities() * numReservedCols) + return false; + if (ids.size() != getNumIds()) + return false; + + // Catches errors where numDims, numSymbols, numIds aren't consistent. + if (numDims > numIds || numSymbols > numIds || numDims + numSymbols > numIds) + return false; + + return true; +} + +/// Checks all rows of equality/inequality constraints for trivial +/// contradictions (for example: 1 == 0, 0 >= 1), which may have surfaced +/// after elimination. Returns 'true' if an invalid constraint is found; +/// 'false' otherwise. +bool FlatAffineConstraints::hasInvalidConstraint() const { + assert(hasConsistentState()); + auto check = [&](bool isEq) -> bool { + unsigned numCols = getNumCols(); + unsigned numRows = isEq ? getNumEqualities() : getNumInequalities(); + for (unsigned i = 0, e = numRows; i < e; ++i) { + unsigned j; + for (j = 0; j < numCols - 1; ++j) { + int64_t v = isEq ? atEq(i, j) : atIneq(i, j); + // Skip rows with non-zero variable coefficients. + if (v != 0) + break; + } + if (j < numCols - 1) { + continue; + } + // Check validity of constant term at 'numCols - 1' w.r.t 'isEq'. + // Example invalid constraints include: '1 == 0' or '-1 >= 0' + int64_t v = isEq ? atEq(i, numCols - 1) : atIneq(i, numCols - 1); + if ((isEq && v != 0) || (!isEq && v < 0)) { + return true; + } + } + return false; + }; + if (check(/*isEq=*/true)) + return true; + return check(/*isEq=*/false); +} + +// Eliminate identifier from constraint at 'rowIdx' based on coefficient at +// pivotRow, pivotCol. Columns in range [elimColStart, pivotCol) will not be +// updated as they have already been eliminated. +static void eliminateFromConstraint(FlatAffineConstraints *constraints, + unsigned rowIdx, unsigned pivotRow, + unsigned pivotCol, unsigned elimColStart, + bool isEq) { + // Skip if equality 'rowIdx' if same as 'pivotRow'. + if (isEq && rowIdx == pivotRow) + return; + auto at = [&](unsigned i, unsigned j) -> int64_t { + return isEq ? constraints->atEq(i, j) : constraints->atIneq(i, j); + }; + int64_t leadCoeff = at(rowIdx, pivotCol); + // Skip if leading coefficient at 'rowIdx' is already zero. + if (leadCoeff == 0) + return; + int64_t pivotCoeff = constraints->atEq(pivotRow, pivotCol); + int64_t sign = (leadCoeff * pivotCoeff > 0) ? -1 : 1; + int64_t lcm = mlir::lcm(pivotCoeff, leadCoeff); + int64_t pivotMultiplier = sign * (lcm / std::abs(pivotCoeff)); + int64_t rowMultiplier = lcm / std::abs(leadCoeff); + + unsigned numCols = constraints->getNumCols(); + for (unsigned j = 0; j < numCols; ++j) { + // Skip updating column 'j' if it was just eliminated. + if (j >= elimColStart && j < pivotCol) + continue; + int64_t v = pivotMultiplier * constraints->atEq(pivotRow, j) + + rowMultiplier * at(rowIdx, j); + isEq ? constraints->atEq(rowIdx, j) = v + : constraints->atIneq(rowIdx, j) = v; + } +} + +// Remove coefficients in column range [colStart, colLimit) in place. +// This removes in data in the specified column range, and copies any +// remaining valid data into place. +static void shiftColumnsToLeft(FlatAffineConstraints *constraints, + unsigned colStart, unsigned colLimit, + bool isEq) { + assert(colStart >= 0 && colLimit <= constraints->getNumIds()); + if (colLimit <= colStart) + return; + + unsigned numCols = constraints->getNumCols(); + unsigned numRows = isEq ? constraints->getNumEqualities() + : constraints->getNumInequalities(); + unsigned numToEliminate = colLimit - colStart; + for (unsigned r = 0, e = numRows; r < e; ++r) { + for (unsigned c = colLimit; c < numCols; ++c) { + if (isEq) { + constraints->atEq(r, c - numToEliminate) = constraints->atEq(r, c); + } else { + constraints->atIneq(r, c - numToEliminate) = constraints->atIneq(r, c); + } + } + } +} + +// Removes identifiers in column range [idStart, idLimit), and copies any +// remaining valid data into place, and updates member variables. +void FlatAffineConstraints::removeIdRange(unsigned idStart, unsigned idLimit) { + assert(idLimit < getNumCols() && "invalid id limit"); + + if (idStart >= idLimit) + return; + + // We are going to be removing one or more identifiers from the range. + assert(idStart < numIds && "invalid idStart position"); + + // TODO(andydavis) Make 'removeIdRange' a lambda called from here. + // Remove eliminated identifiers from equalities. + shiftColumnsToLeft(this, idStart, idLimit, /*isEq=*/true); + + // Remove eliminated identifiers from inequalities. + shiftColumnsToLeft(this, idStart, idLimit, /*isEq=*/false); + + // Update members numDims, numSymbols and numIds. + unsigned numDimsEliminated = 0; + unsigned numLocalsEliminated = 0; + unsigned numColsEliminated = idLimit - idStart; + if (idStart < numDims) { + numDimsEliminated = std::min(numDims, idLimit) - idStart; + } + // Check how many local id's were removed. Note that our identifier order is + // [dims, symbols, locals]. Local id start at position numDims + numSymbols. + if (idLimit > numDims + numSymbols) { + numLocalsEliminated = std::min( + idLimit - std::max(idStart, numDims + numSymbols), getNumLocalIds()); + } + unsigned numSymbolsEliminated = + numColsEliminated - numDimsEliminated - numLocalsEliminated; + + numDims -= numDimsEliminated; + numSymbols -= numSymbolsEliminated; + numIds = numIds - numColsEliminated; + + ids.erase(ids.begin() + idStart, ids.begin() + idLimit); + + // No resize necessary. numReservedCols remains the same. +} + +/// Returns the position of the identifier that has the minimum times from the specified range of +/// identifiers [start, end). It is often best to eliminate in the increasing +/// order of these counts when doing Fourier-Motzkin elimination since FM adds +/// that many new constraints. +static unsigned getBestIdToEliminate(const FlatAffineConstraints &cst, + unsigned start, unsigned end) { + assert(start < cst.getNumIds() && end < cst.getNumIds() + 1); + + auto getProductOfNumLowerUpperBounds = [&](unsigned pos) { + unsigned numLb = 0; + unsigned numUb = 0; + for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) { + if (cst.atIneq(r, pos) > 0) { + ++numLb; + } else if (cst.atIneq(r, pos) < 0) { + ++numUb; + } + } + return numLb * numUb; + }; + + unsigned minLoc = start; + unsigned min = getProductOfNumLowerUpperBounds(start); + for (unsigned c = start + 1; c < end; c++) { + unsigned numLbUbProduct = getProductOfNumLowerUpperBounds(c); + if (numLbUbProduct < min) { + min = numLbUbProduct; + minLoc = c; + } + } + return minLoc; +} + +// Checks for emptiness of the set by eliminating identifiers successively and +// using the GCD test (on all equality constraints) and checking for trivially +// invalid constraints. Returns 'true' if the constraint system is found to be +// empty; false otherwise. +bool FlatAffineConstraints::isEmpty() const { + if (isEmptyByGCDTest() || hasInvalidConstraint()) + return true; + + // First, eliminate as many identifiers as possible using Gaussian + // elimination. + FlatAffineConstraints tmpCst(*this); + unsigned currentPos = 0; + while (currentPos < tmpCst.getNumIds()) { + tmpCst.gaussianEliminateIds(currentPos, tmpCst.getNumIds()); + ++currentPos; + // We check emptiness through trivial checks after eliminating each ID to + // detect emptiness early. Since the checks isEmptyByGCDTest() and + // hasInvalidConstraint() are linear time and single sweep on the constraint + // buffer, this appears reasonable - but can optimize in the future. + if (tmpCst.hasInvalidConstraint() || tmpCst.isEmptyByGCDTest()) + return true; + } + + // Eliminate the remaining using FM. + for (unsigned i = 0, e = tmpCst.getNumIds(); i < e; i++) { + tmpCst.FourierMotzkinEliminate( + getBestIdToEliminate(tmpCst, 0, tmpCst.getNumIds())); + // Check for a constraint explosion. This rarely happens in practice, but + // this check exists as a safeguard against improperly constructed + // constraint systems or artifically created arbitrarily complex systems + // that aren't the intended use case for FlatAffineConstraints. This is + // needed since FM has a worst case exponential complexity in theory. + if (tmpCst.getNumConstraints() >= kExplosionFactor * getNumIds()) { + LLVM_DEBUG(llvm::dbgs() << "FM constraint explosion detected"); + return false; + } + + // FM wouldn't have modified the equalities in any way. So no need to again + // run GCD test. Check for trivial invalid constraints. + if (tmpCst.hasInvalidConstraint()) + return true; + } + return false; +} + +// Runs the GCD test on all equality constraints. Returns 'true' if this test +// fails on any equality. Returns 'false' otherwise. +// This test can be used to disprove the existence of a solution. If it returns +// true, no integer solution to the equality constraints can exist. +// +// GCD test definition: +// +// The equality constraint: +// +// c_1*x_1 + c_2*x_2 + ... + c_n*x_n = c_0 +// +// has an integer solution iff: +// +// GCD of c_1, c_2, ..., c_n divides c_0. +// +bool FlatAffineConstraints::isEmptyByGCDTest() const { + assert(hasConsistentState()); + unsigned numCols = getNumCols(); + for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { + uint64_t gcd = std::abs(atEq(i, 0)); + for (unsigned j = 1; j < numCols - 1; ++j) { + gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atEq(i, j))); + } + int64_t v = std::abs(atEq(i, numCols - 1)); + if (gcd > 0 && (v % gcd != 0)) { + return true; + } + } + return false; +} + +/// Tightens inequalities given that we are dealing with integer spaces. This is +/// analogous to the GCD test but applied to inequalities. The constant term can +/// be reduced to the preceding multiple of the GCD of the coefficients, i.e., +/// 64*i - 100 >= 0 => 64*i - 128 >= 0 (since 'i' is an integer). This is a +/// fast method - linear in the number of coefficients. +// Example on how this affects practical cases: consider the scenario: +// 64*i >= 100, j = 64*i; without a tightening, elimination of i would yield +// j >= 100 instead of the tighter (exact) j >= 128. +void FlatAffineConstraints::GCDTightenInequalities() { + unsigned numCols = getNumCols(); + for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { + uint64_t gcd = std::abs(atIneq(i, 0)); + for (unsigned j = 1; j < numCols - 1; ++j) { + gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atIneq(i, j))); + } + if (gcd > 0) { + int64_t gcdI = static_cast(gcd); + atIneq(i, numCols - 1) = + gcdI * mlir::floorDiv(atIneq(i, numCols - 1), gcdI); + } + } +} + +// Eliminates all identifer variables in column range [posStart, posLimit). +// Returns the number of variables eliminated. +unsigned FlatAffineConstraints::gaussianEliminateIds(unsigned posStart, + unsigned posLimit) { + // Return if identifier positions to eliminate are out of range. + assert(posLimit <= numIds); + assert(hasConsistentState()); + + if (posStart >= posLimit) + return 0; + + GCDTightenInequalities(); + + unsigned pivotCol = 0; + for (pivotCol = posStart; pivotCol < posLimit; ++pivotCol) { + // Find a row which has a non-zero coefficient in column 'j'. + unsigned pivotRow; + if (!findConstraintWithNonZeroAt(*this, pivotCol, /*isEq=*/true, + &pivotRow)) { + // No pivot row in equalities with non-zero at 'pivotCol'. + if (!findConstraintWithNonZeroAt(*this, pivotCol, /*isEq=*/false, + &pivotRow)) { + // If inequalities are also non-zero in 'pivotCol', it can be + // eliminated. + continue; + } + break; + } + + // Eliminate identifier at 'pivotCol' from each equality row. + for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { + eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart, + /*isEq=*/true); + normalizeConstraintByGCD(this, i); + } + + // Eliminate identifier at 'pivotCol' from each inequality row. + for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { + eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart, + /*isEq=*/false); + normalizeConstraintByGCD(this, i); + } + removeEquality(pivotRow); + } + // Update position limit based on number eliminated. + posLimit = pivotCol; + // Remove eliminated columns from all constraints. + removeIdRange(posStart, posLimit); + return posLimit - posStart; +} + +// Detect the identifier at 'pos' (say id_r) as modulo of another identifier +// (say id_n) w.r.t a constant. When this happens, another identifier (say id_q) +// could be detected as the floordiv of n. For eg: +// id_n - 4*id_q - id_r = 0, 0 <= id_r <= 3 <=> +// id_r = id_n mod 4, id_q = id_n floordiv 4. +// lbConst and ubConst are the constant lower and upper bounds for 'pos' - +// pre-detected at the caller. +static bool detectAsMod(const FlatAffineConstraints &cst, unsigned pos, + int64_t lbConst, int64_t ubConst, + SmallVectorImpl *memo) { + assert(pos < cst.getNumIds() && "invalid position"); + + // Check if 0 <= id_r <= divisor - 1 and if id_r is equal to + // id_n - divisor * id_q. If these are true, then id_n becomes the dividend + // and id_q the quotient when dividing id_n by the divisor. + + if (lbConst != 0 || ubConst < 1) + return false; + + int64_t divisor = ubConst + 1; + + // Now check for: id_r = id_n - divisor * id_q. As an example, we + // are looking r = d - 4q, i.e., either r - d + 4q = 0 or -r + d - 4q = 0. + unsigned seenQuotient = 0, seenDividend = 0; + int quotientPos = -1, dividendPos = -1; + for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) { + // id_n should have coeff 1 or -1. + if (std::abs(cst.atEq(r, pos)) != 1) + continue; + for (unsigned c = 0, f = cst.getNumDimAndSymbolIds(); c < f; c++) { + // The coeff of the quotient should be -divisor if the coefficient of + // the pos^th identifier is -1, and divisor if the latter is -1. + if (cst.atEq(r, c) * cst.atEq(r, pos) == divisor) { + seenQuotient++; + quotientPos = c; + } else if (cst.atEq(r, c) * cst.atEq(r, pos) == -1) { + seenDividend++; + dividendPos = c; + } + } + // We are looking for exactly one identifier as part of the dividend. + // TODO(bondhugula): could be extended to cover multiple ones in the + // dividend to detect mod of an affine function of identifiers. + if (seenDividend == 1 && seenQuotient >= 1) { + if (!(*memo)[dividendPos]) + return false; + // Successfully detected a mod. + (*memo)[pos] = (*memo)[dividendPos] % divisor; + if (seenQuotient == 1 && !(*memo)[quotientPos]) + // Successfully detected a floordiv as well. + (*memo)[quotientPos] = (*memo)[dividendPos].floorDiv(divisor); + return true; + } + } + return false; +} + +// Gather lower and upper bounds for the pos^th identifier. +static void getLowerAndUpperBoundIndices(const FlatAffineConstraints &cst, + unsigned pos, + SmallVectorImpl *lbIndices, + SmallVectorImpl *ubIndices) { + assert(pos < cst.getNumIds() && "invalid position"); + + // Gather all lower bounds and upper bounds of the variable. Since the + // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower + // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. + for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) { + if (cst.atIneq(r, pos) >= 1) { + // Lower bound. + lbIndices->push_back(r); + } else if (cst.atIneq(r, pos) <= -1) { + // Upper bound. + ubIndices->push_back(r); + } + } +} + +// Check if the pos^th identifier can be expressed as a floordiv of an affine +// function of other identifiers (where the divisor is a positive constant). +// For eg: 4q <= i + j <= 4q + 3 <=> q = (i + j) floordiv 4. +bool detectAsFloorDiv(const FlatAffineConstraints &cst, unsigned pos, + SmallVectorImpl *memo, MLIRContext *context) { + assert(pos < cst.getNumIds() && "invalid position"); + + SmallVector lbIndices, ubIndices; + getLowerAndUpperBoundIndices(cst, pos, &lbIndices, &ubIndices); + + // Check if any lower bound, upper bound pair is of the form: + // divisor * id >= expr - (divisor - 1) <-- Lower bound for 'id' + // divisor * id <= expr <-- Upper bound for 'id' + // Then, 'id' is equivalent to 'expr floordiv divisor'. (where divisor > 1). + // + // For example, if -32*k + 16*i + j >= 0 + // 32*k - 16*i - j + 31 >= 0 <=> + // k = ( 16*i + j ) floordiv 32 + unsigned seenDividends = 0; + for (auto ubPos : ubIndices) { + for (auto lbPos : lbIndices) { + // Check if lower bound's constant term is 'divisor - 1'. The 'divisor' + // here is cst.atIneq(lbPos, pos) and we already know that it's positive + // (since cst.Ineq(lbPos, ...) is a lower bound expression for 'pos'. + if (cst.atIneq(lbPos, cst.getNumCols() - 1) != cst.atIneq(lbPos, pos) - 1) + continue; + // Check if upper bound's constant term is 0. + if (cst.atIneq(ubPos, cst.getNumCols() - 1) != 0) + continue; + // For the remaining part, check if the lower bound expr's coeff's are + // negations of corresponding upper bound ones'. + unsigned c, f; + for (c = 0, f = cst.getNumCols() - 1; c < f; c++) { + if (cst.atIneq(lbPos, c) != -cst.atIneq(ubPos, c)) + break; + if (c != pos && cst.atIneq(lbPos, c) != 0) + seenDividends++; + } + // Lb coeff's aren't negative of ub coeff's (for the non constant term + // part). + if (c < f) + continue; + if (seenDividends >= 1) { + // The divisor is the constant term of the lower bound expression. + // We already know that cst.atIneq(lbPos, pos) > 0. + int64_t divisor = cst.atIneq(lbPos, pos); + // Construct the dividend expression. + auto dividendExpr = getAffineConstantExpr(0, context); + unsigned c, f; + for (c = 0, f = cst.getNumCols() - 1; c < f; c++) { + if (c == pos) + continue; + int64_t ubVal = cst.atIneq(ubPos, c); + if (ubVal == 0) + continue; + if (!(*memo)[c]) + break; + dividendExpr = dividendExpr + ubVal * (*memo)[c]; + } + // Expression can't be constructed as it depends on a yet unknown + // identifier. + // TODO(mlir-team): Visit/compute the identifiers in an order so that + // this doesn't happen. More complex but much more efficient. + if (c < f) + continue; + // Successfully detected the floordiv. + (*memo)[pos] = dividendExpr.floorDiv(divisor); + return true; + } + } + } + return false; +} + +// Fills an inequality row with the value 'val'. +static inline void fillInequality(FlatAffineConstraints *cst, unsigned r, + int64_t val) { + for (unsigned c = 0, f = cst->getNumCols(); c < f; c++) { + cst->atIneq(r, c) = val; + } +} + +// Negates an inequality. +static inline void negateInequality(FlatAffineConstraints *cst, unsigned r) { + for (unsigned c = 0, f = cst->getNumCols(); c < f; c++) { + cst->atIneq(r, c) = -cst->atIneq(r, c); + } +} + +// A more complex check to eliminate redundant inequalities. +void FlatAffineConstraints::removeRedundantInequalities() { + SmallVector redun(getNumInequalities(), false); + // To check if an inequality is redundant, we replace the inequality by its + // complement (for eg., i - 1 >= 0 by i <= 0), and check if the resulting + // system is empty. If it is, the inequality is redundant. + FlatAffineConstraints tmpCst(*this); + for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { + // Change the inequality to its complement. + negateInequality(&tmpCst, r); + tmpCst.atIneq(r, tmpCst.getNumCols() - 1)--; + if (tmpCst.isEmpty()) { + redun[r] = true; + // Zero fill the redundant inequality. + fillInequality(this, r, /*val=*/0); + fillInequality(&tmpCst, r, /*val=*/0); + } else { + // Reverse the change (to avoid recreating tmpCst each time). + tmpCst.atIneq(r, tmpCst.getNumCols() - 1)++; + negateInequality(&tmpCst, r); + } + } + + // Scan to get rid of all rows marked redundant, in-place. + auto copyRow = [&](unsigned src, unsigned dest) { + if (src == dest) + return; + for (unsigned c = 0, e = getNumCols(); c < e; c++) { + atIneq(dest, c) = atIneq(src, c); + } + }; + unsigned pos = 0; + for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { + if (!redun[r]) + copyRow(r, pos++); + } + inequalities.resize(numReservedCols * pos); +} + +std::pair FlatAffineConstraints::getLowerAndUpperBound( + unsigned pos, unsigned dimStartPos, unsigned symStartPos, + ArrayRef localExprs, MLIRContext *context) { + assert(pos < dimStartPos && "invalid dim start pos"); + assert(symStartPos >= dimStartPos && "invalid sym start pos"); + assert(getNumLocalIds() == localExprs.size() && + "incorrect local exprs count"); + + SmallVector lbIndices, ubIndices; + getLowerAndUpperBoundIndices(*this, pos, &lbIndices, &ubIndices); + + SmallVector lb, ub; + SmallVector exprs; + unsigned dimCount = symStartPos - dimStartPos; + unsigned symCount = getNumDimAndSymbolIds() - symStartPos; + exprs.reserve(lbIndices.size()); + // Lower bound expressions. + for (auto idx : lbIndices) { + auto ineq = getInequality(idx); + // Extract the lower bound (in terms of other coeff's + const), i.e., if + // i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j + // - 1. + lb.assign(ineq.begin() + dimStartPos, ineq.end()); + std::transform(lb.begin(), lb.end(), lb.begin(), std::negate()); + auto expr = mlir::toAffineExpr(lb, dimCount, symCount, localExprs, context); + exprs.push_back(expr); + } + auto lbMap = exprs.empty() ? AffineMap() + : AffineMap::get(dimCount, symCount, exprs, {}); + + exprs.clear(); + exprs.reserve(ubIndices.size()); + // Upper bound expressions. + for (auto idx : ubIndices) { + auto ineq = getInequality(idx); + // Extract the upper bound (in terms of other coeff's + const). + ub.assign(ineq.begin() + dimStartPos, ineq.end()); + auto expr = mlir::toAffineExpr(ub, dimCount, symCount, localExprs, context); + // Upper bound is exclusive. + exprs.push_back(expr + 1); + } + auto ubMap = exprs.empty() ? AffineMap() + : AffineMap::get(dimCount, symCount, exprs, {}); + + return {lbMap, ubMap}; +} + +/// Computes the lower and upper bounds of the first 'num' dimensional +/// identifiers as affine maps of the remaining identifiers (dimensional and +/// symbolic identifiers). Local identifiers are themselves explicitly computed +/// as affine functions of other identifiers in this process if needed. +void FlatAffineConstraints::getSliceBounds(unsigned num, MLIRContext *context, + SmallVectorImpl *lbMaps, + SmallVectorImpl *ubMaps) { + assert(num < getNumDimIds() && "invalid range"); + + // Basic simplification. + normalizeConstraintsByGCD(); + + LLVM_DEBUG(llvm::dbgs() << "getSliceBounds on:\n"); + LLVM_DEBUG(dump()); + + // Record computed/detected identifiers. + SmallVector memo(getNumIds(), AffineExpr::Null()); + // Initialize dimensional and symbolic identifiers. + for (unsigned i = num, e = getNumDimIds(); i < e; i++) + memo[i] = getAffineDimExpr(i - num, context); + for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++) + memo[i] = getAffineSymbolExpr(i - getNumDimIds(), context); + + bool changed; + do { + changed = false; + // Identify yet unknown identifiers as constants or mod's / floordiv's of + // other identifiers if possible. + for (unsigned pos = 0; pos < getNumIds(); pos++) { + if (memo[pos]) + continue; + + auto lbConst = getConstantLowerBound(pos); + auto ubConst = getConstantUpperBound(pos); + if (lbConst.hasValue() && ubConst.hasValue()) { + // Detect equality to a constant. + if (lbConst.getValue() == ubConst.getValue()) { + memo[pos] = getAffineConstantExpr(lbConst.getValue(), context); + changed = true; + continue; + } + + // Detect an identifier as modulo of another identifier w.r.t a + // constant. + if (detectAsMod(*this, pos, lbConst.getValue(), ubConst.getValue(), + &memo)) { + changed = true; + continue; + } + } + + // Detect an identifier as floordiv of another identifier w.r.t a + // constant. + if (detectAsFloorDiv(*this, pos, &memo, context)) { + changed = true; + continue; + } + + // Detect an identifier as an expression of other identifiers. + unsigned idx; + if (!findConstraintWithNonZeroAt(*this, pos, /*isEq=*/true, &idx)) { + continue; + } + + // Build AffineExpr solving for identifier 'pos' in terms of all others. + auto expr = getAffineConstantExpr(0, context); + unsigned j, e; + for (j = 0, e = getNumIds(); j < e; ++j) { + if (j == pos) + continue; + int64_t c = atEq(idx, j); + if (c == 0) + continue; + // If any of the involved IDs hasn't been found yet, we can't proceed. + if (!memo[j]) + break; + expr = expr + memo[j] * c; + } + if (j < e) + // Can't construct expression as it depends on a yet uncomputed + // identifier. + continue; + + // Add constant term to AffineExpr. + expr = expr + atEq(idx, getNumIds()); + int64_t vPos = atEq(idx, pos); + assert(vPos != 0 && "expected non-zero here"); + if (vPos > 0) + expr = (-expr).floorDiv(vPos); + else + // vPos < 0. + expr = expr.floorDiv(-vPos); + // Successfully constructed expression. + memo[pos] = expr; + changed = true; + } + // This loop is guaranteed to reach a fixed point - since once an + // identifier's explicit form is computed (in memo[pos]), it's not updated + // again. + } while (changed); + + // Set the lower and upper bound maps for all the identifiers that were + // computed as affine expressions of the rest as the "detected expr" and + // "detected expr + 1" respectively; set the undetected ones to Null(). + Optional tmpClone; + for (unsigned pos = 0; pos < num; pos++) { + unsigned numMapDims = getNumDimIds() - num; + unsigned numMapSymbols = getNumSymbolIds(); + AffineExpr expr = memo[pos]; + if (expr) + expr = simplifyAffineExpr(expr, numMapDims, numMapSymbols); + + if (expr) { + (*lbMaps)[pos] = AffineMap::get(numMapDims, numMapSymbols, expr, {}); + (*ubMaps)[pos] = AffineMap::get(numMapDims, numMapSymbols, expr + 1, {}); + } else { + // TODO(bondhugula): Whenever there have local identifiers in the + // dependence constraints, we'll conservatively over-approximate, since we + // don't always explicitly compute them above (in the while loop). + if (getNumLocalIds() == 0) { + // Work on a copy so that we don't update this constraint system. + if (!tmpClone) { + tmpClone.emplace(FlatAffineConstraints(*this)); + // Removing redudnant inequalities is necessary so that we don't get + // redundant loop bounds. + tmpClone->removeRedundantInequalities(); + } + std::tie((*lbMaps)[pos], (*ubMaps)[pos]) = + tmpClone->getLowerAndUpperBound(pos, num, getNumDimIds(), {}, + context); + } + + // If the above fails, we'll just use the constant lower bound and the + // constant upper bound (if they exist) as the slice bounds. + if (!(*lbMaps)[pos]) { + LLVM_DEBUG(llvm::dbgs() + << "WARNING: Potentially over-approximating slice lb\n"); + auto lbConst = getConstantLowerBound(pos); + if (lbConst.hasValue()) { + (*lbMaps)[pos] = AffineMap::get( + numMapDims, numMapSymbols, + getAffineConstantExpr(lbConst.getValue(), context), {}); + } + } + if (!(*ubMaps)[pos]) { + LLVM_DEBUG(llvm::dbgs() + << "WARNING: Potentially over-approximating slice ub\n"); + auto ubConst = getConstantUpperBound(pos); + if (ubConst.hasValue()) { + (*ubMaps)[pos] = AffineMap::get( + numMapDims, numMapSymbols, + getAffineConstantExpr(ubConst.getValue() + 1, context), {}); + } + } + } + LLVM_DEBUG(llvm::dbgs() << "lb map for pos = " << Twine(pos) << ", expr: "); + LLVM_DEBUG((*lbMaps)[pos].dump();); + LLVM_DEBUG(llvm::dbgs() << "ub map for pos = " << Twine(pos) << ", expr: "); + LLVM_DEBUG((*ubMaps)[pos].dump();); + } +} + +// Adds slice lower/upper bounds from 'lbMaps'/'upMaps' to the constraint +// system. This function assumes that position 'lbMaps.size' == 'ubMaps.size', +// and that positions [0, lbMaps.size) represent dimensional identifiers which +// correspond to the loop IVs whose iteration bounds are being sliced. +// Note that both lower/upper bounds use operands from 'operands'. +// Returns true on success. Returns false for unimplemented cases such as +// semi-affine expressions or expressions with mod/floordiv. +bool FlatAffineConstraints::addSliceBounds(ArrayRef lbMaps, + ArrayRef ubMaps, + ArrayRef operands) { + assert(lbMaps.size() == ubMaps.size()); + // Record positions of the operands in the constraint system. + SmallVector positions; + for (const auto &operand : operands) { + unsigned loc; + if (!findId(*operand, &loc)) + assert(0 && "expected to be found"); + positions.push_back(loc); + } + + auto addLowerOrUpperBound = [&](unsigned pos, AffineMap boundMap, + bool lower) -> bool { + FlatAffineConstraints localVarCst; + std::vector> flatExprs; + if (!getFlattenedAffineExprs(boundMap, &flatExprs, &localVarCst)) { + LLVM_DEBUG(llvm::dbgs() << "semi-affine expressions not yet supported\n"); + return false; + } + if (localVarCst.getNumLocalIds() > 0) { + LLVM_DEBUG(llvm::dbgs() + << "loop bounds with mod/floordiv expr's not yet supported\n"); + return false; + } + + for (const auto &flatExpr : flatExprs) { + SmallVector ineq(getNumCols(), 0); + ineq[pos] = lower ? 1 : -1; + for (unsigned j = 0, e = boundMap.getNumInputs(); j < e; j++) { + ineq[positions[j]] = lower ? -flatExpr[j] : flatExpr[j]; + } + // Constant term. + ineq[getNumCols() - 1] = + lower ? -flatExpr[flatExpr.size() - 1] + // Upper bound in flattenedExpr is an exclusive one. + : flatExpr[flatExpr.size() - 1] - 1; + addInequality(ineq); + } + return true; + }; + + for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) { + if (!addLowerOrUpperBound(i, lbMaps[i], /*lower=*/true)) + return false; + if (!addLowerOrUpperBound(i, ubMaps[i], /*lower=*/false)) + return false; + } + + return true; +} + +void FlatAffineConstraints::addEquality(ArrayRef eq) { + assert(eq.size() == getNumCols()); + unsigned offset = equalities.size(); + equalities.resize(equalities.size() + numReservedCols); + std::copy(eq.begin(), eq.end(), equalities.begin() + offset); +} + +void FlatAffineConstraints::addInequality(ArrayRef inEq) { + assert(inEq.size() == getNumCols()); + unsigned offset = inequalities.size(); + inequalities.resize(inequalities.size() + numReservedCols); + std::copy(inEq.begin(), inEq.end(), inequalities.begin() + offset); +} + +void FlatAffineConstraints::addConstantLowerBound(unsigned pos, int64_t lb) { + assert(pos < getNumCols()); + unsigned offset = inequalities.size(); + inequalities.resize(inequalities.size() + numReservedCols); + std::fill(inequalities.begin() + offset, + inequalities.begin() + offset + getNumCols(), 0); + inequalities[offset + pos] = 1; + inequalities[offset + getNumCols() - 1] = -lb; +} + +void FlatAffineConstraints::addConstantUpperBound(unsigned pos, int64_t ub) { + assert(pos < getNumCols()); + unsigned offset = inequalities.size(); + inequalities.resize(inequalities.size() + numReservedCols); + std::fill(inequalities.begin() + offset, + inequalities.begin() + offset + getNumCols(), 0); + inequalities[offset + pos] = -1; + inequalities[offset + getNumCols() - 1] = ub; +} + +void FlatAffineConstraints::addConstantLowerBound(ArrayRef expr, + int64_t lb) { + assert(expr.size() == getNumCols()); + unsigned offset = inequalities.size(); + inequalities.resize(inequalities.size() + numReservedCols); + std::fill(inequalities.begin() + offset, + inequalities.begin() + offset + getNumCols(), 0); + std::copy(expr.begin(), expr.end(), inequalities.begin() + offset); + inequalities[offset + getNumCols() - 1] += -lb; +} + +void FlatAffineConstraints::addConstantUpperBound(ArrayRef expr, + int64_t ub) { + assert(expr.size() == getNumCols()); + unsigned offset = inequalities.size(); + inequalities.resize(inequalities.size() + numReservedCols); + std::fill(inequalities.begin() + offset, + inequalities.begin() + offset + getNumCols(), 0); + for (unsigned i = 0, e = getNumCols(); i < e; i++) { + inequalities[offset + i] = -expr[i]; + } + inequalities[offset + getNumCols() - 1] += ub; +} + +/// Adds a new local identifier as the floordiv of an affine function of other +/// identifiers, the coefficients of which are provided in 'dividend' and with +/// respect to a positive constant 'divisor'. Two constraints are added to the +/// system to capture equivalence with the floordiv. +/// q = expr floordiv c <=> c*q <= expr <= c*q + c - 1. +void FlatAffineConstraints::addLocalFloorDiv(ArrayRef dividend, + int64_t divisor) { + assert(dividend.size() == getNumCols() && "incorrect dividend size"); + assert(divisor > 0 && "positive divisor expected"); + + addLocalId(getNumLocalIds()); + + // Add two constraints for this new identifier 'q'. + SmallVector bound(dividend.size() + 1); + + // dividend - q * divisor >= 0 + std::copy(dividend.begin(), dividend.begin() + dividend.size() - 1, + bound.begin()); + bound.back() = dividend.back(); + bound[getNumIds() - 1] = -divisor; + addInequality(bound); + + // -dividend +qdivisor * q + divisor - 1 >= 0 + std::transform(bound.begin(), bound.end(), bound.begin(), + std::negate()); + bound[bound.size() - 1] += divisor - 1; + addInequality(bound); +} + +bool FlatAffineConstraints::findId(const Value &id, unsigned *pos) const { + unsigned i = 0; + for (const auto &mayBeId : ids) { + if (mayBeId.hasValue() && mayBeId.getValue() == &id) { + *pos = i; + return true; + } + i++; + } + return false; +} + +void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) { + assert(newSymbolCount <= numDims + numSymbols && + "invalid separation position"); + numDims = numDims + numSymbols - newSymbolCount; + numSymbols = newSymbolCount; +} + +/// Sets the specified identifer to a constant value. +void FlatAffineConstraints::setIdToConstant(unsigned pos, int64_t val) { + unsigned offset = equalities.size(); + equalities.resize(equalities.size() + numReservedCols); + std::fill(equalities.begin() + offset, + equalities.begin() + offset + getNumCols(), 0); + equalities[offset + pos] = 1; + equalities[offset + getNumCols() - 1] = -val; +} + +/// Sets the specified identifer to a constant value; asserts if the id is not +/// found. +void FlatAffineConstraints::setIdToConstant(const Value &id, int64_t val) { + unsigned pos; + if (!findId(id, &pos)) + // This is a pre-condition for this method. + assert(0 && "id not found"); + setIdToConstant(pos, val); +} + +void FlatAffineConstraints::removeEquality(unsigned pos) { + unsigned numEqualities = getNumEqualities(); + assert(pos < numEqualities); + unsigned outputIndex = pos * numReservedCols; + unsigned inputIndex = (pos + 1) * numReservedCols; + unsigned numElemsToCopy = (numEqualities - pos - 1) * numReservedCols; + std::copy(equalities.begin() + inputIndex, + equalities.begin() + inputIndex + numElemsToCopy, + equalities.begin() + outputIndex); + equalities.resize(equalities.size() - numReservedCols); +} + +/// Finds an equality that equates the specified identifier to a constant. +/// Returns the position of the equality row. If 'symbolic' is set to true, +/// symbols are also treated like a constant, i.e., an affine function of the +/// symbols is also treated like a constant. +static int findEqualityToConstant(const FlatAffineConstraints &cst, + unsigned pos, bool symbolic = false) { + assert(pos < cst.getNumIds() && "invalid position"); + for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) { + int64_t v = cst.atEq(r, pos); + if (v * v != 1) + continue; + unsigned c; + unsigned f = symbolic ? cst.getNumDimIds() : cst.getNumIds(); + // This checks for zeros in all positions other than 'pos' in [0, f) + for (c = 0; c < f; c++) { + if (c == pos) + continue; + if (cst.atEq(r, c) != 0) { + // Dependent on another identifier. + break; + } + } + if (c == f) + // Equality is free of other identifiers. + return r; + } + return -1; +} + +void FlatAffineConstraints::setAndEliminate(unsigned pos, int64_t constVal) { + assert(pos < getNumIds() && "invalid position"); + for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { + atIneq(r, getNumCols() - 1) += atIneq(r, pos) * constVal; + } + for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { + atEq(r, getNumCols() - 1) += atEq(r, pos) * constVal; + } + removeId(pos); +} + +bool FlatAffineConstraints::constantFoldId(unsigned pos) { + assert(pos < getNumIds() && "invalid position"); + int rowIdx; + if ((rowIdx = findEqualityToConstant(*this, pos)) == -1) + return false; + + // atEq(rowIdx, pos) is either -1 or 1. + assert(atEq(rowIdx, pos) * atEq(rowIdx, pos) == 1); + int64_t constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos); + setAndEliminate(pos, constVal); + return true; +} + +void FlatAffineConstraints::constantFoldIdRange(unsigned pos, unsigned num) { + for (unsigned s = pos, t = pos, e = pos + num; s < e; s++) { + if (!constantFoldId(t)) + t++; + } +} + +/// Returns the extent (upper bound - lower bound) of the specified +/// identifier if it is found to be a constant; returns None if it's not a +/// constant. This methods treats symbolic identifiers specially, i.e., +/// it looks for constant differences between affine expressions involving +/// only the symbolic identifiers. See comments at function definition for +/// example. 'lb', if provided, is set to the lower bound associated with the +/// constant difference. Note that 'lb' is purely symbolic and thus will contain +/// the coefficients of the symbolic identifiers and the constant coefficient. +// Egs: 0 <= i <= 15, return 16. +// s0 + 2 <= i <= s0 + 17, returns 16. (s0 has to be a symbol) +// s0 + s1 + 16 <= d0 <= s0 + s1 + 31, returns 16. +// s0 - 7 <= 8*j <= s0 returns 1 with lb = s0, lbDivisor = 8 (since lb = +// ceil(s0 - 7 / 8) = floor(s0 / 8)). +Optional FlatAffineConstraints::getConstantBoundOnDimSize( + unsigned pos, SmallVectorImpl *lb, int64_t *lbFloorDivisor) const { + assert(pos < getNumDimIds() && "Invalid identifier position"); + assert(getNumLocalIds() == 0); + + // TODO(bondhugula): eliminate all remaining dimensional identifiers (other + // than the one at 'pos' to make this more powerful. Not needed for + // hyper-rectangular spaces. + + // Find an equality for 'pos'^th identifier that equates it to some function + // of the symbolic identifiers (+ constant). + int eqRow = findEqualityToConstant(*this, pos, /*symbolic=*/true); + if (eqRow != -1) { + // This identifier can only take a single value. + if (lb) { + // Set lb to the symbolic value. + lb->resize(getNumSymbolIds() + 1); + for (unsigned c = 0, f = getNumSymbolIds() + 1; c < f; c++) { + int64_t v = atEq(eqRow, pos); + // atEq(eqRow, pos) is either -1 or 1. + assert(v * v == 1); + (*lb)[c] = v < 0 ? atEq(eqRow, getNumDimIds() + c) / -v + : -atEq(eqRow, getNumDimIds() + c) / v; + } + assert(lbFloorDivisor && + "both lb and divisor or none should be provided"); + *lbFloorDivisor = 1; + } + return 1; + } + + // Check if the identifier appears at all in any of the inequalities. + unsigned r, e; + for (r = 0, e = getNumInequalities(); r < e; r++) { + if (atIneq(r, pos) != 0) + break; + } + if (r == e) + // If it doesn't, there isn't a bound on it. + return None; + + // Positions of constraints that are lower/upper bounds on the variable. + SmallVector lbIndices, ubIndices; + + // Gather all symbolic lower bounds and upper bounds of the variable. Since + // the canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a + // lower bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. + for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { + unsigned c, f; + for (c = 0, f = getNumDimIds(); c < f; c++) { + if (c != pos && atIneq(r, c) != 0) + break; + } + if (c < getNumDimIds()) + // Not a pure symbolic bound. + continue; + if (atIneq(r, pos) >= 1) + // Lower bound. + lbIndices.push_back(r); + else if (atIneq(r, pos) <= -1) + // Upper bound. + ubIndices.push_back(r); + } + + // TODO(bondhugula): eliminate other dimensional identifiers to make this more + // powerful. Not needed for hyper-rectangular iteration spaces. + + Optional minDiff = None; + unsigned minLbPosition; + for (auto ubPos : ubIndices) { + for (auto lbPos : lbIndices) { + // Look for a lower bound and an upper bound that only differ by a + // constant, i.e., pairs of the form 0 <= c_pos - f(c_i's) <= diffConst. + // For example, if ii is the pos^th variable, we are looking for + // constraints like ii >= i, ii <= ii + 50, 50 being the difference. The + // minimum among all such constant differences is kept since that's the + // constant bounding the extent of the pos^th variable. + unsigned j, e; + for (j = 0, e = getNumCols() - 1; j < e; j++) + if (atIneq(ubPos, j) != -atIneq(lbPos, j)) { + break; + } + if (j < getNumCols() - 1) + continue; + int64_t diff = floorDiv(atIneq(ubPos, getNumCols() - 1) + + atIneq(lbPos, getNumCols() - 1) + 1, + atIneq(lbPos, pos)); + if (minDiff == None || diff < minDiff) { + minDiff = diff; + minLbPosition = lbPos; + } + } + } + if (lb && minDiff.hasValue()) { + // Set lb to the symbolic lower bound. + lb->resize(getNumSymbolIds() + 1); + // The lower bound is the ceildiv of the lb constraint over the coefficient + // of the variable at 'pos'. We express the ceildiv equivalently as a floor + // for uniformity. For eg., if the lower bound constraint was: 32*d0 - N + + // 31 >= 0, the lower bound for d0 is ceil(N - 31, 32), i.e., floor(N, 32). + *lbFloorDivisor = atIneq(minLbPosition, pos); + for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++) { + // ceildiv (val / d) = floordiv (val + d - 1 / d); hence, the addition of + // 'atIneq(minLbPosition, pos) - 1'. + (*lb)[c] = -atIneq(minLbPosition, getNumDimIds() + c) + + atIneq(minLbPosition, pos) - 1; + } + } + return minDiff; +} + +template +Optional +FlatAffineConstraints::computeConstantLowerOrUpperBound(unsigned pos) { + assert(pos < getNumIds() && "invalid position"); + // Project to 'pos'. + projectOut(0, pos); + projectOut(1, getNumIds() - 1); + // Check if there's an equality equating the '0'^th identifier to a constant. + int eqRowIdx = findEqualityToConstant(*this, 0, /*symbolic=*/false); + if (eqRowIdx != -1) + // atEq(rowIdx, 0) is either -1 or 1. + return -atEq(eqRowIdx, getNumCols() - 1) / atEq(eqRowIdx, 0); + + // Check if the identifier appears at all in any of the inequalities. + unsigned r, e; + for (r = 0, e = getNumInequalities(); r < e; r++) { + if (atIneq(r, 0) != 0) + break; + } + if (r == e) + // If it doesn't, there isn't a bound on it. + return None; + + Optional minOrMaxConst = None; + + // Take the max across all const lower bounds (or min across all constant + // upper bounds). + for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { + if (isLower) { + if (atIneq(r, 0) <= 0) + // Not a lower bound. + continue; + } else if (atIneq(r, 0) >= 0) { + // Not an upper bound. + continue; + } + unsigned c, f; + for (c = 0, f = getNumCols() - 1; c < f; c++) + if (c != 0 && atIneq(r, c) != 0) + break; + if (c < getNumCols() - 1) + // Not a constant bound. + continue; + + int64_t boundConst = + isLower ? mlir::ceilDiv(-atIneq(r, getNumCols() - 1), atIneq(r, 0)) + : mlir::floorDiv(atIneq(r, getNumCols() - 1), -atIneq(r, 0)); + if (isLower) { + if (minOrMaxConst == None || boundConst > minOrMaxConst) + minOrMaxConst = boundConst; + } else { + if (minOrMaxConst == None || boundConst < minOrMaxConst) + minOrMaxConst = boundConst; + } + } + return minOrMaxConst; +} + +Optional +FlatAffineConstraints::getConstantLowerBound(unsigned pos) const { + FlatAffineConstraints tmpCst(*this); + return tmpCst.computeConstantLowerOrUpperBound(pos); +} + +Optional +FlatAffineConstraints::getConstantUpperBound(unsigned pos) const { + FlatAffineConstraints tmpCst(*this); + return tmpCst.computeConstantLowerOrUpperBound(pos); +} + +// A simple (naive and conservative) check for hyper-rectangularlity. +bool FlatAffineConstraints::isHyperRectangular(unsigned pos, + unsigned num) const { + assert(pos < getNumCols() - 1); + // Check for two non-zero coefficients in the range [pos, pos + sum). + for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { + unsigned sum = 0; + for (unsigned c = pos; c < pos + num; c++) { + if (atIneq(r, c) != 0) + sum++; + } + if (sum > 1) + return false; + } + for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { + unsigned sum = 0; + for (unsigned c = pos; c < pos + num; c++) { + if (atEq(r, c) != 0) + sum++; + } + if (sum > 1) + return false; + } + return true; +} + +void FlatAffineConstraints::print(raw_ostream &os) const { + assert(hasConsistentState()); + os << "\nConstraints (" << getNumDimIds() << " dims, " << getNumSymbolIds() + << " symbols, " << getNumLocalIds() << " locals), (" << getNumConstraints() + << " constraints)\n"; + os << "("; + for (unsigned i = 0, e = getNumIds(); i < e; i++) { + if (ids[i] == None) + os << "None "; + else + os << "Value "; + } + os << " const)\n"; + for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { + for (unsigned j = 0, f = getNumCols(); j < f; ++j) { + os << atEq(i, j) << " "; + } + os << "= 0\n"; + } + for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { + for (unsigned j = 0, f = getNumCols(); j < f; ++j) { + os << atIneq(i, j) << " "; + } + os << ">= 0\n"; + } + os << '\n'; +} + +void FlatAffineConstraints::dump() const { print(llvm::errs()); } + +/// Removes duplicate constraints and trivially true constraints: a constraint +/// of the form >= 0 is considered a trivially true +/// constraint. +// Uses a DenseSet to hash and detect duplicates followed by a linear scan to +// remove duplicates in place. +void FlatAffineConstraints::removeTrivialRedundancy() { + DenseSet> rowSet; + + // Check if constraint is of the form >= 0. + auto isTriviallyValid = [&](unsigned r) -> bool { + for (unsigned c = 0, e = getNumCols() - 1; c < e; c++) { + if (atIneq(r, c) != 0) + return false; + } + return atIneq(r, getNumCols() - 1) >= 0; + }; + + // Detect and mark redundant constraints. + std::vector redunIneq(getNumInequalities(), false); + for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { + int64_t *rowStart = inequalities.data() + numReservedCols * r; + auto row = ArrayRef(rowStart, getNumCols()); + if (isTriviallyValid(r) || !rowSet.insert(row).second) { + redunIneq[r] = true; + } + } + + auto copyRow = [&](unsigned src, unsigned dest) { + if (src == dest) + return; + for (unsigned c = 0, e = getNumCols(); c < e; c++) { + atIneq(dest, c) = atIneq(src, c); + } + }; + + // Scan to get rid of all rows marked redundant, in-place. + unsigned pos = 0; + for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { + if (!redunIneq[r]) + copyRow(r, pos++); + } + inequalities.resize(numReservedCols * pos); + + // TODO(bondhugula): consider doing this for equalities as well, but probably + // not worth the savings. +} + +void FlatAffineConstraints::clearAndCopyFrom( + const FlatAffineConstraints &other) { + FlatAffineConstraints copy(other); + std::swap(*this, copy); + assert(copy.getNumIds() == copy.getIds().size()); +} + +void FlatAffineConstraints::removeId(unsigned pos) { + removeIdRange(pos, pos + 1); +} + +static std::pair +getNewNumDimsSymbols(unsigned pos, const FlatAffineConstraints &cst) { + unsigned numDims = cst.getNumDimIds(); + unsigned numSymbols = cst.getNumSymbolIds(); + unsigned newNumDims, newNumSymbols; + if (pos < numDims) { + newNumDims = numDims - 1; + newNumSymbols = numSymbols; + } else if (pos < numDims + numSymbols) { + assert(numSymbols >= 1); + newNumDims = numDims; + newNumSymbols = numSymbols - 1; + } else { + newNumDims = numDims; + newNumSymbols = numSymbols; + } + return {newNumDims, newNumSymbols}; +} + +#undef DEBUG_TYPE +#define DEBUG_TYPE "fm" + +/// Eliminates identifier at the specified position using Fourier-Motzkin +/// variable elimination. This technique is exact for rational spaces but +/// conservative (in "rare" cases) for integer spaces. The operation corresponds +/// to a projection operation yielding the (convex) set of integer points +/// contained in the rational shadow of the set. An emptiness test that relies +/// on this method will guarantee emptiness, i.e., it disproves the existence of +/// a solution if it says it's empty. +/// If a non-null isResultIntegerExact is passed, it is set to true if the +/// result is also integer exact. If it's set to false, the obtained solution +/// *may* not be exact, i.e., it may contain integer points that do not have an +/// integer pre-image in the original set. +/// +/// Eg: +/// j >= 0, j <= i + 1 +/// i >= 0, i <= N + 1 +/// Eliminating i yields, +/// j >= 0, 0 <= N + 1, j - 1 <= N + 1 +/// +/// If darkShadow = true, this method computes the dark shadow on elimination; +/// the dark shadow is a convex integer subset of the exact integer shadow. A +/// non-empty dark shadow proves the existence of an integer solution. The +/// elimination in such a case could however be an under-approximation, and thus +/// should not be used for scanning sets or used by itself for dependence +/// checking. +/// +/// Eg: 2-d set, * represents grid points, 'o' represents a point in the set. +/// ^ +/// | +/// | * * * * o o +/// i | * * o o o o +/// | o * * * * * +/// ---------------> +/// j -> +/// +/// Eliminating i from this system (projecting on the j dimension): +/// rational shadow / integer light shadow: 1 <= j <= 6 +/// dark shadow: 3 <= j <= 6 +/// exact integer shadow: j = 1 \union 3 <= j <= 6 +/// holes/splinters: j = 2 +/// +/// darkShadow = false, isResultIntegerExact = nullptr are default values. +// TODO(bondhugula): a slight modification to yield dark shadow version of FM +// (tightened), which can prove the existence of a solution if there is one. +void FlatAffineConstraints::FourierMotzkinEliminate( + unsigned pos, bool darkShadow, bool *isResultIntegerExact) { + LLVM_DEBUG(llvm::dbgs() << "FM input (eliminate pos " << pos << "):\n"); + LLVM_DEBUG(dump()); + assert(pos < getNumIds() && "invalid position"); + assert(hasConsistentState()); + + // Check if this identifier can be eliminated through a substitution. + for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { + if (atEq(r, pos) != 0) { + // Use Gaussian elimination here (since we have an equality). + bool ret = gaussianEliminateId(pos); + (void)ret; + assert(ret && "Gaussian elimination guaranteed to succeed"); + LLVM_DEBUG(llvm::dbgs() << "FM output:\n"); + LLVM_DEBUG(dump()); + return; + } + } + + // A fast linear time tightening. + GCDTightenInequalities(); + + // Check if the identifier appears at all in any of the inequalities. + unsigned r, e; + for (r = 0, e = getNumInequalities(); r < e; r++) { + if (atIneq(r, pos) != 0) + break; + } + if (r == getNumInequalities()) { + // If it doesn't appear, just remove the column and return. + // TODO(andydavis,bondhugula): refactor removeColumns to use it from here. + removeId(pos); + LLVM_DEBUG(llvm::dbgs() << "FM output:\n"); + LLVM_DEBUG(dump()); + return; + } + + // Positions of constraints that are lower bounds on the variable. + SmallVector lbIndices; + // Positions of constraints that are lower bounds on the variable. + SmallVector ubIndices; + // Positions of constraints that do not involve the variable. + std::vector nbIndices; + nbIndices.reserve(getNumInequalities()); + + // Gather all lower bounds and upper bounds of the variable. Since the + // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower + // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. + for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { + if (atIneq(r, pos) == 0) { + // Id does not appear in bound. + nbIndices.push_back(r); + } else if (atIneq(r, pos) >= 1) { + // Lower bound. + lbIndices.push_back(r); + } else { + // Upper bound. + ubIndices.push_back(r); + } + } + + // Set the number of dimensions, symbols in the resulting system. + const auto &dimsSymbols = getNewNumDimsSymbols(pos, *this); + unsigned newNumDims = dimsSymbols.first; + unsigned newNumSymbols = dimsSymbols.second; + + SmallVector, 8> newIds; + newIds.reserve(numIds - 1); + newIds.append(ids.begin(), ids.begin() + pos); + newIds.append(ids.begin() + pos + 1, ids.end()); + + /// Create the new system which has one identifier less. + FlatAffineConstraints newFac( + lbIndices.size() * ubIndices.size() + nbIndices.size(), + getNumEqualities(), getNumCols() - 1, newNumDims, newNumSymbols, + /*numLocals=*/getNumIds() - 1 - newNumDims - newNumSymbols, newIds); + + assert(newFac.getIds().size() == newFac.getNumIds()); + + // This will be used to check if the elimination was integer exact. + unsigned lcmProducts = 1; + + // Let x be the variable we are eliminating. + // For each lower bound, lb <= c_l*x, and each upper bound c_u*x <= ub, (note + // that c_l, c_u >= 1) we have: + // lb*lcm(c_l, c_u)/c_l <= lcm(c_l, c_u)*x <= ub*lcm(c_l, c_u)/c_u + // We thus generate a constraint: + // lcm(c_l, c_u)/c_l*lb <= lcm(c_l, c_u)/c_u*ub. + // Note if c_l = c_u = 1, all integer points captured by the resulting + // constraint correspond to integer points in the original system (i.e., they + // have integer pre-images). Hence, if the lcm's are all 1, the elimination is + // integer exact. + for (auto ubPos : ubIndices) { + for (auto lbPos : lbIndices) { + SmallVector ineq; + ineq.reserve(newFac.getNumCols()); + int64_t lbCoeff = atIneq(lbPos, pos); + // Note that in the comments above, ubCoeff is the negation of the + // coefficient in the canonical form as the view taken here is that of the + // term being moved to the other size of '>='. + int64_t ubCoeff = -atIneq(ubPos, pos); + // TODO(bondhugula): refactor this loop to avoid all branches inside. + for (unsigned l = 0, e = getNumCols(); l < e; l++) { + if (l == pos) + continue; + assert(lbCoeff >= 1 && ubCoeff >= 1 && "bounds wrongly identified"); + int64_t lcm = mlir::lcm(lbCoeff, ubCoeff); + ineq.push_back(atIneq(ubPos, l) * (lcm / ubCoeff) + + atIneq(lbPos, l) * (lcm / lbCoeff)); + lcmProducts *= lcm; + } + if (darkShadow) { + // The dark shadow is a convex subset of the exact integer shadow. If + // there is a point here, it proves the existence of a solution. + ineq[ineq.size() - 1] += lbCoeff * ubCoeff - lbCoeff - ubCoeff + 1; + } + // TODO: we need to have a way to add inequalities in-place in + // FlatAffineConstraints instead of creating and copying over. + newFac.addInequality(ineq); + } + } + + if (lcmProducts == 1 && isResultIntegerExact) + *isResultIntegerExact = 1; + + // Copy over the constraints not involving this variable. + for (auto nbPos : nbIndices) { + SmallVector ineq; + ineq.reserve(getNumCols() - 1); + for (unsigned l = 0, e = getNumCols(); l < e; l++) { + if (l == pos) + continue; + ineq.push_back(atIneq(nbPos, l)); + } + newFac.addInequality(ineq); + } + + assert(newFac.getNumConstraints() == + lbIndices.size() * ubIndices.size() + nbIndices.size()); + + // Copy over the equalities. + for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { + SmallVector eq; + eq.reserve(newFac.getNumCols()); + for (unsigned l = 0, e = getNumCols(); l < e; l++) { + if (l == pos) + continue; + eq.push_back(atEq(r, l)); + } + newFac.addEquality(eq); + } + + newFac.removeTrivialRedundancy(); + clearAndCopyFrom(newFac); + LLVM_DEBUG(llvm::dbgs() << "FM output:\n"); + LLVM_DEBUG(dump()); +} + +#undef DEBUG_TYPE +#define DEBUG_TYPE "affine-structures" + +void FlatAffineConstraints::projectOut(unsigned pos, unsigned num) { + if (num == 0) + return; + + // 'pos' can be at most getNumCols() - 2 if num > 0. + assert(getNumCols() < 2 || pos <= getNumCols() - 2 && "invalid position"); + assert(pos + num < getNumCols() && "invalid range"); + + // Eliminate as many identifiers as possible using Gaussian elimination. + unsigned currentPos = pos; + unsigned numToEliminate = num; + unsigned numGaussianEliminated = 0; + + while (currentPos < getNumIds()) { + unsigned curNumEliminated = + gaussianEliminateIds(currentPos, currentPos + numToEliminate); + ++currentPos; + numToEliminate -= curNumEliminated + 1; + numGaussianEliminated += curNumEliminated; + } + + // Eliminate the remaining using Fourier-Motzkin. + for (unsigned i = 0; i < num - numGaussianEliminated; i++) { + unsigned numToEliminate = num - numGaussianEliminated - i; + FourierMotzkinEliminate( + getBestIdToEliminate(*this, pos, pos + numToEliminate)); + } + + // Fast/trivial simplifications. + GCDTightenInequalities(); + // Normalize constraints after tightening since the latter impacts this, but + // not the other way round. + normalizeConstraintsByGCD(); +} + +void FlatAffineConstraints::projectOut(Value *id) { + unsigned pos; + bool ret = findId(*id, &pos); + assert(ret); + (void)ret; + FourierMotzkinEliminate(pos); +} + +bool FlatAffineConstraints::isRangeOneToOne(unsigned start, + unsigned limit) const { + assert(start <= getNumIds() - 1 && "invalid start position"); + assert(limit > start && limit <= getNumIds() && "invalid limit"); + + FlatAffineConstraints tmpCst(*this); + + if (start != 0) { + // Move [start, limit) to the left. + for (unsigned r = 0, e = getNumInequalities(); r < e; ++r) { + for (unsigned c = 0, f = getNumCols(); c < f; ++c) { + if (c >= start && c < limit) + tmpCst.atIneq(r, c - start) = atIneq(r, c); + else if (c < start) + tmpCst.atIneq(r, c + limit - start) = atIneq(r, c); + else + tmpCst.atIneq(r, c) = atIneq(r, c); + } + } + for (unsigned r = 0, e = getNumEqualities(); r < e; ++r) { + for (unsigned c = 0, f = getNumCols(); c < f; ++c) { + if (c >= start && c < limit) + tmpCst.atEq(r, c - start) = atEq(r, c); + else if (c < start) + tmpCst.atEq(r, c + limit - start) = atEq(r, c); + else + tmpCst.atEq(r, c) = atEq(r, c); + } + } + } + + // Mark everything to the right as symbols so that we can check the extents in + // a symbolic way below. + tmpCst.setDimSymbolSeparation(getNumIds() - (limit - start)); + + // Check if the extents of all the specified dimensions are just one (when + // treating the rest as symbols). + for (unsigned pos = 0, e = tmpCst.getNumDimIds(); pos < e; ++pos) { + auto extent = tmpCst.getConstantBoundOnDimSize(pos); + if (!extent.hasValue() || extent.getValue() != 1) + return false; + } + return true; +} + +void FlatAffineConstraints::clearConstraints() { + equalities.clear(); + inequalities.clear(); +} + +namespace { + +enum BoundCmpResult { Greater, Less, Equal, Unknown }; + +/// Compares two affine bounds whose coefficients are provided in 'first' and +/// 'second'. The last coefficient is the constant term. +static BoundCmpResult compareBounds(ArrayRef a, ArrayRef b) { + assert(a.size() == b.size()); + + // For the bounds to be comparable, their corresponding identifier + // coefficients should be equal; the constant terms are then compared to + // determine less/greater/equal. + + if (!std::equal(a.begin(), a.end() - 1, b.begin())) + return Unknown; + + if (a.back() == b.back()) + return Equal; + + return a.back() < b.back() ? Less : Greater; +} +}; // namespace + +// TODO(bondhugula,andydavis): This still doesn't do a comprehensive merge of +// the symbols. Assumes the common symbols appear in the same order (the +// current/common use case). +static void mergeSymbols(FlatAffineConstraints *A, FlatAffineConstraints *B) { + SmallVector symbolsA, symbolsB; + A->getIdValues(A->getNumDimIds(), A->getNumDimAndSymbolIds(), &symbolsA); + B->getIdValues(B->getNumDimIds(), B->getNumDimAndSymbolIds(), &symbolsB); + + // Both symbol list have a handful symbols each typically (3-4); a merge + // quadratic in complexity with a linear search is fine. + for (auto *symbolB : symbolsB) { + if (llvm::is_contained(symbolsA, symbolB)) { + A->addSymbolId(symbolsA.size(), symbolB); + symbolsA.push_back(symbolB); + } + } + // symbolsA now holds the merged symbol list. + symbolsB.reserve(symbolsA.size()); + unsigned iB = 0; + for (auto *symbolA : symbolsA) { + assert(iB < symbolsB.size()); + if (symbolA != symbolsB[iB]) { + symbolsB.insert(symbolsB.begin() + iB, symbolA); + B->addSymbolId(iB, symbolA); + } + ++iB; + } +} + +// Compute the bounding box with respect to 'other' by finding the min of the +// lower bounds and the max of the upper bounds along each of the dimensions. +bool FlatAffineConstraints::unionBoundingBox( + const FlatAffineConstraints &otherArg) { + assert(otherArg.getNumDimIds() == numDims && "dims mismatch"); + + Optional copy; + if (!otherArg.getIds().equals(getIds())) { + copy.emplace(FlatAffineConstraints(otherArg)); + mergeSymbols(this, ©.getValue()); + assert(getIds().equals(copy->getIds()) && "merge failed"); + } + + const auto &other = copy ? *copy : otherArg; + + assert(other.getNumLocalIds() == 0 && "local ids not eliminated"); + assert(getNumLocalIds() == 0 && "local ids not eliminated"); + + std::vector> boundingLbs; + std::vector> boundingUbs; + boundingLbs.reserve(2 * getNumDimIds()); + boundingUbs.reserve(2 * getNumDimIds()); + + SmallVector lb, otherLb; + lb.reserve(getNumSymbolIds() + 1); + otherLb.reserve(getNumSymbolIds() + 1); + int64_t lbDivisor, otherLbDivisor; + for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) { + lb.clear(); + auto extent = getConstantBoundOnDimSize(d, &lb, &lbDivisor); + if (!extent.hasValue()) + // TODO(bondhugula): symbolic extents when necessary. + // TODO(bondhugula): handle union if a dimension is unbounded. + return false; + + otherLb.clear(); + auto otherExtent = + other.getConstantBoundOnDimSize(d, &otherLb, &otherLbDivisor); + if (!otherExtent.hasValue() || lbDivisor != otherLbDivisor) + // TODO(bondhugula): symbolic extents when necessary. + return false; + + assert(lbDivisor > 0 && "divisor always expected to be positive"); + + // Compute min of lower bounds and max of upper bounds. + ArrayRef minLb, maxUb; + + auto res = compareBounds(lb, otherLb); + // Identify min. + if (res == BoundCmpResult::Less || res == BoundCmpResult::Equal) { + minLb = lb; + } else if (res == BoundCmpResult::Greater) { + minLb = otherLb; + } else { + // Uncomparable. + auto constLb = getConstantLowerBound(d); + auto constOtherLb = other.getConstantLowerBound(d); + if (!constLb.hasValue() || !constOtherLb.hasValue()) + return false; + minLb = std::min(constLb.getValue(), constOtherLb.getValue()); + } + + // Do the same for ub's but max of upper bounds. + SmallVector ub(lb), otherUb(otherLb); + ub.back() += extent.getValue() - 1; + otherUb.back() += otherExtent.getValue() - 1; + + // Identify max. + auto uRes = compareBounds(ub, otherUb); + if (uRes == BoundCmpResult::Greater || uRes == BoundCmpResult::Equal) { + maxUb = ub; + } else if (uRes == BoundCmpResult::Less) { + maxUb = otherUb; + } else { + // Uncomparable. + auto constUb = getConstantUpperBound(d); + auto constOtherUb = other.getConstantUpperBound(d); + if (!constUb.hasValue() || !constOtherUb.hasValue()) + return false; + maxUb = std::max(constUb.getValue(), constOtherUb.getValue()); + } + + SmallVector newLb(getNumCols(), 0); + SmallVector newUb(getNumCols(), 0); + + // The divisor for lb, ub, otherLb, otherUb at this point is lbDivisor, + // and so it's the divisor for newLb and newUb as well. + newLb[d] = lbDivisor; + newUb[d] = -lbDivisor; + // Copy over the symbolic part + constant term. + std::copy(minLb.begin(), minLb.end(), newLb.begin() + getNumDimIds()); + std::transform(newLb.begin() + getNumDimIds(), newLb.end(), + newLb.begin() + getNumDimIds(), std::negate()); + std::copy(maxUb.begin(), maxUb.end(), newUb.begin() + getNumDimIds()); + + boundingLbs.push_back(newLb); + boundingUbs.push_back(newUb); + } + + // Clear all constraints and add the lower/upper bounds for the bounding box. + clearConstraints(); + for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) { + addInequality(boundingLbs[d]); + addInequality(boundingUbs[d]); + } + + return true; +} diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 545735fd6fd..c0deb805bdf 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -23,9 +23,9 @@ #include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/NestedMatcher.h" #include "mlir/Analysis/VectorAnalysis.h" -#include "mlir/IR/AffineStructures.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Instruction.h" @@ -147,8 +147,7 @@ bool mlir::isAccessInvariant(const Value &iv, const Value &index) { auto composeOp = affineApplyOps[0]->cast(); // We need yet another level of indirection because the `dim` index of the // access may not correspond to the `dim` index of composeOp. - return !composeOp->getAsAffineValueMap().isFunctionOf( - 0, const_cast(&iv)); + return !(AffineValueMap(composeOp).isFunctionOf(0, const_cast(&iv))); } llvm::DenseSet diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index b86651793f9..8a0cb44f0cc 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -21,9 +21,9 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Passes.h" #include "mlir/Analysis/Utils.h" -#include "mlir/IR/AffineStructures.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp index 0b5c9b997a5..93d4fde1fd9 100644 --- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp @@ -20,9 +20,9 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Passes.h" #include "mlir/Analysis/Utils.h" -#include "mlir/IR/AffineStructures.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 4e176f63503..9947f1621b5 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -24,7 +24,7 @@ #include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" -#include "mlir/IR/AffineStructures.h" +#include "mlir/Analysis/AffineStructures.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/StandardOps/StandardOps.h" @@ -185,7 +185,7 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth, // bounds expressions involve outer loops or other symbols. // TODO(bondhugula): rewrite this to use getInstIndexSet; this way // conditionals will be handled when the latter supports it. - if (!addAffineForOpDomain(loop, &cst)) + if (!cst.addAffineForOpDomain(loop)) return false; } else { // Has to be a valid symbol. diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp index 5cfb1461590..9081183cb3a 100644 --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -19,7 +19,6 @@ #include "AffineExprDetail.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" -#include "mlir/IR/AffineStructures.h" #include "mlir/IR/IntegerSet.h" #include "mlir/Support/STLExtras.h" #include "llvm/ADT/STLExtras.h" @@ -336,126 +335,39 @@ AffineExpr mlir::toAffineExpr(ArrayRef eq, unsigned numDims, return expr; } -namespace { - -// This class is used to flatten a pure affine expression (AffineExpr, -// which is in a tree form) into a sum of products (w.r.t constants) when -// possible, and in that process simplifying the expression. For a modulo, -// floordiv, or a ceildiv expression, an additional identifier, called a local -// identifier, is introduced to rewrite the expression as a sum of product -// affine expression. Each local identifier is always and by construction a -// floordiv of a pure add/mul affine function of dimensional, symbolic, and -// other local identifiers, in a non-mutually recursive way. Hence, every local -// identifier can ultimately always be recovered as an affine function of -// dimensional and symbolic identifiers (involving floordiv's); note however -// that by AffineExpr construction, some floordiv combinations are converted to -// mod's. The result of the flattening is a flattened expression and a set of -// constraints involving just the local variables. -// -// d2 + (d0 + d1) floordiv 4 is flattened to d2 + q where 'q' is the local -// variable introduced, with localVarCst containing 4*q <= d0 + d1 <= 4*q + 3. -// -// The simplification performed includes the accumulation of contributions for -// each dimensional and symbolic identifier together, the simplification of -// floordiv/ceildiv/mod expressions and other simplifications that in turn -// happen as a result. A simplification that this flattening naturally performs -// is of simplifying the numerator and denominator of floordiv/ceildiv, and -// folding a modulo expression to a zero, if possible. Three examples are below: -// -// (d0 + 3 * d1) + d0) - 2 * d1) - d0 simplified to d0 + d1 -// (d0 - d0 mod 4 + 4) mod 4 simplified to 0 -// (3*d0 + 2*d1 + d0) floordiv 2 + d1 simplified to 2*d0 + 2*d1 -// -// The way the flattening works for the second example is as follows: d0 % 4 is -// replaced by d0 - 4*q with q being introduced: the expression then simplifies -// to: (d0 - (d0 - 4q) + 4) = 4q + 4, modulo of which w.r.t 4 simplifies to -// zero. Note that an affine expression may not always be expressible purely as -// a sum of products involving just the original dimensional and symbolic -// identifiers due to the presence of modulo/floordiv/ceildiv expressions that -// may not be eliminated after simplification; in such cases, the final -// expression can be reconstructed by replacing the local identifiers with their -// corresponding explicit form stored in 'localExprs' (note that each of the -// explicit forms itself would have been simplified). -// -// The expression walk method here performs a linear time post order walk that -// performs the above simplifications through visit methods, with partial -// results being stored in 'operandExprStack'. When a parent expr is visited, -// the flattened expressions corresponding to its two operands would already be -// on the stack - the parent expression looks at the two flattened expressions -// and combines the two. It pops off the operand expressions and pushes the -// combined result (although this is done in-place on its LHS operand expr). -// When the walk is completed, the flattened form of the top-level expression -// would be left on the stack. -// -// A flattener can be repeatedly used for multiple affine expressions that bind -// to the same operands, for example, for all result expressions of an -// AffineMap or AffineValueMap. In such cases, using it for multiple expressions -// is more efficient than creating a new flattener for each expression since -// common idenical div and mod expressions appearing across different -// expressions are mapped to the same local identifier (same column position in -// 'localVarCst'). -struct AffineExprFlattener : public AffineExprVisitor { -public: - // Flattend expression layout: [dims, symbols, locals, constant] - // Stack that holds the LHS and RHS operands while visiting a binary op expr. - // In future, consider adding a prepass to determine how big the SmallVector's - // will be, and linearize this to std::vector to prevent - // SmallVector moves on re-allocation. - std::vector> operandExprStack; - // Constraints connecting newly introduced local variables (for mod's and - // div's) to existing (dimensional and symbolic) ones. These are always - // inequalities. - FlatAffineConstraints localVarCst; - - unsigned numDims; - unsigned numSymbols; - // Number of newly introduced identifiers to flatten mod/floordiv/ceildiv - // expressions that could not be simplified. - unsigned numLocals; - // AffineExpr's corresponding to the floordiv/ceildiv/mod expressions for - // which new identifiers were introduced; if the latter do not get canceled - // out, these expressions can be readily used to reconstruct the AffineExpr - // (tree) form. Note that these expressions themselves would have been - // simplified (recursively) by this pass. Eg. d0 + (d0 + 2*d1 + d0) ceildiv 4 - // will be simplified to d0 + q, where q = (d0 + d1) ceildiv 2. (d0 + d1) - // ceildiv 2 would be the local expression stored for q. - SmallVector localExprs; - MLIRContext *context; - - AffineExprFlattener(unsigned numDims, unsigned numSymbols, - MLIRContext *context) - : numDims(numDims), numSymbols(numSymbols), numLocals(0), - context(context) { - operandExprStack.reserve(8); - localVarCst.reset(numDims, numSymbols, numLocals); - } +SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims, + unsigned numSymbols, + MLIRContext *context) + : numDims(numDims), numSymbols(numSymbols), numLocals(0), context(context) { + operandExprStack.reserve(8); +} - void visitMulExpr(AffineBinaryOpExpr expr) { - assert(operandExprStack.size() >= 2); - // This is a pure affine expr; the RHS will be a constant. - assert(expr.getRHS().isa()); - // Get the RHS constant. - auto rhsConst = operandExprStack.back()[getConstantIndex()]; - operandExprStack.pop_back(); - // Update the LHS in place instead of pop and push. - auto &lhs = operandExprStack.back(); - for (unsigned i = 0, e = lhs.size(); i < e; i++) { - lhs[i] *= rhsConst; - } +void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) { + assert(operandExprStack.size() >= 2); + // This is a pure affine expr; the RHS will be a constant. + assert(expr.getRHS().isa()); + // Get the RHS constant. + auto rhsConst = operandExprStack.back()[getConstantIndex()]; + operandExprStack.pop_back(); + // Update the LHS in place instead of pop and push. + auto &lhs = operandExprStack.back(); + for (unsigned i = 0, e = lhs.size(); i < e; i++) { + lhs[i] *= rhsConst; } +} - void visitAddExpr(AffineBinaryOpExpr expr) { - assert(operandExprStack.size() >= 2); - const auto &rhs = operandExprStack.back(); - auto &lhs = operandExprStack[operandExprStack.size() - 2]; - assert(lhs.size() == rhs.size()); - // Update the LHS in place. - for (unsigned i = 0, e = rhs.size(); i < e; i++) { - lhs[i] += rhs[i]; - } - // Pop off the RHS. - operandExprStack.pop_back(); +void SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) { + assert(operandExprStack.size() >= 2); + const auto &rhs = operandExprStack.back(); + auto &lhs = operandExprStack[operandExprStack.size() - 2]; + assert(lhs.size() == rhs.size()); + // Update the LHS in place. + for (unsigned i = 0, e = rhs.size(); i < e; i++) { + lhs[i] += rhs[i]; } + // Pop off the RHS. + operandExprStack.pop_back(); +} // // t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1 @@ -463,86 +375,85 @@ public: // A mod expression "expr mod c" is thus flattened by introducing a new local // variable q (= expr floordiv c), such that expr mod c is replaced with // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst. - void visitModExpr(AffineBinaryOpExpr expr) { - assert(operandExprStack.size() >= 2); - // This is a pure affine expr; the RHS will be a constant. - assert(expr.getRHS().isa()); - auto rhsConst = operandExprStack.back()[getConstantIndex()]; - operandExprStack.pop_back(); - auto &lhs = operandExprStack.back(); - // TODO(bondhugula): handle modulo by zero case when this issue is fixed - // at the other places in the IR. - assert(rhsConst > 0 && "RHS constant has to be positive"); - - // Check if the LHS expression is a multiple of modulo factor. - unsigned i, e; - for (i = 0, e = lhs.size(); i < e; i++) - if (lhs[i] % rhsConst != 0) - break; - // If yes, modulo expression here simplifies to zero. - if (i == lhs.size()) { - std::fill(lhs.begin(), lhs.end(), 0); - return; - } - - // Add a local variable for the quotient, i.e., expr % c is replaced by - // (expr - q * c) where q = expr floordiv c. Do this while canceling out - // the GCD of expr and c. - SmallVector floorDividend(lhs); - uint64_t gcd = rhsConst; - for (unsigned i = 0, e = lhs.size(); i < e; i++) - gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i])); - // Simplify the numerator and the denominator. - if (gcd != 1) { - for (unsigned i = 0, e = floorDividend.size(); i < e; i++) - floorDividend[i] = floorDividend[i] / static_cast(gcd); - } - int64_t floorDivisor = rhsConst / static_cast(gcd); - - // Construct the AffineExpr form of the floordiv to store in localExprs. - auto dividendExpr = - toAffineExpr(floorDividend, numDims, numSymbols, localExprs, context); - auto divisorExpr = getAffineConstantExpr(floorDivisor, context); - auto floorDivExpr = dividendExpr.floorDiv(divisorExpr); - int loc; - if ((loc = findLocalId(floorDivExpr)) == -1) { - addLocalFloorDivId(floorDividend, floorDivisor, floorDivExpr); - // Set result at top of stack to "lhs - rhsConst * q". - lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst; - } else { - // Reuse the existing local id. - lhs[getLocalVarStartIndex() + loc] = -rhsConst; - } +void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) { + assert(operandExprStack.size() >= 2); + // This is a pure affine expr; the RHS will be a constant. + assert(expr.getRHS().isa()); + auto rhsConst = operandExprStack.back()[getConstantIndex()]; + operandExprStack.pop_back(); + auto &lhs = operandExprStack.back(); + // TODO(bondhugula): handle modulo by zero case when this issue is fixed + // at the other places in the IR. + assert(rhsConst > 0 && "RHS constant has to be positive"); + + // Check if the LHS expression is a multiple of modulo factor. + unsigned i, e; + for (i = 0, e = lhs.size(); i < e; i++) + if (lhs[i] % rhsConst != 0) + break; + // If yes, modulo expression here simplifies to zero. + if (i == lhs.size()) { + std::fill(lhs.begin(), lhs.end(), 0); + return; } - void visitCeilDivExpr(AffineBinaryOpExpr expr) { - visitDivExpr(expr, /*isCeil=*/true); - } - void visitFloorDivExpr(AffineBinaryOpExpr expr) { - visitDivExpr(expr, /*isCeil=*/false); + // Add a local variable for the quotient, i.e., expr % c is replaced by + // (expr - q * c) where q = expr floordiv c. Do this while canceling out + // the GCD of expr and c. + SmallVector floorDividend(lhs); + uint64_t gcd = rhsConst; + for (unsigned i = 0, e = lhs.size(); i < e; i++) + gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i])); + // Simplify the numerator and the denominator. + if (gcd != 1) { + for (unsigned i = 0, e = floorDividend.size(); i < e; i++) + floorDividend[i] = floorDividend[i] / static_cast(gcd); } + int64_t floorDivisor = rhsConst / static_cast(gcd); - void visitDimExpr(AffineDimExpr expr) { - operandExprStack.emplace_back(SmallVector(getNumCols(), 0)); - auto &eq = operandExprStack.back(); - assert(expr.getPosition() < numDims && "Inconsistent number of dims"); - eq[getDimStartIndex() + expr.getPosition()] = 1; + // Construct the AffineExpr form of the floordiv to store in localExprs. + auto dividendExpr = + toAffineExpr(floorDividend, numDims, numSymbols, localExprs, context); + auto divisorExpr = getAffineConstantExpr(floorDivisor, context); + auto floorDivExpr = dividendExpr.floorDiv(divisorExpr); + int loc; + if ((loc = findLocalId(floorDivExpr)) == -1) { + addLocalFloorDivId(floorDividend, floorDivisor, floorDivExpr); + // Set result at top of stack to "lhs - rhsConst * q". + lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst; + } else { + // Reuse the existing local id. + lhs[getLocalVarStartIndex() + loc] = -rhsConst; } +} - void visitSymbolExpr(AffineSymbolExpr expr) { - operandExprStack.emplace_back(SmallVector(getNumCols(), 0)); - auto &eq = operandExprStack.back(); - assert(expr.getPosition() < numSymbols && "inconsistent number of symbols"); - eq[getSymbolStartIndex() + expr.getPosition()] = 1; - } +void SimpleAffineExprFlattener::visitCeilDivExpr(AffineBinaryOpExpr expr) { + visitDivExpr(expr, /*isCeil=*/true); +} +void SimpleAffineExprFlattener::visitFloorDivExpr(AffineBinaryOpExpr expr) { + visitDivExpr(expr, /*isCeil=*/false); +} - void visitConstantExpr(AffineConstantExpr expr) { - operandExprStack.emplace_back(SmallVector(getNumCols(), 0)); - auto &eq = operandExprStack.back(); - eq[getConstantIndex()] = expr.getValue(); - } +void SimpleAffineExprFlattener::visitDimExpr(AffineDimExpr expr) { + operandExprStack.emplace_back(SmallVector(getNumCols(), 0)); + auto &eq = operandExprStack.back(); + assert(expr.getPosition() < numDims && "Inconsistent number of dims"); + eq[getDimStartIndex() + expr.getPosition()] = 1; +} + +void SimpleAffineExprFlattener::visitSymbolExpr(AffineSymbolExpr expr) { + operandExprStack.emplace_back(SmallVector(getNumCols(), 0)); + auto &eq = operandExprStack.back(); + assert(expr.getPosition() < numSymbols && "inconsistent number of symbols"); + eq[getSymbolStartIndex() + expr.getPosition()] = 1; +} + +void SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) { + operandExprStack.emplace_back(SmallVector(getNumCols(), 0)); + auto &eq = operandExprStack.back(); + eq[getConstantIndex()] = expr.getValue(); +} -private: // t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1 // A floordiv is thus flattened by introducing a new local variable q, and // replacing that expression with 'q' while adding the constraints @@ -551,97 +462,86 @@ private: // // A ceildiv is similarly flattened: // t = expr ceildiv c <=> t = (expr + c - 1) floordiv c - void visitDivExpr(AffineBinaryOpExpr expr, bool isCeil) { - assert(operandExprStack.size() >= 2); - assert(expr.getRHS().isa()); - - // This is a pure affine expr; the RHS is a positive constant. - int64_t rhsConst = operandExprStack.back()[getConstantIndex()]; - // TODO(bondhugula): handle division by zero at the same time the issue is - // fixed at other places. - assert(rhsConst > 0 && "RHS constant has to be positive"); - operandExprStack.pop_back(); - auto &lhs = operandExprStack.back(); - - // Simplify the floordiv, ceildiv if possible by canceling out the greatest - // common divisors of the numerator and denominator. - uint64_t gcd = std::abs(rhsConst); +void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr, + bool isCeil) { + assert(operandExprStack.size() >= 2); + assert(expr.getRHS().isa()); + + // This is a pure affine expr; the RHS is a positive constant. + int64_t rhsConst = operandExprStack.back()[getConstantIndex()]; + // TODO(bondhugula): handle division by zero at the same time the issue is + // fixed at other places. + assert(rhsConst > 0 && "RHS constant has to be positive"); + operandExprStack.pop_back(); + auto &lhs = operandExprStack.back(); + + // Simplify the floordiv, ceildiv if possible by canceling out the greatest + // common divisors of the numerator and denominator. + uint64_t gcd = std::abs(rhsConst); + for (unsigned i = 0, e = lhs.size(); i < e; i++) + gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i])); + // Simplify the numerator and the denominator. + if (gcd != 1) { for (unsigned i = 0, e = lhs.size(); i < e; i++) - gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i])); - // Simplify the numerator and the denominator. - if (gcd != 1) { - for (unsigned i = 0, e = lhs.size(); i < e; i++) - lhs[i] = lhs[i] / static_cast(gcd); - } - int64_t divisor = rhsConst / static_cast(gcd); - // If the divisor becomes 1, the updated LHS is the result. (The - // divisor can't be negative since rhsConst is positive). - if (divisor == 1) - return; - - // If the divisor cannot be simplified to one, we will have to retain - // the ceil/floor expr (simplified up until here). Add an existential - // quantifier to express its result, i.e., expr1 div expr2 is replaced - // by a new identifier, q. - auto a = toAffineExpr(lhs, numDims, numSymbols, localExprs, context); - auto b = getAffineConstantExpr(divisor, context); - - int loc; - auto divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b); - if ((loc = findLocalId(divExpr)) == -1) { - if (!isCeil) { - SmallVector dividend(lhs); - addLocalFloorDivId(dividend, divisor, divExpr); - } else { - // lhs ceildiv c <=> (lhs + c - 1) floordiv c - SmallVector dividend(lhs); - dividend.back() += divisor - 1; - addLocalFloorDivId(dividend, divisor, divExpr); - } + lhs[i] = lhs[i] / static_cast(gcd); + } + int64_t divisor = rhsConst / static_cast(gcd); + // If the divisor becomes 1, the updated LHS is the result. (The + // divisor can't be negative since rhsConst is positive). + if (divisor == 1) + return; + + // If the divisor cannot be simplified to one, we will have to retain + // the ceil/floor expr (simplified up until here). Add an existential + // quantifier to express its result, i.e., expr1 div expr2 is replaced + // by a new identifier, q. + auto a = toAffineExpr(lhs, numDims, numSymbols, localExprs, context); + auto b = getAffineConstantExpr(divisor, context); + + int loc; + auto divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b); + if ((loc = findLocalId(divExpr)) == -1) { + if (!isCeil) { + SmallVector dividend(lhs); + addLocalFloorDivId(dividend, divisor, divExpr); + } else { + // lhs ceildiv c <=> (lhs + c - 1) floordiv c + SmallVector dividend(lhs); + dividend.back() += divisor - 1; + addLocalFloorDivId(dividend, divisor, divExpr); } - // Set the expression on stack to the local var introduced to capture the - // result of the division (floor or ceil). - std::fill(lhs.begin(), lhs.end(), 0); - if (loc == -1) - lhs[getLocalVarStartIndex() + numLocals - 1] = 1; - else - lhs[getLocalVarStartIndex() + loc] = 1; } + // Set the expression on stack to the local var introduced to capture the + // result of the division (floor or ceil). + std::fill(lhs.begin(), lhs.end(), 0); + if (loc == -1) + lhs[getLocalVarStartIndex() + numLocals - 1] = 1; + else + lhs[getLocalVarStartIndex() + loc] = 1; +} // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr). // The local identifier added is always a floordiv of a pure add/mul affine // function of other identifiers, coefficients of which are specified in // dividend and with respect to a positive constant divisor. localExpr is the // simplified tree expression (AffineExpr) corresponding to the quantifier. - void addLocalFloorDivId(ArrayRef dividend, int64_t divisor, - AffineExpr localExpr) { - assert(divisor > 0 && "positive constant divisor expected"); - for (auto &subExpr : operandExprStack) - subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0); - localExprs.push_back(localExpr); - numLocals++; - // Update localVarCst. - localVarCst.addLocalFloorDiv(dividend, divisor); - } - - int findLocalId(AffineExpr localExpr) { - SmallVectorImpl::iterator it; - if ((it = std::find(localExprs.begin(), localExprs.end(), localExpr)) == - localExprs.end()) - return -1; - return it - localExprs.begin(); - } - - inline unsigned getNumCols() const { - return numDims + numSymbols + numLocals + 1; - } - inline unsigned getConstantIndex() const { return getNumCols() - 1; } - inline unsigned getLocalVarStartIndex() const { return numDims + numSymbols; } - inline unsigned getSymbolStartIndex() const { return numDims; } - inline unsigned getDimStartIndex() const { return 0; } -}; +void SimpleAffineExprFlattener::addLocalFloorDivId(ArrayRef dividend, + int64_t divisor, + AffineExpr localExpr) { + assert(divisor > 0 && "positive constant divisor expected"); + for (auto &subExpr : operandExprStack) + subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0); + localExprs.push_back(localExpr); + numLocals++; + // dividend and divisor are ignored; an override of this method uses it. +} -} // end anonymous namespace +int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) { + SmallVectorImpl::iterator it; + if ((it = llvm::find(localExprs, localExpr)) == localExprs.end()) + return -1; + return it - localExprs.begin(); +} /// Simplify the affine expression by flattening it and reconstructing it. AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims, @@ -651,7 +551,7 @@ AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims, if (!expr.isPureAffine()) return expr; - AffineExprFlattener flattener(numDims, numSymbols, expr.getContext()); + SimpleAffineExprFlattener flattener(numDims, numSymbols, expr.getContext()); flattener.walkPostOrder(expr); ArrayRef flattenedExpr = flattener.operandExprStack.back(); auto simplifiedExpr = toAffineExpr(flattenedExpr, numDims, numSymbols, @@ -667,17 +567,13 @@ AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims, // handled yet). static bool getFlattenedAffineExprs( ArrayRef exprs, unsigned numDims, unsigned numSymbols, - std::vector> *flattenedExprs, - FlatAffineConstraints *localVarCst) { + std::vector> *flattenedExprs) { if (exprs.empty()) { - localVarCst->reset(numDims, numSymbols); return true; } - flattenedExprs->clear(); - flattenedExprs->reserve(exprs.size()); - - AffineExprFlattener flattener(numDims, numSymbols, exprs[0].getContext()); + SimpleAffineExprFlattener flattener(numDims, numSymbols, + exprs[0].getContext()); // Use the same flattener to simplify each expression successively. This way // local identifiers / expressions are shared. for (auto expr : exprs) { @@ -687,12 +583,10 @@ static bool getFlattenedAffineExprs( flattener.walkPostOrder(expr); } + flattenedExprs->clear(); assert(flattener.operandExprStack.size() == exprs.size()); - flattenedExprs->insert(flattenedExprs->end(), - flattener.operandExprStack.begin(), + flattenedExprs->assign(flattener.operandExprStack.begin(), flattener.operandExprStack.end()); - if (localVarCst) - localVarCst->clearAndCopyFrom(flattener.localVarCst); return true; } @@ -700,13 +594,12 @@ static bool getFlattenedAffineExprs( // Flattens 'expr' into 'flattenedExpr'. Returns true on success or false // if 'expr' was unable to be flattened (semi-affine expressions not handled // yet). -bool mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims, - unsigned numSymbols, - llvm::SmallVectorImpl *flattenedExpr, - FlatAffineConstraints *localVarCst) { +bool mlir::getFlattenedAffineExpr( + AffineExpr expr, unsigned numDims, unsigned numSymbols, + llvm::SmallVectorImpl *flattenedExpr) { std::vector> flattenedExprs; - bool ret = ::getFlattenedAffineExprs({expr}, numDims, numSymbols, - &flattenedExprs, localVarCst); + bool ret = + ::getFlattenedAffineExprs({expr}, numDims, numSymbols, &flattenedExprs); *flattenedExpr = flattenedExprs[0]; return ret; } @@ -715,25 +608,20 @@ bool mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims, /// if 'expr' was unable to be flattened (i.e., semi-affine expressions not /// handled yet). bool mlir::getFlattenedAffineExprs( - AffineMap map, std::vector> *flattenedExprs, - FlatAffineConstraints *localVarCst) { + AffineMap map, std::vector> *flattenedExprs) { if (map.getNumResults() == 0) { - localVarCst->reset(map.getNumDims(), map.getNumSymbols()); return true; } return ::getFlattenedAffineExprs(map.getResults(), map.getNumDims(), - map.getNumSymbols(), flattenedExprs, - localVarCst); + map.getNumSymbols(), flattenedExprs); } bool mlir::getFlattenedAffineExprs( - IntegerSet set, std::vector> *flattenedExprs, - FlatAffineConstraints *localVarCst) { + IntegerSet set, + std::vector> *flattenedExprs) { if (set.getNumConstraints() == 0) { - localVarCst->reset(set.getNumDims(), set.getNumSymbols()); return true; } return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(), - set.getNumSymbols(), flattenedExprs, - localVarCst); + set.getNumSymbols(), flattenedExprs); } diff --git a/mlir/lib/IR/AffineStructures.cpp b/mlir/lib/IR/AffineStructures.cpp deleted file mode 100644 index 5114f56bcfc..00000000000 --- a/mlir/lib/IR/AffineStructures.cpp +++ /dev/null @@ -1,2312 +0,0 @@ -//===- AffineStructures.cpp - MLIR Affine Structures Class-------*- C++ -*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// Structures for affine/polyhedral analysis of MLIR functions. -// -//===----------------------------------------------------------------------===// - -#include "mlir/IR/AffineStructures.h" -#include "mlir/IR/AffineExprVisitor.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Instruction.h" -#include "mlir/IR/IntegerSet.h" -#include "mlir/Support/MathExtras.h" -#include "llvm/ADT/DenseSet.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" - -#define DEBUG_TYPE "affine-structures" - -using namespace mlir; -using namespace llvm; - -//===----------------------------------------------------------------------===// -// MutableAffineMap. -//===----------------------------------------------------------------------===// - -MutableAffineMap::MutableAffineMap(AffineMap map) - : numDims(map.getNumDims()), numSymbols(map.getNumSymbols()), - // A map always has at least 1 result by construction - context(map.getResult(0).getContext()) { - for (auto result : map.getResults()) - results.push_back(result); - for (auto rangeSize : map.getRangeSizes()) - results.push_back(rangeSize); -} - -void MutableAffineMap::reset(AffineMap map) { - results.clear(); - rangeSizes.clear(); - numDims = map.getNumDims(); - numSymbols = map.getNumSymbols(); - // A map always has at least 1 result by construction - context = map.getResult(0).getContext(); - for (auto result : map.getResults()) - results.push_back(result); - for (auto rangeSize : map.getRangeSizes()) - results.push_back(rangeSize); -} - -bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const { - if (results[idx].isMultipleOf(factor)) - return true; - - // TODO(bondhugula): use simplifyAffineExpr and FlatAffineConstraints to - // complete this (for a more powerful analysis). - return false; -} - -// Simplifies the result affine expressions of this map. The expressions have to -// be pure for the simplification implemented. -void MutableAffineMap::simplify() { - // Simplify each of the results if possible. - // TODO(ntv): functional-style map - for (unsigned i = 0, e = getNumResults(); i < e; i++) { - results[i] = simplifyAffineExpr(getResult(i), numDims, numSymbols); - } -} - -AffineMap MutableAffineMap::getAffineMap() const { - return AffineMap::get(numDims, numSymbols, results, rangeSizes); -} - -MutableIntegerSet::MutableIntegerSet(IntegerSet set, MLIRContext *context) - : numDims(set.getNumDims()), numSymbols(set.getNumSymbols()), - context(context) { - // TODO(bondhugula) -} - -// Universal set. -MutableIntegerSet::MutableIntegerSet(unsigned numDims, unsigned numSymbols, - MLIRContext *context) - : numDims(numDims), numSymbols(numSymbols), context(context) {} - -//===----------------------------------------------------------------------===// -// AffineValueMap. -//===----------------------------------------------------------------------===// - -AffineValueMap::AffineValueMap(AffineMap map, ArrayRef operands, - ArrayRef results) - : map(map), operands(operands.begin(), operands.end()), - results(results.begin(), results.end()) {} - -void AffineValueMap::reset(AffineMap map, ArrayRef operands, - ArrayRef results) { - this->map.reset(map); - this->operands.assign(operands.begin(), operands.end()); - this->results.assign(results.begin(), results.end()); -} - -// Returns true and sets 'indexOfMatch' if 'valueToMatch' is found in -// 'valuesToSearch' beginning at 'indexStart'. Returns false otherwise. -static bool findIndex(Value *valueToMatch, ArrayRef valuesToSearch, - unsigned indexStart, unsigned *indexOfMatch) { - unsigned size = valuesToSearch.size(); - for (unsigned i = indexStart; i < size; ++i) { - if (valueToMatch == valuesToSearch[i]) { - *indexOfMatch = i; - return true; - } - } - return false; -} - -inline bool AffineValueMap::isMultipleOf(unsigned idx, int64_t factor) const { - return map.isMultipleOf(idx, factor); -} - -/// This method uses the invariant that operands are always positionally aligned -/// with the AffineDimExpr in the underlying AffineMap. -bool AffineValueMap::isFunctionOf(unsigned idx, Value *value) const { - unsigned index; - if (!findIndex(value, operands, /*indexStart=*/0, &index)) { - return false; - } - auto expr = const_cast(this)->getAffineMap().getResult(idx); - // TODO(ntv): this is better implemented on a flattened representation. - // At least for now it is conservative. - return expr.isFunctionOfDim(index); -} - -Value *AffineValueMap::getOperand(unsigned i) const { - return static_cast(operands[i]); -} - -ArrayRef AffineValueMap::getOperands() const { - return ArrayRef(operands); -} - -AffineMap AffineValueMap::getAffineMap() const { return map.getAffineMap(); } - -AffineValueMap::~AffineValueMap() {} - -//===----------------------------------------------------------------------===// -// FlatAffineConstraints. -//===----------------------------------------------------------------------===// - -// Copy constructor. -FlatAffineConstraints::FlatAffineConstraints( - const FlatAffineConstraints &other) { - numReservedCols = other.numReservedCols; - numDims = other.getNumDimIds(); - numSymbols = other.getNumSymbolIds(); - numIds = other.getNumIds(); - - auto otherIds = other.getIds(); - ids.reserve(numReservedCols); - ids.append(otherIds.begin(), otherIds.end()); - - unsigned numReservedEqualities = other.getNumReservedEqualities(); - unsigned numReservedInequalities = other.getNumReservedInequalities(); - - equalities.reserve(numReservedEqualities * numReservedCols); - inequalities.reserve(numReservedInequalities * numReservedCols); - - for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) { - addInequality(other.getInequality(r)); - } - for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) { - addEquality(other.getEquality(r)); - } -} - -// Clones this object. -std::unique_ptr FlatAffineConstraints::clone() const { - return std::make_unique(*this); -} - -// Construct from an IntegerSet. -FlatAffineConstraints::FlatAffineConstraints(IntegerSet set) - : numReservedCols(set.getNumOperands() + 1), - numIds(set.getNumDims() + set.getNumSymbols()), numDims(set.getNumDims()), - numSymbols(set.getNumSymbols()) { - equalities.reserve(set.getNumEqualities() * numReservedCols); - inequalities.reserve(set.getNumInequalities() * numReservedCols); - ids.resize(numIds, None); - - // Flatten expressions and add them to the constraint system. - std::vector> flatExprs; - FlatAffineConstraints localVarCst; - if (!getFlattenedAffineExprs(set, &flatExprs, &localVarCst)) { - assert(false && "flattening unimplemented for semi-affine integer sets"); - return; - } - assert(flatExprs.size() == set.getNumConstraints()); - for (unsigned l = 0, e = localVarCst.getNumLocalIds(); l < e; l++) { - addLocalId(getNumLocalIds()); - } - - for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) { - const auto &flatExpr = flatExprs[i]; - assert(flatExpr.size() == getNumCols()); - if (set.getEqFlags()[i]) { - addEquality(flatExpr); - } else { - addInequality(flatExpr); - } - } - // Add the other constraints involving local id's from flattening. - append(localVarCst); -} - -void FlatAffineConstraints::reset(unsigned numReservedInequalities, - unsigned numReservedEqualities, - unsigned newNumReservedCols, - unsigned newNumDims, unsigned newNumSymbols, - unsigned newNumLocals, - ArrayRef idArgs) { - assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 && - "minimum 1 column"); - numReservedCols = newNumReservedCols; - numDims = newNumDims; - numSymbols = newNumSymbols; - numIds = numDims + numSymbols + newNumLocals; - assert(idArgs.empty() || idArgs.size() == numIds); - - clearConstraints(); - if (numReservedEqualities >= 1) - equalities.reserve(newNumReservedCols * numReservedEqualities); - if (numReservedInequalities >= 1) - inequalities.reserve(newNumReservedCols * numReservedInequalities); - if (idArgs.empty()) { - ids.resize(numIds, None); - } else { - ids.assign(idArgs.begin(), idArgs.end()); - } -} - -void FlatAffineConstraints::reset(unsigned newNumDims, unsigned newNumSymbols, - unsigned newNumLocals, - ArrayRef idArgs) { - reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims, - newNumSymbols, newNumLocals, idArgs); -} - -void FlatAffineConstraints::append(const FlatAffineConstraints &other) { - assert(other.getNumCols() == getNumCols()); - assert(other.getNumDimIds() == getNumDimIds()); - assert(other.getNumSymbolIds() == getNumSymbolIds()); - - inequalities.reserve(inequalities.size() + - other.getNumInequalities() * numReservedCols); - equalities.reserve(equalities.size() + - other.getNumEqualities() * numReservedCols); - - for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) { - addInequality(other.getInequality(r)); - } - for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) { - addEquality(other.getEquality(r)); - } -} - -void FlatAffineConstraints::addLocalId(unsigned pos) { - addId(IdKind::Local, pos); -} - -void FlatAffineConstraints::addDimId(unsigned pos, Value *id) { - addId(IdKind::Dimension, pos, id); -} - -void FlatAffineConstraints::addSymbolId(unsigned pos, Value *id) { - addId(IdKind::Symbol, pos, id); -} - -/// Adds a dimensional identifier. The added column is initialized to -/// zero. -void FlatAffineConstraints::addId(IdKind kind, unsigned pos, Value *id) { - if (kind == IdKind::Dimension) { - assert(pos <= getNumDimIds()); - } else if (kind == IdKind::Symbol) { - assert(pos <= getNumSymbolIds()); - } else { - assert(pos <= getNumLocalIds()); - } - - unsigned oldNumReservedCols = numReservedCols; - - // Check if a resize is necessary. - if (getNumCols() + 1 > numReservedCols) { - equalities.resize(getNumEqualities() * (getNumCols() + 1)); - inequalities.resize(getNumInequalities() * (getNumCols() + 1)); - numReservedCols++; - } - - unsigned absolutePos; - - if (kind == IdKind::Dimension) { - absolutePos = pos; - numDims++; - } else if (kind == IdKind::Symbol) { - absolutePos = pos + getNumDimIds(); - numSymbols++; - } else { - absolutePos = pos + getNumDimIds() + getNumSymbolIds(); - } - numIds++; - - // Note that getNumCols() now will already return the new size, which will be - // at least one. - int numInequalities = static_cast(getNumInequalities()); - int numEqualities = static_cast(getNumEqualities()); - int numCols = static_cast(getNumCols()); - for (int r = numInequalities - 1; r >= 0; r--) { - for (int c = numCols - 2; c >= 0; c--) { - if (c < absolutePos) - atIneq(r, c) = inequalities[r * oldNumReservedCols + c]; - else - atIneq(r, c + 1) = inequalities[r * oldNumReservedCols + c]; - } - atIneq(r, absolutePos) = 0; - } - - for (int r = numEqualities - 1; r >= 0; r--) { - for (int c = numCols - 2; c >= 0; c--) { - // All values in column absolutePositions < absolutePos have the same - // coordinates in the 2-d view of the coefficient buffer. - if (c < absolutePos) - atEq(r, c) = equalities[r * oldNumReservedCols + c]; - else - // Those at absolutePosition >= absolutePos, get a shifted - // absolutePosition. - atEq(r, c + 1) = equalities[r * oldNumReservedCols + c]; - } - // Initialize added dimension to zero. - atEq(r, absolutePos) = 0; - } - - // If an 'id' is provided, insert it; otherwise use None. - if (id) { - ids.insert(ids.begin() + absolutePos, id); - } else { - ids.insert(ids.begin() + absolutePos, None); - } - assert(ids.size() == getNumIds()); -} - -// This routine may add additional local variables if the flattened expression -// corresponding to the map has such variables due to the presence of -// mod's, ceildiv's, and floordiv's. -bool FlatAffineConstraints::composeMap(AffineValueMap *vMap) { - // Assert if the map and this constraint set aren't associated with the same - // identifiers in the same order. - assert(vMap->getNumDims() <= getNumDimIds()); - assert(vMap->getNumSymbols() <= getNumSymbolIds()); - for (unsigned i = 0, e = vMap->getNumDims(); i < e; i++) { - assert(ids[i].hasValue()); - assert(vMap->getOperand(i) == ids[i].getValue()); - } - for (unsigned i = 0, e = vMap->getNumSymbols(); i < e; i++) { - assert(ids[numDims + i].hasValue()); - assert(vMap->getOperand(vMap->getNumDims() + i) == - ids[numDims + i].getValue()); - } - - std::vector> flatExprs; - FlatAffineConstraints cst; - if (!getFlattenedAffineExprs(vMap->getAffineMap(), &flatExprs, &cst)) { - LLVM_DEBUG(llvm::dbgs() - << "composition unimplemented for semi-affine maps\n"); - return false; - } - assert(flatExprs.size() == vMap->getNumResults()); - - // Make the value map and the flat affine cst dimensions compatible. - // A lot of this code will be refactored/cleaned up. - // TODO(bondhugula): the next ~20 lines of code is pretty UGLY. This needs - // to be factored out into an FlatAffineConstraints::alignAndMerge(). - for (unsigned l = 0, e = cst.getNumLocalIds(); l < e; l++) { - addLocalId(0); - } - - for (unsigned t = 0, e = vMap->getNumResults(); t < e; t++) { - // TODO: Consider using a batched version to add a range of IDs. - addDimId(0); - cst.addDimId(0); - } - - assert(cst.getNumDimIds() <= getNumDimIds()); - for (unsigned t = 0, e = getNumDimIds() - cst.getNumDimIds(); t < e; t++) { - // Dimensions that are in 'this' but not in vMap/cst are added at the end. - cst.addDimId(cst.getNumDimIds()); - } - assert(cst.getNumSymbolIds() <= getNumSymbolIds()); - for (unsigned t = 0, e = getNumSymbolIds() - cst.getNumSymbolIds(); t < e; - t++) { - // Dimensions that are in 'this' but not in vMap/cst are added at the end. - cst.addSymbolId(cst.getNumSymbolIds()); - } - assert(cst.getNumLocalIds() <= getNumLocalIds()); - for (unsigned t = 0, e = getNumLocalIds() - cst.getNumLocalIds(); t < e; - t++) { - cst.addLocalId(cst.getNumLocalIds()); - } - /// Finally, append cst to this constraint set. - append(cst); - - // We add one equality for each result connecting the result dim of the map to - // the other identifiers. - // For eg: if the expression is 16*i0 + i1, and this is the r^th - // iteration/result of the value map, we are adding the equality: - // d_r - 16*i0 - i1 = 0. Hence, when flattening say (i0 + 1, i0 + 8*i2), we - // add two equalities overall: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0. - for (unsigned r = 0, e = flatExprs.size(); r < e; r++) { - const auto &flatExpr = flatExprs[r]; - // eqToAdd is the equality corresponding to the flattened affine expression. - SmallVector eqToAdd(getNumCols(), 0); - // Set the coefficient for this result to one. - eqToAdd[r] = 1; - - assert(flatExpr.size() >= vMap->getNumOperands() + 1); - - // Dims and symbols. - for (unsigned i = 0, e = vMap->getNumOperands(); i < e; i++) { - unsigned loc; - bool ret = findId(*vMap->getOperand(i), &loc); - assert(ret && "value map's id can't be found"); - (void)ret; - // We need to negate 'eq[r]' since the newly added dimension is going to - // be set to this one. - eqToAdd[loc] = -flatExpr[i]; - } - // Local vars common to eq and cst are at the beginning. - int j = getNumDimIds() + getNumSymbolIds(); - int end = flatExpr.size() - 1; - for (int i = vMap->getNumOperands(); i < end; i++, j++) { - eqToAdd[j] = -flatExpr[i]; - } - - // Constant term. - eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1]; - - // Add the equality connecting the result of the map to this constraint set. - addEquality(eqToAdd); - } - - return true; -} - -// Searches for a constraint with a non-zero coefficient at 'colIdx' in -// equality (isEq=true) or inequality (isEq=false) constraints. -// Returns true and sets row found in search in 'rowIdx'. -// Returns false otherwise. -static bool -findConstraintWithNonZeroAt(const FlatAffineConstraints &constraints, - unsigned colIdx, bool isEq, unsigned *rowIdx) { - auto at = [&](unsigned rowIdx) -> int64_t { - return isEq ? constraints.atEq(rowIdx, colIdx) - : constraints.atIneq(rowIdx, colIdx); - }; - unsigned e = - isEq ? constraints.getNumEqualities() : constraints.getNumInequalities(); - for (*rowIdx = 0; *rowIdx < e; ++(*rowIdx)) { - if (at(*rowIdx) != 0) { - return true; - } - } - return false; -} - -// Normalizes the coefficient values across all columns in 'rowIDx' by their -// GCD in equality or inequality contraints as specified by 'isEq'. -template -static void normalizeConstraintByGCD(FlatAffineConstraints *constraints, - unsigned rowIdx) { - auto at = [&](unsigned colIdx) -> int64_t { - return isEq ? constraints->atEq(rowIdx, colIdx) - : constraints->atIneq(rowIdx, colIdx); - }; - uint64_t gcd = std::abs(at(0)); - for (unsigned j = 1, e = constraints->getNumCols(); j < e; ++j) { - gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(at(j))); - } - if (gcd > 0 && gcd != 1) { - for (unsigned j = 0, e = constraints->getNumCols(); j < e; ++j) { - int64_t v = at(j) / static_cast(gcd); - isEq ? constraints->atEq(rowIdx, j) = v - : constraints->atIneq(rowIdx, j) = v; - } - } -} - -void FlatAffineConstraints::normalizeConstraintsByGCD() { - for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { - normalizeConstraintByGCD(this, i); - } - for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { - normalizeConstraintByGCD(this, i); - } -} - -bool FlatAffineConstraints::hasConsistentState() const { - if (inequalities.size() != getNumInequalities() * numReservedCols) - return false; - if (equalities.size() != getNumEqualities() * numReservedCols) - return false; - if (ids.size() != getNumIds()) - return false; - - // Catches errors where numDims, numSymbols, numIds aren't consistent. - if (numDims > numIds || numSymbols > numIds || numDims + numSymbols > numIds) - return false; - - return true; -} - -/// Checks all rows of equality/inequality constraints for trivial -/// contradictions (for example: 1 == 0, 0 >= 1), which may have surfaced -/// after elimination. Returns 'true' if an invalid constraint is found; -/// 'false' otherwise. -bool FlatAffineConstraints::hasInvalidConstraint() const { - assert(hasConsistentState()); - auto check = [&](bool isEq) -> bool { - unsigned numCols = getNumCols(); - unsigned numRows = isEq ? getNumEqualities() : getNumInequalities(); - for (unsigned i = 0, e = numRows; i < e; ++i) { - unsigned j; - for (j = 0; j < numCols - 1; ++j) { - int64_t v = isEq ? atEq(i, j) : atIneq(i, j); - // Skip rows with non-zero variable coefficients. - if (v != 0) - break; - } - if (j < numCols - 1) { - continue; - } - // Check validity of constant term at 'numCols - 1' w.r.t 'isEq'. - // Example invalid constraints include: '1 == 0' or '-1 >= 0' - int64_t v = isEq ? atEq(i, numCols - 1) : atIneq(i, numCols - 1); - if ((isEq && v != 0) || (!isEq && v < 0)) { - return true; - } - } - return false; - }; - if (check(/*isEq=*/true)) - return true; - return check(/*isEq=*/false); -} - -// Eliminate identifier from constraint at 'rowIdx' based on coefficient at -// pivotRow, pivotCol. Columns in range [elimColStart, pivotCol) will not be -// updated as they have already been eliminated. -static void eliminateFromConstraint(FlatAffineConstraints *constraints, - unsigned rowIdx, unsigned pivotRow, - unsigned pivotCol, unsigned elimColStart, - bool isEq) { - // Skip if equality 'rowIdx' if same as 'pivotRow'. - if (isEq && rowIdx == pivotRow) - return; - auto at = [&](unsigned i, unsigned j) -> int64_t { - return isEq ? constraints->atEq(i, j) : constraints->atIneq(i, j); - }; - int64_t leadCoeff = at(rowIdx, pivotCol); - // Skip if leading coefficient at 'rowIdx' is already zero. - if (leadCoeff == 0) - return; - int64_t pivotCoeff = constraints->atEq(pivotRow, pivotCol); - int64_t sign = (leadCoeff * pivotCoeff > 0) ? -1 : 1; - int64_t lcm = mlir::lcm(pivotCoeff, leadCoeff); - int64_t pivotMultiplier = sign * (lcm / std::abs(pivotCoeff)); - int64_t rowMultiplier = lcm / std::abs(leadCoeff); - - unsigned numCols = constraints->getNumCols(); - for (unsigned j = 0; j < numCols; ++j) { - // Skip updating column 'j' if it was just eliminated. - if (j >= elimColStart && j < pivotCol) - continue; - int64_t v = pivotMultiplier * constraints->atEq(pivotRow, j) + - rowMultiplier * at(rowIdx, j); - isEq ? constraints->atEq(rowIdx, j) = v - : constraints->atIneq(rowIdx, j) = v; - } -} - -// Remove coefficients in column range [colStart, colLimit) in place. -// This removes in data in the specified column range, and copies any -// remaining valid data into place. -static void shiftColumnsToLeft(FlatAffineConstraints *constraints, - unsigned colStart, unsigned colLimit, - bool isEq) { - assert(colStart >= 0 && colLimit <= constraints->getNumIds()); - if (colLimit <= colStart) - return; - - unsigned numCols = constraints->getNumCols(); - unsigned numRows = isEq ? constraints->getNumEqualities() - : constraints->getNumInequalities(); - unsigned numToEliminate = colLimit - colStart; - for (unsigned r = 0, e = numRows; r < e; ++r) { - for (unsigned c = colLimit; c < numCols; ++c) { - if (isEq) { - constraints->atEq(r, c - numToEliminate) = constraints->atEq(r, c); - } else { - constraints->atIneq(r, c - numToEliminate) = constraints->atIneq(r, c); - } - } - } -} - -// Removes identifiers in column range [idStart, idLimit), and copies any -// remaining valid data into place, and updates member variables. -void FlatAffineConstraints::removeIdRange(unsigned idStart, unsigned idLimit) { - assert(idLimit < getNumCols() && "invalid id limit"); - - if (idStart >= idLimit) - return; - - // We are going to be removing one or more identifiers from the range. - assert(idStart < numIds && "invalid idStart position"); - - // TODO(andydavis) Make 'removeIdRange' a lambda called from here. - // Remove eliminated identifiers from equalities. - shiftColumnsToLeft(this, idStart, idLimit, /*isEq=*/true); - - // Remove eliminated identifiers from inequalities. - shiftColumnsToLeft(this, idStart, idLimit, /*isEq=*/false); - - // Update members numDims, numSymbols and numIds. - unsigned numDimsEliminated = 0; - unsigned numLocalsEliminated = 0; - unsigned numColsEliminated = idLimit - idStart; - if (idStart < numDims) { - numDimsEliminated = std::min(numDims, idLimit) - idStart; - } - // Check how many local id's were removed. Note that our identifier order is - // [dims, symbols, locals]. Local id start at position numDims + numSymbols. - if (idLimit > numDims + numSymbols) { - numLocalsEliminated = std::min( - idLimit - std::max(idStart, numDims + numSymbols), getNumLocalIds()); - } - unsigned numSymbolsEliminated = - numColsEliminated - numDimsEliminated - numLocalsEliminated; - - numDims -= numDimsEliminated; - numSymbols -= numSymbolsEliminated; - numIds = numIds - numColsEliminated; - - ids.erase(ids.begin() + idStart, ids.begin() + idLimit); - - // No resize necessary. numReservedCols remains the same. -} - -/// Returns the position of the identifier that has the minimum times from the specified range of -/// identifiers [start, end). It is often best to eliminate in the increasing -/// order of these counts when doing Fourier-Motzkin elimination since FM adds -/// that many new constraints. -static unsigned getBestIdToEliminate(const FlatAffineConstraints &cst, - unsigned start, unsigned end) { - assert(start < cst.getNumIds() && end < cst.getNumIds() + 1); - - auto getProductOfNumLowerUpperBounds = [&](unsigned pos) { - unsigned numLb = 0; - unsigned numUb = 0; - for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) { - if (cst.atIneq(r, pos) > 0) { - ++numLb; - } else if (cst.atIneq(r, pos) < 0) { - ++numUb; - } - } - return numLb * numUb; - }; - - unsigned minLoc = start; - unsigned min = getProductOfNumLowerUpperBounds(start); - for (unsigned c = start + 1; c < end; c++) { - unsigned numLbUbProduct = getProductOfNumLowerUpperBounds(c); - if (numLbUbProduct < min) { - min = numLbUbProduct; - minLoc = c; - } - } - return minLoc; -} - -// Checks for emptiness of the set by eliminating identifiers successively and -// using the GCD test (on all equality constraints) and checking for trivially -// invalid constraints. Returns 'true' if the constraint system is found to be -// empty; false otherwise. -bool FlatAffineConstraints::isEmpty() const { - if (isEmptyByGCDTest() || hasInvalidConstraint()) - return true; - - // First, eliminate as many identifiers as possible using Gaussian - // elimination. - FlatAffineConstraints tmpCst(*this); - unsigned currentPos = 0; - while (currentPos < tmpCst.getNumIds()) { - tmpCst.gaussianEliminateIds(currentPos, tmpCst.getNumIds()); - ++currentPos; - // We check emptiness through trivial checks after eliminating each ID to - // detect emptiness early. Since the checks isEmptyByGCDTest() and - // hasInvalidConstraint() are linear time and single sweep on the constraint - // buffer, this appears reasonable - but can optimize in the future. - if (tmpCst.hasInvalidConstraint() || tmpCst.isEmptyByGCDTest()) - return true; - } - - // Eliminate the remaining using FM. - for (unsigned i = 0, e = tmpCst.getNumIds(); i < e; i++) { - tmpCst.FourierMotzkinEliminate( - getBestIdToEliminate(tmpCst, 0, tmpCst.getNumIds())); - // Check for a constraint explosion. This rarely happens in practice, but - // this check exists as a safeguard against improperly constructed - // constraint systems or artifically created arbitrarily complex systems - // that aren't the intended use case for FlatAffineConstraints. This is - // needed since FM has a worst case exponential complexity in theory. - if (tmpCst.getNumConstraints() >= kExplosionFactor * getNumIds()) { - LLVM_DEBUG(llvm::dbgs() << "FM constraint explosion detected"); - return false; - } - - // FM wouldn't have modified the equalities in any way. So no need to again - // run GCD test. Check for trivial invalid constraints. - if (tmpCst.hasInvalidConstraint()) - return true; - } - return false; -} - -// Runs the GCD test on all equality constraints. Returns 'true' if this test -// fails on any equality. Returns 'false' otherwise. -// This test can be used to disprove the existence of a solution. If it returns -// true, no integer solution to the equality constraints can exist. -// -// GCD test definition: -// -// The equality constraint: -// -// c_1*x_1 + c_2*x_2 + ... + c_n*x_n = c_0 -// -// has an integer solution iff: -// -// GCD of c_1, c_2, ..., c_n divides c_0. -// -bool FlatAffineConstraints::isEmptyByGCDTest() const { - assert(hasConsistentState()); - unsigned numCols = getNumCols(); - for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { - uint64_t gcd = std::abs(atEq(i, 0)); - for (unsigned j = 1; j < numCols - 1; ++j) { - gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atEq(i, j))); - } - int64_t v = std::abs(atEq(i, numCols - 1)); - if (gcd > 0 && (v % gcd != 0)) { - return true; - } - } - return false; -} - -/// Tightens inequalities given that we are dealing with integer spaces. This is -/// analogous to the GCD test but applied to inequalities. The constant term can -/// be reduced to the preceding multiple of the GCD of the coefficients, i.e., -/// 64*i - 100 >= 0 => 64*i - 128 >= 0 (since 'i' is an integer). This is a -/// fast method - linear in the number of coefficients. -// Example on how this affects practical cases: consider the scenario: -// 64*i >= 100, j = 64*i; without a tightening, elimination of i would yield -// j >= 100 instead of the tighter (exact) j >= 128. -void FlatAffineConstraints::GCDTightenInequalities() { - unsigned numCols = getNumCols(); - for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { - uint64_t gcd = std::abs(atIneq(i, 0)); - for (unsigned j = 1; j < numCols - 1; ++j) { - gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atIneq(i, j))); - } - if (gcd > 0) { - int64_t gcdI = static_cast(gcd); - atIneq(i, numCols - 1) = - gcdI * mlir::floorDiv(atIneq(i, numCols - 1), gcdI); - } - } -} - -// Eliminates all identifer variables in column range [posStart, posLimit). -// Returns the number of variables eliminated. -unsigned FlatAffineConstraints::gaussianEliminateIds(unsigned posStart, - unsigned posLimit) { - // Return if identifier positions to eliminate are out of range. - assert(posLimit <= numIds); - assert(hasConsistentState()); - - if (posStart >= posLimit) - return 0; - - GCDTightenInequalities(); - - unsigned pivotCol = 0; - for (pivotCol = posStart; pivotCol < posLimit; ++pivotCol) { - // Find a row which has a non-zero coefficient in column 'j'. - unsigned pivotRow; - if (!findConstraintWithNonZeroAt(*this, pivotCol, /*isEq=*/true, - &pivotRow)) { - // No pivot row in equalities with non-zero at 'pivotCol'. - if (!findConstraintWithNonZeroAt(*this, pivotCol, /*isEq=*/false, - &pivotRow)) { - // If inequalities are also non-zero in 'pivotCol', it can be - // eliminated. - continue; - } - break; - } - - // Eliminate identifier at 'pivotCol' from each equality row. - for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { - eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart, - /*isEq=*/true); - normalizeConstraintByGCD(this, i); - } - - // Eliminate identifier at 'pivotCol' from each inequality row. - for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { - eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart, - /*isEq=*/false); - normalizeConstraintByGCD(this, i); - } - removeEquality(pivotRow); - } - // Update position limit based on number eliminated. - posLimit = pivotCol; - // Remove eliminated columns from all constraints. - removeIdRange(posStart, posLimit); - return posLimit - posStart; -} - -// Detect the identifier at 'pos' (say id_r) as modulo of another identifier -// (say id_n) w.r.t a constant. When this happens, another identifier (say id_q) -// could be detected as the floordiv of n. For eg: -// id_n - 4*id_q - id_r = 0, 0 <= id_r <= 3 <=> -// id_r = id_n mod 4, id_q = id_n floordiv 4. -// lbConst and ubConst are the constant lower and upper bounds for 'pos' - -// pre-detected at the caller. -static bool detectAsMod(const FlatAffineConstraints &cst, unsigned pos, - int64_t lbConst, int64_t ubConst, - SmallVectorImpl *memo) { - assert(pos < cst.getNumIds() && "invalid position"); - - // Check if 0 <= id_r <= divisor - 1 and if id_r is equal to - // id_n - divisor * id_q. If these are true, then id_n becomes the dividend - // and id_q the quotient when dividing id_n by the divisor. - - if (lbConst != 0 || ubConst < 1) - return false; - - int64_t divisor = ubConst + 1; - - // Now check for: id_r = id_n - divisor * id_q. As an example, we - // are looking r = d - 4q, i.e., either r - d + 4q = 0 or -r + d - 4q = 0. - unsigned seenQuotient = 0, seenDividend = 0; - int quotientPos = -1, dividendPos = -1; - for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) { - // id_n should have coeff 1 or -1. - if (std::abs(cst.atEq(r, pos)) != 1) - continue; - for (unsigned c = 0, f = cst.getNumDimAndSymbolIds(); c < f; c++) { - // The coeff of the quotient should be -divisor if the coefficient of - // the pos^th identifier is -1, and divisor if the latter is -1. - if (cst.atEq(r, c) * cst.atEq(r, pos) == divisor) { - seenQuotient++; - quotientPos = c; - } else if (cst.atEq(r, c) * cst.atEq(r, pos) == -1) { - seenDividend++; - dividendPos = c; - } - } - // We are looking for exactly one identifier as part of the dividend. - // TODO(bondhugula): could be extended to cover multiple ones in the - // dividend to detect mod of an affine function of identifiers. - if (seenDividend == 1 && seenQuotient >= 1) { - if (!(*memo)[dividendPos]) - return false; - // Successfully detected a mod. - (*memo)[pos] = (*memo)[dividendPos] % divisor; - if (seenQuotient == 1 && !(*memo)[quotientPos]) - // Successfully detected a floordiv as well. - (*memo)[quotientPos] = (*memo)[dividendPos].floorDiv(divisor); - return true; - } - } - return false; -} - -// Gather lower and upper bounds for the pos^th identifier. -static void getLowerAndUpperBoundIndices(const FlatAffineConstraints &cst, - unsigned pos, - SmallVectorImpl *lbIndices, - SmallVectorImpl *ubIndices) { - assert(pos < cst.getNumIds() && "invalid position"); - - // Gather all lower bounds and upper bounds of the variable. Since the - // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower - // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. - for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) { - if (cst.atIneq(r, pos) >= 1) { - // Lower bound. - lbIndices->push_back(r); - } else if (cst.atIneq(r, pos) <= -1) { - // Upper bound. - ubIndices->push_back(r); - } - } -} - -// Check if the pos^th identifier can be expressed as a floordiv of an affine -// function of other identifiers (where the divisor is a positive constant). -// For eg: 4q <= i + j <= 4q + 3 <=> q = (i + j) floordiv 4. -bool detectAsFloorDiv(const FlatAffineConstraints &cst, unsigned pos, - SmallVectorImpl *memo, MLIRContext *context) { - assert(pos < cst.getNumIds() && "invalid position"); - - SmallVector lbIndices, ubIndices; - getLowerAndUpperBoundIndices(cst, pos, &lbIndices, &ubIndices); - - // Check if any lower bound, upper bound pair is of the form: - // divisor * id >= expr - (divisor - 1) <-- Lower bound for 'id' - // divisor * id <= expr <-- Upper bound for 'id' - // Then, 'id' is equivalent to 'expr floordiv divisor'. (where divisor > 1). - // - // For example, if -32*k + 16*i + j >= 0 - // 32*k - 16*i - j + 31 >= 0 <=> - // k = ( 16*i + j ) floordiv 32 - unsigned seenDividends = 0; - for (auto ubPos : ubIndices) { - for (auto lbPos : lbIndices) { - // Check if lower bound's constant term is 'divisor - 1'. The 'divisor' - // here is cst.atIneq(lbPos, pos) and we already know that it's positive - // (since cst.Ineq(lbPos, ...) is a lower bound expression for 'pos'. - if (cst.atIneq(lbPos, cst.getNumCols() - 1) != cst.atIneq(lbPos, pos) - 1) - continue; - // Check if upper bound's constant term is 0. - if (cst.atIneq(ubPos, cst.getNumCols() - 1) != 0) - continue; - // For the remaining part, check if the lower bound expr's coeff's are - // negations of corresponding upper bound ones'. - unsigned c, f; - for (c = 0, f = cst.getNumCols() - 1; c < f; c++) { - if (cst.atIneq(lbPos, c) != -cst.atIneq(ubPos, c)) - break; - if (c != pos && cst.atIneq(lbPos, c) != 0) - seenDividends++; - } - // Lb coeff's aren't negative of ub coeff's (for the non constant term - // part). - if (c < f) - continue; - if (seenDividends >= 1) { - // The divisor is the constant term of the lower bound expression. - // We already know that cst.atIneq(lbPos, pos) > 0. - int64_t divisor = cst.atIneq(lbPos, pos); - // Construct the dividend expression. - auto dividendExpr = getAffineConstantExpr(0, context); - unsigned c, f; - for (c = 0, f = cst.getNumCols() - 1; c < f; c++) { - if (c == pos) - continue; - int64_t ubVal = cst.atIneq(ubPos, c); - if (ubVal == 0) - continue; - if (!(*memo)[c]) - break; - dividendExpr = dividendExpr + ubVal * (*memo)[c]; - } - // Expression can't be constructed as it depends on a yet unknown - // identifier. - // TODO(mlir-team): Visit/compute the identifiers in an order so that - // this doesn't happen. More complex but much more efficient. - if (c < f) - continue; - // Successfully detected the floordiv. - (*memo)[pos] = dividendExpr.floorDiv(divisor); - return true; - } - } - } - return false; -} - -// Fills an inequality row with the value 'val'. -static inline void fillInequality(FlatAffineConstraints *cst, unsigned r, - int64_t val) { - for (unsigned c = 0, f = cst->getNumCols(); c < f; c++) { - cst->atIneq(r, c) = val; - } -} - -// Negates an inequality. -static inline void negateInequality(FlatAffineConstraints *cst, unsigned r) { - for (unsigned c = 0, f = cst->getNumCols(); c < f; c++) { - cst->atIneq(r, c) = -cst->atIneq(r, c); - } -} - -// A more complex check to eliminate redundant inequalities. -void FlatAffineConstraints::removeRedundantInequalities() { - SmallVector redun(getNumInequalities(), false); - // To check if an inequality is redundant, we replace the inequality by its - // complement (for eg., i - 1 >= 0 by i <= 0), and check if the resulting - // system is empty. If it is, the inequality is redundant. - FlatAffineConstraints tmpCst(*this); - for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { - // Change the inequality to its complement. - negateInequality(&tmpCst, r); - tmpCst.atIneq(r, tmpCst.getNumCols() - 1)--; - if (tmpCst.isEmpty()) { - redun[r] = true; - // Zero fill the redundant inequality. - fillInequality(this, r, /*val=*/0); - fillInequality(&tmpCst, r, /*val=*/0); - } else { - // Reverse the change (to avoid recreating tmpCst each time). - tmpCst.atIneq(r, tmpCst.getNumCols() - 1)++; - negateInequality(&tmpCst, r); - } - } - - // Scan to get rid of all rows marked redundant, in-place. - auto copyRow = [&](unsigned src, unsigned dest) { - if (src == dest) - return; - for (unsigned c = 0, e = getNumCols(); c < e; c++) { - atIneq(dest, c) = atIneq(src, c); - } - }; - unsigned pos = 0; - for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { - if (!redun[r]) - copyRow(r, pos++); - } - inequalities.resize(numReservedCols * pos); -} - -std::pair FlatAffineConstraints::getLowerAndUpperBound( - unsigned pos, unsigned dimStartPos, unsigned symStartPos, - ArrayRef localExprs, MLIRContext *context) { - assert(pos < dimStartPos && "invalid dim start pos"); - assert(symStartPos >= dimStartPos && "invalid sym start pos"); - assert(getNumLocalIds() == localExprs.size() && - "incorrect local exprs count"); - - SmallVector lbIndices, ubIndices; - getLowerAndUpperBoundIndices(*this, pos, &lbIndices, &ubIndices); - - SmallVector lb, ub; - SmallVector exprs; - unsigned dimCount = symStartPos - dimStartPos; - unsigned symCount = getNumDimAndSymbolIds() - symStartPos; - exprs.reserve(lbIndices.size()); - // Lower bound expressions. - for (auto idx : lbIndices) { - auto ineq = getInequality(idx); - // Extract the lower bound (in terms of other coeff's + const), i.e., if - // i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j - // - 1. - lb.assign(ineq.begin() + dimStartPos, ineq.end()); - std::transform(lb.begin(), lb.end(), lb.begin(), std::negate()); - auto expr = mlir::toAffineExpr(lb, dimCount, symCount, localExprs, context); - exprs.push_back(expr); - } - auto lbMap = exprs.empty() ? AffineMap() - : AffineMap::get(dimCount, symCount, exprs, {}); - - exprs.clear(); - exprs.reserve(ubIndices.size()); - // Upper bound expressions. - for (auto idx : ubIndices) { - auto ineq = getInequality(idx); - // Extract the upper bound (in terms of other coeff's + const). - ub.assign(ineq.begin() + dimStartPos, ineq.end()); - auto expr = mlir::toAffineExpr(ub, dimCount, symCount, localExprs, context); - // Upper bound is exclusive. - exprs.push_back(expr + 1); - } - auto ubMap = exprs.empty() ? AffineMap() - : AffineMap::get(dimCount, symCount, exprs, {}); - - return {lbMap, ubMap}; -} - -/// Computes the lower and upper bounds of the first 'num' dimensional -/// identifiers as affine maps of the remaining identifiers (dimensional and -/// symbolic identifiers). Local identifiers are themselves explicitly computed -/// as affine functions of other identifiers in this process if needed. -void FlatAffineConstraints::getSliceBounds(unsigned num, MLIRContext *context, - SmallVectorImpl *lbMaps, - SmallVectorImpl *ubMaps) { - assert(num < getNumDimIds() && "invalid range"); - - // Basic simplification. - normalizeConstraintsByGCD(); - - LLVM_DEBUG(llvm::dbgs() << "getSliceBounds on:\n"); - LLVM_DEBUG(dump()); - - // Record computed/detected identifiers. - SmallVector memo(getNumIds(), AffineExpr::Null()); - // Initialize dimensional and symbolic identifiers. - for (unsigned i = num, e = getNumDimIds(); i < e; i++) - memo[i] = getAffineDimExpr(i - num, context); - for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++) - memo[i] = getAffineSymbolExpr(i - getNumDimIds(), context); - - bool changed; - do { - changed = false; - // Identify yet unknown identifiers as constants or mod's / floordiv's of - // other identifiers if possible. - for (unsigned pos = 0; pos < getNumIds(); pos++) { - if (memo[pos]) - continue; - - auto lbConst = getConstantLowerBound(pos); - auto ubConst = getConstantUpperBound(pos); - if (lbConst.hasValue() && ubConst.hasValue()) { - // Detect equality to a constant. - if (lbConst.getValue() == ubConst.getValue()) { - memo[pos] = getAffineConstantExpr(lbConst.getValue(), context); - changed = true; - continue; - } - - // Detect an identifier as modulo of another identifier w.r.t a - // constant. - if (detectAsMod(*this, pos, lbConst.getValue(), ubConst.getValue(), - &memo)) { - changed = true; - continue; - } - } - - // Detect an identifier as floordiv of another identifier w.r.t a - // constant. - if (detectAsFloorDiv(*this, pos, &memo, context)) { - changed = true; - continue; - } - - // Detect an identifier as an expression of other identifiers. - unsigned idx; - if (!findConstraintWithNonZeroAt(*this, pos, /*isEq=*/true, &idx)) { - continue; - } - - // Build AffineExpr solving for identifier 'pos' in terms of all others. - auto expr = getAffineConstantExpr(0, context); - unsigned j, e; - for (j = 0, e = getNumIds(); j < e; ++j) { - if (j == pos) - continue; - int64_t c = atEq(idx, j); - if (c == 0) - continue; - // If any of the involved IDs hasn't been found yet, we can't proceed. - if (!memo[j]) - break; - expr = expr + memo[j] * c; - } - if (j < e) - // Can't construct expression as it depends on a yet uncomputed - // identifier. - continue; - - // Add constant term to AffineExpr. - expr = expr + atEq(idx, getNumIds()); - int64_t vPos = atEq(idx, pos); - assert(vPos != 0 && "expected non-zero here"); - if (vPos > 0) - expr = (-expr).floorDiv(vPos); - else - // vPos < 0. - expr = expr.floorDiv(-vPos); - // Successfully constructed expression. - memo[pos] = expr; - changed = true; - } - // This loop is guaranteed to reach a fixed point - since once an - // identifier's explicit form is computed (in memo[pos]), it's not updated - // again. - } while (changed); - - // Set the lower and upper bound maps for all the identifiers that were - // computed as affine expressions of the rest as the "detected expr" and - // "detected expr + 1" respectively; set the undetected ones to Null(). - Optional tmpClone; - for (unsigned pos = 0; pos < num; pos++) { - unsigned numMapDims = getNumDimIds() - num; - unsigned numMapSymbols = getNumSymbolIds(); - AffineExpr expr = memo[pos]; - if (expr) - expr = simplifyAffineExpr(expr, numMapDims, numMapSymbols); - - if (expr) { - (*lbMaps)[pos] = AffineMap::get(numMapDims, numMapSymbols, expr, {}); - (*ubMaps)[pos] = AffineMap::get(numMapDims, numMapSymbols, expr + 1, {}); - } else { - // TODO(bondhugula): Whenever there have local identifiers in the - // dependence constraints, we'll conservatively over-approximate, since we - // don't always explicitly compute them above (in the while loop). - if (getNumLocalIds() == 0) { - // Work on a copy so that we don't update this constraint system. - if (!tmpClone) { - tmpClone.emplace(FlatAffineConstraints(*this)); - // Removing redudnant inequalities is necessary so that we don't get - // redundant loop bounds. - tmpClone->removeRedundantInequalities(); - } - std::tie((*lbMaps)[pos], (*ubMaps)[pos]) = - tmpClone->getLowerAndUpperBound(pos, num, getNumDimIds(), {}, - context); - } - - // If the above fails, we'll just use the constant lower bound and the - // constant upper bound (if they exist) as the slice bounds. - if (!(*lbMaps)[pos]) { - LLVM_DEBUG(llvm::dbgs() - << "WARNING: Potentially over-approximating slice lb\n"); - auto lbConst = getConstantLowerBound(pos); - if (lbConst.hasValue()) { - (*lbMaps)[pos] = AffineMap::get( - numMapDims, numMapSymbols, - getAffineConstantExpr(lbConst.getValue(), context), {}); - } - } - if (!(*ubMaps)[pos]) { - LLVM_DEBUG(llvm::dbgs() - << "WARNING: Potentially over-approximating slice ub\n"); - auto ubConst = getConstantUpperBound(pos); - if (ubConst.hasValue()) { - (*ubMaps)[pos] = AffineMap::get( - numMapDims, numMapSymbols, - getAffineConstantExpr(ubConst.getValue() + 1, context), {}); - } - } - } - LLVM_DEBUG(llvm::dbgs() << "lb map for pos = " << Twine(pos) << ", expr: "); - LLVM_DEBUG((*lbMaps)[pos].dump();); - LLVM_DEBUG(llvm::dbgs() << "ub map for pos = " << Twine(pos) << ", expr: "); - LLVM_DEBUG((*ubMaps)[pos].dump();); - } -} - -// Adds slice lower/upper bounds from 'lbMaps'/'upMaps' to the constraint -// system. This function assumes that position 'lbMaps.size' == 'ubMaps.size', -// and that positions [0, lbMaps.size) represent dimensional identifiers which -// correspond to the loop IVs whose iteration bounds are being sliced. -// Note that both lower/upper bounds use operands from 'operands'. -// Returns true on success. Returns false for unimplemented cases such as -// semi-affine expressions or expressions with mod/floordiv. -bool FlatAffineConstraints::addSliceBounds(ArrayRef lbMaps, - ArrayRef ubMaps, - ArrayRef operands) { - assert(lbMaps.size() == ubMaps.size()); - // Record positions of the operands in the constraint system. - SmallVector positions; - for (const auto &operand : operands) { - unsigned loc; - if (!findId(*operand, &loc)) - assert(0 && "expected to be found"); - positions.push_back(loc); - } - - auto addLowerOrUpperBound = [&](unsigned pos, AffineMap boundMap, - bool lower) -> bool { - FlatAffineConstraints localVarCst; - std::vector> flatExprs; - if (!getFlattenedAffineExprs(boundMap, &flatExprs, &localVarCst)) { - LLVM_DEBUG(llvm::dbgs() << "semi-affine expressions not yet supported\n"); - return false; - } - if (localVarCst.getNumLocalIds() > 0) { - LLVM_DEBUG(llvm::dbgs() - << "loop bounds with mod/floordiv expr's not yet supported\n"); - return false; - } - - for (const auto &flatExpr : flatExprs) { - SmallVector ineq(getNumCols(), 0); - ineq[pos] = lower ? 1 : -1; - for (unsigned j = 0, e = boundMap.getNumInputs(); j < e; j++) { - ineq[positions[j]] = lower ? -flatExpr[j] : flatExpr[j]; - } - // Constant term. - ineq[getNumCols() - 1] = - lower ? -flatExpr[flatExpr.size() - 1] - // Upper bound in flattenedExpr is an exclusive one. - : flatExpr[flatExpr.size() - 1] - 1; - addInequality(ineq); - } - return true; - }; - - for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) { - if (!addLowerOrUpperBound(i, lbMaps[i], /*lower=*/true)) - return false; - if (!addLowerOrUpperBound(i, ubMaps[i], /*lower=*/false)) - return false; - } - - return true; -} - -void FlatAffineConstraints::addEquality(ArrayRef eq) { - assert(eq.size() == getNumCols()); - unsigned offset = equalities.size(); - equalities.resize(equalities.size() + numReservedCols); - std::copy(eq.begin(), eq.end(), equalities.begin() + offset); -} - -void FlatAffineConstraints::addInequality(ArrayRef inEq) { - assert(inEq.size() == getNumCols()); - unsigned offset = inequalities.size(); - inequalities.resize(inequalities.size() + numReservedCols); - std::copy(inEq.begin(), inEq.end(), inequalities.begin() + offset); -} - -void FlatAffineConstraints::addConstantLowerBound(unsigned pos, int64_t lb) { - assert(pos < getNumCols()); - unsigned offset = inequalities.size(); - inequalities.resize(inequalities.size() + numReservedCols); - std::fill(inequalities.begin() + offset, - inequalities.begin() + offset + getNumCols(), 0); - inequalities[offset + pos] = 1; - inequalities[offset + getNumCols() - 1] = -lb; -} - -void FlatAffineConstraints::addConstantUpperBound(unsigned pos, int64_t ub) { - assert(pos < getNumCols()); - unsigned offset = inequalities.size(); - inequalities.resize(inequalities.size() + numReservedCols); - std::fill(inequalities.begin() + offset, - inequalities.begin() + offset + getNumCols(), 0); - inequalities[offset + pos] = -1; - inequalities[offset + getNumCols() - 1] = ub; -} - -void FlatAffineConstraints::addConstantLowerBound(ArrayRef expr, - int64_t lb) { - assert(expr.size() == getNumCols()); - unsigned offset = inequalities.size(); - inequalities.resize(inequalities.size() + numReservedCols); - std::fill(inequalities.begin() + offset, - inequalities.begin() + offset + getNumCols(), 0); - std::copy(expr.begin(), expr.end(), inequalities.begin() + offset); - inequalities[offset + getNumCols() - 1] += -lb; -} - -void FlatAffineConstraints::addConstantUpperBound(ArrayRef expr, - int64_t ub) { - assert(expr.size() == getNumCols()); - unsigned offset = inequalities.size(); - inequalities.resize(inequalities.size() + numReservedCols); - std::fill(inequalities.begin() + offset, - inequalities.begin() + offset + getNumCols(), 0); - for (unsigned i = 0, e = getNumCols(); i < e; i++) { - inequalities[offset + i] = -expr[i]; - } - inequalities[offset + getNumCols() - 1] += ub; -} - -/// Adds a new local identifier as the floordiv of an affine function of other -/// identifiers, the coefficients of which are provided in 'dividend' and with -/// respect to a positive constant 'divisor'. Two constraints are added to the -/// system to capture equivalence with the floordiv. -/// q = expr floordiv c <=> c*q <= expr <= c*q + c - 1. -void FlatAffineConstraints::addLocalFloorDiv(ArrayRef dividend, - int64_t divisor) { - assert(dividend.size() == getNumCols() && "incorrect dividend size"); - assert(divisor > 0 && "positive divisor expected"); - - addLocalId(getNumLocalIds()); - - // Add two constraints for this new identifier 'q'. - SmallVector bound(dividend.size() + 1); - - // dividend - q * divisor >= 0 - std::copy(dividend.begin(), dividend.begin() + dividend.size() - 1, - bound.begin()); - bound.back() = dividend.back(); - bound[getNumIds() - 1] = -divisor; - addInequality(bound); - - // -dividend +qdivisor * q + divisor - 1 >= 0 - std::transform(bound.begin(), bound.end(), bound.begin(), - std::negate()); - bound[bound.size() - 1] += divisor - 1; - addInequality(bound); -} - -bool FlatAffineConstraints::findId(const Value &id, unsigned *pos) const { - unsigned i = 0; - for (const auto &mayBeId : ids) { - if (mayBeId.hasValue() && mayBeId.getValue() == &id) { - *pos = i; - return true; - } - i++; - } - return false; -} - -void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) { - assert(newSymbolCount <= numDims + numSymbols && - "invalid separation position"); - numDims = numDims + numSymbols - newSymbolCount; - numSymbols = newSymbolCount; -} - -/// Sets the specified identifer to a constant value. -void FlatAffineConstraints::setIdToConstant(unsigned pos, int64_t val) { - unsigned offset = equalities.size(); - equalities.resize(equalities.size() + numReservedCols); - std::fill(equalities.begin() + offset, - equalities.begin() + offset + getNumCols(), 0); - equalities[offset + pos] = 1; - equalities[offset + getNumCols() - 1] = -val; -} - -/// Sets the specified identifer to a constant value; asserts if the id is not -/// found. -void FlatAffineConstraints::setIdToConstant(const Value &id, int64_t val) { - unsigned pos; - if (!findId(id, &pos)) - // This is a pre-condition for this method. - assert(0 && "id not found"); - setIdToConstant(pos, val); -} - -void FlatAffineConstraints::removeEquality(unsigned pos) { - unsigned numEqualities = getNumEqualities(); - assert(pos < numEqualities); - unsigned outputIndex = pos * numReservedCols; - unsigned inputIndex = (pos + 1) * numReservedCols; - unsigned numElemsToCopy = (numEqualities - pos - 1) * numReservedCols; - std::copy(equalities.begin() + inputIndex, - equalities.begin() + inputIndex + numElemsToCopy, - equalities.begin() + outputIndex); - equalities.resize(equalities.size() - numReservedCols); -} - -/// Finds an equality that equates the specified identifier to a constant. -/// Returns the position of the equality row. If 'symbolic' is set to true, -/// symbols are also treated like a constant, i.e., an affine function of the -/// symbols is also treated like a constant. -static int findEqualityToConstant(const FlatAffineConstraints &cst, - unsigned pos, bool symbolic = false) { - assert(pos < cst.getNumIds() && "invalid position"); - for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) { - int64_t v = cst.atEq(r, pos); - if (v * v != 1) - continue; - unsigned c; - unsigned f = symbolic ? cst.getNumDimIds() : cst.getNumIds(); - // This checks for zeros in all positions other than 'pos' in [0, f) - for (c = 0; c < f; c++) { - if (c == pos) - continue; - if (cst.atEq(r, c) != 0) { - // Dependent on another identifier. - break; - } - } - if (c == f) - // Equality is free of other identifiers. - return r; - } - return -1; -} - -void FlatAffineConstraints::setAndEliminate(unsigned pos, int64_t constVal) { - assert(pos < getNumIds() && "invalid position"); - for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { - atIneq(r, getNumCols() - 1) += atIneq(r, pos) * constVal; - } - for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { - atEq(r, getNumCols() - 1) += atEq(r, pos) * constVal; - } - removeId(pos); -} - -bool FlatAffineConstraints::constantFoldId(unsigned pos) { - assert(pos < getNumIds() && "invalid position"); - int rowIdx; - if ((rowIdx = findEqualityToConstant(*this, pos)) == -1) - return false; - - // atEq(rowIdx, pos) is either -1 or 1. - assert(atEq(rowIdx, pos) * atEq(rowIdx, pos) == 1); - int64_t constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos); - setAndEliminate(pos, constVal); - return true; -} - -void FlatAffineConstraints::constantFoldIdRange(unsigned pos, unsigned num) { - for (unsigned s = pos, t = pos, e = pos + num; s < e; s++) { - if (!constantFoldId(t)) - t++; - } -} - -/// Returns the extent (upper bound - lower bound) of the specified -/// identifier if it is found to be a constant; returns None if it's not a -/// constant. This methods treats symbolic identifiers specially, i.e., -/// it looks for constant differences between affine expressions involving -/// only the symbolic identifiers. See comments at function definition for -/// example. 'lb', if provided, is set to the lower bound associated with the -/// constant difference. Note that 'lb' is purely symbolic and thus will contain -/// the coefficients of the symbolic identifiers and the constant coefficient. -// Egs: 0 <= i <= 15, return 16. -// s0 + 2 <= i <= s0 + 17, returns 16. (s0 has to be a symbol) -// s0 + s1 + 16 <= d0 <= s0 + s1 + 31, returns 16. -// s0 - 7 <= 8*j <= s0 returns 1 with lb = s0, lbDivisor = 8 (since lb = -// ceil(s0 - 7 / 8) = floor(s0 / 8)). -Optional FlatAffineConstraints::getConstantBoundOnDimSize( - unsigned pos, SmallVectorImpl *lb, int64_t *lbFloorDivisor) const { - assert(pos < getNumDimIds() && "Invalid identifier position"); - assert(getNumLocalIds() == 0); - - // TODO(bondhugula): eliminate all remaining dimensional identifiers (other - // than the one at 'pos' to make this more powerful. Not needed for - // hyper-rectangular spaces. - - // Find an equality for 'pos'^th identifier that equates it to some function - // of the symbolic identifiers (+ constant). - int eqRow = findEqualityToConstant(*this, pos, /*symbolic=*/true); - if (eqRow != -1) { - // This identifier can only take a single value. - if (lb) { - // Set lb to the symbolic value. - lb->resize(getNumSymbolIds() + 1); - for (unsigned c = 0, f = getNumSymbolIds() + 1; c < f; c++) { - int64_t v = atEq(eqRow, pos); - // atEq(eqRow, pos) is either -1 or 1. - assert(v * v == 1); - (*lb)[c] = v < 0 ? atEq(eqRow, getNumDimIds() + c) / -v - : -atEq(eqRow, getNumDimIds() + c) / v; - } - assert(lbFloorDivisor && - "both lb and divisor or none should be provided"); - *lbFloorDivisor = 1; - } - return 1; - } - - // Check if the identifier appears at all in any of the inequalities. - unsigned r, e; - for (r = 0, e = getNumInequalities(); r < e; r++) { - if (atIneq(r, pos) != 0) - break; - } - if (r == e) - // If it doesn't, there isn't a bound on it. - return None; - - // Positions of constraints that are lower/upper bounds on the variable. - SmallVector lbIndices, ubIndices; - - // Gather all symbolic lower bounds and upper bounds of the variable. Since - // the canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a - // lower bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. - for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { - unsigned c, f; - for (c = 0, f = getNumDimIds(); c < f; c++) { - if (c != pos && atIneq(r, c) != 0) - break; - } - if (c < getNumDimIds()) - // Not a pure symbolic bound. - continue; - if (atIneq(r, pos) >= 1) - // Lower bound. - lbIndices.push_back(r); - else if (atIneq(r, pos) <= -1) - // Upper bound. - ubIndices.push_back(r); - } - - // TODO(bondhugula): eliminate other dimensional identifiers to make this more - // powerful. Not needed for hyper-rectangular iteration spaces. - - Optional minDiff = None; - unsigned minLbPosition; - for (auto ubPos : ubIndices) { - for (auto lbPos : lbIndices) { - // Look for a lower bound and an upper bound that only differ by a - // constant, i.e., pairs of the form 0 <= c_pos - f(c_i's) <= diffConst. - // For example, if ii is the pos^th variable, we are looking for - // constraints like ii >= i, ii <= ii + 50, 50 being the difference. The - // minimum among all such constant differences is kept since that's the - // constant bounding the extent of the pos^th variable. - unsigned j, e; - for (j = 0, e = getNumCols() - 1; j < e; j++) - if (atIneq(ubPos, j) != -atIneq(lbPos, j)) { - break; - } - if (j < getNumCols() - 1) - continue; - int64_t diff = floorDiv(atIneq(ubPos, getNumCols() - 1) + - atIneq(lbPos, getNumCols() - 1) + 1, - atIneq(lbPos, pos)); - if (minDiff == None || diff < minDiff) { - minDiff = diff; - minLbPosition = lbPos; - } - } - } - if (lb && minDiff.hasValue()) { - // Set lb to the symbolic lower bound. - lb->resize(getNumSymbolIds() + 1); - // The lower bound is the ceildiv of the lb constraint over the coefficient - // of the variable at 'pos'. We express the ceildiv equivalently as a floor - // for uniformity. For eg., if the lower bound constraint was: 32*d0 - N + - // 31 >= 0, the lower bound for d0 is ceil(N - 31, 32), i.e., floor(N, 32). - *lbFloorDivisor = atIneq(minLbPosition, pos); - for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++) { - // ceildiv (val / d) = floordiv (val + d - 1 / d); hence, the addition of - // 'atIneq(minLbPosition, pos) - 1'. - (*lb)[c] = -atIneq(minLbPosition, getNumDimIds() + c) + - atIneq(minLbPosition, pos) - 1; - } - } - return minDiff; -} - -template -Optional -FlatAffineConstraints::computeConstantLowerOrUpperBound(unsigned pos) { - assert(pos < getNumIds() && "invalid position"); - // Project to 'pos'. - projectOut(0, pos); - projectOut(1, getNumIds() - 1); - // Check if there's an equality equating the '0'^th identifier to a constant. - int eqRowIdx = findEqualityToConstant(*this, 0, /*symbolic=*/false); - if (eqRowIdx != -1) - // atEq(rowIdx, 0) is either -1 or 1. - return -atEq(eqRowIdx, getNumCols() - 1) / atEq(eqRowIdx, 0); - - // Check if the identifier appears at all in any of the inequalities. - unsigned r, e; - for (r = 0, e = getNumInequalities(); r < e; r++) { - if (atIneq(r, 0) != 0) - break; - } - if (r == e) - // If it doesn't, there isn't a bound on it. - return None; - - Optional minOrMaxConst = None; - - // Take the max across all const lower bounds (or min across all constant - // upper bounds). - for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { - if (isLower) { - if (atIneq(r, 0) <= 0) - // Not a lower bound. - continue; - } else if (atIneq(r, 0) >= 0) { - // Not an upper bound. - continue; - } - unsigned c, f; - for (c = 0, f = getNumCols() - 1; c < f; c++) - if (c != 0 && atIneq(r, c) != 0) - break; - if (c < getNumCols() - 1) - // Not a constant bound. - continue; - - int64_t boundConst = - isLower ? mlir::ceilDiv(-atIneq(r, getNumCols() - 1), atIneq(r, 0)) - : mlir::floorDiv(atIneq(r, getNumCols() - 1), -atIneq(r, 0)); - if (isLower) { - if (minOrMaxConst == None || boundConst > minOrMaxConst) - minOrMaxConst = boundConst; - } else { - if (minOrMaxConst == None || boundConst < minOrMaxConst) - minOrMaxConst = boundConst; - } - } - return minOrMaxConst; -} - -Optional -FlatAffineConstraints::getConstantLowerBound(unsigned pos) const { - FlatAffineConstraints tmpCst(*this); - return tmpCst.computeConstantLowerOrUpperBound(pos); -} - -Optional -FlatAffineConstraints::getConstantUpperBound(unsigned pos) const { - FlatAffineConstraints tmpCst(*this); - return tmpCst.computeConstantLowerOrUpperBound(pos); -} - -// A simple (naive and conservative) check for hyper-rectangularlity. -bool FlatAffineConstraints::isHyperRectangular(unsigned pos, - unsigned num) const { - assert(pos < getNumCols() - 1); - // Check for two non-zero coefficients in the range [pos, pos + sum). - for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { - unsigned sum = 0; - for (unsigned c = pos; c < pos + num; c++) { - if (atIneq(r, c) != 0) - sum++; - } - if (sum > 1) - return false; - } - for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { - unsigned sum = 0; - for (unsigned c = pos; c < pos + num; c++) { - if (atEq(r, c) != 0) - sum++; - } - if (sum > 1) - return false; - } - return true; -} - -void FlatAffineConstraints::print(raw_ostream &os) const { - assert(hasConsistentState()); - os << "\nConstraints (" << getNumDimIds() << " dims, " << getNumSymbolIds() - << " symbols, " << getNumLocalIds() << " locals), (" << getNumConstraints() - << " constraints)\n"; - os << "("; - for (unsigned i = 0, e = getNumIds(); i < e; i++) { - if (ids[i] == None) - os << "None "; - else - os << "Value "; - } - os << " const)\n"; - for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { - for (unsigned j = 0, f = getNumCols(); j < f; ++j) { - os << atEq(i, j) << " "; - } - os << "= 0\n"; - } - for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { - for (unsigned j = 0, f = getNumCols(); j < f; ++j) { - os << atIneq(i, j) << " "; - } - os << ">= 0\n"; - } - os << '\n'; -} - -void FlatAffineConstraints::dump() const { print(llvm::errs()); } - -/// Removes duplicate constraints and trivially true constraints: a constraint -/// of the form >= 0 is considered a trivially true -/// constraint. -// Uses a DenseSet to hash and detect duplicates followed by a linear scan to -// remove duplicates in place. -void FlatAffineConstraints::removeTrivialRedundancy() { - DenseSet> rowSet; - - // Check if constraint is of the form >= 0. - auto isTriviallyValid = [&](unsigned r) -> bool { - for (unsigned c = 0, e = getNumCols() - 1; c < e; c++) { - if (atIneq(r, c) != 0) - return false; - } - return atIneq(r, getNumCols() - 1) >= 0; - }; - - // Detect and mark redundant constraints. - std::vector redunIneq(getNumInequalities(), false); - for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { - int64_t *rowStart = inequalities.data() + numReservedCols * r; - auto row = ArrayRef(rowStart, getNumCols()); - if (isTriviallyValid(r) || !rowSet.insert(row).second) { - redunIneq[r] = true; - } - } - - auto copyRow = [&](unsigned src, unsigned dest) { - if (src == dest) - return; - for (unsigned c = 0, e = getNumCols(); c < e; c++) { - atIneq(dest, c) = atIneq(src, c); - } - }; - - // Scan to get rid of all rows marked redundant, in-place. - unsigned pos = 0; - for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { - if (!redunIneq[r]) - copyRow(r, pos++); - } - inequalities.resize(numReservedCols * pos); - - // TODO(bondhugula): consider doing this for equalities as well, but probably - // not worth the savings. -} - -void FlatAffineConstraints::clearAndCopyFrom( - const FlatAffineConstraints &other) { - FlatAffineConstraints copy(other); - std::swap(*this, copy); - assert(copy.getNumIds() == copy.getIds().size()); -} - -void FlatAffineConstraints::removeId(unsigned pos) { - removeIdRange(pos, pos + 1); -} - -static std::pair -getNewNumDimsSymbols(unsigned pos, const FlatAffineConstraints &cst) { - unsigned numDims = cst.getNumDimIds(); - unsigned numSymbols = cst.getNumSymbolIds(); - unsigned newNumDims, newNumSymbols; - if (pos < numDims) { - newNumDims = numDims - 1; - newNumSymbols = numSymbols; - } else if (pos < numDims + numSymbols) { - assert(numSymbols >= 1); - newNumDims = numDims; - newNumSymbols = numSymbols - 1; - } else { - newNumDims = numDims; - newNumSymbols = numSymbols; - } - return {newNumDims, newNumSymbols}; -} - -#undef DEBUG_TYPE -#define DEBUG_TYPE "fm" - -/// Eliminates identifier at the specified position using Fourier-Motzkin -/// variable elimination. This technique is exact for rational spaces but -/// conservative (in "rare" cases) for integer spaces. The operation corresponds -/// to a projection operation yielding the (convex) set of integer points -/// contained in the rational shadow of the set. An emptiness test that relies -/// on this method will guarantee emptiness, i.e., it disproves the existence of -/// a solution if it says it's empty. -/// If a non-null isResultIntegerExact is passed, it is set to true if the -/// result is also integer exact. If it's set to false, the obtained solution -/// *may* not be exact, i.e., it may contain integer points that do not have an -/// integer pre-image in the original set. -/// -/// Eg: -/// j >= 0, j <= i + 1 -/// i >= 0, i <= N + 1 -/// Eliminating i yields, -/// j >= 0, 0 <= N + 1, j - 1 <= N + 1 -/// -/// If darkShadow = true, this method computes the dark shadow on elimination; -/// the dark shadow is a convex integer subset of the exact integer shadow. A -/// non-empty dark shadow proves the existence of an integer solution. The -/// elimination in such a case could however be an under-approximation, and thus -/// should not be used for scanning sets or used by itself for dependence -/// checking. -/// -/// Eg: 2-d set, * represents grid points, 'o' represents a point in the set. -/// ^ -/// | -/// | * * * * o o -/// i | * * o o o o -/// | o * * * * * -/// ---------------> -/// j -> -/// -/// Eliminating i from this system (projecting on the j dimension): -/// rational shadow / integer light shadow: 1 <= j <= 6 -/// dark shadow: 3 <= j <= 6 -/// exact integer shadow: j = 1 \union 3 <= j <= 6 -/// holes/splinters: j = 2 -/// -/// darkShadow = false, isResultIntegerExact = nullptr are default values. -// TODO(bondhugula): a slight modification to yield dark shadow version of FM -// (tightened), which can prove the existence of a solution if there is one. -void FlatAffineConstraints::FourierMotzkinEliminate( - unsigned pos, bool darkShadow, bool *isResultIntegerExact) { - LLVM_DEBUG(llvm::dbgs() << "FM input (eliminate pos " << pos << "):\n"); - LLVM_DEBUG(dump()); - assert(pos < getNumIds() && "invalid position"); - assert(hasConsistentState()); - - // Check if this identifier can be eliminated through a substitution. - for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { - if (atEq(r, pos) != 0) { - // Use Gaussian elimination here (since we have an equality). - bool ret = gaussianEliminateId(pos); - (void)ret; - assert(ret && "Gaussian elimination guaranteed to succeed"); - LLVM_DEBUG(llvm::dbgs() << "FM output:\n"); - LLVM_DEBUG(dump()); - return; - } - } - - // A fast linear time tightening. - GCDTightenInequalities(); - - // Check if the identifier appears at all in any of the inequalities. - unsigned r, e; - for (r = 0, e = getNumInequalities(); r < e; r++) { - if (atIneq(r, pos) != 0) - break; - } - if (r == getNumInequalities()) { - // If it doesn't appear, just remove the column and return. - // TODO(andydavis,bondhugula): refactor removeColumns to use it from here. - removeId(pos); - LLVM_DEBUG(llvm::dbgs() << "FM output:\n"); - LLVM_DEBUG(dump()); - return; - } - - // Positions of constraints that are lower bounds on the variable. - SmallVector lbIndices; - // Positions of constraints that are lower bounds on the variable. - SmallVector ubIndices; - // Positions of constraints that do not involve the variable. - std::vector nbIndices; - nbIndices.reserve(getNumInequalities()); - - // Gather all lower bounds and upper bounds of the variable. Since the - // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower - // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. - for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { - if (atIneq(r, pos) == 0) { - // Id does not appear in bound. - nbIndices.push_back(r); - } else if (atIneq(r, pos) >= 1) { - // Lower bound. - lbIndices.push_back(r); - } else { - // Upper bound. - ubIndices.push_back(r); - } - } - - // Set the number of dimensions, symbols in the resulting system. - const auto &dimsSymbols = getNewNumDimsSymbols(pos, *this); - unsigned newNumDims = dimsSymbols.first; - unsigned newNumSymbols = dimsSymbols.second; - - SmallVector, 8> newIds; - newIds.reserve(numIds - 1); - newIds.append(ids.begin(), ids.begin() + pos); - newIds.append(ids.begin() + pos + 1, ids.end()); - - /// Create the new system which has one identifier less. - FlatAffineConstraints newFac( - lbIndices.size() * ubIndices.size() + nbIndices.size(), - getNumEqualities(), getNumCols() - 1, newNumDims, newNumSymbols, - /*numLocals=*/getNumIds() - 1 - newNumDims - newNumSymbols, newIds); - - assert(newFac.getIds().size() == newFac.getNumIds()); - - // This will be used to check if the elimination was integer exact. - unsigned lcmProducts = 1; - - // Let x be the variable we are eliminating. - // For each lower bound, lb <= c_l*x, and each upper bound c_u*x <= ub, (note - // that c_l, c_u >= 1) we have: - // lb*lcm(c_l, c_u)/c_l <= lcm(c_l, c_u)*x <= ub*lcm(c_l, c_u)/c_u - // We thus generate a constraint: - // lcm(c_l, c_u)/c_l*lb <= lcm(c_l, c_u)/c_u*ub. - // Note if c_l = c_u = 1, all integer points captured by the resulting - // constraint correspond to integer points in the original system (i.e., they - // have integer pre-images). Hence, if the lcm's are all 1, the elimination is - // integer exact. - for (auto ubPos : ubIndices) { - for (auto lbPos : lbIndices) { - SmallVector ineq; - ineq.reserve(newFac.getNumCols()); - int64_t lbCoeff = atIneq(lbPos, pos); - // Note that in the comments above, ubCoeff is the negation of the - // coefficient in the canonical form as the view taken here is that of the - // term being moved to the other size of '>='. - int64_t ubCoeff = -atIneq(ubPos, pos); - // TODO(bondhugula): refactor this loop to avoid all branches inside. - for (unsigned l = 0, e = getNumCols(); l < e; l++) { - if (l == pos) - continue; - assert(lbCoeff >= 1 && ubCoeff >= 1 && "bounds wrongly identified"); - int64_t lcm = mlir::lcm(lbCoeff, ubCoeff); - ineq.push_back(atIneq(ubPos, l) * (lcm / ubCoeff) + - atIneq(lbPos, l) * (lcm / lbCoeff)); - lcmProducts *= lcm; - } - if (darkShadow) { - // The dark shadow is a convex subset of the exact integer shadow. If - // there is a point here, it proves the existence of a solution. - ineq[ineq.size() - 1] += lbCoeff * ubCoeff - lbCoeff - ubCoeff + 1; - } - // TODO: we need to have a way to add inequalities in-place in - // FlatAffineConstraints instead of creating and copying over. - newFac.addInequality(ineq); - } - } - - if (lcmProducts == 1 && isResultIntegerExact) - *isResultIntegerExact = 1; - - // Copy over the constraints not involving this variable. - for (auto nbPos : nbIndices) { - SmallVector ineq; - ineq.reserve(getNumCols() - 1); - for (unsigned l = 0, e = getNumCols(); l < e; l++) { - if (l == pos) - continue; - ineq.push_back(atIneq(nbPos, l)); - } - newFac.addInequality(ineq); - } - - assert(newFac.getNumConstraints() == - lbIndices.size() * ubIndices.size() + nbIndices.size()); - - // Copy over the equalities. - for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { - SmallVector eq; - eq.reserve(newFac.getNumCols()); - for (unsigned l = 0, e = getNumCols(); l < e; l++) { - if (l == pos) - continue; - eq.push_back(atEq(r, l)); - } - newFac.addEquality(eq); - } - - newFac.removeTrivialRedundancy(); - clearAndCopyFrom(newFac); - LLVM_DEBUG(llvm::dbgs() << "FM output:\n"); - LLVM_DEBUG(dump()); -} - -#undef DEBUG_TYPE -#define DEBUG_TYPE "affine-structures" - -void FlatAffineConstraints::projectOut(unsigned pos, unsigned num) { - if (num == 0) - return; - - // 'pos' can be at most getNumCols() - 2 if num > 0. - assert(getNumCols() < 2 || pos <= getNumCols() - 2 && "invalid position"); - assert(pos + num < getNumCols() && "invalid range"); - - // Eliminate as many identifiers as possible using Gaussian elimination. - unsigned currentPos = pos; - unsigned numToEliminate = num; - unsigned numGaussianEliminated = 0; - - while (currentPos < getNumIds()) { - unsigned curNumEliminated = - gaussianEliminateIds(currentPos, currentPos + numToEliminate); - ++currentPos; - numToEliminate -= curNumEliminated + 1; - numGaussianEliminated += curNumEliminated; - } - - // Eliminate the remaining using Fourier-Motzkin. - for (unsigned i = 0; i < num - numGaussianEliminated; i++) { - unsigned numToEliminate = num - numGaussianEliminated - i; - FourierMotzkinEliminate( - getBestIdToEliminate(*this, pos, pos + numToEliminate)); - } - - // Fast/trivial simplifications. - GCDTightenInequalities(); - // Normalize constraints after tightening since the latter impacts this, but - // not the other way round. - normalizeConstraintsByGCD(); -} - -void FlatAffineConstraints::projectOut(Value *id) { - unsigned pos; - bool ret = findId(*id, &pos); - assert(ret); - (void)ret; - FourierMotzkinEliminate(pos); -} - -bool FlatAffineConstraints::isRangeOneToOne(unsigned start, - unsigned limit) const { - assert(start <= getNumIds() - 1 && "invalid start position"); - assert(limit > start && limit <= getNumIds() && "invalid limit"); - - FlatAffineConstraints tmpCst(*this); - - if (start != 0) { - // Move [start, limit) to the left. - for (unsigned r = 0, e = getNumInequalities(); r < e; ++r) { - for (unsigned c = 0, f = getNumCols(); c < f; ++c) { - if (c >= start && c < limit) - tmpCst.atIneq(r, c - start) = atIneq(r, c); - else if (c < start) - tmpCst.atIneq(r, c + limit - start) = atIneq(r, c); - else - tmpCst.atIneq(r, c) = atIneq(r, c); - } - } - for (unsigned r = 0, e = getNumEqualities(); r < e; ++r) { - for (unsigned c = 0, f = getNumCols(); c < f; ++c) { - if (c >= start && c < limit) - tmpCst.atEq(r, c - start) = atEq(r, c); - else if (c < start) - tmpCst.atEq(r, c + limit - start) = atEq(r, c); - else - tmpCst.atEq(r, c) = atEq(r, c); - } - } - } - - // Mark everything to the right as symbols so that we can check the extents in - // a symbolic way below. - tmpCst.setDimSymbolSeparation(getNumIds() - (limit - start)); - - // Check if the extents of all the specified dimensions are just one (when - // treating the rest as symbols). - for (unsigned pos = 0, e = tmpCst.getNumDimIds(); pos < e; ++pos) { - auto extent = tmpCst.getConstantBoundOnDimSize(pos); - if (!extent.hasValue() || extent.getValue() != 1) - return false; - } - return true; -} - -void FlatAffineConstraints::clearConstraints() { - equalities.clear(); - inequalities.clear(); -} - -namespace { - -enum BoundCmpResult { Greater, Less, Equal, Unknown }; - -/// Compares two affine bounds whose coefficients are provided in 'first' and -/// 'second'. The last coefficient is the constant term. -static BoundCmpResult compareBounds(ArrayRef a, ArrayRef b) { - assert(a.size() == b.size()); - - // For the bounds to be comparable, their corresponding identifier - // coefficients should be equal; the constant terms are then compared to - // determine less/greater/equal. - - if (!std::equal(a.begin(), a.end() - 1, b.begin())) - return Unknown; - - if (a.back() == b.back()) - return Equal; - - return a.back() < b.back() ? Less : Greater; -} -}; // namespace - -// TODO(bondhugula,andydavis): This still doesn't do a comprehensive merge of -// the symbols. Assumes the common symbols appear in the same order (the -// current/common use case). -static void mergeSymbols(FlatAffineConstraints *A, FlatAffineConstraints *B) { - SmallVector symbolsA, symbolsB; - A->getIdValues(A->getNumDimIds(), A->getNumDimAndSymbolIds(), &symbolsA); - B->getIdValues(B->getNumDimIds(), B->getNumDimAndSymbolIds(), &symbolsB); - - // Both symbol list have a handful symbols each typically (3-4); a merge - // quadratic in complexity with a linear search is fine. - for (auto *symbolB : symbolsB) { - if (llvm::is_contained(symbolsA, symbolB)) { - A->addSymbolId(symbolsA.size(), symbolB); - symbolsA.push_back(symbolB); - } - } - // symbolsA now holds the merged symbol list. - symbolsB.reserve(symbolsA.size()); - unsigned iB = 0; - for (auto *symbolA : symbolsA) { - assert(iB < symbolsB.size()); - if (symbolA != symbolsB[iB]) { - symbolsB.insert(symbolsB.begin() + iB, symbolA); - B->addSymbolId(iB, symbolA); - } - ++iB; - } -} - -// Compute the bounding box with respect to 'other' by finding the min of the -// lower bounds and the max of the upper bounds along each of the dimensions. -bool FlatAffineConstraints::unionBoundingBox( - const FlatAffineConstraints &otherArg) { - assert(otherArg.getNumDimIds() == numDims && "dims mismatch"); - - Optional copy; - if (!otherArg.getIds().equals(getIds())) { - copy.emplace(FlatAffineConstraints(otherArg)); - mergeSymbols(this, ©.getValue()); - assert(getIds().equals(copy->getIds()) && "merge failed"); - } - - const auto &other = copy ? *copy : otherArg; - - assert(other.getNumLocalIds() == 0 && "local ids not eliminated"); - assert(getNumLocalIds() == 0 && "local ids not eliminated"); - - std::vector> boundingLbs; - std::vector> boundingUbs; - boundingLbs.reserve(2 * getNumDimIds()); - boundingUbs.reserve(2 * getNumDimIds()); - - SmallVector lb, otherLb; - lb.reserve(getNumSymbolIds() + 1); - otherLb.reserve(getNumSymbolIds() + 1); - int64_t lbDivisor, otherLbDivisor; - for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) { - lb.clear(); - auto extent = getConstantBoundOnDimSize(d, &lb, &lbDivisor); - if (!extent.hasValue()) - // TODO(bondhugula): symbolic extents when necessary. - // TODO(bondhugula): handle union if a dimension is unbounded. - return false; - - otherLb.clear(); - auto otherExtent = - other.getConstantBoundOnDimSize(d, &otherLb, &otherLbDivisor); - if (!otherExtent.hasValue() || lbDivisor != otherLbDivisor) - // TODO(bondhugula): symbolic extents when necessary. - return false; - - assert(lbDivisor > 0 && "divisor always expected to be positive"); - - // Compute min of lower bounds and max of upper bounds. - ArrayRef minLb, maxUb; - - auto res = compareBounds(lb, otherLb); - // Identify min. - if (res == BoundCmpResult::Less || res == BoundCmpResult::Equal) { - minLb = lb; - } else if (res == BoundCmpResult::Greater) { - minLb = otherLb; - } else { - // Uncomparable. - auto constLb = getConstantLowerBound(d); - auto constOtherLb = other.getConstantLowerBound(d); - if (!constLb.hasValue() || !constOtherLb.hasValue()) - return false; - minLb = std::min(constLb.getValue(), constOtherLb.getValue()); - } - - // Do the same for ub's but max of upper bounds. - SmallVector ub(lb), otherUb(otherLb); - ub.back() += extent.getValue() - 1; - otherUb.back() += otherExtent.getValue() - 1; - - // Identify max. - auto uRes = compareBounds(ub, otherUb); - if (uRes == BoundCmpResult::Greater || uRes == BoundCmpResult::Equal) { - maxUb = ub; - } else if (uRes == BoundCmpResult::Less) { - maxUb = otherUb; - } else { - // Uncomparable. - auto constUb = getConstantUpperBound(d); - auto constOtherUb = other.getConstantUpperBound(d); - if (!constUb.hasValue() || !constOtherUb.hasValue()) - return false; - maxUb = std::max(constUb.getValue(), constOtherUb.getValue()); - } - - SmallVector newLb(getNumCols(), 0); - SmallVector newUb(getNumCols(), 0); - - // The divisor for lb, ub, otherLb, otherUb at this point is lbDivisor, - // and so it's the divisor for newLb and newUb as well. - newLb[d] = lbDivisor; - newUb[d] = -lbDivisor; - // Copy over the symbolic part + constant term. - std::copy(minLb.begin(), minLb.end(), newLb.begin() + getNumDimIds()); - std::transform(newLb.begin() + getNumDimIds(), newLb.end(), - newLb.begin() + getNumDimIds(), std::negate()); - std::copy(maxUb.begin(), maxUb.end(), newUb.begin() + getNumDimIds()); - - boundingLbs.push_back(newLb); - boundingUbs.push_back(newUb); - } - - // Clear all constraints and add the lower/upper bounds for the bounding box. - clearConstraints(); - for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) { - addInequality(boundingLbs[d]); - addInequality(boundingUbs[d]); - } - - return true; -} diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 5083bc4d586..4fb6f34ed53 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -22,8 +22,8 @@ //===----------------------------------------------------------------------===// #include "mlir/AffineOps/AffineOps.h" +#include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Utils.h" -#include "mlir/IR/AffineStructures.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 5dbefb875da..677710b00a6 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -21,11 +21,11 @@ #include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Analysis/Utils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" -#include "mlir/IR/AffineStructures.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 2253d1d354a..240b2b6d9b6 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -21,8 +21,8 @@ #include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/LoopAnalysis.h" -#include "mlir/IR/AffineStructures.h" #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/LoopUtils.h" diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp index 4ddfd9f06fb..d0fdcb5527f 100644 --- a/mlir/lib/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp @@ -19,7 +19,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/IR/AffineStructures.h" +#include "mlir/Analysis/AffineStructures.h" #include "mlir/IR/Function.h" #include "mlir/IR/Instruction.h" #include "mlir/IR/IntegerSet.h" diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 6b1a0be3bd3..2a7738924a7 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -23,10 +23,10 @@ #include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" -#include "mlir/IR/AffineStructures.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 519885b3a50..80dc49f1aab 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -24,9 +24,9 @@ #include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Dominance.h" #include "mlir/Analysis/Utils.h" -#include "mlir/IR/AffineStructures.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Module.h" #include "mlir/StandardOps/StandardOps.h" -- cgit v1.2.3 From d4b3ff1096fbe5fec0f02abcb88b49a4d9bfc6dc Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Tue, 26 Feb 2019 16:10:19 -0800 Subject: Loop fusion comand line options cleanup - clean up loop fusion CL options for promoting local buffers to fast memory space - add parameters to loop fusion pass instantiation PiperOrigin-RevId: 235813419 --- mlir/include/mlir/Transforms/Passes.h | 7 +++++-- mlir/lib/Transforms/LoopFusion.cpp | 35 +++++++++++++++++++++++++---------- 2 files changed, 30 insertions(+), 12 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index e0fc934a620..6b472526366 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -72,8 +72,11 @@ FunctionPass *createLoopUnrollAndJamPass(int unrollJamFactor = -1); /// Creates an simplification pass for affine structures. FunctionPass *createSimplifyAffineStructuresPass(); -/// Creates a loop fusion pass which fuses loops. -FunctionPass *createLoopFusionPass(); +/// Creates a loop fusion pass which fuses loops. Buffers of size less than or +/// equal to `localBufSizeThreshold` are promoted to memory space +/// `fastMemorySpace'. +FunctionPass *createLoopFusionPass(unsigned fastMemorySpace = 0, + uint64_t localBufSizeThreshold = 0); /// Creates a pass to pipeline explicit movement of data across levels of the /// memory hierarchy. diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 677710b00a6..72176bacf9a 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -67,9 +67,11 @@ static llvm::cl::opt clFusionFastMemorySpace( llvm::cl::desc("Faster memory space number to promote fusion buffers to"), llvm::cl::cat(clOptionsCategory)); -static llvm::cl::opt clFusionLocalBufThreshold( +// A local buffer of size less than or equal to this size is promoted to fast +// memory. +static llvm::cl::opt clFusionLocalBufThreshold( "fusion-local-buf-threshold", llvm::cl::Hidden, - llvm::cl::desc("Threshold size (bytes) for promoting local buffers to fast " + llvm::cl::desc("Threshold size (KiB) for promoting local buffers to fast " "memory space"), llvm::cl::cat(clOptionsCategory)); @@ -85,14 +87,17 @@ namespace { // and add support for more general loop fusion algorithms. struct LoopFusion : public FunctionPass { - LoopFusion() : FunctionPass(&LoopFusion::passID) {} + LoopFusion(unsigned fastMemorySpace = 0, uint64_t localBufSizeThreshold = 0) + : FunctionPass(&LoopFusion::passID), + localBufSizeThreshold(localBufSizeThreshold), + fastMemorySpace(fastMemorySpace) {} PassResult runOnFunction(Function *f) override; constexpr static PassID passID = {}; - // Any local buffers smaller than this size will be created in + // Any local buffers smaller than this size (in bytes) will be created in // `fastMemorySpace` if provided. - unsigned localBufSizeThreshold = 1024; + uint64_t localBufSizeThreshold; Optional fastMemorySpace = None; // The amount of additional computation that is tolerated while fusing @@ -102,7 +107,10 @@ struct LoopFusion : public FunctionPass { } // end anonymous namespace -FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; } +FunctionPass *mlir::createLoopFusionPass(unsigned fastMemorySpace, + uint64_t localBufSizeThreshold) { + return new LoopFusion(fastMemorySpace, localBufSizeThreshold); +} namespace { @@ -632,7 +640,7 @@ struct LoopNestStatsCollector { unsigned count = 0; stats->opCountMap[forInst] = 0; for (auto &inst : *forOp->getBody()) { - if (!(inst.isa() || inst.isa())) + if (!inst.isa() && !inst.isa()) ++count; } stats->opCountMap[forInst] = count; @@ -1048,7 +1056,7 @@ static Value *createPrivateMemRef(OpPointer forOp, Instruction *srcStoreOpInst, unsigned dstLoopDepth, Optional fastMemorySpace, - unsigned localBufSizeThreshold) { + uint64_t localBufSizeThreshold) { auto *forInst = forOp->getInstruction(); // Create builder to insert alloc op just before 'forOp'. @@ -1102,7 +1110,7 @@ static Value *createPrivateMemRef(OpPointer forOp, uint64_t bufSize = getMemRefEltSizeInBytes(oldMemRefType) * numElements.getValue(); unsigned newMemSpace; - if (bufSize < localBufSizeThreshold && fastMemorySpace.hasValue()) { + if (bufSize <= localBufSizeThreshold && fastMemorySpace.hasValue()) { newMemSpace = fastMemorySpace.getValue(); } else { newMemSpace = oldMemRefType.getMemorySpace(); @@ -1414,7 +1422,8 @@ static bool isFusionProfitable(Instruction *srcOpInst, LLVM_DEBUG({ std::stringstream msg; msg << " evaluating fusion profitability at depth : " << i << "\n" - << std::setprecision(2) << " additional compute fraction: " + << std::fixed << std::setprecision(2) + << " additional compute fraction: " << 100.0 * additionalComputeFraction << "%\n" << " storage reduction factor: " << storageReduction << "x\n" << " fused nest cost: " << fusedLoopNestComputeCost << "\n" @@ -1795,10 +1804,16 @@ public: } // end anonymous namespace PassResult LoopFusion::runOnFunction(Function *f) { + // Override if a command line argument was provided. if (clFusionFastMemorySpace.getNumOccurrences() > 0) { fastMemorySpace = clFusionFastMemorySpace.getValue(); } + // Override if a command line argument was provided. + if (clFusionLocalBufThreshold.getNumOccurrences() > 0) { + localBufSizeThreshold = clFusionLocalBufThreshold * 1024; + } + MemRefDependenceGraph g; if (g.init(f)) GreedyFusion(&g).run(localBufSizeThreshold, fastMemorySpace); -- cgit v1.2.3 From 7aa60a383f222ca75a00a668c8fe60cf9217f508 Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Tue, 26 Feb 2019 17:32:47 -0800 Subject: Temp change in FlatAffineConstraints::getSliceBounds() to deal with TODO in LoopFusion - getConstDifference in LoopFusion is pending a refactoring to handle bounds with min's and max's; it currently asserts on some useful test cases that we want to experiment with. This CL changes getSliceBounds to be more conservative so as to not trigger the assertion. Filed b/126426796 to track this. PiperOrigin-RevId: 235826538 --- mlir/lib/Analysis/AffineStructures.cpp | 27 ++++++++++++++++----------- mlir/lib/Transforms/LoopFusion.cpp | 3 +-- 2 files changed, 17 insertions(+), 13 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 32129166afa..276db4712c5 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1433,9 +1433,12 @@ void FlatAffineConstraints::getSliceBounds(unsigned num, MLIRContext *context, if (expr) expr = simplifyAffineExpr(expr, numMapDims, numMapSymbols); + AffineMap &lbMap = (*lbMaps)[pos]; + AffineMap &ubMap = (*ubMaps)[pos]; + if (expr) { - (*lbMaps)[pos] = AffineMap::get(numMapDims, numMapSymbols, expr, {}); - (*ubMaps)[pos] = AffineMap::get(numMapDims, numMapSymbols, expr + 1, {}); + lbMap = AffineMap::get(numMapDims, numMapSymbols, expr, {}); + ubMap = AffineMap::get(numMapDims, numMapSymbols, expr + 1, {}); } else { // TODO(bondhugula): Whenever there have local identifiers in the // dependence constraints, we'll conservatively over-approximate, since we @@ -1448,38 +1451,40 @@ void FlatAffineConstraints::getSliceBounds(unsigned num, MLIRContext *context, // redundant loop bounds. tmpClone->removeRedundantInequalities(); } - std::tie((*lbMaps)[pos], (*ubMaps)[pos]) = - tmpClone->getLowerAndUpperBound(pos, num, getNumDimIds(), {}, - context); + std::tie(lbMap, ubMap) = tmpClone->getLowerAndUpperBound( + pos, num, getNumDimIds(), {}, context); } // If the above fails, we'll just use the constant lower bound and the // constant upper bound (if they exist) as the slice bounds. - if (!(*lbMaps)[pos]) { + // TODO(b/126426796): being conservative for the moment in cases that + // lead to multiple bounds - until getConstDifference in LoopFusion.cpp is + // fixed (b/126426796). + if (!lbMap || lbMap.getNumResults() > 1) { LLVM_DEBUG(llvm::dbgs() << "WARNING: Potentially over-approximating slice lb\n"); auto lbConst = getConstantLowerBound(pos); if (lbConst.hasValue()) { - (*lbMaps)[pos] = AffineMap::get( + lbMap = AffineMap::get( numMapDims, numMapSymbols, getAffineConstantExpr(lbConst.getValue(), context), {}); } } - if (!(*ubMaps)[pos]) { + if (!ubMap || ubMap.getNumResults() > 1) { LLVM_DEBUG(llvm::dbgs() << "WARNING: Potentially over-approximating slice ub\n"); auto ubConst = getConstantUpperBound(pos); if (ubConst.hasValue()) { - (*ubMaps)[pos] = AffineMap::get( + (ubMap) = AffineMap::get( numMapDims, numMapSymbols, getAffineConstantExpr(ubConst.getValue() + 1, context), {}); } } } LLVM_DEBUG(llvm::dbgs() << "lb map for pos = " << Twine(pos) << ", expr: "); - LLVM_DEBUG((*lbMaps)[pos].dump();); + LLVM_DEBUG(lbMap.dump();); LLVM_DEBUG(llvm::dbgs() << "ub map for pos = " << Twine(pos) << ", expr: "); - LLVM_DEBUG((*ubMaps)[pos].dump();); + LLVM_DEBUG(ubMap.dump();); } } diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 72176bacf9a..0f4e45c372a 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -704,13 +704,12 @@ static int64_t getComputeCost( } // end anonymous namespace +// TODO(andydavis,b/126426796): extend this to handle multiple result maps. static Optional getConstDifference(AffineMap lbMap, AffineMap ubMap) { assert(lbMap.getNumResults() == 1 && "expected single result bound map"); assert(ubMap.getNumResults() == 1 && "expected single result bound map"); assert(lbMap.getNumDims() == ubMap.getNumDims()); assert(lbMap.getNumSymbols() == ubMap.getNumSymbols()); - // TODO(andydavis) Merge this code with 'mlir::getTripCountExpr'. - // ub_expr - lb_expr AffineExpr lbExpr(lbMap.getResult(0)); AffineExpr ubExpr(ubMap.getResult(0)); auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(), -- cgit v1.2.3 From c6c534493d625c10ce0046baa9dc6293f8dba405 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Wed, 27 Feb 2019 10:59:29 -0800 Subject: Port all of the existing passes over to the new pass manager infrastructure. This is largely NFC. PiperOrigin-RevId: 235952357 --- mlir/bindings/python/pybind.cpp | 30 ---------------- mlir/include/mlir/Analysis/Passes.h | 6 ++-- mlir/include/mlir/Pass/Pass.h | 42 +++------------------- mlir/include/mlir/Pass/PassManager.h | 2 +- mlir/include/mlir/Pass/PassRegistry.h | 2 +- mlir/include/mlir/Transforms/Passes.h | 42 +++++++++++----------- mlir/include/mlir/Transforms/ViewFunctionGraph.h | 8 ++--- mlir/lib/Analysis/MemRefBoundCheck.cpp | 14 +++----- mlir/lib/Analysis/MemRefDependenceCheck.cpp | 15 +++----- mlir/lib/Analysis/OpStats.cpp | 13 +++---- mlir/lib/EDSC/LowerEDSCTestPass.cpp | 11 +++--- mlir/lib/ExecutionEngine/ExecutionEngine.cpp | 42 ++++++++-------------- .../lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp | 11 +++--- mlir/lib/Pass/Pass.cpp | 18 ---------- mlir/lib/Transforms/CSE.cpp | 16 ++++----- mlir/lib/Transforms/Canonicalizer.cpp | 20 +++++------ mlir/lib/Transforms/ConstantFold.cpp | 14 +++----- mlir/lib/Transforms/DmaGeneration.cpp | 20 +++++------ mlir/lib/Transforms/LoopFusion.cpp | 16 ++++----- mlir/lib/Transforms/LoopTiling.cpp | 12 +++---- mlir/lib/Transforms/LoopUnroll.cpp | 19 +++++----- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 22 +++++------- mlir/lib/Transforms/LowerAffine.cpp | 16 ++++----- mlir/lib/Transforms/LowerVectorTransfers.cpp | 15 ++++---- mlir/lib/Transforms/MaterializeVectors.cpp | 13 +++---- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 22 +++++------- mlir/lib/Transforms/PipelineDataTransfer.cpp | 13 +++---- mlir/lib/Transforms/SimplifyAffineStructures.cpp | 16 ++++----- mlir/lib/Transforms/StripDebugInfo.cpp | 21 ++++++----- .../Vectorization/VectorizerTestPass.cpp | 12 +++---- mlir/lib/Transforms/Vectorize.cpp | 13 +++---- mlir/lib/Transforms/ViewFunctionGraph.cpp | 17 ++++----- mlir/tools/mlir-opt/mlir-opt.cpp | 19 +++++----- 33 files changed, 205 insertions(+), 367 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/bindings/python/pybind.cpp b/mlir/bindings/python/pybind.cpp index 2673c8a58b1..eca5d632b48 100644 --- a/mlir/bindings/python/pybind.cpp +++ b/mlir/bindings/python/pybind.cpp @@ -31,36 +31,6 @@ namespace mlir { namespace edsc { namespace python { -static std::vector> getDefaultPasses( - const std::vector &mlirPassInfoList = {}) { - std::vector> passList; - passList.reserve(mlirPassInfoList.size() + 4); - // Run each of the passes that were selected. - for (const auto *passInfo : mlirPassInfoList) { - passList.emplace_back(passInfo->createPass()); - } - // Append the extra passes for lowering to MLIR. - passList.emplace_back(mlir::createConstantFoldPass()); - passList.emplace_back(mlir::createCSEPass()); - passList.emplace_back(mlir::createCanonicalizerPass()); - passList.emplace_back(mlir::createLowerAffinePass()); - return passList; -} - -// Run the passes sequentially on the given module. -// Return `nullptr` immediately if any of the passes fails. -static bool runPasses(const std::vector> &passes, - Module *module) { - for (const auto &pass : passes) { - mlir::PassResult result = pass->runOnModule(module); - if (result == mlir::PassResult::Failure || module->verify()) { - llvm::errs() << "Pass failed\n"; - return true; - } - } - return false; -} - namespace py = pybind11; struct PythonBindable; diff --git a/mlir/include/mlir/Analysis/Passes.h b/mlir/include/mlir/Analysis/Passes.h index 8fd1f9c4bf9..5bd05462657 100644 --- a/mlir/include/mlir/Analysis/Passes.h +++ b/mlir/include/mlir/Analysis/Passes.h @@ -27,13 +27,13 @@ namespace mlir { -class FunctionPass; +class FunctionPassBase; /// Creates a pass to check memref accesses in an ML Function. -FunctionPass *createMemRefBoundCheckPass(); +FunctionPassBase *createMemRefBoundCheckPass(); /// Creates a pass to check memref access dependences in an ML Function. -FunctionPass *createMemRefDependenceCheckPass(); +FunctionPassBase *createMemRefDependenceCheckPass(); } // end namespace mlir diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h index cb20a15ea6d..26b76e317eb 100644 --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -42,9 +42,6 @@ public: virtual ~Pass() = default; - /// TODO: This is deprecated and should be removed. - virtual PassResult runOnModule(Module *m) { return failure(); } - /// Returns the unique identifier that corresponds to this pass. const PassID *getPassID() const { return passIDAndKind.getPointer(); } @@ -75,31 +72,6 @@ private: llvm::PointerIntPair passIDAndKind; }; -/// Deprecated Function and Module Pass definitions. -/// TODO(riverriddle) Remove these. -class ModulePass : public Pass { -public: - explicit ModulePass(const PassID *passID) : Pass(passID, Kind::ModulePass) {} - - virtual PassResult runOnModule(Module *m) override = 0; - -private: - /// Out of line virtual method to ensure vtables and metadata are emitted to a - /// single .o file. - virtual void anchor(); -}; -class FunctionPass : public Pass { -public: - explicit FunctionPass(const PassID *passID) - : Pass(passID, Kind::FunctionPass) {} - - /// Implement this function to be run on every function in the module. - virtual PassResult runOnFunction(Function *fn) = 0; - - // Iterates over all functions in a module, halting upon failure. - virtual PassResult runOnModule(Module *m) override; -}; - namespace detail { class FunctionPassExecutor; class ModulePassExecutor; @@ -173,10 +145,6 @@ private: /// The current execution state for the pass. llvm::Optional> passState; - /// TODO(riverriddle) Remove this using directive when the old pass - /// functionality is removed. - using Pass::runOnModule; - /// Allow access to 'run'. friend detail::ModulePassExecutor; }; @@ -195,9 +163,7 @@ protected: /// TODO(riverriddle) Provide additional utilities for cloning, getting the /// derived class name, etc.. }; - -// TODO(riverriddle): Move these to the mlir namespace when the current passes -// have been ported. +} // end namespace detail /// A model for providing function pass specific utilities. /// @@ -210,14 +176,14 @@ protected: /// Derived function passes are expected to provide the following: /// - A 'PassResult runOnFunction()' method. template -using FunctionPass = PassModel; +using FunctionPass = detail::PassModel; /// A model for providing module pass specific utilities. /// /// Derived module passes are expected to provide the following: /// - A 'PassResult runOnModule()' method. -template using ModulePass = PassModel; -} // end namespace detail +template +using ModulePass = detail::PassModel; } // end namespace mlir #endif // MLIR_PASS_PASS_H diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h index 744a5179afd..ec1ab2f7e44 100644 --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -91,7 +91,7 @@ private: /// An adaptor module pass used to run function passes over all of the /// non-external functions of a module. class ModuleToFunctionPassAdaptor - : public detail::ModulePass { + : public ModulePass { public: ModuleToFunctionPassAdaptor() = default; ModuleToFunctionPassAdaptor(ModuleToFunctionPassAdaptor &&) = default; diff --git a/mlir/include/mlir/Pass/PassRegistry.h b/mlir/include/mlir/Pass/PassRegistry.h index 7ac807f08d5..20f64cc274c 100644 --- a/mlir/include/mlir/Pass/PassRegistry.h +++ b/mlir/include/mlir/Pass/PassRegistry.h @@ -96,7 +96,7 @@ void registerPass(StringRef arg, StringRef description, const PassID *passID, /// static PassRegistration Unused("unused", "Unused pass"); template struct PassRegistration { PassRegistration(StringRef arg, StringRef description) { - registerPass(arg, description, &ConcretePass::passID, + registerPass(arg, description, PassID::getID(), [&]() { return new ConcretePass(); }); } }; diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index 6b472526366..32c7d93307c 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -30,28 +30,28 @@ namespace mlir { class AffineForOp; template class ConstOpPointer; -class FunctionPass; -class ModulePass; +class FunctionPassBase; +class ModulePassBase; /// Creates a constant folding pass. -FunctionPass *createConstantFoldPass(); +FunctionPassBase *createConstantFoldPass(); /// Creates an instance of the Canonicalizer pass. -FunctionPass *createCanonicalizerPass(); +FunctionPassBase *createCanonicalizerPass(); /// Creates a pass to perform common sub expression elimination. -FunctionPass *createCSEPass(); +FunctionPassBase *createCSEPass(); /// Creates a pass to vectorize loops, operations and data types using a /// target-independent, n-D super-vector abstraction. -FunctionPass *createVectorizePass(); +FunctionPassBase *createVectorizePass(); /// Creates a pass to allow independent testing of vectorizer functionality with /// FileCheck. -FunctionPass *createVectorizerTestPass(); +FunctionPassBase *createVectorizerTestPass(); /// Creates a pass to lower super-vectors to target-dependent HW vectors. -FunctionPass *createMaterializeVectorsPass(); +FunctionPassBase *createMaterializeVectorsPass(); /// Creates a loop unrolling pass with the provided parameters. /// 'getUnrollFactor' is a function callback for clients to supply a function @@ -59,7 +59,7 @@ FunctionPass *createMaterializeVectorsPass(); /// factors supplied through other means. If -1 is passed as the unrollFactor /// and no callback is provided, anything passed from the command-line (if at /// all) or the default unroll factor is used (LoopUnroll:kDefaultUnrollFactor). -FunctionPass * +FunctionPassBase * createLoopUnrollPass(int unrollFactor = -1, int unrollFull = -1, const std::function)> &getUnrollFactor = nullptr); @@ -67,49 +67,49 @@ createLoopUnrollPass(int unrollFactor = -1, int unrollFull = -1, /// Creates a loop unroll jam pass to unroll jam by the specified factor. A /// factor of -1 lets the pass use the default factor or the one on the command /// line if provided. -FunctionPass *createLoopUnrollAndJamPass(int unrollJamFactor = -1); +FunctionPassBase *createLoopUnrollAndJamPass(int unrollJamFactor = -1); /// Creates an simplification pass for affine structures. -FunctionPass *createSimplifyAffineStructuresPass(); +FunctionPassBase *createSimplifyAffineStructuresPass(); /// Creates a loop fusion pass which fuses loops. Buffers of size less than or /// equal to `localBufSizeThreshold` are promoted to memory space /// `fastMemorySpace'. -FunctionPass *createLoopFusionPass(unsigned fastMemorySpace = 0, - uint64_t localBufSizeThreshold = 0); +FunctionPassBase *createLoopFusionPass(unsigned fastMemorySpace = 0, + uint64_t localBufSizeThreshold = 0); /// Creates a pass to pipeline explicit movement of data across levels of the /// memory hierarchy. -FunctionPass *createPipelineDataTransferPass(); +FunctionPassBase *createPipelineDataTransferPass(); /// Lowers affine control flow instructions (ForStmt, IfStmt and AffineApplyOp) /// to equivalent lower-level constructs (flow of basic blocks and arithmetic /// primitives). -FunctionPass *createLowerAffinePass(); +FunctionPassBase *createLowerAffinePass(); /// Creates a pass to perform tiling on loop nests. -FunctionPass *createLoopTilingPass(); +FunctionPassBase *createLoopTilingPass(); /// Promotes all accessed memref regions to the specified faster memory space /// while generating DMAs to move data. -FunctionPass *createDmaGenerationPass( +FunctionPassBase *createDmaGenerationPass( unsigned slowMemorySpace, unsigned fastMemorySpace, int minDmaTransferSize = 1024, uint64_t fastMemCapacityBytes = std::numeric_limits::max()); /// Creates a pass to lower VectorTransferReadOp and VectorTransferWriteOp. -FunctionPass *createLowerVectorTransfersPass(); +FunctionPassBase *createLowerVectorTransfersPass(); /// Creates a pass to perform optimizations relying on memref dataflow such as /// store to load forwarding, elimination of dead stores, and dead allocs. -FunctionPass *createMemRefDataFlowOptPass(); +FunctionPassBase *createMemRefDataFlowOptPass(); /// Creates a pass to strip debug information from a function. -FunctionPass *createStripDebugInfoPass(); +FunctionPassBase *createStripDebugInfoPass(); /// Creates a pass to convert Standard and Builtin dialects into the LLVMIR /// dialect. -ModulePass *createConvertToLLVMIRPass(); +ModulePassBase *createConvertToLLVMIRPass(); } // end namespace mlir diff --git a/mlir/include/mlir/Transforms/ViewFunctionGraph.h b/mlir/include/mlir/Transforms/ViewFunctionGraph.h index 7b2b9db6e91..33b85b614a0 100644 --- a/mlir/include/mlir/Transforms/ViewFunctionGraph.h +++ b/mlir/include/mlir/Transforms/ViewFunctionGraph.h @@ -29,7 +29,7 @@ namespace mlir { class Function; -class FunctionPass; +class FunctionPassBase; /// Displays the CFG in a window. This is for use from the debugger and /// depends on Graphviz to generate the graph. @@ -41,9 +41,9 @@ llvm::raw_ostream &writeGraph(llvm::raw_ostream &os, const Function *function, bool shortNames = false, const Twine &title = ""); /// Creates a pass to print CFG graphs. -FunctionPass *createPrintCFGGraphPass(llvm::raw_ostream &os = llvm::errs(), - bool shortNames = false, - const llvm::Twine &title = ""); +FunctionPassBase *createPrintCFGGraphPass(llvm::raw_ostream &os = llvm::errs(), + bool shortNames = false, + const llvm::Twine &title = ""); } // end namespace mlir diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index 8a0cb44f0cc..a6730f01199 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -37,22 +37,18 @@ using namespace mlir; namespace { /// Checks for out of bound memef access subscripts.. -struct MemRefBoundCheck : public FunctionPass { - explicit MemRefBoundCheck() : FunctionPass(&MemRefBoundCheck::passID) {} - - PassResult runOnFunction(Function *f) override; - - constexpr static PassID passID = {}; +struct MemRefBoundCheck : public FunctionPass { + PassResult runOnFunction() override; }; } // end anonymous namespace -FunctionPass *mlir::createMemRefBoundCheckPass() { +FunctionPassBase *mlir::createMemRefBoundCheckPass() { return new MemRefBoundCheck(); } -PassResult MemRefBoundCheck::runOnFunction(Function *f) { - f->walk([](Instruction *opInst) { +PassResult MemRefBoundCheck::runOnFunction() { + getFunction().walk([](Instruction *opInst) { if (auto loadOp = opInst->dyn_cast()) { boundCheckLoadOrStoreOp(loadOp); } else if (auto storeOp = opInst->dyn_cast()) { diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp index 93d4fde1fd9..33488f0c7a8 100644 --- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp @@ -37,19 +37,14 @@ namespace { // TODO(andydavis) Add common surrounding loop depth-wise dependence checks. /// Checks dependences between all pairs of memref accesses in a Function. -struct MemRefDependenceCheck : public FunctionPass { +struct MemRefDependenceCheck : public FunctionPass { SmallVector loadsAndStores; - explicit MemRefDependenceCheck() - : FunctionPass(&MemRefDependenceCheck::passID) {} - - PassResult runOnFunction(Function *f) override; - - constexpr static PassID passID = {}; + PassResult runOnFunction() override; }; } // end anonymous namespace -FunctionPass *mlir::createMemRefDependenceCheckPass() { +FunctionPassBase *mlir::createMemRefDependenceCheckPass() { return new MemRefDependenceCheck(); } @@ -116,10 +111,10 @@ static void checkDependences(ArrayRef loadsAndStores) { // Walks the Function 'f' adding load and store ops to 'loadsAndStores'. // Runs pair-wise dependence checks. -PassResult MemRefDependenceCheck::runOnFunction(Function *f) { +PassResult MemRefDependenceCheck::runOnFunction() { // Collect the loads and stores within the function. loadsAndStores.clear(); - f->walk([&](Instruction *inst) { + getFunction().walk([&](Instruction *inst) { if (inst->isa() || inst->isa()) loadsAndStores.push_back(inst); }); diff --git a/mlir/lib/Analysis/OpStats.cpp b/mlir/lib/Analysis/OpStats.cpp index c1fcacac15a..a17be9d176b 100644 --- a/mlir/lib/Analysis/OpStats.cpp +++ b/mlir/lib/Analysis/OpStats.cpp @@ -26,29 +26,26 @@ using namespace mlir; namespace { -struct PrintOpStatsPass : public ModulePass { - explicit PrintOpStatsPass(llvm::raw_ostream &os = llvm::errs()) - : ModulePass(&PrintOpStatsPass::passID), os(os) {} +struct PrintOpStatsPass : public ModulePass { + explicit PrintOpStatsPass(llvm::raw_ostream &os = llvm::errs()) : os(os) {} // Prints the resultant operation statistics post iterating over the module. - PassResult runOnModule(Module *m) override; + PassResult runOnModule() override; // Print summary of op stats. void printSummary(); - constexpr static PassID passID = {}; - private: llvm::StringMap opCount; llvm::raw_ostream &os; }; } // namespace -PassResult PrintOpStatsPass::runOnModule(Module *m) { +PassResult PrintOpStatsPass::runOnModule() { opCount.clear(); // Compute the operation statistics for each function in the module. - for (auto &fn : *m) + for (auto &fn : getModule()) fn.walk( [&](Instruction *inst) { ++opCount[inst->getName().getStringRef()]; }); printSummary(); diff --git a/mlir/lib/EDSC/LowerEDSCTestPass.cpp b/mlir/lib/EDSC/LowerEDSCTestPass.cpp index cdbe6e52ec0..2b6c38bf8c6 100644 --- a/mlir/lib/EDSC/LowerEDSCTestPass.cpp +++ b/mlir/lib/EDSC/LowerEDSCTestPass.cpp @@ -33,18 +33,15 @@ using namespace mlir; namespace { // Testing pass to lower EDSC. -struct LowerEDSCTestPass : public FunctionPass { - LowerEDSCTestPass() : FunctionPass(&LowerEDSCTestPass::passID) {} - PassResult runOnFunction(Function *f) override; - - constexpr static PassID passID = {}; +struct LowerEDSCTestPass : public FunctionPass { + PassResult runOnFunction() override; }; } // end anonymous namespace #include "mlir/EDSC/reference-impl.inc" -PassResult LowerEDSCTestPass::runOnFunction(Function *f) { - f->walk([](Instruction *op) { +PassResult LowerEDSCTestPass::runOnFunction() { + getFunction().walk([](Instruction *op) { if (op->getName().getStringRef() == "print") { auto opName = op->getAttrOfType("op"); if (!opName) { diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp index 01278aad8af..1a3dd6ffff0 100644 --- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp @@ -22,7 +22,7 @@ #include "mlir/ExecutionEngine/ExecutionEngine.h" #include "mlir/IR/Function.h" #include "mlir/IR/Module.h" -#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" #include "mlir/Target/LLVMIR.h" #include "mlir/Transforms/Passes.h" @@ -168,35 +168,20 @@ static inline Error make_string_error(const llvm::Twine &message) { // - CSE // - canonicalization // - affine lowering -static std::vector> -getDefaultPasses(const std::vector &mlirPassInfoList) { - std::vector> passList; - passList.reserve(mlirPassInfoList.size() + 4); +static void +getDefaultPasses(PassManager &manager, + const std::vector &mlirPassInfoList) { // Run each of the passes that were selected. for (const auto *passInfo : mlirPassInfoList) { - passList.emplace_back(passInfo->createPass()); + manager.addPass(passInfo->createPass()); } - // Append the extra passes for lowering to MLIR. - passList.emplace_back(mlir::createConstantFoldPass()); - passList.emplace_back(mlir::createCSEPass()); - passList.emplace_back(mlir::createCanonicalizerPass()); - passList.emplace_back(mlir::createLowerAffinePass()); - passList.emplace_back(mlir::createConvertToLLVMIRPass()); - return passList; -} -// Run the passes sequentially on the given module. -// Return `nullptr` immediately if any of the passes fails. -static bool runPasses(const std::vector> &passes, - Module *module) { - for (const auto &pass : passes) { - mlir::PassResult result = pass->runOnModule(module); - if (result == mlir::PassResult::Failure || module->verify()) { - llvm::errs() << "Pass failed\n"; - return true; - } - } - return false; + // Append the extra passes for lowering to MLIR. + manager.addPass(mlir::createConstantFoldPass()); + manager.addPass(mlir::createCSEPass()); + manager.addPass(mlir::createCanonicalizerPass()); + manager.addPass(mlir::createLowerAffinePass()); + manager.addPass(mlir::createConvertToLLVMIRPass()); } // Setup LLVM target triple from the current machine. @@ -295,7 +280,10 @@ Expected> ExecutionEngine::create( if (!expectedJIT) return expectedJIT.takeError(); - if (runPasses(getDefaultPasses({}), m)) + // Construct and run the default MLIR pipeline. + PassManager manager; + getDefaultPasses(manager, {}); + if (manager.run(m)) return make_string_error("passes failed"); auto llvmModule = translateModuleToLLVMIR(*m); diff --git a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp index 7421ebbeaaa..64ee5862ae7 100644 --- a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp +++ b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp @@ -1137,13 +1137,10 @@ static void ensureDistinctSuccessors(Module *m) { /// A pass converting MLIR Standard and Builtin operations into the LLVM IR /// dialect. -class LLVMLowering : public ModulePass, public DialectConversion { +class LLVMLowering : public ModulePass, public DialectConversion { public: - LLVMLowering() : ModulePass(&passID) {} - - constexpr static PassID passID = {}; - - PassResult runOnModule(Module *m) override { + PassResult runOnModule() override { + Module *m = &getModule(); uniqueSuccessorsWithArguments(m); return DialectConversion::convert(m) ? failure() : success(); } @@ -1203,7 +1200,7 @@ private: llvm::Module *module; }; -ModulePass *mlir::createConvertToLLVMIRPass() { return new LLVMLowering; } +ModulePassBase *mlir::createConvertToLLVMIRPass() { return new LLVMLowering(); } static PassRegistration pass("convert-to-llvmir", "Convert all functions to the LLVM IR dialect"); diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp index b652f1f700b..c05c9d24aa4 100644 --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -29,24 +29,6 @@ using namespace mlir; /// single .o file. void Pass::anchor() {} -/// Out of line virtual method to ensure vtables and metadata are emitted to a -/// single .o file. -void ModulePass::anchor() {} - -/// Function passes walk a module and look at each function with their -/// corresponding hooks and terminates upon error encountered. -PassResult FunctionPass::runOnModule(Module *m) { - for (auto &fn : *m) { - // All function passes ignore external functions. - if (fn.isExternal()) - continue; - - if (runOnFunction(&fn)) - return failure(); - } - return success(); -} - /// Forwarding function to execute this pass. PassResult FunctionPassBase::run(Function *fn) { /// Initialize the pass state. diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index fee9d5a3828..24b53220613 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -80,11 +80,7 @@ struct SimpleOperationInfo : public llvm::DenseMapInfo { namespace { /// Simple common sub-expression elimination. -struct CSE : public FunctionPass { - CSE() : FunctionPass(&CSE::passID) {} - - constexpr static PassID passID = {}; - +struct CSE : public FunctionPass { /// Shared implementation of operation elimination and scoped map definitions. using AllocatorTy = llvm::RecyclingAllocator< llvm::BumpPtrAllocator, @@ -115,7 +111,7 @@ struct CSE : public FunctionPass { void simplifyBlock(DominanceInfo &domInfo, Block *bb); void simplifyBlockList(DominanceInfo &domInfo, BlockList &blockList); - PassResult runOnFunction(Function *f) override; + PassResult runOnFunction() override; private: /// A scoped hash table of defining operations within a function. @@ -220,9 +216,9 @@ void CSE::simplifyBlockList(DominanceInfo &domInfo, BlockList &blockList) { } } -PassResult CSE::runOnFunction(Function *f) { - DominanceInfo domInfo(f); - simplifyBlockList(domInfo, f->getBlockList()); +PassResult CSE::runOnFunction() { + DominanceInfo domInfo(&getFunction()); + simplifyBlockList(domInfo, getFunction().getBlockList()); /// Erase any operations that were marked as dead during simplification. for (auto *op : opsToErase) @@ -232,7 +228,7 @@ PassResult CSE::runOnFunction(Function *f) { return success(); } -FunctionPass *mlir::createCSEPass() { return new CSE(); } +FunctionPassBase *mlir::createCSEPass() { return new CSE(); } static PassRegistration pass("cse", "Eliminate common sub-expressions in functions"); diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp index ac77e201acf..764f055a673 100644 --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -33,30 +33,30 @@ using namespace mlir; namespace { /// Canonicalize operations in functions. -struct Canonicalizer : public FunctionPass { - Canonicalizer() : FunctionPass(&Canonicalizer::passID) {} - PassResult runOnFunction(Function *fn) override; - - constexpr static PassID passID = {}; +struct Canonicalizer : public FunctionPass { + PassResult runOnFunction() override; }; } // end anonymous namespace -PassResult Canonicalizer::runOnFunction(Function *fn) { - auto *context = fn->getContext(); +PassResult Canonicalizer::runOnFunction() { OwningRewritePatternList patterns; + auto &func = getFunction(); // TODO: Instead of adding all known patterns from the whole system lazily add // and cache the canonicalization patterns for ops we see in practice when // building the worklist. For now, we just grab everything. - for (auto *op : fn->getContext()->getRegisteredOperations()) + auto *context = func.getContext(); + for (auto *op : context->getRegisteredOperations()) op->getCanonicalizationPatterns(patterns, context); - applyPatternsGreedily(fn, std::move(patterns)); + applyPatternsGreedily(&func, std::move(patterns)); return success(); } /// Create a Canonicalizer pass. -FunctionPass *mlir::createCanonicalizerPass() { return new Canonicalizer(); } +FunctionPassBase *mlir::createCanonicalizerPass() { + return new Canonicalizer(); +} static PassRegistration pass("canonicalize", "Canonicalize operations"); diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index 4817baaa23e..ed35c03755f 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -26,18 +26,14 @@ using namespace mlir; namespace { /// Simple constant folding pass. -struct ConstantFold : public FunctionPass { - ConstantFold() : FunctionPass(&ConstantFold::passID) {} - +struct ConstantFold : public FunctionPass { // All constants in the function post folding. SmallVector existingConstants; // Operations that were folded and that need to be erased. std::vector opInstsToErase; void foldInstruction(Instruction *op); - PassResult runOnFunction(Function *f) override; - - constexpr static PassID passID = {}; + PassResult runOnFunction() override; }; } // end anonymous namespace @@ -96,11 +92,11 @@ void ConstantFold::foldInstruction(Instruction *op) { // For now, we do a simple top-down pass over a function folding constants. We // don't handle conditional control flow, block arguments, folding // conditional branches, or anything else fancy. -PassResult ConstantFold::runOnFunction(Function *f) { +PassResult ConstantFold::runOnFunction() { existingConstants.clear(); opInstsToErase.clear(); - f->walk([&](Instruction *inst) { foldInstruction(inst); }); + getFunction().walk([&](Instruction *inst) { foldInstruction(inst); }); // At this point, these operations are dead, remove them. // TODO: This is assuming that all constant foldable operations have no @@ -122,7 +118,7 @@ PassResult ConstantFold::runOnFunction(Function *f) { } /// Creates a constant folding pass. -FunctionPass *mlir::createConstantFoldPass() { return new ConstantFold(); } +FunctionPassBase *mlir::createConstantFoldPass() { return new ConstantFold(); } static PassRegistration pass("constant-fold", "Constant fold operations in functions"); diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 4fb6f34ed53..82ba07acb5f 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -74,18 +74,17 @@ namespace { /// memory capacity provided. // TODO(bondhugula): We currently can't generate DMAs correctly when stores are // strided. Check for strided stores. -struct DmaGeneration : public FunctionPass { +struct DmaGeneration : public FunctionPass { explicit DmaGeneration(unsigned slowMemorySpace = 0, unsigned fastMemorySpace = clFastMemorySpace, int minDmaTransferSize = 1024, uint64_t fastMemCapacityBytes = clFastMemoryCapacity * 1024) - : FunctionPass(&DmaGeneration::passID), slowMemorySpace(slowMemorySpace), - fastMemorySpace(fastMemorySpace), + : slowMemorySpace(slowMemorySpace), fastMemorySpace(fastMemorySpace), minDmaTransferSize(minDmaTransferSize), fastMemCapacityBytes(fastMemCapacityBytes) {} - PassResult runOnFunction(Function *f) override; + PassResult runOnFunction() override; bool runOnBlock(Block *block); uint64_t runOnBlock(Block::iterator begin, Block::iterator end); @@ -115,8 +114,6 @@ struct DmaGeneration : public FunctionPass { // Constant zero index to avoid too many duplicates. Value *zeroIndex = nullptr; - - constexpr static PassID passID = {}; }; } // end anonymous namespace @@ -125,10 +122,10 @@ struct DmaGeneration : public FunctionPass { /// buffers in 'fastMemorySpace', and replaces memory operations to the former /// by the latter. Only load op's handled for now. /// TODO(bondhugula): extend this to store op's. -FunctionPass *mlir::createDmaGenerationPass(unsigned slowMemorySpace, - unsigned fastMemorySpace, - int minDmaTransferSize, - uint64_t fastMemCapacityBytes) { +FunctionPassBase *mlir::createDmaGenerationPass(unsigned slowMemorySpace, + unsigned fastMemorySpace, + int minDmaTransferSize, + uint64_t fastMemCapacityBytes) { return new DmaGeneration(slowMemorySpace, fastMemorySpace, minDmaTransferSize, fastMemCapacityBytes); } @@ -757,7 +754,8 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { return totalDmaBuffersSizeInBytes; } -PassResult DmaGeneration::runOnFunction(Function *f) { +PassResult DmaGeneration::runOnFunction() { + Function *f = &getFunction(); FuncBuilder topBuilder(f); zeroIndex = topBuilder.create(f->getLoc(), 0); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 0f4e45c372a..1528e394506 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -86,14 +86,12 @@ namespace { // TODO(andydavis) Extend this pass to check for fusion preventing dependences, // and add support for more general loop fusion algorithms. -struct LoopFusion : public FunctionPass { +struct LoopFusion : public FunctionPass { LoopFusion(unsigned fastMemorySpace = 0, uint64_t localBufSizeThreshold = 0) - : FunctionPass(&LoopFusion::passID), - localBufSizeThreshold(localBufSizeThreshold), + : localBufSizeThreshold(localBufSizeThreshold), fastMemorySpace(fastMemorySpace) {} - PassResult runOnFunction(Function *f) override; - constexpr static PassID passID = {}; + PassResult runOnFunction() override; // Any local buffers smaller than this size (in bytes) will be created in // `fastMemorySpace` if provided. @@ -107,8 +105,8 @@ struct LoopFusion : public FunctionPass { } // end anonymous namespace -FunctionPass *mlir::createLoopFusionPass(unsigned fastMemorySpace, - uint64_t localBufSizeThreshold) { +FunctionPassBase *mlir::createLoopFusionPass(unsigned fastMemorySpace, + uint64_t localBufSizeThreshold) { return new LoopFusion(fastMemorySpace, localBufSizeThreshold); } @@ -1802,7 +1800,7 @@ public: } // end anonymous namespace -PassResult LoopFusion::runOnFunction(Function *f) { +PassResult LoopFusion::runOnFunction() { // Override if a command line argument was provided. if (clFusionFastMemorySpace.getNumOccurrences() > 0) { fastMemorySpace = clFusionFastMemorySpace.getValue(); @@ -1814,7 +1812,7 @@ PassResult LoopFusion::runOnFunction(Function *f) { } MemRefDependenceGraph g; - if (g.init(f)) + if (g.init(&getFunction())) GreedyFusion(&g).run(localBufSizeThreshold, fastMemorySpace); return success(); } diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 240b2b6d9b6..db0e8d51ad8 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -47,12 +47,10 @@ static llvm::cl::list clTileSizes( namespace { /// A pass to perform loop tiling on all suitable loop nests of a Function. -struct LoopTiling : public FunctionPass { - LoopTiling() : FunctionPass(&LoopTiling::passID) {} - PassResult runOnFunction(Function *f) override; +struct LoopTiling : public FunctionPass { + PassResult runOnFunction() override; constexpr static unsigned kDefaultTileSize = 4; - constexpr static PassID passID = {}; }; } // end anonymous namespace @@ -65,7 +63,7 @@ static llvm::cl::opt /// Creates a pass to perform loop tiling on all suitable loop nests of an /// Function. -FunctionPass *mlir::createLoopTilingPass() { return new LoopTiling(); } +FunctionPassBase *mlir::createLoopTilingPass() { return new LoopTiling(); } // Move the loop body of AffineForOp 'src' from 'src' into the specified // location in destination's body. @@ -255,9 +253,9 @@ getTileableBands(Function *f, getMaximalPerfectLoopNest(forOp); } -PassResult LoopTiling::runOnFunction(Function *f) { +PassResult LoopTiling::runOnFunction() { std::vector, 6>> bands; - getTileableBands(f, &bands); + getTileableBands(&getFunction(), &bands); for (auto &band : bands) { // Set up tile sizes; fill missing tile sizes at the end with default tile diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 3b4a0517f0d..231dba65720 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -65,7 +65,7 @@ namespace { /// full unroll threshold was specified, in which case, fully unrolls all loops /// with trip count less than the specified threshold. The latter is for testing /// purposes, especially for testing outer loop unrolling. -struct LoopUnroll : public FunctionPass { +struct LoopUnroll : public FunctionPass { const Optional unrollFactor; const Optional unrollFull; // Callback to obtain unroll factors; if this has a callable target, takes @@ -76,21 +76,19 @@ struct LoopUnroll : public FunctionPass { Optional unrollFull = None, const std::function)> &getUnrollFactor = nullptr) - : FunctionPass(&LoopUnroll::passID), unrollFactor(unrollFactor), - unrollFull(unrollFull), getUnrollFactor(getUnrollFactor) {} + : unrollFactor(unrollFactor), unrollFull(unrollFull), + getUnrollFactor(getUnrollFactor) {} - PassResult runOnFunction(Function *f) override; + PassResult runOnFunction() override; /// Unroll this for inst. Returns false if nothing was done. bool runOnAffineForOp(OpPointer forOp); static const unsigned kDefaultUnrollFactor = 4; - - constexpr static PassID passID = {}; }; } // end anonymous namespace -PassResult LoopUnroll::runOnFunction(Function *f) { +PassResult LoopUnroll::runOnFunction() { // Gathers all innermost loops through a post order pruned walk. struct InnermostLoopGatherer { // Store innermost loops as we walk. @@ -132,7 +130,7 @@ PassResult LoopUnroll::runOnFunction(Function *f) { // Gathers all loops with trip count <= minTripCount. Do a post order walk // so that loops are gathered from innermost to outermost (or else unrolling // an outer one may delete gathered inner ones). - f->walkPostOrder([&](OpPointer forOp) { + getFunction().walkPostOrder([&](OpPointer forOp) { Optional tripCount = getConstantTripCount(forOp); if (tripCount.hasValue() && tripCount.getValue() <= clUnrollFullThreshold) loops.push_back(forOp); @@ -146,9 +144,10 @@ PassResult LoopUnroll::runOnFunction(Function *f) { ? clUnrollNumRepetitions : 1; // If the call back is provided, we will recurse until no loops are found. + Function *func = &getFunction(); for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) { InnermostLoopGatherer ilg; - ilg.walkPostOrder(f); + ilg.walkPostOrder(func); auto &loops = ilg.loops; if (loops.empty()) break; @@ -184,7 +183,7 @@ bool LoopUnroll::runOnAffineForOp(OpPointer forOp) { return loopUnrollByFactor(forOp, kDefaultUnrollFactor); } -FunctionPass *mlir::createLoopUnrollPass( +FunctionPassBase *mlir::createLoopUnrollPass( int unrollFactor, int unrollFull, const std::function)> &getUnrollFactor) { diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 87e2770aa41..e950d117ddc 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -71,34 +71,30 @@ static llvm::cl::opt namespace { /// Loop unroll jam pass. Currently, this just unroll jams the first /// outer loop in a Function. -struct LoopUnrollAndJam : public FunctionPass { +struct LoopUnrollAndJam : public FunctionPass { Optional unrollJamFactor; static const unsigned kDefaultUnrollJamFactor = 4; explicit LoopUnrollAndJam(Optional unrollJamFactor = None) - : FunctionPass(&LoopUnrollAndJam::passID), - unrollJamFactor(unrollJamFactor) {} + : unrollJamFactor(unrollJamFactor) {} - PassResult runOnFunction(Function *f) override; + PassResult runOnFunction() override; bool runOnAffineForOp(OpPointer forOp); - - constexpr static PassID passID = {}; }; } // end anonymous namespace -FunctionPass *mlir::createLoopUnrollAndJamPass(int unrollJamFactor) { +FunctionPassBase *mlir::createLoopUnrollAndJamPass(int unrollJamFactor) { return new LoopUnrollAndJam( unrollJamFactor == -1 ? None : Optional(unrollJamFactor)); } -PassResult LoopUnrollAndJam::runOnFunction(Function *f) { +PassResult LoopUnrollAndJam::runOnFunction() { // Currently, just the outermost loop from the first loop nest is // unroll-and-jammed by this pass. However, runOnAffineForOp can be called on - // any for Inst. - auto &entryBlock = f->front(); - if (!entryBlock.empty()) - if (auto forOp = entryBlock.front().dyn_cast()) - runOnAffineForOp(forOp); + // any for operation. + auto &entryBlock = getFunction().front(); + if (auto forOp = entryBlock.front().dyn_cast()) + runOnAffineForOp(forOp); return success(); } diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index 83620516994..aecd4314d42 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -242,16 +242,12 @@ Optional> static expandAffineMap( } namespace { -class LowerAffinePass : public FunctionPass { -public: - LowerAffinePass() : FunctionPass(&passID) {} - PassResult runOnFunction(Function *function) override; +struct LowerAffinePass : public FunctionPass { + PassResult runOnFunction() override; bool lowerAffineFor(OpPointer forOp); bool lowerAffineIf(AffineIfOp *ifOp); bool lowerAffineApply(AffineApplyOp *op); - - constexpr static PassID passID = {}; }; } // end anonymous namespace @@ -608,12 +604,12 @@ bool LowerAffinePass::lowerAffineApply(AffineApplyOp *op) { // construction. When an Value is used, it gets replaced with the // corresponding Value that has been defined previously. The value flow // starts with function arguments converted to basic block arguments. -PassResult LowerAffinePass::runOnFunction(Function *function) { +PassResult LowerAffinePass::runOnFunction() { SmallVector instsToRewrite; // Collect all the For instructions as well as AffineIfOps and AffineApplyOps. // We do this as a prepass to avoid invalidating the walker with our rewrite. - function->walk([&](Instruction *inst) { + getFunction().walk([&](Instruction *inst) { if (inst->isa() || inst->isa() || inst->isa()) instsToRewrite.push_back(inst); @@ -638,7 +634,9 @@ PassResult LowerAffinePass::runOnFunction(Function *function) { /// Lowers If and For instructions within a function into their lower level CFG /// equivalent blocks. -FunctionPass *mlir::createLowerAffinePass() { return new LowerAffinePass(); } +FunctionPassBase *mlir::createLowerAffinePass() { + return new LowerAffinePass(); +} static PassRegistration pass("lower-affine", diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index 61f75ae76e6..ddeb524f5ab 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -424,25 +424,22 @@ public: } }; -struct LowerVectorTransfersPass : public FunctionPass { - LowerVectorTransfersPass() - : FunctionPass(&LowerVectorTransfersPass::passID) {} - - PassResult runOnFunction(Function *fn) override { +struct LowerVectorTransfersPass + : public FunctionPass { + PassResult runOnFunction() { + Function *f = &getFunction(); applyMLPatternsGreedily, - VectorTransferExpander>(fn); + VectorTransferExpander>(f); return success(); } // Thread-safe RAII context with local scope. BumpPtrAllocator freed on exit. edsc::ScopedEDSCContext raiiContext; - - constexpr static PassID passID = {}; }; } // end anonymous namespace -FunctionPass *mlir::createLowerVectorTransfersPass() { +FunctionPassBase *mlir::createLowerVectorTransfersPass() { return new LowerVectorTransfersPass(); } diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 6177ca1233b..7b45af011ab 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -196,12 +196,8 @@ struct MaterializationState { DenseMap *substitutionsMap; }; -struct MaterializeVectorsPass : public FunctionPass { - MaterializeVectorsPass() : FunctionPass(&MaterializeVectorsPass::passID) {} - - PassResult runOnFunction(Function *f) override; - - constexpr static PassID passID = {}; +struct MaterializeVectorsPass : public FunctionPass { + PassResult runOnFunction() override; }; } // end anonymous namespace @@ -733,11 +729,12 @@ static bool materialize(Function *f, return false; } -PassResult MaterializeVectorsPass::runOnFunction(Function *f) { +PassResult MaterializeVectorsPass::runOnFunction() { // Thread-safe RAII local context, BumpPtrAllocator freed on exit. NestedPatternContext mlContext; // TODO(ntv): Check to see if this supports arbitrary top-level code. + Function *f = &getFunction(); if (f->getBlocks().size() != 1) return success(); @@ -771,7 +768,7 @@ PassResult MaterializeVectorsPass::runOnFunction(Function *f) { return fail ? PassResult::Failure : PassResult::Success; } -FunctionPass *mlir::createMaterializeVectorsPass() { +FunctionPassBase *mlir::createMaterializeVectorsPass() { return new MaterializeVectorsPass(); } diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index 0ba06fecae0..067bfa4c94c 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -69,10 +69,8 @@ namespace { // currently only eliminates the stores only if no other loads/uses (other // than dealloc) remain. // -struct MemRefDataFlowOpt : public FunctionPass { - explicit MemRefDataFlowOpt() : FunctionPass(&MemRefDataFlowOpt::passID) {} - - PassResult runOnFunction(Function *f) override; +struct MemRefDataFlowOpt : public FunctionPass { + PassResult runOnFunction() override; void forwardStoreToLoad(OpPointer loadOp); @@ -83,15 +81,13 @@ struct MemRefDataFlowOpt : public FunctionPass { DominanceInfo *domInfo = nullptr; PostDominanceInfo *postDomInfo = nullptr; - - constexpr static PassID passID = {}; }; } // end anonymous namespace /// Creates a pass to perform optimizations relying on memref dataflow such as /// store to load forwarding, elimination of dead stores, and dead allocs. -FunctionPass *mlir::createMemRefDataFlowOptPass() { +FunctionPassBase *mlir::createMemRefDataFlowOptPass() { return new MemRefDataFlowOpt(); } @@ -213,22 +209,22 @@ void MemRefDataFlowOpt::forwardStoreToLoad(OpPointer loadOp) { loadOpsToErase.push_back(loadOpInst); } -PassResult MemRefDataFlowOpt::runOnFunction(Function *f) { +PassResult MemRefDataFlowOpt::runOnFunction() { // Only supports single block functions at the moment. - if (f->getBlocks().size() != 1) + Function &f = getFunction(); + if (f.getBlocks().size() != 1) return success(); - DominanceInfo theDomInfo(f); + DominanceInfo theDomInfo(&f); domInfo = &theDomInfo; - PostDominanceInfo thePostDomInfo(f); + PostDominanceInfo thePostDomInfo(&f); postDomInfo = &thePostDomInfo; loadOpsToErase.clear(); memrefsToErase.clear(); // Walk all load's and perform load/store forwarding. - f->walk( - [&](OpPointer loadOp) { forwardStoreToLoad(loadOp); }); + f.walk([&](OpPointer loadOp) { forwardStoreToLoad(loadOp); }); // Erase all load op's whose results were replaced with store fwd'ed ones. for (auto *loadOp : loadOpsToErase) { diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index f41f56efd8f..42e1446211b 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -38,21 +38,18 @@ using namespace mlir; namespace { -struct PipelineDataTransfer : public FunctionPass { - PipelineDataTransfer() : FunctionPass(&PipelineDataTransfer::passID) {} - PassResult runOnFunction(Function *f) override; +struct PipelineDataTransfer : public FunctionPass { + PassResult runOnFunction() override; PassResult runOnAffineForOp(OpPointer forOp); std::vector> forOps; - - constexpr static PassID passID = {}; }; } // end anonymous namespace /// Creates a pass to pipeline explicit movement of data across levels of the /// memory hierarchy. -FunctionPass *mlir::createPipelineDataTransferPass() { +FunctionPassBase *mlir::createPipelineDataTransferPass() { return new PipelineDataTransfer(); } @@ -142,14 +139,14 @@ static bool doubleBuffer(Value *oldMemRef, OpPointer forOp) { } /// Returns success if the IR is in a valid state. -PassResult PipelineDataTransfer::runOnFunction(Function *f) { +PassResult PipelineDataTransfer::runOnFunction() { // Do a post order walk so that inner loop DMAs are processed first. This is // necessary since 'for' instructions nested within would otherwise become // invalid (erased) when the outer loop is pipelined (the pipelined one gets // deleted and replaced by a prologue, a new steady-state loop and an // epilogue). forOps.clear(); - f->walkPostOrder( + getFunction().walkPostOrder( [&](OpPointer forOp) { forOps.push_back(forOp); }); bool ret = false; for (auto forOp : forOps) { diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp index d0fdcb5527f..4c0fed5b648 100644 --- a/mlir/lib/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp @@ -36,18 +36,14 @@ namespace { /// the Function. This is mainly to test the simplifyAffineExpr method. /// TODO(someone): This should just be defined as a canonicalization pattern /// on AffineMap and driven from the existing canonicalization pass. -struct SimplifyAffineStructures : public FunctionPass { - explicit SimplifyAffineStructures() - : FunctionPass(&SimplifyAffineStructures::passID) {} - - PassResult runOnFunction(Function *f) override; - - constexpr static PassID passID = {}; +struct SimplifyAffineStructures + : public FunctionPass { + PassResult runOnFunction() override; }; } // end anonymous namespace -FunctionPass *mlir::createSimplifyAffineStructuresPass() { +FunctionPassBase *mlir::createSimplifyAffineStructuresPass() { return new SimplifyAffineStructures(); } @@ -61,8 +57,8 @@ static IntegerSet simplifyIntegerSet(IntegerSet set) { return set; } -PassResult SimplifyAffineStructures::runOnFunction(Function *f) { - f->walk([&](Instruction *opInst) { +PassResult SimplifyAffineStructures::runOnFunction() { + getFunction().walk([&](Instruction *opInst) { for (auto attr : opInst->getAttrs()) { if (auto mapAttr = attr.second.dyn_cast()) { MutableAffineMap mMap(mapAttr.getValue()); diff --git a/mlir/lib/Transforms/StripDebugInfo.cpp b/mlir/lib/Transforms/StripDebugInfo.cpp index fc2b0eb0a95..0f1ba02174b 100644 --- a/mlir/lib/Transforms/StripDebugInfo.cpp +++ b/mlir/lib/Transforms/StripDebugInfo.cpp @@ -23,26 +23,25 @@ using namespace mlir; namespace { -struct StripDebugInfo : public FunctionPass { - StripDebugInfo() : FunctionPass(&StripDebugInfo::passID) {} - - PassResult runOnFunction(Function *f) override; - - constexpr static PassID passID = {}; +struct StripDebugInfo : public FunctionPass { + PassResult runOnFunction() override; }; } // end anonymous namespace -PassResult StripDebugInfo::runOnFunction(Function *f) { - UnknownLoc unknownLoc = UnknownLoc::get(f->getContext()); +PassResult StripDebugInfo::runOnFunction() { + Function &func = getFunction(); + UnknownLoc unknownLoc = UnknownLoc::get(func.getContext()); // Strip the debug info from the function and its instructions. - f->setLoc(unknownLoc); - f->walk([&](Instruction *inst) { inst->setLoc(unknownLoc); }); + func.setLoc(unknownLoc); + func.walk([&](Instruction *inst) { inst->setLoc(unknownLoc); }); return success(); } /// Creates a pass to strip debug information from a function. -FunctionPass *mlir::createStripDebugInfoPass() { return new StripDebugInfo(); } +FunctionPassBase *mlir::createStripDebugInfoPass() { + return new StripDebugInfo(); +} static PassRegistration pass("strip-debuginfo", "Strip debug info from functions and instructions"); diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index 2363c5638ee..60e58c42e6b 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -83,20 +83,17 @@ static llvm::cl::opt clTestNormalizeMaps( namespace { -struct VectorizerTestPass : public FunctionPass { +struct VectorizerTestPass : public FunctionPass { static constexpr auto kTestAffineMapOpName = "test_affine_map"; static constexpr auto kTestAffineMapAttrName = "affine_map"; - VectorizerTestPass() : FunctionPass(&VectorizerTestPass::passID) {} - PassResult runOnFunction(Function *f) override; + PassResult runOnFunction() override; void testVectorShapeRatio(Function *f); void testForwardSlicing(Function *f); void testBackwardSlicing(Function *f); void testSlicing(Function *f); void testComposeMaps(Function *f); void testNormalizeMaps(Function *f); - - constexpr static PassID passID = {}; }; } // end anonymous namespace @@ -263,11 +260,12 @@ void VectorizerTestPass::testNormalizeMaps(Function *f) { } } -PassResult VectorizerTestPass::runOnFunction(Function *f) { +PassResult VectorizerTestPass::runOnFunction() { // Thread-safe RAII local context, BumpPtrAllocator freed on exit. NestedPatternContext mlContext; // Only support single block functions at this point. + Function *f = &getFunction(); if (f->getBlocks().size() != 1) return success(); @@ -292,7 +290,7 @@ PassResult VectorizerTestPass::runOnFunction(Function *f) { return PassResult::Success; } -FunctionPass *mlir::createVectorizerTestPass() { +FunctionPassBase *mlir::createVectorizerTestPass() { return new VectorizerTestPass(); } diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 5722b9d17da..8a378a29c84 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -651,12 +651,8 @@ static std::vector makePatterns() { namespace { -struct Vectorize : public FunctionPass { - Vectorize() : FunctionPass(&Vectorize::passID) {} - - PassResult runOnFunction(Function *f) override; - - constexpr static PassID passID = {}; +struct Vectorize : public FunctionPass { + PassResult runOnFunction() override; }; } // end anonymous namespace @@ -1264,10 +1260,11 @@ static bool vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) { /// Applies vectorization to the current Function by searching over a bunch of /// predetermined patterns. -PassResult Vectorize::runOnFunction(Function *f) { +PassResult Vectorize::runOnFunction() { // Thread-safe RAII local context, BumpPtrAllocator freed on exit. NestedPatternContext mlContext; + Function *f = &getFunction(); for (auto &pat : makePatterns()) { LLVM_DEBUG(dbgs() << "\n******************************************"); LLVM_DEBUG(dbgs() << "\n******************************************"); @@ -1301,7 +1298,7 @@ PassResult Vectorize::runOnFunction(Function *f) { return PassResult::Success; } -FunctionPass *mlir::createVectorizePass() { return new Vectorize(); } +FunctionPassBase *mlir::createVectorizePass() { return new Vectorize(); } static PassRegistration pass("vectorize", diff --git a/mlir/lib/Transforms/ViewFunctionGraph.cpp b/mlir/lib/Transforms/ViewFunctionGraph.cpp index 14e21770e25..30fae94139f 100644 --- a/mlir/lib/Transforms/ViewFunctionGraph.cpp +++ b/mlir/lib/Transforms/ViewFunctionGraph.cpp @@ -73,18 +73,15 @@ void mlir::Function::viewGraph() const { } namespace { -struct PrintCFGPass : public FunctionPass { +struct PrintCFGPass : public FunctionPass { PrintCFGPass(llvm::raw_ostream &os = llvm::errs(), bool shortNames = false, const llvm::Twine &title = "") - : FunctionPass(&PrintCFGPass::passID), os(os), shortNames(shortNames), - title(title) {} - PassResult runOnFunction(Function *function) override { - mlir::writeGraph(os, function, shortNames, title); + : os(os), shortNames(shortNames), title(title) {} + PassResult runOnFunction() { + mlir::writeGraph(os, &getFunction(), shortNames, title); return success(); } - constexpr static PassID passID = {}; - private: llvm::raw_ostream &os; bool shortNames; @@ -92,9 +89,9 @@ private: }; } // namespace -FunctionPass *mlir::createPrintCFGGraphPass(llvm::raw_ostream &os, - bool shortNames, - const llvm::Twine &title) { +FunctionPassBase *mlir::createPrintCFGGraphPass(llvm::raw_ostream &os, + bool shortNames, + const llvm::Twine &title) { return new PrintCFGPass(os, shortNames, title); } diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 4c0c8fdc296..4a2b4e7489f 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -28,7 +28,7 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/Parser.h" -#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" #include "mlir/Support/FileUtilities.h" #include "mlir/TensorFlow/ControlFlowOps.h" #include "mlir/TensorFlow/Passes.h" @@ -125,16 +125,13 @@ static OptResult performActions(SourceMgr &sourceMgr, MLIRContext *context) { return OptFailure; // Run each of the passes that were selected. - for (const auto *passInfo : *passList) { - std::unique_ptr pass(passInfo->createPass()); - PassResult result = pass->runOnModule(module.get()); - if (result) - return OptFailure; - - // Verify that the result of the pass is still valid. - if (module->verify()) - return OptFailure; - } + // TODO(riverriddle) Make sure that the verifer is run after each pass when it + // is no longer run by default within the PassManager. + PassManager pm; + for (const auto *passInfo : *passList) + pm.addPass(passInfo->createPass()); + if (pm.run(module.get())) + return OptFailure; std::string errorMessage; auto output = openOutputFile(outputFilename, &errorMessage); -- cgit v1.2.3 From ed5fe2098be12d839cb4384e59a93f15f6f42e58 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Thu, 28 Feb 2019 14:50:42 -0800 Subject: Remove PassResult and have the runOnFunction/runOnModule functions return void instead. To signal a pass failure, passes should now invoke the 'signalPassFailure' method. This provides the equivalent functionality when needed, but isn't an intrusive part of the API like PassResult. PiperOrigin-RevId: 236202029 --- mlir/include/mlir/Pass/Pass.h | 57 ++++++++++++---------- mlir/include/mlir/Pass/PassManager.h | 4 +- mlir/lib/Analysis/MemRefBoundCheck.cpp | 5 +- mlir/lib/Analysis/MemRefDependenceCheck.cpp | 5 +- mlir/lib/Analysis/OpStats.cpp | 5 +- mlir/lib/EDSC/LowerEDSCTestPass.cpp | 5 +- mlir/lib/ExecutionEngine/ExecutionEngine.cpp | 2 +- .../lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp | 5 +- mlir/lib/Pass/Pass.cpp | 52 ++++++++++++-------- mlir/lib/Transforms/CSE.cpp | 6 +-- mlir/lib/Transforms/Canonicalizer.cpp | 5 +- mlir/lib/Transforms/ConstantFold.cpp | 6 +-- mlir/lib/Transforms/DmaGeneration.cpp | 9 ++-- mlir/lib/Transforms/LoopFusion.cpp | 5 +- mlir/lib/Transforms/LoopTiling.cpp | 10 ++-- mlir/lib/Transforms/LoopUnroll.cpp | 7 ++- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 6 +-- mlir/lib/Transforms/LowerAffine.cpp | 12 ++--- mlir/lib/Transforms/LowerVectorTransfers.cpp | 3 +- mlir/lib/Transforms/MaterializeVectors.cpp | 10 ++-- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 9 ++-- mlir/lib/Transforms/PipelineDataTransfer.cpp | 30 +++++------- mlir/lib/Transforms/SimplifyAffineStructures.cpp | 6 +-- mlir/lib/Transforms/StripDebugInfo.cpp | 5 +- .../Vectorization/VectorizerTestPass.cpp | 7 ++- mlir/lib/Transforms/Vectorize.cpp | 5 +- mlir/lib/Transforms/ViewFunctionGraph.cpp | 3 +- mlir/tools/mlir-opt/mlir-opt.cpp | 2 +- 28 files changed, 135 insertions(+), 151 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h index 26b76e317eb..6a1e2791ed4 100644 --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -18,6 +18,7 @@ #ifndef MLIR_PASS_PASS_H #define MLIR_PASS_PASS_H +#include "mlir/IR/Module.h" #include "mlir/Pass/PassRegistry.h" #include "llvm/ADT/PointerIntPair.h" @@ -25,15 +26,6 @@ namespace mlir { class Function; class Module; -// Values that can be used by to signal success/failure. This can be implicitly -// converted to/from boolean values, with false representing success and true -// failure. -struct LLVM_NODISCARD PassResult { - enum ResultEnum { Success, Failure } value; - PassResult(ResultEnum v) : value(v) {} - operator bool() const { return value == Failure; } -}; - /// The abstract base pass class. This class contains information describing the /// derived pass object, e.g its kind and abstract PassInfo. class Pass { @@ -45,9 +37,6 @@ public: /// Returns the unique identifier that corresponds to this pass. const PassID *getPassID() const { return passIDAndKind.getPointer(); } - static PassResult success() { return PassResult::Success; } - static PassResult failure() { return PassResult::Failure; } - /// Returns the pass info for the specified pass class or null if unknown. static const PassInfo *lookupPassInfo(const PassID *passID); template static const PassInfo *lookupPassInfo() { @@ -79,10 +68,10 @@ class ModulePassExecutor; /// The state for a single execution of a pass. This provides a unified /// interface for accessing and initializing necessary state for pass execution. template struct PassExecutionState { - explicit PassExecutionState(IRUnitT *ir) : ir(ir) {} + explicit PassExecutionState(IRUnitT *ir) : irAndPassFailed(ir, false) {} /// The current IR unit being transformed. - IRUnitT *ir; + llvm::PointerIntPair irAndPassFailed; }; } // namespace detail @@ -99,17 +88,24 @@ protected: explicit FunctionPassBase(const PassID *id) : Pass(id, Kind::FunctionPass) {} /// The polymorphic API that runs the pass over the currently held function. - virtual PassResult runOnFunction() = 0; + virtual void runOnFunction() = 0; /// Return the current function being transformed. Function &getFunction() { + return *getPassState().irAndPassFailed.getPointer(); + } + + /// Returns the current pass state. + detail::PassExecutionState &getPassState() { assert(passState && "pass state was never initialized"); - return *passState->ir; + return *passState; } private: - /// Forwarding function to execute this pass. - PassResult run(Function *fn); + /// Forwarding function to execute this pass. Returns false if the pass + /// execution failed, true otherwise. + LLVM_NODISCARD + bool run(Function *fn); /// The current execution state for the pass. llvm::Optional> passState; @@ -130,17 +126,22 @@ protected: explicit ModulePassBase(const PassID *id) : Pass(id, Kind::ModulePass) {} /// The polymorphic API that runs the pass over the currently held module. - virtual PassResult runOnModule() = 0; + virtual void runOnModule() = 0; /// Return the current module being transformed. - Module &getModule() { + Module &getModule() { return *getPassState().irAndPassFailed.getPointer(); } + + /// Returns the current pass state. + detail::PassExecutionState &getPassState() { assert(passState && "pass state was never initialized"); - return *passState->ir; + return *passState; } private: - /// Forwarding function to execute this pass. - PassResult run(Module *module); + /// Forwarding function to execute this pass. Returns false if the pass + /// execution failed, true otherwise. + LLVM_NODISCARD + bool run(Module *module); /// The current execution state for the pass. llvm::Optional> passState; @@ -162,6 +163,12 @@ protected: /// TODO(riverriddle) Provide additional utilities for cloning, getting the /// derived class name, etc.. + + /// Signal that some invariant was broken when running. The IR is allowed to + /// be in an invalid state. + void signalPassFailure() { + this->getPassState().irAndPassFailed.setInt(true); + } }; } // end namespace detail @@ -174,14 +181,14 @@ protected: /// additional functions. /// /// Derived function passes are expected to provide the following: -/// - A 'PassResult runOnFunction()' method. +/// - A 'void runOnFunction()' method. template using FunctionPass = detail::PassModel; /// A model for providing module pass specific utilities. /// /// Derived module passes are expected to provide the following: -/// - A 'PassResult runOnModule()' method. +/// - A 'void runOnModule()' method. template using ModulePass = detail::PassModel; } // end namespace mlir diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h index 981d860888d..2c00e3dd902 100644 --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -50,7 +50,9 @@ public: /// executor if necessary. void addPass(FunctionPassBase *pass); - /// Run the passes within this manager on the provided module. + /// Run the passes within this manager on the provided module. Returns false + /// if the run failed, true otherwise. + LLVM_NODISCARD bool run(Module *module); private: diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index a6730f01199..d709566c322 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -38,7 +38,7 @@ namespace { /// Checks for out of bound memef access subscripts.. struct MemRefBoundCheck : public FunctionPass { - PassResult runOnFunction() override; + void runOnFunction() override; }; } // end anonymous namespace @@ -47,7 +47,7 @@ FunctionPassBase *mlir::createMemRefBoundCheckPass() { return new MemRefBoundCheck(); } -PassResult MemRefBoundCheck::runOnFunction() { +void MemRefBoundCheck::runOnFunction() { getFunction().walk([](Instruction *opInst) { if (auto loadOp = opInst->dyn_cast()) { boundCheckLoadOrStoreOp(loadOp); @@ -56,7 +56,6 @@ PassResult MemRefBoundCheck::runOnFunction() { } // TODO(bondhugula): do this for DMA ops as well. }); - return success(); } static PassRegistration diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp index 33488f0c7a8..d0074dad7f2 100644 --- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp @@ -39,7 +39,7 @@ namespace { /// Checks dependences between all pairs of memref accesses in a Function. struct MemRefDependenceCheck : public FunctionPass { SmallVector loadsAndStores; - PassResult runOnFunction() override; + void runOnFunction() override; }; } // end anonymous namespace @@ -111,7 +111,7 @@ static void checkDependences(ArrayRef loadsAndStores) { // Walks the Function 'f' adding load and store ops to 'loadsAndStores'. // Runs pair-wise dependence checks. -PassResult MemRefDependenceCheck::runOnFunction() { +void MemRefDependenceCheck::runOnFunction() { // Collect the loads and stores within the function. loadsAndStores.clear(); getFunction().walk([&](Instruction *inst) { @@ -120,7 +120,6 @@ PassResult MemRefDependenceCheck::runOnFunction() { }); checkDependences(loadsAndStores); - return success(); } static PassRegistration diff --git a/mlir/lib/Analysis/OpStats.cpp b/mlir/lib/Analysis/OpStats.cpp index a17be9d176b..3fdbc9c1d46 100644 --- a/mlir/lib/Analysis/OpStats.cpp +++ b/mlir/lib/Analysis/OpStats.cpp @@ -30,7 +30,7 @@ struct PrintOpStatsPass : public ModulePass { explicit PrintOpStatsPass(llvm::raw_ostream &os = llvm::errs()) : os(os) {} // Prints the resultant operation statistics post iterating over the module. - PassResult runOnModule() override; + void runOnModule() override; // Print summary of op stats. void printSummary(); @@ -41,7 +41,7 @@ private: }; } // namespace -PassResult PrintOpStatsPass::runOnModule() { +void PrintOpStatsPass::runOnModule() { opCount.clear(); // Compute the operation statistics for each function in the module. @@ -49,7 +49,6 @@ PassResult PrintOpStatsPass::runOnModule() { fn.walk( [&](Instruction *inst) { ++opCount[inst->getName().getStringRef()]; }); printSummary(); - return success(); } void PrintOpStatsPass::printSummary() { diff --git a/mlir/lib/EDSC/LowerEDSCTestPass.cpp b/mlir/lib/EDSC/LowerEDSCTestPass.cpp index 2b6c38bf8c6..db30eef17bf 100644 --- a/mlir/lib/EDSC/LowerEDSCTestPass.cpp +++ b/mlir/lib/EDSC/LowerEDSCTestPass.cpp @@ -34,13 +34,13 @@ using namespace mlir; namespace { // Testing pass to lower EDSC. struct LowerEDSCTestPass : public FunctionPass { - PassResult runOnFunction() override; + void runOnFunction() override; }; } // end anonymous namespace #include "mlir/EDSC/reference-impl.inc" -PassResult LowerEDSCTestPass::runOnFunction() { +void LowerEDSCTestPass::runOnFunction() { getFunction().walk([](Instruction *op) { if (op->getName().getStringRef() == "print") { auto opName = op->getAttrOfType("op"); @@ -56,7 +56,6 @@ PassResult LowerEDSCTestPass::runOnFunction() { printRefImplementation(opName.getValue(), function.getValue()); } }); - return success(); } static PassRegistration pass("lower-edsc-test", diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp index f0835da77d4..d47c0832771 100644 --- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp @@ -284,7 +284,7 @@ Expected> ExecutionEngine::create( // Construct and run the default MLIR pipeline. PassManager manager; getDefaultPasses(manager, {}); - if (manager.run(m)) + if (!manager.run(m)) return make_string_error("passes failed"); auto llvmModule = translateModuleToLLVMIR(*m); diff --git a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp index 1994d4da233..3ec934e329a 100644 --- a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp +++ b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp @@ -1097,10 +1097,11 @@ static void ensureDistinctSuccessors(Module *m) { /// dialect. class LLVMLowering : public ModulePass, public DialectConversion { public: - PassResult runOnModule() override { + void runOnModule() override { Module *m = &getModule(); uniqueSuccessorsWithArguments(m); - return DialectConversion::convert(m) ? failure() : success(); + if (DialectConversion::convert(m)) + signalPassFailure(); } protected: diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp index 33f3048321f..f3ac6def20f 100644 --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -35,21 +35,28 @@ using namespace mlir::detail; void Pass::anchor() {} /// Forwarding function to execute this pass. -PassResult FunctionPassBase::run(Function *fn) { +bool FunctionPassBase::run(Function *fn) { /// Initialize the pass state. passState.emplace(fn); /// Invoke the virtual runOnFunction function. - return runOnFunction(); + runOnFunction(); + + // Return false if the pass signaled a failure. + return !passState->irAndPassFailed.getInt(); } -/// Forwarding function to execute this pass. -PassResult ModulePassBase::run(Module *module) { +/// Forwarding function to execute this pass. Returns false if the pass +/// execution failed, true otherwise. +bool ModulePassBase::run(Module *module) { /// Initialize the pass state. passState.emplace(module); /// Invoke the virtual runOnModule function. - return runOnModule(); + runOnModule(); + + // Return false if the pass signaled a failure. + return !passState->irAndPassFailed.getInt(); } //===----------------------------------------------------------------------===// @@ -82,7 +89,9 @@ public: FunctionPassExecutor(const FunctionPassExecutor &) = delete; FunctionPassExecutor &operator=(const FunctionPassExecutor &) = delete; - /// Run the executor on the given function. + /// Run the executor on the given function. Returns false if the pass + /// execution failed, true otherwise. + LLVM_NODISCARD bool run(Function *function); /// Add a pass to the current executor. This takes ownership over the provided @@ -107,7 +116,9 @@ public: ModulePassExecutor(const ModulePassExecutor &) = delete; ModulePassExecutor &operator=(const ModulePassExecutor &) = delete; - /// Run the executor on the given module. + /// Run the executor on the given module. Returns false if the pass + /// execution failed, true otherwise. + LLVM_NODISCARD bool run(Module *module); /// Add a pass to the current executor. This takes ownership over the provided @@ -129,25 +140,25 @@ private: bool detail::FunctionPassExecutor::run(Function *function) { for (auto &pass : passes) { /// Create an execution state for this pass. - if (pass->run(function)) - return true; + if (!pass->run(function)) + return false; // TODO: This should be opt-out and handled separately. if (function->verify()) - return true; + return false; } - return false; + return true; } /// Run all of the passes in this manager over the current module. bool detail::ModulePassExecutor::run(Module *module) { for (auto &pass : passes) { - if (pass->run(module)) - return true; + if (!pass->run(module)) + return false; // TODO: This should be opt-out and handled separately. if (module->verify()) - return true; + return false; } - return false; + return true; } //===----------------------------------------------------------------------===// @@ -168,9 +179,9 @@ public: ModuleToFunctionPassAdaptor & operator=(const ModuleToFunctionPassAdaptor &) = delete; - /// run the held function pipeline over all non-external functions within the + /// Run the held function pipeline over all non-external functions within the /// module. - PassResult runOnModule() override; + void runOnModule() override; /// Returns the function pass executor for this adaptor. FunctionPassExecutor &getFunctionExecutor() { return fpe; } @@ -182,17 +193,16 @@ private: /// Execute the held function pass over all non-external functions within the /// module. -PassResult ModuleToFunctionPassAdaptor::runOnModule() { +void ModuleToFunctionPassAdaptor::runOnModule() { for (auto &func : getModule()) { // Skip external functions. if (func.isExternal()) continue; // Run the held function pipeline over the current function. - if (fpe.run(&func)) - return failure(); + if (!fpe.run(&func)) + return signalPassFailure(); } - return success(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index 24b53220613..f4b51c3edd7 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -111,7 +111,7 @@ struct CSE : public FunctionPass { void simplifyBlock(DominanceInfo &domInfo, Block *bb); void simplifyBlockList(DominanceInfo &domInfo, BlockList &blockList); - PassResult runOnFunction() override; + void runOnFunction() override; private: /// A scoped hash table of defining operations within a function. @@ -216,7 +216,7 @@ void CSE::simplifyBlockList(DominanceInfo &domInfo, BlockList &blockList) { } } -PassResult CSE::runOnFunction() { +void CSE::runOnFunction() { DominanceInfo domInfo(&getFunction()); simplifyBlockList(domInfo, getFunction().getBlockList()); @@ -224,8 +224,6 @@ PassResult CSE::runOnFunction() { for (auto *op : opsToErase) op->erase(); opsToErase.clear(); - - return success(); } FunctionPassBase *mlir::createCSEPass() { return new CSE(); } diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp index 764f055a673..17259bb19da 100644 --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -34,11 +34,11 @@ namespace { /// Canonicalize operations in functions. struct Canonicalizer : public FunctionPass { - PassResult runOnFunction() override; + void runOnFunction() override; }; } // end anonymous namespace -PassResult Canonicalizer::runOnFunction() { +void Canonicalizer::runOnFunction() { OwningRewritePatternList patterns; auto &func = getFunction(); @@ -50,7 +50,6 @@ PassResult Canonicalizer::runOnFunction() { op->getCanonicalizationPatterns(patterns, context); applyPatternsGreedily(&func, std::move(patterns)); - return success(); } /// Create a Canonicalizer pass. diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index ed35c03755f..6274d7dc857 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -33,7 +33,7 @@ struct ConstantFold : public FunctionPass { std::vector opInstsToErase; void foldInstruction(Instruction *op); - PassResult runOnFunction() override; + void runOnFunction() override; }; } // end anonymous namespace @@ -92,7 +92,7 @@ void ConstantFold::foldInstruction(Instruction *op) { // For now, we do a simple top-down pass over a function folding constants. We // don't handle conditional control flow, block arguments, folding // conditional branches, or anything else fancy. -PassResult ConstantFold::runOnFunction() { +void ConstantFold::runOnFunction() { existingConstants.clear(); opInstsToErase.clear(); @@ -113,8 +113,6 @@ PassResult ConstantFold::runOnFunction() { if (cst->use_empty()) cst->getDefiningInst()->erase(); } - - return success(); } /// Creates a constant folding pass. diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 82ba07acb5f..53bc56173d2 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -84,7 +84,7 @@ struct DmaGeneration : public FunctionPass { minDmaTransferSize(minDmaTransferSize), fastMemCapacityBytes(fastMemCapacityBytes) {} - PassResult runOnFunction() override; + void runOnFunction() override; bool runOnBlock(Block *block); uint64_t runOnBlock(Block::iterator begin, Block::iterator end); @@ -754,16 +754,13 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { return totalDmaBuffersSizeInBytes; } -PassResult DmaGeneration::runOnFunction() { +void DmaGeneration::runOnFunction() { Function *f = &getFunction(); FuncBuilder topBuilder(f); zeroIndex = topBuilder.create(f->getLoc(), 0); - for (auto &block : *f) { + for (auto &block : *f) runOnBlock(&block); - } - // This function never leaves the IR in an invalid state. - return success(); } static PassRegistration diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 1528e394506..61d13325d13 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -91,7 +91,7 @@ struct LoopFusion : public FunctionPass { : localBufSizeThreshold(localBufSizeThreshold), fastMemorySpace(fastMemorySpace) {} - PassResult runOnFunction() override; + void runOnFunction() override; // Any local buffers smaller than this size (in bytes) will be created in // `fastMemorySpace` if provided. @@ -1800,7 +1800,7 @@ public: } // end anonymous namespace -PassResult LoopFusion::runOnFunction() { +void LoopFusion::runOnFunction() { // Override if a command line argument was provided. if (clFusionFastMemorySpace.getNumOccurrences() > 0) { fastMemorySpace = clFusionFastMemorySpace.getValue(); @@ -1814,7 +1814,6 @@ PassResult LoopFusion::runOnFunction() { MemRefDependenceGraph g; if (g.init(&getFunction())) GreedyFusion(&g).run(localBufSizeThreshold, fastMemorySpace); - return success(); } static PassRegistration pass("loop-fusion", "Fuse loop nests"); diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index db0e8d51ad8..4aebbc2e856 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -48,7 +48,7 @@ namespace { /// A pass to perform loop tiling on all suitable loop nests of a Function. struct LoopTiling : public FunctionPass { - PassResult runOnFunction() override; + void runOnFunction() override; constexpr static unsigned kDefaultTileSize = 4; }; @@ -253,7 +253,7 @@ getTileableBands(Function *f, getMaximalPerfectLoopNest(forOp); } -PassResult LoopTiling::runOnFunction() { +void LoopTiling::runOnFunction() { std::vector, 6>> bands; getTileableBands(&getFunction(), &bands); @@ -265,11 +265,9 @@ PassResult LoopTiling::runOnFunction() { clTileSizes.begin() + std::min(clTileSizes.size(), band.size()), tileSizes.begin()); - if (tileCodeGen(band, tileSizes)) { - return failure(); - } + if (tileCodeGen(band, tileSizes)) + return signalPassFailure(); } - return success(); } static PassRegistration pass("loop-tile", "Tile loop nests"); diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 231dba65720..2bf78ae258e 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -79,7 +79,7 @@ struct LoopUnroll : public FunctionPass { : unrollFactor(unrollFactor), unrollFull(unrollFull), getUnrollFactor(getUnrollFactor) {} - PassResult runOnFunction() override; + void runOnFunction() override; /// Unroll this for inst. Returns false if nothing was done. bool runOnAffineForOp(OpPointer forOp); @@ -88,7 +88,7 @@ struct LoopUnroll : public FunctionPass { }; } // end anonymous namespace -PassResult LoopUnroll::runOnFunction() { +void LoopUnroll::runOnFunction() { // Gathers all innermost loops through a post order pruned walk. struct InnermostLoopGatherer { // Store innermost loops as we walk. @@ -137,7 +137,7 @@ PassResult LoopUnroll::runOnFunction() { }); for (auto forOp : loops) loopUnrollFull(forOp); - return success(); + return; } unsigned numRepetitions = clUnrollNumRepetitions.getNumOccurrences() > 0 @@ -158,7 +158,6 @@ PassResult LoopUnroll::runOnFunction() { // Break out if nothing was unrolled. break; } - return success(); } /// Unrolls a 'for' inst. Returns true if the loop was unrolled, false diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index e950d117ddc..87259497cef 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -78,7 +78,7 @@ struct LoopUnrollAndJam : public FunctionPass { explicit LoopUnrollAndJam(Optional unrollJamFactor = None) : unrollJamFactor(unrollJamFactor) {} - PassResult runOnFunction() override; + void runOnFunction() override; bool runOnAffineForOp(OpPointer forOp); }; } // end anonymous namespace @@ -88,15 +88,13 @@ FunctionPassBase *mlir::createLoopUnrollAndJamPass(int unrollJamFactor) { unrollJamFactor == -1 ? None : Optional(unrollJamFactor)); } -PassResult LoopUnrollAndJam::runOnFunction() { +void LoopUnrollAndJam::runOnFunction() { // Currently, just the outermost loop from the first loop nest is // unroll-and-jammed by this pass. However, runOnAffineForOp can be called on // any for operation. auto &entryBlock = getFunction().front(); if (auto forOp = entryBlock.front().dyn_cast()) runOnAffineForOp(forOp); - - return success(); } /// Unroll and jam a 'for' inst. Default unroll jam factor is diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index aecd4314d42..1070c10a2d4 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -243,7 +243,7 @@ Optional> static expandAffineMap( namespace { struct LowerAffinePass : public FunctionPass { - PassResult runOnFunction() override; + void runOnFunction() override; bool lowerAffineFor(OpPointer forOp); bool lowerAffineIf(AffineIfOp *ifOp); @@ -604,7 +604,7 @@ bool LowerAffinePass::lowerAffineApply(AffineApplyOp *op) { // construction. When an Value is used, it gets replaced with the // corresponding Value that has been defined previously. The value flow // starts with function arguments converted to basic block arguments. -PassResult LowerAffinePass::runOnFunction() { +void LowerAffinePass::runOnFunction() { SmallVector instsToRewrite; // Collect all the For instructions as well as AffineIfOps and AffineApplyOps. @@ -620,16 +620,14 @@ PassResult LowerAffinePass::runOnFunction() { for (auto *inst : instsToRewrite) { if (auto ifOp = inst->dyn_cast()) { if (lowerAffineIf(ifOp)) - return failure(); + return signalPassFailure(); } else if (auto forOp = inst->dyn_cast()) { if (lowerAffineFor(forOp)) - return failure(); + return signalPassFailure(); } else if (lowerAffineApply(inst->cast())) { - return failure(); + return signalPassFailure(); } } - - return success(); } /// Lowers If and For instructions within a function into their lower level CFG diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index ddeb524f5ab..261c360631f 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -426,11 +426,10 @@ public: struct LowerVectorTransfersPass : public FunctionPass { - PassResult runOnFunction() { + void runOnFunction() { Function *f = &getFunction(); applyMLPatternsGreedily, VectorTransferExpander>(f); - return success(); } // Thread-safe RAII context with local scope. BumpPtrAllocator freed on exit. diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 7b45af011ab..c41c75bb88f 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -197,7 +197,7 @@ struct MaterializationState { }; struct MaterializeVectorsPass : public FunctionPass { - PassResult runOnFunction() override; + void runOnFunction() override; }; } // end anonymous namespace @@ -729,14 +729,14 @@ static bool materialize(Function *f, return false; } -PassResult MaterializeVectorsPass::runOnFunction() { +void MaterializeVectorsPass::runOnFunction() { // Thread-safe RAII local context, BumpPtrAllocator freed on exit. NestedPatternContext mlContext; // TODO(ntv): Check to see if this supports arbitrary top-level code. Function *f = &getFunction(); if (f->getBlocks().size() != 1) - return success(); + return; using matcher::Op; LLVM_DEBUG(dbgs() << "\nMaterializeVectors on Function\n"); @@ -764,8 +764,8 @@ PassResult MaterializeVectorsPass::runOnFunction() { terminators.insert(m.getMatchedInstruction()); } - auto fail = materialize(f, terminators, &state); - return fail ? PassResult::Failure : PassResult::Success; + if (materialize(f, terminators, &state)) + signalPassFailure(); } FunctionPassBase *mlir::createMaterializeVectorsPass() { diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index 067bfa4c94c..55837f95d14 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -70,7 +70,7 @@ namespace { // than dealloc) remain. // struct MemRefDataFlowOpt : public FunctionPass { - PassResult runOnFunction() override; + void runOnFunction() override; void forwardStoreToLoad(OpPointer loadOp); @@ -209,11 +209,11 @@ void MemRefDataFlowOpt::forwardStoreToLoad(OpPointer loadOp) { loadOpsToErase.push_back(loadOpInst); } -PassResult MemRefDataFlowOpt::runOnFunction() { +void MemRefDataFlowOpt::runOnFunction() { // Only supports single block functions at the moment. Function &f = getFunction(); if (f.getBlocks().size() != 1) - return success(); + return; DominanceInfo theDomInfo(&f); domInfo = &theDomInfo; @@ -254,9 +254,6 @@ PassResult MemRefDataFlowOpt::runOnFunction() { use.getOwner()->erase(); defInst->erase(); } - - // This function never leaves the IR in an invalid state. - return success(); } static PassRegistration diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 19d31fd9f26..9df1af9767f 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -39,8 +39,8 @@ using namespace mlir; namespace { struct PipelineDataTransfer : public FunctionPass { - PassResult runOnFunction() override; - PassResult runOnAffineForOp(OpPointer forOp); + void runOnFunction() override; + void runOnAffineForOp(OpPointer forOp); std::vector> forOps; }; @@ -139,7 +139,7 @@ static bool doubleBuffer(Value *oldMemRef, OpPointer forOp) { } /// Returns success if the IR is in a valid state. -PassResult PipelineDataTransfer::runOnFunction() { +void PipelineDataTransfer::runOnFunction() { // Do a post order walk so that inner loop DMAs are processed first. This is // necessary since 'for' instructions nested within would otherwise become // invalid (erased) when the outer loop is pipelined (the pipelined one gets @@ -148,11 +148,8 @@ PassResult PipelineDataTransfer::runOnFunction() { forOps.clear(); getFunction().walkPostOrder( [&](OpPointer forOp) { forOps.push_back(forOp); }); - bool ret = false; - for (auto forOp : forOps) { - ret = ret | runOnAffineForOp(forOp); - } - return ret ? failure() : success(); + for (auto forOp : forOps) + runOnAffineForOp(forOp); } // Check if tags of the dma start op and dma wait op match. @@ -252,13 +249,12 @@ static void findMatchingStartFinishInsts( /// Overlap DMA transfers with computation in this loop. If successful, /// 'forOp' is deleted, and a prologue, a new pipelined loop, and epilogue are /// inserted right before where it was. -PassResult -PipelineDataTransfer::runOnAffineForOp(OpPointer forOp) { +void PipelineDataTransfer::runOnAffineForOp(OpPointer forOp) { auto mayBeConstTripCount = getConstantTripCount(forOp); if (!mayBeConstTripCount.hasValue()) { LLVM_DEBUG( forOp->emitNote("won't pipeline due to unknown trip count loop")); - return success(); + return; } SmallVector, 4> startWaitPairs; @@ -266,7 +262,7 @@ PipelineDataTransfer::runOnAffineForOp(OpPointer forOp) { if (startWaitPairs.empty()) { LLVM_DEBUG(forOp->emitNote("No dma start/finish pairs\n")); - return success(); + return; } // Double the buffers for the higher memory space memref's. @@ -287,7 +283,7 @@ PipelineDataTransfer::runOnAffineForOp(OpPointer forOp) { LLVM_DEBUG(llvm::dbgs() << "double buffering failed for: \n";); LLVM_DEBUG(dmaStartInst->dump()); // IR still in a valid state. - return success(); + return; } // If the old memref has no more uses, remove its 'dead' alloc if it was // alloc'ed. (note: DMA buffers are rarely function live-in; but a 'dim' @@ -315,7 +311,7 @@ PipelineDataTransfer::runOnAffineForOp(OpPointer forOp) { dmaFinishInst->getOperand(getTagMemRefPos(*dmaFinishInst)); if (!doubleBuffer(oldTagMemRef, forOp)) { LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";); - return success(); + return; } // If the old tag has no more uses, remove its 'dead' alloc if it was // alloc'ed. @@ -377,15 +373,13 @@ PipelineDataTransfer::runOnAffineForOp(OpPointer forOp) { if (!isInstwiseShiftValid(forOp, shifts)) { // Violates dependences. LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";); - return success(); + return; } if (instBodySkew(forOp, shifts)) { LLVM_DEBUG(llvm::dbgs() << "inst body skewing failed - unexpected\n";); - return success(); + return; } - - return success(); } static PassRegistration pass( diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp index 4c0fed5b648..3adcbe038ea 100644 --- a/mlir/lib/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp @@ -38,7 +38,7 @@ namespace { /// on AffineMap and driven from the existing canonicalization pass. struct SimplifyAffineStructures : public FunctionPass { - PassResult runOnFunction() override; + void runOnFunction() override; }; } // end anonymous namespace @@ -57,7 +57,7 @@ static IntegerSet simplifyIntegerSet(IntegerSet set) { return set; } -PassResult SimplifyAffineStructures::runOnFunction() { +void SimplifyAffineStructures::runOnFunction() { getFunction().walk([&](Instruction *opInst) { for (auto attr : opInst->getAttrs()) { if (auto mapAttr = attr.second.dyn_cast()) { @@ -71,8 +71,6 @@ PassResult SimplifyAffineStructures::runOnFunction() { } } }); - - return success(); } static PassRegistration diff --git a/mlir/lib/Transforms/StripDebugInfo.cpp b/mlir/lib/Transforms/StripDebugInfo.cpp index 0f1ba02174b..47244f94ac9 100644 --- a/mlir/lib/Transforms/StripDebugInfo.cpp +++ b/mlir/lib/Transforms/StripDebugInfo.cpp @@ -24,18 +24,17 @@ using namespace mlir; namespace { struct StripDebugInfo : public FunctionPass { - PassResult runOnFunction() override; + void runOnFunction() override; }; } // end anonymous namespace -PassResult StripDebugInfo::runOnFunction() { +void StripDebugInfo::runOnFunction() { Function &func = getFunction(); UnknownLoc unknownLoc = UnknownLoc::get(func.getContext()); // Strip the debug info from the function and its instructions. func.setLoc(unknownLoc); func.walk([&](Instruction *inst) { inst->setLoc(unknownLoc); }); - return success(); } /// Creates a pass to strip debug information from a function. diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index 60e58c42e6b..c254790dbe7 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -87,7 +87,7 @@ struct VectorizerTestPass : public FunctionPass { static constexpr auto kTestAffineMapOpName = "test_affine_map"; static constexpr auto kTestAffineMapAttrName = "affine_map"; - PassResult runOnFunction() override; + void runOnFunction() override; void testVectorShapeRatio(Function *f); void testForwardSlicing(Function *f); void testBackwardSlicing(Function *f); @@ -260,14 +260,14 @@ void VectorizerTestPass::testNormalizeMaps(Function *f) { } } -PassResult VectorizerTestPass::runOnFunction() { +void VectorizerTestPass::runOnFunction() { // Thread-safe RAII local context, BumpPtrAllocator freed on exit. NestedPatternContext mlContext; // Only support single block functions at this point. Function *f = &getFunction(); if (f->getBlocks().size() != 1) - return success(); + return; if (!clTestVectorShapeRatio.empty()) { testVectorShapeRatio(f); @@ -287,7 +287,6 @@ PassResult VectorizerTestPass::runOnFunction() { if (clTestNormalizeMaps) { testNormalizeMaps(f); } - return PassResult::Success; } FunctionPassBase *mlir::createVectorizerTestPass() { diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 8a378a29c84..50c6cdad0f9 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -652,7 +652,7 @@ static std::vector makePatterns() { namespace { struct Vectorize : public FunctionPass { - PassResult runOnFunction() override; + void runOnFunction() override; }; } // end anonymous namespace @@ -1260,7 +1260,7 @@ static bool vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) { /// Applies vectorization to the current Function by searching over a bunch of /// predetermined patterns. -PassResult Vectorize::runOnFunction() { +void Vectorize::runOnFunction() { // Thread-safe RAII local context, BumpPtrAllocator freed on exit. NestedPatternContext mlContext; @@ -1295,7 +1295,6 @@ PassResult Vectorize::runOnFunction() { } } LLVM_DEBUG(dbgs() << "\n"); - return PassResult::Success; } FunctionPassBase *mlir::createVectorizePass() { return new Vectorize(); } diff --git a/mlir/lib/Transforms/ViewFunctionGraph.cpp b/mlir/lib/Transforms/ViewFunctionGraph.cpp index 30fae94139f..b2dfe6795b6 100644 --- a/mlir/lib/Transforms/ViewFunctionGraph.cpp +++ b/mlir/lib/Transforms/ViewFunctionGraph.cpp @@ -77,9 +77,8 @@ struct PrintCFGPass : public FunctionPass { PrintCFGPass(llvm::raw_ostream &os = llvm::errs(), bool shortNames = false, const llvm::Twine &title = "") : os(os), shortNames(shortNames), title(title) {} - PassResult runOnFunction() { + void runOnFunction() { mlir::writeGraph(os, &getFunction(), shortNames, title); - return success(); } private: diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 96c85b23719..2b9af3debc1 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -131,7 +131,7 @@ static OptResult performActions(SourceMgr &sourceMgr, MLIRContext *context) { PassManager pm; for (const auto *passEntry : *passList) passEntry->addToPipeline(pm); - if (pm.run(module.get())) + if (!pm.run(module.get())) return OptFailure; std::string errorMessage; -- cgit v1.2.3 From d038e3473522418cf11adf796af4596028a9fe67 Mon Sep 17 00:00:00 2001 From: MLIR Team Date: Fri, 1 Mar 2019 11:50:25 -0800 Subject: Loop fusion for input reuse. *) Breaks fusion pass into multiple sub passes over nodes in data dependence graph: - first pass fuses single-use producers into their unique consumer. - second pass enables fusing for input-reuse by fusing sibling nodes which read from the same memref, but which do not share dependence edges. - third pass fuses remaining producers into their consumers (Note that the sibling fusion pass may have transformed a producer with multiple uses into a single-use producer). *) Fusion for input reuse is enabled by computing a sibling node slice using the load/load accesses to the same memref, and fusion safety is guaranteed by checking that the sibling node memref write region (to a different memref) is preserved. *) Enables output vector and output matrix computations from KFAC patches-second-moment operation to fuse into a single loop nest and reuse input from the image patches operation. *) Adds a generic loop utilitiy for finding all sequential loops in a loop nest. *) Adds and updates unit tests. PiperOrigin-RevId: 236350987 --- mlir/include/mlir/Analysis/AffineAnalysis.h | 5 +- mlir/include/mlir/Analysis/Utils.h | 5 + mlir/lib/Analysis/AffineAnalysis.cpp | 10 +- mlir/lib/Analysis/AffineStructures.cpp | 11 +- mlir/lib/Analysis/Utils.cpp | 70 ++++- mlir/lib/Transforms/LoopFusion.cpp | 458 +++++++++++++++++++++++++--- mlir/test/Transforms/loop-fusion.mlir | 151 ++++++++- 7 files changed, 651 insertions(+), 59 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Analysis/AffineAnalysis.h b/mlir/include/mlir/Analysis/AffineAnalysis.h index 9d3887ddb70..51f0a9aaec1 100644 --- a/mlir/include/mlir/Analysis/AffineAnalysis.h +++ b/mlir/include/mlir/Analysis/AffineAnalysis.h @@ -94,13 +94,16 @@ struct DependenceComponent { /// the operation instruction, indices and memref associated with the access. /// Returns 'false' if it can be determined conclusively that the accesses do /// not access the same memref element. Returns 'true' otherwise. +/// If 'allowRAR' is true, will consider read-after-read dependences (typically +/// used by applications trying to optimize input reuse). // TODO(andydavis) Wrap 'dependenceConstraints' and 'dependenceComponents' into // a single struct. // TODO(andydavis) Make 'dependenceConstraints' optional arg. bool checkMemrefAccessDependence( const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, unsigned loopDepth, FlatAffineConstraints *dependenceConstraints, - llvm::SmallVector *dependenceComponents); + llvm::SmallVector *dependenceComponents, + bool allowRAR = false); } // end namespace mlir #endif // MLIR_ANALYSIS_AFFINE_ANALYSIS_H diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index dff0f57aacc..4baa1017352 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -62,6 +62,11 @@ void getLoopIVs(const Instruction &inst, /// surrounding this instruction. unsigned getNestingDepth(const Instruction &stmt); +/// Returns in 'sequentialLoops' all sequential loops in loop nest rooted +/// at 'forOp'. +void getSequentialLoops(OpPointer forOp, + llvm::SmallDenseSet *sequentialLoops); + /// ComputationSliceState aggregates loop IVs, loop bound AffineMaps and their /// associated operands for a set of loops within a loop nest (typically the /// set of loops surrounding a store operation). Loop bound AffineMaps which diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index 591ed1f70a2..1a52b839343 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -768,7 +768,8 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const { bool mlir::checkMemrefAccessDependence( const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, unsigned loopDepth, FlatAffineConstraints *dependenceConstraints, - llvm::SmallVector *dependenceComponents) { + llvm::SmallVector *dependenceComponents, + bool allowRAR) { LLVM_DEBUG(llvm::dbgs() << "Checking for dependence at depth: " << Twine(loopDepth) << " between:\n";); LLVM_DEBUG(srcAccess.opInst->dump();); @@ -778,7 +779,8 @@ bool mlir::checkMemrefAccessDependence( if (srcAccess.memref != dstAccess.memref) return false; // Return 'false' if one of these accesses is not a StoreOp. - if (!srcAccess.opInst->isa() && !dstAccess.opInst->isa()) + if (!allowRAR && !srcAccess.opInst->isa() && + !dstAccess.opInst->isa()) return false; // Get composed access function for 'srcAccess'. @@ -802,9 +804,11 @@ bool mlir::checkMemrefAccessDependence( // Return 'false' if loopDepth > numCommonLoops and if the ancestor operation // instruction of 'srcAccess' does not properly dominate the ancestor // operation instruction of 'dstAccess' in the same common instruction block. + // Note: this check is skipped if 'allowRAR' is true, because because RAR + // deps can exist irrespective of lexicographic ordering b/w src and dst. unsigned numCommonLoops = getNumCommonLoops(srcDomain, dstDomain); assert(loopDepth <= numCommonLoops + 1); - if (loopDepth > numCommonLoops && + if (!allowRAR && loopDepth > numCommonLoops && !srcAppearsBeforeDstInAncestralBlock(srcAccess, dstAccess, srcDomain, numCommonLoops)) { return false; diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 0fc37e6c7a2..26d72f55a7b 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1669,19 +1669,20 @@ bool FlatAffineConstraints::addSliceBounds(ArrayRef values, }; for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) { - assert(lbMaps[i].getNumInputs() == operands.size()); - assert(ubMaps[i].getNumInputs() == operands.size()); unsigned pos; if (!findId(*values[i], &pos)) continue; - if (AffineMap lbMap = lbMaps[i]) + if (AffineMap lbMap = lbMaps[i]) { + assert(lbMaps[i].getNumInputs() == operands.size()); if (!addLowerOrUpperBound(pos, lbMap, /*lower=*/true)) return false; - - if (AffineMap ubMap = ubMaps[i]) + } + if (AffineMap ubMap = ubMaps[i]) { + assert(ubMaps[i].getNumInputs() == operands.size()); if (!addLowerOrUpperBound(pos, ubMap, /*lower=*/false)) return false; + } } return true; } diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 6e3fc38d2f6..a7f0cbb5b29 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -173,7 +173,6 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth, } } } - // We'll first associate the dims and symbols of the access map to the dims // and symbols resp. of cst. This will change below once cst is // fully constructed out. @@ -236,7 +235,7 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth, } // Set all identifiers appearing after the first 'rank' identifiers as - // symbolic identifiers - so that the ones correspoding to the memref + // symbolic identifiers - so that the ones corresponding to the memref // dimensions are the dimensional identifiers for the memref region. cst.setDimSymbolSeparation(cst.getNumDimAndSymbolIds() - rank); @@ -442,10 +441,12 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, unsigned dstLoopDepth, ComputationSliceState *sliceState) { + bool readReadAccesses = + srcAccess.opInst->isa() && dstAccess.opInst->isa(); FlatAffineConstraints dependenceConstraints; - if (!checkMemrefAccessDependence(srcAccess, dstAccess, /*loopDepth=*/1, - &dependenceConstraints, - /*dependenceComponents=*/nullptr)) { + if (!checkMemrefAccessDependence( + srcAccess, dstAccess, /*loopDepth=*/1, &dependenceConstraints, + /*dependenceComponents=*/nullptr, /*allowRAR=*/readReadAccesses)) { return false; } // Get loop nest surrounding src operation. @@ -487,6 +488,25 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess, // canonicalization. sliceState->lbOperands.resize(numSrcLoopIVs, sliceBoundOperands); sliceState->ubOperands.resize(numSrcLoopIVs, sliceBoundOperands); + + // For read-read access pairs, clear any slice bounds on sequential loops. + if (readReadAccesses) { + // Get sequential loops in loop nest rooted at 'srcLoopIVs[0]'. + llvm::SmallDenseSet sequentialLoops; + getSequentialLoops(srcLoopIVs[0], &sequentialLoops); + + // Clear all sliced loop bounds beginning at the first sequential loop. + for (unsigned i = 0; i < numSrcLoopIVs; ++i) { + Value *iv = srcLoopIVs[i]->getInductionVar(); + if (sequentialLoops.count(iv) == 0) + continue; + for (unsigned j = i; j < numSrcLoopIVs; ++j) { + sliceState->lbs[j] = AffineMap(); + sliceState->ubs[j] = AffineMap(); + } + break; + } + } return true; } @@ -675,3 +695,43 @@ mlir::getMemoryFootprintBytes(ConstOpPointer forOp, *forInst->getBlock(), Block::const_iterator(forInst), std::next(Block::const_iterator(forInst)), memorySpace); } + +/// Returns in 'sequentialLoops' all sequential loops in loop nest rooted +/// at 'forOp'. +void mlir::getSequentialLoops( + OpPointer forOp, + llvm::SmallDenseSet *sequentialLoops) { + // Collect all load and store ops in loop nest rooted at 'forOp'. + SmallVector loadAndStoreOpInsts; + forOp->getInstruction()->walk([&](Instruction *opInst) { + if (opInst->isa() || opInst->isa()) + loadAndStoreOpInsts.push_back(opInst); + }); + + // Check dependences on all pairs of ops in 'loadAndStoreOpInsts' and record + // loops which carry dependences in 'sequentialLoops'. + for (unsigned i = 0, e = loadAndStoreOpInsts.size(); i < e; ++i) { + auto *srcOpInst = loadAndStoreOpInsts[i]; + MemRefAccess srcAccess(srcOpInst); + SmallVector, 4> srcLoopIVs; + getLoopIVs(*srcOpInst, &srcLoopIVs); + for (auto *dstOpInst : loadAndStoreOpInsts) { + MemRefAccess dstAccess(dstOpInst); + + unsigned numCommonLoops = + getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst); + for (unsigned d = 1; d <= numCommonLoops; ++d) { + auto *iv = srcLoopIVs[d - 1]->getInductionVar(); + if (sequentialLoops->count(iv) > 0) + continue; + FlatAffineConstraints dependenceConstraints; + if (checkMemrefAccessDependence(srcAccess, dstAccess, d, + &dependenceConstraints, + /*dependenceComponents=*/nullptr)) { + // Record loop with carried dependence between srcAccess/dstAccess. + sequentialLoops->insert(iv); + } + } + } + } +} diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 61d13325d13..025813c6ca4 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -141,6 +141,7 @@ static bool isMemRefDereferencingOp(const Instruction &op) { return true; return false; } + // MemRefDependenceGraph is a graph data structure where graph nodes are // top-level instructions in a Function which contain load/store ops, and edges // are memref dependences between the nodes. @@ -182,7 +183,7 @@ public: return storeOpCount; } - // Returns all store ups in 'storeOps' which access 'memref'. + // Returns all store ops in 'storeOps' which access 'memref'. void getStoreOpsForMemref(Value *memref, SmallVectorImpl *storeOps) { for (auto *storeOpInst : stores) { @@ -190,6 +191,29 @@ public: storeOps->push_back(storeOpInst); } } + + // Returns all load ops in 'loadOps' which access 'memref'. + void getLoadOpsForMemref(Value *memref, + SmallVectorImpl *loadOps) { + for (auto *loadOpInst : loads) { + if (memref == loadOpInst->cast()->getMemRef()) + loadOps->push_back(loadOpInst); + } + } + + // Returns all memrefs in 'loadAndStoreMemrefSet' for which this node + // has at least one load and store operation. + void getLoadAndStoreMemrefSet(DenseSet *loadAndStoreMemrefSet) { + llvm::SmallDenseSet loadMemrefs; + for (auto *loadOpInst : loads) { + loadMemrefs.insert(loadOpInst->cast()->getMemRef()); + } + for (auto *storeOpInst : stores) { + auto *memref = storeOpInst->cast()->getMemRef(); + if (loadMemrefs.count(memref) > 0) + loadAndStoreMemrefSet->insert(memref); + } + } }; // Edge represents a data dependece between nodes in the graph. @@ -300,17 +324,18 @@ public: return true; } - // Returns true iff there is an edge from node 'srcId' to node 'dstId' for - // 'value'. Returns false otherwise. - bool hasEdge(unsigned srcId, unsigned dstId, Value *value) { + // Returns true iff there is an edge from node 'srcId' to node 'dstId' which + // is for 'value' if non-null, or for any value otherwise. Returns false + // otherwise. + bool hasEdge(unsigned srcId, unsigned dstId, Value *value = nullptr) { if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) { return false; } bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) { - return edge.id == dstId && edge.value == value; + return edge.id == dstId && (!value || edge.value == value); }); bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) { - return edge.id == srcId && edge.value == value; + return edge.id == srcId && (!value || edge.value == value); }); return hasOutEdge && hasInEdge; } @@ -349,8 +374,37 @@ public: } } + // Returns true if there is a path in the dependence graph from node 'srcId' + // to node 'dstId'. Returns false otherwise. + bool hasDependencePath(unsigned srcId, unsigned dstId) { + // Worklist state is: + SmallVector, 4> worklist; + worklist.push_back({srcId, 0}); + // Run DFS traversal to see if 'dstId' is reachable from 'srcId'. + while (!worklist.empty()) { + auto &idAndIndex = worklist.back(); + // Return true if we have reached 'dstId'. + if (idAndIndex.first == dstId) + return true; + // Pop and continue if node has no out edges, or if all out edges have + // already been visited. + if (outEdges.count(idAndIndex.first) == 0 || + idAndIndex.second == outEdges[idAndIndex.first].size()) { + worklist.pop_back(); + continue; + } + // Get graph edge to traverse. + Edge edge = outEdges[idAndIndex.first][idAndIndex.second]; + // Increment next output edge index for 'idAndIndex'. + ++idAndIndex.second; + // Add node at 'edge.id' to worklist. + worklist.push_back({edge.id, 0}); + } + return false; + } + // Returns the input edge count for node 'id' and 'memref' from src nodes - // which access 'memref'. + // which access 'memref' with a store operation. unsigned getIncomingMemRefAccesses(unsigned id, Value *memref) { unsigned inEdgeCount = 0; if (inEdges.count(id) > 0) @@ -358,19 +412,19 @@ public: if (inEdge.value == memref) { Node *srcNode = getNode(inEdge.id); // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref' - if (srcNode->getLoadOpCount(memref) > 0 || - srcNode->getStoreOpCount(memref) > 0) + if (srcNode->getStoreOpCount(memref) > 0) ++inEdgeCount; } return inEdgeCount; } - // Returns the output edge count for node 'id' and 'memref'. - unsigned getOutEdgeCount(unsigned id, Value *memref) { + // Returns the output edge count for node 'id' and 'memref' (if non-null), + // otherwise returns the total output edge count from node 'id'. + unsigned getOutEdgeCount(unsigned id, Value *memref = nullptr) { unsigned outEdgeCount = 0; if (outEdges.count(id) > 0) for (auto &outEdge : outEdges[id]) - if (outEdge.value == memref) + if (!memref || outEdge.value == memref) ++outEdgeCount; return outEdgeCount; } @@ -469,6 +523,32 @@ public: } } + // Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion + // of sibling node 'sidId' into node 'dstId'. + void updateEdges(unsigned sibId, unsigned dstId) { + // For each edge in 'inEdges[sibId]': + // *) Add new edge from source node 'inEdge.id' to 'dstNode'. + // *) Remove edge from source node 'inEdge.id' to 'sibNode'. + if (inEdges.count(sibId) > 0) { + SmallVector oldInEdges = inEdges[sibId]; + for (auto &inEdge : oldInEdges) { + addEdge(inEdge.id, dstId, inEdge.value); + removeEdge(inEdge.id, sibId, inEdge.value); + } + } + + // For each edge in 'outEdges[sibId]' to node 'id' + // *) Add new edge from 'dstId' to 'outEdge.id'. + // *) Remove edge from 'sibId' to 'outEdge.id'. + if (outEdges.count(sibId) > 0) { + SmallVector oldOutEdges = outEdges[sibId]; + for (auto &outEdge : oldOutEdges) { + addEdge(dstId, outEdge.id, outEdge.value); + removeEdge(sibId, outEdge.id, outEdge.value); + } + } + } + // Adds ops in 'loads' and 'stores' to node at 'id'. void addToNode(unsigned id, const SmallVectorImpl &loads, const SmallVectorImpl &stores) { @@ -485,6 +565,37 @@ public: node->stores.clear(); } + // Calls 'callback' for each input edge incident to node 'id' which carries a + // memref dependence. + void forEachMemRefInputEdge(unsigned id, + const std::function &callback) { + if (inEdges.count(id) > 0) + forEachMemRefEdge(inEdges[id], callback); + } + // Calls 'callback' for each output edge from node 'id' which carries a + // memref dependence. + void forEachMemRefOutputEdge(unsigned id, + const std::function &callback) { + if (outEdges.count(id) > 0) + forEachMemRefEdge(outEdges[id], callback); + } + // Calls 'callback' for each edge in 'edges' which carries a memref + // dependence. + void forEachMemRefEdge(ArrayRef edges, + const std::function &callback) { + for (auto &edge : edges) { + // Skip if 'edge' is not a memref dependence edge. + if (!edge.value->getType().isa()) + continue; + assert(nodes.count(edge.id) > 0); + // Skip if 'edge.id' is not a loop nest. + if (!getNode(edge.id)->inst->isa()) + continue; + // Visit current input edge 'edge'. + callback(edge); + } + } + void print(raw_ostream &os) const { os << "\nMemRefDependenceGraph\n"; os << "\nNodes:\n"; @@ -1228,6 +1339,14 @@ static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, // Checks the profitability of fusing a backwards slice of the loop nest // surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'. +// The argument 'srcStoreOpInst' is used to calculate the storage reduction on +// the memref being produced and consumed, which is an input to the cost model. +// For producer-constumer fusion, 'srcStoreOpInst' will be the same as +// 'srcOpInst', as we are slicing w.r.t to that producer. +// For input-reuse fusion, 'srcOpInst' will be the src loop nest LoadOp which +// reads from the same memref as dst loop nest load ops, and 'srcStoreOpInst' +// will be the unique store op in the src node, which will be used to check +// that the write region is the same after input-reuse fusion. // Returns true if it is profitable to fuse the candidate loop nests. Returns // false otherwise. `dstLoopDepth` is set to the most profitable depth at which // to materialize the source loop nest slice. @@ -1257,6 +1376,7 @@ static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, // loop nest computed in the previous step, and returns true if the latter // is lower. static bool isFusionProfitable(Instruction *srcOpInst, + Instruction *srcStoreOpInst, ArrayRef dstLoadOpInsts, ArrayRef dstStoreOpInsts, ComputationSliceState *sliceState, @@ -1294,8 +1414,11 @@ static bool isFusionProfitable(Instruction *srcOpInst, return false; // Compute the maximum loop depth at which we can can insert the src slice - // and still satisfy dest loop nest dependences. - unsigned maxDstLoopDepth = getMaxLoopDepth(dstLoadOpInsts, dstStoreOpInsts); + // and still satisfy dest loop nest dependences, for producer-consumer fusion. + unsigned maxDstLoopDepth = + (srcOpInst == srcStoreOpInst) + ? getMaxLoopDepth(dstLoadOpInsts, dstStoreOpInsts) + : dstLoopIVs.size(); if (maxDstLoopDepth == 0) return false; @@ -1306,7 +1429,7 @@ static bool isFusionProfitable(Instruction *srcOpInst, // the cost of the slice and the cost of the slice inserted into the dst // loop nest at 'dstLoopDepth'. uint64_t minFusedLoopNestComputeCost = std::numeric_limits::max(); - uint64_t maxStorageReduction = 0; + double maxStorageReduction = 0.0; Optional sliceMemEstimate = None; SmallVector sliceStates; @@ -1321,8 +1444,8 @@ static bool isFusionProfitable(Instruction *srcOpInst, /*computeCostMap=*/nullptr); // Compute src loop nest write region size. - MemRefRegion srcWriteRegion(srcOpInst->getLoc()); - srcWriteRegion.compute(srcOpInst, /*loopDepth=*/0); + MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc()); + srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0); Optional maybeSrcWriteRegionSizeBytes = srcWriteRegion.getRegionSize(); if (!maybeSrcWriteRegionSizeBytes.hasValue()) @@ -1345,6 +1468,7 @@ static bool isFusionProfitable(Instruction *srcOpInst, if (!mlir::getBackwardComputationSliceState( srcAccess, MemRefAccess(dstLoadOpInsts[0]), i, &sliceStates[i - 1])) return false; + // Compute the union of slice bound of all ops in 'dstLoadOpInsts'. for (int j = 1, e = dstLoadOpInsts.size(); j < e; ++j) { MemRefAccess dstAccess(dstLoadOpInsts[j]); @@ -1372,6 +1496,7 @@ static bool isFusionProfitable(Instruction *srcOpInst, computeCostMap.clear(); // The store and loads to this memref will disappear. + // TODO(andydavis) Add load coalescing to memref data flow opt pass. if (storeLoadFwdGuaranteed) { // A single store disappears: -1 for that. computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]->getInstruction()] = -1; @@ -1403,8 +1528,9 @@ static bool isFusionProfitable(Instruction *srcOpInst, // Compute what the slice write MemRefRegion would be, if the src loop // nest slice 'sliceStates[i - 1]' were to be inserted into the dst loop // nest at loop depth 'i' - MemRefRegion sliceWriteRegion(srcOpInst->getLoc()); - sliceWriteRegion.compute(srcOpInst, /*loopDepth=*/0, &sliceStates[i - 1]); + MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc()); + sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0, + &sliceStates[i - 1]); Optional maybeSliceWriteRegionSizeBytes = sliceWriteRegion.getRegionSize(); if (!maybeSliceWriteRegionSizeBytes.hasValue() || @@ -1413,6 +1539,14 @@ static bool isFusionProfitable(Instruction *srcOpInst, int64_t sliceWriteRegionSizeBytes = maybeSliceWriteRegionSizeBytes.getValue(); + // If we are fusing for reuse, check that write regions remain the same. + // TODO(andydavis) Write region check should check sizes and offsets in + // each dimension, so that we are sure they are covering the same memref + // region. Also, move this out to a isMemRefRegionSuperSet helper function. + if (srcOpInst != srcStoreOpInst && + sliceWriteRegionSizeBytes != srcWriteRegionSizeBytes) + continue; + double storageReduction = static_cast(srcWriteRegionSizeBytes) / static_cast(sliceWriteRegionSizeBytes); @@ -1547,12 +1681,10 @@ static bool isFusionProfitable(Instruction *srcOpInst, return true; } -// GreedyFusion greedily fuses loop nests which have a producer/consumer -// relationship on a memref, with the goal of improving locality. Currently, -// this the producer/consumer relationship is required to be unique in the -// Function (there are TODOs to relax this constraint in the future). +// GreedyFusion greedily fuses loop nests which have a producer/consumer or +// input-reuse relationship on a memref, with the goal of improving locality. // -// The steps of the algorithm are as follows: +// The steps of the producer-consumer fusion algorithm are as follows: // // *) A worklist is initialized with node ids from the dependence graph. // *) For each node id in the worklist: @@ -1560,20 +1692,32 @@ static bool isFusionProfitable(Instruction *srcOpInst, // candidate destination AffineForOp into which fusion will be attempted. // *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'. // *) For each LoadOp in 'dstLoadOps' do: -// *) Lookup dependent loop nests at earlier positions in the Function -// which have a single store op to the same memref. -// *) Check if dependences would be violated by the fusion. For example, -// the src loop nest may load from memrefs which are different than -// the producer-consumer memref between src and dest loop nests. +// *) Lookup dependent loop nests which have a single store op to the same +// memref. +// *) Check if dependences would be violated by the fusion. // *) Get a computation slice of 'srcLoopNest', which adjusts its loop // bounds to be functions of 'dstLoopNest' IVs and symbols. // *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest', -// just before the dst load op user. +// at a loop depth determined by the cost model in 'isFusionProfitable'. // *) Add the newly fused load/store operation instructions to the state, // and also add newly fuse load ops to 'dstLoopOps' to be considered // as fusion dst load ops in another iteration. // *) Remove old src loop nest and its associated state. // +// The steps of the input-reuse fusion algorithm are as follows: +// +// *) Initialize 'worklist' with node ids from the dependence graph. +// *) For each 'dstNode' in the worklist: +// *) Find a candidate sibling node 'sibNode' to fuse with 'dstNode' which +// loads from the same memref, but which has no dependence paths to/from. +// *) Get a computation slice of 'sibLoopNest', which adjusts its loop +// bounds to be functions of 'dstLoopNest' IVs and symbols. +// *) Fuse the 'sibLoopNest' computation slice into the 'dstLoopNest', +// at a loop depth determined by the cost model in 'isFusionProfitable'. +// This function also checks that the memref write region of 'sibLoopNest', +// is preserved in the fused loop nest. +// *) Update graph state to reflect the fusion of 'sibNode' into 'dstNode'. +// // Given a graph where top-level instructions are vertices in the set 'V' and // edges in the set 'E' are dependences between vertices, this algorithm // takes O(V) time for initialization, and has runtime O(V + E). @@ -1582,25 +1726,54 @@ static bool isFusionProfitable(Instruction *srcOpInst, // fusing along single producer consumer edges, but there is a TODO to fix this. // // TODO(andydavis) Experiment with other fusion policies. -// TODO(andydavis) Add support for fusing for input reuse (perhaps by -// constructing a graph with edges which represent loads from the same memref -// in two different loop nests. struct GreedyFusion { public: + // The data dependence graph to traverse during fusion. MemRefDependenceGraph *mdg; + // Worklist of graph nodes visited during the fusion pass. SmallVector worklist; + // Set of graph nodes which are present on the worklist. llvm::SmallDenseSet worklistSet; + // Parameter for local buffer size threshold. + unsigned localBufSizeThreshold; + // Parameter for fast memory space. + Optional fastMemorySpace; + + using Node = MemRefDependenceGraph::Node; - GreedyFusion(MemRefDependenceGraph *mdg) : mdg(mdg) { - // Initialize worklist with nodes from 'mdg'. + GreedyFusion(MemRefDependenceGraph *mdg, unsigned localBufSizeThreshold, + Optional fastMemorySpace) + : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold), + fastMemorySpace(fastMemorySpace) {} + + // Initializes 'worklist' with nodes from 'mdg' + void init() { // TODO(andydavis) Add a priority queue for prioritizing nodes by different // metrics (e.g. arithmetic intensity/flops-to-bytes ratio). - worklist.resize(mdg->nodes.size()); - std::iota(worklist.begin(), worklist.end(), 0); - worklistSet.insert(worklist.begin(), worklist.end()); + worklist.clear(); + worklistSet.clear(); + for (auto &idAndNode : mdg->nodes) { + const Node &node = idAndNode.second; + worklist.push_back(node.id); + worklistSet.insert(node.id); + } } - void run(unsigned localBufSizeThreshold, Optional fastMemorySpace) { + // Run the GreedyFusion pass. + // *) First pass through the nodes fuses single-use producer nodes into their + // unique consumer. + // *) Second pass fuses sibling nodes which share no dependence edges. + // *) Third pass fuses any remaining producer nodes into their users. + void run() { + fuseProducerConsumerNodes(/*maxSrcUserCount=*/1); + fuseSiblingNodes(); + fuseProducerConsumerNodes( + /*maxSrcUserCount=*/std::numeric_limits::max()); + eraseUnusedMemRefAllocations(); + } + + void fuseProducerConsumerNodes(unsigned maxSrcUserCount) { + init(); while (!worklist.empty()) { unsigned dstId = worklist.back(); worklist.pop_back(); @@ -1672,6 +1845,10 @@ public: !canFuseSrcWhichWritesToLiveOut(srcId, dstId, memref, mdg)) continue; + // Skip if 'srcNode' out edge count on 'memref' > 'maxSrcUserCount'. + if (mdg->getOutEdgeCount(srcNode->id, memref) > maxSrcUserCount) + continue; + // Compute an instruction list insertion point for the fused loop // nest which preserves dependences. Instruction *insertPointInst = @@ -1690,8 +1867,8 @@ public: unsigned bestDstLoopDepth; mlir::ComputationSliceState sliceState; // Check if fusion would be profitable. - if (!isFusionProfitable(srcStoreOpInst, dstLoadOpInsts, - dstStoreOpInsts, &sliceState, + if (!isFusionProfitable(srcStoreOpInst, srcStoreOpInst, + dstLoadOpInsts, dstStoreOpInsts, &sliceState, &bestDstLoopDepth)) continue; @@ -1782,7 +1959,202 @@ public: } } } - // Clean up any allocs with no users. + } + + // Visits each node in the graph, and for each node, attempts to fuse it with + // its sibling nodes (nodes which share a parent, but no dependence edges). + void fuseSiblingNodes() { + init(); + while (!worklist.empty()) { + unsigned dstId = worklist.back(); + worklist.pop_back(); + worklistSet.erase(dstId); + + // Skip if this node was removed (fused into another node). + if (mdg->nodes.count(dstId) == 0) + continue; + // Get 'dstNode' into which to attempt fusion. + auto *dstNode = mdg->getNode(dstId); + // Skip if 'dstNode' is not a loop nest. + if (!dstNode->inst->isa()) + continue; + // Attempt to fuse 'dstNode' with its sibling nodes in the graph. + fuseWithSiblingNodes(dstNode); + } + } + + // Attempt to fuse 'dstNode' with sibling nodes in the graph. + void fuseWithSiblingNodes(Node *dstNode) { + DenseSet visitedSibNodeIds; + std::pair idAndMemref; + while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) { + unsigned sibId = idAndMemref.first; + Value *memref = idAndMemref.second; + // TODO(andydavis) Check that 'sibStoreOpInst' post-dominates all other + // stores to the same memref in 'sibNode' loop nest. + auto *sibNode = mdg->getNode(sibId); + // Compute an instruction list insertion point for the fused loop + // nest which preserves dependences. + assert(sibNode->inst->getBlock() == dstNode->inst->getBlock()); + Instruction *insertPointInst = + sibNode->inst->isBeforeInBlock(dstNode->inst) + ? mdg->getFusedLoopNestInsertionPoint(sibNode->id, dstNode->id) + : mdg->getFusedLoopNestInsertionPoint(dstNode->id, sibNode->id); + if (insertPointInst == nullptr) + continue; + + // Check if fusion would be profitable and at what depth. + + // Get unique 'sibNode' load op to 'memref'. + SmallVector sibLoadOpInsts; + sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts); + // Currently findSiblingNodeToFuse searches for siblings with one load. + assert(sibLoadOpInsts.size() == 1); + Instruction *sibLoadOpInst = sibLoadOpInsts[0]; + assert(!sibNode->stores.empty()); + // TODO(andydavis) Choose the store which postdominates all other stores. + auto *sibStoreOpInst = sibNode->stores.back(); + + // Gather 'dstNode' load ops to 'memref'. + SmallVector dstLoadOpInsts; + dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts); + + // Gather 'dstNode' store ops to 'memref'. + SmallVector dstStoreOpInsts; + dstNode->getStoreOpsForMemref(memref, &dstStoreOpInsts); + + unsigned bestDstLoopDepth; + mlir::ComputationSliceState sliceState; + + // Check if fusion would be profitable. + if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstLoadOpInsts, + dstStoreOpInsts, &sliceState, &bestDstLoopDepth)) + continue; + + // Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'. + auto sliceLoopNest = mlir::insertBackwardComputationSlice( + sibLoadOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState); + if (sliceLoopNest != nullptr) { + auto dstForInst = dstNode->inst->cast(); + // Update instruction position of fused loop nest (if needed). + if (insertPointInst != dstForInst->getInstruction()) { + dstForInst->getInstruction()->moveBefore(insertPointInst); + } + // Update data dependence graph state post fusion. + updateStateAfterSiblingFusion(sliceLoopNest, sibNode, dstNode); + } + } + } + + // Searches the graph from 'dstNode' looking for a fusion candidate sibling + // node which shares no dependences with 'dstNode' but which loads from the + // same memref. Returns true and sets 'idAndMemrefToFuse' on success. Returns + // false otherwise. + bool findSiblingNodeToFuse(Node *dstNode, + DenseSet *visitedSibNodeIds, + std::pair *idAndMemrefToFuse) { + // TODO(andydavis) Currently we discover siblings by following edges + // through an intermediate src node. We should also consider siblings + // which load from the same memref, but which do not necessarily share + // a src node parent (e.g. loading from a memref which is a function arg). + // Collect candidate 'dstNode' input edges in 'inEdges'. + SmallVector inEdges; + mdg->forEachMemRefInputEdge( + dstNode->id, [&](MemRefDependenceGraph::Edge inEdge) { + // Add 'inEdge' if it is a read-after-write dependence. + if (dstNode->getLoadOpCount(inEdge.value) > 0 && + mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0) + inEdges.push_back(inEdge); + }); + + // Search for sibling nodes to fuse by visiting output edges from each input + // edge in 'inEdges'. + for (auto &inEdge : inEdges) { + // Collect candidate output edges from each node 'inEdge.id' in 'inEdges'. + SmallVector outEdges; + mdg->forEachMemRefOutputEdge( + inEdge.id, [&](MemRefDependenceGraph::Edge outEdge) { + unsigned sibNodeId = outEdge.id; + if (visitedSibNodeIds->count(sibNodeId) > 0) + return; + // Skip output edge if not a sibling using the same memref. + if (outEdge.id == dstNode->id || outEdge.value != inEdge.value) + return; + auto *sibNode = mdg->getNode(sibNodeId); + if (!sibNode->inst->isa()) + return; + // Skip if 'outEdge' is not a read-after-write dependence. + // TODO(andydavis) Remove restrict to single load op restriction. + if (sibNode->getLoadOpCount(inEdge.value) != 1) + return; + // Skip if there exists a path of dependent edges between + // 'sibNode' and 'dstNode'. + if (mdg->hasDependencePath(sibNodeId, dstNode->id) || + mdg->hasDependencePath(dstNode->id, sibNodeId)) + return; + // Skip sib node if it loads to (and stores from) the same memref on + // which it also has an input dependence edge. + DenseSet loadAndStoreMemrefSet; + sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet); + if (llvm::any_of(loadAndStoreMemrefSet, [=](Value *memref) { + return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > + 0; + })) + return; + // Check that all stores are to the same memref. + DenseSet storeMemrefs; + for (auto *storeOpInst : sibNode->stores) { + storeMemrefs.insert(storeOpInst->cast()->getMemRef()); + } + if (storeMemrefs.size() != 1) + return; + // Add candidate 'outEdge' to sibling node. + outEdges.push_back(outEdge); + }); + + // Add first candidate if any were returned. + if (!outEdges.empty()) { + visitedSibNodeIds->insert(outEdges[0].id); + idAndMemrefToFuse->first = outEdges[0].id; + idAndMemrefToFuse->second = outEdges[0].value; + return true; + } + } + return false; + } + + void updateStateAfterSiblingFusion(OpPointer sliceLoopNest, + Node *sibNode, Node *dstNode) { + // Update 'sibNode' and 'dstNode' input/output edges to reflect fusion. + mdg->updateEdges(sibNode->id, dstNode->id); + + // Collect slice loop stats. + LoopNestStateCollector sliceCollector; + sliceCollector.collect(sliceLoopNest->getInstruction()); + // Promote single iteration slice loops to single IV value. + for (auto forOp : sliceCollector.forOps) { + promoteIfSingleIteration(forOp); + } + + // Collect dst loop stats after memref privatizaton transformation. + auto dstForInst = dstNode->inst->cast(); + LoopNestStateCollector dstLoopCollector; + dstLoopCollector.collect(dstForInst->getInstruction()); + // Clear and add back loads and stores + mdg->clearNodeLoadAndStores(dstNode->id); + mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts, + dstLoopCollector.storeOpInsts); + // Remove old sibling loop nest if it no longer has outgoing dependence + // edges, and it does not write to a memref which escapes the + // function. + if (mdg->getOutEdgeCount(sibNode->id) == 0) { + mdg->removeNode(sibNode->id); + sibNode->inst->cast()->erase(); + } + } + + // Clean up any allocs with no users. + void eraseUnusedMemRefAllocations() { for (auto &pair : mdg->memrefEdgeCount) { if (pair.second > 0) continue; @@ -1813,7 +2185,7 @@ void LoopFusion::runOnFunction() { MemRefDependenceGraph g; if (g.init(&getFunction())) - GreedyFusion(&g).run(localBufSizeThreshold, fastMemorySpace); + GreedyFusion(&g, localBufSizeThreshold, fastMemorySpace).run(); } static PassRegistration pass("loop-fusion", "Fuse loop nests"); diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index 23fe7130e42..656db88e882 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -1228,13 +1228,13 @@ func @should_fuse_with_private_memrefs_with_diff_shapes() { // by loops %i1 and %i2. // CHECK-DAG: %0 = alloc() : memref<1xf32> // CHECK-DAG: %1 = alloc() : memref<1xf32> - // CHECK: for %i0 = 0 to 82 { + // CHECK: for %i0 = 0 to 17 { // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: store %cst, %1[%2] : memref<1xf32> // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: %4 = load %1[%3] : memref<1xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i1 = 0 to 17 { + // CHECK-NEXT: for %i1 = 0 to 82 { // CHECK-NEXT: %5 = affine.apply [[MAP0]](%i1, %i1) // CHECK-NEXT: store %cst, %0[%5] : memref<1xf32> // CHECK-NEXT: %6 = affine.apply [[MAP0]](%i1, %i1) @@ -1915,3 +1915,150 @@ func @test_add_slice_bounds() { // CHECK-NEXT: } return } + +// ----- +// CHECK-DAG: [[MAP0:#map[0-9]+]] = (d0, d1, d2, d3) -> (-d0 + d2) +// CHECK-DAG: [[MAP1:#map[0-9]+]] = (d0, d1, d2, d3) -> (-d1 + d3) + +func @should_fuse_init_loops_siblings_then_shared_producer(%arg0: memref<10x10xf32>, %arg1: memref<10x10xf32>) { + %0 = alloc() : memref<10x10xf32> + %cst = constant 0.000000e+00 : f32 + %cst_0 = constant 1.000000e+00 : f32 + %cst_1 = constant 7.000000e+00 : f32 + for %i0 = 0 to 10 { + for %i1 = 0 to 10 { + store %cst_1, %0[%i0, %i1] : memref<10x10xf32> + } + } + for %i2 = 0 to 3 { + for %i3 = 0 to 3 { + store %cst, %arg0[%i2, %i3] : memref<10x10xf32> + } + } + for %i4 = 0 to 3 { + for %i5 = 0 to 3 { + %1 = load %0[%i4, %i5] : memref<10x10xf32> + %2 = load %arg0[%i4, %i5] : memref<10x10xf32> + %3 = mulf %1, %2 : f32 + store %3, %arg0[%i4, %i5] : memref<10x10xf32> + } + } + for %i6 = 0 to 3 { + for %i7 = 0 to 3 { + store %cst_0, %arg1[%i6, %i7] : memref<10x10xf32> + } + } + for %i8 = 0 to 3 { + for %i9 = 0 to 3 { + %4 = load %0[%i8, %i9] : memref<10x10xf32> + %5 = load %arg1[%i8, %i9] : memref<10x10xf32> + %6 = addf %4, %5 : f32 + store %6, %arg1[%i8, %i9] : memref<10x10xf32> + } + } + + // Pass 1: should fuse single-use producer loop nests into their unique user, + // so '%i2' will fuse into '%i4' and '%i6' will fuse into '%i8'. + // Pass 2: should fuse sibling loop nests which share no dependence edges, + // so should fuse '%i4' into '%i8'. + // Pass 3: should fuse single-use producer loop nest '%i0' into '%i8'. Note + // that loop nest '%i0' now has a single user after Pass 2 fused its + // two users together). + +// CHECK: for %i0 = 0 to 3 { +// CHECK-NEXT: for %i1 = 0 to 3 { +// CHECK-NEXT: %1 = affine.apply [[MAP0]](%i0, %i1, %i0, %i1) +// CHECK-NEXT: %2 = affine.apply [[MAP1]](%i0, %i1, %i0, %i1) +// CHECK-NEXT: store %cst_1, %0[%1, %2] : memref<1x1xf32> +// CHECK-NEXT: store %cst, %arg0[%i0, %i1] : memref<10x10xf32> +// CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i1, %i0, %i1) +// CHECK-NEXT: %4 = affine.apply [[MAP1]](%i0, %i1, %i0, %i1) +// CHECK-NEXT: %5 = load %0[%3, %4] : memref<1x1xf32> +// CHECK-NEXT: %6 = load %arg0[%i0, %i1] : memref<10x10xf32> +// CHECK-NEXT: %7 = mulf %5, %6 : f32 +// CHECK-NEXT: store %7, %arg0[%i0, %i1] : memref<10x10xf32> +// CHECK-NEXT: store %cst_0, %arg1[%i0, %i1] : memref<10x10xf32> +// CHECK-NEXT: %8 = affine.apply [[MAP0]](%i0, %i1, %i0, %i1) +// CHECK-NEXT: %9 = affine.apply [[MAP1]](%i0, %i1, %i0, %i1) +// CHECK-NEXT: %10 = load %0[%8, %9] : memref<1x1xf32> +// CHECK-NEXT: %11 = load %arg1[%i0, %i1] : memref<10x10xf32> +// CHECK-NEXT: %12 = addf %10, %11 : f32 +// CHECK-NEXT: store %12, %arg1[%i0, %i1] : memref<10x10xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return + + return +} + +// ----- +// CHECK-DAG: [[MAP2:#map[0-9]+]] = (d0, d1, d2) -> (d1) +// CHECK-DAG: [[MAP3:#map[0-9]+]] = (d0, d1, d2) -> (-d0 + d2) + +func @two_matrix_vector_products() { + %in_matrix = alloc() : memref<10x10xf32> + %in_vec0 = alloc() : memref<10xf32> + %in_vec1 = alloc() : memref<10xf32> + %out_vec0 = alloc() : memref<10xf32> + %out_vec1 = alloc() : memref<10xf32> + %cf7 = constant 7.0 : f32 + + // Populate input matrix. + for %i0 = 0 to 10 { + for %i1 = 0 to 10 { + store %cf7, %in_matrix[%i0, %i1] : memref<10x10xf32> + } + } + // out_vec0 = in_matrix x in_vec0 + for %i2 = 0 to 10 { + for %i3 = 0 to 10 { + %v0 = load %in_matrix[%i2, %i3] : memref<10x10xf32> + %v1 = load %in_vec0[%i3] : memref<10xf32> + %v2 = mulf %v0, %v1 : f32 + %v3 = load %out_vec0[%i3] : memref<10xf32> + %v4 = addf %v2, %v3 : f32 + store %v4, %out_vec0[%i3] : memref<10xf32> + } + } + // out_vec1 = in_matrix x in_vec1 + for %i4 = 0 to 10 { + for %i5 = 0 to 10 { + %v5 = load %in_matrix[%i4, %i5] : memref<10x10xf32> + %v6 = load %in_vec1[%i5] : memref<10xf32> + %v7 = mulf %v5, %v6 : f32 + %v8 = load %out_vec1[%i5] : memref<10xf32> + %v9 = addf %v7, %v8 : f32 + store %v9, %out_vec1[%i5] : memref<10xf32> + } + } + +// CHECK: for %i0 = 0 to 10 { +// CHECK-NEXT: for %i1 = 0 to 10 { +// CHECK-NEXT: %5 = affine.apply [[MAP2]](%i0, %i1, %i0) +// CHECK-NEXT: %6 = affine.apply [[MAP3]](%i0, %i1, %i0) +// CHECK-NEXT: store %cst, %0[%5, %6] : memref<10x1xf32> +// CHECK-NEXT: } +// CHECK-NEXT: for %i2 = 0 to 10 { +// CHECK-NEXT: %7 = affine.apply [[MAP2]](%i0, %i2, %i0) +// CHECK-NEXT: %8 = affine.apply [[MAP3]](%i0, %i2, %i0) +// CHECK-NEXT: %9 = load %0[%7, %8] : memref<10x1xf32> +// CHECK-NEXT: %10 = load %1[%i0] : memref<10xf32> +// CHECK-NEXT: %11 = mulf %9, %10 : f32 +// CHECK-NEXT: %12 = load %3[%i0] : memref<10xf32> +// CHECK-NEXT: %13 = addf %11, %12 : f32 +// CHECK-NEXT: store %13, %3[%i0] : memref<10xf32> +// CHECK-NEXT: } +// CHECK-NEXT: for %i3 = 0 to 10 { +// CHECK-NEXT: %14 = affine.apply [[MAP2]](%i0, %i3, %i0) +// CHECK-NEXT: %15 = affine.apply [[MAP3]](%i0, %i3, %i0) +// CHECK-NEXT: %16 = load %0[%14, %15] : memref<10x1xf32> +// CHECK-NEXT: %17 = load %2[%i0] : memref<10xf32> +// CHECK-NEXT: %18 = mulf %16, %17 : f32 +// CHECK-NEXT: %19 = load %4[%i0] : memref<10xf32> +// CHECK-NEXT: %20 = addf %18, %19 : f32 +// CHECK-NEXT: store %20, %4[%i0] : memref<10xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return + return +} \ No newline at end of file -- cgit v1.2.3 From 85d9b6c8f7119e519fa9cd34f0359625b7c315ae Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Fri, 1 Mar 2019 13:48:24 -0800 Subject: Use consistent names for dialect op source files This CL changes dialect op source files (.h, .cpp, .td) to follow the following convention: /Ops.{h|cpp|td} Builtin and standard dialects are specially treated, though. Both of them do not have dialect namespace; the former is still named as BuiltinOps.* and the latter is named as Ops.*. Purely mechanical. NFC. PiperOrigin-RevId: 236371358 --- mlir/include/mlir/IR/BuiltinOps.h | 2 +- mlir/include/mlir/IR/OpBase.td | 574 +++++++ mlir/include/mlir/IR/op_base.td | 574 ------- mlir/include/mlir/LLVMIR/LLVMDialect.h | 2 +- mlir/include/mlir/LLVMIR/LLVMOps.td | 207 +++ mlir/include/mlir/LLVMIR/llvm_ops.td | 207 --- mlir/include/mlir/StandardOps/Ops.h | 754 +++++++++ mlir/include/mlir/StandardOps/Ops.td | 135 ++ mlir/include/mlir/StandardOps/StandardOps.h | 754 --------- mlir/include/mlir/StandardOps/standard_ops.td | 135 -- mlir/lib/AffineOps/AffineOps.cpp | 2 +- mlir/lib/Analysis/AffineAnalysis.cpp | 2 +- mlir/lib/Analysis/LoopAnalysis.cpp | 2 +- mlir/lib/Analysis/MemRefBoundCheck.cpp | 2 +- mlir/lib/Analysis/MemRefDependenceCheck.cpp | 2 +- mlir/lib/Analysis/NestedMatcher.cpp | 2 +- mlir/lib/Analysis/Utils.cpp | 2 +- mlir/lib/Analysis/VectorAnalysis.cpp | 2 +- mlir/lib/EDSC/LowerEDSCTestPass.cpp | 2 +- mlir/lib/EDSC/MLIREmitter.cpp | 2 +- mlir/lib/EDSC/Types.cpp | 2 +- mlir/lib/LLVMIR/IR/LLVMDialect.cpp | 2 +- .../lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp | 2 +- mlir/lib/StandardOps/DialectRegistration.cpp | 2 +- mlir/lib/StandardOps/Ops.cpp | 1609 ++++++++++++++++++++ mlir/lib/StandardOps/StandardOps.cpp | 1609 -------------------- mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp | 2 +- mlir/lib/Transforms/DmaGeneration.cpp | 2 +- mlir/lib/Transforms/LoopFusion.cpp | 2 +- mlir/lib/Transforms/LowerAffine.cpp | 2 +- mlir/lib/Transforms/LowerVectorTransfers.cpp | 2 +- mlir/lib/Transforms/MaterializeVectors.cpp | 2 +- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 2 +- mlir/lib/Transforms/PipelineDataTransfer.cpp | 2 +- mlir/lib/Transforms/Utils/LoopUtils.cpp | 2 +- mlir/lib/Transforms/Utils/Utils.cpp | 2 +- mlir/lib/Transforms/Vectorize.cpp | 2 +- mlir/test/EDSC/api-test.cpp | 2 +- mlir/test/mlir-tblgen/one-op-one-result.td | 2 +- mlir/test/mlir-tblgen/op-result.td | 2 +- mlir/test/mlir-tblgen/predicate.td | 2 +- mlir/test/mlir-tblgen/reference-impl.td | 2 +- 42 files changed, 3311 insertions(+), 3311 deletions(-) create mode 100644 mlir/include/mlir/IR/OpBase.td delete mode 100644 mlir/include/mlir/IR/op_base.td create mode 100644 mlir/include/mlir/LLVMIR/LLVMOps.td delete mode 100644 mlir/include/mlir/LLVMIR/llvm_ops.td create mode 100644 mlir/include/mlir/StandardOps/Ops.h create mode 100644 mlir/include/mlir/StandardOps/Ops.td delete mode 100644 mlir/include/mlir/StandardOps/StandardOps.h delete mode 100644 mlir/include/mlir/StandardOps/standard_ops.td create mode 100644 mlir/lib/StandardOps/Ops.cpp delete mode 100644 mlir/lib/StandardOps/StandardOps.cpp (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/IR/BuiltinOps.h b/mlir/include/mlir/IR/BuiltinOps.h index 77594c83a44..7b606fdbd1d 100644 --- a/mlir/include/mlir/IR/BuiltinOps.h +++ b/mlir/include/mlir/IR/BuiltinOps.h @@ -1,4 +1,4 @@ -//===- BuiltinOps.h - Builtin MLIR Operations -----------------*- C++ -*-===// +//===- BuiltinOps.h - Builtin MLIR Operations -------------------*- C++ -*-===// // // Copyright 2019 The MLIR Authors. // diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td new file mode 100644 index 00000000000..7e17e558d4c --- /dev/null +++ b/mlir/include/mlir/IR/OpBase.td @@ -0,0 +1,574 @@ +//===-- OpBase.td - Base op definition file ----------------*- tablegen -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This is the base operation definition file. +// +//===----------------------------------------------------------------------===// + +#ifdef OP_BASE +#else +#define OP_BASE + +//===----------------------------------------------------------------------===// +// Predicates. +//===----------------------------------------------------------------------===// + +// A logical predicate. +class Pred; + +// Logical predicate wrapping a C expression. +class CPred : Pred { + code predCall = "(" # pred # ")"; +} + +// Kinds of combined logical predicates. These must closesly match the +// predicates implemented by the C++ backend (tblgen::PredCombinerKind). +class PredCombinerKind; +def PredCombinerAnd : PredCombinerKind; +def PredCombinerOr : PredCombinerKind; +def PredCombinerNot : PredCombinerKind; +def PredCombinerSubstLeaves : PredCombinerKind; + +// A predicate that combines other predicates as defined by PredCombinerKind. +// Instantiated below. +class CombinedPred c> : Pred { + PredCombinerKind kind = k; + list children = c; +} + +// A predicate that holds if all of its children hold. Always holds for zero +// children. +class AllOf children> : CombinedPred; + +// A predicate that holds if any of its children hold. Never holds for zero +// children. +class AnyOf children> : CombinedPred; + +// A predicate that holds if its child does not. +class Neg : CombinedPred; + +// A predicate that substitutes "pat" with "repl" in predicate calls of the +// leaves of the predicate tree (i.e., not CombinedPredicates). This is plain +// string substitution without regular expressions or captures, new predicates +// with more complex logical can be introduced should the need arise. +class SubstLeaves + : CombinedPred { + string pattern = pat; + string replacement = repl; +} + +//===----------------------------------------------------------------------===// +// Type predicates. ({0} is replaced by an instance of mlir::Type) +//===----------------------------------------------------------------------===// + +// Whether a type is a VectorType. +def IsVectorTypePred : CPred<"{0}.isa()">; + +// Whether a type is a TensorType. +def IsTensorTypePred : CPred<"{0}.isa()">; + +// For a TensorType, verify that it is a statically shaped tensor. +def IsStaticShapeTensorTypePred : + CPred<"{0}.cast().hasStaticShape()">; + +//===----------------------------------------------------------------------===// +// Type constraints and types. +//===----------------------------------------------------------------------===// + +// A constraint on types. This can be used to check the validity of +// instruction arguments. +class TypeConstraint { + // The predicates that this type satisfies. + // Format: {0} will be expanded to the type. + Pred predicate = condition; + // User-readable description used, e.g., for error reporting. If empty, + // a generic message will be used instead. + string description = descr; +} + +// A type, carries type constraints. +class Type + : TypeConstraint; + +// A variadic type. It expands to zero or more of the base type. +// This class is used for supporting variadic operands/results. An op can +// declare no more than one variadic operand/result, and that operand/result +// must be the last one in the operand/result list. +class Variadic + // TODO: support variadic type conditions + : Type, descr> { + Type baseType = type; +} + +// A type that can be constructed using MLIR::Builder. +// Note that this does not "inherit" from Type because it would require +// duplicating Type subclasses for buildable and non-buildable cases to avoid +// diamond "inheritance". +// TODO(zinenko): we may extend this to a more general 'Buildable' trait, +// making some Types and some Attrs buildable. +class BuildableType { + // The builder call to invoke (if specified) to construct the BuildableType. + // Format: this will be affixed to the builder. + code builderCall = builder; +} + +// Integer types. +class IntegerBase : Type; + +// Any integer type irrespective of its width. +def Integer : IntegerBase()">, "integer">; + +// Index type. +def Index : IntegerBase()">, "index">; + +// Integer type of a specific width. +class I + : IntegerBase, + width # "-bit integer">, + BuildableType<"getIntegerType(" # width # ")"> { + int bitwidth = width; +} +def I1 : I<1>; +def I8 : I<8>; +def I16 : I<16>; +def I32 : I<32>; +def I64 : I<64>; + +// Floating point types. +class FloatBase : Type; + +// Any float type irrespective of its width. +def Float : FloatBase()">, "floating-point">; + +// Float type of a specific width. +class F + : FloatBase, + width # "-bit float">, + BuildableType<"getF" # width # "Type()"> { + int bitwidth = width; +} + +def F16 : F<16>; +def F32 : F<32>; +def F64 : F<64>; + +// A container type is a type that has another type embedded within it. +class ContainerType : + // First, check the container predicate. Then, substitute the extracted + // element into the element type checker. + Type(elementTypeCall), + etype.predicate>]>, + descr # " of " # etype.description # " values"> { + // The type of elements in the container. + Type elementType = etype; + + // Call to retrieve. + code getElementTypeCall = elementTypeCall; +} + +// Vector types. +class TypedVector : ContainerType().getElementType()", "vector">; + +class Vector dims> : ContainerType().getShape() == ArrayRef{{" # + !foldl("", dims, sum, element, sum # + !if(!empty(sum), "", ",") # !cast(element)) # "}">]>, + "{0}.cast().getElementType()", + "vector"> { + list dimensions = dims; +} + +// Tensor type. + +// This represents a generic tensor without constraints on elemental type, +// rank, size. As there is no constraint on elemental type, derive from Type +// directly instead of ContainerType. +def Tensor : Type; + +// A tensor with static shape but no other constraints. Note: as +// Tensor is a def this doesn't derive from it, but reuses the predicate +// that must hold for it to be a tensor. +def StaticShapeTensor + : Type, + "statically shaped tensor">; + +// For typed tensors. +class TypedTensor + : ContainerType().getElementType()", + "tensor">; + +def F32Tensor : TypedTensor; + +// Type constraint for integer-like types: integers, indices, vectors of +// integers, tensors of integers. +def IntegerLike : TypeConstraint.predicate, TypedTensor.predicate]>, + "integer-like">; + +// Type constraint for float-like types: floats, vectors or tensors thereof. +def FloatLike : TypeConstraint.predicate, TypedTensor.predicate]>, + "floating-point-like">; + +//===----------------------------------------------------------------------===// +// Attributes +//===----------------------------------------------------------------------===// + +// A constraint on attributes. This can be used to check the validity of +// instruction attributes. +class AttrConstraint { + // The predicates that this attribute satisfies. + // Format: {0} will be expanded to the attribute. + Pred predicate = condition; + // User-readable description used, e.g., for error reporting. + // If empty, a generic message will be used instead. + string description = descr; +} + +// Base class for all attributes. +class Attr : + AttrConstraint { + code storageType = ?; // The backing mlir::Attribute type + code returnType = ?; // The underlying C++ value type + + // Define converter method to convert from the storage type to the return + // type. For example, an enum can be stored as an int but returned as an + // enum class. + // + // Format: {0} will be expanded to the attribute. So + // '{0}.getValue().convertToFloat()' for 'FloatAttr val' will expand to + // 'getAttrOfType("val").getValue().convertToFloat()'. + code convertFromStorage = "{0}.getValue()"; + + // The call expression that builds an attribute from a constant value. + // + // Format: {0} will be expanded to an instance of mlir::Builder, {1} will be + // expanded to the constant value of the attribute. For example, + // '{0}.getStringAttr("{1}")' for 'StringAttr:"foo"' will expand to + // 'builder.getStringAttr("foo")'. + code constBuilderCall = ?; + + // Default value for attribute. + // Requires a constBuilderCall defined. + string defaultValue = ?; + + // Whether the attribute is optional. Typically requires a custom + // convertFromStorage method to handle the case where the attribute is + // not present. + bit isOptional = 0b0; +} + +// Decorates an attribute to have an (unvalidated) default value if not present. +class DefaultValuedAttr : + Attr { + // Construct this attribute with the input attribute and change only + // the default value. + // Note: this has to be kept up to date with Attr above. + let storageType = attr.storageType; + let returnType = attr.returnType; + let convertFromStorage = attr.convertFromStorage; + let constBuilderCall = attr.constBuilderCall; + let defaultValue = val; +} + +// Decorates an attribute as optional. The return type of the generated +// attribute accessor method will be Optional<>. +class OptionalAttr : + Attr { + // Rewrite the attribute to be optional. + // Note: this has to be kept up to date with Attr above. + let storageType = attr.storageType; + let returnType = "Optional<" # attr.returnType #">"; + let convertFromStorage = "{0} ? " # returnType # "({0}.getValue())" # + " : (llvm::None)"; + let isOptional = 0b1; +} + +// A generic attribute that must be constructed around a specific type. +// Backed by a C++ class "attrName". +class TypeBasedAttr : + Attr, descr> { + let constBuilderCall = + "{0}.get" # attrName # "({0}." # t.builderCall # ", {1})"; + let storageType = attrName; +} + +// An attribute backed by a string type. +class StringBasedAttr : Attr, descr> { + let constBuilderCall = [{ {0}.getStringAttr("{1}") }]; + let storageType = [{ StringAttr }]; + let returnType = [{ StringRef }]; +} + +// Base class for instantiating float attributes of fixed width. +class FloatAttrBase : + TypeBasedAttr { + let returnType = [{ APFloat }]; +} + +// Base class for instantiating integer attributes of fixed width. +class IntegerAttrBase : + TypeBasedAttr; + +def BoolAttr : Attr, "bool"> { + let storageType = [{ BoolAttr }]; + let returnType = [{ bool }]; + let constBuilderCall = [{ {0}.getBoolAttr({1}) }]; +} +def ArrayAttr : Attr, "array"> { + let storageType = [{ ArrayAttr }]; + let returnType = [{ ArrayAttr }]; + code convertFromStorage = "{0}"; +} +class ElementsAttrBase : + Attr { + let storageType = [{ ElementsAttr }]; + let returnType = [{ ElementsAttr }]; + let convertFromStorage = "{0}"; +} +def ElementsAttr: ElementsAttrBase, "constant vector/tensor">; +def F32Attr : FloatAttrBase; +def F64Attr : FloatAttrBase; +def I32Attr : IntegerAttrBase { + let storageType = [{ IntegerAttr }]; + let returnType = [{ int }]; + let convertFromStorage = [{ {0}.getValue().getSExtValue() }]; +} +def StrAttr : StringBasedAttr<"string">; + +// DerivedAttr are attributes whose value is computed from properties +// of the operation. They do not require additional storage and are +// materialized as needed. +class DerivedAttr : Attr, "derived"> { + let returnType = ret; + code body = b; +} + +// Derived attribute that returns a mlir::Type. +class DerivedTypeAttr : DerivedAttr<"Type", body>; + +// Represents a constant attribute of specific Attr type. A constant +// attribute can be specified only of attributes that have a constant +// builder call defined. The constant value is specified as a string. +// +// If used as a constraint, it generates a matcher on a constant attribute by +// using the constant value builder of the attribute and the value. +class ConstantAttr : AttrConstraint< + CPred<"{0} == " # + !subst("{0}", "mlir::Builder(ctx)", !subst("{1}", val, + !cast(attribute.constBuilderCall)))>, + "constant attribute " # val> { + Attr attr = attribute; + string value = val; +} + +class ConstF32Attr : ConstantAttr; + +//===----------------------------------------------------------------------===// +// Op Traits +//===----------------------------------------------------------------------===// + +// OpTrait represents a trait regarding an op. +class OpTrait; + +// NativeOpTrait corresponds to the MLIR C++ OpTrait mechanism. The +// purpose to wrap around C++ symbol string with this class is to make +// traits specified for ops in TableGen less alien and more +// integrated. +class NativeOpTrait : OpTrait { + string trait = prop; +} + +// Specify a trait by way of a predicate on the operation. +class PredOpTrait : OpTrait { + string desc = d; + Pred pred = p; +} + +// op supports operand broadcast behavior +def Broadcastable : NativeOpTrait<"BroadcastableTwoOperandsOneResult">; +// X op Y == Y op X +def Commutative : NativeOpTrait<"IsCommutative">; +// op has no side effect +def NoSideEffect : NativeOpTrait<"HasNoSideEffect">; +// op has the same operand and result type +def SameValueType : NativeOpTrait<"SameOperandsAndResultType">; +// op is a terminator +def Terminator : NativeOpTrait<"IsTerminator">; + +//===----------------------------------------------------------------------===// +// Ops +//===----------------------------------------------------------------------===// + +// Marker used to identify the argument list for an op. +def ins; + +// Marker used to identify the result list for an op. +def outs; + +// Base class for all ops. +class Op props = []> { + // The mnemonic of the op. + string opName = mnemonic; + + // One-line human-readable description of what the op does. + string summary = ""; + + // Additional, longer human-readable description of what the op does. + string description = ""; + + // Dag containting the arguments of the op. Default to 0 arguments. Operands + // to the op need to precede attributes to ops in the argument specification. + dag arguments = (ins); + + // The list of results of the op. Default to 0 results. + dag results = (outs); + + // Attribute getters can be added to the op by adding an Attr member + // with the name and type of the attribute. E.g., adding int attribute + // with name "value" and type "i32": + // I32Attr value; + + // Define the hooks used for building, parsing, printing, verification. + + // Custom builder. + // If a derived class/def does not override this, then two default builders + // are generated, with the following signatures: + // + // static void build(Builder* builder, OperationState* result, + // Type resultType0, Type resultType1, ..., + // Value arg0, Value arg1, ..., + // Attribute , Attribute , ...); + // + // * where the attributes follow the same declaration order as in the op. + // + // static void build(Builder* builder, OperationState* result, + // ArrayRef resultTypes, + // ArrayRef args, + // ArrayRef attributes); + code builder = ?; + + // Custom parser. + code parser = ?; + + // Custom printer. + code printer = ?; + + // Custom verifier. + code verifier = ?; + + // Whether this op has associated canonicalization patterns. + // TODO(b/120163349): figure out a better way to write canonicalization + // patterns in TableGen rules directly instead of using this marker + // and C++ implementations. + bit hasCanonicalizer = 0b0; + + // Whether this op has a constant folder. + bit hasConstantFolder = 0b0; + + // Whether this op has a folder. + bit hasFolder = 0b0; + + // Op traits. + list traits = props; +} + +// The arguments of an op. +class Arguments { + dag arguments = args; +} + +// The results of an op. +class Results { + dag results = rets; +} + +//===----------------------------------------------------------------------===// +// Patterns +//===----------------------------------------------------------------------===// + +// Base class for op+ -> op+ rewrite patterns. These allow declaratively +// specifying rewrite patterns. +class Pattern results, list preds> { + dag patternToMatch = source; + list resultOps = results; + list constraints = preds; +} + +// Form of a pattern which produces a single result. +class Pat preds = []> : + Pattern; + +// Attribute matcher. This is the base class to specify a predicate +// that has to match. Used on the input attributes of a rewrite rule. +class mAttr : AttrConstraint; + +// Combine a list of attribute matchers into an attribute matcher that holds if +// any of the original matchers does. +class mAttrAnyOf attrs> : + mAttr, attrs, prev, attr, + !listconcat(prev, [attr.predicate]))>>; + +// Attribute transforms. This is the base class to specify a +// transformation of a matched attribute. Used on the output of a rewrite +// rule. +class tAttr { + // Code to transform the attribute. + // Format: {0} represents the attribute. + code attrTransform = transform; +} + +// Native code op creation method. This allows performing an arbitrary op +// creation/replacement by invoking a C++ function with the operands and +// attributes. The function specified needs to have the signature: +// +// void f(OperationInst *op, ArrayRef operands, +// ArrayRef attrs, PatternRewriter &rewriter); +// +// The operands and attributes are passed to this function in the order of +// the DAG specified. It is the responsibility of this function to replace the +// matched op(s) using the rewriter. This is intended for the long tail op +// creation and replacement. +class cOp { + // Function to invoke with the given arguments to construct a new op. The + // operands will be passed to the function first followed by the attributes + // (as in the function signature above and required by Op arguments). + string function = f; +} + +// Pattern matching predicate specification to constrain when a pattern may be +// used. For example, +// def : Pat<(... $l, (... $r)), (...), [(mPat<"foo"> $l, $r)]; +// will result in this pattern being considered only if `foo(l, r)` holds where +// `foo` is a C++ function and `l` and `r` are the C++ bound variables of +// $l and $r. +class mPat { + string function = f; +} + +// Marker used to indicate that no new result op are generated by applying the +// rewrite pattern, so to replace the matched DAG with an existing SSA value. +def replaceWithValue; + +#endif // OP_BASE diff --git a/mlir/include/mlir/IR/op_base.td b/mlir/include/mlir/IR/op_base.td deleted file mode 100644 index 680d209e3eb..00000000000 --- a/mlir/include/mlir/IR/op_base.td +++ /dev/null @@ -1,574 +0,0 @@ -//===-- op_base.td - Base op definition file ---------------*- tablegen -*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This is the base operation definition file. -// -//===----------------------------------------------------------------------===// - -#ifdef OP_BASE -#else -#define OP_BASE - -//===----------------------------------------------------------------------===// -// Predicates. -//===----------------------------------------------------------------------===// - -// A logical predicate. -class Pred; - -// Logical predicate wrapping a C expression. -class CPred : Pred { - code predCall = "(" # pred # ")"; -} - -// Kinds of combined logical predicates. These must closesly match the -// predicates implemented by the C++ backend (tblgen::PredCombinerKind). -class PredCombinerKind; -def PredCombinerAnd : PredCombinerKind; -def PredCombinerOr : PredCombinerKind; -def PredCombinerNot : PredCombinerKind; -def PredCombinerSubstLeaves : PredCombinerKind; - -// A predicate that combines other predicates as defined by PredCombinerKind. -// Instantiated below. -class CombinedPred c> : Pred { - PredCombinerKind kind = k; - list children = c; -} - -// A predicate that holds if all of its children hold. Always holds for zero -// children. -class AllOf children> : CombinedPred; - -// A predicate that holds if any of its children hold. Never holds for zero -// children. -class AnyOf children> : CombinedPred; - -// A predicate that holds if its child does not. -class Neg : CombinedPred; - -// A predicate that substitutes "pat" with "repl" in predicate calls of the -// leaves of the predicate tree (i.e., not CombinedPredicates). This is plain -// string substitution without regular expressions or captures, new predicates -// with more complex logical can be introduced should the need arise. -class SubstLeaves - : CombinedPred { - string pattern = pat; - string replacement = repl; -} - -//===----------------------------------------------------------------------===// -// Type predicates. ({0} is replaced by an instance of mlir::Type) -//===----------------------------------------------------------------------===// - -// Whether a type is a VectorType. -def IsVectorTypePred : CPred<"{0}.isa()">; - -// Whether a type is a TensorType. -def IsTensorTypePred : CPred<"{0}.isa()">; - -// For a TensorType, verify that it is a statically shaped tensor. -def IsStaticShapeTensorTypePred : - CPred<"{0}.cast().hasStaticShape()">; - -//===----------------------------------------------------------------------===// -// Type constraints and types. -//===----------------------------------------------------------------------===// - -// A constraint on types. This can be used to check the validity of -// instruction arguments. -class TypeConstraint { - // The predicates that this type satisfies. - // Format: {0} will be expanded to the type. - Pred predicate = condition; - // User-readable description used, e.g., for error reporting. If empty, - // a generic message will be used instead. - string description = descr; -} - -// A type, carries type constraints. -class Type - : TypeConstraint; - -// A variadic type. It expands to zero or more of the base type. -// This class is used for supporting variadic operands/results. An op can -// declare no more than one variadic operand/result, and that operand/result -// must be the last one in the operand/result list. -class Variadic - // TODO: support variadic type conditions - : Type, descr> { - Type baseType = type; -} - -// A type that can be constructed using MLIR::Builder. -// Note that this does not "inherit" from Type because it would require -// duplicating Type subclasses for buildable and non-buildable cases to avoid -// diamond "inheritance". -// TODO(zinenko): we may extend this to a more general 'Buildable' trait, -// making some Types and some Attrs buildable. -class BuildableType { - // The builder call to invoke (if specified) to construct the BuildableType. - // Format: this will be affixed to the builder. - code builderCall = builder; -} - -// Integer types. -class IntegerBase : Type; - -// Any integer type irrespective of its width. -def Integer : IntegerBase()">, "integer">; - -// Index type. -def Index : IntegerBase()">, "index">; - -// Integer type of a specific width. -class I - : IntegerBase, - width # "-bit integer">, - BuildableType<"getIntegerType(" # width # ")"> { - int bitwidth = width; -} -def I1 : I<1>; -def I8 : I<8>; -def I16 : I<16>; -def I32 : I<32>; -def I64 : I<64>; - -// Floating point types. -class FloatBase : Type; - -// Any float type irrespective of its width. -def Float : FloatBase()">, "floating-point">; - -// Float type of a specific width. -class F - : FloatBase, - width # "-bit float">, - BuildableType<"getF" # width # "Type()"> { - int bitwidth = width; -} - -def F16 : F<16>; -def F32 : F<32>; -def F64 : F<64>; - -// A container type is a type that has another type embedded within it. -class ContainerType : - // First, check the container predicate. Then, substitute the extracted - // element into the element type checker. - Type(elementTypeCall), - etype.predicate>]>, - descr # " of " # etype.description # " values"> { - // The type of elements in the container. - Type elementType = etype; - - // Call to retrieve. - code getElementTypeCall = elementTypeCall; -} - -// Vector types. -class TypedVector : ContainerType().getElementType()", "vector">; - -class Vector dims> : ContainerType().getShape() == ArrayRef{{" # - !foldl("", dims, sum, element, sum # - !if(!empty(sum), "", ",") # !cast(element)) # "}">]>, - "{0}.cast().getElementType()", - "vector"> { - list dimensions = dims; -} - -// Tensor type. - -// This represents a generic tensor without constraints on elemental type, -// rank, size. As there is no constraint on elemental type, derive from Type -// directly instead of ContainerType. -def Tensor : Type; - -// A tensor with static shape but no other constraints. Note: as -// Tensor is a def this doesn't derive from it, but reuses the predicate -// that must hold for it to be a tensor. -def StaticShapeTensor - : Type, - "statically shaped tensor">; - -// For typed tensors. -class TypedTensor - : ContainerType().getElementType()", - "tensor">; - -def F32Tensor : TypedTensor; - -// Type constraint for integer-like types: integers, indices, vectors of -// integers, tensors of integers. -def IntegerLike : TypeConstraint.predicate, TypedTensor.predicate]>, - "integer-like">; - -// Type constraint for float-like types: floats, vectors or tensors thereof. -def FloatLike : TypeConstraint.predicate, TypedTensor.predicate]>, - "floating-point-like">; - -//===----------------------------------------------------------------------===// -// Attributes -//===----------------------------------------------------------------------===// - -// A constraint on attributes. This can be used to check the validity of -// instruction attributes. -class AttrConstraint { - // The predicates that this attribute satisfies. - // Format: {0} will be expanded to the attribute. - Pred predicate = condition; - // User-readable description used, e.g., for error reporting. - // If empty, a generic message will be used instead. - string description = descr; -} - -// Base class for all attributes. -class Attr : - AttrConstraint { - code storageType = ?; // The backing mlir::Attribute type - code returnType = ?; // The underlying C++ value type - - // Define converter method to convert from the storage type to the return - // type. For example, an enum can be stored as an int but returned as an - // enum class. - // - // Format: {0} will be expanded to the attribute. So - // '{0}.getValue().convertToFloat()' for 'FloatAttr val' will expand to - // 'getAttrOfType("val").getValue().convertToFloat()'. - code convertFromStorage = "{0}.getValue()"; - - // The call expression that builds an attribute from a constant value. - // - // Format: {0} will be expanded to an instance of mlir::Builder, {1} will be - // expanded to the constant value of the attribute. For example, - // '{0}.getStringAttr("{1}")' for 'StringAttr:"foo"' will expand to - // 'builder.getStringAttr("foo")'. - code constBuilderCall = ?; - - // Default value for attribute. - // Requires a constBuilderCall defined. - string defaultValue = ?; - - // Whether the attribute is optional. Typically requires a custom - // convertFromStorage method to handle the case where the attribute is - // not present. - bit isOptional = 0b0; -} - -// Decorates an attribute to have an (unvalidated) default value if not present. -class DefaultValuedAttr : - Attr { - // Construct this attribute with the input attribute and change only - // the default value. - // Note: this has to be kept up to date with Attr above. - let storageType = attr.storageType; - let returnType = attr.returnType; - let convertFromStorage = attr.convertFromStorage; - let constBuilderCall = attr.constBuilderCall; - let defaultValue = val; -} - -// Decorates an attribute as optional. The return type of the generated -// attribute accessor method will be Optional<>. -class OptionalAttr : - Attr { - // Rewrite the attribute to be optional. - // Note: this has to be kept up to date with Attr above. - let storageType = attr.storageType; - let returnType = "Optional<" # attr.returnType #">"; - let convertFromStorage = "{0} ? " # returnType # "({0}.getValue())" # - " : (llvm::None)"; - let isOptional = 0b1; -} - -// A generic attribute that must be constructed around a specific type. -// Backed by a C++ class "attrName". -class TypeBasedAttr : - Attr, descr> { - let constBuilderCall = - "{0}.get" # attrName # "({0}." # t.builderCall # ", {1})"; - let storageType = attrName; -} - -// An attribute backed by a string type. -class StringBasedAttr : Attr, descr> { - let constBuilderCall = [{ {0}.getStringAttr("{1}") }]; - let storageType = [{ StringAttr }]; - let returnType = [{ StringRef }]; -} - -// Base class for instantiating float attributes of fixed width. -class FloatAttrBase : - TypeBasedAttr { - let returnType = [{ APFloat }]; -} - -// Base class for instantiating integer attributes of fixed width. -class IntegerAttrBase : - TypeBasedAttr; - -def BoolAttr : Attr, "bool"> { - let storageType = [{ BoolAttr }]; - let returnType = [{ bool }]; - let constBuilderCall = [{ {0}.getBoolAttr({1}) }]; -} -def ArrayAttr : Attr, "array"> { - let storageType = [{ ArrayAttr }]; - let returnType = [{ ArrayAttr }]; - code convertFromStorage = "{0}"; -} -class ElementsAttrBase : - Attr { - let storageType = [{ ElementsAttr }]; - let returnType = [{ ElementsAttr }]; - let convertFromStorage = "{0}"; -} -def ElementsAttr: ElementsAttrBase, "constant vector/tensor">; -def F32Attr : FloatAttrBase; -def F64Attr : FloatAttrBase; -def I32Attr : IntegerAttrBase { - let storageType = [{ IntegerAttr }]; - let returnType = [{ int }]; - let convertFromStorage = [{ {0}.getValue().getSExtValue() }]; -} -def StrAttr : StringBasedAttr<"string">; - -// DerivedAttr are attributes whose value is computed from properties -// of the operation. They do not require additional storage and are -// materialized as needed. -class DerivedAttr : Attr, "derived"> { - let returnType = ret; - code body = b; -} - -// Derived attribute that returns a mlir::Type. -class DerivedTypeAttr : DerivedAttr<"Type", body>; - -// Represents a constant attribute of specific Attr type. A constant -// attribute can be specified only of attributes that have a constant -// builder call defined. The constant value is specified as a string. -// -// If used as a constraint, it generates a matcher on a constant attribute by -// using the constant value builder of the attribute and the value. -class ConstantAttr : AttrConstraint< - CPred<"{0} == " # - !subst("{0}", "mlir::Builder(ctx)", !subst("{1}", val, - !cast(attribute.constBuilderCall)))>, - "constant attribute " # val> { - Attr attr = attribute; - string value = val; -} - -class ConstF32Attr : ConstantAttr; - -//===----------------------------------------------------------------------===// -// Op Traits -//===----------------------------------------------------------------------===// - -// OpTrait represents a trait regarding an op. -class OpTrait; - -// NativeOpTrait corresponds to the MLIR C++ OpTrait mechanism. The -// purpose to wrap around C++ symbol string with this class is to make -// traits specified for ops in TableGen less alien and more -// integrated. -class NativeOpTrait : OpTrait { - string trait = prop; -} - -// Specify a trait by way of a predicate on the operation. -class PredOpTrait : OpTrait { - string desc = d; - Pred pred = p; -} - -// op supports operand broadcast behavior -def Broadcastable : NativeOpTrait<"BroadcastableTwoOperandsOneResult">; -// X op Y == Y op X -def Commutative : NativeOpTrait<"IsCommutative">; -// op has no side effect -def NoSideEffect : NativeOpTrait<"HasNoSideEffect">; -// op has the same operand and result type -def SameValueType : NativeOpTrait<"SameOperandsAndResultType">; -// op is a terminator -def Terminator : NativeOpTrait<"IsTerminator">; - -//===----------------------------------------------------------------------===// -// Ops -//===----------------------------------------------------------------------===// - -// Marker used to identify the argument list for an op. -def ins; - -// Marker used to identify the result list for an op. -def outs; - -// Base class for all ops. -class Op props = []> { - // The mnemonic of the op. - string opName = mnemonic; - - // One-line human-readable description of what the op does. - string summary = ""; - - // Additional, longer human-readable description of what the op does. - string description = ""; - - // Dag containting the arguments of the op. Default to 0 arguments. Operands - // to the op need to precede attributes to ops in the argument specification. - dag arguments = (ins); - - // The list of results of the op. Default to 0 results. - dag results = (outs); - - // Attribute getters can be added to the op by adding an Attr member - // with the name and type of the attribute. E.g., adding int attribute - // with name "value" and type "i32": - // I32Attr value; - - // Define the hooks used for building, parsing, printing, verification. - - // Custom builder. - // If a derived class/def does not override this, then two default builders - // are generated, with the following signatures: - // - // static void build(Builder* builder, OperationState* result, - // Type resultType0, Type resultType1, ..., - // Value arg0, Value arg1, ..., - // Attribute , Attribute , ...); - // - // * where the attributes follow the same declaration order as in the op. - // - // static void build(Builder* builder, OperationState* result, - // ArrayRef resultTypes, - // ArrayRef args, - // ArrayRef attributes); - code builder = ?; - - // Custom parser. - code parser = ?; - - // Custom printer. - code printer = ?; - - // Custom verifier. - code verifier = ?; - - // Whether this op has associated canonicalization patterns. - // TODO(b/120163349): figure out a better way to write canonicalization - // patterns in TableGen rules directly instead of using this marker - // and C++ implementations. - bit hasCanonicalizer = 0b0; - - // Whether this op has a constant folder. - bit hasConstantFolder = 0b0; - - // Whether this op has a folder. - bit hasFolder = 0b0; - - // Op traits. - list traits = props; -} - -// The arguments of an op. -class Arguments { - dag arguments = args; -} - -// The results of an op. -class Results { - dag results = rets; -} - -//===----------------------------------------------------------------------===// -// Patterns -//===----------------------------------------------------------------------===// - -// Base class for op+ -> op+ rewrite patterns. These allow declaratively -// specifying rewrite patterns. -class Pattern results, list preds> { - dag patternToMatch = source; - list resultOps = results; - list constraints = preds; -} - -// Form of a pattern which produces a single result. -class Pat preds = []> : - Pattern; - -// Attribute matcher. This is the base class to specify a predicate -// that has to match. Used on the input attributes of a rewrite rule. -class mAttr : AttrConstraint; - -// Combine a list of attribute matchers into an attribute matcher that holds if -// any of the original matchers does. -class mAttrAnyOf attrs> : - mAttr, attrs, prev, attr, - !listconcat(prev, [attr.predicate]))>>; - -// Attribute transforms. This is the base class to specify a -// transformation of a matched attribute. Used on the output of a rewrite -// rule. -class tAttr { - // Code to transform the attribute. - // Format: {0} represents the attribute. - code attrTransform = transform; -} - -// Native code op creation method. This allows performing an arbitrary op -// creation/replacement by invoking a C++ function with the operands and -// attributes. The function specified needs to have the signature: -// -// void f(OperationInst *op, ArrayRef operands, -// ArrayRef attrs, PatternRewriter &rewriter); -// -// The operands and attributes are passed to this function in the order of -// the DAG specified. It is the responsibility of this function to replace the -// matched op(s) using the rewriter. This is intended for the long tail op -// creation and replacement. -class cOp { - // Function to invoke with the given arguments to construct a new op. The - // operands will be passed to the function first followed by the attributes - // (as in the function signature above and required by Op arguments). - string function = f; -} - -// Pattern matching predicate specification to constrain when a pattern may be -// used. For example, -// def : Pat<(... $l, (... $r)), (...), [(mPat<"foo"> $l, $r)]; -// will result in this pattern being considered only if `foo(l, r)` holds where -// `foo` is a C++ function and `l` and `r` are the C++ bound variables of -// $l and $r. -class mPat { - string function = f; -} - -// Marker used to indicate that no new result op are generated by applying the -// rewrite pattern, so to replace the matched DAG with an existing SSA value. -def replaceWithValue; - -#endif // OP_BASE diff --git a/mlir/include/mlir/LLVMIR/LLVMDialect.h b/mlir/include/mlir/LLVMIR/LLVMDialect.h index 6c171659796..750bb6f67e7 100644 --- a/mlir/include/mlir/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/LLVMIR/LLVMDialect.h @@ -67,7 +67,7 @@ public: ///// Ops ///// #define GET_OP_CLASSES -#include "mlir/LLVMIR/llvm_ops.inc" +#include "mlir/LLVMIR/LLVMOps.inc" namespace LLVM { class LLVMDialect : public Dialect { diff --git a/mlir/include/mlir/LLVMIR/LLVMOps.td b/mlir/include/mlir/LLVMIR/LLVMOps.td new file mode 100644 index 00000000000..41cfea95447 --- /dev/null +++ b/mlir/include/mlir/LLVMIR/LLVMOps.td @@ -0,0 +1,207 @@ +//===-- LLVMOps.td - LLVM IR dialect op definition file ----*- tablegen -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This is the LLVM IR operation definition file. +// +//===----------------------------------------------------------------------===// + +#ifdef LLVMIR_OPS +#else +#define LLVMIR_OPS + +#ifdef OP_BASE +#else +include "mlir/IR/OpBase.td" +#endif // OP_BASE + +// LLVM IR type wrapped in MLIR. +def LLVM_Type : Type()">, + "LLVM dialect type">; + +// Base class for LLVM operations. All operations get an "llvm." prefix in +// their name automatically. LLVM operations have either zero or one result, +// this class is specialized below for both cases and should not be used +// directly. +class LLVM_Op traits = []> : + Op; + +def LLVM_OneResultOpBuilder { + code builder = [{ + static void build(Builder *, OperationState *result, + Type resultType, ArrayRef operands, + ArrayRef attributes = {}) { + if (resultType) result->addTypes(resultType); + result->addOperands(operands); + for (auto namedAttr : attributes) { + result->addAttribute(namedAttr.first, namedAttr.second); + } + } + }]; +} + +def LLVM_ZeroResultOpBuilder { + code builder = [{ + static void build(Builder *, OperationState *result, + ArrayRef operands, + ArrayRef attributes = {}) { + result->addOperands(operands); + for (auto namedAttr : attributes) { + result->addAttribute(namedAttr.first, namedAttr.second); + } + } + }]; +} + +class LLVM_TwoBuilders { + code builder = !cast(!strconcat(!cast(b1), !cast(b2))); +} + +// Base class for LLVM operations with one result. +class LLVM_OneResultOp traits = []> : + LLVM_Op, Results<(outs LLVM_Type)> { + let builder = LLVM_OneResultOpBuilder.builder; +} + +// Base class for LLVM operations with zero results. +class LLVM_ZeroResultOp traits = []> : + LLVM_Op, Results<(outs)> { + let builder = LLVM_TwoBuilders<[{ + // Compatibility builder that takes an instance of wrapped llvm::VoidType + // to indicate no result. + static void build(Builder *builder, OperationState *result, Type resultType, + ArrayRef operands, + ArrayRef attributes = {}) { + auto llvmType = resultType.dyn_cast(); + assert(llvmType && "result must be an LLVM type"); + assert(llvmType.getUnderlyingType() && + llvmType.getUnderlyingType()->isVoidTy() && + "for zero-result operands, only 'void' is accepted as result type"); + build(builder, result, operands, attributes); + } + }], + LLVM_ZeroResultOpBuilder.builder>.builder; +} + +// Base class for LLVM terminator operations. All terminator operations have +// zero results and an optional list of successors. +class LLVM_TerminatorOp traits = []> : + LLVM_Op, + Arguments<(ins Variadic)>, Results<(outs)> { + let builder = [{ + static void build(Builder *builder, OperationState *result, + ArrayRef properOperands, + ArrayRef destinations, + ArrayRef> operands = {}, + ArrayRef attributes = {}) { + (void) builder; + result->addOperands(properOperands); + for (auto kvp : llvm::zip(destinations, operands)) { + result->addSuccessor(std::get<0>(kvp), std::get<1>(kvp)); + } + for (auto namedAttr : attributes) { + result->addAttribute(namedAttr.first, namedAttr.second); + } + } + }]; +} + +// Class for arithmetic binary instructions. +class LLVM_ArithmeticOp traits = []> : + LLVM_OneResultOp, + Arguments<(ins LLVM_Type:$lhs, LLVM_Type:$rhs)>; + +// Class for variadic instructions. +class LLVM_VariadicOneResultOp traits = []> : + LLVM_OneResultOp, Arguments<(ins Variadic)>; + +// Integer binary instructions. +def LLVM_AddOp : LLVM_ArithmeticOp<"add", [Commutative]>; +def LLVM_SubOp : LLVM_ArithmeticOp<"sub">; +def LLVM_MulOp : LLVM_ArithmeticOp<"mul", [Commutative]>; +def LLVM_UDivOp : LLVM_ArithmeticOp<"udiv">; +def LLVM_SDivOp : LLVM_ArithmeticOp<"sdiv">; +def LLVM_URemOp : LLVM_ArithmeticOp<"urem">; +def LLVM_SRemOp : LLVM_ArithmeticOp<"srem">; + +// Other integer instructions. +def LLVM_ICmpOp : LLVM_OneResultOp<"icmp", [NoSideEffect]>, + Arguments<(ins LLVM_Type:$lhs, LLVM_Type:$rhs)>; + +// Floating point binary instructions. +def LLVM_FAddOp : LLVM_ArithmeticOp<"fadd">; +def LLVM_FSubOp : LLVM_ArithmeticOp<"fsub">; +def LLVM_FMulOp : LLVM_ArithmeticOp<"fmul">; +def LLVM_FDivOp : LLVM_ArithmeticOp<"fdiv">; +def LLVM_FRemOp : LLVM_ArithmeticOp<"frem">; + +// Memory-related instructions. +def LLVM_AllocaOp : LLVM_OneResultOp<"alloca">, + Arguments<(ins LLVM_Type:$arraySize)>; +def LLVM_GEPOp : LLVM_VariadicOneResultOp<"getelementptr", [NoSideEffect]>; +def LLVM_LoadOp : LLVM_OneResultOp<"load">, Arguments<(ins LLVM_Type:$addr)>; +def LLVM_StoreOp : LLVM_ZeroResultOp<"store">, + Arguments<(ins LLVM_Type:$value, LLVM_Type:$addr)>; +def LLVM_BitcastOp : LLVM_OneResultOp<"bitcast", [NoSideEffect]>, + Arguments<(ins LLVM_Type)>; + + +// Call-related instructions. +def LLVM_CallOp : LLVM_Op<"call">, Arguments<(ins Variadic)>, + Results<(outs Variadic)> { + let builder = LLVM_TwoBuilders< + LLVM_OneResultOpBuilder.builder, + LLVM_ZeroResultOpBuilder.builder + >.builder; + + let verifier = [{ + if (getNumResults() > 1) + return emitOpError("must have 0 or 1 result"); + return false; + }]; +} +def LLVM_ExtractValueOp : LLVM_OneResultOp<"extractvalue", [NoSideEffect]>, + Arguments<(ins LLVM_Type)>; +def LLVM_InsertValueOp : LLVM_OneResultOp<"insertvalue", [NoSideEffect]>, + Arguments<(ins LLVM_Type, LLVM_Type)>; + +// Misc instructions. +def LLVM_SelectOp : LLVM_OneResultOp<"select", [NoSideEffect]>, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>; + +// Terminators. +def LLVM_BrOp : LLVM_TerminatorOp<"br", [NoSideEffect]>; +def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br", [NoSideEffect]> { + let verifier = [{ + if (getNumSuccessors() != 2) + return emitOpError("expected exactly two successors"); + if (getSuccessor(0) == getSuccessor(1) && + getNumSuccessorOperands(0) != 0) + return emitOpError( + "expected successors with arguments to be different blocks"); + return false; + }]; +} +def LLVM_ReturnOp : LLVM_TerminatorOp<"return", [NoSideEffect]>; + +// Pseudo-operations (do not appear in LLVM IR but necessary for the dialect to +// work correctly). +def LLVM_UndefOp : LLVM_OneResultOp<"undef", [NoSideEffect]>; +def LLVM_ConstantOp : LLVM_OneResultOp<"constant", [NoSideEffect]>, + Arguments<(ins)>; + +#endif // LLVMIR_OPS diff --git a/mlir/include/mlir/LLVMIR/llvm_ops.td b/mlir/include/mlir/LLVMIR/llvm_ops.td deleted file mode 100644 index 798b2472172..00000000000 --- a/mlir/include/mlir/LLVMIR/llvm_ops.td +++ /dev/null @@ -1,207 +0,0 @@ -//===-- llvm_ops.td - LLVM IR dialect op definition file ---*- tablegen -*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This is the LLVM IR operation definition file. -// -//===----------------------------------------------------------------------===// - -#ifdef LLVMIR_OPS -#else -#define LLVMIR_OPS - -#ifdef OP_BASE -#else -include "mlir/IR/op_base.td" -#endif // OP_BASE - -// LLVM IR type wrapped in MLIR. -def LLVM_Type : Type()">, - "LLVM dialect type">; - -// Base class for LLVM operations. All operations get an "llvm." prefix in -// their name automatically. LLVM operations have either zero or one result, -// this class is specialized below for both cases and should not be used -// directly. -class LLVM_Op traits = []> : - Op; - -def LLVM_OneResultOpBuilder { - code builder = [{ - static void build(Builder *, OperationState *result, - Type resultType, ArrayRef operands, - ArrayRef attributes = {}) { - if (resultType) result->addTypes(resultType); - result->addOperands(operands); - for (auto namedAttr : attributes) { - result->addAttribute(namedAttr.first, namedAttr.second); - } - } - }]; -} - -def LLVM_ZeroResultOpBuilder { - code builder = [{ - static void build(Builder *, OperationState *result, - ArrayRef operands, - ArrayRef attributes = {}) { - result->addOperands(operands); - for (auto namedAttr : attributes) { - result->addAttribute(namedAttr.first, namedAttr.second); - } - } - }]; -} - -class LLVM_TwoBuilders { - code builder = !cast(!strconcat(!cast(b1), !cast(b2))); -} - -// Base class for LLVM operations with one result. -class LLVM_OneResultOp traits = []> : - LLVM_Op, Results<(outs LLVM_Type)> { - let builder = LLVM_OneResultOpBuilder.builder; -} - -// Base class for LLVM operations with zero results. -class LLVM_ZeroResultOp traits = []> : - LLVM_Op, Results<(outs)> { - let builder = LLVM_TwoBuilders<[{ - // Compatibility builder that takes an instance of wrapped llvm::VoidType - // to indicate no result. - static void build(Builder *builder, OperationState *result, Type resultType, - ArrayRef operands, - ArrayRef attributes = {}) { - auto llvmType = resultType.dyn_cast(); - assert(llvmType && "result must be an LLVM type"); - assert(llvmType.getUnderlyingType() && - llvmType.getUnderlyingType()->isVoidTy() && - "for zero-result operands, only 'void' is accepted as result type"); - build(builder, result, operands, attributes); - } - }], - LLVM_ZeroResultOpBuilder.builder>.builder; -} - -// Base class for LLVM terminator operations. All terminator operations have -// zero results and an optional list of successors. -class LLVM_TerminatorOp traits = []> : - LLVM_Op, - Arguments<(ins Variadic)>, Results<(outs)> { - let builder = [{ - static void build(Builder *builder, OperationState *result, - ArrayRef properOperands, - ArrayRef destinations, - ArrayRef> operands = {}, - ArrayRef attributes = {}) { - (void) builder; - result->addOperands(properOperands); - for (auto kvp : llvm::zip(destinations, operands)) { - result->addSuccessor(std::get<0>(kvp), std::get<1>(kvp)); - } - for (auto namedAttr : attributes) { - result->addAttribute(namedAttr.first, namedAttr.second); - } - } - }]; -} - -// Class for arithmetic binary instructions. -class LLVM_ArithmeticOp traits = []> : - LLVM_OneResultOp, - Arguments<(ins LLVM_Type:$lhs, LLVM_Type:$rhs)>; - -// Class for variadic instructions. -class LLVM_VariadicOneResultOp traits = []> : - LLVM_OneResultOp, Arguments<(ins Variadic)>; - -// Integer binary instructions. -def LLVM_AddOp : LLVM_ArithmeticOp<"add", [Commutative]>; -def LLVM_SubOp : LLVM_ArithmeticOp<"sub">; -def LLVM_MulOp : LLVM_ArithmeticOp<"mul", [Commutative]>; -def LLVM_UDivOp : LLVM_ArithmeticOp<"udiv">; -def LLVM_SDivOp : LLVM_ArithmeticOp<"sdiv">; -def LLVM_URemOp : LLVM_ArithmeticOp<"urem">; -def LLVM_SRemOp : LLVM_ArithmeticOp<"srem">; - -// Other integer instructions. -def LLVM_ICmpOp : LLVM_OneResultOp<"icmp", [NoSideEffect]>, - Arguments<(ins LLVM_Type:$lhs, LLVM_Type:$rhs)>; - -// Floating point binary instructions. -def LLVM_FAddOp : LLVM_ArithmeticOp<"fadd">; -def LLVM_FSubOp : LLVM_ArithmeticOp<"fsub">; -def LLVM_FMulOp : LLVM_ArithmeticOp<"fmul">; -def LLVM_FDivOp : LLVM_ArithmeticOp<"fdiv">; -def LLVM_FRemOp : LLVM_ArithmeticOp<"frem">; - -// Memory-related instructions. -def LLVM_AllocaOp : LLVM_OneResultOp<"alloca">, - Arguments<(ins LLVM_Type:$arraySize)>; -def LLVM_GEPOp : LLVM_VariadicOneResultOp<"getelementptr", [NoSideEffect]>; -def LLVM_LoadOp : LLVM_OneResultOp<"load">, Arguments<(ins LLVM_Type:$addr)>; -def LLVM_StoreOp : LLVM_ZeroResultOp<"store">, - Arguments<(ins LLVM_Type:$value, LLVM_Type:$addr)>; -def LLVM_BitcastOp : LLVM_OneResultOp<"bitcast", [NoSideEffect]>, - Arguments<(ins LLVM_Type)>; - - -// Call-related instructions. -def LLVM_CallOp : LLVM_Op<"call">, Arguments<(ins Variadic)>, - Results<(outs Variadic)> { - let builder = LLVM_TwoBuilders< - LLVM_OneResultOpBuilder.builder, - LLVM_ZeroResultOpBuilder.builder - >.builder; - - let verifier = [{ - if (getNumResults() > 1) - return emitOpError("must have 0 or 1 result"); - return false; - }]; -} -def LLVM_ExtractValueOp : LLVM_OneResultOp<"extractvalue", [NoSideEffect]>, - Arguments<(ins LLVM_Type)>; -def LLVM_InsertValueOp : LLVM_OneResultOp<"insertvalue", [NoSideEffect]>, - Arguments<(ins LLVM_Type, LLVM_Type)>; - -// Misc instructions. -def LLVM_SelectOp : LLVM_OneResultOp<"select", [NoSideEffect]>, - Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>; - -// Terminators. -def LLVM_BrOp : LLVM_TerminatorOp<"br", [NoSideEffect]>; -def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br", [NoSideEffect]> { - let verifier = [{ - if (getNumSuccessors() != 2) - return emitOpError("expected exactly two successors"); - if (getSuccessor(0) == getSuccessor(1) && - getNumSuccessorOperands(0) != 0) - return emitOpError( - "expected successors with arguments to be different blocks"); - return false; - }]; -} -def LLVM_ReturnOp : LLVM_TerminatorOp<"return", [NoSideEffect]>; - -// Pseudo-operations (do not appear in LLVM IR but necessary for the dialect to -// work correctly). -def LLVM_UndefOp : LLVM_OneResultOp<"undef", [NoSideEffect]>; -def LLVM_ConstantOp : LLVM_OneResultOp<"constant", [NoSideEffect]>, - Arguments<(ins)>; - -#endif // LLVMIR_OPS diff --git a/mlir/include/mlir/StandardOps/Ops.h b/mlir/include/mlir/StandardOps/Ops.h new file mode 100644 index 00000000000..7cca20cf039 --- /dev/null +++ b/mlir/include/mlir/StandardOps/Ops.h @@ -0,0 +1,754 @@ +//===- Ops.h - Standard MLIR Operations -------------------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines convenience types for working with standard operations +// in the MLIR instruction set. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_STANDARDOPS_OPS_H +#define MLIR_STANDARDOPS_OPS_H + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" + +namespace mlir { +class AffineMap; +class Builder; + +class StandardOpsDialect : public Dialect { +public: + StandardOpsDialect(MLIRContext *context); +}; + +#define GET_OP_CLASSES +#include "mlir/StandardOps/Ops.inc" + +/// The "alloc" operation allocates a region of memory, as specified by its +/// memref type. For example: +/// +/// %0 = alloc() : memref<8x64xf32, (d0, d1) -> (d0, d1), 1> +/// +/// The optional list of dimension operands are bound to the dynamic dimensions +/// specified in its memref type. In the example below, the ssa value '%d' is +/// bound to the second dimension of the memref (which is dynamic). +/// +/// %0 = alloc(%d) : memref<8x?xf32, (d0, d1) -> (d0, d1), 1> +/// +/// The optional list of symbol operands are bound to the symbols of the +/// memrefs affine map. In the example below, the ssa value '%s' is bound to +/// the symbol 's0' in the affine map specified in the allocs memref type. +/// +/// %0 = alloc()[%s] : memref<8x64xf32, (d0, d1)[s0] -> ((d0 + s0), d1), 1> +/// +/// This operation returns a single ssa value of memref type, which can be used +/// by subsequent load and store operations. +class AllocOp + : public Op { +public: + /// The result of an alloc is always a MemRefType. + MemRefType getType() const { + return getResult()->getType().cast(); + } + + static StringRef getOperationName() { return "alloc"; } + + // Hooks to customize behavior of this op. + static void build(Builder *builder, OperationState *result, + MemRefType memrefType, ArrayRef operands = {}); + bool verify() const; + static bool parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p) const; + static void getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context); + +private: + friend class Instruction; + explicit AllocOp(const Instruction *state) : Op(state) {} +}; + +/// The "call" operation represents a direct call to a function. The operands +/// and result types of the call must match the specified function type. The +/// callee is encoded as a function attribute named "callee". +/// +/// %31 = call @my_add(%0, %1) +/// : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32> +class CallOp + : public Op { +public: + static StringRef getOperationName() { return "call"; } + + static void build(Builder *builder, OperationState *result, Function *callee, + ArrayRef operands); + + Function *getCallee() const { + return getAttrOfType("callee").getValue(); + } + + /// Get the argument operands to the called function. + llvm::iterator_range getArgOperands() const { + return {arg_operand_begin(), arg_operand_end()}; + } + llvm::iterator_range getArgOperands() { + return {arg_operand_begin(), arg_operand_end()}; + } + + const_operand_iterator arg_operand_begin() const { return operand_begin(); } + const_operand_iterator arg_operand_end() const { return operand_end(); } + + operand_iterator arg_operand_begin() { return operand_begin(); } + operand_iterator arg_operand_end() { return operand_end(); } + + // Hooks to customize behavior of this op. + static bool parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p) const; + bool verify() const; + +protected: + friend class Instruction; + explicit CallOp(const Instruction *state) : Op(state) {} +}; + +/// The "call_indirect" operation represents an indirect call to a value of +/// function type. Functions are first class types in MLIR, and may be passed +/// as arguments and merged together with block arguments. The operands +/// and result types of the call must match the specified function type. +/// +/// %31 = call_indirect %15(%0, %1) +/// : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32> +/// +class CallIndirectOp : public Op { +public: + static StringRef getOperationName() { return "call_indirect"; } + + static void build(Builder *builder, OperationState *result, Value *callee, + ArrayRef operands); + + const Value *getCallee() const { return getOperand(0); } + Value *getCallee() { return getOperand(0); } + + /// Get the argument operands to the called function. + llvm::iterator_range getArgOperands() const { + return {arg_operand_begin(), arg_operand_end()}; + } + llvm::iterator_range getArgOperands() { + return {arg_operand_begin(), arg_operand_end()}; + } + + const_operand_iterator arg_operand_begin() const { return ++operand_begin(); } + const_operand_iterator arg_operand_end() const { return operand_end(); } + + operand_iterator arg_operand_begin() { return ++operand_begin(); } + operand_iterator arg_operand_end() { return operand_end(); } + + // Hooks to customize behavior of this op. + static bool parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p) const; + bool verify() const; + static void getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context); + +protected: + friend class Instruction; + explicit CallIndirectOp(const Instruction *state) : Op(state) {} +}; + +/// The predicate indicates the type of the comparison to perform: +/// (in)equality; (un)signed less/greater than (or equal to). +enum class CmpIPredicate { + FirstValidValue, + // (In)equality comparisons. + EQ = FirstValidValue, + NE, + // Signed comparisons. + SLT, + SLE, + SGT, + SGE, + // Unsigned comparisons. + ULT, + ULE, + UGT, + UGE, + // Number of predicates. + NumPredicates +}; + +/// The "cmpi" operation compares its two operands according to the integer +/// comparison rules and the predicate specified by the respective attribute. +/// The predicate defines the type of comparison: (in)equality, (un)signed +/// less/greater than (or equal to). The operands must have the same type, and +/// this type must be an integer type, a vector or a tensor thereof. The result +/// is an i1, or a vector/tensor thereof having the same shape as the inputs. +/// Since integers are signless, the predicate also explicitly indicates +/// whether to interpret the operands as signed or unsigned integers for +/// less/greater than comparisons. For the sake of readability by humans, +/// custom assembly form for the instruction uses a string-typed attribute for +/// the predicate. The value of this attribute corresponds to lower-cased name +/// of the predicate constant, e.g., "slt" means "signed less than". The string +/// representation of the attribute is merely a syntactic sugar and is converted +/// to an integer attribute by the parser. +/// +/// %r1 = cmpi "eq" %0, %1 : i32 +/// %r2 = cmpi "slt" %0, %1 : tensor<42x42xi64> +/// %r3 = "cmpi"(%0, %1){predicate: 0} : (i8, i8) -> i1 +class CmpIOp + : public Op::Impl, + OpTrait::OneResult, OpTrait::ResultsAreBoolLike, + OpTrait::SameOperandsAndResultShape, OpTrait::HasNoSideEffect> { +public: + CmpIPredicate getPredicate() const { + return (CmpIPredicate)getAttrOfType(getPredicateAttrName()) + .getInt(); + } + + static StringRef getOperationName() { return "cmpi"; } + static StringRef getPredicateAttrName() { return "predicate"; } + static CmpIPredicate getPredicateByName(StringRef name); + + static void build(Builder *builder, OperationState *result, CmpIPredicate, + Value *lhs, Value *rhs); + static bool parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p) const; + bool verify() const; + Attribute constantFold(ArrayRef operands, + MLIRContext *context) const; + +private: + friend class Instruction; + explicit CmpIOp(const Instruction *state) : Op(state) {} +}; + +/// The "dealloc" operation frees the region of memory referenced by a memref +/// which was originally created by the "alloc" operation. +/// The "dealloc" operation should not be called on memrefs which alias an +// alloc'd memref (i.e. memrefs returned by the "view" and "reshape" +/// operations). +/// +/// %0 = alloc() : memref<8x64xf32, (d0, d1) -> (d0, d1), 1> +/// +/// dealloc %0 : memref<8x64xf32, (d0, d1) -> (d0, d1), 1> +/// +class DeallocOp + : public Op { +public: + Value *getMemRef() { return getOperand(); } + const Value *getMemRef() const { return getOperand(); } + void setMemRef(Value *value) { setOperand(value); } + + static StringRef getOperationName() { return "dealloc"; } + + // Hooks to customize behavior of this op. + static void build(Builder *builder, OperationState *result, Value *memref); + bool verify() const; + static bool parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p) const; + static void getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context); + +private: + friend class Instruction; + explicit DeallocOp(const Instruction *state) : Op(state) {} +}; + +/// The "dim" operation takes a memref or tensor operand and returns an +/// "index". It requires a single integer attribute named "index". It +/// returns the size of the specified dimension. For example: +/// +/// %1 = dim %0, 2 : tensor +/// +class DimOp : public Op { +public: + static void build(Builder *builder, OperationState *result, + Value *memrefOrTensor, unsigned index); + + Attribute constantFold(ArrayRef operands, + MLIRContext *context) const; + + /// This returns the dimension number that the 'dim' is inspecting. + unsigned getIndex() const { + return getAttrOfType("index").getValue().getZExtValue(); + } + + static StringRef getOperationName() { return "dim"; } + + // Hooks to customize behavior of this op. + bool verify() const; + static bool parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p) const; + +private: + friend class Instruction; + explicit DimOp(const Instruction *state) : Op(state) {} +}; + +// DmaStartOp starts a non-blocking DMA operation that transfers data from a +// source memref to a destination memref. The source and destination memref need +// not be of the same dimensionality, but need to have the same elemental type. +// The operands include the source and destination memref's each followed by its +// indices, size of the data transfer in terms of the number of elements (of the +// elemental type of the memref), a tag memref with its indices, and optionally +// at the end, a stride and a number_of_elements_per_stride arguments. The tag +// location is used by a DmaWaitOp to check for completion. The indices of the +// source memref, destination memref, and the tag memref have the same +// restrictions as any load/store. The optional stride arguments should be of +// 'index' type, and specify a stride for the slower memory space (memory space +// with a lower memory space id), tranferring chunks of +// number_of_elements_per_stride every stride until %num_elements are +// transferred. Either both or no stride arguments should be specified. +// +// For example, a DmaStartOp operation that transfers 256 elements of a memref +// '%src' in memory space 0 at indices [%i, %j] to memref '%dst' in memory space +// 1 at indices [%k, %l], would be specified as follows: +// +// %num_elements = constant 256 +// %idx = constant 0 : index +// %tag = alloc() : memref<1 x i32, (d0) -> (d0), 4> +// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx] : +// memref<40 x 128 x f32>, (d0) -> (d0), 0>, +// memref<2 x 1024 x f32>, (d0) -> (d0), 1>, +// memref<1 x i32>, (d0) -> (d0), 2> +// +// If %stride and %num_elt_per_stride are specified, the DMA is expected to +// transfer %num_elt_per_stride elements every %stride elements apart from +// memory space 0 until %num_elements are transferred. +// +// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx], %stride, +// %num_elt_per_stride : +// +// TODO(mlir-team): add additional operands to allow source and destination +// striding, and multiple stride levels. +// TODO(andydavis) Consider replacing src/dst memref indices with view memrefs. +class DmaStartOp + : public Op { +public: + static void build(Builder *builder, OperationState *result, Value *srcMemRef, + ArrayRef srcIndices, Value *destMemRef, + ArrayRef destIndices, Value *numElements, + Value *tagMemRef, ArrayRef tagIndices, + Value *stride = nullptr, + Value *elementsPerStride = nullptr); + + // Returns the source MemRefType for this DMA operation. + const Value *getSrcMemRef() const { return getOperand(0); } + // Returns the rank (number of indices) of the source MemRefType. + unsigned getSrcMemRefRank() const { + return getSrcMemRef()->getType().cast().getRank(); + } + // Returns the source memerf indices for this DMA operation. + llvm::iterator_range + getSrcIndices() const { + return {getInstruction()->operand_begin() + 1, + getInstruction()->operand_begin() + 1 + getSrcMemRefRank()}; + } + + // Returns the destination MemRefType for this DMA operations. + const Value *getDstMemRef() const { + return getOperand(1 + getSrcMemRefRank()); + } + // Returns the rank (number of indices) of the destination MemRefType. + unsigned getDstMemRefRank() const { + return getDstMemRef()->getType().cast().getRank(); + } + unsigned getSrcMemorySpace() const { + return getSrcMemRef()->getType().cast().getMemorySpace(); + } + unsigned getDstMemorySpace() const { + return getDstMemRef()->getType().cast().getMemorySpace(); + } + + // Returns the destination memref indices for this DMA operation. + llvm::iterator_range + getDstIndices() const { + return {getInstruction()->operand_begin() + 1 + getSrcMemRefRank() + 1, + getInstruction()->operand_begin() + 1 + getSrcMemRefRank() + 1 + + getDstMemRefRank()}; + } + + // Returns the number of elements being transferred by this DMA operation. + const Value *getNumElements() const { + return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank()); + } + + // Returns the Tag MemRef for this DMA operation. + const Value *getTagMemRef() const { + return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1); + } + // Returns the rank (number of indices) of the tag MemRefType. + unsigned getTagMemRefRank() const { + return getTagMemRef()->getType().cast().getRank(); + } + + // Returns the tag memref index for this DMA operation. + llvm::iterator_range + getTagIndices() const { + unsigned tagIndexStartPos = + 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1; + return {getInstruction()->operand_begin() + tagIndexStartPos, + getInstruction()->operand_begin() + tagIndexStartPos + + getTagMemRefRank()}; + } + + /// Returns true if this is a DMA from a faster memory space to a slower one. + bool isDestMemorySpaceFaster() const { + return (getSrcMemorySpace() < getDstMemorySpace()); + } + + /// Returns true if this is a DMA from a slower memory space to a faster one. + bool isSrcMemorySpaceFaster() const { + // Assumes that a lower number is for a slower memory space. + return (getDstMemorySpace() < getSrcMemorySpace()); + } + + /// Given a DMA start operation, returns the operand position of either the + /// source or destination memref depending on the one that is at the higher + /// level of the memory hierarchy. Asserts failure if neither is true. + unsigned getFasterMemPos() const { + assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster()); + return isSrcMemorySpaceFaster() ? 0 : getSrcMemRefRank() + 1; + } + + static StringRef getOperationName() { return "dma_start"; } + static bool parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p) const; + bool verify() const; + + static void getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context); + + bool isStrided() const { + return getNumOperands() != 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + + 1 + 1 + getTagMemRefRank(); + } + + Value *getStride() { + if (!isStrided()) + return nullptr; + return getOperand(getNumOperands() - 1 - 1); + } + const Value *getStride() const { + return const_cast(this)->getStride(); + } + + Value *getNumElementsPerStride() { + if (!isStrided()) + return nullptr; + return getOperand(getNumOperands() - 1); + } + const Value *getNumElementsPerStride() const { + return const_cast(this)->getNumElementsPerStride(); + } + +protected: + friend class Instruction; + explicit DmaStartOp(const Instruction *state) : Op(state) {} +}; + +// DmaWaitOp blocks until the completion of a DMA operation associated with the +// tag element '%tag[%index]'. %tag is a memref, and %index has to be an index +// with the same restrictions as any load/store index. %num_elements is the +// number of elements associated with the DMA operation. For example: +// +// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%index] : +// memref<2048 x f32>, (d0) -> (d0), 0>, +// memref<256 x f32>, (d0) -> (d0), 1> +// memref<1 x i32>, (d0) -> (d0), 2> +// ... +// ... +// dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 2> +// +class DmaWaitOp + : public Op { +public: + static void build(Builder *builder, OperationState *result, Value *tagMemRef, + ArrayRef tagIndices, Value *numElements); + + static StringRef getOperationName() { return "dma_wait"; } + + // Returns the Tag MemRef associated with the DMA operation being waited on. + const Value *getTagMemRef() const { return getOperand(0); } + Value *getTagMemRef() { return getOperand(0); } + + // Returns the tag memref index for this DMA operation. + llvm::iterator_range + getTagIndices() const { + return {getInstruction()->operand_begin() + 1, + getInstruction()->operand_begin() + 1 + getTagMemRefRank()}; + } + + // Returns the rank (number of indices) of the tag memref. + unsigned getTagMemRefRank() const { + return getTagMemRef()->getType().cast().getRank(); + } + + // Returns the number of elements transferred in the associated DMA operation. + const Value *getNumElements() const { + return getOperand(1 + getTagMemRefRank()); + } + + static bool parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p) const; + static void getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context); + +protected: + friend class Instruction; + explicit DmaWaitOp(const Instruction *state) : Op(state) {} +}; + +/// The "extract_element" op reads a tensor or vector and returns one element +/// from it specified by an index list. The output of extract is a new value +/// with the same type as the elements of the tensor or vector. The arity of +/// indices matches the rank of the accessed value (i.e., if a tensor is of rank +/// 3, then 3 indices are required for the extract). The indices should all be +/// of affine_int type. +/// +/// For example: +/// +/// %3 = extract_element %0[%1, %2] : vector<4x4xi32> +/// +class ExtractElementOp + : public Op { +public: + static void build(Builder *builder, OperationState *result, Value *aggregate, + ArrayRef indices = {}); + + Value *getAggregate() { return getOperand(0); } + const Value *getAggregate() const { return getOperand(0); } + + llvm::iterator_range getIndices() { + return {getInstruction()->operand_begin() + 1, + getInstruction()->operand_end()}; + } + + llvm::iterator_range getIndices() const { + return {getInstruction()->operand_begin() + 1, + getInstruction()->operand_end()}; + } + + static StringRef getOperationName() { return "extract_element"; } + + // Hooks to customize behavior of this op. + bool verify() const; + static bool parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p) const; + Attribute constantFold(ArrayRef operands, + MLIRContext *context) const; + +private: + friend class Instruction; + explicit ExtractElementOp(const Instruction *state) : Op(state) {} +}; + +/// The "load" op reads an element from a memref specified by an index list. The +/// output of load is a new value with the same type as the elements of the +/// memref. The arity of indices is the rank of the memref (i.e., if the memref +/// loaded from is of rank 3, then 3 indices are required for the load following +/// the memref identifier). For example: +/// +/// %3 = load %0[%1, %1] : memref<4x4xi32> +/// +class LoadOp + : public Op { +public: + // Hooks to customize behavior of this op. + static void build(Builder *builder, OperationState *result, Value *memref, + ArrayRef indices = {}); + + Value *getMemRef() { return getOperand(0); } + const Value *getMemRef() const { return getOperand(0); } + void setMemRef(Value *value) { setOperand(0, value); } + MemRefType getMemRefType() const { + return getMemRef()->getType().cast(); + } + + llvm::iterator_range getIndices() { + return {getInstruction()->operand_begin() + 1, + getInstruction()->operand_end()}; + } + + llvm::iterator_range getIndices() const { + return {getInstruction()->operand_begin() + 1, + getInstruction()->operand_end()}; + } + + static StringRef getOperationName() { return "load"; } + + bool verify() const; + static bool parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p) const; + static void getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context); + +private: + friend class Instruction; + explicit LoadOp(const Instruction *state) : Op(state) {} +}; + +/// The "memref_cast" operation converts a memref from one type to an equivalent +/// type with a compatible shape. The source and destination types are +/// when both are memref types with the same element type, affine mappings, +/// address space, and rank but where the individual dimensions may add or +/// remove constant dimensions from the memref type. +/// +/// If the cast converts any dimensions from an unknown to a known size, then it +/// acts as an assertion that fails at runtime of the dynamic dimensions +/// disagree with resultant destination size. +/// +/// Assert that the input dynamic shape matches the destination static shape. +/// %2 = memref_cast %1 : memref to memref<4x4xf32> +/// Erase static shape information, replacing it with dynamic information. +/// %3 = memref_cast %1 : memref<4xf32> to memref +/// +class MemRefCastOp : public CastOp { +public: + static StringRef getOperationName() { return "memref_cast"; } + + /// The result of a memref_cast is always a memref. + MemRefType getType() const { + return getResult()->getType().cast(); + } + + bool verify() const; + +private: + friend class Instruction; + explicit MemRefCastOp(const Instruction *state) : CastOp(state) {} +}; + +/// The "select" operation chooses one value based on a binary condition +/// supplied as its first operand. If the value of the first operand is 1, the +/// second operand is chosen, otherwise the third operand is chosen. The second +/// and the third operand must have the same type. The operation applies +/// elementwise to vectors and tensors. The shape of all arguments must be +/// identical. For example, the maximum operation is obtained by combining +/// "select" with "cmpi" as follows. +/// +/// %2 = cmpi "gt" %0, %1 : i32 // %2 is i1 +/// %3 = select %2, %0, %1 : i32 +/// +class SelectOp : public Op::Impl, + OpTrait::OneResult, OpTrait::HasNoSideEffect> { +public: + static StringRef getOperationName() { return "select"; } + static void build(Builder *builder, OperationState *result, Value *condition, + Value *trueValue, Value *falseValue); + static bool parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p) const; + bool verify() const; + + Value *getCondition() { return getOperand(0); } + const Value *getCondition() const { return getOperand(0); } + Value *getTrueValue() { return getOperand(1); } + const Value *getTrueValue() const { return getOperand(1); } + Value *getFalseValue() { return getOperand(2); } + const Value *getFalseValue() const { return getOperand(2); } + + Value *fold(); + +private: + friend class Instruction; + explicit SelectOp(const Instruction *state) : Op(state) {} +}; + +/// The "store" op writes an element to a memref specified by an index list. +/// The arity of indices is the rank of the memref (i.e. if the memref being +/// stored to is of rank 3, then 3 indices are required for the store following +/// the memref identifier). The store instruction does not produce a result. +/// +/// In the following example, the ssa value '%v' is stored in memref '%A' at +/// indices [%i, %j]: +/// +/// store %v, %A[%i, %j] : memref<4x128xf32, (d0, d1) -> (d0, d1), 0> +/// +class StoreOp + : public Op { +public: + // Hooks to customize behavior of this op. + static void build(Builder *builder, OperationState *result, + Value *valueToStore, Value *memref, + ArrayRef indices = {}); + + Value *getValueToStore() { return getOperand(0); } + const Value *getValueToStore() const { return getOperand(0); } + + Value *getMemRef() { return getOperand(1); } + const Value *getMemRef() const { return getOperand(1); } + void setMemRef(Value *value) { setOperand(1, value); } + MemRefType getMemRefType() const { + return getMemRef()->getType().cast(); + } + + llvm::iterator_range getIndices() { + return {getInstruction()->operand_begin() + 2, + getInstruction()->operand_end()}; + } + + llvm::iterator_range getIndices() const { + return {getInstruction()->operand_begin() + 2, + getInstruction()->operand_end()}; + } + + static StringRef getOperationName() { return "store"; } + + bool verify() const; + static bool parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p) const; + + static void getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context); + +private: + friend class Instruction; + explicit StoreOp(const Instruction *state) : Op(state) {} +}; + +/// The "tensor_cast" operation converts a tensor from one type to an equivalent +/// type without changing any data elements. The source and destination types +/// must both be tensor types with the same element type, and the source and +/// destination types may not be the same. They must either have the same rank, +/// or one may be an unknown rank. The operation is invalid if converting to a +/// mismatching constant dimension. +/// +/// Convert from unknown rank to rank 2 with unknown dimension sizes. +/// %2 = tensor_cast %1 : tensor to tensor +/// +class TensorCastOp : public CastOp { +public: + static StringRef getOperationName() { return "tensor_cast"; } + + /// The result of a tensor_cast is always a tensor. + TensorType getType() const { + return getResult()->getType().cast(); + } + + bool verify() const; + +private: + friend class Instruction; + explicit TensorCastOp(const Instruction *state) : CastOp(state) {} +}; + +} // end namespace mlir + +#endif // MLIR_STANDARDOPS_OPS_H diff --git a/mlir/include/mlir/StandardOps/Ops.td b/mlir/include/mlir/StandardOps/Ops.td new file mode 100644 index 00000000000..f35103a3adf --- /dev/null +++ b/mlir/include/mlir/StandardOps/Ops.td @@ -0,0 +1,135 @@ +//===- Ops.td - Standard operation definitions -------------*- tablegen -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// Defines some MLIR standard operations. +// +//===----------------------------------------------------------------------===// + +#ifdef STANDARD_OPS +#else +#define STANDARD_OPS + +#ifdef OP_BASE +#else +include "mlir/IR/OpBase.td" +#endif // OP_BASE + +def AnyType : Type, "any type">; + +// Base class for standard arithmetic operations. Requires operands and +// results to be of the same type, but does not constrain them to specific +// types. Individual classes will have `lhs` and `rhs` accessor to operands. +class ArithmeticOp traits = []> : + Op, + Results<(outs AnyType)> { + + let opName = mnemonic; + + let parser = [{ + return impl::parseBinaryOp(parser, result); + }]; + + let printer = [{ + return impl::printBinaryOp(this->getInstruction(), p); + }]; +} + +// Base class for standard arithmetic operations on integers, vectors and +// tensors thereof. This operation takes two operands and returns one result, +// each of these is required to be of the same type. This type may be an +// integer scalar type, a vector whose element type is an integer type, or an +// integer tensor. The custom assembly form of the operaton is as follows +// +// i %0, %1 : i32 +class IntArithmeticOp traits = []> : + ArithmeticOp, + Arguments<(ins IntegerLike:$lhs, IntegerLike:$rhs)>; + +// Base class for standard arithmetic binary operations on floats, vectors and +// tensors thereof. This operation has two operands and returns one result, +// each of these is required to be of the same type. This type may be a +// floating point scalar type, a vector whose element type is a floating point +// type, or a floating point tensor. The custom assembly form of the operation +// is as follows +// +// f %0, %1 : f32 +class FloatArithmeticOp traits = []> : + ArithmeticOp, + Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>; + +def AddFOp : FloatArithmeticOp<"addf"> { + let summary = "floating point addition operation"; + let hasConstantFolder = 0b1; +} + +def AddIOp : IntArithmeticOp<"addi", [Commutative]> { + let summary = "integer addition operation"; + let hasFolder = 1; + let hasConstantFolder = 0b1; +} + +def DivFOp : FloatArithmeticOp<"divf"> { + let summary = "floating point division operation"; +} + +def DivISOp : IntArithmeticOp<"divis"> { + let summary = "signed integer division operation"; + let hasConstantFolder = 0b1; +} + +def DivIUOp : IntArithmeticOp<"diviu"> { + let summary = "unsigned integer division operation"; + let hasConstantFolder = 0b1; +} + +def MulFOp : FloatArithmeticOp<"mulf"> { + let summary = "foating point multiplication operation"; + let hasConstantFolder = 0b1; +} + +def MulIOp : IntArithmeticOp<"muli", [Commutative]> { + let summary = "integer multiplication operation"; + let hasConstantFolder = 0b1; + let hasFolder = 1; +} + +def RemFOp : FloatArithmeticOp<"remf"> { + let summary = "floating point division remainder operation"; +} + +def RemISOp : IntArithmeticOp<"remis"> { + let summary = "signed integer division remainder operation"; + let hasConstantFolder = 0b1; +} + +def RemIUOp : IntArithmeticOp<"remiu"> { + let summary = "unsigned integer division remainder operation"; + let hasConstantFolder = 0b1; +} + +def SubFOp : FloatArithmeticOp<"subf"> { + let summary = "floating point subtraction operation"; + let hasConstantFolder = 0b1; +} + +def SubIOp : IntArithmeticOp<"subi"> { + let summary = "integer subtraction operation"; + let hasConstantFolder = 0b1; + let hasCanonicalizer = 0b1; +} + +#endif // STANDARD_OPS diff --git a/mlir/include/mlir/StandardOps/StandardOps.h b/mlir/include/mlir/StandardOps/StandardOps.h deleted file mode 100644 index 9166b153fd0..00000000000 --- a/mlir/include/mlir/StandardOps/StandardOps.h +++ /dev/null @@ -1,754 +0,0 @@ -//===- StandardOps.h - Standard MLIR Operations -----------------*- C++ -*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file defines convenience types for working with standard operations -// in the MLIR instruction set. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_STANDARDOPS_STANDARDOPS_H -#define MLIR_STANDARDOPS_STANDARDOPS_H - -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/StandardTypes.h" - -namespace mlir { -class AffineMap; -class Builder; - -class StandardOpsDialect : public Dialect { -public: - StandardOpsDialect(MLIRContext *context); -}; - -#define GET_OP_CLASSES -#include "mlir/StandardOps/standard_ops.inc" - -/// The "alloc" operation allocates a region of memory, as specified by its -/// memref type. For example: -/// -/// %0 = alloc() : memref<8x64xf32, (d0, d1) -> (d0, d1), 1> -/// -/// The optional list of dimension operands are bound to the dynamic dimensions -/// specified in its memref type. In the example below, the ssa value '%d' is -/// bound to the second dimension of the memref (which is dynamic). -/// -/// %0 = alloc(%d) : memref<8x?xf32, (d0, d1) -> (d0, d1), 1> -/// -/// The optional list of symbol operands are bound to the symbols of the -/// memrefs affine map. In the example below, the ssa value '%s' is bound to -/// the symbol 's0' in the affine map specified in the allocs memref type. -/// -/// %0 = alloc()[%s] : memref<8x64xf32, (d0, d1)[s0] -> ((d0 + s0), d1), 1> -/// -/// This operation returns a single ssa value of memref type, which can be used -/// by subsequent load and store operations. -class AllocOp - : public Op { -public: - /// The result of an alloc is always a MemRefType. - MemRefType getType() const { - return getResult()->getType().cast(); - } - - static StringRef getOperationName() { return "alloc"; } - - // Hooks to customize behavior of this op. - static void build(Builder *builder, OperationState *result, - MemRefType memrefType, ArrayRef operands = {}); - bool verify() const; - static bool parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p) const; - static void getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context); - -private: - friend class Instruction; - explicit AllocOp(const Instruction *state) : Op(state) {} -}; - -/// The "call" operation represents a direct call to a function. The operands -/// and result types of the call must match the specified function type. The -/// callee is encoded as a function attribute named "callee". -/// -/// %31 = call @my_add(%0, %1) -/// : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32> -class CallOp - : public Op { -public: - static StringRef getOperationName() { return "call"; } - - static void build(Builder *builder, OperationState *result, Function *callee, - ArrayRef operands); - - Function *getCallee() const { - return getAttrOfType("callee").getValue(); - } - - /// Get the argument operands to the called function. - llvm::iterator_range getArgOperands() const { - return {arg_operand_begin(), arg_operand_end()}; - } - llvm::iterator_range getArgOperands() { - return {arg_operand_begin(), arg_operand_end()}; - } - - const_operand_iterator arg_operand_begin() const { return operand_begin(); } - const_operand_iterator arg_operand_end() const { return operand_end(); } - - operand_iterator arg_operand_begin() { return operand_begin(); } - operand_iterator arg_operand_end() { return operand_end(); } - - // Hooks to customize behavior of this op. - static bool parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p) const; - bool verify() const; - -protected: - friend class Instruction; - explicit CallOp(const Instruction *state) : Op(state) {} -}; - -/// The "call_indirect" operation represents an indirect call to a value of -/// function type. Functions are first class types in MLIR, and may be passed -/// as arguments and merged together with block arguments. The operands -/// and result types of the call must match the specified function type. -/// -/// %31 = call_indirect %15(%0, %1) -/// : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32> -/// -class CallIndirectOp : public Op { -public: - static StringRef getOperationName() { return "call_indirect"; } - - static void build(Builder *builder, OperationState *result, Value *callee, - ArrayRef operands); - - const Value *getCallee() const { return getOperand(0); } - Value *getCallee() { return getOperand(0); } - - /// Get the argument operands to the called function. - llvm::iterator_range getArgOperands() const { - return {arg_operand_begin(), arg_operand_end()}; - } - llvm::iterator_range getArgOperands() { - return {arg_operand_begin(), arg_operand_end()}; - } - - const_operand_iterator arg_operand_begin() const { return ++operand_begin(); } - const_operand_iterator arg_operand_end() const { return operand_end(); } - - operand_iterator arg_operand_begin() { return ++operand_begin(); } - operand_iterator arg_operand_end() { return operand_end(); } - - // Hooks to customize behavior of this op. - static bool parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p) const; - bool verify() const; - static void getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context); - -protected: - friend class Instruction; - explicit CallIndirectOp(const Instruction *state) : Op(state) {} -}; - -/// The predicate indicates the type of the comparison to perform: -/// (in)equality; (un)signed less/greater than (or equal to). -enum class CmpIPredicate { - FirstValidValue, - // (In)equality comparisons. - EQ = FirstValidValue, - NE, - // Signed comparisons. - SLT, - SLE, - SGT, - SGE, - // Unsigned comparisons. - ULT, - ULE, - UGT, - UGE, - // Number of predicates. - NumPredicates -}; - -/// The "cmpi" operation compares its two operands according to the integer -/// comparison rules and the predicate specified by the respective attribute. -/// The predicate defines the type of comparison: (in)equality, (un)signed -/// less/greater than (or equal to). The operands must have the same type, and -/// this type must be an integer type, a vector or a tensor thereof. The result -/// is an i1, or a vector/tensor thereof having the same shape as the inputs. -/// Since integers are signless, the predicate also explicitly indicates -/// whether to interpret the operands as signed or unsigned integers for -/// less/greater than comparisons. For the sake of readability by humans, -/// custom assembly form for the instruction uses a string-typed attribute for -/// the predicate. The value of this attribute corresponds to lower-cased name -/// of the predicate constant, e.g., "slt" means "signed less than". The string -/// representation of the attribute is merely a syntactic sugar and is converted -/// to an integer attribute by the parser. -/// -/// %r1 = cmpi "eq" %0, %1 : i32 -/// %r2 = cmpi "slt" %0, %1 : tensor<42x42xi64> -/// %r3 = "cmpi"(%0, %1){predicate: 0} : (i8, i8) -> i1 -class CmpIOp - : public Op::Impl, - OpTrait::OneResult, OpTrait::ResultsAreBoolLike, - OpTrait::SameOperandsAndResultShape, OpTrait::HasNoSideEffect> { -public: - CmpIPredicate getPredicate() const { - return (CmpIPredicate)getAttrOfType(getPredicateAttrName()) - .getInt(); - } - - static StringRef getOperationName() { return "cmpi"; } - static StringRef getPredicateAttrName() { return "predicate"; } - static CmpIPredicate getPredicateByName(StringRef name); - - static void build(Builder *builder, OperationState *result, CmpIPredicate, - Value *lhs, Value *rhs); - static bool parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p) const; - bool verify() const; - Attribute constantFold(ArrayRef operands, - MLIRContext *context) const; - -private: - friend class Instruction; - explicit CmpIOp(const Instruction *state) : Op(state) {} -}; - -/// The "dealloc" operation frees the region of memory referenced by a memref -/// which was originally created by the "alloc" operation. -/// The "dealloc" operation should not be called on memrefs which alias an -// alloc'd memref (i.e. memrefs returned by the "view" and "reshape" -/// operations). -/// -/// %0 = alloc() : memref<8x64xf32, (d0, d1) -> (d0, d1), 1> -/// -/// dealloc %0 : memref<8x64xf32, (d0, d1) -> (d0, d1), 1> -/// -class DeallocOp - : public Op { -public: - Value *getMemRef() { return getOperand(); } - const Value *getMemRef() const { return getOperand(); } - void setMemRef(Value *value) { setOperand(value); } - - static StringRef getOperationName() { return "dealloc"; } - - // Hooks to customize behavior of this op. - static void build(Builder *builder, OperationState *result, Value *memref); - bool verify() const; - static bool parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p) const; - static void getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context); - -private: - friend class Instruction; - explicit DeallocOp(const Instruction *state) : Op(state) {} -}; - -/// The "dim" operation takes a memref or tensor operand and returns an -/// "index". It requires a single integer attribute named "index". It -/// returns the size of the specified dimension. For example: -/// -/// %1 = dim %0, 2 : tensor -/// -class DimOp : public Op { -public: - static void build(Builder *builder, OperationState *result, - Value *memrefOrTensor, unsigned index); - - Attribute constantFold(ArrayRef operands, - MLIRContext *context) const; - - /// This returns the dimension number that the 'dim' is inspecting. - unsigned getIndex() const { - return getAttrOfType("index").getValue().getZExtValue(); - } - - static StringRef getOperationName() { return "dim"; } - - // Hooks to customize behavior of this op. - bool verify() const; - static bool parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p) const; - -private: - friend class Instruction; - explicit DimOp(const Instruction *state) : Op(state) {} -}; - -// DmaStartOp starts a non-blocking DMA operation that transfers data from a -// source memref to a destination memref. The source and destination memref need -// not be of the same dimensionality, but need to have the same elemental type. -// The operands include the source and destination memref's each followed by its -// indices, size of the data transfer in terms of the number of elements (of the -// elemental type of the memref), a tag memref with its indices, and optionally -// at the end, a stride and a number_of_elements_per_stride arguments. The tag -// location is used by a DmaWaitOp to check for completion. The indices of the -// source memref, destination memref, and the tag memref have the same -// restrictions as any load/store. The optional stride arguments should be of -// 'index' type, and specify a stride for the slower memory space (memory space -// with a lower memory space id), tranferring chunks of -// number_of_elements_per_stride every stride until %num_elements are -// transferred. Either both or no stride arguments should be specified. -// -// For example, a DmaStartOp operation that transfers 256 elements of a memref -// '%src' in memory space 0 at indices [%i, %j] to memref '%dst' in memory space -// 1 at indices [%k, %l], would be specified as follows: -// -// %num_elements = constant 256 -// %idx = constant 0 : index -// %tag = alloc() : memref<1 x i32, (d0) -> (d0), 4> -// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx] : -// memref<40 x 128 x f32>, (d0) -> (d0), 0>, -// memref<2 x 1024 x f32>, (d0) -> (d0), 1>, -// memref<1 x i32>, (d0) -> (d0), 2> -// -// If %stride and %num_elt_per_stride are specified, the DMA is expected to -// transfer %num_elt_per_stride elements every %stride elements apart from -// memory space 0 until %num_elements are transferred. -// -// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx], %stride, -// %num_elt_per_stride : -// -// TODO(mlir-team): add additional operands to allow source and destination -// striding, and multiple stride levels. -// TODO(andydavis) Consider replacing src/dst memref indices with view memrefs. -class DmaStartOp - : public Op { -public: - static void build(Builder *builder, OperationState *result, Value *srcMemRef, - ArrayRef srcIndices, Value *destMemRef, - ArrayRef destIndices, Value *numElements, - Value *tagMemRef, ArrayRef tagIndices, - Value *stride = nullptr, - Value *elementsPerStride = nullptr); - - // Returns the source MemRefType for this DMA operation. - const Value *getSrcMemRef() const { return getOperand(0); } - // Returns the rank (number of indices) of the source MemRefType. - unsigned getSrcMemRefRank() const { - return getSrcMemRef()->getType().cast().getRank(); - } - // Returns the source memerf indices for this DMA operation. - llvm::iterator_range - getSrcIndices() const { - return {getInstruction()->operand_begin() + 1, - getInstruction()->operand_begin() + 1 + getSrcMemRefRank()}; - } - - // Returns the destination MemRefType for this DMA operations. - const Value *getDstMemRef() const { - return getOperand(1 + getSrcMemRefRank()); - } - // Returns the rank (number of indices) of the destination MemRefType. - unsigned getDstMemRefRank() const { - return getDstMemRef()->getType().cast().getRank(); - } - unsigned getSrcMemorySpace() const { - return getSrcMemRef()->getType().cast().getMemorySpace(); - } - unsigned getDstMemorySpace() const { - return getDstMemRef()->getType().cast().getMemorySpace(); - } - - // Returns the destination memref indices for this DMA operation. - llvm::iterator_range - getDstIndices() const { - return {getInstruction()->operand_begin() + 1 + getSrcMemRefRank() + 1, - getInstruction()->operand_begin() + 1 + getSrcMemRefRank() + 1 + - getDstMemRefRank()}; - } - - // Returns the number of elements being transferred by this DMA operation. - const Value *getNumElements() const { - return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank()); - } - - // Returns the Tag MemRef for this DMA operation. - const Value *getTagMemRef() const { - return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1); - } - // Returns the rank (number of indices) of the tag MemRefType. - unsigned getTagMemRefRank() const { - return getTagMemRef()->getType().cast().getRank(); - } - - // Returns the tag memref index for this DMA operation. - llvm::iterator_range - getTagIndices() const { - unsigned tagIndexStartPos = - 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1; - return {getInstruction()->operand_begin() + tagIndexStartPos, - getInstruction()->operand_begin() + tagIndexStartPos + - getTagMemRefRank()}; - } - - /// Returns true if this is a DMA from a faster memory space to a slower one. - bool isDestMemorySpaceFaster() const { - return (getSrcMemorySpace() < getDstMemorySpace()); - } - - /// Returns true if this is a DMA from a slower memory space to a faster one. - bool isSrcMemorySpaceFaster() const { - // Assumes that a lower number is for a slower memory space. - return (getDstMemorySpace() < getSrcMemorySpace()); - } - - /// Given a DMA start operation, returns the operand position of either the - /// source or destination memref depending on the one that is at the higher - /// level of the memory hierarchy. Asserts failure if neither is true. - unsigned getFasterMemPos() const { - assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster()); - return isSrcMemorySpaceFaster() ? 0 : getSrcMemRefRank() + 1; - } - - static StringRef getOperationName() { return "dma_start"; } - static bool parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p) const; - bool verify() const; - - static void getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context); - - bool isStrided() const { - return getNumOperands() != 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + - 1 + 1 + getTagMemRefRank(); - } - - Value *getStride() { - if (!isStrided()) - return nullptr; - return getOperand(getNumOperands() - 1 - 1); - } - const Value *getStride() const { - return const_cast(this)->getStride(); - } - - Value *getNumElementsPerStride() { - if (!isStrided()) - return nullptr; - return getOperand(getNumOperands() - 1); - } - const Value *getNumElementsPerStride() const { - return const_cast(this)->getNumElementsPerStride(); - } - -protected: - friend class Instruction; - explicit DmaStartOp(const Instruction *state) : Op(state) {} -}; - -// DmaWaitOp blocks until the completion of a DMA operation associated with the -// tag element '%tag[%index]'. %tag is a memref, and %index has to be an index -// with the same restrictions as any load/store index. %num_elements is the -// number of elements associated with the DMA operation. For example: -// -// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%index] : -// memref<2048 x f32>, (d0) -> (d0), 0>, -// memref<256 x f32>, (d0) -> (d0), 1> -// memref<1 x i32>, (d0) -> (d0), 2> -// ... -// ... -// dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 2> -// -class DmaWaitOp - : public Op { -public: - static void build(Builder *builder, OperationState *result, Value *tagMemRef, - ArrayRef tagIndices, Value *numElements); - - static StringRef getOperationName() { return "dma_wait"; } - - // Returns the Tag MemRef associated with the DMA operation being waited on. - const Value *getTagMemRef() const { return getOperand(0); } - Value *getTagMemRef() { return getOperand(0); } - - // Returns the tag memref index for this DMA operation. - llvm::iterator_range - getTagIndices() const { - return {getInstruction()->operand_begin() + 1, - getInstruction()->operand_begin() + 1 + getTagMemRefRank()}; - } - - // Returns the rank (number of indices) of the tag memref. - unsigned getTagMemRefRank() const { - return getTagMemRef()->getType().cast().getRank(); - } - - // Returns the number of elements transferred in the associated DMA operation. - const Value *getNumElements() const { - return getOperand(1 + getTagMemRefRank()); - } - - static bool parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p) const; - static void getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context); - -protected: - friend class Instruction; - explicit DmaWaitOp(const Instruction *state) : Op(state) {} -}; - -/// The "extract_element" op reads a tensor or vector and returns one element -/// from it specified by an index list. The output of extract is a new value -/// with the same type as the elements of the tensor or vector. The arity of -/// indices matches the rank of the accessed value (i.e., if a tensor is of rank -/// 3, then 3 indices are required for the extract). The indices should all be -/// of affine_int type. -/// -/// For example: -/// -/// %3 = extract_element %0[%1, %2] : vector<4x4xi32> -/// -class ExtractElementOp - : public Op { -public: - static void build(Builder *builder, OperationState *result, Value *aggregate, - ArrayRef indices = {}); - - Value *getAggregate() { return getOperand(0); } - const Value *getAggregate() const { return getOperand(0); } - - llvm::iterator_range getIndices() { - return {getInstruction()->operand_begin() + 1, - getInstruction()->operand_end()}; - } - - llvm::iterator_range getIndices() const { - return {getInstruction()->operand_begin() + 1, - getInstruction()->operand_end()}; - } - - static StringRef getOperationName() { return "extract_element"; } - - // Hooks to customize behavior of this op. - bool verify() const; - static bool parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p) const; - Attribute constantFold(ArrayRef operands, - MLIRContext *context) const; - -private: - friend class Instruction; - explicit ExtractElementOp(const Instruction *state) : Op(state) {} -}; - -/// The "load" op reads an element from a memref specified by an index list. The -/// output of load is a new value with the same type as the elements of the -/// memref. The arity of indices is the rank of the memref (i.e., if the memref -/// loaded from is of rank 3, then 3 indices are required for the load following -/// the memref identifier). For example: -/// -/// %3 = load %0[%1, %1] : memref<4x4xi32> -/// -class LoadOp - : public Op { -public: - // Hooks to customize behavior of this op. - static void build(Builder *builder, OperationState *result, Value *memref, - ArrayRef indices = {}); - - Value *getMemRef() { return getOperand(0); } - const Value *getMemRef() const { return getOperand(0); } - void setMemRef(Value *value) { setOperand(0, value); } - MemRefType getMemRefType() const { - return getMemRef()->getType().cast(); - } - - llvm::iterator_range getIndices() { - return {getInstruction()->operand_begin() + 1, - getInstruction()->operand_end()}; - } - - llvm::iterator_range getIndices() const { - return {getInstruction()->operand_begin() + 1, - getInstruction()->operand_end()}; - } - - static StringRef getOperationName() { return "load"; } - - bool verify() const; - static bool parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p) const; - static void getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context); - -private: - friend class Instruction; - explicit LoadOp(const Instruction *state) : Op(state) {} -}; - -/// The "memref_cast" operation converts a memref from one type to an equivalent -/// type with a compatible shape. The source and destination types are -/// when both are memref types with the same element type, affine mappings, -/// address space, and rank but where the individual dimensions may add or -/// remove constant dimensions from the memref type. -/// -/// If the cast converts any dimensions from an unknown to a known size, then it -/// acts as an assertion that fails at runtime of the dynamic dimensions -/// disagree with resultant destination size. -/// -/// Assert that the input dynamic shape matches the destination static shape. -/// %2 = memref_cast %1 : memref to memref<4x4xf32> -/// Erase static shape information, replacing it with dynamic information. -/// %3 = memref_cast %1 : memref<4xf32> to memref -/// -class MemRefCastOp : public CastOp { -public: - static StringRef getOperationName() { return "memref_cast"; } - - /// The result of a memref_cast is always a memref. - MemRefType getType() const { - return getResult()->getType().cast(); - } - - bool verify() const; - -private: - friend class Instruction; - explicit MemRefCastOp(const Instruction *state) : CastOp(state) {} -}; - -/// The "select" operation chooses one value based on a binary condition -/// supplied as its first operand. If the value of the first operand is 1, the -/// second operand is chosen, otherwise the third operand is chosen. The second -/// and the third operand must have the same type. The operation applies -/// elementwise to vectors and tensors. The shape of all arguments must be -/// identical. For example, the maximum operation is obtained by combining -/// "select" with "cmpi" as follows. -/// -/// %2 = cmpi "gt" %0, %1 : i32 // %2 is i1 -/// %3 = select %2, %0, %1 : i32 -/// -class SelectOp : public Op::Impl, - OpTrait::OneResult, OpTrait::HasNoSideEffect> { -public: - static StringRef getOperationName() { return "select"; } - static void build(Builder *builder, OperationState *result, Value *condition, - Value *trueValue, Value *falseValue); - static bool parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p) const; - bool verify() const; - - Value *getCondition() { return getOperand(0); } - const Value *getCondition() const { return getOperand(0); } - Value *getTrueValue() { return getOperand(1); } - const Value *getTrueValue() const { return getOperand(1); } - Value *getFalseValue() { return getOperand(2); } - const Value *getFalseValue() const { return getOperand(2); } - - Value *fold(); - -private: - friend class Instruction; - explicit SelectOp(const Instruction *state) : Op(state) {} -}; - -/// The "store" op writes an element to a memref specified by an index list. -/// The arity of indices is the rank of the memref (i.e. if the memref being -/// stored to is of rank 3, then 3 indices are required for the store following -/// the memref identifier). The store instruction does not produce a result. -/// -/// In the following example, the ssa value '%v' is stored in memref '%A' at -/// indices [%i, %j]: -/// -/// store %v, %A[%i, %j] : memref<4x128xf32, (d0, d1) -> (d0, d1), 0> -/// -class StoreOp - : public Op { -public: - // Hooks to customize behavior of this op. - static void build(Builder *builder, OperationState *result, - Value *valueToStore, Value *memref, - ArrayRef indices = {}); - - Value *getValueToStore() { return getOperand(0); } - const Value *getValueToStore() const { return getOperand(0); } - - Value *getMemRef() { return getOperand(1); } - const Value *getMemRef() const { return getOperand(1); } - void setMemRef(Value *value) { setOperand(1, value); } - MemRefType getMemRefType() const { - return getMemRef()->getType().cast(); - } - - llvm::iterator_range getIndices() { - return {getInstruction()->operand_begin() + 2, - getInstruction()->operand_end()}; - } - - llvm::iterator_range getIndices() const { - return {getInstruction()->operand_begin() + 2, - getInstruction()->operand_end()}; - } - - static StringRef getOperationName() { return "store"; } - - bool verify() const; - static bool parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p) const; - - static void getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context); - -private: - friend class Instruction; - explicit StoreOp(const Instruction *state) : Op(state) {} -}; - -/// The "tensor_cast" operation converts a tensor from one type to an equivalent -/// type without changing any data elements. The source and destination types -/// must both be tensor types with the same element type, and the source and -/// destination types may not be the same. They must either have the same rank, -/// or one may be an unknown rank. The operation is invalid if converting to a -/// mismatching constant dimension. -/// -/// Convert from unknown rank to rank 2 with unknown dimension sizes. -/// %2 = tensor_cast %1 : tensor to tensor -/// -class TensorCastOp : public CastOp { -public: - static StringRef getOperationName() { return "tensor_cast"; } - - /// The result of a tensor_cast is always a tensor. - TensorType getType() const { - return getResult()->getType().cast(); - } - - bool verify() const; - -private: - friend class Instruction; - explicit TensorCastOp(const Instruction *state) : CastOp(state) {} -}; - -} // end namespace mlir - -#endif diff --git a/mlir/include/mlir/StandardOps/standard_ops.td b/mlir/include/mlir/StandardOps/standard_ops.td deleted file mode 100644 index 08ce19e671e..00000000000 --- a/mlir/include/mlir/StandardOps/standard_ops.td +++ /dev/null @@ -1,135 +0,0 @@ -//===- standard_ops.td - Standard operation definitions ----*- tablegen -*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// Defines some MLIR standard operations. -// -//===----------------------------------------------------------------------===// - -#ifdef STANDARD_OPS -#else -#define STANDARD_OPS - -#ifdef OP_BASE -#else -include "mlir/IR/op_base.td" -#endif // OP_BASE - -def AnyType : Type, "any type">; - -// Base class for standard arithmetic operations. Requires operands and -// results to be of the same type, but does not constrain them to specific -// types. Individual classes will have `lhs` and `rhs` accessor to operands. -class ArithmeticOp traits = []> : - Op, - Results<(outs AnyType)> { - - let opName = mnemonic; - - let parser = [{ - return impl::parseBinaryOp(parser, result); - }]; - - let printer = [{ - return impl::printBinaryOp(this->getInstruction(), p); - }]; -} - -// Base class for standard arithmetic operations on integers, vectors and -// tensors thereof. This operation takes two operands and returns one result, -// each of these is required to be of the same type. This type may be an -// integer scalar type, a vector whose element type is an integer type, or an -// integer tensor. The custom assembly form of the operaton is as follows -// -// i %0, %1 : i32 -class IntArithmeticOp traits = []> : - ArithmeticOp, - Arguments<(ins IntegerLike:$lhs, IntegerLike:$rhs)>; - -// Base class for standard arithmetic binary operations on floats, vectors and -// tensors thereof. This operation has two operands and returns one result, -// each of these is required to be of the same type. This type may be a -// floating point scalar type, a vector whose element type is a floating point -// type, or a floating point tensor. The custom assembly form of the operation -// is as follows -// -// f %0, %1 : f32 -class FloatArithmeticOp traits = []> : - ArithmeticOp, - Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>; - -def AddFOp : FloatArithmeticOp<"addf"> { - let summary = "floating point addition operation"; - let hasConstantFolder = 0b1; -} - -def AddIOp : IntArithmeticOp<"addi", [Commutative]> { - let summary = "integer addition operation"; - let hasFolder = 1; - let hasConstantFolder = 0b1; -} - -def DivFOp : FloatArithmeticOp<"divf"> { - let summary = "floating point division operation"; -} - -def DivISOp : IntArithmeticOp<"divis"> { - let summary = "signed integer division operation"; - let hasConstantFolder = 0b1; -} - -def DivIUOp : IntArithmeticOp<"diviu"> { - let summary = "unsigned integer division operation"; - let hasConstantFolder = 0b1; -} - -def MulFOp : FloatArithmeticOp<"mulf"> { - let summary = "foating point multiplication operation"; - let hasConstantFolder = 0b1; -} - -def MulIOp : IntArithmeticOp<"muli", [Commutative]> { - let summary = "integer multiplication operation"; - let hasConstantFolder = 0b1; - let hasFolder = 1; -} - -def RemFOp : FloatArithmeticOp<"remf"> { - let summary = "floating point division remainder operation"; -} - -def RemISOp : IntArithmeticOp<"remis"> { - let summary = "signed integer division remainder operation"; - let hasConstantFolder = 0b1; -} - -def RemIUOp : IntArithmeticOp<"remiu"> { - let summary = "unsigned integer division remainder operation"; - let hasConstantFolder = 0b1; -} - -def SubFOp : FloatArithmeticOp<"subf"> { - let summary = "floating point subtraction operation"; - let hasConstantFolder = 0b1; -} - -def SubIOp : IntArithmeticOp<"subi"> { - let summary = "integer subtraction operation"; - let hasConstantFolder = 0b1; - let hasCanonicalizer = 0b1; -} - -#endif // STANDARD_OPS diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 21b68e9c1f5..7c0fd29c191 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -23,7 +23,7 @@ #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/StandardOps/StandardOps.h" +#include "mlir/StandardOps/Ops.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/Support/Debug.h" using namespace mlir; diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index 1a52b839343..ad364b17a45 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -29,7 +29,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Instruction.h" #include "mlir/IR/IntegerSet.h" -#include "mlir/StandardOps/StandardOps.h" +#include "mlir/StandardOps/Ops.h" #include "mlir/Support/MathExtras.h" #include "mlir/Support/STLExtras.h" #include "llvm/ADT/DenseMap.h" diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index c0deb805bdf..d17f4560d69 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -29,7 +29,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Instruction.h" -#include "mlir/StandardOps/StandardOps.h" +#include "mlir/StandardOps/Ops.h" #include "mlir/SuperVectorOps/SuperVectorOps.h" #include "mlir/Support/Functional.h" #include "mlir/Support/MathExtras.h" diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index d709566c322..f731ba17686 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -27,7 +27,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/StandardOps.h" +#include "mlir/StandardOps/Ops.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "memref-bound-check" diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp index d0074dad7f2..7b303f0d070 100644 --- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp @@ -26,7 +26,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/StandardOps.h" +#include "mlir/StandardOps/Ops.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "memref-dependence-check" diff --git a/mlir/lib/Analysis/NestedMatcher.cpp b/mlir/lib/Analysis/NestedMatcher.cpp index ec1b60ee437..3e55291972b 100644 --- a/mlir/lib/Analysis/NestedMatcher.cpp +++ b/mlir/lib/Analysis/NestedMatcher.cpp @@ -17,7 +17,7 @@ #include "mlir/Analysis/NestedMatcher.h" #include "mlir/AffineOps/AffineOps.h" -#include "mlir/StandardOps/StandardOps.h" +#include "mlir/StandardOps/Ops.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 5467608d7c0..437cc2254af 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -27,7 +27,7 @@ #include "mlir/Analysis/AffineStructures.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/StandardOps/StandardOps.h" +#include "mlir/StandardOps/Ops.h" #include "llvm/ADT/DenseMap.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index 9985107008a..815831b7922 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -22,7 +22,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Instruction.h" #include "mlir/IR/IntegerSet.h" -#include "mlir/StandardOps/StandardOps.h" +#include "mlir/StandardOps/Ops.h" #include "mlir/SuperVectorOps/SuperVectorOps.h" #include "mlir/Support/Functional.h" #include "mlir/Support/STLExtras.h" diff --git a/mlir/lib/EDSC/LowerEDSCTestPass.cpp b/mlir/lib/EDSC/LowerEDSCTestPass.cpp index db30eef17bf..b1ba9f0503c 100644 --- a/mlir/lib/EDSC/LowerEDSCTestPass.cpp +++ b/mlir/lib/EDSC/LowerEDSCTestPass.cpp @@ -25,7 +25,7 @@ #include "mlir/IR/StandardTypes.h" #include "mlir/IR/Types.h" #include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/StandardOps.h" +#include "mlir/StandardOps/Ops.h" #include "mlir/Transforms/LoopUtils.h" #include "llvm/Support/raw_ostream.h" diff --git a/mlir/lib/EDSC/MLIREmitter.cpp b/mlir/lib/EDSC/MLIREmitter.cpp index 6c55cb56800..5165b2d0527 100644 --- a/mlir/lib/EDSC/MLIREmitter.cpp +++ b/mlir/lib/EDSC/MLIREmitter.cpp @@ -32,7 +32,7 @@ #include "mlir/IR/Location.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Value.h" -#include "mlir/StandardOps/StandardOps.h" +#include "mlir/StandardOps/Ops.h" #include "mlir/SuperVectorOps/SuperVectorOps.h" #include "mlir/Support/Functional.h" #include "mlir/Support/STLExtras.h" diff --git a/mlir/lib/EDSC/Types.cpp b/mlir/lib/EDSC/Types.cpp index 14544c8b288..5571913245e 100644 --- a/mlir/lib/EDSC/Types.cpp +++ b/mlir/lib/EDSC/Types.cpp @@ -27,7 +27,7 @@ #include "mlir/IR/Function.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/StandardTypes.h" -#include "mlir/StandardOps/StandardOps.h" +#include "mlir/StandardOps/Ops.h" #include "mlir/SuperVectorOps/SuperVectorOps.h" #include "mlir/Support/STLExtras.h" diff --git a/mlir/lib/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/LLVMIR/IR/LLVMDialect.cpp index 3444b0ee4c7..613bfb704fa 100644 --- a/mlir/lib/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/LLVMIR/IR/LLVMDialect.cpp @@ -68,7 +68,7 @@ LLVMDialect::LLVMDialect(MLIRContext *context) addTypes(); #define GET_OP_LIST addOperations< -#include "mlir/LLVMIR/llvm_ops.inc" +#include "mlir/LLVMIR/LLVMOps.inc" >(); } diff --git a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp index 64b80c5577b..ee409e4982c 100644 --- a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp +++ b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp @@ -27,7 +27,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/LLVMIR/LLVMDialect.h" #include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/StandardOps.h" +#include "mlir/StandardOps/Ops.h" #include "mlir/Support/Functional.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" diff --git a/mlir/lib/StandardOps/DialectRegistration.cpp b/mlir/lib/StandardOps/DialectRegistration.cpp index 4ae7330812d..1f71a3d014e 100644 --- a/mlir/lib/StandardOps/DialectRegistration.cpp +++ b/mlir/lib/StandardOps/DialectRegistration.cpp @@ -15,7 +15,7 @@ // limitations under the License. // ============================================================================= -#include "mlir/StandardOps/StandardOps.h" +#include "mlir/StandardOps/Ops.h" using namespace mlir; // Static initialization for standard op dialect registration. diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp new file mode 100644 index 00000000000..50f258c84cd --- /dev/null +++ b/mlir/lib/StandardOps/Ops.cpp @@ -0,0 +1,1609 @@ +//===- Ops.cpp - Standard MLIR Operations ---------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/StandardOps/Ops.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/MathExtras.h" +#include "mlir/Support/STLExtras.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/Support/raw_ostream.h" +using namespace mlir; + +//===----------------------------------------------------------------------===// +// StandardOpsDialect +//===----------------------------------------------------------------------===// + +StandardOpsDialect::StandardOpsDialect(MLIRContext *context) + : Dialect(/*namePrefix=*/"", context) { + addOperations(); +} + +//===----------------------------------------------------------------------===// +// Common canonicalization pattern support logic +//===----------------------------------------------------------------------===// + +namespace { +/// This is a common class used for patterns of the form +/// "someop(memrefcast) -> someop". It folds the source of any memref_cast +/// into the root operation directly. +struct MemRefCastFolder : public RewritePattern { + /// The rootOpName is the name of the root operation to match against. + MemRefCastFolder(StringRef rootOpName, MLIRContext *context) + : RewritePattern(rootOpName, 1, context) {} + + PatternMatchResult match(Instruction *op) const override { + for (auto *operand : op->getOperands()) + if (matchPattern(operand, m_Op())) + return matchSuccess(); + + return matchFailure(); + } + + void rewrite(Instruction *op, PatternRewriter &rewriter) const override { + for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) + if (auto *memref = op->getOperand(i)->getDefiningInst()) + if (auto cast = memref->dyn_cast()) + op->setOperand(i, cast->getOperand()); + rewriter.updatedRootInPlace(op); + } +}; + +/// Performs const folding `calculate` with element-wise behavior on the two +/// attributes in `operands` and returns the result if possible. +template > +Attribute constFoldBinaryOp(ArrayRef operands, + const CalculationT &calculate) { + assert(operands.size() == 2 && "binary op takes two operands"); + + if (auto lhs = operands[0].dyn_cast_or_null()) { + auto rhs = operands[1].dyn_cast_or_null(); + if (!rhs || lhs.getType() != rhs.getType()) + return {}; + + return AttrElementT::get(lhs.getType(), + calculate(lhs.getValue(), rhs.getValue())); + } else if (auto lhs = operands[0].dyn_cast_or_null()) { + auto rhs = operands[1].dyn_cast_or_null(); + if (!rhs || lhs.getType() != rhs.getType()) + return {}; + + auto elementResult = constFoldBinaryOp( + {lhs.getValue(), rhs.getValue()}, calculate); + if (!elementResult) + return {}; + + return SplatElementsAttr::get(lhs.getType(), elementResult); + } + return {}; +} +} // end anonymous namespace. + +//===----------------------------------------------------------------------===// +// AddFOp +//===----------------------------------------------------------------------===// + +Attribute AddFOp::constantFold(ArrayRef operands, + MLIRContext *context) const { + return constFoldBinaryOp( + operands, [](APFloat a, APFloat b) { return a + b; }); +} + +//===----------------------------------------------------------------------===// +// AddIOp +//===----------------------------------------------------------------------===// + +Attribute AddIOp::constantFold(ArrayRef operands, + MLIRContext *context) const { + return constFoldBinaryOp(operands, + [](APInt a, APInt b) { return a + b; }); +} + +Value *AddIOp::fold() { + /// addi(x, 0) -> x + if (matchPattern(getOperand(1), m_Zero())) + return getOperand(0); + + return nullptr; +} + +//===----------------------------------------------------------------------===// +// AllocOp +//===----------------------------------------------------------------------===// + +void AllocOp::build(Builder *builder, OperationState *result, + MemRefType memrefType, ArrayRef operands) { + result->addOperands(operands); + result->types.push_back(memrefType); +} + +void AllocOp::print(OpAsmPrinter *p) const { + MemRefType type = getType(); + *p << "alloc"; + // Print dynamic dimension operands. + printDimAndSymbolList(operand_begin(), operand_end(), + type.getNumDynamicDims(), p); + p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{"map"}); + *p << " : " << type; +} + +bool AllocOp::parse(OpAsmParser *parser, OperationState *result) { + MemRefType type; + + // Parse the dimension operands and optional symbol operands, followed by a + // memref type. + unsigned numDimOperands; + if (parseDimAndSymbolList(parser, result->operands, numDimOperands) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(type)) + return true; + + // Check numDynamicDims against number of question marks in memref type. + // Note: this check remains here (instead of in verify()), because the + // partition between dim operands and symbol operands is lost after parsing. + // Verification still checks that the total number of operands matches + // the number of symbols in the affine map, plus the number of dynamic + // dimensions in the memref. + if (numDimOperands != type.getNumDynamicDims()) { + return parser->emitError(parser->getNameLoc(), + "dimension operand count does not equal memref " + "dynamic dimension count"); + } + result->types.push_back(type); + return false; +} + +bool AllocOp::verify() const { + auto memRefType = getResult()->getType().dyn_cast(); + if (!memRefType) + return emitOpError("result must be a memref"); + + unsigned numSymbols = 0; + if (!memRefType.getAffineMaps().empty()) { + AffineMap affineMap = memRefType.getAffineMaps()[0]; + // Store number of symbols used in affine map (used in subsequent check). + numSymbols = affineMap.getNumSymbols(); + // TODO(zinenko): this check does not belong to AllocOp, or any other op but + // to the type system itself. It has been partially hoisted to Parser but + // remains here in case an AllocOp gets constructed programmatically. + // Remove when we can emit errors directly from *Type::get(...) functions. + // + // Verify that the layout affine map matches the rank of the memref. + if (affineMap.getNumDims() != memRefType.getRank()) + return emitOpError("affine map dimension count must equal memref rank"); + } + unsigned numDynamicDims = memRefType.getNumDynamicDims(); + // Check that the total number of operands matches the number of symbols in + // the affine map, plus the number of dynamic dimensions specified in the + // memref type. + if (getInstruction()->getNumOperands() != numDynamicDims + numSymbols) { + return emitOpError( + "operand count does not equal dimension plus symbol operand count"); + } + // Verify that all operands are of type Index. + for (auto *operand : getOperands()) { + if (!operand->getType().isIndex()) + return emitOpError("requires operands to be of type Index"); + } + return false; +} + +namespace { +/// Fold constant dimensions into an alloc instruction. +struct SimplifyAllocConst : public RewritePattern { + SimplifyAllocConst(MLIRContext *context) + : RewritePattern(AllocOp::getOperationName(), 1, context) {} + + PatternMatchResult match(Instruction *op) const override { + auto alloc = op->cast(); + + // Check to see if any dimensions operands are constants. If so, we can + // substitute and drop them. + for (auto *operand : alloc->getOperands()) + if (matchPattern(operand, m_ConstantIndex())) + return matchSuccess(); + return matchFailure(); + } + + void rewrite(Instruction *op, PatternRewriter &rewriter) const override { + auto allocOp = op->cast(); + auto memrefType = allocOp->getType(); + + // Ok, we have one or more constant operands. Collect the non-constant ones + // and keep track of the resultant memref type to build. + SmallVector newShapeConstants; + newShapeConstants.reserve(memrefType.getRank()); + SmallVector newOperands; + SmallVector droppedOperands; + + unsigned dynamicDimPos = 0; + for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { + int64_t dimSize = memrefType.getDimSize(dim); + // If this is already static dimension, keep it. + if (dimSize != -1) { + newShapeConstants.push_back(dimSize); + continue; + } + auto *defOp = allocOp->getOperand(dynamicDimPos)->getDefiningInst(); + OpPointer constantIndexOp; + if (defOp && (constantIndexOp = defOp->dyn_cast())) { + // Dynamic shape dimension will be folded. + newShapeConstants.push_back(constantIndexOp->getValue()); + // Record to check for zero uses later below. + droppedOperands.push_back(constantIndexOp); + } else { + // Dynamic shape dimension not folded; copy operand from old memref. + newShapeConstants.push_back(-1); + newOperands.push_back(allocOp->getOperand(dynamicDimPos)); + } + dynamicDimPos++; + } + + // Create new memref type (which will have fewer dynamic dimensions). + auto newMemRefType = MemRefType::get( + newShapeConstants, memrefType.getElementType(), + memrefType.getAffineMaps(), memrefType.getMemorySpace()); + assert(newOperands.size() == newMemRefType.getNumDynamicDims()); + + // Create and insert the alloc op for the new memref. + auto newAlloc = + rewriter.create(allocOp->getLoc(), newMemRefType, newOperands); + // Insert a cast so we have the same type as the old alloc. + auto resultCast = rewriter.create(allocOp->getLoc(), newAlloc, + allocOp->getType()); + + rewriter.replaceOp(op, {resultCast}, droppedOperands); + } +}; + +/// Fold alloc instructions with no uses. Alloc has side effects on the heap, +/// but can still be deleted if it has zero uses. +struct SimplifyDeadAlloc : public RewritePattern { + SimplifyDeadAlloc(MLIRContext *context) + : RewritePattern(AllocOp::getOperationName(), 1, context) {} + + PatternMatchResult match(Instruction *op) const override { + auto alloc = op->cast(); + // Check if the alloc'ed value has no uses. + return alloc->use_empty() ? matchSuccess() : matchFailure(); + } + + void rewrite(Instruction *op, PatternRewriter &rewriter) const override { + // Erase the alloc operation. + op->erase(); + } +}; +} // end anonymous namespace. + +void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.push_back(std::make_unique(context)); + results.push_back(std::make_unique(context)); +} + +//===----------------------------------------------------------------------===// +// CallOp +//===----------------------------------------------------------------------===// + +void CallOp::build(Builder *builder, OperationState *result, Function *callee, + ArrayRef operands) { + result->addOperands(operands); + result->addAttribute("callee", builder->getFunctionAttr(callee)); + result->addTypes(callee->getType().getResults()); +} + +bool CallOp::parse(OpAsmParser *parser, OperationState *result) { + StringRef calleeName; + llvm::SMLoc calleeLoc; + FunctionType calleeType; + SmallVector operands; + Function *callee = nullptr; + if (parser->parseFunctionName(calleeName, calleeLoc) || + parser->parseOperandList(operands, /*requiredOperandCount=*/-1, + OpAsmParser::Delimiter::Paren) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(calleeType) || + parser->resolveFunctionName(calleeName, calleeType, calleeLoc, callee) || + parser->addTypesToList(calleeType.getResults(), result->types) || + parser->resolveOperands(operands, calleeType.getInputs(), calleeLoc, + result->operands)) + return true; + + result->addAttribute("callee", parser->getBuilder().getFunctionAttr(callee)); + return false; +} + +void CallOp::print(OpAsmPrinter *p) const { + *p << "call "; + p->printFunctionReference(getCallee()); + *p << '('; + p->printOperands(getOperands()); + *p << ')'; + p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{"callee"}); + *p << " : " << getCallee()->getType(); +} + +bool CallOp::verify() const { + // Check that the callee attribute was specified. + auto fnAttr = getAttrOfType("callee"); + if (!fnAttr) + return emitOpError("requires a 'callee' function attribute"); + + // Verify that the operand and result types match the callee. + auto fnType = fnAttr.getValue()->getType(); + if (fnType.getNumInputs() != getNumOperands()) + return emitOpError("incorrect number of operands for callee"); + + for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) { + if (getOperand(i)->getType() != fnType.getInput(i)) + return emitOpError("operand type mismatch"); + } + + if (fnType.getNumResults() != getNumResults()) + return emitOpError("incorrect number of results for callee"); + + for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) { + if (getResult(i)->getType() != fnType.getResult(i)) + return emitOpError("result type mismatch"); + } + + return false; +} + +//===----------------------------------------------------------------------===// +// CallIndirectOp +//===----------------------------------------------------------------------===// +namespace { +/// Fold indirect calls that have a constant function as the callee operand. +struct SimplifyIndirectCallWithKnownCallee : public RewritePattern { + SimplifyIndirectCallWithKnownCallee(MLIRContext *context) + : RewritePattern(CallIndirectOp::getOperationName(), 1, context) {} + + PatternMatchResult match(Instruction *op) const override { + auto indirectCall = op->cast(); + + // Check that the callee is a constant operation. + Value *callee = indirectCall->getCallee(); + Instruction *calleeInst = callee->getDefiningInst(); + if (!calleeInst || !calleeInst->isa()) + return matchFailure(); + + // Check that the constant callee is a function. + if (calleeInst->cast()->getValue().isa()) + return matchSuccess(); + return matchFailure(); + } + void rewrite(Instruction *op, PatternRewriter &rewriter) const override { + auto indirectCall = op->cast(); + auto calleeOp = + indirectCall->getCallee()->getDefiningInst()->cast(); + + // Replace with a direct call. + Function *calledFn = calleeOp->getValue().cast().getValue(); + SmallVector callOperands(indirectCall->getArgOperands()); + rewriter.replaceOpWithNewOp(op, calledFn, callOperands); + } +}; +} // end anonymous namespace. + +void CallIndirectOp::build(Builder *builder, OperationState *result, + Value *callee, ArrayRef operands) { + auto fnType = callee->getType().cast(); + result->operands.push_back(callee); + result->addOperands(operands); + result->addTypes(fnType.getResults()); +} + +bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) { + FunctionType calleeType; + OpAsmParser::OperandType callee; + llvm::SMLoc operandsLoc; + SmallVector operands; + return parser->parseOperand(callee) || + parser->getCurrentLocation(&operandsLoc) || + parser->parseOperandList(operands, /*requiredOperandCount=*/-1, + OpAsmParser::Delimiter::Paren) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(calleeType) || + parser->resolveOperand(callee, calleeType, result->operands) || + parser->resolveOperands(operands, calleeType.getInputs(), operandsLoc, + result->operands) || + parser->addTypesToList(calleeType.getResults(), result->types); +} + +void CallIndirectOp::print(OpAsmPrinter *p) const { + *p << "call_indirect "; + p->printOperand(getCallee()); + *p << '('; + auto operandRange = getOperands(); + p->printOperands(++operandRange.begin(), operandRange.end()); + *p << ')'; + p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{"callee"}); + *p << " : " << getCallee()->getType(); +} + +bool CallIndirectOp::verify() const { + // The callee must be a function. + auto fnType = getCallee()->getType().dyn_cast(); + if (!fnType) + return emitOpError("callee must have function type"); + + // Verify that the operand and result types match the callee. + if (fnType.getNumInputs() != getNumOperands() - 1) + return emitOpError("incorrect number of operands for callee"); + + for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) { + if (getOperand(i + 1)->getType() != fnType.getInput(i)) + return emitOpError("operand type mismatch"); + } + + if (fnType.getNumResults() != getNumResults()) + return emitOpError("incorrect number of results for callee"); + + for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) { + if (getResult(i)->getType() != fnType.getResult(i)) + return emitOpError("result type mismatch"); + } + + return false; +} + +void CallIndirectOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.push_back( + std::make_unique(context)); +} + +//===----------------------------------------------------------------------===// +// CmpIOp +//===----------------------------------------------------------------------===// + +// Return the type of the same shape (scalar, vector or tensor) containing i1. +static Type getCheckedI1SameShape(Builder *build, Type type) { + auto i1Type = build->getI1Type(); + if (type.isIntOrIndexOrFloat()) + return i1Type; + if (auto tensorType = type.dyn_cast()) + return build->getTensorType(tensorType.getShape(), i1Type); + if (auto tensorType = type.dyn_cast()) + return build->getTensorType(i1Type); + if (auto vectorType = type.dyn_cast()) + return build->getVectorType(vectorType.getShape(), i1Type); + return Type(); +} + +static Type getI1SameShape(Builder *build, Type type) { + Type res = getCheckedI1SameShape(build, type); + assert(res && "expected type with valid i1 shape"); + return res; +} + +static inline bool isI1(Type type) { + return type.isa() && type.cast().getWidth() == 1; +} + +template +static inline bool implCheckI1SameShape(Ty pattern, Type type) { + auto specificType = type.dyn_cast(); + if (!specificType) + return true; + if (specificType.getShape() != pattern.getShape()) + return true; + return !isI1(specificType.getElementType()); +} + +// Checks if "type" has the same shape (scalar, vector or tensor) as "pattern" +// and contains i1. +static bool checkI1SameShape(Type pattern, Type type) { + if (pattern.isIntOrIndexOrFloat()) + return !isI1(type); + if (auto patternTensorType = pattern.dyn_cast()) + return implCheckI1SameShape(patternTensorType, type); + if (auto patternVectorType = pattern.dyn_cast()) + return implCheckI1SameShape(patternVectorType, type); + + llvm_unreachable("unsupported type"); +} + +// Returns an array of mnemonics for CmpIPredicates, indexed by values thereof. +static inline const char *const *getPredicateNames() { + static const char *predicateNames[(int)CmpIPredicate::NumPredicates]{ + /*EQ*/ "eq", + /*NE*/ "ne", + /*SLT*/ "slt", + /*SLE*/ "sle", + /*SGT*/ "sgt", + /*SGE*/ "sge", + /*ULT*/ "ult", + /*ULE*/ "ule", + /*UGT*/ "ugt", + /*UGE*/ "uge"}; + return predicateNames; +}; + +// Returns a value of the predicate corresponding to the given mnemonic. +// Returns NumPredicates (one-past-end) if there is no such mnemonic. +CmpIPredicate CmpIOp::getPredicateByName(StringRef name) { + return llvm::StringSwitch(name) + .Case("eq", CmpIPredicate::EQ) + .Case("ne", CmpIPredicate::NE) + .Case("slt", CmpIPredicate::SLT) + .Case("sle", CmpIPredicate::SLE) + .Case("sgt", CmpIPredicate::SGT) + .Case("sge", CmpIPredicate::SGE) + .Case("ult", CmpIPredicate::ULT) + .Case("ule", CmpIPredicate::ULE) + .Case("ugt", CmpIPredicate::UGT) + .Case("uge", CmpIPredicate::UGE) + .Default(CmpIPredicate::NumPredicates); +} + +void CmpIOp::build(Builder *build, OperationState *result, + CmpIPredicate predicate, Value *lhs, Value *rhs) { + result->addOperands({lhs, rhs}); + result->types.push_back(getI1SameShape(build, lhs->getType())); + result->addAttribute(getPredicateAttrName(), + build->getIntegerAttr(build->getIntegerType(64), + static_cast(predicate))); +} + +bool CmpIOp::parse(OpAsmParser *parser, OperationState *result) { + SmallVector ops; + SmallVector attrs; + Attribute predicateNameAttr; + Type type; + if (parser->parseAttribute(predicateNameAttr, getPredicateAttrName(), + attrs) || + parser->parseComma() || parser->parseOperandList(ops, 2) || + parser->parseOptionalAttributeDict(attrs) || + parser->parseColonType(type) || + parser->resolveOperands(ops, type, result->operands)) + return true; + + if (!predicateNameAttr.isa()) + return parser->emitError(parser->getNameLoc(), + "expected string comparison predicate attribute"); + + // Rewrite string attribute to an enum value. + StringRef predicateName = predicateNameAttr.cast().getValue(); + auto predicate = getPredicateByName(predicateName); + if (predicate == CmpIPredicate::NumPredicates) + return parser->emitError(parser->getNameLoc(), + "unknown comparison predicate \"" + predicateName + + "\""); + + auto builder = parser->getBuilder(); + Type i1Type = getCheckedI1SameShape(&builder, type); + if (!i1Type) + return parser->emitError(parser->getNameLoc(), + "expected type with valid i1 shape"); + + attrs[0].second = builder.getI64IntegerAttr(static_cast(predicate)); + result->attributes = attrs; + + result->addTypes({i1Type}); + return false; +} + +void CmpIOp::print(OpAsmPrinter *p) const { + *p << getOperationName() << " "; + + auto predicateValue = + getAttrOfType(getPredicateAttrName()).getInt(); + assert(predicateValue >= static_cast(CmpIPredicate::FirstValidValue) && + predicateValue < static_cast(CmpIPredicate::NumPredicates) && + "unknown predicate index"); + Builder b(getInstruction()->getContext()); + auto predicateStringAttr = + b.getStringAttr(getPredicateNames()[predicateValue]); + p->printAttribute(predicateStringAttr); + + *p << ", "; + p->printOperand(getOperand(0)); + *p << ", "; + p->printOperand(getOperand(1)); + p->printOptionalAttrDict(getAttrs(), + /*elidedAttrs=*/{getPredicateAttrName()}); + *p << " : " << getOperand(0)->getType(); +} + +bool CmpIOp::verify() const { + auto predicateAttr = getAttrOfType(getPredicateAttrName()); + if (!predicateAttr) + return emitOpError("requires an integer attribute named 'predicate'"); + auto predicate = predicateAttr.getInt(); + if (predicate < (int64_t)CmpIPredicate::FirstValidValue || + predicate >= (int64_t)CmpIPredicate::NumPredicates) + return emitOpError("'predicate' attribute value out of range"); + + return false; +} + +// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer +// comparison predicates. +static bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs, + const APInt &rhs) { + switch (predicate) { + case CmpIPredicate::EQ: + return lhs.eq(rhs); + case CmpIPredicate::NE: + return lhs.ne(rhs); + case CmpIPredicate::SLT: + return lhs.slt(rhs); + case CmpIPredicate::SLE: + return lhs.sle(rhs); + case CmpIPredicate::SGT: + return lhs.sgt(rhs); + case CmpIPredicate::SGE: + return lhs.sge(rhs); + case CmpIPredicate::ULT: + return lhs.ult(rhs); + case CmpIPredicate::ULE: + return lhs.ule(rhs); + case CmpIPredicate::UGT: + return lhs.ugt(rhs); + case CmpIPredicate::UGE: + return lhs.uge(rhs); + default: + llvm_unreachable("unknown comparison predicate"); + } +} + +// Constant folding hook for comparisons. +Attribute CmpIOp::constantFold(ArrayRef operands, + MLIRContext *context) const { + assert(operands.size() == 2 && "cmpi takes two arguments"); + + auto lhs = operands.front().dyn_cast_or_null(); + auto rhs = operands.back().dyn_cast_or_null(); + if (!lhs || !rhs) + return {}; + + auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); + return IntegerAttr::get(IntegerType::get(1, context), APInt(1, val)); +} + +//===----------------------------------------------------------------------===// +// DeallocOp +//===----------------------------------------------------------------------===// +namespace { +/// Fold Dealloc instructions that are deallocating an AllocOp that is only used +/// by other Dealloc operations. +struct SimplifyDeadDealloc : public RewritePattern { + SimplifyDeadDealloc(MLIRContext *context) + : RewritePattern(DeallocOp::getOperationName(), 1, context) {} + + PatternMatchResult match(Instruction *op) const override { + auto dealloc = op->cast(); + + // Check that the memref operand's defining instruction is an AllocOp. + Value *memref = dealloc->getMemRef(); + Instruction *defOp = memref->getDefiningInst(); + if (!defOp || !defOp->isa()) + return matchFailure(); + + // Check that all of the uses of the AllocOp are other DeallocOps. + for (auto &use : memref->getUses()) + if (!use.getOwner()->isa()) + return matchFailure(); + return matchSuccess(); + } + + void rewrite(Instruction *op, PatternRewriter &rewriter) const override { + // Erase the dealloc operation. + op->erase(); + } +}; +} // end anonymous namespace. + +void DeallocOp::build(Builder *builder, OperationState *result, Value *memref) { + result->addOperands(memref); +} + +void DeallocOp::print(OpAsmPrinter *p) const { + *p << "dealloc " << *getMemRef() << " : " << getMemRef()->getType(); +} + +bool DeallocOp::parse(OpAsmParser *parser, OperationState *result) { + OpAsmParser::OperandType memrefInfo; + MemRefType type; + + return parser->parseOperand(memrefInfo) || parser->parseColonType(type) || + parser->resolveOperand(memrefInfo, type, result->operands); +} + +bool DeallocOp::verify() const { + if (!getMemRef()->getType().isa()) + return emitOpError("operand must be a memref"); + return false; +} + +void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + /// dealloc(memrefcast) -> dealloc + results.push_back( + std::make_unique(getOperationName(), context)); + results.push_back(std::make_unique(context)); +} + +//===----------------------------------------------------------------------===// +// DimOp +//===----------------------------------------------------------------------===// + +void DimOp::build(Builder *builder, OperationState *result, + Value *memrefOrTensor, unsigned index) { + result->addOperands(memrefOrTensor); + auto type = builder->getIndexType(); + result->addAttribute("index", builder->getIntegerAttr(type, index)); + result->types.push_back(type); +} + +void DimOp::print(OpAsmPrinter *p) const { + *p << "dim " << *getOperand() << ", " << getIndex(); + p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{"index"}); + *p << " : " << getOperand()->getType(); +} + +bool DimOp::parse(OpAsmParser *parser, OperationState *result) { + OpAsmParser::OperandType operandInfo; + IntegerAttr indexAttr; + Type type; + Type indexType = parser->getBuilder().getIndexType(); + + return parser->parseOperand(operandInfo) || parser->parseComma() || + parser->parseAttribute(indexAttr, indexType, "index", + result->attributes) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(type) || + parser->resolveOperand(operandInfo, type, result->operands) || + parser->addTypeToList(indexType, result->types); +} + +bool DimOp::verify() const { + // Check that we have an integer index operand. + auto indexAttr = getAttrOfType("index"); + if (!indexAttr) + return emitOpError("requires an integer attribute named 'index'"); + uint64_t index = indexAttr.getValue().getZExtValue(); + + auto type = getOperand()->getType(); + if (auto tensorType = type.dyn_cast()) { + if (index >= tensorType.getRank()) + return emitOpError("index is out of range"); + } else if (auto memrefType = type.dyn_cast()) { + if (index >= memrefType.getRank()) + return emitOpError("index is out of range"); + + } else if (type.isa()) { + // ok, assumed to be in-range. + } else { + return emitOpError("requires an operand with tensor or memref type"); + } + + return false; +} + +Attribute DimOp::constantFold(ArrayRef operands, + MLIRContext *context) const { + // Constant fold dim when the size along the index referred to is a constant. + auto opType = getOperand()->getType(); + int64_t indexSize = -1; + if (auto tensorType = opType.dyn_cast()) { + indexSize = tensorType.getShape()[getIndex()]; + } else if (auto memrefType = opType.dyn_cast()) { + indexSize = memrefType.getShape()[getIndex()]; + } + + if (indexSize >= 0) + return IntegerAttr::get(IndexType::get(context), indexSize); + + return nullptr; +} + +//===----------------------------------------------------------------------===// +// DivISOp +//===----------------------------------------------------------------------===// + +Attribute DivISOp::constantFold(ArrayRef operands, + MLIRContext *context) const { + assert(operands.size() == 2 && "binary operation takes two operands"); + (void)context; + + auto lhs = operands.front().dyn_cast_or_null(); + auto rhs = operands.back().dyn_cast_or_null(); + if (!lhs || !rhs) + return {}; + + // Don't fold if it requires division by zero. + if (rhs.getValue().isNullValue()) { + return {}; + } + + // Don't fold if it would overflow. + bool overflow; + auto result = lhs.getValue().sdiv_ov(rhs.getValue(), overflow); + return overflow ? IntegerAttr{} : IntegerAttr::get(lhs.getType(), result); +} + +//===----------------------------------------------------------------------===// +// DivIUOp +//===----------------------------------------------------------------------===// + +Attribute DivIUOp::constantFold(ArrayRef operands, + MLIRContext *context) const { + assert(operands.size() == 2 && "binary operation takes two operands"); + (void)context; + + auto lhs = operands.front().dyn_cast_or_null(); + auto rhs = operands.back().dyn_cast_or_null(); + if (!lhs || !rhs) + return {}; + + // Don't fold if it requires division by zero. + if (rhs.getValue().isNullValue()) { + return {}; + } + + return IntegerAttr::get(lhs.getType(), lhs.getValue().udiv(rhs.getValue())); +} + +// --------------------------------------------------------------------------- +// DmaStartOp +// --------------------------------------------------------------------------- + +void DmaStartOp::build(Builder *builder, OperationState *result, + Value *srcMemRef, ArrayRef srcIndices, + Value *destMemRef, ArrayRef destIndices, + Value *numElements, Value *tagMemRef, + ArrayRef tagIndices, Value *stride, + Value *elementsPerStride) { + result->addOperands(srcMemRef); + result->addOperands(srcIndices); + result->addOperands(destMemRef); + result->addOperands(destIndices); + result->addOperands(numElements); + result->addOperands(tagMemRef); + result->addOperands(tagIndices); + if (stride) { + result->addOperands(stride); + result->addOperands(elementsPerStride); + } +} + +void DmaStartOp::print(OpAsmPrinter *p) const { + *p << getOperationName() << ' ' << *getSrcMemRef() << '['; + p->printOperands(getSrcIndices()); + *p << "], " << *getDstMemRef() << '['; + p->printOperands(getDstIndices()); + *p << "], " << *getNumElements(); + *p << ", " << *getTagMemRef() << '['; + p->printOperands(getTagIndices()); + *p << ']'; + if (isStrided()) { + *p << ", " << *getStride(); + *p << ", " << *getNumElementsPerStride(); + } + p->printOptionalAttrDict(getAttrs()); + *p << " : " << getSrcMemRef()->getType(); + *p << ", " << getDstMemRef()->getType(); + *p << ", " << getTagMemRef()->getType(); +} + +// Parse DmaStartOp. +// Ex: +// %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, +// %tag[%index], %stride, %num_elt_per_stride : +// : memref<3076 x f32, 0>, +// memref<1024 x f32, 2>, +// memref<1 x i32> +// +bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) { + OpAsmParser::OperandType srcMemRefInfo; + SmallVector srcIndexInfos; + OpAsmParser::OperandType dstMemRefInfo; + SmallVector dstIndexInfos; + OpAsmParser::OperandType numElementsInfo; + OpAsmParser::OperandType tagMemrefInfo; + SmallVector tagIndexInfos; + SmallVector strideInfo; + + SmallVector types; + auto indexType = parser->getBuilder().getIndexType(); + + // Parse and resolve the following list of operands: + // *) source memref followed by its indices (in square brackets). + // *) destination memref followed by its indices (in square brackets). + // *) dma size in KiB. + if (parser->parseOperand(srcMemRefInfo) || + parser->parseOperandList(srcIndexInfos, -1, + OpAsmParser::Delimiter::Square) || + parser->parseComma() || parser->parseOperand(dstMemRefInfo) || + parser->parseOperandList(dstIndexInfos, -1, + OpAsmParser::Delimiter::Square) || + parser->parseComma() || parser->parseOperand(numElementsInfo) || + parser->parseComma() || parser->parseOperand(tagMemrefInfo) || + parser->parseOperandList(tagIndexInfos, -1, + OpAsmParser::Delimiter::Square)) + return true; + + // Parse optional stride and elements per stride. + if (parser->parseTrailingOperandList(strideInfo)) { + return true; + } + if (!strideInfo.empty() && strideInfo.size() != 2) { + return parser->emitError(parser->getNameLoc(), + "expected two stride related operands"); + } + bool isStrided = strideInfo.size() == 2; + + if (parser->parseColonTypeList(types)) + return true; + + if (types.size() != 3) + return parser->emitError(parser->getNameLoc(), "fewer/more types expected"); + + if (parser->resolveOperand(srcMemRefInfo, types[0], result->operands) || + parser->resolveOperands(srcIndexInfos, indexType, result->operands) || + parser->resolveOperand(dstMemRefInfo, types[1], result->operands) || + parser->resolveOperands(dstIndexInfos, indexType, result->operands) || + // size should be an index. + parser->resolveOperand(numElementsInfo, indexType, result->operands) || + parser->resolveOperand(tagMemrefInfo, types[2], result->operands) || + // tag indices should be index. + parser->resolveOperands(tagIndexInfos, indexType, result->operands)) + return true; + + if (!types[0].isa()) + return parser->emitError(parser->getNameLoc(), + "expected source to be of memref type"); + + if (!types[1].isa()) + return parser->emitError(parser->getNameLoc(), + "expected destination to be of memref type"); + + if (!types[2].isa()) + return parser->emitError(parser->getNameLoc(), + "expected tag to be of memref type"); + + if (isStrided) { + if (parser->resolveOperand(strideInfo[0], indexType, result->operands) || + parser->resolveOperand(strideInfo[1], indexType, result->operands)) + return true; + } + + // Check that source/destination index list size matches associated rank. + if (srcIndexInfos.size() != types[0].cast().getRank() || + dstIndexInfos.size() != types[1].cast().getRank()) + return parser->emitError(parser->getNameLoc(), + "memref rank not equal to indices count"); + + if (tagIndexInfos.size() != types[2].cast().getRank()) + return parser->emitError(parser->getNameLoc(), + "tag memref rank not equal to indices count"); + + return false; +} + +bool DmaStartOp::verify() const { + // DMAs from different memory spaces supported. + if (getSrcMemorySpace() == getDstMemorySpace()) { + return emitOpError("DMA should be between different memory spaces"); + } + + if (getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() + + getDstMemRefRank() + 3 + 1 && + getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() + + getDstMemRefRank() + 3 + 1 + 2) { + return emitOpError("incorrect number of operands"); + } + return false; +} + +void DmaStartOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + /// dma_start(memrefcast) -> dma_start + results.push_back( + std::make_unique(getOperationName(), context)); +} + +// --------------------------------------------------------------------------- +// DmaWaitOp +// --------------------------------------------------------------------------- + +void DmaWaitOp::build(Builder *builder, OperationState *result, + Value *tagMemRef, ArrayRef tagIndices, + Value *numElements) { + result->addOperands(tagMemRef); + result->addOperands(tagIndices); + result->addOperands(numElements); +} + +void DmaWaitOp::print(OpAsmPrinter *p) const { + *p << getOperationName() << ' '; + // Print operands. + p->printOperand(getTagMemRef()); + *p << '['; + p->printOperands(getTagIndices()); + *p << "], "; + p->printOperand(getNumElements()); + p->printOptionalAttrDict(getAttrs()); + *p << " : " << getTagMemRef()->getType(); +} + +// Parse DmaWaitOp. +// Eg: +// dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4> +// +bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) { + OpAsmParser::OperandType tagMemrefInfo; + SmallVector tagIndexInfos; + Type type; + auto indexType = parser->getBuilder().getIndexType(); + OpAsmParser::OperandType numElementsInfo; + + // Parse tag memref, its indices, and dma size. + if (parser->parseOperand(tagMemrefInfo) || + parser->parseOperandList(tagIndexInfos, -1, + OpAsmParser::Delimiter::Square) || + parser->parseComma() || parser->parseOperand(numElementsInfo) || + parser->parseColonType(type) || + parser->resolveOperand(tagMemrefInfo, type, result->operands) || + parser->resolveOperands(tagIndexInfos, indexType, result->operands) || + parser->resolveOperand(numElementsInfo, indexType, result->operands)) + return true; + + if (!type.isa()) + return parser->emitError(parser->getNameLoc(), + "expected tag to be of memref type"); + + if (tagIndexInfos.size() != type.cast().getRank()) + return parser->emitError(parser->getNameLoc(), + "tag memref rank not equal to indices count"); + + return false; +} + +void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + /// dma_wait(memrefcast) -> dma_wait + results.push_back( + std::make_unique(getOperationName(), context)); +} + +//===----------------------------------------------------------------------===// +// ExtractElementOp +//===----------------------------------------------------------------------===// + +void ExtractElementOp::build(Builder *builder, OperationState *result, + Value *aggregate, ArrayRef indices) { + auto aggregateType = aggregate->getType().cast(); + result->addOperands(aggregate); + result->addOperands(indices); + result->types.push_back(aggregateType.getElementType()); +} + +void ExtractElementOp::print(OpAsmPrinter *p) const { + *p << "extract_element " << *getAggregate() << '['; + p->printOperands(getIndices()); + *p << ']'; + p->printOptionalAttrDict(getAttrs()); + *p << " : " << getAggregate()->getType(); +} + +bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) { + OpAsmParser::OperandType aggregateInfo; + SmallVector indexInfo; + VectorOrTensorType type; + + auto affineIntTy = parser->getBuilder().getIndexType(); + return parser->parseOperand(aggregateInfo) || + parser->parseOperandList(indexInfo, -1, + OpAsmParser::Delimiter::Square) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(type) || + parser->resolveOperand(aggregateInfo, type, result->operands) || + parser->resolveOperands(indexInfo, affineIntTy, result->operands) || + parser->addTypeToList(type.getElementType(), result->types); +} + +bool ExtractElementOp::verify() const { + if (getNumOperands() == 0) + return emitOpError("expected an aggregate to index into"); + + auto aggregateType = getAggregate()->getType().dyn_cast(); + if (!aggregateType) + return emitOpError("first operand must be a vector or tensor"); + + if (getType() != aggregateType.getElementType()) + return emitOpError("result type must match element type of aggregate"); + + for (auto *idx : getIndices()) + if (!idx->getType().isIndex()) + return emitOpError("index to extract_element must have 'index' type"); + + // Verify the # indices match if we have a ranked type. + auto aggregateRank = aggregateType.getRank(); + if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1) + return emitOpError("incorrect number of indices for extract_element"); + + return false; +} + +Attribute ExtractElementOp::constantFold(ArrayRef operands, + MLIRContext *context) const { + assert(operands.size() > 1 && "extract_element takes atleast one operands"); + + // The aggregate operand must be a known constant. + Attribute aggregate = operands.front(); + if (!aggregate) + return Attribute(); + + // If this is a splat elements attribute, simply return the value. All of the + // elements of a splat attribute are the same. + if (auto splatAggregate = aggregate.dyn_cast()) + return splatAggregate.getValue(); + + // Otherwise, collect the constant indices into the aggregate. + SmallVector indices; + for (Attribute indice : llvm::drop_begin(operands, 1)) { + if (!indice || !indice.isa()) + return Attribute(); + indices.push_back(indice.cast().getInt()); + } + + // If this is an elements attribute, query the value at the given indices. + if (auto elementsAttr = aggregate.dyn_cast()) + return elementsAttr.getValue(indices); + return Attribute(); +} + +//===----------------------------------------------------------------------===// +// LoadOp +//===----------------------------------------------------------------------===// + +void LoadOp::build(Builder *builder, OperationState *result, Value *memref, + ArrayRef indices) { + auto memrefType = memref->getType().cast(); + result->addOperands(memref); + result->addOperands(indices); + result->types.push_back(memrefType.getElementType()); +} + +void LoadOp::print(OpAsmPrinter *p) const { + *p << "load " << *getMemRef() << '['; + p->printOperands(getIndices()); + *p << ']'; + p->printOptionalAttrDict(getAttrs()); + *p << " : " << getMemRefType(); +} + +bool LoadOp::parse(OpAsmParser *parser, OperationState *result) { + OpAsmParser::OperandType memrefInfo; + SmallVector indexInfo; + MemRefType type; + + auto affineIntTy = parser->getBuilder().getIndexType(); + return parser->parseOperand(memrefInfo) || + parser->parseOperandList(indexInfo, -1, + OpAsmParser::Delimiter::Square) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(type) || + parser->resolveOperand(memrefInfo, type, result->operands) || + parser->resolveOperands(indexInfo, affineIntTy, result->operands) || + parser->addTypeToList(type.getElementType(), result->types); +} + +bool LoadOp::verify() const { + if (getNumOperands() == 0) + return emitOpError("expected a memref to load from"); + + auto memRefType = getMemRef()->getType().dyn_cast(); + if (!memRefType) + return emitOpError("first operand must be a memref"); + + if (getType() != memRefType.getElementType()) + return emitOpError("result type must match element type of memref"); + + if (memRefType.getRank() != getNumOperands() - 1) + return emitOpError("incorrect number of indices for load"); + + for (auto *idx : getIndices()) + if (!idx->getType().isIndex()) + return emitOpError("index to load must have 'index' type"); + + // TODO: Verify we have the right number of indices. + + // TODO: in Function verify that the indices are parameters, IV's, or the + // result of an affine.apply. + return false; +} + +void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + /// load(memrefcast) -> load + results.push_back( + std::make_unique(getOperationName(), context)); +} + +//===----------------------------------------------------------------------===// +// MemRefCastOp +//===----------------------------------------------------------------------===// + +bool MemRefCastOp::verify() const { + auto opType = getOperand()->getType().dyn_cast(); + auto resType = getType().dyn_cast(); + if (!opType || !resType) + return emitOpError("requires input and result types to be memrefs"); + + if (opType == resType) + return emitOpError("requires the input and result type to be different"); + + if (opType.getElementType() != resType.getElementType()) + return emitOpError( + "requires input and result element types to be the same"); + + if (opType.getAffineMaps() != resType.getAffineMaps()) + return emitOpError("requires input and result mappings to be the same"); + + if (opType.getMemorySpace() != resType.getMemorySpace()) + return emitOpError( + "requires input and result memory spaces to be the same"); + + // They must have the same rank, and any specified dimensions must match. + if (opType.getRank() != resType.getRank()) + return emitOpError("requires input and result ranks to match"); + + for (unsigned i = 0, e = opType.getRank(); i != e; ++i) { + int64_t opDim = opType.getDimSize(i), resultDim = resType.getDimSize(i); + if (opDim != -1 && resultDim != -1 && opDim != resultDim) + return emitOpError("requires static dimensions to match"); + } + + return false; +} + +//===----------------------------------------------------------------------===// +// MulFOp +//===----------------------------------------------------------------------===// + +Attribute MulFOp::constantFold(ArrayRef operands, + MLIRContext *context) const { + return constFoldBinaryOp( + operands, [](APFloat a, APFloat b) { return a * b; }); +} + +//===----------------------------------------------------------------------===// +// MulIOp +//===----------------------------------------------------------------------===// + +Attribute MulIOp::constantFold(ArrayRef operands, + MLIRContext *context) const { + // TODO: Handle the overflow case. + return constFoldBinaryOp(operands, + [](APInt a, APInt b) { return a * b; }); +} + +Value *MulIOp::fold() { + /// muli(x, 0) -> 0 + if (matchPattern(getOperand(1), m_Zero())) + return getOperand(1); + /// muli(x, 1) -> x + if (matchPattern(getOperand(1), m_One())) + return getOperand(0); + return nullptr; +} + +//===----------------------------------------------------------------------===// +// RemISOp +//===----------------------------------------------------------------------===// + +Attribute RemISOp::constantFold(ArrayRef operands, + MLIRContext *context) const { + assert(operands.size() == 2 && "remis takes two operands"); + + auto rhs = operands.back().dyn_cast_or_null(); + if (!rhs) + return {}; + + // x % 1 = 0 + if (rhs.getValue().isOneValue()) + return IntegerAttr::get(rhs.getType(), + APInt(rhs.getValue().getBitWidth(), 0)); + + // Don't fold if it requires division by zero. + if (rhs.getValue().isNullValue()) { + return {}; + } + + auto lhs = operands.front().dyn_cast_or_null(); + if (!lhs) + return {}; + + return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhs.getValue())); +} + +//===----------------------------------------------------------------------===// +// RemIUOp +//===----------------------------------------------------------------------===// + +Attribute RemIUOp::constantFold(ArrayRef operands, + MLIRContext *context) const { + assert(operands.size() == 2 && "remiu takes two operands"); + + auto rhs = operands.back().dyn_cast_or_null(); + if (!rhs) + return {}; + + // x % 1 = 0 + if (rhs.getValue().isOneValue()) + return IntegerAttr::get(rhs.getType(), + APInt(rhs.getValue().getBitWidth(), 0)); + + // Don't fold if it requires division by zero. + if (rhs.getValue().isNullValue()) { + return {}; + } + + auto lhs = operands.front().dyn_cast_or_null(); + if (!lhs) + return {}; + + return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhs.getValue())); +} + +//===----------------------------------------------------------------------===// +// SelectOp +//===----------------------------------------------------------------------===// + +void SelectOp::build(Builder *builder, OperationState *result, Value *condition, + Value *trueValue, Value *falseValue) { + result->addOperands({condition, trueValue, falseValue}); + result->addTypes(trueValue->getType()); +} + +bool SelectOp::parse(OpAsmParser *parser, OperationState *result) { + SmallVector ops; + SmallVector attrs; + Type type; + + if (parser->parseOperandList(ops, 3) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(type)) + return true; + + auto i1Type = getCheckedI1SameShape(&parser->getBuilder(), type); + if (!i1Type) + return parser->emitError(parser->getNameLoc(), + "expected type with valid i1 shape"); + + SmallVector types = {i1Type, type, type}; + return parser->resolveOperands(ops, types, parser->getNameLoc(), + result->operands) || + parser->addTypeToList(type, result->types); +} + +void SelectOp::print(OpAsmPrinter *p) const { + *p << getOperationName() << ' '; + p->printOperands(getInstruction()->getOperands()); + *p << " : " << getTrueValue()->getType(); + p->printOptionalAttrDict(getAttrs()); +} + +bool SelectOp::verify() const { + auto conditionType = getCondition()->getType(); + auto trueType = getTrueValue()->getType(); + auto falseType = getFalseValue()->getType(); + + if (trueType != falseType) + return emitOpError( + "requires 'true' and 'false' arguments to be of the same type"); + + if (checkI1SameShape(trueType, conditionType)) + return emitOpError("requires the condition to have the same shape as " + "arguments with elemental type i1"); + + return false; +} + +Value *SelectOp::fold() { + auto *condition = getCondition(); + + // select true, %0, %1 => %0 + if (matchPattern(condition, m_One())) + return getTrueValue(); + + // select false, %0, %1 => %1 + if (matchPattern(condition, m_Zero())) + return getFalseValue(); + return nullptr; +} + +//===----------------------------------------------------------------------===// +// StoreOp +//===----------------------------------------------------------------------===// + +void StoreOp::build(Builder *builder, OperationState *result, + Value *valueToStore, Value *memref, + ArrayRef indices) { + result->addOperands(valueToStore); + result->addOperands(memref); + result->addOperands(indices); +} + +void StoreOp::print(OpAsmPrinter *p) const { + *p << "store " << *getValueToStore(); + *p << ", " << *getMemRef() << '['; + p->printOperands(getIndices()); + *p << ']'; + p->printOptionalAttrDict(getAttrs()); + *p << " : " << getMemRefType(); +} + +bool StoreOp::parse(OpAsmParser *parser, OperationState *result) { + OpAsmParser::OperandType storeValueInfo; + OpAsmParser::OperandType memrefInfo; + SmallVector indexInfo; + MemRefType memrefType; + + auto affineIntTy = parser->getBuilder().getIndexType(); + return parser->parseOperand(storeValueInfo) || parser->parseComma() || + parser->parseOperand(memrefInfo) || + parser->parseOperandList(indexInfo, -1, + OpAsmParser::Delimiter::Square) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(memrefType) || + parser->resolveOperand(storeValueInfo, memrefType.getElementType(), + result->operands) || + parser->resolveOperand(memrefInfo, memrefType, result->operands) || + parser->resolveOperands(indexInfo, affineIntTy, result->operands); +} + +bool StoreOp::verify() const { + if (getNumOperands() < 2) + return emitOpError("expected a value to store and a memref"); + + // Second operand is a memref type. + auto memRefType = getMemRef()->getType().dyn_cast(); + if (!memRefType) + return emitOpError("second operand must be a memref"); + + // First operand must have same type as memref element type. + if (getValueToStore()->getType() != memRefType.getElementType()) + return emitOpError("first operand must have same type memref element type"); + + if (getNumOperands() != 2 + memRefType.getRank()) + return emitOpError("store index operand count not equal to memref rank"); + + for (auto *idx : getIndices()) + if (!idx->getType().isIndex()) + return emitOpError("index to load must have 'index' type"); + + // TODO: Verify we have the right number of indices. + + // TODO: in Function verify that the indices are parameters, IV's, or the + // result of an affine.apply. + return false; +} + +void StoreOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + /// store(memrefcast) -> store + results.push_back( + std::make_unique(getOperationName(), context)); +} + +//===----------------------------------------------------------------------===// +// SubFOp +//===----------------------------------------------------------------------===// + +Attribute SubFOp::constantFold(ArrayRef operands, + MLIRContext *context) const { + return constFoldBinaryOp( + operands, [](APFloat a, APFloat b) { return a - b; }); +} + +//===----------------------------------------------------------------------===// +// SubIOp +//===----------------------------------------------------------------------===// + +Attribute SubIOp::constantFold(ArrayRef operands, + MLIRContext *context) const { + return constFoldBinaryOp(operands, + [](APInt a, APInt b) { return a - b; }); +} + +namespace { +/// subi(x,x) -> 0 +/// +struct SimplifyXMinusX : public RewritePattern { + SimplifyXMinusX(MLIRContext *context) + : RewritePattern(SubIOp::getOperationName(), 1, context) {} + + PatternMatchResult match(Instruction *op) const override { + auto subi = op->cast(); + if (subi->getOperand(0) == subi->getOperand(1)) + return matchSuccess(); + + return matchFailure(); + } + void rewrite(Instruction *op, PatternRewriter &rewriter) const override { + auto subi = op->cast(); + auto result = + rewriter.create(op->getLoc(), 0, subi->getType()); + + rewriter.replaceOp(op, {result}); + } +}; +} // end anonymous namespace. + +void SubIOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.push_back(std::make_unique(context)); +} + +//===----------------------------------------------------------------------===// +// TensorCastOp +//===----------------------------------------------------------------------===// + +bool TensorCastOp::verify() const { + auto opType = getOperand()->getType().dyn_cast(); + auto resType = getType().dyn_cast(); + if (!opType || !resType) + return emitOpError("requires input and result types to be tensors"); + + if (opType == resType) + return emitOpError("requires the input and result type to be different"); + + if (opType.getElementType() != resType.getElementType()) + return emitOpError( + "requires input and result element types to be the same"); + + // If the source or destination are unranked, then the cast is valid. + auto opRType = opType.dyn_cast(); + auto resRType = resType.dyn_cast(); + if (!opRType || !resRType) + return false; + + // If they are both ranked, they have to have the same rank, and any specified + // dimensions must match. + if (opRType.getRank() != resRType.getRank()) + return emitOpError("requires input and result ranks to match"); + + for (unsigned i = 0, e = opRType.getRank(); i != e; ++i) { + int64_t opDim = opRType.getDimSize(i), resultDim = resRType.getDimSize(i); + if (opDim != -1 && resultDim != -1 && opDim != resultDim) + return emitOpError("requires static dimensions to match"); + } + + return false; +} + diff --git a/mlir/lib/StandardOps/StandardOps.cpp b/mlir/lib/StandardOps/StandardOps.cpp deleted file mode 100644 index 6c3b32dfe07..00000000000 --- a/mlir/lib/StandardOps/StandardOps.cpp +++ /dev/null @@ -1,1609 +0,0 @@ -//===- StandardOps.cpp - Standard MLIR Operations -------------------------===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#include "mlir/StandardOps/StandardOps.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/IR/Value.h" -#include "mlir/Support/MathExtras.h" -#include "mlir/Support/STLExtras.h" -#include "llvm/ADT/StringSwitch.h" -#include "llvm/Support/raw_ostream.h" -using namespace mlir; - -//===----------------------------------------------------------------------===// -// StandardOpsDialect -//===----------------------------------------------------------------------===// - -StandardOpsDialect::StandardOpsDialect(MLIRContext *context) - : Dialect(/*namePrefix=*/"", context) { - addOperations(); -} - -//===----------------------------------------------------------------------===// -// Common canonicalization pattern support logic -//===----------------------------------------------------------------------===// - -namespace { -/// This is a common class used for patterns of the form -/// "someop(memrefcast) -> someop". It folds the source of any memref_cast -/// into the root operation directly. -struct MemRefCastFolder : public RewritePattern { - /// The rootOpName is the name of the root operation to match against. - MemRefCastFolder(StringRef rootOpName, MLIRContext *context) - : RewritePattern(rootOpName, 1, context) {} - - PatternMatchResult match(Instruction *op) const override { - for (auto *operand : op->getOperands()) - if (matchPattern(operand, m_Op())) - return matchSuccess(); - - return matchFailure(); - } - - void rewrite(Instruction *op, PatternRewriter &rewriter) const override { - for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) - if (auto *memref = op->getOperand(i)->getDefiningInst()) - if (auto cast = memref->dyn_cast()) - op->setOperand(i, cast->getOperand()); - rewriter.updatedRootInPlace(op); - } -}; - -/// Performs const folding `calculate` with element-wise behavior on the two -/// attributes in `operands` and returns the result if possible. -template > -Attribute constFoldBinaryOp(ArrayRef operands, - const CalculationT &calculate) { - assert(operands.size() == 2 && "binary op takes two operands"); - - if (auto lhs = operands[0].dyn_cast_or_null()) { - auto rhs = operands[1].dyn_cast_or_null(); - if (!rhs || lhs.getType() != rhs.getType()) - return {}; - - return AttrElementT::get(lhs.getType(), - calculate(lhs.getValue(), rhs.getValue())); - } else if (auto lhs = operands[0].dyn_cast_or_null()) { - auto rhs = operands[1].dyn_cast_or_null(); - if (!rhs || lhs.getType() != rhs.getType()) - return {}; - - auto elementResult = constFoldBinaryOp( - {lhs.getValue(), rhs.getValue()}, calculate); - if (!elementResult) - return {}; - - return SplatElementsAttr::get(lhs.getType(), elementResult); - } - return {}; -} -} // end anonymous namespace. - -//===----------------------------------------------------------------------===// -// AddFOp -//===----------------------------------------------------------------------===// - -Attribute AddFOp::constantFold(ArrayRef operands, - MLIRContext *context) const { - return constFoldBinaryOp( - operands, [](APFloat a, APFloat b) { return a + b; }); -} - -//===----------------------------------------------------------------------===// -// AddIOp -//===----------------------------------------------------------------------===// - -Attribute AddIOp::constantFold(ArrayRef operands, - MLIRContext *context) const { - return constFoldBinaryOp(operands, - [](APInt a, APInt b) { return a + b; }); -} - -Value *AddIOp::fold() { - /// addi(x, 0) -> x - if (matchPattern(getOperand(1), m_Zero())) - return getOperand(0); - - return nullptr; -} - -//===----------------------------------------------------------------------===// -// AllocOp -//===----------------------------------------------------------------------===// - -void AllocOp::build(Builder *builder, OperationState *result, - MemRefType memrefType, ArrayRef operands) { - result->addOperands(operands); - result->types.push_back(memrefType); -} - -void AllocOp::print(OpAsmPrinter *p) const { - MemRefType type = getType(); - *p << "alloc"; - // Print dynamic dimension operands. - printDimAndSymbolList(operand_begin(), operand_end(), - type.getNumDynamicDims(), p); - p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{"map"}); - *p << " : " << type; -} - -bool AllocOp::parse(OpAsmParser *parser, OperationState *result) { - MemRefType type; - - // Parse the dimension operands and optional symbol operands, followed by a - // memref type. - unsigned numDimOperands; - if (parseDimAndSymbolList(parser, result->operands, numDimOperands) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(type)) - return true; - - // Check numDynamicDims against number of question marks in memref type. - // Note: this check remains here (instead of in verify()), because the - // partition between dim operands and symbol operands is lost after parsing. - // Verification still checks that the total number of operands matches - // the number of symbols in the affine map, plus the number of dynamic - // dimensions in the memref. - if (numDimOperands != type.getNumDynamicDims()) { - return parser->emitError(parser->getNameLoc(), - "dimension operand count does not equal memref " - "dynamic dimension count"); - } - result->types.push_back(type); - return false; -} - -bool AllocOp::verify() const { - auto memRefType = getResult()->getType().dyn_cast(); - if (!memRefType) - return emitOpError("result must be a memref"); - - unsigned numSymbols = 0; - if (!memRefType.getAffineMaps().empty()) { - AffineMap affineMap = memRefType.getAffineMaps()[0]; - // Store number of symbols used in affine map (used in subsequent check). - numSymbols = affineMap.getNumSymbols(); - // TODO(zinenko): this check does not belong to AllocOp, or any other op but - // to the type system itself. It has been partially hoisted to Parser but - // remains here in case an AllocOp gets constructed programmatically. - // Remove when we can emit errors directly from *Type::get(...) functions. - // - // Verify that the layout affine map matches the rank of the memref. - if (affineMap.getNumDims() != memRefType.getRank()) - return emitOpError("affine map dimension count must equal memref rank"); - } - unsigned numDynamicDims = memRefType.getNumDynamicDims(); - // Check that the total number of operands matches the number of symbols in - // the affine map, plus the number of dynamic dimensions specified in the - // memref type. - if (getInstruction()->getNumOperands() != numDynamicDims + numSymbols) { - return emitOpError( - "operand count does not equal dimension plus symbol operand count"); - } - // Verify that all operands are of type Index. - for (auto *operand : getOperands()) { - if (!operand->getType().isIndex()) - return emitOpError("requires operands to be of type Index"); - } - return false; -} - -namespace { -/// Fold constant dimensions into an alloc instruction. -struct SimplifyAllocConst : public RewritePattern { - SimplifyAllocConst(MLIRContext *context) - : RewritePattern(AllocOp::getOperationName(), 1, context) {} - - PatternMatchResult match(Instruction *op) const override { - auto alloc = op->cast(); - - // Check to see if any dimensions operands are constants. If so, we can - // substitute and drop them. - for (auto *operand : alloc->getOperands()) - if (matchPattern(operand, m_ConstantIndex())) - return matchSuccess(); - return matchFailure(); - } - - void rewrite(Instruction *op, PatternRewriter &rewriter) const override { - auto allocOp = op->cast(); - auto memrefType = allocOp->getType(); - - // Ok, we have one or more constant operands. Collect the non-constant ones - // and keep track of the resultant memref type to build. - SmallVector newShapeConstants; - newShapeConstants.reserve(memrefType.getRank()); - SmallVector newOperands; - SmallVector droppedOperands; - - unsigned dynamicDimPos = 0; - for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { - int64_t dimSize = memrefType.getDimSize(dim); - // If this is already static dimension, keep it. - if (dimSize != -1) { - newShapeConstants.push_back(dimSize); - continue; - } - auto *defOp = allocOp->getOperand(dynamicDimPos)->getDefiningInst(); - OpPointer constantIndexOp; - if (defOp && (constantIndexOp = defOp->dyn_cast())) { - // Dynamic shape dimension will be folded. - newShapeConstants.push_back(constantIndexOp->getValue()); - // Record to check for zero uses later below. - droppedOperands.push_back(constantIndexOp); - } else { - // Dynamic shape dimension not folded; copy operand from old memref. - newShapeConstants.push_back(-1); - newOperands.push_back(allocOp->getOperand(dynamicDimPos)); - } - dynamicDimPos++; - } - - // Create new memref type (which will have fewer dynamic dimensions). - auto newMemRefType = MemRefType::get( - newShapeConstants, memrefType.getElementType(), - memrefType.getAffineMaps(), memrefType.getMemorySpace()); - assert(newOperands.size() == newMemRefType.getNumDynamicDims()); - - // Create and insert the alloc op for the new memref. - auto newAlloc = - rewriter.create(allocOp->getLoc(), newMemRefType, newOperands); - // Insert a cast so we have the same type as the old alloc. - auto resultCast = rewriter.create(allocOp->getLoc(), newAlloc, - allocOp->getType()); - - rewriter.replaceOp(op, {resultCast}, droppedOperands); - } -}; - -/// Fold alloc instructions with no uses. Alloc has side effects on the heap, -/// but can still be deleted if it has zero uses. -struct SimplifyDeadAlloc : public RewritePattern { - SimplifyDeadAlloc(MLIRContext *context) - : RewritePattern(AllocOp::getOperationName(), 1, context) {} - - PatternMatchResult match(Instruction *op) const override { - auto alloc = op->cast(); - // Check if the alloc'ed value has no uses. - return alloc->use_empty() ? matchSuccess() : matchFailure(); - } - - void rewrite(Instruction *op, PatternRewriter &rewriter) const override { - // Erase the alloc operation. - op->erase(); - } -}; -} // end anonymous namespace. - -void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.push_back(std::make_unique(context)); - results.push_back(std::make_unique(context)); -} - -//===----------------------------------------------------------------------===// -// CallOp -//===----------------------------------------------------------------------===// - -void CallOp::build(Builder *builder, OperationState *result, Function *callee, - ArrayRef operands) { - result->addOperands(operands); - result->addAttribute("callee", builder->getFunctionAttr(callee)); - result->addTypes(callee->getType().getResults()); -} - -bool CallOp::parse(OpAsmParser *parser, OperationState *result) { - StringRef calleeName; - llvm::SMLoc calleeLoc; - FunctionType calleeType; - SmallVector operands; - Function *callee = nullptr; - if (parser->parseFunctionName(calleeName, calleeLoc) || - parser->parseOperandList(operands, /*requiredOperandCount=*/-1, - OpAsmParser::Delimiter::Paren) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(calleeType) || - parser->resolveFunctionName(calleeName, calleeType, calleeLoc, callee) || - parser->addTypesToList(calleeType.getResults(), result->types) || - parser->resolveOperands(operands, calleeType.getInputs(), calleeLoc, - result->operands)) - return true; - - result->addAttribute("callee", parser->getBuilder().getFunctionAttr(callee)); - return false; -} - -void CallOp::print(OpAsmPrinter *p) const { - *p << "call "; - p->printFunctionReference(getCallee()); - *p << '('; - p->printOperands(getOperands()); - *p << ')'; - p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{"callee"}); - *p << " : " << getCallee()->getType(); -} - -bool CallOp::verify() const { - // Check that the callee attribute was specified. - auto fnAttr = getAttrOfType("callee"); - if (!fnAttr) - return emitOpError("requires a 'callee' function attribute"); - - // Verify that the operand and result types match the callee. - auto fnType = fnAttr.getValue()->getType(); - if (fnType.getNumInputs() != getNumOperands()) - return emitOpError("incorrect number of operands for callee"); - - for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) { - if (getOperand(i)->getType() != fnType.getInput(i)) - return emitOpError("operand type mismatch"); - } - - if (fnType.getNumResults() != getNumResults()) - return emitOpError("incorrect number of results for callee"); - - for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) { - if (getResult(i)->getType() != fnType.getResult(i)) - return emitOpError("result type mismatch"); - } - - return false; -} - -//===----------------------------------------------------------------------===// -// CallIndirectOp -//===----------------------------------------------------------------------===// -namespace { -/// Fold indirect calls that have a constant function as the callee operand. -struct SimplifyIndirectCallWithKnownCallee : public RewritePattern { - SimplifyIndirectCallWithKnownCallee(MLIRContext *context) - : RewritePattern(CallIndirectOp::getOperationName(), 1, context) {} - - PatternMatchResult match(Instruction *op) const override { - auto indirectCall = op->cast(); - - // Check that the callee is a constant operation. - Value *callee = indirectCall->getCallee(); - Instruction *calleeInst = callee->getDefiningInst(); - if (!calleeInst || !calleeInst->isa()) - return matchFailure(); - - // Check that the constant callee is a function. - if (calleeInst->cast()->getValue().isa()) - return matchSuccess(); - return matchFailure(); - } - void rewrite(Instruction *op, PatternRewriter &rewriter) const override { - auto indirectCall = op->cast(); - auto calleeOp = - indirectCall->getCallee()->getDefiningInst()->cast(); - - // Replace with a direct call. - Function *calledFn = calleeOp->getValue().cast().getValue(); - SmallVector callOperands(indirectCall->getArgOperands()); - rewriter.replaceOpWithNewOp(op, calledFn, callOperands); - } -}; -} // end anonymous namespace. - -void CallIndirectOp::build(Builder *builder, OperationState *result, - Value *callee, ArrayRef operands) { - auto fnType = callee->getType().cast(); - result->operands.push_back(callee); - result->addOperands(operands); - result->addTypes(fnType.getResults()); -} - -bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) { - FunctionType calleeType; - OpAsmParser::OperandType callee; - llvm::SMLoc operandsLoc; - SmallVector operands; - return parser->parseOperand(callee) || - parser->getCurrentLocation(&operandsLoc) || - parser->parseOperandList(operands, /*requiredOperandCount=*/-1, - OpAsmParser::Delimiter::Paren) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(calleeType) || - parser->resolveOperand(callee, calleeType, result->operands) || - parser->resolveOperands(operands, calleeType.getInputs(), operandsLoc, - result->operands) || - parser->addTypesToList(calleeType.getResults(), result->types); -} - -void CallIndirectOp::print(OpAsmPrinter *p) const { - *p << "call_indirect "; - p->printOperand(getCallee()); - *p << '('; - auto operandRange = getOperands(); - p->printOperands(++operandRange.begin(), operandRange.end()); - *p << ')'; - p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{"callee"}); - *p << " : " << getCallee()->getType(); -} - -bool CallIndirectOp::verify() const { - // The callee must be a function. - auto fnType = getCallee()->getType().dyn_cast(); - if (!fnType) - return emitOpError("callee must have function type"); - - // Verify that the operand and result types match the callee. - if (fnType.getNumInputs() != getNumOperands() - 1) - return emitOpError("incorrect number of operands for callee"); - - for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) { - if (getOperand(i + 1)->getType() != fnType.getInput(i)) - return emitOpError("operand type mismatch"); - } - - if (fnType.getNumResults() != getNumResults()) - return emitOpError("incorrect number of results for callee"); - - for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) { - if (getResult(i)->getType() != fnType.getResult(i)) - return emitOpError("result type mismatch"); - } - - return false; -} - -void CallIndirectOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.push_back( - std::make_unique(context)); -} - -//===----------------------------------------------------------------------===// -// CmpIOp -//===----------------------------------------------------------------------===// - -// Return the type of the same shape (scalar, vector or tensor) containing i1. -static Type getCheckedI1SameShape(Builder *build, Type type) { - auto i1Type = build->getI1Type(); - if (type.isIntOrIndexOrFloat()) - return i1Type; - if (auto tensorType = type.dyn_cast()) - return build->getTensorType(tensorType.getShape(), i1Type); - if (auto tensorType = type.dyn_cast()) - return build->getTensorType(i1Type); - if (auto vectorType = type.dyn_cast()) - return build->getVectorType(vectorType.getShape(), i1Type); - return Type(); -} - -static Type getI1SameShape(Builder *build, Type type) { - Type res = getCheckedI1SameShape(build, type); - assert(res && "expected type with valid i1 shape"); - return res; -} - -static inline bool isI1(Type type) { - return type.isa() && type.cast().getWidth() == 1; -} - -template -static inline bool implCheckI1SameShape(Ty pattern, Type type) { - auto specificType = type.dyn_cast(); - if (!specificType) - return true; - if (specificType.getShape() != pattern.getShape()) - return true; - return !isI1(specificType.getElementType()); -} - -// Checks if "type" has the same shape (scalar, vector or tensor) as "pattern" -// and contains i1. -static bool checkI1SameShape(Type pattern, Type type) { - if (pattern.isIntOrIndexOrFloat()) - return !isI1(type); - if (auto patternTensorType = pattern.dyn_cast()) - return implCheckI1SameShape(patternTensorType, type); - if (auto patternVectorType = pattern.dyn_cast()) - return implCheckI1SameShape(patternVectorType, type); - - llvm_unreachable("unsupported type"); -} - -// Returns an array of mnemonics for CmpIPredicates, indexed by values thereof. -static inline const char *const *getPredicateNames() { - static const char *predicateNames[(int)CmpIPredicate::NumPredicates]{ - /*EQ*/ "eq", - /*NE*/ "ne", - /*SLT*/ "slt", - /*SLE*/ "sle", - /*SGT*/ "sgt", - /*SGE*/ "sge", - /*ULT*/ "ult", - /*ULE*/ "ule", - /*UGT*/ "ugt", - /*UGE*/ "uge"}; - return predicateNames; -}; - -// Returns a value of the predicate corresponding to the given mnemonic. -// Returns NumPredicates (one-past-end) if there is no such mnemonic. -CmpIPredicate CmpIOp::getPredicateByName(StringRef name) { - return llvm::StringSwitch(name) - .Case("eq", CmpIPredicate::EQ) - .Case("ne", CmpIPredicate::NE) - .Case("slt", CmpIPredicate::SLT) - .Case("sle", CmpIPredicate::SLE) - .Case("sgt", CmpIPredicate::SGT) - .Case("sge", CmpIPredicate::SGE) - .Case("ult", CmpIPredicate::ULT) - .Case("ule", CmpIPredicate::ULE) - .Case("ugt", CmpIPredicate::UGT) - .Case("uge", CmpIPredicate::UGE) - .Default(CmpIPredicate::NumPredicates); -} - -void CmpIOp::build(Builder *build, OperationState *result, - CmpIPredicate predicate, Value *lhs, Value *rhs) { - result->addOperands({lhs, rhs}); - result->types.push_back(getI1SameShape(build, lhs->getType())); - result->addAttribute(getPredicateAttrName(), - build->getIntegerAttr(build->getIntegerType(64), - static_cast(predicate))); -} - -bool CmpIOp::parse(OpAsmParser *parser, OperationState *result) { - SmallVector ops; - SmallVector attrs; - Attribute predicateNameAttr; - Type type; - if (parser->parseAttribute(predicateNameAttr, getPredicateAttrName(), - attrs) || - parser->parseComma() || parser->parseOperandList(ops, 2) || - parser->parseOptionalAttributeDict(attrs) || - parser->parseColonType(type) || - parser->resolveOperands(ops, type, result->operands)) - return true; - - if (!predicateNameAttr.isa()) - return parser->emitError(parser->getNameLoc(), - "expected string comparison predicate attribute"); - - // Rewrite string attribute to an enum value. - StringRef predicateName = predicateNameAttr.cast().getValue(); - auto predicate = getPredicateByName(predicateName); - if (predicate == CmpIPredicate::NumPredicates) - return parser->emitError(parser->getNameLoc(), - "unknown comparison predicate \"" + predicateName + - "\""); - - auto builder = parser->getBuilder(); - Type i1Type = getCheckedI1SameShape(&builder, type); - if (!i1Type) - return parser->emitError(parser->getNameLoc(), - "expected type with valid i1 shape"); - - attrs[0].second = builder.getI64IntegerAttr(static_cast(predicate)); - result->attributes = attrs; - - result->addTypes({i1Type}); - return false; -} - -void CmpIOp::print(OpAsmPrinter *p) const { - *p << getOperationName() << " "; - - auto predicateValue = - getAttrOfType(getPredicateAttrName()).getInt(); - assert(predicateValue >= static_cast(CmpIPredicate::FirstValidValue) && - predicateValue < static_cast(CmpIPredicate::NumPredicates) && - "unknown predicate index"); - Builder b(getInstruction()->getContext()); - auto predicateStringAttr = - b.getStringAttr(getPredicateNames()[predicateValue]); - p->printAttribute(predicateStringAttr); - - *p << ", "; - p->printOperand(getOperand(0)); - *p << ", "; - p->printOperand(getOperand(1)); - p->printOptionalAttrDict(getAttrs(), - /*elidedAttrs=*/{getPredicateAttrName()}); - *p << " : " << getOperand(0)->getType(); -} - -bool CmpIOp::verify() const { - auto predicateAttr = getAttrOfType(getPredicateAttrName()); - if (!predicateAttr) - return emitOpError("requires an integer attribute named 'predicate'"); - auto predicate = predicateAttr.getInt(); - if (predicate < (int64_t)CmpIPredicate::FirstValidValue || - predicate >= (int64_t)CmpIPredicate::NumPredicates) - return emitOpError("'predicate' attribute value out of range"); - - return false; -} - -// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer -// comparison predicates. -static bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs, - const APInt &rhs) { - switch (predicate) { - case CmpIPredicate::EQ: - return lhs.eq(rhs); - case CmpIPredicate::NE: - return lhs.ne(rhs); - case CmpIPredicate::SLT: - return lhs.slt(rhs); - case CmpIPredicate::SLE: - return lhs.sle(rhs); - case CmpIPredicate::SGT: - return lhs.sgt(rhs); - case CmpIPredicate::SGE: - return lhs.sge(rhs); - case CmpIPredicate::ULT: - return lhs.ult(rhs); - case CmpIPredicate::ULE: - return lhs.ule(rhs); - case CmpIPredicate::UGT: - return lhs.ugt(rhs); - case CmpIPredicate::UGE: - return lhs.uge(rhs); - default: - llvm_unreachable("unknown comparison predicate"); - } -} - -// Constant folding hook for comparisons. -Attribute CmpIOp::constantFold(ArrayRef operands, - MLIRContext *context) const { - assert(operands.size() == 2 && "cmpi takes two arguments"); - - auto lhs = operands.front().dyn_cast_or_null(); - auto rhs = operands.back().dyn_cast_or_null(); - if (!lhs || !rhs) - return {}; - - auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); - return IntegerAttr::get(IntegerType::get(1, context), APInt(1, val)); -} - -//===----------------------------------------------------------------------===// -// DeallocOp -//===----------------------------------------------------------------------===// -namespace { -/// Fold Dealloc instructions that are deallocating an AllocOp that is only used -/// by other Dealloc operations. -struct SimplifyDeadDealloc : public RewritePattern { - SimplifyDeadDealloc(MLIRContext *context) - : RewritePattern(DeallocOp::getOperationName(), 1, context) {} - - PatternMatchResult match(Instruction *op) const override { - auto dealloc = op->cast(); - - // Check that the memref operand's defining instruction is an AllocOp. - Value *memref = dealloc->getMemRef(); - Instruction *defOp = memref->getDefiningInst(); - if (!defOp || !defOp->isa()) - return matchFailure(); - - // Check that all of the uses of the AllocOp are other DeallocOps. - for (auto &use : memref->getUses()) - if (!use.getOwner()->isa()) - return matchFailure(); - return matchSuccess(); - } - - void rewrite(Instruction *op, PatternRewriter &rewriter) const override { - // Erase the dealloc operation. - op->erase(); - } -}; -} // end anonymous namespace. - -void DeallocOp::build(Builder *builder, OperationState *result, Value *memref) { - result->addOperands(memref); -} - -void DeallocOp::print(OpAsmPrinter *p) const { - *p << "dealloc " << *getMemRef() << " : " << getMemRef()->getType(); -} - -bool DeallocOp::parse(OpAsmParser *parser, OperationState *result) { - OpAsmParser::OperandType memrefInfo; - MemRefType type; - - return parser->parseOperand(memrefInfo) || parser->parseColonType(type) || - parser->resolveOperand(memrefInfo, type, result->operands); -} - -bool DeallocOp::verify() const { - if (!getMemRef()->getType().isa()) - return emitOpError("operand must be a memref"); - return false; -} - -void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - /// dealloc(memrefcast) -> dealloc - results.push_back( - std::make_unique(getOperationName(), context)); - results.push_back(std::make_unique(context)); -} - -//===----------------------------------------------------------------------===// -// DimOp -//===----------------------------------------------------------------------===// - -void DimOp::build(Builder *builder, OperationState *result, - Value *memrefOrTensor, unsigned index) { - result->addOperands(memrefOrTensor); - auto type = builder->getIndexType(); - result->addAttribute("index", builder->getIntegerAttr(type, index)); - result->types.push_back(type); -} - -void DimOp::print(OpAsmPrinter *p) const { - *p << "dim " << *getOperand() << ", " << getIndex(); - p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{"index"}); - *p << " : " << getOperand()->getType(); -} - -bool DimOp::parse(OpAsmParser *parser, OperationState *result) { - OpAsmParser::OperandType operandInfo; - IntegerAttr indexAttr; - Type type; - Type indexType = parser->getBuilder().getIndexType(); - - return parser->parseOperand(operandInfo) || parser->parseComma() || - parser->parseAttribute(indexAttr, indexType, "index", - result->attributes) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(type) || - parser->resolveOperand(operandInfo, type, result->operands) || - parser->addTypeToList(indexType, result->types); -} - -bool DimOp::verify() const { - // Check that we have an integer index operand. - auto indexAttr = getAttrOfType("index"); - if (!indexAttr) - return emitOpError("requires an integer attribute named 'index'"); - uint64_t index = indexAttr.getValue().getZExtValue(); - - auto type = getOperand()->getType(); - if (auto tensorType = type.dyn_cast()) { - if (index >= tensorType.getRank()) - return emitOpError("index is out of range"); - } else if (auto memrefType = type.dyn_cast()) { - if (index >= memrefType.getRank()) - return emitOpError("index is out of range"); - - } else if (type.isa()) { - // ok, assumed to be in-range. - } else { - return emitOpError("requires an operand with tensor or memref type"); - } - - return false; -} - -Attribute DimOp::constantFold(ArrayRef operands, - MLIRContext *context) const { - // Constant fold dim when the size along the index referred to is a constant. - auto opType = getOperand()->getType(); - int64_t indexSize = -1; - if (auto tensorType = opType.dyn_cast()) { - indexSize = tensorType.getShape()[getIndex()]; - } else if (auto memrefType = opType.dyn_cast()) { - indexSize = memrefType.getShape()[getIndex()]; - } - - if (indexSize >= 0) - return IntegerAttr::get(IndexType::get(context), indexSize); - - return nullptr; -} - -//===----------------------------------------------------------------------===// -// DivISOp -//===----------------------------------------------------------------------===// - -Attribute DivISOp::constantFold(ArrayRef operands, - MLIRContext *context) const { - assert(operands.size() == 2 && "binary operation takes two operands"); - (void)context; - - auto lhs = operands.front().dyn_cast_or_null(); - auto rhs = operands.back().dyn_cast_or_null(); - if (!lhs || !rhs) - return {}; - - // Don't fold if it requires division by zero. - if (rhs.getValue().isNullValue()) { - return {}; - } - - // Don't fold if it would overflow. - bool overflow; - auto result = lhs.getValue().sdiv_ov(rhs.getValue(), overflow); - return overflow ? IntegerAttr{} : IntegerAttr::get(lhs.getType(), result); -} - -//===----------------------------------------------------------------------===// -// DivIUOp -//===----------------------------------------------------------------------===// - -Attribute DivIUOp::constantFold(ArrayRef operands, - MLIRContext *context) const { - assert(operands.size() == 2 && "binary operation takes two operands"); - (void)context; - - auto lhs = operands.front().dyn_cast_or_null(); - auto rhs = operands.back().dyn_cast_or_null(); - if (!lhs || !rhs) - return {}; - - // Don't fold if it requires division by zero. - if (rhs.getValue().isNullValue()) { - return {}; - } - - return IntegerAttr::get(lhs.getType(), lhs.getValue().udiv(rhs.getValue())); -} - -// --------------------------------------------------------------------------- -// DmaStartOp -// --------------------------------------------------------------------------- - -void DmaStartOp::build(Builder *builder, OperationState *result, - Value *srcMemRef, ArrayRef srcIndices, - Value *destMemRef, ArrayRef destIndices, - Value *numElements, Value *tagMemRef, - ArrayRef tagIndices, Value *stride, - Value *elementsPerStride) { - result->addOperands(srcMemRef); - result->addOperands(srcIndices); - result->addOperands(destMemRef); - result->addOperands(destIndices); - result->addOperands(numElements); - result->addOperands(tagMemRef); - result->addOperands(tagIndices); - if (stride) { - result->addOperands(stride); - result->addOperands(elementsPerStride); - } -} - -void DmaStartOp::print(OpAsmPrinter *p) const { - *p << getOperationName() << ' ' << *getSrcMemRef() << '['; - p->printOperands(getSrcIndices()); - *p << "], " << *getDstMemRef() << '['; - p->printOperands(getDstIndices()); - *p << "], " << *getNumElements(); - *p << ", " << *getTagMemRef() << '['; - p->printOperands(getTagIndices()); - *p << ']'; - if (isStrided()) { - *p << ", " << *getStride(); - *p << ", " << *getNumElementsPerStride(); - } - p->printOptionalAttrDict(getAttrs()); - *p << " : " << getSrcMemRef()->getType(); - *p << ", " << getDstMemRef()->getType(); - *p << ", " << getTagMemRef()->getType(); -} - -// Parse DmaStartOp. -// Ex: -// %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, -// %tag[%index], %stride, %num_elt_per_stride : -// : memref<3076 x f32, 0>, -// memref<1024 x f32, 2>, -// memref<1 x i32> -// -bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) { - OpAsmParser::OperandType srcMemRefInfo; - SmallVector srcIndexInfos; - OpAsmParser::OperandType dstMemRefInfo; - SmallVector dstIndexInfos; - OpAsmParser::OperandType numElementsInfo; - OpAsmParser::OperandType tagMemrefInfo; - SmallVector tagIndexInfos; - SmallVector strideInfo; - - SmallVector types; - auto indexType = parser->getBuilder().getIndexType(); - - // Parse and resolve the following list of operands: - // *) source memref followed by its indices (in square brackets). - // *) destination memref followed by its indices (in square brackets). - // *) dma size in KiB. - if (parser->parseOperand(srcMemRefInfo) || - parser->parseOperandList(srcIndexInfos, -1, - OpAsmParser::Delimiter::Square) || - parser->parseComma() || parser->parseOperand(dstMemRefInfo) || - parser->parseOperandList(dstIndexInfos, -1, - OpAsmParser::Delimiter::Square) || - parser->parseComma() || parser->parseOperand(numElementsInfo) || - parser->parseComma() || parser->parseOperand(tagMemrefInfo) || - parser->parseOperandList(tagIndexInfos, -1, - OpAsmParser::Delimiter::Square)) - return true; - - // Parse optional stride and elements per stride. - if (parser->parseTrailingOperandList(strideInfo)) { - return true; - } - if (!strideInfo.empty() && strideInfo.size() != 2) { - return parser->emitError(parser->getNameLoc(), - "expected two stride related operands"); - } - bool isStrided = strideInfo.size() == 2; - - if (parser->parseColonTypeList(types)) - return true; - - if (types.size() != 3) - return parser->emitError(parser->getNameLoc(), "fewer/more types expected"); - - if (parser->resolveOperand(srcMemRefInfo, types[0], result->operands) || - parser->resolveOperands(srcIndexInfos, indexType, result->operands) || - parser->resolveOperand(dstMemRefInfo, types[1], result->operands) || - parser->resolveOperands(dstIndexInfos, indexType, result->operands) || - // size should be an index. - parser->resolveOperand(numElementsInfo, indexType, result->operands) || - parser->resolveOperand(tagMemrefInfo, types[2], result->operands) || - // tag indices should be index. - parser->resolveOperands(tagIndexInfos, indexType, result->operands)) - return true; - - if (!types[0].isa()) - return parser->emitError(parser->getNameLoc(), - "expected source to be of memref type"); - - if (!types[1].isa()) - return parser->emitError(parser->getNameLoc(), - "expected destination to be of memref type"); - - if (!types[2].isa()) - return parser->emitError(parser->getNameLoc(), - "expected tag to be of memref type"); - - if (isStrided) { - if (parser->resolveOperand(strideInfo[0], indexType, result->operands) || - parser->resolveOperand(strideInfo[1], indexType, result->operands)) - return true; - } - - // Check that source/destination index list size matches associated rank. - if (srcIndexInfos.size() != types[0].cast().getRank() || - dstIndexInfos.size() != types[1].cast().getRank()) - return parser->emitError(parser->getNameLoc(), - "memref rank not equal to indices count"); - - if (tagIndexInfos.size() != types[2].cast().getRank()) - return parser->emitError(parser->getNameLoc(), - "tag memref rank not equal to indices count"); - - return false; -} - -bool DmaStartOp::verify() const { - // DMAs from different memory spaces supported. - if (getSrcMemorySpace() == getDstMemorySpace()) { - return emitOpError("DMA should be between different memory spaces"); - } - - if (getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() + - getDstMemRefRank() + 3 + 1 && - getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() + - getDstMemRefRank() + 3 + 1 + 2) { - return emitOpError("incorrect number of operands"); - } - return false; -} - -void DmaStartOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - /// dma_start(memrefcast) -> dma_start - results.push_back( - std::make_unique(getOperationName(), context)); -} - -// --------------------------------------------------------------------------- -// DmaWaitOp -// --------------------------------------------------------------------------- - -void DmaWaitOp::build(Builder *builder, OperationState *result, - Value *tagMemRef, ArrayRef tagIndices, - Value *numElements) { - result->addOperands(tagMemRef); - result->addOperands(tagIndices); - result->addOperands(numElements); -} - -void DmaWaitOp::print(OpAsmPrinter *p) const { - *p << getOperationName() << ' '; - // Print operands. - p->printOperand(getTagMemRef()); - *p << '['; - p->printOperands(getTagIndices()); - *p << "], "; - p->printOperand(getNumElements()); - p->printOptionalAttrDict(getAttrs()); - *p << " : " << getTagMemRef()->getType(); -} - -// Parse DmaWaitOp. -// Eg: -// dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4> -// -bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) { - OpAsmParser::OperandType tagMemrefInfo; - SmallVector tagIndexInfos; - Type type; - auto indexType = parser->getBuilder().getIndexType(); - OpAsmParser::OperandType numElementsInfo; - - // Parse tag memref, its indices, and dma size. - if (parser->parseOperand(tagMemrefInfo) || - parser->parseOperandList(tagIndexInfos, -1, - OpAsmParser::Delimiter::Square) || - parser->parseComma() || parser->parseOperand(numElementsInfo) || - parser->parseColonType(type) || - parser->resolveOperand(tagMemrefInfo, type, result->operands) || - parser->resolveOperands(tagIndexInfos, indexType, result->operands) || - parser->resolveOperand(numElementsInfo, indexType, result->operands)) - return true; - - if (!type.isa()) - return parser->emitError(parser->getNameLoc(), - "expected tag to be of memref type"); - - if (tagIndexInfos.size() != type.cast().getRank()) - return parser->emitError(parser->getNameLoc(), - "tag memref rank not equal to indices count"); - - return false; -} - -void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - /// dma_wait(memrefcast) -> dma_wait - results.push_back( - std::make_unique(getOperationName(), context)); -} - -//===----------------------------------------------------------------------===// -// ExtractElementOp -//===----------------------------------------------------------------------===// - -void ExtractElementOp::build(Builder *builder, OperationState *result, - Value *aggregate, ArrayRef indices) { - auto aggregateType = aggregate->getType().cast(); - result->addOperands(aggregate); - result->addOperands(indices); - result->types.push_back(aggregateType.getElementType()); -} - -void ExtractElementOp::print(OpAsmPrinter *p) const { - *p << "extract_element " << *getAggregate() << '['; - p->printOperands(getIndices()); - *p << ']'; - p->printOptionalAttrDict(getAttrs()); - *p << " : " << getAggregate()->getType(); -} - -bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) { - OpAsmParser::OperandType aggregateInfo; - SmallVector indexInfo; - VectorOrTensorType type; - - auto affineIntTy = parser->getBuilder().getIndexType(); - return parser->parseOperand(aggregateInfo) || - parser->parseOperandList(indexInfo, -1, - OpAsmParser::Delimiter::Square) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(type) || - parser->resolveOperand(aggregateInfo, type, result->operands) || - parser->resolveOperands(indexInfo, affineIntTy, result->operands) || - parser->addTypeToList(type.getElementType(), result->types); -} - -bool ExtractElementOp::verify() const { - if (getNumOperands() == 0) - return emitOpError("expected an aggregate to index into"); - - auto aggregateType = getAggregate()->getType().dyn_cast(); - if (!aggregateType) - return emitOpError("first operand must be a vector or tensor"); - - if (getType() != aggregateType.getElementType()) - return emitOpError("result type must match element type of aggregate"); - - for (auto *idx : getIndices()) - if (!idx->getType().isIndex()) - return emitOpError("index to extract_element must have 'index' type"); - - // Verify the # indices match if we have a ranked type. - auto aggregateRank = aggregateType.getRank(); - if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1) - return emitOpError("incorrect number of indices for extract_element"); - - return false; -} - -Attribute ExtractElementOp::constantFold(ArrayRef operands, - MLIRContext *context) const { - assert(operands.size() > 1 && "extract_element takes atleast one operands"); - - // The aggregate operand must be a known constant. - Attribute aggregate = operands.front(); - if (!aggregate) - return Attribute(); - - // If this is a splat elements attribute, simply return the value. All of the - // elements of a splat attribute are the same. - if (auto splatAggregate = aggregate.dyn_cast()) - return splatAggregate.getValue(); - - // Otherwise, collect the constant indices into the aggregate. - SmallVector indices; - for (Attribute indice : llvm::drop_begin(operands, 1)) { - if (!indice || !indice.isa()) - return Attribute(); - indices.push_back(indice.cast().getInt()); - } - - // If this is an elements attribute, query the value at the given indices. - if (auto elementsAttr = aggregate.dyn_cast()) - return elementsAttr.getValue(indices); - return Attribute(); -} - -//===----------------------------------------------------------------------===// -// LoadOp -//===----------------------------------------------------------------------===// - -void LoadOp::build(Builder *builder, OperationState *result, Value *memref, - ArrayRef indices) { - auto memrefType = memref->getType().cast(); - result->addOperands(memref); - result->addOperands(indices); - result->types.push_back(memrefType.getElementType()); -} - -void LoadOp::print(OpAsmPrinter *p) const { - *p << "load " << *getMemRef() << '['; - p->printOperands(getIndices()); - *p << ']'; - p->printOptionalAttrDict(getAttrs()); - *p << " : " << getMemRefType(); -} - -bool LoadOp::parse(OpAsmParser *parser, OperationState *result) { - OpAsmParser::OperandType memrefInfo; - SmallVector indexInfo; - MemRefType type; - - auto affineIntTy = parser->getBuilder().getIndexType(); - return parser->parseOperand(memrefInfo) || - parser->parseOperandList(indexInfo, -1, - OpAsmParser::Delimiter::Square) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(type) || - parser->resolveOperand(memrefInfo, type, result->operands) || - parser->resolveOperands(indexInfo, affineIntTy, result->operands) || - parser->addTypeToList(type.getElementType(), result->types); -} - -bool LoadOp::verify() const { - if (getNumOperands() == 0) - return emitOpError("expected a memref to load from"); - - auto memRefType = getMemRef()->getType().dyn_cast(); - if (!memRefType) - return emitOpError("first operand must be a memref"); - - if (getType() != memRefType.getElementType()) - return emitOpError("result type must match element type of memref"); - - if (memRefType.getRank() != getNumOperands() - 1) - return emitOpError("incorrect number of indices for load"); - - for (auto *idx : getIndices()) - if (!idx->getType().isIndex()) - return emitOpError("index to load must have 'index' type"); - - // TODO: Verify we have the right number of indices. - - // TODO: in Function verify that the indices are parameters, IV's, or the - // result of an affine.apply. - return false; -} - -void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - /// load(memrefcast) -> load - results.push_back( - std::make_unique(getOperationName(), context)); -} - -//===----------------------------------------------------------------------===// -// MemRefCastOp -//===----------------------------------------------------------------------===// - -bool MemRefCastOp::verify() const { - auto opType = getOperand()->getType().dyn_cast(); - auto resType = getType().dyn_cast(); - if (!opType || !resType) - return emitOpError("requires input and result types to be memrefs"); - - if (opType == resType) - return emitOpError("requires the input and result type to be different"); - - if (opType.getElementType() != resType.getElementType()) - return emitOpError( - "requires input and result element types to be the same"); - - if (opType.getAffineMaps() != resType.getAffineMaps()) - return emitOpError("requires input and result mappings to be the same"); - - if (opType.getMemorySpace() != resType.getMemorySpace()) - return emitOpError( - "requires input and result memory spaces to be the same"); - - // They must have the same rank, and any specified dimensions must match. - if (opType.getRank() != resType.getRank()) - return emitOpError("requires input and result ranks to match"); - - for (unsigned i = 0, e = opType.getRank(); i != e; ++i) { - int64_t opDim = opType.getDimSize(i), resultDim = resType.getDimSize(i); - if (opDim != -1 && resultDim != -1 && opDim != resultDim) - return emitOpError("requires static dimensions to match"); - } - - return false; -} - -//===----------------------------------------------------------------------===// -// MulFOp -//===----------------------------------------------------------------------===// - -Attribute MulFOp::constantFold(ArrayRef operands, - MLIRContext *context) const { - return constFoldBinaryOp( - operands, [](APFloat a, APFloat b) { return a * b; }); -} - -//===----------------------------------------------------------------------===// -// MulIOp -//===----------------------------------------------------------------------===// - -Attribute MulIOp::constantFold(ArrayRef operands, - MLIRContext *context) const { - // TODO: Handle the overflow case. - return constFoldBinaryOp(operands, - [](APInt a, APInt b) { return a * b; }); -} - -Value *MulIOp::fold() { - /// muli(x, 0) -> 0 - if (matchPattern(getOperand(1), m_Zero())) - return getOperand(1); - /// muli(x, 1) -> x - if (matchPattern(getOperand(1), m_One())) - return getOperand(0); - return nullptr; -} - -//===----------------------------------------------------------------------===// -// RemISOp -//===----------------------------------------------------------------------===// - -Attribute RemISOp::constantFold(ArrayRef operands, - MLIRContext *context) const { - assert(operands.size() == 2 && "remis takes two operands"); - - auto rhs = operands.back().dyn_cast_or_null(); - if (!rhs) - return {}; - - // x % 1 = 0 - if (rhs.getValue().isOneValue()) - return IntegerAttr::get(rhs.getType(), - APInt(rhs.getValue().getBitWidth(), 0)); - - // Don't fold if it requires division by zero. - if (rhs.getValue().isNullValue()) { - return {}; - } - - auto lhs = operands.front().dyn_cast_or_null(); - if (!lhs) - return {}; - - return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhs.getValue())); -} - -//===----------------------------------------------------------------------===// -// RemIUOp -//===----------------------------------------------------------------------===// - -Attribute RemIUOp::constantFold(ArrayRef operands, - MLIRContext *context) const { - assert(operands.size() == 2 && "remiu takes two operands"); - - auto rhs = operands.back().dyn_cast_or_null(); - if (!rhs) - return {}; - - // x % 1 = 0 - if (rhs.getValue().isOneValue()) - return IntegerAttr::get(rhs.getType(), - APInt(rhs.getValue().getBitWidth(), 0)); - - // Don't fold if it requires division by zero. - if (rhs.getValue().isNullValue()) { - return {}; - } - - auto lhs = operands.front().dyn_cast_or_null(); - if (!lhs) - return {}; - - return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhs.getValue())); -} - -//===----------------------------------------------------------------------===// -// SelectOp -//===----------------------------------------------------------------------===// - -void SelectOp::build(Builder *builder, OperationState *result, Value *condition, - Value *trueValue, Value *falseValue) { - result->addOperands({condition, trueValue, falseValue}); - result->addTypes(trueValue->getType()); -} - -bool SelectOp::parse(OpAsmParser *parser, OperationState *result) { - SmallVector ops; - SmallVector attrs; - Type type; - - if (parser->parseOperandList(ops, 3) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(type)) - return true; - - auto i1Type = getCheckedI1SameShape(&parser->getBuilder(), type); - if (!i1Type) - return parser->emitError(parser->getNameLoc(), - "expected type with valid i1 shape"); - - SmallVector types = {i1Type, type, type}; - return parser->resolveOperands(ops, types, parser->getNameLoc(), - result->operands) || - parser->addTypeToList(type, result->types); -} - -void SelectOp::print(OpAsmPrinter *p) const { - *p << getOperationName() << ' '; - p->printOperands(getInstruction()->getOperands()); - *p << " : " << getTrueValue()->getType(); - p->printOptionalAttrDict(getAttrs()); -} - -bool SelectOp::verify() const { - auto conditionType = getCondition()->getType(); - auto trueType = getTrueValue()->getType(); - auto falseType = getFalseValue()->getType(); - - if (trueType != falseType) - return emitOpError( - "requires 'true' and 'false' arguments to be of the same type"); - - if (checkI1SameShape(trueType, conditionType)) - return emitOpError("requires the condition to have the same shape as " - "arguments with elemental type i1"); - - return false; -} - -Value *SelectOp::fold() { - auto *condition = getCondition(); - - // select true, %0, %1 => %0 - if (matchPattern(condition, m_One())) - return getTrueValue(); - - // select false, %0, %1 => %1 - if (matchPattern(condition, m_Zero())) - return getFalseValue(); - return nullptr; -} - -//===----------------------------------------------------------------------===// -// StoreOp -//===----------------------------------------------------------------------===// - -void StoreOp::build(Builder *builder, OperationState *result, - Value *valueToStore, Value *memref, - ArrayRef indices) { - result->addOperands(valueToStore); - result->addOperands(memref); - result->addOperands(indices); -} - -void StoreOp::print(OpAsmPrinter *p) const { - *p << "store " << *getValueToStore(); - *p << ", " << *getMemRef() << '['; - p->printOperands(getIndices()); - *p << ']'; - p->printOptionalAttrDict(getAttrs()); - *p << " : " << getMemRefType(); -} - -bool StoreOp::parse(OpAsmParser *parser, OperationState *result) { - OpAsmParser::OperandType storeValueInfo; - OpAsmParser::OperandType memrefInfo; - SmallVector indexInfo; - MemRefType memrefType; - - auto affineIntTy = parser->getBuilder().getIndexType(); - return parser->parseOperand(storeValueInfo) || parser->parseComma() || - parser->parseOperand(memrefInfo) || - parser->parseOperandList(indexInfo, -1, - OpAsmParser::Delimiter::Square) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(memrefType) || - parser->resolveOperand(storeValueInfo, memrefType.getElementType(), - result->operands) || - parser->resolveOperand(memrefInfo, memrefType, result->operands) || - parser->resolveOperands(indexInfo, affineIntTy, result->operands); -} - -bool StoreOp::verify() const { - if (getNumOperands() < 2) - return emitOpError("expected a value to store and a memref"); - - // Second operand is a memref type. - auto memRefType = getMemRef()->getType().dyn_cast(); - if (!memRefType) - return emitOpError("second operand must be a memref"); - - // First operand must have same type as memref element type. - if (getValueToStore()->getType() != memRefType.getElementType()) - return emitOpError("first operand must have same type memref element type"); - - if (getNumOperands() != 2 + memRefType.getRank()) - return emitOpError("store index operand count not equal to memref rank"); - - for (auto *idx : getIndices()) - if (!idx->getType().isIndex()) - return emitOpError("index to load must have 'index' type"); - - // TODO: Verify we have the right number of indices. - - // TODO: in Function verify that the indices are parameters, IV's, or the - // result of an affine.apply. - return false; -} - -void StoreOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - /// store(memrefcast) -> store - results.push_back( - std::make_unique(getOperationName(), context)); -} - -//===----------------------------------------------------------------------===// -// SubFOp -//===----------------------------------------------------------------------===// - -Attribute SubFOp::constantFold(ArrayRef operands, - MLIRContext *context) const { - return constFoldBinaryOp( - operands, [](APFloat a, APFloat b) { return a - b; }); -} - -//===----------------------------------------------------------------------===// -// SubIOp -//===----------------------------------------------------------------------===// - -Attribute SubIOp::constantFold(ArrayRef operands, - MLIRContext *context) const { - return constFoldBinaryOp(operands, - [](APInt a, APInt b) { return a - b; }); -} - -namespace { -/// subi(x,x) -> 0 -/// -struct SimplifyXMinusX : public RewritePattern { - SimplifyXMinusX(MLIRContext *context) - : RewritePattern(SubIOp::getOperationName(), 1, context) {} - - PatternMatchResult match(Instruction *op) const override { - auto subi = op->cast(); - if (subi->getOperand(0) == subi->getOperand(1)) - return matchSuccess(); - - return matchFailure(); - } - void rewrite(Instruction *op, PatternRewriter &rewriter) const override { - auto subi = op->cast(); - auto result = - rewriter.create(op->getLoc(), 0, subi->getType()); - - rewriter.replaceOp(op, {result}); - } -}; -} // end anonymous namespace. - -void SubIOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.push_back(std::make_unique(context)); -} - -//===----------------------------------------------------------------------===// -// TensorCastOp -//===----------------------------------------------------------------------===// - -bool TensorCastOp::verify() const { - auto opType = getOperand()->getType().dyn_cast(); - auto resType = getType().dyn_cast(); - if (!opType || !resType) - return emitOpError("requires input and result types to be tensors"); - - if (opType == resType) - return emitOpError("requires the input and result type to be different"); - - if (opType.getElementType() != resType.getElementType()) - return emitOpError( - "requires input and result element types to be the same"); - - // If the source or destination are unranked, then the cast is valid. - auto opRType = opType.dyn_cast(); - auto resRType = resType.dyn_cast(); - if (!opRType || !resRType) - return false; - - // If they are both ranked, they have to have the same rank, and any specified - // dimensions must match. - if (opRType.getRank() != resRType.getRank()) - return emitOpError("requires input and result ranks to match"); - - for (unsigned i = 0, e = opRType.getRank(); i != e; ++i) { - int64_t opDim = opRType.getDimSize(i), resultDim = resRType.getDimSize(i); - if (opDim != -1 && resultDim != -1 && opDim != resultDim) - return emitOpError("requires static dimensions to match"); - } - - return false; -} - diff --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp index ecbe56ea730..f3b7f842afe 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp @@ -22,7 +22,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Module.h" #include "mlir/LLVMIR/LLVMDialect.h" -#include "mlir/StandardOps/StandardOps.h" +#include "mlir/StandardOps/Ops.h" #include "mlir/Support/FileUtilities.h" #include "mlir/Support/LLVM.h" #include "mlir/Translation.h" diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 53bc56173d2..c39c0a52a80 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -27,7 +27,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/StandardOps.h" +#include "mlir/StandardOps/Ops.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Utils.h" #include "llvm/ADT/MapVector.h" diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 025813c6ca4..7b7c0bb22bb 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -29,7 +29,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/StandardOps.h" +#include "mlir/StandardOps/Ops.h" #include "mlir/Transforms/LoopUtils.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Utils.h" diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index 1070c10a2d4..9979c3736ef 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -27,7 +27,7 @@ #include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/StandardOps.h" +#include "mlir/StandardOps/Ops.h" #include "mlir/Support/Functional.h" #include "mlir/Transforms/Passes.h" using namespace mlir; diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index 261c360631f..3990e54006d 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -37,7 +37,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Types.h" #include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/StandardOps.h" +#include "mlir/StandardOps/Ops.h" #include "mlir/SuperVectorOps/SuperVectorOps.h" #include "mlir/Support/Functional.h" #include "mlir/Transforms/MLPatternLoweringPass.h" diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index c41c75bb88f..572281c5cc3 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -37,7 +37,7 @@ #include "mlir/IR/OperationSupport.h" #include "mlir/IR/Types.h" #include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/StandardOps.h" +#include "mlir/StandardOps/Ops.h" #include "mlir/SuperVectorOps/SuperVectorOps.h" #include "mlir/Support/Functional.h" #include "mlir/Support/LLVM.h" diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index 55837f95d14..51e9debd0ad 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -26,7 +26,7 @@ #include "mlir/Analysis/Dominance.h" #include "mlir/Analysis/Utils.h" #include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/StandardOps.h" +#include "mlir/StandardOps/Ops.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/SmallPtrSet.h" #include diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 331279cb827..cdce5230ba6 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -27,7 +27,7 @@ #include "mlir/Analysis/Utils.h" #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/StandardOps.h" +#include "mlir/StandardOps/Ops.h" #include "mlir/Transforms/LoopUtils.h" #include "mlir/Transforms/Utils.h" #include "llvm/ADT/DenseMap.h" diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index eab131959b0..c13146ee2f5 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -31,7 +31,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Instruction.h" -#include "mlir/StandardOps/StandardOps.h" +#include "mlir/StandardOps/Ops.h" #include "llvm/ADT/DenseMap.h" #include "llvm/Support/Debug.h" diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 80dc49f1aab..8a21f273006 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -29,7 +29,7 @@ #include "mlir/Analysis/Utils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Module.h" -#include "mlir/StandardOps/StandardOps.h" +#include "mlir/StandardOps/Ops.h" #include "mlir/Support/MathExtras.h" #include "llvm/ADT/DenseMap.h" using namespace mlir; diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 8b7a17a7ff6..8277d4800ab 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -30,7 +30,7 @@ #include "mlir/IR/Location.h" #include "mlir/IR/Types.h" #include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/StandardOps.h" +#include "mlir/StandardOps/Ops.h" #include "mlir/SuperVectorOps/SuperVectorOps.h" #include "mlir/Support/Functional.h" #include "mlir/Support/LLVM.h" diff --git a/mlir/test/EDSC/api-test.cpp b/mlir/test/EDSC/api-test.cpp index 09d3751229a..6ce0be0556b 100644 --- a/mlir/test/EDSC/api-test.cpp +++ b/mlir/test/EDSC/api-test.cpp @@ -27,7 +27,7 @@ #include "mlir/IR/StandardTypes.h" #include "mlir/IR/Types.h" #include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/StandardOps.h" +#include "mlir/StandardOps/Ops.h" #include "mlir/Transforms/LoopUtils.h" #include "Test.h" diff --git a/mlir/test/mlir-tblgen/one-op-one-result.td b/mlir/test/mlir-tblgen/one-op-one-result.td index 048763c76c7..7dc37b58946 100644 --- a/mlir/test/mlir-tblgen/one-op-one-result.td +++ b/mlir/test/mlir-tblgen/one-op-one-result.td @@ -1,6 +1,6 @@ // RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s -include "mlir/IR/op_base.td" +include "mlir/IR/OpBase.td" // Create a Type and Attribute. def T : BuildableType<"buildT">; diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td index d9484fd4b7c..057d07b5006 100644 --- a/mlir/test/mlir-tblgen/op-result.td +++ b/mlir/test/mlir-tblgen/op-result.td @@ -1,6 +1,6 @@ // RUN: mlir-tblgen -gen-op-definitions -I %S/../../include %s | FileCheck %s -include "mlir/IR/op_base.td" +include "mlir/IR/OpBase.td" def SameTypeOp : Op<"same_type_op", [SameValueType]> { let arguments = (ins I32:$x); diff --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td index 347424e91d5..7f16b13d576 100644 --- a/mlir/test/mlir-tblgen/predicate.td +++ b/mlir/test/mlir-tblgen/predicate.td @@ -1,6 +1,6 @@ // RUN: mlir-tblgen -gen-op-definitions -I %S/../../include %s | FileCheck %s -include "mlir/IR/op_base.td" +include "mlir/IR/OpBase.td" def I32OrF32 : Type, "32-bit integer or floating-point type">; diff --git a/mlir/test/mlir-tblgen/reference-impl.td b/mlir/test/mlir-tblgen/reference-impl.td index 90e1efc2950..b722c7d5981 100644 --- a/mlir/test/mlir-tblgen/reference-impl.td +++ b/mlir/test/mlir-tblgen/reference-impl.td @@ -2,7 +2,7 @@ #ifdef OP_BASE #else -include "mlir/IR/op_base.td" +include "mlir/IR/OpBase.td" #endif // OP_BASE def X_AddOp : Op<"x.add">, -- cgit v1.2.3 From f37651c708d7ce1bc110e3f8b3f3507f06601c3e Mon Sep 17 00:00:00 2001 From: River Riddle Date: Fri, 1 Mar 2019 16:58:00 -0800 Subject: NFC. Move all of the remaining operations left in BuiltinOps to StandardOps. The only thing left in BuiltinOps are the core MLIR types. The standard types can't be moved because they are referenced within the IR directory, e.g. in things like Builder. PiperOrigin-RevId: 236403665 --- mlir/bindings/python/pybind.cpp | 1 - mlir/include/mlir/IR/BuiltinOps.h | 357 ---------------- mlir/include/mlir/IR/Instruction.h | 3 - mlir/include/mlir/IR/Matchers.h | 7 +- mlir/include/mlir/StandardOps/Ops.h | 316 ++++++++++++++ mlir/include/mlir/Transforms/Utils.h | 2 +- mlir/lib/AffineOps/AffineOps.cpp | 1 - mlir/lib/Analysis/AffineAnalysis.cpp | 1 - mlir/lib/Analysis/AffineStructures.cpp | 2 +- mlir/lib/Analysis/LoopAnalysis.cpp | 1 - mlir/lib/Analysis/MemRefBoundCheck.cpp | 1 - mlir/lib/Analysis/MemRefDependenceCheck.cpp | 1 - mlir/lib/Analysis/SliceAnalysis.cpp | 1 - mlir/lib/Analysis/Utils.cpp | 1 - mlir/lib/Analysis/VectorAnalysis.cpp | 1 - mlir/lib/EDSC/LowerEDSCTestPass.cpp | 1 - mlir/lib/EDSC/MLIREmitter.cpp | 1 - mlir/lib/EDSC/Types.cpp | 1 - mlir/lib/IR/AsmPrinter.cpp | 30 +- mlir/lib/IR/BuiltinOps.cpp | 454 --------------------- mlir/lib/IR/Instruction.cpp | 4 +- mlir/lib/IR/MLIRContext.cpp | 11 +- .../lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp | 1 - mlir/lib/Parser/Parser.cpp | 1 - mlir/lib/StandardOps/Ops.cpp | 426 ++++++++++++++++++- mlir/lib/Transforms/ConstantFold.cpp | 1 + mlir/lib/Transforms/DmaGeneration.cpp | 1 - mlir/lib/Transforms/LoopFusion.cpp | 1 - mlir/lib/Transforms/LoopUnroll.cpp | 1 - mlir/lib/Transforms/LoopUnrollAndJam.cpp | 1 - mlir/lib/Transforms/LowerAffine.cpp | 1 - mlir/lib/Transforms/LowerVectorTransfers.cpp | 1 - mlir/lib/Transforms/MaterializeVectors.cpp | 1 - .../Utils/GreedyPatternRewriteDriver.cpp | 2 +- mlir/lib/Transforms/Utils/LoopUtils.cpp | 1 - .../Vectorization/VectorizerTestPass.cpp | 1 - mlir/lib/Transforms/Vectorize.cpp | 1 - mlir/test/EDSC/api-test.cpp | 1 - 38 files changed, 771 insertions(+), 869 deletions(-) delete mode 100644 mlir/include/mlir/IR/BuiltinOps.h delete mode 100644 mlir/lib/IR/BuiltinOps.cpp (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/bindings/python/pybind.cpp b/mlir/bindings/python/pybind.cpp index fe80e071467..6be9eebc605 100644 --- a/mlir/bindings/python/pybind.cpp +++ b/mlir/bindings/python/pybind.cpp @@ -9,7 +9,6 @@ #include "third_party/llvm/llvm/projects/google_mlir/include/mlir/EDSC/MLIREmitter.h" #include "third_party/llvm/llvm/projects/google_mlir/include/mlir/EDSC/Types.h" #include "third_party/llvm/llvm/projects/google_mlir/include/mlir/ExecutionEngine/ExecutionEngine.h" -#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/BuiltinOps.h" #include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/Module.h" #include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Pass/Pass.h" #include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Target/LLVMIR.h" diff --git a/mlir/include/mlir/IR/BuiltinOps.h b/mlir/include/mlir/IR/BuiltinOps.h deleted file mode 100644 index 7b606fdbd1d..00000000000 --- a/mlir/include/mlir/IR/BuiltinOps.h +++ /dev/null @@ -1,357 +0,0 @@ -//===- BuiltinOps.h - Builtin MLIR Operations -------------------*- C++ -*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file defines convenience types for working with builtin operations -// in the MLIR instruction set. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_IR_BUILTINOPS_H -#define MLIR_IR_BUILTINOPS_H - -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/OpDefinition.h" - -namespace mlir { -class Builder; - -class BuiltinDialect : public Dialect { -public: - BuiltinDialect(MLIRContext *context); -}; - -/// The "br" operation represents a branch instruction in a CFG function. -/// The operation takes variable number of operands and produces no results. -/// The operand number and types for each successor must match the -/// arguments of the block successor. For example: -/// -/// bb2: -/// %2 = call @someFn() -/// br bb3(%2 : tensor<*xf32>) -/// bb3(%3: tensor<*xf32>): -/// -class BranchOp : public Op { -public: - static StringRef getOperationName() { return "br"; } - - static void build(Builder *builder, OperationState *result, Block *dest, - ArrayRef operands = {}); - - // Hooks to customize behavior of this op. - static bool parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p) const; - - /// Return the block this branch jumps to. - Block *getDest(); - const Block *getDest() const { - return const_cast(this)->getDest(); - } - void setDest(Block *block); - - /// Erase the operand at 'index' from the operand list. - void eraseOperand(unsigned index); - -private: - friend class Instruction; - explicit BranchOp(const Instruction *state) : Op(state) {} -}; - -/// The "cond_br" operation represents a conditional branch instruction in a -/// CFG function. The operation takes variable number of operands and produces -/// no results. The operand number and types for each successor must match the -// arguments of the block successor. For example: -/// -/// bb0: -/// %0 = extract_element %arg0[] : tensor -/// cond_br %0, bb1, bb2 -/// bb1: -/// ... -/// bb2: -/// ... -/// -class CondBranchOp : public Op::Impl, - OpTrait::ZeroResult, OpTrait::IsTerminator> { - // These are the indices into the dests list. - enum { trueIndex = 0, falseIndex = 1 }; - - /// The operands list of a conditional branch operation is layed out as - /// follows: - /// { condition, [true_operands], [false_operands] } -public: - static StringRef getOperationName() { return "cond_br"; } - - static void build(Builder *builder, OperationState *result, Value *condition, - Block *trueDest, ArrayRef trueOperands, - Block *falseDest, ArrayRef falseOperands); - - // Hooks to customize behavior of this op. - static bool parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p) const; - bool verify() const; - - static void getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context); - - // The condition operand is the first operand in the list. - Value *getCondition() { return getOperand(0); } - const Value *getCondition() const { return getOperand(0); } - - /// Return the destination if the condition is true. - Block *getTrueDest(); - const Block *getTrueDest() const { - return const_cast(this)->getTrueDest(); - } - - /// Return the destination if the condition is false. - Block *getFalseDest(); - const Block *getFalseDest() const { - return const_cast(this)->getFalseDest(); - } - - // Accessors for operands to the 'true' destination. - Value *getTrueOperand(unsigned idx) { - assert(idx < getNumTrueOperands()); - return getOperand(getTrueDestOperandIndex() + idx); - } - const Value *getTrueOperand(unsigned idx) const { - return const_cast(this)->getTrueOperand(idx); - } - void setTrueOperand(unsigned idx, Value *value) { - assert(idx < getNumTrueOperands()); - setOperand(getTrueDestOperandIndex() + idx, value); - } - - operand_iterator true_operand_begin() { - return operand_begin() + getTrueDestOperandIndex(); - } - operand_iterator true_operand_end() { - return true_operand_begin() + getNumTrueOperands(); - } - llvm::iterator_range getTrueOperands() { - return {true_operand_begin(), true_operand_end()}; - } - - const_operand_iterator true_operand_begin() const { - return operand_begin() + getTrueDestOperandIndex(); - } - const_operand_iterator true_operand_end() const { - return true_operand_begin() + getNumTrueOperands(); - } - llvm::iterator_range getTrueOperands() const { - return {true_operand_begin(), true_operand_end()}; - } - - unsigned getNumTrueOperands() const; - - /// Erase the operand at 'index' from the true operand list. - void eraseTrueOperand(unsigned index); - - // Accessors for operands to the 'false' destination. - Value *getFalseOperand(unsigned idx) { - assert(idx < getNumFalseOperands()); - return getOperand(getFalseDestOperandIndex() + idx); - } - const Value *getFalseOperand(unsigned idx) const { - return const_cast(this)->getFalseOperand(idx); - } - void setFalseOperand(unsigned idx, Value *value) { - assert(idx < getNumFalseOperands()); - setOperand(getFalseDestOperandIndex() + idx, value); - } - - operand_iterator false_operand_begin() { return true_operand_end(); } - operand_iterator false_operand_end() { - return false_operand_begin() + getNumFalseOperands(); - } - llvm::iterator_range getFalseOperands() { - return {false_operand_begin(), false_operand_end()}; - } - - const_operand_iterator false_operand_begin() const { - return true_operand_end(); - } - const_operand_iterator false_operand_end() const { - return false_operand_begin() + getNumFalseOperands(); - } - llvm::iterator_range getFalseOperands() const { - return {false_operand_begin(), false_operand_end()}; - } - - unsigned getNumFalseOperands() const; - - /// Erase the operand at 'index' from the false operand list. - void eraseFalseOperand(unsigned index); - -private: - /// Get the index of the first true destination operand. - unsigned getTrueDestOperandIndex() const { return 1; } - - /// Get the index of the first false destination operand. - unsigned getFalseDestOperandIndex() const { - return getTrueDestOperandIndex() + getNumTrueOperands(); - } - - friend class Instruction; - explicit CondBranchOp(const Instruction *state) : Op(state) {} -}; - -/// The "constant" operation requires a single attribute named "value". -/// It returns its value as an SSA value. For example: -/// -/// %1 = "constant"(){value: 42} : i32 -/// %2 = "constant"(){value: @foo} : (f32)->f32 -/// -class ConstantOp : public Op { -public: - /// Builds a constant op with the specified attribute value and result type. - static void build(Builder *builder, OperationState *result, Type type, - Attribute value); - - /// Builds a constant op with the specified attribute value and the - /// attribute's type. - static void build(Builder *builder, OperationState *result, Attribute value); - - Attribute getValue() const { return getAttr("value"); } - - static StringRef getOperationName() { return "constant"; } - - // Hooks to customize behavior of this op. - static bool parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p) const; - bool verify() const; - Attribute constantFold(ArrayRef operands, - MLIRContext *context) const; - -protected: - friend class Instruction; - explicit ConstantOp(const Instruction *state) : Op(state) {} -}; - -/// This is a refinement of the "constant" op for the case where it is -/// returning a float value of FloatType. -/// -/// %1 = "constant"(){value: 42.0} : bf16 -/// -class ConstantFloatOp : public ConstantOp { -public: - /// Builds a constant float op producing a float of the specified type. - static void build(Builder *builder, OperationState *result, - const APFloat &value, FloatType type); - - APFloat getValue() const { - return getAttrOfType("value").getValue(); - } - - static bool isClassFor(const Instruction *op); - -private: - friend class Instruction; - explicit ConstantFloatOp(const Instruction *state) : ConstantOp(state) {} -}; - -/// This is a refinement of the "constant" op for the case where it is -/// returning an integer value of IntegerType. -/// -/// %1 = "constant"(){value: 42} : i32 -/// -class ConstantIntOp : public ConstantOp { -public: - /// Build a constant int op producing an integer of the specified width. - static void build(Builder *builder, OperationState *result, int64_t value, - unsigned width); - - /// Build a constant int op producing an integer with the specified type, - /// which must be an integer type. - static void build(Builder *builder, OperationState *result, int64_t value, - Type type); - - int64_t getValue() const { - return getAttrOfType("value").getInt(); - } - - static bool isClassFor(const Instruction *op); - -private: - friend class Instruction; - explicit ConstantIntOp(const Instruction *state) : ConstantOp(state) {} -}; - -/// This is a refinement of the "constant" op for the case where it is -/// returning an integer value of Index type. -/// -/// %1 = "constant"(){value: 99} : () -> index -/// -class ConstantIndexOp : public ConstantOp { -public: - /// Build a constant int op producing an index. - static void build(Builder *builder, OperationState *result, int64_t value); - - int64_t getValue() const { - return getAttrOfType("value").getInt(); - } - - static bool isClassFor(const Instruction *op); - -private: - friend class Instruction; - explicit ConstantIndexOp(const Instruction *state) : ConstantOp(state) {} -}; - -/// The "return" operation represents a return instruction within a function. -/// The operation takes variable number of operands and produces no results. -/// The operand number and types must match the signature of the function -/// that contains the operation. For example: -/// -/// mlfunc @foo() : (i32, f8) { -/// ... -/// return %0, %1 : i32, f8 -/// -class ReturnOp : public Op { -public: - static StringRef getOperationName() { return "return"; } - - static void build(Builder *builder, OperationState *result, - ArrayRef results = {}); - - // Hooks to customize behavior of this op. - static bool parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p) const; - bool verify() const; - -private: - friend class Instruction; - explicit ReturnOp(const Instruction *state) : Op(state) {} -}; - -/// Prints dimension and symbol list. -void printDimAndSymbolList(Instruction::const_operand_iterator begin, - Instruction::const_operand_iterator end, - unsigned numDims, OpAsmPrinter *p); - -/// Parses dimension and symbol list and returns true if parsing failed. -bool parseDimAndSymbolList(OpAsmParser *parser, - SmallVector &operands, - unsigned &numDims); - -} // end namespace mlir - -#endif diff --git a/mlir/include/mlir/IR/Instruction.h b/mlir/include/mlir/IR/Instruction.h index 564a8f16d71..a5a599d13f4 100644 --- a/mlir/include/mlir/IR/Instruction.h +++ b/mlir/include/mlir/IR/Instruction.h @@ -268,9 +268,6 @@ public: /// take O(N) where N is the number of instructions within the parent block. bool isBeforeInBlock(const Instruction *other) const; - /// Check if this instruction is a return instruction. - bool isReturn() const; - void print(raw_ostream &os) const; void dump() const; diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h index 7de84b58d40..b6105e3bf0e 100644 --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -25,7 +25,7 @@ #define MLIR_MATCHERS_H #include "mlir/IR/Attributes.h" -#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Instruction.h" #include "mlir/IR/StandardTypes.h" #include "mlir/IR/Value.h" #include @@ -134,11 +134,6 @@ inline bool matchPattern(Value *value, const Pattern &pattern) { return false; } -/// Matches a ConstantIndexOp. -inline detail::op_matcher m_ConstantIndex() { - return detail::op_matcher(); -} - /// Matches a constant holding a scalar/vector/tensor integer (splat) and /// writes the integer value to bind_value. inline detail::constant_int_op_binder diff --git a/mlir/include/mlir/StandardOps/Ops.h b/mlir/include/mlir/StandardOps/Ops.h index 7cca20cf039..42c46c26381 100644 --- a/mlir/include/mlir/StandardOps/Ops.h +++ b/mlir/include/mlir/StandardOps/Ops.h @@ -83,6 +83,43 @@ private: explicit AllocOp(const Instruction *state) : Op(state) {} }; +/// The "br" operation represents a branch instruction in a function. +/// The operation takes variable number of operands and produces no results. +/// The operand number and types for each successor must match the +/// arguments of the block successor. For example: +/// +/// ^bb2: +/// %2 = call @someFn() +/// br ^bb3(%2 : tensor<*xf32>) +/// ^bb3(%3: tensor<*xf32>): +/// +class BranchOp : public Op { +public: + static StringRef getOperationName() { return "br"; } + + static void build(Builder *builder, OperationState *result, Block *dest, + ArrayRef operands = {}); + + // Hooks to customize behavior of this op. + static bool parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p) const; + + /// Return the block this branch jumps to. + Block *getDest(); + const Block *getDest() const { + return const_cast(this)->getDest(); + } + void setDest(Block *block); + + /// Erase the operand at 'index' from the operand list. + void eraseOperand(unsigned index); + +private: + friend class Instruction; + explicit BranchOp(const Instruction *state) : Op(state) {} +}; + /// The "call" operation represents a direct call to a function. The operands /// and result types of the call must match the specified function type. The /// callee is encoded as a function attribute named "callee". @@ -237,6 +274,248 @@ private: explicit CmpIOp(const Instruction *state) : Op(state) {} }; +/// The "cond_br" operation represents a conditional branch instruction in a +/// function. The operation takes variable number of operands and produces +/// no results. The operand number and types for each successor must match the +// arguments of the block successor. For example: +/// +/// ^bb0: +/// %0 = extract_element %arg0[] : tensor +/// cond_br %0, ^bb1, ^bb2 +/// ^bb1: +/// ... +/// ^bb2: +/// ... +/// +class CondBranchOp : public Op::Impl, + OpTrait::ZeroResult, OpTrait::IsTerminator> { + // These are the indices into the dests list. + enum { trueIndex = 0, falseIndex = 1 }; + + /// The operands list of a conditional branch operation is layed out as + /// follows: + /// { condition, [true_operands], [false_operands] } +public: + static StringRef getOperationName() { return "cond_br"; } + + static void build(Builder *builder, OperationState *result, Value *condition, + Block *trueDest, ArrayRef trueOperands, + Block *falseDest, ArrayRef falseOperands); + + // Hooks to customize behavior of this op. + static bool parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p) const; + bool verify() const; + + static void getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context); + + // The condition operand is the first operand in the list. + Value *getCondition() { return getOperand(0); } + const Value *getCondition() const { return getOperand(0); } + + /// Return the destination if the condition is true. + Block *getTrueDest(); + const Block *getTrueDest() const { + return const_cast(this)->getTrueDest(); + } + + /// Return the destination if the condition is false. + Block *getFalseDest(); + const Block *getFalseDest() const { + return const_cast(this)->getFalseDest(); + } + + // Accessors for operands to the 'true' destination. + Value *getTrueOperand(unsigned idx) { + assert(idx < getNumTrueOperands()); + return getOperand(getTrueDestOperandIndex() + idx); + } + const Value *getTrueOperand(unsigned idx) const { + return const_cast(this)->getTrueOperand(idx); + } + void setTrueOperand(unsigned idx, Value *value) { + assert(idx < getNumTrueOperands()); + setOperand(getTrueDestOperandIndex() + idx, value); + } + + operand_iterator true_operand_begin() { + return operand_begin() + getTrueDestOperandIndex(); + } + operand_iterator true_operand_end() { + return true_operand_begin() + getNumTrueOperands(); + } + llvm::iterator_range getTrueOperands() { + return {true_operand_begin(), true_operand_end()}; + } + + const_operand_iterator true_operand_begin() const { + return operand_begin() + getTrueDestOperandIndex(); + } + const_operand_iterator true_operand_end() const { + return true_operand_begin() + getNumTrueOperands(); + } + llvm::iterator_range getTrueOperands() const { + return {true_operand_begin(), true_operand_end()}; + } + + unsigned getNumTrueOperands() const; + + /// Erase the operand at 'index' from the true operand list. + void eraseTrueOperand(unsigned index); + + // Accessors for operands to the 'false' destination. + Value *getFalseOperand(unsigned idx) { + assert(idx < getNumFalseOperands()); + return getOperand(getFalseDestOperandIndex() + idx); + } + const Value *getFalseOperand(unsigned idx) const { + return const_cast(this)->getFalseOperand(idx); + } + void setFalseOperand(unsigned idx, Value *value) { + assert(idx < getNumFalseOperands()); + setOperand(getFalseDestOperandIndex() + idx, value); + } + + operand_iterator false_operand_begin() { return true_operand_end(); } + operand_iterator false_operand_end() { + return false_operand_begin() + getNumFalseOperands(); + } + llvm::iterator_range getFalseOperands() { + return {false_operand_begin(), false_operand_end()}; + } + + const_operand_iterator false_operand_begin() const { + return true_operand_end(); + } + const_operand_iterator false_operand_end() const { + return false_operand_begin() + getNumFalseOperands(); + } + llvm::iterator_range getFalseOperands() const { + return {false_operand_begin(), false_operand_end()}; + } + + unsigned getNumFalseOperands() const; + + /// Erase the operand at 'index' from the false operand list. + void eraseFalseOperand(unsigned index); + +private: + /// Get the index of the first true destination operand. + unsigned getTrueDestOperandIndex() const { return 1; } + + /// Get the index of the first false destination operand. + unsigned getFalseDestOperandIndex() const { + return getTrueDestOperandIndex() + getNumTrueOperands(); + } + + friend class Instruction; + explicit CondBranchOp(const Instruction *state) : Op(state) {} +}; + +/// The "constant" operation requires a single attribute named "value". +/// It returns its value as an SSA value. For example: +/// +/// %1 = "constant"(){value: 42} : i32 +/// %2 = "constant"(){value: @foo} : (f32)->f32 +/// +class ConstantOp : public Op { +public: + /// Builds a constant op with the specified attribute value and result type. + static void build(Builder *builder, OperationState *result, Type type, + Attribute value); + + /// Builds a constant op with the specified attribute value and the + /// attribute's type. + static void build(Builder *builder, OperationState *result, Attribute value); + + Attribute getValue() const { return getAttr("value"); } + + static StringRef getOperationName() { return "constant"; } + + // Hooks to customize behavior of this op. + static bool parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p) const; + bool verify() const; + Attribute constantFold(ArrayRef operands, + MLIRContext *context) const; + +protected: + friend class Instruction; + explicit ConstantOp(const Instruction *state) : Op(state) {} +}; + +/// This is a refinement of the "constant" op for the case where it is +/// returning a float value of FloatType. +/// +/// %1 = "constant"(){value: 42.0} : bf16 +/// +class ConstantFloatOp : public ConstantOp { +public: + /// Builds a constant float op producing a float of the specified type. + static void build(Builder *builder, OperationState *result, + const APFloat &value, FloatType type); + + APFloat getValue() const { + return getAttrOfType("value").getValue(); + } + + static bool isClassFor(const Instruction *op); + +private: + friend class Instruction; + explicit ConstantFloatOp(const Instruction *state) : ConstantOp(state) {} +}; + +/// This is a refinement of the "constant" op for the case where it is +/// returning an integer value of IntegerType. +/// +/// %1 = "constant"(){value: 42} : i32 +/// +class ConstantIntOp : public ConstantOp { +public: + /// Build a constant int op producing an integer of the specified width. + static void build(Builder *builder, OperationState *result, int64_t value, + unsigned width); + + /// Build a constant int op producing an integer with the specified type, + /// which must be an integer type. + static void build(Builder *builder, OperationState *result, int64_t value, + Type type); + + int64_t getValue() const { + return getAttrOfType("value").getInt(); + } + + static bool isClassFor(const Instruction *op); + +private: + friend class Instruction; + explicit ConstantIntOp(const Instruction *state) : ConstantOp(state) {} +}; + +/// This is a refinement of the "constant" op for the case where it is +/// returning an integer value of Index type. +/// +/// %1 = "constant"(){value: 99} : () -> index +/// +class ConstantIndexOp : public ConstantOp { +public: + /// Build a constant int op producing an index. + static void build(Builder *builder, OperationState *result, int64_t value); + + int64_t getValue() const { + return getAttrOfType("value").getInt(); + } + + static bool isClassFor(const Instruction *op); + +private: + friend class Instruction; + explicit ConstantIndexOp(const Instruction *state) : ConstantOp(state) {} +}; + /// The "dealloc" operation frees the region of memory referenced by a memref /// which was originally created by the "alloc" operation. /// The "dealloc" operation should not be called on memrefs which alias an @@ -636,6 +915,33 @@ private: explicit MemRefCastOp(const Instruction *state) : CastOp(state) {} }; +/// The "return" operation represents a return instruction within a function. +/// The operation takes variable number of operands and produces no results. +/// The operand number and types must match the signature of the function +/// that contains the operation. For example: +/// +/// mlfunc @foo() : (i32, f8) { +/// ... +/// return %0, %1 : i32, f8 +/// +class ReturnOp : public Op { +public: + static StringRef getOperationName() { return "return"; } + + static void build(Builder *builder, OperationState *result, + ArrayRef results = {}); + + // Hooks to customize behavior of this op. + static bool parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p) const; + bool verify() const; + +private: + friend class Instruction; + explicit ReturnOp(const Instruction *state) : Op(state) {} +}; + /// The "select" operation chooses one value based on a binary condition /// supplied as its first operand. If the value of the first operand is 1, the /// second operand is chosen, otherwise the third operand is chosen. The second @@ -749,6 +1055,16 @@ private: explicit TensorCastOp(const Instruction *state) : CastOp(state) {} }; +/// Prints dimension and symbol list. +void printDimAndSymbolList(Instruction::const_operand_iterator begin, + Instruction::const_operand_iterator end, + unsigned numDims, OpAsmPrinter *p); + +/// Parses dimension and symbol list and returns true if parsing failed. +bool parseDimAndSymbolList(OpAsmParser *parser, + SmallVector &operands, + unsigned &numDims); + } // end namespace mlir #endif // MLIR_STANDARDOPS_OPS_H diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h index 3dab02a4cd7..78968ae2a7d 100644 --- a/mlir/include/mlir/Transforms/Utils.h +++ b/mlir/include/mlir/Transforms/Utils.h @@ -26,7 +26,7 @@ #define MLIR_TRANSFORMS_UTILS_H #include "mlir/IR/AffineMap.h" -#include "mlir/IR/BuiltinOps.h" +#include "mlir/StandardOps/Ops.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 7c0fd29c191..f1810269c33 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -18,7 +18,6 @@ #include "mlir/AffineOps/AffineOps.h" #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index ad364b17a45..2f39899e756 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -26,7 +26,6 @@ #include "mlir/Analysis/Utils.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Instruction.h" #include "mlir/IR/IntegerSet.h" #include "mlir/StandardOps/Ops.h" diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 26d72f55a7b..f28a2423d77 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -23,9 +23,9 @@ #include "mlir/AffineOps/AffineOps.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Instruction.h" #include "mlir/IR/IntegerSet.h" +#include "mlir/StandardOps/Ops.h" #include "mlir/Support/MathExtras.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SmallPtrSet.h" diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index d17f4560d69..96ba1958105 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -27,7 +27,6 @@ #include "mlir/Analysis/NestedMatcher.h" #include "mlir/Analysis/VectorAnalysis.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Instruction.h" #include "mlir/StandardOps/Ops.h" #include "mlir/SuperVectorOps/SuperVectorOps.h" diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index f731ba17686..8edf79d6db3 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -25,7 +25,6 @@ #include "mlir/Analysis/Passes.h" #include "mlir/Analysis/Utils.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" #include "mlir/StandardOps/Ops.h" #include "llvm/Support/Debug.h" diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp index 7b303f0d070..0206765f880 100644 --- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp @@ -24,7 +24,6 @@ #include "mlir/Analysis/Passes.h" #include "mlir/Analysis/Utils.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" #include "mlir/StandardOps/Ops.h" #include "llvm/Support/Debug.h" diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index 877a0f2f364..6360e89289b 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -22,7 +22,6 @@ #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/VectorAnalysis.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Instruction.h" #include "mlir/Support/Functional.h" #include "mlir/Support/STLExtras.h" diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 437cc2254af..5f7c6aa19a5 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -26,7 +26,6 @@ #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/StandardOps/Ops.h" #include "llvm/ADT/DenseMap.h" #include "llvm/Support/Debug.h" diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index 815831b7922..5ca3a829cbd 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -19,7 +19,6 @@ #include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/LoopAnalysis.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Instruction.h" #include "mlir/IR/IntegerSet.h" #include "mlir/StandardOps/Ops.h" diff --git a/mlir/lib/EDSC/LowerEDSCTestPass.cpp b/mlir/lib/EDSC/LowerEDSCTestPass.cpp index b1ba9f0503c..f904536e71d 100644 --- a/mlir/lib/EDSC/LowerEDSCTestPass.cpp +++ b/mlir/lib/EDSC/LowerEDSCTestPass.cpp @@ -19,7 +19,6 @@ #include "mlir/EDSC/MLIREmitter.h" #include "mlir/EDSC/Types.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/StandardTypes.h" diff --git a/mlir/lib/EDSC/MLIREmitter.cpp b/mlir/lib/EDSC/MLIREmitter.cpp index 5165b2d0527..202d254aee0 100644 --- a/mlir/lib/EDSC/MLIREmitter.cpp +++ b/mlir/lib/EDSC/MLIREmitter.cpp @@ -26,7 +26,6 @@ #include "mlir/EDSC/MLIREmitter.h" #include "mlir/EDSC/Types.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Instruction.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Location.h" diff --git a/mlir/lib/EDSC/Types.cpp b/mlir/lib/EDSC/Types.cpp index 5571913245e..d7abb8f5368 100644 --- a/mlir/lib/EDSC/Types.cpp +++ b/mlir/lib/EDSC/Types.cpp @@ -23,7 +23,6 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Function.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/StandardTypes.h" diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 845d56289d1..b96501960c4 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -23,11 +23,12 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" -#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" #include "mlir/IR/Function.h" #include "mlir/IR/Instruction.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/Module.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/StandardTypes.h" @@ -1205,20 +1206,23 @@ void FunctionPrinter::numberValueID(const Value *value) { // Give constant integers special names. if (auto *op = value->getDefiningInst()) { - if (auto intOp = op->dyn_cast()) { - // i1 constants get special names. - if (intOp->getType().isInteger(1)) { - specialName << (intOp->getValue() ? "true" : "false"); - } else { - specialName << 'c' << intOp->getValue() << '_' << intOp->getType(); - } - } else if (auto intOp = op->dyn_cast()) { - specialName << 'c' << intOp->getValue(); - } else if (auto constant = op->dyn_cast()) { - if (constant->getValue().isa()) + Attribute cst; + if (m_Constant(&cst).match(const_cast(op))) { + Type type = op->getResult(0)->getType(); + if (auto intCst = cst.dyn_cast()) { + if (type.isIndex()) { + specialName << 'c' << intCst; + } else if (type.cast().isInteger(1)) { + // i1 constants get special names. + specialName << (intCst.getInt() ? "true" : "false"); + } else { + specialName << 'c' << intCst << '_' << type; + } + } else if (cst.isa()) { specialName << 'f'; - else + } else { specialName << "cst"; + } } } diff --git a/mlir/lib/IR/BuiltinOps.cpp b/mlir/lib/IR/BuiltinOps.cpp deleted file mode 100644 index f859babbe1d..00000000000 --- a/mlir/lib/IR/BuiltinOps.cpp +++ /dev/null @@ -1,454 +0,0 @@ -//===- BuiltinOps.cpp - Builtin MLIR Operations -------------------------===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/IR/Value.h" -#include "mlir/Support/MathExtras.h" -#include "mlir/Support/STLExtras.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/Support/raw_ostream.h" - -using namespace mlir; - -//===----------------------------------------------------------------------===// -// BuiltinDialect -//===----------------------------------------------------------------------===// - -BuiltinDialect::BuiltinDialect(MLIRContext *context) - : Dialect(/*namePrefix=*/"", context) { - addOperations(); - addTypes(); -} - -void mlir::printDimAndSymbolList(Instruction::const_operand_iterator begin, - Instruction::const_operand_iterator end, - unsigned numDims, OpAsmPrinter *p) { - *p << '('; - p->printOperands(begin, begin + numDims); - *p << ')'; - - if (begin + numDims != end) { - *p << '['; - p->printOperands(begin + numDims, end); - *p << ']'; - } -} - -// Parses dimension and symbol list, and sets 'numDims' to the number of -// dimension operands parsed. -// Returns 'false' on success and 'true' on error. -bool mlir::parseDimAndSymbolList(OpAsmParser *parser, - SmallVector &operands, - unsigned &numDims) { - SmallVector opInfos; - if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::Paren)) - return true; - // Store number of dimensions for validation by caller. - numDims = opInfos.size(); - - // Parse the optional symbol operands. - auto affineIntTy = parser->getBuilder().getIndexType(); - if (parser->parseOperandList(opInfos, -1, - OpAsmParser::Delimiter::OptionalSquare) || - parser->resolveOperands(opInfos, affineIntTy, operands)) - return true; - return false; -} - -//===----------------------------------------------------------------------===// -// BranchOp -//===----------------------------------------------------------------------===// - -void BranchOp::build(Builder *builder, OperationState *result, Block *dest, - ArrayRef operands) { - result->addSuccessor(dest, operands); -} - -bool BranchOp::parse(OpAsmParser *parser, OperationState *result) { - Block *dest; - SmallVector destOperands; - if (parser->parseSuccessorAndUseList(dest, destOperands)) - return true; - result->addSuccessor(dest, destOperands); - return false; -} - -void BranchOp::print(OpAsmPrinter *p) const { - *p << "br "; - p->printSuccessorAndUseList(getInstruction(), 0); -} - -Block *BranchOp::getDest() { return getInstruction()->getSuccessor(0); } - -void BranchOp::setDest(Block *block) { - return getInstruction()->setSuccessor(block, 0); -} - -void BranchOp::eraseOperand(unsigned index) { - getInstruction()->eraseSuccessorOperand(0, index); -} - -//===----------------------------------------------------------------------===// -// CondBranchOp -//===----------------------------------------------------------------------===// - -namespace { -/// cond_br true, ^bb1, ^bb2 -> br ^bb1 -/// cond_br false, ^bb1, ^bb2 -> br ^bb2 -/// -struct SimplifyConstCondBranchPred : public RewritePattern { - SimplifyConstCondBranchPred(MLIRContext *context) - : RewritePattern(CondBranchOp::getOperationName(), 1, context) {} - - PatternMatchResult match(Instruction *op) const override { - auto condbr = op->cast(); - if (matchPattern(condbr->getCondition(), m_Op())) - return matchSuccess(); - - return matchFailure(); - } - void rewrite(Instruction *op, PatternRewriter &rewriter) const override { - auto condbr = op->cast(); - Block *foldedDest; - SmallVector branchArgs; - - // If the condition is known to evaluate to false we fold to a branch to the - // false destination. Otherwise, we fold to a branch to the true - // destination. - if (matchPattern(condbr->getCondition(), m_Zero())) { - foldedDest = condbr->getFalseDest(); - branchArgs.assign(condbr->false_operand_begin(), - condbr->false_operand_end()); - } else { - foldedDest = condbr->getTrueDest(); - branchArgs.assign(condbr->true_operand_begin(), - condbr->true_operand_end()); - } - - rewriter.replaceOpWithNewOp(op, foldedDest, branchArgs); - } -}; -} // end anonymous namespace. - -void CondBranchOp::build(Builder *builder, OperationState *result, - Value *condition, Block *trueDest, - ArrayRef trueOperands, Block *falseDest, - ArrayRef falseOperands) { - result->addOperands(condition); - result->addSuccessor(trueDest, trueOperands); - result->addSuccessor(falseDest, falseOperands); -} - -bool CondBranchOp::parse(OpAsmParser *parser, OperationState *result) { - SmallVector destOperands; - Block *dest; - OpAsmParser::OperandType condInfo; - - // Parse the condition. - Type int1Ty = parser->getBuilder().getI1Type(); - if (parser->parseOperand(condInfo) || parser->parseComma() || - parser->resolveOperand(condInfo, int1Ty, result->operands)) { - return parser->emitError(parser->getNameLoc(), - "expected condition type was boolean (i1)"); - } - - // Parse the true successor. - if (parser->parseSuccessorAndUseList(dest, destOperands)) - return true; - result->addSuccessor(dest, destOperands); - - // Parse the false successor. - destOperands.clear(); - if (parser->parseComma() || - parser->parseSuccessorAndUseList(dest, destOperands)) - return true; - result->addSuccessor(dest, destOperands); - - // Return false on success. - return false; -} - -void CondBranchOp::print(OpAsmPrinter *p) const { - *p << "cond_br "; - p->printOperand(getCondition()); - *p << ", "; - p->printSuccessorAndUseList(getInstruction(), trueIndex); - *p << ", "; - p->printSuccessorAndUseList(getInstruction(), falseIndex); -} - -bool CondBranchOp::verify() const { - if (!getCondition()->getType().isInteger(1)) - return emitOpError("expected condition type was boolean (i1)"); - return false; -} - -void CondBranchOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.push_back(std::make_unique(context)); -} - -Block *CondBranchOp::getTrueDest() { - return getInstruction()->getSuccessor(trueIndex); -} - -Block *CondBranchOp::getFalseDest() { - return getInstruction()->getSuccessor(falseIndex); -} - -unsigned CondBranchOp::getNumTrueOperands() const { - return getInstruction()->getNumSuccessorOperands(trueIndex); -} - -void CondBranchOp::eraseTrueOperand(unsigned index) { - getInstruction()->eraseSuccessorOperand(trueIndex, index); -} - -unsigned CondBranchOp::getNumFalseOperands() const { - return getInstruction()->getNumSuccessorOperands(falseIndex); -} - -void CondBranchOp::eraseFalseOperand(unsigned index) { - getInstruction()->eraseSuccessorOperand(falseIndex, index); -} - -//===----------------------------------------------------------------------===// -// Constant*Op -//===----------------------------------------------------------------------===// - -/// Builds a constant op with the specified attribute value and result type. -void ConstantOp::build(Builder *builder, OperationState *result, Type type, - Attribute value) { - result->addAttribute("value", value); - result->types.push_back(type); -} - -// Extracts and returns a type of an attribute if it has one. Returns a null -// type otherwise. Currently, NumericAttrs and FunctionAttrs have types. -static Type getAttributeType(Attribute attr) { - assert(attr && "expected non-null attribute"); - if (auto numericAttr = attr.dyn_cast()) - return numericAttr.getType(); - if (auto functionAttr = attr.dyn_cast()) - return functionAttr.getType(); - return {}; -} - -/// Builds a constant with the specified attribute value and type extracted -/// from the attribute. The attribute must have a type. -void ConstantOp::build(Builder *builder, OperationState *result, - Attribute value) { - Type t = getAttributeType(value); - assert(t && "expected an attribute with a type"); - return build(builder, result, t, value); -} - -void ConstantOp::print(OpAsmPrinter *p) const { - *p << "constant "; - p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{"value"}); - - if (getAttrs().size() > 1) - *p << ' '; - *p << getValue(); - if (!getValue().isa()) - *p << " : " << getType(); -} - -bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) { - Attribute valueAttr; - Type type; - - if (parser->parseOptionalAttributeDict(result->attributes) || - parser->parseAttribute(valueAttr, "value", result->attributes)) - return true; - - // 'constant' taking a function reference doesn't get a redundant type - // specifier. The attribute itself carries it. - if (auto fnAttr = valueAttr.dyn_cast()) - return parser->addTypeToList(fnAttr.getValue()->getType(), result->types); - - if (auto intAttr = valueAttr.dyn_cast()) { - type = intAttr.getType(); - } else if (auto fpAttr = valueAttr.dyn_cast()) { - type = fpAttr.getType(); - } else if (parser->parseColonType(type)) { - return true; - } - return parser->addTypeToList(type, result->types); -} - -/// The constant op requires an attribute, and furthermore requires that it -/// matches the return type. -bool ConstantOp::verify() const { - auto value = getValue(); - if (!value) - return emitOpError("requires a 'value' attribute"); - - auto type = this->getType(); - if (type.isa() || type.isIndex()) { - auto intAttr = value.dyn_cast(); - if (!intAttr) - return emitOpError( - "requires 'value' to be an integer for an integer result type"); - - // If the type has a known bitwidth we verify that the value can be - // represented with the given bitwidth. - if (!type.isIndex()) { - auto bitwidth = type.cast().getWidth(); - auto intVal = intAttr.getValue(); - if (!intVal.isSignedIntN(bitwidth) && !intVal.isIntN(bitwidth)) - return emitOpError("requires 'value' to be an integer within the range " - "of the integer result type"); - } - return false; - } - - if (type.isa()) { - if (!value.isa()) - return emitOpError("requires 'value' to be a floating point constant"); - return false; - } - - if (type.isa()) { - if (!value.isa()) - return emitOpError("requires 'value' to be a vector/tensor constant"); - return false; - } - - if (type.isa()) { - if (!value.isa()) - return emitOpError("requires 'value' to be a function reference"); - return false; - } - - auto attrType = getAttributeType(value); - if (!attrType) - return emitOpError("requires 'value' attribute to have a type"); - if (attrType != type) - return emitOpError("requires the type of the 'value' attribute to match " - "that of the operation result"); - - return emitOpError( - "requires a result type that aligns with the 'value' attribute"); -} - -Attribute ConstantOp::constantFold(ArrayRef operands, - MLIRContext *context) const { - assert(operands.empty() && "constant has no operands"); - return getValue(); -} - -void ConstantFloatOp::build(Builder *builder, OperationState *result, - const APFloat &value, FloatType type) { - ConstantOp::build(builder, result, type, builder->getFloatAttr(type, value)); -} - -bool ConstantFloatOp::isClassFor(const Instruction *op) { - return ConstantOp::isClassFor(op) && - op->getResult(0)->getType().isa(); -} - -/// ConstantIntOp only matches values whose result type is an IntegerType. -bool ConstantIntOp::isClassFor(const Instruction *op) { - return ConstantOp::isClassFor(op) && - op->getResult(0)->getType().isa(); -} - -void ConstantIntOp::build(Builder *builder, OperationState *result, - int64_t value, unsigned width) { - Type type = builder->getIntegerType(width); - ConstantOp::build(builder, result, type, - builder->getIntegerAttr(type, value)); -} - -/// Build a constant int op producing an integer with the specified type, -/// which must be an integer type. -void ConstantIntOp::build(Builder *builder, OperationState *result, - int64_t value, Type type) { - assert(type.isa() && "ConstantIntOp can only have integer type"); - ConstantOp::build(builder, result, type, - builder->getIntegerAttr(type, value)); -} - -/// ConstantIndexOp only matches values whose result type is Index. -bool ConstantIndexOp::isClassFor(const Instruction *op) { - return ConstantOp::isClassFor(op) && op->getResult(0)->getType().isIndex(); -} - -void ConstantIndexOp::build(Builder *builder, OperationState *result, - int64_t value) { - Type type = builder->getIndexType(); - ConstantOp::build(builder, result, type, - builder->getIntegerAttr(type, value)); -} - -//===----------------------------------------------------------------------===// -// ReturnOp -//===----------------------------------------------------------------------===// - -void ReturnOp::build(Builder *builder, OperationState *result, - ArrayRef results) { - result->addOperands(results); -} - -bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) { - SmallVector opInfo; - SmallVector types; - llvm::SMLoc loc; - return parser->getCurrentLocation(&loc) || parser->parseOperandList(opInfo) || - (!opInfo.empty() && parser->parseColonTypeList(types)) || - parser->resolveOperands(opInfo, types, loc, result->operands); -} - -void ReturnOp::print(OpAsmPrinter *p) const { - *p << "return"; - if (getNumOperands() > 0) { - *p << ' '; - p->printOperands(operand_begin(), operand_end()); - *p << " : "; - interleave( - operand_begin(), operand_end(), - [&](const Value *e) { p->printType(e->getType()); }, - [&]() { *p << ", "; }); - } -} - -bool ReturnOp::verify() const { - auto *function = getInstruction()->getFunction(); - - // The operand number and types must match the function signature. - const auto &results = function->getType().getResults(); - if (getNumOperands() != results.size()) - return emitOpError("has " + Twine(getNumOperands()) + - " operands, but enclosing function returns " + - Twine(results.size())); - - for (unsigned i = 0, e = results.size(); i != e; ++i) - if (getOperand(i)->getType() != results[i]) - return emitError("type of return operand " + Twine(i) + - " doesn't match function result type"); - - return false; -} diff --git a/mlir/lib/IR/Instruction.cpp b/mlir/lib/IR/Instruction.cpp index fb57bfbc338..4dc5c3ef12e 100644 --- a/mlir/lib/IR/Instruction.cpp +++ b/mlir/lib/IR/Instruction.cpp @@ -19,7 +19,7 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" #include "mlir/IR/Function.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLIRContext.h" @@ -481,8 +481,6 @@ bool Instruction::use_empty() const { return true; } -bool Instruction::isReturn() const { return isa(); } - void Instruction::setSuccessor(Block *block, unsigned index) { assert(index < getNumSuccessors()); getBlockOperands()[index].set(block); diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 10a8ec2ca59..8626337fc32 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -25,7 +25,7 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" -#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" #include "mlir/IR/Function.h" #include "mlir/IR/Identifier.h" #include "mlir/IR/IntegerSet.h" @@ -45,6 +45,15 @@ using namespace mlir::detail; using namespace llvm; namespace { +/// A builtin dialect to define types/etc that are necessary for the +/// validity of the IR. +struct BuiltinDialect : public Dialect { + BuiltinDialect(MLIRContext *context) : Dialect(/*namePrefix=*/"", context) { + addTypes(); + } +}; + struct AffineMapKeyInfo : DenseMapInfo { // Affine maps are uniqued based on their dim/symbol counts and affine // expressions. diff --git a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp index ee409e4982c..748635d87e0 100644 --- a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp +++ b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp @@ -21,7 +21,6 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/PatternMatch.h" diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 5a14c587172..30cc13f36e9 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -25,7 +25,6 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 50f258c84cd..7532e6f8196 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -19,7 +19,6 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" @@ -37,14 +36,57 @@ using namespace mlir; StandardOpsDialect::StandardOpsDialect(MLIRContext *context) : Dialect(/*namePrefix=*/"", context) { - addOperations(); } +void mlir::printDimAndSymbolList(Instruction::const_operand_iterator begin, + Instruction::const_operand_iterator end, + unsigned numDims, OpAsmPrinter *p) { + *p << '('; + p->printOperands(begin, begin + numDims); + *p << ')'; + + if (begin + numDims != end) { + *p << '['; + p->printOperands(begin + numDims, end); + *p << ']'; + } +} + +// Parses dimension and symbol list, and sets 'numDims' to the number of +// dimension operands parsed. +// Returns 'false' on success and 'true' on error. +bool mlir::parseDimAndSymbolList(OpAsmParser *parser, + SmallVector &operands, + unsigned &numDims) { + SmallVector opInfos; + if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::Paren)) + return true; + // Store number of dimensions for validation by caller. + numDims = opInfos.size(); + + // Parse the optional symbol operands. + auto affineIntTy = parser->getBuilder().getIndexType(); + if (parser->parseOperandList(opInfos, -1, + OpAsmParser::Delimiter::OptionalSquare) || + parser->resolveOperands(opInfos, affineIntTy, operands)) + return true; + return false; +} + +/// Matches a ConstantIndexOp. +/// TODO: This should probably just be a general matcher that uses m_Constant +/// and checks the operation for an index type. +static detail::op_matcher m_ConstantIndex() { + return detail::op_matcher(); +} + //===----------------------------------------------------------------------===// // Common canonicalization pattern support logic //===----------------------------------------------------------------------===// @@ -310,6 +352,39 @@ void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, results.push_back(std::make_unique(context)); } +//===----------------------------------------------------------------------===// +// BranchOp +//===----------------------------------------------------------------------===// + +void BranchOp::build(Builder *builder, OperationState *result, Block *dest, + ArrayRef operands) { + result->addSuccessor(dest, operands); +} + +bool BranchOp::parse(OpAsmParser *parser, OperationState *result) { + Block *dest; + SmallVector destOperands; + if (parser->parseSuccessorAndUseList(dest, destOperands)) + return true; + result->addSuccessor(dest, destOperands); + return false; +} + +void BranchOp::print(OpAsmPrinter *p) const { + *p << "br "; + p->printSuccessorAndUseList(getInstruction(), 0); +} + +Block *BranchOp::getDest() { return getInstruction()->getSuccessor(0); } + +void BranchOp::setDest(Block *block) { + return getInstruction()->setSuccessor(block, 0); +} + +void BranchOp::eraseOperand(unsigned index) { + getInstruction()->eraseSuccessorOperand(0, index); +} + //===----------------------------------------------------------------------===// // CallOp //===----------------------------------------------------------------------===// @@ -692,6 +767,300 @@ Attribute CmpIOp::constantFold(ArrayRef operands, return IntegerAttr::get(IntegerType::get(1, context), APInt(1, val)); } +//===----------------------------------------------------------------------===// +// CondBranchOp +//===----------------------------------------------------------------------===// + +namespace { +/// cond_br true, ^bb1, ^bb2 -> br ^bb1 +/// cond_br false, ^bb1, ^bb2 -> br ^bb2 +/// +struct SimplifyConstCondBranchPred : public RewritePattern { + SimplifyConstCondBranchPred(MLIRContext *context) + : RewritePattern(CondBranchOp::getOperationName(), 1, context) {} + + PatternMatchResult match(Instruction *op) const override { + auto condbr = op->cast(); + if (matchPattern(condbr->getCondition(), m_Op())) + return matchSuccess(); + + return matchFailure(); + } + void rewrite(Instruction *op, PatternRewriter &rewriter) const override { + auto condbr = op->cast(); + Block *foldedDest; + SmallVector branchArgs; + + // If the condition is known to evaluate to false we fold to a branch to the + // false destination. Otherwise, we fold to a branch to the true + // destination. + if (matchPattern(condbr->getCondition(), m_Zero())) { + foldedDest = condbr->getFalseDest(); + branchArgs.assign(condbr->false_operand_begin(), + condbr->false_operand_end()); + } else { + foldedDest = condbr->getTrueDest(); + branchArgs.assign(condbr->true_operand_begin(), + condbr->true_operand_end()); + } + + rewriter.replaceOpWithNewOp(op, foldedDest, branchArgs); + } +}; +} // end anonymous namespace. + +void CondBranchOp::build(Builder *builder, OperationState *result, + Value *condition, Block *trueDest, + ArrayRef trueOperands, Block *falseDest, + ArrayRef falseOperands) { + result->addOperands(condition); + result->addSuccessor(trueDest, trueOperands); + result->addSuccessor(falseDest, falseOperands); +} + +bool CondBranchOp::parse(OpAsmParser *parser, OperationState *result) { + SmallVector destOperands; + Block *dest; + OpAsmParser::OperandType condInfo; + + // Parse the condition. + Type int1Ty = parser->getBuilder().getI1Type(); + if (parser->parseOperand(condInfo) || parser->parseComma() || + parser->resolveOperand(condInfo, int1Ty, result->operands)) { + return parser->emitError(parser->getNameLoc(), + "expected condition type was boolean (i1)"); + } + + // Parse the true successor. + if (parser->parseSuccessorAndUseList(dest, destOperands)) + return true; + result->addSuccessor(dest, destOperands); + + // Parse the false successor. + destOperands.clear(); + if (parser->parseComma() || + parser->parseSuccessorAndUseList(dest, destOperands)) + return true; + result->addSuccessor(dest, destOperands); + + // Return false on success. + return false; +} + +void CondBranchOp::print(OpAsmPrinter *p) const { + *p << "cond_br "; + p->printOperand(getCondition()); + *p << ", "; + p->printSuccessorAndUseList(getInstruction(), trueIndex); + *p << ", "; + p->printSuccessorAndUseList(getInstruction(), falseIndex); +} + +bool CondBranchOp::verify() const { + if (!getCondition()->getType().isInteger(1)) + return emitOpError("expected condition type was boolean (i1)"); + return false; +} + +void CondBranchOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.push_back(std::make_unique(context)); +} + +Block *CondBranchOp::getTrueDest() { + return getInstruction()->getSuccessor(trueIndex); +} + +Block *CondBranchOp::getFalseDest() { + return getInstruction()->getSuccessor(falseIndex); +} + +unsigned CondBranchOp::getNumTrueOperands() const { + return getInstruction()->getNumSuccessorOperands(trueIndex); +} + +void CondBranchOp::eraseTrueOperand(unsigned index) { + getInstruction()->eraseSuccessorOperand(trueIndex, index); +} + +unsigned CondBranchOp::getNumFalseOperands() const { + return getInstruction()->getNumSuccessorOperands(falseIndex); +} + +void CondBranchOp::eraseFalseOperand(unsigned index) { + getInstruction()->eraseSuccessorOperand(falseIndex, index); +} + +//===----------------------------------------------------------------------===// +// Constant*Op +//===----------------------------------------------------------------------===// + +/// Builds a constant op with the specified attribute value and result type. +void ConstantOp::build(Builder *builder, OperationState *result, Type type, + Attribute value) { + result->addAttribute("value", value); + result->types.push_back(type); +} + +// Extracts and returns a type of an attribute if it has one. Returns a null +// type otherwise. Currently, NumericAttrs and FunctionAttrs have types. +static Type getAttributeType(Attribute attr) { + assert(attr && "expected non-null attribute"); + if (auto numericAttr = attr.dyn_cast()) + return numericAttr.getType(); + if (auto functionAttr = attr.dyn_cast()) + return functionAttr.getType(); + return {}; +} + +/// Builds a constant with the specified attribute value and type extracted +/// from the attribute. The attribute must have a type. +void ConstantOp::build(Builder *builder, OperationState *result, + Attribute value) { + Type t = getAttributeType(value); + assert(t && "expected an attribute with a type"); + return build(builder, result, t, value); +} + +void ConstantOp::print(OpAsmPrinter *p) const { + *p << "constant "; + p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{"value"}); + + if (getAttrs().size() > 1) + *p << ' '; + *p << getValue(); + if (!getValue().isa()) + *p << " : " << getType(); +} + +bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) { + Attribute valueAttr; + Type type; + + if (parser->parseOptionalAttributeDict(result->attributes) || + parser->parseAttribute(valueAttr, "value", result->attributes)) + return true; + + // 'constant' taking a function reference doesn't get a redundant type + // specifier. The attribute itself carries it. + if (auto fnAttr = valueAttr.dyn_cast()) + return parser->addTypeToList(fnAttr.getValue()->getType(), result->types); + + if (auto intAttr = valueAttr.dyn_cast()) { + type = intAttr.getType(); + } else if (auto fpAttr = valueAttr.dyn_cast()) { + type = fpAttr.getType(); + } else if (parser->parseColonType(type)) { + return true; + } + return parser->addTypeToList(type, result->types); +} + +/// The constant op requires an attribute, and furthermore requires that it +/// matches the return type. +bool ConstantOp::verify() const { + auto value = getValue(); + if (!value) + return emitOpError("requires a 'value' attribute"); + + auto type = this->getType(); + if (type.isa() || type.isIndex()) { + auto intAttr = value.dyn_cast(); + if (!intAttr) + return emitOpError( + "requires 'value' to be an integer for an integer result type"); + + // If the type has a known bitwidth we verify that the value can be + // represented with the given bitwidth. + if (!type.isIndex()) { + auto bitwidth = type.cast().getWidth(); + auto intVal = intAttr.getValue(); + if (!intVal.isSignedIntN(bitwidth) && !intVal.isIntN(bitwidth)) + return emitOpError("requires 'value' to be an integer within the range " + "of the integer result type"); + } + return false; + } + + if (type.isa()) { + if (!value.isa()) + return emitOpError("requires 'value' to be a floating point constant"); + return false; + } + + if (type.isa()) { + if (!value.isa()) + return emitOpError("requires 'value' to be a vector/tensor constant"); + return false; + } + + if (type.isa()) { + if (!value.isa()) + return emitOpError("requires 'value' to be a function reference"); + return false; + } + + auto attrType = getAttributeType(value); + if (!attrType) + return emitOpError("requires 'value' attribute to have a type"); + if (attrType != type) + return emitOpError("requires the type of the 'value' attribute to match " + "that of the operation result"); + + return emitOpError( + "requires a result type that aligns with the 'value' attribute"); +} + +Attribute ConstantOp::constantFold(ArrayRef operands, + MLIRContext *context) const { + assert(operands.empty() && "constant has no operands"); + return getValue(); +} + +void ConstantFloatOp::build(Builder *builder, OperationState *result, + const APFloat &value, FloatType type) { + ConstantOp::build(builder, result, type, builder->getFloatAttr(type, value)); +} + +bool ConstantFloatOp::isClassFor(const Instruction *op) { + return ConstantOp::isClassFor(op) && + op->getResult(0)->getType().isa(); +} + +/// ConstantIntOp only matches values whose result type is an IntegerType. +bool ConstantIntOp::isClassFor(const Instruction *op) { + return ConstantOp::isClassFor(op) && + op->getResult(0)->getType().isa(); +} + +void ConstantIntOp::build(Builder *builder, OperationState *result, + int64_t value, unsigned width) { + Type type = builder->getIntegerType(width); + ConstantOp::build(builder, result, type, + builder->getIntegerAttr(type, value)); +} + +/// Build a constant int op producing an integer with the specified type, +/// which must be an integer type. +void ConstantIntOp::build(Builder *builder, OperationState *result, + int64_t value, Type type) { + assert(type.isa() && "ConstantIntOp can only have integer type"); + ConstantOp::build(builder, result, type, + builder->getIntegerAttr(type, value)); +} + +/// ConstantIndexOp only matches values whose result type is Index. +bool ConstantIndexOp::isClassFor(const Instruction *op) { + return ConstantOp::isClassFor(op) && op->getResult(0)->getType().isIndex(); +} + +void ConstantIndexOp::build(Builder *builder, OperationState *result, + int64_t value) { + Type type = builder->getIndexType(); + ConstantOp::build(builder, result, type, + builder->getIntegerAttr(type, value)); +} + //===----------------------------------------------------------------------===// // DeallocOp //===----------------------------------------------------------------------===// @@ -1380,6 +1749,55 @@ Attribute RemIUOp::constantFold(ArrayRef operands, return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhs.getValue())); } +//===----------------------------------------------------------------------===// +// ReturnOp +//===----------------------------------------------------------------------===// + +void ReturnOp::build(Builder *builder, OperationState *result, + ArrayRef results) { + result->addOperands(results); +} + +bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) { + SmallVector opInfo; + SmallVector types; + llvm::SMLoc loc; + return parser->getCurrentLocation(&loc) || parser->parseOperandList(opInfo) || + (!opInfo.empty() && parser->parseColonTypeList(types)) || + parser->resolveOperands(opInfo, types, loc, result->operands); +} + +void ReturnOp::print(OpAsmPrinter *p) const { + *p << "return"; + if (getNumOperands() > 0) { + *p << ' '; + p->printOperands(operand_begin(), operand_end()); + *p << " : "; + interleave( + operand_begin(), operand_end(), + [&](const Value *e) { p->printType(e->getType()); }, + [&]() { *p << ", "; }); + } +} + +bool ReturnOp::verify() const { + auto *function = getInstruction()->getFunction(); + + // The operand number and types must match the function signature. + const auto &results = function->getType().getResults(); + if (getNumOperands() != results.size()) + return emitOpError("has " + Twine(getNumOperands()) + + " operands, but enclosing function returns " + + Twine(results.size())); + + for (unsigned i = 0, e = results.size(); i != e; ++i) + if (getOperand(i)->getType() != results[i]) + return emitError("type of return operand " + Twine(i) + + " doesn't match function result type"); + + return false; +} + //===----------------------------------------------------------------------===// // SelectOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index 6274d7dc857..d3da8c17580 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -19,6 +19,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" #include "mlir/Pass/Pass.h" +#include "mlir/StandardOps/Ops.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Utils.h" diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index c39c0a52a80..91d53528870 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -25,7 +25,6 @@ #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Utils.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" #include "mlir/StandardOps/Ops.h" #include "mlir/Transforms/Passes.h" diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 7b7c0bb22bb..8b9039d2a22 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -27,7 +27,6 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" #include "mlir/StandardOps/Ops.h" #include "mlir/Transforms/LoopUtils.h" diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 2bf78ae258e..4cf65b4dc83 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -26,7 +26,6 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/LoopUtils.h" #include "llvm/ADT/DenseMap.h" diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 87259497cef..6d04b8492ea 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -49,7 +49,6 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/LoopUtils.h" #include "llvm/ADT/DenseMap.h" diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index 9979c3736ef..3776403a931 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -23,7 +23,6 @@ #include "mlir/AffineOps/AffineOps.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index 3990e54006d..e3ad4297ace 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -30,7 +30,6 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Location.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OperationSupport.h" diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 572281c5cc3..9f4027ea4fc 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -32,7 +32,6 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Location.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/Types.h" diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 71de1ef1830..39e6c5fa546 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -20,9 +20,9 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/StandardOps/Ops.h" #include "llvm/ADT/DenseMap.h" using namespace mlir; diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index c13146ee2f5..89e32695a6e 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -29,7 +29,6 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Instruction.h" #include "mlir/StandardOps/Ops.h" #include "llvm/ADT/DenseMap.h" diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index c254790dbe7..fa9c4bc7e7d 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -25,7 +25,6 @@ #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/VectorAnalysis.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/Functional.h" diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 8277d4800ab..d8e5714962e 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -26,7 +26,6 @@ #include "mlir/Analysis/VectorAnalysis.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Location.h" #include "mlir/IR/Types.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/test/EDSC/api-test.cpp b/mlir/test/EDSC/api-test.cpp index 6ce0be0556b..2e7b69321a6 100644 --- a/mlir/test/EDSC/api-test.cpp +++ b/mlir/test/EDSC/api-test.cpp @@ -21,7 +21,6 @@ #include "mlir/EDSC/MLIREmitter.h" #include "mlir/EDSC/Types.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/StandardTypes.h" -- cgit v1.2.3 From eee85361bbf433a6f3f0ea0ce9d8ae7104d4f404 Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Fri, 1 Mar 2019 17:42:13 -0800 Subject: Remove hidden flag from fusion CL options PiperOrigin-RevId: 236409185 --- mlir/lib/Transforms/LoopFusion.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 8b9039d2a22..2e84d3c3ed7 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -50,26 +50,26 @@ static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); /// Disables fusion profitability check and fuses if valid. static llvm::cl::opt - clMaximalLoopFusion("fusion-maximal", llvm::cl::Hidden, + clMaximalLoopFusion("fusion-maximal", llvm::cl::desc("Enables maximal loop fusion"), llvm::cl::cat(clOptionsCategory)); /// A threshold in percent of additional computation allowed when fusing. static llvm::cl::opt clFusionAddlComputeTolerance( - "fusion-compute-tolerance", llvm::cl::Hidden, + "fusion-compute-tolerance", llvm::cl::desc("Fractional increase in additional " "computation tolerated while fusing"), llvm::cl::cat(clOptionsCategory)); static llvm::cl::opt clFusionFastMemorySpace( - "fusion-fast-mem-space", llvm::cl::Hidden, + "fusion-fast-mem-space", llvm::cl::desc("Faster memory space number to promote fusion buffers to"), llvm::cl::cat(clOptionsCategory)); // A local buffer of size less than or equal to this size is promoted to fast // memory. static llvm::cl::opt clFusionLocalBufThreshold( - "fusion-local-buf-threshold", llvm::cl::Hidden, + "fusion-local-buf-threshold", llvm::cl::desc("Threshold size (KiB) for promoting local buffers to fast " "memory space"), llvm::cl::cat(clOptionsCategory)); -- cgit v1.2.3 From d42ef78a750ed524b76791467f8f23c8013c6b3f Mon Sep 17 00:00:00 2001 From: MLIR Team Date: Mon, 4 Mar 2019 11:01:25 -0800 Subject: Handle MemRefRegion::compute return value in loop fusion pass (NFC). PiperOrigin-RevId: 236685849 --- mlir/lib/Transforms/LoopFusion.cpp | 35 +++++++++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 8 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 2e84d3c3ed7..1e4e020b435 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -1177,7 +1177,9 @@ static Value *createPrivateMemRef(OpPointer forOp, // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'. MemRefRegion region(srcStoreOpInst->getLoc()); - region.compute(srcStoreOpInst, dstLoopDepth); + bool validRegion = region.compute(srcStoreOpInst, dstLoopDepth); + (void)validRegion; + assert(validRegion && "unexpected memref region failure"); SmallVector newShape; std::vector> lbs; SmallVector lbDivisors; @@ -1304,7 +1306,11 @@ static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, // Compute MemRefRegion 'srcWriteRegion' for 'srcStoreOpInst' on 'memref'. auto *srcStoreOpInst = srcNode->stores.front(); MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc()); - srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0); + if (!srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0)) { + LLVM_DEBUG(llvm::dbgs() + << "Unable to compute MemRefRegion for source operation\n."); + return false; + } SmallVector srcShape; // Query 'srcWriteRegion' for 'srcShape' and 'srcNumElements'. // by 'srcStoreOpInst' at depth 'dstLoopDepth'. @@ -1319,7 +1325,11 @@ static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, assert(dstStoreOps.size() == 1); auto *dstStoreOpInst = dstStoreOps[0]; MemRefRegion dstWriteRegion(dstStoreOpInst->getLoc()); - dstWriteRegion.compute(dstStoreOpInst, /*loopDepth=*/0); + if (!dstWriteRegion.compute(dstStoreOpInst, /*loopDepth=*/0)) { + LLVM_DEBUG(llvm::dbgs() + << "Unable to compute MemRefRegion for dest operation\n."); + return false; + } SmallVector dstShape; // Query 'dstWriteRegion' for 'dstShape' and 'dstNumElements'. // by 'dstStoreOpInst' at depth 'dstLoopDepth'. @@ -1444,7 +1454,12 @@ static bool isFusionProfitable(Instruction *srcOpInst, // Compute src loop nest write region size. MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc()); - srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0); + if (!srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0)) { + LLVM_DEBUG(llvm::dbgs() + << "Unable to compute MemRefRegion for source operation\n."); + return false; + } + Optional maybeSrcWriteRegionSizeBytes = srcWriteRegion.getRegionSize(); if (!maybeSrcWriteRegionSizeBytes.hasValue()) @@ -1528,8 +1543,10 @@ static bool isFusionProfitable(Instruction *srcOpInst, // nest slice 'sliceStates[i - 1]' were to be inserted into the dst loop // nest at loop depth 'i' MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc()); - sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0, - &sliceStates[i - 1]); + if (!sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0, + &sliceStates[i - 1])) + continue; + Optional maybeSliceWriteRegionSizeBytes = sliceWriteRegion.getRegionSize(); if (!maybeSliceWriteRegionSizeBytes.hasValue() || @@ -1594,8 +1611,10 @@ static bool isFusionProfitable(Instruction *srcOpInst, return false; } - assert(bestDstLoopDepth.hasValue() && - "expected to have a value per logic above"); + if (!bestDstLoopDepth.hasValue()) { + LLVM_DEBUG(llvm::dbgs() << "no fusion depth could be evaluated.\n"); + return false; + } // Set dstLoopDepth based on best values from search. *dstLoopDepth = bestDstLoopDepth.getValue(); -- cgit v1.2.3 From 02af8c22df523d7cda4399058e0a0945d54f4972 Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Tue, 5 Mar 2019 15:05:34 -0800 Subject: Change Pass:getFunction() to return pointer instead of ref - NFC - change this for consistency - everything else similar takes/returns a Function pointer - the FuncBuilder ctor, Block/Value/Instruction::getFunction(), etc. - saves a whole bunch of &s everywhere PiperOrigin-RevId: 236928761 --- mlir/include/mlir/Pass/Pass.h | 4 ++-- mlir/lib/Analysis/MemRefBoundCheck.cpp | 2 +- mlir/lib/Analysis/MemRefDependenceCheck.cpp | 2 +- mlir/lib/Analysis/ParallelismDetection.cpp | 2 +- mlir/lib/EDSC/LowerEDSCTestPass.cpp | 2 +- mlir/lib/Pass/Pass.cpp | 2 +- mlir/lib/Transforms/CSE.cpp | 2 +- mlir/lib/Transforms/Canonicalizer.cpp | 6 +++--- mlir/lib/Transforms/ConstantFold.cpp | 2 +- mlir/lib/Transforms/DmaGeneration.cpp | 2 +- mlir/lib/Transforms/LoopFusion.cpp | 2 +- mlir/lib/Transforms/LoopTiling.cpp | 2 +- mlir/lib/Transforms/LoopUnroll.cpp | 14 ++++++++------ mlir/lib/Transforms/LoopUnrollAndJam.cpp | 2 +- mlir/lib/Transforms/LowerAffine.cpp | 2 +- mlir/lib/Transforms/LowerVectorTransfers.cpp | 2 +- mlir/lib/Transforms/MaterializeVectors.cpp | 2 +- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 7 ++++--- mlir/lib/Transforms/PipelineDataTransfer.cpp | 2 +- mlir/lib/Transforms/SimplifyAffineStructures.cpp | 2 +- mlir/lib/Transforms/StripDebugInfo.cpp | 8 ++++---- mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp | 2 +- mlir/lib/Transforms/Vectorize.cpp | 2 +- mlir/lib/Transforms/ViewFunctionGraph.cpp | 2 +- 24 files changed, 40 insertions(+), 37 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h index 3a7f56cf501..33ac27177e0 100644 --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -100,8 +100,8 @@ protected: virtual void runOnFunction() = 0; /// Return the current function being transformed. - Function &getFunction() { - return *getPassState().irAndPassFailed.getPointer(); + Function *getFunction() { + return getPassState().irAndPassFailed.getPointer(); } /// Returns the current pass state. diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index 8edf79d6db3..b90a799b794 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -47,7 +47,7 @@ FunctionPassBase *mlir::createMemRefBoundCheckPass() { } void MemRefBoundCheck::runOnFunction() { - getFunction().walk([](Instruction *opInst) { + getFunction()->walk([](Instruction *opInst) { if (auto loadOp = opInst->dyn_cast()) { boundCheckLoadOrStoreOp(loadOp); } else if (auto storeOp = opInst->dyn_cast()) { diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp index 0206765f880..0c2a5defe10 100644 --- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp @@ -113,7 +113,7 @@ static void checkDependences(ArrayRef loadsAndStores) { void MemRefDependenceCheck::runOnFunction() { // Collect the loads and stores within the function. loadsAndStores.clear(); - getFunction().walk([&](Instruction *inst) { + getFunction()->walk([&](Instruction *inst) { if (inst->isa() || inst->isa()) loadsAndStores.push_back(inst); }); diff --git a/mlir/lib/Analysis/ParallelismDetection.cpp b/mlir/lib/Analysis/ParallelismDetection.cpp index 4c86e0c41fd..920511e7887 100644 --- a/mlir/lib/Analysis/ParallelismDetection.cpp +++ b/mlir/lib/Analysis/ParallelismDetection.cpp @@ -42,7 +42,7 @@ FunctionPassBase *mlir::createLoopParallelismDetectionPass() { // Walks the function and marks all parallel 'for' ops with an attribute. void LoopParallelismDetection::runOnFunction() { - Function *f = &getFunction(); + Function *f = getFunction(); FuncBuilder b(f); f->walk([&](OpPointer forOp) { forOp->getInstruction()->setAttr("parallel", diff --git a/mlir/lib/EDSC/LowerEDSCTestPass.cpp b/mlir/lib/EDSC/LowerEDSCTestPass.cpp index f904536e71d..aaf1747cdda 100644 --- a/mlir/lib/EDSC/LowerEDSCTestPass.cpp +++ b/mlir/lib/EDSC/LowerEDSCTestPass.cpp @@ -40,7 +40,7 @@ struct LowerEDSCTestPass : public FunctionPass { #include "mlir/EDSC/reference-impl.inc" void LowerEDSCTestPass::runOnFunction() { - getFunction().walk([](Instruction *op) { + getFunction()->walk([](Instruction *op) { if (op->getName().getStringRef() == "print") { auto opName = op->getAttrOfType("op"); if (!opName) { diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp index 45a3970a274..598e4c83e55 100644 --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -221,7 +221,7 @@ namespace { /// Pass to verify a function and signal failure if necessary. class FunctionVerifier : public FunctionPass { void runOnFunction() { - if (getFunction().verify()) + if (getFunction()->verify()) signalPassFailure(); markAllAnalysesPreserved(); } diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index 9812decbfbf..f4818961ca6 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -218,7 +218,7 @@ void CSE::simplifyBlockList(DominanceInfo &domInfo, BlockList &blockList) { void CSE::runOnFunction() { simplifyBlockList(getAnalysisResult(), - getFunction().getBlockList()); + getFunction()->getBlockList()); // If no operations were erased, then we mark all analyses as preserved. if (opsToErase.empty()) { diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp index 17259bb19da..77244264cda 100644 --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -40,16 +40,16 @@ struct Canonicalizer : public FunctionPass { void Canonicalizer::runOnFunction() { OwningRewritePatternList patterns; - auto &func = getFunction(); + auto *func = getFunction(); // TODO: Instead of adding all known patterns from the whole system lazily add // and cache the canonicalization patterns for ops we see in practice when // building the worklist. For now, we just grab everything. - auto *context = func.getContext(); + auto *context = func->getContext(); for (auto *op : context->getRegisteredOperations()) op->getCanonicalizationPatterns(patterns, context); - applyPatternsGreedily(&func, std::move(patterns)); + applyPatternsGreedily(func, std::move(patterns)); } /// Create a Canonicalizer pass. diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index d3da8c17580..6bdb1bffab3 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -97,7 +97,7 @@ void ConstantFold::runOnFunction() { existingConstants.clear(); opInstsToErase.clear(); - getFunction().walk([&](Instruction *inst) { foldInstruction(inst); }); + getFunction()->walk([&](Instruction *inst) { foldInstruction(inst); }); // At this point, these operations are dead, remove them. // TODO: This is assuming that all constant foldable operations have no diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 91d53528870..44b4d9528e9 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -754,7 +754,7 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { } void DmaGeneration::runOnFunction() { - Function *f = &getFunction(); + Function *f = getFunction(); FuncBuilder topBuilder(f); zeroIndex = topBuilder.create(f->getLoc(), 0); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 1e4e020b435..7466e4968a6 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -2202,7 +2202,7 @@ void LoopFusion::runOnFunction() { } MemRefDependenceGraph g; - if (g.init(&getFunction())) + if (g.init(getFunction())) GreedyFusion(&g, localBufSizeThreshold, fastMemorySpace).run(); } diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 4aebbc2e856..e58de3bc136 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -255,7 +255,7 @@ getTileableBands(Function *f, void LoopTiling::runOnFunction() { std::vector, 6>> bands; - getTileableBands(&getFunction(), &bands); + getTileableBands(getFunction(), &bands); for (auto &band : bands) { // Set up tile sizes; fill missing tile sizes at the end with default tile diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 4cf65b4dc83..40da2f63f62 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -129,11 +129,13 @@ void LoopUnroll::runOnFunction() { // Gathers all loops with trip count <= minTripCount. Do a post order walk // so that loops are gathered from innermost to outermost (or else unrolling // an outer one may delete gathered inner ones). - getFunction().walkPostOrder([&](OpPointer forOp) { - Optional tripCount = getConstantTripCount(forOp); - if (tripCount.hasValue() && tripCount.getValue() <= clUnrollFullThreshold) - loops.push_back(forOp); - }); + getFunction()->walkPostOrder( + [&](OpPointer forOp) { + Optional tripCount = getConstantTripCount(forOp); + if (tripCount.hasValue() && + tripCount.getValue() <= clUnrollFullThreshold) + loops.push_back(forOp); + }); for (auto forOp : loops) loopUnrollFull(forOp); return; @@ -143,7 +145,7 @@ void LoopUnroll::runOnFunction() { ? clUnrollNumRepetitions : 1; // If the call back is provided, we will recurse until no loops are found. - Function *func = &getFunction(); + Function *func = getFunction(); for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) { InnermostLoopGatherer ilg; ilg.walkPostOrder(func); diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 6d04b8492ea..63fc451287b 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -91,7 +91,7 @@ void LoopUnrollAndJam::runOnFunction() { // Currently, just the outermost loop from the first loop nest is // unroll-and-jammed by this pass. However, runOnAffineForOp can be called on // any for operation. - auto &entryBlock = getFunction().front(); + auto &entryBlock = getFunction()->front(); if (auto forOp = entryBlock.front().dyn_cast()) runOnAffineForOp(forOp); } diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index 3776403a931..31232ce7fe4 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -608,7 +608,7 @@ void LowerAffinePass::runOnFunction() { // Collect all the For instructions as well as AffineIfOps and AffineApplyOps. // We do this as a prepass to avoid invalidating the walker with our rewrite. - getFunction().walk([&](Instruction *inst) { + getFunction()->walk([&](Instruction *inst) { if (inst->isa() || inst->isa() || inst->isa()) instsToRewrite.push_back(inst); diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index e3ad4297ace..9ac8583bc78 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -426,7 +426,7 @@ public: struct LowerVectorTransfersPass : public FunctionPass { void runOnFunction() { - Function *f = &getFunction(); + Function *f = getFunction(); applyMLPatternsGreedily, VectorTransferExpander>(f); } diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 9f4027ea4fc..0d54ead424e 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -733,7 +733,7 @@ void MaterializeVectorsPass::runOnFunction() { NestedPatternContext mlContext; // TODO(ntv): Check to see if this supports arbitrary top-level code. - Function *f = &getFunction(); + Function *f = getFunction(); if (f->getBlocks().size() != 1) return; diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index 4fc544074c0..f48f90923ce 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -211,8 +211,8 @@ void MemRefDataFlowOpt::forwardStoreToLoad(OpPointer loadOp) { void MemRefDataFlowOpt::runOnFunction() { // Only supports single block functions at the moment. - Function &f = getFunction(); - if (f.getBlocks().size() != 1) { + Function *f = getFunction(); + if (f->getBlocks().size() != 1) { markAllAnalysesPreserved(); return; } @@ -224,7 +224,8 @@ void MemRefDataFlowOpt::runOnFunction() { memrefsToErase.clear(); // Walk all load's and perform load/store forwarding. - f.walk([&](OpPointer loadOp) { forwardStoreToLoad(loadOp); }); + f->walk( + [&](OpPointer loadOp) { forwardStoreToLoad(loadOp); }); // Erase all load op's whose results were replaced with store fwd'ed ones. for (auto *loadOp : loadOpsToErase) { diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index cdce5230ba6..08115eddbe7 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -146,7 +146,7 @@ void PipelineDataTransfer::runOnFunction() { // deleted and replaced by a prologue, a new steady-state loop and an // epilogue). forOps.clear(); - getFunction().walkPostOrder( + getFunction()->walkPostOrder( [&](OpPointer forOp) { forOps.push_back(forOp); }); for (auto forOp : forOps) runOnAffineForOp(forOp); diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp index 3adcbe038ea..8a8e7af1089 100644 --- a/mlir/lib/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp @@ -58,7 +58,7 @@ static IntegerSet simplifyIntegerSet(IntegerSet set) { } void SimplifyAffineStructures::runOnFunction() { - getFunction().walk([&](Instruction *opInst) { + getFunction()->walk([&](Instruction *opInst) { for (auto attr : opInst->getAttrs()) { if (auto mapAttr = attr.second.dyn_cast()) { MutableAffineMap mMap(mapAttr.getValue()); diff --git a/mlir/lib/Transforms/StripDebugInfo.cpp b/mlir/lib/Transforms/StripDebugInfo.cpp index 47244f94ac9..f8f90c0cdb1 100644 --- a/mlir/lib/Transforms/StripDebugInfo.cpp +++ b/mlir/lib/Transforms/StripDebugInfo.cpp @@ -29,12 +29,12 @@ struct StripDebugInfo : public FunctionPass { } // end anonymous namespace void StripDebugInfo::runOnFunction() { - Function &func = getFunction(); - UnknownLoc unknownLoc = UnknownLoc::get(func.getContext()); + Function *func = getFunction(); + UnknownLoc unknownLoc = UnknownLoc::get(func->getContext()); // Strip the debug info from the function and its instructions. - func.setLoc(unknownLoc); - func.walk([&](Instruction *inst) { inst->setLoc(unknownLoc); }); + func->setLoc(unknownLoc); + func->walk([&](Instruction *inst) { inst->setLoc(unknownLoc); }); } /// Creates a pass to strip debug information from a function. diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index fa9c4bc7e7d..8fd1cac201c 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -264,7 +264,7 @@ void VectorizerTestPass::runOnFunction() { NestedPatternContext mlContext; // Only support single block functions at this point. - Function *f = &getFunction(); + Function *f = getFunction(); if (f->getBlocks().size() != 1) return; diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index d8e5714962e..b084b016be3 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -1262,7 +1262,7 @@ void Vectorize::runOnFunction() { // Thread-safe RAII local context, BumpPtrAllocator freed on exit. NestedPatternContext mlContext; - Function *f = &getFunction(); + Function *f = getFunction(); for (auto &pat : makePatterns()) { LLVM_DEBUG(dbgs() << "\n******************************************"); LLVM_DEBUG(dbgs() << "\n******************************************"); diff --git a/mlir/lib/Transforms/ViewFunctionGraph.cpp b/mlir/lib/Transforms/ViewFunctionGraph.cpp index b2dfe6795b6..d77e96a99d7 100644 --- a/mlir/lib/Transforms/ViewFunctionGraph.cpp +++ b/mlir/lib/Transforms/ViewFunctionGraph.cpp @@ -78,7 +78,7 @@ struct PrintCFGPass : public FunctionPass { const llvm::Twine &title = "") : os(os), shortNames(shortNames), title(title) {} void runOnFunction() { - mlir::writeGraph(os, &getFunction(), shortNames, title); + mlir::writeGraph(os, getFunction(), shortNames, title); } private: -- cgit v1.2.3 From c1ff9e866e37930a70d92d8cdafbd4659605a04f Mon Sep 17 00:00:00 2001 From: MLIR Team Date: Tue, 5 Mar 2019 20:33:30 -0800 Subject: Use FlatAffineConstraints::unionBoundingBox to perform slice bounds union for loop fusion pass (WIP). Adds utility to convert slice bounds to a FlatAffineConstraints representation. Adds utility to FlatAffineConstraints to promote loop IV symbol identifiers to dim identifiers. PiperOrigin-RevId: 236973261 --- mlir/include/mlir/Analysis/AffineStructures.h | 3 + mlir/include/mlir/Analysis/Utils.h | 11 ++ mlir/lib/Analysis/AffineStructures.cpp | 26 +++- mlir/lib/Analysis/Utils.cpp | 49 +++++++ mlir/lib/Transforms/LoopFusion.cpp | 183 +++++++++++++++----------- mlir/test/Transforms/loop-fusion.mlir | 39 ++++++ 6 files changed, 229 insertions(+), 82 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h index 48d9383bef0..81006e661f9 100644 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -485,6 +485,9 @@ public: /// symbols, or some of the leading symbols become dimensions. void setDimSymbolSeparation(unsigned newSymbolCount); + /// Changes all symbol identifiers which are loop IVs to dim identifiers. + void convertLoopIVSymbolsToDims(); + /// Sets the specified identifier to a constant and removes it. void setAndEliminate(unsigned pos, int64_t constVal); diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index e1ee59bb4b6..93c5b734fce 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -83,6 +83,17 @@ struct ComputationSliceState { std::vector> lbOperands; // List of upper bound operands (ubOperands[i] are used by 'ubs[i]'). std::vector> ubOperands; + // Adds to 'cst' with constraints which represent the slice bounds on 'ivs' + // in 'this'. Specifically, the values in 'ivs' are added to 'cst' as dim + // identifiers and the values in 'lb/ubOperands' are added as symbols. + // Constraints are added for all loop IV bounds (dim or symbol), and + // constraints are added for slice bounds in 'lbs'/'ubs'. + // Returns true on success, false otherwise (if we cannot add loop bounds + // because of unsupported cases). + bool getAsConstraints(FlatAffineConstraints *cst); + + // Clears all bounds and operands in slice state. + void clearBounds(); }; /// Computes computation slice loop bounds for the loop nest surrounding diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 4cbddba140b..c6287e281c4 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -694,6 +694,20 @@ static void turnSymbolIntoDim(FlatAffineConstraints *cst, const Value &id) { } } +// Changes all symbol identifiers which are loop IVs to dim identifiers. +void FlatAffineConstraints::convertLoopIVSymbolsToDims() { + // Gather all symbols which are loop IVs. + SmallVector loopIVs; + for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++) { + if (ids[i].hasValue() && getForInductionVarOwner(ids[i].getValue())) + loopIVs.push_back(ids[i].getValue()); + } + // Turn each symbol in 'loopIVs' into a dim identifier. + for (auto *iv : loopIVs) { + turnSymbolIntoDim(this, *iv); + } +} + bool FlatAffineConstraints::addAffineForOpDomain( ConstOpPointer forOp) { unsigned pos; @@ -2704,7 +2718,7 @@ bool FlatAffineConstraints::unionBoundingBox( assert(lbDivisor > 0 && "divisor always expected to be positive"); // Compute min of lower bounds and max of upper bounds. - ArrayRef minLb, maxUb; + SmallVector minLb, maxUb; auto res = compareBounds(lb, otherLb); // Identify min. @@ -2713,12 +2727,13 @@ bool FlatAffineConstraints::unionBoundingBox( } else if (res == BoundCmpResult::Greater) { minLb = otherLb; } else { - // Uncomparable. + // Uncomparable - check for constant lower/upper bounds. auto constLb = getConstantLowerBound(d); auto constOtherLb = other.getConstantLowerBound(d); if (!constLb.hasValue() || !constOtherLb.hasValue()) return false; - minLb = std::min(constLb.getValue(), constOtherLb.getValue()); + minLb.resize(getNumSymbolIds() + 1, 0); + minLb.back() = std::min(constLb.getValue(), constOtherLb.getValue()); } // Do the same for ub's but max of upper bounds. @@ -2733,12 +2748,13 @@ bool FlatAffineConstraints::unionBoundingBox( } else if (uRes == BoundCmpResult::Less) { maxUb = otherUb; } else { - // Uncomparable. + // Uncomparable - check for constant lower/upper bounds. auto constUb = getConstantUpperBound(d); auto constOtherUb = other.getConstantUpperBound(d); if (!constUb.hasValue() || !constOtherUb.hasValue()) return false; - maxUb = std::max(constUb.getValue(), constOtherUb.getValue()); + maxUb.resize(getNumSymbolIds() + 1, 0); + maxUb.back() = std::max(constUb.getValue(), constOtherUb.getValue()); } SmallVector newLb(getNumCols(), 0); diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index f41e2c4f27d..bf2f82e29b1 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -54,6 +54,55 @@ void mlir::getLoopIVs(const Instruction &inst, std::reverse(loops->begin(), loops->end()); } +// Populates 'cst' with FlatAffineConstraints which represent slice bounds. +bool ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) { + assert(!lbOperands.empty()); + // Adds src 'ivs' as dimension identifiers in 'cst'. + unsigned numDims = ivs.size(); + // Adds operands (dst ivs and symbols) as symbols in 'cst'. + unsigned numSymbols = lbOperands[0].size(); + + SmallVector values(ivs); + // Append 'ivs' then 'operands' to 'values'. + values.append(lbOperands[0].begin(), lbOperands[0].end()); + cst->reset(numDims, numSymbols, 0, values); + + // Add loop bound constraints for values which are loop IVs and equality + // constraints for symbols which are constants. + for (const auto &value : values) { + unsigned loc; + (void)loc; + assert(cst->findId(*value, &loc)); + if (isValidSymbol(value)) { + // Check if the symbol is a constant. + if (auto *inst = value->getDefiningInst()) { + if (auto constOp = inst->dyn_cast()) { + cst->setIdToConstant(*value, constOp->getValue()); + } + } + } else { + if (auto loop = getForInductionVarOwner(value)) { + if (!cst->addAffineForOpDomain(loop)) + return false; + } + } + } + + // Add slices bounds on 'ivs' using maps 'lbs'/'ubs' with 'lbOperands[0]' + bool ret = cst->addSliceBounds(ivs, lbs, ubs, lbOperands[0]); + assert(ret && "should not fail as we never have semi-affine slice maps"); + (void)ret; + return true; +} + +// Clears state bounds and operand state. +void ComputationSliceState::clearBounds() { + lbs.clear(); + ubs.clear(); + lbOperands.clear(); + ubOperands.clear(); +} + unsigned MemRefRegion::getRank() const { return memref->getType().cast().getRank(); } diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 7466e4968a6..4183cba8859 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -1085,60 +1085,6 @@ static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) { node->inst = loops[loopNestRootIndex]->getInstruction(); } -// Returns the slice union of 'sliceStateA' and 'sliceStateB' in 'sliceStateB' -// using a rectangular bounding box. -// TODO(andydavis) This function assumes that lower bounds for 'sliceStateA' -// and 'sliceStateB' are aligned. -// Specifically, when taking the union of overlapping intervals, it assumes -// that both intervals start at zero. Support needs to be added to take into -// account interval start offset when computing the union. -// TODO(andydavis) Move this function to an analysis library. -static bool getSliceUnion(const ComputationSliceState &sliceStateA, - ComputationSliceState *sliceStateB) { - assert(sliceStateA.lbs.size() == sliceStateB->lbs.size()); - assert(sliceStateA.ubs.size() == sliceStateB->ubs.size()); - - for (unsigned i = 0, e = sliceStateA.lbs.size(); i < e; ++i) { - AffineMap lbMapA = sliceStateA.lbs[i]; - AffineMap ubMapA = sliceStateA.ubs[i]; - if (lbMapA == AffineMap()) { - assert(ubMapA == AffineMap()); - continue; - } - assert(ubMapA && "expected non-null ub map"); - - AffineMap lbMapB = sliceStateB->lbs[i]; - AffineMap ubMapB = sliceStateB->ubs[i]; - if (lbMapB == AffineMap()) { - assert(ubMapB == AffineMap()); - // Union 'sliceStateB' does not have a bound for 'i' so copy from A. - sliceStateB->lbs[i] = lbMapA; - sliceStateB->ubs[i] = ubMapA; - continue; - } - - // TODO(andydavis) Change this code to take the min across all lower bounds - // and max across all upper bounds for each dimension. This code can for - // cases where a unique min or max could not be statically determined. - - // Assumption: both lower bounds are the same. - if (lbMapA != lbMapB) - return false; - - // Add bound with the largest trip count to union. - Optional tripCountA = getConstDifference(lbMapA, ubMapA); - Optional tripCountB = getConstDifference(lbMapB, ubMapB); - if (!tripCountA.hasValue() || !tripCountB.hasValue()) - return false; - - if (tripCountA.getValue() > tripCountB.getValue()) { - sliceStateB->lbs[i] = lbMapA; - sliceStateB->ubs[i] = ubMapA; - } - } - return true; -} - // TODO(mlir-team): improve/complete this when we have target data. unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { auto elementType = memRefType.getElementType(); @@ -1346,6 +1292,81 @@ static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, return true; } +// Computes the union of all slice bounds computed between 'srcOpInst' +// and each load op in 'dstLoadOpInsts' at 'dstLoopDepth', and returns +// the union in 'sliceState'. Returns true on success, false otherwise. +// TODO(andydavis) Move this to a loop fusion utility function. +static bool getSliceUnion(Instruction *srcOpInst, + ArrayRef dstLoadOpInsts, + unsigned numSrcLoopIVs, unsigned dstLoopDepth, + ComputationSliceState *sliceState) { + MemRefAccess srcAccess(srcOpInst); + unsigned numDstLoadOpInsts = dstLoadOpInsts.size(); + assert(numDstLoadOpInsts > 0); + // Compute the slice bounds between 'srcOpInst' and 'dstLoadOpInsts[0]'. + if (!mlir::getBackwardComputationSliceState( + srcAccess, MemRefAccess(dstLoadOpInsts[0]), dstLoopDepth, sliceState)) + return false; + // Handle the common case of one dst load without a copy. + if (numDstLoadOpInsts == 1) + return true; + + // Initialize 'sliceUnionCst' with the bounds computed in previous step. + FlatAffineConstraints sliceUnionCst; + if (!sliceState->getAsConstraints(&sliceUnionCst)) { + LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice bound constraints\n."); + return false; + } + + // Compute the union of slice bounds between 'srcOpInst' and each load + // in 'dstLoadOpInsts' in range [1, numDstLoadOpInsts), in 'sliceUnionCst'. + for (unsigned i = 1; i < numDstLoadOpInsts; ++i) { + MemRefAccess dstAccess(dstLoadOpInsts[i]); + // Compute slice bounds for 'srcOpInst' and 'dstLoadOpInsts[i]'. + ComputationSliceState tmpSliceState; + if (!mlir::getBackwardComputationSliceState(srcAccess, dstAccess, + dstLoopDepth, &tmpSliceState)) { + LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice bounds\n."); + return false; + } + + // Compute constraints for 'tmpSliceState' in 'tmpSliceCst'. + FlatAffineConstraints tmpSliceCst; + if (!tmpSliceState.getAsConstraints(&tmpSliceCst)) { + LLVM_DEBUG(llvm::dbgs() + << "Unable to compute slice bound constraints\n."); + return false; + } + // Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'. + if (!sliceUnionCst.unionBoundingBox(tmpSliceCst)) { + LLVM_DEBUG(llvm::dbgs() + << "Unable to compute union bounding box of slice bounds.\n."); + return false; + } + } + + // Convert any dst loop IVs which are symbol identifiers to dim identifiers. + sliceUnionCst.convertLoopIVSymbolsToDims(); + + sliceState->clearBounds(); + sliceState->lbs.resize(numSrcLoopIVs, AffineMap()); + sliceState->ubs.resize(numSrcLoopIVs, AffineMap()); + + // Get slice bounds from slice union constraints 'sliceUnionCst'. + sliceUnionCst.getSliceBounds(numSrcLoopIVs, srcOpInst->getContext(), + &sliceState->lbs, &sliceState->ubs); + // Add slice bound operands of union. + SmallVector sliceBoundOperands; + sliceUnionCst.getIdValues(numSrcLoopIVs, + sliceUnionCst.getNumDimAndSymbolIds(), + &sliceBoundOperands); + // Give each bound its own copy of 'sliceBoundOperands' for subsequent + // canonicalization. + sliceState->lbOperands.resize(numSrcLoopIVs, sliceBoundOperands); + sliceState->ubOperands.resize(numSrcLoopIVs, sliceBoundOperands); + return true; +} + // Checks the profitability of fusing a backwards slice of the loop nest // surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'. // The argument 'srcStoreOpInst' is used to calculate the storage reduction on @@ -1408,9 +1429,10 @@ static bool isFusionProfitable(Instruction *srcOpInst, LoopNestStatsCollector srcStatsCollector(&srcLoopNestStats); srcStatsCollector.collect(srcLoopIVs[0]->getInstruction()); // Currently only constant trip count loop nests are supported. - if (srcStatsCollector.hasLoopWithNonConstTripCount) + if (srcStatsCollector.hasLoopWithNonConstTripCount) { + LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count loops unsupported.\n"); return false; - + } // Compute cost of dst loop nest. SmallVector, 4> dstLoopIVs; getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs); @@ -1419,8 +1441,10 @@ static bool isFusionProfitable(Instruction *srcOpInst, LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats); dstStatsCollector.collect(dstLoopIVs[0]->getInstruction()); // Currently only constant trip count loop nests are supported. - if (dstStatsCollector.hasLoopWithNonConstTripCount) + if (dstStatsCollector.hasLoopWithNonConstTripCount) { + LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count loops unsupported.\n"); return false; + } // Compute the maximum loop depth at which we can can insert the src slice // and still satisfy dest loop nest dependences, for producer-consumer fusion. @@ -1428,8 +1452,10 @@ static bool isFusionProfitable(Instruction *srcOpInst, (srcOpInst == srcStoreOpInst) ? getMaxLoopDepth(dstLoadOpInsts, dstStoreOpInsts) : dstLoopIVs.size(); - if (maxDstLoopDepth == 0) + if (maxDstLoopDepth == 0) { + LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxDstLoopDepth == 0 .\n"); return false; + } // Search for min cost value for 'dstLoopDepth'. At each value of // 'dstLoopDepth' from 'maxDstLoopDepth' to '1', compute computation slice @@ -1456,7 +1482,7 @@ static bool isFusionProfitable(Instruction *srcOpInst, MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc()); if (!srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0)) { LLVM_DEBUG(llvm::dbgs() - << "Unable to compute MemRefRegion for source operation\n."); + << "Unable to compute MemRefRegion for source instruction\n."); return false; } @@ -1477,28 +1503,22 @@ static bool isFusionProfitable(Instruction *srcOpInst, llvm::SmallDenseMap sliceTripCountMap; DenseMap computeCostMap; for (unsigned i = maxDstLoopDepth; i >= 1; --i) { - MemRefAccess srcAccess(srcOpInst); - // Handle the common case of one dst load without a copy. - if (!mlir::getBackwardComputationSliceState( - srcAccess, MemRefAccess(dstLoadOpInsts[0]), i, &sliceStates[i - 1])) - return false; - - // Compute the union of slice bound of all ops in 'dstLoadOpInsts'. - for (int j = 1, e = dstLoadOpInsts.size(); j < e; ++j) { - MemRefAccess dstAccess(dstLoadOpInsts[j]); - ComputationSliceState tmpSliceState; - if (!mlir::getBackwardComputationSliceState(srcAccess, dstAccess, i, - &tmpSliceState)) - return false; - // Compute slice boun dunion of 'tmpSliceState' and 'sliceStates[i - 1]'. - getSliceUnion(tmpSliceState, &sliceStates[i - 1]); + // Compute the union of slice bounds of all ops in 'dstLoadOpInsts'. + if (!getSliceUnion(srcOpInst, dstLoadOpInsts, numSrcLoopIVs, i, + &sliceStates[i - 1])) { + LLVM_DEBUG(llvm::dbgs() + << "getSliceUnion failed for loopDepth: " << i << "\n"); + continue; } + // Build trip count map for computation slice. We'll skip cases where the // trip count was non-constant. sliceTripCountMap.clear(); if (!buildSliceTripCountMap(srcOpInst, &sliceStates[i - 1], - &sliceTripCountMap)) + &sliceTripCountMap)) { + LLVM_DEBUG(llvm::dbgs() << "Unable to build slice trip count map.\n."); continue; + } // Checks whether a store to load forwarding will happen. int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap); @@ -1544,14 +1564,22 @@ static bool isFusionProfitable(Instruction *srcOpInst, // nest at loop depth 'i' MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc()); if (!sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0, - &sliceStates[i - 1])) + &sliceStates[i - 1])) { + LLVM_DEBUG(llvm::dbgs() + << "Failed to compute slice write region at loopDepth: " << i + << "\n"); continue; + } Optional maybeSliceWriteRegionSizeBytes = sliceWriteRegion.getRegionSize(); if (!maybeSliceWriteRegionSizeBytes.hasValue() || - maybeSliceWriteRegionSizeBytes.getValue() == 0) + maybeSliceWriteRegionSizeBytes.getValue() == 0) { + LLVM_DEBUG(llvm::dbgs() + << "Failed to get slice write region size at loopDepth: " << i + << "\n"); continue; + } int64_t sliceWriteRegionSizeBytes = maybeSliceWriteRegionSizeBytes.getValue(); @@ -1783,6 +1811,7 @@ public: // *) Second pass fuses sibling nodes which share no dependence edges. // *) Third pass fuses any remaining producer nodes into their users. void run() { + // TODO(andydavis) Run this repeatedly until a fixed-point is reached. fuseProducerConsumerNodes(/*maxSrcUserCount=*/1); fuseSiblingNodes(); fuseProducerConsumerNodes( diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index acf5bbb54bb..4526f01a369 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -2250,3 +2250,42 @@ func @fuse_across_varying_dims_complex() { // MAXIMAL-NEXT: } // MAXIMAL-NEXT: } // MAXIMAL-NEXT: } + +// ----- +// CHECK-DAG: [[MAP3:#map[0-9]+]] = (d0) -> (d0 - 10) + +func @should_fuse_with_slice_union() { + %a = alloc() : memref<100xf32> + %c0 = constant 0 : index + %cf0 = constant 0.0 : f32 + + for %i0 = 0 to 100 { + store %cf0, %a[%i0]: memref<100xf32> + } + + for %i1 = 10 to 20 { + %v0 = load %a[%i1]: memref<100xf32> + for %i2 = 15 to 25 { + %v1 = load %a[%i2]: memref<100xf32> + } + } + // The union of two slice bounds (calculated between the store and each of + // the loads) is computed and used in the fusion cost calculation, index + // remapping, and private memref size. The result is that the temporary + // memref is reduced from 100xf32 to 15xf32 and properly indexed by + // the fused loops based on the union calculation. +// CHECK: for %i0 = 10 to 20 { +// CHECK-NEXT: for %i1 = 10 to 25 { +// CHECK-NEXT: %1 = affine.apply [[MAP3]](%i1) +// CHECK-NEXT: store %cst, %0[%1] : memref<15xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %2 = affine.apply [[MAP3]](%i0) +// CHECK-NEXT: %3 = load %0[%2] : memref<15xf32> +// CHECK-NEXT: for %i2 = 15 to 25 { +// CHECK-NEXT: %4 = affine.apply [[MAP3]](%i2) +// CHECK-NEXT: %5 = load %0[%4] : memref<15xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return + return +} -- cgit v1.2.3 From 1e55ae19a0e4a95f86a681a9c682192712908276 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Thu, 7 Mar 2019 22:14:47 -0800 Subject: Convert ambiguous bool returns in /Analysis to use Status instead. PiperOrigin-RevId: 237390240 --- mlir/include/mlir/Analysis/AffineAnalysis.h | 11 ++- mlir/include/mlir/Analysis/AffineStructures.h | 79 +++++++-------- mlir/include/mlir/Analysis/Utils.h | 54 +++++------ mlir/lib/Analysis/AffineAnalysis.cpp | 46 ++++----- mlir/lib/Analysis/AffineStructures.cpp | 135 +++++++++++++------------- mlir/lib/Analysis/Utils.cpp | 79 +++++++-------- mlir/lib/Transforms/DmaGeneration.cpp | 4 +- mlir/lib/Transforms/LoopFusion.cpp | 27 +++--- 8 files changed, 215 insertions(+), 220 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Analysis/AffineAnalysis.h b/mlir/include/mlir/Analysis/AffineAnalysis.h index 51f0a9aaec1..d8a7b190ab5 100644 --- a/mlir/include/mlir/Analysis/AffineAnalysis.h +++ b/mlir/include/mlir/Analysis/AffineAnalysis.h @@ -24,6 +24,7 @@ #ifndef MLIR_ANALYSIS_AFFINE_ANALYSIS_H #define MLIR_ANALYSIS_AFFINE_ANALYSIS_H +#include "mlir/Support/Status.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/SmallVector.h" @@ -48,11 +49,11 @@ void getReachableAffineApplyOps( /// Builds a system of constraints with dimensional identifiers corresponding to /// the loop IVs of the forOps appearing in that order. Bounds of the loop are /// used to add appropriate inequalities. Any symbols founds in the bound -/// operands are added as symbols in the system. Returns false for the yet +/// operands are added as symbols in the system. Returns failure for the yet /// unimplemented cases. // TODO(bondhugula): handle non-unit strides. -bool getIndexSet(llvm::MutableArrayRef> forOps, - FlatAffineConstraints *domain); +Status getIndexSet(llvm::MutableArrayRef> forOps, + FlatAffineConstraints *domain); /// Encapsulates a memref load or store access information. struct MemRefAccess { @@ -92,8 +93,8 @@ struct DependenceComponent { /// Checks whether two accesses to the same memref access the same element. /// Each access is specified using the MemRefAccess structure, which contains /// the operation instruction, indices and memref associated with the access. -/// Returns 'false' if it can be determined conclusively that the accesses do -/// not access the same memref element. Returns 'true' otherwise. +/// Returns 'success' if it can be determined conclusively that the accesses do +/// not access the same memref element. /// If 'allowRAR' is true, will consider read-after-read dependences (typically /// used by applications trying to optimize input reuse). // TODO(andydavis) Wrap 'dependenceConstraints' and 'dependenceComponents' into diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h index 0ba77cccf90..237ec7d40ba 100644 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -24,6 +24,7 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/Support/Status.h" namespace mlir { @@ -378,14 +379,13 @@ public: /// Adds constraints (lower and upper bounds) for the specified 'for' /// instruction's Value using IR information stored in its bound maps. The /// right identifier is first looked up using forOp's Value. Returns - /// false for the yet unimplemented/unsupported cases, and true if the - /// information is successfully added. Asserts if the Value corresponding to - /// the 'for' instruction isn't found in the constraint system. Any new - /// identifiers that are found in the bound operands of the 'for' instruction - /// are added as trailing identifiers (either dimensional or symbolic - /// depending on whether the operand is a valid ML Function symbol). + /// failure for the yet unimplemented/unsupported cases. Asserts if the Value + /// corresponding to the 'for' instruction isn't found in the constraint + /// system. Any new identifiers that are found in the bound operands of the + /// 'for' instruction are added as trailing identifiers (either dimensional or + /// symbolic depending on whether the operand is a valid ML Function symbol). // TODO(bondhugula): add support for non-unit strides. - bool addAffineForOpDomain(ConstOpPointer forOp); + Status addAffineForOpDomain(ConstOpPointer forOp); /// Computes the lower and upper bounds of the first 'num' dimensional /// identifiers as an affine map of the remaining identifiers (dimensional and @@ -403,9 +403,8 @@ public: /// operand list 'operands'. /// This function assumes 'values.size' == 'lbMaps.size' == 'ubMaps.size'. /// Note that both lower/upper bounds use operands from 'operands'. - /// Returns true on success, returns false for unimplemented cases. - bool addSliceBounds(ArrayRef values, ArrayRef lbMaps, - ArrayRef ubMaps, ArrayRef operands); + Status addSliceBounds(ArrayRef values, ArrayRef lbMaps, + ArrayRef ubMaps, ArrayRef operands); // Adds an inequality (>= 0) from the coefficients specified in inEq. void addInequality(ArrayRef inEq); @@ -457,13 +456,13 @@ public: /// Composes the affine value map with this FlatAffineConstrains, adding the /// results of the map as dimensions at the front [0, vMap->getNumResults()) /// and with the dimensions set to the equalities specified by the value map. - /// Returns false if the composition fails (when vMap is a semi-affine map). + /// Returns failure if the composition fails (when vMap is a semi-affine map). /// The vMap's operand Value's are used to look up the right positions in /// the FlatAffineConstraints with which to associate. The dimensional and /// symbolic operands of vMap should match 1:1 (in the same order) with those /// of this constraint system, but the latter could have additional trailing /// operands. - bool composeMap(AffineValueMap *vMap); + Status composeMap(AffineValueMap *vMap); /// Projects out (aka eliminates) 'num' identifiers starting at position /// 'pos'. The resulting constraint system is the shadow along the dimensions @@ -498,8 +497,8 @@ public: /// Tries to fold the specified identifier to a constant using a trivial /// equality detection; if successful, the constant is substituted for the /// identifier everywhere in the constraint system and then removed from the - /// system. Returns true if the folding happens, false otherwise. - bool constantFoldId(unsigned pos); + /// system. + Status constantFoldId(unsigned pos); /// This method calls constantFoldId for the specified range of identifiers, /// 'num' identifiers starting at position 'pos'. @@ -524,7 +523,7 @@ public: /// 9}, output = {s0 + 1 <= d0 <= s0 + 20}. /// 3) 'this' = {0 <= d0 <= 5, 1 <= d1 <= 9}, 'other' = {2 <= d0 <= 6, 5 <= d1 /// <= 15}, output = {0 <= d0 <= 6, 1 <= d1 <= 15}. - bool unionBoundingBox(const FlatAffineConstraints &other); + Status unionBoundingBox(const FlatAffineConstraints &other); unsigned getNumConstraints() const { return getNumInequalities() + getNumEqualities(); @@ -663,10 +662,12 @@ private: Optional computeConstantLowerOrUpperBound(unsigned pos); // Eliminates a single identifier at 'position' from equality and inequality - // constraints. Returns 'true' if the identifier was eliminated, and false - // otherwise. - inline bool gaussianEliminateId(unsigned position) { - return gaussianEliminateIds(position, position + 1) == 1; + // constraints. Returns 'success' if the identifier was eliminated, and + // 'failure' otherwise. + inline Status gaussianEliminateId(unsigned position) { + return gaussianEliminateIds(position, position + 1) == 1 + ? Status::success() + : Status::failure(); } // Eliminates identifiers from equality and inequality constraints @@ -746,30 +747,30 @@ private: AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols); -/// Flattens 'expr' into 'flattenedExpr'. Returns true on success or false -/// if 'expr' could not be flattened (i.e., semi-affine is not yet handled). -/// 'cst' contains constraints that connect newly introduced local identifiers -/// to existing dimensional and / symbolic identifiers. See documentation for -/// AffineExprFlattener on how mod's and div's are flattened. -bool getFlattenedAffineExpr(AffineExpr expr, unsigned numDims, - unsigned numSymbols, - llvm::SmallVectorImpl *flattenedExpr, - FlatAffineConstraints *cst = nullptr); +/// Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' could not be +/// flattened (i.e., semi-affine is not yet handled). 'cst' contains constraints +/// that connect newly introduced local identifiers to existing dimensional and +/// symbolic identifiers. See documentation for AffineExprFlattener on how +/// mod's and div's are flattened. +Status getFlattenedAffineExpr(AffineExpr expr, unsigned numDims, + unsigned numSymbols, + llvm::SmallVectorImpl *flattenedExpr, + FlatAffineConstraints *cst = nullptr); /// Flattens the result expressions of the map to their corresponding flattened -/// forms and set in 'flattenedExprs'. Returns true on success or false -/// if any expression in the map could not be flattened (i.e., semi-affine is -/// not yet handled). 'cst' contains constraints that connect newly introduced -/// local identifiers to existing dimensional and / symbolic identifiers. See -/// documentation for AffineExprFlattener on how mod's and div's are flattened. -/// For all affine expressions that share the same operands (like those of an -/// affine map), this method should be used instead of repeatedly calling -/// getFlattenedAffineExpr since local variables added to deal with div's and -/// mod's will be reused across expressions. -bool getFlattenedAffineExprs( +/// forms and set in 'flattenedExprs'. Returns failure if any expression in the +/// map could not be flattened (i.e., semi-affine is not yet handled). 'cst' +/// contains constraints that connect newly introduced local identifiers to +/// existing dimensional and / symbolic identifiers. See documentation for +/// AffineExprFlattener on how mod's and div's are flattened. For all affine +/// expressions that share the same operands (like those of an affine map), this +/// method should be used instead of repeatedly calling getFlattenedAffineExpr +/// since local variables added to deal with div's and mod's will be reused +/// across expressions. +Status getFlattenedAffineExprs( AffineMap map, std::vector> *flattenedExprs, FlatAffineConstraints *cst = nullptr); -bool getFlattenedAffineExprs( +Status getFlattenedAffineExprs( IntegerSet set, std::vector> *flattenedExprs, FlatAffineConstraints *cst = nullptr); diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index 93c5b734fce..1d7859a9cbe 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -30,6 +30,7 @@ #include "mlir/IR/Block.h" #include "mlir/IR/Location.h" #include "mlir/Support/LLVM.h" +#include "mlir/Support/Status.h" #include "llvm/ADT/SmallVector.h" #include @@ -46,12 +47,6 @@ template class OpPointer; class Instruction; class Value; -/// Returns true if instruction 'a' dominates instruction b. -bool dominates(const Instruction &a, const Instruction &b); - -/// Returns true if instruction 'a' properly dominates instruction b. -bool properlyDominates(const Instruction &a, const Instruction &b); - /// Populates 'loops' with IVs of the loops surrounding 'inst' ordered from /// the outermost 'for' instruction to the innermost one. // TODO(bondhugula): handle 'if' inst's. @@ -88,9 +83,8 @@ struct ComputationSliceState { // identifiers and the values in 'lb/ubOperands' are added as symbols. // Constraints are added for all loop IV bounds (dim or symbol), and // constraints are added for slice bounds in 'lbs'/'ubs'. - // Returns true on success, false otherwise (if we cannot add loop bounds - // because of unsupported cases). - bool getAsConstraints(FlatAffineConstraints *cst); + // Returns failure if we cannot add loop bounds because of unsupported cases. + Status getAsConstraints(FlatAffineConstraints *cst); // Clears all bounds and operands in slice state. void clearBounds(); @@ -99,11 +93,10 @@ struct ComputationSliceState { /// Computes computation slice loop bounds for the loop nest surrounding /// 'srcAccess', where the returned loop bound AffineMaps are functions of /// loop IVs from the loop nest surrounding 'dstAccess'. -/// Returns true on success, false otherwise. -bool getBackwardComputationSliceState(const MemRefAccess &srcAccess, - const MemRefAccess &dstAccess, - unsigned dstLoopDepth, - ComputationSliceState *sliceState); +Status getBackwardComputationSliceState(const MemRefAccess &srcAccess, + const MemRefAccess &dstAccess, + unsigned dstLoopDepth, + ComputationSliceState *sliceState); /// Creates a clone of the computation contained in the loop nest surrounding /// 'srcOpInst', slices the iteration space of src loop based on slice bounds @@ -139,15 +132,14 @@ struct MemRefRegion { /// Computes the memory region accessed by this memref with the region /// represented as constraints symbolic/parameteric in 'loopDepth' loops - /// surrounding opInst. Returns false if this fails due to yet unimplemented - /// cases. The computed region's 'cst' field has exactly as many dimensional - /// identifiers as the rank of the memref, and *potentially* additional - /// symbolic identifiers which could include any of the loop IVs surrounding - /// opInst up until 'loopDepth' and another additional Function symbols - /// involved with the access (for eg., those appear in affine.apply's, loop - /// bounds, etc.). If 'sliceState' is non-null, operands from 'sliceState' - /// are added as symbols, and the following constraints are added to the - /// system: + /// surrounding opInst. The computed region's 'cst' field has exactly as many + /// dimensional identifiers as the rank of the memref, and *potentially* + /// additional symbolic identifiers which could include any of the loop IVs + /// surrounding opInst up until 'loopDepth' and another additional Function + /// symbols involved with the access (for eg., those appear in affine.apply's, + /// loop bounds, etc.). If 'sliceState' is non-null, operands from + /// 'sliceState' are added as symbols, and the following constraints are added + /// to the system: /// *) Inequality constraints which represent loop bounds for 'sliceState' /// operands which are loop IVS (these represent the destination loop IVs /// of the slice, and are added as symbols to MemRefRegion's constraint @@ -168,8 +160,8 @@ struct MemRefRegion { /// {memref = %A, write = false, {%i <= m0 <= %i + 7} } /// The last field is a 2-d FlatAffineConstraints symbolic in %i. /// - bool compute(Instruction *inst, unsigned loopDepth, - ComputationSliceState *sliceState = nullptr); + Status compute(Instruction *inst, unsigned loopDepth, + ComputationSliceState *sliceState = nullptr); FlatAffineConstraints *getConstraints() { return &cst; } const FlatAffineConstraints *getConstraints() const { return &cst; } @@ -204,7 +196,7 @@ struct MemRefRegion { Optional getRegionSize(); // Wrapper around FlatAffineConstraints::unionBoundingBox. - bool unionBoundingBox(const MemRefRegion &other); + Status unionBoundingBox(const MemRefRegion &other); /// Returns the rank of the memref that this region corresponds to. unsigned getRank() const; @@ -234,12 +226,12 @@ struct MemRefRegion { /// otherwise. Optional getMemRefSizeInBytes(MemRefType memRefType); -/// Checks a load or store op for an out of bound access; returns true if the -/// access is out of bounds along any of the dimensions, false otherwise. Emits -/// a diagnostic error (with location information) if emitError is true. +/// Checks a load or store op for an out of bound access; returns failure if the +/// access is out of bounds along any of the dimensions, success otherwise. +/// Emits a diagnostic error (with location information) if emitError is true. template -bool boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, - bool emitError = true); +Status boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, + bool emitError = true); /// Returns the number of surrounding loops common to both A and B. unsigned getNumCommonSurroundingLoops(const Instruction &A, diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index 2f39899e756..fd48740598a 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -92,24 +92,24 @@ void mlir::getReachableAffineApplyOps( // Builds a system of constraints with dimensional identifiers corresponding to // the loop IVs of the forOps appearing in that order. Any symbols founds in -// the bound operands are added as symbols in the system. Returns false for the -// yet unimplemented cases. +// the bound operands are added as symbols in the system. Returns failure for +// the yet unimplemented cases. // TODO(andydavis,bondhugula) Handle non-unit steps through local variables or // stride information in FlatAffineConstraints. (For eg., by using iv - lb % // step = 0 and/or by introducing a method in FlatAffineConstraints // setExprStride(ArrayRef expr, int64_t stride) -bool mlir::getIndexSet(MutableArrayRef> forOps, - FlatAffineConstraints *domain) { +Status mlir::getIndexSet(MutableArrayRef> forOps, + FlatAffineConstraints *domain) { SmallVector indices; extractForInductionVars(forOps, &indices); // Reset while associated Values in 'indices' to the domain. domain->reset(forOps.size(), /*numSymbols=*/0, /*numLocals=*/0, indices); for (auto forOp : forOps) { // Add constraints from forOp's bounds. - if (!domain->addAffineForOpDomain(forOp)) - return false; + if (failed(domain->addAffineForOpDomain(forOp))) + return Status::failure(); } - return true; + return Status::success(); } // Computes the iteration domain for 'opInst' and populates 'indexSet', which @@ -118,8 +118,8 @@ bool mlir::getIndexSet(MutableArrayRef> forOps, // 'indexSet' correspond to the loops surounding 'inst' from outermost to // innermost. // TODO(andydavis) Add support to handle IfInsts surrounding 'inst'. -static bool getInstIndexSet(const Instruction *inst, - FlatAffineConstraints *indexSet) { +static Status getInstIndexSet(const Instruction *inst, + FlatAffineConstraints *indexSet) { // TODO(andydavis) Extend this to gather enclosing IfInsts and consider // factoring it out into a utility function. SmallVector, 4> loops; @@ -376,17 +376,17 @@ static void addDomainConstraints(const FlatAffineConstraints &srcDomain, // a0 -c0 (a1 - c1) (a1 - c2) = 0 // b0 -f0 (b1 - f1) (b1 - f2) = 0 // -// Returns false if any AffineExpr cannot be flattened (due to it being -// semi-affine). Returns true otherwise. +// Returns failure if any AffineExpr cannot be flattened (due to it being +// semi-affine). Returns success otherwise. // TODO(bondhugula): assumes that dependenceDomain doesn't have local // variables already. Fix this soon. -static bool +static Status addMemRefAccessConstraints(const AffineValueMap &srcAccessMap, const AffineValueMap &dstAccessMap, const ValuePositionMap &valuePosMap, FlatAffineConstraints *dependenceDomain) { if (dependenceDomain->getNumLocalIds() != 0) - return false; + return Status::failure(); AffineMap srcMap = srcAccessMap.getAffineMap(); AffineMap dstMap = dstAccessMap.getAffineMap(); assert(srcMap.getNumResults() == dstMap.getNumResults()); @@ -402,9 +402,9 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap, std::vector> destFlatExprs; FlatAffineConstraints srcLocalVarCst, destLocalVarCst; // Get flattened expressions for the source destination maps. - if (!getFlattenedAffineExprs(srcMap, &srcFlatExprs, &srcLocalVarCst) || - !getFlattenedAffineExprs(dstMap, &destFlatExprs, &destLocalVarCst)) - return false; + if (failed(getFlattenedAffineExprs(srcMap, &srcFlatExprs, &srcLocalVarCst)) || + failed(getFlattenedAffineExprs(dstMap, &destFlatExprs, &destLocalVarCst))) + return Status::failure(); unsigned srcNumLocalIds = srcLocalVarCst.getNumLocalIds(); unsigned dstNumLocalIds = destLocalVarCst.getNumLocalIds(); @@ -511,7 +511,7 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap, dependenceDomain->addInequality(ineq); } - return true; + return Status::success(); } // Returns the number of outer loop common to 'src/dstDomain'. @@ -792,12 +792,12 @@ bool mlir::checkMemrefAccessDependence( // Get iteration domain for the 'srcAccess' instruction. FlatAffineConstraints srcDomain; - if (!getInstIndexSet(srcAccess.opInst, &srcDomain)) + if (failed(getInstIndexSet(srcAccess.opInst, &srcDomain))) return false; // Get iteration domain for 'dstAccess' instruction. FlatAffineConstraints dstDomain; - if (!getInstIndexSet(dstAccess.opInst, &dstDomain)) + if (failed(getInstIndexSet(dstAccess.opInst, &dstDomain))) return false; // Return 'false' if loopDepth > numCommonLoops and if the ancestor operation @@ -826,10 +826,10 @@ bool mlir::checkMemrefAccessDependence( srcDomain.getNumDimIds() + dstDomain.getNumDimIds()); // Create memref access constraint by equating src/dst access functions. - // Note that this check is conservative, and will failure in the future - // when local variables for mod/div exprs are supported. - if (!addMemRefAccessConstraints(srcAccessMap, dstAccessMap, valuePosMap, - dependenceConstraints)) + // Note that this check is conservative, and will fail in the future when + // local variables for mod/div exprs are supported. + if (failed(addMemRefAccessConstraints(srcAccessMap, dstAccessMap, valuePosMap, + dependenceConstraints))) return true; // Add 'src' happens before 'dst' ordering constraints. diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 4f07a9f818e..df6776b8184 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -74,16 +74,15 @@ private: } // end anonymous namespace -// Flattens the expressions in map. Returns true on success or false -// if 'expr' was unable to be flattened (i.e., semi-affine expressions not -// handled yet). -static bool getFlattenedAffineExprs( +// Flattens the expressions in map. Returns failure if 'expr' was unable to be +// flattened (i.e., semi-affine expressions not handled yet). +static Status getFlattenedAffineExprs( ArrayRef exprs, unsigned numDims, unsigned numSymbols, std::vector> *flattenedExprs, FlatAffineConstraints *localVarCst) { if (exprs.empty()) { localVarCst->reset(numDims, numSymbols); - return true; + return Status::success(); } AffineExprFlattener flattener(numDims, numSymbols, exprs[0].getContext()); @@ -91,7 +90,7 @@ static bool getFlattenedAffineExprs( // local identifiers / expressions are shared. for (auto expr : exprs) { if (!expr.isPureAffine()) - return false; + return Status::failure(); flattener.walkPostOrder(expr); } @@ -105,44 +104,43 @@ static bool getFlattenedAffineExprs( localVarCst->clearAndCopyFrom(flattener.localVarCst); } - return true; + return Status::success(); } -// Flattens 'expr' into 'flattenedExpr'. Returns true on success or false -// if 'expr' was unable to be flattened (semi-affine expressions not handled -// yet). -bool mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims, - unsigned numSymbols, - llvm::SmallVectorImpl *flattenedExpr, - FlatAffineConstraints *localVarCst) { +// Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' was unable to +// be flattened (semi-affine expressions not handled yet). +Status +mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims, + unsigned numSymbols, + llvm::SmallVectorImpl *flattenedExpr, + FlatAffineConstraints *localVarCst) { std::vector> flattenedExprs; - bool ret = ::getFlattenedAffineExprs({expr}, numDims, numSymbols, - &flattenedExprs, localVarCst); + Status ret = ::getFlattenedAffineExprs({expr}, numDims, numSymbols, + &flattenedExprs, localVarCst); *flattenedExpr = flattenedExprs[0]; return ret; } -/// Flattens the expressions in map. Returns true on success or false -/// if 'expr' was unable to be flattened (i.e., semi-affine expressions not -/// handled yet). -bool mlir::getFlattenedAffineExprs( +/// Flattens the expressions in map. Returns failure if 'expr' was unable to be +/// flattened (i.e., semi-affine expressions not handled yet). +Status mlir::getFlattenedAffineExprs( AffineMap map, std::vector> *flattenedExprs, FlatAffineConstraints *localVarCst) { if (map.getNumResults() == 0) { localVarCst->reset(map.getNumDims(), map.getNumSymbols()); - return true; + return Status::success(); } return ::getFlattenedAffineExprs(map.getResults(), map.getNumDims(), map.getNumSymbols(), flattenedExprs, localVarCst); } -bool mlir::getFlattenedAffineExprs( +Status mlir::getFlattenedAffineExprs( IntegerSet set, std::vector> *flattenedExprs, FlatAffineConstraints *localVarCst) { if (set.getNumConstraints() == 0) { localVarCst->reset(set.getNumDims(), set.getNumSymbols()); - return true; + return Status::success(); } return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(), set.getNumSymbols(), flattenedExprs, @@ -326,7 +324,7 @@ FlatAffineConstraints::FlatAffineConstraints(IntegerSet set) // Flatten expressions and add them to the constraint system. std::vector> flatExprs; FlatAffineConstraints localVarCst; - if (!getFlattenedAffineExprs(set, &flatExprs, &localVarCst)) { + if (failed(getFlattenedAffineExprs(set, &flatExprs, &localVarCst))) { assert(false && "flattening unimplemented for semi-affine integer sets"); return; } @@ -609,13 +607,14 @@ static void mergeAndAlignIds(unsigned offset, FlatAffineConstraints *A, // This routine may add additional local variables if the flattened expression // corresponding to the map has such variables due to mod's, ceildiv's, and // floordiv's in it. -bool FlatAffineConstraints::composeMap(AffineValueMap *vMap) { +Status FlatAffineConstraints::composeMap(AffineValueMap *vMap) { std::vector> flatExprs; FlatAffineConstraints localCst; - if (!getFlattenedAffineExprs(vMap->getAffineMap(), &flatExprs, &localCst)) { + if (failed(getFlattenedAffineExprs(vMap->getAffineMap(), &flatExprs, + &localCst))) { LLVM_DEBUG(llvm::dbgs() << "composition unimplemented for semi-affine maps\n"); - return false; + return Status::failure(); } assert(flatExprs.size() == vMap->getNumResults()); @@ -674,7 +673,7 @@ bool FlatAffineConstraints::composeMap(AffineValueMap *vMap) { addEquality(eqToAdd); } - return true; + return Status::success(); } // Turn a dimension into a symbol. @@ -710,13 +709,13 @@ void FlatAffineConstraints::convertLoopIVSymbolsToDims() { } } -bool FlatAffineConstraints::addAffineForOpDomain( - ConstOpPointer forOp) { +Status +FlatAffineConstraints::addAffineForOpDomain(ConstOpPointer forOp) { unsigned pos; // Pre-condition for this method. if (!findId(*forOp->getInductionVar(), &pos)) { assert(0 && "Value not found"); - return false; + return Status::failure(); } if (forOp->getStep() != 1) @@ -726,7 +725,7 @@ bool FlatAffineConstraints::addAffineForOpDomain( int64_t step = forOp->getStep(); // Adds a lower or upper bound when the bounds aren't constant. - auto addLowerOrUpperBound = [&](bool lower) -> bool { + auto addLowerOrUpperBound = [&](bool lower) -> Status { auto boundMap = lower ? forOp->getLowerBoundMap() : forOp->getUpperBoundMap(); auto boundOperands = @@ -736,9 +735,9 @@ bool FlatAffineConstraints::addAffineForOpDomain( FlatAffineConstraints localVarCst; std::vector> flatExprs; - if (!getFlattenedAffineExprs(boundMap, &flatExprs, &localVarCst)) { + if (failed(getFlattenedAffineExprs(boundMap, &flatExprs, &localVarCst))) { forOp->emitError("semi-affine expressions not yet supported"); - return false; + return Status::failure(); } // Set values for localVarCst. SmallVector values; @@ -766,8 +765,8 @@ bool FlatAffineConstraints::addAffineForOpDomain( pos = getNumDimIds() - 1; if (auto loop = getForInductionVarOwner(operand)) { // Outer loop IVs could be used in forOp's bounds. - if (!this->addAffineForOpDomain(loop)) - return false; + if (failed(this->addAffineForOpDomain(loop))) + return Status::failure(); } } } @@ -808,20 +807,20 @@ bool FlatAffineConstraints::addAffineForOpDomain( : flatExpr[flatExpr.size() - 1] - step; addInequality(ineq); } - return true; + return Status::success(); }; if (forOp->hasConstantLowerBound()) { addConstantLowerBound(pos, forOp->getConstantLowerBound()); } else { // Non-constant lower bound case. - if (!addLowerOrUpperBound(/*lower=*/true)) - return false; + if (failed(addLowerOrUpperBound(/*lower=*/true))) + return Status::failure(); } if (forOp->hasConstantUpperBound()) { addConstantUpperBound(pos, forOp->getConstantUpperBound() - step); - return true; + return Status::success(); } // Non-constant upper bound case. return addLowerOrUpperBound(/*lower=*/false); @@ -1679,10 +1678,10 @@ void FlatAffineConstraints::getSliceBounds(unsigned num, MLIRContext *context, // Note that both lower/upper bounds use operands from 'operands'. // Returns true on success. Returns false for unimplemented cases such as // semi-affine expressions or expressions with mod/floordiv. -bool FlatAffineConstraints::addSliceBounds(ArrayRef values, - ArrayRef lbMaps, - ArrayRef ubMaps, - ArrayRef operands) { +Status FlatAffineConstraints::addSliceBounds(ArrayRef values, + ArrayRef lbMaps, + ArrayRef ubMaps, + ArrayRef operands) { assert(values.size() == lbMaps.size()); assert(lbMaps.size() == ubMaps.size()); @@ -1690,7 +1689,7 @@ bool FlatAffineConstraints::addSliceBounds(ArrayRef values, // add a single equality equal to the first bound map result expr. // TODO(andydavis,bondhugula): refactor and reuse from addAffineForOpDomain. auto addLowerOrUpperBound = [&](unsigned pos, AffineMap boundMap, bool eq, - bool lower = true) -> bool { + bool lower = true) -> Status { assert(pos < getNumDimAndSymbolIds() && "invalid position"); // Equality follows the logic of lower bound except that we add an equality // instead of an inequality. @@ -1702,9 +1701,9 @@ bool FlatAffineConstraints::addSliceBounds(ArrayRef values, FlatAffineConstraints localVarCst; std::vector> flatExprs; - if (!getFlattenedAffineExprs(boundMap, &flatExprs, &localVarCst)) { + if (failed(getFlattenedAffineExprs(boundMap, &flatExprs, &localVarCst))) { LLVM_DEBUG(llvm::dbgs() << "semi-affine expressions not yet supported\n"); - return false; + return Status::failure(); } // Merge and align with localVarCst. @@ -1758,7 +1757,7 @@ bool FlatAffineConstraints::addSliceBounds(ArrayRef values, : flatExpr[flatExpr.size() - 1] - 1; eq ? addEquality(ineq) : addInequality(ineq); } - return true; + return Status::success(); }; for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) { @@ -1775,20 +1774,20 @@ bool FlatAffineConstraints::addSliceBounds(ArrayRef values, if (lbMap && ubMap && lbMap.getNumResults() == 1 && ubMap.getNumResults() == 1 && lbMap.getResult(0) + 1 == ubMap.getResult(0)) { - if (!addLowerOrUpperBound(pos, lbMap, /*eq=*/true, /*lower=*/true)) - return false; + if (failed(addLowerOrUpperBound(pos, lbMap, /*eq=*/true, /*lower=*/true))) + return Status::failure(); continue; } if (lbMap && - !addLowerOrUpperBound(pos, lbMap, /*eq=*/false, /*lower=*/true)) - return false; + failed(addLowerOrUpperBound(pos, lbMap, /*eq=*/false, /*lower=*/true))) + return Status::failure(); if (ubMap && - !addLowerOrUpperBound(pos, ubMap, /*eq=*/false, /*lower=*/false)) - return false; + failed(addLowerOrUpperBound(pos, ubMap, /*eq=*/false, /*lower=*/false))) + return Status::failure(); } - return true; + return Status::success(); } void FlatAffineConstraints::addEquality(ArrayRef eq) { @@ -1975,22 +1974,22 @@ void FlatAffineConstraints::setAndEliminate(unsigned pos, int64_t constVal) { removeId(pos); } -bool FlatAffineConstraints::constantFoldId(unsigned pos) { +Status FlatAffineConstraints::constantFoldId(unsigned pos) { assert(pos < getNumIds() && "invalid position"); int rowIdx; if ((rowIdx = findEqualityToConstant(*this, pos)) == -1) - return false; + return Status::failure(); // atEq(rowIdx, pos) is either -1 or 1. assert(atEq(rowIdx, pos) * atEq(rowIdx, pos) == 1); int64_t constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos); setAndEliminate(pos, constVal); - return true; + return Status::success(); } void FlatAffineConstraints::constantFoldIdRange(unsigned pos, unsigned num) { for (unsigned s = pos, t = pos, e = pos + num; s < e; s++) { - if (!constantFoldId(t)) + if (failed(constantFoldId(t))) t++; } } @@ -2408,9 +2407,9 @@ void FlatAffineConstraints::FourierMotzkinEliminate( for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { if (atEq(r, pos) != 0) { // Use Gaussian elimination here (since we have an equality). - bool ret = gaussianEliminateId(pos); + Status ret = gaussianEliminateId(pos); (void)ret; - assert(ret && "Gaussian elimination guaranteed to succeed"); + assert(succeeded(ret) && "Gaussian elimination guaranteed to succeed"); LLVM_DEBUG(llvm::dbgs() << "FM output (through Gaussian elimination):\n"); LLVM_DEBUG(dump()); return; @@ -2681,8 +2680,8 @@ static BoundCmpResult compareBounds(ArrayRef a, ArrayRef b) { // Computes the bounding box with respect to 'other' by finding the min of the // lower bounds and the max of the upper bounds along each of the dimensions. -bool FlatAffineConstraints::unionBoundingBox( - const FlatAffineConstraints &otherCst) { +Status +FlatAffineConstraints::unionBoundingBox(const FlatAffineConstraints &otherCst) { assert(otherCst.getNumDimIds() == numDims && "dims mismatch"); assert(otherCst.getIds() .slice(0, getNumDimIds()) @@ -2718,13 +2717,13 @@ bool FlatAffineConstraints::unionBoundingBox( if (!extent.hasValue()) // TODO(bondhugula): symbolic extents when necessary. // TODO(bondhugula): handle union if a dimension is unbounded. - return false; + return Status::failure(); auto otherExtent = other.getConstantBoundOnDimSize(d, &otherLb, &otherLbDivisor); if (!otherExtent.hasValue() || lbDivisor != otherLbDivisor) // TODO(bondhugula): symbolic extents when necessary. - return false; + return Status::failure(); assert(lbDivisor > 0 && "divisor always expected to be positive"); @@ -2739,7 +2738,7 @@ bool FlatAffineConstraints::unionBoundingBox( auto constLb = getConstantLowerBound(d); auto constOtherLb = other.getConstantLowerBound(d); if (!constLb.hasValue() || !constOtherLb.hasValue()) - return false; + return Status::failure(); std::fill(minLb.begin(), minLb.end(), 0); minLb.back() = std::min(constLb.getValue(), constOtherLb.getValue()); } @@ -2761,7 +2760,7 @@ bool FlatAffineConstraints::unionBoundingBox( auto constUb = getConstantUpperBound(d); auto constOtherUb = other.getConstantUpperBound(d); if (!constUb.hasValue() || !constOtherUb.hasValue()) - return false; + return Status::failure(); std::fill(maxUb.begin(), maxUb.end(), 0); maxUb.back() = std::max(constUb.getValue(), constOtherUb.getValue()); } @@ -2790,5 +2789,5 @@ bool FlatAffineConstraints::unionBoundingBox( addInequality(boundingUbs[d]); } - return true; + return Status::success(); } diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index ba6f79d6d77..2669b6dc479 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -55,7 +55,7 @@ void mlir::getLoopIVs(const Instruction &inst, } // Populates 'cst' with FlatAffineConstraints which represent slice bounds. -bool ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) { +Status ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) { assert(!lbOperands.empty()); // Adds src 'ivs' as dimension identifiers in 'cst'. unsigned numDims = ivs.size(); @@ -80,17 +80,18 @@ bool ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) { } } else { if (auto loop = getForInductionVarOwner(value)) { - if (!cst->addAffineForOpDomain(loop)) - return false; + if (failed(cst->addAffineForOpDomain(loop))) + return Status::failure(); } } } // Add slices bounds on 'ivs' using maps 'lbs'/'ubs' with 'lbOperands[0]' - bool ret = cst->addSliceBounds(ivs, lbs, ubs, lbOperands[0]); - assert(ret && "should not fail as we never have semi-affine slice maps"); + Status ret = cst->addSliceBounds(ivs, lbs, ubs, lbOperands[0]); + assert(succeeded(ret) && + "should not fail as we never have semi-affine slice maps"); (void)ret; - return true; + return Status::success(); } // Clears state bounds and operand state. @@ -150,15 +151,14 @@ Optional MemRefRegion::getConstantBoundingSizeAndShape( return numElements; } -bool MemRefRegion::unionBoundingBox(const MemRefRegion &other) { +Status MemRefRegion::unionBoundingBox(const MemRefRegion &other) { assert(memref == other.memref); return cst.unionBoundingBox(*other.getConstraints()); } /// Computes the memory region accessed by this memref with the region /// represented as constraints symbolic/parameteric in 'loopDepth' loops -/// surrounding opInst and any additional Function symbols. Returns false if -/// this fails due to yet unimplemented cases. +/// surrounding opInst and any additional Function symbols. // For example, the memref region for this load operation at loopDepth = 1 will // be as below: // @@ -173,8 +173,8 @@ bool MemRefRegion::unionBoundingBox(const MemRefRegion &other) { // // TODO(bondhugula): extend this to any other memref dereferencing ops // (dma_start, dma_wait). -bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth, - ComputationSliceState *sliceState) { +Status MemRefRegion::compute(Instruction *inst, unsigned loopDepth, + ComputationSliceState *sliceState) { assert((inst->isa() || inst->isa()) && "load/store op expected"); @@ -194,7 +194,7 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth, extractForInductionVars(ivs, ®ionSymbols); // A rank 0 memref has a 0-d region. cst.reset(rank, loopDepth, 0, regionSymbols); - return true; + return Status::success(); } // Build the constraints for this region. @@ -235,8 +235,8 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth, // bounds expressions involve outer loops or other symbols. // TODO(bondhugula): rewrite this to use getInstIndexSet; this way // conditionals will be handled when the latter supports it. - if (!cst.addAffineForOpDomain(loop)) - return false; + if (failed(cst.addAffineForOpDomain(loop))) + return Status::failure(); } else { // Has to be a valid symbol. auto *symbol = operand; @@ -269,17 +269,18 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth, } } // Add upper/lower bounds from 'sliceState' to 'cst'. - bool ret = cst.addSliceBounds(sliceState->ivs, sliceState->lbs, - sliceState->ubs, sliceState->lbOperands[0]); - assert(ret && "should not fail as we never have semi-affine slice maps"); + Status ret = cst.addSliceBounds(sliceState->ivs, sliceState->lbs, + sliceState->ubs, sliceState->lbOperands[0]); + assert(succeeded(ret) && + "should not fail as we never have semi-affine slice maps"); (void)ret; } // Add access function equalities to connect loop IVs to data dimensions. - if (!cst.composeMap(&accessValueMap)) { + if (failed(cst.composeMap(&accessValueMap))) { inst->emitError("getMemRefRegion: compose affine map failed"); LLVM_DEBUG(accessValueMap.getAffineMap().dump()); - return false; + return Status::failure(); } // Set all identifiers appearing after the first 'rank' identifiers as @@ -315,7 +316,7 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth, LLVM_DEBUG(llvm::dbgs() << "Memory region:\n"); LLVM_DEBUG(cst.dump()); - return true; + return Status::success(); } // TODO(mlir-team): improve/complete this when we have target data. @@ -378,8 +379,8 @@ Optional mlir::getMemRefSizeInBytes(MemRefType memRefType) { } template -bool mlir::boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, - bool emitError) { +Status mlir::boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, + bool emitError) { static_assert( std::is_same>::value || std::is_same>::value, @@ -388,8 +389,8 @@ bool mlir::boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, Instruction *opInst = loadOrStoreOp->getInstruction(); MemRefRegion region(opInst->getLoc()); - if (!region.compute(opInst, /*loopDepth=*/0)) - return false; + if (failed(region.compute(opInst, /*loopDepth=*/0))) + return Status::success(); LLVM_DEBUG(llvm::dbgs() << "Memory region"); LLVM_DEBUG(region.getConstraints()->dump()); @@ -429,14 +430,14 @@ bool mlir::boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, "memref out of lower bound access along dimension #" + Twine(r + 1)); } } - return outOfBounds; + return outOfBounds ? Status::failure() : Status::success(); } // Explicitly instantiate the template so that the compiler knows we need them! -template bool mlir::boundCheckLoadOrStoreOp(OpPointer loadOp, - bool emitError); -template bool mlir::boundCheckLoadOrStoreOp(OpPointer storeOp, - bool emitError); +template Status mlir::boundCheckLoadOrStoreOp(OpPointer loadOp, + bool emitError); +template Status mlir::boundCheckLoadOrStoreOp(OpPointer storeOp, + bool emitError); // Returns in 'positions' the Block positions of 'inst' in each ancestor // Block from the Block containing instruction, stopping at 'limitBlock'. @@ -486,17 +487,16 @@ const char *const kSliceFusionBarrierAttrName = "slice_fusion_barrier"; // out any dst loop IVs at depth greater than 'dstLoopDepth', and computes slice // bounds in 'sliceState' which represent the src IVs in terms of the dst IVs, // symbols and constants. -bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess, - const MemRefAccess &dstAccess, - unsigned dstLoopDepth, - ComputationSliceState *sliceState) { +Status mlir::getBackwardComputationSliceState( + const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, + unsigned dstLoopDepth, ComputationSliceState *sliceState) { bool readReadAccesses = srcAccess.opInst->isa() && dstAccess.opInst->isa(); FlatAffineConstraints dependenceConstraints; if (!checkMemrefAccessDependence( srcAccess, dstAccess, /*loopDepth=*/1, &dependenceConstraints, /*dependenceComponents=*/nullptr, /*allowRAR=*/readReadAccesses)) { - return false; + return Status::failure(); } // Get loop nest surrounding src operation. SmallVector, 4> srcLoopIVs; @@ -509,7 +509,7 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess, unsigned numDstLoopIVs = dstLoopIVs.size(); if (dstLoopDepth > numDstLoopIVs) { dstAccess.opInst->emitError("invalid destination loop depth"); - return false; + return Status::failure(); } // Project out dimensions other than those up to 'dstLoopDepth'. @@ -560,7 +560,7 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess, break; } - return true; + return Status::success(); } /// Creates a computation slice of the loop nest surrounding 'srcOpInst', @@ -711,8 +711,9 @@ static Optional getMemoryFootprintBytes(const Block &block, // Compute the memref region symbolic in any IVs enclosing this block. auto region = std::make_unique(opInst->getLoc()); - if (!region->compute(opInst, - /*loopDepth=*/getNestingDepth(*block.begin()))) { + if (failed( + region->compute(opInst, + /*loopDepth=*/getNestingDepth(*block.begin())))) { opInst->emitError("Error obtaining memory region\n"); error = true; return; @@ -720,7 +721,7 @@ static Optional getMemoryFootprintBytes(const Block &block, auto it = regions.find(region->memref); if (it == regions.end()) { regions[region->memref] = std::move(region); - } else if (!it->second->unionBoundingBox(*region)) { + } else if (failed(it->second->unionBoundingBox(*region))) { opInst->emitError("Error performing a union on a memory region\n"); error = true; return; diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 42b5b01a24f..43c709fa3ab 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -619,7 +619,7 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { // Compute the MemRefRegion accessed. auto region = std::make_unique(opInst->getLoc()); - if (!region->compute(opInst, dmaDepth)) { + if (failed(region->compute(opInst, dmaDepth))) { LLVM_DEBUG(llvm::dbgs() << "Error obtaining memory region: semi-affine maps?\n"); LLVM_DEBUG(llvm::dbgs() << "over-approximating to the entire memref\n"); @@ -653,7 +653,7 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { return false; // Perform a union with the existing region. - if (!it->second->unionBoundingBox(*region)) { + if (failed(it->second->unionBoundingBox(*region))) { LLVM_DEBUG(llvm::dbgs() << "Memory region bounding box failed; " "over-approximating to the entire memref\n"); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 4183cba8859..3786177f9ec 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -1123,7 +1123,7 @@ static Value *createPrivateMemRef(OpPointer forOp, // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'. MemRefRegion region(srcStoreOpInst->getLoc()); - bool validRegion = region.compute(srcStoreOpInst, dstLoopDepth); + bool validRegion = succeeded(region.compute(srcStoreOpInst, dstLoopDepth)); (void)validRegion; assert(validRegion && "unexpected memref region failure"); SmallVector newShape; @@ -1252,7 +1252,7 @@ static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, // Compute MemRefRegion 'srcWriteRegion' for 'srcStoreOpInst' on 'memref'. auto *srcStoreOpInst = srcNode->stores.front(); MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc()); - if (!srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0)) { + if (failed(srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0))) { LLVM_DEBUG(llvm::dbgs() << "Unable to compute MemRefRegion for source operation\n."); return false; @@ -1271,7 +1271,7 @@ static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, assert(dstStoreOps.size() == 1); auto *dstStoreOpInst = dstStoreOps[0]; MemRefRegion dstWriteRegion(dstStoreOpInst->getLoc()); - if (!dstWriteRegion.compute(dstStoreOpInst, /*loopDepth=*/0)) { + if (failed(dstWriteRegion.compute(dstStoreOpInst, /*loopDepth=*/0))) { LLVM_DEBUG(llvm::dbgs() << "Unable to compute MemRefRegion for dest operation\n."); return false; @@ -1304,8 +1304,9 @@ static bool getSliceUnion(Instruction *srcOpInst, unsigned numDstLoadOpInsts = dstLoadOpInsts.size(); assert(numDstLoadOpInsts > 0); // Compute the slice bounds between 'srcOpInst' and 'dstLoadOpInsts[0]'. - if (!mlir::getBackwardComputationSliceState( - srcAccess, MemRefAccess(dstLoadOpInsts[0]), dstLoopDepth, sliceState)) + if (failed(mlir::getBackwardComputationSliceState( + srcAccess, MemRefAccess(dstLoadOpInsts[0]), dstLoopDepth, + sliceState))) return false; // Handle the common case of one dst load without a copy. if (numDstLoadOpInsts == 1) @@ -1313,7 +1314,7 @@ static bool getSliceUnion(Instruction *srcOpInst, // Initialize 'sliceUnionCst' with the bounds computed in previous step. FlatAffineConstraints sliceUnionCst; - if (!sliceState->getAsConstraints(&sliceUnionCst)) { + if (failed(sliceState->getAsConstraints(&sliceUnionCst))) { LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice bound constraints\n."); return false; } @@ -1324,21 +1325,21 @@ static bool getSliceUnion(Instruction *srcOpInst, MemRefAccess dstAccess(dstLoadOpInsts[i]); // Compute slice bounds for 'srcOpInst' and 'dstLoadOpInsts[i]'. ComputationSliceState tmpSliceState; - if (!mlir::getBackwardComputationSliceState(srcAccess, dstAccess, - dstLoopDepth, &tmpSliceState)) { + if (failed(mlir::getBackwardComputationSliceState( + srcAccess, dstAccess, dstLoopDepth, &tmpSliceState))) { LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice bounds\n."); return false; } // Compute constraints for 'tmpSliceState' in 'tmpSliceCst'. FlatAffineConstraints tmpSliceCst; - if (!tmpSliceState.getAsConstraints(&tmpSliceCst)) { + if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) { LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice bound constraints\n."); return false; } // Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'. - if (!sliceUnionCst.unionBoundingBox(tmpSliceCst)) { + if (failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) { LLVM_DEBUG(llvm::dbgs() << "Unable to compute union bounding box of slice bounds.\n."); return false; @@ -1480,7 +1481,7 @@ static bool isFusionProfitable(Instruction *srcOpInst, // Compute src loop nest write region size. MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc()); - if (!srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0)) { + if (failed(srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0))) { LLVM_DEBUG(llvm::dbgs() << "Unable to compute MemRefRegion for source instruction\n."); return false; @@ -1563,8 +1564,8 @@ static bool isFusionProfitable(Instruction *srcOpInst, // nest slice 'sliceStates[i - 1]' were to be inserted into the dst loop // nest at loop depth 'i' MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc()); - if (!sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0, - &sliceStates[i - 1])) { + if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0, + &sliceStates[i - 1]))) { LLVM_DEBUG(llvm::dbgs() << "Failed to compute slice write region at loopDepth: " << i << "\n"); -- cgit v1.2.3 From ce7e59536c32e49f9f3ae8cc959571eb5367c45f Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Fri, 8 Mar 2019 09:21:52 -0800 Subject: Add a basic model to set tile sizes + some cleanup - compute tile sizes based on a simple model that looks at memory footprints (instead of using the hardcoded default value) - adjust tile sizes to make them factors of trip counts based on an option - update loop fusion CL options to allow setting maximal fusion at pass creation - change an emitError to emitWarning (since it's not a hard error unless the client treats it that way, in which case, it can emit one) $ mlir-opt -debug-only=loop-tile -loop-tile test/Transforms/loop-tiling.mlir test/Transforms/loop-tiling.mlir:81:3: note: using tile sizes [4 4 5 ] for %i = 0 to 256 { for %i0 = 0 to 256 step 4 { for %i1 = 0 to 256 step 4 { for %i2 = 0 to 250 step 5 { for %i3 = #map4(%i0) to #map11(%i0) { for %i4 = #map4(%i1) to #map11(%i1) { for %i5 = #map4(%i2) to #map12(%i2) { %0 = load %arg0[%i3, %i5] : memref<8x8xvector<64xf32>> %1 = load %arg1[%i5, %i4] : memref<8x8xvector<64xf32>> %2 = load %arg2[%i3, %i4] : memref<8x8xvector<64xf32>> %3 = mulf %0, %1 : vector<64xf32> %4 = addf %2, %3 : vector<64xf32> store %4, %arg2[%i3, %i4] : memref<8x8xvector<64xf32>> } } } } } } PiperOrigin-RevId: 237461836 --- mlir/include/mlir/Transforms/Passes.h | 5 +- mlir/lib/Analysis/AffineStructures.cpp | 2 +- mlir/lib/Transforms/LoopFusion.cpp | 47 ++++++---- mlir/lib/Transforms/LoopTiling.cpp | 164 ++++++++++++++++++++++++++++++--- mlir/test/Transforms/loop-tiling.mlir | 27 ++++++ 5 files changed, 211 insertions(+), 34 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index 32c7d93307c..d75eec74320 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -76,7 +76,8 @@ FunctionPassBase *createSimplifyAffineStructuresPass(); /// equal to `localBufSizeThreshold` are promoted to memory space /// `fastMemorySpace'. FunctionPassBase *createLoopFusionPass(unsigned fastMemorySpace = 0, - uint64_t localBufSizeThreshold = 0); + uint64_t localBufSizeThreshold = 0, + bool maximalFusion = false); /// Creates a pass to pipeline explicit movement of data across levels of the /// memory hierarchy. @@ -88,7 +89,7 @@ FunctionPassBase *createPipelineDataTransferPass(); FunctionPassBase *createLowerAffinePass(); /// Creates a pass to perform tiling on loop nests. -FunctionPassBase *createLoopTilingPass(); +FunctionPassBase *createLoopTilingPass(uint64_t cacheSizeBytes); /// Promotes all accessed memref regions to the specified faster memory space /// while generating DMAs to move data. diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index df6776b8184..0687526ce73 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -736,7 +736,7 @@ FlatAffineConstraints::addAffineForOpDomain(ConstOpPointer forOp) { FlatAffineConstraints localVarCst; std::vector> flatExprs; if (failed(getFlattenedAffineExprs(boundMap, &flatExprs, &localVarCst))) { - forOp->emitError("semi-affine expressions not yet supported"); + forOp->emitWarning("semi-affine expressions not yet supported"); return Status::failure(); } // Set values for localVarCst. diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 3786177f9ec..4e619ef36a4 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -48,7 +48,9 @@ using namespace mlir; static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); -/// Disables fusion profitability check and fuses if valid. +/// Disables fusion profitability check and fuses if valid. Ignore any +/// additional (redundant) computation tolerance threshold +/// that would have prevented fusion. static llvm::cl::opt clMaximalLoopFusion("fusion-maximal", llvm::cl::desc("Enables maximal loop fusion"), @@ -66,8 +68,8 @@ static llvm::cl::opt clFusionFastMemorySpace( llvm::cl::desc("Faster memory space number to promote fusion buffers to"), llvm::cl::cat(clOptionsCategory)); -// A local buffer of size less than or equal to this size is promoted to fast -// memory. +// A local buffer of size less than or equal to this size is automatically +// promoted to fast memory after producer-consumer fusion. static llvm::cl::opt clFusionLocalBufThreshold( "fusion-local-buf-threshold", llvm::cl::desc("Threshold size (KiB) for promoting local buffers to fast " @@ -86,9 +88,10 @@ namespace { // and add support for more general loop fusion algorithms. struct LoopFusion : public FunctionPass { - LoopFusion(unsigned fastMemorySpace = 0, uint64_t localBufSizeThreshold = 0) + LoopFusion(unsigned fastMemorySpace = 0, uint64_t localBufSizeThreshold = 0, + bool maximalFusion = false) : localBufSizeThreshold(localBufSizeThreshold), - fastMemorySpace(fastMemorySpace) {} + fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion) {} void runOnFunction() override; @@ -96,6 +99,9 @@ struct LoopFusion : public FunctionPass { // `fastMemorySpace` if provided. uint64_t localBufSizeThreshold; Optional fastMemorySpace = None; + // If true, ignore any additional (redundant) computation tolerance threshold + // that would have prevented fusion. + bool maximalFusion; // The amount of additional computation that is tolerated while fusing // pair-wise as a fraction of the total computation. @@ -105,8 +111,9 @@ struct LoopFusion : public FunctionPass { } // end anonymous namespace FunctionPassBase *mlir::createLoopFusionPass(unsigned fastMemorySpace, - uint64_t localBufSizeThreshold) { - return new LoopFusion(fastMemorySpace, localBufSizeThreshold); + uint64_t localBufSizeThreshold, + bool maximalFusion) { + return new LoopFusion(fastMemorySpace, localBufSizeThreshold, maximalFusion); } namespace { @@ -1411,7 +1418,7 @@ static bool isFusionProfitable(Instruction *srcOpInst, ArrayRef dstLoadOpInsts, ArrayRef dstStoreOpInsts, ComputationSliceState *sliceState, - unsigned *dstLoopDepth) { + unsigned *dstLoopDepth, bool maximalFusion) { LLVM_DEBUG({ llvm::dbgs() << "Checking whether fusion is profitable between:\n"; llvm::dbgs() << " " << *srcOpInst << " and \n"; @@ -1620,7 +1627,7 @@ static bool isFusionProfitable(Instruction *srcOpInst, // (as per computeToleranceThreshold), we will simply pick the one that // reduces the intermediary size the most. if ((storageReduction > maxStorageReduction) && - (clMaximalLoopFusion || + (maximalFusion || (additionalComputeFraction < computeToleranceThreshold))) { maxStorageReduction = storageReduction; bestDstLoopDepth = i; @@ -1632,7 +1639,7 @@ static bool isFusionProfitable(Instruction *srcOpInst, // A simple cost model: fuse if it reduces the memory footprint. If // -maximal-fusion is set, fuse nevertheless. - if (!clMaximalLoopFusion && !bestDstLoopDepth.hasValue()) { + if (!maximalFusion && !bestDstLoopDepth.hasValue()) { LLVM_DEBUG( llvm::dbgs() << "All fusion choices involve more than the threshold amount of " @@ -1661,7 +1668,7 @@ static bool isFusionProfitable(Instruction *srcOpInst, Optional storageReduction = None; - if (!clMaximalLoopFusion) { + if (!maximalFusion) { if (!dstMemSize.hasValue() || !srcMemSize.hasValue()) { LLVM_DEBUG( llvm::dbgs() @@ -1785,13 +1792,16 @@ public: unsigned localBufSizeThreshold; // Parameter for fast memory space. Optional fastMemorySpace; + // If true, ignore any additional (redundant) computation tolerance threshold + // that would have prevented fusion. + bool maximalFusion; using Node = MemRefDependenceGraph::Node; GreedyFusion(MemRefDependenceGraph *mdg, unsigned localBufSizeThreshold, - Optional fastMemorySpace) + Optional fastMemorySpace, bool maximalFusion) : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold), - fastMemorySpace(fastMemorySpace) {} + fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion) {} // Initializes 'worklist' with nodes from 'mdg' void init() { @@ -1917,7 +1927,7 @@ public: // Check if fusion would be profitable. if (!isFusionProfitable(srcStoreOpInst, srcStoreOpInst, dstLoadOpInsts, dstStoreOpInsts, &sliceState, - &bestDstLoopDepth)) + &bestDstLoopDepth, maximalFusion)) continue; // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. @@ -2076,7 +2086,8 @@ public: // Check if fusion would be profitable. if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstLoadOpInsts, - dstStoreOpInsts, &sliceState, &bestDstLoopDepth)) + dstStoreOpInsts, &sliceState, &bestDstLoopDepth, + maximalFusion)) continue; // Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'. @@ -2231,9 +2242,13 @@ void LoopFusion::runOnFunction() { localBufSizeThreshold = clFusionLocalBufThreshold * 1024; } + if (clMaximalLoopFusion.getNumOccurrences() > 0) + maximalFusion = clMaximalLoopFusion; + MemRefDependenceGraph g; if (g.init(getFunction())) - GreedyFusion(&g, localBufSizeThreshold, fastMemorySpace).run(); + GreedyFusion(&g, localBufSizeThreshold, fastMemorySpace, maximalFusion) + .run(); } static PassRegistration pass("loop-fusion", "Fuse loop nests"); diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 7e345bf15fe..4fa9c5a44d3 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -23,12 +23,15 @@ #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/LoopAnalysis.h" +#include "mlir/Analysis/Utils.h" #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/LoopUtils.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Utils.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include using namespace mlir; @@ -36,34 +39,53 @@ using namespace mlir; static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); +static llvm::cl::opt + clCacheSizeKiB("tile-cache-size", + llvm::cl::desc("Set size of cache to tile for in KiB"), + llvm::cl::cat(clOptionsCategory)); + +// Tile size to use for all loops (overrides -tile-sizes if provided). +static llvm::cl::opt + clTileSize("tile-size", llvm::cl::desc("Use this tile size for all loops"), + llvm::cl::cat(clOptionsCategory)); + // List of tile sizes. If any of them aren't provided, they are filled with // clTileSize / kDefaultTileSize. static llvm::cl::list clTileSizes( "tile-sizes", llvm::cl::desc( - "List of tile sizes for each perfect nest (overrides -tile-size)"), + "List of tile sizes for each perfect nest (overridden by -tile-size)"), llvm::cl::ZeroOrMore, llvm::cl::cat(clOptionsCategory)); namespace { /// A pass to perform loop tiling on all suitable loop nests of a Function. struct LoopTiling : public FunctionPass { + explicit LoopTiling(uint64_t cacheSizeBytes = kDefaultCacheMemCapacity, + bool avoidMaxMinBounds = true) + : cacheSizeBytes(cacheSizeBytes), avoidMaxMinBounds(avoidMaxMinBounds) {} + void runOnFunction() override; + void getTileSizes(ArrayRef> band, + SmallVectorImpl *tileSizes); + // Default tile size if nothing is provided. constexpr static unsigned kDefaultTileSize = 4; + constexpr static uint64_t kDefaultCacheMemCapacity = 512 * 1024UL; + + // Capacity of the cache to tile for. + uint64_t cacheSizeBytes; + // If true, tile sizes are set to avoid max/min in bounds if possible. + bool avoidMaxMinBounds; }; } // end anonymous namespace -// Tile size to use for all loops (overridden by -tile-sizes if provided). -static llvm::cl::opt - clTileSize("tile-size", llvm::cl::init(LoopTiling::kDefaultTileSize), - llvm::cl::desc("Use this tile size for all loops"), - llvm::cl::cat(clOptionsCategory)); - -/// Creates a pass to perform loop tiling on all suitable loop nests of an +/// Creates a pass to perform loop tiling on all suitable loop nests of a /// Function. -FunctionPassBase *mlir::createLoopTilingPass() { return new LoopTiling(); } +FunctionPassBase *mlir::createLoopTilingPass(uint64_t cacheSizeBytes) { + return new LoopTiling(cacheSizeBytes); +} // Move the loop body of AffineForOp 'src' from 'src' into the specified // location in destination's body. @@ -213,7 +235,7 @@ Status mlir::tileCodeGen(MutableArrayRef> band, getIndexSet(band, &cst); if (!cst.isHyperRectangular(0, width)) { - rootAffineForOp->emitError("tiled code generation unimplemented for the" + rootAffineForOp->emitError("tiled code generation unimplemented for the " "non-hyperrectangular case"); return Status::failure(); } @@ -253,18 +275,130 @@ getTileableBands(Function *f, getMaximalPerfectLoopNest(forOp); } +// Reduce each tile size to the largest divisor of the corresponding trip count +// (if the trip count is known). +static void adjustToDivisorsOfTripCounts(ArrayRef> band, + SmallVectorImpl *tileSizes) { + assert(band.size() == tileSizes->size() && "invalid tile size count"); + for (unsigned i = 0, e = band.size(); i < e; i++) { + unsigned &tSizeAdjusted = (*tileSizes)[i]; + auto mayConst = getConstantTripCount(band[i]); + if (!mayConst.hasValue()) + continue; + // Adjust the tile size to largest factor of the trip count less than + // tSize. + uint64_t constTripCount = mayConst.getValue(); + if (tSizeAdjusted > constTripCount / 2) + tSizeAdjusted = constTripCount / 2; + while (constTripCount % tSizeAdjusted != 0) + tSizeAdjusted--; + } +} + +// Returns tile sizes to use. Checks CL options; if none are specified, sets it +// based on a simple model that looks at the memory footprint and determines +// tile sizes assuming identity accesses / 1:1 tile size proportional footprint +// along each of the dimensions being tiled. +// TODO(mlir-team): evolve this model. Tile size determination is a large area +// to play with in general. +void LoopTiling::getTileSizes(ArrayRef> band, + SmallVectorImpl *tileSizes) { + if (band.empty()) + return; + + tileSizes->resize(band.size()); + + // Use clTileSize for all loops if specified. + if (clTileSize.getNumOccurrences() > 0) { + std::fill(tileSizes->begin(), tileSizes->end(), clTileSize); + return; + } + + // Use clTileSizes and fill them with default tile size if it's short. + if (!clTileSizes.empty()) { + std::fill(tileSizes->begin(), tileSizes->end(), + LoopTiling::kDefaultTileSize); + std::copy(clTileSizes.begin(), + clTileSizes.begin() + std::min(clTileSizes.size(), band.size()), + tileSizes->begin()); + return; + } + + // The first loop in the band. + auto rootForOp = band[0]; + (void)rootForOp; + + // Obtain memory footprint and set tile sizes so that a tile fits in + // the cache size. This is an approximation with the assumption that the + // footprint increases with the tile size linearly in that dimension (i.e., + // assumes one-to-one access function). + auto fp = getMemoryFootprintBytes(band[0], 0); + if (!fp.hasValue()) { + // Fill with default tile sizes if footprint is unknown. + std::fill(tileSizes->begin(), tileSizes->end(), + LoopTiling::kDefaultTileSize); + if (avoidMaxMinBounds) + adjustToDivisorsOfTripCounts(band, tileSizes); + LLVM_DEBUG( + rootForOp->emitWarning("memory footprint unknown: using default tile " + "sizes adjusted to trip count divisors")); + return; + } + + // Check how many times larger the cache size is when compared to footprint. + uint64_t excessFactor = llvm::divideCeil(fp.getValue(), cacheSizeBytes); + if (excessFactor <= 1) { + // No need of any tiling - set tile size to 1. + std::fill(tileSizes->begin(), tileSizes->end(), 1); + return; + } + + // Divide all loops equally in an attempt to reduce footprint. + // TODO(bondhugula): this is approximate. Ideally, obtain reuse factor / + // profitability along each dimension and weight tile sizes based on that as + // one possible approach. Or compute a polynomial in tile sizes and solve for + // it. + + // For an n-d tilable band, compute n^th root of the excess. + unsigned tSize = + static_cast(floorl(std::pow(excessFactor, 1.0 / band.size()))); + // We'll keep a running product to determine the last tile size better. + unsigned cumulProductOfTileSizes = 1; + for (unsigned i = 0, e = band.size(); i < e; i++) { + if (i < e - 1) + (*tileSizes)[i] = tSize; + else + // Set last tile size to cover the balance. + (*tileSizes)[i] = std::max(1UL, excessFactor / cumulProductOfTileSizes); + cumulProductOfTileSizes *= (*tileSizes)[i]; + } + if (avoidMaxMinBounds) + adjustToDivisorsOfTripCounts(band, tileSizes); +} + void LoopTiling::runOnFunction() { + // Override cache size if provided on command line. + if (clCacheSizeKiB.getNumOccurrences() > 0) + cacheSizeBytes = clCacheSizeKiB * 1024; + + // Bands of loops to tile. std::vector, 6>> bands; getTileableBands(getFunction(), &bands); for (auto &band : bands) { // Set up tile sizes; fill missing tile sizes at the end with default tile // size or clTileSize if one was provided. - SmallVector tileSizes(band.size(), clTileSize); - std::copy(clTileSizes.begin(), - clTileSizes.begin() + std::min(clTileSizes.size(), band.size()), - tileSizes.begin()); - + SmallVector tileSizes; + getTileSizes(band, &tileSizes); + if (llvm::DebugFlag) { + std::stringstream msg; + msg << "using tile sizes ["; + for (auto tSize : tileSizes) + msg << tSize << " "; + msg << "]\n"; + auto rootForOp = band[0]; + rootForOp->emitNote(msg.str()); + } if (failed(tileCodeGen(band, tileSizes))) return signalPassFailure(); } diff --git a/mlir/test/Transforms/loop-tiling.mlir b/mlir/test/Transforms/loop-tiling.mlir index c2fdbd4f80f..c18c0fccf4b 100644 --- a/mlir/test/Transforms/loop-tiling.mlir +++ b/mlir/test/Transforms/loop-tiling.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -loop-tile -tile-size=32 | FileCheck %s +// RUN: mlir-opt %s -split-input-file -loop-tile -tile-cache-size=512 | FileCheck %s --check-prefix=MODEL // CHECK-DAG: [[MAP0:#map[0-9]+]] = (d0) -> (d0 + 32) // CHECK-DAG: [[MAP1:#map[0-9]+]] = (d0) -> (d0 + 32, 50) @@ -68,3 +69,29 @@ func @loop_max_min_bound(%A : memref, %L : index, %U : index) { // CHECK-NEXT: } // CHECK-NEXT: } } + +// ----- + +// Cache size is set to 512 KiB. This loop nest accesses about 49 MiB, and the +// tile sizes chosen would be 6 x 6 x 6. However, to avoid min/max, which is +// possible here, they are adjusted to 4 x 4 x 5. + +// MODEL-LABEL: func @simple_matmul +func @simple_matmul(%arg0: memref<8x8xvector<64xf32>>, %arg1: memref<8x8xvector<64xf32>>, %arg2: memref<8x8xvector<64xf32>>) -> memref<8x8xvector<64xf32>> { + for %i = 0 to 256 { + for %j = 0 to 256 { + for %k = 0 to 250 { + %l = load %arg0[%i, %k] : memref<8x8xvector<64xf32>> + %r = load %arg1[%k, %j] : memref<8x8xvector<64xf32>> + %o = load %arg2[%i, %j] : memref<8x8xvector<64xf32>> + %m = mulf %l, %r : vector<64xf32> + %a = addf %o, %m : vector<64xf32> + store %a, %arg2[%i, %j] : memref<8x8xvector<64xf32>> + } + } + } + return %arg2 : memref<8x8xvector<64xf32>> +} +// MODEL: for %i0 = 0 to 256 step 4 { +// MODEL-NEXT: for %i1 = 0 to 256 step 4 { +// MODEL-NEXT: for %i2 = 0 to 250 step 5 { -- cgit v1.2.3 From a228b7d477f6db61c387fac48e53f4b6efc669ad Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Tue, 12 Mar 2019 16:09:11 -0700 Subject: Change getMemoryFootprintBytes emitError to a warning - this is really not a hard error; emit a warning instead (for inability to compute footprint due to the union failing due to unimplemented cases) - remove a misleading warning from LoopFusion.cpp PiperOrigin-RevId: 238118711 --- mlir/lib/Analysis/Utils.cpp | 4 +++- mlir/lib/Transforms/LoopFusion.cpp | 1 - 2 files changed, 3 insertions(+), 2 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 0b5e6ab7718..e330a01c6e8 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -712,7 +712,9 @@ static Optional getMemoryFootprintBytes(const Block &block, if (it == regions.end()) { regions[region->memref] = std::move(region); } else if (failed(it->second->unionBoundingBox(*region))) { - opInst->emitError("Error performing a union on a memory region\n"); + opInst->emitWarning( + "getMemoryFootprintBytes: unable to perform a union on a memory " + "region"); error = true; return; } diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 4e619ef36a4..e05af79caa1 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -1680,7 +1680,6 @@ static bool isFusionProfitable(Instruction *srcOpInst, auto dstMemSizeVal = dstMemSize.getValue(); assert(sliceMemEstimate.hasValue() && "expected value"); - // This is an inaccurate estimate since sliceMemEstimate is isaccurate. auto fusedMem = dstMemSizeVal + sliceMemEstimate.getValue(); LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n" -- cgit v1.2.3 From 276fae1b0d57666b991bf7e72aaa7c5d66bd5e15 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 14 Mar 2019 10:38:44 -0700 Subject: Rename BlockList into Region NFC. This is step 1/n to specifying regions as parts of any operation. PiperOrigin-RevId: 238472370 --- mlir/include/mlir/AffineOps/AffineOps.h | 24 +++--- mlir/include/mlir/Analysis/Dominance.h | 12 +-- mlir/include/mlir/IR/Block.h | 68 +++++++++-------- mlir/include/mlir/IR/Function.h | 54 +++++++------- mlir/include/mlir/IR/FunctionGraphTraits.h | 16 ++-- mlir/include/mlir/IR/Instruction.h | 50 ++++++------- mlir/include/mlir/IR/OpImplementation.h | 17 ++--- mlir/include/mlir/IR/OperationSupport.h | 10 +-- mlir/lib/AffineOps/AffineOps.cpp | 82 ++++++++++----------- mlir/lib/Analysis/Dominance.cpp | 32 ++++---- mlir/lib/Analysis/LoopAnalysis.cpp | 2 +- mlir/lib/Analysis/Utils.cpp | 4 +- mlir/lib/Analysis/Verifier.cpp | 8 +- mlir/lib/IR/AsmPrinter.cpp | 27 +++---- mlir/lib/IR/Block.cpp | 45 ++++++------ mlir/lib/IR/Builders.cpp | 7 +- mlir/lib/IR/Function.cpp | 8 +- mlir/lib/IR/Instruction.cpp | 61 +++++++-------- mlir/lib/Parser/Parser.cpp | 114 ++++++++++++++--------------- mlir/lib/Transforms/CSE.cpp | 29 ++++---- mlir/lib/Transforms/DialectConversion.cpp | 2 +- mlir/lib/Transforms/LoopFusion.cpp | 4 +- mlir/lib/Transforms/LoopUnroll.cpp | 4 +- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 4 +- mlir/lib/Transforms/MaterializeVectors.cpp | 4 +- mlir/lib/Transforms/Vectorize.cpp | 4 +- mlir/test/IR/invalid.mlir | 2 +- mlir/test/IR/parser.mlir | 2 +- 28 files changed, 339 insertions(+), 357 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/AffineOps/AffineOps.h b/mlir/include/mlir/AffineOps/AffineOps.h index 487b0a35a37..a70d8102a5d 100644 --- a/mlir/include/mlir/AffineOps/AffineOps.h +++ b/mlir/include/mlir/AffineOps/AffineOps.h @@ -142,14 +142,12 @@ public: Block *createBody(); /// Get the body of the AffineForOp. - Block *getBody() { return &getBlockList().front(); } - const Block *getBody() const { return &getBlockList().front(); } + Block *getBody() { return &getRegion().front(); } + const Block *getBody() const { return &getRegion().front(); } - /// Get the blocklist containing the body. - BlockList &getBlockList() { return getInstruction()->getBlockList(0); } - const BlockList &getBlockList() const { - return getInstruction()->getBlockList(0); - } + /// Get the body region of the AffineForOp. + Region &getRegion() { return getInstruction()->getRegion(0); } + const Region &getRegion() const { return getInstruction()->getRegion(0); } /// Returns the induction variable for this loop. Value *getInductionVar(); @@ -332,15 +330,15 @@ public: IntegerSet getIntegerSet() const; void setIntegerSet(IntegerSet newSet); - /// Returns the list of 'then' blocks. - BlockList &getThenBlocks(); - const BlockList &getThenBlocks() const { + /// Returns the 'then' region. + Region &getThenBlocks(); + const Region &getThenBlocks() const { return const_cast(this)->getThenBlocks(); } - /// Returns the list of 'else' blocks. - BlockList &getElseBlocks(); - const BlockList &getElseBlocks() const { + /// Returns the 'else' blocks. + Region &getElseBlocks(); + const Region &getElseBlocks() const { return const_cast(this)->getElseBlocks(); } diff --git a/mlir/include/mlir/Analysis/Dominance.h b/mlir/include/mlir/Analysis/Dominance.h index edf5a36637a..6ab1b876972 100644 --- a/mlir/include/mlir/Analysis/Dominance.h +++ b/mlir/include/mlir/Analysis/Dominance.h @@ -43,10 +43,10 @@ public: /// Recalculate the dominance info for the provided function. void recalculate(Function *function); - /// Get the root dominance node of the given block list. - DominanceInfoNode *getRootNode(const BlockList *blockList) { - assert(dominanceInfos.count(blockList) != 0); - return dominanceInfos[blockList]->getRootNode(); + /// Get the root dominance node of the given region. + DominanceInfoNode *getRootNode(const Region *region) { + assert(dominanceInfos.count(region) != 0); + return dominanceInfos[region]->getRootNode(); } protected: @@ -55,8 +55,8 @@ protected: /// Return true if the specified block A properly dominates block B. bool properlyDominates(const Block *a, const Block *b); - /// A mapping of block lists to their base dominator tree. - DenseMap> dominanceInfos; + /// A mapping of regions to their base dominator tree. + DenseMap> dominanceInfos; }; } // end namespace detail diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h index da02db8592b..babc1507e48 100644 --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -1,4 +1,4 @@ -//===- Block.h - MLIR Block and BlockList Classes ---------------*- C++ -*-===// +//===- Block.h - MLIR Block and Region Classes ------------------*- C++ -*-===// // // Copyright 2019 The MLIR Authors. // @@ -15,7 +15,7 @@ // limitations under the License. // ============================================================================= // -// This file defines Block and BlockList classes. +// This file defines Block and Region classes. // //===----------------------------------------------------------------------===// @@ -35,7 +35,7 @@ namespace llvm { namespace ilist_detail { // Explicitly define the node access for the instruction list so that we can // break the dependence on the Instruction class in this header. This allows for -// instructions to have trailing BlockLists without a circular include +// instructions to have trailing Regions without a circular include // dependence. template <> struct SpecificNodeAccess< @@ -71,7 +71,7 @@ private: namespace mlir { class BlockAndValueMapping; -class BlockList; +class Region; class Function; using BlockOperand = IROperandImpl; @@ -81,7 +81,7 @@ template class SuccessorIterator; /// `Block` represents an ordered list of `Instruction`s. class Block : public IRObjectWithUseList, - public llvm::ilist_node_with_parent { + public llvm::ilist_node_with_parent { public: explicit Block() {} ~Block(); @@ -96,8 +96,8 @@ public: instructions.pop_back(); } - /// Blocks are maintained in a list of BlockList type. - BlockList *getParent() const { return parentValidInstOrderPair.getPointer(); } + /// Blocks are maintained in a Region. + Region *getParent() const { return parentValidInstOrderPair.getPointer(); } /// Returns the closest surrounding instruction that contains this block or /// nullptr if this is a top-level block. @@ -339,8 +339,7 @@ public: private: /// Pair of the parent object that owns this block and a bit that signifies if /// the instructions within this block have a valid ordering. - llvm::PointerIntPair - parentValidInstOrderPair; + llvm::PointerIntPair parentValidInstOrderPair; /// This is the list of instructions in the block. InstListType instructions; @@ -373,7 +372,7 @@ struct ilist_traits<::mlir::Block> : public ilist_alloc_traits<::mlir::Block> { block_iterator first, block_iterator last); private: - mlir::BlockList *getContainingBlockList(); + mlir::Region *getContainingRegion(); }; } // end namespace llvm @@ -381,20 +380,20 @@ namespace mlir { /// This class contains a list of basic blocks and has a notion of the object it /// is part of - a Function or an operation region. -class BlockList { +class Region { public: - explicit BlockList(Function *container); - explicit BlockList(Instruction *container); + explicit Region(Function *container); + explicit Region(Instruction *container); - using BlockListType = llvm::iplist; - BlockListType &getBlocks() { return blocks; } - const BlockListType &getBlocks() const { return blocks; } + using RegionType = llvm::iplist; + RegionType &getBlocks() { return blocks; } + const RegionType &getBlocks() const { return blocks; } // Iteration over the block in the function. - using iterator = BlockListType::iterator; - using const_iterator = BlockListType::const_iterator; - using reverse_iterator = BlockListType::reverse_iterator; - using const_reverse_iterator = BlockListType::const_reverse_iterator; + using iterator = RegionType::iterator; + using const_iterator = RegionType::const_iterator; + using reverse_iterator = RegionType::reverse_iterator; + using const_reverse_iterator = RegionType::const_reverse_iterator; iterator begin() { return blocks.begin(); } iterator end() { return blocks.end(); } @@ -410,40 +409,39 @@ public: void push_front(Block *block) { blocks.push_front(block); } Block &back() { return blocks.back(); } - const Block &back() const { return const_cast(this)->back(); } + const Block &back() const { return const_cast(this)->back(); } Block &front() { return blocks.front(); } - const Block &front() const { return const_cast(this)->front(); } + const Block &front() const { return const_cast(this)->front(); } - /// getSublistAccess() - Returns pointer to member of block list. - static BlockListType BlockList::*getSublistAccess(Block *) { - return &BlockList::blocks; + /// getSublistAccess() - Returns pointer to member of region. + static RegionType Region::*getSublistAccess(Block *) { + return &Region::blocks; } - /// A BlockList is part of a function or an operation region. If it is - /// part of an operation region, then return the operation, otherwise return - /// null. + /// A Region is either a function body or a part of an operation. If it is + /// part of an operation, then return the operation, otherwise return null. Instruction *getContainingInst(); const Instruction *getContainingInst() const { - return const_cast(this)->getContainingInst(); + return const_cast(this)->getContainingInst(); } - /// A BlockList is part of a function or an operation region. If it is part - /// of a Function, then return it, otherwise return null. + /// A Region is either a function body or a part of an operation. If it is + /// a Function body, then return this function, otherwise return null. Function *getContainingFunction(); const Function *getContainingFunction() const { - return const_cast(this)->getContainingFunction(); + return const_cast(this)->getContainingFunction(); } - /// Clone the internal blocks from this block list into dest. Any + /// Clone the internal blocks from this region into dest. Any /// cloned blocks are appended to the back of dest. If the mapper /// contains entries for block arguments, these arguments are not included /// in the respective cloned block. - void cloneInto(BlockList *dest, BlockAndValueMapping &mapper, + void cloneInto(Region *dest, BlockAndValueMapping &mapper, MLIRContext *context) const; private: - BlockListType blocks; + RegionType blocks; /// This is the object we are part of. llvm::PointerUnion container; diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h index 4a86640a41c..221c5c4e975 100644 --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -80,37 +80,37 @@ public: // Body Handling //===--------------------------------------------------------------------===// - BlockList &getBlockList() { return blocks; } - const BlockList &getBlockList() const { return blocks; } + Region &getBody() { return body; } + const Region &getBody() const { return body; } /// This is the list of blocks in the function. - using BlockListType = llvm::iplist; - BlockListType &getBlocks() { return blocks.getBlocks(); } - const BlockListType &getBlocks() const { return blocks.getBlocks(); } + using RegionType = llvm::iplist; + RegionType &getBlocks() { return body.getBlocks(); } + const RegionType &getBlocks() const { return body.getBlocks(); } // Iteration over the block in the function. - using iterator = BlockListType::iterator; - using const_iterator = BlockListType::const_iterator; - using reverse_iterator = BlockListType::reverse_iterator; - using const_reverse_iterator = BlockListType::const_reverse_iterator; - - iterator begin() { return blocks.begin(); } - iterator end() { return blocks.end(); } - const_iterator begin() const { return blocks.begin(); } - const_iterator end() const { return blocks.end(); } - reverse_iterator rbegin() { return blocks.rbegin(); } - reverse_iterator rend() { return blocks.rend(); } - const_reverse_iterator rbegin() const { return blocks.rbegin(); } - const_reverse_iterator rend() const { return blocks.rend(); } - - bool empty() const { return blocks.empty(); } - void push_back(Block *block) { blocks.push_back(block); } - void push_front(Block *block) { blocks.push_front(block); } - - Block &back() { return blocks.back(); } + using iterator = RegionType::iterator; + using const_iterator = RegionType::const_iterator; + using reverse_iterator = RegionType::reverse_iterator; + using const_reverse_iterator = RegionType::const_reverse_iterator; + + iterator begin() { return body.begin(); } + iterator end() { return body.end(); } + const_iterator begin() const { return body.begin(); } + const_iterator end() const { return body.end(); } + reverse_iterator rbegin() { return body.rbegin(); } + reverse_iterator rend() { return body.rend(); } + const_reverse_iterator rbegin() const { return body.rbegin(); } + const_reverse_iterator rend() const { return body.rend(); } + + bool empty() const { return body.empty(); } + void push_back(Block *block) { body.push_back(block); } + void push_front(Block *block) { body.push_front(block); } + + Block &back() { return body.back(); } const Block &back() const { return const_cast(this)->back(); } - Block &front() { return blocks.front(); } + Block &front() { return body.front(); } const Block &front() const { return const_cast(this)->front(); } //===--------------------------------------------------------------------===// @@ -329,8 +329,8 @@ private: /// The attributes lists for each of the function arguments. std::vector argAttrs; - /// The contents of the body. - BlockList blocks; + /// The body of the function. + Region body; void operator=(const Function &) = delete; friend struct llvm::ilist_traits; diff --git a/mlir/include/mlir/IR/FunctionGraphTraits.h b/mlir/include/mlir/IR/FunctionGraphTraits.h index 6ba50e7ca9e..b8a0d7ea633 100644 --- a/mlir/include/mlir/IR/FunctionGraphTraits.h +++ b/mlir/include/mlir/IR/FunctionGraphTraits.h @@ -153,8 +153,8 @@ struct GraphTraits> }; template <> -struct GraphTraits : public GraphTraits { - using GraphType = mlir::BlockList *; +struct GraphTraits : public GraphTraits { + using GraphType = mlir::Region *; using NodeRef = mlir::Block *; static NodeRef getEntryNode(GraphType fn) { return &fn->front(); } @@ -169,9 +169,9 @@ struct GraphTraits : public GraphTraits { }; template <> -struct GraphTraits +struct GraphTraits : public GraphTraits { - using GraphType = const mlir::BlockList *; + using GraphType = const mlir::Region *; using NodeRef = const mlir::Block *; static NodeRef getEntryNode(GraphType fn) { return &fn->front(); } @@ -186,9 +186,9 @@ struct GraphTraits }; template <> -struct GraphTraits> +struct GraphTraits> : public GraphTraits> { - using GraphType = Inverse; + using GraphType = Inverse; using NodeRef = NodeRef; static NodeRef getEntryNode(GraphType fn) { return &fn.Graph->front(); } @@ -203,9 +203,9 @@ struct GraphTraits> }; template <> -struct GraphTraits> +struct GraphTraits> : public GraphTraits> { - using GraphType = Inverse; + using GraphType = Inverse; using NodeRef = NodeRef; static NodeRef getEntryNode(GraphType fn) { return &fn.Graph->front(); } diff --git a/mlir/include/mlir/IR/Instruction.h b/mlir/include/mlir/IR/Instruction.h index f9a3ac0b4d5..2bcbd8ac121 100644 --- a/mlir/include/mlir/IR/Instruction.h +++ b/mlir/include/mlir/IR/Instruction.h @@ -49,15 +49,15 @@ using BlockOperand = IROperandImpl; class Instruction final : public llvm::ilist_node_with_parent, private llvm::TrailingObjects { + unsigned, Region, detail::OperandStorage> { public: /// Create a new Instruction with the specific fields. - static Instruction * - create(Location location, OperationName name, ArrayRef operands, - ArrayRef resultTypes, ArrayRef attributes, - ArrayRef successors, unsigned numBlockLists, - bool resizableOperandList, MLIRContext *context); + static Instruction *create(Location location, OperationName name, + ArrayRef operands, + ArrayRef resultTypes, + ArrayRef attributes, + ArrayRef successors, unsigned numRegions, + bool resizableOperandList, MLIRContext *context); /// The name of an operation is the key identifier for it. OperationName getName() const { return name; } @@ -279,24 +279,24 @@ public: // Blocks //===--------------------------------------------------------------------===// - /// Returns the number of block lists held by this operation. - unsigned getNumBlockLists() const { return numBlockLists; } + /// Returns the number of regions held by this operation. + unsigned getNumRegions() const { return numRegions; } - /// Returns the block lists held by this operation. - MutableArrayRef getBlockLists() { - return {getTrailingObjects(), numBlockLists}; + /// Returns the regions held by this operation. + MutableArrayRef getRegions() { + return {getTrailingObjects(), numRegions}; } - ArrayRef getBlockLists() const { - return const_cast(this)->getBlockLists(); + ArrayRef getRegions() const { + return const_cast(this)->getRegions(); } - /// Returns the block list held by this operation at position 'index'. - BlockList &getBlockList(unsigned index) { - assert(index < numBlockLists && "invalid block list index"); - return getBlockLists()[index]; + /// Returns the region held by this operation at position 'index'. + Region &getRegion(unsigned index) { + assert(index < numRegions && "invalid region index"); + return getRegions()[index]; } - const BlockList &getBlockList(unsigned index) const { - return const_cast(this)->getBlockList(index); + const Region &getRegion(unsigned index) const { + return const_cast(this)->getRegion(index); } //===--------------------------------------------------------------------===// @@ -528,7 +528,7 @@ public: protected: Instruction(Location location, OperationName name, unsigned numResults, - unsigned numSuccessors, unsigned numBlockLists, + unsigned numSuccessors, unsigned numRegions, ArrayRef attributes, MLIRContext *context); // Instructions are deleted through the destroy() member because they are @@ -558,7 +558,7 @@ private: /// O(1) local dominance checks between instructions. mutable unsigned orderIndex = 0; - const unsigned numResults, numSuccs, numBlockLists; + const unsigned numResults, numSuccs, numRegions; /// This holds the name of the operation. OperationName name; @@ -577,16 +577,14 @@ private: // This stuff is used by the TrailingObjects template. friend llvm::TrailingObjects; + Region, detail::OperandStorage>; size_t numTrailingObjects(OverloadToken) const { return numResults; } size_t numTrailingObjects(OverloadToken) const { return numSuccs; } - size_t numTrailingObjects(OverloadToken) const { - return numBlockLists; - } + size_t numTrailingObjects(OverloadToken) const { return numRegions; } size_t numTrailingObjects(OverloadToken) const { return numSuccs; } }; diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index c5078327700..fcdf2b6b61a 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -89,9 +89,9 @@ public: /// Print the entire operation with the default generic assembly form. virtual void printGenericOp(const Instruction *op) = 0; - /// Prints a block list. - virtual void printBlockList(const BlockList &blocks, - bool printEntryBlockArgs = true) = 0; + /// Prints a region. + virtual void printRegion(const Region &blocks, + bool printEntryBlockArgs = true) = 0; private: OpAsmPrinter(const OpAsmPrinter &) = delete; @@ -314,13 +314,12 @@ public: int requiredOperandCount = -1, Delimiter delimiter = Delimiter::None) = 0; - /// Parses a block list. Any parsed blocks are filled in to the - /// operation's block lists after the operation is created. - virtual bool parseBlockList() = 0; + /// Parses a region. Any parsed blocks are filled in to the operation's + /// regions after the operation is created. + virtual bool parseRegion() = 0; - /// Parses an argument for the entry block of the next block list to be - /// parsed. - virtual bool parseBlockListEntryBlockArgument(Type argType) = 0; + /// Parses an argument for the entry block of the next region to be parsed. + virtual bool parseRegionEntryBlockArgument(Type argType) = 0; //===--------------------------------------------------------------------===// // Methods for interacting with the parser diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index a3e2911a3d1..5c45bd650f5 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -229,7 +229,7 @@ struct OperationState { SmallVector attributes; /// Successors of this operation and their respective operands. SmallVector successors; - unsigned numBlockLists = 0; + unsigned numRegions = 0; /// If the operation has a resizable operand list. bool resizableOperandList = false; @@ -243,14 +243,14 @@ public: OperationState(MLIRContext *context, Location location, StringRef name, ArrayRef operands, ArrayRef types, ArrayRef attributes, - ArrayRef successors = {}, unsigned numBlockLists = 0, + ArrayRef successors = {}, unsigned numRegions = 0, bool resizableOperandList = false) : context(context), location(location), name(name, context), operands(operands.begin(), operands.end()), types(types.begin(), types.end()), attributes(attributes.begin(), attributes.end()), successors(successors.begin(), successors.end()), - numBlockLists(numBlockLists) {} + numRegions(numRegions) {} void addOperands(ArrayRef newOperands) { assert(successors.empty() && @@ -279,8 +279,8 @@ public: operands.append(succOperands.begin(), succOperands.end()); } - /// Add a new block list with the specified blocks. - void reserveBlockLists(unsigned numReserved) { numBlockLists += numReserved; } + /// Reserve space for new regions. + void reserveRegions(unsigned numReserved) { numRegions += numReserved; } /// Sets the operand list of the operation as resizable. void setOperandListToResizable(bool isResizable = true) { diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index a9de42eb732..16fddb45496 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -553,8 +553,8 @@ void AffineForOp::build(Builder *builder, OperationState *result, builder->getAffineMapAttr(ubMap)); result->addOperands(ubOperands); - // Reserve a block list for the body. - result->reserveBlockLists(/*numReserved=*/1); + // Reserve a region for the body. + result->reserveRegions(/*numReserved=*/1); // Set the operands list as resizable so that we can freely modify the bounds. result->setOperandListToResizable(); @@ -568,12 +568,11 @@ void AffineForOp::build(Builder *builder, OperationState *result, int64_t lb, } bool AffineForOp::verify() const { - const auto &bodyBlockList = getInstruction()->getBlockList(0); + const auto &bodyRegion = getInstruction()->getRegion(0); - // The body block list must contain a single basic block. - if (bodyBlockList.empty() || - std::next(bodyBlockList.begin()) != bodyBlockList.end()) - return emitOpError("expected body block list to have a single block"); + // The body region must contain a single basic block. + if (bodyRegion.empty() || std::next(bodyRegion.begin()) != bodyRegion.end()) + return emitOpError("expected body region to have a single block"); // Check that the body defines as single block argument for the induction // variable. @@ -701,7 +700,7 @@ static bool parseBound(bool isLower, OperationState *result, OpAsmParser *p) { bool AffineForOp::parse(OpAsmParser *parser, OperationState *result) { auto &builder = parser->getBuilder(); // Parse the induction variable followed by '='. - if (parser->parseBlockListEntryBlockArgument(builder.getIndexType()) || + if (parser->parseRegionEntryBlockArgument(builder.getIndexType()) || parser->parseEqual()) return true; @@ -730,9 +729,9 @@ bool AffineForOp::parse(OpAsmParser *parser, OperationState *result) { "expected step to be representable as a positive signed integer"); } - // Parse the body block list. - result->reserveBlockLists(/*numReserved=*/1); - if (parser->parseBlockList()) + // Parse the body region. + result->reserveRegions(/*numReserved=*/1); + if (parser->parseRegion()) return true; // Parse the optional attribute list. @@ -793,8 +792,8 @@ void AffineForOp::print(OpAsmPrinter *p) const { if (getStep() != 1) *p << " step " << getStep(); - p->printBlockList(getInstruction()->getBlockList(0), - /*printEntryBlockArgs=*/false); + p->printRegion(getInstruction()->getRegion(0), + /*printEntryBlockArgs=*/false); p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{getLowerBoundAttrName(), getUpperBoundAttrName(), @@ -872,14 +871,14 @@ void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results, } Block *AffineForOp::createBody() { - auto &bodyBlockList = getBlockList(); - assert(bodyBlockList.empty() && "expected no existing body blocks"); + auto &bodyRegion = getRegion(); + assert(bodyRegion.empty() && "expected no existing body blocks"); // Create a new block for the body, and add an argument for the induction // variable. Block *body = new Block(); body->addArgument(IndexType::get(getInstruction()->getContext())); - bodyBlockList.push_back(body); + bodyRegion.push_back(body); return body; } @@ -1040,8 +1039,8 @@ void AffineIfOp::build(Builder *builder, OperationState *result, result->addAttribute(getConditionAttrName(), IntegerSetAttr::get(condition)); result->addOperands(conditionOperands); - // Reserve 2 block lists, one for the 'then' and one for the 'else' regions. - result->reserveBlockLists(2); + // Reserve 2 regions, one for the 'then' and one for the 'else' regions. + result->reserveRegions(2); } bool AffineIfOp::verify() const { @@ -1061,20 +1060,19 @@ bool AffineIfOp::verify() const { condition.getNumDims())) return true; - // Verify that the entry of each child blocklist does not have arguments. - for (const auto &blockList : getInstruction()->getBlockLists()) { - if (blockList.empty()) + // Verify that the entry of each child region does not have arguments. + for (const auto ®ion : getInstruction()->getRegions()) { + if (region.empty()) continue; // TODO(riverriddle) We currently do not allow multiple blocks in child - // block lists. - if (std::next(blockList.begin()) != blockList.end()) - return emitOpError( - "expects only one block per 'then' or 'else' block list"); - if (blockList.front().back().isKnownTerminator()) + // regions. + if (std::next(region.begin()) != region.end()) + return emitOpError("expects only one block per 'then' or 'else' regions"); + if (region.front().back().isKnownTerminator()) return emitOpError("expects region block to not have a terminator"); - for (const auto &b : blockList) + for (const auto &b : region) if (b.getNumArguments() != 0) return emitOpError( "requires that child entry blocks have no arguments"); @@ -1102,13 +1100,13 @@ bool AffineIfOp::parse(OpAsmParser *parser, OperationState *result) { parser->getNameLoc(), "symbol operand count and integer set symbol count must match"); - // Parse the 'then' block list. - if (parser->parseBlockList()) + // Parse the 'then' region. + if (parser->parseRegion()) return true; - // If we find an 'else' keyword then parse the else block list. + // If we find an 'else' keyword then parse the 'else' region. if (!parser->parseOptionalKeyword("else")) { - if (parser->parseBlockList()) + if (parser->parseRegion()) return true; } @@ -1116,8 +1114,8 @@ bool AffineIfOp::parse(OpAsmParser *parser, OperationState *result) { if (parser->parseOptionalAttributeDict(result->attributes)) return true; - // Reserve 2 block lists, one for the 'then' and one for the 'else' regions. - result->reserveBlockLists(2); + // Reserve 2 regions, one for the 'then' and one for the 'else' regions. + result->reserveRegions(2); return false; } @@ -1126,13 +1124,13 @@ void AffineIfOp::print(OpAsmPrinter *p) const { *p << "if " << conditionAttr; printDimAndSymbolList(operand_begin(), operand_end(), conditionAttr.getValue().getNumDims(), p); - p->printBlockList(getInstruction()->getBlockList(0)); + p->printRegion(getInstruction()->getRegion(0)); - // Print the 'else' block list if it has any blocks. - const auto &elseBlockList = getInstruction()->getBlockList(1); - if (!elseBlockList.empty()) { + // Print the 'else' regions if it has any blocks. + const auto &elseRegion = getInstruction()->getRegion(1); + if (!elseRegion.empty()) { *p << " else"; - p->printBlockList(elseBlockList); + p->printRegion(elseRegion); } // Print the attribute list. @@ -1148,11 +1146,7 @@ void AffineIfOp::setIntegerSet(IntegerSet newSet) { } /// Returns the list of 'then' blocks. -BlockList &AffineIfOp::getThenBlocks() { - return getInstruction()->getBlockList(0); -} +Region &AffineIfOp::getThenBlocks() { return getInstruction()->getRegion(0); } /// Returns the list of 'else' blocks. -BlockList &AffineIfOp::getElseBlocks() { - return getInstruction()->getBlockList(1); -} +Region &AffineIfOp::getElseBlocks() { return getInstruction()->getRegion(1); } diff --git a/mlir/lib/Analysis/Dominance.cpp b/mlir/lib/Analysis/Dominance.cpp index b3c9d822f25..a6f6845e9ef 100644 --- a/mlir/lib/Analysis/Dominance.cpp +++ b/mlir/lib/Analysis/Dominance.cpp @@ -41,19 +41,19 @@ void DominanceInfoBase::recalculate(Function *function) { // Build the top level function dominance. auto functionDominance = std::make_unique(); - functionDominance->recalculate(function->getBlockList()); - dominanceInfos.try_emplace(&function->getBlockList(), + functionDominance->recalculate(function->getBody()); + dominanceInfos.try_emplace(&function->getBody(), std::move(functionDominance)); - /// Build the dominance for each of the internal region block lists. + /// Build the dominance for each of the operation regions. function->walk([&](Instruction *inst) { - for (auto &blockList : inst->getBlockLists()) { + for (auto ®ion : inst->getRegions()) { // Don't compute dominance if the region is empty. - if (blockList.empty()) + if (region.empty()) continue; auto opDominance = std::make_unique(); - opDominance->recalculate(blockList); - dominanceInfos.try_emplace(&blockList, std::move(opDominance)); + opDominance->recalculate(region); + dominanceInfos.try_emplace(®ion, std::move(opDominance)); } }); } @@ -66,21 +66,21 @@ bool DominanceInfoBase::properlyDominates(const Block *a, if (a == b) return false; - // If both blocks are not in the same block list, 'a' properly dominates 'b' - // if 'b' is defined in an instruction region that (recursively) ends up being + // If both blocks are not in the same region, 'a' properly dominates 'b' if + // 'b' is defined in an instruction region that (recursively) ends up being // dominated by 'a'. Walk up the list of containers enclosing B. - auto *blockListA = a->getParent(), *blockListB = b->getParent(); - if (blockListA != blockListB) { + auto *regionA = a->getParent(), *regionB = b->getParent(); + if (regionA != regionB) { Instruction *bAncestor; do { - bAncestor = blockListB->getContainingInst(); + bAncestor = regionB->getContainingInst(); // If 'bAncestor' is the top level function, then 'a' is a block // that post dominates 'b'. if (!bAncestor) return IsPostDom; - blockListB = bAncestor->getBlock()->getParent(); - } while (blockListA != blockListB); + regionB = bAncestor->getBlock()->getParent(); + } while (regionA != regionB); // Check to see if the ancestor of 'b' is the same block as 'a'. b = bAncestor->getBlock(); @@ -89,8 +89,8 @@ bool DominanceInfoBase::properlyDominates(const Block *a, } // Otherwise, use the standard dominance functionality. - auto baseInfoIt = dominanceInfos.find(blockListA); - assert(baseInfoIt != dominanceInfos.end() && "block list info not found"); + auto baseInfoIt = dominanceInfos.find(regionA); + assert(baseInfoIt != dominanceInfos.end() && "region info not found"); return baseInfoIt->second->properlyDominates(a, b); } diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 727ce18c2e4..0df51ebd37b 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -303,7 +303,7 @@ static bool isVectorizableLoopWithCond(ConstOpPointer loop, // No vectorization across unknown regions. auto regions = matcher::Op([](const Instruction &inst) -> bool { - return inst.getNumBlockLists() != 0 && + return inst.getNumRegions() != 0 && !(inst.isa() || inst.isa()); }); SmallVector regionsMatched; diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index e330a01c6e8..3399c7da88f 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -462,8 +462,8 @@ static Instruction *getInstAtPosition(ArrayRef positions, return getInstAtPosition(positions, level + 1, childAffineForOp->getBody()); - for (auto &blockList : inst.getBlockLists()) { - for (auto &b : blockList) + for (auto ®ion : inst.getRegions()) { + for (auto &b : region) if (auto *ret = getInstAtPosition(positions, level + 1, &b)) return ret; } diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index 0f9365aad61..cab6758c51a 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -306,8 +306,8 @@ bool FuncVerifier::verifyOperation(const Instruction &op) { } // Verify that all child blocks are ok. - for (auto &blockList : op.getBlockLists()) - for (auto &b : blockList) + for (auto ®ion : op.getRegions()) + for (auto &b : region) if (verifyBlock(b, /*isTopLevel=*/false)) return true; @@ -338,8 +338,8 @@ bool FuncVerifier::verifyInstDominance(const Instruction &inst) { } // Verify the dominance of each of the nested blocks within this instruction. - for (auto &blockList : inst.getBlockLists()) - for (auto &block : blockList) + for (auto ®ion : inst.getRegions()) + for (auto &block : region) if (verifyDominance(block)) return true; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index b88de85482b..4454f69b7fc 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1103,9 +1103,8 @@ public: void printSuccessorAndUseList(const Instruction *term, unsigned index) override; - /// Print a block list. - void printBlockList(const BlockList &blocks, - bool printEntryBlockArgs) override { + /// Print a region. + void printRegion(const Region &blocks, bool printEntryBlockArgs) override { os << " {\n"; if (!blocks.empty()) { auto *entryBlock = &blocks.front(); @@ -1164,7 +1163,9 @@ FunctionPrinter::FunctionPrinter(const Function *function, numberValuesInBlock(block); } -/// Number all of the SSA values in the specified block list. +/// Number all of the SSA values in the specified block. Values get numbered +/// continuously throughout regions. In particular, we traverse the regions +/// held by operations and number values in depth-first pre-order. void FunctionPrinter::numberValuesInBlock(const Block &block) { // Each block gets a unique ID, and all of the instructions within it get // numbered as well. @@ -1178,8 +1179,8 @@ void FunctionPrinter::numberValuesInBlock(const Block &block) { // result. if (inst.getNumResults() != 0) numberValueID(inst.getResult(0)); - for (auto &blockList : inst.getBlockLists()) - for (const auto &block : blockList) + for (auto ®ion : inst.getRegions()) + for (const auto &block : region) numberValuesInBlock(block); } } @@ -1219,9 +1220,9 @@ void FunctionPrinter::numberValueID(const Value *value) { // argument is to an entry block of an operation region, give it an 'i' // name. if (auto *block = cast(value)->getOwner()) { - auto *parentBlockList = block->getParent(); - if (parentBlockList && block == &parentBlockList->front()) { - if (parentBlockList->getContainingFunction()) + auto *parentRegion = block->getParent(); + if (parentRegion && block == &parentRegion->front()) { + if (parentRegion->getContainingFunction()) specialName << "arg" << nextArgumentID++; else specialName << "i" << nextRegionArgumentID++; @@ -1279,7 +1280,7 @@ void FunctionPrinter::print() { printTrailingLocation(function->getLoc()); if (!function->empty()) { - printBlockList(function->getBlockList(), /*printEntryBlockArgs=*/false); + printRegion(function->getBody(), /*printEntryBlockArgs=*/false); os << "\n"; } os << '\n'; @@ -1496,9 +1497,9 @@ void FunctionPrinter::printGenericOp(const Instruction *op) { os << ')'; } - // Print any trailing block lists. - for (auto &blockList : op->getBlockLists()) - printBlockList(blockList, /*printEntryBlockArgs=*/true); + // Print any trailing regions. + for (auto ®ion : op->getRegions()) + printRegion(region, /*printEntryBlockArgs=*/true); } void FunctionPrinter::printSuccessorAndUseList(const Instruction *term, diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp index 60c382fbd84..eddf1241c9e 100644 --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -1,4 +1,4 @@ -//===- Block.cpp - MLIR Block and BlockList Classes -----------------------===// +//===- Block.cpp - MLIR Block and Region Classes --------------------------===// // // Copyright 2019 The MLIR Authors. // @@ -66,7 +66,7 @@ Function *Block::getFunction() { void Block::insertBefore(Block *block) { assert(!getParent() && "already inserted into a block!"); assert(block->getParent() && "cannot insert before a block without a parent"); - block->getParent()->getBlocks().insert(BlockList::iterator(block), this); + block->getParent()->getBlocks().insert(Region::iterator(block), this); } /// Unlink this Block from its Function and delete it. @@ -262,26 +262,26 @@ Block *Block::splitBlock(iterator splitBefore) { } //===----------------------------------------------------------------------===// -// BlockList +// Region //===----------------------------------------------------------------------===// -BlockList::BlockList(Function *container) : container(container) {} +Region::Region(Function *container) : container(container) {} -BlockList::BlockList(Instruction *container) : container(container) {} +Region::Region(Instruction *container) : container(container) {} -Instruction *BlockList::getContainingInst() { +Instruction *Region::getContainingInst() { return container.dyn_cast(); } -Function *BlockList::getContainingFunction() { +Function *Region::getContainingFunction() { return container.dyn_cast(); } -/// Clone the internal blocks from this block list into dest. Any +/// Clone the internal blocks from this region into `dest`. Any /// cloned blocks are appended to the back of dest. -void BlockList::cloneInto(BlockList *dest, BlockAndValueMapping &mapper, - MLIRContext *context) const { - assert(dest && "expected valid block list to clone into"); +void Region::cloneInto(Region *dest, BlockAndValueMapping &mapper, + MLIRContext *context) const { + assert(dest && "expected valid region to clone into"); // If the list is empty there is nothing to clone. if (empty()) @@ -321,25 +321,24 @@ void BlockList::cloneInto(BlockList *dest, BlockAndValueMapping &mapper, it->walk(remapOperands); } -BlockList *llvm::ilist_traits<::mlir::Block>::getContainingBlockList() { +Region *llvm::ilist_traits<::mlir::Block>::getContainingRegion() { size_t Offset( - size_t(&((BlockList *)nullptr->*BlockList::getSublistAccess(nullptr)))); + size_t(&((Region *)nullptr->*Region::getSublistAccess(nullptr)))); iplist *Anchor(static_cast *>(this)); - return reinterpret_cast(reinterpret_cast(Anchor) - - Offset); + return reinterpret_cast(reinterpret_cast(Anchor) - Offset); } -/// This is a trait method invoked when a basic block is added to a function. -/// We keep the function pointer up to date. +/// This is a trait method invoked when a basic block is added to a region. +/// We keep the region pointer up to date. void llvm::ilist_traits<::mlir::Block>::addNodeToList(Block *block) { - assert(!block->getParent() && "already in a function!"); - block->parentValidInstOrderPair.setPointer(getContainingBlockList()); + assert(!block->getParent() && "already in a region!"); + block->parentValidInstOrderPair.setPointer(getContainingRegion()); } /// This is a trait method invoked when an instruction is removed from a -/// function. We keep the function pointer up to date. +/// region. We keep the region pointer up to date. void llvm::ilist_traits<::mlir::Block>::removeNodeFromList(Block *block) { - assert(block->getParent() && "not already in a function!"); + assert(block->getParent() && "not already in a region!"); block->parentValidInstOrderPair.setPointer(nullptr); } @@ -349,8 +348,8 @@ void llvm::ilist_traits<::mlir::Block>::transferNodesFromList( ilist_traits &otherList, block_iterator first, block_iterator last) { // If we are transferring instructions within the same function, the parent // pointer doesn't need to be updated. - auto *curParent = getContainingBlockList(); - if (curParent == otherList.getContainingBlockList()) + auto *curParent = getContainingRegion(); + if (curParent == otherList.getContainingRegion()) return; // Update the 'parent' member of each Block. diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 1100a2b0fba..56d0ad059fa 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -311,10 +311,9 @@ Block *FuncBuilder::createBlock(Block *insertBefore) { /// Create an operation given the fields represented as an OperationState. Instruction *FuncBuilder::createOperation(const OperationState &state) { assert(block && "createOperation() called without setting builder's block"); - auto *op = Instruction::create(state.location, state.name, state.operands, - state.types, state.attributes, - state.successors, state.numBlockLists, - state.resizableOperandList, context); + auto *op = Instruction::create( + state.location, state.name, state.operands, state.types, state.attributes, + state.successors, state.numRegions, state.resizableOperandList, context); block->getInstructions().insert(insertPoint, op); return op; } diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index 8ad81397efe..107f1b5888e 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -30,14 +30,14 @@ Function::Function(Location location, StringRef name, FunctionType type, ArrayRef attrs) : name(Identifier::get(name, type.getContext())), location(location), type(type), attrs(type.getContext(), attrs), - argAttrs(type.getNumInputs()), blocks(this) {} + argAttrs(type.getNumInputs()), body(this) {} Function::Function(Location location, StringRef name, FunctionType type, ArrayRef attrs, ArrayRef argAttrs) : name(Identifier::get(name, type.getContext())), location(location), type(type), attrs(type.getContext(), attrs), argAttrs(argAttrs), - blocks(this) {} + body(this) {} Function::~Function() { // Instructions may have cyclic references, which need to be dropped before we @@ -160,8 +160,8 @@ void Function::cloneInto(Function *dest, BlockAndValueMapping &mapper) const { } dest->setAttrs(newAttrs.takeVector()); - // Clone the block list. - blocks.cloneInto(&dest->blocks, mapper, dest->getContext()); + // Clone the body. + body.cloneInto(&dest->body, mapper, dest->getContext()); } /// Create a deep copy of this function and all of its blocks, remapping diff --git a/mlir/lib/IR/Instruction.cpp b/mlir/lib/IR/Instruction.cpp index 36a0449ef5e..b38ea6a0fb3 100644 --- a/mlir/lib/IR/Instruction.cpp +++ b/mlir/lib/IR/Instruction.cpp @@ -134,12 +134,13 @@ void detail::OperandStorage::grow(ResizableStorage &resizeUtil, //===----------------------------------------------------------------------===// /// Create a new Instruction with the specific fields. -Instruction * -Instruction::create(Location location, OperationName name, - ArrayRef operands, ArrayRef resultTypes, - ArrayRef attributes, - ArrayRef successors, unsigned numBlockLists, - bool resizableOperandList, MLIRContext *context) { +Instruction *Instruction::create(Location location, OperationName name, + ArrayRef operands, + ArrayRef resultTypes, + ArrayRef attributes, + ArrayRef successors, + unsigned numRegions, bool resizableOperandList, + MLIRContext *context) { unsigned numSuccessors = successors.size(); // Input operands are nullptr-separated for each successor, the null operands @@ -147,9 +148,9 @@ Instruction::create(Location location, OperationName name, unsigned numOperands = operands.size() - numSuccessors; // Compute the byte size for the instruction and the operand storage. - auto byteSize = totalSizeToAlloc( - resultTypes.size(), numSuccessors, numSuccessors, numBlockLists, + auto byteSize = totalSizeToAlloc( + resultTypes.size(), numSuccessors, numSuccessors, numRegions, /*detail::OperandStorage*/ 1); byteSize += llvm::alignTo(detail::OperandStorage::additionalAllocSize( numOperands, resizableOperandList), @@ -158,15 +159,15 @@ Instruction::create(Location location, OperationName name, // Create the new Instruction. auto inst = ::new (rawMem) - Instruction(location, name, resultTypes.size(), numSuccessors, - numBlockLists, attributes, context); + Instruction(location, name, resultTypes.size(), numSuccessors, numRegions, + attributes, context); assert((numSuccessors == 0 || !inst->isKnownNonTerminator()) && "unexpected successors in a non-terminator operation"); - // Initialize the block lists. - for (unsigned i = 0; i != numBlockLists; ++i) - new (&inst->getBlockList(i)) BlockList(inst); + // Initialize the regions. + for (unsigned i = 0; i != numRegions; ++i) + new (&inst->getRegion(i)) Region(inst); // Initialize the results and operands. new (&inst->getOperandStorage()) @@ -238,11 +239,11 @@ Instruction::create(Location location, OperationName name, Instruction::Instruction(Location location, OperationName name, unsigned numResults, unsigned numSuccessors, - unsigned numBlockLists, + unsigned numRegions, ArrayRef attributes, MLIRContext *context) : location(location), numResults(numResults), numSuccs(numSuccessors), - numBlockLists(numBlockLists), name(name), attrs(context, attributes) {} + numRegions(numRegions), name(name), attrs(context, attributes) {} // Instructions are deleted through the destroy() member because they are // allocated via malloc. @@ -259,9 +260,9 @@ Instruction::~Instruction() { for (auto &successor : getBlockOperands()) successor.~BlockOperand(); - // Explicitly destroy the block list. - for (auto &blockList : getBlockLists()) - blockList.~BlockList(); + // Explicitly destroy the regions. + for (auto ®ion : getRegions()) + region.~Region(); } /// Destroy this instruction or one of its subclasses. @@ -301,16 +302,16 @@ void Instruction::walk(const std::function &callback) { callback(this); // Visit any internal instructions. - for (auto &blockList : getBlockLists()) - for (auto &block : blockList) + for (auto ®ion : getRegions()) + for (auto &block : region) block.walk(callback); } void Instruction::walkPostOrder( const std::function &callback) { // Visit any internal instructions. - for (auto &blockList : llvm::reverse(getBlockLists())) - for (auto &block : llvm::reverse(blockList)) + for (auto ®ion : llvm::reverse(getRegions())) + for (auto &block : llvm::reverse(region)) block.walkPostOrder(callback); // Visit the current instruction. @@ -465,8 +466,8 @@ void Instruction::dropAllReferences() { for (auto &op : getInstOperands()) op.drop(); - for (auto &blockList : getBlockLists()) - for (Block &block : blockList) + for (auto ®ion : getRegions()) + for (Block &block : region) block.dropAllReferences(); for (auto &dest : getBlockOperands()) @@ -603,14 +604,14 @@ Instruction *Instruction::clone(BlockAndValueMapping &mapper, for (auto *result : getResults()) resultTypes.push_back(result->getType()); - unsigned numBlockLists = getNumBlockLists(); + unsigned numRegions = getNumRegions(); auto *newOp = Instruction::create(getLoc(), getName(), operands, resultTypes, - getAttrs(), successors, numBlockLists, + getAttrs(), successors, numRegions, hasResizableOperandsList(), context); - // Clone the block lists. - for (unsigned i = 0; i != numBlockLists; ++i) - getBlockList(i).cloneInto(&newOp->getBlockList(i), mapper, context); + // Clone the regions. + for (unsigned i = 0; i != numRegions; ++i) + getRegion(i).cloneInto(&newOp->getRegion(i), mapper, context); // Remember the mapping of any results. for (unsigned i = 0, e = getNumResults(); i != e; ++i) diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 7d977ae953d..7b5050bd839 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -2196,9 +2196,9 @@ public: // Block references. ParseResult - parseOperationBlockList(SmallVectorImpl &results, - ArrayRef> entryArguments); - ParseResult parseBlockListBody(SmallVectorImpl &results); + parseOperationRegion(SmallVectorImpl &results, + ArrayRef> entryArguments); + ParseResult parseRegionBody(SmallVectorImpl &results); ParseResult parseBlock(Block *&block); ParseResult parseBlockBody(Block *block); @@ -2279,7 +2279,7 @@ ParseResult FunctionParser::parseFunctionBody(bool hadNamedArguments) { // Parse the remaining list of blocks. SmallVector blocks; - if (parseBlockListBody(blocks)) + if (parseRegionBody(blocks)) return ParseFailure; function->getBlocks().insert(function->end(), blocks.begin(), blocks.end()); @@ -2307,14 +2307,14 @@ ParseResult FunctionParser::parseFunctionBody(bool hadNamedArguments) { /// /// block-list ::= '{' block-list-body /// -ParseResult FunctionParser::parseOperationBlockList( +ParseResult FunctionParser::parseOperationRegion( SmallVectorImpl &results, ArrayRef> entryArguments) { // Parse the '{'. - if (parseToken(Token::l_brace, "expected '{' to begin block list")) + if (parseToken(Token::l_brace, "expected '{' to begin a region")) return ParseFailure; - // Check for an empty block list. + // Check for an empty region. if (entryArguments.empty() && consumeIf(Token::r_brace)) return ParseSuccess; Block *currentBlock = builder.getInsertionBlock(); @@ -2342,9 +2342,9 @@ ParseResult FunctionParser::parseOperationBlockList( return emitError("entry block arguments were already defined"); } - // Parse the rest of the block list. + // Parse the rest of the region. results.push_back(block); - if (parseBlockListBody(results)) + if (parseRegionBody(results)) return ParseFailure; // Reset insertion point to the current block. @@ -2352,13 +2352,12 @@ ParseResult FunctionParser::parseOperationBlockList( return ParseSuccess; } -/// Block list. +/// Region. /// -/// block-list-body ::= block* '}' +/// region-body ::= block* '}' /// -ParseResult -FunctionParser::parseBlockListBody(SmallVectorImpl &results) { - // Parse the block list. +ParseResult FunctionParser::parseRegionBody(SmallVectorImpl &results) { + // Parse the list of blocks. while (!consumeIf(Token::r_brace)) { Block *newBlock = nullptr; if (parseBlock(newBlock)) { @@ -2440,7 +2439,7 @@ Value *FunctionParser::createForwardReferencePlaceholder(SMLoc loc, Type type) { auto name = OperationName("placeholder", getContext()); auto *inst = Instruction::create( getEncodedSourceLocation(loc), name, /*operands=*/{}, type, - /*attributes=*/{}, /*successors=*/{}, /*numBlockLists=*/0, + /*attributes=*/{}, /*successors=*/{}, /*numRegions=*/0, /*resizableOperandList=*/false, getContext()); forwardReferencePlaceholders[inst->getResult(0)] = loc; return inst->getResult(0); @@ -2888,25 +2887,25 @@ Instruction *FunctionParser::parseGenericOperation() { result.addSuccessor(successor, operands); } - // Parse the optional block lists for this operation. + // Parse the optional regions for this operation. std::vector> blocks; while (getToken().is(Token::l_brace)) { SmallVector newBlocks; - if (parseOperationBlockList(newBlocks, /*entryArguments=*/llvm::None)) { - for (auto &blockList : blocks) - cleanupInvalidBlocks(blockList); + if (parseOperationRegion(newBlocks, /*entryArguments=*/llvm::None)) { + for (auto ®ion : blocks) + cleanupInvalidBlocks(region); return nullptr; } blocks.emplace_back(newBlocks); } - result.reserveBlockLists(blocks.size()); + result.reserveRegions(blocks.size()); auto *opInst = builder.createOperation(result); - // Initialize the parsed block lists. + // Initialize the parsed regions. for (unsigned i = 0, e = blocks.size(); i != e; ++i) { - auto &blockList = opInst->getBlockList(i).getBlocks(); - blockList.insert(blockList.end(), blocks[i].begin(), blocks[i].end()); + auto ®ion = opInst->getRegion(i).getBlocks(); + region.insert(region.end(), blocks[i].begin(), blocks[i].end()); } return opInst; } @@ -2922,23 +2921,22 @@ public: if (opDefinition->parseAssembly(this, opState)) return true; - // Check that enough block lists were reserved for those that were parsed. - if (parsedBlockLists.size() > opState->numBlockLists) { + // Check that enough regions were reserved for those that were parsed. + if (parsedRegions.size() > opState->numRegions) { return emitError( nameLoc, - "parsed more block lists than those reserved in the operation state"); + "parsed more regions than those reserved in the operation state"); } // Check there were no dangling entry block arguments. - if (!parsedBlockListEntryArguments.empty()) { + if (!parsedRegionEntryArguments.empty()) { return emitError( - nameLoc, - "no block list was attached to parsed entry block arguments"); + nameLoc, "no region was attached to parsed entry block arguments"); } // Check that none of the operands of the current operation reference an - // entry block argument for any of the block lists. - for (auto *entryArg : parsedBlockListEntryArgumentPlaceholders) + // entry block argument for any of the region. + for (auto *entryArg : parsedRegionEntryArgumentPlaceholders) if (llvm::is_contained(opState->operands, entryArg)) return emitError(nameLoc, "operand use before it's defined"); @@ -3144,21 +3142,19 @@ public: return result == nullptr; } - /// Parses a list of blocks. - bool parseBlockList() override { - // Parse the block list. + /// Parses a region. + bool parseRegion() override { SmallVector results; - if (parser.parseOperationBlockList(results, parsedBlockListEntryArguments)) + if (parser.parseOperationRegion(results, parsedRegionEntryArguments)) return true; - parsedBlockListEntryArguments.clear(); - parsedBlockLists.emplace_back(results); + parsedRegionEntryArguments.clear(); + parsedRegions.emplace_back(results); return false; } - /// Parses an argument for the entry block of the next block list to be - /// parsed. - bool parseBlockListEntryBlockArgument(Type argType) override { + /// Parses an argument for the entry block of the next region to be parsed. + bool parseRegionEntryBlockArgument(Type argType) override { SmallVector argValues; OperandType operand; if (parseOperand(operand)) @@ -3168,10 +3164,10 @@ public: FunctionParser::SSAUseInfo operandInfo = {operand.name, operand.number, operand.location}; if (auto *value = parser.resolveSSAUse(operandInfo, argType)) { - parsedBlockListEntryArguments.emplace_back(operandInfo, argType); + parsedRegionEntryArguments.emplace_back(operandInfo, argType); // Track each of the placeholders so that we can detect invalid references - // to block list arguments. - parsedBlockListEntryArgumentPlaceholders.emplace_back(value); + // to region arguments. + parsedRegionEntryArgumentPlaceholders.emplace_back(value); return false; } @@ -3199,10 +3195,10 @@ public: /// Emit a diagnostic at the specified location and return true. bool emitError(llvm::SMLoc loc, const Twine &message) override { - // If we emit an error, then cleanup any parsed block lists. - for (auto &blockList : parsedBlockLists) - parser.cleanupInvalidBlocks(blockList); - parsedBlockLists.clear(); + // If we emit an error, then cleanup any parsed regions. + for (auto ®ion : parsedRegions) + parser.cleanupInvalidBlocks(region); + parsedRegions.clear(); parser.emitError(loc, "custom op '" + Twine(opName) + "' " + message); emittedError = true; @@ -3211,16 +3207,16 @@ public: bool didEmitError() const { return emittedError; } - /// Returns the block lists that were parsed. - MutableArrayRef> getParsedBlockLists() { - return parsedBlockLists; + /// Returns the regions that were parsed. + MutableArrayRef> getParsedRegions() { + return parsedRegions; } private: - std::vector> parsedBlockLists; + std::vector> parsedRegions; SmallVector, 2> - parsedBlockListEntryArguments; - SmallVector parsedBlockListEntryArgumentPlaceholders; + parsedRegionEntryArguments; + SmallVector parsedRegionEntryArgumentPlaceholders; SMLoc nameLoc; StringRef opName; FunctionParser &parser; @@ -3271,12 +3267,12 @@ Instruction *FunctionParser::parseCustomOperation() { // Otherwise, we succeeded. Use the state it parsed as our op information. auto *opInst = builder.createOperation(opState); - // Resolve any parsed block lists. - auto parsedBlockLists = opAsmParser.getParsedBlockLists(); - for (unsigned i = 0, e = parsedBlockLists.size(); i != e; ++i) { - auto &opBlockList = opInst->getBlockList(i).getBlocks(); - opBlockList.insert(opBlockList.end(), parsedBlockLists[i].begin(), - parsedBlockLists[i].end()); + // Resolve any parsed regions. + auto parsedRegions = opAsmParser.getParsedRegions(); + for (unsigned i = 0, e = parsedRegions.size(); i != e; ++i) { + auto &opRegion = opInst->getRegion(i).getBlocks(); + opRegion.insert(opRegion.end(), parsedRegions[i].begin(), + parsedRegions[i].end()); } return opInst; } diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index 37dfce17bda..02fbda0f680 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -109,7 +109,7 @@ struct CSE : public FunctionPass { bool simplifyOperation(Instruction *op); void simplifyBlock(DominanceInfo &domInfo, Block *bb); - void simplifyBlockList(DominanceInfo &domInfo, BlockList &blockList); + void simplifyRegion(DominanceInfo &domInfo, Region ®ion); void runOnFunction() override; @@ -127,7 +127,7 @@ bool CSE::simplifyOperation(Instruction *op) { // Don't simplify operations with nested blocks. We don't currently model // equality comparisons correctly among other things. It is also unclear // whether we would want to CSE such operations. - if (op->getNumBlockLists() != 0) + if (op->getNumRegions() != 0) return false; // TODO(riverriddle) We currently only eliminate non side-effecting @@ -166,25 +166,25 @@ bool CSE::simplifyOperation(Instruction *op) { void CSE::simplifyBlock(DominanceInfo &domInfo, Block *bb) { for (auto &i : *bb) { - // If the operation is simplified, we don't process any held block lists. + // If the operation is simplified, we don't process any held regions. if (simplifyOperation(&i)) continue; // Simplify any held blocks. - for (auto &blockList : i.getBlockLists()) - simplifyBlockList(domInfo, blockList); + for (auto ®ion : i.getRegions()) + simplifyRegion(domInfo, region); } } -void CSE::simplifyBlockList(DominanceInfo &domInfo, BlockList &blockList) { - // If the block list is empty there is nothing to do. - if (blockList.empty()) +void CSE::simplifyRegion(DominanceInfo &domInfo, Region ®ion) { + // If the region is empty there is nothing to do. + if (region.empty()) return; - // If the block list only contains one block, then simplify it directly. - if (std::next(blockList.begin()) == blockList.end()) { + // If the region only contains one block, then simplify it directly. + if (std::next(region.begin()) == region.end()) { ScopedMapTy::ScopeTy scope(knownValues); - simplifyBlock(domInfo, &blockList.front()); + simplifyBlock(domInfo, ®ion.front()); return; } @@ -196,9 +196,9 @@ void CSE::simplifyBlockList(DominanceInfo &domInfo, BlockList &blockList) { // http://lists.llvm.org/pipermail/llvm-commits/Week-of-Mon-20120116/135228.html std::deque> stack; - // Process the nodes of the dom tree for this blocklist. + // Process the nodes of the dom tree for this region. stack.emplace_back(std::make_unique( - knownValues, domInfo.getRootNode(&blockList))); + knownValues, domInfo.getRootNode(®ion))); while (!stack.empty()) { auto ¤tNode = stack.back(); @@ -223,8 +223,7 @@ void CSE::simplifyBlockList(DominanceInfo &domInfo, BlockList &blockList) { } void CSE::runOnFunction() { - simplifyBlockList(getAnalysis(), - getFunction()->getBlockList()); + simplifyRegion(getAnalysis(), getFunction()->getBody()); // If no operations were erased, then we mark all analyses as preserved. if (opsToErase.empty()) { diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index a621db4e183..72a378b7d9a 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -172,7 +172,7 @@ impl::FunctionConversion::convertBlock(Block *block, FuncBuilder &builder, // Iterate over ops and convert them. for (Instruction &inst : *block) { - if (inst.getNumBlockLists() != 0) { + if (inst.getNumRegions() != 0) { inst.emitError("unsupported region instruction"); return failure(); } diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index e05af79caa1..12e52eaf4d3 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -130,7 +130,7 @@ struct LoopNestStateCollector { instToWalk->walk([&](Instruction *opInst) { if (opInst->isa()) forOps.push_back(opInst->cast()); - else if (opInst->getNumBlockLists() != 0) + else if (opInst->getNumRegions() != 0) hasNonForRegion = true; else if (opInst->isa()) loadOpInsts.push_back(opInst); @@ -670,7 +670,7 @@ bool MemRefDependenceGraph::init(Function *f) { auto *memref = inst.cast()->getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); - } else if (inst.getNumBlockLists() != 0) { + } else if (inst.getNumRegions() != 0) { // Return false if another region is found (not currently supported). return false; } else if (inst.getNumResults() > 0 && !inst.use_empty()) { diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 4682923dc7b..b1a306ddd7f 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -109,8 +109,8 @@ void LoopUnroll::runOnFunction() { } bool walkPostOrder(Instruction *opInst) { bool hasInnerLoops = false; - for (auto &blockList : opInst->getBlockLists()) - for (auto &block : blockList) + for (auto ®ion : opInst->getRegions()) + for (auto &block : region) hasInnerLoops |= walkPostOrder(block.begin(), block.end()); if (opInst->isa()) { if (!hasInnerLoops) diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index f1cc7c6b946..2b92b2a6422 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -132,8 +132,8 @@ LogicalResult mlir::loopUnrollJamByFactor(OpPointer forOp, // This is a linear time walk. void walk(Instruction *inst) { - for (auto &blockList : inst->getBlockLists()) - for (auto &block : blockList) + for (auto ®ion : inst->getRegions()) + for (auto &block : region) walk(block); } void walk(Block &block) { diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 0d54ead424e..804991a7b8b 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -411,7 +411,7 @@ instantiate(FuncBuilder *b, Instruction *opInst, VectorType hwVectorType, "Should call the function specialized for VectorTransferReadOp"); assert(!opInst->isa() && "Should call the function specialized for VectorTransferWriteOp"); - if (opInst->getNumBlockLists() != 0) + if (opInst->getNumRegions() != 0) return nullptr; bool fail = false; @@ -553,7 +553,7 @@ static bool instantiateMaterialization(Instruction *inst, if (inst->isa()) { return false; } - if (inst->getNumBlockLists() != 0) + if (inst->getNumRegions() != 0) return inst->emitError("NYI path Op with region"); if (auto write = inst->dyn_cast()) { diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index b084b016be3..e7e72fdf461 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -1087,7 +1087,7 @@ static Instruction *vectorizeOneInstruction(FuncBuilder *b, Instruction *opInst, opInst->erase(); return res; } - if (opInst->getNumBlockLists() != 0) + if (opInst->getNumRegions() != 0) return nullptr; auto types = map([state](Value *v) { return getVectorType(v, *state); }, @@ -1112,7 +1112,7 @@ static Instruction *vectorizeOneInstruction(FuncBuilder *b, Instruction *opInst, OperationState newOp(b->getContext(), opInst->getLoc(), opInst->getName().getStringRef(), operands, types, opInst->getAttrs(), /*successors=*/{}, - /*numBlockLists=*/0, opInst->hasResizableOperandsList()); + /*numRegions=*/0, opInst->hasResizableOperandsList()); return b->createOperation(newOp); } diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir index b96eaaaeca5..cebcdf6cc9e 100644 --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -222,7 +222,7 @@ func @malformed_for_to() { func @incomplete_for() { for %i = 1 to 10 step 2 -} // expected-error {{expected '{' to begin block list}} +} // expected-error {{expected '{' to begin a region}} // ----- diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index 34b74621490..a9aa4774470 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -787,7 +787,7 @@ func @verbose_if(%N: index) { // CHECK-NEXT: "add" %y = "add"(%c, %N) : (index, index) -> index // CHECK-NEXT: } else { - } { // The else block list. + } { // The else region. // CHECK-NEXT: "add" %z = "add"(%c, %c) : (index, index) -> index } -- cgit v1.2.3 From 57270a9a99704e64ad77934a19def7f86a9e09a9 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Tue, 19 Mar 2019 08:45:06 -0700 Subject: Remove some statements that required >C++11, add includes and qualify names. NFC. PiperOrigin-RevId: 239197784 --- mlir/lib/AffineOps/AffineOps.cpp | 6 +++--- mlir/lib/Analysis/AffineStructures.cpp | 2 +- mlir/lib/Analysis/Dominance.cpp | 4 ++-- mlir/lib/Analysis/MemRefDependenceCheck.cpp | 8 ++++---- mlir/lib/Analysis/Utils.cpp | 2 +- mlir/lib/IR/AffineExprDetail.h | 9 +++++++++ mlir/lib/IR/AttributeDetail.h | 28 ++++++++++++++++++++++++++++ mlir/lib/StandardOps/Ops.cpp | 22 +++++++++++----------- mlir/lib/Transforms/CSE.cpp | 4 ++-- mlir/lib/Transforms/DmaGeneration.cpp | 2 +- mlir/lib/Transforms/LoopFusion.cpp | 1 + mlir/lib/Transforms/LoopTiling.cpp | 1 + mlir/lib/Transforms/Utils/Utils.cpp | 4 ++-- mlir/test/mlir-tblgen/one-op-one-result.td | 2 +- mlir/test/mlir-tblgen/pattern-tAttr.td | 4 ++-- mlir/tools/mlir-tblgen/RewriterGen.cpp | 4 ++-- 16 files changed, 71 insertions(+), 32 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 16fddb45496..bc8564c7931 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -505,7 +505,7 @@ PatternMatchResult SimplifyAffineApply::match(Instruction *op) const { composeAffineMapAndOperands(&map, &resultOperands); if (map != oldMap) return matchSuccess( - std::make_unique(map, resultOperands)); + llvm::make_unique(map, resultOperands)); return matchFailure(); } @@ -520,7 +520,7 @@ void SimplifyAffineApply::rewrite(Instruction *op, void AffineApplyOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - results.push_back(std::make_unique(context)); + results.push_back(llvm::make_unique(context)); } //===----------------------------------------------------------------------===// @@ -867,7 +867,7 @@ struct AffineForLoopBoundFolder : public RewritePattern { void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.push_back(std::make_unique(context)); + results.push_back(llvm::make_unique(context)); } Block *AffineForOp::createBody() { diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 68fccf7762c..07ca59846a6 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -309,7 +309,7 @@ FlatAffineConstraints::FlatAffineConstraints( // Clones this object. std::unique_ptr FlatAffineConstraints::clone() const { - return std::make_unique(*this); + return llvm::make_unique(*this); } // Construct from an IntegerSet. diff --git a/mlir/lib/Analysis/Dominance.cpp b/mlir/lib/Analysis/Dominance.cpp index a6f6845e9ef..8ccee5d4e29 100644 --- a/mlir/lib/Analysis/Dominance.cpp +++ b/mlir/lib/Analysis/Dominance.cpp @@ -40,7 +40,7 @@ void DominanceInfoBase::recalculate(Function *function) { dominanceInfos.clear(); // Build the top level function dominance. - auto functionDominance = std::make_unique(); + auto functionDominance = llvm::make_unique(); functionDominance->recalculate(function->getBody()); dominanceInfos.try_emplace(&function->getBody(), std::move(functionDominance)); @@ -51,7 +51,7 @@ void DominanceInfoBase::recalculate(Function *function) { // Don't compute dominance if the region is empty. if (region.empty()) continue; - auto opDominance = std::make_unique(); + auto opDominance = llvm::make_unique(); opDominance->recalculate(region); dominanceInfos.try_emplace(®ion, std::move(opDominance)); } diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp index 0c2a5defe10..87267183a5f 100644 --- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp @@ -49,22 +49,22 @@ FunctionPassBase *mlir::createMemRefDependenceCheckPass() { // Returns a result string which represents the direction vector (if there was // a dependence), returns the string "false" otherwise. -static string +static std::string getDirectionVectorStr(bool ret, unsigned numCommonLoops, unsigned loopNestDepth, ArrayRef dependenceComponents) { if (!ret) return "false"; if (dependenceComponents.empty() || loopNestDepth > numCommonLoops) return "true"; - string result; + std::string result; for (unsigned i = 0, e = dependenceComponents.size(); i < e; ++i) { - string lbStr = "-inf"; + std::string lbStr = "-inf"; if (dependenceComponents[i].lb.hasValue() && dependenceComponents[i].lb.getValue() != std::numeric_limits::min()) lbStr = std::to_string(dependenceComponents[i].lb.getValue()); - string ubStr = "+inf"; + std::string ubStr = "+inf"; if (dependenceComponents[i].ub.hasValue() && dependenceComponents[i].ub.getValue() != std::numeric_limits::max()) diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 3399c7da88f..7d68b690a2e 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -700,7 +700,7 @@ static Optional getMemoryFootprintBytes(const Block &block, } // Compute the memref region symbolic in any IVs enclosing this block. - auto region = std::make_unique(opInst->getLoc()); + auto region = llvm::make_unique(opInst->getLoc()); if (failed( region->compute(opInst, /*loopDepth=*/getNestingDepth(*block.begin())))) { diff --git a/mlir/lib/IR/AffineExprDetail.h b/mlir/lib/IR/AffineExprDetail.h index 53172615a91..bca0957bcb8 100644 --- a/mlir/lib/IR/AffineExprDetail.h +++ b/mlir/lib/IR/AffineExprDetail.h @@ -42,6 +42,9 @@ struct AffineExprStorage { /// A binary operation appearing in an affine expression. struct AffineBinaryOpExprStorage : public AffineExprStorage { + AffineBinaryOpExprStorage(AffineExprStorage base, AffineExpr lhs, + AffineExpr rhs) + : AffineExprStorage(base), lhs(lhs), rhs(rhs) {} static AffineExpr get(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs); AffineExpr lhs; AffineExpr rhs; @@ -49,18 +52,24 @@ struct AffineBinaryOpExprStorage : public AffineExprStorage { /// A dimensional identifier appearing in an affine expression. struct AffineDimExprStorage : public AffineExprStorage { + AffineDimExprStorage(AffineExprStorage base, unsigned position) + : AffineExprStorage(base), position(position) {} /// Position of this identifier in the argument list. unsigned position; }; /// A symbolic identifier appearing in an affine expression. struct AffineSymbolExprStorage : public AffineExprStorage { + AffineSymbolExprStorage(AffineExprStorage base, unsigned position) + : AffineExprStorage(base), position(position) {} /// Position of this identifier in the symbol list. unsigned position; }; /// An integer constant appearing in affine expression. struct AffineConstantExprStorage : public AffineExprStorage { + AffineConstantExprStorage(AffineExprStorage base, int64_t constant) + : AffineExprStorage(base), constant(constant) {} // The constant. int64_t constant; }; diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h index 4a4d04370f9..e3f3df9dc81 100644 --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -44,6 +44,8 @@ struct AttributeStorage { /// An attribute representing a boolean value. struct BoolAttributeStorage : public AttributeStorage { + BoolAttributeStorage(AttributeStorage base, Type type, bool value) + : AttributeStorage(base), type(type), value(value) {} const Type type; bool value; }; @@ -92,59 +94,85 @@ struct FloatAttributeStorage final /// An attribute representing a string value. struct StringAttributeStorage : public AttributeStorage { + StringAttributeStorage(AttributeStorage base, StringRef value) + : AttributeStorage(base), value(value) {} StringRef value; }; /// An attribute representing an array of other attributes. struct ArrayAttributeStorage : public AttributeStorage { + ArrayAttributeStorage(AttributeStorage base, ArrayRef value) + : AttributeStorage(base), value(value) {} ArrayRef value; }; // An attribute representing a reference to an affine map. struct AffineMapAttributeStorage : public AttributeStorage { + AffineMapAttributeStorage(AttributeStorage base, AffineMap value) + : AttributeStorage(base), value(value) {} AffineMap value; }; // An attribute representing a reference to an integer set. struct IntegerSetAttributeStorage : public AttributeStorage { + IntegerSetAttributeStorage(AttributeStorage base, IntegerSet value) + : AttributeStorage(base), value(value) {} IntegerSet value; }; /// An attribute representing a reference to a type. struct TypeAttributeStorage : public AttributeStorage { + TypeAttributeStorage(AttributeStorage base, Type value) + : AttributeStorage(base), value(value) {} Type value; }; /// An attribute representing a reference to a function. struct FunctionAttributeStorage : public AttributeStorage { + FunctionAttributeStorage(AttributeStorage base, Function *value) + : AttributeStorage(base), value(value) {} Function *value; }; /// A base attribute representing a reference to a vector or tensor constant. struct ElementsAttributeStorage : public AttributeStorage { + ElementsAttributeStorage(AttributeStorage base, VectorOrTensorType type) + : AttributeStorage(base), type(type) {} VectorOrTensorType type; }; /// An attribute representing a reference to a vector or tensor constant, /// inwhich all elements have the same value. struct SplatElementsAttributeStorage : public ElementsAttributeStorage { + SplatElementsAttributeStorage(ElementsAttributeStorage base, Attribute elt) + : ElementsAttributeStorage(base), elt(elt) {} Attribute elt; }; /// An attribute representing a reference to a dense vector or tensor object. struct DenseElementsAttributeStorage : public ElementsAttributeStorage { + DenseElementsAttributeStorage(ElementsAttributeStorage base, + ArrayRef data) + : ElementsAttributeStorage(base), data(data) {} ArrayRef data; }; /// An attribute representing a reference to a tensor constant with opaque /// content. struct OpaqueElementsAttributeStorage : public ElementsAttributeStorage { + OpaqueElementsAttributeStorage(ElementsAttributeStorage base, + Dialect *dialect, StringRef bytes) + : ElementsAttributeStorage(base), dialect(dialect), bytes(bytes) {} Dialect *dialect; StringRef bytes; }; /// An attribute representing a reference to a sparse vector or tensor object. struct SparseElementsAttributeStorage : public ElementsAttributeStorage { + SparseElementsAttributeStorage(ElementsAttributeStorage base, + DenseIntElementsAttr indices, + DenseElementsAttr values) + : ElementsAttributeStorage(base), indices(indices), values(values) {} DenseIntElementsAttr indices; DenseElementsAttr values; }; diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index ea84a3360ae..49837e0211d 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -371,8 +371,8 @@ struct SimplifyDeadAlloc : public RewritePattern { void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.push_back(std::make_unique(context)); - results.push_back(std::make_unique(context)); + results.push_back(llvm::make_unique(context)); + results.push_back(llvm::make_unique(context)); } //===----------------------------------------------------------------------===// @@ -578,7 +578,7 @@ bool CallIndirectOp::verify() const { void CallIndirectOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.push_back( - std::make_unique(context)); + llvm::make_unique(context)); } //===----------------------------------------------------------------------===// @@ -887,7 +887,7 @@ bool CondBranchOp::verify() const { void CondBranchOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - results.push_back(std::make_unique(context)); + results.push_back(llvm::make_unique(context)); } Block *CondBranchOp::getTrueDest() { @@ -1143,8 +1143,8 @@ void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { /// dealloc(memrefcast) -> dealloc results.push_back( - std::make_unique(getOperationName(), context)); - results.push_back(std::make_unique(context)); + llvm::make_unique(getOperationName(), context)); + results.push_back(llvm::make_unique(context)); } //===----------------------------------------------------------------------===// @@ -1424,7 +1424,7 @@ void DmaStartOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { /// dma_start(memrefcast) -> dma_start results.push_back( - std::make_unique(getOperationName(), context)); + llvm::make_unique(getOperationName(), context)); } // --------------------------------------------------------------------------- @@ -1488,7 +1488,7 @@ void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { /// dma_wait(memrefcast) -> dma_wait results.push_back( - std::make_unique(getOperationName(), context)); + llvm::make_unique(getOperationName(), context)); } //===----------------------------------------------------------------------===// @@ -1643,7 +1643,7 @@ void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { /// load(memrefcast) -> load results.push_back( - std::make_unique(getOperationName(), context)); + llvm::make_unique(getOperationName(), context)); } //===----------------------------------------------------------------------===// @@ -1964,7 +1964,7 @@ void StoreOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { /// store(memrefcast) -> store results.push_back( - std::make_unique(getOperationName(), context)); + llvm::make_unique(getOperationName(), context)); } //===----------------------------------------------------------------------===// @@ -2013,7 +2013,7 @@ struct SimplifyXMinusX : public RewritePattern { void SubIOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.push_back(std::make_unique(context)); + results.push_back(llvm::make_unique(context)); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index 02fbda0f680..31f4d48e4ed 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -197,7 +197,7 @@ void CSE::simplifyRegion(DominanceInfo &domInfo, Region ®ion) { std::deque> stack; // Process the nodes of the dom tree for this region. - stack.emplace_back(std::make_unique( + stack.emplace_back(llvm::make_unique( knownValues, domInfo.getRootNode(®ion))); while (!stack.empty()) { @@ -213,7 +213,7 @@ void CSE::simplifyRegion(DominanceInfo &domInfo, Region ®ion) { if (currentNode->childIterator != currentNode->node->end()) { auto *childNode = *(currentNode->childIterator++); stack.emplace_back( - std::make_unique(knownValues, childNode)); + llvm::make_unique(knownValues, childNode)); } else { // Finally, if the node and all of its children have been processed // then we delete the node. diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 43c709fa3ab..23e07fc3a89 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -618,7 +618,7 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { } // Compute the MemRefRegion accessed. - auto region = std::make_unique(opInst->getLoc()); + auto region = llvm::make_unique(opInst->getLoc()); if (failed(region->compute(opInst, dmaDepth))) { LLVM_DEBUG(llvm::dbgs() << "Error obtaining memory region: semi-affine maps?\n"); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 12e52eaf4d3..6d4ea7206b7 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -39,6 +39,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include +#include #define DEBUG_TYPE "loop-fusion" diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 1a4b368c4c6..76ab91641f3 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -32,6 +32,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include +#include using namespace mlir; diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 8a21f273006..cbf68056eb9 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -69,11 +69,11 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, std::unique_ptr domInfo; std::unique_ptr postDomInfo; if (domInstFilter) - domInfo = std::make_unique(domInstFilter->getFunction()); + domInfo = llvm::make_unique(domInstFilter->getFunction()); if (postDomInstFilter) postDomInfo = - std::make_unique(postDomInstFilter->getFunction()); + llvm::make_unique(postDomInstFilter->getFunction()); // The ops where memref replacement succeeds are replaced with new ones. SmallVector opsToErase; diff --git a/mlir/test/mlir-tblgen/one-op-one-result.td b/mlir/test/mlir-tblgen/one-op-one-result.td index 7dc37b58946..324ccf7bb2c 100644 --- a/mlir/test/mlir-tblgen/one-op-one-result.td +++ b/mlir/test/mlir-tblgen/one-op-one-result.td @@ -28,4 +28,4 @@ def : Pat<(X_AddOp (X_AddOp:$res $lhs, $rhs), $rrhs), (Y_AddOp $lhs, U:$rhs, T_C // CHECK: PatternRewriter &rewriter) // CHECK: rewriter.create(loc, op->getResult(0)->getType() // CHECK: void populateWithGenerated -// CHECK: patterns->push_back(std::make_unique(context)) +// CHECK: patterns->push_back(llvm::make_unique(context)) diff --git a/mlir/test/mlir-tblgen/pattern-tAttr.td b/mlir/test/mlir-tblgen/pattern-tAttr.td index 017a676fb52..15dbd69d319 100644 --- a/mlir/test/mlir-tblgen/pattern-tAttr.td +++ b/mlir/test/mlir-tblgen/pattern-tAttr.td @@ -48,5 +48,5 @@ def : Pat<(Z_AddOp $lhs, $rhs, $attr1, $attr2), (Y_AddOp $lhs, $rhs, (T_Compose_ // CHECK-NEXT: rewriter.replaceOp(op, {vAddOp0}); // CHECK: void populateWithGenerated -// CHECK: patterns->push_back(std::make_unique(context)) -// CHECK: patterns->push_back(std::make_unique(context)) +// CHECK: patterns->push_back(llvm::make_unique(context)) +// CHECK: patterns->push_back(llvm::make_unique(context)) diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 88d0b4d2128..6a7af8005a8 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -279,7 +279,7 @@ void PatternEmitter::emitMatchMethod(DagNode tree) { os << R"( PatternMatchResult match(Instruction *op0) const override { auto ctx = op0->getContext(); (void)ctx; - auto state = std::make_unique();)" + auto state = llvm::make_unique();)" << "\n"; // The rewrite pattern may specify that certain outputs should be unused in @@ -660,7 +660,7 @@ static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { os << "void populateWithGenerated(MLIRContext *context, " << "OwningRewritePatternList *patterns) {\n"; for (unsigned i = 0; i != rewritePatternCount; ++i) { - os.indent(2) << "patterns->push_back(std::make_unique<" << baseRewriteName + os.indent(2) << "patterns->push_back(llvm::make_unique<" << baseRewriteName << i << ">(context));\n"; } os << "}\n"; -- cgit v1.2.3 From 986310a68f119c7fe60c1d80379454836d298cf5 Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Sat, 23 Mar 2019 15:09:06 -0700 Subject: Remove const from Value, Instruction, Argument, and the various methods on the *Op classes. This is a net reduction by almost 400LOC. PiperOrigin-RevId: 239972443 --- mlir/include/mlir/AffineOps/AffineOps.h | 63 +++---- mlir/include/mlir/Analysis/AffineStructures.h | 6 +- mlir/include/mlir/Analysis/Dominance.h | 12 +- mlir/include/mlir/Analysis/LoopAnalysis.h | 6 +- mlir/include/mlir/Analysis/NestedMatcher.h | 10 +- mlir/include/mlir/Analysis/Utils.h | 7 +- mlir/include/mlir/Analysis/VectorAnalysis.h | 2 +- mlir/include/mlir/Dialect/Traits.h | 4 +- mlir/include/mlir/IR/Block.h | 2 +- mlir/include/mlir/IR/BlockAndValueMapping.h | 4 +- mlir/include/mlir/IR/Builders.h | 4 +- mlir/include/mlir/IR/Dialect.h | 6 +- mlir/include/mlir/IR/Instruction.h | 199 ++++++--------------- mlir/include/mlir/IR/OpDefinition.h | 173 ++++++------------ mlir/include/mlir/IR/OpImplementation.h | 9 +- mlir/include/mlir/IR/OperationSupport.h | 4 +- mlir/include/mlir/IR/UseDefLists.h | 6 +- mlir/include/mlir/IR/Value.h | 44 ++--- mlir/include/mlir/StandardOps/Ops.h | 183 +++++-------------- mlir/include/mlir/SuperVectorOps/SuperVectorOps.h | 27 ++- mlir/include/mlir/Transforms/Utils.h | 6 +- mlir/lib/AffineOps/AffineOps.cpp | 51 +++--- mlir/lib/Analysis/AffineAnalysis.cpp | 28 +-- mlir/lib/Analysis/AffineStructures.cpp | 14 +- mlir/lib/Analysis/Dominance.cpp | 10 +- mlir/lib/Analysis/LoopAnalysis.cpp | 53 +++--- mlir/lib/Analysis/NestedMatcher.cpp | 22 +-- mlir/lib/Analysis/Utils.cpp | 11 +- mlir/lib/Analysis/VectorAnalysis.cpp | 2 +- mlir/lib/Analysis/Verifier.cpp | 12 +- mlir/lib/Dialect/Traits.cpp | 2 +- mlir/lib/EDSC/MLIREmitter.cpp | 4 +- mlir/lib/IR/AsmPrinter.cpp | 60 +++---- mlir/lib/IR/Block.cpp | 10 +- mlir/lib/IR/Instruction.cpp | 65 +++---- mlir/lib/IR/Operation.cpp | 46 +++-- mlir/lib/IR/Value.cpp | 6 +- mlir/lib/StandardOps/Ops.cpp | 9 +- mlir/lib/SuperVectorOps/SuperVectorOps.cpp | 29 +-- mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp | 19 +- mlir/lib/Transforms/CSE.cpp | 7 +- mlir/lib/Transforms/DialectConversion.cpp | 9 +- mlir/lib/Transforms/DmaGeneration.cpp | 2 +- mlir/lib/Transforms/LoopFusion.cpp | 2 +- mlir/lib/Transforms/MaterializeVectors.cpp | 36 ++-- mlir/lib/Transforms/PipelineDataTransfer.cpp | 8 +- mlir/lib/Transforms/Utils/Utils.cpp | 10 +- .../Vectorization/VectorizerTestPass.cpp | 10 +- mlir/lib/Transforms/Vectorize.cpp | 11 +- mlir/test/mlir-tblgen/op-decl.td | 2 +- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 3 +- 51 files changed, 494 insertions(+), 836 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/AffineOps/AffineOps.h b/mlir/include/mlir/AffineOps/AffineOps.h index 16b4c8e4775..d1ad0a7ddec 100644 --- a/mlir/include/mlir/AffineOps/AffineOps.h +++ b/mlir/include/mlir/AffineOps/AffineOps.h @@ -35,7 +35,7 @@ class FuncBuilder; /// A utility function to check if a value is defined at the top level of a /// function. A value defined at the top level is always a valid symbol. -bool isTopLevelSymbol(const Value *value); +bool isTopLevelSymbol(Value *value); class AffineOpsDialect : public Dialect { public: @@ -64,15 +64,15 @@ public: ArrayRef operands); /// Returns the affine map to be applied by this operation. - AffineMap getAffineMap() const { + AffineMap getAffineMap() { return getAttrOfType("map").getValue(); } /// Returns true if the result of this operation can be used as dimension id. - bool isValidDim() const; + bool isValidDim(); /// Returns true if the result of this operation is a symbol. - bool isValidSymbol() const; + bool isValidSymbol(); static StringRef getOperationName() { return "affine.apply"; } @@ -87,7 +87,7 @@ public: private: friend class Instruction; - explicit AffineApplyOp(const Instruction *state) : Op(state) {} + explicit AffineApplyOp(Instruction *state) : Op(state) {} }; /// The "for" instruction represents an affine loop nest, defining an SSA value @@ -141,16 +141,13 @@ public: Block *createBody(); /// Get the body of the AffineForOp. - Block *getBody() const { return &getRegion().front(); } + Block *getBody() { return &getRegion().front(); } /// Get the body region of the AffineForOp. - Region &getRegion() const { return getInstruction()->getRegion(0); } + Region &getRegion() { return getInstruction()->getRegion(0); } /// Returns the induction variable for this loop. Value *getInductionVar(); - const Value *getInductionVar() const { - return const_cast(this)->getInductionVar(); - } //===--------------------------------------------------------------------===// // Bounds and step @@ -161,29 +158,27 @@ public: /// Returns operands for the lower bound map. operand_range getLowerBoundOperands(); - const_operand_range getLowerBoundOperands() const; /// Returns operands for the upper bound map. operand_range getUpperBoundOperands(); - const_operand_range getUpperBoundOperands() const; /// Returns information about the lower bound as a single object. - const AffineBound getLowerBound() const; + AffineBound getLowerBound(); /// Returns information about the upper bound as a single object. - const AffineBound getUpperBound() const; + AffineBound getUpperBound(); /// Returns loop step. - int64_t getStep() const { + int64_t getStep() { return getAttr(getStepAttrName()).cast().getInt(); } /// Returns affine map for the lower bound. - AffineMap getLowerBoundMap() const { + AffineMap getLowerBoundMap() { return getAttr(getLowerBoundAttrName()).cast().getValue(); } /// Returns affine map for the upper bound. The upper bound is exclusive. - AffineMap getUpperBoundMap() const { + AffineMap getUpperBoundMap() { return getAttr(getUpperBoundAttrName()).cast().getValue(); } @@ -209,19 +204,19 @@ public: } /// Returns true if the lower bound is constant. - bool hasConstantLowerBound() const; + bool hasConstantLowerBound(); /// Returns true if the upper bound is constant. - bool hasConstantUpperBound() const; + bool hasConstantUpperBound(); /// Returns true if both bounds are constant. - bool hasConstantBounds() const { + bool hasConstantBounds() { return hasConstantLowerBound() && hasConstantUpperBound(); } /// Returns the value of the constant lower bound. /// Fails assertion if the bound is non-constant. - int64_t getConstantLowerBound() const; + int64_t getConstantLowerBound(); /// Returns the value of the constant upper bound. The upper bound is /// exclusive. Fails assertion if the bound is non-constant. - int64_t getConstantUpperBound() const; + int64_t getConstantUpperBound(); /// Sets the lower bound to the given constant value. void setConstantLowerBound(int64_t value); /// Sets the upper bound to the given constant value. @@ -229,19 +224,19 @@ public: /// Returns true if both the lower and upper bound have the same operand lists /// (same operands in the same order). - bool matchingBoundOperandList() const; + bool matchingBoundOperandList(); private: friend class Instruction; - explicit AffineForOp(const Instruction *state) : Op(state) {} + explicit AffineForOp(Instruction *state) : Op(state) {} }; /// Returns if the provided value is the induction variable of a AffineForOp. -bool isForInductionVar(const Value *val); +bool isForInductionVar(Value *val); /// Returns the loop parent of an induction variable. If the provided value is /// not an induction variable, then return nullptr. -OpPointer getForInductionVarOwner(const Value *val); +OpPointer getForInductionVarOwner(Value *val); /// Extracts the induction variables from a list of AffineForOps and places them /// in the output argument `ivs`. @@ -262,7 +257,7 @@ public: AffineValueMap getAsAffineValueMap(); unsigned getNumOperands() const { return opEnd - opStart; } - const Value *getOperand(unsigned idx) const { + Value *getOperand(unsigned idx) const { return inst->getInstruction()->getOperand(opStart + idx); } @@ -323,20 +318,14 @@ public: static StringRef getOperationName() { return "if"; } static StringRef getConditionAttrName() { return "condition"; } - IntegerSet getIntegerSet() const; + IntegerSet getIntegerSet(); void setIntegerSet(IntegerSet newSet); /// Returns the 'then' region. Region &getThenBlocks(); - Region &getThenBlocks() const { - return const_cast(this)->getThenBlocks(); - } /// Returns the 'else' blocks. Region &getElseBlocks(); - Region &getElseBlocks() const { - return const_cast(this)->getElseBlocks(); - } bool verify(); static bool parse(OpAsmParser *parser, OperationState *result); @@ -344,14 +333,14 @@ public: private: friend class Instruction; - explicit AffineIfOp(const Instruction *state) : Op(state) {} + explicit AffineIfOp(Instruction *state) : Op(state) {} }; /// Returns true if the given Value can be used as a dimension id. -bool isValidDim(const Value *value); +bool isValidDim(Value *value); /// Returns true if the given Value can be used as a symbol. -bool isValidSymbol(const Value *value); +bool isValidSymbol(Value *value); /// Modifies both `map` and `operands` in-place so as to: /// 1. drop duplicate operands diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h index 848c9215fe7..36b82acd1b7 100644 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -443,16 +443,16 @@ public: /// Sets the identifier corresponding to the specified Value id to a /// constant. Asserts if the 'id' is not found. - void setIdToConstant(const Value &id, int64_t val); + void setIdToConstant(Value &id, int64_t val); /// Looks up the position of the identifier with the specified Value. Returns /// true if found (false otherwise). `pos' is set to the (column) position of /// the identifier. - bool findId(const Value &id, unsigned *pos) const; + bool findId(Value &id, unsigned *pos) const; /// Returns true if an identifier with the specified Value exists, false /// otherwise. - bool containsId(const Value &id) const; + bool containsId(Value &id) const; // Add identifiers of the specified kind - specified positions are relative to // the kind of identifier. The coefficient column corresponding to the added diff --git a/mlir/include/mlir/Analysis/Dominance.h b/mlir/include/mlir/Analysis/Dominance.h index d88c002a274..4aa8c0463d4 100644 --- a/mlir/include/mlir/Analysis/Dominance.h +++ b/mlir/include/mlir/Analysis/Dominance.h @@ -66,18 +66,18 @@ public: using super::super; /// Return true if instruction A properly dominates instruction B. - bool properlyDominates(const Instruction *a, const Instruction *b); + bool properlyDominates(Instruction *a, Instruction *b); /// Return true if instruction A dominates instruction B. - bool dominates(const Instruction *a, const Instruction *b) { + bool dominates(Instruction *a, Instruction *b) { return a == b || properlyDominates(a, b); } /// Return true if value A properly dominates instruction B. - bool properlyDominates(const Value *a, const Instruction *b); + bool properlyDominates(Value *a, Instruction *b); /// Return true if instruction A dominates instruction B. - bool dominates(const Value *a, const Instruction *b) { + bool dominates(Value *a, Instruction *b) { return (Instruction *)a->getDefiningInst() == b || properlyDominates(a, b); } @@ -98,10 +98,10 @@ public: using super::super; /// Return true if instruction A properly postdominates instruction B. - bool properlyPostDominates(const Instruction *a, const Instruction *b); + bool properlyPostDominates(Instruction *a, Instruction *b); /// Return true if instruction A postdominates instruction B. - bool postDominates(const Instruction *a, const Instruction *b) { + bool postDominates(Instruction *a, Instruction *b) { return a == b || properlyPostDominates(a, b); } diff --git a/mlir/include/mlir/Analysis/LoopAnalysis.h b/mlir/include/mlir/Analysis/LoopAnalysis.h index a5222c58cec..7d5ebeed054 100644 --- a/mlir/include/mlir/Analysis/LoopAnalysis.h +++ b/mlir/include/mlir/Analysis/LoopAnalysis.h @@ -72,7 +72,7 @@ uint64_t getLargestDivisorOfTripCount(OpPointer forOp); /// /// Returns false in cases with more than one AffineApplyOp, this is /// conservative. -bool isAccessInvariant(const Value &iv, const Value &index); +bool isAccessInvariant(Value &iv, Value &index); /// Given an induction variable `iv` of type AffineForOp and `indices` of type /// IndexType, returns the set of `indices` that are independent of `iv`. @@ -83,8 +83,8 @@ bool isAccessInvariant(const Value &iv, const Value &index); /// /// Returns false in cases with more than one AffineApplyOp, this is /// conservative. -llvm::DenseSet> -getInvariantAccesses(const Value &iv, llvm::ArrayRef indices); +llvm::DenseSet> +getInvariantAccesses(Value &iv, llvm::ArrayRef indices); /// Checks whether the loop is structurally vectorizable; i.e.: /// 1. the loop has proper dependence semantics (parallel, reduction, etc); diff --git a/mlir/include/mlir/Analysis/NestedMatcher.h b/mlir/include/mlir/Analysis/NestedMatcher.h index 44fe4c0558a..64bdfb4f941 100644 --- a/mlir/include/mlir/Analysis/NestedMatcher.h +++ b/mlir/include/mlir/Analysis/NestedMatcher.h @@ -94,8 +94,8 @@ private: /// aggressive unrolling. As experience has shown, it is generally better to use /// a plain walk over instructions to match flat patterns but the current /// implementation is competitive nonetheless. -using FilterFunctionType = std::function; -static bool defaultFilterFunction(const Instruction &) { return true; }; +using FilterFunctionType = std::function; +static bool defaultFilterFunction(Instruction &) { return true; }; struct NestedPattern { NestedPattern(ArrayRef nested, FilterFunctionType filter = defaultFilterFunction); @@ -182,9 +182,9 @@ NestedPattern For(ArrayRef nested = {}); NestedPattern For(FilterFunctionType filter, ArrayRef nested = {}); -bool isParallelLoop(const Instruction &inst); -bool isReductionLoop(const Instruction &inst); -bool isLoadOrStore(const Instruction &inst); +bool isParallelLoop(Instruction &inst); +bool isReductionLoop(Instruction &inst); +bool isLoadOrStore(Instruction &inst); } // end namespace matcher } // end namespace mlir diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index aa5b5c54720..0982849302c 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -48,12 +48,12 @@ class Value; /// Populates 'loops' with IVs of the loops surrounding 'inst' ordered from /// the outermost 'for' instruction to the innermost one. // TODO(bondhugula): handle 'if' inst's. -void getLoopIVs(const Instruction &inst, +void getLoopIVs(Instruction &inst, SmallVectorImpl> *loops); /// Returns the nesting depth of this instruction, i.e., the number of loops /// surrounding this instruction. -unsigned getNestingDepth(const Instruction &stmt); +unsigned getNestingDepth(Instruction &inst); /// Returns in 'sequentialLoops' all sequential loops in loop nest rooted /// at 'forOp'. @@ -231,8 +231,7 @@ LogicalResult boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, bool emitError = true); /// Returns the number of surrounding loops common to both A and B. -unsigned getNumCommonSurroundingLoops(const Instruction &A, - const Instruction &B); +unsigned getNumCommonSurroundingLoops(Instruction &A, Instruction &B); /// Gets the memory footprint of all data touched in the specified memory space /// in bytes; if the memory space is unspecified, considers all memory spaces. diff --git a/mlir/include/mlir/Analysis/VectorAnalysis.h b/mlir/include/mlir/Analysis/VectorAnalysis.h index 4982481bf6c..f8ed1dd2819 100644 --- a/mlir/include/mlir/Analysis/VectorAnalysis.h +++ b/mlir/include/mlir/Analysis/VectorAnalysis.h @@ -135,7 +135,7 @@ namespace matcher { /// TODO(ntv): this could all be much simpler if we added a bit that a vector /// type to mark that a vector is a strict super-vector but it still does not /// warrant adding even 1 extra bit in the IR for now. -bool operatesOnSuperVectors(const Instruction &inst, VectorType subVectorType); +bool operatesOnSuperVectors(Instruction &inst, VectorType subVectorType); } // end namespace matcher } // end namespace mlir diff --git a/mlir/include/mlir/Dialect/Traits.h b/mlir/include/mlir/Dialect/Traits.h index c25f2151ba3..ffaf5661769 100644 --- a/mlir/include/mlir/Dialect/Traits.h +++ b/mlir/include/mlir/Dialect/Traits.h @@ -32,7 +32,7 @@ namespace OpTrait { // corresponding trait classes. This avoids them being template // instantiated/duplicated. namespace impl { -bool verifyCompatibleOperandBroadcast(const Instruction *op); +bool verifyCompatibleOperandBroadcast(Instruction *op); } // namespace impl namespace util { @@ -78,7 +78,7 @@ template class BroadcastableTwoOperandsOneResult : public TraitBase { public: - static bool verifyTrait(const Instruction *op) { + static bool verifyTrait(Instruction *op) { return impl::verifyCompatibleOperandBroadcast(op); } }; diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h index f373f73bf56..d964878237f 100644 --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -173,7 +173,7 @@ public: /// the latter fails. /// TODO: This is very specific functionality that should live somewhere else, /// probably in Dominance.cpp. - Instruction *findAncestorInstInBlock(const Instruction &inst); + Instruction *findAncestorInstInBlock(Instruction &inst); /// This drops all operand uses from instructions within this block, which is /// an essential step in breaking cyclic dependences between references when diff --git a/mlir/include/mlir/IR/BlockAndValueMapping.h b/mlir/include/mlir/IR/BlockAndValueMapping.h index 2bac95bc39d..8f4c4ce651f 100644 --- a/mlir/include/mlir/IR/BlockAndValueMapping.h +++ b/mlir/include/mlir/IR/BlockAndValueMapping.h @@ -37,7 +37,7 @@ public: /// Inserts a new mapping for 'from' to 'to'. If there is an existing mapping, /// it is overwritten. void map(Block *from, Block *to) { valueMap[from] = to; } - void map(const Value *from, Value *to) { valueMap[from] = to; } + void map(Value *from, Value *to) { valueMap[from] = to; } /// Erases a mapping for 'from'. void erase(const IRObjectWithUseList *from) { valueMap.erase(from); } @@ -52,7 +52,7 @@ public: Block *lookupOrNull(Block *from) const { return lookupOrValue(from, (Block *)nullptr); } - Value *lookupOrNull(const Value *from) const { + Value *lookupOrNull(Value *from) const { return lookupOrValue(from, (Value *)nullptr); } diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 44df05f7380..13b58c40ab3 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -268,12 +268,12 @@ public: /// ( leaving them alone if no entry is present). Replaces references to /// cloned sub-instructions to the corresponding instruction that is copied, /// and adds those mappings to the map. - Instruction *clone(const Instruction &inst, BlockAndValueMapping &mapper) { + Instruction *clone(Instruction &inst, BlockAndValueMapping &mapper) { Instruction *cloneInst = inst.clone(mapper, getContext()); block->getInstructions().insert(insertPoint, cloneInst); return cloneInst; } - Instruction *clone(const Instruction &inst) { + Instruction *clone(Instruction &inst) { Instruction *cloneInst = inst.clone(getContext()); block->getInstructions().insert(insertPoint, cloneInst); return cloneInst; diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h index cf1a6e5cb72..beabbb11237 100644 --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -32,7 +32,7 @@ class Type; using DialectConstantDecodeHook = std::function; using DialectConstantFoldHook = std::function, SmallVectorImpl &)>; + Instruction *, ArrayRef, SmallVectorImpl &)>; using DialectExtractElementHook = std::function)>; @@ -57,7 +57,7 @@ public: /// `results` vector. If not, this returns failure and `results` is /// unspecified. DialectConstantFoldHook constantFoldHook = - [](const Instruction *op, ArrayRef operands, + [](Instruction *op, ArrayRef operands, SmallVectorImpl &results) { return failure(); }; /// Registered hook to decode opaque constants associated with this @@ -117,7 +117,7 @@ public: /// Verify an attribute from this dialect on the given instruction. Returns /// true if the verification failed, false otherwise. - virtual bool verifyInstructionAttribute(const Instruction *, NamedAttribute) { + virtual bool verifyInstructionAttribute(Instruction *, NamedAttribute) { return false; } diff --git a/mlir/include/mlir/IR/Instruction.h b/mlir/include/mlir/IR/Instruction.h index ccdb3fba7da..4b164627a2a 100644 --- a/mlir/include/mlir/IR/Instruction.h +++ b/mlir/include/mlir/IR/Instruction.h @@ -68,11 +68,11 @@ public: bool resizableOperandList, MLIRContext *context); /// The name of an operation is the key identifier for it. - OperationName getName() const { return name; } + OperationName getName() { return name; } /// If this operation has a registered operation description, return it. /// Otherwise return null. - const AbstractOperation *getAbstractOperation() const { + const AbstractOperation *getAbstractOperation() { return getName().getAbstractOperation(); } @@ -84,29 +84,29 @@ public: /// them alone if no entry is present). Replaces references to cloned /// sub-instructions to the corresponding instruction that is copied, and adds /// those mappings to the map. - Instruction *clone(BlockAndValueMapping &mapper, MLIRContext *context) const; - Instruction *clone(MLIRContext *context) const; + Instruction *clone(BlockAndValueMapping &mapper, MLIRContext *context); + Instruction *clone(MLIRContext *context); /// Returns the instruction block that contains this instruction. - Block *getBlock() const { return block; } + Block *getBlock() { return block; } /// Return the context this operation is associated with. - MLIRContext *getContext() const; + MLIRContext *getContext(); /// The source location the operation was defined or derived from. - Location getLoc() const { return location; } + Location getLoc() { return location; } /// Set the source location the operation was defined or derived from. void setLoc(Location loc) { location = loc; } /// Returns the closest surrounding instruction that contains this instruction /// or nullptr if this is a top-level instruction. - Instruction *getParentInst() const; + Instruction *getParentInst(); /// Returns the function that this instruction is part of. /// The function is determined by traversing the chain of parent instructions. /// Returns nullptr if the instruction is unlinked. - Function *getFunction() const; + Function *getFunction(); /// Destroys this instruction and its subclass data. void destroy(); @@ -130,10 +130,10 @@ public: /// of the parent block. /// Note: This function has an average complexity of O(1), but worst case may /// take O(N) where N is the number of instructions within the parent block. - bool isBeforeInBlock(const Instruction *other) const; + bool isBeforeInBlock(Instruction *other); - void print(raw_ostream &os) const; - void dump() const; + void print(raw_ostream &os); + void dump(); //===--------------------------------------------------------------------===// // Operands @@ -141,9 +141,7 @@ public: /// Returns if the operation has a resizable operation list, i.e. operands can /// be added. - bool hasResizableOperandsList() const { - return getOperandStorage().isResizable(); - } + bool hasResizableOperandsList() { return getOperandStorage().isResizable(); } /// Replace the current operands of this operation with the ones provided in /// 'operands'. If the operands list is not resizable, the size of 'operands' @@ -152,12 +150,9 @@ public: getOperandStorage().setOperands(this, operands); } - unsigned getNumOperands() const { return getOperandStorage().size(); } + unsigned getNumOperands() { return getOperandStorage().size(); } Value *getOperand(unsigned idx) { return getInstOperand(idx).get(); } - const Value *getOperand(unsigned idx) const { - return getInstOperand(idx).get(); - } void setOperand(unsigned idx, Value *value) { return getInstOperand(idx).set(value); } @@ -172,77 +167,40 @@ public: /// Returns an iterator on the underlying Value's (Value *). operand_range getOperands(); - // Support const operand iteration. - using const_operand_iterator = - OperandIterator; - using const_operand_range = llvm::iterator_range; - - const_operand_iterator operand_begin() const; - const_operand_iterator operand_end() const; - - /// Returns a const iterator on the underlying Value's (Value *). - llvm::iterator_range getOperands() const; - - ArrayRef getInstOperands() const { - return getOperandStorage().getInstOperands(); - } MutableArrayRef getInstOperands() { return getOperandStorage().getInstOperands(); } InstOperand &getInstOperand(unsigned idx) { return getInstOperands()[idx]; } - const InstOperand &getInstOperand(unsigned idx) const { - return getInstOperands()[idx]; - } //===--------------------------------------------------------------------===// // Results //===--------------------------------------------------------------------===// /// Return true if there are no users of any results of this operation. - bool use_empty() const; + bool use_empty(); - unsigned getNumResults() const { return numResults; } + unsigned getNumResults() { return numResults; } Value *getResult(unsigned idx) { return &getInstResult(idx); } - const Value *getResult(unsigned idx) const { return &getInstResult(idx); } - // Support non-const result iteration. + // Support result iteration. using result_iterator = ResultIterator; result_iterator result_begin(); result_iterator result_end(); llvm::iterator_range getResults(); - // Support const result iteration. - using const_result_iterator = ResultIterator; - const_result_iterator result_begin() const; - - const_result_iterator result_end() const; - - llvm::iterator_range getResults() const; - - ArrayRef getInstResults() const { - return {getTrailingObjects(), numResults}; - } - MutableArrayRef getInstResults() { return {getTrailingObjects(), numResults}; } InstResult &getInstResult(unsigned idx) { return getInstResults()[idx]; } - const InstResult &getInstResult(unsigned idx) const { - return getInstResults()[idx]; - } - // Support result type iteration. - using result_type_iterator = - ResultTypeIterator; - result_type_iterator result_type_begin() const; - - result_type_iterator result_type_end() const; - - llvm::iterator_range getResultTypes() const; + using result_type_iterator = ResultTypeIterator; + result_type_iterator result_type_begin(); + result_type_iterator result_type_end(); + llvm::iterator_range getResultTypes(); //===--------------------------------------------------------------------===// // Attributes @@ -253,17 +211,17 @@ public: // the lifetime of an instruction. /// Return all of the attributes on this instruction. - ArrayRef getAttrs() const { return attrs.getAttrs(); } + ArrayRef getAttrs() { return attrs.getAttrs(); } /// Return the specified attribute if present, null otherwise. - Attribute getAttr(Identifier name) const { return attrs.get(name); } - Attribute getAttr(StringRef name) const { return attrs.get(name); } + Attribute getAttr(Identifier name) { return attrs.get(name); } + Attribute getAttr(StringRef name) { return attrs.get(name); } - template AttrClass getAttrOfType(Identifier name) const { + template AttrClass getAttrOfType(Identifier name) { return getAttr(name).dyn_cast_or_null(); } - template AttrClass getAttrOfType(StringRef name) const { + template AttrClass getAttrOfType(StringRef name) { return getAttr(name).dyn_cast_or_null(); } @@ -287,16 +245,16 @@ public: //===--------------------------------------------------------------------===// /// Returns the number of regions held by this operation. - unsigned getNumRegions() const { return numRegions; } + unsigned getNumRegions() { return numRegions; } /// Returns the regions held by this operation. - MutableArrayRef getRegions() const { + MutableArrayRef getRegions() { auto *regions = getTrailingObjects(); - return {const_cast(regions), numRegions}; + return {regions, numRegions}; } /// Returns the region held by this operation at position 'index'. - Region &getRegion(unsigned index) const { + Region &getRegion(unsigned index) { assert(index < numRegions && "invalid region index"); return getRegions()[index]; } @@ -308,15 +266,10 @@ public: MutableArrayRef getBlockOperands() { return {getTrailingObjects(), numSuccs}; } - ArrayRef getBlockOperands() const { - return const_cast(this)->getBlockOperands(); - } /// Return the operands of this operation that are *not* successor arguments. - const_operand_range getNonSuccessorOperands() const; operand_range getNonSuccessorOperands(); - const_operand_range getSuccessorOperands(unsigned index) const; operand_range getSuccessorOperands(unsigned index); Value *getSuccessorOperand(unsigned succIndex, unsigned opIndex) { @@ -324,19 +277,15 @@ public: assert(opIndex < getNumSuccessorOperands(succIndex)); return getOperand(getSuccessorOperandIndex(succIndex) + opIndex); } - const Value *getSuccessorOperand(unsigned succIndex, unsigned index) const { - return const_cast(this)->getSuccessorOperand(succIndex, - index); - } - unsigned getNumSuccessors() const { return numSuccs; } - unsigned getNumSuccessorOperands(unsigned index) const { + unsigned getNumSuccessors() { return numSuccs; } + unsigned getNumSuccessorOperands(unsigned index) { assert(!isKnownNonTerminator() && "only terminators may have successors"); assert(index < getNumSuccessors()); return getTrailingObjects()[index]; } - Block *getSuccessor(unsigned index) const { + Block *getSuccessor(unsigned index) { assert(index < getNumSuccessors()); return getBlockOperands()[index].get(); } @@ -354,21 +303,21 @@ public: /// Get the index of the first operand of the successor at the provided /// index. - unsigned getSuccessorOperandIndex(unsigned index) const; + unsigned getSuccessorOperandIndex(unsigned index); //===--------------------------------------------------------------------===// // Accessors for various properties of operations //===--------------------------------------------------------------------===// /// Returns whether the operation is commutative. - bool isCommutative() const { + bool isCommutative() { if (auto *absOp = getAbstractOperation()) return absOp->hasProperty(OperationProperty::Commutative); return false; } /// Returns whether the operation has side-effects. - bool hasNoSideEffect() const { + bool hasNoSideEffect() { if (auto *absOp = getAbstractOperation()) return absOp->hasProperty(OperationProperty::NoSideEffect); return false; @@ -380,7 +329,7 @@ public: enum class TerminatorStatus { Terminator, NonTerminator, Unknown }; /// Returns the status of whether this operation is a terminator or not. - TerminatorStatus getTerminatorStatus() const { + TerminatorStatus getTerminatorStatus() { if (auto *absOp = getAbstractOperation()) { return absOp->hasProperty(OperationProperty::Terminator) ? TerminatorStatus::Terminator @@ -390,12 +339,12 @@ public: } /// Returns if the operation is known to be a terminator. - bool isKnownTerminator() const { + bool isKnownTerminator() { return getTerminatorStatus() == TerminatorStatus::Terminator; } /// Returns if the operation is known to *not* be a terminator. - bool isKnownNonTerminator() const { + bool isKnownNonTerminator() { return getTerminatorStatus() == TerminatorStatus::NonTerminator; } @@ -405,7 +354,7 @@ public: /// constant folding is successful, this fills in the `results` vector. If /// not, `results` is unspecified. LogicalResult constantFold(ArrayRef operands, - SmallVectorImpl &results) const; + SmallVectorImpl &results); /// Attempt to fold this operation using the Op's registered foldHook. LogicalResult fold(SmallVectorImpl &results); @@ -421,7 +370,7 @@ public: /// The dyn_cast methods perform a dynamic cast from an Instruction to a typed /// Op like DimOp. This returns a null OpPointer on failure. - template OpPointer dyn_cast() const { + template OpPointer dyn_cast() { if (isa()) { return cast(); } else { @@ -432,16 +381,14 @@ public: /// The cast methods perform a cast from an Instruction to a typed Op like /// DimOp. This aborts if the parameter to the template isn't an instance of /// the template type argument. - template OpPointer cast() const { + template OpPointer cast() { assert(isa() && "cast() argument of incompatible type!"); return OpPointer(OpClass(this)); } /// The is methods return true if the operation is a typed op (like DimOp) of /// of the given class. - template bool isa() const { - return OpClass::isClassFor(const_cast(this)); - } + template bool isa() { return OpClass::isClassFor(this); } //===--------------------------------------------------------------------===// // Instruction Walkers @@ -479,21 +426,21 @@ public: /// Emit an error with the op name prefixed, like "'dim' op " which is /// convenient for verifiers. This function always returns true. - bool emitOpError(const Twine &message) const; + bool emitOpError(const Twine &message); /// Emit an error about fatal conditions with this operation, reporting up to /// any diagnostic handlers that may be listening. This function always /// returns true. NOTE: This may terminate the containing application, only /// use when the IR is in an inconsistent state. - bool emitError(const Twine &message) const; + bool emitError(const Twine &message); /// Emit a warning about this operation, reporting up to any diagnostic /// handlers that may be listening. - void emitWarning(const Twine &message) const; + void emitWarning(const Twine &message); /// Emit a note about this operation, reporting up to any diagnostic /// handlers that may be listening. - void emitNote(const Twine &message) const; + void emitNote(const Twine &message); private: Instruction(Location location, OperationName name, unsigned numResults, @@ -508,12 +455,9 @@ private: detail::OperandStorage &getOperandStorage() { return *getTrailingObjects(); } - const detail::OperandStorage &getOperandStorage() const { - return *getTrailingObjects(); - } // Provide a 'getParent' method for ilist_node_with_parent methods. - Block *getParent() const { return getBlock(); } + Block *getParent() { return getBlock(); } /// The instruction block that containts this instruction. Block *block = nullptr; @@ -556,7 +500,7 @@ private: size_t numTrailingObjects(OverloadToken) const { return numSuccs; } }; -inline raw_ostream &operator<<(raw_ostream &os, const Instruction &inst) { +inline raw_ostream &operator<<(raw_ostream &os, Instruction &inst) { inst.print(os); return os; } @@ -573,13 +517,6 @@ public: : IndexedAccessorIterator, ObjectType, ElementType>(object, index) {} - /// Support converting to the const variant. This will be a no-op for const - /// variant. - operator OperandIterator() const { - return OperandIterator(this->object, - this->index); - } - ElementType *operator*() const { return this->object->getOperand(this->index); } @@ -598,18 +535,6 @@ inline auto Instruction::getOperands() -> operand_range { return {operand_begin(), operand_end()}; } -inline auto Instruction::operand_begin() const -> const_operand_iterator { - return const_operand_iterator(this, 0); -} - -inline auto Instruction::operand_end() const -> const_operand_iterator { - return const_operand_iterator(this, getNumOperands()); -} - -inline auto Instruction::getOperands() const -> const_operand_range { - return {operand_begin(), operand_end()}; -} - /// This template implements the result iterators for the Instruction class /// in terms of getResult(idx). template @@ -622,13 +547,6 @@ public: : IndexedAccessorIterator, ObjectType, ElementType>(object, index) {} - /// Support converting to the const variant. This will be a no-op for const - /// variant. - operator ResultIterator() const { - return ResultIterator(this->object, - this->index); - } - ElementType *operator*() const { return this->object->getResult(this->index); } @@ -672,28 +590,15 @@ inline auto Instruction::getResults() -> llvm::iterator_range { return {result_begin(), result_end()}; } -inline auto Instruction::result_begin() const -> const_result_iterator { - return const_result_iterator(this, 0); -} - -inline auto Instruction::result_end() const -> const_result_iterator { - return const_result_iterator(this, getNumResults()); -} - -inline auto Instruction::getResults() const - -> llvm::iterator_range { - return {result_begin(), result_end()}; -} - -inline auto Instruction::result_type_begin() const -> result_type_iterator { +inline auto Instruction::result_type_begin() -> result_type_iterator { return result_type_iterator(this, 0); } -inline auto Instruction::result_type_end() const -> result_type_iterator { +inline auto Instruction::result_type_end() -> result_type_iterator { return result_type_iterator(this, getNumResults()); } -inline auto Instruction::getResultTypes() const +inline auto Instruction::getResultTypes() -> llvm::iterator_range { return {result_type_begin(), result_type_end()}; } diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index dc6ddd3df88..d21e40818d5 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -103,7 +103,7 @@ private: class OpState { public: /// Return the operation that this refers to. - const Instruction *getInstruction() const { return state; } + Instruction *getInstruction() const { return state; } Instruction *getInstruction() { return state; } /// The source location the operation was defined or derived from. @@ -176,7 +176,7 @@ protected: /// Mutability management is handled by the OpWrapper/OpConstWrapper classes, /// so we can cast it away here. - explicit OpState(const Instruction *state) + explicit OpState(Instruction *state) : state(const_cast(state)) {} private: @@ -327,22 +327,22 @@ namespace OpTrait { // corresponding trait classes. This avoids them being template // instantiated/duplicated. namespace impl { -bool verifyZeroOperands(const Instruction *op); -bool verifyOneOperand(const Instruction *op); -bool verifyNOperands(const Instruction *op, unsigned numOperands); -bool verifyAtLeastNOperands(const Instruction *op, unsigned numOperands); -bool verifyOperandsAreIntegerLike(const Instruction *op); -bool verifySameTypeOperands(const Instruction *op); -bool verifyZeroResult(const Instruction *op); -bool verifyOneResult(const Instruction *op); -bool verifyNResults(const Instruction *op, unsigned numOperands); -bool verifyAtLeastNResults(const Instruction *op, unsigned numOperands); -bool verifySameOperandsAndResultShape(const Instruction *op); -bool verifySameOperandsAndResultType(const Instruction *op); -bool verifyResultsAreBoolLike(const Instruction *op); -bool verifyResultsAreFloatLike(const Instruction *op); -bool verifyResultsAreIntegerLike(const Instruction *op); -bool verifyIsTerminator(const Instruction *op); +bool verifyZeroOperands(Instruction *op); +bool verifyOneOperand(Instruction *op); +bool verifyNOperands(Instruction *op, unsigned numOperands); +bool verifyAtLeastNOperands(Instruction *op, unsigned numOperands); +bool verifyOperandsAreIntegerLike(Instruction *op); +bool verifySameTypeOperands(Instruction *op); +bool verifyZeroResult(Instruction *op); +bool verifyOneResult(Instruction *op); +bool verifyNResults(Instruction *op, unsigned numOperands); +bool verifyAtLeastNResults(Instruction *op, unsigned numOperands); +bool verifySameOperandsAndResultShape(Instruction *op); +bool verifySameOperandsAndResultType(Instruction *op); +bool verifyResultsAreBoolLike(Instruction *op); +bool verifyResultsAreFloatLike(Instruction *op); +bool verifyResultsAreIntegerLike(Instruction *op); +bool verifyIsTerminator(Instruction *op); } // namespace impl /// Helper class for implementing traits. Clients are not expected to interact @@ -361,13 +361,13 @@ protected: auto *base = static_cast(concrete); return base->getInstruction(); } - const Instruction *getInstruction() const { + Instruction *getInstruction() const { return const_cast(this)->getInstruction(); } /// Provide default implementations of trait hooks. This allows traits to /// provide exactly the overrides they care about. - static bool verifyTrait(const Instruction *op) { return false; } + static bool verifyTrait(Instruction *op) { return false; } static AbstractOperation::OperationProperties getTraitProperties() { return 0; } @@ -378,7 +378,7 @@ protected: template class ZeroOperands : public TraitBase { public: - static bool verifyTrait(const Instruction *op) { + static bool verifyTrait(Instruction *op) { return impl::verifyZeroOperands(op); } @@ -393,17 +393,13 @@ private: template class OneOperand : public TraitBase { public: - const Value *getOperand() const { - return this->getInstruction()->getOperand(0); - } - - Value *getOperand() { return this->getInstruction()->getOperand(0); } + Value *getOperand() const { return this->getInstruction()->getOperand(0); } void setOperand(Value *value) { this->getInstruction()->setOperand(0, value); } - static bool verifyTrait(const Instruction *op) { + static bool verifyTrait(Instruction *op) { return impl::verifyOneOperand(op); } }; @@ -418,11 +414,7 @@ public: template class Impl : public TraitBase::Impl> { public: - const Value *getOperand(unsigned i) const { - return this->getInstruction()->getOperand(i); - } - - Value *getOperand(unsigned i) { + Value *getOperand(unsigned i) const { return this->getInstruction()->getOperand(i); } @@ -430,7 +422,7 @@ public: this->getInstruction()->setOperand(i, value); } - static bool verifyTrait(const Instruction *op) { + static bool verifyTrait(Instruction *op) { return impl::verifyNOperands(op, N); } }; @@ -449,11 +441,7 @@ public: unsigned getNumOperands() const { return this->getInstruction()->getNumOperands(); } - const Value *getOperand(unsigned i) const { - return this->getInstruction()->getOperand(i); - } - - Value *getOperand(unsigned i) { + Value *getOperand(unsigned i) const { return this->getInstruction()->getOperand(i); } @@ -473,19 +461,7 @@ public: return this->getInstruction()->getOperands(); } - // Support const operand iteration. - using const_operand_iterator = Instruction::const_operand_iterator; - const_operand_iterator operand_begin() const { - return this->getInstruction()->operand_begin(); - } - const_operand_iterator operand_end() const { - return this->getInstruction()->operand_end(); - } - llvm::iterator_range getOperands() const { - return this->getInstruction()->getOperands(); - } - - static bool verifyTrait(const Instruction *op) { + static bool verifyTrait(Instruction *op) { return impl::verifyAtLeastNOperands(op, N); } }; @@ -500,11 +476,7 @@ public: return this->getInstruction()->getNumOperands(); } - const Value *getOperand(unsigned i) const { - return this->getInstruction()->getOperand(i); - } - - Value *getOperand(unsigned i) { + Value *getOperand(unsigned i) const { return this->getInstruction()->getOperand(i); } @@ -522,19 +494,6 @@ public: return this->getInstruction()->operand_end(); } operand_range getOperands() { return this->getInstruction()->getOperands(); } - - // Support const operand iteration. - using const_operand_iterator = Instruction::const_operand_iterator; - using const_operand_range = Instruction::const_operand_range; - const_operand_iterator operand_begin() const { - return this->getInstruction()->operand_begin(); - } - const_operand_iterator operand_end() const { - return this->getInstruction()->operand_end(); - } - const_operand_range getOperands() const { - return this->getInstruction()->getOperands(); - } }; /// This class provides return value APIs for ops that are known to have @@ -542,7 +501,7 @@ public: template class ZeroResult : public TraitBase { public: - static bool verifyTrait(const Instruction *op) { + static bool verifyTrait(Instruction *op) { return impl::verifyZeroResult(op); } }; @@ -552,10 +511,7 @@ public: template class OneResult : public TraitBase { public: - Value *getResult() { return this->getInstruction()->getResult(0); } - const Value *getResult() const { - return this->getInstruction()->getResult(0); - } + Value *getResult() const { return this->getInstruction()->getResult(0); } Type getType() const { return getResult()->getType(); } @@ -566,9 +522,7 @@ public: getResult()->replaceAllUsesWith(newValue); } - static bool verifyTrait(const Instruction *op) { - return impl::verifyOneResult(op); - } + static bool verifyTrait(Instruction *op) { return impl::verifyOneResult(op); } }; /// This class provides the API for ops that are known to have a specified @@ -583,17 +537,13 @@ public: public: static unsigned getNumResults() { return N; } - const Value *getResult(unsigned i) const { - return this->getInstruction()->getResult(i); - } - - Value *getResult(unsigned i) { + Value *getResult(unsigned i) const { return this->getInstruction()->getResult(i); } Type getType(unsigned i) const { return getResult(i)->getType(); } - static bool verifyTrait(const Instruction *op) { + static bool verifyTrait(Instruction *op) { return impl::verifyNResults(op, N); } }; @@ -609,17 +559,13 @@ public: template class Impl : public TraitBase::Impl> { public: - const Value *getResult(unsigned i) const { - return this->getInstruction()->getResult(i); - } - - Value *getResult(unsigned i) { + Value *getResult(unsigned i) const { return this->getInstruction()->getResult(i); } Type getType(unsigned i) const { return getResult(i)->getType(); } - static bool verifyTrait(const Instruction *op) { + static bool verifyTrait(Instruction *op) { return impl::verifyAtLeastNResults(op, N); } }; @@ -634,12 +580,10 @@ public: return this->getInstruction()->getNumResults(); } - const Value *getResult(unsigned i) const { + Value *getResult(unsigned i) const { return this->getInstruction()->getResult(i); } - Value *getResult(unsigned i) { return this->getInstruction()->getResult(i); } - void setResult(unsigned i, Value *value) { this->getInstruction()->setResult(i, value); } @@ -653,18 +597,6 @@ public: llvm::iterator_range getResults() { return this->getInstruction()->getResults(); } - - // Support const result iteration. - using const_result_iterator = Instruction::const_result_iterator; - const_result_iterator result_begin() const { - return this->getInstruction()->result_begin(); - } - const_result_iterator result_end() const { - return this->getInstruction()->result_end(); - } - llvm::iterator_range getResults() const { - return this->getInstruction()->getResults(); - } }; /// This class provides verification for ops that are known to have the same @@ -674,7 +606,7 @@ template class SameOperandsAndResultShape : public TraitBase { public: - static bool verifyTrait(const Instruction *op) { + static bool verifyTrait(Instruction *op) { return impl::verifySameOperandsAndResultShape(op); } }; @@ -689,7 +621,7 @@ template class SameOperandsAndResultType : public TraitBase { public: - static bool verifyTrait(const Instruction *op) { + static bool verifyTrait(Instruction *op) { return impl::verifySameOperandsAndResultType(op); } }; @@ -699,7 +631,7 @@ public: template class ResultsAreBoolLike : public TraitBase { public: - static bool verifyTrait(const Instruction *op) { + static bool verifyTrait(Instruction *op) { return impl::verifyResultsAreBoolLike(op); } }; @@ -710,7 +642,7 @@ template class ResultsAreFloatLike : public TraitBase { public: - static bool verifyTrait(const Instruction *op) { + static bool verifyTrait(Instruction *op) { return impl::verifyResultsAreFloatLike(op); } }; @@ -721,7 +653,7 @@ template class ResultsAreIntegerLike : public TraitBase { public: - static bool verifyTrait(const Instruction *op) { + static bool verifyTrait(Instruction *op) { return impl::verifyResultsAreIntegerLike(op); } }; @@ -752,7 +684,7 @@ template class OperandsAreIntegerLike : public TraitBase { public: - static bool verifyTrait(const Instruction *op) { + static bool verifyTrait(Instruction *op) { return impl::verifyOperandsAreIntegerLike(op); } }; @@ -762,7 +694,7 @@ public: template class SameTypeOperands : public TraitBase { public: - static bool verifyTrait(const Instruction *op) { + static bool verifyTrait(Instruction *op) { return impl::verifySameTypeOperands(op); } }; @@ -775,7 +707,7 @@ public: return static_cast( OperationProperty::Terminator); } - static bool verifyTrait(const Instruction *op) { + static bool verifyTrait(Instruction *op) { return impl::verifyIsTerminator(op); } @@ -820,10 +752,7 @@ class Op : public OpState, Traits...>::value> { public: /// Return the operation that this refers to. - const Instruction *getInstruction() const { - return OpState::getInstruction(); - } - Instruction *getInstruction() { return OpState::getInstruction(); } + Instruction *getInstruction() const { return OpState::getInstruction(); } /// Return true if this "op class" can match against the specified operation. /// This hook can be overridden with a more specific implementation in @@ -875,20 +804,20 @@ public: using ConcreteOpType = ConcreteType; protected: - explicit Op(const Instruction *state) : OpState(state) {} + explicit Op(Instruction *state) : OpState(state) {} private: template struct BaseVerifier; template struct BaseVerifier { - static bool verifyTrait(const Instruction *op) { + static bool verifyTrait(Instruction *op) { return First::verifyTrait(op) || BaseVerifier::verifyTrait(op); } }; template struct BaseVerifier { - static bool verifyTrait(const Instruction *op) { return false; } + static bool verifyTrait(Instruction *op) { return false; } }; template struct BaseProperties; @@ -917,7 +846,7 @@ bool parseBinaryOp(OpAsmParser *parser, OperationState *result); // Prints the given binary `op` in custom assembly form if both the two operands // and the result have the same time. Otherwise, prints the generic assembly // form. -void printBinaryOp(const Instruction *op, OpAsmPrinter *p); +void printBinaryOp(Instruction *op, OpAsmPrinter *p); } // namespace impl // These functions are out-of-line implementations of the methods in CastOp, @@ -926,7 +855,7 @@ namespace impl { void buildCastOp(Builder *builder, OperationState *result, Value *source, Type destType); bool parseCastOp(OpAsmParser *parser, OperationState *result); -void printCastOp(const Instruction *op, OpAsmPrinter *p); +void printCastOp(Instruction *op, OpAsmPrinter *p); } // namespace impl /// This template is used for operations that are cast operations, that have a @@ -951,7 +880,7 @@ public: } protected: - explicit CastOp(const Instruction *state) + explicit CastOp(Instruction *state) : Op(state) {} }; diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index f8465afe0ea..ae63f485d32 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -48,7 +48,7 @@ public: virtual raw_ostream &getStream() const = 0; /// Print implementations for various things an operation contains. - virtual void printOperand(const Value *value) = 0; + virtual void printOperand(Value *value) = 0; /// Print a comma separated list of operands. template @@ -76,8 +76,7 @@ public: /// Print a successor, and use list, of a terminator operation given the /// terminator and the successor index. - virtual void printSuccessorAndUseList(const Instruction *term, - unsigned index) = 0; + virtual void printSuccessorAndUseList(Instruction *term, unsigned index) = 0; /// If the specified operation has attributes, print out an attribute /// dictionary with their values. elidedAttrs allows the client to ignore @@ -87,7 +86,7 @@ public: ArrayRef elidedAttrs = {}) = 0; /// Print the entire operation with the default generic assembly form. - virtual void printGenericOp(const Instruction *op) = 0; + virtual void printGenericOp(Instruction *op) = 0; /// Prints a region. virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true) = 0; @@ -98,7 +97,7 @@ private: }; // Make the implementations convenient to use. -inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const Value &value) { +inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value &value) { p.printOperand(&value); return p; } diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index 8be6b34bacc..6796131c052 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -86,8 +86,8 @@ public: /// This hook implements the AsmPrinter for this operation. void (&printAssembly)(Instruction *op, OpAsmPrinter *p); - /// This hook implements the verifier for this operation. It should emit an - /// error message and returns true if a problem is detected, or return false + /// This hook implements the verifier for this operation. It should emits an + /// error message and returns true if a problem is detected, or returns false /// if everything is ok. bool (&verifyInvariants)(Instruction *op); diff --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h index 3d4493b0d6b..0529a5a2ebd 100644 --- a/mlir/include/mlir/IR/UseDefLists.h +++ b/mlir/include/mlir/IR/UseDefLists.h @@ -97,7 +97,7 @@ public: /// Return the owner of this operand. Instruction *getOwner() { return owner; } - const Instruction *getOwner() const { return owner; } + Instruction *getOwner() const { return owner; } /// \brief Remove this use of the operand. void drop() { @@ -176,13 +176,13 @@ public: : IROperand(owner, value) {} /// Return the current value being used by this operand. - IRValueTy *get() const { return (IRValueTy *)IROperand::get(); } + IRValueTy *get() { return (IRValueTy *)IROperand::get(); } /// Set the current value being used by this operand. void set(IRValueTy *newValue) { IROperand::set(newValue); } /// Return which operand this is in the operand list of the User. - unsigned getOperandNumber() const; + unsigned getOperandNumber(); }; /// An iterator over all uses of a ValueBase. diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h index fde40530596..dada49bc58f 100644 --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -48,9 +48,9 @@ public: ~Value() {} - Kind getKind() const { return typeAndKind.getInt(); } + Kind getKind() { return typeAndKind.getInt(); } - Type getType() const { return typeAndKind.getPointer(); } + Type getType() { return typeAndKind.getPointer(); } /// Replace all uses of 'this' value with the new value, updating anything in /// the IR that uses 'this' to use the other value instead. When this returns @@ -60,26 +60,23 @@ public: } /// Return the function that this Value is defined in. - Function *getFunction() const; + Function *getFunction(); /// If this value is the result of an operation, return the instruction /// that defines it. Instruction *getDefiningInst(); - const Instruction *getDefiningInst() const { - return const_cast(this)->getDefiningInst(); - } using use_iterator = ValueUseIterator; using use_range = llvm::iterator_range; - inline use_iterator use_begin() const; - inline use_iterator use_end() const; + inline use_iterator use_begin(); + inline use_iterator use_end(); /// Returns a range of all uses, which is useful for iterating over all uses. - inline use_range getUses() const; + inline use_range getUses(); - void print(raw_ostream &os) const; - void dump() const; + void print(raw_ostream &os); + void dump(); protected: Value(Kind kind, Type type) : typeAndKind(type, kind) {} @@ -88,21 +85,19 @@ private: const llvm::PointerIntPair typeAndKind; }; -inline raw_ostream &operator<<(raw_ostream &os, const Value &value) { +inline raw_ostream &operator<<(raw_ostream &os, Value &value) { value.print(os); return os; } // Utility functions for iterating through Value uses. -inline auto Value::use_begin() const -> use_iterator { +inline auto Value::use_begin() -> use_iterator { return use_iterator((InstOperand *)getFirstUse()); } -inline auto Value::use_end() const -> use_iterator { - return use_iterator(nullptr); -} +inline auto Value::use_end() -> use_iterator { return use_iterator(nullptr); } -inline auto Value::getUses() const -> llvm::iterator_range { +inline auto Value::getUses() -> llvm::iterator_range { return {use_begin(), use_end()}; } @@ -110,19 +105,19 @@ inline auto Value::getUses() const -> llvm::iterator_range { class BlockArgument : public Value { public: static bool classof(const Value *value) { - return value->getKind() == Kind::BlockArgument; + return const_cast(value)->getKind() == Kind::BlockArgument; } /// Return the function that this argument is defined in. - Function *getFunction() const; + Function *getFunction(); - Block *getOwner() const { return owner; } + Block *getOwner() { return owner; } /// Returns the number of this argument. - unsigned getArgNumber() const; + unsigned getArgNumber(); /// Returns if the current argument is a function argument. - bool isFunctionArgument() const; + bool isFunctionArgument(); private: friend class Block; // For access to private constructor. @@ -142,14 +137,13 @@ public: : Value(Value::Kind::InstResult, type), owner(owner) {} static bool classof(const Value *value) { - return value->getKind() == Kind::InstResult; + return const_cast(value)->getKind() == Kind::InstResult; } Instruction *getOwner() { return owner; } - const Instruction *getOwner() const { return owner; } /// Returns the number of this result. - unsigned getResultNumber() const; + unsigned getResultNumber(); private: /// The owner of this operand. diff --git a/mlir/include/mlir/StandardOps/Ops.h b/mlir/include/mlir/StandardOps/Ops.h index 827335bf5a9..1cd9392b984 100644 --- a/mlir/include/mlir/StandardOps/Ops.h +++ b/mlir/include/mlir/StandardOps/Ops.h @@ -35,7 +35,7 @@ class Builder; namespace detail { /// A custom binary operation printer that omits the "std." prefix from the /// operation names. -void printStandardBinaryOp(const Instruction *op, OpAsmPrinter *p); +void printStandardBinaryOp(Instruction *op, OpAsmPrinter *p); } // namespace detail class StandardOpsDialect : public Dialect { @@ -69,9 +69,7 @@ class AllocOp : public Op { public: /// The result of an alloc is always a MemRefType. - MemRefType getType() const { - return getResult()->getType().cast(); - } + MemRefType getType() { return getResult()->getType().cast(); } static StringRef getOperationName() { return "std.alloc"; } @@ -86,7 +84,7 @@ public: private: friend class Instruction; - explicit AllocOp(const Instruction *state) : Op(state) {} + explicit AllocOp(Instruction *state) : Op(state) {} }; /// The "br" operation represents a branch instruction in a function. @@ -113,7 +111,6 @@ public: /// Return the block this branch jumps to. Block *getDest(); - Block *getDest() const { return const_cast(this)->getDest(); } void setDest(Block *block); /// Erase the operand at 'index' from the operand list. @@ -121,7 +118,7 @@ public: private: friend class Instruction; - explicit BranchOp(const Instruction *state) : Op(state) {} + explicit BranchOp(Instruction *state) : Op(state) {} }; /// The "call" operation represents a direct call to a function. The operands @@ -138,21 +135,15 @@ public: static void build(Builder *builder, OperationState *result, Function *callee, ArrayRef operands); - Function *getCallee() const { + Function *getCallee() { return getAttrOfType("callee").getValue(); } /// Get the argument operands to the called function. - llvm::iterator_range getArgOperands() const { - return {arg_operand_begin(), arg_operand_end()}; - } llvm::iterator_range getArgOperands() { return {arg_operand_begin(), arg_operand_end()}; } - const_operand_iterator arg_operand_begin() const { return operand_begin(); } - const_operand_iterator arg_operand_end() const { return operand_end(); } - operand_iterator arg_operand_begin() { return operand_begin(); } operand_iterator arg_operand_end() { return operand_end(); } @@ -163,7 +154,7 @@ public: protected: friend class Instruction; - explicit CallOp(const Instruction *state) : Op(state) {} + explicit CallOp(Instruction *state) : Op(state) {} }; /// The "call_indirect" operation represents an indirect call to a value of @@ -182,20 +173,13 @@ public: static void build(Builder *builder, OperationState *result, Value *callee, ArrayRef operands); - const Value *getCallee() const { return getOperand(0); } Value *getCallee() { return getOperand(0); } /// Get the argument operands to the called function. - llvm::iterator_range getArgOperands() const { - return {arg_operand_begin(), arg_operand_end()}; - } llvm::iterator_range getArgOperands() { return {arg_operand_begin(), arg_operand_end()}; } - const_operand_iterator arg_operand_begin() const { return ++operand_begin(); } - const_operand_iterator arg_operand_end() const { return operand_end(); } - operand_iterator arg_operand_begin() { return ++operand_begin(); } operand_iterator arg_operand_end() { return operand_end(); } @@ -208,7 +192,7 @@ public: protected: friend class Instruction; - explicit CallIndirectOp(const Instruction *state) : Op(state) {} + explicit CallIndirectOp(Instruction *state) : Op(state) {} }; /// The predicate indicates the type of the comparison to perform: @@ -274,7 +258,7 @@ public: private: friend class Instruction; - explicit CmpIOp(const Instruction *state) : Op(state) {} + explicit CmpIOp(Instruction *state) : Op(state) {} }; /// The "cond_br" operation represents a conditional branch instruction in a @@ -314,29 +298,20 @@ public: MLIRContext *context); // The condition operand is the first operand in the list. - Value *getCondition() { return getOperand(0); } - const Value *getCondition() const { return getOperand(0); } + Value *getCondition() const { return getOperand(0); } /// Return the destination if the condition is true. Block *getTrueDest(); - Block *getTrueDest() const { - return const_cast(this)->getTrueDest(); - } /// Return the destination if the condition is false. Block *getFalseDest(); - Block *getFalseDest() const { - return const_cast(this)->getFalseDest(); - } // Accessors for operands to the 'true' destination. Value *getTrueOperand(unsigned idx) { assert(idx < getNumTrueOperands()); return getOperand(getTrueDestOperandIndex() + idx); } - const Value *getTrueOperand(unsigned idx) const { - return const_cast(this)->getTrueOperand(idx); - } + void setTrueOperand(unsigned idx, Value *value) { assert(idx < getNumTrueOperands()); setOperand(getTrueDestOperandIndex() + idx, value); @@ -352,16 +327,6 @@ public: return {true_operand_begin(), true_operand_end()}; } - const_operand_iterator true_operand_begin() const { - return operand_begin() + getTrueDestOperandIndex(); - } - const_operand_iterator true_operand_end() const { - return true_operand_begin() + getNumTrueOperands(); - } - llvm::iterator_range getTrueOperands() const { - return {true_operand_begin(), true_operand_end()}; - } - unsigned getNumTrueOperands() const; /// Erase the operand at 'index' from the true operand list. @@ -372,7 +337,7 @@ public: assert(idx < getNumFalseOperands()); return getOperand(getFalseDestOperandIndex() + idx); } - const Value *getFalseOperand(unsigned idx) const { + Value *getFalseOperand(unsigned idx) const { return const_cast(this)->getFalseOperand(idx); } void setFalseOperand(unsigned idx, Value *value) { @@ -388,16 +353,6 @@ public: return {false_operand_begin(), false_operand_end()}; } - const_operand_iterator false_operand_begin() const { - return true_operand_end(); - } - const_operand_iterator false_operand_end() const { - return false_operand_begin() + getNumFalseOperands(); - } - llvm::iterator_range getFalseOperands() const { - return {false_operand_begin(), false_operand_end()}; - } - unsigned getNumFalseOperands() const; /// Erase the operand at 'index' from the false operand list. @@ -413,7 +368,7 @@ private: } friend class Instruction; - explicit CondBranchOp(const Instruction *state) : Op(state) {} + explicit CondBranchOp(Instruction *state) : Op(state) {} }; /// The "constant" operation requires a single attribute named "value". @@ -445,7 +400,7 @@ public: protected: friend class Instruction; - explicit ConstantOp(const Instruction *state) : Op(state) {} + explicit ConstantOp(Instruction *state) : Op(state) {} }; /// This is a refinement of the "constant" op for the case where it is @@ -467,7 +422,7 @@ public: private: friend class Instruction; - explicit ConstantFloatOp(const Instruction *state) : ConstantOp(state) {} + explicit ConstantFloatOp(Instruction *state) : ConstantOp(state) {} }; /// This is a refinement of the "constant" op for the case where it is @@ -494,7 +449,7 @@ public: private: friend class Instruction; - explicit ConstantIntOp(const Instruction *state) : ConstantOp(state) {} + explicit ConstantIntOp(Instruction *state) : ConstantOp(state) {} }; /// This is a refinement of the "constant" op for the case where it is @@ -515,7 +470,7 @@ public: private: friend class Instruction; - explicit ConstantIndexOp(const Instruction *state) : ConstantOp(state) {} + explicit ConstantIndexOp(Instruction *state) : ConstantOp(state) {} }; /// The "dealloc" operation frees the region of memory referenced by a memref @@ -531,8 +486,7 @@ private: class DeallocOp : public Op { public: - Value *getMemRef() { return getOperand(); } - const Value *getMemRef() const { return getOperand(); } + Value *getMemRef() const { return getOperand(); } void setMemRef(Value *value) { setOperand(value); } static StringRef getOperationName() { return "std.dealloc"; } @@ -547,7 +501,7 @@ public: private: friend class Instruction; - explicit DeallocOp(const Instruction *state) : Op(state) {} + explicit DeallocOp(Instruction *state) : Op(state) {} }; /// The "dim" operation takes a memref or tensor operand and returns an @@ -578,7 +532,7 @@ public: private: friend class Instruction; - explicit DimOp(const Instruction *state) : Op(state) {} + explicit DimOp(Instruction *state) : Op(state) {} }; // DmaStartOp starts a non-blocking DMA operation that transfers data from a @@ -629,22 +583,19 @@ public: Value *elementsPerStride = nullptr); // Returns the source MemRefType for this DMA operation. - const Value *getSrcMemRef() const { return getOperand(0); } + Value *getSrcMemRef() const { return getOperand(0); } // Returns the rank (number of indices) of the source MemRefType. unsigned getSrcMemRefRank() const { return getSrcMemRef()->getType().cast().getRank(); } // Returns the source memerf indices for this DMA operation. - llvm::iterator_range - getSrcIndices() const { + llvm::iterator_range getSrcIndices() { return {getInstruction()->operand_begin() + 1, getInstruction()->operand_begin() + 1 + getSrcMemRefRank()}; } // Returns the destination MemRefType for this DMA operations. - const Value *getDstMemRef() const { - return getOperand(1 + getSrcMemRefRank()); - } + Value *getDstMemRef() const { return getOperand(1 + getSrcMemRefRank()); } // Returns the rank (number of indices) of the destination MemRefType. unsigned getDstMemRefRank() const { return getDstMemRef()->getType().cast().getRank(); @@ -657,20 +608,19 @@ public: } // Returns the destination memref indices for this DMA operation. - llvm::iterator_range - getDstIndices() const { + llvm::iterator_range getDstIndices() { return {getInstruction()->operand_begin() + 1 + getSrcMemRefRank() + 1, getInstruction()->operand_begin() + 1 + getSrcMemRefRank() + 1 + getDstMemRefRank()}; } // Returns the number of elements being transferred by this DMA operation. - const Value *getNumElements() const { + Value *getNumElements() const { return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank()); } // Returns the Tag MemRef for this DMA operation. - const Value *getTagMemRef() const { + Value *getTagMemRef() const { return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1); } // Returns the rank (number of indices) of the tag MemRefType. @@ -679,8 +629,7 @@ public: } // Returns the tag memref index for this DMA operation. - llvm::iterator_range - getTagIndices() const { + llvm::iterator_range getTagIndices() const { unsigned tagIndexStartPos = 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1; return {getInstruction()->operand_begin() + tagIndexStartPos, @@ -725,22 +674,16 @@ public: return nullptr; return getOperand(getNumOperands() - 1 - 1); } - const Value *getStride() const { - return const_cast(this)->getStride(); - } Value *getNumElementsPerStride() { if (!isStrided()) return nullptr; return getOperand(getNumOperands() - 1); } - const Value *getNumElementsPerStride() const { - return const_cast(this)->getNumElementsPerStride(); - } protected: friend class Instruction; - explicit DmaStartOp(const Instruction *state) : Op(state) {} + explicit DmaStartOp(Instruction *state) : Op(state) {} }; // DmaWaitOp blocks until the completion of a DMA operation associated with the @@ -765,12 +708,10 @@ public: static StringRef getOperationName() { return "std.dma_wait"; } // Returns the Tag MemRef associated with the DMA operation being waited on. - const Value *getTagMemRef() const { return getOperand(0); } - Value *getTagMemRef() { return getOperand(0); } + Value *getTagMemRef() const { return getOperand(0); } // Returns the tag memref index for this DMA operation. - llvm::iterator_range - getTagIndices() const { + llvm::iterator_range getTagIndices() const { return {getInstruction()->operand_begin() + 1, getInstruction()->operand_begin() + 1 + getTagMemRefRank()}; } @@ -781,9 +722,7 @@ public: } // Returns the number of elements transferred in the associated DMA operation. - const Value *getNumElements() const { - return getOperand(1 + getTagMemRefRank()); - } + Value *getNumElements() const { return getOperand(1 + getTagMemRefRank()); } static bool parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); @@ -792,7 +731,7 @@ public: protected: friend class Instruction; - explicit DmaWaitOp(const Instruction *state) : Op(state) {} + explicit DmaWaitOp(Instruction *state) : Op(state) {} }; /// The "extract_element" op reads a tensor or vector and returns one element @@ -813,19 +752,13 @@ public: static void build(Builder *builder, OperationState *result, Value *aggregate, ArrayRef indices = {}); - Value *getAggregate() { return getOperand(0); } - const Value *getAggregate() const { return getOperand(0); } + Value *getAggregate() const { return getOperand(0); } llvm::iterator_range getIndices() { return {getInstruction()->operand_begin() + 1, getInstruction()->operand_end()}; } - llvm::iterator_range getIndices() const { - return {getInstruction()->operand_begin() + 1, - getInstruction()->operand_end()}; - } - static StringRef getOperationName() { return "std.extract_element"; } // Hooks to customize behavior of this op. @@ -836,7 +769,7 @@ public: private: friend class Instruction; - explicit ExtractElementOp(const Instruction *state) : Op(state) {} + explicit ExtractElementOp(Instruction *state) : Op(state) {} }; /// The "load" op reads an element from a memref specified by an index list. The @@ -854,8 +787,7 @@ public: static void build(Builder *builder, OperationState *result, Value *memref, ArrayRef indices = {}); - Value *getMemRef() { return getOperand(0); } - const Value *getMemRef() const { return getOperand(0); } + Value *getMemRef() const { return getOperand(0); } void setMemRef(Value *value) { setOperand(0, value); } MemRefType getMemRefType() const { return getMemRef()->getType().cast(); @@ -866,11 +798,6 @@ public: getInstruction()->operand_end()}; } - llvm::iterator_range getIndices() const { - return {getInstruction()->operand_begin() + 1, - getInstruction()->operand_end()}; - } - static StringRef getOperationName() { return "std.load"; } bool verify(); @@ -881,7 +808,7 @@ public: private: friend class Instruction; - explicit LoadOp(const Instruction *state) : Op(state) {} + explicit LoadOp(Instruction *state) : Op(state) {} }; /// The "memref_cast" operation converts a memref from one type to an equivalent @@ -914,7 +841,7 @@ public: private: friend class Instruction; - explicit MemRefCastOp(const Instruction *state) : CastOp(state) {} + explicit MemRefCastOp(Instruction *state) : CastOp(state) {} }; /// The "return" operation represents a return instruction within a function. @@ -941,7 +868,7 @@ public: private: friend class Instruction; - explicit ReturnOp(const Instruction *state) : Op(state) {} + explicit ReturnOp(Instruction *state) : Op(state) {} }; /// The "select" operation chooses one value based on a binary condition @@ -965,18 +892,15 @@ public: void print(OpAsmPrinter *p); bool verify(); - Value *getCondition() { return getOperand(0); } - const Value *getCondition() const { return getOperand(0); } - Value *getTrueValue() { return getOperand(1); } - const Value *getTrueValue() const { return getOperand(1); } - Value *getFalseValue() { return getOperand(2); } - const Value *getFalseValue() const { return getOperand(2); } + Value *getCondition() const { return getOperand(0); } + Value *getTrueValue() const { return getOperand(1); } + Value *getFalseValue() const { return getOperand(2); } Value *fold(); private: friend class Instruction; - explicit SelectOp(const Instruction *state) : Op(state) {} + explicit SelectOp(Instruction *state) : Op(state) {} }; /// The "store" op writes an element to a memref specified by an index list. @@ -997,13 +921,11 @@ public: Value *valueToStore, Value *memref, ArrayRef indices = {}); - Value *getValueToStore() { return getOperand(0); } - const Value *getValueToStore() const { return getOperand(0); } + Value *getValueToStore() const { return getOperand(0); } Value *getMemRef() { return getOperand(1); } - const Value *getMemRef() const { return getOperand(1); } void setMemRef(Value *value) { setOperand(1, value); } - MemRefType getMemRefType() const { + MemRefType getMemRefType() { return getMemRef()->getType().cast(); } @@ -1012,11 +934,6 @@ public: getInstruction()->operand_end()}; } - llvm::iterator_range getIndices() const { - return {getInstruction()->operand_begin() + 2, - getInstruction()->operand_end()}; - } - static StringRef getOperationName() { return "std.store"; } bool verify(); @@ -1028,7 +945,7 @@ public: private: friend class Instruction; - explicit StoreOp(const Instruction *state) : Op(state) {} + explicit StoreOp(Instruction *state) : Op(state) {} }; /// The "tensor_cast" operation converts a tensor from one type to an equivalent @@ -1046,9 +963,7 @@ public: static StringRef getOperationName() { return "std.tensor_cast"; } /// The result of a tensor_cast is always a tensor. - TensorType getType() const { - return getResult()->getType().cast(); - } + TensorType getType() { return getResult()->getType().cast(); } void print(OpAsmPrinter *p); @@ -1056,13 +971,13 @@ public: private: friend class Instruction; - explicit TensorCastOp(const Instruction *state) : CastOp(state) {} + explicit TensorCastOp(Instruction *state) : CastOp(state) {} }; /// Prints dimension and symbol list. -void printDimAndSymbolList(Instruction::const_operand_iterator begin, - Instruction::const_operand_iterator end, - unsigned numDims, OpAsmPrinter *p); +void printDimAndSymbolList(Instruction::operand_iterator begin, + Instruction::operand_iterator end, unsigned numDims, + OpAsmPrinter *p); /// Parses dimension and symbol list and returns true if parsing failed. bool parseDimAndSymbolList(OpAsmParser *parser, diff --git a/mlir/include/mlir/SuperVectorOps/SuperVectorOps.h b/mlir/include/mlir/SuperVectorOps/SuperVectorOps.h index 286842338d8..bb9fb8c5b66 100644 --- a/mlir/include/mlir/SuperVectorOps/SuperVectorOps.h +++ b/mlir/include/mlir/SuperVectorOps/SuperVectorOps.h @@ -102,22 +102,18 @@ public: VectorType vectorType, Value *srcMemRef, ArrayRef srcIndices, AffineMap permutationMap, Optional paddingValue = None); - VectorType getResultType() const { + VectorType getResultType() { return getResult()->getType().cast(); } Value *getVector() { return getResult(); } - const Value *getVector() const { return getResult(); } Value *getMemRef() { return getOperand(Offsets::MemRefOffset); } - const Value *getMemRef() const { return getOperand(Offsets::MemRefOffset); } - VectorType getVectorType() const { return getResultType(); } - MemRefType getMemRefType() const { + VectorType getVectorType() { return getResultType(); } + MemRefType getMemRefType() { return getMemRef()->getType().cast(); } llvm::iterator_range getIndices(); - llvm::iterator_range getIndices() const; Optional getPaddingValue(); - Optional getPaddingValue() const; - AffineMap getPermutationMap() const; + AffineMap getPermutationMap(); static bool parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); @@ -125,7 +121,7 @@ public: private: friend class Instruction; - explicit VectorTransferReadOp(const Instruction *state) : Op(state) {} + explicit VectorTransferReadOp(Instruction *state) : Op(state) {} }; /// VectorTransferWriteOp performs a blocking write from a super-vector to @@ -172,18 +168,15 @@ public: Value *dstMemRef, ArrayRef dstIndices, AffineMap permutationMap); Value *getVector() { return getOperand(Offsets::VectorOffset); } - const Value *getVector() const { return getOperand(Offsets::VectorOffset); } - VectorType getVectorType() const { + VectorType getVectorType() { return getVector()->getType().cast(); } Value *getMemRef() { return getOperand(Offsets::MemRefOffset); } - const Value *getMemRef() const { return getOperand(Offsets::MemRefOffset); } - MemRefType getMemRefType() const { + MemRefType getMemRefType() { return getMemRef()->getType().cast(); } llvm::iterator_range getIndices(); - llvm::iterator_range getIndices() const; - AffineMap getPermutationMap() const; + AffineMap getPermutationMap(); static bool parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); @@ -191,7 +184,7 @@ public: private: friend class Instruction; - explicit VectorTransferWriteOp(const Instruction *state) : Op(state) {} + explicit VectorTransferWriteOp(Instruction *state) : Op(state) {} }; /// VectorTypeCastOp performs a conversion from a memref with scalar element to @@ -215,7 +208,7 @@ public: private: friend class Instruction; - explicit VectorTypeCastOp(const Instruction *state) : Op(state) {} + explicit VectorTypeCastOp(Instruction *state) : Op(state) {} }; } // end namespace mlir diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h index 78968ae2a7d..0fc076d1a65 100644 --- a/mlir/include/mlir/Transforms/Utils.h +++ b/mlir/include/mlir/Transforms/Utils.h @@ -69,12 +69,12 @@ class Function; // extra operands, note that 'indexRemap' would just be applied to existing // indices (%i, %j). // TODO(bondhugula): allow extraIndices to be added at any position. -bool replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, +bool replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, ArrayRef extraIndices = {}, AffineMap indexRemap = AffineMap(), ArrayRef extraOperands = {}, - const Instruction *domInstFilter = nullptr, - const Instruction *postDomInstFilter = nullptr); + Instruction *domInstFilter = nullptr, + Instruction *postDomInstFilter = nullptr); /// Creates and inserts into 'builder' a new AffineApplyOp, with the number of /// its results equal to the number of operands, as a composition diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 6fe6f1d63a7..9cb74187cc1 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -42,7 +42,7 @@ AffineOpsDialect::AffineOpsDialect(MLIRContext *context) /// A utility function to check if a value is defined at the top level of a /// function. A value defined at the top level is always a valid symbol. -bool mlir::isTopLevelSymbol(const Value *value) { +bool mlir::isTopLevelSymbol(Value *value) { if (auto *arg = dyn_cast(value)) return arg->getOwner()->getParent()->getContainingFunction(); return value->getDefiningInst()->getParentInst() == nullptr; @@ -51,7 +51,7 @@ bool mlir::isTopLevelSymbol(const Value *value) { // Value can be used as a dimension id if it is valid as a symbol, or // it is an induction variable, or it is a result of affine apply operation // with dimension id arguments. -bool mlir::isValidDim(const Value *value) { +bool mlir::isValidDim(Value *value) { // The value must be an index type. if (!value->getType().isIndex()) return false; @@ -76,7 +76,7 @@ bool mlir::isValidDim(const Value *value) { // Value can be used as a symbol if it is a constant, or it is defined at // the top level, or it is a result of affine apply operation with symbol // arguments. -bool mlir::isValidSymbol(const Value *value) { +bool mlir::isValidSymbol(Value *value) { // The value must be an index type. if (!value->getType().isIndex()) return false; @@ -105,10 +105,9 @@ bool mlir::isValidSymbol(const Value *value) { /// was an invalid operand. An operation is provided to emit any necessary /// errors. template -static bool -verifyDimAndSymbolIdentifiers(const OpTy &op, - Instruction::const_operand_range operands, - unsigned numDims) { +static bool verifyDimAndSymbolIdentifiers(OpTy &op, + Instruction::operand_range operands, + unsigned numDims) { unsigned opIt = 0; for (auto *operand : operands) { if (opIt++ < numDims) { @@ -189,16 +188,16 @@ bool AffineApplyOp::verify() { // The result of the affine apply operation can be used as a dimension id if it // is a CFG value or if it is an Value, and all the operands are valid // dimension ids. -bool AffineApplyOp::isValidDim() const { +bool AffineApplyOp::isValidDim() { return llvm::all_of(getOperands(), - [](const Value *op) { return mlir::isValidDim(op); }); + [](Value *op) { return mlir::isValidDim(op); }); } // The result of the affine apply operation can be used as a symbol if it is // a CFG value or if it is an Value, and all the operands are symbols. -bool AffineApplyOp::isValidSymbol() const { +bool AffineApplyOp::isValidSymbol() { return llvm::all_of(getOperands(), - [](const Value *op) { return mlir::isValidSymbol(op); }); + [](Value *op) { return mlir::isValidSymbol(op); }); } Attribute AffineApplyOp::constantFold(ArrayRef operands, @@ -1069,13 +1068,13 @@ Block *AffineForOp::createBody() { return body; } -const AffineBound AffineForOp::getLowerBound() const { +AffineBound AffineForOp::getLowerBound() { auto lbMap = getLowerBoundMap(); return AffineBound(OpPointer(*this), 0, lbMap.getNumInputs(), lbMap); } -const AffineBound AffineForOp::getUpperBound() const { +AffineBound AffineForOp::getUpperBound() { auto lbMap = getLowerBoundMap(); auto ubMap = getUpperBoundMap(); return AffineBound(OpPointer(*this), lbMap.getNumInputs(), @@ -1124,19 +1123,19 @@ void AffineForOp::setUpperBoundMap(AffineMap map) { setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map)); } -bool AffineForOp::hasConstantLowerBound() const { +bool AffineForOp::hasConstantLowerBound() { return getLowerBoundMap().isSingleConstant(); } -bool AffineForOp::hasConstantUpperBound() const { +bool AffineForOp::hasConstantUpperBound() { return getUpperBoundMap().isSingleConstant(); } -int64_t AffineForOp::getConstantLowerBound() const { +int64_t AffineForOp::getConstantLowerBound() { return getLowerBoundMap().getSingleConstantResult(); } -int64_t AffineForOp::getConstantUpperBound() const { +int64_t AffineForOp::getConstantUpperBound() { return getUpperBoundMap().getSingleConstantResult(); } @@ -1154,19 +1153,11 @@ AffineForOp::operand_range AffineForOp::getLowerBoundOperands() { return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()}; } -AffineForOp::const_operand_range AffineForOp::getLowerBoundOperands() const { - return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()}; -} - AffineForOp::operand_range AffineForOp::getUpperBoundOperands() { return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()}; } -AffineForOp::const_operand_range AffineForOp::getUpperBoundOperands() const { - return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()}; -} - -bool AffineForOp::matchingBoundOperandList() const { +bool AffineForOp::matchingBoundOperandList() { auto lbMap = getLowerBoundMap(); auto ubMap = getUpperBoundMap(); if (lbMap.getNumDims() != ubMap.getNumDims() || @@ -1186,14 +1177,14 @@ bool AffineForOp::matchingBoundOperandList() const { Value *AffineForOp::getInductionVar() { return getBody()->getArgument(0); } /// Returns if the provided value is the induction variable of a AffineForOp. -bool mlir::isForInductionVar(const Value *val) { +bool mlir::isForInductionVar(Value *val) { return getForInductionVarOwner(val) != nullptr; } /// Returns the loop parent of an induction variable. If the provided value is /// not an induction variable, then return nullptr. -OpPointer mlir::getForInductionVarOwner(const Value *val) { - const BlockArgument *ivArg = dyn_cast(val); +OpPointer mlir::getForInductionVarOwner(Value *val) { + auto *ivArg = dyn_cast(val); if (!ivArg || !ivArg->getOwner()) return OpPointer(); auto *containingInst = ivArg->getOwner()->getParent()->getContainingInst(); @@ -1320,7 +1311,7 @@ void AffineIfOp::print(OpAsmPrinter *p) { /*elidedAttrs=*/getConditionAttrName()); } -IntegerSet AffineIfOp::getIntegerSet() const { +IntegerSet AffineIfOp::getIntegerSet() { return getAttrOfType(getConditionAttrName()).getValue(); } void AffineIfOp::setIntegerSet(IntegerSet newSet) { diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index c24a7688a4d..0b7d9f831e4 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -118,7 +118,7 @@ LogicalResult mlir::getIndexSet(MutableArrayRef> forOps, // 'indexSet' correspond to the loops surounding 'inst' from outermost to // innermost. // TODO(andydavis) Add support to handle IfInsts surrounding 'inst'. -static LogicalResult getInstIndexSet(const Instruction *inst, +static LogicalResult getInstIndexSet(Instruction *inst, FlatAffineConstraints *indexSet) { // TODO(andydavis) Extend this to gather enclosing IfInsts and consider // factoring it out into a utility function. @@ -147,25 +147,25 @@ static LogicalResult getInstIndexSet(const Instruction *inst, // of maps to check. So getSrcDimOrSymPos would be "getPos(value, {0, 2})". class ValuePositionMap { public: - void addSrcValue(const Value *value) { + void addSrcValue(Value *value) { if (addValueAt(value, &srcDimPosMap, numSrcDims)) ++numSrcDims; } - void addDstValue(const Value *value) { + void addDstValue(Value *value) { if (addValueAt(value, &dstDimPosMap, numDstDims)) ++numDstDims; } - void addSymbolValue(const Value *value) { + void addSymbolValue(Value *value) { if (addValueAt(value, &symbolPosMap, numSymbols)) ++numSymbols; } - unsigned getSrcDimOrSymPos(const Value *value) const { + unsigned getSrcDimOrSymPos(Value *value) const { return getDimOrSymPos(value, srcDimPosMap, 0); } - unsigned getDstDimOrSymPos(const Value *value) const { + unsigned getDstDimOrSymPos(Value *value) const { return getDimOrSymPos(value, dstDimPosMap, numSrcDims); } - unsigned getSymPos(const Value *value) const { + unsigned getSymPos(Value *value) const { auto it = symbolPosMap.find(value); assert(it != symbolPosMap.end()); return numSrcDims + numDstDims + it->second; @@ -177,7 +177,7 @@ public: unsigned getNumSymbols() const { return numSymbols; } private: - bool addValueAt(const Value *value, DenseMap *posMap, + bool addValueAt(Value *value, DenseMap *posMap, unsigned position) { auto it = posMap->find(value); if (it == posMap->end()) { @@ -186,8 +186,8 @@ private: } return false; } - unsigned getDimOrSymPos(const Value *value, - const DenseMap &dimPosMap, + unsigned getDimOrSymPos(Value *value, + const DenseMap &dimPosMap, unsigned dimPosOffset) const { auto it = dimPosMap.find(value); if (it != dimPosMap.end()) { @@ -201,9 +201,9 @@ private: unsigned numSrcDims = 0; unsigned numDstDims = 0; unsigned numSymbols = 0; - DenseMap srcDimPosMap; - DenseMap dstDimPosMap; - DenseMap symbolPosMap; + DenseMap srcDimPosMap; + DenseMap dstDimPosMap; + DenseMap symbolPosMap; }; // Builds a map from Value to identifier position in a new merged identifier @@ -451,7 +451,7 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap, } // Add equality constraints for any operands that are defined by constant ops. - auto addEqForConstOperands = [&](ArrayRef operands) { + auto addEqForConstOperands = [&](ArrayRef operands) { for (unsigned i = 0, e = operands.size(); i < e; ++i) { if (isForInductionVar(operands[i])) continue; diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index bc4c751dd77..3de26589b12 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -677,7 +677,7 @@ LogicalResult FlatAffineConstraints::composeMap(AffineValueMap *vMap) { } // Turn a dimension into a symbol. -static void turnDimIntoSymbol(FlatAffineConstraints *cst, const Value &id) { +static void turnDimIntoSymbol(FlatAffineConstraints *cst, Value &id) { unsigned pos; if (cst->findId(id, &pos) && pos < cst->getNumDimIds()) { swapId(cst, pos, cst->getNumDimIds() - 1); @@ -686,7 +686,7 @@ static void turnDimIntoSymbol(FlatAffineConstraints *cst, const Value &id) { } // Turn a symbol into a dimension. -static void turnSymbolIntoDim(FlatAffineConstraints *cst, const Value &id) { +static void turnSymbolIntoDim(FlatAffineConstraints *cst, Value &id) { unsigned pos; if (cst->findId(id, &pos) && pos >= cst->getNumDimIds() && pos < cst->getNumDimAndSymbolIds()) { @@ -1669,7 +1669,7 @@ FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap, if (localVarCst.getNumLocalIds() > 0) { // Set values for localVarCst. localVarCst.setIdValues(0, localVarCst.getNumDimAndSymbolIds(), operands); - for (const auto *operand : operands) { + for (auto *operand : operands) { unsigned pos; if (findId(*operand, &pos)) { if (pos >= getNumDimIds() && pos < getNumDimAndSymbolIds()) { @@ -1689,7 +1689,7 @@ FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap, // this here since the constraint system changes after a bound is added. SmallVector positions; unsigned numOperands = operands.size(); - for (const auto *operand : operands) { + for (auto *operand : operands) { unsigned pos; if (!findId(*operand, &pos)) assert(0 && "expected to be found"); @@ -1859,7 +1859,7 @@ void FlatAffineConstraints::addLocalFloorDiv(ArrayRef dividend, addInequality(bound); } -bool FlatAffineConstraints::findId(const Value &id, unsigned *pos) const { +bool FlatAffineConstraints::findId(Value &id, unsigned *pos) const { unsigned i = 0; for (const auto &mayBeId : ids) { if (mayBeId.hasValue() && mayBeId.getValue() == &id) { @@ -1871,7 +1871,7 @@ bool FlatAffineConstraints::findId(const Value &id, unsigned *pos) const { return false; } -bool FlatAffineConstraints::containsId(const Value &id) const { +bool FlatAffineConstraints::containsId(Value &id) const { return llvm::any_of(ids, [&](const Optional &mayBeId) { return mayBeId.hasValue() && mayBeId.getValue() == &id; }); @@ -1896,7 +1896,7 @@ void FlatAffineConstraints::setIdToConstant(unsigned pos, int64_t val) { /// Sets the specified identifer to a constant value; asserts if the id is not /// found. -void FlatAffineConstraints::setIdToConstant(const Value &id, int64_t val) { +void FlatAffineConstraints::setIdToConstant(Value &id, int64_t val) { unsigned pos; if (!findId(id, &pos)) // This is a pre-condition for this method. diff --git a/mlir/lib/Analysis/Dominance.cpp b/mlir/lib/Analysis/Dominance.cpp index 50fb2586f7d..84d0782f7d6 100644 --- a/mlir/lib/Analysis/Dominance.cpp +++ b/mlir/lib/Analysis/Dominance.cpp @@ -101,8 +101,7 @@ template class mlir::detail::DominanceInfoBase; //===----------------------------------------------------------------------===// /// Return true if instruction A properly dominates instruction B. -bool DominanceInfo::properlyDominates(const Instruction *a, - const Instruction *b) { +bool DominanceInfo::properlyDominates(Instruction *a, Instruction *b) { auto *aBlock = a->getBlock(), *bBlock = b->getBlock(); // If the blocks are the same, then check if b is before a in the block. @@ -122,7 +121,7 @@ bool DominanceInfo::properlyDominates(const Instruction *a, } /// Return true if value A properly dominates instruction B. -bool DominanceInfo::properlyDominates(const Value *a, const Instruction *b) { +bool DominanceInfo::properlyDominates(Value *a, Instruction *b) { if (auto *aInst = a->getDefiningInst()) return properlyDominates(aInst, b); @@ -136,8 +135,7 @@ bool DominanceInfo::properlyDominates(const Value *a, const Instruction *b) { //===----------------------------------------------------------------------===// /// Returns true if statement 'a' properly postdominates statement b. -bool PostDominanceInfo::properlyPostDominates(const Instruction *a, - const Instruction *b) { +bool PostDominanceInfo::properlyPostDominates(Instruction *a, Instruction *b) { auto *aBlock = a->getBlock(), *bBlock = b->getBlock(); // If the blocks are the same, check if b is before a in the block. @@ -145,7 +143,7 @@ bool PostDominanceInfo::properlyPostDominates(const Instruction *a, return b->isBeforeInBlock(a); // Traverse up b's hierarchy to check if b's block is contained in a's. - if (const auto *bAncestor = a->getBlock()->findAncestorInstInBlock(*b)) + if (auto *bAncestor = a->getBlock()->findAncestorInstInBlock(*b)) // Since we already know that aBlock != bBlock, here bAncestor != b. // a and bAncestor are in the same block; check if 'a' postdominates // bAncestor. diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 521dc5151e7..28b0f75909c 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -179,7 +179,7 @@ uint64_t mlir::getLargestDivisorOfTripCount(OpPointer forOp) { return gcd.getValue(); } -bool mlir::isAccessInvariant(const Value &iv, const Value &index) { +bool mlir::isAccessInvariant(Value &iv, Value &index) { assert(isForInductionVar(&iv) && "iv must be a AffineForOp"); assert(index.getType().isa() && "index must be of IndexType"); SmallVector affineApplyOps; @@ -203,10 +203,9 @@ bool mlir::isAccessInvariant(const Value &iv, const Value &index) { return !(AffineValueMap(composeOp).isFunctionOf(0, const_cast(&iv))); } -llvm::DenseSet -mlir::getInvariantAccesses(const Value &iv, - llvm::ArrayRef indices) { - llvm::DenseSet res; +llvm::DenseSet +mlir::getInvariantAccesses(Value &iv, llvm::ArrayRef indices) { + llvm::DenseSet res; for (unsigned idx = 0, n = indices.size(); idx < n; ++idx) { auto *val = indices[idx]; if (isAccessInvariant(iv, *val)) { @@ -236,29 +235,29 @@ mlir::getInvariantAccesses(const Value &iv, /// // TODO(ntv): check strides. template -static bool isContiguousAccess(const Value &iv, const LoadOrStoreOp &memoryOp, +static bool isContiguousAccess(Value &iv, OpPointer memoryOp, unsigned fastestVaryingDim) { static_assert(std::is_same::value || std::is_same::value, "Must be called on either const LoadOp & or const StoreOp &"); - auto memRefType = memoryOp.getMemRefType(); + auto memRefType = memoryOp->getMemRefType(); if (fastestVaryingDim >= memRefType.getRank()) { - memoryOp.emitError("fastest varying dim out of bounds"); + memoryOp->emitError("fastest varying dim out of bounds"); return false; } auto layoutMap = memRefType.getAffineMaps(); // TODO(ntv): remove dependence on Builder once we support non-identity // layout map. - Builder b(memoryOp.getInstruction()->getContext()); + Builder b(memoryOp->getInstruction()->getContext()); if (layoutMap.size() >= 2 || (layoutMap.size() == 1 && !(layoutMap[0] == b.getMultiDimIdentityMap(layoutMap[0].getNumDims())))) { - return memoryOp.emitError("NYI: non-trivial layoutMap"), false; + return memoryOp->emitError("NYI: non-trivial layoutMap"), false; } - auto indices = memoryOp.getIndices(); + auto indices = memoryOp->getIndices(); auto numIndices = llvm::size(indices); unsigned d = 0; for (auto index : indices) { @@ -278,12 +277,12 @@ static bool isVectorElement(LoadOrStoreOpPointer memoryOp) { return memRefType.getElementType().template isa(); } -static bool isVectorTransferReadOrWrite(const Instruction &inst) { +static bool isVectorTransferReadOrWrite(Instruction &inst) { return inst.isa() || inst.isa(); } using VectorizableInstFun = - std::function, const Instruction &)>; + std::function, Instruction &)>; static bool isVectorizableLoopWithCond(OpPointer loop, VectorizableInstFun isVectorizableInst) { @@ -302,7 +301,7 @@ static bool isVectorizableLoopWithCond(OpPointer loop, } // No vectorization across unknown regions. - auto regions = matcher::Op([](const Instruction &inst) -> bool { + auto regions = matcher::Op([](Instruction &inst) -> bool { return inst.getNumRegions() != 0 && !(inst.isa() || inst.isa()); }); @@ -342,22 +341,22 @@ static bool isVectorizableLoopWithCond(OpPointer loop, bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim( OpPointer loop, unsigned fastestVaryingDim) { - VectorizableInstFun fun([fastestVaryingDim](OpPointer loop, - const Instruction &op) { - auto load = op.dyn_cast(); - auto store = op.dyn_cast(); - return load ? isContiguousAccess(*loop->getInductionVar(), *load, - fastestVaryingDim) - : isContiguousAccess(*loop->getInductionVar(), *store, - fastestVaryingDim); - }); + VectorizableInstFun fun( + [fastestVaryingDim](OpPointer loop, Instruction &op) { + auto load = op.dyn_cast(); + auto store = op.dyn_cast(); + return load ? isContiguousAccess(*loop->getInductionVar(), load, + fastestVaryingDim) + : isContiguousAccess(*loop->getInductionVar(), store, + fastestVaryingDim); + }); return isVectorizableLoopWithCond(loop, fun); } bool mlir::isVectorizableLoop(OpPointer loop) { VectorizableInstFun fun( // TODO: implement me - [](OpPointer loop, const Instruction &op) { return true; }); + [](OpPointer loop, Instruction &op) { return true; }); return isVectorizableLoopWithCond(loop, fun); } @@ -373,9 +372,9 @@ bool mlir::isInstwiseShiftValid(OpPointer forOp, // Work backwards over the body of the block so that the shift of a use's // ancestor instruction in the block gets recorded before it's looked up. - DenseMap forBodyShift; + DenseMap forBodyShift; for (auto it : llvm::enumerate(llvm::reverse(forBody->getInstructions()))) { - const auto &inst = it.value(); + auto &inst = it.value(); // Get the index of the current instruction, note that we are iterating in // reverse so we need to fix it up. @@ -387,7 +386,7 @@ bool mlir::isInstwiseShiftValid(OpPointer forOp, // Validate the results of this instruction if it were to be shifted. for (unsigned i = 0, e = inst.getNumResults(); i < e; ++i) { - const Value *result = inst.getResult(i); + Value *result = inst.getResult(i); for (const InstOperand &use : result->getUses()) { // If an ancestor instruction doesn't lie in the block of forOp, // there is no shift to check. diff --git a/mlir/lib/Analysis/NestedMatcher.cpp b/mlir/lib/Analysis/NestedMatcher.cpp index 3e55291972b..83b3591ce5c 100644 --- a/mlir/lib/Analysis/NestedMatcher.cpp +++ b/mlir/lib/Analysis/NestedMatcher.cpp @@ -110,13 +110,9 @@ void NestedPattern::matchOne(Instruction *inst, } } -static bool isAffineForOp(const Instruction &inst) { - return inst.isa(); -} +static bool isAffineForOp(Instruction &inst) { return inst.isa(); } -static bool isAffineIfOp(const Instruction &inst) { - return inst.isa(); -} +static bool isAffineIfOp(Instruction &inst) { return inst.isa(); } namespace mlir { namespace matcher { @@ -129,7 +125,7 @@ NestedPattern If(NestedPattern child) { return NestedPattern(child, isAffineIfOp); } NestedPattern If(FilterFunctionType filter, NestedPattern child) { - return NestedPattern(child, [filter](const Instruction &inst) { + return NestedPattern(child, [filter](Instruction &inst) { return isAffineIfOp(inst) && filter(inst); }); } @@ -137,7 +133,7 @@ NestedPattern If(ArrayRef nested) { return NestedPattern(nested, isAffineIfOp); } NestedPattern If(FilterFunctionType filter, ArrayRef nested) { - return NestedPattern(nested, [filter](const Instruction &inst) { + return NestedPattern(nested, [filter](Instruction &inst) { return isAffineIfOp(inst) && filter(inst); }); } @@ -146,7 +142,7 @@ NestedPattern For(NestedPattern child) { return NestedPattern(child, isAffineForOp); } NestedPattern For(FilterFunctionType filter, NestedPattern child) { - return NestedPattern(child, [=](const Instruction &inst) { + return NestedPattern(child, [=](Instruction &inst) { return isAffineForOp(inst) && filter(inst); }); } @@ -154,24 +150,24 @@ NestedPattern For(ArrayRef nested) { return NestedPattern(nested, isAffineForOp); } NestedPattern For(FilterFunctionType filter, ArrayRef nested) { - return NestedPattern(nested, [=](const Instruction &inst) { + return NestedPattern(nested, [=](Instruction &inst) { return isAffineForOp(inst) && filter(inst); }); } // TODO(ntv): parallel annotation on loops. -bool isParallelLoop(const Instruction &inst) { +bool isParallelLoop(Instruction &inst) { auto loop = inst.cast(); return loop || true; // loop->isParallel(); }; // TODO(ntv): reduction annotation on loops. -bool isReductionLoop(const Instruction &inst) { +bool isReductionLoop(Instruction &inst) { auto loop = inst.cast(); return loop || true; // loop->isReduction(); }; -bool isLoadOrStore(const Instruction &inst) { +bool isLoadOrStore(Instruction &inst) { return inst.isa() || inst.isa(); }; diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 8918dd03f80..2cd0a83296b 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -39,7 +39,7 @@ using llvm::SmallDenseMap; /// Populates 'loops' with IVs of the loops surrounding 'inst' ordered from /// the outermost 'for' instruction to the innermost one. -void mlir::getLoopIVs(const Instruction &inst, +void mlir::getLoopIVs(Instruction &inst, SmallVectorImpl> *loops) { auto *currInst = inst.getParentInst(); OpPointer currAffineForOp; @@ -431,7 +431,7 @@ template LogicalResult mlir::boundCheckLoadOrStoreOp(OpPointer storeOp, // Returns in 'positions' the Block positions of 'inst' in each ancestor // Block from the Block containing instruction, stopping at 'limitBlock'. -static void findInstPosition(const Instruction *inst, Block *limitBlock, +static void findInstPosition(Instruction *inst, Block *limitBlock, SmallVectorImpl *positions) { Block *block = inst->getBlock(); while (block != limitBlock) { @@ -653,8 +653,8 @@ bool MemRefAccess::isStore() const { return opInst->isa(); } /// Returns the nesting depth of this statement, i.e., the number of loops /// surrounding this statement. -unsigned mlir::getNestingDepth(const Instruction &inst) { - const Instruction *currInst = &inst; +unsigned mlir::getNestingDepth(Instruction &inst) { + Instruction *currInst = &inst; unsigned depth = 0; while ((currInst = currInst->getParentInst())) { if (currInst->isa()) @@ -665,8 +665,7 @@ unsigned mlir::getNestingDepth(const Instruction &inst) { /// Returns the number of surrounding loops common to 'loopsA' and 'loopsB', /// where each lists loops from outer-most to inner-most in loop nest. -unsigned mlir::getNumCommonSurroundingLoops(const Instruction &A, - const Instruction &B) { +unsigned mlir::getNumCommonSurroundingLoops(Instruction &A, Instruction &B) { SmallVector, 4> loopsA, loopsB; getLoopIVs(A, &loopsA); getLoopIVs(B, &loopsB); diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index 5ca3a829cbd..5df31affe31 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -180,7 +180,7 @@ AffineMap mlir::makePermutationMap( enclosingLoopToVectorDim); } -bool mlir::matcher::operatesOnSuperVectors(const Instruction &opInst, +bool mlir::matcher::operatesOnSuperVectors(Instruction &opInst, VectorType subVectorType) { // First, extract the vector type and ditinguish between: // a. ops that *must* lower a super-vector (i.e. vector_transfer_read, diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index d92aaedad17..b72731ed5cb 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -52,7 +52,7 @@ namespace { /// class FuncVerifier { public: - bool failure(const Twine &message, const Instruction &value) { + bool failure(const Twine &message, Instruction &value) { return value.emitError(message); } @@ -108,9 +108,9 @@ public: bool verify(); bool verifyBlock(Block &block, bool isTopLevel); - bool verifyOperation(const Instruction &op); + bool verifyOperation(Instruction &op); bool verifyDominance(Block &block); - bool verifyInstDominance(const Instruction &inst); + bool verifyInstDominance(Instruction &inst); explicit FuncVerifier(Function &fn) : fn(fn), identifierRegex("^[a-zA-Z_][a-zA-Z_0-9\\.\\$]*$") {} @@ -270,12 +270,12 @@ bool FuncVerifier::verifyBlock(Block &block, bool isTopLevel) { } /// Check the invariants of the specified operation. -bool FuncVerifier::verifyOperation(const Instruction &op) { +bool FuncVerifier::verifyOperation(Instruction &op) { if (op.getFunction() != &fn) return failure("operation in the wrong function", op); // Check that operands are non-nil and structurally ok. - for (const auto *operand : op.getOperands()) { + for (auto *operand : op.getOperands()) { if (!operand) return failure("null operand found", op); @@ -322,7 +322,7 @@ bool FuncVerifier::verifyDominance(Block &block) { return false; } -bool FuncVerifier::verifyInstDominance(const Instruction &inst) { +bool FuncVerifier::verifyInstDominance(Instruction &inst) { // Check that operands properly dominate this use. for (unsigned operandNo = 0, e = inst.getNumOperands(); operandNo != e; ++operandNo) { diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp index af172fcb542..685a7a07a69 100644 --- a/mlir/lib/Dialect/Traits.cpp +++ b/mlir/lib/Dialect/Traits.cpp @@ -184,7 +184,7 @@ static bool isSameShapedVectorOrTensor(Type type1, Type type2) { return false; } -bool OpTrait::impl::verifyCompatibleOperandBroadcast(const Instruction *op) { +bool OpTrait::impl::verifyCompatibleOperandBroadcast(Instruction *op) { assert(op->getNumOperands() == 2 && "only support broadcast check on two operands"); assert(op->getNumResults() == 1 && diff --git a/mlir/lib/EDSC/MLIREmitter.cpp b/mlir/lib/EDSC/MLIREmitter.cpp index 202d254aee0..6430796bcc1 100644 --- a/mlir/lib/EDSC/MLIREmitter.cpp +++ b/mlir/lib/EDSC/MLIREmitter.cpp @@ -45,8 +45,8 @@ using namespace mlir; using namespace mlir::edsc; using namespace mlir::edsc::detail; -static void printDefininingStatement(llvm::raw_ostream &os, const Value &v) { - const auto *inst = v.getDefiningInst(); +static void printDefininingStatement(llvm::raw_ostream &os, Value &v) { + auto *inst = v.getDefiningInst(); if (inst) { inst->print(os); return; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index f4b49497cb2..b62a279fa29 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -138,7 +138,7 @@ private: void recordTypeReference(Type ty) { usedTypes.insert(ty); } // Visit functions. - void visitInstruction(const Instruction *inst); + void visitInstruction(Instruction *inst); void visitType(Type type); void visitAttribute(Attribute attr); @@ -189,7 +189,7 @@ void ModuleState::visitAttribute(Attribute attr) { } } -void ModuleState::visitInstruction(const Instruction *inst) { +void ModuleState::visitInstruction(Instruction *inst) { // Visit all the types used in the operation. for (auto *operand : inst->getOperands()) visitType(operand->getType()); @@ -1060,11 +1060,11 @@ public: void printFunctionSignature(); // Methods to print instructions. - void print(const Instruction *inst); + void print(Instruction *inst); void print(Block *block, bool printBlockArgs = true); - void printOperation(const Instruction *op); - void printGenericOp(const Instruction *op); + void printOperation(Instruction *op); + void printGenericOp(Instruction *op); // Implement OpAsmPrinter. raw_ostream &getStream() const { return os; } @@ -1085,7 +1085,7 @@ public: void printFunctionReference(Function *func) { return ModulePrinter::printFunctionReference(func); } - void printOperand(const Value *value) { printValueID(value); } + void printOperand(Value *value) { printValueID(value); } void printOptionalAttrDict(ArrayRef attrs, ArrayRef elidedAttrs = {}) { @@ -1107,8 +1107,7 @@ public: return it != blockIDs.end() ? it->second : ~0U; } - void printSuccessorAndUseList(const Instruction *term, - unsigned index) override; + void printSuccessorAndUseList(Instruction *term, unsigned index) override; /// Print a region. void printRegion(Region &blocks, bool printEntryBlockArgs) override { @@ -1127,17 +1126,17 @@ public: const static unsigned indentWidth = 2; protected: - void numberValueID(const Value *value); + void numberValueID(Value *value); void numberValuesInBlock(Block &block); - void printValueID(const Value *value, bool printResultNo = true) const; + void printValueID(Value *value, bool printResultNo = true) const; private: Function *function; /// This is the value ID for each SSA value in the current function. If this /// returns ~0, then the valueID has an entry in valueNames. - DenseMap valueIDs; - DenseMap valueNames; + DenseMap valueIDs; + DenseMap valueNames; /// This is the block ID for each block in the current function. DenseMap blockIDs; @@ -1191,7 +1190,7 @@ void FunctionPrinter::numberValuesInBlock(Block &block) { } } -void FunctionPrinter::numberValueID(const Value *value) { +void FunctionPrinter::numberValueID(Value *value) { assert(!valueIDs.count(value) && "Value numbered multiple times"); SmallString<32> specialNameBuffer; @@ -1389,14 +1388,13 @@ void FunctionPrinter::print(Block *block, bool printBlockArgs) { currentIndent -= indentWidth; } -void FunctionPrinter::print(const Instruction *inst) { +void FunctionPrinter::print(Instruction *inst) { os.indent(currentIndent); printOperation(inst); printTrailingLocation(inst->getLoc()); } -void FunctionPrinter::printValueID(const Value *value, - bool printResultNo) const { +void FunctionPrinter::printValueID(Value *value, bool printResultNo) const { int resultNo = -1; auto lookupValue = value; @@ -1434,7 +1432,7 @@ void FunctionPrinter::printValueID(const Value *value, os << '#' << resultNo; } -void FunctionPrinter::printOperation(const Instruction *op) { +void FunctionPrinter::printOperation(Instruction *op) { if (op->getNumResults()) { printValueID(op->getResult(0), /*printResultNo=*/false); os << " = "; @@ -1454,7 +1452,7 @@ void FunctionPrinter::printOperation(const Instruction *op) { printGenericOp(op); } -void FunctionPrinter::printGenericOp(const Instruction *op) { +void FunctionPrinter::printGenericOp(Instruction *op) { os << '"'; printEscapedString(op->getName().getStringRef(), os); os << "\"("; @@ -1465,11 +1463,10 @@ void FunctionPrinter::printGenericOp(const Instruction *op) { for (unsigned i = 0; i < numSuccessors; ++i) totalNumSuccessorOperands += op->getNumSuccessorOperands(i); unsigned numProperOperands = op->getNumOperands() - totalNumSuccessorOperands; - SmallVector properOperands( + SmallVector properOperands( op->operand_begin(), std::next(op->operand_begin(), numProperOperands)); - interleaveComma(properOperands, - [&](const Value *value) { printValueID(value); }); + interleaveComma(properOperands, [&](Value *value) { printValueID(value); }); os << ')'; @@ -1490,7 +1487,7 @@ void FunctionPrinter::printGenericOp(const Instruction *op) { // Print the type signature of the operation. os << " : ("; interleaveComma(properOperands, - [&](const Value *value) { printType(value->getType()); }); + [&](Value *value) { printType(value->getType()); }); os << ") -> "; if (op->getNumResults() == 1 && @@ -1499,7 +1496,7 @@ void FunctionPrinter::printGenericOp(const Instruction *op) { } else { os << '('; interleaveComma(op->getResults(), - [&](const Value *result) { printType(result->getType()); }); + [&](Value *result) { printType(result->getType()); }); os << ')'; } @@ -1508,7 +1505,7 @@ void FunctionPrinter::printGenericOp(const Instruction *op) { printRegion(region, /*printEntryBlockArgs=*/true); } -void FunctionPrinter::printSuccessorAndUseList(const Instruction *term, +void FunctionPrinter::printSuccessorAndUseList(Instruction *term, unsigned index) { printBlockName(term->getSuccessor(index)); @@ -1518,11 +1515,10 @@ void FunctionPrinter::printSuccessorAndUseList(const Instruction *term, os << '('; interleaveComma(succOperands, - [this](const Value *operand) { printValueID(operand); }); + [this](Value *operand) { printValueID(operand); }); os << " : "; - interleaveComma(succOperands, [this](const Value *operand) { - printType(operand->getType()); - }); + interleaveComma(succOperands, + [this](Value *operand) { printType(operand->getType()); }); os << ')'; } @@ -1585,7 +1581,7 @@ void IntegerSet::print(raw_ostream &os) const { ModulePrinter(os, state).printIntegerSet(*this); } -void Value::print(raw_ostream &os) const { +void Value::print(raw_ostream &os) { switch (getKind()) { case Value::Kind::BlockArgument: // TODO: Improve this. @@ -1596,9 +1592,9 @@ void Value::print(raw_ostream &os) const { } } -void Value::dump() const { print(llvm::errs()); } +void Value::dump() { print(llvm::errs()); } -void Instruction::print(raw_ostream &os) const { +void Instruction::print(raw_ostream &os) { auto *function = getFunction(); if (!function) { os << "<>\n"; @@ -1610,7 +1606,7 @@ void Instruction::print(raw_ostream &os) const { FunctionPrinter(function, modulePrinter).print(this); } -void Instruction::dump() const { +void Instruction::dump() { print(llvm::errs()); llvm::errs() << "\n"; } diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp index 0470eb5e13b..4782f92c508 100644 --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -26,7 +26,7 @@ using namespace mlir; //===----------------------------------------------------------------------===// /// Returns the number of this argument. -unsigned BlockArgument::getArgNumber() const { +unsigned BlockArgument::getArgNumber() { // Arguments are not stored in place, so we have to find it within the list. auto argList = getOwner()->getArguments(); return std::distance(argList.begin(), llvm::find(argList, this)); @@ -78,7 +78,7 @@ void Block::eraseFromFunction() { /// Returns 'inst' if 'inst' lies in this block, or otherwise finds the /// ancestor instruction of 'inst' that lies in this block. Returns nullptr if /// the latter fails. -Instruction *Block::findAncestorInstInBlock(const Instruction &inst) { +Instruction *Block::findAncestorInstInBlock(Instruction &inst) { // Traverse up the instruction hierarchy starting from the owner of operand to // find the ancestor instruction that resides in the block of 'forInst'. auto *currInst = const_cast(&inst); @@ -109,7 +109,7 @@ bool Block::verifyInstOrder() { std::next(instructions.begin()) == instructions.end()) return false; - const Instruction *prev = nullptr; + Instruction *prev = nullptr; for (auto &i : *this) { // The previous instruction must have a smaller order index than the next as // it appears earlier in the list. @@ -306,12 +306,12 @@ void Region::cloneInto(Region *dest, BlockAndValueMapping &mapper, // Clone the block arguments. The user might be deleting arguments to the // block by specifying them in the mapper. If so, we don't add the // argument to the cloned block. - for (const auto *arg : block.getArguments()) + for (auto *arg : block.getArguments()) if (!mapper.contains(arg)) mapper.map(arg, newBlock->addArgument(arg->getType())); // Clone and remap the instructions within this block. - for (const auto &inst : block) + for (auto &inst : block) newBlock->push_back(inst.clone(mapper, context)); dest->push_back(newBlock); diff --git a/mlir/lib/IR/Instruction.cpp b/mlir/lib/IR/Instruction.cpp index 4ebf2a798a2..698b4cd1926 100644 --- a/mlir/lib/IR/Instruction.cpp +++ b/mlir/lib/IR/Instruction.cpp @@ -33,7 +33,7 @@ using namespace mlir; //===----------------------------------------------------------------------===// /// Return the result number of this result. -unsigned InstResult::getResultNumber() const { +unsigned InstResult::getResultNumber() { // Results are always stored consecutively, so use pointer subtraction to // figure out what number this is. return this - &getOwner()->getInstResults()[0]; @@ -44,7 +44,7 @@ unsigned InstResult::getResultNumber() const { //===----------------------------------------------------------------------===// /// Return which operand this is in the operand list. -template <> unsigned InstOperand::getOperandNumber() const { +template <> unsigned InstOperand::getOperandNumber() { return this - &getOwner()->getInstOperands()[0]; } @@ -53,7 +53,7 @@ template <> unsigned InstOperand::getOperandNumber() const { //===----------------------------------------------------------------------===// /// Return which operand this is in the operand list. -template <> unsigned BlockOperand::getOperandNumber() const { +template <> unsigned BlockOperand::getOperandNumber() { return this - &getOwner()->getBlockOperands()[0]; } @@ -287,7 +287,7 @@ void Instruction::destroy() { } /// Return the context this operation is associated with. -MLIRContext *Instruction::getContext() const { +MLIRContext *Instruction::getContext() { // If we have a result or operand type, that is a constant time way to get // to the context. if (getNumResults()) @@ -300,11 +300,11 @@ MLIRContext *Instruction::getContext() const { return getFunction()->getContext(); } -Instruction *Instruction::getParentInst() const { +Instruction *Instruction::getParentInst() { return block ? block->getContainingInst() : nullptr; } -Function *Instruction::getFunction() const { +Function *Instruction::getFunction() { return block ? block->getFunction() : nullptr; } @@ -339,14 +339,14 @@ void Instruction::walkPostOrder( /// Emit a note about this instruction, reporting up to any diagnostic /// handlers that may be listening. -void Instruction::emitNote(const Twine &message) const { +void Instruction::emitNote(const Twine &message) { getContext()->emitDiagnostic(getLoc(), message, MLIRContext::DiagnosticKind::Note); } /// Emit a warning about this instruction, reporting up to any diagnostic /// handlers that may be listening. -void Instruction::emitWarning(const Twine &message) const { +void Instruction::emitWarning(const Twine &message) { getContext()->emitDiagnostic(getLoc(), message, MLIRContext::DiagnosticKind::Warning); } @@ -355,7 +355,7 @@ void Instruction::emitWarning(const Twine &message) const { /// any diagnostic handlers that may be listening. This function always /// returns true. NOTE: This may terminate the containing application, only /// use when the IR is in an inconsistent state. -bool Instruction::emitError(const Twine &message) const { +bool Instruction::emitError(const Twine &message) { return getContext()->emitError(getLoc(), message); } @@ -364,7 +364,7 @@ bool Instruction::emitError(const Twine &message) const { /// of the parent block. /// Note: This function has an average complexity of O(1), but worst case may /// take O(N) where N is the number of instructions within the parent block. -bool Instruction::isBeforeInBlock(const Instruction *other) const { +bool Instruction::isBeforeInBlock(Instruction *other) { assert(block && "Instructions without parent blocks have no order."); assert(other && other->block == block && "Expected other instruction to have the same parent block."); @@ -490,7 +490,7 @@ void Instruction::dropAllReferences() { } /// Return true if there are no users of any results of this operation. -bool Instruction::use_empty() const { +bool Instruction::use_empty() { for (auto *result : getResults()) if (!result->use_empty()) return false; @@ -502,10 +502,6 @@ void Instruction::setSuccessor(Block *block, unsigned index) { getBlockOperands()[index].set(block); } -auto Instruction::getNonSuccessorOperands() const -> const_operand_range { - return {const_operand_iterator(this, 0), - const_operand_iterator(this, getSuccessorOperandIndex(0))}; -} auto Instruction::getNonSuccessorOperands() -> operand_range { return {operand_iterator(this, 0), operand_iterator(this, getSuccessorOperandIndex(0))}; @@ -513,7 +509,7 @@ auto Instruction::getNonSuccessorOperands() -> operand_range { /// Get the index of the first operand of the successor at the provided /// index. -unsigned Instruction::getSuccessorOperandIndex(unsigned index) const { +unsigned Instruction::getSuccessorOperandIndex(unsigned index) { assert(!isKnownNonTerminator() && "only terminators may have successors"); assert(index < getNumSuccessors()); @@ -527,13 +523,6 @@ unsigned Instruction::getSuccessorOperandIndex(unsigned index) const { return getNumOperands() - postSuccessorOpCount; } -auto Instruction::getSuccessorOperands(unsigned index) const - -> const_operand_range { - unsigned succOperandIndex = getSuccessorOperandIndex(index); - return {const_operand_iterator(this, succOperandIndex), - const_operand_iterator(this, succOperandIndex + - getNumSuccessorOperands(index))}; -} auto Instruction::getSuccessorOperands(unsigned index) -> operand_range { unsigned succOperandIndex = getSuccessorOperandIndex(index); return {operand_iterator(this, succOperandIndex), @@ -544,19 +533,16 @@ auto Instruction::getSuccessorOperands(unsigned index) -> operand_range { /// Attempt to constant fold this operation with the specified constant /// operand values. If successful, this fills in the results vector. If not, /// results is unspecified. -LogicalResult -Instruction::constantFold(ArrayRef operands, - SmallVectorImpl &results) const { - auto *inst = const_cast(this); - +LogicalResult Instruction::constantFold(ArrayRef operands, + SmallVectorImpl &results) { if (auto *abstractOp = getAbstractOperation()) { // If we have a registered operation definition matching this one, use it to // try to constant fold the operation. - if (succeeded(abstractOp->constantFoldHook(inst, operands, results))) + if (succeeded(abstractOp->constantFoldHook(this, operands, results))) return success(); // Otherwise, fall back on the dialect hook to handle it. - return abstractOp->dialect.constantFoldHook(inst, operands, results); + return abstractOp->dialect.constantFoldHook(this, operands, results); } // If this operation hasn't been registered or doesn't have abstract @@ -564,7 +550,7 @@ Instruction::constantFold(ArrayRef operands, auto opName = getName().getStringRef(); auto dialectPrefix = opName.split('.').first; if (auto *dialect = getContext()->getRegisteredDialect(dialectPrefix)) - return dialect->constantFoldHook(inst, operands, results); + return dialect->constantFoldHook(this, operands, results); return failure(); } @@ -582,7 +568,7 @@ LogicalResult Instruction::fold(SmallVectorImpl &results) { /// Emit an error with the op name prefixed, like "'dim' op " which is /// convenient for verifiers. -bool Instruction::emitOpError(const Twine &message) const { +bool Instruction::emitOpError(const Twine &message) { return emitError(Twine('\'') + getName().getStringRef() + "' op " + message); } @@ -596,7 +582,7 @@ bool Instruction::emitOpError(const Twine &message) const { /// sub-instructions to the corresponding instruction that is copied, and adds /// those mappings to the map. Instruction *Instruction::clone(BlockAndValueMapping &mapper, - MLIRContext *context) const { + MLIRContext *context) { SmallVector operands; SmallVector successors; @@ -605,7 +591,7 @@ Instruction *Instruction::clone(BlockAndValueMapping &mapper, if (getNumSuccessors() == 0) { // Non-branching operations can just add all the operands. for (auto *opValue : getOperands()) - operands.push_back(mapper.lookupOrDefault(const_cast(opValue))); + operands.push_back(mapper.lookupOrDefault(opValue)); } else { // We add the operands separated by nullptr's for each successor. unsigned firstSuccOperand = @@ -614,21 +600,18 @@ Instruction *Instruction::clone(BlockAndValueMapping &mapper, unsigned i = 0; for (; i != firstSuccOperand; ++i) - operands.push_back( - mapper.lookupOrDefault(const_cast(InstOperands[i].get()))); + operands.push_back(mapper.lookupOrDefault(InstOperands[i].get())); successors.reserve(getNumSuccessors()); for (unsigned succ = 0, e = getNumSuccessors(); succ != e; ++succ) { - successors.push_back( - mapper.lookupOrDefault(const_cast(getSuccessor(succ)))); + successors.push_back(mapper.lookupOrDefault(getSuccessor(succ))); // Add sentinel to delineate successor operands. operands.push_back(nullptr); // Remap the successors operands. for (auto *operand : getSuccessorOperands(succ)) - operands.push_back( - mapper.lookupOrDefault(const_cast(operand))); + operands.push_back(mapper.lookupOrDefault(operand)); } } @@ -652,7 +635,7 @@ Instruction *Instruction::clone(BlockAndValueMapping &mapper, return newOp; } -Instruction *Instruction::clone(MLIRContext *context) const { +Instruction *Instruction::clone(MLIRContext *context) { BlockAndValueMapping mapper; return clone(mapper, context); } diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index a2605a9f910..78cc18480cf 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -93,20 +93,19 @@ void OpState::emitNote(const Twine &message) const { // Op Trait implementations //===----------------------------------------------------------------------===// -bool OpTrait::impl::verifyZeroOperands(const Instruction *op) { +bool OpTrait::impl::verifyZeroOperands(Instruction *op) { if (op->getNumOperands() != 0) return op->emitOpError("requires zero operands"); return false; } -bool OpTrait::impl::verifyOneOperand(const Instruction *op) { +bool OpTrait::impl::verifyOneOperand(Instruction *op) { if (op->getNumOperands() != 1) return op->emitOpError("requires a single operand"); return false; } -bool OpTrait::impl::verifyNOperands(const Instruction *op, - unsigned numOperands) { +bool OpTrait::impl::verifyNOperands(Instruction *op, unsigned numOperands) { if (op->getNumOperands() != numOperands) { return op->emitOpError("expected " + Twine(numOperands) + " operands, but found " + @@ -115,7 +114,7 @@ bool OpTrait::impl::verifyNOperands(const Instruction *op, return false; } -bool OpTrait::impl::verifyAtLeastNOperands(const Instruction *op, +bool OpTrait::impl::verifyAtLeastNOperands(Instruction *op, unsigned numOperands) { if (op->getNumOperands() < numOperands) return op->emitOpError("expected " + Twine(numOperands) + @@ -135,7 +134,7 @@ static Type getTensorOrVectorElementType(Type type) { return type; } -bool OpTrait::impl::verifyOperandsAreIntegerLike(const Instruction *op) { +bool OpTrait::impl::verifyOperandsAreIntegerLike(Instruction *op) { for (auto *operand : op->getOperands()) { auto type = getTensorOrVectorElementType(operand->getType()); if (!type.isIntOrIndex()) @@ -144,7 +143,7 @@ bool OpTrait::impl::verifyOperandsAreIntegerLike(const Instruction *op) { return false; } -bool OpTrait::impl::verifySameTypeOperands(const Instruction *op) { +bool OpTrait::impl::verifySameTypeOperands(Instruction *op) { // Zero or one operand always have the "same" type. unsigned nOperands = op->getNumOperands(); if (nOperands < 2) @@ -158,26 +157,25 @@ bool OpTrait::impl::verifySameTypeOperands(const Instruction *op) { return false; } -bool OpTrait::impl::verifyZeroResult(const Instruction *op) { +bool OpTrait::impl::verifyZeroResult(Instruction *op) { if (op->getNumResults() != 0) return op->emitOpError("requires zero results"); return false; } -bool OpTrait::impl::verifyOneResult(const Instruction *op) { +bool OpTrait::impl::verifyOneResult(Instruction *op) { if (op->getNumResults() != 1) return op->emitOpError("requires one result"); return false; } -bool OpTrait::impl::verifyNResults(const Instruction *op, - unsigned numOperands) { +bool OpTrait::impl::verifyNResults(Instruction *op, unsigned numOperands) { if (op->getNumResults() != numOperands) return op->emitOpError("expected " + Twine(numOperands) + " results"); return false; } -bool OpTrait::impl::verifyAtLeastNResults(const Instruction *op, +bool OpTrait::impl::verifyAtLeastNResults(Instruction *op, unsigned numOperands) { if (op->getNumResults() < numOperands) return op->emitOpError("expected " + Twine(numOperands) + @@ -206,7 +204,7 @@ static bool verifyShapeMatch(Type type1, Type type2) { return false; } -bool OpTrait::impl::verifySameOperandsAndResultShape(const Instruction *op) { +bool OpTrait::impl::verifySameOperandsAndResultShape(Instruction *op) { if (op->getNumOperands() == 0 || op->getNumResults() == 0) return true; @@ -224,7 +222,7 @@ bool OpTrait::impl::verifySameOperandsAndResultShape(const Instruction *op) { return false; } -bool OpTrait::impl::verifySameOperandsAndResultType(const Instruction *op) { +bool OpTrait::impl::verifySameOperandsAndResultType(Instruction *op) { if (op->getNumOperands() == 0 || op->getNumResults() == 0) return true; @@ -242,9 +240,9 @@ bool OpTrait::impl::verifySameOperandsAndResultType(const Instruction *op) { return false; } -static bool verifyBBArguments( - llvm::iterator_range operands, - Block *destBB, const Instruction *op) { +static bool +verifyBBArguments(llvm::iterator_range operands, + Block *destBB, Instruction *op) { unsigned operandCount = std::distance(operands.begin(), operands.end()); if (operandCount != destBB->getNumArguments()) return op->emitError("branch has " + Twine(operandCount) + @@ -260,7 +258,7 @@ static bool verifyBBArguments( return false; } -static bool verifyTerminatorSuccessors(const Instruction *op) { +static bool verifyTerminatorSuccessors(Instruction *op) { // Verify that the operands lines up with the BB arguments in the successor. Function *fn = op->getFunction(); for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) { @@ -273,7 +271,7 @@ static bool verifyTerminatorSuccessors(const Instruction *op) { return false; } -bool OpTrait::impl::verifyIsTerminator(const Instruction *op) { +bool OpTrait::impl::verifyIsTerminator(Instruction *op) { Block *block = op->getBlock(); // Verify that the operation is at the end of the respective parent block. if (!block || &block->back() != op) @@ -285,7 +283,7 @@ bool OpTrait::impl::verifyIsTerminator(const Instruction *op) { return false; } -bool OpTrait::impl::verifyResultsAreBoolLike(const Instruction *op) { +bool OpTrait::impl::verifyResultsAreBoolLike(Instruction *op) { for (auto *result : op->getResults()) { auto elementType = getTensorOrVectorElementType(result->getType()); bool isBoolType = elementType.isInteger(1); @@ -296,7 +294,7 @@ bool OpTrait::impl::verifyResultsAreBoolLike(const Instruction *op) { return false; } -bool OpTrait::impl::verifyResultsAreFloatLike(const Instruction *op) { +bool OpTrait::impl::verifyResultsAreFloatLike(Instruction *op) { for (auto *result : op->getResults()) { if (!getTensorOrVectorElementType(result->getType()).isa()) return op->emitOpError("requires a floating point type"); @@ -305,7 +303,7 @@ bool OpTrait::impl::verifyResultsAreFloatLike(const Instruction *op) { return false; } -bool OpTrait::impl::verifyResultsAreIntegerLike(const Instruction *op) { +bool OpTrait::impl::verifyResultsAreIntegerLike(Instruction *op) { for (auto *result : op->getResults()) { auto type = getTensorOrVectorElementType(result->getType()); if (!type.isIntOrIndex()) @@ -338,7 +336,7 @@ bool impl::parseBinaryOp(OpAsmParser *parser, OperationState *result) { parser->addTypeToList(type, result->types); } -void impl::printBinaryOp(const Instruction *op, OpAsmPrinter *p) { +void impl::printBinaryOp(Instruction *op, OpAsmPrinter *p) { assert(op->getNumOperands() == 2 && "binary op should have two operands"); assert(op->getNumResults() == 1 && "binary op should have one result"); @@ -377,7 +375,7 @@ bool impl::parseCastOp(OpAsmParser *parser, OperationState *result) { parser->addTypeToList(dstType, result->types); } -void impl::printCastOp(const Instruction *op, OpAsmPrinter *p) { +void impl::printCastOp(Instruction *op, OpAsmPrinter *p) { *p << op->getName() << ' ' << *op->getOperand(0) << " : " << op->getOperand(0)->getType() << " to " << op->getResult(0)->getType(); } diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp index 2b6eea80a4a..6ac1711229c 100644 --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -29,7 +29,7 @@ Instruction *Value::getDefiningInst() { } /// Return the function that this Value is defined in. -Function *Value::getFunction() const { +Function *Value::getFunction() { switch (getKind()) { case Value::Kind::BlockArgument: return cast(this)->getFunction(); @@ -64,14 +64,14 @@ void IRObjectWithUseList::dropAllUses() { //===----------------------------------------------------------------------===// /// Return the function that this argument is defined in. -Function *BlockArgument::getFunction() const { +Function *BlockArgument::getFunction() { if (auto *owner = getOwner()) return owner->getFunction(); return nullptr; } /// Returns if the current argument is a function argument. -bool BlockArgument::isFunctionArgument() const { +bool BlockArgument::isFunctionArgument() { auto *containingFn = getFunction(); return containingFn && &containingFn->front() == getOwner(); } diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 55da9c6ed6b..963362871a2 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -36,7 +36,7 @@ using namespace mlir; /// A custom binary operation printer that omits the "std." prefix from the /// operation names. -void detail::printStandardBinaryOp(const Instruction *op, OpAsmPrinter *p) { +void detail::printStandardBinaryOp(Instruction *op, OpAsmPrinter *p) { assert(op->getNumOperands() == 2 && "binary op should have two operands"); assert(op->getNumResults() == 1 && "binary op should have one result"); @@ -68,8 +68,8 @@ StandardOpsDialect::StandardOpsDialect(MLIRContext *context) >(); } -void mlir::printDimAndSymbolList(Instruction::const_operand_iterator begin, - Instruction::const_operand_iterator end, +void mlir::printDimAndSymbolList(Instruction::operand_iterator begin, + Instruction::operand_iterator end, unsigned numDims, OpAsmPrinter *p) { *p << '('; p->printOperands(begin, begin + numDims); @@ -1803,8 +1803,7 @@ void ReturnOp::print(OpAsmPrinter *p) { *p << " : "; interleave( operand_begin(), operand_end(), - [&](const Value *e) { p->printType(e->getType()); }, - [&]() { *p << ", "; }); + [&](Value *e) { p->printType(e->getType()); }, [&]() { *p << ", "; }); } } diff --git a/mlir/lib/SuperVectorOps/SuperVectorOps.cpp b/mlir/lib/SuperVectorOps/SuperVectorOps.cpp index 0320e782324..1e0c01a5df1 100644 --- a/mlir/lib/SuperVectorOps/SuperVectorOps.cpp +++ b/mlir/lib/SuperVectorOps/SuperVectorOps.cpp @@ -92,13 +92,6 @@ VectorTransferReadOp::getIndices() { return {begin, end}; } -llvm::iterator_range -VectorTransferReadOp::getIndices() const { - auto begin = getInstruction()->operand_begin() + Offsets::FirstIndexOffset; - auto end = begin + getMemRefType().getRank(); - return {begin, end}; -} - Optional VectorTransferReadOp::getPaddingValue() { auto memRefRank = getMemRefType().getRank(); if (getNumOperands() <= Offsets::FirstIndexOffset + memRefRank) { @@ -107,16 +100,7 @@ Optional VectorTransferReadOp::getPaddingValue() { return Optional(getOperand(Offsets::FirstIndexOffset + memRefRank)); } -Optional VectorTransferReadOp::getPaddingValue() const { - auto memRefRank = getMemRefType().getRank(); - if (getNumOperands() <= Offsets::FirstIndexOffset + memRefRank) { - return None; - } - return Optional( - getOperand(Offsets::FirstIndexOffset + memRefRank)); -} - -AffineMap VectorTransferReadOp::getPermutationMap() const { +AffineMap VectorTransferReadOp::getPermutationMap() { return getAttrOfType(getPermutationMapAttrName()).getValue(); } @@ -134,7 +118,7 @@ void VectorTransferReadOp::print(OpAsmPrinter *p) { // Construct the FunctionType and print it. llvm::SmallVector inputs{getMemRefType()}; // Must have at least one actual index, see verify. - const Value *firstIndex = *(getIndices().begin()); + Value *firstIndex = *getIndices().begin(); Type indexType = firstIndex->getType(); inputs.append(getMemRefType().getRank(), indexType); if (optionalPaddingValue) { @@ -309,14 +293,7 @@ VectorTransferWriteOp::getIndices() { return {begin, end}; } -llvm::iterator_range -VectorTransferWriteOp::getIndices() const { - auto begin = getInstruction()->operand_begin() + Offsets::FirstIndexOffset; - auto end = begin + getMemRefType().getRank(); - return {begin, end}; -} - -AffineMap VectorTransferWriteOp::getPermutationMap() const { +AffineMap VectorTransferWriteOp::getPermutationMap() { return getAttrOfType(getPermutationMapAttrName()).getValue(); } diff --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp index 7c74c2fb2f6..76d484ac402 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp @@ -59,7 +59,7 @@ private: bool convertOneFunction(Function &func); void connectPHINodes(Function &func); bool convertBlock(Block &bb, bool ignoreArguments); - bool convertInstruction(const Instruction &inst, llvm::IRBuilder<> &builder); + bool convertInstruction(Instruction &inst, llvm::IRBuilder<> &builder); template SmallVector lookupValues(Range &&values); @@ -73,7 +73,7 @@ private: // Mappings between original and translated values, used for lookups. llvm::DenseMap functionMapping; - llvm::DenseMap valueMapping; + llvm::DenseMap valueMapping; llvm::DenseMap blockMapping; }; } // end anonymous namespace @@ -185,7 +185,7 @@ template SmallVector ModuleTranslation::lookupValues(Range &&values) { SmallVector remapped; remapped.reserve(llvm::size(values)); - for (const Value *v : values) { + for (Value *v : values) { remapped.push_back(valueMapping.lookup(v)); } return remapped; @@ -195,7 +195,7 @@ SmallVector ModuleTranslation::lookupValues(Range &&values) { // using the `builder`. LLVM IR Builder does not have a generic interface so // this has to be a long chain of `if`s calling different functions with a // different number of arguments. -bool ModuleTranslation::convertInstruction(const Instruction &inst, +bool ModuleTranslation::convertInstruction(Instruction &inst, llvm::IRBuilder<> &builder) { auto extractPosition = [](ArrayAttr attr) { SmallVector position; @@ -212,8 +212,7 @@ bool ModuleTranslation::convertInstruction(const Instruction &inst, // itself. Otherwise, this is an indirect call and the callee is the first // operand, look it up as a normal value. Return the llvm::Value representing // the function result, which may be of llvm::VoidTy type. - auto convertCall = [this, - &builder](const Instruction &inst) -> llvm::Value * { + auto convertCall = [this, &builder](Instruction &inst) -> llvm::Value * { auto operands = lookupValues(inst.getOperands()); ArrayRef operandsRef(operands); if (auto attr = inst.getAttrOfType("callee")) { @@ -270,7 +269,7 @@ bool ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments) { auto predecessors = bb.getPredecessors(); unsigned numPredecessors = std::distance(predecessors.begin(), predecessors.end()); - for (const auto *arg : bb.getArguments()) { + for (auto *arg : bb.getArguments()) { auto wrappedType = arg->getType().dyn_cast(); if (!wrappedType) { arg->getType().getContext()->emitError( @@ -284,7 +283,7 @@ bool ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments) { } // Traverse instructions. - for (const auto &inst : bb) { + for (auto &inst : bb) { if (convertInstruction(inst, builder)) return true; } @@ -294,8 +293,8 @@ bool ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments) { // Get the SSA value passed to the current block from the terminator instruction // of its predecessor. -static const Value *getPHISourceValue(Block *current, Block *pred, - unsigned numArguments, unsigned index) { +static Value *getPHISourceValue(Block *current, Block *pred, + unsigned numArguments, unsigned index) { auto &terminator = *pred->getTerminator(); if (terminator.isa()) { return terminator.getOperand(index); diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index 31f4d48e4ed..05760f18761 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -39,7 +39,8 @@ using namespace mlir; namespace { // TODO(riverriddle) Handle commutative operations. struct SimpleOperationInfo : public llvm::DenseMapInfo { - static unsigned getHashValue(const Instruction *op) { + static unsigned getHashValue(const Instruction *opC) { + auto *op = const_cast(opC); // Hash the operations based upon their: // - Instruction Name // - Attributes @@ -50,7 +51,9 @@ struct SimpleOperationInfo : public llvm::DenseMapInfo { hash_combine_range(op->result_type_begin(), op->result_type_end()), hash_combine_range(op->operand_begin(), op->operand_end())); } - static bool isEqual(const Instruction *lhs, const Instruction *rhs) { + static bool isEqual(const Instruction *lhsC, const Instruction *rhsC) { + auto *lhs = const_cast(lhsC); + auto *rhs = const_cast(rhsC); if (lhs == rhs) return true; if (lhs == getTombstoneKey() || lhs == getEmptyKey() || diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index b033dadfe51..a659b2e480b 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -49,9 +49,8 @@ private: // Utility that looks up a list of value in the value remapping table. Returns // an empty vector if one of the values is not mapped yet. - SmallVector - lookupValues(const llvm::iterator_range - &operands); + SmallVector lookupValues( + const llvm::iterator_range &operands); // Converts the given function to the dialect using hooks defined in // `dialectConversion`. Returns the converted function or `nullptr` on error. @@ -102,10 +101,10 @@ private: } // end namespace mlir SmallVector impl::FunctionConversion::lookupValues( - const llvm::iterator_range &operands) { + const llvm::iterator_range &operands) { SmallVector remapped; remapped.reserve(llvm::size(operands)); - for (const Value *operand : operands) { + for (Value *operand : operands) { Value *value = mapping.lookupOrNull(operand); if (!value) return {}; diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index d97538734d1..954135d2a4f 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -477,7 +477,7 @@ bool DmaGeneration::runOnBlock(Block *block) { // Get to the first load, store, or for op. auto curBegin = - std::find_if(block->begin(), block->end(), [&](const Instruction &inst) { + std::find_if(block->begin(), block->end(), [&](Instruction &inst) { return inst.isa() || inst.isa() || inst.isa(); }); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 6d4ea7206b7..95bdc3ca2d2 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -142,7 +142,7 @@ struct LoopNestStateCollector { }; // TODO(b/117228571) Replace when this is modeled through side-effects/op traits -static bool isMemRefDereferencingOp(const Instruction &op) { +static bool isMemRefDereferencingOp(Instruction &op) { if (op.isa() || op.isa() || op.isa() || op.isa()) return true; diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 804991a7b8b..6208eee5d62 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -192,7 +192,7 @@ struct MaterializationState { VectorType superVectorType; VectorType hwVectorType; SmallVector hwVectorInstance; - DenseMap *substitutionsMap; + DenseMap *substitutionsMap; }; struct MaterializeVectorsPass : public FunctionPass { @@ -239,9 +239,9 @@ static SmallVector delinearize(unsigned linearIndex, return res; } -static Instruction * -instantiate(FuncBuilder *b, Instruction *opInst, VectorType hwVectorType, - DenseMap *substitutionsMap); +static Instruction *instantiate(FuncBuilder *b, Instruction *opInst, + VectorType hwVectorType, + DenseMap *substitutionsMap); /// Not all Values belong to a program slice scoped within the immediately /// enclosing loop. @@ -253,7 +253,7 @@ instantiate(FuncBuilder *b, Instruction *opInst, VectorType hwVectorType, /// /// If substitution fails, returns nullptr. static Value *substitute(Value *v, VectorType hwVectorType, - DenseMap *substitutionsMap) { + DenseMap *substitutionsMap) { auto it = substitutionsMap->find(v); if (it == substitutionsMap->end()) { auto *opInst = v->getDefiningInst(); @@ -404,9 +404,9 @@ materializeAttributes(Instruction *opInst, VectorType hwVectorType) { /// substitutionsMap. /// /// If the underlying substitution fails, this fails too and returns nullptr. -static Instruction * -instantiate(FuncBuilder *b, Instruction *opInst, VectorType hwVectorType, - DenseMap *substitutionsMap) { +static Instruction *instantiate(FuncBuilder *b, Instruction *opInst, + VectorType hwVectorType, + DenseMap *substitutionsMap) { assert(!opInst->isa() && "Should call the function specialized for VectorTransferReadOp"); assert(!opInst->isa() && @@ -481,10 +481,10 @@ static AffineMap projectedPermutationMap(VectorTransferOpTy *transfer, /// `hwVectorType` int the covering of the super-vector type. For a more /// detailed description of the problem, see the description of /// reindexAffineIndices. -static Instruction * -instantiate(FuncBuilder *b, VectorTransferReadOp *read, VectorType hwVectorType, - ArrayRef hwVectorInstance, - DenseMap *substitutionsMap) { +static Instruction *instantiate(FuncBuilder *b, VectorTransferReadOp *read, + VectorType hwVectorType, + ArrayRef hwVectorInstance, + DenseMap *substitutionsMap) { SmallVector indices = map(makePtrDynCaster(), read->getIndices()); auto affineIndices = @@ -505,10 +505,10 @@ instantiate(FuncBuilder *b, VectorTransferReadOp *read, VectorType hwVectorType, /// `hwVectorType` int the covering of th3e super-vector type. For a more /// detailed description of the problem, see the description of /// reindexAffineIndices. -static Instruction * -instantiate(FuncBuilder *b, VectorTransferWriteOp *write, - VectorType hwVectorType, ArrayRef hwVectorInstance, - DenseMap *substitutionsMap) { +static Instruction *instantiate(FuncBuilder *b, VectorTransferWriteOp *write, + VectorType hwVectorType, + ArrayRef hwVectorInstance, + DenseMap *substitutionsMap) { SmallVector indices = map(makePtrDynCaster(), write->getIndices()); auto affineIndices = @@ -624,7 +624,7 @@ static bool emitSlice(MaterializationState *state, // Fresh RAII instanceIndices and substitutionsMap. MaterializationState scopedState = *state; scopedState.hwVectorInstance = delinearize(idx, *ratio); - DenseMap substitutionMap; + DenseMap substitutionMap; scopedState.substitutionsMap = &substitutionMap; // slice are topologically sorted, we can just clone them in order. for (auto *inst : *slice) { @@ -749,7 +749,7 @@ void MaterializeVectorsPass::runOnFunction() { // Capture terminators; i.e. vector_transfer_write ops involving a strict // super-vector of subVectorType. - auto filter = [subVectorType](const Instruction &inst) { + auto filter = [subVectorType](Instruction &inst) { if (!inst.isa()) { return false; } diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 97532fdbe94..1dfc4e7dc17 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -56,7 +56,7 @@ FunctionPassBase *mlir::createPipelineDataTransferPass() { // Returns the position of the tag memref operand given a DMA instruction. // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are // added. TODO(b/117228571) -static unsigned getTagMemRefPos(const Instruction &dmaInst) { +static unsigned getTagMemRefPos(Instruction &dmaInst) { assert(dmaInst.isa() || dmaInst.isa()); if (dmaInst.isa()) { // Second to last operand. @@ -323,7 +323,7 @@ void PipelineDataTransfer::runOnAffineForOp(OpPointer forOp) { findMatchingStartFinishInsts(forOp, startWaitPairs); // Store shift for instruction for later lookup for AffineApplyOp's. - DenseMap instShiftMap; + DenseMap instShiftMap; for (auto &pair : startWaitPairs) { auto *dmaStartInst = pair.first; assert(dmaStartInst->isa()); @@ -341,13 +341,13 @@ void PipelineDataTransfer::runOnAffineForOp(OpPointer forOp) { SmallVector affineApplyInsts; SmallVector operands(dmaStartInst->getOperands()); getReachableAffineApplyOps(operands, affineApplyInsts); - for (const auto *inst : affineApplyInsts) { + for (auto *inst : affineApplyInsts) { instShiftMap[inst] = 0; } } } // Everything else (including compute ops and dma finish) are shifted by one. - for (const auto &inst : *forOp->getBody()) { + for (auto &inst : *forOp->getBody()) { if (instShiftMap.find(&inst) == instShiftMap.end()) { instShiftMap[&inst] = 1; } diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index cbf68056eb9..2f10b898502 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -37,19 +37,19 @@ using namespace mlir; /// Return true if this operation dereferences one or more memref's. // Temporary utility: will be replaced when this is modeled through // side-effects/op traits. TODO(b/117228571) -static bool isMemRefDereferencingOp(const Instruction &op) { +static bool isMemRefDereferencingOp(Instruction &op) { if (op.isa() || op.isa() || op.isa() || op.isa()) return true; return false; } -bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, +bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, ArrayRef extraIndices, AffineMap indexRemap, ArrayRef extraOperands, - const Instruction *domInstFilter, - const Instruction *postDomInstFilter) { + Instruction *domInstFilter, + Instruction *postDomInstFilter) { unsigned newMemRefRank = newMemRef->getType().cast().getRank(); (void)newMemRefRank; // unused in opt mode unsigned oldMemRefRank = oldMemRef->getType().cast().getRank(); @@ -167,7 +167,7 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, // Result types don't change. Both memref's are of the same elemental type. state.types.reserve(opInst->getNumResults()); - for (const auto *result : opInst->getResults()) + for (auto *result : opInst->getResults()) state.types.push_back(result->getType()); // Attributes also do not change. diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index af6fc581cfd..9c9f8593f31 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -105,7 +105,7 @@ void VectorizerTestPass::testVectorShapeRatio(Function *f) { VectorType::get(shape, FloatType::getF32(f->getContext())); // Only filter instructions that operate on a strict super-vector and have one // return. This makes testing easier. - auto filter = [subVectorType](const Instruction &inst) { + auto filter = [subVectorType](Instruction &inst) { assert(subVectorType.getElementType() == FloatType::getF32(subVectorType.getContext()) && "Only f32 supported for now"); @@ -150,7 +150,7 @@ static NestedPattern patternTestSlicingOps() { using functional::map; using matcher::Op; // Match all OpInstructions with the kTestSlicingOpName name. - auto filter = [](const Instruction &inst) { + auto filter = [](Instruction &inst) { return inst.getName().getStringRef() == kTestSlicingOpName; }; return Op(filter); @@ -199,7 +199,7 @@ void VectorizerTestPass::testSlicing(Function *f) { } } -static bool customOpWithAffineMapAttribute(const Instruction &inst) { +static bool customOpWithAffineMapAttribute(Instruction &inst) { return inst.getName().getStringRef() == VectorizerTestPass::kTestAffineMapOpName; } @@ -225,11 +225,11 @@ void VectorizerTestPass::testComposeMaps(Function *f) { simplifyAffineMap(res).print(outs() << "\nComposed map: "); } -static bool affineApplyOp(const Instruction &inst) { +static bool affineApplyOp(Instruction &inst) { return inst.isa(); } -static bool singleResultAffineApplyOpWithoutUses(const Instruction &inst) { +static bool singleResultAffineApplyOpWithoutUses(Instruction &inst) { auto app = inst.dyn_cast(); return app && app->use_empty(); } diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 5c5045b668d..1834b2db0cf 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -734,7 +734,7 @@ struct VectorizationState { // Map of old scalar Instruction to new vectorized Instruction. DenseMap vectorizationMap; // Map of old scalar Value to new vectorized Value. - DenseMap replacementMap; + DenseMap replacementMap; // The strategy drives which loop to vectorize by which amount. const VectorizationStrategy *strategy; // Use-def roots. These represent the starting points for the worklist in the @@ -755,7 +755,7 @@ struct VectorizationState { void registerTerminal(Instruction *inst); private: - void registerReplacement(const Value *key, Value *value); + void registerReplacement(Value *key, Value *value); }; } // end namespace @@ -796,7 +796,7 @@ void VectorizationState::finishVectorizationPattern() { } } -void VectorizationState::registerReplacement(const Value *key, Value *value) { +void VectorizationState::registerReplacement(Value *key, Value *value) { assert(replacementMap.count(key) == 0 && "replacement already registered"); replacementMap.insert(std::make_pair(key, value)); } @@ -858,8 +858,7 @@ static LogicalResult vectorizeAffineForOp(AffineForOp *loop, int64_t step, using namespace functional; loop->setStep(step); - FilterFunctionType notVectorizedThisPattern = [state]( - const Instruction &inst) { + FilterFunctionType notVectorizedThisPattern = [state](Instruction &inst) { if (!matcher::isLoadOrStore(inst)) { return false; } @@ -893,7 +892,7 @@ static LogicalResult vectorizeAffineForOp(AffineForOp *loop, int64_t step, /// we can build a cost model and a search procedure. static FilterFunctionType isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension) { - return [fastestVaryingMemRefDimension](const Instruction &forInst) { + return [fastestVaryingMemRefDimension](Instruction &forInst) { auto loop = forInst.cast(); return isVectorizableLoopAlongFastestVaryingMemRefDim( loop, fastestVaryingMemRefDimension); diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td index 0003069117e..015a889beb2 100644 --- a/mlir/test/mlir-tblgen/op-decl.td +++ b/mlir/test/mlir-tblgen/op-decl.td @@ -47,5 +47,5 @@ def NS_AOp : Op<"a_op", [NoSideEffect]> { // CHECK: bool fold(SmallVectorImpl &results); // CHECK: private: // CHECK: friend class ::mlir::Instruction; -// CHECK: explicit AOp(const Instruction *state) : Op(state) {} +// CHECK: explicit AOp(Instruction *state) : Op(state) {} // CHECK: }; diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 93418cff9a0..7d31dde9156 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -321,8 +321,7 @@ void OpClass::writeDeclTo(raw_ostream &os) const { } os << "\nprivate:\n" << " friend class ::mlir::Instruction;\n"; - os << " explicit " << className - << "(const Instruction *state) : Op(state) {}\n" + os << " explicit " << className << "(Instruction *state) : Op(state) {}\n" << "};"; } -- cgit v1.2.3 From d9b5bc8f5598d74814909ae8a79cadfff72fbb0a Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Sun, 24 Mar 2019 19:53:05 -0700 Subject: Remove OpPointer, cleaning up a ton of code. This also moves Ops to using inherited constructors, which is cleaner and means you can now use DimOp() to get a null op, instead of having to use Instruction::getNull(). This removes another 200 lines of code. PiperOrigin-RevId: 240068113 --- mlir/include/mlir/AffineOps/AffineOps.h | 36 +++---- mlir/include/mlir/Analysis/AffineAnalysis.h | 3 +- mlir/include/mlir/Analysis/AffineStructures.h | 4 +- mlir/include/mlir/Analysis/LoopAnalysis.h | 14 ++- mlir/include/mlir/Analysis/Utils.h | 18 ++-- mlir/include/mlir/EDSC/MLIREmitter.h | 2 +- mlir/include/mlir/IR/Builders.h | 2 +- mlir/include/mlir/IR/Function.h | 6 +- mlir/include/mlir/IR/Instruction.h | 25 ++--- mlir/include/mlir/IR/OpDefinition.h | 76 +++++--------- mlir/include/mlir/IR/PatternMatch.h | 6 +- mlir/include/mlir/StandardOps/Ops.h | 122 +++++++--------------- mlir/include/mlir/SuperVectorOps/SuperVectorOps.h | 18 ++-- mlir/include/mlir/Transforms/LoopUtils.h | 40 +++---- mlir/include/mlir/Transforms/Passes.h | 4 +- mlir/include/mlir/Transforms/Utils.h | 4 +- mlir/lib/AffineOps/AffineOps.cpp | 25 +++-- mlir/lib/Analysis/AffineAnalysis.cpp | 6 +- mlir/lib/Analysis/AffineStructures.cpp | 19 ++-- mlir/lib/Analysis/LoopAnalysis.cpp | 35 +++---- mlir/lib/Analysis/TestParallelismDetection.cpp | 2 +- mlir/lib/Analysis/Utils.cpp | 49 ++++----- mlir/lib/EDSC/Builders.cpp | 9 +- mlir/lib/EDSC/MLIREmitter.cpp | 24 ++--- mlir/lib/StandardOps/Ops.cpp | 4 +- mlir/lib/Transforms/DmaGeneration.cpp | 8 +- mlir/lib/Transforms/LoopFusion.cpp | 28 ++--- mlir/lib/Transforms/LoopTiling.cpp | 39 ++++--- mlir/lib/Transforms/LoopUnroll.cpp | 31 +++--- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 8 +- mlir/lib/Transforms/LowerAffine.cpp | 12 +-- mlir/lib/Transforms/LowerVectorTransfers.cpp | 12 +-- mlir/lib/Transforms/MaterializeVectors.cpp | 6 +- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 7 +- mlir/lib/Transforms/PipelineDataTransfer.cpp | 29 +++-- mlir/lib/Transforms/Utils/LoopUtils.cpp | 84 +++++++-------- mlir/lib/Transforms/Utils/Utils.cpp | 2 +- mlir/lib/Transforms/Vectorize.cpp | 10 +- mlir/test/mlir-tblgen/op-decl.td | 4 +- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 6 +- 40 files changed, 350 insertions(+), 489 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/AffineOps/AffineOps.h b/mlir/include/mlir/AffineOps/AffineOps.h index ee13828ef78..3f586c4e35d 100644 --- a/mlir/include/mlir/AffineOps/AffineOps.h +++ b/mlir/include/mlir/AffineOps/AffineOps.h @@ -59,6 +59,8 @@ public: class AffineApplyOp : public Op { public: + using Op::Op; + /// Builds an affine apply op with the specified map and operands. static void build(Builder *builder, OperationState *result, AffineMap map, ArrayRef operands); @@ -84,10 +86,6 @@ public: static void getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context); - -private: - friend class Instruction; - explicit AffineApplyOp(Instruction *state) : Op(state) {} }; /// The "for" instruction represents an affine loop nest, defining an SSA value @@ -117,6 +115,8 @@ private: class AffineForOp : public Op { public: + using Op::Op; + // Hooks to customize behavior of this op. static void build(Builder *builder, OperationState *result, ArrayRef lbOperands, AffineMap lbMap, @@ -225,10 +225,6 @@ public: /// Returns true if both the lower and upper bound have the same operand lists /// (same operands in the same order). bool matchingBoundOperandList(); - -private: - friend class Instruction; - explicit AffineForOp(Instruction *state) : Op(state) {} }; /// Returns if the provided value is the induction variable of a AffineForOp. @@ -236,21 +232,20 @@ bool isForInductionVar(Value *val); /// Returns the loop parent of an induction variable. If the provided value is /// not an induction variable, then return nullptr. -OpPointer getForInductionVarOwner(Value *val); +AffineForOp getForInductionVarOwner(Value *val); /// Extracts the induction variables from a list of AffineForOps and places them /// in the output argument `ivs`. -void extractForInductionVars(ArrayRef> forInsts, +void extractForInductionVars(ArrayRef forInsts, SmallVectorImpl *ivs); - /// AffineBound represents a lower or upper bound in the for instruction. /// This class does not own the underlying operands. Instead, it refers /// to the operands stored in the AffineForOp. Its life span should not exceed /// that of the for instruction it refers to. class AffineBound { public: - OpPointer getAffineForOp() { return inst; } + AffineForOp getAffineForOp() { return inst; } AffineMap getMap() { return map; } /// Returns an AffineValueMap representing this bound. @@ -274,15 +269,14 @@ public: private: // 'for' instruction that contains this bound. - OpPointer inst; + AffineForOp inst; // Start and end positions of this affine bound operands in the list of // the containing 'for' instruction operands. unsigned opStart, opEnd; // Affine map for this bound. AffineMap map; - AffineBound(OpPointer inst, unsigned opStart, unsigned opEnd, - AffineMap map) + AffineBound(AffineForOp inst, unsigned opStart, unsigned opEnd, AffineMap map) : inst(inst), opStart(opStart), opEnd(opEnd), map(map) {} friend class AffineForOp; @@ -309,6 +303,8 @@ private: class AffineIfOp : public Op { public: + using Op::Op; + // Hooks to customize behavior of this op. static void build(Builder *builder, OperationState *result, IntegerSet condition, ArrayRef conditionOperands); @@ -328,10 +324,6 @@ public: bool verify(); static bool parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); - -private: - friend class Instruction; - explicit AffineIfOp(Instruction *state) : Op(state) {} }; /// Returns true if the given Value can be used as a dimension id. @@ -349,9 +341,9 @@ void canonicalizeMapAndOperands(AffineMap *map, /// Returns a composed AffineApplyOp by composing `map` and `operands` with /// other AffineApplyOps supplying those operands. The operands of the resulting /// AffineApplyOp do not change the length of AffineApplyOp chains. -OpPointer -makeComposedAffineApply(FuncBuilder *b, Location loc, AffineMap map, - llvm::ArrayRef operands); +AffineApplyOp makeComposedAffineApply(FuncBuilder *b, Location loc, + AffineMap map, + llvm::ArrayRef operands); /// Given an affine map `map` and its input `operands`, this method composes /// into `map`, maps of AffineApplyOps whose results are the values in diff --git a/mlir/include/mlir/Analysis/AffineAnalysis.h b/mlir/include/mlir/Analysis/AffineAnalysis.h index 466bd00b471..be44f1226e0 100644 --- a/mlir/include/mlir/Analysis/AffineAnalysis.h +++ b/mlir/include/mlir/Analysis/AffineAnalysis.h @@ -36,7 +36,6 @@ class AffineForOp; class AffineValueMap; class FlatAffineConstraints; class Instruction; -template class OpPointer; class Value; /// Returns in `affineApplyOps`, the sequence of those AffineApplyOp @@ -52,7 +51,7 @@ void getReachableAffineApplyOps( /// operands are added as symbols in the system. Returns failure for the yet /// unimplemented cases. // TODO(bondhugula): handle non-unit strides. -LogicalResult getIndexSet(llvm::MutableArrayRef> forOps, +LogicalResult getIndexSet(llvm::MutableArrayRef forOps, FlatAffineConstraints *domain); /// Encapsulates a memref load or store access information. diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h index 36b82acd1b7..92c809326e3 100644 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -131,7 +131,7 @@ public: AffineValueMap(AffineMap map, ArrayRef operands, ArrayRef results = llvm::None); - explicit AffineValueMap(OpPointer applyOp); + explicit AffineValueMap(AffineApplyOp applyOp); explicit AffineValueMap(AffineBound bound); ~AffineValueMap(); @@ -385,7 +385,7 @@ public: /// instruction are added as trailing identifiers (either dimensional or /// symbolic depending on whether the operand is a valid ML Function symbol). // TODO(bondhugula): add support for non-unit strides. - LogicalResult addAffineForOpDomain(OpPointer forOp); + LogicalResult addAffineForOpDomain(AffineForOp forOp); /// Adds a lower or an upper bound for the identifier at the specified /// position with constraints being drawn from the specified bound map and diff --git a/mlir/include/mlir/Analysis/LoopAnalysis.h b/mlir/include/mlir/Analysis/LoopAnalysis.h index 7d5ebeed054..11168dc9219 100644 --- a/mlir/include/mlir/Analysis/LoopAnalysis.h +++ b/mlir/include/mlir/Analysis/LoopAnalysis.h @@ -31,7 +31,6 @@ namespace mlir { class AffineExpr; class AffineForOp; class AffineMap; -template class OpPointer; class Instruction; class MemRefType; class Value; @@ -44,18 +43,18 @@ class Value; /// bounds before computing the trip count expressions // TODO(mlir-team): this should be moved into 'Transforms/' and be replaced by a // pure analysis method relying on FlatAffineConstraints -void buildTripCountMapAndOperands(OpPointer forOp, AffineMap *map, +void buildTripCountMapAndOperands(AffineForOp forOp, AffineMap *map, SmallVectorImpl *operands); /// Returns the trip count of the loop if it's a constant, None otherwise. This /// uses affine expression analysis and is able to determine constant trip count /// in non-trivial cases. -llvm::Optional getConstantTripCount(OpPointer forOp); +llvm::Optional getConstantTripCount(AffineForOp forOp); /// Returns the greatest known integral divisor of the trip count. Affine /// expression analysis is used (indirectly through getTripCount), and /// this method is thus able to determine non-trivial divisors. -uint64_t getLargestDivisorOfTripCount(OpPointer forOp); +uint64_t getLargestDivisorOfTripCount(AffineForOp forOp); /// Given an induction variable `iv` of type AffineForOp and an `index` of type /// IndexType, returns `true` if `index` is independent of `iv` and false @@ -92,13 +91,13 @@ getInvariantAccesses(Value &iv, llvm::ArrayRef indices); /// 3. all nested load/stores are to scalar MemRefs. /// TODO(ntv): implement dependence semantics /// TODO(ntv): relax the no-conditionals restriction -bool isVectorizableLoop(OpPointer loop); +bool isVectorizableLoop(AffineForOp loop); /// Checks whether the loop is structurally vectorizable and that all the LoadOp /// and StoreOp matched have access indexing functions that are are either: /// 1. invariant along the loop induction variable created by 'loop'; /// 2. varying along the 'fastestVaryingDim' memory dimension. -bool isVectorizableLoopAlongFastestVaryingMemRefDim(OpPointer loop, +bool isVectorizableLoopAlongFastestVaryingMemRefDim(AffineForOp loop, unsigned fastestVaryingDim); /// Checks where SSA dominance would be violated if a for inst's body @@ -106,8 +105,7 @@ bool isVectorizableLoopAlongFastestVaryingMemRefDim(OpPointer loop, /// 'def' and all its uses have the same shift factor. // TODO(mlir-team): extend this to check for memory-based dependence // violation when we have the support. -bool isInstwiseShiftValid(OpPointer forOp, - llvm::ArrayRef shifts); +bool isInstwiseShiftValid(AffineForOp forOp, llvm::ArrayRef shifts); } // end namespace mlir #endif // MLIR_ANALYSIS_LOOP_ANALYSIS_H diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index 0982849302c..84e388361c9 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -41,15 +41,13 @@ class FlatAffineConstraints; class Instruction; class Location; class MemRefAccess; -template class OpPointer; class Instruction; class Value; /// Populates 'loops' with IVs of the loops surrounding 'inst' ordered from /// the outermost 'for' instruction to the innermost one. // TODO(bondhugula): handle 'if' inst's. -void getLoopIVs(Instruction &inst, - SmallVectorImpl> *loops); +void getLoopIVs(Instruction &inst, SmallVectorImpl *loops); /// Returns the nesting depth of this instruction, i.e., the number of loops /// surrounding this instruction. @@ -57,7 +55,7 @@ unsigned getNestingDepth(Instruction &inst); /// Returns in 'sequentialLoops' all sequential loops in loop nest rooted /// at 'forOp'. -void getSequentialLoops(OpPointer forOp, +void getSequentialLoops(AffineForOp forOp, llvm::SmallDenseSet *sequentialLoops); /// ComputationSliceState aggregates loop IVs, loop bound AffineMaps and their @@ -105,10 +103,10 @@ LogicalResult getBackwardComputationSliceState( // materialize the results of the backward slice - presenting a trade-off b/w // storage and redundant computation in several cases. // TODO(andydavis) Support computation slices with common surrounding loops. -OpPointer -insertBackwardComputationSlice(Instruction *srcOpInst, Instruction *dstOpInst, - unsigned dstLoopDepth, - ComputationSliceState *sliceState); +AffineForOp insertBackwardComputationSlice(Instruction *srcOpInst, + Instruction *dstOpInst, + unsigned dstLoopDepth, + ComputationSliceState *sliceState); /// A region of a memref's data space; this is typically constructed by /// analyzing load/store op's on this memref and the index space of loops @@ -235,11 +233,11 @@ unsigned getNumCommonSurroundingLoops(Instruction &A, Instruction &B); /// Gets the memory footprint of all data touched in the specified memory space /// in bytes; if the memory space is unspecified, considers all memory spaces. -Optional getMemoryFootprintBytes(OpPointer forOp, +Optional getMemoryFootprintBytes(AffineForOp forOp, int memorySpace = -1); /// Returns true if `forOp' is a parallel loop. -bool isLoopParallel(OpPointer forOp); +bool isLoopParallel(AffineForOp forOp); } // end namespace mlir diff --git a/mlir/include/mlir/EDSC/MLIREmitter.h b/mlir/include/mlir/EDSC/MLIREmitter.h index 1ee608d80f0..93ee8813001 100644 --- a/mlir/include/mlir/EDSC/MLIREmitter.h +++ b/mlir/include/mlir/EDSC/MLIREmitter.h @@ -165,7 +165,7 @@ struct MLIREmitter { } return res; } - OpPointer getAffineForOp(Expr e); + AffineForOp getAffineForOp(Expr e); private: /// Emits the MLIR for `expr` and inserts at the `builder`'s insertion point. diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 13b58c40ab3..baf71879afd 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -254,7 +254,7 @@ public: /// Create operation of specific op type at the current insertion point. template - OpPointer create(Location location, Args... args) { + OpTy create(Location location, Args... args) { OperationState state(getContext(), location, OpTy::getOperationName()); OpTy::build(this, &state, args...); auto *inst = createOperation(state); diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h index bc36a068f45..0b21d90f336 100644 --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -37,7 +37,6 @@ class FunctionType; class MLIRContext; class Module; class ArgumentIterator; -template class OpPointer; /// This is the base class for all of the MLIR function types. class Function : public llvm::ilist_node_with_parent { @@ -110,8 +109,7 @@ public: void walk(const std::function &callback); /// Specialization of walk to only visit operations of 'OpTy'. - template - void walk(std::function)> callback) { + template void walk(std::function callback) { walk([&](Instruction *inst) { if (auto op = inst->dyn_cast()) callback(op); @@ -124,7 +122,7 @@ public: /// Specialization of walkPostOrder to only visit operations of 'OpTy'. template - void walkPostOrder(std::function)> callback) { + void walkPostOrder(std::function callback) { walkPostOrder([&](Instruction *inst) { if (auto op = inst->dyn_cast()) callback(op); diff --git a/mlir/include/mlir/IR/Instruction.h b/mlir/include/mlir/IR/Instruction.h index f317a7870ba..4d0ebdb1840 100644 --- a/mlir/include/mlir/IR/Instruction.h +++ b/mlir/include/mlir/IR/Instruction.h @@ -33,7 +33,6 @@ namespace mlir { class BlockAndValueMapping; class Location; class MLIRContext; -template class OpPointer; class OperandIterator; class ResultIterator; class ResultTypeIterator; @@ -363,27 +362,20 @@ public: // Conversions to declared operations like DimOp //===--------------------------------------------------------------------===// - // Return a null OpPointer for the specified type. - template static OpPointer getNull() { - return OpPointer(OpClass(nullptr)); - } - /// The dyn_cast methods perform a dynamic cast from an Instruction to a typed - /// Op like DimOp. This returns a null OpPointer on failure. - template OpPointer dyn_cast() { - if (isa()) { + /// Op like DimOp. This returns a null Op on failure. + template OpClass dyn_cast() { + if (isa()) return cast(); - } else { - return OpPointer(OpClass(nullptr)); - } + return OpClass(); } /// The cast methods perform a cast from an Instruction to a typed Op like /// DimOp. This aborts if the parameter to the template isn't an instance of /// the template type argument. - template OpPointer cast() { + template OpClass cast() { assert(isa() && "cast() argument of incompatible type!"); - return OpPointer(OpClass(this)); + return OpClass(this); } /// The is methods return true if the operation is a typed op (like DimOp) of @@ -399,8 +391,7 @@ public: void walk(const std::function &callback); /// Specialization of walk to only visit operations of 'OpTy'. - template - void walk(std::function)> callback) { + template void walk(std::function callback) { walk([&](Instruction *inst) { if (auto op = inst->dyn_cast()) callback(op); @@ -413,7 +404,7 @@ public: /// Specialization of walkPostOrder to only visit operations of 'OpTy'. template - void walkPostOrder(std::function)> callback) { + void walkPostOrder(std::function callback) { walkPostOrder([&](Instruction *inst) { if (auto op = inst->dyn_cast()) callback(op); diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index d742919cd9d..ecf6368bb17 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -54,48 +54,6 @@ template struct IsSingleResult { OpType *, OpTrait::OneResult *>::value; }; -/// This pointer represents a notional "Instruction*" but where the actual -/// storage of the pointer is maintained in the templated "OpType" class. -template -class OpPointer { -public: - explicit OpPointer() : value(Instruction::getNull().value) {} - explicit OpPointer(OpType value) : value(value) {} - - OpType &operator*() { return value; } - - OpType *operator->() { return &value; } - - explicit operator bool() { return value.getInstruction(); } - - bool operator==(OpPointer rhs) { - return value.getInstruction() == rhs.value.getInstruction(); - } - bool operator!=(OpPointer rhs) { return !(*this == rhs); } - - /// OpPointer can be implicitly converted to OpType*. - /// Return `nullptr` if there is no associated Instruction*. - operator OpType *() { - if (!value.getInstruction()) - return nullptr; - return &value; - } - - operator OpType() { return value; } - - /// If the OpType operation includes the OneResult trait, then OpPointer can - /// be implicitly converted to an Value*. This yields the value of the - /// only result. - template - operator typename std::enable_if::value, - Value *>::type() { - return value.getResult(); - } - -private: - OpType value; -}; - /// This is the concrete base class that holds the operation pointer and has /// non-generic methods that only depend on State (to avoid having them /// instantiated on template types that don't affect them. @@ -104,6 +62,12 @@ private: /// they aren't customized. class OpState { public: + /// Ops are pointer-like, so we allow implicit conversion to bool. + operator bool() { return getInstruction() != nullptr; } + + /// This implicitly converts to Instruction*. + operator Instruction *() const { return state; } + /// Return the operation that this refers to. Instruction *getInstruction() { return state; } @@ -186,6 +150,14 @@ private: Instruction *state; }; +// Allow comparing operators. +inline bool operator==(OpState lhs, OpState rhs) { + return lhs.getInstruction() == rhs.getInstruction(); +} +inline bool operator!=(OpState lhs, OpState rhs) { + return lhs.getInstruction() != rhs.getInstruction(); +} + /// This template defines the constantFoldHook and foldHook as used by /// AbstractOperation. /// @@ -257,6 +229,12 @@ template class FoldingHook::type> { public: + /// If the operation returns a single value, then the Op can be implicitly + /// converted to an Value*. This yields the value of the only result. + operator Value *() { + return static_cast(this)->getInstruction()->getResult(0); + } + /// This is an implementation detail of the constant folder hook for /// AbstractOperation. static LogicalResult constantFoldHook(Instruction *op, @@ -801,8 +779,14 @@ public: /// to introspect traits on this operation. using ConcreteOpType = ConcreteType; + /// This is a public constructor. Any op can be initialized to null. + explicit Op() : OpState(nullptr) {} + protected: + /// This is a private constructor only accessible through the + /// Instruction::cast family of methods. explicit Op(Instruction *state) : OpState(state) {} + friend class Instruction; private: template struct BaseVerifier; @@ -866,6 +850,9 @@ template class... Traits> class CastOp : public Op { public: + using Op::Op; + static void build(Builder *builder, OperationState *result, Value *source, Type destType) { impl::buildCastOp(builder, result, source, destType); @@ -876,11 +863,6 @@ public: void print(OpAsmPrinter *p) { return impl::printCastOp(this->getInstruction(), p); } - -protected: - explicit CastOp(Instruction *state) - : Op(state) {} }; } // end namespace mlir diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index b9cc12460ff..5d5b1c3a1ff 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -185,7 +185,7 @@ public: /// Create operation of specific op type at the current insertion point /// without verifying to see if it is valid. template - OpPointer create(Location location, Args... args) { + OpTy create(Location location, Args... args) { OperationState state(getContext(), location, OpTy::getOperationName()); OpTy::build(this, &state, args...); auto *op = createOperation(state); @@ -198,7 +198,7 @@ public: /// If the result is an invalid op (the verifier hook fails), emit an error /// and return null. template - OpPointer createChecked(Location location, Args... args) { + OpTy createChecked(Location location, Args... args) { OperationState state(getContext(), location, OpTy::getOperationName()); OpTy::build(this, &state, args...); auto *op = createOperation(state); @@ -213,7 +213,7 @@ public: // Otherwise, the error message got emitted. Just remove the instruction // we made. op->erase(); - return OpPointer(); + return OpTy(); } /// This method performs the final replacement for a pattern, where the diff --git a/mlir/include/mlir/StandardOps/Ops.h b/mlir/include/mlir/StandardOps/Ops.h index 1502ddb4a9f..b3e0f9daccf 100644 --- a/mlir/include/mlir/StandardOps/Ops.h +++ b/mlir/include/mlir/StandardOps/Ops.h @@ -68,6 +68,8 @@ public: class AllocOp : public Op { public: + using Op::Op; + /// The result of an alloc is always a MemRefType. MemRefType getType() { return getResult()->getType().cast(); } @@ -81,10 +83,6 @@ public: void print(OpAsmPrinter *p); static void getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context); - -private: - friend class Instruction; - explicit AllocOp(Instruction *state) : Op(state) {} }; /// The "br" operation represents a branch instruction in a function. @@ -100,6 +98,8 @@ private: class BranchOp : public Op { public: + using Op::Op; + static StringRef getOperationName() { return "std.br"; } static void build(Builder *builder, OperationState *result, Block *dest, @@ -115,10 +115,6 @@ public: /// Erase the operand at 'index' from the operand list. void eraseOperand(unsigned index); - -private: - friend class Instruction; - explicit BranchOp(Instruction *state) : Op(state) {} }; /// The "call" operation represents a direct call to a function. The operands @@ -130,6 +126,8 @@ private: class CallOp : public Op { public: + using Op::Op; + static StringRef getOperationName() { return "std.call"; } static void build(Builder *builder, OperationState *result, Function *callee, @@ -151,10 +149,6 @@ public: static bool parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); bool verify(); - -protected: - friend class Instruction; - explicit CallOp(Instruction *state) : Op(state) {} }; /// The "call_indirect" operation represents an indirect call to a value of @@ -168,6 +162,7 @@ protected: class CallIndirectOp : public Op { public: + using Op::Op; static StringRef getOperationName() { return "std.call_indirect"; } static void build(Builder *builder, OperationState *result, Value *callee, @@ -189,10 +184,6 @@ public: bool verify(); static void getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context); - -protected: - friend class Instruction; - explicit CallIndirectOp(Instruction *state) : Op(state) {} }; /// The predicate indicates the type of the comparison to perform: @@ -240,6 +231,8 @@ class CmpIOp OpTrait::OneResult, OpTrait::ResultsAreBoolLike, OpTrait::SameOperandsAndResultShape, OpTrait::HasNoSideEffect> { public: + using Op::Op; + CmpIPredicate getPredicate() { return (CmpIPredicate)getAttrOfType(getPredicateAttrName()) .getInt(); @@ -255,10 +248,6 @@ public: void print(OpAsmPrinter *p); bool verify(); Attribute constantFold(ArrayRef operands, MLIRContext *context); - -private: - friend class Instruction; - explicit CmpIOp(Instruction *state) : Op(state) {} }; /// The "cond_br" operation represents a conditional branch instruction in a @@ -283,6 +272,8 @@ class CondBranchOp : public Op::Impl, /// follows: /// { condition, [true_operands], [false_operands] } public: + using Op::Op; + static StringRef getOperationName() { return "std.cond_br"; } static void build(Builder *builder, OperationState *result, Value *condition, @@ -363,9 +354,6 @@ private: unsigned getFalseDestOperandIndex() { return getTrueDestOperandIndex() + getNumTrueOperands(); } - - friend class Instruction; - explicit CondBranchOp(Instruction *state) : Op(state) {} }; /// The "constant" operation requires a single attribute named "value". @@ -377,6 +365,8 @@ private: class ConstantOp : public Op { public: + using Op::Op; + /// Builds a constant op with the specified attribute value and result type. static void build(Builder *builder, OperationState *result, Type type, Attribute value); @@ -394,10 +384,6 @@ public: void print(OpAsmPrinter *p); bool verify(); Attribute constantFold(ArrayRef operands, MLIRContext *context); - -protected: - friend class Instruction; - explicit ConstantOp(Instruction *state) : Op(state) {} }; /// This is a refinement of the "constant" op for the case where it is @@ -407,6 +393,8 @@ protected: /// class ConstantFloatOp : public ConstantOp { public: + using ConstantOp::ConstantOp; + /// Builds a constant float op producing a float of the specified type. static void build(Builder *builder, OperationState *result, const APFloat &value, FloatType type); @@ -414,10 +402,6 @@ public: APFloat getValue() { return getAttrOfType("value").getValue(); } static bool isClassFor(Instruction *op); - -private: - friend class Instruction; - explicit ConstantFloatOp(Instruction *state) : ConstantOp(state) {} }; /// This is a refinement of the "constant" op for the case where it is @@ -427,6 +411,7 @@ private: /// class ConstantIntOp : public ConstantOp { public: + using ConstantOp::ConstantOp; /// Build a constant int op producing an integer of the specified width. static void build(Builder *builder, OperationState *result, int64_t value, unsigned width); @@ -439,10 +424,6 @@ public: int64_t getValue() { return getAttrOfType("value").getInt(); } static bool isClassFor(Instruction *op); - -private: - friend class Instruction; - explicit ConstantIntOp(Instruction *state) : ConstantOp(state) {} }; /// This is a refinement of the "constant" op for the case where it is @@ -452,16 +433,14 @@ private: /// class ConstantIndexOp : public ConstantOp { public: + using ConstantOp::ConstantOp; + /// Build a constant int op producing an index. static void build(Builder *builder, OperationState *result, int64_t value); int64_t getValue() { return getAttrOfType("value").getInt(); } static bool isClassFor(Instruction *op); - -private: - friend class Instruction; - explicit ConstantIndexOp(Instruction *state) : ConstantOp(state) {} }; /// The "dealloc" operation frees the region of memory referenced by a memref @@ -477,6 +456,8 @@ private: class DeallocOp : public Op { public: + using Op::Op; + Value *getMemRef() { return getOperand(); } void setMemRef(Value *value) { setOperand(value); } @@ -489,10 +470,6 @@ public: void print(OpAsmPrinter *p); static void getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context); - -private: - friend class Instruction; - explicit DeallocOp(Instruction *state) : Op(state) {} }; /// The "dim" operation takes a memref or tensor operand and returns an @@ -504,6 +481,8 @@ private: class DimOp : public Op { public: + using Op::Op; + static void build(Builder *builder, OperationState *result, Value *memrefOrTensor, unsigned index); @@ -520,10 +499,6 @@ public: bool verify(); static bool parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); - -private: - friend class Instruction; - explicit DimOp(Instruction *state) : Op(state) {} }; // DmaStartOp starts a non-blocking DMA operation that transfers data from a @@ -566,6 +541,8 @@ private: class DmaStartOp : public Op { public: + using Op::Op; + static void build(Builder *builder, OperationState *result, Value *srcMemRef, ArrayRef srcIndices, Value *destMemRef, ArrayRef destIndices, Value *numElements, @@ -671,10 +648,6 @@ public: return nullptr; return getOperand(getNumOperands() - 1); } - -protected: - friend class Instruction; - explicit DmaStartOp(Instruction *state) : Op(state) {} }; // DmaWaitOp blocks until the completion of a DMA operation associated with the @@ -693,6 +666,8 @@ protected: class DmaWaitOp : public Op { public: + using Op::Op; + static void build(Builder *builder, OperationState *result, Value *tagMemRef, ArrayRef tagIndices, Value *numElements); @@ -719,10 +694,6 @@ public: void print(OpAsmPrinter *p); static void getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context); - -protected: - friend class Instruction; - explicit DmaWaitOp(Instruction *state) : Op(state) {} }; /// The "extract_element" op reads a tensor or vector and returns one element @@ -740,6 +711,8 @@ class ExtractElementOp : public Op { public: + using Op::Op; + static void build(Builder *builder, OperationState *result, Value *aggregate, ArrayRef indices = {}); @@ -757,10 +730,6 @@ public: static bool parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); Attribute constantFold(ArrayRef operands, MLIRContext *context); - -private: - friend class Instruction; - explicit ExtractElementOp(Instruction *state) : Op(state) {} }; /// The "load" op reads an element from a memref specified by an index list. The @@ -774,6 +743,8 @@ private: class LoadOp : public Op { public: + using Op::Op; + // Hooks to customize behavior of this op. static void build(Builder *builder, OperationState *result, Value *memref, ArrayRef indices = {}); @@ -796,10 +767,6 @@ public: void print(OpAsmPrinter *p); static void getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context); - -private: - friend class Instruction; - explicit LoadOp(Instruction *state) : Op(state) {} }; /// The "memref_cast" operation converts a memref from one type to an equivalent @@ -819,6 +786,7 @@ private: /// class MemRefCastOp : public CastOp { public: + using CastOp::CastOp; static StringRef getOperationName() { return "std.memref_cast"; } /// The result of a memref_cast is always a memref. @@ -827,10 +795,6 @@ public: void print(OpAsmPrinter *p); bool verify(); - -private: - friend class Instruction; - explicit MemRefCastOp(Instruction *state) : CastOp(state) {} }; /// The "return" operation represents a return instruction within a function. @@ -845,6 +809,8 @@ private: class ReturnOp : public Op { public: + using Op::Op; + static StringRef getOperationName() { return "std.return"; } static void build(Builder *builder, OperationState *result, @@ -854,10 +820,6 @@ public: static bool parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); bool verify(); - -private: - friend class Instruction; - explicit ReturnOp(Instruction *state) : Op(state) {} }; /// The "select" operation chooses one value based on a binary condition @@ -874,6 +836,8 @@ private: class SelectOp : public Op::Impl, OpTrait::OneResult, OpTrait::HasNoSideEffect> { public: + using Op::Op; + static StringRef getOperationName() { return "std.select"; } static void build(Builder *builder, OperationState *result, Value *condition, Value *trueValue, Value *falseValue); @@ -886,10 +850,6 @@ public: Value *getFalseValue() { return getOperand(2); } Value *fold(); - -private: - friend class Instruction; - explicit SelectOp(Instruction *state) : Op(state) {} }; /// The "store" op writes an element to a memref specified by an index list. @@ -905,6 +865,8 @@ private: class StoreOp : public Op { public: + using Op::Op; + // Hooks to customize behavior of this op. static void build(Builder *builder, OperationState *result, Value *valueToStore, Value *memref, @@ -931,10 +893,6 @@ public: static void getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context); - -private: - friend class Instruction; - explicit StoreOp(Instruction *state) : Op(state) {} }; /// The "tensor_cast" operation converts a tensor from one type to an equivalent @@ -949,6 +907,8 @@ private: /// class TensorCastOp : public CastOp { public: + using CastOp::CastOp; + static StringRef getOperationName() { return "std.tensor_cast"; } /// The result of a tensor_cast is always a tensor. @@ -957,10 +917,6 @@ public: void print(OpAsmPrinter *p); bool verify(); - -private: - friend class Instruction; - explicit TensorCastOp(Instruction *state) : CastOp(state) {} }; /// Prints dimension and symbol list. diff --git a/mlir/include/mlir/SuperVectorOps/SuperVectorOps.h b/mlir/include/mlir/SuperVectorOps/SuperVectorOps.h index bb9fb8c5b66..b2e384157b9 100644 --- a/mlir/include/mlir/SuperVectorOps/SuperVectorOps.h +++ b/mlir/include/mlir/SuperVectorOps/SuperVectorOps.h @@ -96,6 +96,8 @@ class VectorTransferReadOp enum Offsets : unsigned { MemRefOffset = 0, FirstIndexOffset = 1 }; public: + using Op::Op; + static StringRef getOperationName() { return "vector_transfer_read"; } static StringRef getPermutationMapAttrName() { return "permutation_map"; } static void build(Builder *builder, OperationState *result, @@ -118,10 +120,6 @@ public: static bool parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); bool verify(); - -private: - friend class Instruction; - explicit VectorTransferReadOp(Instruction *state) : Op(state) {} }; /// VectorTransferWriteOp performs a blocking write from a super-vector to @@ -162,6 +160,8 @@ class VectorTransferWriteOp }; public: + using Op::Op; + static StringRef getOperationName() { return "vector_transfer_write"; } static StringRef getPermutationMapAttrName() { return "permutation_map"; } static void build(Builder *builder, OperationState *result, Value *srcVector, @@ -181,10 +181,6 @@ public: static bool parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); bool verify(); - -private: - friend class Instruction; - explicit VectorTransferWriteOp(Instruction *state) : Op(state) {} }; /// VectorTypeCastOp performs a conversion from a memref with scalar element to @@ -199,16 +195,14 @@ private: class VectorTypeCastOp : public Op { public: + using Op::Op; + static StringRef getOperationName() { return "vector_type_cast"; } static void build(Builder *builder, OperationState *result, Value *srcVector, Type dstType); static bool parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); bool verify(); - -private: - friend class Instruction; - explicit VectorTypeCastOp(Instruction *state) : Op(state) {} }; } // end namespace mlir diff --git a/mlir/include/mlir/Transforms/LoopUtils.h b/mlir/include/mlir/Transforms/LoopUtils.h index a7addae3d94..0404ab74244 100644 --- a/mlir/include/mlir/Transforms/LoopUtils.h +++ b/mlir/include/mlir/Transforms/LoopUtils.h @@ -32,35 +32,32 @@ class AffineMap; class AffineForOp; class Function; class FuncBuilder; -template class OpPointer; class Value; /// Unrolls this for instruction completely if the trip count is known to be /// constant. Returns failure otherwise. -LogicalResult loopUnrollFull(OpPointer forOp); +LogicalResult loopUnrollFull(AffineForOp forOp); /// Unrolls this for instruction by the specified unroll factor. Returns failure /// if the loop cannot be unrolled either due to restrictions or due to invalid /// unroll factors. -LogicalResult loopUnrollByFactor(OpPointer forOp, - uint64_t unrollFactor); +LogicalResult loopUnrollByFactor(AffineForOp forOp, uint64_t unrollFactor); /// Unrolls this loop by the specified unroll factor or its trip count, /// whichever is lower. -LogicalResult loopUnrollUpToFactor(OpPointer forOp, - uint64_t unrollFactor); +LogicalResult loopUnrollUpToFactor(AffineForOp forOp, uint64_t unrollFactor); /// Unrolls and jams this loop by the specified factor. Returns success if the /// loop is successfully unroll-jammed. -LogicalResult loopUnrollJamByFactor(OpPointer forOp, +LogicalResult loopUnrollJamByFactor(AffineForOp forOp, uint64_t unrollJamFactor); /// Unrolls and jams this loop by the specified factor or by the trip count (if /// constant), whichever is lower. -LogicalResult loopUnrollJamUpToFactor(OpPointer forOp, +LogicalResult loopUnrollJamUpToFactor(AffineForOp forOp, uint64_t unrollJamFactor); /// Promotes the loop body of a AffineForOp to its containing block if the /// AffineForOp was known to have a single iteration. -LogicalResult promoteIfSingleIteration(OpPointer forOp); +LogicalResult promoteIfSingleIteration(AffineForOp forOp); /// Promotes all single iteration AffineForOp's in the Function, i.e., moves /// their body into the containing Block. @@ -71,8 +68,8 @@ void promoteSingleIterationLoops(Function *f); /// part of the unrolled loop. Computes the bound as an AffineMap with its /// operands or a null map when the trip count can't be expressed as an affine /// expression. -void getCleanupLoopLowerBound(OpPointer forOp, - unsigned unrollFactor, AffineMap *map, +void getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor, + AffineMap *map, SmallVectorImpl *operands, FuncBuilder *builder); @@ -80,42 +77,39 @@ void getCleanupLoopLowerBound(OpPointer forOp, /// instruction-wise shifts. The shifts are with respect to the original /// execution order, and are multiplied by the loop 'step' before being applied. LLVM_NODISCARD -LogicalResult instBodySkew(OpPointer forOp, - ArrayRef shifts, +LogicalResult instBodySkew(AffineForOp forOp, ArrayRef shifts, bool unrollPrologueEpilogue = false); /// Tiles the specified band of perfectly nested loops creating tile-space loops /// and intra-tile loops. A band is a contiguous set of loops. LLVM_NODISCARD -LogicalResult tileCodeGen(MutableArrayRef> band, +LogicalResult tileCodeGen(MutableArrayRef band, ArrayRef tileSizes); /// Performs loop interchange on 'forOpA' and 'forOpB'. Requires that 'forOpA' /// and 'forOpB' are part of a perfectly nested sequence of loops. -void interchangeLoops(OpPointer forOpA, - OpPointer forOpB); +void interchangeLoops(AffineForOp forOpA, AffineForOp forOpB); /// Sinks 'forOp' by 'loopDepth' levels by performing a series of loop /// interchanges. Requires that 'forOp' is part of a perfect nest with /// 'loopDepth' AffineForOps consecutively nested under it. -void sinkLoop(OpPointer forOp, unsigned loopDepth); +void sinkLoop(AffineForOp forOp, unsigned loopDepth); /// Performs tiling fo imperfectly nested loops (with interchange) by /// strip-mining the `forOps` by `sizes` and sinking them, in their order of /// occurrence in `forOps`, under each of the `targets`. /// Returns the new AffineForOps, one per each of (`forOps`, `targets`) pair, /// nested immediately under each of `targets`. -SmallVector, 8>, 8> -tile(ArrayRef> forOps, ArrayRef sizes, - ArrayRef> targets); +SmallVector, 8> tile(ArrayRef forOps, + ArrayRef sizes, + ArrayRef targets); /// Performs tiling (with interchange) by strip-mining the `forOps` by `sizes` /// and sinking them, in their order of occurrence in `forOps`, under `target`. /// Returns the new AffineForOps, one per `forOps`, nested immediately under /// `target`. -SmallVector, 8> -tile(ArrayRef> forOps, ArrayRef sizes, - OpPointer target); +SmallVector tile(ArrayRef forOps, + ArrayRef sizes, AffineForOp target); } // end namespace mlir diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index 83a6e24a7f4..3a75a2619f4 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -30,7 +30,6 @@ namespace mlir { class AffineForOp; -template class OpPointer; class FunctionPassBase; class ModulePassBase; @@ -62,8 +61,7 @@ FunctionPassBase *createMaterializeVectorsPass(); /// all) or the default unroll factor is used (LoopUnroll:kDefaultUnrollFactor). FunctionPassBase *createLoopUnrollPass( int unrollFactor = -1, int unrollFull = -1, - const std::function)> &getUnrollFactor = - nullptr); + const std::function &getUnrollFactor = nullptr); /// Creates a loop unroll jam pass to unroll jam by the specified factor. A /// factor of -1 lets the pass use the default factor or the one on the command diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h index 0fc076d1a65..ab5660be871 100644 --- a/mlir/include/mlir/Transforms/Utils.h +++ b/mlir/include/mlir/Transforms/Utils.h @@ -116,8 +116,8 @@ Instruction *createComposedAffineApplyOp(FuncBuilder *builder, Location loc, /// all the affine.apply op's supplying operands to this opInst did not have any /// uses other than those in this opInst. The method otherwise returns the list /// of affine.apply operations created in output argument `sliceOps`. -void createAffineComputationSlice( - Instruction *opInst, SmallVectorImpl> *sliceOps); +void createAffineComputationSlice(Instruction *opInst, + SmallVectorImpl *sliceOps); /// Replaces (potentially nested) function attributes in the operation "op" /// with those specified in "remappingTable". diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 9cb74187cc1..e02d2590154 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -485,7 +485,7 @@ AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map, auto *t = operands[i]; auto affineApply = t->getDefiningInst() ? t->getDefiningInst()->dyn_cast() - : OpPointer(); + : AffineApplyOp(); if (affineApply) { // a. Compose affine.apply instructions. LLVM_DEBUG(affineApply->getInstruction()->print( @@ -567,9 +567,9 @@ void mlir::fullyComposeAffineMapAndOperands( } } -OpPointer -mlir::makeComposedAffineApply(FuncBuilder *b, Location loc, AffineMap map, - ArrayRef operands) { +AffineApplyOp mlir::makeComposedAffineApply(FuncBuilder *b, Location loc, + AffineMap map, + ArrayRef operands) { AffineMap normalizedMap = map; SmallVector normalizedOperands(operands.begin(), operands.end()); composeAffineMapAndOperands(&normalizedMap, &normalizedOperands); @@ -1070,15 +1070,14 @@ Block *AffineForOp::createBody() { AffineBound AffineForOp::getLowerBound() { auto lbMap = getLowerBoundMap(); - return AffineBound(OpPointer(*this), 0, lbMap.getNumInputs(), - lbMap); + return AffineBound(AffineForOp(*this), 0, lbMap.getNumInputs(), lbMap); } AffineBound AffineForOp::getUpperBound() { auto lbMap = getLowerBoundMap(); auto ubMap = getUpperBoundMap(); - return AffineBound(OpPointer(*this), lbMap.getNumInputs(), - getNumOperands(), ubMap); + return AffineBound(AffineForOp(*this), lbMap.getNumInputs(), getNumOperands(), + ubMap); } void AffineForOp::setLowerBound(ArrayRef lbOperands, AffineMap map) { @@ -1178,24 +1177,24 @@ Value *AffineForOp::getInductionVar() { return getBody()->getArgument(0); } /// Returns if the provided value is the induction variable of a AffineForOp. bool mlir::isForInductionVar(Value *val) { - return getForInductionVarOwner(val) != nullptr; + return getForInductionVarOwner(val) != AffineForOp(); } /// Returns the loop parent of an induction variable. If the provided value is /// not an induction variable, then return nullptr. -OpPointer mlir::getForInductionVarOwner(Value *val) { +AffineForOp mlir::getForInductionVarOwner(Value *val) { auto *ivArg = dyn_cast(val); if (!ivArg || !ivArg->getOwner()) - return OpPointer(); + return AffineForOp(); auto *containingInst = ivArg->getOwner()->getParent()->getContainingInst(); if (!containingInst) - return OpPointer(); + return AffineForOp(); return containingInst->dyn_cast(); } /// Extracts the induction variables from a list of AffineForOps and returns /// them. -void mlir::extractForInductionVars(ArrayRef> forInsts, +void mlir::extractForInductionVars(ArrayRef forInsts, SmallVectorImpl *ivs) { ivs->reserve(forInsts.size()); for (auto forInst : forInsts) diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index 0b7d9f831e4..f786731e88a 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -98,7 +98,7 @@ void mlir::getReachableAffineApplyOps( // stride information in FlatAffineConstraints. (For eg., by using iv - lb % // step = 0 and/or by introducing a method in FlatAffineConstraints // setExprStride(ArrayRef expr, int64_t stride) -LogicalResult mlir::getIndexSet(MutableArrayRef> forOps, +LogicalResult mlir::getIndexSet(MutableArrayRef forOps, FlatAffineConstraints *domain) { SmallVector indices; extractForInductionVars(forOps, &indices); @@ -122,7 +122,7 @@ static LogicalResult getInstIndexSet(Instruction *inst, FlatAffineConstraints *indexSet) { // TODO(andydavis) Extend this to gather enclosing IfInsts and consider // factoring it out into a utility function. - SmallVector, 4> loops; + SmallVector loops; getLoopIVs(*inst, &loops); return getIndexSet(loops, indexSet); } @@ -461,7 +461,7 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap, if (auto *opInst = symbol->getDefiningInst()) { if (auto constOp = opInst->dyn_cast()) { dependenceDomain->setIdToConstant(valuePosMap.getSymPos(symbol), - constOp->getValue()); + constOp.getValue()); } } } diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 3de26589b12..64d1809922c 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -217,7 +217,7 @@ AffineValueMap::AffineValueMap(AffineMap map, ArrayRef operands, : map(map), operands(operands.begin(), operands.end()), results(results.begin(), results.end()) {} -AffineValueMap::AffineValueMap(OpPointer applyOp) +AffineValueMap::AffineValueMap(AffineApplyOp applyOp) : map(applyOp->getAffineMap()), operands(applyOp->operand_begin(), applyOp->operand_end()) { results.push_back(applyOp->getResult()); @@ -729,13 +729,12 @@ void FlatAffineConstraints::addInductionVarOrTerminalSymbol(Value *id) { // Check if the symbol is a constant. if (auto *opInst = id->getDefiningInst()) { if (auto constOp = opInst->dyn_cast()) { - setIdToConstant(*id, constOp->getValue()); + setIdToConstant(*id, constOp.getValue()); } } } -LogicalResult -FlatAffineConstraints::addAffineForOpDomain(OpPointer forOp) { +LogicalResult FlatAffineConstraints::addAffineForOpDomain(AffineForOp forOp) { unsigned pos; // Pre-condition for this method. if (!findId(*forOp->getInductionVar(), &pos)) { @@ -772,10 +771,8 @@ FlatAffineConstraints::addAffineForOpDomain(OpPointer forOp) { addConstantLowerBound(pos, forOp->getConstantLowerBound()); } else { // Non-constant lower bound case. - OpPointer ncForOp = - *reinterpret_cast *>(&forOp); - SmallVector lbOperands(ncForOp->getLowerBoundOperands().begin(), - ncForOp->getLowerBoundOperands().end()); + SmallVector lbOperands(forOp->getLowerBoundOperands().begin(), + forOp->getLowerBoundOperands().end()); if (failed(addLowerOrUpperBound(pos, forOp->getLowerBoundMap(), lbOperands, /*eq=*/false, /*lower=*/true))) return failure(); @@ -786,10 +783,8 @@ FlatAffineConstraints::addAffineForOpDomain(OpPointer forOp) { return success(); } // Non-constant upper bound case. - OpPointer ncForOp = - *reinterpret_cast *>(&forOp); - SmallVector ubOperands(ncForOp->getUpperBoundOperands().begin(), - ncForOp->getUpperBoundOperands().end()); + SmallVector ubOperands(forOp->getUpperBoundOperands().begin(), + forOp->getUpperBoundOperands().end()); return addLowerOrUpperBound(pos, forOp->getUpperBoundMap(), ubOperands, /*eq=*/false, /*lower=*/false); } diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 28b0f75909c..651d9b7491a 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -49,16 +49,12 @@ using namespace mlir; // pure analysis method relying on FlatAffineConstraints; the latter will also // be more powerful (since both inequalities and equalities will be considered). void mlir::buildTripCountMapAndOperands( - OpPointer forOp, AffineMap *map, + AffineForOp forOp, AffineMap *map, SmallVectorImpl *tripCountOperands) { int64_t loopSpan; int64_t step = forOp->getStep(); - - // We need to get operands; we aren't changing them here. - auto ncForOp = *reinterpret_cast *>(&forOp); - - FuncBuilder b(ncForOp->getInstruction()); + FuncBuilder b(forOp->getInstruction()); if (forOp->hasConstantBounds()) { int64_t lb = forOp->getConstantLowerBound(); @@ -76,8 +72,8 @@ void mlir::buildTripCountMapAndOperands( *map = AffineMap(); return; } - SmallVector lbOperands(ncForOp->getLowerBoundOperands()); - SmallVector ubOperands(ncForOp->getUpperBoundOperands()); + SmallVector lbOperands(forOp->getLowerBoundOperands()); + SmallVector ubOperands(forOp->getUpperBoundOperands()); auto lb = b.create(forOp->getLoc(), lbMap, lbOperands); SmallVector ubs; ubs.reserve(ubMap.getNumResults()); @@ -117,8 +113,7 @@ void mlir::buildTripCountMapAndOperands( // being an analysis utility, it shouldn't. Replace with a version that just // works with analysis structures (FlatAffineConstraints) and thus doesn't // update the IR. -llvm::Optional -mlir::getConstantTripCount(OpPointer forOp) { +llvm::Optional mlir::getConstantTripCount(AffineForOp forOp) { SmallVector operands; AffineMap map; buildTripCountMapAndOperands(forOp, &map, &operands); @@ -144,7 +139,7 @@ mlir::getConstantTripCount(OpPointer forOp) { /// Returns the greatest known integral divisor of the trip count. Affine /// expression analysis is used (indirectly through getTripCount), and /// this method is thus able to determine non-trivial divisors. -uint64_t mlir::getLargestDivisorOfTripCount(OpPointer forOp) { +uint64_t mlir::getLargestDivisorOfTripCount(AffineForOp forOp) { SmallVector operands; AffineMap map; buildTripCountMapAndOperands(forOp, &map, &operands); @@ -235,7 +230,7 @@ mlir::getInvariantAccesses(Value &iv, llvm::ArrayRef indices) { /// // TODO(ntv): check strides. template -static bool isContiguousAccess(Value &iv, OpPointer memoryOp, +static bool isContiguousAccess(Value &iv, LoadOrStoreOp memoryOp, unsigned fastestVaryingDim) { static_assert(std::is_same::value || std::is_same::value, @@ -281,10 +276,9 @@ static bool isVectorTransferReadOrWrite(Instruction &inst) { return inst.isa() || inst.isa(); } -using VectorizableInstFun = - std::function, Instruction &)>; +using VectorizableInstFun = std::function; -static bool isVectorizableLoopWithCond(OpPointer loop, +static bool isVectorizableLoopWithCond(AffineForOp loop, VectorizableInstFun isVectorizableInst) { auto *forInst = const_cast(loop->getInstruction()); if (!matcher::isParallelLoop(*forInst) && @@ -340,9 +334,9 @@ static bool isVectorizableLoopWithCond(OpPointer loop, } bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim( - OpPointer loop, unsigned fastestVaryingDim) { + AffineForOp loop, unsigned fastestVaryingDim) { VectorizableInstFun fun( - [fastestVaryingDim](OpPointer loop, Instruction &op) { + [fastestVaryingDim](AffineForOp loop, Instruction &op) { auto load = op.dyn_cast(); auto store = op.dyn_cast(); return load ? isContiguousAccess(*loop->getInductionVar(), load, @@ -353,10 +347,10 @@ bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim( return isVectorizableLoopWithCond(loop, fun); } -bool mlir::isVectorizableLoop(OpPointer loop) { +bool mlir::isVectorizableLoop(AffineForOp loop) { VectorizableInstFun fun( // TODO: implement me - [](OpPointer loop, Instruction &op) { return true; }); + [](AffineForOp loop, Instruction &op) { return true; }); return isVectorizableLoopWithCond(loop, fun); } @@ -365,8 +359,7 @@ bool mlir::isVectorizableLoop(OpPointer loop) { /// 'def' and all its uses have the same shift factor. // TODO(mlir-team): extend this to check for memory-based dependence // violation when we have the support. -bool mlir::isInstwiseShiftValid(OpPointer forOp, - ArrayRef shifts) { +bool mlir::isInstwiseShiftValid(AffineForOp forOp, ArrayRef shifts) { auto *forBody = forOp->getBody(); assert(shifts.size() == forBody->getInstructions().size()); diff --git a/mlir/lib/Analysis/TestParallelismDetection.cpp b/mlir/lib/Analysis/TestParallelismDetection.cpp index efe38391c48..b954f0e67d9 100644 --- a/mlir/lib/Analysis/TestParallelismDetection.cpp +++ b/mlir/lib/Analysis/TestParallelismDetection.cpp @@ -44,7 +44,7 @@ FunctionPassBase *mlir::createParallelismDetectionTestPass() { void TestParallelismDetection::runOnFunction() { Function *f = getFunction(); FuncBuilder b(f); - f->walk([&](OpPointer forOp) { + f->walk([&](AffineForOp forOp) { if (isLoopParallel(forOp)) forOp->emitNote("parallel loop"); }); diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 2cd0a83296b..57de7407248 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -39,10 +39,9 @@ using llvm::SmallDenseMap; /// Populates 'loops' with IVs of the loops surrounding 'inst' ordered from /// the outermost 'for' instruction to the innermost one. -void mlir::getLoopIVs(Instruction &inst, - SmallVectorImpl> *loops) { +void mlir::getLoopIVs(Instruction &inst, SmallVectorImpl *loops) { auto *currInst = inst.getParentInst(); - OpPointer currAffineForOp; + AffineForOp currAffineForOp; // Traverse up the hierarchy collecing all 'for' instruction while // skipping over 'if' instructions. while (currInst && ((currAffineForOp = currInst->dyn_cast()) || @@ -76,7 +75,7 @@ ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) { // Check if the symbol is a constant. if (auto *inst = value->getDefiningInst()) { if (auto constOp = inst->dyn_cast()) { - cst->setIdToConstant(*value, constOp->getValue()); + cst->setIdToConstant(*value, constOp.getValue()); } } } else { @@ -189,7 +188,7 @@ LogicalResult MemRefRegion::compute(Instruction *inst, unsigned loopDepth, << "depth: " << loopDepth << "\n";); if (rank == 0) { - SmallVector, 4> ivs; + SmallVector ivs; getLoopIVs(*inst, &ivs); SmallVector regionSymbols; extractForInductionVars(ivs, ®ionSymbols); @@ -245,7 +244,7 @@ LogicalResult MemRefRegion::compute(Instruction *inst, unsigned loopDepth, // Check if the symbol is a constant. if (auto *inst = symbol->getDefiningInst()) { if (auto constOp = inst->dyn_cast()) { - cst.setIdToConstant(*symbol, constOp->getValue()); + cst.setIdToConstant(*symbol, constOp.getValue()); } } } @@ -280,14 +279,14 @@ LogicalResult MemRefRegion::compute(Instruction *inst, unsigned loopDepth, // Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which // this memref region is symbolic. - SmallVector, 4> enclosingIVs; + SmallVector enclosingIVs; getLoopIVs(*inst, &enclosingIVs); assert(loopDepth <= enclosingIVs.size() && "invalid loop depth"); enclosingIVs.resize(loopDepth); SmallVector ids; cst.getIdValues(cst.getNumDimIds(), cst.getNumDimAndSymbolIds(), &ids); for (auto *id : ids) { - OpPointer iv; + AffineForOp iv; if ((iv = getForInductionVarOwner(id)) && llvm::is_contained(enclosingIVs, iv) == false) { cst.projectOut(id); @@ -371,10 +370,9 @@ Optional mlir::getMemRefSizeInBytes(MemRefType memRefType) { template LogicalResult mlir::boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, bool emitError) { - static_assert( - std::is_same>::value || - std::is_same>::value, - "argument should be either a LoadOp or a StoreOp"); + static_assert(std::is_same::value || + std::is_same::value, + "argument should be either a LoadOp or a StoreOp"); Instruction *opInst = loadOrStoreOp->getInstruction(); @@ -424,9 +422,9 @@ LogicalResult mlir::boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, } // Explicitly instantiate the template so that the compiler knows we need them! -template LogicalResult mlir::boundCheckLoadOrStoreOp(OpPointer loadOp, +template LogicalResult mlir::boundCheckLoadOrStoreOp(LoadOp loadOp, bool emitError); -template LogicalResult mlir::boundCheckLoadOrStoreOp(OpPointer storeOp, +template LogicalResult mlir::boundCheckLoadOrStoreOp(StoreOp storeOp, bool emitError); // Returns in 'positions' the Block positions of 'inst' in each ancestor @@ -490,12 +488,12 @@ LogicalResult mlir::getBackwardComputationSliceState( return failure(); } // Get loop nest surrounding src operation. - SmallVector, 4> srcLoopIVs; + SmallVector srcLoopIVs; getLoopIVs(*srcAccess.opInst, &srcLoopIVs); unsigned numSrcLoopIVs = srcLoopIVs.size(); // Get loop nest surrounding dst operation. - SmallVector, 4> dstLoopIVs; + SmallVector dstLoopIVs; getLoopIVs(*dstAccess.opInst, &dstLoopIVs); unsigned numDstLoopIVs = dstLoopIVs.size(); if (dstLoopDepth > numDstLoopIVs) { @@ -566,21 +564,21 @@ LogicalResult mlir::getBackwardComputationSliceState( // entire destination index set. Subtract out the dependent destination // iterations from destination index set and check for emptiness --- this is one // solution. -OpPointer mlir::insertBackwardComputationSlice( +AffineForOp mlir::insertBackwardComputationSlice( Instruction *srcOpInst, Instruction *dstOpInst, unsigned dstLoopDepth, ComputationSliceState *sliceState) { // Get loop nest surrounding src operation. - SmallVector, 4> srcLoopIVs; + SmallVector srcLoopIVs; getLoopIVs(*srcOpInst, &srcLoopIVs); unsigned numSrcLoopIVs = srcLoopIVs.size(); // Get loop nest surrounding dst operation. - SmallVector, 4> dstLoopIVs; + SmallVector dstLoopIVs; getLoopIVs(*dstOpInst, &dstLoopIVs); unsigned dstLoopIVsSize = dstLoopIVs.size(); if (dstLoopDepth > dstLoopIVsSize) { dstOpInst->emitError("invalid destination loop depth"); - return OpPointer(); + return AffineForOp(); } // Find the inst block positions of 'srcOpInst' within 'srcLoopIVs'. @@ -599,7 +597,7 @@ OpPointer mlir::insertBackwardComputationSlice( Instruction *sliceInst = getInstAtPosition(positions, /*level=*/0, sliceLoopNest->getBody()); // Get loop nest surrounding 'sliceInst'. - SmallVector, 4> sliceSurroundingLoops; + SmallVector sliceSurroundingLoops; getLoopIVs(*sliceInst, &sliceSurroundingLoops); // Sanity check. @@ -666,7 +664,7 @@ unsigned mlir::getNestingDepth(Instruction &inst) { /// Returns the number of surrounding loops common to 'loopsA' and 'loopsB', /// where each lists loops from outer-most to inner-most in loop nest. unsigned mlir::getNumCommonSurroundingLoops(Instruction &A, Instruction &B) { - SmallVector, 4> loopsA, loopsB; + SmallVector loopsA, loopsB; getLoopIVs(A, &loopsA); getLoopIVs(B, &loopsB); @@ -728,7 +726,7 @@ static Optional getMemoryFootprintBytes(Block &block, return totalSizeInBytes; } -Optional mlir::getMemoryFootprintBytes(OpPointer forOp, +Optional mlir::getMemoryFootprintBytes(AffineForOp forOp, int memorySpace) { auto *forInst = forOp->getInstruction(); return ::getMemoryFootprintBytes( @@ -739,8 +737,7 @@ Optional mlir::getMemoryFootprintBytes(OpPointer forOp, /// Returns in 'sequentialLoops' all sequential loops in loop nest rooted /// at 'forOp'. void mlir::getSequentialLoops( - OpPointer forOp, - llvm::SmallDenseSet *sequentialLoops) { + AffineForOp forOp, llvm::SmallDenseSet *sequentialLoops) { forOp->getInstruction()->walk([&](Instruction *inst) { if (auto innerFor = inst->dyn_cast()) if (!isLoopParallel(innerFor)) @@ -749,7 +746,7 @@ void mlir::getSequentialLoops( } /// Returns true if 'forOp' is parallel. -bool mlir::isLoopParallel(OpPointer forOp) { +bool mlir::isLoopParallel(AffineForOp forOp) { // Collect all load and store ops in loop nest rooted at 'forOp'. SmallVector loadAndStoreOpInsts; forOp->getInstruction()->walk([&](Instruction *opInst) { diff --git a/mlir/lib/EDSC/Builders.cpp b/mlir/lib/EDSC/Builders.cpp index ac7014923a7..595141af84e 100644 --- a/mlir/lib/EDSC/Builders.cpp +++ b/mlir/lib/EDSC/Builders.cpp @@ -155,8 +155,8 @@ static llvm::Optional emitStaticFor(ArrayRef lbs, if (!lbConst || !ubConst) return llvm::Optional(); - return ValueHandle::create(lbConst->getValue(), - ubConst->getValue(), step); + return ValueHandle::create(lbConst.getValue(), + ubConst.getValue(), step); } mlir::edsc::LoopBuilder::LoopBuilder(ValueHandle *iv, @@ -268,10 +268,9 @@ categorizeValueByAffineType(MLIRContext *context, Value *val, unsigned &numDims, AffineExpr d; Value *resultVal = nullptr; auto *inst = val->getDefiningInst(); - auto constant = - inst ? inst->dyn_cast() : OpPointer(); + auto constant = inst ? inst->dyn_cast() : ConstantIndexOp(); if (constant) { - d = getAffineConstantExpr(constant->getValue(), context); + d = getAffineConstantExpr(constant.getValue(), context); } else if (isValidSymbol(val) && !isValidDim(val)) { d = getAffineSymbolExpr(numSymbols++, context); resultVal = val; diff --git a/mlir/lib/EDSC/MLIREmitter.cpp b/mlir/lib/EDSC/MLIREmitter.cpp index 6430796bcc1..1196748a0af 100644 --- a/mlir/lib/EDSC/MLIREmitter.cpp +++ b/mlir/lib/EDSC/MLIREmitter.cpp @@ -94,25 +94,24 @@ static void checkAffineProvenance(ArrayRef values) { } } -static OpPointer emitStaticFor(FuncBuilder &builder, Location loc, - ArrayRef lbs, - ArrayRef ubs, - uint64_t step) { +static AffineForOp emitStaticFor(FuncBuilder &builder, Location loc, + ArrayRef lbs, ArrayRef ubs, + uint64_t step) { if (lbs.size() != 1 || ubs.size() != 1) - return OpPointer(); + return AffineForOp(); auto *lbDef = lbs.front()->getDefiningInst(); auto *ubDef = ubs.front()->getDefiningInst(); if (!lbDef || !ubDef) - return OpPointer(); + return AffineForOp(); auto lbConst = lbDef->dyn_cast(); auto ubConst = ubDef->dyn_cast(); if (!lbConst || !ubConst) - return OpPointer(); + return AffineForOp(); - return builder.create(loc, lbConst->getValue(), - ubConst->getValue(), step); + return builder.create(loc, lbConst.getValue(), + ubConst.getValue(), step); } Value *mlir::edsc::MLIREmitter::emitExpr(Expr e) { @@ -166,11 +165,10 @@ Value *mlir::edsc::MLIREmitter::emitExpr(Expr e) { // Step must be a static constant. auto step = - stepExpr->getDefiningInst()->cast()->getValue(); + stepExpr->getDefiningInst()->cast().getValue(); // Special case with more concise emitted code for static bounds. - OpPointer forOp = - emitStaticFor(*builder, location, lbs, ubs, step); + AffineForOp forOp = emitStaticFor(*builder, location, lbs, ubs, step); // General case. if (!forOp) @@ -387,7 +385,7 @@ mlir::edsc::MLIREmitter::makeBoundMemRefView(Expr boundMemRef) { return makeBoundMemRefView(v); } -OpPointer mlir::edsc::MLIREmitter::getAffineForOp(Expr e) { +AffineForOp mlir::edsc::MLIREmitter::getAffineForOp(Expr e) { auto *value = ssaBindings.lookup(e); assert(value && "Expr not bound"); return getForInductionVarOwner(value); diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index ffa29b194af..67c51d54789 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -319,10 +319,10 @@ struct SimplifyAllocConst : public RewritePattern { continue; } auto *defOp = allocOp->getOperand(dynamicDimPos)->getDefiningInst(); - OpPointer constantIndexOp; + ConstantIndexOp constantIndexOp; if (defOp && (constantIndexOp = defOp->dyn_cast())) { // Dynamic shape dimension will be folded. - newShapeConstants.push_back(constantIndexOp->getValue()); + newShapeConstants.push_back(constantIndexOp.getValue()); // Record to check for zero uses later below. droppedOperands.push_back(constantIndexOp); } else { diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 954135d2a4f..1616cd3472d 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -187,7 +187,7 @@ static bool getFullMemRefAsRegion(Instruction *opInst, unsigned numParamLoopIVs, // Just get the first numSymbols IVs, which the memref region is parametric // on. - SmallVector, 4> ivs; + SmallVector ivs; getLoopIVs(*opInst, &ivs); ivs.resize(numParamLoopIVs); SmallVector symbols; @@ -485,7 +485,7 @@ bool DmaGeneration::runOnBlock(Block *block) { for (auto it = curBegin; it != block->end(); ++it) { if (auto forOp = it->dyn_cast()) { // Returns true if the footprint is known to exceed capacity. - auto exceedsCapacity = [&](OpPointer forOp) { + auto exceedsCapacity = [&](AffineForOp forOp) { Optional footprint = getMemoryFootprintBytes(forOp, /*memorySpace=*/0); @@ -553,7 +553,7 @@ findHighestBlockForPlacement(const MemRefRegion ®ion, Block &block, SmallVector symbols; cst->getIdValues(cst->getNumDimIds(), cst->getNumDimAndSymbolIds(), &symbols); - SmallVector, 4> enclosingFors; + SmallVector enclosingFors; getLoopIVs(*block.begin(), &enclosingFors); // Walk up loop parents till we find an IV on which this region is // symbolic/variant. @@ -733,7 +733,7 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { // For a range of operation instructions, a note will be emitted at the // caller. - OpPointer forOp; + AffineForOp forOp; uint64_t sizeInKib = llvm::divideCeil(totalDmaBuffersSizeInBytes, 1024); if (llvm::DebugFlag && (forOp = begin->dyn_cast())) { forOp->emitNote( diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 95bdc3ca2d2..8e1fc505348 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -122,7 +122,7 @@ namespace { // LoopNestStateCollector walks loop nests and collects load and store // operations, and whether or not an IfInst was encountered in the loop nest. struct LoopNestStateCollector { - SmallVector, 4> forOps; + SmallVector forOps; SmallVector loadOpInsts; SmallVector storeOpInsts; bool hasNonForRegion = false; @@ -691,7 +691,7 @@ bool MemRefDependenceGraph::init(Function *f) { auto *opInst = node.inst; for (auto *value : opInst->getResults()) { for (auto &use : value->getUses()) { - SmallVector, 4> loops; + SmallVector loops; getLoopIVs(*use.getOwner(), &loops); if (loops.empty()) continue; @@ -727,7 +727,7 @@ namespace { // and operation count) for a loop nest up until the innermost loop body. struct LoopNestStats { // Map from AffineForOp to immediate child AffineForOps in its loop body. - DenseMap, 2>> loopMap; + DenseMap> loopMap; // Map from AffineForOp to count of operations in its loop body. DenseMap opCountMap; // Map from AffineForOp to its constant trip count. @@ -743,7 +743,7 @@ struct LoopNestStatsCollector { LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {} void collect(Instruction *inst) { - inst->walk([&](OpPointer forOp) { + inst->walk([&](AffineForOp forOp) { auto *forInst = forOp->getInstruction(); auto *parentInst = forOp->getInstruction()->getParentInst(); if (parentInst != nullptr) { @@ -844,7 +844,7 @@ static Optional getConstDifference(AffineMap lbMap, AffineMap ubMap) { static bool buildSliceTripCountMap( Instruction *srcOpInst, ComputationSliceState *sliceState, llvm::SmallDenseMap *tripCountMap) { - SmallVector, 4> srcLoopIVs; + SmallVector srcLoopIVs; getLoopIVs(*srcOpInst, &srcLoopIVs); unsigned numSrcLoopIVs = srcLoopIVs.size(); // Populate map from AffineForOp -> trip count @@ -892,7 +892,7 @@ static unsigned getInnermostCommonLoopDepth(ArrayRef ops) { unsigned numOps = ops.size(); assert(numOps > 0); - std::vector, 4>> loops(numOps); + std::vector> loops(numOps); unsigned loopDepthLimit = std::numeric_limits::max(); for (unsigned i = 0; i < numOps; ++i) { getLoopIVs(*ops[i], &loops[i]); @@ -1056,8 +1056,8 @@ static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) { assert(node->inst->isa()); // Get perfectly nested sequence of loops starting at root of loop nest. // TODO(andydavis,bondhugula) Share this with similar code in loop tiling. - SmallVector, 4> loops; - OpPointer curr = node->inst->cast(); + SmallVector loops; + AffineForOp curr = node->inst->cast(); loops.push_back(curr); auto *currBody = curr->getBody(); while (!currBody->empty() && @@ -1113,7 +1113,7 @@ unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { // MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'. // TODO(bondhugula): consider refactoring the common code from generateDma and // this one. -static Value *createPrivateMemRef(OpPointer forOp, +static Value *createPrivateMemRef(AffineForOp forOp, Instruction *srcStoreOpInst, unsigned dstLoopDepth, Optional fastMemorySpace, @@ -1429,7 +1429,7 @@ static bool isFusionProfitable(Instruction *srcOpInst, }); // Compute cost of sliced and unsliced src loop nest. - SmallVector, 4> srcLoopIVs; + SmallVector srcLoopIVs; getLoopIVs(*srcOpInst, &srcLoopIVs); unsigned numSrcLoopIVs = srcLoopIVs.size(); @@ -1443,7 +1443,7 @@ static bool isFusionProfitable(Instruction *srcOpInst, return false; } // Compute cost of dst loop nest. - SmallVector, 4> dstLoopIVs; + SmallVector dstLoopIVs; getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs); LoopNestStats dstLoopNestStats; @@ -1933,7 +1933,7 @@ public: // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. auto sliceLoopNest = mlir::insertBackwardComputationSlice( srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState); - if (sliceLoopNest != nullptr) { + if (sliceLoopNest) { LLVM_DEBUG(llvm::dbgs() << "\tslice loop nest:\n" << *sliceLoopNest->getInstruction() << "\n"); @@ -2182,8 +2182,8 @@ public: return false; } - void updateStateAfterSiblingFusion(OpPointer sliceLoopNest, - Node *sibNode, Node *dstNode) { + void updateStateAfterSiblingFusion(AffineForOp sliceLoopNest, Node *sibNode, + Node *dstNode) { // Update 'sibNode' and 'dstNode' input/output edges to reflect fusion. mdg->updateEdges(sibNode->id, dstNode->id); diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 9c97c1b6c74..0b629531df0 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -67,7 +67,7 @@ struct LoopTiling : public FunctionPass { : cacheSizeBytes(cacheSizeBytes), avoidMaxMinBounds(avoidMaxMinBounds) {} void runOnFunction() override; - void getTileSizes(ArrayRef> band, + void getTileSizes(ArrayRef band, SmallVectorImpl *tileSizes); // Default tile size if nothing is provided. @@ -90,7 +90,7 @@ FunctionPassBase *mlir::createLoopTilingPass(uint64_t cacheSizeBytes) { // Move the loop body of AffineForOp 'src' from 'src' into the specified // location in destination's body. -static inline void moveLoopBody(AffineForOp *src, AffineForOp *dest, +static inline void moveLoopBody(AffineForOp src, AffineForOp dest, Block::iterator loc) { dest->getBody()->getInstructions().splice(loc, src->getBody()->getInstructions()); @@ -98,7 +98,7 @@ static inline void moveLoopBody(AffineForOp *src, AffineForOp *dest, // Move the loop body of AffineForOp 'src' from 'src' to the start of dest's // body. -static inline void moveLoopBody(AffineForOp *src, AffineForOp *dest) { +static inline void moveLoopBody(AffineForOp src, AffineForOp dest) { moveLoopBody(src, dest, dest->getBody()->begin()); } @@ -107,10 +107,10 @@ static inline void moveLoopBody(AffineForOp *src, AffineForOp *dest) { /// depend on other dimensions. Bounds of each dimension can thus be treated /// independently, and deriving the new bounds is much simpler and faster /// than for the case of tiling arbitrary polyhedral shapes. -static void constructTiledIndexSetHyperRect( - MutableArrayRef> origLoops, - MutableArrayRef> newLoops, - ArrayRef tileSizes) { +static void +constructTiledIndexSetHyperRect(MutableArrayRef origLoops, + MutableArrayRef newLoops, + ArrayRef tileSizes) { assert(!origLoops.empty()); assert(origLoops.size() == tileSizes.size()); @@ -174,7 +174,7 @@ static void constructTiledIndexSetHyperRect( /// Tiles the specified band of perfectly nested loops creating tile-space loops /// and intra-tile loops. A band is a contiguous set of loops. // TODO(bondhugula): handle non hyper-rectangular spaces. -LogicalResult mlir::tileCodeGen(MutableArrayRef> band, +LogicalResult mlir::tileCodeGen(MutableArrayRef band, ArrayRef tileSizes) { assert(!band.empty()); assert(band.size() == tileSizes.size() && "Incorrect number of tile sizes"); @@ -187,13 +187,13 @@ LogicalResult mlir::tileCodeGen(MutableArrayRef> band, auto origLoops = band; - OpPointer rootAffineForOp = origLoops[0]; + AffineForOp rootAffineForOp = origLoops[0]; auto loc = rootAffineForOp->getLoc(); // Note that width is at least one since band isn't empty. unsigned width = band.size(); - SmallVector, 12> newLoops(2 * width); - OpPointer innermostPointLoop; + SmallVector newLoops(2 * width); + AffineForOp innermostPointLoop; // The outermost among the loops as we add more.. auto *topLoop = rootAffineForOp->getInstruction(); @@ -256,13 +256,12 @@ LogicalResult mlir::tileCodeGen(MutableArrayRef> band, // Identify valid and profitable bands of loops to tile. This is currently just // a temporary placeholder to test the mechanics of tiled code generation. // Returns all maximal outermost perfect loop nests to tile. -static void -getTileableBands(Function *f, - std::vector, 6>> *bands) { +static void getTileableBands(Function *f, + std::vector> *bands) { // Get maximal perfect nest of 'for' insts starting from root (inclusive). - auto getMaximalPerfectLoopNest = [&](OpPointer root) { - SmallVector, 6> band; - OpPointer currInst = root; + auto getMaximalPerfectLoopNest = [&](AffineForOp root) { + SmallVector band; + AffineForOp currInst = root; do { band.push_back(currInst); } while (currInst->getBody()->getInstructions().size() == 1 && @@ -278,7 +277,7 @@ getTileableBands(Function *f, // Reduce each tile size to the largest divisor of the corresponding trip count // (if the trip count is known). -static void adjustToDivisorsOfTripCounts(ArrayRef> band, +static void adjustToDivisorsOfTripCounts(ArrayRef band, SmallVectorImpl *tileSizes) { assert(band.size() == tileSizes->size() && "invalid tile size count"); for (unsigned i = 0, e = band.size(); i < e; i++) { @@ -302,7 +301,7 @@ static void adjustToDivisorsOfTripCounts(ArrayRef> band, // along each of the dimensions being tiled. // TODO(mlir-team): evolve this model. Tile size determination is a large area // to play with in general. -void LoopTiling::getTileSizes(ArrayRef> band, +void LoopTiling::getTileSizes(ArrayRef band, SmallVectorImpl *tileSizes) { if (band.empty()) return; @@ -383,7 +382,7 @@ void LoopTiling::runOnFunction() { cacheSizeBytes = clCacheSizeKiB * 1024; // Bands of loops to tile. - std::vector, 6>> bands; + std::vector> bands; getTileableBands(getFunction(), &bands); for (auto &band : bands) { diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index b3ee63ff1fa..a16237e6452 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -69,19 +69,18 @@ struct LoopUnroll : public FunctionPass { const Optional unrollFull; // Callback to obtain unroll factors; if this has a callable target, takes // precedence over command-line argument or passed argument. - const std::function)> getUnrollFactor; + const std::function getUnrollFactor; - explicit LoopUnroll(Optional unrollFactor = None, - Optional unrollFull = None, - const std::function)> - &getUnrollFactor = nullptr) + explicit LoopUnroll( + Optional unrollFactor = None, Optional unrollFull = None, + const std::function &getUnrollFactor = nullptr) : unrollFactor(unrollFactor), unrollFull(unrollFull), getUnrollFactor(getUnrollFactor) {} void runOnFunction() override; /// Unroll this for inst. Returns failure if nothing was done. - LogicalResult runOnAffineForOp(OpPointer forOp); + LogicalResult runOnAffineForOp(AffineForOp forOp); static const unsigned kDefaultUnrollFactor = 4; }; @@ -91,7 +90,7 @@ void LoopUnroll::runOnFunction() { // Gathers all innermost loops through a post order pruned walk. struct InnermostLoopGatherer { // Store innermost loops as we walk. - std::vector> loops; + std::vector loops; void walkPostOrder(Function *f) { for (auto &b : *f) @@ -124,18 +123,16 @@ void LoopUnroll::runOnFunction() { if (clUnrollFull.getNumOccurrences() > 0 && clUnrollFullThreshold.getNumOccurrences() > 0) { // Store short loops as we walk. - std::vector> loops; + std::vector loops; // Gathers all loops with trip count <= minTripCount. Do a post order walk // so that loops are gathered from innermost to outermost (or else unrolling // an outer one may delete gathered inner ones). - getFunction()->walkPostOrder( - [&](OpPointer forOp) { - Optional tripCount = getConstantTripCount(forOp); - if (tripCount.hasValue() && - tripCount.getValue() <= clUnrollFullThreshold) - loops.push_back(forOp); - }); + getFunction()->walkPostOrder([&](AffineForOp forOp) { + Optional tripCount = getConstantTripCount(forOp); + if (tripCount.hasValue() && tripCount.getValue() <= clUnrollFullThreshold) + loops.push_back(forOp); + }); for (auto forOp : loops) loopUnrollFull(forOp); return; @@ -163,7 +160,7 @@ void LoopUnroll::runOnFunction() { /// Unrolls a 'for' inst. Returns success if the loop was unrolled, failure /// otherwise. The default unroll factor is 4. -LogicalResult LoopUnroll::runOnAffineForOp(OpPointer forOp) { +LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) { // Use the function callback if one was provided. if (getUnrollFactor) { return loopUnrollByFactor(forOp, getUnrollFactor(forOp)); @@ -185,7 +182,7 @@ LogicalResult LoopUnroll::runOnAffineForOp(OpPointer forOp) { FunctionPassBase *mlir::createLoopUnrollPass( int unrollFactor, int unrollFull, - const std::function)> &getUnrollFactor) { + const std::function &getUnrollFactor) { return new LoopUnroll( unrollFactor == -1 ? None : Optional(unrollFactor), unrollFull == -1 ? None : Optional(unrollFull), getUnrollFactor); diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 2b92b2a6422..03c06b4b450 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -78,7 +78,7 @@ struct LoopUnrollAndJam : public FunctionPass { : unrollJamFactor(unrollJamFactor) {} void runOnFunction() override; - LogicalResult runOnAffineForOp(OpPointer forOp); + LogicalResult runOnAffineForOp(AffineForOp forOp); }; } // end anonymous namespace @@ -98,7 +98,7 @@ void LoopUnrollAndJam::runOnFunction() { /// Unroll and jam a 'for' inst. Default unroll jam factor is /// kDefaultUnrollJamFactor. Return failure if nothing was done. -LogicalResult LoopUnrollAndJam::runOnAffineForOp(OpPointer forOp) { +LogicalResult LoopUnrollAndJam::runOnAffineForOp(AffineForOp forOp) { // Unroll and jam by the factor that was passed if any. if (unrollJamFactor.hasValue()) return loopUnrollJamByFactor(forOp, unrollJamFactor.getValue()); @@ -110,7 +110,7 @@ LogicalResult LoopUnrollAndJam::runOnAffineForOp(OpPointer forOp) { return loopUnrollJamByFactor(forOp, kDefaultUnrollJamFactor); } -LogicalResult mlir::loopUnrollJamUpToFactor(OpPointer forOp, +LogicalResult mlir::loopUnrollJamUpToFactor(AffineForOp forOp, uint64_t unrollJamFactor) { Optional mayBeConstantTripCount = getConstantTripCount(forOp); @@ -121,7 +121,7 @@ LogicalResult mlir::loopUnrollJamUpToFactor(OpPointer forOp, } /// Unrolls and jams this loop by the specified factor. -LogicalResult mlir::loopUnrollJamByFactor(OpPointer forOp, +LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp, uint64_t unrollJamFactor) { // Gathers all maximal sub-blocks of instructions that do not themselves // include a for inst (a instruction could have a descendant for inst though diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index 31232ce7fe4..f80b4426b6d 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -244,9 +244,9 @@ namespace { struct LowerAffinePass : public FunctionPass { void runOnFunction() override; - bool lowerAffineFor(OpPointer forOp); - bool lowerAffineIf(AffineIfOp *ifOp); - bool lowerAffineApply(AffineApplyOp *op); + bool lowerAffineFor(AffineForOp forOp); + bool lowerAffineIf(AffineIfOp ifOp); + bool lowerAffineApply(AffineApplyOp op); }; } // end anonymous namespace @@ -319,7 +319,7 @@ static Value *buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate, // | | // +--------------------------------+ // -bool LowerAffinePass::lowerAffineFor(OpPointer forOp) { +bool LowerAffinePass::lowerAffineFor(AffineForOp forOp) { auto loc = forOp->getLoc(); auto *forInst = forOp->getInstruction(); @@ -452,7 +452,7 @@ bool LowerAffinePass::lowerAffineFor(OpPointer forOp) { // | | // +--------------------------------+ // -bool LowerAffinePass::lowerAffineIf(AffineIfOp *ifOp) { +bool LowerAffinePass::lowerAffineIf(AffineIfOp ifOp) { auto *ifInst = ifOp->getInstruction(); auto loc = ifInst->getLoc(); @@ -568,7 +568,7 @@ bool LowerAffinePass::lowerAffineIf(AffineIfOp *ifOp) { // Convert an "affine.apply" operation into a sequence of arithmetic // instructions using the StandardOps dialect. Return true on error. -bool LowerAffinePass::lowerAffineApply(AffineApplyOp *op) { +bool LowerAffinePass::lowerAffineApply(AffineApplyOp op) { FuncBuilder builder(op->getInstruction()); auto maybeExpandedMap = expandAffineMap(&builder, op->getLoc(), op->getAffineMap(), diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index f0c204f9840..cde28c6517d 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -102,7 +102,7 @@ namespace { /// a VectorTransferWriteOp is rewritten. template class VectorTransferRewriter { public: - VectorTransferRewriter(VectorTransferOpTy *transfer, + VectorTransferRewriter(VectorTransferOpTy transfer, MLFuncLoweringRewriter *rewriter, MLFuncGlobalLoweringState *state); @@ -121,7 +121,7 @@ public: void rewrite(); private: - VectorTransferOpTy *transfer; + VectorTransferOpTy transfer; MLFuncLoweringRewriter *rewriter; MLFuncGlobalLoweringState *state; }; @@ -132,7 +132,7 @@ private: /// `pivs` and `vectorView` are swapped so that the invocation of /// LoopNestBuilder captures it in the innermost loop. template -void coalesceCopy(VectorTransferOpTy *transfer, +void coalesceCopy(VectorTransferOpTy transfer, SmallVectorImpl *pivs, edsc::VectorView *vectorView) { // rank of the remote memory access, coalescing behavior occurs on the @@ -166,7 +166,7 @@ void coalesceCopy(VectorTransferOpTy *transfer, /// MemRef. template static llvm::SmallVector -clip(VectorTransferOpTy *transfer, edsc::MemRefView &view, +clip(VectorTransferOpTy transfer, edsc::MemRefView &view, ArrayRef ivs) { using namespace mlir::edsc; using namespace edsc::op; @@ -216,7 +216,7 @@ clip(VectorTransferOpTy *transfer, edsc::MemRefView &view, template VectorTransferRewriter::VectorTransferRewriter( - VectorTransferOpTy *transfer, MLFuncLoweringRewriter *rewriter, + VectorTransferOpTy transfer, MLFuncLoweringRewriter *rewriter, MLFuncGlobalLoweringState *state) : transfer(transfer), rewriter(rewriter), state(state){}; @@ -368,7 +368,7 @@ public: std::unique_ptr opState, MLFuncLoweringRewriter *rewriter) const override { VectorTransferRewriter( - &*op->dyn_cast(), rewriter, funcWiseState) + op->dyn_cast(), rewriter, funcWiseState) .rewrite(); } }; diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 6208eee5d62..d1374e92dc8 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -441,7 +441,7 @@ static Instruction *instantiate(FuncBuilder *b, Instruction *opInst, /// In particular, if a dimension is fully instantiated (i.e. unrolled) then it /// is projected out in the final result. template -static AffineMap projectedPermutationMap(VectorTransferOpTy *transfer, +static AffineMap projectedPermutationMap(VectorTransferOpTy transfer, VectorType hwVectorType) { static_assert( std::is_same::value || @@ -481,7 +481,7 @@ static AffineMap projectedPermutationMap(VectorTransferOpTy *transfer, /// `hwVectorType` int the covering of the super-vector type. For a more /// detailed description of the problem, see the description of /// reindexAffineIndices. -static Instruction *instantiate(FuncBuilder *b, VectorTransferReadOp *read, +static Instruction *instantiate(FuncBuilder *b, VectorTransferReadOp read, VectorType hwVectorType, ArrayRef hwVectorInstance, DenseMap *substitutionsMap) { @@ -505,7 +505,7 @@ static Instruction *instantiate(FuncBuilder *b, VectorTransferReadOp *read, /// `hwVectorType` int the covering of th3e super-vector type. For a more /// detailed description of the problem, see the description of /// reindexAffineIndices. -static Instruction *instantiate(FuncBuilder *b, VectorTransferWriteOp *write, +static Instruction *instantiate(FuncBuilder *b, VectorTransferWriteOp write, VectorType hwVectorType, ArrayRef hwVectorInstance, DenseMap *substitutionsMap) { diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index f443a34e169..4088cde4185 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -72,7 +72,7 @@ namespace { struct MemRefDataFlowOpt : public FunctionPass { void runOnFunction() override; - void forwardStoreToLoad(OpPointer loadOp); + void forwardStoreToLoad(LoadOp loadOp); // A list of memref's that are potentially dead / could be eliminated. SmallPtrSet memrefsToErase; @@ -93,7 +93,7 @@ FunctionPassBase *mlir::createMemRefDataFlowOptPass() { // This is a straightforward implementation not optimized for speed. Optimize // this in the future if needed. -void MemRefDataFlowOpt::forwardStoreToLoad(OpPointer loadOp) { +void MemRefDataFlowOpt::forwardStoreToLoad(LoadOp loadOp) { Instruction *lastWriteStoreOp = nullptr; Instruction *loadOpInst = loadOp->getInstruction(); @@ -224,8 +224,7 @@ void MemRefDataFlowOpt::runOnFunction() { memrefsToErase.clear(); // Walk all load's and perform load/store forwarding. - f->walk( - [&](OpPointer loadOp) { forwardStoreToLoad(loadOp); }); + f->walk([&](LoadOp loadOp) { forwardStoreToLoad(loadOp); }); // Erase all load op's whose results were replaced with store fwd'ed ones. for (auto *loadOp : loadOpsToErase) { diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 1dfc4e7dc17..9809a146072 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -40,9 +40,9 @@ namespace { struct PipelineDataTransfer : public FunctionPass { void runOnFunction() override; - void runOnAffineForOp(OpPointer forOp); + void runOnAffineForOp(AffineForOp forOp); - std::vector> forOps; + std::vector forOps; }; } // end anonymous namespace @@ -71,7 +71,7 @@ static unsigned getTagMemRefPos(Instruction &dmaInst) { /// of the old memref by the new one while indexing the newly added dimension by /// the loop IV of the specified 'for' instruction modulo 2. Returns false if /// such a replacement cannot be performed. -static bool doubleBuffer(Value *oldMemRef, OpPointer forOp) { +static bool doubleBuffer(Value *oldMemRef, AffineForOp forOp) { auto *forBody = forOp->getBody(); FuncBuilder bInner(forBody, forBody->begin()); bInner.setInsertionPoint(forBody, forBody->begin()); @@ -145,14 +145,13 @@ void PipelineDataTransfer::runOnFunction() { // epilogue). forOps.clear(); getFunction()->walkPostOrder( - [&](OpPointer forOp) { forOps.push_back(forOp); }); + [&](AffineForOp forOp) { forOps.push_back(forOp); }); for (auto forOp : forOps) runOnAffineForOp(forOp); } // Check if tags of the dma start op and dma wait op match. -static bool checkTagMatch(OpPointer startOp, - OpPointer waitOp) { +static bool checkTagMatch(DmaStartOp startOp, DmaWaitOp waitOp) { if (startOp->getTagMemRef() != waitOp->getTagMemRef()) return false; auto startIndices = startOp->getTagIndices(); @@ -176,15 +175,14 @@ static bool checkTagMatch(OpPointer startOp, // Identify matching DMA start/finish instructions to overlap computation with. static void findMatchingStartFinishInsts( - OpPointer forOp, + AffineForOp forOp, SmallVectorImpl> &startWaitPairs) { // Collect outgoing DMA instructions - needed to check for dependences below. - SmallVector, 4> outgoingDmaOps; + SmallVector outgoingDmaOps; for (auto &inst : *forOp->getBody()) { - OpPointer dmaStartOp; - if ((dmaStartOp = inst.dyn_cast()) && - dmaStartOp->isSrcMemorySpaceFaster()) + auto dmaStartOp = inst.dyn_cast(); + if (dmaStartOp && dmaStartOp->isSrcMemorySpaceFaster()) outgoingDmaOps.push_back(dmaStartOp); } @@ -195,9 +193,10 @@ static void findMatchingStartFinishInsts( dmaFinishInsts.push_back(&inst); continue; } - OpPointer dmaStartOp; - if (!(dmaStartOp = inst.dyn_cast())) + auto dmaStartOp = inst.dyn_cast(); + if (!dmaStartOp) continue; + // Only DMAs incoming into higher memory spaces are pipelined for now. // TODO(bondhugula): handle outgoing DMA pipelining. if (!dmaStartOp->isDestMemorySpaceFaster()) @@ -247,7 +246,7 @@ static void findMatchingStartFinishInsts( /// Overlap DMA transfers with computation in this loop. If successful, /// 'forOp' is deleted, and a prologue, a new pipelined loop, and epilogue are /// inserted right before where it was. -void PipelineDataTransfer::runOnAffineForOp(OpPointer forOp) { +void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { auto mayBeConstTripCount = getConstantTripCount(forOp); if (!mayBeConstTripCount.hasValue()) { LLVM_DEBUG( @@ -329,7 +328,7 @@ void PipelineDataTransfer::runOnAffineForOp(OpPointer forOp) { assert(dmaStartInst->isa()); instShiftMap[dmaStartInst] = 0; // Set shifts for DMA start inst's affine operand computation slices to 0. - SmallVector, 4> sliceOps; + SmallVector sliceOps; mlir::createAffineComputationSlice(dmaStartInst, &sliceOps); if (!sliceOps.empty()) { for (auto sliceOp : sliceOps) { diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 5a58c06dd42..e5f1fef990f 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -43,8 +43,8 @@ using namespace mlir; /// part of the unrolled loop. Computes the bound as an AffineMap with its /// operands or a null map when the trip count can't be expressed as an affine /// expression. -void mlir::getCleanupLoopLowerBound(OpPointer forOp, - unsigned unrollFactor, AffineMap *map, +void mlir::getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor, + AffineMap *map, SmallVectorImpl *operands, FuncBuilder *b) { auto lbMap = forOp->getLowerBoundMap(); @@ -67,11 +67,8 @@ void mlir::getCleanupLoopLowerBound(OpPointer forOp, unsigned step = forOp->getStep(); - // We need to get non-const operands; we aren't changing them here. - auto ncForOp = *reinterpret_cast *>(&forOp); - - SmallVector lbOperands(ncForOp->getLowerBoundOperands()); - auto lb = b->create(ncForOp->getLoc(), lbMap, lbOperands); + SmallVector lbOperands(forOp->getLowerBoundOperands()); + auto lb = b->create(forOp->getLoc(), lbMap, lbOperands); // For each upper bound expr, get the range. // Eg: for %i = lb to min (ub1, ub2), @@ -115,7 +112,7 @@ void mlir::getCleanupLoopLowerBound(OpPointer forOp, /// Promotes the loop body of a forOp to its containing block if the forOp /// was known to have a single iteration. // TODO(bondhugula): extend this for arbitrary affine bounds. -LogicalResult mlir::promoteIfSingleIteration(OpPointer forOp) { +LogicalResult mlir::promoteIfSingleIteration(AffineForOp forOp) { Optional tripCount = getConstantTripCount(forOp); if (!tripCount.hasValue() || tripCount.getValue() != 1) return failure(); @@ -161,7 +158,7 @@ LogicalResult mlir::promoteIfSingleIteration(OpPointer forOp) { void mlir::promoteSingleIterationLoops(Function *f) { // Gathers all innermost loops through a post order pruned walk. f->walkPostOrder( - [](OpPointer forOp) { promoteIfSingleIteration(forOp); }); + [](AffineForOp forOp) { promoteIfSingleIteration(forOp); }); } /// Generates a 'for' inst with the specified lower and upper bounds while @@ -171,12 +168,11 @@ void mlir::promoteSingleIterationLoops(Function *f) { /// the pair specifies the shift applied to that group of instructions; note /// that the shift is multiplied by the loop step before being applied. Returns /// nullptr if the generated loop simplifies to a single iteration one. -static OpPointer +static AffineForOp generateLoop(AffineMap lbMap, AffineMap ubMap, const std::vector>> &instGroupQueue, - unsigned offset, OpPointer srcForInst, - FuncBuilder *b) { + unsigned offset, AffineForOp srcForInst, FuncBuilder *b) { SmallVector lbOperands(srcForInst->getLowerBoundOperands()); SmallVector ubOperands(srcForInst->getUpperBoundOperands()); @@ -216,7 +212,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, } } if (succeeded(promoteIfSingleIteration(loopChunk))) - return OpPointer(); + return AffineForOp(); return loopChunk; } @@ -235,8 +231,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, // asserts preservation of SSA dominance. A check for that as well as that for // memory-based depedence preservation check rests with the users of this // method. -LogicalResult mlir::instBodySkew(OpPointer forOp, - ArrayRef shifts, +LogicalResult mlir::instBodySkew(AffineForOp forOp, ArrayRef shifts, bool unrollPrologueEpilogue) { if (forOp->getBody()->empty()) return success(); @@ -285,8 +280,8 @@ LogicalResult mlir::instBodySkew(OpPointer forOp, // Nevertheless, if 'unrollPrologueEpilogue' is set, we will treat the first // loop generated as the prologue and the last as epilogue and unroll these // fully. - OpPointer prologue; - OpPointer epilogue; + AffineForOp prologue; + AffineForOp epilogue; // Do a sweep over the sorted shifts while storing open groups in a // vector, and generating loop portions as necessary during the sweep. A block @@ -306,7 +301,7 @@ LogicalResult mlir::instBodySkew(OpPointer forOp, // The interval for which the loop needs to be generated here is: // [lbShift, min(lbShift + tripCount, d)) and the body of the // loop needs to have all instructions in instQueue in that order. - OpPointer res; + AffineForOp res; if (lbShift + tripCount * step < d * step) { res = generateLoop( b.getShiftedAffineMap(origLbMap, lbShift), @@ -357,7 +352,7 @@ LogicalResult mlir::instBodySkew(OpPointer forOp, } /// Unrolls this loop completely. -LogicalResult mlir::loopUnrollFull(OpPointer forOp) { +LogicalResult mlir::loopUnrollFull(AffineForOp forOp) { Optional mayBeConstantTripCount = getConstantTripCount(forOp); if (mayBeConstantTripCount.hasValue()) { uint64_t tripCount = mayBeConstantTripCount.getValue(); @@ -371,7 +366,7 @@ LogicalResult mlir::loopUnrollFull(OpPointer forOp) { /// Unrolls and jams this loop by the specified factor or by the trip count (if /// constant) whichever is lower. -LogicalResult mlir::loopUnrollUpToFactor(OpPointer forOp, +LogicalResult mlir::loopUnrollUpToFactor(AffineForOp forOp, uint64_t unrollFactor) { Optional mayBeConstantTripCount = getConstantTripCount(forOp); @@ -383,7 +378,7 @@ LogicalResult mlir::loopUnrollUpToFactor(OpPointer forOp, /// Unrolls this loop by the specified factor. Returns success if the loop /// is successfully unrolled. -LogicalResult mlir::loopUnrollByFactor(OpPointer forOp, +LogicalResult mlir::loopUnrollByFactor(AffineForOp forOp, uint64_t unrollFactor) { assert(unrollFactor >= 1 && "unroll factor should be >= 1"); @@ -471,8 +466,7 @@ LogicalResult mlir::loopUnrollByFactor(OpPointer forOp, /// Performs loop interchange on 'forOpA' and 'forOpB', where 'forOpB' is /// nested within 'forOpA' as the only instruction in its block. -void mlir::interchangeLoops(OpPointer forOpA, - OpPointer forOpB) { +void mlir::interchangeLoops(AffineForOp forOpA, AffineForOp forOpB) { auto *forOpAInst = forOpA->getInstruction(); // 1) Slice forOpA's instruction list (which is just forOpB) just before // forOpA (in forOpA's parent's block) this should leave 'forOpA's @@ -492,11 +486,10 @@ void mlir::interchangeLoops(OpPointer forOpA, /// Performs a series of loop interchanges to sink 'forOp' 'loopDepth' levels /// deeper in the loop nest. -void mlir::sinkLoop(OpPointer forOp, unsigned loopDepth) { +void mlir::sinkLoop(AffineForOp forOp, unsigned loopDepth) { for (unsigned i = 0; i < loopDepth; ++i) { assert(forOp->getBody()->front().isa()); - OpPointer nextForOp = - forOp->getBody()->front().cast(); + AffineForOp nextForOp = forOp->getBody()->front().cast(); interchangeLoops(forOp, nextForOp); } } @@ -525,8 +518,8 @@ static void augmentMapAndBounds(FuncBuilder *b, Value *iv, AffineMap *map, // substituting `oldIv` in place of // `forOp.getInductionVariable()`. // Note: `newForOp` may be nested under `forOp`. -static void cloneLoopBodyInto(OpPointer forOp, Value *oldIv, - OpPointer newForOp) { +static void cloneLoopBodyInto(AffineForOp forOp, Value *oldIv, + AffineForOp newForOp) { BlockAndValueMapping map; map.map(oldIv, newForOp->getInductionVar()); FuncBuilder b(newForOp->getBody(), newForOp->getBody()->end()); @@ -554,9 +547,9 @@ static void cloneLoopBodyInto(OpPointer forOp, Value *oldIv, // responsibility to specify `targets` that are dominated by `forOp`. // Returns the new AffineForOps, one per `targets`, nested immediately under // each of the `targets`. -static SmallVector, 8> -stripmineSink(OpPointer forOp, uint64_t factor, - ArrayRef> targets) { +static SmallVector +stripmineSink(AffineForOp forOp, uint64_t factor, + ArrayRef targets) { // TODO(ntv): Use cheap structural assertions that targets are nested under // forOp and that targets are not nested under each other when DominanceInfo // exposes the capability. It seems overkill to construct a whole function @@ -579,7 +572,7 @@ stripmineSink(OpPointer forOp, uint64_t factor, augmentMapAndBounds(&b, forOp->getInductionVar(), &ubMap, &ubOperands, /*offset=*/scaledStep); - SmallVector, 8> innerLoops; + SmallVector innerLoops; for (auto t : targets) { // Insert newForOp at the end of `t`. FuncBuilder b(t->getBody(), t->getBody()->end()); @@ -601,21 +594,18 @@ stripmineSink(OpPointer forOp, uint64_t factor, // Stripmines a `forOp` by `factor` and sinks it under a single `target`. // Returns the new AffineForOps, nested immediately under `target`. -OpPointer stripmineSink(OpPointer forOp, - uint64_t factor, - OpPointer target) { - auto res = - stripmineSink(forOp, factor, ArrayRef>{target}); +AffineForOp stripmineSink(AffineForOp forOp, uint64_t factor, + AffineForOp target) { + auto res = stripmineSink(forOp, factor, ArrayRef{target}); assert(res.size() == 1 && "Expected 1 inner forOp"); return res[0]; } -SmallVector, 8>, 8> -mlir::tile(ArrayRef> forOps, ArrayRef sizes, - ArrayRef> targets) { - SmallVector, 8>, 8> res; - SmallVector, 8> currentTargets(targets.begin(), - targets.end()); +SmallVector, 8> +mlir::tile(ArrayRef forOps, ArrayRef sizes, + ArrayRef targets) { + SmallVector, 8> res; + SmallVector currentTargets(targets.begin(), targets.end()); for (auto it : llvm::zip(forOps, sizes)) { auto step = stripmineSink(std::get<0>(it), std::get<1>(it), currentTargets); res.push_back(step); @@ -624,8 +614,8 @@ mlir::tile(ArrayRef> forOps, ArrayRef sizes, return res; } -SmallVector, 8> -mlir::tile(ArrayRef> forOps, ArrayRef sizes, - OpPointer target) { - return tile(forOps, sizes, ArrayRef>{target})[0]; +SmallVector mlir::tile(ArrayRef forOps, + ArrayRef sizes, + AffineForOp target) { + return tile(forOps, sizes, ArrayRef{target})[0]; } diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 2f10b898502..7bf9993b7c8 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -221,7 +221,7 @@ bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, /// uses besides this opInst; otherwise returns the list of affine.apply /// operations created in output argument `sliceOps`. void mlir::createAffineComputationSlice( - Instruction *opInst, SmallVectorImpl> *sliceOps) { + Instruction *opInst, SmallVectorImpl *sliceOps) { // Collect all operands that are results of affine apply ops. SmallVector subOperands; subOperands.reserve(opInst->getNumOperands()); diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 195202a281f..955e38f4b39 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -853,7 +853,7 @@ static LogicalResult vectorizeRootOrTerminal(Value *iv, /// Coarsens the loops bounds and transforms all remaining load and store /// operations into the appropriate vector_transfer. -static LogicalResult vectorizeAffineForOp(AffineForOp *loop, int64_t step, +static LogicalResult vectorizeAffineForOp(AffineForOp loop, int64_t step, VectorizationState *state) { using namespace functional; loop->setStep(step); @@ -936,7 +936,7 @@ vectorizeLoopsAndLoadsRecursively(NestedMatch oneMatch, LLVM_DEBUG(dbgs() << "\n[early-vect] vectorizeForOp by " << vectorSize << " : "); LLVM_DEBUG(loopInst->print(dbgs())); - return vectorizeAffineForOp(loop, loop->getStep() * vectorSize, state); + return vectorizeAffineForOp(loop, loop.getStep() * vectorSize, state); } /// Tries to transform a scalar constant into a vector splat of that constant. @@ -1012,7 +1012,7 @@ static Value *vectorizeOperand(Value *operand, Instruction *inst, // 3. vectorize constant. if (auto constant = operand->getDefiningInst()->dyn_cast()) { return vectorizeConstant( - inst, *constant, + inst, constant, VectorType::get(state->strategy->vectorSizes, operand->getType())); } // 4. currently non-vectorizable. @@ -1178,8 +1178,8 @@ static LogicalResult vectorizeRootMatch(NestedMatch m, clonedLoop->erase(); return mlir::success(); } - OpPointer loop; - OpPointer clonedLoop; + AffineForOp loop; + AffineForOp clonedLoop; } guard{loop, clonedLoop}; ////////////////////////////////////////////////////////////////////////////// diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td index 015a889beb2..24dd2dc4fd6 100644 --- a/mlir/test/mlir-tblgen/op-decl.td +++ b/mlir/test/mlir-tblgen/op-decl.td @@ -30,6 +30,7 @@ def NS_AOp : Op<"a_op", [NoSideEffect]> { // CHECK: class AOp : public Op::Impl, OpTrait::HasNoSideEffect, OpTrait::AtLeastNOperands<1>::Impl> { // CHECK: public: +// CHECK: using Op::Op; // CHECK: static StringRef getOperationName(); // CHECK: Value *a(); // CHECK: Instruction::operand_range b(); @@ -45,7 +46,4 @@ def NS_AOp : Op<"a_op", [NoSideEffect]> { // CHECK: static void getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context); // CHECK: LogicalResult constantFold(ArrayRef operands, SmallVectorImpl &results, MLIRContext *context); // CHECK: bool fold(SmallVectorImpl &results); -// CHECK: private: -// CHECK: friend class ::mlir::Instruction; -// CHECK: explicit AOp(Instruction *state) : Op(state) {} // CHECK: }; diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index fdbac3053a9..1cb92f5390c 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -315,14 +315,12 @@ void OpClass::writeDeclTo(raw_ostream &os) const { for (const auto &trait : traits) os << ", " << trait; os << "> {\npublic:\n"; + os << " using Op::Op;\n"; for (const auto &method : methods) { method.writeDeclTo(os); os << "\n"; } - os << "\nprivate:\n" - << " friend class ::mlir::Instruction;\n"; - os << " explicit " << className << "(Instruction *state) : Op(state) {}\n" - << "};"; + os << "};"; } void OpClass::writeDefTo(raw_ostream &os) const { -- cgit v1.2.3 From 832567b3799f763ec3ba9480e1628c5a3de7fa6e Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 25 Mar 2019 10:14:34 -0700 Subject: NFC: Rename the 'for' operation in the AffineOps dialect to 'affine.for' and set the namespace of the AffineOps dialect to 'affine'. PiperOrigin-RevId: 240165792 --- mlir/bindings/python/test/test_py2and3.py | 18 +- mlir/g3doc/Dialects/Affine.md | 26 +- mlir/g3doc/Dialects/SuperVector.md | 18 +- mlir/g3doc/LangRef.md | 35 +- mlir/g3doc/Passes.md | 34 +- mlir/g3doc/Rationale.md | 42 +- mlir/g3doc/RationaleSimplifiedPolyhedralForm.md | 14 +- mlir/include/mlir/AffineOps/AffineOps.h | 24 +- mlir/include/mlir/Analysis/AffineStructures.h | 13 +- mlir/include/mlir/Analysis/Utils.h | 10 +- mlir/include/mlir/Analysis/VectorAnalysis.h | 16 +- mlir/include/mlir/EDSC/Types.h | 5 +- mlir/include/mlir/Transforms/LoopUtils.h | 7 +- mlir/include/mlir/Transforms/Utils.h | 4 +- mlir/lib/AffineOps/AffineOps.cpp | 6 +- mlir/lib/Analysis/AffineAnalysis.cpp | 8 +- mlir/lib/Analysis/TestParallelismDetection.cpp | 5 +- mlir/lib/Analysis/Utils.cpp | 10 +- mlir/lib/EDSC/MLIREmitter.cpp | 3 +- mlir/lib/EDSC/Types.cpp | 6 +- mlir/lib/Transforms/DmaGeneration.cpp | 11 +- mlir/lib/Transforms/LoopFusion.cpp | 3 +- mlir/lib/Transforms/LoopTiling.cpp | 3 +- mlir/lib/Transforms/LoopUnroll.cpp | 4 +- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 2 +- mlir/lib/Transforms/LowerAffine.cpp | 7 +- mlir/lib/Transforms/LowerVectorTransfers.cpp | 10 +- mlir/lib/Transforms/MaterializeVectors.cpp | 24 +- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 8 +- mlir/lib/Transforms/PipelineDataTransfer.cpp | 18 +- mlir/lib/Transforms/Utils/LoopUtils.cpp | 18 +- mlir/lib/Transforms/Utils/Utils.cpp | 4 +- mlir/lib/Transforms/Vectorize.cpp | 50 +- mlir/test/AffineOps/canonicalize.mlir | 40 +- mlir/test/AffineOps/invalid.mlir | 26 +- mlir/test/AffineOps/ops.mlir | 4 +- mlir/test/EDSC/api-test.cpp | 36 +- mlir/test/EDSC/builder-api-test.cpp | 20 +- mlir/test/IR/invalid.mlir | 62 +- mlir/test/IR/locations.mlir | 2 +- mlir/test/IR/parser.mlir | 76 +-- mlir/test/IR/pretty-locations.mlir | 2 +- .../Vectorize/lower_vector_transfers.mlir | 58 +- mlir/test/Transforms/Vectorize/materialize.mlir | 16 +- .../Vectorize/materialize_vectors_1d_to_1d.mlir | 24 +- .../Vectorize/materialize_vectors_2d_to_1d.mlir | 24 +- .../Vectorize/materialize_vectors_2d_to_2d.mlir | 24 +- mlir/test/Transforms/Vectorize/normalize_maps.mlir | 24 +- mlir/test/Transforms/Vectorize/vectorize_1d.mlir | 62 +- mlir/test/Transforms/Vectorize/vectorize_2d.mlir | 30 +- mlir/test/Transforms/Vectorize/vectorize_3d.mlir | 20 +- .../Vectorize/vectorize_outer_loop_2d.mlir | 18 +- .../vectorize_outer_loop_transpose_2d.mlir | 42 +- .../Vectorize/vectorize_transpose_2d.mlir | 42 +- mlir/test/Transforms/canonicalize.mlir | 12 +- mlir/test/Transforms/constant-fold.mlir | 4 +- mlir/test/Transforms/cse.mlir | 8 +- mlir/test/Transforms/dma-generate.mlir | 154 ++--- mlir/test/Transforms/loop-fusion.mlir | 732 ++++++++++----------- mlir/test/Transforms/loop-tiling.mlir | 48 +- mlir/test/Transforms/lower-affine.mlir | 28 +- mlir/test/Transforms/memref-bound-check.mlir | 50 +- mlir/test/Transforms/memref-dataflow-opt.mlir | 62 +- mlir/test/Transforms/memref-dependence-check.mlir | 86 +-- mlir/test/Transforms/parallelism-detection.mlir | 6 +- mlir/test/Transforms/pipeline-data-transfer.mlir | 50 +- .../Transforms/simplify-affine-structures.mlir | 38 +- mlir/test/Transforms/strip-debuginfo.mlir | 2 +- mlir/test/Transforms/unroll-jam.mlir | 42 +- mlir/test/Transforms/unroll.mlir | 158 ++--- 70 files changed, 1304 insertions(+), 1294 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/bindings/python/test/test_py2and3.py b/mlir/bindings/python/test/test_py2and3.py index e2cae843b6a..e1f0d96b635 100644 --- a/mlir/bindings/python/test/test_py2and3.py +++ b/mlir/bindings/python/test/test_py2and3.py @@ -56,11 +56,11 @@ class EdscTest(unittest.TestCase): code = str(fun) # TODO(zinenko,ntv): use FileCheck for these tests self.assertIn( - ' "for"() {lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (42)} : () -> () {\n', + ' "affine.for"() {lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (42)} : () -> () {\n', code) self.assertIn(" ^bb1(%i0: index):", code) self.assertIn( - ' "for"(%c42, %2) {lower_bound: (d0) -> (d0), step: 2 : index, upper_bound: (d0) -> (d0)} : (index, index) -> () {\n', + ' "affine.for"(%c42, %2) {lower_bound: (d0) -> (d0), step: 2 : index, upper_bound: (d0) -> (d0)} : (index, index) -> () {\n', code) self.assertIn(" ^bb2(%i1: index):", code) self.assertIn( @@ -76,19 +76,19 @@ class EdscTest(unittest.TestCase): code = str(fun) self.assertIn( - ' "for"() {lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (5)} : () -> () {\n', + ' "affine.for"() {lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (5)} : () -> () {\n', code) self.assertIn(" ^bb1(%i0: index):", code) self.assertIn( - ' "for"() {lower_bound: () -> (1), step: 3 : index, upper_bound: () -> (15)} : () -> () {\n', + ' "affine.for"() {lower_bound: () -> (1), step: 3 : index, upper_bound: () -> (15)} : () -> () {\n', code) self.assertIn(" ^bb2(%i1: index):", code) self.assertIn( - ' "for"() {lower_bound: () -> (2), step: 5 : index, upper_bound: () -> (25)} : () -> () {\n', + ' "affine.for"() {lower_bound: () -> (2), step: 5 : index, upper_bound: () -> (25)} : () -> () {\n', code) self.assertIn(" ^bb3(%i2: index):", code) self.assertIn( - ' "for"() {lower_bound: () -> (3), step: 7 : index, upper_bound: () -> (35)} : () -> () {\n', + ' "affine.for"() {lower_bound: () -> (3), step: 7 : index, upper_bound: () -> (35)} : () -> () {\n', code) self.assertIn(" ^bb4(%i3: index):", code) self.assertIn( @@ -342,10 +342,10 @@ class EdscTest(unittest.TestCase): code = str(fun) self.assertIn( - '"for"() {lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (10)}', + '"affine.for"() {lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (10)}', code) self.assertIn( - '"for"() {lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (42)}', + '"affine.for"() {lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (42)}', code) self.assertIn("%0 = load %arg0[%i0, %i1] : memref<10x42xf32>", code) self.assertIn("%1 = addf %0, %cst : f32", code) @@ -367,7 +367,7 @@ class EdscTest(unittest.TestCase): code = str(fun) self.assertIn( - '"for"() {lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (32)} : () -> ()', + '"affine.for"() {lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (32)} : () -> ()', code) self.assertIn("%0 = load %arg0[%i0, %i2] : memref<32x32xf32>", code) self.assertIn("%1 = load %arg1[%i2, %i1] : memref<32x32xf32>", code) diff --git a/mlir/g3doc/Dialects/Affine.md b/mlir/g3doc/Dialects/Affine.md index 55d26f0d956..0c69c60cbe9 100644 --- a/mlir/g3doc/Dialects/Affine.md +++ b/mlir/g3doc/Dialects/Affine.md @@ -15,7 +15,7 @@ loops and if instructions), the result of a [`affine.apply` operation](#'affine.apply'-operation) that recursively takes as arguments any symbolic identifiers. Dimensions may be bound not only to anything that a symbol is bound to, but also to induction variables of enclosing -[`for` operations](#'for'-operation), and the result of an +[`affine.for` operations](#'affine.for'-operation), and the result of an [`affine.apply` operation](#'affine.apply'-operation) (which recursively may use other dimensions and symbols). @@ -47,12 +47,12 @@ Example: %2 = affine.apply (i)[s0] -> (i+s0) (%42)[%n] ``` -#### 'for' operation {#'for'-operation} +#### 'affine.for' operation {#'affine.for'-operation} Syntax: ``` {.ebnf} -operation ::= `for` ssa-id `=` lower-bound `to` upper-bound +operation ::= `affine.for` ssa-id `=` lower-bound `to` upper-bound (`step` integer-literal)? `{` inst* `}` lower-bound ::= `max`? affine-map dim-and-symbol-use-list | shorthand-bound @@ -60,17 +60,17 @@ upper-bound ::= `min`? affine-map dim-and-symbol-use-list | shorthand-bound shorthand-bound ::= ssa-id | `-`? integer-literal ``` -The `for` operation represents an affine loop nest, defining an SSA value for -its induction variable. This SSA value always has type +The `affine.for` operation represents an affine loop nest, defining an SSA value +for its induction variable. This SSA value always has type [`index`](LangRef.md#index-type), which is the size of the machine word. -The `for` operation executes its body a number of times iterating from a lower -bound to an upper bound by a stride. The stride, represented by `step`, is a -positive constant integer which defaults to "1" if not present. The lower and +The `affine.for` operation executes its body a number of times iterating from a +lower bound to an upper bound by a stride. The stride, represented by `step`, is +a positive constant integer which defaults to "1" if not present. The lower and upper bounds specify a half-open range: the range includes the lower bound but does not include the upper bound. -The lower and upper bounds of a `for` operation are represented as an +The lower and upper bounds of a `affine.for` operation are represented as an application of an affine mapping to a list of SSA values passed to the map. The [same restrictions](#restrictions-on-dimensions-and-symbols) hold for these SSA values as for all bindings of SSA values to dimensions and symbols. @@ -94,8 +94,8 @@ Example showing reverse iteration of the inner loop: func @simple_example(%A: memref, %B: memref) { %N = dim %A, 0 : memref - for %i = 0 to %N step 1 { - for %j = 0 to %N { // implicitly steps by 1 + affine.for %i = 0 to %N step 1 { + affine.for %j = 0 to %N { // implicitly steps by 1 %0 = affine.apply #map57(%j)[%N] %tmp = call @F1(%A, %i, %0) : (memref, index, index)->(f32) call @F2(%tmp, %B, %i, %0) : (f32, memref, index, index)->() @@ -130,8 +130,8 @@ Example: #set = (d0, d1)[s0]: (d0 - 10 >= 0, s0 - d0 - 9 >= 0, d1 - 10 >= 0, s0 - d1 - 9 >= 0) func @reduced_domain_example(%A, %X, %N) : (memref<10xi32>, i32, i32) { - for %i = 0 to %N { - for %j = 0 to %N { + affine.for %i = 0 to %N { + affine.for %j = 0 to %N { %0 = affine.apply #map42(%j) %tmp = call @S1(%X, %i, %0) affine.if #set(%i, %j)[%N] { diff --git a/mlir/g3doc/Dialects/SuperVector.md b/mlir/g3doc/Dialects/SuperVector.md index 09beb950e37..640325306c1 100644 --- a/mlir/g3doc/Dialects/SuperVector.md +++ b/mlir/g3doc/Dialects/SuperVector.md @@ -23,8 +23,8 @@ Examples: // pad with %f0 to handle the boundary case: %f0 = constant 0.0f : f32 for %i0 = 0 to %0 { - for %i1 = 0 to %1 step 256 { - for %i2 = 0 to %2 step 32 { + affine.for %i1 = 0 to %1 step 256 { + affine.for %i2 = 0 to %2 step 32 { %v = vector_transfer_read %A, %i0, %i1, %i2, %f0 {permutation_map: (d0, d1, d2) -> (d2, d1)} : (memref, index, index, f32) -> vector<32x256xf32> @@ -34,7 +34,7 @@ for %i0 = 0 to %0 { // vector<128xf32>. The underlying implementation will require a 1-D vector // broadcast: for %i0 = 0 to %0 { - for %i1 = 0 to %1 { + affine.for %i1 = 0 to %1 { %3 = vector_transfer_read %A, %i0, %i1 {permutation_map: (d0, d1) -> (0)} : (memref, index, index) -> vector<128xf32> @@ -81,8 +81,8 @@ A notional lowering of vector_transfer_read could generate code resembling: %tmp = alloc() : vector<3x4x5xf32> %view_in_tmp = "element_type_cast"(%tmp) : memref<1xvector<3x4x5xf32>> for %i = 0 to 3 { - for %j = 0 to 4 { - for %k = 0 to 5 { + affine.for %j = 0 to 4 { + affine.for %k = 0 to 5 { %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref store %tmp[%i, %j, %k] : vector<3x4x5xf32> }}} @@ -102,7 +102,7 @@ lowered code would resemble: %tmp = alloc() : vector<3x4x5xf32> %view_in_tmp = "element_type_cast"(%tmp) : memref<1xvector<3x4x5xf32>> for %i = 0 to 3 { - for %k = 0 to 5 { + affine.for %k = 0 to 5 { %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref store %tmp[%i, 0, %k] : vector<3x4x5xf32> }} @@ -130,9 +130,9 @@ Examples: ```mlir {.mlir} // write vector<16x32x64xf32> into the slice `%A[%i0, %i1:%i1+32, %i2:%i2+64, %i3:%i3+16]`: for %i0 = 0 to %0 { - for %i1 = 0 to %1 step 32 { - for %i2 = 0 to %2 step 64 { - for %i3 = 0 to %3 step 16 { + affine.for %i1 = 0 to %1 step 32 { + affine.for %i2 = 0 to %2 step 64 { + affine.for %i3 = 0 to %3 step 16 { %val = `ssa-value` : vector<16x32x64xf32> vector_transfer_write %val, %A, %i0, %i1, %i2, %i3 {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} : diff --git a/mlir/g3doc/LangRef.md b/mlir/g3doc/LangRef.md index 9c248a924b9..13ab016a6ec 100644 --- a/mlir/g3doc/LangRef.md +++ b/mlir/g3doc/LangRef.md @@ -40,10 +40,10 @@ which means that values are defined before use and have scope defined by their dominance relations. Operations may produce zero or more results, and each is a distinct SSA value with its own type defined by the [type system](#type-system). -MLIR incorporates polyhedral compiler concepts, including `for` and `affine.if` -operations defined by the [affine dialect](Dialects/Affine.md), which model -affine loops and affine conditionals. It also includes affine maps integrated -into the type system - they are key to the representation of data and +MLIR incorporates polyhedral compiler concepts, including `affine.for` and +`affine.if` operations defined by the [affine dialect](Dialects/Affine.md), +which model affine loops and affine conditionals. It also includes affine maps +integrated into the type system - they are key to the representation of data and [MemRefs](#memref-type), which are the representation for tensors in addressable memory. MLIR also supports a first-class Tensor type allowing it to concisely represent operations on N-dimensional arrays. @@ -99,10 +99,10 @@ func @multiply(%A: memref<100x?xf32>, %B: memref) %C = alloc() : memref<100x50xf32> // Multiplication loop nest. - for %i = 0 to 100 { - for %j = 0 to 50 { + affine.for %i = 0 to 100 { + affine.for %j = 0 to 50 { store 0 to %C[%i, %j] : memref<100x50xf32> - for %k = 0 to %n { + affine.for %k = 0 to %n { %a_v = load %A[%i, %k] : memref<100x?xf32> %b_v = load %B[%k, %j] : memref %prod = mulf %a_v, %b_v : f32 @@ -1697,8 +1697,8 @@ The arity of indices is the rank of the memref (i.e., if the memref loaded from is of rank 3, then 3 indices are required for the load following the memref identifier). -In an `affine.if` or `for` body, the indices of a load are restricted to SSA -values bound to surrounding loop induction variables, +In an `affine.if` or `affine.for` body, the indices of a load are restricted to +SSA values bound to surrounding loop induction variables, [symbols](#dimensions-and-symbols), results of a [`constant` operation](#'constant'-operation), or the result of an `affine.apply` operation that can in turn take as arguments all of the @@ -1719,10 +1719,10 @@ Example: **Context:** The `load` and `store` instructions are specifically crafted to fully resolve a reference to an element of a memref, and (in affine `affine.if` -and `for` instructions) the compiler can follow use-def chains (e.g. through -[`affine.apply`](Dialects/Affine.md#'affine.apply'-operation) operations) to -precisely analyze references at compile-time using polyhedral techniques. This -is possible because of the +and `affine.for` instructions) the compiler can follow use-def chains (e.g. +through [`affine.apply`](Dialects/Affine.md#'affine.apply'-operation) +operations) to precisely analyze references at compile-time using polyhedral +techniques. This is possible because of the [restrictions on dimensions and symbols](Dialects/Affine.md#restrictions-on-dimensions-and-symbols) in these contexts. @@ -1755,10 +1755,11 @@ store %100, %A[%1, 1023] : memref<4x?xf32, #layout, hbm> **Context:** The `load` and `store` instructions are specifically crafted to fully resolve a reference to an element of a memref, and (in polyhedral -`affine.if` and `for` instructions) the compiler can follow use-def chains (e.g. -through [`affine.apply`](Dialects/Affine.md#'affine.apply'-operation) -operations) to precisely analyze references at compile-time using polyhedral -techniques. This is possible because of the +`affine.if` and `affine.for` instructions) the compiler can follow use-def +chains (e.g. through +[`affine.apply`](Dialects/Affine.md#'affine.apply'-operation) operations) to +precisely analyze references at compile-time using polyhedral techniques. This +is possible because of the [restrictions on dimensions and symbols](Dialect/Affine.md#restrictions-on-dimensions-and-symbols) in these contexts. diff --git a/mlir/g3doc/Passes.md b/mlir/g3doc/Passes.md index 525918aa429..8e5926aff3d 100644 --- a/mlir/g3doc/Passes.md +++ b/mlir/g3doc/Passes.md @@ -39,9 +39,9 @@ These restrictions may be lifted in the future. ### Output IR -Functions with `for` and `affine.if` instructions eliminated. These functions -may contain operations from the Standard dialect in addition to those already -present before the pass. +Functions with `affine.for` and `affine.if` instructions eliminated. These +functions may contain operations from the Standard dialect in addition to those +already present before the pass. ### Invariants @@ -95,10 +95,10 @@ Input ```mlir func @loop_nest_tiled() -> memref<256x1024xf32> { %0 = alloc() : memref<256x1024xf32> - for %i0 = 0 to 256 step 32 { - for %i1 = 0 to 1024 step 32 { - for %i2 = (d0) -> (d0)(%i0) to (d0) -> (d0 + 32)(%i0) { - for %i3 = (d0) -> (d0)(%i1) to (d0) -> (d0 + 32)(%i1) { + affine.for %i0 = 0 to 256 step 32 { + affine.for %i1 = 0 to 1024 step 32 { + affine.for %i2 = (d0) -> (d0)(%i0) to (d0) -> (d0 + 32)(%i0) { + affine.for %i3 = (d0) -> (d0)(%i1) to (d0) -> (d0 + 32)(%i1) { %1 = load %0[%i2, %i3] : memref<256x1024xf32> } } @@ -119,16 +119,16 @@ func @loop_nest_tiled() -> memref<256x1024xf32> { %c32 = constant 32 : index %c0 = constant 0 : index %0 = alloc() : memref<256x1024xf32> - for %i0 = 0 to 256 step 32 { - for %i1 = 0 to 1024 step 32 { + affine.for %i0 = 0 to 256 step 32 { + affine.for %i1 = 0 to 1024 step 32 { %1 = affine.apply #map1(%i0) %2 = affine.apply #map1(%i1) %3 = alloc() : memref<32x32xf32, 1> %4 = alloc() : memref<1xi32> dma_start %0[%1, %2], %3[%c0, %c0], %c1024, %4[%c0], %c1024, %c32 : memref<256x1024xf32>, memref<32x32xf32, 1>, memref<1xi32> dma_wait %4[%c0], %c1024 : memref<1xi32> - for %i2 = #map1(%i0) to #map2(%i0) { - for %i3 = #map1(%i1) to #map2(%i1) { + affine.for %i2 = #map1(%i0) to #map2(%i0) { + affine.for %i3 = #map1(%i1) to #map2(%i1) { %5 = affine.apply #map3(%i0, %i2) %6 = affine.apply #map3(%i1, %i3) %7 = load %3[%5, %6] : memref<32x32xf32, 1> @@ -194,8 +194,8 @@ Input func @store_load_affine_apply() -> memref<10x10xf32> { %cf7 = constant 7.0 : f32 %m = alloc() : memref<10x10xf32> - for %i0 = 0 to 10 { - for %i1 = 0 to 10 { + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { %t0 = affine.apply (d0, d1) -> (d1 + 1)(%i0, %i1) %t1 = affine.apply (d0, d1) -> (d0)(%i0, %i1) %idx0 = affine.apply (d0, d1) -> (d1) (%t0, %t1) @@ -217,8 +217,8 @@ Output func @store_load_affine_apply() -> memref<10x10xf32> { %cst = constant 7.000000e+00 : f32 %0 = alloc() : memref<10x10xf32> - for %i0 = 0 to 10 { - for %i1 = 0 to 10 { + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { %3 = affine.apply #map1(%1, %2) %4 = affine.apply #map2(%1, %2) store %cst, %0[%3, %4] : memref<10x10xf32> @@ -258,7 +258,7 @@ Input %2 = alloc() : memref<1xf32> %c0 = constant 0 : index %c128 = constant 128 : index - for %i0 = 0 to 8 { + affine.for %i0 = 0 to 8 { dma_start %0[%i0], %1[%i0], %c128, %2[%c0] : memref<256xf32>, memref<32xf32, 1>, memref<1xf32> dma_wait %2[%c0], %c128 : memref<1xf32> %3 = load %1[%i0] : memref<32xf32, 1> @@ -282,7 +282,7 @@ Output %1 = alloc() : memref<2x32xf32, 1> %2 = alloc() : memref<2x1xf32> dma_start %0[%c0], %1[%c0, %c0], %c128, %2[%c0, %c0] : memref<256xf32>, memref<2x32xf32, 1>, memref<2x1xf32> - for %i0 = 1 to 8 { + affine.for %i0 = 1 to 8 { %3 = affine.apply #map2(%i0) %4 = affine.apply #map2(%i0) dma_start %0[%i0], %1[%3, %i0], %c128, %2[%4, %c0] : memref<256xf32>, memref<2x32xf32, 1>, memref<2x1xf32> diff --git a/mlir/g3doc/Rationale.md b/mlir/g3doc/Rationale.md index bc2b14e289b..91b215fc311 100644 --- a/mlir/g3doc/Rationale.md +++ b/mlir/g3doc/Rationale.md @@ -150,8 +150,8 @@ func bar(%A : memref<8x?xf32, #lmap>) { // dynamically using dim instruction. %N = dim %A, 1 : memref<8x?xf32, #lmap> - for %i = 0 to 8 { - for %j = 0 to %N { + affine.for %i = 0 to 8 { + affine.for %j = 0 to %N { // A[i,j] += 1 %s1 = load %A [%i, %j] : memref<8x?xf32, #lmap> %s2 = add %s1, 1 @@ -548,7 +548,7 @@ nested in an outer function that using affine loops. func @search(memref %S, i32 %key) { %ni = dim %A, 0 : memref // This loop can be parallelized - for %i = 0 to %ni { + affine.for %i = 0 to %ni { call @search_body (%A, %S, %i) : (memref, memref, i32) } return @@ -582,9 +582,9 @@ func @search_body(%A: memref, %S: memref, %key: i32) { As per the [MLIR spec](LangRef.md), the restrictions on dimensions and symbol identifiers to be used with the affine.apply instruction only apply to accesses -inside `for` and `affine.if` instructions. However, an analysis of accesses -inside the called function (`@search_body`) is necessary to determine if the -`%i` loop could be parallelized: such function access analysis is calling +inside `affine.for` and `affine.if` instructions. However, an analysis of +accesses inside the called function (`@search_body`) is necessary to determine +if the `%i` loop could be parallelized: such function access analysis is calling context sensitive. ### Non-affine loop bounds {#non-affine-loop-bounds} @@ -604,8 +604,8 @@ for (i=0; i i32 { - for %k = 0 to %m { - for %l = 0 to %n { + affine.for %k = 0 to %m { + affine.for %l = 0 to %n { ... } } @@ -663,13 +663,13 @@ in a dilated convolution. func @conv2d(memref<16x1024x1024x3xf32, #lm0, vmem> %input, memref<5x5x3x32xf32, #lm0, vmem> %kernel, memref<16x512x512x32xf32, #lm0, vmem> %output) { - for %b = 0 to %batch { - for %oh = 0 to %output_height { - for %ow = 0 to %output_width { - for %of = 0 to %output_feature { - for %kh = 0 to %kernel_height { - for %kw = 0 to %kernel_width { - for %if = 0 to %input_feature { + affine.for %b = 0 to %batch { + affine.for %oh = 0 to %output_height { + affine.for %ow = 0 to %output_width { + affine.for %of = 0 to %output_feature { + affine.for %kh = 0 to %kernel_height { + affine.for %kw = 0 to %kernel_width { + affine.for %if = 0 to %input_feature { // Calculate input indices. %1_0 = affine.apply #map1_0 (%0#1, %0#2, %0#4, %0#5) [%h_stride, %w_stride, %h_kernel_dilation, %w_kernel_dilation, @@ -913,10 +913,10 @@ func @dma_hbm_to_vmem(memref<1024 x f32, #layout_map0, hbm> %a, representation. 2(b) requires no change, but impacts how cost models look at index and layout maps. -### `affine.if` and `for` Extensions for "Escaping Scalars" {#extensions-for-"escaping-scalars"} +### `affine.if` and `affine.for` Extensions for "Escaping Scalars" {#extensions-for-"escaping-scalars"} We considered providing a representation for SSA values that are live out of -`if/else` conditional bodies and loop carried in `for` loops. We +`if/else` conditional bodies and loop carried in `affine.for` loops. We ultimately abandoned this approach due to its complexity. In the current design of MLIR, scalar variables cannot escape for loops or if instructions. In situations, where escaping is necessary, we use zero-dimensional tensors and @@ -948,7 +948,7 @@ Example: // Return sum of elements in 1-dimensional mref A func int32 @sum(%A : memref, %N : i32) -> (i32) { %init = 0 - %result = for %i = 0 to N with %tmp(%init) { + %result = affine.for %i = 0 to N with %tmp(%init) { %value = load %A[%i] %sum = %value + %tmp yield %sum @@ -978,7 +978,7 @@ Example: // Compute sum of half of the array func int32 @sum_half(%A, %N) { %s0 = 0 - %s1 = for %i = 1 ... N step 1 with %s2 (%s0) { + %s1 = affine.for %i = 1 ... N step 1 with %s2 (%s0) { %s3 = if (%i >= %N / 2) { %v0 = load %A[%i] %s4 = %s2 + %v0 diff --git a/mlir/g3doc/RationaleSimplifiedPolyhedralForm.md b/mlir/g3doc/RationaleSimplifiedPolyhedralForm.md index f51eff45633..b40f6708d0d 100644 --- a/mlir/g3doc/RationaleSimplifiedPolyhedralForm.md +++ b/mlir/g3doc/RationaleSimplifiedPolyhedralForm.md @@ -184,8 +184,8 @@ Our simple example above would be represented as: ```mlir mlfunc @simple_example(... %N) { - for %i = 0 ... %N step 1 { - for %j = 0 ... %N step 1 { + affine.for %i = 0 ... %N step 1 { + affine.for %j = 0 ... %N step 1 { // identity noop in this case, but can exist in general. %0,%1 = affine.apply #57(%i, %j) @@ -203,8 +203,8 @@ The example with the reduced domain would be represented with an if instruction: ```mlir mlfunc @reduced_domain_example(... %N) { - for %i = 0 ... %N step 1 { - for %j = 0 ... %N step 1 { + affine.for %i = 0 ... %N step 1 { + affine.for %j = 0 ... %N step 1 { // identity noop in this case, but can exist in general. %0,%1 = affinecall #57(%i, %j) @@ -233,8 +233,8 @@ that transformations call into): ```mlir mlfunc @skewed_domain_example(... %N) { - for %t1 = 0 ... 2*N-2 step 1 { - for %t2 = max(0, t1-N+1) ... min(N, t1) step 1 { + affine.for %t1 = 0 ... 2*N-2 step 1 { + affine.for %t2 = max(0, t1-N+1) ... min(N, t1) step 1 { (%i, %j) = (%t1-%t2, %t2) ... } @@ -373,7 +373,7 @@ mlfunc's (if we support them) will also have to have domains. ### Lack of redundancy in IR The traditional form has multiple encodings for the same sorts of behavior: you -end up having bits on `for` loops to specify whether codegen should use +end up having bits on `affine.for` loops to specify whether codegen should use "atomic/separate" policies, unroll loops, etc. Instructions can be split or can generate multiple copies of their instruction because of overlapping domains, etc. diff --git a/mlir/include/mlir/AffineOps/AffineOps.h b/mlir/include/mlir/AffineOps/AffineOps.h index 2620db1407a..d8e34dc7248 100644 --- a/mlir/include/mlir/AffineOps/AffineOps.h +++ b/mlir/include/mlir/AffineOps/AffineOps.h @@ -88,15 +88,15 @@ public: MLIRContext *context); }; -/// The "for" instruction represents an affine loop nest, defining an SSA value -/// for its induction variable. The induction variable is represented as a +/// The "affine.for" instruction represents an affine loop nest, defining an SSA +/// value for its induction variable. The induction variable is represented as a /// BlockArgument to the entry block of the body. The body and induction -/// variable can be created automatically for new "for" ops with 'createBody'. -/// This SSA value always has type index, which is the size of the machine word. -/// The stride, represented by step, is a positive constant integer which -/// defaults to "1" if not present. The lower and upper bounds specify a -/// half-open range: the range includes the lower bound but does not include the -/// upper bound. +/// variable can be created automatically for new "affine.for" ops with +/// 'createBody'. This SSA value always has type index, which is the size of the +/// machine word. The stride, represented by step, is a positive constant +/// integer which defaults to "1" if not present. The lower and upper bounds +/// specify a half-open range: the range includes the lower bound but does not +/// include the upper bound. /// /// The lower and upper bounds of a for operation are represented as an /// application of an affine mapping to a list of SSA values passed to the map. @@ -108,7 +108,7 @@ public: /// /// Example: /// -/// for %i = 1 to 10 { +/// affine.for %i = 1 to 10 { /// ... /// } /// @@ -131,7 +131,7 @@ public: static void getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context); - static StringRef getOperationName() { return "for"; } + static StringRef getOperationName() { return "affine.for"; } static StringRef getStepAttrName() { return "step"; } static StringRef getLowerBoundAttrName() { return "lower_bound"; } static StringRef getUpperBoundAttrName() { return "upper_bound"; } @@ -268,10 +268,10 @@ public: operand_range getOperands() { return {operand_begin(), operand_end()}; } private: - // 'for' instruction that contains this bound. + // 'affine.for' instruction that contains this bound. AffineForOp inst; // Start and end positions of this affine bound operands in the list of - // the containing 'for' instruction operands. + // the containing 'affine.for' instruction operands. unsigned opStart, opEnd; // Affine map for this bound. AffineMap map; diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h index 92c809326e3..f9ea873d0f7 100644 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -376,14 +376,15 @@ public: AffineExpr toAffineExpr(unsigned idx, MLIRContext *context); - /// Adds constraints (lower and upper bounds) for the specified 'for' + /// Adds constraints (lower and upper bounds) for the specified 'affine.for' /// instruction's Value using IR information stored in its bound maps. The /// right identifier is first looked up using forOp's Value. Asserts if the - /// Value corresponding to the 'for' instruction isn't found in the constraint - /// system. Returns failure for the yet unimplemented/unsupported cases. Any - /// new identifiers that are found in the bound operands of the 'for' - /// instruction are added as trailing identifiers (either dimensional or - /// symbolic depending on whether the operand is a valid ML Function symbol). + /// Value corresponding to the 'affine.for' instruction isn't found in the + /// constraint system. Returns failure for the yet unimplemented/unsupported + /// cases. Any new identifiers that are found in the bound operands of the + /// 'affine.for' instruction are added as trailing identifiers (either + /// dimensional or symbolic depending on whether the operand is a valid ML + /// Function symbol). // TODO(bondhugula): add support for non-unit strides. LogicalResult addAffineForOpDomain(AffineForOp forOp); diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index 96e73166ca4..382ff825995 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -45,7 +45,7 @@ class Instruction; class Value; /// Populates 'loops' with IVs of the loops surrounding 'inst' ordered from -/// the outermost 'for' instruction to the innermost one. +/// the outermost 'affine.for' instruction to the innermost one. // TODO(bondhugula): handle 'affine.if' inst's. void getLoopIVs(Instruction &inst, SmallVectorImpl *loops); @@ -113,8 +113,8 @@ AffineForOp insertBackwardComputationSlice(Instruction *srcOpInst, /// surrounding such op's. // For example, the memref region for a load operation at loop depth = 1: // -// for %i = 0 to 32 { -// for %ii = %i to (d0) -> (d0 + 8) (%i) { +// affine.for %i = 0 to 32 { +// affine.for %ii = %i to (d0) -> (d0 + 8) (%i) { // load %A[%ii] // } // } @@ -146,8 +146,8 @@ struct MemRefRegion { /// For example, the memref region for this operation at loopDepth = 1 will /// be: /// - /// for %i = 0 to 32 { - /// for %ii = %i to (d0) -> (d0 + 8) (%i) { + /// affine.for %i = 0 to 32 { + /// affine.for %ii = %i to (d0) -> (d0 + 8) (%i) { /// load %A[%ii] /// } /// } diff --git a/mlir/include/mlir/Analysis/VectorAnalysis.h b/mlir/include/mlir/Analysis/VectorAnalysis.h index f8ed1dd2819..ffe4ea70332 100644 --- a/mlir/include/mlir/Analysis/VectorAnalysis.h +++ b/mlir/include/mlir/Analysis/VectorAnalysis.h @@ -76,9 +76,9 @@ shapeRatio(VectorType superVectorType, VectorType subVectorType); /// The following MLIR snippet: /// /// ```mlir -/// for %i3 = 0 to %0 { -/// for %i4 = 0 to %1 { -/// for %i5 = 0 to %2 { +/// affine.for %i3 = 0 to %0 { +/// affine.for %i4 = 0 to %1 { +/// affine.for %i5 = 0 to %2 { /// %a5 = load %arg0[%i4, %i5, %i3] : memref /// }}} /// ``` @@ -86,9 +86,9 @@ shapeRatio(VectorType superVectorType, VectorType subVectorType); /// may vectorize with {permutation_map: (d0, d1, d2) -> (d2, d1)} into: /// /// ```mlir -/// for %i3 = 0 to %0 step 32 { -/// for %i4 = 0 to %1 { -/// for %i5 = 0 to %2 step 256 { +/// affine.for %i3 = 0 to %0 step 32 { +/// affine.for %i4 = 0 to %1 { +/// affine.for %i5 = 0 to %2 step 256 { /// %4 = vector_transfer_read %arg0, %i4, %i5, %i3 /// {permutation_map: (d0, d1, d2) -> (d2, d1)} : /// (memref, index, index) -> vector<32x256xf32> @@ -103,7 +103,7 @@ shapeRatio(VectorType superVectorType, VectorType subVectorType); /// /// ```mlir /// %cst0 = constant 0 : index -/// for %i0 = 0 to %0 { +/// affine.for %i0 = 0 to %0 { /// %a0 = load %arg0[%cst0, %cst0] : memref /// } /// ``` @@ -111,7 +111,7 @@ shapeRatio(VectorType superVectorType, VectorType subVectorType); /// may vectorize with {permutation_map: (d0) -> (0)} into: /// /// ```mlir -/// for %i0 = 0 to %0 step 128 { +/// affine.for %i0 = 0 to %0 step 128 { /// %3 = vector_transfer_read %arg0, %c0_0, %c0_0 /// {permutation_map: (d0, d1) -> (0)} : /// (memref, index, index) -> vector<128xf32> diff --git a/mlir/include/mlir/EDSC/Types.h b/mlir/include/mlir/EDSC/Types.h index 35216684169..f0ebbed1959 100644 --- a/mlir/include/mlir/EDSC/Types.h +++ b/mlir/include/mlir/EDSC/Types.h @@ -341,7 +341,8 @@ protected: /// (e.g. vectorValue = load(vectorView, zero)). /// /// Only ExprKind::StmtBlockLikeExpr have `enclosedStmts`, these comprise: -/// 1. `For`-loops for which the `lhs` binds to the induction variable, `rhs` +/// 1. `affine.for`-loops for which the `lhs` binds to the induction variable, +/// `rhs` /// binds to an Expr of kind `ExprKind::For` with lower-bound, upper-bound and /// step respectively. // TODO(zinenko): this StmtBlockLikeExpr should be retired in favor of Expr @@ -647,7 +648,7 @@ Stmt For(llvm::ArrayRef indices, llvm::ArrayRef lbs, llvm::ArrayRef ubs, llvm::ArrayRef steps, llvm::ArrayRef enclosedStmts); -/// Define a 'for' loop from with multi-valued bounds. +/// Define a 'affine.for' loop from with multi-valued bounds. /// /// for max(lbs...) to min(ubs...) {} /// diff --git a/mlir/include/mlir/Transforms/LoopUtils.h b/mlir/include/mlir/Transforms/LoopUtils.h index 0404ab74244..1d5203e77d5 100644 --- a/mlir/include/mlir/Transforms/LoopUtils.h +++ b/mlir/include/mlir/Transforms/LoopUtils.h @@ -73,9 +73,10 @@ void getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor, SmallVectorImpl *operands, FuncBuilder *builder); -/// Skew the instructions in the body of a 'for' instruction with the specified -/// instruction-wise shifts. The shifts are with respect to the original -/// execution order, and are multiplied by the loop 'step' before being applied. +/// Skew the instructions in the body of a 'affine.for' instruction with the +/// specified instruction-wise shifts. The shifts are with respect to the +/// original execution order, and are multiplied by the loop 'step' before being +/// applied. LLVM_NODISCARD LogicalResult instBodySkew(AffineForOp forOp, ArrayRef shifts, bool unrollPrologueEpilogue = false); diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h index ab5660be871..b8976669f97 100644 --- a/mlir/include/mlir/Transforms/Utils.h +++ b/mlir/include/mlir/Transforms/Utils.h @@ -95,14 +95,14 @@ Instruction *createComposedAffineApplyOp(FuncBuilder *builder, Location loc, /// /// Before /// -/// for %i = 0 to #map(%N) +/// affine.for %i = 0 to #map(%N) /// %idx = affine.apply (d0) -> (d0 mod 2) (%i) /// send %A[%idx], ... /// %v = "compute"(%idx, ...) /// /// After /// -/// for %i = 0 to #map(%N) +/// affine.for %i = 0 to #map(%N) /// %idx = affine.apply (d0) -> (d0 mod 2) (%i) /// send %A[%idx], ... /// %idx_ = affine.apply (d0) -> (d0 mod 2) (%i) diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 4badde9012b..92035489e21 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -36,7 +36,7 @@ using llvm::dbgs; //===----------------------------------------------------------------------===// AffineOpsDialect::AffineOpsDialect(MLIRContext *context) - : Dialect(/*namePrefix=*/"", context) { + : Dialect(/*namePrefix=*/"affine", context) { addOperations(); } @@ -69,7 +69,7 @@ bool mlir::isValidDim(Value *value) { return isTopLevelSymbol(dimOp->getOperand()); return false; } - // This value is a block argument (which also includes 'for' loop IVs). + // This value is a block argument (which also includes 'affine.for' loop IVs). return true; } @@ -969,7 +969,7 @@ static void printBound(AffineBound bound, const char *prefix, OpAsmPrinter *p) { } void AffineForOp::print(OpAsmPrinter *p) { - *p << "for "; + *p << "affine.for "; p->printOperand(getBody()->getArgument(0)); *p << " = "; printBound(getLowerBound(), "max", p); diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index f786731e88a..e2e9ef68b17 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -708,8 +708,8 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const { // For example, given the following MLIR code with with "source" and // "destination" accesses to the same memref labled, and symbols %M, %N, %K: // -// for %i0 = 0 to 100 { -// for %i1 = 0 to 50 { +// affine.for %i0 = 0 to 100 { +// affine.for %i1 = 0 to 50 { // %a0 = affine.apply // (d0, d1) -> (d0 * 2 - d1 * 4 + s1, d1 * 3 - s0) (%i0, %i1)[%M, %N] // // Source memref access. @@ -717,8 +717,8 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const { // } // } // -// for %i2 = 0 to 100 { -// for %i3 = 0 to 50 { +// affine.for %i2 = 0 to 100 { +// affine.for %i3 = 0 to 50 { // %a1 = affine.apply // (d0, d1) -> (d0 * 7 + d1 * 9 - s1, d1 * 11 + s0) (%i2, %i3)[%K, %M] // // Destination memref access. diff --git a/mlir/lib/Analysis/TestParallelismDetection.cpp b/mlir/lib/Analysis/TestParallelismDetection.cpp index b954f0e67d9..7ed59b403cd 100644 --- a/mlir/lib/Analysis/TestParallelismDetection.cpp +++ b/mlir/lib/Analysis/TestParallelismDetection.cpp @@ -15,7 +15,7 @@ // limitations under the License. // ============================================================================= // -// This file implements a pass to detect parallel affine 'for' ops. +// This file implements a pass to detect parallel affine 'affine.for' ops. // //===----------------------------------------------------------------------===// @@ -40,7 +40,8 @@ FunctionPassBase *mlir::createParallelismDetectionTestPass() { return new TestParallelismDetection(); } -// Walks the function and emits a note for all 'for' ops detected as parallel. +// Walks the function and emits a note for all 'affine.for' ops detected as +// parallel. void TestParallelismDetection::runOnFunction() { Function *f = getFunction(); FuncBuilder b(f); diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 5a6e1f84b35..6bc395c46bd 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -38,11 +38,11 @@ using namespace mlir; using llvm::SmallDenseMap; /// Populates 'loops' with IVs of the loops surrounding 'inst' ordered from -/// the outermost 'for' instruction to the innermost one. +/// the outermost 'affine.for' instruction to the innermost one. void mlir::getLoopIVs(Instruction &inst, SmallVectorImpl *loops) { auto *currInst = inst.getParentInst(); AffineForOp currAffineForOp; - // Traverse up the hierarchy collecing all 'for' instruction while + // Traverse up the hierarchy collecing all 'affine.for' instruction while // skipping over 'affine.if' instructions. while (currInst && ((currAffineForOp = currInst->dyn_cast()) || currInst->isa())) { @@ -162,8 +162,8 @@ LogicalResult MemRefRegion::unionBoundingBox(const MemRefRegion &other) { // For example, the memref region for this load operation at loopDepth = 1 will // be as below: // -// for %i = 0 to 32 { -// for %ii = %i to (d0) -> (d0 + 8) (%i) { +// affine.for %i = 0 to 32 { +// affine.for %ii = %i to (d0) -> (d0 + 8) (%i) { // load %A[%ii] // } // } @@ -683,7 +683,7 @@ static Optional getMemoryFootprintBytes(Block &block, int memorySpace) { SmallDenseMap, 4> regions; - // Walk this 'for' instruction to gather all memory regions. + // Walk this 'affine.for' instruction to gather all memory regions. bool error = false; block.walk(start, end, [&](Instruction *opInst) { if (!opInst->isa() && !opInst->isa()) { diff --git a/mlir/lib/EDSC/MLIREmitter.cpp b/mlir/lib/EDSC/MLIREmitter.cpp index 1196748a0af..89c66b08941 100644 --- a/mlir/lib/EDSC/MLIREmitter.cpp +++ b/mlir/lib/EDSC/MLIREmitter.cpp @@ -146,7 +146,8 @@ Value *mlir::edsc::MLIREmitter::emitExpr(Expr e) { if (auto expr = e.dyn_cast()) { if (expr.getKind() == ExprKind::For) { auto exprGroups = expr.getAllArgumentGroups(); - assert(exprGroups.size() == 3 && "expected 3 expr groups in `for`"); + assert(exprGroups.size() == 3 && + "expected 3 expr groups in `affine.for`"); assert(!exprGroups[0].empty() && "expected at least one lower bound"); assert(!exprGroups[1].empty() && "expected at least one upper bound"); assert(exprGroups[2].size() == 1 && diff --git a/mlir/lib/EDSC/Types.cpp b/mlir/lib/EDSC/Types.cpp index 72c453718e7..ac8b98e38c3 100644 --- a/mlir/lib/EDSC/Types.cpp +++ b/mlir/lib/EDSC/Types.cpp @@ -526,8 +526,8 @@ Stmt mlir::edsc::For(llvm::ArrayRef idxs, llvm::ArrayRef lbs, Stmt mlir::edsc::MaxMinFor(const Bindable &idx, ArrayRef lbs, ArrayRef ubs, Expr step, ArrayRef enclosedStmts) { - assert(!lbs.empty() && "'for' loop must have lower bounds"); - assert(!ubs.empty() && "'for' loop must have upper bounds"); + assert(!lbs.empty() && "'affine.for' loop must have lower bounds"); + assert(!ubs.empty() && "'affine.for' loop must have upper bounds"); // Use a null expression as a sentinel between lower and upper bound // expressions in the list of children. @@ -964,7 +964,7 @@ void mlir::edsc::Expr::print(raw_ostream &os) const { } else if (auto stmtLikeExpr = this->dyn_cast()) { switch (stmtLikeExpr.getKind()) { // We only print the lb, ub and step here, which are the StmtBlockLike - // part of the `for` StmtBlockLikeExpr. + // part of the `affine.for` StmtBlockLikeExpr. case ExprKind::For: { auto exprGroups = stmtLikeExpr.getAllArgumentGroups(); assert(exprGroups.size() == 3 && diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 7f8c7e411e8..4fa040d73eb 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -343,7 +343,7 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, Block *block, auto fastMemRefType = top.getMemRefType( fastBufferShape, memRefType.getElementType(), {}, fastMemorySpace); - // Create the fast memory space buffer just before the 'for' + // Create the fast memory space buffer just before the 'affine.for' // instruction. fastMemRef = prologue.create(loc, fastMemRefType)->getResult(); // Record it. @@ -472,7 +472,7 @@ bool DmaGeneration::runOnBlock(Block *block) { // approach is conservative in some cases at the moment, we do a check later // and report an error with location info. // TODO(bondhugula): An 'affine.if' instruction is being treated similar to an - // operation instruction. 'affine.if''s could have 'for's in them; + // operation instruction. 'affine.if''s could have 'affine.for's in them; // treat them separately. // Get to the first load, store, or for op. @@ -494,7 +494,7 @@ bool DmaGeneration::runOnBlock(Block *block) { fastMemCapacityBytes); }; - // If the memory footprint of the 'for' loop is higher than fast + // If the memory footprint of the 'affine.for' loop is higher than fast // memory capacity (when provided), we recurse to DMA at an inner level // until we find a depth at which footprint fits in fast mem capacity. If // the footprint can't be calculated, we assume for now it fits. Recurse @@ -507,7 +507,7 @@ bool DmaGeneration::runOnBlock(Block *block) { runOnBlock(/*begin=*/curBegin, /*end=*/it); // Recurse onto the body of this loop. runOnBlock(forOp->getBody()); - // The next region starts right after the 'for' instruction. + // The next region starts right after the 'affine.for' instruction. curBegin = std::next(it); } else { // We have enough capacity, i.e., DMAs will be computed for the portion @@ -698,7 +698,8 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { [&](const SmallMapVector, 4> ®ions) { for (const auto ®ionEntry : regions) { - // For each region, hoist DMA transfer past all invariant 'for's. + // For each region, hoist DMA transfer past all invariant + // 'affine.for's. Block::iterator dmaPlacementReadStart, dmaPlacementWriteStart; Block *dmaPlacementBlock; findHighestBlockForPlacement( diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 8e1fc505348..84644bf11a0 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -641,7 +641,8 @@ bool MemRefDependenceGraph::init(Function *f) { // all loads and store accesses it contains. LoopNestStateCollector collector; collector.collect(&inst); - // Return false if a non 'for' region was found (not currently supported). + // Return false if a non 'affine.for' region was found (not currently + // supported). if (collector.hasNonForRegion) return false; Node node(nextNodeId++, &inst); diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 0b629531df0..314864d3f3c 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -258,7 +258,8 @@ LogicalResult mlir::tileCodeGen(MutableArrayRef band, // Returns all maximal outermost perfect loop nests to tile. static void getTileableBands(Function *f, std::vector> *bands) { - // Get maximal perfect nest of 'for' insts starting from root (inclusive). + // Get maximal perfect nest of 'affine.for' insts starting from root + // (inclusive). auto getMaximalPerfectLoopNest = [&](AffineForOp root) { SmallVector band; AffineForOp currInst = root; diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index a16237e6452..173a171e589 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -158,8 +158,8 @@ void LoopUnroll::runOnFunction() { } } -/// Unrolls a 'for' inst. Returns success if the loop was unrolled, failure -/// otherwise. The default unroll factor is 4. +/// Unrolls a 'affine.for' inst. Returns success if the loop was unrolled, +/// failure otherwise. The default unroll factor is 4. LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) { // Use the function callback if one was provided. if (getUnrollFactor) { diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 03c06b4b450..240f3960488 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -96,7 +96,7 @@ void LoopUnrollAndJam::runOnFunction() { runOnAffineForOp(forOp); } -/// Unroll and jam a 'for' inst. Default unroll jam factor is +/// Unroll and jam a 'affine.for' inst. Default unroll jam factor is /// kDefaultUnrollJamFactor. Return failure if nothing was done. LogicalResult LoopUnrollAndJam::runOnAffineForOp(AffineForOp forOp) { // Unroll and jam by the factor that was passed if any. diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index 3061bcd254d..cb65720cee3 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -276,7 +276,7 @@ static Value *buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate, return value; } -// Convert a "for" loop to a flow of blocks. Return `false` on success. +// Convert a "affine.for" loop to a flow of blocks. Return `false` on success. // // Create an SESE region for the loop (including its body) and append it to the // end of the current region. The loop region consists of the initialization @@ -323,8 +323,9 @@ bool LowerAffinePass::lowerAffineFor(AffineForOp forOp) { auto loc = forOp->getLoc(); auto *forInst = forOp->getInstruction(); - // Start by splitting the block containing the 'for' into two parts. The part - // before will get the init code, the part after will be the end point. + // Start by splitting the block containing the 'affine.for' into two parts. + // The part before will get the init code, the part after will be the end + // point. auto *initBlock = forInst->getBlock(); auto *endBlock = initBlock->splitBlock(forInst); diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index cde28c6517d..7f6be358189 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -56,9 +56,9 @@ /// // Read the slice `%A[%i0, %i1:%i1+256, %i2:%i2+32]` into /// // vector<32x256xf32> and pad with %f0 to handle the boundary case: /// %f0 = constant 0.0f : f32 -/// for %i0 = 0 to %0 { -/// for %i1 = 0 to %1 step 256 { -/// for %i2 = 0 to %2 step 32 { +/// affine.for %i0 = 0 to %0 { +/// affine.for %i1 = 0 to %1 step 256 { +/// affine.for %i2 = 0 to %2 step 32 { /// %v = vector_transfer_read %A, %i0, %i1, %i2, %f0 /// {permutation_map: (d0, d1, d2) -> (d2, d1)} : /// (memref, index, index, f32) -> vector<32x256xf32> @@ -70,8 +70,8 @@ /// abstraction): /// /// ```mlir {.mlir} -/// for %d2 = 0 to 256 { -/// for %d1 = 0 to 32 { +/// affine.for %d2 = 0 to 256 { +/// affine.for %d1 = 0 to 32 { /// %s = %A[%i0, %i1 + %d1, %i2 + %d2] : f32 /// %tmp[%d2, %d1] = %s /// } diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 0a7eaabbb09..ebdb0c8e83e 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -100,10 +100,10 @@ /// mlfunc @materialize(%M : index, %N : index, %O : index, %P : index) { /// %A = alloc (%M, %N, %O, %P) : memref /// %f1 = constant splat, 1.000000e+00> : -/// vector<4x4x4xf32> for %i0 = 0 to %M step 4 { -/// for %i1 = 0 to %N step 4 { -/// for %i2 = 0 to %O { -/// for %i3 = 0 to %P step 4 { +/// vector<4x4x4xf32> affine.for %i0 = 0 to %M step 4 { +/// affine.for %i1 = 0 to %N step 4 { +/// affine.for %i2 = 0 to %O { +/// affine.for %i3 = 0 to %P step 4 { /// vector_transfer_write %f1, %A, %i0, %i1, %i2, %i3 /// {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d0)} : /// vector<4x4x4xf32>, memref, @@ -119,10 +119,10 @@ /// mlfunc @materialize(%M : index, %N : index, %O : index, %P : index) { /// %A = alloc (%M, %N, %O, %P) : memref /// %f1 = constant splat, 1.000000e+00> : vector<4x4x4xf32> -/// for %i0 = 0 to %arg0 step 4 { -/// for %i1 = 0 to %arg1 step 4 { -/// for %i2 = 0 to %arg2 { -/// for %i3 = 0 to %arg3 step 4 { +/// affine.for %i0 = 0 to %arg0 step 4 { +/// affine.for %i1 = 0 to %arg1 step 4 { +/// affine.for %i2 = 0 to %arg2 { +/// affine.for %i3 = 0 to %arg3 step 4 { /// %1 = affine.apply (d0, d1, d2, d3) -> (d0, d1, d2, d3) /// (%i0, %i1, %i2, %i3) /// vector_transfer_write f1, %0, %1#0, %1#1, %1#2, %1#3 @@ -286,10 +286,10 @@ static Value *substitute(Value *v, VectorType hwVectorType, /// super-vectorization has been applied: /// /// ```mlir -/// for %i0 = 0 to %M { -/// for %i1 = 0 to %N step 3 { -/// for %i2 = 0 to %O { -/// for %i3 = 0 to %P step 32 { +/// affine.for %i0 = 0 to %M { +/// affine.for %i1 = 0 to %N step 3 { +/// affine.for %i2 = 0 to %O { +/// affine.for %i3 = 0 to %P step 32 { /// %r = vector_transfer_read(%A, map(%i..)#0, map(%i..)#1, map(%i..)#2) /// -> vector<3x32xf32> /// ... diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index a35a159443d..a7045b3b541 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -19,7 +19,7 @@ // potentially getting rid of intermediate memref's entirely. // TODO(mlir-team): In the future, similar techniques could be used to eliminate // dead memref store's and perform more complex forwarding when support for -// SSA scalars live out of 'for'/'affine.if' statements is available. +// SSA scalars live out of 'affine.for'/'affine.if' statements is available. //===----------------------------------------------------------------------===// #include "mlir/Analysis/AffineAnalysis.h" @@ -55,7 +55,7 @@ namespace { // // (* A dependence being satisfied at a block: a dependence that is satisfied by // virtue of the destination instruction appearing textually / lexically after -// the source instruction within the body of a 'for' instruction; thus, a +// the source instruction within the body of a 'affine.for' instruction; thus, a // dependence is always either satisfied by a loop or by a block). // // The above conditions are simple to check, sufficient, and powerful for most @@ -139,8 +139,8 @@ void MemRefDataFlowOpt::forwardStoreToLoad(LoadOp loadOp) { // Check if this store is a candidate for forwarding; we only forward if // the dependence from the store is carried by the *body* of innermost // common surrounding loop. As an example this filters out cases like: - // for %i0 - // for %i1 + // affine.for %i0 + // affine.for %i1 // %idx = affine.apply (d0) -> (d0 + 1) (%i0) // store %A[%idx] // load %A[%i0] diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 9809a146072..b59071aa9fe 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -66,11 +66,11 @@ static unsigned getTagMemRefPos(Instruction &dmaInst) { return 0; } -/// Doubles the buffer of the supplied memref on the specified 'for' instruction -/// by adding a leading dimension of size two to the memref. Replaces all uses -/// of the old memref by the new one while indexing the newly added dimension by -/// the loop IV of the specified 'for' instruction modulo 2. Returns false if -/// such a replacement cannot be performed. +/// Doubles the buffer of the supplied memref on the specified 'affine.for' +/// instruction by adding a leading dimension of size two to the memref. +/// Replaces all uses of the old memref by the new one while indexing the newly +/// added dimension by the loop IV of the specified 'affine.for' instruction +/// modulo 2. Returns false if such a replacement cannot be performed. static bool doubleBuffer(Value *oldMemRef, AffineForOp forOp) { auto *forBody = forOp->getBody(); FuncBuilder bInner(forBody, forBody->begin()); @@ -104,7 +104,7 @@ static bool doubleBuffer(Value *oldMemRef, AffineForOp forOp) { dynamicDimCount++)); } - // Create and place the alloc right before the 'for' instruction. + // Create and place the alloc right before the 'affine.for' instruction. Value *newMemRef = bOuter.create(forInst->getLoc(), newMemRefType, allocOperands); @@ -139,9 +139,9 @@ static bool doubleBuffer(Value *oldMemRef, AffineForOp forOp) { /// Returns success if the IR is in a valid state. void PipelineDataTransfer::runOnFunction() { // Do a post order walk so that inner loop DMAs are processed first. This is - // necessary since 'for' instructions nested within would otherwise become - // invalid (erased) when the outer loop is pipelined (the pipelined one gets - // deleted and replaced by a prologue, a new steady-state loop and an + // necessary since 'affine.for' instructions nested within would otherwise + // become invalid (erased) when the outer loop is pipelined (the pipelined one + // gets deleted and replaced by a prologue, a new steady-state loop and an // epilogue). forOps.clear(); getFunction()->walkPostOrder( diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index e5f1fef990f..bf0c3ced2e2 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -71,7 +71,7 @@ void mlir::getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor, auto lb = b->create(forOp->getLoc(), lbMap, lbOperands); // For each upper bound expr, get the range. - // Eg: for %i = lb to min (ub1, ub2), + // Eg: affine.for %i = lb to min (ub1, ub2), // where tripCountExprs yield (tr1, tr2), we create affine.apply's: // lb + tr1 - tr1 % ufactor, lb + tr2 - tr2 % ufactor; the results of all // these affine.apply's make up the cleanup loop lower bound. @@ -161,8 +161,8 @@ void mlir::promoteSingleIterationLoops(Function *f) { [](AffineForOp forOp) { promoteIfSingleIteration(forOp); }); } -/// Generates a 'for' inst with the specified lower and upper bounds while -/// generating the right IV remappings for the shifted instructions. The +/// Generates a 'affine.for' inst with the specified lower and upper bounds +/// while generating the right IV remappings for the shifted instructions. The /// instruction blocks that go into the loop are specified in instGroupQueue /// starting from the specified offset, and in that order; the first element of /// the pair specifies the shift applied to that group of instructions; note @@ -216,10 +216,10 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, return loopChunk; } -/// Skew the instructions in the body of a 'for' instruction with the specified -/// instruction-wise shifts. The shifts are with respect to the original -/// execution order, and are multiplied by the loop 'step' before being applied. -/// A shift of zero for each instruction will lead to no change. +/// Skew the instructions in the body of a 'affine.for' instruction with the +/// specified instruction-wise shifts. The shifts are with respect to the +/// original execution order, and are multiplied by the loop 'step' before being +/// applied. A shift of zero for each instruction will lead to no change. // The skewing of instructions with respect to one another can be used for // example to allow overlap of asynchronous operations (such as DMA // communication) with computation, or just relative shifting of instructions @@ -267,7 +267,7 @@ LogicalResult mlir::instBodySkew(AffineForOp forOp, ArrayRef shifts, // An array of instruction groups sorted by shift amount; each group has all // instructions with the same shift in the order in which they appear in the - // body of the 'for' inst. + // body of the 'affine.for' inst. std::vector> sortedInstGroups(maxShift + 1); unsigned pos = 0; for (auto &inst : *forOp->getBody()) { @@ -499,7 +499,7 @@ void mlir::sinkLoop(AffineForOp forOp, unsigned loopDepth) { // bounds, the resulting IR resembles: // // ```mlir -// for %i = max (`iv, ...) to min (`iv` + `offset`) { +// affine.for %i = max (`iv, ...) to min (`iv` + `offset`) { // ... // } // ``` diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 7bf9993b7c8..7a44a6277a6 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -199,14 +199,14 @@ bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, /// /// Before /// -/// for %i = 0 to #map(%N) +/// affine.for %i = 0 to #map(%N) /// %idx = affine.apply (d0) -> (d0 mod 2) (%i) /// "send"(%idx, %A, ...) /// "compute"(%idx) /// /// After /// -/// for %i = 0 to #map(%N) +/// affine.for %i = 0 to #map(%N) /// %idx = affine.apply (d0) -> (d0 mod 2) (%i) /// "send"(%idx, %A, ...) /// %idx_ = affine.apply (d0) -> (d0 mod 2) (%i) diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 955e38f4b39..a52129ed0d6 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -113,7 +113,7 @@ using namespace mlir; /// /// At a high level, a vectorized load in a loop will resemble: /// ```mlir -/// for %i = ? to ? step ? { +/// affine.for %i = ? to ? step ? { /// %v_a = "vector_transfer_read" (A, %i) : (memref, index) -> /// vector<128xf32> /// } @@ -309,7 +309,7 @@ using namespace mlir; /// ```mlir /// mlfunc @fill(%A : memref<128xf32>) -> () { /// %f1 = constant 1.0 : f32 -/// for %i0 = 0 to 32 { +/// affine.for %i0 = 0 to 32 { /// store %f1, %A[%i0] : memref<128xf32, 0> /// } /// return @@ -322,7 +322,7 @@ using namespace mlir; /// is still subject to exploratory tradeoffs. In particular, say we want to /// vectorize by a factor 128, we want to transform the following input: /// ```mlir -/// for %i = %M to %N { +/// affine.for %i = %M to %N { /// %a = load A[%i] : memref /// } /// ``` @@ -331,8 +331,8 @@ using namespace mlir; /// memory promotion etc) say after stripmining (and potentially unrolling in /// the case of LLVM's SLP vectorizer): /// ```mlir -/// for %i = floor(%M, 128) to ceil(%N, 128) { -/// for %ii = max(%M, 128 * %i) to min(%N, 128*%i + 127) { +/// affine.for %i = floor(%M, 128) to ceil(%N, 128) { +/// affine.for %ii = max(%M, 128 * %i) to min(%N, 128*%i + 127) { /// %a = load A[%ii] : memref /// } /// } @@ -341,7 +341,7 @@ using namespace mlir; /// Instead, we seek to vectorize early and freeze vector types before /// scheduling, so we want to generate a pattern that resembles: /// ```mlir -/// for %i = ? to ? step ? { +/// affine.for %i = ? to ? step ? { /// %v_a = "vector_transfer_read" (A, %i) : (memref, index) -> /// vector<128xf32> /// } @@ -362,7 +362,7 @@ using namespace mlir; /// For the simple strawman example above, vectorizing for a 1-D vector /// abstraction of size 128 returns code similar to: /// ```mlir -/// for %i = %M to %N step 128 { +/// affine.for %i = %M to %N step 128 { /// %v_a = "vector_transfer_read" (A, %i) : (memref, index) -> /// vector<128xf32> /// } @@ -391,20 +391,20 @@ using namespace mlir; /// %C = alloc (%M, %N) : memref /// %f1 = constant 1.0 : f32 /// %f2 = constant 2.0 : f32 -/// for %i0 = 0 to %M { -/// for %i1 = 0 to %N { +/// affine.for %i0 = 0 to %M { +/// affine.for %i1 = 0 to %N { /// // non-scoped %f1 /// store %f1, %A[%i0, %i1] : memref /// } /// } -/// for %i2 = 0 to %M { -/// for %i3 = 0 to %N { +/// affine.for %i2 = 0 to %M { +/// affine.for %i3 = 0 to %N { /// // non-scoped %f2 /// store %f2, %B[%i2, %i3] : memref /// } /// } -/// for %i4 = 0 to %M { -/// for %i5 = 0 to %N { +/// affine.for %i4 = 0 to %M { +/// affine.for %i5 = 0 to %N { /// %a5 = load %A[%i4, %i5] : memref /// %b5 = load %B[%i4, %i5] : memref /// %s5 = addf %a5, %b5 : f32 @@ -438,24 +438,24 @@ using namespace mlir; /// %2 = alloc(%arg0, %arg1) : memref /// %cst = constant 1.0 : f32 /// %cst_0 = constant 2.0 : f32 -/// for %i0 = 0 to %arg0 { -/// for %i1 = 0 to %arg1 step 256 { +/// affine.for %i0 = 0 to %arg0 { +/// affine.for %i1 = 0 to %arg1 step 256 { /// %cst_1 = constant splat, 1.0> : /// vector<256xf32> /// "vector_transfer_write"(%cst_1, %0, %i0, %i1) : /// (vector<256xf32>, memref, index, index) -> () /// } /// } -/// for %i2 = 0 to %arg0 { -/// for %i3 = 0 to %arg1 step 256 { +/// affine.for %i2 = 0 to %arg0 { +/// affine.for %i3 = 0 to %arg1 step 256 { /// %cst_2 = constant splat, 2.0> : /// vector<256xf32> /// "vector_transfer_write"(%cst_2, %1, %i2, %i3) : /// (vector<256xf32>, memref, index, index) -> () /// } /// } -/// for %i4 = 0 to %arg0 { -/// for %i5 = 0 to %arg1 step 256 { +/// affine.for %i4 = 0 to %arg0 { +/// affine.for %i5 = 0 to %arg1 step 256 { /// %3 = "vector_transfer_read"(%0, %i4, %i5) : /// (memref, index, index) -> vector<256xf32> /// %4 = "vector_transfer_read"(%1, %i4, %i5) : @@ -494,24 +494,24 @@ using namespace mlir; /// %2 = alloc(%arg0, %arg1) : memref /// %cst = constant 1.0 : f32 /// %cst_0 = constant 2.0 : f32 -/// for %i0 = 0 to %arg0 step 32 { -/// for %i1 = 0 to %arg1 step 256 { +/// affine.for %i0 = 0 to %arg0 step 32 { +/// affine.for %i1 = 0 to %arg1 step 256 { /// %cst_1 = constant splat, 1.0> : /// vector<32x256xf32> /// "vector_transfer_write"(%cst_1, %0, %i0, %i1) : /// (vector<32x256xf32>, memref, index, index) -> () /// } /// } -/// for %i2 = 0 to %arg0 step 32 { -/// for %i3 = 0 to %arg1 step 256 { +/// affine.for %i2 = 0 to %arg0 step 32 { +/// affine.for %i3 = 0 to %arg1 step 256 { /// %cst_2 = constant splat, 2.0> : /// vector<32x256xf32> /// "vector_transfer_write"(%cst_2, %1, %i2, %i3) : /// (vector<32x256xf32>, memref, index, index) -> () /// } /// } -/// for %i4 = 0 to %arg0 step 32 { -/// for %i5 = 0 to %arg1 step 256 { +/// affine.for %i4 = 0 to %arg0 step 32 { +/// affine.for %i5 = 0 to %arg1 step 256 { /// %3 = "vector_transfer_read"(%0, %i4, %i5) : /// (memref, index, index) -> vector<32x256xf32> /// %4 = "vector_transfer_read"(%1, %i4, %i5) : diff --git a/mlir/test/AffineOps/canonicalize.mlir b/mlir/test/AffineOps/canonicalize.mlir index 4fd44be8538..90f6aede0d5 100644 --- a/mlir/test/AffineOps/canonicalize.mlir +++ b/mlir/test/AffineOps/canonicalize.mlir @@ -47,7 +47,7 @@ func @compose_affine_maps_1dto2d_no_symbols() { %0 = alloc() : memref<4x4xf32> - for %i0 = 0 to 15 { + affine.for %i0 = 0 to 15 { // Test load[%x, %x] %x0 = affine.apply (d0) -> (d0 - 1) (%i0) @@ -93,7 +93,7 @@ func @compose_affine_maps_1dto2d_no_symbols() { func @compose_affine_maps_1dto2d_with_symbols() { %0 = alloc() : memref<4x4xf32> - for %i0 = 0 to 15 { + affine.for %i0 = 0 to 15 { // Test load[%x0, %x0] with symbol %c4 %c4 = constant 4 : index %x0 = affine.apply (d0)[s0] -> (d0 - s0) (%i0)[%c4] @@ -134,13 +134,13 @@ func @compose_affine_maps_2d_tile() { %c4 = constant 4 : index %c8 = constant 8 : index - for %i0 = 0 to 3 { + affine.for %i0 = 0 to 3 { %x0 = affine.apply (d0)[s0] -> (d0 ceildiv s0) (%i0)[%c4] - for %i1 = 0 to 3 { + affine.for %i1 = 0 to 3 { %x1 = affine.apply (d0)[s0] -> (d0 ceildiv s0) (%i1)[%c8] - for %i2 = 0 to 3 { + affine.for %i2 = 0 to 3 { %x2 = affine.apply (d0)[s0] -> (d0 mod s0) (%i2)[%c4] - for %i3 = 0 to 3 { + affine.for %i3 = 0 to 3 { %x3 = affine.apply (d0)[s0] -> (d0 mod s0) (%i3)[%c8] %x40 = affine.apply (d0, d1, d2, d3)[s0, s1] -> @@ -166,9 +166,9 @@ func @compose_affine_maps_dependent_loads() { %0 = alloc() : memref<16x32xf32> %1 = alloc() : memref<16x32xf32> - for %i0 = 0 to 3 { - for %i1 = 0 to 3 { - for %i2 = 0 to 3 { + affine.for %i0 = 0 to 3 { + affine.for %i1 = 0 to 3 { + affine.for %i2 = 0 to 3 { %c3 = constant 3 : index %c7 = constant 7 : index @@ -212,7 +212,7 @@ func @compose_affine_maps_dependent_loads() { func @compose_affine_maps_diamond_dependency() { %0 = alloc() : memref<4x4xf32> - for %i0 = 0 to 15 { + affine.for %i0 = 0 to 15 { %a = affine.apply (d0) -> (d0 - 1) (%i0) %b = affine.apply (d0) -> (d0 + 7) (%a) %c = affine.apply (d0) -> (d0 * 4) (%a) @@ -232,8 +232,8 @@ func @arg_used_as_dim_and_symbol(%arg0: memref<100x100xf32>, %arg1: index) { %c9 = constant 9 : index %1 = alloc() : memref<100x100xf32, 1> %2 = alloc() : memref<1xi32> - for %i0 = 0 to 100 { - for %i1 = 0 to 100 { + affine.for %i0 = 0 to 100 { + affine.for %i1 = 0 to 100 { %3 = affine.apply (d0, d1)[s0, s1] -> (d1 + s0 + s1) (%i0, %i1)[%arg1, %c9] %4 = affine.apply (d0, d1, d3) -> (d3 - (d0 + d1)) @@ -253,7 +253,7 @@ func @trivial_maps() { %0 = alloc() : memref<10xf32> %c0 = constant 0 : index %cst = constant 0.000000e+00 : f32 - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { %1 = affine.apply ()[s0] -> (s0)()[%c0] store %cst, %0[%1] : memref<10xf32> %2 = load %0[%c0] : memref<10xf32> @@ -380,7 +380,7 @@ func @mix_dims_and_symbols_g(%M: index, %N: index) -> (index, index, index) { // CHECK-LABEL: func @symbolic_semi_affine(%arg0: index, %arg1: index, %arg2: memref) { func @symbolic_semi_affine(%M: index, %N: index, %A: memref) { %f1 = constant 1.0 : f32 - for %i0 = 1 to 100 { + affine.for %i0 = 1 to 100 { %1 = affine.apply ()[s0] -> (s0 + 1) ()[%M] %2 = affine.apply (d0)[s0] -> (d0 floordiv s0) (%i0)[%1] // CHECK-DAG: {{.*}} = affine.apply [[symbolic_semi_affine]](%i0)[%arg0] @@ -404,20 +404,20 @@ func @constant_fold_bounds(%N : index) { %c3 = affine.apply (d0, d1) -> (d0 + d1) (%c1, %c2) %l = "foo"() : () -> index - // CHECK: for %i0 = 5 to 7 { - for %i = max (d0, d1) -> (0, d0 + d1)(%c2, %c3) to min (d0, d1) -> (d0 - 2, 32*d1) (%c9, %c1) { + // CHECK: affine.for %i0 = 5 to 7 { + affine.for %i = max (d0, d1) -> (0, d0 + d1)(%c2, %c3) to min (d0, d1) -> (d0 - 2, 32*d1) (%c9, %c1) { "foo"(%i, %c3) : (index, index) -> () } // Bound takes a non-constant argument but can still be folded. - // CHECK: for %i1 = 1 to 7 { - for %j = max (d0) -> (0, 1)(%N) to min (d0, d1) -> (7, 9)(%N, %l) { + // CHECK: affine.for %i1 = 1 to 7 { + affine.for %j = max (d0) -> (0, 1)(%N) to min (d0, d1) -> (7, 9)(%N, %l) { "foo"(%j, %c3) : (index, index) -> () } // None of the bounds can be folded. - // CHECK: for %i2 = max [[MAP0]]()[%0] to min [[MAP1]]()[%arg0] { - for %k = max ()[s0] -> (0, s0) ()[%l] to min ()[s0] -> (100, s0)()[%N] { + // CHECK: affine.for %i2 = max [[MAP0]]()[%0] to min [[MAP1]]()[%arg0] { + affine.for %k = max ()[s0] -> (0, s0) ()[%l] to min ()[s0] -> (100, s0)()[%N] { "foo"(%k, %c3) : (index, index) -> () } return diff --git a/mlir/test/AffineOps/invalid.mlir b/mlir/test/AffineOps/invalid.mlir index b9093c756b7..69260a7fce6 100644 --- a/mlir/test/AffineOps/invalid.mlir +++ b/mlir/test/AffineOps/invalid.mlir @@ -5,7 +5,7 @@ #map = (d0)[s0] -> (d0 + s0) func @affine_apply_invalid_dim(%arg : index) { - for %n0 = 0 to 7 { + affine.for %n0 = 0 to 7 { %dim = addi %arg, %arg : index // expected-error@+1 {{operand cannot be used as a dimension id}} @@ -19,7 +19,7 @@ func @affine_apply_invalid_dim(%arg : index) { #map0 = (d0)[s0] -> (d0 + s0) func @affine_apply_invalid_sym() { - for %i0 = 0 to 7 { + affine.for %i0 = 0 to 7 { // expected-error@+1 {{operand cannot be used as a symbol}} %0 = affine.apply #map0(%i0)[%i0] } @@ -31,11 +31,11 @@ func @affine_apply_invalid_sym() { #map = (d0)[s0] -> (d0 + s0) func @affine_for_lower_bound_invalid_dim(%arg : index) { - for %n0 = 0 to 7 { + affine.for %n0 = 0 to 7 { %dim = addi %arg, %arg : index // expected-error@+1 {{operand cannot be used as a dimension id}} - for %n1 = 0 to #map(%dim)[%arg] { + affine.for %n1 = 0 to #map(%dim)[%arg] { } } return @@ -46,11 +46,11 @@ func @affine_for_lower_bound_invalid_dim(%arg : index) { #map = (d0)[s0] -> (d0 + s0) func @affine_for_upper_bound_invalid_dim(%arg : index) { - for %n0 = 0 to 7 { + affine.for %n0 = 0 to 7 { %dim = addi %arg, %arg : index // expected-error@+1 {{operand cannot be used as a dimension id}} - for %n1 = #map(%dim)[%arg] to 7 { + affine.for %n1 = #map(%dim)[%arg] to 7 { } } return @@ -61,9 +61,9 @@ func @affine_for_upper_bound_invalid_dim(%arg : index) { #map0 = (d0)[s0] -> (d0 + s0) func @affine_for_lower_bound_invalid_sym() { - for %i0 = 0 to 7 { + affine.for %i0 = 0 to 7 { // expected-error@+1 {{operand cannot be used as a symbol}} - for %n0 = #map0(%i0)[%i0] to 7 { + affine.for %n0 = #map0(%i0)[%i0] to 7 { } } return @@ -74,9 +74,9 @@ func @affine_for_lower_bound_invalid_sym() { #map0 = (d0)[s0] -> (d0 + s0) func @affine_for_upper_bound_invalid_sym() { - for %i0 = 0 to 7 { + affine.for %i0 = 0 to 7 { // expected-error@+1 {{operand cannot be used as a symbol}} - for %n0 = 0 to #map0(%i0)[%i0] { + affine.for %n0 = 0 to #map0(%i0)[%i0] { } } return @@ -87,7 +87,7 @@ func @affine_for_upper_bound_invalid_sym() { #set0 = (i)[N] : (i >= 0, N - i >= 0) func @affine_if_invalid_dim(%arg : index) { - for %n0 = 0 to 7 { + affine.for %n0 = 0 to 7 { %dim = addi %arg, %arg : index // expected-error@+1 {{operand cannot be used as a dimension id}} @@ -101,7 +101,7 @@ func @affine_if_invalid_dim(%arg : index) { #set0 = (i)[N] : (i >= 0, N - i >= 0) func @affine_if_invalid_sym() { - for %i0 = 0 to 7 { + affine.for %i0 = 0 to 7 { // expected-error@+1 {{operand cannot be used as a symbol}} affine.if #set0(%i0)[%i0] {} } @@ -113,7 +113,7 @@ func @affine_if_invalid_sym() { #set0 = (i)[N] : (i >= 0, N - i >= 0) func @affine_if_invalid_dimop_dim(%arg0: index, %arg1: index, %arg2: index, %arg3: index) { - for %n0 = 0 to 7 { + affine.for %n0 = 0 to 7 { %0 = alloc(%arg0, %arg1, %arg2, %arg3) : memref %dim = dim %0, 0 : memref diff --git a/mlir/test/AffineOps/ops.mlir b/mlir/test/AffineOps/ops.mlir index e265c6be3a4..6e60c624c72 100644 --- a/mlir/test/AffineOps/ops.mlir +++ b/mlir/test/AffineOps/ops.mlir @@ -2,9 +2,9 @@ // Check that the attributes for the affine operations are round-tripped. func @attributes() { - // CHECK: for %i + // CHECK: affine.for %i // CHECK-NEXT: } {some_attr: true} - for %i = 0 to 10 { + affine.for %i = 0 to 10 { } {some_attr: true} // CHECK: if diff --git a/mlir/test/EDSC/api-test.cpp b/mlir/test/EDSC/api-test.cpp index 4446fbc9cbc..8d39af520df 100644 --- a/mlir/test/EDSC/api-test.cpp +++ b/mlir/test/EDSC/api-test.cpp @@ -143,7 +143,7 @@ TEST_FUNC(cond_branch) { f->print(llvm::outs()); } -// Inject a EDSC-constructed `for` loop with bounds coming from function +// Inject a EDSC-constructed `affine.for` loop with bounds coming from function // arguments. TEST_FUNC(dynamic_for_func_args) { auto indexType = IndexType::get(&globalContext()); @@ -164,7 +164,7 @@ TEST_FUNC(dynamic_for_func_args) { // clang-format off // CHECK-LABEL: func @dynamic_for_func_args(%arg0: index, %arg1: index) { - // CHECK: for %i0 = (d0) -> (d0)(%arg0) to (d0) -> (d0)(%arg1) step 3 { + // CHECK: affine.for %i0 = (d0) -> (d0)(%arg0) to (d0) -> (d0)(%arg1) step 3 { // CHECK: {{.*}} = affine.apply ()[s0] -> (s0 * 3)()[%arg0] // CHECK: {{.*}} = affine.apply ()[s0, s1] -> (s1 + s0 * 3)()[%arg0, %arg1] // CHECK: {{.*}} = affine.apply ()[s0] -> (s0 + 3)()[%arg0] @@ -172,7 +172,7 @@ TEST_FUNC(dynamic_for_func_args) { f->print(llvm::outs()); } -// Inject a EDSC-constructed `for` loop with non-constant bounds that are +// Inject a EDSC-constructed `affine.for` loop with non-constant bounds that are // obtained from AffineApplyOp (also constructed using EDSC operator // overloads). TEST_FUNC(dynamic_for) { @@ -200,12 +200,12 @@ TEST_FUNC(dynamic_for) { // CHECK-LABEL: func @dynamic_for(%arg0: index, %arg1: index, %arg2: index, %arg3: index) { // CHECK: %0 = affine.apply ()[s0, s1] -> (s0 - s1)()[%arg0, %arg1] // CHECK-NEXT: %1 = affine.apply ()[s0, s1] -> (s0 + s1)()[%arg2, %arg3] - // CHECK-NEXT: for %i0 = (d0) -> (d0)(%0) to (d0) -> (d0)(%1) step 2 { + // CHECK-NEXT: affine.for %i0 = (d0) -> (d0)(%0) to (d0) -> (d0)(%1) step 2 { // clang-format on f->print(llvm::outs()); } -// Inject a EDSC-constructed empty `for` loop with max/min bounds that +// Inject a EDSC-constructed empty `affine.for` loop with max/min bounds that // corresponds to // // for max(%arg0, %arg1) to (%arg2, %arg3) step 1 @@ -234,7 +234,7 @@ TEST_FUNC(max_min_for) { // clang-format off // CHECK-LABEL: func @max_min_for(%arg0: index, %arg1: index, %arg2: index, %arg3: index) { - // CHECK: for %i0 = max (d0, d1) -> (d0, d1)(%arg0, %arg1) to min (d0, d1) -> (d0, d1)(%arg2, %arg3) { + // CHECK: affine.for %i0 = max (d0, d1) -> (d0, d1)(%arg0, %arg1) to min (d0, d1) -> (d0, d1)(%arg2, %arg3) { // clang-format on f->print(llvm::outs()); } @@ -334,7 +334,7 @@ TEST_FUNC(assignments_1) { // clang-format off // CHECK-LABEL: func @assignments(%arg0: memref<4xf32>, %arg1: memref<4xf32>, %arg2: memref<4xf32>) { - // CHECK: for %[[iv:.*]] = 0 to 4 { + // CHECK: affine.for %[[iv:.*]] = 0 to 4 { // CHECK: %[[a:.*]] = load %arg0[%[[iv]]] : memref<4xf32> // CHECK: %[[b:.*]] = load %arg1[%[[iv]]] : memref<4xf32> // CHECK: %[[tmp:.*]] = mulf %[[a]], %[[b]] : f32 @@ -348,7 +348,7 @@ TEST_FUNC(assignments_2) { // clang-format off // CHECK-LABEL: func @assignments(%arg0: memref, %arg1: memref, %arg2: memref) { - // CHECK: for %[[iv:.*]] = {{.*}} to {{.*}} { + // CHECK: affine.for %[[iv:.*]] = {{.*}} to {{.*}} { // CHECK: %[[a:.*]] = load %arg0[%[[iv]]] : memref // CHECK: %[[b:.*]] = load %arg1[%[[iv]]] : memref // CHECK: %[[tmp:.*]] = mulf %[[a]], %[[b]] : f32 @@ -405,13 +405,13 @@ TEST_FUNC(tile_2d) { // CHECK: %[[M:[0-9]+]] = dim %arg0, 0 : memref // CHECK-NEXT: %[[N:[0-9]+]] = dim %arg0, 1 : memref // CHECK-NEXT: %[[P:[0-9]+]] = dim %arg0, 2 : memref - // CHECK: for %i0 = (d0) -> (d0)(%[[ZERO]]) to (d0) -> (d0)(%[[M]]) step 512 { - // CHECK-NEXT: for %i1 = (d0) -> (d0)(%[[ZERO]]) to (d0) -> (d0)(%[[N]]) step 1024 { - // CHECK-NEXT: for %i2 = (d0) -> (d0)(%[[ZERO]]) to (d0) -> (d0)(%[[P]]) { - // CHECK-NEXT: for %i3 = max (d0)[s0] -> (s0, d0)(%i0)[%[[ZERO]]] to min (d0)[s0] -> (s0, d0 + 512)(%i0)[%[[M]]] step 16 { - // CHECK-NEXT: for %i4 = max (d0)[s0] -> (s0, d0)(%i1)[%[[ZERO]]] to min (d0)[s0] -> (s0, d0 + 1024)(%i1)[%[[N]]] step 32 { - // CHECK-NEXT: for %i5 = max (d0, d1)[s0] -> (s0, d0, d1)(%i1, %i4)[%[[ZERO]]] to min (d0, d1)[s0] -> (s0, d0 + 1024, d1 + 32)(%i1, %i4)[%[[N]]] { - // CHECK-NEXT: for %i6 = max (d0, d1)[s0] -> (s0, d0, d1)(%i0, %i3)[%[[ZERO]]] to min (d0, d1)[s0] -> (s0, d0 + 512, d1 + 16)(%i0, %i3)[%[[M]]] { + // CHECK: affine.for %i0 = (d0) -> (d0)(%[[ZERO]]) to (d0) -> (d0)(%[[M]]) step 512 { + // CHECK-NEXT: affine.for %i1 = (d0) -> (d0)(%[[ZERO]]) to (d0) -> (d0)(%[[N]]) step 1024 { + // CHECK-NEXT: affine.for %i2 = (d0) -> (d0)(%[[ZERO]]) to (d0) -> (d0)(%[[P]]) { + // CHECK-NEXT: affine.for %i3 = max (d0)[s0] -> (s0, d0)(%i0)[%[[ZERO]]] to min (d0)[s0] -> (s0, d0 + 512)(%i0)[%[[M]]] step 16 { + // CHECK-NEXT: affine.for %i4 = max (d0)[s0] -> (s0, d0)(%i1)[%[[ZERO]]] to min (d0)[s0] -> (s0, d0 + 1024)(%i1)[%[[N]]] step 32 { + // CHECK-NEXT: affine.for %i5 = max (d0, d1)[s0] -> (s0, d0, d1)(%i1, %i4)[%[[ZERO]]] to min (d0, d1)[s0] -> (s0, d0 + 1024, d1 + 32)(%i1, %i4)[%[[N]]] { + // CHECK-NEXT: affine.for %i6 = max (d0, d1)[s0] -> (s0, d0, d1)(%i0, %i3)[%[[ZERO]]] to min (d0, d1)[s0] -> (s0, d0 + 512, d1 + 16)(%i0, %i3)[%[[M]]] { // CHECK-NEXT: {{.*}} = load {{.*}}[%i6, %i5, %i2] : memref // CHECK-NEXT: {{.*}} = load {{.*}}[%i6, %i5, %i2] : memref // CHECK-NEXT: {{.*}} = addf {{.*}}, {{.*}} : f32 @@ -421,9 +421,9 @@ TEST_FUNC(tile_2d) { // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } - // CHECK-NEXT: for %i7 = (d0) -> (d0)(%[[ZERO]]) to (d0) -> (d0)(%[[P]]) { - // CHECK-NEXT: for %i8 = max (d0)[s0] -> (s0, d0)(%i0)[%[[ZERO]]] to min (d0)[s0] -> (s0, d0 + 512)(%i0)[%[[M]]] { - // CHECK-NEXT: for %i9 = max (d0)[s0] -> (s0, d0)(%i1)[%[[ZERO]]] to min (d0)[s0] -> (s0, d0 + 1024)(%i1)[%[[N]]] { + // CHECK-NEXT: affine.for %i7 = (d0) -> (d0)(%[[ZERO]]) to (d0) -> (d0)(%[[P]]) { + // CHECK-NEXT: affine.for %i8 = max (d0)[s0] -> (s0, d0)(%i0)[%[[ZERO]]] to min (d0)[s0] -> (s0, d0 + 512)(%i0)[%[[M]]] { + // CHECK-NEXT: affine.for %i9 = max (d0)[s0] -> (s0, d0)(%i1)[%[[ZERO]]] to min (d0)[s0] -> (s0, d0 + 1024)(%i1)[%[[N]]] { // CHECK-NEXT: {{.*}} = load {{.*}}[%i8, %i9, %i7] : memref // CHECK-NEXT: {{.*}} = load {{.*}}[%i8, %i9, %i7] : memref // CHECK-NEXT: {{.*}}= addf {{.*}}, {{.*}} : f32 diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp index ec6d12a0876..5d7000e8950 100644 --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -80,11 +80,11 @@ TEST_FUNC(builder_dynamic_for_func_args) { // clang-format off // CHECK-LABEL: func @builder_dynamic_for_func_args(%arg0: index, %arg1: index) { - // CHECK: for %i0 = (d0) -> (d0)(%arg0) to (d0) -> (d0)(%arg1) step 3 { + // CHECK: affine.for %i0 = (d0) -> (d0)(%arg0) to (d0) -> (d0)(%arg1) step 3 { // CHECK: {{.*}} = affine.apply ()[s0] -> (s0 * 3)()[%arg0] // CHECK: {{.*}} = affine.apply ()[s0, s1] -> (s1 + s0 * 3)()[%arg0, %arg1] // CHECK: {{.*}} = affine.apply ()[s0] -> (s0 + 3)()[%arg0] - // CHECK: for %i1 = (d0) -> (d0)(%arg0) to (d0) -> (d0)(%arg1) step 2 { + // CHECK: affine.for %i1 = (d0) -> (d0)(%arg0) to (d0) -> (d0)(%arg1) step 2 { // CHECK: {{.*}} = affine.apply (d0, d1) -> ((d0 + d1 * 3) floordiv 32)(%i0, %i1) // CHECK: {{.*}} = affine.apply (d0, d1) -> (((d0 + d1 * 3) floordiv 32) * 31)(%i0, %i1) // CHECK: {{.*}} = affine.apply (d0, d1) -> ((((d0 + d1 * 3) floordiv 32) * 31) ceildiv 32)(%i0, %i1) @@ -119,7 +119,7 @@ TEST_FUNC(builder_dynamic_for) { // CHECK-LABEL: func @builder_dynamic_for(%arg0: index, %arg1: index, %arg2: index, %arg3: index) { // CHECK: %0 = affine.apply ()[s0, s1] -> (s0 - s1)()[%arg0, %arg1] // CHECK-NEXT: %1 = affine.apply ()[s0, s1] -> (s0 + s1)()[%arg2, %arg3] - // CHECK-NEXT: for %i0 = (d0) -> (d0)(%0) to (d0) -> (d0)(%1) step 2 { + // CHECK-NEXT: affine.for %i0 = (d0) -> (d0)(%0) to (d0) -> (d0)(%1) step 2 { // clang-format on f->print(llvm::outs()); } @@ -140,7 +140,7 @@ TEST_FUNC(builder_max_min_for) { // clang-format off // CHECK-LABEL: func @builder_max_min_for(%arg0: index, %arg1: index, %arg2: index, %arg3: index) { - // CHECK: for %i0 = max (d0, d1) -> (d0, d1)(%arg0, %arg1) to min (d0, d1) -> (d0, d1)(%arg2, %arg3) { + // CHECK: affine.for %i0 = max (d0, d1) -> (d0, d1)(%arg0, %arg1) to min (d0, d1) -> (d0, d1)(%arg2, %arg3) { // CHECK: return // clang-format on f->print(llvm::outs()); @@ -344,16 +344,16 @@ TEST_FUNC(builder_helpers) { }); // CHECK-LABEL: @builder_helpers - // CHECK: for %i0 = (d0) -> (d0)({{.*}}) to (d0) -> (d0)({{.*}}) { - // CHECK-NEXT: for %i1 = (d0) -> (d0)({{.*}}) to (d0) -> (d0)({{.*}}) { - // CHECK-NEXT: for %i2 = (d0) -> (d0)({{.*}}) to (d0) -> (d0)({{.*}}) { + // CHECK: affine.for %i0 = (d0) -> (d0)({{.*}}) to (d0) -> (d0)({{.*}}) { + // CHECK-NEXT: affine.for %i1 = (d0) -> (d0)({{.*}}) to (d0) -> (d0)({{.*}}) { + // CHECK-NEXT: affine.for %i2 = (d0) -> (d0)({{.*}}) to (d0) -> (d0)({{.*}}) { // CHECK-NEXT: [[a:%.*]] = load %arg0[%i0, %i1, %i2] : memref // CHECK-NEXT: [[b:%.*]] = addf {{.*}}, [[a]] : f32 // CHECK-NEXT: [[c:%.*]] = load %arg1[%i0, %i1, %i2] : memref // CHECK-NEXT: [[d:%.*]] = addf [[b]], [[c]] : f32 // CHECK-NEXT: store [[d]], %arg2[%i0, %i1, %i2] : memref // CHECK-NEXT: } - // CHECK-NEXT: for %i3 = (d0) -> (d0)(%c0_1) to (d0) -> (d0)(%2) { + // CHECK-NEXT: affine.for %i3 = (d0) -> (d0)(%c0_1) to (d0) -> (d0)(%2) { // CHECK-NEXT: [[a:%.*]] = load %arg1[%i0, %i1, %i3] : memref // CHECK-NEXT: [[b:%.*]] = load %arg0[%i0, %i1, %i3] : memref // CHECK-NEXT: [[c:%.*]] = addf [[b]], [[a]] : f32 @@ -392,8 +392,8 @@ TEST_FUNC(custom_ops) { }); // CHECK-LABEL: @custom_ops - // CHECK: for %i0 {{.*}} - // CHECK: for %i1 {{.*}} + // CHECK: affine.for %i0 {{.*}} + // CHECK: affine.for %i1 {{.*}} // CHECK: {{.*}} = "my_custom_op"{{.*}} : (index, index) -> index // CHECK: "my_custom_inst_0"{{.*}} : (index, index) -> () // CHECK: [[TWO:%[a-z0-9]+]] = "my_custom_inst_2"{{.*}} : (index, index) -> (index, index) diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir index d6319028491..650f023c185 100644 --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -204,24 +204,24 @@ func @illegaltype(i0) // expected-error {{invalid integer width}} // ----- func @malformed_for_percent() { - for i = 1 to 10 { // expected-error {{expected SSA operand}} + affine.for i = 1 to 10 { // expected-error {{expected SSA operand}} // ----- func @malformed_for_equal() { - for %i 1 to 10 { // expected-error {{expected '='}} + affine.for %i 1 to 10 { // expected-error {{expected '='}} // ----- func @malformed_for_to() { - for %i = 1 too 10 { // expected-error {{expected 'to' between bounds}} + affine.for %i = 1 too 10 { // expected-error {{expected 'to' between bounds}} } } // ----- func @incomplete_for() { - for %i = 1 to 10 step 2 + affine.for %i = 1 to 10 step 2 } // expected-error {{expected '{' to begin a region}} // ----- @@ -230,19 +230,19 @@ func @incomplete_for() { func @reference_to_iv_in_bound() { // expected-error@+1 {{operand use before it's defined}} - for %i0 = #map0(%i0) to 10 { + affine.for %i0 = #map0(%i0) to 10 { } } // ----- func @nonconstant_step(%1 : i32) { - for %2 = 1 to 5 step %1 { // expected-error {{expected non-function type}} + affine.for %2 = 1 to 5 step %1 { // expected-error {{expected non-function type}} // ----- func @for_negative_stride() { - for %i = 1 to 10 step -1 + affine.for %i = 1 to 10 step -1 } // expected-error@-1 {{expected step to be representable as a positive signed integer}} // ----- @@ -254,7 +254,7 @@ func @non_instruction() { // ----- func @invalid_if_conditional2() { - for %i = 1 to 10 { + affine.for %i = 1 to 10 { affine.if (i)[N] : (i >= ) // expected-error {{expected '== 0' or '>= 0' at end of affine constraint}} } } @@ -262,7 +262,7 @@ func @invalid_if_conditional2() { // ----- func @invalid_if_conditional3() { - for %i = 1 to 10 { + affine.for %i = 1 to 10 { affine.if (i)[N] : (i == 1) // expected-error {{expected '0' after '=='}} } } @@ -270,7 +270,7 @@ func @invalid_if_conditional3() { // ----- func @invalid_if_conditional4() { - for %i = 1 to 10 { + affine.for %i = 1 to 10 { affine.if (i)[N] : (i >= 2) // expected-error {{expected '0' after '>='}} } } @@ -278,7 +278,7 @@ func @invalid_if_conditional4() { // ----- func @invalid_if_conditional5() { - for %i = 1 to 10 { + affine.for %i = 1 to 10 { affine.if (i)[N] : (i <= 0 ) // expected-error {{expected '== 0' or '>= 0' at end of affine constraint}} } } @@ -286,7 +286,7 @@ func @invalid_if_conditional5() { // ----- func @invalid_if_conditional6() { - for %i = 1 to 10 { + affine.for %i = 1 to 10 { affine.if (i) : (i) // expected-error {{expected '== 0' or '>= 0' at end of affine constraint}} } } @@ -294,7 +294,7 @@ func @invalid_if_conditional6() { // ----- // TODO (support affine.if (1)? func @invalid_if_conditional7() { - for %i = 1 to 10 { + affine.for %i = 1 to 10 { affine.if (i) : (1) // expected-error {{expected '== 0' or '>= 0' at end of affine constraint}} } } @@ -440,8 +440,8 @@ func @undef() { // ----- func @duplicate_induction_var() { - for %i = 1 to 10 { // expected-error {{previously defined here}} - for %i = 1 to 10 { // expected-error {{redefinition of SSA value '%i'}} + affine.for %i = 1 to 10 { // expected-error {{previously defined here}} + affine.for %i = 1 to 10 { // expected-error {{redefinition of SSA value '%i'}} } } return @@ -450,7 +450,7 @@ func @duplicate_induction_var() { // ----- func @dominance_failure() { - for %i = 1 to 10 { + affine.for %i = 1 to 10 { } "xxx"(%i) : (index)->() // expected-error {{operand #0 does not dominate this use}} return @@ -477,7 +477,7 @@ func @return_type_mismatch() -> i32 { // ----- func @return_inside_loop() { - for %i = 1 to 100 { + affine.for %i = 1 to 100 { // expected-error@-1 {{op expects body block to not have a terminator}} return } @@ -522,7 +522,7 @@ func @referer() { #map1 = (i)[j] -> (i+j) func @bound_symbol_mismatch(%N : index) { - for %i = #map1(%N) to 100 { + affine.for %i = #map1(%N) to 100 { // expected-error@-1 {{symbol operand count and integer set symbol count must match}} } return @@ -533,7 +533,7 @@ func @bound_symbol_mismatch(%N : index) { #map1 = (i)[j] -> (i+j) func @bound_dim_mismatch(%N : index) { - for %i = #map1(%N, %N)[%N] to 100 { + affine.for %i = #map1(%N, %N)[%N] to 100 { // expected-error@-1 {{dim operand count and integer set dim count must match}} } return @@ -542,7 +542,7 @@ func @bound_dim_mismatch(%N : index) { // ----- func @large_bound() { - for %i = 1 to 9223372036854775810 { + affine.for %i = 1 to 9223372036854775810 { // expected-error@-1 {{integer constant out of range for attribute}} } return @@ -551,7 +551,7 @@ func @large_bound() { // ----- func @max_in_upper_bound(%N : index) { - for %i = 1 to max (i)->(N, 100) { //expected-error {{expected non-function type}} + affine.for %i = 1 to max (i)->(N, 100) { //expected-error {{expected non-function type}} } return } @@ -559,7 +559,7 @@ func @max_in_upper_bound(%N : index) { // ----- func @step_typo() { - for %i = 1 to 100 step -- 1 { //expected-error {{expected constant integer}} + affine.for %i = 1 to 100 step -- 1 { //expected-error {{expected constant integer}} } return } @@ -567,7 +567,7 @@ func @step_typo() { // ----- func @invalid_bound_map(%N : i32) { - for %i = 1 to (i)->(j)(%N) { //expected-error {{use of undeclared identifier}} + affine.for %i = 1 to (i)->(j)(%N) { //expected-error {{use of undeclared identifier}} } return } @@ -580,7 +580,7 @@ func @invalid_bound_map(%N : i32) { #set0 = (i)[N] : (i >= 0, N - i >= 0) func @invalid_if_operands1(%N : index) { - for %i = 1 to 10 { + affine.for %i = 1 to 10 { affine.if #set0(%i) { // expected-error@-1 {{symbol operand count and integer set symbol count must match}} @@ -588,7 +588,7 @@ func @invalid_if_operands1(%N : index) { #set0 = (i)[N] : (i >= 0, N - i >= 0) func @invalid_if_operands2(%N : index) { - for %i = 1 to 10 { + affine.for %i = 1 to 10 { affine.if #set0()[%N] { // expected-error@-1 {{dim operand count and integer set dim count must match}} @@ -596,7 +596,7 @@ func @invalid_if_operands2(%N : index) { #set0 = (i)[N] : (i >= 0, N - i >= 0) func @invalid_if_operands3(%N : index) { - for %i = 1 to 10 { + affine.for %i = 1 to 10 { affine.if #set0(%i)[%i] { // expected-error@-1 {{operand cannot be used as a symbol}} } @@ -751,11 +751,11 @@ func @f(f32) { // ----- func @f(%m : memref) { - for %i0 = 0 to 42 { + affine.for %i0 = 0 to 42 { // expected-error@+1 {{operand #2 does not dominate this use}} %x = load %m[%i0, %i1] : memref } - for %i1 = 0 to 42 { + affine.for %i1 = 0 to 42 { } return } @@ -805,7 +805,7 @@ func @type_alias_unknown(!unknown_alias) -> () { // expected-error {{undefined t // Check ill-formed opaque tensor. func @complex_loops() { - for %i1 = 1 to 100 { + affine.for %i1 = 1 to 100 { // expected-error @+1 {{expected '"' in string literal}} "opaqueIntTensor"(){bar: opaque<"", tensor<2x1x4xi32>, "0x686]>} : () -> () @@ -839,7 +839,7 @@ func @invalid_affine_structure() { func @missing_for_max(%arg0: index, %arg1: index, %arg2: memref<100xf32>) { // expected-error @+1 {{lower loop bound affine map with multiple results requires 'max' prefix}} - for %i0 = ()[s]->(0,s-1)()[%arg0] to %arg1 { + affine.for %i0 = ()[s]->(0,s-1)()[%arg0] to %arg1 { } return } @@ -848,7 +848,7 @@ func @missing_for_max(%arg0: index, %arg1: index, %arg2: memref<100xf32>) { func @missing_for_min(%arg0: index, %arg1: index, %arg2: memref<100xf32>) { // expected-error @+1 {{upper loop bound affine map with multiple results requires 'min' prefix}} - for %i0 = %arg0 to ()[s]->(100,s+1)()[%arg1] { + affine.for %i0 = %arg0 to ()[s]->(100,s+1)()[%arg1] { } return } diff --git a/mlir/test/IR/locations.mlir b/mlir/test/IR/locations.mlir index 3b27301cfae..ac4925e3e52 100644 --- a/mlir/test/IR/locations.mlir +++ b/mlir/test/IR/locations.mlir @@ -13,7 +13,7 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) { %2 = constant 4 : index loc(callsite("foo" at "mysource.cc":10:8)) // CHECK: } loc(fused["foo", "mysource.cc":10:8]) - for %i0 = 0 to 8 { + affine.for %i0 = 0 to 8 { } loc(fused["foo", "mysource.cc":10:8]) // CHECK: } loc(fused<"myPass">["foo", "foo2"]) diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index 92fbc0e19f8..c66c6c0614b 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -208,8 +208,8 @@ func @identity_functor(%a : () -> ()) -> (() -> ()) { func @func_ops_in_loop() { // CHECK: %0 = "foo"() : () -> i64 %a = "foo"() : ()->i64 - // CHECK: for %i0 = 1 to 10 { - for %i = 1 to 10 { + // CHECK: affine.for %i0 = 1 to 10 { + affine.for %i = 1 to 10 { // CHECK: %1 = "doo"() : () -> f32 %b = "doo"() : ()->f32 // CHECK: "bar"(%0, %1) : (i64, f32) -> () @@ -224,10 +224,10 @@ func @func_ops_in_loop() { // CHECK-LABEL: func @loops() { func @loops() { - // CHECK: for %i0 = 1 to 100 step 2 { - for %i = 1 to 100 step 2 { - // CHECK: for %i1 = 1 to 200 { - for %j = 1 to 200 { + // CHECK: affine.for %i0 = 1 to 100 step 2 { + affine.for %i = 1 to 100 step 2 { + // CHECK: affine.for %i1 = 1 to 200 { + affine.for %j = 1 to 200 { } // CHECK: } } // CHECK: } return // CHECK: return @@ -235,14 +235,14 @@ func @loops() { // CHECK-LABEL: func @complex_loops() { func @complex_loops() { - for %i1 = 1 to 100 { // CHECK: for %i0 = 1 to 100 { - for %j1 = 1 to 100 { // CHECK: for %i1 = 1 to 100 { + affine.for %i1 = 1 to 100 { // CHECK: affine.for %i0 = 1 to 100 { + affine.for %j1 = 1 to 100 { // CHECK: affine.for %i1 = 1 to 100 { // CHECK: "foo"(%i0, %i1) : (index, index) -> () "foo"(%i1, %j1) : (index,index) -> () } // CHECK: } "boo"() : () -> () // CHECK: "boo"() : () -> () - for %j2 = 1 to 10 { // CHECK: for %i2 = 1 to 10 { - for %k2 = 1 to 10 { // CHECK: for %i3 = 1 to 10 { + affine.for %j2 = 1 to 10 { // CHECK: affine.for %i2 = 1 to 10 { + affine.for %k2 = 1 to 10 { // CHECK: affine.for %i3 = 1 to 10 { "goo"() : () -> () // CHECK: "goo"() : () -> () } // CHECK: } } // CHECK: } @@ -253,8 +253,8 @@ func @complex_loops() { // CHECK: func @triang_loop(%arg0: index, %arg1: memref) { func @triang_loop(%arg0: index, %arg1: memref) { %c = constant 0 : i32 // CHECK: %c0_i32 = constant 0 : i32 - for %i0 = 1 to %arg0 { // CHECK: for %i0 = 1 to %arg0 { - for %i1 = (d0)[]->(d0)(%i0)[] to %arg0 { // CHECK: for %i1 = #map{{[0-9]+}}(%i0) to %arg0 { + affine.for %i0 = 1 to %arg0 { // CHECK: affine.for %i0 = 1 to %arg0 { + affine.for %i1 = (d0)[]->(d0)(%i0)[] to %arg0 { // CHECK: affine.for %i1 = #map{{[0-9]+}}(%i0) to %arg0 { store %c, %arg1[%i0, %i1] : memref // CHECK: store %c0_i32, %arg1[%i0, %i1] } // CHECK: } } // CHECK: } @@ -263,8 +263,8 @@ func @triang_loop(%arg0: index, %arg1: memref) { // CHECK: func @minmax_loop(%arg0: index, %arg1: index, %arg2: memref<100xf32>) { func @minmax_loop(%arg0: index, %arg1: index, %arg2: memref<100xf32>) { - // CHECK: for %i0 = max #map{{.*}}()[%arg0] to min #map{{.*}}()[%arg1] { - for %i0 = max()[s]->(0,s-1)()[%arg0] to min()[s]->(100,s+1)()[%arg1] { + // CHECK: affine.for %i0 = max #map{{.*}}()[%arg0] to min #map{{.*}}()[%arg1] { + affine.for %i0 = max()[s]->(0,s-1)()[%arg0] to min()[s]->(100,s+1)()[%arg1] { // CHECK: "foo"(%arg2, %i0) : (memref<100xf32>, index) -> () "foo"(%arg2, %i0) : (memref<100xf32>, index) -> () } // CHECK: } @@ -275,24 +275,24 @@ func @minmax_loop(%arg0: index, %arg1: index, %arg2: memref<100xf32>) { func @loop_bounds(%N : index) { // CHECK: %0 = "foo"(%arg0) : (index) -> index %s = "foo"(%N) : (index) -> index - // CHECK: for %i0 = %0 to %arg0 - for %i = %s to %N { - // CHECK: for %i1 = #map{{[0-9]+}}(%i0) to 0 - for %j = (d0)[]->(d0)(%i)[] to 0 step 1 { + // CHECK: affine.for %i0 = %0 to %arg0 + affine.for %i = %s to %N { + // CHECK: affine.for %i1 = #map{{[0-9]+}}(%i0) to 0 + affine.for %j = (d0)[]->(d0)(%i)[] to 0 step 1 { // CHECK: %1 = affine.apply #map{{.*}}(%i0, %i1)[%0] %w1 = affine.apply(d0, d1)[s0] -> (d0+d1) (%i, %j) [%s] // CHECK: %2 = affine.apply #map{{.*}}(%i0, %i1)[%0] %w2 = affine.apply(d0, d1)[s0] -> (s0+1) (%i, %j) [%s] - // CHECK: for %i2 = #map{{.*}}(%1, %i0)[%arg0] to #map{{.*}}(%2, %i1)[%0] { - for %k = #bound_map1 (%w1, %i)[%N] to (i, j)[s] -> (i + j + s) (%w2, %j)[%s] { + // CHECK: affine.for %i2 = #map{{.*}}(%1, %i0)[%arg0] to #map{{.*}}(%2, %i1)[%0] { + affine.for %k = #bound_map1 (%w1, %i)[%N] to (i, j)[s] -> (i + j + s) (%w2, %j)[%s] { // CHECK: "foo"(%i0, %i1, %i2) : (index, index, index) -> () "foo"(%i, %j, %k) : (index, index, index)->() // CHECK: %c30 = constant 30 : index %c = constant 30 : index // CHECK: %3 = affine.apply #map{{.*}}(%arg0, %c30) %u = affine.apply (d0, d1)->(d0+d1) (%N, %c) - // CHECK: for %i3 = max #map{{.*}}(%i0)[%3] to min #map{{.*}}(%i2)[%c30] { - for %l = max #bound_map2(%i)[%u] to min #bound_map2(%k)[%c] { + // CHECK: affine.for %i3 = max #map{{.*}}(%i0)[%3] to min #map{{.*}}(%i2)[%c30] { + affine.for %l = max #bound_map2(%i)[%u] to min #bound_map2(%k)[%c] { // CHECK: "bar"(%i3) : (index) -> () "bar"(%l) : (index) -> () } // CHECK: } @@ -305,7 +305,7 @@ func @loop_bounds(%N : index) { // CHECK-LABEL: func @ifinst(%arg0: index) { func @ifinst(%N: index) { %c = constant 200 : index // CHECK %c200 = constant 200 - for %i = 1 to 10 { // CHECK for %i0 = 1 to 10 { + affine.for %i = 1 to 10 { // CHECK affine.for %i0 = 1 to 10 { affine.if #set0(%i)[%N, %c] { // CHECK affine.if #set0(%i0)[%arg0, %c200] { %x = constant 1 : i32 // CHECK: %c1_i32 = constant 1 : i32 @@ -328,7 +328,7 @@ func @ifinst(%N: index) { // CHECK-LABEL: func @simple_ifinst(%arg0: index) { func @simple_ifinst(%N: index) { %c = constant 200 : index // CHECK %c200 = constant 200 - for %i = 1 to 10 { // CHECK for %i0 = 1 to 10 { + affine.for %i = 1 to 10 { // CHECK affine.for %i0 = 1 to 10 { affine.if #set0(%i)[%N, %c] { // CHECK affine.if #set0(%i0)[%arg0, %c200] { %x = constant 1 : i32 // CHECK: %c1_i32 = constant 1 : i32 @@ -549,18 +549,18 @@ func @funcattrwithblock() -> () #map_non_simple2 = ()[s0, s1] -> (s0 + s1) #map_non_simple3 = ()[s0] -> (s0 + 3) func @funcsimplemap(%arg0: index, %arg1: index) -> () { - for %i0 = 0 to #map_simple0()[] { - // CHECK: for %i0 = 0 to 10 { - for %i1 = 0 to #map_simple1()[%arg1] { - // CHECK: for %i1 = 0 to %arg1 { - for %i2 = 0 to #map_non_simple0(%i0)[] { - // CHECK: for %i2 = 0 to #map{{[a-z_0-9]*}}(%i0) { - for %i3 = 0 to #map_non_simple1(%i0)[%arg1] { - // CHECK: for %i3 = 0 to #map{{[a-z_0-9]*}}(%i0)[%arg1] { - for %i4 = 0 to #map_non_simple2()[%arg1, %arg0] { - // CHECK: for %i4 = 0 to #map{{[a-z_0-9]*}}()[%arg1, %arg0] { - for %i5 = 0 to #map_non_simple3()[%arg0] { - // CHECK: for %i5 = 0 to #map{{[a-z_0-9]*}}()[%arg0] { + affine.for %i0 = 0 to #map_simple0()[] { + // CHECK: affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to #map_simple1()[%arg1] { + // CHECK: affine.for %i1 = 0 to %arg1 { + affine.for %i2 = 0 to #map_non_simple0(%i0)[] { + // CHECK: affine.for %i2 = 0 to #map{{[a-z_0-9]*}}(%i0) { + affine.for %i3 = 0 to #map_non_simple1(%i0)[%arg1] { + // CHECK: affine.for %i3 = 0 to #map{{[a-z_0-9]*}}(%i0)[%arg1] { + affine.for %i4 = 0 to #map_non_simple2()[%arg1, %arg0] { + // CHECK: affine.for %i4 = 0 to #map{{[a-z_0-9]*}}()[%arg1, %arg0] { + affine.for %i5 = 0 to #map_non_simple3()[%arg0] { + // CHECK: affine.for %i5 = 0 to #map{{[a-z_0-9]*}}()[%arg0] { %c42_i32 = constant 42 : i32 } } @@ -745,9 +745,9 @@ func @sparsevectorattr() -> () { // CHECK-LABEL: func @loops_with_blockids() { func @loops_with_blockids() { ^block0: - for %i = 1 to 100 step 2 { + affine.for %i = 1 to 100 step 2 { ^block1: - for %j = 1 to 200 { + affine.for %j = 1 to 200 { ^block2: } } diff --git a/mlir/test/IR/pretty-locations.mlir b/mlir/test/IR/pretty-locations.mlir index bc5a319c99e..defde9e9c70 100644 --- a/mlir/test/IR/pretty-locations.mlir +++ b/mlir/test/IR/pretty-locations.mlir @@ -18,7 +18,7 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) { %3 = constant 4 : index loc(callsite("foo" at callsite("mysource1.cc":10:8 at callsite("mysource2.cc":13:8 at "mysource3.cc":100:10)))) // CHECK: } ["foo", mysource.cc:10:8] - for %i0 = 0 to 8 { + affine.for %i0 = 0 to 8 { } loc(fused["foo", "mysource.cc":10:8]) // CHECK: } <"myPass">["foo", "foo2"] diff --git a/mlir/test/Transforms/Vectorize/lower_vector_transfers.mlir b/mlir/test/Transforms/Vectorize/lower_vector_transfers.mlir index 013f6351a17..a55c79f1141 100644 --- a/mlir/test/Transforms/Vectorize/lower_vector_transfers.mlir +++ b/mlir/test/Transforms/Vectorize/lower_vector_transfers.mlir @@ -6,8 +6,8 @@ // CHECK-LABEL: func @materialize_read_1d() { func @materialize_read_1d() { %A = alloc () : memref<7x42xf32> - for %i0 = 0 to 7 step 4 { - for %i1 = 0 to 42 step 4 { + affine.for %i0 = 0 to 7 step 4 { + affine.for %i1 = 0 to 42 step 4 { %f1 = vector_transfer_read %A, %i0, %i1 {permutation_map: (d0, d1) -> (d0)} : (memref<7x42xf32>, index, index) -> vector<4xf32> %ip1 = affine.apply (d0) -> (d0 + 1) (%i1) %f2 = vector_transfer_read %A, %i0, %ip1 {permutation_map: (d0, d1) -> (d0)} : (memref<7x42xf32>, index, index) -> vector<4xf32> @@ -29,11 +29,11 @@ func @materialize_read_1d() { // CHECK-LABEL: func @materialize_read_1d_partially_specialized func @materialize_read_1d_partially_specialized(%dyn1 : index, %dyn2 : index, %dyn4 : index) { %A = alloc (%dyn1, %dyn2, %dyn4) : memref<7x?x?x42x?xf32> - for %i0 = 0 to 7 { - for %i1 = 0 to %dyn1 { - for %i2 = 0 to %dyn2 { - for %i3 = 0 to 42 step 2 { - for %i4 = 0 to %dyn4 { + affine.for %i0 = 0 to 7 { + affine.for %i1 = 0 to %dyn1 { + affine.for %i2 = 0 to %dyn2 { + affine.for %i3 = 0 to 42 step 2 { + affine.for %i4 = 0 to %dyn4 { %f1 = vector_transfer_read %A, %i0, %i1, %i2, %i3, %i4 {permutation_map: (d0, d1, d2, d3, d4) -> (d3)} : ( memref<7x?x?x42x?xf32>, index, index, index, index, index) -> vector<4xf32> %i3p1 = affine.apply (d0) -> (d0 + 1) (%i3) %f2 = vector_transfer_read %A, %i0, %i1, %i2, %i3p1, %i4 {permutation_map: (d0, d1, d2, d3, d4) -> (d3)} : ( memref<7x?x?x42x?xf32>, index, index, index, index, index) -> vector<4xf32> @@ -54,19 +54,19 @@ func @materialize_read_1d_partially_specialized(%dyn1 : index, %dyn2 : index, %d // CHECK-LABEL: func @materialize_read(%arg0: index, %arg1: index, %arg2: index, %arg3: index) { func @materialize_read(%M: index, %N: index, %O: index, %P: index) { // CHECK-NEXT: %0 = alloc(%arg0, %arg1, %arg2, %arg3) : memref - // CHECK-NEXT: for %[[I0:.*]] = 0 to %arg0 step 3 { - // CHECK-NEXT: for %[[I1:.*]] = 0 to %arg1 { - // CHECK-NEXT: for %[[I2:.*]] = 0 to %arg2 { - // CHECK-NEXT: for %[[I3:.*]] = 0 to %arg3 step 5 { + // CHECK-NEXT: affine.for %[[I0:.*]] = 0 to %arg0 step 3 { + // CHECK-NEXT: affine.for %[[I1:.*]] = 0 to %arg1 { + // CHECK-NEXT: affine.for %[[I2:.*]] = 0 to %arg2 { + // CHECK-NEXT: affine.for %[[I3:.*]] = 0 to %arg3 step 5 { // CHECK: %[[D0:.*]] = dim %0, 0 : memref // CHECK-NEXT: %[[D1:.*]] = dim %0, 1 : memref // CHECK-NEXT: %[[D2:.*]] = dim %0, 2 : memref // CHECK-NEXT: %[[D3:.*]] = dim %0, 3 : memref // CHECK: %[[ALLOC:.*]] = alloc() : memref<5x4x3xf32> // CHECK-NEXT: %[[VECTOR_VIEW:.*]] = vector_type_cast %[[ALLOC]] : memref<5x4x3xf32>, memref<1xvector<5x4x3xf32>> - // CHECK-NEXT: for %[[I4:.*]] = 0 to 3 { - // CHECK-NEXT: for %[[I5:.*]] = 0 to 4 { - // CHECK-NEXT: for %[[I6:.*]] = 0 to 5 { + // CHECK-NEXT: affine.for %[[I4:.*]] = 0 to 3 { + // CHECK-NEXT: affine.for %[[I5:.*]] = 0 to 4 { + // CHECK-NEXT: affine.for %[[I6:.*]] = 0 to 5 { // CHECK-NEXT: %[[C0:.*]] = constant 0 : index // CHECK-NEXT: %[[C1:.*]] = constant 1 : index // CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I0]], %[[I4]]) @@ -117,10 +117,10 @@ func @materialize_read(%M: index, %N: index, %O: index, %P: index) { // Check that I3 + I6 (of size 5) read from last index load(..., L3) and write into first index store(I6, ...) // Other dimensions are just accessed with I1, I2 resp. %A = alloc (%M, %N, %O, %P) : memref - for %i0 = 0 to %M step 3 { - for %i1 = 0 to %N { - for %i2 = 0 to %O { - for %i3 = 0 to %P step 5 { + affine.for %i0 = 0 to %M step 3 { + affine.for %i1 = 0 to %N { + affine.for %i2 = 0 to %O { + affine.for %i3 = 0 to %P step 5 { %f = vector_transfer_read %A, %i0, %i1, %i2, %i3 {permutation_map: (d0, d1, d2, d3) -> (d3, 0, d0)} : (memref, index, index, index, index) -> vector<5x4x3xf32> } } @@ -133,10 +133,10 @@ func @materialize_read(%M: index, %N: index, %O: index, %P: index) { func @materialize_write(%M: index, %N: index, %O: index, %P: index) { // CHECK-NEXT: %0 = alloc(%arg0, %arg1, %arg2, %arg3) : memref // CHECK-NEXT: %cst = constant splat, 1.000000e+00> : vector<5x4x3xf32> - // CHECK-NEXT: for %[[I0:.*]] = 0 to %arg0 step 3 { - // CHECK-NEXT: for %[[I1:.*]] = 0 to %arg1 step 4 { - // CHECK-NEXT: for %[[I2:.*]] = 0 to %arg2 { - // CHECK-NEXT: for %[[I3:.*]] = 0 to %arg3 step 5 { + // CHECK-NEXT: affine.for %[[I0:.*]] = 0 to %arg0 step 3 { + // CHECK-NEXT: affine.for %[[I1:.*]] = 0 to %arg1 step 4 { + // CHECK-NEXT: affine.for %[[I2:.*]] = 0 to %arg2 { + // CHECK-NEXT: affine.for %[[I3:.*]] = 0 to %arg3 step 5 { // CHECK: %[[D0:.*]] = dim %0, 0 : memref // CHECK-NEXT: %[[D1:.*]] = dim %0, 1 : memref // CHECK-NEXT: %[[D2:.*]] = dim %0, 2 : memref @@ -144,9 +144,9 @@ func @materialize_write(%M: index, %N: index, %O: index, %P: index) { // CHECK: %[[ALLOC:.*]] = alloc() : memref<5x4x3xf32> // CHECK-NEXT: %[[VECTOR_VIEW:.*]] = vector_type_cast {{.*}} : memref<5x4x3xf32>, memref<1xvector<5x4x3xf32>> // CHECK: store %cst, {{.*}} : memref<1xvector<5x4x3xf32>> - // CHECK-NEXT: for %[[I4:.*]] = 0 to 3 { - // CHECK-NEXT: for %[[I5:.*]] = 0 to 4 { - // CHECK-NEXT: for %[[I6:.*]] = 0 to 5 { + // CHECK-NEXT: affine.for %[[I4:.*]] = 0 to 3 { + // CHECK-NEXT: affine.for %[[I5:.*]] = 0 to 4 { + // CHECK-NEXT: affine.for %[[I6:.*]] = 0 to 5 { // CHECK-NEXT: %[[C0:.*]] = constant 0 : index // CHECK-NEXT: %[[C1:.*]] = constant 1 : index // CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I0]], %[[I4]]) @@ -201,10 +201,10 @@ func @materialize_write(%M: index, %N: index, %O: index, %P: index) { // Other dimension is just accessed with I2. %A = alloc (%M, %N, %O, %P) : memref %f1 = constant splat, 1.000000e+00> : vector<5x4x3xf32> - for %i0 = 0 to %M step 3 { - for %i1 = 0 to %N step 4 { - for %i2 = 0 to %O { - for %i3 = 0 to %P step 5 { + affine.for %i0 = 0 to %M step 3 { + affine.for %i1 = 0 to %N step 4 { + affine.for %i2 = 0 to %O { + affine.for %i3 = 0 to %P step 5 { vector_transfer_write %f1, %A, %i0, %i1, %i2, %i3 {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d0)} : vector<5x4x3xf32>, memref, index, index, index, index } } diff --git a/mlir/test/Transforms/Vectorize/materialize.mlir b/mlir/test/Transforms/Vectorize/materialize.mlir index 80458c75333..ce445ec75bb 100644 --- a/mlir/test/Transforms/Vectorize/materialize.mlir +++ b/mlir/test/Transforms/Vectorize/materialize.mlir @@ -10,10 +10,10 @@ func @materialize(%M : index, %N : index, %O : index, %P : index) { %A = alloc (%M, %N, %O, %P) : memref %f1 = constant splat, 1.000000e+00> : vector<4x4x4xf32> - // CHECK: for %i0 = 0 to %arg0 step 4 { - // CHECK-NEXT: for %i1 = 0 to %arg1 step 4 { - // CHECK-NEXT: for %i2 = 0 to %arg2 { - // CHECK-NEXT: for %i3 = 0 to %arg3 step 4 { + // CHECK: affine.for %i0 = 0 to %arg0 step 4 { + // CHECK-NEXT: affine.for %i1 = 0 to %arg1 step 4 { + // CHECK-NEXT: affine.for %i2 = 0 to %arg2 { + // CHECK-NEXT: affine.for %i3 = 0 to %arg3 step 4 { // CHECK-NEXT: %[[a:[0-9]+]] = {{.*}}[[ID1]](%i0) // CHECK-NEXT: %[[b:[0-9]+]] = {{.*}}[[ID1]](%i1) // CHECK-NEXT: %[[c:[0-9]+]] = {{.*}}[[ID1]](%i2) @@ -25,10 +25,10 @@ func @materialize(%M : index, %N : index, %O : index, %P : index) { // CHECK: vector_transfer_write {{.*}}, %0, {{.*}}, %[[b2]], {{.*}} {permutation_map: #[[D0D1D2D3TOD1D0]]} : vector<4x4xf32>, memref, index, index, index, index // CHECK: %[[b3:[0-9]+]] = {{.*}}[[D0P3]](%i1) // CHECK: vector_transfer_write {{.*}}, %0, {{.*}}, %[[b3]], {{.*}} {permutation_map: #[[D0D1D2D3TOD1D0]]} : vector<4x4xf32>, memref, index, index, index, index - for %i0 = 0 to %M step 4 { - for %i1 = 0 to %N step 4 { - for %i2 = 0 to %O { - for %i3 = 0 to %P step 4 { + affine.for %i0 = 0 to %M step 4 { + affine.for %i1 = 0 to %N step 4 { + affine.for %i2 = 0 to %O { + affine.for %i3 = 0 to %P step 4 { "vector_transfer_write"(%f1, %A, %i0, %i1, %i2, %i3) {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d0)} : (vector<4x4x4xf32>, memref, index, index, index, index) -> () } } diff --git a/mlir/test/Transforms/Vectorize/materialize_vectors_1d_to_1d.mlir b/mlir/test/Transforms/Vectorize/materialize_vectors_1d_to_1d.mlir index b5f771d7e62..71c442b965e 100644 --- a/mlir/test/Transforms/Vectorize/materialize_vectors_1d_to_1d.mlir +++ b/mlir/test/Transforms/Vectorize/materialize_vectors_1d_to_1d.mlir @@ -15,8 +15,8 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { %f1 = constant 1.0 : f32 %f2 = constant 2.0 : f32 // 4x unroll (jammed by construction). - // CHECK: for %i0 = 0 to %arg0 { - // CHECK-NEXT: for %i1 = 0 to %arg1 step 32 { + // CHECK: affine.for %i0 = 0 to %arg0 { + // CHECK-NEXT: affine.for %i1 = 0 to %arg1 step 32 { // CHECK-NEXT: [[CST0:%.*]] = constant splat, 1.000000e+00> : vector<8xf32> // CHECK-NEXT: [[CST1:%.*]] = constant splat, 1.000000e+00> : vector<8xf32> // CHECK-NEXT: [[CST2:%.*]] = constant splat, 1.000000e+00> : vector<8xf32> @@ -34,15 +34,15 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { // CHECK-NEXT: [[VAL31:%.*]] = affine.apply [[D0P24]]{{.*}} // CHECK-NEXT: vector_transfer_write [[CST3]], {{.*}}, [[VAL30]], [[VAL31]] {permutation_map: [[D0D1TOD1]]} : vector<8xf32> // - for %i0 = 0 to %M { - for %i1 = 0 to %N { + affine.for %i0 = 0 to %M { + affine.for %i1 = 0 to %N { // non-scoped %f1 store %f1, %A[%i0, %i1] : memref } } // 4x unroll (jammed by construction). - // CHECK: for %i2 = 0 to %arg0 { - // CHECK-NEXT: for %i3 = 0 to %arg1 step 32 { + // CHECK: affine.for %i2 = 0 to %arg0 { + // CHECK-NEXT: affine.for %i3 = 0 to %arg1 step 32 { // CHECK-NEXT: [[CST0:%.*]] = constant splat, 2.000000e+00> : vector<8xf32> // CHECK-NEXT: [[CST1:%.*]] = constant splat, 2.000000e+00> : vector<8xf32> // CHECK-NEXT: [[CST2:%.*]] = constant splat, 2.000000e+00> : vector<8xf32> @@ -60,15 +60,15 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { // CHECK-NEXT: [[VAL31:%.*]] = affine.apply [[D0P24]]{{.*}} // CHECK-NEXT: vector_transfer_write [[CST3]], {{.*}}, [[VAL30]], [[VAL31]] {permutation_map: [[D0D1TOD1]]} : vector<8xf32> // - for %i2 = 0 to %M { - for %i3 = 0 to %N { + affine.for %i2 = 0 to %M { + affine.for %i3 = 0 to %N { // non-scoped %f2 store %f2, %B[%i2, %i3] : memref } } // 4x unroll (jammed by construction). - // CHECK: for %i4 = 0 to %arg0 { - // CHECK-NEXT: for %i5 = 0 to %arg1 step 32 { + // CHECK: affine.for %i4 = 0 to %arg0 { + // CHECK-NEXT: affine.for %i5 = 0 to %arg1 step 32 { // CHECK-NEXT: {{.*}} = affine.apply // CHECK-NEXT: {{.*}} = affine.apply // CHECK-NEXT: {{.*}} = vector_transfer_read @@ -110,8 +110,8 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { // CHECK-NEXT: {{.*}} = affine.apply // CHECK-NEXT: vector_transfer_write // - for %i4 = 0 to %M { - for %i5 = 0 to %N { + affine.for %i4 = 0 to %M { + affine.for %i5 = 0 to %N { %a5 = load %A[%i4, %i5] : memref %b5 = load %B[%i4, %i5] : memref %s5 = addf %a5, %b5 : f32 diff --git a/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_1d.mlir b/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_1d.mlir index 92df49fa8fa..62149c323b6 100644 --- a/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_1d.mlir +++ b/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_1d.mlir @@ -15,8 +15,8 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { %f1 = constant 1.0 : f32 %f2 = constant 2.0 : f32 // (3x2)x unroll (jammed by construction). - // CHECK: for %i0 = 0 to %arg0 step 3 { - // CHECK-NEXT: for %i1 = 0 to %arg1 step 16 { + // CHECK: affine.for %i0 = 0 to %arg0 step 3 { + // CHECK-NEXT: affine.for %i1 = 0 to %arg1 step 16 { // CHECK-NEXT: {{.*}} = constant splat, 1.000000e+00> : vector<8xf32> // CHECK-NEXT: {{.*}} = constant splat, 1.000000e+00> : vector<8xf32> // CHECK-NEXT: {{.*}} = constant splat, 1.000000e+00> : vector<8xf32> @@ -41,26 +41,26 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { // CHECK-NEXT: [[VAL50:%.*]] = affine.apply [[D0P2]](%i0) // CHECK-NEXT: [[VAL51:%.*]] = affine.apply [[D0P8]](%i1) // CHECK-NEXT: vector_transfer_write {{.*}}, {{.*}}, [[VAL50]], [[VAL51]] {permutation_map: [[D0D1TOD1]]} : vector<8xf32> - for %i0 = 0 to %M { - for %i1 = 0 to %N { + affine.for %i0 = 0 to %M { + affine.for %i1 = 0 to %N { // non-scoped %f1 store %f1, %A[%i0, %i1] : memref } } // (3x2)x unroll (jammed by construction). - // CHECK: for %i2 = 0 to %arg0 step 3 { - // CHECK-NEXT: for %i3 = 0 to %arg1 step 16 { + // CHECK: affine.for %i2 = 0 to %arg0 step 3 { + // CHECK-NEXT: affine.for %i3 = 0 to %arg1 step 16 { // ..... - for %i2 = 0 to %M { - for %i3 = 0 to %N { + affine.for %i2 = 0 to %M { + affine.for %i3 = 0 to %N { // non-scoped %f2 // CHECK does (3x4)x unrolling. store %f2, %B[%i2, %i3] : memref } } // (3x2)x unroll (jammed by construction). - // CHECK: for %i4 = 0 to %arg0 step 3 { - // CHECK-NEXT: for %i5 = 0 to %arg1 step 16 { + // CHECK: affine.for %i4 = 0 to %arg0 step 3 { + // CHECK-NEXT: affine.for %i5 = 0 to %arg1 step 16 { // CHECK-NEXT: {{.*}} = affine.apply // CHECK-NEXT: {{.*}} = affine.apply // CHECK-NEXT: {{.*}} = vector_transfer_read @@ -122,8 +122,8 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { // CHECK-NEXT: {{.*}} = affine.apply // CHECK-NEXT: vector_transfer_write // - for %i4 = 0 to %M { - for %i5 = 0 to %N { + affine.for %i4 = 0 to %M { + affine.for %i5 = 0 to %N { %a5 = load %A[%i4, %i5] : memref %b5 = load %B[%i4, %i5] : memref %s5 = addf %a5, %b5 : f32 diff --git a/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_2d.mlir b/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_2d.mlir index 36ec96e30b4..59705eca69e 100644 --- a/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_2d.mlir +++ b/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_2d.mlir @@ -13,8 +13,8 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { %f1 = constant 1.0 : f32 %f2 = constant 2.0 : f32 // 2x unroll (jammed by construction). - // CHECK: for %i0 = 0 to %arg0 step 3 { - // CHECK-NEXT: for %i1 = 0 to %arg1 step 32 { + // CHECK: affine.for %i0 = 0 to %arg0 step 3 { + // CHECK-NEXT: affine.for %i1 = 0 to %arg1 step 32 { // CHECK-NEXT: {{.*}} = constant splat, 1.000000e+00> : vector<3x16xf32> // CHECK-NEXT: {{.*}} = constant splat, 1.000000e+00> : vector<3x16xf32> // CHECK-NEXT: [[VAL00:%.*]] = affine.apply [[ID1]](%i0) @@ -24,15 +24,15 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { // CHECK-NEXT: [[VAL11:%.*]] = affine.apply [[D0P16]](%i1) // CHECK-NEXT: vector_transfer_write {{.*}}, {{.*}}, [[VAL10]], [[VAL11]] {permutation_map: [[ID2]]} : vector<3x16xf32> // - for %i0 = 0 to %M { - for %i1 = 0 to %N { + affine.for %i0 = 0 to %M { + affine.for %i1 = 0 to %N { // non-scoped %f1 store %f1, %A[%i0, %i1] : memref } } // 2x unroll (jammed by construction). - // CHECK: for %i2 = 0 to %arg0 step 3 { - // CHECK-NEXT: for %i3 = 0 to %arg1 step 32 { + // CHECK: affine.for %i2 = 0 to %arg0 step 3 { + // CHECK-NEXT: affine.for %i3 = 0 to %arg1 step 32 { // CHECK-NEXT: {{.*}} = constant splat, 2.000000e+00> : vector<3x16xf32> // CHECK-NEXT: {{.*}} = constant splat, 2.000000e+00> : vector<3x16xf32> // CHECK-NEXT: [[VAL00:%.*]] = affine.apply [[ID1]](%i2) @@ -42,15 +42,15 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { // CHECK-NEXT: [[VAL11:%.*]] = affine.apply [[D0P16]](%i3) // CHECK-NEXT: vector_transfer_write {{.*}}, {{.*}}, [[VAL10]], [[VAL11]] {permutation_map: [[ID2]]} : vector<3x16xf32> // - for %i2 = 0 to %M { - for %i3 = 0 to %N { + affine.for %i2 = 0 to %M { + affine.for %i3 = 0 to %N { // non-scoped %f2 store %f2, %B[%i2, %i3] : memref } } // 2x unroll (jammed by construction). - // CHECK: for %i4 = 0 to %arg0 step 3 { - // CHECK-NEXT: for %i5 = 0 to %arg1 step 32 { + // CHECK: affine.for %i4 = 0 to %arg0 step 3 { + // CHECK-NEXT: affine.for %i5 = 0 to %arg1 step 32 { // CHECK-NEXT: {{.*}} = affine.apply // CHECK-NEXT: {{.*}} = affine.apply // CHECK-NEXT: {{.*}} = vector_transfer_read @@ -72,8 +72,8 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { // CHECK-NEXT: {{.*}} = affine.apply // CHECK-NEXT: vector_transfer_write // - for %i4 = 0 to %M { - for %i5 = 0 to %N { + affine.for %i4 = 0 to %M { + affine.for %i5 = 0 to %N { %a5 = load %A[%i4, %i5] : memref %b5 = load %B[%i4, %i5] : memref %s5 = addf %a5, %b5 : f32 diff --git a/mlir/test/Transforms/Vectorize/normalize_maps.mlir b/mlir/test/Transforms/Vectorize/normalize_maps.mlir index 9569dbe07fe..076d2c75633 100644 --- a/mlir/test/Transforms/Vectorize/normalize_maps.mlir +++ b/mlir/test/Transforms/Vectorize/normalize_maps.mlir @@ -9,19 +9,19 @@ // CHECK-LABEL: func @simple() func @simple() { - for %i0 = 0 to 7 { + affine.for %i0 = 0 to 7 { %0 = affine.apply (d0) -> (d0) (%i0) %1 = affine.apply (d0) -> (d0) (%0) %2 = affine.apply (d0, d1) -> (d0 + d1) (%0, %0) %3 = affine.apply (d0, d1) -> (d0 - d1) (%0, %0) } - // CHECK-NEXT: for %i0 = 0 to 7 + // CHECK-NEXT: affine.for %i0 = 0 to 7 // CHECK-NEXT: {{.*}} affine.apply #[[ID1]](%i0) // CHECK-NEXT: {{.*}} affine.apply #[[D0TIMES2]](%i0) // CHECK-NEXT: {{.*}} affine.apply #[[ZERO]]() - for %i1 = 0 to 7 { - for %i2 = 0 to 42 { + affine.for %i1 = 0 to 7 { + affine.for %i2 = 0 to 42 { %20 = affine.apply (d0, d1) -> (d1) (%i1, %i2) %21 = affine.apply (d0, d1) -> (d0) (%i1, %i2) %22 = affine.apply (d0, d1) -> (d0 + d1) (%20, %21) @@ -29,15 +29,15 @@ func @simple() { %24 = affine.apply (d0, d1) -> (-d0 + d1) (%20, %21) } } - // CHECK: for %i1 = 0 to 7 - // CHECK-NEXT: for %i2 = 0 to 42 + // CHECK: affine.for %i1 = 0 to 7 + // CHECK-NEXT: affine.for %i2 = 0 to 42 // CHECK-NEXT: {{.*}} affine.apply #[[D0PLUSD1]](%i1, %i2) // CHECK-NEXT: {{.*}} affine.apply #[[MINSD0PLUSD1]](%i1, %i2) // CHECK-NEXT: {{.*}} affine.apply #[[D0MINUSD1]](%i1, %i2) - for %i3 = 0 to 16 { - for %i4 = 0 to 47 step 2 { - for %i5 = 0 to 78 step 16 { + affine.for %i3 = 0 to 16 { + affine.for %i4 = 0 to 47 step 2 { + affine.for %i5 = 0 to 78 step 16 { %50 = affine.apply (d0) -> (d0) (%i3) %51 = affine.apply (d0) -> (d0) (%i4) %52 = affine.apply (d0) -> (d0) (%i5) @@ -47,9 +47,9 @@ func @simple() { } } } - // CHECK: for %i3 = 0 to 16 - // CHECK-NEXT: for %i4 = 0 to 47 step 2 - // CHECK-NEXT: for %i5 = 0 to 78 step 16 + // CHECK: affine.for %i3 = 0 to 16 + // CHECK-NEXT: affine.for %i4 = 0 to 47 step 2 + // CHECK-NEXT: affine.for %i5 = 0 to 78 step 16 // CHECK-NEXT: {{.*}} affine.apply #[[ID1]](%i3) // CHECK-NEXT: {{.*}} affine.apply #[[ID1]](%i4) // CHECK-NEXT: {{.*}} affine.apply #[[ID1]](%i5) diff --git a/mlir/test/Transforms/Vectorize/vectorize_1d.mlir b/mlir/test/Transforms/Vectorize/vectorize_1d.mlir index 05e31dbdea5..c812db2d498 100644 --- a/mlir/test/Transforms/Vectorize/vectorize_1d.mlir +++ b/mlir/test/Transforms/Vectorize/vectorize_1d.mlir @@ -23,17 +23,17 @@ func @vec1d(%A : memref, %B : memref) { // // CHECK: for {{.*}} step 128 // CHECK-NEXT: {{.*}} = vector_transfer_read %arg0, [[C0]], [[C0]] {permutation_map: #[[map_proj_d0d1_0]]} : (memref, index, index) -> vector<128xf32> - for %i0 = 0 to %M { // vectorized due to scalar -> vector + affine.for %i0 = 0 to %M { // vectorized due to scalar -> vector %a0 = load %A[%cst0, %cst0] : memref } // // CHECK:for {{.*}} [[ARG_M]] { - for %i1 = 0 to %M { // not vectorized + affine.for %i1 = 0 to %M { // not vectorized %a1 = load %A[%i1, %i1] : memref } // -// CHECK: for %i{{[0-9]*}} = 0 to [[ARG_M]] { - for %i2 = 0 to %M { // not vectorized, would vectorize with --test-fastest-varying=1 +// CHECK: affine.for %i{{[0-9]*}} = 0 to [[ARG_M]] { + affine.for %i2 = 0 to %M { // not vectorized, would vectorize with --test-fastest-varying=1 %r2 = affine.apply (d0) -> (d0) (%i2) %a2 = load %A[%r2#0, %cst0] : memref } @@ -41,7 +41,7 @@ func @vec1d(%A : memref, %B : memref) { // CHECK:for [[IV3:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128 // CHECK-NEXT: [[APP3:%[a-zA-Z0-9]+]] = affine.apply {{.*}}[[IV3]] // CHECK-NEXT: {{.*}} = vector_transfer_read %arg0, [[C0]], [[APP3]] {permutation_map: #[[map_proj_d0d1_d1]]} : {{.*}} -> vector<128xf32> - for %i3 = 0 to %M { // vectorized + affine.for %i3 = 0 to %M { // vectorized %r3 = affine.apply (d0) -> (d0) (%i3) %a3 = load %A[%cst0, %r3#0] : memref } @@ -51,8 +51,8 @@ func @vec1d(%A : memref, %B : memref) { // CHECK-NEXT: [[APP50:%[0-9]+]] = affine.apply {{.*}}([[IV4]], [[IV5]]) // CHECK-NEXT: [[APP51:%[0-9]+]] = affine.apply {{.*}}([[IV4]], [[IV5]]) // CHECK-NEXT: {{.*}} = vector_transfer_read %arg0, [[APP50]], [[APP51]] {permutation_map: #[[map_proj_d0d1_d1]]} : {{.*}} -> vector<128xf32> - for %i4 = 0 to %M { // vectorized - for %i5 = 0 to %N { // not vectorized, would vectorize with --test-fastest-varying=1 + affine.for %i4 = 0 to %M { // vectorized + affine.for %i5 = 0 to %N { // not vectorized, would vectorize with --test-fastest-varying=1 %r50 = affine.apply (d0, d1) -> (d1) (%i4, %i5) %r51 = affine.apply (d0, d1) -> (d0) (%i4, %i5) %a5 = load %A[%r50, %r51] : memref @@ -61,8 +61,8 @@ func @vec1d(%A : memref, %B : memref) { // // CHECK: for [[IV6:%[i0-9]*]] = 0 to [[ARG_M]] { // CHECK-NEXT: for [[IV7:%[i0-9]*]] = 0 to [[ARG_N]] { - for %i6 = 0 to %M { // not vectorized, would vectorize with --test-fastest-varying=1 - for %i7 = 0 to %N { // not vectorized, can never vectorize + affine.for %i6 = 0 to %M { // not vectorized, would vectorize with --test-fastest-varying=1 + affine.for %i7 = 0 to %N { // not vectorized, can never vectorize %r70 = affine.apply (d0, d1) -> (d1 + d0) (%i6, %i7) %r71 = affine.apply (d0, d1) -> (d0) (%i6, %i7) %a7 = load %A[%r70, %r71] : memref @@ -74,8 +74,8 @@ func @vec1d(%A : memref, %B : memref) { // CHECK-NEXT: [[APP9_0:%[0-9]+]] = affine.apply {{.*}}([[IV8]], [[IV9]]) // CHECK-NEXT: [[APP9_1:%[0-9]+]] = affine.apply {{.*}}([[IV8]], [[IV9]]) // CHECK-NEXT: {{.*}} = vector_transfer_read %arg0, [[APP9_0]], [[APP9_1]] {permutation_map: #[[map_proj_d0d1_d1]]} : {{.*}} -> vector<128xf32> - for %i8 = 0 to %M { // vectorized - for %i9 = 0 to %N { + affine.for %i8 = 0 to %M { // vectorized + affine.for %i9 = 0 to %N { %r90 = affine.apply (d0, d1) -> (d1) (%i8, %i9) %r91 = affine.apply (d0, d1) -> (d0 + d1) (%i8, %i9) %a9 = load %A[%r90, %r91] : memref @@ -84,8 +84,8 @@ func @vec1d(%A : memref, %B : memref) { // // CHECK: for [[IV10:%[i0-9]*]] = 0 to %{{[0-9]*}} { // CHECK: for [[IV11:%[i0-9]*]] = 0 to %{{[0-9]*}} { - for %i10 = 0 to %M { // not vectorized, need per load transposes - for %i11 = 0 to %N { // not vectorized, need per load transposes + affine.for %i10 = 0 to %M { // not vectorized, need per load transposes + affine.for %i11 = 0 to %N { // not vectorized, need per load transposes %r11_0 = affine.apply (d0, d1) -> (d0) (%i10, %i11) %r11_1 = affine.apply (d0, d1) -> (d1) (%i10, %i11) %a11 = load %A[%r11_0, %r11_1] : memref @@ -98,9 +98,9 @@ func @vec1d(%A : memref, %B : memref) { // CHECK: for [[IV12:%[i0-9]*]] = 0 to %{{[0-9]*}} { // CHECK: for [[IV13:%[i0-9]*]] = 0 to %{{[0-9]*}} { // CHECK: for [[IV14:%[i0-9]+]] = 0 to [[ARG_P]] step 128 - for %i12 = 0 to %M { // not vectorized, can never vectorize - for %i13 = 0 to %N { // not vectorized, can never vectorize - for %i14 = 0 to %P { // vectorized + affine.for %i12 = 0 to %M { // not vectorized, can never vectorize + affine.for %i13 = 0 to %N { // not vectorized, can never vectorize + affine.for %i14 = 0 to %P { // vectorized %r14_0 = affine.apply (d0, d1, d2) -> (d1) (%i12, %i13, %i14) %r14_1 = affine.apply (d0, d1, d2) -> (d0 + d1) (%i12, %i13, %i14) %r14_2 = affine.apply (d0, d1, d2) -> (d0 + d2) (%i12, %i13, %i14) @@ -109,24 +109,24 @@ func @vec1d(%A : memref, %B : memref) { } } // -// CHECK: for %i{{[0-9]*}} = 0 to %{{[0-9]*}} { - for %i15 = 0 to %M { // not vectorized due to condition below +// CHECK: affine.for %i{{[0-9]*}} = 0 to %{{[0-9]*}} { + affine.for %i15 = 0 to %M { // not vectorized due to condition below affine.if #set0(%i15) { %a15 = load %A[%cst0, %cst0] : memref } } // -// CHECK: for %i{{[0-9]*}} = 0 to %{{[0-9]*}} { - for %i16 = 0 to %M { // not vectorized, can't vectorize a vector load +// CHECK: affine.for %i{{[0-9]*}} = 0 to %{{[0-9]*}} { + affine.for %i16 = 0 to %M { // not vectorized, can't vectorize a vector load %a16 = alloc(%M) : memref> %l16 = load %a16[%i16] : memref> } // -// CHECK: for %i{{[0-9]*}} = 0 to %{{[0-9]*}} { +// CHECK: affine.for %i{{[0-9]*}} = 0 to %{{[0-9]*}} { // CHECK: for [[IV18:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128 // CHECK: {{.*}} = vector_transfer_read %arg0, [[C0]], [[C0]] {permutation_map: #[[map_proj_d0d1_0]]} : {{.*}} -> vector<128xf32> - for %i17 = 0 to %M { // not vectorized, the 1-D pattern that matched %i18 in DFS post-order prevents vectorizing %i17 - for %i18 = 0 to %M { // vectorized due to scalar -> vector + affine.for %i17 = 0 to %M { // not vectorized, the 1-D pattern that matched %i18 in DFS post-order prevents vectorizing %i17 + affine.for %i18 = 0 to %M { // vectorized due to scalar -> vector %a18 = load %A[%cst0, %cst0] : memref } } @@ -139,24 +139,24 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { %C = alloc (%M, %N) : memref %f1 = constant 1.0 : f32 %f2 = constant 2.0 : f32 - for %i0 = 0 to %M { - for %i1 = 0 to %N { + affine.for %i0 = 0 to %M { + affine.for %i1 = 0 to %N { // CHECK: [[C1:%.*]] = constant splat, 1.000000e+00> : vector<128xf32> // CHECK: vector_transfer_write [[C1]], {{.*}} {permutation_map: #[[map_proj_d0d1_d1]]} : vector<128xf32>, memref, index, index // non-scoped %f1 store %f1, %A[%i0, %i1] : memref } } - for %i2 = 0 to %M { - for %i3 = 0 to %N { + affine.for %i2 = 0 to %M { + affine.for %i3 = 0 to %N { // CHECK: [[C3:%.*]] = constant splat, 2.000000e+00> : vector<128xf32> // CHECK: vector_transfer_write [[C3]], {{.*}} {permutation_map: #[[map_proj_d0d1_d1]]} : vector<128xf32>, memref, index, index // non-scoped %f2 store %f2, %B[%i2, %i3] : memref } } - for %i4 = 0 to %M { - for %i5 = 0 to %N { + affine.for %i4 = 0 to %M { + affine.for %i5 = 0 to %N { // CHECK: [[A5:%.*]] = vector_transfer_read %0, {{.*}} {permutation_map: #[[map_proj_d0d1_d1]]} : (memref, index, index) -> vector<128xf32> // CHECK: [[B5:%.*]] = vector_transfer_read %1, {{.*}} {permutation_map: #[[map_proj_d0d1_d1]]} : (memref, index, index) -> vector<128xf32> // CHECK: [[S5:%.*]] = addf [[A5]], [[B5]] : vector<128xf32> @@ -188,10 +188,10 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { // CHECK-LABEL: @vec_rejected func @vec_rejected(%A : memref, %C : memref) { %N = dim %A, 0 : memref - for %i = 0 to %N { + affine.for %i = 0 to %N { // CHECK-NOT: vector %a = load %A[%i, %i] : memref // not vectorized - for %j = 0 to %N { + affine.for %j = 0 to %N { %b = load %A[%i, %j] : memref // may be vectorized // CHECK-NOT: vector %c = addf %a, %b : f32 // not vectorized because %a wasn't diff --git a/mlir/test/Transforms/Vectorize/vectorize_2d.mlir b/mlir/test/Transforms/Vectorize/vectorize_2d.mlir index d847f6bb5ce..59c7483749b 100644 --- a/mlir/test/Transforms/Vectorize/vectorize_2d.mlir +++ b/mlir/test/Transforms/Vectorize/vectorize_2d.mlir @@ -11,13 +11,13 @@ func @vec2d(%A : memref) { // CHECK: for {{.*}} = 0 to %1 step 32 // CHECK: for {{.*}} = 0 to %2 step 256 // Example: - // for %i0 = 0 to %0 { - // for %i1 = 0 to %1 step 32 { - // for %i2 = 0 to %2 step 256 { + // affine.for %i0 = 0 to %0 { + // affine.for %i1 = 0 to %1 step 32 { + // affine.for %i2 = 0 to %2 step 256 { // %3 = "vector_transfer_read"(%arg0, %i0, %i1, %i2) : (memref, index, index, index) -> vector<32x256xf32> - for %i0 = 0 to %M { - for %i1 = 0 to %N { - for %i2 = 0 to %P { + affine.for %i0 = 0 to %M { + affine.for %i1 = 0 to %N { + affine.for %i2 = 0 to %P { %a2 = load %A[%i0, %i1, %i2] : memref } } @@ -27,9 +27,9 @@ func @vec2d(%A : memref) { // CHECK: for {{.*}} = 0 to %2 { // For the case: --test-fastest-varying=1 --test-fastest-varying=0 no // vectorization happens because of loop nesting order . - for %i3 = 0 to %M { - for %i4 = 0 to %N { - for %i5 = 0 to %P { + affine.for %i3 = 0 to %M { + affine.for %i4 = 0 to %N { + affine.for %i5 = 0 to %P { %a5 = load %A[%i4, %i5, %i3] : memref } } @@ -43,24 +43,24 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { %C = alloc (%M, %N) : memref %f1 = constant 1.0 : f32 %f2 = constant 2.0 : f32 - for %i0 = 0 to %M { - for %i1 = 0 to %N { + affine.for %i0 = 0 to %M { + affine.for %i1 = 0 to %N { // CHECK: [[C1:%.*]] = constant splat, 1.000000e+00> : vector<32x256xf32> // CHECK: vector_transfer_write [[C1]], {{.*}} {permutation_map: #[[map_proj_d0d1_d0d1]]} : vector<32x256xf32>, memref, index, index // non-scoped %f1 store %f1, %A[%i0, %i1] : memref } } - for %i2 = 0 to %M { - for %i3 = 0 to %N { + affine.for %i2 = 0 to %M { + affine.for %i3 = 0 to %N { // CHECK: [[C3:%.*]] = constant splat, 2.000000e+00> : vector<32x256xf32> // CHECK: vector_transfer_write [[C3]], {{.*}} {permutation_map: #[[map_proj_d0d1_d0d1]]} : vector<32x256xf32>, memref, index, index // non-scoped %f2 store %f2, %B[%i2, %i3] : memref } } - for %i4 = 0 to %M { - for %i5 = 0 to %N { + affine.for %i4 = 0 to %M { + affine.for %i5 = 0 to %N { // CHECK: [[A5:%.*]] = vector_transfer_read %0, {{.*}} {permutation_map: #[[map_proj_d0d1_d0d1]]} : (memref, index, index) -> vector<32x256xf32> // CHECK: [[B5:%.*]] = vector_transfer_read %1, {{.*}} {permutation_map: #[[map_proj_d0d1_d0d1]]} : (memref, index, index) -> vector<32x256xf32> // CHECK: [[S5:%.*]] = addf [[A5]], [[B5]] : vector<32x256xf32> diff --git a/mlir/test/Transforms/Vectorize/vectorize_3d.mlir b/mlir/test/Transforms/Vectorize/vectorize_3d.mlir index 1a6bee585ee..08ca27dbeee 100644 --- a/mlir/test/Transforms/Vectorize/vectorize_3d.mlir +++ b/mlir/test/Transforms/Vectorize/vectorize_3d.mlir @@ -7,17 +7,17 @@ func @vec3d(%A : memref) { %0 = dim %A, 0 : memref %1 = dim %A, 1 : memref %2 = dim %A, 2 : memref - // CHECK: for %i0 = 0 to %0 { - // CHECK: for %i1 = 0 to %0 { - // CHECK: for %i2 = 0 to %0 step 32 { - // CHECK: for %i3 = 0 to %1 step 64 { - // CHECK: for %i4 = 0 to %2 step 256 { + // CHECK: affine.for %i0 = 0 to %0 { + // CHECK: affine.for %i1 = 0 to %0 { + // CHECK: affine.for %i2 = 0 to %0 step 32 { + // CHECK: affine.for %i3 = 0 to %1 step 64 { + // CHECK: affine.for %i4 = 0 to %2 step 256 { // CHECK: %3 = vector_transfer_read %arg0, %i2, %i3, %i4 {permutation_map: #[[map_proj_d0d1d2_d0d1d2]]} : (memref, index, index, index) -> vector<32x64x256xf32> - for %t0 = 0 to %0 { - for %t1 = 0 to %0 { - for %i0 = 0 to %0 { - for %i1 = 0 to %1 { - for %i2 = 0 to %2 { + affine.for %t0 = 0 to %0 { + affine.for %t1 = 0 to %0 { + affine.for %i0 = 0 to %0 { + affine.for %i1 = 0 to %1 { + affine.for %i2 = 0 to %2 { %a2 = load %A[%i0, %i1, %i2] : memref } } diff --git a/mlir/test/Transforms/Vectorize/vectorize_outer_loop_2d.mlir b/mlir/test/Transforms/Vectorize/vectorize_outer_loop_2d.mlir index 4654ab810df..d00b99f1716 100644 --- a/mlir/test/Transforms/Vectorize/vectorize_outer_loop_2d.mlir +++ b/mlir/test/Transforms/Vectorize/vectorize_outer_loop_2d.mlir @@ -7,13 +7,13 @@ func @vec2d(%A : memref) { %M = dim %A, 0 : memref %N = dim %A, 1 : memref %P = dim %A, 2 : memref - // CHECK: for %i0 = 0 to %0 step 32 - // CHECK: for %i1 = 0 to %1 { - // CHECK: for %i2 = 0 to %2 step 256 + // CHECK: affine.for %i0 = 0 to %0 step 32 + // CHECK: affine.for %i1 = 0 to %1 { + // CHECK: affine.for %i2 = 0 to %2 step 256 // CHECK: {{.*}} = vector_transfer_read %arg0, %i0, %i1, %i2 {permutation_map: #[[map_proj_d0d1d2_d0d2]]} : (memref, index, index, index) -> vector<32x256xf32> - for %i0 = 0 to %M { - for %i1 = 0 to %N { - for %i2 = 0 to %P { + affine.for %i0 = 0 to %M { + affine.for %i1 = 0 to %N { + affine.for %i2 = 0 to %P { %a2 = load %A[%i0, %i1, %i2] : memref } } @@ -23,9 +23,9 @@ func @vec2d(%A : memref) { // CHECK: for {{.*}} = 0 to %2 { // For the case: --test-fastest-varying=2 --test-fastest-varying=0 no // vectorization happens because of loop nesting order - for %i3 = 0 to %M { - for %i4 = 0 to %N { - for %i5 = 0 to %P { + affine.for %i3 = 0 to %M { + affine.for %i4 = 0 to %N { + affine.for %i5 = 0 to %P { %a5 = load %A[%i4, %i5, %i3] : memref } } diff --git a/mlir/test/Transforms/Vectorize/vectorize_outer_loop_transpose_2d.mlir b/mlir/test/Transforms/Vectorize/vectorize_outer_loop_transpose_2d.mlir index 0eebf816535..a8a8d5d7790 100644 --- a/mlir/test/Transforms/Vectorize/vectorize_outer_loop_transpose_2d.mlir +++ b/mlir/test/Transforms/Vectorize/vectorize_outer_loop_transpose_2d.mlir @@ -12,20 +12,20 @@ func @vec2d(%A : memref) { // CHECK: for {{.*}} = 0 to %2 { // For the case: --test-fastest-varying=0 --test-fastest-varying=2 no // vectorization happens because of loop nesting order. - for %i0 = 0 to %M { - for %i1 = 0 to %N { - for %i2 = 0 to %P { + affine.for %i0 = 0 to %M { + affine.for %i1 = 0 to %N { + affine.for %i2 = 0 to %P { %a2 = load %A[%i0, %i1, %i2] : memref } } } - // CHECK: for %i3 = 0 to %0 step 32 - // CHECK: for %i4 = 0 to %1 step 256 - // CHECK: for %i5 = 0 to %2 { + // CHECK: affine.for %i3 = 0 to %0 step 32 + // CHECK: affine.for %i4 = 0 to %1 step 256 + // CHECK: affine.for %i5 = 0 to %2 { // CHECK: {{.*}} = vector_transfer_read %arg0, %i4, %i5, %i3 {permutation_map: #[[map_proj_d0d1d2_d2d0]]} : (memref, index, index, index) -> vector<32x256xf32> - for %i3 = 0 to %M { - for %i4 = 0 to %N { - for %i5 = 0 to %P { + affine.for %i3 = 0 to %M { + affine.for %i4 = 0 to %N { + affine.for %i5 = 0 to %P { %a5 = load %A[%i4, %i5, %i3] : memref } } @@ -37,26 +37,26 @@ func @vec2d_imperfectly_nested(%A : memref) { %0 = dim %A, 0 : memref %1 = dim %A, 1 : memref %2 = dim %A, 2 : memref - // CHECK: for %i0 = 0 to %0 step 32 { - // CHECK: for %i1 = 0 to %1 { - // CHECK: for %i2 = 0 to %2 step 256 { + // CHECK: affine.for %i0 = 0 to %0 step 32 { + // CHECK: affine.for %i1 = 0 to %1 { + // CHECK: affine.for %i2 = 0 to %2 step 256 { // CHECK: %3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: #[[map_proj_d0d1d2_d2d0]]} : (memref, index, index, index) -> vector<32x256xf32> - // CHECK: for %i3 = 0 to %1 step 256 { - // CHECK: for %i4 = 0 to %2 { + // CHECK: affine.for %i3 = 0 to %1 step 256 { + // CHECK: affine.for %i4 = 0 to %2 { // CHECK: %4 = vector_transfer_read %arg0, %i3, %i4, %i0 {permutation_map: #[[map_proj_d0d1d2_d2d0]]} : (memref, index, index, index) -> vector<32x256xf32> - // CHECK: for %i5 = 0 to %2 { + // CHECK: affine.for %i5 = 0 to %2 { // CHECK: %5 = vector_transfer_read %arg0, %i3, %i5, %i0 {permutation_map: #[[map_proj_d0d1d2_d2d0]]} : (memref, index, index, index) -> vector<32x256xf32> - for %i0 = 0 to %0 { - for %i1 = 0 to %1 { - for %i2 = 0 to %2 { + affine.for %i0 = 0 to %0 { + affine.for %i1 = 0 to %1 { + affine.for %i2 = 0 to %2 { %a2 = load %A[%i2, %i1, %i0] : memref } } - for %i3 = 0 to %1 { - for %i4 = 0 to %2 { + affine.for %i3 = 0 to %1 { + affine.for %i4 = 0 to %2 { %a4 = load %A[%i3, %i4, %i0] : memref } - for %i5 = 0 to %2 { + affine.for %i5 = 0 to %2 { %a5 = load %A[%i3, %i5, %i0] : memref } } diff --git a/mlir/test/Transforms/Vectorize/vectorize_transpose_2d.mlir b/mlir/test/Transforms/Vectorize/vectorize_transpose_2d.mlir index 1ba563b3442..b8e4e075890 100644 --- a/mlir/test/Transforms/Vectorize/vectorize_transpose_2d.mlir +++ b/mlir/test/Transforms/Vectorize/vectorize_transpose_2d.mlir @@ -12,20 +12,20 @@ func @vec2d(%A : memref) { // CHECK: for {{.*}} = 0 to %2 { // For the case: --test-fastest-varying=0 --test-fastest-varying=1 no // vectorization happens because of loop nesting order. - for %i0 = 0 to %M { - for %i1 = 0 to %N { - for %i2 = 0 to %P { + affine.for %i0 = 0 to %M { + affine.for %i1 = 0 to %N { + affine.for %i2 = 0 to %P { %a2 = load %A[%i0, %i1, %i2] : memref } } } - // CHECK: for %i3 = 0 to %0 step 32 - // CHECK: for %i4 = 0 to %1 { - // CHECK: for %i5 = 0 to %2 step 256 + // CHECK: affine.for %i3 = 0 to %0 step 32 + // CHECK: affine.for %i4 = 0 to %1 { + // CHECK: affine.for %i5 = 0 to %2 step 256 // CHECK: {{.*}} = vector_transfer_read %arg0, %i4, %i5, %i3 {permutation_map: #[[map_proj_d0d1d2_d2d1]]} : (memref, index, index, index) -> vector<32x256xf32> - for %i3 = 0 to %M { - for %i4 = 0 to %N { - for %i5 = 0 to %P { + affine.for %i3 = 0 to %M { + affine.for %i4 = 0 to %N { + affine.for %i5 = 0 to %P { %a5 = load %A[%i4, %i5, %i3] : memref } } @@ -37,26 +37,26 @@ func @vec2d_imperfectly_nested(%A : memref) { %0 = dim %A, 0 : memref %1 = dim %A, 1 : memref %2 = dim %A, 2 : memref - // CHECK: for %i0 = 0 to %0 step 32 { - // CHECK: for %i1 = 0 to %1 step 256 { - // CHECK: for %i2 = 0 to %2 { + // CHECK: affine.for %i0 = 0 to %0 step 32 { + // CHECK: affine.for %i1 = 0 to %1 step 256 { + // CHECK: affine.for %i2 = 0 to %2 { // CHECK: %3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: #[[map_proj_d0d1d2_d2d1]]} : (memref, index, index, index) -> vector<32x256xf32> - // CHECK: for %i3 = 0 to %1 { - // CHECK: for %i4 = 0 to %2 step 256 { + // CHECK: affine.for %i3 = 0 to %1 { + // CHECK: affine.for %i4 = 0 to %2 step 256 { // CHECK: %4 = vector_transfer_read %arg0, %i3, %i4, %i0 {permutation_map: #[[map_proj_d0d1d2_d2d1]]} : (memref, index, index, index) -> vector<32x256xf32> - // CHECK: for %i5 = 0 to %2 step 256 { + // CHECK: affine.for %i5 = 0 to %2 step 256 { // CHECK: %5 = vector_transfer_read %arg0, %i3, %i5, %i0 {permutation_map: #[[map_proj_d0d1d2_d2d1]]} : (memref, index, index, index) -> vector<32x256xf32> - for %i0 = 0 to %0 { - for %i1 = 0 to %1 { - for %i2 = 0 to %2 { + affine.for %i0 = 0 to %0 { + affine.for %i1 = 0 to %1 { + affine.for %i2 = 0 to %2 { %a2 = load %A[%i2, %i1, %i0] : memref } } - for %i3 = 0 to %1 { - for %i4 = 0 to %2 { + affine.for %i3 = 0 to %1 { + affine.for %i4 = 0 to %2 { %a4 = load %A[%i3, %i4, %i0] : memref } - for %i5 = 0 to %2 { + affine.for %i5 = 0 to %2 { %a5 = load %A[%i3, %i5, %i0] : memref } } diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir index 142770f71b6..94edd91004b 100644 --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -205,10 +205,10 @@ func @dyn_shape_fold(%L : index, %M : index) -> (memref, memref %c = alloc(%K, %N) : memref - // CHECK: for %i0 = - for %i = 0 to %L { - // CHECK-NEXT: for %i1 = - for %j = 0 to 10 { + // CHECK: affine.for %i0 = + affine.for %i = 0 to %L { + // CHECK-NEXT: affine.for %i1 = + affine.for %j = 0 to 10 { // CHECK-NEXT: %4 = load %0[%i0, %i1] : memref // CHECK-NEXT: store %4, %1[%c0, %c0, %i0, %i1, %c0] : memref<4x1024x8x512x?xf32> %v = load %a[%i, %j] : memref @@ -234,8 +234,8 @@ func @merge_constants() -> (index, index) { // CHECK-LABEL: func @hoist_constant func @hoist_constant(%arg0: memref<8xi32>) { // CHECK-NEXT: %c42_i32 = constant 42 : i32 - // CHECK-NEXT: for %i0 = 0 to 8 { - for %i0 = 0 to 8 { + // CHECK-NEXT: affine.for %i0 = 0 to 8 { + affine.for %i0 = 0 to 8 { // CHECK-NEXT: store %c42_i32, %arg0[%i0] %c42_i32 = constant 42 : i32 store %c42_i32, %arg0[%i0] : memref<8xi32> diff --git a/mlir/test/Transforms/constant-fold.mlir b/mlir/test/Transforms/constant-fold.mlir index b9197b967ce..b40daa1df6f 100644 --- a/mlir/test/Transforms/constant-fold.mlir +++ b/mlir/test/Transforms/constant-fold.mlir @@ -2,8 +2,8 @@ // CHECK-LABEL: @test(%arg0: memref) { func @test(%p : memref) { - for %i0 = 0 to 128 { - for %i1 = 0 to 8 { // CHECK: for %i1 = 0 to 8 { + affine.for %i0 = 0 to 128 { + affine.for %i1 = 0 to 8 { // CHECK: affine.for %i1 = 0 to 8 { %0 = constant 4.5 : f32 %1 = constant 1.5 : f32 diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir index 38d95a8abec..617bd800fed 100644 --- a/mlir/test/Transforms/cse.mlir +++ b/mlir/test/Transforms/cse.mlir @@ -113,8 +113,8 @@ func @down_propagate_for() { // CHECK: %c1_i32 = constant 1 : i32 %0 = constant 1 : i32 - // CHECK-NEXT: for %i0 = 0 to 4 { - for %i = 0 to 4 { + // CHECK-NEXT: affine.for %i0 = 0 to 4 { + affine.for %i = 0 to 4 { // CHECK-NEXT: "foo"(%c1_i32, %c1_i32) : (i32, i32) -> () %1 = constant 1 : i32 "foo"(%0, %1) : (i32, i32) -> () @@ -145,8 +145,8 @@ func @down_propagate() -> i32 { /// Check that operation definitions are NOT propagated up the dominance tree. // CHECK-LABEL: @up_propagate_for func @up_propagate_for() -> i32 { - // CHECK: for %i0 = 0 to 4 { - for %i = 0 to 4 { + // CHECK: affine.for %i0 = 0 to 4 { + affine.for %i = 0 to 4 { // CHECK-NEXT: %c1_i32 = constant 1 : i32 // CHECK-NEXT: "foo"(%c1_i32) : (i32) -> () %0 = constant 1 : i32 diff --git a/mlir/test/Transforms/dma-generate.mlir b/mlir/test/Transforms/dma-generate.mlir index 1b3d35e1154..dfdfb7a14c3 100644 --- a/mlir/test/Transforms/dma-generate.mlir +++ b/mlir/test/Transforms/dma-generate.mlir @@ -42,7 +42,7 @@ func @loop_nest_1d() { // Second DMA transfer. // CHECK: dma_start %1[%c256], %5[%c0], %c256_0, %6[%c0] : memref<512xf32>, memref<256xf32, 2>, memref<1xi32> // CHECK-NEXT: dma_wait %6[%c0], %c256_0 : memref<1xi32> - // CHECK: for %i0 = 0 to 256 { + // CHECK: affine.for %i0 = 0 to 256 { // CHECK-NEXT: %7 = load %3[%i0] : memref<256xf32, 2> // CHECK: %8 = affine.apply [[MAP_PLUS_256]](%i0) // CHECK: %9 = affine.apply [[MAP_MINUS_256]](%8) @@ -55,7 +55,7 @@ func @loop_nest_1d() { // CHECK-NEXT: dealloc %4 : memref<1xi32> // CHECK-NEXT: dealloc %3 : memref<256xf32, 2> // CHECK-NEXT: return - for %i = 0 to 256 { + affine.for %i = 0 to 256 { load %A[%i] : memref<256 x f32> %idx = affine.apply (d0) -> (d0 + 256)(%i) load %B[%idx] : memref<512 x f32> @@ -82,20 +82,20 @@ func @loop_nest_1d() { // INCOMING DMA for C. // CHECK-DAG: dma_start %arg2[%c0, %c0], [[BUFC]][%c0, %c0], %c16384_0, [[TAGC]][%c0] : memref<512x32xf32>, memref<512x32xf32, 2>, memref<1xi32> // CHECK-DAG: dma_wait [[TAGC]][%c0], %c16384_0 : memref<1xi32> -// CHECK-NEXT: for %i0 = 0 to 32 { -// CHECK-NEXT: for %i1 = 0 to 32 { -// CHECK-NEXT: for %i2 = 0 to 32 { -// CHECK-NEXT: for %i3 = 0 to 16 { +// CHECK-NEXT: affine.for %i0 = 0 to 32 { +// CHECK-NEXT: affine.for %i1 = 0 to 32 { +// CHECK-NEXT: affine.for %i2 = 0 to 32 { +// CHECK-NEXT: affine.for %i3 = 0 to 16 { // CHECK-NEXT: %7 = affine.apply #map{{[0-9]+}}(%i1, %i3) // CHECK-NEXT: %8 = load [[BUFB]][%7, %i0] : memref<512x32xf32, 2> // CHECK-NEXT: "foo"(%8) : (f32) -> () // CHECK-NEXT: } -// CHECK-NEXT: for %i4 = 0 to 16 { +// CHECK-NEXT: affine.for %i4 = 0 to 16 { // CHECK-NEXT: %9 = affine.apply #map{{[0-9]+}}(%i2, %i4) // CHECK-NEXT: %10 = load [[BUFA]][%9, %i1] : memref<512x32xf32, 2> // CHECK-NEXT: "bar"(%10) : (f32) -> () // CHECK-NEXT: } -// CHECK-NEXT: for %i5 = 0 to 16 { +// CHECK-NEXT: affine.for %i5 = 0 to 16 { // CHECK-NEXT: %11 = "abc_compute"() : () -> f32 // CHECK-NEXT: %12 = affine.apply #map{{[0-9]+}}(%i2, %i5) // CHECK-NEXT: %13 = load [[BUFC]][%12, %i0] : memref<512x32xf32, 2> @@ -123,20 +123,20 @@ func @loop_nest_high_d(%A: memref<512 x 32 x f32>, // DMAs will be performed at this level (jT is the first loop without a stride). // A and B are read, while C is both read and written. A total of three new buffers // are allocated and existing load's/store's are replaced by accesses to those buffers. - for %jT = 0 to 32 { - for %kT = 0 to 32 { - for %iT = 0 to 32 { - for %kk = 0 to 16 { // k intratile + affine.for %jT = 0 to 32 { + affine.for %kT = 0 to 32 { + affine.for %iT = 0 to 32 { + affine.for %kk = 0 to 16 { // k intratile %k = affine.apply (d0, d1) -> (16*d0 + d1) (%kT, %kk) %v0 = load %B[%k, %jT] : memref<512 x 32 x f32> "foo"(%v0) : (f32) -> () } - for %ii = 0 to 16 { // i intratile. + affine.for %ii = 0 to 16 { // i intratile. %i = affine.apply (d0, d1) -> (16*d0 + d1)(%iT, %ii) %v1 = load %A[%i, %kT] : memref<512 x 32 x f32> "bar"(%v1) : (f32) -> () } - for %ii_ = 0 to 16 { // i intratile. + affine.for %ii_ = 0 to 16 { // i intratile. %v2 = "abc_compute"() : () -> f32 %i_ = affine.apply (d0, d1) -> (16*d0 + d1)(%iT, %ii_) %v3 = load %C[%i_, %jT] : memref<512 x 32 x f32> @@ -155,13 +155,13 @@ func @loop_nest_high_d(%A: memref<512 x 32 x f32>, // // CHECK-LABEL: func @loop_nest_modulo() { // CHECK: %0 = alloc() : memref<256x8xf32> -// CHECK-NEXT: for %i0 = 0 to 32 step 4 { +// CHECK-NEXT: affine.for %i0 = 0 to 32 step 4 { // CHECK-NEXT: %1 = affine.apply #map{{[0-9]+}}(%i0) // CHECK-NEXT: %2 = alloc() : memref<1x2xf32, 2> // CHECK-NEXT: %3 = alloc() : memref<1xi32> // CHECK-NEXT: dma_start %0[%1, %c0], %2[%c0, %c0], %c2, %3[%c0] : memref<256x8xf32>, memref<1x2xf32, 2>, memref<1xi32> // CHECK-NEXT: dma_wait %3[%c0], %c2 : memref<1xi32> -// CHECK-NEXT: for %i1 = 0 to 8 { +// CHECK-NEXT: affine.for %i1 = 0 to 8 { // ... // ... // CHECK: } @@ -171,9 +171,9 @@ func @loop_nest_high_d(%A: memref<512 x 32 x f32>, // CHECK-NEXT: return func @loop_nest_modulo() { %A = alloc() : memref<256 x 8 x f32> - for %i = 0 to 32 step 4 { + affine.for %i = 0 to 32 step 4 { // DMAs will be performed at this level (%j is the first unit stride loop) - for %j = 0 to 8 { + affine.for %j = 0 to 8 { %idx = affine.apply (d0) -> (d0 mod 2) (%j) // A buffer of size 32 x 2 will be allocated (original buffer was 256 x 8). %v = load %A[%i, %idx] : memref<256 x 8 x f32> @@ -187,17 +187,17 @@ func @loop_nest_modulo() { // CHECK-LABEL: func @loop_nest_tiled() -> memref<256x1024xf32> { func @loop_nest_tiled() -> memref<256x1024xf32> { %0 = alloc() : memref<256x1024xf32> - for %i0 = 0 to 256 step 32 { - for %i1 = 0 to 1024 step 32 { + affine.for %i0 = 0 to 256 step 32 { + affine.for %i1 = 0 to 1024 step 32 { // CHECK: %3 = alloc() : memref<32x32xf32, 2> // CHECK-NEXT: %4 = alloc() : memref<1xi32> // Strided DMA here: 32 x 32 tile in a 256 x 1024 memref. // CHECK-NEXT: dma_start %0[%1, %2], %3[%c0, %c0], %c1024, %4[%c0], %c1024_0, %c32 : memref<256x1024xf32>, memref<32x32xf32, 2>, memref<1xi32> // CHECK-NEXT: dma_wait -// CHECK-NEXT: for %i2 = #map -// CHECK-NEXT: for %i3 = #map - for %i2 = (d0) -> (d0)(%i0) to (d0) -> (d0 + 32)(%i0) { - for %i3 = (d0) -> (d0)(%i1) to (d0) -> (d0 + 32)(%i1) { +// CHECK-NEXT: affine.for %i2 = #map +// CHECK-NEXT: affine.for %i3 = #map + affine.for %i2 = (d0) -> (d0)(%i0) to (d0) -> (d0 + 32)(%i0) { + affine.for %i3 = (d0) -> (d0)(%i1) to (d0) -> (d0 + 32)(%i1) { // CHECK-NEXT: %5 = affine.apply [[MAP_INDEX_DIFF_EVEN]](%i0, %i1, %i2, %i3) // CHECK-NEXT: %6 = affine.apply [[MAP_INDEX_DIFF_ODD]](%i0, %i1, %i2, %i3) // CHECK-NEXT: %7 = load %3[%5, %6] : memref<32x32xf32, 2> @@ -218,8 +218,8 @@ func @dma_constant_dim_access(%A : memref<100x100xf32>) { // No strided DMA needed here. // CHECK: dma_start %arg0[%c1, %c0], %0[%c0, %c0], %c100, %1[%c0] : memref<100x100xf32>, memref<1x100xf32, 2>, // CHECK-NEXT: dma_wait %1[%c0], %c100 : memref<1xi32> - for %i = 0 to 100 { - for %j = 0 to ()[s0] -> (s0) ()[%N] { + affine.for %i = 0 to 100 { + affine.for %j = 0 to ()[s0] -> (s0) ()[%N] { // CHECK: %2 = affine.apply [[MAP_D0_MINUS_ONE]](%c1_0, %i1) // CHECK: %3 = affine.apply [[MAP_D1]](%c1_0, %i1) // CHECK-NEXT: %4 = load %0[%2, %3] : memref<1x100xf32, 2> @@ -232,8 +232,8 @@ func @dma_constant_dim_access(%A : memref<100x100xf32>) { // CHECK-LABEL: func @dma_with_symbolic_accesses func @dma_with_symbolic_accesses(%A : memref<100x100xf32>, %M : index) { %N = constant 9 : index - for %i = 0 to 100 { - for %j = 0 to 100 { + affine.for %i = 0 to 100 { + affine.for %j = 0 to 100 { %idy = affine.apply (d0, d1) [s0, s1] -> (d1 + s0 + s1)(%i, %j)[%M, %N] load %A[%i, %idy] : memref<100 x 100 x f32> } @@ -243,8 +243,8 @@ func @dma_with_symbolic_accesses(%A : memref<100x100xf32>, %M : index) { // CHECK-NEXT: %2 = alloc() : memref<1xi32> // CHECK-NEXT: dma_start %arg0[%c0, %0], %1[%c0, %c0], %c10000, %2[%c0] // CHECK-NEXT: dma_wait %2[%c0], %c10000 -// CHECK-NEXT: for %i0 = 0 to 100 { -// CHECK-NEXT: for %i1 = 0 to 100 { +// CHECK-NEXT: affine.for %i0 = 0 to 100 { +// CHECK-NEXT: affine.for %i1 = 0 to 100 { // CHECK-NEXT: %3 = affine.apply [[MAP_SYM_SHIFT]](%i0, %i1)[%arg1, %c9] // CHECK-NEXT: %4 = affine.apply [[MAP_3D_D1]](%arg1, %i0, %3) // CHECK-NEXT: %5 = affine.apply [[MAP_SUB_OFFSET]](%arg1, %i0, %3) @@ -263,8 +263,8 @@ func @dma_with_symbolic_loop_bounds(%A : memref<100x100xf32>, %M : index, %N: in // CHECK-NEXT: %1 = alloc() : memref<1xi32> // CHECK-NEXT: dma_start %arg0[%c0, %c0], %0[%c0, %c0], %c10000, %1[%c0] : memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32> // CHECK-NEXT: dma_wait %1[%c0], %c10000 : memref<1xi32> - for %i = 0 to 100 { - for %j = %M to %N { + affine.for %i = 0 to 100 { + affine.for %j = %M to %N { %idy = affine.apply (d1) [s0] -> (d1 + s0)(%j)[%K] load %A[%i, %idy] : memref<100 x 100 x f32> } @@ -278,8 +278,8 @@ func @dma_with_symbolic_loop_bounds(%A : memref<100x100xf32>, %M : index, %N: in func @dma_unknown_size(%arg0: memref) { %M = dim %arg0, 0 : memref %N = dim %arg0, 0 : memref - for %i = 0 to %M { - for %j = 0 to %N { + affine.for %i = 0 to %M { + affine.for %j = 0 to %N { // If this loop nest isn't tiled, the access requires a non-constant DMA // size -- not yet implemented. // CHECK: %2 = load %arg0[%i0, %i1] : memref @@ -294,9 +294,9 @@ func @dma_unknown_size(%arg0: memref) { // CHECK-LABEL: func @dma_memref_3d func @dma_memref_3d(%arg0: memref<1024x1024x1024xf32>) { - for %i = 0 to 1024 { - for %j = 0 to 1024 { - for %k = 0 to 1024 { + affine.for %i = 0 to 1024 { + affine.for %j = 0 to 1024 { + affine.for %k = 0 to 1024 { %idx = affine.apply (d0) -> (d0 mod 128)(%i) %idy = affine.apply (d0) -> (d0 mod 128)(%j) %idz = affine.apply (d0) -> (d0 mod 128)(%k) @@ -330,8 +330,8 @@ func @dma_memref_3d(%arg0: memref<1024x1024x1024xf32>) { // CHECK-LABEL: func @multi_load_store_union() { func @multi_load_store_union() { %A = alloc() : memref<512 x 512 x f32> - for %i = 0 to 256 { - for %j = 0 to 256 { + affine.for %i = 0 to 256 { + affine.for %j = 0 to 256 { %idx = affine.apply (d0) -> (d0 + 64)(%i) %idy = affine.apply (d0) -> (d0 + 128)(%j) %ishift = affine.apply (d0) -> (d0 + 2)(%i) @@ -355,8 +355,8 @@ func @multi_load_store_union() { // CHECK-NEXT: dma_start %0[%c2_1, %c2_2], %1[%c0, %c0], %c170372_3, %2[%c0], %c512_4, %c446_5 : memref<512x512xf32>, memref<382x446xf32, 2>, memref<1xi32> // CHECK-NEXT: dma_wait %2[%c0], %c170372_3 : memref<1xi32> // CHECK-NEXT: %3 = alloc() : memref<1xi32> -// CHECK-NEXT: for %i0 = 0 to 256 { -// CHECK-NEXT: for %i1 = 0 to 256 { +// CHECK-NEXT: affine.for %i0 = 0 to 256 { +// CHECK-NEXT: affine.for %i1 = 0 to 256 { // CHECK-NEXT: %4 = affine.apply [[MAP_PLUS_64]](%i0) // CHECK-NEXT: %5 = affine.apply [[MAP_PLUS_128]](%i1) // CHECK-NEXT: %6 = affine.apply [[MAP_PLUS_2]](%i0) @@ -395,7 +395,7 @@ func @dma_loop_straightline_interspersed() { %c255 = constant 255 : index %A = alloc() : memref<256 x f32> %v = load %A[%c0] : memref<256 x f32> - for %i = 1 to 255 { + affine.for %i = 1 to 255 { load %A[%i] : memref<256 x f32> } %l = load %A[%c255] : memref<256 x f32> @@ -416,7 +416,7 @@ func @dma_loop_straightline_interspersed() { // CHECK-NEXT: %5 = alloc() : memref<1xi32> // CHECK-NEXT: dma_start %0[%c1_0], %4[%c0], %c254, %5[%c0] : memref<256xf32>, memref<254xf32, 2>, memref<1xi32> // CHECK-NEXT: dma_wait %5[%c0], %c254 : memref<1xi32> -// CHECK-NEXT: for %i0 = 1 to 255 { +// CHECK-NEXT: affine.for %i0 = 1 to 255 { // CHECK-NEXT: %6 = affine.apply [[MAP_MINUS_ONE]](%i0) // CHECK-NEXT: %7 = load %4[%6] : memref<254xf32, 2> // CHECK-NEXT: } @@ -442,10 +442,10 @@ func @dma_loop_straightline_interspersed() { func @dma_mixed_loop_blocks() { %c0 = constant 0 : index %A = alloc() : memref<256 x 256 x vector<8 x f32>> - for %i = 0 to 256 { + affine.for %i = 0 to 256 { %v = load %A[%c0, %c0] : memref<256 x 256 x vector<8 x f32>> "foo"(%v) : (vector<8 x f32>) -> () - for %j = 0 to 256 { + affine.for %j = 0 to 256 { %w = load %A[%i, %j] : memref<256 x 256 x vector<8 x f32>> "bar"(%w) : (vector<8 x f32>) -> () } @@ -457,17 +457,17 @@ func @dma_mixed_loop_blocks() { // CHECK-DAG: [[TAG:%[0-9]+]] = alloc() : memref<1xi32> // CHECK: dma_start [[MEM]][%c0, %c0], [[BUF]][%c0, %c0], %c65536, [[TAG]][%c0] : memref<256x256xvector<8xf32>>, memref<256x256xvector<8xf32>, 2>, memref<1xi32> // CHECK-NEXT: dma_wait [[TAG]][%c0], %c65536 : memref<1xi32> -// CHECK-NEXT: for %i0 = 0 to 256 { +// CHECK-NEXT: affine.for %i0 = 0 to 256 { // CHECK-NEXT: %3 = load [[BUF]][%c0_0, %c0_0] : memref<256x256xvector<8xf32>, 2> -// CHECK: for %i1 = 0 to 256 { +// CHECK: affine.for %i1 = 0 to 256 { // CHECK-NEXT: %4 = load [[BUF]][%i0, %i1] : memref<256x256xvector<8xf32>, 2> // ----- // CHECK-LABEL: func @relative_loop_bounds func @relative_loop_bounds(%arg0: memref<1027xf32>) { - for %i0 = 0 to 1024 { - for %i2 = (d0) -> (d0)(%i0) to (d0) -> (d0 + 4)(%i0) { + affine.for %i0 = 0 to 1024 { + affine.for %i2 = (d0) -> (d0)(%i0) to (d0) -> (d0 + 4)(%i0) { %0 = constant 0.0 : f32 store %0, %arg0[%i2] : memref<1027xf32> } @@ -476,8 +476,8 @@ func @relative_loop_bounds(%arg0: memref<1027xf32>) { } // CHECK: [[BUF:%[0-9]+]] = alloc() : memref<1027xf32, 2> // CHECK-NEXT: [[MEM:%[0-9]+]] = alloc() : memref<1xi32> -// CHECK-NEXT: for %i0 = 0 to 1024 { -// CHECK-NEXT: for %i1 = {{#map[0-9]+}}(%i0) to {{#map[0-9]+}}(%i0) { +// CHECK-NEXT: affine.for %i0 = 0 to 1024 { +// CHECK-NEXT: affine.for %i1 = {{#map[0-9]+}}(%i0) to {{#map[0-9]+}}(%i0) { // CHECK-NEXT: %cst = constant 0.000000e+00 : f32 // CHECK-NEXT: store %cst, [[BUF]][%i1] : memref<1027xf32, 2> // CHECK-NEXT: } @@ -487,7 +487,7 @@ func @relative_loop_bounds(%arg0: memref<1027xf32>) { // ---- -// This should create a buffer of size 2 for %arg2. +// This should create a buffer of size 2 affine.for %arg2. #map_lb = (d0) -> (d0) #map_ub = (d0) -> (d0 + 3) @@ -498,9 +498,9 @@ func @test_analysis_util(%arg0: memref<4x4x16x1xf32>, %arg1: memref<144x9xf32>, %0 = alloc() : memref<64x1xf32> %1 = alloc() : memref<144x4xf32> %2 = constant 0.0 : f32 - for %i8 = 0 to 9 step 3 { - for %i9 = #map_lb(%i8) to #map_ub(%i8) { - for %i17 = 0 to 64 { + affine.for %i8 = 0 to 9 step 3 { + affine.for %i9 = #map_lb(%i8) to #map_ub(%i8) { + affine.for %i17 = 0 to 64 { %23 = affine.apply #map_acc(%i9) %25 = load %arg2[%23] : memref<2xf32> %26 = affine.apply #map_lb(%i17) @@ -511,11 +511,11 @@ func @test_analysis_util(%arg0: memref<4x4x16x1xf32>, %arg1: memref<144x9xf32>, } return %arg1, %arg2 : memref<144x9xf32>, memref<2xf32> } -// CHECK: for %i0 = 0 to 9 step 3 { +// CHECK: affine.for %i0 = 0 to 9 step 3 { // CHECK: [[BUF:%[0-9]+]] = alloc() : memref<2xf32, 2> // CHECK: dma_start %arg2[%4], [[BUF]] // CHECK: dma_wait %6[%c0], %c2_0 : memref<1xi32> -// CHECK: for %i1 = +// CHECK: affine.for %i1 = // ----- @@ -524,17 +524,17 @@ func @test_analysis_util(%arg0: memref<4x4x16x1xf32>, %arg1: memref<144x9xf32>, // FAST-MEM-16KB-LABEL: func @load_store_same_memref func @load_store_same_memref(%arg0: memref<256x1024xf32>) { - // FAST-MEM-16KB: for %i0 = 0 to 256 step 4 - for %i0 = 0 to 256 step 4 { + // FAST-MEM-16KB: affine.for %i0 = 0 to 256 step 4 + affine.for %i0 = 0 to 256 step 4 { // FAST-MEM-16KB: [[BUF:%[0-9]+]] = alloc() : memref<4x1024xf32, 2> // FAST-MEM-16KB: dma_start %arg0 // FAST-MEM-16KB-NEXT: dma_wait - // FAST-MEM-16KB: for %i1 - for %i1 = 0 to 1024 step 4 { - // FAST-MEM-16KB: for %i2 - for %i2 = (d0) -> (d0)(%i0) to (d0) -> (d0 + 4)(%i0) { - // FAST-MEM-16KB: for %i3 - for %i3 = (d0) -> (d0)(%i1) to (d0) -> (d0 + 4)(%i1) { + // FAST-MEM-16KB: affine.for %i1 + affine.for %i1 = 0 to 1024 step 4 { + // FAST-MEM-16KB: affine.for %i2 + affine.for %i2 = (d0) -> (d0)(%i0) to (d0) -> (d0 + 4)(%i0) { + // FAST-MEM-16KB: affine.for %i3 + affine.for %i3 = (d0) -> (d0)(%i1) to (d0) -> (d0 + 4)(%i1) { %3 = load %arg0[%i2, %i3] : memref<256x1024xf32> %4 = mulf %3, %3 : f32 store %4, %arg0[%i2, %i3] : memref<256x1024xf32> @@ -560,12 +560,12 @@ func @load_store_same_memref(%arg0: memref<256x1024xf32>) { #map1 = (d0) -> (d0 + 4) // FAST-MEM-16KB-LABEL: func @simple_matmul func @simple_matmul(%arg0: memref<8x8xvector<64xf32>>, %arg1: memref<8x8xvector<64xf32>>, %arg2: memref<8x8xvector<64xf32>>) -> memref<8x8xvector<64xf32>> { - for %i = 0 to 8 step 4 { - for %j = 0 to 8 step 4 { - for %k = 0 to 8 step 4 { - for %ii = #map0(%i) to #map1(%i) { - for %jj = #map0(%j) to #map1(%j) { - for %kk = #map0(%k) to #map1(%k) { + affine.for %i = 0 to 8 step 4 { + affine.for %j = 0 to 8 step 4 { + affine.for %k = 0 to 8 step 4 { + affine.for %ii = #map0(%i) to #map1(%i) { + affine.for %jj = #map0(%j) to #map1(%j) { + affine.for %kk = #map0(%k) to #map1(%k) { %5 = load %arg0[%ii, %kk] : memref<8x8xvector<64xf32>> %6 = load %arg1[%kk, %jj] : memref<8x8xvector<64xf32>> %7 = load %arg2[%ii, %jj] : memref<8x8xvector<64xf32>> @@ -580,18 +580,18 @@ func @simple_matmul(%arg0: memref<8x8xvector<64xf32>>, %arg1: memref<8x8xvector< } return %arg2 : memref<8x8xvector<64xf32>> } -// FAST-MEM-16KB: for %i0 = 0 to 8 step 4 { -// FAST-MEM-16KB: for %i1 = 0 to 8 step 4 { +// FAST-MEM-16KB: affine.for %i0 = 0 to 8 step 4 { +// FAST-MEM-16KB: affine.for %i1 = 0 to 8 step 4 { // FAST-MEM-16KB: dma_start %arg2 // FAST-MEM-16KB: dma_wait -// FAST-MEM-16KB: for %i2 = 0 to 8 step 4 { +// FAST-MEM-16KB: affine.for %i2 = 0 to 8 step 4 { // FAST-MEM-16KB: dma_start %arg0 // FAST-MEM-16KB: dma_wait // FAST-MEM-16KB: dma_start %arg1 // FAST-MEM-16KB: dma_wait -// FAST-MEM-16KB: for %i3 = #map{{[0-9]+}}(%i0) to #map{{[0-9]+}}(%i0) { -// FAST-MEM-16KB-NEXT: for %i4 = #map{{[0-9]+}}(%i1) to #map{{[0-9]+}}(%i1) { -// FAST-MEM-16KB-NEXT: for %i5 = #map{{[0-9]+}}(%i2) to #map{{[0-9]+}}(%i2) { +// FAST-MEM-16KB: affine.for %i3 = #map{{[0-9]+}}(%i0) to #map{{[0-9]+}}(%i0) { +// FAST-MEM-16KB-NEXT: affine.for %i4 = #map{{[0-9]+}}(%i1) to #map{{[0-9]+}}(%i1) { +// FAST-MEM-16KB-NEXT: affine.for %i5 = #map{{[0-9]+}}(%i2) to #map{{[0-9]+}}(%i2) { // FAST-MEM-16KB: } // FAST-MEM-16KB: } // FAST-MEM-16KB: } diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index 0e67e1178f8..4d21d006ff1 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -17,13 +17,13 @@ func @should_fuse_raw_dep_for_locality() { %m = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { %v0 = load %m[%i1] : memref<10xf32> } - // CHECK: for %i0 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: %1 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: store %cst, %0[%1] : memref<1xf32> // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i0, %i0) @@ -45,23 +45,23 @@ func @should_fuse_reduction_to_pointwise() { %cf7 = constant 7.0 : f32 - for %i0 = 0 to 10 { - for %i1 = 0 to 10 { + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { %v0 = load %b[%i0] : memref<10xf32> %v1 = load %a[%i0, %i1] : memref<10x10xf32> %v3 = addf %v0, %v1 : f32 store %v3, %b[%i0] : memref<10xf32> } } - for %i2 = 0 to 10 { + affine.for %i2 = 0 to 10 { %v4 = load %b[%i2] : memref<10xf32> store %v4, %c[%i2] : memref<10xf32> } // Should fuse in entire inner loop on %i1 from source loop nest, as %i1 // is not used in the access function of the store/load on %b. - // CHECK: for %i0 = 0 to 10 { - // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { + // CHECK-NEXT: affine.for %i1 = 0 to 10 { // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: %4 = load %0[%3] : memref<1xf32> // CHECK-NEXT: %5 = load %1[%i0, %i1] : memref<10x10xf32> @@ -89,15 +89,15 @@ func @should_fuse_loop_nests_with_shifts() { %a = alloc() : memref<10x10xf32> %cf7 = constant 7.0 : f32 - for %i0 = 0 to 9 { - for %i1 = 0 to 9 { + affine.for %i0 = 0 to 9 { + affine.for %i1 = 0 to 9 { %idx = affine.apply (d0) -> (d0 + 1) (%i0) %idy = affine.apply (d0) -> (d0 + 1) (%i1) store %cf7, %a[%idx, %idy] : memref<10x10xf32> } } - for %i2 = 1 to 10 { - for %i3 = 1 to 10 { + affine.for %i2 = 1 to 10 { + affine.for %i3 = 1 to 10 { %v0 = load %a[%i2, %i3] : memref<10x10xf32> } } @@ -110,8 +110,8 @@ func @should_fuse_loop_nests_with_shifts() { // *) Fifth affine apply shifts the loads access function by '-1', because // of the offset induced by reducing the memref shape from 10x10 to 9x9. // NOTE: Should create a private memref with reduced shape 9x9xf32. - // CHECK: for %i0 = 1 to 10 { - // CHECK-NEXT: for %i1 = 1 to 10 { + // CHECK: affine.for %i0 = 1 to 10 { + // CHECK-NEXT: affine.for %i1 = 1 to 10 { // CHECK-NEXT: %1 = affine.apply [[MAP_SHIFT_MINUS_ONE_R1]](%i0) // CHECK-NEXT: %2 = affine.apply [[MAP_SHIFT_MINUS_ONE_R1]](%i1) // CHECK-NEXT: %3 = affine.apply [[MAP_SHIFT_BY_ONE]](%1) @@ -139,27 +139,27 @@ func @should_fuse_loop_nest() { %b = alloc() : memref<10x10xf32> %cf7 = constant 7.0 : f32 - for %i0 = 0 to 10 { - for %i1 = 0 to 10 { + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { store %cf7, %a[%i0, %i1] : memref<10x10xf32> } } - for %i2 = 0 to 10 { - for %i3 = 0 to 10 { + affine.for %i2 = 0 to 10 { + affine.for %i3 = 0 to 10 { %v0 = load %a[%i3, %i2] : memref<10x10xf32> store %v0, %b[%i2, %i3] : memref<10x10xf32> } } - for %i4 = 0 to 10 { - for %i5 = 0 to 10 { + affine.for %i4 = 0 to 10 { + affine.for %i5 = 0 to 10 { %v1 = load %b[%i4, %i5] : memref<10x10xf32> } } // Expecting private memref for '%a' first, then private memref for '%b'. // CHECK-DAG: [[NEWA:%[0-9]+]] = alloc() : memref<1x1xf32> // CHECK-DAG: [[NEWB:%[0-9]+]] = alloc() : memref<1x1xf32> - // CHECK: for %i0 = 0 to 10 { - // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { + // CHECK-NEXT: affine.for %i1 = 0 to 10 { // CHECK-NEXT: %2 = affine.apply [[MAP_D2_D0_DIFF]](%i1, %i0, %i1, %i0) // CHECK-NEXT: %3 = affine.apply [[MAP_D3_D1_DIFF]](%i1, %i0, %i1, %i0) // CHECK-NEXT: store %cst, [[NEWA]][%2, %3] : memref<1x1xf32> @@ -190,23 +190,23 @@ func @should_fuse_across_intermediate_loop_with_no_deps() { %cf7 = constant 7.0 : f32 - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { %v0 = load %a[%i0] : memref<10xf32> store %v0, %b[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { store %cf7, %c[%i1] : memref<10xf32> } - for %i2 = 0 to 10 { + affine.for %i2 = 0 to 10 { %v1 = load %b[%i2] : memref<10xf32> } // Should fuse first loop (past second loop with no dependences) into third. // Note that fusion creates a private memref '%2' for the fused loop nest. - // CHECK: for %i0 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %2[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK: for %i1 = 0 to 10 { + // CHECK: affine.for %i1 = 0 to 10 { // CHECK-NEXT: %3 = load %1[%i1] : memref<10xf32> // CHECK-NEXT: %4 = affine.apply [[MAP0]](%i1, %i1) // CHECK-NEXT: store %3, %0[%4] : memref<1xf32> @@ -228,13 +228,13 @@ func @should_fuse_all_loops() { %cf7 = constant 7.0 : f32 // Set up flow dependences from first and second loops to third. - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %a[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { store %cf7, %b[%i1] : memref<10xf32> } - for %i2 = 0 to 10 { + affine.for %i2 = 0 to 10 { %v0 = load %a[%i2] : memref<10xf32> %v1 = load %b[%i2] : memref<10xf32> } @@ -243,7 +243,7 @@ func @should_fuse_all_loops() { // Expecting private memref for '%a' first, then private memref for '%b'. // CHECK-DAG: [[NEWA:%[0-9]+]] = alloc() : memref<1xf32> // CHECK-DAG: [[NEWB:%[0-9]+]] = alloc() : memref<1xf32> - // CHECK: for %i0 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: store %cst, [[NEWA]][%2] : memref<1xf32> // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i0) @@ -269,27 +269,27 @@ func @should_fuse_first_and_second_loops() { %cf7 = constant 7.0 : f32 - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %a[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { %v0 = load %a[%i1] : memref<10xf32> store %cf7, %b[%i1] : memref<10xf32> } - for %i2 = 0 to 10 { + affine.for %i2 = 0 to 10 { %v1 = load %c[%i2] : memref<10xf32> } // Should fuse first loop into the second (last loop should not be fused). // Should create private memref '%2' for fused loop. - // CHECK: for %i0 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: store %cst, %0[%3] : memref<1xf32> // CHECK-NEXT: %4 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: %5 = load %0[%4] : memref<1xf32> // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK: for %i1 = 0 to 10 { + // CHECK: affine.for %i1 = 0 to 10 { // CHECK-NEXT: %6 = load %2[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return @@ -311,28 +311,28 @@ func @should_not_fuse_would_create_cycle() { // 1) loop0 -> loop1 on memref '%a' // 2) loop0 -> loop2 on memref '%b' // 3) loop1 -> loop2 on memref '%c' - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { %v0 = load %a[%i0] : memref<10xf32> store %cf7, %b[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { store %cf7, %a[%i1] : memref<10xf32> %v1 = load %c[%i1] : memref<10xf32> } - for %i2 = 0 to 10 { + affine.for %i2 = 0 to 10 { %v2 = load %b[%i2] : memref<10xf32> store %cf7, %c[%i2] : memref<10xf32> } // Should not fuse: fusing loop first loop into last would create a cycle. - // CHECK: for %i0 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: %3 = load %0[%i0] : memref<10xf32> // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK: for %i1 = 0 to 10 { + // CHECK: affine.for %i1 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i1] : memref<10xf32> // CHECK-NEXT: %4 = load %2[%i1] : memref<10xf32> // CHECK-NEXT: } - // CHECK: for %i2 = 0 to 10 { + // CHECK: affine.for %i2 = 0 to 10 { // CHECK-NEXT: %5 = load %1[%i2] : memref<10xf32> // CHECK-NEXT: store %cst, %2[%i2] : memref<10xf32> // CHECK-NEXT: } @@ -347,23 +347,23 @@ func @should_not_fuse_across_waw_dep() { %m = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { store %cf7, %m[%i1] : memref<10xf32> } - for %i2 = 0 to 10 { + affine.for %i2 = 0 to 10 { %v1 = load %m[%i2] : memref<10xf32> } // Fusing loop %i0 to %i2 would violate the WAW dependence between %i0 and %i1 - // CHECK: for %i0 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK: for %i1 = 0 to 10 { + // CHECK: affine.for %i1 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i1] : memref<10xf32> // CHECK-NEXT: } - // CHECK: for %i2 = 0 to 10 { + // CHECK: affine.for %i2 = 0 to 10 { // CHECK-NEXT: %1 = load %0[%i2] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return @@ -380,27 +380,27 @@ func @should_fuse_and_move_to_preserve_war_dep() { %b = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { %v0 = load %a[%i0] : memref<10xf32> store %v0, %b[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { store %cf7, %a[%i1] : memref<10xf32> } - for %i2 = 0 to 10 { + affine.for %i2 = 0 to 10 { %v1 = load %b[%i2] : memref<10xf32> } // Loops '%i1' and '%i2' have no dependences. We can fuse a slice of '%i0' // into '%i2' if we move the fused loop nest before '%i1', which preserves // the WAR dependence from load '%a' in '%i0' to the store '%a' in loop '%i1'. - // CHECK: for %i0 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: %2 = load %1[%i0] : memref<10xf32> // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: store %2, %0[%3] : memref<1xf32> // CHECK-NEXT: %4 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: %5 = load %0[%4] : memref<1xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: affine.for %i1 = 0 to 10 { // CHECK-NEXT: store %cst, %1[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return @@ -416,20 +416,20 @@ func @should_fuse_with_private_memref_if_top_level_access() { %m = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { %v0 = load %m[%i1] : memref<10xf32> } %c0 = constant 4 : index %v1 = load %m[%c0] : memref<10xf32> // Top-level load to '%m' should prevent fusion. - // CHECK: for %i0 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: affine.for %i1 = 0 to 10 { // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i1, %i1) // CHECK-NEXT: store %cst, %0[%2] : memref<1xf32> // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i1, %i1) @@ -447,13 +447,13 @@ func @should_fuse_no_top_level_access() { %m = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { %v0 = load %m[%i1] : memref<10xf32> } - // CHECK: for %i0 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: %1 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: store %cst, %0[%1] : memref<1xf32> // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i0, %i0) @@ -472,20 +472,20 @@ func @should_not_fuse_if_inst_at_top_level() { %m = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { %v0 = load %m[%i1] : memref<10xf32> } %c0 = constant 4 : index affine.if #set0(%c0) { } // Top-level IfOp should prevent fusion. - // CHECK: for %i0 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK: for %i1 = 0 to 10 { + // CHECK: affine.for %i1 = 0 to 10 { // CHECK-NEXT: %1 = load %0[%i1] : memref<10xf32> // CHECK-NEXT: } return @@ -501,20 +501,20 @@ func @should_not_fuse_if_inst_in_loop_nest() { %cf7 = constant 7.0 : f32 %c4 = constant 4 : index - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { affine.if #set0(%c4) { } %v0 = load %m[%i1] : memref<10xf32> } // IfOp in ForInst should prevent fusion. - // CHECK: for %i0 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK: for %i1 = 0 to 10 { + // CHECK: affine.for %i1 = 0 to 10 { // CHECK-NEXT: affine.if #set0(%c4) { // CHECK-NEXT: } // CHECK-NEXT: %1 = load %0[%i1] : memref<10xf32> @@ -533,24 +533,24 @@ func @permute_and_fuse() { %m = alloc() : memref<10x20x30xf32> %cf7 = constant 7.0 : f32 - for %i0 = 0 to 10 { - for %i1 = 0 to 20 { - for %i2 = 0 to 30 { + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 20 { + affine.for %i2 = 0 to 30 { store %cf7, %m[%i0, %i1, %i2] : memref<10x20x30xf32> } } } - for %i3 = 0 to 30 { - for %i4 = 0 to 10 { - for %i5 = 0 to 20 { + affine.for %i3 = 0 to 30 { + affine.for %i4 = 0 to 10 { + affine.for %i5 = 0 to 20 { %v0 = load %m[%i4, %i5, %i3] : memref<10x20x30xf32> "foo"(%v0) : (f32) -> () } } } -// CHECK: for %i0 = 0 to 30 { -// CHECK-NEXT: for %i1 = 0 to 10 { -// CHECK-NEXT: for %i2 = 0 to 20 { +// CHECK: affine.for %i0 = 0 to 30 { +// CHECK-NEXT: affine.for %i1 = 0 to 10 { +// CHECK-NEXT: affine.for %i2 = 0 to 20 { // CHECK-NEXT: %1 = affine.apply [[MAP0]](%i1, %i2, %i0, %i1, %i2, %i0) // CHECK-NEXT: %2 = affine.apply [[MAP1]](%i1, %i2, %i0, %i1, %i2, %i0) // CHECK-NEXT: %3 = affine.apply [[MAP2]](%i1, %i2, %i0, %i1, %i2, %i0) @@ -579,22 +579,22 @@ func @permute_and_fuse() { func @fuse_reshape_64_16_4(%in : memref<64xf32>) { %out = alloc() : memref<16x4xf32> - for %i0 = 0 to 64 { + affine.for %i0 = 0 to 64 { %v = load %in[%i0] : memref<64xf32> %idx = affine.apply (d0) -> (d0 floordiv 4) (%i0) %idy = affine.apply (d0) -> (d0 mod 4) (%i0) store %v, %out[%idx, %idy] : memref<16x4xf32> } - for %i1 = 0 to 16 { - for %i2 = 0 to 4 { + affine.for %i1 = 0 to 16 { + affine.for %i2 = 0 to 4 { %w = load %out[%i1, %i2] : memref<16x4xf32> "foo"(%w) : (f32) -> () } } return - // CHECK: for %i0 = - // CHECK-NEXT: for %i1 = + // CHECK: affine.for %i0 = + // CHECK-NEXT: affine.for %i1 = // CHECK-NOT: for // CHECK: } // CHECK-NEXT: } @@ -613,19 +613,19 @@ func @fuse_reshape_16_4_64() { %in = alloc() : memref<16x4xf32> %out = alloc() : memref<64xf32> - for %i0 = 0 to 16 { - for %i1 = 0 to 4 { + affine.for %i0 = 0 to 16 { + affine.for %i1 = 0 to 4 { %v = load %in[%i0, %i1] : memref<16x4xf32> %idx = affine.apply (d0, d1) -> (4*d0 + d1) (%i0, %i1) store %v, %out[%idx] : memref<64xf32> } } - for %i2 = 0 to 64 { + affine.for %i2 = 0 to 64 { %w = load %out[%i2] : memref<64xf32> "foo"(%w) : (f32) -> () } -// CHECK: for %i0 = 0 to 64 { +// CHECK: affine.for %i0 = 0 to 64 { // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i0) // CHECK-NEXT: %3 = affine.apply [[MAP1]](%i0) // CHECK-NEXT: %4 = load %1[%2, %3] : memref<16x4xf32> @@ -651,12 +651,12 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> { %live_out = alloc() : memref<64x9xi32> // Initialize input. - for %i0 = 0 to 2 { - for %i1 = 0 to 2 { - for %i2 = 0 to 3 { - for %i3 = 0 to 3 { - for %i4 = 0 to 16 { - for %i5 = 0 to 1 { + affine.for %i0 = 0 to 2 { + affine.for %i1 = 0 to 2 { + affine.for %i2 = 0 to 3 { + affine.for %i3 = 0 to 3 { + affine.for %i4 = 0 to 16 { + affine.for %i5 = 0 to 1 { %val = "foo"(%i0, %i1, %i2, %i3, %i4, %i5) : (index, index, index, index, index, index) -> i32 store %val, %in[%i0, %i1, %i2, %i3, %i4, %i5] : memref<2x2x3x3x16x1xi32> } @@ -666,8 +666,8 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> { } } - for %ii = 0 to 64 { - for %jj = 0 to 9 { + affine.for %ii = 0 to 64 { + affine.for %jj = 0 to 9 { // Convert output coordinates to linear index. %a0 = affine.apply (d0, d1) -> (d0 * 9 + d1) (%ii, %jj) %0 = affine.apply (d0) -> (d0 floordiv (2 * 3 * 3 * 16 * 1))(%a0) @@ -681,8 +681,8 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> { } } - for %i = 0 to 64 { - for %j = 0 to 9 { + affine.for %i = 0 to 64 { + affine.for %j = 0 to 9 { %a = load %out[%i, %j] : memref<64x9xi32> %b = muli %a, %a : i32 store %b, %live_out[%i, %j] : memref<64x9xi32> @@ -718,8 +718,8 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> { // CHECK: %0 = alloc() : memref<1x2x3x3x16x1xi32> // CHECK: %1 = alloc() : memref<1x1xi32> // CHECK: %2 = alloc() : memref<64x9xi32> -// CHECK-NEXT: for %i0 = 0 to 64 { -// CHECK-NEXT: for %i1 = 0 to 9 { +// CHECK-NEXT: affine.for %i0 = 0 to 64 { +// CHECK-NEXT: affine.for %i1 = 0 to 9 { // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i1) // CHECK-NEXT: %4 = affine.apply [[MAP1]](%i0, %i1) // CHECK-NEXT: %5 = affine.apply [[MAP2]](%i0, %i1) @@ -769,14 +769,14 @@ func @fuse_symbolic_bounds(%M : index, %N : index) { %c0 = constant 0.0 : f32 %s = constant 5 : index - for %i0 = 0 to %M { - for %i1 = 0 to (d0) -> (d0 + 5) (%N) { + affine.for %i0 = 0 to %M { + affine.for %i1 = 0 to (d0) -> (d0 + 5) (%N) { store %c0, %m[%i0, %i1] : memref } } - for %i2 = 0 to %M { - for %i3 = 0 to %N { + affine.for %i2 = 0 to %M { + affine.for %i3 = 0 to %N { %idy = affine.apply (d0)[s0] -> (d0 + s0) (%i3)[%s] %v = load %m[%i2, %idy] : memref } @@ -793,16 +793,16 @@ func @should_fuse_reduction_at_depth1() { %a = alloc() : memref<10x100xf32> %b = alloc() : memref<10xf32> - for %i0 = 0 to 10 { - for %i1 = 0 to 100 { + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 100 { %v0 = load %b[%i0] : memref<10xf32> %v1 = load %a[%i0, %i1] : memref<10x100xf32> %v2 = "maxf"(%v0, %v1) : (f32, f32) -> f32 store %v2, %b[%i0] : memref<10xf32> } } - for %i2 = 0 to 10 { - for %i3 = 0 to 100 { + affine.for %i2 = 0 to 10 { + affine.for %i3 = 0 to 100 { %v3 = load %b[%i2] : memref<10xf32> %v4 = load %a[%i2, %i3] : memref<10x100xf32> %v5 = subf %v4, %v3 : f32 @@ -813,8 +813,8 @@ func @should_fuse_reduction_at_depth1() { // loop nest, which improves locality and enables subsequence passes to // decrease the reduction memref size and possibly place it in a faster // memory space. - // CHECK: for %i0 = 0 to 10 { - // CHECK-NEXT: for %i1 = 0 to 100 { + // CHECK: affine.for %i0 = 0 to 10 { + // CHECK-NEXT: affine.for %i1 = 0 to 100 { // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: %3 = load %0[%2] : memref<1xf32> // CHECK-NEXT: %4 = load %1[%i0, %i1] : memref<10x100xf32> @@ -822,7 +822,7 @@ func @should_fuse_reduction_at_depth1() { // CHECK-NEXT: %6 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: store %5, %0[%6] : memref<1xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i2 = 0 to 100 { + // CHECK-NEXT: affine.for %i2 = 0 to 100 { // CHECK-NEXT: %7 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: %8 = load %0[%7] : memref<1xf32> // CHECK-NEXT: %9 = load %1[%i0, %i2] : memref<10x100xf32> @@ -844,19 +844,19 @@ func @should_fuse_at_src_depth1_and_dst_depth1() { %a = alloc() : memref<100x16xf32> %b = alloc() : memref<100x16xf32> - for %i0 = 0 to 100 { - for %i1 = 0 to 16 { + affine.for %i0 = 0 to 100 { + affine.for %i1 = 0 to 16 { %v0 = load %a[%i0, %i1] : memref<100x16xf32> "op0"(%v0) : (f32) -> () } - for %i2 = 0 to 16 { + affine.for %i2 = 0 to 16 { %v1 = "op1"() : () -> (f32) store %v1, %b[%i0, %i2] : memref<100x16xf32> } } - for %i3 = 0 to 100 { - for %i4 = 0 to 16 { + affine.for %i3 = 0 to 100 { + affine.for %i4 = 0 to 16 { %v2 = load %b[%i3, %i4] : memref<100x16xf32> "op2"(%v2) : (f32) -> () } @@ -866,18 +866,18 @@ func @should_fuse_at_src_depth1_and_dst_depth1() { // destination loop nest at depth2 causes extra computation. Instead, // the fusion algorithm should detect that the source loop should be sliced // at depth 1 and the slice should be inserted at depth 1. - // CHECK: for %i0 = 0 to 100 { - // CHECK-NEXT: for %i1 = 0 to 16 { + // CHECK: affine.for %i0 = 0 to 100 { + // CHECK-NEXT: affine.for %i1 = 0 to 16 { // CHECK-NEXT: %2 = load %1[%i0, %i1] : memref<100x16xf32> // CHECK-NEXT: "op0"(%2) : (f32) -> () // CHECK-NEXT: } - // CHECK-NEXT: for %i2 = 0 to 16 { + // CHECK-NEXT: affine.for %i2 = 0 to 16 { // CHECK-NEXT: %3 = "op1"() : () -> f32 // CHECK-NEXT: %4 = affine.apply [[MAP0]](%i0, %i0, %i2) // CHECK-NEXT: %5 = affine.apply [[MAP1]](%i0, %i0, %i2) // CHECK-NEXT: store %3, %0[%4, %5] : memref<1x16xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i3 = 0 to 16 { + // CHECK-NEXT: affine.for %i3 = 0 to 16 { // CHECK-NEXT: %6 = affine.apply [[MAP0]](%i0, %i0, %i3) // CHECK-NEXT: %7 = affine.apply [[MAP1]](%i0, %i0, %i3) // CHECK-NEXT: %8 = load %0[%6, %7] : memref<1x16xf32> @@ -897,20 +897,20 @@ func @should_fuse_src_depth1_at_dst_depth2() { %a = alloc() : memref<100xf32> %c0 = constant 0.0 : f32 - for %i0 = 0 to 100 { + affine.for %i0 = 0 to 100 { store %c0, %a[%i0] : memref<100xf32> } - for %i1 = 0 to 10 { - for %i2 = 0 to 10 { + affine.for %i1 = 0 to 10 { + affine.for %i2 = 0 to 10 { %a0 = affine.apply (d0, d1) -> (d0 * 10 + d1) (%i1, %i2) %v0 = load %a[%a0] : memref<100xf32> } } // The source loop nest slice loop bound is a function of both destination // loop IVs, so we should slice at depth 1 and insert the slice at depth 2. - // CHECK: for %i0 = 0 to 10 { - // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { + // CHECK-NEXT: affine.for %i1 = 0 to 10 { // CHECK-NEXT: %1 = affine.apply [[MAP0]](%i0, %i1) // CHECK-NEXT: %2 = affine.apply [[MAP1]](%i0, %i1, %1) // CHECK-NEXT: store %cst, %0[%2] : memref<1xf32> @@ -930,16 +930,16 @@ func @fusion_at_depth0_not_currently_supported() { %0 = alloc() : memref<10xf32> %c0 = constant 0 : index %cst = constant 0.000000e+00 : f32 - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cst, %0[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { %1 = load %0[%c0] : memref<10xf32> } // NOTE: Should shrink memref size to 1 element access by load in dst loop // nest, and make the store in the slice store to the same element. // CHECK-DAG: %0 = alloc() : memref<1xf32> - // CHECK: for %i0 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%c0] : memref<1xf32> // CHECK-NEXT: %1 = load %0[%c0_0] : memref<1xf32> // CHECK-NEXT: } @@ -965,18 +965,18 @@ func @should_fuse_deep_loop_nests() { %c1 = constant 1 : index %c1_0 = constant 1 : index %cst = constant 0.000000e+00 : f32 - for %i0 = 0 to 2 { - for %i1 = 0 to 2 { - for %i2 = 0 to 3 { - for %i3 = 0 to 3 { - for %i4 = 0 to 16 { - for %i5 = 0 to 10 { + affine.for %i0 = 0 to 2 { + affine.for %i1 = 0 to 2 { + affine.for %i2 = 0 to 3 { + affine.for %i3 = 0 to 3 { + affine.for %i4 = 0 to 16 { + affine.for %i5 = 0 to 10 { %3 = load %0[%i0, %i1, %i2, %i3, %i4, %i5] : memref<2x2x3x3x16x10xf32, 2> } } - for %i6 = 0 to 16 { - for %i7 = 0 to 10 { + affine.for %i6 = 0 to 16 { + affine.for %i7 = 0 to 10 { store %cst, %1[%i0, %i1, %i2, %i3, %i6, %i7] : memref<2x2x3x3x16x10xf32, 2> } @@ -985,22 +985,22 @@ func @should_fuse_deep_loop_nests() { } } } - for %i8 = 0 to 3 { - for %i9 = 0 to 3 { - for %i10 = 0 to 2 { - for %i11 = 0 to 2 { - for %i12 = 0 to 3 { - for %i13 = 0 to 3 { - for %i14 = 0 to 2 { - for %i15 = 0 to 2 { - for %i16 = 0 to 16 { - for %i17 = 0 to 10 { + affine.for %i8 = 0 to 3 { + affine.for %i9 = 0 to 3 { + affine.for %i10 = 0 to 2 { + affine.for %i11 = 0 to 2 { + affine.for %i12 = 0 to 3 { + affine.for %i13 = 0 to 3 { + affine.for %i14 = 0 to 2 { + affine.for %i15 = 0 to 2 { + affine.for %i16 = 0 to 16 { + affine.for %i17 = 0 to 10 { %5 = load %0[%i14, %i15, %i12, %i13, %i16, %i17] : memref<2x2x3x3x16x10xf32, 2> } } - for %i18 = 0 to 16 { - for %i19 = 0 to 10 { + affine.for %i18 = 0 to 16 { + affine.for %i19 = 0 to 10 { %6 = load %1[%i10, %i11, %i8, %i9, %i18, %i19] : memref<2x2x3x3x16x10xf32, 2> } @@ -1018,19 +1018,19 @@ func @should_fuse_deep_loop_nests() { // where the destination loops nests have been interchanged. // CHECK-DAG: %0 = alloc() : memref<1x1x1x1x16x10xf32, 2> -// CHECK: for %i0 = 0 to 3 { -// CHECK-NEXT: for %i1 = 0 to 3 { -// CHECK-NEXT: for %i2 = 0 to 2 { -// CHECK-NEXT: for %i3 = 0 to 2 { -// CHECK-NEXT: for %i4 = 0 to 3 { -// CHECK-NEXT: for %i5 = 0 to 3 { -// CHECK-NEXT: for %i6 = 0 to 16 { -// CHECK-NEXT: for %i7 = 0 to 10 { +// CHECK: affine.for %i0 = 0 to 3 { +// CHECK-NEXT: affine.for %i1 = 0 to 3 { +// CHECK-NEXT: affine.for %i2 = 0 to 2 { +// CHECK-NEXT: affine.for %i3 = 0 to 2 { +// CHECK-NEXT: affine.for %i4 = 0 to 3 { +// CHECK-NEXT: affine.for %i5 = 0 to 3 { +// CHECK-NEXT: affine.for %i6 = 0 to 16 { +// CHECK-NEXT: affine.for %i7 = 0 to 10 { // CHECK-NEXT: %3 = load %1[%i2, %i3, %i0, %i1, %i6, %i7] : memref<2x2x3x3x16x10xf32, 2> // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: for %i8 = 0 to 16 { -// CHECK-NEXT: for %i9 = 0 to 10 { +// CHECK-NEXT: affine.for %i8 = 0 to 16 { +// CHECK-NEXT: affine.for %i9 = 0 to 10 { // CHECK-NEXT: %4 = affine.apply [[MAP0]](%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i8, %i9) // CHECK-NEXT: %5 = affine.apply [[MAP1]](%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i8, %i9) // CHECK-NEXT: %6 = affine.apply [[MAP2]](%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i8, %i9) @@ -1040,15 +1040,15 @@ func @should_fuse_deep_loop_nests() { // CHECK-NEXT: store %cst, %0[%4, %5, %6, %7, %8, %9] : memref<1x1x1x1x16x10xf32, 2> // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: for %i10 = 0 to 2 { -// CHECK-NEXT: for %i11 = 0 to 2 { -// CHECK-NEXT: for %i12 = 0 to 16 { -// CHECK-NEXT: for %i13 = 0 to 10 { +// CHECK-NEXT: affine.for %i10 = 0 to 2 { +// CHECK-NEXT: affine.for %i11 = 0 to 2 { +// CHECK-NEXT: affine.for %i12 = 0 to 16 { +// CHECK-NEXT: affine.for %i13 = 0 to 10 { // CHECK-NEXT: %10 = load %1[%i10, %i11, %i4, %i5, %i12, %i13] : memref<2x2x3x3x16x10xf32, 2> // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: for %i14 = 0 to 16 { -// CHECK-NEXT: for %i15 = 0 to 10 { +// CHECK-NEXT: affine.for %i14 = 0 to 16 { +// CHECK-NEXT: affine.for %i15 = 0 to 10 { // CHECK-NEXT: %11 = affine.apply [[MAP0]](%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i14, %i15) // CHECK-NEXT: %12 = affine.apply [[MAP1]](%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i14, %i15) // CHECK-NEXT: %13 = affine.apply [[MAP2]](%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i14, %i15) @@ -1082,17 +1082,17 @@ func @should_fuse_at_depth1_and_reduce_slice_trip_count() { %c0 = constant 0 : index %cf0 = constant 0.0 : f32 - for %i0 = 0 to 4 { - for %i1 = 0 to 256 { + affine.for %i0 = 0 to 4 { + affine.for %i1 = 0 to 256 { %v0 = load %b[%i0, %i1] : memref<4x256xf32> } - for %i2 = 0 to 256 { + affine.for %i2 = 0 to 256 { store %cf0, %a[%i0, %i2] : memref<4x256xf32> } } - for %d0 = 0 to 4 { - for %d1 = 0 to 16 { + affine.for %d0 = 0 to 4 { + affine.for %d1 = 0 to 16 { %v1 = load %a[%d0, %d1] : memref<4x256xf32> } } @@ -1106,16 +1106,16 @@ func @should_fuse_at_depth1_and_reduce_slice_trip_count() { // is reduced from the original shape from 4x256 to 4x16 because of the // data accessed by the load. // CHECK-DAG: %0 = alloc() : memref<1x16xf32> - // CHECK: for %i0 = 0 to 4 { - // CHECK-NEXT: for %i1 = 0 to 256 { + // CHECK: affine.for %i0 = 0 to 4 { + // CHECK-NEXT: affine.for %i1 = 0 to 256 { // CHECK-NEXT: %2 = load %1[%i0, %i1] : memref<4x256xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i2 = 0 to 16 { + // CHECK-NEXT: affine.for %i2 = 0 to 16 { // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i0, %i2) // CHECK-NEXT: %4 = affine.apply [[MAP1]](%i0, %i0, %i2) // CHECK-NEXT: store %cst, %0[%3, %4] : memref<1x16xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i3 = 0 to 16 { + // CHECK-NEXT: affine.for %i3 = 0 to 16 { // CHECK-NEXT: %5 = affine.apply [[MAP0]](%i0, %i0, %i3) // CHECK-NEXT: %6 = affine.apply [[MAP1]](%i0, %i0, %i3) // CHECK-NEXT: %7 = load %0[%5, %6] : memref<1x16xf32> @@ -1133,31 +1133,31 @@ func @should_fuse_at_depth1_with_trip_count_20() { %c0 = constant 0 : index %cf0 = constant 0.0 : f32 - for %i0 = 0 to 100 { + affine.for %i0 = 0 to 100 { store %cf0, %a[%i0]: memref<100xf32> } - for %i1 = 0 to 5 { - for %i2 = 0 to 10 { + affine.for %i1 = 0 to 5 { + affine.for %i2 = 0 to 10 { %v0 = load %a[%i2]: memref<100xf32> } - for %i3 = 0 to 10 { - for %i4 = 0 to 20 { + affine.for %i3 = 0 to 10 { + affine.for %i4 = 0 to 20 { %v1 = load %a[%i4]: memref<100xf32> } } } // NOTE: The size of the private memref created for fusion is shrunk to 20xf32 // CHECK-DAG: %0 = alloc() : memref<20xf32> - // CHECK: for %i0 = 0 to 5 { - // CHECK-NEXT: for %i1 = 0 to 20 { + // CHECK: affine.for %i0 = 0 to 5 { + // CHECK-NEXT: affine.for %i1 = 0 to 20 { // CHECK-NEXT: store %cst, %0[%i1] : memref<20xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i2 = 0 to 10 { + // CHECK-NEXT: affine.for %i2 = 0 to 10 { // CHECK-NEXT: %1 = load %0[%i2] : memref<20xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i3 = 0 to 10 { - // CHECK-NEXT: for %i4 = 0 to 20 { + // CHECK-NEXT: affine.for %i3 = 0 to 10 { + // CHECK-NEXT: affine.for %i4 = 0 to 20 { // CHECK-NEXT: %2 = load %0[%i4] : memref<20xf32> // CHECK-NEXT: } // CHECK-NEXT: } @@ -1174,31 +1174,31 @@ func @should_fuse_at_depth1_with_trip_count_19() { %c0 = constant 0 : index %cf0 = constant 0.0 : f32 - for %i0 = 0 to 100 { + affine.for %i0 = 0 to 100 { store %cf0, %a[%i0]: memref<100xf32> } - for %i1 = 0 to 5 { - for %i2 = 0 to 19 { + affine.for %i1 = 0 to 5 { + affine.for %i2 = 0 to 19 { %v0 = load %a[%i2]: memref<100xf32> } - for %i3 = 0 to 10 { - for %i4 = 0 to 10 { + affine.for %i3 = 0 to 10 { + affine.for %i4 = 0 to 10 { %v1 = load %a[%i4]: memref<100xf32> } } } // NOTE: The size of the private memref created for fusion is shrunk to 19xf32 // CHECK-DAG: %0 = alloc() : memref<19xf32> - // CHECK: for %i0 = 0 to 5 { - // CHECK-NEXT: for %i1 = 0 to 19 { + // CHECK: affine.for %i0 = 0 to 5 { + // CHECK-NEXT: affine.for %i1 = 0 to 19 { // CHECK-NEXT: store %cst, %0[%i1] : memref<19xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i2 = 0 to 19 { + // CHECK-NEXT: affine.for %i2 = 0 to 19 { // CHECK-NEXT: %1 = load %0[%i2] : memref<19xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i3 = 0 to 10 { - // CHECK-NEXT: for %i4 = 0 to 10 { + // CHECK-NEXT: affine.for %i3 = 0 to 10 { + // CHECK-NEXT: affine.for %i4 = 0 to 10 { // CHECK-NEXT: %2 = load %0[%i4] : memref<19xf32> // CHECK-NEXT: } // CHECK-NEXT: } @@ -1216,26 +1216,26 @@ func @should_fuse_with_private_memrefs_with_diff_shapes() { %m = alloc() : memref<100xf32> %cf7 = constant 7.0 : f32 - for %i0 = 0 to 100 { + affine.for %i0 = 0 to 100 { store %cf7, %m[%i0] : memref<100xf32> } - for %i1 = 0 to 17 { + affine.for %i1 = 0 to 17 { %v0 = load %m[%i1] : memref<100xf32> } - for %i2 = 0 to 82 { + affine.for %i2 = 0 to 82 { %v1 = load %m[%i2] : memref<100xf32> } // Should create two new private memrefs customized to the shapes accessed // by loops %i1 and %i2. // CHECK-DAG: %0 = alloc() : memref<1xf32> // CHECK-DAG: %1 = alloc() : memref<1xf32> - // CHECK: for %i0 = 0 to 17 { + // CHECK: affine.for %i0 = 0 to 17 { // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: store %cst, %1[%2] : memref<1xf32> // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: %4 = load %1[%3] : memref<1xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i1 = 0 to 82 { + // CHECK-NEXT: affine.for %i1 = 0 to 82 { // CHECK-NEXT: %5 = affine.apply [[MAP0]](%i1, %i1) // CHECK-NEXT: store %cst, %0[%5] : memref<1xf32> // CHECK-NEXT: %6 = affine.apply [[MAP0]](%i1, %i1) @@ -1251,18 +1251,18 @@ func @should_fuse_with_private_memrefs_with_diff_shapes() { func @should_not_fuse_live_out_arg(%arg0: memref<10xf32>) { %cf7 = constant 7.0 : f32 - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %arg0[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { %v0 = load %arg0[%i1] : memref<10xf32> } // This tests that the loop nest '%i0' should not be removed after fusion // because it writes to memref argument '%arg0'. - // CHECK: for %i0 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %arg0[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: affine.for %i1 = 0 to 10 { // CHECK-NEXT: %0 = load %arg0[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return @@ -1275,19 +1275,19 @@ func @should_not_fuse_live_out_arg(%arg0: memref<10xf32>) { func @should_not_fuse_escaping_memref() -> memref<10xf32> { %cf7 = constant 7.0 : f32 %m = alloc() : memref<10xf32> - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { %v0 = load %m[%i1] : memref<10xf32> } // This tests that the loop nest '%i0' should not be removed after fusion // because it writes to memref '%m' which is returned by the function. // CHECK-DAG: %0 = alloc() : memref<10xf32> - // CHECK: for %i0 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: affine.for %i1 = 0 to 10 { // CHECK-NEXT: %1 = load %0[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return %0 : memref<10xf32> @@ -1302,17 +1302,17 @@ func @R3_to_R2_reshape() { %c0 = constant 0 : index - for %i0 = 0 to 2 { - for %i1 = 0 to 3 { - for %i2 = 0 to 16 { + affine.for %i0 = 0 to 2 { + affine.for %i1 = 0 to 3 { + affine.for %i2 = 0 to 16 { %val = "foo"(%i0, %i1, %i2) : (index, index, index) -> i32 store %val, %in[%i0, %i1, %i2] : memref<2x3x16xi32> } } } - for %ii = 0 to 32 { - for %jj = 0 to 3 { + affine.for %ii = 0 to 32 { + affine.for %jj = 0 to 3 { %a0 = affine.apply (d0, d1) -> (d0 * 3 + d1) (%ii, %jj) %idx = affine.apply (d0) -> (d0 floordiv (3 * 16)) (%a0) %v = load %in[%idx, %jj, %c0] @@ -1330,8 +1330,8 @@ func @R3_to_R2_reshape() { // CHECK-LABEL: func @R3_to_R2_reshape() // CHECK-DAG: %0 = alloc() : memref<1x1x1xi32> -// CHECK: for %i0 = 0 to 32 { -// CHECK-NEXT: for %i1 = 0 to 3 { +// CHECK: affine.for %i0 = 0 to 32 { +// CHECK-NEXT: affine.for %i1 = 0 to 3 { // CHECK-NEXT: %1 = affine.apply [[MAP0]](%i0, %i1) // CHECK-NEXT: %2 = "foo"(%1, %i1, %c0) : (index, index, index) -> i32 // CHECK-NEXT: %3 = affine.apply [[MAP2]](%i0, %i1, %1, %i1, %c0) @@ -1357,19 +1357,19 @@ func @should_not_fuse_multi_output_producer() { %cf7 = constant 7.0 : f32 - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %a[%i0] : memref<10xf32> store %cf7, %b[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { %v0 = load %a[%i1] : memref<10xf32> } - // CHECK: for %i0 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: affine.for %i1 = 0 to 10 { // CHECK-NEXT: %2 = load %0[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return @@ -1386,30 +1386,30 @@ func @fusion_preventing_deps_on_middle_loop() { %cf7 = constant 7.0 : f32 - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { %v0 = load %a[%i0] : memref<10xf32> store %v0, %b[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { store %cf7, %a[%i1] : memref<10xf32> %v1 = load %c[%i1] : memref<10xf32> } - for %i2 = 0 to 10 { + affine.for %i2 = 0 to 10 { %v2 = load %b[%i2] : memref<10xf32> store %v2, %c[%i2] : memref<10xf32> } // Loops '%i0' and '%i2' cannot fuse along producer/consumer edge on memref // '%b', because of the WAR dep from '%i0' to '%i1' on memref '%a' and // because of the WAR dep from '%i1' to '%i2' on memref '%c'. - // CHECK: for %i0 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: %3 = load %0[%i0] : memref<10xf32> // CHECK-NEXT: store %3, %1[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: affine.for %i1 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i1] : memref<10xf32> // CHECK-NEXT: %4 = load %2[%i1] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i2 = 0 to 10 { + // CHECK-NEXT: affine.for %i2 = 0 to 10 { // CHECK-NEXT: %5 = load %1[%i2] : memref<10xf32> // CHECK-NEXT: store %5, %2[%i2] : memref<10xf32> // CHECK-NEXT: } @@ -1429,17 +1429,17 @@ func @should_fuse_and_move_to_preserve_war_dep() { %cf7 = constant 7.0 : f32 - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { %v0 = load %b[%i0] : memref<10xf32> store %v0, %a[%i0] : memref<10xf32> } - for %i1 = 0 to 3 { + affine.for %i1 = 0 to 3 { %v2 = load %c[%i1] : memref<10xf32> } - for %i2 = 0 to 5 { + affine.for %i2 = 0 to 5 { store %cf7, %b[%i2] : memref<10xf32> } - for %i3 = 0 to 10 { + affine.for %i3 = 0 to 10 { %v1 = load %a[%i3] : memref<10xf32> store %cf7, %c[%i3] : memref<10xf32> } @@ -1458,10 +1458,10 @@ func @should_fuse_and_move_to_preserve_war_dep() { // if the fused loop nest is inserted between loops '%i1' and '%i2'. // CHECK-DAG: %0 = alloc() : memref<1xf32> - // CHECK: for %i0 = 0 to 3 { + // CHECK: affine.for %i0 = 0 to 3 { // CHECK-NEXT: %3 = load %2[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: affine.for %i1 = 0 to 10 { // CHECK-NEXT: %4 = load %1[%i1] : memref<10xf32> // CHECK-NEXT: %5 = affine.apply [[MAP0]](%i1, %i1) // CHECK-NEXT: store %4, %0[%5] : memref<1xf32> @@ -1469,7 +1469,7 @@ func @should_fuse_and_move_to_preserve_war_dep() { // CHECK-NEXT: %7 = load %0[%6] : memref<1xf32> // CHECK-NEXT: store %cst, %2[%i1] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i2 = 0 to 5 { + // CHECK-NEXT: affine.for %i2 = 0 to 5 { // CHECK-NEXT: store %cst, %1[%i2] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return @@ -1486,30 +1486,30 @@ func @fusion_preventing_dep_on_constant() { %cf7 = constant 7.0 : f32 - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { %v0 = load %b[%i0] : memref<10xf32> store %cf7, %a[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { store %cf7, %b[%i1] : memref<10xf32> } %cf11 = constant 11.0 : f32 - for %i2 = 0 to 10 { + affine.for %i2 = 0 to 10 { %v2 = load %a[%i2] : memref<10xf32> store %cf11, %c[%i2] : memref<10xf32> } // Loops '%i0' and '%i2' cannot fuse along producer/consumer edge on memref // '%a', because of the WAR dep from '%i0' to '%i1' on memref '%b' and // because of the SSA value dep from '%cf11' def to use in '%i2'. - // CHECK: for %i0 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: %3 = load %1[%i0] : memref<10xf32> // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: affine.for %i1 = 0 to 10 { // CHECK-NEXT: store %cst, %1[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: %cst_0 = constant 1.100000e+01 : f32 - // CHECK-NEXT: for %i2 = 0 to 10 { + // CHECK-NEXT: affine.for %i2 = 0 to 10 { // CHECK-NEXT: %4 = load %0[%i2] : memref<10xf32> // CHECK-NEXT: store %cst_0, %2[%i2] : memref<10xf32> // CHECK-NEXT: } @@ -1529,14 +1529,14 @@ func @should_fuse_and_preserve_dep_on_constant() { %cf7 = constant 7.0 : f32 %cf11 = constant 11.0 : f32 - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { %v0 = load %b[%i0] : memref<10xf32> store %cf7, %a[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { store %cf7, %b[%i1] : memref<10xf32> } - for %i2 = 0 to 10 { + affine.for %i2 = 0 to 10 { %v2 = load %a[%i2] : memref<10xf32> store %cf11, %c[%i2] : memref<10xf32> } @@ -1546,7 +1546,7 @@ func @should_fuse_and_preserve_dep_on_constant() { // the SSA value dep from '%cf11' def to use in '%i2'. // CHECK: %cst_0 = constant 1.100000e+01 : f32 - // CHECK-NEXT: for %i0 = 0 to 10 { + // CHECK-NEXT: affine.for %i0 = 0 to 10 { // CHECK-NEXT: %3 = load %1[%i0] : memref<10xf32> // CHECK-NEXT: %4 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: store %cst, %0[%4] : memref<1xf32> @@ -1554,7 +1554,7 @@ func @should_fuse_and_preserve_dep_on_constant() { // CHECK-NEXT: %6 = load %0[%5] : memref<1xf32> // CHECK-NEXT: store %cst_0, %2[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: affine.for %i1 = 0 to 10 { // CHECK-NEXT: store %cst, %1[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return @@ -1572,25 +1572,25 @@ func @should_fuse_and_preserve_dep_on_constant() { func @should_fuse_at_depth_above_loop_carried_dependence(%arg0: memref<64x4xf32>, %arg1: memref<64x4xf32>) { %out = alloc() : memref<64x4xf32> %0 = constant 0.0 : f32 - for %i0 = 0 to 64 { - for %i1 = 0 to 4 { + affine.for %i0 = 0 to 64 { + affine.for %i1 = 0 to 4 { store %0, %out[%i0, %i1] : memref<64x4xf32> } } - for %i2 = 0 to 4 { - for %i3 = 0 to 4 { - for %i4 = 0 to 16 { + affine.for %i2 = 0 to 4 { + affine.for %i3 = 0 to 4 { + affine.for %i4 = 0 to 16 { %1 = affine.apply (d0, d1) -> (d0 * 16 - d1 + 15)(%i3, %i4) %2 = load %arg1[%1, %i2] : memref<64x4xf32> "op0"(%2) : (f32) -> () } - for %i5 = 0 to 4 { - for %i6 = 0 to 16 { + affine.for %i5 = 0 to 4 { + affine.for %i6 = 0 to 16 { %3 = affine.apply (d0, d1) -> (d0 * 16 - d1 + 15)(%i5, %i6) %4 = load %arg0[%3, %i3] : memref<64x4xf32> "op1"(%4) : (f32) -> () } - for %i7 = 0 to 16 { + affine.for %i7 = 0 to 16 { %5 = "op2"() : () -> (f32) %6 = affine.apply (d0, d1) -> (d0 * 16 + d1)(%i5, %i7) %7 = load %out[%6, %i2] : memref<64x4xf32> @@ -1610,25 +1610,25 @@ func @should_fuse_at_depth_above_loop_carried_dependence(%arg0: memref<64x4xf32> // memref size can be reduced to 128x1xf32. // CHECK: %0 = alloc() : memref<64x1xf32> - // CHECK: for %i0 = 0 to 4 { - // CHECK-NEXT: for %i1 = 0 to 64 { + // CHECK: affine.for %i0 = 0 to 4 { + // CHECK-NEXT: affine.for %i1 = 0 to 64 { // CHECK-NEXT: %1 = affine.apply [[MAP0]](%i0, %i1, %i0) // CHECK-NEXT: %2 = affine.apply [[MAP1]](%i0, %i1, %i0) // CHECK-NEXT: store %cst, %0[%1, %2] : memref<64x1xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i2 = 0 to 4 { - // CHECK-NEXT: for %i3 = 0 to 16 { + // CHECK-NEXT: affine.for %i2 = 0 to 4 { + // CHECK-NEXT: affine.for %i3 = 0 to 16 { // CHECK-NEXT: %3 = affine.apply [[MAP2]](%i2, %i3) // CHECK-NEXT: %4 = load %arg1[%3, %i0] : memref<64x4xf32> // CHECK-NEXT: "op0"(%4) : (f32) -> () // CHECK-NEXT: } - // CHECK-NEXT: for %i4 = 0 to 4 { - // CHECK-NEXT: for %i5 = 0 to 16 { + // CHECK-NEXT: affine.for %i4 = 0 to 4 { + // CHECK-NEXT: affine.for %i5 = 0 to 16 { // CHECK-NEXT: %5 = affine.apply [[MAP2]](%i4, %i5) // CHECK-NEXT: %6 = load %arg0[%5, %i2] : memref<64x4xf32> // CHECK-NEXT: "op1"(%6) : (f32) -> () // CHECK-NEXT: } - // CHECK-NEXT: for %i6 = 0 to 16 { + // CHECK-NEXT: affine.for %i6 = 0 to 16 { // CHECK-NEXT: %7 = "op2"() : () -> f32 // CHECK-NEXT: %8 = affine.apply [[MAP3]](%i4, %i6) // CHECK-NEXT: %9 = affine.apply [[MAP0]](%i0, %8, %i0) @@ -1657,14 +1657,14 @@ func @should_fuse_after_private_memref_creation() { %cf7 = constant 7.0 : f32 - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %a[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { %v0 = load %a[%i1] : memref<10xf32> store %v0, %b[%i1] : memref<10xf32> } - for %i2 = 0 to 10 { + affine.for %i2 = 0 to 10 { %v1 = load %a[%i2] : memref<10xf32> store %v1, %b[%i2] : memref<10xf32> } @@ -1675,14 +1675,14 @@ func @should_fuse_after_private_memref_creation() { // private memref, the dependence between '%i0' and '%i1' on memref '%a' no // longer exists, so '%i0' can now be fused into '%i2'. - // CHECK: for %i0 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: store %cst, %1[%3] : memref<1xf32> // CHECK-NEXT: %4 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: %5 = load %1[%4] : memref<1xf32> // CHECK-NEXT: store %5, %2[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: affine.for %i1 = 0 to 10 { // CHECK-NEXT: %6 = affine.apply [[MAP0]](%i1, %i1) // CHECK-NEXT: store %cst, %0[%6] : memref<1xf32> // CHECK-NEXT: %7 = affine.apply [[MAP0]](%i1, %i1) @@ -1702,12 +1702,12 @@ func @should_fuse_after_one_loop_interchange() { %a = alloc() : memref<10xf32> %cf0 = constant 0.0 : f32 - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf0, %a[%i0] : memref<10xf32> } - for %i1 = 0 to 5 { - for %i2 = 0 to 10 { + affine.for %i1 = 0 to 5 { + affine.for %i2 = 0 to 10 { %v0 = load %a[%i2] : memref<10xf32> store %v0, %a[%i2] : memref<10xf32> } @@ -1719,10 +1719,10 @@ func @should_fuse_after_one_loop_interchange() { // at loop depth 1, because the loop carrying the dependence has been // interchanged and is now at depth 2. - // CHECK: for %i0 = 0 to 10 { + // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: %1 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: store %cst, %0[%1] : memref<1xf32> - // CHECK-NEXT: for %i1 = 0 to 5 { + // CHECK-NEXT: affine.for %i1 = 0 to 5 { // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i0, %i0) // CHECK-NEXT: %3 = load %0[%2] : memref<1xf32> // CHECK-NEXT: %4 = affine.apply [[MAP0]](%i0, %i0) @@ -1743,16 +1743,16 @@ func @should_fuse_after_two_loop_interchanges() { %a = alloc() : memref<6x8xf32> %cf0 = constant 0.0 : f32 - for %i0 = 0 to 6 { - for %i1 = 0 to 8 { + affine.for %i0 = 0 to 6 { + affine.for %i1 = 0 to 8 { store %cf0, %a[%i0, %i1] : memref<6x8xf32> } } - for %i2 = 0 to 4 { - for %i3 = 0 to 6 { - for %i4 = 0 to 2 { - for %i5 = 0 to 8 { + affine.for %i2 = 0 to 4 { + affine.for %i3 = 0 to 6 { + affine.for %i4 = 0 to 2 { + affine.for %i5 = 0 to 8 { %v0 = load %a[%i3, %i5] : memref<6x8xf32> %v1 = addf %v0, %v0 : f32 store %v1, %a[%i3, %i5] : memref<6x8xf32> @@ -1768,13 +1768,13 @@ func @should_fuse_after_two_loop_interchanges() { // '%i5', then loop '%i0' can be fused at loop depth 2, because the loop // carring the dependences have been interchanged with loops at depth > 2. - // CHECK: for %i0 = 0 to 6 { - // CHECK-NEXT: for %i1 = 0 to 8 { + // CHECK: affine.for %i0 = 0 to 6 { + // CHECK-NEXT: affine.for %i1 = 0 to 8 { // CHECK-NEXT: %1 = affine.apply [[MAP0]](%i0, %i1, %i0, %i1) // CHECK-NEXT: %2 = affine.apply [[MAP1]](%i0, %i1, %i0, %i1) // CHECK-NEXT: store %cst, %0[%1, %2] : memref<1x1xf32> - // CHECK-NEXT: for %i2 = 0 to 4 { - // CHECK-NEXT: for %i3 = 0 to 2 { + // CHECK-NEXT: affine.for %i2 = 0 to 4 { + // CHECK-NEXT: affine.for %i3 = 0 to 2 { // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i1, %i0, %i1) // CHECK-NEXT: %4 = affine.apply [[MAP1]](%i0, %i1, %i0, %i1) // CHECK-NEXT: %5 = load %0[%3, %4] : memref<1x1xf32> @@ -1794,17 +1794,17 @@ func @should_fuse_after_two_loop_interchanges() { func @should_fuse_live_out_writer(%arg0 : memref<10xf32>) -> memref<10xf32> { %cst = constant 0.000000e+00 : f32 - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cst, %arg0[%i0] : memref<10xf32> } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { %1 = load %arg0[%i1] : memref<10xf32> store %1, %arg0[%i1] : memref<10xf32> } return %arg0 : memref<10xf32> // CHECK: %cst = constant 0.000000e+00 : f32 - // CHECK-NEXT: for %i0 = 0 to 10 { + // CHECK-NEXT: affine.for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %arg0[%i0] : memref<10xf32> // CHECK-NEXT: %0 = load %arg0[%i0] : memref<10xf32> // CHECK-NEXT: store %0, %arg0[%i0] : memref<10xf32> @@ -1823,20 +1823,20 @@ func @should_fuse_live_out_writer(%arg0 : memref<10xf32>) -> memref<10xf32> { // CHECK-LABEL: slice_tile func @slice_tile(%arg0: memref<128x8xf32>, %arg1: memref<32x8xf32>, %0 : f32) -> memref<32x8xf32> { - for %i0 = 0 to 32 { - for %i1 = 0 to 8 { + affine.for %i0 = 0 to 32 { + affine.for %i1 = 0 to 8 { store %0, %arg1[%i0, %i1] : memref<32x8xf32> } } - for %i = 0 to 2 { - for %j = 0 to 8 { - for %k = 0 to 8 { - for %kk = 0 to 16 { + affine.for %i = 0 to 2 { + affine.for %j = 0 to 8 { + affine.for %k = 0 to 8 { + affine.for %kk = 0 to 16 { %1 = affine.apply #map(%k, %kk) %2 = load %arg0[%1, %j] : memref<128x8xf32> %3 = "foo"(%2) : (f32) -> f32 } - for %ii = 0 to 16 { + affine.for %ii = 0 to 16 { %6 = affine.apply #map(%i, %ii) %7 = load %arg1[%6, %j] : memref<32x8xf32> %8 = addf %7, %7 : f32 @@ -1847,18 +1847,18 @@ func @slice_tile(%arg0: memref<128x8xf32>, %arg1: memref<32x8xf32>, %0 : f32) -> } return %arg1 : memref<32x8xf32> } -// CHECK: for %i0 = 0 to 2 { -// CHECK-NEXT: for %i1 = 0 to 8 { -// CHECK-NEXT: for %i2 = [[MAP_LB]](%i0) to [[MAP_UB]](%i0) { +// CHECK: affine.for %i0 = 0 to 2 { +// CHECK-NEXT: affine.for %i1 = 0 to 8 { +// CHECK-NEXT: affine.for %i2 = [[MAP_LB]](%i0) to [[MAP_UB]](%i0) { // CHECK-NEXT: store %arg2, %arg1[%i2, %i1] : memref<32x8xf32> // CHECK-NEXT: } -// CHECK-NEXT: for %i3 = 0 to 8 { -// CHECK-NEXT: for %i4 = 0 to 16 { +// CHECK-NEXT: affine.for %i3 = 0 to 8 { +// CHECK-NEXT: affine.for %i4 = 0 to 16 { // CHECK-NEXT: %0 = affine.apply #map{{[0-9]+}}(%i3, %i4) // CHECK-NEXT: %1 = load %arg0[%0, %i1] : memref<128x8xf32> // CHECK-NEXT: %2 = "foo"(%1) : (f32) -> f32 // CHECK-NEXT: } -// CHECK-NEXT: for %i5 = 0 to 16 { +// CHECK-NEXT: affine.for %i5 = 0 to 16 { // CHECK-NEXT: %3 = affine.apply #map{{[0-9]+}}(%i0, %i5) // CHECK-NEXT: %4 = load %arg1[%3, %i1] : memref<32x8xf32> // CHECK-NEXT: %5 = addf %4, %4 : f32 @@ -1879,9 +1879,9 @@ func @test_add_slice_bounds() { %cf7 = constant 7.0 : f32 %c0 = constant 0 : index - for %i0 = 0 to 10 { - for %i1 = 0 to 10 { - for %i2 = 0 to 10 { + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { + affine.for %i2 = 0 to 10 { %a0 = affine.apply (d0) -> (d0) (%i0) %a1 = affine.apply (d0) -> (d0) (%i0) %a2 = affine.apply (d0, d1) -> (d0 - d1) (%a0, %a1) @@ -1889,17 +1889,17 @@ func @test_add_slice_bounds() { } } } - for %i3 = 0 to 10 { - for %i4 = 0 to 10 { - for %i5 = 0 to 10 { + affine.for %i3 = 0 to 10 { + affine.for %i4 = 0 to 10 { + affine.for %i5 = 0 to 10 { %v0 = load %a[%c0] : memref<10xf32> } } } -// CHECK: for %i0 = 0 to 10 { -// CHECK-NEXT: for %i1 = 0 to 10 { -// CHECK-NEXT: for %i2 = 0 to 10 { +// CHECK: affine.for %i0 = 0 to 10 { +// CHECK-NEXT: affine.for %i1 = 0 to 10 { +// CHECK-NEXT: affine.for %i2 = 0 to 10 { // CHECK-NEXT: %2 = affine.apply #map2(%i0) // CHECK-NEXT: %3 = affine.apply #map2(%i0) // CHECK-NEXT: %4 = affine.apply #map3(%2, %3) @@ -1907,9 +1907,9 @@ func @test_add_slice_bounds() { // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: for %i3 = 0 to 10 { -// CHECK-NEXT: for %i4 = 0 to 10 { -// CHECK-NEXT: for %i5 = 0 to 10 { +// CHECK-NEXT: affine.for %i3 = 0 to 10 { +// CHECK-NEXT: affine.for %i4 = 0 to 10 { +// CHECK-NEXT: affine.for %i5 = 0 to 10 { // CHECK-NEXT: %5 = load %0[%c0] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: } @@ -1926,31 +1926,31 @@ func @should_fuse_init_loops_siblings_then_shared_producer(%arg0: memref<10x10xf %cst = constant 0.000000e+00 : f32 %cst_0 = constant 1.000000e+00 : f32 %cst_1 = constant 7.000000e+00 : f32 - for %i0 = 0 to 10 { - for %i1 = 0 to 10 { + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { store %cst_1, %0[%i0, %i1] : memref<10x10xf32> } } - for %i2 = 0 to 3 { - for %i3 = 0 to 3 { + affine.for %i2 = 0 to 3 { + affine.for %i3 = 0 to 3 { store %cst, %arg0[%i2, %i3] : memref<10x10xf32> } } - for %i4 = 0 to 3 { - for %i5 = 0 to 3 { + affine.for %i4 = 0 to 3 { + affine.for %i5 = 0 to 3 { %1 = load %0[%i4, %i5] : memref<10x10xf32> %2 = load %arg0[%i4, %i5] : memref<10x10xf32> %3 = mulf %1, %2 : f32 store %3, %arg0[%i4, %i5] : memref<10x10xf32> } } - for %i6 = 0 to 3 { - for %i7 = 0 to 3 { + affine.for %i6 = 0 to 3 { + affine.for %i7 = 0 to 3 { store %cst_0, %arg1[%i6, %i7] : memref<10x10xf32> } } - for %i8 = 0 to 3 { - for %i9 = 0 to 3 { + affine.for %i8 = 0 to 3 { + affine.for %i9 = 0 to 3 { %4 = load %0[%i8, %i9] : memref<10x10xf32> %5 = load %arg1[%i8, %i9] : memref<10x10xf32> %6 = addf %4, %5 : f32 @@ -1966,8 +1966,8 @@ func @should_fuse_init_loops_siblings_then_shared_producer(%arg0: memref<10x10xf // that loop nest '%i0' now has a single user after Pass 2 fused its // two users together). -// CHECK: for %i0 = 0 to 3 { -// CHECK-NEXT: for %i1 = 0 to 3 { +// CHECK: affine.for %i0 = 0 to 3 { +// CHECK-NEXT: affine.for %i1 = 0 to 3 { // CHECK-NEXT: %1 = affine.apply [[MAP0]](%i0, %i1, %i0, %i1) // CHECK-NEXT: %2 = affine.apply [[MAP1]](%i0, %i1, %i0, %i1) // CHECK-NEXT: store %cst_1, %0[%1, %2] : memref<1x1xf32> @@ -2005,14 +2005,14 @@ func @two_matrix_vector_products() { %cf7 = constant 7.0 : f32 // Populate input matrix. - for %i0 = 0 to 10 { - for %i1 = 0 to 10 { + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { store %cf7, %in_matrix[%i0, %i1] : memref<10x10xf32> } } // out_vec0 = in_matrix x in_vec0 - for %i2 = 0 to 10 { - for %i3 = 0 to 10 { + affine.for %i2 = 0 to 10 { + affine.for %i3 = 0 to 10 { %v0 = load %in_matrix[%i2, %i3] : memref<10x10xf32> %v1 = load %in_vec0[%i3] : memref<10xf32> %v2 = mulf %v0, %v1 : f32 @@ -2022,8 +2022,8 @@ func @two_matrix_vector_products() { } } // out_vec1 = in_matrix x in_vec1 - for %i4 = 0 to 10 { - for %i5 = 0 to 10 { + affine.for %i4 = 0 to 10 { + affine.for %i5 = 0 to 10 { %v5 = load %in_matrix[%i4, %i5] : memref<10x10xf32> %v6 = load %in_vec1[%i5] : memref<10xf32> %v7 = mulf %v5, %v6 : f32 @@ -2033,13 +2033,13 @@ func @two_matrix_vector_products() { } } -// CHECK: for %i0 = 0 to 10 { -// CHECK-NEXT: for %i1 = 0 to 10 { +// CHECK: affine.for %i0 = 0 to 10 { +// CHECK-NEXT: affine.for %i1 = 0 to 10 { // CHECK-NEXT: %5 = affine.apply [[MAP2]](%i0, %i1, %i0) // CHECK-NEXT: %6 = affine.apply [[MAP3]](%i0, %i1, %i0) // CHECK-NEXT: store %cst, %0[%5, %6] : memref<10x1xf32> // CHECK-NEXT: } -// CHECK-NEXT: for %i2 = 0 to 10 { +// CHECK-NEXT: affine.for %i2 = 0 to 10 { // CHECK-NEXT: %7 = affine.apply [[MAP2]](%i0, %i2, %i0) // CHECK-NEXT: %8 = affine.apply [[MAP3]](%i0, %i2, %i0) // CHECK-NEXT: %9 = load %0[%7, %8] : memref<10x1xf32> @@ -2049,7 +2049,7 @@ func @two_matrix_vector_products() { // CHECK-NEXT: %13 = addf %11, %12 : f32 // CHECK-NEXT: store %13, %3[%i0] : memref<10xf32> // CHECK-NEXT: } -// CHECK-NEXT: for %i3 = 0 to 10 { +// CHECK-NEXT: affine.for %i3 = 0 to 10 { // CHECK-NEXT: %14 = affine.apply [[MAP2]](%i0, %i3, %i0) // CHECK-NEXT: %15 = affine.apply [[MAP3]](%i0, %i3, %i0) // CHECK-NEXT: %16 = load %0[%14, %15] : memref<10x1xf32> @@ -2070,28 +2070,28 @@ func @two_matrix_vector_products() { func @should_not_slice_past_slice_barrier() { %0 = alloc() : memref<100x16xf32> - for %i0 = 0 to 100 { - for %i1 = 0 to 16 { + affine.for %i0 = 0 to 100 { + affine.for %i1 = 0 to 16 { %1 = "op1"() : () -> f32 store %1, %0[%i0, %i1] : memref<100x16xf32> } {slice_fusion_barrier: true} } - for %i2 = 0 to 100 { - for %i3 = 0 to 16 { + affine.for %i2 = 0 to 100 { + affine.for %i3 = 0 to 16 { %2 = load %0[%i2, %i3] : memref<100x16xf32> "op2"(%2) : (f32) -> () } } // The 'slice_fusion_barrier' attribute on '%i1' prevents slicing the // iteration space of '%i1' and any enclosing loop nests. -// CHECK: for %i0 = 0 to 100 { -// CHECK-NEXT: for %i1 = 0 to 16 { +// CHECK: affine.for %i0 = 0 to 100 { +// CHECK-NEXT: affine.for %i1 = 0 to 16 { // CHECK-NEXT: %1 = "op1"() : () -> f32 // CHECK-NEXT: %2 = affine.apply [[MAP3]](%i0, %i0, %i1) // CHECK-NEXT: %3 = affine.apply [[MAP4]](%i0, %i0, %i1) // CHECK-NEXT: store %1, %0[%2, %3] : memref<1x16xf32> // CHECK-NEXT: } {slice_fusion_barrier: true} -// CHECK-NEXT: for %i2 = 0 to 16 { +// CHECK-NEXT: affine.for %i2 = 0 to 16 { // CHECK-NEXT: %4 = affine.apply [[MAP3]](%i0, %i0, %i2) // CHECK-NEXT: %5 = affine.apply [[MAP4]](%i0, %i0, %i2) // CHECK-NEXT: %6 = load %0[%4, %5] : memref<1x16xf32> @@ -2107,18 +2107,18 @@ func @should_not_slice_past_slice_barrier() { func @fuse_across_dim_mismatch(%arg0: memref<4x4x16x1xf32>, %arg1: memref<144x9xf32>, %arg2: memref<9xf32>) { %1 = alloc() : memref<144x4xf32> %2 = constant 0.0 : f32 - for %i2 = 0 to 9 { - for %i3 = 0 to 4 { - for %i5 = 0 to 16 { + affine.for %i2 = 0 to 9 { + affine.for %i3 = 0 to 4 { + affine.for %i5 = 0 to 16 { %7 = affine.apply #map0(%i2, %i5) store %2, %1[%7, %i3] : memref<144x4xf32> } } } - for %i6 = 0 to 9 { - for %i7 = 0 to 9 { - for %i8 = 0 to 4 { - for %i10 = 0 to 16 { + affine.for %i6 = 0 to 9 { + affine.for %i7 = 0 to 9 { + affine.for %i8 = 0 to 4 { + affine.for %i10 = 0 to 16 { %10 = affine.apply #map0(%i6, %i10) %11 = load %1[%10, %i8] : memref<144x4xf32> } @@ -2132,10 +2132,10 @@ func @fuse_across_dim_mismatch(%arg0: memref<4x4x16x1xf32>, %arg1: memref<144x9x // MAXIMAL-NEXT: #map6 = (d0, d1, d2, d3, d4) -> (-d2 + d4) // MAXIMAL-LABEL: func @fuse_across_dim_mismatch // MAXIMAL: %0 = alloc() : memref<1x1xf32> -// MAXIMAL: for %i0 = 0 to 9 { -// MAXIMAL-NEXT: for %i1 = 0 to 9 { -// MAXIMAL-NEXT: for %i2 = 0 to 4 { -// MAXIMAL-NEXT: for %i3 = 0 to 16 { +// MAXIMAL: affine.for %i0 = 0 to 9 { +// MAXIMAL-NEXT: affine.for %i1 = 0 to 9 { +// MAXIMAL-NEXT: affine.for %i2 = 0 to 4 { +// MAXIMAL-NEXT: affine.for %i3 = 0 to 16 { // MAXIMAL-NEXT: %1 = affine.apply #map4(%i0, %i3) // MAXIMAL-NEXT: %2 = affine.apply #map5(%i0, %i3, %i2, %1, %i2) // MAXIMAL-NEXT: %3 = affine.apply #map6(%i0, %i3, %i2, %1, %i2) @@ -2164,8 +2164,8 @@ func @fuse_across_varying_dims_complex() { %0 = alloc() : memref<2x2x3x3x16x1xf32> %1 = alloc() : memref<64x9xf32> %2 = alloc() : memref<144x4xf32> - for %i0 = 0 to 64 { - for %i1 = 0 to 9 { + affine.for %i0 = 0 to 64 { + affine.for %i1 = 0 to 9 { %4 = affine.apply #map3(%i0, %i1) %5 = affine.apply #map4(%i0, %i1) %6 = affine.apply #map5(%i0, %i1) @@ -2175,23 +2175,23 @@ func @fuse_across_varying_dims_complex() { store %9, %1[%i0, %i1] : memref<64x9xf32> } } - for %i2 = 0 to 9 { - for %i3 = 0 to 4 { - for %i4 = 0 to 16 { + affine.for %i2 = 0 to 9 { + affine.for %i3 = 0 to 4 { + affine.for %i4 = 0 to 16 { %10 = affine.apply #map10(%i3, %i4) %11 = load %1[%10, %i2] : memref<64x9xf32> } - for %i5 = 0 to 16 { + affine.for %i5 = 0 to 16 { %13 = "bar"() : () -> f32 %14 = affine.apply #map11(%i2, %i5) store %13, %2[%14, %i3] : memref<144x4xf32> } } } - for %i6 = 0 to 9 { - for %i7 = 0 to 9 { - for %i8 = 0 to 4 { - for %i9 = 0 to 16 { + affine.for %i6 = 0 to 9 { + affine.for %i7 = 0 to 9 { + affine.for %i8 = 0 to 4 { + affine.for %i9 = 0 to 16 { %15 = affine.apply #map12(%i8, %i9) %16 = load %1[%15, %i7] : memref<64x9xf32> } @@ -2214,11 +2214,11 @@ func @fuse_across_varying_dims_complex() { // MAXIMAL-NEXT: %c0 = constant 0 : index // MAXIMAL-NEXT: %1 = alloc() : memref<2x2x3x3x16x1xf32> // MAXIMAL-NEXT: %2 = alloc() : memref<144x4xf32> -// MAXIMAL-NEXT: for %i0 = 0 to 9 { -// MAXIMAL-NEXT: for %i1 = 0 to 9 { -// MAXIMAL-NEXT: for %i2 = 0 to 4 { -// MAXIMAL-NEXT: for %i3 = 0 to 16 { -// MAXIMAL-NEXT: for %i4 = 0 to 64 { +// MAXIMAL-NEXT: affine.for %i0 = 0 to 9 { +// MAXIMAL-NEXT: affine.for %i1 = 0 to 9 { +// MAXIMAL-NEXT: affine.for %i2 = 0 to 4 { +// MAXIMAL-NEXT: affine.for %i3 = 0 to 16 { +// MAXIMAL-NEXT: affine.for %i4 = 0 to 64 { // MAXIMAL-NEXT: %3 = affine.apply #map5(%i4, %i0) // MAXIMAL-NEXT: %4 = affine.apply #map6(%i4, %i0) // MAXIMAL-NEXT: %5 = affine.apply #map7(%i4, %i0) @@ -2229,14 +2229,14 @@ func @fuse_across_varying_dims_complex() { // MAXIMAL-NEXT: %10 = affine.apply #map11(%i0, %i4, %i0) // MAXIMAL-NEXT: store %8, %0[%9, %10] : memref<64x1xf32> // MAXIMAL-NEXT: } -// MAXIMAL-NEXT: for %i5 = 0 to 4 { -// MAXIMAL-NEXT: for %i6 = 0 to 16 { +// MAXIMAL-NEXT: affine.for %i5 = 0 to 4 { +// MAXIMAL-NEXT: affine.for %i6 = 0 to 16 { // MAXIMAL-NEXT: %11 = affine.apply #map12(%i5, %i6) // MAXIMAL-NEXT: %12 = affine.apply #map10(%i0, %11, %i0) // MAXIMAL-NEXT: %13 = affine.apply #map11(%i0, %11, %i0) // MAXIMAL-NEXT: %14 = load %0[%12, %13] : memref<64x1xf32> // MAXIMAL-NEXT: } -// MAXIMAL-NEXT: for %i7 = 0 to 16 { +// MAXIMAL-NEXT: affine.for %i7 = 0 to 16 { // MAXIMAL-NEXT: %15 = "bar"() : () -> f32 // MAXIMAL-NEXT: %16 = affine.apply #map12(%i0, %i7) // MAXIMAL-NEXT: store %15, %2[%16, %i5] : memref<144x4xf32> @@ -2259,13 +2259,13 @@ func @should_fuse_with_slice_union() { %c0 = constant 0 : index %cf0 = constant 0.0 : f32 - for %i0 = 0 to 100 { + affine.for %i0 = 0 to 100 { store %cf0, %a[%i0]: memref<100xf32> } - for %i1 = 10 to 20 { + affine.for %i1 = 10 to 20 { %v0 = load %a[%i1]: memref<100xf32> - for %i2 = 15 to 25 { + affine.for %i2 = 15 to 25 { %v1 = load %a[%i2]: memref<100xf32> } } @@ -2274,14 +2274,14 @@ func @should_fuse_with_slice_union() { // remapping, and private memref size. The result is that the temporary // memref is reduced from 100xf32 to 15xf32 and properly indexed by // the fused loops based on the union calculation. -// CHECK: for %i0 = 10 to 20 { -// CHECK-NEXT: for %i1 = 10 to 25 { +// CHECK: affine.for %i0 = 10 to 20 { +// CHECK-NEXT: affine.for %i1 = 10 to 25 { // CHECK-NEXT: %1 = affine.apply [[MAP3]](%i1) // CHECK-NEXT: store %cst, %0[%1] : memref<15xf32> // CHECK-NEXT: } // CHECK-NEXT: %2 = affine.apply [[MAP3]](%i0) // CHECK-NEXT: %3 = load %0[%2] : memref<15xf32> -// CHECK-NEXT: for %i2 = 15 to 25 { +// CHECK-NEXT: affine.for %i2 = 15 to 25 { // CHECK-NEXT: %4 = affine.apply [[MAP3]](%i2) // CHECK-NEXT: %5 = load %0[%4] : memref<15xf32> // CHECK-NEXT: } diff --git a/mlir/test/Transforms/loop-tiling.mlir b/mlir/test/Transforms/loop-tiling.mlir index c18c0fccf4b..ff1fd30ce20 100644 --- a/mlir/test/Transforms/loop-tiling.mlir +++ b/mlir/test/Transforms/loop-tiling.mlir @@ -9,12 +9,12 @@ // CHECK-DAG: [[UB_INTRA_TILE:#map[0-9]+]] = (d0, d1, d2) -> (d2 + 32, s0, 4096 floordiv s1) // CHECK-LABEL: func @loop_tiling() -// CHECK-NEXT: for %i0 = 0 to 256 step 32 { -// CHECK-NEXT: for %i1 = 0 to 512 step 32 { -// CHECK-NEXT: for %i2 = 0 to 1024 step 32 { -// CHECK-NEXT: for %i3 = [[IDENTITY]](%i0) to [[MAP0]](%i0) { -// CHECK-NEXT: for %i4 = [[IDENTITY]](%i1) to [[MAP0]](%i1) { -// CHECK-NEXT: for %i5 = [[IDENTITY]](%i2) to [[MAP0]](%i2) { +// CHECK-NEXT: affine.for %i0 = 0 to 256 step 32 { +// CHECK-NEXT: affine.for %i1 = 0 to 512 step 32 { +// CHECK-NEXT: affine.for %i2 = 0 to 1024 step 32 { +// CHECK-NEXT: affine.for %i3 = [[IDENTITY]](%i0) to [[MAP0]](%i0) { +// CHECK-NEXT: affine.for %i4 = [[IDENTITY]](%i1) to [[MAP0]](%i1) { +// CHECK-NEXT: affine.for %i5 = [[IDENTITY]](%i2) to [[MAP0]](%i2) { // CHECK-NEXT: "foo"(%i3, %i4, %i5) : (index, index, index) -> () // CHECK-NEXT: } // CHECK-NEXT: } @@ -22,32 +22,32 @@ // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: for %i6 = 0 to 50 step 32 { -// CHECK-NEXT: for %i7 = [[IDENTITY]](%i6) to min [[MAP1]](%i6) { +// CHECK-NEXT: affine.for %i6 = 0 to 50 step 32 { +// CHECK-NEXT: affine.for %i7 = [[IDENTITY]](%i6) to min [[MAP1]](%i6) { // CHECK-NEXT: "bar"(%i7, %i7) : (index, index) -> () // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: for %i8 = 0 to 21 step 32 { -// CHECK-NEXT: for %i9 = [[IDENTITY]](%i8) to 21 { +// CHECK-NEXT: affine.for %i8 = 0 to 21 step 32 { +// CHECK-NEXT: affine.for %i9 = [[IDENTITY]](%i8) to 21 { // CHECK-NEXT: "foobar"(%i9) : (index) -> () // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return func @loop_tiling() { - for %i = 0 to 256 { - for %j = 0 to 512 { - for %k = 0 to 1024 { + affine.for %i = 0 to 256 { + affine.for %j = 0 to 512 { + affine.for %k = 0 to 1024 { "foo"(%i, %j, %k) : (index, index, index) -> () } } } - for %x = 0 to 50 { + affine.for %x = 0 to 50 { "bar"(%x, %x) : (index, index) -> () } // Intra-tile loop won't need a min expression. - for %y = 0 to 21 { + affine.for %y = 0 to 21 { "foobar"(%y) : (index) -> () } @@ -59,12 +59,12 @@ func @loop_tiling() { // CHECK-LABEL: func @loop_max_min_bound(%arg0: memref, %arg1: index, %arg2: index) { func @loop_max_min_bound(%A : memref, %L : index, %U : index) { %M = dim %A, 0 : memref - for %iTT = max #lb()[%L] to min #ub()[%M, %U] { + affine.for %iTT = max #lb()[%L] to min #ub()[%M, %U] { %out = affine.apply (d0) -> (d0) (%iTT) } return -// CHECK: for %i0 = max [[LB]]()[%arg1] to min [[UB]]()[%0, %arg2] step 32 { -// CHECK-NEXT: for %i1 = [[IDENTITY]](%i0) to min [[UB_INTRA_TILE]](%0, %arg2, %i0) { +// CHECK: affine.for %i0 = max [[LB]]()[%arg1] to min [[UB]]()[%0, %arg2] step 32 { +// CHECK-NEXT: affine.for %i1 = [[IDENTITY]](%i0) to min [[UB_INTRA_TILE]](%0, %arg2, %i0) { // CHECK-NEXT: %1 = affine.apply [[IDENTITY]](%i1) // CHECK-NEXT: } // CHECK-NEXT: } @@ -78,9 +78,9 @@ func @loop_max_min_bound(%A : memref, %L : index, %U : index) { // MODEL-LABEL: func @simple_matmul func @simple_matmul(%arg0: memref<8x8xvector<64xf32>>, %arg1: memref<8x8xvector<64xf32>>, %arg2: memref<8x8xvector<64xf32>>) -> memref<8x8xvector<64xf32>> { - for %i = 0 to 256 { - for %j = 0 to 256 { - for %k = 0 to 250 { + affine.for %i = 0 to 256 { + affine.for %j = 0 to 256 { + affine.for %k = 0 to 250 { %l = load %arg0[%i, %k] : memref<8x8xvector<64xf32>> %r = load %arg1[%k, %j] : memref<8x8xvector<64xf32>> %o = load %arg2[%i, %j] : memref<8x8xvector<64xf32>> @@ -92,6 +92,6 @@ func @simple_matmul(%arg0: memref<8x8xvector<64xf32>>, %arg1: memref<8x8xvector< } return %arg2 : memref<8x8xvector<64xf32>> } -// MODEL: for %i0 = 0 to 256 step 4 { -// MODEL-NEXT: for %i1 = 0 to 256 step 4 { -// MODEL-NEXT: for %i2 = 0 to 250 step 5 { +// MODEL: affine.for %i0 = 0 to 256 step 4 { +// MODEL-NEXT: affine.for %i1 = 0 to 256 step 4 { +// MODEL-NEXT: affine.for %i2 = 0 to 250 step 5 { diff --git a/mlir/test/Transforms/lower-affine.mlir b/mlir/test/Transforms/lower-affine.mlir index ac55afdf7c2..0cbc6d012c5 100644 --- a/mlir/test/Transforms/lower-affine.mlir +++ b/mlir/test/Transforms/lower-affine.mlir @@ -24,7 +24,7 @@ func @body(index) -> () // CHECK-NEXT: return // CHECK-NEXT: } func @simple_loop() { - for %i = 1 to 42 { + affine.for %i = 1 to 42 { call @body(%i) : (index) -> () } return @@ -65,9 +65,9 @@ func @post(index) -> () // CHECK-NEXT: return // CHECK-NEXT: } func @imperfectly_nested_loops() { - for %i = 0 to 42 { + affine.for %i = 0 to 42 { call @pre(%i) : (index) -> () - for %j = 7 to 56 step 2 { + affine.for %j = 7 to 56 step 2 { call @body2(%i, %j) : (index, index) -> () } call @post(%i) : (index) -> () @@ -122,13 +122,13 @@ func @body3(index, index) -> () // CHECK-NEXT: return // CHECK-NEXT: } func @more_imperfectly_nested_loops() { - for %i = 0 to 42 { + affine.for %i = 0 to 42 { call @pre(%i) : (index) -> () - for %j = 7 to 56 step 2 { + affine.for %j = 7 to 56 step 2 { call @body2(%i, %j) : (index, index) -> () } call @mid(%i) : (index) -> () - for %k = 18 to 37 step 3 { + affine.for %k = 18 to 37 step 3 { call @body3(%i, %k) : (index, index) -> () } call @post(%i) : (index) -> () @@ -161,8 +161,8 @@ func @more_imperfectly_nested_loops() { // CHECK-NEXT: return // CHECK-NEXT: } func @affine_apply_loops_shorthand(%N : index) { - for %i = 0 to %N { - for %j = (d0)[]->(d0)(%i)[] to 42 { + affine.for %i = 0 to %N { + affine.for %j = (d0)[]->(d0)(%i)[] to 42 { call @body2(%i, %j) : (index, index) -> () } } @@ -360,7 +360,7 @@ func @if_for() { // CHECK-NEXT: [[outerEndBB]]: // CHECK-NEXT: br [[outerLoopInit:\^bb[0-9]+]] affine.if #set1(%i) { - for %j = 0 to 42 { + affine.for %j = 0 to 42 { affine.if #set2(%j) { call @body2(%i, %j) : (index, index) -> () } @@ -397,9 +397,9 @@ func @if_for() { // CHECK-NEXT: %c1_9 = constant 1 : index // CHECK-NEXT: %16 = addi %9, %c1_9 : index // CHECK-NEXT: br [[outerLoopCond]](%16 : index) - for %k = 0 to 42 { + affine.for %k = 0 to 42 { affine.if #set2(%k) { - for %l = 0 to 42 { + affine.for %l = 0 to 42 { call @body3(%k, %l) : (index, index) -> () } } @@ -446,8 +446,8 @@ func @if_for() { // CHECK-NEXT: return // CHECK-NEXT: } func @loop_min_max(%N : index) { - for %i = 0 to 42 { - for %j = max #lbMultiMap(%i)[%N] to min #ubMultiMap(%i)[%N] { + affine.for %i = 0 to 42 { + affine.for %j = max #lbMultiMap(%i)[%N] to min #ubMultiMap(%i)[%N] { call @body2(%i, %j) : (index, index) -> () } } @@ -486,7 +486,7 @@ func @loop_min_max(%N : index) { // CHECK-NEXT: return // CHECK-NEXT: } func @min_reduction_tree(%v : index) { - for %i = 0 to min #map_7_values(%v)[] { + affine.for %i = 0 to min #map_7_values(%v)[] { call @body(%i) : (index) -> () } return diff --git a/mlir/test/Transforms/memref-bound-check.mlir b/mlir/test/Transforms/memref-bound-check.mlir index 8a276d6763d..41f56672135 100644 --- a/mlir/test/Transforms/memref-bound-check.mlir +++ b/mlir/test/Transforms/memref-bound-check.mlir @@ -11,8 +11,8 @@ func @test() { %A = alloc() : memref<9 x 9 x i32> %B = alloc() : memref<111 x i32> - for %i = -1 to 10 { - for %j = -1 to 10 { + affine.for %i = -1 to 10 { + affine.for %j = -1 to 10 { %idx0 = affine.apply (d0, d1) -> (d0)(%i, %j) %idx1 = affine.apply (d0, d1) -> (d1)(%i, %j) // Out of bound access. @@ -27,7 +27,7 @@ func @test() { } } - for %k = 0 to 10 { + affine.for %k = 0 to 10 { // In bound. %u = load %B[%zero] : memref<111 x i32> // Out of bounds. @@ -43,8 +43,8 @@ func @test_mod_floordiv_ceildiv() { %zero = constant 0 : index %A = alloc() : memref<128 x 64 x 64 x i32> - for %i = 0 to 256 { - for %j = 0 to 256 { + affine.for %i = 0 to 256 { + affine.for %j = 0 to 256 { %idx0 = affine.apply (d0, d1, d2) -> (d0 mod 128 + 1)(%i, %j, %j) %idx1 = affine.apply (d0, d1, d2) -> (d1 floordiv 4 + 1)(%i, %j, %j) %idx2 = affine.apply (d0, d1, d2) -> (d2 ceildiv 4)(%i, %j, %j) @@ -69,8 +69,8 @@ func @test_no_out_of_bounds() { %C = alloc() : memref<257 x i32> %B = alloc() : memref<1 x i32> - for %i = 0 to 256 { - for %j = 0 to 256 { + affine.for %i = 0 to 256 { + affine.for %j = 0 to 256 { // All of these accesses are in bound; check that no errors are emitted. // CHECK: %3 = affine.apply {{#map.*}}(%i0, %i1) // CHECK-NEXT: %4 = load %0[%3, %c0] : memref<257x256xi32> @@ -93,8 +93,8 @@ func @mod_div() { %zero = constant 0 : index %A = alloc() : memref<128 x 64 x 64 x i32> - for %i = 0 to 256 { - for %j = 0 to 256 { + affine.for %i = 0 to 256 { + affine.for %j = 0 to 256 { %idx0 = affine.apply (d0, d1, d2) -> (d0 mod 128 + 1)(%i, %j, %j) %idx1 = affine.apply (d0, d1, d2) -> (d1 floordiv 4 + 1)(%i, %j, %j) %idx2 = affine.apply (d0, d1, d2) -> (d2 ceildiv 4)(%i, %j, %j) @@ -115,8 +115,8 @@ func @mod_div() { // CHECK-LABEL: func @mod_floordiv_nested() { func @mod_floordiv_nested() { %A = alloc() : memref<256 x 256 x i32> - for %i = 0 to 256 { - for %j = 0 to 256 { + affine.for %i = 0 to 256 { + affine.for %j = 0 to 256 { %idx0 = affine.apply (d0, d1) -> ((d0 mod 1024) floordiv 4)(%i, %j) %idx1 = affine.apply (d0, d1) -> ((((d1 mod 128) mod 32) ceildiv 4) * 32)(%i, %j) load %A[%idx0, %idx1] : memref<256 x 256 x i32> // expected-error {{'std.load' op memref out of upper bound access along dimension #2}} @@ -128,7 +128,7 @@ func @mod_floordiv_nested() { // CHECK-LABEL: func @test_semi_affine_bailout func @test_semi_affine_bailout(%N : index) { %B = alloc() : memref<10 x i32> - for %i = 0 to 10 { + affine.for %i = 0 to 10 { %idx = affine.apply (d0)[s0] -> (d0 * s0)(%i)[%N] %y = load %B[%idx] : memref<10 x i32> // expected-error@-1 {{getMemRefRegion: compose affine map failed}} @@ -139,7 +139,7 @@ func @test_semi_affine_bailout(%N : index) { // CHECK-LABEL: func @multi_mod_floordiv func @multi_mod_floordiv() { %A = alloc() : memref<2x2xi32> - for %ii = 0 to 64 { + affine.for %ii = 0 to 64 { %idx0 = affine.apply (d0) -> ((d0 mod 147456) floordiv 1152) (%ii) %idx1 = affine.apply (d0) -> (((d0 mod 147456) mod 1152) floordiv 384) (%ii) %v = load %A[%idx0, %idx1] : memref<2x2xi32> @@ -154,8 +154,8 @@ func @delinearize_mod_floordiv() { %out = alloc() : memref<64x9xi32> // Reshape '%in' into '%out'. - for %ii = 0 to 64 { - for %jj = 0 to 9 { + affine.for %ii = 0 to 64 { + affine.for %jj = 0 to 9 { %a0 = affine.apply (d0, d1) -> (d0 * (9 * 1024) + d1 * 128) (%ii, %jj) %a10 = affine.apply (d0) -> (d0 floordiv (2 * 3 * 3 * 128 * 128)) (%a0) @@ -190,7 +190,7 @@ func @out_of_bounds() { %in = alloc() : memref<1xi32> %c9 = constant 9 : i32 - for %i0 = 10 to 11 { + affine.for %i0 = 10 to 11 { %idy = affine.apply (d0) -> (100 * d0 floordiv 1000) (%i0) store %c9, %in[%idy] : memref<1xi32> // expected-error {{'std.store' op memref out of upper bound access along dimension #1}} } @@ -210,8 +210,8 @@ func @out_of_bounds() { func @test_complex_mod_floordiv(%arg0: memref<4x4x16x1xf32>) { %c0 = constant 0 : index %0 = alloc() : memref<1x2x3x3x16x1xf32> - for %i0 = 0 to 64 { - for %i1 = 0 to 9 { + affine.for %i0 = 0 to 64 { + affine.for %i1 = 0 to 9 { %2 = affine.apply #map3(%i0, %i1) %3 = affine.apply #map4(%i0, %i1) %4 = affine.apply #map5(%i0, %i1) @@ -231,8 +231,8 @@ func @test_complex_mod_floordiv(%arg0: memref<4x4x16x1xf32>) { func @test_mod_bound() { %0 = alloc() : memref<7 x f32> %1 = alloc() : memref<6 x f32> - for %i0 = 0 to 4096 { - for %i1 = #map0(%i0) to #map1(%i0) { + affine.for %i0 = 0 to 4096 { + affine.for %i1 = #map0(%i0) to #map1(%i0) { load %0[%i1] : memref<7 x f32> load %1[%i1] : memref<6 x f32> // expected-error@-1 {{'std.load' op memref out of upper bound access along dimension #1}} @@ -253,13 +253,13 @@ func @test_floordiv_bound() { %1 = alloc() : memref<1026 x f32> %2 = alloc() : memref<4096 x f32> %N = constant 2048 : index - for %i0 = 0 to 4096 { - for %i1 = #map0(%i0) to #map1(%i0) { + affine.for %i0 = 0 to 4096 { + affine.for %i1 = #map0(%i0) to #map1(%i0) { load %0[%i1] : memref<1027 x f32> load %1[%i1] : memref<1026 x f32> // expected-error@-1 {{'std.load' op memref out of upper bound access along dimension #1}} } - for %i2 = 0 to #map2(%N) { + affine.for %i2 = 0 to #map2(%N) { // Within bounds. %v = load %2[%i2] : memref<4096 x f32> } @@ -277,9 +277,9 @@ func @test_floordiv_bound() { // CHECK-LABEL: func @non_composed_bound_operand func @non_composed_bound_operand(%arg0: memref<1024xf32>) { - for %i0 = 4 to 1028 step 4 { + affine.for %i0 = 4 to 1028 step 4 { %i1 = affine.apply (d0) -> (d0 - 4) (%i0) - for %i2 = #map_lb(%i1) to #map_ub(%i1) { + affine.for %i2 = #map_lb(%i1) to #map_ub(%i1) { %0 = load %arg0[%i2] : memref<1024xf32> } } diff --git a/mlir/test/Transforms/memref-dataflow-opt.mlir b/mlir/test/Transforms/memref-dataflow-opt.mlir index 710d14c1cf9..ed39d71eefd 100644 --- a/mlir/test/Transforms/memref-dataflow-opt.mlir +++ b/mlir/test/Transforms/memref-dataflow-opt.mlir @@ -10,14 +10,14 @@ func @simple_store_load() { %cf7 = constant 7.0 : f32 %m = alloc() : memref<10xf32> - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> %v0 = load %m[%i0] : memref<10xf32> %v1 = addf %v0, %v0 : f32 } return // CHECK: %cst = constant 7.000000e+00 : f32 -// CHECK-NEXT: for %i0 = 0 to 10 { +// CHECK-NEXT: affine.for %i0 = 0 to 10 { // CHECK-NEXT: %0 = addf %cst, %cst : f32 // CHECK-NEXT: } // CHECK-NEXT: return @@ -30,7 +30,7 @@ func @multi_store_load() { %cf8 = constant 8.0 : f32 %cf9 = constant 9.0 : f32 %m = alloc() : memref<10xf32> - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> %v0 = load %m[%i0] : memref<10xf32> %v1 = addf %v0, %v0 : f32 @@ -45,7 +45,7 @@ func @multi_store_load() { // CHECK-NEXT: %cst = constant 7.000000e+00 : f32 // CHECK-NEXT: %cst_0 = constant 8.000000e+00 : f32 // CHECK-NEXT: %cst_1 = constant 9.000000e+00 : f32 -// CHECK-NEXT: for %i0 = 0 to 10 { +// CHECK-NEXT: affine.for %i0 = 0 to 10 { // CHECK-NEXT: %0 = addf %cst, %cst : f32 // CHECK-NEXT: %1 = mulf %cst_1, %cst_1 : f32 // CHECK-NEXT: } @@ -59,8 +59,8 @@ func @multi_store_load() { func @store_load_affine_apply() -> memref<10x10xf32> { %cf7 = constant 7.0 : f32 %m = alloc() : memref<10x10xf32> - for %i0 = 0 to 10 { - for %i1 = 0 to 10 { + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { %t0 = affine.apply (d0, d1) -> (d1 + 1)(%i0, %i1) %t1 = affine.apply (d0, d1) -> (d0)(%i0, %i1) %idx0 = affine.apply (d0, d1) -> (d1) (%t0, %t1) @@ -75,8 +75,8 @@ func @store_load_affine_apply() -> memref<10x10xf32> { return %m : memref<10x10xf32> // CHECK: %cst = constant 7.000000e+00 : f32 // CHECK-NEXT: %0 = alloc() : memref<10x10xf32> -// CHECK-NEXT: for %i0 = 0 to 10 { -// CHECK-NEXT: for %i1 = 0 to 10 { +// CHECK-NEXT: affine.for %i0 = 0 to 10 { +// CHECK-NEXT: affine.for %i1 = 0 to 10 { // CHECK-NEXT: %1 = affine.apply [[MAP0]](%i0, %i1) // CHECK-NEXT: %2 = affine.apply [[MAP1]](%i0, %i1) // CHECK-NEXT: %3 = affine.apply [[MAP2]](%1, %2) @@ -92,17 +92,17 @@ func @store_load_affine_apply() -> memref<10x10xf32> { func @store_load_nested(%N : index) { %cf7 = constant 7.0 : f32 %m = alloc() : memref<10xf32> - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> - for %i1 = 0 to %N { + affine.for %i1 = 0 to %N { %v0 = load %m[%i0] : memref<10xf32> %v1 = addf %v0, %v0 : f32 } } return // CHECK: %cst = constant 7.000000e+00 : f32 -// CHECK-NEXT: for %i0 = 0 to 10 { -// CHECK-NEXT: for %i1 = 0 to %arg0 { +// CHECK-NEXT: affine.for %i0 = 0 to 10 { +// CHECK-NEXT: affine.for %i1 = 0 to %arg0 { // CHECK-NEXT: %0 = addf %cst, %cst : f32 // CHECK-NEXT: } // CHECK-NEXT: } @@ -117,12 +117,12 @@ func @multi_store_load_nested_no_fwd(%N : index) { %cf7 = constant 7.0 : f32 %cf8 = constant 8.0 : f32 %m = alloc() : memref<10xf32> - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> - for %i1 = 0 to %N { + affine.for %i1 = 0 to %N { store %cf8, %m[%i1] : memref<10xf32> } - for %i2 = 0 to %N { + affine.for %i2 = 0 to %N { // CHECK: %{{[0-9]+}} = load %0[%i0] : memref<10xf32> %v0 = load %m[%i0] : memref<10xf32> %v1 = addf %v0, %v0 : f32 @@ -138,9 +138,9 @@ func @store_load_store_nested_no_fwd(%N : index) { %cf7 = constant 7.0 : f32 %cf9 = constant 9.0 : f32 %m = alloc() : memref<10xf32> - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> - for %i1 = 0 to %N { + affine.for %i1 = 0 to %N { // CHECK: %{{[0-9]+}} = load %0[%i0] : memref<10xf32> %v0 = load %m[%i0] : memref<10xf32> %v1 = addf %v0, %v0 : f32 @@ -159,16 +159,16 @@ func @multi_store_load_nested_fwd(%N : index) { %cf9 = constant 9.0 : f32 %cf10 = constant 10.0 : f32 %m = alloc() : memref<10xf32> - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> - for %i1 = 0 to %N { + affine.for %i1 = 0 to %N { store %cf8, %m[%i1] : memref<10xf32> } - for %i2 = 0 to %N { + affine.for %i2 = 0 to %N { store %cf9, %m[%i2] : memref<10xf32> } store %cf10, %m[%i0] : memref<10xf32> - for %i3 = 0 to %N { + affine.for %i3 = 0 to %N { // CHECK-NOT: %{{[0-9]+}} = load %v0 = load %m[%i0] : memref<10xf32> %v1 = addf %v0, %v0 : f32 @@ -182,10 +182,10 @@ func @multi_store_load_nested_fwd(%N : index) { func @store_load_no_fwd() { %cf7 = constant 7.0 : f32 %m = alloc() : memref<10xf32> - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> - for %i1 = 0 to 10 { - for %i2 = 0 to 10 { + affine.for %i1 = 0 to 10 { + affine.for %i2 = 0 to 10 { // CHECK: load %{{[0-9]+}} %v0 = load %m[%i2] : memref<10xf32> %v1 = addf %v0, %v0 : f32 @@ -202,9 +202,9 @@ func @store_load_fwd() { %c0 = constant 0 : index %m = alloc() : memref<10xf32> store %cf7, %m[%c0] : memref<10xf32> - for %i0 = 0 to 10 { - for %i1 = 0 to 10 { - for %i2 = 0 to 10 { + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { + affine.for %i2 = 0 to 10 { // CHECK-NOT: load %{{[0-9]}}+ %v0 = load %m[%c0] : memref<10xf32> %v1 = addf %v0, %v0 : f32 @@ -223,9 +223,9 @@ func @store_load_store_nested_fwd(%N : index) -> f32 { %c0 = constant 0 : index %c1 = constant 1 : index %m = alloc() : memref<10xf32> - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> - for %i1 = 0 to %N { + affine.for %i1 = 0 to %N { %v0 = load %m[%i0] : memref<10xf32> %v1 = addf %v0, %v0 : f32 %idx = affine.apply (d0) -> (d0 + 1) (%i0) @@ -236,9 +236,9 @@ func @store_load_store_nested_fwd(%N : index) -> f32 { %v3 = load %m[%c1] : memref<10xf32> return %v3 : f32 // CHECK: %0 = alloc() : memref<10xf32> -// CHECK-NEXT: for %i0 = 0 to 10 { +// CHECK-NEXT: affine.for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> -// CHECK-NEXT: for %i1 = 0 to %arg0 { +// CHECK-NEXT: affine.for %i1 = 0 to %arg0 { // CHECK-NEXT: %1 = addf %cst, %cst : f32 // CHECK-NEXT: %2 = affine.apply [[MAP4]](%i0) // CHECK-NEXT: store %cst_0, %0[%2] : memref<10xf32> diff --git a/mlir/test/Transforms/memref-dependence-check.mlir b/mlir/test/Transforms/memref-dependence-check.mlir index 0accc30630b..00d0e730098 100644 --- a/mlir/test/Transforms/memref-dependence-check.mlir +++ b/mlir/test/Transforms/memref-dependence-check.mlir @@ -13,14 +13,14 @@ func @store_may_execute_before_load() { // ancestor IfOp of the store, dominates the ancestor ForSmt of the load, // and thus the store "may" conditionally execute before the load. affine.if #set0(%c0) { - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-note@-2 {{dependence from 0 to 0 at depth 2 = false}} // expected-note@-3 {{dependence from 0 to 1 at depth 1 = true}} } } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { %v0 = load %m[%i1] : memref<10xf32> // expected-note@-1 {{dependence from 1 to 1 at depth 1 = false}} // expected-note@-2 {{dependence from 1 to 1 at depth 2 = false}} @@ -37,13 +37,13 @@ func @dependent_loops() { %cst = constant 7.000000e+00 : f32 // There is a dependence from 0 to 1 at depth 1 (common surrounding loops 0) // because the first loop with the store dominates the second loop. - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { store %cst, %0[%i0] : memref<10xf32> // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-note@-2 {{dependence from 0 to 0 at depth 2 = false}} // expected-note@-3 {{dependence from 0 to 1 at depth 1 = true}} } - for %i1 = 0 to 10 { + affine.for %i1 = 0 to 10 { %1 = load %0[%i1] : memref<10xf32> // expected-note@-1 {{dependence from 1 to 1 at depth 1 = false}} // expected-note@-2 {{dependence from 1 to 1 at depth 2 = false}} @@ -231,7 +231,7 @@ func @store_range_load_after_range() { %m = alloc() : memref<100xf32> %c7 = constant 7.0 : f32 %c10 = constant 10 : index - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { %a0 = affine.apply (d0) -> (d0) (%i0) store %c7, %m[%a0] : memref<100xf32> // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} @@ -254,7 +254,7 @@ func @store_load_func_symbol(%arg0: index, %arg1: index) { %m = alloc() : memref<100xf32> %c7 = constant 7.0 : f32 %c10 = constant 10 : index - for %i0 = 0 to %arg1 { + affine.for %i0 = 0 to %arg1 { %a0 = affine.apply (d0) -> (d0) (%arg0) store %c7, %m[%a0] : memref<100xf32> // expected-note@-1 {{dependence from 0 to 0 at depth 1 = [1, +inf]}} @@ -277,7 +277,7 @@ func @store_range_load_last_in_range() { %m = alloc() : memref<100xf32> %c7 = constant 7.0 : f32 %c10 = constant 10 : index - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { %a0 = affine.apply (d0) -> (d0) (%i0) // For dependence from 0 to 1, we do not have a loop carried dependence // because only the final write in the loop accesses the same element as the @@ -305,7 +305,7 @@ func @store_range_load_before_range() { %m = alloc() : memref<100xf32> %c7 = constant 7.0 : f32 %c0 = constant 0 : index - for %i0 = 1 to 11 { + affine.for %i0 = 1 to 11 { %a0 = affine.apply (d0) -> (d0) (%i0) store %c7, %m[%a0] : memref<100xf32> // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} @@ -328,7 +328,7 @@ func @store_range_load_first_in_range() { %m = alloc() : memref<100xf32> %c7 = constant 7.0 : f32 %c0 = constant 0 : index - for %i0 = 1 to 11 { + affine.for %i0 = 1 to 11 { %a0 = affine.apply (d0) -> (d0) (%i0) // Dependence from 0 to 1 at depth 1 is a range because all loads at // constant index zero are reads after first store at index zero during @@ -353,7 +353,7 @@ func @store_range_load_first_in_range() { func @store_plus_3() { %m = alloc() : memref<100xf32> %c7 = constant 7.0 : f32 - for %i0 = 1 to 11 { + affine.for %i0 = 1 to 11 { %a0 = affine.apply (d0) -> (d0 + 3) (%i0) store %c7, %m[%a0] : memref<100xf32> // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} @@ -375,7 +375,7 @@ func @store_plus_3() { func @load_minus_2() { %m = alloc() : memref<100xf32> %c7 = constant 7.0 : f32 - for %i0 = 2 to 11 { + affine.for %i0 = 2 to 11 { %a0 = affine.apply (d0) -> (d0) (%i0) store %c7, %m[%a0] : memref<100xf32> // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} @@ -397,8 +397,8 @@ func @load_minus_2() { func @perfectly_nested_loops_loop_independent() { %m = alloc() : memref<10x10xf32> %c7 = constant 7.0 : f32 - for %i0 = 0 to 11 { - for %i1 = 0 to 11 { + affine.for %i0 = 0 to 11 { + affine.for %i1 = 0 to 11 { // Dependence from access 0 to 1 is loop independent at depth = 3. %a00 = affine.apply (d0, d1) -> (d0) (%i0, %i1) %a01 = affine.apply (d0, d1) -> (d1) (%i0, %i1) @@ -428,8 +428,8 @@ func @perfectly_nested_loops_loop_independent() { func @perfectly_nested_loops_loop_carried_at_depth1() { %m = alloc() : memref<10x10xf32> %c7 = constant 7.0 : f32 - for %i0 = 0 to 9 { - for %i1 = 0 to 9 { + affine.for %i0 = 0 to 9 { + affine.for %i1 = 0 to 9 { // Dependence from access 0 to 1 is loop carried at depth 1. %a00 = affine.apply (d0, d1) -> (d0) (%i0, %i1) %a01 = affine.apply (d0, d1) -> (d1) (%i0, %i1) @@ -459,8 +459,8 @@ func @perfectly_nested_loops_loop_carried_at_depth1() { func @perfectly_nested_loops_loop_carried_at_depth2() { %m = alloc() : memref<10x10xf32> %c7 = constant 7.0 : f32 - for %i0 = 0 to 10 { - for %i1 = 0 to 10 { + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { // Dependence from access 0 to 1 is loop carried at depth 2. %a00 = affine.apply (d0, d1) -> (d0) (%i0, %i1) %a01 = affine.apply (d0, d1) -> (d1) (%i0, %i1) @@ -491,8 +491,8 @@ func @one_common_loop() { %m = alloc() : memref<10x10xf32> %c7 = constant 7.0 : f32 // There is a loop-independent dependence from access 0 to 1 at depth 2. - for %i0 = 0 to 10 { - for %i1 = 0 to 10 { + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { %a00 = affine.apply (d0, d1) -> (d0) (%i0, %i1) %a01 = affine.apply (d0, d1) -> (d1) (%i0, %i1) store %c7, %m[%a00, %a01] : memref<10x10xf32> @@ -502,7 +502,7 @@ func @one_common_loop() { // expected-note@-4 {{dependence from 0 to 1 at depth 1 = false}} // expected-note@-5 {{dependence from 0 to 1 at depth 2 = true}} } - for %i2 = 0 to 9 { + affine.for %i2 = 0 to 9 { %a10 = affine.apply (d0, d1) -> (d0) (%i0, %i2) %a11 = affine.apply (d0, d1) -> (d1) (%i0, %i2) %v0 = load %m[%a10, %a11] : memref<10x10xf32> @@ -525,7 +525,7 @@ func @dependence_cycle() { // Dependences: // *) loop-independent dependence from access 1 to 2 at depth 2. // *) loop-carried dependence from access 3 to 0 at depth 1. - for %i0 = 0 to 9 { + affine.for %i0 = 0 to 9 { %a0 = affine.apply (d0) -> (d0) (%i0) %v0 = load %m.a[%a0] : memref<100xf32> // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} @@ -575,8 +575,8 @@ func @dependence_cycle() { func @negative_and_positive_direction_vectors(%arg0: index, %arg1: index) { %m = alloc() : memref<10x10xf32> %c7 = constant 7.0 : f32 - for %i0 = 0 to %arg0 { - for %i1 = 0 to %arg1 { + affine.for %i0 = 0 to %arg0 { + affine.for %i1 = 0 to %arg1 { %a00 = affine.apply (d0, d1) -> (d0 - 1) (%i0, %i1) %a01 = affine.apply (d0, d1) -> (d1 + 1) (%i0, %i1) %v0 = load %m[%a00, %a01] : memref<10x10xf32> @@ -605,8 +605,8 @@ func @negative_and_positive_direction_vectors(%arg0: index, %arg1: index) { func @war_raw_waw_deps() { %m = alloc() : memref<100xf32> %c7 = constant 7.0 : f32 - for %i0 = 0 to 10 { - for %i1 = 0 to 10 { + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { %a0 = affine.apply (d0) -> (d0 + 1) (%i1) %v0 = load %m[%a0] : memref<100xf32> // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} @@ -633,7 +633,7 @@ func @war_raw_waw_deps() { func @mod_deps() { %m = alloc() : memref<100xf32> %c7 = constant 7.0 : f32 - for %i0 = 0 to 10 { + affine.for %i0 = 0 to 10 { %a0 = affine.apply (d0) -> (d0 mod 2) (%i0) // Results are conservative here since we currently don't have a way to // represent strided sets in FlatAffineConstraints. @@ -658,8 +658,8 @@ func @loop_nest_depth() { %0 = alloc() : memref<100x100xf32> %c7 = constant 7.0 : f32 - for %i0 = 0 to 128 { - for %i1 = 0 to 8 { + affine.for %i0 = 0 to 128 { + affine.for %i1 = 0 to 8 { store %c7, %0[%i0, %i1] : memref<100x100xf32> // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-note@-2 {{dependence from 0 to 0 at depth 2 = false}} @@ -667,10 +667,10 @@ func @loop_nest_depth() { // expected-note@-4 {{dependence from 0 to 1 at depth 1 = true}} } } - for %i2 = 0 to 8 { - for %i3 = 0 to 8 { - for %i4 = 0 to 8 { - for %i5 = 0 to 16 { + affine.for %i2 = 0 to 8 { + affine.for %i3 = 0 to 8 { + affine.for %i4 = 0 to 8 { + affine.for %i5 = 0 to 16 { %8 = affine.apply (d0, d1) -> (d0 * 16 + d1)(%i4, %i5) %9 = load %0[%8, %i3] : memref<100x100xf32> // expected-note@-1 {{dependence from 1 to 0 at depth 1 = false}} @@ -693,9 +693,9 @@ func @loop_nest_depth() { func @mod_div_3d() { %M = alloc() : memref<2x2x2xi32> %c0 = constant 0 : i32 - for %i0 = 0 to 8 { - for %i1 = 0 to 8 { - for %i2 = 0 to 8 { + affine.for %i0 = 0 to 8 { + affine.for %i1 = 0 to 8 { + affine.for %i2 = 0 to 8 { %idx0 = affine.apply (d0, d1, d2) -> (d0 floordiv 4) (%i0, %i1, %i2) %idx1 = affine.apply (d0, d1, d2) -> (d1 mod 2) (%i0, %i1, %i2) %idx2 = affine.apply (d0, d1, d2) -> (d2 floordiv 4) (%i0, %i1, %i2) @@ -719,12 +719,12 @@ func @delinearize_mod_floordiv() { %in = alloc() : memref<2x2x3x3x16x1xi32> %out = alloc() : memref<64x9xi32> - for %i0 = 0 to 2 { - for %i1 = 0 to 2 { - for %i2 = 0 to 3 { - for %i3 = 0 to 3 { - for %i4 = 0 to 16 { - for %i5 = 0 to 1 { + affine.for %i0 = 0 to 2 { + affine.for %i1 = 0 to 2 { + affine.for %i2 = 0 to 3 { + affine.for %i3 = 0 to 3 { + affine.for %i4 = 0 to 16 { + affine.for %i5 = 0 to 1 { store %val, %in[%i0, %i1, %i2, %i3, %i4, %i5] : memref<2x2x3x3x16x1xi32> // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-note@-2 {{dependence from 0 to 0 at depth 2 = false}} @@ -742,8 +742,8 @@ func @delinearize_mod_floordiv() { } } - for %ii = 0 to 64 { - for %jj = 0 to 9 { + affine.for %ii = 0 to 64 { + affine.for %jj = 0 to 9 { %a0 = affine.apply (d0, d1) -> (d0 * (9 * 1024) + d1 * 128) (%ii, %jj) %a10 = affine.apply (d0) -> (d0 floordiv (2 * 3 * 3 * 128 * 128)) (%a0) diff --git a/mlir/test/Transforms/parallelism-detection.mlir b/mlir/test/Transforms/parallelism-detection.mlir index 91f8f92c1dc..2d76b2649df 100644 --- a/mlir/test/Transforms/parallelism-detection.mlir +++ b/mlir/test/Transforms/parallelism-detection.mlir @@ -5,11 +5,11 @@ func @loop_nest_3d_outer_two_parallel(%N : index) { %0 = alloc() : memref<1024 x 1024 x vector<64xf32>> %1 = alloc() : memref<1024 x 1024 x vector<64xf32>> %2 = alloc() : memref<1024 x 1024 x vector<64xf32>> - for %i = 0 to %N { + affine.for %i = 0 to %N { // expected-note@-1 {{parallel loop}} - for %j = 0 to %N { + affine.for %j = 0 to %N { // expected-note@-1 {{parallel loop}} - for %k = 0 to %N { + affine.for %k = 0 to %N { %5 = load %0[%i, %k] : memref<1024x1024xvector<64xf32>> %6 = load %1[%k, %j] : memref<1024x1024xvector<64xf32>> %7 = load %2[%i, %j] : memref<1024x1024xvector<64xf32>> diff --git a/mlir/test/Transforms/pipeline-data-transfer.mlir b/mlir/test/Transforms/pipeline-data-transfer.mlir index 1f9383b68ab..d7ae69c7bb0 100644 --- a/mlir/test/Transforms/pipeline-data-transfer.mlir +++ b/mlir/test/Transforms/pipeline-data-transfer.mlir @@ -16,13 +16,13 @@ func @loop_nest_dma() { %zero = constant 0 : index %num_elts = constant 128 : index - for %i = 0 to 8 { + affine.for %i = 0 to 8 { dma_start %A[%i], %Ah[%i], %num_elts, %tag[%zero] : memref<256 x f32>, memref<32 x f32, 1>, memref<1 x f32> dma_wait %tag[%zero], %num_elts : memref<1 x f32> %v = load %Ah[%i] : memref<32 x f32, (d0) -> (d0), 1> %r = "compute"(%v) : (f32) -> (f32) store %r, %Ah[%i] : memref<32 x f32, (d0) -> (d0), 1> - for %j = 0 to 128 { + affine.for %j = 0 to 128 { "do_more_compute"(%i, %j) : (index, index) -> () } } @@ -34,7 +34,7 @@ func @loop_nest_dma() { // CHECK-NEXT: %3 = affine.apply [[MOD_2]](%c0) // CHECK-NEXT: %4 = affine.apply [[MOD_2]](%c0) // CHECK-NEXT: dma_start %0[%c0], %1[%3, %c0], %c128, %2[%4, %c0_0] : memref<256xf32>, memref<2x32xf32, 1>, memref<2x1xf32> -// CHECK-NEXT: for %i0 = 1 to 8 { +// CHECK-NEXT: affine.for %i0 = 1 to 8 { // CHECK-NEXT: %5 = affine.apply [[MOD_2]](%i0) // CHECK-NEXT: %6 = affine.apply [[MOD_2]](%i0) // CHECK-NEXT: dma_start %0[%i0], %1[%5, %i0], %c128, %2[%6, %c0_0] : memref<256xf32>, memref<2x32xf32, 1>, memref<2x1xf32> @@ -45,7 +45,7 @@ func @loop_nest_dma() { // CHECK-NEXT: %10 = load %1[%9, %7] : memref<2x32xf32, 1> // CHECK-NEXT: %11 = "compute"(%10) : (f32) -> f32 // CHECK-NEXT: store %11, %1[%9, %7] : memref<2x32xf32, 1> -// CHECK-NEXT: for %i1 = 0 to 128 { +// CHECK-NEXT: affine.for %i1 = 0 to 128 { // CHECK-NEXT: "do_more_compute"(%7, %i1) : (index, index) -> () // CHECK-NEXT: } // CHECK-NEXT: } @@ -56,7 +56,7 @@ func @loop_nest_dma() { // CHECK-NEXT: %15 = load %1[%14, %12] : memref<2x32xf32, 1> // CHECK-NEXT: %16 = "compute"(%15) : (f32) -> f32 // CHECK-NEXT: store %16, %1[%14, %12] : memref<2x32xf32, 1> -// CHECK-NEXT: for %i2 = 0 to 128 { +// CHECK-NEXT: affine.for %i2 = 0 to 128 { // CHECK-NEXT: "do_more_compute"(%12, %i2) : (index, index) -> () // CHECK-NEXT: } // CHECK-NEXT: dealloc %2 : memref<2x1xf32> @@ -70,7 +70,7 @@ func @loop_step(%arg0: memref<512xf32>, %arg1: memref<512xf32>) { %c0 = constant 0 : index %c4 = constant 4 : index - for %i0 = 0 to 512 step 4 { + affine.for %i0 = 0 to 512 step 4 { %1 = alloc() : memref<4xf32, 1> %2 = alloc() : memref<1xi32> dma_start %arg0[%i0], %1[%c0], %c4, %2[%c0] @@ -84,7 +84,7 @@ func @loop_step(%arg0: memref<512xf32>, // CHECK: %2 = affine.apply [[FLOOR_MOD_2]](%c0) // CHECK: %3 = affine.apply [[FLOOR_MOD_2]](%c0) // CHECK-NEXT: dma_start %arg0[%c0], %0[%2, %c0_0], %c4, [[TAG]][%3, %c0_0] : memref<512xf32>, memref<2x4xf32, 1>, memref<2x1xi32> -// CHECK-NEXT: for %i0 = 4 to 512 step 4 { +// CHECK-NEXT: affine.for %i0 = 4 to 512 step 4 { // CHECK-NEXT: %4 = affine.apply [[FLOOR_MOD_2]](%i0) // CHECK-NEXT: %5 = affine.apply [[FLOOR_MOD_2]](%i0) // CHECK-NEXT: dma_start %arg0[%i0], %0[%4, %c0_0], %c4, [[TAG]][%5, %c0_0] : memref<512xf32>, memref<2x4xf32, 1>, memref<2x1xi32> @@ -117,8 +117,8 @@ func @loop_dma_nested(%arg0: memref<512x32xvector<8xf32>, #map0>, %arg1: memref< // CHECK-DAG: [[BUF_ARG2:%[0-9]+]] = alloc() : memref<2x64x4xvector<8xf32>, 2> // CHECK-DAG: [[TAG_ARG2:%[0-9]+]] = alloc() : memref<2x2xi32> // CHECK: dma_start %arg2[ - // CHECK: for %i0 = 1 to 8 { - for %i0 = 0 to 8 { + // CHECK: affine.for %i0 = 1 to 8 { + affine.for %i0 = 0 to 8 { %6 = affine.apply #map2(%i0) dma_start %arg2[%6, %c0], %2[%c0, %c0], %num_elts, %5[%c0] : memref<512x32xvector<8xf32>, #map0>, memref<64x4xvector<8xf32>, #map0, 2>, memref<2xi32> dma_wait %5[%c0], %num_elts : memref<2xi32> @@ -132,8 +132,8 @@ func @loop_dma_nested(%arg0: memref<512x32xvector<8xf32>, #map0>, %arg1: memref< // CHECK: [[TAG_ARG1:%[0-9]+]] = alloc() : memref<2x2xi32> // CHECK: dma_start %arg0[ // CHECK: dma_start %arg1[ - // CHECK-NEXT for %i1 = 1 to 8 { - for %i1 = 0 to 8 { + // CHECK-NEXT affine.for %i1 = 1 to 8 { + affine.for %i1 = 0 to 8 { %7 = affine.apply #map1(%i0, %i1) %8 = affine.apply #map2(%i1) dma_start %arg0[%7, %c0], %0[%c0, %c0], %num_elts, %3[%c0] : memref<512x32xvector<8xf32>, #map0>, memref<64x4xvector<8xf32>, #map0, 2>, memref<2xi32> @@ -145,8 +145,8 @@ func @loop_dma_nested(%arg0: memref<512x32xvector<8xf32>, #map0>, %arg1: memref< // CHECK: dma_start %arg1[ // CHECK: dma_wait [[TAG_ARG0]] // CHECK: dma_wait [[TAG_ARG1]] - // CHECK-NEXT: for %i2 = 0 to 4 { - for %i2 = 0 to 4 { + // CHECK-NEXT: affine.for %i2 = 0 to 4 { + affine.for %i2 = 0 to 4 { "foo"() : () -> () } } @@ -166,16 +166,16 @@ func @loop_dma_nested(%arg0: memref<512x32xvector<8xf32>, #map0>, %arg1: memref< // CHECK: [[TAG_ARG1_NESTED:%[0-9]+]] = alloc() : memref<2x2xi32> // CHECK: dma_start %arg0[ // CHECK: dma_start %arg1[ - // CHECK: for %i4 = 1 to 8 { + // CHECK: affine.for %i4 = 1 to 8 { // CHECK: dma_start %arg0[ // CHECK: dma_start %arg1[ // CHECK: dma_wait [[TAG_ARG0_NESTED]] // CHECK: dma_wait [[TAG_ARG1_NESTED]] - // CHECK: for %i5 = 0 to 4 { + // CHECK: affine.for %i5 = 0 to 4 { // CHECK: "foo"() : () -> () // CHECK: dma_wait [[TAG_ARG0_NESTED]] // CHECK: dma_wait [[TAG_ARG1_NESTED]] - // CHECK: for %i6 = 0 to 4 { + // CHECK: affine.for %i6 = 0 to 4 { } return // CHECK: } @@ -202,8 +202,8 @@ func @loop_dma_dependent(%arg2: memref<512x32xvector<8xf32>>) { // The two DMAs below are dependent (incoming and outgoing on the same // memref) in the same iteration; so no pipelining here. // CHECK-NOT: dma_start - // CHECK: for %i0 = 0 to 8 { - for %i0 = 0 to 8 { + // CHECK: affine.for %i0 = 0 to 8 { + affine.for %i0 = 0 to 8 { %6 = affine.apply #map2(%i0) dma_start %arg2[%6, %c0], %2[%c0, %c0], %num_elts, %5[%c0] : memref<512x32xvector<8xf32>>, memref<64x4xvector<8xf32>, 2>, memref<2xi32> dma_wait %5[%c0], %num_elts : memref<2xi32> @@ -223,8 +223,8 @@ func @escaping_use(%arg0: memref<512 x 32 x f32>) { %tag = alloc() : memref<1 x i32> // CHECK-NOT: dma_start - // CHECK: for %i0 = 0 to 16 { - for %kTT = 0 to 16 { + // CHECK: affine.for %i0 = 0 to 16 { + affine.for %kTT = 0 to 16 { dma_start %arg0[%zero, %zero], %Av[%zero, %zero], %num_elt, %tag[%zero] : memref<512 x 32 x f32>, memref<32 x 32 x f32, 2>, memref<1 x i32> @@ -247,14 +247,14 @@ func @live_out_use(%arg0: memref<512 x 32 x f32>) -> f32 { %tag = alloc() : memref<1 x i32> // CHECK-NOT: dma_start - // CHECK: for %i0 = 0 to 16 { - for %kTT = 0 to 16 { + // CHECK: affine.for %i0 = 0 to 16 { + affine.for %kTT = 0 to 16 { dma_start %arg0[%zero, %zero], %Av[%zero, %zero], %num_elt, %tag[%zero] : memref<512 x 32 x f32>, memref<32 x 32 x f32, 2>, memref<1 x i32> dma_wait %tag[%zero], %num_elt : memref<1 x i32> } - // Use live out of 'for' inst; no DMA pipelining will be done. + // Use live out of 'affine.for' inst; no DMA pipelining will be done. %v = load %Av[%zero, %zero] : memref<32 x 32 x f32, 2> return %v : f32 // CHECK: %{{[0-9]+}} = load %{{[0-9]+}}[%c0, %c0] : memref<32x32xf32, 2> @@ -278,14 +278,14 @@ func @dynamic_shape_dma_buffer(%arg0: memref<512 x 32 x f32>) { // CHECK: %5 = affine.apply [[MOD_2]](%c0) // CHECK: %6 = affine.apply [[MOD_2]](%c0) // CHECK: dma_start %arg0[%c0_0, %c0_0], %3[%5, %c0_0, %c0_0], %c512, %4[%6, %c0_0] - for %kTT = 0 to 16 { + affine.for %kTT = 0 to 16 { dma_start %arg0[%zero, %zero], %Av[%zero, %zero], %num_elt, %tag[%zero] : memref<512 x 32 x f32>, memref, memref<1 x i32> dma_wait %tag[%zero], %num_elt : memref<1 x i32> } return -// CHECK-NEXT: for %i0 = 1 to 16 { +// CHECK-NEXT: affine.for %i0 = 1 to 16 { // CHECK: %7 = affine.apply [[MOD_2]](%i0) // CHECK: %8 = affine.apply [[MOD_2]](%i0) // CHECK: dma_start %arg0[%c0_0, %c0_0], %3[%7, %c0_0, %c0_0], %c512, %4[%8, %c0_0] diff --git a/mlir/test/Transforms/simplify-affine-structures.mlir b/mlir/test/Transforms/simplify-affine-structures.mlir index 2459604f369..feb3a99b70b 100644 --- a/mlir/test/Transforms/simplify-affine-structures.mlir +++ b/mlir/test/Transforms/simplify-affine-structures.mlir @@ -73,8 +73,8 @@ // CHECK-LABEL: func @test_gaussian_elimination_empty_set0() { func @test_gaussian_elimination_empty_set0() { - for %i0 = 1 to 10 { - for %i1 = 1 to 100 { + affine.for %i0 = 1 to 10 { + affine.for %i1 = 1 to 100 { // CHECK: [[SET_EMPTY_2D]](%i0, %i1) affine.if (d0, d1) : (2 == 0)(%i0, %i1) { } @@ -85,8 +85,8 @@ func @test_gaussian_elimination_empty_set0() { // CHECK-LABEL: func @test_gaussian_elimination_empty_set1() { func @test_gaussian_elimination_empty_set1() { - for %i0 = 1 to 10 { - for %i1 = 1 to 100 { + affine.for %i0 = 1 to 10 { + affine.for %i1 = 1 to 100 { // CHECK: [[SET_EMPTY_2D]](%i0, %i1) affine.if (d0, d1) : (1 >= 0, -1 >= 0) (%i0, %i1) { } @@ -97,8 +97,8 @@ func @test_gaussian_elimination_empty_set1() { // CHECK-LABEL: func @test_gaussian_elimination_non_empty_set2() { func @test_gaussian_elimination_non_empty_set2() { - for %i0 = 1 to 10 { - for %i1 = 1 to 100 { + affine.for %i0 = 1 to 10 { + affine.for %i1 = 1 to 100 { // CHECK: #set1(%i0, %i1) affine.if #set2(%i0, %i1) { } @@ -111,8 +111,8 @@ func @test_gaussian_elimination_non_empty_set2() { func @test_gaussian_elimination_empty_set3() { %c7 = constant 7 : index %c11 = constant 11 : index - for %i0 = 1 to 10 { - for %i1 = 1 to 100 { + affine.for %i0 = 1 to 10 { + affine.for %i1 = 1 to 100 { // CHECK: #set2(%i0, %i1)[%c7, %c11] affine.if #set3(%i0, %i1)[%c7, %c11] { } @@ -125,8 +125,8 @@ func @test_gaussian_elimination_empty_set3() { func @test_gaussian_elimination_non_empty_set4() { %c7 = constant 7 : index %c11 = constant 11 : index - for %i0 = 1 to 10 { - for %i1 = 1 to 100 { + affine.for %i0 = 1 to 10 { + affine.for %i1 = 1 to 100 { // CHECK: #set3(%i0, %i1)[%c7, %c11] affine.if #set4(%i0, %i1)[%c7, %c11] { } @@ -139,8 +139,8 @@ func @test_gaussian_elimination_non_empty_set4() { func @test_gaussian_elimination_empty_set5() { %c7 = constant 7 : index %c11 = constant 11 : index - for %i0 = 1 to 10 { - for %i1 = 1 to 100 { + affine.for %i0 = 1 to 10 { + affine.for %i1 = 1 to 100 { // CHECK: #set2(%i0, %i1)[%c7, %c11] affine.if #set5(%i0, %i1)[%c7, %c11] { } @@ -151,8 +151,8 @@ func @test_gaussian_elimination_empty_set5() { // CHECK-LABEL: func @test_fuzz_explosion func @test_fuzz_explosion(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index) { - for %i0 = 1 to 10 { - for %i1 = 1 to 100 { + affine.for %i0 = 1 to 10 { + affine.for %i1 = 1 to 100 { affine.if #set_fuzz_virus(%i0, %i1, %arg0, %arg1, %arg2, %arg3) { } } @@ -163,8 +163,8 @@ func @test_fuzz_explosion(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : i // CHECK-LABEL: func @test_empty_set(%arg0: index) { func @test_empty_set(%N : index) { - for %i = 0 to 10 { - for %j = 0 to 10 { + affine.for %i = 0 to 10 { + affine.for %j = 0 to 10 { // CHECK: affine.if [[SET_EMPTY_2D]](%i0, %i1) affine.if (d0, d1) : (d0 - d1 >= 0, d1 - d0 - 1 >= 0)(%i, %j) { "foo"() : () -> () @@ -198,8 +198,8 @@ func @test_empty_set(%N : index) { } } // The tests below test GCDTightenInequalities(). - for %k = 0 to 10 { - for %l = 0 to 10 { + affine.for %k = 0 to 10 { + affine.for %l = 0 to 10 { // Empty because no multiple of 8 lies between 4 and 7. // CHECK: affine.if [[SET_EMPTY_1D]](%i2) affine.if (d0) : (8*d0 - 4 >= 0, -8*d0 + 7 >= 0)(%k) { @@ -226,7 +226,7 @@ func @test_empty_set(%N : index) { } } - for %m = 0 to 10 { + affine.for %m = 0 to 10 { // CHECK: affine.if [[SET_EMPTY_1D]](%i{{[0-9]+}}) affine.if (d0) : (d0 mod 2 - 3 == 0) (%m) { "foo"() : () -> () diff --git a/mlir/test/Transforms/strip-debuginfo.mlir b/mlir/test/Transforms/strip-debuginfo.mlir index fdabd5d12e0..181481279d0 100644 --- a/mlir/test/Transforms/strip-debuginfo.mlir +++ b/mlir/test/Transforms/strip-debuginfo.mlir @@ -10,7 +10,7 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) { %1 = "foo"() : () -> i32 loc("foo") // CHECK: } loc(unknown) - for %i0 = 0 to 8 { + affine.for %i0 = 0 to 8 { } loc(fused["foo", "mysource.cc":10:8]) // CHECK: } loc(unknown) diff --git a/mlir/test/Transforms/unroll-jam.mlir b/mlir/test/Transforms/unroll-jam.mlir index b872cb687fc..44feeee8f18 100644 --- a/mlir/test/Transforms/unroll-jam.mlir +++ b/mlir/test/Transforms/unroll-jam.mlir @@ -8,13 +8,13 @@ // CHECK-LABEL: func @unroll_jam_imperfect_nest() { func @unroll_jam_imperfect_nest() { // CHECK: %c100 = constant 100 : index - // CHECK-NEXT: for %i0 = 0 to 100 step 2 { - for %i = 0 to 101 { + // CHECK-NEXT: affine.for %i0 = 0 to 100 step 2 { + affine.for %i = 0 to 101 { // CHECK: %0 = "addi32"(%i0, %i0) : (index, index) -> i32 // CHECK-NEXT: %1 = affine.apply [[MAP_PLUS_1]](%i0) // CHECK-NEXT: %2 = "addi32"(%1, %1) : (index, index) -> i32 %x = "addi32"(%i, %i) : (index, index) -> i32 - for %j = 0 to 17 { + affine.for %j = 0 to 17 { // CHECK: %3 = "addi32"(%i0, %i0) : (index, index) -> i32 // CHECK-NEXT: %4 = "addi32"(%3, %3) : (i32, i32) -> i32 // CHECK-NEXT: %5 = affine.apply [[MAP_PLUS_1]](%i0) @@ -30,7 +30,7 @@ func @unroll_jam_imperfect_nest() { } // CHECK } // cleanup loop (single iteration) // CHECK: %11 = "addi32"(%c100, %c100) : (index, index) -> i32 - // CHECK-NEXT: for %i2 = 0 to 17 { + // CHECK-NEXT: affine.for %i2 = 0 to 17 { // CHECK-NEXT: %12 = "addi32"(%c100, %c100) : (index, index) -> i32 // CHECK-NEXT: %13 = "addi32"(%12, %12) : (i32, i32) -> i32 // CHECK-NEXT: } @@ -40,20 +40,20 @@ func @unroll_jam_imperfect_nest() { // CHECK-LABEL: func @loop_nest_unknown_count_1(%arg0: index) { func @loop_nest_unknown_count_1(%N : index) { - // CHECK-NEXT: for %i0 = 1 to [[MAP_DIV_OFFSET]]()[%arg0] step 2 { - // CHECK-NEXT: for %i1 = 1 to 100 { + // CHECK-NEXT: affine.for %i0 = 1 to [[MAP_DIV_OFFSET]]()[%arg0] step 2 { + // CHECK-NEXT: affine.for %i1 = 1 to 100 { // CHECK-NEXT: %0 = "foo"() : () -> i32 // CHECK-NEXT: %1 = "foo"() : () -> i32 // CHECK-NEXT: } // CHECK-NEXT: } // A cleanup loop should be generated here. - // CHECK-NEXT: for %i2 = [[MAP_DIV_OFFSET]]()[%arg0] to %arg0 { - // CHECK-NEXT: for %i3 = 1 to 100 { + // CHECK-NEXT: affine.for %i2 = [[MAP_DIV_OFFSET]]()[%arg0] to %arg0 { + // CHECK-NEXT: affine.for %i3 = 1 to 100 { // CHECK-NEXT: %2 = "foo"() : () -> i32 // CHECK_NEXT: } // CHECK_NEXT: } - for %i = 1 to %N { - for %j = 1 to 100 { + affine.for %i = 1 to %N { + affine.for %j = 1 to 100 { %x = "foo"() : () -> i32 } } @@ -62,8 +62,8 @@ func @loop_nest_unknown_count_1(%N : index) { // CHECK-LABEL: func @loop_nest_unknown_count_2(%arg0: index) { func @loop_nest_unknown_count_2(%arg : index) { - // CHECK-NEXT: for %i0 = %arg0 to [[M1]]()[%arg0] step 2 { - // CHECK-NEXT: for %i1 = 1 to 100 { + // CHECK-NEXT: affine.for %i0 = %arg0 to [[M1]]()[%arg0] step 2 { + // CHECK-NEXT: affine.for %i1 = 1 to 100 { // CHECK-NEXT: %0 = "foo"(%i0) : (index) -> i32 // CHECK-NEXT: %1 = affine.apply #map{{[0-9]+}}(%i0) // CHECK-NEXT: %2 = "foo"(%1) : (index) -> i32 @@ -71,11 +71,11 @@ func @loop_nest_unknown_count_2(%arg : index) { // CHECK-NEXT: } // The cleanup loop is a single iteration one and is promoted. // CHECK-NEXT: %3 = affine.apply [[M1]]()[%arg0] - // CHECK-NEXT: for %i2 = 1 to 100 { + // CHECK-NEXT: affine.for %i2 = 1 to 100 { // CHECK-NEXT: %4 = "foo"(%3) : (index) -> i32 // CHECK_NEXT: } - for %i = %arg to ()[s0] -> (s0+9) ()[%arg] { - for %j = 1 to 100 { + affine.for %i = %arg to ()[s0] -> (s0+9) ()[%arg] { + affine.for %j = 1 to 100 { %x = "foo"(%i) : (index) -> i32 } } @@ -84,22 +84,22 @@ func @loop_nest_unknown_count_2(%arg : index) { // CHECK-LABEL: func @loop_nest_symbolic_and_min_upper_bound func @loop_nest_symbolic_and_min_upper_bound(%M : index, %N : index, %K : index) { - for %i = 0 to min ()[s0, s1] -> (s0, s1, 1024)()[%M, %N] { - for %j = 0 to %K { + affine.for %i = 0 to min ()[s0, s1] -> (s0, s1, 1024)()[%M, %N] { + affine.for %j = 0 to %K { "foo"(%i, %j) : (index, index) -> () } } return } -// CHECK-NEXT: for %i0 = 0 to min [[MAP_MULTI_RES]]()[%arg0, %arg1] step 2 { -// CHECK-NEXT: for %i1 = 0 to %arg2 { +// CHECK-NEXT: affine.for %i0 = 0 to min [[MAP_MULTI_RES]]()[%arg0, %arg1] step 2 { +// CHECK-NEXT: affine.for %i1 = 0 to %arg2 { // CHECK-NEXT: "foo"(%i0, %i1) : (index, index) -> () // CHECK-NEXT: %0 = affine.apply #map2(%i0) // CHECK-NEXT: "foo"(%0, %i1) : (index, index) -> () // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: for %i2 = max [[MAP_MULTI_RES]]()[%arg0, %arg1] to min #map9()[%arg0, %arg1] { -// CHECK-NEXT: for %i3 = 0 to %arg2 { +// CHECK-NEXT: affine.for %i2 = max [[MAP_MULTI_RES]]()[%arg0, %arg1] to min #map9()[%arg0, %arg1] { +// CHECK-NEXT: affine.for %i3 = 0 to %arg2 { // CHECK-NEXT: "foo"(%i2, %i3) : (index, index) -> () // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/mlir/test/Transforms/unroll.mlir b/mlir/test/Transforms/unroll.mlir index f676023ee1c..5bbf3b8ce1e 100644 --- a/mlir/test/Transforms/unroll.mlir +++ b/mlir/test/Transforms/unroll.mlir @@ -25,13 +25,13 @@ // UNROLL-FULL-LABEL: func @loop_nest_simplest() { func @loop_nest_simplest() { - // UNROLL-FULL: for %i0 = 0 to 100 step 2 { - for %i = 0 to 100 step 2 { + // UNROLL-FULL: affine.for %i0 = 0 to 100 step 2 { + affine.for %i = 0 to 100 step 2 { // UNROLL-FULL: %c1_i32 = constant 1 : i32 // UNROLL-FULL-NEXT: %c1_i32_0 = constant 1 : i32 // UNROLL-FULL-NEXT: %c1_i32_1 = constant 1 : i32 // UNROLL-FULL-NEXT: %c1_i32_2 = constant 1 : i32 - for %j = 0 to 4 { + affine.for %j = 0 to 4 { %x = constant 1 : i32 } } // UNROLL-FULL: } @@ -41,8 +41,8 @@ func @loop_nest_simplest() { // UNROLL-FULL-LABEL: func @loop_nest_simple_iv_use() { func @loop_nest_simple_iv_use() { // UNROLL-FULL: %c0 = constant 0 : index - // UNROLL-FULL-NEXT: for %i0 = 0 to 100 step 2 { - for %i = 0 to 100 step 2 { + // UNROLL-FULL-NEXT: affine.for %i0 = 0 to 100 step 2 { + affine.for %i = 0 to 100 step 2 { // UNROLL-FULL: %0 = "addi32"(%c0, %c0) : (index, index) -> i32 // UNROLL-FULL: %1 = affine.apply [[MAP0]](%c0) // UNROLL-FULL-NEXT: %2 = "addi32"(%1, %1) : (index, index) -> i32 @@ -50,7 +50,7 @@ func @loop_nest_simple_iv_use() { // UNROLL-FULL-NEXT: %4 = "addi32"(%3, %3) : (index, index) -> i32 // UNROLL-FULL: %5 = affine.apply [[MAP2]](%c0) // UNROLL-FULL-NEXT: %6 = "addi32"(%5, %5) : (index, index) -> i32 - for %j = 0 to 4 { + affine.for %j = 0 to 4 { %x = "addi32"(%j, %j) : (index, index) -> i32 } } // UNROLL-FULL: } @@ -61,8 +61,8 @@ func @loop_nest_simple_iv_use() { // UNROLL-FULL-LABEL: func @loop_nest_body_def_use() { func @loop_nest_body_def_use() { // UNROLL-FULL: %c0 = constant 0 : index - // UNROLL-FULL-NEXT: for %i0 = 0 to 100 step 2 { - for %i = 0 to 100 step 2 { + // UNROLL-FULL-NEXT: affine.for %i0 = 0 to 100 step 2 { + affine.for %i = 0 to 100 step 2 { // UNROLL-FULL: %c0_0 = constant 0 : index %c0 = constant 0 : index // UNROLL-FULL: %0 = affine.apply [[MAP0]](%c0) @@ -76,7 +76,7 @@ func @loop_nest_body_def_use() { // UNROLL-FULL-NEXT: %8 = affine.apply [[MAP2]](%c0) // UNROLL-FULL-NEXT: %9 = affine.apply [[MAP0]](%8) // UNROLL-FULL-NEXT: %10 = "addi32"(%9, %c0_0) : (index, index) -> index - for %j = 0 to 4 { + affine.for %j = 0 to 4 { %x = "affine.apply" (%j) { map: (d0) -> (d0 + 1) } : (index) -> (index) %y = "addi32"(%x, %c0) : (index, index) -> index @@ -89,14 +89,14 @@ func @loop_nest_body_def_use() { func @loop_nest_strided() { // UNROLL-FULL: %c2 = constant 2 : index // UNROLL-FULL-NEXT: %c2_0 = constant 2 : index - // UNROLL-FULL-NEXT: for %i0 = 0 to 100 { - for %i = 0 to 100 { + // UNROLL-FULL-NEXT: affine.for %i0 = 0 to 100 { + affine.for %i = 0 to 100 { // UNROLL-FULL: %0 = affine.apply [[MAP0]](%c2_0) // UNROLL-FULL-NEXT: %1 = "addi32"(%0, %0) : (index, index) -> index // UNROLL-FULL-NEXT: %2 = affine.apply [[MAP1]](%c2_0) // UNROLL-FULL-NEXT: %3 = affine.apply [[MAP0]](%2) // UNROLL-FULL-NEXT: %4 = "addi32"(%3, %3) : (index, index) -> index - for %j = 2 to 6 step 2 { + affine.for %j = 2 to 6 step 2 { %x = "affine.apply" (%j) { map: (d0) -> (d0 + 1) } : (index) -> (index) %y = "addi32"(%x, %x) : (index, index) -> index @@ -109,7 +109,7 @@ func @loop_nest_strided() { // UNROLL-FULL-NEXT: %10 = affine.apply [[MAP3]](%c2) // UNROLL-FULL-NEXT: %11 = affine.apply [[MAP0]](%10) // UNROLL-FULL-NEXT: %12 = "addi32"(%11, %11) : (index, index) -> index - for %k = 2 to 7 step 2 { + affine.for %k = 2 to 7 step 2 { %z = "affine.apply" (%k) { map: (d0) -> (d0 + 1) } : (index) -> (index) %w = "addi32"(%z, %z) : (index, index) -> index @@ -121,8 +121,8 @@ func @loop_nest_strided() { // UNROLL-FULL-LABEL: func @loop_nest_multiple_results() { func @loop_nest_multiple_results() { // UNROLL-FULL: %c0 = constant 0 : index - // UNROLL-FULL-NEXT: for %i0 = 0 to 100 { - for %i = 0 to 100 { + // UNROLL-FULL-NEXT: affine.for %i0 = 0 to 100 { + affine.for %i = 0 to 100 { // UNROLL-FULL: %0 = affine.apply [[MAP4]](%i0, %c0) // UNROLL-FULL-NEXT: %1 = "addi32"(%0, %0) : (index, index) -> index // UNROLL-FULL-NEXT: %2 = affine.apply #map{{.*}}(%i0, %c0) @@ -132,7 +132,7 @@ func @loop_nest_multiple_results() { // UNROLL-FULL-NEXT: %6 = "addi32"(%5, %5) : (index, index) -> index // UNROLL-FULL-NEXT: %7 = affine.apply #map{{.*}}(%i0, %4) // UNROLL-FULL-NEXT: %8 = "fma"(%7, %5, %5) : (index, index, index) -> (index, index) - for %j = 0 to 2 step 1 { + affine.for %j = 0 to 2 step 1 { %x = affine.apply (d0, d1) -> (d0 + 1) (%i, %j) %y = "addi32"(%x, %x) : (index, index) -> index %z = affine.apply (d0, d1) -> (d0 + 3) (%i, %j) @@ -149,8 +149,8 @@ func @loop_nest_seq_imperfect(%a : memref<128x128xf32>) { // UNROLL-FULL: %c0 = constant 0 : index // UNROLL-FULL-NEXT: %c128 = constant 128 : index %c128 = constant 128 : index - // UNROLL-FULL: for %i0 = 0 to 100 { - for %i = 0 to 100 { + // UNROLL-FULL: affine.for %i0 = 0 to 100 { + affine.for %i = 0 to 100 { // UNROLL-FULL: %0 = "vld"(%i0) : (index) -> i32 %ld = "vld"(%i) : (index) -> i32 // UNROLL-FULL: %1 = affine.apply [[MAP0]](%c0) @@ -168,7 +168,7 @@ func @loop_nest_seq_imperfect(%a : memref<128x128xf32>) { // UNROLL-FULL-NEXT: %13 = affine.apply [[MAP0]](%12) // UNROLL-FULL-NEXT: %14 = "vmulf"(%12, %13) : (index, index) -> index // UNROLL-FULL-NEXT: %15 = "vaddf"(%14, %14) : (index, index) -> index - for %j = 0 to 4 { + affine.for %j = 0 to 4 { %x = "affine.apply" (%j) { map: (d0) -> (d0 + 1) } : (index) -> (index) %y = "vmulf"(%j, %x) : (index, index) -> index @@ -197,7 +197,7 @@ func @loop_nest_seq_multiple() { // UNROLL-FULL-NEXT: %5 = affine.apply [[MAP2]](%c0_0) // UNROLL-FULL-NEXT: %6 = affine.apply [[MAP0]](%5) // UNROLL-FULL-NEXT: "mul"(%6, %6) : (index, index) -> () - for %j = 0 to 4 { + affine.for %j = 0 to 4 { %x = "affine.apply" (%j) { map: (d0) -> (d0 + 1) } : (index) -> (index) "mul"(%x, %x) : (index, index) -> () @@ -205,8 +205,8 @@ func @loop_nest_seq_multiple() { // UNROLL-FULL: %c99 = constant 99 : index %k = constant 99 : index - // UNROLL-FULL: for %i0 = 0 to 100 step 2 { - for %m = 0 to 100 step 2 { + // UNROLL-FULL: affine.for %i0 = 0 to 100 step 2 { + affine.for %m = 0 to 100 step 2 { // UNROLL-FULL: %7 = affine.apply [[MAP0]](%c0) // UNROLL-FULL-NEXT: %8 = affine.apply [[MAP6]](%c0)[%c99] // UNROLL-FULL-NEXT: %9 = affine.apply [[MAP0]](%c0) @@ -218,7 +218,7 @@ func @loop_nest_seq_multiple() { // UNROLL-FULL-NEXT: %15 = affine.apply [[MAP2]](%c0) // UNROLL-FULL-NEXT: %16 = affine.apply [[MAP0]](%15) // UNROLL-FULL-NEXT: %17 = affine.apply [[MAP6]](%15)[%c99] - for %n = 0 to 4 { + affine.for %n = 0 to 4 { %y = "affine.apply" (%n) { map: (d0) -> (d0 + 1) } : (index) -> (index) %z = "affine.apply" (%n, %k) { map: (d0) [s0] -> (d0 + s0 + 1) } : @@ -233,7 +233,7 @@ func @loop_nest_unroll_full() { // UNROLL-FULL-NEXT: %0 = "foo"() : () -> i32 // UNROLL-FULL-NEXT: %1 = "bar"() : () -> i32 // UNROLL-FULL-NEXT: return - for %i = 0 to 1 { + affine.for %i = 0 to 1 { %x = "foo"() : () -> i32 %y = "bar"() : () -> i32 } @@ -242,16 +242,16 @@ func @loop_nest_unroll_full() { // SHORT-LABEL: func @loop_nest_outer_unroll() { func @loop_nest_outer_unroll() { - // SHORT: for %i0 = 0 to 4 { + // SHORT: affine.for %i0 = 0 to 4 { // SHORT-NEXT: %0 = affine.apply [[MAP0]](%i0) // SHORT-NEXT: %1 = "addi32"(%0, %0) : (index, index) -> index // SHORT-NEXT: } - // SHORT-NEXT: for %i1 = 0 to 4 { + // SHORT-NEXT: affine.for %i1 = 0 to 4 { // SHORT-NEXT: %2 = affine.apply [[MAP0]](%i1) // SHORT-NEXT: %3 = "addi32"(%2, %2) : (index, index) -> index // SHORT-NEXT: } - for %i = 0 to 2 { - for %j = 0 to 4 { + affine.for %i = 0 to 2 { + affine.for %j = 0 to 4 { %x = "affine.apply" (%j) { map: (d0) -> (d0 + 1) } : (index) -> (index) %y = "addi32"(%x, %x) : (index, index) -> index @@ -275,33 +275,33 @@ func @loop_nest_seq_long() -> i32 { %zero_idx = constant 0 : index - // CHECK: for %i0 = 0 to 512 - for %n0 = 0 to 512 { - // CHECK: for %i1 = 0 to 8 - for %n1 = 0 to 8 { + // CHECK: affine.for %i0 = 0 to 512 + affine.for %n0 = 0 to 512 { + // CHECK: affine.for %i1 = 0 to 8 + affine.for %n1 = 0 to 8 { store %one, %A[%n0, %n1] : memref<512 x 512 x i32, (d0, d1) -> (d0, d1), 2> store %two, %B[%n0, %n1] : memref<512 x 512 x i32, (d0, d1) -> (d0, d1), 2> store %zero, %C[%n0, %n1] : memref<512 x 512 x i32, (d0, d1) -> (d0, d1), 2> } } - for %x = 0 to 2 { - for %y = 0 to 2 { - // CHECK: for %i2 - for %i2 = 0 to 8 { - // CHECK-NOT: for %i3 + affine.for %x = 0 to 2 { + affine.for %y = 0 to 2 { + // CHECK: affine.for %i2 + affine.for %i2 = 0 to 8 { + // CHECK-NOT: affine.for %i3 // CHECK: %{{[0-9]+}} = affine.apply %b2 = "affine.apply" (%y, %i2) {map: (d0, d1) -> (16*d0 + d1)} : (index, index) -> index %z = load %B[%x, %b2] : memref<512 x 512 x i32, (d0, d1) -> (d0, d1), 2> "op1"(%z) : (i32) -> () } - for %j1 = 0 to 8 { - for %j2 = 0 to 8 { + affine.for %j1 = 0 to 8 { + affine.for %j2 = 0 to 8 { %a2 = "affine.apply" (%y, %j2) {map: (d0, d1) -> (16*d0 + d1)} : (index, index) -> index %v203 = load %A[%j1, %a2] : memref<512 x 512 x i32, (d0, d1) -> (d0, d1), 2> "op2"(%v203) : (i32) -> () } - for %k2 = 0 to 8 { + affine.for %k2 = 0 to 8 { %s0 = "op3"() : () -> i32 %c2 = "affine.apply" (%x, %k2) {map: (d0, d1) -> (16*d0 + d1)} : (index, index) -> index %s1 = load %C[%j1, %c2] : memref<512 x 512 x i32, (d0, d1) -> (d0, d1), 2> @@ -318,8 +318,8 @@ func @loop_nest_seq_long() -> i32 { // UNROLL-BY-4-LABEL: func @unroll_unit_stride_no_cleanup() { func @unroll_unit_stride_no_cleanup() { - // UNROLL-BY-4: for %i0 = 0 to 100 { - for %i = 0 to 100 { + // UNROLL-BY-4: affine.for %i0 = 0 to 100 { + affine.for %i = 0 to 100 { // UNROLL-BY-4: for [[L1:%i[0-9]+]] = 0 to 8 step 4 { // UNROLL-BY-4-NEXT: %0 = "addi32"([[L1]], [[L1]]) : (index, index) -> i32 // UNROLL-BY-4-NEXT: %1 = "addi32"(%0, %0) : (i32, i32) -> i32 @@ -333,13 +333,13 @@ func @unroll_unit_stride_no_cleanup() { // UNROLL-BY-4-NEXT: %9 = "addi32"(%8, %8) : (index, index) -> i32 // UNROLL-BY-4-NEXT: %10 = "addi32"(%9, %9) : (i32, i32) -> i32 // UNROLL-BY-4-NEXT: } - for %j = 0 to 8 { + affine.for %j = 0 to 8 { %x = "addi32"(%j, %j) : (index, index) -> i32 %y = "addi32"(%x, %x) : (i32, i32) -> i32 } // empty loop - // UNROLL-BY-4: for %i2 = 0 to 8 { - for %k = 0 to 8 { + // UNROLL-BY-4: affine.for %i2 = 0 to 8 { + affine.for %k = 0 to 8 { } } return @@ -347,8 +347,8 @@ func @unroll_unit_stride_no_cleanup() { // UNROLL-BY-4-LABEL: func @unroll_unit_stride_cleanup() { func @unroll_unit_stride_cleanup() { - // UNROLL-BY-4: for %i0 = 0 to 100 { - for %i = 0 to 100 { + // UNROLL-BY-4: affine.for %i0 = 0 to 100 { + affine.for %i = 0 to 100 { // UNROLL-BY-4: for [[L1:%i[0-9]+]] = 0 to 8 step 4 { // UNROLL-BY-4-NEXT: %0 = "addi32"([[L1]], [[L1]]) : (index, index) -> i32 // UNROLL-BY-4-NEXT: %1 = "addi32"(%0, %0) : (i32, i32) -> i32 @@ -366,7 +366,7 @@ func @unroll_unit_stride_cleanup() { // UNROLL-BY-4-NEXT: %11 = "addi32"([[L2]], [[L2]]) : (index, index) -> i32 // UNROLL-BY-4-NEXT: %12 = "addi32"(%11, %11) : (i32, i32) -> i32 // UNROLL-BY-4-NEXT: } - for %j = 0 to 10 { + affine.for %j = 0 to 10 { %x = "addi32"(%j, %j) : (index, index) -> i32 %y = "addi32"(%x, %x) : (i32, i32) -> i32 } @@ -376,8 +376,8 @@ func @unroll_unit_stride_cleanup() { // UNROLL-BY-4-LABEL: func @unroll_non_unit_stride_cleanup() { func @unroll_non_unit_stride_cleanup() { - // UNROLL-BY-4: for %i0 = 0 to 100 { - for %i = 0 to 100 { + // UNROLL-BY-4: affine.for %i0 = 0 to 100 { + affine.for %i = 0 to 100 { // UNROLL-BY-4: for [[L1:%i[0-9]+]] = 2 to 42 step 20 { // UNROLL-BY-4-NEXT: %0 = "addi32"([[L1]], [[L1]]) : (index, index) -> i32 // UNROLL-BY-4-NEXT: %1 = "addi32"(%0, %0) : (i32, i32) -> i32 @@ -395,7 +395,7 @@ func @unroll_non_unit_stride_cleanup() { // UNROLL-BY-4-NEXT: %11 = "addi32"([[L2]], [[L2]]) : (index, index) -> i32 // UNROLL-BY-4-NEXT: %12 = "addi32"(%11, %11) : (i32, i32) -> i32 // UNROLL-BY-4-NEXT: } - for %j = 2 to 48 step 5 { + affine.for %j = 2 to 48 step 5 { %x = "addi32"(%j, %j) : (index, index) -> i32 %y = "addi32"(%x, %x) : (i32, i32) -> i32 } @@ -408,8 +408,8 @@ func @unroll_non_unit_stride_cleanup() { func @loop_nest_single_iteration_after_unroll(%N: index) { // UNROLL-BY-4: %c0 = constant 0 : index // UNROLL-BY-4: %c4 = constant 4 : index - // UNROLL-BY-4: for %i0 = 0 to %arg0 { - for %i = 0 to %N { + // UNROLL-BY-4: affine.for %i0 = 0 to %arg0 { + affine.for %i = 0 to %N { // UNROLL-BY-4: %0 = "addi32"(%c0, %c0) : (index, index) -> i32 // UNROLL-BY-4-NEXT: %1 = affine.apply [[MAP0]](%c0) // UNROLL-BY-4-NEXT: %2 = "addi32"(%1, %1) : (index, index) -> i32 @@ -419,7 +419,7 @@ func @loop_nest_single_iteration_after_unroll(%N: index) { // UNROLL-BY-4-NEXT: %6 = "addi32"(%5, %5) : (index, index) -> i32 // UNROLL-BY-4-NEXT: %7 = "addi32"(%c4, %c4) : (index, index) -> i32 // UNROLL-BY-4-NOT: for - for %j = 0 to 5 { + affine.for %j = 0 to 5 { %x = "addi32"(%j, %j) : (index, index) -> i32 } // UNROLL-BY-4-NOT: } } // UNROLL-BY-4: } @@ -431,8 +431,8 @@ func @loop_nest_single_iteration_after_unroll(%N: index) { // No cleanup will be generated here. // UNROLL-BY-4-LABEL: func @loop_nest_operand1() { func @loop_nest_operand1() { -// UNROLL-BY-4: for %i0 = 0 to 100 step 2 { -// UNROLL-BY-4-NEXT: for %i1 = 0 to #map{{[0-9]+}}(%i0) step 4 +// UNROLL-BY-4: affine.for %i0 = 0 to 100 step 2 { +// UNROLL-BY-4-NEXT: affine.for %i1 = 0 to #map{{[0-9]+}}(%i0) step 4 // UNROLL-BY-4-NEXT: %0 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32 @@ -440,8 +440,8 @@ func @loop_nest_operand1() { // UNROLL-BY-4-NEXT: } // UNROLL-BY-4-NEXT: } // UNROLL-BY-4-NEXT: return - for %i = 0 to 100 step 2 { - for %j = 0 to (d0) -> (d0 - d0 mod 4) (%i) { + affine.for %i = 0 to 100 step 2 { + affine.for %j = 0 to (d0) -> (d0 - d0 mod 4) (%i) { %x = "foo"() : () -> i32 } } @@ -451,8 +451,8 @@ func @loop_nest_operand1() { // No cleanup will be generated here. // UNROLL-BY-4-LABEL: func @loop_nest_operand2() { func @loop_nest_operand2() { -// UNROLL-BY-4: for %i0 = 0 to 100 step 2 { -// UNROLL-BY-4-NEXT: for %i1 = [[MAP11]](%i0) to #map{{[0-9]+}}(%i0) step 4 { +// UNROLL-BY-4: affine.for %i0 = 0 to 100 step 2 { +// UNROLL-BY-4-NEXT: affine.for %i1 = [[MAP11]](%i0) to #map{{[0-9]+}}(%i0) step 4 { // UNROLL-BY-4-NEXT: %0 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32 @@ -460,8 +460,8 @@ func @loop_nest_operand2() { // UNROLL-BY-4-NEXT: } // UNROLL-BY-4-NEXT: } // UNROLL-BY-4-NEXT: return - for %i = 0 to 100 step 2 { - for %j = (d0) -> (d0) (%i) to (d0) -> (5*d0 + 4) (%i) { + affine.for %i = 0 to 100 step 2 { + affine.for %j = (d0) -> (d0) (%i) to (d0) -> (5*d0 + 4) (%i) { %x = "foo"() : () -> i32 } } @@ -472,16 +472,16 @@ func @loop_nest_operand2() { // factor. The cleanup loop happens to be a single iteration one and is promoted. // UNROLL-BY-4-LABEL: func @loop_nest_operand3() { func @loop_nest_operand3() { - // UNROLL-BY-4: for %i0 = 0 to 100 step 2 { - for %i = 0 to 100 step 2 { - // UNROLL-BY-4: for %i1 = [[MAP11]](%i0) to #map{{[0-9]+}}(%i0) step 4 { + // UNROLL-BY-4: affine.for %i0 = 0 to 100 step 2 { + affine.for %i = 0 to 100 step 2 { + // UNROLL-BY-4: affine.for %i1 = [[MAP11]](%i0) to #map{{[0-9]+}}(%i0) step 4 { // UNROLL-BY-4-NEXT: %0 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %3 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: } // UNROLL-BY-4-NEXT: %4 = "foo"() : () -> i32 - for %j = (d0) -> (d0) (%i) to (d0) -> (d0 + 9) (%i) { + affine.for %j = (d0) -> (d0) (%i) to (d0) -> (d0 + 9) (%i) { %x = "foo"() : () -> i32 } } // UNROLL-BY-4: } @@ -490,19 +490,19 @@ func @loop_nest_operand3() { // UNROLL-BY-4-LABEL: func @loop_nest_symbolic_bound(%arg0: index) { func @loop_nest_symbolic_bound(%N : index) { - // UNROLL-BY-4: for %i0 = 0 to 100 { - for %i = 0 to 100 { - // UNROLL-BY-4: for %i1 = 0 to #map{{[0-9]+}}()[%arg0] step 4 { + // UNROLL-BY-4: affine.for %i0 = 0 to 100 { + affine.for %i = 0 to 100 { + // UNROLL-BY-4: affine.for %i1 = 0 to #map{{[0-9]+}}()[%arg0] step 4 { // UNROLL-BY-4: %0 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: %3 = "foo"() : () -> i32 // UNROLL-BY-4-NEXT: } // A cleanup loop will be be generated here. - // UNROLL-BY-4-NEXT: for %i2 = #map{{[0-9]+}}()[%arg0] to %arg0 { + // UNROLL-BY-4-NEXT: affine.for %i2 = #map{{[0-9]+}}()[%arg0] to %arg0 { // UNROLL-BY-4-NEXT: %4 = "foo"() : () -> i32 // UNROLL-BY-4_NEXT: } - for %j = 0 to %N { + affine.for %j = 0 to %N { %x = "foo"() : () -> i32 } } @@ -511,18 +511,18 @@ func @loop_nest_symbolic_bound(%N : index) { // UNROLL-BY-4-LABEL: func @loop_nest_symbolic_and_min_upper_bound func @loop_nest_symbolic_and_min_upper_bound(%M : index, %N : index, %K : index) { - for %i = %M to min ()[s0, s1] -> (s0, s1, 1024)()[%N, %K] { + affine.for %i = %M to min ()[s0, s1] -> (s0, s1, 1024)()[%N, %K] { "foo"() : () -> () } return } -// CHECK-NEXT: for %i0 = %arg0 to min [[MAP_TRIP_COUNT_MULTIPLE_FOUR]]()[%arg0, %arg1, %arg2] step 4 { +// CHECK-NEXT: affine.for %i0 = %arg0 to min [[MAP_TRIP_COUNT_MULTIPLE_FOUR]]()[%arg0, %arg1, %arg2] step 4 { // CHECK-NEXT: "foo"() : () -> () // CHECK-NEXT: "foo"() : () -> () // CHECK-NEXT: "foo"() : () -> () // CHECK-NEXT: "foo"() : () -> () // CHECK-NEXT: } -// CHECK-NEXT: for %i1 = max [[MAP_TRIP_COUNT_MULTIPLE_FOUR]]()[%arg0, %arg1, %arg2] to min #map28()[%arg1, %arg2] { +// CHECK-NEXT: affine.for %i1 = max [[MAP_TRIP_COUNT_MULTIPLE_FOUR]]()[%arg0, %arg1, %arg2] to min #map28()[%arg1, %arg2] { // CHECK-NEXT: "foo"() : () -> () // CHECK-NEXT: } // CHECK-NEXT: return @@ -533,22 +533,22 @@ func @loop_nest_symbolic_and_min_upper_bound(%M : index, %N : index, %K : index) func @loop_nest_non_trivial_multiple_unroll_factor(%M : index, %N : index) { %T = affine.apply (d0) -> (4*d0 + 1)(%M) %K = affine.apply (d0) -> (d0 - 1) (%T) - for %i = 0 to min (d0, d1) -> (4 * d0, d1, 1024)(%N, %K) { + affine.for %i = 0 to min (d0, d1) -> (4 * d0, d1, 1024)(%N, %K) { "foo"() : () -> () } return } -// UNROLL-BY-4: for %i0 = 0 to min +// UNROLL-BY-4: affine.for %i0 = 0 to min // UNROLL-BY-4-NOT: for // UNROLL-BY-4: return // UNROLL-BY-4-LABEL: func @loop_nest_non_trivial_multiple_unroll_factor_2 func @loop_nest_non_trivial_multiple_unroll_factor_2(%M : index, %N : index) { %K = affine.apply (d0) -> (4*d0) (%M) - for %i = 0 to min ()[s0, s1] -> (4 * s0, s1, 1024)()[%N, %K] { + affine.for %i = 0 to min ()[s0, s1] -> (4 * s0, s1, 1024)()[%N, %K] { "foo"() : () -> () } - // UNROLL-BY-4: for %i0 = 0 to min + // UNROLL-BY-4: affine.for %i0 = 0 to min // UNROLL-BY-4-NEXT: "foo" // UNROLL-BY-4-NEXT: "foo" // UNROLL-BY-4-NEXT: "foo" @@ -560,7 +560,7 @@ func @loop_nest_non_trivial_multiple_unroll_factor_2(%M : index, %N : index) { // UNROLL-BY-1-LABEL: func @unroll_by_one_should_promote_single_iteration_loop() func @unroll_by_one_should_promote_single_iteration_loop() { - for %i = 0 to 1 { + affine.for %i = 0 to 1 { %x = "foo"(%i) : (index) -> i32 } return -- cgit v1.2.3 From af1abcc80b6eb6de2049b6cc79bbeac92f134e58 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 25 Mar 2019 11:13:31 -0700 Subject: Replace usages of "operator->" with "." for the AffineOps. Note: The "operator->" method is a temporary helper for the de-const transition and is gradually being phased out. PiperOrigin-RevId: 240179439 --- mlir/include/mlir/AffineOps/AffineOps.h | 10 +- mlir/include/mlir/EDSC/Builders.h | 4 +- mlir/lib/AffineOps/AffineOps.cpp | 31 ++--- mlir/lib/Analysis/AffineAnalysis.cpp | 2 +- mlir/lib/Analysis/AffineStructures.cpp | 38 +++--- mlir/lib/Analysis/LoopAnalysis.cpp | 34 ++--- mlir/lib/Analysis/SliceAnalysis.cpp | 2 +- mlir/lib/Analysis/TestParallelismDetection.cpp | 2 +- mlir/lib/Analysis/Utils.cpp | 32 ++--- mlir/lib/Analysis/VectorAnalysis.cpp | 2 +- mlir/lib/EDSC/Builders.cpp | 8 +- mlir/lib/EDSC/MLIREmitter.cpp | 8 +- mlir/lib/EDSC/Types.cpp | 2 +- mlir/lib/Transforms/DmaGeneration.cpp | 19 ++- mlir/lib/Transforms/LoopFusion.cpp | 79 ++++++----- mlir/lib/Transforms/LoopTiling.cpp | 76 +++++------ mlir/lib/Transforms/LoopUnrollAndJam.cpp | 16 +-- mlir/lib/Transforms/LowerAffine.cpp | 48 +++---- mlir/lib/Transforms/PipelineDataTransfer.cpp | 41 +++--- mlir/lib/Transforms/Utils/LoopUtils.cpp | 152 ++++++++++----------- .../Vectorization/VectorizerTestPass.cpp | 6 +- mlir/lib/Transforms/Vectorize.cpp | 17 ++- 22 files changed, 310 insertions(+), 319 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/AffineOps/AffineOps.h b/mlir/include/mlir/AffineOps/AffineOps.h index d8e34dc7248..4f949231674 100644 --- a/mlir/include/mlir/AffineOps/AffineOps.h +++ b/mlir/include/mlir/AffineOps/AffineOps.h @@ -253,18 +253,14 @@ public: unsigned getNumOperands() { return opEnd - opStart; } Value *getOperand(unsigned idx) { - return inst->getInstruction()->getOperand(opStart + idx); + return inst.getInstruction()->getOperand(opStart + idx); } using operand_iterator = AffineForOp::operand_iterator; using operand_range = AffineForOp::operand_range; - operand_iterator operand_begin() { - return inst->getInstruction()->operand_begin() + opStart; - } - operand_iterator operand_end() { - return inst->getInstruction()->operand_begin() + opEnd; - } + operand_iterator operand_begin() { return inst.operand_begin() + opStart; } + operand_iterator operand_end() { return inst.operand_begin() + opEnd; } operand_range getOperands() { return {operand_begin(), operand_end()}; } private: diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h index 38d3bf32dbc..8a186c28476 100644 --- a/mlir/include/mlir/EDSC/Builders.h +++ b/mlir/include/mlir/EDSC/Builders.h @@ -433,8 +433,8 @@ ValueHandle ValueHandle::create(Args... args) { return ValueHandle(inst->getResult(0)); } else if (inst->getNumResults() == 0) { if (auto f = inst->dyn_cast()) { - f->createBody(); - return ValueHandle(f->getInductionVar()); + f.createBody(); + return ValueHandle(f.getInductionVar()); } } llvm_unreachable("unsupported instruction, use an InstructionHandle instead"); diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 92035489e21..2901d815032 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -62,7 +62,7 @@ bool mlir::isValidDim(Value *value) { return true; // Affine apply operation is ok if all of its operands are ok. if (auto op = inst->dyn_cast()) - return op->isValidDim(); + return op.isValidDim(); // The dim op is okay if its operand memref/tensor is defined at the top // level. if (auto dimOp = inst->dyn_cast()) @@ -488,12 +488,11 @@ AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map, : AffineApplyOp(); if (affineApply) { // a. Compose affine.apply instructions. - LLVM_DEBUG(affineApply->getInstruction()->print( + LLVM_DEBUG(affineApply.getInstruction()->print( dbgs() << "\nCompose AffineApplyOp recursively: ")); - AffineMap affineApplyMap = affineApply->getAffineMap(); + AffineMap affineApplyMap = affineApply.getAffineMap(); SmallVector affineApplyOperands( - affineApply->getOperands().begin(), - affineApply->getOperands().end()); + affineApply.getOperands().begin(), affineApply.getOperands().end()); AffineApplyNormalizer normalizer(affineApplyMap, affineApplyOperands); LLVM_DEBUG(normalizer.affineMap.print( @@ -684,10 +683,10 @@ void mlir::canonicalizeMapAndOperands( PatternMatchResult SimplifyAffineApply::match(Instruction *op) const { auto apply = op->cast(); - auto map = apply->getAffineMap(); + auto map = apply.getAffineMap(); AffineMap oldMap = map; - SmallVector resultOperands(apply->getOperands()); + SmallVector resultOperands(apply.getOperands()); composeAffineMapAndOperands(&map, &resultOperands); if (map != oldMap) return matchSuccess( @@ -997,7 +996,7 @@ struct AffineForLoopBoundFolder : public RewritePattern { auto forOp = op->cast(); // If the loop has non-constant bounds, it may be foldable. - if (!forOp->hasConstantBounds()) + if (!forOp.hasConstantBounds()) return matchSuccess(); return matchFailure(); @@ -1009,8 +1008,8 @@ struct AffineForLoopBoundFolder : public RewritePattern { // Check to see if each of the operands is the result of a constant. If // so, get the value. If not, ignore it. SmallVector operandConstants; - auto boundOperands = lower ? forOp->getLowerBoundOperands() - : forOp->getUpperBoundOperands(); + auto boundOperands = + lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands(); for (auto *operand : boundOperands) { Attribute operandCst; matchPattern(operand, m_Constant(&operandCst)); @@ -1018,7 +1017,7 @@ struct AffineForLoopBoundFolder : public RewritePattern { } AffineMap boundMap = - lower ? forOp->getLowerBoundMap() : forOp->getUpperBoundMap(); + lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap(); assert(boundMap.getNumResults() >= 1 && "bound maps should have at least one result"); SmallVector foldedResults; @@ -1034,16 +1033,16 @@ struct AffineForLoopBoundFolder : public RewritePattern { maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult) : llvm::APIntOps::smin(maxOrMin, foldedResult); } - lower ? forOp->setConstantLowerBound(maxOrMin.getSExtValue()) - : forOp->setConstantUpperBound(maxOrMin.getSExtValue()); + lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue()) + : forOp.setConstantUpperBound(maxOrMin.getSExtValue()); }; // Try to fold the lower bound. - if (!forOp->hasConstantLowerBound()) + if (!forOp.hasConstantLowerBound()) foldLowerOrUpperBound(/*lower=*/true); // Try to fold the upper bound. - if (!forOp->hasConstantUpperBound()) + if (!forOp.hasConstantUpperBound()) foldLowerOrUpperBound(/*lower=*/false); rewriter.updatedRootInPlace(op); @@ -1196,7 +1195,7 @@ void mlir::extractForInductionVars(ArrayRef forInsts, SmallVectorImpl *ivs) { ivs->reserve(forInsts.size()); for (auto forInst : forInsts) - ivs->push_back(forInst->getInductionVar()); + ivs->push_back(forInst.getInductionVar()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index e2e9ef68b17..a2d511cf965 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -546,7 +546,7 @@ static Block *getCommonBlock(const MemRefAccess &srcAccess, auto *commonForValue = srcDomain.getIdValue(numCommonLoops - 1); auto forOp = getForInductionVarOwner(commonForValue); assert(forOp && "commonForValue was not an induction variable"); - return forOp->getBody(); + return forOp.getBody(); } // Returns true if the ancestor operation instruction of 'srcAccess' appears diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 64d1809922c..64f18e325de 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -218,9 +218,9 @@ AffineValueMap::AffineValueMap(AffineMap map, ArrayRef operands, results(results.begin(), results.end()) {} AffineValueMap::AffineValueMap(AffineApplyOp applyOp) - : map(applyOp->getAffineMap()), - operands(applyOp->operand_begin(), applyOp->operand_end()) { - results.push_back(applyOp->getResult()); + : map(applyOp.getAffineMap()), + operands(applyOp.operand_begin(), applyOp.operand_end()) { + results.push_back(applyOp.getResult()); } AffineValueMap::AffineValueMap(AffineBound bound) @@ -721,7 +721,7 @@ void FlatAffineConstraints::addInductionVarOrTerminalSymbol(Value *id) { addDimId(getNumDimIds(), id); if (failed(this->addAffineForOpDomain(loop))) LLVM_DEBUG( - loop->emitWarning("failed to add domain info to constraint system")); + loop.emitWarning("failed to add domain info to constraint system")); return; } // Add top level symbol. @@ -737,15 +737,15 @@ void FlatAffineConstraints::addInductionVarOrTerminalSymbol(Value *id) { LogicalResult FlatAffineConstraints::addAffineForOpDomain(AffineForOp forOp) { unsigned pos; // Pre-condition for this method. - if (!findId(*forOp->getInductionVar(), &pos)) { + if (!findId(*forOp.getInductionVar(), &pos)) { assert(false && "Value not found"); return failure(); } - int64_t step = forOp->getStep(); + int64_t step = forOp.getStep(); if (step != 1) { - if (!forOp->hasConstantLowerBound()) - forOp->emitWarning("domain conservatively approximated"); + if (!forOp.hasConstantLowerBound()) + forOp.emitWarning("domain conservatively approximated"); else { // Add constraints for the stride. // (iv - lb) % step = 0 can be written as: @@ -753,7 +753,7 @@ LogicalResult FlatAffineConstraints::addAffineForOpDomain(AffineForOp forOp) { // Add local variable 'q' and add the above equality. // The first constraint is q = (iv - lb) floordiv step SmallVector dividend(getNumCols(), 0); - int64_t lb = forOp->getConstantLowerBound(); + int64_t lb = forOp.getConstantLowerBound(); dividend[pos] = 1; dividend.back() -= lb; addLocalFloorDiv(dividend, step); @@ -767,25 +767,25 @@ LogicalResult FlatAffineConstraints::addAffineForOpDomain(AffineForOp forOp) { } } - if (forOp->hasConstantLowerBound()) { - addConstantLowerBound(pos, forOp->getConstantLowerBound()); + if (forOp.hasConstantLowerBound()) { + addConstantLowerBound(pos, forOp.getConstantLowerBound()); } else { // Non-constant lower bound case. - SmallVector lbOperands(forOp->getLowerBoundOperands().begin(), - forOp->getLowerBoundOperands().end()); - if (failed(addLowerOrUpperBound(pos, forOp->getLowerBoundMap(), lbOperands, + SmallVector lbOperands(forOp.getLowerBoundOperands().begin(), + forOp.getLowerBoundOperands().end()); + if (failed(addLowerOrUpperBound(pos, forOp.getLowerBoundMap(), lbOperands, /*eq=*/false, /*lower=*/true))) return failure(); } - if (forOp->hasConstantUpperBound()) { - addConstantUpperBound(pos, forOp->getConstantUpperBound() - 1); + if (forOp.hasConstantUpperBound()) { + addConstantUpperBound(pos, forOp.getConstantUpperBound() - 1); return success(); } // Non-constant upper bound case. - SmallVector ubOperands(forOp->getUpperBoundOperands().begin(), - forOp->getUpperBoundOperands().end()); - return addLowerOrUpperBound(pos, forOp->getUpperBoundMap(), ubOperands, + SmallVector ubOperands(forOp.getUpperBoundOperands().begin(), + forOp.getUpperBoundOperands().end()); + return addLowerOrUpperBound(pos, forOp.getUpperBoundMap(), ubOperands, /*eq=*/false, /*lower=*/false); } diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index ba9a29177fe..bf8e265dbb8 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -53,12 +53,12 @@ void mlir::buildTripCountMapAndOperands( SmallVectorImpl *tripCountOperands) { int64_t loopSpan; - int64_t step = forOp->getStep(); - FuncBuilder b(forOp->getInstruction()); + int64_t step = forOp.getStep(); + FuncBuilder b(forOp.getInstruction()); - if (forOp->hasConstantBounds()) { - int64_t lb = forOp->getConstantLowerBound(); - int64_t ub = forOp->getConstantUpperBound(); + if (forOp.hasConstantBounds()) { + int64_t lb = forOp.getConstantLowerBound(); + int64_t ub = forOp.getConstantUpperBound(); loopSpan = ub - lb; if (loopSpan < 0) loopSpan = 0; @@ -66,20 +66,20 @@ void mlir::buildTripCountMapAndOperands( tripCountOperands->clear(); return; } - auto lbMap = forOp->getLowerBoundMap(); - auto ubMap = forOp->getUpperBoundMap(); + auto lbMap = forOp.getLowerBoundMap(); + auto ubMap = forOp.getUpperBoundMap(); if (lbMap.getNumResults() != 1) { *map = AffineMap(); return; } - SmallVector lbOperands(forOp->getLowerBoundOperands()); - SmallVector ubOperands(forOp->getUpperBoundOperands()); - auto lb = b.create(forOp->getLoc(), lbMap, lbOperands); + SmallVector lbOperands(forOp.getLowerBoundOperands()); + SmallVector ubOperands(forOp.getUpperBoundOperands()); + auto lb = b.create(forOp.getLoc(), lbMap, lbOperands); SmallVector ubs; ubs.reserve(ubMap.getNumResults()); for (auto ubExpr : ubMap.getResults()) ubs.push_back(b.create( - forOp->getLoc(), + forOp.getLoc(), b.getAffineMap(ubMap.getNumDims(), ubMap.getNumSymbols(), {ubExpr}, {}), ubOperands)); @@ -102,8 +102,8 @@ void mlir::buildTripCountMapAndOperands( for (auto *v : ubs) if (v->use_empty()) v->getDefiningInst()->erase(); - if (lb->use_empty()) - lb->erase(); + if (lb.use_empty()) + lb.erase(); } /// Returns the trip count of the loop if it's a constant, None otherwise. This @@ -280,7 +280,7 @@ using VectorizableInstFun = std::function; static bool isVectorizableLoopWithCond(AffineForOp loop, VectorizableInstFun isVectorizableInst) { - auto *forInst = loop->getInstruction(); + auto *forInst = loop.getInstruction(); if (!matcher::isParallelLoop(*forInst) && !matcher::isReductionLoop(*forInst)) { return false; @@ -339,9 +339,9 @@ bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim( [fastestVaryingDim](AffineForOp loop, Instruction &op) { auto load = op.dyn_cast(); auto store = op.dyn_cast(); - return load ? isContiguousAccess(*loop->getInductionVar(), load, + return load ? isContiguousAccess(*loop.getInductionVar(), load, fastestVaryingDim) - : isContiguousAccess(*loop->getInductionVar(), store, + : isContiguousAccess(*loop.getInductionVar(), store, fastestVaryingDim); }); return isVectorizableLoopWithCond(loop, fun); @@ -360,7 +360,7 @@ bool mlir::isVectorizableLoop(AffineForOp loop) { // TODO(mlir-team): extend this to check for memory-based dependence // violation when we have the support. bool mlir::isInstwiseShiftValid(AffineForOp forOp, ArrayRef shifts) { - auto *forBody = forOp->getBody(); + auto *forBody = forOp.getBody(); assert(shifts.size() == forBody->getInstructions().size()); // Work backwards over the body of the block so that the shift of a use's diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index 4b599c4d4df..878b713ded1 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -53,7 +53,7 @@ static void getForwardSliceImpl(Instruction *inst, } if (auto forOp = inst->dyn_cast()) { - for (auto &u : forOp->getInductionVar()->getUses()) { + for (auto &u : forOp.getInductionVar()->getUses()) { auto *ownerInst = u.getOwner(); if (forwardSlice->count(ownerInst) == 0) { getForwardSliceImpl(ownerInst, forwardSlice, filter); diff --git a/mlir/lib/Analysis/TestParallelismDetection.cpp b/mlir/lib/Analysis/TestParallelismDetection.cpp index 7ed59b403cd..af112e5b02c 100644 --- a/mlir/lib/Analysis/TestParallelismDetection.cpp +++ b/mlir/lib/Analysis/TestParallelismDetection.cpp @@ -47,7 +47,7 @@ void TestParallelismDetection::runOnFunction() { FuncBuilder b(f); f->walk([&](AffineForOp forOp) { if (isLoopParallel(forOp)) - forOp->emitNote("parallel loop"); + forOp.emitNote("parallel loop"); }); } diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 6bc395c46bd..2ac4ee9000f 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -374,7 +374,7 @@ LogicalResult mlir::boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, std::is_same::value, "argument should be either a LoadOp or a StoreOp"); - Instruction *opInst = loadOrStoreOp->getInstruction(); + Instruction *opInst = loadOrStoreOp.getInstruction(); MemRefRegion region(opInst->getLoc()); if (failed(region.compute(opInst, /*loopDepth=*/0))) @@ -458,7 +458,7 @@ static Instruction *getInstAtPosition(ArrayRef positions, return &inst; if (auto childAffineForOp = inst.dyn_cast()) return getInstAtPosition(positions, level + 1, - childAffineForOp->getBody()); + childAffineForOp.getBody()); for (auto ®ion : inst.getRegions()) { for (auto &b : region) @@ -537,9 +537,9 @@ LogicalResult mlir::getBackwardComputationSliceState( // TODO(andydavis, bondhugula) Use MemRef read/write regions instead of // using 'kSliceFusionBarrierAttrName'. for (unsigned i = 0; i < numSrcLoopIVs; ++i) { - Value *iv = srcLoopIVs[i]->getInductionVar(); + Value *iv = srcLoopIVs[i].getInductionVar(); if (sequentialLoops.count(iv) == 0 && - srcLoopIVs[i]->getAttr(kSliceFusionBarrierAttrName) == nullptr) + srcLoopIVs[i].getAttr(kSliceFusionBarrierAttrName) == nullptr) continue; for (unsigned j = i; j < numSrcLoopIVs; ++j) { sliceState->lbs[j] = AffineMap(); @@ -583,18 +583,18 @@ AffineForOp mlir::insertBackwardComputationSlice( // Find the inst block positions of 'srcOpInst' within 'srcLoopIVs'. SmallVector positions; // TODO(andydavis): This code is incorrect since srcLoopIVs can be 0-d. - findInstPosition(srcOpInst, srcLoopIVs[0]->getInstruction()->getBlock(), + findInstPosition(srcOpInst, srcLoopIVs[0].getInstruction()->getBlock(), &positions); // Clone src loop nest and insert it a the beginning of the instruction block // of the loop at 'dstLoopDepth' in 'dstLoopIVs'. auto dstAffineForOp = dstLoopIVs[dstLoopDepth - 1]; - FuncBuilder b(dstAffineForOp->getBody(), dstAffineForOp->getBody()->begin()); + FuncBuilder b(dstAffineForOp.getBody(), dstAffineForOp.getBody()->begin()); auto sliceLoopNest = - b.clone(*srcLoopIVs[0]->getInstruction())->cast(); + b.clone(*srcLoopIVs[0].getInstruction())->cast(); Instruction *sliceInst = - getInstAtPosition(positions, /*level=*/0, sliceLoopNest->getBody()); + getInstAtPosition(positions, /*level=*/0, sliceLoopNest.getBody()); // Get loop nest surrounding 'sliceInst'. SmallVector sliceSurroundingLoops; getLoopIVs(*sliceInst, &sliceSurroundingLoops); @@ -611,9 +611,9 @@ AffineForOp mlir::insertBackwardComputationSlice( for (unsigned i = 0; i < numSrcLoopIVs; ++i) { auto forOp = sliceSurroundingLoops[dstLoopDepth + i]; if (AffineMap lbMap = sliceState->lbs[i]) - forOp->setLowerBound(sliceState->lbOperands[i], lbMap); + forOp.setLowerBound(sliceState->lbOperands[i], lbMap); if (AffineMap ubMap = sliceState->ubs[i]) - forOp->setUpperBound(sliceState->ubOperands[i], ubMap); + forOp.setUpperBound(sliceState->ubOperands[i], ubMap); } return sliceLoopNest; } @@ -670,7 +670,7 @@ unsigned mlir::getNumCommonSurroundingLoops(Instruction &A, Instruction &B) { unsigned minNumLoops = std::min(loopsA.size(), loopsB.size()); unsigned numCommonLoops = 0; for (unsigned i = 0; i < minNumLoops; ++i) { - if (loopsA[i]->getInstruction() != loopsB[i]->getInstruction()) + if (loopsA[i].getInstruction() != loopsB[i].getInstruction()) break; ++numCommonLoops; } @@ -727,7 +727,7 @@ static Optional getMemoryFootprintBytes(Block &block, Optional mlir::getMemoryFootprintBytes(AffineForOp forOp, int memorySpace) { - auto *forInst = forOp->getInstruction(); + auto *forInst = forOp.getInstruction(); return ::getMemoryFootprintBytes( *forInst->getBlock(), Block::iterator(forInst), std::next(Block::iterator(forInst)), memorySpace); @@ -737,10 +737,10 @@ Optional mlir::getMemoryFootprintBytes(AffineForOp forOp, /// at 'forOp'. void mlir::getSequentialLoops( AffineForOp forOp, llvm::SmallDenseSet *sequentialLoops) { - forOp->getInstruction()->walk([&](Instruction *inst) { + forOp.getInstruction()->walk([&](Instruction *inst) { if (auto innerFor = inst->dyn_cast()) if (!isLoopParallel(innerFor)) - sequentialLoops->insert(innerFor->getInductionVar()); + sequentialLoops->insert(innerFor.getInductionVar()); }); } @@ -748,13 +748,13 @@ void mlir::getSequentialLoops( bool mlir::isLoopParallel(AffineForOp forOp) { // Collect all load and store ops in loop nest rooted at 'forOp'. SmallVector loadAndStoreOpInsts; - forOp->getInstruction()->walk([&](Instruction *opInst) { + forOp.getInstruction()->walk([&](Instruction *opInst) { if (opInst->isa() || opInst->isa()) loadAndStoreOpInsts.push_back(opInst); }); // Dep check depth would be number of enclosing loops + 1. - unsigned depth = getNestingDepth(*forOp->getInstruction()) + 1; + unsigned depth = getNestingDepth(*forOp.getInstruction()) + 1; // Check dependences between all pairs of ops in 'loadAndStoreOpInsts'. for (auto *srcOpInst : loadAndStoreOpInsts) { diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index 5df31affe31..32543c8d975 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -115,7 +115,7 @@ static AffineMap makePermutationMap( for (auto kvp : enclosingLoopToVectorDim) { assert(kvp.second < perm.size()); auto invariants = getInvariantAccesses( - *kvp.first->cast()->getInductionVar(), unwrappedIndices); + *kvp.first->cast().getInductionVar(), unwrappedIndices); unsigned numIndices = unwrappedIndices.size(); unsigned countInvariantIndices = 0; for (unsigned dim = 0; dim < numIndices; ++dim) { diff --git a/mlir/lib/EDSC/Builders.cpp b/mlir/lib/EDSC/Builders.cpp index 595141af84e..5cf5cb6cfff 100644 --- a/mlir/lib/EDSC/Builders.cpp +++ b/mlir/lib/EDSC/Builders.cpp @@ -87,7 +87,7 @@ mlir::edsc::ValueHandle::createComposedAffineApply(AffineMap map, Instruction *inst = makeComposedAffineApply(ScopedContext::getBuilder(), ScopedContext::getLocation(), map, operands) - ->getInstruction(); + .getInstruction(); assert(inst->getNumResults() == 1 && "Not a single result AffineApply"); return ValueHandle(inst->getResult(0)); } @@ -103,8 +103,8 @@ ValueHandle ValueHandle::create(StringRef name, ArrayRef operands, if (auto f = inst->dyn_cast()) { // Immediately create the loop body so we can just insert instructions right // away. - f->createBody(); - return ValueHandle(f->getInductionVar()); + f.createBody(); + return ValueHandle(f.getInductionVar()); } llvm_unreachable("unsupported instruction, use an InstructionHandle instead"); } @@ -173,7 +173,7 @@ mlir::edsc::LoopBuilder::LoopBuilder(ValueHandle *iv, ubs, ScopedContext::getBuilder()->getMultiDimIdentityMap(ubs.size()), step); } - auto *body = getForInductionVarOwner(iv->getValue())->getBody(); + auto *body = getForInductionVarOwner(iv->getValue()).getBody(); enter(body); } diff --git a/mlir/lib/EDSC/MLIREmitter.cpp b/mlir/lib/EDSC/MLIREmitter.cpp index 89c66b08941..97f8bd75b36 100644 --- a/mlir/lib/EDSC/MLIREmitter.cpp +++ b/mlir/lib/EDSC/MLIREmitter.cpp @@ -52,7 +52,7 @@ static void printDefininingStatement(llvm::raw_ostream &os, Value &v) { return; } if (auto forInst = getForInductionVarOwner(&v)) { - forInst->getInstruction()->print(os); + forInst.getInstruction()->print(os); } else if (auto *bbArg = dyn_cast(&v)) { os << "block_argument"; } else { @@ -176,8 +176,8 @@ Value *mlir::edsc::MLIREmitter::emitExpr(Expr e) { forOp = builder->create( location, lbs, builder->getMultiDimIdentityMap(lbs.size()), ubs, builder->getMultiDimIdentityMap(ubs.size()), step); - forOp->createBody(); - res = forOp->getInductionVar(); + forOp.createBody(); + res = forOp.getInductionVar(); } } @@ -236,7 +236,7 @@ mlir::edsc::MLIREmitter &mlir::edsc::MLIREmitter::emitStmt(const Stmt &stmt) { bind(Bindable(stmt.getLHS()), val); if (stmt.getRHS().getKind() == ExprKind::For) { // Step into the loop. - builder->setInsertionPointToStart(getForInductionVarOwner(val)->getBody()); + builder->setInsertionPointToStart(getForInductionVarOwner(val).getBody()); } emitStmts(stmt.getEnclosedStmts()); builder->setInsertionPoint(block, ip); diff --git a/mlir/lib/EDSC/Types.cpp b/mlir/lib/EDSC/Types.cpp index ac8b98e38c3..a516f9617ac 100644 --- a/mlir/lib/EDSC/Types.cpp +++ b/mlir/lib/EDSC/Types.cpp @@ -209,7 +209,7 @@ Expr::build(FuncBuilder &b, const llvm::DenseMap &ssaBindings, auto affInstr = makeComposedAffineApply( &b, b.getUnknownLoc(), getAttribute("map").cast().getValue(), operandValues); - return {affInstr->getResult()}; + return {affInstr.getResult()}; } auto state = OperationState(b.getContext(), b.getUnknownLoc(), getName()); diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 4fa040d73eb..5a1af03d299 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -403,7 +403,7 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, Block *block, zeroIndex, stride, numEltPerStride); // Since new ops are being appended (for outgoing DMAs), adjust the end to // mark end of range of the original. - *nEnd = Block::iterator(op->getInstruction()); + *nEnd = Block::iterator(op.getInstruction()); } // Matching DMA wait to block on completion; tag always has a 0 index. @@ -414,7 +414,7 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, Block *block, if (*nEnd == end) // Since new ops are being appended (for outgoing DMAs), adjust the end to // mark end of range of the original. - *nEnd = Block::iterator(tagDeallocOp->getInstruction()); + *nEnd = Block::iterator(tagDeallocOp.getInstruction()); // Generate dealloc for the DMA buffer. if (!existingBuf) @@ -500,13 +500,13 @@ bool DmaGeneration::runOnBlock(Block *block) { // the footprint can't be calculated, we assume for now it fits. Recurse // inside if footprint for 'forOp' exceeds capacity, or when // clSkipNonUnitStrideLoop is set and the step size is not one. - bool recurseInner = clSkipNonUnitStrideLoop ? forOp->getStep() != 1 + bool recurseInner = clSkipNonUnitStrideLoop ? forOp.getStep() != 1 : exceedsCapacity(forOp); if (recurseInner) { // We'll recurse and do the DMAs at an inner level for 'forInst'. runOnBlock(/*begin=*/curBegin, /*end=*/it); // Recurse onto the body of this loop. - runOnBlock(forOp->getBody()); + runOnBlock(forOp.getBody()); // The next region starts right after the 'affine.for' instruction. curBegin = std::next(it); } else { @@ -561,15 +561,15 @@ findHighestBlockForPlacement(const MemRefRegion ®ion, Block &block, for (auto e = enclosingFors.rend(); it != e; ++it) { // TODO(bondhugula): also need to be checking this for regions symbols that // aren't loop IVs, whether we are within their resp. defs' dominance scope. - if (llvm::is_contained(symbols, (*it)->getInductionVar())) + if (llvm::is_contained(symbols, it->getInductionVar())) break; } if (it != enclosingFors.rbegin()) { auto lastInvariantIV = *std::prev(it); - *dmaPlacementReadStart = Block::iterator(lastInvariantIV->getInstruction()); + *dmaPlacementReadStart = Block::iterator(lastInvariantIV.getInstruction()); *dmaPlacementWriteStart = std::next(*dmaPlacementReadStart); - *dmaPlacementBlock = lastInvariantIV->getInstruction()->getBlock(); + *dmaPlacementBlock = lastInvariantIV.getInstruction()->getBlock(); } else { *dmaPlacementReadStart = begin; *dmaPlacementWriteStart = end; @@ -737,9 +737,8 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { AffineForOp forOp; uint64_t sizeInKib = llvm::divideCeil(totalDmaBuffersSizeInBytes, 1024); if (llvm::DebugFlag && (forOp = begin->dyn_cast())) { - forOp->emitNote( - Twine(sizeInKib) + - " KiB of DMA buffers in fast memory space for this block\n"); + forOp.emitNote(Twine(sizeInKib) + + " KiB of DMA buffers in fast memory space for this block\n"); } if (totalDmaBuffersSizeInBytes > fastMemCapacityBytes) { diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 84644bf11a0..c757ea8e58b 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -696,8 +696,8 @@ bool MemRefDependenceGraph::init(Function *f) { getLoopIVs(*use.getOwner(), &loops); if (loops.empty()) continue; - assert(forToNodeMap.count(loops[0]->getInstruction()) > 0); - unsigned userLoopNestId = forToNodeMap[loops[0]->getInstruction()]; + assert(forToNodeMap.count(loops[0].getInstruction()) > 0); + unsigned userLoopNestId = forToNodeMap[loops[0].getInstruction()]; addEdge(node.id, userLoopNestId, value); } } @@ -745,8 +745,8 @@ struct LoopNestStatsCollector { void collect(Instruction *inst) { inst->walk([&](AffineForOp forOp) { - auto *forInst = forOp->getInstruction(); - auto *parentInst = forOp->getInstruction()->getParentInst(); + auto *forInst = forOp.getInstruction(); + auto *parentInst = forOp.getInstruction()->getParentInst(); if (parentInst != nullptr) { assert(parentInst->isa() && "Expected parent AffineForOp"); // Add mapping to 'forOp' from its parent AffineForOp. @@ -756,7 +756,7 @@ struct LoopNestStatsCollector { // Record the number of op instructions in the body of 'forOp'. unsigned count = 0; stats->opCountMap[forInst] = 0; - for (auto &inst : *forOp->getBody()) { + for (auto &inst : *forOp.getBody()) { if (!inst.isa() && !inst.isa()) ++count; } @@ -796,7 +796,7 @@ static int64_t getComputeCost( int64_t opCount = stats->opCountMap[forInst]; if (stats->loopMap.count(forInst) > 0) { for (auto childForOp : stats->loopMap[forInst]) { - opCount += getComputeCost(childForOp->getInstruction(), stats, + opCount += getComputeCost(childForOp.getInstruction(), stats, tripCountOverrideMap, computeCostMap); } } @@ -854,11 +854,11 @@ static bool buildSliceTripCountMap( AffineMap ubMap = sliceState->ubs[i]; if (lbMap == AffineMap() || ubMap == AffineMap()) { // The iteration of src loop IV 'i' was not sliced. Use full loop bounds. - if (srcLoopIVs[i]->hasConstantLowerBound() && - srcLoopIVs[i]->hasConstantUpperBound()) { - (*tripCountMap)[srcLoopIVs[i]->getInstruction()] = - srcLoopIVs[i]->getConstantUpperBound() - - srcLoopIVs[i]->getConstantLowerBound(); + if (srcLoopIVs[i].hasConstantLowerBound() && + srcLoopIVs[i].hasConstantUpperBound()) { + (*tripCountMap)[srcLoopIVs[i].getInstruction()] = + srcLoopIVs[i].getConstantUpperBound() - + srcLoopIVs[i].getConstantLowerBound(); continue; } return false; @@ -866,7 +866,7 @@ static bool buildSliceTripCountMap( Optional tripCount = getConstDifference(lbMap, ubMap); if (!tripCount.hasValue()) return false; - (*tripCountMap)[srcLoopIVs[i]->getInstruction()] = tripCount.getValue(); + (*tripCountMap)[srcLoopIVs[i].getInstruction()] = tripCount.getValue(); } return true; } @@ -1060,12 +1060,12 @@ static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) { SmallVector loops; AffineForOp curr = node->inst->cast(); loops.push_back(curr); - auto *currBody = curr->getBody(); + auto *currBody = curr.getBody(); while (!currBody->empty() && std::next(currBody->begin()) == currBody->end() && - (curr = curr->getBody()->front().dyn_cast())) { + (curr = curr.getBody()->front().dyn_cast())) { loops.push_back(curr); - currBody = curr->getBody(); + currBody = curr.getBody(); } if (loops.size() < 2) return; @@ -1091,7 +1091,7 @@ static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) { } } assert(loopNestRootIndex != -1 && "invalid root index"); - node->inst = loops[loopNestRootIndex]->getInstruction(); + node->inst = loops[loopNestRootIndex].getInstruction(); } // TODO(mlir-team): improve/complete this when we have target data. @@ -1119,7 +1119,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, unsigned dstLoopDepth, Optional fastMemorySpace, uint64_t localBufSizeThreshold) { - auto *forInst = forOp->getInstruction(); + auto *forInst = forOp.getInstruction(); // Create builder to insert alloc op just before 'forOp'. FuncBuilder b(forInst); @@ -1187,7 +1187,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, for (auto dimSize : oldMemRefType.getShape()) { if (dimSize == -1) allocOperands.push_back( - top.create(forOp->getLoc(), oldMemRef, dynamicDimCount++)); + top.create(forOp.getLoc(), oldMemRef, dynamicDimCount++)); } // Create new private memref for fused loop 'forOp'. @@ -1196,7 +1196,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, // at the beginning of the function, because loop nests can be reordered // during the fusion pass. Value *newMemRef = - top.create(forOp->getLoc(), newMemRefType, allocOperands); + top.create(forOp.getLoc(), newMemRefType, allocOperands); // Build an AffineMap to remap access functions based on lower bound offsets. SmallVector remapExprs; @@ -1220,7 +1220,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, bool ret = replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap, /*extraOperands=*/outerIVs, - /*domInstFilter=*/&*forOp->getBody()->begin()); + /*domInstFilter=*/&*forOp.getBody()->begin()); assert(ret && "replaceAllMemrefUsesWith should always succeed here"); (void)ret; return newMemRef; @@ -1437,7 +1437,7 @@ static bool isFusionProfitable(Instruction *srcOpInst, // Walk src loop nest and collect stats. LoopNestStats srcLoopNestStats; LoopNestStatsCollector srcStatsCollector(&srcLoopNestStats); - srcStatsCollector.collect(srcLoopIVs[0]->getInstruction()); + srcStatsCollector.collect(srcLoopIVs[0].getInstruction()); // Currently only constant trip count loop nests are supported. if (srcStatsCollector.hasLoopWithNonConstTripCount) { LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count loops unsupported.\n"); @@ -1449,7 +1449,7 @@ static bool isFusionProfitable(Instruction *srcOpInst, LoopNestStats dstLoopNestStats; LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats); - dstStatsCollector.collect(dstLoopIVs[0]->getInstruction()); + dstStatsCollector.collect(dstLoopIVs[0].getInstruction()); // Currently only constant trip count loop nests are supported. if (dstStatsCollector.hasLoopWithNonConstTripCount) { LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count loops unsupported.\n"); @@ -1484,7 +1484,7 @@ static bool isFusionProfitable(Instruction *srcOpInst, // Compute op instance count for the src loop nest without iteration slicing. uint64_t srcLoopNestCost = - getComputeCost(srcLoopIVs[0]->getInstruction(), &srcLoopNestStats, + getComputeCost(srcLoopIVs[0].getInstruction(), &srcLoopNestStats, /*tripCountOverrideMap=*/nullptr, /*computeCostMap=*/nullptr); @@ -1504,7 +1504,7 @@ static bool isFusionProfitable(Instruction *srcOpInst, // Compute op instance count for the src loop nest. uint64_t dstLoopNestCost = - getComputeCost(dstLoopIVs[0]->getInstruction(), &dstLoopNestStats, + getComputeCost(dstLoopIVs[0].getInstruction(), &dstLoopNestStats, /*tripCountOverrideMap=*/nullptr, /*computeCostMap=*/nullptr); @@ -1543,7 +1543,7 @@ static bool isFusionProfitable(Instruction *srcOpInst, // TODO(andydavis) Add load coalescing to memref data flow opt pass. if (storeLoadFwdGuaranteed) { // A single store disappears: -1 for that. - computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]->getInstruction()] = -1; + computeCostMap[srcLoopIVs[numSrcLoopIVs - 1].getInstruction()] = -1; for (auto *loadOp : dstLoadOpInsts) { auto *parentInst = loadOp->getParentInst(); if (parentInst && parentInst->isa()) @@ -1553,15 +1553,15 @@ static bool isFusionProfitable(Instruction *srcOpInst, // Compute op instance count for the src loop nest with iteration slicing. int64_t sliceComputeCost = - getComputeCost(srcLoopIVs[0]->getInstruction(), &srcLoopNestStats, + getComputeCost(srcLoopIVs[0].getInstruction(), &srcLoopNestStats, /*tripCountOverrideMap=*/&sliceTripCountMap, /*computeCostMap=*/&computeCostMap); // Compute cost of fusion for this depth. - computeCostMap[dstLoopIVs[i - 1]->getInstruction()] = sliceComputeCost; + computeCostMap[dstLoopIVs[i - 1].getInstruction()] = sliceComputeCost; int64_t fusedLoopNestComputeCost = - getComputeCost(dstLoopIVs[0]->getInstruction(), &dstLoopNestStats, + getComputeCost(dstLoopIVs[0].getInstruction(), &dstLoopNestStats, /*tripCountOverrideMap=*/nullptr, &computeCostMap); double additionalComputeFraction = @@ -1935,20 +1935,19 @@ public: auto sliceLoopNest = mlir::insertBackwardComputationSlice( srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState); if (sliceLoopNest) { - LLVM_DEBUG(llvm::dbgs() - << "\tslice loop nest:\n" - << *sliceLoopNest->getInstruction() << "\n"); + LLVM_DEBUG(llvm::dbgs() << "\tslice loop nest:\n" + << *sliceLoopNest.getInstruction() << "\n"); // Move 'dstAffineForOp' before 'insertPointInst' if needed. auto dstAffineForOp = dstNode->inst->cast(); - if (insertPointInst != dstAffineForOp->getInstruction()) { - dstAffineForOp->getInstruction()->moveBefore(insertPointInst); + if (insertPointInst != dstAffineForOp.getInstruction()) { + dstAffineForOp.getInstruction()->moveBefore(insertPointInst); } // Update edges between 'srcNode' and 'dstNode'. mdg->updateEdges(srcNode->id, dstNode->id, memref); // Collect slice loop stats. LoopNestStateCollector sliceCollector; - sliceCollector.collect(sliceLoopNest->getInstruction()); + sliceCollector.collect(sliceLoopNest.getInstruction()); // Promote single iteration slice loops to single IV value. for (auto forOp : sliceCollector.forOps) { promoteIfSingleIteration(forOp); @@ -1974,7 +1973,7 @@ public: // Collect dst loop stats after memref privatizaton transformation. LoopNestStateCollector dstLoopCollector; - dstLoopCollector.collect(dstAffineForOp->getInstruction()); + dstLoopCollector.collect(dstAffineForOp.getInstruction()); // Add new load ops to current Node load op list 'loads' to // continue fusing based on new operands. @@ -2097,8 +2096,8 @@ public: if (sliceLoopNest != nullptr) { auto dstForInst = dstNode->inst->cast(); // Update instruction position of fused loop nest (if needed). - if (insertPointInst != dstForInst->getInstruction()) { - dstForInst->getInstruction()->moveBefore(insertPointInst); + if (insertPointInst != dstForInst.getInstruction()) { + dstForInst.getInstruction()->moveBefore(insertPointInst); } // Update data dependence graph state post fusion. updateStateAfterSiblingFusion(sliceLoopNest, sibNode, dstNode); @@ -2190,7 +2189,7 @@ public: // Collect slice loop stats. LoopNestStateCollector sliceCollector; - sliceCollector.collect(sliceLoopNest->getInstruction()); + sliceCollector.collect(sliceLoopNest.getInstruction()); // Promote single iteration slice loops to single IV value. for (auto forOp : sliceCollector.forOps) { promoteIfSingleIteration(forOp); @@ -2199,7 +2198,7 @@ public: // Collect dst loop stats after memref privatizaton transformation. auto dstForInst = dstNode->inst->cast(); LoopNestStateCollector dstLoopCollector; - dstLoopCollector.collect(dstForInst->getInstruction()); + dstLoopCollector.collect(dstForInst.getInstruction()); // Clear and add back loads and stores mdg->clearNodeLoadAndStores(dstNode->id); mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts, @@ -2209,7 +2208,7 @@ public: // function. if (mdg->getOutEdgeCount(sibNode->id) == 0) { mdg->removeNode(sibNode->id); - sibNode->inst->cast()->erase(); + sibNode->inst->cast().erase(); } } diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 314864d3f3c..2dbdf689f02 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -92,14 +92,14 @@ FunctionPassBase *mlir::createLoopTilingPass(uint64_t cacheSizeBytes) { // location in destination's body. static inline void moveLoopBody(AffineForOp src, AffineForOp dest, Block::iterator loc) { - dest->getBody()->getInstructions().splice(loc, - src->getBody()->getInstructions()); + dest.getBody()->getInstructions().splice(loc, + src.getBody()->getInstructions()); } // Move the loop body of AffineForOp 'src' from 'src' to the start of dest's // body. static inline void moveLoopBody(AffineForOp src, AffineForOp dest) { - moveLoopBody(src, dest, dest->getBody()->begin()); + moveLoopBody(src, dest, dest.getBody()->begin()); } /// Constructs and sets new loop bounds after tiling for the case of @@ -114,18 +114,18 @@ constructTiledIndexSetHyperRect(MutableArrayRef origLoops, assert(!origLoops.empty()); assert(origLoops.size() == tileSizes.size()); - FuncBuilder b(origLoops[0]->getInstruction()); + FuncBuilder b(origLoops[0].getInstruction()); unsigned width = origLoops.size(); // Bounds for tile space loops. for (unsigned i = 0; i < width; i++) { - auto lbOperands = origLoops[i]->getLowerBoundOperands(); - auto ubOperands = origLoops[i]->getUpperBoundOperands(); + auto lbOperands = origLoops[i].getLowerBoundOperands(); + auto ubOperands = origLoops[i].getUpperBoundOperands(); SmallVector newLbOperands(lbOperands); SmallVector newUbOperands(ubOperands); - newLoops[i]->setLowerBound(newLbOperands, origLoops[i]->getLowerBoundMap()); - newLoops[i]->setUpperBound(newUbOperands, origLoops[i]->getUpperBoundMap()); - newLoops[i]->setStep(tileSizes[i]); + newLoops[i].setLowerBound(newLbOperands, origLoops[i].getLowerBoundMap()); + newLoops[i].setUpperBound(newUbOperands, origLoops[i].getUpperBoundMap()); + newLoops[i].setStep(tileSizes[i]); } // Bounds for intra-tile loops. for (unsigned i = 0; i < width; i++) { @@ -133,24 +133,24 @@ constructTiledIndexSetHyperRect(MutableArrayRef origLoops, auto mayBeConstantCount = getConstantTripCount(origLoops[i]); // The lower bound is just the tile-space loop. AffineMap lbMap = b.getDimIdentityMap(); - newLoops[width + i]->setLowerBound( - /*operands=*/newLoops[i]->getInductionVar(), lbMap); + newLoops[width + i].setLowerBound( + /*operands=*/newLoops[i].getInductionVar(), lbMap); // Set the upper bound. if (mayBeConstantCount.hasValue() && mayBeConstantCount.getValue() < tileSizes[i]) { // Trip count is less than tile size; upper bound is the trip count. auto ubMap = b.getConstantAffineMap(mayBeConstantCount.getValue()); - newLoops[width + i]->setUpperBoundMap(ubMap); + newLoops[width + i].setUpperBoundMap(ubMap); } else if (largestDiv % tileSizes[i] != 0) { // Intra-tile loop ii goes from i to min(i + tileSize, ub_i). // Construct the upper bound map; the operands are the original operands // with 'i' (tile-space loop) appended to it. The new upper bound map is // the original one with an additional expression i + tileSize appended. - SmallVector ubOperands(origLoops[i]->getUpperBoundOperands()); - ubOperands.push_back(newLoops[i]->getInductionVar()); + SmallVector ubOperands(origLoops[i].getUpperBoundOperands()); + ubOperands.push_back(newLoops[i].getInductionVar()); - auto origUbMap = origLoops[i]->getUpperBoundMap(); + auto origUbMap = origLoops[i].getUpperBoundMap(); SmallVector boundExprs; boundExprs.reserve(1 + origUbMap.getNumResults()); auto dim = b.getAffineDimExpr(origUbMap.getNumInputs()); @@ -161,12 +161,12 @@ constructTiledIndexSetHyperRect(MutableArrayRef origLoops, origUbMap.getResults().end()); auto ubMap = b.getAffineMap(origUbMap.getNumInputs() + 1, 0, boundExprs, {}); - newLoops[width + i]->setUpperBound(/*operands=*/ubOperands, ubMap); + newLoops[width + i].setUpperBound(/*operands=*/ubOperands, ubMap); } else { // No need of the min expression. auto dim = b.getAffineDimExpr(0); auto ubMap = b.getAffineMap(1, 0, dim + tileSizes[i], {}); - newLoops[width + i]->setUpperBound(newLoops[i]->getInductionVar(), ubMap); + newLoops[width + i].setUpperBound(newLoops[i].getInductionVar(), ubMap); } } } @@ -181,14 +181,14 @@ LogicalResult mlir::tileCodeGen(MutableArrayRef band, // Check if the supplied for inst's are all successively nested. for (unsigned i = 1, e = band.size(); i < e; i++) { - assert(band[i]->getInstruction()->getParentInst() == - band[i - 1]->getInstruction()); + assert(band[i].getInstruction()->getParentInst() == + band[i - 1].getInstruction()); } auto origLoops = band; AffineForOp rootAffineForOp = origLoops[0]; - auto loc = rootAffineForOp->getLoc(); + auto loc = rootAffineForOp.getLoc(); // Note that width is at least one since band isn't empty. unsigned width = band.size(); @@ -196,19 +196,19 @@ LogicalResult mlir::tileCodeGen(MutableArrayRef band, AffineForOp innermostPointLoop; // The outermost among the loops as we add more.. - auto *topLoop = rootAffineForOp->getInstruction(); + auto *topLoop = rootAffineForOp.getInstruction(); // Add intra-tile (or point) loops. for (unsigned i = 0; i < width; i++) { FuncBuilder b(topLoop); // Loop bounds will be set later. auto pointLoop = b.create(loc, 0, 0); - pointLoop->createBody(); - pointLoop->getBody()->getInstructions().splice( - pointLoop->getBody()->begin(), topLoop->getBlock()->getInstructions(), + pointLoop.createBody(); + pointLoop.getBody()->getInstructions().splice( + pointLoop.getBody()->begin(), topLoop->getBlock()->getInstructions(), topLoop); newLoops[2 * width - 1 - i] = pointLoop; - topLoop = pointLoop->getInstruction(); + topLoop = pointLoop.getInstruction(); if (i == 0) innermostPointLoop = pointLoop; } @@ -218,12 +218,12 @@ LogicalResult mlir::tileCodeGen(MutableArrayRef band, FuncBuilder b(topLoop); // Loop bounds will be set later. auto tileSpaceLoop = b.create(loc, 0, 0); - tileSpaceLoop->createBody(); - tileSpaceLoop->getBody()->getInstructions().splice( - tileSpaceLoop->getBody()->begin(), + tileSpaceLoop.createBody(); + tileSpaceLoop.getBody()->getInstructions().splice( + tileSpaceLoop.getBody()->begin(), topLoop->getBlock()->getInstructions(), topLoop); newLoops[2 * width - i - 1] = tileSpaceLoop; - topLoop = tileSpaceLoop->getInstruction(); + topLoop = tileSpaceLoop.getInstruction(); } // Move the loop body of the original nest to the new one. @@ -236,19 +236,19 @@ LogicalResult mlir::tileCodeGen(MutableArrayRef band, getIndexSet(band, &cst); if (!cst.isHyperRectangular(0, width)) { - rootAffineForOp->emitError("tiled code generation unimplemented for the " - "non-hyperrectangular case"); + rootAffineForOp.emitError("tiled code generation unimplemented for the " + "non-hyperrectangular case"); return failure(); } constructTiledIndexSetHyperRect(origLoops, newLoops, tileSizes); // In this case, the point loop IVs just replace the original ones. for (unsigned i = 0; i < width; i++) { - origLoopIVs[i]->replaceAllUsesWith(newLoops[i + width]->getInductionVar()); + origLoopIVs[i]->replaceAllUsesWith(newLoops[i + width].getInductionVar()); } // Erase the old loop nest. - rootAffineForOp->erase(); + rootAffineForOp.erase(); return success(); } @@ -265,8 +265,8 @@ static void getTileableBands(Function *f, AffineForOp currInst = root; do { band.push_back(currInst); - } while (currInst->getBody()->getInstructions().size() == 1 && - (currInst = currInst->getBody()->front().dyn_cast())); + } while (currInst.getBody()->getInstructions().size() == 1 && + (currInst = currInst.getBody()->front().dyn_cast())); bands->push_back(band); }; @@ -341,8 +341,8 @@ void LoopTiling::getTileSizes(ArrayRef band, if (avoidMaxMinBounds) adjustToDivisorsOfTripCounts(band, tileSizes); LLVM_DEBUG( - rootForOp->emitWarning("memory footprint unknown: using default tile " - "sizes adjusted to trip count divisors")); + rootForOp.emitWarning("memory footprint unknown: using default tile " + "sizes adjusted to trip count divisors")); return; } @@ -398,7 +398,7 @@ void LoopTiling::runOnFunction() { msg << tSize << " "; msg << "]\n"; auto rootForOp = band[0]; - rootForOp->emitNote(msg.str()); + rootForOp.emitNote(msg.str()); } if (failed(tileCodeGen(band, tileSizes))) return signalPassFailure(); diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 240f3960488..0822ddf37e3 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -155,7 +155,7 @@ LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp, if (unrollJamFactor == 1) return promoteIfSingleIteration(forOp); - if (forOp->getBody()->empty()) + if (forOp.getBody()->empty()) return failure(); // Loops where both lower and upper bounds are multi-result maps won't be @@ -164,7 +164,7 @@ LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp, // TODO(mlir-team): this may not be common, but we could support the case // where the lower bound is a multi-result map and the ub is a single result // one. - if (forOp->getLowerBoundMap().getNumResults() != 1) + if (forOp.getLowerBoundMap().getNumResults() != 1) return failure(); Optional mayBeConstantTripCount = getConstantTripCount(forOp); @@ -173,7 +173,7 @@ LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp, mayBeConstantTripCount.getValue() < unrollJamFactor) return failure(); - auto *forInst = forOp->getInstruction(); + auto *forInst = forOp.getInstruction(); // Gather all sub-blocks to jam upon the loop being unrolled. JamBlockGatherer jbg; @@ -193,21 +193,21 @@ LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp, SmallVector cleanupOperands; getCleanupLoopLowerBound(forOp, unrollJamFactor, &cleanupMap, &cleanupOperands, &builder); - cleanupAffineForOp->setLowerBound(cleanupOperands, cleanupMap); + cleanupAffineForOp.setLowerBound(cleanupOperands, cleanupMap); // Promote the cleanup loop if it has turned into a single iteration loop. promoteIfSingleIteration(cleanupAffineForOp); // Adjust the upper bound of the original loop - it will be the same as the // cleanup loop's lower bound. Its lower bound remains unchanged. - forOp->setUpperBound(cleanupOperands, cleanupMap); + forOp.setUpperBound(cleanupOperands, cleanupMap); } // Scale the step of loop being unroll-jammed by the unroll-jam factor. - int64_t step = forOp->getStep(); - forOp->setStep(step * unrollJamFactor); + int64_t step = forOp.getStep(); + forOp.setStep(step * unrollJamFactor); - auto *forOpIV = forOp->getInductionVar(); + auto *forOpIV = forOp.getInductionVar(); for (auto &subBlock : subBlocks) { // Builder to insert unroll-jammed bodies. Insert right at the end of // sub-block. diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index cb65720cee3..93197c30cb2 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -51,7 +51,7 @@ public: if (!lhs || !rhs) return nullptr; auto op = builder.create(loc, lhs, rhs); - return op->getResult(); + return op.getResult(); } Value *visitAddExpr(AffineBinaryOpExpr expr) { @@ -189,7 +189,7 @@ public: builder.getIntegerAttr(builder.getIndexType(), expr.getValue()); auto op = builder.create(loc, builder.getIndexType(), valueAttr); - return op->getResult(); + return op.getResult(); } Value *visitDimExpr(AffineDimExpr expr) { @@ -270,7 +270,7 @@ static Value *buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate, Value *value = *valueIt++; for (; valueIt != values.end(); ++valueIt) { auto cmpOp = builder.create(loc, predicate, value, *valueIt); - value = builder.create(loc, cmpOp->getResult(), value, *valueIt); + value = builder.create(loc, cmpOp.getResult(), value, *valueIt); } return value; @@ -320,8 +320,8 @@ static Value *buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate, // +--------------------------------+ // bool LowerAffinePass::lowerAffineFor(AffineForOp forOp) { - auto loc = forOp->getLoc(); - auto *forInst = forOp->getInstruction(); + auto loc = forOp.getLoc(); + auto *forInst = forOp.getInstruction(); // Start by splitting the block containing the 'affine.for' into two parts. // The part before will get the init code, the part after will be the end @@ -339,19 +339,19 @@ bool LowerAffinePass::lowerAffineFor(AffineForOp forOp) { auto *bodyBlock = new Block(); bodyBlock->insertBefore(endBlock); - auto *oldBody = forOp->getBody(); + auto *oldBody = forOp.getBody(); bodyBlock->getInstructions().splice(bodyBlock->begin(), oldBody->getInstructions(), oldBody->begin(), oldBody->end()); // The code in the body of the forOp now uses 'iv' as its indvar. - forOp->getInductionVar()->replaceAllUsesWith(iv); + forOp.getInductionVar()->replaceAllUsesWith(iv); // Append the induction variable stepping logic and branch back to the exit // condition block. Construct an affine expression f : (x -> x+step) and // apply this expression to the induction variable. FuncBuilder builder(bodyBlock); - auto affStep = builder.getAffineConstantExpr(forOp->getStep()); + auto affStep = builder.getAffineConstantExpr(forOp.getStep()); auto affDim = builder.getAffineDimExpr(0); auto stepped = expandAffineExpr(&builder, loc, affDim + affStep, iv, {}); if (!stepped) @@ -364,18 +364,18 @@ bool LowerAffinePass::lowerAffineFor(AffineForOp forOp) { builder.setInsertionPointToEnd(initBlock); // Compute loop bounds. - SmallVector operands(forOp->getLowerBoundOperands()); + SmallVector operands(forOp.getLowerBoundOperands()); auto lbValues = expandAffineMap(&builder, forInst->getLoc(), - forOp->getLowerBoundMap(), operands); + forOp.getLowerBoundMap(), operands); if (!lbValues) return true; Value *lowerBound = buildMinMaxReductionSeq(loc, CmpIPredicate::SGT, *lbValues, builder); - operands.assign(forOp->getUpperBoundOperands().begin(), - forOp->getUpperBoundOperands().end()); + operands.assign(forOp.getUpperBoundOperands().begin(), + forOp.getUpperBoundOperands().end()); auto ubValues = expandAffineMap(&builder, forInst->getLoc(), - forOp->getUpperBoundMap(), operands); + forOp.getUpperBoundMap(), operands); if (!ubValues) return true; Value *upperBound = @@ -390,7 +390,7 @@ bool LowerAffinePass::lowerAffineFor(AffineForOp forOp) { endBlock, ArrayRef()); // Ok, we're done! - forOp->erase(); + forOp.erase(); return false; } @@ -454,7 +454,7 @@ bool LowerAffinePass::lowerAffineFor(AffineForOp forOp) { // +--------------------------------+ // bool LowerAffinePass::lowerAffineIf(AffineIfOp ifOp) { - auto *ifInst = ifOp->getInstruction(); + auto *ifInst = ifOp.getInstruction(); auto loc = ifInst->getLoc(); // Start by splitting the block containing the 'affine.if' into two parts. The @@ -470,7 +470,7 @@ bool LowerAffinePass::lowerAffineIf(AffineIfOp ifOp) { thenBlock->insertBefore(continueBlock); // If the 'then' block is not empty, then splice the instructions. - auto &oldThenBlocks = ifOp->getThenBlocks(); + auto &oldThenBlocks = ifOp.getThenBlocks(); if (!oldThenBlocks.empty()) { // We currently only handle one 'then' block. if (std::next(oldThenBlocks.begin()) != oldThenBlocks.end()) @@ -489,7 +489,7 @@ bool LowerAffinePass::lowerAffineIf(AffineIfOp ifOp) { // Handle the 'else' block the same way, but we skip it if we have no else // code. Block *elseBlock = continueBlock; - auto &oldElseBlocks = ifOp->getElseBlocks(); + auto &oldElseBlocks = ifOp.getElseBlocks(); if (!oldElseBlocks.empty()) { // We currently only handle one 'else' block. if (std::next(oldElseBlocks.begin()) != oldElseBlocks.end()) @@ -507,7 +507,7 @@ bool LowerAffinePass::lowerAffineIf(AffineIfOp ifOp) { } // Ok, now we just have to handle the condition logic. - auto integerSet = ifOp->getIntegerSet(); + auto integerSet = ifOp.getIntegerSet(); // Implement short-circuit logic. For each affine expression in the // 'affine.if' condition, convert it into an affine map and call @@ -545,7 +545,7 @@ bool LowerAffinePass::lowerAffineIf(AffineIfOp ifOp) { auto comparisonOp = builder.create( loc, isEquality ? CmpIPredicate::EQ : CmpIPredicate::SGE, affResult, zeroConstant); - builder.create(loc, comparisonOp->getResult(), nextBlock, + builder.create(loc, comparisonOp.getResult(), nextBlock, /*trueArgs*/ ArrayRef(), elseBlock, /*falseArgs*/ ArrayRef()); builder.setInsertionPointToEnd(nextBlock); @@ -570,19 +570,19 @@ bool LowerAffinePass::lowerAffineIf(AffineIfOp ifOp) { // Convert an "affine.apply" operation into a sequence of arithmetic // instructions using the StandardOps dialect. Return true on error. bool LowerAffinePass::lowerAffineApply(AffineApplyOp op) { - FuncBuilder builder(op->getInstruction()); + FuncBuilder builder(op.getInstruction()); auto maybeExpandedMap = - expandAffineMap(&builder, op->getLoc(), op->getAffineMap(), - llvm::to_vector<8>(op->getOperands())); + expandAffineMap(&builder, op.getLoc(), op.getAffineMap(), + llvm::to_vector<8>(op.getOperands())); if (!maybeExpandedMap) return true; - Value *original = op->getResult(); + Value *original = op.getResult(); Value *expanded = (*maybeExpandedMap)[0]; if (!expanded) return true; original->replaceAllUsesWith(expanded); - op->erase(); + op.erase(); return false; } diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index b59071aa9fe..a92e2d5960c 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -72,7 +72,7 @@ static unsigned getTagMemRefPos(Instruction &dmaInst) { /// added dimension by the loop IV of the specified 'affine.for' instruction /// modulo 2. Returns false if such a replacement cannot be performed. static bool doubleBuffer(Value *oldMemRef, AffineForOp forOp) { - auto *forBody = forOp->getBody(); + auto *forBody = forOp.getBody(); FuncBuilder bInner(forBody, forBody->begin()); bInner.setInsertionPoint(forBody, forBody->begin()); @@ -93,7 +93,7 @@ static bool doubleBuffer(Value *oldMemRef, AffineForOp forOp) { auto newMemRefType = doubleShape(oldMemRefType); // The double buffer is allocated right before 'forInst'. - auto *forInst = forOp->getInstruction(); + auto *forInst = forOp.getInstruction(); FuncBuilder bOuter(forInst); // Put together alloc operands for any dynamic dimensions of the memref. SmallVector allocOperands; @@ -110,21 +110,21 @@ static bool doubleBuffer(Value *oldMemRef, AffineForOp forOp) { // Create 'iv mod 2' value to index the leading dimension. auto d0 = bInner.getAffineDimExpr(0); - int64_t step = forOp->getStep(); + int64_t step = forOp.getStep(); auto modTwoMap = bInner.getAffineMap(/*dimCount=*/1, /*symbolCount=*/0, {d0.floorDiv(step) % 2}, {}); - auto ivModTwoOp = bInner.create(forOp->getLoc(), modTwoMap, - forOp->getInductionVar()); + auto ivModTwoOp = bInner.create(forOp.getLoc(), modTwoMap, + forOp.getInductionVar()); // replaceAllMemRefUsesWith will always succeed unless the forOp body has // non-deferencing uses of the memref (dealloc's are fine though). - if (!replaceAllMemRefUsesWith( - oldMemRef, newMemRef, /*extraIndices=*/{ivModTwoOp}, - /*indexRemap=*/AffineMap(), - /*extraOperands=*/{}, - /*domInstFilter=*/&*forOp->getBody()->begin())) { + if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef, + /*extraIndices=*/{ivModTwoOp}, + /*indexRemap=*/AffineMap(), + /*extraOperands=*/{}, + /*domInstFilter=*/&*forOp.getBody()->begin())) { LLVM_DEBUG( - forOp->emitError("memref replacement for double buffering failed")); + forOp.emitError("memref replacement for double buffering failed")); ivModTwoOp->getInstruction()->erase(); return false; } @@ -180,14 +180,14 @@ static void findMatchingStartFinishInsts( // Collect outgoing DMA instructions - needed to check for dependences below. SmallVector outgoingDmaOps; - for (auto &inst : *forOp->getBody()) { + for (auto &inst : *forOp.getBody()) { auto dmaStartOp = inst.dyn_cast(); if (dmaStartOp && dmaStartOp->isSrcMemorySpaceFaster()) outgoingDmaOps.push_back(dmaStartOp); } SmallVector dmaStartInsts, dmaFinishInsts; - for (auto &inst : *forOp->getBody()) { + for (auto &inst : *forOp.getBody()) { // Collect DMA finish instructions. if (inst.isa()) { dmaFinishInsts.push_back(&inst); @@ -220,7 +220,7 @@ static void findMatchingStartFinishInsts( // We can double buffer regardless of dealloc's outside the loop. if (use.getOwner()->isa()) continue; - if (!forOp->getBody()->findAncestorInstInBlock(*use.getOwner())) { + if (!forOp.getBody()->findAncestorInstInBlock(*use.getOwner())) { LLVM_DEBUG(llvm::dbgs() << "can't pipeline: buffer is live out of loop\n";); escapingUses = true; @@ -249,8 +249,7 @@ static void findMatchingStartFinishInsts( void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { auto mayBeConstTripCount = getConstantTripCount(forOp); if (!mayBeConstTripCount.hasValue()) { - LLVM_DEBUG( - forOp->emitNote("won't pipeline due to unknown trip count loop")); + LLVM_DEBUG(forOp.emitNote("won't pipeline due to unknown trip count loop")); return; } @@ -258,7 +257,7 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { findMatchingStartFinishInsts(forOp, startWaitPairs); if (startWaitPairs.empty()) { - LLVM_DEBUG(forOp->emitNote("No dma start/finish pairs\n")); + LLVM_DEBUG(forOp.emitNote("No dma start/finish pairs\n")); return; } @@ -332,7 +331,7 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { mlir::createAffineComputationSlice(dmaStartInst, &sliceOps); if (!sliceOps.empty()) { for (auto sliceOp : sliceOps) { - instShiftMap[sliceOp->getInstruction()] = 0; + instShiftMap[sliceOp.getInstruction()] = 0; } } else { // If a slice wasn't created, the reachable affine.apply op's from its @@ -346,16 +345,16 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { } } // Everything else (including compute ops and dma finish) are shifted by one. - for (auto &inst : *forOp->getBody()) { + for (auto &inst : *forOp.getBody()) { if (instShiftMap.find(&inst) == instShiftMap.end()) { instShiftMap[&inst] = 1; } } // Get shifts stored in map. - std::vector shifts(forOp->getBody()->getInstructions().size()); + std::vector shifts(forOp.getBody()->getInstructions().size()); unsigned s = 0; - for (auto &inst : *forOp->getBody()) { + for (auto &inst : *forOp.getBody()) { assert(instShiftMap.find(&inst) != instShiftMap.end()); shifts[s++] = instShiftMap[&inst]; diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index bf0c3ced2e2..918bd5b9e21 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -47,7 +47,7 @@ void mlir::getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor, AffineMap *map, SmallVectorImpl *operands, FuncBuilder *b) { - auto lbMap = forOp->getLowerBoundMap(); + auto lbMap = forOp.getLowerBoundMap(); // Single result lower bound map only. if (lbMap.getNumResults() != 1) { @@ -65,10 +65,10 @@ void mlir::getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor, return; } - unsigned step = forOp->getStep(); + unsigned step = forOp.getStep(); - SmallVector lbOperands(forOp->getLowerBoundOperands()); - auto lb = b->create(forOp->getLoc(), lbMap, lbOperands); + SmallVector lbOperands(forOp.getLowerBoundOperands()); + auto lb = b->create(forOp.getLoc(), lbMap, lbOperands); // For each upper bound expr, get the range. // Eg: affine.for %i = lb to min (ub1, ub2), @@ -84,7 +84,7 @@ void mlir::getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor, b->getAffineMap(tripCountMap.getNumDims(), tripCountMap.getNumSymbols(), bumpExprs[i], {}); bumpValues[i] = - b->create(forOp->getLoc(), bumpMap, tripCountOperands); + b->create(forOp.getLoc(), bumpMap, tripCountOperands); } SmallVector newUbExprs(tripCountMap.getNumResults()); @@ -105,8 +105,8 @@ void mlir::getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor, v->getDefiningInst()->erase(); } } - if (lb->use_empty()) - lb->erase(); + if (lb.use_empty()) + lb.erase(); } /// Promotes the loop body of a forOp to its containing block if the forOp @@ -118,21 +118,21 @@ LogicalResult mlir::promoteIfSingleIteration(AffineForOp forOp) { return failure(); // TODO(mlir-team): there is no builder for a max. - if (forOp->getLowerBoundMap().getNumResults() != 1) + if (forOp.getLowerBoundMap().getNumResults() != 1) return failure(); // Replaces all IV uses to its single iteration value. - auto *iv = forOp->getInductionVar(); - Instruction *forInst = forOp->getInstruction(); + auto *iv = forOp.getInductionVar(); + Instruction *forInst = forOp.getInstruction(); if (!iv->use_empty()) { - if (forOp->hasConstantLowerBound()) { + if (forOp.hasConstantLowerBound()) { auto *mlFunc = forInst->getFunction(); FuncBuilder topBuilder(mlFunc); auto constOp = topBuilder.create( - forOp->getLoc(), forOp->getConstantLowerBound()); + forOp.getLoc(), forOp.getConstantLowerBound()); iv->replaceAllUsesWith(constOp); } else { - AffineBound lb = forOp->getLowerBound(); + AffineBound lb = forOp.getLowerBound(); SmallVector lbOperands(lb.operand_begin(), lb.operand_end()); FuncBuilder builder(forInst->getBlock(), Block::iterator(forInst)); if (lb.getMap() == builder.getDimIdentityMap()) { @@ -148,8 +148,8 @@ LogicalResult mlir::promoteIfSingleIteration(AffineForOp forOp) { // Move the loop body instructions to the loop's containing block. auto *block = forInst->getBlock(); block->getInstructions().splice(Block::iterator(forInst), - forOp->getBody()->getInstructions()); - forOp->erase(); + forOp.getBody()->getInstructions()); + forOp.erase(); return success(); } @@ -173,18 +173,18 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, const std::vector>> &instGroupQueue, unsigned offset, AffineForOp srcForInst, FuncBuilder *b) { - SmallVector lbOperands(srcForInst->getLowerBoundOperands()); - SmallVector ubOperands(srcForInst->getUpperBoundOperands()); + SmallVector lbOperands(srcForInst.getLowerBoundOperands()); + SmallVector ubOperands(srcForInst.getUpperBoundOperands()); assert(lbMap.getNumInputs() == lbOperands.size()); assert(ubMap.getNumInputs() == ubOperands.size()); auto loopChunk = - b->create(srcForInst->getLoc(), lbOperands, lbMap, - ubOperands, ubMap, srcForInst->getStep()); - loopChunk->createBody(); - auto *loopChunkIV = loopChunk->getInductionVar(); - auto *srcIV = srcForInst->getInductionVar(); + b->create(srcForInst.getLoc(), lbOperands, lbMap, ubOperands, + ubMap, srcForInst.getStep()); + loopChunk.createBody(); + auto *loopChunkIV = loopChunk.getInductionVar(); + auto *srcIV = srcForInst.getInductionVar(); BlockAndValueMapping operandMap; @@ -197,18 +197,18 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, // Generate the remapping if the shift is not zero: remappedIV = newIV - // shift. if (!srcIV->use_empty() && shift != 0) { - FuncBuilder b(loopChunk->getBody()); + FuncBuilder b(loopChunk.getBody()); auto ivRemap = b.create( - srcForInst->getLoc(), + srcForInst.getLoc(), b.getSingleDimShiftAffineMap( - -static_cast(srcForInst->getStep() * shift)), + -static_cast(srcForInst.getStep() * shift)), loopChunkIV); operandMap.map(srcIV, ivRemap); } else { operandMap.map(srcIV, loopChunkIV); } for (auto *inst : insts) { - loopChunk->getBody()->push_back(inst->clone(operandMap, b->getContext())); + loopChunk.getBody()->push_back(inst->clone(operandMap, b->getContext())); } } if (succeeded(promoteIfSingleIteration(loopChunk))) @@ -233,7 +233,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, // method. LogicalResult mlir::instBodySkew(AffineForOp forOp, ArrayRef shifts, bool unrollPrologueEpilogue) { - if (forOp->getBody()->empty()) + if (forOp.getBody()->empty()) return success(); // If the trip counts aren't constant, we would need versioning and @@ -242,7 +242,7 @@ LogicalResult mlir::instBodySkew(AffineForOp forOp, ArrayRef shifts, // constant trip count "full tiles" before applying this. auto mayBeConstTripCount = getConstantTripCount(forOp); if (!mayBeConstTripCount.hasValue()) { - LLVM_DEBUG(forOp->emitNote("non-constant trip count loop not handled")); + LLVM_DEBUG(forOp.emitNote("non-constant trip count loop not handled")); return success(); } uint64_t tripCount = mayBeConstTripCount.getValue(); @@ -250,9 +250,9 @@ LogicalResult mlir::instBodySkew(AffineForOp forOp, ArrayRef shifts, assert(isInstwiseShiftValid(forOp, shifts) && "shifts will lead to an invalid transformation\n"); - int64_t step = forOp->getStep(); + int64_t step = forOp.getStep(); - unsigned numChildInsts = forOp->getBody()->getInstructions().size(); + unsigned numChildInsts = forOp.getBody()->getInstructions().size(); // Do a linear time (counting) sort for the shifts. uint64_t maxShift = 0; @@ -261,7 +261,7 @@ LogicalResult mlir::instBodySkew(AffineForOp forOp, ArrayRef shifts, } // Such large shifts are not the typical use case. if (maxShift >= numChildInsts) { - forOp->emitWarning("not shifting because shifts are unrealistically large"); + forOp.emitWarning("not shifting because shifts are unrealistically large"); return success(); } @@ -270,7 +270,7 @@ LogicalResult mlir::instBodySkew(AffineForOp forOp, ArrayRef shifts, // body of the 'affine.for' inst. std::vector> sortedInstGroups(maxShift + 1); unsigned pos = 0; - for (auto &inst : *forOp->getBody()) { + for (auto &inst : *forOp.getBody()) { auto shift = shifts[pos++]; sortedInstGroups[shift].push_back(&inst); } @@ -288,9 +288,9 @@ LogicalResult mlir::instBodySkew(AffineForOp forOp, ArrayRef shifts, // of instructions is paired with its shift. std::vector>> instGroupQueue; - auto origLbMap = forOp->getLowerBoundMap(); + auto origLbMap = forOp.getLowerBoundMap(); uint64_t lbShift = 0; - FuncBuilder b(forOp->getInstruction()); + FuncBuilder b(forOp.getInstruction()); for (uint64_t d = 0, e = sortedInstGroups.size(); d < e; ++d) { // If nothing is shifted by d, continue. if (sortedInstGroups[d].empty()) @@ -340,12 +340,12 @@ LogicalResult mlir::instBodySkew(AffineForOp forOp, ArrayRef shifts, } // Erase the original for inst. - forOp->erase(); + forOp.erase(); if (unrollPrologueEpilogue && prologue) loopUnrollFull(prologue); if (unrollPrologueEpilogue && !epilogue && - epilogue->getInstruction() != prologue->getInstruction()) + epilogue.getInstruction() != prologue.getInstruction()) loopUnrollFull(epilogue); return success(); @@ -385,7 +385,7 @@ LogicalResult mlir::loopUnrollByFactor(AffineForOp forOp, if (unrollFactor == 1) return promoteIfSingleIteration(forOp); - if (forOp->getBody()->empty()) + if (forOp.getBody()->empty()) return failure(); // Loops where the lower bound is a max expression isn't supported for @@ -393,7 +393,7 @@ LogicalResult mlir::loopUnrollByFactor(AffineForOp forOp, // both the lower bound and the upper bound are multi-result maps. However, // one meaningful way to do such unrolling would be to specialize the loop for // the 'hotspot' case and unroll that hotspot. - if (forOp->getLowerBoundMap().getNumResults() != 1) + if (forOp.getLowerBoundMap().getNumResults() != 1) return failure(); // If the trip count is lower than the unroll factor, no unrolled body. @@ -404,7 +404,7 @@ LogicalResult mlir::loopUnrollByFactor(AffineForOp forOp, return failure(); // Generate the cleanup loop if trip count isn't a multiple of unrollFactor. - Instruction *forInst = forOp->getInstruction(); + Instruction *forInst = forOp.getInstruction(); if (getLargestDivisorOfTripCount(forOp) % unrollFactor != 0) { FuncBuilder builder(forInst->getBlock(), ++Block::iterator(forInst)); auto cleanupForInst = builder.clone(*forInst)->cast(); @@ -415,29 +415,29 @@ LogicalResult mlir::loopUnrollByFactor(AffineForOp forOp, assert(cleanupMap && "cleanup loop lower bound map for single result lower bound maps " "can always be determined"); - cleanupForInst->setLowerBound(cleanupOperands, cleanupMap); + cleanupForInst.setLowerBound(cleanupOperands, cleanupMap); // Promote the loop body up if this has turned into a single iteration loop. promoteIfSingleIteration(cleanupForInst); // Adjust upper bound of the original loop; this is the same as the lower // bound of the cleanup loop. - forOp->setUpperBound(cleanupOperands, cleanupMap); + forOp.setUpperBound(cleanupOperands, cleanupMap); } // Scale the step of loop being unrolled by unroll factor. - int64_t step = forOp->getStep(); - forOp->setStep(step * unrollFactor); + int64_t step = forOp.getStep(); + forOp.setStep(step * unrollFactor); // Builder to insert unrolled bodies right after the last instruction in the // body of 'forOp'. - FuncBuilder builder(forOp->getBody(), forOp->getBody()->end()); + FuncBuilder builder(forOp.getBody(), forOp.getBody()->end()); // Keep a pointer to the last instruction in the original block so that we // know what to clone (since we are doing this in-place). - Block::iterator srcBlockEnd = std::prev(forOp->getBody()->end()); + Block::iterator srcBlockEnd = std::prev(forOp.getBody()->end()); // Unroll the contents of 'forOp' (append unrollFactor-1 additional copies). - auto *forOpIV = forOp->getInductionVar(); + auto *forOpIV = forOp.getInductionVar(); for (unsigned i = 1; i < unrollFactor; i++) { BlockAndValueMapping operandMap; @@ -448,12 +448,12 @@ LogicalResult mlir::loopUnrollByFactor(AffineForOp forOp, auto d0 = builder.getAffineDimExpr(0); auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {}); auto ivUnroll = - builder.create(forOp->getLoc(), bumpMap, forOpIV); + builder.create(forOp.getLoc(), bumpMap, forOpIV); operandMap.map(forOpIV, ivUnroll); } // Clone the original body of 'forOp'. - for (auto it = forOp->getBody()->begin(); it != std::next(srcBlockEnd); + for (auto it = forOp.getBody()->begin(); it != std::next(srcBlockEnd); it++) { builder.clone(*it, operandMap); } @@ -467,20 +467,20 @@ LogicalResult mlir::loopUnrollByFactor(AffineForOp forOp, /// Performs loop interchange on 'forOpA' and 'forOpB', where 'forOpB' is /// nested within 'forOpA' as the only instruction in its block. void mlir::interchangeLoops(AffineForOp forOpA, AffineForOp forOpB) { - auto *forOpAInst = forOpA->getInstruction(); + auto *forOpAInst = forOpA.getInstruction(); // 1) Slice forOpA's instruction list (which is just forOpB) just before // forOpA (in forOpA's parent's block) this should leave 'forOpA's // instruction list empty (because its perfectly nested). - assert(&*forOpA->getBody()->begin() == forOpB->getInstruction()); + assert(&*forOpA.getBody()->begin() == forOpB.getInstruction()); forOpAInst->getBlock()->getInstructions().splice( - Block::iterator(forOpAInst), forOpA->getBody()->getInstructions()); + Block::iterator(forOpAInst), forOpA.getBody()->getInstructions()); // 2) Slice forOpB's instruction list into forOpA's instruction list (this // leaves forOpB's instruction list empty). - forOpA->getBody()->getInstructions().splice( - forOpA->getBody()->begin(), forOpB->getBody()->getInstructions()); + forOpA.getBody()->getInstructions().splice( + forOpA.getBody()->begin(), forOpB.getBody()->getInstructions()); // 3) Slice forOpA into forOpB's instruction list. - forOpB->getBody()->getInstructions().splice( - forOpB->getBody()->begin(), forOpAInst->getBlock()->getInstructions(), + forOpB.getBody()->getInstructions().splice( + forOpB.getBody()->begin(), forOpAInst->getBlock()->getInstructions(), Block::iterator(forOpAInst)); } @@ -488,8 +488,8 @@ void mlir::interchangeLoops(AffineForOp forOpA, AffineForOp forOpB) { /// deeper in the loop nest. void mlir::sinkLoop(AffineForOp forOp, unsigned loopDepth) { for (unsigned i = 0; i < loopDepth; ++i) { - assert(forOp->getBody()->front().isa()); - AffineForOp nextForOp = forOp->getBody()->front().cast(); + assert(forOp.getBody()->front().isa()); + AffineForOp nextForOp = forOp.getBody()->front().cast(); interchangeLoops(forOp, nextForOp); } } @@ -521,12 +521,12 @@ static void augmentMapAndBounds(FuncBuilder *b, Value *iv, AffineMap *map, static void cloneLoopBodyInto(AffineForOp forOp, Value *oldIv, AffineForOp newForOp) { BlockAndValueMapping map; - map.map(oldIv, newForOp->getInductionVar()); - FuncBuilder b(newForOp->getBody(), newForOp->getBody()->end()); - for (auto it = forOp->getBody()->begin(), end = forOp->getBody()->end(); + map.map(oldIv, newForOp.getInductionVar()); + FuncBuilder b(newForOp.getBody(), newForOp.getBody()->end()); + for (auto it = forOp.getBody()->begin(), end = forOp.getBody()->end(); it != end; ++it) { // Step over newForOp in case it is nested under forOp. - if (&*it == newForOp->getInstruction()) { + if (&*it == newForOp.getInstruction()) { continue; } auto *inst = b.clone(*it, map); @@ -554,35 +554,35 @@ stripmineSink(AffineForOp forOp, uint64_t factor, // forOp and that targets are not nested under each other when DominanceInfo // exposes the capability. It seems overkill to construct a whole function // dominance tree at this point. - auto originalStep = forOp->getStep(); + auto originalStep = forOp.getStep(); auto scaledStep = originalStep * factor; - forOp->setStep(scaledStep); + forOp.setStep(scaledStep); - auto *forInst = forOp->getInstruction(); + auto *forInst = forOp.getInstruction(); FuncBuilder b(forInst->getBlock(), ++Block::iterator(forInst)); // Lower-bound map creation. - auto lbMap = forOp->getLowerBoundMap(); - SmallVector lbOperands(forOp->getLowerBoundOperands()); - augmentMapAndBounds(&b, forOp->getInductionVar(), &lbMap, &lbOperands); + auto lbMap = forOp.getLowerBoundMap(); + SmallVector lbOperands(forOp.getLowerBoundOperands()); + augmentMapAndBounds(&b, forOp.getInductionVar(), &lbMap, &lbOperands); // Upper-bound map creation. - auto ubMap = forOp->getUpperBoundMap(); - SmallVector ubOperands(forOp->getUpperBoundOperands()); - augmentMapAndBounds(&b, forOp->getInductionVar(), &ubMap, &ubOperands, + auto ubMap = forOp.getUpperBoundMap(); + SmallVector ubOperands(forOp.getUpperBoundOperands()); + augmentMapAndBounds(&b, forOp.getInductionVar(), &ubMap, &ubOperands, /*offset=*/scaledStep); SmallVector innerLoops; for (auto t : targets) { // Insert newForOp at the end of `t`. - FuncBuilder b(t->getBody(), t->getBody()->end()); - auto newForOp = b.create(t->getLoc(), lbOperands, lbMap, + FuncBuilder b(t.getBody(), t.getBody()->end()); + auto newForOp = b.create(t.getLoc(), lbOperands, lbMap, ubOperands, ubMap, originalStep); - newForOp->createBody(); - cloneLoopBodyInto(t, forOp->getInductionVar(), newForOp); + newForOp.createBody(); + cloneLoopBodyInto(t, forOp.getInductionVar(), newForOp); // Remove all instructions from `t` except `newForOp`. - auto rit = ++newForOp->getInstruction()->getReverseIterator(); - auto re = t->getBody()->rend(); + auto rit = ++newForOp.getInstruction()->getReverseIterator(); + auto re = t.getBody()->rend(); for (auto &inst : llvm::make_early_inc_range(llvm::make_range(rit, re))) { inst.erase(); } diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index 9c9f8593f31..f57a53d3670 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -231,7 +231,7 @@ static bool affineApplyOp(Instruction &inst) { static bool singleResultAffineApplyOpWithoutUses(Instruction &inst) { auto app = inst.dyn_cast(); - return app && app->use_empty(); + return app && app.use_empty(); } void VectorizerTestPass::testNormalizeMaps(Function *f) { @@ -249,8 +249,8 @@ void VectorizerTestPass::testNormalizeMaps(Function *f) { for (auto m : matches) { auto app = m.getMatchedInstruction()->cast(); FuncBuilder b(m.getMatchedInstruction()); - SmallVector operands(app->getOperands()); - makeComposedAffineApply(&b, app->getLoc(), app->getAffineMap(), operands); + SmallVector operands(app.getOperands()); + makeComposedAffineApply(&b, app.getLoc(), app.getAffineMap(), operands); } } // We should now be able to erase everything in reverse order in this test. diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index a52129ed0d6..362cad352fb 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -856,7 +856,7 @@ static LogicalResult vectorizeRootOrTerminal(Value *iv, static LogicalResult vectorizeAffineForOp(AffineForOp loop, int64_t step, VectorizationState *state) { using namespace functional; - loop->setStep(step); + loop.setStep(step); FilterFunctionType notVectorizedThisPattern = [state](Instruction &inst) { if (!matcher::isLoadOrStore(inst)) { @@ -868,15 +868,15 @@ static LogicalResult vectorizeAffineForOp(AffineForOp loop, int64_t step, }; auto loadAndStores = matcher::Op(notVectorizedThisPattern); SmallVector loadAndStoresMatches; - loadAndStores.match(loop->getInstruction(), &loadAndStoresMatches); + loadAndStores.match(loop.getInstruction(), &loadAndStoresMatches); for (auto ls : loadAndStoresMatches) { auto *opInst = ls.getMatchedInstruction(); auto load = opInst->dyn_cast(); auto store = opInst->dyn_cast(); LLVM_DEBUG(opInst->print(dbgs())); LogicalResult result = - load ? vectorizeRootOrTerminal(loop->getInductionVar(), load, state) - : vectorizeRootOrTerminal(loop->getInductionVar(), store, state); + load ? vectorizeRootOrTerminal(loop.getInductionVar(), load, state) + : vectorizeRootOrTerminal(loop.getInductionVar(), store, state); if (failed(result)) { return failure(); } @@ -1164,18 +1164,17 @@ static LogicalResult vectorizeRootMatch(NestedMatch m, /// Sets up error handling for this root loop. This is how the root match /// maintains a clone for handling failure and restores the proper state via /// RAII. - auto *loopInst = loop->getInstruction(); + auto *loopInst = loop.getInstruction(); FuncBuilder builder(loopInst); auto clonedLoop = builder.clone(*loopInst)->cast(); struct Guard { LogicalResult failure() { - loop->getInductionVar()->replaceAllUsesWith( - clonedLoop->getInductionVar()); - loop->erase(); + loop.getInductionVar()->replaceAllUsesWith(clonedLoop.getInductionVar()); + loop.erase(); return mlir::failure(); } LogicalResult success() { - clonedLoop->erase(); + clonedLoop.erase(); return mlir::success(); } AffineForOp loop; -- cgit v1.2.3 From 96ebde9cfd0dccb672fae02b44faf97355b1ac1b Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 25 Mar 2019 13:02:06 -0700 Subject: Replace usages of "Op::operator->" with ".". This is step 2/N of removing the temporary operator-> method as part of the de-const transition. PiperOrigin-RevId: 240200792 --- mlir/include/mlir/EDSC/Builders.h | 4 +- mlir/include/mlir/IR/OpDefinition.h | 14 +++---- mlir/include/mlir/IR/PatternMatch.h | 4 +- mlir/lib/AffineOps/AffineOps.cpp | 6 +-- mlir/lib/Analysis/LoopAnalysis.cpp | 10 ++--- mlir/lib/Analysis/Utils.cpp | 20 +++++----- mlir/lib/Analysis/VectorAnalysis.cpp | 8 ++-- mlir/lib/EDSC/Builders.cpp | 4 +- .../lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp | 32 +++++++-------- mlir/lib/StandardOps/Ops.cpp | 45 +++++++++++----------- mlir/lib/Transforms/ConstantFold.cpp | 2 +- mlir/lib/Transforms/DmaGeneration.cpp | 16 ++++---- mlir/lib/Transforms/LoopFusion.cpp | 38 +++++++++--------- mlir/lib/Transforms/LowerVectorTransfers.cpp | 40 +++++++++---------- mlir/lib/Transforms/MaterializeVectors.cpp | 28 +++++++------- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 14 +++---- mlir/lib/Transforms/PipelineDataTransfer.cpp | 18 ++++----- .../Utils/GreedyPatternRewriteDriver.cpp | 14 +++---- mlir/lib/Transforms/Vectorize.cpp | 19 +++++---- mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp | 10 ++--- 20 files changed, 172 insertions(+), 174 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h index 8a186c28476..b7b3b1f1844 100644 --- a/mlir/include/mlir/EDSC/Builders.h +++ b/mlir/include/mlir/EDSC/Builders.h @@ -421,14 +421,14 @@ InstructionHandle InstructionHandle::create(Args... args) { return InstructionHandle( ScopedContext::getBuilder() ->create(ScopedContext::getLocation(), args...) - ->getInstruction()); + .getInstruction()); } template ValueHandle ValueHandle::create(Args... args) { Instruction *inst = ScopedContext::getBuilder() ->create(ScopedContext::getLocation(), args...) - ->getInstruction(); + .getInstruction(); if (inst->getNumResults() == 1) { return ValueHandle(inst->getResult(0)); } else if (inst->getNumResults() == 0) { diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index b03ef08f37a..85f9a88963e 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -171,8 +171,8 @@ public: static LogicalResult constantFoldHook(Instruction *op, ArrayRef operands, SmallVectorImpl &results) { - return op->cast()->constantFold(operands, results, - op->getContext()); + return op->cast().constantFold(operands, results, + op->getContext()); } /// Op implementations can implement this hook. It should attempt to constant @@ -193,7 +193,7 @@ public: /// This is an implementation detail of the folder hook for AbstractOperation. static LogicalResult foldHook(Instruction *op, SmallVectorImpl &results) { - return op->cast()->fold(results); + return op->cast().fold(results); } /// This hook implements a generalized folder for this operation. Operations @@ -241,7 +241,7 @@ public: ArrayRef operands, SmallVectorImpl &results) { auto result = - op->cast()->constantFold(operands, op->getContext()); + op->cast().constantFold(operands, op->getContext()); if (!result) return failure(); @@ -265,7 +265,7 @@ public: /// This is an implementation detail of the folder hook for AbstractOperation. static LogicalResult foldHook(Instruction *op, SmallVectorImpl &results) { - auto *result = op->cast()->fold(); + auto *result = op->cast().fold(); if (!result) return failure(); if (result != op->getResult(0)) @@ -752,7 +752,7 @@ public: auto opPointer = op->dyn_cast(); assert(opPointer && "op's name does not match name of concrete type instantiated with"); - opPointer->print(p); + opPointer.print(p); } /// This is the hook that checks whether or not this instruction is well @@ -764,7 +764,7 @@ public: /// diagnostic subsystem and returns true. static bool verifyInvariants(Instruction *op) { return BaseVerifier...>::verifyTrait(op) || - op->cast()->verify(); + op->cast().verify(); } // Returns the properties of an operation by combining the properties of the diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index a65c8dd7c8a..e6b9551339e 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -256,7 +256,7 @@ public: template void replaceOpWithNewOp(Instruction *op, Args... args) { auto newOp = create(op->getLoc(), args...); - replaceOpWithResultsOfAnotherOp(op, newOp->getInstruction(), {}); + replaceOpWithResultsOfAnotherOp(op, newOp.getInstruction(), {}); } /// Replaces the result op with a new op that is created without verification. @@ -267,7 +267,7 @@ public: ArrayRef valuesToRemoveIfDead, Args... args) { auto newOp = create(op->getLoc(), args...); - replaceOpWithResultsOfAnotherOp(op, newOp->getInstruction(), + replaceOpWithResultsOfAnotherOp(op, newOp.getInstruction(), valuesToRemoveIfDead); } diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 2901d815032..139c40745b6 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -66,7 +66,7 @@ bool mlir::isValidDim(Value *value) { // The dim op is okay if its operand memref/tensor is defined at the top // level. if (auto dimOp = inst->dyn_cast()) - return isTopLevelSymbol(dimOp->getOperand()); + return isTopLevelSymbol(dimOp.getOperand()); return false; } // This value is a block argument (which also includes 'affine.for' loop IVs). @@ -87,11 +87,11 @@ bool mlir::isValidSymbol(Value *value) { return true; // Affine apply operation is ok if all of its operands are ok. if (auto op = inst->dyn_cast()) - return op->isValidSymbol(); + return op.isValidSymbol(); // The dim op is okay if its operand memref/tensor is defined at the top // level. if (auto dimOp = inst->dyn_cast()) - return isTopLevelSymbol(dimOp->getOperand()); + return isTopLevelSymbol(dimOp.getOperand()); return false; } // Otherwise, the only valid symbol is a top level block argument. diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index bf8e265dbb8..0f587701036 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -235,9 +235,9 @@ static bool isContiguousAccess(Value &iv, LoadOrStoreOp memoryOp, static_assert(std::is_same::value || std::is_same::value, "Must be called on either const LoadOp & or const StoreOp &"); - auto memRefType = memoryOp->getMemRefType(); + auto memRefType = memoryOp.getMemRefType(); if (fastestVaryingDim >= memRefType.getRank()) { - memoryOp->emitError("fastest varying dim out of bounds"); + memoryOp.emitError("fastest varying dim out of bounds"); return false; } @@ -249,10 +249,10 @@ static bool isContiguousAccess(Value &iv, LoadOrStoreOp memoryOp, (layoutMap.size() == 1 && !(layoutMap[0] == b.getMultiDimIdentityMap(layoutMap[0].getNumDims())))) { - return memoryOp->emitError("NYI: non-trivial layoutMap"), false; + return memoryOp.emitError("NYI: non-trivial layoutMap"), false; } - auto indices = memoryOp->getIndices(); + auto indices = memoryOp.getIndices(); auto numIndices = llvm::size(indices); unsigned d = 0; for (auto index : indices) { @@ -268,7 +268,7 @@ static bool isContiguousAccess(Value &iv, LoadOrStoreOp memoryOp, template static bool isVectorElement(LoadOrStoreOpPointer memoryOp) { - auto memRefType = memoryOp->getMemRefType(); + auto memRefType = memoryOp.getMemRefType(); return memRefType.getElementType().template isa(); } diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 2ac4ee9000f..a564592b2dd 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -384,7 +384,7 @@ LogicalResult mlir::boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, LLVM_DEBUG(region.getConstraints()->dump()); bool outOfBounds = false; - unsigned rank = loadOrStoreOp->getMemRefType().getRank(); + unsigned rank = loadOrStoreOp.getMemRefType().getRank(); // For each dimension, check for out of bounds. for (unsigned r = 0; r < rank; r++) { @@ -394,7 +394,7 @@ LogicalResult mlir::boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, // of upper and out of lower), and check if the constraint system is // feasible. If it is, there is at least one point out of bounds. SmallVector ineq(rank + 1, 0); - int64_t dimSize = loadOrStoreOp->getMemRefType().getDimSize(r); + int64_t dimSize = loadOrStoreOp.getMemRefType().getDimSize(r); // TODO(bondhugula): handle dynamic dim sizes. if (dimSize == -1) continue; @@ -403,7 +403,7 @@ LogicalResult mlir::boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, ucst.addConstantLowerBound(r, dimSize); outOfBounds = !ucst.isEmpty(); if (outOfBounds && emitError) { - loadOrStoreOp->emitOpError( + loadOrStoreOp.emitOpError( "memref out of upper bound access along dimension #" + Twine(r + 1)); } @@ -414,7 +414,7 @@ LogicalResult mlir::boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, lcst.addConstantUpperBound(r, -1); outOfBounds = !lcst.isEmpty(); if (outOfBounds && emitError) { - loadOrStoreOp->emitOpError( + loadOrStoreOp.emitOpError( "memref out of lower bound access along dimension #" + Twine(r + 1)); } } @@ -622,21 +622,21 @@ AffineForOp mlir::insertBackwardComputationSlice( // opinst from 'loadOrStoreOpInst'. MemRefAccess::MemRefAccess(Instruction *loadOrStoreOpInst) { if (auto loadOp = loadOrStoreOpInst->dyn_cast()) { - memref = loadOp->getMemRef(); + memref = loadOp.getMemRef(); opInst = loadOrStoreOpInst; - auto loadMemrefType = loadOp->getMemRefType(); + auto loadMemrefType = loadOp.getMemRefType(); indices.reserve(loadMemrefType.getRank()); - for (auto *index : loadOp->getIndices()) { + for (auto *index : loadOp.getIndices()) { indices.push_back(index); } } else { assert(loadOrStoreOpInst->isa() && "load/store op expected"); auto storeOp = loadOrStoreOpInst->dyn_cast(); opInst = loadOrStoreOpInst; - memref = storeOp->getMemRef(); - auto storeMemrefType = storeOp->getMemRefType(); + memref = storeOp.getMemRef(); + auto storeMemrefType = storeOp.getMemRefType(); indices.reserve(storeMemrefType.getRank()); - for (auto *index : storeOp->getIndices()) { + for (auto *index : storeOp.getIndices()) { indices.push_back(index); } } diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index 32543c8d975..d167da38b7b 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -171,12 +171,12 @@ AffineMap mlir::makePermutationMap( } if (auto load = opInst->dyn_cast()) { - return ::makePermutationMap(opInst->getContext(), load->getIndices(), + return ::makePermutationMap(opInst->getContext(), load.getIndices(), enclosingLoopToVectorDim); } auto store = opInst->cast(); - return ::makePermutationMap(opInst->getContext(), store->getIndices(), + return ::makePermutationMap(opInst->getContext(), store.getIndices(), enclosingLoopToVectorDim); } @@ -194,10 +194,10 @@ bool mlir::matcher::operatesOnSuperVectors(Instruction &opInst, bool mustDivide = false; VectorType superVectorType; if (auto read = opInst.dyn_cast()) { - superVectorType = read->getResultType(); + superVectorType = read.getResultType(); mustDivide = true; } else if (auto write = opInst.dyn_cast()) { - superVectorType = write->getVectorType(); + superVectorType = write.getVectorType(); mustDivide = true; } else if (opInst.getNumResults() == 0) { if (!opInst.isa()) { diff --git a/mlir/lib/EDSC/Builders.cpp b/mlir/lib/EDSC/Builders.cpp index 5cf5cb6cfff..191b789dec6 100644 --- a/mlir/lib/EDSC/Builders.cpp +++ b/mlir/lib/EDSC/Builders.cpp @@ -69,7 +69,7 @@ MLIRContext *mlir::edsc::ScopedContext::getContext() { mlir::edsc::ValueHandle::ValueHandle(index_t cst) { auto *b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); - v = b->create(loc, cst.v)->getResult(); + v = b->create(loc, cst.v).getResult(); t = v->getType(); } @@ -393,7 +393,7 @@ static ValueHandle createComparisonExpr(CmpIPredicate predicate, auto op = ScopedContext::getBuilder()->create( ScopedContext::getLocation(), predicate, lhs.getValue(), rhs.getValue()); - return ValueHandle(op->getResult()); + return ValueHandle(op.getResult()); } ValueHandle mlir::edsc::op::operator==(ValueHandle lhs, ValueHandle rhs) { diff --git a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp index e9605fa2bfe..6bf460a5d24 100644 --- a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp +++ b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp @@ -450,7 +450,7 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern { if (numResults == 0) return {}; if (numResults == 1) - return {newOp->getInstruction()->getResult(0)}; + return {newOp.getInstruction()->getResult(0)}; // Otherwise, it had been converted to an operation producing a structure. // Extract individual results from the structure and return them as list. @@ -460,7 +460,7 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern { auto type = TypeConverter::convert(op->getResult(i)->getType(), this->dialect.getLLVMModule()); results.push_back(rewriter.create( - op->getLoc(), type, newOp->getInstruction()->getResult(0), + op->getLoc(), type, newOp.getInstruction()->getResult(0), this->getIntegerArrayAttr(rewriter, i))); } return results; @@ -546,19 +546,19 @@ struct AllocOpLowering : public LLVMLegalizationPattern { if (!LLVMLegalizationPattern::match(op)) return matchFailure(); auto allocOp = op->cast(); - MemRefType type = allocOp->getType(); + MemRefType type = allocOp.getType(); return isSupportedMemRefType(type) ? matchSuccess() : matchFailure(); } SmallVector rewrite(Instruction *op, ArrayRef operands, FuncBuilder &rewriter) const override { auto allocOp = op->cast(); - MemRefType type = allocOp->getType(); + MemRefType type = allocOp.getType(); // Get actual sizes of the memref as values: static sizes are constant // values and dynamic sizes are passed to 'alloc' as operands. SmallVector sizes; - auto numOperands = allocOp->getNumOperands(); + auto numOperands = allocOp.getNumOperands(); sizes.reserve(numOperands); unsigned i = 0; for (int64_t s : type.getShape()) @@ -607,7 +607,7 @@ struct AllocOpLowering : public LLVMLegalizationPattern { .create(op->getLoc(), getVoidPtrType(), rewriter.getFunctionAttr(mallocFunc), cumulativeSize) - ->getResult(0); + .getResult(0); auto structElementType = TypeConverter::convert(elementType, getModule()); auto elementPtrType = LLVM::LLVMType::get( op->getContext(), structElementType.cast() @@ -688,8 +688,8 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern { return matchFailure(); auto memRefCastOp = op->cast(); MemRefType sourceType = - memRefCastOp->getOperand()->getType().cast(); - MemRefType targetType = memRefCastOp->getType(); + memRefCastOp.getOperand()->getType().cast(); + MemRefType targetType = memRefCastOp.getType(); return (isSupportedMemRefType(targetType) && isSupportedMemRefType(sourceType)) ? matchSuccess() @@ -699,8 +699,8 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern { SmallVector rewrite(Instruction *op, ArrayRef operands, FuncBuilder &rewriter) const override { auto memRefCastOp = op->cast(); - auto targetType = memRefCastOp->getType(); - auto sourceType = memRefCastOp->getOperand()->getType().cast(); + auto targetType = memRefCastOp.getType(); + auto sourceType = memRefCastOp.getOperand()->getType().cast(); // Copy the data buffer pointer. auto elementTypePtr = @@ -767,7 +767,7 @@ struct DimOpLowering : public LLVMLegalizationPattern { if (!LLVMLegalizationPattern::match(op)) return this->matchFailure(); auto dimOp = op->cast(); - MemRefType type = dimOp->getOperand()->getType().cast(); + MemRefType type = dimOp.getOperand()->getType().cast(); return isSupportedMemRefType(type) ? matchSuccess() : matchFailure(); } @@ -775,11 +775,11 @@ struct DimOpLowering : public LLVMLegalizationPattern { FuncBuilder &rewriter) const override { assert(operands.size() == 1 && "expected exactly one operand"); auto dimOp = op->cast(); - MemRefType type = dimOp->getOperand()->getType().cast(); + MemRefType type = dimOp.getOperand()->getType().cast(); SmallVector results; auto shape = type.getShape(); - uint64_t index = dimOp->getIndex(); + uint64_t index = dimOp.getIndex(); // Extract dynamic size from the memref descriptor and define static size // as a constant. if (shape[index] == -1) { @@ -814,7 +814,7 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern { if (!LLVMLegalizationPattern::match(op)) return this->matchFailure(); auto loadOp = op->cast(); - MemRefType type = loadOp->getMemRefType(); + MemRefType type = loadOp.getMemRefType(); return isSupportedMemRefType(type) ? this->matchSuccess() : this->matchFailure(); } @@ -918,7 +918,7 @@ struct LoadOpLowering : public LoadStoreOpLowering { SmallVector rewrite(Instruction *op, ArrayRef operands, FuncBuilder &rewriter) const override { auto loadOp = op->cast(); - auto type = loadOp->getMemRefType(); + auto type = loadOp.getMemRefType(); Value *dataPtr = getDataPtr(op->getLoc(), type, operands.front(), operands.drop_front(), rewriter, getModule()); @@ -940,7 +940,7 @@ struct StoreOpLowering : public LoadStoreOpLowering { SmallVector rewrite(Instruction *op, ArrayRef operands, FuncBuilder &rewriter) const override { auto storeOp = op->cast(); - auto type = storeOp->getMemRefType(); + auto type = storeOp.getMemRefType(); Value *dataPtr = getDataPtr(op->getLoc(), type, operands[1], operands.drop_front(2), rewriter, getModule()); diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 50db72faea1..228a4a5acc4 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -135,7 +135,7 @@ struct MemRefCastFolder : public RewritePattern { for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) if (auto *memref = op->getOperand(i)->getDefiningInst()) if (auto cast = memref->dyn_cast()) - op->setOperand(i, cast->getOperand()); + op->setOperand(i, cast.getOperand()); rewriter.updatedRootInPlace(op); } }; @@ -293,7 +293,7 @@ struct SimplifyAllocConst : public RewritePattern { // Check to see if any dimensions operands are constants. If so, we can // substitute and drop them. - for (auto *operand : alloc->getOperands()) + for (auto *operand : alloc.getOperands()) if (matchPattern(operand, m_ConstantIndex())) return matchSuccess(); return matchFailure(); @@ -301,7 +301,7 @@ struct SimplifyAllocConst : public RewritePattern { void rewrite(Instruction *op, PatternRewriter &rewriter) const override { auto allocOp = op->cast(); - auto memrefType = allocOp->getType(); + auto memrefType = allocOp.getType(); // Ok, we have one or more constant operands. Collect the non-constant ones // and keep track of the resultant memref type to build. @@ -318,7 +318,7 @@ struct SimplifyAllocConst : public RewritePattern { newShapeConstants.push_back(dimSize); continue; } - auto *defOp = allocOp->getOperand(dynamicDimPos)->getDefiningInst(); + auto *defOp = allocOp.getOperand(dynamicDimPos)->getDefiningInst(); ConstantIndexOp constantIndexOp; if (defOp && (constantIndexOp = defOp->dyn_cast())) { // Dynamic shape dimension will be folded. @@ -328,7 +328,7 @@ struct SimplifyAllocConst : public RewritePattern { } else { // Dynamic shape dimension not folded; copy operand from old memref. newShapeConstants.push_back(-1); - newOperands.push_back(allocOp->getOperand(dynamicDimPos)); + newOperands.push_back(allocOp.getOperand(dynamicDimPos)); } dynamicDimPos++; } @@ -341,10 +341,10 @@ struct SimplifyAllocConst : public RewritePattern { // Create and insert the alloc op for the new memref. auto newAlloc = - rewriter.create(allocOp->getLoc(), newMemRefType, newOperands); + rewriter.create(allocOp.getLoc(), newMemRefType, newOperands); // Insert a cast so we have the same type as the old alloc. - auto resultCast = rewriter.create(allocOp->getLoc(), newAlloc, - allocOp->getType()); + auto resultCast = rewriter.create(allocOp.getLoc(), newAlloc, + allocOp.getType()); rewriter.replaceOp(op, {resultCast}, droppedOperands); } @@ -360,7 +360,7 @@ struct SimplifyDeadAlloc : public RewritePattern { PatternRewriter &rewriter) const override { // Check if the alloc'ed value has any uses. auto alloc = op->cast(); - if (!alloc->use_empty()) + if (!alloc.use_empty()) return matchFailure(); // If it doesn't, we can eliminate it. @@ -493,7 +493,7 @@ struct SimplifyIndirectCallWithKnownCallee : public RewritePattern { // Check that the callee is a constant operation. Attribute callee; - if (!matchPattern(indirectCall->getCallee(), m_Constant(&callee))) + if (!matchPattern(indirectCall.getCallee(), m_Constant(&callee))) return matchFailure(); // Check that the constant callee is a function. @@ -502,7 +502,7 @@ struct SimplifyIndirectCallWithKnownCallee : public RewritePattern { return matchFailure(); // Replace with a direct call. - SmallVector callOperands(indirectCall->getArgOperands()); + SmallVector callOperands(indirectCall.getArgOperands()); rewriter.replaceOpWithNewOp(op, calledFn.getValue(), callOperands); return matchSuccess(); } @@ -803,7 +803,7 @@ struct SimplifyConstCondBranchPred : public RewritePattern { auto condbr = op->cast(); // Check that the condition is a constant. - if (!matchPattern(condbr->getCondition(), m_Op())) + if (!matchPattern(condbr.getCondition(), m_Op())) return matchFailure(); Block *foldedDest; @@ -812,14 +812,13 @@ struct SimplifyConstCondBranchPred : public RewritePattern { // If the condition is known to evaluate to false we fold to a branch to the // false destination. Otherwise, we fold to a branch to the true // destination. - if (matchPattern(condbr->getCondition(), m_Zero())) { - foldedDest = condbr->getFalseDest(); - branchArgs.assign(condbr->false_operand_begin(), - condbr->false_operand_end()); + if (matchPattern(condbr.getCondition(), m_Zero())) { + foldedDest = condbr.getFalseDest(); + branchArgs.assign(condbr.false_operand_begin(), + condbr.false_operand_end()); } else { - foldedDest = condbr->getTrueDest(); - branchArgs.assign(condbr->true_operand_begin(), - condbr->true_operand_end()); + foldedDest = condbr.getTrueDest(); + branchArgs.assign(condbr.true_operand_begin(), condbr.true_operand_end()); } rewriter.replaceOpWithNewOp(op, foldedDest, branchArgs); @@ -1095,7 +1094,7 @@ struct SimplifyDeadDealloc : public RewritePattern { auto dealloc = op->cast(); // Check that the memref operand's defining instruction is an AllocOp. - Value *memref = dealloc->getMemRef(); + Value *memref = dealloc.getMemRef(); Instruction *defOp = memref->getDefiningInst(); if (!defOp || !defOp->isa()) return matchFailure(); @@ -1986,15 +1985,15 @@ namespace { /// struct SimplifyXMinusX : public RewritePattern { SimplifyXMinusX(MLIRContext *context) - : RewritePattern(SubIOp::getOperationName(), 10, context) {} + : RewritePattern(SubIOp::getOperationName(), 1, context) {} PatternMatchResult matchAndRewrite(Instruction *op, PatternRewriter &rewriter) const override { auto subi = op->cast(); - if (subi->getOperand(0) != subi->getOperand(1)) + if (subi.getOperand(0) != subi.getOperand(1)) return matchFailure(); - rewriter.replaceOpWithNewOp(op, 0, subi->getType()); + rewriter.replaceOpWithNewOp(op, 0, subi.getType()); return matchSuccess(); } }; diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index ef063d036c2..8c4423a9a06 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -56,7 +56,7 @@ void ConstantFold::foldInstruction(Instruction *op) { Attribute operandCst = nullptr; if (auto *operandOp = operand->getDefiningInst()) { if (auto operandConstantOp = operandOp->dyn_cast()) - operandCst = operandConstantOp->getValue(); + operandCst = operandConstantOp.getValue(); } operandConstants.push_back(operandCst); } diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 5a1af03d299..c1aa77ed5bd 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -168,12 +168,12 @@ static bool getFullMemRefAsRegion(Instruction *opInst, unsigned numParamLoopIVs, MemRefRegion *region) { unsigned rank; if (auto loadOp = opInst->dyn_cast()) { - rank = loadOp->getMemRefType().getRank(); - region->memref = loadOp->getMemRef(); + rank = loadOp.getMemRefType().getRank(); + region->memref = loadOp.getMemRef(); region->setWrite(false); } else if (auto storeOp = opInst->dyn_cast()) { - rank = storeOp->getMemRefType().getRank(); - region->memref = storeOp->getMemRef(); + rank = storeOp.getMemRefType().getRank(); + region->memref = storeOp.getMemRef(); region->setWrite(true); } else { assert(false && "expected load or store op"); @@ -317,7 +317,7 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, Block *block, memIndices.push_back(zeroIndex); } else { memIndices.push_back( - top.create(loc, indexVal)->getResult()); + top.create(loc, indexVal).getResult()); } } else { // The coordinate for the start location is just the lower bound along the @@ -345,7 +345,7 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, Block *block, // Create the fast memory space buffer just before the 'affine.for' // instruction. - fastMemRef = prologue.create(loc, fastMemRefType)->getResult(); + fastMemRef = prologue.create(loc, fastMemRefType).getResult(); // Record it. fastBufferMap[memref] = fastMemRef; // fastMemRefType is a constant shaped memref. @@ -608,10 +608,10 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { block->walk(begin, end, [&](Instruction *opInst) { // Gather regions to allocate to buffers in faster memory space. if (auto loadOp = opInst->dyn_cast()) { - if (loadOp->getMemRefType().getMemorySpace() != slowMemorySpace) + if (loadOp.getMemRefType().getMemorySpace() != slowMemorySpace) return; } else if (auto storeOp = opInst->dyn_cast()) { - if (storeOp->getMemRefType().getMemorySpace() != slowMemorySpace) + if (storeOp.getMemRefType().getMemorySpace() != slowMemorySpace) return; } else { // Neither load nor a store op. diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index c757ea8e58b..0e0e002c9ad 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -174,7 +174,7 @@ public: unsigned getLoadOpCount(Value *memref) { unsigned loadOpCount = 0; for (auto *loadOpInst : loads) { - if (memref == loadOpInst->cast()->getMemRef()) + if (memref == loadOpInst->cast().getMemRef()) ++loadOpCount; } return loadOpCount; @@ -184,7 +184,7 @@ public: unsigned getStoreOpCount(Value *memref) { unsigned storeOpCount = 0; for (auto *storeOpInst : stores) { - if (memref == storeOpInst->cast()->getMemRef()) + if (memref == storeOpInst->cast().getMemRef()) ++storeOpCount; } return storeOpCount; @@ -194,7 +194,7 @@ public: void getStoreOpsForMemref(Value *memref, SmallVectorImpl *storeOps) { for (auto *storeOpInst : stores) { - if (memref == storeOpInst->cast()->getMemRef()) + if (memref == storeOpInst->cast().getMemRef()) storeOps->push_back(storeOpInst); } } @@ -203,7 +203,7 @@ public: void getLoadOpsForMemref(Value *memref, SmallVectorImpl *loadOps) { for (auto *loadOpInst : loads) { - if (memref == loadOpInst->cast()->getMemRef()) + if (memref == loadOpInst->cast().getMemRef()) loadOps->push_back(loadOpInst); } } @@ -213,10 +213,10 @@ public: void getLoadAndStoreMemrefSet(DenseSet *loadAndStoreMemrefSet) { llvm::SmallDenseSet loadMemrefs; for (auto *loadOpInst : loads) { - loadMemrefs.insert(loadOpInst->cast()->getMemRef()); + loadMemrefs.insert(loadOpInst->cast().getMemRef()); } for (auto *storeOpInst : stores) { - auto *memref = storeOpInst->cast()->getMemRef(); + auto *memref = storeOpInst->cast().getMemRef(); if (loadMemrefs.count(memref) > 0) loadAndStoreMemrefSet->insert(memref); } @@ -300,7 +300,7 @@ public: bool writesToLiveInOrEscapingMemrefs(unsigned id) { Node *node = getNode(id); for (auto *storeOpInst : node->stores) { - auto *memref = storeOpInst->cast()->getMemRef(); + auto *memref = storeOpInst->cast().getMemRef(); auto *inst = memref->getDefiningInst(); // Return true if 'memref' is a block argument. if (!inst) @@ -325,7 +325,7 @@ public: Node *node = getNode(id); for (auto *storeOpInst : node->stores) { // Return false if there exist out edges from 'id' on 'memref'. - if (getOutEdgeCount(id, storeOpInst->cast()->getMemRef()) > 0) + if (getOutEdgeCount(id, storeOpInst->cast().getMemRef()) > 0) return false; } return true; @@ -648,12 +648,12 @@ bool MemRefDependenceGraph::init(Function *f) { Node node(nextNodeId++, &inst); for (auto *opInst : collector.loadOpInsts) { node.loads.push_back(opInst); - auto *memref = opInst->cast()->getMemRef(); + auto *memref = opInst->cast().getMemRef(); memrefAccesses[memref].insert(node.id); } for (auto *opInst : collector.storeOpInsts) { node.stores.push_back(opInst); - auto *memref = opInst->cast()->getMemRef(); + auto *memref = opInst->cast().getMemRef(); memrefAccesses[memref].insert(node.id); } forToNodeMap[&inst] = node.id; @@ -662,14 +662,14 @@ bool MemRefDependenceGraph::init(Function *f) { // Create graph node for top-level load op. Node node(nextNodeId++, &inst); node.loads.push_back(&inst); - auto *memref = inst.cast()->getMemRef(); + auto *memref = inst.cast().getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); } else if (auto storeOp = inst.dyn_cast()) { // Create graph node for top-level store op. Node node(nextNodeId++, &inst); node.stores.push_back(&inst); - auto *memref = inst.cast()->getMemRef(); + auto *memref = inst.cast().getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); } else if (inst.getNumRegions() != 0) { @@ -880,7 +880,7 @@ moveLoadsAccessingMemrefTo(Value *memref, dstLoads->clear(); SmallVector srcLoadsToKeep; for (auto *load : *srcLoads) { - if (load->cast()->getMemRef() == memref) + if (load->cast().getMemRef() == memref) dstLoads->push_back(load); else srcLoadsToKeep.push_back(load); @@ -1126,7 +1126,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, // Builder to create constants at the top level. FuncBuilder top(forInst->getFunction()); // Create new memref type based on slice bounds. - auto *oldMemRef = srcStoreOpInst->cast()->getMemRef(); + auto *oldMemRef = srcStoreOpInst->cast().getMemRef(); auto oldMemRefType = oldMemRef->getType().cast(); unsigned rank = oldMemRefType.getRank(); @@ -1857,7 +1857,7 @@ public: DenseSet visitedMemrefs; while (!loads.empty()) { // Get memref of load on top of the stack. - auto *memref = loads.back()->cast()->getMemRef(); + auto *memref = loads.back()->cast().getMemRef(); if (visitedMemrefs.count(memref) > 0) continue; visitedMemrefs.insert(memref); @@ -1920,7 +1920,7 @@ public: // Gather 'dstNode' store ops to 'memref'. SmallVector dstStoreOpInsts; for (auto *storeOpInst : dstNode->stores) - if (storeOpInst->cast()->getMemRef() == memref) + if (storeOpInst->cast().getMemRef() == memref) dstStoreOpInsts.push_back(storeOpInst); unsigned bestDstLoopDepth; @@ -1956,7 +1956,7 @@ public: // Create private memref for 'memref' in 'dstAffineForOp'. SmallVector storesForMemref; for (auto *storeOpInst : sliceCollector.storeOpInsts) { - if (storeOpInst->cast()->getMemRef() == memref) + if (storeOpInst->cast().getMemRef() == memref) storesForMemref.push_back(storeOpInst); } assert(storesForMemref.size() == 1); @@ -1978,7 +1978,7 @@ public: // Add new load ops to current Node load op list 'loads' to // continue fusing based on new operands. for (auto *loadOpInst : dstLoopCollector.loadOpInsts) { - auto *loadMemRef = loadOpInst->cast()->getMemRef(); + auto *loadMemRef = loadOpInst->cast().getMemRef(); if (visitedMemrefs.count(loadMemRef) == 0) loads.push_back(loadOpInst); } @@ -2163,7 +2163,7 @@ public: // Check that all stores are to the same memref. DenseSet storeMemrefs; for (auto *storeOpInst : sibNode->stores) { - storeMemrefs.insert(storeOpInst->cast()->getMemRef()); + storeMemrefs.insert(storeOpInst->cast().getMemRef()); } if (storeMemrefs.size() != 1) return; diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index 7f6be358189..860d4f3c2de 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -108,14 +108,14 @@ public: /// Used for staging the transfer in a local scalar buffer. MemRefType tmpMemRefType() { - auto vectorType = transfer->getVectorType(); + auto vectorType = transfer.getVectorType(); return MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {}, 0); } /// View of tmpMemRefType as one vector, used in vector load/store to tmp /// buffer. MemRefType vectorMemRefType() { - return MemRefType::get({1}, transfer->getVectorType(), {}, 0); + return MemRefType::get({1}, transfer.getVectorType(), {}, 0); } /// Performs the rewrite. void rewrite(); @@ -137,12 +137,12 @@ void coalesceCopy(VectorTransferOpTy transfer, edsc::VectorView *vectorView) { // rank of the remote memory access, coalescing behavior occurs on the // innermost memory dimension. - auto remoteRank = transfer->getMemRefType().getRank(); + auto remoteRank = transfer.getMemRefType().getRank(); // Iterate over the results expressions of the permutation map to determine // the loop order for creating pointwise copies between remote and local // memories. int coalescedIdx = -1; - auto exprs = transfer->getPermutationMap().getResults(); + auto exprs = transfer.getPermutationMap().getResults(); for (auto en : llvm::enumerate(exprs)) { auto dim = en.value().template dyn_cast(); if (!dim) { @@ -173,7 +173,7 @@ clip(VectorTransferOpTy transfer, edsc::MemRefView &view, using edsc::intrinsics::select; IndexHandle zero(index_t(0)), one(index_t(1)); - llvm::SmallVector memRefAccess(transfer->getIndices()); + llvm::SmallVector memRefAccess(transfer.getIndices()); llvm::SmallVector clippedScalarAccessExprs( memRefAccess.size(), edsc::IndexHandle()); @@ -183,7 +183,7 @@ clip(VectorTransferOpTy transfer, edsc::MemRefView &view, ++memRefDim) { // Linear search on a small number of entries. int loopIndex = -1; - auto exprs = transfer->getPermutationMap().getResults(); + auto exprs = transfer.getPermutationMap().getResults(); for (auto en : llvm::enumerate(exprs)) { auto expr = en.value(); auto dim = expr.template dyn_cast(); @@ -267,11 +267,11 @@ template <> void VectorTransferRewriter::rewrite() { using namespace mlir::edsc::intrinsics; // 1. Setup all the captures. - ScopedContext scope(FuncBuilder(transfer->getInstruction()), - transfer->getLoc()); - IndexedValue remote(transfer->getMemRef()); - MemRefView view(transfer->getMemRef()); - VectorView vectorView(transfer->getVector()); + ScopedContext scope(FuncBuilder(transfer.getInstruction()), + transfer.getLoc()); + IndexedValue remote(transfer.getMemRef()); + MemRefView view(transfer.getMemRef()); + VectorView vectorView(transfer.getVector()); SmallVector ivs = IndexHandle::makeIndexHandles(vectorView.rank()); SmallVector pivs = @@ -294,8 +294,8 @@ template <> void VectorTransferRewriter::rewrite() { (dealloc(tmp)); // vexing parse // 3. Propagate. - transfer->replaceAllUsesWith(vectorValue.getValue()); - transfer->erase(); + transfer.replaceAllUsesWith(vectorValue.getValue()); + transfer.erase(); } /// Lowers VectorTransferWriteOp into a combination of: @@ -322,12 +322,12 @@ template <> void VectorTransferRewriter::rewrite() { using namespace mlir::edsc::intrinsics; // 1. Setup all the captures. - ScopedContext scope(FuncBuilder(transfer->getInstruction()), - transfer->getLoc()); - IndexedValue remote(transfer->getMemRef()); - MemRefView view(transfer->getMemRef()); - ValueHandle vectorValue(transfer->getVector()); - VectorView vectorView(transfer->getVector()); + ScopedContext scope(FuncBuilder(transfer.getInstruction()), + transfer.getLoc()); + IndexedValue remote(transfer.getMemRef()); + MemRefView view(transfer.getMemRef()); + ValueHandle vectorValue(transfer.getVector()); + VectorView vectorView(transfer.getVector()); SmallVector ivs = IndexHandle::makeIndexHandles(vectorView.rank()); SmallVector pivs = @@ -349,7 +349,7 @@ template <> void VectorTransferRewriter::rewrite() { }); (dealloc(tmp)); // vexing parse... - transfer->erase(); + transfer.erase(); } namespace { diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index ebdb0c8e83e..cca0c889daa 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -447,7 +447,7 @@ static AffineMap projectedPermutationMap(VectorTransferOpTy transfer, std::is_same::value || std::is_same::value, "Must be called on a VectorTransferOp"); - auto superVectorType = transfer->getVectorType(); + auto superVectorType = transfer.getVectorType(); auto optionalRatio = shapeRatio(superVectorType, hwVectorType); assert(optionalRatio && (optionalRatio->size() == superVectorType.getShape().size()) && @@ -465,7 +465,7 @@ static AffineMap projectedPermutationMap(VectorTransferOpTy transfer, ++dim; }, superVectorType.getShape(), *optionalRatio); - auto permutationMap = transfer->getPermutationMap(); + auto permutationMap = transfer.getPermutationMap(); LLVM_DEBUG(permutationMap.print(dbgs() << "\npermutationMap: ")); if (keep.empty()) { return permutationMap; @@ -486,17 +486,17 @@ static Instruction *instantiate(FuncBuilder *b, VectorTransferReadOp read, ArrayRef hwVectorInstance, DenseMap *substitutionsMap) { SmallVector indices = - map(makePtrDynCaster(), read->getIndices()); + map(makePtrDynCaster(), read.getIndices()); auto affineIndices = reindexAffineIndices(b, hwVectorType, hwVectorInstance, indices); auto map = projectedPermutationMap(read, hwVectorType); if (!map) { return nullptr; } - auto cloned = b->create( - read->getLoc(), hwVectorType, read->getMemRef(), affineIndices, map, - read->getPaddingValue()); - return cloned->getInstruction(); + auto cloned = b->create(read.getLoc(), hwVectorType, + read.getMemRef(), affineIndices, + map, read.getPaddingValue()); + return cloned.getInstruction(); } /// Creates an instantiated version of `write` for the instance of @@ -510,15 +510,15 @@ static Instruction *instantiate(FuncBuilder *b, VectorTransferWriteOp write, ArrayRef hwVectorInstance, DenseMap *substitutionsMap) { SmallVector indices = - map(makePtrDynCaster(), write->getIndices()); + map(makePtrDynCaster(), write.getIndices()); auto affineIndices = reindexAffineIndices(b, hwVectorType, hwVectorInstance, indices); auto cloned = b->create( - write->getLoc(), - substitute(write->getVector(), hwVectorType, substitutionsMap), - write->getMemRef(), affineIndices, + write.getLoc(), + substitute(write.getVector(), hwVectorType, substitutionsMap), + write.getMemRef(), affineIndices, projectedPermutationMap(write, hwVectorType)); - return cloned->getInstruction(); + return cloned.getInstruction(); } /// Returns `true` if inst instance is properly cloned and inserted, false @@ -568,7 +568,7 @@ static bool instantiateMaterialization(Instruction *inst, return true; } state->substitutionsMap->insert( - std::make_pair(read->getResult(), clone->getResult(0))); + std::make_pair(read.getResult(), clone->getResult(0))); return false; } // The only op with 0 results reaching this point must, by construction, be @@ -712,7 +712,7 @@ static bool materialize(Function *f, // Emit the current slice. // Set scoped super-vector and corresponding hw vector types. - state->superVectorType = terminator->getVectorType(); + state->superVectorType = terminator.getVectorType(); assert((state->superVectorType.getElementType() == FloatType::getF32(term->getContext())) && "Only f32 supported for now"); diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index a7045b3b541..0356032b46a 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -95,18 +95,18 @@ FunctionPassBase *mlir::createMemRefDataFlowOptPass() { // this in the future if needed. void MemRefDataFlowOpt::forwardStoreToLoad(LoadOp loadOp) { Instruction *lastWriteStoreOp = nullptr; - Instruction *loadOpInst = loadOp->getInstruction(); + Instruction *loadOpInst = loadOp.getInstruction(); // First pass over the use list to get minimum number of surrounding // loops common between the load op and the store op, with min taken across // all store ops. SmallVector storeOps; unsigned minSurroundingLoops = getNestingDepth(*loadOpInst); - for (InstOperand &use : loadOp->getMemRef()->getUses()) { + for (InstOperand &use : loadOp.getMemRef()->getUses()) { auto storeOp = use.getOwner()->dyn_cast(); if (!storeOp) continue; - auto *storeOpInst = storeOp->getInstruction(); + auto *storeOpInst = storeOp.getInstruction(); unsigned nsLoops = getNumCommonSurroundingLoops(*loadOpInst, *storeOpInst); minSurroundingLoops = std::min(nsLoops, minSurroundingLoops); storeOps.push_back(storeOpInst); @@ -169,7 +169,7 @@ void MemRefDataFlowOpt::forwardStoreToLoad(LoadOp loadOp) { MemRefRegion region(loadOpInst->getLoc()); region.compute(loadOpInst, nsLoops); if (!region.getConstraints()->isRangeOneToOne( - /*start=*/0, /*limit=*/loadOp->getMemRefType().getRank())) + /*start=*/0, /*limit=*/loadOp.getMemRefType().getRank())) break; } @@ -201,10 +201,10 @@ void MemRefDataFlowOpt::forwardStoreToLoad(LoadOp loadOp) { return; // Perform the actual store to load forwarding. - Value *storeVal = lastWriteStoreOp->cast()->getValueToStore(); - loadOp->getResult()->replaceAllUsesWith(storeVal); + Value *storeVal = lastWriteStoreOp->cast().getValueToStore(); + loadOp.getResult()->replaceAllUsesWith(storeVal); // Record the memref for a later sweep to optimize away. - memrefsToErase.insert(loadOp->getMemRef()); + memrefsToErase.insert(loadOp.getMemRef()); // Record this to erase later. loadOpsToErase.push_back(loadOpInst); } diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index a92e2d5960c..520b9e69744 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -125,7 +125,7 @@ static bool doubleBuffer(Value *oldMemRef, AffineForOp forOp) { /*domInstFilter=*/&*forOp.getBody()->begin())) { LLVM_DEBUG( forOp.emitError("memref replacement for double buffering failed")); - ivModTwoOp->getInstruction()->erase(); + ivModTwoOp.erase(); return false; } // Insert the dealloc op right after the for loop. @@ -152,10 +152,10 @@ void PipelineDataTransfer::runOnFunction() { // Check if tags of the dma start op and dma wait op match. static bool checkTagMatch(DmaStartOp startOp, DmaWaitOp waitOp) { - if (startOp->getTagMemRef() != waitOp->getTagMemRef()) + if (startOp.getTagMemRef() != waitOp.getTagMemRef()) return false; - auto startIndices = startOp->getTagIndices(); - auto waitIndices = waitOp->getTagIndices(); + auto startIndices = startOp.getTagIndices(); + auto waitIndices = waitOp.getTagIndices(); // Both of these have the same number of indices since they correspond to the // same tag memref. for (auto it = startIndices.begin(), wIt = waitIndices.begin(), @@ -182,7 +182,7 @@ static void findMatchingStartFinishInsts( SmallVector outgoingDmaOps; for (auto &inst : *forOp.getBody()) { auto dmaStartOp = inst.dyn_cast(); - if (dmaStartOp && dmaStartOp->isSrcMemorySpaceFaster()) + if (dmaStartOp && dmaStartOp.isSrcMemorySpaceFaster()) outgoingDmaOps.push_back(dmaStartOp); } @@ -199,7 +199,7 @@ static void findMatchingStartFinishInsts( // Only DMAs incoming into higher memory spaces are pipelined for now. // TODO(bondhugula): handle outgoing DMA pipelining. - if (!dmaStartOp->isDestMemorySpaceFaster()) + if (!dmaStartOp.isDestMemorySpaceFaster()) continue; // Check for dependence with outgoing DMAs. Doing this conservatively. @@ -207,14 +207,14 @@ static void findMatchingStartFinishInsts( // dependences between an incoming and outgoing DMA in the same iteration. auto it = outgoingDmaOps.begin(); for (; it != outgoingDmaOps.end(); ++it) { - if ((*it)->getDstMemRef() == dmaStartOp->getSrcMemRef()) + if (it->getDstMemRef() == dmaStartOp.getSrcMemRef()) break; } if (it != outgoingDmaOps.end()) continue; // We only double buffer if the buffer is not live out of loop. - auto *memref = dmaStartOp->getOperand(dmaStartOp->getFasterMemPos()); + auto *memref = dmaStartOp.getOperand(dmaStartOp.getFasterMemPos()); bool escapingUses = false; for (const auto &use : memref->getUses()) { // We can double buffer regardless of dealloc's outside the loop. @@ -272,7 +272,7 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { for (auto &pair : startWaitPairs) { auto *dmaStartInst = pair.first; Value *oldMemRef = dmaStartInst->getOperand( - dmaStartInst->cast()->getFasterMemPos()); + dmaStartInst->cast().getFasterMemPos()); if (!doubleBuffer(oldMemRef, forOp)) { // Normally, double buffering should not fail because we already checked // that there are no uses outside. diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index a2d6f392c32..fd5a5843d5b 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -158,18 +158,18 @@ void GreedyPatternRewriteDriver::simplifyFunction() { if (auto constant = op->dyn_cast()) { // If this constant is dead, remove it, being careful to keep // uniquedConstants up to date. - if (constant->use_empty()) { + if (constant.use_empty()) { auto it = - uniquedConstants.find({constant->getValue(), constant->getType()}); + uniquedConstants.find({constant.getValue(), constant.getType()}); if (it != uniquedConstants.end() && it->second == op) uniquedConstants.erase(it); - constant->erase(); + constant.erase(); continue; } // Check to see if we already have a constant with this type and value: - auto &entry = uniquedConstants[std::make_pair(constant->getValue(), - constant->getType())]; + auto &entry = uniquedConstants[std::make_pair(constant.getValue(), + constant.getType())]; if (entry) { // If this constant is already our uniqued one, then leave it alone. if (entry == op) @@ -178,8 +178,8 @@ void GreedyPatternRewriteDriver::simplifyFunction() { // Otherwise replace this redundant constant with the uniqued one. We // know this is safe because we move constants to the top of the // function when they are uniqued, so we know they dominate all uses. - constant->replaceAllUsesWith(entry->getResult(0)); - constant->erase(); + constant.replaceAllUsesWith(entry->getResult(0)); + constant.erase(); continue; } diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 362cad352fb..2d12fe66d4f 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -819,8 +819,7 @@ template static LogicalResult vectorizeRootOrTerminal(Value *iv, LoadOrStoreOpPointer memoryOp, VectorizationState *state) { - auto memRefType = - memoryOp->getMemRef()->getType().template cast(); + auto memRefType = memoryOp.getMemRef()->getType().template cast(); auto elementType = memRefType.getElementType(); // TODO(ntv): ponder whether we want to further vectorize a vector value. @@ -829,7 +828,7 @@ static LogicalResult vectorizeRootOrTerminal(Value *iv, auto vectorType = VectorType::get(state->strategy->vectorSizes, elementType); // Materialize a MemRef with 1 vector. - auto *opInst = memoryOp->getInstruction(); + auto *opInst = memoryOp.getInstruction(); // For now, vector_transfers must be aligned, operate only on indices with an // identity subset of AffineMap and do not change layout. // TODO(ntv): increase the expressiveness power of vector_transfer operations @@ -841,9 +840,9 @@ static LogicalResult vectorizeRootOrTerminal(Value *iv, LLVM_DEBUG(permutationMap.print(dbgs())); FuncBuilder b(opInst); auto transfer = b.create( - opInst->getLoc(), vectorType, memoryOp->getMemRef(), - map(makePtrDynCaster(), memoryOp->getIndices()), permutationMap); - state->registerReplacement(opInst, transfer->getInstruction()); + opInst->getLoc(), vectorType, memoryOp.getMemRef(), + map(makePtrDynCaster(), memoryOp.getIndices()), permutationMap); + state->registerReplacement(opInst, transfer.getInstruction()); } else { state->registerTerminal(opInst); } @@ -1041,10 +1040,10 @@ static Instruction *vectorizeOneInstruction(Instruction *opInst, "vector_transfer_write cannot be further vectorized"); if (auto store = opInst->dyn_cast()) { - auto *memRef = store->getMemRef(); - auto *value = store->getValueToStore(); + auto *memRef = store.getMemRef(); + auto *value = store.getValueToStore(); auto *vectorValue = vectorizeOperand(value, opInst, state); - auto indices = map(makePtrDynCaster(), store->getIndices()); + auto indices = map(makePtrDynCaster(), store.getIndices()); FuncBuilder b(opInst); auto permutationMap = makePermutationMap(opInst, state->strategy->loopToVectorDim); @@ -1052,7 +1051,7 @@ static Instruction *vectorizeOneInstruction(Instruction *opInst, LLVM_DEBUG(permutationMap.print(dbgs())); auto transfer = b.create( opInst->getLoc(), vectorValue, memRef, indices, permutationMap); - auto *res = transfer->getInstruction(); + auto *res = transfer.getInstruction(); LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << *res); // "Terminals" (i.e. StoreOps) are erased on the spot. opInst->erase(); diff --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp index dfe27741490..cba0b67a2ce 100644 --- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp +++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp @@ -130,15 +130,15 @@ static bool emitOneBuilder(const Record &record, raw_ostream &os) { bool isVariadicArg = isVariadicArgumentName(op, name); if (isOperandName(op, name)) { auto result = isVariadicArg - ? formatv("lookupValues(op->{0}())", name) - : formatv("valueMapping.lookup(op->{0}())", name); + ? formatv("lookupValues(op.{0}())", name) + : formatv("valueMapping.lookup(op.{0}())", name); bs << result; } else if (isAttributeName(op, name)) { - bs << formatv("op->{0}()", name); + bs << formatv("op.{0}()", name); } else if (isResultName(op, name)) { - bs << formatv("valueMapping[op->{0}()]", name); + bs << formatv("valueMapping[op.{0}()]", name); } else if (name == "_resultType") { - bs << "op->getResult()->getType().cast()." + bs << "op.getResult()->getType().cast()." "getUnderlyingType()"; } else if (name == "_hasResult") { bs << "inst.getNumResults() == 1"; -- cgit v1.2.3 From 46ade282c8d98558d0d1b8e79d2eee3ae00086f1 Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Mon, 25 Mar 2019 18:02:49 -0700 Subject: Make FunctionPass::getFunction() return a reference to the function, instead of a pointer. This makes it consistent with all the other methods in FunctionPass, as well as with ModulePass::getModule(). NFC. PiperOrigin-RevId: 240257910 --- mlir/include/mlir/IR/Builders.h | 5 +- mlir/include/mlir/IR/PatternMatch.h | 2 +- mlir/include/mlir/Pass/Pass.h | 4 +- mlir/include/mlir/Transforms/ViewFunctionGraph.h | 2 +- mlir/lib/Analysis/MemRefBoundCheck.cpp | 2 +- mlir/lib/Analysis/MemRefDependenceCheck.cpp | 2 +- mlir/lib/Analysis/TestParallelismDetection.cpp | 4 +- mlir/lib/EDSC/LowerEDSCTestPass.cpp | 2 +- mlir/lib/Pass/Pass.cpp | 2 +- mlir/lib/Transforms/CSE.cpp | 2 +- mlir/lib/Transforms/Canonicalizer.cpp | 4 +- mlir/lib/Transforms/ConstantFold.cpp | 2 +- mlir/lib/Transforms/DmaGeneration.cpp | 6 +- mlir/lib/Transforms/LoopFusion.cpp | 8 +-- mlir/lib/Transforms/LoopTiling.cpp | 4 +- mlir/lib/Transforms/LoopUnroll.cpp | 6 +- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 2 +- mlir/lib/Transforms/LowerAffine.cpp | 2 +- mlir/lib/Transforms/LowerVectorTransfers.cpp | 5 +- mlir/lib/Transforms/MaterializeVectors.cpp | 2 +- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 6 +- mlir/lib/Transforms/PipelineDataTransfer.cpp | 2 +- mlir/lib/Transforms/SimplifyAffineStructures.cpp | 2 +- mlir/lib/Transforms/StripDebugInfo.cpp | 8 +-- .../Utils/GreedyPatternRewriteDriver.cpp | 10 +-- .../Vectorization/VectorizerTestPass.cpp | 73 ++++++++++++---------- mlir/lib/Transforms/Vectorize.cpp | 6 +- mlir/lib/Transforms/ViewFunctionGraph.cpp | 4 +- 28 files changed, 95 insertions(+), 84 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index baf71879afd..fbb8ff9cd62 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -177,13 +177,16 @@ class FuncBuilder : public Builder { public: /// Create a function builder and set the insertion point to the start of /// the function. - FuncBuilder(Function *func) : Builder(func->getContext()), function(func) { + explicit FuncBuilder(Function *func) + : Builder(func->getContext()), function(func) { if (!func->empty()) setInsertionPoint(&func->front(), func->front().begin()); else clearInsertionPoint(); } + explicit FuncBuilder(Function &func) : FuncBuilder(&func) {} + /// Create a function builder and set insertion point to the given /// instruction, which will cause subsequent insertions to go right before it. FuncBuilder(Instruction *inst) : FuncBuilder(inst->getFunction()) { diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index e6b9551339e..2e8aba2aedd 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -350,7 +350,7 @@ private: /// Rewrite the specified function by repeatedly applying the highest benefit /// patterns in a greedy work-list driven manner. /// -void applyPatternsGreedily(Function *fn, OwningRewritePatternList &&patterns); +void applyPatternsGreedily(Function &fn, OwningRewritePatternList &&patterns); } // end namespace mlir diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h index f4fc6b80eff..53629e0f127 100644 --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -104,8 +104,8 @@ protected: virtual void runOnFunction() = 0; /// Return the current function being transformed. - Function *getFunction() { - return getPassState().irAndPassFailed.getPointer(); + Function &getFunction() { + return *getPassState().irAndPassFailed.getPointer(); } /// Returns the current pass state. diff --git a/mlir/include/mlir/Transforms/ViewFunctionGraph.h b/mlir/include/mlir/Transforms/ViewFunctionGraph.h index f56003b2939..c1da5ef9638 100644 --- a/mlir/include/mlir/Transforms/ViewFunctionGraph.h +++ b/mlir/include/mlir/Transforms/ViewFunctionGraph.h @@ -37,7 +37,7 @@ void viewGraph(Function &function, const Twine &name, bool shortNames = false, const Twine &title = "", llvm::GraphProgram::Name program = llvm::GraphProgram::DOT); -llvm::raw_ostream &writeGraph(llvm::raw_ostream &os, Function *function, +llvm::raw_ostream &writeGraph(llvm::raw_ostream &os, Function &function, bool shortNames = false, const Twine &title = ""); /// Creates a pass to print CFG graphs. diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index b90a799b794..8edf79d6db3 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -47,7 +47,7 @@ FunctionPassBase *mlir::createMemRefBoundCheckPass() { } void MemRefBoundCheck::runOnFunction() { - getFunction()->walk([](Instruction *opInst) { + getFunction().walk([](Instruction *opInst) { if (auto loadOp = opInst->dyn_cast()) { boundCheckLoadOrStoreOp(loadOp); } else if (auto storeOp = opInst->dyn_cast()) { diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp index 87267183a5f..8e438108bce 100644 --- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp @@ -113,7 +113,7 @@ static void checkDependences(ArrayRef loadsAndStores) { void MemRefDependenceCheck::runOnFunction() { // Collect the loads and stores within the function. loadsAndStores.clear(); - getFunction()->walk([&](Instruction *inst) { + getFunction().walk([&](Instruction *inst) { if (inst->isa() || inst->isa()) loadsAndStores.push_back(inst); }); diff --git a/mlir/lib/Analysis/TestParallelismDetection.cpp b/mlir/lib/Analysis/TestParallelismDetection.cpp index af112e5b02c..701ef6ab348 100644 --- a/mlir/lib/Analysis/TestParallelismDetection.cpp +++ b/mlir/lib/Analysis/TestParallelismDetection.cpp @@ -43,9 +43,9 @@ FunctionPassBase *mlir::createParallelismDetectionTestPass() { // Walks the function and emits a note for all 'affine.for' ops detected as // parallel. void TestParallelismDetection::runOnFunction() { - Function *f = getFunction(); + Function &f = getFunction(); FuncBuilder b(f); - f->walk([&](AffineForOp forOp) { + f.walk([&](AffineForOp forOp) { if (isLoopParallel(forOp)) forOp.emitNote("parallel loop"); }); diff --git a/mlir/lib/EDSC/LowerEDSCTestPass.cpp b/mlir/lib/EDSC/LowerEDSCTestPass.cpp index 94e94bf48f9..8604de1f4b8 100644 --- a/mlir/lib/EDSC/LowerEDSCTestPass.cpp +++ b/mlir/lib/EDSC/LowerEDSCTestPass.cpp @@ -40,7 +40,7 @@ struct LowerEDSCTestPass : public FunctionPass { #include "mlir/EDSC/reference-impl.inc" void LowerEDSCTestPass::runOnFunction() { - getFunction()->walk([](Instruction *op) { + getFunction().walk([](Instruction *op) { if (op->getName().getStringRef() == "print") { auto opName = op->getAttrOfType("op"); if (!opName) { diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp index fe114f09d77..71b060dd95d 100644 --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -153,7 +153,7 @@ namespace { /// Pass to verify a function and signal failure if necessary. class FunctionVerifier : public FunctionPass { void runOnFunction() { - if (getFunction()->verify()) + if (getFunction().verify()) signalPassFailure(); markAllAnalysesPreserved(); } diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index 05760f18761..ee0a10b2f5d 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -226,7 +226,7 @@ void CSE::simplifyRegion(DominanceInfo &domInfo, Region ®ion) { } void CSE::runOnFunction() { - simplifyRegion(getAnalysis(), getFunction()->getBody()); + simplifyRegion(getAnalysis(), getFunction().getBody()); // If no operations were erased, then we mark all analyses as preserved. if (opsToErase.empty()) { diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp index 77244264cda..54579759058 100644 --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -40,12 +40,12 @@ struct Canonicalizer : public FunctionPass { void Canonicalizer::runOnFunction() { OwningRewritePatternList patterns; - auto *func = getFunction(); + auto &func = getFunction(); // TODO: Instead of adding all known patterns from the whole system lazily add // and cache the canonicalization patterns for ops we see in practice when // building the worklist. For now, we just grab everything. - auto *context = func->getContext(); + auto *context = func.getContext(); for (auto *op : context->getRegisteredOperations()) op->getCanonicalizationPatterns(patterns, context); diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index 8c4423a9a06..ece87ce6b6c 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -97,7 +97,7 @@ void ConstantFold::runOnFunction() { existingConstants.clear(); opInstsToErase.clear(); - getFunction()->walk([&](Instruction *inst) { foldInstruction(inst); }); + getFunction().walk([&](Instruction *inst) { foldInstruction(inst); }); // At this point, these operations are dead, remove them. // TODO: This is assuming that all constant foldable operations have no diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index c1aa77ed5bd..e20472770ae 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -754,16 +754,16 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { } void DmaGeneration::runOnFunction() { - Function *f = getFunction(); + Function &f = getFunction(); FuncBuilder topBuilder(f); - zeroIndex = topBuilder.create(f->getLoc(), 0); + zeroIndex = topBuilder.create(f.getLoc(), 0); // Override default is a command line option is provided. if (clFastMemoryCapacity.getNumOccurrences() > 0) { fastMemCapacityBytes = clFastMemoryCapacity * 1024; } - for (auto &block : *f) + for (auto &block : f) runOnBlock(&block); } diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 0e0e002c9ad..df5005bc7b1 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -257,7 +257,7 @@ public: // Initializes the dependence graph based on operations in 'f'. // Returns true on success, false otherwise. - bool init(Function *f); + bool init(Function &f); // Returns the graph node for 'id'. Node *getNode(unsigned id) { @@ -627,15 +627,15 @@ public: // Assigns each node in the graph a node id based on program order in 'f'. // TODO(andydavis) Add support for taking a Block arg to construct the // dependence graph at a different depth. -bool MemRefDependenceGraph::init(Function *f) { +bool MemRefDependenceGraph::init(Function &f) { DenseMap> memrefAccesses; // TODO: support multi-block functions. - if (f->getBlocks().size() != 1) + if (f.getBlocks().size() != 1) return false; DenseMap forToNodeMap; - for (auto &inst : f->front()) { + for (auto &inst : f.front()) { if (auto forOp = inst.dyn_cast()) { // Create graph node 'id' to represent top-level 'forOp' and record // all loads and store accesses it contains. diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 2dbdf689f02..eafa7bca4d4 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -256,7 +256,7 @@ LogicalResult mlir::tileCodeGen(MutableArrayRef band, // Identify valid and profitable bands of loops to tile. This is currently just // a temporary placeholder to test the mechanics of tiled code generation. // Returns all maximal outermost perfect loop nests to tile. -static void getTileableBands(Function *f, +static void getTileableBands(Function &f, std::vector> *bands) { // Get maximal perfect nest of 'affine.for' insts starting from root // (inclusive). @@ -270,7 +270,7 @@ static void getTileableBands(Function *f, bands->push_back(band); }; - for (auto &block : *f) + for (auto &block : f) for (auto &inst : block) if (auto forOp = inst.dyn_cast()) getMaximalPerfectLoopNest(forOp); diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 173a171e589..5687c6126d1 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -128,7 +128,7 @@ void LoopUnroll::runOnFunction() { // Gathers all loops with trip count <= minTripCount. Do a post order walk // so that loops are gathered from innermost to outermost (or else unrolling // an outer one may delete gathered inner ones). - getFunction()->walkPostOrder([&](AffineForOp forOp) { + getFunction().walkPostOrder([&](AffineForOp forOp) { Optional tripCount = getConstantTripCount(forOp); if (tripCount.hasValue() && tripCount.getValue() <= clUnrollFullThreshold) loops.push_back(forOp); @@ -142,10 +142,10 @@ void LoopUnroll::runOnFunction() { ? clUnrollNumRepetitions : 1; // If the call back is provided, we will recurse until no loops are found. - Function *func = getFunction(); + Function &func = getFunction(); for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) { InnermostLoopGatherer ilg; - ilg.walkPostOrder(func); + ilg.walkPostOrder(&func); auto &loops = ilg.loops; if (loops.empty()) break; diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 0822ddf37e3..174f93e4d2d 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -91,7 +91,7 @@ void LoopUnrollAndJam::runOnFunction() { // Currently, just the outermost loop from the first loop nest is // unroll-and-jammed by this pass. However, runOnAffineForOp can be called on // any for operation. - auto &entryBlock = getFunction()->front(); + auto &entryBlock = getFunction().front(); if (auto forOp = entryBlock.front().dyn_cast()) runOnAffineForOp(forOp); } diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index 93197c30cb2..162eed00b6c 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -609,7 +609,7 @@ void LowerAffinePass::runOnFunction() { // Collect all the For instructions as well as AffineIfOps and AffineApplyOps. // We do this as a prepass to avoid invalidating the walker with our rewrite. - getFunction()->walk([&](Instruction *inst) { + getFunction().walk([&](Instruction *inst) { if (inst->isa() || inst->isa() || inst->isa()) instsToRewrite.push_back(inst); diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index 860d4f3c2de..e6b1950c222 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -43,7 +43,6 @@ #include "mlir/Transforms/MLPatternLoweringPass.h" #include "mlir/Transforms/Passes.h" -/// /// Implements lowering of VectorTransferReadOp and VectorTransferWriteOp to a /// proper abstraction for the hardware. /// @@ -376,9 +375,9 @@ public: struct LowerVectorTransfersPass : public FunctionPass { void runOnFunction() { - Function *f = getFunction(); + auto &f = getFunction(); applyMLPatternsGreedily, - VectorTransferExpander>(f); + VectorTransferExpander>(&f); } }; diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index cca0c889daa..a4deba26d83 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -733,7 +733,7 @@ void MaterializeVectorsPass::runOnFunction() { NestedPatternContext mlContext; // TODO(ntv): Check to see if this supports arbitrary top-level code. - Function *f = getFunction(); + Function *f = &getFunction(); if (f->getBlocks().size() != 1) return; diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index 0356032b46a..e1e253d1869 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -211,8 +211,8 @@ void MemRefDataFlowOpt::forwardStoreToLoad(LoadOp loadOp) { void MemRefDataFlowOpt::runOnFunction() { // Only supports single block functions at the moment. - Function *f = getFunction(); - if (f->getBlocks().size() != 1) { + Function &f = getFunction(); + if (f.getBlocks().size() != 1) { markAllAnalysesPreserved(); return; } @@ -224,7 +224,7 @@ void MemRefDataFlowOpt::runOnFunction() { memrefsToErase.clear(); // Walk all load's and perform load/store forwarding. - f->walk([&](LoadOp loadOp) { forwardStoreToLoad(loadOp); }); + f.walk([&](LoadOp loadOp) { forwardStoreToLoad(loadOp); }); // Erase all load op's whose results were replaced with store fwd'ed ones. for (auto *loadOp : loadOpsToErase) { diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 520b9e69744..051ac733c14 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -144,7 +144,7 @@ void PipelineDataTransfer::runOnFunction() { // gets deleted and replaced by a prologue, a new steady-state loop and an // epilogue). forOps.clear(); - getFunction()->walkPostOrder( + getFunction().walkPostOrder( [&](AffineForOp forOp) { forOps.push_back(forOp); }); for (auto forOp : forOps) runOnAffineForOp(forOp); diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp index 47d68461fa5..ab83ede303c 100644 --- a/mlir/lib/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp @@ -93,7 +93,7 @@ FunctionPassBase *mlir::createSimplifyAffineStructuresPass() { void SimplifyAffineStructures::runOnFunction() { simplifiedAttributes.clear(); - getFunction()->walk([&](Instruction *opInst) { + getFunction().walk([&](Instruction *opInst) { for (auto attr : opInst->getAttrs()) { if (auto mapAttr = attr.second.dyn_cast()) simplifyAndUpdateAttribute(opInst, attr.first, mapAttr); diff --git a/mlir/lib/Transforms/StripDebugInfo.cpp b/mlir/lib/Transforms/StripDebugInfo.cpp index f8f90c0cdb1..47244f94ac9 100644 --- a/mlir/lib/Transforms/StripDebugInfo.cpp +++ b/mlir/lib/Transforms/StripDebugInfo.cpp @@ -29,12 +29,12 @@ struct StripDebugInfo : public FunctionPass { } // end anonymous namespace void StripDebugInfo::runOnFunction() { - Function *func = getFunction(); - UnknownLoc unknownLoc = UnknownLoc::get(func->getContext()); + Function &func = getFunction(); + UnknownLoc unknownLoc = UnknownLoc::get(func.getContext()); // Strip the debug info from the function and its instructions. - func->setLoc(unknownLoc); - func->walk([&](Instruction *inst) { inst->setLoc(unknownLoc); }); + func.setLoc(unknownLoc); + func.walk([&](Instruction *inst) { inst->setLoc(unknownLoc); }); } /// Creates a pass to strip debug information from a function. diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index fd5a5843d5b..e8dce29729d 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -32,14 +32,14 @@ namespace { /// applies the locally optimal patterns in a roughly "bottom up" way. class GreedyPatternRewriteDriver : public PatternRewriter { public: - explicit GreedyPatternRewriteDriver(Function *fn, + explicit GreedyPatternRewriteDriver(Function &fn, OwningRewritePatternList &&patterns) - : PatternRewriter(fn->getContext()), matcher(std::move(patterns), *this), - builder(fn) { + : PatternRewriter(fn.getContext()), matcher(std::move(patterns), *this), + builder(&fn) { worklist.reserve(64); // Add all operations to the worklist. - fn->walk([&](Instruction *inst) { addToWorklist(inst); }); + fn.walk([&](Instruction *inst) { addToWorklist(inst); }); } /// Perform the rewrites. @@ -299,7 +299,7 @@ void GreedyPatternRewriteDriver::simplifyFunction() { /// Rewrite the specified function by repeatedly applying the highest benefit /// patterns in a greedy work-list driven manner. /// -void mlir::applyPatternsGreedily(Function *fn, +void mlir::applyPatternsGreedily(Function &fn, OwningRewritePatternList &&patterns) { GreedyPatternRewriteDriver driver(fn, std::move(patterns)); driver.simplifyFunction(); diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index f57a53d3670..b5109a20ba9 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -87,17 +87,18 @@ struct VectorizerTestPass : public FunctionPass { static constexpr auto kTestAffineMapAttrName = "affine_map"; void runOnFunction() override; - void testVectorShapeRatio(Function *f); - void testForwardSlicing(Function *f); - void testBackwardSlicing(Function *f); - void testSlicing(Function *f); - void testComposeMaps(Function *f); - void testNormalizeMaps(Function *f); + void testVectorShapeRatio(); + void testForwardSlicing(); + void testBackwardSlicing(); + void testSlicing(); + void testComposeMaps(); + void testNormalizeMaps(); }; } // end anonymous namespace -void VectorizerTestPass::testVectorShapeRatio(Function *f) { +void VectorizerTestPass::testVectorShapeRatio() { + auto *f = &getFunction(); using matcher::Op; SmallVector shape(clTestVectorShapeRatio.begin(), clTestVectorShapeRatio.end()); @@ -156,7 +157,9 @@ static NestedPattern patternTestSlicingOps() { return Op(filter); } -void VectorizerTestPass::testBackwardSlicing(Function *f) { +void VectorizerTestPass::testBackwardSlicing() { + auto *f = &getFunction(); + SmallVector matches; patternTestSlicingOps().match(f, &matches); for (auto m : matches) { @@ -171,7 +174,8 @@ void VectorizerTestPass::testBackwardSlicing(Function *f) { } } -void VectorizerTestPass::testForwardSlicing(Function *f) { +void VectorizerTestPass::testForwardSlicing() { + auto *f = &getFunction(); SmallVector matches; patternTestSlicingOps().match(f, &matches); for (auto m : matches) { @@ -186,7 +190,9 @@ void VectorizerTestPass::testForwardSlicing(Function *f) { } } -void VectorizerTestPass::testSlicing(Function *f) { +void VectorizerTestPass::testSlicing() { + auto *f = &getFunction(); + SmallVector matches; patternTestSlicingOps().match(f, &matches); for (auto m : matches) { @@ -204,7 +210,9 @@ static bool customOpWithAffineMapAttribute(Instruction &inst) { VectorizerTestPass::kTestAffineMapOpName; } -void VectorizerTestPass::testComposeMaps(Function *f) { +void VectorizerTestPass::testComposeMaps() { + auto *f = &getFunction(); + using matcher::Op; auto pattern = Op(customOpWithAffineMapAttribute); SmallVector matches; @@ -234,9 +242,11 @@ static bool singleResultAffineApplyOpWithoutUses(Instruction &inst) { return app && app.use_empty(); } -void VectorizerTestPass::testNormalizeMaps(Function *f) { +void VectorizerTestPass::testNormalizeMaps() { using matcher::Op; + auto *f = &getFunction(); + // Save matched AffineApplyOp that all need to be erased in the end. auto pattern = Op(affineApplyOp); SmallVector toErase; @@ -264,28 +274,27 @@ void VectorizerTestPass::runOnFunction() { NestedPatternContext mlContext; // Only support single block functions at this point. - Function *f = getFunction(); - if (f->getBlocks().size() != 1) + Function &f = getFunction(); + if (f.getBlocks().size() != 1) return; - if (!clTestVectorShapeRatio.empty()) { - testVectorShapeRatio(f); - } - if (clTestForwardSlicingAnalysis) { - testForwardSlicing(f); - } - if (clTestBackwardSlicingAnalysis) { - testBackwardSlicing(f); - } - if (clTestSlicingAnalysis) { - testSlicing(f); - } - if (clTestComposeMaps) { - testComposeMaps(f); - } - if (clTestNormalizeMaps) { - testNormalizeMaps(f); - } + if (!clTestVectorShapeRatio.empty()) + testVectorShapeRatio(); + + if (clTestForwardSlicingAnalysis) + testForwardSlicing(); + + if (clTestBackwardSlicingAnalysis) + testBackwardSlicing(); + + if (clTestSlicingAnalysis) + testSlicing(); + + if (clTestComposeMaps) + testComposeMaps(); + + if (clTestNormalizeMaps) + testNormalizeMaps(); } FunctionPassBase *mlir::createVectorizerTestPass() { diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 2d12fe66d4f..0e0ac1bf2a3 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -1227,16 +1227,16 @@ void Vectorize::runOnFunction() { // Thread-safe RAII local context, BumpPtrAllocator freed on exit. NestedPatternContext mlContext; - Function *f = getFunction(); + Function &f = getFunction(); for (auto &pat : makePatterns()) { LLVM_DEBUG(dbgs() << "\n******************************************"); LLVM_DEBUG(dbgs() << "\n******************************************"); LLVM_DEBUG(dbgs() << "\n[early-vect] new pattern on Function\n"); - LLVM_DEBUG(f->print(dbgs())); + LLVM_DEBUG(f.print(dbgs())); unsigned patternDepth = pat.getDepth(); SmallVector matches; - pat.match(f, &matches); + pat.match(&f, &matches); // Iterate over all the top-level matches and vectorize eagerly. // This automatically prunes intersecting matches. for (auto m : matches) { diff --git a/mlir/lib/Transforms/ViewFunctionGraph.cpp b/mlir/lib/Transforms/ViewFunctionGraph.cpp index 834424951bf..46e47a4ab1b 100644 --- a/mlir/lib/Transforms/ViewFunctionGraph.cpp +++ b/mlir/lib/Transforms/ViewFunctionGraph.cpp @@ -61,9 +61,9 @@ void mlir::viewGraph(Function &function, const llvm::Twine &name, llvm::ViewGraph(&function, name, shortNames, title, program); } -llvm::raw_ostream &mlir::writeGraph(llvm::raw_ostream &os, Function *function, +llvm::raw_ostream &mlir::writeGraph(llvm::raw_ostream &os, Function &function, bool shortNames, const llvm::Twine &title) { - return llvm::WriteGraph(os, function, shortNames, title); + return llvm::WriteGraph(os, &function, shortNames, title); } void mlir::Function::viewGraph() { -- cgit v1.2.3 From f9d91531df58a561a2c6e197dfb7cb796f7e44e3 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Tue, 26 Mar 2019 17:05:09 -0700 Subject: Replace usages of Instruction with Operation in the /IR directory. This is step 2/N to renaming Instruction to Operation. PiperOrigin-RevId: 240459216 --- mlir/include/mlir/AffineOps/AffineOps.h | 4 +- mlir/include/mlir/Analysis/Dominance.h | 2 +- mlir/include/mlir/EDSC/Builders.h | 6 +- mlir/include/mlir/IR/Block.h | 134 ++++++------- mlir/include/mlir/IR/Builders.h | 40 ++-- mlir/include/mlir/IR/Dialect.h | 8 +- mlir/include/mlir/IR/Function.h | 18 +- mlir/include/mlir/IR/IntegerSet.h | 4 +- mlir/include/mlir/IR/MLIRContext.h | 4 +- mlir/include/mlir/IR/Matchers.h | 10 +- mlir/include/mlir/IR/OpDefinition.h | 214 ++++++++++----------- mlir/include/mlir/IR/OpImplementation.h | 4 +- mlir/include/mlir/IR/Operation.h | 16 +- mlir/include/mlir/IR/PatternMatch.h | 41 ++-- mlir/include/mlir/IR/UseDefLists.h | 26 ++- mlir/include/mlir/IR/Value.h | 23 ++- mlir/include/mlir/StandardOps/Ops.h | 25 ++- mlir/include/mlir/StandardOps/Ops.td | 2 +- mlir/lib/AffineOps/AffineOps.cpp | 36 ++-- mlir/lib/Analysis/AffineAnalysis.cpp | 10 +- mlir/lib/Analysis/AffineStructures.cpp | 2 +- mlir/lib/Analysis/Dominance.cpp | 4 +- mlir/lib/Analysis/LoopAnalysis.cpp | 10 +- mlir/lib/Analysis/SliceAnalysis.cpp | 2 +- mlir/lib/Analysis/Utils.cpp | 22 +-- mlir/lib/Analysis/Verifier.cpp | 6 +- mlir/lib/EDSC/Builders.cpp | 8 +- mlir/lib/EDSC/MLIREmitter.cpp | 13 +- mlir/lib/IR/AsmPrinter.cpp | 40 ++-- mlir/lib/IR/Block.cpp | 119 ++++++------ mlir/lib/IR/Builders.cpp | 12 +- mlir/lib/IR/Function.cpp | 9 +- mlir/lib/IR/MLIRContext.cpp | 2 +- mlir/lib/IR/Operation.cpp | 39 ++-- mlir/lib/IR/PatternMatch.cpp | 18 +- mlir/lib/IR/Value.cpp | 12 +- .../lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp | 4 +- mlir/lib/Parser/Parser.cpp | 4 +- mlir/lib/StandardOps/Ops.cpp | 36 ++-- mlir/lib/SuperVectorOps/SuperVectorOps.cpp | 4 +- mlir/lib/Transforms/ConstantFold.cpp | 4 +- mlir/lib/Transforms/DmaGeneration.cpp | 12 +- mlir/lib/Transforms/LoopFusion.cpp | 58 +++--- mlir/lib/Transforms/LoopTiling.cpp | 27 ++- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 2 +- mlir/lib/Transforms/LowerAffine.cpp | 24 +-- mlir/lib/Transforms/LowerVectorTransfers.cpp | 6 +- mlir/lib/Transforms/MaterializeVectors.cpp | 8 +- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 6 +- mlir/lib/Transforms/PipelineDataTransfer.cpp | 12 +- .../Utils/GreedyPatternRewriteDriver.cpp | 2 +- mlir/lib/Transforms/Utils/LoopUtils.cpp | 38 ++-- mlir/lib/Transforms/Utils/Utils.cpp | 4 +- mlir/lib/Transforms/Vectorize.cpp | 16 +- mlir/test/EDSC/builder-api-test.cpp | 4 +- mlir/test/mlir-tblgen/op-operand.td | 2 +- mlir/test/mlir-tblgen/op-result.td | 2 +- mlir/test/mlir-tblgen/predicate.td | 8 +- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 10 +- mlir/tools/mlir-tblgen/RewriterGen.cpp | 6 +- 60 files changed, 606 insertions(+), 638 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/AffineOps/AffineOps.h b/mlir/include/mlir/AffineOps/AffineOps.h index 4f949231674..65467677bd6 100644 --- a/mlir/include/mlir/AffineOps/AffineOps.h +++ b/mlir/include/mlir/AffineOps/AffineOps.h @@ -144,7 +144,7 @@ public: Block *getBody() { return &getRegion().front(); } /// Get the body region of the AffineForOp. - Region &getRegion() { return getInstruction()->getRegion(0); } + Region &getRegion() { return getOperation()->getRegion(0); } /// Returns the induction variable for this loop. Value *getInductionVar(); @@ -253,7 +253,7 @@ public: unsigned getNumOperands() { return opEnd - opStart; } Value *getOperand(unsigned idx) { - return inst.getInstruction()->getOperand(opStart + idx); + return inst.getOperation()->getOperand(opStart + idx); } using operand_iterator = AffineForOp::operand_iterator; diff --git a/mlir/include/mlir/Analysis/Dominance.h b/mlir/include/mlir/Analysis/Dominance.h index 4aa8c0463d4..1c3ca02e41c 100644 --- a/mlir/include/mlir/Analysis/Dominance.h +++ b/mlir/include/mlir/Analysis/Dominance.h @@ -78,7 +78,7 @@ public: /// Return true if instruction A dominates instruction B. bool dominates(Value *a, Instruction *b) { - return (Instruction *)a->getDefiningInst() == b || properlyDominates(a, b); + return (Instruction *)a->getDefiningOp() == b || properlyDominates(a, b); } /// Return true if the specified block A dominates block B. diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h index b7b3b1f1844..4b0e5b938b5 100644 --- a/mlir/include/mlir/EDSC/Builders.h +++ b/mlir/include/mlir/EDSC/Builders.h @@ -359,7 +359,7 @@ struct InstructionHandle : public CapturableHandle { ArrayRef attributes = {}); operator Instruction *() { return inst; } - Instruction *getInstruction() { return inst; } + Instruction *getOperation() { return inst; } private: Instruction *inst; @@ -421,14 +421,14 @@ InstructionHandle InstructionHandle::create(Args... args) { return InstructionHandle( ScopedContext::getBuilder() ->create(ScopedContext::getLocation(), args...) - .getInstruction()); + .getOperation()); } template ValueHandle ValueHandle::create(Args... args) { Instruction *inst = ScopedContext::getBuilder() ->create(ScopedContext::getLocation(), args...) - .getInstruction(); + .getOperation(); if (inst->getNumResults() == 1) { return ValueHandle(inst->getResult(0)); } else if (inst->getNumResults() == 0) { diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h index 57a61606719..fd318525e2b 100644 --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -28,20 +28,20 @@ #include "llvm/ADT/ilist_node.h" //===----------------------------------------------------------------------===// -// ilist_traits for Instruction +// ilist_traits for Operation //===----------------------------------------------------------------------===// namespace llvm { namespace ilist_detail { -// Explicitly define the node access for the instruction list so that we can -// break the dependence on the Instruction class in this header. This allows for -// instructions to have trailing Regions without a circular include +// Explicitly define the node access for the operation list so that we can +// break the dependence on the Operation class in this header. This allows for +// operations to have trailing Regions without a circular include // dependence. template <> struct SpecificNodeAccess< - typename compute_node_options<::mlir::Instruction>::type> : NodeAccess { + typename compute_node_options<::mlir::Operation>::type> : NodeAccess { protected: - using OptionsT = typename compute_node_options::type; + using OptionsT = typename compute_node_options::type; using pointer = typename OptionsT::pointer; using const_pointer = typename OptionsT::const_pointer; using node_type = ilist_node_impl; @@ -54,15 +54,15 @@ protected: }; } // end namespace ilist_detail -template <> struct ilist_traits<::mlir::Instruction> { - using Instruction = ::mlir::Instruction; - using inst_iterator = simple_ilist::iterator; +template <> struct ilist_traits<::mlir::Operation> { + using Operation = ::mlir::Operation; + using op_iterator = simple_ilist::iterator; - static void deleteNode(Instruction *inst); - void addNodeToList(Instruction *inst); - void removeNodeFromList(Instruction *inst); - void transferNodesFromList(ilist_traits &otherList, - inst_iterator first, inst_iterator last); + static void deleteNode(Operation *op); + void addNodeToList(Operation *op); + void removeNodeFromList(Operation *op); + void transferNodesFromList(ilist_traits &otherList, + op_iterator first, op_iterator last); private: mlir::Block *getContainingBlock(); @@ -79,7 +79,7 @@ using BlockOperand = IROperandImpl; class PredecessorIterator; class SuccessorIterator; -/// `Block` represents an ordered list of `Instruction`s. +/// `Block` represents an ordered list of `Operation`s. class Block : public IRObjectWithUseList, public llvm::ilist_node_with_parent { public: @@ -90,18 +90,18 @@ public: // Drop all references from within this block. dropAllReferences(); - // Clear instructions in the reverse order so that uses are destroyed + // Clear operations in the reverse order so that uses are destroyed // before their defs. while (!empty()) - instructions.pop_back(); + operations.pop_back(); } /// Blocks are maintained in a Region. Region *getParent() { return parentValidInstOrderPair.getPointer(); } - /// Returns the closest surrounding instruction that contains this block or + /// Returns the closest surrounding operation that contains this block or /// nullptr if this is a top-level block. - Instruction *getContainingInst(); + Operation *getContainingOp(); /// Returns the function that this block is part of, even if the block is /// nested under an operation region. @@ -145,37 +145,37 @@ public: BlockArgument *getArgument(unsigned i) { return arguments[i]; } //===--------------------------------------------------------------------===// - // Instruction list management + // Operation list management //===--------------------------------------------------------------------===// - /// This is the list of instructions in the block. - using InstListType = llvm::iplist; - InstListType &getInstructions() { return instructions; } + /// This is the list of operations in the block. + using InstListType = llvm::iplist; + InstListType &getOperations() { return operations; } - // Iteration over the instructions in the block. + // Iteration over the operations in the block. using iterator = InstListType::iterator; using reverse_iterator = InstListType::reverse_iterator; - iterator begin() { return instructions.begin(); } - iterator end() { return instructions.end(); } - reverse_iterator rbegin() { return instructions.rbegin(); } - reverse_iterator rend() { return instructions.rend(); } + iterator begin() { return operations.begin(); } + iterator end() { return operations.end(); } + reverse_iterator rbegin() { return operations.rbegin(); } + reverse_iterator rend() { return operations.rend(); } - bool empty() { return instructions.empty(); } - void push_back(Instruction *inst) { instructions.push_back(inst); } - void push_front(Instruction *inst) { instructions.push_front(inst); } + bool empty() { return operations.empty(); } + void push_back(Operation *op) { operations.push_back(op); } + void push_front(Operation *op) { operations.push_front(op); } - Instruction &back() { return instructions.back(); } - Instruction &front() { return instructions.front(); } + Operation &back() { return operations.back(); } + Operation &front() { return operations.front(); } - /// Returns 'inst' if 'inst' lies in this block, or otherwise finds the - /// ancestor instruction of 'inst' that lies in this block. Returns nullptr if + /// Returns 'op' if 'op' lies in this block, or otherwise finds the + /// ancestor operation of 'op' that lies in this block. Returns nullptr if /// the latter fails. /// TODO: This is very specific functionality that should live somewhere else, /// probably in Dominance.cpp. - Instruction *findAncestorInstInBlock(Instruction &inst); + Operation *findAncestorInstInBlock(Operation &op); - /// This drops all operand uses from instructions within this block, which is + /// This drops all operand uses from operations within this block, which is /// an essential step in breaking cyclic dependences between references when /// they are to be deleted. void dropAllReferences(); @@ -184,31 +184,31 @@ public: /// nested regions wherever the uses are located. void dropAllDefinedValueUses(); - /// Returns true if the ordering of the child instructions is valid, false + /// Returns true if the ordering of the child operations is valid, false /// otherwise. bool isInstOrderValid() { return parentValidInstOrderPair.getInt(); } - /// Invalidates the current ordering of instructions. + /// Invalidates the current ordering of operations. void invalidateInstOrder() { // Validate the current ordering. assert(!verifyInstOrder()); parentValidInstOrderPair.setInt(false); } - /// Verifies the current ordering of child instructions matches the + /// Verifies the current ordering of child operations matches the /// validInstOrder flag. Returns false if the order is valid, true otherwise. bool verifyInstOrder(); - /// Recomputes the ordering of child instructions within the block. + /// Recomputes the ordering of child operations within the block. void recomputeInstOrder(); //===--------------------------------------------------------------------===// // Terminator management //===--------------------------------------------------------------------===// - /// Get the terminator instruction of this block. This function asserts that - /// the block has a valid terminator instruction. - Instruction *getTerminator(); + /// Get the terminator operation of this block. This function asserts that + /// the block has a valid terminator operation. + Operation *getTerminator(); //===--------------------------------------------------------------------===// // Predecessors and successors. @@ -242,49 +242,49 @@ public: llvm::iterator_range getSuccessors(); //===--------------------------------------------------------------------===// - // Instruction Walkers + // Operation Walkers //===--------------------------------------------------------------------===// - /// Walk the instructions of this block in preorder, calling the callback for + /// Walk the operations of this block in preorder, calling the callback for /// each operation. - void walk(const std::function &callback); + void walk(const std::function &callback); - /// Walk the instructions in the specified [begin, end) range of + /// Walk the operations in the specified [begin, end) range of /// this block, calling the callback for each operation. void walk(Block::iterator begin, Block::iterator end, - const std::function &callback); + const std::function &callback); - /// Walk the instructions in this block in postorder, calling the callback for + /// Walk the operations in this block in postorder, calling the callback for /// each operation. - void walkPostOrder(const std::function &callback); + void walkPostOrder(const std::function &callback); - /// Walk the instructions in the specified [begin, end) range of this block + /// Walk the operations in the specified [begin, end) range of this block /// in postorder, calling the callback for each operation. void walkPostOrder(Block::iterator begin, Block::iterator end, - const std::function &callback); + const std::function &callback); //===--------------------------------------------------------------------===// // Other //===--------------------------------------------------------------------===// - /// Split the block into two blocks before the specified instruction or + /// Split the block into two blocks before the specified operation or /// iterator. /// - /// Note that all instructions BEFORE the specified iterator stay as part of - /// the original basic block, and the rest of the instructions in the original + /// Note that all operations BEFORE the specified iterator stay as part of + /// the original basic block, and the rest of the operations in the original /// block are moved to the new block, including the old terminator. The /// original block is left without a terminator. /// /// The newly formed Block is returned, and the specified iterator is /// invalidated. Block *splitBlock(iterator splitBefore); - Block *splitBlock(Instruction *splitBeforeInst) { + Block *splitBlock(Operation *splitBeforeInst) { return splitBlock(iterator(splitBeforeInst)); } - /// Returns pointer to member of instruction list. - static InstListType Block::*getSublistAccess(Instruction *) { - return &Block::instructions; + /// Returns pointer to member of operation list. + static InstListType Block::*getSublistAccess(Operation *) { + return &Block::operations; } void print(raw_ostream &os); @@ -297,11 +297,11 @@ public: private: /// Pair of the parent object that owns this block and a bit that signifies if - /// the instructions within this block have a valid ordering. + /// the operations within this block have a valid ordering. llvm::PointerIntPair parentValidInstOrderPair; - /// This is the list of instructions in the block. - InstListType instructions; + /// This is the list of operations in the block. + InstListType operations; /// This is the list of arguments to the block. std::vector arguments; @@ -342,7 +342,7 @@ namespace mlir { class Region { public: explicit Region(Function *container = nullptr); - explicit Region(Instruction *container); + explicit Region(Operation *container); ~Region(); using RegionType = llvm::iplist; @@ -371,7 +371,7 @@ public: /// A Region is either a function body or a part of an operation. If it is /// part of an operation, then return the operation, otherwise return null. - Instruction *getContainingInst(); + Operation *getContainingOp(); /// A Region is either a function body or a part of an operation. If it is /// a Function body, then return this function, otherwise return null. @@ -395,7 +395,7 @@ private: RegionType blocks; /// This is the object we are part of. - llvm::PointerUnion container; + llvm::PointerUnion container; }; //===----------------------------------------------------------------------===// @@ -404,7 +404,7 @@ private: /// Implement a predecessor iterator as a forward iterator. This works by /// walking the use lists of the blocks. The entries on this list are the -/// BlockOperands that are embedded into terminator instructions. From the +/// BlockOperands that are embedded into terminator operations. From the /// operand, we can get the terminator that contains it, and it's parent block /// is the predecessor. class PredecessorIterator diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 1cf18cc2f4c..65f986b1257 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -171,7 +171,7 @@ protected: MLIRContext *context; }; -/// This class helps build a Function. Instructions that are created are +/// This class helps build a Function. Operations that are created are /// automatically inserted at an insertion point. The builder is copyable. class FuncBuilder : public Builder { public: @@ -188,9 +188,9 @@ public: explicit FuncBuilder(Function &func) : FuncBuilder(&func) {} /// Create a function builder and set insertion point to the given - /// instruction, which will cause subsequent insertions to go right before it. - FuncBuilder(Instruction *inst) : FuncBuilder(inst->getFunction()) { - setInsertionPoint(inst); + /// operation, which will cause subsequent insertions to go right before it. + FuncBuilder(Operation *op) : FuncBuilder(op->getFunction()) { + setInsertionPoint(op); } FuncBuilder(Block *block) : FuncBuilder(block->getFunction()) { @@ -222,8 +222,8 @@ public: /// Sets the insertion point to the specified operation, which will cause /// subsequent insertions to go right before it. - void setInsertionPoint(Instruction *inst) { - setInsertionPoint(inst->getBlock(), Block::iterator(inst)); + void setInsertionPoint(Operation *op) { + setInsertionPoint(op->getBlock(), Block::iterator(op)); } /// Sets the insertion point to the start of the specified block. @@ -253,33 +253,33 @@ public: Block *getBlock() const { return block; } /// Creates an operation given the fields represented as an OperationState. - Instruction *createOperation(const OperationState &state); + Operation *createOperation(const OperationState &state); /// Create operation of specific op type at the current insertion point. template OpTy create(Location location, Args... args) { OperationState state(getContext(), location, OpTy::getOperationName()); OpTy::build(this, &state, args...); - auto *inst = createOperation(state); - auto result = inst->dyn_cast(); + auto *op = createOperation(state); + auto result = op->dyn_cast(); assert(result && "Builder didn't return the right type"); return result; } - /// Creates a deep copy of the specified instruction, remapping any operands - /// that use values outside of the instruction using the map that is provided + /// Creates a deep copy of the specified operation, remapping any operands + /// that use values outside of the operation using the map that is provided /// ( leaving them alone if no entry is present). Replaces references to - /// cloned sub-instructions to the corresponding instruction that is copied, + /// cloned sub-operations to the corresponding operation that is copied, /// and adds those mappings to the map. - Instruction *clone(Instruction &inst, BlockAndValueMapping &mapper) { - Instruction *cloneInst = inst.clone(mapper, getContext()); - block->getInstructions().insert(insertPoint, cloneInst); - return cloneInst; + Operation *clone(Operation &op, BlockAndValueMapping &mapper) { + Operation *cloneOp = op.clone(mapper, getContext()); + block->getOperations().insert(insertPoint, cloneOp); + return cloneOp; } - Instruction *clone(Instruction &inst) { - Instruction *cloneInst = inst.clone(getContext()); - block->getInstructions().insert(insertPoint, cloneInst); - return cloneInst; + Operation *clone(Operation &op) { + Operation *cloneOp = op.clone(getContext()); + block->getOperations().insert(insertPoint, cloneOp); + return cloneOp; } private: diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h index fe2c79a3e36..b129395d012 100644 --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -32,7 +32,7 @@ class Type; using DialectConstantDecodeHook = std::function; using DialectConstantFoldHook = std::function, SmallVectorImpl &)>; + Operation *, ArrayRef, SmallVectorImpl &)>; using DialectExtractElementHook = std::function)>; @@ -57,7 +57,7 @@ public: /// `results` vector. If not, this returns failure and `results` is /// unspecified. DialectConstantFoldHook constantFoldHook = - [](Instruction *op, ArrayRef operands, + [](Operation *op, ArrayRef operands, SmallVectorImpl &results) { return failure(); }; /// Registered hook to decode opaque constants associated with this @@ -115,9 +115,9 @@ public: return false; } - /// Verify an attribute from this dialect on the given instruction. Returns + /// Verify an attribute from this dialect on the given operation. Returns /// true if the verification failed, false otherwise. - virtual bool verifyInstructionAttribute(Instruction *, NamedAttribute) { + virtual bool verifyOperationAttribute(Operation *, NamedAttribute) { return false; } diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h index 67f9762e7fb..41435add404 100644 --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -101,29 +101,29 @@ public: Block &front() { return body.front(); } //===--------------------------------------------------------------------===// - // Instruction Walkers + // Operation Walkers //===--------------------------------------------------------------------===// - /// Walk the instructions in the function in preorder, calling the callback - /// for each instruction. - void walk(const std::function &callback); + /// Walk the operations in the function in preorder, calling the callback + /// for each operation. + void walk(const std::function &callback); /// Specialization of walk to only visit operations of 'OpTy'. template void walk(std::function callback) { - walk([&](Instruction *inst) { + walk([&](Operation *inst) { if (auto op = inst->dyn_cast()) callback(op); }); } - /// Walk the instructions in the function in postorder, calling the callback - /// for each instruction. - void walkPostOrder(const std::function &callback); + /// Walk the operations in the function in postorder, calling the callback + /// for each operation. + void walkPostOrder(const std::function &callback); /// Specialization of walkPostOrder to only visit operations of 'OpTy'. template void walkPostOrder(std::function callback) { - walkPostOrder([&](Instruction *inst) { + walkPostOrder([&](Operation *inst) { if (auto op = inst->dyn_cast()) callback(op); }); diff --git a/mlir/include/mlir/IR/IntegerSet.h b/mlir/include/mlir/IR/IntegerSet.h index db417db69c9..b7662f095a5 100644 --- a/mlir/include/mlir/IR/IntegerSet.h +++ b/mlir/include/mlir/IR/IntegerSet.h @@ -17,8 +17,8 @@ // // Integer sets are sets of points from the integer lattice constrained by // affine equality/inequality constraints. This class is meant to represent -// integer sets in the IR - for 'affine.if' instructions and as attributes of -// other instructions. It is typically expected to contain only a handful of +// integer sets in the IR - for 'affine.if' operations and as attributes of +// other operations. It is typically expected to contain only a handful of // affine constraints, and is immutable like an affine map. Integer sets are not // unique'd - although affine expressions that make up its equalities and // inequalites are themselves unique. diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h index 304970166c6..85eb62d229f 100644 --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -61,7 +61,7 @@ public: // Diagnostic handler registration and use. MLIR supports the ability for the // IR to carry arbitrary metadata about operation location information. If an // problem is detected by the compiler, it can invoke the emitError / - // emitWarning / emitNote method on an Instruction and have it get reported + // emitWarning / emitNote method on an Operation and have it get reported // through this interface. // // Tools using MLIR are encouraged to register error handlers and define a @@ -81,7 +81,7 @@ public: /// Emit a diagnostic using the registered issue handle if present, or with /// the default behavior if not. The MLIR compiler should not generally - /// interact with this, it should use methods on Instruction instead. + /// interact with this, it should use methods on Operation instead. void emitDiagnostic(Location location, const Twine &message, DiagnosticKind kind); diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h index 9e91e9a8842..2ac334bec94 100644 --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -69,7 +69,7 @@ struct constant_op_binder { /// bind_value if match succeeds. constant_op_binder(Attribute *bind_value) : bind_value(bind_value) {} - bool match(Instruction *op) { + bool match(Operation *op) { if (op->getNumOperands() > 0 || op->getNumResults() != 1) return false; SmallVector foldedAttr; @@ -89,7 +89,7 @@ struct constant_int_op_binder { /// Creates a matcher instance that binds the value to bv if match succeeds. constant_int_op_binder(IntegerAttr::ValueType *bv) : bind_value(bv) {} - bool match(Instruction *op) { + bool match(Operation *op) { Attribute attr; if (!constant_op_binder(&attr).match(op)) return false; @@ -111,7 +111,7 @@ struct constant_int_op_binder { // The matcher that matches a given target constant scalar / vector splat / // tensor splat integer value. template struct constant_int_value_matcher { - bool match(Instruction *op) { + bool match(Operation *op) { APInt value; return constant_int_op_binder(&value).match(op) && TargetValue == value; @@ -120,7 +120,7 @@ template struct constant_int_value_matcher { /// The matcher that matches a certain kind of op. template struct op_matcher { - bool match(Instruction *op) { return op->isa(); } + bool match(Operation *op) { return op->isa(); } }; } // end namespace detail @@ -129,7 +129,7 @@ template struct op_matcher { template inline bool matchPattern(Value *value, const Pattern &pattern) { // TODO: handle other cases - if (auto *op = value->getDefiningInst()) + if (auto *op = value->getDefiningOp()) return const_cast(pattern).match(op); return false; } diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index fabbf9d5767..24fc7fd2740 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -63,16 +63,16 @@ template struct IsSingleResult { class OpState { public: /// Ops are pointer-like, so we allow implicit conversion to bool. - operator bool() { return getInstruction() != nullptr; } + operator bool() { return getOperation() != nullptr; } - /// This implicitly converts to Instruction*. - operator Instruction *() const { return state; } + /// This implicitly converts to Operation*. + operator Operation *() const { return state; } /// Return the operation that this refers to. - Instruction *getInstruction() { return state; } + Operation *getOperation() { return state; } /// Return the context this operation belongs to. - MLIRContext *getContext() { return getInstruction()->getContext(); } + MLIRContext *getContext() { return getOperation()->getContext(); } /// The source location the operation was defined or derived from. Location getLoc() { return state->getLoc(); } @@ -144,18 +144,18 @@ protected: /// Mutability management is handled by the OpWrapper/OpConstWrapper classes, /// so we can cast it away here. - explicit OpState(Instruction *state) : state(state) {} + explicit OpState(Operation *state) : state(state) {} private: - Instruction *state; + Operation *state; }; // Allow comparing operators. inline bool operator==(OpState lhs, OpState rhs) { - return lhs.getInstruction() == rhs.getInstruction(); + return lhs.getOperation() == rhs.getOperation(); } inline bool operator!=(OpState lhs, OpState rhs) { - return lhs.getInstruction() != rhs.getInstruction(); + return lhs.getOperation() != rhs.getOperation(); } /// This template defines the constantFoldHook and foldHook as used by @@ -168,7 +168,7 @@ class FoldingHook { public: /// This is an implementation detail of the constant folder hook for /// AbstractOperation. - static LogicalResult constantFoldHook(Instruction *op, + static LogicalResult constantFoldHook(Operation *op, ArrayRef operands, SmallVectorImpl &results) { return op->cast().constantFold(operands, results, @@ -191,7 +191,7 @@ public: } /// This is an implementation detail of the folder hook for AbstractOperation. - static LogicalResult foldHook(Instruction *op, + static LogicalResult foldHook(Operation *op, SmallVectorImpl &results) { return op->cast().fold(results); } @@ -232,12 +232,12 @@ public: /// If the operation returns a single value, then the Op can be implicitly /// converted to an Value*. This yields the value of the only result. operator Value *() { - return static_cast(this)->getInstruction()->getResult(0); + return static_cast(this)->getOperation()->getResult(0); } /// This is an implementation detail of the constant folder hook for /// AbstractOperation. - static LogicalResult constantFoldHook(Instruction *op, + static LogicalResult constantFoldHook(Operation *op, ArrayRef operands, SmallVectorImpl &results) { auto result = @@ -263,7 +263,7 @@ public: } /// This is an implementation detail of the folder hook for AbstractOperation. - static LogicalResult foldHook(Instruction *op, + static LogicalResult foldHook(Operation *op, SmallVectorImpl &results) { auto *result = op->cast().fold(); if (!result) @@ -299,7 +299,7 @@ public: }; //===----------------------------------------------------------------------===// -// Instruction Trait Types +// Operation Trait Types //===----------------------------------------------------------------------===// namespace OpTrait { @@ -308,22 +308,22 @@ namespace OpTrait { // corresponding trait classes. This avoids them being template // instantiated/duplicated. namespace impl { -bool verifyZeroOperands(Instruction *op); -bool verifyOneOperand(Instruction *op); -bool verifyNOperands(Instruction *op, unsigned numOperands); -bool verifyAtLeastNOperands(Instruction *op, unsigned numOperands); -bool verifyOperandsAreIntegerLike(Instruction *op); -bool verifySameTypeOperands(Instruction *op); -bool verifyZeroResult(Instruction *op); -bool verifyOneResult(Instruction *op); -bool verifyNResults(Instruction *op, unsigned numOperands); -bool verifyAtLeastNResults(Instruction *op, unsigned numOperands); -bool verifySameOperandsAndResultShape(Instruction *op); -bool verifySameOperandsAndResultType(Instruction *op); -bool verifyResultsAreBoolLike(Instruction *op); -bool verifyResultsAreFloatLike(Instruction *op); -bool verifyResultsAreIntegerLike(Instruction *op); -bool verifyIsTerminator(Instruction *op); +bool verifyZeroOperands(Operation *op); +bool verifyOneOperand(Operation *op); +bool verifyNOperands(Operation *op, unsigned numOperands); +bool verifyAtLeastNOperands(Operation *op, unsigned numOperands); +bool verifyOperandsAreIntegerLike(Operation *op); +bool verifySameTypeOperands(Operation *op); +bool verifyZeroResult(Operation *op); +bool verifyOneResult(Operation *op); +bool verifyNResults(Operation *op, unsigned numOperands); +bool verifyAtLeastNResults(Operation *op, unsigned numOperands); +bool verifySameOperandsAndResultShape(Operation *op); +bool verifySameOperandsAndResultType(Operation *op); +bool verifyResultsAreBoolLike(Operation *op); +bool verifyResultsAreFloatLike(Operation *op); +bool verifyResultsAreIntegerLike(Operation *op); +bool verifyIsTerminator(Operation *op); } // namespace impl /// Helper class for implementing traits. Clients are not expected to interact @@ -331,8 +331,8 @@ bool verifyIsTerminator(Instruction *op); template class TraitType> class TraitBase { protected: - /// Return the ultimate Instruction being worked on. - Instruction *getInstruction() { + /// Return the ultimate Operation being worked on. + Operation *getOperation() { // We have to cast up to the trait type, then to the concrete type, then to // the BaseState class in explicit hops because the concrete type will // multiply derive from the (content free) TraitBase class, and we need to @@ -340,12 +340,12 @@ protected: auto *trait = static_cast *>(this); auto *concrete = static_cast(trait); auto *base = static_cast(concrete); - return base->getInstruction(); + return base->getOperation(); } /// Provide default implementations of trait hooks. This allows traits to /// provide exactly the overrides they care about. - static bool verifyTrait(Instruction *op) { return false; } + static bool verifyTrait(Operation *op) { return false; } static AbstractOperation::OperationProperties getTraitProperties() { return 0; } @@ -356,7 +356,7 @@ protected: template class ZeroOperands : public TraitBase { public: - static bool verifyTrait(Instruction *op) { + static bool verifyTrait(Operation *op) { return impl::verifyZeroOperands(op); } @@ -371,15 +371,11 @@ private: template class OneOperand : public TraitBase { public: - Value *getOperand() { return this->getInstruction()->getOperand(0); } + Value *getOperand() { return this->getOperation()->getOperand(0); } - void setOperand(Value *value) { - this->getInstruction()->setOperand(0, value); - } + void setOperand(Value *value) { this->getOperation()->setOperand(0, value); } - static bool verifyTrait(Instruction *op) { - return impl::verifyOneOperand(op); - } + static bool verifyTrait(Operation *op) { return impl::verifyOneOperand(op); } }; /// This class provides the API for ops that are known to have a specified @@ -393,14 +389,14 @@ public: class Impl : public TraitBase::Impl> { public: Value *getOperand(unsigned i) { - return this->getInstruction()->getOperand(i); + return this->getOperation()->getOperand(i); } void setOperand(unsigned i, Value *value) { - this->getInstruction()->setOperand(i, value); + this->getOperation()->setOperand(i, value); } - static bool verifyTrait(Instruction *op) { + static bool verifyTrait(Operation *op) { return impl::verifyNOperands(op, N); } }; @@ -416,30 +412,28 @@ public: template class Impl : public TraitBase::Impl> { public: - unsigned getNumOperands() { - return this->getInstruction()->getNumOperands(); - } + unsigned getNumOperands() { return this->getOperation()->getNumOperands(); } Value *getOperand(unsigned i) { - return this->getInstruction()->getOperand(i); + return this->getOperation()->getOperand(i); } void setOperand(unsigned i, Value *value) { - this->getInstruction()->setOperand(i, value); + this->getOperation()->setOperand(i, value); } - using operand_iterator = Instruction::operand_iterator; + using operand_iterator = Operation::operand_iterator; operand_iterator operand_begin() { - return this->getInstruction()->operand_begin(); + return this->getOperation()->operand_begin(); } operand_iterator operand_end() { - return this->getInstruction()->operand_end(); + return this->getOperation()->operand_end(); } llvm::iterator_range getOperands() { - return this->getInstruction()->getOperands(); + return this->getOperation()->getOperands(); } - static bool verifyTrait(Instruction *op) { + static bool verifyTrait(Operation *op) { return impl::verifyAtLeastNOperands(op, N); } }; @@ -450,26 +444,22 @@ public: template class VariadicOperands : public TraitBase { public: - unsigned getNumOperands() { return this->getInstruction()->getNumOperands(); } + unsigned getNumOperands() { return this->getOperation()->getNumOperands(); } - Value *getOperand(unsigned i) { - return this->getInstruction()->getOperand(i); - } + Value *getOperand(unsigned i) { return this->getOperation()->getOperand(i); } void setOperand(unsigned i, Value *value) { - this->getInstruction()->setOperand(i, value); + this->getOperation()->setOperand(i, value); } // Support operand iteration. - using operand_iterator = Instruction::operand_iterator; - using operand_range = Instruction::operand_range; + using operand_iterator = Operation::operand_iterator; + using operand_range = Operation::operand_range; operand_iterator operand_begin() { - return this->getInstruction()->operand_begin(); + return this->getOperation()->operand_begin(); } - operand_iterator operand_end() { - return this->getInstruction()->operand_end(); - } - operand_range getOperands() { return this->getInstruction()->getOperands(); } + operand_iterator operand_end() { return this->getOperation()->operand_end(); } + operand_range getOperands() { return this->getOperation()->getOperands(); } }; /// This class provides return value APIs for ops that are known to have @@ -477,9 +467,7 @@ public: template class ZeroResult : public TraitBase { public: - static bool verifyTrait(Instruction *op) { - return impl::verifyZeroResult(op); - } + static bool verifyTrait(Operation *op) { return impl::verifyZeroResult(op); } }; /// This class provides return value APIs for ops that are known to have a @@ -487,7 +475,7 @@ public: template class OneResult : public TraitBase { public: - Value *getResult() { return this->getInstruction()->getResult(0); } + Value *getResult() { return this->getOperation()->getResult(0); } Type getType() { return getResult()->getType(); } @@ -498,7 +486,7 @@ public: getResult()->replaceAllUsesWith(newValue); } - static bool verifyTrait(Instruction *op) { return impl::verifyOneResult(op); } + static bool verifyTrait(Operation *op) { return impl::verifyOneResult(op); } }; /// This class provides the API for ops that are known to have a specified @@ -513,13 +501,11 @@ public: public: static unsigned getNumResults() { return N; } - Value *getResult(unsigned i) { - return this->getInstruction()->getResult(i); - } + Value *getResult(unsigned i) { return this->getOperation()->getResult(i); } Type getType(unsigned i) { return getResult(i)->getType(); } - static bool verifyTrait(Instruction *op) { + static bool verifyTrait(Operation *op) { return impl::verifyNResults(op, N); } }; @@ -535,13 +521,11 @@ public: template class Impl : public TraitBase::Impl> { public: - Value *getResult(unsigned i) { - return this->getInstruction()->getResult(i); - } + Value *getResult(unsigned i) { return this->getOperation()->getResult(i); } Type getType(unsigned i) { return getResult(i)->getType(); } - static bool verifyTrait(Instruction *op) { + static bool verifyTrait(Operation *op) { return impl::verifyAtLeastNResults(op, N); } }; @@ -552,22 +536,22 @@ public: template class VariadicResults : public TraitBase { public: - unsigned getNumResults() { return this->getInstruction()->getNumResults(); } + unsigned getNumResults() { return this->getOperation()->getNumResults(); } - Value *getResult(unsigned i) { return this->getInstruction()->getResult(i); } + Value *getResult(unsigned i) { return this->getOperation()->getResult(i); } void setResult(unsigned i, Value *value) { - this->getInstruction()->setResult(i, value); + this->getOperation()->setResult(i, value); } // Support result iteration. - using result_iterator = Instruction::result_iterator; + using result_iterator = Operation::result_iterator; result_iterator result_begin() { - return this->getInstruction()->result_begin(); + return this->getOperation()->result_begin(); } - result_iterator result_end() { return this->getInstruction()->result_end(); } + result_iterator result_end() { return this->getOperation()->result_end(); } llvm::iterator_range getResults() { - return this->getInstruction()->getResults(); + return this->getOperation()->getResults(); } }; @@ -578,7 +562,7 @@ template class SameOperandsAndResultShape : public TraitBase { public: - static bool verifyTrait(Instruction *op) { + static bool verifyTrait(Operation *op) { return impl::verifySameOperandsAndResultShape(op); } }; @@ -593,7 +577,7 @@ template class SameOperandsAndResultType : public TraitBase { public: - static bool verifyTrait(Instruction *op) { + static bool verifyTrait(Operation *op) { return impl::verifySameOperandsAndResultType(op); } }; @@ -603,7 +587,7 @@ public: template class ResultsAreBoolLike : public TraitBase { public: - static bool verifyTrait(Instruction *op) { + static bool verifyTrait(Operation *op) { return impl::verifyResultsAreBoolLike(op); } }; @@ -614,7 +598,7 @@ template class ResultsAreFloatLike : public TraitBase { public: - static bool verifyTrait(Instruction *op) { + static bool verifyTrait(Operation *op) { return impl::verifyResultsAreFloatLike(op); } }; @@ -625,7 +609,7 @@ template class ResultsAreIntegerLike : public TraitBase { public: - static bool verifyTrait(Instruction *op) { + static bool verifyTrait(Operation *op) { return impl::verifyResultsAreIntegerLike(op); } }; @@ -656,7 +640,7 @@ template class OperandsAreIntegerLike : public TraitBase { public: - static bool verifyTrait(Instruction *op) { + static bool verifyTrait(Operation *op) { return impl::verifyOperandsAreIntegerLike(op); } }; @@ -666,7 +650,7 @@ public: template class SameTypeOperands : public TraitBase { public: - static bool verifyTrait(Instruction *op) { + static bool verifyTrait(Operation *op) { return impl::verifySameTypeOperands(op); } }; @@ -679,37 +663,37 @@ public: return static_cast( OperationProperty::Terminator); } - static bool verifyTrait(Instruction *op) { + static bool verifyTrait(Operation *op) { return impl::verifyIsTerminator(op); } unsigned getNumSuccessors() { - return this->getInstruction()->getNumSuccessors(); + return this->getOperation()->getNumSuccessors(); } unsigned getNumSuccessorOperands(unsigned index) { - return this->getInstruction()->getNumSuccessorOperands(index); + return this->getOperation()->getNumSuccessorOperands(index); } Block *getSuccessor(unsigned index) { - return this->getInstruction()->getSuccessor(index); + return this->getOperation()->getSuccessor(index); } void setSuccessor(Block *block, unsigned index) { - return this->getInstruction()->setSuccessor(block, index); + return this->getOperation()->setSuccessor(block, index); } void addSuccessorOperand(unsigned index, Value *value) { - return this->getInstruction()->addSuccessorOperand(index, value); + return this->getOperation()->addSuccessorOperand(index, value); } void addSuccessorOperands(unsigned index, ArrayRef values) { - return this->getInstruction()->addSuccessorOperand(index, values); + return this->getOperation()->addSuccessorOperand(index, values); } }; } // end namespace OpTrait //===----------------------------------------------------------------------===// -// Instruction Definition classes +// Operation Definition classes //===----------------------------------------------------------------------===// /// This provides public APIs that all operations should have. The template @@ -724,13 +708,13 @@ class Op : public OpState, Traits...>::value> { public: /// Return the operation that this refers to. - Instruction *getInstruction() { return OpState::getInstruction(); } + Operation *getOperation() { return OpState::getOperation(); } /// Return true if this "op class" can match against the specified operation. /// This hook can be overridden with a more specific implementation in /// the subclass of Base. /// - static bool isClassFor(Instruction *op) { + static bool isClassFor(Operation *op) { return op->getName().getStringRef() == ConcreteType::getOperationName(); } @@ -744,21 +728,21 @@ public: /// This is the hook used by the AsmPrinter to emit this to the .mlir file. /// Op implementations should provide a print method. - static void printAssembly(Instruction *op, OpAsmPrinter *p) { + static void printAssembly(Operation *op, OpAsmPrinter *p) { auto opPointer = op->dyn_cast(); assert(opPointer && "op's name does not match name of concrete type instantiated with"); opPointer.print(p); } - /// This is the hook that checks whether or not this instruction is well + /// This is the hook that checks whether or not this operation is well /// formed according to the invariants of its opcode. It delegates to the /// Traits for their policy implementations, and allows the user to specify /// their own verify() method. /// /// On success this returns false; on failure it emits an error to the /// diagnostic subsystem and returns true. - static bool verifyInvariants(Instruction *op) { + static bool verifyInvariants(Operation *op) { return BaseVerifier...>::verifyTrait(op) || op->cast().verify(); } @@ -780,8 +764,8 @@ public: protected: /// This is a private constructor only accessible through the - /// Instruction::cast family of methods. - explicit Op(Instruction *state) : OpState(state) {} + /// Operation::cast family of methods. + explicit Op(Operation *state) : OpState(state) {} friend class Operation; private: @@ -789,13 +773,13 @@ private: template struct BaseVerifier { - static bool verifyTrait(Instruction *op) { + static bool verifyTrait(Operation *op) { return First::verifyTrait(op) || BaseVerifier::verifyTrait(op); } }; template struct BaseVerifier { - static bool verifyTrait(Instruction *op) { return false; } + static bool verifyTrait(Operation *op) { return false; } }; template struct BaseProperties; @@ -824,7 +808,7 @@ bool parseBinaryOp(OpAsmParser *parser, OperationState *result); // Prints the given binary `op` in custom assembly form if both the two operands // and the result have the same time. Otherwise, prints the generic assembly // form. -void printBinaryOp(Instruction *op, OpAsmPrinter *p); +void printBinaryOp(Operation *op, OpAsmPrinter *p); } // namespace impl // These functions are out-of-line implementations of the methods in CastOp, @@ -833,7 +817,7 @@ namespace impl { void buildCastOp(Builder *builder, OperationState *result, Value *source, Type destType); bool parseCastOp(OpAsmParser *parser, OperationState *result); -void printCastOp(Instruction *op, OpAsmPrinter *p); +void printCastOp(Operation *op, OpAsmPrinter *p); } // namespace impl /// This template is used for operations that are cast operations, that have a @@ -857,7 +841,7 @@ public: return impl::parseCastOp(parser, result); } void print(OpAsmPrinter *p) { - return impl::printCastOp(this->getInstruction(), p); + return impl::printCastOp(this->getOperation(), p); } }; diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index ffa26828e8b..eeb35b2d51a 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -76,7 +76,7 @@ public: /// Print a successor, and use list, of a terminator operation given the /// terminator and the successor index. - virtual void printSuccessorAndUseList(Instruction *term, unsigned index) = 0; + virtual void printSuccessorAndUseList(Operation *term, unsigned index) = 0; /// If the specified operation has attributes, print out an attribute /// dictionary with their values. elidedAttrs allows the client to ignore @@ -86,7 +86,7 @@ public: ArrayRef elidedAttrs = {}) = 0; /// Print the entire operation with the default generic assembly form. - virtual void printGenericOp(Instruction *op) = 0; + virtual void printGenericOp(Operation *op) = 0; /// Prints a region. virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true) = 0; diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index 0baceb3fe0d..0fae3fc495f 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -46,8 +46,8 @@ using BlockOperand = IROperandImpl; /// class. class Operation final : public llvm::ilist_node_with_parent, - private llvm::TrailingObjects { + private llvm::TrailingObjects { public: /// Create a new Operation with the specific fields. static Operation *create(Location location, OperationName name, @@ -188,7 +188,7 @@ public: unsigned getNumResults() { return numResults; } - Value *getResult(unsigned idx) { return &getInstResult(idx); } + Value *getResult(unsigned idx) { return &getOpResult(idx); } // Support result iteration. using result_iterator = ResultIterator; @@ -196,11 +196,11 @@ public: result_iterator result_end(); llvm::iterator_range getResults(); - MutableArrayRef getInstResults() { - return {getTrailingObjects(), numResults}; + MutableArrayRef getOpResults() { + return {getTrailingObjects(), numResults}; } - InstResult &getInstResult(unsigned idx) { return getInstResults()[idx]; } + OpResult &getOpResult(unsigned idx) { return getOpResults()[idx]; } // Support result type iteration. using result_type_iterator = ResultTypeIterator; @@ -486,9 +486,9 @@ private: friend class llvm::ilist_node_with_parent; // This stuff is used by the TrailingObjects template. - friend llvm::TrailingObjects; - size_t numTrailingObjects(OverloadToken) const { + size_t numTrailingObjects(OverloadToken) const { return numResults; } size_t numTrailingObjects(OverloadToken) const { diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 2e8aba2aedd..0b35bb32a68 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -109,7 +109,7 @@ public: /// which is the same operation code as getRootKind(). On failure, this /// returns a None value. On success it returns a (possibly null) /// pattern-specific state wrapped in an Optional. - virtual PatternMatchResult match(Instruction *op) const = 0; + virtual PatternMatchResult match(Operation *op) const = 0; virtual ~Pattern() {} @@ -155,7 +155,7 @@ public: /// rewriter. If an unexpected error is encountered (an internal /// compiler error), it is emitted through the normal MLIR diagnostic /// hooks and the IR is left in a valid state. - virtual void rewrite(Instruction *op, std::unique_ptr state, + virtual void rewrite(Operation *op, std::unique_ptr state, PatternRewriter &rewriter) const; /// Rewrite the IR rooted at the specified operation with the result of @@ -163,19 +163,19 @@ public: /// builder. If an unexpected error is encountered (an internal /// compiler error), it is emitted through the normal MLIR diagnostic /// hooks and the IR is left in a valid state. - virtual void rewrite(Instruction *op, PatternRewriter &rewriter) const; + virtual void rewrite(Operation *op, PatternRewriter &rewriter) const; /// Attempt to match against code rooted at the specified operation, /// which is the same operation code as getRootKind(). On failure, this /// returns a None value. On success, it returns a (possibly null) /// pattern-specific state wrapped in an Optional. This state is passed back /// into the rewrite function if this match is selected. - PatternMatchResult match(Instruction *op) const override; + PatternMatchResult match(Operation *op) const override; /// Attempt to match against code rooted at the specified operation, /// which is the same operation code as getRootKind(). If successful, this /// function will automatically perform the rewrite. - virtual PatternMatchResult matchAndRewrite(Instruction *op, + virtual PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { if (auto matchResult = match(op)) { rewrite(op, std::move(*matchResult), rewriter); @@ -229,14 +229,14 @@ public: OpTy::build(this, &state, args...); auto *op = createOperation(state); - // If the Instruction we produce is valid, return it. + // If the Operation we produce is valid, return it. if (!OpTy::verifyInvariants(op)) { auto result = op->dyn_cast(); assert(result && "Builder didn't return the right type"); return result; } - // Otherwise, the error message got emitted. Just remove the instruction + // Otherwise, the error message got emitted. Just remove the operation // we made. op->erase(); return OpTy(); @@ -248,38 +248,37 @@ public: /// clients can specify a list of other nodes that this replacement may make /// (perhaps transitively) dead. If any of those values are dead, this will /// remove them as well. - void replaceOp(Instruction *op, ArrayRef newValues, + void replaceOp(Operation *op, ArrayRef newValues, ArrayRef valuesToRemoveIfDead = {}); /// Replaces the result op with a new op that is created without verification. /// The result values of the two ops must be the same types. template - void replaceOpWithNewOp(Instruction *op, Args... args) { + void replaceOpWithNewOp(Operation *op, Args... args) { auto newOp = create(op->getLoc(), args...); - replaceOpWithResultsOfAnotherOp(op, newOp.getInstruction(), {}); + replaceOpWithResultsOfAnotherOp(op, newOp.getOperation(), {}); } /// Replaces the result op with a new op that is created without verification. /// The result values of the two ops must be the same types. This allows /// specifying a list of ops that may be removed if dead. template - void replaceOpWithNewOp(Instruction *op, - ArrayRef valuesToRemoveIfDead, + void replaceOpWithNewOp(Operation *op, ArrayRef valuesToRemoveIfDead, Args... args) { auto newOp = create(op->getLoc(), args...); - replaceOpWithResultsOfAnotherOp(op, newOp.getInstruction(), + replaceOpWithResultsOfAnotherOp(op, newOp.getOperation(), valuesToRemoveIfDead); } /// This method is used as the final notification hook for patterns that end /// up modifying the pattern root in place, by changing its operands. This is - /// a minor efficiency win (it avoids creating a new instruction and removing + /// a minor efficiency win (it avoids creating a new operation and removing /// the old one) but also often allows simpler code in the client. /// /// The valuesToRemoveIfDead list is an optional list of values that the /// rewriter should remove if they are dead at this point. /// - void updatedRootInPlace(Instruction *op, + void updatedRootInPlace(Operation *op, ArrayRef valuesToRemoveIfDead = {}); protected: @@ -291,26 +290,26 @@ protected: /// This is implemented to create the specified operations and serves as a /// notification hook for rewriters that want to know about new operations. - virtual Instruction *createOperation(const OperationState &state) = 0; + virtual Operation *createOperation(const OperationState &state) = 0; /// Notify the pattern rewriter that the specified operation has been mutated /// in place. This is called after the mutation is done. - virtual void notifyRootUpdated(Instruction *op) {} + virtual void notifyRootUpdated(Operation *op) {} /// Notify the pattern rewriter that the specified operation is about to be /// replaced with another set of operations. This is called before the uses /// of the operation have been changed. - virtual void notifyRootReplaced(Instruction *op) {} + virtual void notifyRootReplaced(Operation *op) {} /// This is called on an operation that a pattern match is removing, right /// before the operation is deleted. At this point, the operation has zero /// uses. - virtual void notifyOperationRemoved(Instruction *op) {} + virtual void notifyOperationRemoved(Operation *op) {} private: /// op and newOp are known to have the same number of results, replace the /// uses of op with uses of newOp - void replaceOpWithResultsOfAnotherOp(Instruction *op, Instruction *newOp, + void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp, ArrayRef valuesToRemoveIfDead); }; @@ -333,7 +332,7 @@ public: PatternRewriter &rewriter); /// Try to match the given operation to a pattern and rewrite it. - void matchAndRewrite(Instruction *op); + void matchAndRewrite(Operation *op); private: RewritePatternMatcher(const RewritePatternMatcher &) = delete; diff --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h index 761cd6fa45d..623fd9cd64e 100644 --- a/mlir/include/mlir/IR/UseDefLists.h +++ b/mlir/include/mlir/IR/UseDefLists.h @@ -30,7 +30,6 @@ namespace mlir { class IROperand; class Operation; -using Instruction = Operation; template class ValueUseIterator; class IRObjectWithUseList { @@ -75,11 +74,11 @@ private: IROperand *firstUse = nullptr; }; -/// A reference to a value, suitable for use as an operand of an instruction. +/// A reference to a value, suitable for use as an operand of an operation. class IROperand { public: - IROperand(Instruction *owner) : owner(owner) {} - IROperand(Instruction *owner, IRObjectWithUseList *value) + IROperand(Operation *owner) : owner(owner) {} + IROperand(Operation *owner, IRObjectWithUseList *value) : value(value), owner(owner) { insertIntoCurrent(); } @@ -97,8 +96,8 @@ public: } /// Return the owner of this operand. - Instruction *getOwner() { return owner; } - Instruction *getOwner() const { return owner; } + Operation *getOwner() { return owner; } + Operation *getOwner() const { return owner; } /// \brief Remove this use of the operand. void drop() { @@ -143,8 +142,8 @@ private: /// This points to the previous link in the use-chain. IROperand **back = nullptr; - /// The instruction owner of this operand. - Instruction *const owner; + /// The operation owner of this operand. + Operation *const owner; /// Operands are not copyable or assignable. IROperand(const IROperand &use) = delete; @@ -167,14 +166,13 @@ private: } }; -/// A reference to a value, suitable for use as an operand of an instruction, -/// instruction, etc. IRValueTy is the root type to use for values this tracks, +/// A reference to a value, suitable for use as an operand of an operation, +/// operation, etc. IRValueTy is the root type to use for values this tracks, /// and SSAUserTy is the type that will contain operands. template class IROperandImpl : public IROperand { public: - IROperandImpl(Instruction *owner) : IROperand(owner) {} - IROperandImpl(Instruction *owner, IRValueTy *value) - : IROperand(owner, value) {} + IROperandImpl(Operation *owner) : IROperand(owner) {} + IROperandImpl(Operation *owner, IRValueTy *value) : IROperand(owner, value) {} /// Return the current value being used by this operand. IRValueTy *get() { return (IRValueTy *)IROperand::get(); } @@ -196,7 +194,7 @@ public: OperandType *operator->() const { return current; } OperandType &operator*() const { return *current; } - Instruction *getUser() const { return current->getOwner(); } + Operation *getUser() const { return current->getOwner(); } ValueUseIterator &operator++() { assert(current && "incrementing past end()!"); diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h index ef3d7c47ce8..04ba9cd2788 100644 --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -30,7 +30,6 @@ namespace mlir { class Block; class Function; class Operation; -using Instruction = Operation; class Value; /// Operands contain a Value. @@ -44,7 +43,7 @@ public: /// This enumerates all of the SSA value kinds in the MLIR system. enum class Kind { BlockArgument, // block argument - InstResult, // operation instruction result + OpResult, // operation result }; ~Value() {} @@ -63,9 +62,9 @@ public: /// Return the function that this Value is defined in. Function *getFunction(); - /// If this value is the result of an operation, return the instruction - /// that defines it. - Instruction *getDefiningInst(); + /// If this value is the result of an operation, return the operation that + /// defines it. + Operation *getDefiningOp(); using use_iterator = ValueUseIterator; using use_range = llvm::iterator_range; @@ -131,17 +130,17 @@ private: Block *const owner; }; -/// This is a value defined by a result of an operation instruction. -class InstResult : public Value { +/// This is a value defined by a result of an operation. +class OpResult : public Value { public: - InstResult(Type type, Instruction *owner) - : Value(Value::Kind::InstResult, type), owner(owner) {} + OpResult(Type type, Operation *owner) + : Value(Value::Kind::OpResult, type), owner(owner) {} static bool classof(const Value *value) { - return const_cast(value)->getKind() == Kind::InstResult; + return const_cast(value)->getKind() == Kind::OpResult; } - Instruction *getOwner() { return owner; } + Operation *getOwner() { return owner; } /// Returns the number of this result. unsigned getResultNumber(); @@ -150,7 +149,7 @@ private: /// The owner of this operand. /// TODO: can encode this more efficiently to avoid the space hit of this /// through bitpacking shenanigans. - Instruction *const owner; + Operation *const owner; }; /// This is a helper template used to implement an iterator that contains a diff --git a/mlir/include/mlir/StandardOps/Ops.h b/mlir/include/mlir/StandardOps/Ops.h index b3e0f9daccf..eb0a3ec644f 100644 --- a/mlir/include/mlir/StandardOps/Ops.h +++ b/mlir/include/mlir/StandardOps/Ops.h @@ -558,8 +558,8 @@ public: } // Returns the source memerf indices for this DMA operation. llvm::iterator_range getSrcIndices() { - return {getInstruction()->operand_begin() + 1, - getInstruction()->operand_begin() + 1 + getSrcMemRefRank()}; + return {getOperation()->operand_begin() + 1, + getOperation()->operand_begin() + 1 + getSrcMemRefRank()}; } // Returns the destination MemRefType for this DMA operations. @@ -577,8 +577,8 @@ public: // Returns the destination memref indices for this DMA operation. llvm::iterator_range getDstIndices() { - return {getInstruction()->operand_begin() + 1 + getSrcMemRefRank() + 1, - getInstruction()->operand_begin() + 1 + getSrcMemRefRank() + 1 + + return {getOperation()->operand_begin() + 1 + getSrcMemRefRank() + 1, + getOperation()->operand_begin() + 1 + getSrcMemRefRank() + 1 + getDstMemRefRank()}; } @@ -600,8 +600,8 @@ public: llvm::iterator_range getTagIndices() { unsigned tagIndexStartPos = 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1; - return {getInstruction()->operand_begin() + tagIndexStartPos, - getInstruction()->operand_begin() + tagIndexStartPos + + return {getOperation()->operand_begin() + tagIndexStartPos, + getOperation()->operand_begin() + tagIndexStartPos + getTagMemRefRank()}; } @@ -678,8 +678,8 @@ public: // Returns the tag memref index for this DMA operation. llvm::iterator_range getTagIndices() { - return {getInstruction()->operand_begin() + 1, - getInstruction()->operand_begin() + 1 + getTagMemRefRank()}; + return {getOperation()->operand_begin() + 1, + getOperation()->operand_begin() + 1 + getTagMemRefRank()}; } // Returns the rank (number of indices) of the tag memref. @@ -719,8 +719,7 @@ public: Value *getAggregate() { return getOperand(0); } llvm::iterator_range getIndices() { - return {getInstruction()->operand_begin() + 1, - getInstruction()->operand_end()}; + return {getOperation()->operand_begin() + 1, getOperation()->operand_end()}; } static StringRef getOperationName() { return "std.extract_element"; } @@ -756,8 +755,7 @@ public: } llvm::iterator_range getIndices() { - return {getInstruction()->operand_begin() + 1, - getInstruction()->operand_end()}; + return {getOperation()->operand_begin() + 1, getOperation()->operand_end()}; } static StringRef getOperationName() { return "std.load"; } @@ -881,8 +879,7 @@ public: } llvm::iterator_range getIndices() { - return {getInstruction()->operand_begin() + 2, - getInstruction()->operand_end()}; + return {getOperation()->operand_begin() + 2, getOperation()->operand_end()}; } static StringRef getOperationName() { return "std.store"; } diff --git a/mlir/include/mlir/StandardOps/Ops.td b/mlir/include/mlir/StandardOps/Ops.td index 5024999fc71..fd1e87fd0d8 100644 --- a/mlir/include/mlir/StandardOps/Ops.td +++ b/mlir/include/mlir/StandardOps/Ops.td @@ -42,7 +42,7 @@ class ArithmeticOp traits = []> : }]; let printer = [{ - return detail::printStandardBinaryOp(this->getInstruction(), p); + return detail::printStandardBinaryOp(this->getOperation(), p); }]; } diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index a6dd19a418f..c1f9606eba3 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -45,7 +45,7 @@ AffineOpsDialect::AffineOpsDialect(MLIRContext *context) bool mlir::isTopLevelSymbol(Value *value) { if (auto *arg = dyn_cast(value)) return arg->getOwner()->getParent()->getContainingFunction(); - return value->getDefiningInst()->getParentInst() == nullptr; + return value->getDefiningOp()->getParentInst() == nullptr; } // Value can be used as a dimension id if it is valid as a symbol, or @@ -56,7 +56,7 @@ bool mlir::isValidDim(Value *value) { if (!value->getType().isIndex()) return false; - if (auto *inst = value->getDefiningInst()) { + if (auto *inst = value->getDefiningOp()) { // Top level instruction or constant operation is ok. if (inst->getParentInst() == nullptr || inst->isa()) return true; @@ -81,7 +81,7 @@ bool mlir::isValidSymbol(Value *value) { if (!value->getType().isIndex()) return false; - if (auto *inst = value->getDefiningInst()) { + if (auto *inst = value->getDefiningOp()) { // Top level instruction or constant operation is ok. if (inst->getParentInst() == nullptr || inst->isa()) return true; @@ -317,7 +317,7 @@ indicesFromAffineApplyOp(ArrayRef operands) { llvm::SetVector res; for (auto en : llvm::enumerate(operands)) { auto *t = en.value(); - if (t->getDefiningInst() && t->getDefiningInst()->isa()) { + if (t->getDefiningOp() && t->getDefiningOp()->isa()) { res.insert(en.index()); } } @@ -458,12 +458,12 @@ AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map, // 2. Compose AffineApplyOps and dispatch dims or symbols. for (unsigned i = 0, e = operands.size(); i < e; ++i) { auto *t = operands[i]; - auto affineApply = t->getDefiningInst() - ? t->getDefiningInst()->dyn_cast() + auto affineApply = t->getDefiningOp() + ? t->getDefiningOp()->dyn_cast() : AffineApplyOp(); if (affineApply) { // a. Compose affine.apply instructions. - LLVM_DEBUG(affineApply.getInstruction()->print( + LLVM_DEBUG(affineApply.getOperation()->print( dbgs() << "\nCompose AffineApplyOp recursively: ")); AffineMap affineApplyMap = affineApply.getAffineMap(); SmallVector affineApplyOperands( @@ -535,7 +535,7 @@ static void composeAffineMapAndOperands(AffineMap *map, void mlir::fullyComposeAffineMapAndOperands( AffineMap *map, SmallVectorImpl *operands) { while (llvm::any_of(*operands, [](Value *v) { - return v->getDefiningInst() && v->getDefiningInst()->isa(); + return v->getDefiningOp() && v->getDefiningOp()->isa(); })) { composeAffineMapAndOperands(map, operands); } @@ -731,7 +731,7 @@ void AffineForOp::build(Builder *builder, OperationState *result, int64_t lb, } bool AffineForOp::verify() { - auto &bodyRegion = getInstruction()->getRegion(0); + auto &bodyRegion = getOperation()->getRegion(0); // The body region must contain a single basic block. if (bodyRegion.empty() || std::next(bodyRegion.begin()) != bodyRegion.end()) @@ -955,7 +955,7 @@ void AffineForOp::print(OpAsmPrinter *p) { if (getStep() != 1) *p << " step " << getStep(); - p->printRegion(getInstruction()->getRegion(0), + p->printRegion(getOperation()->getRegion(0), /*printEntryBlockArgs=*/false); p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{getLowerBoundAttrName(), @@ -1062,7 +1062,7 @@ void AffineForOp::setLowerBound(ArrayRef lbOperands, AffineMap map) { auto ubOperands = getUpperBoundOperands(); newOperands.append(ubOperands.begin(), ubOperands.end()); - getInstruction()->setOperands(newOperands); + getOperation()->setOperands(newOperands); setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map)); } @@ -1073,7 +1073,7 @@ void AffineForOp::setUpperBound(ArrayRef ubOperands, AffineMap map) { SmallVector newOperands(getLowerBoundOperands()); newOperands.append(ubOperands.begin(), ubOperands.end()); - getInstruction()->setOperands(newOperands); + getOperation()->setOperands(newOperands); setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map)); } @@ -1158,7 +1158,7 @@ AffineForOp mlir::getForInductionVarOwner(Value *val) { auto *ivArg = dyn_cast(val); if (!ivArg || !ivArg->getOwner()) return AffineForOp(); - auto *containingInst = ivArg->getOwner()->getParent()->getContainingInst(); + auto *containingInst = ivArg->getOwner()->getParent()->getContainingOp(); if (!containingInst) return AffineForOp(); return containingInst->dyn_cast(); @@ -1207,7 +1207,7 @@ bool AffineIfOp::verify() { return true; // Verify that the entry of each child region does not have arguments. - for (auto ®ion : getInstruction()->getRegions()) { + for (auto ®ion : getOperation()->getRegions()) { if (region.empty()) continue; @@ -1273,10 +1273,10 @@ void AffineIfOp::print(OpAsmPrinter *p) { *p << "affine.if " << conditionAttr; printDimAndSymbolList(operand_begin(), operand_end(), conditionAttr.getValue().getNumDims(), p); - p->printRegion(getInstruction()->getRegion(0)); + p->printRegion(getOperation()->getRegion(0)); // Print the 'else' regions if it has any blocks. - auto &elseRegion = getInstruction()->getRegion(1); + auto &elseRegion = getOperation()->getRegion(1); if (!elseRegion.empty()) { *p << " else"; p->printRegion(elseRegion); @@ -1295,7 +1295,7 @@ void AffineIfOp::setIntegerSet(IntegerSet newSet) { } /// Returns the list of 'then' blocks. -Region &AffineIfOp::getThenBlocks() { return getInstruction()->getRegion(0); } +Region &AffineIfOp::getThenBlocks() { return getOperation()->getRegion(0); } /// Returns the list of 'else' blocks. -Region &AffineIfOp::getElseBlocks() { return getInstruction()->getRegion(1); } +Region &AffineIfOp::getElseBlocks() { return getOperation()->getRegion(1); } diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index 6b865c40638..b3548f96b29 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -62,8 +62,8 @@ void mlir::getReachableAffineApplyOps( while (!worklist.empty()) { State &state = worklist.back(); - auto *opInst = state.value->getDefiningInst(); - // Note: getDefiningInst will return nullptr if the operand is not an + auto *opInst = state.value->getDefiningOp(); + // Note: getDefiningOp will return nullptr if the operand is not an // Instruction (i.e. AffineForOp), which is a terminator for the search. if (opInst == nullptr || !opInst->isa()) { worklist.pop_back(); @@ -458,7 +458,7 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap, auto *symbol = operands[i]; assert(isValidSymbol(symbol)); // Check if the symbol is a constant. - if (auto *opInst = symbol->getDefiningInst()) { + if (auto *opInst = symbol->getDefiningOp()) { if (auto constOp = opInst->dyn_cast()) { dependenceDomain->setIdToConstant(valuePosMap.getSymPos(symbol), constOp.getValue()); @@ -538,8 +538,8 @@ static Block *getCommonBlock(const MemRefAccess &srcAccess, unsigned numCommonLoops) { if (numCommonLoops == 0) { auto *block = srcAccess.opInst->getBlock(); - while (block->getContainingInst()) { - block = block->getContainingInst()->getBlock(); + while (block->getContainingOp()) { + block = block->getContainingOp()->getBlock(); } return block; } diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index dd254b466e1..483e69f7b1d 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -727,7 +727,7 @@ void FlatAffineConstraints::addInductionVarOrTerminalSymbol(Value *id) { // Add top level symbol. addSymbolId(getNumSymbolIds(), id); // Check if the symbol is a constant. - if (auto *opInst = id->getDefiningInst()) { + if (auto *opInst = id->getDefiningOp()) { if (auto constOp = opInst->dyn_cast()) { setIdToConstant(*id, constOp.getValue()); } diff --git a/mlir/lib/Analysis/Dominance.cpp b/mlir/lib/Analysis/Dominance.cpp index 24828a71204..b8a9e1c0218 100644 --- a/mlir/lib/Analysis/Dominance.cpp +++ b/mlir/lib/Analysis/Dominance.cpp @@ -72,7 +72,7 @@ bool DominanceInfoBase::properlyDominates(Block *a, Block *b) { if (regionA != regionB) { Instruction *bAncestor; do { - bAncestor = regionB->getContainingInst(); + bAncestor = regionB->getContainingOp(); // If 'bAncestor' is the top level function, then 'a' is a block // that post dominates 'b'. if (!bAncestor) @@ -122,7 +122,7 @@ bool DominanceInfo::properlyDominates(Instruction *a, Instruction *b) { /// Return true if value A properly dominates instruction B. bool DominanceInfo::properlyDominates(Value *a, Instruction *b) { - if (auto *aInst = a->getDefiningInst()) + if (auto *aInst = a->getDefiningOp()) return properlyDominates(aInst, b); // block arguments properly dominate all instructions in their own block, so diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index ab2598985db..eb272389957 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -54,7 +54,7 @@ void mlir::buildTripCountMapAndOperands( int64_t loopSpan; int64_t step = forOp.getStep(); - FuncBuilder b(forOp.getInstruction()); + FuncBuilder b(forOp.getOperation()); if (forOp.hasConstantBounds()) { int64_t lb = forOp.getConstantLowerBound(); @@ -101,7 +101,7 @@ void mlir::buildTripCountMapAndOperands( // simplification, and canonicalization above. for (auto *v : ubs) if (v->use_empty()) - v->getDefiningInst()->erase(); + v->getDefiningOp()->erase(); if (lb.use_empty()) lb.erase(); } @@ -280,7 +280,7 @@ using VectorizableInstFun = std::function; static bool isVectorizableLoopWithCond(AffineForOp loop, VectorizableInstFun isVectorizableInst) { - auto *forInst = loop.getInstruction(); + auto *forInst = loop.getOperation(); if (!matcher::isParallelLoop(*forInst) && !matcher::isReductionLoop(*forInst)) { return false; @@ -361,12 +361,12 @@ bool mlir::isVectorizableLoop(AffineForOp loop) { // violation when we have the support. bool mlir::isInstwiseShiftValid(AffineForOp forOp, ArrayRef shifts) { auto *forBody = forOp.getBody(); - assert(shifts.size() == forBody->getInstructions().size()); + assert(shifts.size() == forBody->getOperations().size()); // Work backwards over the body of the block so that the shift of a use's // ancestor instruction in the block gets recorded before it's looked up. DenseMap forBodyShift; - for (auto it : llvm::enumerate(llvm::reverse(forBody->getInstructions()))) { + for (auto it : llvm::enumerate(llvm::reverse(forBody->getOperations()))) { auto &inst = it.value(); // Get the index of the current instruction, note that we are iterating in diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index e1b3c63d08b..82320bd26ff 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -104,7 +104,7 @@ static void getBackwardSliceImpl(Instruction *inst, } for (auto *operand : inst->getOperands()) { - auto *inst = operand->getDefiningInst(); + auto *inst = operand->getDefiningOp(); if (backwardSlice->count(inst) == 0) { getBackwardSliceImpl(inst, backwardSlice, filter); } diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index a564592b2dd..a9c22d62f0b 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -73,7 +73,7 @@ ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) { assert(cst->containsId(*value) && "value expected to be present"); if (isValidSymbol(value)) { // Check if the symbol is a constant. - if (auto *inst = value->getDefiningInst()) { + if (auto *inst = value->getDefiningOp()) { if (auto constOp = inst->dyn_cast()) { cst->setIdToConstant(*value, constOp.getValue()); } @@ -242,7 +242,7 @@ LogicalResult MemRefRegion::compute(Instruction *inst, unsigned loopDepth, auto *symbol = operand; assert(isValidSymbol(symbol)); // Check if the symbol is a constant. - if (auto *inst = symbol->getDefiningInst()) { + if (auto *inst = symbol->getDefiningOp()) { if (auto constOp = inst->dyn_cast()) { cst.setIdToConstant(*symbol, constOp.getValue()); } @@ -374,7 +374,7 @@ LogicalResult mlir::boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, std::is_same::value, "argument should be either a LoadOp or a StoreOp"); - Instruction *opInst = loadOrStoreOp.getInstruction(); + Instruction *opInst = loadOrStoreOp.getOperation(); MemRefRegion region(opInst->getLoc()); if (failed(region.compute(opInst, /*loopDepth=*/0))) @@ -437,7 +437,7 @@ static void findInstPosition(Instruction *inst, Block *limitBlock, // rely on linear scans. int instPosInBlock = std::distance(block->begin(), inst->getIterator()); positions->push_back(instPosInBlock); - inst = block->getContainingInst(); + inst = block->getContainingOp(); block = inst->getBlock(); } std::reverse(positions->begin(), positions->end()); @@ -583,7 +583,7 @@ AffineForOp mlir::insertBackwardComputationSlice( // Find the inst block positions of 'srcOpInst' within 'srcLoopIVs'. SmallVector positions; // TODO(andydavis): This code is incorrect since srcLoopIVs can be 0-d. - findInstPosition(srcOpInst, srcLoopIVs[0].getInstruction()->getBlock(), + findInstPosition(srcOpInst, srcLoopIVs[0].getOperation()->getBlock(), &positions); // Clone src loop nest and insert it a the beginning of the instruction block @@ -591,7 +591,7 @@ AffineForOp mlir::insertBackwardComputationSlice( auto dstAffineForOp = dstLoopIVs[dstLoopDepth - 1]; FuncBuilder b(dstAffineForOp.getBody(), dstAffineForOp.getBody()->begin()); auto sliceLoopNest = - b.clone(*srcLoopIVs[0].getInstruction())->cast(); + b.clone(*srcLoopIVs[0].getOperation())->cast(); Instruction *sliceInst = getInstAtPosition(positions, /*level=*/0, sliceLoopNest.getBody()); @@ -670,7 +670,7 @@ unsigned mlir::getNumCommonSurroundingLoops(Instruction &A, Instruction &B) { unsigned minNumLoops = std::min(loopsA.size(), loopsB.size()); unsigned numCommonLoops = 0; for (unsigned i = 0; i < minNumLoops; ++i) { - if (loopsA[i].getInstruction() != loopsB[i].getInstruction()) + if (loopsA[i].getOperation() != loopsB[i].getOperation()) break; ++numCommonLoops; } @@ -727,7 +727,7 @@ static Optional getMemoryFootprintBytes(Block &block, Optional mlir::getMemoryFootprintBytes(AffineForOp forOp, int memorySpace) { - auto *forInst = forOp.getInstruction(); + auto *forInst = forOp.getOperation(); return ::getMemoryFootprintBytes( *forInst->getBlock(), Block::iterator(forInst), std::next(Block::iterator(forInst)), memorySpace); @@ -737,7 +737,7 @@ Optional mlir::getMemoryFootprintBytes(AffineForOp forOp, /// at 'forOp'. void mlir::getSequentialLoops( AffineForOp forOp, llvm::SmallDenseSet *sequentialLoops) { - forOp.getInstruction()->walk([&](Instruction *inst) { + forOp.getOperation()->walk([&](Instruction *inst) { if (auto innerFor = inst->dyn_cast()) if (!isLoopParallel(innerFor)) sequentialLoops->insert(innerFor.getInductionVar()); @@ -748,13 +748,13 @@ void mlir::getSequentialLoops( bool mlir::isLoopParallel(AffineForOp forOp) { // Collect all load and store ops in loop nest rooted at 'forOp'. SmallVector loadAndStoreOpInsts; - forOp.getInstruction()->walk([&](Instruction *opInst) { + forOp.getOperation()->walk([&](Instruction *opInst) { if (opInst->isa() || opInst->isa()) loadAndStoreOpInsts.push_back(opInst); }); // Dep check depth would be number of enclosing loops + 1. - unsigned depth = getNestingDepth(*forOp.getInstruction()) + 1; + unsigned depth = getNestingDepth(*forOp.getOperation()) + 1; // Check dependences between all pairs of ops in 'loadAndStoreOpInsts'. for (auto *srcOpInst : loadAndStoreOpInsts) { diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index 781a3cde9fe..f211417b798 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -225,7 +225,7 @@ static bool canBlockHaveNoTerminator(Block &block) { // Allow the first block of an operation region to have no terminator if it is // the only block in the region. auto *parentList = block.getParent(); - return parentList->getContainingInst() && + return parentList->getContainingOp() && std::next(parentList->begin()) == parentList->end(); } @@ -295,7 +295,7 @@ bool FuncVerifier::verifyOperation(Instruction &op) { if (!attr.first.strref().contains('.')) continue; if (auto *dialect = getDialectForAttribute(attr, op)) - if (dialect->verifyInstructionAttribute(&op, attr)) + if (dialect->verifyOperationAttribute(&op, attr)) return true; } @@ -332,7 +332,7 @@ bool FuncVerifier::verifyInstDominance(Instruction &inst) { inst.emitError("operand #" + Twine(operandNo) + " does not dominate this use"); - if (auto *useInst = op->getDefiningInst()) + if (auto *useInst = op->getDefiningOp()) useInst->emitNote("operand defined here"); return true; } diff --git a/mlir/lib/EDSC/Builders.cpp b/mlir/lib/EDSC/Builders.cpp index 191b789dec6..e991817b6d7 100644 --- a/mlir/lib/EDSC/Builders.cpp +++ b/mlir/lib/EDSC/Builders.cpp @@ -87,7 +87,7 @@ mlir::edsc::ValueHandle::createComposedAffineApply(AffineMap map, Instruction *inst = makeComposedAffineApply(ScopedContext::getBuilder(), ScopedContext::getLocation(), map, operands) - .getInstruction(); + .getOperation(); assert(inst->getNumResults() == 1 && "Not a single result AffineApply"); return ValueHandle(inst->getResult(0)); } @@ -145,8 +145,8 @@ static llvm::Optional emitStaticFor(ArrayRef lbs, if (lbs.size() != 1 || ubs.size() != 1) return llvm::Optional(); - auto *lbDef = lbs.front().getValue()->getDefiningInst(); - auto *ubDef = ubs.front().getValue()->getDefiningInst(); + auto *lbDef = lbs.front().getValue()->getDefiningOp(); + auto *ubDef = ubs.front().getValue()->getDefiningOp(); if (!lbDef || !ubDef) return llvm::Optional(); @@ -267,7 +267,7 @@ categorizeValueByAffineType(MLIRContext *context, Value *val, unsigned &numDims, unsigned &numSymbols) { AffineExpr d; Value *resultVal = nullptr; - auto *inst = val->getDefiningInst(); + auto *inst = val->getDefiningOp(); auto constant = inst ? inst->dyn_cast() : ConstantIndexOp(); if (constant) { d = getAffineConstantExpr(constant.getValue(), context); diff --git a/mlir/lib/EDSC/MLIREmitter.cpp b/mlir/lib/EDSC/MLIREmitter.cpp index 49b544a9b77..6c6262c2790 100644 --- a/mlir/lib/EDSC/MLIREmitter.cpp +++ b/mlir/lib/EDSC/MLIREmitter.cpp @@ -46,13 +46,13 @@ using namespace mlir::edsc; using namespace mlir::edsc::detail; static void printDefininingStatement(llvm::raw_ostream &os, Value &v) { - auto *inst = v.getDefiningInst(); + auto *inst = v.getDefiningOp(); if (inst) { inst->print(os); return; } if (auto forInst = getForInductionVarOwner(&v)) { - forInst.getInstruction()->print(os); + forInst.getOperation()->print(os); } else if (auto *bbArg = dyn_cast(&v)) { os << "block_argument"; } else { @@ -84,7 +84,7 @@ MLIREmitter &mlir::edsc::MLIREmitter::bind(Bindable e, Value *v) { static void checkAffineProvenance(ArrayRef values) { for (Value *v : values) { - auto *def = v->getDefiningInst(); + auto *def = v->getDefiningOp(); (void)def; // There may be no defining instruction if the value is a function // argument. We accept such values. @@ -100,8 +100,8 @@ static AffineForOp emitStaticFor(FuncBuilder &builder, Location loc, if (lbs.size() != 1 || ubs.size() != 1) return AffineForOp(); - auto *lbDef = lbs.front()->getDefiningInst(); - auto *ubDef = ubs.front()->getDefiningInst(); + auto *lbDef = lbs.front()->getDefiningOp(); + auto *ubDef = ubs.front()->getDefiningOp(); if (!lbDef || !ubDef) return AffineForOp(); @@ -165,8 +165,7 @@ Value *mlir::edsc::MLIREmitter::emitExpr(Expr e) { checkAffineProvenance(ubs); // Step must be a static constant. - auto step = - stepExpr->getDefiningInst()->cast().getValue(); + auto step = stepExpr->getDefiningOp()->cast().getValue(); // Special case with more concise emitted code for static bounds. AffineForOp forOp = emitStaticFor(*builder, location, lbs, ubs, step); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index a2c259e3377..de6654cf532 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -138,7 +138,7 @@ private: void recordTypeReference(Type ty) { usedTypes.insert(ty); } // Visit functions. - void visitInstruction(Instruction *inst); + void visitOperation(Operation *inst); void visitType(Type type); void visitAttribute(Attribute attr); @@ -189,7 +189,7 @@ void ModuleState::visitAttribute(Attribute attr) { } } -void ModuleState::visitInstruction(Instruction *inst) { +void ModuleState::visitOperation(Operation *inst) { // Visit all the types used in the operation. for (auto *operand : inst->getOperands()) visitType(operand->getType()); @@ -270,7 +270,7 @@ void ModuleState::initialize(Module *module) { for (auto &fn : *module) { visitType(fn.getType()); - fn.walk([&](Instruction *op) { ModuleState::visitInstruction(op); }); + fn.walk([&](Operation *op) { ModuleState::visitOperation(op); }); } // Initialize the symbol aliases. @@ -1059,11 +1059,11 @@ public: void printFunctionSignature(); // Methods to print instructions. - void print(Instruction *inst); + void print(Operation *inst); void print(Block *block, bool printBlockArgs = true); - void printOperation(Instruction *op); - void printGenericOp(Instruction *op); + void printOperation(Operation *op); + void printGenericOp(Operation *op); // Implement OpAsmPrinter. raw_ostream &getStream() const { return os; } @@ -1106,7 +1106,7 @@ public: return it != blockIDs.end() ? it->second : ~0U; } - void printSuccessorAndUseList(Instruction *term, unsigned index) override; + void printSuccessorAndUseList(Operation *term, unsigned index) override; /// Print a region. void printRegion(Region &blocks, bool printEntryBlockArgs) override { @@ -1196,7 +1196,7 @@ void FunctionPrinter::numberValueID(Value *value) { llvm::raw_svector_ostream specialName(specialNameBuffer); // Give constant integers special names. - if (auto *op = value->getDefiningInst()) { + if (auto *op = value->getDefiningOp()) { Attribute cst; if (m_Constant(&cst).match(op)) { Type type = op->getResult(0)->getType(); @@ -1236,7 +1236,7 @@ void FunctionPrinter::numberValueID(Value *value) { // Otherwise number it normally. valueIDs[value] = nextValueID++; return; - case Value::Kind::InstResult: + case Value::Kind::OpResult: // This is an uninteresting result, give it a boring number and be // done with it. valueIDs[value] = nextValueID++; @@ -1380,14 +1380,14 @@ void FunctionPrinter::print(Block *block, bool printBlockArgs) { currentIndent += indentWidth; - for (auto &inst : block->getInstructions()) { + for (auto &inst : block->getOperations()) { print(&inst); os << '\n'; } currentIndent -= indentWidth; } -void FunctionPrinter::print(Instruction *inst) { +void FunctionPrinter::print(Operation *inst) { os.indent(currentIndent); printOperation(inst); printTrailingLocation(inst->getLoc()); @@ -1400,12 +1400,12 @@ void FunctionPrinter::printValueID(Value *value, bool printResultNo) const { // If this is a reference to the result of a multi-result instruction or // instruction, print out the # identifier and make sure to map our lookup // to the first result of the instruction. - if (auto *result = dyn_cast(value)) { + if (auto *result = dyn_cast(value)) { if (result->getOwner()->getNumResults() != 1) { resultNo = result->getResultNumber(); lookupValue = result->getOwner()->getResult(0); } - } else if (auto *result = dyn_cast(value)) { + } else if (auto *result = dyn_cast(value)) { if (result->getOwner()->getNumResults() != 1) { resultNo = result->getResultNumber(); lookupValue = result->getOwner()->getResult(0); @@ -1431,7 +1431,7 @@ void FunctionPrinter::printValueID(Value *value, bool printResultNo) const { os << '#' << resultNo; } -void FunctionPrinter::printOperation(Instruction *op) { +void FunctionPrinter::printOperation(Operation *op) { if (op->getNumResults()) { printValueID(op->getResult(0), /*printResultNo=*/false); os << " = "; @@ -1451,7 +1451,7 @@ void FunctionPrinter::printOperation(Instruction *op) { printGenericOp(op); } -void FunctionPrinter::printGenericOp(Instruction *op) { +void FunctionPrinter::printGenericOp(Operation *op) { os << '"'; printEscapedString(op->getName().getStringRef(), os); os << "\"("; @@ -1504,7 +1504,7 @@ void FunctionPrinter::printGenericOp(Instruction *op) { printRegion(region, /*printEntryBlockArgs=*/true); } -void FunctionPrinter::printSuccessorAndUseList(Instruction *term, +void FunctionPrinter::printSuccessorAndUseList(Operation *term, unsigned index) { printBlockName(term->getSuccessor(index)); @@ -1586,14 +1586,14 @@ void Value::print(raw_ostream &os) { // TODO: Improve this. os << "\n"; return; - case Value::Kind::InstResult: - return getDefiningInst()->print(os); + case Value::Kind::OpResult: + return getDefiningOp()->print(os); } } void Value::dump() { print(llvm::errs()); } -void Instruction::print(raw_ostream &os) { +void Operation::print(raw_ostream &os) { auto *function = getFunction(); if (!function) { os << "<>\n"; @@ -1605,7 +1605,7 @@ void Instruction::print(raw_ostream &os) { FunctionPrinter(function, modulePrinter).print(this); } -void Instruction::dump() { +void Operation::dump() { print(llvm::errs()); llvm::errs() << "\n"; } diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp index 2a18260a378..cd5816009a5 100644 --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -37,22 +37,22 @@ unsigned BlockArgument::getArgNumber() { //===----------------------------------------------------------------------===// Block::~Block() { - assert(!verifyInstOrder() && "Expected valid instruction ordering."); + assert(!verifyInstOrder() && "Expected valid operation ordering."); clear(); llvm::DeleteContainerPointers(arguments); } -/// Returns the closest surrounding instruction that contains this block or -/// nullptr if this is a top-level instruction block. -Instruction *Block::getContainingInst() { - return getParent() ? getParent()->getContainingInst() : nullptr; +/// Returns the closest surrounding operation that contains this block or +/// nullptr if this is a top-level operation block. +Operation *Block::getContainingOp() { + return getParent() ? getParent()->getContainingOp() : nullptr; } Function *Block::getFunction() { Block *block = this; - while (auto *inst = block->getContainingInst()) { - block = inst->getBlock(); + while (auto *op = block->getContainingOp()) { + block = op->getBlock(); if (!block) return nullptr; } @@ -75,13 +75,13 @@ void Block::eraseFromFunction() { getFunction()->getBlocks().erase(this); } -/// Returns 'inst' if 'inst' lies in this block, or otherwise finds the -/// ancestor instruction of 'inst' that lies in this block. Returns nullptr if +/// Returns 'op' if 'op' lies in this block, or otherwise finds the +/// ancestor operation of 'op' that lies in this block. Returns nullptr if /// the latter fails. -Instruction *Block::findAncestorInstInBlock(Instruction &inst) { - // Traverse up the instruction hierarchy starting from the owner of operand to - // find the ancestor instruction that resides in the block of 'forInst'. - auto *currInst = &inst; +Operation *Block::findAncestorInstInBlock(Operation &op) { + // Traverse up the operation hierarchy starting from the owner of operand to + // find the ancestor operation that resides in the block of 'forInst'. + auto *currInst = &op; while (currInst->getBlock() != this) { currInst = currInst->getParentInst(); if (!currInst) @@ -90,36 +90,35 @@ Instruction *Block::findAncestorInstInBlock(Instruction &inst) { return currInst; } -/// This drops all operand uses from instructions within this block, which is +/// This drops all operand uses from operations within this block, which is /// an essential step in breaking cyclic dependences between references when /// they are to be deleted. void Block::dropAllReferences() { - for (Instruction &i : *this) + for (Operation &i : *this) i.dropAllReferences(); } void Block::dropAllDefinedValueUses() { for (auto *arg : getArguments()) arg->dropAllUses(); - for (auto &inst : *this) - inst.dropAllDefinedValueUses(); + for (auto &op : *this) + op.dropAllDefinedValueUses(); dropAllUses(); } -/// Verifies the current ordering of child instructions. Returns false if the +/// Verifies the current ordering of child operations. Returns false if the /// order is valid, true otherwise. bool Block::verifyInstOrder() { // The order is already known to be invalid. if (!isInstOrderValid()) return false; - // The order is valid if there are less than 2 instructions. - if (instructions.empty() || - std::next(instructions.begin()) == instructions.end()) + // The order is valid if there are less than 2 operations. + if (operations.empty() || std::next(operations.begin()) == operations.end()) return false; - Instruction *prev = nullptr; + Operation *prev = nullptr; for (auto &i : *this) { - // The previous instruction must have a smaller order index than the next as + // The previous operation must have a smaller order index than the next as // it appears earlier in the list. if (prev && prev->orderIndex >= i.orderIndex) return true; @@ -128,15 +127,15 @@ bool Block::verifyInstOrder() { return false; } -/// Recomputes the ordering of child instructions within the block. +/// Recomputes the ordering of child operations within the block. void Block::recomputeInstOrder() { parentValidInstOrderPair.setInt(true); // TODO(riverriddle) Have non-congruent indices to reduce the number of times // an insert invalidates the list. unsigned orderIndex = 0; - for (auto &inst : *this) - inst.orderIndex = orderIndex++; + for (auto &op : *this) + op.orderIndex = orderIndex++; } Block *PredecessorIterator::operator*() const { @@ -190,9 +189,9 @@ void Block::eraseArgument(unsigned index) { // Terminator management //===----------------------------------------------------------------------===// -/// Get the terminator instruction of this block. This function asserts that -/// the block has a valid terminator instruction. -Instruction *Block::getTerminator() { +/// Get the terminator operation of this block. This function asserts that +/// the block has a valid terminator operation. +Operation *Block::getTerminator() { assert(!empty() && !back().isKnownNonTerminator()); return &back(); } @@ -226,42 +225,42 @@ Block *Block::getSinglePredecessor() { } //===----------------------------------------------------------------------===// -// Instruction Walkers +// Operation Walkers //===----------------------------------------------------------------------===// -void Block::walk(const std::function &callback) { +void Block::walk(const std::function &callback) { walk(begin(), end(), callback); } void Block::walk(Block::iterator begin, Block::iterator end, - const std::function &callback) { - // Walk the instructions within this block. - for (auto &inst : llvm::make_early_inc_range(llvm::make_range(begin, end))) - inst.walk(callback); + const std::function &callback) { + // Walk the operations within this block. + for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end))) + op.walk(callback); } -void Block::walkPostOrder(const std::function &callback) { +void Block::walkPostOrder(const std::function &callback) { walkPostOrder(begin(), end(), callback); } -/// Walk the instructions in the specified [begin, end) range of this block +/// Walk the operations in the specified [begin, end) range of this block /// in postorder, calling the callback for each operation. void Block::walkPostOrder(Block::iterator begin, Block::iterator end, - const std::function &callback) { - // Walk the instructions within this block. - for (auto &inst : llvm::make_early_inc_range(llvm::make_range(begin, end))) - inst.walkPostOrder(callback); + const std::function &callback) { + // Walk the operations within this block. + for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end))) + op.walkPostOrder(callback); } //===----------------------------------------------------------------------===// // Other //===----------------------------------------------------------------------===// -/// Split the block into two blocks before the specified instruction or +/// Split the block into two blocks before the specified operation or /// iterator. /// -/// Note that all instructions BEFORE the specified iterator stay as part of -/// the original basic block, and the rest of the instructions in the original +/// Note that all operations BEFORE the specified iterator stay as part of +/// the original basic block, and the rest of the operations in the original /// block are moved to the new block, including the old terminator. The /// original block is left without a terminator. /// @@ -275,8 +274,8 @@ Block *Block::splitBlock(iterator splitBefore) { // Move all of the operations from the split point to the end of the function // into the new block. - newBB->getInstructions().splice(newBB->end(), getInstructions(), splitBefore, - end()); + newBB->getOperations().splice(newBB->end(), getOperations(), splitBefore, + end()); return newBB; } @@ -286,18 +285,18 @@ Block *Block::splitBlock(iterator splitBefore) { Region::Region(Function *container) : container(container) {} -Region::Region(Instruction *container) : container(container) {} +Region::Region(Operation *container) : container(container) {} Region::~Region() { - // Instructions may have cyclic references, which need to be dropped before we + // Operations may have cyclic references, which need to be dropped before we // can start deleting them. for (auto &bb : *this) bb.dropAllReferences(); } -Instruction *Region::getContainingInst() { +Operation *Region::getContainingOp() { assert(!container.isNull() && "no container"); - return container.dyn_cast(); + return container.dyn_cast(); } Function *Region::getContainingFunction() { @@ -327,20 +326,20 @@ void Region::cloneInto(Region *dest, BlockAndValueMapping &mapper, if (!mapper.contains(arg)) mapper.map(arg, newBlock->addArgument(arg->getType())); - // Clone and remap the instructions within this block. - for (auto &inst : block) - newBlock->push_back(inst.clone(mapper, context)); + // Clone and remap the operations within this block. + for (auto &op : block) + newBlock->push_back(op.clone(mapper, context)); dest->push_back(newBlock); } // Now that each of the blocks have been cloned, go through and remap the - // operands of each of the instructions. - auto remapOperands = [&](Instruction *inst) { - for (auto &instOp : inst->getInstOperands()) + // operands of each of the operations. + auto remapOperands = [&](Operation *op) { + for (auto &instOp : op->getInstOperands()) if (auto *mappedOp = mapper.lookupOrNull(instOp.get())) instOp.set(mappedOp); - for (auto &succOp : inst->getBlockOperands()) + for (auto &succOp : op->getBlockOperands()) if (auto *mappedOp = mapper.lookupOrNull(succOp.get())) succOp.set(mappedOp); }; @@ -363,18 +362,18 @@ void llvm::ilist_traits<::mlir::Block>::addNodeToList(Block *block) { block->parentValidInstOrderPair.setPointer(getContainingRegion()); } -/// This is a trait method invoked when an instruction is removed from a +/// This is a trait method invoked when an operation is removed from a /// region. We keep the region pointer up to date. void llvm::ilist_traits<::mlir::Block>::removeNodeFromList(Block *block) { assert(block->getParent() && "not already in a region!"); block->parentValidInstOrderPair.setPointer(nullptr); } -/// This is a trait method invoked when an instruction is moved from one block +/// This is a trait method invoked when an operation is moved from one block /// to another. We keep the block pointer up to date. void llvm::ilist_traits<::mlir::Block>::transferNodesFromList( ilist_traits &otherList, block_iterator first, block_iterator last) { - // If we are transferring instructions within the same function, the parent + // If we are transferring operations within the same function, the parent // pointer doesn't need to be updated. auto *curParent = getContainingRegion(); if (curParent == otherList.getContainingRegion()) diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 8ec883ae274..f4d532a482f 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -296,7 +296,7 @@ AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) { } //===----------------------------------------------------------------------===// -// Instructions. +// Operations. //===----------------------------------------------------------------------===// /// Add new block and set the insertion point to the end of it. If an @@ -318,18 +318,18 @@ Block *FuncBuilder::createBlock(Block *insertBefore) { } /// Create an operation given the fields represented as an OperationState. -Instruction *FuncBuilder::createOperation(const OperationState &state) { +Operation *FuncBuilder::createOperation(const OperationState &state) { assert(block && "createOperation() called without setting builder's block"); unsigned numRegions = state.regions.size(); - auto *op = Instruction::create( - state.location, state.name, state.operands, state.types, state.attributes, - state.successors, numRegions, state.resizableOperandList, context); + auto *op = Operation::create(state.location, state.name, state.operands, + state.types, state.attributes, state.successors, + numRegions, state.resizableOperandList, context); for (unsigned i = 0; i < numRegions; ++i) if (state.regions[i]) op->getRegion(i).takeBody(*state.regions[i]); - block->getInstructions().insert(insertPoint, op); + block->getOperations().insert(insertPoint, op); return op; } diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index 585d01e0fb1..fa9328f869e 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -91,7 +91,7 @@ void llvm::ilist_traits::removeNodeFromList(Function *function) { function->module = nullptr; } -/// This is a trait method invoked when an instruction is moved from one block +/// This is a trait method invoked when an operation is moved from one block /// to another. We keep the block pointer up to date. void llvm::ilist_traits::transferNodesFromList( ilist_traits &otherList, function_iterator first, @@ -115,7 +115,7 @@ void Function::erase() { getModule()->getFunctions().erase(this); } -/// Emit a note about this instruction, reporting up to any diagnostic +/// Emit a note about this operation, reporting up to any diagnostic /// handlers that may be listening. void Function::emitNote(const Twine &message) { getContext()->emitDiagnostic(getLoc(), message, @@ -209,14 +209,13 @@ void Function::addEntryBlock() { entry->addArguments(type.getInputs()); } -void Function::walk(const std::function &callback) { +void Function::walk(const std::function &callback) { // Walk each of the blocks within the function. for (auto &block : getBlocks()) block.walk(callback); } -void Function::walkPostOrder( - const std::function &callback) { +void Function::walkPostOrder(const std::function &callback) { // Walk each of the blocks within the function. for (auto &block : llvm::reverse(getBlocks())) block.walkPostOrder(callback); diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 9db1f6ba964..aee0ac96917 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -624,7 +624,7 @@ auto MLIRContext::getDiagnosticHandler() -> DiagnosticHandlerTy { /// This emits a diagnostic using the registered issue handle if present, or /// with the default behavior if not. The MLIR compiler should not generally -/// interact with this, it should use methods on Instruction instead. +/// interact with this, it should use methods on Operation instead. void MLIRContext::emitDiagnostic(Location location, const llvm::Twine &message, DiagnosticKind kind) { // Check to see if we are emitting a diagnostic on a fused location. diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 4d0222e598a..4bc8c3e2508 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -53,14 +53,14 @@ OperationName OperationName::getFromOpaquePointer(void *pointer) { OpAsmParser::~OpAsmParser() {} //===----------------------------------------------------------------------===// -// InstResult +// OpResult //===----------------------------------------------------------------------===// /// Return the result number of this result. -unsigned InstResult::getResultNumber() { +unsigned OpResult::getResultNumber() { // Results are always stored consecutively, so use pointer subtraction to // figure out what number this is. - return this - &getOwner()->getInstResults()[0]; + return this - &getOwner()->getOpResults()[0]; } //===----------------------------------------------------------------------===// @@ -112,7 +112,7 @@ Operation *Operation::create(Location location, OperationName name, unsigned numOperands = operands.size() - numSuccessors; // Compute the byte size for the operation and the operand storage. - auto byteSize = totalSizeToAlloc( resultTypes.size(), numSuccessors, numSuccessors, numRegions, /*detail::OperandStorage*/ 1); @@ -137,9 +137,9 @@ Operation *Operation::create(Location location, OperationName name, new (&op->getOperandStorage()) detail::OperandStorage(numOperands, resizableOperandList); - auto instResults = op->getInstResults(); + auto instResults = op->getOpResults(); for (unsigned i = 0, e = resultTypes.size(); i != e; ++i) - new (&instResults[i]) InstResult(resultTypes[i], op); + new (&instResults[i]) OpResult(resultTypes[i], op); auto InstOperands = op->getInstOperands(); @@ -215,8 +215,8 @@ Operation::~Operation() { // Explicitly run the destructors for the operands and results. getOperandStorage().~OperandStorage(); - for (auto &result : getInstResults()) - result.~InstResult(); + for (auto &result : getOpResults()) + result.~OpResult(); // Explicitly run the destructors for the successors. for (auto &successor : getBlockOperands()) @@ -261,7 +261,7 @@ Dialect *Operation::getDialect() { } Operation *Operation::getParentInst() { - return block ? block->getContainingInst() : nullptr; + return block ? block->getContainingOp() : nullptr; } Function *Operation::getFunction() { @@ -394,8 +394,7 @@ void llvm::ilist_traits<::mlir::Operation>::removeNodeFromList(Operation *op) { /// This is a trait method invoked when a operation is moved from one block /// to another. We keep the block pointer up to date. void llvm::ilist_traits<::mlir::Operation>::transferNodesFromList( - ilist_traits &otherList, inst_iterator first, - inst_iterator last) { + ilist_traits &otherList, op_iterator first, op_iterator last) { Block *curParent = getContainingBlock(); // Invalidate the ordering of the parent block. @@ -415,7 +414,7 @@ void llvm::ilist_traits<::mlir::Operation>::transferNodesFromList( /// all of them. void Operation::erase() { assert(getBlock() && "Operation has no block"); - getBlock()->getInstructions().erase(this); + getBlock()->getOperations().erase(this); } /// Unlink this operation from its current block and insert it right before @@ -429,8 +428,8 @@ void Operation::moveBefore(Operation *existingInst) { /// it right before `iterator` in the specified basic block. void Operation::moveBefore(Block *block, llvm::iplist::iterator iterator) { - block->getInstructions().splice(iterator, getBlock()->getInstructions(), - getIterator()); + block->getOperations().splice(iterator, getBlock()->getOperations(), + getIterator()); } /// This drops all operand uses from this operation, which is an essential @@ -451,7 +450,7 @@ void Operation::dropAllReferences() { /// This drops all uses of any values defined by this operation or its nested /// regions, wherever they are located. void Operation::dropAllDefinedValueUses() { - for (auto &val : getInstResults()) + for (auto &val : getOpResults()) val.dropAllUses(); for (auto ®ion : getRegions()) @@ -620,32 +619,32 @@ bool OpState::parse(OpAsmParser *parser, OperationState *result) { } // The fallback for the printer is to print in the generic assembly form. -void OpState::print(OpAsmPrinter *p) { p->printGenericOp(getInstruction()); } +void OpState::print(OpAsmPrinter *p) { p->printGenericOp(getOperation()); } /// Emit an error about fatal conditions with this operation, reporting up to /// any diagnostic handlers that may be listening. NOTE: This may terminate /// the containing application, only use when the IR is in an inconsistent /// state. bool OpState::emitError(const Twine &message) { - return getInstruction()->emitError(message); + return getOperation()->emitError(message); } /// Emit an error with the op name prefixed, like "'dim' op " which is /// convenient for verifiers. bool OpState::emitOpError(const Twine &message) { - return getInstruction()->emitOpError(message); + return getOperation()->emitOpError(message); } /// Emit a warning about this operation, reporting up to any diagnostic /// handlers that may be listening. void OpState::emitWarning(const Twine &message) { - getInstruction()->emitWarning(message); + getOperation()->emitWarning(message); } /// Emit a note about this operation, reporting up to any diagnostic /// handlers that may be listening. void OpState::emitNote(const Twine &message) { - getInstruction()->emitNote(message); + getOperation()->emitNote(message); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index 404a0e9a2cb..7132408bfa5 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -46,18 +46,17 @@ void Pattern::anchor() {} // RewritePattern and PatternRewriter implementation //===----------------------------------------------------------------------===// -void RewritePattern::rewrite(Instruction *op, - std::unique_ptr state, +void RewritePattern::rewrite(Operation *op, std::unique_ptr state, PatternRewriter &rewriter) const { rewrite(op, rewriter); } -void RewritePattern::rewrite(Instruction *op, PatternRewriter &rewriter) const { +void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const { llvm_unreachable("need to implement either matchAndRewrite or one of the " "rewrite functions!"); } -PatternMatchResult RewritePattern::match(Instruction *op) const { +PatternMatchResult RewritePattern::match(Operation *op) const { llvm_unreachable("need to implement either match or matchAndRewrite!"); } @@ -71,7 +70,7 @@ PatternRewriter::~PatternRewriter() { /// clients can specify a list of other nodes that this replacement may make /// (perhaps transitively) dead. If any of those ops are dead, this will /// remove them as well. -void PatternRewriter::replaceOp(Instruction *op, ArrayRef newValues, +void PatternRewriter::replaceOp(Operation *op, ArrayRef newValues, ArrayRef valuesToRemoveIfDead) { // Notify the rewriter subclass that we're about to replace this root. notifyRootReplaced(op); @@ -91,8 +90,7 @@ void PatternRewriter::replaceOp(Instruction *op, ArrayRef newValues, /// op and newOp are known to have the same number of results, replace the /// uses of op with uses of newOp void PatternRewriter::replaceOpWithResultsOfAnotherOp( - Instruction *op, Instruction *newOp, - ArrayRef valuesToRemoveIfDead) { + Operation *op, Operation *newOp, ArrayRef valuesToRemoveIfDead) { assert(op->getNumResults() == newOp->getNumResults() && "replacement op doesn't match results of original op"); if (op->getNumResults() == 1) @@ -105,14 +103,14 @@ void PatternRewriter::replaceOpWithResultsOfAnotherOp( /// This method is used as the final notification hook for patterns that end /// up modifying the pattern root in place, by changing its operands. This is -/// a minor efficiency win (it avoids creating a new instruction and removing +/// a minor efficiency win (it avoids creating a new operation and removing /// the old one) but also often allows simpler code in the client. /// /// The opsToRemoveIfDead list is an optional list of nodes that the rewriter /// should remove if they are dead at this point. /// void PatternRewriter::updatedRootInPlace( - Instruction *op, ArrayRef valuesToRemoveIfDead) { + Operation *op, ArrayRef valuesToRemoveIfDead) { // Notify the rewriter subclass that we're about to replace this root. notifyRootUpdated(op); @@ -136,7 +134,7 @@ RewritePatternMatcher::RewritePatternMatcher( } /// Try to match the given operation to a pattern and rewrite it. -void RewritePatternMatcher::matchAndRewrite(Instruction *op) { +void RewritePatternMatcher::matchAndRewrite(Operation *op) { for (auto &pattern : patterns) { // Ignore patterns that are for the wrong root or are impossible to match. if (pattern->getRootKind() != op->getName() || diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp index b369794770a..bd81784c512 100644 --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -20,10 +20,10 @@ #include "mlir/IR/Operation.h" using namespace mlir; -/// If this value is the result of an Instruction, return the instruction -/// that defines it. -Instruction *Value::getDefiningInst() { - if (auto *result = dyn_cast(this)) +/// If this value is the result of an Operation, return the operation that +/// defines it. +Operation *Value::getDefiningOp() { + if (auto *result = dyn_cast(this)) return result->getOwner(); return nullptr; } @@ -33,8 +33,8 @@ Function *Value::getFunction() { switch (getKind()) { case Value::Kind::BlockArgument: return cast(this)->getFunction(); - case Value::Kind::InstResult: - return getDefiningInst()->getFunction(); + case Value::Kind::OpResult: + return getDefiningOp()->getFunction(); } } diff --git a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp index 6bf460a5d24..222f4a657bd 100644 --- a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp +++ b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp @@ -450,7 +450,7 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern { if (numResults == 0) return {}; if (numResults == 1) - return {newOp.getInstruction()->getResult(0)}; + return {newOp.getOperation()->getResult(0)}; // Otherwise, it had been converted to an operation producing a structure. // Extract individual results from the structure and return them as list. @@ -460,7 +460,7 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern { auto type = TypeConverter::convert(op->getResult(i)->getType(), this->dialect.getLLVMModule()); results.push_back(rewriter.create( - op->getLoc(), type, newOp.getInstruction()->getResult(0), + op->getLoc(), type, newOp.getOperation()->getResult(0), this->getIntegerArrayAttr(rewriter, i))); } return results; diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index f5d55f73df9..62258595cd7 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -2534,7 +2534,7 @@ FunctionParser::~FunctionParser() { // Drop all uses of undefined forward declared reference and destroy // defining instruction. fwd.first->dropAllUses(); - fwd.first->getDefiningInst()->destroy(); + fwd.first->getDefiningOp()->destroy(); } } @@ -2560,7 +2560,7 @@ ParseResult FunctionParser::addDefinition(SSAUseInfo useInfo, Value *value) { // the actual definition instead, delete the forward ref, and remove it // from our set of forward references we track. existing->replaceAllUsesWith(value); - existing->getDefiningInst()->destroy(); + existing->getDefiningOp()->destroy(); forwardReferencePlaceholders.erase(existing); } diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 228a4a5acc4..6bbf8c74bbb 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -133,7 +133,7 @@ struct MemRefCastFolder : public RewritePattern { void rewrite(Instruction *op, PatternRewriter &rewriter) const override { for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) - if (auto *memref = op->getOperand(i)->getDefiningInst()) + if (auto *memref = op->getOperand(i)->getDefiningOp()) if (auto cast = memref->dyn_cast()) op->setOperand(i, cast.getOperand()); rewriter.updatedRootInPlace(op); @@ -270,7 +270,7 @@ bool AllocOp::verify() { // Check that the total number of operands matches the number of symbols in // the affine map, plus the number of dynamic dimensions specified in the // memref type. - if (getInstruction()->getNumOperands() != numDynamicDims + numSymbols) { + if (getOperation()->getNumOperands() != numDynamicDims + numSymbols) { return emitOpError( "operand count does not equal dimension plus symbol operand count"); } @@ -318,7 +318,7 @@ struct SimplifyAllocConst : public RewritePattern { newShapeConstants.push_back(dimSize); continue; } - auto *defOp = allocOp.getOperand(dynamicDimPos)->getDefiningInst(); + auto *defOp = allocOp.getOperand(dynamicDimPos)->getDefiningOp(); ConstantIndexOp constantIndexOp; if (defOp && (constantIndexOp = defOp->dyn_cast())) { // Dynamic shape dimension will be folded. @@ -396,17 +396,17 @@ bool BranchOp::parse(OpAsmParser *parser, OperationState *result) { void BranchOp::print(OpAsmPrinter *p) { *p << "br "; - p->printSuccessorAndUseList(getInstruction(), 0); + p->printSuccessorAndUseList(getOperation(), 0); } -Block *BranchOp::getDest() { return getInstruction()->getSuccessor(0); } +Block *BranchOp::getDest() { return getOperation()->getSuccessor(0); } void BranchOp::setDest(Block *block) { - return getInstruction()->setSuccessor(block, 0); + return getOperation()->setSuccessor(block, 0); } void BranchOp::eraseOperand(unsigned index) { - getInstruction()->eraseSuccessorOperand(0, index); + getOperation()->eraseSuccessorOperand(0, index); } //===----------------------------------------------------------------------===// @@ -869,9 +869,9 @@ void CondBranchOp::print(OpAsmPrinter *p) { *p << "cond_br "; p->printOperand(getCondition()); *p << ", "; - p->printSuccessorAndUseList(getInstruction(), trueIndex); + p->printSuccessorAndUseList(getOperation(), trueIndex); *p << ", "; - p->printSuccessorAndUseList(getInstruction(), falseIndex); + p->printSuccessorAndUseList(getOperation(), falseIndex); } bool CondBranchOp::verify() { @@ -886,27 +886,27 @@ void CondBranchOp::getCanonicalizationPatterns( } Block *CondBranchOp::getTrueDest() { - return getInstruction()->getSuccessor(trueIndex); + return getOperation()->getSuccessor(trueIndex); } Block *CondBranchOp::getFalseDest() { - return getInstruction()->getSuccessor(falseIndex); + return getOperation()->getSuccessor(falseIndex); } unsigned CondBranchOp::getNumTrueOperands() { - return getInstruction()->getNumSuccessorOperands(trueIndex); + return getOperation()->getNumSuccessorOperands(trueIndex); } void CondBranchOp::eraseTrueOperand(unsigned index) { - getInstruction()->eraseSuccessorOperand(trueIndex, index); + getOperation()->eraseSuccessorOperand(trueIndex, index); } unsigned CondBranchOp::getNumFalseOperands() { - return getInstruction()->getNumSuccessorOperands(falseIndex); + return getOperation()->getNumSuccessorOperands(falseIndex); } void CondBranchOp::eraseFalseOperand(unsigned index) { - getInstruction()->eraseSuccessorOperand(falseIndex, index); + getOperation()->eraseSuccessorOperand(falseIndex, index); } //===----------------------------------------------------------------------===// @@ -1095,7 +1095,7 @@ struct SimplifyDeadDealloc : public RewritePattern { // Check that the memref operand's defining instruction is an AllocOp. Value *memref = dealloc.getMemRef(); - Instruction *defOp = memref->getDefiningInst(); + Instruction *defOp = memref->getDefiningOp(); if (!defOp || !defOp->isa()) return matchFailure(); @@ -1802,7 +1802,7 @@ void ReturnOp::print(OpAsmPrinter *p) { } bool ReturnOp::verify() { - auto *function = getInstruction()->getFunction(); + auto *function = getOperation()->getFunction(); // The operand number and types must match the function signature. const auto &results = function->getType().getResults(); @@ -1852,7 +1852,7 @@ bool SelectOp::parse(OpAsmParser *parser, OperationState *result) { void SelectOp::print(OpAsmPrinter *p) { *p << "select "; - p->printOperands(getInstruction()->getOperands()); + p->printOperands(getOperation()->getOperands()); *p << " : " << getTrueValue()->getType(); p->printOptionalAttrDict(getAttrs()); } diff --git a/mlir/lib/SuperVectorOps/SuperVectorOps.cpp b/mlir/lib/SuperVectorOps/SuperVectorOps.cpp index 1e0c01a5df1..15bd31f4bff 100644 --- a/mlir/lib/SuperVectorOps/SuperVectorOps.cpp +++ b/mlir/lib/SuperVectorOps/SuperVectorOps.cpp @@ -87,7 +87,7 @@ void VectorTransferReadOp::build(Builder *builder, OperationState *result, llvm::iterator_range VectorTransferReadOp::getIndices() { - auto begin = getInstruction()->operand_begin() + Offsets::FirstIndexOffset; + auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset; auto end = begin + getMemRefType().getRank(); return {begin, end}; } @@ -288,7 +288,7 @@ void VectorTransferWriteOp::build(Builder *builder, OperationState *result, llvm::iterator_range VectorTransferWriteOp::getIndices() { - auto begin = getInstruction()->operand_begin() + Offsets::FirstIndexOffset; + auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset; auto end = begin + getMemRefType().getRank(); return {begin, end}; } diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index ece87ce6b6c..4c4c8cc4019 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -54,7 +54,7 @@ void ConstantFold::foldInstruction(Instruction *op) { SmallVector operandConstants; for (auto *operand : op->getOperands()) { Attribute operandCst = nullptr; - if (auto *operandOp = operand->getDefiningInst()) { + if (auto *operandOp = operand->getDefiningOp()) { if (auto operandConstantOp = operandOp->dyn_cast()) operandCst = operandConstantOp.getValue(); } @@ -112,7 +112,7 @@ void ConstantFold::runOnFunction() { // around dead constants. Check for them now and remove them. for (auto *cst : existingConstants) { if (cst->use_empty()) - cst->getDefiningInst()->erase(); + cst->getDefiningOp()->erase(); } } diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index e20472770ae..50edcc6e64c 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -206,7 +206,7 @@ static bool getFullMemRefAsRegion(Instruction *opInst, unsigned numParamLoopIVs, } static void emitNoteForBlock(Block &block, const Twine &message) { - auto *inst = block.getContainingInst(); + auto *inst = block.getContainingOp(); if (!inst) { block.getFunction()->emitNote(message); } else { @@ -403,7 +403,7 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, Block *block, zeroIndex, stride, numEltPerStride); // Since new ops are being appended (for outgoing DMAs), adjust the end to // mark end of range of the original. - *nEnd = Block::iterator(op.getInstruction()); + *nEnd = Block::iterator(op.getOperation()); } // Matching DMA wait to block on completion; tag always has a 0 index. @@ -414,7 +414,7 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, Block *block, if (*nEnd == end) // Since new ops are being appended (for outgoing DMAs), adjust the end to // mark end of range of the original. - *nEnd = Block::iterator(tagDeallocOp.getInstruction()); + *nEnd = Block::iterator(tagDeallocOp.getOperation()); // Generate dealloc for the DMA buffer. if (!existingBuf) @@ -567,9 +567,9 @@ findHighestBlockForPlacement(const MemRefRegion ®ion, Block &block, if (it != enclosingFors.rbegin()) { auto lastInvariantIV = *std::prev(it); - *dmaPlacementReadStart = Block::iterator(lastInvariantIV.getInstruction()); + *dmaPlacementReadStart = Block::iterator(lastInvariantIV.getOperation()); *dmaPlacementWriteStart = std::next(*dmaPlacementReadStart); - *dmaPlacementBlock = lastInvariantIV.getInstruction()->getBlock(); + *dmaPlacementBlock = lastInvariantIV.getOperation()->getBlock(); } else { *dmaPlacementReadStart = begin; *dmaPlacementWriteStart = end; @@ -744,7 +744,7 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { if (totalDmaBuffersSizeInBytes > fastMemCapacityBytes) { StringRef str = "Total size of all DMA buffers' for this block " "exceeds fast memory capacity\n"; - if (auto *inst = block->getContainingInst()) + if (auto *inst = block->getContainingOp()) inst->emitError(str); else block->getFunction()->emitError(str); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index df5005bc7b1..d76aca20b6d 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -301,7 +301,7 @@ public: Node *node = getNode(id); for (auto *storeOpInst : node->stores) { auto *memref = storeOpInst->cast().getMemRef(); - auto *inst = memref->getDefiningInst(); + auto *inst = memref->getDefiningOp(); // Return true if 'memref' is a block argument. if (!inst) return true; @@ -696,8 +696,8 @@ bool MemRefDependenceGraph::init(Function &f) { getLoopIVs(*use.getOwner(), &loops); if (loops.empty()) continue; - assert(forToNodeMap.count(loops[0].getInstruction()) > 0); - unsigned userLoopNestId = forToNodeMap[loops[0].getInstruction()]; + assert(forToNodeMap.count(loops[0].getOperation()) > 0); + unsigned userLoopNestId = forToNodeMap[loops[0].getOperation()]; addEdge(node.id, userLoopNestId, value); } } @@ -745,8 +745,8 @@ struct LoopNestStatsCollector { void collect(Instruction *inst) { inst->walk([&](AffineForOp forOp) { - auto *forInst = forOp.getInstruction(); - auto *parentInst = forOp.getInstruction()->getParentInst(); + auto *forInst = forOp.getOperation(); + auto *parentInst = forOp.getOperation()->getParentInst(); if (parentInst != nullptr) { assert(parentInst->isa() && "Expected parent AffineForOp"); // Add mapping to 'forOp' from its parent AffineForOp. @@ -796,7 +796,7 @@ static int64_t getComputeCost( int64_t opCount = stats->opCountMap[forInst]; if (stats->loopMap.count(forInst) > 0) { for (auto childForOp : stats->loopMap[forInst]) { - opCount += getComputeCost(childForOp.getInstruction(), stats, + opCount += getComputeCost(childForOp.getOperation(), stats, tripCountOverrideMap, computeCostMap); } } @@ -856,7 +856,7 @@ static bool buildSliceTripCountMap( // The iteration of src loop IV 'i' was not sliced. Use full loop bounds. if (srcLoopIVs[i].hasConstantLowerBound() && srcLoopIVs[i].hasConstantUpperBound()) { - (*tripCountMap)[srcLoopIVs[i].getInstruction()] = + (*tripCountMap)[srcLoopIVs[i].getOperation()] = srcLoopIVs[i].getConstantUpperBound() - srcLoopIVs[i].getConstantLowerBound(); continue; @@ -866,7 +866,7 @@ static bool buildSliceTripCountMap( Optional tripCount = getConstDifference(lbMap, ubMap); if (!tripCount.hasValue()) return false; - (*tripCountMap)[srcLoopIVs[i].getInstruction()] = tripCount.getValue(); + (*tripCountMap)[srcLoopIVs[i].getOperation()] = tripCount.getValue(); } return true; } @@ -1091,7 +1091,7 @@ static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) { } } assert(loopNestRootIndex != -1 && "invalid root index"); - node->inst = loops[loopNestRootIndex].getInstruction(); + node->inst = loops[loopNestRootIndex].getOperation(); } // TODO(mlir-team): improve/complete this when we have target data. @@ -1119,7 +1119,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, unsigned dstLoopDepth, Optional fastMemorySpace, uint64_t localBufSizeThreshold) { - auto *forInst = forOp.getInstruction(); + auto *forInst = forOp.getOperation(); // Create builder to insert alloc op just before 'forOp'. FuncBuilder b(forInst); @@ -1437,7 +1437,7 @@ static bool isFusionProfitable(Instruction *srcOpInst, // Walk src loop nest and collect stats. LoopNestStats srcLoopNestStats; LoopNestStatsCollector srcStatsCollector(&srcLoopNestStats); - srcStatsCollector.collect(srcLoopIVs[0].getInstruction()); + srcStatsCollector.collect(srcLoopIVs[0].getOperation()); // Currently only constant trip count loop nests are supported. if (srcStatsCollector.hasLoopWithNonConstTripCount) { LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count loops unsupported.\n"); @@ -1449,7 +1449,7 @@ static bool isFusionProfitable(Instruction *srcOpInst, LoopNestStats dstLoopNestStats; LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats); - dstStatsCollector.collect(dstLoopIVs[0].getInstruction()); + dstStatsCollector.collect(dstLoopIVs[0].getOperation()); // Currently only constant trip count loop nests are supported. if (dstStatsCollector.hasLoopWithNonConstTripCount) { LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count loops unsupported.\n"); @@ -1484,7 +1484,7 @@ static bool isFusionProfitable(Instruction *srcOpInst, // Compute op instance count for the src loop nest without iteration slicing. uint64_t srcLoopNestCost = - getComputeCost(srcLoopIVs[0].getInstruction(), &srcLoopNestStats, + getComputeCost(srcLoopIVs[0].getOperation(), &srcLoopNestStats, /*tripCountOverrideMap=*/nullptr, /*computeCostMap=*/nullptr); @@ -1504,7 +1504,7 @@ static bool isFusionProfitable(Instruction *srcOpInst, // Compute op instance count for the src loop nest. uint64_t dstLoopNestCost = - getComputeCost(dstLoopIVs[0].getInstruction(), &dstLoopNestStats, + getComputeCost(dstLoopIVs[0].getOperation(), &dstLoopNestStats, /*tripCountOverrideMap=*/nullptr, /*computeCostMap=*/nullptr); @@ -1543,7 +1543,7 @@ static bool isFusionProfitable(Instruction *srcOpInst, // TODO(andydavis) Add load coalescing to memref data flow opt pass. if (storeLoadFwdGuaranteed) { // A single store disappears: -1 for that. - computeCostMap[srcLoopIVs[numSrcLoopIVs - 1].getInstruction()] = -1; + computeCostMap[srcLoopIVs[numSrcLoopIVs - 1].getOperation()] = -1; for (auto *loadOp : dstLoadOpInsts) { auto *parentInst = loadOp->getParentInst(); if (parentInst && parentInst->isa()) @@ -1553,15 +1553,15 @@ static bool isFusionProfitable(Instruction *srcOpInst, // Compute op instance count for the src loop nest with iteration slicing. int64_t sliceComputeCost = - getComputeCost(srcLoopIVs[0].getInstruction(), &srcLoopNestStats, + getComputeCost(srcLoopIVs[0].getOperation(), &srcLoopNestStats, /*tripCountOverrideMap=*/&sliceTripCountMap, /*computeCostMap=*/&computeCostMap); // Compute cost of fusion for this depth. - computeCostMap[dstLoopIVs[i - 1].getInstruction()] = sliceComputeCost; + computeCostMap[dstLoopIVs[i - 1].getOperation()] = sliceComputeCost; int64_t fusedLoopNestComputeCost = - getComputeCost(dstLoopIVs[0].getInstruction(), &dstLoopNestStats, + getComputeCost(dstLoopIVs[0].getOperation(), &dstLoopNestStats, /*tripCountOverrideMap=*/nullptr, &computeCostMap); double additionalComputeFraction = @@ -1936,18 +1936,18 @@ public: srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState); if (sliceLoopNest) { LLVM_DEBUG(llvm::dbgs() << "\tslice loop nest:\n" - << *sliceLoopNest.getInstruction() << "\n"); + << *sliceLoopNest.getOperation() << "\n"); // Move 'dstAffineForOp' before 'insertPointInst' if needed. auto dstAffineForOp = dstNode->inst->cast(); - if (insertPointInst != dstAffineForOp.getInstruction()) { - dstAffineForOp.getInstruction()->moveBefore(insertPointInst); + if (insertPointInst != dstAffineForOp.getOperation()) { + dstAffineForOp.getOperation()->moveBefore(insertPointInst); } // Update edges between 'srcNode' and 'dstNode'. mdg->updateEdges(srcNode->id, dstNode->id, memref); // Collect slice loop stats. LoopNestStateCollector sliceCollector; - sliceCollector.collect(sliceLoopNest.getInstruction()); + sliceCollector.collect(sliceLoopNest.getOperation()); // Promote single iteration slice loops to single IV value. for (auto forOp : sliceCollector.forOps) { promoteIfSingleIteration(forOp); @@ -1966,14 +1966,14 @@ public: visitedMemrefs.insert(newMemRef); // Create new node in dependence graph for 'newMemRef' alloc op. unsigned newMemRefNodeId = - mdg->addNode(newMemRef->getDefiningInst()); + mdg->addNode(newMemRef->getDefiningOp()); // Add edge from 'newMemRef' node to dstNode. mdg->addEdge(newMemRefNodeId, dstId, newMemRef); } // Collect dst loop stats after memref privatizaton transformation. LoopNestStateCollector dstLoopCollector; - dstLoopCollector.collect(dstAffineForOp.getInstruction()); + dstLoopCollector.collect(dstAffineForOp.getOperation()); // Add new load ops to current Node load op list 'loads' to // continue fusing based on new operands. @@ -2096,8 +2096,8 @@ public: if (sliceLoopNest != nullptr) { auto dstForInst = dstNode->inst->cast(); // Update instruction position of fused loop nest (if needed). - if (insertPointInst != dstForInst.getInstruction()) { - dstForInst.getInstruction()->moveBefore(insertPointInst); + if (insertPointInst != dstForInst.getOperation()) { + dstForInst.getOperation()->moveBefore(insertPointInst); } // Update data dependence graph state post fusion. updateStateAfterSiblingFusion(sliceLoopNest, sibNode, dstNode); @@ -2189,7 +2189,7 @@ public: // Collect slice loop stats. LoopNestStateCollector sliceCollector; - sliceCollector.collect(sliceLoopNest.getInstruction()); + sliceCollector.collect(sliceLoopNest.getOperation()); // Promote single iteration slice loops to single IV value. for (auto forOp : sliceCollector.forOps) { promoteIfSingleIteration(forOp); @@ -2198,7 +2198,7 @@ public: // Collect dst loop stats after memref privatizaton transformation. auto dstForInst = dstNode->inst->cast(); LoopNestStateCollector dstLoopCollector; - dstLoopCollector.collect(dstForInst.getInstruction()); + dstLoopCollector.collect(dstForInst.getOperation()); // Clear and add back loads and stores mdg->clearNodeLoadAndStores(dstNode->id); mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts, @@ -2222,7 +2222,7 @@ public: if (!memref->use_empty()) continue; // Use list expected to match the dep graph info. - auto *inst = memref->getDefiningInst(); + auto *inst = memref->getDefiningOp(); if (inst && inst->isa()) inst->erase(); } diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index eafa7bca4d4..d9f74808ad8 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -92,8 +92,7 @@ FunctionPassBase *mlir::createLoopTilingPass(uint64_t cacheSizeBytes) { // location in destination's body. static inline void moveLoopBody(AffineForOp src, AffineForOp dest, Block::iterator loc) { - dest.getBody()->getInstructions().splice(loc, - src.getBody()->getInstructions()); + dest.getBody()->getOperations().splice(loc, src.getBody()->getOperations()); } // Move the loop body of AffineForOp 'src' from 'src' to the start of dest's @@ -114,7 +113,7 @@ constructTiledIndexSetHyperRect(MutableArrayRef origLoops, assert(!origLoops.empty()); assert(origLoops.size() == tileSizes.size()); - FuncBuilder b(origLoops[0].getInstruction()); + FuncBuilder b(origLoops[0].getOperation()); unsigned width = origLoops.size(); // Bounds for tile space loops. @@ -181,8 +180,8 @@ LogicalResult mlir::tileCodeGen(MutableArrayRef band, // Check if the supplied for inst's are all successively nested. for (unsigned i = 1, e = band.size(); i < e; i++) { - assert(band[i].getInstruction()->getParentInst() == - band[i - 1].getInstruction()); + assert(band[i].getOperation()->getParentInst() == + band[i - 1].getOperation()); } auto origLoops = band; @@ -196,7 +195,7 @@ LogicalResult mlir::tileCodeGen(MutableArrayRef band, AffineForOp innermostPointLoop; // The outermost among the loops as we add more.. - auto *topLoop = rootAffineForOp.getInstruction(); + auto *topLoop = rootAffineForOp.getOperation(); // Add intra-tile (or point) loops. for (unsigned i = 0; i < width; i++) { @@ -204,11 +203,11 @@ LogicalResult mlir::tileCodeGen(MutableArrayRef band, // Loop bounds will be set later. auto pointLoop = b.create(loc, 0, 0); pointLoop.createBody(); - pointLoop.getBody()->getInstructions().splice( - pointLoop.getBody()->begin(), topLoop->getBlock()->getInstructions(), + pointLoop.getBody()->getOperations().splice( + pointLoop.getBody()->begin(), topLoop->getBlock()->getOperations(), topLoop); newLoops[2 * width - 1 - i] = pointLoop; - topLoop = pointLoop.getInstruction(); + topLoop = pointLoop.getOperation(); if (i == 0) innermostPointLoop = pointLoop; } @@ -219,11 +218,11 @@ LogicalResult mlir::tileCodeGen(MutableArrayRef band, // Loop bounds will be set later. auto tileSpaceLoop = b.create(loc, 0, 0); tileSpaceLoop.createBody(); - tileSpaceLoop.getBody()->getInstructions().splice( - tileSpaceLoop.getBody()->begin(), - topLoop->getBlock()->getInstructions(), topLoop); + tileSpaceLoop.getBody()->getOperations().splice( + tileSpaceLoop.getBody()->begin(), topLoop->getBlock()->getOperations(), + topLoop); newLoops[2 * width - i - 1] = tileSpaceLoop; - topLoop = tileSpaceLoop.getInstruction(); + topLoop = tileSpaceLoop.getOperation(); } // Move the loop body of the original nest to the new one. @@ -265,7 +264,7 @@ static void getTileableBands(Function &f, AffineForOp currInst = root; do { band.push_back(currInst); - } while (currInst.getBody()->getInstructions().size() == 1 && + } while (currInst.getBody()->getOperations().size() == 1 && (currInst = currInst.getBody()->front().dyn_cast())); bands->push_back(band); }; diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 174f93e4d2d..3fa4eab93da 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -173,7 +173,7 @@ LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp, mayBeConstantTripCount.getValue() < unrollJamFactor) return failure(); - auto *forInst = forOp.getInstruction(); + auto *forInst = forOp.getOperation(); // Gather all sub-blocks to jam upon the loop being unrolled. JamBlockGatherer jbg; diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index 162eed00b6c..5046bf2596b 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -321,7 +321,7 @@ static Value *buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate, // bool LowerAffinePass::lowerAffineFor(AffineForOp forOp) { auto loc = forOp.getLoc(); - auto *forInst = forOp.getInstruction(); + auto *forInst = forOp.getOperation(); // Start by splitting the block containing the 'affine.for' into two parts. // The part before will get the init code, the part after will be the end @@ -340,9 +340,9 @@ bool LowerAffinePass::lowerAffineFor(AffineForOp forOp) { bodyBlock->insertBefore(endBlock); auto *oldBody = forOp.getBody(); - bodyBlock->getInstructions().splice(bodyBlock->begin(), - oldBody->getInstructions(), - oldBody->begin(), oldBody->end()); + bodyBlock->getOperations().splice(bodyBlock->begin(), + oldBody->getOperations(), oldBody->begin(), + oldBody->end()); // The code in the body of the forOp now uses 'iv' as its indvar. forOp.getInductionVar()->replaceAllUsesWith(iv); @@ -454,7 +454,7 @@ bool LowerAffinePass::lowerAffineFor(AffineForOp forOp) { // +--------------------------------+ // bool LowerAffinePass::lowerAffineIf(AffineIfOp ifOp) { - auto *ifInst = ifOp.getInstruction(); + auto *ifInst = ifOp.getOperation(); auto loc = ifInst->getLoc(); // Start by splitting the block containing the 'affine.if' into two parts. The @@ -478,9 +478,9 @@ bool LowerAffinePass::lowerAffineIf(AffineIfOp ifOp) { Block *oldThen = &oldThenBlocks.front(); - thenBlock->getInstructions().splice(thenBlock->begin(), - oldThen->getInstructions(), - oldThen->begin(), oldThen->end()); + thenBlock->getOperations().splice(thenBlock->begin(), + oldThen->getOperations(), + oldThen->begin(), oldThen->end()); } FuncBuilder builder(thenBlock); @@ -499,9 +499,9 @@ bool LowerAffinePass::lowerAffineIf(AffineIfOp ifOp) { elseBlock = new Block(); elseBlock->insertBefore(continueBlock); - elseBlock->getInstructions().splice(elseBlock->begin(), - oldElse->getInstructions(), - oldElse->begin(), oldElse->end()); + elseBlock->getOperations().splice(elseBlock->begin(), + oldElse->getOperations(), + oldElse->begin(), oldElse->end()); builder.setInsertionPointToEnd(elseBlock); builder.create(loc, continueBlock); } @@ -570,7 +570,7 @@ bool LowerAffinePass::lowerAffineIf(AffineIfOp ifOp) { // Convert an "affine.apply" operation into a sequence of arithmetic // instructions using the StandardOps dialect. Return true on error. bool LowerAffinePass::lowerAffineApply(AffineApplyOp op) { - FuncBuilder builder(op.getInstruction()); + FuncBuilder builder(op.getOperation()); auto maybeExpandedMap = expandAffineMap(&builder, op.getLoc(), op.getAffineMap(), llvm::to_vector<8>(op.getOperands())); diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index e6b1950c222..708ad7d1693 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -266,8 +266,7 @@ template <> void VectorTransferRewriter::rewrite() { using namespace mlir::edsc::intrinsics; // 1. Setup all the captures. - ScopedContext scope(FuncBuilder(transfer.getInstruction()), - transfer.getLoc()); + ScopedContext scope(FuncBuilder(transfer.getOperation()), transfer.getLoc()); IndexedValue remote(transfer.getMemRef()); MemRefView view(transfer.getMemRef()); VectorView vectorView(transfer.getVector()); @@ -321,8 +320,7 @@ template <> void VectorTransferRewriter::rewrite() { using namespace mlir::edsc::intrinsics; // 1. Setup all the captures. - ScopedContext scope(FuncBuilder(transfer.getInstruction()), - transfer.getLoc()); + ScopedContext scope(FuncBuilder(transfer.getOperation()), transfer.getLoc()); IndexedValue remote(transfer.getMemRef()); MemRefView view(transfer.getMemRef()); ValueHandle vectorValue(transfer.getVector()); diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index a4deba26d83..8ea9d4e8020 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -256,7 +256,7 @@ static Value *substitute(Value *v, VectorType hwVectorType, DenseMap *substitutionsMap) { auto it = substitutionsMap->find(v); if (it == substitutionsMap->end()) { - auto *opInst = v->getDefiningInst(); + auto *opInst = v->getDefiningOp(); if (opInst->isa()) { FuncBuilder b(opInst); auto *inst = instantiate(&b, opInst, hwVectorType, substitutionsMap); @@ -265,7 +265,7 @@ static Value *substitute(Value *v, VectorType hwVectorType, assert(res.second && "Insertion failed"); return res.first->second; } - v->getDefiningInst()->emitError("Missing substitution"); + v->getDefiningOp()->emitError("Missing substitution"); return nullptr; } return it->second; @@ -496,7 +496,7 @@ static Instruction *instantiate(FuncBuilder *b, VectorTransferReadOp read, auto cloned = b->create(read.getLoc(), hwVectorType, read.getMemRef(), affineIndices, map, read.getPaddingValue()); - return cloned.getInstruction(); + return cloned.getOperation(); } /// Creates an instantiated version of `write` for the instance of @@ -518,7 +518,7 @@ static Instruction *instantiate(FuncBuilder *b, VectorTransferWriteOp write, substitute(write.getVector(), hwVectorType, substitutionsMap), write.getMemRef(), affineIndices, projectedPermutationMap(write, hwVectorType)); - return cloned.getInstruction(); + return cloned.getOperation(); } /// Returns `true` if inst instance is properly cloned and inserted, false diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index e1e253d1869..9779ab78a3f 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -95,7 +95,7 @@ FunctionPassBase *mlir::createMemRefDataFlowOptPass() { // this in the future if needed. void MemRefDataFlowOpt::forwardStoreToLoad(LoadOp loadOp) { Instruction *lastWriteStoreOp = nullptr; - Instruction *loadOpInst = loadOp.getInstruction(); + Instruction *loadOpInst = loadOp.getOperation(); // First pass over the use list to get minimum number of surrounding // loops common between the load op and the store op, with min taken across @@ -106,7 +106,7 @@ void MemRefDataFlowOpt::forwardStoreToLoad(LoadOp loadOp) { auto storeOp = use.getOwner()->dyn_cast(); if (!storeOp) continue; - auto *storeOpInst = storeOp.getInstruction(); + auto *storeOpInst = storeOp.getOperation(); unsigned nsLoops = getNumCommonSurroundingLoops(*loadOpInst, *storeOpInst); minSurroundingLoops = std::min(nsLoops, minSurroundingLoops); storeOps.push_back(storeOpInst); @@ -236,7 +236,7 @@ void MemRefDataFlowOpt::runOnFunction() { // to do this as well, but we'll do it here since we collected these anyway. for (auto *memref : memrefsToErase) { // If the memref hasn't been alloc'ed in this function, skip. - Instruction *defInst = memref->getDefiningInst(); + Instruction *defInst = memref->getDefiningOp(); if (!defInst || !defInst->isa()) // TODO(mlir-team): if the memref was returned by a 'call' instruction, we // could still erase it if the call had no side-effects. diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 051ac733c14..a7d37161aa1 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -93,7 +93,7 @@ static bool doubleBuffer(Value *oldMemRef, AffineForOp forOp) { auto newMemRefType = doubleShape(oldMemRefType); // The double buffer is allocated right before 'forInst'. - auto *forInst = forOp.getInstruction(); + auto *forInst = forOp.getOperation(); FuncBuilder bOuter(forInst); // Put together alloc operands for any dynamic dimensions of the memref. SmallVector allocOperands; @@ -287,14 +287,14 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { // order to create the double buffer above.) // '-canonicalize' does this in a more general way, but we'll anyway do the // simple/common case so that the output / test cases looks clear. - if (auto *allocInst = oldMemRef->getDefiningInst()) { + if (auto *allocInst = oldMemRef->getDefiningOp()) { if (oldMemRef->use_empty()) { allocInst->erase(); } else if (oldMemRef->hasOneUse()) { auto *singleUse = oldMemRef->use_begin()->getOwner(); if (singleUse->isa()) { singleUse->erase(); - oldMemRef->getDefiningInst()->erase(); + oldMemRef->getDefiningOp()->erase(); } } } @@ -312,7 +312,7 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { // If the old tag has no more uses, remove its 'dead' alloc if it was // alloc'ed. if (oldTagMemRef->use_empty()) - if (auto *allocInst = oldTagMemRef->getDefiningInst()) + if (auto *allocInst = oldTagMemRef->getDefiningOp()) allocInst->erase(); } @@ -331,7 +331,7 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { mlir::createAffineComputationSlice(dmaStartInst, &sliceOps); if (!sliceOps.empty()) { for (auto sliceOp : sliceOps) { - instShiftMap[sliceOp.getInstruction()] = 0; + instShiftMap[sliceOp.getOperation()] = 0; } } else { // If a slice wasn't created, the reachable affine.apply op's from its @@ -352,7 +352,7 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { } // Get shifts stored in map. - std::vector shifts(forOp.getBody()->getInstructions().size()); + std::vector shifts(forOp.getBody()->getOperations().size()); unsigned s = 0; for (auto &inst : *forOp.getBody()) { assert(instShiftMap.find(&inst) != instShiftMap.end()); diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index e8dce29729d..79a2b12d242 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -116,7 +116,7 @@ private: if (!operand->use_empty() && std::next(operand->use_begin()) != operand->use_end()) continue; - if (auto *defInst = operand->getDefiningInst()) + if (auto *defInst = operand->getDefiningOp()) addToWorklist(defInst); } } diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index a6f1c8dd1ca..9a7db193d29 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -102,7 +102,7 @@ void mlir::getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor, // Remove any affine.apply's that became dead from the simplification above. for (auto *v : bumpValues) { if (v->use_empty()) { - v->getDefiningInst()->erase(); + v->getDefiningOp()->erase(); } } if (lb.use_empty()) @@ -123,7 +123,7 @@ LogicalResult mlir::promoteIfSingleIteration(AffineForOp forOp) { // Replaces all IV uses to its single iteration value. auto *iv = forOp.getInductionVar(); - Instruction *forInst = forOp.getInstruction(); + Instruction *forInst = forOp.getOperation(); if (!iv->use_empty()) { if (forOp.hasConstantLowerBound()) { auto *mlFunc = forInst->getFunction(); @@ -147,8 +147,8 @@ LogicalResult mlir::promoteIfSingleIteration(AffineForOp forOp) { } // Move the loop body instructions to the loop's containing block. auto *block = forInst->getBlock(); - block->getInstructions().splice(Block::iterator(forInst), - forOp.getBody()->getInstructions()); + block->getOperations().splice(Block::iterator(forInst), + forOp.getBody()->getOperations()); forOp.erase(); return success(); } @@ -252,7 +252,7 @@ LogicalResult mlir::instBodySkew(AffineForOp forOp, ArrayRef shifts, int64_t step = forOp.getStep(); - unsigned numChildInsts = forOp.getBody()->getInstructions().size(); + unsigned numChildInsts = forOp.getBody()->getOperations().size(); // Do a linear time (counting) sort for the shifts. uint64_t maxShift = 0; @@ -290,7 +290,7 @@ LogicalResult mlir::instBodySkew(AffineForOp forOp, ArrayRef shifts, auto origLbMap = forOp.getLowerBoundMap(); uint64_t lbShift = 0; - FuncBuilder b(forOp.getInstruction()); + FuncBuilder b(forOp.getOperation()); for (uint64_t d = 0, e = sortedInstGroups.size(); d < e; ++d) { // If nothing is shifted by d, continue. if (sortedInstGroups[d].empty()) @@ -345,7 +345,7 @@ LogicalResult mlir::instBodySkew(AffineForOp forOp, ArrayRef shifts, if (unrollPrologueEpilogue && prologue) loopUnrollFull(prologue); if (unrollPrologueEpilogue && !epilogue && - epilogue.getInstruction() != prologue.getInstruction()) + epilogue.getOperation() != prologue.getOperation()) loopUnrollFull(epilogue); return success(); @@ -404,7 +404,7 @@ LogicalResult mlir::loopUnrollByFactor(AffineForOp forOp, return failure(); // Generate the cleanup loop if trip count isn't a multiple of unrollFactor. - Instruction *forInst = forOp.getInstruction(); + Instruction *forInst = forOp.getOperation(); if (getLargestDivisorOfTripCount(forOp) % unrollFactor != 0) { FuncBuilder builder(forInst->getBlock(), ++Block::iterator(forInst)); auto cleanupForInst = builder.clone(*forInst)->cast(); @@ -467,20 +467,20 @@ LogicalResult mlir::loopUnrollByFactor(AffineForOp forOp, /// Performs loop interchange on 'forOpA' and 'forOpB', where 'forOpB' is /// nested within 'forOpA' as the only instruction in its block. void mlir::interchangeLoops(AffineForOp forOpA, AffineForOp forOpB) { - auto *forOpAInst = forOpA.getInstruction(); + auto *forOpAInst = forOpA.getOperation(); // 1) Slice forOpA's instruction list (which is just forOpB) just before // forOpA (in forOpA's parent's block) this should leave 'forOpA's // instruction list empty (because its perfectly nested). - assert(&*forOpA.getBody()->begin() == forOpB.getInstruction()); - forOpAInst->getBlock()->getInstructions().splice( - Block::iterator(forOpAInst), forOpA.getBody()->getInstructions()); + assert(&*forOpA.getBody()->begin() == forOpB.getOperation()); + forOpAInst->getBlock()->getOperations().splice( + Block::iterator(forOpAInst), forOpA.getBody()->getOperations()); // 2) Slice forOpB's instruction list into forOpA's instruction list (this // leaves forOpB's instruction list empty). - forOpA.getBody()->getInstructions().splice( - forOpA.getBody()->begin(), forOpB.getBody()->getInstructions()); + forOpA.getBody()->getOperations().splice(forOpA.getBody()->begin(), + forOpB.getBody()->getOperations()); // 3) Slice forOpA into forOpB's instruction list. - forOpB.getBody()->getInstructions().splice( - forOpB.getBody()->begin(), forOpAInst->getBlock()->getInstructions(), + forOpB.getBody()->getOperations().splice( + forOpB.getBody()->begin(), forOpAInst->getBlock()->getOperations(), Block::iterator(forOpAInst)); } @@ -526,7 +526,7 @@ static void cloneLoopBodyInto(AffineForOp forOp, Value *oldIv, for (auto it = forOp.getBody()->begin(), end = forOp.getBody()->end(); it != end; ++it) { // Step over newForOp in case it is nested under forOp. - if (&*it == newForOp.getInstruction()) { + if (&*it == newForOp.getOperation()) { continue; } auto *inst = b.clone(*it, map); @@ -558,7 +558,7 @@ stripmineSink(AffineForOp forOp, uint64_t factor, auto scaledStep = originalStep * factor; forOp.setStep(scaledStep); - auto *forInst = forOp.getInstruction(); + auto *forInst = forOp.getOperation(); FuncBuilder b(forInst->getBlock(), ++Block::iterator(forInst)); // Lower-bound map creation. @@ -581,7 +581,7 @@ stripmineSink(AffineForOp forOp, uint64_t factor, newForOp.createBody(); cloneLoopBodyInto(t, forOp.getInductionVar(), newForOp); // Remove all instructions from `t` except `newForOp`. - auto rit = ++newForOp.getInstruction()->getReverseIterator(); + auto rit = ++newForOp.getOperation()->getReverseIterator(); auto re = t.getBody()->rend(); for (auto &inst : llvm::make_early_inc_range(llvm::make_range(rit, re))) { inst.erase(); diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 7a44a6277a6..b5225d08827 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -127,7 +127,7 @@ bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, FuncBuilder builder(opInst); for (auto *extraIndex : extraIndices) { - assert(extraIndex->getDefiningInst()->getNumResults() == 1 && + assert(extraIndex->getDefiningOp()->getNumResults() == 1 && "single result op's expected to generate these indices"); assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) && "invalid memory op index"); @@ -226,7 +226,7 @@ void mlir::createAffineComputationSlice( SmallVector subOperands; subOperands.reserve(opInst->getNumOperands()); for (auto *operand : opInst->getOperands()) { - auto *defInst = operand->getDefiningInst(); + auto *defInst = operand->getDefiningOp(); if (defInst && defInst->isa()) { subOperands.push_back(operand); } diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index a1e2c609653..3c6ab6c2cac 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -828,7 +828,7 @@ static LogicalResult vectorizeRootOrTerminal(Value *iv, auto vectorType = VectorType::get(state->strategy->vectorSizes, elementType); // Materialize a MemRef with 1 vector. - auto *opInst = memoryOp.getInstruction(); + auto *opInst = memoryOp.getOperation(); // For now, vector_transfers must be aligned, operate only on indices with an // identity subset of AffineMap and do not change layout. // TODO(ntv): increase the expressiveness power of vector_transfer operations @@ -842,7 +842,7 @@ static LogicalResult vectorizeRootOrTerminal(Value *iv, auto transfer = b.create( opInst->getLoc(), vectorType, memoryOp.getMemRef(), map(makePtrDynCaster(), memoryOp.getIndices()), permutationMap); - state->registerReplacement(opInst, transfer.getInstruction()); + state->registerReplacement(opInst, transfer.getOperation()); } else { state->registerTerminal(opInst); } @@ -867,7 +867,7 @@ static LogicalResult vectorizeAffineForOp(AffineForOp loop, int64_t step, }; auto loadAndStores = matcher::Op(notVectorizedThisPattern); SmallVector loadAndStoresMatches; - loadAndStores.match(loop.getInstruction(), &loadAndStoresMatches); + loadAndStores.match(loop.getOperation(), &loadAndStoresMatches); for (auto ls : loadAndStoresMatches) { auto *opInst = ls.getMatchedInstruction(); auto load = opInst->dyn_cast(); @@ -953,7 +953,7 @@ static Value *vectorizeConstant(Instruction *inst, ConstantOp constant, Location loc = inst->getLoc(); auto vectorType = type.cast(); auto attr = SplatElementsAttr::get(vectorType, constant.getValue()); - auto *constantOpInst = constant.getInstruction(); + auto *constantOpInst = constant.getOperation(); OperationState state(b.getContext(), loc, constantOpInst->getName().getStringRef(), {}, @@ -988,7 +988,7 @@ static Value *vectorizeOperand(Value *operand, Instruction *inst, LLVM_DEBUG(dbgs() << "\n[early-vect]vectorize operand: "); LLVM_DEBUG(operand->print(dbgs())); // 1. If this value has already been vectorized this round, we are done. - if (state->vectorizedSet.count(operand->getDefiningInst()) > 0) { + if (state->vectorizedSet.count(operand->getDefiningOp()) > 0) { LLVM_DEBUG(dbgs() << " -> already vector operand"); return operand; } @@ -1009,7 +1009,7 @@ static Value *vectorizeOperand(Value *operand, Instruction *inst, return nullptr; } // 3. vectorize constant. - if (auto constant = operand->getDefiningInst()->dyn_cast()) { + if (auto constant = operand->getDefiningOp()->dyn_cast()) { return vectorizeConstant( inst, constant, VectorType::get(state->strategy->vectorSizes, operand->getType())); @@ -1051,7 +1051,7 @@ static Instruction *vectorizeOneInstruction(Instruction *opInst, LLVM_DEBUG(permutationMap.print(dbgs())); auto transfer = b.create( opInst->getLoc(), vectorValue, memRef, indices, permutationMap); - auto *res = transfer.getInstruction(); + auto *res = transfer.getOperation(); LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << *res); // "Terminals" (i.e. StoreOps) are erased on the spot. opInst->erase(); @@ -1163,7 +1163,7 @@ static LogicalResult vectorizeRootMatch(NestedMatch m, /// Sets up error handling for this root loop. This is how the root match /// maintains a clone for handling failure and restores the proper state via /// RAII. - auto *loopInst = loop.getInstruction(); + auto *loopInst = loop.getOperation(); FuncBuilder builder(loopInst); auto clonedLoop = builder.clone(*loopInst)->cast(); struct Guard { diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp index 8a31eeb9f05..f607a3d4d3f 100644 --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -386,8 +386,8 @@ TEST_FUNC(custom_ops) { ih0 = MY_CUSTOM_INST_0({m, m + n}, {}), ih2 = MY_CUSTOM_INST_2({m, m + n}, {indexType, indexType}), // These captures are verbose for now, can improve when used in practice. - vh20 = ValueHandle(ih2.getInstruction()->getResult(0)), - vh21 = ValueHandle(ih2.getInstruction()->getResult(1)), + vh20 = ValueHandle(ih2.getOperation()->getResult(0)), + vh21 = ValueHandle(ih2.getOperation()->getResult(1)), MY_CUSTOM_OP({vh20, vh21}, {indexType}, {}), }); diff --git a/mlir/test/mlir-tblgen/op-operand.td b/mlir/test/mlir-tblgen/op-operand.td index 303d5c9be96..7849be2c00c 100644 --- a/mlir/test/mlir-tblgen/op-operand.td +++ b/mlir/test/mlir-tblgen/op-operand.td @@ -8,5 +8,5 @@ def OneOperandOp : Op<"one_operand_op", []> { // CHECK-LABEL: OneOperandOp definitions // CHECK: bool OneOperandOp::verify() { -// CHECK: if (!((this->getInstruction()->getOperand(0)->getType().isInteger(32)))) +// CHECK: if (!((this->getOperation()->getOperand(0)->getType().isInteger(32)))) // CHECK-NEXT: return emitOpError("operand #0 must be 32-bit integer"); diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td index f98564c5d28..d9e82956505 100644 --- a/mlir/test/mlir-tblgen/op-result.td +++ b/mlir/test/mlir-tblgen/op-result.td @@ -8,7 +8,7 @@ def OneResultOp : Op<"one_result_op", []> { // CHECK-LABEL: OneResultOp definitions // CHECK: bool OneResultOp::verify() { -// CHECK: if (!((this->getInstruction()->getResult(0)->getType().isInteger(32)))) +// CHECK: if (!((this->getOperation()->getResult(0)->getType().isInteger(32)))) // CHECK-NEXT: return emitOpError("result #0 must be 32-bit integer"); diff --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td index 756505f1a1d..9dabf7955d7 100644 --- a/mlir/test/mlir-tblgen/predicate.td +++ b/mlir/test/mlir-tblgen/predicate.td @@ -34,12 +34,12 @@ def IdentityI32 : Op<"identity_i32", [PredOpTrait< // CHECK-LABEL: Identity::verify // Verify arg constraints. -// CHECK: this->getInstruction()->getOperand(0)->getType().cast().getElementType().isInteger(32) || -// CHECK-SAME: this->getInstruction()->getOperand(0)->getType().cast().getElementType().isF32() +// CHECK: this->getOperation()->getOperand(0)->getType().cast().getElementType().isInteger(32) || +// CHECK-SAME: this->getOperation()->getOperand(0)->getType().cast().getElementType().isF32() // Verify tautology constraint. -// CHECK: if (!((((*this->getInstruction()).getNumOperands() > std::max(0,0))) && (((*this->getInstruction()).getOperand(0)->getType().isa())) && (((*this->getInstruction()).getOperand(0)->getType().isa())) && (((*this->getInstruction()).getOperand(0)->getType().cast().getElementType() == (*this->getInstruction()).getOperand(0)->getType().cast().getElementType())))) +// CHECK: if (!((((*this->getOperation()).getNumOperands() > std::max(0,0))) && (((*this->getOperation()).getOperand(0)->getType().isa())) && (((*this->getOperation()).getOperand(0)->getType().isa())) && (((*this->getOperation()).getOperand(0)->getType().cast().getElementType() == (*this->getOperation()).getOperand(0)->getType().cast().getElementType())))) // CHECK-NEXT: return emitOpError("failed to verify that first operand is a vector or tensor with the same elemental type as itself"); // CHECK-LABEL: IdentityI32::verify -// CHECK: if (!((((*this->getInstruction()).getNumOperands() > 0)) && (((*this->getInstruction()).getOperand(0)->getType().isa())) && (((*this->getInstruction()).getOperand(0)->getType().cast().getElementType().isInteger(32))))) +// CHECK: if (!((((*this->getOperation()).getNumOperands() > 0)) && (((*this->getOperation()).getOperand(0)->getType().isa())) && (((*this->getOperation()).getOperand(0)->getType().cast().getElementType().isInteger(32))))) // CHECK-NEXT: return emitOpError("failed to verify that first operand has i32 element type"); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 0ce9719a414..198aa2599df 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -478,12 +478,12 @@ void OpEmitter::genNamedOperandGetters() { if (!operand.constraint.isVariadic()) { auto &m = opClass.newMethod("Value *", operand.name); - m.body() << " return this->getInstruction()->getOperand(" << i << ");\n"; + m.body() << " return this->getOperation()->getOperand(" << i << ");\n"; } else { assert(i + 1 == e && "only the last operand can be variadic"); const char *const code = R"( - assert(getInstruction()->getNumOperands() >= {0}); + assert(getOperation()->getNumOperands() >= {0}); return {std::next(operand_begin(), {0}), operand_end()}; )"; auto &m = opClass.newMethod("Instruction::operand_range", operand.name); @@ -499,7 +499,7 @@ void OpEmitter::genNamedResultGetters() { continue; auto &m = opClass.newMethod("Value *", result.name); - m.body() << " return this->getInstruction()->getResult(" << i << ");\n"; + m.body() << " return this->getOperation()->getResult(" << i << ");\n"; } } @@ -846,7 +846,7 @@ void OpEmitter::genVerifier() { auto description = value.constraint.getDescription(); body << " if (!(" << formatv(value.constraint.getConditionTemplate(), - "this->getInstruction()->get" + + "this->getOperation()->get" + Twine(isOperand ? "Operand" : "Result") + "(" + Twine(index) + ")->getType()") << "))\n"; @@ -869,7 +869,7 @@ void OpEmitter::genVerifier() { for (auto &trait : op.getTraits()) { if (auto t = dyn_cast(&trait)) { body << " if (!(" - << formatv(t->getPredTemplate().c_str(), "(*this->getInstruction())") + << formatv(t->getPredTemplate().c_str(), "(*this->getOperation())") << "))\n"; body << " return emitOpError(\"failed to verify that " << t->getDescription() << "\");\n"; diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 21276a9e4f9..23fb49501e1 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -184,9 +184,9 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) { // Handle nested DAG construct first if (DagNode argTree = tree.getArgAsNestedDag(i)) { os.indent(indent) << "{\n"; - os.indent(indent + 2) << formatv( - "auto op{0} = op{1}->getOperand({2})->getDefiningInst();\n", - depth + 1, depth, i); + os.indent(indent + 2) + << formatv("auto op{0} = op{1}->getOperand({2})->getDefiningOp();\n", + depth + 1, depth, i); emitOpMatch(argTree, depth + 1); os.indent(indent) << "}\n"; continue; -- cgit v1.2.3 From 5a5bba0279a5754c8e7aa2a9bf415aee2a0f1774 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Wed, 27 Mar 2019 05:11:58 -0700 Subject: Introduce affine terminator Due to legacy reasons (ML/CFG function separation), regions in affine control flow operations require contained blocks not to have terminators. This is inconsistent with the notion of the block and may complicate code motion between regions of affine control operations and other regions. Introduce `affine.terminator`, a special terminator operation that must be used to terminate blocks inside affine operations and transfers the control back to he region enclosing the affine operation. For brevity and readability reasons, allow `affine.for` and `affine.if` to omit the `affine.terminator` in their regions when using custom printing and parsing format. The custom parser injects the `affine.terminator` if it is missing so as to always have it present in constructed operations. Update transformations to account for the presence of terminator. In particular, most code motion transformation between loops should leave the terminator in place, and code motion between loops and non-affine blocks should drop the terminator. PiperOrigin-RevId: 240536998 --- mlir/g3doc/Dialects/Affine.md | 36 +++++++++++-- mlir/include/mlir/AffineOps/AffineOps.h | 47 ++++++++++++----- mlir/include/mlir/EDSC/Builders.h | 18 +++++-- mlir/include/mlir/IR/OpImplementation.h | 3 +- mlir/include/mlir/IR/Operation.h | 4 ++ mlir/lib/AffineOps/AffineOps.cpp | 83 +++++++++++++++++++++--------- mlir/lib/Analysis/Verifier.cpp | 13 +---- mlir/lib/EDSC/Builders.cpp | 5 +- mlir/lib/EDSC/MLIREmitter.cpp | 1 - mlir/lib/IR/AsmPrinter.cpp | 24 ++++++--- mlir/lib/IR/Builders.cpp | 11 +--- mlir/lib/IR/Operation.cpp | 12 +++++ mlir/lib/Transforms/LoopFusion.cpp | 6 +-- mlir/lib/Transforms/LoopTiling.cpp | 10 ++-- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 7 +-- mlir/lib/Transforms/LowerAffine.cpp | 24 +++++---- mlir/lib/Transforms/Utils/LoopUtils.cpp | 88 +++++++++++++++++--------------- mlir/test/AffineOps/ops.mlir | 41 ++++++++++++++- mlir/test/IR/invalid.mlir | 3 +- mlir/test/IR/parser.mlir | 2 + 20 files changed, 294 insertions(+), 144 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/g3doc/Dialects/Affine.md b/mlir/g3doc/Dialects/Affine.md index 0c69c60cbe9..02264d85b7e 100644 --- a/mlir/g3doc/Dialects/Affine.md +++ b/mlir/g3doc/Dialects/Affine.md @@ -60,9 +60,12 @@ upper-bound ::= `min`? affine-map dim-and-symbol-use-list | shorthand-bound shorthand-bound ::= ssa-id | `-`? integer-literal ``` -The `affine.for` operation represents an affine loop nest, defining an SSA value -for its induction variable. This SSA value always has type -[`index`](LangRef.md#index-type), which is the size of the machine word. +The `affine.for` operation represents an affine loop nest. It has one region +containing its body. This region must contain one block that terminates with +[`affine.terminator`](#'affine.terminator"-operation). *Note:* when `affine.for` +is printed in custom format, the terminator is omitted. The block has one +argument of [`index`](LangRef.md#index-type) type that represents the induction +variable of the loop. The `affine.for` operation executes its body a number of times iterating from a lower bound to an upper bound by a stride. The stride, represented by `step`, is @@ -124,6 +127,13 @@ and the SSA values bound to the dimensions and symbols in the integer set. The [same restrictions](#restrictions-on-dimensions-and-symbols) hold for these SSA values as for all bindings of SSA values to dimensions and symbols. +The `if` operation contains two regions for the "then" and "else" clauses. The +latter may be empty (i.e. contain no blocks), meaning the absence of the else +clause. When non-empty, both regions must contain exactly one block terminating +with [`affine.terminator`](#'affine.terminator'-operation). *Note:* when `if` is +printed in custom format, the terminator is omitted. These blocks must not have +any arguments. + Example: ```mlir {.mlir} @@ -143,3 +153,23 @@ func @reduced_domain_example(%A, %X, %N) : (memref<10xi32>, i32, i32) { return } ``` + +#### `affine.terminator` operation {#'affine.terminator'-operation} + +Syntax: + +``` {.ebnf} +operation ::= `"affine.terminator"() : () -> ()` +``` + +Affine terminator is a special terminator operation for blocks inside affine +loops ([`for`](#'for'-operation)) and branches ([`if`](#'if'-operation)). It +unconditionally transmits the control flow to the successor of the operation +enclosing the region. + +*Rationale*: bodies of affine operations are [blocks](LangRef.md#block) that +must have terminators. Loops and branches represent structured control flow and +should not accept arbitrary branches as terminators. + +This operation does _not_ have a custom syntax. However, affine control +operations omit the terminator in their custom syntax for brevity. diff --git a/mlir/include/mlir/AffineOps/AffineOps.h b/mlir/include/mlir/AffineOps/AffineOps.h index 65467677bd6..2108f16bb31 100644 --- a/mlir/include/mlir/AffineOps/AffineOps.h +++ b/mlir/include/mlir/AffineOps/AffineOps.h @@ -88,15 +88,18 @@ public: MLIRContext *context); }; -/// The "affine.for" instruction represents an affine loop nest, defining an SSA -/// value for its induction variable. The induction variable is represented as a -/// BlockArgument to the entry block of the body. The body and induction -/// variable can be created automatically for new "affine.for" ops with -/// 'createBody'. This SSA value always has type index, which is the size of the -/// machine word. The stride, represented by step, is a positive constant -/// integer which defaults to "1" if not present. The lower and upper bounds -/// specify a half-open range: the range includes the lower bound but does not -/// include the upper bound. +/// The "affine.for" operation represents an affine loop nest, defining an SSA +/// value for its induction variable. It has one region capturing the loop body. +/// The induction variable is represented as a argument of this region. This SSA +/// value always has type index, which is the size of the machine word. The +/// stride, represented by step, is a positive constant integer which defaults +/// to "1" if not present. The lower and upper bounds specify a half-open range: +/// the range includes the lower bound but does not include the upper bound. +/// +/// The body region must contain exactly one block that terminates with +/// "affine.terminator". Calling AffineForOp::build will create such region +/// and insert the terminator, so will the parsing even in cases if it is absent +/// from the custom format. /// /// The lower and upper bounds of a for operation are represented as an /// application of an affine mapping to a list of SSA values passed to the map. @@ -136,9 +139,9 @@ public: static StringRef getLowerBoundAttrName() { return "lower_bound"; } static StringRef getUpperBoundAttrName() { return "upper_bound"; } - /// Generate a body block for this AffineForOp. The operation must not already - /// have a body. The operation must contain a parent function. - Block *createBody(); + /// Return a Builder set up to insert operations immediately before the + /// terminator. + FuncBuilder getBodyBuilder(); /// Get the body of the AffineForOp. Block *getBody() { return &getRegion().front(); } @@ -322,6 +325,26 @@ public: void print(OpAsmPrinter *p); }; +/// Affine terminator is a special terminator operation for blocks inside affine +/// loops and branches. It unconditionally transmits the control flow to the +/// successor of the operation enclosing the region. +/// +/// This operation does _not_ have a custom syntax. However, affine control +/// operations omit the terminator in their custom syntax for brevity. +class AffineTerminatorOp + : public Op { +public: + using Op::Op; + + static void build(Builder *, OperationState *) {} + + static StringRef getOperationName() { return "affine.terminator"; } + +private: + friend Instruction; +}; + /// Returns true if the given Value can be used as a dimension id. bool isValidDim(Value *value); diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h index 4b0e5b938b5..e01f43e562e 100644 --- a/mlir/include/mlir/EDSC/Builders.h +++ b/mlir/include/mlir/EDSC/Builders.h @@ -112,9 +112,12 @@ protected: /// scoping itself, we use enter/exit pairs of instructions. /// As a consequence we must allocate a new FuncBuilder + ScopedContext and /// let the escape. - void enter(mlir::Block *block) { - bodyScope = new ScopedContext(FuncBuilder(block, block->end()), - ScopedContext::getLocation()); + /// Step back "prev" times from the end of the block to set up the insertion + /// point, which is useful for non-empty blocks. + void enter(mlir::Block *block, int prev = 0) { + bodyScope = + new ScopedContext(FuncBuilder(block, std::prev(block->end(), prev)), + ScopedContext::getLocation()); bodyScope->nestedBuilder = this; } @@ -326,6 +329,12 @@ public: bool hasType() const { return t != Type(); } Type getType() const { return t; } + Instruction *getOperation() const { + if (!v) + return nullptr; + return v->getDefiningOp(); + } + protected: ValueHandle() : t(), v(nullptr) {} @@ -359,7 +368,7 @@ struct InstructionHandle : public CapturableHandle { ArrayRef attributes = {}); operator Instruction *() { return inst; } - Instruction *getOperation() { return inst; } + Instruction *getOperation() const { return inst; } private: Instruction *inst; @@ -433,7 +442,6 @@ ValueHandle ValueHandle::create(Args... args) { return ValueHandle(inst->getResult(0)); } else if (inst->getNumResults() == 0) { if (auto f = inst->dyn_cast()) { - f.createBody(); return ValueHandle(f.getInductionVar()); } } diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index eeb35b2d51a..3fa1cccd8c8 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -89,7 +89,8 @@ public: virtual void printGenericOp(Operation *op) = 0; /// Prints a region. - virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true) = 0; + virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true, + bool printBlockTerminators = true) = 0; private: OpAsmPrinter(const OpAsmPrinter &) = delete; diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index 0fae3fc495f..ef7a6e56368 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -34,6 +34,7 @@ class BlockAndValueMapping; class Location; class MLIRContext; class OperandIterator; +class OperationState; class ResultIterator; class ResultTypeIterator; @@ -66,6 +67,9 @@ public: ArrayRef successors, unsigned numRegions, bool resizableOperandList, MLIRContext *context); + /// Create a new Operation from the fields stored in `state`. + static Operation *create(const OperationState &state); + /// The name of an operation is the key identifier for it. OperationName getName() { return name; } diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index c1f9606eba3..d23f2841e15 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -37,7 +37,7 @@ using llvm::dbgs; AffineOpsDialect::AffineOpsDialect(MLIRContext *context) : Dialect(/*namePrefix=*/"affine", context) { - addOperations(); + addOperations(); } /// A utility function to check if a value is defined at the top level of a @@ -690,6 +690,37 @@ void AffineApplyOp::getCanonicalizationPatterns( // AffineForOp //===----------------------------------------------------------------------===// +// Check that if a "block" has a terminator, it is an `AffineTerminatorOp`. +// Return true on success, report errors and return true on failure. +static bool checkHasAffineTerminator(OpState &op, Block &block) { + if (block.empty() || block.back().isa()) + return false; + + op.emitOpError("expects regions to end with '" + + AffineTerminatorOp::getOperationName() + "'"); + op.emitNote("in custom textual format, the absence of terminator implies '" + + AffineTerminatorOp::getOperationName() + "'"); + return true; +} + +// Insert `affine.terminator` at the end of the region's only block if it does +// not have a terminator already. If the region is empty, insert a new block +// first. +static void ensureAffineTerminator(Region ®ion, Builder &builder, + Location loc) { + if (region.empty()) + region.push_back(new Block); + + Block &block = region.back(); + if (!block.empty() && block.back().isKnownTerminator()) + return; + + OperationState terminatorState(builder.getContext(), loc, + AffineTerminatorOp::getOperationName()); + AffineTerminatorOp::build(&builder, &terminatorState); + block.push_back(Operation::create(terminatorState)); +} + void AffineForOp::build(Builder *builder, OperationState *result, ArrayRef lbOperands, AffineMap lbMap, ArrayRef ubOperands, AffineMap ubMap, @@ -716,8 +747,13 @@ void AffineForOp::build(Builder *builder, OperationState *result, builder->getAffineMapAttr(ubMap)); result->addOperands(ubOperands); - // Create a region for the body. - result->addRegion(); + // Create a region and a block for the body. The argument of the region is + // the loop induction variable. + Region *bodyRegion = result->addRegion(); + Block *body = new Block(); + body->addArgument(IndexType::get(builder->getContext())); + bodyRegion->push_back(body); + ensureAffineTerminator(*bodyRegion, *builder, result->location); // Set the operands list as resizable so that we can freely modify the bounds. result->setOperandListToResizable(); @@ -745,9 +781,8 @@ bool AffineForOp::verify() { return emitOpError("expected body to have a single index argument for the " "induction variable"); - // Check that the body has no terminator. - if (!body->empty() && body->back().isKnownTerminator()) - return emitOpError("expects body block to not have a terminator"); + if (checkHasAffineTerminator(*this, *body)) + return true; // Verify that there are enough operands for the bounds. AffineMap lowerBoundMap = getLowerBoundMap(), @@ -897,6 +932,8 @@ bool AffineForOp::parse(OpAsmParser *parser, OperationState *result) { if (parser->parseRegion(*body)) return true; + ensureAffineTerminator(*body, builder, result->location); + // Parse the optional attribute list. if (parser->parseOptionalAttributeDict(result->attributes)) return true; @@ -955,8 +992,9 @@ void AffineForOp::print(OpAsmPrinter *p) { if (getStep() != 1) *p << " step " << getStep(); - p->printRegion(getOperation()->getRegion(0), - /*printEntryBlockArgs=*/false); + p->printRegion(getRegion(), + /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/false); p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{getLowerBoundAttrName(), getUpperBoundAttrName(), @@ -1030,16 +1068,9 @@ void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results, results.push_back(llvm::make_unique(context)); } -Block *AffineForOp::createBody() { - auto &bodyRegion = getRegion(); - assert(bodyRegion.empty() && "expected no existing body blocks"); - - // Create a new block for the body, and add an argument for the induction - // variable. - Block *body = new Block(); - body->addArgument(IndexType::get(getContext())); - bodyRegion.push_back(body); - return body; +FuncBuilder AffineForOp::getBodyBuilder() { + Block *body = getBody(); + return FuncBuilder(body, std::prev(body->end())); } AffineBound AffineForOp::getLowerBound() { @@ -1215,8 +1246,7 @@ bool AffineIfOp::verify() { // regions. if (std::next(region.begin()) != region.end()) return emitOpError("expects only one block per 'then' or 'else' regions"); - if (region.front().back().isKnownTerminator()) - return emitOpError("expects region block to not have a terminator"); + checkHasAffineTerminator(*this, region.front()); for (auto &b : region) if (b.getNumArguments() != 0) @@ -1255,11 +1285,14 @@ bool AffineIfOp::parse(OpAsmParser *parser, OperationState *result) { // Parse the 'then' region. if (parser->parseRegion(*thenRegion)) return true; + ensureAffineTerminator(*thenRegion, parser->getBuilder(), result->location); // If we find an 'else' keyword then parse the 'else' region. - if (!parser->parseOptionalKeyword("else")) + if (!parser->parseOptionalKeyword("else")) { if (parser->parseRegion(*elseRegion)) return true; + ensureAffineTerminator(*elseRegion, parser->getBuilder(), result->location); + } // Parse the optional attribute list. if (parser->parseOptionalAttributeDict(result->attributes)) @@ -1273,13 +1306,17 @@ void AffineIfOp::print(OpAsmPrinter *p) { *p << "affine.if " << conditionAttr; printDimAndSymbolList(operand_begin(), operand_end(), conditionAttr.getValue().getNumDims(), p); - p->printRegion(getOperation()->getRegion(0)); + p->printRegion(getOperation()->getRegion(0), + /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/false); // Print the 'else' regions if it has any blocks. auto &elseRegion = getOperation()->getRegion(1); if (!elseRegion.empty()) { *p << " else"; - p->printRegion(elseRegion); + p->printRegion(elseRegion, + /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/false); } // Print the attribute list. diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index f211417b798..d7eb578ab1a 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -220,15 +220,6 @@ bool FuncVerifier::verify() { return false; } -// Returns if the given block is allowed to have no terminator. -static bool canBlockHaveNoTerminator(Block &block) { - // Allow the first block of an operation region to have no terminator if it is - // the only block in the region. - auto *parentList = block.getParent(); - return parentList->getContainingOp() && - std::next(parentList->begin()) == parentList->end(); -} - bool FuncVerifier::verifyBlock(Block &block, bool isTopLevel) { for (auto *arg : block.getArguments()) { if (arg->getOwner() != &block) @@ -237,8 +228,6 @@ bool FuncVerifier::verifyBlock(Block &block, bool isTopLevel) { // Verify that this block has a terminator. if (block.empty()) { - if (canBlockHaveNoTerminator(block)) - return false; return failure("block with no terminator", block); } @@ -257,7 +246,7 @@ bool FuncVerifier::verifyBlock(Block &block, bool isTopLevel) { // Verify the terminator. if (verifyOperation(block.back())) return true; - if (block.back().isKnownNonTerminator() && !canBlockHaveNoTerminator(block)) + if (block.back().isKnownNonTerminator()) return failure("block with no terminator", block); // Verify that this block is not branching to a block of a different diff --git a/mlir/lib/EDSC/Builders.cpp b/mlir/lib/EDSC/Builders.cpp index e991817b6d7..a6bd1d977d8 100644 --- a/mlir/lib/EDSC/Builders.cpp +++ b/mlir/lib/EDSC/Builders.cpp @@ -101,9 +101,6 @@ ValueHandle ValueHandle::create(StringRef name, ArrayRef operands, return ValueHandle(inst->getResult(0)); } if (auto f = inst->dyn_cast()) { - // Immediately create the loop body so we can just insert instructions right - // away. - f.createBody(); return ValueHandle(f.getInductionVar()); } llvm_unreachable("unsupported instruction, use an InstructionHandle instead"); @@ -174,7 +171,7 @@ mlir::edsc::LoopBuilder::LoopBuilder(ValueHandle *iv, step); } auto *body = getForInductionVarOwner(iv->getValue()).getBody(); - enter(body); + enter(body, /*prev=*/1); } ValueHandle diff --git a/mlir/lib/EDSC/MLIREmitter.cpp b/mlir/lib/EDSC/MLIREmitter.cpp index 6c6262c2790..069fb6cb40b 100644 --- a/mlir/lib/EDSC/MLIREmitter.cpp +++ b/mlir/lib/EDSC/MLIREmitter.cpp @@ -175,7 +175,6 @@ Value *mlir::edsc::MLIREmitter::emitExpr(Expr e) { forOp = builder->create( location, lbs, builder->getMultiDimIdentityMap(lbs.size()), ubs, builder->getMultiDimIdentityMap(ubs.size()), step); - forOp.createBody(); res = forOp.getInductionVar(); } } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index de6654cf532..ae550e7a3d2 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1060,7 +1060,8 @@ public: // Methods to print instructions. void print(Operation *inst); - void print(Block *block, bool printBlockArgs = true); + void print(Block *block, bool printBlockArgs = true, + bool printBlockTerminator = true); void printOperation(Operation *op); void printGenericOp(Operation *op); @@ -1109,12 +1110,14 @@ public: void printSuccessorAndUseList(Operation *term, unsigned index) override; /// Print a region. - void printRegion(Region &blocks, bool printEntryBlockArgs) override { + void printRegion(Region &blocks, bool printEntryBlockArgs, + bool printBlockTerminators) override { os << " {\n"; if (!blocks.empty()) { auto *entryBlock = &blocks.front(); print(entryBlock, - printEntryBlockArgs && entryBlock->getNumArguments() != 0); + printEntryBlockArgs && entryBlock->getNumArguments() != 0, + printBlockTerminators); for (auto &b : llvm::drop_begin(blocks.getBlocks(), 1)) print(&b); } @@ -1284,7 +1287,8 @@ void FunctionPrinter::print() { printTrailingLocation(function->getLoc()); if (!function->empty()) { - printRegion(function->getBody(), /*printEntryBlockArgs=*/false); + printRegion(function->getBody(), /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); os << "\n"; } os << '\n'; @@ -1335,7 +1339,8 @@ void FunctionPrinter::printFunctionSignature() { } } -void FunctionPrinter::print(Block *block, bool printBlockArgs) { +void FunctionPrinter::print(Block *block, bool printBlockArgs, + bool printBlockTerminator) { // Print the block label and argument list if requested. if (printBlockArgs) { os.indent(currentIndent); @@ -1379,8 +1384,10 @@ void FunctionPrinter::print(Block *block, bool printBlockArgs) { } currentIndent += indentWidth; - - for (auto &inst : block->getOperations()) { + auto range = llvm::make_range( + block->getOperations().begin(), + std::prev(block->getOperations().end(), printBlockTerminator ? 0 : 1)); + for (auto &inst : range) { print(&inst); os << '\n'; } @@ -1501,7 +1508,8 @@ void FunctionPrinter::printGenericOp(Operation *op) { // Print any trailing regions. for (auto ®ion : op->getRegions()) - printRegion(region, /*printEntryBlockArgs=*/true); + printRegion(region, /*printEntryBlockArgs=*/true, + /*printBlockTerminators=*/true); } void FunctionPrinter::printSuccessorAndUseList(Operation *term, diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index f4d532a482f..a0d9367fa5f 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -320,16 +320,7 @@ Block *FuncBuilder::createBlock(Block *insertBefore) { /// Create an operation given the fields represented as an OperationState. Operation *FuncBuilder::createOperation(const OperationState &state) { assert(block && "createOperation() called without setting builder's block"); - - unsigned numRegions = state.regions.size(); - auto *op = Operation::create(state.location, state.name, state.operands, - state.types, state.attributes, state.successors, - numRegions, state.resizableOperandList, context); - - for (unsigned i = 0; i < numRegions; ++i) - if (state.regions[i]) - op->getRegion(i).takeBody(*state.regions[i]); - + auto *op = Operation::create(state); block->getOperations().insert(insertPoint, op); return op; } diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index f1c89813ead..3de620b524c 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -98,6 +98,18 @@ Operation *Operation::create(Location location, OperationName name, resizableOperandList, context); } +/// Create a new Operation from operation state. +Operation *Operation::create(const OperationState &state) { + unsigned numRegions = state.regions.size(); + Operation *inst = create( + state.location, state.name, state.operands, state.types, state.attributes, + state.successors, numRegions, state.resizableOperandList, state.context); + for (unsigned i = 0; i < numRegions; ++i) + if (state.regions[i]) + inst->getRegion(i).takeBody(*state.regions[i]); + return inst; +} + /// Overload of create that takes an existing NamedAttributeList to avoid /// unnecessarily uniquing a list of attributes. Operation *Operation::create(Location location, OperationName name, diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index d76aca20b6d..8c29d1a76b4 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -1055,14 +1055,14 @@ computeLoopInterchangePermutation(ArrayRef ops, // pushing loop carried dependence to a greater depth in the loop nest. static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) { assert(node->inst->isa()); - // Get perfectly nested sequence of loops starting at root of loop nest. + // Get perfectly nested sequence of loops starting at root of loop nest + // (the first op being another AffineFor, and the second op - a terminator). // TODO(andydavis,bondhugula) Share this with similar code in loop tiling. SmallVector loops; AffineForOp curr = node->inst->cast(); loops.push_back(curr); auto *currBody = curr.getBody(); - while (!currBody->empty() && - std::next(currBody->begin()) == currBody->end() && + while (currBody->begin() == std::prev(currBody->end(), 2) && (curr = curr.getBody()->front().dyn_cast())) { loops.push_back(curr); currBody = curr.getBody(); diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index d9f74808ad8..c235190b4b7 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -89,10 +89,12 @@ FunctionPassBase *mlir::createLoopTilingPass(uint64_t cacheSizeBytes) { } // Move the loop body of AffineForOp 'src' from 'src' into the specified -// location in destination's body. +// location in destination's body, ignoring the terminator. static inline void moveLoopBody(AffineForOp src, AffineForOp dest, Block::iterator loc) { - dest.getBody()->getOperations().splice(loc, src.getBody()->getOperations()); + auto &insts = src.getBody()->getOperations(); + dest.getBody()->getOperations().splice(loc, insts, insts.begin(), + std::prev(insts.end())); } // Move the loop body of AffineForOp 'src' from 'src' to the start of dest's @@ -202,7 +204,6 @@ LogicalResult mlir::tileCodeGen(MutableArrayRef band, FuncBuilder b(topLoop); // Loop bounds will be set later. auto pointLoop = b.create(loc, 0, 0); - pointLoop.createBody(); pointLoop.getBody()->getOperations().splice( pointLoop.getBody()->begin(), topLoop->getBlock()->getOperations(), topLoop); @@ -217,7 +218,6 @@ LogicalResult mlir::tileCodeGen(MutableArrayRef band, FuncBuilder b(topLoop); // Loop bounds will be set later. auto tileSpaceLoop = b.create(loc, 0, 0); - tileSpaceLoop.createBody(); tileSpaceLoop.getBody()->getOperations().splice( tileSpaceLoop.getBody()->begin(), topLoop->getBlock()->getOperations(), topLoop); @@ -264,7 +264,7 @@ static void getTileableBands(Function &f, AffineForOp currInst = root; do { band.push_back(currInst); - } while (currInst.getBody()->getOperations().size() == 1 && + } while (currInst.getBody()->getOperations().size() == 2 && (currInst = currInst.getBody()->front().dyn_cast())); bands->push_back(band); }; diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 3fa4eab93da..3ea20c0c282 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -125,7 +125,7 @@ LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp, uint64_t unrollJamFactor) { // Gathers all maximal sub-blocks of instructions that do not themselves // include a for inst (a instruction could have a descendant for inst though - // in its tree). + // in its tree). Ignore the block terminators. struct JamBlockGatherer { // Store iterators to the first and last inst of each sub-block found. std::vector> subBlocks; @@ -137,7 +137,7 @@ LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp, walk(block); } void walk(Block &block) { - for (auto it = block.begin(), e = block.end(); it != e;) { + for (auto it = block.begin(), e = std::prev(block.end()); it != e;) { auto subBlockStart = it; while (it != e && !it->isa()) ++it; @@ -155,7 +155,8 @@ LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp, if (unrollJamFactor == 1) return promoteIfSingleIteration(forOp); - if (forOp.getBody()->empty()) + if (forOp.getBody()->empty() || + forOp.getBody()->begin() == std::prev(forOp.getBody()->end())) return failure(); // Loops where both lower and upper bounds are multi-result maps won't be diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index 5046bf2596b..acc9481e89c 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -335,14 +335,15 @@ bool LowerAffinePass::lowerAffineFor(AffineForOp forOp) { conditionBlock->insertBefore(endBlock); auto *iv = conditionBlock->addArgument(IndexType::get(forInst->getContext())); - // Create the body block, moving the body of the forOp over to it. + // Create the body block, moving the body of the forOp over to it and dropping + // the affine terminator. auto *bodyBlock = new Block(); bodyBlock->insertBefore(endBlock); auto *oldBody = forOp.getBody(); bodyBlock->getOperations().splice(bodyBlock->begin(), oldBody->getOperations(), oldBody->begin(), - oldBody->end()); + std::prev(oldBody->end())); // The code in the body of the forOp now uses 'iv' as its indvar. forOp.getInductionVar()->replaceAllUsesWith(iv); @@ -406,7 +407,7 @@ bool LowerAffinePass::lowerAffineFor(AffineForOp forOp) { // enabling easy nesting of "if" instructions and if-then-else-if chains. // // +--------------------------------+ -// | | +// | | // | %zero = constant 0 : index | // | %v = affine.apply #expr1(%ops) | // | %c = cmpi "sge" %v, %zero | @@ -450,7 +451,7 @@ bool LowerAffinePass::lowerAffineFor(AffineForOp forOp) { // v v // +--------------------------------+ // | continue: | -// | | +// | | // +--------------------------------+ // bool LowerAffinePass::lowerAffineIf(AffineIfOp ifOp) { @@ -469,7 +470,8 @@ bool LowerAffinePass::lowerAffineIf(AffineIfOp ifOp) { Block *thenBlock = new Block(); thenBlock->insertBefore(continueBlock); - // If the 'then' block is not empty, then splice the instructions. + // If the 'then' block is not empty, then splice the instructions except for + // the terminator. auto &oldThenBlocks = ifOp.getThenBlocks(); if (!oldThenBlocks.empty()) { // We currently only handle one 'then' block. @@ -478,9 +480,9 @@ bool LowerAffinePass::lowerAffineIf(AffineIfOp ifOp) { Block *oldThen = &oldThenBlocks.front(); - thenBlock->getOperations().splice(thenBlock->begin(), - oldThen->getOperations(), - oldThen->begin(), oldThen->end()); + thenBlock->getOperations().splice( + thenBlock->begin(), oldThen->getOperations(), oldThen->begin(), + std::prev(oldThen->end())); } FuncBuilder builder(thenBlock); @@ -499,9 +501,9 @@ bool LowerAffinePass::lowerAffineIf(AffineIfOp ifOp) { elseBlock = new Block(); elseBlock->insertBefore(continueBlock); - elseBlock->getOperations().splice(elseBlock->begin(), - oldElse->getOperations(), - oldElse->begin(), oldElse->end()); + elseBlock->getOperations().splice( + elseBlock->begin(), oldElse->getOperations(), oldElse->begin(), + std::prev(oldElse->end())); builder.setInsertionPointToEnd(elseBlock); builder.create(loc, continueBlock); } diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 9a7db193d29..2760e8b8bd3 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -145,8 +145,10 @@ LogicalResult mlir::promoteIfSingleIteration(AffineForOp forOp) { } } } - // Move the loop body instructions to the loop's containing block. + // Move the loop body instructions, except for terminator, to the loop's + // containing block. auto *block = forInst->getBlock(); + forOp.getBody()->getOperations().back().erase(); block->getOperations().splice(Block::iterator(forInst), forOp.getBody()->getOperations()); forOp.erase(); @@ -182,12 +184,12 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, auto loopChunk = b->create(srcForInst.getLoc(), lbOperands, lbMap, ubOperands, ubMap, srcForInst.getStep()); - loopChunk.createBody(); auto *loopChunkIV = loopChunk.getInductionVar(); auto *srcIV = srcForInst.getInductionVar(); BlockAndValueMapping operandMap; + FuncBuilder bodyBuilder = loopChunk.getBodyBuilder(); for (auto it = instGroupQueue.begin() + offset, e = instGroupQueue.end(); it != e; ++it) { uint64_t shift = it->first; @@ -197,10 +199,9 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, // Generate the remapping if the shift is not zero: remappedIV = newIV - // shift. if (!srcIV->use_empty() && shift != 0) { - FuncBuilder b(loopChunk.getBody()); - auto ivRemap = b.create( + auto ivRemap = bodyBuilder.create( srcForInst.getLoc(), - b.getSingleDimShiftAffineMap( + bodyBuilder.getSingleDimShiftAffineMap( -static_cast(srcForInst.getStep() * shift)), loopChunkIV); operandMap.map(srcIV, ivRemap); @@ -208,9 +209,10 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, operandMap.map(srcIV, loopChunkIV); } for (auto *inst : insts) { - loopChunk.getBody()->push_back(inst->clone(operandMap, b->getContext())); + if (!inst->isa()) + bodyBuilder.clone(*inst, operandMap); } - } + }; if (succeeded(promoteIfSingleIteration(loopChunk))) return AffineForOp(); return loopChunk; @@ -233,7 +235,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, // method. LogicalResult mlir::instBodySkew(AffineForOp forOp, ArrayRef shifts, bool unrollPrologueEpilogue) { - if (forOp.getBody()->empty()) + if (forOp.getBody()->begin() == std::prev(forOp.getBody()->end())) return success(); // If the trip counts aren't constant, we would need versioning and @@ -385,7 +387,8 @@ LogicalResult mlir::loopUnrollByFactor(AffineForOp forOp, if (unrollFactor == 1) return promoteIfSingleIteration(forOp); - if (forOp.getBody()->empty()) + if (forOp.getBody()->empty() || + forOp.getBody()->begin() == std::prev(forOp.getBody()->end())) return failure(); // Loops where the lower bound is a max expression isn't supported for @@ -428,13 +431,13 @@ LogicalResult mlir::loopUnrollByFactor(AffineForOp forOp, int64_t step = forOp.getStep(); forOp.setStep(step * unrollFactor); - // Builder to insert unrolled bodies right after the last instruction in the - // body of 'forOp'. - FuncBuilder builder(forOp.getBody(), forOp.getBody()->end()); + // Builder to insert unrolled bodies just before the terminator of the body of + // 'forOp'. + FuncBuilder builder = forOp.getBodyBuilder(); - // Keep a pointer to the last instruction in the original block so that we - // know what to clone (since we are doing this in-place). - Block::iterator srcBlockEnd = std::prev(forOp.getBody()->end()); + // Keep a pointer to the last non-terminator instruction in the original block + // so that we know what to clone (since we are doing this in-place). + Block::iterator srcBlockEnd = std::prev(forOp.getBody()->end(), 2); // Unroll the contents of 'forOp' (append unrollFactor-1 additional copies). auto *forOpIV = forOp.getInductionVar(); @@ -465,23 +468,27 @@ LogicalResult mlir::loopUnrollByFactor(AffineForOp forOp, } /// Performs loop interchange on 'forOpA' and 'forOpB', where 'forOpB' is -/// nested within 'forOpA' as the only instruction in its block. +/// nested within 'forOpA' as the only non-terminator operation in its block. void mlir::interchangeLoops(AffineForOp forOpA, AffineForOp forOpB) { auto *forOpAInst = forOpA.getOperation(); - // 1) Slice forOpA's instruction list (which is just forOpB) just before - // forOpA (in forOpA's parent's block) this should leave 'forOpA's - // instruction list empty (because its perfectly nested). + assert(&*forOpA.getBody()->begin() == forOpB.getOperation()); - forOpAInst->getBlock()->getOperations().splice( - Block::iterator(forOpAInst), forOpA.getBody()->getOperations()); - // 2) Slice forOpB's instruction list into forOpA's instruction list (this - // leaves forOpB's instruction list empty). - forOpA.getBody()->getOperations().splice(forOpA.getBody()->begin(), - forOpB.getBody()->getOperations()); - // 3) Slice forOpA into forOpB's instruction list. - forOpB.getBody()->getOperations().splice( - forOpB.getBody()->begin(), forOpAInst->getBlock()->getOperations(), - Block::iterator(forOpAInst)); + auto &forOpABody = forOpA.getBody()->getOperations(); + auto &forOpBBody = forOpB.getBody()->getOperations(); + + // 1) Splice forOpA's non-terminator operations (which is just forOpB) just + // before forOpA (in ForOpA's parent's block) this should leave 'forOpA's + // body containing only the terminator. + forOpAInst->getBlock()->getOperations().splice(Block::iterator(forOpAInst), + forOpABody, forOpABody.begin(), + std::prev(forOpABody.end())); + // 2) Splice forOpB's non-terminator operations into the beginning of forOpA's + // body (this leaves forOpB's body containing only the terminator). + forOpABody.splice(forOpABody.begin(), forOpBBody, forOpBBody.begin(), + std::prev(forOpBBody.end())); + // 3) Splice forOpA into the beginning of forOpB's body. + forOpBBody.splice(forOpBBody.begin(), forOpAInst->getBlock()->getOperations(), + Block::iterator(forOpAInst)); } /// Performs a series of loop interchanges to sink 'forOp' 'loopDepth' levels @@ -516,25 +523,27 @@ static void augmentMapAndBounds(FuncBuilder *b, Value *iv, AffineMap *map, // Clone the original body of `forOp` into the body of `newForOp` while // substituting `oldIv` in place of -// `forOp.getInductionVariable()`. +// `forOp.getInductionVariable()` and ignoring the terminator. // Note: `newForOp` may be nested under `forOp`. static void cloneLoopBodyInto(AffineForOp forOp, Value *oldIv, AffineForOp newForOp) { BlockAndValueMapping map; map.map(oldIv, newForOp.getInductionVar()); - FuncBuilder b(newForOp.getBody(), newForOp.getBody()->end()); - for (auto it = forOp.getBody()->begin(), end = forOp.getBody()->end(); - it != end; ++it) { + FuncBuilder b = newForOp.getBodyBuilder(); + for (auto &inst : *forOp.getBody()) { // Step over newForOp in case it is nested under forOp. - if (&*it == newForOp.getOperation()) { + if (&inst == newForOp.getOperation()) { + continue; + } + if (inst.isa()) { continue; } - auto *inst = b.clone(*it, map); + auto *instClone = b.clone(inst, map); unsigned idx = 0; - for (auto r : it->getResults()) { + for (auto r : inst.getResults()) { // Since we do a forward pass over the body, we iteratively augment // the `map` with everything we clone. - map.map(r, inst->getResult(idx++)); + map.map(r, instClone->getResult(idx++)); } } } @@ -574,11 +583,10 @@ stripmineSink(AffineForOp forOp, uint64_t factor, SmallVector innerLoops; for (auto t : targets) { - // Insert newForOp at the end of `t`. - FuncBuilder b(t.getBody(), t.getBody()->end()); + // Insert newForOp before the terminator of `t`. + FuncBuilder b = t.getBodyBuilder(); auto newForOp = b.create(t.getLoc(), lbOperands, lbMap, ubOperands, ubMap, originalStep); - newForOp.createBody(); cloneLoopBodyInto(t, forOp.getInductionVar(), newForOp); // Remove all instructions from `t` except `newForOp`. auto rit = ++newForOp.getOperation()->getReverseIterator(); diff --git a/mlir/test/AffineOps/ops.mlir b/mlir/test/AffineOps/ops.mlir index 6e60c624c72..92581ff10f7 100644 --- a/mlir/test/AffineOps/ops.mlir +++ b/mlir/test/AffineOps/ops.mlir @@ -1,19 +1,39 @@ // RUN: mlir-opt %s | FileCheck %s +// RUN: mlir-opt %s -mlir-print-op-generic | FileCheck -check-prefix=GENERIC %s // Check that the attributes for the affine operations are round-tripped. -func @attributes() { +// Check that `affine.terminator` is visible in the generic form. +// CHECK-LABEL: @empty +func @empty() { // CHECK: affine.for %i // CHECK-NEXT: } {some_attr: true} + // + // GENERIC: "affine.for"() + // GENERIC-NEXT: ^bb1(%i0: index): + // GENERIC-NEXT: "affine.terminator"() : () -> () + // GENERIC-NEXT: } affine.for %i = 0 to 10 { } {some_attr: true} - // CHECK: if + // CHECK: affine.if // CHECK-NEXT: } {some_attr: true} + // + // GENERIC: "affine.if"() + // GENERIC-NEXT: "affine.terminator"() : () -> () + // GENERIC-NEXT: } { + // GENERIC-NEXT: } affine.if () : () () { } {some_attr: true} // CHECK: } else { // CHECK: } {some_attr: true} + // + // GENERIC: "affine.if"() + // GENERIC-NEXT: "affine.terminator"() : () -> () + // GENERIC-NEXT: } { + // GENERIC-NEXT: "foo"() : () -> () + // GENERIC-NEXT: "affine.terminator"() : () -> () + // GENERIC-NEXT: } affine.if () : () () { } else { "foo"() : () -> () @@ -21,3 +41,20 @@ func @attributes() { return } + +// Check that an explicit affine terminator is not printed in custom format. +// Check that no extra terminator is introduced. +// CHEKC-LABEL: @affine_terminator +func @affine_terminator() { + // CHECK: affine.for %i + // CHECK-NEXT: } + // + // GENERIC: "affine.for"() {lower_bound: #map0, step: 1 : index, upper_bound: #map1} : () -> () { + // GENERIC-NEXT: ^bb1(%i0: index): // no predecessors + // GENERIC-NEXT: "affine.terminator"() : () -> () + // GENERIC-NEXT: } + affine.for %i = 0 to 10 { + "affine.terminator"() : () -> () + } + return +} diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir index 224d8813648..9a96ccdc413 100644 --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -478,7 +478,8 @@ func @return_type_mismatch() -> i32 { func @return_inside_loop() { affine.for %i = 1 to 100 { - // expected-error@-1 {{op expects body block to not have a terminator}} + // expected-error@-1 {{op expects regions to end with 'affine.terminator'}} + // expected-note@-2 {{in custom textual format, the absence of terminator implies}} return } return diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index c66c6c0614b..bdf13f7cfe1 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -791,10 +791,12 @@ func @verbose_if(%N: index) { "affine.if"(%c, %N, %c) { condition: #set0 } : (index, index, index) -> () { // CHECK-NEXT: "add" %y = "add"(%c, %N) : (index, index) -> index + "affine.terminator"() : () -> () // CHECK-NEXT: } else { } { // The else region. // CHECK-NEXT: "add" %z = "add"(%c, %c) : (index, index) -> index + "affine.terminator"() : () -> () } return } -- cgit v1.2.3 From 9c085406904780d25673ad213ac53a4c6e1558c0 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Wed, 27 Mar 2019 08:55:17 -0700 Subject: Replace usages of Instruction with Operation in the /Analysis directory. PiperOrigin-RevId: 240569775 --- mlir/include/mlir/Analysis/AffineAnalysis.h | 21 ++-- mlir/include/mlir/Analysis/AffineStructures.h | 6 +- mlir/include/mlir/Analysis/Dominance.h | 26 ++--- mlir/include/mlir/Analysis/LoopAnalysis.h | 3 +- mlir/include/mlir/Analysis/NestedMatcher.h | 47 +++++---- mlir/include/mlir/Analysis/SliceAnalysis.h | 65 ++++++------ mlir/include/mlir/Analysis/Utils.h | 25 +++-- mlir/include/mlir/Analysis/VectorAnalysis.h | 7 +- mlir/include/mlir/IR/Operation.h | 6 +- mlir/lib/AffineOps/AffineOps.cpp | 22 ++--- mlir/lib/Analysis/AffineAnalysis.cpp | 67 ++++++------- mlir/lib/Analysis/Dominance.cpp | 20 ++-- mlir/lib/Analysis/LoopAnalysis.cpp | 44 ++++----- mlir/lib/Analysis/MemRefBoundCheck.cpp | 2 +- mlir/lib/Analysis/MemRefDependenceCheck.cpp | 10 +- mlir/lib/Analysis/NestedMatcher.cpp | 60 ++++++----- mlir/lib/Analysis/OpStats.cpp | 3 +- mlir/lib/Analysis/SliceAnalysis.cpp | 93 +++++++++-------- mlir/lib/Analysis/Utils.cpp | 110 ++++++++++----------- mlir/lib/Analysis/VectorAnalysis.cpp | 53 +++++----- mlir/lib/Analysis/Verifier.cpp | 54 +++++----- mlir/lib/IR/Block.cpp | 2 +- mlir/lib/IR/Operation.cpp | 6 +- mlir/lib/Transforms/LoopFusion.cpp | 4 +- mlir/lib/Transforms/LoopTiling.cpp | 3 +- mlir/lib/Transforms/MaterializeVectors.cpp | 4 +- .../Vectorization/VectorizerTestPass.cpp | 22 ++--- mlir/lib/Transforms/Vectorize.cpp | 10 +- 28 files changed, 390 insertions(+), 405 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Analysis/AffineAnalysis.h b/mlir/include/mlir/Analysis/AffineAnalysis.h index 2944a6bca97..4873475e58b 100644 --- a/mlir/include/mlir/Analysis/AffineAnalysis.h +++ b/mlir/include/mlir/Analysis/AffineAnalysis.h @@ -36,15 +36,14 @@ class AffineForOp; class AffineValueMap; class FlatAffineConstraints; class Operation; -using Instruction = Operation; class Value; /// Returns in `affineApplyOps`, the sequence of those AffineApplyOp -/// Instructions that are reachable via a search starting from `operands` and +/// Operations that are reachable via a search starting from `operands` and /// ending at those operands that are not the result of an AffineApplyOp. void getReachableAffineApplyOps( llvm::ArrayRef operands, - llvm::SmallVectorImpl &affineApplyOps); + llvm::SmallVectorImpl &affineApplyOps); /// Builds a system of constraints with dimensional identifiers corresponding to /// the loop IVs of the forOps appearing in that order. Bounds of the loop are @@ -58,13 +57,13 @@ LogicalResult getIndexSet(llvm::MutableArrayRef forOps, /// Encapsulates a memref load or store access information. struct MemRefAccess { Value *memref; - Instruction *opInst; + Operation *opInst; llvm::SmallVector indices; - /// Constructs a MemRefAccess from a load or store operation instruction. + /// Constructs a MemRefAccess from a load or store operation. // TODO(b/119949820): add accessors to standard op's load, store, DMA op's to // return MemRefAccess, i.e., loadOp->getAccess(), dmaOp->getRead/WriteAccess. - explicit MemRefAccess(Instruction *opInst); + explicit MemRefAccess(Operation *opInst); // Returns the rank of the memref associated with this access. unsigned getRank() const; @@ -92,11 +91,11 @@ struct DependenceComponent { /// Checks whether two accesses to the same memref access the same element. /// Each access is specified using the MemRefAccess structure, which contains -/// the operation instruction, indices and memref associated with the access. -/// Returns 'success' if it can be determined conclusively that the accesses do -/// not access the same memref element. -/// If 'allowRAR' is true, will consider read-after-read dependences (typically -/// used by applications trying to optimize input reuse). +/// the operation, indices and memref associated with the access. Returns +/// 'false' if it can be determined conclusively that the accesses do not +/// access the same memref element. If 'allowRAR' is true, will consider +/// read-after-read dependences (typically used by applications trying to +/// optimize input reuse). // TODO(andydavis) Wrap 'dependenceConstraints' and 'dependenceComponents' into // a single struct. // TODO(andydavis) Make 'dependenceConstraints' optional arg. diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h index f9ea873d0f7..7dcab234143 100644 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -377,12 +377,12 @@ public: AffineExpr toAffineExpr(unsigned idx, MLIRContext *context); /// Adds constraints (lower and upper bounds) for the specified 'affine.for' - /// instruction's Value using IR information stored in its bound maps. The + /// operation's Value using IR information stored in its bound maps. The /// right identifier is first looked up using forOp's Value. Asserts if the - /// Value corresponding to the 'affine.for' instruction isn't found in the + /// Value corresponding to the 'affine.for' operation isn't found in the /// constraint system. Returns failure for the yet unimplemented/unsupported /// cases. Any new identifiers that are found in the bound operands of the - /// 'affine.for' instruction are added as trailing identifiers (either + /// 'affine.for' operation are added as trailing identifiers (either /// dimensional or symbolic depending on whether the operand is a valid ML /// Function symbol). // TODO(bondhugula): add support for non-unit strides. diff --git a/mlir/include/mlir/Analysis/Dominance.h b/mlir/include/mlir/Analysis/Dominance.h index 1c3ca02e41c..f22def7699d 100644 --- a/mlir/include/mlir/Analysis/Dominance.h +++ b/mlir/include/mlir/Analysis/Dominance.h @@ -65,20 +65,20 @@ class DominanceInfo : public detail::DominanceInfoBase { public: using super::super; - /// Return true if instruction A properly dominates instruction B. - bool properlyDominates(Instruction *a, Instruction *b); + /// Return true if operation A properly dominates operation B. + bool properlyDominates(Operation *a, Operation *b); - /// Return true if instruction A dominates instruction B. - bool dominates(Instruction *a, Instruction *b) { + /// Return true if operation A dominates operation B. + bool dominates(Operation *a, Operation *b) { return a == b || properlyDominates(a, b); } - /// Return true if value A properly dominates instruction B. - bool properlyDominates(Value *a, Instruction *b); + /// Return true if value A properly dominates operation B. + bool properlyDominates(Value *a, Operation *b); - /// Return true if instruction A dominates instruction B. - bool dominates(Value *a, Instruction *b) { - return (Instruction *)a->getDefiningOp() == b || properlyDominates(a, b); + /// Return true if operation A dominates operation B. + bool dominates(Value *a, Operation *b) { + return (Operation *)a->getDefiningOp() == b || properlyDominates(a, b); } /// Return true if the specified block A dominates block B. @@ -97,11 +97,11 @@ class PostDominanceInfo : public detail::DominanceInfoBase { public: using super::super; - /// Return true if instruction A properly postdominates instruction B. - bool properlyPostDominates(Instruction *a, Instruction *b); + /// Return true if operation A properly postdominates operation B. + bool properlyPostDominates(Operation *a, Operation *b); - /// Return true if instruction A postdominates instruction B. - bool postDominates(Instruction *a, Instruction *b) { + /// Return true if operation A postdominates operation B. + bool postDominates(Operation *a, Operation *b) { return a == b || properlyPostDominates(a, b); } diff --git a/mlir/include/mlir/Analysis/LoopAnalysis.h b/mlir/include/mlir/Analysis/LoopAnalysis.h index cc7af1184b3..b364f084295 100644 --- a/mlir/include/mlir/Analysis/LoopAnalysis.h +++ b/mlir/include/mlir/Analysis/LoopAnalysis.h @@ -32,7 +32,6 @@ class AffineExpr; class AffineForOp; class AffineMap; class Operation; -using Instruction = Operation; class MemRefType; class Value; @@ -102,7 +101,7 @@ bool isVectorizableLoopAlongFastestVaryingMemRefDim(AffineForOp loop, unsigned fastestVaryingDim); /// Checks where SSA dominance would be violated if a for inst's body -/// instructions are shifted by the specified shifts. This method checks if a +/// operations are shifted by the specified shifts. This method checks if a /// 'def' and all its uses have the same shift factor. // TODO(mlir-team): extend this to check for memory-based dependence // violation when we have the support. diff --git a/mlir/include/mlir/Analysis/NestedMatcher.h b/mlir/include/mlir/Analysis/NestedMatcher.h index 393abdb33a4..8ee5ba826b2 100644 --- a/mlir/include/mlir/Analysis/NestedMatcher.h +++ b/mlir/include/mlir/Analysis/NestedMatcher.h @@ -25,7 +25,6 @@ namespace mlir { struct NestedPattern; class Operation; -using Instruction = Operation; /// An NestedPattern captures nested patterns in the IR. /// It is used in conjunction with a scoped NestedPatternContext which is an @@ -47,20 +46,20 @@ using Instruction = Operation; /// /// /// Nested abstraction for matching results. -/// Provides access to the nested Instruction* captured by a Matcher. +/// Provides access to the nested Operation* captured by a Matcher. /// -/// A NestedMatch contains an Instruction* and the children NestedMatch and is +/// A NestedMatch contains an Operation* and the children NestedMatch and is /// thus cheap to copy. NestedMatch is stored in a scoped bumper allocator whose /// lifetime is managed by an RAII NestedPatternContext. struct NestedMatch { - static NestedMatch build(Instruction *instruction, + static NestedMatch build(Operation *operation, ArrayRef nestedMatches); NestedMatch(const NestedMatch &) = default; NestedMatch &operator=(const NestedMatch &) = default; - explicit operator bool() { return matchedInstruction != nullptr; } + explicit operator bool() { return matchedOperation != nullptr; } - Instruction *getMatchedInstruction() { return matchedInstruction; } + Operation *getMatchedOperation() { return matchedOperation; } ArrayRef getMatchedChildren() { return matchedChildren; } private: @@ -73,11 +72,11 @@ private: NestedMatch() = default; /// Payload, holds a NestedMatch and all its children along this branch. - Instruction *matchedInstruction; + Operation *matchedOperation; ArrayRef matchedChildren; }; -/// A NestedPattern is a nested instruction walker that: +/// A NestedPattern is a nested operation walker that: /// 1. recursively matches a substructure in the tree; /// 2. uses a filter function to refine matches with extra semantic /// constraints (passed via a lambda of type FilterFunctionType); @@ -93,10 +92,10 @@ private: /// /// The NestedMatches captured in the IR can grow large, especially after /// aggressive unrolling. As experience has shown, it is generally better to use -/// a plain walk over instructions to match flat patterns but the current +/// a plain walk over operations to match flat patterns but the current /// implementation is competitive nonetheless. -using FilterFunctionType = std::function; -static bool defaultFilterFunction(Instruction &) { return true; }; +using FilterFunctionType = std::function; +static bool defaultFilterFunction(Operation &) { return true; }; struct NestedPattern { NestedPattern(ArrayRef nested, FilterFunctionType filter = defaultFilterFunction); @@ -105,12 +104,12 @@ struct NestedPattern { /// Returns all the top-level matches in `func`. void match(Function *func, SmallVectorImpl *matches) { - func->walkPostOrder([&](Instruction *inst) { matchOne(inst, matches); }); + func->walkPostOrder([&](Operation *op) { matchOne(op, matches); }); } - /// Returns all the top-level matches in `inst`. - void match(Instruction *inst, SmallVectorImpl *matches) { - inst->walkPostOrder([&](Instruction *child) { matchOne(child, matches); }); + /// Returns all the top-level matches in `op`. + void match(Operation *op, SmallVectorImpl *matches) { + op->walkPostOrder([&](Operation *child) { matchOne(child, matches); }); } /// Returns the depth of the pattern. @@ -124,9 +123,9 @@ private: /// Underlying global bump allocator managed by a NestedPatternContext. static llvm::BumpPtrAllocator *&allocator(); - /// Matches this pattern against a single `inst` and fills matches with the + /// Matches this pattern against a single `op` and fills matches with the /// result. - void matchOne(Instruction *inst, SmallVectorImpl *matches); + void matchOne(Operation *op, SmallVectorImpl *matches); /// Nested patterns to be matched. ArrayRef nestedPatterns; @@ -135,19 +134,19 @@ private: FilterFunctionType filter; /// skip is an implementation detail needed so that we can implement match - /// without switching on the type of the Instruction. The idea is that a + /// without switching on the type of the Operation. The idea is that a /// NestedPattern first checks if it matches locally and then recursively /// applies its nested matchers to its elem->nested. Since we want to rely on - /// the existing instruction walking functionality rather than duplicate + /// the existing operation walking functionality rather than duplicate /// it, we allow an off-by-one traversal to account for the fact that we /// write: /// - /// void match(Instruction *elem) { + /// void match(Operation *elem) { /// for (auto &c : getNestedPatterns()) { /// NestedPattern childPattern(...); /// ^~~~ Needs off-by-one skip. /// - Instruction *skip; + Operation *skip; }; /// RAII structure to transparently manage the bump allocator for @@ -183,9 +182,9 @@ NestedPattern For(ArrayRef nested = {}); NestedPattern For(FilterFunctionType filter, ArrayRef nested = {}); -bool isParallelLoop(Instruction &inst); -bool isReductionLoop(Instruction &inst); -bool isLoadOrStore(Instruction &inst); +bool isParallelLoop(Operation &op); +bool isReductionLoop(Operation &op); +bool isLoadOrStore(Operation &op); } // end namespace matcher } // end namespace mlir diff --git a/mlir/include/mlir/Analysis/SliceAnalysis.h b/mlir/include/mlir/Analysis/SliceAnalysis.h index a3fb841092a..c76f0b2a03c 100644 --- a/mlir/include/mlir/Analysis/SliceAnalysis.h +++ b/mlir/include/mlir/Analysis/SliceAnalysis.h @@ -28,24 +28,23 @@ namespace mlir { class Operation; -using Instruction = Operation; /// Type of the condition to limit the propagation of transitive use-defs. /// This can be used in particular to limit the propagation to a given Scope or -/// to avoid passing through certain types of instruction in a configurable +/// to avoid passing through certain types of operation in a configurable /// manner. -using TransitiveFilter = std::function; +using TransitiveFilter = std::function; /// Fills `forwardSlice` with the computed forward slice (i.e. all -/// the transitive uses of inst), **without** including that instruction. +/// the transitive uses of op), **without** including that operation. /// /// This additionally takes a TransitiveFilter which acts as a frontier: -/// when looking at uses transitively, a instruction that does not pass the +/// when looking at uses transitively, a operation that does not pass the /// filter is never propagated through. This allows in particular to carve out /// the scope within a ForInst or the scope within an IfInst. /// /// The implementation traverses the use chains in postorder traversal for -/// efficiency reasons: if a instruction is already in `forwardSlice`, no +/// efficiency reasons: if a operation is already in `forwardSlice`, no /// need to traverse its uses again. Since use-def chains form a DAG, this /// terminates. /// @@ -78,20 +77,20 @@ using TransitiveFilter = std::function; /// {4, 3, 6, 2, 1, 5, 8, 7, 9} /// void getForwardSlice( - Instruction *inst, llvm::SetVector *forwardSlice, + Operation *op, llvm::SetVector *forwardSlice, TransitiveFilter filter = /* pass-through*/ - [](Instruction *) { return true; }); + [](Operation *) { return true; }); /// Fills `backwardSlice` with the computed backward slice (i.e. -/// all the transitive defs of inst), **without** including that instruction. +/// all the transitive defs of op), **without** including that operation. /// /// This additionally takes a TransitiveFilter which acts as a frontier: -/// when looking at defs transitively, a instruction that does not pass the +/// when looking at defs transitively, a operation that does not pass the /// filter is never propagated through. This allows in particular to carve out /// the scope within a ForInst or the scope within an IfInst. /// /// The implementation traverses the def chains in postorder traversal for -/// efficiency reasons: if a instruction is already in `backwardSlice`, no +/// efficiency reasons: if a operation is already in `backwardSlice`, no /// need to traverse its definitions again. Since useuse-def chains form a DAG, /// this terminates. /// @@ -117,18 +116,18 @@ void getForwardSlice( /// {1, 2, 5, 7, 3, 4, 6, 8} /// void getBackwardSlice( - Instruction *inst, llvm::SetVector *backwardSlice, + Operation *op, llvm::SetVector *backwardSlice, TransitiveFilter filter = /* pass-through*/ - [](Instruction *) { return true; }); + [](Operation *) { return true; }); /// Iteratively computes backward slices and forward slices until -/// a fixed point is reached. Returns an `llvm::SetVector` which -/// **includes** the original instruction. +/// a fixed point is reached. Returns an `llvm::SetVector` which +/// **includes** the original operation. /// /// This allows building a slice (i.e. multi-root DAG where everything /// that is reachable from an Value in forward and backward direction is /// contained in the slice). -/// This is the abstraction we need to materialize all the instructions for +/// This is the abstraction we need to materialize all the operations for /// supervectorization without worrying about orderings and Value /// replacements. /// @@ -157,20 +156,20 @@ void getBackwardSlice( /// /// Additional implementation considerations /// ======================================== -/// Consider the defs-inst-uses hourglass. +/// Consider the defs-op-uses hourglass. /// ____ /// \ / defs (in some topological order) /// \/ -/// inst +/// op /// /\ /// / \ uses (in some topological order) /// /____\ /// /// We want to iteratively apply `getSlice` to construct the whole -/// list of Instruction that are reachable by (use|def)+ from inst. +/// list of Operation that are reachable by (use|def)+ from op. /// We want the resulting slice in topological order. /// Ideally we would like the ordering to be maintained in-place to avoid -/// copying Instruction at each step. Keeping this ordering by construction +/// copying Operation at each step. Keeping this ordering by construction /// seems very unclear, so we list invariants in the hope of seeing whether /// useful properties pop up. /// @@ -182,34 +181,34 @@ void getBackwardSlice( /// =========== /// We wish to maintain the following property by a recursive argument: /// """ -/// defs << {inst} < getSlice( - Instruction *inst, +llvm::SetVector getSlice( + Operation *op, TransitiveFilter backwardFilter = /* pass-through*/ - [](Instruction *) { return true; }, + [](Operation *) { return true; }, TransitiveFilter forwardFilter = /* pass-through*/ - [](Instruction *) { return true; }); + [](Operation *) { return true; }); /// Multi-root DAG topological sort. -/// Performs a topological sort of the Instruction in the `toSort` SetVector. +/// Performs a topological sort of the Operation in the `toSort` SetVector. /// Returns a topologically sorted SetVector. -llvm::SetVector -topologicalSort(const llvm::SetVector &toSort); +llvm::SetVector +topologicalSort(const llvm::SetVector &toSort); } // end namespace mlir diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index e6af0ce3ff2..8ce4de10eb7 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -41,17 +41,16 @@ class FlatAffineConstraints; class Location; class MemRefAccess; class Operation; -using Instruction = Operation; class Value; -/// Populates 'loops' with IVs of the loops surrounding 'inst' ordered from -/// the outermost 'affine.for' instruction to the innermost one. -// TODO(bondhugula): handle 'affine.if' inst's. -void getLoopIVs(Instruction &inst, SmallVectorImpl *loops); +/// Populates 'loops' with IVs of the loops surrounding 'op' ordered from +/// the outermost 'affine.for' operation to the innermost one. +// TODO(bondhugula): handle 'affine.if' ops. +void getLoopIVs(Operation &op, SmallVectorImpl *loops); -/// Returns the nesting depth of this instruction, i.e., the number of loops -/// surrounding this instruction. -unsigned getNestingDepth(Instruction &inst); +/// Returns the nesting depth of this operation, i.e., the number of loops +/// surrounding this operation. +unsigned getNestingDepth(Operation &op); /// Returns in 'sequentialLoops' all sequential loops in loop nest rooted /// at 'forOp'. @@ -96,15 +95,15 @@ LogicalResult getBackwardComputationSliceState( /// Creates a clone of the computation contained in the loop nest surrounding /// 'srcOpInst', slices the iteration space of src loop based on slice bounds /// in 'sliceState', and inserts the computation slice at the beginning of the -/// instruction block of the loop at 'dstLoopDepth' in the loop nest surrounding +/// operation block of the loop at 'dstLoopDepth' in the loop nest surrounding /// 'dstOpInst'. Returns the top-level loop of the computation slice on /// success, returns nullptr otherwise. // Loop depth is a crucial optimization choice that determines where to // materialize the results of the backward slice - presenting a trade-off b/w // storage and redundant computation in several cases. // TODO(andydavis) Support computation slices with common surrounding loops. -AffineForOp insertBackwardComputationSlice(Instruction *srcOpInst, - Instruction *dstOpInst, +AffineForOp insertBackwardComputationSlice(Operation *srcOpInst, + Operation *dstOpInst, unsigned dstLoopDepth, ComputationSliceState *sliceState); @@ -155,7 +154,7 @@ struct MemRefRegion { /// {memref = %A, write = false, {%i <= m0 <= %i + 7} } /// The last field is a 2-d FlatAffineConstraints symbolic in %i. /// - LogicalResult compute(Instruction *inst, unsigned loopDepth, + LogicalResult compute(Operation *op, unsigned loopDepth, ComputationSliceState *sliceState = nullptr); FlatAffineConstraints *getConstraints() { return &cst; } @@ -229,7 +228,7 @@ LogicalResult boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, bool emitError = true); /// Returns the number of surrounding loops common to both A and B. -unsigned getNumCommonSurroundingLoops(Instruction &A, Instruction &B); +unsigned getNumCommonSurroundingLoops(Operation &A, Operation &B); /// Gets the memory footprint of all data touched in the specified memory space /// in bytes; if the memory space is unspecified, considers all memory spaces. diff --git a/mlir/include/mlir/Analysis/VectorAnalysis.h b/mlir/include/mlir/Analysis/VectorAnalysis.h index deb630b1708..c7726ed8a89 100644 --- a/mlir/include/mlir/Analysis/VectorAnalysis.h +++ b/mlir/include/mlir/Analysis/VectorAnalysis.h @@ -31,7 +31,6 @@ class FuncBuilder; class Location; class MemRefType; class Operation; -using Instruction = Operation; class Value; class VectorType; @@ -123,8 +122,8 @@ shapeRatio(VectorType superVectorType, VectorType subVectorType); /// `%arg0[%c0, %c0]` into vector<128xf32> which needs a 1-D vector broadcast. /// AffineMap makePermutationMap( - Instruction *opInst, - const llvm::DenseMap &loopToVectorDim); + Operation *op, + const llvm::DenseMap &loopToVectorDim); namespace matcher { @@ -136,7 +135,7 @@ namespace matcher { /// TODO(ntv): this could all be much simpler if we added a bit that a vector /// type to mark that a vector is a strict super-vector but it still does not /// warrant adding even 1 extra bit in the IR for now. -bool operatesOnSuperVectors(Instruction &inst, VectorType subVectorType); +bool operatesOnSuperVectors(Operation &op, VectorType subVectorType); } // end namespace matcher } // end namespace mlir diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index ef7a6e56368..f2dd357a1ae 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -108,7 +108,7 @@ public: /// Returns the closest surrounding operation that contains this operation /// or nullptr if this is a top-level operation. - Operation *getParentInst(); + Operation *getParentOp(); /// Returns the function that this operation is part of. /// The function is determined by traversing the chain of parent operations. @@ -131,8 +131,8 @@ public: /// function. void moveBefore(Operation *existingInst); - /// Unlink this operation operation from its current block and insert it - /// right before `iterator` in the specified block. + /// Unlink this operation from its current block and insert it right before + /// `iterator` in the specified block. void moveBefore(Block *block, llvm::iplist::iterator iterator); /// Given an operation 'other' that is within the same parent block, return diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index d23f2841e15..76889168d09 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -45,7 +45,7 @@ AffineOpsDialect::AffineOpsDialect(MLIRContext *context) bool mlir::isTopLevelSymbol(Value *value) { if (auto *arg = dyn_cast(value)) return arg->getOwner()->getParent()->getContainingFunction(); - return value->getDefiningOp()->getParentInst() == nullptr; + return value->getDefiningOp()->getParentOp() == nullptr; } // Value can be used as a dimension id if it is valid as a symbol, or @@ -56,16 +56,16 @@ bool mlir::isValidDim(Value *value) { if (!value->getType().isIndex()) return false; - if (auto *inst = value->getDefiningOp()) { + if (auto *op = value->getDefiningOp()) { // Top level instruction or constant operation is ok. - if (inst->getParentInst() == nullptr || inst->isa()) + if (op->getParentOp() == nullptr || op->isa()) return true; // Affine apply operation is ok if all of its operands are ok. - if (auto op = inst->dyn_cast()) - return op.isValidDim(); + if (auto applyOp = op->dyn_cast()) + return applyOp.isValidDim(); // The dim op is okay if its operand memref/tensor is defined at the top // level. - if (auto dimOp = inst->dyn_cast()) + if (auto dimOp = op->dyn_cast()) return isTopLevelSymbol(dimOp.getOperand()); return false; } @@ -81,16 +81,16 @@ bool mlir::isValidSymbol(Value *value) { if (!value->getType().isIndex()) return false; - if (auto *inst = value->getDefiningOp()) { + if (auto *op = value->getDefiningOp()) { // Top level instruction or constant operation is ok. - if (inst->getParentInst() == nullptr || inst->isa()) + if (op->getParentOp() == nullptr || op->isa()) return true; // Affine apply operation is ok if all of its operands are ok. - if (auto op = inst->dyn_cast()) - return op.isValidSymbol(); + if (auto applyOp = op->dyn_cast()) + return applyOp.isValidSymbol(); // The dim op is okay if its operand memref/tensor is defined at the top // level. - if (auto dimOp = inst->dyn_cast()) + if (auto dimOp = op->dyn_cast()) return isTopLevelSymbol(dimOp.getOperand()); return false; } diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index b3548f96b29..9fac3c8d11b 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -41,14 +41,13 @@ using namespace mlir; using llvm::dbgs; -/// Returns the sequence of AffineApplyOp Instructions operation in +/// Returns the sequence of AffineApplyOp Operations operation in /// 'affineApplyOps', which are reachable via a search starting from 'operands', /// and ending at operands which are not defined by AffineApplyOps. // TODO(andydavis) Add a method to AffineApplyOp which forward substitutes // the AffineApplyOp into any user AffineApplyOps. void mlir::getReachableAffineApplyOps( - ArrayRef operands, - SmallVectorImpl &affineApplyOps) { + ArrayRef operands, SmallVectorImpl &affineApplyOps) { struct State { // The ssa value for this node in the DFS traversal. Value *value; @@ -64,28 +63,27 @@ void mlir::getReachableAffineApplyOps( State &state = worklist.back(); auto *opInst = state.value->getDefiningOp(); // Note: getDefiningOp will return nullptr if the operand is not an - // Instruction (i.e. AffineForOp), which is a terminator for the search. + // Operation (i.e. block argument), which is a terminator for the search. if (opInst == nullptr || !opInst->isa()) { worklist.pop_back(); continue; } - if (auto affineApplyOp = opInst->dyn_cast()) { - if (state.operandIndex == 0) { - // Pre-Visit: Add 'opInst' to reachable sequence. - affineApplyOps.push_back(opInst); - } - if (state.operandIndex < opInst->getNumOperands()) { - // Visit: Add next 'affineApplyOp' operand to worklist. - // Get next operand to visit at 'operandIndex'. - auto *nextOperand = opInst->getOperand(state.operandIndex); - // Increment 'operandIndex' in 'state'. - ++state.operandIndex; - // Add 'nextOperand' to worklist. - worklist.push_back({nextOperand, 0}); - } else { - // Post-visit: done visiting operands AffineApplyOp, pop off stack. - worklist.pop_back(); - } + + if (state.operandIndex == 0) { + // Pre-Visit: Add 'opInst' to reachable sequence. + affineApplyOps.push_back(opInst); + } + if (state.operandIndex < opInst->getNumOperands()) { + // Visit: Add next 'affineApplyOp' operand to worklist. + // Get next operand to visit at 'operandIndex'. + auto *nextOperand = opInst->getOperand(state.operandIndex); + // Increment 'operandIndex' in 'state'. + ++state.operandIndex; + // Add 'nextOperand' to worklist. + worklist.push_back({nextOperand, 0}); + } else { + // Post-visit: done visiting operands AffineApplyOp, pop off stack. + worklist.pop_back(); } } } @@ -115,15 +113,15 @@ LogicalResult mlir::getIndexSet(MutableArrayRef forOps, // Computes the iteration domain for 'opInst' and populates 'indexSet', which // encapsulates the constraints involving loops surrounding 'opInst' and // potentially involving any Function symbols. The dimensional identifiers in -// 'indexSet' correspond to the loops surounding 'inst' from outermost to +// 'indexSet' correspond to the loops surounding 'op' from outermost to // innermost. -// TODO(andydavis) Add support to handle IfInsts surrounding 'inst'. -static LogicalResult getInstIndexSet(Instruction *inst, +// TODO(andydavis) Add support to handle IfInsts surrounding 'op'. +static LogicalResult getInstIndexSet(Operation *op, FlatAffineConstraints *indexSet) { // TODO(andydavis) Extend this to gather enclosing IfInsts and consider // factoring it out into a utility function. SmallVector loops; - getLoopIVs(*inst, &loops); + getLoopIVs(*op, &loops); return getIndexSet(loops, indexSet); } @@ -549,13 +547,12 @@ static Block *getCommonBlock(const MemRefAccess &srcAccess, return forOp.getBody(); } -// Returns true if the ancestor operation instruction of 'srcAccess' appears -// before the ancestor operation instruction of 'dstAccess' in the common -// ancestral block. Returns false otherwise. +// Returns true if the ancestor operation of 'srcAccess' appears before the +// ancestor operation of 'dstAccess' in the common ancestral block. Returns +// false otherwise. // Note that because 'srcAccess' or 'dstAccess' may be nested in conditionals, -// the function is named 'srcAppearsBeforeDstInCommonBlock'. -// Note that 'numCommonLoops' is the number of contiguous surrounding outer -// loops. +// the function is named 'srcAppearsBeforeDstInCommonBlock'. Note that +// 'numCommonLoops' is the number of contiguous surrounding outer loops. static bool srcAppearsBeforeDstInAncestralBlock( const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, const FlatAffineConstraints &srcDomain, unsigned numCommonLoops) { @@ -791,19 +788,19 @@ bool mlir::checkMemrefAccessDependence( AffineValueMap dstAccessMap; dstAccess.getAccessMap(&dstAccessMap); - // Get iteration domain for the 'srcAccess' instruction. + // Get iteration domain for the 'srcAccess' operation. FlatAffineConstraints srcDomain; if (failed(getInstIndexSet(srcAccess.opInst, &srcDomain))) return false; - // Get iteration domain for 'dstAccess' instruction. + // Get iteration domain for 'dstAccess' operation. FlatAffineConstraints dstDomain; if (failed(getInstIndexSet(dstAccess.opInst, &dstDomain))) return false; // Return 'false' if loopDepth > numCommonLoops and if the ancestor operation - // instruction of 'srcAccess' does not properly dominate the ancestor - // operation instruction of 'dstAccess' in the same common instruction block. + // operation of 'srcAccess' does not properly dominate the ancestor + // operation of 'dstAccess' in the same common operation block. // Note: this check is skipped if 'allowRAR' is true, because because RAR // deps can exist irrespective of lexicographic ordering b/w src and dst. unsigned numCommonLoops = getNumCommonLoops(srcDomain, dstDomain); diff --git a/mlir/lib/Analysis/Dominance.cpp b/mlir/lib/Analysis/Dominance.cpp index b8a9e1c0218..d914f36cdaf 100644 --- a/mlir/lib/Analysis/Dominance.cpp +++ b/mlir/lib/Analysis/Dominance.cpp @@ -46,8 +46,8 @@ void DominanceInfoBase::recalculate(Function *function) { std::move(functionDominance)); /// Build the dominance for each of the operation regions. - function->walk([&](Instruction *inst) { - for (auto ®ion : inst->getRegions()) { + function->walk([&](Operation *op) { + for (auto ®ion : op->getRegions()) { // Don't compute dominance if the region is empty. if (region.empty()) continue; @@ -66,11 +66,11 @@ bool DominanceInfoBase::properlyDominates(Block *a, Block *b) { return false; // If both blocks are not in the same region, 'a' properly dominates 'b' if - // 'b' is defined in an instruction region that (recursively) ends up being + // 'b' is defined in an operation region that (recursively) ends up being // dominated by 'a'. Walk up the list of containers enclosing B. auto *regionA = a->getParent(), *regionB = b->getParent(); if (regionA != regionB) { - Instruction *bAncestor; + Operation *bAncestor; do { bAncestor = regionB->getContainingOp(); // If 'bAncestor' is the top level function, then 'a' is a block @@ -100,8 +100,8 @@ template class mlir::detail::DominanceInfoBase; // DominanceInfo //===----------------------------------------------------------------------===// -/// Return true if instruction A properly dominates instruction B. -bool DominanceInfo::properlyDominates(Instruction *a, Instruction *b) { +/// Return true if operation A properly dominates operation B. +bool DominanceInfo::properlyDominates(Operation *a, Operation *b) { auto *aBlock = a->getBlock(), *bBlock = b->getBlock(); // If the blocks are the same, then check if b is before a in the block. @@ -120,12 +120,12 @@ bool DominanceInfo::properlyDominates(Instruction *a, Instruction *b) { return properlyDominates(aBlock, bBlock); } -/// Return true if value A properly dominates instruction B. -bool DominanceInfo::properlyDominates(Value *a, Instruction *b) { +/// Return true if value A properly dominates operation B. +bool DominanceInfo::properlyDominates(Value *a, Operation *b) { if (auto *aInst = a->getDefiningOp()) return properlyDominates(aInst, b); - // block arguments properly dominate all instructions in their own block, so + // block arguments properly dominate all operations in their own block, so // we use a dominates check here, not a properlyDominates check. return dominates(cast(a)->getOwner(), b->getBlock()); } @@ -135,7 +135,7 @@ bool DominanceInfo::properlyDominates(Value *a, Instruction *b) { //===----------------------------------------------------------------------===// /// Returns true if statement 'a' properly postdominates statement b. -bool PostDominanceInfo::properlyPostDominates(Instruction *a, Instruction *b) { +bool PostDominanceInfo::properlyPostDominates(Operation *a, Operation *b) { auto *aBlock = a->getBlock(), *bBlock = b->getBlock(); // If the blocks are the same, check if b is before a in the block. diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index eb272389957..e720e194814 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -177,7 +177,7 @@ uint64_t mlir::getLargestDivisorOfTripCount(AffineForOp forOp) { bool mlir::isAccessInvariant(Value &iv, Value &index) { assert(isForInductionVar(&iv) && "iv must be a AffineForOp"); assert(index.getType().isa() && "index must be of IndexType"); - SmallVector affineApplyOps; + SmallVector affineApplyOps; getReachableAffineApplyOps({&index}, affineApplyOps); if (affineApplyOps.empty()) { @@ -272,11 +272,11 @@ static bool isVectorElement(LoadOrStoreOpPointer memoryOp) { return memRefType.getElementType().template isa(); } -static bool isVectorTransferReadOrWrite(Instruction &inst) { - return inst.isa() || inst.isa(); +static bool isVectorTransferReadOrWrite(Operation &op) { + return op.isa() || op.isa(); } -using VectorizableInstFun = std::function; +using VectorizableInstFun = std::function; static bool isVectorizableLoopWithCond(AffineForOp loop, VectorizableInstFun isVectorizableInst) { @@ -295,9 +295,9 @@ static bool isVectorizableLoopWithCond(AffineForOp loop, } // No vectorization across unknown regions. - auto regions = matcher::Op([](Instruction &inst) -> bool { - return inst.getNumRegions() != 0 && - !(inst.isa() || inst.isa()); + auto regions = matcher::Op([](Operation &op) -> bool { + return op.getNumRegions() != 0 && + !(op.isa() || op.isa()); }); SmallVector regionsMatched; regions.match(forInst, ®ionsMatched); @@ -316,7 +316,7 @@ static bool isVectorizableLoopWithCond(AffineForOp loop, SmallVector loadAndStoresMatched; loadAndStores.match(forInst, &loadAndStoresMatched); for (auto ls : loadAndStoresMatched) { - auto *op = ls.getMatchedInstruction(); + auto *op = ls.getMatchedOperation(); auto load = op->dyn_cast(); auto store = op->dyn_cast(); // Only scalar types are considered vectorizable, all load/store must be @@ -336,7 +336,7 @@ static bool isVectorizableLoopWithCond(AffineForOp loop, bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim( AffineForOp loop, unsigned fastestVaryingDim) { VectorizableInstFun fun( - [fastestVaryingDim](AffineForOp loop, Instruction &op) { + [fastestVaryingDim](AffineForOp loop, Operation &op) { auto load = op.dyn_cast(); auto store = op.dyn_cast(); return load ? isContiguousAccess(*loop.getInductionVar(), load, @@ -350,12 +350,12 @@ bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim( bool mlir::isVectorizableLoop(AffineForOp loop) { VectorizableInstFun fun( // TODO: implement me - [](AffineForOp loop, Instruction &op) { return true; }); + [](AffineForOp loop, Operation &op) { return true; }); return isVectorizableLoopWithCond(loop, fun); } -/// Checks whether SSA dominance would be violated if a for inst's body -/// instructions are shifted by the specified shifts. This method checks if a +/// Checks whether SSA dominance would be violated if a for op's body +/// operations are shifted by the specified shifts. This method checks if a /// 'def' and all its uses have the same shift factor. // TODO(mlir-team): extend this to check for memory-based dependence // violation when we have the support. @@ -364,24 +364,24 @@ bool mlir::isInstwiseShiftValid(AffineForOp forOp, ArrayRef shifts) { assert(shifts.size() == forBody->getOperations().size()); // Work backwards over the body of the block so that the shift of a use's - // ancestor instruction in the block gets recorded before it's looked up. - DenseMap forBodyShift; + // ancestor operation in the block gets recorded before it's looked up. + DenseMap forBodyShift; for (auto it : llvm::enumerate(llvm::reverse(forBody->getOperations()))) { - auto &inst = it.value(); + auto &op = it.value(); - // Get the index of the current instruction, note that we are iterating in + // Get the index of the current operation, note that we are iterating in // reverse so we need to fix it up. size_t index = shifts.size() - it.index() - 1; - // Remember the shift of this instruction. + // Remember the shift of this operation. uint64_t shift = shifts[index]; - forBodyShift.try_emplace(&inst, shift); + forBodyShift.try_emplace(&op, shift); - // Validate the results of this instruction if it were to be shifted. - for (unsigned i = 0, e = inst.getNumResults(); i < e; ++i) { - Value *result = inst.getResult(i); + // Validate the results of this operation if it were to be shifted. + for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) { + Value *result = op.getResult(i); for (const InstOperand &use : result->getUses()) { - // If an ancestor instruction doesn't lie in the block of forOp, + // If an ancestor operation doesn't lie in the block of forOp, // there is no shift to check. if (auto *ancInst = forBody->findAncestorInstInBlock(*use.getOwner())) { assert(forBodyShift.count(ancInst) > 0 && "ancestor expected in map"); diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index 8edf79d6db3..0fb88620fa1 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -47,7 +47,7 @@ FunctionPassBase *mlir::createMemRefBoundCheckPass() { } void MemRefBoundCheck::runOnFunction() { - getFunction().walk([](Instruction *opInst) { + getFunction().walk([](Operation *opInst) { if (auto loadOp = opInst->dyn_cast()) { boundCheckLoadOrStoreOp(loadOp); } else if (auto storeOp = opInst->dyn_cast()) { diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp index 8e438108bce..2872c4cf256 100644 --- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp @@ -37,7 +37,7 @@ namespace { // TODO(andydavis) Add common surrounding loop depth-wise dependence checks. /// Checks dependences between all pairs of memref accesses in a Function. struct MemRefDependenceCheck : public FunctionPass { - SmallVector loadsAndStores; + SmallVector loadsAndStores; void runOnFunction() override; }; @@ -79,7 +79,7 @@ getDirectionVectorStr(bool ret, unsigned numCommonLoops, unsigned loopNestDepth, // "source" access and all subsequent "destination" accesses in // 'loadsAndStores'. Emits the result of the dependence check as a note with // the source access. -static void checkDependences(ArrayRef loadsAndStores) { +static void checkDependences(ArrayRef loadsAndStores) { for (unsigned i = 0, e = loadsAndStores.size(); i < e; ++i) { auto *srcOpInst = loadsAndStores[i]; MemRefAccess srcAccess(srcOpInst); @@ -113,9 +113,9 @@ static void checkDependences(ArrayRef loadsAndStores) { void MemRefDependenceCheck::runOnFunction() { // Collect the loads and stores within the function. loadsAndStores.clear(); - getFunction().walk([&](Instruction *inst) { - if (inst->isa() || inst->isa()) - loadsAndStores.push_back(inst); + getFunction().walk([&](Operation *op) { + if (op->isa() || op->isa()) + loadsAndStores.push_back(op); }); checkDependences(loadsAndStores); diff --git a/mlir/lib/Analysis/NestedMatcher.cpp b/mlir/lib/Analysis/NestedMatcher.cpp index 83b3591ce5c..43a725a3b7d 100644 --- a/mlir/lib/Analysis/NestedMatcher.cpp +++ b/mlir/lib/Analysis/NestedMatcher.cpp @@ -31,13 +31,13 @@ llvm::BumpPtrAllocator *&NestedMatch::allocator() { return allocator; } -NestedMatch NestedMatch::build(Instruction *instruction, +NestedMatch NestedMatch::build(Operation *operation, ArrayRef nestedMatches) { auto *result = allocator()->Allocate(); auto *children = allocator()->Allocate(nestedMatches.size()); std::uninitialized_copy(nestedMatches.begin(), nestedMatches.end(), children); new (result) NestedMatch(); - result->matchedInstruction = instruction; + result->matchedOperation = operation; result->matchedChildren = ArrayRef(children, nestedMatches.size()); return *result; @@ -69,29 +69,29 @@ unsigned NestedPattern::getDepth() const { return depth + 1; } -/// Matches a single instruction in the following way: -/// 1. checks the kind of instruction against the matcher, if different then +/// Matches a single operation in the following way: +/// 1. checks the kind of operation against the matcher, if different then /// there is no match; -/// 2. calls the customizable filter function to refine the single instruction +/// 2. calls the customizable filter function to refine the single operation /// match with extra semantic constraints; /// 3. if all is good, recursivey matches the nested patterns; -/// 4. if all nested match then the single instruction matches too and is +/// 4. if all nested match then the single operation matches too and is /// appended to the list of matches; /// 5. TODO(ntv) Optionally applies actions (lambda), in which case we will /// want to traverse in post-order DFS to avoid invalidating iterators. -void NestedPattern::matchOne(Instruction *inst, +void NestedPattern::matchOne(Operation *op, SmallVectorImpl *matches) { - if (skip == inst) { + if (skip == op) { return; } // Local custom filter function - if (!filter(*inst)) { + if (!filter(*op)) { return; } if (nestedPatterns.empty()) { SmallVector nestedMatches; - matches->push_back(NestedMatch::build(inst, nestedMatches)); + matches->push_back(NestedMatch::build(op, nestedMatches)); return; } // Take a copy of each nested pattern so we can match it. @@ -99,20 +99,20 @@ void NestedPattern::matchOne(Instruction *inst, SmallVector nestedMatches; // Skip elem in the walk immediately following. Without this we would // essentially need to reimplement walkPostOrder here. - nestedPattern.skip = inst; - nestedPattern.match(inst, &nestedMatches); + nestedPattern.skip = op; + nestedPattern.match(op, &nestedMatches); // If we could not match even one of the specified nestedPattern, early exit // as this whole branch is not a match. if (nestedMatches.empty()) { return; } - matches->push_back(NestedMatch::build(inst, nestedMatches)); + matches->push_back(NestedMatch::build(op, nestedMatches)); } } -static bool isAffineForOp(Instruction &inst) { return inst.isa(); } +static bool isAffineForOp(Operation &op) { return op.isa(); } -static bool isAffineIfOp(Instruction &inst) { return inst.isa(); } +static bool isAffineIfOp(Operation &op) { return op.isa(); } namespace mlir { namespace matcher { @@ -125,16 +125,16 @@ NestedPattern If(NestedPattern child) { return NestedPattern(child, isAffineIfOp); } NestedPattern If(FilterFunctionType filter, NestedPattern child) { - return NestedPattern(child, [filter](Instruction &inst) { - return isAffineIfOp(inst) && filter(inst); + return NestedPattern(child, [filter](Operation &op) { + return isAffineIfOp(op) && filter(op); }); } NestedPattern If(ArrayRef nested) { return NestedPattern(nested, isAffineIfOp); } NestedPattern If(FilterFunctionType filter, ArrayRef nested) { - return NestedPattern(nested, [filter](Instruction &inst) { - return isAffineIfOp(inst) && filter(inst); + return NestedPattern(nested, [filter](Operation &op) { + return isAffineIfOp(op) && filter(op); }); } @@ -142,33 +142,31 @@ NestedPattern For(NestedPattern child) { return NestedPattern(child, isAffineForOp); } NestedPattern For(FilterFunctionType filter, NestedPattern child) { - return NestedPattern(child, [=](Instruction &inst) { - return isAffineForOp(inst) && filter(inst); - }); + return NestedPattern( + child, [=](Operation &op) { return isAffineForOp(op) && filter(op); }); } NestedPattern For(ArrayRef nested) { return NestedPattern(nested, isAffineForOp); } NestedPattern For(FilterFunctionType filter, ArrayRef nested) { - return NestedPattern(nested, [=](Instruction &inst) { - return isAffineForOp(inst) && filter(inst); - }); + return NestedPattern( + nested, [=](Operation &op) { return isAffineForOp(op) && filter(op); }); } // TODO(ntv): parallel annotation on loops. -bool isParallelLoop(Instruction &inst) { - auto loop = inst.cast(); +bool isParallelLoop(Operation &op) { + auto loop = op.cast(); return loop || true; // loop->isParallel(); }; // TODO(ntv): reduction annotation on loops. -bool isReductionLoop(Instruction &inst) { - auto loop = inst.cast(); +bool isReductionLoop(Operation &op) { + auto loop = op.cast(); return loop || true; // loop->isReduction(); }; -bool isLoadOrStore(Instruction &inst) { - return inst.isa() || inst.isa(); +bool isLoadOrStore(Operation &op) { + return op.isa() || op.isa(); }; } // end namespace matcher diff --git a/mlir/lib/Analysis/OpStats.cpp b/mlir/lib/Analysis/OpStats.cpp index 0986ac480fe..7be7e9d9f12 100644 --- a/mlir/lib/Analysis/OpStats.cpp +++ b/mlir/lib/Analysis/OpStats.cpp @@ -46,8 +46,7 @@ void PrintOpStatsPass::runOnModule() { // Compute the operation statistics for each function in the module. for (auto &fn : getModule()) - fn.walk( - [&](Instruction *inst) { ++opCount[inst->getName().getStringRef()]; }); + fn.walk([&](Operation *op) { ++opCount[op->getName().getStringRef()]; }); printSummary(); } diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index 82320bd26ff..496c0b33e1e 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -38,21 +38,21 @@ using namespace mlir; using llvm::DenseSet; using llvm::SetVector; -static void getForwardSliceImpl(Instruction *inst, - SetVector *forwardSlice, +static void getForwardSliceImpl(Operation *op, + SetVector *forwardSlice, TransitiveFilter filter) { - if (!inst) { + if (!op) { return; } // Evaluate whether we should keep this use. // This is useful in particular to implement scoping; i.e. return the // transitive forwardSlice in the current scope. - if (!filter(inst)) { + if (!filter(op)) { return; } - if (auto forOp = inst->dyn_cast()) { + if (auto forOp = op->dyn_cast()) { for (auto &u : forOp.getInductionVar()->getUses()) { auto *ownerInst = u.getOwner(); if (forwardSlice->count(ownerInst) == 0) { @@ -60,9 +60,9 @@ static void getForwardSliceImpl(Instruction *inst, } } } else { - assert(inst->getNumResults() <= 1 && "NYI: multiple results"); - if (inst->getNumResults() > 0) { - for (auto &u : inst->getResult(0)->getUses()) { + assert(op->getNumResults() <= 1 && "NYI: multiple results"); + if (op->getNumResults() > 0) { + for (auto &u : op->getResult(0)->getUses()) { auto *ownerInst = u.getOwner(); if (forwardSlice->count(ownerInst) == 0) { getForwardSliceImpl(ownerInst, forwardSlice, filter); @@ -71,67 +71,66 @@ static void getForwardSliceImpl(Instruction *inst, } } - forwardSlice->insert(inst); + forwardSlice->insert(op); } -void mlir::getForwardSlice(Instruction *inst, - SetVector *forwardSlice, +void mlir::getForwardSlice(Operation *op, SetVector *forwardSlice, TransitiveFilter filter) { - getForwardSliceImpl(inst, forwardSlice, filter); - // Don't insert the top level instruction, we just queried on it and don't + getForwardSliceImpl(op, forwardSlice, filter); + // Don't insert the top level operation, we just queried on it and don't // want it in the results. - forwardSlice->remove(inst); + forwardSlice->remove(op); // Reverse to get back the actual topological order. // std::reverse does not work out of the box on SetVector and I want an // in-place swap based thing (the real std::reverse, not the LLVM adapter). - std::vector v(forwardSlice->takeVector()); + std::vector v(forwardSlice->takeVector()); forwardSlice->insert(v.rbegin(), v.rend()); } -static void getBackwardSliceImpl(Instruction *inst, - SetVector *backwardSlice, +static void getBackwardSliceImpl(Operation *op, + SetVector *backwardSlice, TransitiveFilter filter) { - if (!inst) { + if (!op) { return; } // Evaluate whether we should keep this def. // This is useful in particular to implement scoping; i.e. return the // transitive forwardSlice in the current scope. - if (!filter(inst)) { + if (!filter(op)) { return; } - for (auto *operand : inst->getOperands()) { - auto *inst = operand->getDefiningOp(); - if (backwardSlice->count(inst) == 0) { - getBackwardSliceImpl(inst, backwardSlice, filter); + for (auto *operand : op->getOperands()) { + auto *op = operand->getDefiningOp(); + if (backwardSlice->count(op) == 0) { + getBackwardSliceImpl(op, backwardSlice, filter); } } - backwardSlice->insert(inst); + backwardSlice->insert(op); } -void mlir::getBackwardSlice(Instruction *inst, - SetVector *backwardSlice, +void mlir::getBackwardSlice(Operation *op, + SetVector *backwardSlice, TransitiveFilter filter) { - getBackwardSliceImpl(inst, backwardSlice, filter); + getBackwardSliceImpl(op, backwardSlice, filter); - // Don't insert the top level instruction, we just queried on it and don't + // Don't insert the top level operation, we just queried on it and don't // want it in the results. - backwardSlice->remove(inst); + backwardSlice->remove(op); } -SetVector mlir::getSlice(Instruction *inst, - TransitiveFilter backwardFilter, - TransitiveFilter forwardFilter) { - SetVector slice; - slice.insert(inst); +SetVector mlir::getSlice(Operation *op, + TransitiveFilter backwardFilter, + TransitiveFilter forwardFilter) { + SetVector slice; + slice.insert(op); unsigned currentIndex = 0; - SetVector backwardSlice; - SetVector forwardSlice; + SetVector backwardSlice; + SetVector forwardSlice; while (currentIndex != slice.size()) { auto *currentInst = (slice)[currentIndex]; // Compute and insert the backwardSlice starting from currentInst. @@ -151,23 +150,23 @@ SetVector mlir::getSlice(Instruction *inst, namespace { /// DFS post-order implementation that maintains a global count to work across /// multiple invocations, to help implement topological sort on multi-root DAGs. -/// We traverse all instructions but only record the ones that appear in +/// We traverse all operations but only record the ones that appear in /// `toSort` for the final result. struct DFSState { - DFSState(const SetVector &set) + DFSState(const SetVector &set) : toSort(set), topologicalCounts(), seen() {} - const SetVector &toSort; - SmallVector topologicalCounts; - DenseSet seen; + const SetVector &toSort; + SmallVector topologicalCounts; + DenseSet seen; }; } // namespace -static void DFSPostorder(Instruction *current, DFSState *state) { +static void DFSPostorder(Operation *current, DFSState *state) { assert(current->getNumResults() <= 1 && "NYI: multi-result"); if (current->getNumResults() > 0) { for (auto &u : current->getResult(0)->getUses()) { - auto *inst = u.getOwner(); - DFSPostorder(inst, state); + auto *op = u.getOwner(); + DFSPostorder(op, state); } } bool inserted; @@ -181,8 +180,8 @@ static void DFSPostorder(Instruction *current, DFSState *state) { } } -SetVector -mlir::topologicalSort(const SetVector &toSort) { +SetVector +mlir::topologicalSort(const SetVector &toSort) { if (toSort.empty()) { return toSort; } @@ -195,7 +194,7 @@ mlir::topologicalSort(const SetVector &toSort) { } // Reorder and return. - SetVector res; + SetVector res; for (auto it = state.topologicalCounts.rbegin(), eit = state.topologicalCounts.rend(); it != eit; ++it) { diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index a9c22d62f0b..5999b357e96 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -37,18 +37,18 @@ using namespace mlir; using llvm::SmallDenseMap; -/// Populates 'loops' with IVs of the loops surrounding 'inst' ordered from -/// the outermost 'affine.for' instruction to the innermost one. -void mlir::getLoopIVs(Instruction &inst, SmallVectorImpl *loops) { - auto *currInst = inst.getParentInst(); +/// Populates 'loops' with IVs of the loops surrounding 'op' ordered from +/// the outermost 'affine.for' operation to the innermost one. +void mlir::getLoopIVs(Operation &op, SmallVectorImpl *loops) { + auto *currOp = op.getParentOp(); AffineForOp currAffineForOp; - // Traverse up the hierarchy collecing all 'affine.for' instruction while - // skipping over 'affine.if' instructions. - while (currInst && ((currAffineForOp = currInst->dyn_cast()) || - currInst->isa())) { + // Traverse up the hierarchy collecing all 'affine.for' operation while + // skipping over 'affine.if' operations. + while (currOp && ((currAffineForOp = currOp->dyn_cast()) || + currOp->isa())) { if (currAffineForOp) loops->push_back(currAffineForOp); - currInst = currInst->getParentInst(); + currOp = currOp->getParentOp(); } std::reverse(loops->begin(), loops->end()); } @@ -73,8 +73,8 @@ ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) { assert(cst->containsId(*value) && "value expected to be present"); if (isValidSymbol(value)) { // Check if the symbol is a constant. - if (auto *inst = value->getDefiningOp()) { - if (auto constOp = inst->dyn_cast()) { + if (auto *op = value->getDefiningOp()) { + if (auto constOp = op->dyn_cast()) { cst->setIdToConstant(*value, constOp.getValue()); } } @@ -173,23 +173,22 @@ LogicalResult MemRefRegion::unionBoundingBox(const MemRefRegion &other) { // // TODO(bondhugula): extend this to any other memref dereferencing ops // (dma_start, dma_wait). -LogicalResult MemRefRegion::compute(Instruction *inst, unsigned loopDepth, +LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth, ComputationSliceState *sliceState) { - assert((inst->isa() || inst->isa()) && - "load/store op expected"); + assert((op->isa() || op->isa()) && "load/store op expected"); - MemRefAccess access(inst); + MemRefAccess access(op); memref = access.memref; write = access.isStore(); unsigned rank = access.getRank(); - LLVM_DEBUG(llvm::dbgs() << "MemRefRegion::compute: " << *inst + LLVM_DEBUG(llvm::dbgs() << "MemRefRegion::compute: " << *op << "depth: " << loopDepth << "\n";); if (rank == 0) { SmallVector ivs; - getLoopIVs(*inst, &ivs); + getLoopIVs(*op, &ivs); SmallVector regionSymbols; extractForInductionVars(ivs, ®ionSymbols); // A rank 0 memref has a 0-d region. @@ -242,8 +241,8 @@ LogicalResult MemRefRegion::compute(Instruction *inst, unsigned loopDepth, auto *symbol = operand; assert(isValidSymbol(symbol)); // Check if the symbol is a constant. - if (auto *inst = symbol->getDefiningOp()) { - if (auto constOp = inst->dyn_cast()) { + if (auto *op = symbol->getDefiningOp()) { + if (auto constOp = op->dyn_cast()) { cst.setIdToConstant(*symbol, constOp.getValue()); } } @@ -267,7 +266,7 @@ LogicalResult MemRefRegion::compute(Instruction *inst, unsigned loopDepth, // Add access function equalities to connect loop IVs to data dimensions. if (failed(cst.composeMap(&accessValueMap))) { - inst->emitError("getMemRefRegion: compose affine map failed"); + op->emitError("getMemRefRegion: compose affine map failed"); LLVM_DEBUG(accessValueMap.getAffineMap().dump()); return failure(); } @@ -280,7 +279,7 @@ LogicalResult MemRefRegion::compute(Instruction *inst, unsigned loopDepth, // Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which // this memref region is symbolic. SmallVector enclosingIVs; - getLoopIVs(*inst, &enclosingIVs); + getLoopIVs(*op, &enclosingIVs); assert(loopDepth <= enclosingIVs.size() && "invalid loop depth"); enclosingIVs.resize(loopDepth); SmallVector ids; @@ -374,7 +373,7 @@ LogicalResult mlir::boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, std::is_same::value, "argument should be either a LoadOp or a StoreOp"); - Instruction *opInst = loadOrStoreOp.getOperation(); + Operation *opInst = loadOrStoreOp.getOperation(); MemRefRegion region(opInst->getLoc()); if (failed(region.compute(opInst, /*loopDepth=*/0))) @@ -427,40 +426,40 @@ template LogicalResult mlir::boundCheckLoadOrStoreOp(LoadOp loadOp, template LogicalResult mlir::boundCheckLoadOrStoreOp(StoreOp storeOp, bool emitError); -// Returns in 'positions' the Block positions of 'inst' in each ancestor -// Block from the Block containing instruction, stopping at 'limitBlock'. -static void findInstPosition(Instruction *inst, Block *limitBlock, +// Returns in 'positions' the Block positions of 'op' in each ancestor +// Block from the Block containing operation, stopping at 'limitBlock'. +static void findInstPosition(Operation *op, Block *limitBlock, SmallVectorImpl *positions) { - Block *block = inst->getBlock(); + Block *block = op->getBlock(); while (block != limitBlock) { // FIXME: This algorithm is unnecessarily O(n) and should be improved to not // rely on linear scans. - int instPosInBlock = std::distance(block->begin(), inst->getIterator()); + int instPosInBlock = std::distance(block->begin(), op->getIterator()); positions->push_back(instPosInBlock); - inst = block->getContainingOp(); - block = inst->getBlock(); + op = block->getContainingOp(); + block = op->getBlock(); } std::reverse(positions->begin(), positions->end()); } -// Returns the Instruction in a possibly nested set of Blocks, where the -// position of the instruction is represented by 'positions', which has a +// Returns the Operation in a possibly nested set of Blocks, where the +// position of the operation is represented by 'positions', which has a // Block position for each level of nesting. -static Instruction *getInstAtPosition(ArrayRef positions, - unsigned level, Block *block) { +static Operation *getInstAtPosition(ArrayRef positions, + unsigned level, Block *block) { unsigned i = 0; - for (auto &inst : *block) { + for (auto &op : *block) { if (i != positions[level]) { ++i; continue; } if (level == positions.size() - 1) - return &inst; - if (auto childAffineForOp = inst.dyn_cast()) + return &op; + if (auto childAffineForOp = op.dyn_cast()) return getInstAtPosition(positions, level + 1, childAffineForOp.getBody()); - for (auto ®ion : inst.getRegions()) { + for (auto ®ion : op.getRegions()) { for (auto &b : region) if (auto *ret = getInstAtPosition(positions, level + 1, &b)) return ret; @@ -563,9 +562,10 @@ LogicalResult mlir::getBackwardComputationSliceState( // entire destination index set. Subtract out the dependent destination // iterations from destination index set and check for emptiness --- this is one // solution. -AffineForOp mlir::insertBackwardComputationSlice( - Instruction *srcOpInst, Instruction *dstOpInst, unsigned dstLoopDepth, - ComputationSliceState *sliceState) { +AffineForOp +mlir::insertBackwardComputationSlice(Operation *srcOpInst, Operation *dstOpInst, + unsigned dstLoopDepth, + ComputationSliceState *sliceState) { // Get loop nest surrounding src operation. SmallVector srcLoopIVs; getLoopIVs(*srcOpInst, &srcLoopIVs); @@ -580,20 +580,20 @@ AffineForOp mlir::insertBackwardComputationSlice( return AffineForOp(); } - // Find the inst block positions of 'srcOpInst' within 'srcLoopIVs'. + // Find the op block positions of 'srcOpInst' within 'srcLoopIVs'. SmallVector positions; // TODO(andydavis): This code is incorrect since srcLoopIVs can be 0-d. findInstPosition(srcOpInst, srcLoopIVs[0].getOperation()->getBlock(), &positions); - // Clone src loop nest and insert it a the beginning of the instruction block + // Clone src loop nest and insert it a the beginning of the operation block // of the loop at 'dstLoopDepth' in 'dstLoopIVs'. auto dstAffineForOp = dstLoopIVs[dstLoopDepth - 1]; FuncBuilder b(dstAffineForOp.getBody(), dstAffineForOp.getBody()->begin()); auto sliceLoopNest = b.clone(*srcLoopIVs[0].getOperation())->cast(); - Instruction *sliceInst = + Operation *sliceInst = getInstAtPosition(positions, /*level=*/0, sliceLoopNest.getBody()); // Get loop nest surrounding 'sliceInst'. SmallVector sliceSurroundingLoops; @@ -620,7 +620,7 @@ AffineForOp mlir::insertBackwardComputationSlice( // Constructs MemRefAccess populating it with the memref, its indices and // opinst from 'loadOrStoreOpInst'. -MemRefAccess::MemRefAccess(Instruction *loadOrStoreOpInst) { +MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) { if (auto loadOp = loadOrStoreOpInst->dyn_cast()) { memref = loadOp.getMemRef(); opInst = loadOrStoreOpInst; @@ -650,11 +650,11 @@ bool MemRefAccess::isStore() const { return opInst->isa(); } /// Returns the nesting depth of this statement, i.e., the number of loops /// surrounding this statement. -unsigned mlir::getNestingDepth(Instruction &inst) { - Instruction *currInst = &inst; +unsigned mlir::getNestingDepth(Operation &op) { + Operation *currOp = &op; unsigned depth = 0; - while ((currInst = currInst->getParentInst())) { - if (currInst->isa()) + while ((currOp = currOp->getParentOp())) { + if (currOp->isa()) depth++; } return depth; @@ -662,7 +662,7 @@ unsigned mlir::getNestingDepth(Instruction &inst) { /// Returns the number of surrounding loops common to 'loopsA' and 'loopsB', /// where each lists loops from outer-most to inner-most in loop nest. -unsigned mlir::getNumCommonSurroundingLoops(Instruction &A, Instruction &B) { +unsigned mlir::getNumCommonSurroundingLoops(Operation &A, Operation &B) { SmallVector loopsA, loopsB; getLoopIVs(A, &loopsA); getLoopIVs(B, &loopsB); @@ -683,9 +683,9 @@ static Optional getMemoryFootprintBytes(Block &block, int memorySpace) { SmallDenseMap, 4> regions; - // Walk this 'affine.for' instruction to gather all memory regions. + // Walk this 'affine.for' operation to gather all memory regions. bool error = false; - block.walk(start, end, [&](Instruction *opInst) { + block.walk(start, end, [&](Operation *opInst) { if (!opInst->isa() && !opInst->isa()) { // Neither load nor a store op. return; @@ -737,8 +737,8 @@ Optional mlir::getMemoryFootprintBytes(AffineForOp forOp, /// at 'forOp'. void mlir::getSequentialLoops( AffineForOp forOp, llvm::SmallDenseSet *sequentialLoops) { - forOp.getOperation()->walk([&](Instruction *inst) { - if (auto innerFor = inst->dyn_cast()) + forOp.getOperation()->walk([&](Operation *op) { + if (auto innerFor = op->dyn_cast()) if (!isLoopParallel(innerFor)) sequentialLoops->insert(innerFor.getInductionVar()); }); @@ -747,8 +747,8 @@ void mlir::getSequentialLoops( /// Returns true if 'forOp' is parallel. bool mlir::isLoopParallel(AffineForOp forOp) { // Collect all load and store ops in loop nest rooted at 'forOp'. - SmallVector loadAndStoreOpInsts; - forOp.getOperation()->walk([&](Instruction *opInst) { + SmallVector loadAndStoreOpInsts; + forOp.getOperation()->walk([&](Operation *opInst) { if (opInst->isa() || opInst->isa()) loadAndStoreOpInsts.push_back(opInst); }); diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index 232fe1a16ff..9a2e72f66be 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -105,8 +105,8 @@ Optional> mlir::shapeRatio(VectorType superVectorType, /// header file. static AffineMap makePermutationMap( MLIRContext *context, - llvm::iterator_range indices, - const DenseMap &enclosingLoopToVectorDim) { + llvm::iterator_range indices, + const DenseMap &enclosingLoopToVectorDim) { using functional::makePtrDynCaster; using functional::map; auto unwrappedIndices = map(makePtrDynCaster(), indices); @@ -140,10 +140,10 @@ static AffineMap makePermutationMap( /// TODO(ntv): could also be implemented as a collect parents followed by a /// filter and made available outside this file. template -static SetVector getParentsOfType(Instruction *inst) { - SetVector res; - auto *current = inst; - while (auto *parent = current->getParentInst()) { +static SetVector getParentsOfType(Operation *op) { + SetVector res; + auto *current = op; + while (auto *parent = current->getParentOp()) { if (auto typedParent = parent->template dyn_cast()) { assert(res.count(parent) == 0 && "Already inserted"); res.insert(parent); @@ -154,15 +154,14 @@ static SetVector getParentsOfType(Instruction *inst) { } /// Returns the enclosing AffineForOp, from closest to farthest. -static SetVector getEnclosingforOps(Instruction *inst) { - return getParentsOfType(inst); +static SetVector getEnclosingforOps(Operation *op) { + return getParentsOfType(op); } AffineMap mlir::makePermutationMap( - Instruction *opInst, - const DenseMap &loopToVectorDim) { - DenseMap enclosingLoopToVectorDim; - auto enclosingLoops = getEnclosingforOps(opInst); + Operation *op, const DenseMap &loopToVectorDim) { + DenseMap enclosingLoopToVectorDim; + auto enclosingLoops = getEnclosingforOps(op); for (auto *forInst : enclosingLoops) { auto it = loopToVectorDim.find(forInst); if (it != loopToVectorDim.end()) { @@ -170,17 +169,17 @@ AffineMap mlir::makePermutationMap( } } - if (auto load = opInst->dyn_cast()) { - return ::makePermutationMap(opInst->getContext(), load.getIndices(), + if (auto load = op->dyn_cast()) { + return ::makePermutationMap(op->getContext(), load.getIndices(), enclosingLoopToVectorDim); } - auto store = opInst->cast(); - return ::makePermutationMap(opInst->getContext(), store.getIndices(), + auto store = op->cast(); + return ::makePermutationMap(op->getContext(), store.getIndices(), enclosingLoopToVectorDim); } -bool mlir::matcher::operatesOnSuperVectors(Instruction &opInst, +bool mlir::matcher::operatesOnSuperVectors(Operation &op, VectorType subVectorType) { // First, extract the vector type and ditinguish between: // a. ops that *must* lower a super-vector (i.e. vector_transfer_read, @@ -193,20 +192,20 @@ bool mlir::matcher::operatesOnSuperVectors(Instruction &opInst, /// do not have to special case. Maybe a trait, or just a method, unclear atm. bool mustDivide = false; VectorType superVectorType; - if (auto read = opInst.dyn_cast()) { + if (auto read = op.dyn_cast()) { superVectorType = read.getResultType(); mustDivide = true; - } else if (auto write = opInst.dyn_cast()) { + } else if (auto write = op.dyn_cast()) { superVectorType = write.getVectorType(); mustDivide = true; - } else if (opInst.getNumResults() == 0) { - if (!opInst.isa()) { - opInst.emitError("NYI: assuming only return instructions can have 0 " - " results at this point"); + } else if (op.getNumResults() == 0) { + if (!op.isa()) { + op.emitError("NYI: assuming only return operations can have 0 " + " results at this point"); } return false; - } else if (opInst.getNumResults() == 1) { - if (auto v = opInst.getResult(0)->getType().dyn_cast()) { + } else if (op.getNumResults() == 1) { + if (auto v = op.getResult(0)->getType().dyn_cast()) { superVectorType = v; } else { // Not a vector type. @@ -215,7 +214,7 @@ bool mlir::matcher::operatesOnSuperVectors(Instruction &opInst, } else { // Not a vector_transfer and has more than 1 result, fail hard for now to // wake us up when something changes. - opInst.emitError("NYI: instruction has more than 1 result"); + op.emitError("NYI: operation has more than 1 result"); return false; } @@ -224,7 +223,7 @@ bool mlir::matcher::operatesOnSuperVectors(Instruction &opInst, // Sanity check. assert((ratio.hasValue() || !mustDivide) && - "vector_transfer instruction in which super-vector size is not an" + "vector_transfer operation in which super-vector size is not an" " integer multiple of sub-vector size"); // This catches cases that are not strictly necessary to have multiplicity but diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index d7eb578ab1a..fddd9ac25e4 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -23,9 +23,9 @@ // The checks in this file are only for things that can occur as part of IR // transformations: e.g. violation of dominance information, malformed operation // attributes, etc. MLIR supports transformations moving IR through locally -// invalid states (e.g. unlinking an instruction from an instruction before -// re-inserting it in a new place), but each transformation must complete with -// the IR in a valid form. +// invalid states (e.g. unlinking an operation from a block before re-inserting +// it in a new place), but each transformation must complete with the IR in a +// valid form. // // This should not check for things that are always wrong by construction (e.g. // affine maps or other immutable structures that are incorrect), because those @@ -52,7 +52,7 @@ namespace { /// class FuncVerifier { public: - bool failure(const Twine &message, Instruction &value) { + bool failure(const Twine &message, Operation &value) { return value.emitError(message); } @@ -61,7 +61,7 @@ public: } bool failure(const Twine &message, Block &bb) { - // Take the location information for the first instruction in the block. + // Take the location information for the first operation in the block. if (!bb.empty()) return failure(message, bb.front()); @@ -108,9 +108,9 @@ public: bool verify(); bool verifyBlock(Block &block, bool isTopLevel); - bool verifyOperation(Instruction &op); + bool verifyOperation(Operation &op); bool verifyDominance(Block &block); - bool verifyInstDominance(Instruction &inst); + bool verifyOpDominance(Operation &op); explicit FuncVerifier(Function &fn) : fn(fn), identifierRegex("^[a-zA-Z_][a-zA-Z_0-9\\.\\$]*$") {} @@ -231,15 +231,15 @@ bool FuncVerifier::verifyBlock(Block &block, bool isTopLevel) { return failure("block with no terminator", block); } - // Verify the non-terminator instructions separately so that we can verify + // Verify the non-terminator operations separately so that we can verify // they has no successors. - for (auto &inst : llvm::make_range(block.begin(), std::prev(block.end()))) { - if (inst.getNumSuccessors() != 0) + for (auto &op : llvm::make_range(block.begin(), std::prev(block.end()))) { + if (op.getNumSuccessors() != 0) return failure( - "instruction with block successors must terminate its parent block", - inst); + "operation with block successors must terminate its parent block", + op); - if (verifyOperation(inst)) + if (verifyOperation(op)) return true; } @@ -259,7 +259,7 @@ bool FuncVerifier::verifyBlock(Block &block, bool isTopLevel) { } /// Check the invariants of the specified operation. -bool FuncVerifier::verifyOperation(Instruction &op) { +bool FuncVerifier::verifyOperation(Operation &op) { if (op.getFunction() != &fn) return failure("operation in the wrong function", op); @@ -304,30 +304,30 @@ bool FuncVerifier::verifyOperation(Instruction &op) { } bool FuncVerifier::verifyDominance(Block &block) { - // Verify the dominance of each of the held instructions. - for (auto &inst : block) - if (verifyInstDominance(inst)) + // Verify the dominance of each of the held operations. + for (auto &op : block) + if (verifyOpDominance(op)) return true; return false; } -bool FuncVerifier::verifyInstDominance(Instruction &inst) { +bool FuncVerifier::verifyOpDominance(Operation &op) { // Check that operands properly dominate this use. - for (unsigned operandNo = 0, e = inst.getNumOperands(); operandNo != e; + for (unsigned operandNo = 0, e = op.getNumOperands(); operandNo != e; ++operandNo) { - auto *op = inst.getOperand(operandNo); - if (domInfo->properlyDominates(op, &inst)) + auto *operand = op.getOperand(operandNo); + if (domInfo->properlyDominates(operand, &op)) continue; - inst.emitError("operand #" + Twine(operandNo) + - " does not dominate this use"); - if (auto *useInst = op->getDefiningOp()) - useInst->emitNote("operand defined here"); + op.emitError("operand #" + Twine(operandNo) + + " does not dominate this use"); + if (auto *useOp = operand->getDefiningOp()) + useOp->emitNote("operand defined here"); return true; } - // Verify the dominance of each of the nested blocks within this instruction. - for (auto ®ion : inst.getRegions()) + // Verify the dominance of each of the nested blocks within this operation. + for (auto ®ion : op.getRegions()) for (auto &block : region) if (verifyDominance(block)) return true; diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp index cd5816009a5..455f2a0b5fe 100644 --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -83,7 +83,7 @@ Operation *Block::findAncestorInstInBlock(Operation &op) { // find the ancestor operation that resides in the block of 'forInst'. auto *currInst = &op; while (currInst->getBlock() != this) { - currInst = currInst->getParentInst(); + currInst = currInst->getParentOp(); if (!currInst) return nullptr; } diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 3de620b524c..c54b5a24a3d 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -273,7 +273,7 @@ Dialect *Operation::getDialect() { return getContext()->getRegisteredDialect(dialectPrefix); } -Operation *Operation::getParentInst() { +Operation *Operation::getParentOp() { return block ? block->getContainingOp() : nullptr; } @@ -437,8 +437,8 @@ void Operation::moveBefore(Operation *existingInst) { moveBefore(existingInst->getBlock(), existingInst->getIterator()); } -/// Unlink this operation operation from its current basic block and insert -/// it right before `iterator` in the specified basic block. +/// Unlink this operation from its current basic block and insert it right +/// before `iterator` in the specified basic block. void Operation::moveBefore(Block *block, llvm::iplist::iterator iterator) { block->getOperations().splice(iterator, getBlock()->getOperations(), diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 8c29d1a76b4..7a6f188e6af 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -746,7 +746,7 @@ struct LoopNestStatsCollector { void collect(Instruction *inst) { inst->walk([&](AffineForOp forOp) { auto *forInst = forOp.getOperation(); - auto *parentInst = forOp.getOperation()->getParentInst(); + auto *parentInst = forOp.getOperation()->getParentOp(); if (parentInst != nullptr) { assert(parentInst->isa() && "Expected parent AffineForOp"); // Add mapping to 'forOp' from its parent AffineForOp. @@ -1545,7 +1545,7 @@ static bool isFusionProfitable(Instruction *srcOpInst, // A single store disappears: -1 for that. computeCostMap[srcLoopIVs[numSrcLoopIVs - 1].getOperation()] = -1; for (auto *loadOp : dstLoadOpInsts) { - auto *parentInst = loadOp->getParentInst(); + auto *parentInst = loadOp->getParentOp(); if (parentInst && parentInst->isa()) computeCostMap[parentInst] = -1; } diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index c235190b4b7..f99b602cf0b 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -182,8 +182,7 @@ LogicalResult mlir::tileCodeGen(MutableArrayRef band, // Check if the supplied for inst's are all successively nested. for (unsigned i = 1, e = band.size(); i < e; i++) { - assert(band[i].getOperation()->getParentInst() == - band[i - 1].getOperation()); + assert(band[i].getOperation()->getParentOp() == band[i - 1].getOperation()); } auto origLoops = band; diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 8ea9d4e8020..c15108530fb 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -687,7 +687,7 @@ static bool materialize(Function *f, // current enclosing scope of the terminator. See the top of the function // Note for the justification of this restriction. // TODO(ntv): relax scoping constraints. - auto *enclosingScope = term->getParentInst(); + auto *enclosingScope = term->getParentOp(); auto keepIfInSameScope = [enclosingScope, &domInfo](Instruction *inst) { assert(inst && "NULL inst"); if (!enclosingScope) { @@ -760,7 +760,7 @@ void MaterializeVectorsPass::runOnFunction() { pat.match(f, &matches); SetVector terminators; for (auto m : matches) { - terminators.insert(m.getMatchedInstruction()); + terminators.insert(m.getMatchedOperation()); } if (materialize(f, terminators, &state)) diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index b5109a20ba9..c06e9359324 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -122,7 +122,7 @@ void VectorizerTestPass::testVectorShapeRatio() { SmallVector matches; pat.match(f, &matches); for (auto m : matches) { - auto *opInst = m.getMatchedInstruction(); + auto *opInst = m.getMatchedOperation(); // This is a unit test that only checks and prints shape ratio. // As a consequence we write only Ops with a single return type for the // purpose of this test. If we need to test more intricate behavior in the @@ -164,9 +164,9 @@ void VectorizerTestPass::testBackwardSlicing() { patternTestSlicingOps().match(f, &matches); for (auto m : matches) { SetVector backwardSlice; - getBackwardSlice(m.getMatchedInstruction(), &backwardSlice); + getBackwardSlice(m.getMatchedOperation(), &backwardSlice); auto strs = map(toString, backwardSlice); - outs() << "\nmatched: " << *m.getMatchedInstruction() + outs() << "\nmatched: " << *m.getMatchedOperation() << " backward static slice: "; for (const auto &s : strs) { outs() << "\n" << s; @@ -180,9 +180,9 @@ void VectorizerTestPass::testForwardSlicing() { patternTestSlicingOps().match(f, &matches); for (auto m : matches) { SetVector forwardSlice; - getForwardSlice(m.getMatchedInstruction(), &forwardSlice); + getForwardSlice(m.getMatchedOperation(), &forwardSlice); auto strs = map(toString, forwardSlice); - outs() << "\nmatched: " << *m.getMatchedInstruction() + outs() << "\nmatched: " << *m.getMatchedOperation() << " forward static slice: "; for (const auto &s : strs) { outs() << "\n" << s; @@ -196,9 +196,9 @@ void VectorizerTestPass::testSlicing() { SmallVector matches; patternTestSlicingOps().match(f, &matches); for (auto m : matches) { - SetVector staticSlice = getSlice(m.getMatchedInstruction()); + SetVector staticSlice = getSlice(m.getMatchedOperation()); auto strs = map(toString, staticSlice); - outs() << "\nmatched: " << *m.getMatchedInstruction() << " static slice: "; + outs() << "\nmatched: " << *m.getMatchedOperation() << " static slice: "; for (const auto &s : strs) { outs() << "\n" << s; } @@ -220,7 +220,7 @@ void VectorizerTestPass::testComposeMaps() { SmallVector maps; maps.reserve(matches.size()); for (auto m : llvm::reverse(matches)) { - auto *opInst = m.getMatchedInstruction(); + auto *opInst = m.getMatchedOperation(); auto map = opInst->getAttr(VectorizerTestPass::kTestAffineMapAttrName) .cast() .getValue(); @@ -257,15 +257,15 @@ void VectorizerTestPass::testNormalizeMaps() { SmallVector matches; pattern.match(f, &matches); for (auto m : matches) { - auto app = m.getMatchedInstruction()->cast(); - FuncBuilder b(m.getMatchedInstruction()); + auto app = m.getMatchedOperation()->cast(); + FuncBuilder b(m.getMatchedOperation()); SmallVector operands(app.getOperands()); makeComposedAffineApply(&b, app.getLoc(), app.getAffineMap(), operands); } } // We should now be able to erase everything in reverse order in this test. for (auto m : llvm::reverse(toErase)) { - m.getMatchedInstruction()->erase(); + m.getMatchedOperation()->erase(); } } diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 3c6ab6c2cac..8a7a7a6dbba 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -705,7 +705,7 @@ static LogicalResult analyzeProfitability(ArrayRef matches, patternDepth, strategy))) { return failure(); } - vectorizeLoopIfProfitable(m.getMatchedInstruction(), depthInPattern, + vectorizeLoopIfProfitable(m.getMatchedOperation(), depthInPattern, patternDepth, strategy); } return success(); @@ -869,7 +869,7 @@ static LogicalResult vectorizeAffineForOp(AffineForOp loop, int64_t step, SmallVector loadAndStoresMatches; loadAndStores.match(loop.getOperation(), &loadAndStoresMatches); for (auto ls : loadAndStoresMatches) { - auto *opInst = ls.getMatchedInstruction(); + auto *opInst = ls.getMatchedOperation(); auto load = opInst->dyn_cast(); auto store = opInst->dyn_cast(); LLVM_DEBUG(opInst->print(dbgs())); @@ -904,7 +904,7 @@ isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension) { static LogicalResult vectorizeLoopsAndLoadsRecursively(NestedMatch oneMatch, VectorizationState *state) { - auto *loopInst = oneMatch.getMatchedInstruction(); + auto *loopInst = oneMatch.getMatchedOperation(); auto loop = loopInst->cast(); auto childrenMatches = oneMatch.getMatchedChildren(); @@ -1144,7 +1144,7 @@ static LogicalResult vectorizeNonTerminals(VectorizationState *state) { /// anything below it fails. static LogicalResult vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) { - auto loop = m.getMatchedInstruction()->cast(); + auto loop = m.getMatchedOperation()->cast(); VectorizationState state; state.strategy = strategy; @@ -1248,7 +1248,7 @@ void Vectorize::runOnFunction() { &strategy))) { continue; } - vectorizeLoopIfProfitable(m.getMatchedInstruction(), 0, patternDepth, + vectorizeLoopIfProfitable(m.getMatchedOperation(), 0, patternDepth, &strategy); // TODO(ntv): if pattern does not apply, report it; alter the // cost/benefit. -- cgit v1.2.3 From 99b87c9707b389183de33961f81d4b2730b033c8 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Wed, 27 Mar 2019 14:02:02 -0700 Subject: Replace usages of Instruction with Operation in the Transforms/ directory. PiperOrigin-RevId: 240636130 --- mlir/include/mlir/Transforms/DialectConversion.h | 7 +- mlir/include/mlir/Transforms/LoopUtils.h | 8 +- .../mlir/Transforms/MLPatternLoweringPass.h | 16 +- mlir/include/mlir/Transforms/Passes.h | 2 +- mlir/include/mlir/Transforms/Utils.h | 20 +- mlir/lib/Transforms/CSE.cpp | 24 +- mlir/lib/Transforms/ConstantFold.cpp | 12 +- mlir/lib/Transforms/DialectConversion.cpp | 32 +-- mlir/lib/Transforms/DmaGeneration.cpp | 44 ++-- mlir/lib/Transforms/LoopFusion.cpp | 269 ++++++++++----------- mlir/lib/Transforms/LoopTiling.cpp | 6 +- mlir/lib/Transforms/LoopUnroll.cpp | 6 +- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 16 +- mlir/lib/Transforms/LowerAffine.cpp | 56 ++--- mlir/lib/Transforms/LowerVectorTransfers.cpp | 4 +- mlir/lib/Transforms/MaterializeVectors.cpp | 102 ++++---- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 22 +- mlir/lib/Transforms/PipelineDataTransfer.cpp | 76 +++--- mlir/lib/Transforms/SimplifyAffineStructures.cpp | 10 +- mlir/lib/Transforms/StripDebugInfo.cpp | 6 +- .../Utils/GreedyPatternRewriteDriver.cpp | 28 +-- mlir/lib/Transforms/Utils/LoopUtils.cpp | 104 ++++---- mlir/lib/Transforms/Utils/Utils.cpp | 27 +-- .../Vectorization/VectorizerTestPass.cpp | 38 ++- mlir/lib/Transforms/Vectorize.cpp | 126 +++++----- 25 files changed, 522 insertions(+), 539 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index fe3a46d6050..27af342079b 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -33,7 +33,6 @@ class Block; class FuncBuilder; class MLIRContext; class Operation; -using Instruction = Operation; class Type; class Value; @@ -43,7 +42,7 @@ class FunctionConversion; } /// Base class for the dialect op conversion patterns. Specific conversions -/// must derive this class and implement `PatternMatch match(Instruction *)` +/// must derive this class and implement `PatternMatch match(Operation *)` /// defined in `Pattern` and at least one of `rewrite` and `rewriteTerminator`. // // TODO(zinenko): this should eventually converge with RewritePattern. So far, @@ -67,7 +66,7 @@ public: /// DialectOpConversion ever needs to replace an operation that does not have /// successors. This function should not fail. If some specific cases of the /// operation are not supported, these cases should not be matched. - virtual SmallVector rewrite(Instruction *op, + virtual SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const { llvm_unreachable("unimplemented rewrite, did you mean rewriteTerminator?"); @@ -85,7 +84,7 @@ public: /// successors. This function should not fail the pass. If some specific /// cases of the operation are not supported, these cases should not be /// matched. - virtual void rewriteTerminator(Instruction *op, + virtual void rewriteTerminator(Operation *op, ArrayRef properOperands, ArrayRef destinations, ArrayRef> operands, diff --git a/mlir/include/mlir/Transforms/LoopUtils.h b/mlir/include/mlir/Transforms/LoopUtils.h index 1d5203e77d5..f1e7b503769 100644 --- a/mlir/include/mlir/Transforms/LoopUtils.h +++ b/mlir/include/mlir/Transforms/LoopUtils.h @@ -34,10 +34,10 @@ class Function; class FuncBuilder; class Value; -/// Unrolls this for instruction completely if the trip count is known to be +/// Unrolls this for operation completely if the trip count is known to be /// constant. Returns failure otherwise. LogicalResult loopUnrollFull(AffineForOp forOp); -/// Unrolls this for instruction by the specified unroll factor. Returns failure +/// Unrolls this for operation by the specified unroll factor. Returns failure /// if the loop cannot be unrolled either due to restrictions or due to invalid /// unroll factors. LogicalResult loopUnrollByFactor(AffineForOp forOp, uint64_t unrollFactor); @@ -73,8 +73,8 @@ void getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor, SmallVectorImpl *operands, FuncBuilder *builder); -/// Skew the instructions in the body of a 'affine.for' instruction with the -/// specified instruction-wise shifts. The shifts are with respect to the +/// Skew the operations in the body of a 'affine.for' operation with the +/// specified operation-wise shifts. The shifts are with respect to the /// original execution order, and are multiplied by the loop 'step' before being /// applied. LLVM_NODISCARD diff --git a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h index c9ed3a38a65..c43b551c49a 100644 --- a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h +++ b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h @@ -37,7 +37,7 @@ public: FuncBuilder *getBuilder() { return builder; } - Instruction *createOperation(const OperationState &state) override { + Operation *createOperation(const OperationState &state) override { auto *result = builder->createOperation(state); return result; } @@ -66,7 +66,7 @@ public: /// must override). It will be passed the function-wise state, common to all /// matches, and the state returned by the `match` call, if any. The subclass /// must use `rewriter` to modify the function. - virtual void rewriteOpInst(Instruction *op, + virtual void rewriteOpInst(Operation *op, MLFuncGlobalLoweringState *funcWiseState, std::unique_ptr opState, MLFuncLoweringRewriter *rewriter) const = 0; @@ -123,14 +123,14 @@ void applyMLPatternsGreedily( FuncBuilder builder(f); MLFuncLoweringRewriter rewriter(&builder); - llvm::SmallVector ops; - f->walk([&ops](Instruction *inst) { ops.push_back(inst); }); + llvm::SmallVector ops; + f->walk([&ops](Operation *op) { ops.push_back(op); }); - for (Instruction *inst : ops) { + for (Operation *op : ops) { for (const auto &pattern : patterns) { - builder.setInsertionPoint(inst); - if (auto matchResult = pattern->match(inst)) { - pattern->rewriteOpInst(inst, funcWiseState, std::move(*matchResult), + builder.setInsertionPoint(op); + if (auto matchResult = pattern->match(op)) { + pattern->rewriteOpInst(op, funcWiseState, std::move(*matchResult), &rewriter); break; } diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index 3a75a2619f4..634f690c451 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -82,7 +82,7 @@ FunctionPassBase *createLoopFusionPass(unsigned fastMemorySpace = 0, /// memory hierarchy. FunctionPassBase *createPipelineDataTransferPass(); -/// Lowers affine control flow instructions (ForStmt, IfStmt and AffineApplyOp) +/// Lowers affine control flow operations (ForStmt, IfStmt and AffineApplyOp) /// to equivalent lower-level constructs (flow of basic blocks and arithmetic /// primitives). FunctionPassBase *createLowerAffinePass(); diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h index 97bf16c2c54..db25ed1f26d 100644 --- a/mlir/include/mlir/Transforms/Utils.h +++ b/mlir/include/mlir/Transforms/Utils.h @@ -73,8 +73,8 @@ bool replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, ArrayRef extraIndices = {}, AffineMap indexRemap = AffineMap(), ArrayRef extraOperands = {}, - Instruction *domInstFilter = nullptr, - Instruction *postDomInstFilter = nullptr); + Operation *domInstFilter = nullptr, + Operation *postDomInstFilter = nullptr); /// Creates and inserts into 'builder' a new AffineApplyOp, with the number of /// its results equal to the number of operands, as a composition @@ -83,13 +83,13 @@ bool replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, /// these will also be collected into a single (multi-result) affine apply op. /// The final results of the composed AffineApplyOp are returned in output /// parameter 'results'. Returns the affine apply op created. -Instruction *createComposedAffineApplyOp(FuncBuilder *builder, Location loc, - ArrayRef operands, - ArrayRef affineApplyOps, - SmallVectorImpl *results); +Operation *createComposedAffineApplyOp(FuncBuilder *builder, Location loc, + ArrayRef operands, + ArrayRef affineApplyOps, + SmallVectorImpl *results); -/// Given an instruction, inserts one or more single result affine apply -/// operations, results of which are exclusively used by this instruction. +/// Given an operation, inserts one or more single result affine apply +/// operations, results of which are exclusively used by this operation. /// The operands of these newly created affine apply ops are /// guaranteed to be loop iterators or terminal symbols of a function. /// @@ -117,13 +117,13 @@ Instruction *createComposedAffineApplyOp(FuncBuilder *builder, Location loc, /// (i.e., there was no affine computation slice to create). /// 2. If all the affine.apply op's supplying operands to this opInst did not /// have any uses other than those in this opInst. -void createAffineComputationSlice(Instruction *opInst, +void createAffineComputationSlice(Operation *opInst, SmallVectorImpl *sliceOps); /// Replaces (potentially nested) function attributes in the operation "op" /// with those specified in "remappingTable". void remapFunctionAttrs( - Instruction &op, const DenseMap &remappingTable); + Operation &op, const DenseMap &remappingTable); /// Replaces (potentially nested) function attributes all operations of the /// Function "fn" with those specified in "remappingTable". diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index c3916a07c18..f90e12db772 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -38,11 +38,11 @@ using namespace mlir; namespace { // TODO(riverriddle) Handle commutative operations. -struct SimpleOperationInfo : public llvm::DenseMapInfo { - static unsigned getHashValue(const Instruction *opC) { - auto *op = const_cast(opC); +struct SimpleOperationInfo : public llvm::DenseMapInfo { + static unsigned getHashValue(const Operation *opC) { + auto *op = const_cast(opC); // Hash the operations based upon their: - // - Instruction Name + // - Operation Name // - Attributes // - Result Types // - Operands @@ -51,9 +51,9 @@ struct SimpleOperationInfo : public llvm::DenseMapInfo { hash_combine_range(op->result_type_begin(), op->result_type_end()), hash_combine_range(op->operand_begin(), op->operand_end())); } - static bool isEqual(const Instruction *lhsC, const Instruction *rhsC) { - auto *lhs = const_cast(lhsC); - auto *rhs = const_cast(rhsC); + static bool isEqual(const Operation *lhsC, const Operation *rhsC) { + auto *lhs = const_cast(lhsC); + auto *rhs = const_cast(rhsC); if (lhs == rhs) return true; if (lhs == getTombstoneKey() || lhs == getEmptyKey() || @@ -90,8 +90,8 @@ struct CSE : public FunctionPass { /// Shared implementation of operation elimination and scoped map definitions. using AllocatorTy = llvm::RecyclingAllocator< llvm::BumpPtrAllocator, - llvm::ScopedHashTableVal>; - using ScopedMapTy = llvm::ScopedHashTable>; + using ScopedMapTy = llvm::ScopedHashTable; /// Represents a single entry in the depth first traversal of a CFG. @@ -112,7 +112,7 @@ struct CSE : public FunctionPass { /// Attempt to eliminate a redundant operation. Returns true if the operation /// was marked for removal, false otherwise. - bool simplifyOperation(Instruction *op); + bool simplifyOperation(Operation *op); void simplifyBlock(DominanceInfo &domInfo, Block *bb); void simplifyRegion(DominanceInfo &domInfo, Region ®ion); @@ -124,12 +124,12 @@ private: ScopedMapTy knownValues; /// Operations marked as dead and to be erased. - std::vector opsToErase; + std::vector opsToErase; }; } // end anonymous namespace /// Attempt to eliminate a redundant operation. -bool CSE::simplifyOperation(Instruction *op) { +bool CSE::simplifyOperation(Operation *op) { // Don't simplify operations with nested blocks. We don't currently model // equality comparisons correctly among other things. It is also unclear // whether we would want to CSE such operations. diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index 4c4c8cc4019..364c3dcd6ad 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -31,9 +31,9 @@ struct ConstantFold : public FunctionPass { // All constants in the function post folding. SmallVector existingConstants; // Operations that were folded and that need to be erased. - std::vector opInstsToErase; + std::vector opInstsToErase; - void foldInstruction(Instruction *op); + void foldOperation(Operation *op); void runOnFunction() override; }; } // end anonymous namespace @@ -41,7 +41,7 @@ struct ConstantFold : public FunctionPass { /// Attempt to fold the specified operation, updating the IR to match. If /// constants are found, we keep track of them in the existingConstants list. /// -void ConstantFold::foldInstruction(Instruction *op) { +void ConstantFold::foldOperation(Operation *op) { // If this operation is already a constant, just remember it for cleanup // later, and don't try to fold it. if (auto constant = op->dyn_cast()) { @@ -97,15 +97,15 @@ void ConstantFold::runOnFunction() { existingConstants.clear(); opInstsToErase.clear(); - getFunction().walk([&](Instruction *inst) { foldInstruction(inst); }); + getFunction().walk([&](Operation *op) { foldOperation(op); }); // At this point, these operations are dead, remove them. // TODO: This is assuming that all constant foldable operations have no // side effects. When we have side effect modeling, we should verify that // the operation is effect-free before we remove it. Until then this is // close enough. - for (auto *inst : opInstsToErase) { - inst->erase(); + for (auto *op : opInstsToErase) { + op->erase(); } // By the time we are done, we may have simplified a bunch of code, leaving diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index a659b2e480b..2d16f23d41f 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -50,7 +50,7 @@ private: // Utility that looks up a list of value in the value remapping table. Returns // an empty vector if one of the values is not mapped yet. SmallVector lookupValues( - const llvm::iterator_range &operands); + const llvm::iterator_range &operands); // Converts the given function to the dialect using hooks defined in // `dialectConversion`. Returns the converted function or `nullptr` on error. @@ -61,16 +61,16 @@ private: // passes them to `converter->rewriteTerminator` function defined in the // pattern, together with `builder`. LogicalResult convertOpWithSuccessors(DialectOpConversion *converter, - Instruction *op, FuncBuilder &builder); + Operation *op, FuncBuilder &builder); // Converts an operation without successors. Extracts the converted operands // from `valueRemapping` and passes them to the `converter->rewrite` function // defined in the pattern, together with `builder`. - LogicalResult convertOp(DialectOpConversion *converter, Instruction *op, + LogicalResult convertOp(DialectOpConversion *converter, Operation *op, FuncBuilder &builder); - // Converts a block by traversing its instructions sequentially, looking for - // the first pattern match and dispatching the instruction conversion to + // Converts a block by traversing its operations sequentially, looking for + // the first pattern match and dispatching the operation conversion to // either `convertOp` or `convertOpWithSuccessors` depending on the presence // of successors. If there is no match, clones the operation. // @@ -101,7 +101,7 @@ private: } // end namespace mlir SmallVector impl::FunctionConversion::lookupValues( - const llvm::iterator_range &operands) { + const llvm::iterator_range &operands) { SmallVector remapped; remapped.reserve(llvm::size(operands)); for (Value *operand : operands) { @@ -114,7 +114,7 @@ SmallVector impl::FunctionConversion::lookupValues( } LogicalResult impl::FunctionConversion::convertOpWithSuccessors( - DialectOpConversion *converter, Instruction *op, FuncBuilder &builder) { + DialectOpConversion *converter, Operation *op, FuncBuilder &builder) { SmallVector destinations; destinations.reserve(op->getNumSuccessors()); SmallVector operands = lookupValues(op->getOperands()); @@ -146,7 +146,7 @@ LogicalResult impl::FunctionConversion::convertOpWithSuccessors( LogicalResult impl::FunctionConversion::convertOp(DialectOpConversion *converter, - Instruction *op, FuncBuilder &builder) { + Operation *op, FuncBuilder &builder) { auto operands = lookupValues(op->getOperands()); assert((!operands.empty() || op->getNumOperands() == 0) && "converting op before ops defining its operands"); @@ -170,22 +170,22 @@ impl::FunctionConversion::convertBlock(Block *block, FuncBuilder &builder, builder.setInsertionPointToStart(mapping.lookupOrNull(block)); // Iterate over ops and convert them. - for (Instruction &inst : *block) { - if (inst.getNumRegions() != 0) { - inst.emitError("unsupported region instruction"); + for (Operation &op : *block) { + if (op.getNumRegions() != 0) { + op.emitError("unsupported region operation"); return failure(); } // Find the first matching conversion and apply it. bool converted = false; for (auto *conversion : conversions) { - if (!conversion->match(&inst)) + if (!conversion->match(&op)) continue; - if (inst.getNumSuccessors() != 0) { - if (failed(convertOpWithSuccessors(conversion, &inst, builder))) + if (op.getNumSuccessors() != 0) { + if (failed(convertOpWithSuccessors(conversion, &op, builder))) return failure(); - } else if (failed(convertOp(conversion, &inst, builder))) { + } else if (failed(convertOp(conversion, &op, builder))) { return failure(); } converted = true; @@ -193,7 +193,7 @@ impl::FunctionConversion::convertBlock(Block *block, FuncBuilder &builder, } // If there is no conversion provided for the op, clone the op as is. if (!converted) - builder.clone(inst, mapping); + builder.clone(op, mapping); } // Recurse to children unless they have been already visited. diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index e04ae3d45bb..83ba858447b 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -170,7 +170,7 @@ static void getMultiLevelStrides(const MemRefRegion ®ion, /// dynamic shaped memref's for now. `numParamLoopIVs` is the number of /// enclosing loop IVs of opInst (starting from the outermost) that the region /// is parametric on. -static bool getFullMemRefAsRegion(Instruction *opInst, unsigned numParamLoopIVs, +static bool getFullMemRefAsRegion(Operation *opInst, unsigned numParamLoopIVs, MemRefRegion *region) { unsigned rank; if (auto loadOp = opInst->dyn_cast()) { @@ -212,11 +212,11 @@ static bool getFullMemRefAsRegion(Instruction *opInst, unsigned numParamLoopIVs, } static void emitNoteForBlock(Block &block, const Twine &message) { - auto *inst = block.getContainingOp(); - if (!inst) { + auto *op = block.getContainingOp(); + if (!op) { block.getFunction()->emitNote(message); } else { - inst->emitNote(message); + op->emitNote(message); } } @@ -350,7 +350,7 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, Block *block, fastBufferShape, memRefType.getElementType(), {}, fastMemorySpace); // Create the fast memory space buffer just before the 'affine.for' - // instruction. + // operation. fastMemRef = prologue.create(loc, fastMemRefType).getResult(); // Record it. fastBufferMap[memref] = fastMemRef; @@ -391,7 +391,7 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, Block *block, top.create(loc, strideInfos[0].numEltPerStride); } - // Record the last instruction just before the point where we insert the + // Record the last operation just before the point where we insert the // outgoing DMAs. We later do the memref replacement later only in [begin, // postDomFilter] so that the original memref's in the DMA ops themselves // don't get replaced. @@ -464,7 +464,7 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, Block *block, } /// Generate DMAs for this block. The block is partitioned into separate -/// `regions`; each region is either a sequence of one or more instructions +/// `regions`; each region is either a sequence of one or more operations /// starting and ending with a load or store op, or just a loop (which could /// have other loops nested within). Returns false on an error, true otherwise. bool DmaGeneration::runOnBlock(Block *block) { @@ -472,20 +472,19 @@ bool DmaGeneration::runOnBlock(Block *block) { return true; // Every loop in the block starts and ends a region. A contiguous sequence of - // operation instructions starting and ending with a load/store op is also + // operations starting and ending with a load/store op is also // identified as a region. Straightline code (contiguous chunks of operation - // instructions) are always assumed to not exhaust memory. As a result, this + // operations) are always assumed to not exhaust memory. As a result, this // approach is conservative in some cases at the moment, we do a check later // and report an error with location info. - // TODO(bondhugula): An 'affine.if' instruction is being treated similar to an - // operation instruction. 'affine.if''s could have 'affine.for's in them; + // TODO(bondhugula): An 'affine.if' operation is being treated similar to an + // operation. 'affine.if''s could have 'affine.for's in them; // treat them separately. // Get to the first load, store, or for op. auto curBegin = - std::find_if(block->begin(), block->end(), [&](Instruction &inst) { - return inst.isa() || inst.isa() || - inst.isa(); + std::find_if(block->begin(), block->end(), [&](Operation &op) { + return op.isa() || op.isa() || op.isa(); }); for (auto it = curBegin; it != block->end(); ++it) { @@ -513,7 +512,7 @@ bool DmaGeneration::runOnBlock(Block *block) { runOnBlock(/*begin=*/curBegin, /*end=*/it); // Recurse onto the body of this loop. runOnBlock(forOp.getBody()); - // The next region starts right after the 'affine.for' instruction. + // The next region starts right after the 'affine.for' operation. curBegin = std::next(it); } else { // We have enough capacity, i.e., DMAs will be computed for the portion @@ -583,10 +582,10 @@ findHighestBlockForPlacement(const MemRefRegion ®ion, Block &block, } } -/// Generates DMAs for a contiguous sequence of instructions in `block` in the +/// Generates DMAs for a contiguous sequence of operations in `block` in the /// iterator range [begin, end). Returns the total size of the DMA buffers used. // Since we generate alloc's and dealloc's for all DMA buffers (before and -// after the range of instructions resp), all of the fast memory capacity is +// after the range of operations resp), all of the fast memory capacity is // assumed to be available. uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { if (begin == end) @@ -610,8 +609,8 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { // To check for errors when walking the block. bool error = false; - // Walk this range of instructions to gather all memory regions. - block->walk(begin, end, [&](Instruction *opInst) { + // Walk this range of operations to gather all memory regions. + block->walk(begin, end, [&](Operation *opInst) { // Gather regions to allocate to buffers in faster memory space. if (auto loadOp = opInst->dyn_cast()) { if (loadOp.getMemRefType().getMemorySpace() != slowMemorySpace) @@ -738,8 +737,7 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { return totalDmaBuffersSizeInBytes; } - // For a range of operation instructions, a note will be emitted at the - // caller. + // For a range of operations, a note will be emitted at the caller. AffineForOp forOp; uint64_t sizeInKib = llvm::divideCeil(totalDmaBuffersSizeInBytes, 1024); if (llvm::DebugFlag && (forOp = begin->dyn_cast())) { @@ -750,8 +748,8 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { if (totalDmaBuffersSizeInBytes > fastMemCapacityBytes) { StringRef str = "Total size of all DMA buffers' for this block " "exceeds fast memory capacity\n"; - if (auto *inst = block->getContainingOp()) - inst->emitError(str); + if (auto *op = block->getContainingOp()) + op->emitError(str); else block->getFunction()->emitError(str); } diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 7a6f188e6af..80308ea6a40 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -123,26 +123,26 @@ namespace { // operations, and whether or not an IfInst was encountered in the loop nest. struct LoopNestStateCollector { SmallVector forOps; - SmallVector loadOpInsts; - SmallVector storeOpInsts; + SmallVector loadOpInsts; + SmallVector storeOpInsts; bool hasNonForRegion = false; - void collect(Instruction *instToWalk) { - instToWalk->walk([&](Instruction *opInst) { - if (opInst->isa()) - forOps.push_back(opInst->cast()); - else if (opInst->getNumRegions() != 0) + void collect(Operation *opToWalk) { + opToWalk->walk([&](Operation *op) { + if (op->isa()) + forOps.push_back(op->cast()); + else if (op->getNumRegions() != 0) hasNonForRegion = true; - else if (opInst->isa()) - loadOpInsts.push_back(opInst); - else if (opInst->isa()) - storeOpInsts.push_back(opInst); + else if (op->isa()) + loadOpInsts.push_back(op); + else if (op->isa()) + storeOpInsts.push_back(op); }); } }; // TODO(b/117228571) Replace when this is modeled through side-effects/op traits -static bool isMemRefDereferencingOp(Instruction &op) { +static bool isMemRefDereferencingOp(Operation &op) { if (op.isa() || op.isa() || op.isa() || op.isa()) return true; @@ -150,7 +150,7 @@ static bool isMemRefDereferencingOp(Instruction &op) { } // MemRefDependenceGraph is a graph data structure where graph nodes are -// top-level instructions in a Function which contain load/store ops, and edges +// top-level operations in a Function which contain load/store ops, and edges // are memref dependences between the nodes. // TODO(andydavis) Add a more flexible dependece graph representation. // TODO(andydavis) Add a depth parameter to dependence graph construction. @@ -163,12 +163,12 @@ public: // The unique identifier of this node in the graph. unsigned id; // The top-level statment which is (or contains) loads/stores. - Instruction *inst; + Operation *op; // List of load operations. - SmallVector loads; + SmallVector loads; // List of store op insts. - SmallVector stores; - Node(unsigned id, Instruction *inst) : id(id), inst(inst) {} + SmallVector stores; + Node(unsigned id, Operation *op) : id(id), op(op) {} // Returns the load op count for 'memref'. unsigned getLoadOpCount(Value *memref) { @@ -192,7 +192,7 @@ public: // Returns all store ops in 'storeOps' which access 'memref'. void getStoreOpsForMemref(Value *memref, - SmallVectorImpl *storeOps) { + SmallVectorImpl *storeOps) { for (auto *storeOpInst : stores) { if (memref == storeOpInst->cast().getMemRef()) storeOps->push_back(storeOpInst); @@ -201,7 +201,7 @@ public: // Returns all load ops in 'loadOps' which access 'memref'. void getLoadOpsForMemref(Value *memref, - SmallVectorImpl *loadOps) { + SmallVectorImpl *loadOps) { for (auto *loadOpInst : loads) { if (memref == loadOpInst->cast().getMemRef()) loadOps->push_back(loadOpInst); @@ -236,7 +236,7 @@ public: // which contain accesses to the same memref 'value'. If the value is a // non-memref value, then the dependence is between a graph node which // defines an SSA value and another graph node which uses the SSA value - // (e.g. a constant instruction defining a value which is used inside a loop + // (e.g. a constant operation defining a value which is used inside a loop // nest). Value *value; }; @@ -266,9 +266,9 @@ public: return &it->second; } - // Adds a node with 'inst' to the graph and returns its unique identifier. - unsigned addNode(Instruction *inst) { - Node node(nextNodeId++, inst); + // Adds a node with 'op' to the graph and returns its unique identifier. + unsigned addNode(Operation *op) { + Node node(nextNodeId++, op); nodes.insert({node.id, node}); return node.id; } @@ -301,9 +301,9 @@ public: Node *node = getNode(id); for (auto *storeOpInst : node->stores) { auto *memref = storeOpInst->cast().getMemRef(); - auto *inst = memref->getDefiningOp(); + auto *op = memref->getDefiningOp(); // Return true if 'memref' is a block argument. - if (!inst) + if (!op) return true; // Return true if any use of 'memref' escapes the function. for (auto &use : memref->getUses()) @@ -436,50 +436,50 @@ public: return outEdgeCount; } - // Computes and returns an insertion point instruction, before which the + // Computes and returns an insertion point operation, before which the // the fused loop nest can be inserted while preserving // dependences. Returns nullptr if no such insertion point is found. - Instruction *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId) { + Operation *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId) { if (outEdges.count(srcId) == 0) - return getNode(dstId)->inst; + return getNode(dstId)->op; // Build set of insts in range (srcId, dstId) which depend on 'srcId'. - SmallPtrSet srcDepInsts; + SmallPtrSet srcDepInsts; for (auto &outEdge : outEdges[srcId]) if (outEdge.id != dstId) - srcDepInsts.insert(getNode(outEdge.id)->inst); + srcDepInsts.insert(getNode(outEdge.id)->op); // Build set of insts in range (srcId, dstId) on which 'dstId' depends. - SmallPtrSet dstDepInsts; + SmallPtrSet dstDepInsts; for (auto &inEdge : inEdges[dstId]) if (inEdge.id != srcId) - dstDepInsts.insert(getNode(inEdge.id)->inst); + dstDepInsts.insert(getNode(inEdge.id)->op); - Instruction *srcNodeInst = getNode(srcId)->inst; - Instruction *dstNodeInst = getNode(dstId)->inst; + Operation *srcNodeInst = getNode(srcId)->op; + Operation *dstNodeInst = getNode(dstId)->op; // Computing insertion point: - // *) Walk all instruction positions in Block instruction list in the - // range (src, dst). For each instruction 'inst' visited in this search: - // *) Store in 'firstSrcDepPos' the first position where 'inst' has a + // *) Walk all operation positions in Block operation list in the + // range (src, dst). For each operation 'op' visited in this search: + // *) Store in 'firstSrcDepPos' the first position where 'op' has a // dependence edge from 'srcNode'. - // *) Store in 'lastDstDepPost' the last position where 'inst' has a + // *) Store in 'lastDstDepPost' the last position where 'op' has a // dependence edge to 'dstNode'. // *) Compare 'firstSrcDepPos' and 'lastDstDepPost' to determine the - // instruction insertion point (or return null pointer if no such + // operation insertion point (or return null pointer if no such // insertion point exists: 'firstSrcDepPos' <= 'lastDstDepPos'). - SmallVector depInsts; + SmallVector depInsts; Optional firstSrcDepPos; Optional lastDstDepPos; unsigned pos = 0; for (Block::iterator it = std::next(Block::iterator(srcNodeInst)); it != Block::iterator(dstNodeInst); ++it) { - Instruction *inst = &(*it); - if (srcDepInsts.count(inst) > 0 && firstSrcDepPos == None) + Operation *op = &(*it); + if (srcDepInsts.count(op) > 0 && firstSrcDepPos == None) firstSrcDepPos = pos; - if (dstDepInsts.count(inst) > 0) + if (dstDepInsts.count(op) > 0) lastDstDepPos = pos; - depInsts.push_back(inst); + depInsts.push_back(op); ++pos; } @@ -557,8 +557,8 @@ public: } // Adds ops in 'loads' and 'stores' to node at 'id'. - void addToNode(unsigned id, const SmallVectorImpl &loads, - const SmallVectorImpl &stores) { + void addToNode(unsigned id, const SmallVectorImpl &loads, + const SmallVectorImpl &stores) { Node *node = getNode(id); for (auto *loadOpInst : loads) node->loads.push_back(loadOpInst); @@ -596,7 +596,7 @@ public: continue; assert(nodes.count(edge.id) > 0); // Skip if 'edge.id' is not a loop nest. - if (!getNode(edge.id)->inst->isa()) + if (!getNode(edge.id)->op->isa()) continue; // Visit current input edge 'edge'. callback(edge); @@ -623,7 +623,7 @@ public: void dump() const { print(llvm::errs()); } }; -// Intializes the data dependence graph by walking instructions in 'f'. +// Intializes the data dependence graph by walking operations in 'f'. // Assigns each node in the graph a node id based on program order in 'f'. // TODO(andydavis) Add support for taking a Block arg to construct the // dependence graph at a different depth. @@ -634,18 +634,18 @@ bool MemRefDependenceGraph::init(Function &f) { if (f.getBlocks().size() != 1) return false; - DenseMap forToNodeMap; - for (auto &inst : f.front()) { - if (auto forOp = inst.dyn_cast()) { + DenseMap forToNodeMap; + for (auto &op : f.front()) { + if (auto forOp = op.dyn_cast()) { // Create graph node 'id' to represent top-level 'forOp' and record // all loads and store accesses it contains. LoopNestStateCollector collector; - collector.collect(&inst); + collector.collect(&op); // Return false if a non 'affine.for' region was found (not currently // supported). if (collector.hasNonForRegion) return false; - Node node(nextNodeId++, &inst); + Node node(nextNodeId++, &op); for (auto *opInst : collector.loadOpInsts) { node.loads.push_back(opInst); auto *memref = opInst->cast().getMemRef(); @@ -656,29 +656,29 @@ bool MemRefDependenceGraph::init(Function &f) { auto *memref = opInst->cast().getMemRef(); memrefAccesses[memref].insert(node.id); } - forToNodeMap[&inst] = node.id; + forToNodeMap[&op] = node.id; nodes.insert({node.id, node}); - } else if (auto loadOp = inst.dyn_cast()) { + } else if (auto loadOp = op.dyn_cast()) { // Create graph node for top-level load op. - Node node(nextNodeId++, &inst); - node.loads.push_back(&inst); - auto *memref = inst.cast().getMemRef(); + Node node(nextNodeId++, &op); + node.loads.push_back(&op); + auto *memref = op.cast().getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); - } else if (auto storeOp = inst.dyn_cast()) { + } else if (auto storeOp = op.dyn_cast()) { // Create graph node for top-level store op. - Node node(nextNodeId++, &inst); - node.stores.push_back(&inst); - auto *memref = inst.cast().getMemRef(); + Node node(nextNodeId++, &op); + node.stores.push_back(&op); + auto *memref = op.cast().getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); - } else if (inst.getNumRegions() != 0) { + } else if (op.getNumRegions() != 0) { // Return false if another region is found (not currently supported). return false; - } else if (inst.getNumResults() > 0 && !inst.use_empty()) { + } else if (op.getNumResults() > 0 && !op.use_empty()) { // Create graph node for top-level producer of SSA values, which // could be used by loop nest nodes. - Node node(nextNodeId++, &inst); + Node node(nextNodeId++, &op); nodes.insert({node.id, node}); } } @@ -689,7 +689,7 @@ bool MemRefDependenceGraph::init(Function &f) { const Node &node = idAndNode.second; if (!node.loads.empty() || !node.stores.empty()) continue; - auto *opInst = node.inst; + auto *opInst = node.op; for (auto *value : opInst->getResults()) { for (auto &use : value->getUses()) { SmallVector loops; @@ -728,11 +728,11 @@ namespace { // and operation count) for a loop nest up until the innermost loop body. struct LoopNestStats { // Map from AffineForOp to immediate child AffineForOps in its loop body. - DenseMap> loopMap; + DenseMap> loopMap; // Map from AffineForOp to count of operations in its loop body. - DenseMap opCountMap; + DenseMap opCountMap; // Map from AffineForOp to its constant trip count. - DenseMap tripCountMap; + DenseMap tripCountMap; }; // LoopNestStatsCollector walks a single loop nest and gathers per-loop @@ -743,8 +743,8 @@ struct LoopNestStatsCollector { LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {} - void collect(Instruction *inst) { - inst->walk([&](AffineForOp forOp) { + void collect(Operation *op) { + op->walk([&](AffineForOp forOp) { auto *forInst = forOp.getOperation(); auto *parentInst = forOp.getOperation()->getParentOp(); if (parentInst != nullptr) { @@ -753,11 +753,11 @@ struct LoopNestStatsCollector { stats->loopMap[parentInst].push_back(forOp); } - // Record the number of op instructions in the body of 'forOp'. + // Record the number of op operations in the body of 'forOp'. unsigned count = 0; stats->opCountMap[forInst] = 0; - for (auto &inst : *forOp.getBody()) { - if (!inst.isa() && !inst.isa()) + for (auto &op : *forOp.getBody()) { + if (!op.isa() && !op.isa()) ++count; } stats->opCountMap[forInst] = count; @@ -789,9 +789,9 @@ struct LoopNestStatsCollector { // NOTE: this is used to compute the cost of fusing a slice of some loop nest // within another loop. static int64_t getComputeCost( - Instruction *forInst, LoopNestStats *stats, - llvm::SmallDenseMap *tripCountOverrideMap, - DenseMap *computeCostMap) { + Operation *forInst, LoopNestStats *stats, + llvm::SmallDenseMap *tripCountOverrideMap, + DenseMap *computeCostMap) { // 'opCount' is the total number operations in one iteration of 'forOp' body int64_t opCount = stats->opCountMap[forInst]; if (stats->loopMap.count(forInst) > 0) { @@ -843,8 +843,8 @@ static Optional getConstDifference(AffineMap lbMap, AffineMap ubMap) { // was encountered). // TODO(andydavis) Make this work with non-unit step loops. static bool buildSliceTripCountMap( - Instruction *srcOpInst, ComputationSliceState *sliceState, - llvm::SmallDenseMap *tripCountMap) { + Operation *srcOpInst, ComputationSliceState *sliceState, + llvm::SmallDenseMap *tripCountMap) { SmallVector srcLoopIVs; getLoopIVs(*srcOpInst, &srcLoopIVs); unsigned numSrcLoopIVs = srcLoopIVs.size(); @@ -873,12 +873,11 @@ static bool buildSliceTripCountMap( // Removes load operations from 'srcLoads' which operate on 'memref', and // adds them to 'dstLoads'. -static void -moveLoadsAccessingMemrefTo(Value *memref, - SmallVectorImpl *srcLoads, - SmallVectorImpl *dstLoads) { +static void moveLoadsAccessingMemrefTo(Value *memref, + SmallVectorImpl *srcLoads, + SmallVectorImpl *dstLoads) { dstLoads->clear(); - SmallVector srcLoadsToKeep; + SmallVector srcLoadsToKeep; for (auto *load : *srcLoads) { if (load->cast().getMemRef() == memref) dstLoads->push_back(load); @@ -889,7 +888,7 @@ moveLoadsAccessingMemrefTo(Value *memref, } // Returns the innermost common loop depth for the set of operations in 'ops'. -static unsigned getInnermostCommonLoopDepth(ArrayRef ops) { +static unsigned getInnermostCommonLoopDepth(ArrayRef ops) { unsigned numOps = ops.size(); assert(numOps > 0); @@ -917,10 +916,10 @@ static unsigned getInnermostCommonLoopDepth(ArrayRef ops) { // Returns the maximum loop depth at which no dependences between 'loadOpInsts' // and 'storeOpInsts' are satisfied. -static unsigned getMaxLoopDepth(ArrayRef loadOpInsts, - ArrayRef storeOpInsts) { +static unsigned getMaxLoopDepth(ArrayRef loadOpInsts, + ArrayRef storeOpInsts) { // Merge loads and stores into the same array. - SmallVector ops(loadOpInsts.begin(), loadOpInsts.end()); + SmallVector ops(loadOpInsts.begin(), loadOpInsts.end()); ops.append(storeOpInsts.begin(), storeOpInsts.end()); // Compute the innermost common loop depth for loads and stores. @@ -970,7 +969,7 @@ static unsigned getMaxLoopDepth(ArrayRef loadOpInsts, // dependence componenent lexicographically negative. // TODO(andydavis) Move this function to LoopUtils. static bool -computeLoopInterchangePermutation(ArrayRef ops, +computeLoopInterchangePermutation(ArrayRef ops, unsigned maxLoopDepth, SmallVectorImpl *loopPermMap) { // Gather dependence components for dependences between all ops in 'ops' @@ -1054,12 +1053,12 @@ computeLoopInterchangePermutation(ArrayRef ops, // This can increase the loop depth at which we can fuse a slice, since we are // pushing loop carried dependence to a greater depth in the loop nest. static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) { - assert(node->inst->isa()); + assert(node->op->isa()); // Get perfectly nested sequence of loops starting at root of loop nest // (the first op being another AffineFor, and the second op - a terminator). // TODO(andydavis,bondhugula) Share this with similar code in loop tiling. SmallVector loops; - AffineForOp curr = node->inst->cast(); + AffineForOp curr = node->op->cast(); loops.push_back(curr); auto *currBody = curr.getBody(); while (currBody->begin() == std::prev(currBody->end(), 2) && @@ -1071,7 +1070,7 @@ static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) { return; // Merge loads and stores into the same array. - SmallVector memOps(node->loads.begin(), node->loads.end()); + SmallVector memOps(node->loads.begin(), node->loads.end()); memOps.append(node->stores.begin(), node->stores.end()); // Compute loop permutation in 'loopPermMap'. @@ -1091,7 +1090,7 @@ static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) { } } assert(loopNestRootIndex != -1 && "invalid root index"); - node->inst = loops[loopNestRootIndex].getOperation(); + node->op = loops[loopNestRootIndex].getOperation(); } // TODO(mlir-team): improve/complete this when we have target data. @@ -1114,8 +1113,7 @@ unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { // MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'. // TODO(bondhugula): consider refactoring the common code from generateDma and // this one. -static Value *createPrivateMemRef(AffineForOp forOp, - Instruction *srcStoreOpInst, +static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, unsigned dstLoopDepth, Optional fastMemorySpace, uint64_t localBufSizeThreshold) { @@ -1228,7 +1226,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, // Does the slice have a single iteration? static uint64_t getSliceIterationCount( - const llvm::SmallDenseMap &sliceTripCountMap) { + const llvm::SmallDenseMap &sliceTripCountMap) { uint64_t iterCount = 1; for (const auto &count : sliceTripCountMap) { iterCount *= count.second; @@ -1275,7 +1273,7 @@ static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, return false; // Compute MemRefRegion 'dstWriteRegion' for 'dstStoreOpInst' on 'memref'. - SmallVector dstStoreOps; + SmallVector dstStoreOps; dstNode->getStoreOpsForMemref(memref, &dstStoreOps); assert(dstStoreOps.size() == 1); auto *dstStoreOpInst = dstStoreOps[0]; @@ -1305,8 +1303,8 @@ static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, // and each load op in 'dstLoadOpInsts' at 'dstLoopDepth', and returns // the union in 'sliceState'. Returns true on success, false otherwise. // TODO(andydavis) Move this to a loop fusion utility function. -static bool getSliceUnion(Instruction *srcOpInst, - ArrayRef dstLoadOpInsts, +static bool getSliceUnion(Operation *srcOpInst, + ArrayRef dstLoadOpInsts, unsigned numSrcLoopIVs, unsigned dstLoopDepth, ComputationSliceState *sliceState) { MemRefAccess srcAccess(srcOpInst); @@ -1415,10 +1413,9 @@ static bool getSliceUnion(Instruction *srcOpInst, // *) Compares the total cost of the unfused loop nests to the min cost fused // loop nest computed in the previous step, and returns true if the latter // is lower. -static bool isFusionProfitable(Instruction *srcOpInst, - Instruction *srcStoreOpInst, - ArrayRef dstLoadOpInsts, - ArrayRef dstStoreOpInsts, +static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, + ArrayRef dstLoadOpInsts, + ArrayRef dstStoreOpInsts, ComputationSliceState *sliceState, unsigned *dstLoopDepth, bool maximalFusion) { LLVM_DEBUG({ @@ -1492,7 +1489,7 @@ static bool isFusionProfitable(Instruction *srcOpInst, MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc()); if (failed(srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0))) { LLVM_DEBUG(llvm::dbgs() - << "Unable to compute MemRefRegion for source instruction\n."); + << "Unable to compute MemRefRegion for source operation\n."); return false; } @@ -1510,8 +1507,8 @@ static bool isFusionProfitable(Instruction *srcOpInst, // Evaluate all depth choices for materializing the slice in the destination // loop nest. - llvm::SmallDenseMap sliceTripCountMap; - DenseMap computeCostMap; + llvm::SmallDenseMap sliceTripCountMap; + DenseMap computeCostMap; for (unsigned i = maxDstLoopDepth; i >= 1; --i) { // Compute the union of slice bounds of all ops in 'dstLoadOpInsts'. if (!getSliceUnion(srcOpInst, dstLoadOpInsts, numSrcLoopIVs, i, @@ -1754,7 +1751,7 @@ static bool isFusionProfitable(Instruction *srcOpInst, // bounds to be functions of 'dstLoopNest' IVs and symbols. // *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest', // at a loop depth determined by the cost model in 'isFusionProfitable'. -// *) Add the newly fused load/store operation instructions to the state, +// *) Add the newly fused load/store operations to the state, // and also add newly fuse load ops to 'dstLoopOps' to be considered // as fusion dst load ops in another iteration. // *) Remove old src loop nest and its associated state. @@ -1773,7 +1770,7 @@ static bool isFusionProfitable(Instruction *srcOpInst, // is preserved in the fused loop nest. // *) Update graph state to reflect the fusion of 'sibNode' into 'dstNode'. // -// Given a graph where top-level instructions are vertices in the set 'V' and +// Given a graph where top-level operations are vertices in the set 'V' and // edges in the set 'E' are dependences between vertices, this algorithm // takes O(V) time for initialization, and has runtime O(V + E). // @@ -1844,7 +1841,7 @@ public: // Get 'dstNode' into which to attempt fusion. auto *dstNode = mdg->getNode(dstId); // Skip if 'dstNode' is not a loop nest. - if (!dstNode->inst->isa()) + if (!dstNode->op->isa()) continue; // Sink sequential loops in 'dstNode' (and thus raise parallel loops) // while preserving relative order. This can increase the maximum loop @@ -1852,8 +1849,8 @@ public: // consumer loop nest. sinkSequentialLoops(dstNode); - SmallVector loads = dstNode->loads; - SmallVector dstLoadOpInsts; + SmallVector loads = dstNode->loads; + SmallVector dstLoadOpInsts; DenseSet visitedMemrefs; while (!loads.empty()) { // Get memref of load on top of the stack. @@ -1882,7 +1879,7 @@ public: // Get 'srcNode' from which to attempt fusion into 'dstNode'. auto *srcNode = mdg->getNode(srcId); // Skip if 'srcNode' is not a loop nest. - if (!srcNode->inst->isa()) + if (!srcNode->op->isa()) continue; // Skip if 'srcNode' has more than one store to any memref. // TODO(andydavis) Support fusing multi-output src loop nests. @@ -1908,9 +1905,9 @@ public: if (mdg->getOutEdgeCount(srcNode->id, memref) > maxSrcUserCount) continue; - // Compute an instruction list insertion point for the fused loop + // Compute an operation list insertion point for the fused loop // nest which preserves dependences. - Instruction *insertPointInst = + Operation *insertPointInst = mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id); if (insertPointInst == nullptr) continue; @@ -1918,7 +1915,7 @@ public: // Get unique 'srcNode' store op. auto *srcStoreOpInst = srcNode->stores.front(); // Gather 'dstNode' store ops to 'memref'. - SmallVector dstStoreOpInsts; + SmallVector dstStoreOpInsts; for (auto *storeOpInst : dstNode->stores) if (storeOpInst->cast().getMemRef() == memref) dstStoreOpInsts.push_back(storeOpInst); @@ -1938,7 +1935,7 @@ public: LLVM_DEBUG(llvm::dbgs() << "\tslice loop nest:\n" << *sliceLoopNest.getOperation() << "\n"); // Move 'dstAffineForOp' before 'insertPointInst' if needed. - auto dstAffineForOp = dstNode->inst->cast(); + auto dstAffineForOp = dstNode->op->cast(); if (insertPointInst != dstAffineForOp.getOperation()) { dstAffineForOp.getOperation()->moveBefore(insertPointInst); } @@ -1954,7 +1951,7 @@ public: } if (!writesToLiveInOrOut) { // Create private memref for 'memref' in 'dstAffineForOp'. - SmallVector storesForMemref; + SmallVector storesForMemref; for (auto *storeOpInst : sliceCollector.storeOpInsts) { if (storeOpInst->cast().getMemRef() == memref) storesForMemref.push_back(storeOpInst); @@ -1995,7 +1992,7 @@ public: // so it is safe to remove. if (writesToLiveInOrOut || mdg->canRemoveNode(srcNode->id)) { mdg->removeNode(srcNode->id); - srcNode->inst->erase(); + srcNode->op->erase(); } else { // Add remaining users of 'oldMemRef' back on the worklist (if not // already there), as its replacement with a local/private memref @@ -2034,7 +2031,7 @@ public: // Get 'dstNode' into which to attempt fusion. auto *dstNode = mdg->getNode(dstId); // Skip if 'dstNode' is not a loop nest. - if (!dstNode->inst->isa()) + if (!dstNode->op->isa()) continue; // Attempt to fuse 'dstNode' with its sibling nodes in the graph. fuseWithSiblingNodes(dstNode); @@ -2051,11 +2048,11 @@ public: // TODO(andydavis) Check that 'sibStoreOpInst' post-dominates all other // stores to the same memref in 'sibNode' loop nest. auto *sibNode = mdg->getNode(sibId); - // Compute an instruction list insertion point for the fused loop + // Compute an operation list insertion point for the fused loop // nest which preserves dependences. - assert(sibNode->inst->getBlock() == dstNode->inst->getBlock()); - Instruction *insertPointInst = - sibNode->inst->isBeforeInBlock(dstNode->inst) + assert(sibNode->op->getBlock() == dstNode->op->getBlock()); + Operation *insertPointInst = + sibNode->op->isBeforeInBlock(dstNode->op) ? mdg->getFusedLoopNestInsertionPoint(sibNode->id, dstNode->id) : mdg->getFusedLoopNestInsertionPoint(dstNode->id, sibNode->id); if (insertPointInst == nullptr) @@ -2064,21 +2061,21 @@ public: // Check if fusion would be profitable and at what depth. // Get unique 'sibNode' load op to 'memref'. - SmallVector sibLoadOpInsts; + SmallVector sibLoadOpInsts; sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts); // Currently findSiblingNodeToFuse searches for siblings with one load. assert(sibLoadOpInsts.size() == 1); - Instruction *sibLoadOpInst = sibLoadOpInsts[0]; + Operation *sibLoadOpInst = sibLoadOpInsts[0]; assert(!sibNode->stores.empty()); // TODO(andydavis) Choose the store which postdominates all other stores. auto *sibStoreOpInst = sibNode->stores.back(); // Gather 'dstNode' load ops to 'memref'. - SmallVector dstLoadOpInsts; + SmallVector dstLoadOpInsts; dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts); // Gather 'dstNode' store ops to 'memref'. - SmallVector dstStoreOpInsts; + SmallVector dstStoreOpInsts; dstNode->getStoreOpsForMemref(memref, &dstStoreOpInsts); unsigned bestDstLoopDepth; @@ -2094,8 +2091,8 @@ public: auto sliceLoopNest = mlir::insertBackwardComputationSlice( sibLoadOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState); if (sliceLoopNest != nullptr) { - auto dstForInst = dstNode->inst->cast(); - // Update instruction position of fused loop nest (if needed). + auto dstForInst = dstNode->op->cast(); + // Update operation position of fused loop nest (if needed). if (insertPointInst != dstForInst.getOperation()) { dstForInst.getOperation()->moveBefore(insertPointInst); } @@ -2140,7 +2137,7 @@ public: if (outEdge.id == dstNode->id || outEdge.value != inEdge.value) return; auto *sibNode = mdg->getNode(sibNodeId); - if (!sibNode->inst->isa()) + if (!sibNode->op->isa()) return; // Skip if 'outEdge' is not a read-after-write dependence. // TODO(andydavis) Remove restrict to single load op restriction. @@ -2196,7 +2193,7 @@ public: } // Collect dst loop stats after memref privatizaton transformation. - auto dstForInst = dstNode->inst->cast(); + auto dstForInst = dstNode->op->cast(); LoopNestStateCollector dstLoopCollector; dstLoopCollector.collect(dstForInst.getOperation()); // Clear and add back loads and stores @@ -2208,7 +2205,7 @@ public: // function. if (mdg->getOutEdgeCount(sibNode->id) == 0) { mdg->removeNode(sibNode->id); - sibNode->inst->cast().erase(); + sibNode->op->cast().erase(); } } @@ -2218,13 +2215,13 @@ public: if (pair.second > 0) continue; auto *memref = pair.first; - // Skip if there exist other uses (return instruction or function calls). + // Skip if there exist other uses (return operation or function calls). if (!memref->use_empty()) continue; // Use list expected to match the dep graph info. - auto *inst = memref->getDefiningOp(); - if (inst && inst->isa()) - inst->erase(); + auto *op = memref->getDefiningOp(); + if (op && op->isa()) + op->erase(); } } }; diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index f99b602cf0b..f7fef1a428c 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -180,7 +180,7 @@ LogicalResult mlir::tileCodeGen(MutableArrayRef band, assert(!band.empty()); assert(band.size() == tileSizes.size() && "Incorrect number of tile sizes"); - // Check if the supplied for inst's are all successively nested. + // Check if the supplied for op's are all successively nested. for (unsigned i = 1, e = band.size(); i < e; i++) { assert(band[i].getOperation()->getParentOp() == band[i - 1].getOperation()); } @@ -269,8 +269,8 @@ static void getTileableBands(Function &f, }; for (auto &block : f) - for (auto &inst : block) - if (auto forOp = inst.dyn_cast()) + for (auto &op : block) + if (auto forOp = op.dyn_cast()) getMaximalPerfectLoopNest(forOp); } diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 5687c6126d1..3b79d6245be 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -79,7 +79,7 @@ struct LoopUnroll : public FunctionPass { void runOnFunction() override; - /// Unroll this for inst. Returns failure if nothing was done. + /// Unroll this for op. Returns failure if nothing was done. LogicalResult runOnAffineForOp(AffineForOp forOp); static const unsigned kDefaultUnrollFactor = 4; @@ -106,7 +106,7 @@ void LoopUnroll::runOnFunction() { hasInnerLoops |= walkPostOrder(&(*Start++)); return hasInnerLoops; } - bool walkPostOrder(Instruction *opInst) { + bool walkPostOrder(Operation *opInst) { bool hasInnerLoops = false; for (auto ®ion : opInst->getRegions()) for (auto &block : region) @@ -158,7 +158,7 @@ void LoopUnroll::runOnFunction() { } } -/// Unrolls a 'affine.for' inst. Returns success if the loop was unrolled, +/// Unrolls a 'affine.for' op. Returns success if the loop was unrolled, /// failure otherwise. The default unroll factor is 4. LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) { // Use the function callback if one was provided. diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 3ea20c0c282..a3a24f6c0f7 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -17,7 +17,7 @@ // // This file implements loop unroll and jam. Unroll and jam is a transformation // that improves locality, in particular, register reuse, while also improving -// instruction level parallelism. The example below shows what it does in nearly +// operation level parallelism. The example below shows what it does in nearly // the general case. Loop unroll and jam currently works if the bounds of the // loops inner to the loop being unroll-jammed do not depend on the latter. // @@ -39,7 +39,7 @@ // S6(i+1); // // Note: 'if/else' blocks are not jammed. So, if there are loops inside if -// inst's, bodies of those loops will not be jammed. +// op's, bodies of those loops will not be jammed. //===----------------------------------------------------------------------===// #include "mlir/Transforms/Passes.h" @@ -96,7 +96,7 @@ void LoopUnrollAndJam::runOnFunction() { runOnAffineForOp(forOp); } -/// Unroll and jam a 'affine.for' inst. Default unroll jam factor is +/// Unroll and jam a 'affine.for' op. Default unroll jam factor is /// kDefaultUnrollJamFactor. Return failure if nothing was done. LogicalResult LoopUnrollAndJam::runOnAffineForOp(AffineForOp forOp) { // Unroll and jam by the factor that was passed if any. @@ -123,16 +123,16 @@ LogicalResult mlir::loopUnrollJamUpToFactor(AffineForOp forOp, /// Unrolls and jams this loop by the specified factor. LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp, uint64_t unrollJamFactor) { - // Gathers all maximal sub-blocks of instructions that do not themselves - // include a for inst (a instruction could have a descendant for inst though + // Gathers all maximal sub-blocks of operations that do not themselves + // include a for op (a operation could have a descendant for op though // in its tree). Ignore the block terminators. struct JamBlockGatherer { - // Store iterators to the first and last inst of each sub-block found. + // Store iterators to the first and last op of each sub-block found. std::vector> subBlocks; // This is a linear time walk. - void walk(Instruction *inst) { - for (auto ®ion : inst->getRegions()) + void walk(Operation *op) { + for (auto ®ion : op->getRegions()) for (auto &block : region) walk(block); } diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index acc9481e89c..3676c2faae9 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -32,7 +32,7 @@ using namespace mlir; namespace { -// Visit affine expressions recursively and build the sequence of instructions +// Visit affine expressions recursively and build the sequence of operations // that correspond to it. Visitation functions return an Value of the // expression subtree they visited or `nullptr` on error. class AffineApplyExpander @@ -102,7 +102,7 @@ public: // Floor division operation (rounds towards negative infinity). // // For positive divisors, it can be implemented without branching and with a - // single division instruction as + // single division operation as // // a floordiv b = // let negative = a < 0 in @@ -144,7 +144,7 @@ public: // Ceiling division operation (rounds towards positive infinity). // // For positive divisors, it can be implemented without branching and with a - // single division instruction as + // single division operation as // // a ceildiv b = // let negative = a <= 0 in @@ -213,7 +213,7 @@ private: }; } // namespace -// Create a sequence of instructions that implement the `expr` applied to the +// Create a sequence of operations that implement the `expr` applied to the // given dimension and symbol values. static mlir::Value *expandAffineExpr(FuncBuilder *builder, Location loc, AffineExpr expr, @@ -222,7 +222,7 @@ static mlir::Value *expandAffineExpr(FuncBuilder *builder, Location loc, return AffineApplyExpander(builder, dimValues, symbolValues, loc).visit(expr); } -// Create a sequence of instructions that implement the `affineMap` applied to +// Create a sequence of operations that implement the `affineMap` applied to // the given `operands` (as it it were an AffineApplyOp). Optional> static expandAffineMap( FuncBuilder *builder, Location loc, AffineMap affineMap, @@ -395,16 +395,16 @@ bool LowerAffinePass::lowerAffineFor(AffineForOp forOp) { return false; } -// Convert an "if" instruction into a flow of basic blocks. +// Convert an "if" operation into a flow of basic blocks. // -// Create an SESE region for the if instruction (including its "then" and -// optional "else" instruction blocks) and append it to the end of the current +// Create an SESE region for the if operation (including its "then" and +// optional "else" operation blocks) and append it to the end of the current // region. The conditional region consists of a sequence of condition-checking // blocks that implement the short-circuit scheme, followed by a "then" SESE // region and an "else" SESE region, and the continuation block that -// post-dominates all blocks of the "if" instruction. The flow of blocks that +// post-dominates all blocks of the "if" operation. The flow of blocks that // correspond to the "then" and "else" clauses are constructed recursively, -// enabling easy nesting of "if" instructions and if-then-else-if chains. +// enabling easy nesting of "if" operations and if-then-else-if chains. // // +--------------------------------+ // | | @@ -465,12 +465,12 @@ bool LowerAffinePass::lowerAffineIf(AffineIfOp ifOp) { auto *continueBlock = condBlock->splitBlock(ifInst); // Create a block for the 'then' code, inserting it between the cond and - // continue blocks. Move the instructions over from the AffineIfOp and add a + // continue blocks. Move the operations over from the AffineIfOp and add a // branch to the continuation point. Block *thenBlock = new Block(); thenBlock->insertBefore(continueBlock); - // If the 'then' block is not empty, then splice the instructions except for + // If the 'then' block is not empty, then splice the operations except for // the terminator. auto &oldThenBlocks = ifOp.getThenBlocks(); if (!oldThenBlocks.empty()) { @@ -570,7 +570,7 @@ bool LowerAffinePass::lowerAffineIf(AffineIfOp ifOp) { } // Convert an "affine.apply" operation into a sequence of arithmetic -// instructions using the StandardOps dialect. Return true on error. +// operations using the StandardOps dialect. Return true on error. bool LowerAffinePass::lowerAffineApply(AffineApplyOp op) { FuncBuilder builder(op.getOperation()); auto maybeExpandedMap = @@ -590,12 +590,12 @@ bool LowerAffinePass::lowerAffineApply(AffineApplyOp op) { // Entry point of the function convertor. // -// Conversion is performed by recursively visiting instructions of a Function. +// Conversion is performed by recursively visiting operations of a Function. // It reasons in terms of single-entry single-exit (SESE) regions that are not // materialized in the code. Instead, the pointer to the last block of the // region is maintained throughout the conversion as the insertion point of the // IR builder since we never change the first block after its creation. "Block" -// instructions such as loops and branches create new SESE regions for their +// operations such as loops and branches create new SESE regions for their // bodies, and surround them with additional basic blocks for the control flow. // Individual operations are simply appended to the end of the last basic block // of the current region. The SESE invariant allows us to easily handle nested @@ -607,32 +607,32 @@ bool LowerAffinePass::lowerAffineApply(AffineApplyOp op) { // corresponding Value that has been defined previously. The value flow // starts with function arguments converted to basic block arguments. void LowerAffinePass::runOnFunction() { - SmallVector instsToRewrite; + SmallVector instsToRewrite; - // Collect all the For instructions as well as AffineIfOps and AffineApplyOps. + // Collect all the For operations as well as AffineIfOps and AffineApplyOps. // We do this as a prepass to avoid invalidating the walker with our rewrite. - getFunction().walk([&](Instruction *inst) { - if (inst->isa() || inst->isa() || - inst->isa()) - instsToRewrite.push_back(inst); + getFunction().walk([&](Operation *op) { + if (op->isa() || op->isa() || + op->isa()) + instsToRewrite.push_back(op); }); - // Rewrite all of the ifs and fors. We walked the instructions in preorder, + // Rewrite all of the ifs and fors. We walked the operations in preorder, // so we know that we will rewrite them in the same order. - for (auto *inst : instsToRewrite) { - if (auto ifOp = inst->dyn_cast()) { + for (auto *op : instsToRewrite) { + if (auto ifOp = op->dyn_cast()) { if (lowerAffineIf(ifOp)) return signalPassFailure(); - } else if (auto forOp = inst->dyn_cast()) { + } else if (auto forOp = op->dyn_cast()) { if (lowerAffineFor(forOp)) return signalPassFailure(); - } else if (lowerAffineApply(inst->cast())) { + } else if (lowerAffineApply(op->cast())) { return signalPassFailure(); } } } -/// Lowers If and For instructions within a function into their lower level CFG +/// Lowers If and For operations within a function into their lower level CFG /// equivalent blocks. FunctionPassBase *mlir::createLowerAffinePass() { return new LowerAffinePass(); @@ -640,4 +640,4 @@ FunctionPassBase *mlir::createLowerAffinePass() { static PassRegistration pass("lower-affine", - "Lower If, For, AffineApply instructions to primitive equivalents"); + "Lower If, For, AffineApply operations to primitive equivalents"); diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index 708ad7d1693..0e5a8680f77 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -356,12 +356,12 @@ public: explicit VectorTransferExpander(MLIRContext *context) : MLLoweringPattern(VectorTransferOpTy::getOperationName(), 1, context) {} - PatternMatchResult match(Instruction *op) const override { + PatternMatchResult match(Operation *op) const override { if (m_Op().match(op)) return matchSuccess(); return matchFailure(); } - void rewriteOpInst(Instruction *op, MLFuncGlobalLoweringState *funcWiseState, + void rewriteOpInst(Operation *op, MLFuncGlobalLoweringState *funcWiseState, std::unique_ptr opState, MLFuncLoweringRewriter *rewriter) const override { VectorTransferRewriter( diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 2a877c45680..7e4a459326f 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -55,7 +55,7 @@ /// to the pass. This pass is thus a partial lowering that opens the "greybox" /// that is the super-vector abstraction. In particular, this pass can turn the /// vector_transfer_read and vector_transfer_write ops in either: -/// 1. a loop nest with either scalar and vector load/store instructions; or +/// 1. a loop nest with either scalar and vector load/store operations; or /// 2. a loop-nest with DmaStartOp / DmaWaitOp; or /// 3. a pre-existing blackbox library call that can be written manually or /// synthesized using search and superoptimization. @@ -239,9 +239,9 @@ static SmallVector delinearize(unsigned linearIndex, return res; } -static Instruction *instantiate(FuncBuilder *b, Instruction *opInst, - VectorType hwVectorType, - DenseMap *substitutionsMap); +static Operation *instantiate(FuncBuilder *b, Operation *opInst, + VectorType hwVectorType, + DenseMap *substitutionsMap); /// Not all Values belong to a program slice scoped within the immediately /// enclosing loop. @@ -259,9 +259,8 @@ static Value *substitute(Value *v, VectorType hwVectorType, auto *opInst = v->getDefiningOp(); if (opInst->isa()) { FuncBuilder b(opInst); - auto *inst = instantiate(&b, opInst, hwVectorType, substitutionsMap); - auto res = - substitutionsMap->insert(std::make_pair(v, inst->getResult(0))); + auto *op = instantiate(&b, opInst, hwVectorType, substitutionsMap); + auto res = substitutionsMap->insert(std::make_pair(v, op->getResult(0))); assert(res.second && "Insertion failed"); return res.first->second; } @@ -384,7 +383,7 @@ reindexAffineIndices(FuncBuilder *b, VectorType hwVectorType, /// - constant splat is replaced by constant splat of `hwVectorType`. /// TODO(ntv): add more substitutions on a per-need basis. static SmallVector -materializeAttributes(Instruction *opInst, VectorType hwVectorType) { +materializeAttributes(Operation *opInst, VectorType hwVectorType) { SmallVector res; for (auto a : opInst->getAttrs()) { if (auto splat = a.second.dyn_cast()) { @@ -404,9 +403,9 @@ materializeAttributes(Instruction *opInst, VectorType hwVectorType) { /// substitutionsMap. /// /// If the underlying substitution fails, this fails too and returns nullptr. -static Instruction *instantiate(FuncBuilder *b, Instruction *opInst, - VectorType hwVectorType, - DenseMap *substitutionsMap) { +static Operation *instantiate(FuncBuilder *b, Operation *opInst, + VectorType hwVectorType, + DenseMap *substitutionsMap) { assert(!opInst->isa() && "Should call the function specialized for VectorTransferReadOp"); assert(!opInst->isa() && @@ -481,10 +480,10 @@ static AffineMap projectedPermutationMap(VectorTransferOpTy transfer, /// `hwVectorType` int the covering of the super-vector type. For a more /// detailed description of the problem, see the description of /// reindexAffineIndices. -static Instruction *instantiate(FuncBuilder *b, VectorTransferReadOp read, - VectorType hwVectorType, - ArrayRef hwVectorInstance, - DenseMap *substitutionsMap) { +static Operation *instantiate(FuncBuilder *b, VectorTransferReadOp read, + VectorType hwVectorType, + ArrayRef hwVectorInstance, + DenseMap *substitutionsMap) { SmallVector indices = map(makePtrDynCaster(), read.getIndices()); auto affineIndices = @@ -505,10 +504,10 @@ static Instruction *instantiate(FuncBuilder *b, VectorTransferReadOp read, /// `hwVectorType` int the covering of th3e super-vector type. For a more /// detailed description of the problem, see the description of /// reindexAffineIndices. -static Instruction *instantiate(FuncBuilder *b, VectorTransferWriteOp write, - VectorType hwVectorType, - ArrayRef hwVectorInstance, - DenseMap *substitutionsMap) { +static Operation *instantiate(FuncBuilder *b, VectorTransferWriteOp write, + VectorType hwVectorType, + ArrayRef hwVectorInstance, + DenseMap *substitutionsMap) { SmallVector indices = map(makePtrDynCaster(), write.getIndices()); auto affineIndices = @@ -521,11 +520,11 @@ static Instruction *instantiate(FuncBuilder *b, VectorTransferWriteOp write, return cloned.getOperation(); } -/// Returns `true` if inst instance is properly cloned and inserted, false +/// Returns `true` if op instance is properly cloned and inserted, false /// otherwise. /// The multi-dimensional `hwVectorInstance` belongs to the shapeRatio of /// super-vector type to hw vector type. -/// A cloned instance of `inst` is formed as follows: +/// A cloned instance of `op` is formed as follows: /// 1. vector_transfer_read: the return `superVectorType` is replaced by /// `hwVectorType`. Additionally, affine indices are reindexed with /// `reindexAffineIndices` using `hwVectorInstance` and vector type @@ -542,26 +541,26 @@ static Instruction *instantiate(FuncBuilder *b, VectorTransferWriteOp write, /// possible. /// /// Returns true on failure. -static bool instantiateMaterialization(Instruction *inst, +static bool instantiateMaterialization(Operation *op, MaterializationState *state) { - LLVM_DEBUG(dbgs() << "\ninstantiate: " << *inst); + LLVM_DEBUG(dbgs() << "\ninstantiate: " << *op); // Create a builder here for unroll-and-jam effects. - FuncBuilder b(inst); + FuncBuilder b(op); // AffineApplyOp are ignored: instantiating the proper vector op will take // care of AffineApplyOps by composing them properly. - if (inst->isa()) { + if (op->isa()) { return false; } - if (inst->getNumRegions() != 0) - return inst->emitError("NYI path Op with region"); + if (op->getNumRegions() != 0) + return op->emitError("NYI path Op with region"); - if (auto write = inst->dyn_cast()) { + if (auto write = op->dyn_cast()) { auto *clone = instantiate(&b, write, state->hwVectorType, state->hwVectorInstance, state->substitutionsMap); return clone == nullptr; } - if (auto read = inst->dyn_cast()) { + if (auto read = op->dyn_cast()) { auto *clone = instantiate(&b, read, state->hwVectorType, state->hwVectorInstance, state->substitutionsMap); if (!clone) { @@ -574,19 +573,19 @@ static bool instantiateMaterialization(Instruction *inst, // The only op with 0 results reaching this point must, by construction, be // VectorTransferWriteOps and have been caught above. Ops with >= 2 results // are not yet supported. So just support 1 result. - if (inst->getNumResults() != 1) { - return inst->emitError("NYI: ops with != 1 results"); + if (op->getNumResults() != 1) { + return op->emitError("NYI: ops with != 1 results"); } - if (inst->getResult(0)->getType() != state->superVectorType) { - return inst->emitError("Op does not return a supervector."); + if (op->getResult(0)->getType() != state->superVectorType) { + return op->emitError("Op does not return a supervector."); } auto *clone = - instantiate(&b, inst, state->hwVectorType, state->substitutionsMap); + instantiate(&b, op, state->hwVectorType, state->substitutionsMap); if (!clone) { return true; } state->substitutionsMap->insert( - std::make_pair(inst->getResult(0), clone->getResult(0))); + std::make_pair(op->getResult(0), clone->getResult(0))); return false; } @@ -612,7 +611,7 @@ static bool instantiateMaterialization(Instruction *inst, /// TODO(ntv): full loops + materialized allocs. /// TODO(ntv): partial unrolling + materialized allocs. static bool emitSlice(MaterializationState *state, - SetVector *slice) { + SetVector *slice) { auto ratio = shapeRatio(state->superVectorType, state->hwVectorType); assert(ratio.hasValue() && "ratio of super-vector to HW-vector shape is not integral"); @@ -627,10 +626,10 @@ static bool emitSlice(MaterializationState *state, DenseMap substitutionMap; scopedState.substitutionsMap = &substitutionMap; // slice are topologically sorted, we can just clone them in order. - for (auto *inst : *slice) { - auto fail = instantiateMaterialization(inst, &scopedState); + for (auto *op : *slice) { + auto fail = instantiateMaterialization(op, &scopedState); if (fail) { - inst->emitError("Unhandled super-vector materialization failure"); + op->emitError("Unhandled super-vector materialization failure"); return true; } } @@ -653,7 +652,7 @@ static bool emitSlice(MaterializationState *state, /// Materializes super-vector types into concrete hw vector types as follows: /// 1. start from super-vector terminators (current vector_transfer_write /// ops); -/// 2. collect all the instructions that can be reached by transitive use-defs +/// 2. collect all the operations that can be reached by transitive use-defs /// chains; /// 3. get the superVectorType for this particular terminator and the /// corresponding hardware vector type (for now limited to F32) @@ -664,14 +663,13 @@ static bool emitSlice(MaterializationState *state, /// Notes /// ===== /// The `slice` is sorted in topological order by construction. -/// Additionally, this set is limited to instructions in the same lexical scope +/// Additionally, this set is limited to operations in the same lexical scope /// because we currently disallow vectorization of defs that come from another /// scope. /// TODO(ntv): please document return value. -static bool materialize(Function *f, - const SetVector &terminators, +static bool materialize(Function *f, const SetVector &terminators, MaterializationState *state) { - DenseSet seen; + DenseSet seen; DominanceInfo domInfo(f); for (auto *term : terminators) { // Short-circuit test, a given terminator may have been reached by some @@ -688,15 +686,15 @@ static bool materialize(Function *f, // Note for the justification of this restriction. // TODO(ntv): relax scoping constraints. auto *enclosingScope = term->getParentOp(); - auto keepIfInSameScope = [enclosingScope, &domInfo](Instruction *inst) { - assert(inst && "NULL inst"); + auto keepIfInSameScope = [enclosingScope, &domInfo](Operation *op) { + assert(op && "NULL op"); if (!enclosingScope) { // by construction, everyone is always under the top scope (null scope). return true; } - return domInfo.properlyDominates(enclosingScope, inst); + return domInfo.properlyDominates(enclosingScope, op); }; - SetVector slice = + SetVector slice = getSlice(term, keepIfInSameScope, keepIfInSameScope); assert(!slice.empty()); @@ -749,16 +747,16 @@ void MaterializeVectorsPass::runOnFunction() { // Capture terminators; i.e. vector_transfer_write ops involving a strict // super-vector of subVectorType. - auto filter = [subVectorType](Instruction &inst) { - if (!inst.isa()) { + auto filter = [subVectorType](Operation &op) { + if (!op.isa()) { return false; } - return matcher::operatesOnSuperVectors(inst, subVectorType); + return matcher::operatesOnSuperVectors(op, subVectorType); }; auto pat = Op(filter); SmallVector matches; pat.match(f, &matches); - SetVector terminators; + SetVector terminators; for (auto m : matches) { terminators.insert(m.getMatchedOperation()); } diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index 9779ab78a3f..a579d439368 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -54,8 +54,8 @@ namespace { // iteration of the innermost loop enclosing both the store op and the load op. // // (* A dependence being satisfied at a block: a dependence that is satisfied by -// virtue of the destination instruction appearing textually / lexically after -// the source instruction within the body of a 'affine.for' instruction; thus, a +// virtue of the destination operation appearing textually / lexically after +// the source operation within the body of a 'affine.for' operation; thus, a // dependence is always either satisfied by a loop or by a block). // // The above conditions are simple to check, sufficient, and powerful for most @@ -77,7 +77,7 @@ struct MemRefDataFlowOpt : public FunctionPass { // A list of memref's that are potentially dead / could be eliminated. SmallPtrSet memrefsToErase; // Load op's whose results were replaced by those forwarded from stores. - std::vector loadOpsToErase; + std::vector loadOpsToErase; DominanceInfo *domInfo = nullptr; PostDominanceInfo *postDomInfo = nullptr; @@ -94,13 +94,13 @@ FunctionPassBase *mlir::createMemRefDataFlowOptPass() { // This is a straightforward implementation not optimized for speed. Optimize // this in the future if needed. void MemRefDataFlowOpt::forwardStoreToLoad(LoadOp loadOp) { - Instruction *lastWriteStoreOp = nullptr; - Instruction *loadOpInst = loadOp.getOperation(); + Operation *lastWriteStoreOp = nullptr; + Operation *loadOpInst = loadOp.getOperation(); // First pass over the use list to get minimum number of surrounding // loops common between the load op and the store op, with min taken across // all store ops. - SmallVector storeOps; + SmallVector storeOps; unsigned minSurroundingLoops = getNestingDepth(*loadOpInst); for (InstOperand &use : loadOp.getMemRef()->getUses()) { auto storeOp = use.getOwner()->dyn_cast(); @@ -119,11 +119,11 @@ void MemRefDataFlowOpt::forwardStoreToLoad(LoadOp loadOp) { // and loadOp. // The list of store op candidates for forwarding - need to satisfy the // conditions listed at the top. - SmallVector fwdingCandidates; + SmallVector fwdingCandidates; // Store ops that have a dependence into the load (even if they aren't // forwarding candidates). Each forwarding candidate will be checked for a // post-dominance on these. 'fwdingCandidates' are a subset of depSrcStores. - SmallVector depSrcStores; + SmallVector depSrcStores; for (auto *storeOpInst : storeOps) { MemRefAccess srcAccess(storeOpInst); MemRefAccess destAccess(loadOpInst); @@ -186,7 +186,7 @@ void MemRefDataFlowOpt::forwardStoreToLoad(LoadOp loadOp) { // that postdominates all 'depSrcStores' (if such a store exists) is the // unique store providing the value to the load, i.e., provably the last // writer to that memref loc. - if (llvm::all_of(depSrcStores, [&](Instruction *depStore) { + if (llvm::all_of(depSrcStores, [&](Operation *depStore) { return postDomInfo->postDominates(storeOpInst, depStore); })) { lastWriteStoreOp = storeOpInst; @@ -236,9 +236,9 @@ void MemRefDataFlowOpt::runOnFunction() { // to do this as well, but we'll do it here since we collected these anyway. for (auto *memref : memrefsToErase) { // If the memref hasn't been alloc'ed in this function, skip. - Instruction *defInst = memref->getDefiningOp(); + Operation *defInst = memref->getDefiningOp(); if (!defInst || !defInst->isa()) - // TODO(mlir-team): if the memref was returned by a 'call' instruction, we + // TODO(mlir-team): if the memref was returned by a 'call' operation, we // could still erase it if the call had no side-effects. continue; if (std::any_of(memref->use_begin(), memref->use_end(), diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index a7d37161aa1..667aad2f79d 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -53,23 +53,23 @@ FunctionPassBase *mlir::createPipelineDataTransferPass() { return new PipelineDataTransfer(); } -// Returns the position of the tag memref operand given a DMA instruction. +// Returns the position of the tag memref operand given a DMA operation. // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are // added. TODO(b/117228571) -static unsigned getTagMemRefPos(Instruction &dmaInst) { +static unsigned getTagMemRefPos(Operation &dmaInst) { assert(dmaInst.isa() || dmaInst.isa()); if (dmaInst.isa()) { // Second to last operand. return dmaInst.getNumOperands() - 2; } - // First operand for a dma finish instruction. + // First operand for a dma finish operation. return 0; } /// Doubles the buffer of the supplied memref on the specified 'affine.for' -/// instruction by adding a leading dimension of size two to the memref. +/// operation by adding a leading dimension of size two to the memref. /// Replaces all uses of the old memref by the new one while indexing the newly -/// added dimension by the loop IV of the specified 'affine.for' instruction +/// added dimension by the loop IV of the specified 'affine.for' operation /// modulo 2. Returns false if such a replacement cannot be performed. static bool doubleBuffer(Value *oldMemRef, AffineForOp forOp) { auto *forBody = forOp.getBody(); @@ -104,7 +104,7 @@ static bool doubleBuffer(Value *oldMemRef, AffineForOp forOp) { dynamicDimCount++)); } - // Create and place the alloc right before the 'affine.for' instruction. + // Create and place the alloc right before the 'affine.for' operation. Value *newMemRef = bOuter.create(forInst->getLoc(), newMemRefType, allocOperands); @@ -139,7 +139,7 @@ static bool doubleBuffer(Value *oldMemRef, AffineForOp forOp) { /// Returns success if the IR is in a valid state. void PipelineDataTransfer::runOnFunction() { // Do a post order walk so that inner loop DMAs are processed first. This is - // necessary since 'affine.for' instructions nested within would otherwise + // necessary since 'affine.for' operations nested within would otherwise // become invalid (erased) when the outer loop is pipelined (the pipelined one // gets deleted and replaced by a prologue, a new steady-state loop and an // epilogue). @@ -173,27 +173,27 @@ static bool checkTagMatch(DmaStartOp startOp, DmaWaitOp waitOp) { return true; } -// Identify matching DMA start/finish instructions to overlap computation with. +// Identify matching DMA start/finish operations to overlap computation with. static void findMatchingStartFinishInsts( AffineForOp forOp, - SmallVectorImpl> &startWaitPairs) { + SmallVectorImpl> &startWaitPairs) { - // Collect outgoing DMA instructions - needed to check for dependences below. + // Collect outgoing DMA operations - needed to check for dependences below. SmallVector outgoingDmaOps; - for (auto &inst : *forOp.getBody()) { - auto dmaStartOp = inst.dyn_cast(); + for (auto &op : *forOp.getBody()) { + auto dmaStartOp = op.dyn_cast(); if (dmaStartOp && dmaStartOp.isSrcMemorySpaceFaster()) outgoingDmaOps.push_back(dmaStartOp); } - SmallVector dmaStartInsts, dmaFinishInsts; - for (auto &inst : *forOp.getBody()) { - // Collect DMA finish instructions. - if (inst.isa()) { - dmaFinishInsts.push_back(&inst); + SmallVector dmaStartInsts, dmaFinishInsts; + for (auto &op : *forOp.getBody()) { + // Collect DMA finish operations. + if (op.isa()) { + dmaFinishInsts.push_back(&op); continue; } - auto dmaStartOp = inst.dyn_cast(); + auto dmaStartOp = op.dyn_cast(); if (!dmaStartOp) continue; @@ -228,10 +228,10 @@ static void findMatchingStartFinishInsts( } } if (!escapingUses) - dmaStartInsts.push_back(&inst); + dmaStartInsts.push_back(&op); } - // For each start instruction, we look for a matching finish instruction. + // For each start operation, we look for a matching finish operation. for (auto *dmaStartInst : dmaStartInsts) { for (auto *dmaFinishInst : dmaFinishInsts) { if (checkTagMatch(dmaStartInst->cast(), @@ -253,7 +253,7 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { return; } - SmallVector, 4> startWaitPairs; + SmallVector, 4> startWaitPairs; findMatchingStartFinishInsts(forOp, startWaitPairs); if (startWaitPairs.empty()) { @@ -263,7 +263,7 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { // Double the buffers for the higher memory space memref's. // Identify memref's to replace by scanning through all DMA start - // instructions. A DMA start instruction has two memref's - the one from the + // operations. A DMA start operation has two memref's - the one from the // higher level of memory hierarchy is the one to double buffer. // TODO(bondhugula): check whether double-buffering is even necessary. // TODO(bondhugula): make this work with different layouts: assuming here that @@ -320,13 +320,13 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { startWaitPairs.clear(); findMatchingStartFinishInsts(forOp, startWaitPairs); - // Store shift for instruction for later lookup for AffineApplyOp's. - DenseMap instShiftMap; + // Store shift for operation for later lookup for AffineApplyOp's. + DenseMap instShiftMap; for (auto &pair : startWaitPairs) { auto *dmaStartInst = pair.first; assert(dmaStartInst->isa()); instShiftMap[dmaStartInst] = 0; - // Set shifts for DMA start inst's affine operand computation slices to 0. + // Set shifts for DMA start op's affine operand computation slices to 0. SmallVector sliceOps; mlir::createAffineComputationSlice(dmaStartInst, &sliceOps); if (!sliceOps.empty()) { @@ -336,32 +336,32 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { } else { // If a slice wasn't created, the reachable affine.apply op's from its // operands are the ones that go with it. - SmallVector affineApplyInsts; + SmallVector affineApplyInsts; SmallVector operands(dmaStartInst->getOperands()); getReachableAffineApplyOps(operands, affineApplyInsts); - for (auto *inst : affineApplyInsts) { - instShiftMap[inst] = 0; + for (auto *op : affineApplyInsts) { + instShiftMap[op] = 0; } } } // Everything else (including compute ops and dma finish) are shifted by one. - for (auto &inst : *forOp.getBody()) { - if (instShiftMap.find(&inst) == instShiftMap.end()) { - instShiftMap[&inst] = 1; + for (auto &op : *forOp.getBody()) { + if (instShiftMap.find(&op) == instShiftMap.end()) { + instShiftMap[&op] = 1; } } // Get shifts stored in map. std::vector shifts(forOp.getBody()->getOperations().size()); unsigned s = 0; - for (auto &inst : *forOp.getBody()) { - assert(instShiftMap.find(&inst) != instShiftMap.end()); - shifts[s++] = instShiftMap[&inst]; + for (auto &op : *forOp.getBody()) { + assert(instShiftMap.find(&op) != instShiftMap.end()); + shifts[s++] = instShiftMap[&op]; - // Tagging instructions with shifts for debugging purposes. + // Tagging operations with shifts for debugging purposes. LLVM_DEBUG({ - FuncBuilder b(&inst); - inst.setAttr("shift", b.getI64IntegerAttr(shifts[s - 1])); + FuncBuilder b(&op); + op.setAttr("shift", b.getI64IntegerAttr(shifts[s - 1])); }); } @@ -372,7 +372,7 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { } if (failed(instBodySkew(forOp, shifts))) { - LLVM_DEBUG(llvm::dbgs() << "inst body skewing failed - unexpected\n";); + LLVM_DEBUG(llvm::dbgs() << "op body skewing failed - unexpected\n";); return; } } diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp index 4ff5367abbb..e777a4d9ca3 100644 --- a/mlir/lib/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp @@ -32,7 +32,7 @@ using namespace mlir; namespace { -/// Simplifies all affine expressions appearing in the operation instructions of +/// Simplifies all affine expressions appearing in the operations of /// the Function. This is mainly to test the simplifyAffineExpr method. /// TODO(someone): This should just be defined as a canonicalization pattern /// on AffineMap and driven from the existing canonicalization pass. @@ -41,9 +41,9 @@ struct SimplifyAffineStructures void runOnFunction() override; /// Utility to simplify an affine attribute and update its entry in the parent - /// instruction if necessary. + /// operation if necessary. template - void simplifyAndUpdateAttribute(Instruction *inst, Identifier name, + void simplifyAndUpdateAttribute(Operation *op, Identifier name, AttributeT attr) { auto &simplified = simplifiedAttributes[attr]; if (simplified == attr) @@ -62,7 +62,7 @@ struct SimplifyAffineStructures } // Simplification was successful, so update the attribute. - inst->setAttr(name, simplified); + op->setAttr(name, simplified); } /// Performs basic integer set simplifications. Checks if it's empty, and @@ -93,7 +93,7 @@ FunctionPassBase *mlir::createSimplifyAffineStructuresPass() { void SimplifyAffineStructures::runOnFunction() { simplifiedAttributes.clear(); - getFunction().walk([&](Instruction *opInst) { + getFunction().walk([&](Operation *opInst) { for (auto attr : opInst->getAttrs()) { if (auto mapAttr = attr.second.dyn_cast()) simplifyAndUpdateAttribute(opInst, attr.first, mapAttr); diff --git a/mlir/lib/Transforms/StripDebugInfo.cpp b/mlir/lib/Transforms/StripDebugInfo.cpp index 9d6b7a0ba27..1691976a05a 100644 --- a/mlir/lib/Transforms/StripDebugInfo.cpp +++ b/mlir/lib/Transforms/StripDebugInfo.cpp @@ -32,9 +32,9 @@ void StripDebugInfo::runOnFunction() { Function &func = getFunction(); UnknownLoc unknownLoc = UnknownLoc::get(&getContext()); - // Strip the debug info from the function and its instructions. + // Strip the debug info from the function and its operations. func.setLoc(unknownLoc); - func.walk([&](Instruction *inst) { inst->setLoc(unknownLoc); }); + func.walk([&](Operation *op) { op->setLoc(unknownLoc); }); } /// Creates a pass to strip debug information from a function. @@ -43,4 +43,4 @@ FunctionPassBase *mlir::createStripDebugInfoPass() { } static PassRegistration - pass("strip-debuginfo", "Strip debug info from functions and instructions"); + pass("strip-debuginfo", "Strip debug info from functions and operations"); diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 79a2b12d242..b01b8dba598 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -39,13 +39,13 @@ public: worklist.reserve(64); // Add all operations to the worklist. - fn.walk([&](Instruction *inst) { addToWorklist(inst); }); + fn.walk([&](Operation *op) { addToWorklist(op); }); } /// Perform the rewrites. void simplifyFunction(); - void addToWorklist(Instruction *op) { + void addToWorklist(Operation *op) { // Check to see if the worklist already contains this op. if (worklistMap.count(op)) return; @@ -54,7 +54,7 @@ public: worklist.push_back(op); } - Instruction *popFromWorklist() { + Operation *popFromWorklist() { auto *op = worklist.back(); worklist.pop_back(); @@ -66,7 +66,7 @@ public: /// If the specified operation is in the worklist, remove it. If not, this is /// a no-op. - void removeFromWorklist(Instruction *op) { + void removeFromWorklist(Operation *op) { auto it = worklistMap.find(op); if (it != worklistMap.end()) { assert(worklist[it->second] == op && "malformed worklist data structure"); @@ -78,7 +78,7 @@ public: protected: // Implement the hook for creating operations, and make sure that newly // created ops are added to the worklist for processing. - Instruction *createOperation(const OperationState &state) override { + Operation *createOperation(const OperationState &state) override { auto *result = builder.createOperation(state); addToWorklist(result); return result; @@ -86,7 +86,7 @@ protected: // If an operation is about to be removed, make sure it is not in our // worklist anymore because we'd get dangling references to it. - void notifyOperationRemoved(Instruction *op) override { + void notifyOperationRemoved(Operation *op) override { addToWorklist(op->getOperands()); removeFromWorklist(op); } @@ -94,7 +94,7 @@ protected: // When the root of a pattern is about to be replaced, it can trigger // simplifications to its users - make sure to add them to the worklist // before the root is changed. - void notifyRootReplaced(Instruction *op) override { + void notifyRootReplaced(Operation *op) override { for (auto *result : op->getResults()) // TODO: Add a result->getUsers() iterator. for (auto &user : result->getUses()) @@ -102,15 +102,15 @@ protected: } private: - // Look over the provided operands for any defining instructions that should + // Look over the provided operands for any defining operations that should // be re-added to the worklist. This function should be called when an // operation is modified or removed, as it may trigger further // simplifications. template void addToWorklist(Operands &&operands) { for (Value *operand : operands) { // If the use count of this operand is now < 2, we re-add the defining - // instruction to the worklist. - // TODO(riverriddle) This is based on the fact that zero use instructions + // operation to the worklist. + // TODO(riverriddle) This is based on the fact that zero use operations // may be deleted, and that single use values often have more // canonicalization opportunities. if (!operand->use_empty() && @@ -131,13 +131,13 @@ private: /// need to be revisited, plus their index in the worklist. This allows us to /// efficiently remove operations from the worklist when they are erased from /// the function, even if they aren't the root of a pattern. - std::vector worklist; - DenseMap worklistMap; + std::vector worklist; + DenseMap worklistMap; /// As part of canonicalization, we move constants to the top of the entry /// block of the current function and de-duplicate them. This keeps track of /// constants we have done this for. - DenseMap, Instruction *> uniquedConstants; + DenseMap, Operation *> uniquedConstants; }; }; // end anonymous namespace @@ -199,7 +199,7 @@ void GreedyPatternRewriteDriver::simplifyFunction() { continue; } - // Check to see if any operands to the instruction is constant and whether + // Check to see if any operands to the operation is constant and whether // the operation knows how to constant fold itself. operandConstants.assign(op->getNumOperands(), Attribute()); for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 2760e8b8bd3..0f962657fad 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -123,10 +123,10 @@ LogicalResult mlir::promoteIfSingleIteration(AffineForOp forOp) { // Replaces all IV uses to its single iteration value. auto *iv = forOp.getInductionVar(); - Instruction *forInst = forOp.getOperation(); + Operation *op = forOp.getOperation(); if (!iv->use_empty()) { if (forOp.hasConstantLowerBound()) { - auto *mlFunc = forInst->getFunction(); + auto *mlFunc = op->getFunction(); FuncBuilder topBuilder(mlFunc); auto constOp = topBuilder.create( forOp.getLoc(), forOp.getConstantLowerBound()); @@ -134,28 +134,28 @@ LogicalResult mlir::promoteIfSingleIteration(AffineForOp forOp) { } else { AffineBound lb = forOp.getLowerBound(); SmallVector lbOperands(lb.operand_begin(), lb.operand_end()); - FuncBuilder builder(forInst->getBlock(), Block::iterator(forInst)); + FuncBuilder builder(op->getBlock(), Block::iterator(op)); if (lb.getMap() == builder.getDimIdentityMap()) { // No need of generating an affine.apply. iv->replaceAllUsesWith(lbOperands[0]); } else { auto affineApplyOp = builder.create( - forInst->getLoc(), lb.getMap(), lbOperands); + op->getLoc(), lb.getMap(), lbOperands); iv->replaceAllUsesWith(affineApplyOp); } } } - // Move the loop body instructions, except for terminator, to the loop's + // Move the loop body operations, except for terminator, to the loop's // containing block. - auto *block = forInst->getBlock(); + auto *block = op->getBlock(); forOp.getBody()->getOperations().back().erase(); - block->getOperations().splice(Block::iterator(forInst), + block->getOperations().splice(Block::iterator(op), forOp.getBody()->getOperations()); forOp.erase(); return success(); } -/// Promotes all single iteration for inst's in the Function, i.e., moves +/// Promotes all single iteration for op's in the Function, i.e., moves /// their body into the containing Block. void mlir::promoteSingleIterationLoops(Function *f) { // Gathers all innermost loops through a post order pruned walk. @@ -163,16 +163,16 @@ void mlir::promoteSingleIterationLoops(Function *f) { [](AffineForOp forOp) { promoteIfSingleIteration(forOp); }); } -/// Generates a 'affine.for' inst with the specified lower and upper bounds -/// while generating the right IV remappings for the shifted instructions. The -/// instruction blocks that go into the loop are specified in instGroupQueue +/// Generates a 'affine.for' op with the specified lower and upper bounds +/// while generating the right IV remappings for the shifted operations. The +/// operation blocks that go into the loop are specified in instGroupQueue /// starting from the specified offset, and in that order; the first element of -/// the pair specifies the shift applied to that group of instructions; note +/// the pair specifies the shift applied to that group of operations; note /// that the shift is multiplied by the loop step before being applied. Returns /// nullptr if the generated loop simplifies to a single iteration one. static AffineForOp generateLoop(AffineMap lbMap, AffineMap ubMap, - const std::vector>> + const std::vector>> &instGroupQueue, unsigned offset, AffineForOp srcForInst, FuncBuilder *b) { SmallVector lbOperands(srcForInst.getLowerBoundOperands()); @@ -194,8 +194,8 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, it != e; ++it) { uint64_t shift = it->first; auto insts = it->second; - // All 'same shift' instructions get added with their operands being - // remapped to results of cloned instructions, and their IV used remapped. + // All 'same shift' operations get added with their operands being + // remapped to results of cloned operations, and their IV used remapped. // Generate the remapping if the shift is not zero: remappedIV = newIV - // shift. if (!srcIV->use_empty() && shift != 0) { @@ -208,9 +208,9 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, } else { operandMap.map(srcIV, loopChunkIV); } - for (auto *inst : insts) { - if (!inst->isa()) - bodyBuilder.clone(*inst, operandMap); + for (auto *op : insts) { + if (!op->isa()) + bodyBuilder.clone(*op, operandMap); } }; if (succeeded(promoteIfSingleIteration(loopChunk))) @@ -218,17 +218,17 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, return loopChunk; } -/// Skew the instructions in the body of a 'affine.for' instruction with the -/// specified instruction-wise shifts. The shifts are with respect to the +/// Skew the operations in the body of a 'affine.for' operation with the +/// specified operation-wise shifts. The shifts are with respect to the /// original execution order, and are multiplied by the loop 'step' before being -/// applied. A shift of zero for each instruction will lead to no change. -// The skewing of instructions with respect to one another can be used for +/// applied. A shift of zero for each operation will lead to no change. +// The skewing of operations with respect to one another can be used for // example to allow overlap of asynchronous operations (such as DMA -// communication) with computation, or just relative shifting of instructions +// communication) with computation, or just relative shifting of operations // for better register reuse, locality or parallelism. As such, the shifts are -// typically expected to be at most of the order of the number of instructions. +// typically expected to be at most of the order of the number of operations. // This method should not be used as a substitute for loop distribution/fission. -// This method uses an algorithm// in time linear in the number of instructions +// This method uses an algorithm// in time linear in the number of operations // in the body of the for loop - (using the 'sweep line' paradigm). This method // asserts preservation of SSA dominance. A check for that as well as that for // memory-based depedence preservation check rests with the users of this @@ -267,14 +267,14 @@ LogicalResult mlir::instBodySkew(AffineForOp forOp, ArrayRef shifts, return success(); } - // An array of instruction groups sorted by shift amount; each group has all - // instructions with the same shift in the order in which they appear in the - // body of the 'affine.for' inst. - std::vector> sortedInstGroups(maxShift + 1); + // An array of operation groups sorted by shift amount; each group has all + // operations with the same shift in the order in which they appear in the + // body of the 'affine.for' op. + std::vector> sortedInstGroups(maxShift + 1); unsigned pos = 0; - for (auto &inst : *forOp.getBody()) { + for (auto &op : *forOp.getBody()) { auto shift = shifts[pos++]; - sortedInstGroups[shift].push_back(&inst); + sortedInstGroups[shift].push_back(&op); } // Unless the shifts have a specific pattern (which actually would be the @@ -287,8 +287,8 @@ LogicalResult mlir::instBodySkew(AffineForOp forOp, ArrayRef shifts, // Do a sweep over the sorted shifts while storing open groups in a // vector, and generating loop portions as necessary during the sweep. A block - // of instructions is paired with its shift. - std::vector>> instGroupQueue; + // of operations is paired with its shift. + std::vector>> instGroupQueue; auto origLbMap = forOp.getLowerBoundMap(); uint64_t lbShift = 0; @@ -302,14 +302,14 @@ LogicalResult mlir::instBodySkew(AffineForOp forOp, ArrayRef shifts, "Queue expected to be empty when the first block is found"); // The interval for which the loop needs to be generated here is: // [lbShift, min(lbShift + tripCount, d)) and the body of the - // loop needs to have all instructions in instQueue in that order. + // loop needs to have all operations in instQueue in that order. AffineForOp res; if (lbShift + tripCount * step < d * step) { res = generateLoop( b.getShiftedAffineMap(origLbMap, lbShift), b.getShiftedAffineMap(origLbMap, lbShift + tripCount * step), instGroupQueue, 0, forOp, &b); - // Entire loop for the queued inst groups generated, empty it. + // Entire loop for the queued op groups generated, empty it. instGroupQueue.clear(); lbShift += tripCount * step; } else { @@ -325,11 +325,11 @@ LogicalResult mlir::instBodySkew(AffineForOp forOp, ArrayRef shifts, // Start of first interval. lbShift = d * step; } - // Augment the list of instructions that get into the current open interval. + // Augment the list of operations that get into the current open interval. instGroupQueue.push_back({d, sortedInstGroups[d]}); } - // Those instructions groups left in the queue now need to be processed (FIFO) + // Those operations groups left in the queue now need to be processed (FIFO) // and their loops completed. for (unsigned i = 0, e = instGroupQueue.size(); i < e; ++i) { uint64_t ubShift = (instGroupQueue[i].first + tripCount) * step; @@ -341,7 +341,7 @@ LogicalResult mlir::instBodySkew(AffineForOp forOp, ArrayRef shifts, prologue = epilogue; } - // Erase the original for inst. + // Erase the original for op. forOp.erase(); if (unrollPrologueEpilogue && prologue) @@ -407,10 +407,10 @@ LogicalResult mlir::loopUnrollByFactor(AffineForOp forOp, return failure(); // Generate the cleanup loop if trip count isn't a multiple of unrollFactor. - Instruction *forInst = forOp.getOperation(); + Operation *op = forOp.getOperation(); if (getLargestDivisorOfTripCount(forOp) % unrollFactor != 0) { - FuncBuilder builder(forInst->getBlock(), ++Block::iterator(forInst)); - auto cleanupForInst = builder.clone(*forInst)->cast(); + FuncBuilder builder(op->getBlock(), ++Block::iterator(op)); + auto cleanupForInst = builder.clone(*op)->cast(); AffineMap cleanupMap; SmallVector cleanupOperands; getCleanupLoopLowerBound(forOp, unrollFactor, &cleanupMap, &cleanupOperands, @@ -435,7 +435,7 @@ LogicalResult mlir::loopUnrollByFactor(AffineForOp forOp, // 'forOp'. FuncBuilder builder = forOp.getBodyBuilder(); - // Keep a pointer to the last non-terminator instruction in the original block + // Keep a pointer to the last non-terminator operation in the original block // so that we know what to clone (since we are doing this in-place). Block::iterator srcBlockEnd = std::prev(forOp.getBody()->end(), 2); @@ -530,17 +530,17 @@ static void cloneLoopBodyInto(AffineForOp forOp, Value *oldIv, BlockAndValueMapping map; map.map(oldIv, newForOp.getInductionVar()); FuncBuilder b = newForOp.getBodyBuilder(); - for (auto &inst : *forOp.getBody()) { + for (auto &op : *forOp.getBody()) { // Step over newForOp in case it is nested under forOp. - if (&inst == newForOp.getOperation()) { + if (&op == newForOp.getOperation()) { continue; } - if (inst.isa()) { + if (op.isa()) { continue; } - auto *instClone = b.clone(inst, map); + auto *instClone = b.clone(op, map); unsigned idx = 0; - for (auto r : inst.getResults()) { + for (auto r : op.getResults()) { // Since we do a forward pass over the body, we iteratively augment // the `map` with everything we clone. map.map(r, instClone->getResult(idx++)); @@ -567,8 +567,8 @@ stripmineSink(AffineForOp forOp, uint64_t factor, auto scaledStep = originalStep * factor; forOp.setStep(scaledStep); - auto *forInst = forOp.getOperation(); - FuncBuilder b(forInst->getBlock(), ++Block::iterator(forInst)); + auto *op = forOp.getOperation(); + FuncBuilder b(op->getBlock(), ++Block::iterator(op)); // Lower-bound map creation. auto lbMap = forOp.getLowerBoundMap(); @@ -588,11 +588,11 @@ stripmineSink(AffineForOp forOp, uint64_t factor, auto newForOp = b.create(t.getLoc(), lbOperands, lbMap, ubOperands, ubMap, originalStep); cloneLoopBodyInto(t, forOp.getInductionVar(), newForOp); - // Remove all instructions from `t` except `newForOp`. + // Remove all operations from `t` except `newForOp`. auto rit = ++newForOp.getOperation()->getReverseIterator(); auto re = t.getBody()->rend(); - for (auto &inst : llvm::make_early_inc_range(llvm::make_range(rit, re))) { - inst.erase(); + for (auto &op : llvm::make_early_inc_range(llvm::make_range(rit, re))) { + op.erase(); } innerLoops.push_back(newForOp); } diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index b5225d08827..422d6b136ab 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -37,7 +37,7 @@ using namespace mlir; /// Return true if this operation dereferences one or more memref's. // Temporary utility: will be replaced when this is modeled through // side-effects/op traits. TODO(b/117228571) -static bool isMemRefDereferencingOp(Instruction &op) { +static bool isMemRefDereferencingOp(Operation &op) { if (op.isa() || op.isa() || op.isa() || op.isa()) return true; @@ -48,8 +48,8 @@ bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, ArrayRef extraIndices, AffineMap indexRemap, ArrayRef extraOperands, - Instruction *domInstFilter, - Instruction *postDomInstFilter) { + Operation *domInstFilter, + Operation *postDomInstFilter) { unsigned newMemRefRank = newMemRef->getType().cast().getRank(); (void)newMemRefRank; // unused in opt mode unsigned oldMemRefRank = oldMemRef->getType().cast().getRank(); @@ -76,7 +76,7 @@ bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, llvm::make_unique(postDomInstFilter->getFunction()); // The ops where memref replacement succeeds are replaced with new ones. - SmallVector opsToErase; + SmallVector opsToErase; // Walk all uses of old memref. Operation using the memref gets replaced. for (auto &use : llvm::make_early_inc_range(oldMemRef->getUses())) { @@ -115,7 +115,7 @@ bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, }; unsigned memRefOperandPos = getMemRefOperandPos(); - // Construct the new operation instruction using this memref. + // Construct the new operation using this memref. OperationState state(opInst->getContext(), opInst->getLoc(), opInst->getName()); state.setOperandListToResizable(opInst->hasResizableOperandsList()); @@ -192,9 +192,9 @@ bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, return true; } -/// Given an operation instruction, inserts one or more single result affine +/// Given an operation, inserts one or more single result affine /// apply operations, results of which are exclusively used by this operation -/// instruction. The operands of these newly created affine apply ops are +/// operation. The operands of these newly created affine apply ops are /// guaranteed to be loop iterators or terminal symbols of a function. /// /// Before @@ -221,7 +221,7 @@ bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, /// uses besides this opInst; otherwise returns the list of affine.apply /// operations created in output argument `sliceOps`. void mlir::createAffineComputationSlice( - Instruction *opInst, SmallVectorImpl *sliceOps) { + Operation *opInst, SmallVectorImpl *sliceOps) { // Collect all operands that are results of affine apply ops. SmallVector subOperands; subOperands.reserve(opInst->getNumOperands()); @@ -233,13 +233,13 @@ void mlir::createAffineComputationSlice( } // Gather sequence of AffineApplyOps reachable from 'subOperands'. - SmallVector affineApplyOps; + SmallVector affineApplyOps; getReachableAffineApplyOps(subOperands, affineApplyOps); // Skip transforming if there are no affine maps to compose. if (affineApplyOps.empty()) return; - // Check if all uses of the affine apply op's lie only in this op inst, in + // Check if all uses of the affine apply op's lie only in this op op, in // which case there would be nothing to do. bool localized = true; for (auto *op : affineApplyOps) { @@ -291,7 +291,7 @@ void mlir::createAffineComputationSlice( } void mlir::remapFunctionAttrs( - Instruction &op, const DenseMap &remappingTable) { + Operation &op, const DenseMap &remappingTable) { for (auto attr : op.getAttrs()) { // Do the remapping, if we got the same thing back, then it must contain // functions that aren't getting remapped. @@ -310,9 +310,8 @@ void mlir::remapFunctionAttrs( void mlir::remapFunctionAttrs( Function &fn, const DenseMap &remappingTable) { - // Look at all instructions in a Function. - fn.walk( - [&](Instruction *inst) { remapFunctionAttrs(*inst, remappingTable); }); + // Look at all operations in a Function. + fn.walk([&](Operation *op) { remapFunctionAttrs(*op, remappingTable); }); } void mlir::remapFunctionAttrs( diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index 5590dbad7f1..cf0684ef90d 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -77,7 +77,7 @@ static llvm::cl::opt clTestNormalizeMaps( llvm::cl::desc( "Enable testing the normalization of AffineAffineApplyOp " "where each AffineAffineApplyOp in the composition is a single output " - "instruction."), + "operation."), llvm::cl::cat(clOptionsCategory)); namespace { @@ -104,16 +104,16 @@ void VectorizerTestPass::testVectorShapeRatio() { clTestVectorShapeRatio.end()); auto subVectorType = VectorType::get(shape, FloatType::getF32(f->getContext())); - // Only filter instructions that operate on a strict super-vector and have one + // Only filter operations that operate on a strict super-vector and have one // return. This makes testing easier. - auto filter = [subVectorType](Instruction &inst) { + auto filter = [subVectorType](Operation &op) { assert(subVectorType.getElementType() == FloatType::getF32(subVectorType.getContext()) && "Only f32 supported for now"); - if (!matcher::operatesOnSuperVectors(inst, subVectorType)) { + if (!matcher::operatesOnSuperVectors(op, subVectorType)) { return false; } - if (inst.getNumResults() != 1) { + if (op.getNumResults() != 1) { return false; } return true; @@ -138,10 +138,10 @@ void VectorizerTestPass::testVectorShapeRatio() { } } -static std::string toString(Instruction *inst) { +static std::string toString(Operation *op) { std::string res; llvm::raw_string_ostream os(res); - inst->print(os); + op->print(os); return res; } @@ -150,9 +150,9 @@ static NestedPattern patternTestSlicingOps() { constexpr auto kTestSlicingOpName = "slicing-test-op"; using functional::map; using matcher::Op; - // Match all OpInstructions with the kTestSlicingOpName name. - auto filter = [](Instruction &inst) { - return inst.getName().getStringRef() == kTestSlicingOpName; + // Match all operations with the kTestSlicingOpName name. + auto filter = [](Operation &op) { + return op.getName().getStringRef() == kTestSlicingOpName; }; return Op(filter); } @@ -163,7 +163,7 @@ void VectorizerTestPass::testBackwardSlicing() { SmallVector matches; patternTestSlicingOps().match(f, &matches); for (auto m : matches) { - SetVector backwardSlice; + SetVector backwardSlice; getBackwardSlice(m.getMatchedOperation(), &backwardSlice); auto strs = map(toString, backwardSlice); outs() << "\nmatched: " << *m.getMatchedOperation() @@ -179,7 +179,7 @@ void VectorizerTestPass::testForwardSlicing() { SmallVector matches; patternTestSlicingOps().match(f, &matches); for (auto m : matches) { - SetVector forwardSlice; + SetVector forwardSlice; getForwardSlice(m.getMatchedOperation(), &forwardSlice); auto strs = map(toString, forwardSlice); outs() << "\nmatched: " << *m.getMatchedOperation() @@ -196,7 +196,7 @@ void VectorizerTestPass::testSlicing() { SmallVector matches; patternTestSlicingOps().match(f, &matches); for (auto m : matches) { - SetVector staticSlice = getSlice(m.getMatchedOperation()); + SetVector staticSlice = getSlice(m.getMatchedOperation()); auto strs = map(toString, staticSlice); outs() << "\nmatched: " << *m.getMatchedOperation() << " static slice: "; for (const auto &s : strs) { @@ -205,8 +205,8 @@ void VectorizerTestPass::testSlicing() { } } -static bool customOpWithAffineMapAttribute(Instruction &inst) { - return inst.getName().getStringRef() == +static bool customOpWithAffineMapAttribute(Operation &op) { + return op.getName().getStringRef() == VectorizerTestPass::kTestAffineMapOpName; } @@ -233,12 +233,10 @@ void VectorizerTestPass::testComposeMaps() { simplifyAffineMap(res).print(outs() << "\nComposed map: "); } -static bool affineApplyOp(Instruction &inst) { - return inst.isa(); -} +static bool affineApplyOp(Operation &op) { return op.isa(); } -static bool singleResultAffineApplyOpWithoutUses(Instruction &inst) { - auto app = inst.dyn_cast(); +static bool singleResultAffineApplyOpWithoutUses(Operation &op) { + auto app = op.dyn_cast(); return app && app.use_empty(); } diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 8a7a7a6dbba..98e4053c633 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -166,7 +166,7 @@ using namespace mlir; /// references along fastest varying dimensions and loops with recursive nested /// patterns capturing imperfectly-nested loop nests; the SLP vectorizer, on /// the other hand, performs flat pattern matching inside a single unrolled loop -/// body and stitches together pieces of load and store instructions into full +/// body and stitches together pieces of load and store operations into full /// 1-D vectors. We envision that the SLP vectorizer is a good way to capture /// innermost loop, control-flow dependent patterns that super-vectorization may /// not be able to capture easily. In other words, super-vectorization does not @@ -662,13 +662,12 @@ namespace { struct VectorizationStrategy { SmallVector vectorSizes; - DenseMap loopToVectorDim; + DenseMap loopToVectorDim; }; } // end anonymous namespace -static void vectorizeLoopIfProfitable(Instruction *loop, - unsigned depthInPattern, +static void vectorizeLoopIfProfitable(Operation *loop, unsigned depthInPattern, unsigned patternDepth, VectorizationStrategy *strategy) { assert(patternDepth > depthInPattern && @@ -716,23 +715,23 @@ static LogicalResult analyzeProfitability(ArrayRef matches, namespace { struct VectorizationState { - /// Adds an entry of pre/post vectorization instructions in the state. - void registerReplacement(Instruction *key, Instruction *value); + /// Adds an entry of pre/post vectorization operations in the state. + void registerReplacement(Operation *key, Operation *value); /// When the current vectorization pattern is successful, this erases the - /// instructions that were marked for erasure in the proper order and resets + /// operations that were marked for erasure in the proper order and resets /// the internal state for the next pattern. void finishVectorizationPattern(); - // In-order tracking of original Instruction that have been vectorized. + // In-order tracking of original Operation that have been vectorized. // Erase in reverse order. - SmallVector toErase; - // Set of Instruction that have been vectorized (the values in the + SmallVector toErase; + // Set of Operation that have been vectorized (the values in the // vectorizationMap for hashed access). The vectorizedSet is used in - // particular to filter the instructions that have already been vectorized by + // particular to filter the operations that have already been vectorized by // this pattern, when iterating over nested loops in this pattern. - DenseSet vectorizedSet; - // Map of old scalar Instruction to new vectorized Instruction. - DenseMap vectorizationMap; + DenseSet vectorizedSet; + // Map of old scalar Operation to new vectorized Operation. + DenseMap vectorizationMap; // Map of old scalar Value to new vectorized Value. DenseMap replacementMap; // The strategy drives which loop to vectorize by which amount. @@ -742,17 +741,16 @@ struct VectorizationState { // operations that have been vectorized. They can be retrieved from // `vectorizationMap` but it is convenient to keep track of them in a separate // data structure. - DenseSet roots; - // Terminal instructions for the worklist in the vectorizeNonTerminals + DenseSet roots; + // Terminal operations for the worklist in the vectorizeNonTerminals // function. They consist of the subset of store operations that have been // vectorized. They can be retrieved from `vectorizationMap` but it is // convenient to keep track of them in a separate data structure. Since they // do not necessarily belong to use-def chains starting from loads (e.g // storing a constant), we need to handle them in a post-pass. - DenseSet terminals; - // Checks that the type of `inst` is StoreOp and adds it to the terminals - // set. - void registerTerminal(Instruction *inst); + DenseSet terminals; + // Checks that the type of `op` is StoreOp and adds it to the terminals set. + void registerTerminal(Operation *op); private: void registerReplacement(Value *key, Value *value); @@ -760,8 +758,7 @@ private: } // end namespace -void VectorizationState::registerReplacement(Instruction *key, - Instruction *value) { +void VectorizationState::registerReplacement(Operation *key, Operation *value) { LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ commit vectorized op: "); LLVM_DEBUG(key->print(dbgs())); LLVM_DEBUG(dbgs() << " into "); @@ -780,19 +777,19 @@ void VectorizationState::registerReplacement(Instruction *key, } } -void VectorizationState::registerTerminal(Instruction *inst) { - assert(inst->isa() && "terminal must be a StoreOp"); - assert(terminals.count(inst) == 0 && +void VectorizationState::registerTerminal(Operation *op) { + assert(op->isa() && "terminal must be a StoreOp"); + assert(terminals.count(op) == 0 && "terminal was already inserted previously"); - terminals.insert(inst); + terminals.insert(op); } void VectorizationState::finishVectorizationPattern() { while (!toErase.empty()) { - auto *inst = toErase.pop_back_val(); + auto *op = toErase.pop_back_val(); LLVM_DEBUG(dbgs() << "\n[early-vect] finishVectorizationPattern erase: "); - LLVM_DEBUG(inst->print(dbgs())); - inst->erase(); + LLVM_DEBUG(op->print(dbgs())); + op->erase(); } } @@ -857,13 +854,13 @@ static LogicalResult vectorizeAffineForOp(AffineForOp loop, int64_t step, using namespace functional; loop.setStep(step); - FilterFunctionType notVectorizedThisPattern = [state](Instruction &inst) { - if (!matcher::isLoadOrStore(inst)) { + FilterFunctionType notVectorizedThisPattern = [state](Operation &op) { + if (!matcher::isLoadOrStore(op)) { return false; } - return state->vectorizationMap.count(&inst) == 0 && - state->vectorizedSet.count(&inst) == 0 && - state->roots.count(&inst) == 0 && state->terminals.count(&inst) == 0; + return state->vectorizationMap.count(&op) == 0 && + state->vectorizedSet.count(&op) == 0 && + state->roots.count(&op) == 0 && state->terminals.count(&op) == 0; }; auto loadAndStores = matcher::Op(notVectorizedThisPattern); SmallVector loadAndStoresMatches; @@ -891,8 +888,8 @@ static LogicalResult vectorizeAffineForOp(AffineForOp loop, int64_t step, /// we can build a cost model and a search procedure. static FilterFunctionType isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension) { - return [fastestVaryingMemRefDimension](Instruction &forInst) { - auto loop = forInst.cast(); + return [fastestVaryingMemRefDimension](Operation &forOp) { + auto loop = forOp.cast(); return isVectorizableLoopAlongFastestVaryingMemRefDim( loop, fastestVaryingMemRefDimension); }; @@ -943,14 +940,13 @@ vectorizeLoopsAndLoadsRecursively(NestedMatch oneMatch, /// element type. /// If `type` is not a valid vector type or if the scalar constant is not a /// valid vector element type, returns nullptr. -static Value *vectorizeConstant(Instruction *inst, ConstantOp constant, - Type type) { +static Value *vectorizeConstant(Operation *op, ConstantOp constant, Type type) { if (!type || !type.isa() || !VectorType::isValidElementType(constant.getType())) { return nullptr; } - FuncBuilder b(inst); - Location loc = inst->getLoc(); + FuncBuilder b(op); + Location loc = op->getLoc(); auto vectorType = type.cast(); auto attr = SplatElementsAttr::get(vectorType, constant.getValue()); auto *constantOpInst = constant.getOperation(); @@ -962,10 +958,10 @@ static Value *vectorizeConstant(Instruction *inst, ConstantOp constant, return b.createOperation(state)->getResult(0); } -/// Tries to vectorize a given operand `op` of Instruction `inst` during +/// Tries to vectorize a given operand `op` of Operation `op` during /// def-chain propagation or during terminal vectorization, by applying the /// following logic: -/// 1. if the defining instruction is part of the vectorizedSet (i.e. vectorized +/// 1. if the defining operation is part of the vectorizedSet (i.e. vectorized /// useby -def propagation), `op` is already in the proper vector form; /// 2. otherwise, the `op` may be in some other vector form that fails to /// vectorize atm (i.e. broadcasting required), returns nullptr to indicate @@ -983,7 +979,7 @@ static Value *vectorizeConstant(Instruction *inst, ConstantOp constant, /// vectorization is possible with the above logic. Returns nullptr otherwise. /// /// TODO(ntv): handle more complex cases. -static Value *vectorizeOperand(Value *operand, Instruction *inst, +static Value *vectorizeOperand(Value *operand, Operation *op, VectorizationState *state) { LLVM_DEBUG(dbgs() << "\n[early-vect]vectorize operand: "); LLVM_DEBUG(operand->print(dbgs())); @@ -1011,7 +1007,7 @@ static Value *vectorizeOperand(Value *operand, Instruction *inst, // 3. vectorize constant. if (auto constant = operand->getDefiningOp()->dyn_cast()) { return vectorizeConstant( - inst, constant, + op, constant, VectorType::get(state->strategy->vectorSizes, operand->getType())); } // 4. currently non-vectorizable. @@ -1020,7 +1016,7 @@ static Value *vectorizeOperand(Value *operand, Instruction *inst, return nullptr; }; -/// Encodes Instruction-specific behavior for vectorization. In general we +/// Encodes Operation-specific behavior for vectorization. In general we /// assume that all operands of an op must be vectorized but this is not always /// true. In the future, it would be nice to have a trait that describes how a /// particular operation vectorizes. For now we implement the case distinction @@ -1029,8 +1025,8 @@ static Value *vectorizeOperand(Value *operand, Instruction *inst, /// TODO(ntv): consider adding a trait to Op to describe how it gets vectorized. /// Maybe some Ops are not vectorizable or require some tricky logic, we cannot /// do one-off logic here; ideally it would be TableGen'd. -static Instruction *vectorizeOneInstruction(Instruction *opInst, - VectorizationState *state) { +static Operation *vectorizeOneOperation(Operation *opInst, + VectorizationState *state) { // Sanity checks. assert(!opInst->isa() && "all loads must have already been fully vectorized independently"); @@ -1079,9 +1075,8 @@ static Instruction *vectorizeOneInstruction(Instruction *opInst, // Create a clone of the op with the proper operands and return types. // TODO(ntv): The following assumes there is always an op with a fixed // name that works both in scalar mode and vector mode. - // TODO(ntv): Is it worth considering an Instruction.clone operation - // which changes the type so we can promote an Instruction with less - // boilerplate? + // TODO(ntv): Is it worth considering an Operation.clone operation which + // changes the type so we can promote an Operation with less boilerplate? FuncBuilder b(opInst); OperationState newOp(b.getContext(), opInst->getLoc(), opInst->getName().getStringRef(), vectorOperands, @@ -1100,31 +1095,31 @@ static Instruction *vectorizeOneInstruction(Instruction *opInst, /// replacementMap. If any such replacement is missing, vectorization fails. static LogicalResult vectorizeNonTerminals(VectorizationState *state) { // 1. create initial worklist with the uses of the roots. - SetVector worklist; + SetVector worklist; // Note: state->roots have already been vectorized and must not be vectorized - // again. This fits `getForwardSlice` which does not insert `inst` in the + // again. This fits `getForwardSlice` which does not insert `op` in the // result. // Note: we have to exclude terminals because some of their defs may not be // nested under the vectorization pattern (e.g. constants defined in an // encompassing scope). // TODO(ntv): Use a backward slice for terminals, avoid special casing and // merge implementations. - for (auto *inst : state->roots) { - getForwardSlice(inst, &worklist, [state](Instruction *inst) { - return state->terminals.count(inst) == 0; // propagate if not terminal + for (auto *op : state->roots) { + getForwardSlice(op, &worklist, [state](Operation *op) { + return state->terminals.count(op) == 0; // propagate if not terminal }); } // We merged multiple slices, topological order may not hold anymore. worklist = topologicalSort(worklist); for (unsigned i = 0; i < worklist.size(); ++i) { - auto *inst = worklist[i]; + auto *op = worklist[i]; LLVM_DEBUG(dbgs() << "\n[early-vect] vectorize use: "); - LLVM_DEBUG(inst->print(dbgs())); + LLVM_DEBUG(op->print(dbgs())); - // Create vector form of the instruction. - // Insert it just before inst, on success register inst as replaced. - auto *vectorizedInst = vectorizeOneInstruction(inst, state); + // Create vector form of the operation. + // Insert it just before op, on success register op as replaced. + auto *vectorizedInst = vectorizeOneOperation(op, state); if (!vectorizedInst) { return failure(); } @@ -1133,7 +1128,7 @@ static LogicalResult vectorizeNonTerminals(VectorizationState *state) { // Note that we cannot just call replaceAllUsesWith because it may // result in ops with mixed types, for ops whose operands have not all // yet been vectorized. This would be invalid IR. - state->registerReplacement(inst, vectorizedInst); + state->registerReplacement(op, vectorizedInst); } return success(); } @@ -1193,9 +1188,8 @@ static LogicalResult vectorizeRootMatch(NestedMatch m, return guard.failure(); } - // 2. Vectorize operations reached by use-def chains from root - // except the terminals (store instructions) that need to be - // post-processed separately. + // 2. Vectorize operations reached by use-def chains from root except the + // terminals (store operations) that need to be post-processed separately. // TODO(ntv): add more as we expand. if (failed(vectorizeNonTerminals(&state))) { LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed vectorizeNonTerminals"); @@ -1208,8 +1202,8 @@ static LogicalResult vectorizeRootMatch(NestedMatch m, // encompassing scope). // TODO(ntv): Use a backward slice for terminals, avoid special casing and // merge implementations. - for (auto *inst : state.terminals) { - if (!vectorizeOneInstruction(inst, &state)) { // nullptr == failure + for (auto *op : state.terminals) { + if (!vectorizeOneOperation(op, &state)) { // nullptr == failure LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed to vectorize terminals"); return guard.failure(); } -- cgit v1.2.3 From 9d9675fc8fa96e78efa17dcc2d6fcc3e773f7a5f Mon Sep 17 00:00:00 2001 From: MLIR Team Date: Thu, 28 Mar 2019 14:54:49 -0700 Subject: Remove overly conservative check in LoopFusion pass (enables fusion in tutorial example). PiperOrigin-RevId: 240859227 --- mlir/lib/Transforms/LoopFusion.cpp | 10 +---- mlir/test/Transforms/loop-fusion.mlir | 69 ++++++++++++++++++++++++++++++----- 2 files changed, 62 insertions(+), 17 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 80308ea6a40..c35b75ff5ed 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -1275,7 +1275,8 @@ static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, // Compute MemRefRegion 'dstWriteRegion' for 'dstStoreOpInst' on 'memref'. SmallVector dstStoreOps; dstNode->getStoreOpsForMemref(memref, &dstStoreOps); - assert(dstStoreOps.size() == 1); + // TODO(andydavis) Compute 'unionboundingbox' of all write regions (one for + // each store op in 'dstStoreOps'). auto *dstStoreOpInst = dstStoreOps[0]; MemRefRegion dstWriteRegion(dstStoreOpInst->getLoc()); if (failed(dstWriteRegion.compute(dstStoreOpInst, /*loopDepth=*/0))) { @@ -1886,13 +1887,6 @@ public: if (srcNode->stores.size() != 1) continue; - // Skip 'srcNode' if it has in edges on 'memref'. - // TODO(andydavis) Track dependence type with edges, and just check - // for WAW dependence edge here. Note that this check is overly - // conservative and will be removed in the future. - if (mdg->getIncomingMemRefAccesses(srcNode->id, memref) != 0) - continue; - // Skip if 'srcNode' writes to any live in or escaping memrefs, // and cannot be fused. bool writesToLiveInOrOut = diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index 4d21d006ff1..dd3af0664f4 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -342,8 +342,10 @@ func @should_not_fuse_would_create_cycle() { // ----- -// CHECK-LABEL: func @should_not_fuse_across_waw_dep() { -func @should_not_fuse_across_waw_dep() { +// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) + +// CHECK-LABEL: func @should_fuse_producer_consumer() { +func @should_fuse_producer_consumer() { %m = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 @@ -356,15 +358,20 @@ func @should_not_fuse_across_waw_dep() { affine.for %i2 = 0 to 10 { %v1 = load %m[%i2] : memref<10xf32> } - // Fusing loop %i0 to %i2 would violate the WAW dependence between %i0 and %i1 + // Fusing loop %i0 to %i2 would violate the WAW dependence between %i0 and + // %i1, but OK to fuse %i1 into %i2. + // TODO(andydavis) When the fusion pass is run to a fixed-point, it should + // fuse all three of these loop nests. + // CHECK: %0 = alloc() : memref<1xf32> + // CHECK: %1 = alloc() : memref<10xf32> // CHECK: affine.for %i0 = 0 to 10 { - // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> - // CHECK-NEXT: } - // CHECK: affine.for %i1 = 0 to 10 { - // CHECK-NEXT: store %cst, %0[%i1] : memref<10xf32> + // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK: affine.for %i2 = 0 to 10 { - // CHECK-NEXT: %1 = load %0[%i2] : memref<10xf32> + // CHECK-NEXT: affine.for %i1 = 0 to 10 { + // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i1, %i1) + // CHECK-NEXT: store %cst, %0[%2] : memref<1xf32> + // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i1, %i1) + // CHECK-NEXT: %4 = load %0[%3] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -2289,3 +2296,47 @@ func @should_fuse_with_slice_union() { // CHECK-NEXT: return return } + +// ----- + +func @affine_add_mm_fused(%arg0: memref<1024x1024xf32>, %arg1: memref<1024x1024xf32>, %arg2: memref<1024x1024xf32>, %arg3: memref<1024x1024xf32>) { + affine.for %i2 = 0 to 1024 { + affine.for %i3 = 0 to 1024 { + %0 = load %arg3[%i2, %i3] : memref<1024x1024xf32> + %1 = load %arg2[%i2, %i3] : memref<1024x1024xf32> + %2 = addf %1, %0 : f32 + store %2, %arg2[%i2, %i3] : memref<1024x1024xf32> + } + } + affine.for %i4 = 0 to 1024 { + affine.for %i5 = 0 to 1024 { + affine.for %i6 = 0 to 1024 { + %3 = load %arg1[%i6, %i5] : memref<1024x1024xf32> + %4 = load %arg0[%i4, %i6] : memref<1024x1024xf32> + %5 = mulf %4, %3 : f32 + %6 = load %arg2[%i4, %i5] : memref<1024x1024xf32> + %7 = addf %6, %5 : f32 + store %7, %arg2[%i4, %i5] : memref<1024x1024xf32> + } + } + } + // Should fuse elementwise add loop at loop depth 2, above loop-carried + // dependence between load/store on '%arg2', carried on reduction loop %i6. + // CHECK: affine.for %i0 = 0 to 1024 { + // CHECK-NEXT: affine.for %i1 = 0 to 1024 { + // CHECK-NEXT: %0 = load %arg3[%i0, %i1] : memref<1024x1024xf32> + // CHECK-NEXT: %1 = load %arg2[%i0, %i1] : memref<1024x1024xf32> + // CHECK-NEXT: %2 = addf %1, %0 : f32 + // CHECK-NEXT: store %2, %arg2[%i0, %i1] : memref<1024x1024xf32> + // CHECK-NEXT: affine.for %i2 = 0 to 1024 { + // CHECK-NEXT: %3 = load %arg1[%i2, %i1] : memref<1024x1024xf32> + // CHECK-NEXT: %4 = load %arg0[%i0, %i2] : memref<1024x1024xf32> + // CHECK-NEXT: %5 = mulf %4, %3 : f32 + // CHECK-NEXT: %6 = load %arg2[%i0, %i1] : memref<1024x1024xf32> + // CHECK-NEXT: %7 = addf %6, %5 : f32 + // CHECK-NEXT: store %7, %arg2[%i0, %i1] : memref<1024x1024xf32> + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: } + return +} -- cgit v1.2.3 From 9d30b36aaf8eb80d3744dd881057ae7618f8f08d Mon Sep 17 00:00:00 2001 From: MLIR Team Date: Fri, 29 Mar 2019 08:06:25 -0700 Subject: Enable input-reuse fusion to search function arguments for fusion candidates (takes care of a TODO, enables another tutorial test case). PiperOrigin-RevId: 240979894 --- mlir/lib/Transforms/LoopFusion.cpp | 116 ++++++++++++++++++++++++---------- mlir/test/Transforms/loop-fusion.mlir | 70 ++++++++++++++++++++ 2 files changed, 152 insertions(+), 34 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index c35b75ff5ed..900c45fce12 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -266,6 +266,14 @@ public: return &it->second; } + // Returns the graph node for 'forOp'. + Node *getForOpNode(AffineForOp forOp) { + for (auto &idAndNode : nodes) + if (idAndNode.second.op == forOp.getOperation()) + return &idAndNode.second; + return nullptr; + } + // Adds a node with 'op' to the graph and returns its unique identifier. unsigned addNode(Operation *op) { Node node(nextNodeId++, op); @@ -2096,17 +2104,79 @@ public: } } - // Searches the graph from 'dstNode' looking for a fusion candidate sibling - // node which shares no dependences with 'dstNode' but which loads from the - // same memref. Returns true and sets 'idAndMemrefToFuse' on success. Returns - // false otherwise. + // Searches function argument uses and the graph from 'dstNode' looking for a + // fusion candidate sibling node which shares no dependences with 'dstNode' + // but which loads from the same memref. Returns true and sets + // 'idAndMemrefToFuse' on success. Returns false otherwise. bool findSiblingNodeToFuse(Node *dstNode, DenseSet *visitedSibNodeIds, std::pair *idAndMemrefToFuse) { - // TODO(andydavis) Currently we discover siblings by following edges - // through an intermediate src node. We should also consider siblings - // which load from the same memref, but which do not necessarily share - // a src node parent (e.g. loading from a memref which is a function arg). + // Returns true if 'sibNode' can be fused with 'dstNode' for input reuse + // on 'memref'. + auto canFuseWithSibNode = [&](Node *sibNode, Value *memref) { + // Skip if 'outEdge' is not a read-after-write dependence. + // TODO(andydavis) Remove restrict to single load op restriction. + if (sibNode->getLoadOpCount(memref) != 1) + return false; + // Skip if there exists a path of dependent edges between + // 'sibNode' and 'dstNode'. + if (mdg->hasDependencePath(sibNode->id, dstNode->id) || + mdg->hasDependencePath(dstNode->id, sibNode->id)) + return false; + // Skip sib node if it loads to (and stores from) the same memref on + // which it also has an input dependence edge. + DenseSet loadAndStoreMemrefSet; + sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet); + if (llvm::any_of(loadAndStoreMemrefSet, [=](Value *memref) { + return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > 0; + })) + return false; + + // Check that all stores are to the same memref. + DenseSet storeMemrefs; + for (auto *storeOpInst : sibNode->stores) { + storeMemrefs.insert(storeOpInst->cast().getMemRef()); + } + if (storeMemrefs.size() != 1) + return false; + return true; + }; + + // Search for siblings which load the same memref function argument. + auto *fn = dstNode->op->getFunction(); + for (unsigned i = 0, e = fn->getNumArguments(); i != e; ++i) { + for (auto &use : fn->getArgument(i)->getUses()) { + if (auto loadOp = use.getOwner()->dyn_cast()) { + // Gather loops surrounding 'use'. + SmallVector loops; + getLoopIVs(*use.getOwner(), &loops); + // Skip 'use' if it is not within a loop nest. + if (loops.empty()) + continue; + Node *sibNode = mdg->getForOpNode(loops[0]); + assert(sibNode != nullptr); + // Skip 'use' if it not a sibling to 'dstNode'. + if (sibNode->id == dstNode->id) + continue; + // Skip 'use' if it has been visited. + if (visitedSibNodeIds->count(sibNode->id) > 0) + continue; + // Skip 'use' if it does not load from the same memref as 'dstNode'. + auto *memref = loadOp.getMemRef(); + if (dstNode->getLoadOpCount(memref) == 0) + continue; + // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'. + if (canFuseWithSibNode(sibNode, memref)) { + visitedSibNodeIds->insert(sibNode->id); + idAndMemrefToFuse->first = sibNode->id; + idAndMemrefToFuse->second = memref; + return true; + } + } + } + } + + // Search for siblings by following edges through an intermediate src node. // Collect candidate 'dstNode' input edges in 'inEdges'. SmallVector inEdges; mdg->forEachMemRefInputEdge( @@ -2133,33 +2203,11 @@ public: auto *sibNode = mdg->getNode(sibNodeId); if (!sibNode->op->isa()) return; - // Skip if 'outEdge' is not a read-after-write dependence. - // TODO(andydavis) Remove restrict to single load op restriction. - if (sibNode->getLoadOpCount(inEdge.value) != 1) - return; - // Skip if there exists a path of dependent edges between - // 'sibNode' and 'dstNode'. - if (mdg->hasDependencePath(sibNodeId, dstNode->id) || - mdg->hasDependencePath(dstNode->id, sibNodeId)) - return; - // Skip sib node if it loads to (and stores from) the same memref on - // which it also has an input dependence edge. - DenseSet loadAndStoreMemrefSet; - sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet); - if (llvm::any_of(loadAndStoreMemrefSet, [=](Value *memref) { - return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > - 0; - })) - return; - // Check that all stores are to the same memref. - DenseSet storeMemrefs; - for (auto *storeOpInst : sibNode->stores) { - storeMemrefs.insert(storeOpInst->cast().getMemRef()); + // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'. + if (canFuseWithSibNode(sibNode, outEdge.value)) { + // Add candidate 'outEdge' to sibling node. + outEdges.push_back(outEdge); } - if (storeMemrefs.size() != 1) - return; - // Add candidate 'outEdge' to sibling node. - outEdges.push_back(outEdge); }); // Add first candidate if any were returned. diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index dd3af0664f4..7da36dd9edb 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -2340,3 +2340,73 @@ func @affine_add_mm_fused(%arg0: memref<1024x1024xf32>, %arg1: memref<1024x1024x // CHECK-NEXT: } return } + +// ----- + +func @affine_2mm_fused(%arg0: memref<1024x1024xf32>, %arg1: memref<1024x1024xf32>, %arg2: memref<1024x1024xf32>, %arg3: memref<1024x1024xf32>, %arg4: memref<1024x1024xf32>) { + %cst = constant 0.000000e+00 : f32 + affine.for %i0 = 0 to 1024 { + affine.for %i1 = 0 to 1024 { + store %cst, %arg2[%i0, %i1] : memref<1024x1024xf32> + } + } + affine.for %i2 = 0 to 1024 { + affine.for %i3 = 0 to 1024 { + store %cst, %arg4[%i2, %i3] : memref<1024x1024xf32> + } + } + affine.for %i4 = 0 to 1024 { + affine.for %i5 = 0 to 1024 { + affine.for %i6 = 0 to 1024 { + %0 = load %arg1[%i6, %i5] : memref<1024x1024xf32> + %1 = load %arg0[%i4, %i6] : memref<1024x1024xf32> + %2 = mulf %1, %0 : f32 + %3 = load %arg2[%i4, %i5] : memref<1024x1024xf32> + %4 = addf %3, %2 : f32 + store %4, %arg2[%i4, %i5] : memref<1024x1024xf32> + } + } + } + affine.for %i7 = 0 to 1024 { + affine.for %i8 = 0 to 1024 { + affine.for %i9 = 0 to 1024 { + %5 = load %arg1[%i9, %i8] : memref<1024x1024xf32> + %6 = load %arg0[%i7, %i9] : memref<1024x1024xf32> + %7 = mulf %6, %5 : f32 + %8 = load %arg4[%i7, %i8] : memref<1024x1024xf32> + %9 = addf %8, %7 : f32 + store %9, %arg4[%i7, %i8] : memref<1024x1024xf32> + } + } + } + + // Should fuse MM intialization loops into their consumers, then fuse the + // two matmul loops together for input reuse on '%arg0/%arg1'. + + // CHECK: affine.for %i0 = 0 to 1024 { + // CHECK-NEXT: affine.for %i1 = 0 to 1024 { + // CHECK-NEXT: store %cst, %arg4[%i0, %i1] : memref<1024x1024xf32> + // CHECK-NEXT: affine.for %i2 = 0 to 1024 { + // CHECK-NEXT: %0 = load %arg1[%i2, %i1] : memref<1024x1024xf32> + // CHECK-NEXT: %1 = load %arg0[%i0, %i2] : memref<1024x1024xf32> + // CHECK-NEXT: %2 = mulf %1, %0 : f32 + // CHECK-NEXT: %3 = load %arg4[%i0, %i1] : memref<1024x1024xf32> + // CHECK-NEXT: %4 = addf %3, %2 : f32 + // CHECK-NEXT: store %4, %arg4[%i0, %i1] : memref<1024x1024xf32> + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: affine.for %i3 = 0 to 1024 { + // CHECK-NEXT: store %cst, %arg2[%i0, %i3] : memref<1024x1024xf32> + // CHECK-NEXT: affine.for %i4 = 0 to 1024 { + // CHECK-NEXT: %5 = load %arg1[%i4, %i3] : memref<1024x1024xf32> + // CHECK-NEXT: %6 = load %arg0[%i0, %i4] : memref<1024x1024xf32> + // CHECK-NEXT: %7 = mulf %6, %5 : f32 + // CHECK-NEXT: %8 = load %arg2[%i0, %i3] : memref<1024x1024xf32> + // CHECK-NEXT: %9 = addf %8, %7 : f32 + // CHECK-NEXT: store %9, %arg2[%i0, %i3] : memref<1024x1024xf32> + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: } + + return +} -- cgit v1.2.3 From 7c1fc9e795e479accb9614cf2a9686b5eae855fa Mon Sep 17 00:00:00 2001 From: Andy Davis Date: Tue, 2 Apr 2019 06:37:40 -0700 Subject: Enable producer-consumer fusion for liveout memrefs if consumer read region matches producer write region. -- PiperOrigin-RevId: 241517207 --- mlir/lib/Transforms/LoopFusion.cpp | 40 +++++++++------- mlir/test/Transforms/loop-fusion.mlir | 89 +++++++++++++++++++++++++++++++++-- 2 files changed, 107 insertions(+), 22 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 900c45fce12..2ed159ca647 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -1244,11 +1244,10 @@ static uint64_t getSliceIterationCount( // Checks if node 'srcId' (which writes to a live out memref), can be safely // fused into node 'dstId'. Returns true if the following conditions are met: -// *) 'srcNode' writes only writes to live out 'memref'. +// *) 'srcNode' only writes to live out 'memref'. // *) 'srcNode' has exaclty one output edge on 'memref' (which is to 'dstId'). -// *) 'dstNode' does write to 'memref'. -// *) 'dstNode's write region to 'memref' is a super set of 'srcNode's write -// region to 'memref'. +// *) 'dstNode's read/write region to 'memref' is a super set of 'srcNode's +// write region to 'memref'. // TODO(andydavis) Generalize this to handle more live in/out cases. static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, Value *memref, @@ -1256,13 +1255,17 @@ static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, auto *srcNode = mdg->getNode(srcId); auto *dstNode = mdg->getNode(dstId); + // Gather all memrefs from 'srcNode' store ops. + DenseSet storeMemrefs; + for (auto *storeOpInst : srcNode->stores) { + storeMemrefs.insert(storeOpInst->cast().getMemRef()); + } // Return false if any of the following are true: // *) 'srcNode' writes to a live in/out memref other than 'memref'. // *) 'srcNode' has more than one output edge on 'memref'. - // *) 'dstNode' does not write to 'memref'. - if (srcNode->getStoreOpCount(memref) != 1 || - mdg->getOutEdgeCount(srcNode->id, memref) != 1 || - dstNode->getStoreOpCount(memref) == 0) + // Check that all stores are to the same memref. + if (storeMemrefs.size() != 1 || + mdg->getOutEdgeCount(srcNode->id, memref) != 1) return false; // Compute MemRefRegion 'srcWriteRegion' for 'srcStoreOpInst' on 'memref'. auto *srcStoreOpInst = srcNode->stores.front(); @@ -1280,23 +1283,26 @@ static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, if (!srcNumElements.hasValue()) return false; - // Compute MemRefRegion 'dstWriteRegion' for 'dstStoreOpInst' on 'memref'. - SmallVector dstStoreOps; - dstNode->getStoreOpsForMemref(memref, &dstStoreOps); + // Compute MemRefRegion 'dstRegion' for 'dstStore/LoadOpInst' on 'memref'. // TODO(andydavis) Compute 'unionboundingbox' of all write regions (one for // each store op in 'dstStoreOps'). - auto *dstStoreOpInst = dstStoreOps[0]; - MemRefRegion dstWriteRegion(dstStoreOpInst->getLoc()); - if (failed(dstWriteRegion.compute(dstStoreOpInst, /*loopDepth=*/0))) { + SmallVector dstStoreOps; + dstNode->getStoreOpsForMemref(memref, &dstStoreOps); + SmallVector dstLoadOps; + dstNode->getLoadOpsForMemref(memref, &dstLoadOps); + + auto *dstOpInst = dstStoreOps.empty() ? dstLoadOps[0] : dstStoreOps[0]; + MemRefRegion dstRegion(dstOpInst->getLoc()); + if (failed(dstRegion.compute(dstOpInst, /*loopDepth=*/0))) { LLVM_DEBUG(llvm::dbgs() << "Unable to compute MemRefRegion for dest operation\n."); return false; } SmallVector dstShape; - // Query 'dstWriteRegion' for 'dstShape' and 'dstNumElements'. - // by 'dstStoreOpInst' at depth 'dstLoopDepth'. + // Query 'dstRegion' for 'dstShape' and 'dstNumElements'. + // by 'dstOpInst' at depth 'dstLoopDepth'. Optional dstNumElements = - dstWriteRegion.getConstantBoundingSizeAndShape(&dstShape); + dstRegion.getConstantBoundingSizeAndShape(&dstShape); if (!dstNumElements.hasValue()) return false; diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index 7da36dd9edb..2fe10b750a6 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -1261,15 +1261,18 @@ func @should_not_fuse_live_out_arg(%arg0: memref<10xf32>) { affine.for %i0 = 0 to 10 { store %cf7, %arg0[%i0] : memref<10xf32> } - affine.for %i1 = 0 to 10 { + affine.for %i1 = 0 to 9 { %v0 = load %arg0[%i1] : memref<10xf32> } // This tests that the loop nest '%i0' should not be removed after fusion - // because it writes to memref argument '%arg0'. + // because it writes to memref argument '%arg0', and its read region + // does not cover its write region (so fusion would shrink the write region + // in the fused loop nest, so complete live out data region would not + // be written). // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %arg0[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: affine.for %i1 = 0 to 10 { + // CHECK-NEXT: affine.for %i1 = 0 to 9 { // CHECK-NEXT: %0 = load %arg0[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return @@ -1278,6 +1281,29 @@ func @should_not_fuse_live_out_arg(%arg0: memref<10xf32>) { // ----- +// CHECK-LABEL: func @should_fuse_live_out_arg(%arg0: memref<10xf32>) { +func @should_fuse_live_out_arg(%arg0: memref<10xf32>) { + %cf7 = constant 7.0 : f32 + + affine.for %i0 = 0 to 10 { + store %cf7, %arg0[%i0] : memref<10xf32> + } + affine.for %i1 = 0 to 10 { + %v0 = load %arg0[%i1] : memref<10xf32> + } + // The read/write regions for memref '%arg0' are the same for both + // loops, so they should fuse. + + // CHECK: affine.for %i0 = 0 to 10 { + // CHECK-NEXT: store %cst, %arg0[%i0] : memref<10xf32> + // CHECK-NEXT: %0 = load %arg0[%i0] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: return + return +} + +// ----- + // CHECK-LABEL: func @should_not_fuse_escaping_memref() -> memref<10xf32> func @should_not_fuse_escaping_memref() -> memref<10xf32> { %cf7 = constant 7.0 : f32 @@ -1285,7 +1311,7 @@ func @should_not_fuse_escaping_memref() -> memref<10xf32> { affine.for %i0 = 0 to 10 { store %cf7, %m[%i0] : memref<10xf32> } - affine.for %i1 = 0 to 10 { + affine.for %i1 = 0 to 9 { %v0 = load %m[%i1] : memref<10xf32> } // This tests that the loop nest '%i0' should not be removed after fusion @@ -1294,7 +1320,7 @@ func @should_not_fuse_escaping_memref() -> memref<10xf32> { // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: affine.for %i1 = 0 to 10 { + // CHECK-NEXT: affine.for %i1 = 0 to 9 { // CHECK-NEXT: %1 = load %0[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return %0 : memref<10xf32> @@ -2410,3 +2436,56 @@ func @affine_2mm_fused(%arg0: memref<1024x1024xf32>, %arg1: memref<1024x1024xf32 return } + +// ----- + +func @affine_2_dependent_mm_fused(%arg0: memref<1024x1024xf32>, %arg1: memref<1024x1024xf32>, %arg2: memref<1024x1024xf32>, %arg3: memref<1024x1024xf32>, %arg4: memref<1024x1024xf32>) { + affine.for %i0 = 0 to 1024 { + affine.for %i1 = 0 to 1024 { + affine.for %i2 = 0 to 1024 { + %0 = load %arg1[%i2, %i1] : memref<1024x1024xf32> + %1 = load %arg0[%i0, %i2] : memref<1024x1024xf32> + %2 = mulf %1, %0 : f32 + %3 = load %arg2[%i0, %i1] : memref<1024x1024xf32> + %4 = addf %3, %2 : f32 + store %4, %arg2[%i0, %i1] : memref<1024x1024xf32> + } + } + } + affine.for %i3 = 0 to 1024 { + affine.for %i4 = 0 to 1024 { + affine.for %i5 = 0 to 1024 { + %5 = load %arg3[%i5, %i4] : memref<1024x1024xf32> + %6 = load %arg2[%i3, %i5] : memref<1024x1024xf32> + %7 = mulf %6, %5 : f32 + %8 = load %arg4[%i3, %i4] : memref<1024x1024xf32> + %9 = addf %8, %7 : f32 + store %9, %arg4[%i3, %i4] : memref<1024x1024xf32> + } + } + } + + // CHECK: affine.for %i0 = 0 to 1024 { + // CHECK-NEXT: affine.for %i1 = 0 to 1024 { + // CHECK-NEXT: affine.for %i2 = 0 to 1024 { + // CHECK-NEXT: %0 = load %arg1[%i2, %i1] : memref<1024x1024xf32> + // CHECK-NEXT: %1 = load %arg0[%i0, %i2] : memref<1024x1024xf32> + // CHECK-NEXT: %2 = mulf %1, %0 : f32 + // CHECK-NEXT: %3 = load %arg2[%i0, %i1] : memref<1024x1024xf32> + // CHECK-NEXT: %4 = addf %3, %2 : f32 + // CHECK-NEXT: store %4, %arg2[%i0, %i1] : memref<1024x1024xf32> + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: affine.for %i3 = 0 to 1024 { + // CHECK-NEXT: affine.for %i4 = 0 to 1024 { + // CHECK-NEXT: %5 = load %arg3[%i4, %i3] : memref<1024x1024xf32> + // CHECK-NEXT: %6 = load %arg2[%i0, %i4] : memref<1024x1024xf32> + // CHECK-NEXT: %7 = mulf %6, %5 : f32 + // CHECK-NEXT: %8 = load %arg4[%i0, %i3] : memref<1024x1024xf32> + // CHECK-NEXT: %9 = addf %8, %7 : f32 + // CHECK-NEXT: store %9, %arg4[%i0, %i3] : memref<1024x1024xf32> + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: } + return +} -- cgit v1.2.3 From 0cd589c337eae12454452356f96752b229b373f9 Mon Sep 17 00:00:00 2001 From: MLIR Team Date: Thu, 4 Apr 2019 15:19:17 -0700 Subject: Create a LoopUtil function to return perfectly nested loop set -- PiperOrigin-RevId: 242019230 --- mlir/include/mlir/Transforms/LoopUtils.h | 9 +++++++++ mlir/lib/Transforms/LoopFusion.cpp | 11 +---------- mlir/lib/Transforms/LoopTiling.cpp | 6 +----- mlir/lib/Transforms/Utils/LoopUtils.cpp | 16 ++++++++++++++++ 4 files changed, 27 insertions(+), 15 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Transforms/LoopUtils.h b/mlir/include/mlir/Transforms/LoopUtils.h index f1e7b503769..2aecdceff7f 100644 --- a/mlir/include/mlir/Transforms/LoopUtils.h +++ b/mlir/include/mlir/Transforms/LoopUtils.h @@ -37,14 +37,23 @@ class Value; /// Unrolls this for operation completely if the trip count is known to be /// constant. Returns failure otherwise. LogicalResult loopUnrollFull(AffineForOp forOp); + /// Unrolls this for operation by the specified unroll factor. Returns failure /// if the loop cannot be unrolled either due to restrictions or due to invalid /// unroll factors. LogicalResult loopUnrollByFactor(AffineForOp forOp, uint64_t unrollFactor); + /// Unrolls this loop by the specified unroll factor or its trip count, /// whichever is lower. LogicalResult loopUnrollUpToFactor(AffineForOp forOp, uint64_t unrollFactor); +/// Get perfectly nested sequence of loops starting at root of loop nest +/// (the first op being another AffineFor, and the second op - a terminator). +/// A loop is perfectly nested iff: the first op in the loop's body is another +/// AffineForOp, and the second op is a terminator). +void getPerfectlyNestedLoops(SmallVectorImpl &nestedLoops, + AffineForOp root); + /// Unrolls and jams this loop by the specified factor. Returns success if the /// loop is successfully unroll-jammed. LogicalResult loopUnrollJamByFactor(AffineForOp forOp, diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 2ed159ca647..39ed5a100a1 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -1062,18 +1062,9 @@ computeLoopInterchangePermutation(ArrayRef ops, // pushing loop carried dependence to a greater depth in the loop nest. static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) { assert(node->op->isa()); - // Get perfectly nested sequence of loops starting at root of loop nest - // (the first op being another AffineFor, and the second op - a terminator). - // TODO(andydavis,bondhugula) Share this with similar code in loop tiling. SmallVector loops; AffineForOp curr = node->op->cast(); - loops.push_back(curr); - auto *currBody = curr.getBody(); - while (currBody->begin() == std::prev(currBody->end(), 2) && - (curr = curr.getBody()->front().dyn_cast())) { - loops.push_back(curr); - currBody = curr.getBody(); - } + getPerfectlyNestedLoops(loops, curr); if (loops.size() < 2) return; diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 956d50ec26f..c215fa34172 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -270,11 +270,7 @@ static void getTileableBands(Function &f, // (inclusive). auto getMaximalPerfectLoopNest = [&](AffineForOp root) { SmallVector band; - AffineForOp currInst = root; - do { - band.push_back(currInst); - } while (currInst.getBody()->getOperations().size() == 2 && - (currInst = currInst.getBody()->front().dyn_cast())); + getPerfectlyNestedLoops(band, root); bands->push_back(band); }; diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 2b17f4b84fa..1e9697a9ad5 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -353,6 +353,22 @@ LogicalResult mlir::instBodySkew(AffineForOp forOp, ArrayRef shifts, return success(); } +/// Get perfectly nested sequence of loops starting at root of loop nest +/// (the first op being another AffineFor, and the second op - a terminator). +/// A loop is perfectly nested iff: the first op in the loop's body is another +/// AffineForOp, and the second op is a terminator). +void mlir::getPerfectlyNestedLoops(SmallVectorImpl &nestedLoops, + AffineForOp root) { + AffineForOp curr = root; + nestedLoops.push_back(curr); + auto *currBody = curr.getBody(); + while (currBody->begin() == std::prev(currBody->end(), 2) && + (curr = curr.getBody()->front().dyn_cast())) { + nestedLoops.push_back(curr); + currBody = curr.getBody(); + } +} + /// Unrolls this loop completely. LogicalResult mlir::loopUnrollFull(AffineForOp forOp) { Optional mayBeConstantTripCount = getConstantTripCount(forOp); -- cgit v1.2.3 From e4628b79fb810be529157cdf1197bea78f059c3e Mon Sep 17 00:00:00 2001 From: River Riddle Date: Fri, 5 Apr 2019 12:24:03 -0700 Subject: Add new utilities for RTTI Operation casting: dyn_cast_or_null and isa_nonnull * dyn_cast_or_null - This will first check if the operation is null before trying to 'dyn_cast': Value *v = ...; if (auto forOp = dyn_cast_or_null(v->getDefiningOp())) ... * isa_nonnull - This will first check if the pointer is null before trying to 'isa': Value *v = ...; if (isa_nonnull(v->getDefiningOp()); ... -- PiperOrigin-RevId: 242171343 --- mlir/examples/Linalg/Linalg1/lib/SliceOp.cpp | 2 +- mlir/examples/Linalg/Linalg1/lib/Utils.cpp | 2 +- mlir/examples/toy/Ch4/mlir/ToyCombine.cpp | 22 ++++++++++------------ mlir/include/mlir/IR/Operation.h | 11 +++++++++++ mlir/lib/AffineOps/AffineOps.cpp | 17 +++++------------ mlir/lib/Analysis/AffineAnalysis.cpp | 11 ++++------- mlir/lib/Analysis/AffineStructures.cpp | 7 ++----- mlir/lib/Analysis/Utils.cpp | 16 ++++++---------- mlir/lib/EDSC/Builders.cpp | 4 +--- mlir/lib/StandardOps/Ops.cpp | 5 ++--- mlir/lib/Transforms/LoopFusion.cpp | 10 ++++------ 11 files changed, 47 insertions(+), 60 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/examples/Linalg/Linalg1/lib/SliceOp.cpp b/mlir/examples/Linalg/Linalg1/lib/SliceOp.cpp index a3bdceee073..0b68f20a7f1 100644 --- a/mlir/examples/Linalg/Linalg1/lib/SliceOp.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/SliceOp.cpp @@ -93,7 +93,7 @@ void linalg::SliceOp::print(OpAsmPrinter *p) { *p << "*"; } else { auto *v = getIndexing(); - if (v->getDefiningOp() && v->getDefiningOp()->isa()) { + if (isa_nonnull(v->getDefiningOp())) { *p << *v << ".."; } else { *p << *v; diff --git a/mlir/examples/Linalg/Linalg1/lib/Utils.cpp b/mlir/examples/Linalg/Linalg1/lib/Utils.cpp index 46be3252f45..f81930ad5e0 100644 --- a/mlir/examples/Linalg/Linalg1/lib/Utils.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/Utils.cpp @@ -30,5 +30,5 @@ unsigned linalg::getViewRank(Value *view) { assert(view->getType().isa() && "expected a ViewType"); if (auto viewOp = view->getDefiningOp()->dyn_cast()) return viewOp.getRank(); - return view->getDefiningOp()->dyn_cast().getRank(); + return view->getDefiningOp()->cast().getRank(); } diff --git a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp index 3b86b4c5d05..f3e8ff06781 100644 --- a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp @@ -50,14 +50,14 @@ struct SimplifyRedundantTranspose : public mlir::RewritePattern { // We can directly cast the current operation as this will only get invoked // on TransposeOp. TransposeOp transpose = op->cast(); - // look through the input to the current transpose + // Look through the input of the current transpose. mlir::Value *transposeInput = transpose.getOperand(); + TransposeOp transposeInputOp = + mlir::dyn_cast_or_null(transposeInput->getDefiningOp()); // If the input is defined by another Transpose, bingo! - if (!matchPattern(transposeInput, mlir::m_Op())) + if (!transposeInputOp) return matchFailure(); - auto transposeInputOp = - transposeInput->getDefiningOp()->cast(); // Use the rewriter to perform the replacement rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp}); return matchSuccess(); @@ -74,15 +74,13 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern { matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { ReshapeOp reshape = op->cast(); - // look through the input to the current reshape - mlir::Value *reshapeInput = reshape.getOperand(); - mlir::Operation *reshapeInputInst = reshapeInput->getDefiningOp(); - // If the input is defined by another reshape, bingo! - if (!reshapeInputInst || !reshapeInputInst->template isa()) + // Look through the input of the current reshape. + ConstantOp constantOp = mlir::dyn_cast_or_null( + reshape.getOperand()->getDefiningOp()); + // If the input is defined by another constant, bingo! + if (!constantOp) return matchFailure(); - ConstantOp constantOp = reshapeInputInst->template cast(); - auto reshapeType = op->getResult(0)->getType().cast(); if (auto valueAttr = constantOp.getAttrOfType("value")) { @@ -123,7 +121,7 @@ struct SimplifyReshapeReshape : public mlir::RewritePattern { matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { ReshapeOp reshape = op->cast(); - // look through the input to the current reshape + // Look through the input of the current reshape. mlir::Value *reshapeInput = reshape.getOperand(); // If the input is defined by another reshape, bingo! if (!matchPattern(reshapeInput, mlir::m_Op())) diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index 3eb1560e4f6..7007f81e471 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -527,6 +527,17 @@ inline auto Operation::getOperands() -> operand_range { return {operand_begin(), operand_end()}; } +/// Provide dyn_cast_or_null functionality for Operation casts. +template T dyn_cast_or_null(Operation *op) { + return op ? op->dyn_cast() : T(); +} + +/// Provide isa_nonnull functionality for Operation casts, i.e. if the operation +/// is non-null and a class of 'T'. +template bool isa_nonnull(Operation *op) { + return op && op->isa(); +} + /// This class implements the result iterators for the Operation class /// in terms of getResult(idx). class ResultIterator final diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 8b8b794ab99..f1bcb121bff 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -316,12 +316,9 @@ AffineMap AffineApplyNormalizer::renumber(const AffineApplyNormalizer &other) { static llvm::SetVector indicesFromAffineApplyOp(ArrayRef operands) { llvm::SetVector res; - for (auto en : llvm::enumerate(operands)) { - auto *t = en.value(); - if (t->getDefiningOp() && t->getDefiningOp()->isa()) { + for (auto en : llvm::enumerate(operands)) + if (isa_nonnull(en.value()->getDefiningOp())) res.insert(en.index()); - } - } return res; } @@ -459,9 +456,7 @@ AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map, // 2. Compose AffineApplyOps and dispatch dims or symbols. for (unsigned i = 0, e = operands.size(); i < e; ++i) { auto *t = operands[i]; - auto affineApply = t->getDefiningOp() - ? t->getDefiningOp()->dyn_cast() - : AffineApplyOp(); + auto affineApply = dyn_cast_or_null(t->getDefiningOp()); if (affineApply) { // a. Compose affine.apply operations. LLVM_DEBUG(affineApply.getOperation()->print( @@ -536,7 +531,7 @@ static void composeAffineMapAndOperands(AffineMap *map, void mlir::fullyComposeAffineMapAndOperands( AffineMap *map, SmallVectorImpl *operands) { while (llvm::any_of(*operands, [](Value *v) { - return v->getDefiningOp() && v->getDefiningOp()->isa(); + return isa_nonnull(v->getDefiningOp()); })) { composeAffineMapAndOperands(map, operands); } @@ -1190,9 +1185,7 @@ AffineForOp mlir::getForInductionVarOwner(Value *val) { if (!ivArg || !ivArg->getOwner()) return AffineForOp(); auto *containingInst = ivArg->getOwner()->getParent()->getContainingOp(); - if (!containingInst) - return AffineForOp(); - return containingInst->dyn_cast(); + return dyn_cast_or_null(containingInst); } /// Extracts the induction variables from a list of AffineForOps and returns diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index c7d992cef0e..f53571d896b 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -64,7 +64,7 @@ void mlir::getReachableAffineApplyOps( auto *opInst = state.value->getDefiningOp(); // Note: getDefiningOp will return nullptr if the operand is not an // Operation (i.e. block argument), which is a terminator for the search. - if (opInst == nullptr || !opInst->isa()) { + if (!isa_nonnull(opInst)) { worklist.pop_back(); continue; } @@ -463,12 +463,9 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap, auto *symbol = operands[i]; assert(isValidSymbol(symbol)); // Check if the symbol is a constant. - if (auto *opInst = symbol->getDefiningOp()) { - if (auto constOp = opInst->dyn_cast()) { - dependenceDomain->setIdToConstant(valuePosMap.getSymPos(symbol), - constOp.getValue()); - } - } + if (auto cOp = dyn_cast_or_null(symbol->getDefiningOp())) + dependenceDomain->setIdToConstant(valuePosMap.getSymPos(symbol), + cOp.getValue()); } }; diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 4a796299b2c..c38881a1b57 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -727,11 +727,8 @@ void FlatAffineConstraints::addInductionVarOrTerminalSymbol(Value *id) { // Add top level symbol. addSymbolId(getNumSymbolIds(), id); // Check if the symbol is a constant. - if (auto *opInst = id->getDefiningOp()) { - if (auto constOp = opInst->dyn_cast()) { - setIdToConstant(*id, constOp.getValue()); - } - } + if (auto constOp = dyn_cast_or_null(id->getDefiningOp())) + setIdToConstant(*id, constOp.getValue()); } LogicalResult FlatAffineConstraints::addAffineForOpDomain(AffineForOp forOp) { diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 5999b357e96..b2d004be11c 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -73,16 +73,12 @@ ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) { assert(cst->containsId(*value) && "value expected to be present"); if (isValidSymbol(value)) { // Check if the symbol is a constant. - if (auto *op = value->getDefiningOp()) { - if (auto constOp = op->dyn_cast()) { - cst->setIdToConstant(*value, constOp.getValue()); - } - } - } else { - if (auto loop = getForInductionVarOwner(value)) { - if (failed(cst->addAffineForOpDomain(loop))) - return failure(); - } + + if (auto cOp = dyn_cast_or_null(value->getDefiningOp())) + cst->setIdToConstant(*value, cOp.getValue()); + } else if (auto loop = getForInductionVarOwner(value)) { + if (failed(cst->addAffineForOpDomain(loop))) + return failure(); } } diff --git a/mlir/lib/EDSC/Builders.cpp b/mlir/lib/EDSC/Builders.cpp index 4cd31be7692..610c8b66320 100644 --- a/mlir/lib/EDSC/Builders.cpp +++ b/mlir/lib/EDSC/Builders.cpp @@ -264,9 +264,7 @@ categorizeValueByAffineType(MLIRContext *context, Value *val, unsigned &numDims, unsigned &numSymbols) { AffineExpr d; Value *resultVal = nullptr; - auto *op = val->getDefiningOp(); - auto constant = op ? op->dyn_cast() : ConstantIndexOp(); - if (constant) { + if (auto constant = dyn_cast_or_null(val->getDefiningOp())) { d = getAffineConstantExpr(constant.getValue(), context); } else if (isValidSymbol(val) && !isValidDim(val)) { d = getAffineSymbolExpr(numSymbols++, context); diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 2b998e98e1a..000ced61448 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -318,8 +318,7 @@ struct SimplifyAllocConst : public RewritePattern { continue; } auto *defOp = allocOp.getOperand(dynamicDimPos)->getDefiningOp(); - ConstantIndexOp constantIndexOp; - if (defOp && (constantIndexOp = defOp->dyn_cast())) { + if (auto constantIndexOp = dyn_cast_or_null(defOp)) { // Dynamic shape dimension will be folded. newShapeConstants.push_back(constantIndexOp.getValue()); // Record to check for zero uses later below. @@ -1095,7 +1094,7 @@ struct SimplifyDeadDealloc : public RewritePattern { // Check that the memref operand's defining operation is an AllocOp. Value *memref = dealloc.getMemRef(); Operation *defOp = memref->getDefiningOp(); - if (!defOp || !defOp->isa()) + if (!isa_nonnull(defOp)) return matchFailure(); // Check that all of the uses of the AllocOp are other DeallocOps. diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 39ed5a100a1..2eb76ba20cd 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -1547,11 +1547,9 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, if (storeLoadFwdGuaranteed) { // A single store disappears: -1 for that. computeCostMap[srcLoopIVs[numSrcLoopIVs - 1].getOperation()] = -1; - for (auto *loadOp : dstLoadOpInsts) { - auto *parentInst = loadOp->getParentOp(); - if (parentInst && parentInst->isa()) - computeCostMap[parentInst] = -1; - } + for (auto *loadOp : dstLoadOpInsts) + if (auto forOp = dyn_cast_or_null(loadOp->getParentOp())) + computeCostMap[forOp] = -1; } // Compute op instance count for the src loop nest with iteration slicing. @@ -2259,7 +2257,7 @@ public: continue; // Use list expected to match the dep graph info. auto *op = memref->getDefiningOp(); - if (op && op->isa()) + if (isa_nonnull(op)) op->erase(); } } -- cgit v1.2.3 From 70a416de14ce87439f491036aee46fb5ec66e1fd Mon Sep 17 00:00:00 2001 From: Amit Sabne Date: Tue, 9 Apr 2019 09:17:40 -0700 Subject: Fix typos in LoopFusion -- PiperOrigin-RevId: 242679298 --- mlir/lib/Transforms/LoopFusion.cpp | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 2eb76ba20cd..3c7ecc20e99 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -162,7 +162,7 @@ public: struct Node { // The unique identifier of this node in the graph. unsigned id; - // The top-level statment which is (or contains) loads/stores. + // The top-level statement which is (or contains) a load/store. Operation *op; // List of load operations. SmallVector loads; @@ -587,6 +587,7 @@ public: if (inEdges.count(id) > 0) forEachMemRefEdge(inEdges[id], callback); } + // Calls 'callback' for each output edge from node 'id' which carries a // memref dependence. void forEachMemRefOutputEdge(unsigned id, @@ -594,6 +595,7 @@ public: if (outEdges.count(id) > 0) forEachMemRefEdge(outEdges[id], callback); } + // Calls 'callback' for each edge in 'edges' which carries a memref // dependence. void forEachMemRefEdge(ArrayRef edges, @@ -787,14 +789,14 @@ struct LoopNestStatsCollector { // operation count * loop trip count) for the entire loop nest. // If 'tripCountOverrideMap' is non-null, overrides the trip count for loops // specified in the map when computing the total op instance count. -// NOTE: this is used to compute the cost of computation slices, which are +// NOTEs: 1) This is used to compute the cost of computation slices, which are // sliced along the iteration dimension, and thus reduce the trip count. // If 'computeCostMap' is non-null, the total op count for forOps specified // in the map is increased (not overridden) by adding the op count from the // map to the existing op count for the for loop. This is done before // multiplying by the loop's trip count, and is used to model the cost of // inserting a sliced loop nest of known cost into the loop's body. -// NOTE: this is used to compute the cost of fusing a slice of some loop nest +// 2) This is also used to compute the cost of fusing a slice of some loop nest // within another loop. static int64_t getComputeCost( Operation *forInst, LoopNestStats *stats, @@ -973,7 +975,7 @@ static unsigned getMaxLoopDepth(ArrayRef loadOpInsts, // *) Computes the loop permutation which sinks sequential loops deeper into // the loop nest, while preserving the relative order between other loops. // *) Checks each dependence component against the permutation to see if the -// desired loop interchange would violated dependences by making the a +// desired loop interchange would violate dependences by making the // dependence componenent lexicographically negative. // TODO(andydavis) Move this function to LoopUtils. static bool @@ -1001,7 +1003,7 @@ computeLoopInterchangePermutation(ArrayRef ops, FlatAffineConstraints dependenceConstraints; llvm::SmallVector depComps; // TODO(andydavis,bondhugula) Explore whether it would be profitable - // to pre-compute and store deps instead of repeatidly checking. + // to pre-compute and store deps instead of repeatedly checking. if (checkMemrefAccessDependence(srcAccess, dstAccess, d, &dependenceConstraints, &depComps)) { isParallelLoop[d - 1] = false; @@ -1010,6 +1012,7 @@ computeLoopInterchangePermutation(ArrayRef ops, } } } + // Count the number of parallel loops. unsigned numParallelLoops = 0; for (unsigned i = 0, e = isParallelLoop.size(); i < e; ++i) @@ -1223,7 +1226,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, return newMemRef; } -// Does the slice have a single iteration? +// Return the number of iterations in the given slice. static uint64_t getSliceIterationCount( const llvm::SmallDenseMap &sliceTripCountMap) { uint64_t iterCount = 1; @@ -1236,7 +1239,7 @@ static uint64_t getSliceIterationCount( // Checks if node 'srcId' (which writes to a live out memref), can be safely // fused into node 'dstId'. Returns true if the following conditions are met: // *) 'srcNode' only writes to live out 'memref'. -// *) 'srcNode' has exaclty one output edge on 'memref' (which is to 'dstId'). +// *) 'srcNode' has exactly one output edge on 'memref' (which is to 'dstId'). // *) 'dstNode's read/write region to 'memref' is a super set of 'srcNode's // write region to 'memref'. // TODO(andydavis) Generalize this to handle more live in/out cases. @@ -1570,7 +1573,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, (static_cast(srcLoopNestCost) + dstLoopNestCost) - 1; - // Compute what the slice write MemRefRegion would be, if the src loop + // Determine what the slice write MemRefRegion would be, if the src loop // nest slice 'sliceStates[i - 1]' were to be inserted into the dst loop // nest at loop depth 'i' MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc()); @@ -1744,11 +1747,11 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, // // *) A worklist is initialized with node ids from the dependence graph. // *) For each node id in the worklist: -// *) Pop a AffineForOp of the worklist. This 'dstAffineForOp' will be a +// *) Pop an AffineForOp of the worklist. This 'dstAffineForOp' will be a // candidate destination AffineForOp into which fusion will be attempted. // *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'. // *) For each LoadOp in 'dstLoadOps' do: -// *) Lookup dependent loop nests which have a single store op to the same +// *) Look up dependent loop nests which have a single store op to the same // memref. // *) Check if dependences would be violated by the fusion. // *) Get a computation slice of 'srcLoopNest', which adjusts its loop @@ -1756,7 +1759,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, // *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest', // at a loop depth determined by the cost model in 'isFusionProfitable'. // *) Add the newly fused load/store operations to the state, -// and also add newly fuse load ops to 'dstLoopOps' to be considered +// and also add newly fused load ops to 'dstLoopOps' to be considered // as fusion dst load ops in another iteration. // *) Remove old src loop nest and its associated state. // @@ -1867,7 +1870,7 @@ public: // Skip if no input edges along which to fuse. if (mdg->inEdges.count(dstId) == 0) continue; - // Iterate through in edges for 'dstId' and src node id for any + // Iterate through in-edges for 'dstId' and src node id for any // edges on 'memref'. SmallVector srcNodeIds; for (auto &srcEdge : mdg->inEdges[dstId]) { @@ -1977,12 +1980,12 @@ public: loads.push_back(loadOpInst); } - // Clear and add back loads and stores + // Clear and add back loads and stores. mdg->clearNodeLoadAndStores(dstNode->id); mdg->addToNode(dstId, dstLoopCollector.loadOpInsts, dstLoopCollector.storeOpInsts); // Remove old src loop nest if it no longer has outgoing dependence - // edges, and it does not write to a memref which escapes the + // edges, and if it does not write to a memref which escapes the // function. If 'writesToLiveInOrOut' is true, then 'srcNode' has // been fused into 'dstNode' and write region of 'dstNode' covers // the write region of 'srcNode', and 'srcNode' has no other users -- cgit v1.2.3 From 44f6dffbf8ed634b242751071197706532b5e8cc Mon Sep 17 00:00:00 2001 From: Andy Davis Date: Tue, 9 Apr 2019 12:21:28 -0700 Subject: Factor code to compute dependence components out of loop fusion pass, and into a reusable utility function (NFC). -- PiperOrigin-RevId: 242716259 --- mlir/include/mlir/Analysis/AffineAnalysis.h | 12 ++++++- mlir/lib/Analysis/AffineAnalysis.cpp | 50 +++++++++++++++++++++++++-- mlir/lib/Transforms/LoopFusion.cpp | 53 ++++++++++------------------- 3 files changed, 76 insertions(+), 39 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Analysis/AffineAnalysis.h b/mlir/include/mlir/Analysis/AffineAnalysis.h index 4873475e58b..1b92bd1b14c 100644 --- a/mlir/include/mlir/Analysis/AffineAnalysis.h +++ b/mlir/include/mlir/Analysis/AffineAnalysis.h @@ -76,12 +76,14 @@ struct MemRefAccess { }; // DependenceComponent contains state about the direction of a dependence as an -// interval [lb, ub]. +// interval [lb, ub] for an AffineForOp. // Distance vectors components are represented by the interval [lb, ub] with // lb == ub. // Direction vectors components are represented by the interval [lb, ub] with // lb < ub. Note that ub/lb == None means unbounded. struct DependenceComponent { + // The AffineForOp Operation associated with this dependence component. + Operation *op; // The lower bound of the dependence distance. llvm::Optional lb; // The upper bound of the dependence distance (inclusive). @@ -104,6 +106,14 @@ bool checkMemrefAccessDependence( unsigned loopDepth, FlatAffineConstraints *dependenceConstraints, llvm::SmallVector *dependenceComponents, bool allowRAR = false); + +/// Returns in 'depCompsVec', dependence components for dependences between all +/// load and store ops in loop nest rooted at 'forOp', at loop depths in range +/// [1, maxLoopDepth]. +void getDependenceComponents( + AffineForOp forOp, unsigned maxLoopDepth, + std::vector> *depCompsVec); + } // end namespace mlir #endif // MLIR_ANALYSIS_AFFINE_ANALYSIS_H diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index f53571d896b..53120af8574 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -517,8 +517,11 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap, } // Returns the number of outer loop common to 'src/dstDomain'. -static unsigned getNumCommonLoops(const FlatAffineConstraints &srcDomain, - const FlatAffineConstraints &dstDomain) { +// Loops common to 'src/dst' domains are added to 'commonLoops' if non-null. +static unsigned +getNumCommonLoops(const FlatAffineConstraints &srcDomain, + const FlatAffineConstraints &dstDomain, + SmallVectorImpl *commonLoops = nullptr) { // Find the number of common loops shared by src and dst accesses. unsigned minNumLoops = std::min(srcDomain.getNumDimIds(), dstDomain.getNumDimIds()); @@ -528,8 +531,12 @@ static unsigned getNumCommonLoops(const FlatAffineConstraints &srcDomain, !isForInductionVar(dstDomain.getIdValue(i)) || srcDomain.getIdValue(i) != dstDomain.getIdValue(i)) break; + if (commonLoops != nullptr) + commonLoops->push_back(getForInductionVarOwner(srcDomain.getIdValue(i))); ++numCommonLoops; } + if (commonLoops != nullptr) + assert(commonLoops->size() == numCommonLoops); return numCommonLoops; } @@ -628,7 +635,9 @@ static void computeDirectionVector( FlatAffineConstraints *dependenceDomain, llvm::SmallVector *dependenceComponents) { // Find the number of common loops shared by src and dst accesses. - unsigned numCommonLoops = getNumCommonLoops(srcDomain, dstDomain); + SmallVector commonLoops; + unsigned numCommonLoops = + getNumCommonLoops(srcDomain, dstDomain, &commonLoops); if (numCommonLoops == 0) return; // Compute direction vectors for requested loop depth. @@ -658,6 +667,7 @@ static void computeDirectionVector( // on eliminated constraint system. dependenceComponents->resize(numCommonLoops); for (unsigned j = 0; j < numCommonLoops; ++j) { + (*dependenceComponents)[j].op = commonLoops[j].getOperation(); auto lbConst = dependenceDomain->getConstantLowerBound(j); (*dependenceComponents)[j].lb = lbConst.getValueOr(std::numeric_limits::min()); @@ -856,3 +866,37 @@ bool mlir::checkMemrefAccessDependence( LLVM_DEBUG(dependenceConstraints->dump()); return true; } + +/// Gathers dependence components for dependences between all ops in loop nest +/// rooted at 'forOp' at loop depths in range [1, maxLoopDepth]. +void mlir::getDependenceComponents( + AffineForOp forOp, unsigned maxLoopDepth, + std::vector> *depCompsVec) { + // Collect all load and store ops in loop nest rooted at 'forOp'. + SmallVector loadAndStoreOpInsts; + forOp.getOperation()->walk([&](Operation *opInst) { + if (opInst->isa() || opInst->isa()) + loadAndStoreOpInsts.push_back(opInst); + }); + + unsigned numOps = loadAndStoreOpInsts.size(); + for (unsigned d = 1; d <= maxLoopDepth; ++d) { + for (unsigned i = 0; i < numOps; ++i) { + auto *srcOpInst = loadAndStoreOpInsts[i]; + MemRefAccess srcAccess(srcOpInst); + for (unsigned j = 0; j < numOps; ++j) { + auto *dstOpInst = loadAndStoreOpInsts[j]; + MemRefAccess dstAccess(dstOpInst); + + FlatAffineConstraints dependenceConstraints; + llvm::SmallVector depComps; + // TODO(andydavis,bondhugula) Explore whether it would be profitable + // to pre-compute and store deps instead of repeatedly checking. + if (checkMemrefAccessDependence(srcAccess, dstAccess, d, + &dependenceConstraints, &depComps)) { + depCompsVec->push_back(depComps); + } + } + } + } +} diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 3c7ecc20e99..011423bcc55 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -968,8 +968,8 @@ static unsigned getMaxLoopDepth(ArrayRef loadOpInsts, } // Compute loop interchange permutation: -// *) Computes dependence components between all op pairs in 'ops' for loop -// depths in range [1, 'maxLoopDepth']. +// *) Computes dependence components between all op pairs of ops in loop nest +// rooted at 'loops[0]', for loop depths in range [1, 'maxLoopDepth']. // *) Classifies the outermost 'maxLoopDepth' loops surrounding 'ops' as either // parallel or sequential. // *) Computes the loop permutation which sinks sequential loops deeper into @@ -979,37 +979,24 @@ static unsigned getMaxLoopDepth(ArrayRef loadOpInsts, // dependence componenent lexicographically negative. // TODO(andydavis) Move this function to LoopUtils. static bool -computeLoopInterchangePermutation(ArrayRef ops, - unsigned maxLoopDepth, +computeLoopInterchangePermutation(ArrayRef loops, SmallVectorImpl *loopPermMap) { - // Gather dependence components for dependences between all ops in 'ops' - // at loop depths in range [1, maxLoopDepth]. - // TODO(andydavis) Refactor this loop into a LoopUtil utility function: - // mlir::getDependenceComponents(). - // TODO(andydavis) Split this loop into two: first check all dependences, - // and construct dep vectors. Then, scan through them to detect the parallel - // ones. + assert(loops.size() > 1); + // Gather dependence components for dependences between all ops in loop nest + // rooted at 'loops[0]', at loop depths in range [1, maxLoopDepth]. + unsigned maxLoopDepth = loops.size(); std::vector> depCompsVec; + getDependenceComponents(loops[0], maxLoopDepth, &depCompsVec); + // Mark loops as either parallel or sequential. llvm::SmallVector isParallelLoop(maxLoopDepth, true); - unsigned numOps = ops.size(); - for (unsigned d = 1; d <= maxLoopDepth; ++d) { - for (unsigned i = 0; i < numOps; ++i) { - auto *srcOpInst = ops[i]; - MemRefAccess srcAccess(srcOpInst); - for (unsigned j = 0; j < numOps; ++j) { - auto *dstOpInst = ops[j]; - MemRefAccess dstAccess(dstOpInst); - - FlatAffineConstraints dependenceConstraints; - llvm::SmallVector depComps; - // TODO(andydavis,bondhugula) Explore whether it would be profitable - // to pre-compute and store deps instead of repeatedly checking. - if (checkMemrefAccessDependence(srcAccess, dstAccess, d, - &dependenceConstraints, &depComps)) { - isParallelLoop[d - 1] = false; - depCompsVec.push_back(depComps); - } - } + for (unsigned i = 0, e = depCompsVec.size(); i < e; ++i) { + llvm::SmallVector &depComps = depCompsVec[i]; + assert(depComps.size() >= maxLoopDepth); + for (unsigned j = 0; j < maxLoopDepth; ++j) { + DependenceComponent &depComp = depComps[j]; + assert(depComp.lb.hasValue() && depComp.ub.hasValue()); + if (depComp.lb.getValue() != 0 || depComp.ub.getValue() != 0) + isParallelLoop[j] = false; } } @@ -1071,13 +1058,9 @@ static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) { if (loops.size() < 2) return; - // Merge loads and stores into the same array. - SmallVector memOps(node->loads.begin(), node->loads.end()); - memOps.append(node->stores.begin(), node->stores.end()); - // Compute loop permutation in 'loopPermMap'. llvm::SmallVector loopPermMap; - if (!computeLoopInterchangePermutation(memOps, loops.size(), &loopPermMap)) + if (!computeLoopInterchangePermutation(loops, &loopPermMap)) return; int loopNestRootIndex = -1; -- cgit v1.2.3 From 1423acc03cdba3114356b2986386e359061e248d Mon Sep 17 00:00:00 2001 From: River Riddle Date: Tue, 23 Apr 2019 14:38:26 -0700 Subject: Rename isa_nonnull to isa_and_nonnull to match the upstream llvm name. -- PiperOrigin-RevId: 244928036 --- mlir/examples/Linalg/Linalg1/lib/SliceOp.cpp | 2 +- mlir/include/mlir/IR/Operation.h | 6 +++--- mlir/lib/AffineOps/AffineOps.cpp | 4 ++-- mlir/lib/Analysis/AffineAnalysis.cpp | 2 +- mlir/lib/StandardOps/Ops.cpp | 2 +- mlir/lib/Transforms/LoopFusion.cpp | 2 +- 6 files changed, 9 insertions(+), 9 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/examples/Linalg/Linalg1/lib/SliceOp.cpp b/mlir/examples/Linalg/Linalg1/lib/SliceOp.cpp index 818a770c58d..b7337a1c5cd 100644 --- a/mlir/examples/Linalg/Linalg1/lib/SliceOp.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/SliceOp.cpp @@ -93,7 +93,7 @@ void linalg::SliceOp::print(OpAsmPrinter *p) { *p << "*"; } else { auto *v = getIndexing(); - if (isa_nonnull(v->getDefiningOp())) { + if (isa_and_nonnull(v->getDefiningOp())) { *p << *v << ".."; } else { *p << *v; diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index 7007f81e471..aacf9ee1117 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -532,9 +532,9 @@ template T dyn_cast_or_null(Operation *op) { return op ? op->dyn_cast() : T(); } -/// Provide isa_nonnull functionality for Operation casts, i.e. if the operation -/// is non-null and a class of 'T'. -template bool isa_nonnull(Operation *op) { +/// Provide isa_and_nonnull functionality for Operation casts, i.e. if the +/// operation is non-null and a class of 'T'. +template bool isa_and_nonnull(Operation *op) { return op && op->isa(); } diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 0ff7db22e81..63c2b890628 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -317,7 +317,7 @@ static llvm::SetVector indicesFromAffineApplyOp(ArrayRef operands) { llvm::SetVector res; for (auto en : llvm::enumerate(operands)) - if (isa_nonnull(en.value()->getDefiningOp())) + if (isa_and_nonnull(en.value()->getDefiningOp())) res.insert(en.index()); return res; } @@ -531,7 +531,7 @@ static void composeAffineMapAndOperands(AffineMap *map, void mlir::fullyComposeAffineMapAndOperands( AffineMap *map, SmallVectorImpl *operands) { while (llvm::any_of(*operands, [](Value *v) { - return isa_nonnull(v->getDefiningOp()); + return isa_and_nonnull(v->getDefiningOp()); })) { composeAffineMapAndOperands(map, operands); } diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index 53120af8574..ad9a87b69b3 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -64,7 +64,7 @@ void mlir::getReachableAffineApplyOps( auto *opInst = state.value->getDefiningOp(); // Note: getDefiningOp will return nullptr if the operand is not an // Operation (i.e. block argument), which is a terminator for the search. - if (!isa_nonnull(opInst)) { + if (!isa_and_nonnull(opInst)) { worklist.pop_back(); continue; } diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 38a86c4ea35..99f93d1cc47 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -1094,7 +1094,7 @@ struct SimplifyDeadDealloc : public RewritePattern { // Check that the memref operand's defining operation is an AllocOp. Value *memref = dealloc.getMemRef(); Operation *defOp = memref->getDefiningOp(); - if (!isa_nonnull(defOp)) + if (!isa_and_nonnull(defOp)) return matchFailure(); // Check that all of the uses of the AllocOp are other DeallocOps. diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 011423bcc55..a69836c8653 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -2243,7 +2243,7 @@ public: continue; // Use list expected to match the dep graph info. auto *op = memref->getDefiningOp(); - if (isa_nonnull(op)) + if (isa_and_nonnull(op)) op->erase(); } } -- cgit v1.2.3 From 258e8d9ce2e7da290dbda335771d5e84e04c813a Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Fri, 3 May 2019 11:07:37 -0700 Subject: Prepend an "affine-" prefix to Affine pass option names - NFC Trying to activate both LLVM and MLIR passes in mlir-cpu-runner showed name collisions when registering pass names. One possible way of disambiguating that should also work across dialects is to prepend the dialect name to the passes that specifically operate on that dialect. With this CL, mlir-cpu-runner tests still run when both LLVM and MLIR passes are registered -- PiperOrigin-RevId: 246539917 --- mlir/g3doc/Passes.md | 12 ++++++------ mlir/lib/Transforms/DmaGeneration.cpp | 4 ++-- mlir/lib/Transforms/LoopFusion.cpp | 5 +++-- mlir/lib/Transforms/LoopInvariantCodeMotion.cpp | 2 +- mlir/lib/Transforms/LoopTiling.cpp | 4 ++-- mlir/lib/Transforms/LoopUnroll.cpp | 4 ++-- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 4 ++-- mlir/lib/Transforms/LowerVectorTransfers.cpp | 7 ++++--- mlir/lib/Transforms/MaterializeVectors.cpp | 5 +++-- mlir/lib/Transforms/PipelineDataTransfer.cpp | 4 ++-- mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp | 5 +++-- mlir/lib/Transforms/Vectorize.cpp | 2 +- mlir/test/Transforms/Vectorize/compose_maps.mlir | 2 +- mlir/test/Transforms/Vectorize/lower_vector_transfers.mlir | 2 +- mlir/test/Transforms/Vectorize/materialize.mlir | 2 +- .../Transforms/Vectorize/materialize_vectors_1d_to_1d.mlir | 2 +- .../Transforms/Vectorize/materialize_vectors_2d_to_1d.mlir | 2 +- .../Transforms/Vectorize/materialize_vectors_2d_to_2d.mlir | 2 +- mlir/test/Transforms/Vectorize/normalize_maps.mlir | 2 +- mlir/test/Transforms/Vectorize/vector_utils.mlir | 4 ++-- mlir/test/Transforms/Vectorize/vectorize_1d.mlir | 2 +- mlir/test/Transforms/Vectorize/vectorize_2d.mlir | 4 ++-- mlir/test/Transforms/Vectorize/vectorize_3d.mlir | 2 +- mlir/test/Transforms/Vectorize/vectorize_outer_loop_2d.mlir | 2 +- .../Vectorize/vectorize_outer_loop_transpose_2d.mlir | 2 +- mlir/test/Transforms/Vectorize/vectorize_transpose_2d.mlir | 2 +- mlir/test/Transforms/dma-generate.mlir | 4 ++-- mlir/test/Transforms/loop-fusion.mlir | 4 ++-- mlir/test/Transforms/loop-invariant-code-motion.mlir | 2 +- mlir/test/Transforms/loop-tiling.mlir | 4 ++-- mlir/test/Transforms/pipeline-data-transfer.mlir | 2 +- mlir/test/Transforms/slicing_utils.mlir | 6 +++--- mlir/test/Transforms/unroll-jam.mlir | 2 +- mlir/test/Transforms/unroll.mlir | 8 ++++---- 34 files changed, 63 insertions(+), 59 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/g3doc/Passes.md b/mlir/g3doc/Passes.md index baf585a5669..4a4391aa286 100644 --- a/mlir/g3doc/Passes.md +++ b/mlir/g3doc/Passes.md @@ -76,7 +76,7 @@ value is returned, packed into an LLVM IR struct type. Function calls and returns are updated accordingly. Block argument types are updated to use LLVM IR types. -## DMA generation (`-dma-generate`) +## DMA generation (`-affine-dma-generate`) Replaces all loads and stores on memref's living in 'slowMemorySpace' by introducing DMA operations (strided DMA if necessary) to transfer data to/from @@ -143,22 +143,22 @@ func @loop_nest_tiled() -> memref<256x1024xf32> { } ``` -## Loop tiling (`-loop-tile`) +## Loop tiling (`-affine-loop-tile`) Performs tiling or blocking of loop nests. It currently works on perfect loop nests. -## Loop unroll (`-loop-unroll`) +## Loop unroll (`-affine-loop-unroll`) This pass implements loop unrolling. It is able to unroll loops with arbitrary bounds, and generate a cleanup loop when necessary. -## Loop unroll and jam (`-loop-unroll-jam`) +## Loop unroll and jam (`-affine-loop-unroll-jam`) This pass implements unroll and jam for loops. It works on both perfect or imperfect loop nests. -## Loop fusion (`-loop-fusion`) +## Loop fusion (`-affine-loop-fusion`) Performs fusion of loop nests using a slicing-based approach. The fused loop nests, when possible, are rewritten to access significantly smaller local @@ -245,7 +245,7 @@ test/Transforms/memref-dataflow-opt.mlir:232:7: note: dependence from 2 to 1 at store %cf9, %m[%idx] : memref<10xf32> ``` -## Pipeline data transfer (`-pipeline-data-transfer`) +## Pipeline data transfer (`-affine-pipeline-data-transfer`) This pass performs a transformation to overlap non-blocking DMA operations in a loop with computations through double buffering. This is achieved by advancing diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index c70b218e36d..f58905353ef 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -34,7 +34,7 @@ #include "llvm/Support/Debug.h" #include -#define DEBUG_TYPE "dma-generate" +#define DEBUG_TYPE "affine-dma-generate" using namespace mlir; using llvm::SmallMapVector; @@ -773,4 +773,4 @@ void DmaGeneration::runOnFunction() { } static PassRegistration - pass("dma-generate", "Generate DMAs for memory operations"); + pass("affine-dma-generate", "Generate DMAs for memory operations"); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index a69836c8653..f8db53a3216 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -41,7 +41,7 @@ #include #include -#define DEBUG_TYPE "loop-fusion" +#define DEBUG_TYPE "affine-loop-fusion" using llvm::SetVector; @@ -2271,4 +2271,5 @@ void LoopFusion::runOnFunction() { .run(); } -static PassRegistration pass("loop-fusion", "Fuse loop nests"); +static PassRegistration pass("affine-loop-fusion", + "Fuse loop nests"); diff --git a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp index 13e9cda2407..b03a3c70a17 100644 --- a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp +++ b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp @@ -126,5 +126,5 @@ void LoopInvariantCodeMotion::runOnFunction() { } static PassRegistration - pass("loop-invariant-code-motion", + pass("affine-loop-invariant-code-motion", "Hoist loop invariant instructions outside of the loop"); diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 11f2468b1a9..4eb1ce22008 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -36,7 +36,7 @@ using namespace mlir; -#define DEBUG_TYPE "loop-tile" +#define DEBUG_TYPE "affine-loop-tile" static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); @@ -413,4 +413,4 @@ void LoopTiling::runOnFunction() { constexpr unsigned LoopTiling::kDefaultTileSize; constexpr uint64_t LoopTiling::kDefaultCacheMemCapacity; -static PassRegistration pass("loop-tile", "Tile loop nests"); +static PassRegistration pass("affine-loop-tile", "Tile loop nests"); diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 1e92ebec655..236ef81ebd2 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -34,7 +34,7 @@ using namespace mlir; -#define DEBUG_TYPE "loop-unroll" +#define DEBUG_TYPE "affine-loop-unroll" static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); @@ -188,4 +188,4 @@ FunctionPassBase *mlir::createLoopUnrollPass( unrollFull == -1 ? None : Optional(unrollFull), getUnrollFactor); } -static PassRegistration pass("loop-unroll", "Unroll loops"); +static PassRegistration pass("affine-loop-unroll", "Unroll loops"); diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index a3a24f6c0f7..366a7ede5eb 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -56,7 +56,7 @@ using namespace mlir; -#define DEBUG_TYPE "loop-unroll-jam" +#define DEBUG_TYPE "affine-loop-unroll-jam" static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); @@ -240,5 +240,5 @@ LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp, return success(); } -static PassRegistration pass("loop-unroll-jam", +static PassRegistration pass("affine-loop-unroll-jam", "Unroll and jam loops"); diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index a1048b763e0..f0990e47313 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -85,7 +85,7 @@ using namespace mlir; -#define DEBUG_TYPE "lower-vector-transfers" +#define DEBUG_TYPE "affine-lower-vector-transfers" namespace { @@ -380,5 +380,6 @@ FunctionPassBase *mlir::createLowerVectorTransfersPass() { } static PassRegistration - pass("lower-vector-transfers", "Materializes vector transfer ops to a " - "proper abstraction for the hardware"); + pass("affine-lower-vector-transfers", + "Materializes vector transfer ops to a " + "proper abstraction for the hardware"); diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 76dbdafb08d..2f06a9aa3bf 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -772,7 +772,8 @@ mlir::createMaterializeVectorsPass(llvm::ArrayRef vectorSize) { } static PassRegistration - pass("materialize-vectors", "Materializes super-vectors to vectors of the " - "proper size for the hardware"); + pass("affine-materialize-vectors", + "Materializes super-vectors to vectors of the " + "proper size for the hardware"); #undef DEBUG_TYPE diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 0ad24e6a711..66fbf4a1306 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -32,7 +32,7 @@ #include "mlir/Transforms/Utils.h" #include "llvm/ADT/DenseMap.h" #include "llvm/Support/Debug.h" -#define DEBUG_TYPE "pipeline-data-transfer" +#define DEBUG_TYPE "affine-pipeline-data-transfer" using namespace mlir; @@ -379,6 +379,6 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { } static PassRegistration pass( - "pipeline-data-transfer", + "affine-pipeline-data-transfer", "Pipeline non-blocking data transfers between explicitly managed levels of " "the memory hierarchy"); diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index b8640f05590..ecb391c1dbc 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -36,7 +36,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" -#define DEBUG_TYPE "vectorizer-test" +#define DEBUG_TYPE "affine-vectorizer-test" using namespace mlir; @@ -306,6 +306,7 @@ FunctionPassBase *mlir::createVectorizerTestPass() { } static PassRegistration - pass("vectorizer-test", "Tests vectorizer standalone functionality."); + pass("affine-vectorizer-test", + "Tests vectorizer standalone functionality."); #undef DEBUG_TYPE diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index d36dca8eca9..e135e95f30a 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -1262,5 +1262,5 @@ mlir::createVectorizePass(llvm::ArrayRef virtualVectorSize) { } static PassRegistration - pass("vectorize", + pass("affine-vectorize", "Vectorize to a target independent n-D vector abstraction"); diff --git a/mlir/test/Transforms/Vectorize/compose_maps.mlir b/mlir/test/Transforms/Vectorize/compose_maps.mlir index 0f1c599be3b..0b2b16aab58 100644 --- a/mlir/test/Transforms/Vectorize/compose_maps.mlir +++ b/mlir/test/Transforms/Vectorize/compose_maps.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -vectorizer-test -compose-maps 2>&1 | FileCheck %s +// RUN: mlir-opt %s -affine-vectorizer-test -compose-maps 2>&1 | FileCheck %s // For all these cases, the test traverses the `test_affine_map` ops and // composes them in order one-by-one. diff --git a/mlir/test/Transforms/Vectorize/lower_vector_transfers.mlir b/mlir/test/Transforms/Vectorize/lower_vector_transfers.mlir index f1ea8f46761..6331c38e9a4 100644 --- a/mlir/test/Transforms/Vectorize/lower_vector_transfers.mlir +++ b/mlir/test/Transforms/Vectorize/lower_vector_transfers.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -lower-vector-transfers | FileCheck %s +// RUN: mlir-opt %s -affine-lower-vector-transfers | FileCheck %s // CHECK: #[[ADD:map[0-9]+]] = (d0, d1) -> (d0 + d1) // CHECK: #[[SUB:map[0-9]+]] = ()[s0] -> (s0 - 1) diff --git a/mlir/test/Transforms/Vectorize/materialize.mlir b/mlir/test/Transforms/Vectorize/materialize.mlir index 09f62a4cc16..40460e1b9b3 100644 --- a/mlir/test/Transforms/Vectorize/materialize.mlir +++ b/mlir/test/Transforms/Vectorize/materialize.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -materialize-vectors -vector-size=4 -vector-size=4 | FileCheck %s +// RUN: mlir-opt %s -affine-materialize-vectors -vector-size=4 -vector-size=4 | FileCheck %s // CHECK-DAG: #[[ID1:map[0-9]+]] = (d0) -> (d0) // CHECK-DAG: #[[D0D1D2D3TOD1D0:map[0-9]+]] = (d0, d1, d2, d3) -> (d1, d0) diff --git a/mlir/test/Transforms/Vectorize/materialize_vectors_1d_to_1d.mlir b/mlir/test/Transforms/Vectorize/materialize_vectors_1d_to_1d.mlir index e3023764c74..318373af381 100644 --- a/mlir/test/Transforms/Vectorize/materialize_vectors_1d_to_1d.mlir +++ b/mlir/test/Transforms/Vectorize/materialize_vectors_1d_to_1d.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -vectorize -virtual-vector-size 32 --test-fastest-varying=0 -materialize-vectors -vector-size=8 | FileCheck %s +// RUN: mlir-opt %s -affine-vectorize -virtual-vector-size 32 --test-fastest-varying=0 -affine-materialize-vectors -vector-size=8 | FileCheck %s // vector<32xf32> -> vector<8xf32> // CHECK-DAG: [[ID1:#.*]] = (d0) -> (d0) diff --git a/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_1d.mlir b/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_1d.mlir index b655790b271..e5034e41f74 100644 --- a/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_1d.mlir +++ b/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_1d.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -vectorize -virtual-vector-size 3 -virtual-vector-size 16 --test-fastest-varying=1 --test-fastest-varying=0 -materialize-vectors -vector-size=8 | FileCheck %s +// RUN: mlir-opt %s -affine-vectorize -virtual-vector-size 3 -virtual-vector-size 16 --test-fastest-varying=1 --test-fastest-varying=0 -affine-materialize-vectors -vector-size=8 | FileCheck %s // vector<3x16xf32> -> vector<8xf32> // CHECK-DAG: [[ID1:#.*]] = (d0) -> (d0) diff --git a/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_2d.mlir b/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_2d.mlir index bb48cda0907..ea1353db73c 100644 --- a/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_2d.mlir +++ b/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_2d.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -vectorize -virtual-vector-size 3 -virtual-vector-size 32 --test-fastest-varying=1 --test-fastest-varying=0 -materialize-vectors -vector-size=3 -vector-size=16 | FileCheck %s +// RUN: mlir-opt %s -affine-vectorize -virtual-vector-size 3 -virtual-vector-size 32 --test-fastest-varying=1 --test-fastest-varying=0 -affine-materialize-vectors -vector-size=3 -vector-size=16 | FileCheck %s // vector<3x32xf32> -> vector<3x16xf32> // CHECK-DAG: [[ID1:#.*]] = (d0) -> (d0) diff --git a/mlir/test/Transforms/Vectorize/normalize_maps.mlir b/mlir/test/Transforms/Vectorize/normalize_maps.mlir index 076d2c75633..e7b08b70cce 100644 --- a/mlir/test/Transforms/Vectorize/normalize_maps.mlir +++ b/mlir/test/Transforms/Vectorize/normalize_maps.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -vectorizer-test -normalize-maps | FileCheck %s +// RUN: mlir-opt %s -affine-vectorizer-test -normalize-maps | FileCheck %s // CHECK-DAG: #[[ZERO:[a-zA-Z0-9]+]] = () -> (0) // CHECK-DAG: #[[ID1:[a-zA-Z0-9]+]] = (d0) -> (d0) diff --git a/mlir/test/Transforms/Vectorize/vector_utils.mlir b/mlir/test/Transforms/Vectorize/vector_utils.mlir index ad35b67e549..ceb295b3784 100644 --- a/mlir/test/Transforms/Vectorize/vector_utils.mlir +++ b/mlir/test/Transforms/Vectorize/vector_utils.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt %s -vectorizer-test -vector-shape-ratio 4 -vector-shape-ratio 8 2>&1 | FileCheck %s -// RUN: mlir-opt %s -vectorizer-test -vector-shape-ratio 2 -vector-shape-ratio 5 -vector-shape-ratio 2 2>&1 | FileCheck %s -check-prefix=TEST-3x4x5x8 +// RUN: mlir-opt %s -affine-vectorizer-test -vector-shape-ratio 4 -vector-shape-ratio 8 2>&1 | FileCheck %s +// RUN: mlir-opt %s -affine-vectorizer-test -vector-shape-ratio 2 -vector-shape-ratio 5 -vector-shape-ratio 2 2>&1 | FileCheck %s -check-prefix=TEST-3x4x5x8 func @vector_add_2d(%arg0: index, %arg1: index) -> f32 { // Nothing should be matched in this first block. diff --git a/mlir/test/Transforms/Vectorize/vectorize_1d.mlir b/mlir/test/Transforms/Vectorize/vectorize_1d.mlir index bd1337b5a7f..5a0fab1b715 100644 --- a/mlir/test/Transforms/Vectorize/vectorize_1d.mlir +++ b/mlir/test/Transforms/Vectorize/vectorize_1d.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -vectorize -virtual-vector-size 128 --test-fastest-varying=0 | FileCheck %s +// RUN: mlir-opt %s -affine-vectorize -virtual-vector-size 128 --test-fastest-varying=0 | FileCheck %s // Permutation maps used in vectorization. // CHECK: #[[map_proj_d0d1_0:map[0-9]+]] = (d0, d1) -> (0) diff --git a/mlir/test/Transforms/Vectorize/vectorize_2d.mlir b/mlir/test/Transforms/Vectorize/vectorize_2d.mlir index 5c6819502ae..217c7a6b39e 100644 --- a/mlir/test/Transforms/Vectorize/vectorize_2d.mlir +++ b/mlir/test/Transforms/Vectorize/vectorize_2d.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt %s -vectorize -virtual-vector-size 4 -virtual-vector-size 8 | FileCheck %s -check-prefix=VECT -// RUN: mlir-opt %s -vectorize -virtual-vector-size 32 -virtual-vector-size 256 --test-fastest-varying=1 --test-fastest-varying=0 | FileCheck %s +// RUN: mlir-opt %s -affine-vectorize -virtual-vector-size 4 -virtual-vector-size 8 | FileCheck %s -check-prefix=VECT +// RUN: mlir-opt %s -affine-vectorize -virtual-vector-size 32 -virtual-vector-size 256 --test-fastest-varying=1 --test-fastest-varying=0 | FileCheck %s // Permutation maps used in vectorization. // CHECK-DAG: #[[map_id1:map[0-9]+]] = (d0) -> (d0) diff --git a/mlir/test/Transforms/Vectorize/vectorize_3d.mlir b/mlir/test/Transforms/Vectorize/vectorize_3d.mlir index 5b0e9f1de20..3f766ea9ac4 100644 --- a/mlir/test/Transforms/Vectorize/vectorize_3d.mlir +++ b/mlir/test/Transforms/Vectorize/vectorize_3d.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -vectorize -virtual-vector-size 32 -virtual-vector-size 64 -virtual-vector-size 256 --test-fastest-varying=2 --test-fastest-varying=1 --test-fastest-varying=0 | FileCheck %s +// RUN: mlir-opt %s -affine-vectorize -virtual-vector-size 32 -virtual-vector-size 64 -virtual-vector-size 256 --test-fastest-varying=2 --test-fastest-varying=1 --test-fastest-varying=0 | FileCheck %s // Permutation maps used in vectorization. // CHECK: #[[map_proj_d0d1d2_d0d1d2:map[0-9]+]] = (d0, d1, d2) -> (d0, d1, d2) diff --git a/mlir/test/Transforms/Vectorize/vectorize_outer_loop_2d.mlir b/mlir/test/Transforms/Vectorize/vectorize_outer_loop_2d.mlir index 3db78be439a..2a50c426f12 100644 --- a/mlir/test/Transforms/Vectorize/vectorize_outer_loop_2d.mlir +++ b/mlir/test/Transforms/Vectorize/vectorize_outer_loop_2d.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -vectorize -virtual-vector-size 32 -virtual-vector-size 256 --test-fastest-varying=2 --test-fastest-varying=0 | FileCheck %s +// RUN: mlir-opt %s -affine-vectorize -virtual-vector-size 32 -virtual-vector-size 256 --test-fastest-varying=2 --test-fastest-varying=0 | FileCheck %s // Permutation maps used in vectorization. // CHECK: #[[map_proj_d0d1d2_d0d2:map[0-9]+]] = (d0, d1, d2) -> (d0, d2) diff --git a/mlir/test/Transforms/Vectorize/vectorize_outer_loop_transpose_2d.mlir b/mlir/test/Transforms/Vectorize/vectorize_outer_loop_transpose_2d.mlir index 8f335e607df..63e98ab795d 100644 --- a/mlir/test/Transforms/Vectorize/vectorize_outer_loop_transpose_2d.mlir +++ b/mlir/test/Transforms/Vectorize/vectorize_outer_loop_transpose_2d.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -vectorize -virtual-vector-size 32 -virtual-vector-size 256 --test-fastest-varying=0 --test-fastest-varying=2 | FileCheck %s +// RUN: mlir-opt %s -affine-vectorize -virtual-vector-size 32 -virtual-vector-size 256 --test-fastest-varying=0 --test-fastest-varying=2 | FileCheck %s // Permutation maps used in vectorization. // CHECK: #[[map_proj_d0d1d2_d2d0:map[0-9]+]] = (d0, d1, d2) -> (d2, d0) diff --git a/mlir/test/Transforms/Vectorize/vectorize_transpose_2d.mlir b/mlir/test/Transforms/Vectorize/vectorize_transpose_2d.mlir index 3e4e5ade50f..280ded81d2f 100644 --- a/mlir/test/Transforms/Vectorize/vectorize_transpose_2d.mlir +++ b/mlir/test/Transforms/Vectorize/vectorize_transpose_2d.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -vectorize -virtual-vector-size 32 -virtual-vector-size 256 --test-fastest-varying=0 --test-fastest-varying=1 | FileCheck %s +// RUN: mlir-opt %s -affine-vectorize -virtual-vector-size 32 -virtual-vector-size 256 --test-fastest-varying=0 --test-fastest-varying=1 | FileCheck %s // Permutation maps used in vectorization. // CHECK-DAG: #[[map_proj_d0d1d2_d2d1:map[0-9]+]] = (d0, d1, d2) -> (d2, d1) diff --git a/mlir/test/Transforms/dma-generate.mlir b/mlir/test/Transforms/dma-generate.mlir index 99fa581c026..19c7a637e4d 100644 --- a/mlir/test/Transforms/dma-generate.mlir +++ b/mlir/test/Transforms/dma-generate.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt %s -split-input-file -dma-generate -dma-skip-non-unit-stride-loops -verify | FileCheck %s -// RUN: mlir-opt %s -split-input-file -dma-generate -dma-fast-mem-capacity=16 -dma-fast-mem-space=2 | FileCheck %s --check-prefix FAST-MEM-16KB +// RUN: mlir-opt %s -split-input-file -affine-dma-generate -dma-skip-non-unit-stride-loops -verify | FileCheck %s +// RUN: mlir-opt %s -split-input-file -affine-dma-generate -dma-fast-mem-capacity=16 -dma-fast-mem-space=2 | FileCheck %s --check-prefix FAST-MEM-16KB // We run most test cases with -dma-skip-non-unit-stride-loops to allow testing // DMA generation at inner levels easily - since the DMA generation would diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index 91b0d040890..4a5eebb4dd8 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt %s -loop-fusion -split-input-file -verify | FileCheck %s -// RUN: mlir-opt %s -loop-fusion -fusion-maximal -split-input-file -verify | FileCheck %s --check-prefix=MAXIMAL +// RUN: mlir-opt %s -affine-loop-fusion -split-input-file -verify | FileCheck %s +// RUN: mlir-opt %s -affine-loop-fusion -fusion-maximal -split-input-file -verify | FileCheck %s --check-prefix=MAXIMAL // TODO(andydavis) Add more tests: // *) Add nested fusion test cases when non-constant loop bound support is diff --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir index f2276d8d83c..af9560bf0f8 100644 --- a/mlir/test/Transforms/loop-invariant-code-motion.mlir +++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -loop-invariant-code-motion -split-input-file -verify | FileCheck %s +// RUN: mlir-opt %s -affine-loop-invariant-code-motion -split-input-file -verify | FileCheck %s func @nested_loops_both_having_invariant_code() { %m = alloc() : memref<10xf32> diff --git a/mlir/test/Transforms/loop-tiling.mlir b/mlir/test/Transforms/loop-tiling.mlir index c65f3ae8be9..4686ff5bdb9 100644 --- a/mlir/test/Transforms/loop-tiling.mlir +++ b/mlir/test/Transforms/loop-tiling.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt %s -split-input-file -loop-tile -tile-size=32 | FileCheck %s -// RUN: mlir-opt %s -split-input-file -loop-tile -tile-cache-size=512 | FileCheck %s --check-prefix=MODEL +// RUN: mlir-opt %s -split-input-file -affine-loop-tile -tile-size=32 | FileCheck %s +// RUN: mlir-opt %s -split-input-file -affine-loop-tile -tile-cache-size=512 | FileCheck %s --check-prefix=MODEL // ----- diff --git a/mlir/test/Transforms/pipeline-data-transfer.mlir b/mlir/test/Transforms/pipeline-data-transfer.mlir index 9fafab6c4e6..30e6be82e2a 100644 --- a/mlir/test/Transforms/pipeline-data-transfer.mlir +++ b/mlir/test/Transforms/pipeline-data-transfer.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -pipeline-data-transfer | FileCheck %s +// RUN: mlir-opt %s -affine-pipeline-data-transfer | FileCheck %s // CHECK-DAG: [[MOD_2:#map[0-9]+]] = (d0) -> (d0 mod 2) // CHECK-DAG: [[FLOOR_MOD_2:#map[0-9]+]] = (d0) -> ((d0 floordiv 4) mod 2) diff --git a/mlir/test/Transforms/slicing_utils.mlir b/mlir/test/Transforms/slicing_utils.mlir index 79b8771d946..07e1a509987 100644 --- a/mlir/test/Transforms/slicing_utils.mlir +++ b/mlir/test/Transforms/slicing_utils.mlir @@ -1,6 +1,6 @@ -// RUN: mlir-opt %s -vectorizer-test -forward-slicing=true 2>&1 | FileCheck %s --check-prefix=FWD -// RUN: mlir-opt %s -vectorizer-test -backward-slicing=true 2>&1 | FileCheck %s --check-prefix=BWD -// RUN: mlir-opt %s -vectorizer-test -slicing=true 2>&1 | FileCheck %s --check-prefix=FWDBWD +// RUN: mlir-opt %s -affine-vectorizer-test -forward-slicing=true 2>&1 | FileCheck %s --check-prefix=FWD +// RUN: mlir-opt %s -affine-vectorizer-test -backward-slicing=true 2>&1 | FileCheck %s --check-prefix=BWD +// RUN: mlir-opt %s -affine-vectorizer-test -slicing=true 2>&1 | FileCheck %s --check-prefix=FWDBWD /// 1 2 3 4 /// |_______| |______| diff --git a/mlir/test/Transforms/unroll-jam.mlir b/mlir/test/Transforms/unroll-jam.mlir index 6ea152f48d1..1d04792ff8b 100644 --- a/mlir/test/Transforms/unroll-jam.mlir +++ b/mlir/test/Transforms/unroll-jam.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -loop-unroll-jam -unroll-jam-factor=2 | FileCheck %s +// RUN: mlir-opt %s -affine-loop-unroll-jam -unroll-jam-factor=2 | FileCheck %s // CHECK-DAG: [[MAP_PLUS_1:#map[0-9]+]] = (d0) -> (d0 + 1) // CHECK-DAG: [[M1:#map[0-9]+]] = ()[s0] -> (s0 + 8) diff --git a/mlir/test/Transforms/unroll.mlir b/mlir/test/Transforms/unroll.mlir index 9fd00886b3f..5dc6637cd7f 100644 --- a/mlir/test/Transforms/unroll.mlir +++ b/mlir/test/Transforms/unroll.mlir @@ -1,7 +1,7 @@ -// RUN: mlir-opt %s -loop-unroll -unroll-full | FileCheck %s --check-prefix UNROLL-FULL -// RUN: mlir-opt %s -loop-unroll -unroll-full -unroll-full-threshold=2 | FileCheck %s --check-prefix SHORT -// RUN: mlir-opt %s -loop-unroll -unroll-factor=4 | FileCheck %s --check-prefix UNROLL-BY-4 -// RUN: mlir-opt %s -loop-unroll -unroll-factor=1 | FileCheck %s --check-prefix UNROLL-BY-1 +// RUN: mlir-opt %s -affine-loop-unroll -unroll-full | FileCheck %s --check-prefix UNROLL-FULL +// RUN: mlir-opt %s -affine-loop-unroll -unroll-full -unroll-full-threshold=2 | FileCheck %s --check-prefix SHORT +// RUN: mlir-opt %s -affine-loop-unroll -unroll-factor=4 | FileCheck %s --check-prefix UNROLL-BY-4 +// RUN: mlir-opt %s -affine-loop-unroll -unroll-factor=1 | FileCheck %s --check-prefix UNROLL-BY-1 // UNROLL-FULL-DAG: [[MAP0:#map[0-9]+]] = (d0) -> (d0 + 1) // UNROLL-FULL-DAG: [[MAP1:#map[0-9]+]] = (d0) -> (d0 + 2) -- cgit v1.2.3 From 2fe8ae4f6cab83b753fded814a0992b0cddf1609 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Fri, 3 May 2019 19:48:57 -0700 Subject: Fix up some mixed sign warnings. -- PiperOrigin-RevId: 246614498 --- mlir/include/mlir/TableGen/Pattern.h | 6 +- mlir/lib/Analysis/AffineStructures.cpp | 2 +- mlir/lib/IR/AsmPrinter.cpp | 2 +- mlir/lib/IR/Attributes.cpp | 11 ++-- mlir/lib/StandardOps/Ops.cpp | 2 +- mlir/lib/TableGen/Pattern.cpp | 12 ++-- mlir/lib/Transforms/LoopFusion.cpp | 2 +- .../Utils/GreedyPatternRewriteDriver.cpp | 4 +- mlir/test/mlir-tblgen/op-operand.td | 20 +++---- mlir/test/mlir-tblgen/op-result.td | 20 +++---- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 68 +++++++++++----------- mlir/tools/mlir-tblgen/RewriterGen.cpp | 12 ++-- 12 files changed, 80 insertions(+), 81 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h index e833e49c73c..1b75712e083 100644 --- a/mlir/include/mlir/TableGen/Pattern.h +++ b/mlir/include/mlir/TableGen/Pattern.h @@ -144,10 +144,10 @@ public: // Returns the number of operations recursively involved in the DAG tree // rooted from this node. - unsigned getNumOps() const; + int getNumOps() const; // Returns the number of immediate arguments to this DAG node. - unsigned getNumArgs() const; + int getNumArgs() const; // Returns true if the `index`-th argument is a nested DAG construct. bool isNestedDagArg(unsigned index) const; @@ -192,7 +192,7 @@ public: DagNode getSourcePattern() const; // Returns the number of results generated by applying this rewrite pattern. - unsigned getNumResults() const; + int getNumResults() const; // Returns the DAG tree root node of the `index`-th result pattern. DagNode getResultPattern(unsigned index) const; diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 5ffc749dc94..b98b4b891db 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -429,7 +429,7 @@ void FlatAffineConstraints::addId(IdKind kind, unsigned pos, Value *id) { numReservedCols++; } - unsigned absolutePos; + int absolutePos; if (kind == IdKind::Dimension) { absolutePos = pos; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 03fb5c6f85e..a86821ff8bc 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1099,7 +1099,7 @@ void ModulePrinter::printIntegerSet(IntegerSet set) { // Print constraints. os << " : ("; - auto numConstraints = set.getNumConstraints(); + int numConstraints = set.getNumConstraints(); for (int i = 1; i < numConstraints; ++i) { printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1)); os << ", "; diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index e4b5b78a4be..c69cb2bffc2 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -433,19 +433,19 @@ Attribute DenseElementsAttr::getValue(ArrayRef index) const { // Verify that the rank of the indices matches the held type. auto rank = type.getRank(); - if (rank != index.size()) + if (static_cast(rank) != index.size()) return Attribute(); // Verify that all of the indices are within the shape dimensions. auto shape = type.getShape(); for (unsigned i = 0; i != rank; ++i) - if (shape[i] <= index[i]) + if (shape[i] <= static_cast(index[i])) return Attribute(); // Reduce the provided multidimensional index into a 1D index. uint64_t valueIndex = 0; uint64_t dimMultiplier = 1; - for (auto i = rank - 1; i >= 0; --i) { + for (int i = rank - 1; i >= 0; --i) { valueIndex += index[i] * dimMultiplier; dimMultiplier *= shape[i]; } @@ -701,7 +701,7 @@ Attribute SparseElementsAttr::getValue(ArrayRef index) const { auto type = getType(); // Verify that the rank of the indices matches the held type. - auto rank = type.getRank(); + size_t rank = type.getRank(); if (rank != index.size()) return Attribute(); @@ -715,8 +715,7 @@ Attribute SparseElementsAttr::getValue(ArrayRef index) const { llvm::SmallDenseMap, size_t> mappedIndices; auto numSparseIndices = sparseIndices.getType().getDimSize(0); for (size_t i = 0, e = numSparseIndices; i != e; ++i) - mappedIndices.try_emplace( - {sparseIndexValues + (i * rank), static_cast(rank)}, i); + mappedIndices.try_emplace({sparseIndexValues + (i * rank), rank}, i); // Look for the provided index key within the mapped indices. If the provided // index is not found, then return a zero attribute. diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 4cc192b4ce0..46baa7441ed 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -1153,7 +1153,7 @@ LogicalResult DimOp::verify() { auto type = getOperand()->getType(); if (auto tensorType = type.dyn_cast()) { - if (index >= tensorType.getRank()) + if (index >= static_cast(tensorType.getRank())) return emitOpError("index is out of range"); } else if (auto memrefType = type.dyn_cast()) { if (index >= memrefType.getRank()) diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp index ff0150dabad..d4ba6cdf292 100644 --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -114,16 +114,16 @@ Operator &tblgen::DagNode::getDialectOp(RecordOperatorMap *mapper) const { .first->second; } -unsigned tblgen::DagNode::getNumOps() const { - unsigned count = isReplaceWithValue() ? 0 : 1; - for (unsigned i = 0, e = getNumArgs(); i != e; ++i) { +int tblgen::DagNode::getNumOps() const { + int count = isReplaceWithValue() ? 0 : 1; + for (int i = 0, e = getNumArgs(); i != e; ++i) { if (auto child = getArgAsNestedDag(i)) count += child.getNumOps(); } return count; } -unsigned tblgen::DagNode::getNumArgs() const { return node->getNumArgs(); } +int tblgen::DagNode::getNumArgs() const { return node->getNumArgs(); } bool tblgen::DagNode::isNestedDagArg(unsigned index) const { return isa(node->getArg(index)); @@ -161,7 +161,7 @@ tblgen::DagNode tblgen::Pattern::getSourcePattern() const { return tblgen::DagNode(def.getValueAsDag("sourcePattern")); } -unsigned tblgen::Pattern::getNumResults() const { +int tblgen::Pattern::getNumResults() const { auto *results = def.getValueAsListInit("resultPatterns"); return results->size(); } @@ -248,7 +248,7 @@ void tblgen::Pattern::collectBoundArguments(DagNode tree) { boundOps.insert(treeName); // TODO(jpienaar): Expand to multiple matches. - for (unsigned i = 0; i != numTreeArgs; ++i) { + for (int i = 0; i != numTreeArgs; ++i) { if (auto treeArg = tree.getArgAsNestedDag(i)) { // This DAG node argument is a DAG node itself. Go inside recursively. collectBoundArguments(treeArg); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index f8db53a3216..fda6574f503 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -1676,7 +1676,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, << " fused mem: " << fusedMem << "\n" << " slice mem: " << sliceMemEstimate << "\n"); - if (fusedMem > srcMemSizeVal + dstMemSizeVal) { + if (static_cast(fusedMem) > srcMemSizeVal + dstMemSizeVal) { LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n"); return false; } diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 8c6a932f5dd..3fd44c79c79 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -53,7 +53,7 @@ public: /// Perform the rewrites. Return true if the rewrite converges in /// `maxIterations`. - bool simplifyFunction(unsigned maxIterations); + bool simplifyFunction(int maxIterations); void addToWorklist(Operation *op) { // Check to see if the worklist already contains this op. @@ -146,7 +146,7 @@ private: }; // end anonymous namespace /// Perform the rewrites. -bool GreedyPatternRewriteDriver::simplifyFunction(unsigned maxIterations) { +bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) { Function *fn = builder.getFunction(); ConstantFoldHelper helper(fn); diff --git a/mlir/test/mlir-tblgen/op-operand.td b/mlir/test/mlir-tblgen/op-operand.td index 6416905b545..89f260fe25a 100644 --- a/mlir/test/mlir-tblgen/op-operand.td +++ b/mlir/test/mlir-tblgen/op-operand.td @@ -41,13 +41,13 @@ def OpC : NS_Op<"all_variadic_inputs_op", [SameVariadicOperandSize]> { } // CHECK-LABEL: Operation::operand_range OpC::input1() -// CHECK-NEXT: unsigned variadicOperandSize = (this->getNumOperands() - 0) / 2; -// CHECK-NEXT: unsigned offset = 0 + variadicOperandSize * 0; +// CHECK-NEXT: variadicOperandSize = (this->getNumOperands() - 0) / 2; +// CHECK-NEXT: offset = 0 + variadicOperandSize * 0; // CHECK-NEXT: return {std::next(operand_begin(), offset), std::next(operand_begin(), offset + variadicOperandSize)}; // CHECK-LABEL: Operation::operand_range OpC::input2() -// CHECK-NEXT: unsigned variadicOperandSize = (this->getNumOperands() - 0) / 2; -// CHECK-NEXT: unsigned offset = 0 + variadicOperandSize * 1; +// CHECK-NEXT: variadicOperandSize = (this->getNumOperands() - 0) / 2; +// CHECK-NEXT: offset = 0 + variadicOperandSize * 1; // CHECK-NEXT: return {std::next(operand_begin(), offset), std::next(operand_begin(), offset + variadicOperandSize)}; // CHECK-LABEL: OpC::build @@ -59,18 +59,18 @@ def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]> } // CHECK-LABEL: Operation::operand_range OpD::input1() -// CHECK-NEXT: unsigned variadicOperandSize = (this->getNumOperands() - 1) / 2; -// CHECK-NEXT: unsigned offset = 0 + variadicOperandSize * 0; +// CHECK-NEXT: variadicOperandSize = (this->getNumOperands() - 1) / 2; +// CHECK-NEXT: offset = 0 + variadicOperandSize * 0; // CHECK-NEXT: return {std::next(operand_begin(), offset), std::next(operand_begin(), offset + variadicOperandSize)}; // CHECK-LABEL: Value *OpD::input2() -// CHECK-NEXT: unsigned variadicOperandSize = (this->getNumOperands() - 1) / 2; -// CHECK-NEXT: unsigned offset = 0 + variadicOperandSize * 1; +// CHECK-NEXT: variadicOperandSize = (this->getNumOperands() - 1) / 2; +// CHECK-NEXT: offset = 0 + variadicOperandSize * 1; // CHECK-NEXT: return this->getOperand(offset); // CHECK-LABEL: Operation::operand_range OpD::input3() -// CHECK-NEXT: unsigned variadicOperandSize = (this->getNumOperands() - 1) / 2; -// CHECK-NEXT: unsigned offset = 1 + variadicOperandSize * 1; +// CHECK-NEXT: variadicOperandSize = (this->getNumOperands() - 1) / 2; +// CHECK-NEXT: offset = 1 + variadicOperandSize * 1; // CHECK-NEXT: return {std::next(operand_begin(), offset), std::next(operand_begin(), offset + variadicOperandSize)}; // CHECK-LABEL: OpD::build diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td index 1adfef84220..268f0c0c514 100644 --- a/mlir/test/mlir-tblgen/op-result.td +++ b/mlir/test/mlir-tblgen/op-result.td @@ -98,13 +98,13 @@ def OpH : NS_Op<"all_variadic_results_op", [SameVariadicResultSize]> { } // CHECK-LABEL: Operation::result_range OpH::output1() -// CHECK-NEXT: unsigned variadicResultSize = (this->getNumResults() - 0) / 2; -// CHECK-NEXT: unsigned offset = 0 + variadicResultSize * 0; +// CHECK-NEXT: variadicResultSize = (this->getNumResults() - 0) / 2; +// CHECK-NEXT: offset = 0 + variadicResultSize * 0; // CHECK-NEXT: return {std::next(result_begin(), offset), std::next(result_begin(), offset + variadicResultSize)}; // CHECK-LABEL: Operation::result_range OpH::output2() -// CHECK-NEXT: unsigned variadicResultSize = (this->getNumResults() - 0) / 2; -// CHECK-NEXT: unsigned offset = 0 + variadicResultSize * 1; +// CHECK-NEXT: variadicResultSize = (this->getNumResults() - 0) / 2; +// CHECK-NEXT: offset = 0 + variadicResultSize * 1; // CHECK-NEXT: return {std::next(result_begin(), offset), std::next(result_begin(), offset + variadicResultSize)}; @@ -117,18 +117,18 @@ def OpI : NS_Op<"mix_variadic_and_normal_results_op", [SameVariadicResultSize]> } // CHECK-LABEL: Operation::result_range OpI::output1() -// CHECK-NEXT: unsigned variadicResultSize = (this->getNumResults() - 1) / 2; -// CHECK-NEXT: unsigned offset = 0 + variadicResultSize * 0; +// CHECK-NEXT: variadicResultSize = (this->getNumResults() - 1) / 2; +// CHECK-NEXT: offset = 0 + variadicResultSize * 0; // CHECK-NEXT: return {std::next(result_begin(), offset), std::next(result_begin(), offset + variadicResultSize)}; // CHECK-LABEL: Value *OpI::output2() -// CHECK-NEXT: unsigned variadicResultSize = (this->getNumResults() - 1) / 2; -// CHECK-NEXT: unsigned offset = 0 + variadicResultSize * 1; +// CHECK-NEXT: variadicResultSize = (this->getNumResults() - 1) / 2; +// CHECK-NEXT: offset = 0 + variadicResultSize * 1; // CHECK-NEXT: return this->getResult(offset); // CHECK-LABEL: Operation::result_range OpI::output3() -// CHECK-NEXT: unsigned variadicResultSize = (this->getNumResults() - 1) / 2; -// CHECK-NEXT: unsigned offset = 1 + variadicResultSize * 1; +// CHECK-NEXT: variadicResultSize = (this->getNumResults() - 1) / 2; +// CHECK-NEXT: offset = 1 + variadicResultSize * 1; // CHECK-NEXT: return {std::next(result_begin(), offset), std::next(result_begin(), offset + variadicResultSize)}; // CHECK-LABEL: OpI::build diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 58f735c938f..1a3167338ce 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -461,9 +461,9 @@ void OpEmitter::genAttrGetters() { } void OpEmitter::genNamedOperandGetters() { - const unsigned numOperands = op.getNumOperands(); - const unsigned numVariadicOperands = op.getNumVariadicOperands(); - const unsigned numNormalOperands = numOperands - numVariadicOperands; + const int numOperands = op.getNumOperands(); + const int numVariadicOperands = op.getNumVariadicOperands(); + const int numNormalOperands = numOperands - numVariadicOperands; // Special case for ops without variadic operands: the i-th value is for the // i-th operand defined in the op. @@ -473,7 +473,7 @@ void OpEmitter::genNamedOperandGetters() { // operand. if (numVariadicOperands <= 1) { bool emittedVariadicOperand = false; - for (unsigned i = 0; i != numOperands; ++i) { + for (int i = 0; i != numOperands; ++i) { const auto &operand = op.getOperand(i); if (operand.name.empty()) continue; @@ -506,17 +506,17 @@ void OpEmitter::genNamedOperandGetters() { "specification over their sizes"); } - unsigned emittedNormalOperands = 0; - unsigned emittedVariadicOperands = 0; + int emittedNormalOperands = 0; + int emittedVariadicOperands = 0; - for (unsigned i = 0; i != numOperands; ++i) { + for (int i = 0; i != numOperands; ++i) { const auto &operand = op.getOperand(i); if (operand.name.empty()) continue; const char *code = R"( - unsigned variadicOperandSize = (this->getNumOperands() - {0}) / {1}; - unsigned offset = {2} + variadicOperandSize * {3}; + int variadicOperandSize = (this->getNumOperands() - {0}) / {1}; + int offset = {2} + variadicOperandSize * {3}; return )"; auto sizeAndOffset = formatv(code, numNormalOperands, numVariadicOperands, @@ -537,9 +537,9 @@ void OpEmitter::genNamedOperandGetters() { } void OpEmitter::genNamedResultGetters() { - const unsigned numResults = op.getNumResults(); - const unsigned numVariadicResults = op.getNumVariadicResults(); - const unsigned numNormalResults = numResults - numVariadicResults; + const int numResults = op.getNumResults(); + const int numVariadicResults = op.getNumVariadicResults(); + const int numNormalResults = numResults - numVariadicResults; // Special case for ops without variadic results: the i-th value is for the // i-th result defined in the op. @@ -549,7 +549,7 @@ void OpEmitter::genNamedResultGetters() { // result. if (numVariadicResults <= 1) { bool emittedVariadicResult = false; - for (unsigned i = 0; i != numResults; ++i) { + for (int i = 0; i != numResults; ++i) { const auto &result = op.getResult(i); if (result.name.empty()) continue; @@ -582,17 +582,17 @@ void OpEmitter::genNamedResultGetters() { "specification over their sizes"); } - unsigned emittedNormalResults = 0; - unsigned emittedVariadicResults = 0; + int emittedNormalResults = 0; + int emittedVariadicResults = 0; - for (unsigned i = 0; i != numResults; ++i) { + for (int i = 0; i != numResults; ++i) { const auto &result = op.getResult(i); if (result.name.empty()) continue; const char *code = R"( - unsigned variadicResultSize = (this->getNumResults() - {0}) / {1}; - unsigned offset = {2} + variadicResultSize * {3}; + int variadicResultSize = (this->getNumResults() - {0}) / {1}; + int offset = {2} + variadicResultSize * {3}; return )"; auto sizeAndOffset = formatv(code, numNormalResults, numVariadicResults, emittedNormalResults, emittedVariadicResults); @@ -628,7 +628,7 @@ void OpEmitter::genStandaloneParamBuilder(bool useOperandType, // Emit parameters for all return types if (!useOperandType && !useAttrType) { - for (unsigned i = 0; i != numResults; ++i) { + for (int i = 0; i != numResults; ++i) { const auto &result = op.getResult(i); std::string resultName = result.name; if (resultName.empty()) @@ -674,7 +674,7 @@ void OpEmitter::genStandaloneParamBuilder(bool useOperandType, // Push all result types to the result if (numResults > 0) { if (!useOperandType && !useAttrType) { - for (unsigned i = 0; i < numResults; ++i) { + for (int i = 0; i < numResults; ++i) { const auto &result = op.getResult(i); m.body() << " " << builderOpState; if (result.isVariadic()) { @@ -699,14 +699,14 @@ void OpEmitter::genStandaloneParamBuilder(bool useOperandType, formatv("{0}{1}->getType()", getArgumentName(op, 0), index).str(); } m.body() << " " << builderOpState << "->addTypes({" << resultType; - for (unsigned i = 1; i != numResults; ++i) + for (int i = 1; i != numResults; ++i) m.body() << ", " << resultType; m.body() << "});\n\n"; } } // Push all operands to the result - for (unsigned i = 0; i < numOperands; ++i) { + for (int i = 0; i < numOperands; ++i) { const auto &operand = op.getOperand(i); m.body() << " " << builderOpState; if (operand.isVariadic()) { @@ -755,13 +755,13 @@ void OpEmitter::genBuilder() { } } - unsigned numResults = op.getNumResults(); - unsigned numVariadicResults = op.getNumVariadicResults(); - unsigned numNonVariadicResults = numResults - numVariadicResults; + int numResults = op.getNumResults(); + int numVariadicResults = op.getNumVariadicResults(); + int numNonVariadicResults = numResults - numVariadicResults; - unsigned numOperands = op.getNumOperands(); - unsigned numVariadicOperands = op.getNumVariadicOperands(); - unsigned numNonVariadicOperands = numOperands - numVariadicOperands; + int numOperands = op.getNumOperands(); + int numVariadicOperands = op.getNumVariadicOperands(); + int numNonVariadicOperands = numOperands - numVariadicOperands; // Generate default builders that requires all result type, operands, and // attributes as parameters. @@ -955,11 +955,11 @@ void OpEmitter::genVerifier() { } }; - for (unsigned i = 0, e = op.getNumOperands(); i < e; ++i) { + for (int i = 0, e = op.getNumOperands(); i < e; ++i) { verifyValue(op.getOperand(i), i, /*isOperand=*/true); } - for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) { + for (int i = 0, e = op.getNumResults(); i < e; ++i) { verifyValue(op.getResult(i), i, /*isOperand=*/false); } @@ -979,8 +979,8 @@ void OpEmitter::genVerifier() { } void OpEmitter::genTraits() { - unsigned numResults = op.getNumResults(); - unsigned numVariadicResults = op.getNumVariadicResults(); + int numResults = op.getNumResults(); + int numVariadicResults = op.getNumVariadicResults(); // Add return size trait. if (numVariadicResults != 0) { @@ -1008,8 +1008,8 @@ void OpEmitter::genTraits() { } // Add variadic size trait and normal op traits. - unsigned numOperands = op.getNumOperands(); - unsigned numVariadicOperands = op.getNumVariadicOperands(); + int numOperands = op.getNumOperands(); + int numVariadicOperands = op.getNumVariadicOperands(); // Add operand size trait. if (numVariadicOperands != 0) { diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 501b7a15d30..6c0e5287706 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -420,8 +420,8 @@ void PatternEmitter::emitMatchMethod(DagNode tree) { PrintFatalError(loc, "only support up to 4-entity constraints now"); } SmallVector names; - unsigned i = 0; - for (unsigned e = entities.size(); i < e; ++i) + int i = 0; + for (int e = entities.size(); i < e; ++i) names.push_back(resolveSymbol(entities[i])); for (; i < 4; ++i) names.push_back(""); @@ -475,7 +475,7 @@ void PatternEmitter::emit(StringRef rewriteName) { void PatternEmitter::emitRewriteMethod() { const Operator &rootOp = pattern.getSourceRootOp(); int numExpectedResults = rootOp.getNumResults(); - unsigned numProvidedResults = pattern.getNumResults(); + int numProvidedResults = pattern.getNumResults(); if (numProvidedResults < numExpectedResults) PrintFatalError( @@ -490,7 +490,7 @@ void PatternEmitter::emitRewriteMethod() { // Collect the replacement value for each result llvm::SmallVector resultValues; - for (unsigned i = 0; i < numProvidedResults; ++i) { + for (int i = 0; i < numProvidedResults; ++i) { DagNode resultTree = pattern.getResultPattern(i); resultValues.push_back(handleRewritePattern(resultTree, i, 0)); // Keep track of bound symbols at the top-level DAG nodes @@ -595,7 +595,7 @@ std::string PatternEmitter::emitReplaceWithNativeCodeCall(DagNode tree) { PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " + Twine(tree.getNumArgs())); } - for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i) { + for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)); } return tgfmt(fmt, &rewriteCtx, attrs[0], attrs[1], attrs[2], attrs[3], @@ -646,7 +646,7 @@ std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex, // First go through all the child nodes who are nested DAG constructs to // create ops for them, so that we can use the results in the current node. // This happens in a recursive manner. - for (unsigned i = 0, e = resultOp.getNumOperands(); i != e; ++i) { + for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) { if (auto child = tree.getArgAsNestedDag(i)) { childNodeNames[i] = handleRewritePattern(child, i, depth + 1); // Keep track of bound symbols at the middle-level DAG nodes -- cgit v1.2.3 From 0134b5df3a00018f7db2a0ee5be6c4abf9bee4b2 Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Sat, 11 May 2019 08:28:15 -0700 Subject: Cleanups and simplifications to code, noticed by inspection. NFC. -- PiperOrigin-RevId: 247758075 --- mlir/lib/Analysis/SliceAnalysis.cpp | 2 -- mlir/lib/Transforms/LoopFusion.cpp | 2 -- mlir/lib/Transforms/LoopInvariantCodeMotion.cpp | 24 +++++++----------------- mlir/lib/Transforms/LoopTiling.cpp | 3 --- 4 files changed, 7 insertions(+), 24 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index 496c0b33e1e..bce000a4c1f 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -25,9 +25,7 @@ #include "mlir/IR/Operation.h" #include "mlir/Support/Functional.h" #include "mlir/Support/STLExtras.h" - #include "llvm/ADT/SetVector.h" -#include /// /// Implements Analysis functions specific to slicing in Function. diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index fda6574f503..796d2164ad9 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -39,8 +39,6 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include -#include - #define DEBUG_TYPE "affine-loop-fusion" using llvm::SetVector; diff --git a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp index b03a3c70a17..2f95db95c6f 100644 --- a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp +++ b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp @@ -19,9 +19,6 @@ // //===----------------------------------------------------------------------===// -#include -#include - #include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" @@ -57,7 +54,6 @@ namespace { struct LoopInvariantCodeMotion : public FunctionPass { void runOnFunction() override; void runOnAffineForOp(AffineForOp forOp); - std::vector forOps; }; } // end anonymous namespace @@ -81,7 +77,7 @@ void LoopInvariantCodeMotion::runOnAffineForOp(AffineForOp forOp) { LLVM_DEBUG(for (auto i : loopDefinedOps) { - (i->print(llvm::dbgs() << "\nLoop-dependent op\n")); + i->print(llvm::dbgs() << "\nLoop-dependent op\n"); }); for (auto &op : *loopBody) { @@ -109,20 +105,14 @@ void LoopInvariantCodeMotion::runOnAffineForOp(AffineForOp forOp) { } void LoopInvariantCodeMotion::runOnFunction() { - forOps.clear(); - - // Gather all loops in a function, and order them in innermost-loop-first - // order. This way, we first LICM from the inner loop, and place the ops in - // the outer loop, which in turn can be further LICM'ed. This saves iterating - // on the inner loop operations while LICMing through the outer loop. - getFunction().walk( - [&](AffineForOp forOp) { forOps.push_back(forOp); }); - // We gather loops first, and then go over them later because we don't want to - // mess the iterators up. - for (auto op : forOps) { + + // Walk through all loops in a function in innermost-loop-first order. This + // way, we first LICM from the inner loop, and place the ops in + // the outer loop, which in turn can be further LICM'ed. + getFunction().walk([&](AffineForOp op) { LLVM_DEBUG(op.getOperation()->print(llvm::dbgs() << "\nOriginal loop\n")); runOnAffineForOp(op); - } + }); } static PassRegistration diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 303193bcb96..ce42a5eba85 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -31,9 +31,6 @@ #include "mlir/Transforms/Utils.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" -#include -#include - using namespace mlir; #define DEBUG_TYPE "affine-loop-tile" -- cgit v1.2.3 From 02e03b9bf4a1fe60b89d4bd662895ebcc374129b Mon Sep 17 00:00:00 2001 From: River Riddle Date: Sat, 11 May 2019 15:17:28 -0700 Subject: Add support for using llvm::dyn_cast/cast/isa for operation casts and replace usages of Operation::dyn_cast with llvm::dyn_cast. -- PiperOrigin-RevId: 247778391 --- mlir/examples/Linalg/Linalg1/lib/Analysis.cpp | 4 +- mlir/examples/Linalg/Linalg1/lib/Common.cpp | 2 +- mlir/examples/Linalg/Linalg1/lib/Utils.cpp | 2 +- mlir/examples/Linalg/Linalg2/lib/Transforms.cpp | 4 +- .../Linalg/Linalg3/include/linalg3/TensorOps-inl.h | 2 +- mlir/examples/Linalg/Linalg3/lib/Transforms.cpp | 14 +++--- mlir/examples/Linalg/Linalg4/lib/Transforms.cpp | 16 +++---- mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp | 8 ++-- mlir/examples/toy/Ch4/mlir/ToyCombine.cpp | 4 +- mlir/examples/toy/Ch5/mlir/LateLowering.cpp | 8 ++-- mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp | 6 +-- mlir/include/mlir/EDSC/Builders.h | 2 +- mlir/include/mlir/IR/Builders.h | 2 +- mlir/include/mlir/IR/Function.h | 2 +- mlir/include/mlir/IR/OpDefinition.h | 12 ++--- mlir/include/mlir/IR/Operation.h | 51 ++++++++++++---------- mlir/include/mlir/IR/PatternMatch.h | 4 +- mlir/include/mlir/Support/LLVM.h | 1 + mlir/lib/AffineOps/AffineOps.cpp | 8 ++-- mlir/lib/Analysis/LoopAnalysis.cpp | 8 ++-- mlir/lib/Analysis/MemRefBoundCheck.cpp | 4 +- mlir/lib/Analysis/SliceAnalysis.cpp | 2 +- mlir/lib/Analysis/Utils.cpp | 12 ++--- mlir/lib/Analysis/VectorAnalysis.cpp | 8 ++-- mlir/lib/EDSC/Builders.cpp | 6 +-- mlir/lib/Linalg/Transforms/Tiling.cpp | 6 +-- mlir/lib/Linalg/Utils/Utils.cpp | 6 +-- mlir/lib/StandardOps/Ops.cpp | 2 +- mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 6 +-- mlir/lib/Transforms/DmaGeneration.cpp | 12 ++--- mlir/lib/Transforms/LoopFusion.cpp | 8 ++-- mlir/lib/Transforms/LoopTiling.cpp | 2 +- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 2 +- mlir/lib/Transforms/LowerAffine.cpp | 4 +- mlir/lib/Transforms/MaterializeVectors.cpp | 4 +- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 2 +- mlir/lib/Transforms/PipelineDataTransfer.cpp | 4 +- mlir/lib/Transforms/TestConstantFold.cpp | 2 +- mlir/lib/Transforms/Utils/ConstantFoldUtils.cpp | 2 +- mlir/lib/Transforms/Utils/LoopUtils.cpp | 2 +- .../Vectorization/VectorizerTestPass.cpp | 2 +- mlir/lib/Transforms/Vectorize.cpp | 8 ++-- mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp | 4 +- 43 files changed, 140 insertions(+), 130 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/examples/Linalg/Linalg1/lib/Analysis.cpp b/mlir/examples/Linalg/Linalg1/lib/Analysis.cpp index ecb6309466a..a7fba179c79 100644 --- a/mlir/examples/Linalg/Linalg1/lib/Analysis.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/Analysis.cpp @@ -31,7 +31,7 @@ ViewOp linalg::getViewBaseViewOp(Value *view) { auto viewType = view->getType().dyn_cast(); (void)viewType; assert(viewType.isa() && "expected a ViewType"); - while (auto slice = view->getDefiningOp()->dyn_cast()) { + while (auto slice = dyn_cast(view->getDefiningOp())) { view = slice.getParentView(); assert(viewType.isa() && "expected a ViewType"); } @@ -48,7 +48,7 @@ std::pair linalg::getViewRootIndexing(Value *view, (void)viewType; assert(viewType.isa() && "expected a ViewType"); assert(dim < viewType.getRank() && "dim exceeds rank"); - if (auto viewOp = view->getDefiningOp()->dyn_cast()) + if (auto viewOp = dyn_cast(view->getDefiningOp())) return std::make_pair(viewOp.getIndexing(dim), dim); auto sliceOp = view->getDefiningOp()->cast(); diff --git a/mlir/examples/Linalg/Linalg1/lib/Common.cpp b/mlir/examples/Linalg/Linalg1/lib/Common.cpp index bfdc40a6aa0..278f9c57607 100644 --- a/mlir/examples/Linalg/Linalg1/lib/Common.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/Common.cpp @@ -40,7 +40,7 @@ linalg::common::LoopNestRangeBuilder::LoopNestRangeBuilder( assert(ivs.size() == indexings.size()); for (unsigned i = 0, e = indexings.size(); i < e; ++i) { auto rangeOp = - indexings[i].getValue()->getDefiningOp()->dyn_cast(); + llvm::dyn_cast(indexings[i].getValue()->getDefiningOp()); if (!rangeOp) { continue; } diff --git a/mlir/examples/Linalg/Linalg1/lib/Utils.cpp b/mlir/examples/Linalg/Linalg1/lib/Utils.cpp index 372c08f9eea..5bcebc79c18 100644 --- a/mlir/examples/Linalg/Linalg1/lib/Utils.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/Utils.cpp @@ -33,7 +33,7 @@ using namespace linalg::intrinsics; unsigned linalg::getViewRank(Value *view) { assert(view->getType().isa() && "expected a ViewType"); - if (auto viewOp = view->getDefiningOp()->dyn_cast()) + if (auto viewOp = dyn_cast(view->getDefiningOp())) return viewOp.getRank(); return view->getDefiningOp()->cast().getRank(); } diff --git a/mlir/examples/Linalg/Linalg2/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg2/lib/Transforms.cpp index d1af7503d1b..83fd9ad3143 100644 --- a/mlir/examples/Linalg/Linalg2/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg2/lib/Transforms.cpp @@ -43,7 +43,7 @@ using namespace linalg::intrinsics; // analyses. This builds the chain. static SmallVector getViewChain(mlir::Value *v) { assert(v->getType().isa() && "ViewType expected"); - if (v->getDefiningOp()->dyn_cast()) { + if (v->getDefiningOp()->isa()) { return SmallVector{v}; } @@ -53,7 +53,7 @@ static SmallVector getViewChain(mlir::Value *v) { tmp.push_back(v); v = sliceOp.getParentView(); } while (!v->getType().isa()); - assert(v->getDefiningOp()->cast() && "must be a ViewOp"); + assert(v->getDefiningOp()->isa() && "must be a ViewOp"); tmp.push_back(v); return SmallVector(tmp.rbegin(), tmp.rend()); } diff --git a/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps-inl.h b/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps-inl.h index 9339d7309e3..3090f29dcfc 100644 --- a/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps-inl.h +++ b/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps-inl.h @@ -91,7 +91,7 @@ inline llvm::SmallVector extractRangesFromViewOrSliceOp(mlir::Value *view) { // This expects a viewType which must come from either ViewOp or SliceOp. assert(view->getType().isa() && "expected ViewType"); - if (auto viewOp = view->getDefiningOp()->dyn_cast()) + if (auto viewOp = llvm::dyn_cast(view->getDefiningOp())) return viewOp.getRanges(); auto sliceOp = view->getDefiningOp()->cast(); diff --git a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp index 42999aef7ae..bce7f58860d 100644 --- a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp @@ -46,9 +46,9 @@ void linalg::composeSliceOps(mlir::Function *f) { void linalg::lowerToFinerGrainedTensorContraction(mlir::Function *f) { f->walk([](Operation *op) { - if (auto matmulOp = op->dyn_cast()) { + if (auto matmulOp = dyn_cast(op)) { matmulOp.writeAsFinerGrainTensorContraction(); - } else if (auto matvecOp = op->dyn_cast()) { + } else if (auto matvecOp = dyn_cast(op)) { matvecOp.writeAsFinerGrainTensorContraction(); } else { return; @@ -205,11 +205,11 @@ writeContractionAsLoops(ContractionOp contraction) { llvm::Optional> linalg::writeAsLoops(Operation *op) { - if (auto matmulOp = op->dyn_cast()) { + if (auto matmulOp = dyn_cast(op)) { return writeContractionAsLoops(matmulOp); - } else if (auto matvecOp = op->dyn_cast()) { + } else if (auto matvecOp = dyn_cast(op)) { return writeContractionAsLoops(matvecOp); - } else if (auto dotOp = op->dyn_cast()) { + } else if (auto dotOp = dyn_cast(op)) { return writeContractionAsLoops(dotOp); } return llvm::None; @@ -276,7 +276,7 @@ PatternMatchResult Rewriter::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { auto load = op->cast(); - SliceOp slice = load.getView()->getDefiningOp()->dyn_cast(); + SliceOp slice = dyn_cast(load.getView()->getDefiningOp()); ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult()) : load.getView()->getDefiningOp()->cast(); ScopedContext scope(FuncBuilder(load), load.getLoc()); @@ -291,7 +291,7 @@ PatternMatchResult Rewriter::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { auto store = op->cast(); - SliceOp slice = store.getView()->getDefiningOp()->dyn_cast(); + SliceOp slice = dyn_cast(store.getView()->getDefiningOp()); ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult()) : store.getView()->getDefiningOp()->cast(); ScopedContext scope(FuncBuilder(store), store.getLoc()); diff --git a/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp index 05865e9e53c..6771257ae0f 100644 --- a/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp @@ -52,8 +52,8 @@ void linalg::lowerToTiledLoops(mlir::Function *f, } static bool isZeroIndex(Value *v) { - return v->getDefiningOp() && v->getDefiningOp()->isa() && - v->getDefiningOp()->dyn_cast().getValue() == 0; + return isa_and_nonnull(v->getDefiningOp()) && + cast(v->getDefiningOp()).getValue() == 0; } template @@ -178,11 +178,11 @@ writeContractionAsTiledViews(TensorContractionBase &contraction, llvm::Optional> linalg::writeAsTiledViews(Operation *op, ArrayRef tileSizes) { - if (auto matmulOp = op->dyn_cast()) { + if (auto matmulOp = dyn_cast(op)) { return writeContractionAsTiledViews(matmulOp, tileSizes); - } else if (auto matvecOp = op->dyn_cast()) { + } else if (auto matvecOp = dyn_cast(op)) { return writeContractionAsTiledViews(matvecOp, tileSizes); - } else if (auto dotOp = op->dyn_cast()) { + } else if (auto dotOp = dyn_cast(op)) { return writeContractionAsTiledViews(dotOp, tileSizes); } return llvm::None; @@ -190,11 +190,11 @@ linalg::writeAsTiledViews(Operation *op, ArrayRef tileSizes) { void linalg::lowerToTiledViews(mlir::Function *f, ArrayRef tileSizes) { f->walk([tileSizes](Operation *op) { - if (auto matmulOp = op->dyn_cast()) { + if (auto matmulOp = dyn_cast(op)) { writeAsTiledViews(matmulOp, tileSizes); - } else if (auto matvecOp = op->dyn_cast()) { + } else if (auto matvecOp = dyn_cast(op)) { writeAsTiledViews(matvecOp, tileSizes); - } else if (auto dotOp = op->dyn_cast()) { + } else if (auto dotOp = dyn_cast(op)) { writeAsTiledViews(dotOp, tileSizes); } else { return; diff --git a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp index a11c88266b7..c9f98e7d6a9 100644 --- a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp @@ -238,13 +238,13 @@ public: LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); // The add operation is trivial: propagate the input type as is. - if (auto addOp = op->dyn_cast()) { + if (auto addOp = llvm::dyn_cast(op)) { op->getResult(0)->setType(op->getOperand(0)->getType()); continue; } // Transpose is easy: just invert the dimensions. - if (auto transpose = op->dyn_cast()) { + if (auto transpose = llvm::dyn_cast(op)) { SmallVector dims; auto arrayTy = transpose.getOperand()->getType().cast(); dims.insert(dims.end(), arrayTy.getShape().begin(), @@ -259,7 +259,7 @@ public: // catch it but shape inference earlier in the pass could generate an // invalid IR (from an invalid Toy input of course) and we wouldn't want // to crash here. - if (auto mulOp = op->dyn_cast()) { + if (auto mulOp = llvm::dyn_cast(op)) { auto lhs = mulOp.getLHS()->getType().cast(); auto rhs = mulOp.getRHS()->getType().cast(); auto lhsRank = lhs.getShape().size(); @@ -291,7 +291,7 @@ public: // for this function, queue the callee in the inter-procedural work list, // and return. The current function stays in the work list and will // restart after the callee is processed. - if (auto callOp = op->dyn_cast()) { + if (auto callOp = llvm::dyn_cast(op)) { auto calleeName = callOp.getCalleeName(); auto *callee = getModule().getNamedFunction(calleeName); if (!callee) { diff --git a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp index f3e8ff06781..942ce866182 100644 --- a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp @@ -53,7 +53,7 @@ struct SimplifyRedundantTranspose : public mlir::RewritePattern { // Look through the input of the current transpose. mlir::Value *transposeInput = transpose.getOperand(); TransposeOp transposeInputOp = - mlir::dyn_cast_or_null(transposeInput->getDefiningOp()); + llvm::dyn_cast_or_null(transposeInput->getDefiningOp()); // If the input is defined by another Transpose, bingo! if (!transposeInputOp) return matchFailure(); @@ -75,7 +75,7 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern { mlir::PatternRewriter &rewriter) const override { ReshapeOp reshape = op->cast(); // Look through the input of the current reshape. - ConstantOp constantOp = mlir::dyn_cast_or_null( + ConstantOp constantOp = llvm::dyn_cast_or_null( reshape.getOperand()->getDefiningOp()); // If the input is defined by another constant, bingo! if (!constantOp) diff --git a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp index 4ef62d33adc..534b5cbd2ab 100644 --- a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp @@ -366,7 +366,7 @@ struct LateLoweringPass : public ModulePass { // First patch calls type to return memref instead of ToyArray for (auto &function : getModule()) { function.walk([&](Operation *op) { - auto callOp = op->dyn_cast(); + auto callOp = dyn_cast(op); if (!callOp) return; if (!callOp.getNumResults()) @@ -382,14 +382,14 @@ struct LateLoweringPass : public ModulePass { for (auto &function : getModule()) { function.walk([&](Operation *op) { // Turns toy.alloc into sequence of alloc/dealloc (later malloc/free). - if (auto allocOp = op->dyn_cast()) { + if (auto allocOp = dyn_cast(op)) { auto result = allocTensor(allocOp); allocOp.replaceAllUsesWith(result); allocOp.erase(); return; } // Eliminate all type.cast before lowering to LLVM. - if (auto typeCastOp = op->dyn_cast()) { + if (auto typeCastOp = dyn_cast(op)) { typeCastOp.replaceAllUsesWith(typeCastOp.getOperand()); typeCastOp.erase(); return; @@ -429,7 +429,7 @@ struct LateLoweringPass : public ModulePass { // Insert a `dealloc` operation right before the `return` operations, unless // it is returned itself in which case the caller is responsible for it. builder.getFunction()->walk([&](Operation *op) { - auto returnOp = op->dyn_cast(); + auto returnOp = dyn_cast(op); if (!returnOp) return; if (returnOp.getNumOperands() && returnOp.getOperand(0) == alloc) diff --git a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp index a083e62f05f..4e17b234d14 100644 --- a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp @@ -238,7 +238,7 @@ public: LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); // The add operation is trivial: propagate the input type as is. - if (auto addOp = op->dyn_cast()) { + if (auto addOp = llvm::dyn_cast(op)) { op->getResult(0)->setType(op->getOperand(0)->getType()); continue; } @@ -261,7 +261,7 @@ public: // catch it but shape inference earlier in the pass could generate an // invalid IR (from an invalid Toy input of course) and we wouldn't want // to crash here. - if (auto mulOp = op->dyn_cast()) { + if (auto mulOp = llvm::dyn_cast(op)) { auto lhs = mulOp.getLHS()->getType().cast(); auto rhs = mulOp.getRHS()->getType().cast(); auto lhsRank = lhs.getShape().size(); @@ -295,7 +295,7 @@ public: // for this function, queue the callee in the inter-procedural work list, // and return. The current function stays in the work list and will // restart after the callee is processed. - if (auto callOp = op->dyn_cast()) { + if (auto callOp = llvm::dyn_cast(op)) { auto calleeName = callOp.getCalleeName(); auto *callee = getModule().getNamedFunction(calleeName); if (!callee) { diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h index 5d23488c95d..39302f6c0f9 100644 --- a/mlir/include/mlir/EDSC/Builders.h +++ b/mlir/include/mlir/EDSC/Builders.h @@ -439,7 +439,7 @@ ValueHandle ValueHandle::create(Args... args) { if (op->getNumResults() == 1) { return ValueHandle(op->getResult(0)); } else if (op->getNumResults() == 0) { - if (auto f = op->dyn_cast()) { + if (auto f = dyn_cast(op)) { return ValueHandle(f.getInductionVar()); } } diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 1ee6c4806fb..7f182e882db 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -271,7 +271,7 @@ public: OperationState state(getContext(), location, OpTy::getOperationName()); OpTy::build(this, &state, args...); auto *op = createOperation(state); - auto result = op->dyn_cast(); + auto result = dyn_cast(op); assert(result && "Builder didn't return the right type"); return result; } diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h index 0770d2cfa27..d4b85b56d0f 100644 --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -116,7 +116,7 @@ public: /// Specialization of walk to only visit operations of 'OpTy'. template void walk(std::function callback) { walk([&](Operation *opInst) { - if (auto op = opInst->dyn_cast()) + if (auto op = dyn_cast(opInst)) callback(op); }); } diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index b80e8aca9bc..2eff412a71e 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -792,7 +792,7 @@ public: /// This is the hook used by the AsmPrinter to emit this to the .mlir file. /// Op implementations should provide a print method. static void printAssembly(Operation *op, OpAsmPrinter *p) { - auto opPointer = op->dyn_cast(); + auto opPointer = dyn_cast(op); assert(opPointer && "op's name does not match name of concrete type instantiated with"); opPointer.print(p); @@ -825,11 +825,13 @@ public: /// This is a public constructor. Any op can be initialized to null. explicit Op() : OpState(nullptr) {} + Op(std::nullptr_t) : OpState(nullptr) {} -protected: - /// This is a private constructor only accessible through the - /// Operation::cast family of methods. - explicit Op(Operation *state) : OpState(state) {} + /// This is a public constructor to enable access via the llvm::cast family of + /// methods. This should not be used directly. + explicit Op(Operation *state) : OpState(state) { + assert(!state || isa(state)); + } friend class Operation; private: diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index 54e49b73e3b..31ec8ea54a6 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -389,14 +389,6 @@ public: // Conversions to declared operations like DimOp //===--------------------------------------------------------------------===// - /// The dyn_cast methods perform a dynamic cast from an Operation to a typed - /// Op like DimOp. This returns a null Op on failure. - template OpClass dyn_cast() { - if (isa()) - return cast(); - return OpClass(); - } - /// The cast methods perform a cast from an Operation to a typed Op like /// DimOp. This aborts if the parameter to the template isn't an instance of /// the template type argument. @@ -417,10 +409,10 @@ public: /// including this one. void walk(const std::function &callback); - /// Specialization of walk to only visit operations of 'OpTy'. - template void walk(std::function callback) { + /// Specialization of walk to only visit operations of 'T'. + template void walk(std::function callback) { walk([&](Operation *op) { - if (auto derivedOp = op->dyn_cast()) + if (auto derivedOp = dyn_cast(op)) callback(derivedOp); }); } @@ -534,17 +526,6 @@ inline auto Operation::getOperands() -> operand_range { return {operand_begin(), operand_end()}; } -/// Provide dyn_cast_or_null functionality for Operation casts. -template T dyn_cast_or_null(Operation *op) { - return op ? op->dyn_cast() : T(); -} - -/// Provide isa_and_nonnull functionality for Operation casts, i.e. if the -/// operation is non-null and a class of 'T'. -template bool isa_and_nonnull(Operation *op) { - return op && op->isa(); -} - /// This class implements the result iterators for the Operation class /// in terms of getResult(idx). class ResultIterator final @@ -598,4 +579,30 @@ inline auto Operation::getResultTypes() } // end namespace mlir +namespace llvm { +/// Provide isa functionality for operation casts. +template struct isa_impl { + static inline bool doit(const ::mlir::Operation &op) { + return T::classof(const_cast<::mlir::Operation *>(&op)); + } +}; + +/// Provide specializations for operation casts as the resulting T is value +/// typed. +template struct cast_retty_impl { + using ret_type = T; +}; +template struct cast_retty_impl { + using ret_type = T; +}; +template +struct cast_convert_val { + static T doit(::mlir::Operation &val) { return T(&val); } +}; +template +struct cast_convert_val { + static T doit(::mlir::Operation *val) { return T(val); } +}; +} // end namespace llvm + #endif // MLIR_IR_OPERATION_H diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 3b02ed55c34..51528c18d38 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -215,7 +215,7 @@ public: OperationState state(getContext(), location, OpTy::getOperationName()); OpTy::build(this, &state, args...); auto *op = createOperation(state); - auto result = op->dyn_cast(); + auto result = dyn_cast(op); assert(result && "Builder didn't return the right type"); return result; } @@ -231,7 +231,7 @@ public: // If the Operation we produce is valid, return it. if (!OpTy::verifyInvariants(op)) { - auto result = op->dyn_cast(); + auto result = dyn_cast(op); assert(result && "Builder didn't return the right type"); return result; } diff --git a/mlir/include/mlir/Support/LLVM.h b/mlir/include/mlir/Support/LLVM.h index 031dceb518e..6676ad0d818 100644 --- a/mlir/include/mlir/Support/LLVM.h +++ b/mlir/include/mlir/Support/LLVM.h @@ -69,6 +69,7 @@ using llvm::cast_or_null; using llvm::dyn_cast; using llvm::dyn_cast_or_null; using llvm::isa; +using llvm::isa_and_nonnull; // Containers. using llvm::ArrayRef; diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 51209da7385..2dfed934ee0 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -61,11 +61,11 @@ bool mlir::isValidDim(Value *value) { if (op->getParentOp() == nullptr || op->isa()) return true; // Affine apply operation is ok if all of its operands are ok. - if (auto applyOp = op->dyn_cast()) + if (auto applyOp = dyn_cast(op)) return applyOp.isValidDim(); // The dim op is okay if its operand memref/tensor is defined at the top // level. - if (auto dimOp = op->dyn_cast()) + if (auto dimOp = dyn_cast(op)) return isTopLevelSymbol(dimOp.getOperand()); return false; } @@ -86,11 +86,11 @@ bool mlir::isValidSymbol(Value *value) { if (op->getParentOp() == nullptr || op->isa()) return true; // Affine apply operation is ok if all of its operands are ok. - if (auto applyOp = op->dyn_cast()) + if (auto applyOp = dyn_cast(op)) return applyOp.isValidSymbol(); // The dim op is okay if its operand memref/tensor is defined at the top // level. - if (auto dimOp = op->dyn_cast()) + if (auto dimOp = dyn_cast(op)) return isTopLevelSymbol(dimOp.getOperand()); return false; } diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 78caa4c2625..60f2b142986 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -320,8 +320,8 @@ isVectorizableLoopBodyWithOpCond(AffineForOp loop, loadAndStores.match(forOp, &loadAndStoresMatched); for (auto ls : loadAndStoresMatched) { auto *op = ls.getMatchedOperation(); - auto load = op->dyn_cast(); - auto store = op->dyn_cast(); + auto load = dyn_cast(op); + auto store = dyn_cast(op); // Only scalar types are considered vectorizable, all load/store must be // vectorizable for a loop to qualify as vectorizable. // TODO(ntv): ponder whether we want to be more general here. @@ -338,8 +338,8 @@ isVectorizableLoopBodyWithOpCond(AffineForOp loop, bool mlir::isVectorizableLoopBody(AffineForOp loop, int *memRefDim) { VectorizableOpFun fun([memRefDim](AffineForOp loop, Operation &op) { - auto load = op.dyn_cast(); - auto store = op.dyn_cast(); + auto load = dyn_cast(op); + auto store = dyn_cast(op); return load ? isContiguousAccess(loop.getInductionVar(), load, memRefDim) : isContiguousAccess(loop.getInductionVar(), store, memRefDim); }); diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index 0fb88620fa1..4e23441d5a5 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -48,9 +48,9 @@ FunctionPassBase *mlir::createMemRefBoundCheckPass() { void MemRefBoundCheck::runOnFunction() { getFunction().walk([](Operation *opInst) { - if (auto loadOp = opInst->dyn_cast()) { + if (auto loadOp = dyn_cast(opInst)) { boundCheckLoadOrStoreOp(loadOp); - } else if (auto storeOp = opInst->dyn_cast()) { + } else if (auto storeOp = dyn_cast(opInst)) { boundCheckLoadOrStoreOp(storeOp); } // TODO(bondhugula): do this for DMA ops as well. diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index bce000a4c1f..155a2bbbd1b 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -50,7 +50,7 @@ static void getForwardSliceImpl(Operation *op, return; } - if (auto forOp = op->dyn_cast()) { + if (auto forOp = dyn_cast(op)) { for (auto &u : forOp.getInductionVar()->getUses()) { auto *ownerInst = u.getOwner(); if (forwardSlice->count(ownerInst) == 0) { diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 1eaab676567..8d963e4739c 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -44,7 +44,7 @@ void mlir::getLoopIVs(Operation &op, SmallVectorImpl *loops) { AffineForOp currAffineForOp; // Traverse up the hierarchy collecing all 'affine.for' operation while // skipping over 'affine.if' operations. - while (currOp && ((currAffineForOp = currOp->dyn_cast()) || + while (currOp && ((currAffineForOp = dyn_cast(currOp)) || currOp->isa())) { if (currAffineForOp) loops->push_back(currAffineForOp); @@ -239,7 +239,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth, assert(isValidSymbol(symbol)); // Check if the symbol is a constant. if (auto *op = symbol->getDefiningOp()) { - if (auto constOp = op->dyn_cast()) { + if (auto constOp = dyn_cast(op)) { cst.setIdToConstant(*symbol, constOp.getValue()); } } @@ -467,7 +467,7 @@ static Operation *getInstAtPosition(ArrayRef positions, } if (level == positions.size() - 1) return &op; - if (auto childAffineForOp = op.dyn_cast()) + if (auto childAffineForOp = dyn_cast(op)) return getInstAtPosition(positions, level + 1, childAffineForOp.getBody()); @@ -633,7 +633,7 @@ mlir::insertBackwardComputationSlice(Operation *srcOpInst, Operation *dstOpInst, // Constructs MemRefAccess populating it with the memref, its indices and // opinst from 'loadOrStoreOpInst'. MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) { - if (auto loadOp = loadOrStoreOpInst->dyn_cast()) { + if (auto loadOp = dyn_cast(loadOrStoreOpInst)) { memref = loadOp.getMemRef(); opInst = loadOrStoreOpInst; auto loadMemrefType = loadOp.getMemRefType(); @@ -643,7 +643,7 @@ MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) { } } else { assert(loadOrStoreOpInst->isa() && "load/store op expected"); - auto storeOp = loadOrStoreOpInst->dyn_cast(); + auto storeOp = dyn_cast(loadOrStoreOpInst); opInst = loadOrStoreOpInst; memref = storeOp.getMemRef(); auto storeMemrefType = storeOp.getMemRefType(); @@ -750,7 +750,7 @@ Optional mlir::getMemoryFootprintBytes(AffineForOp forOp, void mlir::getSequentialLoops( AffineForOp forOp, llvm::SmallDenseSet *sequentialLoops) { forOp.getOperation()->walk([&](Operation *op) { - if (auto innerFor = op->dyn_cast()) + if (auto innerFor = dyn_cast(op)) if (!isLoopParallel(innerFor)) sequentialLoops->insert(innerFor.getInductionVar()); }); diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index b45ac001be4..8fecf058bfc 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -152,7 +152,7 @@ static SetVector getParentsOfType(Operation *op) { SetVector res; auto *current = op; while (auto *parent = current->getParentOp()) { - if (auto typedParent = parent->template dyn_cast()) { + if (auto typedParent = dyn_cast(parent)) { assert(res.count(parent) == 0 && "Already inserted"); res.insert(parent); } @@ -177,7 +177,7 @@ AffineMap mlir::makePermutationMap( } } - if (auto load = op->dyn_cast()) { + if (auto load = dyn_cast(op)) { return ::makePermutationMap(load.getIndices(), enclosingLoopToVectorDim); } @@ -198,10 +198,10 @@ bool mlir::matcher::operatesOnSuperVectorsOf(Operation &op, /// do not have to special case. Maybe a trait, or just a method, unclear atm. bool mustDivide = false; VectorType superVectorType; - if (auto read = op.dyn_cast()) { + if (auto read = dyn_cast(op)) { superVectorType = read.getResultType(); mustDivide = true; - } else if (auto write = op.dyn_cast()) { + } else if (auto write = dyn_cast(op)) { superVectorType = write.getVectorType(); mustDivide = true; } else if (op.getNumResults() == 0) { diff --git a/mlir/lib/EDSC/Builders.cpp b/mlir/lib/EDSC/Builders.cpp index 610c8b66320..2c9117736ae 100644 --- a/mlir/lib/EDSC/Builders.cpp +++ b/mlir/lib/EDSC/Builders.cpp @@ -100,7 +100,7 @@ ValueHandle ValueHandle::create(StringRef name, ArrayRef operands, if (op->getNumResults() == 1) { return ValueHandle(op->getResult(0)); } - if (auto f = op->dyn_cast()) { + if (auto f = dyn_cast(op)) { return ValueHandle(f.getInductionVar()); } llvm_unreachable("unsupported operation, use an OperationHandle instead"); @@ -147,8 +147,8 @@ static llvm::Optional emitStaticFor(ArrayRef lbs, if (!lbDef || !ubDef) return llvm::Optional(); - auto lbConst = lbDef->dyn_cast(); - auto ubConst = ubDef->dyn_cast(); + auto lbConst = dyn_cast(lbDef); + auto ubConst = dyn_cast(ubDef); if (!lbConst || !ubConst) return llvm::Optional(); diff --git a/mlir/lib/Linalg/Transforms/Tiling.cpp b/mlir/lib/Linalg/Transforms/Tiling.cpp index 434f7206e04..6e20542a818 100644 --- a/mlir/lib/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Linalg/Transforms/Tiling.cpp @@ -319,11 +319,11 @@ static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef tileSizes, // TODO(ntv) expose as a primitive for other passes. static LogicalResult tileLinalgOp(Operation *op, ArrayRef tileSizes, PerFunctionState &state) { - if (auto matmulOp = op->dyn_cast()) { + if (auto matmulOp = dyn_cast(op)) { return tileLinalgOp(matmulOp, tileSizes, state); - } else if (auto matvecOp = op->dyn_cast()) { + } else if (auto matvecOp = dyn_cast(op)) { return tileLinalgOp(matvecOp, tileSizes, state); - } else if (auto dotOp = op->dyn_cast()) { + } else if (auto dotOp = dyn_cast(op)) { return tileLinalgOp(dotOp, tileSizes, state); } return failure(); diff --git a/mlir/lib/Linalg/Utils/Utils.cpp b/mlir/lib/Linalg/Utils/Utils.cpp index 4b77ece21dd..98cf4b75b6a 100644 --- a/mlir/lib/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Linalg/Utils/Utils.cpp @@ -68,9 +68,9 @@ ValueHandle LoopNestRangeBuilder::LoopNestRangeBuilder::operator()( SmallVector mlir::getRanges(Operation *op) { SmallVector res; - if (auto view = op->dyn_cast()) { + if (auto view = dyn_cast(op)) { res.append(view.getIndexings().begin(), view.getIndexings().end()); - } else if (auto slice = op->dyn_cast()) { + } else if (auto slice = dyn_cast(op)) { for (auto *i : slice.getIndexings()) if (i->getType().isa()) res.push_back(i); @@ -100,7 +100,7 @@ SmallVector mlir::getRanges(Operation *op) { Value *mlir::createOrReturnView(FuncBuilder *b, Location loc, Operation *viewDefiningOp, ArrayRef ranges) { - if (auto view = viewDefiningOp->dyn_cast()) { + if (auto view = dyn_cast(viewDefiningOp)) { auto indexings = view.getIndexings(); if (std::equal(indexings.begin(), indexings.end(), ranges.begin())) return view.getResult(); diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 05e3b13eb4c..bc68a78bd0a 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -134,7 +134,7 @@ struct MemRefCastFolder : public RewritePattern { void rewrite(Operation *op, PatternRewriter &rewriter) const override { for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) if (auto *memref = op->getOperand(i)->getDefiningOp()) - if (auto cast = memref->dyn_cast()) + if (auto cast = dyn_cast(memref)) op->setOperand(i, cast.getOperand()); rewriter.updatedRootInPlace(op); } diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 8a9c649feb3..597efc3ba37 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -199,11 +199,11 @@ bool ModuleTranslation::convertOperation(Operation &opInst, // Emit branches. We need to look up the remapped blocks and ignore the block // arguments that were transformed into PHI nodes. - if (auto brOp = opInst.dyn_cast()) { + if (auto brOp = dyn_cast(opInst)) { builder.CreateBr(blockMapping[brOp.getSuccessor(0)]); return false; } - if (auto condbrOp = opInst.dyn_cast()) { + if (auto condbrOp = dyn_cast(opInst)) { builder.CreateCondBr(valueMapping.lookup(condbrOp.getOperand(0)), blockMapping[condbrOp.getSuccessor(0)], blockMapping[condbrOp.getSuccessor(1)]); @@ -264,7 +264,7 @@ static Value *getPHISourceValue(Block *current, Block *pred, // For conditional branches, we need to check if the current block is reached // through the "true" or the "false" branch and take the relevant operands. - auto condBranchOp = terminator.dyn_cast(); + auto condBranchOp = dyn_cast(terminator); assert(condBranchOp && "only branch operations can be terminators of a block that " "has successors"); diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 10f47fe9be1..937399cc703 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -173,11 +173,11 @@ static void getMultiLevelStrides(const MemRefRegion ®ion, static bool getFullMemRefAsRegion(Operation *opInst, unsigned numParamLoopIVs, MemRefRegion *region) { unsigned rank; - if (auto loadOp = opInst->dyn_cast()) { + if (auto loadOp = dyn_cast(opInst)) { rank = loadOp.getMemRefType().getRank(); region->memref = loadOp.getMemRef(); region->setWrite(false); - } else if (auto storeOp = opInst->dyn_cast()) { + } else if (auto storeOp = dyn_cast(opInst)) { rank = storeOp.getMemRefType().getRank(); region->memref = storeOp.getMemRef(); region->setWrite(true); @@ -483,7 +483,7 @@ bool DmaGeneration::runOnBlock(Block *block) { }); for (auto it = curBegin; it != block->end(); ++it) { - if (auto forOp = it->dyn_cast()) { + if (auto forOp = dyn_cast(&*it)) { // Returns true if the footprint is known to exceed capacity. auto exceedsCapacity = [&](AffineForOp forOp) { Optional footprint = @@ -607,10 +607,10 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { // Walk this range of operations to gather all memory regions. block->walk(begin, end, [&](Operation *opInst) { // Gather regions to allocate to buffers in faster memory space. - if (auto loadOp = opInst->dyn_cast()) { + if (auto loadOp = dyn_cast(opInst)) { if (loadOp.getMemRefType().getMemorySpace() != slowMemorySpace) return; - } else if (auto storeOp = opInst->dyn_cast()) { + } else if (auto storeOp = dyn_cast(opInst)) { if (storeOp.getMemRefType().getMemorySpace() != slowMemorySpace) return; } else { @@ -739,7 +739,7 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { // For a range of operations, a note will be emitted at the caller. AffineForOp forOp; uint64_t sizeInKib = llvm::divideCeil(totalDmaBuffersSizeInBytes, 1024); - if (llvm::DebugFlag && (forOp = begin->dyn_cast())) { + if (llvm::DebugFlag && (forOp = dyn_cast(&*begin))) { forOp.emitRemark() << sizeInKib << " KiB of DMA buffers in fast memory space for this block\n"; diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 796d2164ad9..1c4a4d1f755 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -644,7 +644,7 @@ bool MemRefDependenceGraph::init(Function &f) { DenseMap forToNodeMap; for (auto &op : f.front()) { - if (auto forOp = op.dyn_cast()) { + if (auto forOp = dyn_cast(op)) { // Create graph node 'id' to represent top-level 'forOp' and record // all loads and store accesses it contains. LoopNestStateCollector collector; @@ -666,14 +666,14 @@ bool MemRefDependenceGraph::init(Function &f) { } forToNodeMap[&op] = node.id; nodes.insert({node.id, node}); - } else if (auto loadOp = op.dyn_cast()) { + } else if (auto loadOp = dyn_cast(op)) { // Create graph node for top-level load op. Node node(nextNodeId++, &op); node.loads.push_back(&op); auto *memref = op.cast().getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); - } else if (auto storeOp = op.dyn_cast()) { + } else if (auto storeOp = dyn_cast(op)) { // Create graph node for top-level store op. Node node(nextNodeId++, &op); node.stores.push_back(&op); @@ -2125,7 +2125,7 @@ public: auto *fn = dstNode->op->getFunction(); for (unsigned i = 0, e = fn->getNumArguments(); i != e; ++i) { for (auto &use : fn->getArgument(i)->getUses()) { - if (auto loadOp = use.getOwner()->dyn_cast()) { + if (auto loadOp = dyn_cast(use.getOwner())) { // Gather loops surrounding 'use'. SmallVector loops; getLoopIVs(*use.getOwner(), &loops); diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index ce42a5eba85..28e13d89ada 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -273,7 +273,7 @@ static void getTileableBands(Function &f, for (auto &block : f) for (auto &op : block) - if (auto forOp = op.dyn_cast()) + if (auto forOp = dyn_cast(op)) getMaximalPerfectLoopNest(forOp); } diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 366a7ede5eb..0a23295c8d9 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -92,7 +92,7 @@ void LoopUnrollAndJam::runOnFunction() { // unroll-and-jammed by this pass. However, runOnAffineForOp can be called on // any for operation. auto &entryBlock = getFunction().front(); - if (auto forOp = entryBlock.front().dyn_cast()) + if (auto forOp = dyn_cast(entryBlock.front())) runOnAffineForOp(forOp); } diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index dc389c8e37a..1ffe5e3ddd7 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -620,10 +620,10 @@ void LowerAffinePass::runOnFunction() { // Rewrite all of the ifs and fors. We walked the operations in postorders, // so we know that we will rewrite them in the reverse order. for (auto *op : llvm::reverse(instsToRewrite)) { - if (auto ifOp = op->dyn_cast()) { + if (auto ifOp = dyn_cast(op)) { if (lowerAffineIf(ifOp)) return signalPassFailure(); - } else if (auto forOp = op->dyn_cast()) { + } else if (auto forOp = dyn_cast(op)) { if (lowerAffineFor(forOp)) return signalPassFailure(); } else if (lowerAffineApply(op->cast())) { diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 2f06a9aa3bf..28dfb2278e0 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -556,12 +556,12 @@ static bool instantiateMaterialization(Operation *op, if (op->getNumRegions() != 0) return op->emitError("NYI path Op with region"), true; - if (auto write = op->dyn_cast()) { + if (auto write = dyn_cast(op)) { auto *clone = instantiate(&b, write, state->hwVectorType, state->hwVectorInstance, state->substitutionsMap); return clone == nullptr; } - if (auto read = op->dyn_cast()) { + if (auto read = dyn_cast(op)) { auto *clone = instantiate(&b, read, state->hwVectorType, state->hwVectorInstance, state->substitutionsMap); if (!clone) { diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index a63d462c4a9..94df936c93f 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -103,7 +103,7 @@ void MemRefDataFlowOpt::forwardStoreToLoad(LoadOp loadOp) { SmallVector storeOps; unsigned minSurroundingLoops = getNestingDepth(*loadOpInst); for (auto &use : loadOp.getMemRef()->getUses()) { - auto storeOp = use.getOwner()->dyn_cast(); + auto storeOp = dyn_cast(use.getOwner()); if (!storeOp) continue; auto *storeOpInst = storeOp.getOperation(); diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 66fbf4a1306..0da97f7d169 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -181,7 +181,7 @@ static void findMatchingStartFinishInsts( // Collect outgoing DMA operations - needed to check for dependences below. SmallVector outgoingDmaOps; for (auto &op : *forOp.getBody()) { - auto dmaStartOp = op.dyn_cast(); + auto dmaStartOp = dyn_cast(op); if (dmaStartOp && dmaStartOp.isSrcMemorySpaceFaster()) outgoingDmaOps.push_back(dmaStartOp); } @@ -193,7 +193,7 @@ static void findMatchingStartFinishInsts( dmaFinishInsts.push_back(&op); continue; } - auto dmaStartOp = op.dyn_cast(); + auto dmaStartOp = dyn_cast(op); if (!dmaStartOp) continue; diff --git a/mlir/lib/Transforms/TestConstantFold.cpp b/mlir/lib/Transforms/TestConstantFold.cpp index 0990d7a73f6..ec1e971973e 100644 --- a/mlir/lib/Transforms/TestConstantFold.cpp +++ b/mlir/lib/Transforms/TestConstantFold.cpp @@ -48,7 +48,7 @@ void TestConstantFold::foldOperation(Operation *op, } // If this op is a constant that are used and cannot be de-duplicated, // remember it for cleanup later. - else if (auto constant = op->dyn_cast()) { + else if (auto constant = dyn_cast(op)) { existingConstants.push_back(op); } } diff --git a/mlir/lib/Transforms/Utils/ConstantFoldUtils.cpp b/mlir/lib/Transforms/Utils/ConstantFoldUtils.cpp index fc8209be872..b907840b27d 100644 --- a/mlir/lib/Transforms/Utils/ConstantFoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/ConstantFoldUtils.cpp @@ -40,7 +40,7 @@ bool ConstantFoldHelper::tryToConstantFold( // into the value it contains. We need to consider constants before the // constant folding logic to avoid re-creating the same constant later. // TODO: Extend to support dialect-specific constant ops. - if (auto constant = op->dyn_cast()) { + if (auto constant = dyn_cast(op)) { // If this constant is dead, update bookkeeping and signal the caller. if (constant.use_empty()) { notifyRemoval(op); diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index a10e4a1ae49..7fbb48ecf99 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -363,7 +363,7 @@ void mlir::getPerfectlyNestedLoops(SmallVectorImpl &nestedLoops, nestedLoops.push_back(curr); auto *currBody = curr.getBody(); while (currBody->begin() == std::prev(currBody->end(), 2) && - (curr = curr.getBody()->front().dyn_cast())) { + (curr = dyn_cast(curr.getBody()->front()))) { nestedLoops.push_back(curr); currBody = curr.getBody(); } diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index 753f7cf750f..b64dc53e037 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -234,7 +234,7 @@ void VectorizerTestPass::testComposeMaps(llvm::raw_ostream &outs) { static bool affineApplyOp(Operation &op) { return op.isa(); } static bool singleResultAffineApplyOpWithoutUses(Operation &op) { - auto app = op.dyn_cast(); + auto app = dyn_cast(op); return app && app.use_empty(); } diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 025a6535a78..9b8768a6445 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -839,8 +839,8 @@ static LogicalResult vectorizeAffineForOp(AffineForOp loop, int64_t step, loadAndStores.match(loop.getOperation(), &loadAndStoresMatches); for (auto ls : loadAndStoresMatches) { auto *opInst = ls.getMatchedOperation(); - auto load = opInst->dyn_cast(); - auto store = opInst->dyn_cast(); + auto load = dyn_cast(opInst); + auto store = dyn_cast(opInst); LLVM_DEBUG(opInst->print(dbgs())); LogicalResult result = load ? vectorizeRootOrTerminal(loop.getInductionVar(), load, state) @@ -982,7 +982,7 @@ static Value *vectorizeOperand(Value *operand, Operation *op, return nullptr; } // 3. vectorize constant. - if (auto constant = operand->getDefiningOp()->dyn_cast()) { + if (auto constant = dyn_cast(operand->getDefiningOp())) { return vectorizeConstant( op, constant, VectorType::get(state->strategy->vectorSizes, operand->getType())); @@ -1012,7 +1012,7 @@ static Operation *vectorizeOneOperation(Operation *opInst, assert(!opInst->isa() && "vector.transfer_write cannot be further vectorized"); - if (auto store = opInst->dyn_cast()) { + if (auto store = dyn_cast(opInst)) { auto *memRef = store.getMemRef(); auto *value = store.getValueToStore(); auto *vectorValue = vectorizeOperand(value, opInst, state); diff --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp index ec566e28825..5c34ed160b2 100644 --- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp +++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp @@ -161,8 +161,8 @@ static bool emitOneBuilder(const Record &record, raw_ostream &os) { } // Output the check and the rewritten builder string. - os << "if (auto op = opInst.dyn_cast<" << op.getQualCppClassName() - << ">()) {\n"; + os << "if (auto op = dyn_cast<" << op.getQualCppClassName() + << ">(opInst)) {\n"; os << bs.str() << builderStrRef << "\n"; os << " return false;\n"; os << "}\n"; -- cgit v1.2.3 From 41d90a85bd7942a9a27011f09c3a49cc32fdaeae Mon Sep 17 00:00:00 2001 From: MLIR Team Date: Sat, 11 May 2019 15:24:47 -0700 Subject: Automated rollback of changelist 247778391. PiperOrigin-RevId: 247778691 --- mlir/examples/Linalg/Linalg1/lib/Analysis.cpp | 4 +- mlir/examples/Linalg/Linalg1/lib/Common.cpp | 2 +- mlir/examples/Linalg/Linalg1/lib/Utils.cpp | 2 +- mlir/examples/Linalg/Linalg2/lib/Transforms.cpp | 4 +- .../Linalg/Linalg3/include/linalg3/TensorOps-inl.h | 2 +- mlir/examples/Linalg/Linalg3/lib/Transforms.cpp | 14 +++--- mlir/examples/Linalg/Linalg4/lib/Transforms.cpp | 16 +++---- mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp | 8 ++-- mlir/examples/toy/Ch4/mlir/ToyCombine.cpp | 4 +- mlir/examples/toy/Ch5/mlir/LateLowering.cpp | 8 ++-- mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp | 6 +-- mlir/include/mlir/EDSC/Builders.h | 2 +- mlir/include/mlir/IR/Builders.h | 2 +- mlir/include/mlir/IR/Function.h | 2 +- mlir/include/mlir/IR/OpDefinition.h | 12 +++-- mlir/include/mlir/IR/Operation.h | 51 ++++++++++------------ mlir/include/mlir/IR/PatternMatch.h | 4 +- mlir/include/mlir/Support/LLVM.h | 1 - mlir/lib/AffineOps/AffineOps.cpp | 8 ++-- mlir/lib/Analysis/LoopAnalysis.cpp | 8 ++-- mlir/lib/Analysis/MemRefBoundCheck.cpp | 4 +- mlir/lib/Analysis/SliceAnalysis.cpp | 2 +- mlir/lib/Analysis/Utils.cpp | 12 ++--- mlir/lib/Analysis/VectorAnalysis.cpp | 8 ++-- mlir/lib/EDSC/Builders.cpp | 6 +-- mlir/lib/Linalg/Transforms/Tiling.cpp | 6 +-- mlir/lib/Linalg/Utils/Utils.cpp | 6 +-- mlir/lib/StandardOps/Ops.cpp | 2 +- mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 6 +-- mlir/lib/Transforms/DmaGeneration.cpp | 12 ++--- mlir/lib/Transforms/LoopFusion.cpp | 8 ++-- mlir/lib/Transforms/LoopTiling.cpp | 2 +- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 2 +- mlir/lib/Transforms/LowerAffine.cpp | 4 +- mlir/lib/Transforms/MaterializeVectors.cpp | 4 +- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 2 +- mlir/lib/Transforms/PipelineDataTransfer.cpp | 4 +- mlir/lib/Transforms/TestConstantFold.cpp | 2 +- mlir/lib/Transforms/Utils/ConstantFoldUtils.cpp | 2 +- mlir/lib/Transforms/Utils/LoopUtils.cpp | 2 +- .../Vectorization/VectorizerTestPass.cpp | 2 +- mlir/lib/Transforms/Vectorize.cpp | 8 ++-- mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp | 4 +- 43 files changed, 130 insertions(+), 140 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/examples/Linalg/Linalg1/lib/Analysis.cpp b/mlir/examples/Linalg/Linalg1/lib/Analysis.cpp index a7fba179c79..ecb6309466a 100644 --- a/mlir/examples/Linalg/Linalg1/lib/Analysis.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/Analysis.cpp @@ -31,7 +31,7 @@ ViewOp linalg::getViewBaseViewOp(Value *view) { auto viewType = view->getType().dyn_cast(); (void)viewType; assert(viewType.isa() && "expected a ViewType"); - while (auto slice = dyn_cast(view->getDefiningOp())) { + while (auto slice = view->getDefiningOp()->dyn_cast()) { view = slice.getParentView(); assert(viewType.isa() && "expected a ViewType"); } @@ -48,7 +48,7 @@ std::pair linalg::getViewRootIndexing(Value *view, (void)viewType; assert(viewType.isa() && "expected a ViewType"); assert(dim < viewType.getRank() && "dim exceeds rank"); - if (auto viewOp = dyn_cast(view->getDefiningOp())) + if (auto viewOp = view->getDefiningOp()->dyn_cast()) return std::make_pair(viewOp.getIndexing(dim), dim); auto sliceOp = view->getDefiningOp()->cast(); diff --git a/mlir/examples/Linalg/Linalg1/lib/Common.cpp b/mlir/examples/Linalg/Linalg1/lib/Common.cpp index 278f9c57607..bfdc40a6aa0 100644 --- a/mlir/examples/Linalg/Linalg1/lib/Common.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/Common.cpp @@ -40,7 +40,7 @@ linalg::common::LoopNestRangeBuilder::LoopNestRangeBuilder( assert(ivs.size() == indexings.size()); for (unsigned i = 0, e = indexings.size(); i < e; ++i) { auto rangeOp = - llvm::dyn_cast(indexings[i].getValue()->getDefiningOp()); + indexings[i].getValue()->getDefiningOp()->dyn_cast(); if (!rangeOp) { continue; } diff --git a/mlir/examples/Linalg/Linalg1/lib/Utils.cpp b/mlir/examples/Linalg/Linalg1/lib/Utils.cpp index 5bcebc79c18..372c08f9eea 100644 --- a/mlir/examples/Linalg/Linalg1/lib/Utils.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/Utils.cpp @@ -33,7 +33,7 @@ using namespace linalg::intrinsics; unsigned linalg::getViewRank(Value *view) { assert(view->getType().isa() && "expected a ViewType"); - if (auto viewOp = dyn_cast(view->getDefiningOp())) + if (auto viewOp = view->getDefiningOp()->dyn_cast()) return viewOp.getRank(); return view->getDefiningOp()->cast().getRank(); } diff --git a/mlir/examples/Linalg/Linalg2/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg2/lib/Transforms.cpp index 83fd9ad3143..d1af7503d1b 100644 --- a/mlir/examples/Linalg/Linalg2/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg2/lib/Transforms.cpp @@ -43,7 +43,7 @@ using namespace linalg::intrinsics; // analyses. This builds the chain. static SmallVector getViewChain(mlir::Value *v) { assert(v->getType().isa() && "ViewType expected"); - if (v->getDefiningOp()->isa()) { + if (v->getDefiningOp()->dyn_cast()) { return SmallVector{v}; } @@ -53,7 +53,7 @@ static SmallVector getViewChain(mlir::Value *v) { tmp.push_back(v); v = sliceOp.getParentView(); } while (!v->getType().isa()); - assert(v->getDefiningOp()->isa() && "must be a ViewOp"); + assert(v->getDefiningOp()->cast() && "must be a ViewOp"); tmp.push_back(v); return SmallVector(tmp.rbegin(), tmp.rend()); } diff --git a/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps-inl.h b/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps-inl.h index 3090f29dcfc..9339d7309e3 100644 --- a/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps-inl.h +++ b/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps-inl.h @@ -91,7 +91,7 @@ inline llvm::SmallVector extractRangesFromViewOrSliceOp(mlir::Value *view) { // This expects a viewType which must come from either ViewOp or SliceOp. assert(view->getType().isa() && "expected ViewType"); - if (auto viewOp = llvm::dyn_cast(view->getDefiningOp())) + if (auto viewOp = view->getDefiningOp()->dyn_cast()) return viewOp.getRanges(); auto sliceOp = view->getDefiningOp()->cast(); diff --git a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp index bce7f58860d..42999aef7ae 100644 --- a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp @@ -46,9 +46,9 @@ void linalg::composeSliceOps(mlir::Function *f) { void linalg::lowerToFinerGrainedTensorContraction(mlir::Function *f) { f->walk([](Operation *op) { - if (auto matmulOp = dyn_cast(op)) { + if (auto matmulOp = op->dyn_cast()) { matmulOp.writeAsFinerGrainTensorContraction(); - } else if (auto matvecOp = dyn_cast(op)) { + } else if (auto matvecOp = op->dyn_cast()) { matvecOp.writeAsFinerGrainTensorContraction(); } else { return; @@ -205,11 +205,11 @@ writeContractionAsLoops(ContractionOp contraction) { llvm::Optional> linalg::writeAsLoops(Operation *op) { - if (auto matmulOp = dyn_cast(op)) { + if (auto matmulOp = op->dyn_cast()) { return writeContractionAsLoops(matmulOp); - } else if (auto matvecOp = dyn_cast(op)) { + } else if (auto matvecOp = op->dyn_cast()) { return writeContractionAsLoops(matvecOp); - } else if (auto dotOp = dyn_cast(op)) { + } else if (auto dotOp = op->dyn_cast()) { return writeContractionAsLoops(dotOp); } return llvm::None; @@ -276,7 +276,7 @@ PatternMatchResult Rewriter::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { auto load = op->cast(); - SliceOp slice = dyn_cast(load.getView()->getDefiningOp()); + SliceOp slice = load.getView()->getDefiningOp()->dyn_cast(); ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult()) : load.getView()->getDefiningOp()->cast(); ScopedContext scope(FuncBuilder(load), load.getLoc()); @@ -291,7 +291,7 @@ PatternMatchResult Rewriter::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { auto store = op->cast(); - SliceOp slice = dyn_cast(store.getView()->getDefiningOp()); + SliceOp slice = store.getView()->getDefiningOp()->dyn_cast(); ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult()) : store.getView()->getDefiningOp()->cast(); ScopedContext scope(FuncBuilder(store), store.getLoc()); diff --git a/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp index 6771257ae0f..05865e9e53c 100644 --- a/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp @@ -52,8 +52,8 @@ void linalg::lowerToTiledLoops(mlir::Function *f, } static bool isZeroIndex(Value *v) { - return isa_and_nonnull(v->getDefiningOp()) && - cast(v->getDefiningOp()).getValue() == 0; + return v->getDefiningOp() && v->getDefiningOp()->isa() && + v->getDefiningOp()->dyn_cast().getValue() == 0; } template @@ -178,11 +178,11 @@ writeContractionAsTiledViews(TensorContractionBase &contraction, llvm::Optional> linalg::writeAsTiledViews(Operation *op, ArrayRef tileSizes) { - if (auto matmulOp = dyn_cast(op)) { + if (auto matmulOp = op->dyn_cast()) { return writeContractionAsTiledViews(matmulOp, tileSizes); - } else if (auto matvecOp = dyn_cast(op)) { + } else if (auto matvecOp = op->dyn_cast()) { return writeContractionAsTiledViews(matvecOp, tileSizes); - } else if (auto dotOp = dyn_cast(op)) { + } else if (auto dotOp = op->dyn_cast()) { return writeContractionAsTiledViews(dotOp, tileSizes); } return llvm::None; @@ -190,11 +190,11 @@ linalg::writeAsTiledViews(Operation *op, ArrayRef tileSizes) { void linalg::lowerToTiledViews(mlir::Function *f, ArrayRef tileSizes) { f->walk([tileSizes](Operation *op) { - if (auto matmulOp = dyn_cast(op)) { + if (auto matmulOp = op->dyn_cast()) { writeAsTiledViews(matmulOp, tileSizes); - } else if (auto matvecOp = dyn_cast(op)) { + } else if (auto matvecOp = op->dyn_cast()) { writeAsTiledViews(matvecOp, tileSizes); - } else if (auto dotOp = dyn_cast(op)) { + } else if (auto dotOp = op->dyn_cast()) { writeAsTiledViews(dotOp, tileSizes); } else { return; diff --git a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp index c9f98e7d6a9..a11c88266b7 100644 --- a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp @@ -238,13 +238,13 @@ public: LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); // The add operation is trivial: propagate the input type as is. - if (auto addOp = llvm::dyn_cast(op)) { + if (auto addOp = op->dyn_cast()) { op->getResult(0)->setType(op->getOperand(0)->getType()); continue; } // Transpose is easy: just invert the dimensions. - if (auto transpose = llvm::dyn_cast(op)) { + if (auto transpose = op->dyn_cast()) { SmallVector dims; auto arrayTy = transpose.getOperand()->getType().cast(); dims.insert(dims.end(), arrayTy.getShape().begin(), @@ -259,7 +259,7 @@ public: // catch it but shape inference earlier in the pass could generate an // invalid IR (from an invalid Toy input of course) and we wouldn't want // to crash here. - if (auto mulOp = llvm::dyn_cast(op)) { + if (auto mulOp = op->dyn_cast()) { auto lhs = mulOp.getLHS()->getType().cast(); auto rhs = mulOp.getRHS()->getType().cast(); auto lhsRank = lhs.getShape().size(); @@ -291,7 +291,7 @@ public: // for this function, queue the callee in the inter-procedural work list, // and return. The current function stays in the work list and will // restart after the callee is processed. - if (auto callOp = llvm::dyn_cast(op)) { + if (auto callOp = op->dyn_cast()) { auto calleeName = callOp.getCalleeName(); auto *callee = getModule().getNamedFunction(calleeName); if (!callee) { diff --git a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp index 942ce866182..f3e8ff06781 100644 --- a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp @@ -53,7 +53,7 @@ struct SimplifyRedundantTranspose : public mlir::RewritePattern { // Look through the input of the current transpose. mlir::Value *transposeInput = transpose.getOperand(); TransposeOp transposeInputOp = - llvm::dyn_cast_or_null(transposeInput->getDefiningOp()); + mlir::dyn_cast_or_null(transposeInput->getDefiningOp()); // If the input is defined by another Transpose, bingo! if (!transposeInputOp) return matchFailure(); @@ -75,7 +75,7 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern { mlir::PatternRewriter &rewriter) const override { ReshapeOp reshape = op->cast(); // Look through the input of the current reshape. - ConstantOp constantOp = llvm::dyn_cast_or_null( + ConstantOp constantOp = mlir::dyn_cast_or_null( reshape.getOperand()->getDefiningOp()); // If the input is defined by another constant, bingo! if (!constantOp) diff --git a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp index 534b5cbd2ab..4ef62d33adc 100644 --- a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp @@ -366,7 +366,7 @@ struct LateLoweringPass : public ModulePass { // First patch calls type to return memref instead of ToyArray for (auto &function : getModule()) { function.walk([&](Operation *op) { - auto callOp = dyn_cast(op); + auto callOp = op->dyn_cast(); if (!callOp) return; if (!callOp.getNumResults()) @@ -382,14 +382,14 @@ struct LateLoweringPass : public ModulePass { for (auto &function : getModule()) { function.walk([&](Operation *op) { // Turns toy.alloc into sequence of alloc/dealloc (later malloc/free). - if (auto allocOp = dyn_cast(op)) { + if (auto allocOp = op->dyn_cast()) { auto result = allocTensor(allocOp); allocOp.replaceAllUsesWith(result); allocOp.erase(); return; } // Eliminate all type.cast before lowering to LLVM. - if (auto typeCastOp = dyn_cast(op)) { + if (auto typeCastOp = op->dyn_cast()) { typeCastOp.replaceAllUsesWith(typeCastOp.getOperand()); typeCastOp.erase(); return; @@ -429,7 +429,7 @@ struct LateLoweringPass : public ModulePass { // Insert a `dealloc` operation right before the `return` operations, unless // it is returned itself in which case the caller is responsible for it. builder.getFunction()->walk([&](Operation *op) { - auto returnOp = dyn_cast(op); + auto returnOp = op->dyn_cast(); if (!returnOp) return; if (returnOp.getNumOperands() && returnOp.getOperand(0) == alloc) diff --git a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp index 4e17b234d14..a083e62f05f 100644 --- a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp @@ -238,7 +238,7 @@ public: LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); // The add operation is trivial: propagate the input type as is. - if (auto addOp = llvm::dyn_cast(op)) { + if (auto addOp = op->dyn_cast()) { op->getResult(0)->setType(op->getOperand(0)->getType()); continue; } @@ -261,7 +261,7 @@ public: // catch it but shape inference earlier in the pass could generate an // invalid IR (from an invalid Toy input of course) and we wouldn't want // to crash here. - if (auto mulOp = llvm::dyn_cast(op)) { + if (auto mulOp = op->dyn_cast()) { auto lhs = mulOp.getLHS()->getType().cast(); auto rhs = mulOp.getRHS()->getType().cast(); auto lhsRank = lhs.getShape().size(); @@ -295,7 +295,7 @@ public: // for this function, queue the callee in the inter-procedural work list, // and return. The current function stays in the work list and will // restart after the callee is processed. - if (auto callOp = llvm::dyn_cast(op)) { + if (auto callOp = op->dyn_cast()) { auto calleeName = callOp.getCalleeName(); auto *callee = getModule().getNamedFunction(calleeName); if (!callee) { diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h index 39302f6c0f9..5d23488c95d 100644 --- a/mlir/include/mlir/EDSC/Builders.h +++ b/mlir/include/mlir/EDSC/Builders.h @@ -439,7 +439,7 @@ ValueHandle ValueHandle::create(Args... args) { if (op->getNumResults() == 1) { return ValueHandle(op->getResult(0)); } else if (op->getNumResults() == 0) { - if (auto f = dyn_cast(op)) { + if (auto f = op->dyn_cast()) { return ValueHandle(f.getInductionVar()); } } diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 7f182e882db..1ee6c4806fb 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -271,7 +271,7 @@ public: OperationState state(getContext(), location, OpTy::getOperationName()); OpTy::build(this, &state, args...); auto *op = createOperation(state); - auto result = dyn_cast(op); + auto result = op->dyn_cast(); assert(result && "Builder didn't return the right type"); return result; } diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h index d4b85b56d0f..0770d2cfa27 100644 --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -116,7 +116,7 @@ public: /// Specialization of walk to only visit operations of 'OpTy'. template void walk(std::function callback) { walk([&](Operation *opInst) { - if (auto op = dyn_cast(opInst)) + if (auto op = opInst->dyn_cast()) callback(op); }); } diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 2eff412a71e..b80e8aca9bc 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -792,7 +792,7 @@ public: /// This is the hook used by the AsmPrinter to emit this to the .mlir file. /// Op implementations should provide a print method. static void printAssembly(Operation *op, OpAsmPrinter *p) { - auto opPointer = dyn_cast(op); + auto opPointer = op->dyn_cast(); assert(opPointer && "op's name does not match name of concrete type instantiated with"); opPointer.print(p); @@ -825,13 +825,11 @@ public: /// This is a public constructor. Any op can be initialized to null. explicit Op() : OpState(nullptr) {} - Op(std::nullptr_t) : OpState(nullptr) {} - /// This is a public constructor to enable access via the llvm::cast family of - /// methods. This should not be used directly. - explicit Op(Operation *state) : OpState(state) { - assert(!state || isa(state)); - } +protected: + /// This is a private constructor only accessible through the + /// Operation::cast family of methods. + explicit Op(Operation *state) : OpState(state) {} friend class Operation; private: diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index 31ec8ea54a6..54e49b73e3b 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -389,6 +389,14 @@ public: // Conversions to declared operations like DimOp //===--------------------------------------------------------------------===// + /// The dyn_cast methods perform a dynamic cast from an Operation to a typed + /// Op like DimOp. This returns a null Op on failure. + template OpClass dyn_cast() { + if (isa()) + return cast(); + return OpClass(); + } + /// The cast methods perform a cast from an Operation to a typed Op like /// DimOp. This aborts if the parameter to the template isn't an instance of /// the template type argument. @@ -409,10 +417,10 @@ public: /// including this one. void walk(const std::function &callback); - /// Specialization of walk to only visit operations of 'T'. - template void walk(std::function callback) { + /// Specialization of walk to only visit operations of 'OpTy'. + template void walk(std::function callback) { walk([&](Operation *op) { - if (auto derivedOp = dyn_cast(op)) + if (auto derivedOp = op->dyn_cast()) callback(derivedOp); }); } @@ -526,6 +534,17 @@ inline auto Operation::getOperands() -> operand_range { return {operand_begin(), operand_end()}; } +/// Provide dyn_cast_or_null functionality for Operation casts. +template T dyn_cast_or_null(Operation *op) { + return op ? op->dyn_cast() : T(); +} + +/// Provide isa_and_nonnull functionality for Operation casts, i.e. if the +/// operation is non-null and a class of 'T'. +template bool isa_and_nonnull(Operation *op) { + return op && op->isa(); +} + /// This class implements the result iterators for the Operation class /// in terms of getResult(idx). class ResultIterator final @@ -579,30 +598,4 @@ inline auto Operation::getResultTypes() } // end namespace mlir -namespace llvm { -/// Provide isa functionality for operation casts. -template struct isa_impl { - static inline bool doit(const ::mlir::Operation &op) { - return T::classof(const_cast<::mlir::Operation *>(&op)); - } -}; - -/// Provide specializations for operation casts as the resulting T is value -/// typed. -template struct cast_retty_impl { - using ret_type = T; -}; -template struct cast_retty_impl { - using ret_type = T; -}; -template -struct cast_convert_val { - static T doit(::mlir::Operation &val) { return T(&val); } -}; -template -struct cast_convert_val { - static T doit(::mlir::Operation *val) { return T(val); } -}; -} // end namespace llvm - #endif // MLIR_IR_OPERATION_H diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 51528c18d38..3b02ed55c34 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -215,7 +215,7 @@ public: OperationState state(getContext(), location, OpTy::getOperationName()); OpTy::build(this, &state, args...); auto *op = createOperation(state); - auto result = dyn_cast(op); + auto result = op->dyn_cast(); assert(result && "Builder didn't return the right type"); return result; } @@ -231,7 +231,7 @@ public: // If the Operation we produce is valid, return it. if (!OpTy::verifyInvariants(op)) { - auto result = dyn_cast(op); + auto result = op->dyn_cast(); assert(result && "Builder didn't return the right type"); return result; } diff --git a/mlir/include/mlir/Support/LLVM.h b/mlir/include/mlir/Support/LLVM.h index 6676ad0d818..031dceb518e 100644 --- a/mlir/include/mlir/Support/LLVM.h +++ b/mlir/include/mlir/Support/LLVM.h @@ -69,7 +69,6 @@ using llvm::cast_or_null; using llvm::dyn_cast; using llvm::dyn_cast_or_null; using llvm::isa; -using llvm::isa_and_nonnull; // Containers. using llvm::ArrayRef; diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 2dfed934ee0..51209da7385 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -61,11 +61,11 @@ bool mlir::isValidDim(Value *value) { if (op->getParentOp() == nullptr || op->isa()) return true; // Affine apply operation is ok if all of its operands are ok. - if (auto applyOp = dyn_cast(op)) + if (auto applyOp = op->dyn_cast()) return applyOp.isValidDim(); // The dim op is okay if its operand memref/tensor is defined at the top // level. - if (auto dimOp = dyn_cast(op)) + if (auto dimOp = op->dyn_cast()) return isTopLevelSymbol(dimOp.getOperand()); return false; } @@ -86,11 +86,11 @@ bool mlir::isValidSymbol(Value *value) { if (op->getParentOp() == nullptr || op->isa()) return true; // Affine apply operation is ok if all of its operands are ok. - if (auto applyOp = dyn_cast(op)) + if (auto applyOp = op->dyn_cast()) return applyOp.isValidSymbol(); // The dim op is okay if its operand memref/tensor is defined at the top // level. - if (auto dimOp = dyn_cast(op)) + if (auto dimOp = op->dyn_cast()) return isTopLevelSymbol(dimOp.getOperand()); return false; } diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 60f2b142986..78caa4c2625 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -320,8 +320,8 @@ isVectorizableLoopBodyWithOpCond(AffineForOp loop, loadAndStores.match(forOp, &loadAndStoresMatched); for (auto ls : loadAndStoresMatched) { auto *op = ls.getMatchedOperation(); - auto load = dyn_cast(op); - auto store = dyn_cast(op); + auto load = op->dyn_cast(); + auto store = op->dyn_cast(); // Only scalar types are considered vectorizable, all load/store must be // vectorizable for a loop to qualify as vectorizable. // TODO(ntv): ponder whether we want to be more general here. @@ -338,8 +338,8 @@ isVectorizableLoopBodyWithOpCond(AffineForOp loop, bool mlir::isVectorizableLoopBody(AffineForOp loop, int *memRefDim) { VectorizableOpFun fun([memRefDim](AffineForOp loop, Operation &op) { - auto load = dyn_cast(op); - auto store = dyn_cast(op); + auto load = op.dyn_cast(); + auto store = op.dyn_cast(); return load ? isContiguousAccess(loop.getInductionVar(), load, memRefDim) : isContiguousAccess(loop.getInductionVar(), store, memRefDim); }); diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index 4e23441d5a5..0fb88620fa1 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -48,9 +48,9 @@ FunctionPassBase *mlir::createMemRefBoundCheckPass() { void MemRefBoundCheck::runOnFunction() { getFunction().walk([](Operation *opInst) { - if (auto loadOp = dyn_cast(opInst)) { + if (auto loadOp = opInst->dyn_cast()) { boundCheckLoadOrStoreOp(loadOp); - } else if (auto storeOp = dyn_cast(opInst)) { + } else if (auto storeOp = opInst->dyn_cast()) { boundCheckLoadOrStoreOp(storeOp); } // TODO(bondhugula): do this for DMA ops as well. diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index 155a2bbbd1b..bce000a4c1f 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -50,7 +50,7 @@ static void getForwardSliceImpl(Operation *op, return; } - if (auto forOp = dyn_cast(op)) { + if (auto forOp = op->dyn_cast()) { for (auto &u : forOp.getInductionVar()->getUses()) { auto *ownerInst = u.getOwner(); if (forwardSlice->count(ownerInst) == 0) { diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 8d963e4739c..1eaab676567 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -44,7 +44,7 @@ void mlir::getLoopIVs(Operation &op, SmallVectorImpl *loops) { AffineForOp currAffineForOp; // Traverse up the hierarchy collecing all 'affine.for' operation while // skipping over 'affine.if' operations. - while (currOp && ((currAffineForOp = dyn_cast(currOp)) || + while (currOp && ((currAffineForOp = currOp->dyn_cast()) || currOp->isa())) { if (currAffineForOp) loops->push_back(currAffineForOp); @@ -239,7 +239,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth, assert(isValidSymbol(symbol)); // Check if the symbol is a constant. if (auto *op = symbol->getDefiningOp()) { - if (auto constOp = dyn_cast(op)) { + if (auto constOp = op->dyn_cast()) { cst.setIdToConstant(*symbol, constOp.getValue()); } } @@ -467,7 +467,7 @@ static Operation *getInstAtPosition(ArrayRef positions, } if (level == positions.size() - 1) return &op; - if (auto childAffineForOp = dyn_cast(op)) + if (auto childAffineForOp = op.dyn_cast()) return getInstAtPosition(positions, level + 1, childAffineForOp.getBody()); @@ -633,7 +633,7 @@ mlir::insertBackwardComputationSlice(Operation *srcOpInst, Operation *dstOpInst, // Constructs MemRefAccess populating it with the memref, its indices and // opinst from 'loadOrStoreOpInst'. MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) { - if (auto loadOp = dyn_cast(loadOrStoreOpInst)) { + if (auto loadOp = loadOrStoreOpInst->dyn_cast()) { memref = loadOp.getMemRef(); opInst = loadOrStoreOpInst; auto loadMemrefType = loadOp.getMemRefType(); @@ -643,7 +643,7 @@ MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) { } } else { assert(loadOrStoreOpInst->isa() && "load/store op expected"); - auto storeOp = dyn_cast(loadOrStoreOpInst); + auto storeOp = loadOrStoreOpInst->dyn_cast(); opInst = loadOrStoreOpInst; memref = storeOp.getMemRef(); auto storeMemrefType = storeOp.getMemRefType(); @@ -750,7 +750,7 @@ Optional mlir::getMemoryFootprintBytes(AffineForOp forOp, void mlir::getSequentialLoops( AffineForOp forOp, llvm::SmallDenseSet *sequentialLoops) { forOp.getOperation()->walk([&](Operation *op) { - if (auto innerFor = dyn_cast(op)) + if (auto innerFor = op->dyn_cast()) if (!isLoopParallel(innerFor)) sequentialLoops->insert(innerFor.getInductionVar()); }); diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index 8fecf058bfc..b45ac001be4 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -152,7 +152,7 @@ static SetVector getParentsOfType(Operation *op) { SetVector res; auto *current = op; while (auto *parent = current->getParentOp()) { - if (auto typedParent = dyn_cast(parent)) { + if (auto typedParent = parent->template dyn_cast()) { assert(res.count(parent) == 0 && "Already inserted"); res.insert(parent); } @@ -177,7 +177,7 @@ AffineMap mlir::makePermutationMap( } } - if (auto load = dyn_cast(op)) { + if (auto load = op->dyn_cast()) { return ::makePermutationMap(load.getIndices(), enclosingLoopToVectorDim); } @@ -198,10 +198,10 @@ bool mlir::matcher::operatesOnSuperVectorsOf(Operation &op, /// do not have to special case. Maybe a trait, or just a method, unclear atm. bool mustDivide = false; VectorType superVectorType; - if (auto read = dyn_cast(op)) { + if (auto read = op.dyn_cast()) { superVectorType = read.getResultType(); mustDivide = true; - } else if (auto write = dyn_cast(op)) { + } else if (auto write = op.dyn_cast()) { superVectorType = write.getVectorType(); mustDivide = true; } else if (op.getNumResults() == 0) { diff --git a/mlir/lib/EDSC/Builders.cpp b/mlir/lib/EDSC/Builders.cpp index 2c9117736ae..610c8b66320 100644 --- a/mlir/lib/EDSC/Builders.cpp +++ b/mlir/lib/EDSC/Builders.cpp @@ -100,7 +100,7 @@ ValueHandle ValueHandle::create(StringRef name, ArrayRef operands, if (op->getNumResults() == 1) { return ValueHandle(op->getResult(0)); } - if (auto f = dyn_cast(op)) { + if (auto f = op->dyn_cast()) { return ValueHandle(f.getInductionVar()); } llvm_unreachable("unsupported operation, use an OperationHandle instead"); @@ -147,8 +147,8 @@ static llvm::Optional emitStaticFor(ArrayRef lbs, if (!lbDef || !ubDef) return llvm::Optional(); - auto lbConst = dyn_cast(lbDef); - auto ubConst = dyn_cast(ubDef); + auto lbConst = lbDef->dyn_cast(); + auto ubConst = ubDef->dyn_cast(); if (!lbConst || !ubConst) return llvm::Optional(); diff --git a/mlir/lib/Linalg/Transforms/Tiling.cpp b/mlir/lib/Linalg/Transforms/Tiling.cpp index 6e20542a818..434f7206e04 100644 --- a/mlir/lib/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Linalg/Transforms/Tiling.cpp @@ -319,11 +319,11 @@ static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef tileSizes, // TODO(ntv) expose as a primitive for other passes. static LogicalResult tileLinalgOp(Operation *op, ArrayRef tileSizes, PerFunctionState &state) { - if (auto matmulOp = dyn_cast(op)) { + if (auto matmulOp = op->dyn_cast()) { return tileLinalgOp(matmulOp, tileSizes, state); - } else if (auto matvecOp = dyn_cast(op)) { + } else if (auto matvecOp = op->dyn_cast()) { return tileLinalgOp(matvecOp, tileSizes, state); - } else if (auto dotOp = dyn_cast(op)) { + } else if (auto dotOp = op->dyn_cast()) { return tileLinalgOp(dotOp, tileSizes, state); } return failure(); diff --git a/mlir/lib/Linalg/Utils/Utils.cpp b/mlir/lib/Linalg/Utils/Utils.cpp index 98cf4b75b6a..4b77ece21dd 100644 --- a/mlir/lib/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Linalg/Utils/Utils.cpp @@ -68,9 +68,9 @@ ValueHandle LoopNestRangeBuilder::LoopNestRangeBuilder::operator()( SmallVector mlir::getRanges(Operation *op) { SmallVector res; - if (auto view = dyn_cast(op)) { + if (auto view = op->dyn_cast()) { res.append(view.getIndexings().begin(), view.getIndexings().end()); - } else if (auto slice = dyn_cast(op)) { + } else if (auto slice = op->dyn_cast()) { for (auto *i : slice.getIndexings()) if (i->getType().isa()) res.push_back(i); @@ -100,7 +100,7 @@ SmallVector mlir::getRanges(Operation *op) { Value *mlir::createOrReturnView(FuncBuilder *b, Location loc, Operation *viewDefiningOp, ArrayRef ranges) { - if (auto view = dyn_cast(viewDefiningOp)) { + if (auto view = viewDefiningOp->dyn_cast()) { auto indexings = view.getIndexings(); if (std::equal(indexings.begin(), indexings.end(), ranges.begin())) return view.getResult(); diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index bc68a78bd0a..05e3b13eb4c 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -134,7 +134,7 @@ struct MemRefCastFolder : public RewritePattern { void rewrite(Operation *op, PatternRewriter &rewriter) const override { for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) if (auto *memref = op->getOperand(i)->getDefiningOp()) - if (auto cast = dyn_cast(memref)) + if (auto cast = memref->dyn_cast()) op->setOperand(i, cast.getOperand()); rewriter.updatedRootInPlace(op); } diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 597efc3ba37..8a9c649feb3 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -199,11 +199,11 @@ bool ModuleTranslation::convertOperation(Operation &opInst, // Emit branches. We need to look up the remapped blocks and ignore the block // arguments that were transformed into PHI nodes. - if (auto brOp = dyn_cast(opInst)) { + if (auto brOp = opInst.dyn_cast()) { builder.CreateBr(blockMapping[brOp.getSuccessor(0)]); return false; } - if (auto condbrOp = dyn_cast(opInst)) { + if (auto condbrOp = opInst.dyn_cast()) { builder.CreateCondBr(valueMapping.lookup(condbrOp.getOperand(0)), blockMapping[condbrOp.getSuccessor(0)], blockMapping[condbrOp.getSuccessor(1)]); @@ -264,7 +264,7 @@ static Value *getPHISourceValue(Block *current, Block *pred, // For conditional branches, we need to check if the current block is reached // through the "true" or the "false" branch and take the relevant operands. - auto condBranchOp = dyn_cast(terminator); + auto condBranchOp = terminator.dyn_cast(); assert(condBranchOp && "only branch operations can be terminators of a block that " "has successors"); diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 937399cc703..10f47fe9be1 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -173,11 +173,11 @@ static void getMultiLevelStrides(const MemRefRegion ®ion, static bool getFullMemRefAsRegion(Operation *opInst, unsigned numParamLoopIVs, MemRefRegion *region) { unsigned rank; - if (auto loadOp = dyn_cast(opInst)) { + if (auto loadOp = opInst->dyn_cast()) { rank = loadOp.getMemRefType().getRank(); region->memref = loadOp.getMemRef(); region->setWrite(false); - } else if (auto storeOp = dyn_cast(opInst)) { + } else if (auto storeOp = opInst->dyn_cast()) { rank = storeOp.getMemRefType().getRank(); region->memref = storeOp.getMemRef(); region->setWrite(true); @@ -483,7 +483,7 @@ bool DmaGeneration::runOnBlock(Block *block) { }); for (auto it = curBegin; it != block->end(); ++it) { - if (auto forOp = dyn_cast(&*it)) { + if (auto forOp = it->dyn_cast()) { // Returns true if the footprint is known to exceed capacity. auto exceedsCapacity = [&](AffineForOp forOp) { Optional footprint = @@ -607,10 +607,10 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { // Walk this range of operations to gather all memory regions. block->walk(begin, end, [&](Operation *opInst) { // Gather regions to allocate to buffers in faster memory space. - if (auto loadOp = dyn_cast(opInst)) { + if (auto loadOp = opInst->dyn_cast()) { if (loadOp.getMemRefType().getMemorySpace() != slowMemorySpace) return; - } else if (auto storeOp = dyn_cast(opInst)) { + } else if (auto storeOp = opInst->dyn_cast()) { if (storeOp.getMemRefType().getMemorySpace() != slowMemorySpace) return; } else { @@ -739,7 +739,7 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { // For a range of operations, a note will be emitted at the caller. AffineForOp forOp; uint64_t sizeInKib = llvm::divideCeil(totalDmaBuffersSizeInBytes, 1024); - if (llvm::DebugFlag && (forOp = dyn_cast(&*begin))) { + if (llvm::DebugFlag && (forOp = begin->dyn_cast())) { forOp.emitRemark() << sizeInKib << " KiB of DMA buffers in fast memory space for this block\n"; diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 1c4a4d1f755..796d2164ad9 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -644,7 +644,7 @@ bool MemRefDependenceGraph::init(Function &f) { DenseMap forToNodeMap; for (auto &op : f.front()) { - if (auto forOp = dyn_cast(op)) { + if (auto forOp = op.dyn_cast()) { // Create graph node 'id' to represent top-level 'forOp' and record // all loads and store accesses it contains. LoopNestStateCollector collector; @@ -666,14 +666,14 @@ bool MemRefDependenceGraph::init(Function &f) { } forToNodeMap[&op] = node.id; nodes.insert({node.id, node}); - } else if (auto loadOp = dyn_cast(op)) { + } else if (auto loadOp = op.dyn_cast()) { // Create graph node for top-level load op. Node node(nextNodeId++, &op); node.loads.push_back(&op); auto *memref = op.cast().getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); - } else if (auto storeOp = dyn_cast(op)) { + } else if (auto storeOp = op.dyn_cast()) { // Create graph node for top-level store op. Node node(nextNodeId++, &op); node.stores.push_back(&op); @@ -2125,7 +2125,7 @@ public: auto *fn = dstNode->op->getFunction(); for (unsigned i = 0, e = fn->getNumArguments(); i != e; ++i) { for (auto &use : fn->getArgument(i)->getUses()) { - if (auto loadOp = dyn_cast(use.getOwner())) { + if (auto loadOp = use.getOwner()->dyn_cast()) { // Gather loops surrounding 'use'. SmallVector loops; getLoopIVs(*use.getOwner(), &loops); diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 28e13d89ada..ce42a5eba85 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -273,7 +273,7 @@ static void getTileableBands(Function &f, for (auto &block : f) for (auto &op : block) - if (auto forOp = dyn_cast(op)) + if (auto forOp = op.dyn_cast()) getMaximalPerfectLoopNest(forOp); } diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 0a23295c8d9..366a7ede5eb 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -92,7 +92,7 @@ void LoopUnrollAndJam::runOnFunction() { // unroll-and-jammed by this pass. However, runOnAffineForOp can be called on // any for operation. auto &entryBlock = getFunction().front(); - if (auto forOp = dyn_cast(entryBlock.front())) + if (auto forOp = entryBlock.front().dyn_cast()) runOnAffineForOp(forOp); } diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index 1ffe5e3ddd7..dc389c8e37a 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -620,10 +620,10 @@ void LowerAffinePass::runOnFunction() { // Rewrite all of the ifs and fors. We walked the operations in postorders, // so we know that we will rewrite them in the reverse order. for (auto *op : llvm::reverse(instsToRewrite)) { - if (auto ifOp = dyn_cast(op)) { + if (auto ifOp = op->dyn_cast()) { if (lowerAffineIf(ifOp)) return signalPassFailure(); - } else if (auto forOp = dyn_cast(op)) { + } else if (auto forOp = op->dyn_cast()) { if (lowerAffineFor(forOp)) return signalPassFailure(); } else if (lowerAffineApply(op->cast())) { diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 28dfb2278e0..2f06a9aa3bf 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -556,12 +556,12 @@ static bool instantiateMaterialization(Operation *op, if (op->getNumRegions() != 0) return op->emitError("NYI path Op with region"), true; - if (auto write = dyn_cast(op)) { + if (auto write = op->dyn_cast()) { auto *clone = instantiate(&b, write, state->hwVectorType, state->hwVectorInstance, state->substitutionsMap); return clone == nullptr; } - if (auto read = dyn_cast(op)) { + if (auto read = op->dyn_cast()) { auto *clone = instantiate(&b, read, state->hwVectorType, state->hwVectorInstance, state->substitutionsMap); if (!clone) { diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index 94df936c93f..a63d462c4a9 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -103,7 +103,7 @@ void MemRefDataFlowOpt::forwardStoreToLoad(LoadOp loadOp) { SmallVector storeOps; unsigned minSurroundingLoops = getNestingDepth(*loadOpInst); for (auto &use : loadOp.getMemRef()->getUses()) { - auto storeOp = dyn_cast(use.getOwner()); + auto storeOp = use.getOwner()->dyn_cast(); if (!storeOp) continue; auto *storeOpInst = storeOp.getOperation(); diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 0da97f7d169..66fbf4a1306 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -181,7 +181,7 @@ static void findMatchingStartFinishInsts( // Collect outgoing DMA operations - needed to check for dependences below. SmallVector outgoingDmaOps; for (auto &op : *forOp.getBody()) { - auto dmaStartOp = dyn_cast(op); + auto dmaStartOp = op.dyn_cast(); if (dmaStartOp && dmaStartOp.isSrcMemorySpaceFaster()) outgoingDmaOps.push_back(dmaStartOp); } @@ -193,7 +193,7 @@ static void findMatchingStartFinishInsts( dmaFinishInsts.push_back(&op); continue; } - auto dmaStartOp = dyn_cast(op); + auto dmaStartOp = op.dyn_cast(); if (!dmaStartOp) continue; diff --git a/mlir/lib/Transforms/TestConstantFold.cpp b/mlir/lib/Transforms/TestConstantFold.cpp index ec1e971973e..0990d7a73f6 100644 --- a/mlir/lib/Transforms/TestConstantFold.cpp +++ b/mlir/lib/Transforms/TestConstantFold.cpp @@ -48,7 +48,7 @@ void TestConstantFold::foldOperation(Operation *op, } // If this op is a constant that are used and cannot be de-duplicated, // remember it for cleanup later. - else if (auto constant = dyn_cast(op)) { + else if (auto constant = op->dyn_cast()) { existingConstants.push_back(op); } } diff --git a/mlir/lib/Transforms/Utils/ConstantFoldUtils.cpp b/mlir/lib/Transforms/Utils/ConstantFoldUtils.cpp index b907840b27d..fc8209be872 100644 --- a/mlir/lib/Transforms/Utils/ConstantFoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/ConstantFoldUtils.cpp @@ -40,7 +40,7 @@ bool ConstantFoldHelper::tryToConstantFold( // into the value it contains. We need to consider constants before the // constant folding logic to avoid re-creating the same constant later. // TODO: Extend to support dialect-specific constant ops. - if (auto constant = dyn_cast(op)) { + if (auto constant = op->dyn_cast()) { // If this constant is dead, update bookkeeping and signal the caller. if (constant.use_empty()) { notifyRemoval(op); diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 7fbb48ecf99..a10e4a1ae49 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -363,7 +363,7 @@ void mlir::getPerfectlyNestedLoops(SmallVectorImpl &nestedLoops, nestedLoops.push_back(curr); auto *currBody = curr.getBody(); while (currBody->begin() == std::prev(currBody->end(), 2) && - (curr = dyn_cast(curr.getBody()->front()))) { + (curr = curr.getBody()->front().dyn_cast())) { nestedLoops.push_back(curr); currBody = curr.getBody(); } diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index b64dc53e037..753f7cf750f 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -234,7 +234,7 @@ void VectorizerTestPass::testComposeMaps(llvm::raw_ostream &outs) { static bool affineApplyOp(Operation &op) { return op.isa(); } static bool singleResultAffineApplyOpWithoutUses(Operation &op) { - auto app = dyn_cast(op); + auto app = op.dyn_cast(); return app && app.use_empty(); } diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 9b8768a6445..025a6535a78 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -839,8 +839,8 @@ static LogicalResult vectorizeAffineForOp(AffineForOp loop, int64_t step, loadAndStores.match(loop.getOperation(), &loadAndStoresMatches); for (auto ls : loadAndStoresMatches) { auto *opInst = ls.getMatchedOperation(); - auto load = dyn_cast(opInst); - auto store = dyn_cast(opInst); + auto load = opInst->dyn_cast(); + auto store = opInst->dyn_cast(); LLVM_DEBUG(opInst->print(dbgs())); LogicalResult result = load ? vectorizeRootOrTerminal(loop.getInductionVar(), load, state) @@ -982,7 +982,7 @@ static Value *vectorizeOperand(Value *operand, Operation *op, return nullptr; } // 3. vectorize constant. - if (auto constant = dyn_cast(operand->getDefiningOp())) { + if (auto constant = operand->getDefiningOp()->dyn_cast()) { return vectorizeConstant( op, constant, VectorType::get(state->strategy->vectorSizes, operand->getType())); @@ -1012,7 +1012,7 @@ static Operation *vectorizeOneOperation(Operation *opInst, assert(!opInst->isa() && "vector.transfer_write cannot be further vectorized"); - if (auto store = dyn_cast(opInst)) { + if (auto store = opInst->dyn_cast()) { auto *memRef = store.getMemRef(); auto *value = store.getValueToStore(); auto *vectorValue = vectorizeOperand(value, opInst, state); diff --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp index 5c34ed160b2..ec566e28825 100644 --- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp +++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp @@ -161,8 +161,8 @@ static bool emitOneBuilder(const Record &record, raw_ostream &os) { } // Output the check and the rewritten builder string. - os << "if (auto op = dyn_cast<" << op.getQualCppClassName() - << ">(opInst)) {\n"; + os << "if (auto op = opInst.dyn_cast<" << op.getQualCppClassName() + << ">()) {\n"; os << bs.str() << builderStrRef << "\n"; os << " return false;\n"; os << "}\n"; -- cgit v1.2.3 From c5ecf9910a209d96ab768a205783871b4316d711 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Sat, 11 May 2019 15:56:50 -0700 Subject: Add support for using llvm::dyn_cast/cast/isa for operation casts and replace usages of Operation::dyn_cast with llvm::dyn_cast. -- PiperOrigin-RevId: 247780086 --- mlir/examples/Linalg/Linalg1/lib/Analysis.cpp | 4 +- mlir/examples/Linalg/Linalg1/lib/Common.cpp | 2 +- mlir/examples/Linalg/Linalg1/lib/Utils.cpp | 2 +- mlir/examples/Linalg/Linalg2/lib/Transforms.cpp | 4 +- .../Linalg/Linalg3/include/linalg3/TensorOps-inl.h | 2 +- mlir/examples/Linalg/Linalg3/lib/Transforms.cpp | 14 +++--- mlir/examples/Linalg/Linalg4/lib/Transforms.cpp | 16 +++---- mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp | 8 ++-- mlir/examples/toy/Ch4/mlir/ToyCombine.cpp | 4 +- mlir/examples/toy/Ch5/mlir/LateLowering.cpp | 8 ++-- mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp | 6 +-- mlir/include/mlir/EDSC/Builders.h | 2 +- mlir/include/mlir/IR/Builders.h | 2 +- mlir/include/mlir/IR/Function.h | 2 +- mlir/include/mlir/IR/OpDefinition.h | 12 ++--- mlir/include/mlir/IR/Operation.h | 51 ++++++++++++---------- mlir/include/mlir/IR/PatternMatch.h | 4 +- mlir/include/mlir/Support/LLVM.h | 1 + mlir/lib/AffineOps/AffineOps.cpp | 8 ++-- mlir/lib/Analysis/LoopAnalysis.cpp | 8 ++-- mlir/lib/Analysis/MemRefBoundCheck.cpp | 4 +- mlir/lib/Analysis/SliceAnalysis.cpp | 2 +- mlir/lib/Analysis/Utils.cpp | 12 ++--- mlir/lib/Analysis/VectorAnalysis.cpp | 8 ++-- mlir/lib/EDSC/Builders.cpp | 6 +-- mlir/lib/Linalg/Transforms/Tiling.cpp | 6 +-- mlir/lib/Linalg/Utils/Utils.cpp | 6 +-- mlir/lib/StandardOps/Ops.cpp | 2 +- mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 6 +-- mlir/lib/Transforms/DmaGeneration.cpp | 12 ++--- mlir/lib/Transforms/LoopFusion.cpp | 8 ++-- mlir/lib/Transforms/LoopTiling.cpp | 2 +- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 2 +- mlir/lib/Transforms/LowerAffine.cpp | 4 +- mlir/lib/Transforms/MaterializeVectors.cpp | 4 +- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 2 +- mlir/lib/Transforms/PipelineDataTransfer.cpp | 4 +- mlir/lib/Transforms/TestConstantFold.cpp | 2 +- mlir/lib/Transforms/Utils/ConstantFoldUtils.cpp | 2 +- mlir/lib/Transforms/Utils/LoopUtils.cpp | 2 +- .../Vectorization/VectorizerTestPass.cpp | 2 +- mlir/lib/Transforms/Vectorize.cpp | 8 ++-- mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp | 4 +- 43 files changed, 140 insertions(+), 130 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/examples/Linalg/Linalg1/lib/Analysis.cpp b/mlir/examples/Linalg/Linalg1/lib/Analysis.cpp index ecb6309466a..a7fba179c79 100644 --- a/mlir/examples/Linalg/Linalg1/lib/Analysis.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/Analysis.cpp @@ -31,7 +31,7 @@ ViewOp linalg::getViewBaseViewOp(Value *view) { auto viewType = view->getType().dyn_cast(); (void)viewType; assert(viewType.isa() && "expected a ViewType"); - while (auto slice = view->getDefiningOp()->dyn_cast()) { + while (auto slice = dyn_cast(view->getDefiningOp())) { view = slice.getParentView(); assert(viewType.isa() && "expected a ViewType"); } @@ -48,7 +48,7 @@ std::pair linalg::getViewRootIndexing(Value *view, (void)viewType; assert(viewType.isa() && "expected a ViewType"); assert(dim < viewType.getRank() && "dim exceeds rank"); - if (auto viewOp = view->getDefiningOp()->dyn_cast()) + if (auto viewOp = dyn_cast(view->getDefiningOp())) return std::make_pair(viewOp.getIndexing(dim), dim); auto sliceOp = view->getDefiningOp()->cast(); diff --git a/mlir/examples/Linalg/Linalg1/lib/Common.cpp b/mlir/examples/Linalg/Linalg1/lib/Common.cpp index bfdc40a6aa0..278f9c57607 100644 --- a/mlir/examples/Linalg/Linalg1/lib/Common.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/Common.cpp @@ -40,7 +40,7 @@ linalg::common::LoopNestRangeBuilder::LoopNestRangeBuilder( assert(ivs.size() == indexings.size()); for (unsigned i = 0, e = indexings.size(); i < e; ++i) { auto rangeOp = - indexings[i].getValue()->getDefiningOp()->dyn_cast(); + llvm::dyn_cast(indexings[i].getValue()->getDefiningOp()); if (!rangeOp) { continue; } diff --git a/mlir/examples/Linalg/Linalg1/lib/Utils.cpp b/mlir/examples/Linalg/Linalg1/lib/Utils.cpp index 372c08f9eea..5bcebc79c18 100644 --- a/mlir/examples/Linalg/Linalg1/lib/Utils.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/Utils.cpp @@ -33,7 +33,7 @@ using namespace linalg::intrinsics; unsigned linalg::getViewRank(Value *view) { assert(view->getType().isa() && "expected a ViewType"); - if (auto viewOp = view->getDefiningOp()->dyn_cast()) + if (auto viewOp = dyn_cast(view->getDefiningOp())) return viewOp.getRank(); return view->getDefiningOp()->cast().getRank(); } diff --git a/mlir/examples/Linalg/Linalg2/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg2/lib/Transforms.cpp index d1af7503d1b..83fd9ad3143 100644 --- a/mlir/examples/Linalg/Linalg2/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg2/lib/Transforms.cpp @@ -43,7 +43,7 @@ using namespace linalg::intrinsics; // analyses. This builds the chain. static SmallVector getViewChain(mlir::Value *v) { assert(v->getType().isa() && "ViewType expected"); - if (v->getDefiningOp()->dyn_cast()) { + if (v->getDefiningOp()->isa()) { return SmallVector{v}; } @@ -53,7 +53,7 @@ static SmallVector getViewChain(mlir::Value *v) { tmp.push_back(v); v = sliceOp.getParentView(); } while (!v->getType().isa()); - assert(v->getDefiningOp()->cast() && "must be a ViewOp"); + assert(v->getDefiningOp()->isa() && "must be a ViewOp"); tmp.push_back(v); return SmallVector(tmp.rbegin(), tmp.rend()); } diff --git a/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps-inl.h b/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps-inl.h index 9339d7309e3..3090f29dcfc 100644 --- a/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps-inl.h +++ b/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps-inl.h @@ -91,7 +91,7 @@ inline llvm::SmallVector extractRangesFromViewOrSliceOp(mlir::Value *view) { // This expects a viewType which must come from either ViewOp or SliceOp. assert(view->getType().isa() && "expected ViewType"); - if (auto viewOp = view->getDefiningOp()->dyn_cast()) + if (auto viewOp = llvm::dyn_cast(view->getDefiningOp())) return viewOp.getRanges(); auto sliceOp = view->getDefiningOp()->cast(); diff --git a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp index 42999aef7ae..bce7f58860d 100644 --- a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp @@ -46,9 +46,9 @@ void linalg::composeSliceOps(mlir::Function *f) { void linalg::lowerToFinerGrainedTensorContraction(mlir::Function *f) { f->walk([](Operation *op) { - if (auto matmulOp = op->dyn_cast()) { + if (auto matmulOp = dyn_cast(op)) { matmulOp.writeAsFinerGrainTensorContraction(); - } else if (auto matvecOp = op->dyn_cast()) { + } else if (auto matvecOp = dyn_cast(op)) { matvecOp.writeAsFinerGrainTensorContraction(); } else { return; @@ -205,11 +205,11 @@ writeContractionAsLoops(ContractionOp contraction) { llvm::Optional> linalg::writeAsLoops(Operation *op) { - if (auto matmulOp = op->dyn_cast()) { + if (auto matmulOp = dyn_cast(op)) { return writeContractionAsLoops(matmulOp); - } else if (auto matvecOp = op->dyn_cast()) { + } else if (auto matvecOp = dyn_cast(op)) { return writeContractionAsLoops(matvecOp); - } else if (auto dotOp = op->dyn_cast()) { + } else if (auto dotOp = dyn_cast(op)) { return writeContractionAsLoops(dotOp); } return llvm::None; @@ -276,7 +276,7 @@ PatternMatchResult Rewriter::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { auto load = op->cast(); - SliceOp slice = load.getView()->getDefiningOp()->dyn_cast(); + SliceOp slice = dyn_cast(load.getView()->getDefiningOp()); ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult()) : load.getView()->getDefiningOp()->cast(); ScopedContext scope(FuncBuilder(load), load.getLoc()); @@ -291,7 +291,7 @@ PatternMatchResult Rewriter::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { auto store = op->cast(); - SliceOp slice = store.getView()->getDefiningOp()->dyn_cast(); + SliceOp slice = dyn_cast(store.getView()->getDefiningOp()); ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult()) : store.getView()->getDefiningOp()->cast(); ScopedContext scope(FuncBuilder(store), store.getLoc()); diff --git a/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp index 05865e9e53c..6771257ae0f 100644 --- a/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp @@ -52,8 +52,8 @@ void linalg::lowerToTiledLoops(mlir::Function *f, } static bool isZeroIndex(Value *v) { - return v->getDefiningOp() && v->getDefiningOp()->isa() && - v->getDefiningOp()->dyn_cast().getValue() == 0; + return isa_and_nonnull(v->getDefiningOp()) && + cast(v->getDefiningOp()).getValue() == 0; } template @@ -178,11 +178,11 @@ writeContractionAsTiledViews(TensorContractionBase &contraction, llvm::Optional> linalg::writeAsTiledViews(Operation *op, ArrayRef tileSizes) { - if (auto matmulOp = op->dyn_cast()) { + if (auto matmulOp = dyn_cast(op)) { return writeContractionAsTiledViews(matmulOp, tileSizes); - } else if (auto matvecOp = op->dyn_cast()) { + } else if (auto matvecOp = dyn_cast(op)) { return writeContractionAsTiledViews(matvecOp, tileSizes); - } else if (auto dotOp = op->dyn_cast()) { + } else if (auto dotOp = dyn_cast(op)) { return writeContractionAsTiledViews(dotOp, tileSizes); } return llvm::None; @@ -190,11 +190,11 @@ linalg::writeAsTiledViews(Operation *op, ArrayRef tileSizes) { void linalg::lowerToTiledViews(mlir::Function *f, ArrayRef tileSizes) { f->walk([tileSizes](Operation *op) { - if (auto matmulOp = op->dyn_cast()) { + if (auto matmulOp = dyn_cast(op)) { writeAsTiledViews(matmulOp, tileSizes); - } else if (auto matvecOp = op->dyn_cast()) { + } else if (auto matvecOp = dyn_cast(op)) { writeAsTiledViews(matvecOp, tileSizes); - } else if (auto dotOp = op->dyn_cast()) { + } else if (auto dotOp = dyn_cast(op)) { writeAsTiledViews(dotOp, tileSizes); } else { return; diff --git a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp index a11c88266b7..c9f98e7d6a9 100644 --- a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp @@ -238,13 +238,13 @@ public: LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); // The add operation is trivial: propagate the input type as is. - if (auto addOp = op->dyn_cast()) { + if (auto addOp = llvm::dyn_cast(op)) { op->getResult(0)->setType(op->getOperand(0)->getType()); continue; } // Transpose is easy: just invert the dimensions. - if (auto transpose = op->dyn_cast()) { + if (auto transpose = llvm::dyn_cast(op)) { SmallVector dims; auto arrayTy = transpose.getOperand()->getType().cast(); dims.insert(dims.end(), arrayTy.getShape().begin(), @@ -259,7 +259,7 @@ public: // catch it but shape inference earlier in the pass could generate an // invalid IR (from an invalid Toy input of course) and we wouldn't want // to crash here. - if (auto mulOp = op->dyn_cast()) { + if (auto mulOp = llvm::dyn_cast(op)) { auto lhs = mulOp.getLHS()->getType().cast(); auto rhs = mulOp.getRHS()->getType().cast(); auto lhsRank = lhs.getShape().size(); @@ -291,7 +291,7 @@ public: // for this function, queue the callee in the inter-procedural work list, // and return. The current function stays in the work list and will // restart after the callee is processed. - if (auto callOp = op->dyn_cast()) { + if (auto callOp = llvm::dyn_cast(op)) { auto calleeName = callOp.getCalleeName(); auto *callee = getModule().getNamedFunction(calleeName); if (!callee) { diff --git a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp index f3e8ff06781..942ce866182 100644 --- a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp @@ -53,7 +53,7 @@ struct SimplifyRedundantTranspose : public mlir::RewritePattern { // Look through the input of the current transpose. mlir::Value *transposeInput = transpose.getOperand(); TransposeOp transposeInputOp = - mlir::dyn_cast_or_null(transposeInput->getDefiningOp()); + llvm::dyn_cast_or_null(transposeInput->getDefiningOp()); // If the input is defined by another Transpose, bingo! if (!transposeInputOp) return matchFailure(); @@ -75,7 +75,7 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern { mlir::PatternRewriter &rewriter) const override { ReshapeOp reshape = op->cast(); // Look through the input of the current reshape. - ConstantOp constantOp = mlir::dyn_cast_or_null( + ConstantOp constantOp = llvm::dyn_cast_or_null( reshape.getOperand()->getDefiningOp()); // If the input is defined by another constant, bingo! if (!constantOp) diff --git a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp index 4ef62d33adc..534b5cbd2ab 100644 --- a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp @@ -366,7 +366,7 @@ struct LateLoweringPass : public ModulePass { // First patch calls type to return memref instead of ToyArray for (auto &function : getModule()) { function.walk([&](Operation *op) { - auto callOp = op->dyn_cast(); + auto callOp = dyn_cast(op); if (!callOp) return; if (!callOp.getNumResults()) @@ -382,14 +382,14 @@ struct LateLoweringPass : public ModulePass { for (auto &function : getModule()) { function.walk([&](Operation *op) { // Turns toy.alloc into sequence of alloc/dealloc (later malloc/free). - if (auto allocOp = op->dyn_cast()) { + if (auto allocOp = dyn_cast(op)) { auto result = allocTensor(allocOp); allocOp.replaceAllUsesWith(result); allocOp.erase(); return; } // Eliminate all type.cast before lowering to LLVM. - if (auto typeCastOp = op->dyn_cast()) { + if (auto typeCastOp = dyn_cast(op)) { typeCastOp.replaceAllUsesWith(typeCastOp.getOperand()); typeCastOp.erase(); return; @@ -429,7 +429,7 @@ struct LateLoweringPass : public ModulePass { // Insert a `dealloc` operation right before the `return` operations, unless // it is returned itself in which case the caller is responsible for it. builder.getFunction()->walk([&](Operation *op) { - auto returnOp = op->dyn_cast(); + auto returnOp = dyn_cast(op); if (!returnOp) return; if (returnOp.getNumOperands() && returnOp.getOperand(0) == alloc) diff --git a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp index a083e62f05f..4e17b234d14 100644 --- a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp @@ -238,7 +238,7 @@ public: LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); // The add operation is trivial: propagate the input type as is. - if (auto addOp = op->dyn_cast()) { + if (auto addOp = llvm::dyn_cast(op)) { op->getResult(0)->setType(op->getOperand(0)->getType()); continue; } @@ -261,7 +261,7 @@ public: // catch it but shape inference earlier in the pass could generate an // invalid IR (from an invalid Toy input of course) and we wouldn't want // to crash here. - if (auto mulOp = op->dyn_cast()) { + if (auto mulOp = llvm::dyn_cast(op)) { auto lhs = mulOp.getLHS()->getType().cast(); auto rhs = mulOp.getRHS()->getType().cast(); auto lhsRank = lhs.getShape().size(); @@ -295,7 +295,7 @@ public: // for this function, queue the callee in the inter-procedural work list, // and return. The current function stays in the work list and will // restart after the callee is processed. - if (auto callOp = op->dyn_cast()) { + if (auto callOp = llvm::dyn_cast(op)) { auto calleeName = callOp.getCalleeName(); auto *callee = getModule().getNamedFunction(calleeName); if (!callee) { diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h index 5d23488c95d..39302f6c0f9 100644 --- a/mlir/include/mlir/EDSC/Builders.h +++ b/mlir/include/mlir/EDSC/Builders.h @@ -439,7 +439,7 @@ ValueHandle ValueHandle::create(Args... args) { if (op->getNumResults() == 1) { return ValueHandle(op->getResult(0)); } else if (op->getNumResults() == 0) { - if (auto f = op->dyn_cast()) { + if (auto f = dyn_cast(op)) { return ValueHandle(f.getInductionVar()); } } diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 1ee6c4806fb..7f182e882db 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -271,7 +271,7 @@ public: OperationState state(getContext(), location, OpTy::getOperationName()); OpTy::build(this, &state, args...); auto *op = createOperation(state); - auto result = op->dyn_cast(); + auto result = dyn_cast(op); assert(result && "Builder didn't return the right type"); return result; } diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h index 0770d2cfa27..d4b85b56d0f 100644 --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -116,7 +116,7 @@ public: /// Specialization of walk to only visit operations of 'OpTy'. template void walk(std::function callback) { walk([&](Operation *opInst) { - if (auto op = opInst->dyn_cast()) + if (auto op = dyn_cast(opInst)) callback(op); }); } diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index b80e8aca9bc..2eff412a71e 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -792,7 +792,7 @@ public: /// This is the hook used by the AsmPrinter to emit this to the .mlir file. /// Op implementations should provide a print method. static void printAssembly(Operation *op, OpAsmPrinter *p) { - auto opPointer = op->dyn_cast(); + auto opPointer = dyn_cast(op); assert(opPointer && "op's name does not match name of concrete type instantiated with"); opPointer.print(p); @@ -825,11 +825,13 @@ public: /// This is a public constructor. Any op can be initialized to null. explicit Op() : OpState(nullptr) {} + Op(std::nullptr_t) : OpState(nullptr) {} -protected: - /// This is a private constructor only accessible through the - /// Operation::cast family of methods. - explicit Op(Operation *state) : OpState(state) {} + /// This is a public constructor to enable access via the llvm::cast family of + /// methods. This should not be used directly. + explicit Op(Operation *state) : OpState(state) { + assert(!state || isa(state)); + } friend class Operation; private: diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index 54e49b73e3b..31ec8ea54a6 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -389,14 +389,6 @@ public: // Conversions to declared operations like DimOp //===--------------------------------------------------------------------===// - /// The dyn_cast methods perform a dynamic cast from an Operation to a typed - /// Op like DimOp. This returns a null Op on failure. - template OpClass dyn_cast() { - if (isa()) - return cast(); - return OpClass(); - } - /// The cast methods perform a cast from an Operation to a typed Op like /// DimOp. This aborts if the parameter to the template isn't an instance of /// the template type argument. @@ -417,10 +409,10 @@ public: /// including this one. void walk(const std::function &callback); - /// Specialization of walk to only visit operations of 'OpTy'. - template void walk(std::function callback) { + /// Specialization of walk to only visit operations of 'T'. + template void walk(std::function callback) { walk([&](Operation *op) { - if (auto derivedOp = op->dyn_cast()) + if (auto derivedOp = dyn_cast(op)) callback(derivedOp); }); } @@ -534,17 +526,6 @@ inline auto Operation::getOperands() -> operand_range { return {operand_begin(), operand_end()}; } -/// Provide dyn_cast_or_null functionality for Operation casts. -template T dyn_cast_or_null(Operation *op) { - return op ? op->dyn_cast() : T(); -} - -/// Provide isa_and_nonnull functionality for Operation casts, i.e. if the -/// operation is non-null and a class of 'T'. -template bool isa_and_nonnull(Operation *op) { - return op && op->isa(); -} - /// This class implements the result iterators for the Operation class /// in terms of getResult(idx). class ResultIterator final @@ -598,4 +579,30 @@ inline auto Operation::getResultTypes() } // end namespace mlir +namespace llvm { +/// Provide isa functionality for operation casts. +template struct isa_impl { + static inline bool doit(const ::mlir::Operation &op) { + return T::classof(const_cast<::mlir::Operation *>(&op)); + } +}; + +/// Provide specializations for operation casts as the resulting T is value +/// typed. +template struct cast_retty_impl { + using ret_type = T; +}; +template struct cast_retty_impl { + using ret_type = T; +}; +template +struct cast_convert_val { + static T doit(::mlir::Operation &val) { return T(&val); } +}; +template +struct cast_convert_val { + static T doit(::mlir::Operation *val) { return T(val); } +}; +} // end namespace llvm + #endif // MLIR_IR_OPERATION_H diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 3b02ed55c34..51528c18d38 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -215,7 +215,7 @@ public: OperationState state(getContext(), location, OpTy::getOperationName()); OpTy::build(this, &state, args...); auto *op = createOperation(state); - auto result = op->dyn_cast(); + auto result = dyn_cast(op); assert(result && "Builder didn't return the right type"); return result; } @@ -231,7 +231,7 @@ public: // If the Operation we produce is valid, return it. if (!OpTy::verifyInvariants(op)) { - auto result = op->dyn_cast(); + auto result = dyn_cast(op); assert(result && "Builder didn't return the right type"); return result; } diff --git a/mlir/include/mlir/Support/LLVM.h b/mlir/include/mlir/Support/LLVM.h index 031dceb518e..6676ad0d818 100644 --- a/mlir/include/mlir/Support/LLVM.h +++ b/mlir/include/mlir/Support/LLVM.h @@ -69,6 +69,7 @@ using llvm::cast_or_null; using llvm::dyn_cast; using llvm::dyn_cast_or_null; using llvm::isa; +using llvm::isa_and_nonnull; // Containers. using llvm::ArrayRef; diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 51209da7385..2dfed934ee0 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -61,11 +61,11 @@ bool mlir::isValidDim(Value *value) { if (op->getParentOp() == nullptr || op->isa()) return true; // Affine apply operation is ok if all of its operands are ok. - if (auto applyOp = op->dyn_cast()) + if (auto applyOp = dyn_cast(op)) return applyOp.isValidDim(); // The dim op is okay if its operand memref/tensor is defined at the top // level. - if (auto dimOp = op->dyn_cast()) + if (auto dimOp = dyn_cast(op)) return isTopLevelSymbol(dimOp.getOperand()); return false; } @@ -86,11 +86,11 @@ bool mlir::isValidSymbol(Value *value) { if (op->getParentOp() == nullptr || op->isa()) return true; // Affine apply operation is ok if all of its operands are ok. - if (auto applyOp = op->dyn_cast()) + if (auto applyOp = dyn_cast(op)) return applyOp.isValidSymbol(); // The dim op is okay if its operand memref/tensor is defined at the top // level. - if (auto dimOp = op->dyn_cast()) + if (auto dimOp = dyn_cast(op)) return isTopLevelSymbol(dimOp.getOperand()); return false; } diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 78caa4c2625..60f2b142986 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -320,8 +320,8 @@ isVectorizableLoopBodyWithOpCond(AffineForOp loop, loadAndStores.match(forOp, &loadAndStoresMatched); for (auto ls : loadAndStoresMatched) { auto *op = ls.getMatchedOperation(); - auto load = op->dyn_cast(); - auto store = op->dyn_cast(); + auto load = dyn_cast(op); + auto store = dyn_cast(op); // Only scalar types are considered vectorizable, all load/store must be // vectorizable for a loop to qualify as vectorizable. // TODO(ntv): ponder whether we want to be more general here. @@ -338,8 +338,8 @@ isVectorizableLoopBodyWithOpCond(AffineForOp loop, bool mlir::isVectorizableLoopBody(AffineForOp loop, int *memRefDim) { VectorizableOpFun fun([memRefDim](AffineForOp loop, Operation &op) { - auto load = op.dyn_cast(); - auto store = op.dyn_cast(); + auto load = dyn_cast(op); + auto store = dyn_cast(op); return load ? isContiguousAccess(loop.getInductionVar(), load, memRefDim) : isContiguousAccess(loop.getInductionVar(), store, memRefDim); }); diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index 0fb88620fa1..4e23441d5a5 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -48,9 +48,9 @@ FunctionPassBase *mlir::createMemRefBoundCheckPass() { void MemRefBoundCheck::runOnFunction() { getFunction().walk([](Operation *opInst) { - if (auto loadOp = opInst->dyn_cast()) { + if (auto loadOp = dyn_cast(opInst)) { boundCheckLoadOrStoreOp(loadOp); - } else if (auto storeOp = opInst->dyn_cast()) { + } else if (auto storeOp = dyn_cast(opInst)) { boundCheckLoadOrStoreOp(storeOp); } // TODO(bondhugula): do this for DMA ops as well. diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index bce000a4c1f..155a2bbbd1b 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -50,7 +50,7 @@ static void getForwardSliceImpl(Operation *op, return; } - if (auto forOp = op->dyn_cast()) { + if (auto forOp = dyn_cast(op)) { for (auto &u : forOp.getInductionVar()->getUses()) { auto *ownerInst = u.getOwner(); if (forwardSlice->count(ownerInst) == 0) { diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 1eaab676567..8d963e4739c 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -44,7 +44,7 @@ void mlir::getLoopIVs(Operation &op, SmallVectorImpl *loops) { AffineForOp currAffineForOp; // Traverse up the hierarchy collecing all 'affine.for' operation while // skipping over 'affine.if' operations. - while (currOp && ((currAffineForOp = currOp->dyn_cast()) || + while (currOp && ((currAffineForOp = dyn_cast(currOp)) || currOp->isa())) { if (currAffineForOp) loops->push_back(currAffineForOp); @@ -239,7 +239,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth, assert(isValidSymbol(symbol)); // Check if the symbol is a constant. if (auto *op = symbol->getDefiningOp()) { - if (auto constOp = op->dyn_cast()) { + if (auto constOp = dyn_cast(op)) { cst.setIdToConstant(*symbol, constOp.getValue()); } } @@ -467,7 +467,7 @@ static Operation *getInstAtPosition(ArrayRef positions, } if (level == positions.size() - 1) return &op; - if (auto childAffineForOp = op.dyn_cast()) + if (auto childAffineForOp = dyn_cast(op)) return getInstAtPosition(positions, level + 1, childAffineForOp.getBody()); @@ -633,7 +633,7 @@ mlir::insertBackwardComputationSlice(Operation *srcOpInst, Operation *dstOpInst, // Constructs MemRefAccess populating it with the memref, its indices and // opinst from 'loadOrStoreOpInst'. MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) { - if (auto loadOp = loadOrStoreOpInst->dyn_cast()) { + if (auto loadOp = dyn_cast(loadOrStoreOpInst)) { memref = loadOp.getMemRef(); opInst = loadOrStoreOpInst; auto loadMemrefType = loadOp.getMemRefType(); @@ -643,7 +643,7 @@ MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) { } } else { assert(loadOrStoreOpInst->isa() && "load/store op expected"); - auto storeOp = loadOrStoreOpInst->dyn_cast(); + auto storeOp = dyn_cast(loadOrStoreOpInst); opInst = loadOrStoreOpInst; memref = storeOp.getMemRef(); auto storeMemrefType = storeOp.getMemRefType(); @@ -750,7 +750,7 @@ Optional mlir::getMemoryFootprintBytes(AffineForOp forOp, void mlir::getSequentialLoops( AffineForOp forOp, llvm::SmallDenseSet *sequentialLoops) { forOp.getOperation()->walk([&](Operation *op) { - if (auto innerFor = op->dyn_cast()) + if (auto innerFor = dyn_cast(op)) if (!isLoopParallel(innerFor)) sequentialLoops->insert(innerFor.getInductionVar()); }); diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index b45ac001be4..8fecf058bfc 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -152,7 +152,7 @@ static SetVector getParentsOfType(Operation *op) { SetVector res; auto *current = op; while (auto *parent = current->getParentOp()) { - if (auto typedParent = parent->template dyn_cast()) { + if (auto typedParent = dyn_cast(parent)) { assert(res.count(parent) == 0 && "Already inserted"); res.insert(parent); } @@ -177,7 +177,7 @@ AffineMap mlir::makePermutationMap( } } - if (auto load = op->dyn_cast()) { + if (auto load = dyn_cast(op)) { return ::makePermutationMap(load.getIndices(), enclosingLoopToVectorDim); } @@ -198,10 +198,10 @@ bool mlir::matcher::operatesOnSuperVectorsOf(Operation &op, /// do not have to special case. Maybe a trait, or just a method, unclear atm. bool mustDivide = false; VectorType superVectorType; - if (auto read = op.dyn_cast()) { + if (auto read = dyn_cast(op)) { superVectorType = read.getResultType(); mustDivide = true; - } else if (auto write = op.dyn_cast()) { + } else if (auto write = dyn_cast(op)) { superVectorType = write.getVectorType(); mustDivide = true; } else if (op.getNumResults() == 0) { diff --git a/mlir/lib/EDSC/Builders.cpp b/mlir/lib/EDSC/Builders.cpp index 610c8b66320..2c9117736ae 100644 --- a/mlir/lib/EDSC/Builders.cpp +++ b/mlir/lib/EDSC/Builders.cpp @@ -100,7 +100,7 @@ ValueHandle ValueHandle::create(StringRef name, ArrayRef operands, if (op->getNumResults() == 1) { return ValueHandle(op->getResult(0)); } - if (auto f = op->dyn_cast()) { + if (auto f = dyn_cast(op)) { return ValueHandle(f.getInductionVar()); } llvm_unreachable("unsupported operation, use an OperationHandle instead"); @@ -147,8 +147,8 @@ static llvm::Optional emitStaticFor(ArrayRef lbs, if (!lbDef || !ubDef) return llvm::Optional(); - auto lbConst = lbDef->dyn_cast(); - auto ubConst = ubDef->dyn_cast(); + auto lbConst = dyn_cast(lbDef); + auto ubConst = dyn_cast(ubDef); if (!lbConst || !ubConst) return llvm::Optional(); diff --git a/mlir/lib/Linalg/Transforms/Tiling.cpp b/mlir/lib/Linalg/Transforms/Tiling.cpp index 434f7206e04..6e20542a818 100644 --- a/mlir/lib/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Linalg/Transforms/Tiling.cpp @@ -319,11 +319,11 @@ static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef tileSizes, // TODO(ntv) expose as a primitive for other passes. static LogicalResult tileLinalgOp(Operation *op, ArrayRef tileSizes, PerFunctionState &state) { - if (auto matmulOp = op->dyn_cast()) { + if (auto matmulOp = dyn_cast(op)) { return tileLinalgOp(matmulOp, tileSizes, state); - } else if (auto matvecOp = op->dyn_cast()) { + } else if (auto matvecOp = dyn_cast(op)) { return tileLinalgOp(matvecOp, tileSizes, state); - } else if (auto dotOp = op->dyn_cast()) { + } else if (auto dotOp = dyn_cast(op)) { return tileLinalgOp(dotOp, tileSizes, state); } return failure(); diff --git a/mlir/lib/Linalg/Utils/Utils.cpp b/mlir/lib/Linalg/Utils/Utils.cpp index 4b77ece21dd..98cf4b75b6a 100644 --- a/mlir/lib/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Linalg/Utils/Utils.cpp @@ -68,9 +68,9 @@ ValueHandle LoopNestRangeBuilder::LoopNestRangeBuilder::operator()( SmallVector mlir::getRanges(Operation *op) { SmallVector res; - if (auto view = op->dyn_cast()) { + if (auto view = dyn_cast(op)) { res.append(view.getIndexings().begin(), view.getIndexings().end()); - } else if (auto slice = op->dyn_cast()) { + } else if (auto slice = dyn_cast(op)) { for (auto *i : slice.getIndexings()) if (i->getType().isa()) res.push_back(i); @@ -100,7 +100,7 @@ SmallVector mlir::getRanges(Operation *op) { Value *mlir::createOrReturnView(FuncBuilder *b, Location loc, Operation *viewDefiningOp, ArrayRef ranges) { - if (auto view = viewDefiningOp->dyn_cast()) { + if (auto view = dyn_cast(viewDefiningOp)) { auto indexings = view.getIndexings(); if (std::equal(indexings.begin(), indexings.end(), ranges.begin())) return view.getResult(); diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 05e3b13eb4c..bc68a78bd0a 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -134,7 +134,7 @@ struct MemRefCastFolder : public RewritePattern { void rewrite(Operation *op, PatternRewriter &rewriter) const override { for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) if (auto *memref = op->getOperand(i)->getDefiningOp()) - if (auto cast = memref->dyn_cast()) + if (auto cast = dyn_cast(memref)) op->setOperand(i, cast.getOperand()); rewriter.updatedRootInPlace(op); } diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 8a9c649feb3..597efc3ba37 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -199,11 +199,11 @@ bool ModuleTranslation::convertOperation(Operation &opInst, // Emit branches. We need to look up the remapped blocks and ignore the block // arguments that were transformed into PHI nodes. - if (auto brOp = opInst.dyn_cast()) { + if (auto brOp = dyn_cast(opInst)) { builder.CreateBr(blockMapping[brOp.getSuccessor(0)]); return false; } - if (auto condbrOp = opInst.dyn_cast()) { + if (auto condbrOp = dyn_cast(opInst)) { builder.CreateCondBr(valueMapping.lookup(condbrOp.getOperand(0)), blockMapping[condbrOp.getSuccessor(0)], blockMapping[condbrOp.getSuccessor(1)]); @@ -264,7 +264,7 @@ static Value *getPHISourceValue(Block *current, Block *pred, // For conditional branches, we need to check if the current block is reached // through the "true" or the "false" branch and take the relevant operands. - auto condBranchOp = terminator.dyn_cast(); + auto condBranchOp = dyn_cast(terminator); assert(condBranchOp && "only branch operations can be terminators of a block that " "has successors"); diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 10f47fe9be1..937399cc703 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -173,11 +173,11 @@ static void getMultiLevelStrides(const MemRefRegion ®ion, static bool getFullMemRefAsRegion(Operation *opInst, unsigned numParamLoopIVs, MemRefRegion *region) { unsigned rank; - if (auto loadOp = opInst->dyn_cast()) { + if (auto loadOp = dyn_cast(opInst)) { rank = loadOp.getMemRefType().getRank(); region->memref = loadOp.getMemRef(); region->setWrite(false); - } else if (auto storeOp = opInst->dyn_cast()) { + } else if (auto storeOp = dyn_cast(opInst)) { rank = storeOp.getMemRefType().getRank(); region->memref = storeOp.getMemRef(); region->setWrite(true); @@ -483,7 +483,7 @@ bool DmaGeneration::runOnBlock(Block *block) { }); for (auto it = curBegin; it != block->end(); ++it) { - if (auto forOp = it->dyn_cast()) { + if (auto forOp = dyn_cast(&*it)) { // Returns true if the footprint is known to exceed capacity. auto exceedsCapacity = [&](AffineForOp forOp) { Optional footprint = @@ -607,10 +607,10 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { // Walk this range of operations to gather all memory regions. block->walk(begin, end, [&](Operation *opInst) { // Gather regions to allocate to buffers in faster memory space. - if (auto loadOp = opInst->dyn_cast()) { + if (auto loadOp = dyn_cast(opInst)) { if (loadOp.getMemRefType().getMemorySpace() != slowMemorySpace) return; - } else if (auto storeOp = opInst->dyn_cast()) { + } else if (auto storeOp = dyn_cast(opInst)) { if (storeOp.getMemRefType().getMemorySpace() != slowMemorySpace) return; } else { @@ -739,7 +739,7 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { // For a range of operations, a note will be emitted at the caller. AffineForOp forOp; uint64_t sizeInKib = llvm::divideCeil(totalDmaBuffersSizeInBytes, 1024); - if (llvm::DebugFlag && (forOp = begin->dyn_cast())) { + if (llvm::DebugFlag && (forOp = dyn_cast(&*begin))) { forOp.emitRemark() << sizeInKib << " KiB of DMA buffers in fast memory space for this block\n"; diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 796d2164ad9..1c4a4d1f755 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -644,7 +644,7 @@ bool MemRefDependenceGraph::init(Function &f) { DenseMap forToNodeMap; for (auto &op : f.front()) { - if (auto forOp = op.dyn_cast()) { + if (auto forOp = dyn_cast(op)) { // Create graph node 'id' to represent top-level 'forOp' and record // all loads and store accesses it contains. LoopNestStateCollector collector; @@ -666,14 +666,14 @@ bool MemRefDependenceGraph::init(Function &f) { } forToNodeMap[&op] = node.id; nodes.insert({node.id, node}); - } else if (auto loadOp = op.dyn_cast()) { + } else if (auto loadOp = dyn_cast(op)) { // Create graph node for top-level load op. Node node(nextNodeId++, &op); node.loads.push_back(&op); auto *memref = op.cast().getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); - } else if (auto storeOp = op.dyn_cast()) { + } else if (auto storeOp = dyn_cast(op)) { // Create graph node for top-level store op. Node node(nextNodeId++, &op); node.stores.push_back(&op); @@ -2125,7 +2125,7 @@ public: auto *fn = dstNode->op->getFunction(); for (unsigned i = 0, e = fn->getNumArguments(); i != e; ++i) { for (auto &use : fn->getArgument(i)->getUses()) { - if (auto loadOp = use.getOwner()->dyn_cast()) { + if (auto loadOp = dyn_cast(use.getOwner())) { // Gather loops surrounding 'use'. SmallVector loops; getLoopIVs(*use.getOwner(), &loops); diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index ce42a5eba85..28e13d89ada 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -273,7 +273,7 @@ static void getTileableBands(Function &f, for (auto &block : f) for (auto &op : block) - if (auto forOp = op.dyn_cast()) + if (auto forOp = dyn_cast(op)) getMaximalPerfectLoopNest(forOp); } diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 366a7ede5eb..0a23295c8d9 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -92,7 +92,7 @@ void LoopUnrollAndJam::runOnFunction() { // unroll-and-jammed by this pass. However, runOnAffineForOp can be called on // any for operation. auto &entryBlock = getFunction().front(); - if (auto forOp = entryBlock.front().dyn_cast()) + if (auto forOp = dyn_cast(entryBlock.front())) runOnAffineForOp(forOp); } diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index dc389c8e37a..1ffe5e3ddd7 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -620,10 +620,10 @@ void LowerAffinePass::runOnFunction() { // Rewrite all of the ifs and fors. We walked the operations in postorders, // so we know that we will rewrite them in the reverse order. for (auto *op : llvm::reverse(instsToRewrite)) { - if (auto ifOp = op->dyn_cast()) { + if (auto ifOp = dyn_cast(op)) { if (lowerAffineIf(ifOp)) return signalPassFailure(); - } else if (auto forOp = op->dyn_cast()) { + } else if (auto forOp = dyn_cast(op)) { if (lowerAffineFor(forOp)) return signalPassFailure(); } else if (lowerAffineApply(op->cast())) { diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 2f06a9aa3bf..28dfb2278e0 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -556,12 +556,12 @@ static bool instantiateMaterialization(Operation *op, if (op->getNumRegions() != 0) return op->emitError("NYI path Op with region"), true; - if (auto write = op->dyn_cast()) { + if (auto write = dyn_cast(op)) { auto *clone = instantiate(&b, write, state->hwVectorType, state->hwVectorInstance, state->substitutionsMap); return clone == nullptr; } - if (auto read = op->dyn_cast()) { + if (auto read = dyn_cast(op)) { auto *clone = instantiate(&b, read, state->hwVectorType, state->hwVectorInstance, state->substitutionsMap); if (!clone) { diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index a63d462c4a9..94df936c93f 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -103,7 +103,7 @@ void MemRefDataFlowOpt::forwardStoreToLoad(LoadOp loadOp) { SmallVector storeOps; unsigned minSurroundingLoops = getNestingDepth(*loadOpInst); for (auto &use : loadOp.getMemRef()->getUses()) { - auto storeOp = use.getOwner()->dyn_cast(); + auto storeOp = dyn_cast(use.getOwner()); if (!storeOp) continue; auto *storeOpInst = storeOp.getOperation(); diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 66fbf4a1306..0da97f7d169 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -181,7 +181,7 @@ static void findMatchingStartFinishInsts( // Collect outgoing DMA operations - needed to check for dependences below. SmallVector outgoingDmaOps; for (auto &op : *forOp.getBody()) { - auto dmaStartOp = op.dyn_cast(); + auto dmaStartOp = dyn_cast(op); if (dmaStartOp && dmaStartOp.isSrcMemorySpaceFaster()) outgoingDmaOps.push_back(dmaStartOp); } @@ -193,7 +193,7 @@ static void findMatchingStartFinishInsts( dmaFinishInsts.push_back(&op); continue; } - auto dmaStartOp = op.dyn_cast(); + auto dmaStartOp = dyn_cast(op); if (!dmaStartOp) continue; diff --git a/mlir/lib/Transforms/TestConstantFold.cpp b/mlir/lib/Transforms/TestConstantFold.cpp index 0990d7a73f6..ec1e971973e 100644 --- a/mlir/lib/Transforms/TestConstantFold.cpp +++ b/mlir/lib/Transforms/TestConstantFold.cpp @@ -48,7 +48,7 @@ void TestConstantFold::foldOperation(Operation *op, } // If this op is a constant that are used and cannot be de-duplicated, // remember it for cleanup later. - else if (auto constant = op->dyn_cast()) { + else if (auto constant = dyn_cast(op)) { existingConstants.push_back(op); } } diff --git a/mlir/lib/Transforms/Utils/ConstantFoldUtils.cpp b/mlir/lib/Transforms/Utils/ConstantFoldUtils.cpp index fc8209be872..b907840b27d 100644 --- a/mlir/lib/Transforms/Utils/ConstantFoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/ConstantFoldUtils.cpp @@ -40,7 +40,7 @@ bool ConstantFoldHelper::tryToConstantFold( // into the value it contains. We need to consider constants before the // constant folding logic to avoid re-creating the same constant later. // TODO: Extend to support dialect-specific constant ops. - if (auto constant = op->dyn_cast()) { + if (auto constant = dyn_cast(op)) { // If this constant is dead, update bookkeeping and signal the caller. if (constant.use_empty()) { notifyRemoval(op); diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index a10e4a1ae49..7fbb48ecf99 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -363,7 +363,7 @@ void mlir::getPerfectlyNestedLoops(SmallVectorImpl &nestedLoops, nestedLoops.push_back(curr); auto *currBody = curr.getBody(); while (currBody->begin() == std::prev(currBody->end(), 2) && - (curr = curr.getBody()->front().dyn_cast())) { + (curr = dyn_cast(curr.getBody()->front()))) { nestedLoops.push_back(curr); currBody = curr.getBody(); } diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index 753f7cf750f..b64dc53e037 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -234,7 +234,7 @@ void VectorizerTestPass::testComposeMaps(llvm::raw_ostream &outs) { static bool affineApplyOp(Operation &op) { return op.isa(); } static bool singleResultAffineApplyOpWithoutUses(Operation &op) { - auto app = op.dyn_cast(); + auto app = dyn_cast(op); return app && app.use_empty(); } diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 025a6535a78..9b8768a6445 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -839,8 +839,8 @@ static LogicalResult vectorizeAffineForOp(AffineForOp loop, int64_t step, loadAndStores.match(loop.getOperation(), &loadAndStoresMatches); for (auto ls : loadAndStoresMatches) { auto *opInst = ls.getMatchedOperation(); - auto load = opInst->dyn_cast(); - auto store = opInst->dyn_cast(); + auto load = dyn_cast(opInst); + auto store = dyn_cast(opInst); LLVM_DEBUG(opInst->print(dbgs())); LogicalResult result = load ? vectorizeRootOrTerminal(loop.getInductionVar(), load, state) @@ -982,7 +982,7 @@ static Value *vectorizeOperand(Value *operand, Operation *op, return nullptr; } // 3. vectorize constant. - if (auto constant = operand->getDefiningOp()->dyn_cast()) { + if (auto constant = dyn_cast(operand->getDefiningOp())) { return vectorizeConstant( op, constant, VectorType::get(state->strategy->vectorSizes, operand->getType())); @@ -1012,7 +1012,7 @@ static Operation *vectorizeOneOperation(Operation *opInst, assert(!opInst->isa() && "vector.transfer_write cannot be further vectorized"); - if (auto store = opInst->dyn_cast()) { + if (auto store = dyn_cast(opInst)) { auto *memRef = store.getMemRef(); auto *value = store.getValueToStore(); auto *vectorValue = vectorizeOperand(value, opInst, state); diff --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp index ec566e28825..5c34ed160b2 100644 --- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp +++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp @@ -161,8 +161,8 @@ static bool emitOneBuilder(const Record &record, raw_ostream &os) { } // Output the check and the rewritten builder string. - os << "if (auto op = opInst.dyn_cast<" << op.getQualCppClassName() - << ">()) {\n"; + os << "if (auto op = dyn_cast<" << op.getQualCppClassName() + << ">(opInst)) {\n"; os << bs.str() << builderStrRef << "\n"; os << " return false;\n"; os << "}\n"; -- cgit v1.2.3 From adca3c2edcdd1375d8c421816ec53044537ccd64 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Sat, 11 May 2019 17:57:32 -0700 Subject: Replace Operation::cast with llvm::cast. -- PiperOrigin-RevId: 247785983 --- mlir/examples/Linalg/Linalg1/lib/Analysis.cpp | 6 +-- mlir/examples/Linalg/Linalg1/lib/Common.cpp | 4 +- .../Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp | 6 +-- mlir/examples/Linalg/Linalg1/lib/Utils.cpp | 2 +- mlir/examples/Linalg/Linalg2/lib/Transforms.cpp | 13 +++--- .../Linalg/Linalg3/include/linalg3/TensorOps-inl.h | 2 +- .../Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp | 2 +- mlir/examples/Linalg/Linalg3/lib/Transforms.cpp | 16 +++---- mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp | 2 +- mlir/examples/toy/Ch4/mlir/ToyCombine.cpp | 8 ++-- mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp | 2 +- mlir/examples/toy/Ch5/mlir/LateLowering.cpp | 10 ++--- mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp | 2 +- mlir/examples/toy/Ch5/mlir/ToyCombine.cpp | 10 ++--- mlir/include/mlir/IR/OpDefinition.h | 12 ++--- mlir/include/mlir/IR/Operation.h | 8 ---- mlir/lib/AffineOps/AffineOps.cpp | 4 +- mlir/lib/Analysis/LoopAnalysis.cpp | 2 +- mlir/lib/Analysis/Utils.cpp | 2 +- mlir/lib/Analysis/VectorAnalysis.cpp | 4 +- .../FxpMathOps/Transforms/LowerUniformRealMath.cpp | 6 +-- .../lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp | 18 ++++---- mlir/lib/Linalg/IR/LinalgOps.cpp | 5 +-- mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp | 12 ++--- mlir/lib/Linalg/Transforms/Tiling.cpp | 8 ++-- mlir/lib/Linalg/Utils/Utils.cpp | 6 +-- mlir/lib/Quantization/IR/QuantOps.cpp | 8 ++-- mlir/lib/Quantization/Transforms/ConvertConst.cpp | 2 +- .../Quantization/Transforms/ConvertSimQuant.cpp | 2 +- mlir/lib/StandardOps/Ops.cpp | 16 +++---- mlir/lib/Transforms/LoopFusion.cpp | 52 +++++++++++----------- mlir/lib/Transforms/LoopUnroll.cpp | 2 +- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 2 +- mlir/lib/Transforms/LowerAffine.cpp | 2 +- mlir/lib/Transforms/LowerVectorTransfers.cpp | 4 +- mlir/lib/Transforms/MaterializeVectors.cpp | 2 +- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 2 +- mlir/lib/Transforms/PipelineDataTransfer.cpp | 6 +-- mlir/lib/Transforms/Utils/LoopUtils.cpp | 4 +- .../Vectorization/VectorizerTestPass.cpp | 2 +- mlir/lib/Transforms/Vectorize.cpp | 8 ++-- 41 files changed, 139 insertions(+), 147 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/examples/Linalg/Linalg1/lib/Analysis.cpp b/mlir/examples/Linalg/Linalg1/lib/Analysis.cpp index a7fba179c79..092b83ae1ff 100644 --- a/mlir/examples/Linalg/Linalg1/lib/Analysis.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/Analysis.cpp @@ -35,7 +35,7 @@ ViewOp linalg::getViewBaseViewOp(Value *view) { view = slice.getParentView(); assert(viewType.isa() && "expected a ViewType"); } - return view->getDefiningOp()->cast(); + return cast(view->getDefiningOp()); } Value *linalg::getViewSupportingMemRef(Value *view) { @@ -51,12 +51,12 @@ std::pair linalg::getViewRootIndexing(Value *view, if (auto viewOp = dyn_cast(view->getDefiningOp())) return std::make_pair(viewOp.getIndexing(dim), dim); - auto sliceOp = view->getDefiningOp()->cast(); + auto sliceOp = cast(view->getDefiningOp()); auto *parentView = sliceOp.getParentView(); unsigned sliceDim = sliceOp.getSlicingDim(); auto *indexing = sliceOp.getIndexing(); if (indexing->getDefiningOp()) { - if (auto rangeOp = indexing->getDefiningOp()->cast()) { + if (auto rangeOp = cast(indexing->getDefiningOp())) { // If I sliced with a range and I sliced at this dim, then I'm it. if (dim == sliceDim) { return std::make_pair(rangeOp.getResult(), dim); diff --git a/mlir/examples/Linalg/Linalg1/lib/Common.cpp b/mlir/examples/Linalg/Linalg1/lib/Common.cpp index 278f9c57607..1e211bfc928 100644 --- a/mlir/examples/Linalg/Linalg1/lib/Common.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/Common.cpp @@ -47,8 +47,8 @@ linalg::common::LoopNestRangeBuilder::LoopNestRangeBuilder( auto lb = rangeOp.getMin(); auto ub = rangeOp.getMax(); // This must be a constexpr index until we relax the affine.for constraint - auto step = - rangeOp.getStep()->getDefiningOp()->cast().getValue(); + auto step = llvm::cast(rangeOp.getStep()->getDefiningOp()) + .getValue(); loops.emplace_back(ivs[i], ValueHandle(lb), ValueHandle(ub), step); } } diff --git a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp index 60972405b05..48884b12dad 100644 --- a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp @@ -155,7 +155,7 @@ public: SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { - auto rangeOp = op->cast(); + auto rangeOp = cast(op); auto rangeDescriptorType = linalg::convertLinalgType(rangeOp.getResult()->getType()); @@ -187,7 +187,7 @@ public: SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { - auto viewOp = op->cast(); + auto viewOp = cast(op); auto viewDescriptorType = linalg::convertLinalgType(viewOp.getViewType()); auto memrefType = viewOp.getSupportingMemRef()->getType().cast(); @@ -319,7 +319,7 @@ public: SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { - auto sliceOp = op->cast(); + auto sliceOp = cast(op); auto newViewDescriptorType = linalg::convertLinalgType(sliceOp.getViewType()); auto elementType = rewriter.getType( diff --git a/mlir/examples/Linalg/Linalg1/lib/Utils.cpp b/mlir/examples/Linalg/Linalg1/lib/Utils.cpp index 5bcebc79c18..05070a9de30 100644 --- a/mlir/examples/Linalg/Linalg1/lib/Utils.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/Utils.cpp @@ -35,7 +35,7 @@ unsigned linalg::getViewRank(Value *view) { assert(view->getType().isa() && "expected a ViewType"); if (auto viewOp = dyn_cast(view->getDefiningOp())) return viewOp.getRank(); - return view->getDefiningOp()->cast().getRank(); + return cast(view->getDefiningOp()).getRank(); } ViewOp linalg::emitAndReturnViewOpFromMemRef(Value *memRef) { diff --git a/mlir/examples/Linalg/Linalg2/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg2/lib/Transforms.cpp index 83fd9ad3143..9df0af8d25d 100644 --- a/mlir/examples/Linalg/Linalg2/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg2/lib/Transforms.cpp @@ -28,6 +28,7 @@ #include "mlir/IR/StandardTypes.h" using llvm::ArrayRef; +using llvm::cast; using llvm::SmallVector; using mlir::FuncBuilder; using mlir::MemRefType; @@ -49,7 +50,7 @@ static SmallVector getViewChain(mlir::Value *v) { SmallVector tmp; do { - auto sliceOp = v->getDefiningOp()->cast(); // must be a slice op + auto sliceOp = cast(v->getDefiningOp()); // must be a slice op tmp.push_back(v); v = sliceOp.getParentView(); } while (!v->getType().isa()); @@ -62,15 +63,15 @@ static mlir::Value *createFullyComposedIndexing(unsigned dim, ArrayRef chain) { using namespace mlir::edsc::op; assert(chain.front()->getType().isa() && "must be a ViewType"); - auto viewOp = chain.front()->getDefiningOp()->cast(); + auto viewOp = cast(chain.front()->getDefiningOp()); auto *indexing = viewOp.getIndexing(dim); if (!indexing->getType().isa()) return indexing; - auto rangeOp = indexing->getDefiningOp()->cast(); + auto rangeOp = cast(indexing->getDefiningOp()); Value *min = rangeOp.getMin(), *max = rangeOp.getMax(), *step = rangeOp.getStep(); for (auto *v : chain.drop_front(1)) { - auto slice = v->getDefiningOp()->cast(); + auto slice = cast(v->getDefiningOp()); if (slice.getRank() != slice.getParentRank()) { // Rank-reducing slice. if (slice.getSlicingDim() == dim) { @@ -82,7 +83,7 @@ static mlir::Value *createFullyComposedIndexing(unsigned dim, dim = (slice.getSlicingDim() < dim) ? dim - 1 : dim; } else { // not a rank-reducing slice. if (slice.getSlicingDim() == dim) { - auto range = slice.getIndexing()->getDefiningOp()->cast(); + auto range = cast(slice.getIndexing()->getDefiningOp()); auto oldMin = min; min = ValueHandle(min) + ValueHandle(range.getMin()); // ideally: max = min(oldMin + ValueHandle(range.getMax()), oldMax); @@ -110,5 +111,5 @@ ViewOp linalg::emitAndReturnFullyComposedView(Value *v) { for (unsigned idx = 0; idx < rank; ++idx) { ranges.push_back(createFullyComposedIndexing(idx, chain)); } - return view(memRef, ranges).getOperation()->cast(); + return cast(view(memRef, ranges).getOperation()); } diff --git a/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps-inl.h b/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps-inl.h index 3090f29dcfc..2c475418b05 100644 --- a/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps-inl.h +++ b/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps-inl.h @@ -94,7 +94,7 @@ extractRangesFromViewOrSliceOp(mlir::Value *view) { if (auto viewOp = llvm::dyn_cast(view->getDefiningOp())) return viewOp.getRanges(); - auto sliceOp = view->getDefiningOp()->cast(); + auto sliceOp = llvm::cast(view->getDefiningOp()); unsigned slicingDim = sliceOp.getSlicingDim(); auto *indexing = *(sliceOp.getIndexings().begin()); bool isRankReducing = indexing->getType().isa(); diff --git a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp index f1bb90dc618..22feb668966 100644 --- a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp @@ -71,7 +71,7 @@ public: // a getelementptr. Value *obtainDataPtr(Operation *op, Value *viewDescriptor, ArrayRef indices, FuncBuilder &rewriter) const { - auto loadOp = op->cast(); + auto loadOp = cast(op); auto elementType = loadOp.getViewType().template cast().getElementType(); auto *llvmPtrType = linalg::convertLinalgType(elementType) diff --git a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp index bce7f58860d..63093005c9c 100644 --- a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp @@ -95,7 +95,7 @@ extractFromRanges(ArrayRef ranges, SmallVector res; res.reserve(ranges.size()); for (auto *v : ranges) { - auto r = v->getDefiningOp()->cast(); + auto r = cast(v->getDefiningOp()); res.push_back(extract(r)); } return res; @@ -149,9 +149,9 @@ linalg::makeGenericLoopRanges(AffineMap operandRangesToLoopMaps, for (auto z : llvm::zip(res.steps, tileSizes)) { auto *step = std::get<0>(z); auto tileSize = std::get<1>(z); - auto stepValue = step->getDefiningOp()->cast().getValue(); + auto stepValue = cast(step->getDefiningOp()).getValue(); auto tileSizeValue = - tileSize->getDefiningOp()->cast().getValue(); + cast(tileSize->getDefiningOp()).getValue(); assert(stepValue > 0); tiledSteps.push_back(constant_index(stepValue * tileSizeValue)); } @@ -236,7 +236,7 @@ emitAndReturnLoadStoreOperands(LoadOrStoreOp loadOrStoreOp, ViewOp viewOp) { operands.push_back(indexing); continue; } - RangeOp range = indexing->getDefiningOp()->cast(); + RangeOp range = cast(indexing->getDefiningOp()); ValueHandle min(range.getMin()); Value *storeIndex = *(loadOrStoreOp.getIndices().begin() + storeDim++); using edsc::op::operator+; @@ -275,10 +275,10 @@ template <> PatternMatchResult Rewriter::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { - auto load = op->cast(); + auto load = cast(op); SliceOp slice = dyn_cast(load.getView()->getDefiningOp()); ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult()) - : load.getView()->getDefiningOp()->cast(); + : cast(load.getView()->getDefiningOp()); ScopedContext scope(FuncBuilder(load), load.getLoc()); auto *memRef = view.getSupportingMemRef(); auto operands = emitAndReturnLoadStoreOperands(load, view); @@ -290,10 +290,10 @@ template <> PatternMatchResult Rewriter::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { - auto store = op->cast(); + auto store = cast(op); SliceOp slice = dyn_cast(store.getView()->getDefiningOp()); ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult()) - : store.getView()->getDefiningOp()->cast(); + : cast(store.getView()->getDefiningOp()); ScopedContext scope(FuncBuilder(store), store.getLoc()); auto *valueToStore = store.getValueToStore(); auto *memRef = view.getSupportingMemRef(); diff --git a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp index c9f98e7d6a9..5f024ea4033 100644 --- a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp @@ -350,7 +350,7 @@ public: // Finally, update the return type of the function based on the argument to // the return operation. for (auto &block : f->getBlocks()) { - auto ret = block.getTerminator()->cast(); + auto ret = llvm::cast(block.getTerminator()); if (!ret) continue; if (ret.getNumOperands() && diff --git a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp index 942ce866182..4175fc2c5bf 100644 --- a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp @@ -49,7 +49,7 @@ struct SimplifyRedundantTranspose : public mlir::RewritePattern { mlir::PatternRewriter &rewriter) const override { // We can directly cast the current operation as this will only get invoked // on TransposeOp. - TransposeOp transpose = op->cast(); + TransposeOp transpose = llvm::cast(op); // Look through the input of the current transpose. mlir::Value *transposeInput = transpose.getOperand(); TransposeOp transposeInputOp = @@ -73,7 +73,7 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern { mlir::PatternMatchResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { - ReshapeOp reshape = op->cast(); + ReshapeOp reshape = llvm::cast(op); // Look through the input of the current reshape. ConstantOp constantOp = llvm::dyn_cast_or_null( reshape.getOperand()->getDefiningOp()); @@ -120,7 +120,7 @@ struct SimplifyReshapeReshape : public mlir::RewritePattern { mlir::PatternMatchResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { - ReshapeOp reshape = op->cast(); + ReshapeOp reshape = llvm::cast(op); // Look through the input of the current reshape. mlir::Value *reshapeInput = reshape.getOperand(); // If the input is defined by another reshape, bingo! @@ -142,7 +142,7 @@ struct SimplifyNullReshape : public mlir::RewritePattern { mlir::PatternMatchResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { - ReshapeOp reshape = op->cast(); + ReshapeOp reshape = llvm::cast(op); if (reshape.getOperand()->getType() != reshape.getResult()->getType()) return matchFailure(); rewriter.replaceOp(reshape, {reshape.getOperand()}); diff --git a/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp b/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp index db6ba73a73b..3e640dea558 100644 --- a/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp @@ -92,7 +92,7 @@ public: using intrinsics::constant_index; using linalg::intrinsics::range; using linalg::intrinsics::view; - toy::MulOp mul = op->cast(); + toy::MulOp mul = cast(op); auto loc = mul.getLoc(); Value *result = memRefTypeCast( rewriter, rewriter.create(loc, mul.getResult()->getType()) diff --git a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp index 534b5cbd2ab..0a2ff1d733d 100644 --- a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp @@ -93,7 +93,7 @@ public: /// number must match the number of result of `op`. SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { - auto add = op->cast(); + auto add = cast(op); auto loc = add.getLoc(); // Create a `toy.alloc` operation to allocate the output buffer for this op. Value *result = memRefTypeCast( @@ -135,7 +135,7 @@ public: // Get or create the declaration of the printf function in the module. Function *printfFunc = getPrintf(*op->getFunction()->getModule()); - auto print = op->cast(); + auto print = cast(op); auto loc = print.getLoc(); // We will operate on a MemRef abstraction, we use a type.cast to get one // if our operand is still a Toy array. @@ -234,7 +234,7 @@ public: SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { - toy::ConstantOp cstOp = op->cast(); + toy::ConstantOp cstOp = cast(op); auto loc = cstOp.getLoc(); auto retTy = cstOp.getResult()->getType().cast(); auto shape = retTy.getShape(); @@ -277,7 +277,7 @@ public: SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { - auto transpose = op->cast(); + auto transpose = cast(op); auto loc = transpose.getLoc(); Value *result = memRefTypeCast( rewriter, @@ -309,7 +309,7 @@ public: SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { - auto retOp = op->cast(); + auto retOp = cast(op); using namespace edsc; auto loc = retOp.getLoc(); // Argument is optional, handle both cases. diff --git a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp index 4e17b234d14..ab990193ab6 100644 --- a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp @@ -357,7 +357,7 @@ public: // Finally, update the return type of the function based on the argument to // the return operation. for (auto &block : f->getBlocks()) { - auto ret = block.getTerminator()->cast(); + auto ret = llvm::cast(block.getTerminator()); if (!ret) continue; if (ret.getNumOperands() && diff --git a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp index 8d6aed63d53..260f6a6e092 100644 --- a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp @@ -49,7 +49,7 @@ struct SimplifyRedundantTranspose : public mlir::RewritePattern { mlir::PatternRewriter &rewriter) const override { // We can directly cast the current operation as this will only get invoked // on TransposeOp. - TransposeOp transpose = op->cast(); + TransposeOp transpose = llvm::cast(op); // look through the input to the current transpose mlir::Value *transposeInput = transpose.getOperand(); mlir::Operation *transposeInputInst = transposeInput->getDefiningOp(); @@ -74,7 +74,7 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern { mlir::PatternMatchResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { - ReshapeOp reshape = op->cast(); + ReshapeOp reshape = llvm::cast(op); // look through the input to the current reshape mlir::Value *reshapeInput = reshape.getOperand(); mlir::Operation *reshapeInputInst = reshapeInput->getDefiningOp(); @@ -125,7 +125,7 @@ struct SimplifyReshapeReshape : public mlir::RewritePattern { mlir::PatternMatchResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { - ReshapeOp reshape = op->cast(); + ReshapeOp reshape = llvm::cast(op); // look through the input to the current reshape mlir::Value *reshapeInput = reshape.getOperand(); mlir::Operation *reshapeInputInst = reshapeInput->getDefiningOp(); @@ -150,7 +150,7 @@ struct SimplifyNullReshape : public mlir::RewritePattern { mlir::PatternMatchResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { - ReshapeOp reshape = op->cast(); + ReshapeOp reshape = llvm::cast(op); if (reshape.getOperand()->getType() != reshape.getResult()->getType()) return matchFailure(); rewriter.replaceOp(reshape, {reshape.getOperand()}); @@ -185,7 +185,7 @@ struct SimplifyIdentityTypeCast : public mlir::RewritePattern { mlir::PatternMatchResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { - TypeCastOp typeCast = op->cast(); + TypeCastOp typeCast = llvm::cast(op); auto resTy = typeCast.getResult()->getType(); auto *candidateOp = op; while (candidateOp && candidateOp->isa()) { diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 2eff412a71e..250fb942fb1 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -183,8 +183,8 @@ public: static LogicalResult constantFoldHook(Operation *op, ArrayRef operands, SmallVectorImpl &results) { - return op->cast().constantFold(operands, results, - op->getContext()); + return cast(op).constantFold(operands, results, + op->getContext()); } /// Op implementations can implement this hook. It should attempt to constant @@ -205,7 +205,7 @@ public: /// This is an implementation detail of the folder hook for AbstractOperation. static LogicalResult foldHook(Operation *op, SmallVectorImpl &results) { - return op->cast().fold(results); + return cast(op).fold(results); } /// This hook implements a generalized folder for this operation. Operations @@ -253,7 +253,7 @@ public: ArrayRef operands, SmallVectorImpl &results) { auto result = - op->cast().constantFold(operands, op->getContext()); + cast(op).constantFold(operands, op->getContext()); if (!result) return failure(); @@ -277,7 +277,7 @@ public: /// This is an implementation detail of the folder hook for AbstractOperation. static LogicalResult foldHook(Operation *op, SmallVectorImpl &results) { - auto *result = op->cast().fold(); + auto *result = cast(op).fold(); if (!result) return failure(); if (result != op->getResult(0)) @@ -808,7 +808,7 @@ public: static LogicalResult verifyInvariants(Operation *op) { return failure( failed(BaseVerifier...>::verifyTrait(op)) || - failed(op->cast().verify())); + failed(cast(op).verify())); } // Returns the properties of an operation by combining the properties of the diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index 31ec8ea54a6..088a4e473cd 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -389,14 +389,6 @@ public: // Conversions to declared operations like DimOp //===--------------------------------------------------------------------===// - /// The cast methods perform a cast from an Operation to a typed Op like - /// DimOp. This aborts if the parameter to the template isn't an instance of - /// the template type argument. - template OpClass cast() { - assert(isa() && "cast() argument of incompatible type!"); - return OpClass(this); - } - /// The is methods return true if the operation is a typed op (like DimOp) of /// of the given class. template bool isa() { return OpClass::classof(this); } diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 2dfed934ee0..f551afb2fe0 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -661,7 +661,7 @@ struct SimplifyAffineApply : public RewritePattern { PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - auto apply = op->cast(); + auto apply = cast(op); auto map = apply.getAffineMap(); AffineMap oldMap = map; @@ -1010,7 +1010,7 @@ struct AffineForLoopBoundFolder : public RewritePattern { PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - auto forOp = op->cast(); + auto forOp = cast(op); auto foldLowerOrUpperBound = [&forOp](bool lower) { // Check to see if each of the operands is the result of a constant. If // so, get the value. If not, ignore it. diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 60f2b142986..3d984c5efcf 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -192,7 +192,7 @@ bool mlir::isAccessInvariant(Value *iv, Value *index) { return false; } - auto composeOp = affineApplyOps[0]->cast(); + auto composeOp = cast(affineApplyOps[0]); // We need yet another level of indirection because the `dim` index of the // access may not correspond to the `dim` index of composeOp. return !(AffineValueMap(composeOp).isFunctionOf(0, iv)); diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 8d963e4739c..cc46d6558b6 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -603,7 +603,7 @@ mlir::insertBackwardComputationSlice(Operation *srcOpInst, Operation *dstOpInst, auto dstAffineForOp = dstLoopIVs[dstLoopDepth - 1]; FuncBuilder b(dstAffineForOp.getBody(), dstAffineForOp.getBody()->begin()); auto sliceLoopNest = - b.clone(*srcLoopIVs[0].getOperation())->cast(); + cast(b.clone(*srcLoopIVs[0].getOperation())); Operation *sliceInst = getInstAtPosition(positions, /*level=*/0, sliceLoopNest.getBody()); diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index 8fecf058bfc..627ca7add94 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -123,7 +123,7 @@ static AffineMap makePermutationMap( for (auto kvp : enclosingLoopToVectorDim) { assert(kvp.second < perm.size()); auto invariants = getInvariantAccesses( - kvp.first->cast().getInductionVar(), indices); + cast(kvp.first).getInductionVar(), indices); unsigned numIndices = indices.size(); unsigned countInvariantIndices = 0; for (unsigned dim = 0; dim < numIndices; ++dim) { @@ -181,7 +181,7 @@ AffineMap mlir::makePermutationMap( return ::makePermutationMap(load.getIndices(), enclosingLoopToVectorDim); } - auto store = op->cast(); + auto store = cast(op); return ::makePermutationMap(store.getIndices(), enclosingLoopToVectorDim); } diff --git a/mlir/lib/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/mlir/lib/FxpMathOps/Transforms/LowerUniformRealMath.cpp index afd8152a0c9..8bfcfa5aac8 100644 --- a/mlir/lib/FxpMathOps/Transforms/LowerUniformRealMath.cpp +++ b/mlir/lib/FxpMathOps/Transforms/LowerUniformRealMath.cpp @@ -124,7 +124,7 @@ struct UniformDequantizePattern : public RewritePattern { PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { - auto dcastOp = op->cast(); + auto dcastOp = cast(op); Type inputType = dcastOp.arg()->getType(); Type outputType = dcastOp.getResult()->getType(); @@ -328,7 +328,7 @@ struct UniformRealAddEwPattern : public RewritePattern { PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { - auto addOp = op->cast(); + auto addOp = cast(op); const UniformBinaryOpInfo info(op, addOp.lhs(), addOp.rhs(), addOp.clamp_min(), addOp.clamp_max()); if (!info.isValid()) { @@ -350,7 +350,7 @@ struct UniformRealMulEwPattern : public RewritePattern { PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { - auto mulOp = op->cast(); + auto mulOp = cast(op); const UniformBinaryOpInfo info(op, mulOp.lhs(), mulOp.rhs(), mulOp.clamp_min(), mulOp.clamp_max()); if (!info.isValid()) { diff --git a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp index 0d9025bc917..e9aee954b35 100644 --- a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp +++ b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp @@ -414,14 +414,14 @@ struct AllocOpLowering : public LLVMLegalizationPattern { PatternMatchResult match(Operation *op) const override { if (!LLVMLegalizationPattern::match(op)) return matchFailure(); - auto allocOp = op->cast(); + auto allocOp = cast(op); MemRefType type = allocOp.getType(); return isSupportedMemRefType(type) ? matchSuccess() : matchFailure(); } SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { - auto allocOp = op->cast(); + auto allocOp = cast(op); MemRefType type = allocOp.getType(); // Get actual sizes of the memref as values: static sizes are constant @@ -557,7 +557,7 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern { PatternMatchResult match(Operation *op) const override { if (!LLVMLegalizationPattern::match(op)) return matchFailure(); - auto memRefCastOp = op->cast(); + auto memRefCastOp = cast(op); MemRefType sourceType = memRefCastOp.getOperand()->getType().cast(); MemRefType targetType = memRefCastOp.getType(); @@ -569,7 +569,7 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern { SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { - auto memRefCastOp = op->cast(); + auto memRefCastOp = cast(op); auto targetType = memRefCastOp.getType(); auto sourceType = memRefCastOp.getOperand()->getType().cast(); @@ -636,7 +636,7 @@ struct DimOpLowering : public LLVMLegalizationPattern { PatternMatchResult match(Operation *op) const override { if (!LLVMLegalizationPattern::match(op)) return this->matchFailure(); - auto dimOp = op->cast(); + auto dimOp = cast(op); MemRefType type = dimOp.getOperand()->getType().cast(); return isSupportedMemRefType(type) ? matchSuccess() : matchFailure(); } @@ -644,7 +644,7 @@ struct DimOpLowering : public LLVMLegalizationPattern { SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { assert(operands.size() == 1 && "expected exactly one operand"); - auto dimOp = op->cast(); + auto dimOp = cast(op); MemRefType type = dimOp.getOperand()->getType().cast(); SmallVector results; @@ -683,7 +683,7 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern { PatternMatchResult match(Operation *op) const override { if (!LLVMLegalizationPattern::match(op)) return this->matchFailure(); - auto loadOp = op->cast(); + auto loadOp = cast(op); MemRefType type = loadOp.getMemRefType(); return isSupportedMemRefType(type) ? this->matchSuccess() : this->matchFailure(); @@ -794,7 +794,7 @@ struct LoadOpLowering : public LoadStoreOpLowering { SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { - auto loadOp = op->cast(); + auto loadOp = cast(op); auto type = loadOp.getMemRefType(); Value *dataPtr = getDataPtr(op->getLoc(), type, operands.front(), @@ -815,7 +815,7 @@ struct StoreOpLowering : public LoadStoreOpLowering { SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { - auto storeOp = op->cast(); + auto storeOp = cast(op); auto type = storeOp.getMemRefType(); Value *dataPtr = getDataPtr(op->getLoc(), type, operands[1], diff --git a/mlir/lib/Linalg/IR/LinalgOps.cpp b/mlir/lib/Linalg/IR/LinalgOps.cpp index 6998da595ee..8ea45dfd6a1 100644 --- a/mlir/lib/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Linalg/IR/LinalgOps.cpp @@ -320,7 +320,7 @@ void mlir::linalg::SliceOp::print(OpAsmPrinter *p) { } ViewOp mlir::linalg::SliceOp::getBaseViewOp() { - return getOperand(0)->getDefiningOp()->cast(); + return cast(getOperand(0)->getDefiningOp()); } ViewType mlir::linalg::SliceOp::getBaseViewType() { @@ -505,8 +505,7 @@ ParseResult parseBufferSizeOp(OpAsmParser *parser, OperationState *result); /// ``` void mlir::linalg::impl::printBufferSizeOp(OpAsmPrinter *p, Operation *op) { assert(op->getAbstractOperation() && "unregistered operation"); - *p << op->cast().getOperationName() << " " - << *op->getOperand(0); + *p << cast(op).getOperationName() << " " << *op->getOperand(0); p->printOptionalAttrDict(op->getAttrs()); *p << " : " << op->getOperand(0)->getType(); } diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index 90111a88476..2d1f5f22da5 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -181,7 +181,7 @@ public: } // Get MLIR types for injecting element pointer. - auto allocOp = op->cast(); + auto allocOp = cast(op); auto elementType = allocOp.getElementType(); uint64_t elementSize = 0; if (auto vectorType = elementType.dyn_cast()) @@ -239,7 +239,7 @@ public: } // Get MLIR types for extracting element pointer. - auto deallocOp = op->cast(); + auto deallocOp = cast(op); auto elementPtrTy = rewriter.getType(getPtrToElementType( deallocOp.getOperand()->getType().cast(), lowering)); @@ -283,7 +283,7 @@ public: // a getelementptr. This must be called under an edsc::ScopedContext. Value *obtainDataPtr(Operation *op, Value *viewDescriptor, ArrayRef indices, FuncBuilder &rewriter) const { - auto loadOp = op->cast(); + auto loadOp = cast(op); auto elementTy = rewriter.getType( getPtrToElementType(loadOp.getViewType(), lowering)); auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); @@ -329,7 +329,7 @@ public: SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { - auto rangeOp = op->cast(); + auto rangeOp = cast(op); auto rangeDescriptorTy = convertLinalgType(rangeOp.getResult()->getType(), lowering); @@ -355,7 +355,7 @@ public: SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { - auto sliceOp = op->cast(); + auto sliceOp = cast(op); auto viewDescriptorTy = convertLinalgType(sliceOp.getViewType(), lowering); auto viewType = sliceOp.getBaseViewType(); auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); @@ -453,7 +453,7 @@ public: SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { - auto viewOp = op->cast(); + auto viewOp = cast(op); auto viewDescriptorTy = convertLinalgType(viewOp.getViewType(), lowering); auto elementTy = rewriter.getType( getPtrToElementType(viewOp.getViewType(), lowering)); diff --git a/mlir/lib/Linalg/Transforms/Tiling.cpp b/mlir/lib/Linalg/Transforms/Tiling.cpp index 6e20542a818..e1fa74da698 100644 --- a/mlir/lib/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Linalg/Transforms/Tiling.cpp @@ -115,8 +115,8 @@ static SmallVector applyMapToRangePart(FuncBuilder *b, Location loc, } static bool isZero(Value *v) { - return v->getDefiningOp() && v->getDefiningOp()->isa() && - v->getDefiningOp()->cast().getValue() == 0; + return isa_and_nonnull(v->getDefiningOp()) && + cast(v->getDefiningOp()).getValue() == 0; } /// Returns a map that can be used to filter the zero values out of tileSizes. @@ -176,8 +176,8 @@ makeTiledLoopRanges(FuncBuilder *b, Location loc, AffineMap map, // Steps must be constant for now to abide by affine.for semantics. auto *newStep = state.getOrCreate( - step->getDefiningOp()->cast().getValue() * - tileSize->getDefiningOp()->cast().getValue()); + cast(step->getDefiningOp()).getValue() * + cast(tileSize->getDefiningOp()).getValue()); res.push_back(b->create(loc, mins[idx], maxes[idx], newStep)); // clang-format on } diff --git a/mlir/lib/Linalg/Utils/Utils.cpp b/mlir/lib/Linalg/Utils/Utils.cpp index 98cf4b75b6a..6732fa136bb 100644 --- a/mlir/lib/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Linalg/Utils/Utils.cpp @@ -42,12 +42,12 @@ mlir::edsc::LoopNestRangeBuilder::LoopNestRangeBuilder( assert(ranges[i].getType() && "expected !linalg.range type"); assert(ranges[i].getValue()->getDefiningOp() && "need operations to extract range parts"); - auto rangeOp = ranges[i].getValue()->getDefiningOp()->cast(); + auto rangeOp = cast(ranges[i].getValue()->getDefiningOp()); auto lb = rangeOp.min(); auto ub = rangeOp.max(); // This must be a constexpr index until we relax the affine.for constraint auto step = - rangeOp.step()->getDefiningOp()->cast().getValue(); + cast(rangeOp.step()->getDefiningOp()).getValue(); loops.emplace_back(ivs[i], ValueHandle(lb), ValueHandle(ub), step); } assert(loops.size() == ivs.size() && "Mismatch loops vs ivs size"); @@ -106,7 +106,7 @@ Value *mlir::createOrReturnView(FuncBuilder *b, Location loc, return view.getResult(); return b->create(loc, view.getResult(), ranges); } - auto slice = viewDefiningOp->cast(); + auto slice = cast(viewDefiningOp); unsigned idxRange = 0; SmallVector newIndexings; bool elide = true; diff --git a/mlir/lib/Quantization/IR/QuantOps.cpp b/mlir/lib/Quantization/IR/QuantOps.cpp index 046ad85a2a3..ab6d97f3a6f 100644 --- a/mlir/lib/Quantization/IR/QuantOps.cpp +++ b/mlir/lib/Quantization/IR/QuantOps.cpp @@ -43,9 +43,9 @@ public: : RewritePattern(StorageCastOp::getOperationName(), 1, context) {} PatternMatchResult match(Operation *op) const override { - auto scastOp = op->cast(); + auto scastOp = cast(op); if (matchPattern(scastOp.arg(), m_Op())) { - auto srcScastOp = scastOp.arg()->getDefiningOp()->cast(); + auto srcScastOp = cast(scastOp.arg()->getDefiningOp()); if (srcScastOp.arg()->getType() == scastOp.getResult()->getType()) { return matchSuccess(); } @@ -54,8 +54,8 @@ public: } void rewrite(Operation *op, PatternRewriter &rewriter) const override { - auto scastOp = op->cast(); - auto srcScastOp = scastOp.arg()->getDefiningOp()->cast(); + auto scastOp = cast(op); + auto srcScastOp = cast(scastOp.arg()->getDefiningOp()); rewriter.replaceOp(op, srcScastOp.arg()); } }; diff --git a/mlir/lib/Quantization/Transforms/ConvertConst.cpp b/mlir/lib/Quantization/Transforms/ConvertConst.cpp index 21a0de2ec63..ad41f8f5d75 100644 --- a/mlir/lib/Quantization/Transforms/ConvertConst.cpp +++ b/mlir/lib/Quantization/Transforms/ConvertConst.cpp @@ -59,7 +59,7 @@ PatternMatchResult QuantizedConstRewrite::match(Operation *op) const { State state; // Is the operand a constant? - auto qbarrier = op->cast(); + auto qbarrier = cast(op); if (!matchPattern(qbarrier.arg(), m_Constant(&state.value))) { return matchFailure(); } diff --git a/mlir/lib/Quantization/Transforms/ConvertSimQuant.cpp b/mlir/lib/Quantization/Transforms/ConvertSimQuant.cpp index 4df7b88bda0..c62adc8c58c 100644 --- a/mlir/lib/Quantization/Transforms/ConvertSimQuant.cpp +++ b/mlir/lib/Quantization/Transforms/ConvertSimQuant.cpp @@ -59,7 +59,7 @@ public: } bool failableRewrite(Operation *op, PatternRewriter &rewriter) const { - auto fqOp = op->cast(); + auto fqOp = cast(op); auto converter = ExpressedToUniformQuantizedConverter::forInputType(fqOp.getType()); diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index bc68a78bd0a..59c1400aa88 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -283,7 +283,7 @@ struct SimplifyAllocConst : public RewritePattern { : RewritePattern(AllocOp::getOperationName(), 1, context) {} PatternMatchResult match(Operation *op) const override { - auto alloc = op->cast(); + auto alloc = cast(op); // Check to see if any dimensions operands are constants. If so, we can // substitute and drop them. @@ -294,7 +294,7 @@ struct SimplifyAllocConst : public RewritePattern { } void rewrite(Operation *op, PatternRewriter &rewriter) const override { - auto allocOp = op->cast(); + auto allocOp = cast(op); auto memrefType = allocOp.getType(); // Ok, we have one or more constant operands. Collect the non-constant ones @@ -352,7 +352,7 @@ struct SimplifyDeadAlloc : public RewritePattern { PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { // Check if the alloc'ed value has any uses. - auto alloc = op->cast(); + auto alloc = cast(op); if (!alloc.use_empty()) return matchFailure(); @@ -468,7 +468,7 @@ struct SimplifyIndirectCallWithKnownCallee : public RewritePattern { PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - auto indirectCall = op->cast(); + auto indirectCall = cast(op); // Check that the callee is a constant operation. Attribute callee; @@ -978,7 +978,7 @@ struct SimplifyConstCondBranchPred : public RewritePattern { PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - auto condbr = op->cast(); + auto condbr = cast(op); // Check that the condition is a constant. if (!matchPattern(condbr.getCondition(), m_Op())) @@ -1222,7 +1222,7 @@ struct SimplifyDeadDealloc : public RewritePattern { PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - auto dealloc = op->cast(); + auto dealloc = cast(op); // Check that the memref operand's defining operation is an AllocOp. Value *memref = dealloc.memref(); @@ -2107,7 +2107,7 @@ struct SimplifyXMinusX : public RewritePattern { PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - auto subi = op->cast(); + auto subi = cast(op); if (subi.getOperand(0) != subi.getOperand(1)) return matchFailure(); @@ -2192,7 +2192,7 @@ struct SimplifyXXOrX : public RewritePattern { PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - auto xorOp = op->cast(); + auto xorOp = cast(op); if (xorOp.lhs() != xorOp.rhs()) return matchFailure(); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 1c4a4d1f755..d430c5d85cd 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -128,7 +128,7 @@ struct LoopNestStateCollector { void collect(Operation *opToWalk) { opToWalk->walk([&](Operation *op) { if (op->isa()) - forOps.push_back(op->cast()); + forOps.push_back(cast(op)); else if (op->getNumRegions() != 0) hasNonForRegion = true; else if (op->isa()) @@ -172,7 +172,7 @@ public: unsigned getLoadOpCount(Value *memref) { unsigned loadOpCount = 0; for (auto *loadOpInst : loads) { - if (memref == loadOpInst->cast().getMemRef()) + if (memref == cast(loadOpInst).getMemRef()) ++loadOpCount; } return loadOpCount; @@ -182,7 +182,7 @@ public: unsigned getStoreOpCount(Value *memref) { unsigned storeOpCount = 0; for (auto *storeOpInst : stores) { - if (memref == storeOpInst->cast().getMemRef()) + if (memref == cast(storeOpInst).getMemRef()) ++storeOpCount; } return storeOpCount; @@ -192,7 +192,7 @@ public: void getStoreOpsForMemref(Value *memref, SmallVectorImpl *storeOps) { for (auto *storeOpInst : stores) { - if (memref == storeOpInst->cast().getMemRef()) + if (memref == cast(storeOpInst).getMemRef()) storeOps->push_back(storeOpInst); } } @@ -201,7 +201,7 @@ public: void getLoadOpsForMemref(Value *memref, SmallVectorImpl *loadOps) { for (auto *loadOpInst : loads) { - if (memref == loadOpInst->cast().getMemRef()) + if (memref == cast(loadOpInst).getMemRef()) loadOps->push_back(loadOpInst); } } @@ -211,10 +211,10 @@ public: void getLoadAndStoreMemrefSet(DenseSet *loadAndStoreMemrefSet) { llvm::SmallDenseSet loadMemrefs; for (auto *loadOpInst : loads) { - loadMemrefs.insert(loadOpInst->cast().getMemRef()); + loadMemrefs.insert(cast(loadOpInst).getMemRef()); } for (auto *storeOpInst : stores) { - auto *memref = storeOpInst->cast().getMemRef(); + auto *memref = cast(storeOpInst).getMemRef(); if (loadMemrefs.count(memref) > 0) loadAndStoreMemrefSet->insert(memref); } @@ -306,7 +306,7 @@ public: bool writesToLiveInOrEscapingMemrefs(unsigned id) { Node *node = getNode(id); for (auto *storeOpInst : node->stores) { - auto *memref = storeOpInst->cast().getMemRef(); + auto *memref = cast(storeOpInst).getMemRef(); auto *op = memref->getDefiningOp(); // Return true if 'memref' is a block argument. if (!op) @@ -331,7 +331,7 @@ public: Node *node = getNode(id); for (auto *storeOpInst : node->stores) { // Return false if there exist out edges from 'id' on 'memref'. - if (getOutEdgeCount(id, storeOpInst->cast().getMemRef()) > 0) + if (getOutEdgeCount(id, cast(storeOpInst).getMemRef()) > 0) return false; } return true; @@ -656,12 +656,12 @@ bool MemRefDependenceGraph::init(Function &f) { Node node(nextNodeId++, &op); for (auto *opInst : collector.loadOpInsts) { node.loads.push_back(opInst); - auto *memref = opInst->cast().getMemRef(); + auto *memref = cast(opInst).getMemRef(); memrefAccesses[memref].insert(node.id); } for (auto *opInst : collector.storeOpInsts) { node.stores.push_back(opInst); - auto *memref = opInst->cast().getMemRef(); + auto *memref = cast(opInst).getMemRef(); memrefAccesses[memref].insert(node.id); } forToNodeMap[&op] = node.id; @@ -670,14 +670,14 @@ bool MemRefDependenceGraph::init(Function &f) { // Create graph node for top-level load op. Node node(nextNodeId++, &op); node.loads.push_back(&op); - auto *memref = op.cast().getMemRef(); + auto *memref = cast(op).getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); } else if (auto storeOp = dyn_cast(op)) { // Create graph node for top-level store op. Node node(nextNodeId++, &op); node.stores.push_back(&op); - auto *memref = op.cast().getMemRef(); + auto *memref = cast(op).getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); } else if (op.getNumRegions() != 0) { @@ -887,7 +887,7 @@ static void moveLoadsAccessingMemrefTo(Value *memref, dstLoads->clear(); SmallVector srcLoadsToKeep; for (auto *load : *srcLoads) { - if (load->cast().getMemRef() == memref) + if (cast(load).getMemRef() == memref) dstLoads->push_back(load); else srcLoadsToKeep.push_back(load); @@ -1051,7 +1051,7 @@ computeLoopInterchangePermutation(ArrayRef loops, static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) { assert(node->op->isa()); SmallVector loops; - AffineForOp curr = node->op->cast(); + AffineForOp curr = cast(node->op); getPerfectlyNestedLoops(loops, curr); if (loops.size() < 2) return; @@ -1107,7 +1107,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, // Builder to create constants at the top level. FuncBuilder top(forInst->getFunction()); // Create new memref type based on slice bounds. - auto *oldMemRef = srcStoreOpInst->cast().getMemRef(); + auto *oldMemRef = cast(srcStoreOpInst).getMemRef(); auto oldMemRefType = oldMemRef->getType().cast(); unsigned rank = oldMemRefType.getRank(); @@ -1233,7 +1233,7 @@ static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, // Gather all memrefs from 'srcNode' store ops. DenseSet storeMemrefs; for (auto *storeOpInst : srcNode->stores) { - storeMemrefs.insert(storeOpInst->cast().getMemRef()); + storeMemrefs.insert(cast(storeOpInst).getMemRef()); } // Return false if any of the following are true: // *) 'srcNode' writes to a live in/out memref other than 'memref'. @@ -1842,7 +1842,7 @@ public: DenseSet visitedMemrefs; while (!loads.empty()) { // Get memref of load on top of the stack. - auto *memref = loads.back()->cast().getMemRef(); + auto *memref = cast(loads.back()).getMemRef(); if (visitedMemrefs.count(memref) > 0) continue; visitedMemrefs.insert(memref); @@ -1898,7 +1898,7 @@ public: // Gather 'dstNode' store ops to 'memref'. SmallVector dstStoreOpInsts; for (auto *storeOpInst : dstNode->stores) - if (storeOpInst->cast().getMemRef() == memref) + if (cast(storeOpInst).getMemRef() == memref) dstStoreOpInsts.push_back(storeOpInst); unsigned bestDstLoopDepth; @@ -1916,7 +1916,7 @@ public: LLVM_DEBUG(llvm::dbgs() << "\tslice loop nest:\n" << *sliceLoopNest.getOperation() << "\n"); // Move 'dstAffineForOp' before 'insertPointInst' if needed. - auto dstAffineForOp = dstNode->op->cast(); + auto dstAffineForOp = cast(dstNode->op); if (insertPointInst != dstAffineForOp.getOperation()) { dstAffineForOp.getOperation()->moveBefore(insertPointInst); } @@ -1934,7 +1934,7 @@ public: // Create private memref for 'memref' in 'dstAffineForOp'. SmallVector storesForMemref; for (auto *storeOpInst : sliceCollector.storeOpInsts) { - if (storeOpInst->cast().getMemRef() == memref) + if (cast(storeOpInst).getMemRef() == memref) storesForMemref.push_back(storeOpInst); } assert(storesForMemref.size() == 1); @@ -1956,7 +1956,7 @@ public: // Add new load ops to current Node load op list 'loads' to // continue fusing based on new operands. for (auto *loadOpInst : dstLoopCollector.loadOpInsts) { - auto *loadMemRef = loadOpInst->cast().getMemRef(); + auto *loadMemRef = cast(loadOpInst).getMemRef(); if (visitedMemrefs.count(loadMemRef) == 0) loads.push_back(loadOpInst); } @@ -2072,7 +2072,7 @@ public: auto sliceLoopNest = mlir::insertBackwardComputationSlice( sibLoadOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState); if (sliceLoopNest != nullptr) { - auto dstForInst = dstNode->op->cast(); + auto dstForInst = cast(dstNode->op); // Update operation position of fused loop nest (if needed). if (insertPointInst != dstForInst.getOperation()) { dstForInst.getOperation()->moveBefore(insertPointInst); @@ -2114,7 +2114,7 @@ public: // Check that all stores are to the same memref. DenseSet storeMemrefs; for (auto *storeOpInst : sibNode->stores) { - storeMemrefs.insert(storeOpInst->cast().getMemRef()); + storeMemrefs.insert(cast(storeOpInst).getMemRef()); } if (storeMemrefs.size() != 1) return false; @@ -2214,7 +2214,7 @@ public: } // Collect dst loop stats after memref privatizaton transformation. - auto dstForInst = dstNode->op->cast(); + auto dstForInst = cast(dstNode->op); LoopNestStateCollector dstLoopCollector; dstLoopCollector.collect(dstForInst.getOperation()); // Clear and add back loads and stores @@ -2226,7 +2226,7 @@ public: // function. if (mdg->getOutEdgeCount(sibNode->id) == 0) { mdg->removeNode(sibNode->id); - sibNode->op->cast().erase(); + sibNode->op->erase(); } } diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 236ef81ebd2..1707f78143f 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -113,7 +113,7 @@ void LoopUnroll::runOnFunction() { hasInnerLoops |= walkPostOrder(block.begin(), block.end()); if (opInst->isa()) { if (!hasInnerLoops) - loops.push_back(opInst->cast()); + loops.push_back(cast(opInst)); return true; } return hasInnerLoops; diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 0a23295c8d9..43e8f4a7306 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -187,7 +187,7 @@ LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp, // Insert the cleanup loop right after 'forOp'. FuncBuilder builder(forInst->getBlock(), std::next(Block::iterator(forInst))); - auto cleanupAffineForOp = builder.clone(*forInst)->cast(); + auto cleanupAffineForOp = cast(builder.clone(*forInst)); // Adjust the lower bound of the cleanup loop; its upper bound is the same // as the original loop's upper bound. AffineMap cleanupMap; diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index 1ffe5e3ddd7..6f0162eaea6 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -626,7 +626,7 @@ void LowerAffinePass::runOnFunction() { } else if (auto forOp = dyn_cast(op)) { if (lowerAffineFor(forOp)) return signalPassFailure(); - } else if (lowerAffineApply(op->cast())) { + } else if (lowerAffineApply(cast(op))) { return signalPassFailure(); } } diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index f7352d6e2b7..657169ad81e 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -264,7 +264,7 @@ VectorTransferRewriter::matchAndRewrite( using namespace mlir::edsc::op; using namespace mlir::edsc::intrinsics; - VectorTransferReadOp transfer = op->cast(); + VectorTransferReadOp transfer = cast(op); // 1. Setup all the captures. ScopedContext scope(FuncBuilder(op), transfer.getLoc()); @@ -323,7 +323,7 @@ VectorTransferRewriter::matchAndRewrite( using namespace mlir::edsc::op; using namespace mlir::edsc::intrinsics; - VectorTransferWriteOp transfer = op->cast(); + VectorTransferWriteOp transfer = cast(op); // 1. Setup all the captures. ScopedContext scope(FuncBuilder(op), transfer.getLoc()); diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 28dfb2278e0..206ae53b4bd 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -679,7 +679,7 @@ static bool materialize(Function *f, const SetVector &terminators, continue; } - auto terminator = term->cast(); + auto terminator = cast(term); LLVM_DEBUG(dbgs() << "\nFrom terminator:" << *term); // Get the transitive use-defs starting from terminator, limited to the diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index 94df936c93f..118efe5548d 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -201,7 +201,7 @@ void MemRefDataFlowOpt::forwardStoreToLoad(LoadOp loadOp) { return; // Perform the actual store to load forwarding. - Value *storeVal = lastWriteStoreOp->cast().getValueToStore(); + Value *storeVal = cast(lastWriteStoreOp).getValueToStore(); loadOp.getResult()->replaceAllUsesWith(storeVal); // Record the memref for a later sweep to optimize away. memrefsToErase.insert(loadOp.getMemRef()); diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 0da97f7d169..272972d233d 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -234,8 +234,8 @@ static void findMatchingStartFinishInsts( // For each start operation, we look for a matching finish operation. for (auto *dmaStartInst : dmaStartInsts) { for (auto *dmaFinishInst : dmaFinishInsts) { - if (checkTagMatch(dmaStartInst->cast(), - dmaFinishInst->cast())) { + if (checkTagMatch(cast(dmaStartInst), + cast(dmaFinishInst))) { startWaitPairs.push_back({dmaStartInst, dmaFinishInst}); break; } @@ -273,7 +273,7 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { for (auto &pair : startWaitPairs) { auto *dmaStartInst = pair.first; Value *oldMemRef = dmaStartInst->getOperand( - dmaStartInst->cast().getFasterMemPos()); + cast(dmaStartInst).getFasterMemPos()); if (!doubleBuffer(oldMemRef, forOp)) { // Normally, double buffering should not fail because we already checked // that there are no uses outside. diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 7fbb48ecf99..1ae75b4fbf7 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -426,7 +426,7 @@ LogicalResult mlir::loopUnrollByFactor(AffineForOp forOp, Operation *op = forOp.getOperation(); if (getLargestDivisorOfTripCount(forOp) % unrollFactor != 0) { FuncBuilder builder(op->getBlock(), ++Block::iterator(op)); - auto cleanupForInst = builder.clone(*op)->cast(); + auto cleanupForInst = cast(builder.clone(*op)); AffineMap cleanupMap; SmallVector cleanupOperands; getCleanupLoopLowerBound(forOp, unrollFactor, &cleanupMap, &cleanupOperands, @@ -512,7 +512,7 @@ void mlir::interchangeLoops(AffineForOp forOpA, AffineForOp forOpB) { void mlir::sinkLoop(AffineForOp forOp, unsigned loopDepth) { for (unsigned i = 0; i < loopDepth; ++i) { assert(forOp.getBody()->front().isa()); - AffineForOp nextForOp = forOp.getBody()->front().cast(); + AffineForOp nextForOp = cast(forOp.getBody()->front()); interchangeLoops(forOp, nextForOp); } } diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index b64dc53e037..20138d56a3a 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -253,7 +253,7 @@ void VectorizerTestPass::testNormalizeMaps() { SmallVector matches; pattern.match(f, &matches); for (auto m : matches) { - auto app = m.getMatchedOperation()->cast(); + auto app = cast(m.getMatchedOperation()); FuncBuilder b(m.getMatchedOperation()); SmallVector operands(app.getOperands()); makeComposedAffineApply(&b, app.getLoc(), app.getAffineMap(), operands); diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 9b8768a6445..4a58b15e720 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -859,7 +859,7 @@ static FilterFunctionType isVectorizableLoopPtrFactory(const llvm::DenseSet ¶llelLoops, int fastestVaryingMemRefDimension) { return [¶llelLoops, fastestVaryingMemRefDimension](Operation &forOp) { - auto loop = forOp.cast(); + auto loop = cast(forOp); auto parallelIt = parallelLoops.find(loop); if (parallelIt == parallelLoops.end()) return false; @@ -879,7 +879,7 @@ static LogicalResult vectorizeLoopsAndLoadsRecursively(NestedMatch oneMatch, VectorizationState *state) { auto *loopInst = oneMatch.getMatchedOperation(); - auto loop = loopInst->cast(); + auto loop = cast(loopInst); auto childrenMatches = oneMatch.getMatchedChildren(); // 1. DFS postorder recursion, if any of my children fails, I fail too. @@ -1118,7 +1118,7 @@ static LogicalResult vectorizeNonTerminals(VectorizationState *state) { /// anything below it fails. static LogicalResult vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) { - auto loop = m.getMatchedOperation()->cast(); + auto loop = cast(m.getMatchedOperation()); VectorizationState state; state.strategy = strategy; @@ -1139,7 +1139,7 @@ static LogicalResult vectorizeRootMatch(NestedMatch m, /// RAII. auto *loopInst = loop.getOperation(); FuncBuilder builder(loopInst); - auto clonedLoop = builder.clone(*loopInst)->cast(); + auto clonedLoop = cast(builder.clone(*loopInst)); struct Guard { LogicalResult failure() { loop.getInductionVar()->replaceAllUsesWith(clonedLoop.getInductionVar()); -- cgit v1.2.3 From d5b60ee8407d12e51b017c6390af5c17683713e1 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Sat, 11 May 2019 18:59:54 -0700 Subject: Replace Operation::isa with llvm::isa. -- PiperOrigin-RevId: 247789235 --- .../Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp | 6 ++--- mlir/examples/Linalg/Linalg2/lib/Transforms.cpp | 5 +++-- .../Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp | 2 +- mlir/examples/Linalg/Linalg3/lib/TensorOps.cpp | 8 +++---- mlir/examples/toy/Ch5/mlir/ToyCombine.cpp | 2 +- mlir/include/mlir/IR/Matchers.h | 2 +- mlir/include/mlir/IR/Operation.h | 8 ------- mlir/lib/AffineOps/AffineOps.cpp | 6 ++--- mlir/lib/Analysis/AffineAnalysis.cpp | 6 ++--- mlir/lib/Analysis/LoopAnalysis.cpp | 4 ++-- mlir/lib/Analysis/NestedMatcher.cpp | 6 ++--- mlir/lib/Analysis/TestMemRefDependenceCheck.cpp | 2 +- mlir/lib/Analysis/Utils.cpp | 16 ++++++------- mlir/lib/Analysis/VectorAnalysis.cpp | 2 +- .../lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp | 2 +- mlir/lib/Linalg/IR/LinalgOps.cpp | 6 ++--- mlir/lib/StandardOps/Ops.cpp | 2 +- mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 4 ++-- mlir/lib/Transforms/DmaGeneration.cpp | 4 ++-- mlir/lib/Transforms/LoopFusion.cpp | 26 +++++++++++----------- mlir/lib/Transforms/LoopInvariantCodeMotion.cpp | 4 ++-- mlir/lib/Transforms/LoopUnroll.cpp | 2 +- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 4 ++-- mlir/lib/Transforms/LowerAffine.cpp | 3 +-- mlir/lib/Transforms/MaterializeVectors.cpp | 10 ++++----- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 6 ++--- mlir/lib/Transforms/PipelineDataTransfer.cpp | 12 +++++----- .../Utils/GreedyPatternRewriteDriver.cpp | 2 +- mlir/lib/Transforms/Utils/LoopUtils.cpp | 5 ++--- mlir/lib/Transforms/Utils/Utils.cpp | 13 +++++------ .../Vectorization/VectorizerTestPass.cpp | 2 +- mlir/lib/Transforms/Vectorize.cpp | 12 +++++----- mlir/tools/mlir-tblgen/RewriterGen.cpp | 2 +- 33 files changed, 92 insertions(+), 104 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp index 48884b12dad..f1fc4ed23ab 100644 --- a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp @@ -148,7 +148,7 @@ public: : DialectOpConversion(linalg::RangeOp::getOperationName(), 1, context) {} PatternMatchResult match(Operation *op) const override { - if (op->isa()) + if (isa(op)) return matchSuccess(); return matchFailure(); } @@ -180,7 +180,7 @@ public: : DialectOpConversion(linalg::ViewOp::getOperationName(), 1, context) {} PatternMatchResult match(Operation *op) const override { - if (op->isa()) + if (isa(op)) return matchSuccess(); return matchFailure(); } @@ -312,7 +312,7 @@ public: : DialectOpConversion(linalg::SliceOp::getOperationName(), 1, context) {} PatternMatchResult match(Operation *op) const override { - if (op->isa()) + if (isa(op)) return matchSuccess(); return matchFailure(); } diff --git a/mlir/examples/Linalg/Linalg2/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg2/lib/Transforms.cpp index 9df0af8d25d..d78d6aa02fd 100644 --- a/mlir/examples/Linalg/Linalg2/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg2/lib/Transforms.cpp @@ -29,6 +29,7 @@ using llvm::ArrayRef; using llvm::cast; +using llvm::isa; using llvm::SmallVector; using mlir::FuncBuilder; using mlir::MemRefType; @@ -44,7 +45,7 @@ using namespace linalg::intrinsics; // analyses. This builds the chain. static SmallVector getViewChain(mlir::Value *v) { assert(v->getType().isa() && "ViewType expected"); - if (v->getDefiningOp()->isa()) { + if (isa(v->getDefiningOp())) { return SmallVector{v}; } @@ -54,7 +55,7 @@ static SmallVector getViewChain(mlir::Value *v) { tmp.push_back(v); v = sliceOp.getParentView(); } while (!v->getType().isa()); - assert(v->getDefiningOp()->isa() && "must be a ViewOp"); + assert(isa(v->getDefiningOp()) && "must be a ViewOp"); tmp.push_back(v); return SmallVector(tmp.rbegin(), tmp.rend()); } diff --git a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp index 22feb668966..a2b39de7eac 100644 --- a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp @@ -60,7 +60,7 @@ public: // Match the Op specified as template argument. PatternMatchResult match(Operation *op) const override { - if (op->isa()) + if (isa(op)) return matchSuccess(); return matchFailure(); } diff --git a/mlir/examples/Linalg/Linalg3/lib/TensorOps.cpp b/mlir/examples/Linalg/Linalg3/lib/TensorOps.cpp index a5b094c777e..2209e9d3127 100644 --- a/mlir/examples/Linalg/Linalg3/lib/TensorOps.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/TensorOps.cpp @@ -103,8 +103,8 @@ void linalg::MatvecOp::writeAsFinerGrainTensorContraction() { auto *op = getOperation(); auto *vA(getInputView(0)), *vB(getInputView(1)), *vC(getOutputView(0)); auto indexingPosPair = getViewRootIndexing(vA, 0); - assert(indexingPosPair.first->getDefiningOp() && - indexingPosPair.first->getDefiningOp()->isa()); + assert( + llvm::isa_and_nonnull(indexingPosPair.first->getDefiningOp())); // clang-format off ScopedContext scope(FuncBuilder(op), op->getLoc()); IndexHandle i; @@ -177,8 +177,8 @@ void linalg::MatmulOp::writeAsFinerGrainTensorContraction() { auto *op = getOperation(); auto *vA(getInputView(0)), *vB(getInputView(1)), *vC(getOutputView(0)); auto indexingPosPair = getViewRootIndexing(vB, 1); - assert(indexingPosPair.first->getDefiningOp() && - indexingPosPair.first->getDefiningOp()->isa()); + assert( + llvm::isa_and_nonnull(indexingPosPair.first->getDefiningOp())); using linalg::common::LoopNestRangeBuilder; // clang-format off ScopedContext scope(FuncBuilder(op), op->getLoc()); diff --git a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp index 260f6a6e092..23304475ffa 100644 --- a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp @@ -188,7 +188,7 @@ struct SimplifyIdentityTypeCast : public mlir::RewritePattern { TypeCastOp typeCast = llvm::cast(op); auto resTy = typeCast.getResult()->getType(); auto *candidateOp = op; - while (candidateOp && candidateOp->isa()) { + while (llvm::isa_and_nonnull(candidateOp)) { if (resTy == candidateOp->getOperand(0)->getType()) { rewriter.replaceOp(typeCast, {candidateOp->getOperand(0)}); return matchSuccess(); diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h index fd139c68db1..3e337b29819 100644 --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -121,7 +121,7 @@ template struct constant_int_value_matcher { /// The matcher that matches a certain kind of op. template struct op_matcher { - bool match(Operation *op) { return op->isa(); } + bool match(Operation *op) { return isa(op); } }; } // end namespace detail diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index 088a4e473cd..e71e8ed8d15 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -385,14 +385,6 @@ public: /// Attempt to fold this operation using the Op's registered foldHook. LogicalResult fold(SmallVectorImpl &results); - //===--------------------------------------------------------------------===// - // Conversions to declared operations like DimOp - //===--------------------------------------------------------------------===// - - /// The is methods return true if the operation is a typed op (like DimOp) of - /// of the given class. - template bool isa() { return OpClass::classof(this); } - //===--------------------------------------------------------------------===// // Operation Walkers //===--------------------------------------------------------------------===// diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index f551afb2fe0..40069f6887d 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -58,7 +58,7 @@ bool mlir::isValidDim(Value *value) { if (auto *op = value->getDefiningOp()) { // Top level operation or constant operation is ok. - if (op->getParentOp() == nullptr || op->isa()) + if (op->getParentOp() == nullptr || isa(op)) return true; // Affine apply operation is ok if all of its operands are ok. if (auto applyOp = dyn_cast(op)) @@ -83,7 +83,7 @@ bool mlir::isValidSymbol(Value *value) { if (auto *op = value->getDefiningOp()) { // Top level operation or constant operation is ok. - if (op->getParentOp() == nullptr || op->isa()) + if (op->getParentOp() == nullptr || isa(op)) return true; // Affine apply operation is ok if all of its operands are ok. if (auto applyOp = dyn_cast(op)) @@ -688,7 +688,7 @@ void AffineApplyOp::getCanonicalizationPatterns( // Check that if a "block" has a terminator, it is an `AffineTerminatorOp`. static LogicalResult checkHasAffineTerminator(OpState &op, Block &block) { - if (block.empty() || block.back().isa()) + if (block.empty() || isa(block.back())) return success(); op.emitOpError("expects regions to end with '" + diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index c1f455eb9dc..861c0a18c26 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -775,8 +775,8 @@ bool mlir::checkMemrefAccessDependence( if (srcAccess.memref != dstAccess.memref) return false; // Return 'false' if one of these accesses is not a StoreOp. - if (!allowRAR && !srcAccess.opInst->isa() && - !dstAccess.opInst->isa()) + if (!allowRAR && !isa(srcAccess.opInst) && + !isa(dstAccess.opInst)) return false; // Get composed access function for 'srcAccess'. @@ -860,7 +860,7 @@ void mlir::getDependenceComponents( // Collect all load and store ops in loop nest rooted at 'forOp'. SmallVector loadAndStoreOpInsts; forOp.getOperation()->walk([&](Operation *opInst) { - if (opInst->isa() || opInst->isa()) + if (isa(opInst) || isa(opInst)) loadAndStoreOpInsts.push_back(opInst); }); diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 3d984c5efcf..3ec4833329d 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -279,7 +279,7 @@ static bool isVectorElement(LoadOrStoreOpPointer memoryOp) { } static bool isVectorTransferReadOrWrite(Operation &op) { - return op.isa() || op.isa(); + return isa(op) || isa(op); } using VectorizableOpFun = std::function; @@ -300,7 +300,7 @@ isVectorizableLoopBodyWithOpCond(AffineForOp loop, // No vectorization across unknown regions. auto regions = matcher::Op([](Operation &op) -> bool { return op.getNumRegions() != 0 && - !(op.isa() || op.isa()); + !(isa(op) || isa(op)); }); SmallVector regionsMatched; regions.match(forOp, ®ionsMatched); diff --git a/mlir/lib/Analysis/NestedMatcher.cpp b/mlir/lib/Analysis/NestedMatcher.cpp index 95270a1b946..f08f66df506 100644 --- a/mlir/lib/Analysis/NestedMatcher.cpp +++ b/mlir/lib/Analysis/NestedMatcher.cpp @@ -110,9 +110,9 @@ void NestedPattern::matchOne(Operation *op, } } -static bool isAffineForOp(Operation &op) { return op.isa(); } +static bool isAffineForOp(Operation &op) { return isa(op); } -static bool isAffineIfOp(Operation &op) { return op.isa(); } +static bool isAffineIfOp(Operation &op) { return isa(op); } namespace mlir { namespace matcher { @@ -154,7 +154,7 @@ NestedPattern For(FilterFunctionType filter, ArrayRef nested) { } bool isLoadOrStore(Operation &op) { - return op.isa() || op.isa(); + return isa(op) || isa(op); } } // end namespace matcher diff --git a/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp b/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp index 4005871ba60..2b0f1ab50ad 100644 --- a/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp @@ -114,7 +114,7 @@ void TestMemRefDependenceCheck::runOnFunction() { // Collect the loads and stores within the function. loadsAndStores.clear(); getFunction().walk([&](Operation *op) { - if (op->isa() || op->isa()) + if (isa(op) || isa(op)) loadsAndStores.push_back(op); }); diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index cc46d6558b6..2a46c0e5b4f 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -45,7 +45,7 @@ void mlir::getLoopIVs(Operation &op, SmallVectorImpl *loops) { // Traverse up the hierarchy collecing all 'affine.for' operation while // skipping over 'affine.if' operations. while (currOp && ((currAffineForOp = dyn_cast(currOp)) || - currOp->isa())) { + isa(currOp))) { if (currAffineForOp) loops->push_back(currAffineForOp); currOp = currOp->getParentOp(); @@ -172,7 +172,7 @@ LogicalResult MemRefRegion::unionBoundingBox(const MemRefRegion &other) { LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth, ComputationSliceState *sliceState, bool addMemRefDimBounds) { - assert((op->isa() || op->isa()) && "load/store op expected"); + assert((isa(op) || isa(op)) && "load/store op expected"); MemRefAccess access(op); memref = access.memref; @@ -490,7 +490,7 @@ LogicalResult mlir::getBackwardComputationSliceState( const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, unsigned dstLoopDepth, ComputationSliceState *sliceState) { bool readReadAccesses = - srcAccess.opInst->isa() && dstAccess.opInst->isa(); + isa(srcAccess.opInst) && isa(dstAccess.opInst); FlatAffineConstraints dependenceConstraints; if (!checkMemrefAccessDependence( srcAccess, dstAccess, /*loopDepth=*/1, &dependenceConstraints, @@ -642,7 +642,7 @@ MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) { indices.push_back(index); } } else { - assert(loadOrStoreOpInst->isa() && "load/store op expected"); + assert(isa(loadOrStoreOpInst) && "load/store op expected"); auto storeOp = dyn_cast(loadOrStoreOpInst); opInst = loadOrStoreOpInst; memref = storeOp.getMemRef(); @@ -658,7 +658,7 @@ unsigned MemRefAccess::getRank() const { return memref->getType().cast().getRank(); } -bool MemRefAccess::isStore() const { return opInst->isa(); } +bool MemRefAccess::isStore() const { return isa(opInst); } /// Returns the nesting depth of this statement, i.e., the number of loops /// surrounding this statement. @@ -666,7 +666,7 @@ unsigned mlir::getNestingDepth(Operation &op) { Operation *currOp = &op; unsigned depth = 0; while ((currOp = currOp->getParentOp())) { - if (currOp->isa()) + if (isa(currOp)) depth++; } return depth; @@ -698,7 +698,7 @@ static Optional getMemoryFootprintBytes(Block &block, // Walk this 'affine.for' operation to gather all memory regions. bool error = false; block.walk(start, end, [&](Operation *opInst) { - if (!opInst->isa() && !opInst->isa()) { + if (!isa(opInst) && !isa(opInst)) { // Neither load nor a store op. return; } @@ -761,7 +761,7 @@ bool mlir::isLoopParallel(AffineForOp forOp) { // Collect all load and store ops in loop nest rooted at 'forOp'. SmallVector loadAndStoreOpInsts; forOp.getOperation()->walk([&](Operation *opInst) { - if (opInst->isa() || opInst->isa()) + if (isa(opInst) || isa(opInst)) loadAndStoreOpInsts.push_back(opInst); }); diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index 627ca7add94..7c0176d94ed 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -205,7 +205,7 @@ bool mlir::matcher::operatesOnSuperVectorsOf(Operation &op, superVectorType = write.getVectorType(); mustDivide = true; } else if (op.getNumResults() == 0) { - if (!op.isa()) { + if (!isa(op)) { op.emitError("NYI: assuming only return operations can have 0 " " results at this point"); } diff --git a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp index e9aee954b35..ad16143afe6 100644 --- a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp +++ b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp @@ -209,7 +209,7 @@ public: // Match by type. PatternMatchResult match(Operation *op) const override { - if (op->isa()) + if (isa(op)) return this->matchSuccess(); return this->matchFailure(); } diff --git a/mlir/lib/Linalg/IR/LinalgOps.cpp b/mlir/lib/Linalg/IR/LinalgOps.cpp index 8ea45dfd6a1..da102b3f394 100644 --- a/mlir/lib/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Linalg/IR/LinalgOps.cpp @@ -572,17 +572,17 @@ SmallVector mlir::linalg::loopToOperandRangesMaps(Operation *op) { auto i = getAffineDimExpr(0, context); auto j = getAffineDimExpr(1, context); auto k = getAffineDimExpr(2, context); - if (op->isa()) + if (isa(op)) // A(r_i) * B(r_i) -> C() return SmallVector{AffineMap::get(1, 0, {i}, {}), AffineMap::get(1, 0, {i}, {}), AffineMap()}; - if (op->isa()) + if (isa(op)) // A(i, r_j) * B(r_j) -> C(i) return SmallVector{AffineMap::get(2, 0, {i, j}, {}), AffineMap::get(2, 0, {j}, {}), AffineMap::get(2, 0, {i}, {})}; - if (op->isa()) + if (isa(op)) // A(i, r_j) * B(r_j) -> C(i) return SmallVector{AffineMap::get(3, 0, {i, k}, {}), AffineMap::get(3, 0, {k, j}, {}), diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 59c1400aa88..d7b60a05d41 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -1232,7 +1232,7 @@ struct SimplifyDeadDealloc : public RewritePattern { // Check that all of the uses of the AllocOp are other DeallocOps. for (auto &use : memref->getUses()) - if (!use.getOwner()->isa()) + if (!isa(use.getOwner())) return matchFailure(); // Erase the dealloc operation. diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 597efc3ba37..9b1a42eaec7 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -187,7 +187,7 @@ bool ModuleTranslation::convertOperation(Operation &opInst, // Emit calls. If the called function has a result, remap the corresponding // value. Note that LLVM IR dialect CallOp has either 0 or 1 result. - if (opInst.isa()) { + if (isa(opInst)) { llvm::Value *result = convertCall(opInst); if (opInst.getNumResults() != 0) { valueMapping[opInst.getResult(0)] = result; @@ -258,7 +258,7 @@ bool ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments) { static Value *getPHISourceValue(Block *current, Block *pred, unsigned numArguments, unsigned index) { auto &terminator = *pred->getTerminator(); - if (terminator.isa()) { + if (isa(terminator)) { return terminator.getOperand(index); } diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 937399cc703..00ae92bdbc8 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -479,7 +479,7 @@ bool DmaGeneration::runOnBlock(Block *block) { // Get to the first load, store, or for op. auto curBegin = std::find_if(block->begin(), block->end(), [&](Operation &op) { - return op.isa() || op.isa() || op.isa(); + return isa(op) || isa(op) || isa(op); }); for (auto it = curBegin; it != block->end(); ++it) { @@ -522,7 +522,7 @@ bool DmaGeneration::runOnBlock(Block *block) { runOnBlock(/*begin=*/it, /*end=*/std::next(it)); curBegin = std::next(it); } - } else if (!it->isa() && !it->isa()) { + } else if (!isa(&*it) && !isa(&*it)) { runOnBlock(/*begin=*/curBegin, /*end=*/it); curBegin = std::next(it); } diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index d430c5d85cd..4e9e48cc3b3 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -127,13 +127,13 @@ struct LoopNestStateCollector { void collect(Operation *opToWalk) { opToWalk->walk([&](Operation *op) { - if (op->isa()) + if (isa(op)) forOps.push_back(cast(op)); else if (op->getNumRegions() != 0) hasNonForRegion = true; - else if (op->isa()) + else if (isa(op)) loadOpInsts.push_back(op); - else if (op->isa()) + else if (isa(op)) storeOpInsts.push_back(op); }); } @@ -141,8 +141,8 @@ struct LoopNestStateCollector { // TODO(b/117228571) Replace when this is modeled through side-effects/op traits static bool isMemRefDereferencingOp(Operation &op) { - if (op.isa() || op.isa() || op.isa() || - op.isa()) + if (isa(op) || isa(op) || isa(op) || + isa(op)) return true; return false; } @@ -604,7 +604,7 @@ public: continue; assert(nodes.count(edge.id) > 0); // Skip if 'edge.id' is not a loop nest. - if (!getNode(edge.id)->op->isa()) + if (!isa(getNode(edge.id)->op)) continue; // Visit current input edge 'edge'. callback(edge); @@ -756,7 +756,7 @@ struct LoopNestStatsCollector { auto *forInst = forOp.getOperation(); auto *parentInst = forOp.getOperation()->getParentOp(); if (parentInst != nullptr) { - assert(parentInst->isa() && "Expected parent AffineForOp"); + assert(isa(parentInst) && "Expected parent AffineForOp"); // Add mapping to 'forOp' from its parent AffineForOp. stats->loopMap[parentInst].push_back(forOp); } @@ -765,7 +765,7 @@ struct LoopNestStatsCollector { unsigned count = 0; stats->opCountMap[forInst] = 0; for (auto &op : *forOp.getBody()) { - if (!op.isa() && !op.isa()) + if (!isa(op) && !isa(op)) ++count; } stats->opCountMap[forInst] = count; @@ -1049,7 +1049,7 @@ computeLoopInterchangePermutation(ArrayRef loops, // This can increase the loop depth at which we can fuse a slice, since we are // pushing loop carried dependence to a greater depth in the loop nest. static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) { - assert(node->op->isa()); + assert(isa(node->op)); SmallVector loops; AffineForOp curr = cast(node->op); getPerfectlyNestedLoops(loops, curr); @@ -1829,7 +1829,7 @@ public: // Get 'dstNode' into which to attempt fusion. auto *dstNode = mdg->getNode(dstId); // Skip if 'dstNode' is not a loop nest. - if (!dstNode->op->isa()) + if (!isa(dstNode->op)) continue; // Sink sequential loops in 'dstNode' (and thus raise parallel loops) // while preserving relative order. This can increase the maximum loop @@ -1867,7 +1867,7 @@ public: // Get 'srcNode' from which to attempt fusion into 'dstNode'. auto *srcNode = mdg->getNode(srcId); // Skip if 'srcNode' is not a loop nest. - if (!srcNode->op->isa()) + if (!isa(srcNode->op)) continue; // Skip if 'srcNode' has more than one store to any memref. // TODO(andydavis) Support fusing multi-output src loop nests. @@ -2012,7 +2012,7 @@ public: // Get 'dstNode' into which to attempt fusion. auto *dstNode = mdg->getNode(dstId); // Skip if 'dstNode' is not a loop nest. - if (!dstNode->op->isa()) + if (!isa(dstNode->op)) continue; // Attempt to fuse 'dstNode' with its sibling nodes in the graph. fuseWithSiblingNodes(dstNode); @@ -2180,7 +2180,7 @@ public: if (outEdge.id == dstNode->id || outEdge.value != inEdge.value) return; auto *sibNode = mdg->getNode(sibNodeId); - if (!sibNode->op->isa()) + if (!isa(sibNode->op)) return; // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'. if (canFuseWithSibNode(sibNode, outEdge.value)) { diff --git a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp index 2f95db95c6f..402f7d9afef 100644 --- a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp +++ b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp @@ -82,7 +82,7 @@ void LoopInvariantCodeMotion::runOnAffineForOp(AffineForOp forOp) { for (auto &op : *loopBody) { // If the operation is loop invariant, insert it into opsToMove. - if (!op.isa() && !op.isa() && + if (!isa(op) && !isa(op) && loopDefinedOps.count(&op) != 1) { LLVM_DEBUG(op.print(llvm::dbgs() << "\nLICM'ing op\n")); opsToMove.push_back(&op); @@ -99,7 +99,7 @@ void LoopInvariantCodeMotion::runOnAffineForOp(AffineForOp forOp) { // If the for loop body has a single operation (the terminator), erase it. if (forOp.getBody()->getOperations().size() == 1) { - assert(forOp.getBody()->getOperations().front().isa()); + assert(isa(forOp.getBody()->front())); forOp.erase(); } } diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 1707f78143f..05953926376 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -111,7 +111,7 @@ void LoopUnroll::runOnFunction() { for (auto ®ion : opInst->getRegions()) for (auto &block : region) hasInnerLoops |= walkPostOrder(block.begin(), block.end()); - if (opInst->isa()) { + if (isa(opInst)) { if (!hasInnerLoops) loops.push_back(cast(opInst)); return true; diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 43e8f4a7306..609b42455f5 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -139,12 +139,12 @@ LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp, void walk(Block &block) { for (auto it = block.begin(), e = std::prev(block.end()); it != e;) { auto subBlockStart = it; - while (it != e && !it->isa()) + while (it != e && !isa(&*it)) ++it; if (it != subBlockStart) subBlocks.push_back({subBlockStart, std::prev(it)}); // Process all for insts that appear next. - while (it != e && it->isa()) + while (it != e && isa(&*it)) walk(&*it++); } } diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index 6f0162eaea6..7f52e859309 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -612,8 +612,7 @@ void LowerAffinePass::runOnFunction() { // Collect all the For operations as well as AffineIfOps and AffineApplyOps. // We do this as a prepass to avoid invalidating the walker with our rewrite. getFunction().walk([&](Operation *op) { - if (op->isa() || op->isa() || - op->isa()) + if (isa(op) || isa(op) || isa(op)) instsToRewrite.push_back(op); }); diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 206ae53b4bd..f81fabb2965 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -256,7 +256,7 @@ static Value *substitute(Value *v, VectorType hwVectorType, auto it = substitutionsMap->find(v); if (it == substitutionsMap->end()) { auto *opInst = v->getDefiningOp(); - if (opInst->isa()) { + if (isa(opInst)) { FuncBuilder b(opInst); auto *op = instantiate(&b, opInst, hwVectorType, substitutionsMap); auto res = substitutionsMap->insert(std::make_pair(v, op->getResult(0))); @@ -407,9 +407,9 @@ materializeAttributes(Operation *opInst, VectorType hwVectorType) { static Operation *instantiate(FuncBuilder *b, Operation *opInst, VectorType hwVectorType, DenseMap *substitutionsMap) { - assert(!opInst->isa() && + assert(!isa(opInst) && "Should call the function specialized for VectorTransferReadOp"); - assert(!opInst->isa() && + assert(!isa(opInst) && "Should call the function specialized for VectorTransferWriteOp"); if (opInst->getNumRegions() != 0) return nullptr; @@ -550,7 +550,7 @@ static bool instantiateMaterialization(Operation *op, FuncBuilder b(op); // AffineApplyOp are ignored: instantiating the proper vector op will take // care of AffineApplyOps by composing them properly. - if (op->isa()) { + if (isa(op)) { return false; } if (op->getNumRegions() != 0) @@ -749,7 +749,7 @@ void MaterializeVectorsPass::runOnFunction() { // Capture terminators; i.e. vector.transfer_write ops involving a strict // super-vector of subVectorType. auto filter = [subVectorType](Operation &op) { - if (!op.isa()) { + if (!isa(op)) { return false; } return matcher::operatesOnSuperVectorsOf(op, subVectorType); diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index 118efe5548d..fcbaeab5132 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -237,15 +237,15 @@ void MemRefDataFlowOpt::runOnFunction() { for (auto *memref : memrefsToErase) { // If the memref hasn't been alloc'ed in this function, skip. Operation *defInst = memref->getDefiningOp(); - if (!defInst || !defInst->isa()) + if (!defInst || !isa(defInst)) // TODO(mlir-team): if the memref was returned by a 'call' operation, we // could still erase it if the call had no side-effects. continue; if (std::any_of(memref->use_begin(), memref->use_end(), [&](OpOperand &use) { auto *ownerInst = use.getOwner(); - return (!ownerInst->isa() && - !ownerInst->isa()); + return (!isa(ownerInst) && + !isa(ownerInst)); })) continue; diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 272972d233d..0d4b2012ce1 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -57,8 +57,8 @@ FunctionPassBase *mlir::createPipelineDataTransferPass() { // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are // added. TODO(b/117228571) static unsigned getTagMemRefPos(Operation &dmaInst) { - assert(dmaInst.isa() || dmaInst.isa()); - if (dmaInst.isa()) { + assert(isa(dmaInst) || isa(dmaInst)); + if (isa(dmaInst)) { // Second to last operand. return dmaInst.getNumOperands() - 2; } @@ -189,7 +189,7 @@ static void findMatchingStartFinishInsts( SmallVector dmaStartInsts, dmaFinishInsts; for (auto &op : *forOp.getBody()) { // Collect DMA finish operations. - if (op.isa()) { + if (isa(op)) { dmaFinishInsts.push_back(&op); continue; } @@ -218,7 +218,7 @@ static void findMatchingStartFinishInsts( bool escapingUses = false; for (const auto &use : memref->getUses()) { // We can double buffer regardless of dealloc's outside the loop. - if (use.getOwner()->isa()) + if (isa(use.getOwner())) continue; if (!forOp.getBody()->findAncestorInstInBlock(*use.getOwner())) { LLVM_DEBUG(llvm::dbgs() @@ -293,7 +293,7 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { allocInst->erase(); } else if (oldMemRef->hasOneUse()) { auto *singleUse = oldMemRef->use_begin()->getOwner(); - if (singleUse->isa()) { + if (isa(singleUse)) { singleUse->erase(); oldMemRef->getDefiningOp()->erase(); } @@ -325,7 +325,7 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { DenseMap instShiftMap; for (auto &pair : startWaitPairs) { auto *dmaStartInst = pair.first; - assert(dmaStartInst->isa()); + assert(isa(dmaStartInst)); instShiftMap[dmaStartInst] = 0; // Set shifts for DMA start op's affine operand computation slices to 0. SmallVector sliceOps; diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 7fe62a2c86a..fbdee58beab 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -173,7 +173,7 @@ bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) { if (op->hasNoSideEffect() && op->use_empty()) { // Be careful to update bookkeeping in ConstantHelper to keep // consistency if this is a constant op. - if (op->isa()) + if (isa(op)) helper.notifyRemoval(op); op->erase(); continue; diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 1ae75b4fbf7..d0d564a76c6 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -209,7 +209,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, operandMap.map(srcIV, loopChunkIV); } for (auto *op : insts) { - if (!op->isa()) + if (!isa(op)) bodyBuilder.clone(*op, operandMap); } }; @@ -511,7 +511,6 @@ void mlir::interchangeLoops(AffineForOp forOpA, AffineForOp forOpB) { /// deeper in the loop nest. void mlir::sinkLoop(AffineForOp forOp, unsigned loopDepth) { for (unsigned i = 0; i < loopDepth; ++i) { - assert(forOp.getBody()->front().isa()); AffineForOp nextForOp = cast(forOp.getBody()->front()); interchangeLoops(forOp, nextForOp); } @@ -551,7 +550,7 @@ static void cloneLoopBodyInto(AffineForOp forOp, Value *oldIv, if (&op == newForOp.getOperation()) { continue; } - if (op.isa()) { + if (isa(op)) { continue; } auto *instClone = b.clone(op, map); diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 1ab821a9366..00ee9554b25 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -38,8 +38,8 @@ using namespace mlir; // Temporary utility: will be replaced when this is modeled through // side-effects/op traits. TODO(b/117228571) static bool isMemRefDereferencingOp(Operation &op) { - if (op.isa() || op.isa() || op.isa() || - op.isa()) + if (isa(op) || isa(op) || isa(op) || + isa(op)) return true; return false; } @@ -93,7 +93,7 @@ bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, // Skip dealloc's - no replacement is necessary, and a replacement doesn't // hurt dealloc's. - if (opInst->isa()) + if (isa(opInst)) continue; // Check if the memref was used in a non-deferencing context. It is fine for @@ -225,12 +225,9 @@ void mlir::createAffineComputationSlice( // Collect all operands that are results of affine apply ops. SmallVector subOperands; subOperands.reserve(opInst->getNumOperands()); - for (auto *operand : opInst->getOperands()) { - auto *defInst = operand->getDefiningOp(); - if (defInst && defInst->isa()) { + for (auto *operand : opInst->getOperands()) + if (isa_and_nonnull(operand->getDefiningOp())) subOperands.push_back(operand); - } - } // Gather sequence of AffineApplyOps reachable from 'subOperands'. SmallVector affineApplyOps; diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index 20138d56a3a..7b4db1f6cd8 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -231,7 +231,7 @@ void VectorizerTestPass::testComposeMaps(llvm::raw_ostream &outs) { simplifyAffineMap(res).print(outs << "\nComposed map: "); } -static bool affineApplyOp(Operation &op) { return op.isa(); } +static bool affineApplyOp(Operation &op) { return isa(op); } static bool singleResultAffineApplyOpWithoutUses(Operation &op) { auto app = dyn_cast(op); diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 4a58b15e720..a5bb23fc559 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -741,14 +741,14 @@ void VectorizationState::registerReplacement(Operation *key, Operation *value) { vectorizedSet.insert(value); vectorizationMap.insert(std::make_pair(key, value)); registerReplacement(key->getResult(0), value->getResult(0)); - if (key->isa()) { + if (isa(key)) { assert(roots.count(key) == 0 && "root was already inserted previously"); roots.insert(key); } } void VectorizationState::registerTerminal(Operation *op) { - assert(op->isa() && "terminal must be a StoreOp"); + assert(isa(op) && "terminal must be a StoreOp"); assert(terminals.count(op) == 0 && "terminal was already inserted previously"); terminals.insert(op); @@ -800,7 +800,7 @@ static LogicalResult vectorizeRootOrTerminal(Value *iv, // identity subset of AffineMap and do not change layout. // TODO(ntv): increase the expressiveness power of vector.transfer operations // as needed by various targets. - if (opInst->template isa()) { + if (isa(opInst)) { auto permutationMap = makePermutationMap(opInst, state->strategy->loopToVectorDim); if (!permutationMap) @@ -1005,11 +1005,11 @@ static Value *vectorizeOperand(Value *operand, Operation *op, static Operation *vectorizeOneOperation(Operation *opInst, VectorizationState *state) { // Sanity checks. - assert(!opInst->isa() && + assert(!isa(opInst) && "all loads must have already been fully vectorized independently"); - assert(!opInst->isa() && + assert(!isa(opInst) && "vector.transfer_read cannot be further vectorized"); - assert(!opInst->isa() && + assert(!isa(opInst) && "vector.transfer_write cannot be further vectorized"); if (auto store = dyn_cast(opInst)) { diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 9cf85079ec9..57068e2bff2 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -251,7 +251,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) { // Skip if there is no defining operation (e.g., arguments to function). os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth); os.indent(indent) << formatv( - "if (!op{0}->isa<{1}>()) return matchFailure();\n", depth, + "if (!isa<{1}>(op{0})) return matchFailure();\n", depth, op.getQualCppClassName()); } if (tree.getNumArgs() != op.getNumArgs()) { -- cgit v1.2.3 From 90d4023c9b0fc67706860478153cf13295f48727 Mon Sep 17 00:00:00 2001 From: Andy Davis Date: Mon, 13 May 2019 06:57:56 -0700 Subject: Factor out loop interchange code from LoopFusion into LoopUtils (NFC). -- PiperOrigin-RevId: 247926512 --- mlir/include/mlir/Transforms/LoopUtils.h | 18 +++++ mlir/lib/Transforms/LoopFusion.cpp | 104 +------------------------- mlir/lib/Transforms/Utils/LoopUtils.cpp | 123 +++++++++++++++++++++++++++++++ 3 files changed, 143 insertions(+), 102 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Transforms/LoopUtils.h b/mlir/include/mlir/Transforms/LoopUtils.h index 2aecdceff7f..1105688aa55 100644 --- a/mlir/include/mlir/Transforms/LoopUtils.h +++ b/mlir/include/mlir/Transforms/LoopUtils.h @@ -100,6 +100,24 @@ LogicalResult tileCodeGen(MutableArrayRef band, /// and 'forOpB' are part of a perfectly nested sequence of loops. void interchangeLoops(AffineForOp forOpA, AffineForOp forOpB); +/// Checks if the loop interchange permutation 'loopPermMap', of the perfectly +/// nested sequence of loops in 'loops', would violate dependences (loop 'i' in +/// 'loops' is mapped to location 'j = 'loopPermMap[i]' in the interchange). +bool isValidLoopInterchangePermutation(ArrayRef loops, + ArrayRef loopPermMap); + +/// Performs a sequence of loop interchanges on perfectly nested 'loops', as +/// specified by permutation 'loopPermMap' (loop 'i' in 'loops' is mapped to +/// location 'j = 'loopPermMap[i]' after the loop interchange). +unsigned interchangeLoops(ArrayRef loops, + ArrayRef loopPermMap); + +// Sinks all sequential loops to the innermost levels (while preserving +// relative order among them) and moves all parallel loops to the +// outermost (while again preserving relative order among them). +// Returns AffineForOp of the root of the new loop nest after loop interchanges. +AffineForOp sinkSequentialLoops(AffineForOp forOp); + /// Sinks 'forOp' by 'loopDepth' levels by performing a series of loop /// interchanges. Requires that 'forOp' is part of a perfect nest with /// 'loopDepth' AffineForOps consecutively nested under it. diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 4e9e48cc3b3..8a0d63a7e68 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -965,84 +965,6 @@ static unsigned getMaxLoopDepth(ArrayRef loadOpInsts, return loopDepth; } -// Compute loop interchange permutation: -// *) Computes dependence components between all op pairs of ops in loop nest -// rooted at 'loops[0]', for loop depths in range [1, 'maxLoopDepth']. -// *) Classifies the outermost 'maxLoopDepth' loops surrounding 'ops' as either -// parallel or sequential. -// *) Computes the loop permutation which sinks sequential loops deeper into -// the loop nest, while preserving the relative order between other loops. -// *) Checks each dependence component against the permutation to see if the -// desired loop interchange would violate dependences by making the -// dependence componenent lexicographically negative. -// TODO(andydavis) Move this function to LoopUtils. -static bool -computeLoopInterchangePermutation(ArrayRef loops, - SmallVectorImpl *loopPermMap) { - assert(loops.size() > 1); - // Gather dependence components for dependences between all ops in loop nest - // rooted at 'loops[0]', at loop depths in range [1, maxLoopDepth]. - unsigned maxLoopDepth = loops.size(); - std::vector> depCompsVec; - getDependenceComponents(loops[0], maxLoopDepth, &depCompsVec); - // Mark loops as either parallel or sequential. - llvm::SmallVector isParallelLoop(maxLoopDepth, true); - for (unsigned i = 0, e = depCompsVec.size(); i < e; ++i) { - llvm::SmallVector &depComps = depCompsVec[i]; - assert(depComps.size() >= maxLoopDepth); - for (unsigned j = 0; j < maxLoopDepth; ++j) { - DependenceComponent &depComp = depComps[j]; - assert(depComp.lb.hasValue() && depComp.ub.hasValue()); - if (depComp.lb.getValue() != 0 || depComp.ub.getValue() != 0) - isParallelLoop[j] = false; - } - } - - // Count the number of parallel loops. - unsigned numParallelLoops = 0; - for (unsigned i = 0, e = isParallelLoop.size(); i < e; ++i) - if (isParallelLoop[i]) - ++numParallelLoops; - - // Compute permutation of loops that sinks sequential loops (and thus raises - // parallel loops) while preserving relative order. - llvm::SmallVector loopPermMapInv; - loopPermMapInv.resize(maxLoopDepth); - loopPermMap->resize(maxLoopDepth); - unsigned nextSequentialLoop = numParallelLoops; - unsigned nextParallelLoop = 0; - for (unsigned i = 0; i < maxLoopDepth; ++i) { - if (isParallelLoop[i]) { - (*loopPermMap)[i] = nextParallelLoop; - loopPermMapInv[nextParallelLoop++] = i; - } else { - (*loopPermMap)[i] = nextSequentialLoop; - loopPermMapInv[nextSequentialLoop++] = i; - } - } - - // Check each dependence component against the permutation to see if the - // desired loop interchange permutation would make the dependence vectors - // lexicographically negative. - // Example 1: [-1, 1][0, 0] - // Example 2: [0, 0][-1, 1] - for (unsigned i = 0, e = depCompsVec.size(); i < e; ++i) { - llvm::SmallVector &depComps = depCompsVec[i]; - assert(depComps.size() >= maxLoopDepth); - // Check if the first non-zero dependence component is positive. - for (unsigned j = 0; j < maxLoopDepth; ++j) { - unsigned permIndex = loopPermMapInv[j]; - assert(depComps[permIndex].lb.hasValue()); - int64_t depCompLb = depComps[permIndex].lb.getValue(); - if (depCompLb > 0) - break; - if (depCompLb < 0) - return false; - } - } - return true; -} - // Sinks all sequential loops to the innermost levels (while preserving // relative order among them) and moves all parallel loops to the // outermost (while again preserving relative order among them). @@ -1050,30 +972,8 @@ computeLoopInterchangePermutation(ArrayRef loops, // pushing loop carried dependence to a greater depth in the loop nest. static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) { assert(isa(node->op)); - SmallVector loops; - AffineForOp curr = cast(node->op); - getPerfectlyNestedLoops(loops, curr); - if (loops.size() < 2) - return; - - // Compute loop permutation in 'loopPermMap'. - llvm::SmallVector loopPermMap; - if (!computeLoopInterchangePermutation(loops, &loopPermMap)) - return; - - int loopNestRootIndex = -1; - for (int i = loops.size() - 1; i >= 0; --i) { - int permIndex = static_cast(loopPermMap[i]); - // Store the index of the for loop which will be the new loop nest root. - if (permIndex == 0) - loopNestRootIndex = i; - if (permIndex > i) { - // Sink loop 'i' by 'permIndex - i' levels deeper into the loop nest. - sinkLoop(loops[i], permIndex - i); - } - } - assert(loopNestRootIndex != -1 && "invalid root index"); - node->op = loops[loopNestRootIndex].getOperation(); + AffineForOp newRootForOp = sinkSequentialLoops(cast(node->op)); + node->op = newRootForOp.getOperation(); } // TODO(mlir-team): improve/complete this when we have target data. diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index d0d564a76c6..47ee626f811 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -507,6 +507,129 @@ void mlir::interchangeLoops(AffineForOp forOpA, AffineForOp forOpB) { Block::iterator(forOpAInst)); } +// Checks each dependence component against the permutation to see if the +// desired loop interchange would violate dependences by making the +// dependence componenent lexicographically negative. +static bool checkLoopInterchangeDependences( + const std::vector> &depCompsVec, + ArrayRef loops, ArrayRef loopPermMap) { + // Invert permutation map. + unsigned maxLoopDepth = loops.size(); + llvm::SmallVector loopPermMapInv; + loopPermMapInv.resize(maxLoopDepth); + for (unsigned i = 0; i < maxLoopDepth; ++i) + loopPermMapInv[loopPermMap[i]] = i; + + // Check each dependence component against the permutation to see if the + // desired loop interchange permutation would make the dependence vectors + // lexicographically negative. + // Example 1: [-1, 1][0, 0] + // Example 2: [0, 0][-1, 1] + for (unsigned i = 0, e = depCompsVec.size(); i < e; ++i) { + const llvm::SmallVector &depComps = depCompsVec[i]; + assert(depComps.size() >= maxLoopDepth); + // Check if the first non-zero dependence component is positive. + // This iterates through loops in the desired order. + for (unsigned j = 0; j < maxLoopDepth; ++j) { + unsigned permIndex = loopPermMapInv[j]; + assert(depComps[permIndex].lb.hasValue()); + int64_t depCompLb = depComps[permIndex].lb.getValue(); + if (depCompLb > 0) + break; + if (depCompLb < 0) + return false; + } + } + return true; +} + +/// Checks if the loop interchange permutation 'loopPermMap' of the perfectly +/// nested sequence of loops in 'loops' would violate dependences. +bool mlir::isValidLoopInterchangePermutation(ArrayRef loops, + ArrayRef loopPermMap) { + // Gather dependence components for dependences between all ops in loop nest + // rooted at 'loops[0]', at loop depths in range [1, maxLoopDepth]. + assert(loopPermMap.size() == loops.size()); + unsigned maxLoopDepth = loops.size(); + std::vector> depCompsVec; + getDependenceComponents(loops[0], maxLoopDepth, &depCompsVec); + return checkLoopInterchangeDependences(depCompsVec, loops, loopPermMap); +} + +/// Performs a sequence of loop interchanges of loops in perfectly nested +/// sequence of loops in 'loops', as specified by permutation in 'loopPermMap'. +unsigned mlir::interchangeLoops(ArrayRef loops, + ArrayRef loopPermMap) { + Optional loopNestRootIndex; + for (int i = loops.size() - 1; i >= 0; --i) { + int permIndex = static_cast(loopPermMap[i]); + // Store the index of the for loop which will be the new loop nest root. + if (permIndex == 0) + loopNestRootIndex = i; + if (permIndex > i) { + // Sink loop 'i' by 'permIndex - i' levels deeper into the loop nest. + sinkLoop(loops[i], permIndex - i); + } + } + assert(loopNestRootIndex.hasValue()); + return loopNestRootIndex.getValue(); +} + +// Sinks all sequential loops to the innermost levels (while preserving +// relative order among them) and moves all parallel loops to the +// outermost (while again preserving relative order among them). +AffineForOp mlir::sinkSequentialLoops(AffineForOp forOp) { + SmallVector loops; + getPerfectlyNestedLoops(loops, forOp); + if (loops.size() < 2) + return forOp; + + // Gather dependence components for dependences between all ops in loop nest + // rooted at 'loops[0]', at loop depths in range [1, maxLoopDepth]. + unsigned maxLoopDepth = loops.size(); + std::vector> depCompsVec; + getDependenceComponents(loops[0], maxLoopDepth, &depCompsVec); + + // Mark loops as either parallel or sequential. + llvm::SmallVector isParallelLoop(maxLoopDepth, true); + for (unsigned i = 0, e = depCompsVec.size(); i < e; ++i) { + llvm::SmallVector &depComps = depCompsVec[i]; + assert(depComps.size() >= maxLoopDepth); + for (unsigned j = 0; j < maxLoopDepth; ++j) { + DependenceComponent &depComp = depComps[j]; + assert(depComp.lb.hasValue() && depComp.ub.hasValue()); + if (depComp.lb.getValue() != 0 || depComp.ub.getValue() != 0) + isParallelLoop[j] = false; + } + } + + // Count the number of parallel loops. + unsigned numParallelLoops = 0; + for (unsigned i = 0, e = isParallelLoop.size(); i < e; ++i) + if (isParallelLoop[i]) + ++numParallelLoops; + + // Compute permutation of loops that sinks sequential loops (and thus raises + // parallel loops) while preserving relative order. + llvm::SmallVector loopPermMap(maxLoopDepth); + unsigned nextSequentialLoop = numParallelLoops; + unsigned nextParallelLoop = 0; + for (unsigned i = 0; i < maxLoopDepth; ++i) { + if (isParallelLoop[i]) { + loopPermMap[i] = nextParallelLoop++; + } else { + loopPermMap[i] = nextSequentialLoop++; + } + } + + // Check if permutation 'loopPermMap' would violate dependences. + if (!checkLoopInterchangeDependences(depCompsVec, loops, loopPermMap)) + return forOp; + // Perform loop interchange according to permutation 'loopPermMap'. + unsigned loopNestRootIndex = interchangeLoops(loops, loopPermMap); + return loops[loopNestRootIndex]; +} + /// Performs a series of loop interchanges to sink 'forOp' 'loopDepth' levels /// deeper in the loop nest. void mlir::sinkLoop(AffineForOp forOp, unsigned loopDepth) { -- cgit v1.2.3 From 1a2ad06bae22a5a55dde62e654928620ba3cf184 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 13 May 2019 18:10:48 -0700 Subject: Fix lingering sign compare warnings in exposed by "ninja check-mlir". -- PiperOrigin-RevId: 248050178 --- mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp | 2 +- mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp | 2 +- mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp | 2 +- mlir/lib/IR/Attributes.cpp | 3 ++- mlir/lib/LLVMIR/IR/LLVMDialect.cpp | 6 ++++-- mlir/lib/Linalg/Transforms/Tiling.cpp | 4 ++-- mlir/lib/Parser/Parser.cpp | 3 ++- mlir/lib/Transforms/LoopFusion.cpp | 1 + 8 files changed, 14 insertions(+), 9 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp index 02c82d79104..8117aed9e14 100644 --- a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp @@ -210,7 +210,7 @@ public: // dimensions, extracts the size from the memref descriptor. auto memrefSize = [int64Ty, pos, i64cst](MemRefType type, Value *memref, int dim) -> Value * { - assert(dim < type.getRank()); + assert(static_cast(dim) < type.getRank()); if (type.getShape()[dim] != -1) { return i64cst(type.getShape()[dim]); } diff --git a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp index 5f024ea4033..907e3f1c566 100644 --- a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp @@ -189,7 +189,7 @@ public: // FIXME: this seems like a bug in `cloneInto()` above? auto &entryBlock = f->getBlocks().front(); int blockArgSize = entryBlock.getArguments().size(); - assert(blockArgSize == f->getType().getInputs().size()); + assert(blockArgSize == static_cast(f->getType().getInputs().size())); entryBlock.addArguments(f->getType().getInputs()); auto argList = entryBlock.getArguments(); for (int argNum = 0; argNum < blockArgSize; ++argNum) { diff --git a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp index ab990193ab6..5267586a5e4 100644 --- a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp @@ -189,7 +189,7 @@ public: // FIXME: this seems like a bug in `cloneInto()` above? auto &entryBlock = f->getBlocks().front(); int blockArgSize = entryBlock.getArguments().size(); - assert(blockArgSize == f->getType().getInputs().size()); + assert(blockArgSize == static_cast(f->getType().getInputs().size())); entryBlock.addArguments(f->getType().getInputs()); auto argList = entryBlock.getArguments(); for (int argNum = 0; argNum < blockArgSize; ++argNum) { diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index 47a44382349..fad12a89309 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -340,7 +340,8 @@ APInt DenseElementsAttr::RawElementIterator::operator*() const { DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type, ArrayRef data) { - assert((type.getSizeInBits() <= data.size() * APInt::APINT_WORD_SIZE) && + assert((static_cast(type.getSizeInBits()) <= + data.size() * APInt::APINT_WORD_SIZE) && "Input data bit size should be larger than that type requires"); switch (type.getElementType().getKind()) { case StandardTypes::BF16: diff --git a/mlir/lib/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/LLVMIR/IR/LLVMDialect.cpp index 3a8ead491ed..d8d72c2e220 100644 --- a/mlir/lib/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/LLVMIR/IR/LLVMDialect.cpp @@ -576,12 +576,14 @@ static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser *parser, nullptr; int position = positionElementAttr.getInt(); if (llvmContainerType->isArrayTy()) { - if (position < 0 || position >= llvmContainerType->getArrayNumElements()) + if (position < 0 || static_cast(position) >= + llvmContainerType->getArrayNumElements()) return parser->emitError(attributeLoc, "position out of bounds"), nullptr; llvmContainerType = llvmContainerType->getArrayElementType(); } else if (llvmContainerType->isStructTy()) { - if (position < 0 || position >= llvmContainerType->getStructNumElements()) + if (position < 0 || static_cast(position) >= + llvmContainerType->getStructNumElements()) return parser->emitError(attributeLoc, "position out of bounds"), nullptr; llvmContainerType = llvmContainerType->getStructElementType(position); diff --git a/mlir/lib/Linalg/Transforms/Tiling.cpp b/mlir/lib/Linalg/Transforms/Tiling.cpp index f50076a1710..ba1fdbe2715 100644 --- a/mlir/lib/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Linalg/Transforms/Tiling.cpp @@ -189,9 +189,9 @@ static SmallVector makeTiledViews(FuncBuilder *b, Location loc, ArrayRef ivs, ArrayRef tileSizes, PerFunctionState &state) { - assert(ivs.size() == llvm::count_if( + assert(ivs.size() == static_cast(llvm::count_if( llvm::make_range(tileSizes.begin(), tileSizes.end()), - [](Value *v) { return !isZero(v); }) && + [](Value *v) { return !isZero(v); })) && "expected as many ivs as non-zero sizes"); auto *context = op->getContext(); diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 31b77664df6..0fef0cdbb64 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -3292,7 +3292,8 @@ public: break; } - if (requiredOperandCount != -1 && result.size() != requiredOperandCount) + if (requiredOperandCount != -1 && + result.size() != static_cast(requiredOperandCount)) return emitError(startLoc, "expected ") << requiredOperandCount << " operands"; return success(); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 8a0d63a7e68..3a4d6a94e80 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -39,6 +39,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include +#include #define DEBUG_TYPE "affine-loop-fusion" using llvm::SetVector; -- cgit v1.2.3 From 8780d8d8ebcfb3519f93b034973f8e3c1629456e Mon Sep 17 00:00:00 2001 From: River Riddle Date: Sat, 18 May 2019 11:09:07 -0700 Subject: Add user iterators to IRObjects, i.e. Values. -- PiperOrigin-RevId: 248877752 --- mlir/include/mlir/IR/UseDefLists.h | 43 +++++++++++++++++++++- mlir/lib/Analysis/LoopAnalysis.cpp | 4 +- mlir/lib/Analysis/SliceAnalysis.cpp | 14 ++----- mlir/lib/StandardOps/Ops.cpp | 4 +- mlir/lib/Transforms/LoopFusion.cpp | 14 +++---- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 17 ++++----- mlir/lib/Transforms/PipelineDataTransfer.cpp | 11 +++--- .../Utils/GreedyPatternRewriteDriver.cpp | 13 +++---- mlir/lib/Transforms/Utils/Utils.cpp | 8 ++-- 9 files changed, 76 insertions(+), 52 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h index 623fd9cd64e..d266935c206 100644 --- a/mlir/include/mlir/IR/UseDefLists.h +++ b/mlir/include/mlir/IR/UseDefLists.h @@ -31,6 +31,7 @@ namespace mlir { class IROperand; class Operation; template class ValueUseIterator; +template class ValueUserIterator; class IRObjectWithUseList { public: @@ -53,6 +54,15 @@ public: /// Returns a range of all uses, which is useful for iterating over all uses. inline use_range getUses() const; + using user_iterator = ValueUserIterator; + using user_range = llvm::iterator_range; + + inline user_iterator user_begin() const; + inline user_iterator user_end() const; + + /// Returns a range of all users. + inline user_range getUsers() const; + /// Replace all uses of 'this' value with the new value, updating anything in /// the IR that uses 'this' to use the other value instead. When this returns /// there are zero uses of 'this'. @@ -228,8 +238,7 @@ inline auto IRObjectWithUseList::use_end() const -> use_iterator { return use_iterator(nullptr); } -inline auto IRObjectWithUseList::getUses() const - -> llvm::iterator_range { +inline auto IRObjectWithUseList::getUses() const -> use_range { return {use_begin(), use_end()}; } @@ -238,6 +247,36 @@ inline bool IRObjectWithUseList::hasOneUse() const { return firstUse && firstUse->getNextOperandUsingThisValue() == nullptr; } +/// An iterator over all users of a ValueBase. +template +class ValueUserIterator final + : public llvm::mapped_iterator, + Operation *(*)(OperandType &)> { + static Operation *unwrap(OperandType &value) { return value.getOwner(); } + +public: + using pointer = Operation *; + using reference = Operation *; + + /// Initializes the result type iterator to the specified result iterator. + ValueUserIterator(ValueUseIterator it) + : llvm::mapped_iterator, + Operation *(*)(OperandType &)>(it, &unwrap) {} + Operation *operator->() { return **this; } +}; + +inline auto IRObjectWithUseList::user_begin() const -> user_iterator { + return user_iterator(use_begin()); +} + +inline auto IRObjectWithUseList::user_end() const -> user_iterator { + return user_iterator(use_end()); +} + +inline auto IRObjectWithUseList::getUsers() const -> user_range { + return {user_begin(), user_end()}; +} + } // namespace mlir #endif diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 3ec4833329d..97c2a87a10d 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -376,10 +376,10 @@ bool mlir::isInstwiseShiftValid(AffineForOp forOp, ArrayRef shifts) { // Validate the results of this operation if it were to be shifted. for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) { Value *result = op.getResult(i); - for (const auto &use : result->getUses()) { + for (auto *user : result->getUsers()) { // If an ancestor operation doesn't lie in the block of forOp, // there is no shift to check. - if (auto *ancInst = forBody->findAncestorInstInBlock(*use.getOwner())) { + if (auto *ancInst = forBody->findAncestorInstInBlock(*user)) { assert(forBodyShift.count(ancInst) > 0 && "ancestor expected in map"); if (shift != forBodyShift[ancInst]) return false; diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index 155a2bbbd1b..7f1b2e3e0fa 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -51,21 +51,15 @@ static void getForwardSliceImpl(Operation *op, } if (auto forOp = dyn_cast(op)) { - for (auto &u : forOp.getInductionVar()->getUses()) { - auto *ownerInst = u.getOwner(); - if (forwardSlice->count(ownerInst) == 0) { + for (auto *ownerInst : forOp.getInductionVar()->getUsers()) + if (forwardSlice->count(ownerInst) == 0) getForwardSliceImpl(ownerInst, forwardSlice, filter); - } - } } else { assert(op->getNumResults() <= 1 && "NYI: multiple results"); if (op->getNumResults() > 0) { - for (auto &u : op->getResult(0)->getUses()) { - auto *ownerInst = u.getOwner(); - if (forwardSlice->count(ownerInst) == 0) { + for (auto *ownerInst : op->getResult(0)->getUsers()) + if (forwardSlice->count(ownerInst) == 0) getForwardSliceImpl(ownerInst, forwardSlice, filter); - } - } } } diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 5b9f75c1419..aabaed53970 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -1241,8 +1241,8 @@ struct SimplifyDeadDealloc : public RewritePattern { return matchFailure(); // Check that all of the uses of the AllocOp are other DeallocOps. - for (auto &use : memref->getUses()) - if (!isa(use.getOwner())) + for (auto *user : memref->getUsers()) + if (!isa(user)) return matchFailure(); // Erase the dealloc operation. diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 3a4d6a94e80..33a7918b9b8 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -313,8 +313,8 @@ public: if (!op) return true; // Return true if any use of 'memref' escapes the function. - for (auto &use : memref->getUses()) - if (!isMemRefDereferencingOp(*use.getOwner())) + for (auto *user : memref->getUsers()) + if (!isMemRefDereferencingOp(*user)) return true; } return false; @@ -700,9 +700,9 @@ bool MemRefDependenceGraph::init(Function &f) { continue; auto *opInst = node.op; for (auto *value : opInst->getResults()) { - for (auto &use : value->getUses()) { + for (auto *user : value->getUsers()) { SmallVector loops; - getLoopIVs(*use.getOwner(), &loops); + getLoopIVs(*user, &loops); if (loops.empty()) continue; assert(forToNodeMap.count(loops[0].getOperation()) > 0); @@ -2025,11 +2025,11 @@ public: // Search for siblings which load the same memref function argument. auto *fn = dstNode->op->getFunction(); for (unsigned i = 0, e = fn->getNumArguments(); i != e; ++i) { - for (auto &use : fn->getArgument(i)->getUses()) { - if (auto loadOp = dyn_cast(use.getOwner())) { + for (auto *user : fn->getArgument(i)->getUsers()) { + if (auto loadOp = dyn_cast(user)) { // Gather loops surrounding 'use'. SmallVector loops; - getLoopIVs(*use.getOwner(), &loops); + getLoopIVs(*user, &loops); // Skip 'use' if it is not within a loop nest. if (loops.empty()) continue; diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index fcbaeab5132..45a11efc3e3 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -102,8 +102,8 @@ void MemRefDataFlowOpt::forwardStoreToLoad(LoadOp loadOp) { // all store ops. SmallVector storeOps; unsigned minSurroundingLoops = getNestingDepth(*loadOpInst); - for (auto &use : loadOp.getMemRef()->getUses()) { - auto storeOp = dyn_cast(use.getOwner()); + for (auto *user : loadOp.getMemRef()->getUsers()) { + auto storeOp = dyn_cast(user); if (!storeOp) continue; auto *storeOpInst = storeOp.getOperation(); @@ -241,17 +241,14 @@ void MemRefDataFlowOpt::runOnFunction() { // TODO(mlir-team): if the memref was returned by a 'call' operation, we // could still erase it if the call had no side-effects. continue; - if (std::any_of(memref->use_begin(), memref->use_end(), - [&](OpOperand &use) { - auto *ownerInst = use.getOwner(); - return (!isa(ownerInst) && - !isa(ownerInst)); - })) + if (llvm::any_of(memref->getUsers(), [&](Operation *ownerInst) { + return (!isa(ownerInst) && !isa(ownerInst)); + })) continue; // Erase all stores, the dealloc, and the alloc on the memref. - for (auto &use : llvm::make_early_inc_range(memref->getUses())) - use.getOwner()->erase(); + for (auto *user : llvm::make_early_inc_range(memref->getUsers())) + user->erase(); defInst->erase(); } } diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 0d4b2012ce1..c9e1dcefcf6 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -216,11 +216,11 @@ static void findMatchingStartFinishInsts( // We only double buffer if the buffer is not live out of loop. auto *memref = dmaStartOp.getOperand(dmaStartOp.getFasterMemPos()); bool escapingUses = false; - for (const auto &use : memref->getUses()) { + for (auto *user : memref->getUsers()) { // We can double buffer regardless of dealloc's outside the loop. - if (isa(use.getOwner())) + if (isa(user)) continue; - if (!forOp.getBody()->findAncestorInstInBlock(*use.getOwner())) { + if (!forOp.getBody()->findAncestorInstInBlock(*user)) { LLVM_DEBUG(llvm::dbgs() << "can't pipeline: buffer is live out of loop\n";); escapingUses = true; @@ -292,9 +292,8 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { if (oldMemRef->use_empty()) { allocInst->erase(); } else if (oldMemRef->hasOneUse()) { - auto *singleUse = oldMemRef->use_begin()->getOwner(); - if (isa(singleUse)) { - singleUse->erase(); + if (auto dealloc = dyn_cast(*oldMemRef->user_begin())) { + dealloc.erase(); oldMemRef->getDefiningOp()->erase(); } } diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 6d34f034c1e..a2d2d03c3d8 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -105,9 +105,8 @@ protected: // before the root is changed. void notifyRootReplaced(Operation *op) override { for (auto *result : op->getResults()) - // TODO: Add a result->getUsers() iterator. - for (auto &user : result->getUses()) - addToWorklist(user.getOwner()); + for (auto *user : result->getUsers()) + addToWorklist(user); } private: @@ -183,11 +182,9 @@ bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) { // Add all the users of the result to the worklist so we make sure // to revisit them. - // - // TODO: Add a result->getUsers() iterator. - for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) - for (auto &operand : op->getResult(i)->getUses()) - addToWorklist(operand.getOwner()); + for (auto *result : op->getResults()) + for (auto *operand : result->getUsers()) + addToWorklist(operand); }; // Try to fold this op. diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 00ee9554b25..484131a9ce5 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -79,9 +79,7 @@ bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, SmallVector opsToErase; // Walk all uses of old memref. Operation using the memref gets replaced. - for (auto &use : llvm::make_early_inc_range(oldMemRef->getUses())) { - auto *opInst = use.getOwner(); - + for (auto *opInst : llvm::make_early_inc_range(oldMemRef->getUsers())) { // Skip this use if it's not dominated by domInstFilter. if (domInstFilter && !domInfo->dominates(domInstFilter, opInst)) continue; @@ -241,8 +239,8 @@ void mlir::createAffineComputationSlice( bool localized = true; for (auto *op : affineApplyOps) { for (auto *result : op->getResults()) { - for (auto &use : result->getUses()) { - if (use.getOwner() != opInst) { + for (auto *user : result->getUsers()) { + if (user != opInst) { localized = false; break; } -- cgit v1.2.3 From 80884d28ac3f63900146b3efb0b566493e4b1734 Mon Sep 17 00:00:00 2001 From: MLIR Team Date: Mon, 20 May 2019 14:18:43 -0700 Subject: [LoopFusion] Don't count terminator op in compute cost. -- PiperOrigin-RevId: 249124895 --- mlir/LICENSE.TXT | 227 ++++++------------------------------- mlir/lib/Transforms/LoopFusion.cpp | 5 +- 2 files changed, 36 insertions(+), 196 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/LICENSE.TXT b/mlir/LICENSE.TXT index a4b160b6e33..5af756e1e01 100644 --- a/mlir/LICENSE.TXT +++ b/mlir/LICENSE.TXT @@ -1,205 +1,44 @@ -Copyright 2019 The MLIR Authors. +============================================================================== +LLVM Release License +============================================================================== +University of Illinois/NCSA +Open Source License - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ +Copyright (c) 2003-2018 University of Illinois at Urbana-Champaign. +All rights reserved. - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION +Developed by: - 1. Definitions. + LLVM Team - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. + University of Illinois at Urbana-Champaign - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. + http://llvm.org - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal with +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimers. - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimers in the + documentation and/or other materials provided with the distribution. - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. + * Neither the names of the LLVM Team, University of Illinois at + Urbana-Champaign, nor the names of its contributors may be used to + endorse or promote products derived from this Software without specific + prior written permission. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH THE +SOFTWARE. diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 33a7918b9b8..7999e0a6557 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -801,8 +801,9 @@ static int64_t getComputeCost( Operation *forInst, LoopNestStats *stats, llvm::SmallDenseMap *tripCountOverrideMap, DenseMap *computeCostMap) { - // 'opCount' is the total number operations in one iteration of 'forOp' body - int64_t opCount = stats->opCountMap[forInst]; + // 'opCount' is the total number operations in one iteration of 'forOp' body, + // minus terminator op which is a no-op. + int64_t opCount = stats->opCountMap[forInst] - 1; if (stats->loopMap.count(forInst) > 0) { for (auto childForOp : stats->loopMap[forInst]) { opCount += getComputeCost(childForOp.getOperation(), stats, -- cgit v1.2.3 From a560f2c646d7a762a3cf0a74ce55fc9876c1d974 Mon Sep 17 00:00:00 2001 From: Andy Davis Date: Fri, 24 May 2019 10:54:22 -0700 Subject: Affine Loop Fusion Utility Module (1/n). *) Adds LoopFusionUtils which will expose a set of loop fusion utilities (e.g. dependence checks, fusion cost/storage reduction, loop fusion transformation) for use by loop fusion algorithms. Support for checking block-level fusion-preventing dependences is added in this CL (additional loop fusion utilities will be added in subsequent CLs). *) Adds TestLoopFusion test pass for testing LoopFusionUtils at a fine granularity. *) Adds unit test for testing dependence check for block-level fusion-preventing dependences. -- PiperOrigin-RevId: 249861071 --- mlir/include/mlir/Transforms/LoopFusionUtils.h | 59 ++++ mlir/include/mlir/Transforms/Passes.h | 3 + mlir/lib/Transforms/LoopFusion.cpp | 8 + mlir/lib/Transforms/TestLoopFusion.cpp | 112 +++++++ mlir/lib/Transforms/Utils/LoopFusionUtils.cpp | 202 ++++++++++++ .../Transforms/loop-fusion-dependence-check.mlir | 337 +++++++++++++++++++++ 6 files changed, 721 insertions(+) create mode 100644 mlir/include/mlir/Transforms/LoopFusionUtils.h create mode 100644 mlir/lib/Transforms/TestLoopFusion.cpp create mode 100644 mlir/lib/Transforms/Utils/LoopFusionUtils.cpp create mode 100644 mlir/test/Transforms/loop-fusion-dependence-check.mlir (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Transforms/LoopFusionUtils.h b/mlir/include/mlir/Transforms/LoopFusionUtils.h new file mode 100644 index 00000000000..ccda6693f88 --- /dev/null +++ b/mlir/include/mlir/Transforms/LoopFusionUtils.h @@ -0,0 +1,59 @@ +//===- LoopFusionUtils.h - Loop fusion utilities ----------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This header file defines prototypes for various loop fusion utility +// methods: these are not passes by themselves but are used either by passes, +// optimization sequences, or in turn by other transformation utilities. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TRANSFORMS_LOOP_FUSION_UTILS_H +#define MLIR_TRANSFORMS_LOOP_FUSION_UTILS_H + +namespace mlir { +class AffineForOp; +struct ComputationSliceState; + +// TODO(andydavis) Extend this module to include utility functions for querying +// fusion cost/storage reduction, and for performing the loop fusion +// transformation. + +struct FusionResult { + enum ResultEnum { + Success, + FailPrecondition, // Failed precondition for fusion. (e.g. same block). + FailBlockDependence, // Fusion would violate another dependence in block. + FailFusionDependence, // Fusion would reverse dependences between loops. + FailComputationSlice, // Unable to compute src loop computation slice. + } value; + FusionResult(ResultEnum v) : value(v) {} +}; + +/// Checks the feasibility of fusing the loop nest rooted at 'srcForOp' into the +/// loop nest rooted at 'dstForOp' at 'dstLoopDepth'. Returns FusionResult +/// 'Success' if fusion of the src/dst loop nests is feasible (i.e. they are +/// in the same block and dependences would not be violated). Otherwise +/// returns a FusionResult explaining why fusion is not feasible. +/// NOTE: This function is not feature complete and should only be used in +/// testing. +/// TODO(andydavis) Update comments when this function is fully implemented. +FusionResult canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, + unsigned dstLoopDepth, + ComputationSliceState *srcSlice); +} // end namespace mlir + +#endif // MLIR_TRANSFORMS_LOOP_FUSION_UTILS_H diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index dc5d57fba4e..48822cdac86 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -116,6 +116,9 @@ FunctionPassBase *createMemRefDataFlowOptPass(); /// Creates a pass to strip debug information from a function. FunctionPassBase *createStripDebugInfoPass(); +/// Creates a pass which tests loop fusion utilities. +FunctionPassBase *createTestLoopFusionPass(); + } // end namespace mlir #endif // MLIR_TRANSFORMS_PASSES_H diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 7999e0a6557..1f475f1fb44 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -29,6 +29,7 @@ #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" #include "mlir/StandardOps/Ops.h" +#include "mlir/Transforms/LoopFusionUtils.h" #include "mlir/Transforms/LoopUtils.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Utils.h" @@ -1810,6 +1811,13 @@ public: dstLoadOpInsts, dstStoreOpInsts, &sliceState, &bestDstLoopDepth, maximalFusion)) continue; + // TODO(andydavis) Remove assert and surrounding code when + // canFuseLoops is fully functional. + FusionResult result = mlir::canFuseLoops( + cast(srcNode->op), cast(dstNode->op), + bestDstLoopDepth, /*srcSlice=*/nullptr); + assert(result.value == FusionResult::Success); + (void)result; // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. auto sliceLoopNest = mlir::insertBackwardComputationSlice( diff --git a/mlir/lib/Transforms/TestLoopFusion.cpp b/mlir/lib/Transforms/TestLoopFusion.cpp new file mode 100644 index 00000000000..9ace2fb4350 --- /dev/null +++ b/mlir/lib/Transforms/TestLoopFusion.cpp @@ -0,0 +1,112 @@ +//===- TestLoopFusion.cpp - Test loop fusion ------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements a pass to test various loop fusion utility functions. +// +//===----------------------------------------------------------------------===// + +#include "mlir/AffineOps/AffineOps.h" +#include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/Analysis/AffineStructures.h" +#include "mlir/Analysis/Passes.h" +#include "mlir/Analysis/Utils.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" +#include "mlir/StandardOps/Ops.h" +#include "mlir/Transforms/LoopFusionUtils.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "test-loop-fusion" + +using namespace mlir; + +static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); + +static llvm::cl::opt clTestDependenceCheck( + "test-loop-fusion-dependence-check", + llvm::cl::desc("Enable testing of loop fusion dependence check"), + llvm::cl::cat(clOptionsCategory)); + +namespace { + +struct TestLoopFusion : public FunctionPass { + void runOnFunction() override; +}; + +} // end anonymous namespace + +FunctionPassBase *mlir::createTestLoopFusionPass() { + return new TestLoopFusion; +} + +// Gathers all AffineForOps in 'block' at 'currLoopDepth' in 'depthToLoops'. +static void +gatherLoops(Block *block, unsigned currLoopDepth, + DenseMap> &depthToLoops) { + auto &loopsAtDepth = depthToLoops[currLoopDepth]; + for (auto &op : *block) { + if (auto forOp = dyn_cast(op)) { + loopsAtDepth.push_back(forOp); + gatherLoops(forOp.getBody(), currLoopDepth + 1, depthToLoops); + } + } +} + +// Run fusion dependence check on 'loops[i]' and 'loops[j]' at 'loopDepth'. +// Emits a remark on 'loops[i]' if a fusion-preventing dependence exists. +static void testDependenceCheck(SmallVector &loops, unsigned i, + unsigned j, unsigned loopDepth) { + AffineForOp srcForOp = loops[i]; + AffineForOp dstForOp = loops[j]; + FusionResult result = mlir::canFuseLoops(srcForOp, dstForOp, loopDepth, + /*srcSlice=*/nullptr); + if (result.value == FusionResult::FailBlockDependence) { + srcForOp.getOperation()->emitRemark("block-level dependence preventing" + " fusion of loop nest ") + << i << " into loop nest " << j << " at depth " << loopDepth; + } +} + +void TestLoopFusion::runOnFunction() { + // Gather all AffineForOps by loop depth. + DenseMap> depthToLoops; + for (auto &block : getFunction()) { + gatherLoops(&block, /*currLoopDepth=*/0, depthToLoops); + } + + // Run tests on all combinations of src/dst loop nests in 'depthToLoops'. + for (auto &depthAndLoops : depthToLoops) { + unsigned loopDepth = depthAndLoops.first; + auto &loops = depthAndLoops.second; + unsigned numLoops = loops.size(); + for (unsigned j = 0; j < numLoops; ++j) { + for (unsigned k = 0; k < numLoops; ++k) { + if (j == k) + continue; + if (clTestDependenceCheck) + testDependenceCheck(loops, j, k, loopDepth); + } + } + } +} + +static PassRegistration + pass("test-loop-fusion", "Tests loop fusion utility functions."); diff --git a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp new file mode 100644 index 00000000000..9de6766e075 --- /dev/null +++ b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp @@ -0,0 +1,202 @@ +//===- LoopFusionUtils.cpp ---- Utilities for loop fusion ----------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements loop fusion transformation utility functions. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/LoopFusionUtils.h" + +#include "mlir/AffineOps/AffineOps.h" +#include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/Analysis/AffineStructures.h" +#include "mlir/Analysis/Utils.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Operation.h" +#include "mlir/StandardOps/Ops.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "loop-fusion-utils" + +using namespace mlir; + +// Gathers all load and store operations in 'opA' into 'values', where +// 'values[memref] == true' for each store operation. +static void getLoadsAndStores(Operation *opA, DenseMap &values) { + opA->walk([&](Operation *op) { + if (auto loadOp = dyn_cast(op)) { + if (values.count(loadOp.getMemRef()) == 0) + values[loadOp.getMemRef()] = false; + } else if (auto storeOp = dyn_cast(op)) { + values[storeOp.getMemRef()] = true; + } + }); +} + +// Returns true if 'op' is a load or store operation which access an memref +// accessed 'values' and at least one of the access is a store operation. +// Returns false otherwise. +static bool isDependentLoadOrStoreOp(Operation *op, + DenseMap &values) { + if (auto loadOp = dyn_cast(op)) { + return values.count(loadOp.getMemRef()) > 0 && + values[loadOp.getMemRef()] == true; + } else if (auto storeOp = dyn_cast(op)) { + return values.count(storeOp.getMemRef()) > 0; + } + return false; +} + +// Returns the first operation in range ('opA', 'opB') which has a data +// dependence on 'opA'. Returns 'nullptr' of no dependence exists. +static Operation *getFirstDependentOpInRange(Operation *opA, Operation *opB) { + // Record memref values from all loads/store in loop nest rooted at 'opA'. + // Map from memref value to bool which is true if store, false otherwise. + DenseMap values; + getLoadsAndStores(opA, values); + + // For each 'opX' in block in range ('opA', 'opB'), check if there is a data + // dependence from 'opA' to 'opX' ('opA' and 'opX' access the same memref + // and at least one of the accesses is a store). + Operation *firstDepOp = nullptr; + for (Block::iterator it = std::next(Block::iterator(opA)); + it != Block::iterator(opB); ++it) { + Operation *opX = &(*it); + opX->walk([&](Operation *op) { + if (!firstDepOp && isDependentLoadOrStoreOp(op, values)) + firstDepOp = opX; + }); + if (firstDepOp) + break; + } + return firstDepOp; +} + +// Returns the last operation 'opX' in range ('opA', 'opB'), for which there +// exists a data dependence from 'opX' to 'opB'. +// Returns 'nullptr' of no dependence exists. +static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) { + // Record memref values from all loads/store in loop nest rooted at 'opB'. + // Map from memref value to bool which is true if store, false otherwise. + DenseMap values; + getLoadsAndStores(opB, values); + + // For each 'opX' in block in range ('opA', 'opB') in reverse order, + // check if there is a data dependence from 'opX' to 'opB': + // *) 'opX' and 'opB' access the same memref and at least one of the accesses + // is a store. + // *) 'opX' produces an SSA Value which is used by 'opB'. + Operation *lastDepOp = nullptr; + for (Block::reverse_iterator it = std::next(Block::reverse_iterator(opB)); + it != Block::reverse_iterator(opA); ++it) { + Operation *opX = &(*it); + opX->walk([&](Operation *op) { + if (lastDepOp) + return; + if (isa(op) || isa(op)) { + if (isDependentLoadOrStoreOp(op, values)) + lastDepOp = opX; + return; + } + for (auto *value : op->getResults()) { + for (auto *user : value->getUsers()) { + SmallVector loops; + // Check if any loop in loop nest surrounding 'user' is 'opB'. + getLoopIVs(*user, &loops); + if (llvm::is_contained(loops, cast(opB))) { + lastDepOp = opX; + } + } + } + }); + if (lastDepOp) + break; + } + return lastDepOp; +} + +// Computes and returns an insertion point operation, before which the +// the fused loop nest can be inserted while preserving +// dependences. Returns nullptr if no such insertion point is found. +static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp, + AffineForOp dstForOp) { + bool isSrcForOpBeforeDstForOp = + srcForOp.getOperation()->isBeforeInBlock(dstForOp.getOperation()); + auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp; + auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp; + + auto *firstDepOpA = + getFirstDependentOpInRange(forOpA.getOperation(), forOpB.getOperation()); + auto *lastDepOpB = + getLastDependentOpInRange(forOpA.getOperation(), forOpB.getOperation()); + // Block: + // ... + // |-- opA + // | ... + // | lastDepOpB --| + // | ... | + // |-> firstDepOpA | + // ... | + // opB <--------- + // + // Valid insertion point range: (lastDepOpB, firstDepOpA) + // + if (firstDepOpA != nullptr) { + if (lastDepOpB != nullptr) { + if (firstDepOpA->isBeforeInBlock(lastDepOpB) || firstDepOpA == lastDepOpB) + // No valid insertion point exists which preserves dependences. + return nullptr; + } + // Return insertion point in valid range closest to 'opB'. + // TODO(andydavis) Consider other insertion points in valid range. + return firstDepOpA; + } + // No dependences from 'opA' to operation in range ('opA', 'opB'), return + // 'opB' insertion point. + return forOpB.getOperation(); +} + +// TODO(andydavis) Add support for the following features in subsequent CLs: +// *) Computing union of slices computed between src/dst loads and stores. +// *) Compute dependences of unfused src/dst loops. +// *) Compute dependences of src/dst loop as if they were fused. +// *) Check for fusion preventing dependences (e.g. a dependence which changes +// from loop-independent to backward loop-carried after fusion). +FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, + unsigned dstLoopDepth, + ComputationSliceState *srcSlice) { + // Return 'false' if 'srcForOp' and 'dstForOp' are not in the same block. + auto *block = srcForOp.getOperation()->getBlock(); + if (block != dstForOp.getOperation()->getBlock()) { + LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n."); + return FusionResult::FailPrecondition; + } + + // Return 'false' if no valid insertion point for fused loop nest in 'block' + // exists which would preserve dependences. + if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) { + LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n."); + return FusionResult::FailBlockDependence; + } + return FusionResult::Success; +} diff --git a/mlir/test/Transforms/loop-fusion-dependence-check.mlir b/mlir/test/Transforms/loop-fusion-dependence-check.mlir new file mode 100644 index 00000000000..3174f896c25 --- /dev/null +++ b/mlir/test/Transforms/loop-fusion-dependence-check.mlir @@ -0,0 +1,337 @@ +// RUN: mlir-opt %s -test-loop-fusion -test-loop-fusion-dependence-check -split-input-file -verify | FileCheck %s + +// ----- + +// CHECK-LABEL: func @cannot_fuse_would_create_cycle() { +func @cannot_fuse_would_create_cycle() { + %a = alloc() : memref<10xf32> + %b = alloc() : memref<10xf32> + %c = alloc() : memref<10xf32> + + %cf7 = constant 7.0 : f32 + + // Set up the following dependences: + // 1) loop0 -> loop1 on memref '%a' + // 2) loop0 -> loop2 on memref '%b' + // 3) loop1 -> loop2 on memref '%c' + + // Fusing loop nest '%i0' and loop nest '%i2' would create a cycle. + affine.for %i0 = 0 to 10 { + // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 0 into loop nest 2 at depth 0}} + %v0 = load %a[%i0] : memref<10xf32> + store %cf7, %b[%i0] : memref<10xf32> + } + affine.for %i1 = 0 to 10 { + store %cf7, %a[%i1] : memref<10xf32> + %v1 = load %c[%i1] : memref<10xf32> + } + affine.for %i2 = 0 to 10 { + // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 2 into loop nest 0 at depth 0}} + %v2 = load %b[%i2] : memref<10xf32> + store %cf7, %c[%i2] : memref<10xf32> + } + return +} + +// ----- + +// CHECK-LABEL: func @can_fuse_rar_dependence() { +func @can_fuse_rar_dependence() { + %a = alloc() : memref<10xf32> + %b = alloc() : memref<10xf32> + %c = alloc() : memref<10xf32> + + %cf7 = constant 7.0 : f32 + + // Set up the following dependences: + // Make dependence from 0 to 1 on '%a' read-after-read. + // 1) loop0 -> loop1 on memref '%a' + // 2) loop0 -> loop2 on memref '%b' + // 3) loop1 -> loop2 on memref '%c' + + // Should fuse: no fusion preventing remarks should be emitted for this test. + affine.for %i0 = 0 to 10 { + %v0 = load %a[%i0] : memref<10xf32> + store %cf7, %b[%i0] : memref<10xf32> + } + affine.for %i1 = 0 to 10 { + %v1 = load %a[%i1] : memref<10xf32> + %v2 = load %c[%i1] : memref<10xf32> + } + affine.for %i2 = 0 to 10 { + %v3 = load %b[%i2] : memref<10xf32> + store %cf7, %c[%i2] : memref<10xf32> + } + return +} + +// ----- + +// CHECK-LABEL: func @can_fuse_different_memrefs() { +func @can_fuse_different_memrefs() { + %a = alloc() : memref<10xf32> + %b = alloc() : memref<10xf32> + %c = alloc() : memref<10xf32> + %d = alloc() : memref<10xf32> + + %cf7 = constant 7.0 : f32 + + // Set up the following dependences: + // Make dependence from 0 to 1 on unrelated memref '%d'. + // 1) loop0 -> loop1 on memref '%a' + // 2) loop0 -> loop2 on memref '%b' + // 3) loop1 -> loop2 on memref '%c' + + // Should fuse: no fusion preventing remarks should be emitted for this test. + affine.for %i0 = 0 to 10 { + %v0 = load %a[%i0] : memref<10xf32> + store %cf7, %b[%i0] : memref<10xf32> + } + affine.for %i1 = 0 to 10 { + store %cf7, %d[%i1] : memref<10xf32> + %v1 = load %c[%i1] : memref<10xf32> + } + affine.for %i2 = 0 to 10 { + %v2 = load %b[%i2] : memref<10xf32> + store %cf7, %c[%i2] : memref<10xf32> + } + return +} + +// ----- + +// CHECK-LABEL: func @should_not_fuse_across_intermediate_store() { +func @should_not_fuse_across_intermediate_store() { + %0 = alloc() : memref<10xf32> + %c0 = constant 0 : index + %cf7 = constant 7.0 : f32 + + affine.for %i0 = 0 to 10 { + // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 0 into loop nest 1 at depth 0}} + %v0 = load %0[%i0] : memref<10xf32> + "op0"(%v0) : (f32) -> () + } + + // Should not fuse loop nests '%i0' and '%i1' across top-level store. + store %cf7, %0[%c0] : memref<10xf32> + + affine.for %i1 = 0 to 10 { + // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 1 into loop nest 0 at depth 0}} + %v1 = load %0[%i1] : memref<10xf32> + "op1"(%v1) : (f32) -> () + } + return +} + +// ----- + +// CHECK-LABEL: func @should_not_fuse_across_intermediate_load() { +func @should_not_fuse_across_intermediate_load() { + %0 = alloc() : memref<10xf32> + %c0 = constant 0 : index + %cf7 = constant 7.0 : f32 + + affine.for %i0 = 0 to 10 { + // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 0 into loop nest 1 at depth 0}} + store %cf7, %0[%i0] : memref<10xf32> + } + + // Should not fuse loop nests '%i0' and '%i1' across top-level load. + %v0 = load %0[%c0] : memref<10xf32> + "op0"(%v0) : (f32) -> () + + affine.for %i1 = 0 to 10 { + // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 1 into loop nest 0 at depth 0}} + store %cf7, %0[%i1] : memref<10xf32> + } + + return +} + +// ----- + +// CHECK-LABEL: func @should_not_fuse_across_ssa_value_def() { +func @should_not_fuse_across_ssa_value_def() { + %0 = alloc() : memref<10xf32> + %1 = alloc() : memref<10xf32> + %c0 = constant 0 : index + %cf7 = constant 7.0 : f32 + + affine.for %i0 = 0 to 10 { + // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 0 into loop nest 1 at depth 0}} + %v0 = load %0[%i0] : memref<10xf32> + store %v0, %1[%i0] : memref<10xf32> + } + + // Loop nest '%i0" cannot be fused past load from '%1' due to RAW dependence. + %v1 = load %1[%c0] : memref<10xf32> + "op0"(%v1) : (f32) -> () + + // Loop nest '%i1' cannot be fused past SSA value def '%c2' which it uses. + %c2 = constant 2 : index + + affine.for %i1 = 0 to 10 { + // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 1 into loop nest 0 at depth 0}} + store %cf7, %0[%c2] : memref<10xf32> + } + + return +} + +// ----- + +// CHECK-LABEL: func @should_not_fuse_store_before_load() { +func @should_not_fuse_store_before_load() { + %0 = alloc() : memref<10xf32> + %c0 = constant 0 : index + %cf7 = constant 7.0 : f32 + + affine.for %i0 = 0 to 10 { + // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 0 into loop nest 2 at depth 0}} + store %cf7, %0[%i0] : memref<10xf32> + %v0 = load %0[%i0] : memref<10xf32> + } + + affine.for %i1 = 0 to 10 { + %v1 = load %0[%i1] : memref<10xf32> + } + + affine.for %i2 = 0 to 10 { + // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 2 into loop nest 0 at depth 0}} + store %cf7, %0[%i2] : memref<10xf32> + %v2 = load %0[%i2] : memref<10xf32> + } + return +} + +// ----- + +// CHECK-LABEL: func @should_not_fuse_across_load_at_depth1() { +func @should_not_fuse_across_load_at_depth1() { + %0 = alloc() : memref<10x10xf32> + %c0 = constant 0 : index + %cf7 = constant 7.0 : f32 + + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { + // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 0 into loop nest 1 at depth 1}} + store %cf7, %0[%i0, %i1] : memref<10x10xf32> + } + + %v1 = load %0[%i0, %c0] : memref<10x10xf32> + + affine.for %i3 = 0 to 10 { + // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 1 into loop nest 0 at depth 1}} + store %cf7, %0[%i0, %i3] : memref<10x10xf32> + } + } + return +} + +// ----- + +// CHECK-LABEL: func @should_not_fuse_across_load_in_loop_at_depth1() { +func @should_not_fuse_across_load_in_loop_at_depth1() { + %0 = alloc() : memref<10x10xf32> + %c0 = constant 0 : index + %cf7 = constant 7.0 : f32 + + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { + // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 0 into loop nest 2 at depth 1}} + store %cf7, %0[%i0, %i1] : memref<10x10xf32> + } + + affine.for %i2 = 0 to 10 { + %v1 = load %0[%i0, %i2] : memref<10x10xf32> + } + + affine.for %i3 = 0 to 10 { + // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 2 into loop nest 0 at depth 1}} + store %cf7, %0[%i0, %i3] : memref<10x10xf32> + } + } + return +} + +// ----- + +// CHECK-LABEL: func @should_not_fuse_across_store_at_depth1() { +func @should_not_fuse_across_store_at_depth1() { + %0 = alloc() : memref<10x10xf32> + %c0 = constant 0 : index + %cf7 = constant 7.0 : f32 + + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { + // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 0 into loop nest 1 at depth 1}} + %v0 = load %0[%i0, %i1] : memref<10x10xf32> + } + + store %cf7, %0[%i0, %c0] : memref<10x10xf32> + + affine.for %i3 = 0 to 10 { + // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 1 into loop nest 0 at depth 1}} + %v1 = load %0[%i0, %i3] : memref<10x10xf32> + } + } + return +} + +// ----- + +// CHECK-LABEL: func @should_not_fuse_across_store_in_loop_at_depth1() { +func @should_not_fuse_across_store_in_loop_at_depth1() { + %0 = alloc() : memref<10x10xf32> + %c0 = constant 0 : index + %cf7 = constant 7.0 : f32 + + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { + // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 0 into loop nest 2 at depth 1}} + %v0 = load %0[%i0, %i1] : memref<10x10xf32> + } + + affine.for %i2 = 0 to 10 { + store %cf7, %0[%i0, %i2] : memref<10x10xf32> + } + + affine.for %i3 = 0 to 10 { + // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 2 into loop nest 0 at depth 1}} + %v1 = load %0[%i0, %i3] : memref<10x10xf32> + } + } + return +} + +// ----- + +// CHECK-LABEL: func @should_not_fuse_across_ssa_value_def_at_depth1() { +func @should_not_fuse_across_ssa_value_def_at_depth1() { + %0 = alloc() : memref<10x10xf32> + %1 = alloc() : memref<10x10xf32> + %c0 = constant 0 : index + %cf7 = constant 7.0 : f32 + + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { + // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 0 into loop nest 1 at depth 1}} + %v0 = load %0[%i0, %i1] : memref<10x10xf32> + store %v0, %1[%i0, %i1] : memref<10x10xf32> + } + + // RAW dependence from store in loop nest '%i1' to 'load %1' prevents + // fusion loop nest '%i1' into loops after load. + %v1 = load %1[%i0, %c0] : memref<10x10xf32> + "op0"(%v1) : (f32) -> () + + // Loop nest '%i2' cannot be fused past SSA value def '%c2' which it uses. + %c2 = constant 2 : index + + affine.for %i2 = 0 to 10 { + // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 1 into loop nest 0 at depth 1}} + store %cf7, %0[%i0, %c2] : memref<10x10xf32> + } + } + return +} \ No newline at end of file -- cgit v1.2.3 From 1de0f97fff7b7f5fae21374e77d35c5c311c9f39 Mon Sep 17 00:00:00 2001 From: Andy Davis Date: Wed, 29 May 2019 14:02:14 -0700 Subject: LoopFusionUtils CL 2/n: Factor out and generalize slice union computation. *) Factors slice union computation out of LoopFusion into Analysis/Utils (where other iteration slice utilities exist). *) Generalizes slice union computation to take the union of slices computed on all loads/stores pairs between source and destination loop nests. *) Fixes a bug in FlatAffineConstraints::addSliceBounds where redundant constraints were added. *) Takes care of a TODO to expose FlatAffineConstraints::mergeAndAlignIds as a public method. -- PiperOrigin-RevId: 250561529 --- mlir/include/mlir/Analysis/AffineStructures.h | 19 ++++ mlir/include/mlir/Analysis/Utils.h | 8 ++ mlir/lib/Analysis/AffineStructures.cpp | 24 +++-- mlir/lib/Analysis/Utils.cpp | 148 ++++++++++++++++++++++++++ mlir/lib/Transforms/LoopFusion.cpp | 86 ++------------- mlir/lib/Transforms/TestLoopFusion.cpp | 6 +- mlir/lib/Transforms/Utils/LoopFusionUtils.cpp | 57 ++++++++-- 7 files changed, 250 insertions(+), 98 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h index 1cff4290cc7..aadace079b0 100644 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -541,6 +541,25 @@ public: /// <= 15}, output = {0 <= d0 <= 6, 1 <= d1 <= 15}. LogicalResult unionBoundingBox(const FlatAffineConstraints &other); + /// Returns 'true' if this constraint system and 'other' are in the same + /// space, i.e., if they are associated with the same set of identifiers, + /// appearing in the same order. Returns 'false' otherwise. + bool areIdsAlignedWithOther(const FlatAffineConstraints &other); + + /// Merge and align the identifiers of 'this' and 'other' starting at + /// 'offset', so that both constraint systems get the union of the contained + /// identifiers that is dimension-wise and symbol-wise unique; both + /// constraint systems are updated so that they have the union of all + /// identifiers, with this's original identifiers appearing first followed by + /// any of other's identifiers that didn't appear in 'this'. Local + /// identifiers of each system are by design separate/local and are placed + /// one after other (this's followed by other's). + // Eg: Input: 'this' has ((%i %j) [%M %N]) + // 'other' has (%k, %j) [%P, %N, %M]) + // Output: both 'this', 'other' have (%i, %j, %k) [%M, %N, %P] + // + void mergeAndAlignIdsWithOther(unsigned offset, FlatAffineConstraints *other); + unsigned getNumConstraints() const { return getNumInequalities() + getNumEqualities(); } diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index 34eb627dcad..d6bf0c617ae 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -92,6 +92,14 @@ LogicalResult getBackwardComputationSliceState( const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, unsigned dstLoopDepth, ComputationSliceState *sliceState); +/// Computes in 'sliceUnion' the union of all slice bounds computed at +/// 'dstLoopDepth' between all pairs in 'srcOps' and 'dstOp' which access the +/// same memref. Returns 'success' if union was computed, 'failure' otherwise. +LogicalResult computeSliceUnion(ArrayRef srcOps, + ArrayRef dstOps, + unsigned dstLoopDepth, + ComputationSliceState *sliceUnion); + /// Creates a clone of the computation contained in the loop nest surrounding /// 'srcOpInst', slices the iteration space of src loop based on slice bounds /// in 'sliceState', and inserts the computation slice at the beginning of the diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 3b7d5a00d9e..9a821a0266d 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -482,13 +482,20 @@ void FlatAffineConstraints::addId(IdKind kind, unsigned pos, Value *id) { /// Checks if two constraint systems are in the same space, i.e., if they are /// associated with the same set of identifiers, appearing in the same order. -bool areIdsAligned(const FlatAffineConstraints &A, - const FlatAffineConstraints &B) { +static bool areIdsAligned(const FlatAffineConstraints &A, + const FlatAffineConstraints &B) { return A.getNumDimIds() == B.getNumDimIds() && A.getNumSymbolIds() == B.getNumSymbolIds() && A.getNumIds() == B.getNumIds() && A.getIds().equals(B.getIds()); } +/// Calls areIdsAligned to check if two constraint systems have the same set +/// of identifiers in the same order. +bool FlatAffineConstraints::areIdsAlignedWithOther( + const FlatAffineConstraints &other) { + return areIdsAligned(*this, other); +} + /// Checks if the SSA values associated with `cst''s identifiers are unique. static bool LLVM_ATTRIBUTE_UNUSED areIdsUnique(const FlatAffineConstraints &cst) { @@ -527,7 +534,6 @@ static void swapId(FlatAffineConstraints *A, unsigned posA, unsigned posB) { // Eg: Input: A has ((%i %j) [%M %N]) and B has (%k, %j) [%P, %N, %M]) // Output: both A, B have (%i, %j, %k) [%M, %N, %P] // -// TODO(mlir-team): expose this function at some point. static void mergeAndAlignIds(unsigned offset, FlatAffineConstraints *A, FlatAffineConstraints *B) { assert(offset <= A->getNumDimIds() && offset <= B->getNumDimIds()); @@ -604,6 +610,12 @@ static void mergeAndAlignIds(unsigned offset, FlatAffineConstraints *A, assert(areIdsAligned(*A, *B) && "IDs expected to be aligned"); } +// Call 'mergeAndAlignIds' to align constraint systems of 'this' and 'other'. +void FlatAffineConstraints::mergeAndAlignIdsWithOther( + unsigned offset, FlatAffineConstraints *other) { + mergeAndAlignIds(offset, this, other); +} + // This routine may add additional local variables if the flattened expression // corresponding to the map has such variables due to mod's, ceildiv's, and // floordiv's in it. @@ -1745,15 +1757,9 @@ LogicalResult FlatAffineConstraints::addSliceBounds( if (failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/true, /*lower=*/true))) return failure(); - if (failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/true, - /*lower=*/true))) - return failure(); continue; } - if (lbMap && failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/false, - /*lower=*/true))) - return failure(); if (lbMap && failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/false, /*lower=*/true))) return failure(); diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 2a46c0e5b4f..3026074fd01 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -28,6 +28,7 @@ #include "mlir/IR/Builders.h" #include "mlir/StandardOps/Ops.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -481,6 +482,153 @@ static Operation *getInstAtPosition(ArrayRef positions, return nullptr; } +// Returns the MemRef accessed by load or store 'op'. +static Value *getLoadOrStoreMemRef(Operation *op) { + if (auto loadOp = dyn_cast(op)) + return loadOp.getMemRef(); + return cast(op).getMemRef(); +} + +// Adds loop IV bounds to 'cst' for loop IVs not found in 'ivs'. +LogicalResult addMissingLoopIVBounds(SmallPtrSet &ivs, + FlatAffineConstraints *cst) { + for (unsigned i = 0, e = cst->getNumDimIds(); i < e; ++i) { + auto *value = cst->getIdValue(i); + if (ivs.count(value) == 0) { + assert(isForInductionVar(value)); + auto loop = getForInductionVarOwner(value); + if (failed(cst->addAffineForOpDomain(loop))) + return failure(); + } + } + return success(); +} + +/// Computes in 'sliceUnion' the union of all slice bounds computed at +/// 'dstLoopDepth' between all pairs in 'srcOps' and 'dstOp' which access the +/// same memref. Returns 'Success' if union was computed, 'failure' otherwise. +LogicalResult mlir::computeSliceUnion(ArrayRef srcOps, + ArrayRef dstOps, + unsigned dstLoopDepth, + ComputationSliceState *sliceUnion) { + unsigned numSrcOps = srcOps.size(); + unsigned numDstOps = dstOps.size(); + assert(numSrcOps > 0 && numDstOps > 0); + + // Compute the intersection of 'srcMemrefToOps' and 'dstMemrefToOps'. + llvm::SmallDenseSet memrefIntersection; + for (auto *srcOp : srcOps) { + auto *srcMemRef = getLoadOrStoreMemRef(srcOp); + for (auto *dstOp : dstOps) { + if (srcMemRef == getLoadOrStoreMemRef(dstOp)) + memrefIntersection.insert(srcMemRef); + } + } + // Return failure if 'memrefIntersection' is empty. + if (memrefIntersection.empty()) + return failure(); + + // Compute the union of slice bounds between all pairs in 'srcOps' and + // 'dstOps' in 'sliceUnionCst'. + FlatAffineConstraints sliceUnionCst; + assert(sliceUnionCst.getNumDimAndSymbolIds() == 0); + for (unsigned i = 0; i < numSrcOps; ++i) { + MemRefAccess srcAccess(srcOps[i]); + for (unsigned j = 0; j < numDstOps; ++j) { + MemRefAccess dstAccess(dstOps[j]); + if (srcAccess.memref != dstAccess.memref) + continue; + // Compute slice bounds for 'srcAccess' and 'dstAccess'. + ComputationSliceState tmpSliceState; + if (failed(mlir::getBackwardComputationSliceState( + srcAccess, dstAccess, dstLoopDepth, &tmpSliceState))) { + LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice bounds\n."); + return failure(); + } + + if (sliceUnionCst.getNumDimAndSymbolIds() == 0) { + // Initialize 'sliceUnionCst' with the bounds computed in previous step. + if (failed(tmpSliceState.getAsConstraints(&sliceUnionCst))) { + LLVM_DEBUG(llvm::dbgs() + << "Unable to compute slice bound constraints\n."); + return failure(); + } + assert(sliceUnionCst.getNumDimAndSymbolIds() > 0); + continue; + } + + // Compute constraints for 'tmpSliceState' in 'tmpSliceCst'. + FlatAffineConstraints tmpSliceCst; + if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) { + LLVM_DEBUG(llvm::dbgs() + << "Unable to compute slice bound constraints\n."); + return failure(); + } + + // Align coordinate spaces of 'sliceUnionCst' and 'tmpSliceCst' if needed. + if (!sliceUnionCst.areIdsAlignedWithOther(tmpSliceCst)) { + + // Pre-constraint id alignment: record loop IVs used in each constraint + // system. + SmallPtrSet sliceUnionIVs; + for (unsigned k = 0, l = sliceUnionCst.getNumDimIds(); k < l; ++k) + sliceUnionIVs.insert(sliceUnionCst.getIdValue(k)); + SmallPtrSet tmpSliceIVs; + for (unsigned k = 0, l = tmpSliceCst.getNumDimIds(); k < l; ++k) + tmpSliceIVs.insert(tmpSliceCst.getIdValue(k)); + + sliceUnionCst.mergeAndAlignIdsWithOther(/*offset=*/0, &tmpSliceCst); + + // Post-constraint id alignment: add loop IV bounds missing after + // id alignment to constraint systems. This can occur if one constraint + // system uses an loop IV that is not used by the other. The call + // to unionBoundingBox below expects constraints for each Loop IV, even + // if they are the unsliced full loop bounds added here. + if (failed(addMissingLoopIVBounds(sliceUnionIVs, &sliceUnionCst))) + return failure(); + if (failed(addMissingLoopIVBounds(tmpSliceIVs, &tmpSliceCst))) + return failure(); + } + // Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'. + if (failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) { + LLVM_DEBUG(llvm::dbgs() + << "Unable to compute union bounding box of slice bounds." + "\n."); + return failure(); + } + } + } + + // Store 'numSrcLoopIvs' before converting dst loop IVs to dims. + unsigned numSrcLoopIVs = sliceUnionCst.getNumDimIds(); + + // Convert any dst loop IVs which are symbol identifiers to dim identifiers. + sliceUnionCst.convertLoopIVSymbolsToDims(); + sliceUnion->clearBounds(); + sliceUnion->lbs.resize(numSrcLoopIVs, AffineMap()); + sliceUnion->ubs.resize(numSrcLoopIVs, AffineMap()); + + // Get slice bounds from slice union constraints 'sliceUnionCst'. + sliceUnionCst.getSliceBounds(numSrcLoopIVs, srcOps[0]->getContext(), + &sliceUnion->lbs, &sliceUnion->ubs); + + // Add slice bound operands of union. + SmallVector sliceBoundOperands; + sliceUnionCst.getIdValues(numSrcLoopIVs, + sliceUnionCst.getNumDimAndSymbolIds(), + &sliceBoundOperands); + + // Copy src loop IVs from 'sliceUnionCst' to 'sliceUnion'. + sliceUnion->ivs.clear(); + sliceUnionCst.getIdValues(0, numSrcLoopIVs, &sliceUnion->ivs); + + // Give each bound its own copy of 'sliceBoundOperands' for subsequent + // canonicalization. + sliceUnion->lbOperands.resize(numSrcLoopIVs, sliceBoundOperands); + sliceUnion->ubOperands.resize(numSrcLoopIVs, sliceBoundOperands); + return success(); +} + const char *const kSliceFusionBarrierAttrName = "slice_fusion_barrier"; // Computes memref dependence between 'srcAccess' and 'dstAccess', projects // out any dst loop IVs at depth greater than 'dstLoopDepth', and computes slice diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 1f475f1fb44..7eb2c7289c0 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -1192,82 +1192,6 @@ static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, return true; } -// Computes the union of all slice bounds computed between 'srcOpInst' -// and each load op in 'dstLoadOpInsts' at 'dstLoopDepth', and returns -// the union in 'sliceState'. Returns true on success, false otherwise. -// TODO(andydavis) Move this to a loop fusion utility function. -static bool getSliceUnion(Operation *srcOpInst, - ArrayRef dstLoadOpInsts, - unsigned numSrcLoopIVs, unsigned dstLoopDepth, - ComputationSliceState *sliceState) { - MemRefAccess srcAccess(srcOpInst); - unsigned numDstLoadOpInsts = dstLoadOpInsts.size(); - assert(numDstLoadOpInsts > 0); - // Compute the slice bounds between 'srcOpInst' and 'dstLoadOpInsts[0]'. - if (failed(mlir::getBackwardComputationSliceState( - srcAccess, MemRefAccess(dstLoadOpInsts[0]), dstLoopDepth, - sliceState))) - return false; - // Handle the common case of one dst load without a copy. - if (numDstLoadOpInsts == 1) - return true; - - // Initialize 'sliceUnionCst' with the bounds computed in previous step. - FlatAffineConstraints sliceUnionCst; - if (failed(sliceState->getAsConstraints(&sliceUnionCst))) { - LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice bound constraints\n."); - return false; - } - - // Compute the union of slice bounds between 'srcOpInst' and each load - // in 'dstLoadOpInsts' in range [1, numDstLoadOpInsts), in 'sliceUnionCst'. - for (unsigned i = 1; i < numDstLoadOpInsts; ++i) { - MemRefAccess dstAccess(dstLoadOpInsts[i]); - // Compute slice bounds for 'srcOpInst' and 'dstLoadOpInsts[i]'. - ComputationSliceState tmpSliceState; - if (failed(mlir::getBackwardComputationSliceState( - srcAccess, dstAccess, dstLoopDepth, &tmpSliceState))) { - LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice bounds\n."); - return false; - } - - // Compute constraints for 'tmpSliceState' in 'tmpSliceCst'. - FlatAffineConstraints tmpSliceCst; - if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) { - LLVM_DEBUG(llvm::dbgs() - << "Unable to compute slice bound constraints\n."); - return false; - } - // Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'. - if (failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) { - LLVM_DEBUG(llvm::dbgs() - << "Unable to compute union bounding box of slice bounds.\n."); - return false; - } - } - - // Convert any dst loop IVs which are symbol identifiers to dim identifiers. - sliceUnionCst.convertLoopIVSymbolsToDims(); - - sliceState->clearBounds(); - sliceState->lbs.resize(numSrcLoopIVs, AffineMap()); - sliceState->ubs.resize(numSrcLoopIVs, AffineMap()); - - // Get slice bounds from slice union constraints 'sliceUnionCst'. - sliceUnionCst.getSliceBounds(numSrcLoopIVs, srcOpInst->getContext(), - &sliceState->lbs, &sliceState->ubs); - // Add slice bound operands of union. - SmallVector sliceBoundOperands; - sliceUnionCst.getIdValues(numSrcLoopIVs, - sliceUnionCst.getNumDimAndSymbolIds(), - &sliceBoundOperands); - // Give each bound its own copy of 'sliceBoundOperands' for subsequent - // canonicalization. - sliceState->lbOperands.resize(numSrcLoopIVs, sliceBoundOperands); - sliceState->ubOperands.resize(numSrcLoopIVs, sliceBoundOperands); - return true; -} - // Checks the profitability of fusing a backwards slice of the loop nest // surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'. // The argument 'srcStoreOpInst' is used to calculate the storage reduction on @@ -1404,10 +1328,11 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, DenseMap computeCostMap; for (unsigned i = maxDstLoopDepth; i >= 1; --i) { // Compute the union of slice bounds of all ops in 'dstLoadOpInsts'. - if (!getSliceUnion(srcOpInst, dstLoadOpInsts, numSrcLoopIVs, i, - &sliceStates[i - 1])) { + if (failed(mlir::computeSliceUnion({srcOpInst}, dstLoadOpInsts, + /*dstLoopDepth=*/i, + &sliceStates[i - 1]))) { LLVM_DEBUG(llvm::dbgs() - << "getSliceUnion failed for loopDepth: " << i << "\n"); + << "computeSliceUnion failed for loopDepth: " << i << "\n"); continue; } @@ -1813,9 +1738,10 @@ public: continue; // TODO(andydavis) Remove assert and surrounding code when // canFuseLoops is fully functional. + mlir::ComputationSliceState sliceUnion; FusionResult result = mlir::canFuseLoops( cast(srcNode->op), cast(dstNode->op), - bestDstLoopDepth, /*srcSlice=*/nullptr); + bestDstLoopDepth, &sliceUnion); assert(result.value == FusionResult::Success); (void)result; diff --git a/mlir/lib/Transforms/TestLoopFusion.cpp b/mlir/lib/Transforms/TestLoopFusion.cpp index 9ace2fb4350..638cf915b6a 100644 --- a/mlir/lib/Transforms/TestLoopFusion.cpp +++ b/mlir/lib/Transforms/TestLoopFusion.cpp @@ -76,8 +76,10 @@ static void testDependenceCheck(SmallVector &loops, unsigned i, unsigned j, unsigned loopDepth) { AffineForOp srcForOp = loops[i]; AffineForOp dstForOp = loops[j]; - FusionResult result = mlir::canFuseLoops(srcForOp, dstForOp, loopDepth, - /*srcSlice=*/nullptr); + mlir::ComputationSliceState sliceUnion; + // TODO(andydavis) Test at deeper loop depths current loop depth + 1. + FusionResult result = + mlir::canFuseLoops(srcForOp, dstForOp, loopDepth + 1, &sliceUnion); if (result.value == FusionResult::FailBlockDependence) { srcForOp.getOperation()->emitRemark("block-level dependence preventing" " fusion of loop nest ") diff --git a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp index 9de6766e075..cb1d9d17ed0 100644 --- a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp @@ -40,9 +40,10 @@ using namespace mlir; -// Gathers all load and store operations in 'opA' into 'values', where +// Gathers all load and store memref accesses in 'opA' into 'values', where // 'values[memref] == true' for each store operation. -static void getLoadsAndStores(Operation *opA, DenseMap &values) { +static void getLoadAndStoreMemRefAccesses(Operation *opA, + DenseMap &values) { opA->walk([&](Operation *op) { if (auto loadOp = dyn_cast(op)) { if (values.count(loadOp.getMemRef()) == 0) @@ -73,7 +74,7 @@ static Operation *getFirstDependentOpInRange(Operation *opA, Operation *opB) { // Record memref values from all loads/store in loop nest rooted at 'opA'. // Map from memref value to bool which is true if store, false otherwise. DenseMap values; - getLoadsAndStores(opA, values); + getLoadAndStoreMemRefAccesses(opA, values); // For each 'opX' in block in range ('opA', 'opB'), check if there is a data // dependence from 'opA' to 'opX' ('opA' and 'opX' access the same memref @@ -99,7 +100,7 @@ static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) { // Record memref values from all loads/store in loop nest rooted at 'opB'. // Map from memref value to bool which is true if store, false otherwise. DenseMap values; - getLoadsAndStores(opB, values); + getLoadAndStoreMemRefAccesses(opB, values); // For each 'opX' in block in range ('opA', 'opB') in reverse order, // check if there is a data dependence from 'opX' to 'opB': @@ -176,8 +177,22 @@ static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp, return forOpB.getOperation(); } +// Gathers all load and store ops in loop nest rooted at 'forOp' into +// 'loadAndStoreOps'. +static bool +gatherLoadsAndStores(AffineForOp forOp, + SmallVectorImpl &loadAndStoreOps) { + bool hasIfOp = false; + forOp.getOperation()->walk([&](Operation *op) { + if (isa(op) || isa(op)) + loadAndStoreOps.push_back(op); + else if (isa(op)) + hasIfOp = true; + }); + return !hasIfOp; +} + // TODO(andydavis) Add support for the following features in subsequent CLs: -// *) Computing union of slices computed between src/dst loads and stores. // *) Compute dependences of unfused src/dst loops. // *) Compute dependences of src/dst loop as if they were fused. // *) Check for fusion preventing dependences (e.g. a dependence which changes @@ -185,18 +200,46 @@ static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp, FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, unsigned dstLoopDepth, ComputationSliceState *srcSlice) { - // Return 'false' if 'srcForOp' and 'dstForOp' are not in the same block. + // Return 'failure' if 'dstLoopDepth == 0'. + if (dstLoopDepth == 0) { + LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n."); + return FusionResult::FailPrecondition; + } + // Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block. auto *block = srcForOp.getOperation()->getBlock(); if (block != dstForOp.getOperation()->getBlock()) { LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n."); return FusionResult::FailPrecondition; } - // Return 'false' if no valid insertion point for fused loop nest in 'block' + // Return 'failure' if no valid insertion point for fused loop nest in 'block' // exists which would preserve dependences. if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) { LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n."); return FusionResult::FailBlockDependence; } + + // Gather all load and store ops in 'srcForOp'. + SmallVector srcLoadAndStoreOps; + if (!gatherLoadsAndStores(srcForOp, srcLoadAndStoreOps)) { + LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported.\n."); + return FusionResult::FailPrecondition; + } + + // Gather all load and store ops in 'dstForOp'. + SmallVector dstLoadAndStoreOps; + if (!gatherLoadsAndStores(dstForOp, dstLoadAndStoreOps)) { + LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported.\n."); + return FusionResult::FailPrecondition; + } + + // Compute union of computation slices computed from all pairs in + // {'srcLoadAndStoreOps', 'dstLoadAndStoreOps'}. + if (failed(mlir::computeSliceUnion(srcLoadAndStoreOps, dstLoadAndStoreOps, + dstLoopDepth, srcSlice))) { + LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n"); + return FusionResult::FailPrecondition; + } + return FusionResult::Success; } -- cgit v1.2.3 From 5a91b9896ce4cd21d97f5df609931b7adf7806c3 Mon Sep 17 00:00:00 2001 From: MLIR Team Date: Wed, 29 May 2019 14:56:41 -0700 Subject: Remove "size" property of affine maps. -- PiperOrigin-RevId: 250572818 --- .../Linalg/Linalg3/include/linalg3/TensorOps-inl.h | 2 +- mlir/examples/Linalg/Linalg3/lib/Analysis.cpp | 4 +- mlir/examples/Linalg/Linalg3/lib/TensorOps.cpp | 19 +++++---- mlir/examples/Linalg/Linalg3/lib/Transforms.cpp | 3 +- mlir/examples/Linalg/Linalg4/lib/Transforms.cpp | 3 +- mlir/g3doc/Dialects/Affine.md | 37 +++++++----------- mlir/include/mlir/Analysis/AffineStructures.h | 1 - mlir/include/mlir/IR/AffineMap.h | 23 ++--------- mlir/include/mlir/IR/Builders.h | 3 +- mlir/include/mlir/VectorOps/VectorOps.h | 5 +-- mlir/lib/AffineOps/AffineOps.cpp | 3 +- mlir/lib/Analysis/AffineStructures.cpp | 23 +++++------ mlir/lib/Analysis/LoopAnalysis.cpp | 4 +- mlir/lib/Analysis/VectorAnalysis.cpp | 2 +- mlir/lib/EDSC/Builders.cpp | 2 +- mlir/lib/IR/AffineMap.cpp | 45 +++++----------------- mlir/lib/IR/AffineMapDetail.h | 4 -- mlir/lib/IR/AsmPrinter.cpp | 10 ----- mlir/lib/IR/Builders.cpp | 18 ++++----- mlir/lib/IR/MLIRContext.cpp | 25 +++++------- mlir/lib/IR/StandardTypes.cpp | 4 +- mlir/lib/Linalg/IR/LinalgOps.cpp | 17 ++++---- mlir/lib/Linalg/Utils/Utils.cpp | 2 +- mlir/lib/Parser/Parser.cpp | 40 +------------------ mlir/lib/Transforms/DmaGeneration.cpp | 5 +-- mlir/lib/Transforms/LoopFusion.cpp | 7 ++-- mlir/lib/Transforms/LoopTiling.cpp | 4 +- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 2 +- mlir/lib/Transforms/MaterializeVectors.cpp | 4 +- mlir/lib/Transforms/PipelineDataTransfer.cpp | 2 +- mlir/lib/Transforms/Utils/LoopUtils.cpp | 12 +++--- mlir/lib/Transforms/Utils/Utils.cpp | 7 ++-- mlir/lib/VectorOps/VectorOps.cpp | 6 --- mlir/test/IR/affine-map.mlir | 18 --------- mlir/test/IR/invalid-affinemap.mlir | 15 -------- 35 files changed, 106 insertions(+), 275 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps-inl.h b/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps-inl.h index 2c475418b05..d86c5344cbc 100644 --- a/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps-inl.h +++ b/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps-inl.h @@ -71,7 +71,7 @@ mlir::AffineMap linalg::operandRangesToLoopsMap( results.append(m.getResults().begin(), m.getResults().end()); current = mlir::AffineMap::get( std::max(current.getNumDims(), m.getNumDims()), - current.getNumSymbols() + m.getNumSymbols(), results, {}); + current.getNumSymbols() + m.getNumSymbols(), results); } return inverseSubMap(current); } diff --git a/mlir/examples/Linalg/Linalg3/lib/Analysis.cpp b/mlir/examples/Linalg/Linalg3/lib/Analysis.cpp index 9e7c8eee5a0..9d7dfd08a37 100644 --- a/mlir/examples/Linalg/Linalg3/lib/Analysis.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/Analysis.cpp @@ -48,7 +48,7 @@ static AffineMap inversePermutationMap(AffineMap map) { seenExprs.push_back(expr); assert(map.getNumSymbols() == 0 && "expected map without symbols"); assert(seenExprs.size() == map.getNumInputs() && "map is not invertible"); - return AffineMap::get(map.getNumResults(), 0, seenExprs, {}); + return AffineMap::get(map.getNumResults(), 0, seenExprs); } mlir::AffineMap linalg::inverseSubMap(AffineMap map, unsigned beginResult, @@ -57,6 +57,6 @@ mlir::AffineMap linalg::inverseSubMap(AffineMap map, unsigned beginResult, endResult = map.getNumResults(); auto subMap = AffineMap::get( map.getNumDims(), map.getNumSymbols(), - map.getResults().slice(beginResult, endResult - beginResult), {}); + map.getResults().slice(beginResult, endResult - beginResult)); return inversePermutationMap(subMap); } diff --git a/mlir/examples/Linalg/Linalg3/lib/TensorOps.cpp b/mlir/examples/Linalg/Linalg3/lib/TensorOps.cpp index 0a55fc2a4e5..f539c702549 100644 --- a/mlir/examples/Linalg/Linalg3/lib/TensorOps.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/TensorOps.cpp @@ -46,9 +46,9 @@ SmallVector linalg::DotOp::loopsToOperandRangeMaps() { auto d0 = getAffineDimExpr(0, context); // K // A(K), B(K), C() // (d0) -> (d0, d0)(%k) - return SmallVector{AffineMap::get(1, 0, {d0}, {}), // A(K) - AffineMap::get(1, 0, {d0}, {}), // B(K) - AffineMap()}; // C() + return SmallVector{AffineMap::get(1, 0, {d0}), // A(K) + AffineMap::get(1, 0, {d0}), // B(K) + AffineMap()}; // C() } void linalg::DotOp::emitScalarImplementation( @@ -92,10 +92,9 @@ SmallVector linalg::MatvecOp::loopsToOperandRangeMaps() { auto d1 = getAffineDimExpr(1, context); // K // A(M, K), B(K), C(M) // (d0, d1) -> (d0, d1, d1, d0)(%m, %k) - return SmallVector{ - AffineMap::get(2, 0, {d0, d1}, {}), // A(M, K) - AffineMap::get(2, 0, {d1}, {}), // B(K) - AffineMap::get(2, 0, {d0}, {})}; // C(M) + return SmallVector{AffineMap::get(2, 0, {d0, d1}), // A(M, K) + AffineMap::get(2, 0, {d1}), // B(K) + AffineMap::get(2, 0, {d0})}; // C(M) } // The body expression for matvec is: C(i) = scalarC + A(i, r_j) * B(r_j) @@ -163,9 +162,9 @@ SmallVector linalg::MatmulOp::loopsToOperandRangeMaps() { // A(M, K), B(K, N), C(M, N): // (d0, d1, d2) -> (d0, d2, d2, d1, d0, d1)(%m, %n, %k) return SmallVector{ - AffineMap::get(3, 0, {d0, d2}, {}), // A(M, K) - AffineMap::get(3, 0, {d2, d1}, {}), // B(K, N) - AffineMap::get(3, 0, {d0, d1}, {}) // C(M, N) + AffineMap::get(3, 0, {d0, d2}), // A(M, K) + AffineMap::get(3, 0, {d2, d1}), // B(K, N) + AffineMap::get(3, 0, {d0, d1}) // C(M, N) }; } diff --git a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp index 0fe70e27f1b..3a11c6d17d9 100644 --- a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp @@ -120,12 +120,11 @@ static RangeParts makeGenericRangeParts(AffineMap map, assert(map.getNumInputs() == ranges.size()); unsigned numDims = map.getNumDims(); assert(map.getNumSymbols() == 0); - assert(map.getRangeSizes().empty()); RangeParts res(map.getNumResults()); RangeParts rangeParts(ranges); for (auto expr : map.getResults()) { - AffineMap map = AffineMap::get(numDims, 0, expr, {}); + AffineMap map = AffineMap::get(numDims, 0, expr); res.mins.push_back(makeFoldedComposedAffineApply(map, rangeParts.mins)); res.maxes.push_back(makeFoldedComposedAffineApply(map, rangeParts.maxes)); res.steps.push_back(makeFoldedComposedAffineApply(map, rangeParts.steps)); diff --git a/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp index d695f5404e3..3df6f4b00ac 100644 --- a/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp @@ -73,7 +73,6 @@ makeTiledRanges(TensorContractionBase &contraction, // 1. Take the first ivs results of the map, the other ones are not composed // but merely copied over. assert(map.getNumSymbols() == 0); - assert(map.getRangeSizes().empty()); MLIRContext *context = ScopedContext::getContext(); unsigned numParallel = op->getNumParallelDims(); unsigned numReduction = op->getNumReductionDims(); @@ -93,7 +92,7 @@ makeTiledRanges(TensorContractionBase &contraction, for (auto en : llvm::enumerate(map.getResults())) { auto index = en.index(); auto expr = en.value(); - AffineMap exprMap = AffineMap::get(numDims, 0, expr, {}); + AffineMap exprMap = AffineMap::get(numDims, 0, expr); ValueHandle offset(makeFoldedComposedAffineApply(exprMap, ivs)); // Offset is normally a function of loop induction variables. // If it is 0, it must come from a dimension that was not tiled. diff --git a/mlir/g3doc/Dialects/Affine.md b/mlir/g3doc/Dialects/Affine.md index 91bf40e1dc5..a209d8b5a45 100644 --- a/mlir/g3doc/Dialects/Affine.md +++ b/mlir/g3doc/Dialects/Affine.md @@ -13,15 +13,16 @@ core concepts that are used throughout the document. ### Dimensions and Symbols Dimensions and symbols are the two kinds of identifiers that can appear in the -polyhedral structures, and are always of [`index`](../LangRef.md#index-type) type. Dimensions -are declared in parentheses and symbols are declared in square brackets. +polyhedral structures, and are always of [`index`](../LangRef.md#index-type) +type. Dimensions are declared in parentheses and symbols are declared in square +brackets. Examples: ```mlir {.mlir} // A 2d to 3d affine mapping. // d0/d1 are dimensions, s0 is a symbol -#affine_map2to3 = (d0, d1)[s0] -> (d0, d1 + s0, d1 - s0) size (10, 20, 30) +#affine_map2to3 = (d0, d1)[s0] -> (d0, d1 + s0, d1 - s0) ``` Dimensional identifiers correspond to the dimensions of the underlying structure @@ -51,7 +52,7 @@ SSA values bound to dimensions and symbols must always have 'index' type. Example: ```mlir {.mlir} -#affine_map2to3 = (d0, d1)[s0] -> (d0, d1 + s0, d1 - s0) size (10,20,30) +#affine_map2to3 = (d0, d1)[s0] -> (d0, d1 + s0, d1 - s0) // Binds %N to the s0 symbol in affine_map2to3. %x = alloc()[%N] : memref<40x50xf32, #affine_map2to3> ``` @@ -98,10 +99,11 @@ less than or equal to that result. `mod` is the modulo operation: since its second argument is always positive, its results are always positive in our usage. The `integer-literal` operand for ceildiv, floordiv, and mod is always expected to be positive. `bare-id` is an identifier which must have type -[index](../LangRef.md#index-type). The precedence of operations in an affine expression are -ordered from highest to lowest in the order: (1) parenthesization, (2) negation, -(3) modulo, multiplication, floordiv, and ceildiv, and (4) addition and -subtraction. All of these operators associate from left to right. +[index](../LangRef.md#index-type). The precedence of operations in an affine +expression are ordered from highest to lowest in the order: (1) +parenthesization, (2) negation, (3) modulo, multiplication, floordiv, and +ceildiv, and (4) addition and subtraction. All of these operators associate from +left to right. A _multi-dimensional affine expression_ is a comma separated list of one-dimensional affine expressions, with the entire list enclosed in @@ -129,20 +131,12 @@ Syntax: ``` {.ebnf} affine-map-inline ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr - ( `size` `(` dim-size (`,` dim-size)* `)` )? - -dim-size ::= affine-expr - | `min` `(` affine-expr ( `,` affine-expr)+ `)` ``` The identifiers in the dimensions and symbols lists must be unique. These are -the only identifiers that may appear in 'multi-dim-affine-expr'. In addition, -only symbolic identifiers and constants can appear in 'dim-size'. Affine maps +the only identifiers that may appear in 'multi-dim-affine-expr'. Affine maps with one or more symbols in its specification are known as "symbolic affine -maps", and those with no symbols as "non-symbolic affine maps". An affine map -has an optional "size" tuple which provides the size for each corresponding -dimension. Affine maps with a size in their specification are known as "bounded -affine maps", and those without a size are "unbounded affine maps". +maps", and those with no symbols as "non-symbolic affine maps". **Context:** Affine maps are mathematical functions that transform a list of dimension indices and symbols into a list of results, with affine expressions @@ -180,16 +174,14 @@ Examples: ```mlir {.mlir} // Affine map out-of-line definition and usage example. -#affine_map42 = - (d0, d1)[s0] -> (d0, d0 + d1 + floordiv(s0,2)) size (10, s0) +#affine_map42 = (d0, d1)[s0] -> (d0, d0 + d1 + floordiv(s0,2)) // Use an affine mapping definition in an alloc operation, binding the // SSA value %N to the symbol s0. %a = alloc()[%N] : memref<4x4xf32, #affine_map42> // Same thing with an inline affine mapping definition. -%b = alloc()[%N] : memref<4x4xf32, (d0, d1)[s0] -> (d0, d0 + d1 + floordiv(s0,2)) - size (10, s0)> +%b = alloc()[%N] : memref<4x4xf32, (d0, d1)[s0] -> (d0, d0 + d1 + floordiv(s0,2))> ``` ### Semi-affine maps @@ -224,7 +216,6 @@ Syntax of semi-affine maps: ``` {.ebnf} semi-affine-map-inline ::= dim-and-symbol-id-lists `->` multi-dim-semi-affine-expr - ( `size` `(` dim-size (`,` dim-size)* `)` )? ``` Semi-affine maps may be defined inline at the point of use, or may be hoisted to diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h index aadace079b0..d3feb3436ff 100644 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -72,7 +72,6 @@ public: private: // Same meaning as AffineMap's fields. SmallVector results; - SmallVector rangeSizes; unsigned numDims; unsigned numSymbols; /// A pointer to the IR's context to store all newly created diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h index 5670d45e9a4..4b3cd838d35 100644 --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -53,8 +53,7 @@ public: AffineMap &operator=(const AffineMap &other) = default; static AffineMap get(unsigned dimCount, unsigned symbolCount, - ArrayRef results, - ArrayRef rangeSizes); + ArrayRef results); /// Returns a single constant result affine map. static AffineMap getConstantMap(int64_t val, MLIRContext *context); @@ -69,11 +68,6 @@ public: bool operator==(AffineMap other) const { return other.map == map; } bool operator!=(AffineMap other) const { return !(other.map == map); } - /// Returns true if the co-domain (or more loosely speaking, range) of this - /// map is bounded. Bounded affine maps have a size (extent) for each of - /// their range dimensions (more accurately co-domain dimensions). - bool isBounded() const; - /// Returns true if this affine map is an identity affine map. /// An identity affine map corresponds to an identity affine function on the /// dimensional identifiers. @@ -98,10 +92,7 @@ public: ArrayRef getResults() const; AffineExpr getResult(unsigned idx) const; - ArrayRef getRangeSizes() const; - - /// Walk all of the AffineExpr's in this mapping. The results are visited - /// first, and then the range sizes (if present). Each node in an expression + /// Walk all of the AffineExpr's in this mapping. Each node in an expression /// tree is visited in postorder. void walkExprs(std::function callback) const; @@ -128,15 +119,12 @@ public: /// Prerequisites: /// The maps are composable, i.e. that the number of AffineDimExpr of `this` /// matches the number of results of `map`. - /// At this time, composition of bounded AffineMap is not supported. Both - /// `this` and `map` must be unbounded. /// /// Example: /// map1: `(d0, d1)[s0, s1] -> (d0 + 1 + s1, d1 - 1 - s0)` /// map2: `(d0)[s0] -> (d0 + s0, d0 - s0))` /// map1.compose(map2): /// `(d0)[s0, s1, s2] -> (d0 + s1 + s2 + 1, d0 - s0 - s2 - 1)` - // TODO(ntv): support composition of bounded maps when we have a need for it. AffineMap compose(AffineMap map); friend ::llvm::hash_code hash_value(AffineMap arg); @@ -150,8 +138,7 @@ inline ::llvm::hash_code hash_value(AffineMap arg) { return ::llvm::hash_value(arg.map); } -/// Simplify an affine map by simplifying its underlying AffineExpr results and -/// sizes. +/// Simplify an affine map by simplifying its underlying AffineExpr results. AffineMap simplifyAffineMap(AffineMap map); /// Returns a map of codomain to domain dimensions such that the first codomain @@ -160,7 +147,6 @@ AffineMap simplifyAffineMap(AffineMap map); /// Prerequisites: /// 1. `map` is a permutation of full rank. /// 2. `map` has no symbols. -/// 3. `map` has empty `rangeSizes`. /// /// Example: /// @@ -177,8 +163,7 @@ AffineMap simplifyAffineMap(AffineMap map); AffineMap inversePermutation(AffineMap map); /// Concatenates a list of `maps` into a single AffineMap, stepping over -/// potentially empty maps. Assumes each of the underlying map has 0 symbols and -/// empty `rangeSizes`. +/// potentially empty maps. Assumes each of the underlying map has 0 symbols. /// The resulting map has a number of dims equal to the max of `maps`' dims and /// the concatenated results as its results. /// diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 9b0e27040d0..e8489661023 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -148,8 +148,7 @@ public: AffineExpr getAffineConstantExpr(int64_t constant); AffineMap getAffineMap(unsigned dimCount, unsigned symbolCount, - ArrayRef results, - ArrayRef rangeSizes); + ArrayRef results); // Special cases of affine maps and integer sets /// Returns a single constant result affine map with 0 dimensions and 0 diff --git a/mlir/include/mlir/VectorOps/VectorOps.h b/mlir/include/mlir/VectorOps/VectorOps.h index e9c7551ad89..434cda1af43 100644 --- a/mlir/include/mlir/VectorOps/VectorOps.h +++ b/mlir/include/mlir/VectorOps/VectorOps.h @@ -53,9 +53,8 @@ public: /// the access is statically guaranteed to be within bounds; /// 2. an attribute of type AffineMap to specify a slice of the original /// MemRef access and its transposition into the super-vector shape. -/// The permutation_map is an unbounded AffineMap that must -/// represent a permutation from the MemRef dim space projected onto the -/// vector dim space. +/// The permutation_map is an AffineMap that must represent a permutation +/// from the MemRef dim space projected onto the vector dim space. /// This permutation_map has as many output dimensions as the vector rank. /// However, it is not necessarily full rank on the target space to signify /// that broadcast operations will be needed along certain vector diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index ffb1dd28836..f6c044159ee 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -407,7 +407,6 @@ AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map, ArrayRef operands) : AffineApplyNormalizer() { static_assert(kMaxAffineApplyDepth > 0, "kMaxAffineApplyDepth must be > 0"); - assert(map.getRangeSizes().empty() && "Unbounded map expected"); assert(map.getNumInputs() == operands.size() && "number of operands does not match the number of map inputs"); @@ -497,7 +496,7 @@ AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map, "Unexpected number of concatenated symbols"); auto numDims = dimValueToPosition.size(); auto numSymbols = concatenatedSymbols.size() - map.getNumSymbols(); - auto auxiliaryMap = AffineMap::get(numDims, numSymbols, auxiliaryExprs, {}); + auto auxiliaryMap = AffineMap::get(numDims, numSymbols, auxiliaryExprs); LLVM_DEBUG(map.print(dbgs() << "\nCompose map: ")); LLVM_DEBUG(auxiliaryMap.print(dbgs() << "\nWith map: ")); diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 9a821a0266d..41f8e075813 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -157,21 +157,16 @@ MutableAffineMap::MutableAffineMap(AffineMap map) context(map.getResult(0).getContext()) { for (auto result : map.getResults()) results.push_back(result); - for (auto rangeSize : map.getRangeSizes()) - results.push_back(rangeSize); } void MutableAffineMap::reset(AffineMap map) { results.clear(); - rangeSizes.clear(); numDims = map.getNumDims(); numSymbols = map.getNumSymbols(); // A map always has at least 1 result by construction context = map.getResult(0).getContext(); for (auto result : map.getResults()) results.push_back(result); - for (auto rangeSize : map.getRangeSizes()) - results.push_back(rangeSize); } bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const { @@ -194,7 +189,7 @@ void MutableAffineMap::simplify() { } AffineMap MutableAffineMap::getAffineMap() const { - return AffineMap::get(numDims, numSymbols, results, rangeSizes); + return AffineMap::get(numDims, numSymbols, results); } MutableIntegerSet::MutableIntegerSet(IntegerSet set, MLIRContext *context) @@ -1454,8 +1449,8 @@ std::pair FlatAffineConstraints::getLowerAndUpperBound( auto expr = mlir::toAffineExpr(lb, dimCount, symCount, localExprs, context); exprs.push_back(expr); } - auto lbMap = exprs.empty() ? AffineMap() - : AffineMap::get(dimCount, symCount, exprs, {}); + auto lbMap = + exprs.empty() ? AffineMap() : AffineMap::get(dimCount, symCount, exprs); exprs.clear(); exprs.reserve(ubIndices.size()); @@ -1468,8 +1463,8 @@ std::pair FlatAffineConstraints::getLowerAndUpperBound( // Upper bound is exclusive. exprs.push_back(expr + 1); } - auto ubMap = exprs.empty() ? AffineMap() - : AffineMap::get(dimCount, symCount, exprs, {}); + auto ubMap = + exprs.empty() ? AffineMap() : AffineMap::get(dimCount, symCount, exprs); return {lbMap, ubMap}; } @@ -1591,8 +1586,8 @@ void FlatAffineConstraints::getSliceBounds(unsigned num, MLIRContext *context, AffineMap &ubMap = (*ubMaps)[pos]; if (expr) { - lbMap = AffineMap::get(numMapDims, numMapSymbols, expr, {}); - ubMap = AffineMap::get(numMapDims, numMapSymbols, expr + 1, {}); + lbMap = AffineMap::get(numMapDims, numMapSymbols, expr); + ubMap = AffineMap::get(numMapDims, numMapSymbols, expr + 1); } else { // TODO(bondhugula): Whenever there are local identifiers in the // dependence constraints, we'll conservatively over-approximate, since we @@ -1621,7 +1616,7 @@ void FlatAffineConstraints::getSliceBounds(unsigned num, MLIRContext *context, if (lbConst.hasValue()) { lbMap = AffineMap::get( numMapDims, numMapSymbols, - getAffineConstantExpr(lbConst.getValue(), context), {}); + getAffineConstantExpr(lbConst.getValue(), context)); } } if (!ubMap || ubMap.getNumResults() > 1) { @@ -1631,7 +1626,7 @@ void FlatAffineConstraints::getSliceBounds(unsigned num, MLIRContext *context, if (ubConst.hasValue()) { (ubMap) = AffineMap::get( numMapDims, numMapSymbols, - getAffineConstantExpr(ubConst.getValue() + 1, context), {}); + getAffineConstantExpr(ubConst.getValue() + 1, context)); } } } diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 97c2a87a10d..117cf6e109e 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -80,7 +80,7 @@ void mlir::buildTripCountMapAndOperands( for (auto ubExpr : ubMap.getResults()) ubs.push_back(b.create( forOp.getLoc(), - b.getAffineMap(ubMap.getNumDims(), ubMap.getNumSymbols(), {ubExpr}, {}), + b.getAffineMap(ubMap.getNumDims(), ubMap.getNumSymbols(), {ubExpr}), ubOperands)); tripCountOperands->clear(); @@ -92,7 +92,7 @@ void mlir::buildTripCountMapAndOperands( for (unsigned i = 0, e = ubs.size(); i < e; i++) tripCountExprs[i] = (b.getAffineDimExpr(1 + i) - b.getAffineDimExpr(0)).ceilDiv(step); - *map = b.getAffineMap(1 + ubs.size(), 0, tripCountExprs, {}); + *map = b.getAffineMap(1 + ubs.size(), 0, tripCountExprs); fullyComposeAffineMapAndOperands(map, tripCountOperands); *map = simplifyAffineMap(*map); diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index e1d31ad1d2b..0d1e2c0f416 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -140,7 +140,7 @@ static AffineMap makePermutationMap( "Vectorization prerequisite violated: at most 1 index may be " "invariant wrt a vectorized loop"); } - return AffineMap::get(indices.size(), 0, perm, {}); + return AffineMap::get(indices.size(), 0, perm); } /// Implementation detail that walks up the parents and records the ones with diff --git a/mlir/lib/EDSC/Builders.cpp b/mlir/lib/EDSC/Builders.cpp index d0fd3e97291..5c17454fb49 100644 --- a/mlir/lib/EDSC/Builders.cpp +++ b/mlir/lib/EDSC/Builders.cpp @@ -311,7 +311,7 @@ static ValueHandle createBinaryIndexHandle( if (v1) { operands.push_back(v1); } - auto map = AffineMap::get(numDims, numSymbols, {affCombiner(d0, d1)}, {}); + auto map = AffineMap::get(numDims, numSymbols, {affCombiner(d0, d1)}); // TODO: createOrFold when available. return ValueHandle::createComposedAffineApply(map, operands); } diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp index 8039a38627a..e313c6fda9f 100644 --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -102,7 +102,7 @@ private: /// Returns a single constant result affine map. AffineMap AffineMap::getConstantMap(int64_t val, MLIRContext *context) { return get(/*dimCount=*/0, /*symbolCount=*/0, - {getAffineConstantExpr(val, context)}, {}); + {getAffineConstantExpr(val, context)}); } AffineMap AffineMap::getMultiDimIdentityMap(unsigned numDims, @@ -111,16 +111,11 @@ AffineMap AffineMap::getMultiDimIdentityMap(unsigned numDims, dimExprs.reserve(numDims); for (unsigned i = 0; i < numDims; ++i) dimExprs.push_back(mlir::getAffineDimExpr(i, context)); - return get(/*dimCount=*/numDims, /*symbolCount=*/0, dimExprs, {}); + return get(/*dimCount=*/numDims, /*symbolCount=*/0, dimExprs); } MLIRContext *AffineMap::getContext() const { return getResult(0).getContext(); } -bool AffineMap::isBounded() const { - assert(map && "uninitialized AffineMap"); - return !map->rangeSizes.empty(); -} - bool AffineMap::isIdentity() const { if (getNumDims() != getNumResults()) return false; @@ -167,10 +162,6 @@ AffineExpr AffineMap::getResult(unsigned idx) const { assert(map && "uninitialized map storage"); return map->results[idx]; } -ArrayRef AffineMap::getRangeSizes() const { - assert(map && "uninitialized map storage"); - return map->rangeSizes; -} /// Folds the results of the application of an affine map on the provided /// operands to a constant if possible. Returns false if the folding happens, @@ -196,15 +187,11 @@ AffineMap::constantFold(ArrayRef operandConstants, return success(); } -/// Walk all of the AffineExpr's in this mapping. The results are visited -/// first, and then the range sizes (if present). Each node in an expression +/// Walk all of the AffineExpr's in this mapping. Each node in an expression /// tree is visited in postorder. void AffineMap::walkExprs(std::function callback) const { for (auto expr : getResults()) expr.walk(callback); - - for (auto expr : getRangeSizes()) - expr.walk(callback); } /// This method substitutes any uses of dimensions and symbols (e.g. @@ -222,19 +209,11 @@ AffineMap AffineMap::replaceDimsAndSymbols(ArrayRef dimReplacements, results.push_back( expr.replaceDimsAndSymbols(dimReplacements, symReplacements)); - SmallVector resultRanges; - resultRanges.reserve(getRangeSizes().size()); - for (auto expr : getRangeSizes()) - resultRanges.push_back( - expr.replaceDimsAndSymbols(dimReplacements, symReplacements)); - - return get(numResultDims, numResultSyms, results, resultRanges); + return get(numResultDims, numResultSyms, results); } AffineMap AffineMap::compose(AffineMap map) { assert(getNumDims() == map.getNumResults() && "Number of results mismatch"); - assert(getRangeSizes().empty() && "TODO: support bounded AffineMap"); - assert(map.getRangeSizes().empty() && "TODO: support bounded AffineMap"); // Prepare `map` by concatenating the symbols and rewriting its exprs. unsigned numDims = map.getNumDims(); unsigned numSymbolsThisMap = getNumSymbols(); @@ -254,25 +233,20 @@ AffineMap AffineMap::compose(AffineMap map) { exprs.reserve(getResults().size()); for (auto expr : getResults()) exprs.push_back(expr.compose(newMap)); - return AffineMap::get(numDims, numSymbols, exprs, {}); + return AffineMap::get(numDims, numSymbols, exprs); } AffineMap mlir::simplifyAffineMap(AffineMap map) { - SmallVector exprs, sizes; + SmallVector exprs; for (auto e : map.getResults()) { exprs.push_back( simplifyAffineExpr(e, map.getNumDims(), map.getNumSymbols())); } - for (auto e : map.getRangeSizes()) { - sizes.push_back( - simplifyAffineExpr(e, map.getNumDims(), map.getNumSymbols())); - } - return AffineMap::get(map.getNumDims(), map.getNumSymbols(), exprs, sizes); + return AffineMap::get(map.getNumDims(), map.getNumSymbols(), exprs); } AffineMap mlir::inversePermutation(AffineMap map) { assert(map.getNumSymbols() == 0 && "expected map without symbols"); - assert(map.getRangeSizes().empty() && "expected map without range sizes"); SmallVector exprs(map.getNumDims()); for (auto en : llvm::enumerate(map.getResults())) { auto expr = en.value(); @@ -287,7 +261,7 @@ AffineMap mlir::inversePermutation(AffineMap map) { if (expr) seenExprs.push_back(expr); assert(seenExprs.size() == map.getNumInputs() && "map is not full rank"); - return AffineMap::get(map.getNumResults(), 0, seenExprs, {}); + return AffineMap::get(map.getNumResults(), 0, seenExprs); } AffineMap mlir::concatAffineMaps(ArrayRef maps) { @@ -301,9 +275,8 @@ AffineMap mlir::concatAffineMaps(ArrayRef maps) { if (!m) continue; assert(m.getNumSymbols() == 0 && "expected map without symbols"); - assert(m.getRangeSizes().empty() && "expected map without range sizes"); results.append(m.getResults().begin(), m.getResults().end()); numDims = std::max(m.getNumDims(), numDims); } - return AffineMap::get(numDims, 0, results, {}); + return AffineMap::get(numDims, 0, results); } diff --git a/mlir/lib/IR/AffineMapDetail.h b/mlir/lib/IR/AffineMapDetail.h index edbc714f00b..af1d89cd239 100644 --- a/mlir/lib/IR/AffineMapDetail.h +++ b/mlir/lib/IR/AffineMapDetail.h @@ -36,10 +36,6 @@ struct AffineMapStorage { /// The affine expressions for this (multi-dimensional) map. /// TODO: use trailing objects for this. ArrayRef results; - - /// The extents along each of the range dimensions if the map is bounded, - /// nullptr otherwise. - ArrayRef rangeSizes; }; } // end namespace detail diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 20770252b2d..4c056a1ca9d 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1059,16 +1059,6 @@ void ModulePrinter::printAffineMap(AffineMap map) { interleaveComma(map.getResults(), [&](AffineExpr expr) { printAffineExpr(expr); }); os << ')'; - - if (!map.isBounded()) { - return; - } - - // Print range sizes for bounded affine maps. - os << " size ("; - interleaveComma(map.getRangeSizes(), - [&](AffineExpr expr) { printAffineExpr(expr); }); - os << ')'; } void ModulePrinter::printIntegerSet(IntegerSet set) { diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 65129cb926e..c6e84ff4858 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -267,9 +267,8 @@ Attribute Builder::getZeroAttr(Type type) { //===----------------------------------------------------------------------===// AffineMap Builder::getAffineMap(unsigned dimCount, unsigned symbolCount, - ArrayRef results, - ArrayRef rangeSizes) { - return AffineMap::get(dimCount, symbolCount, results, rangeSizes); + ArrayRef results) { + return AffineMap::get(dimCount, symbolCount, results); } AffineExpr Builder::getAffineDimExpr(unsigned position) { @@ -292,12 +291,12 @@ IntegerSet Builder::getIntegerSet(unsigned dimCount, unsigned symbolCount, AffineMap Builder::getConstantAffineMap(int64_t val) { return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0, - {getAffineConstantExpr(val)}, {}); + {getAffineConstantExpr(val)}); } AffineMap Builder::getDimIdentityMap() { return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, - {getAffineDimExpr(0)}, {}); + {getAffineDimExpr(0)}); } AffineMap Builder::getMultiDimIdentityMap(unsigned rank) { @@ -305,18 +304,18 @@ AffineMap Builder::getMultiDimIdentityMap(unsigned rank) { dimExprs.reserve(rank); for (unsigned i = 0; i < rank; ++i) dimExprs.push_back(getAffineDimExpr(i)); - return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, dimExprs, {}); + return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, dimExprs); } AffineMap Builder::getSymbolIdentityMap() { return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1, - {getAffineSymbolExpr(0)}, {}); + {getAffineSymbolExpr(0)}); } AffineMap Builder::getSingleDimShiftAffineMap(int64_t shift) { // expr = d0 + shift. auto expr = getAffineDimExpr(0) + shift; - return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, {expr}, {}); + return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, {expr}); } AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) { @@ -325,8 +324,7 @@ AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) { for (auto resultExpr : map.getResults()) { shiftedResults.push_back(resultExpr + shift); } - return AffineMap::get(map.getNumDims(), map.getNumSymbols(), shiftedResults, - map.getRangeSizes()); + return AffineMap::get(map.getNumDims(), map.getNumSymbols(), shiftedResults); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index c353fd5ec3d..1f7aca8ec00 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -149,27 +149,25 @@ struct BuiltinDialect : public Dialect { struct AffineMapKeyInfo : DenseMapInfo { // Affine maps are uniqued based on their dim/symbol counts and affine // expressions. - using KeyTy = std::tuple, - ArrayRef>; + using KeyTy = std::tuple>; using DenseMapInfo::isEqual; static unsigned getHashValue(const AffineMap &key) { - return getHashValue(KeyTy(key.getNumDims(), key.getNumSymbols(), - key.getResults(), key.getRangeSizes())); + return getHashValue( + KeyTy(key.getNumDims(), key.getNumSymbols(), key.getResults())); } static unsigned getHashValue(KeyTy key) { return hash_combine( std::get<0>(key), std::get<1>(key), - hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()), - hash_combine_range(std::get<3>(key).begin(), std::get<3>(key).end())); + hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end())); } static bool isEqual(const KeyTy &lhs, AffineMap rhs) { if (rhs == getEmptyKey() || rhs == getTombstoneKey()) return false; return lhs == std::make_tuple(rhs.getNumDims(), rhs.getNumSymbols(), - rhs.getResults(), rhs.getRangeSizes()); + rhs.getResults()); } }; @@ -797,27 +795,22 @@ StorageUniquer &MLIRContext::getAffineUniquer() { } AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, - ArrayRef results, - ArrayRef rangeSizes) { + ArrayRef results) { // The number of results can't be zero. assert(!results.empty()); - assert(rangeSizes.empty() || results.size() == rangeSizes.size()); - auto &impl = results[0].getContext()->getImpl(); - auto key = std::make_tuple(dimCount, symbolCount, results, rangeSizes); + auto key = std::make_tuple(dimCount, symbolCount, results); // Safely get or create an AffineMap instance. return safeGetOrCreate(impl.affineMaps, key, impl.affineMutex, [&] { auto *res = impl.affineAllocator.Allocate(); - // Copy the results and range sizes into the bump pointer. + // Copy the results into the bump pointer. results = copyArrayRefInto(impl.affineAllocator, results); - rangeSizes = copyArrayRefInto(impl.affineAllocator, rangeSizes); // Initialize the memory using placement new. - new (res) - detail::AffineMapStorage{dimCount, symbolCount, results, rangeSizes}; + new (res) detail::AffineMapStorage{dimCount, symbolCount, results}; return AffineMap(res); }); } diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp index 03bce65aa2f..6c0d74010bf 100644 --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -329,12 +329,12 @@ MemRefType MemRefType::getImpl(ArrayRef shape, Type elementType, ++i; } - // Drop the unbounded identity maps from the composition. + // Drop identity maps from the composition. // This may lead to the composition becoming empty, which is interpreted as an // implicit identity. llvm::SmallVector cleanedAffineMapComposition; for (const auto &map : affineMapComposition) { - if (map.isIdentity() && !map.isBounded()) + if (map.isIdentity()) continue; cleanedAffineMapComposition.push_back(map); } diff --git a/mlir/lib/Linalg/IR/LinalgOps.cpp b/mlir/lib/Linalg/IR/LinalgOps.cpp index dd32d5c4504..55a791a6d63 100644 --- a/mlir/lib/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Linalg/IR/LinalgOps.cpp @@ -741,19 +741,18 @@ SmallVector mlir::linalg::loopToOperandRangesMaps(Operation *op) { auto k = getAffineDimExpr(2, context); if (isa(op)) // A(r_i) * B(r_i) -> C() - return SmallVector{AffineMap::get(1, 0, {i}, {}), - AffineMap::get(1, 0, {i}, {}), - AffineMap()}; + return SmallVector{AffineMap::get(1, 0, {i}), + AffineMap::get(1, 0, {i}), AffineMap()}; if (isa(op)) // A(i, r_j) * B(r_j) -> C(i) - return SmallVector{AffineMap::get(2, 0, {i, j}, {}), - AffineMap::get(2, 0, {j}, {}), - AffineMap::get(2, 0, {i}, {})}; + return SmallVector{AffineMap::get(2, 0, {i, j}), + AffineMap::get(2, 0, {j}), + AffineMap::get(2, 0, {i})}; if (isa(op)) // A(i, r_k) * B(r_k, j) -> C(i, j) - return SmallVector{AffineMap::get(3, 0, {i, k}, {}), - AffineMap::get(3, 0, {k, j}, {}), - AffineMap::get(3, 0, {i, j}, {})}; + return SmallVector{AffineMap::get(3, 0, {i, k}), + AffineMap::get(3, 0, {k, j}), + AffineMap::get(3, 0, {i, j})}; llvm_unreachable("Missing loopToOperandRangesMaps for op"); } diff --git a/mlir/lib/Linalg/Utils/Utils.cpp b/mlir/lib/Linalg/Utils/Utils.cpp index c3fea9b227c..f19e61c5531 100644 --- a/mlir/lib/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Linalg/Utils/Utils.cpp @@ -131,7 +131,7 @@ mlir::linalg::applyMapToValues(FuncBuilder *b, Location loc, AffineMap map, // ranges. If the resulting application can be folded into a Value*, the // folding occurs eagerly. Otherwise, an affine.apply operation is emitted. for (auto expr : map.getResults()) { - AffineMap map = AffineMap::get(numDims, 0, expr, {}); + AffineMap map = AffineMap::get(numDims, 0, expr); res.push_back(emitOrFoldComposedAffineApply(b, loc, map, values, state)); } return res; diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 553aa926696..f1a6601f8dd 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -2115,8 +2115,6 @@ ParseResult AffineParser::parseAffineMapOrIntegerSetInline(AffineMap &map, /// Parse the range and sizes affine map definition inline. /// /// affine-map ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr -/// (`size` `(` dim-size (`,` dim-size)* `)`)? -/// dim-size ::= affine-expr | `min` `(` affine-expr ( `,` affine-expr)+ `)` /// /// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `) AffineMap AffineParser::parseAffineMapRange(unsigned numDims, @@ -2137,44 +2135,8 @@ AffineMap AffineParser::parseAffineMapRange(unsigned numDims, if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, false)) return AffineMap(); - // Parse optional range sizes. - // range-sizes ::= (`size` `(` dim-size (`,` dim-size)* `)`)? - // dim-size ::= affine-expr | `min` `(` affine-expr (`,` affine-expr)+ `)` - // TODO(bondhugula): support for min of several affine expressions. - // TODO: check if sizes are non-negative whenever they are constant. - SmallVector rangeSizes; - if (consumeIf(Token::kw_size)) { - // Location of the l_paren token (if it exists) for error reporting later. - auto loc = getToken().getLoc(); - if (parseToken(Token::l_paren, "expected '(' at start of affine map range")) - return AffineMap(); - - auto parseRangeSize = [&]() -> ParseResult { - auto loc = getToken().getLoc(); - auto elt = parseAffineExpr(); - if (!elt) - return failure(); - - if (!elt.isSymbolicOrConstant()) - return emitError(loc, - "size expressions cannot refer to dimension values"); - - rangeSizes.push_back(elt); - return success(); - }; - - if (parseCommaSeparatedListUntil(Token::r_paren, parseRangeSize, false)) - return AffineMap(); - if (exprs.size() > rangeSizes.size()) - return (emitError(loc, "fewer range sizes than range expressions"), - AffineMap()); - if (exprs.size() < rangeSizes.size()) - return (emitError(loc, "more range sizes than range expressions"), - AffineMap()); - } - // Parsed a valid affine map. - return builder.getAffineMap(numDims, numSymbols, exprs, rangeSizes); + return builder.getAffineMap(numDims, numSymbols, exprs); } /// Parse an ambiguous reference to either and affine map or an integer set. diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 6452346a6d1..143662763ae 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -326,7 +326,7 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, Block *block, // The coordinate for the start location is just the lower bound along the // corresponding dimension on the memory region (stored in 'offset'). auto map = top.getAffineMap( - cst->getNumDimIds() + cst->getNumSymbolIds() - rank, 0, offset, {}); + cst->getNumDimIds() + cst->getNumSymbolIds() - rank, 0, offset); memIndices.push_back(b->create(loc, map, regionSymbols)); } // The fast buffer is DMAed into at location zero; addressing is relative. @@ -438,8 +438,7 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, Block *block, auto dimExpr = b->getAffineDimExpr(regionSymbols.size() + i); remapExprs.push_back(dimExpr - offsets[i]); } - auto indexRemap = - b->getAffineMap(regionSymbols.size() + rank, 0, remapExprs, {}); + auto indexRemap = b->getAffineMap(regionSymbols.size() + rank, 0, remapExprs); // Record the begin since it may be invalidated by memref replacement. Block::iterator prev; diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 7eb2c7289c0..b7b69fa54fe 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -1096,10 +1096,9 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0); remapExprs.push_back(remapExpr); } - auto indexRemap = - zeroOffsetCount == rank - ? AffineMap() - : b.getAffineMap(outerIVs.size() + rank, 0, remapExprs, {}); + auto indexRemap = zeroOffsetCount == rank + ? AffineMap() + : b.getAffineMap(outerIVs.size() + rank, 0, remapExprs); // Replace all users of 'oldMemRef' with 'newMemRef'. bool ret = replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap, diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 28e13d89ada..5233081b5f1 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -168,12 +168,12 @@ constructTiledIndexSetHyperRect(MutableArrayRef origLoops, boundExprs.append(origUbMap.getResults().begin(), origUbMap.getResults().end()); auto ubMap = b.getAffineMap(origUbMap.getNumDims() + 1, - origUbMap.getNumSymbols(), boundExprs, {}); + origUbMap.getNumSymbols(), boundExprs); newLoops[width + i].setUpperBound(/*operands=*/ubOperands, ubMap); } else { // No need of the min expression. auto dim = b.getAffineDimExpr(0); - auto ubMap = b.getAffineMap(1, 0, dim + tileSizes[i], {}); + auto ubMap = b.getAffineMap(1, 0, dim + tileSizes[i]); newLoops[width + i].setUpperBound(newLoops[i].getInductionVar(), ubMap); } } diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 609b42455f5..731464bd7c1 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -223,7 +223,7 @@ LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp, if (!forOpIV->use_empty()) { // iv' = iv + i, i = 1 to unrollJamFactor-1. auto d0 = builder.getAffineDimExpr(0); - auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {}); + auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}); auto ivUnroll = builder.create(forInst->getLoc(), bumpMap, forOpIV); operandMapping.map(forOpIV, ivUnroll); diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 8094ff2f986..80e080f8fa6 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -373,7 +373,7 @@ reindexAffineIndices(FuncBuilder *b, VectorType hwVectorType, SmallVector res; res.reserve(affineExprs.size()); for (auto expr : affineExprs) { - auto map = AffineMap::get(numIndices, 0, expr, {}); + auto map = AffineMap::get(numIndices, 0, expr); res.push_back(makeComposedAffineApply(b, b->getInsertionPoint()->getLoc(), map, memrefIndices)); } @@ -470,7 +470,7 @@ static AffineMap projectedPermutationMap(VectorTransferOpTy transfer, if (keep.empty()) { return permutationMap; } - auto projectionMap = AffineMap::get(optionalRatio->size(), 0, keep, {}); + auto projectionMap = AffineMap::get(optionalRatio->size(), 0, keep); LLVM_DEBUG(projectionMap.print(dbgs() << "\nprojectionMap: ")); return simplifyAffineMap(projectionMap.compose(permutationMap)); } diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index c9e1dcefcf6..de8038c931c 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -112,7 +112,7 @@ static bool doubleBuffer(Value *oldMemRef, AffineForOp forOp) { auto d0 = bInner.getAffineDimExpr(0); int64_t step = forOp.getStep(); auto modTwoMap = bInner.getAffineMap(/*dimCount=*/1, /*symbolCount=*/0, - {d0.floorDiv(step) % 2}, {}); + {d0.floorDiv(step) % 2}); auto ivModTwoOp = bInner.create(forOp.getLoc(), modTwoMap, forOp.getInductionVar()); diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 47ee626f811..d5bdcea2c55 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -80,9 +80,8 @@ void mlir::getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor, for (unsigned i = 0, e = tripCountMap.getNumResults(); i < e; i++) { auto tripCountExpr = tripCountMap.getResult(i); bumpExprs[i] = (tripCountExpr - tripCountExpr % unrollFactor) * step; - auto bumpMap = - b->getAffineMap(tripCountMap.getNumDims(), tripCountMap.getNumSymbols(), - bumpExprs[i], {}); + auto bumpMap = b->getAffineMap(tripCountMap.getNumDims(), + tripCountMap.getNumSymbols(), bumpExprs[i]); bumpValues[i] = b->create(forOp.getLoc(), bumpMap, tripCountOperands); } @@ -94,7 +93,7 @@ void mlir::getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor, operands->clear(); operands->push_back(lb); operands->append(bumpValues.begin(), bumpValues.end()); - *map = b->getAffineMap(1 + tripCountMap.getNumResults(), 0, newUbExprs, {}); + *map = b->getAffineMap(1 + tripCountMap.getNumResults(), 0, newUbExprs); // Simplify the map + operands. fullyComposeAffineMapAndOperands(map, operands); *map = simplifyAffineMap(*map); @@ -465,7 +464,7 @@ LogicalResult mlir::loopUnrollByFactor(AffineForOp forOp, if (!forOpIV->use_empty()) { // iv' = iv + 1/2/3...unrollFactor-1; auto d0 = builder.getAffineDimExpr(0); - auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {}); + auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}); auto ivUnroll = builder.create(forOp.getLoc(), bumpMap, forOpIV); operandMap.map(forOpIV, ivUnroll); @@ -654,8 +653,7 @@ static void augmentMapAndBounds(FuncBuilder *b, Value *iv, AffineMap *map, auto bounds = llvm::to_vector<4>(map->getResults()); bounds.push_back(b->getAffineDimExpr(map->getNumDims()) + offset); operands->insert(operands->begin() + map->getNumDims(), iv); - *map = - b->getAffineMap(map->getNumDims() + 1, map->getNumSymbols(), bounds, {}); + *map = b->getAffineMap(map->getNumDims() + 1, map->getNumSymbols(), bounds); canonicalizeMapAndOperands(map, operands); } diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index e6ecd8e45fa..13e5b2f2f08 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -146,9 +146,8 @@ bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, // Remapped indices. for (auto resultExpr : indexRemap.getResults()) { - auto singleResMap = - builder.getAffineMap(indexRemap.getNumDims(), - indexRemap.getNumSymbols(), resultExpr, {}); + auto singleResMap = builder.getAffineMap( + indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr); auto afOp = builder.create(opInst->getLoc(), singleResMap, remapOperands); state.operands.push_back(afOp); @@ -259,7 +258,7 @@ void mlir::createAffineComputationSlice( sliceOps->reserve(composedMap.getNumResults()); for (auto resultExpr : composedMap.getResults()) { auto singleResMap = builder.getAffineMap( - composedMap.getNumDims(), composedMap.getNumSymbols(), resultExpr, {}); + composedMap.getNumDims(), composedMap.getNumSymbols(), resultExpr); sliceOps->push_back(builder.create( opInst->getLoc(), singleResMap, composedOpOperands)); } diff --git a/mlir/lib/VectorOps/VectorOps.cpp b/mlir/lib/VectorOps/VectorOps.cpp index 05af0293989..b40a1d9ee39 100644 --- a/mlir/lib/VectorOps/VectorOps.cpp +++ b/mlir/lib/VectorOps/VectorOps.cpp @@ -231,9 +231,6 @@ LogicalResult VectorTransferReadOp::verify() { return emitOpError("requires an AffineMapAttr named 'permutation_map'"); } auto permutationMap = getPermutationMap(); - if (!permutationMap.getRangeSizes().empty()) { - return emitOpError("requires an unbounded permutation_map"); - } if (permutationMap.getNumSymbols() != 0) { return emitOpError("requires a permutation_map without symbols"); } @@ -364,9 +361,6 @@ LogicalResult VectorTransferWriteOp::verify() { return emitOpError("requires an AffineMapAttr named 'permutation_map'"); } auto permutationMap = getPermutationMap(); - if (!permutationMap.getRangeSizes().empty()) { - return emitOpError("requires an unbounded permutation_map"); - } if (permutationMap.getNumSymbols() != 0) { return emitOpError("requires a permutation_map without symbols"); } diff --git a/mlir/test/IR/affine-map.mlir b/mlir/test/IR/affine-map.mlir index 9e67b234b9c..a393d77a068 100644 --- a/mlir/test/IR/affine-map.mlir +++ b/mlir/test/IR/affine-map.mlir @@ -135,15 +135,6 @@ // CHECK: #map{{[0-9]+}} = (d0, d1)[s0, s1] -> (d0 * s0, d0 + s0, d0 + 2, d1 * 2, s1 * 2, s0 + 2) #map39 = (i, j)[M, N] -> (i*M, M + i, 2+i, j*2, N*2, 2 + M) -// CHECK: #map{{[0-9]+}} = (d0, d1) -> (d0, d1) size (10, 20) -#map40 = (i, j) -> (i, j) size (10, 20) - -// CHECK: #map{{[0-9]+}} = (d0, d1)[s0, s1] -> (d0, d1) size (s0, s1 + 10) -#map41 = (i, j)[N, M] -> (i, j) size (N, M+10) - -// CHECK: #map{{[0-9]+}} = (d0, d1)[s0, s1] -> (d0, d1) size (128, s0 * 2 + s1 + 5) -#map42 = (i, j)[N, M] -> (i, j) size (64 + 64, 5 + 2*N + M) - // CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> ((d0 * 5) floordiv 4, (d1 ceildiv 7) mod s0) #map43 = (i, j) [s0] -> ( i * 5 floordiv 4, j ceildiv 7 mod s0) @@ -317,15 +308,6 @@ func @f38(memref<2x4xi8, #map38, 1>) // CHECK: func @f39(memref<2x4xi8, #map{{[0-9]+}}, 1>) func @f39(memref<2x4xi8, #map39, 1>) -// CHECK: func @f40(memref<2x4xi8, #map{{[0-9]+}}, 1>) -func @f40(memref<2x4xi8, #map40, 1>) - -// CHECK: func @f41(memref<2x4xi8, #map{{[0-9]+}}, 1>) -func @f41(memref<2x4xi8, #map41, 1>) - -// CHECK: func @f42(memref<2x4xi8, #map{{[0-9]+}}, 1>) -func @f42(memref<2x4xi8, #map42, 1>) - // CHECK: func @f43(memref<2x4xi8, #map{{[0-9]+}}>) func @f43(memref<2x4xi8, #map43>) diff --git a/mlir/test/IR/invalid-affinemap.mlir b/mlir/test/IR/invalid-affinemap.mlir index f48ec3af929..0a9d5bb3853 100644 --- a/mlir/test/IR/invalid-affinemap.mlir +++ b/mlir/test/IR/invalid-affinemap.mlir @@ -99,21 +99,6 @@ // ----- #hello_world = (i, j) -> (i, 3*d0 + ) // expected-error {{use of undeclared identifier}} -// ----- -#hello_world = (i, j) -> (i, j) size (10, x) // expected-error {{use of undeclared identifier}} - -// ----- -#hello_world = (i, j) [M] -> (i, j) size (10, j) // expected-error {{size expressions cannot refer to dimension values}} - -// ----- -#hello_world = (i, j) [M] -> (i, j) size (10, M+i) // expected-error {{size expressions cannot refer to dimension values}} - -// ----- -#hello_world = (i, j) -> (i, j) size (10) // expected-error {{fewer range sizes than range expressions}} - -// ----- -#hello_world = (i, j) -> (i, j) size (10, 20, 30) // expected-error {{more range sizes than range expressions}} - // TODO(bondhugula): Add more tests; coverage of error messages emitted not complete // ----- -- cgit v1.2.3 From f1b848e4701a4cd3fa781c259e3728faff1c31df Mon Sep 17 00:00:00 2001 From: River Riddle Date: Tue, 4 Jun 2019 19:18:23 -0700 Subject: NFC: Rename FuncBuilder to OpBuilder and refactor to take a top level region instead of a function. PiperOrigin-RevId: 251563898 --- mlir/bindings/python/pybind.cpp | 4 +- .../Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp | 3 +- mlir/examples/Linalg/Linalg2/Example.cpp | 4 +- mlir/examples/Linalg/Linalg2/lib/Transforms.cpp | 4 +- mlir/examples/Linalg/Linalg3/Conversion.cpp | 2 +- mlir/examples/Linalg/Linalg3/Example.cpp | 2 +- mlir/examples/Linalg/Linalg3/Execution.cpp | 2 +- .../Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp | 5 +-- mlir/examples/Linalg/Linalg3/lib/TensorOps.cpp | 10 ++--- mlir/examples/Linalg/Linalg3/lib/Transforms.cpp | 6 +-- mlir/examples/Linalg/Linalg4/Example.cpp | 6 +-- mlir/examples/Linalg/Linalg4/lib/Transforms.cpp | 2 +- mlir/examples/toy/Ch2/mlir/MLIRGen.cpp | 4 +- mlir/examples/toy/Ch3/mlir/MLIRGen.cpp | 4 +- mlir/examples/toy/Ch4/mlir/MLIRGen.cpp | 4 +- mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp | 2 +- mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp | 4 +- mlir/examples/toy/Ch5/mlir/LateLowering.cpp | 10 ++--- mlir/examples/toy/Ch5/mlir/MLIRGen.cpp | 4 +- mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp | 2 +- mlir/g3doc/Tutorials/Linalg/LLVMConversion.md | 14 +++---- mlir/g3doc/Tutorials/Toy/Ch-2.md | 2 +- mlir/g3doc/Tutorials/Toy/Ch-3.md | 4 +- mlir/g3doc/Tutorials/Toy/Ch-5.md | 2 +- mlir/include/mlir/AffineOps/AffineOps.h | 7 ++-- mlir/include/mlir/Analysis/VectorAnalysis.h | 2 +- mlir/include/mlir/EDSC/Builders.h | 18 ++++----- mlir/include/mlir/IR/Block.h | 15 ++++++++ mlir/include/mlir/IR/Builders.h | 43 ++++++++++------------ mlir/include/mlir/IR/PatternMatch.h | 4 +- mlir/include/mlir/Linalg/IR/LinalgOps.h | 10 ++--- mlir/include/mlir/Linalg/Utils/Utils.h | 2 +- mlir/include/mlir/Transforms/DialectConversion.h | 1 - mlir/include/mlir/Transforms/LoopUtils.h | 4 +- mlir/include/mlir/Transforms/Utils.h | 6 +-- mlir/lib/AffineOps/AffineOps.cpp | 6 +-- mlir/lib/Analysis/LoopAnalysis.cpp | 2 +- mlir/lib/Analysis/TestParallelismDetection.cpp | 2 +- mlir/lib/Analysis/Utils.cpp | 2 +- mlir/lib/EDSC/Builders.cpp | 9 ++--- mlir/lib/GPU/Transforms/KernelOutlining.cpp | 16 ++++---- mlir/lib/IR/Block.cpp | 18 +++++++++ mlir/lib/IR/Builders.cpp | 12 +++--- mlir/lib/IR/Function.cpp | 4 +- mlir/lib/IR/Operation.cpp | 3 +- .../lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp | 2 +- mlir/lib/Linalg/IR/LinalgOps.cpp | 2 +- mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp | 2 +- mlir/lib/Linalg/Transforms/LowerToLoops.cpp | 4 +- mlir/lib/Linalg/Transforms/Tiling.cpp | 6 +-- mlir/lib/Linalg/Utils/Utils.cpp | 6 +-- mlir/lib/Parser/Parser.cpp | 4 +- .../Transforms/AddDefaultStatsTestPass.cpp | 4 +- .../Transforms/InferQuantizedTypesPass.cpp | 4 +- mlir/lib/Transforms/DialectConversion.cpp | 8 ++-- mlir/lib/Transforms/DmaGeneration.cpp | 10 ++--- mlir/lib/Transforms/LoopFusion.cpp | 4 +- mlir/lib/Transforms/LoopInvariantCodeMotion.cpp | 2 +- mlir/lib/Transforms/LoopTiling.cpp | 6 +-- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 5 +-- mlir/lib/Transforms/LowerAffine.cpp | 16 ++++---- mlir/lib/Transforms/MaterializeVectors.cpp | 14 +++---- mlir/lib/Transforms/PipelineDataTransfer.cpp | 6 +-- mlir/lib/Transforms/Utils/FoldUtils.cpp | 2 +- .../Utils/GreedyPatternRewriteDriver.cpp | 12 +++--- mlir/lib/Transforms/Utils/LoopUtils.cpp | 25 ++++++------- mlir/lib/Transforms/Utils/Utils.cpp | 4 +- .../Vectorization/VectorizerTestPass.cpp | 2 +- mlir/lib/Transforms/Vectorize.cpp | 10 ++--- mlir/test/EDSC/builder-api-test.cpp | 26 ++++++------- 70 files changed, 249 insertions(+), 229 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/bindings/python/pybind.cpp b/mlir/bindings/python/pybind.cpp index 76cec271621..6ec0860ff7f 100644 --- a/mlir/bindings/python/pybind.cpp +++ b/mlir/bindings/python/pybind.cpp @@ -248,7 +248,7 @@ struct PythonFunctionContext { PythonFunction enter() { assert(function.function && "function is not set up"); auto *mlirFunc = static_cast(function.function); - contextBuilder.emplace(mlirFunc); + contextBuilder.emplace(mlirFunc->getBody()); context = new mlir::edsc::ScopedContext(*contextBuilder, mlirFunc->getLoc()); return function; @@ -262,7 +262,7 @@ struct PythonFunctionContext { PythonFunction function; mlir::edsc::ScopedContext *context; - llvm::Optional contextBuilder; + llvm::Optional contextBuilder; }; PythonFunctionContext PythonMLIRModule::makeFunctionContext( diff --git a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp index 8cd970c56f1..d13f7f3de92 100644 --- a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp @@ -121,8 +121,7 @@ Type linalg::convertLinalgType(Type t) { // Create an array attribute containing integer attributes with values provided // in `position`. -static ArrayAttr makePositionAttr(FuncBuilder &builder, - ArrayRef position) { +static ArrayAttr makePositionAttr(OpBuilder &builder, ArrayRef position) { SmallVector attrs; attrs.reserve(position.size()); for (auto p : position) diff --git a/mlir/examples/Linalg/Linalg2/Example.cpp b/mlir/examples/Linalg/Linalg2/Example.cpp index 0de8a9044e1..10f4bf7e905 100644 --- a/mlir/examples/Linalg/Linalg2/Example.cpp +++ b/mlir/examples/Linalg/Linalg2/Example.cpp @@ -39,7 +39,7 @@ TEST_FUNC(linalg_ops) { mlir::Function *f = makeFunction(module, "linalg_ops", {indexType, indexType, indexType}, {}); - FuncBuilder builder(f); + OpBuilder builder(f->getBody()); ScopedContext scope(builder, f->getLoc()); // clang-format off @@ -78,7 +78,7 @@ TEST_FUNC(linalg_ops_folded_slices) { mlir::Function *f = makeFunction(module, "linalg_ops_folded_slices", {indexType, indexType, indexType}, {}); - FuncBuilder builder(f); + OpBuilder builder(f->getBody()); ScopedContext scope(builder, f->getLoc()); // clang-format off diff --git a/mlir/examples/Linalg/Linalg2/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg2/lib/Transforms.cpp index 4523830129c..f4d3d68d28a 100644 --- a/mlir/examples/Linalg/Linalg2/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg2/lib/Transforms.cpp @@ -31,8 +31,8 @@ using llvm::ArrayRef; using llvm::cast; using llvm::isa; using llvm::SmallVector; -using mlir::FuncBuilder; using mlir::MemRefType; +using mlir::OpBuilder; using mlir::Value; using mlir::edsc::ScopedContext; using mlir::edsc::ValueHandle; @@ -101,7 +101,7 @@ static mlir::Value *createFullyComposedIndexing(unsigned dim, } ViewOp linalg::emitAndReturnFullyComposedView(Value *v) { - FuncBuilder builder(v->getDefiningOp()); + OpBuilder builder(v->getDefiningOp()); ScopedContext scope(builder, v->getDefiningOp()->getLoc()); assert(v->getType().isa() && "must be a ViewType"); auto *memRef = getViewSupportingMemRef(v); diff --git a/mlir/examples/Linalg/Linalg3/Conversion.cpp b/mlir/examples/Linalg/Linalg3/Conversion.cpp index 0d7b22b0fe9..37d1b51f53e 100644 --- a/mlir/examples/Linalg/Linalg3/Conversion.cpp +++ b/mlir/examples/Linalg/Linalg3/Conversion.cpp @@ -44,7 +44,7 @@ Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) { module, name, {dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {}); - FuncBuilder builder(f); + OpBuilder builder(f->getBody()); ScopedContext scope(builder, f->getLoc()); // clang-format off ValueHandle diff --git a/mlir/examples/Linalg/Linalg3/Example.cpp b/mlir/examples/Linalg/Linalg3/Example.cpp index cf77785532b..f02aef920e4 100644 --- a/mlir/examples/Linalg/Linalg3/Example.cpp +++ b/mlir/examples/Linalg/Linalg3/Example.cpp @@ -41,7 +41,7 @@ Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) { module, name, {dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {}); - mlir::FuncBuilder builder(f); + mlir::OpBuilder builder(f->getBody()); ScopedContext scope(builder, f->getLoc()); // clang-format off ValueHandle diff --git a/mlir/examples/Linalg/Linalg3/Execution.cpp b/mlir/examples/Linalg/Linalg3/Execution.cpp index 902ea67abe1..00d571cbc99 100644 --- a/mlir/examples/Linalg/Linalg3/Execution.cpp +++ b/mlir/examples/Linalg/Linalg3/Execution.cpp @@ -44,7 +44,7 @@ Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) { module, name, {dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {}); - mlir::FuncBuilder builder(f); + mlir::OpBuilder builder(f->getBody()); ScopedContext scope(builder, f->getLoc()); // clang-format off ValueHandle diff --git a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp index 60fdf60039f..ef0d8581a99 100644 --- a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp @@ -41,8 +41,7 @@ using namespace mlir; // Create an array attribute containing integer attributes with values provided // in `position`. -static ArrayAttr makePositionAttr(FuncBuilder &builder, - ArrayRef position) { +static ArrayAttr makePositionAttr(Builder &builder, ArrayRef position) { SmallVector attrs; attrs.reserve(position.size()); for (auto p : position) @@ -64,7 +63,7 @@ public: // descriptor to emit IR iteratively computing the actual offset, followed by // a getelementptr. Value *obtainDataPtr(Operation *op, Value *viewDescriptor, - ArrayRef indices, FuncBuilder &rewriter) const { + ArrayRef indices, Builder &rewriter) const { auto loadOp = cast(op); auto elementType = loadOp.getViewType().template cast().getElementType(); diff --git a/mlir/examples/Linalg/Linalg3/lib/TensorOps.cpp b/mlir/examples/Linalg/Linalg3/lib/TensorOps.cpp index f539c702549..778f2ea5540 100644 --- a/mlir/examples/Linalg/Linalg3/lib/TensorOps.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/TensorOps.cpp @@ -64,7 +64,7 @@ void linalg::DotOp::emitScalarImplementation( using edsc::intrinsics::select; // Account for affine.terminator in loop. - FuncBuilder builder(body, std::prev(body->end(), 1)); + OpBuilder builder(body, std::prev(body->end(), 1)); ScopedContext scope(builder, innermostLoop.getLoc()); FloatType fTy = getOperand(0) ->getType() @@ -107,7 +107,7 @@ void linalg::MatvecOp::writeAsFinerGrainTensorContraction() { assert( llvm::isa_and_nonnull(indexingPosPair.first->getDefiningOp())); // clang-format off - FuncBuilder builder(op); + OpBuilder builder(op); ScopedContext scope(builder, op->getLoc()); IndexHandle i; using linalg::common::LoopNestRangeBuilder; @@ -132,7 +132,7 @@ void linalg::MatvecOp::emitScalarImplementation( using edsc::op::operator==; using edsc::intrinsics::select; // Account for affine.terminator in loop. - FuncBuilder builder(body, std::prev(body->end(), 1)); + OpBuilder builder(body, std::prev(body->end(), 1)); ScopedContext scope(builder, innermostLoop.getLoc()); FloatType fTy = getOperand(0) ->getType() @@ -181,7 +181,7 @@ void linalg::MatmulOp::writeAsFinerGrainTensorContraction() { llvm::isa_and_nonnull(indexingPosPair.first->getDefiningOp())); using linalg::common::LoopNestRangeBuilder; // clang-format off - FuncBuilder builder(op); + OpBuilder builder(op); ScopedContext scope(builder, op->getLoc()); IndexHandle j; LoopNestRangeBuilder(&j, ValueHandle(indexingPosPair.first))( @@ -205,7 +205,7 @@ void linalg::MatmulOp::emitScalarImplementation( using edsc::op::operator==; using edsc::intrinsics::select; // Account for affine.terminator in loop. - FuncBuilder builder(body, std::prev(body->end(), 1)); + OpBuilder builder(body, std::prev(body->end(), 1)); ScopedContext scope(builder, innermostLoop.getLoc()); FloatType fTy = getOperand(0) ->getType() diff --git a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp index 3a11c6d17d9..5b16ce0eda5 100644 --- a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp @@ -161,7 +161,7 @@ linalg::makeGenericLoopRanges(AffineMap operandRangesToLoopMaps, template static SmallVector writeContractionAsLoops(ContractionOp contraction) { - FuncBuilder builder(contraction.getOperation()); + OpBuilder builder(contraction.getOperation()); ScopedContext scope(builder, contraction.getLoc()); auto allRanges = getRanges(contraction); auto loopRanges = @@ -274,7 +274,7 @@ Rewriter::matchAndRewrite(linalg::LoadOp load, SliceOp slice = dyn_cast(load.getView()->getDefiningOp()); ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult()) : cast(load.getView()->getDefiningOp()); - FuncBuilder builder(load); + OpBuilder builder(load); ScopedContext scope(builder, load.getLoc()); auto *memRef = view.getSupportingMemRef(); auto operands = emitAndReturnLoadStoreOperands(load, view); @@ -289,7 +289,7 @@ Rewriter::matchAndRewrite(linalg::StoreOp store, SliceOp slice = dyn_cast(store.getView()->getDefiningOp()); ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult()) : cast(store.getView()->getDefiningOp()); - FuncBuilder builder(store); + OpBuilder builder(store); ScopedContext scope(builder, store.getLoc()); auto *valueToStore = store.getValueToStore(); auto *memRef = view.getSupportingMemRef(); diff --git a/mlir/examples/Linalg/Linalg4/Example.cpp b/mlir/examples/Linalg/Linalg4/Example.cpp index 73e75706f11..cdc05a1cc21 100644 --- a/mlir/examples/Linalg/Linalg4/Example.cpp +++ b/mlir/examples/Linalg/Linalg4/Example.cpp @@ -41,7 +41,7 @@ Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) { module, name, {dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {}); - FuncBuilder builder(f); + OpBuilder builder(f->getBody()); ScopedContext scope(builder, f->getLoc()); // clang-format off @@ -97,7 +97,7 @@ TEST_FUNC(matmul_tiled_views) { MLIRContext context; Module module(&context); mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_tiled_views"); - FuncBuilder b(f); + OpBuilder b(f->getBody()); lowerToTiledViews(f, {b.create(f->getLoc(), 8), b.create(f->getLoc(), 9)}); composeSliceOps(f); @@ -127,7 +127,7 @@ TEST_FUNC(matmul_tiled_views_as_loops) { Module module(&context); mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_tiled_views_as_loops"); - FuncBuilder b(f); + OpBuilder b(f->getBody()); lowerToTiledViews(f, {b.create(f->getLoc(), 8), b.create(f->getLoc(), 9)}); composeSliceOps(f); diff --git a/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp index 3df6f4b00ac..11cd6e5cd9a 100644 --- a/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp @@ -148,7 +148,7 @@ writeContractionAsTiledViews(TensorContractionBase &contraction, contraction.getNumParallelDims() + contraction.getNumReductionDims()); auto *op = static_cast(&contraction); - mlir::FuncBuilder builder(op->getOperation()); + mlir::OpBuilder builder(op->getOperation()); ScopedContext scope(builder, op->getLoc()); SmallVector ivs(tileSizes.size()); auto pivs = IndexHandle::makeIndexHandlePointers(ivs); diff --git a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp index 5eb8cd089f5..df09cd0921e 100644 --- a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp @@ -104,7 +104,7 @@ private: /// convenience for emitting individual operations. /// The builder is stateful, in particular it keeeps an "insertion point": /// this is where the next operations will be introduced. - std::unique_ptr builder; + std::unique_ptr builder; /// The symbol table maps a variable name to a value in the current scope. /// Entering a function creates a new scope, and the function arguments are @@ -174,7 +174,7 @@ private: // Create a builder for the function, it will be used throughout the codegen // to create operations in this function. - builder = llvm::make_unique(function.get()); + builder = llvm::make_unique(function->getBody()); // Emit the body of the function. if (!mlirGen(*funcAST.getBody())) diff --git a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp index 7c580d25488..4001b308a3c 100644 --- a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp @@ -105,7 +105,7 @@ private: /// convenience for emitting individual operations. /// The builder is stateful, in particular it keeeps an "insertion point": /// this is where the next operations will be introduced. - std::unique_ptr builder; + std::unique_ptr builder; /// The symbol table maps a variable name to a value in the current scope. /// Entering a function creates a new scope, and the function arguments are @@ -175,7 +175,7 @@ private: // Create a builder for the function, it will be used throughout the codegen // to create operations in this function. - builder = llvm::make_unique(function.get()); + builder = llvm::make_unique(function->getBody()); // Emit the body of the function. if (!mlirGen(*funcAST.getBody())) diff --git a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp index e2001fb575e..e091cbd9ec2 100644 --- a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp @@ -105,7 +105,7 @@ private: /// convenience for emitting individual operations. /// The builder is stateful, in particular it keeeps an "insertion point": /// this is where the next operations will be introduced. - std::unique_ptr builder; + std::unique_ptr builder; /// The symbol table maps a variable name to a value in the current scope. /// Entering a function creates a new scope, and the function arguments are @@ -175,7 +175,7 @@ private: // Create a builder for the function, it will be used throughout the codegen // to create operations in this function. - builder = llvm::make_unique(function.get()); + builder = llvm::make_unique(function->getBody()); // Emit the body of the function. if (!mlirGen(*funcAST.getBody())) diff --git a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp index 2c06526ec7b..440e3d8be00 100644 --- a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp @@ -315,7 +315,7 @@ public: // Found a specialized callee! Let's turn this into a normal call // operation. SmallVector operands(op->getOperands()); - mlir::FuncBuilder builder(op); + mlir::OpBuilder builder(op); auto newCall = builder.create(op->getLoc(), mangledCallee, operands); if (newCall.getNumResults()) { diff --git a/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp b/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp index 45d608d72a5..189add05ee3 100644 --- a/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp @@ -57,7 +57,7 @@ namespace { /// time both side of the cast (producer and consumer) will be lowered to a /// dialect like LLVM and end up with the same LLVM representation, at which /// point this becomes a no-op and is eliminated. -Value *typeCast(FuncBuilder &builder, Value *val, Type destTy) { +Value *typeCast(PatternRewriter &builder, Value *val, Type destTy) { if (val->getType() == destTy) return val; return builder.create(val->getLoc(), val, destTy) @@ -67,7 +67,7 @@ Value *typeCast(FuncBuilder &builder, Value *val, Type destTy) { /// Create a type cast to turn a toy.array into a memref. The Toy Array will be /// lowered to a memref during buffer allocation, at which point the type cast /// becomes useless. -Value *memRefTypeCast(FuncBuilder &builder, Value *val) { +Value *memRefTypeCast(PatternRewriter &builder, Value *val) { if (val->getType().isa()) return val; auto toyArrayTy = val->getType().dyn_cast(); diff --git a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp index d682d12b253..ecf6c9df05d 100644 --- a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp @@ -57,7 +57,7 @@ namespace { /// time both side of the cast (producer and consumer) will be lowered to a /// dialect like LLVM and end up with the same LLVM representation, at which /// point this becomes a no-op and is eliminated. -Value *typeCast(FuncBuilder &builder, Value *val, Type destTy) { +Value *typeCast(PatternRewriter &builder, Value *val, Type destTy) { if (val->getType() == destTy) return val; return builder.create(val->getLoc(), val, destTy) @@ -67,7 +67,7 @@ Value *typeCast(FuncBuilder &builder, Value *val, Type destTy) { /// Create a type cast to turn a toy.array into a memref. The Toy Array will be /// lowered to a memref during buffer allocation, at which point the type cast /// becomes useless. -Value *memRefTypeCast(FuncBuilder &builder, Value *val) { +Value *memRefTypeCast(PatternRewriter &builder, Value *val) { if (val->getType().isa()) return val; auto toyArrayTy = val->getType().dyn_cast(); @@ -183,7 +183,7 @@ public: private: // Turn a string into a toy.alloc (malloc/free abstraction) and a sequence // of stores into the buffer, and return a MemRef into the buffer. - Value *getConstantCharBuffer(FuncBuilder &builder, Location loc, + Value *getConstantCharBuffer(PatternRewriter &builder, Location loc, StringRef data) const { auto retTy = builder.getMemRefType(data.size() + 1, builder.getIntegerType(8)); @@ -405,7 +405,7 @@ struct LateLoweringPass : public ModulePass { /// operating in a brand new function: we don't have the return to hook the /// dealloc operations. Value *allocTensor(toy::AllocOp alloc) { - FuncBuilder builder(alloc); + OpBuilder builder(alloc); auto retTy = alloc.getResult()->getType(); auto memRefTy = retTy.dyn_cast(); @@ -420,7 +420,7 @@ struct LateLoweringPass : public ModulePass { // Insert a `dealloc` operation right before the `return` operations, unless // it is returned itself in which case the caller is responsible for it. - builder.getFunction()->walk([&](Operation *op) { + builder.getRegion()->walk([&](Operation *op) { auto returnOp = dyn_cast(op); if (!returnOp) return; diff --git a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp index e2001fb575e..e091cbd9ec2 100644 --- a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp @@ -105,7 +105,7 @@ private: /// convenience for emitting individual operations. /// The builder is stateful, in particular it keeeps an "insertion point": /// this is where the next operations will be introduced. - std::unique_ptr builder; + std::unique_ptr builder; /// The symbol table maps a variable name to a value in the current scope. /// Entering a function creates a new scope, and the function arguments are @@ -175,7 +175,7 @@ private: // Create a builder for the function, it will be used throughout the codegen // to create operations in this function. - builder = llvm::make_unique(function.get()); + builder = llvm::make_unique(function->getBody()); // Emit the body of the function. if (!mlirGen(*funcAST.getBody())) diff --git a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp index c9da85f8106..4294f7bbbbf 100644 --- a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp @@ -319,7 +319,7 @@ public: // Found a specialized callee! Let's turn this into a normal call // operation. SmallVector operands(op->getOperands()); - mlir::FuncBuilder builder(f); + mlir::OpBuilder builder(f->getBody()); builder.setInsertionPoint(op); auto newCall = builder.create(op->getLoc(), mangledCallee, operands); diff --git a/mlir/g3doc/Tutorials/Linalg/LLVMConversion.md b/mlir/g3doc/Tutorials/Linalg/LLVMConversion.md index 83a2a319713..af34c9ceec4 100644 --- a/mlir/g3doc/Tutorials/Linalg/LLVMConversion.md +++ b/mlir/g3doc/Tutorials/Linalg/LLVMConversion.md @@ -233,7 +233,7 @@ public: // needs to define as many value as the original operation, but their types // may be different. SmallVector rewrite(Operation *op, ArrayRef operands, - FuncBuilder &rewriter) const override; + OpBuilder &rewriter) const override; } ``` @@ -296,7 +296,7 @@ operates. ```c++ SmallVector ViewOpConversion::rewrite( Operation *op, ArrayRef operands, - FuncBuilder &rewriter) const override { + OpBuilder &rewriter) const override { // Obtain the typed operation (we know we matched only one type). auto viewOp = op->cast(); @@ -437,7 +437,7 @@ struct ViewDescriptor { } // The builder into which we emit code. - FuncBuilder &builder; + OpBuilder &builder; // The actual descriptor. Value *d; @@ -450,7 +450,7 @@ rules described above: ```c++ SmallVector SliceOpConversion::rewrite( Operation *op, ArrayRef operands, - FuncBuilder &rewriter) const override { + OpBuilder &rewriter) const override { // Obtain the typed operation (we know we matched only one type). auto sliceOp = op->cast(); @@ -528,7 +528,7 @@ for the view descriptor: ```c++ Value *obtainDataPtr(Location loc, int rank, Value *viewDescriptorVal, - ArrayRef indices, FuncBuilder &rewriter) { + ArrayRef indices, OpBuilder &rewriter) { // Create the context object (RAII) in which we can use declarative builders. // Bring all the builders into the namespace. using namespace intrinsics; @@ -560,7 +560,7 @@ conversions for load and store operations. // Load Operation Conversion. SmallVector LoadOpConversion::rewrite( Operation *op, ArrayRef operands, - FuncBuilder &rewriter) const override { + OpBuilder &rewriter) const override { // Obtain the typed operation (we know we matched only one type). auto loadOp = op->cast(); @@ -582,7 +582,7 @@ SmallVector LoadOpConversion::rewrite( // Store Operation Conversion SmallVector StoreOpConversion::rewrite( Operation *op, ArrayRef operands, - FuncBuilder &rewriter) const override { + OpBuilder &rewriter) const override { // Obtain the typed operation (we know we matched only one type). auto loadOp = op->cast(); diff --git a/mlir/g3doc/Tutorials/Toy/Ch-2.md b/mlir/g3doc/Tutorials/Toy/Ch-2.md index 4a8b8dc4454..9b07385bb89 100644 --- a/mlir/g3doc/Tutorials/Toy/Ch-2.md +++ b/mlir/g3doc/Tutorials/Toy/Ch-2.md @@ -123,7 +123,7 @@ generation through a simple depth-first search traversal of the Toy AST. Here is how we create a `toy.transpose` operation: ``` -mlir::Operation *createTransposeOp(FuncBuilder *builder, +mlir::Operation *createTransposeOp(OpBuilder *builder, mlir::Value *input_array) { // We bundle our custom type in a `toy` dialect. auto toyDialect = mlir::Identifier::get("toy", builder->getContext()); diff --git a/mlir/g3doc/Tutorials/Toy/Ch-3.md b/mlir/g3doc/Tutorials/Toy/Ch-3.md index 498438ad03c..9ff6c401b55 100644 --- a/mlir/g3doc/Tutorials/Toy/Ch-3.md +++ b/mlir/g3doc/Tutorials/Toy/Ch-3.md @@ -202,11 +202,11 @@ class GenericCallOp bool verify(); /// Interface to the builder to allow: - /// mlir::FuncBuilder::create(...) + /// mlir::OpBuilder::create(...) /// This method populate the `state` that MLIR use to create operations. /// The `toy.generic_call` operation accepts a callee name and a list of /// arguments for the call. - static void build(mlir::FuncBuilder *builder, mlir::OperationState *state, + static void build(mlir::OpBuilder *builder, mlir::OperationState *state, llvm::StringRef callee, llvm::ArrayRef arguments); diff --git a/mlir/g3doc/Tutorials/Toy/Ch-5.md b/mlir/g3doc/Tutorials/Toy/Ch-5.md index 24612755ce0..2681720e0fc 100644 --- a/mlir/g3doc/Tutorials/Toy/Ch-5.md +++ b/mlir/g3doc/Tutorials/Toy/Ch-5.md @@ -80,7 +80,7 @@ public: /// The results created by the new IR with the builder are returned, and their /// number must match the number of result of `op`. SmallVector rewrite(Operation *op, ArrayRef operands, - FuncBuilder &rewriter) const override { + OpBuilder &rewriter) const override { ... // Return the newly allocated buffer, it will be used as an operand when diff --git a/mlir/include/mlir/AffineOps/AffineOps.h b/mlir/include/mlir/AffineOps/AffineOps.h index a3749a389cb..8fcd0abe920 100644 --- a/mlir/include/mlir/AffineOps/AffineOps.h +++ b/mlir/include/mlir/AffineOps/AffineOps.h @@ -32,7 +32,7 @@ namespace mlir { class AffineBound; class AffineValueMap; class FlatAffineConstraints; -class FuncBuilder; +class OpBuilder; /// A utility function to check if a value is defined at the top level of a /// function. A value defined at the top level is always a valid symbol. @@ -143,7 +143,7 @@ public: /// Return a Builder set up to insert operations immediately before the /// terminator. - FuncBuilder getBodyBuilder(); + OpBuilder getBodyBuilder(); /// Get the body of the AffineForOp. Block *getBody() { return &getRegion().front(); } @@ -361,8 +361,7 @@ void canonicalizeMapAndOperands(AffineMap *map, /// Returns a composed AffineApplyOp by composing `map` and `operands` with /// other AffineApplyOps supplying those operands. The operands of the resulting /// AffineApplyOp do not change the length of AffineApplyOp chains. -AffineApplyOp makeComposedAffineApply(FuncBuilder *b, Location loc, - AffineMap map, +AffineApplyOp makeComposedAffineApply(OpBuilder *b, Location loc, AffineMap map, llvm::ArrayRef operands); /// Given an affine map `map` and its input `operands`, this method composes diff --git a/mlir/include/mlir/Analysis/VectorAnalysis.h b/mlir/include/mlir/Analysis/VectorAnalysis.h index bf070e81be8..1f4e50c1178 100644 --- a/mlir/include/mlir/Analysis/VectorAnalysis.h +++ b/mlir/include/mlir/Analysis/VectorAnalysis.h @@ -27,9 +27,9 @@ namespace mlir { class AffineApplyOp; class AffineForOp; class AffineMap; -class FuncBuilder; class Location; class MemRefType; +class OpBuilder; class Operation; class Value; class VectorType; diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h index c925e0a39e7..aa5c321627d 100644 --- a/mlir/include/mlir/EDSC/Builders.h +++ b/mlir/include/mlir/EDSC/Builders.h @@ -50,17 +50,17 @@ class ValueHandle; /// setting and restoring of insertion points. class ScopedContext { public: - ScopedContext(FuncBuilder &builder, Location location); + ScopedContext(OpBuilder &builder, Location location); /// Sets the insertion point of the builder to 'newInsertPt' for the duration /// of the scope. The existing insertion point of the builder is restored on /// destruction. - ScopedContext(FuncBuilder &builder, FuncBuilder::InsertPoint newInsertPt, + ScopedContext(OpBuilder &builder, OpBuilder::InsertPoint newInsertPt, Location location); ~ScopedContext(); static MLIRContext *getContext(); - static FuncBuilder *getBuilder(); + static OpBuilder *getBuilder(); static Location getLocation(); private: @@ -74,10 +74,10 @@ private: static ScopedContext *&getCurrentScopedContext(); - /// Top level FuncBuilder. - FuncBuilder &builder; + /// Top level OpBuilder. + OpBuilder &builder; /// The previous insertion point of the builder. - llvm::Optional prevBuilderInsertPoint; + llvm::Optional prevBuilderInsertPoint; /// Current location. Location location; /// Parent context we return into. @@ -116,20 +116,20 @@ protected: /// Enter an mlir::Block and setup a ScopedContext to insert operations at /// the end of it. Since we cannot use c++ language-level scoping to implement /// scoping itself, we use enter/exit pairs of operations. - /// As a consequence we must allocate a new FuncBuilder + ScopedContext and + /// As a consequence we must allocate a new OpBuilder + ScopedContext and /// let the escape. /// Step back "prev" times from the end of the block to set up the insertion /// point, which is useful for non-empty blocks. void enter(mlir::Block *block, int prev = 0) { bodyScope = new ScopedContext( *ScopedContext::getBuilder(), - FuncBuilder::InsertPoint(block, std::prev(block->end(), prev)), + OpBuilder::InsertPoint(block, std::prev(block->end(), prev)), ScopedContext::getLocation()); bodyScope->nestedBuilder = this; } /// Exit the current mlir::Block by explicitly deleting the dynamically - /// allocated FuncBuilder and ScopedContext. + /// allocated OpBuilder and ScopedContext. void exit() { // Reclaim now to exit the scope. bodyScope->nestedBuilder = nullptr; diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h index 3c627b4ef3f..381a790cfe7 100644 --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -344,6 +344,14 @@ public: explicit Region(Operation *container); ~Region(); + /// Return the context this region is inserted in. The region must have a + /// valid parent container. + MLIRContext *getContext(); + + /// Return a location for this region. This is the location attached to the + /// parent container. The region must have a valid parent container. + Location getLoc(); + using RegionType = llvm::iplist; RegionType &getBlocks() { return blocks; } @@ -409,6 +417,13 @@ public: /// the operation with an offending use. bool isIsolatedAbove(llvm::Optional noteLoc = llvm::None); + /// Walk the operations in this block in postorder, calling the callback for + /// each operation. + void walk(const std::function &callback) { + for (auto &block : *this) + block.walk(callback); + } + private: RegionType blocks; diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index b869dcdaf4a..09eaf5669a6 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -181,40 +181,37 @@ protected: MLIRContext *context; }; -/// This class helps build a Function. Operations that are created are -/// automatically inserted at an insertion point. The builder is copyable. -class FuncBuilder : public Builder { +/// This class helps build Operations. Operations that are created are +/// automatically inserted at an insertion point. The builder is copyable. +class OpBuilder : public Builder { public: - /// Create a function builder and set the insertion point to the start of - /// the function. - explicit FuncBuilder(Function *func) - : Builder(func->getContext()), function(func) { - if (!func->empty()) - setInsertionPoint(&func->front(), func->front().begin()); + /// Create a builder and set the insertion point to the start of the region. + explicit OpBuilder(Region *region) + : Builder(region->getContext()), region(region) { + if (!region->empty()) + setInsertionPoint(®ion->front(), region->front().begin()); else clearInsertionPoint(); } + explicit OpBuilder(Region ®ion) : OpBuilder(®ion) {} - explicit FuncBuilder(Function &func) : FuncBuilder(&func) {} - virtual ~FuncBuilder(); + virtual ~OpBuilder(); - /// Create a function builder and set insertion point to the given - /// operation, which will cause subsequent insertions to go right before it. - FuncBuilder(Operation *op) : FuncBuilder(op->getFunction()) { + /// Create a builder and set insertion point to the given operation, which + /// will cause subsequent insertions to go right before it. + OpBuilder(Operation *op) : OpBuilder(op->getContainingRegion()) { setInsertionPoint(op); } - FuncBuilder(Block *block) : FuncBuilder(block->getFunction()) { - setInsertionPoint(block, block->end()); - } + OpBuilder(Block *block) : OpBuilder(block, block->end()) {} - FuncBuilder(Block *block, Block::iterator insertPoint) - : FuncBuilder(block->getFunction()) { + OpBuilder(Block *block, Block::iterator insertPoint) + : OpBuilder(block->getParent()) { setInsertionPoint(block, insertPoint); } - /// Return the function this builder is referring to. - Function *getFunction() const { return function; } + /// Return the region this builder is referring to. + Region *getRegion() const { return region; } /// This class represents a saved insertion point. class InsertPoint { @@ -291,7 +288,7 @@ public: /// Add new block and set the insertion point to the end of it. If an /// 'insertBefore' block is passed, the block will be placed before the /// specified block. If not, the block will be appended to the end of the - /// current function. + /// current region. Block *createBlock(Block *insertBefore = nullptr); /// Returns the current block of the builder. @@ -342,7 +339,7 @@ public: } private: - Function *function; + Region *region; Block *block = nullptr; Block::iterator insertPoint; }; diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index bbca58bf521..08fb4905e37 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -266,7 +266,7 @@ template struct OpRewritePattern : public RewritePattern { /// to apply patterns and observe their effects (e.g. to keep worklists or /// other data structures up to date). /// -class PatternRewriter : public FuncBuilder { +class PatternRewriter : public OpBuilder { public: /// Create operation of specific op type at the current insertion point /// without verifying to see if it is valid. @@ -342,7 +342,7 @@ public: ArrayRef valuesToRemoveIfDead = {}); protected: - PatternRewriter(Function *fn) : FuncBuilder(fn) {} + PatternRewriter(Region ®ion) : OpBuilder(region) {} virtual ~PatternRewriter(); // These are the callback methods that subclasses can choose to implement if diff --git a/mlir/include/mlir/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Linalg/IR/LinalgOps.h index 92f2630dae0..9b023448240 100644 --- a/mlir/include/mlir/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Linalg/IR/LinalgOps.h @@ -113,9 +113,9 @@ public: /// Return a Builder set up to insert operations immediately before the /// terminator. - FuncBuilder getBodyBuilder() { + OpBuilder getBodyBuilder() { Block *body = getBody(); - return FuncBuilder(body, std::prev(body->end())); + return OpBuilder(body, std::prev(body->end())); } /// Get the body of the ForOp. @@ -408,7 +408,7 @@ public: unsigned getNumInputsAndOutputs() { return impl->getNumInputsAndOutputs(getOperation()); } - Operation *create(FuncBuilder &builder, Location loc, + Operation *create(OpBuilder &builder, Location loc, ArrayRef operands) { return impl->create(builder, loc, operands); } @@ -425,7 +425,7 @@ private: virtual unsigned getNumReductionLoops(Operation *op) = 0; virtual unsigned getNumWindowLoops(Operation *op) = 0; virtual unsigned getNumLoops(Operation *op) = 0; - virtual Operation *create(FuncBuilder &builder, Location loc, + virtual Operation *create(OpBuilder &builder, Location loc, ArrayRef operands) = 0; }; @@ -458,7 +458,7 @@ private: unsigned getNumLoops(Operation *op) override { return cast(op).getNumLoops(); } - Operation *create(FuncBuilder &builder, Location loc, + Operation *create(OpBuilder &builder, Location loc, ArrayRef operands) override { return builder.create(loc, operands); } diff --git a/mlir/include/mlir/Linalg/Utils/Utils.h b/mlir/include/mlir/Linalg/Utils/Utils.h index 31963b243b0..594a9d116ca 100644 --- a/mlir/include/mlir/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Linalg/Utils/Utils.h @@ -88,7 +88,7 @@ SmallVector getViewSizes(LinalgOp &linalgOp); /// Returns the values obtained by applying `map` to the list of values. /// Performs simplifications and foldings where possible. -SmallVector applyMapToValues(FuncBuilder *b, Location loc, +SmallVector applyMapToValues(OpBuilder *b, Location loc, AffineMap map, ArrayRef values, FunctionConstants &state); diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 3886f0c7a0a..8b476c0dc3e 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -31,7 +31,6 @@ namespace mlir { // Forward declarations. class Block; -class FuncBuilder; class MLIRContext; class Operation; class Type; diff --git a/mlir/include/mlir/Transforms/LoopUtils.h b/mlir/include/mlir/Transforms/LoopUtils.h index 1105688aa55..8a255228893 100644 --- a/mlir/include/mlir/Transforms/LoopUtils.h +++ b/mlir/include/mlir/Transforms/LoopUtils.h @@ -31,7 +31,7 @@ namespace mlir { class AffineMap; class AffineForOp; class Function; -class FuncBuilder; +class OpBuilder; class Value; /// Unrolls this for operation completely if the trip count is known to be @@ -80,7 +80,7 @@ void promoteSingleIterationLoops(Function *f); void getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor, AffineMap *map, SmallVectorImpl *operands, - FuncBuilder *builder); + OpBuilder *builder); /// Skew the operations in the body of a 'affine.for' operation with the /// specified operation-wise shifts. The shifts are with respect to the diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h index 75407ad59b5..1b32a98206c 100644 --- a/mlir/include/mlir/Transforms/Utils.h +++ b/mlir/include/mlir/Transforms/Utils.h @@ -34,11 +34,9 @@ namespace mlir { class AffineApplyOp; class AffineForOp; -class FuncBuilder; class Location; class Module; - -class Function; +class OpBuilder; /// Replaces all "deferencing" uses of oldMemRef with newMemRef while optionally /// remapping the old memref's indices using the supplied affine map, @@ -83,7 +81,7 @@ bool replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, /// these will also be collected into a single (multi-result) affine apply op. /// The final results of the composed AffineApplyOp are returned in output /// parameter 'results'. Returns the affine apply op created. -Operation *createComposedAffineApplyOp(FuncBuilder *builder, Location loc, +Operation *createComposedAffineApplyOp(OpBuilder *builder, Location loc, ArrayRef operands, ArrayRef affineApplyOps, SmallVectorImpl *results); diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 28594a34d45..9189acf5f50 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -544,7 +544,7 @@ void mlir::fullyComposeAffineMapAndOperands( } } -AffineApplyOp mlir::makeComposedAffineApply(FuncBuilder *b, Location loc, +AffineApplyOp mlir::makeComposedAffineApply(OpBuilder *b, Location loc, AffineMap map, ArrayRef operands) { AffineMap normalizedMap = map; @@ -1069,9 +1069,9 @@ void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results, results.push_back(llvm::make_unique(context)); } -FuncBuilder AffineForOp::getBodyBuilder() { +OpBuilder AffineForOp::getBodyBuilder() { Block *body = getBody(); - return FuncBuilder(body, std::prev(body->end())); + return OpBuilder(body, std::prev(body->end())); } AffineBound AffineForOp::getLowerBound() { diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 117cf6e109e..16e092b8205 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -54,7 +54,7 @@ void mlir::buildTripCountMapAndOperands( int64_t loopSpan; int64_t step = forOp.getStep(); - FuncBuilder b(forOp.getOperation()); + OpBuilder b(forOp.getOperation()); if (forOp.hasConstantBounds()) { int64_t lb = forOp.getConstantLowerBound(); diff --git a/mlir/lib/Analysis/TestParallelismDetection.cpp b/mlir/lib/Analysis/TestParallelismDetection.cpp index ae5551db215..cbda6d40224 100644 --- a/mlir/lib/Analysis/TestParallelismDetection.cpp +++ b/mlir/lib/Analysis/TestParallelismDetection.cpp @@ -44,7 +44,7 @@ FunctionPassBase *mlir::createParallelismDetectionTestPass() { // parallel. void TestParallelismDetection::runOnFunction() { Function &f = getFunction(); - FuncBuilder b(f); + OpBuilder b(f.getBody()); f.walk([&](AffineForOp forOp) { if (isLoopParallel(forOp)) forOp.emitRemark("parallel loop"); diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 476c7c87aec..aa842364f26 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -749,7 +749,7 @@ mlir::insertBackwardComputationSlice(Operation *srcOpInst, Operation *dstOpInst, // Clone src loop nest and insert it a the beginning of the operation block // of the loop at 'dstLoopDepth' in 'dstLoopIVs'. auto dstAffineForOp = dstLoopIVs[dstLoopDepth - 1]; - FuncBuilder b(dstAffineForOp.getBody(), dstAffineForOp.getBody()->begin()); + OpBuilder b(dstAffineForOp.getBody(), dstAffineForOp.getBody()->begin()); auto sliceLoopNest = cast(b.clone(*srcLoopIVs[0].getOperation())); diff --git a/mlir/lib/EDSC/Builders.cpp b/mlir/lib/EDSC/Builders.cpp index 22f91399498..6f6363ffcc5 100644 --- a/mlir/lib/EDSC/Builders.cpp +++ b/mlir/lib/EDSC/Builders.cpp @@ -24,8 +24,7 @@ using namespace mlir; using namespace mlir::edsc; -mlir::edsc::ScopedContext::ScopedContext(FuncBuilder &builder, - Location location) +mlir::edsc::ScopedContext::ScopedContext(OpBuilder &builder, Location location) : builder(builder), location(location), enclosingScopedContext(ScopedContext::getCurrentScopedContext()), nestedBuilder(nullptr) { @@ -35,8 +34,8 @@ mlir::edsc::ScopedContext::ScopedContext(FuncBuilder &builder, /// Sets the insertion point of the builder to 'newInsertPt' for the duration /// of the scope. The existing insertion point of the builder is restored on /// destruction. -mlir::edsc::ScopedContext::ScopedContext(FuncBuilder &builder, - FuncBuilder::InsertPoint newInsertPt, +mlir::edsc::ScopedContext::ScopedContext(OpBuilder &builder, + OpBuilder::InsertPoint newInsertPt, Location location) : builder(builder), prevBuilderInsertPoint(builder.saveInsertionPoint()), location(location), @@ -59,7 +58,7 @@ ScopedContext *&mlir::edsc::ScopedContext::getCurrentScopedContext() { return context; } -FuncBuilder *mlir::edsc::ScopedContext::getBuilder() { +OpBuilder *mlir::edsc::ScopedContext::getBuilder() { assert(ScopedContext::getCurrentScopedContext() && "Unexpected Null ScopedContext"); return &ScopedContext::getCurrentScopedContext()->builder; diff --git a/mlir/lib/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/GPU/Transforms/KernelOutlining.cpp index 163a7cf6b6c..86fab1a20f1 100644 --- a/mlir/lib/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/GPU/Transforms/KernelOutlining.cpp @@ -30,7 +30,7 @@ using namespace mlir; namespace { template -void createForAllDimensions(FuncBuilder &builder, Location loc, +void createForAllDimensions(OpBuilder &builder, Location loc, SmallVectorImpl &values) { for (StringRef dim : {"x", "y", "z"}) { Value *v = builder.create(loc, builder.getIndexType(), @@ -42,12 +42,12 @@ void createForAllDimensions(FuncBuilder &builder, Location loc, // Add operations generating block/thread ids and gird/block dimensions at the // beginning of `kernelFunc` and replace uses of the respective function args. void injectGpuIndexOperations(Location loc, Function &kernelFunc) { - FuncBuilder funcBuilder(kernelFunc); + OpBuilder OpBuilder(kernelFunc.getBody()); SmallVector indexOps; - createForAllDimensions(funcBuilder, loc, indexOps); - createForAllDimensions(funcBuilder, loc, indexOps); - createForAllDimensions(funcBuilder, loc, indexOps); - createForAllDimensions(funcBuilder, loc, indexOps); + createForAllDimensions(OpBuilder, loc, indexOps); + createForAllDimensions(OpBuilder, loc, indexOps); + createForAllDimensions(OpBuilder, loc, indexOps); + createForAllDimensions(OpBuilder, loc, indexOps); // Replace the leading 12 function args with the respective thread/block index // operations. Iterate backwards since args are erased and indices change. for (int i = 11; i >= 0; --i) { @@ -78,10 +78,10 @@ Function *outlineKernelFunc(Module &module, gpu::LaunchOp &launchOp) { // Replace `gpu.launch` operations with an `gpu.launch_func` operation launching // `kernelFunc`. void convertToLaunchFuncOp(gpu::LaunchOp &launchOp, Function &kernelFunc) { - FuncBuilder funcBuilder(launchOp); + OpBuilder OpBuilder(launchOp); SmallVector kernelOperandValues( launchOp.getKernelOperandValues()); - funcBuilder.create( + OpBuilder.create( launchOp.getLoc(), &kernelFunc, launchOp.getGridSizeOperandValues(), launchOp.getBlockSizeOperandValues(), kernelOperandValues); launchOp.erase(); diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp index cf85cc86ef0..9595a72c226 100644 --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -282,6 +282,24 @@ Region::~Region() { bb.dropAllReferences(); } +/// Return the context this region is inserted in. The region must have a valid +/// parent container. +MLIRContext *Region::getContext() { + assert(!container.isNull() && "region is not attached to a container"); + if (auto *inst = getContainingOp()) + return inst->getContext(); + return getContainingFunction()->getContext(); +} + +/// Return a location for this region. This is the location attached to the +/// parent container. The region must have a valid parent container. +Location Region::getLoc() { + assert(!container.isNull() && "region is not attached to a container"); + if (auto *inst = getContainingOp()) + return inst->getLoc(); + return getContainingFunction()->getLoc(); +} + Region *Region::getContainingRegion() { if (auto *inst = getContainingOp()) return inst->getContainingRegion(); diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 4accfb54e28..d32e705785a 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -332,31 +332,31 @@ AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) { } //===----------------------------------------------------------------------===// -// Operations. +// OpBuilder. //===----------------------------------------------------------------------===// -FuncBuilder::~FuncBuilder() {} +OpBuilder::~OpBuilder() {} /// Add new block and set the insertion point to the end of it. If an /// 'insertBefore' block is passed, the block will be placed before the /// specified block. If not, the block will be appended to the end of the /// current function. -Block *FuncBuilder::createBlock(Block *insertBefore) { +Block *OpBuilder::createBlock(Block *insertBefore) { Block *b = new Block(); // If we are supposed to insert before a specific block, do so, otherwise add // the block to the end of the function. if (insertBefore) - function->getBlocks().insert(Function::iterator(insertBefore), b); + region->getBlocks().insert(Function::iterator(insertBefore), b); else - function->push_back(b); + region->push_back(b); setInsertionPointToEnd(b); return b; } /// Create an operation given the fields represented as an OperationState. -Operation *FuncBuilder::createOperation(const OperationState &state) { +Operation *OpBuilder::createOperation(const OperationState &state) { assert(block && "createOperation() called without setting builder's block"); auto *op = Operation::create(state); block->getOperations().insert(insertPoint, op); diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index f53c7156825..6ab5a6febf0 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -214,9 +214,7 @@ void Function::addEntryBlock() { } void Function::walk(const std::function &callback) { - // Walk each of the blocks within the function. - for (auto &block : getBlocks()) - block.walk(callback); + getBody().walk(callback); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 64fb8bcd0e2..5804770a88f 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -312,8 +312,7 @@ void Operation::replaceUsesOfWith(Value *from, Value *to) { void Operation::walk(const std::function &callback) { // Visit any internal operations. for (auto ®ion : getRegions()) - for (auto &block : region) - block.walk(callback); + region.walk(callback); // Visit the current operation. callback(this); diff --git a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp index 0e30a8e147e..1b50320071c 100644 --- a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp +++ b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp @@ -889,7 +889,7 @@ static void ensureDistinctSuccessors(Block &bb) { position != end; ++position) { auto *dummyBlock = new Block(); bb.getParent()->push_back(dummyBlock); - auto builder = FuncBuilder(dummyBlock); + auto builder = OpBuilder(dummyBlock); SmallVector operands( terminator->getSuccessorOperands(*position)); builder.create(terminator->getLoc(), successor.first, operands); diff --git a/mlir/lib/Linalg/IR/LinalgOps.cpp b/mlir/lib/Linalg/IR/LinalgOps.cpp index 55a791a6d63..3b3a04012fd 100644 --- a/mlir/lib/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Linalg/IR/LinalgOps.cpp @@ -773,7 +773,7 @@ void mlir::linalg::emitScalarImplementation( using edsc::intrinsics::select; // account for affine.terminator in loop. - FuncBuilder b(body, std::prev(body->end(), 1)); + OpBuilder b(body, std::prev(body->end(), 1)); ScopedContext scope(b, innermostLoop.getLoc()); auto *op = linalgOp.getOperation(); if (isa(op)) { diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index 60c0daf7938..b3857ac5171 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -621,7 +621,7 @@ static void lowerLinalgForToCFG(Function &f) { auto *op = forOp.getOperation(); auto loc = op->getLoc(); using namespace edsc::op; - FuncBuilder builder(op); + OpBuilder builder(op); ScopedContext scope(builder, loc); ValueHandle lb(forOp.getLowerBound()), ub(forOp.getUpperBound()), step(forOp.getStep()); diff --git a/mlir/lib/Linalg/Transforms/LowerToLoops.cpp b/mlir/lib/Linalg/Transforms/LowerToLoops.cpp index b2f59c43da1..5e22f8601fc 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLoops.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLoops.cpp @@ -35,7 +35,7 @@ using namespace mlir::linalg; // Creates a number of ranges equal to the number of results in `map`. // The returned ranges correspond to the loop ranges, in the proper order, for // which new loops will be created. -static SmallVector emitLoopRanges(FuncBuilder *b, Location loc, +static SmallVector emitLoopRanges(OpBuilder *b, Location loc, AffineMap map, ArrayRef allViewSizes, FunctionConstants &state) { @@ -51,7 +51,7 @@ static SmallVector emitLoopRanges(FuncBuilder *b, Location loc, } static void emitLinalgOpAsLoops(LinalgOp &linalgOp, FunctionConstants &state) { - FuncBuilder b(linalgOp.getOperation()); + OpBuilder b(linalgOp.getOperation()); ScopedContext scope(b, linalgOp.getOperation()->getLoc()); auto loopRanges = emitLoopRanges( scope.getBuilder(), scope.getLocation(), diff --git a/mlir/lib/Linalg/Transforms/Tiling.cpp b/mlir/lib/Linalg/Transforms/Tiling.cpp index 22090ca6aac..bc2ed2b60de 100644 --- a/mlir/lib/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Linalg/Transforms/Tiling.cpp @@ -58,7 +58,7 @@ static bool isZero(Value *v) { // The returned ranges correspond to the loop ranges, in the proper order, that // are tiled and for which new loops will be created. static SmallVector -makeTiledLoopRanges(FuncBuilder *b, Location loc, AffineMap map, +makeTiledLoopRanges(OpBuilder *b, Location loc, AffineMap map, ArrayRef allViewSizes, ArrayRef allTileSizes, FunctionConstants &state) { assert(allTileSizes.size() == map.getNumResults()); @@ -127,7 +127,7 @@ static Value *foldRange(Value *view, unsigned dim) { return nullptr; } -static SmallVector makeTiledViews(FuncBuilder *b, Location loc, +static SmallVector makeTiledViews(OpBuilder *b, Location loc, LinalgOp &linalgOp, ArrayRef ivs, ArrayRef tileSizes, @@ -210,7 +210,7 @@ static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef tileSizes, tileSizes.size() && "expected matching number of tile sizes and loops"); - FuncBuilder builder(op.getOperation()); + OpBuilder builder(op.getOperation()); ScopedContext scope(builder, op.getLoc()); auto loopRanges = makeTiledLoopRanges( scope.getBuilder(), scope.getLocation(), diff --git a/mlir/lib/Linalg/Utils/Utils.cpp b/mlir/lib/Linalg/Utils/Utils.cpp index f19e61c5531..81fad1c870a 100644 --- a/mlir/lib/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Linalg/Utils/Utils.cpp @@ -109,7 +109,7 @@ static Value *tryFold(AffineMap map, ArrayRef operands, return nullptr; } -static Value *emitOrFoldComposedAffineApply(FuncBuilder *b, Location loc, +static Value *emitOrFoldComposedAffineApply(OpBuilder *b, Location loc, AffineMap map, ArrayRef operandsRef, FunctionConstants &state) { @@ -121,7 +121,7 @@ static Value *emitOrFoldComposedAffineApply(FuncBuilder *b, Location loc, } SmallVector -mlir::linalg::applyMapToValues(FuncBuilder *b, Location loc, AffineMap map, +mlir::linalg::applyMapToValues(OpBuilder *b, Location loc, AffineMap map, ArrayRef values, FunctionConstants &state) { SmallVector res; @@ -141,7 +141,7 @@ Value *FunctionConstants::getOrCreateIndex(int64_t v) { auto it = map.find(v); if (it != map.end()) return it->second; - FuncBuilder builder(f); + OpBuilder builder(f.getBody()); edsc::ScopedContext s(builder, f.getLoc()); return map.insert(std::make_pair(v, edsc::intrinsics::constant_index(v))) .first->getSecond(); diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 6cc933a8169..a3d44f9935e 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -2302,11 +2302,11 @@ public: /// more specific builder type. #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wshadow-field" - FuncBuilder builder; + OpBuilder builder; #pragma clang diagnostic pop FunctionParser(ParserState &state, Function *function) - : Parser(state), builder(function), function(function) {} + : Parser(state), builder(function->getBody()), function(function) {} ~FunctionParser(); diff --git a/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp b/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp index 75c082fd8ce..375a64d8f2d 100644 --- a/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp +++ b/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp @@ -77,7 +77,7 @@ void AddDefaultStatsPass::runWithConfig(SolverContext &solverContext, for (auto *arg : func.getArguments()) { if (!config.isHandledType(arg->getType())) continue; - FuncBuilder b(func); + OpBuilder b(func.getBody()); APFloat minValue(-1.0f); APFloat maxValue(1.0f); ElementsAttr layerStats = DenseFPElementsAttr::get( @@ -102,7 +102,7 @@ void AddDefaultStatsPass::runWithConfig(SolverContext &solverContext, if (!config.isHandledType(originalResult->getType())) return; - FuncBuilder b(op->getBlock(), ++op->getIterator()); + OpBuilder b(op->getBlock(), ++op->getIterator()); APFloat minValue(-1.0f); APFloat maxValue(1.0f); diff --git a/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp b/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp index 94bac98598c..c443354714f 100644 --- a/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp +++ b/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp @@ -184,7 +184,7 @@ void InferQuantizedTypesPass::transformOperandType(CAGOperandAnchor *anchor, Type newType) { Value *inputValue = anchor->getValue(); Operation *op = anchor->getOp(); - FuncBuilder b(op->getBlock(), Block::iterator(op)); + OpBuilder b(op->getBlock(), Block::iterator(op)); SmallVector removeValuesIfDead; @@ -240,7 +240,7 @@ void InferQuantizedTypesPass::transformResultType(CAGResultAnchor *anchor, Type newType) { Value *origResultValue = anchor->getValue(); Operation *op = origResultValue->getDefiningOp(); - FuncBuilder b(op->getBlock(), ++Block::iterator(op)); + OpBuilder b(op->getBlock(), ++Block::iterator(op)); Value *replacedResultValue = nullptr; Value *newResultValue = nullptr; diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index 6002cadd37d..1deedc1520c 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -108,8 +108,8 @@ struct DialectConversionRewriter final : public PatternRewriter { SmallVector newValues; }; - DialectConversionRewriter(Function *fn) - : PatternRewriter(fn), argConverter(fn->getContext()) {} + DialectConversionRewriter(Region ®ion) + : PatternRewriter(region), argConverter(region.getContext()) {} ~DialectConversionRewriter() = default; /// Cleanup and destroy any generated rewrite operations. This method is @@ -151,7 +151,7 @@ struct DialectConversionRewriter final : public PatternRewriter { /// PatternRewriter hook for creating a new operation. Operation *createOperation(const OperationState &state) override { - auto *result = FuncBuilder::createOperation(state); + auto *result = OpBuilder::createOperation(state); createdOps.push_back(result); return result; } @@ -572,7 +572,7 @@ LogicalResult FunctionConverter::convertFunction(Function *f) { return success(); // Rewrite the function body. - DialectConversionRewriter rewriter(f); + DialectConversionRewriter rewriter(f->getBody()); if (failed(convertRegion(rewriter, f->getBody(), f->getLoc()))) { // Reset any of the generated rewrites. rewriter.discardRewrites(); diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 1ead2e5c8e2..7c745aa1a0b 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -240,14 +240,14 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, Block *block, return true; // DMAs for read regions are going to be inserted just before the for loop. - FuncBuilder prologue(block, begin); + OpBuilder prologue(block, begin); // DMAs for write regions are going to be inserted just after the for loop. - FuncBuilder epilogue(block, end); - FuncBuilder *b = region.isWrite() ? &epilogue : &prologue; + OpBuilder epilogue(block, end); + OpBuilder *b = region.isWrite() ? &epilogue : &prologue; // Builder to create constants at the top level. auto *func = block->getFunction(); - FuncBuilder top(func); + OpBuilder top(func->getBody()); auto loc = region.loc; auto *memref = region.memref; @@ -759,7 +759,7 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { void DmaGeneration::runOnFunction() { Function &f = getFunction(); - FuncBuilder topBuilder(f); + OpBuilder topBuilder(f.getBody()); zeroIndex = topBuilder.create(f.getLoc(), 0); // Override default is a command line option is provided. diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index b7b69fa54fe..0f39e52eefb 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -1006,9 +1006,9 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, auto *forInst = forOp.getOperation(); // Create builder to insert alloc op just before 'forOp'. - FuncBuilder b(forInst); + OpBuilder b(forInst); // Builder to create constants at the top level. - FuncBuilder top(forInst->getFunction()); + OpBuilder top(forInst->getFunction()->getBody()); // Create new memref type based on slice bounds. auto *oldMemRef = cast(srcStoreOpInst).getMemRef(); auto oldMemRefType = oldMemRef->getType().cast(); diff --git a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp index 31875664922..c4c1184fa82 100644 --- a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp +++ b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp @@ -203,7 +203,7 @@ void LoopInvariantCodeMotion::runOnAffineForOp(AffineForOp forOp) { SmallPtrSet definedOps; // This is the place where hoisted instructions would reside. - FuncBuilder b(forOp.getOperation()); + OpBuilder b(forOp.getOperation()); SmallPtrSet opsToHoist; SmallVector opsToMove; diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 5233081b5f1..c1be6e8f6b1 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -112,7 +112,7 @@ constructTiledIndexSetHyperRect(MutableArrayRef origLoops, assert(!origLoops.empty()); assert(origLoops.size() == tileSizes.size()); - FuncBuilder b(origLoops[0].getOperation()); + OpBuilder b(origLoops[0].getOperation()); unsigned width = origLoops.size(); // Bounds for tile space loops. @@ -207,7 +207,7 @@ LogicalResult mlir::tileCodeGen(MutableArrayRef band, // Add intra-tile (or point) loops. for (unsigned i = 0; i < width; i++) { - FuncBuilder b(topLoop); + OpBuilder b(topLoop); // Loop bounds will be set later. auto pointLoop = b.create(loc, 0, 0); pointLoop.getBody()->getOperations().splice( @@ -221,7 +221,7 @@ LogicalResult mlir::tileCodeGen(MutableArrayRef band, // Add tile space loops; for (unsigned i = width; i < 2 * width; i++) { - FuncBuilder b(topLoop); + OpBuilder b(topLoop); // Loop bounds will be set later. auto tileSpaceLoop = b.create(loc, 0, 0); tileSpaceLoop.getBody()->getOperations().splice( diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 731464bd7c1..409eb397df4 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -185,8 +185,7 @@ LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp, // unrollJamFactor. if (getLargestDivisorOfTripCount(forOp) % unrollJamFactor != 0) { // Insert the cleanup loop right after 'forOp'. - FuncBuilder builder(forInst->getBlock(), - std::next(Block::iterator(forInst))); + OpBuilder builder(forInst->getBlock(), std::next(Block::iterator(forInst))); auto cleanupAffineForOp = cast(builder.clone(*forInst)); // Adjust the lower bound of the cleanup loop; its upper bound is the same // as the original loop's upper bound. @@ -212,7 +211,7 @@ LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp, for (auto &subBlock : subBlocks) { // Builder to insert unroll-jammed bodies. Insert right at the end of // sub-block. - FuncBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second)); + OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second)); // Unroll and jam (appends unrollJamFactor-1 additional copies). for (unsigned i = 1; i < unrollJamFactor; i++) { diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index 4dcc82f1178..b890b43da81 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -41,7 +41,7 @@ class AffineApplyExpander public: // This internal class expects arguments to be non-null, checks must be // performed at the call site. - AffineApplyExpander(FuncBuilder *builder, ArrayRef dimValues, + AffineApplyExpander(OpBuilder *builder, ArrayRef dimValues, ArrayRef symbolValues, Location loc) : builder(*builder), dimValues(dimValues), symbolValues(symbolValues), loc(loc) {} @@ -206,7 +206,7 @@ public: } private: - FuncBuilder &builder; + OpBuilder &builder; ArrayRef dimValues; ArrayRef symbolValues; @@ -216,7 +216,7 @@ private: // Create a sequence of operations that implement the `expr` applied to the // given dimension and symbol values. -static mlir::Value *expandAffineExpr(FuncBuilder *builder, Location loc, +static mlir::Value *expandAffineExpr(OpBuilder *builder, Location loc, AffineExpr expr, ArrayRef dimValues, ArrayRef symbolValues) { @@ -226,7 +226,7 @@ static mlir::Value *expandAffineExpr(FuncBuilder *builder, Location loc, // Create a sequence of operations that implement the `affineMap` applied to // the given `operands` (as it it were an AffineApplyOp). Optional> static expandAffineMap( - FuncBuilder *builder, Location loc, AffineMap affineMap, + OpBuilder *builder, Location loc, AffineMap affineMap, ArrayRef operands) { auto numDims = affineMap.getNumDims(); auto expanded = functional::map( @@ -260,7 +260,7 @@ struct LowerAffinePass : public FunctionPass { // recognize as a reduction by the subsequent passes. static Value *buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate, ArrayRef values, - FuncBuilder &builder) { + OpBuilder &builder) { assert(!llvm::empty(values) && "empty min/max chain"); auto valueIt = values.begin(); @@ -348,7 +348,7 @@ static LogicalResult lowerAffineFor(AffineForOp forOp) { // Append the induction variable stepping logic and branch back to the exit // condition block. Construct an affine expression f : (x -> x+step) and // apply this expression to the induction variable. - FuncBuilder builder(bodyBlock); + OpBuilder builder(bodyBlock); auto affStep = builder.getAffineConstantExpr(forOp.getStep()); auto affDim = builder.getAffineDimExpr(0); auto stepped = expandAffineExpr(&builder, loc, affDim + affStep, iv, {}); @@ -482,7 +482,7 @@ static LogicalResult lowerAffineIf(AffineIfOp ifOp) { std::prev(oldThen->end())); } - FuncBuilder builder(thenBlock); + OpBuilder builder(thenBlock); builder.create(loc, continueBlock); // Handle the 'else' block the same way, but we skip it if we have no else @@ -569,7 +569,7 @@ static LogicalResult lowerAffineIf(AffineIfOp ifOp) { // Convert an "affine.apply" operation into a sequence of arithmetic // operations using the StandardOps dialect. Return true on error. static LogicalResult lowerAffineApply(AffineApplyOp op) { - FuncBuilder builder(op.getOperation()); + OpBuilder builder(op.getOperation()); auto maybeExpandedMap = expandAffineMap(&builder, op.getLoc(), op.getAffineMap(), llvm::to_vector<8>(op.getOperands())); diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 80e080f8fa6..0d8cfea7312 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -238,7 +238,7 @@ static SmallVector delinearize(unsigned linearIndex, return res; } -static Operation *instantiate(FuncBuilder *b, Operation *opInst, +static Operation *instantiate(OpBuilder *b, Operation *opInst, VectorType hwVectorType, DenseMap *substitutionsMap); @@ -257,7 +257,7 @@ static Value *substitute(Value *v, VectorType hwVectorType, if (it == substitutionsMap->end()) { auto *opInst = v->getDefiningOp(); if (isa(opInst)) { - FuncBuilder b(opInst); + OpBuilder b(opInst); auto *op = instantiate(&b, opInst, hwVectorType, substitutionsMap); auto res = substitutionsMap->insert(std::make_pair(v, op->getResult(0))); assert(res.second && "Insertion failed"); @@ -331,7 +331,7 @@ static Value *substitute(Value *v, VectorType hwVectorType, /// TODO(ntv): these implementation details should be captured in a /// vectorization trait at the op level directly. static SmallVector -reindexAffineIndices(FuncBuilder *b, VectorType hwVectorType, +reindexAffineIndices(OpBuilder *b, VectorType hwVectorType, ArrayRef hwVectorInstance, ArrayRef memrefIndices) { auto vectorShape = hwVectorType.getShape(); @@ -404,7 +404,7 @@ materializeAttributes(Operation *opInst, VectorType hwVectorType) { /// substitutionsMap. /// /// If the underlying substitution fails, this fails too and returns nullptr. -static Operation *instantiate(FuncBuilder *b, Operation *opInst, +static Operation *instantiate(OpBuilder *b, Operation *opInst, VectorType hwVectorType, DenseMap *substitutionsMap) { assert(!isa(opInst) && @@ -481,7 +481,7 @@ static AffineMap projectedPermutationMap(VectorTransferOpTy transfer, /// `hwVectorType` int the covering of the super-vector type. For a more /// detailed description of the problem, see the description of /// reindexAffineIndices. -static Operation *instantiate(FuncBuilder *b, VectorTransferReadOp read, +static Operation *instantiate(OpBuilder *b, VectorTransferReadOp read, VectorType hwVectorType, ArrayRef hwVectorInstance, DenseMap *substitutionsMap) { @@ -505,7 +505,7 @@ static Operation *instantiate(FuncBuilder *b, VectorTransferReadOp read, /// `hwVectorType` int the covering of th3e super-vector type. For a more /// detailed description of the problem, see the description of /// reindexAffineIndices. -static Operation *instantiate(FuncBuilder *b, VectorTransferWriteOp write, +static Operation *instantiate(OpBuilder *b, VectorTransferWriteOp write, VectorType hwVectorType, ArrayRef hwVectorInstance, DenseMap *substitutionsMap) { @@ -547,7 +547,7 @@ static bool instantiateMaterialization(Operation *op, LLVM_DEBUG(dbgs() << "\ninstantiate: " << *op); // Create a builder here for unroll-and-jam effects. - FuncBuilder b(op); + OpBuilder b(op); // AffineApplyOp are ignored: instantiating the proper vector op will take // care of AffineApplyOps by composing them properly. if (isa(op)) { diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index de8038c931c..d0e0d18d586 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -73,7 +73,7 @@ static unsigned getTagMemRefPos(Operation &dmaInst) { /// modulo 2. Returns false if such a replacement cannot be performed. static bool doubleBuffer(Value *oldMemRef, AffineForOp forOp) { auto *forBody = forOp.getBody(); - FuncBuilder bInner(forBody, forBody->begin()); + OpBuilder bInner(forBody, forBody->begin()); bInner.setInsertionPoint(forBody, forBody->begin()); // Doubles the shape with a leading dimension extent of 2. @@ -94,7 +94,7 @@ static bool doubleBuffer(Value *oldMemRef, AffineForOp forOp) { // The double buffer is allocated right before 'forInst'. auto *forInst = forOp.getOperation(); - FuncBuilder bOuter(forInst); + OpBuilder bOuter(forInst); // Put together alloc operands for any dynamic dimensions of the memref. SmallVector allocOperands; unsigned dynamicDimCount = 0; @@ -360,7 +360,7 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { // Tagging operations with shifts for debugging purposes. LLVM_DEBUG({ - FuncBuilder b(&op); + OpBuilder b(&op); op.setAttr("shift", b.getI64IntegerAttr(shifts[s - 1])); }); } diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp index fbf1a2ae9c1..3983dda98fd 100644 --- a/mlir/lib/Transforms/Utils/FoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp @@ -110,7 +110,7 @@ LogicalResult OperationFolder::tryToFold(Operation *op, assert(foldResults.size() == op->getNumResults()); // Create the result constants and replace the results. - FuncBuilder builder(op); + OpBuilder builder(op); for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) { assert(!foldResults[i].isNull() && "expected valid OpFoldResult"); diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index a2e64271c46..0cd32255561 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -46,7 +46,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter { public: explicit GreedyPatternRewriteDriver(Function &fn, OwningRewritePatternList &&patterns) - : PatternRewriter(&fn), matcher(std::move(patterns)) { + : PatternRewriter(fn.getBody()), matcher(std::move(patterns)) { worklist.reserve(64); } @@ -88,7 +88,7 @@ protected: // Implement the hook for creating operations, and make sure that newly // created ops are added to the worklist for processing. Operation *createOperation(const OperationState &state) override { - auto *result = FuncBuilder::createOperation(state); + auto *result = OpBuilder::createOperation(state); addToWorklist(result); return result; } @@ -142,14 +142,16 @@ private: /// Perform the rewrites. bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) { - Function *fn = getFunction(); - OperationFolder helper(fn); + Region *region = getRegion(); + + // TODO(riverriddle) OperationFolder should take a region to insert into. + OperationFolder helper(region->getContainingFunction()); bool changed = false; int i = 0; do { // Add all operations to the worklist. - fn->walk([&](Operation *op) { addToWorklist(op); }); + region->walk([&](Operation *op) { addToWorklist(op); }); // These are scratch vectors used in the folding loop below. SmallVector originalOperands, resultValues; diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index d5bdcea2c55..23375e7b472 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -46,7 +46,7 @@ using namespace mlir; void mlir::getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor, AffineMap *map, SmallVectorImpl *operands, - FuncBuilder *b) { + OpBuilder *b) { auto lbMap = forOp.getLowerBoundMap(); // Single result lower bound map only. @@ -125,15 +125,14 @@ LogicalResult mlir::promoteIfSingleIteration(AffineForOp forOp) { Operation *op = forOp.getOperation(); if (!iv->use_empty()) { if (forOp.hasConstantLowerBound()) { - auto *mlFunc = op->getFunction(); - FuncBuilder topBuilder(mlFunc); + OpBuilder topBuilder(op->getFunction()->getBody()); auto constOp = topBuilder.create( forOp.getLoc(), forOp.getConstantLowerBound()); iv->replaceAllUsesWith(constOp); } else { AffineBound lb = forOp.getLowerBound(); SmallVector lbOperands(lb.operand_begin(), lb.operand_end()); - FuncBuilder builder(op->getBlock(), Block::iterator(op)); + OpBuilder builder(op->getBlock(), Block::iterator(op)); if (lb.getMap() == builder.getDimIdentityMap()) { // No need of generating an affine.apply. iv->replaceAllUsesWith(lbOperands[0]); @@ -173,7 +172,7 @@ static AffineForOp generateLoop(AffineMap lbMap, AffineMap ubMap, const std::vector>> &instGroupQueue, - unsigned offset, AffineForOp srcForInst, FuncBuilder *b) { + unsigned offset, AffineForOp srcForInst, OpBuilder *b) { SmallVector lbOperands(srcForInst.getLowerBoundOperands()); SmallVector ubOperands(srcForInst.getUpperBoundOperands()); @@ -188,7 +187,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, BlockAndValueMapping operandMap; - FuncBuilder bodyBuilder = loopChunk.getBodyBuilder(); + OpBuilder bodyBuilder = loopChunk.getBodyBuilder(); for (auto it = instGroupQueue.begin() + offset, e = instGroupQueue.end(); it != e; ++it) { uint64_t shift = it->first; @@ -291,7 +290,7 @@ LogicalResult mlir::instBodySkew(AffineForOp forOp, ArrayRef shifts, auto origLbMap = forOp.getLowerBoundMap(); uint64_t lbShift = 0; - FuncBuilder b(forOp.getOperation()); + OpBuilder b(forOp.getOperation()); for (uint64_t d = 0, e = sortedInstGroups.size(); d < e; ++d) { // If nothing is shifted by d, continue. if (sortedInstGroups[d].empty()) @@ -424,7 +423,7 @@ LogicalResult mlir::loopUnrollByFactor(AffineForOp forOp, // Generate the cleanup loop if trip count isn't a multiple of unrollFactor. Operation *op = forOp.getOperation(); if (getLargestDivisorOfTripCount(forOp) % unrollFactor != 0) { - FuncBuilder builder(op->getBlock(), ++Block::iterator(op)); + OpBuilder builder(op->getBlock(), ++Block::iterator(op)); auto cleanupForInst = cast(builder.clone(*op)); AffineMap cleanupMap; SmallVector cleanupOperands; @@ -448,7 +447,7 @@ LogicalResult mlir::loopUnrollByFactor(AffineForOp forOp, // Builder to insert unrolled bodies just before the terminator of the body of // 'forOp'. - FuncBuilder builder = forOp.getBodyBuilder(); + OpBuilder builder = forOp.getBodyBuilder(); // Keep a pointer to the last non-terminator operation in the original block // so that we know what to clone (since we are doing this in-place). @@ -647,7 +646,7 @@ void mlir::sinkLoop(AffineForOp forOp, unsigned loopDepth) { // ... // } // ``` -static void augmentMapAndBounds(FuncBuilder *b, Value *iv, AffineMap *map, +static void augmentMapAndBounds(OpBuilder *b, Value *iv, AffineMap *map, SmallVector *operands, int64_t offset = 0) { auto bounds = llvm::to_vector<4>(map->getResults()); @@ -665,7 +664,7 @@ static void cloneLoopBodyInto(AffineForOp forOp, Value *oldIv, AffineForOp newForOp) { BlockAndValueMapping map; map.map(oldIv, newForOp.getInductionVar()); - FuncBuilder b = newForOp.getBodyBuilder(); + OpBuilder b = newForOp.getBodyBuilder(); for (auto &op : *forOp.getBody()) { // Step over newForOp in case it is nested under forOp. if (&op == newForOp.getOperation()) { @@ -704,7 +703,7 @@ stripmineSink(AffineForOp forOp, uint64_t factor, forOp.setStep(scaledStep); auto *op = forOp.getOperation(); - FuncBuilder b(op->getBlock(), ++Block::iterator(op)); + OpBuilder b(op->getBlock(), ++Block::iterator(op)); // Lower-bound map creation. auto lbMap = forOp.getLowerBoundMap(); @@ -720,7 +719,7 @@ stripmineSink(AffineForOp forOp, uint64_t factor, SmallVector innerLoops; for (auto t : targets) { // Insert newForOp before the terminator of `t`. - FuncBuilder b = t.getBodyBuilder(); + OpBuilder b = t.getBodyBuilder(); auto newForOp = b.create(t.getLoc(), lbOperands, lbMap, ubOperands, ubMap, originalStep); cloneLoopBodyInto(t, forOp.getInductionVar(), newForOp); diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 13e5b2f2f08..2e2bc08b24d 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -123,7 +123,7 @@ bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, opInst->operand_begin() + memRefOperandPos); state.operands.push_back(newMemRef); - FuncBuilder builder(opInst); + OpBuilder builder(opInst); for (auto *extraIndex : extraIndices) { assert(extraIndex->getDefiningOp()->getNumResults() == 1 && "single result op's expected to generate these indices"); @@ -249,7 +249,7 @@ void mlir::createAffineComputationSlice( if (localized) return; - FuncBuilder builder(opInst); + OpBuilder builder(opInst); SmallVector composedOpOperands(subOperands); auto composedMap = builder.getMultiDimIdentityMap(composedOpOperands.size()); fullyComposeAffineMapAndOperands(&composedMap, &composedOpOperands); diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index aeaea02dad0..9220b7bb751 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -240,7 +240,7 @@ void VectorizerTestPass::testNormalizeMaps() { pattern.match(f, &matches); for (auto m : matches) { auto app = cast(m.getMatchedOperation()); - FuncBuilder b(m.getMatchedOperation()); + OpBuilder b(m.getMatchedOperation()); SmallVector operands(app.getOperands()); makeComposedAffineApply(&b, app.getLoc(), app.getAffineMap(), operands); } diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index ddaf112dece..a96713be547 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -805,7 +805,7 @@ static LogicalResult vectorizeRootOrTerminal(Value *iv, return LogicalResult::Failure; LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: "); LLVM_DEBUG(permutationMap.print(dbgs())); - FuncBuilder b(opInst); + OpBuilder b(opInst); auto transfer = b.create( opInst->getLoc(), vectorType, memoryOp.getMemRef(), map(makePtrDynCaster(), memoryOp.getIndices()), permutationMap); @@ -920,7 +920,7 @@ static Value *vectorizeConstant(Operation *op, ConstantOp constant, Type type) { !VectorType::isValidElementType(constant.getType())) { return nullptr; } - FuncBuilder b(op); + OpBuilder b(op); Location loc = op->getLoc(); auto vectorType = type.cast(); auto attr = SplatElementsAttr::get(vectorType, constant.getValue()); @@ -1015,7 +1015,7 @@ static Operation *vectorizeOneOperation(Operation *opInst, auto *value = store.getValueToStore(); auto *vectorValue = vectorizeOperand(value, opInst, state); auto indices = map(makePtrDynCaster(), store.getIndices()); - FuncBuilder b(opInst); + OpBuilder b(opInst); auto permutationMap = makePermutationMap(opInst, state->strategy->loopToVectorDim); if (!permutationMap) @@ -1054,7 +1054,7 @@ static Operation *vectorizeOneOperation(Operation *opInst, // name that works both in scalar mode and vector mode. // TODO(ntv): Is it worth considering an Operation.clone operation which // changes the type so we can promote an Operation with less boilerplate? - FuncBuilder b(opInst); + OpBuilder b(opInst); OperationState newOp(b.getContext(), opInst->getLoc(), opInst->getName().getStringRef(), vectorOperands, vectorTypes, opInst->getAttrs(), /*successors=*/{}, @@ -1136,7 +1136,7 @@ static LogicalResult vectorizeRootMatch(NestedMatch m, /// maintains a clone for handling failure and restores the proper state via /// RAII. auto *loopInst = loop.getOperation(); - FuncBuilder builder(loopInst); + OpBuilder builder(loopInst); auto clonedLoop = cast(builder.clone(*loopInst)); struct Guard { LogicalResult failure() { diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp index 55c457443e5..6018abe0043 100644 --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -62,7 +62,7 @@ TEST_FUNC(builder_dynamic_for_func_args) { auto f = makeFunction("builder_dynamic_for_func_args", {}, {indexType, indexType}); - FuncBuilder builder(*f); + OpBuilder builder(f->getBody()); ScopedContext scope(builder, f->getLoc()); ValueHandle i(indexType), j(indexType), lb(f->getArgument(0)), ub(f->getArgument(1)); @@ -113,7 +113,7 @@ TEST_FUNC(builder_dynamic_for) { auto f = makeFunction("builder_dynamic_for", {}, {indexType, indexType, indexType, indexType}); - FuncBuilder builder(*f); + OpBuilder builder(f->getBody()); ScopedContext scope(builder, f->getLoc()); ValueHandle i(indexType), a(f->getArgument(0)), b(f->getArgument(1)), c(f->getArgument(2)), d(f->getArgument(3)); @@ -136,7 +136,7 @@ TEST_FUNC(builder_max_min_for) { auto f = makeFunction("builder_max_min_for", {}, {indexType, indexType, indexType, indexType}); - FuncBuilder builder(*f); + OpBuilder builder(f->getBody()); ScopedContext scope(builder, f->getLoc()); ValueHandle i(indexType), lb1(f->getArgument(0)), lb2(f->getArgument(1)), ub1(f->getArgument(2)), ub2(f->getArgument(3)); @@ -157,7 +157,7 @@ TEST_FUNC(builder_blocks) { using namespace edsc::op; auto f = makeFunction("builder_blocks"); - FuncBuilder builder(*f); + OpBuilder builder(f->getBody()); ScopedContext scope(builder, f->getLoc()); ValueHandle c1(ValueHandle::create(42, 32)), c2(ValueHandle::create(1234, 32)); @@ -201,7 +201,7 @@ TEST_FUNC(builder_blocks_eager) { using namespace edsc::op; auto f = makeFunction("builder_blocks_eager"); - FuncBuilder builder(*f); + OpBuilder builder(f->getBody()); ScopedContext scope(builder, f->getLoc()); ValueHandle c1(ValueHandle::create(42, 32)), c2(ValueHandle::create(1234, 32)); @@ -244,7 +244,7 @@ TEST_FUNC(builder_cond_branch) { auto f = makeFunction("builder_cond_branch", {}, {IntegerType::get(1, &globalContext())}); - FuncBuilder builder(*f); + OpBuilder builder(f->getBody()); ScopedContext scope(builder, f->getLoc()); ValueHandle funcArg(f->getArgument(0)); ValueHandle c32(ValueHandle::create(32, 32)), @@ -281,7 +281,7 @@ TEST_FUNC(builder_cond_branch_eager) { auto f = makeFunction("builder_cond_branch_eager", {}, {IntegerType::get(1, &globalContext())}); - FuncBuilder builder(*f); + OpBuilder builder(f->getBody()); ScopedContext scope(builder, f->getLoc()); ValueHandle funcArg(f->getArgument(0)); ValueHandle c32(ValueHandle::create(32, 32)), @@ -321,7 +321,7 @@ TEST_FUNC(builder_helpers) { auto f = makeFunction("builder_helpers", {}, {memrefType, memrefType, memrefType}); - FuncBuilder builder(*f); + OpBuilder builder(f->getBody()); ScopedContext scope(builder, f->getLoc()); // clang-format off ValueHandle f7( @@ -373,7 +373,7 @@ TEST_FUNC(custom_ops) { auto indexType = IndexType::get(&globalContext()); auto f = makeFunction("custom_ops", {}, {indexType, indexType}); - FuncBuilder builder(*f); + OpBuilder builder(f->getBody()); ScopedContext scope(builder, f->getLoc()); CustomOperation MY_CUSTOM_OP("my_custom_op"); CustomOperation MY_CUSTOM_OP_0("my_custom_op_0"); @@ -412,7 +412,7 @@ TEST_FUNC(insertion_in_block) { auto indexType = IndexType::get(&globalContext()); auto f = makeFunction("insertion_in_block", {}, {indexType, indexType}); - FuncBuilder builder(*f); + OpBuilder builder(f->getBody()); ScopedContext scope(builder, f->getLoc()); BlockHandle b1; // clang-format off @@ -438,7 +438,7 @@ TEST_FUNC(select_op) { auto memrefType = MemRefType::get({-1, -1, -1}, f32Type, {}, 0); auto f = makeFunction("select_op", {}, {memrefType}); - FuncBuilder builder(*f); + OpBuilder builder(f->getBody()); ScopedContext scope(builder, f->getLoc()); // clang-format off ValueHandle zero = constant_index(0), one = constant_index(1); @@ -474,7 +474,7 @@ TEST_FUNC(tile_2d) { MemRefType::get({-1, -1, -1}, FloatType::getF32(&globalContext()), {}, 0); auto f = makeFunction("tile_2d", {}, {memrefType, memrefType, memrefType}); - FuncBuilder builder(*f); + OpBuilder builder(f->getBody()); ScopedContext scope(builder, f->getLoc()); ValueHandle zero = constant_index(0); MemRefView vA(f->getArgument(0)), vB(f->getArgument(1)), @@ -548,7 +548,7 @@ TEST_FUNC(vectorize_2d) { mlir::Module module(&globalContext()); module.getFunctions().push_back(f); - FuncBuilder builder(f); + OpBuilder builder(f->getBody()); ScopedContext scope(builder, f->getLoc()); ValueHandle zero = constant_index(0); MemRefView vA(f->getArgument(0)), vB(f->getArgument(1)), -- cgit v1.2.3 From e33e36f1788c47d077ff9d4e8dfc12a665ac28e6 Mon Sep 17 00:00:00 2001 From: Andy Davis Date: Mon, 10 Jun 2019 10:50:08 -0700 Subject: Return dependence result enum to distiguish between dependence result and error cases (NFC). PiperOrigin-RevId: 252437616 --- mlir/include/mlir/Analysis/AffineAnalysis.h | 19 ++++++++++-- mlir/lib/Analysis/AffineAnalysis.cpp | 39 +++++++++++++------------ mlir/lib/Analysis/TestMemRefDependenceCheck.cpp | 8 +++-- mlir/lib/Analysis/Utils.cpp | 14 +++++---- mlir/lib/Transforms/LoopFusion.cpp | 7 +++-- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 7 +++-- 6 files changed, 59 insertions(+), 35 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Analysis/AffineAnalysis.h b/mlir/include/mlir/Analysis/AffineAnalysis.h index 1b92bd1b14c..bb25a65205c 100644 --- a/mlir/include/mlir/Analysis/AffineAnalysis.h +++ b/mlir/include/mlir/Analysis/AffineAnalysis.h @@ -94,19 +94,34 @@ struct DependenceComponent { /// Checks whether two accesses to the same memref access the same element. /// Each access is specified using the MemRefAccess structure, which contains /// the operation, indices and memref associated with the access. Returns -/// 'false' if it can be determined conclusively that the accesses do not +/// 'NoDependence' if it can be determined conclusively that the accesses do not /// access the same memref element. If 'allowRAR' is true, will consider /// read-after-read dependences (typically used by applications trying to /// optimize input reuse). // TODO(andydavis) Wrap 'dependenceConstraints' and 'dependenceComponents' into // a single struct. // TODO(andydavis) Make 'dependenceConstraints' optional arg. -bool checkMemrefAccessDependence( +struct DependenceResult { + enum ResultEnum { + HasDependence, // A dependence exists between 'srcAccess' and 'dstAccess'. + NoDependence, // No dependence exists between 'srcAccess' and 'dstAccess'. + Failure, // Dependence check failed due to unsupported cases. + } value; + DependenceResult(ResultEnum v) : value(v) {} +}; + +DependenceResult checkMemrefAccessDependence( const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, unsigned loopDepth, FlatAffineConstraints *dependenceConstraints, llvm::SmallVector *dependenceComponents, bool allowRAR = false); +/// Utility function that returns true if the provided DependenceResult +/// corresponds to a dependence result. +inline bool hasDependence(DependenceResult result) { + return result.value == DependenceResult::HasDependence; +} + /// Returns in 'depCompsVec', dependence components for dependences between all /// load and store ops in loop nest rooted at 'forOp', at loop depths in range /// [1, maxLoopDepth]. diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index a9dce13d041..fc8c712e8b0 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -681,8 +681,10 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const { // Builds a flat affine constraint system to check if there exists a dependence // between memref accesses 'srcAccess' and 'dstAccess'. -// Returns 'false' if the accesses can be definitively shown not to access the -// same element. Returns 'true' otherwise. +// Returns 'NoDependence' if the accesses can be definitively shown not to +// access the same element. +// Returns 'HasDependence' if the accesses do access the same element. +// Returns 'Failure' if an error or unsupported case was encountered. // If a dependence exists, returns in 'dependenceComponents' a direction // vector for the dependence, with a component for each loop IV in loops // common to both accesses (see Dependence in AffineAnalysis.h for details). @@ -764,7 +766,7 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const { // // // TODO(andydavis) Support AffineExprs mod/floordiv/ceildiv. -bool mlir::checkMemrefAccessDependence( +DependenceResult mlir::checkMemrefAccessDependence( const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, unsigned loopDepth, FlatAffineConstraints *dependenceConstraints, llvm::SmallVector *dependenceComponents, @@ -774,13 +776,14 @@ bool mlir::checkMemrefAccessDependence( LLVM_DEBUG(srcAccess.opInst->dump();); LLVM_DEBUG(dstAccess.opInst->dump();); - // Return 'false' if these accesses do not acces the same memref. + // Return 'NoDependence' if these accesses do not access the same memref. if (srcAccess.memref != dstAccess.memref) - return false; - // Return 'false' if one of these accesses is not a StoreOp. + return DependenceResult::NoDependence; + + // Return 'NoDependence' if one of these accesses is not a StoreOp. if (!allowRAR && !isa(srcAccess.opInst) && !isa(dstAccess.opInst)) - return false; + return DependenceResult::NoDependence; // Get composed access function for 'srcAccess'. AffineValueMap srcAccessMap; @@ -793,14 +796,14 @@ bool mlir::checkMemrefAccessDependence( // Get iteration domain for the 'srcAccess' operation. FlatAffineConstraints srcDomain; if (failed(getInstIndexSet(srcAccess.opInst, &srcDomain))) - return false; + return DependenceResult::Failure; // Get iteration domain for 'dstAccess' operation. FlatAffineConstraints dstDomain; if (failed(getInstIndexSet(dstAccess.opInst, &dstDomain))) - return false; + return DependenceResult::Failure; - // Return 'false' if loopDepth > numCommonLoops and if the ancestor operation + // Return 'NoDependence' if loopDepth > numCommonLoops and if the ancestor // operation of 'srcAccess' does not properly dominate the ancestor // operation of 'dstAccess' in the same common operation block. // Note: this check is skipped if 'allowRAR' is true, because because RAR @@ -810,7 +813,7 @@ bool mlir::checkMemrefAccessDependence( if (!allowRAR && loopDepth > numCommonLoops && !srcAppearsBeforeDstInAncestralBlock(srcAccess, dstAccess, srcDomain, numCommonLoops)) { - return false; + return DependenceResult::NoDependence; } // Build dim and symbol position maps for each access from access operand // Value to position in merged contstraint system. @@ -830,7 +833,7 @@ bool mlir::checkMemrefAccessDependence( // local variables for mod/div exprs are supported. if (failed(addMemRefAccessConstraints(srcAccessMap, dstAccessMap, valuePosMap, dependenceConstraints))) - return true; + return DependenceResult::Failure; // Add 'src' happens before 'dst' ordering constraints. addOrderingConstraints(srcDomain, dstDomain, loopDepth, @@ -839,9 +842,9 @@ bool mlir::checkMemrefAccessDependence( addDomainConstraints(srcDomain, dstDomain, valuePosMap, dependenceConstraints); - // Return false if the solution space is empty: no dependence. + // Return 'NoDependence' if the solution space is empty: no dependence. if (dependenceConstraints->isEmpty()) { - return false; + return DependenceResult::NoDependence; } // Compute dependence direction vector and return true. @@ -852,7 +855,7 @@ bool mlir::checkMemrefAccessDependence( LLVM_DEBUG(llvm::dbgs() << "Dependence polyhedron:\n"); LLVM_DEBUG(dependenceConstraints->dump()); - return true; + return DependenceResult::HasDependence; } /// Gathers dependence components for dependences between all ops in loop nest @@ -880,10 +883,10 @@ void mlir::getDependenceComponents( llvm::SmallVector depComps; // TODO(andydavis,bondhugula) Explore whether it would be profitable // to pre-compute and store deps instead of repeatedly checking. - if (checkMemrefAccessDependence(srcAccess, dstAccess, d, - &dependenceConstraints, &depComps)) { + DependenceResult result = checkMemrefAccessDependence( + srcAccess, dstAccess, d, &dependenceConstraints, &depComps); + if (hasDependence(result)) depCompsVec->push_back(depComps); - } } } } diff --git a/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp b/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp index 2b0f1ab50ad..4456ac2b50b 100644 --- a/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp @@ -93,9 +93,11 @@ static void checkDependences(ArrayRef loadsAndStores) { for (unsigned d = 1; d <= numCommonLoops + 1; ++d) { FlatAffineConstraints dependenceConstraints; llvm::SmallVector dependenceComponents; - bool ret = checkMemrefAccessDependence(srcAccess, dstAccess, d, - &dependenceConstraints, - &dependenceComponents); + DependenceResult result = checkMemrefAccessDependence( + srcAccess, dstAccess, d, &dependenceConstraints, + &dependenceComponents); + assert(result.value != DependenceResult::Failure); + bool ret = hasDependence(result); // TODO(andydavis) Print dependence type (i.e. RAW, etc) and print // distance vectors as: ([2, 3], [0, 10]). Also, shorten distance // vectors from ([1, 1], [3, 3]) to (1, 3). diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index aa842364f26..e5418fc17a2 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -640,9 +640,10 @@ LogicalResult mlir::getBackwardComputationSliceState( bool readReadAccesses = isa(srcAccess.opInst) && isa(dstAccess.opInst); FlatAffineConstraints dependenceConstraints; - if (!checkMemrefAccessDependence( - srcAccess, dstAccess, /*loopDepth=*/1, &dependenceConstraints, - /*dependenceComponents=*/nullptr, /*allowRAR=*/readReadAccesses)) { + DependenceResult result = checkMemrefAccessDependence( + srcAccess, dstAccess, /*loopDepth=*/1, &dependenceConstraints, + /*dependenceComponents=*/nullptr, /*allowRAR=*/readReadAccesses); + if (!hasDependence(result)) { return failure(); } // Get loop nest surrounding src operation. @@ -922,9 +923,10 @@ bool mlir::isLoopParallel(AffineForOp forOp) { for (auto *dstOpInst : loadAndStoreOpInsts) { MemRefAccess dstAccess(dstOpInst); FlatAffineConstraints dependenceConstraints; - if (checkMemrefAccessDependence(srcAccess, dstAccess, depth, - &dependenceConstraints, - /*dependenceComponents=*/nullptr)) + DependenceResult result = checkMemrefAccessDependence( + srcAccess, dstAccess, depth, &dependenceConstraints, + /*dependenceComponents=*/nullptr); + if (result.value != DependenceResult::NoDependence) return false; } } diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 0f39e52eefb..829b1b221ef 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -954,9 +954,10 @@ static unsigned getMaxLoopDepth(ArrayRef loadOpInsts, for (unsigned d = 1; d <= numCommonLoops + 1; ++d) { FlatAffineConstraints dependenceConstraints; // TODO(andydavis) Cache dependence analysis results, check cache here. - if (checkMemrefAccessDependence(srcAccess, dstAccess, d, - &dependenceConstraints, - /*dependenceComponents=*/nullptr)) { + DependenceResult result = checkMemrefAccessDependence( + srcAccess, dstAccess, d, &dependenceConstraints, + /*dependenceComponents=*/nullptr); + if (hasDependence(result)) { // Store minimum loop depth and break because we want the min 'd' at // which there is a dependence. loopDepth = std::min(loopDepth, d - 1); diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index 45a11efc3e3..c5676afaf63 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -131,9 +131,10 @@ void MemRefDataFlowOpt::forwardStoreToLoad(LoadOp loadOp) { unsigned nsLoops = getNumCommonSurroundingLoops(*loadOpInst, *storeOpInst); // Dependences at loop depth <= minSurroundingLoops do NOT matter. for (unsigned d = nsLoops + 1; d > minSurroundingLoops; d--) { - if (!checkMemrefAccessDependence(srcAccess, destAccess, d, - &dependenceConstraints, - /*dependenceComponents=*/nullptr)) + DependenceResult result = checkMemrefAccessDependence( + srcAccess, destAccess, d, &dependenceConstraints, + /*dependenceComponents=*/nullptr); + if (!hasDependence(result)) continue; depSrcStores.push_back(storeOpInst); // Check if this store is a candidate for forwarding; we only forward if -- cgit v1.2.3 From 898cf0e96878530e76a98ebe00c8f9d1492bea7e Mon Sep 17 00:00:00 2001 From: Andy Davis Date: Mon, 17 Jun 2019 09:59:35 -0700 Subject: LoopFusion: adds support for computing forward computation slices, which will enable fusion of consumer loop nests into their producers in subsequent CLs. PiperOrigin-RevId: 253601994 --- mlir/include/mlir/Analysis/AffineStructures.h | 25 ++- mlir/include/mlir/Analysis/Utils.h | 72 +++++- mlir/lib/Analysis/AffineStructures.cpp | 55 +++-- mlir/lib/Analysis/Utils.cpp | 243 +++++++++++++-------- mlir/lib/Transforms/LoopFusion.cpp | 21 +- mlir/lib/Transforms/TestLoopFusion.cpp | 81 ++++++- mlir/lib/Transforms/Utils/LoopFusionUtils.cpp | 37 ++-- .../Transforms/loop-fusion-slice-computation.mlir | 145 ++++++++++++ 8 files changed, 515 insertions(+), 164 deletions(-) create mode 100644 mlir/test/Transforms/loop-fusion-slice-computation.mlir (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h index d3feb3436ff..3e2b90d6557 100644 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -393,12 +393,12 @@ public: bool lower = true); /// Computes the lower and upper bounds of the first 'num' dimensional - /// identifiers as an affine map of the remaining identifiers (dimensional and - /// symbolic). This method is able to detect identifiers as floordiv's - /// and mod's of affine expressions of other identifiers with respect to - /// (positive) constants. Sets bound map to a null AffineMap if such a bound - /// can't be found (or yet unimplemented). - void getSliceBounds(unsigned num, MLIRContext *context, + /// identifiers (starting at 'offset') as an affine map of the remaining + /// identifiers (dimensional and symbolic). This method is able to detect + /// identifiers as floordiv's and mod's of affine expressions of other + /// identifiers with respect to (positive) constants. Sets bound map to a + /// null AffineMap if such a bound can't be found (or yet unimplemented). + void getSliceBounds(unsigned offset, unsigned num, MLIRContext *context, SmallVectorImpl *lbMaps, SmallVectorImpl *ubMaps); @@ -648,13 +648,14 @@ public: Optional getConstantUpperBound(unsigned pos) const; /// Gets the lower and upper bound of the pos^th identifier treating - /// [dimStartPos, symbStartPos) as dimensions and [symStartPos, - /// getNumDimAndSymbolIds) as symbols. The returned multi-dimensional maps - /// in the pair represent the max and min of potentially multiple affine - /// expressions. The upper bound is exclusive. 'localExprs' holds pre-computed - /// AffineExpr's for all local identifiers in the system. + /// [0, offset) U [offset + num, symbStartPos) as dimensions and + /// [symStartPos, getNumDimAndSymbolIds) as symbols. The returned + /// multi-dimensional maps in the pair represent the max and min of + /// potentially multiple affine expressions. The upper bound is exclusive. + /// 'localExprs' holds pre-computed AffineExpr's for all local identifiers in + /// the system. std::pair - getLowerAndUpperBound(unsigned pos, unsigned dimStartPos, + getLowerAndUpperBound(unsigned pos, unsigned offset, unsigned num, unsigned symStartPos, ArrayRef localExprs, MLIRContext *context); diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index d6bf0c617ae..5c1f47a1348 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -73,6 +73,8 @@ struct ComputationSliceState { std::vector> lbOperands; // List of upper bound operands (ubOperands[i] are used by 'ubs[i]'). std::vector> ubOperands; + // Slice loop nest insertion point in target loop nest. + Block::iterator insertPoint; // Adds to 'cst' with constraints which represent the slice bounds on 'ivs' // in 'this'. Specifically, the values in 'ivs' are added to 'cst' as dim // identifiers and the values in 'lb/ubOperands' are added as symbols. @@ -85,19 +87,67 @@ struct ComputationSliceState { void clearBounds(); }; -/// Computes computation slice loop bounds for the loop nest surrounding -/// 'srcAccess', where the returned loop bound AffineMaps are functions of -/// loop IVs from the loop nest surrounding 'dstAccess'. -LogicalResult getBackwardComputationSliceState( - const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, - unsigned dstLoopDepth, ComputationSliceState *sliceState); +/// Computes the computation slice loop bounds for one loop nest as affine maps +/// of the other loop nest's IVs and symbols, using 'dependenceConstraints' +/// computed between 'depSourceAccess' and 'depSinkAccess'. +/// If 'isBackwardSlice' is true, a backwards slice is computed in which the +/// slice bounds of loop nest surrounding 'depSourceAccess' are computed in +/// terms of loop IVs and symbols of the loop nest surrounding 'depSinkAccess' +/// at 'loopDepth'. +/// If 'isBackwardSlice' is false, a forward slice is computed in which the +/// slice bounds of loop nest surrounding 'depSinkAccess' are computed in terms +/// of loop IVs and symbols of the loop nest surrounding 'depSourceAccess' at +/// 'loopDepth'. +/// The slice loop bounds and associated operands are returned in 'sliceState'. +// +// Backward slice example: +// +// affine.for %i0 = 0 to 10 { +// store %cst, %0[%i0] : memref<100xf32> // 'depSourceAccess' +// } +// affine.for %i1 = 0 to 10 { +// %v = load %0[%i1] : memref<100xf32> // 'depSinkAccess' +// } +// +// // Backward computation slice of loop nest '%i0'. +// affine.for %i0 = (d0) -> (d0)(%i1) to (d0) -> (d0 + 1)(%i1) { +// store %cst, %0[%i0] : memref<100xf32> // 'depSourceAccess' +// } +// +// Forward slice example: +// +// affine.for %i0 = 0 to 10 { +// store %cst, %0[%i0] : memref<100xf32> // 'depSourceAccess' +// } +// affine.for %i1 = 0 to 10 { +// %v = load %0[%i1] : memref<100xf32> // 'depSinkAccess' +// } +// +// // Forward computation slice of loop nest '%i1'. +// affine.for %i1 = (d0) -> (d0)(%i0) to (d0) -> (d0 + 1)(%i0) { +// %v = load %0[%i1] : memref<100xf32> // 'depSinkAccess' +// } +// +void getComputationSliceState(Operation *depSourceOp, Operation *depSinkOp, + FlatAffineConstraints *dependenceConstraints, + unsigned loopDepth, bool isBackwardSlice, + ComputationSliceState *sliceState); /// Computes in 'sliceUnion' the union of all slice bounds computed at -/// 'dstLoopDepth' between all pairs in 'srcOps' and 'dstOp' which access the -/// same memref. Returns 'success' if union was computed, 'failure' otherwise. -LogicalResult computeSliceUnion(ArrayRef srcOps, - ArrayRef dstOps, - unsigned dstLoopDepth, +/// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB'. +/// The parameter 'numCommonLoops' is the number of loops common to the +/// operations in 'opsA' and 'opsB'. +/// If 'isBackwardSlice' is true, computes slice bounds for loop nest +/// surrounding ops in 'opsA', as a function of IVs and symbols of loop nest +/// surrounding ops in 'opsB' at 'loopDepth'. +/// If 'isBackwardSlice' is false, computes slice bounds for loop nest +/// surrounding ops in 'opsB', as a function of IVs and symbols of loop nest +/// surrounding ops in 'opsA' at 'loopDepth'. +/// Returns 'success' if union was computed, 'failure' otherwise. +// TODO(andydavis) Change this API to take 'forOpA'/'forOpB'. +LogicalResult computeSliceUnion(ArrayRef opsA, + ArrayRef opsB, unsigned loopDepth, + unsigned numCommonLoops, bool isBackwardSlice, ComputationSliceState *sliceUnion); /// Creates a clone of the computation contained in the loop nest surrounding diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 41f8e075813..46e45351d54 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1423,19 +1423,28 @@ void FlatAffineConstraints::removeRedundantInequalities() { } std::pair FlatAffineConstraints::getLowerAndUpperBound( - unsigned pos, unsigned dimStartPos, unsigned symStartPos, + unsigned pos, unsigned offset, unsigned num, unsigned symStartPos, ArrayRef localExprs, MLIRContext *context) { - assert(pos < dimStartPos && "invalid dim start pos"); - assert(symStartPos >= dimStartPos && "invalid sym start pos"); + assert(pos + offset < getNumDimIds() && "invalid dim start pos"); + assert(symStartPos >= (pos + offset) && "invalid sym start pos"); assert(getNumLocalIds() == localExprs.size() && "incorrect local exprs count"); SmallVector lbIndices, ubIndices; - getLowerAndUpperBoundIndices(*this, pos, &lbIndices, &ubIndices); + getLowerAndUpperBoundIndices(*this, pos + offset, &lbIndices, &ubIndices); + + /// Add to 'b' from 'a' in set [0, offset) U [offset + num, symbStartPos). + auto addCoeffs = [&](ArrayRef a, SmallVectorImpl &b) { + b.clear(); + for (unsigned i = 0, e = a.size(); i < e; ++i) { + if (i < offset || i >= offset + num) + b.push_back(a[i]); + } + }; SmallVector lb, ub; SmallVector exprs; - unsigned dimCount = symStartPos - dimStartPos; + unsigned dimCount = symStartPos - num; unsigned symCount = getNumDimAndSymbolIds() - symStartPos; exprs.reserve(lbIndices.size()); // Lower bound expressions. @@ -1444,7 +1453,7 @@ std::pair FlatAffineConstraints::getLowerAndUpperBound( // Extract the lower bound (in terms of other coeff's + const), i.e., if // i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j // - 1. - lb.assign(ineq.begin() + dimStartPos, ineq.end()); + addCoeffs(ineq, lb); std::transform(lb.begin(), lb.end(), lb.begin(), std::negate()); auto expr = mlir::toAffineExpr(lb, dimCount, symCount, localExprs, context); exprs.push_back(expr); @@ -1458,7 +1467,7 @@ std::pair FlatAffineConstraints::getLowerAndUpperBound( for (auto idx : ubIndices) { auto ineq = getInequality(idx); // Extract the upper bound (in terms of other coeff's + const). - ub.assign(ineq.begin() + dimStartPos, ineq.end()); + addCoeffs(ineq, ub); auto expr = mlir::toAffineExpr(ub, dimCount, symCount, localExprs, context); // Upper bound is exclusive. exprs.push_back(expr + 1); @@ -1470,10 +1479,12 @@ std::pair FlatAffineConstraints::getLowerAndUpperBound( } /// Computes the lower and upper bounds of the first 'num' dimensional -/// identifiers as affine maps of the remaining identifiers (dimensional and -/// symbolic identifiers). Local identifiers are themselves explicitly computed -/// as affine functions of other identifiers in this process if needed. -void FlatAffineConstraints::getSliceBounds(unsigned num, MLIRContext *context, +/// identifiers (starting at 'offset') as affine maps of the remaining +/// identifiers (dimensional and symbolic identifiers). Local identifiers are +/// themselves explicitly computed as affine functions of other identifiers in +/// this process if needed. +void FlatAffineConstraints::getSliceBounds(unsigned offset, unsigned num, + MLIRContext *context, SmallVectorImpl *lbMaps, SmallVectorImpl *ubMaps) { assert(num < getNumDimIds() && "invalid range"); @@ -1488,8 +1499,12 @@ void FlatAffineConstraints::getSliceBounds(unsigned num, MLIRContext *context, // Record computed/detected identifiers. SmallVector memo(getNumIds()); // Initialize dimensional and symbolic identifiers. - for (unsigned i = num, e = getNumDimIds(); i < e; i++) - memo[i] = getAffineDimExpr(i - num, context); + for (unsigned i = 0, e = getNumDimIds(); i < e; i++) { + if (i < offset) + memo[i] = getAffineDimExpr(i, context); + else if (i >= offset + num) + memo[i] = getAffineDimExpr(i - num, context); + } for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++) memo[i] = getAffineSymbolExpr(i - getNumDimIds(), context); @@ -1578,7 +1593,7 @@ void FlatAffineConstraints::getSliceBounds(unsigned num, MLIRContext *context, for (unsigned pos = 0; pos < num; pos++) { unsigned numMapDims = getNumDimIds() - num; unsigned numMapSymbols = getNumSymbolIds(); - AffineExpr expr = memo[pos]; + AffineExpr expr = memo[pos + offset]; if (expr) expr = simplifyAffineExpr(expr, numMapDims, numMapSymbols); @@ -1601,7 +1616,7 @@ void FlatAffineConstraints::getSliceBounds(unsigned num, MLIRContext *context, tmpClone->removeRedundantInequalities(); } std::tie(lbMap, ubMap) = tmpClone->getLowerAndUpperBound( - pos, num, getNumDimIds(), {}, context); + pos, offset, num, getNumDimIds(), {}, context); } // If the above fails, we'll just use the constant lower bound and the @@ -1612,7 +1627,7 @@ void FlatAffineConstraints::getSliceBounds(unsigned num, MLIRContext *context, if (!lbMap || lbMap.getNumResults() > 1) { LLVM_DEBUG(llvm::dbgs() << "WARNING: Potentially over-approximating slice lb\n"); - auto lbConst = getConstantLowerBound(pos); + auto lbConst = getConstantLowerBound(pos + offset); if (lbConst.hasValue()) { lbMap = AffineMap::get( numMapDims, numMapSymbols, @@ -1622,7 +1637,7 @@ void FlatAffineConstraints::getSliceBounds(unsigned num, MLIRContext *context, if (!ubMap || ubMap.getNumResults() > 1) { LLVM_DEBUG(llvm::dbgs() << "WARNING: Potentially over-approximating slice ub\n"); - auto ubConst = getConstantUpperBound(pos); + auto ubConst = getConstantUpperBound(pos + offset); if (ubConst.hasValue()) { (ubMap) = AffineMap::get( numMapDims, numMapSymbols, @@ -1630,9 +1645,11 @@ void FlatAffineConstraints::getSliceBounds(unsigned num, MLIRContext *context, } } } - LLVM_DEBUG(llvm::dbgs() << "lb map for pos = " << Twine(pos) << ", expr: "); + LLVM_DEBUG(llvm::dbgs() + << "lb map for pos = " << Twine(pos + offset) << ", expr: "); LLVM_DEBUG(lbMap.dump();); - LLVM_DEBUG(llvm::dbgs() << "ub map for pos = " << Twine(pos) << ", expr: "); + LLVM_DEBUG(llvm::dbgs() + << "ub map for pos = " << Twine(pos + offset) << ", expr: "); LLVM_DEBUG(ubMap.dump();); } } diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index e5418fc17a2..ae991f796e0 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -504,48 +504,84 @@ LogicalResult addMissingLoopIVBounds(SmallPtrSet &ivs, return success(); } -/// Computes in 'sliceUnion' the union of all slice bounds computed at -/// 'dstLoopDepth' between all pairs in 'srcOps' and 'dstOp' which access the -/// same memref. Returns 'Success' if union was computed, 'failure' otherwise. -LogicalResult mlir::computeSliceUnion(ArrayRef srcOps, - ArrayRef dstOps, - unsigned dstLoopDepth, - ComputationSliceState *sliceUnion) { - unsigned numSrcOps = srcOps.size(); - unsigned numDstOps = dstOps.size(); - assert(numSrcOps > 0 && numDstOps > 0); - - // Compute the intersection of 'srcMemrefToOps' and 'dstMemrefToOps'. - llvm::SmallDenseSet memrefIntersection; - for (auto *srcOp : srcOps) { - auto *srcMemRef = getLoadOrStoreMemRef(srcOp); - for (auto *dstOp : dstOps) { - if (srcMemRef == getLoadOrStoreMemRef(dstOp)) - memrefIntersection.insert(srcMemRef); +// Returns the innermost common loop depth for the set of operations in 'ops'. +// TODO(andydavis) Move this to LoopUtils. +static unsigned +getInnermostCommonLoopDepth(ArrayRef ops, + SmallVectorImpl &surroundingLoops) { + unsigned numOps = ops.size(); + assert(numOps > 0); + + std::vector> loops(numOps); + unsigned loopDepthLimit = std::numeric_limits::max(); + for (unsigned i = 0; i < numOps; ++i) { + getLoopIVs(*ops[i], &loops[i]); + loopDepthLimit = + std::min(loopDepthLimit, static_cast(loops[i].size())); + } + + unsigned loopDepth = 0; + for (unsigned d = 0; d < loopDepthLimit; ++d) { + unsigned i; + for (i = 1; i < numOps; ++i) { + if (loops[i - 1][d] != loops[i][d]) + return loopDepth; } + surroundingLoops.push_back(loops[i - 1][d]); + ++loopDepth; } - // Return failure if 'memrefIntersection' is empty. - if (memrefIntersection.empty()) - return failure(); + return loopDepth; +} - // Compute the union of slice bounds between all pairs in 'srcOps' and - // 'dstOps' in 'sliceUnionCst'. +/// Computes in 'sliceUnion' the union of all slice bounds computed at +/// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB'. +/// Returns 'Success' if union was computed, 'failure' otherwise. +LogicalResult mlir::computeSliceUnion(ArrayRef opsA, + ArrayRef opsB, + unsigned loopDepth, + unsigned numCommonLoops, + bool isBackwardSlice, + ComputationSliceState *sliceUnion) { + // Compute the union of slice bounds between all pairs in 'opsA' and + // 'opsB' in 'sliceUnionCst'. FlatAffineConstraints sliceUnionCst; assert(sliceUnionCst.getNumDimAndSymbolIds() == 0); - for (unsigned i = 0; i < numSrcOps; ++i) { - MemRefAccess srcAccess(srcOps[i]); - for (unsigned j = 0; j < numDstOps; ++j) { - MemRefAccess dstAccess(dstOps[j]); + std::vector> dependentOpPairs; + for (unsigned i = 0, numOpsA = opsA.size(); i < numOpsA; ++i) { + MemRefAccess srcAccess(opsA[i]); + for (unsigned j = 0, numOpsB = opsB.size(); j < numOpsB; ++j) { + MemRefAccess dstAccess(opsB[j]); if (srcAccess.memref != dstAccess.memref) continue; - // Compute slice bounds for 'srcAccess' and 'dstAccess'. - ComputationSliceState tmpSliceState; - if (failed(mlir::getBackwardComputationSliceState( - srcAccess, dstAccess, dstLoopDepth, &tmpSliceState))) { - LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice bounds\n."); + // Check if 'loopDepth' exceeds nesting depth of src/dst ops. + if ((!isBackwardSlice && loopDepth > getNestingDepth(*opsA[i])) || + (isBackwardSlice && loopDepth > getNestingDepth(*opsB[j]))) { + LLVM_DEBUG(llvm::dbgs() << "Invalid loop depth\n."); return failure(); } + bool readReadAccesses = + isa(srcAccess.opInst) && isa(dstAccess.opInst); + FlatAffineConstraints dependenceConstraints; + // Check dependence between 'srcAccess' and 'dstAccess'. + DependenceResult result = checkMemrefAccessDependence( + srcAccess, dstAccess, /*loopDepth=*/numCommonLoops + 1, + &dependenceConstraints, /*dependenceComponents=*/nullptr, + /*allowRAR=*/readReadAccesses); + if (result.value == DependenceResult::Failure) { + LLVM_DEBUG(llvm::dbgs() << "Dependence check failed\n."); + return failure(); + } + if (result.value == DependenceResult::NoDependence) + continue; + dependentOpPairs.push_back({opsA[i], opsB[j]}); + + // Compute slice bounds for 'srcAccess' and 'dstAccess'. + ComputationSliceState tmpSliceState; + mlir::getComputationSliceState(opsA[i], opsB[j], &dependenceConstraints, + loopDepth, isBackwardSlice, + &tmpSliceState); + if (sliceUnionCst.getNumDimAndSymbolIds() == 0) { // Initialize 'sliceUnionCst' with the bounds computed in previous step. if (failed(tmpSliceState.getAsConstraints(&sliceUnionCst))) { @@ -599,116 +635,147 @@ LogicalResult mlir::computeSliceUnion(ArrayRef srcOps, } } - // Store 'numSrcLoopIvs' before converting dst loop IVs to dims. - unsigned numSrcLoopIVs = sliceUnionCst.getNumDimIds(); + // Empty union. + if (sliceUnionCst.getNumDimAndSymbolIds() == 0) + return failure(); + + // Gather loops surrounding ops from loop nest where slice will be inserted. + SmallVector ops; + for (auto &dep : dependentOpPairs) { + ops.push_back(isBackwardSlice ? dep.second : dep.first); + } + SmallVector surroundingLoops; + unsigned innermostCommonLoopDepth = + getInnermostCommonLoopDepth(ops, surroundingLoops); + if (loopDepth > innermostCommonLoopDepth) { + LLVM_DEBUG(llvm::dbgs() << "Exceeds max loop depth\n."); + return failure(); + } + + // Store 'numSliceLoopIVs' before converting dst loop IVs to dims. + unsigned numSliceLoopIVs = sliceUnionCst.getNumDimIds(); // Convert any dst loop IVs which are symbol identifiers to dim identifiers. sliceUnionCst.convertLoopIVSymbolsToDims(); sliceUnion->clearBounds(); - sliceUnion->lbs.resize(numSrcLoopIVs, AffineMap()); - sliceUnion->ubs.resize(numSrcLoopIVs, AffineMap()); + sliceUnion->lbs.resize(numSliceLoopIVs, AffineMap()); + sliceUnion->ubs.resize(numSliceLoopIVs, AffineMap()); // Get slice bounds from slice union constraints 'sliceUnionCst'. - sliceUnionCst.getSliceBounds(numSrcLoopIVs, srcOps[0]->getContext(), - &sliceUnion->lbs, &sliceUnion->ubs); + sliceUnionCst.getSliceBounds(/*offset=*/0, numSliceLoopIVs, + opsA[0]->getContext(), &sliceUnion->lbs, + &sliceUnion->ubs); // Add slice bound operands of union. SmallVector sliceBoundOperands; - sliceUnionCst.getIdValues(numSrcLoopIVs, + sliceUnionCst.getIdValues(numSliceLoopIVs, sliceUnionCst.getNumDimAndSymbolIds(), &sliceBoundOperands); // Copy src loop IVs from 'sliceUnionCst' to 'sliceUnion'. sliceUnion->ivs.clear(); - sliceUnionCst.getIdValues(0, numSrcLoopIVs, &sliceUnion->ivs); + sliceUnionCst.getIdValues(0, numSliceLoopIVs, &sliceUnion->ivs); + + // Set loop nest insertion point to block start at 'loopDepth'. + sliceUnion->insertPoint = + isBackwardSlice + ? surroundingLoops[loopDepth - 1].getBody()->begin() + : std::prev(surroundingLoops[loopDepth - 1].getBody()->end()); // Give each bound its own copy of 'sliceBoundOperands' for subsequent // canonicalization. - sliceUnion->lbOperands.resize(numSrcLoopIVs, sliceBoundOperands); - sliceUnion->ubOperands.resize(numSrcLoopIVs, sliceBoundOperands); + sliceUnion->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands); + sliceUnion->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands); return success(); } const char *const kSliceFusionBarrierAttrName = "slice_fusion_barrier"; -// Computes memref dependence between 'srcAccess' and 'dstAccess', projects -// out any dst loop IVs at depth greater than 'dstLoopDepth', and computes slice -// bounds in 'sliceState' which represent the src IVs in terms of the dst IVs, -// symbols and constants. -LogicalResult mlir::getBackwardComputationSliceState( - const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, - unsigned dstLoopDepth, ComputationSliceState *sliceState) { - bool readReadAccesses = - isa(srcAccess.opInst) && isa(dstAccess.opInst); - FlatAffineConstraints dependenceConstraints; - DependenceResult result = checkMemrefAccessDependence( - srcAccess, dstAccess, /*loopDepth=*/1, &dependenceConstraints, - /*dependenceComponents=*/nullptr, /*allowRAR=*/readReadAccesses); - if (!hasDependence(result)) { - return failure(); - } +// Computes slice bounds by projecting out any loop IVs from +// 'dependenceConstraints' at depth greater than 'loopDepth', and computes slice +// bounds in 'sliceState' which represent the one loop nest's IVs in terms of +// the other loop nest's IVs, symbols and constants (using 'isBackwardsSlice'). +void mlir::getComputationSliceState( + Operation *depSourceOp, Operation *depSinkOp, + FlatAffineConstraints *dependenceConstraints, unsigned loopDepth, + bool isBackwardSlice, ComputationSliceState *sliceState) { // Get loop nest surrounding src operation. SmallVector srcLoopIVs; - getLoopIVs(*srcAccess.opInst, &srcLoopIVs); + getLoopIVs(*depSourceOp, &srcLoopIVs); unsigned numSrcLoopIVs = srcLoopIVs.size(); // Get loop nest surrounding dst operation. SmallVector dstLoopIVs; - getLoopIVs(*dstAccess.opInst, &dstLoopIVs); + getLoopIVs(*depSinkOp, &dstLoopIVs); unsigned numDstLoopIVs = dstLoopIVs.size(); - if (dstLoopDepth > numDstLoopIVs) { - dstAccess.opInst->emitError("invalid destination loop depth"); - return failure(); - } - // Project out dimensions other than those up to 'dstLoopDepth'. - dependenceConstraints.projectOut(numSrcLoopIVs + dstLoopDepth, - numDstLoopIVs - dstLoopDepth); + assert((!isBackwardSlice && loopDepth <= numSrcLoopIVs) || + (isBackwardSlice && loopDepth <= numDstLoopIVs)); + + // Project out dimensions other than those up to 'loopDepth'. + unsigned pos = isBackwardSlice ? numSrcLoopIVs + loopDepth : loopDepth; + unsigned num = + isBackwardSlice ? numDstLoopIVs - loopDepth : numSrcLoopIVs - loopDepth; + dependenceConstraints->projectOut(pos, num); - // Add src loop IV values to 'sliceState'. - dependenceConstraints.getIdValues(0, numSrcLoopIVs, &sliceState->ivs); + // Add slice loop IV values to 'sliceState'. + unsigned offset = isBackwardSlice ? 0 : loopDepth; + unsigned numSliceLoopIVs = isBackwardSlice ? numSrcLoopIVs : numDstLoopIVs; + dependenceConstraints->getIdValues(offset, offset + numSliceLoopIVs, + &sliceState->ivs); // Set up lower/upper bound affine maps for the slice. - sliceState->lbs.resize(numSrcLoopIVs, AffineMap()); - sliceState->ubs.resize(numSrcLoopIVs, AffineMap()); + sliceState->lbs.resize(numSliceLoopIVs, AffineMap()); + sliceState->ubs.resize(numSliceLoopIVs, AffineMap()); - // Get bounds for src IVs in terms of dst IVs, symbols, and constants. - dependenceConstraints.getSliceBounds(numSrcLoopIVs, - srcAccess.opInst->getContext(), - &sliceState->lbs, &sliceState->ubs); + // Get bounds for slice IVs in terms of other IVs, symbols, and constants. + dependenceConstraints->getSliceBounds(offset, numSliceLoopIVs, + depSourceOp->getContext(), + &sliceState->lbs, &sliceState->ubs); // Set up bound operands for the slice's lower and upper bounds. SmallVector sliceBoundOperands; - dependenceConstraints.getIdValues( - numSrcLoopIVs, dependenceConstraints.getNumDimAndSymbolIds(), - &sliceBoundOperands); + unsigned numDimsAndSymbols = dependenceConstraints->getNumDimAndSymbolIds(); + for (unsigned i = 0; i < numDimsAndSymbols; ++i) { + if (i < offset || i >= offset + numSliceLoopIVs) { + sliceBoundOperands.push_back(dependenceConstraints->getIdValue(i)); + } + } + // Give each bound its own copy of 'sliceBoundOperands' for subsequent // canonicalization. - sliceState->lbOperands.resize(numSrcLoopIVs, sliceBoundOperands); - sliceState->ubOperands.resize(numSrcLoopIVs, sliceBoundOperands); + sliceState->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands); + sliceState->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands); + + // Set destination loop nest insertion point to block start at 'dstLoopDepth'. + sliceState->insertPoint = + isBackwardSlice ? dstLoopIVs[loopDepth - 1].getBody()->begin() + : std::prev(srcLoopIVs[loopDepth - 1].getBody()->end()); llvm::SmallDenseSet sequentialLoops; - if (readReadAccesses) { + if (isa(depSourceOp) && isa(depSinkOp)) { // For read-read access pairs, clear any slice bounds on sequential loops. // Get sequential loops in loop nest rooted at 'srcLoopIVs[0]'. - getSequentialLoops(srcLoopIVs[0], &sequentialLoops); + getSequentialLoops(isBackwardSlice ? srcLoopIVs[0] : dstLoopIVs[0], + &sequentialLoops); } // Clear all sliced loop bounds beginning at the first sequential loop, or // first loop with a slice fusion barrier attribute.. // TODO(andydavis, bondhugula) Use MemRef read/write regions instead of // using 'kSliceFusionBarrierAttrName'. - for (unsigned i = 0; i < numSrcLoopIVs; ++i) { - Value *iv = srcLoopIVs[i].getInductionVar(); + auto getSliceLoop = [&](unsigned i) { + return isBackwardSlice ? srcLoopIVs[i] : dstLoopIVs[i]; + }; + for (unsigned i = 0; i < numSliceLoopIVs; ++i) { + Value *iv = getSliceLoop(i).getInductionVar(); if (sequentialLoops.count(iv) == 0 && - srcLoopIVs[i].getAttr(kSliceFusionBarrierAttrName) == nullptr) + getSliceLoop(i).getAttr(kSliceFusionBarrierAttrName) == nullptr) continue; - for (unsigned j = i; j < numSrcLoopIVs; ++j) { + for (unsigned j = i; j < numSliceLoopIVs; ++j) { sliceState->lbs[j] = AffineMap(); sliceState->ubs[j] = AffineMap(); } break; } - - return success(); } /// Creates a computation slice of the loop nest surrounding 'srcOpInst', diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 829b1b221ef..95890a68126 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -1329,7 +1329,9 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, for (unsigned i = maxDstLoopDepth; i >= 1; --i) { // Compute the union of slice bounds of all ops in 'dstLoadOpInsts'. if (failed(mlir::computeSliceUnion({srcOpInst}, dstLoadOpInsts, - /*dstLoopDepth=*/i, + /*loopDepth=*/i, + /*numCommonLoops=*/0, + /*isBackwardSlice=*/true, &sliceStates[i - 1]))) { LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed for loopDepth: " << i << "\n"); @@ -1736,15 +1738,16 @@ public: dstLoadOpInsts, dstStoreOpInsts, &sliceState, &bestDstLoopDepth, maximalFusion)) continue; - // TODO(andydavis) Remove assert and surrounding code when - // canFuseLoops is fully functional. + // TODO(andydavis) Remove the following test code when canFuseLoops + // is fully functional. mlir::ComputationSliceState sliceUnion; - FusionResult result = mlir::canFuseLoops( - cast(srcNode->op), cast(dstNode->op), - bestDstLoopDepth, &sliceUnion); - assert(result.value == FusionResult::Success); - (void)result; - + if (!maximalFusion) { + FusionResult result = mlir::canFuseLoops( + cast(srcNode->op), cast(dstNode->op), + bestDstLoopDepth, &sliceUnion); + assert(result.value == FusionResult::Success); + (void)result; + } // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. auto sliceLoopNest = mlir::insertBackwardComputationSlice( srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState); diff --git a/mlir/lib/Transforms/TestLoopFusion.cpp b/mlir/lib/Transforms/TestLoopFusion.cpp index 638cf915b6a..39990968a34 100644 --- a/mlir/lib/Transforms/TestLoopFusion.cpp +++ b/mlir/lib/Transforms/TestLoopFusion.cpp @@ -45,6 +45,11 @@ static llvm::cl::opt clTestDependenceCheck( llvm::cl::desc("Enable testing of loop fusion dependence check"), llvm::cl::cat(clOptionsCategory)); +static llvm::cl::opt clTestSliceComputation( + "test-loop-fusion-slice-computation", + llvm::cl::desc("Enable testing of loop fusion slice computation"), + llvm::cl::cat(clOptionsCategory)); + namespace { struct TestLoopFusion : public FunctionPass { @@ -70,20 +75,74 @@ gatherLoops(Block *block, unsigned currLoopDepth, } } -// Run fusion dependence check on 'loops[i]' and 'loops[j]' at 'loopDepth'. +// Run fusion dependence check on 'loops[i]' and 'loops[j]' at loop depths +// in range ['loopDepth' + 1, 'maxLoopDepth']. // Emits a remark on 'loops[i]' if a fusion-preventing dependence exists. static void testDependenceCheck(SmallVector &loops, unsigned i, - unsigned j, unsigned loopDepth) { + unsigned j, unsigned loopDepth, + unsigned maxLoopDepth) { AffineForOp srcForOp = loops[i]; AffineForOp dstForOp = loops[j]; mlir::ComputationSliceState sliceUnion; - // TODO(andydavis) Test at deeper loop depths current loop depth + 1. - FusionResult result = - mlir::canFuseLoops(srcForOp, dstForOp, loopDepth + 1, &sliceUnion); - if (result.value == FusionResult::FailBlockDependence) { - srcForOp.getOperation()->emitRemark("block-level dependence preventing" - " fusion of loop nest ") - << i << " into loop nest " << j << " at depth " << loopDepth; + for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) { + FusionResult result = + mlir::canFuseLoops(srcForOp, dstForOp, d, &sliceUnion); + if (result.value == FusionResult::FailBlockDependence) { + srcForOp.getOperation()->emitRemark("block-level dependence preventing" + " fusion of loop nest ") + << i << " into loop nest " << j << " at depth " << loopDepth; + } + } +} + +// Returns the index of 'op' in its block. +static unsigned getBlockIndex(Operation &op) { + unsigned index = 0; + for (auto &opX : *op.getBlock()) { + if (&op == &opX) + break; + ++index; + } + return index; +} + +// Returns a string representation of 'sliceUnion'. +static std::string getSliceStr(const mlir::ComputationSliceState &sliceUnion) { + std::string result; + llvm::raw_string_ostream os(result); + // Slice insertion point format [loop-depth, operation-block-index] + unsigned ipd = getNestingDepth(*sliceUnion.insertPoint); + unsigned ipb = getBlockIndex(*sliceUnion.insertPoint); + os << "insert point: (" << std::to_string(ipd) << ", " << std::to_string(ipb) + << ")"; + assert(sliceUnion.lbs.size() == sliceUnion.ubs.size()); + os << " loop bounds: "; + for (unsigned k = 0, e = sliceUnion.lbs.size(); k < e; ++k) { + os << '['; + sliceUnion.lbs[k].print(os); + os << ", "; + sliceUnion.ubs[k].print(os); + os << "] "; + } + return os.str(); +} + +// Computes fusion slice union on 'loops[i]' and 'loops[j]' at loop depths +// in range ['loopDepth' + 1, 'maxLoopDepth']. +// Emits a string represention of the slice union as a remark on 'loops[j]'. +static void testSliceComputation(SmallVector &loops, unsigned i, + unsigned j, unsigned loopDepth, + unsigned maxLoopDepth) { + AffineForOp forOpA = loops[i]; + AffineForOp forOpB = loops[j]; + for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) { + mlir::ComputationSliceState sliceUnion; + FusionResult result = mlir::canFuseLoops(forOpA, forOpB, d, &sliceUnion); + if (result.value == FusionResult::Success) { + forOpB.getOperation()->emitRemark("slice (") + << " src loop: " << i << ", dst loop: " << j << ", depth: " << d + << " : " << getSliceStr(sliceUnion) << ")"; + } } } @@ -104,7 +163,9 @@ void TestLoopFusion::runOnFunction() { if (j == k) continue; if (clTestDependenceCheck) - testDependenceCheck(loops, j, k, loopDepth); + testDependenceCheck(loops, j, k, loopDepth, depthToLoops.size()); + if (clTestSliceComputation) + testSliceComputation(loops, j, k, loopDepth, depthToLoops.size()); } } } diff --git a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp index cb1d9d17ed0..1fb41a2a5e2 100644 --- a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp @@ -192,11 +192,7 @@ gatherLoadsAndStores(AffineForOp forOp, return !hasIfOp; } -// TODO(andydavis) Add support for the following features in subsequent CLs: -// *) Compute dependences of unfused src/dst loops. -// *) Compute dependences of src/dst loop as if they were fused. -// *) Check for fusion preventing dependences (e.g. a dependence which changes -// from loop-independent to backward loop-carried after fusion). +// TODO(andydavis) Prevent fusion of loop nests with side-effecting operations. FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, unsigned dstLoopDepth, ComputationSliceState *srcSlice) { @@ -219,24 +215,35 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, return FusionResult::FailBlockDependence; } - // Gather all load and store ops in 'srcForOp'. - SmallVector srcLoadAndStoreOps; - if (!gatherLoadsAndStores(srcForOp, srcLoadAndStoreOps)) { + // Check if 'srcForOp' precedeces 'dstForOp' in 'block'. + bool isSrcForOpBeforeDstForOp = + srcForOp.getOperation()->isBeforeInBlock(dstForOp.getOperation()); + // 'forOpA' executes before 'forOpB' in 'block'. + auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp; + auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp; + + // Gather all load and store from 'forOpA' which precedes 'forOpB' in 'block'. + SmallVector opsA; + if (!gatherLoadsAndStores(forOpA, opsA)) { LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported.\n."); return FusionResult::FailPrecondition; } - // Gather all load and store ops in 'dstForOp'. - SmallVector dstLoadAndStoreOps; - if (!gatherLoadsAndStores(dstForOp, dstLoadAndStoreOps)) { + // Gather all load and store from 'forOpB' which succeeds 'forOpA' in 'block'. + SmallVector opsB; + if (!gatherLoadsAndStores(forOpB, opsB)) { LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported.\n."); return FusionResult::FailPrecondition; } - // Compute union of computation slices computed from all pairs in - // {'srcLoadAndStoreOps', 'dstLoadAndStoreOps'}. - if (failed(mlir::computeSliceUnion(srcLoadAndStoreOps, dstLoadAndStoreOps, - dstLoopDepth, srcSlice))) { + // Calculate the number of common loops surrounding 'srcForOp' and 'dstForOp'. + unsigned numCommonLoops = mlir::getNumCommonSurroundingLoops( + *srcForOp.getOperation(), *dstForOp.getOperation()); + + // Compute union of computation slices computed between all pairs of ops + // from 'forOpA' and 'forOpB'. + if (failed(mlir::computeSliceUnion(opsA, opsB, dstLoopDepth, numCommonLoops, + isSrcForOpBeforeDstForOp, srcSlice))) { LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n"); return FusionResult::FailPrecondition; } diff --git a/mlir/test/Transforms/loop-fusion-slice-computation.mlir b/mlir/test/Transforms/loop-fusion-slice-computation.mlir new file mode 100644 index 00000000000..9550cb72506 --- /dev/null +++ b/mlir/test/Transforms/loop-fusion-slice-computation.mlir @@ -0,0 +1,145 @@ +// RUN: mlir-opt %s -test-loop-fusion -test-loop-fusion-slice-computation -split-input-file -verify | FileCheck %s + +// ----- + +// CHECK-LABEL: func @slice_depth1_loop_nest() { +func @slice_depth1_loop_nest() { + %0 = alloc() : memref<100xf32> + %cst = constant 7.000000e+00 : f32 + affine.for %i0 = 0 to 16 { + // expected-remark@-1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] )}} + store %cst, %0[%i0] : memref<100xf32> + } + affine.for %i1 = 0 to 5 { + // expected-remark@-1 {{slice ( src loop: 0, dst loop: 1, depth: 1 : insert point: (1, 0) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] )}} + %1 = load %0[%i1] : memref<100xf32> + } + return +} + +// ----- + +// Loop %i0 writes to locations [2, 17] and loop %i0 reads from locations [3, 6] +// Slice loop bounds should be adjusted such that the load/store are for the +// same location. +// CHECK-LABEL: func @slice_depth1_loop_nest_with_offsets() { +func @slice_depth1_loop_nest_with_offsets() { + %0 = alloc() : memref<100xf32> + %cst = constant 7.000000e+00 : f32 + affine.for %i0 = 0 to 16 { + // expected-remark@-1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 2) loop bounds: [(d0) -> (d0 + 3), (d0) -> (d0 + 4)] )}} + %a0 = affine.apply (d0) -> (d0 + 2)(%i0) + store %cst, %0[%a0] : memref<100xf32> + } + affine.for %i1 = 4 to 8 { + // expected-remark@-1 {{slice ( src loop: 0, dst loop: 1, depth: 1 : insert point: (1, 0) loop bounds: [(d0) -> (d0 - 3), (d0) -> (d0 - 2)] )}} + %a1 = affine.apply (d0) -> (d0 - 1)(%i1) + %1 = load %0[%a1] : memref<100xf32> + } + return +} + +// ----- + +// Slices at loop depth 1 should only slice the loop bounds of the first loop. +// Slices at loop detph 2 should slice loop bounds of both loops. +// CHECK-LABEL: func @slice_depth2_loop_nest() { +func @slice_depth2_loop_nest() { + %0 = alloc() : memref<100x100xf32> + %cst = constant 7.000000e+00 : f32 + affine.for %i0 = 0 to 16 { + // expected-remark@-1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] [(d0) -> (0), (d0) -> (8)] )}} + // expected-remark@-2 {{slice ( src loop: 1, dst loop: 0, depth: 2 : insert point: (2, 1) loop bounds: [(d0, d1) -> (d0), (d0, d1) -> (d0 + 1)] [(d0, d1) -> (d1), (d0, d1) -> (d1 + 1)] )}} + affine.for %i1 = 0 to 16 { + store %cst, %0[%i0, %i1] : memref<100x100xf32> + } + } + affine.for %i2 = 0 to 10 { + // expected-remark@-1 {{slice ( src loop: 0, dst loop: 1, depth: 1 : insert point: (1, 0) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] [(d0) -> (0), (d0) -> (8)] )}} + // expected-remark@-2 {{slice ( src loop: 0, dst loop: 1, depth: 2 : insert point: (2, 0) loop bounds: [(d0, d1) -> (d0), (d0, d1) -> (d0 + 1)] [(d0, d1) -> (d1), (d0, d1) -> (d1 + 1)] )}} + affine.for %i3 = 0 to 8 { + %1 = load %0[%i2, %i3] : memref<100x100xf32> + } + } + return +} + +// ----- + +// The load at depth 1 in loop nest %i2 prevents slicing loop nest %i0 at depths +// greater than 1. However, loop nest %i2 can be sliced into loop nest %i0 at +// depths 1 and 2 because the dependent store in loop nest %i0 is at depth 2. +// CHECK-LABEL: func @slice_depth2_loop_nest_two_loads() { +func @slice_depth2_loop_nest_two_loads() { + %0 = alloc() : memref<100x100xf32> + %c0 = constant 0 : index + %cst = constant 7.000000e+00 : f32 + affine.for %i0 = 0 to 16 { + // expected-remark@-1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0)[s0] -> (d0), (d0)[s0] -> (d0 + 1)] [(d0)[s0] -> (0), (d0)[s0] -> (8)] )}} + // expected-remark@-2 {{slice ( src loop: 1, dst loop: 0, depth: 2 : insert point: (2, 1) loop bounds: [(d0, d1)[s0] -> (d0), (d0, d1)[s0] -> (d0 + 1)] [(d0, d1)[s0] -> (0), (d0, d1)[s0] -> (8)] )}} + affine.for %i1 = 0 to 16 { + store %cst, %0[%i0, %i1] : memref<100x100xf32> + } + } + affine.for %i2 = 0 to 10 { + // expected-remark@-1 {{slice ( src loop: 0, dst loop: 1, depth: 1 : insert point: (1, 0) loop bounds: [(d0)[s0] -> (d0), (d0)[s0] -> (d0 + 1)] [(d0)[s0] -> (0), (d0)[s0] -> (8)] )}} + affine.for %i3 = 0 to 8 { + %1 = load %0[%i2, %i3] : memref<100x100xf32> + } + %2 = load %0[%i2, %c0] : memref<100x100xf32> + } + return +} + +// ----- + +// The store at depth 1 in loop nest %i0 prevents slicing loop nest %i2 at +// depths greater than 1 into loop nest %i0. However, loop nest %i0 can be +// sliced into loop nest %i2 at depths 1 and 2 because the dependent load in +// loop nest %i2 is at depth 2. +// CHECK-LABEL: func @slice_depth2_loop_nest_two_stores() { +func @slice_depth2_loop_nest_two_stores() { + %0 = alloc() : memref<100x100xf32> + %c0 = constant 0 : index + %cst = constant 7.000000e+00 : f32 + affine.for %i0 = 0 to 16 { + // expected-remark@-1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 2) loop bounds: [(d0)[s0] -> (d0), (d0)[s0] -> (d0 + 1)] [(d0)[s0] -> (0), (d0)[s0] -> (8)] )}} + affine.for %i1 = 0 to 16 { + store %cst, %0[%i0, %i1] : memref<100x100xf32> + } + store %cst, %0[%i0, %c0] : memref<100x100xf32> + } + affine.for %i2 = 0 to 10 { + // expected-remark@-1 {{slice ( src loop: 0, dst loop: 1, depth: 1 : insert point: (1, 0) loop bounds: [(d0)[s0] -> (d0), (d0)[s0] -> (d0 + 1)] [(d0)[s0] -> (0), (d0)[s0] -> (16)] )}} + // expected-remark@-2 {{slice ( src loop: 0, dst loop: 1, depth: 2 : insert point: (2, 0) loop bounds: [(d0, d1)[s0] -> (d0), (d0, d1)[s0] -> (d0 + 1)] [(d0, d1)[s0] -> (0), (d0, d1)[s0] -> (16)] )}} + affine.for %i3 = 0 to 8 { + %1 = load %0[%i2, %i3] : memref<100x100xf32> + } + } + return +} + +// ----- + +// Test loop nest which has a smaller outer trip count than its inner loop. +// CHECK-LABEL: func @slice_loop_nest_with_smaller_outer_trip_count() { +func @slice_loop_nest_with_smaller_outer_trip_count() { + %0 = alloc() : memref<100x100xf32> + %c0 = constant 0 : index + %cst = constant 7.000000e+00 : f32 + affine.for %i0 = 0 to 16 { + // expected-remark@-1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] [(d0) -> (0), (d0) -> (10)] )}} + // expected-remark@-2 {{slice ( src loop: 1, dst loop: 0, depth: 2 : insert point: (2, 1) loop bounds: [(d0, d1) -> (d0), (d0, d1) -> (d0 + 1)] [(d0, d1) -> (d1), (d0, d1) -> (d1 + 1)] )}} + affine.for %i1 = 0 to 16 { + store %cst, %0[%i0, %i1] : memref<100x100xf32> + } + } + affine.for %i2 = 0 to 8 { + // expected-remark@-1 {{slice ( src loop: 0, dst loop: 1, depth: 1 : insert point: (1, 0) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] [(d0) -> (0), (d0) -> (10)] )}} + // expected-remark@-2 {{slice ( src loop: 0, dst loop: 1, depth: 2 : insert point: (2, 0) loop bounds: [(d0, d1) -> (d0), (d0, d1) -> (d0 + 1)] [(d0, d1) -> (d1), (d0, d1) -> (d1 + 1)] )}} + affine.for %i3 = 0 to 10 { + %1 = load %0[%i2, %i3] : memref<100x100xf32> + } + } + return +} \ No newline at end of file -- cgit v1.2.3 From 59b68146ffb34adddcdf0e7daff3fe1c66badb4b Mon Sep 17 00:00:00 2001 From: Andy Davis Date: Tue, 18 Jun 2019 08:52:09 -0700 Subject: Factor fusion compute cost calculation out of LoopFusion and into LoopFusionUtils (NFC). PiperOrigin-RevId: 253797886 --- mlir/include/mlir/Transforms/LoopFusionUtils.h | 41 +++++ mlir/lib/Transforms/LoopFusion.cpp | 231 ++---------------------- mlir/lib/Transforms/Utils/LoopFusionUtils.cpp | 234 +++++++++++++++++++++++++ 3 files changed, 285 insertions(+), 221 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Transforms/LoopFusionUtils.h b/mlir/include/mlir/Transforms/LoopFusionUtils.h index ccda6693f88..b6d1ea41ce6 100644 --- a/mlir/include/mlir/Transforms/LoopFusionUtils.h +++ b/mlir/include/mlir/Transforms/LoopFusionUtils.h @@ -24,9 +24,13 @@ #ifndef MLIR_TRANSFORMS_LOOP_FUSION_UTILS_H #define MLIR_TRANSFORMS_LOOP_FUSION_UTILS_H +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" + namespace mlir { class AffineForOp; struct ComputationSliceState; +class Operation; // TODO(andydavis) Extend this module to include utility functions for querying // fusion cost/storage reduction, and for performing the loop fusion @@ -54,6 +58,43 @@ struct FusionResult { FusionResult canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, unsigned dstLoopDepth, ComputationSliceState *srcSlice); + +/// LoopNestStats aggregates various per-loop statistics (eg. loop trip count +/// and operation count) for a loop nest up until (and including) the innermost +/// loop body. +struct LoopNestStats { + /// Map from AffineForOp to immediate child AffineForOps in its loop body. + llvm::DenseMap> loopMap; + /// Map from AffineForOp to count of operations in its loop body. + llvm::DenseMap opCountMap; + /// Map from AffineForOp to its constant trip count. + llvm::DenseMap tripCountMap; +}; + +/// Collect loop nest statistics (eg. loop trip count and operation count) +/// in 'stats' for loop nest rooted at 'forOp'. Returns true on success, +/// returns false otherwise. +// TODO(andydavis) Consider moving this to LoopUtils. +bool getLoopNestStats(AffineForOp forOp, LoopNestStats *stats); + +/// Computes the total cost of the loop nest rooted at 'forOp' using 'stats'. +/// Currently, the total cost is computed by counting the total operation +/// instance count (i.e. total number of operations in the loop body * loop +/// trip count) for the entire loop nest. +// TODO(andydavis) Improve this cost model. +int64_t getComputeCost(AffineForOp forOp, LoopNestStats &stats); + +/// Computes and returns in 'computeCost', the total compute cost of fusing the +/// 'slice' of the loop nest rooted at 'srcForOp' into 'dstForOp'. Currently, +/// the total cost is computed by counting the total operation instance count +/// (i.e. total number of operations in the loop body * loop trip count) for +/// the entire loop nest. +/// Returns true on success, failure otherwise (e.g. non-constant trip counts). +// TODO(andydavis) Improve this cost model. +bool getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, + AffineForOp dstForOp, LoopNestStats &dstStats, + ComputationSliceState *slice, int64_t *computeCost); + } // end namespace mlir #endif // MLIR_TRANSFORMS_LOOP_FUSION_UTILS_H diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 95890a68126..8d2e75b2dca 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -732,156 +732,6 @@ bool MemRefDependenceGraph::init(Function &f) { return true; } -namespace { - -// LoopNestStats aggregates various per-loop statistics (eg. loop trip count -// and operation count) for a loop nest up until the innermost loop body. -struct LoopNestStats { - // Map from AffineForOp to immediate child AffineForOps in its loop body. - DenseMap> loopMap; - // Map from AffineForOp to count of operations in its loop body. - DenseMap opCountMap; - // Map from AffineForOp to its constant trip count. - DenseMap tripCountMap; -}; - -// LoopNestStatsCollector walks a single loop nest and gathers per-loop -// trip count and operation count statistics and records them in 'stats'. -struct LoopNestStatsCollector { - LoopNestStats *stats; - bool hasLoopWithNonConstTripCount = false; - - LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {} - - void collect(Operation *op) { - op->walk([&](AffineForOp forOp) { - auto *forInst = forOp.getOperation(); - auto *parentInst = forOp.getOperation()->getParentOp(); - if (parentInst != nullptr) { - assert(isa(parentInst) && "Expected parent AffineForOp"); - // Add mapping to 'forOp' from its parent AffineForOp. - stats->loopMap[parentInst].push_back(forOp); - } - - // Record the number of op operations in the body of 'forOp'. - unsigned count = 0; - stats->opCountMap[forInst] = 0; - for (auto &op : *forOp.getBody()) { - if (!isa(op) && !isa(op)) - ++count; - } - stats->opCountMap[forInst] = count; - // Record trip count for 'forOp'. Set flag if trip count is not - // constant. - Optional maybeConstTripCount = getConstantTripCount(forOp); - if (!maybeConstTripCount.hasValue()) { - hasLoopWithNonConstTripCount = true; - return; - } - stats->tripCountMap[forInst] = maybeConstTripCount.getValue(); - }); - } -}; - -// Computes the total cost of the loop nest rooted at 'forOp'. -// Currently, the total cost is computed by counting the total operation -// instance count (i.e. total number of operations in the loop bodyloop -// operation count * loop trip count) for the entire loop nest. -// If 'tripCountOverrideMap' is non-null, overrides the trip count for loops -// specified in the map when computing the total op instance count. -// NOTEs: 1) This is used to compute the cost of computation slices, which are -// sliced along the iteration dimension, and thus reduce the trip count. -// If 'computeCostMap' is non-null, the total op count for forOps specified -// in the map is increased (not overridden) by adding the op count from the -// map to the existing op count for the for loop. This is done before -// multiplying by the loop's trip count, and is used to model the cost of -// inserting a sliced loop nest of known cost into the loop's body. -// 2) This is also used to compute the cost of fusing a slice of some loop nest -// within another loop. -static int64_t getComputeCost( - Operation *forInst, LoopNestStats *stats, - llvm::SmallDenseMap *tripCountOverrideMap, - DenseMap *computeCostMap) { - // 'opCount' is the total number operations in one iteration of 'forOp' body, - // minus terminator op which is a no-op. - int64_t opCount = stats->opCountMap[forInst] - 1; - if (stats->loopMap.count(forInst) > 0) { - for (auto childForOp : stats->loopMap[forInst]) { - opCount += getComputeCost(childForOp.getOperation(), stats, - tripCountOverrideMap, computeCostMap); - } - } - // Add in additional op instances from slice (if specified in map). - if (computeCostMap != nullptr) { - auto it = computeCostMap->find(forInst); - if (it != computeCostMap->end()) { - opCount += it->second; - } - } - // Override trip count (if specified in map). - int64_t tripCount = stats->tripCountMap[forInst]; - if (tripCountOverrideMap != nullptr) { - auto it = tripCountOverrideMap->find(forInst); - if (it != tripCountOverrideMap->end()) { - tripCount = it->second; - } - } - // Returns the total number of dynamic instances of operations in loop body. - return tripCount * opCount; -} - -} // end anonymous namespace - -// TODO(andydavis,b/126426796): extend this to handle multiple result maps. -static Optional getConstDifference(AffineMap lbMap, AffineMap ubMap) { - assert(lbMap.getNumResults() == 1 && "expected single result bound map"); - assert(ubMap.getNumResults() == 1 && "expected single result bound map"); - assert(lbMap.getNumDims() == ubMap.getNumDims()); - assert(lbMap.getNumSymbols() == ubMap.getNumSymbols()); - AffineExpr lbExpr(lbMap.getResult(0)); - AffineExpr ubExpr(ubMap.getResult(0)); - auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(), - lbMap.getNumSymbols()); - auto cExpr = loopSpanExpr.dyn_cast(); - if (!cExpr) - return None; - return cExpr.getValue(); -} - -// Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop -// nest surrounding 'srcAccess' utilizing slice loop bounds in 'sliceState'. -// Returns true on success, false otherwise (if a non-constant trip count -// was encountered). -// TODO(andydavis) Make this work with non-unit step loops. -static bool buildSliceTripCountMap( - Operation *srcOpInst, ComputationSliceState *sliceState, - llvm::SmallDenseMap *tripCountMap) { - SmallVector srcLoopIVs; - getLoopIVs(*srcOpInst, &srcLoopIVs); - unsigned numSrcLoopIVs = srcLoopIVs.size(); - // Populate map from AffineForOp -> trip count - for (unsigned i = 0; i < numSrcLoopIVs; ++i) { - AffineMap lbMap = sliceState->lbs[i]; - AffineMap ubMap = sliceState->ubs[i]; - if (lbMap == AffineMap() || ubMap == AffineMap()) { - // The iteration of src loop IV 'i' was not sliced. Use full loop bounds. - if (srcLoopIVs[i].hasConstantLowerBound() && - srcLoopIVs[i].hasConstantUpperBound()) { - (*tripCountMap)[srcLoopIVs[i].getOperation()] = - srcLoopIVs[i].getConstantUpperBound() - - srcLoopIVs[i].getConstantLowerBound(); - continue; - } - return false; - } - Optional tripCount = getConstDifference(lbMap, ubMap); - if (!tripCount.hasValue()) - return false; - (*tripCountMap)[srcLoopIVs[i].getOperation()] = tripCount.getValue(); - } - return true; -} - // Removes load operations from 'srcLoads' which operate on 'memref', and // adds them to 'dstLoads'. static void moveLoadsAccessingMemrefTo(Value *memref, @@ -1110,16 +960,6 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, return newMemRef; } -// Return the number of iterations in the given slice. -static uint64_t getSliceIterationCount( - const llvm::SmallDenseMap &sliceTripCountMap) { - uint64_t iterCount = 1; - for (const auto &count : sliceTripCountMap) { - iterCount *= count.second; - } - return iterCount; -} - // Checks if node 'srcId' (which writes to a live out memref), can be safely // fused into node 'dstId'. Returns true if the following conditions are met: // *) 'srcNode' only writes to live out 'memref'. @@ -1250,25 +1090,16 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, // Walk src loop nest and collect stats. LoopNestStats srcLoopNestStats; - LoopNestStatsCollector srcStatsCollector(&srcLoopNestStats); - srcStatsCollector.collect(srcLoopIVs[0].getOperation()); - // Currently only constant trip count loop nests are supported. - if (srcStatsCollector.hasLoopWithNonConstTripCount) { - LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count loops unsupported.\n"); + if (!getLoopNestStats(srcLoopIVs[0], &srcLoopNestStats)) return false; - } + // Compute cost of dst loop nest. SmallVector dstLoopIVs; getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs); LoopNestStats dstLoopNestStats; - LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats); - dstStatsCollector.collect(dstLoopIVs[0].getOperation()); - // Currently only constant trip count loop nests are supported. - if (dstStatsCollector.hasLoopWithNonConstTripCount) { - LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count loops unsupported.\n"); + if (!getLoopNestStats(dstLoopIVs[0], &dstLoopNestStats)) return false; - } // Compute the maximum loop depth at which we can can insert the src slice // and still satisfy dest loop nest dependences, for producer-consumer fusion. @@ -1297,10 +1128,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, Optional bestDstLoopDepth = None; // Compute op instance count for the src loop nest without iteration slicing. - uint64_t srcLoopNestCost = - getComputeCost(srcLoopIVs[0].getOperation(), &srcLoopNestStats, - /*tripCountOverrideMap=*/nullptr, - /*computeCostMap=*/nullptr); + uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], srcLoopNestStats); // Compute src loop nest write region size. MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc()); @@ -1317,15 +1145,10 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, int64_t srcWriteRegionSizeBytes = maybeSrcWriteRegionSizeBytes.getValue(); // Compute op instance count for the src loop nest. - uint64_t dstLoopNestCost = - getComputeCost(dstLoopIVs[0].getOperation(), &dstLoopNestStats, - /*tripCountOverrideMap=*/nullptr, - /*computeCostMap=*/nullptr); + uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], dstLoopNestStats); // Evaluate all depth choices for materializing the slice in the destination // loop nest. - llvm::SmallDenseMap sliceTripCountMap; - DenseMap computeCostMap; for (unsigned i = maxDstLoopDepth; i >= 1; --i) { // Compute the union of slice bounds of all ops in 'dstLoadOpInsts'. if (failed(mlir::computeSliceUnion({srcOpInst}, dstLoadOpInsts, @@ -1338,47 +1161,14 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, continue; } - // Build trip count map for computation slice. We'll skip cases where the - // trip count was non-constant. - sliceTripCountMap.clear(); - if (!buildSliceTripCountMap(srcOpInst, &sliceStates[i - 1], - &sliceTripCountMap)) { - LLVM_DEBUG(llvm::dbgs() << "Unable to build slice trip count map.\n."); + int64_t fusedLoopNestComputeCost; + if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstLoopIVs[0], + dstLoopNestStats, &sliceStates[i - 1], + &fusedLoopNestComputeCost)) { + LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost.\n."); continue; } - // Checks whether a store to load forwarding will happen. - int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap); - assert(sliceIterationCount > 0); - bool storeLoadFwdGuaranteed = (sliceIterationCount == 1); - - // Compute cost of fusion for this dest loop depth. - - computeCostMap.clear(); - - // The store and loads to this memref will disappear. - // TODO(andydavis) Add load coalescing to memref data flow opt pass. - if (storeLoadFwdGuaranteed) { - // A single store disappears: -1 for that. - computeCostMap[srcLoopIVs[numSrcLoopIVs - 1].getOperation()] = -1; - for (auto *loadOp : dstLoadOpInsts) - if (auto forOp = dyn_cast_or_null(loadOp->getParentOp())) - computeCostMap[forOp] = -1; - } - - // Compute op instance count for the src loop nest with iteration slicing. - int64_t sliceComputeCost = - getComputeCost(srcLoopIVs[0].getOperation(), &srcLoopNestStats, - /*tripCountOverrideMap=*/&sliceTripCountMap, - /*computeCostMap=*/&computeCostMap); - - // Compute cost of fusion for this depth. - computeCostMap[dstLoopIVs[i - 1].getOperation()] = sliceComputeCost; - - int64_t fusedLoopNestComputeCost = - getComputeCost(dstLoopIVs[0].getOperation(), &dstLoopNestStats, - /*tripCountOverrideMap=*/nullptr, &computeCostMap); - double additionalComputeFraction = fusedLoopNestComputeCost / (static_cast(srcLoopNestCost) + dstLoopNestCost) - @@ -1427,7 +1217,6 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, << 100.0 * additionalComputeFraction << "%\n" << " storage reduction factor: " << storageReduction << "x\n" << " fused nest cost: " << fusedLoopNestComputeCost << "\n" - << " slice iteration count: " << sliceIterationCount << "\n" << " src write region size: " << srcWriteRegionSizeBytes << "\n" << " slice write region size: " << sliceWriteRegionSizeBytes << "\n"; diff --git a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp index 1fb41a2a5e2..93503d11e0a 100644 --- a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp @@ -24,6 +24,7 @@ #include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" +#include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Analysis/Utils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -250,3 +251,236 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, return FusionResult::Success; } + +/// Collect loop nest statistics (eg. loop trip count and operation count) +/// in 'stats' for loop nest rooted at 'forOp'. Returns true on success, +/// returns false otherwise. +bool mlir::getLoopNestStats(AffineForOp forOpRoot, LoopNestStats *stats) { + bool ret = true; + forOpRoot.getOperation()->walk([&](AffineForOp forOp) { + auto *childForOp = forOp.getOperation(); + auto *parentForOp = forOp.getOperation()->getParentOp(); + if (parentForOp != nullptr) { + if (!isa(parentForOp)) { + LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp"); + ret = false; + return; + } + // Add mapping to 'forOp' from its parent AffineForOp. + stats->loopMap[parentForOp].push_back(forOp); + } + + // Record the number of op operations in the body of 'forOp'. + unsigned count = 0; + stats->opCountMap[childForOp] = 0; + for (auto &op : *forOp.getBody()) { + if (!isa(op) && !isa(op)) + ++count; + } + stats->opCountMap[childForOp] = count; + // Record trip count for 'forOp'. Set flag if trip count is not + // constant. + Optional maybeConstTripCount = getConstantTripCount(forOp); + if (!maybeConstTripCount.hasValue()) { + // Currently only constant trip count loop nests are supported. + LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported"); + ret = false; + return; + } + stats->tripCountMap[childForOp] = maybeConstTripCount.getValue(); + }); + return ret; +} + +// Computes the total cost of the loop nest rooted at 'forOp'. +// Currently, the total cost is computed by counting the total operation +// instance count (i.e. total number of operations in the loop bodyloop +// operation count * loop trip count) for the entire loop nest. +// If 'tripCountOverrideMap' is non-null, overrides the trip count for loops +// specified in the map when computing the total op instance count. +// NOTEs: 1) This is used to compute the cost of computation slices, which are +// sliced along the iteration dimension, and thus reduce the trip count. +// If 'computeCostMap' is non-null, the total op count for forOps specified +// in the map is increased (not overridden) by adding the op count from the +// map to the existing op count for the for loop. This is done before +// multiplying by the loop's trip count, and is used to model the cost of +// inserting a sliced loop nest of known cost into the loop's body. +// 2) This is also used to compute the cost of fusing a slice of some loop nest +// within another loop. +static int64_t getComputeCostHelper( + Operation *forOp, LoopNestStats &stats, + llvm::SmallDenseMap *tripCountOverrideMap, + DenseMap *computeCostMap) { + // 'opCount' is the total number operations in one iteration of 'forOp' body, + // minus terminator op which is a no-op. + int64_t opCount = stats.opCountMap[forOp] - 1; + if (stats.loopMap.count(forOp) > 0) { + for (auto childForOp : stats.loopMap[forOp]) { + opCount += getComputeCostHelper(childForOp.getOperation(), stats, + tripCountOverrideMap, computeCostMap); + } + } + // Add in additional op instances from slice (if specified in map). + if (computeCostMap != nullptr) { + auto it = computeCostMap->find(forOp); + if (it != computeCostMap->end()) { + opCount += it->second; + } + } + // Override trip count (if specified in map). + int64_t tripCount = stats.tripCountMap[forOp]; + if (tripCountOverrideMap != nullptr) { + auto it = tripCountOverrideMap->find(forOp); + if (it != tripCountOverrideMap->end()) { + tripCount = it->second; + } + } + // Returns the total number of dynamic instances of operations in loop body. + return tripCount * opCount; +} + +// TODO(andydavis,b/126426796): extend this to handle multiple result maps. +static Optional getConstDifference(AffineMap lbMap, AffineMap ubMap) { + assert(lbMap.getNumResults() == 1 && "expected single result bound map"); + assert(ubMap.getNumResults() == 1 && "expected single result bound map"); + assert(lbMap.getNumDims() == ubMap.getNumDims()); + assert(lbMap.getNumSymbols() == ubMap.getNumSymbols()); + AffineExpr lbExpr(lbMap.getResult(0)); + AffineExpr ubExpr(ubMap.getResult(0)); + auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(), + lbMap.getNumSymbols()); + auto cExpr = loopSpanExpr.dyn_cast(); + if (!cExpr) + return None; + return cExpr.getValue(); +} + +// Return the number of iterations in the given slice. +static uint64_t getSliceIterationCount( + const llvm::SmallDenseMap &sliceTripCountMap) { + uint64_t iterCount = 1; + for (const auto &count : sliceTripCountMap) { + iterCount *= count.second; + } + return iterCount; +} + +// Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop +// nest surrounding represented by slice loop bounds in 'slice'. +// Returns true on success, false otherwise (if a non-constant trip count +// was encountered). +// TODO(andydavis) Make this work with non-unit step loops. +static bool buildSliceTripCountMap( + ComputationSliceState *slice, + llvm::SmallDenseMap *tripCountMap) { + unsigned numSrcLoopIVs = slice->ivs.size(); + // Populate map from AffineForOp -> trip count + for (unsigned i = 0; i < numSrcLoopIVs; ++i) { + AffineForOp forOp = getForInductionVarOwner(slice->ivs[i]); + auto *op = forOp.getOperation(); + AffineMap lbMap = slice->lbs[i]; + AffineMap ubMap = slice->ubs[i]; + if (lbMap == AffineMap() || ubMap == AffineMap()) { + // The iteration of src loop IV 'i' was not sliced. Use full loop bounds. + if (forOp.hasConstantLowerBound() && forOp.hasConstantUpperBound()) { + (*tripCountMap)[op] = + forOp.getConstantUpperBound() - forOp.getConstantLowerBound(); + continue; + } + Optional maybeConstTripCount = getConstantTripCount(forOp); + if (maybeConstTripCount.hasValue()) { + (*tripCountMap)[op] = maybeConstTripCount.getValue(); + continue; + } + return false; + } + Optional tripCount = getConstDifference(lbMap, ubMap); + // Slice bounds are created with a constant ub - lb difference. + if (!tripCount.hasValue()) + return false; + (*tripCountMap)[op] = tripCount.getValue(); + } + return true; +} + +/// Computes the total cost of the loop nest rooted at 'forOp' using 'stats'. +/// Currently, the total cost is computed by counting the total operation +/// instance count (i.e. total number of operations in the loop body * loop +/// trip count) for the entire loop nest. +int64_t mlir::getComputeCost(AffineForOp forOp, LoopNestStats &stats) { + return getComputeCostHelper(forOp.getOperation(), stats, + /*tripCountOverrideMap=*/nullptr, + /*computeCostMap=*/nullptr); +} + +/// Computes and returns in 'computeCost', the total compute cost of fusing the +/// 'slice' of the loop nest rooted at 'srcForOp' into 'dstForOp'. Currently, +/// the total cost is computed by counting the total operation instance count +/// (i.e. total number of operations in the loop body * loop trip count) for +/// the entire loop nest. +bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, + AffineForOp dstForOp, LoopNestStats &dstStats, + ComputationSliceState *slice, + int64_t *computeCost) { + llvm::SmallDenseMap sliceTripCountMap; + DenseMap computeCostMap; + + // Build trip count map for computation slice. + if (!buildSliceTripCountMap(slice, &sliceTripCountMap)) + return false; + // Checks whether a store to load forwarding will happen. + int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap); + assert(sliceIterationCount > 0); + bool storeLoadFwdGuaranteed = (sliceIterationCount == 1); + auto *insertPointParent = slice->insertPoint->getParentOp(); + + // The store and loads to this memref will disappear. + // TODO(andydavis) Add load coalescing to memref data flow opt pass. + if (storeLoadFwdGuaranteed) { + // Subtract from operation count the loads/store we expect load/store + // forwarding to remove. + unsigned storeCount = 0; + llvm::SmallDenseSet storeMemrefs; + srcForOp.getOperation()->walk([&](Operation *op) { + if (auto storeOp = dyn_cast(op)) { + storeMemrefs.insert(storeOp.getMemRef()); + ++storeCount; + } + }); + // Subtract out any store ops in single-iteration src slice loop nest. + if (storeCount > 0) + computeCostMap[insertPointParent] = -storeCount; + // Subtract out any load users of 'storeMemrefs' nested below + // 'insertPointParent'. + for (auto *value : storeMemrefs) { + for (auto *user : value->getUsers()) { + if (auto loadOp = dyn_cast(user)) { + SmallVector loops; + // Check if any loop in loop nest surrounding 'user' is + // 'insertPointParent'. + getLoopIVs(*user, &loops); + if (llvm::is_contained(loops, cast(insertPointParent))) { + if (auto forOp = + dyn_cast_or_null(user->getParentOp())) { + if (computeCostMap.count(forOp) == 0) + computeCostMap[forOp] = 0; + computeCostMap[forOp] -= 1; + } + } + } + } + } + } + + // Compute op instance count for the src loop nest with iteration slicing. + int64_t sliceComputeCost = getComputeCostHelper( + srcForOp.getOperation(), srcStats, &sliceTripCountMap, &computeCostMap); + + // Compute cost of fusion for this depth. + computeCostMap[insertPointParent] = sliceComputeCost; + + *computeCost = + getComputeCostHelper(dstForOp.getOperation(), dstStats, + /*tripCountOverrideMap=*/nullptr, &computeCostMap); + return true; +} -- cgit v1.2.3 From 54cd6a7e97a226738e2c85b86559918dd9e3cd5d Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 1 Jul 2019 10:29:09 -0700 Subject: NFC: Refactor Function to be value typed. Move the data members out of Function and into a new impl storage class 'FunctionStorage'. This allows for Function to become value typed, which will greatly simplify the transition of Function to FuncOp(given that FuncOp is also value typed). PiperOrigin-RevId: 255983022 --- mlir/bindings/python/pybind.cpp | 35 +-- .../Linalg/Linalg1/include/linalg1/Common.h | 24 +-- mlir/examples/Linalg/Linalg2/Example.cpp | 20 +- mlir/examples/Linalg/Linalg3/Conversion.cpp | 22 +- mlir/examples/Linalg/Linalg3/Example.cpp | 32 +-- mlir/examples/Linalg/Linalg3/Execution.cpp | 22 +- .../Linalg/Linalg3/include/linalg3/Transforms.h | 6 +- .../Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp | 2 +- mlir/examples/Linalg/Linalg3/lib/Transforms.cpp | 12 +- mlir/examples/Linalg/Linalg4/Example.cpp | 40 ++-- .../Linalg/Linalg4/include/linalg4/Transforms.h | 4 +- mlir/examples/Linalg/Linalg4/lib/Transforms.cpp | 9 +- mlir/examples/toy/Ch2/mlir/MLIRGen.cpp | 28 +-- mlir/examples/toy/Ch3/mlir/MLIRGen.cpp | 28 +-- mlir/examples/toy/Ch4/mlir/MLIRGen.cpp | 28 +-- mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp | 54 ++--- mlir/examples/toy/Ch5/mlir/LateLowering.cpp | 22 +- mlir/examples/toy/Ch5/mlir/MLIRGen.cpp | 28 +-- mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp | 54 ++--- mlir/include/mlir/Analysis/Dominance.h | 4 +- mlir/include/mlir/Analysis/NestedMatcher.h | 4 +- mlir/include/mlir/ExecutionEngine/MemRefUtils.h | 2 +- mlir/include/mlir/GPU/GPUDialect.h | 6 +- mlir/include/mlir/IR/Attributes.h | 3 +- mlir/include/mlir/IR/Block.h | 2 +- mlir/include/mlir/IR/Builders.h | 2 +- mlir/include/mlir/IR/Dialect.h | 10 +- mlir/include/mlir/IR/Function.h | 234 ++++++++++++++------- mlir/include/mlir/IR/Module.h | 66 ++++-- mlir/include/mlir/IR/Operation.h | 2 +- mlir/include/mlir/IR/PatternMatch.h | 2 +- mlir/include/mlir/IR/Region.h | 11 +- mlir/include/mlir/IR/SymbolTable.h | 12 +- mlir/include/mlir/IR/Value.h | 4 +- mlir/include/mlir/LLVMIR/LLVMDialect.h | 2 +- mlir/include/mlir/Pass/AnalysisManager.h | 18 +- mlir/include/mlir/Pass/Pass.h | 17 +- mlir/include/mlir/Pass/PassInstrumentation.h | 10 +- mlir/include/mlir/StandardOps/Ops.td | 4 +- mlir/include/mlir/Transforms/DialectConversion.h | 4 +- mlir/include/mlir/Transforms/LowerAffine.h | 2 +- mlir/include/mlir/Transforms/ViewFunctionGraph.h | 4 +- mlir/lib/AffineOps/AffineOps.cpp | 2 +- mlir/lib/Analysis/Dominance.cpp | 9 +- mlir/lib/Analysis/OpStats.cpp | 2 +- mlir/lib/Analysis/TestParallelismDetection.cpp | 2 +- mlir/lib/Analysis/Verifier.cpp | 14 +- .../GPUToCUDA/ConvertKernelFuncToCubin.cpp | 6 +- .../GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp | 106 +++++----- .../GPUToCUDA/GenerateCubinAccessors.cpp | 26 +-- .../StandardToLLVM/ConvertStandardToLLVM.cpp | 18 +- .../FxpMathOps/Transforms/LowerUniformRealMath.cpp | 4 +- .../Dialect/QuantOps/Transforms/ConvertConst.cpp | 2 +- .../QuantOps/Transforms/ConvertSimQuant.cpp | 2 +- mlir/lib/ExecutionEngine/MemRefUtils.cpp | 10 +- mlir/lib/GPU/IR/GPUDialect.cpp | 18 +- mlir/lib/GPU/Transforms/KernelOutlining.cpp | 28 +-- mlir/lib/IR/AsmPrinter.cpp | 50 ++--- mlir/lib/IR/Attributes.cpp | 5 - mlir/lib/IR/Block.cpp | 2 +- mlir/lib/IR/Builders.cpp | 4 +- mlir/lib/IR/Dialect.cpp | 15 ++ mlir/lib/IR/Function.cpp | 59 +++--- mlir/lib/IR/Operation.cpp | 9 +- mlir/lib/IR/Region.cpp | 10 +- mlir/lib/IR/SymbolTable.cpp | 22 +- mlir/lib/IR/Value.cpp | 8 +- mlir/lib/LLVMIR/IR/LLVMDialect.cpp | 4 +- mlir/lib/Linalg/Transforms/Fusion.cpp | 2 +- mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp | 69 +++--- mlir/lib/Linalg/Transforms/LowerToLoops.cpp | 3 +- mlir/lib/Linalg/Transforms/Tiling.cpp | 2 +- mlir/lib/Parser/Parser.cpp | 28 +-- mlir/lib/Pass/IRPrinting.cpp | 12 +- mlir/lib/Pass/Pass.cpp | 23 +- mlir/lib/Pass/PassDetail.h | 2 +- .../Transforms/AddDefaultStatsTestPass.cpp | 2 +- .../Transforms/InferQuantizedTypesPass.cpp | 2 +- .../Transforms/RemoveInstrumentationPass.cpp | 2 +- mlir/lib/SPIRV/Serialization/ConvertFromBinary.cpp | 6 +- mlir/lib/SPIRV/Serialization/ConvertToBinary.cpp | 2 +- .../SPIRV/Transforms/StdOpsToSPIRVConversion.cpp | 2 +- mlir/lib/StandardOps/Ops.cpp | 12 +- mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp | 31 +-- mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 4 +- mlir/lib/Transforms/Canonicalizer.cpp | 2 +- mlir/lib/Transforms/DialectConversion.cpp | 54 ++--- mlir/lib/Transforms/DmaGeneration.cpp | 10 +- mlir/lib/Transforms/LoopFusion.cpp | 12 +- mlir/lib/Transforms/LoopTiling.cpp | 2 +- mlir/lib/Transforms/LoopUnroll.cpp | 8 +- mlir/lib/Transforms/LowerAffine.cpp | 2 +- mlir/lib/Transforms/MaterializeVectors.cpp | 12 +- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 2 +- mlir/lib/Transforms/StripDebugInfo.cpp | 2 +- .../Utils/GreedyPatternRewriteDriver.cpp | 4 +- mlir/lib/Transforms/Utils/LoopUtils.cpp | 2 +- mlir/lib/Transforms/Vectorize.cpp | 4 +- mlir/lib/Transforms/ViewFunctionGraph.cpp | 4 +- mlir/test/EDSC/builder-api-test.cpp | 150 +++++++------ .../test/lib/Transforms/TestVectorizationUtils.cpp | 16 +- mlir/tools/mlir-cpu-runner/mlir-cpu-runner-lib.cpp | 18 +- mlir/unittests/Pass/AnalysisManagerTest.cpp | 20 +- 103 files changed, 986 insertions(+), 874 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/bindings/python/pybind.cpp b/mlir/bindings/python/pybind.cpp index 222ef52b9be..cdf4a7fe89c 100644 --- a/mlir/bindings/python/pybind.cpp +++ b/mlir/bindings/python/pybind.cpp @@ -17,6 +17,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/IR/Function.h" #include "llvm/IR/Module.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Support/raw_ostream.h" @@ -110,13 +111,14 @@ struct PythonValueHandle { struct PythonFunction { PythonFunction() : function{nullptr} {} PythonFunction(mlir_func_t f) : function{f} {} - PythonFunction(mlir::Function *f) : function{f} {} + PythonFunction(mlir::Function f) + : function(const_cast(f.getAsOpaquePointer())) {} operator mlir_func_t() { return function; } std::string str() { - mlir::Function *f = reinterpret_cast(function); + mlir::Function f = mlir::Function::getFromOpaquePointer(function); std::string res; llvm::raw_string_ostream os(res); - f->print(os); + f.print(os); return res; } @@ -124,18 +126,18 @@ struct PythonFunction { // declaration, add the entry block, transforming the declaration into a // definition. Return true if the block was added, false otherwise. bool define() { - auto *f = reinterpret_cast(function); - if (!f->getBlocks().empty()) + auto f = mlir::Function::getFromOpaquePointer(function); + if (!f.getBlocks().empty()) return false; - f->addEntryBlock(); + f.addEntryBlock(); return true; } PythonValueHandle arg(unsigned index) { - Function *f = static_cast(function); - assert(index < f->getNumArguments() && "argument index out of bounds"); - return PythonValueHandle(ValueHandle(f->getArgument(index))); + auto f = mlir::Function::getFromOpaquePointer(function); + assert(index < f.getNumArguments() && "argument index out of bounds"); + return PythonValueHandle(ValueHandle(f.getArgument(index))); } mlir_func_t function; @@ -250,10 +252,9 @@ struct PythonFunctionContext { PythonFunction enter() { assert(function.function && "function is not set up"); - auto *mlirFunc = static_cast(function.function); - contextBuilder.emplace(mlirFunc->getBody()); - context = - new mlir::edsc::ScopedContext(*contextBuilder, mlirFunc->getLoc()); + auto mlirFunc = mlir::Function::getFromOpaquePointer(function.function); + contextBuilder.emplace(mlirFunc.getBody()); + context = new mlir::edsc::ScopedContext(*contextBuilder, mlirFunc.getLoc()); return function; } @@ -594,7 +595,7 @@ PythonMLIRModule::declareFunction(const std::string &name, } // Create the function itself. - auto *func = new mlir::Function( + auto func = mlir::Function::create( UnknownLoc::get(&mlirContext), name, mlir::Type::getFromOpaquePointer(funcType).cast(), attrs, inputAttrs); @@ -652,9 +653,9 @@ PYBIND11_MODULE(pybind, m) { return ValueHandle::create(value, floatType); }); m.def("constant_function", [](PythonFunction func) -> PythonValueHandle { - auto *function = reinterpret_cast(func.function); - auto attr = FunctionAttr::get(function); - return ValueHandle::create(function->getType(), attr); + auto function = Function::getFromOpaquePointer(func.function); + auto attr = FunctionAttr::get(function.getName(), function.getContext()); + return ValueHandle::create(function.getType(), attr); }); m.def("appendTo", [](const PythonBlockHandle &handle) { return PythonBlockAppender(handle); diff --git a/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h b/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h index ddd6df9fb89..1f129c6b283 100644 --- a/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h +++ b/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h @@ -57,15 +57,15 @@ inline mlir::MemRefType floatMemRefType(mlir::MLIRContext *context, } /// A basic function builder -inline mlir::Function *makeFunction(mlir::Module &module, llvm::StringRef name, - llvm::ArrayRef types, - llvm::ArrayRef resultTypes) { +inline mlir::Function makeFunction(mlir::Module &module, llvm::StringRef name, + llvm::ArrayRef types, + llvm::ArrayRef resultTypes) { auto *context = module.getContext(); - auto *function = new mlir::Function( + auto function = mlir::Function::create( mlir::UnknownLoc::get(context), name, mlir::FunctionType::get({types}, resultTypes, context)); - function->addEntryBlock(); - module.getFunctions().push_back(function); + function.addEntryBlock(); + module.push_back(function); return function; } @@ -83,19 +83,19 @@ inline std::unique_ptr cleanupPassManager() { /// llvm::outs() for FileCheck'ing. /// If an error occurs, dump to llvm::errs() and do not print to llvm::outs() /// which will make the associated FileCheck test fail. -inline void cleanupAndPrintFunction(mlir::Function *f) { +inline void cleanupAndPrintFunction(mlir::Function f) { bool printToOuts = true; - auto check = [f, &printToOuts](mlir::LogicalResult result) { + auto check = [&f, &printToOuts](mlir::LogicalResult result) { if (failed(result)) { - f->emitError("Verification and cleanup passes failed"); + f.emitError("Verification and cleanup passes failed"); printToOuts = false; } }; auto pm = cleanupPassManager(); - check(f->getModule()->verify()); - check(pm->run(f->getModule())); + check(f.getModule()->verify()); + check(pm->run(f.getModule())); if (printToOuts) - f->print(llvm::outs()); + f.print(llvm::outs()); } /// Helper class to sugar building loop nests from indexings that appear in diff --git a/mlir/examples/Linalg/Linalg2/Example.cpp b/mlir/examples/Linalg/Linalg2/Example.cpp index a415daebdf5..9534711f1f4 100644 --- a/mlir/examples/Linalg/Linalg2/Example.cpp +++ b/mlir/examples/Linalg/Linalg2/Example.cpp @@ -36,14 +36,14 @@ TEST_FUNC(linalg_ops) { MLIRContext context; Module module(&context); auto indexType = mlir::IndexType::get(&context); - mlir::Function *f = + mlir::Function f = makeFunction(module, "linalg_ops", {indexType, indexType, indexType}, {}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); // clang-format off - ValueHandle M(f->getArgument(0)), N(f->getArgument(1)), K(f->getArgument(2)), + ValueHandle M(f.getArgument(0)), N(f.getArgument(1)), K(f.getArgument(2)), rM = range(constant_index(0), M, constant_index(1)), rN = range(constant_index(0), N, constant_index(1)), rK = range(constant_index(0), K, constant_index(1)), @@ -75,14 +75,14 @@ TEST_FUNC(linalg_ops_folded_slices) { MLIRContext context; Module module(&context); auto indexType = mlir::IndexType::get(&context); - mlir::Function *f = makeFunction(module, "linalg_ops_folded_slices", - {indexType, indexType, indexType}, {}); + mlir::Function f = makeFunction(module, "linalg_ops_folded_slices", + {indexType, indexType, indexType}, {}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); // clang-format off - ValueHandle M(f->getArgument(0)), N(f->getArgument(1)), K(f->getArgument(2)), + ValueHandle M(f.getArgument(0)), N(f.getArgument(1)), K(f.getArgument(2)), rM = range(constant_index(0), M, constant_index(1)), rN = range(constant_index(0), N, constant_index(1)), rK = range(constant_index(0), K, constant_index(1)), @@ -104,7 +104,7 @@ TEST_FUNC(linalg_ops_folded_slices) { // CHECK-NEXT: linalg.dot({{.*}}, {{.*}}, {{.*}}) : !linalg.view // clang-format on - f->walk([](SliceOp slice) { + f.walk([](SliceOp slice) { auto *sliceResult = slice.getResult(); auto viewOp = emitAndReturnFullyComposedView(sliceResult); sliceResult->replaceAllUsesWith(viewOp.getResult()); diff --git a/mlir/examples/Linalg/Linalg3/Conversion.cpp b/mlir/examples/Linalg/Linalg3/Conversion.cpp index 37d1b51f53e..23d1cfef5dc 100644 --- a/mlir/examples/Linalg/Linalg3/Conversion.cpp +++ b/mlir/examples/Linalg/Linalg3/Conversion.cpp @@ -37,26 +37,26 @@ using namespace linalg; using namespace linalg::common; using namespace linalg::intrinsics; -Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) { +Function makeFunctionWithAMatmulOp(Module &module, StringRef name) { MLIRContext *context = module.getContext(); auto dynamic2DMemRefType = floatMemRefType<2>(context); - mlir::Function *f = linalg::common::makeFunction( + mlir::Function f = linalg::common::makeFunction( module, name, {dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); // clang-format off ValueHandle - M = dim(f->getArgument(0), 0), - N = dim(f->getArgument(2), 1), - K = dim(f->getArgument(0), 1), + M = dim(f.getArgument(0), 0), + N = dim(f.getArgument(2), 1), + K = dim(f.getArgument(0), 1), rM = range(constant_index(0), M, constant_index(1)), rN = range(constant_index(0), N, constant_index(1)), rK = range(constant_index(0), K, constant_index(1)), - vA = view(f->getArgument(0), {rM, rK}), - vB = view(f->getArgument(1), {rK, rN}), - vC = view(f->getArgument(2), {rM, rN}); + vA = view(f.getArgument(0), {rM, rK}), + vB = view(f.getArgument(1), {rK, rN}), + vC = view(f.getArgument(2), {rM, rN}); matmul(vA, vB, vC); ret(); // clang-format on @@ -67,7 +67,7 @@ Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) { TEST_FUNC(foo) { MLIRContext context; Module module(&context); - mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_loops"); + mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_loops"); lowerToLoops(f); convertLinalg3ToLLVM(module); diff --git a/mlir/examples/Linalg/Linalg3/Example.cpp b/mlir/examples/Linalg/Linalg3/Example.cpp index f02aef920e4..8b04344b19e 100644 --- a/mlir/examples/Linalg/Linalg3/Example.cpp +++ b/mlir/examples/Linalg/Linalg3/Example.cpp @@ -34,26 +34,26 @@ using namespace linalg; using namespace linalg::common; using namespace linalg::intrinsics; -Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) { +Function makeFunctionWithAMatmulOp(Module &module, StringRef name) { MLIRContext *context = module.getContext(); auto dynamic2DMemRefType = floatMemRefType<2>(context); - mlir::Function *f = linalg::common::makeFunction( + mlir::Function f = linalg::common::makeFunction( module, name, {dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {}); - mlir::OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + mlir::OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); // clang-format off ValueHandle - M = dim(f->getArgument(0), 0), - N = dim(f->getArgument(2), 1), - K = dim(f->getArgument(0), 1), + M = dim(f.getArgument(0), 0), + N = dim(f.getArgument(2), 1), + K = dim(f.getArgument(0), 1), rM = range(constant_index(0), M, constant_index(1)), rN = range(constant_index(0), N, constant_index(1)), rK = range(constant_index(0), K, constant_index(1)), - vA = view(f->getArgument(0), {rM, rK}), - vB = view(f->getArgument(1), {rK, rN}), - vC = view(f->getArgument(2), {rM, rN}); + vA = view(f.getArgument(0), {rM, rK}), + vB = view(f.getArgument(1), {rK, rN}), + vC = view(f.getArgument(2), {rM, rN}); matmul(vA, vB, vC); ret(); // clang-format on @@ -64,7 +64,7 @@ Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) { TEST_FUNC(matmul_as_matvec) { MLIRContext context; Module module(&context); - mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec"); + mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec"); lowerToFinerGrainedTensorContraction(f); composeSliceOps(f); // clang-format off @@ -82,7 +82,7 @@ TEST_FUNC(matmul_as_matvec) { TEST_FUNC(matmul_as_dot) { MLIRContext context; Module module(&context); - mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_dot"); + mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_dot"); lowerToFinerGrainedTensorContraction(f); lowerToFinerGrainedTensorContraction(f); composeSliceOps(f); @@ -103,7 +103,7 @@ TEST_FUNC(matmul_as_dot) { TEST_FUNC(matmul_as_loops) { MLIRContext context; Module module(&context); - mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_loops"); + mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_loops"); lowerToLoops(f); composeSliceOps(f); // clang-format off @@ -135,7 +135,7 @@ TEST_FUNC(matmul_as_loops) { TEST_FUNC(matmul_as_matvec_as_loops) { MLIRContext context; Module module(&context); - mlir::Function *f = + mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec_as_loops"); lowerToFinerGrainedTensorContraction(f); lowerToLoops(f); @@ -166,14 +166,14 @@ TEST_FUNC(matmul_as_matvec_as_loops) { TEST_FUNC(matmul_as_matvec_as_affine) { MLIRContext context; Module module(&context); - mlir::Function *f = + mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec_as_affine"); lowerToFinerGrainedTensorContraction(f); composeSliceOps(f); lowerToLoops(f); PassManager pm; pm.addPass(createLowerLinalgLoadStorePass()); - if (succeeded(pm.run(f->getModule()))) + if (succeeded(pm.run(f.getModule()))) cleanupAndPrintFunction(f); // clang-format off diff --git a/mlir/examples/Linalg/Linalg3/Execution.cpp b/mlir/examples/Linalg/Linalg3/Execution.cpp index 00d571cbc99..94b233a56b0 100644 --- a/mlir/examples/Linalg/Linalg3/Execution.cpp +++ b/mlir/examples/Linalg/Linalg3/Execution.cpp @@ -37,26 +37,26 @@ using namespace linalg; using namespace linalg::common; using namespace linalg::intrinsics; -Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) { +Function makeFunctionWithAMatmulOp(Module &module, StringRef name) { MLIRContext *context = module.getContext(); auto dynamic2DMemRefType = floatMemRefType<2>(context); - mlir::Function *f = linalg::common::makeFunction( + mlir::Function f = linalg::common::makeFunction( module, name, {dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {}); - mlir::OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + mlir::OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); // clang-format off ValueHandle - M = dim(f->getArgument(0), 0), - N = dim(f->getArgument(2), 1), - K = dim(f->getArgument(0), 1), + M = dim(f.getArgument(0), 0), + N = dim(f.getArgument(2), 1), + K = dim(f.getArgument(0), 1), rM = range(constant_index(0), M, constant_index(1)), rN = range(constant_index(0), N, constant_index(1)), rK = range(constant_index(0), K, constant_index(1)), - vA = view(f->getArgument(0), {rM, rK}), - vB = view(f->getArgument(1), {rK, rN}), - vC = view(f->getArgument(2), {rM, rN}); + vA = view(f.getArgument(0), {rM, rK}), + vB = view(f.getArgument(1), {rK, rN}), + vC = view(f.getArgument(2), {rM, rN}); matmul(vA, vB, vC); ret(); // clang-format on @@ -110,7 +110,7 @@ TEST_FUNC(execution) { // dialect through partial conversions. MLIRContext context; Module module(&context); - mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_loops"); + mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_loops"); lowerToLoops(f); convertLinalg3ToLLVM(module); diff --git a/mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h b/mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h index 9af528e8c51..6c0aec0b000 100644 --- a/mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h +++ b/mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h @@ -55,11 +55,11 @@ makeGenericLoopRanges(mlir::AffineMap operandRangesToLoopMaps, /// Traverses `f` and rewrites linalg.slice, and the operations it depends on, /// to only use linalg.view operations. -void composeSliceOps(mlir::Function *f); +void composeSliceOps(mlir::Function f); /// Traverses `f` and rewrites linalg.matmul(resp. linalg.matvec) /// as linalg.matvec(resp. linalg.dot). -void lowerToFinerGrainedTensorContraction(mlir::Function *f); +void lowerToFinerGrainedTensorContraction(mlir::Function f); /// Operation-wise writing of linalg operations to loop form. /// It is the caller's responsibility to erase the `op` if necessary. @@ -69,7 +69,7 @@ llvm::Optional> writeAsLoops(mlir::Operation *op); /// Traverses `f` and rewrites linalg operations in loop form. -void lowerToLoops(mlir::Function *f); +void lowerToLoops(mlir::Function f); /// Creates a pass that rewrites linalg.load and linalg.store to affine.load and /// affine.store operations. diff --git a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp index 7b559bf2f21..96b0f371ef1 100644 --- a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp @@ -148,7 +148,7 @@ static void populateLinalg3ToLLVMConversionPatterns( void linalg::convertLinalg3ToLLVM(Module &module) { // Remove affine constructs. - for (auto &func : module) { + for (auto func : module) { auto rr = lowerAffineConstructs(func); (void)rr; assert(succeeded(rr) && "affine loop lowering failed"); diff --git a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp index d5c8641acbe..7b9e5ffee96 100644 --- a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp @@ -35,8 +35,8 @@ using namespace mlir::edsc::intrinsics; using namespace linalg; using namespace linalg::intrinsics; -void linalg::composeSliceOps(mlir::Function *f) { - f->walk([](SliceOp sliceOp) { +void linalg::composeSliceOps(mlir::Function f) { + f.walk([](SliceOp sliceOp) { auto *sliceResult = sliceOp.getResult(); auto viewOp = emitAndReturnFullyComposedView(sliceResult); sliceResult->replaceAllUsesWith(viewOp.getResult()); @@ -44,8 +44,8 @@ void linalg::composeSliceOps(mlir::Function *f) { }); } -void linalg::lowerToFinerGrainedTensorContraction(mlir::Function *f) { - f->walk([](Operation *op) { +void linalg::lowerToFinerGrainedTensorContraction(mlir::Function f) { + f.walk([](Operation *op) { if (auto matmulOp = dyn_cast(op)) { matmulOp.writeAsFinerGrainTensorContraction(); } else if (auto matvecOp = dyn_cast(op)) { @@ -211,8 +211,8 @@ linalg::writeAsLoops(Operation *op) { return llvm::None; } -void linalg::lowerToLoops(mlir::Function *f) { - f->walk([](Operation *op) { +void linalg::lowerToLoops(mlir::Function f) { + f.walk([](Operation *op) { if (writeAsLoops(op)) op->erase(); }); diff --git a/mlir/examples/Linalg/Linalg4/Example.cpp b/mlir/examples/Linalg/Linalg4/Example.cpp index cdc05a1cc21..873e57e78f3 100644 --- a/mlir/examples/Linalg/Linalg4/Example.cpp +++ b/mlir/examples/Linalg/Linalg4/Example.cpp @@ -34,27 +34,27 @@ using namespace linalg; using namespace linalg::common; using namespace linalg::intrinsics; -Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) { +Function makeFunctionWithAMatmulOp(Module &module, StringRef name) { MLIRContext *context = module.getContext(); auto dynamic2DMemRefType = floatMemRefType<2>(context); - mlir::Function *f = linalg::common::makeFunction( + mlir::Function f = linalg::common::makeFunction( module, name, {dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); // clang-format off ValueHandle - M = dim(f->getArgument(0), 0), - N = dim(f->getArgument(2), 1), - K = dim(f->getArgument(0), 1), + M = dim(f.getArgument(0), 0), + N = dim(f.getArgument(2), 1), + K = dim(f.getArgument(0), 1), rM = range(constant_index(0), M, constant_index(1)), rN = range(constant_index(0), N, constant_index(1)), rK = range(constant_index(0), K, constant_index(1)), - vA = view(f->getArgument(0), {rM, rK}), - vB = view(f->getArgument(1), {rK, rN}), - vC = view(f->getArgument(2), {rM, rN}); + vA = view(f.getArgument(0), {rM, rK}), + vB = view(f.getArgument(1), {rK, rN}), + vC = view(f.getArgument(2), {rM, rN}); matmul(vA, vB, vC); ret(); // clang-format on @@ -65,11 +65,11 @@ Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) { TEST_FUNC(matmul_tiled_loops) { MLIRContext context; Module module(&context); - mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_tiled_loops"); + mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_tiled_loops"); lowerToTiledLoops(f, {8, 9}); PassManager pm; pm.addPass(createLowerLinalgLoadStorePass()); - if (succeeded(pm.run(f->getModule()))) + if (succeeded(pm.run(f.getModule()))) cleanupAndPrintFunction(f); // clang-format off @@ -96,10 +96,10 @@ TEST_FUNC(matmul_tiled_loops) { TEST_FUNC(matmul_tiled_views) { MLIRContext context; Module module(&context); - mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_tiled_views"); - OpBuilder b(f->getBody()); - lowerToTiledViews(f, {b.create(f->getLoc(), 8), - b.create(f->getLoc(), 9)}); + mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_tiled_views"); + OpBuilder b(f.getBody()); + lowerToTiledViews(f, {b.create(f.getLoc(), 8), + b.create(f.getLoc(), 9)}); composeSliceOps(f); // clang-format off @@ -125,11 +125,11 @@ TEST_FUNC(matmul_tiled_views) { TEST_FUNC(matmul_tiled_views_as_loops) { MLIRContext context; Module module(&context); - mlir::Function *f = + mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_tiled_views_as_loops"); - OpBuilder b(f->getBody()); - lowerToTiledViews(f, {b.create(f->getLoc(), 8), - b.create(f->getLoc(), 9)}); + OpBuilder b(f.getBody()); + lowerToTiledViews(f, {b.create(f.getLoc(), 8), + b.create(f.getLoc(), 9)}); composeSliceOps(f); lowerToLoops(f); // This cannot lower below linalg.load and linalg.store due to lost diff --git a/mlir/examples/Linalg/Linalg4/include/linalg4/Transforms.h b/mlir/examples/Linalg/Linalg4/include/linalg4/Transforms.h index 2165cab6ac1..ba7273e409d 100644 --- a/mlir/examples/Linalg/Linalg4/include/linalg4/Transforms.h +++ b/mlir/examples/Linalg/Linalg4/include/linalg4/Transforms.h @@ -34,12 +34,12 @@ writeAsTiledViews(mlir::Operation *op, llvm::ArrayRef tileSizes); /// Apply `writeAsTiledLoops` on all linalg ops. This is a convenience function /// and is not exposed as a pass because a fixed set of tile sizes for all ops /// in a function can generally not be specified. -void lowerToTiledLoops(mlir::Function *f, llvm::ArrayRef tileSizes); +void lowerToTiledLoops(mlir::Function f, llvm::ArrayRef tileSizes); /// Apply `writeAsTiledViews` on all linalg ops. This is a convenience function /// and is not exposed as a pass because a fixed set of tile sizes for all ops /// in a function can generally not be specified. -void lowerToTiledViews(mlir::Function *f, +void lowerToTiledViews(mlir::Function f, llvm::ArrayRef tileSizes); } // namespace linalg diff --git a/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp index 1a308df1313..16b395da506 100644 --- a/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp @@ -43,9 +43,8 @@ linalg::writeAsTiledLoops(Operation *op, ArrayRef tileSizes) { return llvm::None; } -void linalg::lowerToTiledLoops(mlir::Function *f, - ArrayRef tileSizes) { - f->walk([tileSizes](Operation *op) { +void linalg::lowerToTiledLoops(mlir::Function f, ArrayRef tileSizes) { + f.walk([tileSizes](Operation *op) { if (writeAsTiledLoops(op, tileSizes).hasValue()) op->erase(); }); @@ -185,8 +184,8 @@ linalg::writeAsTiledViews(Operation *op, ArrayRef tileSizes) { return llvm::None; } -void linalg::lowerToTiledViews(mlir::Function *f, ArrayRef tileSizes) { - f->walk([tileSizes](Operation *op) { +void linalg::lowerToTiledViews(mlir::Function f, ArrayRef tileSizes) { + f.walk([tileSizes](Operation *op) { if (auto matmulOp = dyn_cast(op)) { writeAsTiledViews(matmulOp, tileSizes); } else if (auto matvecOp = dyn_cast(op)) { diff --git a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp index 842c7a1d0f8..73789fa41a4 100644 --- a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp @@ -75,7 +75,7 @@ public: auto func = mlirGen(F); if (!func) return nullptr; - theModule->getFunctions().push_back(func.release()); + theModule->push_back(func); } // FIXME: (in the next chapter...) without registering a dialect in MLIR, @@ -129,40 +129,40 @@ private: /// Create the prototype for an MLIR function with as many arguments as the /// provided Toy AST prototype. - mlir::Function *mlirGen(PrototypeAST &proto) { + mlir::Function mlirGen(PrototypeAST &proto) { // This is a generic function, the return type will be inferred later. llvm::SmallVector ret_types; // Arguments type is uniformly a generic array. llvm::SmallVector arg_types(proto.getArgs().size(), getType(VarType{})); auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context); - auto *function = new mlir::Function(loc(proto.loc()), proto.getName(), - func_type, /* attrs = */ {}); + auto function = mlir::Function::create(loc(proto.loc()), proto.getName(), + func_type, /* attrs = */ {}); // Mark the function as generic: it'll require type specialization for every // call site. - if (function->getNumArguments()) - function->setAttr("toy.generic", mlir::BoolAttr::get(true, &context)); + if (function.getNumArguments()) + function.setAttr("toy.generic", mlir::BoolAttr::get(true, &context)); return function; } /// Emit a new function and add it to the MLIR module. - std::unique_ptr mlirGen(FunctionAST &funcAST) { + mlir::Function mlirGen(FunctionAST &funcAST) { // Create a scope in the symbol table to hold variable declarations. ScopedHashTableScope var_scope(symbolTable); // Create an MLIR function for the given prototype. - std::unique_ptr function(mlirGen(*funcAST.getProto())); + mlir::Function function(mlirGen(*funcAST.getProto())); if (!function) return nullptr; // Let's start the body of the function now! // In MLIR the entry block of the function is special: it must have the same // argument list as the function itself. - function->addEntryBlock(); + function.addEntryBlock(); - auto &entryBlock = function->front(); + auto &entryBlock = function.front(); auto &protoArgs = funcAST.getProto()->getArgs(); // Declare all the function arguments in the symbol table. for (const auto &name_value : @@ -172,16 +172,18 @@ private: // Create a builder for the function, it will be used throughout the codegen // to create operations in this function. - builder = llvm::make_unique(function->getBody()); + builder = llvm::make_unique(function.getBody()); // Emit the body of the function. - if (!mlirGen(*funcAST.getBody())) + if (!mlirGen(*funcAST.getBody())) { + function.erase(); return nullptr; + } // Implicitly return void if no return statement was emitted. // FIXME: we may fix the parser instead to always return the last expression // (this would possibly help the REPL case later) - if (function->getBlocks().back().back().getName().getStringRef() != + if (function.getBlocks().back().back().getName().getStringRef() != "toy.return") { ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None); mlirGen(fakeRet); diff --git a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp index e365f37f8c8..23cb85309c2 100644 --- a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp @@ -76,7 +76,7 @@ public: auto func = mlirGen(F); if (!func) return nullptr; - theModule->getFunctions().push_back(func.release()); + theModule->push_back(func); } // FIXME: (in the next chapter...) without registering a dialect in MLIR, @@ -130,40 +130,40 @@ private: /// Create the prototype for an MLIR function with as many arguments as the /// provided Toy AST prototype. - mlir::Function *mlirGen(PrototypeAST &proto) { + mlir::Function mlirGen(PrototypeAST &proto) { // This is a generic function, the return type will be inferred later. llvm::SmallVector ret_types; // Arguments type is uniformly a generic array. llvm::SmallVector arg_types(proto.getArgs().size(), getType(VarType{})); auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context); - auto *function = new mlir::Function(loc(proto.loc()), proto.getName(), - func_type, /* attrs = */ {}); + auto function = mlir::Function::create(loc(proto.loc()), proto.getName(), + func_type, /* attrs = */ {}); // Mark the function as generic: it'll require type specialization for every // call site. - if (function->getNumArguments()) - function->setAttr("toy.generic", mlir::BoolAttr::get(true, &context)); + if (function.getNumArguments()) + function.setAttr("toy.generic", mlir::BoolAttr::get(true, &context)); return function; } /// Emit a new function and add it to the MLIR module. - std::unique_ptr mlirGen(FunctionAST &funcAST) { + mlir::Function mlirGen(FunctionAST &funcAST) { // Create a scope in the symbol table to hold variable declarations. ScopedHashTableScope var_scope(symbolTable); // Create an MLIR function for the given prototype. - std::unique_ptr function(mlirGen(*funcAST.getProto())); + mlir::Function function(mlirGen(*funcAST.getProto())); if (!function) return nullptr; // Let's start the body of the function now! // In MLIR the entry block of the function is special: it must have the same // argument list as the function itself. - function->addEntryBlock(); + function.addEntryBlock(); - auto &entryBlock = function->front(); + auto &entryBlock = function.front(); auto &protoArgs = funcAST.getProto()->getArgs(); // Declare all the function arguments in the symbol table. for (const auto &name_value : @@ -173,16 +173,18 @@ private: // Create a builder for the function, it will be used throughout the codegen // to create operations in this function. - builder = llvm::make_unique(function->getBody()); + builder = llvm::make_unique(function.getBody()); // Emit the body of the function. - if (!mlirGen(*funcAST.getBody())) + if (!mlirGen(*funcAST.getBody())) { + function.erase(); return nullptr; + } // Implicitly return void if no return statement was emitted. // FIXME: we may fix the parser instead to always return the last expression // (this would possibly help the REPL case later) - if (function->getBlocks().back().back().getName().getStringRef() != + if (function.getBlocks().back().back().getName().getStringRef() != "toy.return") { ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None); mlirGen(fakeRet); diff --git a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp index 032766a547f..f2132c29c33 100644 --- a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp @@ -76,7 +76,7 @@ public: auto func = mlirGen(F); if (!func) return nullptr; - theModule->getFunctions().push_back(func.release()); + theModule->push_back(func); } // FIXME: (in the next chapter...) without registering a dialect in MLIR, @@ -130,40 +130,40 @@ private: /// Create the prototype for an MLIR function with as many arguments as the /// provided Toy AST prototype. - mlir::Function *mlirGen(PrototypeAST &proto) { + mlir::Function mlirGen(PrototypeAST &proto) { // This is a generic function, the return type will be inferred later. llvm::SmallVector ret_types; // Arguments type is uniformly a generic array. llvm::SmallVector arg_types(proto.getArgs().size(), getType(VarType{})); auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context); - auto *function = new mlir::Function(loc(proto.loc()), proto.getName(), - func_type, /* attrs = */ {}); + auto function = mlir::Function::create(loc(proto.loc()), proto.getName(), + func_type, /* attrs = */ {}); // Mark the function as generic: it'll require type specialization for every // call site. - if (function->getNumArguments()) - function->setAttr("toy.generic", mlir::BoolAttr::get(true, &context)); + if (function.getNumArguments()) + function.setAttr("toy.generic", mlir::BoolAttr::get(true, &context)); return function; } /// Emit a new function and add it to the MLIR module. - std::unique_ptr mlirGen(FunctionAST &funcAST) { + mlir::Function mlirGen(FunctionAST &funcAST) { // Create a scope in the symbol table to hold variable declarations. ScopedHashTableScope var_scope(symbolTable); // Create an MLIR function for the given prototype. - std::unique_ptr function(mlirGen(*funcAST.getProto())); + mlir::Function function(mlirGen(*funcAST.getProto())); if (!function) return nullptr; // Let's start the body of the function now! // In MLIR the entry block of the function is special: it must have the same // argument list as the function itself. - function->addEntryBlock(); + function.addEntryBlock(); - auto &entryBlock = function->front(); + auto &entryBlock = function.front(); auto &protoArgs = funcAST.getProto()->getArgs(); // Declare all the function arguments in the symbol table. for (const auto &name_value : @@ -173,16 +173,18 @@ private: // Create a builder for the function, it will be used throughout the codegen // to create operations in this function. - builder = llvm::make_unique(function->getBody()); + builder = llvm::make_unique(function.getBody()); // Emit the body of the function. - if (!mlirGen(*funcAST.getBody())) + if (!mlirGen(*funcAST.getBody())) { + function.erase(); return nullptr; + } // Implicitly return void if no return statement was emited. // FIXME: we may fix the parser instead to always return the last expression // (this would possibly help the REPL case later) - if (function->getBlocks().back().back().getName().getStringRef() != + if (function.getBlocks().back().back().getName().getStringRef() != "toy.return") { ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None); mlirGen(fakeRet); diff --git a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp index 688c73645a5..f237fd9fb53 100644 --- a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp @@ -113,14 +113,14 @@ public: // function to process, the mangled name for this specialization, and the // types of the arguments on which to specialize. struct FunctionToSpecialize { - mlir::Function *function; + mlir::Function function; std::string mangledName; SmallVector argumentsType; }; void runOnModule() override { auto &module = getModule(); - auto *main = module.getNamedFunction("main"); + auto main = module.getNamedFunction("main"); if (!main) { emitError(mlir::UnknownLoc::get(module.getContext()), "Shape inference failed: can't find a main function\n"); @@ -139,7 +139,7 @@ public: // Delete any generic function left // FIXME: we may want this as a separate pass. - for (mlir::Function &function : llvm::make_early_inc_range(module)) { + for (mlir::Function function : llvm::make_early_inc_range(module)) { if (auto genericAttr = function.getAttrOfType("toy.generic")) { if (genericAttr.getValue()) @@ -153,7 +153,7 @@ public: mlir::LogicalResult specialize(SmallVectorImpl &funcWorklist) { FunctionToSpecialize &functionToSpecialize = funcWorklist.back(); - mlir::Function *f = functionToSpecialize.function; + mlir::Function f = functionToSpecialize.function; // Check if cloning for specialization is needed (usually anything but main) // We will create a new function with the concrete types for the parameters @@ -169,36 +169,36 @@ public: auto type = mlir::FunctionType::get(functionToSpecialize.argumentsType, {ToyArrayType::get(&getContext())}, &getContext()); - auto *newFunction = new mlir::Function( - f->getLoc(), functionToSpecialize.mangledName, type, f->getAttrs()); - getModule().getFunctions().push_back(newFunction); + auto newFunction = mlir::Function::create( + f.getLoc(), functionToSpecialize.mangledName, type, f.getAttrs()); + getModule().push_back(newFunction); // Clone the function body mlir::BlockAndValueMapping mapper; - f->cloneInto(newFunction, mapper); + f.cloneInto(newFunction, mapper); LLVM_DEBUG({ llvm::dbgs() << "====== Cloned : \n"; - f->dump(); + f.dump(); llvm::dbgs() << "====== Into : \n"; - newFunction->dump(); + newFunction.dump(); }); f = newFunction; - f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); + f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); // Remap the entry-block arguments // FIXME: this seems like a bug in `cloneInto()` above? - auto &entryBlock = f->getBlocks().front(); + auto &entryBlock = f.getBlocks().front(); int blockArgSize = entryBlock.getArguments().size(); - assert(blockArgSize == static_cast(f->getType().getInputs().size())); - entryBlock.addArguments(f->getType().getInputs()); + assert(blockArgSize == static_cast(f.getType().getInputs().size())); + entryBlock.addArguments(f.getType().getInputs()); auto argList = entryBlock.getArguments(); for (int argNum = 0; argNum < blockArgSize; ++argNum) { argList[0]->replaceAllUsesWith(argList[blockArgSize]); entryBlock.eraseArgument(0); } - assert(succeeded(f->verify())); + assert(succeeded(f.verify())); } LLVM_DEBUG(llvm::dbgs() - << "Run shape inference on : '" << f->getName() << "'\n"); + << "Run shape inference on : '" << f.getName() << "'\n"); auto *toyDialect = getContext().getRegisteredDialect("toy"); if (!toyDialect) { @@ -211,7 +211,7 @@ public: // Populate the worklist with the operations that need shape inference: // these are the Toy operations that return a generic array. llvm::SmallPtrSet opWorklist; - f->walk([&](mlir::Operation *op) { + f.walk([&](mlir::Operation *op) { if (op->getDialect() == toyDialect) { if (op->getNumResults() == 1 && op->getResult(0)->getType().cast().isGeneric()) @@ -292,9 +292,9 @@ public: // restart after the callee is processed. if (auto callOp = llvm::dyn_cast(op)) { auto calleeName = callOp.getCalleeName(); - auto *callee = getModule().getNamedFunction(calleeName); + auto callee = getModule().getNamedFunction(calleeName); if (!callee) { - f->emitError("Shape inference failed, call to unknown '") + f.emitError("Shape inference failed, call to unknown '") << calleeName << "'"; signalPassFailure(); return mlir::failure(); @@ -302,7 +302,7 @@ public: auto mangledName = mangle(calleeName, op->getOpOperands()); LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName << "', mangled: '" << mangledName << "'\n"); - auto *mangledCallee = getModule().getNamedFunction(mangledName); + auto mangledCallee = getModule().getNamedFunction(mangledName); if (!mangledCallee) { // Can't find the target, this is where we queue the request for the // callee and stop the inference for the current function now. @@ -327,7 +327,7 @@ public: // Done with inference on this function, removing it from the worklist. funcWorklist.pop_back(); // Mark the function as non-generic now that inference has succeeded - f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); + f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); // If the operation worklist isn't empty, this indicates a failure. if (!opWorklist.empty()) { @@ -337,31 +337,31 @@ public: << " operations couldn't be inferred\n"; for (auto *ope : opWorklist) errorMsg << " - " << *ope << "\n"; - f->emitError(errorMsg.str()); + f.emitError(errorMsg.str()); signalPassFailure(); return mlir::failure(); } // Finally, update the return type of the function based on the argument to // the return operation. - for (auto &block : f->getBlocks()) { + for (auto &block : f.getBlocks()) { auto ret = llvm::cast(block.getTerminator()); if (!ret) continue; if (ret.getNumOperands() && - f->getType().getResult(0) == ret.getOperand()->getType()) + f.getType().getResult(0) == ret.getOperand()->getType()) // type match, we're done break; SmallVector retTy; if (ret.getNumOperands()) retTy.push_back(ret.getOperand()->getType()); std::vector argumentsType; - for (auto arg : f->getArguments()) + for (auto arg : f.getArguments()) argumentsType.push_back(arg->getType()); auto newType = mlir::FunctionType::get(argumentsType, retTy, &getContext()); - f->setType(newType); - assert(succeeded(f->verify())); + f.setType(newType); + assert(succeeded(f.verify())); break; } return mlir::success(); diff --git a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp index 8b2a3927d78..60a8b5a3b9a 100644 --- a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp @@ -136,14 +136,14 @@ public: PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, PatternRewriter &rewriter) const override { // Get or create the declaration of the printf function in the module. - Function *printfFunc = getPrintf(*op->getFunction()->getModule()); + Function printfFunc = getPrintf(*op->getFunction().getModule()); auto print = cast(op); auto loc = print.getLoc(); // We will operate on a MemRef abstraction, we use a type.cast to get one // if our operand is still a Toy array. Value *operand = memRefTypeCast(rewriter, operands[0]); - Type retTy = printfFunc->getType().getResult(0); + Type retTy = printfFunc.getType().getResult(0); // Create our loop nest now using namespace edsc; @@ -205,8 +205,8 @@ private: /// Return the prototype declaration for printf in the module, create it if /// necessary. - Function *getPrintf(Module &module) const { - auto *printfFunc = module.getNamedFunction("printf"); + Function getPrintf(Module &module) const { + auto printfFunc = module.getNamedFunction("printf"); if (printfFunc) return printfFunc; @@ -218,10 +218,10 @@ private: auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(dialect); auto llvmI8PtrTy = LLVM::LLVMType::getInt8Ty(dialect).getPointerTo(); auto printfTy = builder.getFunctionType({llvmI8PtrTy}, {llvmI32Ty}); - printfFunc = new Function(builder.getUnknownLoc(), "printf", printfTy); + printfFunc = Function::create(builder.getUnknownLoc(), "printf", printfTy); // It should be variadic, but we don't support it fully just yet. - printfFunc->setAttr("std.varargs", builder.getBoolAttr(true)); - module.getFunctions().push_back(printfFunc); + printfFunc.setAttr("std.varargs", builder.getBoolAttr(true)); + module.push_back(printfFunc); return printfFunc; } }; @@ -369,7 +369,7 @@ struct LateLoweringPass : public ModulePass { // affine dialect: they already include conversion to the LLVM dialect. // First patch calls type to return memref instead of ToyArray - for (auto &function : getModule()) { + for (auto function : getModule()) { function.walk([&](Operation *op) { auto callOp = dyn_cast(op); if (!callOp) @@ -384,7 +384,7 @@ struct LateLoweringPass : public ModulePass { }); } - for (auto &function : getModule()) { + for (auto function : getModule()) { function.walk([&](Operation *op) { // Turns toy.alloc into sequence of alloc/dealloc (later malloc/free). if (auto allocOp = dyn_cast(op)) { @@ -403,8 +403,8 @@ struct LateLoweringPass : public ModulePass { } // Lower Linalg to affine - for (auto &function : getModule()) - linalg::lowerToLoops(&function); + for (auto function : getModule()) + linalg::lowerToLoops(function); getModule().dump(); diff --git a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp index f7e6fad568e..9ebfeb438ca 100644 --- a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp @@ -76,7 +76,7 @@ public: auto func = mlirGen(F); if (!func) return nullptr; - theModule->getFunctions().push_back(func.release()); + theModule->push_back(func); } // FIXME: (in the next chapter...) without registering a dialect in MLIR, @@ -130,40 +130,40 @@ private: /// Create the prototype for an MLIR function with as many arguments as the /// provided Toy AST prototype. - mlir::Function *mlirGen(PrototypeAST &proto) { + mlir::Function mlirGen(PrototypeAST &proto) { // This is a generic function, the return type will be inferred later. llvm::SmallVector ret_types; // Arguments type is uniformly a generic array. llvm::SmallVector arg_types(proto.getArgs().size(), getType(VarType{})); auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context); - auto *function = new mlir::Function(loc(proto.loc()), proto.getName(), - func_type, /* attrs = */ {}); + auto function = mlir::Function::create(loc(proto.loc()), proto.getName(), + func_type, /* attrs = */ {}); // Mark the function as generic: it'll require type specialization for every // call site. - if (function->getNumArguments()) - function->setAttr("toy.generic", mlir::BoolAttr::get(true, &context)); + if (function.getNumArguments()) + function.setAttr("toy.generic", mlir::BoolAttr::get(true, &context)); return function; } /// Emit a new function and add it to the MLIR module. - std::unique_ptr mlirGen(FunctionAST &funcAST) { + mlir::Function mlirGen(FunctionAST &funcAST) { // Create a scope in the symbol table to hold variable declarations. ScopedHashTableScope var_scope(symbolTable); // Create an MLIR function for the given prototype. - std::unique_ptr function(mlirGen(*funcAST.getProto())); + mlir::Function function(mlirGen(*funcAST.getProto())); if (!function) return nullptr; // Let's start the body of the function now! // In MLIR the entry block of the function is special: it must have the same // argument list as the function itself. - function->addEntryBlock(); + function.addEntryBlock(); - auto &entryBlock = function->front(); + auto &entryBlock = function.front(); auto &protoArgs = funcAST.getProto()->getArgs(); // Declare all the function arguments in the symbol table. for (const auto &name_value : @@ -173,16 +173,18 @@ private: // Create a builder for the function, it will be used throughout the codegen // to create operations in this function. - builder = llvm::make_unique(function->getBody()); + builder = llvm::make_unique(function.getBody()); // Emit the body of the function. - if (!mlirGen(*funcAST.getBody())) + if (!mlirGen(*funcAST.getBody())) { + function.erase(); return nullptr; + } // Implicitly return void if no return statement was emited. // FIXME: we may fix the parser instead to always return the last expression // (this would possibly help the REPL case later) - if (function->getBlocks().back().back().getName().getStringRef() != + if (function.getBlocks().back().back().getName().getStringRef() != "toy.return") { ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None); mlirGen(fakeRet); diff --git a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp index cad2deda57e..0abcb4bb850 100644 --- a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp @@ -113,7 +113,7 @@ public: // function to process, the mangled name for this specialization, and the // types of the arguments on which to specialize. struct FunctionToSpecialize { - mlir::Function *function; + mlir::Function function; std::string mangledName; SmallVector argumentsType; }; @@ -121,7 +121,7 @@ public: void runOnModule() override { auto &module = getModule(); mlir::ModuleManager moduleManager(&module); - auto *main = moduleManager.getNamedFunction("main"); + auto main = moduleManager.getNamedFunction("main"); if (!main) { emitError(mlir::UnknownLoc::get(module.getContext()), "Shape inference failed: can't find a main function\n"); @@ -140,7 +140,7 @@ public: // Delete any generic function left // FIXME: we may want this as a separate pass. - for (mlir::Function &function : llvm::make_early_inc_range(module)) { + for (mlir::Function function : llvm::make_early_inc_range(module)) { if (auto genericAttr = function.getAttrOfType("toy.generic")) { if (genericAttr.getValue()) @@ -155,7 +155,7 @@ public: specialize(SmallVectorImpl &funcWorklist, mlir::ModuleManager &moduleManager) { FunctionToSpecialize &functionToSpecialize = funcWorklist.back(); - mlir::Function *f = functionToSpecialize.function; + mlir::Function f = functionToSpecialize.function; // Check if cloning for specialization is needed (usually anything but main) // We will create a new function with the concrete types for the parameters @@ -171,36 +171,36 @@ public: auto type = mlir::FunctionType::get(functionToSpecialize.argumentsType, {ToyArrayType::get(&getContext())}, &getContext()); - auto *newFunction = new mlir::Function( - f->getLoc(), functionToSpecialize.mangledName, type, f->getAttrs()); + auto newFunction = mlir::Function::create( + f.getLoc(), functionToSpecialize.mangledName, type, f.getAttrs()); moduleManager.insert(newFunction); // Clone the function body mlir::BlockAndValueMapping mapper; - f->cloneInto(newFunction, mapper); + f.cloneInto(newFunction, mapper); LLVM_DEBUG({ llvm::dbgs() << "====== Cloned : \n"; - f->dump(); + f.dump(); llvm::dbgs() << "====== Into : \n"; - newFunction->dump(); + newFunction.dump(); }); f = newFunction; - f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); + f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); // Remap the entry-block arguments // FIXME: this seems like a bug in `cloneInto()` above? - auto &entryBlock = f->getBlocks().front(); + auto &entryBlock = f.getBlocks().front(); int blockArgSize = entryBlock.getArguments().size(); - assert(blockArgSize == static_cast(f->getType().getInputs().size())); - entryBlock.addArguments(f->getType().getInputs()); + assert(blockArgSize == static_cast(f.getType().getInputs().size())); + entryBlock.addArguments(f.getType().getInputs()); auto argList = entryBlock.getArguments(); for (int argNum = 0; argNum < blockArgSize; ++argNum) { argList[0]->replaceAllUsesWith(argList[blockArgSize]); entryBlock.eraseArgument(0); } - assert(succeeded(f->verify())); + assert(succeeded(f.verify())); } LLVM_DEBUG(llvm::dbgs() - << "Run shape inference on : '" << f->getName() << "'\n"); + << "Run shape inference on : '" << f.getName() << "'\n"); auto *toyDialect = getContext().getRegisteredDialect("toy"); if (!toyDialect) { @@ -212,7 +212,7 @@ public: // Populate the worklist with the operations that need shape inference: // these are the Toy operations that return a generic array. llvm::SmallPtrSet opWorklist; - f->walk([&](mlir::Operation *op) { + f.walk([&](mlir::Operation *op) { if (op->getDialect() == toyDialect) { if (op->getNumResults() == 1 && op->getResult(0)->getType().cast().isGeneric()) @@ -295,16 +295,16 @@ public: // restart after the callee is processed. if (auto callOp = llvm::dyn_cast(op)) { auto calleeName = callOp.getCalleeName(); - auto *callee = moduleManager.getNamedFunction(calleeName); + auto callee = moduleManager.getNamedFunction(calleeName); if (!callee) { signalPassFailure(); - return f->emitError("Shape inference failed, call to unknown '") + return f.emitError("Shape inference failed, call to unknown '") << calleeName << "'"; } auto mangledName = mangle(calleeName, op->getOpOperands()); LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName << "', mangled: '" << mangledName << "'\n"); - auto *mangledCallee = moduleManager.getNamedFunction(mangledName); + auto mangledCallee = moduleManager.getNamedFunction(mangledName); if (!mangledCallee) { // Can't find the target, this is where we queue the request for the // callee and stop the inference for the current function now. @@ -315,7 +315,7 @@ public: // Found a specialized callee! Let's turn this into a normal call // operation. SmallVector operands(op->getOperands()); - mlir::OpBuilder builder(f->getBody()); + mlir::OpBuilder builder(f.getBody()); builder.setInsertionPoint(op); auto newCall = builder.create(op->getLoc(), mangledCallee, operands); @@ -330,12 +330,12 @@ public: // Done with inference on this function, removing it from the worklist. funcWorklist.pop_back(); // Mark the function as non-generic now that inference has succeeded - f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); + f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); // If the operation worklist isn't empty, this indicates a failure. if (!opWorklist.empty()) { signalPassFailure(); - auto diag = f->emitError("Shape inference failed, ") + auto diag = f.emitError("Shape inference failed, ") << opWorklist.size() << " operations couldn't be inferred\n"; for (auto *ope : opWorklist) diag << " - " << *ope << "\n"; @@ -344,24 +344,24 @@ public: // Finally, update the return type of the function based on the argument to // the return operation. - for (auto &block : f->getBlocks()) { + for (auto &block : f.getBlocks()) { auto ret = llvm::cast(block.getTerminator()); if (!ret) continue; if (ret.getNumOperands() && - f->getType().getResult(0) == ret.getOperand()->getType()) + f.getType().getResult(0) == ret.getOperand()->getType()) // type match, we're done break; SmallVector retTy; if (ret.getNumOperands()) retTy.push_back(ret.getOperand()->getType()); std::vector argumentsType; - for (auto arg : f->getArguments()) + for (auto arg : f.getArguments()) argumentsType.push_back(arg->getType()); auto newType = mlir::FunctionType::get(argumentsType, retTy, &getContext()); - f->setType(newType); - assert(succeeded(f->verify())); + f.setType(newType); + assert(succeeded(f.verify())); break; } return mlir::success(); diff --git a/mlir/include/mlir/Analysis/Dominance.h b/mlir/include/mlir/Analysis/Dominance.h index e69756e73f4..8d7b2d59afe 100644 --- a/mlir/include/mlir/Analysis/Dominance.h +++ b/mlir/include/mlir/Analysis/Dominance.h @@ -34,7 +34,7 @@ template class DominanceInfoBase { using base = llvm::DominatorTreeBase; public: - DominanceInfoBase(Function *function) { recalculate(function); } + DominanceInfoBase(Function function) { recalculate(function); } DominanceInfoBase(Operation *op) { recalculate(op); } DominanceInfoBase(DominanceInfoBase &&) = default; DominanceInfoBase &operator=(DominanceInfoBase &&) = default; @@ -43,7 +43,7 @@ public: DominanceInfoBase &operator=(const DominanceInfoBase &) = delete; /// Recalculate the dominance info. - void recalculate(Function *function); + void recalculate(Function function); void recalculate(Operation *op); /// Get the root dominance node of the given region. diff --git a/mlir/include/mlir/Analysis/NestedMatcher.h b/mlir/include/mlir/Analysis/NestedMatcher.h index 3ab24f84640..b89011a28e3 100644 --- a/mlir/include/mlir/Analysis/NestedMatcher.h +++ b/mlir/include/mlir/Analysis/NestedMatcher.h @@ -104,8 +104,8 @@ struct NestedPattern { NestedPattern &operator=(const NestedPattern &) = default; /// Returns all the top-level matches in `func`. - void match(Function *func, SmallVectorImpl *matches) { - func->walk([&](Operation *op) { matchOne(op, matches); }); + void match(Function func, SmallVectorImpl *matches) { + func.walk([&](Operation *op) { matchOne(op, matches); }); } /// Returns all the top-level matches in `op`. diff --git a/mlir/include/mlir/ExecutionEngine/MemRefUtils.h b/mlir/include/mlir/ExecutionEngine/MemRefUtils.h index a2d982d299b..3d20eaff46c 100644 --- a/mlir/include/mlir/ExecutionEngine/MemRefUtils.h +++ b/mlir/include/mlir/ExecutionEngine/MemRefUtils.h @@ -44,7 +44,7 @@ struct StaticFloatMemRef { /// each of the arguments, initialize the storage with `initialValue`, and /// return a list of type-erased descriptor pointers. llvm::Expected> -allocateMemRefArguments(Function *func, float initialValue = 0.0); +allocateMemRefArguments(Function func, float initialValue = 0.0); /// Free a list of type-erased descriptors to statically-shaped memrefs with /// element type f32. diff --git a/mlir/include/mlir/GPU/GPUDialect.h b/mlir/include/mlir/GPU/GPUDialect.h index 8f682ce7c2e..c0326deb7cd 100644 --- a/mlir/include/mlir/GPU/GPUDialect.h +++ b/mlir/include/mlir/GPU/GPUDialect.h @@ -44,7 +44,7 @@ public: /// Returns whether the given function is a kernel function, i.e., has the /// 'gpu.kernel' attribute. - static bool isKernel(Function *function); + static bool isKernel(Function function); }; /// Utility class for the GPU dialect to represent triples of `Value`s @@ -122,12 +122,12 @@ public: using Op::Op; static void build(Builder *builder, OperationState *result, - Function *kernelFunc, Value *gridSizeX, Value *gridSizeY, + Function kernelFunc, Value *gridSizeX, Value *gridSizeY, Value *gridSizeZ, Value *blockSizeX, Value *blockSizeY, Value *blockSizeZ, ArrayRef kernelOperands); static void build(Builder *builder, OperationState *result, - Function *kernelFunc, KernelDim3 gridSize, + Function kernelFunc, KernelDim3 gridSize, KernelDim3 blockSize, ArrayRef kernelOperands); /// The kernel function specified by the operation's `kernel` attribute. diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index 5b9bfca35ad..b46e160174b 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -313,9 +313,8 @@ class FunctionAttr detail::StringAttributeStorage> { public: using Base::Base; - using ValueType = Function *; + using ValueType = StringRef; - static FunctionAttr get(Function *value); static FunctionAttr get(StringRef value, MLIRContext *ctx); /// Returns the name of the held function reference. diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h index f4ecb4ec6d7..feae5c93fea 100644 --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -101,7 +101,7 @@ public: /// Returns the function that this block is part of, even if the block is /// nested under an operation region. - Function *getFunction(); + Function getFunction(); /// Insert this block (which must not already be in a function) right before /// the specified block. diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 6ce5c22eadc..e5c8c035c46 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -112,7 +112,7 @@ public: AffineMapAttr getAffineMapAttr(AffineMap map); IntegerSetAttr getIntegerSetAttr(IntegerSet set); TypeAttr getTypeAttr(Type type); - FunctionAttr getFunctionAttr(Function *value); + FunctionAttr getFunctionAttr(Function value); FunctionAttr getFunctionAttr(StringRef value); ElementsAttr getDenseElementsAttr(ShapedType type, ArrayRef values); diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h index 56d06619c79..4e82689efff 100644 --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -145,17 +145,13 @@ public: /// Verify an attribute from this dialect on the given function. Returns /// failure if the verification failed, success otherwise. - virtual LogicalResult verifyFunctionAttribute(Function *, NamedAttribute) { - return success(); - } + virtual LogicalResult verifyFunctionAttribute(Function, NamedAttribute); /// Verify an attribute from this dialect on the argument at 'argIndex' for /// the given function. Returns failure if the verification failed, success /// otherwise. - virtual LogicalResult - verifyFunctionArgAttribute(Function *, unsigned argIndex, NamedAttribute) { - return success(); - } + virtual LogicalResult verifyFunctionArgAttribute(Function, unsigned argIndex, + NamedAttribute); /// Verify an attribute from this dialect on the given operation. Returns /// failure if the verification failed, success otherwise. diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h index 8f3b3b0df13..e11a45ba033 100644 --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -29,29 +29,79 @@ namespace mlir { class BlockAndValueMapping; class FunctionType; +class Function; class MLIRContext; class Module; -/// This is the base class for all of the MLIR function types. -class Function : public llvm::ilist_node_with_parent { +namespace detail { +/// This class represents all of the internal state of a Function. This allows +/// for the Function class to be value typed. +class FunctionStorage + : public llvm::ilist_node_with_parent { + FunctionStorage(Location location, StringRef name, FunctionType type, + ArrayRef attrs = {}); + FunctionStorage(Location location, StringRef name, FunctionType type, + ArrayRef attrs, + ArrayRef argAttrs); + /// The name of the function. + Identifier name; + + /// The module this function is embedded into. + Module *module = nullptr; + + /// The source location the function was defined or derived from. + Location location; + + /// The type of the function. + FunctionType type; + + /// This holds general named attributes for the function. + NamedAttributeList attrs; + + /// The attributes lists for each of the function arguments. + std::vector argAttrs; + + /// The body of the function. + Region body; + + friend struct llvm::ilist_traits; + friend Function; +}; +} // namespace detail + +/// This class represents an MLIR function, or the common unit of computation. +/// The region of a function is not allowed to implicitly capture global values, +/// and all external references must use Function arguments or attributes. +class Function { public: - Function(Location location, StringRef name, FunctionType type, - ArrayRef attrs = {}); - Function(Location location, StringRef name, FunctionType type, - ArrayRef attrs, - ArrayRef argAttrs); + Function(detail::FunctionStorage *impl = nullptr) : impl(impl) {} + + static Function create(Location location, StringRef name, FunctionType type, + ArrayRef attrs = {}) { + return new detail::FunctionStorage(location, name, type, attrs); + } + static Function create(Location location, StringRef name, FunctionType type, + ArrayRef attrs, + ArrayRef argAttrs) { + return new detail::FunctionStorage(location, name, type, attrs, argAttrs); + } + + /// Allow converting a Function to bool for null checks. + operator bool() const { return impl; } + bool operator==(Function other) const { return impl == other.impl; } + bool operator!=(Function other) const { return !(*this == other); } /// The source location the function was defined or derived from. - Location getLoc() { return location; } + Location getLoc() { return impl->location; } /// Set the source location this function was defined or derived from. - void setLoc(Location loc) { location = loc; } + void setLoc(Location loc) { impl->location = loc; } /// Return the name of this function, without the @. - Identifier getName() { return name; } + Identifier getName() { return impl->name; } /// Return the type of this function. - FunctionType getType() { return type; } + FunctionType getType() { return impl->type; } /// Change the type of this function in place. This is an extremely dangerous /// operation and it is up to the caller to ensure that this is legal for this @@ -61,12 +111,12 @@ public: /// parameters we drop the extra attributes, if there are more parameters /// they won't have any attributes. void setType(FunctionType newType) { - type = newType; - argAttrs.resize(type.getNumInputs()); + impl->type = newType; + impl->argAttrs.resize(newType.getNumInputs()); } MLIRContext *getContext(); - Module *getModule() { return module; } + Module *getModule() { return impl->module; } /// Add an entry block to an empty function, and set up the block arguments /// to match the signature of the function. @@ -82,28 +132,28 @@ public: // Body Handling //===--------------------------------------------------------------------===// - Region &getBody() { return body; } - void eraseBody() { body.getBlocks().clear(); } + Region &getBody() { return impl->body; } + void eraseBody() { getBody().getBlocks().clear(); } /// This is the list of blocks in the function. using RegionType = Region::RegionType; - RegionType &getBlocks() { return body.getBlocks(); } + RegionType &getBlocks() { return getBody().getBlocks(); } // Iteration over the block in the function. using iterator = RegionType::iterator; using reverse_iterator = RegionType::reverse_iterator; - iterator begin() { return body.begin(); } - iterator end() { return body.end(); } - reverse_iterator rbegin() { return body.rbegin(); } - reverse_iterator rend() { return body.rend(); } + iterator begin() { return getBody().begin(); } + iterator end() { return getBody().end(); } + reverse_iterator rbegin() { return getBody().rbegin(); } + reverse_iterator rend() { return getBody().rend(); } - bool empty() { return body.empty(); } - void push_back(Block *block) { body.push_back(block); } - void push_front(Block *block) { body.push_front(block); } + bool empty() { return getBody().empty(); } + void push_back(Block *block) { getBody().push_back(block); } + void push_front(Block *block) { getBody().push_front(block); } - Block &back() { return body.back(); } - Block &front() { return body.front(); } + Block &back() { return getBody().back(); } + Block &front() { return getBody().front(); } //===--------------------------------------------------------------------===// // Operation Walkers @@ -150,53 +200,55 @@ public: /// the lifetime of an function. /// Return all of the attributes on this function. - ArrayRef getAttrs() { return attrs.getAttrs(); } + ArrayRef getAttrs() { return impl->attrs.getAttrs(); } /// Return the internal attribute list on this function. - NamedAttributeList &getAttrList() { return attrs; } + NamedAttributeList &getAttrList() { return impl->attrs; } /// Return all of the attributes for the argument at 'index'. ArrayRef getArgAttrs(unsigned index) { assert(index < getNumArguments() && "invalid argument number"); - return argAttrs[index].getAttrs(); + return impl->argAttrs[index].getAttrs(); } /// Set the attributes held by this function. void setAttrs(ArrayRef attributes) { - attrs.setAttrs(attributes); + impl->attrs.setAttrs(attributes); } /// Set the attributes held by the argument at 'index'. void setArgAttrs(unsigned index, ArrayRef attributes) { assert(index < getNumArguments() && "invalid argument number"); - argAttrs[index].setAttrs(attributes); + impl->argAttrs[index].setAttrs(attributes); } void setArgAttrs(unsigned index, NamedAttributeList attributes) { assert(index < getNumArguments() && "invalid argument number"); - argAttrs[index] = attributes; + impl->argAttrs[index] = attributes; } void setAllArgAttrs(ArrayRef attributes) { assert(attributes.size() == getNumArguments()); for (unsigned i = 0, e = attributes.size(); i != e; ++i) - argAttrs[i] = attributes[i]; + impl->argAttrs[i] = attributes[i]; } /// Return all argument attributes of this function. - MutableArrayRef getAllArgAttrs() { return argAttrs; } + MutableArrayRef getAllArgAttrs() { + return impl->argAttrs; + } /// Return the specified attribute if present, null otherwise. - Attribute getAttr(Identifier name) { return attrs.get(name); } - Attribute getAttr(StringRef name) { return attrs.get(name); } + Attribute getAttr(Identifier name) { return impl->attrs.get(name); } + Attribute getAttr(StringRef name) { return impl->attrs.get(name); } /// Return the specified attribute, if present, for the argument at 'index', /// null otherwise. Attribute getArgAttr(unsigned index, Identifier name) { assert(index < getNumArguments() && "invalid argument number"); - return argAttrs[index].get(name); + return impl->argAttrs[index].get(name); } Attribute getArgAttr(unsigned index, StringRef name) { assert(index < getNumArguments() && "invalid argument number"); - return argAttrs[index].get(name); + return impl->argAttrs[index].get(name); } template AttrClass getAttrOfType(Identifier name) { @@ -219,13 +271,15 @@ public: /// If the an attribute exists with the specified name, change it to the new /// value. Otherwise, add a new attribute with the specified name/value. - void setAttr(Identifier name, Attribute value) { attrs.set(name, value); } + void setAttr(Identifier name, Attribute value) { + impl->attrs.set(name, value); + } void setAttr(StringRef name, Attribute value) { setAttr(Identifier::get(name, getContext()), value); } void setArgAttr(unsigned index, Identifier name, Attribute value) { assert(index < getNumArguments() && "invalid argument number"); - argAttrs[index].set(name, value); + impl->argAttrs[index].set(name, value); } void setArgAttr(unsigned index, StringRef name, Attribute value) { setArgAttr(index, Identifier::get(name, getContext()), value); @@ -234,12 +288,12 @@ public: /// Remove the attribute with the specified name if it exists. The return /// value indicates whether the attribute was present or not. NamedAttributeList::RemoveResult removeAttr(Identifier name) { - return attrs.remove(name); + return impl->attrs.remove(name); } NamedAttributeList::RemoveResult removeArgAttr(unsigned index, Identifier name) { assert(index < getNumArguments() && "invalid argument number"); - return attrs.remove(name); + return impl->attrs.remove(name); } //===--------------------------------------------------------------------===// @@ -281,44 +335,37 @@ public: /// contains entries for function arguments, these arguments are not included /// in the new function. Replaces references to cloned sub-values with the /// corresponding value that is copied, and adds those mappings to the mapper. - Function *clone(BlockAndValueMapping &mapper); - Function *clone(); + Function clone(BlockAndValueMapping &mapper); + Function clone(); /// Clone the internal blocks and attributes from this function into dest. Any /// cloned blocks are appended to the back of dest. This function asserts that /// the attributes of the current function and dest are compatible. - void cloneInto(Function *dest, BlockAndValueMapping &mapper); + void cloneInto(Function dest, BlockAndValueMapping &mapper); + + /// Methods for supporting PointerLikeTypeTraits. + const void *getAsOpaquePointer() const { + return static_cast(impl); + } + static Function getFromOpaquePointer(const void *pointer) { + return reinterpret_cast( + const_cast(pointer)); + } private: /// Set the name of this function. - void setName(Identifier newName) { name = newName; } - - /// The name of the function. - Identifier name; - - /// The module this function is embedded into. - Module *module = nullptr; - - /// The source location the function was defined or derived from. - Location location; - - /// The type of the function. - FunctionType type; - - /// This holds general named attributes for the function. - NamedAttributeList attrs; + void setName(Identifier newName) { impl->name = newName; } - /// The attributes lists for each of the function arguments. - std::vector argAttrs; - - /// The body of the function. - Region body; - - void operator=(Function &) = delete; - friend struct llvm::ilist_traits; + /// A pointer to the impl storage instance for this function. This allows for + /// 'Function' to be treated as a value type. + detail::FunctionStorage *impl = nullptr; // Allow access to 'setName'. friend class SymbolTable; + + // Allow access to 'impl'. + friend class Module; + friend class Region; }; //===--------------------------------------------------------------------===// @@ -487,21 +534,52 @@ private: namespace llvm { template <> -struct ilist_traits<::mlir::Function> - : public ilist_alloc_traits<::mlir::Function> { - using Function = ::mlir::Function; - using function_iterator = simple_ilist::iterator; +struct ilist_traits<::mlir::detail::FunctionStorage> + : public ilist_alloc_traits<::mlir::detail::FunctionStorage> { + using FunctionStorage = ::mlir::detail::FunctionStorage; + using function_iterator = simple_ilist::iterator; - static void deleteNode(Function *function) { delete function; } + static void deleteNode(FunctionStorage *function) { delete function; } - void addNodeToList(Function *function); - void removeNodeFromList(Function *function); - void transferNodesFromList(ilist_traits &otherList, + void addNodeToList(FunctionStorage *function); + void removeNodeFromList(FunctionStorage *function); + void transferNodesFromList(ilist_traits &otherList, function_iterator first, function_iterator last); private: mlir::Module *getContainingModule(); }; -} // end namespace llvm + +// Functions hash just like pointers. +template <> struct DenseMapInfo { + static mlir::Function getEmptyKey() { + auto pointer = llvm::DenseMapInfo::getEmptyKey(); + return mlir::Function::getFromOpaquePointer(pointer); + } + static mlir::Function getTombstoneKey() { + auto pointer = llvm::DenseMapInfo::getTombstoneKey(); + return mlir::Function::getFromOpaquePointer(pointer); + } + static unsigned getHashValue(mlir::Function val) { + return hash_value(val.getAsOpaquePointer()); + } + static bool isEqual(mlir::Function LHS, mlir::Function RHS) { + return LHS == RHS; + } +}; + +/// Allow stealing the low bits of FunctionStorage. +template <> struct PointerLikeTypeTraits { +public: + static inline void *getAsVoidPointer(mlir::Function I) { + return const_cast(I.getAsOpaquePointer()); + } + static inline mlir::Function getFromVoidPointer(void *P) { + return mlir::Function::getFromOpaquePointer(P); + } + enum { NumLowBitsAvailable = 3 }; +}; + +} // namespace llvm #endif // MLIR_IR_FUNCTION_H diff --git a/mlir/include/mlir/IR/Module.h b/mlir/include/mlir/IR/Module.h index 8161a305fb5..d8a47891ace 100644 --- a/mlir/include/mlir/IR/Module.h +++ b/mlir/include/mlir/IR/Module.h @@ -34,34 +34,54 @@ public: MLIRContext *getContext() { return context; } + /// An iterator class used to iterate over the held functions. + class iterator : public llvm::mapped_iterator< + llvm::iplist::iterator, + Function (*)(detail::FunctionStorage &)> { + static Function unwrap(detail::FunctionStorage &impl) { return &impl; } + + public: + using reference = Function; + + /// Initializes the operand type iterator to the specified operand iterator. + iterator(llvm::iplist::iterator it) + : llvm::mapped_iterator::iterator, + Function (*)(detail::FunctionStorage &)>( + it, &unwrap) {} + iterator(Function it) + : iterator(llvm::iplist::iterator(it.impl)) {} + }; + /// This is the list of functions in the module. - using FunctionListType = llvm::iplist; - FunctionListType &getFunctions() { return functions; } + llvm::iterator_range getFunctions() { return {begin(), end()}; } // Iteration over the functions in the module. - using iterator = FunctionListType::iterator; - using reverse_iterator = FunctionListType::reverse_iterator; - iterator begin() { return functions.begin(); } iterator end() { return functions.end(); } - reverse_iterator rbegin() { return functions.rbegin(); } - reverse_iterator rend() { return functions.rend(); } + Function front() { return &functions.front(); } + Function back() { return &functions.back(); } + + void push_back(Function fn) { functions.push_back(fn.impl); } + void insert(iterator insertPt, Function fn) { + functions.insert(insertPt.getCurrent(), fn.impl); + } // Interfaces for working with the symbol table. /// Look up a function with the specified name, returning null if no such /// name exists. Function names never include the @ on them. Note: This /// performs a linear scan of held symbols. - Function *getNamedFunction(StringRef name) { + Function getNamedFunction(StringRef name) { return getNamedFunction(Identifier::get(name, getContext())); } /// Look up a function with the specified name, returning null if no such /// name exists. Function names never include the @ on them. Note: This /// performs a linear scan of held symbols. - Function *getNamedFunction(Identifier name) { - auto it = llvm::find_if( - functions, [name](Function &fn) { return fn.getName() == name; }); + Function getNamedFunction(Identifier name) { + auto it = llvm::find_if(functions, [name](detail::FunctionStorage &fn) { + return Function(&fn).getName() == name; + }); return it == functions.end() ? nullptr : &*it; } @@ -74,11 +94,13 @@ public: void dump(); private: - friend struct llvm::ilist_traits; - friend class Function; + friend struct llvm::ilist_traits; + friend detail::FunctionStorage; + friend Function; /// getSublistAccess() - Returns pointer to member of function list - static FunctionListType Module::*getSublistAccess(Function *) { + static llvm::iplist Module::* + getSublistAccess(detail::FunctionStorage *) { return &Module::functions; } @@ -86,7 +108,7 @@ private: MLIRContext *context; /// This is the actual list of functions the module contains. - FunctionListType functions; + llvm::iplist functions; }; /// A class used to manage the symbols held by a module. This class handles @@ -98,24 +120,24 @@ public: /// Look up a symbol with the specified name, returning null if no such /// name exists. Names must never include the @ on them. - template Function *getNamedFunction(NameTy &&name) const { + template Function getNamedFunction(NameTy &&name) const { return symbolTable.lookup(name); } /// Insert a new symbol into the module, auto-renaming it as necessary. - void insert(Function *function) { + void insert(Function function) { symbolTable.insert(function); - module->getFunctions().push_back(function); + module->push_back(function); } - void insert(Module::iterator insertPt, Function *function) { + void insert(Module::iterator insertPt, Function function) { symbolTable.insert(function); - module->getFunctions().insert(insertPt, function); + module->insert(insertPt, function); } /// Remove the given symbol from the module symbol table and then erase it. - void erase(Function *function) { + void erase(Function function) { symbolTable.erase(function); - function->erase(); + function.erase(); } /// Return the internally held module. diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index e5323999df7..f916f4ba583 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -128,7 +128,7 @@ public: /// Returns the function that this operation is part of. /// The function is determined by traversing the chain of parent operations. /// Returns nullptr if the operation is unlinked. - Function *getFunction(); + Function getFunction(); /// Replace any uses of 'from' with 'to' within this operation. void replaceUsesOfWith(Value *from, Value *to); diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index a1b81fcde40..921437601e1 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -420,7 +420,7 @@ private: /// patterns in a greedy work-list driven manner. Return true if no more /// patterns can be matched in the result function. /// -bool applyPatternsGreedily(Function &fn, OwningRewritePatternList &&patterns); +bool applyPatternsGreedily(Function fn, OwningRewritePatternList &&patterns); /// Helper class to create a list of rewrite patterns given a list of their /// types and a list of attributes perfect-forwarded to each of the conversion diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h index 2189ad490f8..ad0692b0864 100644 --- a/mlir/include/mlir/IR/Region.h +++ b/mlir/include/mlir/IR/Region.h @@ -27,11 +27,16 @@ namespace mlir { class BlockAndValueMapping; +namespace detail { +class FunctionStorage; +} + /// This class contains a list of basic blocks and has a notion of the object it /// is part of - a Function or an Operation. class Region { public: - explicit Region(Function *container = nullptr); + Region() = default; + explicit Region(Function container); explicit Region(Operation *container); ~Region(); @@ -77,7 +82,7 @@ public: /// A Region is either a function body or a part of an operation. If it is /// a Function body, then return this function, otherwise return null. - Function *getContainingFunction(); + Function getContainingFunction(); /// Return true if this region is a proper ancestor of the `other` region. bool isProperAncestor(Region *other); @@ -118,7 +123,7 @@ private: RegionType blocks; /// This is the object we are part of. - llvm::PointerUnion container; + llvm::PointerUnion container; }; } // end namespace mlir diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h index 30749582031..a351f66eb2e 100644 --- a/mlir/include/mlir/IR/SymbolTable.h +++ b/mlir/include/mlir/IR/SymbolTable.h @@ -18,7 +18,7 @@ #ifndef MLIR_IR_SYMBOLTABLE_H #define MLIR_IR_SYMBOLTABLE_H -#include "mlir/IR/Identifier.h" +#include "mlir/IR/Function.h" #include "llvm/ADT/DenseMap.h" namespace mlir { @@ -35,18 +35,18 @@ public: /// Look up a symbol with the specified name, returning null if no such /// name exists. Names never include the @ on them. - Function *lookup(StringRef name) const; + Function lookup(StringRef name) const; /// Look up a symbol with the specified name, returning null if no such /// name exists. Names never include the @ on them. - Function *lookup(Identifier name) const; + Function lookup(Identifier name) const; /// Erase the given symbol from the table. - void erase(Function *symbol); + void erase(Function symbol); /// Insert a new symbol into the table, and rename it as necessary to avoid /// collisions. - void insert(Function *symbol); + void insert(Function symbol); /// Returns the context held by this symbol table. MLIRContext *getContext() const { return context; } @@ -55,7 +55,7 @@ private: MLIRContext *context; /// This is a mapping from a name to the function with that name. - llvm::DenseMap symbolTable; + llvm::DenseMap symbolTable; /// This is used when name conflicts are detected. unsigned uniquingCounter = 0; diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h index e90505ec90d..4604ed99c77 100644 --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -72,7 +72,7 @@ public: } /// Return the function that this Value is defined in. - Function *getFunction(); + Function getFunction(); /// If this value is the result of an operation, return the operation that /// defines it. @@ -128,7 +128,7 @@ public: } /// Return the function that this argument is defined in. - Function *getFunction(); + Function getFunction(); Block *getOwner() { return owner; } diff --git a/mlir/include/mlir/LLVMIR/LLVMDialect.h b/mlir/include/mlir/LLVMIR/LLVMDialect.h index bd3286df8f4..a28aa719965 100644 --- a/mlir/include/mlir/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/LLVMIR/LLVMDialect.h @@ -153,7 +153,7 @@ public: /// Verify a function argument attribute registered to this dialect. /// Returns failure if the verification failed, success otherwise. - LogicalResult verifyFunctionArgAttribute(Function *func, unsigned argIdx, + LogicalResult verifyFunctionArgAttribute(Function func, unsigned argIdx, NamedAttribute argAttr) override; private: diff --git a/mlir/include/mlir/Pass/AnalysisManager.h b/mlir/include/mlir/Pass/AnalysisManager.h index 3751a93629d..c44f88f6763 100644 --- a/mlir/include/mlir/Pass/AnalysisManager.h +++ b/mlir/include/mlir/Pass/AnalysisManager.h @@ -106,7 +106,7 @@ template class AnalysisMap { } public: - explicit AnalysisMap(IRUnitT *ir) : ir(ir) {} + explicit AnalysisMap(IRUnitT ir) : ir(ir) {} /// Get an analysis for the current IR unit, computing it if necessary. template AnalysisT &getAnalysis(PassInstrumentor *pi) { @@ -140,8 +140,8 @@ public: } /// Returns the IR unit that this analysis map represents. - IRUnitT *getIRUnit() { return ir; } - const IRUnitT *getIRUnit() const { return ir; } + IRUnitT getIRUnit() { return ir; } + const IRUnitT getIRUnit() const { return ir; } /// Clear any held analyses. void clear() { analyses.clear(); } @@ -158,7 +158,7 @@ public: } private: - IRUnitT *ir; + IRUnitT ir; ConceptMap analyses; }; @@ -231,14 +231,14 @@ public: /// Query for the analysis of a function. The analysis is computed if it does /// not exist. template - AnalysisT &getFunctionAnalysis(Function *function) { + AnalysisT &getFunctionAnalysis(Function function) { return slice(function).getAnalysis(); } /// Query for a cached analysis of a child function, or return null. template llvm::Optional> - getCachedFunctionAnalysis(Function *function) const { + getCachedFunctionAnalysis(Function function) const { auto it = functionAnalyses.find(function); if (it == functionAnalyses.end()) return llvm::None; @@ -258,7 +258,7 @@ public: } /// Create an analysis slice for the given child function. - FunctionAnalysisManager slice(Function *function); + FunctionAnalysisManager slice(Function function); /// Invalidate any non preserved analyses. void invalidate(const detail::PreservedAnalyses &pa); @@ -269,11 +269,11 @@ public: private: /// The cached analyses for functions within the current module. - llvm::DenseMap>> + llvm::DenseMap>> functionAnalyses; /// The analyses for the owning module. - detail::AnalysisMap moduleAnalyses; + detail::AnalysisMap moduleAnalyses; /// An optional instrumentation object. PassInstrumentor *passInstrumentor; diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h index 5fd6dfd18b5..41d20ccdd63 100644 --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -70,12 +70,12 @@ class ModulePassExecutor; /// interface for accessing and initializing necessary state for pass execution. template struct PassExecutionState { - PassExecutionState(IRUnitT *ir, AnalysisManagerT &analysisManager) + PassExecutionState(IRUnitT ir, AnalysisManagerT &analysisManager) : irAndPassFailed(ir, false), analysisManager(analysisManager) {} /// The current IR unit being transformed and a bool for if the pass signaled /// a failure. - llvm::PointerIntPair irAndPassFailed; + llvm::PointerIntPair irAndPassFailed; /// The analysis manager for the IR unit. AnalysisManagerT &analysisManager; @@ -107,9 +107,7 @@ protected: virtual FunctionPassBase *clone() const = 0; /// Return the current function being transformed. - Function &getFunction() { - return *getPassState().irAndPassFailed.getPointer(); - } + Function getFunction() { return getPassState().irAndPassFailed.getPointer(); } /// Return the MLIR context for the current function being transformed. MLIRContext &getContext() { return *getFunction().getContext(); } @@ -128,7 +126,7 @@ protected: private: /// Forwarding function to execute this pass. LLVM_NODISCARD - LogicalResult run(Function *fn, FunctionAnalysisManager &fam); + LogicalResult run(Function fn, FunctionAnalysisManager &fam); /// The current execution state for the pass. llvm::Optional passState; @@ -140,7 +138,8 @@ private: /// Pass to transform a module. Derived passes should not inherit from this /// class directly, and instead should use the CRTP ModulePass class. class ModulePassBase : public Pass { - using PassStateT = detail::PassExecutionState; + using PassStateT = + detail::PassExecutionState; public: static bool classof(const Pass *pass) { @@ -272,7 +271,7 @@ struct FunctionPass : public detail::PassModel { template struct ModulePass : public detail::PassModel { /// Returns the analysis for a child function. - template AnalysisT &getFunctionAnalysis(Function *f) { + template AnalysisT &getFunctionAnalysis(Function f) { return this->getAnalysisManager().template getFunctionAnalysis( f); } @@ -280,7 +279,7 @@ struct ModulePass : public detail::PassModel { /// Returns an existing analysis for a child function if it exists. template llvm::Optional> - getCachedFunctionAnalysis(Function *f) { + getCachedFunctionAnalysis(Function f) { return this->getAnalysisManager() .template getCachedFunctionAnalysis(f); } diff --git a/mlir/include/mlir/Pass/PassInstrumentation.h b/mlir/include/mlir/Pass/PassInstrumentation.h index 0f427066296..40358329f45 100644 --- a/mlir/include/mlir/Pass/PassInstrumentation.h +++ b/mlir/include/mlir/Pass/PassInstrumentation.h @@ -77,29 +77,29 @@ public: ~PassInstrumentor(); /// See PassInstrumentation::runBeforePass for details. - template void runBeforePass(Pass *pass, IRUnitT *ir) { + template void runBeforePass(Pass *pass, IRUnitT ir) { runBeforePass(pass, llvm::Any(ir)); } /// See PassInstrumentation::runAfterPass for details. - template void runAfterPass(Pass *pass, IRUnitT *ir) { + template void runAfterPass(Pass *pass, IRUnitT ir) { runAfterPass(pass, llvm::Any(ir)); } /// See PassInstrumentation::runAfterPassFailed for details. - template void runAfterPassFailed(Pass *pass, IRUnitT *ir) { + template void runAfterPassFailed(Pass *pass, IRUnitT ir) { runAfterPassFailed(pass, llvm::Any(ir)); } /// See PassInstrumentation::runBeforeAnalysis for details. template - void runBeforeAnalysis(llvm::StringRef name, AnalysisID *id, IRUnitT *ir) { + void runBeforeAnalysis(llvm::StringRef name, AnalysisID *id, IRUnitT ir) { runBeforeAnalysis(name, id, llvm::Any(ir)); } /// See PassInstrumentation::runAfterAnalysis for details. template - void runAfterAnalysis(llvm::StringRef name, AnalysisID *id, IRUnitT *ir) { + void runAfterAnalysis(llvm::StringRef name, AnalysisID *id, IRUnitT ir) { runAfterAnalysis(name, id, llvm::Any(ir)); } diff --git a/mlir/include/mlir/StandardOps/Ops.td b/mlir/include/mlir/StandardOps/Ops.td index 1b14e2a2a9c..a7afe1f9e7c 100644 --- a/mlir/include/mlir/StandardOps/Ops.td +++ b/mlir/include/mlir/StandardOps/Ops.td @@ -214,11 +214,11 @@ def CallOp : Std_Op<"call"> { let results = (outs Variadic); let builders = [OpBuilder< - "Builder *builder, OperationState *result, Function *callee," + "Builder *builder, OperationState *result, Function callee," "ArrayRef operands = {}", [{ result->addOperands(operands); result->addAttribute("callee", builder->getFunctionAttr(callee)); - result->addTypes(callee->getType().getResults()); + result->addTypes(callee.getType().getResults()); }]>, OpBuilder< "Builder *builder, OperationState *result, StringRef callee," "ArrayRef results, ArrayRef operands = {}", [{ diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 00da0d5fcc0..c8ede78ec20 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -345,7 +345,7 @@ LLVM_NODISCARD LogicalResult applyConversionPatterns( /// Convert the given functions with the provided conversion patterns. This /// function returns failure if a type conversion failed. LLVM_NODISCARD -LogicalResult applyConversionPatterns(ArrayRef fns, +LogicalResult applyConversionPatterns(MutableArrayRef fns, ConversionTarget &target, TypeConverter &converter, OwningRewritePatternList &&patterns); @@ -354,7 +354,7 @@ LogicalResult applyConversionPatterns(ArrayRef fns, /// convert as many of the operations within 'fn' as possible given the set of /// patterns. LLVM_NODISCARD -LogicalResult applyConversionPatterns(Function &fn, ConversionTarget &target, +LogicalResult applyConversionPatterns(Function fn, ConversionTarget &target, OwningRewritePatternList &&patterns); } // end namespace mlir diff --git a/mlir/include/mlir/Transforms/LowerAffine.h b/mlir/include/mlir/Transforms/LowerAffine.h index d77b35a8044..09aa7dc8acd 100644 --- a/mlir/include/mlir/Transforms/LowerAffine.h +++ b/mlir/include/mlir/Transforms/LowerAffine.h @@ -37,7 +37,7 @@ Value *expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, /// Convert from the Affine dialect to the Standard dialect, in particular /// convert structured affine control flow into CFG branch-based control flow. -LogicalResult lowerAffineConstructs(Function &function); +LogicalResult lowerAffineConstructs(Function function); /// Emit code that computes the lower bound of the given affine loop using /// standard arithmetic operations. diff --git a/mlir/include/mlir/Transforms/ViewFunctionGraph.h b/mlir/include/mlir/Transforms/ViewFunctionGraph.h index c1da5ef9638..5780df5c21b 100644 --- a/mlir/include/mlir/Transforms/ViewFunctionGraph.h +++ b/mlir/include/mlir/Transforms/ViewFunctionGraph.h @@ -33,11 +33,11 @@ class FunctionPassBase; /// Displays the CFG in a window. This is for use from the debugger and /// depends on Graphviz to generate the graph. -void viewGraph(Function &function, const Twine &name, bool shortNames = false, +void viewGraph(Function function, const Twine &name, bool shortNames = false, const Twine &title = "", llvm::GraphProgram::Name program = llvm::GraphProgram::DOT); -llvm::raw_ostream &writeGraph(llvm::raw_ostream &os, Function &function, +llvm::raw_ostream &writeGraph(llvm::raw_ostream &os, Function function, bool shortNames = false, const Twine &title = ""); /// Creates a pass to print CFG graphs. diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 016ef43a84a..d7650dcb127 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -303,7 +303,7 @@ AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value *v) { if (inserted) { reorderedDims.push_back(v); } - return getAffineDimExpr(iterPos->second, v->getFunction()->getContext()) + return getAffineDimExpr(iterPos->second, v->getFunction().getContext()) .cast(); } diff --git a/mlir/lib/Analysis/Dominance.cpp b/mlir/lib/Analysis/Dominance.cpp index 954a01b4843..b4cdeb7d886 100644 --- a/mlir/lib/Analysis/Dominance.cpp +++ b/mlir/lib/Analysis/Dominance.cpp @@ -37,17 +37,16 @@ template class llvm::DomTreeNodeBase; /// Recalculate the dominance info. template -void DominanceInfoBase::recalculate(Function *function) { +void DominanceInfoBase::recalculate(Function function) { dominanceInfos.clear(); // Build the top level function dominance. auto functionDominance = llvm::make_unique(); - functionDominance->recalculate(function->getBody()); - dominanceInfos.try_emplace(&function->getBody(), - std::move(functionDominance)); + functionDominance->recalculate(function.getBody()); + dominanceInfos.try_emplace(&function.getBody(), std::move(functionDominance)); /// Build the dominance for each of the operation regions. - function->walk([&](Operation *op) { + function.walk([&](Operation *op) { for (auto ®ion : op->getRegions()) { // Don't compute dominance if the region is empty. if (region.empty()) diff --git a/mlir/lib/Analysis/OpStats.cpp b/mlir/lib/Analysis/OpStats.cpp index 5177afcee67..75a2fc1a5dc 100644 --- a/mlir/lib/Analysis/OpStats.cpp +++ b/mlir/lib/Analysis/OpStats.cpp @@ -45,7 +45,7 @@ void PrintOpStatsPass::runOnModule() { opCount.clear(); // Compute the operation statistics for each function in the module. - for (auto &fn : getModule()) + for (auto fn : getModule()) fn.walk([&](Operation *op) { ++opCount[op->getName().getStringRef()]; }); printSummary(); } diff --git a/mlir/lib/Analysis/TestParallelismDetection.cpp b/mlir/lib/Analysis/TestParallelismDetection.cpp index cbda6d40224..473d253cfa2 100644 --- a/mlir/lib/Analysis/TestParallelismDetection.cpp +++ b/mlir/lib/Analysis/TestParallelismDetection.cpp @@ -43,7 +43,7 @@ FunctionPassBase *mlir::createParallelismDetectionTestPass() { // Walks the function and emits a note for all 'affine.for' ops detected as // parallel. void TestParallelismDetection::runOnFunction() { - Function &f = getFunction(); + Function f = getFunction(); OpBuilder b(f.getBody()); f.walk([&](AffineForOp forOp) { if (isLoopParallel(forOp)) diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index 1330fe0fb94..0d0525145ef 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -53,7 +53,7 @@ public: : ctx(ctx), identifierRegex("^[a-zA-Z_][a-zA-Z_0-9\\.\\$]*$") {} /// Verify the body of the given function. - LogicalResult verify(Function &fn); + LogicalResult verify(Function fn); /// Verify the given operation. LogicalResult verify(Operation &op); @@ -104,7 +104,7 @@ private: } // end anonymous namespace /// Verify the body of the given function. -LogicalResult OperationVerifier::verify(Function &fn) { +LogicalResult OperationVerifier::verify(Function fn) { // Verify the body first. if (failed(verifyRegion(fn.getBody()))) return failure(); @@ -113,7 +113,7 @@ LogicalResult OperationVerifier::verify(Function &fn) { // check. We do this as a second pass since malformed CFG's can cause // dominator analysis constructure to crash and we want the verifier to be // resilient to malformed code. - DominanceInfo theDomInfo(&fn); + DominanceInfo theDomInfo(fn); domInfo = &theDomInfo; if (failed(verifyDominance(fn.getBody()))) return failure(); @@ -313,7 +313,7 @@ LogicalResult Function::verify() { // Verify this attribute with the defining dialect. if (auto *dialect = opVerifier.getDialectForAttribute(attr)) - if (failed(dialect->verifyFunctionAttribute(this, attr))) + if (failed(dialect->verifyFunctionAttribute(*this, attr))) return failure(); } @@ -331,7 +331,7 @@ LogicalResult Function::verify() { // Verify this attribute with the defining dialect. if (auto *dialect = opVerifier.getDialectForAttribute(attr)) - if (failed(dialect->verifyFunctionArgAttribute(this, i, attr))) + if (failed(dialect->verifyFunctionArgAttribute(*this, i, attr))) return failure(); } } @@ -369,7 +369,7 @@ LogicalResult Operation::verify() { LogicalResult Module::verify() { // Check that all functions are uniquely named. llvm::StringMap nameToOrigLoc; - for (auto &fn : *this) { + for (auto fn : *this) { auto it = nameToOrigLoc.try_emplace(fn.getName(), fn.getLoc()); if (!it.second) return fn.emitError() @@ -379,7 +379,7 @@ LogicalResult Module::verify() { } // Check that each function is correct. - for (auto &fn : *this) + for (auto fn : *this) if (failed(fn.verify())) return failure(); diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp index 9d7aeeb6321..022d8c70cc6 100644 --- a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp @@ -64,8 +64,8 @@ public: LLVMInitializeNVPTXTargetMC(); LLVMInitializeNVPTXAsmPrinter(); - for (auto &function : getModule()) { - if (!gpu::GPUDialect::isKernel(&function) || function.isExternal()) { + for (auto function : getModule()) { + if (!gpu::GPUDialect::isKernel(function) || function.isExternal()) { continue; } if (failed(translateGpuKernelToCubinAnnotation(function))) @@ -142,7 +142,7 @@ GpuKernelToCubinPass::translateGpuKernelToCubinAnnotation(Function &function) { std::unique_ptr module(builder.createModule()); // TODO(herhut): Also handle called functions. - module->getFunctions().push_back(function.clone()); + module->push_back(function.clone()); auto llvmModule = translateModuleToNVVMIR(*module); auto cubin = convertModuleToCubin(*llvmModule, function); diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp index bd96f396b22..f9d5899456a 100644 --- a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp @@ -118,7 +118,7 @@ private: void declareCudaFunctions(Location loc); Value *setupParamsArray(gpu::LaunchFuncOp launchOp, OpBuilder &builder); - Value *generateKernelNameConstant(Function *kernelFunction, Location &loc, + Value *generateKernelNameConstant(Function kernelFunction, Location &loc, OpBuilder &builder); void translateGpuLaunchCalls(mlir::gpu::LaunchFuncOp launchOp); @@ -130,7 +130,7 @@ public: // Cache the used LLVM types. initializeCachedTypes(); - for (auto &func : getModule()) { + for (auto func : getModule()) { func.walk( [this](mlir::gpu::LaunchFuncOp op) { translateGpuLaunchCalls(op); }); } @@ -155,66 +155,66 @@ void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) { Module &module = getModule(); Builder builder(&module); if (!module.getNamedFunction(cuModuleLoadName)) { - module.getFunctions().push_back( - new Function(loc, cuModuleLoadName, - builder.getFunctionType( - { - getPointerPointerType(), /* CUmodule *module */ - getPointerType() /* void *cubin */ - }, - getCUResultType()))); + module.push_back( + Function::create(loc, cuModuleLoadName, + builder.getFunctionType( + { + getPointerPointerType(), /* CUmodule *module */ + getPointerType() /* void *cubin */ + }, + getCUResultType()))); } if (!module.getNamedFunction(cuModuleGetFunctionName)) { // The helper uses void* instead of CUDA's opaque CUmodule and // CUfunction. - module.getFunctions().push_back( - new Function(loc, cuModuleGetFunctionName, - builder.getFunctionType( - { - getPointerPointerType(), /* void **function */ - getPointerType(), /* void *module */ - getPointerType() /* char *name */ - }, - getCUResultType()))); + module.push_back( + Function::create(loc, cuModuleGetFunctionName, + builder.getFunctionType( + { + getPointerPointerType(), /* void **function */ + getPointerType(), /* void *module */ + getPointerType() /* char *name */ + }, + getCUResultType()))); } if (!module.getNamedFunction(cuLaunchKernelName)) { // Other than the CUDA api, the wrappers use uintptr_t to match the // LLVM type if MLIR's index type, which the GPU dialect uses. // Furthermore, they use void* instead of CUDA's opaque CUfunction and // CUstream. - module.getFunctions().push_back( - new Function(loc, cuLaunchKernelName, - builder.getFunctionType( - { - getPointerType(), /* void* f */ - getIntPtrType(), /* intptr_t gridXDim */ - getIntPtrType(), /* intptr_t gridyDim */ - getIntPtrType(), /* intptr_t gridZDim */ - getIntPtrType(), /* intptr_t blockXDim */ - getIntPtrType(), /* intptr_t blockYDim */ - getIntPtrType(), /* intptr_t blockZDim */ - getInt32Type(), /* unsigned int sharedMemBytes */ - getPointerType(), /* void *hstream */ - getPointerPointerType(), /* void **kernelParams */ - getPointerPointerType() /* void **extra */ - }, - getCUResultType()))); + module.push_back(Function::create( + loc, cuLaunchKernelName, + builder.getFunctionType( + { + getPointerType(), /* void* f */ + getIntPtrType(), /* intptr_t gridXDim */ + getIntPtrType(), /* intptr_t gridyDim */ + getIntPtrType(), /* intptr_t gridZDim */ + getIntPtrType(), /* intptr_t blockXDim */ + getIntPtrType(), /* intptr_t blockYDim */ + getIntPtrType(), /* intptr_t blockZDim */ + getInt32Type(), /* unsigned int sharedMemBytes */ + getPointerType(), /* void *hstream */ + getPointerPointerType(), /* void **kernelParams */ + getPointerPointerType() /* void **extra */ + }, + getCUResultType()))); } if (!module.getNamedFunction(cuGetStreamHelperName)) { // Helper function to get the current CUDA stream. Uses void* instead of // CUDAs opaque CUstream. - module.getFunctions().push_back(new Function( + module.push_back(Function::create( loc, cuGetStreamHelperName, builder.getFunctionType({}, getPointerType() /* void *stream */))); } if (!module.getNamedFunction(cuStreamSynchronizeName)) { - module.getFunctions().push_back( - new Function(loc, cuStreamSynchronizeName, - builder.getFunctionType( - { - getPointerType() /* CUstream stream */ - }, - getCUResultType()))); + module.push_back( + Function::create(loc, cuStreamSynchronizeName, + builder.getFunctionType( + { + getPointerType() /* CUstream stream */ + }, + getCUResultType()))); } } @@ -264,14 +264,14 @@ GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp, // %0[n] = constant name[n] // %0[n+1] = 0 Value *GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant( - Function *kernelFunction, Location &loc, OpBuilder &builder) { + Function kernelFunction, Location &loc, OpBuilder &builder) { // TODO(herhut): Make this a constant once this is supported. auto kernelNameSize = builder.create( loc, getInt32Type(), - builder.getI32IntegerAttr(kernelFunction->getName().size() + 1)); + builder.getI32IntegerAttr(kernelFunction.getName().size() + 1)); auto kernelName = builder.create(loc, getPointerType(), kernelNameSize); - for (auto byte : llvm::enumerate(kernelFunction->getName())) { + for (auto byte : llvm::enumerate(kernelFunction.getName())) { auto index = builder.create( loc, getInt32Type(), builder.getI32IntegerAttr(byte.index())); auto gep = builder.create(loc, getPointerType(), kernelName, @@ -284,7 +284,7 @@ Value *GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant( // Add trailing zero to terminate string. auto index = builder.create( loc, getInt32Type(), - builder.getI32IntegerAttr(kernelFunction->getName().size())); + builder.getI32IntegerAttr(kernelFunction.getName().size())); auto gep = builder.create(loc, getPointerType(), kernelName, ArrayRef{index}); auto value = builder.create( @@ -326,9 +326,9 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls( // TODO(herhut): This should rather be a static global once supported. auto kernelFunction = getModule().getNamedFunction(launchOp.kernel()); auto cubinGetter = - kernelFunction->getAttrOfType(kCubinGetterAnnotation); + kernelFunction.getAttrOfType(kCubinGetterAnnotation); if (!cubinGetter) { - kernelFunction->emitError("Missing ") + kernelFunction.emitError("Missing ") << kCubinGetterAnnotation << " attribute."; return signalPassFailure(); } @@ -337,7 +337,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls( // Emit the load module call to load the module data. Error checking is done // in the called helper function. auto cuModule = allocatePointer(builder, loc); - Function *cuModuleLoad = getModule().getNamedFunction(cuModuleLoadName); + Function cuModuleLoad = getModule().getNamedFunction(cuModuleLoadName); builder.create(loc, ArrayRef{getCUResultType()}, builder.getFunctionAttr(cuModuleLoad), ArrayRef{cuModule, data.getResult(0)}); @@ -347,14 +347,14 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls( builder.create(loc, getPointerType(), cuModule); auto kernelName = generateKernelNameConstant(kernelFunction, loc, builder); auto cuFunction = allocatePointer(builder, loc); - Function *cuModuleGetFunction = + Function cuModuleGetFunction = getModule().getNamedFunction(cuModuleGetFunctionName); builder.create( loc, ArrayRef{getCUResultType()}, builder.getFunctionAttr(cuModuleGetFunction), ArrayRef{cuFunction, cuModuleRef, kernelName}); // Grab the global stream needed for execution. - Function *cuGetStreamHelper = + Function cuGetStreamHelper = getModule().getNamedFunction(cuGetStreamHelperName); auto cuStream = builder.create( loc, ArrayRef{getPointerType()}, diff --git a/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp b/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp index c1d4af380ce..97790a5afce 100644 --- a/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp @@ -53,15 +53,15 @@ constexpr const char *kMallocHelperName = "mcuMalloc"; class GpuGenerateCubinAccessorsPass : public ModulePass { private: - Function *getMallocHelper(Location loc, Builder &builder) { - Function *result = getModule().getNamedFunction(kMallocHelperName); + Function getMallocHelper(Location loc, Builder &builder) { + Function result = getModule().getNamedFunction(kMallocHelperName); if (!result) { - result = new Function( + result = Function::create( loc, kMallocHelperName, builder.getFunctionType( ArrayRef{LLVM::LLVMType::getInt32Ty(llvmDialect)}, LLVM::LLVMType::getInt8PtrTy(llvmDialect))); - getModule().getFunctions().push_back(result); + getModule().push_back(result); } return result; } @@ -70,18 +70,18 @@ private: // data from blob. As there are currently no global constants, this uses a // sequence of store operations. // TODO(herhut): Use global constants instead. - Function *generateCubinAccessor(Builder &builder, Function &orig, - StringAttr blob) { + Function generateCubinAccessor(Builder &builder, Function &orig, + StringAttr blob) { Location loc = orig.getLoc(); SmallString<128> nameBuffer(orig.getName()); nameBuffer.append(kCubinGetterSuffix); // Generate a function that returns void*. - Function *result = new Function( + Function result = Function::create( loc, mlir::Identifier::get(nameBuffer, &getContext()), builder.getFunctionType(ArrayRef{}, LLVM::LLVMType::getInt8PtrTy(llvmDialect))); // Insert a body block that just returns the constant. - OpBuilder ob(result->getBody()); + OpBuilder ob(result.getBody()); ob.createBlock(); auto sizeConstant = ob.create( loc, LLVM::LLVMType::getInt32Ty(llvmDialect), @@ -115,18 +115,18 @@ public: void runOnModule() override { llvmDialect = getModule().getContext()->getRegisteredDialect(); - Builder builder(getModule().getContext()); + auto &module = getModule(); + Builder builder(&getContext()); - auto &functions = getModule().getFunctions(); + auto functions = module.getFunctions(); for (auto it = functions.begin(); it != functions.end();) { // Move iterator to after the current function so that potential insertion // of the accessor is after the kernel with cubin iself. - Function &orig = *it++; + Function orig = *it++; StringAttr cubinBlob = orig.getAttrOfType(kCubinAnnotation); if (!cubinBlob) continue; - it = - functions.insert(it, generateCubinAccessor(builder, orig, cubinBlob)); + module.insert(it, generateCubinAccessor(builder, orig, cubinBlob)); } } diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 872707842d7..e849f6fd023 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -441,13 +441,14 @@ struct AllocOpLowering : public LLVMLegalizationPattern { createIndexConstant(rewriter, op->getLoc(), elementSize)}); // Insert the `malloc` declaration if it is not already present. - Function *mallocFunc = - op->getFunction()->getModule()->getNamedFunction("malloc"); + Function mallocFunc = + op->getFunction().getModule()->getNamedFunction("malloc"); if (!mallocFunc) { auto mallocType = rewriter.getFunctionType(getIndexType(), getVoidPtrType()); - mallocFunc = new Function(rewriter.getUnknownLoc(), "malloc", mallocType); - op->getFunction()->getModule()->getFunctions().push_back(mallocFunc); + mallocFunc = + Function::create(rewriter.getUnknownLoc(), "malloc", mallocType); + op->getFunction().getModule()->push_back(mallocFunc); } // Allocate the underlying buffer and store a pointer to it in the MemRef @@ -502,12 +503,11 @@ struct DeallocOpLowering : public LLVMLegalizationPattern { OperandAdaptor transformed(operands); // Insert the `free` declaration if it is not already present. - Function *freeFunc = - op->getFunction()->getModule()->getNamedFunction("free"); + Function freeFunc = op->getFunction().getModule()->getNamedFunction("free"); if (!freeFunc) { auto freeType = rewriter.getFunctionType(getVoidPtrType(), {}); - freeFunc = new Function(rewriter.getUnknownLoc(), "free", freeType); - op->getFunction()->getModule()->getFunctions().push_back(freeFunc); + freeFunc = Function::create(rewriter.getUnknownLoc(), "free", freeType); + op->getFunction().getModule()->push_back(freeFunc); } auto type = transformed.memref()->getType().cast(); @@ -937,7 +937,7 @@ static void ensureDistinctSuccessors(Block &bb) { } void mlir::LLVM::ensureDistinctSuccessors(Module *m) { - for (auto &f : *m) { + for (auto f : *m) { for (auto &bb : f.getBlocks()) { ::ensureDistinctSuccessors(bb); } diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp index ff198217bb7..dafc8e711f5 100644 --- a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp +++ b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp @@ -365,7 +365,7 @@ struct UniformRealMulEwPattern : public OpRewritePattern { //===----------------------------------------------------------------------===// void LowerUniformRealMathPass::runOnFunction() { - auto &fn = getFunction(); + auto fn = getFunction(); OwningRewritePatternList patterns; auto *context = &getContext(); patterns.push_back(llvm::make_unique(context)); @@ -386,7 +386,7 @@ static PassRegistration lowerUniformRealMathPass( //===----------------------------------------------------------------------===// void LowerUniformCastsPass::runOnFunction() { - auto &fn = getFunction(); + auto fn = getFunction(); OwningRewritePatternList patterns; auto *context = &getContext(); patterns.push_back(llvm::make_unique(context)); diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp index 9dcc6df6bea..8469fa2ea70 100644 --- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp +++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp @@ -106,7 +106,7 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier, void ConvertConstPass::runOnFunction() { OwningRewritePatternList patterns; - auto &func = getFunction(); + auto func = getFunction(); auto *context = &getContext(); patterns.push_back(llvm::make_unique(context)); applyPatternsGreedily(func, std::move(patterns)); diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp index ea8095b791c..0c93146a232 100644 --- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp +++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp @@ -95,7 +95,7 @@ public: void ConvertSimulatedQuantPass::runOnFunction() { bool hadFailure = false; OwningRewritePatternList patterns; - auto &func = getFunction(); + auto func = getFunction(); auto *context = &getContext(); patterns.push_back( llvm::make_unique(context, &hadFailure)); diff --git a/mlir/lib/ExecutionEngine/MemRefUtils.cpp b/mlir/lib/ExecutionEngine/MemRefUtils.cpp index 51636037382..f13b743de0c 100644 --- a/mlir/lib/ExecutionEngine/MemRefUtils.cpp +++ b/mlir/lib/ExecutionEngine/MemRefUtils.cpp @@ -67,10 +67,10 @@ allocMemRefDescriptor(Type type, bool allocateData = true, } llvm::Expected> -mlir::allocateMemRefArguments(Function *func, float initialValue) { +mlir::allocateMemRefArguments(Function func, float initialValue) { SmallVector args; - args.reserve(func->getNumArguments()); - for (const auto &arg : func->getArguments()) { + args.reserve(func.getNumArguments()); + for (const auto &arg : func.getArguments()) { auto descriptor = allocMemRefDescriptor(arg->getType(), /*allocateData=*/true, initialValue); @@ -79,10 +79,10 @@ mlir::allocateMemRefArguments(Function *func, float initialValue) { args.push_back(*descriptor); } - if (func->getType().getNumResults() > 1) + if (func.getType().getNumResults() > 1) return make_string_error("functions with more than 1 result not supported"); - for (Type resType : func->getType().getResults()) { + for (Type resType : func.getType().getResults()) { auto descriptor = allocMemRefDescriptor(resType, /*allocateData=*/false); if (!descriptor) return descriptor.takeError(); diff --git a/mlir/lib/GPU/IR/GPUDialect.cpp b/mlir/lib/GPU/IR/GPUDialect.cpp index e39860bddda..5e8090b42b4 100644 --- a/mlir/lib/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/GPU/IR/GPUDialect.cpp @@ -30,9 +30,9 @@ using namespace mlir::gpu; StringRef GPUDialect::getDialectName() { return "gpu"; } -bool GPUDialect::isKernel(Function *function) { +bool GPUDialect::isKernel(Function function) { UnitAttr isKernelAttr = - function->getAttrOfType(getKernelFuncAttrName()); + function.getAttrOfType(getKernelFuncAttrName()); return static_cast(isKernelAttr); } @@ -318,7 +318,7 @@ ParseResult LaunchOp::parse(OpAsmParser *parser, OperationState *result) { //===----------------------------------------------------------------------===// void LaunchFuncOp::build(Builder *builder, OperationState *result, - Function *kernelFunc, Value *gridSizeX, + Function kernelFunc, Value *gridSizeX, Value *gridSizeY, Value *gridSizeZ, Value *blockSizeX, Value *blockSizeY, Value *blockSizeZ, ArrayRef kernelOperands) { @@ -331,7 +331,7 @@ void LaunchFuncOp::build(Builder *builder, OperationState *result, } void LaunchFuncOp::build(Builder *builder, OperationState *result, - Function *kernelFunc, KernelDim3 gridSize, + Function kernelFunc, KernelDim3 gridSize, KernelDim3 blockSize, ArrayRef kernelOperands) { build(builder, result, kernelFunc, gridSize.x, gridSize.y, gridSize.z, @@ -366,23 +366,23 @@ LogicalResult LaunchFuncOp::verify() { return emitOpError("attribute 'kernel' must be a function"); } - auto *module = getOperation()->getFunction()->getModule(); - Function *kernelFunc = module->getNamedFunction(kernel()); + auto *module = getOperation()->getFunction().getModule(); + Function kernelFunc = module->getNamedFunction(kernel()); if (!kernelFunc) return emitError() << "kernel function '" << kernelAttr << "' is undefined"; - if (!kernelFunc->getAttrOfType( + if (!kernelFunc.getAttrOfType( GPUDialect::getKernelFuncAttrName())) { return emitError("kernel function is missing the '") << GPUDialect::getKernelFuncAttrName() << "' attribute"; } - unsigned numKernelFuncArgs = kernelFunc->getNumArguments(); + unsigned numKernelFuncArgs = kernelFunc.getNumArguments(); if (getNumKernelOperands() != numKernelFuncArgs) { return emitOpError("got ") << getNumKernelOperands() << " kernel operands but expected " << numKernelFuncArgs; } - auto functionType = kernelFunc->getType(); + auto functionType = kernelFunc.getType(); for (unsigned i = 0; i < numKernelFuncArgs; ++i) { if (getKernelOperand(i)->getType() != functionType.getInput(i)) { return emitOpError("type of function argument ") diff --git a/mlir/lib/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/GPU/Transforms/KernelOutlining.cpp index 46363f06f72..f93febcf5da 100644 --- a/mlir/lib/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/GPU/Transforms/KernelOutlining.cpp @@ -40,7 +40,7 @@ static void createForAllDimensions(OpBuilder &builder, Location loc, // Add operations generating block/thread ids and gird/block dimensions at the // beginning of `kernelFunc` and replace uses of the respective function args. -static void injectGpuIndexOperations(Location loc, Function &kernelFunc) { +static void injectGpuIndexOperations(Location loc, Function kernelFunc) { OpBuilder OpBuilder(kernelFunc.getBody()); SmallVector indexOps; createForAllDimensions(OpBuilder, loc, indexOps); @@ -58,20 +58,20 @@ static void injectGpuIndexOperations(Location loc, Function &kernelFunc) { // Outline the `gpu.launch` operation body into a kernel function. Replace // `gpu.return` operations by `std.return` in the generated functions. -static Function *outlineKernelFunc(gpu::LaunchOp launchOp) { +static Function outlineKernelFunc(gpu::LaunchOp launchOp) { Location loc = launchOp.getLoc(); SmallVector kernelOperandTypes(launchOp.getKernelOperandTypes()); FunctionType type = FunctionType::get(kernelOperandTypes, {}, launchOp.getContext()); std::string kernelFuncName = - Twine(launchOp.getOperation()->getFunction()->getName(), "_kernel").str(); - Function *outlinedFunc = new mlir::Function(loc, kernelFuncName, type); - outlinedFunc->getBody().takeBody(launchOp.getBody()); + Twine(launchOp.getOperation()->getFunction().getName(), "_kernel").str(); + Function outlinedFunc = Function::create(loc, kernelFuncName, type); + outlinedFunc.getBody().takeBody(launchOp.getBody()); Builder builder(launchOp.getContext()); - outlinedFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(), - builder.getUnitAttr()); - injectGpuIndexOperations(loc, *outlinedFunc); - outlinedFunc->walk([](mlir::gpu::Return op) { + outlinedFunc.setAttr(gpu::GPUDialect::getKernelFuncAttrName(), + builder.getUnitAttr()); + injectGpuIndexOperations(loc, outlinedFunc); + outlinedFunc.walk([](mlir::gpu::Return op) { OpBuilder replacer(op); replacer.create(op.getLoc()); op.erase(); @@ -82,12 +82,12 @@ static Function *outlineKernelFunc(gpu::LaunchOp launchOp) { // Replace `gpu.launch` operations with an `gpu.launch_func` operation launching // `kernelFunc`. static void convertToLaunchFuncOp(gpu::LaunchOp &launchOp, - Function &kernelFunc) { + Function kernelFunc) { OpBuilder builder(launchOp); SmallVector kernelOperandValues( launchOp.getKernelOperandValues()); builder.create( - launchOp.getLoc(), &kernelFunc, launchOp.getGridSizeOperandValues(), + launchOp.getLoc(), kernelFunc, launchOp.getGridSizeOperandValues(), launchOp.getBlockSizeOperandValues(), kernelOperandValues); launchOp.erase(); } @@ -98,11 +98,11 @@ class GpuKernelOutliningPass : public ModulePass { public: void runOnModule() override { ModuleManager moduleManager(&getModule()); - for (auto &func : getModule()) { + for (auto func : getModule()) { func.walk([&](mlir::gpu::LaunchOp op) { - Function *outlinedFunc = outlineKernelFunc(op); + Function outlinedFunc = outlineKernelFunc(op); moduleManager.insert(outlinedFunc); - convertToLaunchFuncOp(op, *outlinedFunc); + convertToLaunchFuncOp(op, outlinedFunc); }); } } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 8e3d5788bb1..346d35af231 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -306,7 +306,7 @@ void ModuleState::initialize(Module *module) { initializeSymbolAliases(); // Walk the module and visit each operation. - for (auto &fn : *module) { + for (auto fn : *module) { visitType(fn.getType()); for (auto attr : fn.getAttrs()) ModuleState::visitAttribute(attr.second); @@ -342,7 +342,7 @@ public: void printAttribute(Attribute attr, bool mayElideType = false); void printType(Type type); - void print(Function *fn); + void print(Function fn); void printLocation(LocationAttr loc); void printAffineMap(AffineMap map); @@ -460,8 +460,8 @@ void ModulePrinter::print(Module *module) { state.printTypeAliases(os); // Print the module. - for (auto &fn : *module) - print(&fn); + for (auto fn : *module) + print(fn); } /// Print a floating point value in a way that the parser will be able to @@ -1186,7 +1186,7 @@ namespace { // CFG and ML functions. class FunctionPrinter : public ModulePrinter, private OpAsmPrinter { public: - FunctionPrinter(Function *function, ModulePrinter &other); + FunctionPrinter(Function function, ModulePrinter &other); // Prints the function as a whole. void print(); @@ -1275,7 +1275,7 @@ protected: void printValueID(Value *value, bool printResultNo = true) const; private: - Function *function; + Function function; /// This is the value ID for each SSA value in the current function. If this /// returns ~0, then the valueID has an entry in valueNames. @@ -1305,10 +1305,10 @@ private: }; } // end anonymous namespace -FunctionPrinter::FunctionPrinter(Function *function, ModulePrinter &other) +FunctionPrinter::FunctionPrinter(Function function, ModulePrinter &other) : ModulePrinter(other), function(function) { - for (auto &block : *function) + for (auto &block : function) numberValuesInBlock(block); } @@ -1419,17 +1419,17 @@ void FunctionPrinter::print() { printFunctionSignature(); // Print out function attributes, if present. - auto attrs = function->getAttrs(); + auto attrs = function.getAttrs(); if (!attrs.empty()) { os << "\n attributes "; printOptionalAttrDict(attrs); } // Print the trailing location. - printTrailingLocation(function->getLoc()); + printTrailingLocation(function.getLoc()); - if (!function->empty()) { - printRegion(function->getBody(), /*printEntryBlockArgs=*/false, + if (!function.empty()) { + printRegion(function.getBody(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/true); os << "\n"; } @@ -1437,24 +1437,24 @@ void FunctionPrinter::print() { } void FunctionPrinter::printFunctionSignature() { - os << "func @" << function->getName() << '('; + os << "func @" << function.getName() << '('; - auto fnType = function->getType(); - bool isExternal = function->isExternal(); - for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i) { + auto fnType = function.getType(); + bool isExternal = function.isExternal(); + for (unsigned i = 0, e = function.getNumArguments(); i != e; ++i) { if (i > 0) os << ", "; // If this is an external function, don't print argument labels. if (!isExternal) { - printOperand(function->getArgument(i)); + printOperand(function.getArgument(i)); os << ": "; } printType(fnType.getInput(i)); // Print the attributes for this argument. - printOptionalAttrDict(function->getArgAttrs(i)); + printOptionalAttrDict(function.getArgAttrs(i)); } os << ')'; @@ -1662,7 +1662,7 @@ void FunctionPrinter::printSuccessorAndUseList(Operation *term, } // Prints function with initialized module state. -void ModulePrinter::print(Function *fn) { FunctionPrinter(fn, *this).print(); } +void ModulePrinter::print(Function fn) { FunctionPrinter(fn, *this).print(); } //===----------------------------------------------------------------------===// // print and dump methods @@ -1737,13 +1737,13 @@ void Value::print(raw_ostream &os) { void Value::dump() { print(llvm::errs()); } void Operation::print(raw_ostream &os) { - auto *function = getFunction(); + auto function = getFunction(); if (!function) { os << "<>\n"; return; } - ModuleState state(function->getContext()); + ModuleState state(function.getContext()); ModulePrinter modulePrinter(os, state); FunctionPrinter(function, modulePrinter).print(this); } @@ -1754,13 +1754,13 @@ void Operation::dump() { } void Block::print(raw_ostream &os) { - auto *function = getFunction(); + auto function = getFunction(); if (!function) { os << "<>\n"; return; } - ModuleState state(function->getContext()); + ModuleState state(function.getContext()); ModulePrinter modulePrinter(os, state); FunctionPrinter(function, modulePrinter).print(this); } @@ -1773,14 +1773,14 @@ void Block::printAsOperand(raw_ostream &os, bool printType) { os << "<>\n"; return; } - ModuleState state(getFunction()->getContext()); + ModuleState state(getFunction().getContext()); ModulePrinter modulePrinter(os, state); FunctionPrinter(getFunction(), modulePrinter).printBlockName(this); } void Function::print(raw_ostream &os) { ModuleState state(getContext()); - ModulePrinter(os, state).print(this); + ModulePrinter(os, state).print(*this); } void Function::dump() { print(llvm::errs()); } diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index 01f9a060bd9..9cbba0fe429 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -249,11 +249,6 @@ FloatAttr::verifyConstructionInvariants(llvm::Optional loc, // FunctionAttr //===----------------------------------------------------------------------===// -FunctionAttr FunctionAttr::get(Function *value) { - assert(value && "Cannot get FunctionAttr for a null function"); - return get(value->getName(), value->getContext()); -} - FunctionAttr FunctionAttr::get(StringRef value, MLIRContext *ctx) { return Base::get(ctx, StandardAttributes::Function, value, NoneType::get(ctx)); diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp index e7616f6d7d0..134f6e468a0 100644 --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -50,7 +50,7 @@ Operation *Block::getContainingOp() { return getParent() ? getParent()->getContainingOp() : nullptr; } -Function *Block::getFunction() { +Function Block::getFunction() { Block *block = this; while (auto *op = block->getContainingOp()) { block = op->getBlock(); diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 9b30205abdb..89df64260d3 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -177,8 +177,8 @@ IntegerSetAttr Builder::getIntegerSetAttr(IntegerSet set) { TypeAttr Builder::getTypeAttr(Type type) { return TypeAttr::get(type); } -FunctionAttr Builder::getFunctionAttr(Function *value) { - return FunctionAttr::get(value); +FunctionAttr Builder::getFunctionAttr(Function value) { + return getFunctionAttr(value.getName()); } FunctionAttr Builder::getFunctionAttr(StringRef value) { return FunctionAttr::get(value, getContext()); diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp index 4547452eb55..e38b95ff0f7 100644 --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectHooks.h" +#include "mlir/IR/Function.h" #include "mlir/IR/MLIRContext.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/ManagedStatic.h" @@ -68,6 +69,20 @@ Dialect::Dialect(StringRef name, MLIRContext *context) Dialect::~Dialect() {} +/// Verify an attribute from this dialect on the given function. Returns +/// failure if the verification failed, success otherwise. +LogicalResult Dialect::verifyFunctionAttribute(Function, NamedAttribute) { + return success(); +} + +/// Verify an attribute from this dialect on the argument at 'argIndex' for +/// the given function. Returns failure if the verification failed, success +/// otherwise. +LogicalResult Dialect::verifyFunctionArgAttribute(Function, unsigned argIndex, + NamedAttribute) { + return success(); +} + /// Parse an attribute registered to this dialect. Attribute Dialect::parseAttribute(StringRef attrData, Location loc) const { emitError(loc) << "dialect '" << getNamespace() diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index 7d17ed1d705..f8835f02c26 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -27,45 +27,50 @@ #include "llvm/ADT/Twine.h" using namespace mlir; +using namespace mlir::detail; -Function::Function(Location location, StringRef name, FunctionType type, - ArrayRef attrs) +FunctionStorage::FunctionStorage(Location location, StringRef name, + FunctionType type, + ArrayRef attrs) : name(Identifier::get(name, type.getContext())), location(location), type(type), attrs(attrs), argAttrs(type.getNumInputs()), body(this) {} -Function::Function(Location location, StringRef name, FunctionType type, - ArrayRef attrs, - ArrayRef argAttrs) +FunctionStorage::FunctionStorage(Location location, StringRef name, + FunctionType type, + ArrayRef attrs, + ArrayRef argAttrs) : name(Identifier::get(name, type.getContext())), location(location), type(type), attrs(attrs), argAttrs(argAttrs), body(this) {} MLIRContext *Function::getContext() { return getType().getContext(); } -Module *llvm::ilist_traits::getContainingModule() { +Module *llvm::ilist_traits::getContainingModule() { size_t Offset( size_t(&((Module *)nullptr->*Module::getSublistAccess(nullptr)))); - iplist *Anchor(static_cast *>(this)); + iplist *Anchor(static_cast *>(this)); return reinterpret_cast(reinterpret_cast(Anchor) - Offset); } /// This is a trait method invoked when a Function is added to a Module. We /// keep the module pointer and module symbol table up to date. -void llvm::ilist_traits::addNodeToList(Function *function) { - assert(!function->getModule() && "already in a module!"); +void llvm::ilist_traits::addNodeToList( + FunctionStorage *function) { + assert(!function->module && "already in a module!"); function->module = getContainingModule(); } /// This is a trait method invoked when a Function is removed from a Module. /// We keep the module pointer up to date. -void llvm::ilist_traits::removeNodeFromList(Function *function) { +void llvm::ilist_traits::removeNodeFromList( + FunctionStorage *function) { assert(function->module && "not already in a module!"); function->module = nullptr; } /// This is a trait method invoked when an operation is moved from one block /// to another. We keep the block pointer up to date. -void llvm::ilist_traits::transferNodesFromList( - ilist_traits &otherList, function_iterator first, +void llvm::ilist_traits::transferNodesFromList( + ilist_traits &otherList, function_iterator first, function_iterator last) { // If we are transferring functions within the same module, the Module // pointer doesn't need to be updated. @@ -82,8 +87,10 @@ void llvm::ilist_traits::transferNodesFromList( /// Unlink this function from its Module and delete it. void Function::erase() { - assert(getModule() && "Function has no parent"); - getModule()->getFunctions().erase(this); + if (auto *module = getModule()) + getModule()->functions.erase(impl); + else + delete impl; } /// Emit an error about fatal conditions with this function, reporting up to @@ -111,10 +118,10 @@ InFlightDiagnostic Function::emitRemark(const Twine &message) { /// Clone the internal blocks from this function into dest and all attributes /// from this function to dest. -void Function::cloneInto(Function *dest, BlockAndValueMapping &mapper) { +void Function::cloneInto(Function dest, BlockAndValueMapping &mapper) { // Add the attributes of this function to dest. llvm::MapVector newAttrs; - for (auto &attr : dest->getAttrs()) + for (auto &attr : dest.getAttrs()) newAttrs.insert(attr); for (auto &attr : getAttrs()) { auto insertPair = newAttrs.insert(attr); @@ -125,10 +132,10 @@ void Function::cloneInto(Function *dest, BlockAndValueMapping &mapper) { assert((insertPair.second || insertPair.first->second == attr.second) && "the two functions have incompatible attributes"); } - dest->setAttrs(newAttrs.takeVector()); + dest.setAttrs(newAttrs.takeVector()); // Clone the body. - body.cloneInto(&dest->body, mapper); + impl->body.cloneInto(&dest.impl->body, mapper); } /// Create a deep copy of this function and all of its blocks, remapping @@ -136,8 +143,8 @@ void Function::cloneInto(Function *dest, BlockAndValueMapping &mapper) { /// provided (leaving them alone if no entry is present). Replaces references /// to cloned sub-values with the corresponding value that is copied, and adds /// those mappings to the mapper. -Function *Function::clone(BlockAndValueMapping &mapper) { - FunctionType newType = type; +Function Function::clone(BlockAndValueMapping &mapper) { + FunctionType newType = impl->type; // If the function has a body, then the user might be deleting arguments to // the function by specifying them in the mapper. If so, we don't add the @@ -147,23 +154,23 @@ Function *Function::clone(BlockAndValueMapping &mapper) { SmallVector inputTypes; for (unsigned i = 0, e = getNumArguments(); i != e; ++i) if (!mapper.contains(getArgument(i))) - inputTypes.push_back(type.getInput(i)); - newType = FunctionType::get(inputTypes, type.getResults(), getContext()); + inputTypes.push_back(newType.getInput(i)); + newType = FunctionType::get(inputTypes, newType.getResults(), getContext()); } // Create the new function. - Function *newFunc = new Function(getLoc(), getName(), newType); + Function newFunc = Function::create(getLoc(), getName(), newType); /// Set the argument attributes for arguments that aren't being replaced. for (unsigned i = 0, e = getNumArguments(), destI = 0; i != e; ++i) if (isExternalFn || !mapper.contains(getArgument(i))) - newFunc->setArgAttrs(destI++, getArgAttrs(i)); + newFunc.setArgAttrs(destI++, getArgAttrs(i)); /// Clone the current function into the new one and return it. cloneInto(newFunc, mapper); return newFunc; } -Function *Function::clone() { +Function Function::clone() { BlockAndValueMapping mapper; return clone(mapper); } @@ -178,7 +185,7 @@ void Function::addEntryBlock() { assert(empty() && "function already has an entry block"); auto *entry = new Block(); push_back(entry); - entry->addArguments(type.getInputs()); + entry->addArguments(impl->type.getInputs()); } void Function::walk(const std::function &callback) { diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 83171f12d1d..f953cd27a56 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -281,7 +281,7 @@ Operation *Operation::getParentOp() { return block ? block->getContainingOp() : nullptr; } -Function *Operation::getFunction() { +Function Operation::getFunction() { return block ? block->getFunction() : nullptr; } @@ -861,12 +861,13 @@ static LogicalResult verifyBBArguments(Operation::operand_range operands, } static LogicalResult verifyTerminatorSuccessors(Operation *op) { + auto *parent = op->getContainingRegion(); + // Verify that the operands lines up with the BB arguments in the successor. - Function *fn = op->getFunction(); for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) { auto *succ = op->getSuccessor(i); - if (succ->getFunction() != fn) - return op->emitError("reference to block defined in another function"); + if (succ->getParent() != parent) + return op->emitError("reference to block defined in another region"); if (failed(verifyBBArguments(op->getSuccessorOperands(i), succ, op))) return failure(); } diff --git a/mlir/lib/IR/Region.cpp b/mlir/lib/IR/Region.cpp index 992d9112beb..74c71b7aeac 100644 --- a/mlir/lib/IR/Region.cpp +++ b/mlir/lib/IR/Region.cpp @@ -21,7 +21,7 @@ #include "mlir/IR/Operation.h" using namespace mlir; -Region::Region(Function *container) : container(container) {} +Region::Region(Function container) : container(container.impl) {} Region::Region(Operation *container) : container(container) {} @@ -38,7 +38,7 @@ MLIRContext *Region::getContext() { assert(!container.isNull() && "region is not attached to a container"); if (auto *inst = getContainingOp()) return inst->getContext(); - return getContainingFunction()->getContext(); + return getContainingFunction().getContext(); } /// Return a location for this region. This is the location attached to the @@ -47,7 +47,7 @@ Location Region::getLoc() { assert(!container.isNull() && "region is not attached to a container"); if (auto *inst = getContainingOp()) return inst->getLoc(); - return getContainingFunction()->getLoc(); + return getContainingFunction().getLoc(); } Region *Region::getContainingRegion() { @@ -60,8 +60,8 @@ Operation *Region::getContainingOp() { return container.dyn_cast(); } -Function *Region::getContainingFunction() { - return container.dyn_cast(); +Function Region::getContainingFunction() { + return container.dyn_cast(); } bool Region::isProperAncestor(Region *other) { diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp index a0819a78fc1..dafbd48f513 100644 --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -22,8 +22,8 @@ using namespace mlir; /// Build a symbol table with the symbols within the given module. SymbolTable::SymbolTable(Module *module) : context(module->getContext()) { - for (auto &func : *module) { - auto inserted = symbolTable.insert({func.getName(), &func}); + for (auto func : *module) { + auto inserted = symbolTable.insert({func.getName(), func}); (void)inserted; assert(inserted.second && "expected module to contain uniquely named functions"); @@ -32,34 +32,34 @@ SymbolTable::SymbolTable(Module *module) : context(module->getContext()) { /// Look up a symbol with the specified name, returning null if no such name /// exists. Names never include the @ on them. -Function *SymbolTable::lookup(StringRef name) const { +Function SymbolTable::lookup(StringRef name) const { return lookup(Identifier::get(name, context)); } /// Look up a symbol with the specified name, returning null if no such name /// exists. Names never include the @ on them. -Function *SymbolTable::lookup(Identifier name) const { +Function SymbolTable::lookup(Identifier name) const { return symbolTable.lookup(name); } /// Erase the given symbol from the table. -void SymbolTable::erase(Function *symbol) { - auto it = symbolTable.find(symbol->getName()); +void SymbolTable::erase(Function symbol) { + auto it = symbolTable.find(symbol.getName()); if (it != symbolTable.end() && it->second == symbol) symbolTable.erase(it); } /// Insert a new symbol into the table, and rename it as necessary to avoid /// collisions. -void SymbolTable::insert(Function *symbol) { +void SymbolTable::insert(Function symbol) { // Add this symbol to the symbol table, uniquing the name if a conflict is // detected. - if (symbolTable.insert({symbol->getName(), symbol}).second) + if (symbolTable.insert({symbol.getName(), symbol}).second) return; // If a conflict was detected, then the function will not have been added to // the symbol table. Try suffixes until we get to a unique name that works. - SmallString<128> nameBuffer(symbol->getName()); + SmallString<128> nameBuffer(symbol.getName()); unsigned originalLength = nameBuffer.size(); // Iteratively try suffixes until we find one that isn't used. We use a @@ -68,6 +68,6 @@ void SymbolTable::insert(Function *symbol) { nameBuffer.resize(originalLength); nameBuffer += '_'; nameBuffer += std::to_string(uniquingCounter++); - symbol->setName(Identifier::get(nameBuffer, context)); - } while (!symbolTable.insert({symbol->getName(), symbol}).second); + symbol.setName(Identifier::get(nameBuffer, context)); + } while (!symbolTable.insert({symbol.getName(), symbol}).second); } diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp index 073c3b369c6..65a98f7ee59 100644 --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -30,7 +30,7 @@ Operation *Value::getDefiningOp() { } /// Return the function that this Value is defined in. -Function *Value::getFunction() { +Function Value::getFunction() { switch (getKind()) { case Value::Kind::BlockArgument: return cast(this)->getFunction(); @@ -84,7 +84,7 @@ void IRObjectWithUseList::dropAllUses() { //===----------------------------------------------------------------------===// /// Return the function that this argument is defined in. -Function *BlockArgument::getFunction() { +Function BlockArgument::getFunction() { if (auto *owner = getOwner()) return owner->getFunction(); return nullptr; @@ -92,6 +92,6 @@ Function *BlockArgument::getFunction() { /// Returns if the current argument is a function argument. bool BlockArgument::isFunctionArgument() { - auto *containingFn = getFunction(); - return containingFn && &containingFn->front() == getOwner(); + auto containingFn = getFunction(); + return containingFn && &containingFn.front() == getOwner(); } diff --git a/mlir/lib/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/LLVMIR/IR/LLVMDialect.cpp index 0d3a5ca2756..0dbf63a3ce7 100644 --- a/mlir/lib/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/LLVMIR/IR/LLVMDialect.cpp @@ -816,12 +816,12 @@ void LLVMDialect::printType(Type type, raw_ostream &os) const { } /// Verify LLVMIR function argument attributes. -LogicalResult LLVMDialect::verifyFunctionArgAttribute(Function *func, +LogicalResult LLVMDialect::verifyFunctionArgAttribute(Function func, unsigned argIdx, NamedAttribute argAttr) { // Check that llvm.noalias is a boolean attribute. if (argAttr.first == "llvm.noalias" && !argAttr.second.isa()) - return func->emitError() + return func.emitError() << "llvm.noalias argument attribute of non boolean type"; return success(); } diff --git a/mlir/lib/Linalg/Transforms/Fusion.cpp b/mlir/lib/Linalg/Transforms/Fusion.cpp index 7ddb7b0c19f..5761cc637b7 100644 --- a/mlir/lib/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Linalg/Transforms/Fusion.cpp @@ -209,7 +209,7 @@ static bool isStructurallyFusableProducer(LinalgOp producer, Value *readView, return true; } -static void fuseLinalgOps(Function &f, ArrayRef tileSizes) { +static void fuseLinalgOps(Function f, ArrayRef tileSizes) { OperationFolder state; DenseSet eraseSet; diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index a8099aaff99..5fe4f07613a 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -170,12 +170,13 @@ public: LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo(); auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); // Insert the `malloc` declaration if it is not already present. - auto *module = op->getFunction()->getModule(); - Function *mallocFunc = module->getNamedFunction("malloc"); + auto *module = op->getFunction().getModule(); + Function mallocFunc = module->getNamedFunction("malloc"); if (!mallocFunc) { auto mallocType = rewriter.getFunctionType(int64Ty, voidPtrTy); - mallocFunc = new Function(rewriter.getUnknownLoc(), "malloc", mallocType); - module->getFunctions().push_back(mallocFunc); + mallocFunc = + Function::create(rewriter.getUnknownLoc(), "malloc", mallocType); + module->push_back(mallocFunc); } // Get MLIR types for injecting element pointer. @@ -230,12 +231,12 @@ public: auto voidPtrTy = LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo(); // Insert the `free` declaration if it is not already present. - auto *module = op->getFunction()->getModule(); - Function *freeFunc = module->getNamedFunction("free"); + auto *module = op->getFunction().getModule(); + Function freeFunc = module->getNamedFunction("free"); if (!freeFunc) { auto freeType = rewriter.getFunctionType(voidPtrTy, {}); - freeFunc = new Function(rewriter.getUnknownLoc(), "free", freeType); - module->getFunctions().push_back(freeFunc); + freeFunc = Function::create(rewriter.getUnknownLoc(), "free", freeType); + module->push_back(freeFunc); } // Get MLIR types for extracting element pointer. @@ -572,37 +573,37 @@ public: // Create a function definition which takes as argument pointers to the input // types and returns pointers to the output types. -static Function *getLLVMLibraryCallImplDefinition(Function *libFn) { - auto implFnName = (libFn->getName().str() + "_impl"); - auto module = libFn->getModule(); - if (auto *f = module->getNamedFunction(implFnName)) { +static Function getLLVMLibraryCallImplDefinition(Function libFn) { + auto implFnName = (libFn.getName().str() + "_impl"); + auto module = libFn.getModule(); + if (auto f = module->getNamedFunction(implFnName)) { return f; } SmallVector fnArgTypes; - for (auto t : libFn->getType().getInputs()) { + for (auto t : libFn.getType().getInputs()) { assert(t.isa() && "Expected LLVM Type for argument while generating library Call " "Implementation Definition"); fnArgTypes.push_back(t.cast().getPointerTo()); } - auto implFnType = FunctionType::get(fnArgTypes, {}, libFn->getContext()); + auto implFnType = FunctionType::get(fnArgTypes, {}, libFn.getContext()); // Insert the implementation function definition. - auto implFnDefn = new Function(libFn->getLoc(), implFnName, implFnType); - module->getFunctions().push_back(implFnDefn); + auto implFnDefn = Function::create(libFn.getLoc(), implFnName, implFnType); + module->push_back(implFnDefn); return implFnDefn; } // Get function definition for the LinalgOp. If it doesn't exist, insert a // definition. template -static Function *getLLVMLibraryCallDeclaration(Operation *op, - LLVMTypeConverter &lowering, - PatternRewriter &rewriter) { +static Function getLLVMLibraryCallDeclaration(Operation *op, + LLVMTypeConverter &lowering, + PatternRewriter &rewriter) { assert(isa(op)); auto fnName = LinalgOp::getLibraryCallName(); - auto module = op->getFunction()->getModule(); - if (auto *f = module->getNamedFunction(fnName)) { + auto module = op->getFunction().getModule(); + if (auto f = module->getNamedFunction(fnName)) { return f; } @@ -618,29 +619,29 @@ static Function *getLLVMLibraryCallDeclaration(Operation *op, "Library call for linalg operation can be generated only for ops that " "have void return types"); auto libFnType = FunctionType::get(inputTypes, {}, op->getContext()); - auto libFn = new Function(op->getLoc(), fnName, libFnType); - module->getFunctions().push_back(libFn); + auto libFn = Function::create(op->getLoc(), fnName, libFnType); + module->push_back(libFn); // Return after creating the function definition. The body will be created // later. return libFn; } -static void getLLVMLibraryCallDefinition(Function *fn, +static void getLLVMLibraryCallDefinition(Function fn, LLVMTypeConverter &lowering) { // Generate the implementation function definition. auto implFn = getLLVMLibraryCallImplDefinition(fn); // Generate the function body. - fn->addEntryBlock(); + fn.addEntryBlock(); - OpBuilder builder(fn->getBody()); - edsc::ScopedContext scope(builder, fn->getLoc()); + OpBuilder builder(fn.getBody()); + edsc::ScopedContext scope(builder, fn.getLoc()); SmallVector implFnArgs; // Create a constant 1. auto one = constant(LLVMType::getInt64Ty(lowering.getDialect()), - IntegerAttr::get(IndexType::get(fn->getContext()), 1)); - for (auto arg : fn->getArguments()) { + IntegerAttr::get(IndexType::get(fn.getContext()), 1)); + for (auto arg : fn.getArguments()) { // Allocate a stack for storing the argument value. The stack is passed to // the implementation function. auto alloca = @@ -665,17 +666,17 @@ public: return convertLinalgType(t, *this); } - void addLibraryFnDeclaration(Function *fn) { + void addLibraryFnDeclaration(Function fn) { libraryFnDeclarations.push_back(fn); } - ArrayRef getLibraryFnDeclarations() { + ArrayRef getLibraryFnDeclarations() { return libraryFnDeclarations; } private: /// List of library functions declarations needed during dialect conversion - SmallVector libraryFnDeclarations; + SmallVector libraryFnDeclarations; }; } // end anonymous namespace @@ -692,7 +693,7 @@ public: PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, PatternRewriter &rewriter) const override { // Only emit library call declaration. Fill in the body later. - auto *f = getLLVMLibraryCallDeclaration(op, lowering, rewriter); + auto f = getLLVMLibraryCallDeclaration(op, lowering, rewriter); static_cast(lowering).addLibraryFnDeclaration(f); auto fAttr = rewriter.getFunctionAttr(f); @@ -803,7 +804,7 @@ static void lowerLinalgForToCFG(Function &f) { void LowerLinalgToLLVMPass::runOnModule() { auto &module = getModule(); - for (auto &f : module.getFunctions()) { + for (auto f : module.getFunctions()) { lowerLinalgSubViewOps(f); lowerLinalgForToCFG(f); if (failed(lowerAffineConstructs(f))) diff --git a/mlir/lib/Linalg/Transforms/LowerToLoops.cpp b/mlir/lib/Linalg/Transforms/LowerToLoops.cpp index d31ba5bf22d..2e616c35f1d 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLoops.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLoops.cpp @@ -104,9 +104,8 @@ struct LowerLinalgToLoopsPass : public FunctionPass { } // namespace void LowerLinalgToLoopsPass::runOnFunction() { - auto &f = getFunction(); OperationFolder state; - f.walk([&state](LinalgOp linalgOp) { + getFunction().walk([&state](LinalgOp linalgOp) { emitLinalgOpAsLoops(linalgOp, state); linalgOp.getOperation()->erase(); }); diff --git a/mlir/lib/Linalg/Transforms/Tiling.cpp b/mlir/lib/Linalg/Transforms/Tiling.cpp index c63e1cf197d..2f752b2b637 100644 --- a/mlir/lib/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Linalg/Transforms/Tiling.cpp @@ -259,7 +259,7 @@ mlir::linalg::tileLinalgOp(LinalgOp op, ArrayRef tileSizes, return tileLinalgOp(op, tileSizeValues, state); } -static void tileLinalgOps(Function &f, ArrayRef tileSizes) { +static void tileLinalgOps(Function f, ArrayRef tileSizes) { OperationFolder state; f.walk([tileSizes, &state](LinalgOp op) { auto opLoopsPair = tileLinalgOp(op, tileSizes, state); diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 44f05963727..4af2f093daf 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -254,7 +254,7 @@ public: /// trailing-location ::= location? /// template - ParseResult parseOptionalTrailingLocation(Owner *owner) { + ParseResult parseOptionalTrailingLocation(Owner &owner) { // If there is a 'loc' we parse a trailing location. if (!getToken().is(Token::kw_loc)) return success(); @@ -263,7 +263,7 @@ public: LocationAttr directLoc; if (parseLocation(directLoc)) return failure(); - owner->setLoc(directLoc); + owner.setLoc(directLoc); return success(); } @@ -2472,8 +2472,8 @@ namespace { /// operations. class OperationParser : public Parser { public: - OperationParser(ParserState &state, Function *function) - : Parser(state), function(function), opBuilder(function->getBody()) {} + OperationParser(ParserState &state, Function function) + : Parser(state), function(function), opBuilder(function.getBody()) {} ~OperationParser(); @@ -2588,7 +2588,7 @@ public: Block *defineBlockNamed(StringRef name, SMLoc loc, Block *existing); private: - Function *function; + Function function; /// Returns the info for a block at the current scope for the given name. std::pair &getBlockInfoByName(StringRef name) { @@ -2690,7 +2690,7 @@ ParseResult OperationParser::popSSANameScope() { for (auto entry : forwardRefInCurrentScope) { errors.push_back({entry.second.getPointer(), entry.first}); // Add this block to the top-level region to allow for automatic cleanup. - function->push_back(entry.first); + function.push_back(entry.first); } llvm::array_pod_sort(errors.begin(), errors.end()); @@ -2984,7 +2984,7 @@ ParseResult OperationParser::parseOperation() { } // Try to parse the optional trailing location. - if (parseOptionalTrailingLocation(op)) + if (parseOptionalTrailingLocation(*op)) return failure(); return success(); @@ -4049,17 +4049,17 @@ ParseResult ModuleParser::parseFunc(Module *module) { } // Okay, the function signature was parsed correctly, create the function now. - auto *function = - new Function(getEncodedSourceLocation(loc), name, type, attrs); - module->getFunctions().push_back(function); + auto function = + Function::create(getEncodedSourceLocation(loc), name, type, attrs); + module->push_back(function); // Parse an optional trailing location. if (parseOptionalTrailingLocation(function)) return failure(); // Add the attributes to the function arguments. - for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i) - function->setArgAttrs(i, argAttrs[i]); + for (unsigned i = 0, e = function.getNumArguments(); i != e; ++i) + function.setArgAttrs(i, argAttrs[i]); // External functions have no body. if (getToken().isNot(Token::l_brace)) @@ -4076,11 +4076,11 @@ ParseResult ModuleParser::parseFunc(Module *module) { // Parse the function body. auto parser = OperationParser(getState(), function); - if (parser.parseRegion(function->getBody(), entryArgs)) + if (parser.parseRegion(function.getBody(), entryArgs)) return failure(); // Verify that a valid function body was parsed. - if (function->empty()) + if (function.empty()) return emitError(braceLoc, "function must have a body"); return parser.finalize(braceLoc); diff --git a/mlir/lib/Pass/IRPrinting.cpp b/mlir/lib/Pass/IRPrinting.cpp index 868d492e094..057f2655207 100644 --- a/mlir/lib/Pass/IRPrinting.cpp +++ b/mlir/lib/Pass/IRPrinting.cpp @@ -61,12 +61,12 @@ private: static void printIR(const llvm::Any &ir, bool printModuleScope, raw_ostream &out) { // Check for printing at module scope. - if (printModuleScope && llvm::any_isa(ir)) { - Function *function = llvm::any_cast(ir); + if (printModuleScope && llvm::any_isa(ir)) { + Function function = llvm::any_cast(ir); // Print the function name and a newline before the Module. - out << " (function: " << function->getName() << ")\n"; - function->getModule()->print(out); + out << " (function: " << function.getName() << ")\n"; + function.getModule()->print(out); return; } @@ -74,8 +74,8 @@ static void printIR(const llvm::Any &ir, bool printModuleScope, out << "\n"; // Print the given function. - if (llvm::any_isa(ir)) { - llvm::any_cast(ir)->print(out); + if (llvm::any_isa(ir)) { + llvm::any_cast(ir).print(out); return; } diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp index 2f605b6690b..27ec74c23c2 100644 --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -46,8 +46,7 @@ static llvm::cl::opt void Pass::anchor() {} /// Forwarding function to execute this pass. -LogicalResult FunctionPassBase::run(Function *fn, - FunctionAnalysisManager &fam) { +LogicalResult FunctionPassBase::run(Function fn, FunctionAnalysisManager &fam) { // Initialize the pass state. passState.emplace(fn, fam); @@ -115,7 +114,7 @@ FunctionPassExecutor::FunctionPassExecutor(const FunctionPassExecutor &rhs) } /// Run all of the passes in this manager over the current function. -LogicalResult detail::FunctionPassExecutor::run(Function *function, +LogicalResult detail::FunctionPassExecutor::run(Function function, FunctionAnalysisManager &fam) { // Run each of the held passes. for (auto &pass : passes) @@ -141,7 +140,7 @@ LogicalResult detail::ModulePassExecutor::run(Module *module, /// Utility to run the given function and analysis manager on a provided /// function pass executor. static LogicalResult runFunctionPipeline(FunctionPassExecutor &fpe, - Function *func, + Function func, FunctionAnalysisManager &fam) { // Run the function pipeline over the provided function. auto result = fpe.run(func, fam); @@ -158,14 +157,14 @@ static LogicalResult runFunctionPipeline(FunctionPassExecutor &fpe, /// module. void ModuleToFunctionPassAdaptor::runOnModule() { ModuleAnalysisManager &mam = getAnalysisManager(); - for (auto &func : getModule()) { + for (auto func : getModule()) { // Skip external functions. if (func.isExternal()) continue; // Run the held function pipeline over the current function. - auto fam = mam.slice(&func); - if (failed(runFunctionPipeline(fpe, &func, fam))) + auto fam = mam.slice(func); + if (failed(runFunctionPipeline(fpe, func, fam))) return signalPassFailure(); // Clear out any computed function analyses. These analyses won't be used @@ -189,10 +188,10 @@ void ModuleToFunctionPassAdaptorParallel::runOnModule() { // Run a prepass over the module to collect the functions to execute a over. // This ensures that an analysis manager exists for each function, as well as // providing a queue of functions to execute over. - std::vector> funcAMPairs; - for (auto &func : getModule()) + std::vector> funcAMPairs; + for (auto func : getModule()) if (!func.isExternal()) - funcAMPairs.emplace_back(&func, mam.slice(&func)); + funcAMPairs.emplace_back(func, mam.slice(func)); // A parallel diagnostic handler that provides deterministic diagnostic // ordering. @@ -340,8 +339,8 @@ PassInstrumentor *FunctionAnalysisManager::getPassInstrumentor() const { } /// Create an analysis slice for the given child function. -FunctionAnalysisManager ModuleAnalysisManager::slice(Function *func) { - assert(func->getModule() == moduleAnalyses.getIRUnit() && +FunctionAnalysisManager ModuleAnalysisManager::slice(Function func) { + assert(func.getModule() == moduleAnalyses.getIRUnit() && "function has a different parent module"); auto it = functionAnalyses.find(func); if (it == functionAnalyses.end()) { diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h index 46addfb8e9c..d2563fb62cd 100644 --- a/mlir/lib/Pass/PassDetail.h +++ b/mlir/lib/Pass/PassDetail.h @@ -48,7 +48,7 @@ public: FunctionPassExecutor(const FunctionPassExecutor &rhs); /// Run the executor on the given function. - LogicalResult run(Function *function, FunctionAnalysisManager &fam); + LogicalResult run(Function function, FunctionAnalysisManager &fam); /// Add a pass to the current executor. This takes ownership over the provided /// pass pointer. diff --git a/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp b/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp index 375a64d8f2d..3f26bf075af 100644 --- a/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp +++ b/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp @@ -71,7 +71,7 @@ void AddDefaultStatsPass::runOnFunction() { void AddDefaultStatsPass::runWithConfig(SolverContext &solverContext, const TargetConfiguration &config) { - auto &func = getFunction(); + auto func = getFunction(); // Insert stats for each argument. for (auto *arg : func.getArguments()) { diff --git a/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp b/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp index dec4ea90db8..169fec3b39a 100644 --- a/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp +++ b/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp @@ -129,7 +129,7 @@ void InferQuantizedTypesPass::runOnModule() { void InferQuantizedTypesPass::runWithConfig(SolverContext &solverContext, const TargetConfiguration &config) { CAGSlice cag(solverContext); - for (auto &f : getModule()) { + for (auto f : getModule()) { f.walk([&cag, &config](Operation *op) { config.handleOp(op, cag); }); } config.finalizeAnchors(cag); diff --git a/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp b/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp index ed3b0956a16..6b376db8516 100644 --- a/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp +++ b/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp @@ -58,7 +58,7 @@ public: void RemoveInstrumentationPass::runOnFunction() { OwningRewritePatternList patterns; - auto &func = getFunction(); + auto func = getFunction(); auto *context = &getContext(); patterns.push_back( llvm::make_unique>(context)); diff --git a/mlir/lib/SPIRV/Serialization/ConvertFromBinary.cpp b/mlir/lib/SPIRV/Serialization/ConvertFromBinary.cpp index 3add211fdd5..543b7300af0 100644 --- a/mlir/lib/SPIRV/Serialization/ConvertFromBinary.cpp +++ b/mlir/lib/SPIRV/Serialization/ConvertFromBinary.cpp @@ -36,11 +36,11 @@ using namespace mlir; // block. The created block will be terminated by `std.return`. Block *createOneBlockFunction(Builder builder, Module *module) { auto fnType = builder.getFunctionType(/*inputs=*/{}, /*results=*/{}); - auto *fn = new Function(builder.getUnknownLoc(), "spirv_module", fnType); - module->getFunctions().push_back(fn); + auto fn = Function::create(builder.getUnknownLoc(), "spirv_module", fnType); + module->push_back(fn); auto *block = new Block(); - fn->push_back(block); + fn.push_back(block); OperationState state(builder.getUnknownLoc(), ReturnOp::getOperationName()); ReturnOp::build(&builder, &state); diff --git a/mlir/lib/SPIRV/Serialization/ConvertToBinary.cpp b/mlir/lib/SPIRV/Serialization/ConvertToBinary.cpp index ebdcaf73717..33572d5adbe 100644 --- a/mlir/lib/SPIRV/Serialization/ConvertToBinary.cpp +++ b/mlir/lib/SPIRV/Serialization/ConvertToBinary.cpp @@ -45,7 +45,7 @@ LogicalResult serializeModule(Module *module, StringRef outputFilename) { // wrapping the SPIR-V ModuleOp inside a MLIR module. This should be changed // to take in the SPIR-V ModuleOp directly after module and function are // migrated to be general ops. - for (auto &fn : *module) { + for (auto fn : *module) { fn.walk([&](spirv::ModuleOp spirvModule) { if (done) { spirvModule.emitError("found more than one 'spv.module' op"); diff --git a/mlir/lib/SPIRV/Transforms/StdOpsToSPIRVConversion.cpp b/mlir/lib/SPIRV/Transforms/StdOpsToSPIRVConversion.cpp index 1a8d79c1790..1ce2b69f055 100644 --- a/mlir/lib/SPIRV/Transforms/StdOpsToSPIRVConversion.cpp +++ b/mlir/lib/SPIRV/Transforms/StdOpsToSPIRVConversion.cpp @@ -42,7 +42,7 @@ class StdOpsToSPIRVConversionPass void StdOpsToSPIRVConversionPass::runOnFunction() { OwningRewritePatternList patterns; - auto &func = getFunction(); + auto func = getFunction(); populateWithGenerated(func.getContext(), &patterns); applyPatternsGreedily(func, std::move(patterns)); diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 6d5073f1c37..9fc216eef25 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -440,14 +440,14 @@ static LogicalResult verify(CallOp op) { auto fnAttr = op.getAttrOfType("callee"); if (!fnAttr) return op.emitOpError("requires a 'callee' function attribute"); - auto *fn = op.getOperation()->getFunction()->getModule()->getNamedFunction( + auto fn = op.getOperation()->getFunction().getModule()->getNamedFunction( fnAttr.getValue()); if (!fn) return op.emitOpError() << "'" << fnAttr.getValue() << "' does not reference a valid function"; // Verify that the operand and result types match the callee. - auto fnType = fn->getType(); + auto fnType = fn.getType(); if (fnType.getNumInputs() != op.getNumOperands()) return op.emitOpError("incorrect number of operands for callee"); @@ -1107,13 +1107,13 @@ static LogicalResult verify(ConstantOp &op) { return op.emitOpError("requires 'value' to be a function reference"); // Try to find the referenced function. - auto *fn = op.getOperation()->getFunction()->getModule()->getNamedFunction( + auto fn = op.getOperation()->getFunction().getModule()->getNamedFunction( fnAttr.getValue()); if (!fn) return op.emitOpError("reference to undefined function 'bar'"); // Check that the referenced function has the correct type. - if (fn->getType() != type) + if (fn.getType() != type) return op.emitOpError("reference to function with mismatched type"); return success(); @@ -1876,10 +1876,10 @@ static void print(OpAsmPrinter *p, ReturnOp op) { } static LogicalResult verify(ReturnOp op) { - auto *function = op.getOperation()->getFunction(); + auto function = op.getOperation()->getFunction(); // The operand number and types must match the function signature. - const auto &results = function->getType().getResults(); + const auto &results = function.getType().getResults(); if (op.getNumOperands() != results.size()) return op.emitOpError("has ") << op.getNumOperands() diff --git a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp index 74ade942fc7..1e8409246ef 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp @@ -69,7 +69,7 @@ std::unique_ptr mlir::translateModuleToNVVMIR(Module &m) { // Insert the nvvm.annotations kernel so that the NVVM backend recognizes the // function as a kernel. - for (Function &func : m) { + for (Function func : m) { if (!func.getAttrOfType(gpu::GPUDialect::getKernelFuncAttrName())) continue; @@ -89,20 +89,21 @@ std::unique_ptr mlir::translateModuleToNVVMIR(Module &m) { return llvmModule; } -static TranslateFromMLIRRegistration registration( - "mlir-to-nvvmir", [](Module *module, llvm::StringRef outputFilename) { - if (!module) - return true; +static TranslateFromMLIRRegistration + registration("mlir-to-nvvmir", + [](Module *module, llvm::StringRef outputFilename) { + if (!module) + return true; - auto llvmModule = mlir::translateModuleToNVVMIR(*module); - if (!llvmModule) - return true; + auto llvmModule = mlir::translateModuleToNVVMIR(*module); + if (!llvmModule) + return true; - auto file = openOutputFile(outputFilename); - if (!file) - return true; + auto file = openOutputFile(outputFilename); + if (!file) + return true; - llvmModule->print(file->os(), nullptr); - file->keep(); - return false; - }); + llvmModule->print(file->os(), nullptr); + file->keep(); + return false; + }); diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index ef286cb64fd..4a68ac71ee0 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -375,7 +375,7 @@ bool ModuleTranslation::convertOneFunction(Function &func) { bool ModuleTranslation::convertFunctions() { // Declare all functions first because there may be function calls that form a // call graph with cycles. - for (Function &function : mlirModule) { + for (Function function : mlirModule) { mlir::BoolAttr isVarArgsAttr = function.getAttrOfType("std.varargs"); bool isVarArgs = isVarArgsAttr && isVarArgsAttr.getValue(); @@ -392,7 +392,7 @@ bool ModuleTranslation::convertFunctions() { } // Convert functions. - for (Function &function : mlirModule) { + for (Function function : mlirModule) { // Ignore external functions. if (function.isExternal()) continue; diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp index 8a2002ce368..394b3ef8db5 100644 --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -40,7 +40,7 @@ struct Canonicalizer : public FunctionPass { void Canonicalizer::runOnFunction() { OwningRewritePatternList patterns; - auto &func = getFunction(); + auto func = getFunction(); // TODO: Instead of adding all known patterns from the whole system lazily add // and cache the canonicalization patterns for ops we see in practice when diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index be60ada6a43..84f00b97e38 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -849,7 +849,7 @@ struct FunctionConverter { /// error, success otherwise. If 'signatureConversion' is provided, the /// arguments of the entry block are updated accordingly. LogicalResult - convertFunction(Function *f, + convertFunction(Function f, TypeConverter::SignatureConversion *signatureConversion); /// Converts the given region starting from the entry block and following the @@ -957,22 +957,22 @@ FunctionConverter::convertRegion(DialectConversionRewriter &rewriter, } LogicalResult FunctionConverter::convertFunction( - Function *f, TypeConverter::SignatureConversion *signatureConversion) { + Function f, TypeConverter::SignatureConversion *signatureConversion) { // If this is an external function, there is nothing else to do. - if (f->isExternal()) + if (f.isExternal()) return success(); - DialectConversionRewriter rewriter(f->getBody(), typeConverter); + DialectConversionRewriter rewriter(f.getBody(), typeConverter); // Update the signature of the entry block. if (signatureConversion) { rewriter.argConverter.convertSignature( - &f->getBody().front(), *signatureConversion, rewriter.mapping); + &f.getBody().front(), *signatureConversion, rewriter.mapping); } // Rewrite the function body. if (failed( - convertRegion(rewriter, f->getBody(), /*convertEntryTypes=*/false))) { + convertRegion(rewriter, f.getBody(), /*convertEntryTypes=*/false))) { // Reset any of the generated rewrites. rewriter.discardRewrites(); return failure(); @@ -1124,24 +1124,6 @@ auto ConversionTarget::getOpAction(OperationName op) const // applyConversionPatterns //===----------------------------------------------------------------------===// -namespace { -/// This class represents a function to be converted. It allows for converting -/// the body of functions and the signature in two phases. -struct ConvertedFunction { - ConvertedFunction(Function *fn, FunctionType newType, - ArrayRef newFunctionArgAttrs) - : fn(fn), newType(newType), - newFunctionArgAttrs(newFunctionArgAttrs.begin(), - newFunctionArgAttrs.end()) {} - - /// The function to convert. - Function *fn; - /// The new type and argument attributes for the function. - FunctionType newType; - SmallVector newFunctionArgAttrs; -}; -} // end anonymous namespace - /// Convert the given module with the provided conversion patterns and type /// conversion object. If conversion fails for specific functions, those /// functions remains unmodified. @@ -1149,37 +1131,33 @@ LogicalResult mlir::applyConversionPatterns(Module &module, ConversionTarget &target, TypeConverter &converter, OwningRewritePatternList &&patterns) { - std::vector allFunctions; - allFunctions.reserve(module.getFunctions().size()); - for (auto &func : module) - allFunctions.push_back(&func); + SmallVector allFunctions(module.getFunctions()); return applyConversionPatterns(allFunctions, target, converter, std::move(patterns)); } /// Convert the given functions with the provided conversion patterns. LogicalResult mlir::applyConversionPatterns( - ArrayRef fns, ConversionTarget &target, + MutableArrayRef fns, ConversionTarget &target, TypeConverter &converter, OwningRewritePatternList &&patterns) { if (fns.empty()) return success(); // Build the function converter. - FunctionConverter funcConverter(fns.front()->getContext(), target, patterns, - &converter); + auto *ctx = fns.front().getContext(); + FunctionConverter funcConverter(ctx, target, patterns, &converter); // Try to convert each of the functions within the module. - auto *ctx = fns.front()->getContext(); - for (auto *func : fns) { + for (auto func : fns) { // Convert the function type using the type converter. auto conversion = - converter.convertSignature(func->getType(), func->getAllArgAttrs()); + converter.convertSignature(func.getType(), func.getAllArgAttrs()); if (!conversion) return failure(); // Update the function signature. - func->setType(conversion->getConvertedType(ctx)); - func->setAllArgAttrs(conversion->getConvertedArgAttrs()); + func.setType(conversion->getConvertedType(ctx)); + func.setAllArgAttrs(conversion->getConvertedArgAttrs()); // Convert the body of this function. if (failed(funcConverter.convertFunction(func, &*conversion))) @@ -1193,9 +1171,9 @@ LogicalResult mlir::applyConversionPatterns( /// convert as many of the operations within 'fn' as possible given the set of /// patterns. LogicalResult -mlir::applyConversionPatterns(Function &fn, ConversionTarget &target, +mlir::applyConversionPatterns(Function fn, ConversionTarget &target, OwningRewritePatternList &&patterns) { // Convert the body of this function. FunctionConverter converter(fn.getContext(), target, patterns); - return converter.convertFunction(&fn, /*signatureConversion=*/nullptr); + return converter.convertFunction(fn, /*signatureConversion=*/nullptr); } diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 5a926ceaa92..a3aa092b0ec 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -214,7 +214,7 @@ static bool getFullMemRefAsRegion(Operation *opInst, unsigned numParamLoopIVs, static InFlightDiagnostic LLVM_ATTRIBUTE_UNUSED emitRemarkForBlock(Block &block) { auto *op = block.getContainingOp(); - return op ? op->emitRemark() : block.getFunction()->emitRemark(); + return op ? op->emitRemark() : block.getFunction().emitRemark(); } /// Creates a buffer in the faster memory space for the specified region; @@ -246,8 +246,8 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, Block *block, OpBuilder &b = region.isWrite() ? epilogue : prologue; // Builder to create constants at the top level. - auto *func = block->getFunction(); - OpBuilder top(func->getBody()); + auto func = block->getFunction(); + OpBuilder top(func.getBody()); auto loc = region.loc; auto *memref = region.memref; @@ -751,14 +751,14 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { if (auto *op = block->getContainingOp()) op->emitError(str); else - block->getFunction()->emitError(str); + block->getFunction().emitError(str); } return totalDmaBuffersSizeInBytes; } void DmaGeneration::runOnFunction() { - Function &f = getFunction(); + Function f = getFunction(); OpBuilder topBuilder(f.getBody()); zeroIndex = topBuilder.create(f.getLoc(), 0); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 8d2e75b2dca..77b944f3e01 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -257,7 +257,7 @@ public: // Initializes the dependence graph based on operations in 'f'. // Returns true on success, false otherwise. - bool init(Function &f); + bool init(Function f); // Returns the graph node for 'id'. Node *getNode(unsigned id) { @@ -637,7 +637,7 @@ public: // Assigns each node in the graph a node id based on program order in 'f'. // TODO(andydavis) Add support for taking a Block arg to construct the // dependence graph at a different depth. -bool MemRefDependenceGraph::init(Function &f) { +bool MemRefDependenceGraph::init(Function f) { DenseMap> memrefAccesses; // TODO: support multi-block functions. @@ -859,7 +859,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, // Create builder to insert alloc op just before 'forOp'. OpBuilder b(forInst); // Builder to create constants at the top level. - OpBuilder top(forInst->getFunction()->getBody()); + OpBuilder top(forInst->getFunction().getBody()); // Create new memref type based on slice bounds. auto *oldMemRef = cast(srcStoreOpInst).getMemRef(); auto oldMemRefType = oldMemRef->getType().cast(); @@ -1750,9 +1750,9 @@ public: }; // Search for siblings which load the same memref function argument. - auto *fn = dstNode->op->getFunction(); - for (unsigned i = 0, e = fn->getNumArguments(); i != e; ++i) { - for (auto *user : fn->getArgument(i)->getUsers()) { + auto fn = dstNode->op->getFunction(); + for (unsigned i = 0, e = fn.getNumArguments(); i != e; ++i) { + for (auto *user : fn.getArgument(i)->getUsers()) { if (auto loadOp = dyn_cast(user)) { // Gather loops surrounding 'use'. SmallVector loops; diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index c1be6e8f6b1..2744e5ca05c 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -261,7 +261,7 @@ LogicalResult mlir::tileCodeGen(MutableArrayRef band, // Identify valid and profitable bands of loops to tile. This is currently just // a temporary placeholder to test the mechanics of tiled code generation. // Returns all maximal outermost perfect loop nests to tile. -static void getTileableBands(Function &f, +static void getTileableBands(Function f, std::vector> *bands) { // Get maximal perfect nest of 'affine.for' insts starting from root // (inclusive). diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 05953926376..6f13f623fe8 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -92,8 +92,8 @@ void LoopUnroll::runOnFunction() { // Store innermost loops as we walk. std::vector loops; - void walkPostOrder(Function *f) { - for (auto &b : *f) + void walkPostOrder(Function f) { + for (auto &b : f) walkPostOrder(b.begin(), b.end()); } @@ -142,10 +142,10 @@ void LoopUnroll::runOnFunction() { ? clUnrollNumRepetitions : 1; // If the call back is provided, we will recurse until no loops are found. - Function &func = getFunction(); + Function func = getFunction(); for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) { InnermostLoopGatherer ilg; - ilg.walkPostOrder(&func); + ilg.walkPostOrder(func); auto &loops = ilg.loops; if (loops.empty()) break; diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index 77a23b156b0..df30e270fe6 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -726,7 +726,7 @@ public: } // end namespace -LogicalResult mlir::lowerAffineConstructs(Function &function) { +LogicalResult mlir::lowerAffineConstructs(Function function) { OwningRewritePatternList patterns; RewriteListBuildergetFunction()->print(dbgs())); + LLVM_DEBUG((*slice)[0]->getFunction().print(dbgs())); // slice are topologically sorted, we can just erase them in reverse // order. Reverse iterator does not just work simply with an operator* @@ -667,7 +667,7 @@ static bool emitSlice(MaterializationState *state, /// because we currently disallow vectorization of defs that come from another /// scope. /// TODO(ntv): please document return value. -static bool materialize(Function *f, const SetVector &terminators, +static bool materialize(Function f, const SetVector &terminators, MaterializationState *state) { DenseSet seen; DominanceInfo domInfo(f); @@ -721,7 +721,7 @@ static bool materialize(Function *f, const SetVector &terminators, return true; } LLVM_DEBUG(dbgs() << "\nMLFunction is now\n"); - LLVM_DEBUG(f->print(dbgs())); + LLVM_DEBUG(f.print(dbgs())); } return false; } @@ -731,13 +731,13 @@ void MaterializeVectorsPass::runOnFunction() { NestedPatternContext mlContext; // TODO(ntv): Check to see if this supports arbitrary top-level code. - Function *f = &getFunction(); - if (f->getBlocks().size() != 1) + Function f = getFunction(); + if (f.getBlocks().size() != 1) return; using matcher::Op; LLVM_DEBUG(dbgs() << "\nMaterializeVectors on Function\n"); - LLVM_DEBUG(f->print(dbgs())); + LLVM_DEBUG(f.print(dbgs())); MaterializationState state(hwVectorSize); // Get the hardware vector type. diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index c5676afaf63..1208e2fdd15 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -212,7 +212,7 @@ void MemRefDataFlowOpt::forwardStoreToLoad(LoadOp loadOp) { void MemRefDataFlowOpt::runOnFunction() { // Only supports single block functions at the moment. - Function &f = getFunction(); + Function f = getFunction(); if (f.getBlocks().size() != 1) { markAllAnalysesPreserved(); return; diff --git a/mlir/lib/Transforms/StripDebugInfo.cpp b/mlir/lib/Transforms/StripDebugInfo.cpp index f97f549c93e..c7c3621781a 100644 --- a/mlir/lib/Transforms/StripDebugInfo.cpp +++ b/mlir/lib/Transforms/StripDebugInfo.cpp @@ -29,7 +29,7 @@ struct StripDebugInfo : public FunctionPass { } // end anonymous namespace void StripDebugInfo::runOnFunction() { - Function &func = getFunction(); + Function func = getFunction(); auto unknownLoc = UnknownLoc::get(&getContext()); // Strip the debug info from the function and its operations. diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 47ca378f324..e185f702d27 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -44,7 +44,7 @@ namespace { /// applies the locally optimal patterns in a roughly "bottom up" way. class GreedyPatternRewriteDriver : public PatternRewriter { public: - explicit GreedyPatternRewriteDriver(Function &fn, + explicit GreedyPatternRewriteDriver(Function fn, OwningRewritePatternList &&patterns) : PatternRewriter(fn.getBody()), matcher(std::move(patterns)) { worklist.reserve(64); @@ -213,7 +213,7 @@ bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) { /// patterns in a greedy work-list driven manner. Return true if no more /// patterns can be matched in the result function. /// -bool mlir::applyPatternsGreedily(Function &fn, +bool mlir::applyPatternsGreedily(Function fn, OwningRewritePatternList &&patterns) { GreedyPatternRewriteDriver driver(fn, std::move(patterns)); bool converged = driver.simplifyFunction(maxPatternMatchIterations); diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 728123f71a5..4ddf93c2232 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -125,7 +125,7 @@ LogicalResult mlir::promoteIfSingleIteration(AffineForOp forOp) { Operation *op = forOp.getOperation(); if (!iv->use_empty()) { if (forOp.hasConstantLowerBound()) { - OpBuilder topBuilder(op->getFunction()->getBody()); + OpBuilder topBuilder(op->getFunction().getBody()); auto constOp = topBuilder.create( forOp.getLoc(), forOp.getConstantLowerBound()); iv->replaceAllUsesWith(constOp); diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 39a05d8c300..3fca26bef19 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -1194,7 +1194,7 @@ static LogicalResult vectorizeRootMatch(NestedMatch m, /// Applies vectorization to the current Function by searching over a bunch of /// predetermined patterns. void Vectorize::runOnFunction() { - Function &f = getFunction(); + Function f = getFunction(); if (!fastestVaryingPattern.empty() && fastestVaryingPattern.size() != vectorSizes.size()) { f.emitRemark("Fastest varying pattern specified with different size than " @@ -1220,7 +1220,7 @@ void Vectorize::runOnFunction() { unsigned patternDepth = pat.getDepth(); SmallVector matches; - pat.match(&f, &matches); + pat.match(f, &matches); // Iterate over all the top-level matches and vectorize eagerly. // This automatically prunes intersecting matches. for (auto m : matches) { diff --git a/mlir/lib/Transforms/ViewFunctionGraph.cpp b/mlir/lib/Transforms/ViewFunctionGraph.cpp index 1f2ab69409e..3c1a1b3b481 100644 --- a/mlir/lib/Transforms/ViewFunctionGraph.cpp +++ b/mlir/lib/Transforms/ViewFunctionGraph.cpp @@ -53,13 +53,13 @@ std::string DOTGraphTraits::getNodeLabel(Block *Block, Function *) { } // end namespace llvm -void mlir::viewGraph(Function &function, const llvm::Twine &name, +void mlir::viewGraph(Function function, const llvm::Twine &name, bool shortNames, const llvm::Twine &title, llvm::GraphProgram::Name program) { llvm::ViewGraph(&function, name, shortNames, title, program); } -llvm::raw_ostream &mlir::writeGraph(llvm::raw_ostream &os, Function &function, +llvm::raw_ostream &mlir::writeGraph(llvm::raw_ostream &os, Function function, bool shortNames, const llvm::Twine &title) { return llvm::WriteGraph(os, &function, shortNames, title); } diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp index 834e7c98228..a88312dba9b 100644 --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -43,13 +43,12 @@ static MLIRContext &globalContext() { return context; } -static std::unique_ptr makeFunction(StringRef name, - ArrayRef results = {}, - ArrayRef args = {}) { +static Function makeFunction(StringRef name, ArrayRef results = {}, + ArrayRef args = {}) { auto &ctx = globalContext(); - auto function = llvm::make_unique( - UnknownLoc::get(&ctx), name, FunctionType::get(args, results, &ctx)); - function->addEntryBlock(); + auto function = Function::create(UnknownLoc::get(&ctx), name, + FunctionType::get(args, results, &ctx)); + function.addEntryBlock(); return function; } @@ -62,10 +61,10 @@ TEST_FUNC(builder_dynamic_for_func_args) { auto f = makeFunction("builder_dynamic_for_func_args", {}, {indexType, indexType}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); - ValueHandle i(indexType), j(indexType), lb(f->getArgument(0)), - ub(f->getArgument(1)); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); + ValueHandle i(indexType), j(indexType), lb(f.getArgument(0)), + ub(f.getArgument(1)); ValueHandle f7(constant_float(llvm::APFloat(7.0f), f32Type)); ValueHandle f13(constant_float(llvm::APFloat(13.0f), f32Type)); ValueHandle i7(constant_int(7, 32)); @@ -102,7 +101,8 @@ TEST_FUNC(builder_dynamic_for_func_args) { // CHECK-DAG: [[ri4:%[0-9]+]] = muli {{.*}}, {{.*}} : i32 // CHECK: {{.*}} = subi [[ri3]], [[ri4]] : i32 // clang-format on - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } TEST_FUNC(builder_dynamic_for) { @@ -113,10 +113,10 @@ TEST_FUNC(builder_dynamic_for) { auto f = makeFunction("builder_dynamic_for", {}, {indexType, indexType, indexType, indexType}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); - ValueHandle i(indexType), a(f->getArgument(0)), b(f->getArgument(1)), - c(f->getArgument(2)), d(f->getArgument(3)); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); + ValueHandle i(indexType), a(f.getArgument(0)), b(f.getArgument(1)), + c(f.getArgument(2)), d(f.getArgument(3)); LoopBuilder(&i, a - b, c + d, 2)(); // clang-format off @@ -125,7 +125,8 @@ TEST_FUNC(builder_dynamic_for) { // CHECK-DAG: [[r1:%[0-9]+]] = affine.apply ()[s0, s1] -> (s0 + s1)()[%arg2, %arg3] // CHECK-NEXT: affine.for %i0 = (d0) -> (d0)([[r0]]) to (d0) -> (d0)([[r1]]) step 2 { // clang-format on - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } TEST_FUNC(builder_max_min_for) { @@ -136,10 +137,10 @@ TEST_FUNC(builder_max_min_for) { auto f = makeFunction("builder_max_min_for", {}, {indexType, indexType, indexType, indexType}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); - ValueHandle i(indexType), lb1(f->getArgument(0)), lb2(f->getArgument(1)), - ub1(f->getArgument(2)), ub2(f->getArgument(3)); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); + ValueHandle i(indexType), lb1(f.getArgument(0)), lb2(f.getArgument(1)), + ub1(f.getArgument(2)), ub2(f.getArgument(3)); LoopBuilder(&i, {lb1, lb2}, {ub1, ub2}, 1)(); ret(); @@ -148,7 +149,8 @@ TEST_FUNC(builder_max_min_for) { // CHECK: affine.for %i0 = max (d0, d1) -> (d0, d1)(%arg0, %arg1) to min (d0, d1) -> (d0, d1)(%arg2, %arg3) { // CHECK: return // clang-format on - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } TEST_FUNC(builder_blocks) { @@ -157,14 +159,14 @@ TEST_FUNC(builder_blocks) { using namespace edsc::op; auto f = makeFunction("builder_blocks"); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); ValueHandle c1(ValueHandle::create(42, 32)), c2(ValueHandle::create(1234, 32)); ValueHandle arg1(c1.getType()), arg2(c1.getType()), arg3(c1.getType()), arg4(c1.getType()), r(c1.getType()); - BlockHandle b1, b2, functionBlock(&f->front()); + BlockHandle b1, b2, functionBlock(&f.front()); BlockBuilder(&b1, {&arg1, &arg2})( // b2 has not yet been constructed, need to come back later. // This is a byproduct of non-structured control-flow. @@ -192,7 +194,8 @@ TEST_FUNC(builder_blocks) { // CHECK-NEXT: br ^bb1(%3, %4 : i32, i32) // CHECK-NEXT: } // clang-format on - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } TEST_FUNC(builder_blocks_eager) { @@ -201,8 +204,8 @@ TEST_FUNC(builder_blocks_eager) { using namespace edsc::op; auto f = makeFunction("builder_blocks_eager"); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); ValueHandle c1(ValueHandle::create(42, 32)), c2(ValueHandle::create(1234, 32)); ValueHandle arg1(c1.getType()), arg2(c1.getType()), arg3(c1.getType()), @@ -235,7 +238,8 @@ TEST_FUNC(builder_blocks_eager) { // CHECK-NEXT: br ^bb1(%3, %4 : i32, i32) // CHECK-NEXT: } // clang-format on - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } TEST_FUNC(builder_cond_branch) { @@ -244,15 +248,15 @@ TEST_FUNC(builder_cond_branch) { auto f = makeFunction("builder_cond_branch", {}, {IntegerType::get(1, &globalContext())}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); - ValueHandle funcArg(f->getArgument(0)); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); + ValueHandle funcArg(f.getArgument(0)); ValueHandle c32(ValueHandle::create(32, 32)), c64(ValueHandle::create(64, 64)), c42(ValueHandle::create(42, 32)); ValueHandle arg1(c32.getType()), arg2(c64.getType()), arg3(c32.getType()); - BlockHandle b1, b2, functionBlock(&f->front()); + BlockHandle b1, b2, functionBlock(&f.front()); BlockBuilder(&b1, {&arg1})([&] { ret(); }); BlockBuilder(&b2, {&arg2, &arg3})([&] { ret(); }); // Get back to entry block and add a conditional branch @@ -271,7 +275,8 @@ TEST_FUNC(builder_cond_branch) { // CHECK-NEXT: ^bb2(%1: i64, %2: i32): // pred: ^bb0 // CHECK-NEXT: return // clang-format on - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } TEST_FUNC(builder_cond_branch_eager) { @@ -281,9 +286,9 @@ TEST_FUNC(builder_cond_branch_eager) { auto f = makeFunction("builder_cond_branch_eager", {}, {IntegerType::get(1, &globalContext())}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); - ValueHandle funcArg(f->getArgument(0)); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); + ValueHandle funcArg(f.getArgument(0)); ValueHandle c32(ValueHandle::create(32, 32)), c64(ValueHandle::create(64, 64)), c42(ValueHandle::create(42, 32)); @@ -309,7 +314,8 @@ TEST_FUNC(builder_cond_branch_eager) { // CHECK-NEXT: ^bb2(%1: i64, %2: i32): // pred: ^bb0 // CHECK-NEXT: return // clang-format on - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } TEST_FUNC(builder_helpers) { @@ -321,14 +327,14 @@ TEST_FUNC(builder_helpers) { auto f = makeFunction("builder_helpers", {}, {memrefType, memrefType, memrefType}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); // clang-format off ValueHandle f7( ValueHandle::create(llvm::APFloat(7.0f), f32Type)); - MemRefView vA(f->getArgument(0)), vB(f->getArgument(1)), - vC(f->getArgument(2)); - IndexedValue A(f->getArgument(0)), B(f->getArgument(1)), C(f->getArgument(2)); + MemRefView vA(f.getArgument(0)), vB(f.getArgument(1)), + vC(f.getArgument(2)); + IndexedValue A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2)); IndexHandle i, j, k1, k2, lb0, lb1, lb2, ub0, ub1, ub2; int64_t step0, step1, step2; std::tie(lb0, ub0, step0) = vA.range(0); @@ -363,7 +369,8 @@ TEST_FUNC(builder_helpers) { // CHECK-DAG: [[e:%.*]] = addf [[d]], [[c]] : f32 // CHECK-NEXT: store [[e]], %arg2[%i0, %i1, %i3] : memref // clang-format on - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } TEST_FUNC(custom_ops) { @@ -373,8 +380,8 @@ TEST_FUNC(custom_ops) { auto indexType = IndexType::get(&globalContext()); auto f = makeFunction("custom_ops", {}, {indexType, indexType}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); CustomOperation MY_CUSTOM_OP("my_custom_op"); CustomOperation MY_CUSTOM_OP_0("my_custom_op_0"); CustomOperation MY_CUSTOM_OP_2("my_custom_op_2"); @@ -382,7 +389,7 @@ TEST_FUNC(custom_ops) { // clang-format off ValueHandle vh(indexType), vh20(indexType), vh21(indexType); OperationHandle ih0, ih2; - IndexHandle m, n, M(f->getArgument(0)), N(f->getArgument(1)); + IndexHandle m, n, M(f.getArgument(0)), N(f.getArgument(1)); IndexHandle ten(index_t(10)), twenty(index_t(20)); LoopNestBuilder({&m, &n}, {M, N}, {M + ten, N + twenty}, {1, 1})([&]{ vh = MY_CUSTOM_OP({m, m + n}, {indexType}, {}); @@ -402,7 +409,8 @@ TEST_FUNC(custom_ops) { // CHECK: [[TWO:%[a-z0-9]+]]:2 = "my_custom_op_2"{{.*}} : (index, index) -> (index, index) // CHECK: {{.*}} = "my_custom_op"([[TWO]]#0, [[TWO]]#1) : (index, index) -> index // clang-format on - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } TEST_FUNC(insertion_in_block) { @@ -412,8 +420,8 @@ TEST_FUNC(insertion_in_block) { auto indexType = IndexType::get(&globalContext()); auto f = makeFunction("insertion_in_block", {}, {indexType, indexType}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); BlockHandle b1; // clang-format off ValueHandle::create(0, 32); @@ -427,7 +435,8 @@ TEST_FUNC(insertion_in_block) { // CHECK: ^bb1: // no predecessors // CHECK: {{.*}} = constant 1 : i32 // clang-format on - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } TEST_FUNC(select_op) { @@ -438,12 +447,12 @@ TEST_FUNC(select_op) { auto memrefType = MemRefType::get({-1, -1, -1}, f32Type, {}, 0); auto f = makeFunction("select_op", {}, {memrefType}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); // clang-format off ValueHandle zero = constant_index(0), one = constant_index(1); - MemRefView vA(f->getArgument(0)); - IndexedValue A(f->getArgument(0)); + MemRefView vA(f.getArgument(0)); + IndexedValue A(f.getArgument(0)); IndexHandle i, j; LoopNestBuilder({&i, &j}, {zero, zero}, {one, one}, {1, 1})([&]{ // This test exercises IndexedValue::operator Value*. @@ -461,7 +470,8 @@ TEST_FUNC(select_op) { // CHECK-DAG: {{.*}} = load // CHECK-NEXT: {{.*}} = select // clang-format on - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } // Inject an EDSC-constructed computation to exercise imperfectly nested 2-d @@ -474,12 +484,11 @@ TEST_FUNC(tile_2d) { MemRefType::get({-1, -1, -1}, FloatType::getF32(&globalContext()), {}, 0); auto f = makeFunction("tile_2d", {}, {memrefType, memrefType, memrefType}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); ValueHandle zero = constant_index(0); - MemRefView vA(f->getArgument(0)), vB(f->getArgument(1)), - vC(f->getArgument(2)); - IndexedValue A(f->getArgument(0)), B(f->getArgument(1)), C(f->getArgument(2)); + MemRefView vA(f.getArgument(0)), vB(f.getArgument(1)), vC(f.getArgument(2)); + IndexedValue A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2)); IndexHandle i, j, k1, k2, M(vC.ub(0)), N(vC.ub(1)), O(vC.ub(2)); // clang-format off @@ -531,7 +540,8 @@ TEST_FUNC(tile_2d) { // CHECK-NEXT: {{.*}}= addf {{.*}}, {{.*}} : f32 // CHECK-NEXT: store {{.*}}, {{.*}}[%i8, %i9, %i7] : memref // clang-format on - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } // Inject an EDSC-constructed computation to exercise 2-d vectorization. @@ -544,16 +554,15 @@ TEST_FUNC(vectorize_2d) { auto owningF = makeFunction("vectorize_2d", {}, {memrefType, memrefType, memrefType}); - mlir::Function *f = owningF.release(); + mlir::Function f = owningF; mlir::Module module(&globalContext()); - module.getFunctions().push_back(f); + module.push_back(f); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); ValueHandle zero = constant_index(0); - MemRefView vA(f->getArgument(0)), vB(f->getArgument(1)), - vC(f->getArgument(2)); - IndexedValue A(f->getArgument(0)), B(f->getArgument(1)), C(f->getArgument(2)); + MemRefView vA(f.getArgument(0)), vB(f.getArgument(1)), vC(f.getArgument(2)); + IndexedValue A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2)); IndexHandle M(vA.ub(0)), N(vA.ub(1)), P(vA.ub(2)); // clang-format off @@ -580,9 +589,10 @@ TEST_FUNC(vectorize_2d) { pm.addPass(mlir::createCanonicalizerPass()); SmallVector vectorSizes{4, 4}; pm.addPass(mlir::createVectorizePass(vectorSizes)); - auto result = pm.run(f->getModule()); + auto result = pm.run(f.getModule()); if (succeeded(result)) - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } int main() { diff --git a/mlir/test/lib/Transforms/TestVectorizationUtils.cpp b/mlir/test/lib/Transforms/TestVectorizationUtils.cpp index 4767e3367be..7bfb5564064 100644 --- a/mlir/test/lib/Transforms/TestVectorizationUtils.cpp +++ b/mlir/test/lib/Transforms/TestVectorizationUtils.cpp @@ -97,12 +97,12 @@ struct VectorizerTestPass : public FunctionPass { } // end anonymous namespace void VectorizerTestPass::testVectorShapeRatio(llvm::raw_ostream &outs) { - auto *f = &getFunction(); + auto f = getFunction(); using matcher::Op; SmallVector shape(clTestVectorShapeRatio.begin(), clTestVectorShapeRatio.end()); auto subVectorType = - VectorType::get(shape, FloatType::getF32(f->getContext())); + VectorType::get(shape, FloatType::getF32(f.getContext())); // Only filter operations that operate on a strict super-vector and have one // return. This makes testing easier. auto filter = [&](Operation &op) { @@ -148,7 +148,7 @@ static NestedPattern patternTestSlicingOps() { } void VectorizerTestPass::testBackwardSlicing(llvm::raw_ostream &outs) { - auto *f = &getFunction(); + auto f = getFunction(); SmallVector matches; patternTestSlicingOps().match(f, &matches); @@ -163,7 +163,7 @@ void VectorizerTestPass::testBackwardSlicing(llvm::raw_ostream &outs) { } void VectorizerTestPass::testForwardSlicing(llvm::raw_ostream &outs) { - auto *f = &getFunction(); + auto f = getFunction(); SmallVector matches; patternTestSlicingOps().match(f, &matches); for (auto m : matches) { @@ -177,7 +177,7 @@ void VectorizerTestPass::testForwardSlicing(llvm::raw_ostream &outs) { } void VectorizerTestPass::testSlicing(llvm::raw_ostream &outs) { - auto *f = &getFunction(); + auto f = getFunction(); SmallVector matches; patternTestSlicingOps().match(f, &matches); @@ -195,7 +195,7 @@ static bool customOpWithAffineMapAttribute(Operation &op) { } void VectorizerTestPass::testComposeMaps(llvm::raw_ostream &outs) { - auto *f = &getFunction(); + auto f = getFunction(); using matcher::Op; auto pattern = Op(customOpWithAffineMapAttribute); @@ -227,7 +227,7 @@ static bool singleResultAffineApplyOpWithoutUses(Operation &op) { void VectorizerTestPass::testNormalizeMaps() { using matcher::Op; - auto *f = &getFunction(); + auto f = getFunction(); // Save matched AffineApplyOp that all need to be erased in the end. auto pattern = Op(affineApplyOp); @@ -256,7 +256,7 @@ void VectorizerTestPass::runOnFunction() { NestedPatternContext mlContext; // Only support single block functions at this point. - Function &f = getFunction(); + Function f = getFunction(); if (f.getBlocks().size() != 1) return; diff --git a/mlir/tools/mlir-cpu-runner/mlir-cpu-runner-lib.cpp b/mlir/tools/mlir-cpu-runner/mlir-cpu-runner-lib.cpp index 54a9c6ce95c..1ac6c402630 100644 --- a/mlir/tools/mlir-cpu-runner/mlir-cpu-runner-lib.cpp +++ b/mlir/tools/mlir-cpu-runner/mlir-cpu-runner-lib.cpp @@ -163,8 +163,8 @@ static LogicalResult convertAffineStandardToLLVMIR(Module *module) { static Error compileAndExecuteFunctionWithMemRefs( Module *module, StringRef entryPoint, std::function transformer) { - Function *mainFunction = module->getNamedFunction(entryPoint); - if (!mainFunction || mainFunction->getBlocks().empty()) { + Function mainFunction = module->getNamedFunction(entryPoint); + if (!mainFunction || mainFunction.getBlocks().empty()) { return make_string_error("entry point not found"); } @@ -172,9 +172,9 @@ static Error compileAndExecuteFunctionWithMemRefs( // pretty print the results, because the function itself will be rewritten // to use the LLVM dialect. SmallVector argTypes = - llvm::to_vector<8>(mainFunction->getType().getInputs()); + llvm::to_vector<8>(mainFunction.getType().getInputs()); SmallVector resTypes = - llvm::to_vector<8>(mainFunction->getType().getResults()); + llvm::to_vector<8>(mainFunction.getType().getResults()); float init = std::stof(initValue.getValue()); @@ -206,18 +206,18 @@ static Error compileAndExecuteFunctionWithMemRefs( static Error compileAndExecuteSingleFloatReturnFunction( Module *module, StringRef entryPoint, std::function transformer) { - Function *mainFunction = module->getNamedFunction(entryPoint); - if (!mainFunction || mainFunction->isExternal()) { + Function mainFunction = module->getNamedFunction(entryPoint); + if (!mainFunction || mainFunction.isExternal()) { return make_string_error("entry point not found"); } - if (!mainFunction->getType().getInputs().empty()) + if (!mainFunction.getType().getInputs().empty()) return make_string_error("function inputs not supported"); - if (mainFunction->getType().getResults().size() != 1) + if (mainFunction.getType().getResults().size() != 1) return make_string_error("only single f32 function result supported"); - auto t = mainFunction->getType().getResults()[0].dyn_cast(); + auto t = mainFunction.getType().getResults()[0].dyn_cast(); if (!t) return make_string_error("only single llvm.f32 function result supported"); auto *llvmTy = t.getUnderlyingType(); diff --git a/mlir/unittests/Pass/AnalysisManagerTest.cpp b/mlir/unittests/Pass/AnalysisManagerTest.cpp index 38a059b3ba5..d2a82374124 100644 --- a/mlir/unittests/Pass/AnalysisManagerTest.cpp +++ b/mlir/unittests/Pass/AnalysisManagerTest.cpp @@ -25,11 +25,11 @@ using namespace mlir::detail; namespace { /// Minimal class definitions for two analyses. struct MyAnalysis { - MyAnalysis(Function *) {} + MyAnalysis(Function) {} MyAnalysis(Module *) {} }; struct OtherAnalysis { - OtherAnalysis(Function *) {} + OtherAnalysis(Function) {} OtherAnalysis(Module *) {} }; @@ -59,10 +59,10 @@ TEST(AnalysisManagerTest, FineGrainFunctionAnalysisPreservation) { // Create a function and a module. std::unique_ptr module(new Module(&context)); - Function *func1 = - new Function(builder.getUnknownLoc(), "foo", - builder.getFunctionType(llvm::None, llvm::None)); - module->getFunctions().push_back(func1); + Function func1 = + Function::create(builder.getUnknownLoc(), "foo", + builder.getFunctionType(llvm::None, llvm::None)); + module->push_back(func1); // Test fine grain invalidation of the function analysis manager. ModuleAnalysisManager mam(&*module, /*passInstrumentor=*/nullptr); @@ -87,10 +87,10 @@ TEST(AnalysisManagerTest, FineGrainChildFunctionAnalysisPreservation) { // Create a function and a module. std::unique_ptr module(new Module(&context)); - Function *func1 = - new Function(builder.getUnknownLoc(), "foo", - builder.getFunctionType(llvm::None, llvm::None)); - module->getFunctions().push_back(func1); + Function func1 = + Function::create(builder.getUnknownLoc(), "foo", + builder.getFunctionType(llvm::None, llvm::None)); + module->push_back(func1); // Test fine grain invalidation of a function analysis from within a module // analysis manager. -- cgit v1.2.3 From 2e1187dd25aee1c364e5273108c14e591656f19a Mon Sep 17 00:00:00 2001 From: Andy Davis Date: Wed, 3 Jul 2019 10:35:03 -0700 Subject: Globally change load/store/dma_start/dma_wait operations over to affine.load/store/dma_start/dma_wait. In most places, this is just a name change (with the exception of affine.dma_start swapping the operand positions of its tag memref and num_elements operands). Significant code changes occur here: *) Vectorization: LoopAnalysis.cpp, Vectorize.cpp *) Affine Transforms: Transforms/Utils/Utils.cpp PiperOrigin-RevId: 256395088 --- mlir/include/mlir/AffineOps/AffineOps.h | 29 +- mlir/include/mlir/Analysis/Utils.h | 14 +- mlir/include/mlir/Analysis/VectorAnalysis.h | 2 +- mlir/include/mlir/Transforms/Passes.h | 2 +- mlir/lib/AffineOps/AffineOps.cpp | 167 ++- mlir/lib/Analysis/AffineAnalysis.cpp | 18 +- mlir/lib/Analysis/LoopAnalysis.cpp | 46 +- mlir/lib/Analysis/MemRefBoundCheck.cpp | 5 +- mlir/lib/Analysis/NestedMatcher.cpp | 2 +- mlir/lib/Analysis/TestMemRefDependenceCheck.cpp | 3 +- mlir/lib/Analysis/Utils.cpp | 36 +- mlir/lib/Analysis/VectorAnalysis.cpp | 15 +- mlir/lib/Parser/Parser.cpp | 20 +- mlir/lib/Transforms/DmaGeneration.cpp | 53 +- mlir/lib/Transforms/LoopFusion.cpp | 54 +- mlir/lib/Transforms/LoopInvariantCodeMotion.cpp | 14 +- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 14 +- mlir/lib/Transforms/PipelineDataTransfer.cpp | 25 +- mlir/lib/Transforms/Utils/LoopFusionUtils.cpp | 16 +- mlir/lib/Transforms/Utils/Utils.cpp | 124 ++- mlir/lib/Transforms/Vectorize.cpp | 83 +- mlir/test/EDSC/builder-api-test.cpp | 32 +- .../Vectorize/materialize_vectors_1d_to_1d.mlir | 12 +- .../Vectorize/materialize_vectors_2d_to_1d.mlir | 12 +- .../Vectorize/materialize_vectors_2d_to_2d.mlir | 12 +- mlir/test/Transforms/Vectorize/vectorize_1d.mlir | 89 +- mlir/test/Transforms/Vectorize/vectorize_2d.mlir | 26 +- mlir/test/Transforms/Vectorize/vectorize_3d.mlir | 2 +- .../Vectorize/vectorize_outer_loop_2d.mlir | 4 +- .../vectorize_outer_loop_transpose_2d.mlir | 10 +- .../Vectorize/vectorize_transpose_2d.mlir | 10 +- mlir/test/Transforms/dma-generate.mlir | 288 +++-- .../Transforms/loop-fusion-dependence-check.mlir | 98 +- .../Transforms/loop-fusion-slice-computation.mlir | 28 +- mlir/test/Transforms/loop-fusion.mlir | 1117 ++++++++------------ .../Transforms/loop-invariant-code-motion.mlir | 140 +-- mlir/test/Transforms/loop-tiling.mlir | 42 +- mlir/test/Transforms/memref-bound-check.mlir | 76 +- mlir/test/Transforms/memref-dataflow-opt.mlir | 82 +- mlir/test/Transforms/memref-dependence-check.mlir | 150 +-- mlir/test/Transforms/parallelism-detection.mlir | 8 +- mlir/test/Transforms/pipeline-data-transfer.mlir | 189 ++-- 42 files changed, 1588 insertions(+), 1581 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/AffineOps/AffineOps.h b/mlir/include/mlir/AffineOps/AffineOps.h index 91b0d0e7d98..b8bf3685c9b 100644 --- a/mlir/include/mlir/AffineOps/AffineOps.h +++ b/mlir/include/mlir/AffineOps/AffineOps.h @@ -284,6 +284,8 @@ public: static ParseResult parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); LogicalResult verify(); + static void getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context); /// Returns true if this DMA operation is strided, returns false otherwise. bool isStrided() { @@ -367,6 +369,8 @@ public: static ParseResult parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); LogicalResult verify(); + static void getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context); }; /// The "affine.for" operation represents an affine loop nest, defining an SSA @@ -649,10 +653,16 @@ public: /// Builds an affine load op with the specified map and operands. static void build(Builder *builder, OperationState *result, AffineMap map, ArrayRef operands); + /// Builds an affine load op an identify map and operands. + static void build(Builder *builder, OperationState *result, Value *memref, + ArrayRef indices = {}); + + /// Returns the operand index of the memref. + unsigned getMemRefOperandIndex() { return 0; } /// Get memref operand. - Value *getMemRef() { return getOperand(0); } - void setMemRef(Value *value) { setOperand(0, value); } + Value *getMemRef() { return getOperand(getMemRefOperandIndex()); } + void setMemRef(Value *value) { setOperand(getMemRefOperandIndex(), value); } MemRefType getMemRefType() { return getMemRef()->getType().cast(); } @@ -680,6 +690,8 @@ public: static ParseResult parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); LogicalResult verify(); + static void getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context); }; /// The "affine.store" op writes an element to a memref, where the index @@ -707,13 +719,20 @@ public: static void build(Builder *builder, OperationState *result, Value *valueToStore, AffineMap map, ArrayRef operands); + /// Builds an affine store operation with an identity map and operands. + static void build(Builder *builder, OperationState *result, + Value *valueToStore, Value *memref, + ArrayRef operands); /// Get value to be stored by store operation. Value *getValueToStore() { return getOperand(0); } + /// Returns the operand index of the memref. + unsigned getMemRefOperandIndex() { return 1; } + /// Get memref operand. - Value *getMemRef() { return getOperand(1); } - void setMemRef(Value *value) { setOperand(1, value); } + Value *getMemRef() { return getOperand(getMemRefOperandIndex()); } + void setMemRef(Value *value) { setOperand(getMemRefOperandIndex(), value); } MemRefType getMemRefType() { return getMemRef()->getType().cast(); @@ -742,6 +761,8 @@ public: static ParseResult parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); LogicalResult verify(); + static void getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context); }; /// Returns true if the given Value can be used as a dimension id. diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index 5c1f47a1348..b012cc1e60e 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -103,29 +103,29 @@ struct ComputationSliceState { // Backward slice example: // // affine.for %i0 = 0 to 10 { -// store %cst, %0[%i0] : memref<100xf32> // 'depSourceAccess' +// affine.store %cst, %0[%i0] : memref<100xf32> // 'depSourceAccess' // } // affine.for %i1 = 0 to 10 { -// %v = load %0[%i1] : memref<100xf32> // 'depSinkAccess' +// %v = affine.load %0[%i1] : memref<100xf32> // 'depSinkAccess' // } // // // Backward computation slice of loop nest '%i0'. // affine.for %i0 = (d0) -> (d0)(%i1) to (d0) -> (d0 + 1)(%i1) { -// store %cst, %0[%i0] : memref<100xf32> // 'depSourceAccess' +// affine.store %cst, %0[%i0] : memref<100xf32> // 'depSourceAccess' // } // // Forward slice example: // // affine.for %i0 = 0 to 10 { -// store %cst, %0[%i0] : memref<100xf32> // 'depSourceAccess' +// affine.store %cst, %0[%i0] : memref<100xf32> // 'depSourceAccess' // } // affine.for %i1 = 0 to 10 { -// %v = load %0[%i1] : memref<100xf32> // 'depSinkAccess' +// %v = affine.load %0[%i1] : memref<100xf32> // 'depSinkAccess' // } // // // Forward computation slice of loop nest '%i1'. // affine.for %i1 = (d0) -> (d0)(%i0) to (d0) -> (d0 + 1)(%i0) { -// %v = load %0[%i1] : memref<100xf32> // 'depSinkAccess' +// %v = affine.load %0[%i1] : memref<100xf32> // 'depSinkAccess' // } // void getComputationSliceState(Operation *depSourceOp, Operation *depSinkOp, @@ -172,7 +172,7 @@ AffineForOp insertBackwardComputationSlice(Operation *srcOpInst, // // affine.for %i = 0 to 32 { // affine.for %ii = %i to (d0) -> (d0 + 8) (%i) { -// load %A[%ii] +// affine.load %A[%ii] // } // } // diff --git a/mlir/include/mlir/Analysis/VectorAnalysis.h b/mlir/include/mlir/Analysis/VectorAnalysis.h index 1f4e50c1178..8b9992da90e 100644 --- a/mlir/include/mlir/Analysis/VectorAnalysis.h +++ b/mlir/include/mlir/Analysis/VectorAnalysis.h @@ -122,7 +122,7 @@ shapeRatio(VectorType superVectorType, VectorType subVectorType); /// `%arg0[%c0, %c0]` into vector<128xf32> which needs a 1-D vector broadcast. /// AffineMap makePermutationMap( - Operation *op, + Operation *op, ArrayRef indices, const llvm::DenseMap &loopToVectorDim); namespace matcher { diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index 48822cdac86..a253871bc29 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -103,7 +103,7 @@ FunctionPassBase *createLoopTilingPass(uint64_t cacheSizeBytes); /// while generating DMAs to move data. FunctionPassBase *createDmaGenerationPass( unsigned slowMemorySpace, unsigned fastMemorySpace, - int minDmaTransferSize = 1024, + unsigned tagMemorySpace = 0, int minDmaTransferSize = 1024, uint64_t fastMemCapacityBytes = std::numeric_limits::max()); /// Creates a pass to lower VectorTransferReadOp and VectorTransferWriteOp. diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index d7650dcb127..04a34627a62 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -696,6 +696,38 @@ void AffineApplyOp::getCanonicalizationPatterns( results.push_back(llvm::make_unique(context)); } +//===----------------------------------------------------------------------===// +// Common canonicalization pattern support logic +//===----------------------------------------------------------------------===// + +namespace { +/// This is a common class used for patterns of the form +/// "someop(memrefcast) -> someop". It folds the source of any memref_cast +/// into the root operation directly. +struct MemRefCastFolder : public RewritePattern { + /// The rootOpName is the name of the root operation to match against. + MemRefCastFolder(StringRef rootOpName, MLIRContext *context) + : RewritePattern(rootOpName, 1, context) {} + + PatternMatchResult match(Operation *op) const override { + for (auto *operand : op->getOperands()) + if (matchPattern(operand, m_Op())) + return matchSuccess(); + + return matchFailure(); + } + + void rewrite(Operation *op, PatternRewriter &rewriter) const override { + for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) + if (auto *memref = op->getOperand(i)->getDefiningOp()) + if (auto cast = dyn_cast(memref)) + op->setOperand(i, cast.getOperand()); + rewriter.updatedRootInPlace(op); + } +}; + +} // end anonymous namespace. + //===----------------------------------------------------------------------===// // AffineDmaStartOp //===----------------------------------------------------------------------===// @@ -770,19 +802,16 @@ ParseResult AffineDmaStartOp::parse(OpAsmParser *parser, // *) src memref followed by its affine map operands (in square brackets). // *) tag memref followed by its affine map operands (in square brackets). // *) number of elements transferred by DMA operation. - if (parser->parseOperand(srcMemRefInfo) || parser->parseLSquare() || + if (parser->parseOperand(srcMemRefInfo) || parser->parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr, getSrcMapAttrName(), result->attributes) || - parser->parseRSquare() || parser->parseComma() || - parser->parseOperand(dstMemRefInfo) || parser->parseLSquare() || + parser->parseComma() || parser->parseOperand(dstMemRefInfo) || parser->parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr, getDstMapAttrName(), result->attributes) || - parser->parseRSquare() || parser->parseComma() || - parser->parseOperand(tagMemRefInfo) || parser->parseLSquare() || + parser->parseComma() || parser->parseOperand(tagMemRefInfo) || parser->parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr, getTagMapAttrName(), result->attributes) || - parser->parseRSquare() || parser->parseComma() || - parser->parseOperand(numElementsInfo)) + parser->parseComma() || parser->parseOperand(numElementsInfo)) return failure(); // Parse optional stride and elements per stride. @@ -846,6 +875,13 @@ LogicalResult AffineDmaStartOp::verify() { return success(); } +void AffineDmaStartOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + /// dma_start(memrefcast) -> dma_start + results.push_back( + llvm::make_unique(getOperationName(), context)); +} + //===----------------------------------------------------------------------===// // AffineDmaWaitOp //===----------------------------------------------------------------------===// @@ -884,11 +920,11 @@ ParseResult AffineDmaWaitOp::parse(OpAsmParser *parser, OpAsmParser::OperandType numElementsInfo; // Parse tag memref, its map operands, and dma size. - if (parser->parseOperand(tagMemRefInfo) || parser->parseLSquare() || + if (parser->parseOperand(tagMemRefInfo) || parser->parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr, getTagMapAttrName(), result->attributes) || - parser->parseRSquare() || parser->parseComma() || - parser->parseOperand(numElementsInfo) || parser->parseColonType(type) || + parser->parseComma() || parser->parseOperand(numElementsInfo) || + parser->parseColonType(type) || parser->resolveOperand(tagMemRefInfo, type, result->operands) || parser->resolveOperands(tagMapOperands, indexType, result->operands) || parser->resolveOperand(numElementsInfo, indexType, result->operands)) @@ -910,6 +946,13 @@ LogicalResult AffineDmaWaitOp::verify() { return success(); } +void AffineDmaWaitOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + /// dma_wait(memrefcast) -> dma_wait + results.push_back( + llvm::make_unique(getOperationName(), context)); +} + //===----------------------------------------------------------------------===// // AffineForOp //===----------------------------------------------------------------------===// @@ -1556,7 +1599,20 @@ void AffineLoadOp::build(Builder *builder, OperationState *result, AffineMap map, ArrayRef operands) { // TODO(b/133776335) Check that map operands are loop IVs or symbols. result->addOperands(operands); - result->addAttribute("map", builder->getAffineMapAttr(map)); + if (map) + result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map)); + auto memrefType = operands[0]->getType().cast(); + result->types.push_back(memrefType.getElementType()); +} + +void AffineLoadOp::build(Builder *builder, OperationState *result, + Value *memref, ArrayRef indices) { + result->addOperands(memref); + result->addOperands(indices); + auto memrefType = memref->getType().cast(); + auto map = builder->getMultiDimIdentityMap(memrefType.getRank()); + result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map)); + result->types.push_back(memrefType.getElementType()); } ParseResult AffineLoadOp::parse(OpAsmParser *parser, OperationState *result) { @@ -1568,10 +1624,11 @@ ParseResult AffineLoadOp::parse(OpAsmParser *parser, OperationState *result) { AffineMapAttr mapAttr; SmallVector mapOperands; return failure( - parser->parseOperand(memrefInfo) || parser->parseLSquare() || - parser->parseAffineMapOfSSAIds(mapOperands, mapAttr, "map", + parser->parseOperand(memrefInfo) || + parser->parseAffineMapOfSSAIds(mapOperands, mapAttr, getMapAttrName(), result->attributes) || - parser->parseRSquare() || parser->parseColonType(type) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(type) || parser->resolveOperand(memrefInfo, type, result->operands) || parser->resolveOperands(mapOperands, affineIntTy, result->operands) || parser->addTypeToList(type.getElementType(), result->types)); @@ -1579,20 +1636,27 @@ ParseResult AffineLoadOp::parse(OpAsmParser *parser, OperationState *result) { void AffineLoadOp::print(OpAsmPrinter *p) { *p << "affine.load " << *getMemRef() << '['; - AffineMapAttr mapAttr = getAttrOfType("map"); - SmallVector operands(getIndices()); - p->printAffineMapOfSSAIds(mapAttr, operands); - *p << "] : " << getMemRefType(); + AffineMapAttr mapAttr = getAttrOfType(getMapAttrName()); + if (mapAttr) { + SmallVector operands(getIndices()); + p->printAffineMapOfSSAIds(mapAttr, operands); + } + *p << ']'; + p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{getMapAttrName()}); + *p << " : " << getMemRefType(); } LogicalResult AffineLoadOp::verify() { if (getType() != getMemRefType().getElementType()) return emitOpError("result type must match element type of memref"); - AffineMap map = getAttrOfType("map").getValue(); - if (map.getNumResults() != getMemRefType().getRank()) - return emitOpError("affine.load affine map num results must equal memref " - "rank"); + auto mapAttr = getAttrOfType(getMapAttrName()); + if (mapAttr) { + AffineMap map = getAttrOfType(getMapAttrName()).getValue(); + if (map.getNumResults() != getMemRefType().getRank()) + return emitOpError("affine.load affine map num results must equal" + " memref rank"); + } for (auto *idx : getIndices()) if (!idx->getType().isIndex()) @@ -1601,6 +1665,13 @@ LogicalResult AffineLoadOp::verify() { return success(); } +void AffineLoadOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + /// load(memrefcast) -> load + results.push_back( + llvm::make_unique(getOperationName(), context)); +} + //===----------------------------------------------------------------------===// // AffineStoreOp //===----------------------------------------------------------------------===// @@ -1611,7 +1682,19 @@ void AffineStoreOp::build(Builder *builder, OperationState *result, // TODO(b/133776335) Check that map operands are loop IVs or symbols. result->addOperands(valueToStore); result->addOperands(operands); - result->addAttribute("map", builder->getAffineMapAttr(map)); + if (map) + result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map)); +} + +void AffineStoreOp::build(Builder *builder, OperationState *result, + Value *valueToStore, Value *memref, + ArrayRef operands) { + result->addOperands(valueToStore); + result->addOperands(memref); + result->addOperands(operands); + auto memrefType = memref->getType().cast(); + auto map = builder->getMultiDimIdentityMap(memrefType.getRank()); + result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map)); } ParseResult AffineStoreOp::parse(OpAsmParser *parser, OperationState *result) { @@ -1624,10 +1707,11 @@ ParseResult AffineStoreOp::parse(OpAsmParser *parser, OperationState *result) { SmallVector mapOperands; return failure( parser->parseOperand(storeValueInfo) || parser->parseComma() || - parser->parseOperand(memrefInfo) || parser->parseLSquare() || - parser->parseAffineMapOfSSAIds(mapOperands, mapAttr, "map", + parser->parseOperand(memrefInfo) || + parser->parseAffineMapOfSSAIds(mapOperands, mapAttr, getMapAttrName(), result->attributes) || - parser->parseRSquare() || parser->parseColonType(type) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(type) || parser->resolveOperand(storeValueInfo, type.getElementType(), result->operands) || parser->resolveOperand(memrefInfo, type, result->operands) || @@ -1637,10 +1721,14 @@ ParseResult AffineStoreOp::parse(OpAsmParser *parser, OperationState *result) { void AffineStoreOp::print(OpAsmPrinter *p) { *p << "affine.store " << *getValueToStore(); *p << ", " << *getMemRef() << '['; - AffineMapAttr mapAttr = getAttrOfType("map"); - SmallVector operands(getIndices()); - p->printAffineMapOfSSAIds(mapAttr, operands); - *p << "] : " << getMemRefType(); + AffineMapAttr mapAttr = getAttrOfType(getMapAttrName()); + if (mapAttr) { + SmallVector operands(getIndices()); + p->printAffineMapOfSSAIds(mapAttr, operands); + } + *p << ']'; + p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{getMapAttrName()}); + *p << " : " << getMemRefType(); } LogicalResult AffineStoreOp::verify() { @@ -1648,14 +1736,23 @@ LogicalResult AffineStoreOp::verify() { if (getValueToStore()->getType() != getMemRefType().getElementType()) return emitOpError("first operand must have same type memref element type"); - AffineMap map = getAttrOfType("map").getValue(); - if (map.getNumResults() != getMemRefType().getRank()) - return emitOpError("affine.store affine map num results must equal memref " - "rank"); - + auto mapAttr = getAttrOfType(getMapAttrName()); + if (mapAttr) { + AffineMap map = mapAttr.getValue(); + if (map.getNumResults() != getMemRefType().getRank()) + return emitOpError("affine.store affine map num results must equal" + " memref rank"); + } for (auto *idx : getIndices()) if (!idx->getType().isIndex()) return emitOpError("index to load must have 'index' type"); // TODO(b/133776335) Verify that map operands are loop IVs or symbols. return success(); } + +void AffineStoreOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + /// load(memrefcast) -> load + results.push_back( + llvm::make_unique(getOperationName(), context)); +} diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index fc8c712e8b0..28ee5d6e58b 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -668,10 +668,12 @@ static void computeDirectionVector( // Populates 'accessMap' with composition of AffineApplyOps reachable from // indices of MemRefAccess. void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const { - auto memrefType = memref->getType().cast(); - // Create identity map with same number of dimensions as 'memrefType' rank. - auto map = AffineMap::getMultiDimIdentityMap(memrefType.getRank(), - memref->getType().getContext()); + // Get affine map from AffineLoad/Store. + AffineMap map; + if (auto loadOp = dyn_cast(opInst)) + map = loadOp.getAffineMap(); + else if (auto storeOp = dyn_cast(opInst)) + map = storeOp.getAffineMap(); SmallVector operands(indices.begin(), indices.end()); fullyComposeAffineMapAndOperands(&map, &operands); map = simplifyAffineMap(map); @@ -780,9 +782,9 @@ DependenceResult mlir::checkMemrefAccessDependence( if (srcAccess.memref != dstAccess.memref) return DependenceResult::NoDependence; - // Return 'NoDependence' if one of these accesses is not a StoreOp. - if (!allowRAR && !isa(srcAccess.opInst) && - !isa(dstAccess.opInst)) + // Return 'NoDependence' if one of these accesses is not an AffineStoreOp. + if (!allowRAR && !isa(srcAccess.opInst) && + !isa(dstAccess.opInst)) return DependenceResult::NoDependence; // Get composed access function for 'srcAccess'. @@ -866,7 +868,7 @@ void mlir::getDependenceComponents( // Collect all load and store ops in loop nest rooted at 'forOp'. SmallVector loadAndStoreOpInsts; forOp.getOperation()->walk([&](Operation *opInst) { - if (isa(opInst) || isa(opInst)) + if (isa(opInst) || isa(opInst)) loadAndStoreOpInsts.push_back(opInst); }); diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 16e092b8205..0b487bac0ef 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -232,8 +232,8 @@ mlir::getInvariantAccesses(Value *iv, llvm::ArrayRef indices) { template static bool isContiguousAccess(Value *iv, LoadOrStoreOp memoryOp, int *memRefDim) { - static_assert(std::is_same::value || - std::is_same::value, + static_assert(std::is_same::value || + std::is_same::value, "Must be called on either const LoadOp & or const StoreOp &"); assert(memRefDim && "memRefDim == nullptr"); auto memRefType = memoryOp.getMemRefType(); @@ -250,25 +250,35 @@ static bool isContiguousAccess(Value *iv, LoadOrStoreOp memoryOp, } int uniqueVaryingIndexAlongIv = -1; - auto indices = memoryOp.getIndices(); - unsigned numIndices = llvm::size(indices); - unsigned dim = 0; - for (auto *index : indices) { - if (!isAccessInvariant(iv, index)) { - if (uniqueVaryingIndexAlongIv != -1) { - // 2+ varying indices -> do not vectorize along iv. - return false; + auto accessMap = memoryOp.getAffineMap(); + SmallVector mapOperands(memoryOp.getIndices()); + unsigned numDims = accessMap.getNumDims(); + for (unsigned i = 0, e = memRefType.getRank(); i < e; ++i) { + // Gather map operands used result expr 'i' in 'exprOperands'. + SmallVector exprOperands; + auto resultExpr = accessMap.getResult(i); + resultExpr.walk([&](AffineExpr expr) { + if (auto dimExpr = expr.dyn_cast()) + exprOperands.push_back(mapOperands[dimExpr.getPosition()]); + else if (auto symExpr = expr.dyn_cast()) + exprOperands.push_back(mapOperands[numDims + symExpr.getPosition()]); + }); + // Check access invariance of each operand in 'exprOperands'. + for (auto *exprOperand : exprOperands) { + if (!isAccessInvariant(iv, exprOperand)) { + if (uniqueVaryingIndexAlongIv != -1) { + // 2+ varying indices -> do not vectorize along iv. + return false; + } + uniqueVaryingIndexAlongIv = i; } - uniqueVaryingIndexAlongIv = dim; } - ++dim; } if (uniqueVaryingIndexAlongIv == -1) *memRefDim = -1; else - *memRefDim = numIndices - (uniqueVaryingIndexAlongIv + 1); - + *memRefDim = memRefType.getRank() - (uniqueVaryingIndexAlongIv + 1); return true; } @@ -320,8 +330,8 @@ isVectorizableLoopBodyWithOpCond(AffineForOp loop, loadAndStores.match(forOp, &loadAndStoresMatched); for (auto ls : loadAndStoresMatched) { auto *op = ls.getMatchedOperation(); - auto load = dyn_cast(op); - auto store = dyn_cast(op); + auto load = dyn_cast(op); + auto store = dyn_cast(op); // Only scalar types are considered vectorizable, all load/store must be // vectorizable for a loop to qualify as vectorizable. // TODO(ntv): ponder whether we want to be more general here. @@ -338,8 +348,8 @@ isVectorizableLoopBodyWithOpCond(AffineForOp loop, bool mlir::isVectorizableLoopBody(AffineForOp loop, int *memRefDim) { VectorizableOpFun fun([memRefDim](AffineForOp loop, Operation &op) { - auto load = dyn_cast(op); - auto store = dyn_cast(op); + auto load = dyn_cast(op); + auto store = dyn_cast(op); return load ? isContiguousAccess(loop.getInductionVar(), load, memRefDim) : isContiguousAccess(loop.getInductionVar(), store, memRefDim); }); diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index 0f5edc7b25a..b043d4734fd 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -20,6 +20,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Passes.h" @@ -48,9 +49,9 @@ FunctionPassBase *mlir::createMemRefBoundCheckPass() { void MemRefBoundCheck::runOnFunction() { getFunction().walk([](Operation *opInst) { - if (auto loadOp = dyn_cast(opInst)) { + if (auto loadOp = dyn_cast(opInst)) { boundCheckLoadOrStoreOp(loadOp); - } else if (auto storeOp = dyn_cast(opInst)) { + } else if (auto storeOp = dyn_cast(opInst)) { boundCheckLoadOrStoreOp(storeOp); } // TODO(bondhugula): do this for DMA ops as well. diff --git a/mlir/lib/Analysis/NestedMatcher.cpp b/mlir/lib/Analysis/NestedMatcher.cpp index dc6f939a59c..18be6cf3bc9 100644 --- a/mlir/lib/Analysis/NestedMatcher.cpp +++ b/mlir/lib/Analysis/NestedMatcher.cpp @@ -154,7 +154,7 @@ NestedPattern For(FilterFunctionType filter, ArrayRef nested) { } bool isLoadOrStore(Operation &op) { - return isa(op) || isa(op); + return isa(op) || isa(op); } } // end namespace matcher diff --git a/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp b/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp index 4456ac2b50b..1802b736fad 100644 --- a/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp @@ -19,6 +19,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Passes.h" @@ -116,7 +117,7 @@ void TestMemRefDependenceCheck::runOnFunction() { // Collect the loads and stores within the function. loadsAndStores.clear(); getFunction().walk([&](Operation *op) { - if (isa(op) || isa(op)) + if (isa(op) || isa(op)) loadsAndStores.push_back(op); }); diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index ae991f796e0..486c265525a 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -173,7 +173,8 @@ LogicalResult MemRefRegion::unionBoundingBox(const MemRefRegion &other) { LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth, ComputationSliceState *sliceState, bool addMemRefDimBounds) { - assert((isa(op) || isa(op)) && "load/store op expected"); + assert((isa(op) || isa(op)) && + "affine load/store op expected"); MemRefAccess access(op); memref = access.memref; @@ -381,12 +382,11 @@ Optional mlir::getMemRefSizeInBytes(MemRefType memRefType) { template LogicalResult mlir::boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, bool emitError) { - static_assert(std::is_same::value || - std::is_same::value, - "argument should be either a LoadOp or a StoreOp"); + static_assert(std::is_same::value || + std::is_same::value, + "argument should be either a AffineLoadOp or a AffineStoreOp"); Operation *opInst = loadOrStoreOp.getOperation(); - MemRefRegion region(opInst->getLoc()); if (failed(region.compute(opInst, /*loopDepth=*/0, /*sliceState=*/nullptr, /*addMemRefDimBounds=*/false))) @@ -434,9 +434,9 @@ LogicalResult mlir::boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, } // Explicitly instantiate the template so that the compiler knows we need them! -template LogicalResult mlir::boundCheckLoadOrStoreOp(LoadOp loadOp, +template LogicalResult mlir::boundCheckLoadOrStoreOp(AffineLoadOp loadOp, bool emitError); -template LogicalResult mlir::boundCheckLoadOrStoreOp(StoreOp storeOp, +template LogicalResult mlir::boundCheckLoadOrStoreOp(AffineStoreOp storeOp, bool emitError); // Returns in 'positions' the Block positions of 'op' in each ancestor @@ -484,9 +484,9 @@ static Operation *getInstAtPosition(ArrayRef positions, // Returns the MemRef accessed by load or store 'op'. static Value *getLoadOrStoreMemRef(Operation *op) { - if (auto loadOp = dyn_cast(op)) + if (auto loadOp = dyn_cast(op)) return loadOp.getMemRef(); - return cast(op).getMemRef(); + return cast(op).getMemRef(); } // Adds loop IV bounds to 'cst' for loop IVs not found in 'ivs'. @@ -560,8 +560,8 @@ LogicalResult mlir::computeSliceUnion(ArrayRef opsA, return failure(); } - bool readReadAccesses = - isa(srcAccess.opInst) && isa(dstAccess.opInst); + bool readReadAccesses = isa(srcAccess.opInst) && + isa(dstAccess.opInst); FlatAffineConstraints dependenceConstraints; // Check dependence between 'srcAccess' and 'dstAccess'. DependenceResult result = checkMemrefAccessDependence( @@ -752,7 +752,7 @@ void mlir::getComputationSliceState( : std::prev(srcLoopIVs[loopDepth - 1].getBody()->end()); llvm::SmallDenseSet sequentialLoops; - if (isa(depSourceOp) && isa(depSinkOp)) { + if (isa(depSourceOp) && isa(depSinkOp)) { // For read-read access pairs, clear any slice bounds on sequential loops. // Get sequential loops in loop nest rooted at 'srcLoopIVs[0]'. getSequentialLoops(isBackwardSlice ? srcLoopIVs[0] : dstLoopIVs[0], @@ -849,7 +849,7 @@ mlir::insertBackwardComputationSlice(Operation *srcOpInst, Operation *dstOpInst, // Constructs MemRefAccess populating it with the memref, its indices and // opinst from 'loadOrStoreOpInst'. MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) { - if (auto loadOp = dyn_cast(loadOrStoreOpInst)) { + if (auto loadOp = dyn_cast(loadOrStoreOpInst)) { memref = loadOp.getMemRef(); opInst = loadOrStoreOpInst; auto loadMemrefType = loadOp.getMemRefType(); @@ -858,8 +858,8 @@ MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) { indices.push_back(index); } } else { - assert(isa(loadOrStoreOpInst) && "load/store op expected"); - auto storeOp = dyn_cast(loadOrStoreOpInst); + assert(isa(loadOrStoreOpInst) && "load/store op expected"); + auto storeOp = dyn_cast(loadOrStoreOpInst); opInst = loadOrStoreOpInst; memref = storeOp.getMemRef(); auto storeMemrefType = storeOp.getMemRefType(); @@ -874,7 +874,7 @@ unsigned MemRefAccess::getRank() const { return memref->getType().cast().getRank(); } -bool MemRefAccess::isStore() const { return isa(opInst); } +bool MemRefAccess::isStore() const { return isa(opInst); } /// Returns the nesting depth of this statement, i.e., the number of loops /// surrounding this statement. @@ -914,7 +914,7 @@ static Optional getMemoryFootprintBytes(Block &block, // Walk this 'affine.for' operation to gather all memory regions. bool error = false; block.walk(start, end, [&](Operation *opInst) { - if (!isa(opInst) && !isa(opInst)) { + if (!isa(opInst) && !isa(opInst)) { // Neither load nor a store op. return; } @@ -977,7 +977,7 @@ bool mlir::isLoopParallel(AffineForOp forOp) { // Collect all load and store ops in loop nest rooted at 'forOp'. SmallVector loadAndStoreOpInsts; forOp.getOperation()->walk([&](Operation *opInst) { - if (isa(opInst) || isa(opInst)) + if (isa(opInst) || isa(opInst)) loadAndStoreOpInsts.push_back(opInst); }); diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index 0d1e2c0f416..7bb28e9893e 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -19,6 +19,7 @@ #include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/LoopAnalysis.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Operation.h" #include "mlir/StandardOps/Ops.h" @@ -108,7 +109,7 @@ Optional> mlir::shapeRatio(VectorType superVectorType, /// Examples can be found in the documentation of `makePermutationMap`, in the /// header file. static AffineMap makePermutationMap( - Operation::operand_range operands, + ArrayRef indices, const DenseMap &enclosingLoopToVectorDim) { if (enclosingLoopToVectorDim.empty()) return AffineMap(); @@ -116,7 +117,6 @@ static AffineMap makePermutationMap( enclosingLoopToVectorDim.begin()->getFirst()->getContext(); using functional::makePtrDynCaster; using functional::map; - SmallVector indices(operands); SmallVector perm(enclosingLoopToVectorDim.size(), getAffineConstantExpr(0, context)); @@ -167,7 +167,8 @@ static SetVector getEnclosingforOps(Operation *op) { } AffineMap mlir::makePermutationMap( - Operation *op, const DenseMap &loopToVectorDim) { + Operation *op, ArrayRef indices, + const DenseMap &loopToVectorDim) { DenseMap enclosingLoopToVectorDim; auto enclosingLoops = getEnclosingforOps(op); for (auto *forInst : enclosingLoops) { @@ -176,13 +177,7 @@ AffineMap mlir::makePermutationMap( enclosingLoopToVectorDim.insert(*it); } } - - if (auto load = dyn_cast(op)) { - return ::makePermutationMap(load.getIndices(), enclosingLoopToVectorDim); - } - - auto store = cast(op); - return ::makePermutationMap(store.getIndices(), enclosingLoopToVectorDim); + return ::makePermutationMap(indices, enclosingLoopToVectorDim); } bool mlir::matcher::operatesOnSuperVectorsOf(Operation &op, diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 5d8a0dd54d5..57832aafe45 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -1035,7 +1035,6 @@ Attribute Parser::parseAttribute(Type type) { case Token::string: { auto val = getToken().getStringValue(); consumeToken(Token::string); - // Parse the optional trailing colon type if one wasn't explicitly provided. if (!type && consumeIf(Token::colon) && !(type = parseType())) return Attribute(); @@ -2326,6 +2325,9 @@ ParseResult AffineParser::parseAffineMapOrIntegerSetInline(AffineMap &map, /// Parse an AffineMap where the dim and symbol identifiers are SSA ids. ParseResult AffineParser::parseAffineMapOfSSAIds(AffineMap &map) { + if (!consumeIf(Token::l_square)) + return failure(); + SmallVector exprs; auto parseElt = [&]() -> ParseResult { auto elt = parseAffineExpr(); @@ -2336,11 +2338,15 @@ ParseResult AffineParser::parseAffineMapOfSSAIds(AffineMap &map) { // Parse a multi-dimensional affine expression (a comma-separated list of // 1-d affine expressions); the list cannot be empty. Grammar: // multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `) - if (parseCommaSeparatedList(parseElt)) + if (parseCommaSeparatedListUntil(Token::r_square, parseElt, + /*allowEmptyList=*/true)) return failure(); // Parsed a valid affine map. - map = builder.getAffineMap(numDimOperands, - dimsAndSymbols.size() - numDimOperands, exprs); + if (exprs.empty()) + map = AffineMap(); + else + map = builder.getAffineMap(numDimOperands, + dimsAndSymbols.size() - numDimOperands, exprs); return success(); } @@ -3452,8 +3458,10 @@ public: if (parser.parseAffineMapOfSSAIds(map, parseElement)) return failure(); // Add AffineMap attribute. - mapAttr = parser.builder.getAffineMapAttr(map); - attrs.push_back(parser.builder.getNamedAttr(attrName, mapAttr)); + if (map) { + mapAttr = parser.builder.getAffineMapAttr(map); + attrs.push_back(parser.builder.getNamedAttr(attrName, mapAttr)); + } // Add dim operands before symbol operands in 'operands'. operands.assign(dimOperands.begin(), dimOperands.end()); diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index a3aa092b0ec..e867dc70ed3 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -75,16 +75,17 @@ namespace { struct DmaGeneration : public FunctionPass { explicit DmaGeneration( unsigned slowMemorySpace = 0, - unsigned fastMemorySpace = clFastMemorySpace, + unsigned fastMemorySpace = clFastMemorySpace, unsigned tagMemorySpace = 0, int minDmaTransferSize = 1024, uint64_t fastMemCapacityBytes = std::numeric_limits::max()) : slowMemorySpace(slowMemorySpace), fastMemorySpace(fastMemorySpace), - minDmaTransferSize(minDmaTransferSize), + tagMemorySpace(tagMemorySpace), minDmaTransferSize(minDmaTransferSize), fastMemCapacityBytes(fastMemCapacityBytes) {} explicit DmaGeneration(const DmaGeneration &other) : slowMemorySpace(other.slowMemorySpace), fastMemorySpace(other.fastMemorySpace), + tagMemorySpace(other.tagMemorySpace), minDmaTransferSize(other.minDmaTransferSize), fastMemCapacityBytes(other.fastMemCapacityBytes) {} @@ -111,6 +112,8 @@ struct DmaGeneration : public FunctionPass { const unsigned slowMemorySpace; // Fast memory space associated with DMAs. unsigned fastMemorySpace; + // Tag memory space associated with DMAs. + unsigned tagMemorySpace; // Minimum DMA transfer size supported by the target in bytes. const int minDmaTransferSize; // Capacity of the faster memory space. @@ -128,10 +131,11 @@ struct DmaGeneration : public FunctionPass { /// TODO(bondhugula): extend this to store op's. FunctionPassBase *mlir::createDmaGenerationPass(unsigned slowMemorySpace, unsigned fastMemorySpace, + unsigned tagMemorySpace, int minDmaTransferSize, uint64_t fastMemCapacityBytes) { - return new DmaGeneration(slowMemorySpace, fastMemorySpace, minDmaTransferSize, - fastMemCapacityBytes); + return new DmaGeneration(slowMemorySpace, fastMemorySpace, tagMemorySpace, + minDmaTransferSize, fastMemCapacityBytes); } // Info comprising stride and number of elements transferred every stride. @@ -173,11 +177,11 @@ static void getMultiLevelStrides(const MemRefRegion ®ion, static bool getFullMemRefAsRegion(Operation *opInst, unsigned numParamLoopIVs, MemRefRegion *region) { unsigned rank; - if (auto loadOp = dyn_cast(opInst)) { + if (auto loadOp = dyn_cast(opInst)) { rank = loadOp.getMemRefType().getRank(); region->memref = loadOp.getMemRef(); region->setWrite(false); - } else if (auto storeOp = dyn_cast(opInst)) { + } else if (auto storeOp = dyn_cast(opInst)) { rank = storeOp.getMemRefType().getRank(); region->memref = storeOp.getMemRef(); region->setWrite(true); @@ -363,7 +367,8 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, Block *block, *sizeInBytes = 0; } // Create a tag (single element 1-d memref) for the DMA. - auto tagMemRefType = top.getMemRefType({1}, top.getIntegerType(32)); + auto tagMemRefType = + top.getMemRefType({1}, top.getIntegerType(32), {}, tagMemorySpace); auto tagMemRef = prologue.create(loc, tagMemRefType); auto numElementsSSA = @@ -393,23 +398,34 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, Block *block, // don't get replaced. auto postDomFilter = std::prev(end); + // Create fully composed affine maps for each memref. + auto memAffineMap = b.getMultiDimIdentityMap(memIndices.size()); + fullyComposeAffineMapAndOperands(&memAffineMap, &memIndices); + auto bufAffineMap = b.getMultiDimIdentityMap(bufIndices.size()); + fullyComposeAffineMapAndOperands(&bufAffineMap, &bufIndices); + SmallVector tagIndices({zeroIndex}); + auto tagAffineMap = b.getMultiDimIdentityMap(tagIndices.size()); + fullyComposeAffineMapAndOperands(&tagAffineMap, &tagIndices); if (!region.isWrite()) { // DMA non-blocking read from original buffer to fast buffer. - b.create(loc, memref, memIndices, fastMemRef, bufIndices, - numElementsSSA, tagMemRef, zeroIndex, stride, - numEltPerStride); + b.create(loc, memref, memAffineMap, memIndices, + fastMemRef, bufAffineMap, bufIndices, tagMemRef, + tagAffineMap, tagIndices, numElementsSSA, stride, + numEltPerStride); } else { // DMA non-blocking write from fast buffer to the original memref. - auto op = b.create(loc, fastMemRef, bufIndices, memref, - memIndices, numElementsSSA, tagMemRef, - zeroIndex, stride, numEltPerStride); + auto op = b.create( + loc, fastMemRef, bufAffineMap, bufIndices, memref, memAffineMap, + memIndices, tagMemRef, tagAffineMap, tagIndices, numElementsSSA, stride, + numEltPerStride); // Since new ops are being appended (for outgoing DMAs), adjust the end to // mark end of range of the original. *nEnd = Block::iterator(op.getOperation()); } // Matching DMA wait to block on completion; tag always has a 0 index. - b.create(loc, tagMemRef, zeroIndex, numElementsSSA); + b.create(loc, tagMemRef, tagAffineMap, zeroIndex, + numElementsSSA); // Generate dealloc for the tag. auto tagDeallocOp = epilogue.create(loc, tagMemRef); @@ -479,7 +495,8 @@ bool DmaGeneration::runOnBlock(Block *block) { // Get to the first load, store, or for op. auto curBegin = std::find_if(block->begin(), block->end(), [&](Operation &op) { - return isa(op) || isa(op) || isa(op); + return isa(op) || isa(op) || + isa(op); }); for (auto it = curBegin; it != block->end(); ++it) { @@ -522,7 +539,7 @@ bool DmaGeneration::runOnBlock(Block *block) { runOnBlock(/*begin=*/it, /*end=*/std::next(it)); curBegin = std::next(it); } - } else if (!isa(&*it) && !isa(&*it)) { + } else if (!isa(&*it) && !isa(&*it)) { runOnBlock(/*begin=*/curBegin, /*end=*/it); curBegin = std::next(it); } @@ -607,10 +624,10 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { // Walk this range of operations to gather all memory regions. block->walk(begin, end, [&](Operation *opInst) { // Gather regions to allocate to buffers in faster memory space. - if (auto loadOp = dyn_cast(opInst)) { + if (auto loadOp = dyn_cast(opInst)) { if (loadOp.getMemRefType().getMemorySpace() != slowMemorySpace) return; - } else if (auto storeOp = dyn_cast(opInst)) { + } else if (auto storeOp = dyn_cast(opInst)) { if (storeOp.getMemRefType().getMemorySpace() != slowMemorySpace) return; } else { diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 77b944f3e01..1eee40b88da 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -133,9 +133,9 @@ struct LoopNestStateCollector { forOps.push_back(cast(op)); else if (op->getNumRegions() != 0) hasNonForRegion = true; - else if (isa(op)) + else if (isa(op)) loadOpInsts.push_back(op); - else if (isa(op)) + else if (isa(op)) storeOpInsts.push_back(op); }); } @@ -143,8 +143,8 @@ struct LoopNestStateCollector { // TODO(b/117228571) Replace when this is modeled through side-effects/op traits static bool isMemRefDereferencingOp(Operation &op) { - if (isa(op) || isa(op) || isa(op) || - isa(op)) + if (isa(op) || isa(op) || + isa(op) || isa(op)) return true; return false; } @@ -174,7 +174,7 @@ public: unsigned getLoadOpCount(Value *memref) { unsigned loadOpCount = 0; for (auto *loadOpInst : loads) { - if (memref == cast(loadOpInst).getMemRef()) + if (memref == cast(loadOpInst).getMemRef()) ++loadOpCount; } return loadOpCount; @@ -184,7 +184,7 @@ public: unsigned getStoreOpCount(Value *memref) { unsigned storeOpCount = 0; for (auto *storeOpInst : stores) { - if (memref == cast(storeOpInst).getMemRef()) + if (memref == cast(storeOpInst).getMemRef()) ++storeOpCount; } return storeOpCount; @@ -194,7 +194,7 @@ public: void getStoreOpsForMemref(Value *memref, SmallVectorImpl *storeOps) { for (auto *storeOpInst : stores) { - if (memref == cast(storeOpInst).getMemRef()) + if (memref == cast(storeOpInst).getMemRef()) storeOps->push_back(storeOpInst); } } @@ -203,7 +203,7 @@ public: void getLoadOpsForMemref(Value *memref, SmallVectorImpl *loadOps) { for (auto *loadOpInst : loads) { - if (memref == cast(loadOpInst).getMemRef()) + if (memref == cast(loadOpInst).getMemRef()) loadOps->push_back(loadOpInst); } } @@ -213,10 +213,10 @@ public: void getLoadAndStoreMemrefSet(DenseSet *loadAndStoreMemrefSet) { llvm::SmallDenseSet loadMemrefs; for (auto *loadOpInst : loads) { - loadMemrefs.insert(cast(loadOpInst).getMemRef()); + loadMemrefs.insert(cast(loadOpInst).getMemRef()); } for (auto *storeOpInst : stores) { - auto *memref = cast(storeOpInst).getMemRef(); + auto *memref = cast(storeOpInst).getMemRef(); if (loadMemrefs.count(memref) > 0) loadAndStoreMemrefSet->insert(memref); } @@ -308,7 +308,7 @@ public: bool writesToLiveInOrEscapingMemrefs(unsigned id) { Node *node = getNode(id); for (auto *storeOpInst : node->stores) { - auto *memref = cast(storeOpInst).getMemRef(); + auto *memref = cast(storeOpInst).getMemRef(); auto *op = memref->getDefiningOp(); // Return true if 'memref' is a block argument. if (!op) @@ -333,7 +333,7 @@ public: Node *node = getNode(id); for (auto *storeOpInst : node->stores) { // Return false if there exist out edges from 'id' on 'memref'. - if (getOutEdgeCount(id, cast(storeOpInst).getMemRef()) > 0) + if (getOutEdgeCount(id, cast(storeOpInst).getMemRef()) > 0) return false; } return true; @@ -658,28 +658,28 @@ bool MemRefDependenceGraph::init(Function f) { Node node(nextNodeId++, &op); for (auto *opInst : collector.loadOpInsts) { node.loads.push_back(opInst); - auto *memref = cast(opInst).getMemRef(); + auto *memref = cast(opInst).getMemRef(); memrefAccesses[memref].insert(node.id); } for (auto *opInst : collector.storeOpInsts) { node.stores.push_back(opInst); - auto *memref = cast(opInst).getMemRef(); + auto *memref = cast(opInst).getMemRef(); memrefAccesses[memref].insert(node.id); } forToNodeMap[&op] = node.id; nodes.insert({node.id, node}); - } else if (auto loadOp = dyn_cast(op)) { + } else if (auto loadOp = dyn_cast(op)) { // Create graph node for top-level load op. Node node(nextNodeId++, &op); node.loads.push_back(&op); - auto *memref = cast(op).getMemRef(); + auto *memref = cast(op).getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); - } else if (auto storeOp = dyn_cast(op)) { + } else if (auto storeOp = dyn_cast(op)) { // Create graph node for top-level store op. Node node(nextNodeId++, &op); node.stores.push_back(&op); - auto *memref = cast(op).getMemRef(); + auto *memref = cast(op).getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); } else if (op.getNumRegions() != 0) { @@ -740,7 +740,7 @@ static void moveLoadsAccessingMemrefTo(Value *memref, dstLoads->clear(); SmallVector srcLoadsToKeep; for (auto *load : *srcLoads) { - if (cast(load).getMemRef() == memref) + if (cast(load).getMemRef() == memref) dstLoads->push_back(load); else srcLoadsToKeep.push_back(load); @@ -861,7 +861,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, // Builder to create constants at the top level. OpBuilder top(forInst->getFunction().getBody()); // Create new memref type based on slice bounds. - auto *oldMemRef = cast(srcStoreOpInst).getMemRef(); + auto *oldMemRef = cast(srcStoreOpInst).getMemRef(); auto oldMemRefType = oldMemRef->getType().cast(); unsigned rank = oldMemRefType.getRank(); @@ -976,7 +976,7 @@ static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, // Gather all memrefs from 'srcNode' store ops. DenseSet storeMemrefs; for (auto *storeOpInst : srcNode->stores) { - storeMemrefs.insert(cast(storeOpInst).getMemRef()); + storeMemrefs.insert(cast(storeOpInst).getMemRef()); } // Return false if any of the following are true: // *) 'srcNode' writes to a live in/out memref other than 'memref'. @@ -1461,7 +1461,7 @@ public: DenseSet visitedMemrefs; while (!loads.empty()) { // Get memref of load on top of the stack. - auto *memref = cast(loads.back()).getMemRef(); + auto *memref = cast(loads.back()).getMemRef(); if (visitedMemrefs.count(memref) > 0) continue; visitedMemrefs.insert(memref); @@ -1517,7 +1517,7 @@ public: // Gather 'dstNode' store ops to 'memref'. SmallVector dstStoreOpInsts; for (auto *storeOpInst : dstNode->stores) - if (cast(storeOpInst).getMemRef() == memref) + if (cast(storeOpInst).getMemRef() == memref) dstStoreOpInsts.push_back(storeOpInst); unsigned bestDstLoopDepth; @@ -1562,7 +1562,7 @@ public: // Create private memref for 'memref' in 'dstAffineForOp'. SmallVector storesForMemref; for (auto *storeOpInst : sliceCollector.storeOpInsts) { - if (cast(storeOpInst).getMemRef() == memref) + if (cast(storeOpInst).getMemRef() == memref) storesForMemref.push_back(storeOpInst); } assert(storesForMemref.size() == 1); @@ -1584,7 +1584,7 @@ public: // Add new load ops to current Node load op list 'loads' to // continue fusing based on new operands. for (auto *loadOpInst : dstLoopCollector.loadOpInsts) { - auto *loadMemRef = cast(loadOpInst).getMemRef(); + auto *loadMemRef = cast(loadOpInst).getMemRef(); if (visitedMemrefs.count(loadMemRef) == 0) loads.push_back(loadOpInst); } @@ -1742,7 +1742,7 @@ public: // Check that all stores are to the same memref. DenseSet storeMemrefs; for (auto *storeOpInst : sibNode->stores) { - storeMemrefs.insert(cast(storeOpInst).getMemRef()); + storeMemrefs.insert(cast(storeOpInst).getMemRef()); } if (storeMemrefs.size() != 1) return false; @@ -1753,7 +1753,7 @@ public: auto fn = dstNode->op->getFunction(); for (unsigned i = 0, e = fn.getNumArguments(); i != e; ++i) { for (auto *user : fn.getArgument(i)->getUsers()) { - if (auto loadOp = dyn_cast(user)) { + if (auto loadOp = dyn_cast(user)) { // Gather loops surrounding 'use'. SmallVector loops; getLoopIVs(*user, &loops); diff --git a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp index c4c1184fa82..48e97f44436 100644 --- a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp +++ b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp @@ -70,7 +70,7 @@ areAllOpsInTheBlockListInvariant(Region &blockList, Value *indVar, static bool isMemRefDereferencingOp(Operation &op) { // TODO(asabne): Support DMA Ops. - if (isa(op) || isa(op)) { + if (isa(op) || isa(op)) { return true; } return false; @@ -94,23 +94,25 @@ bool isOpLoopInvariant(Operation &op, Value *indVar, // If the body of a predicated region has a for loop, we don't hoist the // 'affine.if'. return false; - } else if (isa(op) || isa(op)) { + } else if (isa(op) || isa(op)) { // TODO(asabne): Support DMA ops. return false; } else if (!isa(op)) { if (isMemRefDereferencingOp(op)) { - Value *memref = isa(op) ? cast(op).getMemRef() - : cast(op).getMemRef(); + Value *memref = isa(op) + ? cast(op).getMemRef() + : cast(op).getMemRef(); for (auto *user : memref->getUsers()) { // If this memref has a user that is a DMA, give up because these // operations write to this memref. - if (isa(op) || isa(op)) { + if (isa(op) || isa(op)) { return false; } // If the memref used by the load/store is used in a store elsewhere in // the loop nest, we do not hoist. Similarly, if the memref used in a // load is also being stored too, we do not hoist the load. - if (isa(user) || (isa(user) && isa(op))) { + if (isa(user) || + (isa(user) && isa(op))) { if (&op != user) { SmallVector userIVs; getLoopIVs(*user, &userIVs); diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index 1208e2fdd15..13a53e3a944 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -22,6 +22,7 @@ // SSA scalars live out of 'affine.for'/'affine.if' statements is available. //===----------------------------------------------------------------------===// +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/Dominance.h" #include "mlir/Analysis/Utils.h" @@ -72,7 +73,7 @@ namespace { struct MemRefDataFlowOpt : public FunctionPass { void runOnFunction() override; - void forwardStoreToLoad(LoadOp loadOp); + void forwardStoreToLoad(AffineLoadOp loadOp); // A list of memref's that are potentially dead / could be eliminated. SmallPtrSet memrefsToErase; @@ -93,7 +94,7 @@ FunctionPassBase *mlir::createMemRefDataFlowOptPass() { // This is a straightforward implementation not optimized for speed. Optimize // this in the future if needed. -void MemRefDataFlowOpt::forwardStoreToLoad(LoadOp loadOp) { +void MemRefDataFlowOpt::forwardStoreToLoad(AffineLoadOp loadOp) { Operation *lastWriteStoreOp = nullptr; Operation *loadOpInst = loadOp.getOperation(); @@ -103,7 +104,7 @@ void MemRefDataFlowOpt::forwardStoreToLoad(LoadOp loadOp) { SmallVector storeOps; unsigned minSurroundingLoops = getNestingDepth(*loadOpInst); for (auto *user : loadOp.getMemRef()->getUsers()) { - auto storeOp = dyn_cast(user); + auto storeOp = dyn_cast(user); if (!storeOp) continue; auto *storeOpInst = storeOp.getOperation(); @@ -202,7 +203,7 @@ void MemRefDataFlowOpt::forwardStoreToLoad(LoadOp loadOp) { return; // Perform the actual store to load forwarding. - Value *storeVal = cast(lastWriteStoreOp).getValueToStore(); + Value *storeVal = cast(lastWriteStoreOp).getValueToStore(); loadOp.getResult()->replaceAllUsesWith(storeVal); // Record the memref for a later sweep to optimize away. memrefsToErase.insert(loadOp.getMemRef()); @@ -225,7 +226,8 @@ void MemRefDataFlowOpt::runOnFunction() { memrefsToErase.clear(); // Walk all load's and perform load/store forwarding. - f.walk([&](LoadOp loadOp) { forwardStoreToLoad(loadOp); }); + f.walk( + [&](AffineLoadOp loadOp) { forwardStoreToLoad(loadOp); }); // Erase all load op's whose results were replaced with store fwd'ed ones. for (auto *loadOp : loadOpsToErase) { @@ -243,7 +245,7 @@ void MemRefDataFlowOpt::runOnFunction() { // could still erase it if the call had no side-effects. continue; if (llvm::any_of(memref->getUsers(), [&](Operation *ownerInst) { - return (!isa(ownerInst) && !isa(ownerInst)); + return (!isa(ownerInst) && !isa(ownerInst)); })) continue; diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index d0e0d18d586..af456c31408 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -57,10 +57,9 @@ FunctionPassBase *mlir::createPipelineDataTransferPass() { // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are // added. TODO(b/117228571) static unsigned getTagMemRefPos(Operation &dmaInst) { - assert(isa(dmaInst) || isa(dmaInst)); - if (isa(dmaInst)) { - // Second to last operand. - return dmaInst.getNumOperands() - 2; + assert(isa(dmaInst) || isa(dmaInst)); + if (auto dmaStartOp = dyn_cast(dmaInst)) { + return dmaStartOp.getTagMemRefOperandIndex(); } // First operand for a dma finish operation. return 0; @@ -151,7 +150,7 @@ void PipelineDataTransfer::runOnFunction() { } // Check if tags of the dma start op and dma wait op match. -static bool checkTagMatch(DmaStartOp startOp, DmaWaitOp waitOp) { +static bool checkTagMatch(AffineDmaStartOp startOp, AffineDmaWaitOp waitOp) { if (startOp.getTagMemRef() != waitOp.getTagMemRef()) return false; auto startIndices = startOp.getTagIndices(); @@ -179,9 +178,9 @@ static void findMatchingStartFinishInsts( SmallVectorImpl> &startWaitPairs) { // Collect outgoing DMA operations - needed to check for dependences below. - SmallVector outgoingDmaOps; + SmallVector outgoingDmaOps; for (auto &op : *forOp.getBody()) { - auto dmaStartOp = dyn_cast(op); + auto dmaStartOp = dyn_cast(op); if (dmaStartOp && dmaStartOp.isSrcMemorySpaceFaster()) outgoingDmaOps.push_back(dmaStartOp); } @@ -189,11 +188,11 @@ static void findMatchingStartFinishInsts( SmallVector dmaStartInsts, dmaFinishInsts; for (auto &op : *forOp.getBody()) { // Collect DMA finish operations. - if (isa(op)) { + if (isa(op)) { dmaFinishInsts.push_back(&op); continue; } - auto dmaStartOp = dyn_cast(op); + auto dmaStartOp = dyn_cast(op); if (!dmaStartOp) continue; @@ -234,8 +233,8 @@ static void findMatchingStartFinishInsts( // For each start operation, we look for a matching finish operation. for (auto *dmaStartInst : dmaStartInsts) { for (auto *dmaFinishInst : dmaFinishInsts) { - if (checkTagMatch(cast(dmaStartInst), - cast(dmaFinishInst))) { + if (checkTagMatch(cast(dmaStartInst), + cast(dmaFinishInst))) { startWaitPairs.push_back({dmaStartInst, dmaFinishInst}); break; } @@ -273,7 +272,7 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { for (auto &pair : startWaitPairs) { auto *dmaStartInst = pair.first; Value *oldMemRef = dmaStartInst->getOperand( - cast(dmaStartInst).getFasterMemPos()); + cast(dmaStartInst).getFasterMemPos()); if (!doubleBuffer(oldMemRef, forOp)) { // Normally, double buffering should not fail because we already checked // that there are no uses outside. @@ -324,7 +323,7 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { DenseMap instShiftMap; for (auto &pair : startWaitPairs) { auto *dmaStartInst = pair.first; - assert(isa(dmaStartInst)); + assert(isa(dmaStartInst)); instShiftMap[dmaStartInst] = 0; // Set shifts for DMA start op's affine operand computation slices to 0. SmallVector sliceOps; diff --git a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp index 93503d11e0a..a87883d12da 100644 --- a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp @@ -46,10 +46,10 @@ using namespace mlir; static void getLoadAndStoreMemRefAccesses(Operation *opA, DenseMap &values) { opA->walk([&](Operation *op) { - if (auto loadOp = dyn_cast(op)) { + if (auto loadOp = dyn_cast(op)) { if (values.count(loadOp.getMemRef()) == 0) values[loadOp.getMemRef()] = false; - } else if (auto storeOp = dyn_cast(op)) { + } else if (auto storeOp = dyn_cast(op)) { values[storeOp.getMemRef()] = true; } }); @@ -60,10 +60,10 @@ static void getLoadAndStoreMemRefAccesses(Operation *opA, // Returns false otherwise. static bool isDependentLoadOrStoreOp(Operation *op, DenseMap &values) { - if (auto loadOp = dyn_cast(op)) { + if (auto loadOp = dyn_cast(op)) { return values.count(loadOp.getMemRef()) > 0 && values[loadOp.getMemRef()] == true; - } else if (auto storeOp = dyn_cast(op)) { + } else if (auto storeOp = dyn_cast(op)) { return values.count(storeOp.getMemRef()) > 0; } return false; @@ -115,7 +115,7 @@ static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) { opX->walk([&](Operation *op) { if (lastDepOp) return; - if (isa(op) || isa(op)) { + if (isa(op) || isa(op)) { if (isDependentLoadOrStoreOp(op, values)) lastDepOp = opX; return; @@ -185,7 +185,7 @@ gatherLoadsAndStores(AffineForOp forOp, SmallVectorImpl &loadAndStoreOps) { bool hasIfOp = false; forOp.getOperation()->walk([&](Operation *op) { - if (isa(op) || isa(op)) + if (isa(op) || isa(op)) loadAndStoreOps.push_back(op); else if (isa(op)) hasIfOp = true; @@ -442,7 +442,7 @@ bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, unsigned storeCount = 0; llvm::SmallDenseSet storeMemrefs; srcForOp.getOperation()->walk([&](Operation *op) { - if (auto storeOp = dyn_cast(op)) { + if (auto storeOp = dyn_cast(op)) { storeMemrefs.insert(storeOp.getMemRef()); ++storeCount; } @@ -454,7 +454,7 @@ bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, // 'insertPointParent'. for (auto *value : storeMemrefs) { for (auto *user : value->getUsers()) { - if (auto loadOp = dyn_cast(user)) { + if (auto loadOp = dyn_cast(user)) { SmallVector loops; // Check if any loop in loop nest surrounding 'user' is // 'insertPointParent'. diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 876f44b8604..16f4effca15 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -38,12 +38,24 @@ using namespace mlir; // Temporary utility: will be replaced when this is modeled through // side-effects/op traits. TODO(b/117228571) static bool isMemRefDereferencingOp(Operation &op) { - if (isa(op) || isa(op) || isa(op) || - isa(op)) + if (isa(op) || isa(op) || + isa(op) || isa(op)) return true; return false; } +/// Return the AffineMapAttr associated with memory 'op' on 'memref'. +static NamedAttribute getAffineMapAttrForMemRef(Operation *op, Value *memref) { + if (auto loadOp = dyn_cast(op)) + return loadOp.getAffineMapAttrForMemRef(memref); + else if (auto storeOp = dyn_cast(op)) + return storeOp.getAffineMapAttrForMemRef(memref); + else if (auto dmaStart = dyn_cast(op)) + return dmaStart.getAffineMapAttrForMemRef(memref); + assert(isa(op)); + return cast(op).getAffineMapAttrForMemRef(memref); +} + bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, ArrayRef extraIndices, AffineMap indexRemap, @@ -111,24 +123,32 @@ bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, assert(i < opInst->getNumOperands() && "operand guaranteed to be found"); return i; }; - unsigned memRefOperandPos = getMemRefOperandPos(); - - // Construct the new operation using this memref. - OperationState state(opInst->getLoc(), opInst->getName()); - state.setOperandListToResizable(opInst->hasResizableOperandsList()); - state.operands.reserve(opInst->getNumOperands() + extraIndices.size()); - // Insert the non-memref operands. - state.operands.append(opInst->operand_begin(), - opInst->operand_begin() + memRefOperandPos); - state.operands.push_back(newMemRef); OpBuilder builder(opInst); - for (auto *extraIndex : extraIndices) { - assert(extraIndex->getDefiningOp()->getNumResults() == 1 && - "single result op's expected to generate these indices"); - assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) && - "invalid memory op index"); - state.operands.push_back(extraIndex); + unsigned memRefOperandPos = getMemRefOperandPos(); + NamedAttribute oldMapAttrPair = + getAffineMapAttrForMemRef(opInst, oldMemRef); + AffineMap oldMap = oldMapAttrPair.second.cast().getValue(); + unsigned oldMapNumInputs = oldMap.getNumInputs(); + SmallVector oldMapOperands( + opInst->operand_begin() + memRefOperandPos + 1, + opInst->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs); + SmallVector affineApplyOps; + + // Apply 'oldMemRefOperands = oldMap(oldMapOperands)'. + SmallVector oldMemRefOperands; + oldMemRefOperands.reserve(oldMemRefRank); + if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) { + for (auto resultExpr : oldMap.getResults()) { + auto singleResMap = builder.getAffineMap( + oldMap.getNumDims(), oldMap.getNumSymbols(), resultExpr); + auto afOp = builder.create(opInst->getLoc(), + singleResMap, oldMapOperands); + oldMemRefOperands.push_back(afOp); + affineApplyOps.push_back(afOp); + } + } else { + oldMemRefOperands.append(oldMapOperands.begin(), oldMapOperands.end()); } // Construct new indices as a remap of the old ones if a remapping has been @@ -137,28 +157,70 @@ bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, SmallVector remapOperands; remapOperands.reserve(extraOperands.size() + oldMemRefRank); remapOperands.append(extraOperands.begin(), extraOperands.end()); - remapOperands.append(opInst->operand_begin() + memRefOperandPos + 1, - opInst->operand_begin() + memRefOperandPos + 1 + - oldMemRefRank); + remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end()); + + SmallVector remapOutputs; + remapOutputs.reserve(oldMemRefRank); + if (indexRemap && indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) { - // Remapped indices. for (auto resultExpr : indexRemap.getResults()) { auto singleResMap = builder.getAffineMap( indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr); auto afOp = builder.create(opInst->getLoc(), singleResMap, remapOperands); - state.operands.push_back(afOp); + remapOutputs.push_back(afOp); + affineApplyOps.push_back(afOp); } } else { // No remapping specified. - state.operands.append(remapOperands.begin(), remapOperands.end()); + remapOutputs.append(remapOperands.begin(), remapOperands.end()); + } + + SmallVector newMapOperands; + newMapOperands.reserve(newMemRefRank); + + // Prepend 'extraIndices' in 'newMapOperands'. + for (auto *extraIndex : extraIndices) { + assert(extraIndex->getDefiningOp()->getNumResults() == 1 && + "single result op's expected to generate these indices"); + assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) && + "invalid memory op index"); + newMapOperands.push_back(extraIndex); } + // Append 'remapOutputs' to 'newMapOperands'. + newMapOperands.append(remapOutputs.begin(), remapOutputs.end()); + + // Create new fully composed AffineMap for new op to be created. + assert(newMapOperands.size() == newMemRefRank); + auto newMap = builder.getMultiDimIdentityMap(newMemRefRank); + // TODO(b/136262594) Avoid creating/deleting temporary AffineApplyOps here. + fullyComposeAffineMapAndOperands(&newMap, &newMapOperands); + newMap = simplifyAffineMap(newMap); + canonicalizeMapAndOperands(&newMap, &newMapOperands); + // Remove any affine.apply's that became dead as a result of composition. + for (auto *value : affineApplyOps) + if (value->use_empty()) + value->getDefiningOp()->erase(); + + // Construct the new operation using this memref. + OperationState state(opInst->getLoc(), opInst->getName()); + state.setOperandListToResizable(opInst->hasResizableOperandsList()); + state.operands.reserve(opInst->getNumOperands() + extraIndices.size()); + // Insert the non-memref operands. + state.operands.append(opInst->operand_begin(), + opInst->operand_begin() + memRefOperandPos); + // Insert the new memref value. + state.operands.push_back(newMemRef); + + // Insert the new memref map operands. + state.operands.append(newMapOperands.begin(), newMapOperands.end()); + // Insert the remaining operands unmodified. state.operands.append(opInst->operand_begin() + memRefOperandPos + 1 + - oldMemRefRank, + oldMapNumInputs, opInst->operand_end()); // Result types don't change. Both memref's are of the same elemental type. @@ -166,9 +228,15 @@ bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, for (auto *result : opInst->getResults()) state.types.push_back(result->getType()); - // Attributes also do not change. - state.attributes.append(opInst->getAttrs().begin(), - opInst->getAttrs().end()); + // Add attribute for 'newMap', other Attributes do not change. + auto newMapAttr = builder.getAffineMapAttr(newMap); + for (auto namedAttr : opInst->getAttrs()) { + if (namedAttr.first == oldMapAttrPair.first) { + state.attributes.push_back({namedAttr.first, newMapAttr}); + } else { + state.attributes.push_back(namedAttr); + } + } // Create the new operation. auto *repOp = builder.createOperation(state); diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 3fca26bef19..4aff2ac4d13 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -44,7 +44,6 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" using namespace mlir; @@ -717,7 +716,8 @@ struct VectorizationState { // do not necessarily belong to use-def chains starting from loads (e.g // storing a constant), we need to handle them in a post-pass. DenseSet terminals; - // Checks that the type of `op` is StoreOp and adds it to the terminals set. + // Checks that the type of `op` is AffineStoreOp and adds it to the terminals + // set. void registerTerminal(Operation *op); private: @@ -739,14 +739,14 @@ void VectorizationState::registerReplacement(Operation *key, Operation *value) { vectorizedSet.insert(value); vectorizationMap.insert(std::make_pair(key, value)); registerReplacement(key->getResult(0), value->getResult(0)); - if (isa(key)) { + if (isa(key)) { assert(roots.count(key) == 0 && "root was already inserted previously"); roots.insert(key); } } void VectorizationState::registerTerminal(Operation *op) { - assert(isa(op) && "terminal must be a StoreOp"); + assert(isa(op) && "terminal must be a AffineStoreOp"); assert(terminals.count(op) == 0 && "terminal was already inserted previously"); terminals.insert(op); @@ -766,16 +766,31 @@ void VectorizationState::registerReplacement(Value *key, Value *value) { replacementMap.insert(std::make_pair(key, value)); } +// Apply 'map' with 'mapOperands' returning resulting values in 'results'. +static void computeMemoryOpIndices(Operation *op, AffineMap map, + ArrayRef mapOperands, + SmallVectorImpl &results) { + OpBuilder builder(op); + for (auto resultExpr : map.getResults()) { + auto singleResMap = + builder.getAffineMap(map.getNumDims(), map.getNumSymbols(), resultExpr); + auto afOp = + builder.create(op->getLoc(), singleResMap, mapOperands); + results.push_back(afOp); + } +} + ////// TODO(ntv): Hoist to a VectorizationMaterialize.cpp when appropriate. //// /// Handles the vectorization of load and store MLIR operations. /// -/// LoadOp operations are the roots of the vectorizeNonTerminals call. They are -/// vectorized immediately. The resulting vector.transfer_read is immediately -/// registered to replace all uses of the LoadOp in this pattern's scope. +/// AffineLoadOp operations are the roots of the vectorizeNonTerminals call. +/// They are vectorized immediately. The resulting vector.transfer_read is +/// immediately registered to replace all uses of the AffineLoadOp in this +/// pattern's scope. /// -/// StoreOp are the terminals of the vectorizeNonTerminals call. They need to be -/// vectorized late once all the use-def chains have been traversed. +/// AffineStoreOp are the terminals of the vectorizeNonTerminals call. They +/// need to be vectorized late once all the use-def chains have been traversed. /// Additionally, they may have ssa-values operands which come from outside the /// scope of the current pattern. /// Such special cases force us to delay the vectorization of the stores until @@ -798,17 +813,26 @@ static LogicalResult vectorizeRootOrTerminal(Value *iv, // identity subset of AffineMap and do not change layout. // TODO(ntv): increase the expressiveness power of vector.transfer operations // as needed by various targets. - if (isa(opInst)) { + if (auto load = dyn_cast(opInst)) { + OpBuilder b(opInst); + SmallVector mapOperands(load.getIndices()); + SmallVector indices; + indices.reserve(load.getMemRefType().getRank()); + if (load.getAffineMap() != + b.getMultiDimIdentityMap(load.getMemRefType().getRank())) { + computeMemoryOpIndices(opInst, load.getAffineMap(), mapOperands, indices); + } else { + indices.append(load.getIndices().begin(), load.getIndices().end()); + } auto permutationMap = - makePermutationMap(opInst, state->strategy->loopToVectorDim); + makePermutationMap(opInst, indices, state->strategy->loopToVectorDim); if (!permutationMap) return LogicalResult::Failure; LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: "); LLVM_DEBUG(permutationMap.print(dbgs())); - OpBuilder b(opInst); auto transfer = b.create( opInst->getLoc(), vectorType, memoryOp.getMemRef(), - map(makePtrDynCaster(), memoryOp.getIndices()), permutationMap); + map(makePtrDynCaster(), indices), permutationMap); state->registerReplacement(opInst, transfer.getOperation()); } else { state->registerTerminal(opInst); @@ -837,8 +861,8 @@ static LogicalResult vectorizeAffineForOp(AffineForOp loop, int64_t step, loadAndStores.match(loop.getOperation(), &loadAndStoresMatches); for (auto ls : loadAndStoresMatches) { auto *opInst = ls.getMatchedOperation(); - auto load = dyn_cast(opInst); - auto store = dyn_cast(opInst); + auto load = dyn_cast(opInst); + auto store = dyn_cast(opInst); LLVM_DEBUG(opInst->print(dbgs())); LogicalResult result = load ? vectorizeRootOrTerminal(loop.getInductionVar(), load, state) @@ -1002,21 +1026,32 @@ static Value *vectorizeOperand(Value *operand, Operation *op, static Operation *vectorizeOneOperation(Operation *opInst, VectorizationState *state) { // Sanity checks. - assert(!isa(opInst) && + assert(!isa(opInst) && "all loads must have already been fully vectorized independently"); assert(!isa(opInst) && "vector.transfer_read cannot be further vectorized"); assert(!isa(opInst) && "vector.transfer_write cannot be further vectorized"); - if (auto store = dyn_cast(opInst)) { + if (auto store = dyn_cast(opInst)) { + OpBuilder b(opInst); auto *memRef = store.getMemRef(); auto *value = store.getValueToStore(); auto *vectorValue = vectorizeOperand(value, opInst, state); - auto indices = map(makePtrDynCaster(), store.getIndices()); - OpBuilder b(opInst); + + SmallVector mapOperands(store.getIndices()); + SmallVector indices; + indices.reserve(store.getMemRefType().getRank()); + if (store.getAffineMap() != + b.getMultiDimIdentityMap(store.getMemRefType().getRank())) { + computeMemoryOpIndices(opInst, store.getAffineMap(), mapOperands, + indices); + } else { + indices.append(store.getIndices().begin(), store.getIndices().end()); + } + auto permutationMap = - makePermutationMap(opInst, state->strategy->loopToVectorDim); + makePermutationMap(opInst, indices, state->strategy->loopToVectorDim); if (!permutationMap) return nullptr; LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: "); @@ -1025,7 +1060,7 @@ static Operation *vectorizeOneOperation(Operation *opInst, opInst->getLoc(), vectorValue, memRef, indices, permutationMap); auto *res = transfer.getOperation(); LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << *res); - // "Terminals" (i.e. StoreOps) are erased on the spot. + // "Terminals" (i.e. AffineStoreOps) are erased on the spot. opInst->erase(); return res; } @@ -1156,9 +1191,9 @@ static LogicalResult vectorizeRootMatch(NestedMatch m, // From now on, any error triggers the scope guard above. ////////////////////////////////////////////////////////////////////////////// // 1. Vectorize all the loops matched by the pattern, recursively. - // This also vectorizes the roots (LoadOp) as well as registers the terminals - // (StoreOp) for post-processing vectorization (we need to wait for all - // use-def chains into them to be vectorized first). + // This also vectorizes the roots (AffineLoadOp) as well as registers the + // terminals (AffineStoreOp) for post-processing vectorization (we need to + // wait for all use-def chains into them to be vectorized first). if (failed(vectorizeLoopsAndLoadsRecursively(m, &state))) { LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed root vectorizeLoop"); return guard.failure(); diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp index 788857c1a10..2f01085165c 100644 --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -545,6 +545,8 @@ TEST_FUNC(tile_2d) { } // Inject an EDSC-constructed computation to exercise 2-d vectorization. +// TODO(ntv,andydavis) Convert EDSC to use AffineLoad/Store. +/* TEST_FUNC(vectorize_2d) { using namespace edsc; using namespace edsc::intrinsics; @@ -572,17 +574,23 @@ TEST_FUNC(vectorize_2d) { }); ret(); - // CHECK-LABEL: func @vectorize_2d - // CHECK-NEXT: %[[M:.*]] = dim %arg0, 0 : memref - // CHECK-NEXT: %[[N:.*]] = dim %arg0, 1 : memref - // CHECK-NEXT: %[[P:.*]] = dim %arg0, 2 : memref - // CHECK-NEXT: affine.for %i0 = 0 to (d0) -> (d0)(%[[M]]) { - // CHECK-NEXT: affine.for %i1 = 0 to (d0) -> (d0)(%[[N]]) step 4 { - // CHECK-NEXT: affine.for %i2 = 0 to (d0) -> (d0)(%[[P]]) step 4 { - // CHECK-NEXT: %[[vA:.*]] = "vector.transfer_read"(%arg1, %i0, %i1, %i2) {permutation_map = (d0, d1, d2) -> (d1, d2)} : (memref, index, index, index) -> vector<4x4xf32> - // CHECK-NEXT: %[[vB:.*]] = "vector.transfer_read"(%arg0, %i0, %i1, %i2) {permutation_map = (d0, d1, d2) -> (d1, d2)} : (memref, index, index, index) -> vector<4x4xf32> - // CHECK-NEXT: %[[vRES:.*]] = addf %[[vB]], %[[vA]] : vector<4x4xf32> - // CHECK-NEXT: "vector.transfer_write"(%[[vRES:.*]], %arg2, %i0, %i1, %i2) {permutation_map = (d0, d1, d2) -> (d1, d2)} : (vector<4x4xf32>, memref, index, index, index) -> () + // xCHECK-LABEL: func @vectorize_2d + // xCHECK-NEXT: %[[M:.*]] = dim %arg0, 0 : memref + // xCHECK-NEXT: %[[N:.*]] = dim %arg0, 1 : memref + // xCHECK-NEXT: %[[P:.*]] = dim %arg0, 2 : memref + // xCHECK-NEXT: affine.for %i0 = 0 to (d0) -> (d0)(%[[M]]) { + // xCHECK-NEXT: affine.for %i1 = 0 to (d0) -> (d0)(%[[N]]) step 4 { + // xCHECK-NEXT: affine.for %i2 = 0 to (d0) -> (d0)(%[[P]]) step 4 { + // xCHECK-NEXT: %[[vA:.*]] = "vector.transfer_read"(%arg1, %i0, %i1, +%i2) {permutation_map = (d0, d1, d2) -> (d1, d2)} : (memref, index, +index, index) -> vector<4x4xf32> + // xCHECK-NEXT: %[[vB:.*]] = "vector.transfer_read"(%arg0, %i0, %i1, +%i2) {permutation_map = (d0, d1, d2) -> (d1, d2)} : (memref, index, +index, index) -> vector<4x4xf32> + // xCHECK-NEXT: %[[vRES:.*]] = addf %[[vB]], %[[vA]] : vector<4x4xf32> + // xCHECK-NEXT: "vector.transfer_write"(%[[vRES:.*]], %arg2, %i0, %i1, +%i2) {permutation_map = (d0, d1, d2) -> (d1, d2)} : (vector<4x4xf32>, +memref, index, index, index) -> () // clang-format on mlir::PassManager pm; @@ -594,7 +602,7 @@ TEST_FUNC(vectorize_2d) { f.print(llvm::outs()); f.erase(); } - +*/ int main() { RUN_TESTS(); return 0; diff --git a/mlir/test/Transforms/Vectorize/materialize_vectors_1d_to_1d.mlir b/mlir/test/Transforms/Vectorize/materialize_vectors_1d_to_1d.mlir index 6d365eda414..88a62eab0b9 100644 --- a/mlir/test/Transforms/Vectorize/materialize_vectors_1d_to_1d.mlir +++ b/mlir/test/Transforms/Vectorize/materialize_vectors_1d_to_1d.mlir @@ -37,7 +37,7 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { affine.for %i0 = 0 to %M { affine.for %i1 = 0 to %N { // non-scoped %f1 - store %f1, %A[%i0, %i1] : memref + affine.store %f1, %A[%i0, %i1] : memref } } // 4x unroll (jammed by construction). @@ -63,7 +63,7 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { affine.for %i2 = 0 to %M { affine.for %i3 = 0 to %N { // non-scoped %f2 - store %f2, %B[%i2, %i3] : memref + affine.store %f2, %B[%i2, %i3] : memref } } // 4x unroll (jammed by construction). @@ -112,14 +112,14 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { // affine.for %i4 = 0 to %M { affine.for %i5 = 0 to %N { - %a5 = load %A[%i4, %i5] : memref - %b5 = load %B[%i4, %i5] : memref + %a5 = affine.load %A[%i4, %i5] : memref + %b5 = affine.load %B[%i4, %i5] : memref %s5 = addf %a5, %b5 : f32 - store %s5, %C[%i4, %i5] : memref + affine.store %s5, %C[%i4, %i5] : memref } } %c7 = constant 7 : index %c42 = constant 42 : index - %res = load %C[%c7, %c42] : memref + %res = affine.load %C[%c7, %c42] : memref return %res : f32 } diff --git a/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_1d.mlir b/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_1d.mlir index 28059f39e3a..93e42ecfbc7 100644 --- a/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_1d.mlir +++ b/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_1d.mlir @@ -44,7 +44,7 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { affine.for %i0 = 0 to %M { affine.for %i1 = 0 to %N { // non-scoped %f1 - store %f1, %A[%i0, %i1] : memref + affine.store %f1, %A[%i0, %i1] : memref } } // (3x2)x unroll (jammed by construction). @@ -55,7 +55,7 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { affine.for %i3 = 0 to %N { // non-scoped %f2 // CHECK does (3x4)x unrolling. - store %f2, %B[%i2, %i3] : memref + affine.store %f2, %B[%i2, %i3] : memref } } // (3x2)x unroll (jammed by construction). @@ -124,14 +124,14 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { // affine.for %i4 = 0 to %M { affine.for %i5 = 0 to %N { - %a5 = load %A[%i4, %i5] : memref - %b5 = load %B[%i4, %i5] : memref + %a5 = affine.load %A[%i4, %i5] : memref + %b5 = affine.load %B[%i4, %i5] : memref %s5 = addf %a5, %b5 : f32 - store %s5, %C[%i4, %i5] : memref + affine.store %s5, %C[%i4, %i5] : memref } } %c7 = constant 7 : index %c42 = constant 42 : index - %res = load %C[%c7, %c42] : memref + %res = affine.load %C[%c7, %c42] : memref return %res : f32 } diff --git a/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_2d.mlir b/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_2d.mlir index 29b99f87d0e..ad6452f349c 100644 --- a/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_2d.mlir +++ b/mlir/test/Transforms/Vectorize/materialize_vectors_2d_to_2d.mlir @@ -27,7 +27,7 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { affine.for %i0 = 0 to %M { affine.for %i1 = 0 to %N { // non-scoped %f1 - store %f1, %A[%i0, %i1] : memref + affine.store %f1, %A[%i0, %i1] : memref } } // 2x unroll (jammed by construction). @@ -45,7 +45,7 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { affine.for %i2 = 0 to %M { affine.for %i3 = 0 to %N { // non-scoped %f2 - store %f2, %B[%i2, %i3] : memref + affine.store %f2, %B[%i2, %i3] : memref } } // 2x unroll (jammed by construction). @@ -74,14 +74,14 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { // affine.for %i4 = 0 to %M { affine.for %i5 = 0 to %N { - %a5 = load %A[%i4, %i5] : memref - %b5 = load %B[%i4, %i5] : memref + %a5 = affine.load %A[%i4, %i5] : memref + %b5 = affine.load %B[%i4, %i5] : memref %s5 = addf %a5, %b5 : f32 - store %s5, %C[%i4, %i5] : memref + affine.store %s5, %C[%i4, %i5] : memref } } %c7 = constant 7 : index %c42 = constant 42 : index - %res = load %C[%c7, %c42] : memref + %res = affine.load %C[%c7, %c42] : memref return %res : f32 } diff --git a/mlir/test/Transforms/Vectorize/vectorize_1d.mlir b/mlir/test/Transforms/Vectorize/vectorize_1d.mlir index 71f92b96466..48b0ca63661 100644 --- a/mlir/test/Transforms/Vectorize/vectorize_1d.mlir +++ b/mlir/test/Transforms/Vectorize/vectorize_1d.mlir @@ -23,9 +23,11 @@ func @vec1d_1(%A : memref, %B : memref) { %cst0 = constant 0 : index // // CHECK: for {{.*}} step 128 -// CHECK-NEXT: {{.*}} = vector.transfer_read %arg0[%[[C0]], %[[C0]]] {permutation_map = #[[map_proj_d0d1_0]]} : memref, vector<128xf32> +// CHECK-NEXT: %3 = affine.apply #map0(%[[C0]]) +// CHECK-NEXT: %4 = affine.apply #map0(%[[C0]]) +// CHECK-NEXT: {{.*}} = vector.transfer_read %arg0[%3, %4] {permutation_map = #[[map_proj_d0d1_0]]} : memref, vector<128xf32> affine.for %i0 = 0 to %M { // vectorized due to scalar -> vector - %a0 = load %A[%cst0, %cst0] : memref + %a0 = affine.load %A[%cst0, %cst0] : memref } return } @@ -42,11 +44,9 @@ func @vec1d_2(%A : memref, %B : memref) { %cst0 = constant 0 : index // // CHECK:for [[IV3:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128 -// CHECK-NEXT: %[[APP3:[a-zA-Z0-9]+]] = affine.apply {{.*}}[[IV3]] -// CHECK-NEXT: {{.*}} = vector.transfer_read %arg0[%[[C0]], %[[APP3]]] {permutation_map = #[[map_proj_d0d1_d1]]} : memref, vector<128xf32> +// CHECK-NEXT: {{.*}} = vector.transfer_read %arg0[%c0, %i0] {permutation_map = #[[map_proj_d0d1_d1]]} : memref, vector<128xf32> affine.for %i3 = 0 to %M { // vectorized - %r3 = affine.apply (d0) -> (d0) (%i3) - %a3 = load %A[%cst0, %r3] : memref + %a3 = affine.load %A[%cst0, %i3] : memref } return } @@ -64,14 +64,12 @@ func @vec1d_3(%A : memref, %B : memref) { // // CHECK:for [[IV8:%[i0-9]+]] = 0 to [[ARG_M]] step 128 // CHECK-NEXT: for [[IV9:%[i0-9]*]] = 0 to [[ARG_N]] { -// CHECK-NEXT: %[[APP9_0:[0-9]+]] = affine.apply {{.*}}([[IV8]], [[IV9]]) -// CHECK-NEXT: %[[APP9_1:[0-9]+]] = affine.apply {{.*}}([[IV8]], [[IV9]]) +// CHECK-NEXT: %[[APP9_0:[0-9]+]] = affine.apply {{.*}}([[IV9]], [[IV8]]) +// CHECK-NEXT: %[[APP9_1:[0-9]+]] = affine.apply {{.*}}([[IV9]], [[IV8]]) // CHECK-NEXT: {{.*}} = vector.transfer_read %arg0[%[[APP9_0]], %[[APP9_1]]] {permutation_map = #[[map_proj_d0d1_d1]]} : memref, vector<128xf32> affine.for %i8 = 0 to %M { // vectorized affine.for %i9 = 0 to %N { - %r90 = affine.apply (d0, d1) -> (d1) (%i8, %i9) - %r91 = affine.apply (d0, d1) -> (d0 + d1) (%i8, %i9) - %a9 = load %A[%r90, %r91] : memref + %a9 = affine.load %A[%i9, %i8 + %i9] : memref } } return @@ -89,7 +87,7 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { // CHECK: [[C1:%.*]] = constant dense<1.000000e+00> : vector<128xf32> // CHECK: vector.transfer_write [[C1]], {{.*}} {permutation_map = #[[map_proj_d0d1_d1]]} : vector<128xf32>, memref // non-scoped %f1 - store %f1, %A[%i0, %i1] : memref + affine.store %f1, %A[%i0, %i1] : memref } } affine.for %i2 = 0 to %M { @@ -97,7 +95,7 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { // CHECK: [[C3:%.*]] = constant dense<2.000000e+00> : vector<128xf32> // CHECK: vector.transfer_write [[C3]], {{.*}} {permutation_map = #[[map_proj_d0d1_d1]]} : vector<128xf32>, memref // non-scoped %f2 - store %f2, %B[%i2, %i3] : memref + affine.store %f2, %B[%i2, %i3] : memref } } affine.for %i4 = 0 to %M { @@ -111,8 +109,8 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { // CHECK: [[S7:%.*]] = addf [[S5]], [[SPLAT2]] : vector<128xf32> // CHECK: [[S8:%.*]] = addf [[S7]], [[S6]] : vector<128xf32> // CHECK: vector.transfer_write [[S8]], {{.*}} {permutation_map = #[[map_proj_d0d1_d1]]} : vector<128xf32>, memref - %a5 = load %A[%i4, %i5] : memref - %b5 = load %B[%i4, %i5] : memref + %a5 = affine.load %A[%i4, %i5] : memref + %b5 = affine.load %B[%i4, %i5] : memref %s5 = addf %a5, %b5 : f32 // non-scoped %f1 %s6 = addf %s5, %f1 : f32 @@ -120,12 +118,12 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { %s7 = addf %s5, %f2 : f32 // diamond dependency. %s8 = addf %s7, %s6 : f32 - store %s8, %C[%i4, %i5] : memref + affine.store %s8, %C[%i4, %i5] : memref } } %c7 = constant 7 : index %c42 = constant 42 : index - %res = load %C[%c7, %c42] : memref + %res = affine.load %C[%c7, %c42] : memref return %res : f32 } @@ -142,7 +140,7 @@ func @vec_rejected_1(%A : memref, %B : memref) { // // CHECK:for {{.*}} [[ARG_M]] { affine.for %i1 = 0 to %M { // not vectorized - %a1 = load %A[%i1, %i1] : memref + %a1 = affine.load %A[%i1, %i1] : memref } return } @@ -160,8 +158,7 @@ func @vec_rejected_2(%A : memref, %B : memref) { // // CHECK: affine.for %i{{[0-9]*}} = 0 to [[ARG_M]] { affine.for %i2 = 0 to %M { // not vectorized, would vectorize with --test-fastest-varying=1 - %r2 = affine.apply (d0) -> (d0) (%i2) - %a2 = load %A[%r2, %cst0] : memref + %a2 = affine.load %A[%i2, %cst0] : memref } return } @@ -179,14 +176,10 @@ func @vec_rejected_3(%A : memref, %B : memref) { // // CHECK:for [[IV4:%[i0-9]+]] = 0 to [[ARG_M]] step 128 { // CHECK-NEXT: for [[IV5:%[i0-9]*]] = 0 to [[ARG_N]] { -// CHECK-NEXT: %[[APP50:[0-9]+]] = affine.apply {{.*}}([[IV4]], [[IV5]]) -// CHECK-NEXT: %[[APP51:[0-9]+]] = affine.apply {{.*}}([[IV4]], [[IV5]]) -// CHECK-NEXT: {{.*}} = vector.transfer_read %arg0[%[[APP50]], %[[APP51]]] {permutation_map = #[[map_proj_d0d1_d1]]} : memref, vector<128xf32> +// CHECK-NEXT: {{.*}} = vector.transfer_read %arg0[%i1, %i0] {permutation_map = #[[map_proj_d0d1_d1]]} : memref, vector<128xf32> affine.for %i4 = 0 to %M { // vectorized affine.for %i5 = 0 to %N { // not vectorized, would vectorize with --test-fastest-varying=1 - %r50 = affine.apply (d0, d1) -> (d1) (%i4, %i5) - %r51 = affine.apply (d0, d1) -> (d0) (%i4, %i5) - %a5 = load %A[%r50, %r51] : memref + %a5 = affine.load %A[%i5, %i4] : memref } } return @@ -207,9 +200,7 @@ func @vec_rejected_4(%A : memref, %B : memref) { // CHECK-NEXT: for [[IV7:%[i0-9]*]] = 0 to [[ARG_N]] { affine.for %i6 = 0 to %M { // not vectorized, would vectorize with --test-fastest-varying=1 affine.for %i7 = 0 to %N { // not vectorized, can never vectorize - %r70 = affine.apply (d0, d1) -> (d1 + d0) (%i6, %i7) - %r71 = affine.apply (d0, d1) -> (d0) (%i6, %i7) - %a7 = load %A[%r70, %r71] : memref + %a7 = affine.load %A[%i6 + %i7, %i6] : memref } } return @@ -230,12 +221,8 @@ func @vec_rejected_5(%A : memref, %B : memref) { // CHECK: for [[IV11:%[i0-9]*]] = 0 to %{{[0-9]*}} { affine.for %i10 = 0 to %M { // not vectorized, need per load transposes affine.for %i11 = 0 to %N { // not vectorized, need per load transposes - %r11_0 = affine.apply (d0, d1) -> (d0) (%i10, %i11) - %r11_1 = affine.apply (d0, d1) -> (d1) (%i10, %i11) - %a11 = load %A[%r11_0, %r11_1] : memref - %r12_0 = affine.apply (d0, d1) -> (d1) (%i10, %i11) - %r12_1 = affine.apply (d0, d1) -> (d0) (%i10, %i11) - store %a11, %A[%r12_0, %r12_1] : memref + %a11 = affine.load %A[%i10, %i11] : memref + affine.store %a11, %A[%i11, %i10] : memref } } return @@ -258,10 +245,7 @@ func @vec_rejected_6(%A : memref, %B : memref) { affine.for %i12 = 0 to %M { // not vectorized, can never vectorize affine.for %i13 = 0 to %N { // not vectorized, can never vectorize affine.for %i14 = 0 to %P { // vectorized - %r14_0 = affine.apply (d0, d1, d2) -> (d1) (%i12, %i13, %i14) - %r14_1 = affine.apply (d0, d1, d2) -> (d0 + d1) (%i12, %i13, %i14) - %r14_2 = affine.apply (d0, d1, d2) -> (d0 + d2) (%i12, %i13, %i14) - %a14 = load %B[%r14_0, %r14_1, %r14_2] : memref + %a14 = affine.load %B[%i13, %i12 + %i13, %i12 + %i14] : memref } } } @@ -282,7 +266,7 @@ func @vec_rejected_7(%A : memref, %B : memref) { // CHECK: affine.for %i{{[0-9]*}} = 0 to %{{[0-9]*}} { affine.for %i16 = 0 to %M { // not vectorized, can't vectorize a vector load %a16 = alloc(%M) : memref> - %l16 = load %a16[%i16] : memref> + %l16 = affine.load %a16[%i16] : memref> } return } @@ -300,10 +284,12 @@ func @vec_rejected_8(%A : memref, %B : memref) { // // CHECK: affine.for %i{{[0-9]*}} = 0 to %{{[0-9]*}} { // CHECK: for [[IV18:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128 -// CHECK: {{.*}} = vector.transfer_read %arg0[%[[C0]], %[[C0]]] {permutation_map = #[[map_proj_d0d1_0]]} : memref, vector<128xf32> +// CHECK: %3 = affine.apply #map0(%c0) +// CHECK: %4 = affine.apply #map0(%c0) +// CHECK: {{.*}} = vector.transfer_read %arg0[%3, %4] {permutation_map = #[[map_proj_d0d1_0]]} : memref, vector<128xf32> affine.for %i17 = 0 to %M { // not vectorized, the 1-D pattern that matched %i18 in DFS post-order prevents vectorizing %i17 affine.for %i18 = 0 to %M { // vectorized due to scalar -> vector - %a18 = load %A[%cst0, %cst0] : memref + %a18 = affine.load %A[%cst0, %cst0] : memref } } return @@ -322,10 +308,12 @@ func @vec_rejected_9(%A : memref, %B : memref) { // // CHECK: affine.for %i{{[0-9]*}} = 0 to %{{[0-9]*}} { // CHECK: for [[IV18:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128 -// CHECK: {{.*}} = vector.transfer_read %arg0[%[[C0]], %[[C0]]] {permutation_map = #[[map_proj_d0d1_0]]} : memref, vector<128xf32> +// CHECK: %3 = affine.apply #map0(%c0) +// CHECK-NEXT: %4 = affine.apply #map0(%c0) +// CHECK-NEXT: {{.*}} = vector.transfer_read %arg0[%3, %4] {permutation_map = #[[map_proj_d0d1_0]]} : memref, vector<128xf32> affine.for %i17 = 0 to %M { // not vectorized, the 1-D pattern that matched %i18 in DFS post-order prevents vectorizing %i17 affine.for %i18 = 0 to %M { // vectorized due to scalar -> vector - %a18 = load %A[%cst0, %cst0] : memref + %a18 = affine.load %A[%cst0, %cst0] : memref } } return @@ -345,7 +333,7 @@ func @vec_rejected_10(%A : memref, %B : memref) { // CHECK: affine.for %i{{[0-9]*}} = 0 to %{{[0-9]*}} { affine.for %i15 = 0 to %M { // not vectorized due to condition below affine.if #set0(%i15) { - %a15 = load %A[%cst0, %cst0] : memref + %a15 = affine.load %A[%cst0, %cst0] : memref } } return @@ -357,13 +345,13 @@ func @vec_rejected_11(%A : memref, %C : memref) { %N = dim %A, 0 : memref affine.for %i = 0 to %N { // CHECK-NOT: vector - %a = load %A[%i, %i] : memref // not vectorized + %a = affine.load %A[%i, %i] : memref // not vectorized affine.for %j = 0 to %N { - %b = load %A[%i, %j] : memref // may be vectorized + %b = affine.load %A[%i, %j] : memref // may be vectorized // CHECK-NOT: vector %c = addf %a, %b : f32 // not vectorized because %a wasn't // CHECK-NOT: vector - store %c, %C[%i, %j] : memref // not vectorized because %c wasn't + affine.store %c, %C[%i, %j] : memref // not vectorized because %c wasn't } } return @@ -375,10 +363,9 @@ func @vec_rejected_sequential(%A : memref) { %N = dim %A, 0 : memref affine.for %i = 0 to %N { // CHECK-NOT: vector - %a = load %A[%i] : memref + %a = affine.load %A[%i] : memref // CHECK-NOT: vector - %ip1 = affine.apply (d0)->(d0 + 1) (%i) - store %a, %A[%ip1] : memref + affine.store %a, %A[%i + 1] : memref } return } diff --git a/mlir/test/Transforms/Vectorize/vectorize_2d.mlir b/mlir/test/Transforms/Vectorize/vectorize_2d.mlir index b4b16117ecf..a44dc5446ed 100644 --- a/mlir/test/Transforms/Vectorize/vectorize_2d.mlir +++ b/mlir/test/Transforms/Vectorize/vectorize_2d.mlir @@ -26,7 +26,7 @@ func @vec2d(%A : memref) { affine.for %i0 = 0 to %M { affine.for %i1 = 0 to %N { affine.for %i2 = 0 to %P { - %a2 = load %A[%i0, %i1, %i2] : memref + %a2 = affine.load %A[%i0, %i1, %i2] : memref } } } @@ -38,7 +38,7 @@ func @vec2d(%A : memref) { affine.for %i3 = 0 to %M { affine.for %i4 = 0 to %N { affine.for %i5 = 0 to %P { - %a5 = load %A[%i4, %i5, %i3] : memref + %a5 = affine.load %A[%i4, %i5, %i3] : memref } } } @@ -56,7 +56,7 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { // CHECK: [[C1:%.*]] = constant dense<1.000000e+00> : vector<32x256xf32> // CHECK: vector.transfer_write [[C1]], {{.*}} {permutation_map = #[[map_id2]]} : vector<32x256xf32>, memref // non-scoped %f1 - store %f1, %A[%i0, %i1] : memref + affine.store %f1, %A[%i0, %i1] : memref } } affine.for %i2 = 0 to %M { @@ -64,7 +64,7 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { // CHECK: [[C3:%.*]] = constant dense<2.000000e+00> : vector<32x256xf32> // CHECK: vector.transfer_write [[C3]], {{.*}} {permutation_map = #[[map_id2]]} : vector<32x256xf32>, memref // non-scoped %f2 - store %f2, %B[%i2, %i3] : memref + affine.store %f2, %B[%i2, %i3] : memref } } affine.for %i4 = 0 to %M { @@ -79,8 +79,8 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { // CHECK: [[S8:%.*]] = addf [[S7]], [[S6]] : vector<32x256xf32> // CHECK: vector.transfer_write [[S8]], {{.*}} {permutation_map = #[[map_id2]]} : vector<32x256xf32>, memref // - %a5 = load %A[%i4, %i5] : memref - %b5 = load %B[%i4, %i5] : memref + %a5 = affine.load %A[%i4, %i5] : memref + %b5 = affine.load %B[%i4, %i5] : memref %s5 = addf %a5, %b5 : f32 // non-scoped %f1 %s6 = addf %s5, %f1 : f32 @@ -88,12 +88,12 @@ func @vector_add_2d(%M : index, %N : index) -> f32 { %s7 = addf %s5, %f2 : f32 // diamond dependency. %s8 = addf %s7, %s6 : f32 - store %s8, %C[%i4, %i5] : memref + affine.store %s8, %C[%i4, %i5] : memref } } %c7 = constant 7 : index %c42 = constant 42 : index - %res = load %C[%c7, %c42] : memref + %res = affine.load %C[%c7, %c42] : memref return %res : f32 } @@ -114,7 +114,7 @@ func @vectorize_matmul(%arg0: memref, %arg1: memref, %arg2: me affine.for %i0 = (d0) -> (d0)(%c0) to (d0) -> (d0)(%M) { affine.for %i1 = (d0) -> (d0)(%c0) to (d0) -> (d0)(%N) { %cst = constant 0.000000e+00 : f32 - store %cst, %arg2[%i0, %i1] : memref + affine.store %cst, %arg2[%i0, %i1] : memref } } // VECT: affine.for %[[I2:.*]] = #[[map_id1]](%[[C0]]) to #[[map_id1]](%[[M]]) step 4 { @@ -129,12 +129,12 @@ func @vectorize_matmul(%arg0: memref, %arg1: memref, %arg2: me affine.for %i2 = (d0) -> (d0)(%c0) to (d0) -> (d0)(%M) { affine.for %i3 = (d0) -> (d0)(%c0) to (d0) -> (d0)(%N) { affine.for %i4 = (d0) -> (d0)(%c0) to (d0) -> (d0)(%K) { - %6 = load %arg1[%i4, %i3] : memref - %7 = load %arg0[%i2, %i4] : memref + %6 = affine.load %arg1[%i4, %i3] : memref + %7 = affine.load %arg0[%i2, %i4] : memref %8 = mulf %7, %6 : f32 - %9 = load %arg2[%i2, %i3] : memref + %9 = affine.load %arg2[%i2, %i3] : memref %10 = addf %9, %8 : f32 - store %10, %arg2[%i2, %i3] : memref + affine.store %10, %arg2[%i2, %i3] : memref } } } diff --git a/mlir/test/Transforms/Vectorize/vectorize_3d.mlir b/mlir/test/Transforms/Vectorize/vectorize_3d.mlir index 34db2255ff9..98d8ebccf79 100644 --- a/mlir/test/Transforms/Vectorize/vectorize_3d.mlir +++ b/mlir/test/Transforms/Vectorize/vectorize_3d.mlir @@ -18,7 +18,7 @@ func @vec3d(%A : memref) { affine.for %i0 = 0 to %0 { affine.for %i1 = 0 to %1 { affine.for %i2 = 0 to %2 { - %a2 = load %A[%i0, %i1, %i2] : memref + %a2 = affine.load %A[%i0, %i1, %i2] : memref } } } diff --git a/mlir/test/Transforms/Vectorize/vectorize_outer_loop_2d.mlir b/mlir/test/Transforms/Vectorize/vectorize_outer_loop_2d.mlir index 00f76d1d3d7..b1257d1e4fa 100644 --- a/mlir/test/Transforms/Vectorize/vectorize_outer_loop_2d.mlir +++ b/mlir/test/Transforms/Vectorize/vectorize_outer_loop_2d.mlir @@ -14,7 +14,7 @@ func @vec2d(%A : memref) { affine.for %i0 = 0 to %M { affine.for %i1 = 0 to %N { affine.for %i2 = 0 to %P { - %a2 = load %A[%i0, %i1, %i2] : memref + %a2 = affine.load %A[%i0, %i1, %i2] : memref } } } @@ -26,7 +26,7 @@ func @vec2d(%A : memref) { affine.for %i3 = 0 to %M { affine.for %i4 = 0 to %N { affine.for %i5 = 0 to %P { - %a5 = load %A[%i4, %i5, %i3] : memref + %a5 = affine.load %A[%i4, %i5, %i3] : memref } } } diff --git a/mlir/test/Transforms/Vectorize/vectorize_outer_loop_transpose_2d.mlir b/mlir/test/Transforms/Vectorize/vectorize_outer_loop_transpose_2d.mlir index 813fef027bf..7d30162e468 100644 --- a/mlir/test/Transforms/Vectorize/vectorize_outer_loop_transpose_2d.mlir +++ b/mlir/test/Transforms/Vectorize/vectorize_outer_loop_transpose_2d.mlir @@ -15,7 +15,7 @@ func @vec2d(%A : memref) { affine.for %i0 = 0 to %M { affine.for %i1 = 0 to %N { affine.for %i2 = 0 to %P { - %a2 = load %A[%i0, %i1, %i2] : memref + %a2 = affine.load %A[%i0, %i1, %i2] : memref } } } @@ -26,7 +26,7 @@ func @vec2d(%A : memref) { affine.for %i3 = 0 to %M { affine.for %i4 = 0 to %N { affine.for %i5 = 0 to %P { - %a5 = load %A[%i4, %i5, %i3] : memref + %a5 = affine.load %A[%i4, %i5, %i3] : memref } } } @@ -49,15 +49,15 @@ func @vec2d_imperfectly_nested(%A : memref) { affine.for %i0 = 0 to %0 { affine.for %i1 = 0 to %1 { affine.for %i2 = 0 to %2 { - %a2 = load %A[%i2, %i1, %i0] : memref + %a2 = affine.load %A[%i2, %i1, %i0] : memref } } affine.for %i3 = 0 to %1 { affine.for %i4 = 0 to %2 { - %a4 = load %A[%i3, %i4, %i0] : memref + %a4 = affine.load %A[%i3, %i4, %i0] : memref } affine.for %i5 = 0 to %2 { - %a5 = load %A[%i3, %i5, %i0] : memref + %a5 = affine.load %A[%i3, %i5, %i0] : memref } } } diff --git a/mlir/test/Transforms/Vectorize/vectorize_transpose_2d.mlir b/mlir/test/Transforms/Vectorize/vectorize_transpose_2d.mlir index 99b9bde7c8d..f33e434696e 100644 --- a/mlir/test/Transforms/Vectorize/vectorize_transpose_2d.mlir +++ b/mlir/test/Transforms/Vectorize/vectorize_transpose_2d.mlir @@ -15,7 +15,7 @@ func @vec2d(%A : memref) { affine.for %i0 = 0 to %M { affine.for %i1 = 0 to %N { affine.for %i2 = 0 to %P { - %a2 = load %A[%i0, %i1, %i2] : memref + %a2 = affine.load %A[%i0, %i1, %i2] : memref } } } @@ -26,7 +26,7 @@ func @vec2d(%A : memref) { affine.for %i3 = 0 to %M { affine.for %i4 = 0 to %N { affine.for %i5 = 0 to %P { - %a5 = load %A[%i4, %i5, %i3] : memref + %a5 = affine.load %A[%i4, %i5, %i3] : memref } } } @@ -49,15 +49,15 @@ func @vec2d_imperfectly_nested(%A : memref) { affine.for %i0 = 0 to %0 { affine.for %i1 = 0 to %1 { affine.for %i2 = 0 to %2 { - %a2 = load %A[%i2, %i1, %i0] : memref + %a2 = affine.load %A[%i2, %i1, %i0] : memref } } affine.for %i3 = 0 to %1 { affine.for %i4 = 0 to %2 { - %a4 = load %A[%i3, %i4, %i0] : memref + %a4 = affine.load %A[%i3, %i4, %i0] : memref } affine.for %i5 = 0 to %2 { - %a5 = load %A[%i3, %i5, %i0] : memref + %a5 = affine.load %A[%i3, %i5, %i0] : memref } } } diff --git a/mlir/test/Transforms/dma-generate.mlir b/mlir/test/Transforms/dma-generate.mlir index 98405801bd2..6275f2ef10c 100644 --- a/mlir/test/Transforms/dma-generate.mlir +++ b/mlir/test/Transforms/dma-generate.mlir @@ -13,8 +13,8 @@ // ----- // Index of the buffer for the second DMA is remapped. -// CHECK-DAG: [[MAP_MINUS_256:#map[0-9]+]] = (d0) -> (d0 - 256) // CHECK-DAG: [[MAP_PLUS_256:#map[0-9]+]] = (d0) -> (d0 + 256) +// CHECK-DAG: [[MAP0:#map[0-9]+]] = (d0) -> (d0) // CHECK-LABEL: func @loop_nest_1d() { func @loop_nest_1d() { @@ -27,22 +27,23 @@ func @loop_nest_1d() { // Tag for first DMA. // CHECK: %4 = alloc() : memref<1xi32> // First DMA transfer. - // CHECK: dma_start %0[%c0], %3[%c0], %c256_1, %4[%c0] : memref<256xf32>, memref<256xf32, 2>, memref<1xi32> - // CHECK: dma_wait %4[%c0], %c256_1 : memref<1xi32> + // CHECK: affine.dma_start %0[%c0], %3[%c0], %4[%c0], %c256_1 : memref<256xf32>, memref<256xf32, 2>, memref<1xi32> + // CHECK: affine.dma_wait %4[%c0], %c256_1 : memref<1xi32> // Second DMA buffer. // CHECK: %5 = alloc() : memref<256xf32, 2> // Tag for second DMA. // CHECK: %6 = alloc() : memref<1xi32> // Second DMA transfer. - // CHECK: dma_start %1[%c256], %5[%c0], %c256_0, %6[%c0] : memref<512xf32>, memref<256xf32, 2>, memref<1xi32> - // CHECK-NEXT: dma_wait %6[%c0], %c256_0 : memref<1xi32> + // CHECK: affine.dma_start %1[%c256], %5[%c0], %6[%c0], %c256_0 : memref<512xf32>, memref<256xf32, 2>, memref<1xi32> + // CHECK-NEXT: affine.dma_wait %6[%c0], %c256_0 : memref<1xi32> // CHECK: affine.for %i0 = 0 to 256 { - // CHECK-NEXT: %7 = load %3[%i0] : memref<256xf32, 2> + // CHECK-NEXT: %7 = affine.load %3[%i0] : memref<256xf32, 2> // CHECK: %8 = affine.apply [[MAP_PLUS_256]](%i0) - // CHECK: %9 = affine.apply [[MAP_MINUS_256]](%8) - // CHECK-NEXT: %10 = load %5[%9] : memref<256xf32, 2> + // Buffer for '%B' in faster memref space is smaller size: 256xf32 + // Affine map for 'affine.load %5' is composed: %i0 + 256 - 256 = %i0. + // CHECK-NEXT: %9 = affine.load %5[%i0] : memref<256xf32, 2> // Already in faster memory space. - // CHECK: %11 = load %2[%i0] : memref<256xf32, 2> + // CHECK: %10 = affine.load %2[%i0] : memref<256xf32, 2> // CHECK-NEXT: } // CHECK-NEXT: dealloc %6 : memref<1xi32> // CHECK-NEXT: dealloc %5 : memref<256xf32, 2> @@ -50,10 +51,10 @@ func @loop_nest_1d() { // CHECK-NEXT: dealloc %3 : memref<256xf32, 2> // CHECK-NEXT: return affine.for %i = 0 to 256 { - load %A[%i] : memref<256 x f32> + affine.load %A[%i] : memref<256 x f32> %idx = affine.apply (d0) -> (d0 + 256)(%i) - load %B[%idx] : memref<512 x f32> - load %F[%i] : memref<256 x f32, 2> + affine.load %B[%idx] : memref<512 x f32> + affine.load %F[%i] : memref<256 x f32, 2> } return } @@ -70,41 +71,41 @@ func @loop_nest_1d() { // CHECK-DAG: [[TAGC:%[0-9]+]] = alloc() : memref<1xi32> // CHECK-DAG: [[TAGC_W:%[0-9]+]] = alloc() : memref<1xi32> // INCOMING DMA for B -// CHECK-DAG: dma_start %arg1[%c0, %c0], [[BUFB]][%c0, %c0], %c16384_2, [[TAGB]][%c0] : memref<512x32xf32>, memref<512x32xf32, 2>, memref<1xi32> -// CHECK-DAG: dma_wait [[TAGB]][%c0], %c16384_2 : memref<1xi32> +// CHECK-DAG: affine.dma_start %arg1[%c0, %c0], [[BUFB]][%c0, %c0], [[TAGB]][%c0], %c16384_2 : memref<512x32xf32>, memref<512x32xf32, 2>, memref<1xi32> +// CHECK-DAG: affine.dma_wait [[TAGB]][%c0], %c16384_2 : memref<1xi32> // INCOMING DMA for A. -// CHECK-DAG: dma_start %arg0[%c0, %c0], [[BUFA]][%c0, %c0], %c16384_1, [[TAGA]][%c0] : memref<512x32xf32>, memref<512x32xf32, 2>, memref<1xi32> -// CHECK-DAG: dma_wait [[TAGA]][%c0], %c16384_1 : memref<1xi32> +// CHECK-DAG: affine.dma_start %arg0[%c0, %c0], [[BUFA]][%c0, %c0], [[TAGA]][%c0], %c16384_1 : memref<512x32xf32>, memref<512x32xf32, 2>, memref<1xi32> +// CHECK-DAG: affine.dma_wait [[TAGA]][%c0], %c16384_1 : memref<1xi32> // INCOMING DMA for C. -// CHECK-DAG: dma_start %arg2[%c0, %c0], [[BUFC]][%c0, %c0], %c16384_0, [[TAGC]][%c0] : memref<512x32xf32>, memref<512x32xf32, 2>, memref<1xi32> -// CHECK-DAG: dma_wait [[TAGC]][%c0], %c16384_0 : memref<1xi32> +// CHECK-DAG: affine.dma_start %arg2[%c0, %c0], [[BUFC]][%c0, %c0], [[TAGC]][%c0], %c16384_0 : memref<512x32xf32>, memref<512x32xf32, 2>, memref<1xi32> +// CHECK-DAG: affine.dma_wait [[TAGC]][%c0], %c16384_0 : memref<1xi32> // CHECK-NEXT: affine.for %i0 = 0 to 32 { // CHECK-NEXT: affine.for %i1 = 0 to 32 { // CHECK-NEXT: affine.for %i2 = 0 to 32 { // CHECK-NEXT: affine.for %i3 = 0 to 16 { // CHECK-NEXT: %7 = affine.apply #map{{[0-9]+}}(%i1, %i3) -// CHECK-NEXT: %8 = load [[BUFB]][%7, %i0] : memref<512x32xf32, 2> +// CHECK-NEXT: %8 = affine.load [[BUFB]][%i1 * 16 + %i3, %i0] : memref<512x32xf32, 2> // CHECK-NEXT: "foo"(%8) : (f32) -> () // CHECK-NEXT: } // CHECK-NEXT: affine.for %i4 = 0 to 16 { // CHECK-NEXT: %9 = affine.apply #map{{[0-9]+}}(%i2, %i4) -// CHECK-NEXT: %10 = load [[BUFA]][%9, %i1] : memref<512x32xf32, 2> +// CHECK-NEXT: %10 = affine.load [[BUFA]][%i2 * 16 + %i4, %i1] : memref<512x32xf32, 2> // CHECK-NEXT: "bar"(%10) : (f32) -> () // CHECK-NEXT: } // CHECK-NEXT: affine.for %i5 = 0 to 16 { // CHECK-NEXT: %11 = "abc_compute"() : () -> f32 // CHECK-NEXT: %12 = affine.apply #map{{[0-9]+}}(%i2, %i5) -// CHECK-NEXT: %13 = load [[BUFC]][%12, %i0] : memref<512x32xf32, 2> +// CHECK-NEXT: %13 = affine.load [[BUFC]][%i2 * 16 + %i5, %i0] : memref<512x32xf32, 2> // CHECK-NEXT: %14 = "addf32"(%11, %13) : (f32, f32) -> f32 -// CHECK-NEXT: store %14, [[BUFC]][%12, %i0] : memref<512x32xf32, 2> +// CHECK-NEXT: affine.store %14, [[BUFC]][%i2 * 16 + %i5, %i0] : memref<512x32xf32, 2> // CHECK-NEXT: } // CHECK-NEXT: "foobar"() : () -> () // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } // OUTGOING DMA for C. -// CHECK-NEXT: dma_start [[BUFC]][%c0, %c0], %arg2[%c0, %c0], %c16384, [[TAGC_W]][%c0] : memref<512x32xf32, 2>, memref<512x32xf32>, memref<1xi32> -// CHECK-NEXT: dma_wait [[TAGC_W]][%c0], %c16384 : memref<1xi32> +// CHECK-NEXT: affine.dma_start [[BUFC]][%c0, %c0], %arg2[%c0, %c0], [[TAGC_W]][%c0], %c16384 : memref<512x32xf32, 2>, memref<512x32xf32>, memref<1xi32> +// CHECK-NEXT: affine.dma_wait [[TAGC_W]][%c0], %c16384 : memref<1xi32> // CHECK-NEXT: dealloc [[TAGC_W]] : memref<1xi32> // CHECK-NEXT: dealloc [[TAGC]] : memref<1xi32> // CHECK-NEXT: dealloc [[BUFC]] : memref<512x32xf32, 2> @@ -124,20 +125,20 @@ func @loop_nest_high_d(%A: memref<512 x 32 x f32>, affine.for %iT = 0 to 32 { affine.for %kk = 0 to 16 { // k intratile %k = affine.apply (d0, d1) -> (16*d0 + d1) (%kT, %kk) - %v0 = load %B[%k, %jT] : memref<512 x 32 x f32> + %v0 = affine.load %B[%k, %jT] : memref<512 x 32 x f32> "foo"(%v0) : (f32) -> () } affine.for %ii = 0 to 16 { // i intratile. %i = affine.apply (d0, d1) -> (16*d0 + d1)(%iT, %ii) - %v1 = load %A[%i, %kT] : memref<512 x 32 x f32> + %v1 = affine.load %A[%i, %kT] : memref<512 x 32 x f32> "bar"(%v1) : (f32) -> () } affine.for %ii_ = 0 to 16 { // i intratile. %v2 = "abc_compute"() : () -> f32 %i_ = affine.apply (d0, d1) -> (16*d0 + d1)(%iT, %ii_) - %v3 = load %C[%i_, %jT] : memref<512 x 32 x f32> + %v3 = affine.load %C[%i_, %jT] : memref<512 x 32 x f32> %v4 = "addf32"(%v2, %v3) : (f32, f32) -> (f32) - store %v4, %C[%i_, %jT] : memref<512 x 32 x f32> + affine.store %v4, %C[%i_, %jT] : memref<512 x 32 x f32> } "foobar"() : () -> () } @@ -157,8 +158,9 @@ func @loop_nest_high_d(%A: memref<512 x 32 x f32>, // CHECK-NEXT: %1 = affine.apply #map{{[0-9]+}}(%i0) // CHECK-NEXT: %2 = alloc() : memref<1x2xf32, 2> // CHECK-NEXT: %3 = alloc() : memref<1xi32> -// CHECK-NEXT: dma_start %0[%1, %c0], %2[%c0, %c0], %c2, %3[%c0] : memref<256x8xf32>, memref<1x2xf32, 2>, memref<1xi32> -// CHECK-NEXT: dma_wait %3[%c0], %c2 : memref<1xi32> +// Composition of the affine map for '%0' causes '%c0' to be added as a symbol. +// CHECK-NEXT: affine.dma_start %0[%i0, symbol(%c0)], %2[%c0, %c0], %3[%c0], %c2 : memref<256x8xf32>, memref<1x2xf32, 2>, memref<1xi32> +// CHECK-NEXT: affine.dma_wait %3[%c0], %c2 : memref<1xi32> // CHECK-NEXT: affine.for %i1 = 0 to 8 { // ... // ... @@ -174,7 +176,7 @@ func @loop_nest_modulo() { affine.for %j = 0 to 8 { %idx = affine.apply (d0) -> (d0 mod 2) (%j) // A buffer of size 32 x 2 will be allocated (original buffer was 256 x 8). - %v = load %A[%i, %idx] : memref<256 x 8 x f32> + %v = affine.load %A[%i, %idx] : memref<256 x 8 x f32> } } return @@ -182,9 +184,6 @@ func @loop_nest_modulo() { // ----- -// CHECK-DAG: [[MAP_INDEX_DIFF_EVEN:#map[0-9]+]] = (d0, d1, d2, d3) -> (d2 - d0) -// CHECK-DAG: [[MAP_INDEX_DIFF_ODD:#map[0-9]+]] = (d0, d1, d2, d3) -> (d3 - d1) - // DMA on tiled loop nest. This also tests the case where the bounds are // dependent on outer loop IVs. // CHECK-LABEL: func @loop_nest_tiled() -> memref<256x1024xf32> { @@ -195,16 +194,14 @@ func @loop_nest_tiled() -> memref<256x1024xf32> { // CHECK: %3 = alloc() : memref<32x32xf32, 2> // CHECK-NEXT: %4 = alloc() : memref<1xi32> // Strided DMA here: 32 x 32 tile in a 256 x 1024 memref. -// CHECK-NEXT: dma_start %0[%1, %2], %3[%c0, %c0], %c1024, %4[%c0], %c1024_0, %c32 : memref<256x1024xf32>, memref<32x32xf32, 2>, memref<1xi32> -// CHECK-NEXT: dma_wait +// CHECK-NEXT: affine.dma_start %0[%i0, %i1], %3[%c0, %c0], %4[%c0], %c1024, %c1024_0, %c32 : memref<256x1024xf32>, memref<32x32xf32, 2>, memref<1xi32> +// CHECK-NEXT: affine.dma_wait // CHECK-NEXT: affine.for %i2 = #map // CHECK-NEXT: affine.for %i3 = #map affine.for %i2 = (d0) -> (d0)(%i0) to (d0) -> (d0 + 32)(%i0) { affine.for %i3 = (d0) -> (d0)(%i1) to (d0) -> (d0 + 32)(%i1) { - // CHECK-NEXT: %5 = affine.apply [[MAP_INDEX_DIFF_EVEN]](%i0, %i1, %i2, %i3) - // CHECK-NEXT: %6 = affine.apply [[MAP_INDEX_DIFF_ODD]](%i0, %i1, %i2, %i3) - // CHECK-NEXT: %7 = load %3[%5, %6] : memref<32x32xf32, 2> - %1 = load %0[%i2, %i3] : memref<256x1024xf32> + // CHECK: %5 = affine.load %3[-%i0 + %i2, -%i1 + %i3] : memref<32x32xf32, 2> + %1 = affine.load %0[%i2, %i3] : memref<256x1024xf32> } // CHECK-NEXT: } } } @@ -214,9 +211,6 @@ func @loop_nest_tiled() -> memref<256x1024xf32> { // ----- -// CHECK-DAG: [[MAP_D0_MINUS_ONE:#map[0-9]+]] = (d0, d1) -> (d0 - 1) -// CHECK-DAG: [[MAP_D1:#map[0-9]+]] = (d0, d1) -> (d1) - // CHECK-LABEL: func @dma_constant_dim_access func @dma_constant_dim_access(%A : memref<100x100xf32>) { %one = constant 1 : index @@ -224,14 +218,12 @@ func @dma_constant_dim_access(%A : memref<100x100xf32>) { // CHECK: %0 = alloc() : memref<1x100xf32, 2> // CHECK-NEXT: %1 = alloc() : memref<1xi32> // No strided DMA needed here. - // CHECK: dma_start %arg0[%c1, %c0], %0[%c0, %c0], %c100, %1[%c0] : memref<100x100xf32>, memref<1x100xf32, 2>, - // CHECK-NEXT: dma_wait %1[%c0], %c100 : memref<1xi32> + // CHECK: affine.dma_start %arg0[%c1, %c0], %0[%c0, %c0], %1[%c0], %c100 : memref<100x100xf32>, memref<1x100xf32, 2>, + // CHECK-NEXT: affine.dma_wait %1[%c0], %c100 : memref<1xi32> affine.for %i = 0 to 100 { affine.for %j = 0 to ()[s0] -> (s0) ()[%N] { - // CHECK: %2 = affine.apply [[MAP_D0_MINUS_ONE]](%c1_0, %i1) - // CHECK: %3 = affine.apply [[MAP_D1]](%c1_0, %i1) - // CHECK-NEXT: %4 = load %0[%2, %3] : memref<1x100xf32, 2> - load %A[%one, %j] : memref<100 x 100 x f32> + // CHECK: %2 = affine.load %0[symbol(%c1_0) - 1, %i1] : memref<1x100xf32, 2> + affine.load %A[%one, %j] : memref<100 x 100 x f32> } } return @@ -240,8 +232,6 @@ func @dma_constant_dim_access(%A : memref<100x100xf32>) { // ----- // CHECK-DAG: [[MAP_SYM_SHIFT:#map[0-9]+]] = (d0, d1)[s0, s1] -> (d1 + s0 + s1) -// CHECK-DAG: [[MAP_3D_D1:#map[0-9]+]] = (d0, d1, d2) -> (d1) -// CHECK-DAG: [[MAP_SUB_OFFSET:#map[0-9]+]] = (d0, d1, d2) -> (d2 - (d0 + 9)) // CHECK-LABEL: func @dma_with_symbolic_accesses func @dma_with_symbolic_accesses(%A : memref<100x100xf32>, %M : index) { @@ -249,20 +239,18 @@ func @dma_with_symbolic_accesses(%A : memref<100x100xf32>, %M : index) { affine.for %i = 0 to 100 { affine.for %j = 0 to 100 { %idy = affine.apply (d0, d1) [s0, s1] -> (d1 + s0 + s1)(%i, %j)[%M, %N] - load %A[%i, %idy] : memref<100 x 100 x f32> + affine.load %A[%i, %idy] : memref<100 x 100 x f32> } } return // CHECK: %1 = alloc() : memref<100x100xf32, 2> // CHECK-NEXT: %2 = alloc() : memref<1xi32> -// CHECK-NEXT: dma_start %arg0[%c0, %0], %1[%c0, %c0], %c10000, %2[%c0] -// CHECK-NEXT: dma_wait %2[%c0], %c10000 +// CHECK-NEXT: affine.dma_start %arg0[symbol(%c0), symbol(%arg1) + 9], %1[%c0, %c0], %2[%c0], %c10000 +// CHECK-NEXT: affine.dma_wait %2[%c0], %c10000 // CHECK-NEXT: affine.for %i0 = 0 to 100 { // CHECK-NEXT: affine.for %i1 = 0 to 100 { // CHECK-NEXT: %3 = affine.apply [[MAP_SYM_SHIFT]](%i0, %i1)[%arg1, %c9] -// CHECK-NEXT: %4 = affine.apply [[MAP_3D_D1]](%arg1, %i0, %3) -// CHECK-NEXT: %5 = affine.apply [[MAP_SUB_OFFSET]](%arg1, %i0, %3) -// CHECK-NEXT: %6 = load %1[%4, %5] : memref<100x100xf32, 2> +// CHECK-NEXT: %4 = affine.load %1[%i0, %i1 + symbol(%c9) - 9] : memref<100x100xf32, 2> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK: return @@ -277,12 +265,12 @@ func @dma_with_symbolic_loop_bounds(%A : memref<100x100xf32>, %M : index, %N: in // memref size; so the DMA buffer is the entire 100x100. // CHECK: %0 = alloc() : memref<100x100xf32, 2> // CHECK-NEXT: %1 = alloc() : memref<1xi32> -// CHECK-NEXT: dma_start %arg0[%c0, %c0], %0[%c0, %c0], %c10000, %1[%c0] : memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32> -// CHECK-NEXT: dma_wait %1[%c0], %c10000 : memref<1xi32> +// CHECK-NEXT: affine.dma_start %arg0[%c0, %c0], %0[%c0, %c0], %1[%c0], %c10000 : memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32> +// CHECK-NEXT: affine.dma_wait %1[%c0], %c10000 : memref<1xi32> affine.for %i = 0 to 100 { affine.for %j = %M to %N { %idy = affine.apply (d1) [s0] -> (d1 + s0)(%j)[%K] - load %A[%i, %idy] : memref<100 x 100 x f32> + affine.load %A[%i, %idy] : memref<100 x 100 x f32> } } return @@ -298,8 +286,8 @@ func @dma_unknown_size(%arg0: memref) { affine.for %j = 0 to %N { // If this loop nest isn't tiled, the access requires a non-constant DMA // size -- not yet implemented. - // CHECK: %2 = load %arg0[%i0, %i1] : memref - load %arg0[%i, %j] : memref + // CHECK: %2 = affine.load %arg0[%i0, %i1] : memref + affine.load %arg0[%i, %j] : memref // expected-error@-6 {{DMA generation failed for one or more memref's in this block}} } } @@ -318,8 +306,8 @@ func @dma_memref_3d(%arg0: memref<1024x1024x1024xf32>) { %idz = affine.apply (d0) -> (d0 mod 128)(%k) // DMA with nested striding (or emulating with loop around strided DMA) // not yet implemented. - // CHECK: %5 = load %arg0[%2, %3, %4] : memref<1024x1024x1024xf32> - %v = load %arg0[%idx, %idy, %idz] : memref<1024 x 1024 x 1024 x f32> + // CHECK: %5 = affine.load %arg0[%2, %3, %4] : memref<1024x1024x1024xf32> + %v = affine.load %arg0[%idx, %idy, %idz] : memref<1024 x 1024 x 1024 x f32> // expected-error@-10 {{DMA generation failed for one or more memref's in this block}} } } @@ -332,8 +320,6 @@ func @dma_memref_3d(%arg0: memref<1024x1024x1024xf32>) { // CHECK-DAG: [[MAP_PLUS_64:#map[0-9]+]] = (d0) -> (d0 + 64) // CHECK-DAG: [[MAP_PLUS_128:#map[0-9]+]] = (d0) -> (d0 + 128) // CHECK-DAG: [[MAP_PLUS_2:#map[0-9]+]] = (d0) -> (d0 + 2) -// CHECK-DAG: [[MAP_D0_MINUS_2:#map[0-9]+]] = (d0, d1) -> (d0 - 2) -// CHECK-DAG: [[MAP_D1_MINUS_2:#map[0-9]+]] = (d0, d1) -> (d1 - 2) // CHECK-DAG: [[MAP_PLUS_192:#map[0-9]+]] = (d0) -> (d0 + 192) // The first load accesses ([2,258), [128,384)) @@ -353,14 +339,14 @@ func @multi_load_store_union() { %ishift = affine.apply (d0) -> (d0 + 2)(%i) %jshift = affine.apply (d0) -> (d0 + 2)(%j) - %u = load %A[%ishift, %idy] : memref<512 x 512 x f32> - %v = load %A[%idx, %jshift] : memref<512 x 512 x f32> + %u = affine.load %A[%ishift, %idy] : memref<512 x 512 x f32> + %v = affine.load %A[%idx, %jshift] : memref<512 x 512 x f32> %sidx = affine.apply (d0) -> (d0 + 128)(%i) %sidy = affine.apply (d0) -> (d0 + 192)(%j) - store %u, %A[%ishift, %sidy] : memref<512 x 512 x f32> - store %v, %A[%sidx, %jshift] : memref<512 x 512 x f32> + affine.store %u, %A[%ishift, %sidy] : memref<512 x 512 x f32> + affine.store %v, %A[%sidx, %jshift] : memref<512 x 512 x f32> } } return @@ -368,8 +354,8 @@ func @multi_load_store_union() { // CHECK: %0 = alloc() : memref<512x512xf32> // CHECK-NEXT: %1 = alloc() : memref<382x446xf32, 2> // CHECK-NEXT: %2 = alloc() : memref<1xi32> -// CHECK-NEXT: dma_start %0[%c2_1, %c2_2], %1[%c0, %c0], %c170372_3, %2[%c0], %c512_4, %c446_5 : memref<512x512xf32>, memref<382x446xf32, 2>, memref<1xi32> -// CHECK-NEXT: dma_wait %2[%c0], %c170372_3 : memref<1xi32> +// CHECK-NEXT: affine.dma_start %0[%c2_1, %c2_2], %1[%c0, %c0], %2[%c0], %c170372_3, %c512_4, %c446_5 : memref<512x512xf32>, memref<382x446xf32, 2>, memref<1xi32> +// CHECK-NEXT: affine.dma_wait %2[%c0], %c170372_3 : memref<1xi32> // CHECK-NEXT: %3 = alloc() : memref<1xi32> // CHECK-NEXT: affine.for %i0 = 0 to 256 { // CHECK-NEXT: affine.for %i1 = 0 to 256 { @@ -377,24 +363,16 @@ func @multi_load_store_union() { // CHECK-NEXT: %5 = affine.apply [[MAP_PLUS_128]](%i1) // CHECK-NEXT: %6 = affine.apply [[MAP_PLUS_2]](%i0) // CHECK-NEXT: %7 = affine.apply [[MAP_PLUS_2]](%i1) -// CHECK-NEXT: %8 = affine.apply [[MAP_D0_MINUS_2]](%6, %5) -// CHECK-NEXT: %9 = affine.apply [[MAP_D1_MINUS_2]](%6, %5) -// CHECK-NEXT: %10 = load %1[%8, %9] : memref<382x446xf32, 2> -// CHECK-NEXT: %11 = affine.apply [[MAP_D0_MINUS_2]](%4, %7) -// CHECK-NEXT: %12 = affine.apply [[MAP_D1_MINUS_2]](%4, %7) -// CHECK-NEXT: %13 = load %1[%11, %12] : memref<382x446xf32, 2> -// CHECK-NEXT: %14 = affine.apply [[MAP_PLUS_128]](%i0) -// CHECK-NEXT: %15 = affine.apply [[MAP_PLUS_192]](%i1) -// CHECK-NEXT: %16 = affine.apply [[MAP_D0_MINUS_2]](%6, %15) -// CHECK-NEXT: %17 = affine.apply [[MAP_D1_MINUS_2]](%6, %15) -// CHECK-NEXT: store %10, %1[%16, %17] : memref<382x446xf32, 2> -// CHECK-NEXT: %18 = affine.apply [[MAP_D0_MINUS_2]](%14, %7) -// CHECK-NEXT: %19 = affine.apply [[MAP_D1_MINUS_2]](%14, %7) -// CHECK-NEXT: store %13, %1[%18, %19] : memref<382x446xf32, 2> +// CHECK-NEXT: %8 = affine.load %1[%i0, %i1 + 126] : memref<382x446xf32, 2> +// CHECK-NEXT: %9 = affine.load %1[%i0 + 62, %i1] : memref<382x446xf32, 2> +// CHECK-NEXT: %10 = affine.apply [[MAP_PLUS_128]](%i0) +// CHECK-NEXT: %11 = affine.apply [[MAP_PLUS_192]](%i1) +// CHECK-NEXT: affine.store %8, %1[%i0, %i1 + 190] : memref<382x446xf32, 2> +// CHECK-NEXT: affine.store %9, %1[%i0 + 126, %i1] : memref<382x446xf32, 2> // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: dma_start %1[%c0, %c0], %0[%c2, %c2_0], %c170372, %3[%c0], %c512, %c446 : memref<382x446xf32, 2>, memref<512x512xf32>, memref<1xi32> -// CHECK-NEXT: dma_wait %3[%c0], %c170372 : memref<1xi32> +// CHECK-NEXT: affine.dma_start %1[%c0, %c0], %0[%c2, %c2_0], %3[%c0], %c170372, %c512, %c446 : memref<382x446xf32, 2>, memref<512x512xf32>, memref<1xi32> +// CHECK-NEXT: affine.dma_wait %3[%c0], %c170372 : memref<1xi32> // CHECK-NEXT: dealloc %3 : memref<1xi32> // CHECK-NEXT: dealloc %2 : memref<1xi32> // CHECK-NEXT: dealloc %1 : memref<382x446xf32, 2> @@ -403,19 +381,17 @@ func @multi_load_store_union() { // ----- -// CHECK-DAG: [[MAP_MINUS_ONE:#map[0-9]+]] = (d0) -> (d0 - 1) - // CHECK-LABEL: func @dma_loop_straightline_interspersed() { func @dma_loop_straightline_interspersed() { %c0 = constant 0 : index %c255 = constant 255 : index %A = alloc() : memref<256 x f32> - %v = load %A[%c0] : memref<256 x f32> + %v = affine.load %A[%c0] : memref<256 x f32> affine.for %i = 1 to 255 { - load %A[%i] : memref<256 x f32> + affine.load %A[%i] : memref<256 x f32> } - %l = load %A[%c255] : memref<256 x f32> - store %l, %A[%c0] : memref<256 x f32> + %l = affine.load %A[%c255] : memref<256 x f32> + affine.store %l, %A[%c0] : memref<256 x f32> return } // There are three regions here - the 'load' preceding the loop, the loop @@ -423,33 +399,32 @@ func @dma_loop_straightline_interspersed() { // CHECK: %0 = alloc() : memref<256xf32> // CHECK-NEXT: %1 = alloc() : memref<1xf32, 2> // CHECK-NEXT: %2 = alloc() : memref<1xi32> -// CHECK-NEXT: dma_start %0[%c0], %1[%c0], %c1_1, %2[%c0] : memref<256xf32>, memref<1xf32, 2>, memref<1xi32> -// CHECK-NEXT: dma_wait %2[%c0], %c1_1 : memref<1xi32> -// CHECK-NEXT: %3 = load %1[%c0_2] : memref<1xf32, 2> +// CHECK-NEXT: affine.dma_start %0[%c0], %1[%c0], %2[%c0], %c1_1 : memref<256xf32>, memref<1xf32, 2>, memref<1xi32> +// CHECK-NEXT: affine.dma_wait %2[%c0], %c1_1 : memref<1xi32> +// CHECK-NEXT: %3 = affine.load %1[symbol(%c0_2)] : memref<1xf32, 2> // CHECK-NEXT: dealloc %2 : memref<1xi32> // CHECK-NEXT: dealloc %1 : memref<1xf32, 2> // CHECK-NEXT: %4 = alloc() : memref<254xf32, 2> // CHECK-NEXT: %5 = alloc() : memref<1xi32> -// CHECK-NEXT: dma_start %0[%c1], %4[%c0], %c254, %5[%c0] : memref<256xf32>, memref<254xf32, 2>, memref<1xi32> -// CHECK-NEXT: dma_wait %5[%c0], %c254 : memref<1xi32> +// CHECK-NEXT: affine.dma_start %0[%c1], %4[%c0], %5[%c0], %c254 : memref<256xf32>, memref<254xf32, 2>, memref<1xi32> +// CHECK-NEXT: affine.dma_wait %5[%c0], %c254 : memref<1xi32> // CHECK-NEXT: affine.for %i0 = 1 to 255 { -// CHECK-NEXT: %6 = affine.apply [[MAP_MINUS_ONE]](%i0) -// CHECK-NEXT: %7 = load %4[%6] : memref<254xf32, 2> +// CHECK-NEXT: %6 = affine.load %4[%i0 - 1] : memref<254xf32, 2> // CHECK-NEXT: } // CHECK-NEXT: dealloc %5 : memref<1xi32> // CHECK-NEXT: dealloc %4 : memref<254xf32, 2> -// CHECK-NEXT: %8 = alloc() : memref<256xf32, 2> +// CHECK-NEXT: %7 = alloc() : memref<256xf32, 2> +// CHECK-NEXT: %8 = alloc() : memref<1xi32> +// CHECK-NEXT: affine.dma_start %0[%c0], %7[%c0], %8[%c0], %c256_0 : memref<256xf32>, memref<256xf32, 2>, memref<1xi32> +// CHECK-NEXT: affine.dma_wait %8[%c0], %c256_0 : memref<1xi32> // CHECK-NEXT: %9 = alloc() : memref<1xi32> -// CHECK-NEXT: dma_start %0[%c0], %8[%c0], %c256_0, %9[%c0] : memref<256xf32>, memref<256xf32, 2>, memref<1xi32> -// CHECK-NEXT: dma_wait %9[%c0], %c256_0 : memref<1xi32> -// CHECK-NEXT: %10 = alloc() : memref<1xi32> -// CHECK-NEXT: %11 = load %8[%c255] : memref<256xf32, 2> -// CHECK-NEXT: store %11, %8[%c0_2] : memref<256xf32, 2> -// CHECK-NEXT: dma_start %8[%c0], %0[%c0], %c256, %10[%c0] : memref<256xf32, 2>, memref<256xf32>, memref<1xi32> -// CHECK-NEXT: dma_wait %10[%c0], %c256 : memref<1xi32> -// CHECK-NEXT: dealloc %10 : memref<1xi32> +// CHECK-NEXT: %10 = affine.load %7[symbol(%c255)] : memref<256xf32, 2> +// CHECK-NEXT: affine.store %10, %7[symbol(%c0_2)] : memref<256xf32, 2> +// CHECK-NEXT: affine.dma_start %7[%c0], %0[%c0], %9[%c0], %c256 : memref<256xf32, 2>, memref<256xf32>, memref<1xi32> +// CHECK-NEXT: affine.dma_wait %9[%c0], %c256 : memref<1xi32> // CHECK-NEXT: dealloc %9 : memref<1xi32> -// CHECK-NEXT: dealloc %8 : memref<256xf32, 2> +// CHECK-NEXT: dealloc %8 : memref<1xi32> +// CHECK-NEXT: dealloc %7 : memref<256xf32, 2> // CHECK-NEXT: return // ----- @@ -459,10 +434,10 @@ func @dma_mixed_loop_blocks() { %c0 = constant 0 : index %A = alloc() : memref<256 x 256 x vector<8 x f32>> affine.for %i = 0 to 256 { - %v = load %A[%c0, %c0] : memref<256 x 256 x vector<8 x f32>> + %v = affine.load %A[%c0, %c0] : memref<256 x 256 x vector<8 x f32>> "foo"(%v) : (vector<8 x f32>) -> () affine.for %j = 0 to 256 { - %w = load %A[%i, %j] : memref<256 x 256 x vector<8 x f32>> + %w = affine.load %A[%i, %j] : memref<256 x 256 x vector<8 x f32>> "bar"(%w) : (vector<8 x f32>) -> () } } @@ -471,12 +446,12 @@ func @dma_mixed_loop_blocks() { // CHECK-DAG: [[MEM:%[0-9]+]] = alloc() : memref<256x256xvector<8xf32>> // CHECK-DAG: [[BUF:%[0-9]+]] = alloc() : memref<256x256xvector<8xf32>, 2> // CHECK-DAG: [[TAG:%[0-9]+]] = alloc() : memref<1xi32> -// CHECK: dma_start [[MEM]][%c0, %c0], [[BUF]][%c0, %c0], %c65536, [[TAG]][%c0] : memref<256x256xvector<8xf32>>, memref<256x256xvector<8xf32>, 2>, memref<1xi32> -// CHECK-NEXT: dma_wait [[TAG]][%c0], %c65536 : memref<1xi32> +// CHECK: affine.dma_start [[MEM]][%c0, %c0], [[BUF]][%c0, %c0], [[TAG]][%c0], %c65536 : memref<256x256xvector<8xf32>>, memref<256x256xvector<8xf32>, 2>, memref<1xi32> +// CHECK-NEXT: affine.dma_wait [[TAG]][%c0], %c65536 : memref<1xi32> // CHECK-NEXT: affine.for %i0 = 0 to 256 { -// CHECK-NEXT: %3 = load [[BUF]][%c0_0, %c0_0] : memref<256x256xvector<8xf32>, 2> +// CHECK: %3 = affine.load [[BUF]][symbol(%c0_0), symbol(%c0_0)] : memref<256x256xvector<8xf32>, 2> // CHECK: affine.for %i1 = 0 to 256 { -// CHECK-NEXT: %4 = load [[BUF]][%i0, %i1] : memref<256x256xvector<8xf32>, 2> +// CHECK-NEXT: %4 = affine.load [[BUF]][%i0, %i1] : memref<256x256xvector<8xf32>, 2> // ----- @@ -485,7 +460,7 @@ func @relative_loop_bounds(%arg0: memref<1027xf32>) { affine.for %i0 = 0 to 1024 { affine.for %i2 = (d0) -> (d0)(%i0) to (d0) -> (d0 + 4)(%i0) { %0 = constant 0.0 : f32 - store %0, %arg0[%i2] : memref<1027xf32> + affine.store %0, %arg0[%i2] : memref<1027xf32> } } return @@ -495,17 +470,16 @@ func @relative_loop_bounds(%arg0: memref<1027xf32>) { // CHECK-NEXT: affine.for %i0 = 0 to 1024 { // CHECK-NEXT: affine.for %i1 = {{#map[0-9]+}}(%i0) to {{#map[0-9]+}}(%i0) { // CHECK-NEXT: %cst = constant 0.000000e+00 : f32 -// CHECK-NEXT: store %cst, [[BUF]][%i1] : memref<1027xf32, 2> +// CHECK-NEXT: affine.store %cst, [[BUF]][%i1] : memref<1027xf32, 2> // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: dma_start [[BUF]][%c0], %arg0[%c0], %c1027, [[MEM]][%c0] : memref<1027xf32, 2>, memref<1027xf32>, memref<1xi32> -// CHECK-NEXT: dma_wait [[MEM]][%c0], %c1027 : memref<1xi32> +// CHECK-NEXT: affine.dma_start [[BUF]][%c0], %arg0[%c0], [[MEM]][%c0], %c1027 : memref<1027xf32, 2>, memref<1027xf32>, memref<1xi32> +// CHECK-NEXT: affine.dma_wait [[MEM]][%c0], %c1027 : memref<1xi32> // ----- // CHECK-DAG: [[MAP_READ_OFFSET:#map[0-9]+]] = (d0) -> (d0 + 100) // CHECK-DAG: [[MAP_WRITE_OFFSET:#map[0-9]+]] = (d0) -> (d0 + 25) -// CHECK-DAG: [[MAP_BUFFER_OFFSET:#map[0-9]+]] = (d0) -> (d0 - 25) func @test_read_write_region_union() { %0 = alloc() : memref<256xf32> @@ -516,8 +490,8 @@ func @test_read_write_region_union() { // union region: [25, 110) %a0 = affine.apply (d0) -> (d0 + 100)(%i0) %a1 = affine.apply (d0) -> (d0 + 25)(%i0) - %1 = load %0[%a0] : memref<256xf32> - store %1, %0[%a1] : memref<256xf32> + %1 = affine.load %0[%a0] : memref<256xf32> + affine.store %1, %0[%a1] : memref<256xf32> } return } @@ -525,19 +499,17 @@ func @test_read_write_region_union() { // CHECK: %0 = alloc() : memref<256xf32> // CHECK-NEXT: %1 = alloc() : memref<85xf32, 2> // CHECK-NEXT: %2 = alloc() : memref<1xi32> -// CHECK-NEXT: dma_start %0[%c25_0], %1[%c0], %c85_1, %2[%c0] : memref<256xf32>, memref<85xf32, 2>, memref<1xi32> -// CHECK-NEXT: dma_wait %2[%c0], %c85_1 : memref<1xi32> +// CHECK-NEXT: affine.dma_start %0[%c25_0], %1[%c0], %2[%c0], %c85_1 : memref<256xf32>, memref<85xf32, 2>, memref<1xi32> +// CHECK-NEXT: affine.dma_wait %2[%c0], %c85_1 : memref<1xi32> // CHECK-NEXT: %3 = alloc() : memref<1xi32> // CHECK-NEXT: affine.for %i0 = 0 to 10 { // CHECK-NEXT: %4 = affine.apply [[MAP_READ_OFFSET]](%i0) // CHECK-NEXT: %5 = affine.apply [[MAP_WRITE_OFFSET]](%i0) -// CHECK-NEXT: %6 = affine.apply [[MAP_BUFFER_OFFSET]](%4) -// CHECK-NEXT: %7 = load %1[%6] : memref<85xf32, 2> -// CHECK-NEXT: %8 = affine.apply [[MAP_BUFFER_OFFSET]](%5) -// CHECK-NEXT: store %7, %1[%8] : memref<85xf32, 2> +// CHECK-NEXT: %6 = affine.load %1[%i0 + 75] : memref<85xf32, 2> +// CHECK-NEXT: affine.store %6, %1[%i0] : memref<85xf32, 2> // CHECK-NEXT: } -// CHECK-NEXT: dma_start %1[%c0], %0[%c25], %c85, %3[%c0] : memref<85xf32, 2>, memref<256xf32>, memref<1xi32> -// CHECK-NEXT: dma_wait %3[%c0], %c85 : memref<1xi32> +// CHECK-NEXT: affine.dma_start %1[%c0], %0[%c25], %3[%c0], %c85 : memref<85xf32, 2>, memref<256xf32>, memref<1xi32> +// CHECK-NEXT: affine.dma_wait %3[%c0], %c85 : memref<1xi32> // ----- @@ -556,10 +528,10 @@ func @test_analysis_util(%arg0: memref<4x4x16x1xf32>, %arg1: memref<144x9xf32>, affine.for %i9 = #map_lb(%i8) to #map_ub(%i8) { affine.for %i17 = 0 to 64 { %23 = affine.apply #map_acc(%i9) - %25 = load %arg2[%23] : memref<2xf32> + %25 = affine.load %arg2[%23] : memref<2xf32> %26 = affine.apply #map_lb(%i17) - %27 = load %0[%26, %c0] : memref<64x1xf32> - store %27, %arg2[%23] : memref<2xf32> + %27 = affine.load %0[%26, %c0] : memref<64x1xf32> + affine.store %27, %arg2[%23] : memref<2xf32> } } } @@ -567,8 +539,8 @@ func @test_analysis_util(%arg0: memref<4x4x16x1xf32>, %arg1: memref<144x9xf32>, } // CHECK: affine.for %i0 = 0 to 9 step 3 { // CHECK: [[BUF:%[0-9]+]] = alloc() : memref<2xf32, 2> -// CHECK: dma_start %arg2[%4], [[BUF]] -// CHECK: dma_wait %6[%c0], %c2_0 : memref<1xi32> +// CHECK: affine.dma_start %arg2[%i0 floordiv 8], [[BUF]] +// CHECK: affine.dma_wait %6[%c0], %c2_0 : memref<1xi32> // CHECK: affine.for %i1 = // ---- @@ -587,7 +559,7 @@ func @test_memref_bounds(%arg0: memref<4x4x16x1xvector<8x128xf32>>, %arg1: memre %10 = affine.apply #map14(%i9, %i10) %11 = affine.apply #map15(%i9, %i10) %12 = affine.apply #map16(%i9, %i10) - %13 = load %arg0[%10, %11, %12, %c0] : memref<4x4x16x1xvector<8x128xf32>> + %13 = affine.load %arg0[%10, %11, %12, %c0] : memref<4x4x16x1xvector<8x128xf32>> } } } @@ -596,8 +568,8 @@ func @test_memref_bounds(%arg0: memref<4x4x16x1xvector<8x128xf32>>, %arg1: memre // CHECK: %0 = alloc() : memref<4x4x16x1xvector<8x128xf32>, 2> // CHECK-NEXT: %1 = alloc() : memref<1xi32> -// CHECK-NEXT: dma_start %arg0[%c0, %c0, %c0, %c0], %0[%c0, %c0, %c0, %c0], %c256, %1[%c0] : memref<4x4x16x1xvector<8x128xf32>>, memref<4x4x16x1xvector<8x128xf32>, 2>, memref<1xi32> -// CHECK-NEXT: dma_wait %1[%c0], %c256 : memref<1xi32> +// CHECK-NEXT: affine.dma_start %arg0[%c0, %c0, %c0, %c0], %0[%c0, %c0, %c0, %c0], %1[%c0], %c256 : memref<4x4x16x1xvector<8x128xf32>>, memref<4x4x16x1xvector<8x128xf32>, 2>, memref<1xi32> +// CHECK-NEXT: affine.dma_wait %1[%c0], %c256 : memref<1xi32> // ----- @@ -609,22 +581,22 @@ func @load_store_same_memref(%arg0: memref<256x1024xf32>) { // FAST-MEM-16KB: affine.for %i0 = 0 to 256 step 4 affine.for %i0 = 0 to 256 step 4 { // FAST-MEM-16KB: [[BUF:%[0-9]+]] = alloc() : memref<4x1024xf32, 2> - // FAST-MEM-16KB: dma_start %arg0 - // FAST-MEM-16KB-NEXT: dma_wait + // FAST-MEM-16KB: affine.dma_start %arg0 + // FAST-MEM-16KB-NEXT: affine.dma_wait // FAST-MEM-16KB: affine.for %i1 affine.for %i1 = 0 to 1024 step 4 { // FAST-MEM-16KB: affine.for %i2 affine.for %i2 = (d0) -> (d0)(%i0) to (d0) -> (d0 + 4)(%i0) { // FAST-MEM-16KB: affine.for %i3 affine.for %i3 = (d0) -> (d0)(%i1) to (d0) -> (d0 + 4)(%i1) { - %3 = load %arg0[%i2, %i3] : memref<256x1024xf32> + %3 = affine.load %arg0[%i2, %i3] : memref<256x1024xf32> %4 = mulf %3, %3 : f32 - store %4, %arg0[%i2, %i3] : memref<256x1024xf32> + affine.store %4, %arg0[%i2, %i3] : memref<256x1024xf32> } // FAST-MEM-16KB: } } // FAST-MEM-16KB: } } // FAST-MEM-16KB: } - // FAST-MEM-16KB: dma_start [[BUF]] - // FAST-MEM-16KB-NEXT: dma_wait + // FAST-MEM-16KB: affine.dma_start [[BUF]] + // FAST-MEM-16KB-NEXT: affine.dma_wait } return } @@ -648,12 +620,12 @@ func @simple_matmul(%arg0: memref<8x8xvector<64xf32>>, %arg1: memref<8x8xvector< affine.for %ii = #map0(%i) to #map1(%i) { affine.for %jj = #map0(%j) to #map1(%j) { affine.for %kk = #map0(%k) to #map1(%k) { - %5 = load %arg0[%ii, %kk] : memref<8x8xvector<64xf32>> - %6 = load %arg1[%kk, %jj] : memref<8x8xvector<64xf32>> - %7 = load %arg2[%ii, %jj] : memref<8x8xvector<64xf32>> + %5 = affine.load %arg0[%ii, %kk] : memref<8x8xvector<64xf32>> + %6 = affine.load %arg1[%kk, %jj] : memref<8x8xvector<64xf32>> + %7 = affine.load %arg2[%ii, %jj] : memref<8x8xvector<64xf32>> %8 = mulf %5, %6 : vector<64xf32> %9 = addf %7, %8 : vector<64xf32> - store %9, %arg2[%ii, %jj] : memref<8x8xvector<64xf32>> + affine.store %9, %arg2[%ii, %jj] : memref<8x8xvector<64xf32>> } } } @@ -664,13 +636,13 @@ func @simple_matmul(%arg0: memref<8x8xvector<64xf32>>, %arg1: memref<8x8xvector< } // FAST-MEM-16KB: affine.for %i0 = 0 to 8 step 4 { // FAST-MEM-16KB: affine.for %i1 = 0 to 8 step 4 { -// FAST-MEM-16KB: dma_start %arg2 -// FAST-MEM-16KB: dma_wait +// FAST-MEM-16KB: affine.dma_start %arg2 +// FAST-MEM-16KB: affine.dma_wait // FAST-MEM-16KB: affine.for %i2 = 0 to 8 step 4 { -// FAST-MEM-16KB: dma_start %arg0 -// FAST-MEM-16KB: dma_wait -// FAST-MEM-16KB: dma_start %arg1 -// FAST-MEM-16KB: dma_wait +// FAST-MEM-16KB: affine.dma_start %arg0 +// FAST-MEM-16KB: affine.dma_wait +// FAST-MEM-16KB: affine.dma_start %arg1 +// FAST-MEM-16KB: affine.dma_wait // FAST-MEM-16KB: affine.for %i3 = #map{{[0-9]+}}(%i0) to #map{{[0-9]+}}(%i0) { // FAST-MEM-16KB-NEXT: affine.for %i4 = #map{{[0-9]+}}(%i1) to #map{{[0-9]+}}(%i1) { // FAST-MEM-16KB-NEXT: affine.for %i5 = #map{{[0-9]+}}(%i2) to #map{{[0-9]+}}(%i2) { @@ -678,5 +650,5 @@ func @simple_matmul(%arg0: memref<8x8xvector<64xf32>>, %arg1: memref<8x8xvector< // FAST-MEM-16KB: } // FAST-MEM-16KB: } // FAST-MEM-16KB: } -// FAST-MEM-16KB: dma_start %2[%c0, %c0], %arg2 -// FAST-MEM-16KB: dma_wait +// FAST-MEM-16KB: affine.dma_start %2[%c0, %c0], %arg2 +// FAST-MEM-16KB: affine.dma_wait diff --git a/mlir/test/Transforms/loop-fusion-dependence-check.mlir b/mlir/test/Transforms/loop-fusion-dependence-check.mlir index 697ee9eeaa0..4b5c77839fb 100644 --- a/mlir/test/Transforms/loop-fusion-dependence-check.mlir +++ b/mlir/test/Transforms/loop-fusion-dependence-check.mlir @@ -18,17 +18,17 @@ func @cannot_fuse_would_create_cycle() { // Fusing loop nest '%i0' and loop nest '%i2' would create a cycle. affine.for %i0 = 0 to 10 { // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 0 into loop nest 2 at depth 0}} - %v0 = load %a[%i0] : memref<10xf32> - store %cf7, %b[%i0] : memref<10xf32> + %v0 = affine.load %a[%i0] : memref<10xf32> + affine.store %cf7, %b[%i0] : memref<10xf32> } affine.for %i1 = 0 to 10 { - store %cf7, %a[%i1] : memref<10xf32> - %v1 = load %c[%i1] : memref<10xf32> + affine.store %cf7, %a[%i1] : memref<10xf32> + %v1 = affine.load %c[%i1] : memref<10xf32> } affine.for %i2 = 0 to 10 { // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 2 into loop nest 0 at depth 0}} - %v2 = load %b[%i2] : memref<10xf32> - store %cf7, %c[%i2] : memref<10xf32> + %v2 = affine.load %b[%i2] : memref<10xf32> + affine.store %cf7, %c[%i2] : memref<10xf32> } return } @@ -51,16 +51,16 @@ func @can_fuse_rar_dependence() { // Should fuse: no fusion preventing remarks should be emitted for this test. affine.for %i0 = 0 to 10 { - %v0 = load %a[%i0] : memref<10xf32> - store %cf7, %b[%i0] : memref<10xf32> + %v0 = affine.load %a[%i0] : memref<10xf32> + affine.store %cf7, %b[%i0] : memref<10xf32> } affine.for %i1 = 0 to 10 { - %v1 = load %a[%i1] : memref<10xf32> - %v2 = load %c[%i1] : memref<10xf32> + %v1 = affine.load %a[%i1] : memref<10xf32> + %v2 = affine.load %c[%i1] : memref<10xf32> } affine.for %i2 = 0 to 10 { - %v3 = load %b[%i2] : memref<10xf32> - store %cf7, %c[%i2] : memref<10xf32> + %v3 = affine.load %b[%i2] : memref<10xf32> + affine.store %cf7, %c[%i2] : memref<10xf32> } return } @@ -84,16 +84,16 @@ func @can_fuse_different_memrefs() { // Should fuse: no fusion preventing remarks should be emitted for this test. affine.for %i0 = 0 to 10 { - %v0 = load %a[%i0] : memref<10xf32> - store %cf7, %b[%i0] : memref<10xf32> + %v0 = affine.load %a[%i0] : memref<10xf32> + affine.store %cf7, %b[%i0] : memref<10xf32> } affine.for %i1 = 0 to 10 { - store %cf7, %d[%i1] : memref<10xf32> - %v1 = load %c[%i1] : memref<10xf32> + affine.store %cf7, %d[%i1] : memref<10xf32> + %v1 = affine.load %c[%i1] : memref<10xf32> } affine.for %i2 = 0 to 10 { - %v2 = load %b[%i2] : memref<10xf32> - store %cf7, %c[%i2] : memref<10xf32> + %v2 = affine.load %b[%i2] : memref<10xf32> + affine.store %cf7, %c[%i2] : memref<10xf32> } return } @@ -108,16 +108,16 @@ func @should_not_fuse_across_intermediate_store() { affine.for %i0 = 0 to 10 { // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 0 into loop nest 1 at depth 0}} - %v0 = load %0[%i0] : memref<10xf32> + %v0 = affine.load %0[%i0] : memref<10xf32> "op0"(%v0) : (f32) -> () } // Should not fuse loop nests '%i0' and '%i1' across top-level store. - store %cf7, %0[%c0] : memref<10xf32> + affine.store %cf7, %0[%c0] : memref<10xf32> affine.for %i1 = 0 to 10 { // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 1 into loop nest 0 at depth 0}} - %v1 = load %0[%i1] : memref<10xf32> + %v1 = affine.load %0[%i1] : memref<10xf32> "op1"(%v1) : (f32) -> () } return @@ -133,16 +133,16 @@ func @should_not_fuse_across_intermediate_load() { affine.for %i0 = 0 to 10 { // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 0 into loop nest 1 at depth 0}} - store %cf7, %0[%i0] : memref<10xf32> + affine.store %cf7, %0[%i0] : memref<10xf32> } // Should not fuse loop nests '%i0' and '%i1' across top-level load. - %v0 = load %0[%c0] : memref<10xf32> + %v0 = affine.load %0[%c0] : memref<10xf32> "op0"(%v0) : (f32) -> () affine.for %i1 = 0 to 10 { // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 1 into loop nest 0 at depth 0}} - store %cf7, %0[%i1] : memref<10xf32> + affine.store %cf7, %0[%i1] : memref<10xf32> } return @@ -159,12 +159,12 @@ func @should_not_fuse_across_ssa_value_def() { affine.for %i0 = 0 to 10 { // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 0 into loop nest 1 at depth 0}} - %v0 = load %0[%i0] : memref<10xf32> - store %v0, %1[%i0] : memref<10xf32> + %v0 = affine.load %0[%i0] : memref<10xf32> + affine.store %v0, %1[%i0] : memref<10xf32> } // Loop nest '%i0" cannot be fused past load from '%1' due to RAW dependence. - %v1 = load %1[%c0] : memref<10xf32> + %v1 = affine.load %1[%c0] : memref<10xf32> "op0"(%v1) : (f32) -> () // Loop nest '%i1' cannot be fused past SSA value def '%c2' which it uses. @@ -172,7 +172,7 @@ func @should_not_fuse_across_ssa_value_def() { affine.for %i1 = 0 to 10 { // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 1 into loop nest 0 at depth 0}} - store %cf7, %0[%c2] : memref<10xf32> + affine.store %cf7, %0[%c2] : memref<10xf32> } return @@ -188,18 +188,18 @@ func @should_not_fuse_store_before_load() { affine.for %i0 = 0 to 10 { // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 0 into loop nest 2 at depth 0}} - store %cf7, %0[%i0] : memref<10xf32> - %v0 = load %0[%i0] : memref<10xf32> + affine.store %cf7, %0[%i0] : memref<10xf32> + %v0 = affine.load %0[%i0] : memref<10xf32> } affine.for %i1 = 0 to 10 { - %v1 = load %0[%i1] : memref<10xf32> + %v1 = affine.load %0[%i1] : memref<10xf32> } affine.for %i2 = 0 to 10 { // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 2 into loop nest 0 at depth 0}} - store %cf7, %0[%i2] : memref<10xf32> - %v2 = load %0[%i2] : memref<10xf32> + affine.store %cf7, %0[%i2] : memref<10xf32> + %v2 = affine.load %0[%i2] : memref<10xf32> } return } @@ -215,14 +215,14 @@ func @should_not_fuse_across_load_at_depth1() { affine.for %i0 = 0 to 10 { affine.for %i1 = 0 to 10 { // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 0 into loop nest 1 at depth 1}} - store %cf7, %0[%i0, %i1] : memref<10x10xf32> + affine.store %cf7, %0[%i0, %i1] : memref<10x10xf32> } - %v1 = load %0[%i0, %c0] : memref<10x10xf32> + %v1 = affine.load %0[%i0, %c0] : memref<10x10xf32> affine.for %i3 = 0 to 10 { // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 1 into loop nest 0 at depth 1}} - store %cf7, %0[%i0, %i3] : memref<10x10xf32> + affine.store %cf7, %0[%i0, %i3] : memref<10x10xf32> } } return @@ -239,16 +239,16 @@ func @should_not_fuse_across_load_in_loop_at_depth1() { affine.for %i0 = 0 to 10 { affine.for %i1 = 0 to 10 { // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 0 into loop nest 2 at depth 1}} - store %cf7, %0[%i0, %i1] : memref<10x10xf32> + affine.store %cf7, %0[%i0, %i1] : memref<10x10xf32> } affine.for %i2 = 0 to 10 { - %v1 = load %0[%i0, %i2] : memref<10x10xf32> + %v1 = affine.load %0[%i0, %i2] : memref<10x10xf32> } affine.for %i3 = 0 to 10 { // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 2 into loop nest 0 at depth 1}} - store %cf7, %0[%i0, %i3] : memref<10x10xf32> + affine.store %cf7, %0[%i0, %i3] : memref<10x10xf32> } } return @@ -265,14 +265,14 @@ func @should_not_fuse_across_store_at_depth1() { affine.for %i0 = 0 to 10 { affine.for %i1 = 0 to 10 { // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 0 into loop nest 1 at depth 1}} - %v0 = load %0[%i0, %i1] : memref<10x10xf32> + %v0 = affine.load %0[%i0, %i1] : memref<10x10xf32> } - store %cf7, %0[%i0, %c0] : memref<10x10xf32> + affine.store %cf7, %0[%i0, %c0] : memref<10x10xf32> affine.for %i3 = 0 to 10 { // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 1 into loop nest 0 at depth 1}} - %v1 = load %0[%i0, %i3] : memref<10x10xf32> + %v1 = affine.load %0[%i0, %i3] : memref<10x10xf32> } } return @@ -289,16 +289,16 @@ func @should_not_fuse_across_store_in_loop_at_depth1() { affine.for %i0 = 0 to 10 { affine.for %i1 = 0 to 10 { // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 0 into loop nest 2 at depth 1}} - %v0 = load %0[%i0, %i1] : memref<10x10xf32> + %v0 = affine.load %0[%i0, %i1] : memref<10x10xf32> } affine.for %i2 = 0 to 10 { - store %cf7, %0[%i0, %i2] : memref<10x10xf32> + affine.store %cf7, %0[%i0, %i2] : memref<10x10xf32> } affine.for %i3 = 0 to 10 { // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 2 into loop nest 0 at depth 1}} - %v1 = load %0[%i0, %i3] : memref<10x10xf32> + %v1 = affine.load %0[%i0, %i3] : memref<10x10xf32> } } return @@ -316,13 +316,13 @@ func @should_not_fuse_across_ssa_value_def_at_depth1() { affine.for %i0 = 0 to 10 { affine.for %i1 = 0 to 10 { // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 0 into loop nest 1 at depth 1}} - %v0 = load %0[%i0, %i1] : memref<10x10xf32> - store %v0, %1[%i0, %i1] : memref<10x10xf32> + %v0 = affine.load %0[%i0, %i1] : memref<10x10xf32> + affine.store %v0, %1[%i0, %i1] : memref<10x10xf32> } // RAW dependence from store in loop nest '%i1' to 'load %1' prevents // fusion loop nest '%i1' into loops after load. - %v1 = load %1[%i0, %c0] : memref<10x10xf32> + %v1 = affine.load %1[%i0, %c0] : memref<10x10xf32> "op0"(%v1) : (f32) -> () // Loop nest '%i2' cannot be fused past SSA value def '%c2' which it uses. @@ -330,7 +330,7 @@ func @should_not_fuse_across_ssa_value_def_at_depth1() { affine.for %i2 = 0 to 10 { // expected-remark@-1 {{block-level dependence preventing fusion of loop nest 1 into loop nest 0 at depth 1}} - store %cf7, %0[%i0, %c2] : memref<10x10xf32> + affine.store %cf7, %0[%i0, %c2] : memref<10x10xf32> } } return diff --git a/mlir/test/Transforms/loop-fusion-slice-computation.mlir b/mlir/test/Transforms/loop-fusion-slice-computation.mlir index 859b750b710..1e5e4486d66 100644 --- a/mlir/test/Transforms/loop-fusion-slice-computation.mlir +++ b/mlir/test/Transforms/loop-fusion-slice-computation.mlir @@ -8,11 +8,11 @@ func @slice_depth1_loop_nest() { %cst = constant 7.000000e+00 : f32 affine.for %i0 = 0 to 16 { // expected-remark@-1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] )}} - store %cst, %0[%i0] : memref<100xf32> + affine.store %cst, %0[%i0] : memref<100xf32> } affine.for %i1 = 0 to 5 { // expected-remark@-1 {{slice ( src loop: 0, dst loop: 1, depth: 1 : insert point: (1, 0) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] )}} - %1 = load %0[%i1] : memref<100xf32> + %1 = affine.load %0[%i1] : memref<100xf32> } return } @@ -29,12 +29,12 @@ func @slice_depth1_loop_nest_with_offsets() { affine.for %i0 = 0 to 16 { // expected-remark@-1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 2) loop bounds: [(d0) -> (d0 + 3), (d0) -> (d0 + 4)] )}} %a0 = affine.apply (d0) -> (d0 + 2)(%i0) - store %cst, %0[%a0] : memref<100xf32> + affine.store %cst, %0[%a0] : memref<100xf32> } affine.for %i1 = 4 to 8 { // expected-remark@-1 {{slice ( src loop: 0, dst loop: 1, depth: 1 : insert point: (1, 0) loop bounds: [(d0) -> (d0 - 3), (d0) -> (d0 - 2)] )}} %a1 = affine.apply (d0) -> (d0 - 1)(%i1) - %1 = load %0[%a1] : memref<100xf32> + %1 = affine.load %0[%a1] : memref<100xf32> } return } @@ -51,14 +51,14 @@ func @slice_depth2_loop_nest() { // expected-remark@-1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] [(d0) -> (0), (d0) -> (8)] )}} // expected-remark@-2 {{slice ( src loop: 1, dst loop: 0, depth: 2 : insert point: (2, 1) loop bounds: [(d0, d1) -> (d0), (d0, d1) -> (d0 + 1)] [(d0, d1) -> (d1), (d0, d1) -> (d1 + 1)] )}} affine.for %i1 = 0 to 16 { - store %cst, %0[%i0, %i1] : memref<100x100xf32> + affine.store %cst, %0[%i0, %i1] : memref<100x100xf32> } } affine.for %i2 = 0 to 10 { // expected-remark@-1 {{slice ( src loop: 0, dst loop: 1, depth: 1 : insert point: (1, 0) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] [(d0) -> (0), (d0) -> (8)] )}} // expected-remark@-2 {{slice ( src loop: 0, dst loop: 1, depth: 2 : insert point: (2, 0) loop bounds: [(d0, d1) -> (d0), (d0, d1) -> (d0 + 1)] [(d0, d1) -> (d1), (d0, d1) -> (d1 + 1)] )}} affine.for %i3 = 0 to 8 { - %1 = load %0[%i2, %i3] : memref<100x100xf32> + %1 = affine.load %0[%i2, %i3] : memref<100x100xf32> } } return @@ -78,15 +78,15 @@ func @slice_depth2_loop_nest_two_loads() { // expected-remark@-1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0)[s0] -> (d0), (d0)[s0] -> (d0 + 1)] [(d0)[s0] -> (0), (d0)[s0] -> (8)] )}} // expected-remark@-2 {{slice ( src loop: 1, dst loop: 0, depth: 2 : insert point: (2, 1) loop bounds: [(d0, d1)[s0] -> (d0), (d0, d1)[s0] -> (d0 + 1)] [(d0, d1)[s0] -> (0), (d0, d1)[s0] -> (8)] )}} affine.for %i1 = 0 to 16 { - store %cst, %0[%i0, %i1] : memref<100x100xf32> + affine.store %cst, %0[%i0, %i1] : memref<100x100xf32> } } affine.for %i2 = 0 to 10 { // expected-remark@-1 {{slice ( src loop: 0, dst loop: 1, depth: 1 : insert point: (1, 0) loop bounds: [(d0)[s0] -> (d0), (d0)[s0] -> (d0 + 1)] [(d0)[s0] -> (0), (d0)[s0] -> (8)] )}} affine.for %i3 = 0 to 8 { - %1 = load %0[%i2, %i3] : memref<100x100xf32> + %1 = affine.load %0[%i2, %i3] : memref<100x100xf32> } - %2 = load %0[%i2, %c0] : memref<100x100xf32> + %2 = affine.load %0[%i2, %c0] : memref<100x100xf32> } return } @@ -105,15 +105,15 @@ func @slice_depth2_loop_nest_two_stores() { affine.for %i0 = 0 to 16 { // expected-remark@-1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 2) loop bounds: [(d0)[s0] -> (d0), (d0)[s0] -> (d0 + 1)] [(d0)[s0] -> (0), (d0)[s0] -> (8)] )}} affine.for %i1 = 0 to 16 { - store %cst, %0[%i0, %i1] : memref<100x100xf32> + affine.store %cst, %0[%i0, %i1] : memref<100x100xf32> } - store %cst, %0[%i0, %c0] : memref<100x100xf32> + affine.store %cst, %0[%i0, %c0] : memref<100x100xf32> } affine.for %i2 = 0 to 10 { // expected-remark@-1 {{slice ( src loop: 0, dst loop: 1, depth: 1 : insert point: (1, 0) loop bounds: [(d0)[s0] -> (d0), (d0)[s0] -> (d0 + 1)] [(d0)[s0] -> (0), (d0)[s0] -> (16)] )}} // expected-remark@-2 {{slice ( src loop: 0, dst loop: 1, depth: 2 : insert point: (2, 0) loop bounds: [(d0, d1)[s0] -> (d0), (d0, d1)[s0] -> (d0 + 1)] [(d0, d1)[s0] -> (0), (d0, d1)[s0] -> (16)] )}} affine.for %i3 = 0 to 8 { - %1 = load %0[%i2, %i3] : memref<100x100xf32> + %1 = affine.load %0[%i2, %i3] : memref<100x100xf32> } } return @@ -131,14 +131,14 @@ func @slice_loop_nest_with_smaller_outer_trip_count() { // expected-remark@-1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] [(d0) -> (0), (d0) -> (10)] )}} // expected-remark@-2 {{slice ( src loop: 1, dst loop: 0, depth: 2 : insert point: (2, 1) loop bounds: [(d0, d1) -> (d0), (d0, d1) -> (d0 + 1)] [(d0, d1) -> (d1), (d0, d1) -> (d1 + 1)] )}} affine.for %i1 = 0 to 16 { - store %cst, %0[%i0, %i1] : memref<100x100xf32> + affine.store %cst, %0[%i0, %i1] : memref<100x100xf32> } } affine.for %i2 = 0 to 8 { // expected-remark@-1 {{slice ( src loop: 0, dst loop: 1, depth: 1 : insert point: (1, 0) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] [(d0) -> (0), (d0) -> (10)] )}} // expected-remark@-2 {{slice ( src loop: 0, dst loop: 1, depth: 2 : insert point: (2, 0) loop bounds: [(d0, d1) -> (d0), (d0, d1) -> (d0 + 1)] [(d0, d1) -> (d1), (d0, d1) -> (d1 + 1)] )}} affine.for %i3 = 0 to 10 { - %1 = load %0[%i2, %i3] : memref<100x100xf32> + %1 = affine.load %0[%i2, %i3] : memref<100x100xf32> } } return diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index 84b953d2bb6..a8caff40bb7 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -10,24 +10,20 @@ // ----- -// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) - // CHECK-LABEL: func @should_fuse_raw_dep_for_locality() { func @should_fuse_raw_dep_for_locality() { %m = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 affine.for %i0 = 0 to 10 { - store %cf7, %m[%i0] : memref<10xf32> + affine.store %cf7, %m[%i0] : memref<10xf32> } affine.for %i1 = 0 to 10 { - %v0 = load %m[%i1] : memref<10xf32> + %v0 = affine.load %m[%i1] : memref<10xf32> } // CHECK: affine.for %i0 = 0 to 10 { - // CHECK-NEXT: %1 = affine.apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: store %cst, %0[%1] : memref<1xf32> - // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: %3 = load %0[%2] : memref<1xf32> + // CHECK-NEXT: affine.store %cst, %0[0] : memref<1xf32> + // CHECK-NEXT: %1 = affine.load %0[0] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -35,8 +31,6 @@ func @should_fuse_raw_dep_for_locality() { // ----- -// CHECK-DAG: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) - // CHECK-LABEL: func @should_fuse_reduction_to_pointwise() { func @should_fuse_reduction_to_pointwise() { %a = alloc() : memref<10x10xf32> @@ -47,31 +41,28 @@ func @should_fuse_reduction_to_pointwise() { affine.for %i0 = 0 to 10 { affine.for %i1 = 0 to 10 { - %v0 = load %b[%i0] : memref<10xf32> - %v1 = load %a[%i0, %i1] : memref<10x10xf32> + %v0 = affine.load %b[%i0] : memref<10xf32> + %v1 = affine.load %a[%i0, %i1] : memref<10x10xf32> %v3 = addf %v0, %v1 : f32 - store %v3, %b[%i0] : memref<10xf32> + affine.store %v3, %b[%i0] : memref<10xf32> } } affine.for %i2 = 0 to 10 { - %v4 = load %b[%i2] : memref<10xf32> - store %v4, %c[%i2] : memref<10xf32> + %v4 = affine.load %b[%i2] : memref<10xf32> + affine.store %v4, %c[%i2] : memref<10xf32> } // Should fuse in entire inner loop on %i1 from source loop nest, as %i1 // is not used in the access function of the store/load on %b. // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: affine.for %i1 = 0 to 10 { - // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: %4 = load %0[%3] : memref<1xf32> - // CHECK-NEXT: %5 = load %1[%i0, %i1] : memref<10x10xf32> - // CHECK-NEXT: %6 = addf %4, %5 : f32 - // CHECK-NEXT: %7 = affine.apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: store %6, %0[%7] : memref<1xf32> + // CHECK-NEXT: %3 = affine.load %0[0] : memref<1xf32> + // CHECK-NEXT: %4 = affine.load %1[%i0, %i1] : memref<10x10xf32> + // CHECK-NEXT: %5 = addf %3, %4 : f32 + // CHECK-NEXT: affine.store %5, %0[0] : memref<1xf32> // CHECK-NEXT: } - // CHECK-NEXT: %8 = affine.apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: %9 = load %0[%8] : memref<1xf32> - // CHECK-NEXT: store %9, %2[%i0] : memref<10xf32> + // CHECK-NEXT: %6 = affine.load %0[0] : memref<1xf32> + // CHECK-NEXT: affine.store %6, %2[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -81,8 +72,6 @@ func @should_fuse_reduction_to_pointwise() { // CHECK-DAG: [[MAP_SHIFT_MINUS_ONE_R1:#map[0-9]+]] = (d0) -> (d0 - 1) // CHECK-DAG: [[MAP_SHIFT_BY_ONE:#map[0-9]+]] = (d0) -> (d0 + 1) -// CHECK-DAG: [[MAP_SHIFT_MINUS_IV_R2_EVEN:#map[0-9]+]] = (d0, d1, d2, d3) -> (-d0 + d2) -// CHECK-DAG: [[MAP_SHIFT_MINUS_IV_R2_ODD:#map[0-9]+]] = (d0, d1, d2, d3) -> (-d1 + d3) // CHECK-LABEL: func @should_fuse_loop_nests_with_shifts() { func @should_fuse_loop_nests_with_shifts() { @@ -93,12 +82,12 @@ func @should_fuse_loop_nests_with_shifts() { affine.for %i1 = 0 to 9 { %idx = affine.apply (d0) -> (d0 + 1) (%i0) %idy = affine.apply (d0) -> (d0 + 1) (%i1) - store %cf7, %a[%idx, %idy] : memref<10x10xf32> + affine.store %cf7, %a[%idx, %idy] : memref<10x10xf32> } } affine.for %i2 = 1 to 10 { affine.for %i3 = 1 to 10 { - %v0 = load %a[%i2, %i3] : memref<10x10xf32> + %v0 = affine.load %a[%i2, %i3] : memref<10x10xf32> } } @@ -116,12 +105,8 @@ func @should_fuse_loop_nests_with_shifts() { // CHECK-NEXT: %2 = affine.apply [[MAP_SHIFT_MINUS_ONE_R1]](%i1) // CHECK-NEXT: %3 = affine.apply [[MAP_SHIFT_BY_ONE]](%1) // CHECK-NEXT: %4 = affine.apply [[MAP_SHIFT_BY_ONE]](%2) - // CHECK-NEXT: %5 = affine.apply [[MAP_SHIFT_MINUS_IV_R2_EVEN]](%i0, %i1, %3, %4) - // CHECK-NEXT: %6 = affine.apply [[MAP_SHIFT_MINUS_IV_R2_ODD]](%i0, %i1, %3, %4) - // CHECK-NEXT: store %cst, %0[%5, %6] : memref<1x1xf32> - // CHECK-NEXT: %7 = affine.apply [[MAP_SHIFT_MINUS_IV_R2_EVEN]](%i0, %i1, %i0, %i1) - // CHECK-NEXT: %8 = affine.apply [[MAP_SHIFT_MINUS_IV_R2_ODD]](%i0, %i1, %i0, %i1) - // CHECK-NEXT: %9 = load %0[%7, %8] : memref<1x1xf32> + // CHECK-NEXT: affine.store %cst, %0[0, 0] : memref<1x1xf32> + // CHECK-NEXT: %5 = affine.load %0[0, 0] : memref<1x1xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -130,9 +115,6 @@ func @should_fuse_loop_nests_with_shifts() { // ----- -// CHECK-DAG: [[MAP_D2_D0_DIFF:#map[0-9]+]] = (d0, d1, d2, d3) -> (-d0 + d2) -// CHECK-DAG: [[MAP_D3_D1_DIFF:#map[0-9]+]] = (d0, d1, d2, d3) -> (-d1 + d3) - // CHECK-LABEL: func @should_fuse_loop_nest() { func @should_fuse_loop_nest() { %a = alloc() : memref<10x10xf32> @@ -141,18 +123,18 @@ func @should_fuse_loop_nest() { affine.for %i0 = 0 to 10 { affine.for %i1 = 0 to 10 { - store %cf7, %a[%i0, %i1] : memref<10x10xf32> + affine.store %cf7, %a[%i0, %i1] : memref<10x10xf32> } } affine.for %i2 = 0 to 10 { affine.for %i3 = 0 to 10 { - %v0 = load %a[%i3, %i2] : memref<10x10xf32> - store %v0, %b[%i2, %i3] : memref<10x10xf32> + %v0 = affine.load %a[%i3, %i2] : memref<10x10xf32> + affine.store %v0, %b[%i2, %i3] : memref<10x10xf32> } } affine.for %i4 = 0 to 10 { affine.for %i5 = 0 to 10 { - %v1 = load %b[%i4, %i5] : memref<10x10xf32> + %v1 = affine.load %b[%i4, %i5] : memref<10x10xf32> } } // Expecting private memref for '%a' first, then private memref for '%b'. @@ -160,18 +142,10 @@ func @should_fuse_loop_nest() { // CHECK-DAG: [[NEWB:%[0-9]+]] = alloc() : memref<1x1xf32> // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: affine.for %i1 = 0 to 10 { - // CHECK-NEXT: %2 = affine.apply [[MAP_D2_D0_DIFF]](%i1, %i0, %i1, %i0) - // CHECK-NEXT: %3 = affine.apply [[MAP_D3_D1_DIFF]](%i1, %i0, %i1, %i0) - // CHECK-NEXT: store %cst, [[NEWA]][%2, %3] : memref<1x1xf32> - // CHECK-NEXT: %4 = affine.apply [[MAP_D2_D0_DIFF]](%i1, %i0, %i1, %i0) - // CHECK-NEXT: %5 = affine.apply [[MAP_D3_D1_DIFF]](%i1, %i0, %i1, %i0) - // CHECK-NEXT: %6 = load [[NEWA]][%4, %5] : memref<1x1xf32> - // CHECK-NEXT: %7 = affine.apply [[MAP_D2_D0_DIFF]](%i0, %i1, %i0, %i1) - // CHECK-NEXT: %8 = affine.apply [[MAP_D3_D1_DIFF]](%i0, %i1, %i0, %i1) - // CHECK-NEXT: store %6, [[NEWB]][%7, %8] : memref<1x1xf32> - // CHECK-NEXT: %9 = affine.apply [[MAP_D2_D0_DIFF]](%i0, %i1, %i0, %i1) - // CHECK-NEXT: %10 = affine.apply [[MAP_D3_D1_DIFF]](%i0, %i1, %i0, %i1) - // CHECK-NEXT: %11 = load [[NEWB]][%9, %10] : memref<1x1xf32> + // CHECK-NEXT: affine.store %cst, [[NEWA]][0, 0] : memref<1x1xf32> + // CHECK-NEXT: %2 = affine.load [[NEWA]][0, 0] : memref<1x1xf32> + // CHECK-NEXT: affine.store %2, [[NEWB]][0, 0] : memref<1x1xf32> + // CHECK-NEXT: %3 = affine.load [[NEWB]][0, 0] : memref<1x1xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -180,8 +154,6 @@ func @should_fuse_loop_nest() { // ----- -// CHECK-DAG: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) - // CHECK-LABEL: func @should_fuse_across_intermediate_loop_with_no_deps() { func @should_fuse_across_intermediate_loop_with_no_deps() { %a = alloc() : memref<10xf32> @@ -191,27 +163,25 @@ func @should_fuse_across_intermediate_loop_with_no_deps() { %cf7 = constant 7.0 : f32 affine.for %i0 = 0 to 10 { - %v0 = load %a[%i0] : memref<10xf32> - store %v0, %b[%i0] : memref<10xf32> + %v0 = affine.load %a[%i0] : memref<10xf32> + affine.store %v0, %b[%i0] : memref<10xf32> } affine.for %i1 = 0 to 10 { - store %cf7, %c[%i1] : memref<10xf32> + affine.store %cf7, %c[%i1] : memref<10xf32> } affine.for %i2 = 0 to 10 { - %v1 = load %b[%i2] : memref<10xf32> + %v1 = affine.load %b[%i2] : memref<10xf32> } // Should fuse first loop (past second loop with no dependences) into third. // Note that fusion creates a private memref '%2' for the fused loop nest. // CHECK: affine.for %i0 = 0 to 10 { - // CHECK-NEXT: store %cst, %2[%i0] : memref<10xf32> + // CHECK-NEXT: affine.store %cst, %2[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK: affine.for %i1 = 0 to 10 { - // CHECK-NEXT: %3 = load %1[%i1] : memref<10xf32> - // CHECK-NEXT: %4 = affine.apply [[MAP0]](%i1, %i1) - // CHECK-NEXT: store %3, %0[%4] : memref<1xf32> - // CHECK-NEXT: %5 = affine.apply [[MAP0]](%i1, %i1) - // CHECK-NEXT: %6 = load %0[%5] : memref<1xf32> + // CHECK-NEXT: %3 = affine.load %1[%i1] : memref<10xf32> + // CHECK-NEXT: affine.store %3, %0[0] : memref<1xf32> + // CHECK-NEXT: %4 = affine.load %0[0] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -219,8 +189,6 @@ func @should_fuse_across_intermediate_loop_with_no_deps() { // ----- -// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) - // CHECK-LABEL: func @should_fuse_all_loops() { func @should_fuse_all_loops() { %a = alloc() : memref<10xf32> @@ -229,14 +197,14 @@ func @should_fuse_all_loops() { // Set up flow dependences from first and second loops to third. affine.for %i0 = 0 to 10 { - store %cf7, %a[%i0] : memref<10xf32> + affine.store %cf7, %a[%i0] : memref<10xf32> } affine.for %i1 = 0 to 10 { - store %cf7, %b[%i1] : memref<10xf32> + affine.store %cf7, %b[%i1] : memref<10xf32> } affine.for %i2 = 0 to 10 { - %v0 = load %a[%i2] : memref<10xf32> - %v1 = load %b[%i2] : memref<10xf32> + %v0 = affine.load %a[%i2] : memref<10xf32> + %v1 = affine.load %b[%i2] : memref<10xf32> } // Should fuse first and second loops into third. @@ -244,14 +212,10 @@ func @should_fuse_all_loops() { // CHECK-DAG: [[NEWA:%[0-9]+]] = alloc() : memref<1xf32> // CHECK-DAG: [[NEWB:%[0-9]+]] = alloc() : memref<1xf32> // CHECK: affine.for %i0 = 0 to 10 { - // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: store %cst, [[NEWA]][%2] : memref<1xf32> - // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: store %cst, [[NEWB]][%3] : memref<1xf32> - // CHECK-NEXT: %4 = affine.apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: %5 = load [[NEWA]][%4] : memref<1xf32> - // CHECK-NEXT: %6 = affine.apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: %7 = load [[NEWB]][%6] : memref<1xf32> + // CHECK-NEXT: affine.store %cst, [[NEWA]][0] : memref<1xf32> + // CHECK-NEXT: affine.store %cst, [[NEWB]][0] : memref<1xf32> + // CHECK-NEXT: %2 = affine.load [[NEWA]][0] : memref<1xf32> + // CHECK-NEXT: %3 = affine.load [[NEWB]][0] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -259,8 +223,6 @@ func @should_fuse_all_loops() { // ----- -// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) - // CHECK-LABEL: func @should_fuse_first_and_second_loops() { func @should_fuse_first_and_second_loops() { %a = alloc() : memref<10xf32> @@ -270,27 +232,25 @@ func @should_fuse_first_and_second_loops() { %cf7 = constant 7.0 : f32 affine.for %i0 = 0 to 10 { - store %cf7, %a[%i0] : memref<10xf32> + affine.store %cf7, %a[%i0] : memref<10xf32> } affine.for %i1 = 0 to 10 { - %v0 = load %a[%i1] : memref<10xf32> - store %cf7, %b[%i1] : memref<10xf32> + %v0 = affine.load %a[%i1] : memref<10xf32> + affine.store %cf7, %b[%i1] : memref<10xf32> } affine.for %i2 = 0 to 10 { - %v1 = load %c[%i2] : memref<10xf32> + %v1 = affine.load %c[%i2] : memref<10xf32> } // Should fuse first loop into the second (last loop should not be fused). // Should create private memref '%2' for fused loop. // CHECK: affine.for %i0 = 0 to 10 { - // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: store %cst, %0[%3] : memref<1xf32> - // CHECK-NEXT: %4 = affine.apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: %5 = load %0[%4] : memref<1xf32> - // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32> + // CHECK-NEXT: affine.store %cst, %0[0] : memref<1xf32> + // CHECK-NEXT: %3 = affine.load %0[0] : memref<1xf32> + // CHECK-NEXT: affine.store %cst, %1[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK: affine.for %i1 = 0 to 10 { - // CHECK-NEXT: %6 = load %2[%i1] : memref<10xf32> + // CHECK-NEXT: %4 = affine.load %2[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return @@ -312,29 +272,29 @@ func @should_not_fuse_would_create_cycle() { // 2) loop0 -> loop2 on memref '%b' // 3) loop1 -> loop2 on memref '%c' affine.for %i0 = 0 to 10 { - %v0 = load %a[%i0] : memref<10xf32> - store %cf7, %b[%i0] : memref<10xf32> + %v0 = affine.load %a[%i0] : memref<10xf32> + affine.store %cf7, %b[%i0] : memref<10xf32> } affine.for %i1 = 0 to 10 { - store %cf7, %a[%i1] : memref<10xf32> - %v1 = load %c[%i1] : memref<10xf32> + affine.store %cf7, %a[%i1] : memref<10xf32> + %v1 = affine.load %c[%i1] : memref<10xf32> } affine.for %i2 = 0 to 10 { - %v2 = load %b[%i2] : memref<10xf32> - store %cf7, %c[%i2] : memref<10xf32> + %v2 = affine.load %b[%i2] : memref<10xf32> + affine.store %cf7, %c[%i2] : memref<10xf32> } // Should not fuse: fusing loop first loop into last would create a cycle. // CHECK: affine.for %i0 = 0 to 10 { - // CHECK-NEXT: %3 = load %0[%i0] : memref<10xf32> - // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32> + // CHECK-NEXT: %3 = affine.load %0[%i0] : memref<10xf32> + // CHECK-NEXT: affine.store %cst, %1[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK: affine.for %i1 = 0 to 10 { - // CHECK-NEXT: store %cst, %0[%i1] : memref<10xf32> - // CHECK-NEXT: %4 = load %2[%i1] : memref<10xf32> + // CHECK-NEXT: affine.store %cst, %0[%i1] : memref<10xf32> + // CHECK-NEXT: %4 = affine.load %2[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK: affine.for %i2 = 0 to 10 { - // CHECK-NEXT: %5 = load %1[%i2] : memref<10xf32> - // CHECK-NEXT: store %cst, %2[%i2] : memref<10xf32> + // CHECK-NEXT: %5 = affine.load %1[%i2] : memref<10xf32> + // CHECK-NEXT: affine.store %cst, %2[%i2] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -342,21 +302,19 @@ func @should_not_fuse_would_create_cycle() { // ----- -// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) - // CHECK-LABEL: func @should_fuse_producer_consumer() { func @should_fuse_producer_consumer() { %m = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 affine.for %i0 = 0 to 10 { - store %cf7, %m[%i0] : memref<10xf32> + affine.store %cf7, %m[%i0] : memref<10xf32> } affine.for %i1 = 0 to 10 { - store %cf7, %m[%i1] : memref<10xf32> + affine.store %cf7, %m[%i1] : memref<10xf32> } affine.for %i2 = 0 to 10 { - %v1 = load %m[%i2] : memref<10xf32> + %v1 = affine.load %m[%i2] : memref<10xf32> } // Fusing loop %i0 to %i2 would violate the WAW dependence between %i0 and // %i1, but OK to fuse %i1 into %i2. @@ -365,13 +323,11 @@ func @should_fuse_producer_consumer() { // CHECK: %0 = alloc() : memref<1xf32> // CHECK: %1 = alloc() : memref<10xf32> // CHECK: affine.for %i0 = 0 to 10 { - // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32> + // CHECK-NEXT: affine.store %cst, %1[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: affine.for %i1 = 0 to 10 { - // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i1, %i1) - // CHECK-NEXT: store %cst, %0[%2] : memref<1xf32> - // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i1, %i1) - // CHECK-NEXT: %4 = load %0[%3] : memref<1xf32> + // CHECK-NEXT: affine.store %cst, %0[0] : memref<1xf32> + // CHECK-NEXT: %2 = affine.load %0[0] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -379,8 +335,6 @@ func @should_fuse_producer_consumer() { // ----- -// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) - // CHECK-LABEL: func @should_fuse_and_move_to_preserve_war_dep() { func @should_fuse_and_move_to_preserve_war_dep() { %a = alloc() : memref<10xf32> @@ -388,27 +342,25 @@ func @should_fuse_and_move_to_preserve_war_dep() { %cf7 = constant 7.0 : f32 affine.for %i0 = 0 to 10 { - %v0 = load %a[%i0] : memref<10xf32> - store %v0, %b[%i0] : memref<10xf32> + %v0 = affine.load %a[%i0] : memref<10xf32> + affine.store %v0, %b[%i0] : memref<10xf32> } affine.for %i1 = 0 to 10 { - store %cf7, %a[%i1] : memref<10xf32> + affine.store %cf7, %a[%i1] : memref<10xf32> } affine.for %i2 = 0 to 10 { - %v1 = load %b[%i2] : memref<10xf32> + %v1 = affine.load %b[%i2] : memref<10xf32> } // Loops '%i1' and '%i2' have no dependences. We can fuse a slice of '%i0' // into '%i2' if we move the fused loop nest before '%i1', which preserves // the WAR dependence from load '%a' in '%i0' to the store '%a' in loop '%i1'. // CHECK: affine.for %i0 = 0 to 10 { - // CHECK-NEXT: %2 = load %1[%i0] : memref<10xf32> - // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: store %2, %0[%3] : memref<1xf32> - // CHECK-NEXT: %4 = affine.apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: %5 = load %0[%4] : memref<1xf32> + // CHECK-NEXT: %2 = affine.load %1[%i0] : memref<10xf32> + // CHECK-NEXT: affine.store %2, %0[0] : memref<1xf32> + // CHECK-NEXT: %3 = affine.load %0[0] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: affine.for %i1 = 0 to 10 { - // CHECK-NEXT: store %cst, %1[%i1] : memref<10xf32> + // CHECK-NEXT: affine.store %cst, %1[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -416,55 +368,47 @@ func @should_fuse_and_move_to_preserve_war_dep() { // ----- -// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) - // CHECK-LABEL: func @should_fuse_with_private_memref_if_top_level_access() { func @should_fuse_with_private_memref_if_top_level_access() { %m = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 affine.for %i0 = 0 to 10 { - store %cf7, %m[%i0] : memref<10xf32> + affine.store %cf7, %m[%i0] : memref<10xf32> } affine.for %i1 = 0 to 10 { - %v0 = load %m[%i1] : memref<10xf32> + %v0 = affine.load %m[%i1] : memref<10xf32> } %c0 = constant 4 : index - %v1 = load %m[%c0] : memref<10xf32> + %v1 = affine.load %m[%c0] : memref<10xf32> // Top-level load to '%m' should prevent fusion. // CHECK: affine.for %i0 = 0 to 10 { - // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32> + // CHECK-NEXT: affine.store %cst, %1[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: affine.for %i1 = 0 to 10 { - // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i1, %i1) - // CHECK-NEXT: store %cst, %0[%2] : memref<1xf32> - // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i1, %i1) - // CHECK-NEXT: %4 = load %0[%3] : memref<1xf32> + // CHECK-NEXT: affine.store %cst, %0[0] : memref<1xf32> + // CHECK-NEXT: %2 = affine.load %0[0] : memref<1xf32> // CHECK-NEXT: } return } // ----- -// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) - // CHECK-LABEL: func @should_fuse_no_top_level_access() { func @should_fuse_no_top_level_access() { %m = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 affine.for %i0 = 0 to 10 { - store %cf7, %m[%i0] : memref<10xf32> + affine.store %cf7, %m[%i0] : memref<10xf32> } affine.for %i1 = 0 to 10 { - %v0 = load %m[%i1] : memref<10xf32> + %v0 = affine.load %m[%i1] : memref<10xf32> } // CHECK: affine.for %i0 = 0 to 10 { - // CHECK-NEXT: %1 = affine.apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: store %cst, %0[%1] : memref<1xf32> - // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: %3 = load %0[%2] : memref<1xf32> + // CHECK-NEXT: affine.store %cst, %0[0] : memref<1xf32> + // CHECK-NEXT: %1 = affine.load %0[0] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -480,20 +424,20 @@ func @should_not_fuse_if_inst_at_top_level() { %cf7 = constant 7.0 : f32 affine.for %i0 = 0 to 10 { - store %cf7, %m[%i0] : memref<10xf32> + affine.store %cf7, %m[%i0] : memref<10xf32> } affine.for %i1 = 0 to 10 { - %v0 = load %m[%i1] : memref<10xf32> + %v0 = affine.load %m[%i1] : memref<10xf32> } %c0 = constant 4 : index affine.if #set0(%c0) { } // Top-level IfOp should prevent fusion. // CHECK: affine.for %i0 = 0 to 10 { - // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> + // CHECK-NEXT: affine.store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK: affine.for %i1 = 0 to 10 { - // CHECK-NEXT: %1 = load %0[%i1] : memref<10xf32> + // CHECK-NEXT: %1 = affine.load %0[%i1] : memref<10xf32> // CHECK-NEXT: } return } @@ -509,32 +453,28 @@ func @should_not_fuse_if_inst_in_loop_nest() { %c4 = constant 4 : index affine.for %i0 = 0 to 10 { - store %cf7, %m[%i0] : memref<10xf32> + affine.store %cf7, %m[%i0] : memref<10xf32> } affine.for %i1 = 0 to 10 { affine.if #set0(%c4) { } - %v0 = load %m[%i1] : memref<10xf32> + %v0 = affine.load %m[%i1] : memref<10xf32> } // IfOp in ForInst should prevent fusion. // CHECK: affine.for %i0 = 0 to 10 { - // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> + // CHECK-NEXT: affine.store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK: affine.for %i1 = 0 to 10 { // CHECK-NEXT: affine.if #set0(%c4) { // CHECK-NEXT: } - // CHECK-NEXT: %1 = load %0[%i1] : memref<10xf32> + // CHECK-NEXT: %1 = affine.load %0[%i1] : memref<10xf32> // CHECK-NEXT: } return } // ----- -// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1, d2, d3, d4, d5) -> (-d0 + d3) -// CHECK: [[MAP1:#map[0-9]+]] = (d0, d1, d2, d3, d4, d5) -> (-d1 + d4) -// CHECK: [[MAP2:#map[0-9]+]] = (d0, d1, d2, d3, d4, d5) -> (-d2 + d5) - // CHECK-LABEL: func @permute_and_fuse() { func @permute_and_fuse() { %m = alloc() : memref<10x20x30xf32> @@ -543,14 +483,14 @@ func @permute_and_fuse() { affine.for %i0 = 0 to 10 { affine.for %i1 = 0 to 20 { affine.for %i2 = 0 to 30 { - store %cf7, %m[%i0, %i1, %i2] : memref<10x20x30xf32> + affine.store %cf7, %m[%i0, %i1, %i2] : memref<10x20x30xf32> } } } affine.for %i3 = 0 to 30 { affine.for %i4 = 0 to 10 { affine.for %i5 = 0 to 20 { - %v0 = load %m[%i4, %i5, %i3] : memref<10x20x30xf32> + %v0 = affine.load %m[%i4, %i5, %i3] : memref<10x20x30xf32> "foo"(%v0) : (f32) -> () } } @@ -558,15 +498,9 @@ func @permute_and_fuse() { // CHECK: affine.for %i0 = 0 to 30 { // CHECK-NEXT: affine.for %i1 = 0 to 10 { // CHECK-NEXT: affine.for %i2 = 0 to 20 { -// CHECK-NEXT: %1 = affine.apply [[MAP0]](%i1, %i2, %i0, %i1, %i2, %i0) -// CHECK-NEXT: %2 = affine.apply [[MAP1]](%i1, %i2, %i0, %i1, %i2, %i0) -// CHECK-NEXT: %3 = affine.apply [[MAP2]](%i1, %i2, %i0, %i1, %i2, %i0) -// CHECK-NEXT: store %cst, %0[%1, %2, %3] : memref<1x1x1xf32> -// CHECK-NEXT: %4 = affine.apply [[MAP0]](%i1, %i2, %i0, %i1, %i2, %i0) -// CHECK-NEXT: %5 = affine.apply [[MAP1]](%i1, %i2, %i0, %i1, %i2, %i0) -// CHECK-NEXT: %6 = affine.apply [[MAP2]](%i1, %i2, %i0, %i1, %i2, %i0) -// CHECK-NEXT: %7 = load %0[%4, %5, %6] : memref<1x1x1xf32> -// CHECK-NEXT: "foo"(%7) : (f32) -> () +// CHECK-NEXT: affine.store %cst, %0[0, 0, 0] : memref<1x1x1xf32> +// CHECK-NEXT: %1 = affine.load %0[0, 0, 0] : memref<1x1x1xf32> +// CHECK-NEXT: "foo"(%1) : (f32) -> () // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } @@ -587,15 +521,15 @@ func @fuse_reshape_64_16_4(%in : memref<64xf32>) { %out = alloc() : memref<16x4xf32> affine.for %i0 = 0 to 64 { - %v = load %in[%i0] : memref<64xf32> + %v = affine.load %in[%i0] : memref<64xf32> %idx = affine.apply (d0) -> (d0 floordiv 4) (%i0) %idy = affine.apply (d0) -> (d0 mod 4) (%i0) - store %v, %out[%idx, %idy] : memref<16x4xf32> + affine.store %v, %out[%idx, %idy] : memref<16x4xf32> } affine.for %i1 = 0 to 16 { affine.for %i2 = 0 to 4 { - %w = load %out[%i1, %i2] : memref<16x4xf32> + %w = affine.load %out[%i1, %i2] : memref<16x4xf32> "foo"(%w) : (f32) -> () } } @@ -612,7 +546,6 @@ func @fuse_reshape_64_16_4(%in : memref<64xf32>) { // CHECK-DAG: [[MAP0:#map[0-9]+]] = (d0) -> (d0 floordiv 4) // CHECK-DAG: [[MAP1:#map[0-9]+]] = (d0) -> (d0 mod 4) // CHECK-DAG: [[MAP2:#map[0-9]+]] = (d0, d1) -> (d0 * 4 + d1) -// CHECK-DAG: [[MAP3:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) // Reshape a 16x4xf32 to 64xf32. // CHECK-LABEL: func @fuse_reshape_16_4_64 @@ -622,26 +555,24 @@ func @fuse_reshape_16_4_64() { affine.for %i0 = 0 to 16 { affine.for %i1 = 0 to 4 { - %v = load %in[%i0, %i1] : memref<16x4xf32> + %v = affine.load %in[%i0, %i1] : memref<16x4xf32> %idx = affine.apply (d0, d1) -> (4*d0 + d1) (%i0, %i1) - store %v, %out[%idx] : memref<64xf32> + affine.store %v, %out[%idx] : memref<64xf32> } } affine.for %i2 = 0 to 64 { - %w = load %out[%i2] : memref<64xf32> + %w = affine.load %out[%i2] : memref<64xf32> "foo"(%w) : (f32) -> () } // CHECK: affine.for %i0 = 0 to 64 { // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i0) // CHECK-NEXT: %3 = affine.apply [[MAP1]](%i0) -// CHECK-NEXT: %4 = load %1[%2, %3] : memref<16x4xf32> +// CHECK-NEXT: %4 = affine.load %1[%2, %3] : memref<16x4xf32> // CHECK-NEXT: %5 = affine.apply [[MAP2]](%2, %3) -// CHECK-NEXT: %6 = affine.apply [[MAP3]](%i0, %5) -// CHECK-NEXT: store %4, %0[%6] : memref<1xf32> -// CHECK-NEXT: %7 = affine.apply [[MAP3]](%i0, %i0) -// CHECK-NEXT: %8 = load %0[%7] : memref<1xf32> -// CHECK-NEXT: "foo"(%8) : (f32) -> () +// CHECK-NEXT: affine.store %4, %0[0] : memref<1xf32> +// CHECK-NEXT: %6 = affine.load %0[0] : memref<1xf32> +// CHECK-NEXT: "foo"(%6) : (f32) -> () // CHECK-NEXT: } // CHECK-NEXT: return return @@ -665,7 +596,7 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> { affine.for %i4 = 0 to 16 { affine.for %i5 = 0 to 1 { %val = "foo"(%i0, %i1, %i2, %i3, %i4, %i5) : (index, index, index, index, index, index) -> i32 - store %val, %in[%i0, %i1, %i2, %i3, %i4, %i5] : memref<2x2x3x3x16x1xi32> + affine.store %val, %in[%i0, %i1, %i2, %i3, %i4, %i5] : memref<2x2x3x3x16x1xi32> } } } @@ -683,16 +614,16 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> { %3 = affine.apply (d0) -> ((((d0 mod 288) mod 144) mod 48) floordiv (16 * 1))(%a0) %4 = affine.apply (d0) -> ((((d0 mod 288) mod 144) mod 48) mod 16)(%a0) %5 = affine.apply (d0) -> (((((d0 mod 144) mod 144) mod 48) mod 16) mod 1)(%a0) - %v = load %in[%0, %1, %2, %3, %4, %5] : memref<2x2x3x3x16x1xi32> - store %v, %out[%ii, %jj] : memref<64x9xi32> + %v = affine.load %in[%0, %1, %2, %3, %4, %5] : memref<2x2x3x3x16x1xi32> + affine.store %v, %out[%ii, %jj] : memref<64x9xi32> } } affine.for %i = 0 to 64 { affine.for %j = 0 to 9 { - %a = load %out[%i, %j] : memref<64x9xi32> + %a = affine.load %out[%i, %j] : memref<64x9xi32> %b = muli %a, %a : i32 - store %b, %live_out[%i, %j] : memref<64x9xi32> + affine.store %b, %live_out[%i, %j] : memref<64x9xi32> } } return %live_out : memref<64x9xi32> @@ -705,12 +636,6 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> { // CHECK-DAG: [[MAP2:#map[0-9]+]] = (d0, d1) -> ((((d0 * 9 + d1) mod 288) mod 144) floordiv 48) // CHECK-DAG: [[MAP3:#map[0-9]+]] = (d0, d1) -> (((((d0 * 9 + d1) mod 288) mod 144) mod 48) floordiv 16) // CHECK-DAG: [[MAP4:#map[0-9]+]] = (d0, d1) -> (((((d0 * 9 + d1) mod 288) mod 144) mod 48) mod 16) -// CHECK-DAG: [[MAP5:#map[0-9]+]] = (d0, d1, d2, d3, d4, d5, d6, d7) -> (d2 - (d0 * 9 + d1) floordiv 288) -// CHECK-DAG: [[MAP6:#map[0-9]+]] = (d0, d1, d2, d3, d4, d5, d6, d7) -> (d3) -// CHECK-DAG: [[MAP7:#map[0-9]+]] = (d0, d1, d2, d3, d4, d5, d6, d7) -> (d4) -// CHECK-DAG: [[MAP8:#map[0-9]+]] = (d0, d1, d2, d3, d4, d5, d6, d7) -> (d5) -// CHECK-DAG: [[MAP9:#map[0-9]+]] = (d0, d1, d2, d3, d4, d5, d6, d7) -> (d6) -// CHECK-DAG: [[MAP10:#map[0-9]+]] = (d0, d1, d2, d3, d4, d5, d6, d7) -> (d7) // CHECK-DAG: [[MAP11:#map[0-9]+]] = (d0, d1) -> (d0 * 9 + d1) // CHECK-DAG: [[MAP12:#map[0-9]+]] = (d0) -> (d0 floordiv 288) // CHECK-DAG: [[MAP13:#map[0-9]+]] = (d0) -> ((d0 mod 288) floordiv 144) @@ -718,8 +643,7 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> { // CHECK-DAG: [[MAP15:#map[0-9]+]] = (d0) -> ((((d0 mod 288) mod 144) mod 48) floordiv 16) // CHECK-DAG: [[MAP16:#map[0-9]+]] = (d0) -> ((((d0 mod 288) mod 144) mod 48) mod 16) // CHECK-DAG: [[MAP17:#map[0-9]+]] = (d0) -> (0) -// CHECK-DAG: [[MAP18:#map[0-9]+]] = (d0, d1, d2, d3) -> (-d0 + d2) -// CHECK-DAG: [[MAP19:#map[0-9]+]] = (d0, d1, d2, d3) -> (-d1 + d3) + // // CHECK-LABEL: func @R6_to_R2_reshape // CHECK: %0 = alloc() : memref<1x2x3x3x16x1xi32> @@ -733,35 +657,19 @@ func @R6_to_R2_reshape_square() -> memref<64x9xi32> { // CHECK-NEXT: %6 = affine.apply [[MAP3]](%i0, %i1) // CHECK-NEXT: %7 = affine.apply [[MAP4]](%i0, %i1) // CHECK-NEXT: %8 = "foo"(%3, %4, %5, %6, %7, %c0) : (index, index, index, index, index, index) -> i32 -// CHECK-NEXT: %9 = affine.apply [[MAP5]](%i0, %i1, %3, %4, %5, %6, %7, %c0) -// CHECK-NEXT: %10 = affine.apply [[MAP6]](%i0, %i1, %3, %4, %5, %6, %7, %c0) -// CHECK-NEXT: %11 = affine.apply [[MAP7]](%i0, %i1, %3, %4, %5, %6, %7, %c0) -// CHECK-NEXT: %12 = affine.apply [[MAP8]](%i0, %i1, %3, %4, %5, %6, %7, %c0) -// CHECK-NEXT: %13 = affine.apply [[MAP9]](%i0, %i1, %3, %4, %5, %6, %7, %c0) -// CHECK-NEXT: %14 = affine.apply [[MAP10]](%i0, %i1, %3, %4, %5, %6, %7, %c0) -// CHECK-NEXT: store %8, %0[%9, %10, %11, %12, %13, %14] : memref<1x2x3x3x16x1xi32> -// CHECK-NEXT: %15 = affine.apply [[MAP11]](%i0, %i1) -// CHECK-NEXT: %16 = affine.apply [[MAP12]](%15) -// CHECK-NEXT: %17 = affine.apply [[MAP13]](%15) -// CHECK-NEXT: %18 = affine.apply [[MAP14]](%15) -// CHECK-NEXT: %19 = affine.apply [[MAP15]](%15) -// CHECK-NEXT: %20 = affine.apply [[MAP16]](%15) -// CHECK-NEXT: %21 = affine.apply [[MAP17]](%15) -// CHECK-NEXT: %22 = affine.apply [[MAP5]](%i0, %i1, %16, %17, %18, %19, %20, %21) -// CHECK-NEXT: %23 = affine.apply [[MAP6]](%i0, %i1, %16, %17, %18, %19, %20, %21) -// CHECK-NEXT: %24 = affine.apply [[MAP7]](%i0, %i1, %16, %17, %18, %19, %20, %21) -// CHECK-NEXT: %25 = affine.apply [[MAP8]](%i0, %i1, %16, %17, %18, %19, %20, %21) -// CHECK-NEXT: %26 = affine.apply [[MAP9]](%i0, %i1, %16, %17, %18, %19, %20, %21) -// CHECK-NEXT: %27 = affine.apply [[MAP10]](%i0, %i1, %16, %17, %18, %19, %20, %21) -// CHECK-NEXT: %28 = load %0[%22, %23, %24, %25, %26, %27] : memref<1x2x3x3x16x1xi32> -// CHECK-NEXT: %29 = affine.apply [[MAP18]](%i0, %i1, %i0, %i1) -// CHECK-NEXT: %30 = affine.apply [[MAP19]](%i0, %i1, %i0, %i1) -// CHECK-NEXT: store %28, %1[%29, %30] : memref<1x1xi32> -// CHECK-NEXT: %31 = affine.apply [[MAP18]](%i0, %i1, %i0, %i1) -// CHECK-NEXT: %32 = affine.apply [[MAP19]](%i0, %i1, %i0, %i1) -// CHECK-NEXT: %33 = load %1[%31, %32] : memref<1x1xi32> -// CHECK-NEXT: %34 = muli %33, %33 : i32 -// CHECK-NEXT: store %34, %2[%i0, %i1] : memref<64x9xi32> +// CHECK-NEXT: affine.store %8, %0[0, ((%i0 * 9 + %i1) mod 288) floordiv 144, (((%i0 * 9 + %i1) mod 288) mod 144) floordiv 48, ((((%i0 * 9 + %i1) mod 288) mod 144) mod 48) floordiv 16, ((((%i0 * 9 + %i1) mod 288) mod 144) mod 48) mod 16, symbol(%c0)] : memref<1x2x3x3x16x1xi32> +// CHECK-NEXT: %9 = affine.apply [[MAP11]](%i0, %i1) +// CHECK-NEXT: %10 = affine.apply [[MAP12]](%9) +// CHECK-NEXT: %11 = affine.apply [[MAP13]](%9) +// CHECK-NEXT: %12 = affine.apply [[MAP14]](%9) +// CHECK-NEXT: %13 = affine.apply [[MAP15]](%9) +// CHECK-NEXT: %14 = affine.apply [[MAP16]](%9) +// CHECK-NEXT: %15 = affine.apply [[MAP17]](%9) +// CHECK-NEXT: %16 = affine.load %0[0, ((%i0 * 9 + %i1) mod 288) floordiv 144, (((%i0 * 9 + %i1) mod 288) mod 144) floordiv 48, ((((%i0 * 9 + %i1) mod 288) mod 144) mod 48) floordiv 16, ((((%i0 * 9 + %i1) mod 288) mod 144) mod 48) mod 16, 0] : memref<1x2x3x3x16x1xi32> +// CHECK-NEXT: affine.store %16, %1[0, 0] : memref<1x1xi32> +// CHECK-NEXT: %17 = affine.load %1[0, 0] : memref<1x1xi32> +// CHECK-NEXT: %18 = muli %17, %17 : i32 +// CHECK-NEXT: affine.store %18, %2[%i0, %i1] : memref<64x9xi32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return %2 : memref<64x9xi32> @@ -778,14 +686,14 @@ func @fuse_symbolic_bounds(%M : index, %N : index) { affine.for %i0 = 0 to %M { affine.for %i1 = 0 to (d0) -> (d0 + 5) (%N) { - store %c0, %m[%i0, %i1] : memref + affine.store %c0, %m[%i0, %i1] : memref } } affine.for %i2 = 0 to %M { affine.for %i3 = 0 to %N { %idy = affine.apply (d0)[s0] -> (d0 + s0) (%i3)[%s] - %v = load %m[%i2, %idy] : memref + %v = affine.load %m[%i2, %idy] : memref } } @@ -793,7 +701,6 @@ func @fuse_symbolic_bounds(%M : index, %N : index) { } // ----- -// CHECK-DAG: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) // CHECK-LABEL: func @should_fuse_reduction_at_depth1 func @should_fuse_reduction_at_depth1() { @@ -802,18 +709,18 @@ func @should_fuse_reduction_at_depth1() { affine.for %i0 = 0 to 10 { affine.for %i1 = 0 to 100 { - %v0 = load %b[%i0] : memref<10xf32> - %v1 = load %a[%i0, %i1] : memref<10x100xf32> + %v0 = affine.load %b[%i0] : memref<10xf32> + %v1 = affine.load %a[%i0, %i1] : memref<10x100xf32> %v2 = "maxf"(%v0, %v1) : (f32, f32) -> f32 - store %v2, %b[%i0] : memref<10xf32> + affine.store %v2, %b[%i0] : memref<10xf32> } } affine.for %i2 = 0 to 10 { affine.for %i3 = 0 to 100 { - %v3 = load %b[%i2] : memref<10xf32> - %v4 = load %a[%i2, %i3] : memref<10x100xf32> + %v3 = affine.load %b[%i2] : memref<10xf32> + %v4 = affine.load %a[%i2, %i3] : memref<10x100xf32> %v5 = subf %v4, %v3 : f32 - store %v5, %b[%i2] : memref<10xf32> + affine.store %v5, %b[%i2] : memref<10xf32> } } // This test should fuse the src reduction loop at depth 1 in the destination @@ -822,20 +729,16 @@ func @should_fuse_reduction_at_depth1() { // memory space. // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: affine.for %i1 = 0 to 100 { - // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: %3 = load %0[%2] : memref<1xf32> - // CHECK-NEXT: %4 = load %1[%i0, %i1] : memref<10x100xf32> - // CHECK-NEXT: %5 = "maxf"(%3, %4) : (f32, f32) -> f32 - // CHECK-NEXT: %6 = affine.apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: store %5, %0[%6] : memref<1xf32> + // CHECK-NEXT: %2 = affine.load %0[0] : memref<1xf32> + // CHECK-NEXT: %3 = affine.load %1[%i0, %i1] : memref<10x100xf32> + // CHECK-NEXT: %4 = "maxf"(%2, %3) : (f32, f32) -> f32 + // CHECK-NEXT: affine.store %4, %0[0] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: affine.for %i2 = 0 to 100 { - // CHECK-NEXT: %7 = affine.apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: %8 = load %0[%7] : memref<1xf32> - // CHECK-NEXT: %9 = load %1[%i0, %i2] : memref<10x100xf32> - // CHECK-NEXT: %10 = subf %9, %8 : f32 - // CHECK-NEXT: %11 = affine.apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: store %10, %0[%11] : memref<1xf32> + // CHECK-NEXT: %5 = affine.load %0[0] : memref<1xf32> + // CHECK-NEXT: %6 = affine.load %1[%i0, %i2] : memref<10x100xf32> + // CHECK-NEXT: %7 = subf %6, %5 : f32 + // CHECK-NEXT: affine.store %7, %0[0] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -843,8 +746,6 @@ func @should_fuse_reduction_at_depth1() { } // ----- -// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1, d2) -> (-d0 + d1) -// CHECK: [[MAP1:#map[0-9]+]] = (d0, d1, d2) -> (d2) // CHECK-LABEL: func @should_fuse_at_src_depth1_and_dst_depth1 func @should_fuse_at_src_depth1_and_dst_depth1() { @@ -853,18 +754,18 @@ func @should_fuse_at_src_depth1_and_dst_depth1() { affine.for %i0 = 0 to 100 { affine.for %i1 = 0 to 16 { - %v0 = load %a[%i0, %i1] : memref<100x16xf32> + %v0 = affine.load %a[%i0, %i1] : memref<100x16xf32> "op0"(%v0) : (f32) -> () } affine.for %i2 = 0 to 16 { %v1 = "op1"() : () -> (f32) - store %v1, %b[%i0, %i2] : memref<100x16xf32> + affine.store %v1, %b[%i0, %i2] : memref<100x16xf32> } } affine.for %i3 = 0 to 100 { affine.for %i4 = 0 to 16 { - %v2 = load %b[%i3, %i4] : memref<100x16xf32> + %v2 = affine.load %b[%i3, %i4] : memref<100x16xf32> "op2"(%v2) : (f32) -> () } } @@ -875,20 +776,16 @@ func @should_fuse_at_src_depth1_and_dst_depth1() { // at depth 1 and the slice should be inserted at depth 1. // CHECK: affine.for %i0 = 0 to 100 { // CHECK-NEXT: affine.for %i1 = 0 to 16 { - // CHECK-NEXT: %2 = load %1[%i0, %i1] : memref<100x16xf32> + // CHECK-NEXT: %2 = affine.load %1[%i0, %i1] : memref<100x16xf32> // CHECK-NEXT: "op0"(%2) : (f32) -> () // CHECK-NEXT: } // CHECK-NEXT: affine.for %i2 = 0 to 16 { // CHECK-NEXT: %3 = "op1"() : () -> f32 - // CHECK-NEXT: %4 = affine.apply [[MAP0]](%i0, %i0, %i2) - // CHECK-NEXT: %5 = affine.apply [[MAP1]](%i0, %i0, %i2) - // CHECK-NEXT: store %3, %0[%4, %5] : memref<1x16xf32> + // CHECK-NEXT: affine.store %3, %0[0, %i2] : memref<1x16xf32> // CHECK-NEXT: } // CHECK-NEXT: affine.for %i3 = 0 to 16 { - // CHECK-NEXT: %6 = affine.apply [[MAP0]](%i0, %i0, %i3) - // CHECK-NEXT: %7 = affine.apply [[MAP1]](%i0, %i0, %i3) - // CHECK-NEXT: %8 = load %0[%6, %7] : memref<1x16xf32> - // CHECK-NEXT: "op2"(%8) : (f32) -> () + // CHECK-NEXT: %4 = affine.load %0[0, %i3] : memref<1x16xf32> + // CHECK-NEXT: "op2"(%4) : (f32) -> () // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -897,7 +794,6 @@ func @should_fuse_at_src_depth1_and_dst_depth1() { // ----- // CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (d0 * 10 + d1) -// CHECK: [[MAP1:#map[0-9]+]] = (d0, d1, d2) -> (d0 * -10 - d1 + d2) // CHECK-LABEL: func @should_fuse_src_depth1_at_dst_depth2 func @should_fuse_src_depth1_at_dst_depth2() { @@ -905,13 +801,13 @@ func @should_fuse_src_depth1_at_dst_depth2() { %c0 = constant 0.0 : f32 affine.for %i0 = 0 to 100 { - store %c0, %a[%i0] : memref<100xf32> + affine.store %c0, %a[%i0] : memref<100xf32> } affine.for %i1 = 0 to 10 { affine.for %i2 = 0 to 10 { %a0 = affine.apply (d0, d1) -> (d0 * 10 + d1) (%i1, %i2) - %v0 = load %a[%a0] : memref<100xf32> + %v0 = affine.load %a[%a0] : memref<100xf32> } } // The source loop nest slice loop bound is a function of both destination @@ -919,11 +815,9 @@ func @should_fuse_src_depth1_at_dst_depth2() { // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: affine.for %i1 = 0 to 10 { // CHECK-NEXT: %1 = affine.apply [[MAP0]](%i0, %i1) - // CHECK-NEXT: %2 = affine.apply [[MAP1]](%i0, %i1, %1) - // CHECK-NEXT: store %cst, %0[%2] : memref<1xf32> - // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i1) - // CHECK-NEXT: %4 = affine.apply [[MAP1]](%i0, %i1, %3) - // CHECK-NEXT: %5 = load %0[%4] : memref<1xf32> + // CHECK-NEXT: affine.store %cst, %0[0] : memref<1xf32> + // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i0, %i1) + // CHECK-NEXT: %3 = affine.load %0[0] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -938,17 +832,17 @@ func @fusion_at_depth0_not_currently_supported() { %c0 = constant 0 : index %cst = constant 0.000000e+00 : f32 affine.for %i0 = 0 to 10 { - store %cst, %0[%i0] : memref<10xf32> + affine.store %cst, %0[%i0] : memref<10xf32> } affine.for %i1 = 0 to 10 { - %1 = load %0[%c0] : memref<10xf32> + %1 = affine.load %0[%c0] : memref<10xf32> } // NOTE: Should shrink memref size to 1 element access by load in dst loop // nest, and make the store in the slice store to the same element. // CHECK-DAG: %0 = alloc() : memref<1xf32> // CHECK: affine.for %i0 = 0 to 10 { - // CHECK-NEXT: store %cst, %0[%c0] : memref<1xf32> - // CHECK-NEXT: %1 = load %0[%c0_0] : memref<1xf32> + // CHECK-NEXT: affine.store %cst, %0[symbol(%c0)] : memref<1xf32> + // CHECK-NEXT: %1 = affine.load %0[symbol(%c0_0)] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -956,13 +850,6 @@ func @fusion_at_depth0_not_currently_supported() { // ----- -// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1, d2, d3, d4, d5, d6, d7, d8, d9) -> (-d0 + d4) -// CHECK: [[MAP1:#map[0-9]+]] = (d0, d1, d2, d3, d4, d5, d6, d7, d8, d9) -> (-d1 + d5) -// CHECK: [[MAP2:#map[0-9]+]] = (d0, d1, d2, d3, d4, d5, d6, d7, d8, d9) -> (-d2 + d6) -// CHECK: [[MAP3:#map[0-9]+]] = (d0, d1, d2, d3, d4, d5, d6, d7, d8, d9) -> (-d3 + d7) -// CHECK: [[MAP4:#map[0-9]+]] = (d0, d1, d2, d3, d4, d5, d6, d7, d8, d9) -> (d8) -// CHECK: [[MAP5:#map[0-9]+]] = (d0, d1, d2, d3, d4, d5, d6, d7, d8, d9) -> (d9) - // CHECK-LABEL: func @should_fuse_deep_loop_nests func @should_fuse_deep_loop_nests() { %0 = alloc() : memref<2x2x3x3x16x10xf32, 2> @@ -978,13 +865,13 @@ func @should_fuse_deep_loop_nests() { affine.for %i3 = 0 to 3 { affine.for %i4 = 0 to 16 { affine.for %i5 = 0 to 10 { - %3 = load %0[%i0, %i1, %i2, %i3, %i4, %i5] + %3 = affine.load %0[%i0, %i1, %i2, %i3, %i4, %i5] : memref<2x2x3x3x16x10xf32, 2> } } affine.for %i6 = 0 to 16 { affine.for %i7 = 0 to 10 { - store %cst, %1[%i0, %i1, %i2, %i3, %i6, %i7] + affine.store %cst, %1[%i0, %i1, %i2, %i3, %i6, %i7] : memref<2x2x3x3x16x10xf32, 2> } } @@ -1002,13 +889,13 @@ func @should_fuse_deep_loop_nests() { affine.for %i15 = 0 to 2 { affine.for %i16 = 0 to 16 { affine.for %i17 = 0 to 10 { - %5 = load %0[%i14, %i15, %i12, %i13, %i16, %i17] + %5 = affine.load %0[%i14, %i15, %i12, %i13, %i16, %i17] : memref<2x2x3x3x16x10xf32, 2> } } affine.for %i18 = 0 to 16 { affine.for %i19 = 0 to 10 { - %6 = load %1[%i10, %i11, %i8, %i9, %i18, %i19] + %6 = affine.load %1[%i10, %i11, %i8, %i9, %i18, %i19] : memref<2x2x3x3x16x10xf32, 2> } } @@ -1033,36 +920,24 @@ func @should_fuse_deep_loop_nests() { // CHECK-NEXT: affine.for %i5 = 0 to 3 { // CHECK-NEXT: affine.for %i6 = 0 to 16 { // CHECK-NEXT: affine.for %i7 = 0 to 10 { -// CHECK-NEXT: %3 = load %1[%i2, %i3, %i0, %i1, %i6, %i7] : memref<2x2x3x3x16x10xf32, 2> +// CHECK-NEXT: %3 = affine.load %1[%i2, %i3, %i0, %i1, %i6, %i7] : memref<2x2x3x3x16x10xf32, 2> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: affine.for %i8 = 0 to 16 { // CHECK-NEXT: affine.for %i9 = 0 to 10 { -// CHECK-NEXT: %4 = affine.apply [[MAP0]](%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i8, %i9) -// CHECK-NEXT: %5 = affine.apply [[MAP1]](%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i8, %i9) -// CHECK-NEXT: %6 = affine.apply [[MAP2]](%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i8, %i9) -// CHECK-NEXT: %7 = affine.apply [[MAP3]](%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i8, %i9) -// CHECK-NEXT: %8 = affine.apply [[MAP4]](%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i8, %i9) -// CHECK-NEXT: %9 = affine.apply [[MAP5]](%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i8, %i9) -// CHECK-NEXT: store %cst, %0[%4, %5, %6, %7, %8, %9] : memref<1x1x1x1x16x10xf32, 2> +// CHECK-NEXT: affine.store %cst, %0[0, 0, 0, 0, %i8, %i9] : memref<1x1x1x1x16x10xf32, 2> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: affine.for %i10 = 0 to 2 { // CHECK-NEXT: affine.for %i11 = 0 to 2 { // CHECK-NEXT: affine.for %i12 = 0 to 16 { // CHECK-NEXT: affine.for %i13 = 0 to 10 { -// CHECK-NEXT: %10 = load %1[%i10, %i11, %i4, %i5, %i12, %i13] : memref<2x2x3x3x16x10xf32, 2> +// CHECK-NEXT: %4 = affine.load %1[%i10, %i11, %i4, %i5, %i12, %i13] : memref<2x2x3x3x16x10xf32, 2> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: affine.for %i14 = 0 to 16 { // CHECK-NEXT: affine.for %i15 = 0 to 10 { -// CHECK-NEXT: %11 = affine.apply [[MAP0]](%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i14, %i15) -// CHECK-NEXT: %12 = affine.apply [[MAP1]](%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i14, %i15) -// CHECK-NEXT: %13 = affine.apply [[MAP2]](%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i14, %i15) -// CHECK-NEXT: %14 = affine.apply [[MAP3]](%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i14, %i15) -// CHECK-NEXT: %15 = affine.apply [[MAP4]](%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i14, %i15) -// CHECK-NEXT: %16 = affine.apply [[MAP5]](%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i14, %i15) -// CHECK-NEXT: %17 = load %0[%11, %12, %13, %14, %15, %16] : memref<1x1x1x1x16x10xf32, 2> +// CHECK-NEXT: %5 = affine.load %0[0, 0, 0, 0, %i14, %i15] : memref<1x1x1x1x16x10xf32, 2> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } @@ -1078,8 +953,6 @@ func @should_fuse_deep_loop_nests() { } // ----- -// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1, d2) -> (-d0 + d1) -// CHECK: [[MAP1:#map[0-9]+]] = (d0, d1, d2) -> (d2) // CHECK-LABEL: func @should_fuse_at_depth1_and_reduce_slice_trip_count func @should_fuse_at_depth1_and_reduce_slice_trip_count() { @@ -1091,16 +964,16 @@ func @should_fuse_at_depth1_and_reduce_slice_trip_count() { affine.for %i0 = 0 to 4 { affine.for %i1 = 0 to 256 { - %v0 = load %b[%i0, %i1] : memref<4x256xf32> + %v0 = affine.load %b[%i0, %i1] : memref<4x256xf32> } affine.for %i2 = 0 to 256 { - store %cf0, %a[%i0, %i2] : memref<4x256xf32> + affine.store %cf0, %a[%i0, %i2] : memref<4x256xf32> } } affine.for %d0 = 0 to 4 { affine.for %d1 = 0 to 16 { - %v1 = load %a[%d0, %d1] : memref<4x256xf32> + %v1 = affine.load %a[%d0, %d1] : memref<4x256xf32> } } // The cost of fusing at depth 2 is greater than the cost of fusing at depth 1 @@ -1115,17 +988,13 @@ func @should_fuse_at_depth1_and_reduce_slice_trip_count() { // CHECK-DAG: %0 = alloc() : memref<1x16xf32> // CHECK: affine.for %i0 = 0 to 4 { // CHECK-NEXT: affine.for %i1 = 0 to 256 { - // CHECK-NEXT: %2 = load %1[%i0, %i1] : memref<4x256xf32> + // CHECK-NEXT: %2 = affine.load %1[%i0, %i1] : memref<4x256xf32> // CHECK-NEXT: } // CHECK-NEXT: affine.for %i2 = 0 to 16 { - // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i0, %i2) - // CHECK-NEXT: %4 = affine.apply [[MAP1]](%i0, %i0, %i2) - // CHECK-NEXT: store %cst, %0[%3, %4] : memref<1x16xf32> + // CHECK-NEXT: affine.store %cst, %0[0, %i2] : memref<1x16xf32> // CHECK-NEXT: } // CHECK-NEXT: affine.for %i3 = 0 to 16 { - // CHECK-NEXT: %5 = affine.apply [[MAP0]](%i0, %i0, %i3) - // CHECK-NEXT: %6 = affine.apply [[MAP1]](%i0, %i0, %i3) - // CHECK-NEXT: %7 = load %0[%5, %6] : memref<1x16xf32> + // CHECK-NEXT: %3 = affine.load %0[0, %i3] : memref<1x16xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -1141,16 +1010,16 @@ func @should_fuse_at_depth1_with_trip_count_20() { %cf0 = constant 0.0 : f32 affine.for %i0 = 0 to 100 { - store %cf0, %a[%i0]: memref<100xf32> + affine.store %cf0, %a[%i0]: memref<100xf32> } affine.for %i1 = 0 to 5 { affine.for %i2 = 0 to 10 { - %v0 = load %a[%i2]: memref<100xf32> + %v0 = affine.load %a[%i2]: memref<100xf32> } affine.for %i3 = 0 to 10 { affine.for %i4 = 0 to 20 { - %v1 = load %a[%i4]: memref<100xf32> + %v1 = affine.load %a[%i4]: memref<100xf32> } } } @@ -1158,14 +1027,14 @@ func @should_fuse_at_depth1_with_trip_count_20() { // CHECK-DAG: %0 = alloc() : memref<20xf32> // CHECK: affine.for %i0 = 0 to 5 { // CHECK-NEXT: affine.for %i1 = 0 to 20 { - // CHECK-NEXT: store %cst, %0[%i1] : memref<20xf32> + // CHECK-NEXT: affine.store %cst, %0[%i1] : memref<20xf32> // CHECK-NEXT: } // CHECK-NEXT: affine.for %i2 = 0 to 10 { - // CHECK-NEXT: %1 = load %0[%i2] : memref<20xf32> + // CHECK-NEXT: %1 = affine.load %0[%i2] : memref<20xf32> // CHECK-NEXT: } // CHECK-NEXT: affine.for %i3 = 0 to 10 { // CHECK-NEXT: affine.for %i4 = 0 to 20 { - // CHECK-NEXT: %2 = load %0[%i4] : memref<20xf32> + // CHECK-NEXT: %2 = affine.load %0[%i4] : memref<20xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } @@ -1182,16 +1051,16 @@ func @should_fuse_at_depth1_with_trip_count_19() { %cf0 = constant 0.0 : f32 affine.for %i0 = 0 to 100 { - store %cf0, %a[%i0]: memref<100xf32> + affine.store %cf0, %a[%i0]: memref<100xf32> } affine.for %i1 = 0 to 5 { affine.for %i2 = 0 to 19 { - %v0 = load %a[%i2]: memref<100xf32> + %v0 = affine.load %a[%i2]: memref<100xf32> } affine.for %i3 = 0 to 10 { affine.for %i4 = 0 to 10 { - %v1 = load %a[%i4]: memref<100xf32> + %v1 = affine.load %a[%i4]: memref<100xf32> } } } @@ -1199,14 +1068,14 @@ func @should_fuse_at_depth1_with_trip_count_19() { // CHECK-DAG: %0 = alloc() : memref<19xf32> // CHECK: affine.for %i0 = 0 to 5 { // CHECK-NEXT: affine.for %i1 = 0 to 19 { - // CHECK-NEXT: store %cst, %0[%i1] : memref<19xf32> + // CHECK-NEXT: affine.store %cst, %0[%i1] : memref<19xf32> // CHECK-NEXT: } // CHECK-NEXT: affine.for %i2 = 0 to 19 { - // CHECK-NEXT: %1 = load %0[%i2] : memref<19xf32> + // CHECK-NEXT: %1 = affine.load %0[%i2] : memref<19xf32> // CHECK-NEXT: } // CHECK-NEXT: affine.for %i3 = 0 to 10 { // CHECK-NEXT: affine.for %i4 = 0 to 10 { - // CHECK-NEXT: %2 = load %0[%i4] : memref<19xf32> + // CHECK-NEXT: %2 = affine.load %0[%i4] : memref<19xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } @@ -1216,7 +1085,6 @@ func @should_fuse_at_depth1_with_trip_count_19() { // ----- -// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) // CHECK-LABEL: func @should_fuse_with_private_memrefs_with_diff_shapes() { func @should_fuse_with_private_memrefs_with_diff_shapes() { @@ -1224,29 +1092,25 @@ func @should_fuse_with_private_memrefs_with_diff_shapes() { %cf7 = constant 7.0 : f32 affine.for %i0 = 0 to 100 { - store %cf7, %m[%i0] : memref<100xf32> + affine.store %cf7, %m[%i0] : memref<100xf32> } affine.for %i1 = 0 to 17 { - %v0 = load %m[%i1] : memref<100xf32> + %v0 = affine.load %m[%i1] : memref<100xf32> } affine.for %i2 = 0 to 82 { - %v1 = load %m[%i2] : memref<100xf32> + %v1 = affine.load %m[%i2] : memref<100xf32> } // Should create two new private memrefs customized to the shapes accessed // by loops %i1 and %i2. // CHECK-DAG: %0 = alloc() : memref<1xf32> // CHECK-DAG: %1 = alloc() : memref<1xf32> // CHECK: affine.for %i0 = 0 to 17 { - // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: store %cst, %1[%2] : memref<1xf32> - // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: %4 = load %1[%3] : memref<1xf32> + // CHECK-NEXT: affine.store %cst, %1[0] : memref<1xf32> + // CHECK-NEXT: %2 = affine.load %1[0] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: affine.for %i1 = 0 to 82 { - // CHECK-NEXT: %5 = affine.apply [[MAP0]](%i1, %i1) - // CHECK-NEXT: store %cst, %0[%5] : memref<1xf32> - // CHECK-NEXT: %6 = affine.apply [[MAP0]](%i1, %i1) - // CHECK-NEXT: %7 = load %0[%6] : memref<1xf32> + // CHECK-NEXT: affine.store %cst, %0[0] : memref<1xf32> + // CHECK-NEXT: %3 = affine.load %0[0] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -1259,10 +1123,10 @@ func @should_not_fuse_live_out_arg(%arg0: memref<10xf32>) { %cf7 = constant 7.0 : f32 affine.for %i0 = 0 to 10 { - store %cf7, %arg0[%i0] : memref<10xf32> + affine.store %cf7, %arg0[%i0] : memref<10xf32> } affine.for %i1 = 0 to 9 { - %v0 = load %arg0[%i1] : memref<10xf32> + %v0 = affine.load %arg0[%i1] : memref<10xf32> } // This tests that the loop nest '%i0' should not be removed after fusion // because it writes to memref argument '%arg0', and its read region @@ -1270,10 +1134,10 @@ func @should_not_fuse_live_out_arg(%arg0: memref<10xf32>) { // in the fused loop nest, so complete live out data region would not // be written). // CHECK: affine.for %i0 = 0 to 10 { - // CHECK-NEXT: store %cst, %arg0[%i0] : memref<10xf32> + // CHECK-NEXT: affine.store %cst, %arg0[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: affine.for %i1 = 0 to 9 { - // CHECK-NEXT: %0 = load %arg0[%i1] : memref<10xf32> + // CHECK-NEXT: %0 = affine.load %arg0[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -1286,17 +1150,17 @@ func @should_fuse_live_out_arg(%arg0: memref<10xf32>) { %cf7 = constant 7.0 : f32 affine.for %i0 = 0 to 10 { - store %cf7, %arg0[%i0] : memref<10xf32> + affine.store %cf7, %arg0[%i0] : memref<10xf32> } affine.for %i1 = 0 to 10 { - %v0 = load %arg0[%i1] : memref<10xf32> + %v0 = affine.load %arg0[%i1] : memref<10xf32> } // The read/write regions for memref '%arg0' are the same for both // loops, so they should fuse. // CHECK: affine.for %i0 = 0 to 10 { - // CHECK-NEXT: store %cst, %arg0[%i0] : memref<10xf32> - // CHECK-NEXT: %0 = load %arg0[%i0] : memref<10xf32> + // CHECK-NEXT: affine.store %cst, %arg0[%i0] : memref<10xf32> + // CHECK-NEXT: %0 = affine.load %arg0[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -1309,19 +1173,19 @@ func @should_not_fuse_escaping_memref() -> memref<10xf32> { %cf7 = constant 7.0 : f32 %m = alloc() : memref<10xf32> affine.for %i0 = 0 to 10 { - store %cf7, %m[%i0] : memref<10xf32> + affine.store %cf7, %m[%i0] : memref<10xf32> } affine.for %i1 = 0 to 9 { - %v0 = load %m[%i1] : memref<10xf32> + %v0 = affine.load %m[%i1] : memref<10xf32> } // This tests that the loop nest '%i0' should not be removed after fusion // because it writes to memref '%m' which is returned by the function. // CHECK-DAG: %0 = alloc() : memref<10xf32> // CHECK: affine.for %i0 = 0 to 10 { - // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> + // CHECK-NEXT: affine.store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: affine.for %i1 = 0 to 9 { - // CHECK-NEXT: %1 = load %0[%i1] : memref<10xf32> + // CHECK-NEXT: %1 = affine.load %0[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return %0 : memref<10xf32> return %m : memref<10xf32> @@ -1339,7 +1203,7 @@ func @R3_to_R2_reshape() { affine.for %i1 = 0 to 3 { affine.for %i2 = 0 to 16 { %val = "foo"(%i0, %i1, %i2) : (index, index, index) -> i32 - store %val, %in[%i0, %i1, %i2] : memref<2x3x16xi32> + affine.store %val, %in[%i0, %i1, %i2] : memref<2x3x16xi32> } } } @@ -1348,18 +1212,15 @@ func @R3_to_R2_reshape() { affine.for %jj = 0 to 3 { %a0 = affine.apply (d0, d1) -> (d0 * 3 + d1) (%ii, %jj) %idx = affine.apply (d0) -> (d0 floordiv (3 * 16)) (%a0) - %v = load %in[%idx, %jj, %c0] + %v = affine.load %in[%idx, %jj, %c0] : memref<2x3x16xi32> } } return } -// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> ((d0 * 3 + d1) floordiv 48) -// CHECK-NEXT: [[MAP2:#map[0-9]+]] = (d0, d1, d2, d3, d4) -> (d2 - (d0 * 3 + d1) floordiv 48) -// CHECK-NEXT: [[MAP3:#map[0-9]+]] = (d0, d1, d2, d3, d4) -> (-d1 + d3) -// CHECK-NEXT: [[MAP4:#map[0-9]+]] = (d0, d1, d2, d3, d4) -> (d4) -// CHECK-NEXT: [[MAP5:#map[0-9]+]] = (d0, d1) -> (d0 * 3 + d1) -// CHECK-NEXT: [[MAP6:#map[0-9]+]] = (d0) -> (d0 floordiv 48) +// CHECK-DAG: [[MAP0:#map[0-9]+]] = (d0, d1) -> ((d0 * 3 + d1) floordiv 48) +// CHECK-DAG: [[MAP1:#map[0-9]+]] = (d0, d1) -> (d0 * 3 + d1) +// CHECK-DAG: [[MAP2:#map[0-9]+]] = (d0) -> (d0 floordiv 48) // CHECK-LABEL: func @R3_to_R2_reshape() // CHECK-DAG: %0 = alloc() : memref<1x1x1xi32> @@ -1367,16 +1228,10 @@ func @R3_to_R2_reshape() { // CHECK-NEXT: affine.for %i1 = 0 to 3 { // CHECK-NEXT: %1 = affine.apply [[MAP0]](%i0, %i1) // CHECK-NEXT: %2 = "foo"(%1, %i1, %c0) : (index, index, index) -> i32 -// CHECK-NEXT: %3 = affine.apply [[MAP2]](%i0, %i1, %1, %i1, %c0) -// CHECK-NEXT: %4 = affine.apply [[MAP3]](%i0, %i1, %1, %i1, %c0) -// CHECK-NEXT: %5 = affine.apply [[MAP4]](%i0, %i1, %1, %i1, %c0) -// CHECK-NEXT: store %2, %0[%3, %4, %5] : memref<1x1x1xi32> -// CHECK-NEXT: %6 = affine.apply [[MAP5]](%i0, %i1) -// CHECK-NEXT: %7 = affine.apply [[MAP6]](%6) -// CHECK-NEXT: %8 = affine.apply [[MAP2]](%i0, %i1, %7, %i1, %c0_0) -// CHECK-NEXT: %9 = affine.apply [[MAP3]](%i0, %i1, %7, %i1, %c0_0) -// CHECK-NEXT: %10 = affine.apply [[MAP4]](%i0, %i1, %7, %i1, %c0_0) -// CHECK-NEXT: %11 = load %0[%8, %9, %10] : memref<1x1x1xi32> +// CHECK-NEXT: affine.store %2, %0[0, 0, symbol(%c0)] : memref<1x1x1xi32> +// CHECK-NEXT: %3 = affine.apply [[MAP1]](%i0, %i1) +// CHECK-NEXT: %4 = affine.apply [[MAP2]](%3) +// CHECK-NEXT: %5 = affine.load %0[0, 0, symbol(%c0_0)] : memref<1x1x1xi32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -1391,19 +1246,19 @@ func @should_not_fuse_multi_output_producer() { %cf7 = constant 7.0 : f32 affine.for %i0 = 0 to 10 { - store %cf7, %a[%i0] : memref<10xf32> - store %cf7, %b[%i0] : memref<10xf32> + affine.store %cf7, %a[%i0] : memref<10xf32> + affine.store %cf7, %b[%i0] : memref<10xf32> } affine.for %i1 = 0 to 10 { - %v0 = load %a[%i1] : memref<10xf32> + %v0 = affine.load %a[%i1] : memref<10xf32> } // CHECK: affine.for %i0 = 0 to 10 { - // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> - // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32> + // CHECK-NEXT: affine.store %cst, %0[%i0] : memref<10xf32> + // CHECK-NEXT: affine.store %cst, %1[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: affine.for %i1 = 0 to 10 { - // CHECK-NEXT: %2 = load %0[%i1] : memref<10xf32> + // CHECK-NEXT: %2 = affine.load %0[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -1420,31 +1275,31 @@ func @fusion_preventing_deps_on_middle_loop() { %cf7 = constant 7.0 : f32 affine.for %i0 = 0 to 10 { - %v0 = load %a[%i0] : memref<10xf32> - store %v0, %b[%i0] : memref<10xf32> + %v0 = affine.load %a[%i0] : memref<10xf32> + affine.store %v0, %b[%i0] : memref<10xf32> } affine.for %i1 = 0 to 10 { - store %cf7, %a[%i1] : memref<10xf32> - %v1 = load %c[%i1] : memref<10xf32> + affine.store %cf7, %a[%i1] : memref<10xf32> + %v1 = affine.load %c[%i1] : memref<10xf32> } affine.for %i2 = 0 to 10 { - %v2 = load %b[%i2] : memref<10xf32> - store %v2, %c[%i2] : memref<10xf32> + %v2 = affine.load %b[%i2] : memref<10xf32> + affine.store %v2, %c[%i2] : memref<10xf32> } // Loops '%i0' and '%i2' cannot fuse along producer/consumer edge on memref // '%b', because of the WAR dep from '%i0' to '%i1' on memref '%a' and // because of the WAR dep from '%i1' to '%i2' on memref '%c'. // CHECK: affine.for %i0 = 0 to 10 { - // CHECK-NEXT: %3 = load %0[%i0] : memref<10xf32> - // CHECK-NEXT: store %3, %1[%i0] : memref<10xf32> + // CHECK-NEXT: %3 = affine.load %0[%i0] : memref<10xf32> + // CHECK-NEXT: affine.store %3, %1[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: affine.for %i1 = 0 to 10 { - // CHECK-NEXT: store %cst, %0[%i1] : memref<10xf32> - // CHECK-NEXT: %4 = load %2[%i1] : memref<10xf32> + // CHECK-NEXT: affine.store %cst, %0[%i1] : memref<10xf32> + // CHECK-NEXT: %4 = affine.load %2[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: affine.for %i2 = 0 to 10 { - // CHECK-NEXT: %5 = load %1[%i2] : memref<10xf32> - // CHECK-NEXT: store %5, %2[%i2] : memref<10xf32> + // CHECK-NEXT: %5 = affine.load %1[%i2] : memref<10xf32> + // CHECK-NEXT: affine.store %5, %2[%i2] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -1452,8 +1307,6 @@ func @fusion_preventing_deps_on_middle_loop() { // ----- -// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) - // CHECK-LABEL: func @should_fuse_and_move_to_preserve_war_dep() { func @should_fuse_and_move_to_preserve_war_dep() { %a = alloc() : memref<10xf32> @@ -1463,18 +1316,18 @@ func @should_fuse_and_move_to_preserve_war_dep() { %cf7 = constant 7.0 : f32 affine.for %i0 = 0 to 10 { - %v0 = load %b[%i0] : memref<10xf32> - store %v0, %a[%i0] : memref<10xf32> + %v0 = affine.load %b[%i0] : memref<10xf32> + affine.store %v0, %a[%i0] : memref<10xf32> } affine.for %i1 = 0 to 3 { - %v2 = load %c[%i1] : memref<10xf32> + %v2 = affine.load %c[%i1] : memref<10xf32> } affine.for %i2 = 0 to 5 { - store %cf7, %b[%i2] : memref<10xf32> + affine.store %cf7, %b[%i2] : memref<10xf32> } affine.for %i3 = 0 to 10 { - %v1 = load %a[%i3] : memref<10xf32> - store %cf7, %c[%i3] : memref<10xf32> + %v1 = affine.load %a[%i3] : memref<10xf32> + affine.store %cf7, %c[%i3] : memref<10xf32> } // Dependence graph: @@ -1492,18 +1345,16 @@ func @should_fuse_and_move_to_preserve_war_dep() { // CHECK-DAG: %0 = alloc() : memref<1xf32> // CHECK: affine.for %i0 = 0 to 3 { - // CHECK-NEXT: %3 = load %2[%i0] : memref<10xf32> + // CHECK-NEXT: %3 = affine.load %2[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: affine.for %i1 = 0 to 10 { - // CHECK-NEXT: %4 = load %1[%i1] : memref<10xf32> - // CHECK-NEXT: %5 = affine.apply [[MAP0]](%i1, %i1) - // CHECK-NEXT: store %4, %0[%5] : memref<1xf32> - // CHECK-NEXT: %6 = affine.apply [[MAP0]](%i1, %i1) - // CHECK-NEXT: %7 = load %0[%6] : memref<1xf32> - // CHECK-NEXT: store %cst, %2[%i1] : memref<10xf32> + // CHECK-NEXT: %4 = affine.load %1[%i1] : memref<10xf32> + // CHECK-NEXT: affine.store %4, %0[0] : memref<1xf32> + // CHECK-NEXT: %5 = affine.load %0[0] : memref<1xf32> + // CHECK-NEXT: affine.store %cst, %2[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: affine.for %i2 = 0 to 5 { - // CHECK-NEXT: store %cst, %1[%i2] : memref<10xf32> + // CHECK-NEXT: affine.store %cst, %1[%i2] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -1520,31 +1371,31 @@ func @fusion_preventing_dep_on_constant() { %cf7 = constant 7.0 : f32 affine.for %i0 = 0 to 10 { - %v0 = load %b[%i0] : memref<10xf32> - store %cf7, %a[%i0] : memref<10xf32> + %v0 = affine.load %b[%i0] : memref<10xf32> + affine.store %cf7, %a[%i0] : memref<10xf32> } affine.for %i1 = 0 to 10 { - store %cf7, %b[%i1] : memref<10xf32> + affine.store %cf7, %b[%i1] : memref<10xf32> } %cf11 = constant 11.0 : f32 affine.for %i2 = 0 to 10 { - %v2 = load %a[%i2] : memref<10xf32> - store %cf11, %c[%i2] : memref<10xf32> + %v2 = affine.load %a[%i2] : memref<10xf32> + affine.store %cf11, %c[%i2] : memref<10xf32> } // Loops '%i0' and '%i2' cannot fuse along producer/consumer edge on memref // '%a', because of the WAR dep from '%i0' to '%i1' on memref '%b' and // because of the SSA value dep from '%cf11' def to use in '%i2'. // CHECK: affine.for %i0 = 0 to 10 { - // CHECK-NEXT: %3 = load %1[%i0] : memref<10xf32> - // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> + // CHECK-NEXT: %3 = affine.load %1[%i0] : memref<10xf32> + // CHECK-NEXT: affine.store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: affine.for %i1 = 0 to 10 { - // CHECK-NEXT: store %cst, %1[%i1] : memref<10xf32> + // CHECK-NEXT: affine.store %cst, %1[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: %cst_0 = constant 1.100000e+01 : f32 // CHECK-NEXT: affine.for %i2 = 0 to 10 { - // CHECK-NEXT: %4 = load %0[%i2] : memref<10xf32> - // CHECK-NEXT: store %cst_0, %2[%i2] : memref<10xf32> + // CHECK-NEXT: %4 = affine.load %0[%i2] : memref<10xf32> + // CHECK-NEXT: affine.store %cst_0, %2[%i2] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -1552,8 +1403,6 @@ func @fusion_preventing_dep_on_constant() { // ----- -// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) - // CHECK-LABEL: func @should_fuse_and_preserve_dep_on_constant() { func @should_fuse_and_preserve_dep_on_constant() { %a = alloc() : memref<10xf32> @@ -1563,15 +1412,15 @@ func @should_fuse_and_preserve_dep_on_constant() { %cf7 = constant 7.0 : f32 %cf11 = constant 11.0 : f32 affine.for %i0 = 0 to 10 { - %v0 = load %b[%i0] : memref<10xf32> - store %cf7, %a[%i0] : memref<10xf32> + %v0 = affine.load %b[%i0] : memref<10xf32> + affine.store %cf7, %a[%i0] : memref<10xf32> } affine.for %i1 = 0 to 10 { - store %cf7, %b[%i1] : memref<10xf32> + affine.store %cf7, %b[%i1] : memref<10xf32> } affine.for %i2 = 0 to 10 { - %v2 = load %a[%i2] : memref<10xf32> - store %cf11, %c[%i2] : memref<10xf32> + %v2 = affine.load %a[%i2] : memref<10xf32> + affine.store %cf11, %c[%i2] : memref<10xf32> } // Loops '%i0' and '%i2' can fuse along producer/consumer edge on memref @@ -1580,15 +1429,13 @@ func @should_fuse_and_preserve_dep_on_constant() { // CHECK: %cst_0 = constant 1.100000e+01 : f32 // CHECK-NEXT: affine.for %i0 = 0 to 10 { - // CHECK-NEXT: %3 = load %1[%i0] : memref<10xf32> - // CHECK-NEXT: %4 = affine.apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: store %cst, %0[%4] : memref<1xf32> - // CHECK-NEXT: %5 = affine.apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: %6 = load %0[%5] : memref<1xf32> - // CHECK-NEXT: store %cst_0, %2[%i0] : memref<10xf32> + // CHECK-NEXT: %3 = affine.load %1[%i0] : memref<10xf32> + // CHECK-NEXT: affine.store %cst, %0[0] : memref<1xf32> + // CHECK-NEXT: %4 = affine.load %0[0] : memref<1xf32> + // CHECK-NEXT: affine.store %cst_0, %2[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: affine.for %i1 = 0 to 10 { - // CHECK-NEXT: store %cst, %1[%i1] : memref<10xf32> + // CHECK-NEXT: affine.store %cst, %1[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -1596,8 +1443,6 @@ func @should_fuse_and_preserve_dep_on_constant() { // ----- -// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1, d2) -> (d1) -// CHECK: [[MAP1:#map[0-9]+]] = (d0, d1, d2) -> (-d0 + d2) // CHECK: [[MAP2:#map[0-9]+]] = (d0, d1) -> (d0 * 16 - d1 + 15) // CHECK: [[MAP3:#map[0-9]+]] = (d0, d1) -> (d0 * 16 + d1) @@ -1607,28 +1452,28 @@ func @should_fuse_at_depth_above_loop_carried_dependence(%arg0: memref<64x4xf32> %0 = constant 0.0 : f32 affine.for %i0 = 0 to 64 { affine.for %i1 = 0 to 4 { - store %0, %out[%i0, %i1] : memref<64x4xf32> + affine.store %0, %out[%i0, %i1] : memref<64x4xf32> } } affine.for %i2 = 0 to 4 { affine.for %i3 = 0 to 4 { affine.for %i4 = 0 to 16 { %1 = affine.apply (d0, d1) -> (d0 * 16 - d1 + 15)(%i3, %i4) - %2 = load %arg1[%1, %i2] : memref<64x4xf32> + %2 = affine.load %arg1[%1, %i2] : memref<64x4xf32> "op0"(%2) : (f32) -> () } affine.for %i5 = 0 to 4 { affine.for %i6 = 0 to 16 { %3 = affine.apply (d0, d1) -> (d0 * 16 - d1 + 15)(%i5, %i6) - %4 = load %arg0[%3, %i3] : memref<64x4xf32> + %4 = affine.load %arg0[%3, %i3] : memref<64x4xf32> "op1"(%4) : (f32) -> () } affine.for %i7 = 0 to 16 { %5 = "op2"() : () -> (f32) %6 = affine.apply (d0, d1) -> (d0 * 16 + d1)(%i5, %i7) - %7 = load %out[%6, %i2] : memref<64x4xf32> + %7 = affine.load %out[%6, %i2] : memref<64x4xf32> %8 = addf %7, %5 : f32 - store %8, %out[%6, %i2] : memref<64x4xf32> + affine.store %8, %out[%6, %i2] : memref<64x4xf32> } } } @@ -1645,32 +1490,26 @@ func @should_fuse_at_depth_above_loop_carried_dependence(%arg0: memref<64x4xf32> // CHECK: %0 = alloc() : memref<64x1xf32> // CHECK: affine.for %i0 = 0 to 4 { // CHECK-NEXT: affine.for %i1 = 0 to 64 { - // CHECK-NEXT: %1 = affine.apply [[MAP0]](%i0, %i1, %i0) - // CHECK-NEXT: %2 = affine.apply [[MAP1]](%i0, %i1, %i0) - // CHECK-NEXT: store %cst, %0[%1, %2] : memref<64x1xf32> + // CHECK-NEXT: affine.store %cst, %0[%i1, 0] : memref<64x1xf32> // CHECK-NEXT: } // CHECK-NEXT: affine.for %i2 = 0 to 4 { // CHECK-NEXT: affine.for %i3 = 0 to 16 { - // CHECK-NEXT: %3 = affine.apply [[MAP2]](%i2, %i3) - // CHECK-NEXT: %4 = load %arg1[%3, %i0] : memref<64x4xf32> - // CHECK-NEXT: "op0"(%4) : (f32) -> () + // CHECK-NEXT: %1 = affine.apply [[MAP2]](%i2, %i3) + // CHECK-NEXT: %2 = affine.load %arg1[%1, %i0] : memref<64x4xf32> + // CHECK-NEXT: "op0"(%2) : (f32) -> () // CHECK-NEXT: } // CHECK-NEXT: affine.for %i4 = 0 to 4 { // CHECK-NEXT: affine.for %i5 = 0 to 16 { - // CHECK-NEXT: %5 = affine.apply [[MAP2]](%i4, %i5) - // CHECK-NEXT: %6 = load %arg0[%5, %i2] : memref<64x4xf32> - // CHECK-NEXT: "op1"(%6) : (f32) -> () + // CHECK-NEXT: %3 = affine.apply [[MAP2]](%i4, %i5) + // CHECK-NEXT: %4 = affine.load %arg0[%3, %i2] : memref<64x4xf32> + // CHECK-NEXT: "op1"(%4) : (f32) -> () // CHECK-NEXT: } // CHECK-NEXT: affine.for %i6 = 0 to 16 { - // CHECK-NEXT: %7 = "op2"() : () -> f32 - // CHECK-NEXT: %8 = affine.apply [[MAP3]](%i4, %i6) - // CHECK-NEXT: %9 = affine.apply [[MAP0]](%i0, %8, %i0) - // CHECK-NEXT: %10 = affine.apply [[MAP1]](%i0, %8, %i0) - // CHECK-NEXT: %11 = load %0[%9, %10] : memref<64x1xf32> - // CHECK-NEXT: %12 = addf %11, %7 : f32 - // CHECK-NEXT: %13 = affine.apply [[MAP0]](%i0, %8, %i0) - // CHECK-NEXT: %14 = affine.apply [[MAP1]](%i0, %8, %i0) - // CHECK-NEXT: store %12, %0[%13, %14] : memref<64x1xf32> + // CHECK-NEXT: %5 = "op2"() : () -> f32 + // CHECK-NEXT: %6 = affine.apply [[MAP3]](%i4, %i6) + // CHECK-NEXT: %7 = affine.load %0[%i4 * 16 + %i6, 0] : memref<64x1xf32> + // CHECK-NEXT: %8 = addf %7, %5 : f32 + // CHECK-NEXT: affine.store %8, %0[%i4 * 16 + %i6, 0] : memref<64x1xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } @@ -1681,8 +1520,6 @@ func @should_fuse_at_depth_above_loop_carried_dependence(%arg0: memref<64x4xf32> // ----- -// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) - // CHECK-LABEL: func @should_fuse_after_private_memref_creation() { func @should_fuse_after_private_memref_creation() { %a = alloc() : memref<10xf32> @@ -1691,15 +1528,15 @@ func @should_fuse_after_private_memref_creation() { %cf7 = constant 7.0 : f32 affine.for %i0 = 0 to 10 { - store %cf7, %a[%i0] : memref<10xf32> + affine.store %cf7, %a[%i0] : memref<10xf32> } affine.for %i1 = 0 to 10 { - %v0 = load %a[%i1] : memref<10xf32> - store %v0, %b[%i1] : memref<10xf32> + %v0 = affine.load %a[%i1] : memref<10xf32> + affine.store %v0, %b[%i1] : memref<10xf32> } affine.for %i2 = 0 to 10 { - %v1 = load %a[%i2] : memref<10xf32> - store %v1, %b[%i2] : memref<10xf32> + %v1 = affine.load %a[%i2] : memref<10xf32> + affine.store %v1, %b[%i2] : memref<10xf32> } // On the first visit to '%i2', the fusion algorithm can not fuse loop nest @@ -1709,18 +1546,14 @@ func @should_fuse_after_private_memref_creation() { // longer exists, so '%i0' can now be fused into '%i2'. // CHECK: affine.for %i0 = 0 to 10 { - // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: store %cst, %1[%3] : memref<1xf32> - // CHECK-NEXT: %4 = affine.apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: %5 = load %1[%4] : memref<1xf32> - // CHECK-NEXT: store %5, %2[%i0] : memref<10xf32> + // CHECK-NEXT: affine.store %cst, %1[0] : memref<1xf32> + // CHECK-NEXT: %3 = affine.load %1[0] : memref<1xf32> + // CHECK-NEXT: affine.store %3, %2[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: affine.for %i1 = 0 to 10 { - // CHECK-NEXT: %6 = affine.apply [[MAP0]](%i1, %i1) - // CHECK-NEXT: store %cst, %0[%6] : memref<1xf32> - // CHECK-NEXT: %7 = affine.apply [[MAP0]](%i1, %i1) - // CHECK-NEXT: %8 = load %0[%7] : memref<1xf32> - // CHECK-NEXT: store %8, %2[%i1] : memref<10xf32> + // CHECK-NEXT: affine.store %cst, %0[0] : memref<1xf32> + // CHECK-NEXT: %4 = affine.load %0[0] : memref<1xf32> + // CHECK-NEXT: affine.store %4, %2[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -1728,38 +1561,33 @@ func @should_fuse_after_private_memref_creation() { // ----- -// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) - // CHECK-LABEL: func @should_fuse_after_one_loop_interchange() { func @should_fuse_after_one_loop_interchange() { %a = alloc() : memref<10xf32> %cf0 = constant 0.0 : f32 affine.for %i0 = 0 to 10 { - store %cf0, %a[%i0] : memref<10xf32> + affine.store %cf0, %a[%i0] : memref<10xf32> } affine.for %i1 = 0 to 5 { affine.for %i2 = 0 to 10 { - %v0 = load %a[%i2] : memref<10xf32> - store %v0, %a[%i2] : memref<10xf32> + %v0 = affine.load %a[%i2] : memref<10xf32> + affine.store %v0, %a[%i2] : memref<10xf32> } } - // The dependence between the load and store is carried on loop '%i1', and + // The dependence between the load and affine.store is carried on loop '%i1', and // cannot be fused with loop '%i0' without violating this dependence. // Once loops '%i1' and %i2' are interchanged, loop '%i0' can be fused // at loop depth 1, because the loop carrying the dependence has been // interchanged and is now at depth 2. // CHECK: affine.for %i0 = 0 to 10 { - // CHECK-NEXT: %1 = affine.apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: store %cst, %0[%1] : memref<1xf32> + // CHECK-NEXT: affine.store %cst, %0[0] : memref<1xf32> // CHECK-NEXT: affine.for %i1 = 0 to 5 { - // CHECK-NEXT: %2 = affine.apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: %3 = load %0[%2] : memref<1xf32> - // CHECK-NEXT: %4 = affine.apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: store %3, %0[%4] : memref<1xf32> + // CHECK-NEXT: %1 = affine.load %0[0] : memref<1xf32> + // CHECK-NEXT: affine.store %1, %0[0] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -1768,9 +1596,6 @@ func @should_fuse_after_one_loop_interchange() { // ----- -// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1, d2, d3) -> (-d0 + d2) -// CHECK: [[MAP1:#map[0-9]+]] = (d0, d1, d2, d3) -> (-d1 + d3) - // CHECK-LABEL: func @should_fuse_after_two_loop_interchanges() { func @should_fuse_after_two_loop_interchanges() { %a = alloc() : memref<6x8xf32> @@ -1778,7 +1603,7 @@ func @should_fuse_after_two_loop_interchanges() { %cf0 = constant 0.0 : f32 affine.for %i0 = 0 to 6 { affine.for %i1 = 0 to 8 { - store %cf0, %a[%i0, %i1] : memref<6x8xf32> + affine.store %cf0, %a[%i0, %i1] : memref<6x8xf32> } } @@ -1786,15 +1611,15 @@ func @should_fuse_after_two_loop_interchanges() { affine.for %i3 = 0 to 6 { affine.for %i4 = 0 to 2 { affine.for %i5 = 0 to 8 { - %v0 = load %a[%i3, %i5] : memref<6x8xf32> + %v0 = affine.load %a[%i3, %i5] : memref<6x8xf32> %v1 = addf %v0, %v0 : f32 - store %v1, %a[%i3, %i5] : memref<6x8xf32> + affine.store %v1, %a[%i3, %i5] : memref<6x8xf32> } } } } - // The dependence between the load and store is carried on loops '%i2' and + // The dependence between the load and affine.store is carried on loops '%i2' and // '%i4', and cannot be fused with loop '%i0' without violating this // dependence. // Once loop '%i2' is interchanged with loop '%i3', and again with loop @@ -1803,18 +1628,12 @@ func @should_fuse_after_two_loop_interchanges() { // CHECK: affine.for %i0 = 0 to 6 { // CHECK-NEXT: affine.for %i1 = 0 to 8 { - // CHECK-NEXT: %1 = affine.apply [[MAP0]](%i0, %i1, %i0, %i1) - // CHECK-NEXT: %2 = affine.apply [[MAP1]](%i0, %i1, %i0, %i1) - // CHECK-NEXT: store %cst, %0[%1, %2] : memref<1x1xf32> + // CHECK-NEXT: affine.store %cst, %0[0, 0] : memref<1x1xf32> // CHECK-NEXT: affine.for %i2 = 0 to 4 { // CHECK-NEXT: affine.for %i3 = 0 to 2 { - // CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i1, %i0, %i1) - // CHECK-NEXT: %4 = affine.apply [[MAP1]](%i0, %i1, %i0, %i1) - // CHECK-NEXT: %5 = load %0[%3, %4] : memref<1x1xf32> - // CHECK-NEXT: %6 = addf %5, %5 : f32 - // CHECK-NEXT: %7 = affine.apply [[MAP0]](%i0, %i1, %i0, %i1) - // CHECK-NEXT: %8 = affine.apply [[MAP1]](%i0, %i1, %i0, %i1) - // CHECK-NEXT: store %6, %0[%7, %8] : memref<1x1xf32> + // CHECK-NEXT: %1 = affine.load %0[0, 0] : memref<1x1xf32> + // CHECK-NEXT: %2 = addf %1, %1 : f32 + // CHECK-NEXT: affine.store %2, %0[0, 0] : memref<1x1xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } @@ -1828,19 +1647,19 @@ func @should_fuse_after_two_loop_interchanges() { func @should_fuse_live_out_writer(%arg0 : memref<10xf32>) -> memref<10xf32> { %cst = constant 0.000000e+00 : f32 affine.for %i0 = 0 to 10 { - store %cst, %arg0[%i0] : memref<10xf32> + affine.store %cst, %arg0[%i0] : memref<10xf32> } affine.for %i1 = 0 to 10 { - %1 = load %arg0[%i1] : memref<10xf32> - store %1, %arg0[%i1] : memref<10xf32> + %1 = affine.load %arg0[%i1] : memref<10xf32> + affine.store %1, %arg0[%i1] : memref<10xf32> } return %arg0 : memref<10xf32> // CHECK: %cst = constant 0.000000e+00 : f32 // CHECK-NEXT: affine.for %i0 = 0 to 10 { - // CHECK-NEXT: store %cst, %arg0[%i0] : memref<10xf32> - // CHECK-NEXT: %0 = load %arg0[%i0] : memref<10xf32> - // CHECK-NEXT: store %0, %arg0[%i0] : memref<10xf32> + // CHECK-NEXT: affine.store %cst, %arg0[%i0] : memref<10xf32> + // CHECK-NEXT: %0 = affine.load %arg0[%i0] : memref<10xf32> + // CHECK-NEXT: affine.store %0, %arg0[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return %arg0 : memref<10xf32> } @@ -1858,7 +1677,7 @@ func @should_fuse_live_out_writer(%arg0 : memref<10xf32>) -> memref<10xf32> { func @slice_tile(%arg0: memref<128x8xf32>, %arg1: memref<32x8xf32>, %0 : f32) -> memref<32x8xf32> { affine.for %i0 = 0 to 32 { affine.for %i1 = 0 to 8 { - store %0, %arg1[%i0, %i1] : memref<32x8xf32> + affine.store %0, %arg1[%i0, %i1] : memref<32x8xf32> } } affine.for %i = 0 to 2 { @@ -1866,14 +1685,14 @@ func @slice_tile(%arg0: memref<128x8xf32>, %arg1: memref<32x8xf32>, %0 : f32) -> affine.for %k = 0 to 8 { affine.for %kk = 0 to 16 { %1 = affine.apply #map(%k, %kk) - %2 = load %arg0[%1, %j] : memref<128x8xf32> + %2 = affine.load %arg0[%1, %j] : memref<128x8xf32> %3 = "foo"(%2) : (f32) -> f32 } affine.for %ii = 0 to 16 { %6 = affine.apply #map(%i, %ii) - %7 = load %arg1[%6, %j] : memref<32x8xf32> + %7 = affine.load %arg1[%6, %j] : memref<32x8xf32> %8 = addf %7, %7 : f32 - store %8, %arg1[%6, %j] : memref<32x8xf32> + affine.store %8, %arg1[%6, %j] : memref<32x8xf32> } } } @@ -1883,19 +1702,19 @@ func @slice_tile(%arg0: memref<128x8xf32>, %arg1: memref<32x8xf32>, %0 : f32) -> // CHECK: affine.for %i0 = 0 to 2 { // CHECK-NEXT: affine.for %i1 = 0 to 8 { // CHECK-NEXT: affine.for %i2 = [[MAP_LB]](%i0) to [[MAP_UB]](%i0) { -// CHECK-NEXT: store %arg2, %arg1[%i2, %i1] : memref<32x8xf32> +// CHECK-NEXT: affine.store %arg2, %arg1[%i2, %i1] : memref<32x8xf32> // CHECK-NEXT: } // CHECK-NEXT: affine.for %i3 = 0 to 8 { // CHECK-NEXT: affine.for %i4 = 0 to 16 { // CHECK-NEXT: %0 = affine.apply #map{{[0-9]+}}(%i3, %i4) -// CHECK-NEXT: %1 = load %arg0[%0, %i1] : memref<128x8xf32> +// CHECK-NEXT: %1 = affine.load %arg0[%0, %i1] : memref<128x8xf32> // CHECK-NEXT: %2 = "foo"(%1) : (f32) -> f32 // CHECK-NEXT: } // CHECK-NEXT: affine.for %i5 = 0 to 16 { // CHECK-NEXT: %3 = affine.apply #map{{[0-9]+}}(%i0, %i5) -// CHECK-NEXT: %4 = load %arg1[%3, %i1] : memref<32x8xf32> +// CHECK-NEXT: %4 = affine.load %arg1[%3, %i1] : memref<32x8xf32> // CHECK-NEXT: %5 = addf %4, %4 : f32 -// CHECK-NEXT: store %5, %arg1[%3, %i1] : memref<32x8xf32> +// CHECK-NEXT: affine.store %5, %arg1[%3, %i1] : memref<32x8xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } @@ -1918,14 +1737,14 @@ func @test_add_slice_bounds() { %a0 = affine.apply (d0) -> (d0) (%i0) %a1 = affine.apply (d0) -> (d0) (%i0) %a2 = affine.apply (d0, d1) -> (d0 - d1) (%a0, %a1) - store %cf7, %a[%a2] : memref<10xf32> + affine.store %cf7, %a[%a2] : memref<10xf32> } } } affine.for %i3 = 0 to 10 { affine.for %i4 = 0 to 10 { affine.for %i5 = 0 to 10 { - %v0 = load %a[%c0] : memref<10xf32> + %v0 = affine.load %a[%c0] : memref<10xf32> } } } @@ -1936,14 +1755,14 @@ func @test_add_slice_bounds() { // CHECK-NEXT: %2 = affine.apply #map0(%i0) // CHECK-NEXT: %3 = affine.apply #map0(%i0) // CHECK-NEXT: %4 = affine.apply #map1(%2, %3) -// CHECK-NEXT: store %cst, %0[%4] : memref<10xf32> +// CHECK-NEXT: affine.store %cst, %0[%4] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: affine.for %i3 = 0 to 10 { // CHECK-NEXT: affine.for %i4 = 0 to 10 { // CHECK-NEXT: affine.for %i5 = 0 to 10 { -// CHECK-NEXT: %5 = load %0[%c0] : memref<10xf32> +// CHECK-NEXT: %5 = affine.load %0[%c0] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } @@ -1951,8 +1770,6 @@ func @test_add_slice_bounds() { } // ----- -// CHECK-DAG: [[MAP0:#map[0-9]+]] = (d0, d1, d2, d3) -> (-d0 + d2) -// CHECK-DAG: [[MAP1:#map[0-9]+]] = (d0, d1, d2, d3) -> (-d1 + d3) func @should_fuse_init_loops_siblings_then_shared_producer(%arg0: memref<10x10xf32>, %arg1: memref<10x10xf32>) { %0 = alloc() : memref<10x10xf32> @@ -1961,33 +1778,33 @@ func @should_fuse_init_loops_siblings_then_shared_producer(%arg0: memref<10x10xf %cst_1 = constant 7.000000e+00 : f32 affine.for %i0 = 0 to 10 { affine.for %i1 = 0 to 10 { - store %cst_1, %0[%i0, %i1] : memref<10x10xf32> + affine.store %cst_1, %0[%i0, %i1] : memref<10x10xf32> } } affine.for %i2 = 0 to 3 { affine.for %i3 = 0 to 3 { - store %cst, %arg0[%i2, %i3] : memref<10x10xf32> + affine.store %cst, %arg0[%i2, %i3] : memref<10x10xf32> } } affine.for %i4 = 0 to 3 { affine.for %i5 = 0 to 3 { - %1 = load %0[%i4, %i5] : memref<10x10xf32> - %2 = load %arg0[%i4, %i5] : memref<10x10xf32> + %1 = affine.load %0[%i4, %i5] : memref<10x10xf32> + %2 = affine.load %arg0[%i4, %i5] : memref<10x10xf32> %3 = mulf %1, %2 : f32 - store %3, %arg0[%i4, %i5] : memref<10x10xf32> + affine.store %3, %arg0[%i4, %i5] : memref<10x10xf32> } } affine.for %i6 = 0 to 3 { affine.for %i7 = 0 to 3 { - store %cst_0, %arg1[%i6, %i7] : memref<10x10xf32> + affine.store %cst_0, %arg1[%i6, %i7] : memref<10x10xf32> } } affine.for %i8 = 0 to 3 { affine.for %i9 = 0 to 3 { - %4 = load %0[%i8, %i9] : memref<10x10xf32> - %5 = load %arg1[%i8, %i9] : memref<10x10xf32> + %4 = affine.load %0[%i8, %i9] : memref<10x10xf32> + %5 = affine.load %arg1[%i8, %i9] : memref<10x10xf32> %6 = addf %4, %5 : f32 - store %6, %arg1[%i8, %i9] : memref<10x10xf32> + affine.store %6, %arg1[%i8, %i9] : memref<10x10xf32> } } @@ -2001,23 +1818,17 @@ func @should_fuse_init_loops_siblings_then_shared_producer(%arg0: memref<10x10xf // CHECK: affine.for %i0 = 0 to 3 { // CHECK-NEXT: affine.for %i1 = 0 to 3 { -// CHECK-NEXT: %1 = affine.apply [[MAP0]](%i0, %i1, %i0, %i1) -// CHECK-NEXT: %2 = affine.apply [[MAP1]](%i0, %i1, %i0, %i1) -// CHECK-NEXT: store %cst_1, %0[%1, %2] : memref<1x1xf32> -// CHECK-NEXT: store %cst, %arg0[%i0, %i1] : memref<10x10xf32> -// CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i1, %i0, %i1) -// CHECK-NEXT: %4 = affine.apply [[MAP1]](%i0, %i1, %i0, %i1) -// CHECK-NEXT: %5 = load %0[%3, %4] : memref<1x1xf32> -// CHECK-NEXT: %6 = load %arg0[%i0, %i1] : memref<10x10xf32> -// CHECK-NEXT: %7 = mulf %5, %6 : f32 -// CHECK-NEXT: store %7, %arg0[%i0, %i1] : memref<10x10xf32> -// CHECK-NEXT: store %cst_0, %arg1[%i0, %i1] : memref<10x10xf32> -// CHECK-NEXT: %8 = affine.apply [[MAP0]](%i0, %i1, %i0, %i1) -// CHECK-NEXT: %9 = affine.apply [[MAP1]](%i0, %i1, %i0, %i1) -// CHECK-NEXT: %10 = load %0[%8, %9] : memref<1x1xf32> -// CHECK-NEXT: %11 = load %arg1[%i0, %i1] : memref<10x10xf32> -// CHECK-NEXT: %12 = addf %10, %11 : f32 -// CHECK-NEXT: store %12, %arg1[%i0, %i1] : memref<10x10xf32> +// CHECK-NEXT: affine.store %cst_1, %0[0, 0] : memref<1x1xf32> +// CHECK-NEXT: affine.store %cst, %arg0[%i0, %i1] : memref<10x10xf32> +// CHECK-NEXT: %1 = affine.load %0[0, 0] : memref<1x1xf32> +// CHECK-NEXT: %2 = affine.load %arg0[%i0, %i1] : memref<10x10xf32> +// CHECK-NEXT: %3 = mulf %1, %2 : f32 +// CHECK-NEXT: affine.store %3, %arg0[%i0, %i1] : memref<10x10xf32> +// CHECK-NEXT: affine.store %cst_0, %arg1[%i0, %i1] : memref<10x10xf32> +// CHECK-NEXT: %4 = affine.load %0[0, 0] : memref<1x1xf32> +// CHECK-NEXT: %5 = affine.load %arg1[%i0, %i1] : memref<10x10xf32> +// CHECK-NEXT: %6 = addf %4, %5 : f32 +// CHECK-NEXT: affine.store %6, %arg1[%i0, %i1] : memref<10x10xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -2026,8 +1837,6 @@ func @should_fuse_init_loops_siblings_then_shared_producer(%arg0: memref<10x10xf } // ----- -// CHECK-DAG: [[MAP2:#map[0-9]+]] = (d0, d1, d2) -> (d1) -// CHECK-DAG: [[MAP3:#map[0-9]+]] = (d0, d1, d2) -> (-d0 + d2) func @two_matrix_vector_products() { %in_matrix = alloc() : memref<10x10xf32> @@ -2040,57 +1849,51 @@ func @two_matrix_vector_products() { // Populate input matrix. affine.for %i0 = 0 to 10 { affine.for %i1 = 0 to 10 { - store %cf7, %in_matrix[%i0, %i1] : memref<10x10xf32> + affine.store %cf7, %in_matrix[%i0, %i1] : memref<10x10xf32> } } // out_vec0 = in_matrix x in_vec0 affine.for %i2 = 0 to 10 { affine.for %i3 = 0 to 10 { - %v0 = load %in_matrix[%i2, %i3] : memref<10x10xf32> - %v1 = load %in_vec0[%i3] : memref<10xf32> + %v0 = affine.load %in_matrix[%i2, %i3] : memref<10x10xf32> + %v1 = affine.load %in_vec0[%i3] : memref<10xf32> %v2 = mulf %v0, %v1 : f32 - %v3 = load %out_vec0[%i3] : memref<10xf32> + %v3 = affine.load %out_vec0[%i3] : memref<10xf32> %v4 = addf %v2, %v3 : f32 - store %v4, %out_vec0[%i3] : memref<10xf32> + affine.store %v4, %out_vec0[%i3] : memref<10xf32> } } // out_vec1 = in_matrix x in_vec1 affine.for %i4 = 0 to 10 { affine.for %i5 = 0 to 10 { - %v5 = load %in_matrix[%i4, %i5] : memref<10x10xf32> - %v6 = load %in_vec1[%i5] : memref<10xf32> + %v5 = affine.load %in_matrix[%i4, %i5] : memref<10x10xf32> + %v6 = affine.load %in_vec1[%i5] : memref<10xf32> %v7 = mulf %v5, %v6 : f32 - %v8 = load %out_vec1[%i5] : memref<10xf32> + %v8 = affine.load %out_vec1[%i5] : memref<10xf32> %v9 = addf %v7, %v8 : f32 - store %v9, %out_vec1[%i5] : memref<10xf32> + affine.store %v9, %out_vec1[%i5] : memref<10xf32> } } // CHECK: affine.for %i0 = 0 to 10 { // CHECK-NEXT: affine.for %i1 = 0 to 10 { -// CHECK-NEXT: %5 = affine.apply [[MAP2]](%i0, %i1, %i0) -// CHECK-NEXT: %6 = affine.apply [[MAP3]](%i0, %i1, %i0) -// CHECK-NEXT: store %cst, %0[%5, %6] : memref<10x1xf32> +// CHECK-NEXT: affine.store %cst, %0[%i1, 0] : memref<10x1xf32> // CHECK-NEXT: } // CHECK-NEXT: affine.for %i2 = 0 to 10 { -// CHECK-NEXT: %7 = affine.apply [[MAP2]](%i0, %i2, %i0) -// CHECK-NEXT: %8 = affine.apply [[MAP3]](%i0, %i2, %i0) -// CHECK-NEXT: %9 = load %0[%7, %8] : memref<10x1xf32> -// CHECK-NEXT: %10 = load %1[%i0] : memref<10xf32> -// CHECK-NEXT: %11 = mulf %9, %10 : f32 -// CHECK-NEXT: %12 = load %3[%i0] : memref<10xf32> -// CHECK-NEXT: %13 = addf %11, %12 : f32 -// CHECK-NEXT: store %13, %3[%i0] : memref<10xf32> +// CHECK-NEXT: %5 = affine.load %0[%i2, 0] : memref<10x1xf32> +// CHECK-NEXT: %6 = affine.load %1[%i0] : memref<10xf32> +// CHECK-NEXT: %7 = mulf %5, %6 : f32 +// CHECK-NEXT: %8 = affine.load %3[%i0] : memref<10xf32> +// CHECK-NEXT: %9 = addf %7, %8 : f32 +// CHECK-NEXT: affine.store %9, %3[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: affine.for %i3 = 0 to 10 { -// CHECK-NEXT: %14 = affine.apply [[MAP2]](%i0, %i3, %i0) -// CHECK-NEXT: %15 = affine.apply [[MAP3]](%i0, %i3, %i0) -// CHECK-NEXT: %16 = load %0[%14, %15] : memref<10x1xf32> -// CHECK-NEXT: %17 = load %2[%i0] : memref<10xf32> -// CHECK-NEXT: %18 = mulf %16, %17 : f32 -// CHECK-NEXT: %19 = load %4[%i0] : memref<10xf32> -// CHECK-NEXT: %20 = addf %18, %19 : f32 -// CHECK-NEXT: store %20, %4[%i0] : memref<10xf32> +// CHECK-NEXT: %10 = affine.load %0[%i3, 0] : memref<10x1xf32> +// CHECK-NEXT: %11 = affine.load %2[%i0] : memref<10xf32> +// CHECK-NEXT: %12 = mulf %10, %11 : f32 +// CHECK-NEXT: %13 = affine.load %4[%i0] : memref<10xf32> +// CHECK-NEXT: %14 = addf %12, %13 : f32 +// CHECK-NEXT: affine.store %14, %4[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -2098,20 +1901,18 @@ func @two_matrix_vector_products() { } // ----- -// CHECK-DAG: [[MAP3:#map[0-9]+]] = (d0, d1, d2) -> (-d0 + d1) -// CHECK-DAG: [[MAP4:#map[0-9]+]] = (d0, d1, d2) -> (d2) func @should_not_slice_past_slice_barrier() { %0 = alloc() : memref<100x16xf32> affine.for %i0 = 0 to 100 { affine.for %i1 = 0 to 16 { %1 = "op1"() : () -> f32 - store %1, %0[%i0, %i1] : memref<100x16xf32> + affine.store %1, %0[%i0, %i1] : memref<100x16xf32> } {slice_fusion_barrier = true} } affine.for %i2 = 0 to 100 { affine.for %i3 = 0 to 16 { - %2 = load %0[%i2, %i3] : memref<100x16xf32> + %2 = affine.load %0[%i2, %i3] : memref<100x16xf32> "op2"(%2) : (f32) -> () } } @@ -2120,15 +1921,11 @@ func @should_not_slice_past_slice_barrier() { // CHECK: affine.for %i0 = 0 to 100 { // CHECK-NEXT: affine.for %i1 = 0 to 16 { // CHECK-NEXT: %1 = "op1"() : () -> f32 -// CHECK-NEXT: %2 = affine.apply [[MAP3]](%i0, %i0, %i1) -// CHECK-NEXT: %3 = affine.apply [[MAP4]](%i0, %i0, %i1) -// CHECK-NEXT: store %1, %0[%2, %3] : memref<1x16xf32> +// CHECK-NEXT: affine.store %1, %0[0, %i1] : memref<1x16xf32> // CHECK-NEXT: } {slice_fusion_barrier = true} // CHECK-NEXT: affine.for %i2 = 0 to 16 { -// CHECK-NEXT: %4 = affine.apply [[MAP3]](%i0, %i0, %i2) -// CHECK-NEXT: %5 = affine.apply [[MAP4]](%i0, %i0, %i2) -// CHECK-NEXT: %6 = load %0[%4, %5] : memref<1x16xf32> -// CHECK-NEXT: "op2"(%6) : (f32) -> () +// CHECK-NEXT: %2 = affine.load %0[0, %i2] : memref<1x16xf32> +// CHECK-NEXT: "op2"(%2) : (f32) -> () // CHECK-NEXT: } // CHECK-NEXT: } return @@ -2144,7 +1941,7 @@ func @fuse_across_dim_mismatch(%arg0: memref<4x4x16x1xf32>, %arg1: memref<144x9x affine.for %i3 = 0 to 4 { affine.for %i5 = 0 to 16 { %7 = affine.apply #map0(%i2, %i5) - store %2, %1[%7, %i3] : memref<144x4xf32> + affine.store %2, %1[%7, %i3] : memref<144x4xf32> } } } @@ -2153,7 +1950,7 @@ func @fuse_across_dim_mismatch(%arg0: memref<4x4x16x1xf32>, %arg1: memref<144x9x affine.for %i8 = 0 to 4 { affine.for %i10 = 0 to 16 { %10 = affine.apply #map0(%i6, %i10) - %11 = load %1[%10, %i8] : memref<144x4xf32> + %11 = affine.load %1[%10, %i8] : memref<144x4xf32> } } } @@ -2161,8 +1958,6 @@ func @fuse_across_dim_mismatch(%arg0: memref<4x4x16x1xf32>, %arg1: memref<144x9x return } // MAXIMAL: #map0 = (d0, d1) -> (d0 * 16 + d1) -// MAXIMAL-NEXT: #map1 = (d0, d1, d2, d3, d4) -> (d0 * -16 - d1 + d3) -// MAXIMAL-NEXT: #map2 = (d0, d1, d2, d3, d4) -> (-d2 + d4) // MAXIMAL-LABEL: func @fuse_across_dim_mismatch // MAXIMAL: %0 = alloc() : memref<1x1xf32> // MAXIMAL: affine.for %i0 = 0 to 9 { @@ -2170,13 +1965,9 @@ func @fuse_across_dim_mismatch(%arg0: memref<4x4x16x1xf32>, %arg1: memref<144x9x // MAXIMAL-NEXT: affine.for %i2 = 0 to 4 { // MAXIMAL-NEXT: affine.for %i3 = 0 to 16 { // MAXIMAL-NEXT: %1 = affine.apply #map0(%i0, %i3) -// MAXIMAL-NEXT: %2 = affine.apply #map1(%i0, %i3, %i2, %1, %i2) -// MAXIMAL-NEXT: %3 = affine.apply #map2(%i0, %i3, %i2, %1, %i2) -// MAXIMAL-NEXT: store %cst, %0[%2, %3] : memref<1x1xf32> -// MAXIMAL-NEXT: %4 = affine.apply #map0(%i0, %i3) -// MAXIMAL-NEXT: %5 = affine.apply #map1(%i0, %i3, %i2, %4, %i2) -// MAXIMAL-NEXT: %6 = affine.apply #map2(%i0, %i3, %i2, %4, %i2) -// MAXIMAL-NEXT: %7 = load %0[%5, %6] : memref<1x1xf32> +// MAXIMAL-NEXT: affine.store %cst, %0[0, 0] : memref<1x1xf32> +// MAXIMAL-NEXT: %2 = affine.apply #map0(%i0, %i3) +// MAXIMAL-NEXT: %3 = affine.load %0[0, 0] : memref<1x1xf32> // MAXIMAL-NEXT: } // MAXIMAL-NEXT: } // MAXIMAL-NEXT: } @@ -2204,20 +1995,20 @@ func @fuse_across_varying_dims_complex() { %6 = affine.apply #map5(%i0, %i1) %7 = affine.apply #map6(%i0, %i1) %8 = affine.apply #map7(%i0, %i1) - %9 = load %0[%4, %5, %7, %8, %6, %c0] : memref<2x2x3x3x16x1xf32> - store %9, %1[%i0, %i1] : memref<64x9xf32> + %9 = affine.load %0[%4, %5, %7, %8, %6, %c0] : memref<2x2x3x3x16x1xf32> + affine.store %9, %1[%i0, %i1] : memref<64x9xf32> } } affine.for %i2 = 0 to 9 { affine.for %i3 = 0 to 4 { affine.for %i4 = 0 to 16 { %10 = affine.apply #map10(%i3, %i4) - %11 = load %1[%10, %i2] : memref<64x9xf32> + %11 = affine.load %1[%10, %i2] : memref<64x9xf32> } affine.for %i5 = 0 to 16 { %13 = "bar"() : () -> f32 %14 = affine.apply #map11(%i2, %i5) - store %13, %2[%14, %i3] : memref<144x4xf32> + affine.store %13, %2[%14, %i3] : memref<144x4xf32> } } } @@ -2226,7 +2017,7 @@ func @fuse_across_varying_dims_complex() { affine.for %i8 = 0 to 4 { affine.for %i9 = 0 to 16 { %15 = affine.apply #map12(%i8, %i9) - %16 = load %1[%15, %i7] : memref<64x9xf32> + %16 = affine.load %1[%15, %i7] : memref<64x9xf32> } } } @@ -2238,8 +2029,6 @@ func @fuse_across_varying_dims_complex() { // MAXIMAL-DAG: [[MAP2:#map[0-9]+]] = (d0, d1) -> (((((d0 * 72 + d1) mod 2304) mod 1152) floordiv 9) floordiv 8) // MAXIMAL-DAG: [[MAP3:#map[0-9]+]] = (d0, d1) -> (((((d0 * 72 + d1) mod 2304) mod 1152) mod 9) floordiv 3) // MAXIMAL-DAG: [[MAP4:#map[0-9]+]] = (d0, d1) -> (((((d0 * 72 + d1) mod 2304) mod 1152) mod 9) mod 3) -// MAXIMAL-DAG: [[MAP5:#map[0-9]+]] = (d0, d1, d2) -> (d1) -// MAXIMAL-DAG: [[MAP6:#map[0-9]+]] = (d0, d1, d2) -> (-d0 + d2) // MAXIMAL-DAG: [[MAP7:#map[0-9]+]] = (d0, d1) -> (d0 * 16 + d1) // MAXIMAL-DAG: [[MAP8:#map[0-9]+]] = (d0, d1) -> (d0 * 16 - d1 + 15) // MAXIMAL-LABEL: func @fuse_across_varying_dims_complex @@ -2257,35 +2046,28 @@ func @fuse_across_varying_dims_complex() { // MAXIMAL-NEXT: %5 = affine.apply [[MAP2]](%i4, %i0) // MAXIMAL-NEXT: %6 = affine.apply [[MAP3]](%i4, %i0) // MAXIMAL-NEXT: %7 = affine.apply [[MAP4]](%i4, %i0) -// MAXIMAL-NEXT: %8 = load %1[%3, %4, %6, %7, %5, %c0] : memref<2x2x3x3x16x1xf32> -// MAXIMAL-NEXT: %9 = affine.apply [[MAP5]](%i0, %i4, %i0) -// MAXIMAL-NEXT: %10 = affine.apply [[MAP6]](%i0, %i4, %i0) -// MAXIMAL-NEXT: store %8, %0[%9, %10] : memref<64x1xf32> +// MAXIMAL-NEXT: %8 = affine.load %1[%3, %4, %6, %7, %5, %c0] : memref<2x2x3x3x16x1xf32> +// MAXIMAL-NEXT: affine.store %8, %0[%i4, 0] : memref<64x1xf32> // MAXIMAL-NEXT: } // MAXIMAL-NEXT: affine.for %i5 = 0 to 4 { // MAXIMAL-NEXT: affine.for %i6 = 0 to 16 { -// MAXIMAL-NEXT: %11 = affine.apply [[MAP7]](%i5, %i6) -// MAXIMAL-NEXT: %12 = affine.apply [[MAP5]](%i0, %11, %i0) -// MAXIMAL-NEXT: %13 = affine.apply [[MAP6]](%i0, %11, %i0) -// MAXIMAL-NEXT: %14 = load %0[%12, %13] : memref<64x1xf32> +// MAXIMAL-NEXT: %9 = affine.apply [[MAP7]](%i5, %i6) +// MAXIMAL-NEXT: %10 = affine.load %0[%i5 * 16 + %i6, 0] : memref<64x1xf32> // MAXIMAL-NEXT: } // MAXIMAL-NEXT: affine.for %i7 = 0 to 16 { -// MAXIMAL-NEXT: %15 = "bar"() : () -> f32 -// MAXIMAL-NEXT: %16 = affine.apply [[MAP7]](%i0, %i7) -// MAXIMAL-NEXT: store %15, %2[%16, %i5] : memref<144x4xf32> +// MAXIMAL-NEXT: %11 = "bar"() : () -> f32 +// MAXIMAL-NEXT: %12 = affine.apply [[MAP7]](%i0, %i7) +// MAXIMAL-NEXT: affine.store %11, %2[%12, %i5] : memref<144x4xf32> // MAXIMAL-NEXT: } // MAXIMAL-NEXT: } -// MAXIMAL-NEXT: %17 = affine.apply [[MAP8]](%i2, %i3) -// MAXIMAL-NEXT: %18 = affine.apply [[MAP5]](%i0, %17, %i0) -// MAXIMAL-NEXT: %19 = affine.apply [[MAP6]](%i0, %17, %i0) -// MAXIMAL-NEXT: %20 = load %0[%18, %19] : memref<64x1xf32> +// MAXIMAL-NEXT: %13 = affine.apply [[MAP8]](%i2, %i3) +// MAXIMAL-NEXT: %14 = affine.load %0[%i2 * 16 - %i3 + 15, 0] : memref<64x1xf32> // MAXIMAL-NEXT: } // MAXIMAL-NEXT: } // MAXIMAL-NEXT: } // MAXIMAL-NEXT: } // ----- -// CHECK-DAG: [[MAP3:#map[0-9]+]] = (d0) -> (d0 - 10) func @should_fuse_with_slice_union() { %a = alloc() : memref<100xf32> @@ -2293,13 +2075,13 @@ func @should_fuse_with_slice_union() { %cf0 = constant 0.0 : f32 affine.for %i0 = 0 to 100 { - store %cf0, %a[%i0]: memref<100xf32> + affine.store %cf0, %a[%i0]: memref<100xf32> } affine.for %i1 = 10 to 20 { - %v0 = load %a[%i1]: memref<100xf32> + %v0 = affine.load %a[%i1]: memref<100xf32> affine.for %i2 = 15 to 25 { - %v1 = load %a[%i2]: memref<100xf32> + %v1 = affine.load %a[%i2]: memref<100xf32> } } // The union of two slice bounds (calculated between the store and each of @@ -2309,14 +2091,11 @@ func @should_fuse_with_slice_union() { // the fused loops based on the union calculation. // CHECK: affine.for %i0 = 10 to 20 { // CHECK-NEXT: affine.for %i1 = 10 to 25 { -// CHECK-NEXT: %1 = affine.apply [[MAP3]](%i1) -// CHECK-NEXT: store %cst, %0[%1] : memref<15xf32> +// CHECK-NEXT: affine.store %cst, %0[%i1 - 10] : memref<15xf32> // CHECK-NEXT: } -// CHECK-NEXT: %2 = affine.apply [[MAP3]](%i0) -// CHECK-NEXT: %3 = load %0[%2] : memref<15xf32> +// CHECK-NEXT: %1 = affine.load %0[%i0 - 10] : memref<15xf32> // CHECK-NEXT: affine.for %i2 = 15 to 25 { -// CHECK-NEXT: %4 = affine.apply [[MAP3]](%i2) -// CHECK-NEXT: %5 = load %0[%4] : memref<15xf32> +// CHECK-NEXT: %2 = affine.load %0[%i2 - 10] : memref<15xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -2328,21 +2107,21 @@ func @should_fuse_with_slice_union() { func @affine_add_mm_fused(%arg0: memref<1024x1024xf32>, %arg1: memref<1024x1024xf32>, %arg2: memref<1024x1024xf32>, %arg3: memref<1024x1024xf32>) { affine.for %i2 = 0 to 1024 { affine.for %i3 = 0 to 1024 { - %0 = load %arg3[%i2, %i3] : memref<1024x1024xf32> - %1 = load %arg2[%i2, %i3] : memref<1024x1024xf32> + %0 = affine.load %arg3[%i2, %i3] : memref<1024x1024xf32> + %1 = affine.load %arg2[%i2, %i3] : memref<1024x1024xf32> %2 = addf %1, %0 : f32 - store %2, %arg2[%i2, %i3] : memref<1024x1024xf32> + affine.store %2, %arg2[%i2, %i3] : memref<1024x1024xf32> } } affine.for %i4 = 0 to 1024 { affine.for %i5 = 0 to 1024 { affine.for %i6 = 0 to 1024 { - %3 = load %arg1[%i6, %i5] : memref<1024x1024xf32> - %4 = load %arg0[%i4, %i6] : memref<1024x1024xf32> + %3 = affine.load %arg1[%i6, %i5] : memref<1024x1024xf32> + %4 = affine.load %arg0[%i4, %i6] : memref<1024x1024xf32> %5 = mulf %4, %3 : f32 - %6 = load %arg2[%i4, %i5] : memref<1024x1024xf32> + %6 = affine.load %arg2[%i4, %i5] : memref<1024x1024xf32> %7 = addf %6, %5 : f32 - store %7, %arg2[%i4, %i5] : memref<1024x1024xf32> + affine.store %7, %arg2[%i4, %i5] : memref<1024x1024xf32> } } } @@ -2350,17 +2129,17 @@ func @affine_add_mm_fused(%arg0: memref<1024x1024xf32>, %arg1: memref<1024x1024x // dependence between load/store on '%arg2', carried on reduction loop %i6. // CHECK: affine.for %i0 = 0 to 1024 { // CHECK-NEXT: affine.for %i1 = 0 to 1024 { - // CHECK-NEXT: %0 = load %arg3[%i0, %i1] : memref<1024x1024xf32> - // CHECK-NEXT: %1 = load %arg2[%i0, %i1] : memref<1024x1024xf32> + // CHECK-NEXT: %0 = affine.load %arg3[%i0, %i1] : memref<1024x1024xf32> + // CHECK-NEXT: %1 = affine.load %arg2[%i0, %i1] : memref<1024x1024xf32> // CHECK-NEXT: %2 = addf %1, %0 : f32 - // CHECK-NEXT: store %2, %arg2[%i0, %i1] : memref<1024x1024xf32> + // CHECK-NEXT: affine.store %2, %arg2[%i0, %i1] : memref<1024x1024xf32> // CHECK-NEXT: affine.for %i2 = 0 to 1024 { - // CHECK-NEXT: %3 = load %arg1[%i2, %i1] : memref<1024x1024xf32> - // CHECK-NEXT: %4 = load %arg0[%i0, %i2] : memref<1024x1024xf32> + // CHECK-NEXT: %3 = affine.load %arg1[%i2, %i1] : memref<1024x1024xf32> + // CHECK-NEXT: %4 = affine.load %arg0[%i0, %i2] : memref<1024x1024xf32> // CHECK-NEXT: %5 = mulf %4, %3 : f32 - // CHECK-NEXT: %6 = load %arg2[%i0, %i1] : memref<1024x1024xf32> + // CHECK-NEXT: %6 = affine.load %arg2[%i0, %i1] : memref<1024x1024xf32> // CHECK-NEXT: %7 = addf %6, %5 : f32 - // CHECK-NEXT: store %7, %arg2[%i0, %i1] : memref<1024x1024xf32> + // CHECK-NEXT: affine.store %7, %arg2[%i0, %i1] : memref<1024x1024xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } @@ -2373,35 +2152,35 @@ func @affine_2mm_fused(%arg0: memref<1024x1024xf32>, %arg1: memref<1024x1024xf32 %cst = constant 0.000000e+00 : f32 affine.for %i0 = 0 to 1024 { affine.for %i1 = 0 to 1024 { - store %cst, %arg2[%i0, %i1] : memref<1024x1024xf32> + affine.store %cst, %arg2[%i0, %i1] : memref<1024x1024xf32> } } affine.for %i2 = 0 to 1024 { affine.for %i3 = 0 to 1024 { - store %cst, %arg4[%i2, %i3] : memref<1024x1024xf32> + affine.store %cst, %arg4[%i2, %i3] : memref<1024x1024xf32> } } affine.for %i4 = 0 to 1024 { affine.for %i5 = 0 to 1024 { affine.for %i6 = 0 to 1024 { - %0 = load %arg1[%i6, %i5] : memref<1024x1024xf32> - %1 = load %arg0[%i4, %i6] : memref<1024x1024xf32> + %0 = affine.load %arg1[%i6, %i5] : memref<1024x1024xf32> + %1 = affine.load %arg0[%i4, %i6] : memref<1024x1024xf32> %2 = mulf %1, %0 : f32 - %3 = load %arg2[%i4, %i5] : memref<1024x1024xf32> + %3 = affine.load %arg2[%i4, %i5] : memref<1024x1024xf32> %4 = addf %3, %2 : f32 - store %4, %arg2[%i4, %i5] : memref<1024x1024xf32> + affine.store %4, %arg2[%i4, %i5] : memref<1024x1024xf32> } } } affine.for %i7 = 0 to 1024 { affine.for %i8 = 0 to 1024 { affine.for %i9 = 0 to 1024 { - %5 = load %arg1[%i9, %i8] : memref<1024x1024xf32> - %6 = load %arg0[%i7, %i9] : memref<1024x1024xf32> + %5 = affine.load %arg1[%i9, %i8] : memref<1024x1024xf32> + %6 = affine.load %arg0[%i7, %i9] : memref<1024x1024xf32> %7 = mulf %6, %5 : f32 - %8 = load %arg4[%i7, %i8] : memref<1024x1024xf32> + %8 = affine.load %arg4[%i7, %i8] : memref<1024x1024xf32> %9 = addf %8, %7 : f32 - store %9, %arg4[%i7, %i8] : memref<1024x1024xf32> + affine.store %9, %arg4[%i7, %i8] : memref<1024x1024xf32> } } } @@ -2411,25 +2190,25 @@ func @affine_2mm_fused(%arg0: memref<1024x1024xf32>, %arg1: memref<1024x1024xf32 // CHECK: affine.for %i0 = 0 to 1024 { // CHECK-NEXT: affine.for %i1 = 0 to 1024 { - // CHECK-NEXT: store %cst, %arg4[%i0, %i1] : memref<1024x1024xf32> + // CHECK-NEXT: affine.store %cst, %arg4[%i0, %i1] : memref<1024x1024xf32> // CHECK-NEXT: affine.for %i2 = 0 to 1024 { - // CHECK-NEXT: %0 = load %arg1[%i2, %i1] : memref<1024x1024xf32> - // CHECK-NEXT: %1 = load %arg0[%i0, %i2] : memref<1024x1024xf32> + // CHECK-NEXT: %0 = affine.load %arg1[%i2, %i1] : memref<1024x1024xf32> + // CHECK-NEXT: %1 = affine.load %arg0[%i0, %i2] : memref<1024x1024xf32> // CHECK-NEXT: %2 = mulf %1, %0 : f32 - // CHECK-NEXT: %3 = load %arg4[%i0, %i1] : memref<1024x1024xf32> + // CHECK-NEXT: %3 = affine.load %arg4[%i0, %i1] : memref<1024x1024xf32> // CHECK-NEXT: %4 = addf %3, %2 : f32 - // CHECK-NEXT: store %4, %arg4[%i0, %i1] : memref<1024x1024xf32> + // CHECK-NEXT: affine.store %4, %arg4[%i0, %i1] : memref<1024x1024xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: affine.for %i3 = 0 to 1024 { - // CHECK-NEXT: store %cst, %arg2[%i0, %i3] : memref<1024x1024xf32> + // CHECK-NEXT: affine.store %cst, %arg2[%i0, %i3] : memref<1024x1024xf32> // CHECK-NEXT: affine.for %i4 = 0 to 1024 { - // CHECK-NEXT: %5 = load %arg1[%i4, %i3] : memref<1024x1024xf32> - // CHECK-NEXT: %6 = load %arg0[%i0, %i4] : memref<1024x1024xf32> + // CHECK-NEXT: %5 = affine.load %arg1[%i4, %i3] : memref<1024x1024xf32> + // CHECK-NEXT: %6 = affine.load %arg0[%i0, %i4] : memref<1024x1024xf32> // CHECK-NEXT: %7 = mulf %6, %5 : f32 - // CHECK-NEXT: %8 = load %arg2[%i0, %i3] : memref<1024x1024xf32> + // CHECK-NEXT: %8 = affine.load %arg2[%i0, %i3] : memref<1024x1024xf32> // CHECK-NEXT: %9 = addf %8, %7 : f32 - // CHECK-NEXT: store %9, %arg2[%i0, %i3] : memref<1024x1024xf32> + // CHECK-NEXT: affine.store %9, %arg2[%i0, %i3] : memref<1024x1024xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } @@ -2443,24 +2222,24 @@ func @affine_2_dependent_mm_fused(%arg0: memref<1024x1024xf32>, %arg1: memref<10 affine.for %i0 = 0 to 1024 { affine.for %i1 = 0 to 1024 { affine.for %i2 = 0 to 1024 { - %0 = load %arg1[%i2, %i1] : memref<1024x1024xf32> - %1 = load %arg0[%i0, %i2] : memref<1024x1024xf32> + %0 = affine.load %arg1[%i2, %i1] : memref<1024x1024xf32> + %1 = affine.load %arg0[%i0, %i2] : memref<1024x1024xf32> %2 = mulf %1, %0 : f32 - %3 = load %arg2[%i0, %i1] : memref<1024x1024xf32> + %3 = affine.load %arg2[%i0, %i1] : memref<1024x1024xf32> %4 = addf %3, %2 : f32 - store %4, %arg2[%i0, %i1] : memref<1024x1024xf32> + affine.store %4, %arg2[%i0, %i1] : memref<1024x1024xf32> } } } affine.for %i3 = 0 to 1024 { affine.for %i4 = 0 to 1024 { affine.for %i5 = 0 to 1024 { - %5 = load %arg3[%i5, %i4] : memref<1024x1024xf32> - %6 = load %arg2[%i3, %i5] : memref<1024x1024xf32> + %5 = affine.load %arg3[%i5, %i4] : memref<1024x1024xf32> + %6 = affine.load %arg2[%i3, %i5] : memref<1024x1024xf32> %7 = mulf %6, %5 : f32 - %8 = load %arg4[%i3, %i4] : memref<1024x1024xf32> + %8 = affine.load %arg4[%i3, %i4] : memref<1024x1024xf32> %9 = addf %8, %7 : f32 - store %9, %arg4[%i3, %i4] : memref<1024x1024xf32> + affine.store %9, %arg4[%i3, %i4] : memref<1024x1024xf32> } } } @@ -2468,22 +2247,22 @@ func @affine_2_dependent_mm_fused(%arg0: memref<1024x1024xf32>, %arg1: memref<10 // CHECK: affine.for %i0 = 0 to 1024 { // CHECK-NEXT: affine.for %i1 = 0 to 1024 { // CHECK-NEXT: affine.for %i2 = 0 to 1024 { - // CHECK-NEXT: %0 = load %arg1[%i2, %i1] : memref<1024x1024xf32> - // CHECK-NEXT: %1 = load %arg0[%i0, %i2] : memref<1024x1024xf32> + // CHECK-NEXT: %0 = affine.load %arg1[%i2, %i1] : memref<1024x1024xf32> + // CHECK-NEXT: %1 = affine.load %arg0[%i0, %i2] : memref<1024x1024xf32> // CHECK-NEXT: %2 = mulf %1, %0 : f32 - // CHECK-NEXT: %3 = load %arg2[%i0, %i1] : memref<1024x1024xf32> + // CHECK-NEXT: %3 = affine.load %arg2[%i0, %i1] : memref<1024x1024xf32> // CHECK-NEXT: %4 = addf %3, %2 : f32 - // CHECK-NEXT: store %4, %arg2[%i0, %i1] : memref<1024x1024xf32> + // CHECK-NEXT: affine.store %4, %arg2[%i0, %i1] : memref<1024x1024xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: affine.for %i3 = 0 to 1024 { // CHECK-NEXT: affine.for %i4 = 0 to 1024 { - // CHECK-NEXT: %5 = load %arg3[%i4, %i3] : memref<1024x1024xf32> - // CHECK-NEXT: %6 = load %arg2[%i0, %i4] : memref<1024x1024xf32> + // CHECK-NEXT: %5 = affine.load %arg3[%i4, %i3] : memref<1024x1024xf32> + // CHECK-NEXT: %6 = affine.load %arg2[%i0, %i4] : memref<1024x1024xf32> // CHECK-NEXT: %7 = mulf %6, %5 : f32 - // CHECK-NEXT: %8 = load %arg4[%i0, %i3] : memref<1024x1024xf32> + // CHECK-NEXT: %8 = affine.load %arg4[%i0, %i3] : memref<1024x1024xf32> // CHECK-NEXT: %9 = addf %8, %7 : f32 - // CHECK-NEXT: store %9, %arg4[%i0, %i3] : memref<1024x1024xf32> + // CHECK-NEXT: affine.store %9, %arg4[%i0, %i3] : memref<1024x1024xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir index 4173386d585..0493ed18ec3 100644 --- a/mlir/test/Transforms/loop-invariant-code-motion.mlir +++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir @@ -8,7 +8,7 @@ func @nested_loops_both_having_invariant_code() { affine.for %i0 = 0 to 10 { %v0 = addf %cf7, %cf8 : f32 affine.for %i1 = 0 to 10 { - store %v0, %m[%i0] : memref<10xf32> + affine.store %v0, %m[%i0] : memref<10xf32> } } @@ -17,7 +17,7 @@ func @nested_loops_both_having_invariant_code() { // CHECK-NEXT: %cst_0 = constant 8.000000e+00 : f32 // CHECK-NEXT: %1 = addf %cst, %cst_0 : f32 // CHECK-NEXT: affine.for %i0 = 0 to 10 { - // CHECK-NEXT: store %1, %0[%i0] : memref<10xf32> + // CHECK-NEXT: affine.store %1, %0[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -31,14 +31,14 @@ func @store_affine_apply() -> memref<10xf32> { %m = alloc() : memref<10xf32> affine.for %i0 = 0 to 10 { %t0 = affine.apply (d1) -> (d1 + 1)(%i0) - store %cf7, %m[%t0] : memref<10xf32> + affine.store %cf7, %m[%t0] : memref<10xf32> } return %m : memref<10xf32> // CHECK: %cst = constant 7.000000e+00 : f32 // CHECK-NEXT: %0 = alloc() : memref<10xf32> // CHECK-NEXT: affine.for %i0 = 0 to 10 { -// CHECK-NEXT: %1 = affine.apply #map2(%i0) -// CHECK-NEXT: store %cst, %0[%1] : memref<10xf32> +// CHECK-NEXT: %1 = affine.apply #map3(%i0) +// CHECK-NEXT: affine.store %cst, %0[%1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return %0 : memref<10xf32> } @@ -66,19 +66,19 @@ func @single_loop_nothing_invariant() { %m1 = alloc() : memref<10xf32> %m2 = alloc() : memref<10xf32> affine.for %i0 = 0 to 10 { - %v0 = load %m1[%i0] : memref<10xf32> - %v1 = load %m2[%i0] : memref<10xf32> + %v0 = affine.load %m1[%i0] : memref<10xf32> + %v1 = affine.load %m2[%i0] : memref<10xf32> %v2 = addf %v0, %v1 : f32 - store %v2, %m1[%i0] : memref<10xf32> + affine.store %v2, %m1[%i0] : memref<10xf32> } // CHECK: %0 = alloc() : memref<10xf32> // CHECK-NEXT: %1 = alloc() : memref<10xf32> // CHECK-NEXT: affine.for %i0 = 0 to 10 { - // CHECK-NEXT: %2 = load %0[%i0] : memref<10xf32> - // CHECK-NEXT: %3 = load %1[%i0] : memref<10xf32> + // CHECK-NEXT: %2 = affine.load %0[%i0] : memref<10xf32> + // CHECK-NEXT: %3 = affine.load %1[%i0] : memref<10xf32> // CHECK-NEXT: %4 = addf %2, %3 : f32 - // CHECK-NEXT: store %4, %0[%i0] : memref<10xf32> + // CHECK-NEXT: affine.store %4, %0[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -93,7 +93,7 @@ func @invariant_code_inside_affine_if() { %t0 = affine.apply (d1) -> (d1 + 1)(%i0) affine.if (d0, d1) : (d1 - d0 >= 0) (%i0, %t0) { %cf9 = addf %cf8, %cf8 : f32 - store %cf9, %m[%i0] : memref<10xf32> + affine.store %cf9, %m[%i0] : memref<10xf32> } } @@ -101,10 +101,10 @@ func @invariant_code_inside_affine_if() { // CHECK: %0 = alloc() : memref<10xf32> // CHECK-NEXT: %cst = constant 8.000000e+00 : f32 // CHECK-NEXT: affine.for %i0 = 0 to 10 { - // CHECK-NEXT: %1 = affine.apply #map2(%i0) + // CHECK-NEXT: %1 = affine.apply #map3(%i0) // CHECK-NEXT: affine.if #set0(%i0, %1) { // CHECK-NEXT: %2 = addf %cst, %cst : f32 - // CHECK-NEXT: store %2, %0[%i0] : memref<10xf32> + // CHECK-NEXT: affine.store %2, %0[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -122,8 +122,8 @@ func @dependent_stores() { %v0 = addf %cf7, %cf8 : f32 affine.for %i1 = 0 to 10 { %v1 = addf %cf7, %cf7 : f32 - store %v1, %m[%i1] : memref<10xf32> - store %v0, %m[%i0] : memref<10xf32> + affine.store %v1, %m[%i1] : memref<10xf32> + affine.store %v0, %m[%i0] : memref<10xf32> } } @@ -135,8 +135,8 @@ func @dependent_stores() { // CHECK-NEXT: affine.for %i0 = 0 to 10 { // CHECK-NEXT: affine.for %i1 = 0 to 10 { - // CHECK-NEXT: store %2, %0[%i1] : memref<10xf32> - // CHECK-NEXT: store %1, %0[%i0] : memref<10xf32> + // CHECK-NEXT: affine.store %2, %0[%i1] : memref<10xf32> + // CHECK-NEXT: affine.store %1, %0[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -152,8 +152,8 @@ func @independent_stores() { %v0 = addf %cf7, %cf8 : f32 affine.for %i1 = 0 to 10 { %v1 = addf %cf7, %cf7 : f32 - store %v0, %m[%i0] : memref<10xf32> - store %v1, %m[%i1] : memref<10xf32> + affine.store %v0, %m[%i0] : memref<10xf32> + affine.store %v1, %m[%i1] : memref<10xf32> } } @@ -164,8 +164,8 @@ func @independent_stores() { // CHECK-NEXT: %2 = addf %cst, %cst : f32 // CHECK-NEXT: affine.for %i0 = 0 to 10 { // CHECK-NEXT: affine.for %i1 = 0 to 10 { - // CHECK-NEXT: store %1, %0[%i0] : memref<10xf32> - // CHECK-NEXT: store %2, %0[%i1] : memref<10xf32> + // CHECK-NEXT: affine.store %1, %0[%i0] : memref<10xf32> + // CHECK-NEXT: affine.store %2, %0[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -181,8 +181,8 @@ func @load_dependent_store() { %v0 = addf %cf7, %cf8 : f32 affine.for %i1 = 0 to 10 { %v1 = addf %cf7, %cf7 : f32 - store %v0, %m[%i1] : memref<10xf32> - %v2 = load %m[%i0] : memref<10xf32> + affine.store %v0, %m[%i1] : memref<10xf32> + %v2 = affine.load %m[%i0] : memref<10xf32> } } @@ -193,8 +193,8 @@ func @load_dependent_store() { // CHECK-NEXT: %2 = addf %cst, %cst : f32 // CHECK-NEXT: affine.for %i0 = 0 to 10 { // CHECK-NEXT: affine.for %i1 = 0 to 10 { - // CHECK-NEXT: store %1, %0[%i1] : memref<10xf32> - // CHECK-NEXT: %3 = load %0[%i0] : memref<10xf32> + // CHECK-NEXT: affine.store %1, %0[%i1] : memref<10xf32> + // CHECK-NEXT: %3 = affine.load %0[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -210,8 +210,8 @@ func @load_after_load() { %v0 = addf %cf7, %cf8 : f32 affine.for %i1 = 0 to 10 { %v1 = addf %cf7, %cf7 : f32 - %v3 = load %m[%i1] : memref<10xf32> - %v2 = load %m[%i0] : memref<10xf32> + %v3 = affine.load %m[%i1] : memref<10xf32> + %v2 = affine.load %m[%i0] : memref<10xf32> } } @@ -221,9 +221,9 @@ func @load_after_load() { // CHECK-NEXT: %1 = addf %cst, %cst_0 : f32 // CHECK-NEXT: %2 = addf %cst, %cst : f32 // CHECK-NEXT: affine.for %i0 = 0 to 10 { - // CHECK-NEXT: %3 = load %0[%i0] : memref<10xf32> + // CHECK-NEXT: %3 = affine.load %0[%i0] : memref<10xf32> // CHECK-NEXT: affine.for %i1 = 0 to 10 { - // CHECK-NEXT: %4 = load %0[%i1] : memref<10xf32> + // CHECK-NEXT: %4 = affine.load %0[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -237,7 +237,7 @@ func @invariant_affine_if() { affine.for %i1 = 0 to 10 { affine.if (d0, d1) : (d1 - d0 >= 0) (%i0, %i0) { %cf9 = addf %cf8, %cf8 : f32 - store %cf9, %m[%i0] : memref<10xf32> + affine.store %cf9, %m[%i0] : memref<10xf32> } } @@ -248,7 +248,7 @@ func @invariant_affine_if() { // CHECK-NEXT: affine.for %i0 = 0 to 10 { // CHECK-NEXT: affine.if #set0(%i0, %i0) { // CHECK-NEXT: %1 = addf %cst, %cst : f32 - // CHECK-NEXT: store %1, %0[%i0] : memref<10xf32> + // CHECK-NEXT: affine.store %1, %0[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -263,7 +263,7 @@ func @invariant_affine_if2() { affine.for %i1 = 0 to 10 { affine.if (d0, d1) : (d1 - d0 >= 0) (%i0, %i0) { %cf9 = addf %cf8, %cf8 : f32 - store %cf9, %m[%i1] : memref<10xf32> + affine.store %cf9, %m[%i1] : memref<10xf32> } } @@ -275,7 +275,7 @@ func @invariant_affine_if2() { // CHECK-NEXT: affine.for %i1 = 0 to 10 { // CHECK-NEXT: affine.if #set0(%i0, %i0) { // CHECK-NEXT: %1 = addf %cst, %cst : f32 - // CHECK-NEXT: store %1, %0[%i1] : memref<10xf32> + // CHECK-NEXT: affine.store %1, %0[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } @@ -291,9 +291,9 @@ func @invariant_affine_nested_if() { affine.for %i1 = 0 to 10 { affine.if (d0, d1) : (d1 - d0 >= 0) (%i0, %i0) { %cf9 = addf %cf8, %cf8 : f32 - store %cf9, %m[%i0] : memref<10xf32> + affine.store %cf9, %m[%i0] : memref<10xf32> affine.if (d0, d1) : (d1 - d0 >= 0) (%i0, %i0) { - store %cf9, %m[%i1] : memref<10xf32> + affine.store %cf9, %m[%i1] : memref<10xf32> } } } @@ -305,9 +305,9 @@ func @invariant_affine_nested_if() { // CHECK-NEXT: affine.for %i1 = 0 to 10 { // CHECK-NEXT: affine.if #set0(%i0, %i0) { // CHECK-NEXT: %1 = addf %cst, %cst : f32 - // CHECK-NEXT: store %1, %0[%i0] : memref<10xf32> + // CHECK-NEXT: affine.store %1, %0[%i0] : memref<10xf32> // CHECK-NEXT: affine.if #set0(%i0, %i0) { - // CHECK-NEXT: store %1, %0[%i1] : memref<10xf32> + // CHECK-NEXT: affine.store %1, %0[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } @@ -324,11 +324,11 @@ func @invariant_affine_nested_if_else() { affine.for %i1 = 0 to 10 { affine.if (d0, d1) : (d1 - d0 >= 0) (%i0, %i0) { %cf9 = addf %cf8, %cf8 : f32 - store %cf9, %m[%i0] : memref<10xf32> + affine.store %cf9, %m[%i0] : memref<10xf32> affine.if (d0, d1) : (d1 - d0 >= 0) (%i0, %i0) { - store %cf9, %m[%i0] : memref<10xf32> + affine.store %cf9, %m[%i0] : memref<10xf32> } else { - store %cf9, %m[%i1] : memref<10xf32> + affine.store %cf9, %m[%i1] : memref<10xf32> } } } @@ -340,11 +340,11 @@ func @invariant_affine_nested_if_else() { // CHECK-NEXT: affine.for %i1 = 0 to 10 { // CHECK-NEXT: affine.if #set0(%i0, %i0) { // CHECK-NEXT: %1 = addf %cst, %cst : f32 - // CHECK-NEXT: store %1, %0[%i0] : memref<10xf32> + // CHECK-NEXT: affine.store %1, %0[%i0] : memref<10xf32> // CHECK-NEXT: affine.if #set0(%i0, %i0) { - // CHECK-NEXT: store %1, %0[%i0] : memref<10xf32> + // CHECK-NEXT: affine.store %1, %0[%i0] : memref<10xf32> // CHECK-NEXT: } else { - // CHECK-NEXT: store %1, %0[%i1] : memref<10xf32> + // CHECK-NEXT: affine.store %1, %0[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } @@ -362,11 +362,11 @@ func @invariant_affine_nested_if_else2() { affine.for %i1 = 0 to 10 { affine.if (d0, d1) : (d1 - d0 >= 0) (%i0, %i0) { %cf9 = addf %cf8, %cf8 : f32 - %tload1 = load %m[%i0] : memref<10xf32> + %tload1 = affine.load %m[%i0] : memref<10xf32> affine.if (d0, d1) : (d1 - d0 >= 0) (%i0, %i0) { - store %cf9, %m2[%i0] : memref<10xf32> + affine.store %cf9, %m2[%i0] : memref<10xf32> } else { - %tload2 = load %m[%i0] : memref<10xf32> + %tload2 = affine.load %m[%i0] : memref<10xf32> } } } @@ -378,11 +378,11 @@ func @invariant_affine_nested_if_else2() { // CHECK-NEXT: affine.for %i0 = 0 to 10 { // CHECK-NEXT: affine.if #set0(%i0, %i0) { // CHECK-NEXT: %2 = addf %cst, %cst : f32 - // CHECK-NEXT: %3 = load %0[%i0] : memref<10xf32> + // CHECK-NEXT: %3 = affine.load %0[%i0] : memref<10xf32> // CHECK-NEXT: affine.if #set0(%i0, %i0) { - // CHECK-NEXT: store %2, %1[%i0] : memref<10xf32> + // CHECK-NEXT: affine.store %2, %1[%i0] : memref<10xf32> // CHECK-NEXT: } else { - // CHECK-NEXT: %4 = load %0[%i0] : memref<10xf32> + // CHECK-NEXT: %4 = affine.load %0[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } @@ -399,9 +399,9 @@ func @invariant_affine_nested_if2() { affine.for %i1 = 0 to 10 { affine.if (d0, d1) : (d1 - d0 >= 0) (%i0, %i0) { %cf9 = addf %cf8, %cf8 : f32 - %v1 = load %m[%i0] : memref<10xf32> + %v1 = affine.load %m[%i0] : memref<10xf32> affine.if (d0, d1) : (d1 - d0 >= 0) (%i0, %i0) { - %v2 = load %m[%i0] : memref<10xf32> + %v2 = affine.load %m[%i0] : memref<10xf32> } } } @@ -412,9 +412,9 @@ func @invariant_affine_nested_if2() { // CHECK-NEXT: affine.for %i0 = 0 to 10 { // CHECK-NEXT: affine.if #set0(%i0, %i0) { // CHECK-NEXT: %1 = addf %cst, %cst : f32 - // CHECK-NEXT: %2 = load %0[%i0] : memref<10xf32> + // CHECK-NEXT: %2 = affine.load %0[%i0] : memref<10xf32> // CHECK-NEXT: affine.if #set0(%i0, %i0) { - // CHECK-NEXT: %3 = load %0[%i0] : memref<10xf32> + // CHECK-NEXT: %3 = affine.load %0[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } @@ -430,9 +430,9 @@ func @invariant_affine_for_inside_affine_if() { affine.for %i1 = 0 to 10 { affine.if (d0, d1) : (d1 - d0 >= 0) (%i0, %i0) { %cf9 = addf %cf8, %cf8 : f32 - store %cf9, %m[%i0] : memref<10xf32> + affine.store %cf9, %m[%i0] : memref<10xf32> affine.for %i2 = 0 to 10 { - store %cf9, %m[%i2] : memref<10xf32> + affine.store %cf9, %m[%i2] : memref<10xf32> } } } @@ -444,9 +444,9 @@ func @invariant_affine_for_inside_affine_if() { // CHECK-NEXT: affine.for %i1 = 0 to 10 { // CHECK-NEXT: affine.if #set0(%i0, %i0) { // CHECK-NEXT: %1 = addf %cst, %cst : f32 - // CHECK-NEXT: store %1, %0[%i0] : memref<10xf32> + // CHECK-NEXT: affine.store %1, %0[%i0] : memref<10xf32> // CHECK-NEXT: affine.for %i2 = 0 to 10 { - // CHECK-NEXT: store %1, %0[%i2] : memref<10xf32> + // CHECK-NEXT: affine.store %1, %0[%i2] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } @@ -462,16 +462,16 @@ func @invariant_constant_and_load() { %m2 = alloc() : memref<100xf32> affine.for %i0 = 0 to 5 { %c0 = constant 0 : index - %v = load %m2[%c0] : memref<100xf32> - store %v, %m[%i0] : memref<100xf32> + %v = affine.load %m2[%c0] : memref<100xf32> + affine.store %v, %m[%i0] : memref<100xf32> } // CHECK: %0 = alloc() : memref<100xf32> // CHECK-NEXT: %1 = alloc() : memref<100xf32> // CHECK-NEXT: %c0 = constant 0 : index - // CHECK-NEXT: %2 = load %1[%c0] : memref<100xf32> + // CHECK-NEXT: %2 = affine.load %1[%c0] : memref<100xf32> // CHECK-NEXT: affine.for %i0 = 0 to 5 { - // CHECK-NEXT: store %2, %0[%i0] : memref<100xf32> + // CHECK-NEXT: affine.store %2, %0[%i0] : memref<100xf32> // CHECK-NEXT: } // CHECK-NEXT: return @@ -484,9 +484,9 @@ func @nested_load_store_same_memref() { %cst = constant 8.0 : f32 %c0 = constant 0 : index affine.for %i0 = 0 to 10 { - %v0 = load %m[%c0] : memref<10xf32> + %v0 = affine.load %m[%c0] : memref<10xf32> affine.for %i1 = 0 to 10 { - store %cst, %m[%i1] : memref<10xf32> + affine.store %cst, %m[%i1] : memref<10xf32> } } @@ -494,9 +494,9 @@ func @nested_load_store_same_memref() { // CHECK-NEXT: %cst = constant 8.000000e+00 : f32 // CHECK-NEXT: %c0 = constant 0 : index // CHECK-NEXT: affine.for %i0 = 0 to 10 { - // CHECK-NEXT: %1 = load %0[%c0] : memref<10xf32> + // CHECK-NEXT: %1 = affine.load %0[%c0] : memref<10xf32> // CHECK-NEXT: affine.for %i1 = 0 to 10 { - // CHECK-NEXT: store %cst, %0[%i1] : memref<10xf32> + // CHECK-NEXT: affine.store %cst, %0[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -510,9 +510,9 @@ func @nested_load_store_same_memref2() { %cst = constant 8.0 : f32 %c0 = constant 0 : index affine.for %i0 = 0 to 10 { - store %cst, %m[%c0] : memref<10xf32> + affine.store %cst, %m[%c0] : memref<10xf32> affine.for %i1 = 0 to 10 { - %v0 = load %m[%i0] : memref<10xf32> + %v0 = affine.load %m[%i0] : memref<10xf32> } } @@ -520,8 +520,8 @@ func @nested_load_store_same_memref2() { // CHECK-NEXT: %cst = constant 8.000000e+00 : f32 // CHECK-NEXT: %c0 = constant 0 : index // CHECK-NEXT: affine.for %i0 = 0 to 10 { - // CHECK-NEXT: store %cst, %0[%c0] : memref<10xf32> - // CHECK-NEXT: %1 = load %0[%i0] : memref<10xf32> + // CHECK-NEXT: affine.store %cst, %0[%c0] : memref<10xf32> + // CHECK-NEXT: %1 = affine.load %0[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return diff --git a/mlir/test/Transforms/loop-tiling.mlir b/mlir/test/Transforms/loop-tiling.mlir index b57ee2d7e35..bf1b0ac6dfd 100644 --- a/mlir/test/Transforms/loop-tiling.mlir +++ b/mlir/test/Transforms/loop-tiling.mlir @@ -87,12 +87,12 @@ func @simple_matmul(%arg0: memref<256x256xvector<64xf32>>, %arg1: memref<256x256 affine.for %i = 0 to 256 { affine.for %j = 0 to 256 { affine.for %k = 0 to 250 { - %l = load %arg0[%i, %k] : memref<256x256xvector<64xf32>> - %r = load %arg1[%k, %j] : memref<256x256xvector<64xf32>> - %o = load %arg2[%i, %j] : memref<256x256xvector<64xf32>> + %l = affine.load %arg0[%i, %k] : memref<256x256xvector<64xf32>> + %r = affine.load %arg1[%k, %j] : memref<256x256xvector<64xf32>> + %o = affine.load %arg2[%i, %j] : memref<256x256xvector<64xf32>> %m = mulf %l, %r : vector<64xf32> %a = addf %o, %m : vector<64xf32> - store %a, %arg2[%i, %j] : memref<256x256xvector<64xf32>> + affine.store %a, %arg2[%i, %j] : memref<256x256xvector<64xf32>> } } } @@ -112,14 +112,14 @@ func @tile_with_symbolic_loop_upper_bounds(%arg0: memref, %arg1: memref %0 = dim %arg0, 0 : memref affine.for %i0 = 0 to %0 { affine.for %i1 = 0 to %0 { - store %cst, %arg2[%i0, %i1] : memref + affine.store %cst, %arg2[%i0, %i1] : memref affine.for %i2 = 0 to %0 { - %1 = load %arg0[%i0, %i2] : memref - %2 = load %arg1[%i2, %i1] : memref + %1 = affine.load %arg0[%i0, %i2] : memref + %2 = affine.load %arg1[%i2, %i1] : memref %3 = mulf %1, %2 : f32 - %4 = load %arg2[%i0, %i1] : memref + %4 = affine.load %arg2[%i0, %i1] : memref %5 = addf %4, %3 : f32 - store %5, %arg2[%i0, %i1] : memref + affine.store %5, %arg2[%i0, %i1] : memref } } } @@ -129,16 +129,16 @@ func @tile_with_symbolic_loop_upper_bounds(%arg0: memref, %arg1: memref // CHECK: %0 = dim %arg0, 0 : memref // CHECK-NEXT: affine.for %i0 = 0 to %0 step 32 { // CHECK-NEXT: affine.for %i1 = 0 to %0 step 32 { -// CHECK-NEXT: affine.for %i2 = #map2(%i0) to min [[UBMAP]](%i0)[%0] { -// CHECK-NEXT: affine.for %i3 = #map2(%i1) to min [[UBMAP]](%i1)[%0] { -// CHECK-NEXT: store %cst, %arg2[%i2, %i3] : memref +// CHECK-NEXT: affine.for %i2 = #map3(%i0) to min [[UBMAP]](%i0)[%0] { +// CHECK-NEXT: affine.for %i3 = #map3(%i1) to min [[UBMAP]](%i1)[%0] { +// CHECK-NEXT: affine.store %cst, %arg2[%i2, %i3] : memref // CHECK-NEXT: affine.for %i4 = 0 to %0 { -// CHECK-NEXT: %1 = load %arg0[%i2, %i4] : memref -// CHECK-NEXT: %2 = load %arg1[%i4, %i3] : memref +// CHECK-NEXT: %1 = affine.load %arg0[%i2, %i4] : memref +// CHECK-NEXT: %2 = affine.load %arg1[%i4, %i3] : memref // CHECK-NEXT: %3 = mulf %1, %2 : f32 -// CHECK-NEXT: %4 = load %arg2[%i2, %i3] : memref +// CHECK-NEXT: %4 = affine.load %arg2[%i2, %i3] : memref // CHECK-NEXT: %5 = addf %4, %3 : f32 -// CHECK-NEXT: store %5, %arg2[%i2, %i3] : memref +// CHECK-NEXT: affine.store %5, %arg2[%i2, %i3] : memref // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } @@ -155,7 +155,7 @@ func @tile_with_symbolic_loop_upper_bounds(%arg0: memref, %arg1: memref func @tile_with_loop_upper_bounds_in_two_symbols(%arg0: memref, %limit: index) { %dim0 = dim %arg0, 0 : memref affine.for %i0 = 0 to ()[s0, s1] -> (s0 + s1) ()[%dim0, %limit] { - %v0 = load %arg0[%i0] : memref + %v0 = affine.load %arg0[%i0] : memref } return } @@ -163,7 +163,7 @@ func @tile_with_loop_upper_bounds_in_two_symbols(%arg0: memref, %limit: i // CHECK: %0 = dim %arg0, 0 : memref // CHECK-NEXT: affine.for %i0 = 0 to [[MAP1]]()[%0, %arg1] step 32 { // CHECK-NEXT: affine.for %i1 = [[MAP0]](%i0) to min [[UBMAP]](%i0)[%0, %arg1] { -// CHECK-NEXT: %1 = load %arg0[%i1] : memref +// CHECK-NEXT: %1 = affine.load %arg0[%i1] : memref // CHECK-NEXT: } // CHECK-NEXT: } @@ -173,12 +173,12 @@ func @trip_count_1(%arg0: memref<196608x1xf32>, %arg1: memref<196608x1xf32>) -> memref<196608x1xf32> { affine.for %i1 = 0 to 196608 { affine.for %i3 = 0 to 1 { - %4 = load %arg0[%i1, %i3] : memref<196608x1xf32> - store %4, %arg1[%i1, %i3] : memref<196608x1xf32> + %4 = affine.load %arg0[%i1, %i3] : memref<196608x1xf32> + affine.store %4, %arg1[%i1, %i3] : memref<196608x1xf32> } } return %arg1 : memref<196608x1xf32> } -// CHECK: %0 = load %arg0[%i2, %i3] : memref<196608x1xf32> +// CHECK: %0 = affine.load %arg0[%i2, %i3] : memref<196608x1xf32> diff --git a/mlir/test/Transforms/memref-bound-check.mlir b/mlir/test/Transforms/memref-bound-check.mlir index cbad0aeb12d..b83e9c9e71a 100644 --- a/mlir/test/Transforms/memref-bound-check.mlir +++ b/mlir/test/Transforms/memref-bound-check.mlir @@ -16,24 +16,24 @@ func @test() { %idx0 = affine.apply (d0, d1) -> (d0)(%i, %j) %idx1 = affine.apply (d0, d1) -> (d1)(%i, %j) // Out of bound access. - %x = load %A[%idx0, %idx1] : memref<9 x 9 x i32> - // expected-error@-1 {{'std.load' op memref out of upper bound access along dimension #1}} - // expected-error@-2 {{'std.load' op memref out of lower bound access along dimension #1}} - // expected-error@-3 {{'std.load' op memref out of upper bound access along dimension #2}} - // expected-error@-4 {{'std.load' op memref out of lower bound access along dimension #2}} + %x = affine.load %A[%idx0, %idx1] : memref<9 x 9 x i32> + // expected-error@-1 {{'affine.load' op memref out of upper bound access along dimension #1}} + // expected-error@-2 {{'affine.load' op memref out of lower bound access along dimension #1}} + // expected-error@-3 {{'affine.load' op memref out of upper bound access along dimension #2}} + // expected-error@-4 {{'affine.load' op memref out of lower bound access along dimension #2}} // This will access 0 to 110 - hence an overflow. %idy = affine.apply (d0, d1) -> (10*d0 - d1 + 19)(%i, %j) - %y = load %B[%idy] : memref<111 x i32> + %y = affine.load %B[%idy] : memref<111 x i32> } } affine.for %k = 0 to 10 { // In bound. - %u = load %B[%zero] : memref<111 x i32> + %u = affine.load %B[%zero] : memref<111 x i32> // Out of bounds. - %v = load %B[%sym] : memref<111 x i32> // expected-error {{'std.load' op memref out of upper bound access along dimension #1}} + %v = affine.load %B[%sym] : memref<111 x i32> // expected-error {{'affine.load' op memref out of upper bound access along dimension #1}} // Out of bounds. - store %v, %B[%minusone] : memref<111 x i32> // expected-error {{'std.store' op memref out of lower bound access along dimension #1}} + affine.store %v, %B[%minusone] : memref<111 x i32> // expected-error {{'affine.store' op memref out of lower bound access along dimension #1}} } return } @@ -48,14 +48,14 @@ func @test_mod_floordiv_ceildiv() { %idx0 = affine.apply (d0, d1, d2) -> (d0 mod 128 + 1)(%i, %j, %j) %idx1 = affine.apply (d0, d1, d2) -> (d1 floordiv 4 + 1)(%i, %j, %j) %idx2 = affine.apply (d0, d1, d2) -> (d2 ceildiv 4)(%i, %j, %j) - %x = load %A[%idx0, %idx1, %idx2] : memref<128 x 64 x 64 x i32> - // expected-error@-1 {{'std.load' op memref out of upper bound access along dimension #1}} - // expected-error@-2 {{'std.load' op memref out of upper bound access along dimension #2}} - // expected-error@-3 {{'std.load' op memref out of upper bound access along dimension #3}} + %x = affine.load %A[%idx0, %idx1, %idx2] : memref<128 x 64 x 64 x i32> + // expected-error@-1 {{'affine.load' op memref out of upper bound access along dimension #1}} + // expected-error@-2 {{'affine.load' op memref out of upper bound access along dimension #2}} + // expected-error@-3 {{'affine.load' op memref out of upper bound access along dimension #3}} %idy0 = affine.apply (d0, d1, d2) -> (d0 mod 128)(%i, %j, %j) %idy1 = affine.apply (d0, d1, d2) -> (d1 floordiv 4)(%i, %j, %j) %idy2 = affine.apply (d0, d1, d2) -> (d2 ceildiv 4 - 1)(%i, %j, %j) - store %x, %A[%idy0, %idy1, %idy2] : memref<128 x 64 x 64 x i32> // expected-error {{'std.store' op memref out of lower bound access along dimension #3}} + affine.store %x, %A[%idy0, %idy1, %idy2] : memref<128 x 64 x 64 x i32> // expected-error {{'affine.store' op memref out of lower bound access along dimension #3}} } // CHECK } } // CHECK } return @@ -72,16 +72,16 @@ func @test_no_out_of_bounds() { affine.for %j = 0 to 256 { // All of these accesses are in bound; check that no errors are emitted. // CHECK: %3 = affine.apply {{#map.*}}(%i0, %i1) - // CHECK-NEXT: %4 = load %0[%3, %c0] : memref<257x256xi32> + // CHECK-NEXT: %4 = affine.load %0[%3, %c0] : memref<257x256xi32> // CHECK-NEXT: %5 = affine.apply {{#map.*}}(%i0, %i0) - // CHECK-NEXT: %6 = load %2[%5] : memref<1xi32> + // CHECK-NEXT: %6 = affine.load %2[%5] : memref<1xi32> %idx0 = affine.apply (d0, d1) -> ( 64 * (d0 ceildiv 64))(%i, %j) // Without GCDTightenInequalities(), the upper bound on the region // accessed along first memref dimension would have come out as d0 <= 318 // (instead of d0 <= 256), and led to a false positive out of bounds. - %x = load %A[%idx0, %zero] : memref<257 x 256 x i32> + %x = affine.load %A[%idx0, %zero] : memref<257 x 256 x i32> %idy = affine.apply (d0, d1) -> (d0 floordiv 256)(%i, %i) - %y = load %B[%idy] : memref<1 x i32> + %y = affine.load %B[%idy] : memref<1 x i32> } // CHECK-NEXT } } return @@ -97,14 +97,14 @@ func @mod_div() { %idx0 = affine.apply (d0, d1, d2) -> (d0 mod 128 + 1)(%i, %j, %j) %idx1 = affine.apply (d0, d1, d2) -> (d1 floordiv 4 + 1)(%i, %j, %j) %idx2 = affine.apply (d0, d1, d2) -> (d2 ceildiv 4)(%i, %j, %j) - %x = load %A[%idx0, %idx1, %idx2] : memref<128 x 64 x 64 x i32> - // expected-error@-1 {{'std.load' op memref out of upper bound access along dimension #1}} - // expected-error@-2 {{'std.load' op memref out of upper bound access along dimension #2}} - // expected-error@-3 {{'std.load' op memref out of upper bound access along dimension #3}} + %x = affine.load %A[%idx0, %idx1, %idx2] : memref<128 x 64 x 64 x i32> + // expected-error@-1 {{'affine.load' op memref out of upper bound access along dimension #1}} + // expected-error@-2 {{'affine.load' op memref out of upper bound access along dimension #2}} + // expected-error@-3 {{'affine.load' op memref out of upper bound access along dimension #3}} %idy0 = affine.apply (d0, d1, d2) -> (d0 mod 128)(%i, %j, %j) %idy1 = affine.apply (d0, d1, d2) -> (d1 floordiv 4)(%i, %j, %j) %idy2 = affine.apply (d0, d1, d2) -> (d2 ceildiv 4 - 1)(%i, %j, %j) - store %x, %A[%idy0, %idy1, %idy2] : memref<128 x 64 x 64 x i32> // expected-error {{'std.store' op memref out of lower bound access along dimension #3}} + affine.store %x, %A[%idy0, %idy1, %idy2] : memref<128 x 64 x 64 x i32> // expected-error {{'affine.store' op memref out of lower bound access along dimension #3}} } } return @@ -118,7 +118,7 @@ func @mod_floordiv_nested() { affine.for %j = 0 to 256 { %idx0 = affine.apply (d0, d1) -> ((d0 mod 1024) floordiv 4)(%i, %j) %idx1 = affine.apply (d0, d1) -> ((((d1 mod 128) mod 32) ceildiv 4) * 32)(%i, %j) - load %A[%idx0, %idx1] : memref<256 x 256 x i32> // expected-error {{'std.load' op memref out of upper bound access along dimension #2}} + affine.load %A[%idx0, %idx1] : memref<256 x 256 x i32> // expected-error {{'affine.load' op memref out of upper bound access along dimension #2}} } } return @@ -129,7 +129,7 @@ func @test_semi_affine_bailout(%N : index) { %B = alloc() : memref<10 x i32> affine.for %i = 0 to 10 { %idx = affine.apply (d0)[s0] -> (d0 * s0)(%i)[%N] - %y = load %B[%idx] : memref<10 x i32> + %y = affine.load %B[%idx] : memref<10 x i32> // expected-error@-1 {{getMemRefRegion: compose affine map failed}} } return @@ -141,7 +141,7 @@ func @multi_mod_floordiv() { affine.for %ii = 0 to 64 { %idx0 = affine.apply (d0) -> ((d0 mod 147456) floordiv 1152) (%ii) %idx1 = affine.apply (d0) -> (((d0 mod 147456) mod 1152) floordiv 384) (%ii) - %v = load %A[%idx0, %idx1] : memref<2x2xi32> + %v = affine.load %A[%idx0, %idx1] : memref<2x2xi32> } return } @@ -169,7 +169,7 @@ func @delinearize_mod_floordiv() { %a15 = affine.apply (d0) -> ((((((d0 mod 294912) mod 147456) mod 1152) mod 384) mod 128) floordiv 128) (%a0) - %v0 = load %in[%a10, %a11, %a13, %a14, %a12, %a15] + %v0 = affine.load %in[%a10, %a11, %a13, %a14, %a12, %a15] : memref<2x2x3x3x16x1xi32> } } @@ -180,7 +180,7 @@ func @delinearize_mod_floordiv() { func @zero_d_memref(%arg0: memref) { %c0 = constant 0 : i32 // A 0-d memref always has in-bound accesses! - store %c0, %arg0[] : memref + affine.store %c0, %arg0[] : memref return } @@ -191,7 +191,7 @@ func @out_of_bounds() { affine.for %i0 = 10 to 11 { %idy = affine.apply (d0) -> (100 * d0 floordiv 1000) (%i0) - store %c9, %in[%idy] : memref<1xi32> // expected-error {{'std.store' op memref out of upper bound access along dimension #1}} + affine.store %c9, %in[%idy] : memref<1xi32> // expected-error {{'affine.store' op memref out of upper bound access along dimension #1}} } return } @@ -214,7 +214,7 @@ func @test_complex_mod_floordiv(%arg0: memref<4x4x16x1xf32>) { %2 = affine.apply #map3(%i0, %i1) %3 = affine.apply #map4(%i0, %i1) %4 = affine.apply #map5(%i0, %i1) - %5 = load %arg0[%2, %c0, %4, %c0] : memref<4x4x16x1xf32> + %5 = affine.load %arg0[%2, %c0, %4, %c0] : memref<4x4x16x1xf32> } } return @@ -232,9 +232,9 @@ func @test_mod_bound() { %1 = alloc() : memref<6 x f32> affine.for %i0 = 0 to 4096 { affine.for %i1 = #map0(%i0) to #map1(%i0) { - load %0[%i1] : memref<7 x f32> - load %1[%i1] : memref<6 x f32> - // expected-error@-1 {{'std.load' op memref out of upper bound access along dimension #1}} + affine.load %0[%i1] : memref<7 x f32> + affine.load %1[%i1] : memref<6 x f32> + // expected-error@-1 {{'affine.load' op memref out of upper bound access along dimension #1}} } } return @@ -254,13 +254,13 @@ func @test_floordiv_bound() { %N = constant 2048 : index affine.for %i0 = 0 to 4096 { affine.for %i1 = #map0(%i0) to #map1(%i0) { - load %0[%i1] : memref<1027 x f32> - load %1[%i1] : memref<1026 x f32> - // expected-error@-1 {{'std.load' op memref out of upper bound access along dimension #1}} + affine.load %0[%i1] : memref<1027 x f32> + affine.load %1[%i1] : memref<1026 x f32> + // expected-error@-1 {{'affine.load' op memref out of upper bound access along dimension #1}} } affine.for %i2 = 0 to #map2(%N) { // Within bounds. - %v = load %2[%i2] : memref<4096 x f32> + %v = affine.load %2[%i2] : memref<4096 x f32> } } return @@ -279,7 +279,7 @@ func @non_composed_bound_operand(%arg0: memref<1024xf32>) { affine.for %i0 = 4 to 1028 step 4 { %i1 = affine.apply (d0) -> (d0 - 4) (%i0) affine.for %i2 = #map_lb(%i1) to #map_ub(%i1) { - %0 = load %arg0[%i2] : memref<1024xf32> + %0 = affine.load %arg0[%i2] : memref<1024xf32> } } return diff --git a/mlir/test/Transforms/memref-dataflow-opt.mlir b/mlir/test/Transforms/memref-dataflow-opt.mlir index 979f4867d9b..764d5246b69 100644 --- a/mlir/test/Transforms/memref-dataflow-opt.mlir +++ b/mlir/test/Transforms/memref-dataflow-opt.mlir @@ -11,8 +11,8 @@ func @simple_store_load() { %cf7 = constant 7.0 : f32 %m = alloc() : memref<10xf32> affine.for %i0 = 0 to 10 { - store %cf7, %m[%i0] : memref<10xf32> - %v0 = load %m[%i0] : memref<10xf32> + affine.store %cf7, %m[%i0] : memref<10xf32> + %v0 = affine.load %m[%i0] : memref<10xf32> %v1 = addf %v0, %v0 : f32 } return @@ -31,13 +31,13 @@ func @multi_store_load() { %cf9 = constant 9.0 : f32 %m = alloc() : memref<10xf32> affine.for %i0 = 0 to 10 { - store %cf7, %m[%i0] : memref<10xf32> - %v0 = load %m[%i0] : memref<10xf32> + affine.store %cf7, %m[%i0] : memref<10xf32> + %v0 = affine.load %m[%i0] : memref<10xf32> %v1 = addf %v0, %v0 : f32 - store %cf8, %m[%i0] : memref<10xf32> - store %cf9, %m[%i0] : memref<10xf32> - %v2 = load %m[%i0] : memref<10xf32> - %v3 = load %m[%i0] : memref<10xf32> + affine.store %cf8, %m[%i0] : memref<10xf32> + affine.store %cf9, %m[%i0] : memref<10xf32> + %v2 = affine.load %m[%i0] : memref<10xf32> + %v3 = affine.load %m[%i0] : memref<10xf32> %v4 = mulf %v2, %v3 : f32 } return @@ -65,9 +65,9 @@ func @store_load_affine_apply() -> memref<10x10xf32> { %t1 = affine.apply (d0, d1) -> (d0)(%i0, %i1) %idx0 = affine.apply (d0, d1) -> (d1) (%t0, %t1) %idx1 = affine.apply (d0, d1) -> (d0 - 1) (%t0, %t1) - store %cf7, %m[%idx0, %idx1] : memref<10x10xf32> - // CHECK-NOT: load %{{[0-9]+}} - %v0 = load %m[%i0, %i1] : memref<10x10xf32> + affine.store %cf7, %m[%idx0, %idx1] : memref<10x10xf32> + // CHECK-NOT: affine.load %{{[0-9]+}} + %v0 = affine.load %m[%i0, %i1] : memref<10x10xf32> %v1 = addf %v0, %v0 : f32 } } @@ -81,7 +81,7 @@ func @store_load_affine_apply() -> memref<10x10xf32> { // CHECK-NEXT: %2 = affine.apply [[MAP1]](%i0, %i1) // CHECK-NEXT: %3 = affine.apply [[MAP2]](%1, %2) // CHECK-NEXT: %4 = affine.apply [[MAP3]](%1, %2) -// CHECK-NEXT: store %cst, %0[%3, %4] : memref<10x10xf32> +// CHECK-NEXT: affine.store %cst, %0[%3, %4] : memref<10x10xf32> // CHECK-NEXT: %5 = addf %cst, %cst : f32 // CHECK-NEXT: } // CHECK-NEXT: } @@ -93,9 +93,9 @@ func @store_load_nested(%N : index) { %cf7 = constant 7.0 : f32 %m = alloc() : memref<10xf32> affine.for %i0 = 0 to 10 { - store %cf7, %m[%i0] : memref<10xf32> + affine.store %cf7, %m[%i0] : memref<10xf32> affine.for %i1 = 0 to %N { - %v0 = load %m[%i0] : memref<10xf32> + %v0 = affine.load %m[%i0] : memref<10xf32> %v1 = addf %v0, %v0 : f32 } } @@ -118,13 +118,13 @@ func @multi_store_load_nested_no_fwd(%N : index) { %cf8 = constant 8.0 : f32 %m = alloc() : memref<10xf32> affine.for %i0 = 0 to 10 { - store %cf7, %m[%i0] : memref<10xf32> + affine.store %cf7, %m[%i0] : memref<10xf32> affine.for %i1 = 0 to %N { - store %cf8, %m[%i1] : memref<10xf32> + affine.store %cf8, %m[%i1] : memref<10xf32> } affine.for %i2 = 0 to %N { - // CHECK: %{{[0-9]+}} = load %0[%i0] : memref<10xf32> - %v0 = load %m[%i0] : memref<10xf32> + // CHECK: %{{[0-9]+}} = affine.load %0[%i0] : memref<10xf32> + %v0 = affine.load %m[%i0] : memref<10xf32> %v1 = addf %v0, %v0 : f32 } } @@ -139,12 +139,12 @@ func @store_load_store_nested_no_fwd(%N : index) { %cf9 = constant 9.0 : f32 %m = alloc() : memref<10xf32> affine.for %i0 = 0 to 10 { - store %cf7, %m[%i0] : memref<10xf32> + affine.store %cf7, %m[%i0] : memref<10xf32> affine.for %i1 = 0 to %N { - // CHECK: %{{[0-9]+}} = load %0[%i0] : memref<10xf32> - %v0 = load %m[%i0] : memref<10xf32> + // CHECK: %{{[0-9]+}} = affine.load %0[%i0] : memref<10xf32> + %v0 = affine.load %m[%i0] : memref<10xf32> %v1 = addf %v0, %v0 : f32 - store %cf9, %m[%i0] : memref<10xf32> + affine.store %cf9, %m[%i0] : memref<10xf32> } } return @@ -160,17 +160,17 @@ func @multi_store_load_nested_fwd(%N : index) { %cf10 = constant 10.0 : f32 %m = alloc() : memref<10xf32> affine.for %i0 = 0 to 10 { - store %cf7, %m[%i0] : memref<10xf32> + affine.store %cf7, %m[%i0] : memref<10xf32> affine.for %i1 = 0 to %N { - store %cf8, %m[%i1] : memref<10xf32> + affine.store %cf8, %m[%i1] : memref<10xf32> } affine.for %i2 = 0 to %N { - store %cf9, %m[%i2] : memref<10xf32> + affine.store %cf9, %m[%i2] : memref<10xf32> } - store %cf10, %m[%i0] : memref<10xf32> + affine.store %cf10, %m[%i0] : memref<10xf32> affine.for %i3 = 0 to %N { - // CHECK-NOT: %{{[0-9]+}} = load - %v0 = load %m[%i0] : memref<10xf32> + // CHECK-NOT: %{{[0-9]+}} = affine.load + %v0 = affine.load %m[%i0] : memref<10xf32> %v1 = addf %v0, %v0 : f32 } } @@ -183,11 +183,11 @@ func @store_load_no_fwd() { %cf7 = constant 7.0 : f32 %m = alloc() : memref<10xf32> affine.for %i0 = 0 to 10 { - store %cf7, %m[%i0] : memref<10xf32> + affine.store %cf7, %m[%i0] : memref<10xf32> affine.for %i1 = 0 to 10 { affine.for %i2 = 0 to 10 { - // CHECK: load %{{[0-9]+}} - %v0 = load %m[%i2] : memref<10xf32> + // CHECK: affine.load %{{[0-9]+}} + %v0 = affine.load %m[%i2] : memref<10xf32> %v1 = addf %v0, %v0 : f32 } } @@ -201,12 +201,12 @@ func @store_load_fwd() { %cf7 = constant 7.0 : f32 %c0 = constant 0 : index %m = alloc() : memref<10xf32> - store %cf7, %m[%c0] : memref<10xf32> + affine.store %cf7, %m[%c0] : memref<10xf32> affine.for %i0 = 0 to 10 { affine.for %i1 = 0 to 10 { affine.for %i2 = 0 to 10 { - // CHECK-NOT: load %{{[0-9]}}+ - %v0 = load %m[%c0] : memref<10xf32> + // CHECK-NOT: affine.load %{{[0-9]}}+ + %v0 = affine.load %m[%c0] : memref<10xf32> %v1 = addf %v0, %v0 : f32 } } @@ -224,26 +224,26 @@ func @store_load_store_nested_fwd(%N : index) -> f32 { %c1 = constant 1 : index %m = alloc() : memref<10xf32> affine.for %i0 = 0 to 10 { - store %cf7, %m[%i0] : memref<10xf32> + affine.store %cf7, %m[%i0] : memref<10xf32> affine.for %i1 = 0 to %N { - %v0 = load %m[%i0] : memref<10xf32> + %v0 = affine.load %m[%i0] : memref<10xf32> %v1 = addf %v0, %v0 : f32 %idx = affine.apply (d0) -> (d0 + 1) (%i0) - store %cf9, %m[%idx] : memref<10xf32> + affine.store %cf9, %m[%idx] : memref<10xf32> } } // Due to this load, the memref isn't optimized away. - %v3 = load %m[%c1] : memref<10xf32> + %v3 = affine.load %m[%c1] : memref<10xf32> return %v3 : f32 // CHECK: %0 = alloc() : memref<10xf32> // CHECK-NEXT: affine.for %i0 = 0 to 10 { -// CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> +// CHECK-NEXT: affine.store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: affine.for %i1 = 0 to %arg0 { // CHECK-NEXT: %1 = addf %cst, %cst : f32 // CHECK-NEXT: %2 = affine.apply [[MAP4]](%i0) -// CHECK-NEXT: store %cst_0, %0[%2] : memref<10xf32> +// CHECK-NEXT: affine.store %cst_0, %0[%2] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: %3 = load %0[%c1] : memref<10xf32> +// CHECK-NEXT: %3 = affine.load %0[%c1] : memref<10xf32> // CHECK-NEXT: return %3 : f32 } diff --git a/mlir/test/Transforms/memref-dependence-check.mlir b/mlir/test/Transforms/memref-dependence-check.mlir index 84acac49c36..3efc134e627 100644 --- a/mlir/test/Transforms/memref-dependence-check.mlir +++ b/mlir/test/Transforms/memref-dependence-check.mlir @@ -14,14 +14,14 @@ func @store_may_execute_before_load() { // and thus the store "may" conditionally execute before the load. affine.if #set0(%c0) { affine.for %i0 = 0 to 10 { - store %cf7, %m[%i0] : memref<10xf32> + affine.store %cf7, %m[%i0] : memref<10xf32> // expected-remark@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 0 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 0 to 1 at depth 1 = true}} } } affine.for %i1 = 0 to 10 { - %v0 = load %m[%i1] : memref<10xf32> + %v0 = affine.load %m[%i1] : memref<10xf32> // expected-remark@-1 {{dependence from 1 to 1 at depth 1 = false}} // expected-remark@-2 {{dependence from 1 to 1 at depth 2 = false}} // expected-remark@-3 {{dependence from 1 to 0 at depth 1 = false}} @@ -38,13 +38,13 @@ func @dependent_loops() { // There is a dependence from 0 to 1 at depth 1 (common surrounding loops 0) // because the first loop with the store dominates the second loop. affine.for %i0 = 0 to 10 { - store %cst, %0[%i0] : memref<10xf32> + affine.store %cst, %0[%i0] : memref<10xf32> // expected-remark@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 0 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 0 to 1 at depth 1 = true}} } affine.for %i1 = 0 to 10 { - %1 = load %0[%i1] : memref<10xf32> + %1 = affine.load %0[%i1] : memref<10xf32> // expected-remark@-1 {{dependence from 1 to 1 at depth 1 = false}} // expected-remark@-2 {{dependence from 1 to 1 at depth 2 = false}} // expected-remark@-3 {{dependence from 1 to 0 at depth 1 = false}} @@ -59,10 +59,10 @@ func @different_memrefs() { %m.b = alloc() : memref<100xf32> %c0 = constant 0 : index %c1 = constant 1.0 : f32 - store %c1, %m.a[%c0] : memref<100xf32> + affine.store %c1, %m.a[%c0] : memref<100xf32> // expected-remark@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 0 to 1 at depth 1 = false}} - %v0 = load %m.b[%c0] : memref<100xf32> + %v0 = affine.load %m.b[%c0] : memref<100xf32> // expected-remark@-1 {{dependence from 1 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 1 to 1 at depth 1 = false}} return @@ -75,10 +75,10 @@ func @store_load_different_elements() { %c0 = constant 0 : index %c1 = constant 1 : index %c7 = constant 7.0 : f32 - store %c7, %m[%c0] : memref<100xf32> + affine.store %c7, %m[%c0] : memref<100xf32> // expected-remark@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 0 to 1 at depth 1 = false}} - %v0 = load %m[%c1] : memref<100xf32> + %v0 = affine.load %m[%c1] : memref<100xf32> // expected-remark@-1 {{dependence from 1 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 1 to 1 at depth 1 = false}} return @@ -91,10 +91,10 @@ func @load_store_different_elements() { %c0 = constant 0 : index %c1 = constant 1 : index %c7 = constant 7.0 : f32 - %v0 = load %m[%c1] : memref<100xf32> + %v0 = affine.load %m[%c1] : memref<100xf32> // expected-remark@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 0 to 1 at depth 1 = false}} - store %c7, %m[%c0] : memref<100xf32> + affine.store %c7, %m[%c0] : memref<100xf32> // expected-remark@-1 {{dependence from 1 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 1 to 1 at depth 1 = false}} return @@ -106,10 +106,10 @@ func @store_load_same_element() { %m = alloc() : memref<100xf32> %c11 = constant 11 : index %c7 = constant 7.0 : f32 - store %c7, %m[%c11] : memref<100xf32> + affine.store %c7, %m[%c11] : memref<100xf32> // expected-remark@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 0 to 1 at depth 1 = true}} - %v0 = load %m[%c11] : memref<100xf32> + %v0 = affine.load %m[%c11] : memref<100xf32> // expected-remark@-1 {{dependence from 1 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 1 to 1 at depth 1 = false}} return @@ -121,10 +121,10 @@ func @load_load_same_element() { %m = alloc() : memref<100xf32> %c11 = constant 11 : index %c7 = constant 7.0 : f32 - %v0 = load %m[%c11] : memref<100xf32> + %v0 = affine.load %m[%c11] : memref<100xf32> // expected-remark@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 0 to 1 at depth 1 = false}} - %v1 = load %m[%c11] : memref<100xf32> + %v1 = affine.load %m[%c11] : memref<100xf32> // expected-remark@-1 {{dependence from 1 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 1 to 1 at depth 1 = false}} return @@ -135,10 +135,10 @@ func @load_load_same_element() { func @store_load_same_symbol(%arg0: index) { %m = alloc() : memref<100xf32> %c7 = constant 7.0 : f32 - store %c7, %m[%arg0] : memref<100xf32> + affine.store %c7, %m[%arg0] : memref<100xf32> // expected-remark@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 0 to 1 at depth 1 = true}} - %v0 = load %m[%arg0] : memref<100xf32> + %v0 = affine.load %m[%arg0] : memref<100xf32> // expected-remark@-1 {{dependence from 1 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 1 to 1 at depth 1 = false}} return @@ -149,10 +149,10 @@ func @store_load_same_symbol(%arg0: index) { func @store_load_different_symbols(%arg0: index, %arg1: index) { %m = alloc() : memref<100xf32> %c7 = constant 7.0 : f32 - store %c7, %m[%arg0] : memref<100xf32> + affine.store %c7, %m[%arg0] : memref<100xf32> // expected-remark@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 0 to 1 at depth 1 = true}} - %v0 = load %m[%arg1] : memref<100xf32> + %v0 = affine.load %m[%arg1] : memref<100xf32> // expected-remark@-1 {{dependence from 1 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 1 to 1 at depth 1 = false}} return @@ -165,11 +165,11 @@ func @store_load_diff_element_affine_apply_const() { %c1 = constant 1 : index %c8 = constant 8.0 : f32 %a0 = affine.apply (d0) -> (d0) (%c1) - store %c8, %m[%a0] : memref<100xf32> + affine.store %c8, %m[%a0] : memref<100xf32> // expected-remark@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 0 to 1 at depth 1 = false}} %a1 = affine.apply (d0) -> (d0 + 1) (%c1) - %v0 = load %m[%a1] : memref<100xf32> + %v0 = affine.load %m[%a1] : memref<100xf32> // expected-remark@-1 {{dependence from 1 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 1 to 1 at depth 1 = false}} return @@ -183,11 +183,11 @@ func @store_load_same_element_affine_apply_const() { %c9 = constant 9 : index %c11 = constant 11 : index %a0 = affine.apply (d0) -> (d0 + 1) (%c9) - store %c7, %m[%a0] : memref<100xf32> + affine.store %c7, %m[%a0] : memref<100xf32> // expected-remark@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 0 to 1 at depth 1 = true}} %a1 = affine.apply (d0) -> (d0 - 1) (%c11) - %v0 = load %m[%a1] : memref<100xf32> + %v0 = affine.load %m[%a1] : memref<100xf32> // expected-remark@-1 {{dependence from 1 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 1 to 1 at depth 1 = false}} return @@ -199,11 +199,11 @@ func @store_load_affine_apply_symbol(%arg0: index) { %m = alloc() : memref<100xf32> %c7 = constant 7.0 : f32 %a0 = affine.apply (d0) -> (d0) (%arg0) - store %c7, %m[%a0] : memref<100xf32> + affine.store %c7, %m[%a0] : memref<100xf32> // expected-remark@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 0 to 1 at depth 1 = true}} %a1 = affine.apply (d0) -> (d0) (%arg0) - %v0 = load %m[%a1] : memref<100xf32> + %v0 = affine.load %m[%a1] : memref<100xf32> // expected-remark@-1 {{dependence from 1 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 1 to 1 at depth 1 = false}} return @@ -215,11 +215,11 @@ func @store_load_affine_apply_symbol_offset(%arg0: index) { %m = alloc() : memref<100xf32> %c7 = constant 7.0 : f32 %a0 = affine.apply (d0) -> (d0) (%arg0) - store %c7, %m[%a0] : memref<100xf32> + affine.store %c7, %m[%a0] : memref<100xf32> // expected-remark@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 0 to 1 at depth 1 = false}} %a1 = affine.apply (d0) -> (d0 + 1) (%arg0) - %v0 = load %m[%a1] : memref<100xf32> + %v0 = affine.load %m[%a1] : memref<100xf32> // expected-remark@-1 {{dependence from 1 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 1 to 1 at depth 1 = false}} return @@ -233,13 +233,13 @@ func @store_range_load_after_range() { %c10 = constant 10 : index affine.for %i0 = 0 to 10 { %a0 = affine.apply (d0) -> (d0) (%i0) - store %c7, %m[%a0] : memref<100xf32> + affine.store %c7, %m[%a0] : memref<100xf32> // expected-remark@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 0 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 0 to 1 at depth 1 = false}} // expected-remark@-4 {{dependence from 0 to 1 at depth 2 = false}} %a1 = affine.apply (d0) -> (d0) (%c10) - %v0 = load %m[%a1] : memref<100xf32> + %v0 = affine.load %m[%a1] : memref<100xf32> // expected-remark@-1 {{dependence from 1 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 1 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 1 to 1 at depth 1 = false}} @@ -256,13 +256,13 @@ func @store_load_func_symbol(%arg0: index, %arg1: index) { %c10 = constant 10 : index affine.for %i0 = 0 to %arg1 { %a0 = affine.apply (d0) -> (d0) (%arg0) - store %c7, %m[%a0] : memref<100xf32> + affine.store %c7, %m[%a0] : memref<100xf32> // expected-remark@-1 {{dependence from 0 to 0 at depth 1 = [1, +inf]}} // expected-remark@-2 {{dependence from 0 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 0 to 1 at depth 1 = [1, +inf]}} // expected-remark@-4 {{dependence from 0 to 1 at depth 2 = true}} %a1 = affine.apply (d0) -> (d0) (%arg0) - %v0 = load %m[%a1] : memref<100xf32> + %v0 = affine.load %m[%a1] : memref<100xf32> // expected-remark@-1 {{dependence from 1 to 0 at depth 1 = [1, +inf]}} // expected-remark@-2 {{dependence from 1 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 1 to 1 at depth 1 = false}} @@ -282,7 +282,7 @@ func @store_range_load_last_in_range() { // For dependence from 0 to 1, we do not have a loop carried dependence // because only the final write in the loop accesses the same element as the // load, so this dependence appears only at depth 2 (loop independent). - store %c7, %m[%a0] : memref<100xf32> + affine.store %c7, %m[%a0] : memref<100xf32> // expected-remark@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 0 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 0 to 1 at depth 1 = false}} @@ -290,7 +290,7 @@ func @store_range_load_last_in_range() { %a1 = affine.apply (d0) -> (d0 - 1) (%c10) // For dependence from 1 to 0, we have write-after-read (WAR) dependences // for all loads in the loop to the store on the last iteration. - %v0 = load %m[%a1] : memref<100xf32> + %v0 = affine.load %m[%a1] : memref<100xf32> // expected-remark@-1 {{dependence from 1 to 0 at depth 1 = [1, 9]}} // expected-remark@-2 {{dependence from 1 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 1 to 1 at depth 1 = false}} @@ -307,13 +307,13 @@ func @store_range_load_before_range() { %c0 = constant 0 : index affine.for %i0 = 1 to 11 { %a0 = affine.apply (d0) -> (d0) (%i0) - store %c7, %m[%a0] : memref<100xf32> + affine.store %c7, %m[%a0] : memref<100xf32> // expected-remark@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 0 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 0 to 1 at depth 1 = false}} // expected-remark@-4 {{dependence from 0 to 1 at depth 2 = false}} %a1 = affine.apply (d0) -> (d0) (%c0) - %v0 = load %m[%a1] : memref<100xf32> + %v0 = affine.load %m[%a1] : memref<100xf32> // expected-remark@-1 {{dependence from 1 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 1 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 1 to 1 at depth 1 = false}} @@ -333,13 +333,13 @@ func @store_range_load_first_in_range() { // Dependence from 0 to 1 at depth 1 is a range because all loads at // constant index zero are reads after first store at index zero during // first iteration of the loop. - store %c7, %m[%a0] : memref<100xf32> + affine.store %c7, %m[%a0] : memref<100xf32> // expected-remark@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 0 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 0 to 1 at depth 1 = [1, 9]}} // expected-remark@-4 {{dependence from 0 to 1 at depth 2 = true}} %a1 = affine.apply (d0) -> (d0 + 1) (%c0) - %v0 = load %m[%a1] : memref<100xf32> + %v0 = affine.load %m[%a1] : memref<100xf32> // expected-remark@-1 {{dependence from 1 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 1 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 1 to 1 at depth 1 = false}} @@ -355,13 +355,13 @@ func @store_plus_3() { %c7 = constant 7.0 : f32 affine.for %i0 = 1 to 11 { %a0 = affine.apply (d0) -> (d0 + 3) (%i0) - store %c7, %m[%a0] : memref<100xf32> + affine.store %c7, %m[%a0] : memref<100xf32> // expected-remark@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 0 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 0 to 1 at depth 1 = [3, 3]}} // expected-remark@-4 {{dependence from 0 to 1 at depth 2 = false}} %a1 = affine.apply (d0) -> (d0) (%i0) - %v0 = load %m[%a1] : memref<100xf32> + %v0 = affine.load %m[%a1] : memref<100xf32> // expected-remark@-1 {{dependence from 1 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 1 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 1 to 1 at depth 1 = false}} @@ -377,13 +377,13 @@ func @load_minus_2() { %c7 = constant 7.0 : f32 affine.for %i0 = 2 to 11 { %a0 = affine.apply (d0) -> (d0) (%i0) - store %c7, %m[%a0] : memref<100xf32> + affine.store %c7, %m[%a0] : memref<100xf32> // expected-remark@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 0 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 0 to 1 at depth 1 = [2, 2]}} // expected-remark@-4 {{dependence from 0 to 1 at depth 2 = false}} %a1 = affine.apply (d0) -> (d0 - 2) (%i0) - %v0 = load %m[%a1] : memref<100xf32> + %v0 = affine.load %m[%a1] : memref<100xf32> // expected-remark@-1 {{dependence from 1 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 1 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 1 to 1 at depth 1 = false}} @@ -402,7 +402,7 @@ func @perfectly_nested_loops_loop_independent() { // Dependence from access 0 to 1 is loop independent at depth = 3. %a00 = affine.apply (d0, d1) -> (d0) (%i0, %i1) %a01 = affine.apply (d0, d1) -> (d1) (%i0, %i1) - store %c7, %m[%a00, %a01] : memref<10x10xf32> + affine.store %c7, %m[%a00, %a01] : memref<10x10xf32> // expected-remark@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 0 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 0 to 0 at depth 3 = false}} @@ -411,7 +411,7 @@ func @perfectly_nested_loops_loop_independent() { // expected-remark@-6 {{dependence from 0 to 1 at depth 3 = true}} %a10 = affine.apply (d0, d1) -> (d0) (%i0, %i1) %a11 = affine.apply (d0, d1) -> (d1) (%i0, %i1) - %v0 = load %m[%a10, %a11] : memref<10x10xf32> + %v0 = affine.load %m[%a10, %a11] : memref<10x10xf32> // expected-remark@-1 {{dependence from 1 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 1 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 1 to 0 at depth 3 = false}} @@ -433,7 +433,7 @@ func @perfectly_nested_loops_loop_carried_at_depth1() { // Dependence from access 0 to 1 is loop carried at depth 1. %a00 = affine.apply (d0, d1) -> (d0) (%i0, %i1) %a01 = affine.apply (d0, d1) -> (d1) (%i0, %i1) - store %c7, %m[%a00, %a01] : memref<10x10xf32> + affine.store %c7, %m[%a00, %a01] : memref<10x10xf32> // expected-remark@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 0 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 0 to 0 at depth 3 = false}} @@ -442,7 +442,7 @@ func @perfectly_nested_loops_loop_carried_at_depth1() { // expected-remark@-6 {{dependence from 0 to 1 at depth 3 = false}} %a10 = affine.apply (d0, d1) -> (d0 - 2) (%i0, %i1) %a11 = affine.apply (d0, d1) -> (d1) (%i0, %i1) - %v0 = load %m[%a10, %a11] : memref<10x10xf32> + %v0 = affine.load %m[%a10, %a11] : memref<10x10xf32> // expected-remark@-1 {{dependence from 1 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 1 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 1 to 0 at depth 3 = false}} @@ -464,7 +464,7 @@ func @perfectly_nested_loops_loop_carried_at_depth2() { // Dependence from access 0 to 1 is loop carried at depth 2. %a00 = affine.apply (d0, d1) -> (d0) (%i0, %i1) %a01 = affine.apply (d0, d1) -> (d1) (%i0, %i1) - store %c7, %m[%a00, %a01] : memref<10x10xf32> + affine.store %c7, %m[%a00, %a01] : memref<10x10xf32> // expected-remark@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 0 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 0 to 0 at depth 3 = false}} @@ -473,7 +473,7 @@ func @perfectly_nested_loops_loop_carried_at_depth2() { // expected-remark@-6 {{dependence from 0 to 1 at depth 3 = false}} %a10 = affine.apply (d0, d1) -> (d0) (%i0, %i1) %a11 = affine.apply (d0, d1) -> (d1 - 3) (%i0, %i1) - %v0 = load %m[%a10, %a11] : memref<10x10xf32> + %v0 = affine.load %m[%a10, %a11] : memref<10x10xf32> // expected-remark@-1 {{dependence from 1 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 1 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 1 to 0 at depth 3 = false}} @@ -495,7 +495,7 @@ func @one_common_loop() { affine.for %i1 = 0 to 10 { %a00 = affine.apply (d0, d1) -> (d0) (%i0, %i1) %a01 = affine.apply (d0, d1) -> (d1) (%i0, %i1) - store %c7, %m[%a00, %a01] : memref<10x10xf32> + affine.store %c7, %m[%a00, %a01] : memref<10x10xf32> // expected-remark@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 0 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 0 to 0 at depth 3 = false}} @@ -505,7 +505,7 @@ func @one_common_loop() { affine.for %i2 = 0 to 9 { %a10 = affine.apply (d0, d1) -> (d0) (%i0, %i2) %a11 = affine.apply (d0, d1) -> (d1) (%i0, %i2) - %v0 = load %m[%a10, %a11] : memref<10x10xf32> + %v0 = affine.load %m[%a10, %a11] : memref<10x10xf32> // expected-remark@-1 {{dependence from 1 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 1 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 1 to 1 at depth 1 = false}} @@ -527,7 +527,7 @@ func @dependence_cycle() { // *) loop-carried dependence from access 3 to 0 at depth 1. affine.for %i0 = 0 to 9 { %a0 = affine.apply (d0) -> (d0) (%i0) - %v0 = load %m.a[%a0] : memref<100xf32> + %v0 = affine.load %m.a[%a0] : memref<100xf32> // expected-remark@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 0 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 0 to 1 at depth 1 = false}} @@ -537,7 +537,7 @@ func @dependence_cycle() { // expected-remark@-7 {{dependence from 0 to 3 at depth 1 = false}} // expected-remark@-8 {{dependence from 0 to 3 at depth 2 = false}} %a1 = affine.apply (d0) -> (d0) (%i0) - store %v0, %m.b[%a1] : memref<100xf32> + affine.store %v0, %m.b[%a1] : memref<100xf32> // expected-remark@-1 {{dependence from 1 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 1 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 1 to 1 at depth 1 = false}} @@ -547,7 +547,7 @@ func @dependence_cycle() { // expected-remark@-7 {{dependence from 1 to 3 at depth 1 = false}} // expected-remark@-8 {{dependence from 1 to 3 at depth 2 = false}} %a2 = affine.apply (d0) -> (d0) (%i0) - %v1 = load %m.b[%a2] : memref<100xf32> + %v1 = affine.load %m.b[%a2] : memref<100xf32> // expected-remark@-1 {{dependence from 2 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 2 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 2 to 1 at depth 1 = false}} @@ -557,7 +557,7 @@ func @dependence_cycle() { // expected-remark@-7 {{dependence from 2 to 3 at depth 1 = false}} // expected-remark@-8 {{dependence from 2 to 3 at depth 2 = false}} %a3 = affine.apply (d0) -> (d0 + 1) (%i0) - store %v1, %m.a[%a3] : memref<100xf32> + affine.store %v1, %m.a[%a3] : memref<100xf32> // expected-remark@-1 {{dependence from 3 to 0 at depth 1 = [1, 1]}} // expected-remark@-2 {{dependence from 3 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 3 to 1 at depth 1 = false}} @@ -579,7 +579,7 @@ func @negative_and_positive_direction_vectors(%arg0: index, %arg1: index) { affine.for %i1 = 0 to %arg1 { %a00 = affine.apply (d0, d1) -> (d0 - 1) (%i0, %i1) %a01 = affine.apply (d0, d1) -> (d1 + 1) (%i0, %i1) - %v0 = load %m[%a00, %a01] : memref<10x10xf32> + %v0 = affine.load %m[%a00, %a01] : memref<10x10xf32> // expected-remark@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 0 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 0 to 0 at depth 3 = false}} @@ -588,7 +588,7 @@ func @negative_and_positive_direction_vectors(%arg0: index, %arg1: index) { // expected-remark@-6 {{dependence from 0 to 1 at depth 3 = false}} %a10 = affine.apply (d0, d1) -> (d0) (%i0, %i1) %a11 = affine.apply (d0, d1) -> (d1) (%i0, %i1) - store %c7, %m[%a10, %a11] : memref<10x10xf32> + affine.store %c7, %m[%a10, %a11] : memref<10x10xf32> // expected-remark@-1 {{dependence from 1 to 0 at depth 1 = [1, 1][-1, -1]}} // expected-remark@-2 {{dependence from 1 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 1 to 0 at depth 3 = false}} @@ -608,7 +608,7 @@ func @war_raw_waw_deps() { affine.for %i0 = 0 to 10 { affine.for %i1 = 0 to 10 { %a0 = affine.apply (d0) -> (d0 + 1) (%i1) - %v0 = load %m[%a0] : memref<100xf32> + %v0 = affine.load %m[%a0] : memref<100xf32> // expected-remark@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 0 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 0 to 0 at depth 3 = false}} @@ -616,7 +616,7 @@ func @war_raw_waw_deps() { // expected-remark@-5 {{dependence from 0 to 1 at depth 2 = [0, 0][1, 1]}} // expected-remark@-6 {{dependence from 0 to 1 at depth 3 = false}} %a1 = affine.apply (d0) -> (d0) (%i1) - store %c7, %m[%a1] : memref<100xf32> + affine.store %c7, %m[%a1] : memref<100xf32> // expected-remark@-1 {{dependence from 1 to 0 at depth 1 = [1, 9][-1, -1]}} // expected-remark@-2 {{dependence from 1 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 1 to 0 at depth 3 = false}} @@ -637,13 +637,13 @@ func @mod_deps() { %a0 = affine.apply (d0) -> (d0 mod 2) (%i0) // Results are conservative here since we currently don't have a way to // represent strided sets in FlatAffineConstraints. - %v0 = load %m[%a0] : memref<100xf32> + %v0 = affine.load %m[%a0] : memref<100xf32> // expected-remark@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 0 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 0 to 1 at depth 1 = [1, 9]}} // expected-remark@-4 {{dependence from 0 to 1 at depth 2 = false}} %a1 = affine.apply (d0) -> ( (d0 + 1) mod 2) (%i0) - store %c7, %m[%a1] : memref<100xf32> + affine.store %c7, %m[%a1] : memref<100xf32> // expected-remark@-1 {{dependence from 1 to 0 at depth 1 = [1, 9]}} // expected-remark@-2 {{dependence from 1 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 1 to 1 at depth 1 = [2, 9]}} @@ -660,7 +660,7 @@ func @loop_nest_depth() { affine.for %i0 = 0 to 128 { affine.for %i1 = 0 to 8 { - store %c7, %0[%i0, %i1] : memref<100x100xf32> + affine.store %c7, %0[%i0, %i1] : memref<100x100xf32> // expected-remark@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 0 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 0 to 0 at depth 3 = false}} @@ -672,7 +672,7 @@ func @loop_nest_depth() { affine.for %i4 = 0 to 8 { affine.for %i5 = 0 to 16 { %8 = affine.apply (d0, d1) -> (d0 * 16 + d1)(%i4, %i5) - %9 = load %0[%8, %i3] : memref<100x100xf32> + %9 = affine.load %0[%8, %i3] : memref<100x100xf32> // expected-remark@-1 {{dependence from 1 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 1 to 1 at depth 1 = false}} // expected-remark@-3 {{dependence from 1 to 1 at depth 2 = false}} @@ -699,7 +699,7 @@ func @mod_div_3d() { %idx0 = affine.apply (d0, d1, d2) -> (d0 floordiv 4) (%i0, %i1, %i2) %idx1 = affine.apply (d0, d1, d2) -> (d1 mod 2) (%i0, %i1, %i2) %idx2 = affine.apply (d0, d1, d2) -> (d2 floordiv 4) (%i0, %i1, %i2) - store %c0, %M[%idx0, %idx1, %idx2] : memref<2 x 2 x 2 x i32> + affine.store %c0, %M[%idx0, %idx1, %idx2] : memref<2 x 2 x 2 x i32> // expected-remark@-1 {{dependence from 0 to 0 at depth 1 = [1, 3][-7, 7][-3, 3]}} // expected-remark@-2 {{dependence from 0 to 0 at depth 2 = [0, 0][2, 7][-3, 3]}} // expected-remark@-3 {{dependence from 0 to 0 at depth 3 = [0, 0][0, 0][1, 3]}} @@ -725,7 +725,7 @@ func @delinearize_mod_floordiv() { affine.for %i3 = 0 to 3 { affine.for %i4 = 0 to 16 { affine.for %i5 = 0 to 1 { - store %val, %in[%i0, %i1, %i2, %i3, %i4, %i5] : memref<2x2x3x3x16x1xi32> + affine.store %val, %in[%i0, %i1, %i2, %i3, %i4, %i5] : memref<2x2x3x3x16x1xi32> // expected-remark@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 0 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 0 to 0 at depth 3 = false}} @@ -758,7 +758,7 @@ func @delinearize_mod_floordiv() { %a15 = affine.apply (d0) -> ((((((d0 mod 294912) mod 147456) mod 1152) mod 384) mod 128) floordiv 128) (%a0) - %v0 = load %in[%a10, %a11, %a13, %a14, %a12, %a15] : memref<2x2x3x3x16x1xi32> + %v0 = affine.load %in[%a10, %a11, %a13, %a14, %a12, %a15] : memref<2x2x3x3x16x1xi32> // expected-remark@-1 {{dependence from 1 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 1 to 1 at depth 1 = false}} // expected-remark@-3 {{dependence from 1 to 1 at depth 2 = false}} @@ -768,7 +768,7 @@ func @delinearize_mod_floordiv() { // expected-remark@-7 {{dependence from 1 to 2 at depth 3 = false}} // TODO(andydavis): the dep tester shouldn't be printing out these messages // below; they are redundant. - store %v0, %out[%ii, %jj] : memref<64x9xi32> + affine.store %v0, %out[%ii, %jj] : memref<64x9xi32> // expected-remark@-1 {{dependence from 2 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 2 to 1 at depth 1 = false}} // expected-remark@-3 {{dependence from 2 to 1 at depth 2 = false}} @@ -791,12 +791,12 @@ func @strided_loop_with_dependence_at_depth2() { %0 = alloc() : memref<10xf32> %cf0 = constant 0.0 : f32 affine.for %i0 = 0 to 8 step 2 { - store %cf0, %0[%i0] : memref<10xf32> + affine.store %cf0, %0[%i0] : memref<10xf32> // expected-remark@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 0 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 0 to 1 at depth 1 = false}} // expected-remark@-4 {{dependence from 0 to 1 at depth 2 = true}} - %v0 = load %0[%i0] : memref<10xf32> + %v0 = affine.load %0[%i0] : memref<10xf32> // expected-remark@-1 {{dependence from 1 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 1 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 1 to 1 at depth 1 = false}} @@ -814,12 +814,12 @@ func @strided_loop_with_no_dependence() { %cf0 = constant 0.0 : f32 affine.for %i0 = 0 to 8 step 2 { %a0 = affine.apply (d0) -> (d0 + 1)(%i0) - store %cf0, %0[%a0] : memref<10xf32> + affine.store %cf0, %0[%a0] : memref<10xf32> // expected-remark@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 0 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 0 to 1 at depth 1 = false}} // expected-remark@-4 {{dependence from 0 to 1 at depth 2 = false}} - %v0 = load %0[%i0] : memref<10xf32> + %v0 = affine.load %0[%i0] : memref<10xf32> // expected-remark@-1 {{dependence from 1 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 1 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 1 to 1 at depth 1 = false}} @@ -830,19 +830,19 @@ func @strided_loop_with_no_dependence() { // ----- -// Store op accesses memref elements at offset causing loop-carried dependence. +// Affine.Store op accesses memref elements at offset causing loop-carried dependence. // CHECK-LABEL: func @strided_loop_with_loop_carried_dependence_at_depth1 func @strided_loop_with_loop_carried_dependence_at_depth1() { %0 = alloc() : memref<10xf32> %cf0 = constant 0.0 : f32 affine.for %i0 = 0 to 8 step 2 { %a0 = affine.apply (d0) -> (d0 + 4)(%i0) - store %cf0, %0[%a0] : memref<10xf32> + affine.store %cf0, %0[%a0] : memref<10xf32> // expected-remark@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 0 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 0 to 1 at depth 1 = [4, 4]}} // expected-remark@-4 {{dependence from 0 to 1 at depth 2 = false}} - %v0 = load %0[%i0] : memref<10xf32> + %v0 = affine.load %0[%i0] : memref<10xf32> // expected-remark@-1 {{dependence from 1 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 1 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 1 to 1 at depth 1 = false}} @@ -861,13 +861,13 @@ func @test_dep_store_depth1_load_depth2() { %cst = constant 7.000000e+00 : f32 affine.for %i0 = 0 to 10 { %a0 = affine.apply (d0) -> (d0 - 1)(%i0) - store %cst, %0[%a0] : memref<100xf32> + affine.store %cst, %0[%a0] : memref<100xf32> // expected-remark@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 0 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 0 to 1 at depth 1 = false}} // expected-remark@-4 {{dependence from 0 to 1 at depth 2 = false}} affine.for %i1 = (d0) -> (d0)(%i0) to (d0) -> (d0 + 1)(%i0) { - %1 = load %0[%i1] : memref<100xf32> + %1 = affine.load %0[%i1] : memref<100xf32> // expected-remark@-1 {{dependence from 1 to 0 at depth 1 = [1, 1]}} // expected-remark@-2 {{dependence from 1 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 1 to 1 at depth 1 = false}} @@ -888,7 +888,7 @@ func @test_dep_store_depth2_load_depth1() { %cst = constant 7.000000e+00 : f32 affine.for %i0 = 0 to 10 { affine.for %i1 = (d0) -> (d0)(%i0) to (d0) -> (d0 + 1)(%i0) { - store %cst, %0[%i1] : memref<100xf32> + affine.store %cst, %0[%i1] : memref<100xf32> // expected-remark@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 0 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 0 to 0 at depth 3 = false}} @@ -896,7 +896,7 @@ func @test_dep_store_depth2_load_depth1() { // expected-remark@-5 {{dependence from 0 to 1 at depth 2 = false}} } %a0 = affine.apply (d0) -> (d0 - 2)(%i0) - %1 = load %0[%a0] : memref<100xf32> + %1 = affine.load %0[%a0] : memref<100xf32> // expected-remark@-1 {{dependence from 1 to 0 at depth 1 = false}} // expected-remark@-2 {{dependence from 1 to 0 at depth 2 = false}} // expected-remark@-3 {{dependence from 1 to 1 at depth 1 = false}} diff --git a/mlir/test/Transforms/parallelism-detection.mlir b/mlir/test/Transforms/parallelism-detection.mlir index 6ea6cb5c4a2..c6aa4bad382 100644 --- a/mlir/test/Transforms/parallelism-detection.mlir +++ b/mlir/test/Transforms/parallelism-detection.mlir @@ -10,12 +10,12 @@ func @loop_nest_3d_outer_two_parallel(%N : index) { affine.for %j = 0 to %N { // expected-remark@-1 {{parallel loop}} affine.for %k = 0 to %N { - %5 = load %0[%i, %k] : memref<1024x1024xvector<64xf32>> - %6 = load %1[%k, %j] : memref<1024x1024xvector<64xf32>> - %7 = load %2[%i, %j] : memref<1024x1024xvector<64xf32>> + %5 = affine.load %0[%i, %k] : memref<1024x1024xvector<64xf32>> + %6 = affine.load %1[%k, %j] : memref<1024x1024xvector<64xf32>> + %7 = affine.load %2[%i, %j] : memref<1024x1024xvector<64xf32>> %8 = mulf %5, %6 : vector<64xf32> %9 = addf %7, %8 : vector<64xf32> - store %9, %2[%i, %j] : memref<1024x1024xvector<64xf32>> + affine.store %9, %2[%i, %j] : memref<1024x1024xvector<64xf32>> } } } diff --git a/mlir/test/Transforms/pipeline-data-transfer.mlir b/mlir/test/Transforms/pipeline-data-transfer.mlir index 30e6be82e2a..6708282610a 100644 --- a/mlir/test/Transforms/pipeline-data-transfer.mlir +++ b/mlir/test/Transforms/pipeline-data-transfer.mlir @@ -1,8 +1,8 @@ -// RUN: mlir-opt %s -affine-pipeline-data-transfer | FileCheck %s +// RUN: mlir-opt %s -split-input-file -affine-pipeline-data-transfer | FileCheck %s + +// ----- // CHECK-DAG: [[MOD_2:#map[0-9]+]] = (d0) -> (d0 mod 2) -// CHECK-DAG: [[FLOOR_MOD_2:#map[0-9]+]] = (d0) -> ((d0 floordiv 4) mod 2) -// CHECK-DAG: [[REMAP_SHIFT_MINUS_4:#map[0-9]+]] = (d0) -> (d0 - 4) // CHECK-DAG: [[MAP_MINUS_1:#map[0-9]+]] = (d0) -> (d0 - 1) // CHECK-LABEL: func @loop_nest_dma() { @@ -17,11 +17,11 @@ func @loop_nest_dma() { %num_elts = constant 128 : index affine.for %i = 0 to 8 { - dma_start %A[%i], %Ah[%i], %num_elts, %tag[%zero] : memref<256 x f32>, memref<32 x f32, 1>, memref<1 x f32> - dma_wait %tag[%zero], %num_elts : memref<1 x f32> - %v = load %Ah[%i] : memref<32 x f32, (d0) -> (d0), 1> + affine.dma_start %A[%i], %Ah[%i], %tag[%zero], %num_elts : memref<256 x f32>, memref<32 x f32, 1>, memref<1 x f32> + affine.dma_wait %tag[%zero], %num_elts : memref<1 x f32> + %v = affine.load %Ah[%i] : memref<32 x f32, (d0) -> (d0), 1> %r = "compute"(%v) : (f32) -> (f32) - store %r, %Ah[%i] : memref<32 x f32, (d0) -> (d0), 1> + affine.store %r, %Ah[%i] : memref<32 x f32, (d0) -> (d0), 1> affine.for %j = 0 to 128 { "do_more_compute"(%i, %j) : (index, index) -> () } @@ -31,39 +31,39 @@ func @loop_nest_dma() { // CHECK: %0 = alloc() : memref<256xf32> // CHECK: %1 = alloc() : memref<2x32xf32, 1> // CHECK-NEXT: %2 = alloc() : memref<2x1xf32> -// CHECK-NEXT: %3 = affine.apply [[MOD_2]](%c0) -// CHECK-NEXT: %4 = affine.apply [[MOD_2]](%c0) -// CHECK-NEXT: dma_start %0[%c0], %1[%3, %c0], %c128, %2[%4, %c0_0] : memref<256xf32>, memref<2x32xf32, 1>, memref<2x1xf32> +// CHECK-NEXT: affine.dma_start %0[%c0], %1[%c0 mod 2, %c0], %2[%c0 mod 2, symbol(%c0_0)], %c128 : memref<256xf32>, memref<2x32xf32, 1>, memref<2x1xf32> // CHECK-NEXT: affine.for %i0 = 1 to 8 { -// CHECK-NEXT: %5 = affine.apply [[MOD_2]](%i0) -// CHECK-NEXT: %6 = affine.apply [[MOD_2]](%i0) -// CHECK-NEXT: dma_start %0[%i0], %1[%5, %i0], %c128, %2[%6, %c0_0] : memref<256xf32>, memref<2x32xf32, 1>, memref<2x1xf32> -// CHECK-NEXT: %7 = affine.apply [[MAP_MINUS_1]](%i0) -// CHECK-NEXT: %8 = affine.apply [[MOD_2]](%7) -// CHECK-NEXT: %9 = affine.apply [[MOD_2]](%7) -// CHECK-NEXT: dma_wait %2[%8, %c0_0], %c128 : memref<2x1xf32> -// CHECK-NEXT: %10 = load %1[%9, %7] : memref<2x32xf32, 1> -// CHECK-NEXT: %11 = "compute"(%10) : (f32) -> f32 -// CHECK-NEXT: store %11, %1[%9, %7] : memref<2x32xf32, 1> +// CHECK-NEXT: affine.dma_start %0[%i0], %1[%i0 mod 2, %i0], %2[%i0 mod 2, symbol(%c0_0)], %c128 : memref<256xf32>, memref<2x32xf32, 1>, memref<2x1xf32> +// CHECK-NEXT: %3 = affine.apply [[MAP_MINUS_1]](%i0) +// CHECK-NEXT: %4 = affine.apply [[MOD_2]](%3) +// CHECK-NEXT: %5 = affine.apply [[MOD_2]](%3) +// CHECK-NEXT: affine.dma_wait %2[%3 mod 2, symbol(%c0_0)], %c128 : memref<2x1xf32> +// CHECK-NEXT: %6 = affine.load %1[%3 mod 2, %3] : memref<2x32xf32, 1> +// CHECK-NEXT: %7 = "compute"(%6) : (f32) -> f32 +// CHECK-NEXT: affine.store %7, %1[%3 mod 2, %3] : memref<2x32xf32, 1> // CHECK-NEXT: affine.for %i1 = 0 to 128 { -// CHECK-NEXT: "do_more_compute"(%7, %i1) : (index, index) -> () +// CHECK-NEXT: "do_more_compute"(%3, %i1) : (index, index) -> () // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: %12 = affine.apply [[MAP_MINUS_1]](%c8) -// CHECK-NEXT: %13 = affine.apply [[MOD_2]](%12) -// CHECK-NEXT: %14 = affine.apply [[MOD_2]](%12) -// CHECK-NEXT: dma_wait %2[%13, %c0_0], %c128 : memref<2x1xf32> -// CHECK-NEXT: %15 = load %1[%14, %12] : memref<2x32xf32, 1> -// CHECK-NEXT: %16 = "compute"(%15) : (f32) -> f32 -// CHECK-NEXT: store %16, %1[%14, %12] : memref<2x32xf32, 1> +// CHECK-NEXT: %8 = affine.apply [[MAP_MINUS_1]](%c8) +// CHECK-NEXT: %9 = affine.apply [[MOD_2]](%8) +// CHECK-NEXT: %10 = affine.apply [[MOD_2]](%8) +// CHECK-NEXT: affine.dma_wait %2[%8 mod 2, symbol(%c0_0)], %c128 : memref<2x1xf32> +// CHECK-NEXT: %11 = affine.load %1[%8 mod 2, %8] : memref<2x32xf32, 1> +// CHECK-NEXT: %12 = "compute"(%11) : (f32) -> f32 +// CHECK-NEXT: affine.store %12, %1[%8 mod 2, %8] : memref<2x32xf32, 1> // CHECK-NEXT: affine.for %i2 = 0 to 128 { -// CHECK-NEXT: "do_more_compute"(%12, %i2) : (index, index) -> () +// CHECK-NEXT: "do_more_compute"(%8, %i2) : (index, index) -> () // CHECK-NEXT: } // CHECK-NEXT: dealloc %2 : memref<2x1xf32> // CHECK-NEXT: dealloc %1 : memref<2x32xf32, 1> // CHECK-NEXT: return // CHECK-NEXT:} +// ----- + +// CHECK-DAG: [[FLOOR_MOD_2:#map[0-9]+]] = (d0) -> ((d0 floordiv 4) mod 2) +// CHECK-DAG: [[REMAP_SHIFT_MINUS_4:#map[0-9]+]] = (d0) -> (d0 - 4) // CHECK-LABEL: @loop_step func @loop_step(%arg0: memref<512xf32>, @@ -73,33 +73,31 @@ func @loop_step(%arg0: memref<512xf32>, affine.for %i0 = 0 to 512 step 4 { %1 = alloc() : memref<4xf32, 1> %2 = alloc() : memref<1xi32> - dma_start %arg0[%i0], %1[%c0], %c4, %2[%c0] + affine.dma_start %arg0[%i0], %1[%c0], %2[%c0], %c4, : memref<512xf32>, memref<4xf32, 1>, memref<1xi32> - dma_wait %2[%c0], %c4 : memref<1xi32> + affine.dma_wait %2[%c0], %c4 : memref<1xi32> "compute"(%i0) : (index) -> () } return } // CHECK: [[TAG:%[0-9]+]] = alloc() : memref<2x1xi32> -// CHECK: %2 = affine.apply [[FLOOR_MOD_2]](%c0) -// CHECK: %3 = affine.apply [[FLOOR_MOD_2]](%c0) -// CHECK-NEXT: dma_start %arg0[%c0], %0[%2, %c0_0], %c4, [[TAG]][%3, %c0_0] : memref<512xf32>, memref<2x4xf32, 1>, memref<2x1xi32> +// CHECK-NEXT: affine.dma_start %arg0[%c0], %0[(%c0 floordiv 4) mod 2, symbol(%c0_0)], [[TAG]][(%c0 floordiv 4) mod 2, symbol(%c0_0)], %c4 : memref<512xf32>, memref<2x4xf32, 1>, memref<2x1xi32> // CHECK-NEXT: affine.for %i0 = 4 to 512 step 4 { -// CHECK-NEXT: %4 = affine.apply [[FLOOR_MOD_2]](%i0) -// CHECK-NEXT: %5 = affine.apply [[FLOOR_MOD_2]](%i0) -// CHECK-NEXT: dma_start %arg0[%i0], %0[%4, %c0_0], %c4, [[TAG]][%5, %c0_0] : memref<512xf32>, memref<2x4xf32, 1>, memref<2x1xi32> -// CHECK-NEXT: %6 = affine.apply [[REMAP_SHIFT_MINUS_4]](%i0) -// CHECK-NEXT: %7 = affine.apply [[FLOOR_MOD_2]](%6) -// CHECK: dma_wait [[TAG]][%7, %c0_0], %c4 : memref<2x1xi32> -// CHECK-NEXT: "compute"(%6) : (index) -> () +// CHECK-NEXT: affine.dma_start %arg0[%i0], %0[(%i0 floordiv 4) mod 2, symbol(%c0_0)], [[TAG]][(%i0 floordiv 4) mod 2, symbol(%c0_0)], %c4 : memref<512xf32>, memref<2x4xf32, 1>, memref<2x1xi32> +// CHECK-NEXT: %2 = affine.apply [[REMAP_SHIFT_MINUS_4]](%i0) +// CHECK-NEXT: %3 = affine.apply [[FLOOR_MOD_2]](%2) +// CHECK: affine.dma_wait [[TAG]][(%2 floordiv 4) mod 2, symbol(%c0_0)], %c4 : memref<2x1xi32> +// CHECK-NEXT: "compute"(%2) : (index) -> () // CHECK-NEXT: } // CHECK-NEXT: [[SHIFTED:%[0-9]+]] = affine.apply [[REMAP_SHIFT_MINUS_4]](%c512) -// CHECK-NEXT: %10 = affine.apply [[FLOOR_MOD_2]]([[SHIFTED]]) -// CHECK: dma_wait [[TAG]][%10, %c0_0], %c4 : memref<2x1xi32> -// CHECK-NEXT: "compute"(%9) : (index) -> () +// CHECK-NEXT: %6 = affine.apply [[FLOOR_MOD_2]]([[SHIFTED]]) +// CHECK: affine.dma_wait [[TAG]][(%5 floordiv 4) mod 2, symbol(%c0_0)], %c4 : memref<2x1xi32> +// CHECK-NEXT: "compute"(%5) : (index) -> () // CHECK: return // CHECK-NEXT: } +// ----- + #map0 = (d0, d1) -> (d0, d1) #map1 = (d0, d1) -> ((d0 * 2048 + d1 * 256) floordiv 32) #map2 = (d0) -> ((d0 * 2048) floordiv 32) @@ -116,65 +114,65 @@ func @loop_dma_nested(%arg0: memref<512x32xvector<8xf32>, #map0>, %arg1: memref< // Prologue for DMA overlap on arg2. // CHECK-DAG: [[BUF_ARG2:%[0-9]+]] = alloc() : memref<2x64x4xvector<8xf32>, 2> // CHECK-DAG: [[TAG_ARG2:%[0-9]+]] = alloc() : memref<2x2xi32> - // CHECK: dma_start %arg2[ + // CHECK: affine.dma_start %arg2[ // CHECK: affine.for %i0 = 1 to 8 { affine.for %i0 = 0 to 8 { %6 = affine.apply #map2(%i0) - dma_start %arg2[%6, %c0], %2[%c0, %c0], %num_elts, %5[%c0] : memref<512x32xvector<8xf32>, #map0>, memref<64x4xvector<8xf32>, #map0, 2>, memref<2xi32> - dma_wait %5[%c0], %num_elts : memref<2xi32> + affine.dma_start %arg2[%6, %c0], %2[%c0, %c0], %5[%c0], %num_elts : memref<512x32xvector<8xf32>, #map0>, memref<64x4xvector<8xf32>, #map0, 2>, memref<2xi32> + affine.dma_wait %5[%c0], %num_elts : memref<2xi32> // Steady state for DMA overlap on arg2 - // CHECK: dma_start %arg2[ - // CHECK: dma_wait [[TAG_ARG2]] + // CHECK: affine.dma_start %arg2[ + // CHECK: affine.dma_wait [[TAG_ARG2]] // Prologue for DMA overlap on arg0, arg1 nested within i0 // CHECK: [[BUF_ARG0:%[0-9]+]] = alloc() : memref<2x64x4xvector<8xf32>, 2> // CHECK: [[BUF_ARG1:%[0-9]+]] = alloc() : memref<2x64x4xvector<8xf32>, 2> // CHECK: [[TAG_ARG0:%[0-9]+]] = alloc() : memref<2x2xi32> // CHECK: [[TAG_ARG1:%[0-9]+]] = alloc() : memref<2x2xi32> - // CHECK: dma_start %arg0[ - // CHECK: dma_start %arg1[ + // CHECK: affine.dma_start %arg0[ + // CHECK: affine.dma_start %arg1[ // CHECK-NEXT affine.for %i1 = 1 to 8 { affine.for %i1 = 0 to 8 { %7 = affine.apply #map1(%i0, %i1) %8 = affine.apply #map2(%i1) - dma_start %arg0[%7, %c0], %0[%c0, %c0], %num_elts, %3[%c0] : memref<512x32xvector<8xf32>, #map0>, memref<64x4xvector<8xf32>, #map0, 2>, memref<2xi32> - dma_start %arg1[%8, %c0], %1[%c0, %c0], %num_elts, %4[%c0] : memref<512x32xvector<8xf32>, #map0>, memref<64x4xvector<8xf32>, #map0, 2>, memref<2xi32> - dma_wait %3[%c0], %num_elts : memref<2xi32> - dma_wait %4[%c0], %num_elts : memref<2xi32> + affine.dma_start %arg0[%7, %c0], %0[%c0, %c0], %3[%c0], %num_elts : memref<512x32xvector<8xf32>, #map0>, memref<64x4xvector<8xf32>, #map0, 2>, memref<2xi32> + affine.dma_start %arg1[%8, %c0], %1[%c0, %c0], %4[%c0], %num_elts : memref<512x32xvector<8xf32>, #map0>, memref<64x4xvector<8xf32>, #map0, 2>, memref<2xi32> + affine.dma_wait %3[%c0], %num_elts : memref<2xi32> + affine.dma_wait %4[%c0], %num_elts : memref<2xi32> // Steady state for DMA overlap on arg0, arg1 - // CHECK: dma_start %arg0[ - // CHECK: dma_start %arg1[ - // CHECK: dma_wait [[TAG_ARG0]] - // CHECK: dma_wait [[TAG_ARG1]] + // CHECK: affine.dma_start %arg0[ + // CHECK: affine.dma_start %arg1[ + // CHECK: affine.dma_wait [[TAG_ARG0]] + // CHECK: affine.dma_wait [[TAG_ARG1]] // CHECK-NEXT: affine.for %i2 = 0 to 4 { affine.for %i2 = 0 to 4 { "foo"() : () -> () } } // epilogue for arg0, arg1 - // CHECK: dma_wait [[TAG_ARG0]] - // CHECK: dma_wait [[TAG_ARG1]] + // CHECK: affine.dma_wait [[TAG_ARG0]] + // CHECK: affine.dma_wait [[TAG_ARG1]] // CHECK-DAG: dealloc [[TAG_ARG1]] : memref<2x2xi32> // CHECK-DAG: dealloc [[TAG_ARG0]] : memref<2x2xi32> // CHECK-DAG: dealloc [[BUF_ARG1]] : memref<2x64x4xvector<8xf32>, 2> // CHECK-DAG: dealloc [[BUF_ARG0]] : memref<2x64x4xvector<8xf32>, 2> // epilogue for DMA overlap on %arg2 - // CHECK: dma_wait [[TAG_ARG2]] + // CHECK: affine.dma_wait [[TAG_ARG2]] // Within the epilogue for arg2's DMA, we have the DMAs on %arg1, %arg2 nested. // CHECK: [[BUF_ARG0_NESTED:%[0-9]+]] = alloc() : memref<2x64x4xvector<8xf32>, 2> // CHECK: [[BUF_ARG1_NESTED:%[0-9]+]] = alloc() : memref<2x64x4xvector<8xf32>, 2> // CHECK: [[TAG_ARG0_NESTED:%[0-9]+]] = alloc() : memref<2x2xi32> // CHECK: [[TAG_ARG1_NESTED:%[0-9]+]] = alloc() : memref<2x2xi32> - // CHECK: dma_start %arg0[ - // CHECK: dma_start %arg1[ + // CHECK: affine.dma_start %arg0[ + // CHECK: affine.dma_start %arg1[ // CHECK: affine.for %i4 = 1 to 8 { - // CHECK: dma_start %arg0[ - // CHECK: dma_start %arg1[ - // CHECK: dma_wait [[TAG_ARG0_NESTED]] - // CHECK: dma_wait [[TAG_ARG1_NESTED]] + // CHECK: affine.dma_start %arg0[ + // CHECK: affine.dma_start %arg1[ + // CHECK: affine.dma_wait [[TAG_ARG0_NESTED]] + // CHECK: affine.dma_wait [[TAG_ARG1_NESTED]] // CHECK: affine.for %i5 = 0 to 4 { // CHECK: "foo"() : () -> () - // CHECK: dma_wait [[TAG_ARG0_NESTED]] - // CHECK: dma_wait [[TAG_ARG1_NESTED]] + // CHECK: affine.dma_wait [[TAG_ARG0_NESTED]] + // CHECK: affine.dma_wait [[TAG_ARG1_NESTED]] // CHECK: affine.for %i6 = 0 to 4 { } return @@ -188,6 +186,9 @@ func @loop_dma_nested(%arg0: memref<512x32xvector<8xf32>, #map0>, %arg1: memref< // CHECK: return } +// ----- +#map2 = (d0) -> ((d0 * 2048) floordiv 32) + // CHECK: func @loop_dma_dependent func @loop_dma_dependent(%arg2: memref<512x32xvector<8xf32>>) { %num_elts = constant 256 : index @@ -201,19 +202,21 @@ func @loop_dma_dependent(%arg2: memref<512x32xvector<8xf32>>) { // The two DMAs below are dependent (incoming and outgoing on the same // memref) in the same iteration; so no pipelining here. - // CHECK-NOT: dma_start + // CHECK-NOT: affine.dma_start // CHECK: affine.for %i0 = 0 to 8 { affine.for %i0 = 0 to 8 { %6 = affine.apply #map2(%i0) - dma_start %arg2[%6, %c0], %2[%c0, %c0], %num_elts, %5[%c0] : memref<512x32xvector<8xf32>>, memref<64x4xvector<8xf32>, 2>, memref<2xi32> - dma_wait %5[%c0], %num_elts : memref<2xi32> + affine.dma_start %arg2[%6, %c0], %2[%c0, %c0], %5[%c0], %num_elts : memref<512x32xvector<8xf32>>, memref<64x4xvector<8xf32>, 2>, memref<2xi32> + affine.dma_wait %5[%c0], %num_elts : memref<2xi32> - dma_start %2[%c0, %c0], %arg2[%6, %c0], %num_elts, %5[%c0] : memref<64x4xvector<8xf32>, 2>, memref<512x32xvector<8xf32>>, memref<2xi32> - dma_wait %5[%c0], %num_elts : memref<2xi32> + affine.dma_start %2[%c0, %c0], %arg2[%6, %c0], %5[%c0], %num_elts : memref<64x4xvector<8xf32>, 2>, memref<512x32xvector<8xf32>>, memref<2xi32> + affine.dma_wait %5[%c0], %num_elts : memref<2xi32> } // CHECK: } return // CHECK: return } +// ----- + // CHECK-LABEL: func @escaping_use func @escaping_use(%arg0: memref<512 x 32 x f32>) { %c32 = constant 32 : index @@ -222,13 +225,13 @@ func @escaping_use(%arg0: memref<512 x 32 x f32>) { %Av = alloc() : memref<32 x 32 x f32, 2> %tag = alloc() : memref<1 x i32> - // CHECK-NOT: dma_start + // CHECK-NOT: affine.dma_start // CHECK: affine.for %i0 = 0 to 16 { affine.for %kTT = 0 to 16 { - dma_start %arg0[%zero, %zero], %Av[%zero, %zero], %num_elt, %tag[%zero] : + affine.dma_start %arg0[%zero, %zero], %Av[%zero, %zero], %tag[%zero], %num_elt : memref<512 x 32 x f32>, memref<32 x 32 x f32, 2>, memref<1 x i32> - dma_wait %tag[%zero], %num_elt : memref<1 x i32> + affine.dma_wait %tag[%zero], %num_elt : memref<1 x i32> // escaping use; no DMA pipelining / double buffering will be done. "foo"(%Av) : (memref<32 x 32 x f32, 2>) -> () } @@ -238,6 +241,8 @@ func @escaping_use(%arg0: memref<512 x 32 x f32>) { // CHECK: return } +// ----- + // CHECK-LABEL: func @live_out_use func @live_out_use(%arg0: memref<512 x 32 x f32>) -> f32 { %c32 = constant 32 : index @@ -246,21 +251,23 @@ func @live_out_use(%arg0: memref<512 x 32 x f32>) -> f32 { %Av = alloc() : memref<32 x 32 x f32, 2> %tag = alloc() : memref<1 x i32> - // CHECK-NOT: dma_start + // CHECK-NOT: affine.dma_start // CHECK: affine.for %i0 = 0 to 16 { affine.for %kTT = 0 to 16 { - dma_start %arg0[%zero, %zero], %Av[%zero, %zero], %num_elt, %tag[%zero] : + affine.dma_start %arg0[%zero, %zero], %Av[%zero, %zero], %tag[%zero], %num_elt : memref<512 x 32 x f32>, memref<32 x 32 x f32, 2>, memref<1 x i32> - dma_wait %tag[%zero], %num_elt : memref<1 x i32> + affine.dma_wait %tag[%zero], %num_elt : memref<1 x i32> } // Use live out of 'affine.for' op; no DMA pipelining will be done. - %v = load %Av[%zero, %zero] : memref<32 x 32 x f32, 2> + %v = affine.load %Av[%zero, %zero] : memref<32 x 32 x f32, 2> return %v : f32 -// CHECK: %{{[0-9]+}} = load %{{[0-9]+}}[%c0, %c0] : memref<32x32xf32, 2> +// CHECK: %{{[0-9]+}} = affine.load %{{[0-9]+}}[%c0, %c0] : memref<32x32xf32, 2> // CHECK: return } +// ----- + // CHECK-LABEL: func @dynamic_shape_dma_buffer func @dynamic_shape_dma_buffer(%arg0: memref<512 x 32 x f32>) { %c32 = constant 32 : index @@ -275,22 +282,18 @@ func @dynamic_shape_dma_buffer(%arg0: memref<512 x 32 x f32>) { // CHECK-NEXT: %1 = dim %0, 0 : memref // CHECK-NEXT: %2 = dim %0, 1 : memref // CHECK-NEXT: %3 = alloc(%1, %2) : memref<2x?x?xf32, 2> -// CHECK: %5 = affine.apply [[MOD_2]](%c0) -// CHECK: %6 = affine.apply [[MOD_2]](%c0) -// CHECK: dma_start %arg0[%c0_0, %c0_0], %3[%5, %c0_0, %c0_0], %c512, %4[%6, %c0_0] +// CHECK: affine.dma_start %arg0[%c0_0, %c0_0], %3[%c0 mod 2, symbol(%c0_0), symbol(%c0_0)], %4[%c0 mod 2, symbol(%c0_0)], %c512 affine.for %kTT = 0 to 16 { - dma_start %arg0[%zero, %zero], %Av[%zero, %zero], %num_elt, %tag[%zero] : + affine.dma_start %arg0[%zero, %zero], %Av[%zero, %zero], %tag[%zero], %num_elt : memref<512 x 32 x f32>, memref, memref<1 x i32> - dma_wait %tag[%zero], %num_elt : memref<1 x i32> + affine.dma_wait %tag[%zero], %num_elt : memref<1 x i32> } return // CHECK-NEXT: affine.for %i0 = 1 to 16 { -// CHECK: %7 = affine.apply [[MOD_2]](%i0) -// CHECK: %8 = affine.apply [[MOD_2]](%i0) -// CHECK: dma_start %arg0[%c0_0, %c0_0], %3[%7, %c0_0, %c0_0], %c512, %4[%8, %c0_0] -// CHECK: dma_wait %4[%10, %c0_0], %c512 : memref<2x1xi32> +// CHECK: affine.dma_start %arg0[%c0_0, %c0_0], %3[%i0 mod 2, symbol(%c0_0), symbol(%c0_0)], %4[%i0 mod 2, symbol(%c0_0)], %c512 +// CHECK: affine.dma_wait %4[%5 mod 2, symbol(%c0_0)], %c512 : memref<2x1xi32> // CHECK: } -// CHECK: dma_wait %4[%13, %c0_0], %c512 : memref<2x1xi32> +// CHECK: affine.dma_wait %4[%8 mod 2, symbol(%c0_0)], %c512 : memref<2x1xi32> // CHECK: return } -- cgit v1.2.3 From ce502af9cd6f0fff04f0f98c8a71fa4a00fa0de7 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 8 Jul 2019 11:20:26 -0700 Subject: NFC: Remove the various "::getFunction" methods. These methods assume that a function is a valid builtin top-level operation, and removing these methods allows for decoupling FuncOp and IR/. Utility "getParentOfType" methods have been added to Operation/OpState to allow for querying the first parent operation of a given type. PiperOrigin-RevId: 257018913 --- mlir/examples/toy/Ch5/mlir/LateLowering.cpp | 2 +- mlir/include/mlir/IR/Block.h | 4 --- mlir/include/mlir/IR/OpDefinition.h | 5 ++++ mlir/include/mlir/IR/Operation.h | 12 ++++++--- mlir/include/mlir/IR/Value.h | 9 ------- mlir/lib/AffineOps/AffineOps.cpp | 2 +- .../StandardToLLVM/ConvertStandardToLLVM.cpp | 9 ++++--- mlir/lib/EDSC/CoreAPIs.cpp | 4 +-- mlir/lib/GPU/IR/GPUDialect.cpp | 2 +- mlir/lib/GPU/Transforms/KernelOutlining.cpp | 2 +- mlir/lib/IR/AsmPrinter.cpp | 4 +-- mlir/lib/IR/Block.cpp | 13 +++------ mlir/lib/IR/Operation.cpp | 4 --- mlir/lib/IR/Value.cpp | 31 +--------------------- mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp | 6 ++--- mlir/lib/StandardOps/Ops.cpp | 11 ++++---- mlir/lib/Transforms/DmaGeneration.cpp | 10 +++---- mlir/lib/Transforms/LoopFusion.cpp | 4 +-- mlir/lib/Transforms/MaterializeVectors.cpp | 4 +-- mlir/lib/Transforms/Utils/LoopUtils.cpp | 2 +- mlir/lib/Transforms/Utils/Utils.cpp | 7 ++--- 21 files changed, 52 insertions(+), 95 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp index 1c220806266..2ea1c6cad9a 100644 --- a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp @@ -136,7 +136,7 @@ public: PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, PatternRewriter &rewriter) const override { // Get or create the declaration of the printf function in the module. - Function printfFunc = getPrintf(op->getFunction().getModule()); + Function printfFunc = getPrintf(op->getParentOfType()); auto print = cast(op); auto loc = print.getLoc(); diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h index 4f5ca5e495a..92772669639 100644 --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -99,10 +99,6 @@ public: /// nullptr if this is a top-level block. Operation *getContainingOp(); - /// Returns the function that this block is part of, even if the block is - /// nested under an operation region. - Function getFunction(); - /// Insert this block (which must not already be in a function) right before /// the specified block. void insertBefore(Block *block); diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 70cd9b41ebb..6913b7638d7 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -71,6 +71,11 @@ public: /// Return the operation that this refers to. Operation *getOperation() { return state; } + /// Return the closes surrounding parent operation that is of type 'OpTy'. + template OpTy getParentOfType() { + return getOperation()->getParentOfType(); + } + /// Return the context this operation belongs to. MLIRContext *getContext() { return getOperation()->getContext(); } diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index b31dbda34f1..6e17ef063f8 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -125,10 +125,14 @@ public: /// or nullptr if this is a top-level operation. Operation *getParentOp(); - /// Returns the function that this operation is part of. - /// The function is determined by traversing the chain of parent operations. - /// Returns nullptr if the operation is unlinked. - Function getFunction(); + /// Return the closest surrounding parent operation that is of type 'OpTy'. + template OpTy getParentOfType() { + auto *op = this; + while ((op = op->getParentOp())) + if (auto parentOp = llvm::dyn_cast(op)) + return parentOp; + return OpTy(); + } /// Replace any uses of 'from' with 'to' within this operation. void replaceUsesOfWith(Value *from, Value *to); diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h index 37aebbdc496..b5dbd539eb0 100644 --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -72,9 +72,6 @@ public: IRObjectWithUseList::replaceAllUsesWith(newValue); } - /// Return the function that this Value is defined in. - Function getFunction(); - /// If this value is the result of an operation, return the operation that /// defines it. Operation *getDefiningOp(); @@ -128,17 +125,11 @@ public: return const_cast(value)->getKind() == Kind::BlockArgument; } - /// Return the function that this argument is defined in. - Function getFunction(); - Block *getOwner() { return owner; } /// Returns the number of this argument. unsigned getArgNumber(); - /// Returns if the current argument is a function argument. - bool isFunctionArgument(); - private: friend class Block; // For access to private constructor. BlockArgument(Type type, Block *owner) diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index d98904346e5..d11d525dce6 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -307,7 +307,7 @@ AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value *v) { if (inserted) { reorderedDims.push_back(v); } - return getAffineDimExpr(iterPos->second, v->getFunction().getContext()) + return getAffineDimExpr(iterPos->second, v->getContext()) .cast(); } diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index f915c487826..2a52706c277 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -442,13 +442,13 @@ struct AllocOpLowering : public LLVMLegalizationPattern { // Insert the `malloc` declaration if it is not already present. Function mallocFunc = - op->getFunction().getModule().getNamedFunction("malloc"); + op->getParentOfType().getModule().getNamedFunction("malloc"); if (!mallocFunc) { auto mallocType = rewriter.getFunctionType(getIndexType(), getVoidPtrType()); mallocFunc = Function::create(rewriter.getUnknownLoc(), "malloc", mallocType); - op->getFunction().getModule().push_back(mallocFunc); + op->getParentOfType().getModule().push_back(mallocFunc); } // Allocate the underlying buffer and store a pointer to it in the MemRef @@ -503,11 +503,12 @@ struct DeallocOpLowering : public LLVMLegalizationPattern { OperandAdaptor transformed(operands); // Insert the `free` declaration if it is not already present. - Function freeFunc = op->getFunction().getModule().getNamedFunction("free"); + Function freeFunc = + op->getParentOfType().getModule().getNamedFunction("free"); if (!freeFunc) { auto freeType = rewriter.getFunctionType(getVoidPtrType(), {}); freeFunc = Function::create(rewriter.getUnknownLoc(), "free", freeType); - op->getFunction().getModule().push_back(freeFunc); + op->getParentOfType().getModule().push_back(freeFunc); } auto type = transformed.memref()->getType().cast(); diff --git a/mlir/lib/EDSC/CoreAPIs.cpp b/mlir/lib/EDSC/CoreAPIs.cpp index 8a94dad8ae6..578b8673658 100644 --- a/mlir/lib/EDSC/CoreAPIs.cpp +++ b/mlir/lib/EDSC/CoreAPIs.cpp @@ -98,6 +98,6 @@ mlir_attr_t makeBoolAttr(mlir_context_t context, bool value) { } unsigned getFunctionArity(mlir_func_t function) { - auto *f = reinterpret_cast(function); - return f->getNumArguments(); + auto f = mlir::Function::getFromOpaquePointer(function); + return f.getNumArguments(); } diff --git a/mlir/lib/GPU/IR/GPUDialect.cpp b/mlir/lib/GPU/IR/GPUDialect.cpp index 6cf57b42f45..92034c5d288 100644 --- a/mlir/lib/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/GPU/IR/GPUDialect.cpp @@ -426,7 +426,7 @@ LogicalResult LaunchFuncOp::verify() { return emitOpError("attribute 'kernel' must be a function"); } - auto module = getOperation()->getFunction().getModule(); + auto module = getParentOfType(); Function kernelFunc = module.getNamedFunction(kernel()); if (!kernelFunc) return emitError() << "kernel function '" << kernelAttr << "' is undefined"; diff --git a/mlir/lib/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/GPU/Transforms/KernelOutlining.cpp index 6cb920a10b3..4f110ac286a 100644 --- a/mlir/lib/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/GPU/Transforms/KernelOutlining.cpp @@ -64,7 +64,7 @@ static Function outlineKernelFunc(gpu::LaunchOp launchOp) { FunctionType type = FunctionType::get(kernelOperandTypes, {}, launchOp.getContext()); std::string kernelFuncName = - Twine(launchOp.getOperation()->getFunction().getName(), "_kernel").str(); + Twine(launchOp.getParentOfType().getName(), "_kernel").str(); Function outlinedFunc = Function::create(loc, kernelFuncName, type); outlinedFunc.getBody().takeBody(launchOp.getBody()); Builder builder(launchOp.getContext()); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 065462273b4..bd6137b41b0 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1421,8 +1421,8 @@ void OperationPrinter::print(Block *block, bool printBlockArgs, os << ':'; // Print out some context information about the predecessors of this block. - if (!block->getFunction()) { - os << "\t// block is not in a function!"; + if (!block->getParent()) { + os << "\t// block is not in a region!"; } else if (block->hasNoPredecessors()) { os << "\t// no predecessors"; } else if (auto *pred = block->getSinglePredecessor()) { diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp index e17b13d98fe..93f5fe6e976 100644 --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -50,13 +50,8 @@ Operation *Block::getContainingOp() { return getParent() ? getParent()->getContainingOp() : nullptr; } -Function Block::getFunction() { - auto *parent = getParent(); - return parent ? parent->getParentOfType() : nullptr; -} - -/// Insert this block (which must not already be in a function) right before -/// the specified block. +/// Insert this block (which must not already be in a region) right before the +/// specified block. void Block::insertBefore(Block *block) { assert(!getParent() && "already inserted into a block!"); assert(block->getParent() && "cannot insert before a block without a parent"); @@ -254,11 +249,11 @@ void Block::walk(Block::iterator begin, Block::iterator end, /// invalidated. Block *Block::splitBlock(iterator splitBefore) { // Start by creating a new basic block, and insert it immediate after this - // one in the containing function. + // one in the containing region. auto newBB = new Block(); getParent()->getBlocks().insert(std::next(Region::iterator(this)), newBB); - // Move all of the operations from the split point to the end of the function + // Move all of the operations from the split point to the end of the region // into the new block. newBB->getOperations().splice(newBB->end(), getOperations(), splitBefore, end()); diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 40b759fcfd5..ba9e3cf17b9 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -281,10 +281,6 @@ Operation *Operation::getParentOp() { return block ? block->getContainingOp() : nullptr; } -Function Operation::getFunction() { - return block ? block->getFunction() : nullptr; -} - /// Replace any uses of 'from' with 'to' within this operation. void Operation::replaceUsesOfWith(Value *from, Value *to) { if (from == to) diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp index 65a98f7ee59..669f641b734 100644 --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -29,21 +29,9 @@ Operation *Value::getDefiningOp() { return nullptr; } -/// Return the function that this Value is defined in. -Function Value::getFunction() { - switch (getKind()) { - case Value::Kind::BlockArgument: - return cast(this)->getFunction(); - case Value::Kind::OpResult: - return getDefiningOp()->getFunction(); - } - llvm_unreachable("Unknown Value Kind"); -} - Location Value::getLoc() { - if (auto *op = getDefiningOp()) { + if (auto *op = getDefiningOp()) return op->getLoc(); - } return UnknownLoc::get(getContext()); } @@ -78,20 +66,3 @@ void IRObjectWithUseList::dropAllUses() { use_begin()->drop(); } } - -//===----------------------------------------------------------------------===// -// BlockArgument implementation. -//===----------------------------------------------------------------------===// - -/// Return the function that this argument is defined in. -Function BlockArgument::getFunction() { - if (auto *owner = getOwner()) - return owner->getFunction(); - return nullptr; -} - -/// Returns if the current argument is a function argument. -bool BlockArgument::isFunctionArgument() { - auto containingFn = getFunction(); - return containingFn && &containingFn.front() == getOwner(); -} diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index f42f2860d5d..a3d89c3c42b 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -170,7 +170,7 @@ public: LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo(); auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); // Insert the `malloc` declaration if it is not already present. - auto module = op->getFunction().getModule(); + auto module = op->getParentOfType(); Function mallocFunc = module.getNamedFunction("malloc"); if (!mallocFunc) { auto mallocType = rewriter.getFunctionType(int64Ty, voidPtrTy); @@ -231,7 +231,7 @@ public: auto voidPtrTy = LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo(); // Insert the `free` declaration if it is not already present. - auto module = op->getFunction().getModule(); + auto module = op->getParentOfType(); Function freeFunc = module.getNamedFunction("free"); if (!freeFunc) { auto freeType = rewriter.getFunctionType(voidPtrTy, {}); @@ -602,7 +602,7 @@ static Function getLLVMLibraryCallDeclaration(Operation *op, PatternRewriter &rewriter) { assert(isa(op)); auto fnName = LinalgOp::getLibraryCallName(); - auto module = op->getFunction().getModule(); + auto module = op->getParentOfType(); if (auto f = module.getNamedFunction(fnName)) { return f; } diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 3be51e67186..63a01e254cd 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -431,8 +431,7 @@ static LogicalResult verify(CallOp op) { auto fnAttr = op.getAttrOfType("callee"); if (!fnAttr) return op.emitOpError("requires a 'callee' function attribute"); - auto fn = op.getOperation()->getFunction().getModule().getNamedFunction( - fnAttr.getValue()); + auto fn = op.getParentOfType().getNamedFunction(fnAttr.getValue()); if (!fn) return op.emitOpError() << "'" << fnAttr.getValue() << "' does not reference a valid function"; @@ -1098,8 +1097,8 @@ static LogicalResult verify(ConstantOp &op) { return op.emitOpError("requires 'value' to be a function reference"); // Try to find the referenced function. - auto fn = op.getOperation()->getFunction().getModule().getNamedFunction( - fnAttr.getValue()); + auto fn = + op.getParentOfType().getNamedFunction(fnAttr.getValue()); if (!fn) return op.emitOpError("reference to undefined function 'bar'"); @@ -2029,7 +2028,9 @@ static void print(OpAsmPrinter *p, ReturnOp op) { } static LogicalResult verify(ReturnOp op) { - auto function = op.getOperation()->getFunction(); + // TODO(b/137008268): Return op should verify that it is nested directly + // within a function operation. + auto function = op.getParentOfType(); // The operand number and types must match the function signature. const auto &results = function.getType().getResults(); diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index e867dc70ed3..830546db497 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -217,8 +217,7 @@ static bool getFullMemRefAsRegion(Operation *opInst, unsigned numParamLoopIVs, static InFlightDiagnostic LLVM_ATTRIBUTE_UNUSED emitRemarkForBlock(Block &block) { - auto *op = block.getContainingOp(); - return op ? op->emitRemark() : block.getFunction().emitRemark(); + return block.getContainingOp()->emitRemark(); } /// Creates a buffer in the faster memory space for the specified region; @@ -250,7 +249,7 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, Block *block, OpBuilder &b = region.isWrite() ? epilogue : prologue; // Builder to create constants at the top level. - auto func = block->getFunction(); + auto func = block->getParent()->getParentOfType(); OpBuilder top(func.getBody()); auto loc = region.loc; @@ -765,10 +764,7 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { if (totalDmaBuffersSizeInBytes > fastMemCapacityBytes) { StringRef str = "Total size of all DMA buffers' for this block " "exceeds fast memory capacity\n"; - if (auto *op = block->getContainingOp()) - op->emitError(str); - else - block->getFunction().emitError(str); + block->getContainingOp()->emitError(str); } return totalDmaBuffersSizeInBytes; diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 1eee40b88da..b2557a6c6fd 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -859,7 +859,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, // Create builder to insert alloc op just before 'forOp'. OpBuilder b(forInst); // Builder to create constants at the top level. - OpBuilder top(forInst->getFunction().getBody()); + OpBuilder top(forInst->getParentOfType().getBody()); // Create new memref type based on slice bounds. auto *oldMemRef = cast(srcStoreOpInst).getMemRef(); auto oldMemRefType = oldMemRef->getType().cast(); @@ -1750,7 +1750,7 @@ public: }; // Search for siblings which load the same memref function argument. - auto fn = dstNode->op->getFunction(); + auto fn = dstNode->op->getParentOfType(); for (unsigned i = 0, e = fn.getNumArguments(); i != e; ++i) { for (auto *user : fn.getArgument(i)->getUsers()) { if (auto loadOp = dyn_cast(user)) { diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index f59f1006ec5..fcac60c6a92 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -635,8 +635,8 @@ static bool emitSlice(MaterializationState *state, } } - LLVM_DEBUG(dbgs() << "\nMLFunction is now\n"); - LLVM_DEBUG((*slice)[0]->getFunction().print(dbgs())); + LLVM_DEBUG(dbgs() << "\nFunction is now\n"); + LLVM_DEBUG((*slice)[0]->getParentOfType().print(dbgs())); // slice are topologically sorted, we can just erase them in reverse // order. Reverse iterator does not just work simply with an operator* diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 4ddf93c2232..65847fc8bee 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -125,7 +125,7 @@ LogicalResult mlir::promoteIfSingleIteration(AffineForOp forOp) { Operation *op = forOp.getOperation(); if (!iv->use_empty()) { if (forOp.hasConstantLowerBound()) { - OpBuilder topBuilder(op->getFunction().getBody()); + OpBuilder topBuilder(op->getParentOfType().getBody()); auto constOp = topBuilder.create( forOp.getLoc(), forOp.getConstantLowerBound()); iv->replaceAllUsesWith(constOp); diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 16f4effca15..c1a4dcb7ebb 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -81,11 +81,12 @@ bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, std::unique_ptr domInfo; std::unique_ptr postDomInfo; if (domInstFilter) - domInfo = llvm::make_unique(domInstFilter->getFunction()); + domInfo = llvm::make_unique( + domInstFilter->getParentOfType()); if (postDomInstFilter) - postDomInfo = - llvm::make_unique(postDomInstFilter->getFunction()); + postDomInfo = llvm::make_unique( + postDomInstFilter->getParentOfType()); // The ops where memref replacement succeeds are replaced with new ones. SmallVector opsToErase; -- cgit v1.2.3 From 8c443678918b29f1191f51da4501f53f7a0ccffd Mon Sep 17 00:00:00 2001 From: River Riddle Date: Tue, 9 Jul 2019 16:17:55 -0700 Subject: NFC: Rename Function to FuncOp. PiperOrigin-RevId: 257293379 --- mlir/LICENSE.TXT | 227 ++++++++++++++++++--- mlir/bindings/python/pybind.cpp | 23 +-- .../Linalg/Linalg1/include/linalg1/Common.h | 10 +- mlir/examples/Linalg/Linalg2/Example.cpp | 8 +- mlir/examples/Linalg/Linalg3/Conversion.cpp | 6 +- mlir/examples/Linalg/Linalg3/Example.cpp | 14 +- mlir/examples/Linalg/Linalg3/Execution.cpp | 6 +- .../Linalg/Linalg3/include/linalg3/Transforms.h | 7 +- mlir/examples/Linalg/Linalg3/lib/Transforms.cpp | 6 +- mlir/examples/Linalg/Linalg4/Example.cpp | 10 +- .../Linalg/Linalg4/include/linalg4/Transforms.h | 5 +- mlir/examples/Linalg/Linalg4/lib/Transforms.cpp | 4 +- mlir/examples/toy/Ch2/mlir/MLIRGen.cpp | 10 +- mlir/examples/toy/Ch3/mlir/MLIRGen.cpp | 10 +- mlir/examples/toy/Ch4/mlir/MLIRGen.cpp | 10 +- mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp | 12 +- mlir/examples/toy/Ch5/mlir/LateLowering.cpp | 6 +- mlir/examples/toy/Ch5/mlir/MLIRGen.cpp | 10 +- mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp | 16 +- mlir/g3doc/Tutorials/Toy/Ch-4.md | 2 +- mlir/g3doc/WritingAPass.md | 6 +- mlir/include/mlir-c/Core.h | 2 +- mlir/include/mlir/Analysis/AffineStructures.h | 4 +- mlir/include/mlir/Analysis/NestedMatcher.h | 2 +- mlir/include/mlir/Analysis/Passes.h | 4 +- .../mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h | 4 +- mlir/include/mlir/ExecutionEngine/MemRefUtils.h | 4 +- mlir/include/mlir/GPU/GPUDialect.h | 16 +- mlir/include/mlir/IR/Attributes.h | 2 - mlir/include/mlir/IR/Builders.h | 2 +- mlir/include/mlir/IR/Dialect.h | 5 +- mlir/include/mlir/IR/Function.h | 3 - mlir/include/mlir/IR/Module.h | 4 +- mlir/include/mlir/IR/PatternMatch.h | 2 +- mlir/include/mlir/IR/Value.h | 2 - mlir/include/mlir/LLVMIR/LLVMDialect.h | 2 +- mlir/include/mlir/Pass/AnalysisManager.h | 12 +- mlir/include/mlir/Pass/Pass.h | 12 +- mlir/include/mlir/StandardOps/Ops.td | 2 +- .../include/mlir/Target/LLVMIR/ModuleTranslation.h | 4 +- mlir/include/mlir/Transforms/DialectConversion.h | 4 +- mlir/include/mlir/Transforms/FoldUtils.h | 2 - mlir/include/mlir/Transforms/LoopUtils.h | 3 +- mlir/include/mlir/Transforms/LowerAffine.h | 3 +- mlir/lib/Analysis/TestParallelismDetection.cpp | 2 +- .../GPUToCUDA/ConvertKernelFuncToCubin.cpp | 12 +- .../GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp | 60 +++--- .../GPUToCUDA/GenerateCubinAccessors.cpp | 14 +- .../StandardToLLVM/ConvertStandardToLLVM.cpp | 8 +- mlir/lib/EDSC/CoreAPIs.cpp | 2 +- mlir/lib/ExecutionEngine/MemRefUtils.cpp | 2 +- mlir/lib/GPU/IR/GPUDialect.cpp | 21 +- mlir/lib/GPU/Transforms/KernelOutlining.cpp | 11 +- mlir/lib/IR/Builders.cpp | 8 +- mlir/lib/IR/Dialect.cpp | 4 +- mlir/lib/IR/Function.cpp | 20 +- mlir/lib/IR/Region.cpp | 1 - mlir/lib/IR/Value.cpp | 1 - mlir/lib/LLVMIR/IR/LLVMDialect.cpp | 2 +- mlir/lib/Linalg/Transforms/Fusion.cpp | 2 +- mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp | 34 ++- mlir/lib/Linalg/Transforms/Tiling.cpp | 2 +- mlir/lib/Pass/IRPrinting.cpp | 8 +- mlir/lib/Pass/Pass.cpp | 15 +- mlir/lib/Pass/PassDetail.h | 2 +- mlir/lib/SPIRV/Serialization/ConvertFromBinary.cpp | 2 +- mlir/lib/Support/MlirOptMain.cpp | 1 - mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp | 2 +- mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 10 +- mlir/lib/Transforms/DialectConversion.cpp | 10 +- mlir/lib/Transforms/DmaGeneration.cpp | 2 +- mlir/lib/Transforms/LoopFusion.cpp | 6 +- mlir/lib/Transforms/LoopParametricTiling.cpp | 2 +- mlir/lib/Transforms/LoopTiling.cpp | 2 +- mlir/lib/Transforms/LoopUnroll.cpp | 4 +- mlir/lib/Transforms/LowerAffine.cpp | 2 +- mlir/lib/Transforms/MaterializeVectors.cpp | 4 +- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 2 +- mlir/lib/Transforms/StripDebugInfo.cpp | 2 +- .../Utils/GreedyPatternRewriteDriver.cpp | 4 +- mlir/lib/Transforms/Utils/LoopUtils.cpp | 6 +- mlir/lib/Transforms/Vectorize.cpp | 2 +- mlir/test/EDSC/builder-api-test.cpp | 10 +- .../test/lib/Transforms/TestVectorizationUtils.cpp | 2 +- mlir/tools/mlir-cpu-runner/mlir-cpu-runner-lib.cpp | 4 +- mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp | 4 +- mlir/tools/mlir-tblgen/ReferenceImplGen.cpp | 2 +- mlir/unittests/Pass/AnalysisManagerTest.cpp | 16 +- 88 files changed, 486 insertions(+), 354 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/LICENSE.TXT b/mlir/LICENSE.TXT index 5af756e1e01..a4b160b6e33 100644 --- a/mlir/LICENSE.TXT +++ b/mlir/LICENSE.TXT @@ -1,44 +1,205 @@ -============================================================================== -LLVM Release License -============================================================================== -University of Illinois/NCSA -Open Source License +Copyright 2019 The MLIR Authors. -Copyright (c) 2003-2018 University of Illinois at Urbana-Champaign. -All rights reserved. + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ -Developed by: + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - LLVM Team + 1. Definitions. - University of Illinois at Urbana-Champaign + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. - http://llvm.org + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. -Permission is hereby granted, free of charge, to any person obtaining a copy of -this software and associated documentation files (the "Software"), to deal with -the Software without restriction, including without limitation the rights to -use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies -of the Software, and to permit persons to whom the Software is furnished to do -so, subject to the following conditions: + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. - * Redistributions of source code must retain the above copyright notice, - this list of conditions and the following disclaimers. + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. - * Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimers in the - documentation and/or other materials provided with the distribution. + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. - * Neither the names of the LLVM Team, University of Illinois at - Urbana-Champaign, nor the names of its contributors may be used to - endorse or promote products derived from this Software without specific - prior written permission. + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS -FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH THE -SOFTWARE. diff --git a/mlir/bindings/python/pybind.cpp b/mlir/bindings/python/pybind.cpp index f730f8e48bb..f58e854f20a 100644 --- a/mlir/bindings/python/pybind.cpp +++ b/mlir/bindings/python/pybind.cpp @@ -111,11 +111,11 @@ struct PythonValueHandle { struct PythonFunction { PythonFunction() : function{nullptr} {} PythonFunction(mlir_func_t f) : function{f} {} - PythonFunction(mlir::Function f) + PythonFunction(mlir::FuncOp f) : function(const_cast(f.getAsOpaquePointer())) {} operator mlir_func_t() { return function; } std::string str() { - mlir::Function f = mlir::Function::getFromOpaquePointer(function); + mlir::FuncOp f = mlir::FuncOp::getFromOpaquePointer(function); std::string res; llvm::raw_string_ostream os(res); f.print(os); @@ -126,7 +126,7 @@ struct PythonFunction { // declaration, add the entry block, transforming the declaration into a // definition. Return true if the block was added, false otherwise. bool define() { - auto f = mlir::Function::getFromOpaquePointer(function); + auto f = mlir::FuncOp::getFromOpaquePointer(function); if (!f.getBlocks().empty()) return false; @@ -135,7 +135,7 @@ struct PythonFunction { } PythonValueHandle arg(unsigned index) { - auto f = mlir::Function::getFromOpaquePointer(function); + auto f = mlir::FuncOp::getFromOpaquePointer(function); assert(index < f.getNumArguments() && "argument index out of bounds"); return PythonValueHandle(ValueHandle(f.getArgument(index))); } @@ -252,7 +252,7 @@ struct PythonFunctionContext { PythonFunction enter() { assert(function.function && "function is not set up"); - auto mlirFunc = mlir::Function::getFromOpaquePointer(function.function); + auto mlirFunc = mlir::FuncOp::getFromOpaquePointer(function.function); contextBuilder.emplace(mlirFunc.getBody()); context = new mlir::edsc::ScopedContext(*contextBuilder, mlirFunc.getLoc()); return function; @@ -595,7 +595,7 @@ PythonMLIRModule::declareFunction(const std::string &name, } // Create the function itself. - auto func = mlir::Function::create( + auto func = mlir::FuncOp::create( UnknownLoc::get(&mlirContext), name, mlir::Type::getFromOpaquePointer(funcType).cast(), attrs, inputAttrs); @@ -653,7 +653,7 @@ PYBIND11_MODULE(pybind, m) { return ValueHandle::create(value, floatType); }); m.def("constant_function", [](PythonFunction func) -> PythonValueHandle { - auto function = Function::getFromOpaquePointer(func.function); + auto function = FuncOp::getFromOpaquePointer(func.function); auto attr = FunctionAttr::get(function.getName(), function.getContext()); return ValueHandle::create(function.getType(), attr); }); @@ -723,8 +723,7 @@ PYBIND11_MODULE(pybind, m) { return ValueHandle::create(name, operandHandles, types, attrs); }); - py::class_(m, "Function", - "Wrapping class for mlir::Function.") + py::class_(m, "Function", "Wrapping class for mlir::FuncOp.") .def(py::init()) .def("__str__", &PythonFunction::str) .def("define", &PythonFunction::define, @@ -773,13 +772,13 @@ PYBIND11_MODULE(pybind, m) { "Creates an mlir::IntegerAttr of the given type with the given value " "in the context associated with this MLIR module.") .def("declare_function", &PythonMLIRModule::declareFunction, - "Declares a new mlir::Function in the current mlir::Module. The " + "Declares a new mlir::FuncOp in the current mlir::Module. The " "function arguments can have attributes. The function has no " "definition and can be linked to an external library.") .def("make_function", &PythonMLIRModule::makeFunction, - "Defines a new mlir::Function in the current mlir::Module.") + "Defines a new mlir::FuncOp in the current mlir::Module.") .def("function_context", &PythonMLIRModule::makeFunctionContext, - "Defines a new mlir::Function in the mlir::Module and creates the " + "Defines a new mlir::FuncOp in the mlir::Module and creates the " "function context for building the body of the function.") .def("get_function", &PythonMLIRModule::getNamedFunction, "Looks up the function with the given name in the module.") diff --git a/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h b/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h index 104139005d6..73ccbe67968 100644 --- a/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h +++ b/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h @@ -58,11 +58,11 @@ inline mlir::MemRefType floatMemRefType(mlir::MLIRContext *context, } /// A basic function builder -inline mlir::Function makeFunction(mlir::Module module, llvm::StringRef name, - llvm::ArrayRef types, - llvm::ArrayRef resultTypes) { +inline mlir::FuncOp makeFunction(mlir::Module module, llvm::StringRef name, + llvm::ArrayRef types, + llvm::ArrayRef resultTypes) { auto *context = module.getContext(); - auto function = mlir::Function::create( + auto function = mlir::FuncOp::create( mlir::UnknownLoc::get(context), name, mlir::FunctionType::get({types}, resultTypes, context)); function.addEntryBlock(); @@ -84,7 +84,7 @@ inline std::unique_ptr cleanupPassManager() { /// llvm::outs() for FileCheck'ing. /// If an error occurs, dump to llvm::errs() and do not print to llvm::outs() /// which will make the associated FileCheck test fail. -inline void cleanupAndPrintFunction(mlir::Function f) { +inline void cleanupAndPrintFunction(mlir::FuncOp f) { bool printToOuts = true; auto check = [&f, &printToOuts](mlir::LogicalResult result) { if (failed(result)) { diff --git a/mlir/examples/Linalg/Linalg2/Example.cpp b/mlir/examples/Linalg/Linalg2/Example.cpp index cb93b96cc58..57e6a0f21f0 100644 --- a/mlir/examples/Linalg/Linalg2/Example.cpp +++ b/mlir/examples/Linalg/Linalg2/Example.cpp @@ -36,8 +36,8 @@ TEST_FUNC(linalg_ops) { MLIRContext context; OwningModuleRef module = Module::create(&context); auto indexType = mlir::IndexType::get(&context); - mlir::Function f = makeFunction(*module, "linalg_ops", - {indexType, indexType, indexType}, {}); + mlir::FuncOp f = makeFunction(*module, "linalg_ops", + {indexType, indexType, indexType}, {}); OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); @@ -75,8 +75,8 @@ TEST_FUNC(linalg_ops_folded_slices) { MLIRContext context; OwningModuleRef module = Module::create(&context); auto indexType = mlir::IndexType::get(&context); - mlir::Function f = makeFunction(*module, "linalg_ops_folded_slices", - {indexType, indexType, indexType}, {}); + mlir::FuncOp f = makeFunction(*module, "linalg_ops_folded_slices", + {indexType, indexType, indexType}, {}); OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); diff --git a/mlir/examples/Linalg/Linalg3/Conversion.cpp b/mlir/examples/Linalg/Linalg3/Conversion.cpp index 6bd428fe8a5..7ba6dfc42fd 100644 --- a/mlir/examples/Linalg/Linalg3/Conversion.cpp +++ b/mlir/examples/Linalg/Linalg3/Conversion.cpp @@ -37,10 +37,10 @@ using namespace linalg; using namespace linalg::common; using namespace linalg::intrinsics; -Function makeFunctionWithAMatmulOp(Module module, StringRef name) { +FuncOp makeFunctionWithAMatmulOp(Module module, StringRef name) { MLIRContext *context = module.getContext(); auto dynamic2DMemRefType = floatMemRefType<2>(context); - mlir::Function f = linalg::common::makeFunction( + mlir::FuncOp f = linalg::common::makeFunction( module, name, {dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {}); @@ -67,7 +67,7 @@ Function makeFunctionWithAMatmulOp(Module module, StringRef name) { TEST_FUNC(foo) { MLIRContext context; OwningModuleRef module = Module::create(&context); - mlir::Function f = makeFunctionWithAMatmulOp(*module, "matmul_as_loops"); + mlir::FuncOp f = makeFunctionWithAMatmulOp(*module, "matmul_as_loops"); lowerToLoops(f); convertLinalg3ToLLVM(*module); diff --git a/mlir/examples/Linalg/Linalg3/Example.cpp b/mlir/examples/Linalg/Linalg3/Example.cpp index e68acf2f983..2b5540fbe32 100644 --- a/mlir/examples/Linalg/Linalg3/Example.cpp +++ b/mlir/examples/Linalg/Linalg3/Example.cpp @@ -34,10 +34,10 @@ using namespace linalg; using namespace linalg::common; using namespace linalg::intrinsics; -Function makeFunctionWithAMatmulOp(Module module, StringRef name) { +FuncOp makeFunctionWithAMatmulOp(Module module, StringRef name) { MLIRContext *context = module.getContext(); auto dynamic2DMemRefType = floatMemRefType<2>(context); - mlir::Function f = linalg::common::makeFunction( + mlir::FuncOp f = linalg::common::makeFunction( module, name, {dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {}); @@ -64,7 +64,7 @@ Function makeFunctionWithAMatmulOp(Module module, StringRef name) { TEST_FUNC(matmul_as_matvec) { MLIRContext context; Module module = Module::create(&context); - mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec"); + mlir::FuncOp f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec"); lowerToFinerGrainedTensorContraction(f); composeSliceOps(f); // clang-format off @@ -82,7 +82,7 @@ TEST_FUNC(matmul_as_matvec) { TEST_FUNC(matmul_as_dot) { MLIRContext context; Module module = Module::create(&context); - mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_dot"); + mlir::FuncOp f = makeFunctionWithAMatmulOp(module, "matmul_as_dot"); lowerToFinerGrainedTensorContraction(f); lowerToFinerGrainedTensorContraction(f); composeSliceOps(f); @@ -103,7 +103,7 @@ TEST_FUNC(matmul_as_dot) { TEST_FUNC(matmul_as_loops) { MLIRContext context; Module module = Module::create(&context); - mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_loops"); + mlir::FuncOp f = makeFunctionWithAMatmulOp(module, "matmul_as_loops"); lowerToLoops(f); composeSliceOps(f); // clang-format off @@ -135,7 +135,7 @@ TEST_FUNC(matmul_as_loops) { TEST_FUNC(matmul_as_matvec_as_loops) { MLIRContext context; Module module = Module::create(&context); - mlir::Function f = + mlir::FuncOp f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec_as_loops"); lowerToFinerGrainedTensorContraction(f); lowerToLoops(f); @@ -166,7 +166,7 @@ TEST_FUNC(matmul_as_matvec_as_loops) { TEST_FUNC(matmul_as_matvec_as_affine) { MLIRContext context; Module module = Module::create(&context); - mlir::Function f = + mlir::FuncOp f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec_as_affine"); lowerToFinerGrainedTensorContraction(f); composeSliceOps(f); diff --git a/mlir/examples/Linalg/Linalg3/Execution.cpp b/mlir/examples/Linalg/Linalg3/Execution.cpp index 4b1078791b7..a70cad8259b 100644 --- a/mlir/examples/Linalg/Linalg3/Execution.cpp +++ b/mlir/examples/Linalg/Linalg3/Execution.cpp @@ -37,10 +37,10 @@ using namespace linalg; using namespace linalg::common; using namespace linalg::intrinsics; -Function makeFunctionWithAMatmulOp(Module module, StringRef name) { +FuncOp makeFunctionWithAMatmulOp(Module module, StringRef name) { MLIRContext *context = module.getContext(); auto dynamic2DMemRefType = floatMemRefType<2>(context); - mlir::Function f = linalg::common::makeFunction( + mlir::FuncOp f = linalg::common::makeFunction( module, name, {dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {}); @@ -110,7 +110,7 @@ TEST_FUNC(execution) { // dialect through partial conversions. MLIRContext context; OwningModuleRef module = Module::create(&context); - mlir::Function f = makeFunctionWithAMatmulOp(*module, "matmul_as_loops"); + mlir::FuncOp f = makeFunctionWithAMatmulOp(*module, "matmul_as_loops"); lowerToLoops(f); convertLinalg3ToLLVM(*module); diff --git a/mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h b/mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h index 81905d3aed7..4346b47cf49 100644 --- a/mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h +++ b/mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h @@ -27,7 +27,6 @@ namespace mlir { class AffineForOp; class AffineMap; class FuncOp; -using Function = FuncOp; class FunctionPassBase; class Operation; class Value; @@ -56,11 +55,11 @@ makeGenericLoopRanges(mlir::AffineMap operandRangesToLoopMaps, /// Traverses `f` and rewrites linalg.slice, and the operations it depends on, /// to only use linalg.view operations. -void composeSliceOps(mlir::Function f); +void composeSliceOps(mlir::FuncOp f); /// Traverses `f` and rewrites linalg.matmul(resp. linalg.matvec) /// as linalg.matvec(resp. linalg.dot). -void lowerToFinerGrainedTensorContraction(mlir::Function f); +void lowerToFinerGrainedTensorContraction(mlir::FuncOp f); /// Operation-wise writing of linalg operations to loop form. /// It is the caller's responsibility to erase the `op` if necessary. @@ -70,7 +69,7 @@ llvm::Optional> writeAsLoops(mlir::Operation *op); /// Traverses `f` and rewrites linalg operations in loop form. -void lowerToLoops(mlir::Function f); +void lowerToLoops(mlir::FuncOp f); /// Creates a pass that rewrites linalg.load and linalg.store to affine.load and /// affine.store operations. diff --git a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp index 7b9e5ffee96..d81eec0a370 100644 --- a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp @@ -35,7 +35,7 @@ using namespace mlir::edsc::intrinsics; using namespace linalg; using namespace linalg::intrinsics; -void linalg::composeSliceOps(mlir::Function f) { +void linalg::composeSliceOps(mlir::FuncOp f) { f.walk([](SliceOp sliceOp) { auto *sliceResult = sliceOp.getResult(); auto viewOp = emitAndReturnFullyComposedView(sliceResult); @@ -44,7 +44,7 @@ void linalg::composeSliceOps(mlir::Function f) { }); } -void linalg::lowerToFinerGrainedTensorContraction(mlir::Function f) { +void linalg::lowerToFinerGrainedTensorContraction(mlir::FuncOp f) { f.walk([](Operation *op) { if (auto matmulOp = dyn_cast(op)) { matmulOp.writeAsFinerGrainTensorContraction(); @@ -211,7 +211,7 @@ linalg::writeAsLoops(Operation *op) { return llvm::None; } -void linalg::lowerToLoops(mlir::Function f) { +void linalg::lowerToLoops(mlir::FuncOp f) { f.walk([](Operation *op) { if (writeAsLoops(op)) op->erase(); diff --git a/mlir/examples/Linalg/Linalg4/Example.cpp b/mlir/examples/Linalg/Linalg4/Example.cpp index 90af11ee8d9..ef8097d3265 100644 --- a/mlir/examples/Linalg/Linalg4/Example.cpp +++ b/mlir/examples/Linalg/Linalg4/Example.cpp @@ -34,10 +34,10 @@ using namespace linalg; using namespace linalg::common; using namespace linalg::intrinsics; -Function makeFunctionWithAMatmulOp(Module module, StringRef name) { +FuncOp makeFunctionWithAMatmulOp(Module module, StringRef name) { MLIRContext *context = module.getContext(); auto dynamic2DMemRefType = floatMemRefType<2>(context); - mlir::Function f = linalg::common::makeFunction( + mlir::FuncOp f = linalg::common::makeFunction( module, name, {dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {}); @@ -65,7 +65,7 @@ Function makeFunctionWithAMatmulOp(Module module, StringRef name) { TEST_FUNC(matmul_tiled_loops) { MLIRContext context; OwningModuleRef module = Module::create(&context); - mlir::Function f = makeFunctionWithAMatmulOp(*module, "matmul_tiled_loops"); + mlir::FuncOp f = makeFunctionWithAMatmulOp(*module, "matmul_tiled_loops"); lowerToTiledLoops(f, {8, 9}); PassManager pm; pm.addPass(createLowerLinalgLoadStorePass()); @@ -96,7 +96,7 @@ TEST_FUNC(matmul_tiled_loops) { TEST_FUNC(matmul_tiled_views) { MLIRContext context; OwningModuleRef module = Module::create(&context); - mlir::Function f = makeFunctionWithAMatmulOp(*module, "matmul_tiled_views"); + mlir::FuncOp f = makeFunctionWithAMatmulOp(*module, "matmul_tiled_views"); OpBuilder b(f.getBody()); lowerToTiledViews(f, {b.create(f.getLoc(), 8), b.create(f.getLoc(), 9)}); @@ -125,7 +125,7 @@ TEST_FUNC(matmul_tiled_views) { TEST_FUNC(matmul_tiled_views_as_loops) { MLIRContext context; OwningModuleRef module = Module::create(&context); - mlir::Function f = + mlir::FuncOp f = makeFunctionWithAMatmulOp(*module, "matmul_tiled_views_as_loops"); OpBuilder b(f.getBody()); lowerToTiledViews(f, {b.create(f.getLoc(), 8), diff --git a/mlir/examples/Linalg/Linalg4/include/linalg4/Transforms.h b/mlir/examples/Linalg/Linalg4/include/linalg4/Transforms.h index ba7273e409d..259ac79ebf2 100644 --- a/mlir/examples/Linalg/Linalg4/include/linalg4/Transforms.h +++ b/mlir/examples/Linalg/Linalg4/include/linalg4/Transforms.h @@ -34,13 +34,12 @@ writeAsTiledViews(mlir::Operation *op, llvm::ArrayRef tileSizes); /// Apply `writeAsTiledLoops` on all linalg ops. This is a convenience function /// and is not exposed as a pass because a fixed set of tile sizes for all ops /// in a function can generally not be specified. -void lowerToTiledLoops(mlir::Function f, llvm::ArrayRef tileSizes); +void lowerToTiledLoops(mlir::FuncOp f, llvm::ArrayRef tileSizes); /// Apply `writeAsTiledViews` on all linalg ops. This is a convenience function /// and is not exposed as a pass because a fixed set of tile sizes for all ops /// in a function can generally not be specified. -void lowerToTiledViews(mlir::Function f, - llvm::ArrayRef tileSizes); +void lowerToTiledViews(mlir::FuncOp f, llvm::ArrayRef tileSizes); } // namespace linalg diff --git a/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp index 16b395da506..fcb4c20704d 100644 --- a/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp @@ -43,7 +43,7 @@ linalg::writeAsTiledLoops(Operation *op, ArrayRef tileSizes) { return llvm::None; } -void linalg::lowerToTiledLoops(mlir::Function f, ArrayRef tileSizes) { +void linalg::lowerToTiledLoops(mlir::FuncOp f, ArrayRef tileSizes) { f.walk([tileSizes](Operation *op) { if (writeAsTiledLoops(op, tileSizes).hasValue()) op->erase(); @@ -184,7 +184,7 @@ linalg::writeAsTiledViews(Operation *op, ArrayRef tileSizes) { return llvm::None; } -void linalg::lowerToTiledViews(mlir::Function f, ArrayRef tileSizes) { +void linalg::lowerToTiledViews(mlir::FuncOp f, ArrayRef tileSizes) { f.walk([tileSizes](Operation *op) { if (auto matmulOp = dyn_cast(op)) { writeAsTiledViews(matmulOp, tileSizes); diff --git a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp index 396997ef608..831e2ab542e 100644 --- a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp @@ -130,15 +130,15 @@ private: /// Create the prototype for an MLIR function with as many arguments as the /// provided Toy AST prototype. - mlir::Function mlirGen(PrototypeAST &proto) { + mlir::FuncOp mlirGen(PrototypeAST &proto) { // This is a generic function, the return type will be inferred later. llvm::SmallVector ret_types; // Arguments type is uniformly a generic array. llvm::SmallVector arg_types(proto.getArgs().size(), getType(VarType{})); auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context); - auto function = mlir::Function::create(loc(proto.loc()), proto.getName(), - func_type, /* attrs = */ {}); + auto function = mlir::FuncOp::create(loc(proto.loc()), proto.getName(), + func_type, /* attrs = */ {}); // Mark the function as generic: it'll require type specialization for every // call site. @@ -149,12 +149,12 @@ private: } /// Emit a new function and add it to the MLIR module. - mlir::Function mlirGen(FunctionAST &funcAST) { + mlir::FuncOp mlirGen(FunctionAST &funcAST) { // Create a scope in the symbol table to hold variable declarations. ScopedHashTableScope var_scope(symbolTable); // Create an MLIR function for the given prototype. - mlir::Function function(mlirGen(*funcAST.getProto())); + mlir::FuncOp function(mlirGen(*funcAST.getProto())); if (!function) return nullptr; diff --git a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp index 0312d179a33..fb106d8ee97 100644 --- a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp @@ -131,15 +131,15 @@ private: /// Create the prototype for an MLIR function with as many arguments as the /// provided Toy AST prototype. - mlir::Function mlirGen(PrototypeAST &proto) { + mlir::FuncOp mlirGen(PrototypeAST &proto) { // This is a generic function, the return type will be inferred later. llvm::SmallVector ret_types; // Arguments type is uniformly a generic array. llvm::SmallVector arg_types(proto.getArgs().size(), getType(VarType{})); auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context); - auto function = mlir::Function::create(loc(proto.loc()), proto.getName(), - func_type, /* attrs = */ {}); + auto function = mlir::FuncOp::create(loc(proto.loc()), proto.getName(), + func_type, /* attrs = */ {}); // Mark the function as generic: it'll require type specialization for every // call site. @@ -150,12 +150,12 @@ private: } /// Emit a new function and add it to the MLIR module. - mlir::Function mlirGen(FunctionAST &funcAST) { + mlir::FuncOp mlirGen(FunctionAST &funcAST) { // Create a scope in the symbol table to hold variable declarations. ScopedHashTableScope var_scope(symbolTable); // Create an MLIR function for the given prototype. - mlir::Function function(mlirGen(*funcAST.getProto())); + mlir::FuncOp function(mlirGen(*funcAST.getProto())); if (!function) return nullptr; diff --git a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp index 069e391f60b..0fabbfd2316 100644 --- a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp @@ -131,15 +131,15 @@ private: /// Create the prototype for an MLIR function with as many arguments as the /// provided Toy AST prototype. - mlir::Function mlirGen(PrototypeAST &proto) { + mlir::FuncOp mlirGen(PrototypeAST &proto) { // This is a generic function, the return type will be inferred later. llvm::SmallVector ret_types; // Arguments type is uniformly a generic array. llvm::SmallVector arg_types(proto.getArgs().size(), getType(VarType{})); auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context); - auto function = mlir::Function::create(loc(proto.loc()), proto.getName(), - func_type, /* attrs = */ {}); + auto function = mlir::FuncOp::create(loc(proto.loc()), proto.getName(), + func_type, /* attrs = */ {}); // Mark the function as generic: it'll require type specialization for every // call site. @@ -150,12 +150,12 @@ private: } /// Emit a new function and add it to the MLIR module. - mlir::Function mlirGen(FunctionAST &funcAST) { + mlir::FuncOp mlirGen(FunctionAST &funcAST) { // Create a scope in the symbol table to hold variable declarations. ScopedHashTableScope var_scope(symbolTable); // Create an MLIR function for the given prototype. - mlir::Function function(mlirGen(*funcAST.getProto())); + mlir::FuncOp function(mlirGen(*funcAST.getProto())); if (!function) return nullptr; diff --git a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp index fe24f0fcd3e..4814fe72f78 100644 --- a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp @@ -114,7 +114,7 @@ public: // function to process, the mangled name for this specialization, and the // types of the arguments on which to specialize. struct FunctionToSpecialize { - mlir::Function function; + mlir::FuncOp function; std::string mangledName; SmallVector argumentsType; }; @@ -140,8 +140,8 @@ public: // Delete any generic function left // FIXME: we may want this as a separate pass. - for (mlir::Function function : - llvm::make_early_inc_range(module.getOps())) { + for (mlir::FuncOp function : + llvm::make_early_inc_range(module.getOps())) { if (auto genericAttr = function.getAttrOfType("toy.generic")) { if (genericAttr.getValue()) @@ -155,7 +155,7 @@ public: mlir::LogicalResult specialize(SmallVectorImpl &funcWorklist) { FunctionToSpecialize &functionToSpecialize = funcWorklist.back(); - mlir::Function f = functionToSpecialize.function; + mlir::FuncOp f = functionToSpecialize.function; // Check if cloning for specialization is needed (usually anything but main) // We will create a new function with the concrete types for the parameters @@ -172,8 +172,8 @@ public: {ToyArrayType::get(&getContext())}, &getContext()); auto newFunction = - mlir::Function::create(f.getLoc(), functionToSpecialize.mangledName, - type, f.getDialectAttrs()); + mlir::FuncOp::create(f.getLoc(), functionToSpecialize.mangledName, + type, f.getDialectAttrs()); getModule().push_back(newFunction); // Clone the function body diff --git a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp index 5267ae3d5db..2f158f4326e 100644 --- a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp @@ -136,7 +136,7 @@ public: PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, PatternRewriter &rewriter) const override { // Get or create the declaration of the printf function in the module. - Function printfFunc = getPrintf(op->getParentOfType()); + FuncOp printfFunc = getPrintf(op->getParentOfType()); auto print = cast(op); auto loc = print.getLoc(); @@ -205,7 +205,7 @@ private: /// Return the prototype declaration for printf in the module, create it if /// necessary. - Function getPrintf(Module module) const { + FuncOp getPrintf(Module module) const { auto printfFunc = module.getNamedFunction("printf"); if (printfFunc) return printfFunc; @@ -218,7 +218,7 @@ private: auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(dialect); auto llvmI8PtrTy = LLVM::LLVMType::getInt8Ty(dialect).getPointerTo(); auto printfTy = builder.getFunctionType({llvmI8PtrTy}, {llvmI32Ty}); - printfFunc = Function::create(builder.getUnknownLoc(), "printf", printfTy); + printfFunc = FuncOp::create(builder.getUnknownLoc(), "printf", printfTy); // It should be variadic, but we don't support it fully just yet. printfFunc.setAttr("std.varargs", builder.getBoolAttr(true)); module.push_back(printfFunc); diff --git a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp index 6b0f67dd226..0fc1fc47408 100644 --- a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp @@ -131,15 +131,15 @@ private: /// Create the prototype for an MLIR function with as many arguments as the /// provided Toy AST prototype. - mlir::Function mlirGen(PrototypeAST &proto) { + mlir::FuncOp mlirGen(PrototypeAST &proto) { // This is a generic function, the return type will be inferred later. llvm::SmallVector ret_types; // Arguments type is uniformly a generic array. llvm::SmallVector arg_types(proto.getArgs().size(), getType(VarType{})); auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context); - auto function = mlir::Function::create(loc(proto.loc()), proto.getName(), - func_type, /* attrs = */ {}); + auto function = mlir::FuncOp::create(loc(proto.loc()), proto.getName(), + func_type, /* attrs = */ {}); // Mark the function as generic: it'll require type specialization for every // call site. @@ -150,12 +150,12 @@ private: } /// Emit a new function and add it to the MLIR module. - mlir::Function mlirGen(FunctionAST &funcAST) { + mlir::FuncOp mlirGen(FunctionAST &funcAST) { // Create a scope in the symbol table to hold variable declarations. ScopedHashTableScope var_scope(symbolTable); // Create an MLIR function for the given prototype. - mlir::Function function(mlirGen(*funcAST.getProto())); + mlir::FuncOp function(mlirGen(*funcAST.getProto())); if (!function) return nullptr; diff --git a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp index a94b261b30e..8ffc44fd46e 100644 --- a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp @@ -104,7 +104,7 @@ namespace { /// a) Take the last inserted function in the worklist. /// b) Run the intra-procedural shape inference on this function. /// c) If the intra-procedural shape inference can't complete, it returns -/// a Function that needs to be inferred first. In this case, queue this +/// a FuncOp that needs to be inferred first. In this case, queue this /// new function and continue. Otherwise the inference succeeded and we /// can pop from the queue. /// @@ -114,7 +114,7 @@ public: // function to process, the mangled name for this specialization, and the // types of the arguments on which to specialize. struct FunctionToSpecialize { - mlir::Function function; + mlir::FuncOp function; std::string mangledName; SmallVector argumentsType; }; @@ -141,8 +141,8 @@ public: // Delete any generic function left // FIXME: we may want this as a separate pass. - for (mlir::Function function : - llvm::make_early_inc_range(module.getOps())) { + for (mlir::FuncOp function : + llvm::make_early_inc_range(module.getOps())) { if (auto genericAttr = function.getAttrOfType("toy.generic")) { if (genericAttr.getValue()) @@ -157,7 +157,7 @@ public: specialize(SmallVectorImpl &funcWorklist, mlir::ModuleManager &moduleManager) { FunctionToSpecialize &functionToSpecialize = funcWorklist.back(); - mlir::Function f = functionToSpecialize.function; + mlir::FuncOp f = functionToSpecialize.function; // Check if cloning for specialization is needed (usually anything but main) // We will create a new function with the concrete types for the parameters @@ -165,7 +165,7 @@ public: if (!functionToSpecialize.mangledName.empty()) { if (moduleManager.getNamedFunction(functionToSpecialize.mangledName)) { funcWorklist.pop_back(); - // Function already specialized, move on. + // FuncOp already specialized, move on. return mlir::success(); } // Create a new function with a generic array return type, it will be @@ -174,8 +174,8 @@ public: {ToyArrayType::get(&getContext())}, &getContext()); auto newFunction = - mlir::Function::create(f.getLoc(), functionToSpecialize.mangledName, - type, f.getDialectAttrs()); + mlir::FuncOp::create(f.getLoc(), functionToSpecialize.mangledName, + type, f.getDialectAttrs()); moduleManager.insert(newFunction); // Clone the function body diff --git a/mlir/g3doc/Tutorials/Toy/Ch-4.md b/mlir/g3doc/Tutorials/Toy/Ch-4.md index 1a725db5bd3..343d8f9e00b 100644 --- a/mlir/g3doc/Tutorials/Toy/Ch-4.md +++ b/mlir/g3doc/Tutorials/Toy/Ch-4.md @@ -229,7 +229,7 @@ inter-procedural flow that wraps the intra-procedural inference: - Take the last inserted function in the worklist. - Run the intra-procedural shape inference on this function. - If the intra-procedural shape inference can't complete, it returns a - Function that needs to be inferred first. In this case, queue this new + FuncOp that needs to be inferred first. In this case, queue this new function and continue. Otherwise the inference succeeded and we can pop from the queue. diff --git a/mlir/g3doc/WritingAPass.md b/mlir/g3doc/WritingAPass.md index c432ed378c7..597fcd2cf08 100644 --- a/mlir/g3doc/WritingAPass.md +++ b/mlir/g3doc/WritingAPass.md @@ -54,7 +54,7 @@ namespace { struct MyFunctionPass : public FunctionPass { void runOnFunction() override { // Get the current function being operated on. - Function f = getFunction(); + FuncOp f = getFunction(); // Operate on the operations within the function. f.walk([](Operation *inst) { @@ -114,7 +114,7 @@ static PassRegistration pass( An important concept, along with transformation passes, are analyses. These are conceptually similar to transformation passes, except that they compute -information on a specific Function, or Module, without modifying it. In MLIR, +information on a specific FuncOp, or Module, without modifying it. In MLIR, analyses are not passes but free standing classes that are computed lazily on-demand and cached to avoid unnecessary recomputation. An analysis in MLIR must adhere to the following: @@ -143,7 +143,7 @@ above, let's see some examples: /// An interesting function analysis. struct MyFunctionAnalysis { // Compute this analysis with the provided function. - MyFunctionAnalysis(Function function); + MyFunctionAnalysis(FuncOp function); }; /// An interesting module analysis. diff --git a/mlir/include/mlir-c/Core.h b/mlir/include/mlir-c/Core.h index 0eaf9bb991f..918ccdf60ec 100644 --- a/mlir/include/mlir-c/Core.h +++ b/mlir/include/mlir-c/Core.h @@ -34,7 +34,7 @@ extern "C" { typedef void *mlir_context_t; /// Opaque C type for mlir::Type. typedef const void *mlir_type_t; -/// Opaque C type for mlir::Function*. +/// Opaque C type for mlir::FuncOp. typedef void *mlir_func_t; /// Opaque C type for mlir::Attribute. typedef const void *mlir_attr_t; diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h index 3e2b90d6557..968ffb1a791 100644 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -379,8 +379,8 @@ public: /// constraint system. Returns failure for the yet unimplemented/unsupported /// cases. Any new identifiers that are found in the bound operands of the /// 'affine.for' operation are added as trailing identifiers (either - /// dimensional or symbolic depending on whether the operand is a valid ML - /// Function symbol). + /// dimensional or symbolic depending on whether the operand is a valid + /// symbol). // TODO(bondhugula): add support for non-unit strides. LogicalResult addAffineForOpDomain(AffineForOp forOp); diff --git a/mlir/include/mlir/Analysis/NestedMatcher.h b/mlir/include/mlir/Analysis/NestedMatcher.h index b89011a28e3..b07b73a023a 100644 --- a/mlir/include/mlir/Analysis/NestedMatcher.h +++ b/mlir/include/mlir/Analysis/NestedMatcher.h @@ -104,7 +104,7 @@ struct NestedPattern { NestedPattern &operator=(const NestedPattern &) = default; /// Returns all the top-level matches in `func`. - void match(Function func, SmallVectorImpl *matches) { + void match(FuncOp func, SmallVectorImpl *matches) { func.walk([&](Operation *op) { matchOne(op, matches); }); } diff --git a/mlir/include/mlir/Analysis/Passes.h b/mlir/include/mlir/Analysis/Passes.h index 4790f6bc160..9eafcd35576 100644 --- a/mlir/include/mlir/Analysis/Passes.h +++ b/mlir/include/mlir/Analysis/Passes.h @@ -29,10 +29,10 @@ namespace mlir { class FunctionPassBase; -/// Creates a pass to check memref accesses in an ML Function. +/// Creates a pass to check memref accesses in a Function. FunctionPassBase *createMemRefBoundCheckPass(); -/// Creates a pass to check memref access dependences in an ML Function. +/// Creates a pass to check memref access dependences in a Function. FunctionPassBase *createTestMemRefDependenceCheckPass(); /// Creates a pass to test parallelism detection; emits note for parallel loops. diff --git a/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h b/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h index e3d9ab607e5..b19fb53e3e2 100644 --- a/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h +++ b/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h @@ -26,11 +26,9 @@ namespace mlir { class ModulePassBase; class FuncOp; -using Function = FuncOp; using OwnedCubin = std::unique_ptr>; -using CubinGenerator = - std::function; +using CubinGenerator = std::function; /// Creates a pass to convert kernel functions into CUBIN blobs. /// diff --git a/mlir/include/mlir/ExecutionEngine/MemRefUtils.h b/mlir/include/mlir/ExecutionEngine/MemRefUtils.h index df6b128ff43..694686467a9 100644 --- a/mlir/include/mlir/ExecutionEngine/MemRefUtils.h +++ b/mlir/include/mlir/ExecutionEngine/MemRefUtils.h @@ -30,9 +30,7 @@ template class Expected; } namespace mlir { - class FuncOp; -using Function = FuncOp; /// Simple memref descriptor class compatible with the ABI of functions emitted /// by MLIR to LLVM IR conversion for statically-shaped memrefs of float type. @@ -45,7 +43,7 @@ struct StaticFloatMemRef { /// each of the arguments, initialize the storage with `initialValue`, and /// return a list of type-erased descriptor pointers. llvm::Expected> -allocateMemRefArguments(Function func, float initialValue = 0.0); +allocateMemRefArguments(FuncOp func, float initialValue = 0.0); /// Free a list of type-erased descriptors to statically-shaped memrefs with /// element type f32. diff --git a/mlir/include/mlir/GPU/GPUDialect.h b/mlir/include/mlir/GPU/GPUDialect.h index 4a304fdc77e..381810140ca 100644 --- a/mlir/include/mlir/GPU/GPUDialect.h +++ b/mlir/include/mlir/GPU/GPUDialect.h @@ -44,7 +44,7 @@ public: /// Returns whether the given function is a kernel function, i.e., has the /// 'gpu.kernel' attribute. - static bool isKernel(Function function); + static bool isKernel(FuncOp function); }; /// Utility class for the GPU dialect to represent triples of `Value`s @@ -129,14 +129,14 @@ class LaunchFuncOp : public Op::Impl, public: using Op::Op; - static void build(Builder *builder, OperationState *result, - Function kernelFunc, Value *gridSizeX, Value *gridSizeY, - Value *gridSizeZ, Value *blockSizeX, Value *blockSizeY, - Value *blockSizeZ, ArrayRef kernelOperands); + static void build(Builder *builder, OperationState *result, FuncOp kernelFunc, + Value *gridSizeX, Value *gridSizeY, Value *gridSizeZ, + Value *blockSizeX, Value *blockSizeY, Value *blockSizeZ, + ArrayRef kernelOperands); - static void build(Builder *builder, OperationState *result, - Function kernelFunc, KernelDim3 gridSize, - KernelDim3 blockSize, ArrayRef kernelOperands); + static void build(Builder *builder, OperationState *result, FuncOp kernelFunc, + KernelDim3 gridSize, KernelDim3 blockSize, + ArrayRef kernelOperands); /// The kernel function specified by the operation's `kernel` attribute. StringRef kernel(); diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index 7c86e9bbe9e..f7d14abf063 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -24,8 +24,6 @@ namespace mlir { class AffineMap; class Dialect; -class FuncOp; -using Function = FuncOp; class FunctionType; class Identifier; class IntegerSet; diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index ba77472224a..07ed9912e31 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -113,7 +113,7 @@ public: AffineMapAttr getAffineMapAttr(AffineMap map); IntegerSetAttr getIntegerSetAttr(IntegerSet set); TypeAttr getTypeAttr(Type type); - FunctionAttr getFunctionAttr(Function value); + FunctionAttr getFunctionAttr(FuncOp value); FunctionAttr getFunctionAttr(StringRef value); ElementsAttr getDenseElementsAttr(ShapedType type, ArrayRef values); diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h index 4e82689efff..5062b6f8c8c 100644 --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -25,6 +25,7 @@ #include "mlir/IR/OperationSupport.h" namespace mlir { +class FuncOp; class OpBuilder; class Type; @@ -145,12 +146,12 @@ public: /// Verify an attribute from this dialect on the given function. Returns /// failure if the verification failed, success otherwise. - virtual LogicalResult verifyFunctionAttribute(Function, NamedAttribute); + virtual LogicalResult verifyFunctionAttribute(FuncOp, NamedAttribute); /// Verify an attribute from this dialect on the argument at 'argIndex' for /// the given function. Returns failure if the verification failed, success /// otherwise. - virtual LogicalResult verifyFunctionArgAttribute(Function, unsigned argIndex, + virtual LogicalResult verifyFunctionArgAttribute(FuncOp, unsigned argIndex, NamedAttribute); /// Verify an attribute from this dialect on the given operation. Returns diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h index a7a2c68a992..ae13328bf7f 100644 --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -236,9 +236,6 @@ private: return getAttrOfType(getArgAttrName(index, nameOut)); } }; - -/// Temporary forward declaration of FuncOp to the legacy Function. -using Function = FuncOp; } // end namespace mlir namespace llvm { diff --git a/mlir/include/mlir/IR/Module.h b/mlir/include/mlir/IR/Module.h index aac60a2e6e9..59b28fdd014 100644 --- a/mlir/include/mlir/IR/Module.h +++ b/mlir/include/mlir/IR/Module.h @@ -97,9 +97,7 @@ public: /// Look up a function with the specified name, returning null if no such /// name exists. Function names never include the @ on them. Note: This /// performs a linear scan of held symbols. - Function getNamedFunction(StringRef name) { - return lookupSymbol(name); - } + FuncOp getNamedFunction(StringRef name) { return lookupSymbol(name); } }; /// The ModuleTerminatorOp is a special terminator operation for the body of a diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 921437601e1..34daa3ec0ba 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -420,7 +420,7 @@ private: /// patterns in a greedy work-list driven manner. Return true if no more /// patterns can be matched in the result function. /// -bool applyPatternsGreedily(Function fn, OwningRewritePatternList &&patterns); +bool applyPatternsGreedily(FuncOp fn, OwningRewritePatternList &&patterns); /// Helper class to create a list of rewrite patterns given a list of their /// types and a list of attributes perfect-forwarded to each of the conversion diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h index b5dbd539eb0..1bad41f4c4c 100644 --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -28,8 +28,6 @@ namespace mlir { class Block; -class FuncOp; -using Function = FuncOp; class Operation; class Region; class Value; diff --git a/mlir/include/mlir/LLVMIR/LLVMDialect.h b/mlir/include/mlir/LLVMIR/LLVMDialect.h index a28aa719965..60876cdcd06 100644 --- a/mlir/include/mlir/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/LLVMIR/LLVMDialect.h @@ -153,7 +153,7 @@ public: /// Verify a function argument attribute registered to this dialect. /// Returns failure if the verification failed, success otherwise. - LogicalResult verifyFunctionArgAttribute(Function func, unsigned argIdx, + LogicalResult verifyFunctionArgAttribute(FuncOp func, unsigned argIdx, NamedAttribute argAttr) override; private: diff --git a/mlir/include/mlir/Pass/AnalysisManager.h b/mlir/include/mlir/Pass/AnalysisManager.h index 18ba7a826cc..58b9644bd0c 100644 --- a/mlir/include/mlir/Pass/AnalysisManager.h +++ b/mlir/include/mlir/Pass/AnalysisManager.h @@ -207,14 +207,14 @@ public: private: FunctionAnalysisManager(const ModuleAnalysisManager *parent, - detail::AnalysisMap *impl) + detail::AnalysisMap *impl) : parent(parent), impl(impl) {} /// A reference to the parent analysis manager. const ModuleAnalysisManager *parent; /// A reference to the impl analysis map within the owning analysis manager. - detail::AnalysisMap *impl; + detail::AnalysisMap *impl; /// Allow access to the constructor. friend class ModuleAnalysisManager; @@ -231,14 +231,14 @@ public: /// Query for the analysis of a function. The analysis is computed if it does /// not exist. template - AnalysisT &getFunctionAnalysis(Function function) { + AnalysisT &getFunctionAnalysis(FuncOp function) { return slice(function).getAnalysis(); } /// Query for a cached analysis of a child function, or return null. template llvm::Optional> - getCachedFunctionAnalysis(Function function) const { + getCachedFunctionAnalysis(FuncOp function) const { auto it = functionAnalyses.find(function); if (it == functionAnalyses.end()) return llvm::None; @@ -258,7 +258,7 @@ public: } /// Create an analysis slice for the given child function. - FunctionAnalysisManager slice(Function function); + FunctionAnalysisManager slice(FuncOp function); /// Invalidate any non preserved analyses. void invalidate(const detail::PreservedAnalyses &pa); @@ -269,7 +269,7 @@ public: private: /// The cached analyses for functions within the current module. - llvm::DenseMap>> + llvm::DenseMap>> functionAnalyses; /// The analyses for the owning module. diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h index 6ee78c52a25..c8b2e2ad563 100644 --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -90,7 +90,7 @@ struct PassExecutionState { /// FunctionPass class. class FunctionPassBase : public Pass { using PassStateT = - detail::PassExecutionState; + detail::PassExecutionState; public: static bool classof(const Pass *pass) { @@ -107,7 +107,7 @@ protected: virtual FunctionPassBase *clone() const = 0; /// Return the current function being transformed. - Function getFunction() { return getPassState().irAndPassFailed.getPointer(); } + FuncOp getFunction() { return getPassState().irAndPassFailed.getPointer(); } /// Return the MLIR context for the current function being transformed. MLIRContext &getContext() { return *getFunction().getContext(); } @@ -126,7 +126,7 @@ protected: private: /// Forwarding function to execute this pass. LLVM_NODISCARD - LogicalResult run(Function fn, FunctionAnalysisManager &fam); + LogicalResult run(FuncOp fn, FunctionAnalysisManager &fam); /// The current execution state for the pass. llvm::Optional passState; @@ -249,7 +249,7 @@ protected: /// Derived function passes are expected to provide the following: /// - A 'void runOnFunction()' method. template -struct FunctionPass : public detail::PassModel { +struct FunctionPass : public detail::PassModel { /// Returns the analysis for the parent module if it exists. template llvm::Optional> getCachedModuleAnalysis() { @@ -270,7 +270,7 @@ struct FunctionPass : public detail::PassModel { template struct ModulePass : public detail::PassModel { /// Returns the analysis for a child function. - template AnalysisT &getFunctionAnalysis(Function f) { + template AnalysisT &getFunctionAnalysis(FuncOp f) { return this->getAnalysisManager().template getFunctionAnalysis( f); } @@ -278,7 +278,7 @@ struct ModulePass : public detail::PassModel { /// Returns an existing analysis for a child function if it exists. template llvm::Optional> - getCachedFunctionAnalysis(Function f) { + getCachedFunctionAnalysis(FuncOp f) { return this->getAnalysisManager() .template getCachedFunctionAnalysis(f); } diff --git a/mlir/include/mlir/StandardOps/Ops.td b/mlir/include/mlir/StandardOps/Ops.td index 8e37a58cd14..87168fd2e7d 100644 --- a/mlir/include/mlir/StandardOps/Ops.td +++ b/mlir/include/mlir/StandardOps/Ops.td @@ -214,7 +214,7 @@ def CallOp : Std_Op<"call"> { let results = (outs Variadic); let builders = [OpBuilder< - "Builder *builder, OperationState *result, Function callee," + "Builder *builder, OperationState *result, FuncOp callee," "ArrayRef operands = {}", [{ result->addOperands(operands); result->addAttribute("callee", builder->getFunctionAttr(callee)); diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h index f72901875b1..662ceaba7a9 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -73,8 +73,8 @@ protected: private: bool convertFunctions(); - bool convertOneFunction(Function func); - void connectPHINodes(Function func); + bool convertOneFunction(FuncOp func); + void connectPHINodes(FuncOp func); bool convertBlock(Block &bb, bool ignoreArguments); template diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 0101673b500..45aaf722c25 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -345,7 +345,7 @@ LLVM_NODISCARD LogicalResult applyConversionPatterns( /// Convert the given functions with the provided conversion patterns. This /// function returns failure if a type conversion failed. LLVM_NODISCARD -LogicalResult applyConversionPatterns(MutableArrayRef fns, +LogicalResult applyConversionPatterns(MutableArrayRef fns, ConversionTarget &target, TypeConverter &converter, OwningRewritePatternList &&patterns); @@ -354,7 +354,7 @@ LogicalResult applyConversionPatterns(MutableArrayRef fns, /// convert as many of the operations within 'fn' as possible given the set of /// patterns. LLVM_NODISCARD -LogicalResult applyConversionPatterns(Function fn, ConversionTarget &target, +LogicalResult applyConversionPatterns(FuncOp fn, ConversionTarget &target, OwningRewritePatternList &&patterns); } // end namespace mlir diff --git a/mlir/include/mlir/Transforms/FoldUtils.h b/mlir/include/mlir/Transforms/FoldUtils.h index c56bef03bf1..87a3e13c0cd 100644 --- a/mlir/include/mlir/Transforms/FoldUtils.h +++ b/mlir/include/mlir/Transforms/FoldUtils.h @@ -27,8 +27,6 @@ #include "mlir/IR/Dialect.h" namespace mlir { -class FuncOp; -using Function = FuncOp; class Operation; class Value; diff --git a/mlir/include/mlir/Transforms/LoopUtils.h b/mlir/include/mlir/Transforms/LoopUtils.h index 654555830ec..aa124f33dfe 100644 --- a/mlir/include/mlir/Transforms/LoopUtils.h +++ b/mlir/include/mlir/Transforms/LoopUtils.h @@ -32,7 +32,6 @@ class AffineMap; class AffineForOp; class ForOp; class FuncOp; -using Function = FuncOp; class OpBuilder; class Value; @@ -72,7 +71,7 @@ LogicalResult promoteIfSingleIteration(AffineForOp forOp); /// Promotes all single iteration AffineForOp's in the Function, i.e., moves /// their body into the containing Block. -void promoteSingleIterationLoops(Function *f); +void promoteSingleIterationLoops(FuncOp f); /// Computes the cleanup loop lower bound of the loop being unrolled with /// the specified unroll factor; this bound will also be upper bound of the main diff --git a/mlir/include/mlir/Transforms/LowerAffine.h b/mlir/include/mlir/Transforms/LowerAffine.h index f9c32c9063a..1711bdd467b 100644 --- a/mlir/include/mlir/Transforms/LowerAffine.h +++ b/mlir/include/mlir/Transforms/LowerAffine.h @@ -24,7 +24,6 @@ namespace mlir { class AffineExpr; class AffineForOp; class FuncOp; -using Function = FuncOp; class Location; struct LogicalResult; class OpBuilder; @@ -38,7 +37,7 @@ Value *expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, /// Convert from the Affine dialect to the Standard dialect, in particular /// convert structured affine control flow into CFG branch-based control flow. -LogicalResult lowerAffineConstructs(Function function); +LogicalResult lowerAffineConstructs(FuncOp function); /// Emit code that computes the lower bound of the given affine loop using /// standard arithmetic operations. diff --git a/mlir/lib/Analysis/TestParallelismDetection.cpp b/mlir/lib/Analysis/TestParallelismDetection.cpp index 473d253cfa2..9ae8a311f6d 100644 --- a/mlir/lib/Analysis/TestParallelismDetection.cpp +++ b/mlir/lib/Analysis/TestParallelismDetection.cpp @@ -43,7 +43,7 @@ FunctionPassBase *mlir::createParallelismDetectionTestPass() { // Walks the function and emits a note for all 'affine.for' ops detected as // parallel. void TestParallelismDetection::runOnFunction() { - Function f = getFunction(); + FuncOp f = getFunction(); OpBuilder b(f.getBody()); f.walk([&](AffineForOp forOp) { if (isLoopParallel(forOp)) diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp index 1dbedf9fcee..8f381604bb5 100644 --- a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp @@ -75,12 +75,12 @@ public: private: static OwnedCubin compilePtxToCubinForTesting(const std::string &ptx, - Function &function); + FuncOp &function); std::string translateModuleToPtx(llvm::Module &module, llvm::TargetMachine &target_machine); - OwnedCubin convertModuleToCubin(llvm::Module &llvmModule, Function &function); - LogicalResult translateGpuKernelToCubinAnnotation(Function &function); + OwnedCubin convertModuleToCubin(llvm::Module &llvmModule, FuncOp &function); + LogicalResult translateGpuKernelToCubinAnnotation(FuncOp &function); CubinGenerator cubinGenerator; }; @@ -104,13 +104,13 @@ std::string GpuKernelToCubinPass::translateModuleToPtx( OwnedCubin GpuKernelToCubinPass::compilePtxToCubinForTesting(const std::string &ptx, - Function &function) { + FuncOp &function) { const char data[] = "CUBIN"; return llvm::make_unique>(data, data + sizeof(data) - 1); } OwnedCubin GpuKernelToCubinPass::convertModuleToCubin(llvm::Module &llvmModule, - Function &function) { + FuncOp &function) { std::unique_ptr targetMachine; { std::string error; @@ -136,7 +136,7 @@ OwnedCubin GpuKernelToCubinPass::convertModuleToCubin(llvm::Module &llvmModule, } LogicalResult -GpuKernelToCubinPass::translateGpuKernelToCubinAnnotation(Function &function) { +GpuKernelToCubinPass::translateGpuKernelToCubinAnnotation(FuncOp &function) { Builder builder(function.getContext()); OwningModuleRef module = builder.createModule(); diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp index dafc5fa5730..24fc706d82f 100644 --- a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp @@ -88,9 +88,7 @@ private: LLVM::LLVMType getPointerType() { return llvmPointerType; } - LLVM::LLVMType getPointerPointerType() { - return llvmPointerPointerType; - } + LLVM::LLVMType getPointerPointerType() { return llvmPointerPointerType; } LLVM::LLVMType getInt8Type() { return llvmInt8Type; } @@ -118,7 +116,7 @@ private: void declareCudaFunctions(Location loc); Value *setupParamsArray(gpu::LaunchFuncOp launchOp, OpBuilder &builder); - Value *generateKernelNameConstant(Function kernelFunction, Location &loc, + Value *generateKernelNameConstant(FuncOp kernelFunction, Location &loc, OpBuilder &builder); void translateGpuLaunchCalls(mlir::gpu::LaunchFuncOp launchOp); @@ -156,33 +154,33 @@ void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) { Builder builder(module); if (!module.getNamedFunction(cuModuleLoadName)) { module.push_back( - Function::create(loc, cuModuleLoadName, - builder.getFunctionType( - { - getPointerPointerType(), /* CUmodule *module */ - getPointerType() /* void *cubin */ - }, - getCUResultType()))); + FuncOp::create(loc, cuModuleLoadName, + builder.getFunctionType( + { + getPointerPointerType(), /* CUmodule *module */ + getPointerType() /* void *cubin */ + }, + getCUResultType()))); } if (!module.getNamedFunction(cuModuleGetFunctionName)) { // The helper uses void* instead of CUDA's opaque CUmodule and // CUfunction. module.push_back( - Function::create(loc, cuModuleGetFunctionName, - builder.getFunctionType( - { - getPointerPointerType(), /* void **function */ - getPointerType(), /* void *module */ - getPointerType() /* char *name */ - }, - getCUResultType()))); + FuncOp::create(loc, cuModuleGetFunctionName, + builder.getFunctionType( + { + getPointerPointerType(), /* void **function */ + getPointerType(), /* void *module */ + getPointerType() /* char *name */ + }, + getCUResultType()))); } if (!module.getNamedFunction(cuLaunchKernelName)) { // Other than the CUDA api, the wrappers use uintptr_t to match the // LLVM type if MLIR's index type, which the GPU dialect uses. // Furthermore, they use void* instead of CUDA's opaque CUfunction and // CUstream. - module.push_back(Function::create( + module.push_back(FuncOp::create( loc, cuLaunchKernelName, builder.getFunctionType( { @@ -203,18 +201,18 @@ void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) { if (!module.getNamedFunction(cuGetStreamHelperName)) { // Helper function to get the current CUDA stream. Uses void* instead of // CUDAs opaque CUstream. - module.push_back(Function::create( + module.push_back(FuncOp::create( loc, cuGetStreamHelperName, builder.getFunctionType({}, getPointerType() /* void *stream */))); } if (!module.getNamedFunction(cuStreamSynchronizeName)) { module.push_back( - Function::create(loc, cuStreamSynchronizeName, - builder.getFunctionType( - { - getPointerType() /* CUstream stream */ - }, - getCUResultType()))); + FuncOp::create(loc, cuStreamSynchronizeName, + builder.getFunctionType( + { + getPointerType() /* CUstream stream */ + }, + getCUResultType()))); } } @@ -264,7 +262,7 @@ GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp, // %0[n] = constant name[n] // %0[n+1] = 0 Value *GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant( - Function kernelFunction, Location &loc, OpBuilder &builder) { + FuncOp kernelFunction, Location &loc, OpBuilder &builder) { // TODO(herhut): Make this a constant once this is supported. auto kernelNameSize = builder.create( loc, getInt32Type(), @@ -337,7 +335,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls( // Emit the load module call to load the module data. Error checking is done // in the called helper function. auto cuModule = allocatePointer(builder, loc); - Function cuModuleLoad = getModule().getNamedFunction(cuModuleLoadName); + FuncOp cuModuleLoad = getModule().getNamedFunction(cuModuleLoadName); builder.create(loc, ArrayRef{getCUResultType()}, builder.getFunctionAttr(cuModuleLoad), ArrayRef{cuModule, data.getResult(0)}); @@ -347,14 +345,14 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls( builder.create(loc, getPointerType(), cuModule); auto kernelName = generateKernelNameConstant(kernelFunction, loc, builder); auto cuFunction = allocatePointer(builder, loc); - Function cuModuleGetFunction = + FuncOp cuModuleGetFunction = getModule().getNamedFunction(cuModuleGetFunctionName); builder.create( loc, ArrayRef{getCUResultType()}, builder.getFunctionAttr(cuModuleGetFunction), ArrayRef{cuFunction, cuOwningModuleRef, kernelName}); // Grab the global stream needed for execution. - Function cuGetStreamHelper = + FuncOp cuGetStreamHelper = getModule().getNamedFunction(cuGetStreamHelperName); auto cuStream = builder.create( loc, ArrayRef{getPointerType()}, diff --git a/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp b/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp index 6306c567907..e19e2de99fb 100644 --- a/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp @@ -59,10 +59,10 @@ private: return LLVM::LLVMType::getIntNTy(llvmDialect, bits); } - Function getMallocHelper(Location loc, Builder &builder) { - Function result = getModule().getNamedFunction(kMallocHelperName); + FuncOp getMallocHelper(Location loc, Builder &builder) { + FuncOp result = getModule().getNamedFunction(kMallocHelperName); if (!result) { - result = Function::create( + result = FuncOp::create( loc, kMallocHelperName, builder.getFunctionType(ArrayRef{getIndexType()}, LLVM::LLVMType::getInt8PtrTy(llvmDialect))); @@ -75,13 +75,13 @@ private: // data from blob. As there are currently no global constants, this uses a // sequence of store operations. // TODO(herhut): Use global constants instead. - Function generateCubinAccessor(Builder &builder, Function &orig, - StringAttr blob) { + FuncOp generateCubinAccessor(Builder &builder, FuncOp &orig, + StringAttr blob) { Location loc = orig.getLoc(); SmallString<128> nameBuffer(orig.getName()); nameBuffer.append(kCubinGetterSuffix); // Generate a function that returns void*. - Function result = Function::create( + FuncOp result = FuncOp::create( loc, mlir::Identifier::get(nameBuffer, &getContext()), builder.getFunctionType(ArrayRef{}, LLVM::LLVMType::getInt8PtrTy(llvmDialect))); @@ -127,7 +127,7 @@ public: for (auto it = functions.begin(); it != functions.end();) { // Move iterator to after the current function so that potential insertion // of the accessor is after the kernel with cubin iself. - Function orig = *it++; + FuncOp orig = *it++; StringAttr cubinBlob = orig.getAttrOfType(kCubinAnnotation); if (!cubinBlob) continue; diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 01d473e7f59..dc783f94865 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -441,13 +441,13 @@ struct AllocOpLowering : public LLVMLegalizationPattern { createIndexConstant(rewriter, op->getLoc(), elementSize)}); // Insert the `malloc` declaration if it is not already present. - Function mallocFunc = + FuncOp mallocFunc = op->getParentOfType().getModule().getNamedFunction("malloc"); if (!mallocFunc) { auto mallocType = rewriter.getFunctionType(getIndexType(), getVoidPtrType()); mallocFunc = - Function::create(rewriter.getUnknownLoc(), "malloc", mallocType); + FuncOp::create(rewriter.getUnknownLoc(), "malloc", mallocType); op->getParentOfType().getModule().push_back(mallocFunc); } @@ -503,11 +503,11 @@ struct DeallocOpLowering : public LLVMLegalizationPattern { OperandAdaptor transformed(operands); // Insert the `free` declaration if it is not already present. - Function freeFunc = + FuncOp freeFunc = op->getParentOfType().getModule().getNamedFunction("free"); if (!freeFunc) { auto freeType = rewriter.getFunctionType(getVoidPtrType(), {}); - freeFunc = Function::create(rewriter.getUnknownLoc(), "free", freeType); + freeFunc = FuncOp::create(rewriter.getUnknownLoc(), "free", freeType); op->getParentOfType().getModule().push_back(freeFunc); } diff --git a/mlir/lib/EDSC/CoreAPIs.cpp b/mlir/lib/EDSC/CoreAPIs.cpp index 578b8673658..8b1831342b8 100644 --- a/mlir/lib/EDSC/CoreAPIs.cpp +++ b/mlir/lib/EDSC/CoreAPIs.cpp @@ -98,6 +98,6 @@ mlir_attr_t makeBoolAttr(mlir_context_t context, bool value) { } unsigned getFunctionArity(mlir_func_t function) { - auto f = mlir::Function::getFromOpaquePointer(function); + auto f = mlir::FuncOp::getFromOpaquePointer(function); return f.getNumArguments(); } diff --git a/mlir/lib/ExecutionEngine/MemRefUtils.cpp b/mlir/lib/ExecutionEngine/MemRefUtils.cpp index f13b743de0c..e34bf4455ab 100644 --- a/mlir/lib/ExecutionEngine/MemRefUtils.cpp +++ b/mlir/lib/ExecutionEngine/MemRefUtils.cpp @@ -67,7 +67,7 @@ allocMemRefDescriptor(Type type, bool allocateData = true, } llvm::Expected> -mlir::allocateMemRefArguments(Function func, float initialValue) { +mlir::allocateMemRefArguments(FuncOp func, float initialValue) { SmallVector args; args.reserve(func.getNumArguments()); for (const auto &arg : func.getArguments()) { diff --git a/mlir/lib/GPU/IR/GPUDialect.cpp b/mlir/lib/GPU/IR/GPUDialect.cpp index 92034c5d288..22f87a9911f 100644 --- a/mlir/lib/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/GPU/IR/GPUDialect.cpp @@ -32,7 +32,7 @@ using namespace mlir::gpu; StringRef GPUDialect::getDialectName() { return "gpu"; } -bool GPUDialect::isKernel(Function function) { +bool GPUDialect::isKernel(FuncOp function) { UnitAttr isKernelAttr = function.getAttrOfType(getKernelFuncAttrName()); return static_cast(isKernelAttr); @@ -84,25 +84,25 @@ void LaunchOp::build(Builder *builder, OperationState *result, Value *gridSizeX, Region &LaunchOp::getBody() { return getOperation()->getRegion(0); } KernelDim3 LaunchOp::getBlockIds() { - assert(!getBody().getBlocks().empty() && "Function body must not be empty."); + assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty."); auto args = getBody().getBlocks().front().getArguments(); return KernelDim3{args[0], args[1], args[2]}; } KernelDim3 LaunchOp::getThreadIds() { - assert(!getBody().getBlocks().empty() && "Function body must not be empty."); + assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty."); auto args = getBody().getBlocks().front().getArguments(); return KernelDim3{args[3], args[4], args[5]}; } KernelDim3 LaunchOp::getGridSize() { - assert(!getBody().getBlocks().empty() && "Function body must not be empty."); + assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty."); auto args = getBody().getBlocks().front().getArguments(); return KernelDim3{args[6], args[7], args[8]}; } KernelDim3 LaunchOp::getBlockSize() { - assert(!getBody().getBlocks().empty() && "Function body must not be empty."); + assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty."); auto args = getBody().getBlocks().front().getArguments(); return KernelDim3{args[9], args[10], args[11]}; } @@ -378,10 +378,9 @@ void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList &results, //===----------------------------------------------------------------------===// void LaunchFuncOp::build(Builder *builder, OperationState *result, - Function kernelFunc, Value *gridSizeX, - Value *gridSizeY, Value *gridSizeZ, Value *blockSizeX, - Value *blockSizeY, Value *blockSizeZ, - ArrayRef kernelOperands) { + FuncOp kernelFunc, Value *gridSizeX, Value *gridSizeY, + Value *gridSizeZ, Value *blockSizeX, Value *blockSizeY, + Value *blockSizeZ, ArrayRef kernelOperands) { // Add grid and block sizes as op operands, followed by the data operands. result->addOperands( {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ}); @@ -391,7 +390,7 @@ void LaunchFuncOp::build(Builder *builder, OperationState *result, } void LaunchFuncOp::build(Builder *builder, OperationState *result, - Function kernelFunc, KernelDim3 gridSize, + FuncOp kernelFunc, KernelDim3 gridSize, KernelDim3 blockSize, ArrayRef kernelOperands) { build(builder, result, kernelFunc, gridSize.x, gridSize.y, gridSize.z, @@ -427,7 +426,7 @@ LogicalResult LaunchFuncOp::verify() { } auto module = getParentOfType(); - Function kernelFunc = module.getNamedFunction(kernel()); + FuncOp kernelFunc = module.getNamedFunction(kernel()); if (!kernelFunc) return emitError() << "kernel function '" << kernelAttr << "' is undefined"; diff --git a/mlir/lib/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/GPU/Transforms/KernelOutlining.cpp index 0bc7041bd6e..75b98fdb9c5 100644 --- a/mlir/lib/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/GPU/Transforms/KernelOutlining.cpp @@ -40,7 +40,7 @@ static void createForAllDimensions(OpBuilder &builder, Location loc, // Add operations generating block/thread ids and gird/block dimensions at the // beginning of `kernelFunc` and replace uses of the respective function args. -static void injectGpuIndexOperations(Location loc, Function kernelFunc) { +static void injectGpuIndexOperations(Location loc, FuncOp kernelFunc) { OpBuilder OpBuilder(kernelFunc.getBody()); SmallVector indexOps; createForAllDimensions(OpBuilder, loc, indexOps); @@ -58,14 +58,14 @@ static void injectGpuIndexOperations(Location loc, Function kernelFunc) { // Outline the `gpu.launch` operation body into a kernel function. Replace // `gpu.return` operations by `std.return` in the generated functions. -static Function outlineKernelFunc(gpu::LaunchOp launchOp) { +static FuncOp outlineKernelFunc(gpu::LaunchOp launchOp) { Location loc = launchOp.getLoc(); SmallVector kernelOperandTypes(launchOp.getKernelOperandTypes()); FunctionType type = FunctionType::get(kernelOperandTypes, {}, launchOp.getContext()); std::string kernelFuncName = Twine(launchOp.getParentOfType().getName(), "_kernel").str(); - Function outlinedFunc = Function::create(loc, kernelFuncName, type); + FuncOp outlinedFunc = FuncOp::create(loc, kernelFuncName, type); outlinedFunc.getBody().takeBody(launchOp.getBody()); Builder builder(launchOp.getContext()); outlinedFunc.setAttr(gpu::GPUDialect::getKernelFuncAttrName(), @@ -81,8 +81,7 @@ static Function outlineKernelFunc(gpu::LaunchOp launchOp) { // Replace `gpu.launch` operations with an `gpu.launch_func` operation launching // `kernelFunc`. -static void convertToLaunchFuncOp(gpu::LaunchOp &launchOp, - Function kernelFunc) { +static void convertToLaunchFuncOp(gpu::LaunchOp &launchOp, FuncOp kernelFunc) { OpBuilder builder(launchOp); SmallVector kernelOperandValues( launchOp.getKernelOperandValues()); @@ -100,7 +99,7 @@ public: ModuleManager moduleManager(getModule()); for (auto func : getModule().getOps()) { func.walk([&](mlir::gpu::LaunchOp op) { - Function outlinedFunc = outlineKernelFunc(op); + FuncOp outlinedFunc = outlineKernelFunc(op); moduleManager.insert(outlinedFunc); convertToLaunchFuncOp(op, outlinedFunc); }); diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 6d0df6ded8e..ddead2e01a7 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -177,7 +177,7 @@ IntegerSetAttr Builder::getIntegerSetAttr(IntegerSet set) { TypeAttr Builder::getTypeAttr(Type type) { return TypeAttr::get(type); } -FunctionAttr Builder::getFunctionAttr(Function value) { +FunctionAttr Builder::getFunctionAttr(FuncOp value) { return getFunctionAttr(value.getName()); } FunctionAttr Builder::getFunctionAttr(StringRef value) { @@ -337,14 +337,14 @@ OpBuilder::~OpBuilder() {} /// Add new block and set the insertion point to the end of it. If an /// 'insertBefore' block is passed, the block will be placed before the /// specified block. If not, the block will be appended to the end of the -/// current function. +/// current region. Block *OpBuilder::createBlock(Block *insertBefore) { Block *b = new Block(); // If we are supposed to insert before a specific block, do so, otherwise add - // the block to the end of the function. + // the block to the end of the region. if (insertBefore) - region->getBlocks().insert(Function::iterator(insertBefore), b); + region->getBlocks().insert(Region::iterator(insertBefore), b); else region->push_back(b); diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp index e38b95ff0f7..1e042cbf893 100644 --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -71,14 +71,14 @@ Dialect::~Dialect() {} /// Verify an attribute from this dialect on the given function. Returns /// failure if the verification failed, success otherwise. -LogicalResult Dialect::verifyFunctionAttribute(Function, NamedAttribute) { +LogicalResult Dialect::verifyFunctionAttribute(FuncOp, NamedAttribute) { return success(); } /// Verify an attribute from this dialect on the argument at 'argIndex' for /// the given function. Returns failure if the verification failed, success /// otherwise. -LogicalResult Dialect::verifyFunctionArgAttribute(Function, unsigned argIndex, +LogicalResult Dialect::verifyFunctionArgAttribute(FuncOp, unsigned argIndex, NamedAttribute) { return success(); } diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index 973a0910f42..b471010d0af 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -33,22 +33,22 @@ using namespace mlir; // Function Operation. //===----------------------------------------------------------------------===// -Function FuncOp::create(Location location, StringRef name, FunctionType type, - ArrayRef attrs) { +FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, + ArrayRef attrs) { OperationState state(location, "func"); Builder builder(location->getContext()); - Function::build(&builder, &state, name, type, attrs); + FuncOp::build(&builder, &state, name, type, attrs); return llvm::cast(Operation::create(state)); } -Function FuncOp::create(Location location, StringRef name, FunctionType type, - llvm::iterator_range attrs) { +FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, + llvm::iterator_range attrs) { SmallVector attrRef(attrs); return create(location, name, type, llvm::makeArrayRef(attrRef)); } -Function FuncOp::create(Location location, StringRef name, FunctionType type, - ArrayRef attrs, - ArrayRef argAttrs) { - Function func = create(location, name, type, attrs); +FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, + ArrayRef attrs, + ArrayRef argAttrs) { + FuncOp func = create(location, name, type, attrs); func.setAllArgAttrs(argAttrs); return func; } @@ -74,7 +74,7 @@ void FuncOp::build(Builder *builder, OperationState *result, StringRef name, } /// Get the parent module. -ModuleOp Function::getModule() { +ModuleOp FuncOp::getModule() { auto *parent = getOperation()->getContainingRegion(); return parent ? parent->getParentOfType() : nullptr; } diff --git a/mlir/lib/IR/Region.cpp b/mlir/lib/IR/Region.cpp index 2818b1ce207..d6ed2102fb3 100644 --- a/mlir/lib/IR/Region.cpp +++ b/mlir/lib/IR/Region.cpp @@ -17,7 +17,6 @@ #include "mlir/IR/Region.h" #include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/Function.h" #include "mlir/IR/Operation.h" using namespace mlir; diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp index 669f641b734..4fa49213a3f 100644 --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -17,7 +17,6 @@ #include "mlir/IR/Value.h" #include "mlir/IR/Block.h" -#include "mlir/IR/Function.h" #include "mlir/IR/Operation.h" using namespace mlir; diff --git a/mlir/lib/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/LLVMIR/IR/LLVMDialect.cpp index 0dbf63a3ce7..7a73d89f77e 100644 --- a/mlir/lib/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/LLVMIR/IR/LLVMDialect.cpp @@ -816,7 +816,7 @@ void LLVMDialect::printType(Type type, raw_ostream &os) const { } /// Verify LLVMIR function argument attributes. -LogicalResult LLVMDialect::verifyFunctionArgAttribute(Function func, +LogicalResult LLVMDialect::verifyFunctionArgAttribute(FuncOp func, unsigned argIdx, NamedAttribute argAttr) { // Check that llvm.noalias is a boolean attribute. diff --git a/mlir/lib/Linalg/Transforms/Fusion.cpp b/mlir/lib/Linalg/Transforms/Fusion.cpp index 5761cc637b7..cb9509ed54e 100644 --- a/mlir/lib/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Linalg/Transforms/Fusion.cpp @@ -209,7 +209,7 @@ static bool isStructurallyFusableProducer(LinalgOp producer, Value *readView, return true; } -static void fuseLinalgOps(Function f, ArrayRef tileSizes) { +static void fuseLinalgOps(FuncOp f, ArrayRef tileSizes) { OperationFolder state; DenseSet eraseSet; diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index 2b9c893276a..0cda24722e2 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -171,11 +171,11 @@ public: auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); // Insert the `malloc` declaration if it is not already present. auto module = op->getParentOfType(); - Function mallocFunc = module.getNamedFunction("malloc"); + FuncOp mallocFunc = module.getNamedFunction("malloc"); if (!mallocFunc) { auto mallocType = rewriter.getFunctionType(int64Ty, voidPtrTy); mallocFunc = - Function::create(rewriter.getUnknownLoc(), "malloc", mallocType); + FuncOp::create(rewriter.getUnknownLoc(), "malloc", mallocType); module.push_back(mallocFunc); } @@ -232,10 +232,10 @@ public: LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo(); // Insert the `free` declaration if it is not already present. auto module = op->getParentOfType(); - Function freeFunc = module.getNamedFunction("free"); + FuncOp freeFunc = module.getNamedFunction("free"); if (!freeFunc) { auto freeType = rewriter.getFunctionType(voidPtrTy, {}); - freeFunc = Function::create(rewriter.getUnknownLoc(), "free", freeType); + freeFunc = FuncOp::create(rewriter.getUnknownLoc(), "free", freeType); module.push_back(freeFunc); } @@ -573,7 +573,7 @@ public: // Create a function definition which takes as argument pointers to the input // types and returns pointers to the output types. -static Function getLLVMLibraryCallImplDefinition(Function libFn) { +static FuncOp getLLVMLibraryCallImplDefinition(FuncOp libFn) { auto implFnName = (libFn.getName().str() + "_impl"); auto module = libFn.getModule(); if (auto f = module.getNamedFunction(implFnName)) { @@ -589,7 +589,7 @@ static Function getLLVMLibraryCallImplDefinition(Function libFn) { auto implFnType = FunctionType::get(fnArgTypes, {}, libFn.getContext()); // Insert the implementation function definition. - auto implFnDefn = Function::create(libFn.getLoc(), implFnName, implFnType); + auto implFnDefn = FuncOp::create(libFn.getLoc(), implFnName, implFnType); module.push_back(implFnDefn); return implFnDefn; } @@ -597,9 +597,9 @@ static Function getLLVMLibraryCallImplDefinition(Function libFn) { // Get function definition for the LinalgOp. If it doesn't exist, insert a // definition. template -static Function getLLVMLibraryCallDeclaration(Operation *op, - LLVMTypeConverter &lowering, - PatternRewriter &rewriter) { +static FuncOp getLLVMLibraryCallDeclaration(Operation *op, + LLVMTypeConverter &lowering, + PatternRewriter &rewriter) { assert(isa(op)); auto fnName = LinalgOp::getLibraryCallName(); auto module = op->getParentOfType(); @@ -619,14 +619,14 @@ static Function getLLVMLibraryCallDeclaration(Operation *op, "Library call for linalg operation can be generated only for ops that " "have void return types"); auto libFnType = FunctionType::get(inputTypes, {}, op->getContext()); - auto libFn = Function::create(op->getLoc(), fnName, libFnType); + auto libFn = FuncOp::create(op->getLoc(), fnName, libFnType); module.push_back(libFn); // Return after creating the function definition. The body will be created // later. return libFn; } -static void getLLVMLibraryCallDefinition(Function fn, +static void getLLVMLibraryCallDefinition(FuncOp fn, LLVMTypeConverter &lowering) { // Generate the implementation function definition. auto implFn = getLLVMLibraryCallImplDefinition(fn); @@ -666,17 +666,15 @@ public: return convertLinalgType(t, *this); } - void addLibraryFnDeclaration(Function fn) { + void addLibraryFnDeclaration(FuncOp fn) { libraryFnDeclarations.push_back(fn); } - ArrayRef getLibraryFnDeclarations() { - return libraryFnDeclarations; - } + ArrayRef getLibraryFnDeclarations() { return libraryFnDeclarations; } private: /// List of library functions declarations needed during dialect conversion - SmallVector libraryFnDeclarations; + SmallVector libraryFnDeclarations; }; } // end anonymous namespace @@ -727,7 +725,7 @@ struct LowerLinalgToLLVMPass : public ModulePass { // This is currently written as a standalone function because the lowering to // affine will look different than lowering to LLVM and it is still unclear how // everything will be eventually structured. -static void lowerLinalgSubViewOps(Function &f) { +static void lowerLinalgSubViewOps(FuncOp &f) { f.walk([&](SubViewOp op) { OpBuilder b(op); ScopedContext scope(b, op.getLoc()); @@ -750,7 +748,7 @@ static void lowerLinalgSubViewOps(Function &f) { // Converts a `linalg.for` op to CFG form before actual conversion to the LLVM // dialect starts. -static void lowerLinalgForToCFG(Function &f) { +static void lowerLinalgForToCFG(FuncOp &f) { // Collect all the For operations. We do this as a prepass to avoid // invalidating the walker with our rewrite. SmallVector instsToRewrite; diff --git a/mlir/lib/Linalg/Transforms/Tiling.cpp b/mlir/lib/Linalg/Transforms/Tiling.cpp index 4955a80ef5f..e6bb6c302f7 100644 --- a/mlir/lib/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Linalg/Transforms/Tiling.cpp @@ -482,7 +482,7 @@ mlir::linalg::tileLinalgOp(LinalgOp op, ArrayRef tileSizes, return tileLinalgOp(op, tileSizeValues, folder, viewsToPromote); } -static void tileLinalgOps(Function f, ArrayRef tileSizes, +static void tileLinalgOps(FuncOp f, ArrayRef tileSizes, bool promoteViews) { OperationFolder folder; f.walk([promoteViews, tileSizes, &folder](LinalgOp op) { diff --git a/mlir/lib/Pass/IRPrinting.cpp b/mlir/lib/Pass/IRPrinting.cpp index aef16ff231a..232068fad08 100644 --- a/mlir/lib/Pass/IRPrinting.cpp +++ b/mlir/lib/Pass/IRPrinting.cpp @@ -61,8 +61,8 @@ private: static void printIR(const llvm::Any &ir, bool printModuleScope, raw_ostream &out) { // Check for printing at module scope. - if (printModuleScope && llvm::any_isa(ir)) { - Function function = llvm::any_cast(ir); + if (printModuleScope && llvm::any_isa(ir)) { + FuncOp function = llvm::any_cast(ir); // Print the function name and a newline before the Module. out << " (function: " << function.getName() << ")\n"; @@ -74,8 +74,8 @@ static void printIR(const llvm::Any &ir, bool printModuleScope, out << "\n"; // Print the given function. - if (llvm::any_isa(ir)) { - llvm::any_cast(ir).print(out); + if (llvm::any_isa(ir)) { + llvm::any_cast(ir).print(out); return; } diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp index a35efbdcb7a..0357fb3a8d4 100644 --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -42,7 +42,7 @@ using namespace mlir::detail; void Pass::anchor() {} /// Forwarding function to execute this pass. -LogicalResult FunctionPassBase::run(Function fn, FunctionAnalysisManager &fam) { +LogicalResult FunctionPassBase::run(FuncOp fn, FunctionAnalysisManager &fam) { // Initialize the pass state. passState.emplace(fn, fam); @@ -110,7 +110,7 @@ FunctionPassExecutor::FunctionPassExecutor(const FunctionPassExecutor &rhs) } /// Run all of the passes in this manager over the current function. -LogicalResult detail::FunctionPassExecutor::run(Function function, +LogicalResult detail::FunctionPassExecutor::run(FuncOp function, FunctionAnalysisManager &fam) { // Run each of the held passes. for (auto &pass : passes) @@ -135,8 +135,7 @@ LogicalResult detail::ModulePassExecutor::run(Module module, /// Utility to run the given function and analysis manager on a provided /// function pass executor. -static LogicalResult runFunctionPipeline(FunctionPassExecutor &fpe, - Function func, +static LogicalResult runFunctionPipeline(FunctionPassExecutor &fpe, FuncOp func, FunctionAnalysisManager &fam) { // Run the function pipeline over the provided function. auto result = fpe.run(func, fam); @@ -184,7 +183,7 @@ void ModuleToFunctionPassAdaptorParallel::runOnModule() { // Run a prepass over the module to collect the functions to execute a over. // This ensures that an analysis manager exists for each function, as well as // providing a queue of functions to execute over. - std::vector> funcAMPairs; + std::vector> funcAMPairs; for (auto func : getModule().getOps()) if (!func.isExternal()) funcAMPairs.emplace_back(func, mam.slice(func)); @@ -340,13 +339,13 @@ PassInstrumentor *FunctionAnalysisManager::getPassInstrumentor() const { } /// Create an analysis slice for the given child function. -FunctionAnalysisManager ModuleAnalysisManager::slice(Function func) { +FunctionAnalysisManager ModuleAnalysisManager::slice(FuncOp func) { assert(func.getModule() == moduleAnalyses.getIRUnit() && "function has a different parent module"); auto it = functionAnalyses.find(func); if (it == functionAnalyses.end()) { - it = functionAnalyses.try_emplace(func, new AnalysisMap(func)) - .first; + it = + functionAnalyses.try_emplace(func, new AnalysisMap(func)).first; } return {this, it->second.get()}; } diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h index b0cd22820a3..46674eabb5d 100644 --- a/mlir/lib/Pass/PassDetail.h +++ b/mlir/lib/Pass/PassDetail.h @@ -48,7 +48,7 @@ public: FunctionPassExecutor(const FunctionPassExecutor &rhs); /// Run the executor on the given function. - LogicalResult run(Function function, FunctionAnalysisManager &fam); + LogicalResult run(FuncOp function, FunctionAnalysisManager &fam); /// Add a pass to the current executor. This takes ownership over the provided /// pass pointer. diff --git a/mlir/lib/SPIRV/Serialization/ConvertFromBinary.cpp b/mlir/lib/SPIRV/Serialization/ConvertFromBinary.cpp index d3eec3e9c13..d1efd56b97f 100644 --- a/mlir/lib/SPIRV/Serialization/ConvertFromBinary.cpp +++ b/mlir/lib/SPIRV/Serialization/ConvertFromBinary.cpp @@ -36,7 +36,7 @@ using namespace mlir; // block. The created block will be terminated by `std.return`. Block *createOneBlockFunction(Builder builder, Module module) { auto fnType = builder.getFunctionType(/*inputs=*/{}, /*results=*/{}); - auto fn = Function::create(builder.getUnknownLoc(), "spirv_module", fnType); + auto fn = FuncOp::create(builder.getUnknownLoc(), "spirv_module", fnType); module.push_back(fn); fn.addEntryBlock(); diff --git a/mlir/lib/Support/MlirOptMain.cpp b/mlir/lib/Support/MlirOptMain.cpp index ab7b021f920..80cba5ad73f 100644 --- a/mlir/lib/Support/MlirOptMain.cpp +++ b/mlir/lib/Support/MlirOptMain.cpp @@ -24,7 +24,6 @@ #include "mlir/Analysis/Passes.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Diagnostics.h" -#include "mlir/IR/Function.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" diff --git a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp index 0fef7606812..00502306362 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp @@ -68,7 +68,7 @@ std::unique_ptr mlir::translateModuleToNVVMIR(Module m) { // Insert the nvvm.annotations kernel so that the NVVM backend recognizes the // function as a kernel. - for (Function func : m.getOps()) { + for (FuncOp func : m.getOps()) { if (!func.getAttrOfType(gpu::GPUDialect::getKernelFuncAttrName())) continue; diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index a358e8363f4..9388f2318d2 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -275,7 +275,7 @@ static Value *getPHISourceValue(Block *current, Block *pred, : terminator.getSuccessorOperand(1, index); } -void ModuleTranslation::connectPHINodes(Function func) { +void ModuleTranslation::connectPHINodes(FuncOp func) { // Skip the first block, it cannot be branched to and its arguments correspond // to the arguments of the LLVM function. for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) { @@ -306,7 +306,7 @@ static void topologicalSortImpl(llvm::SetVector &blocks, Block *b) { } // Sort function blocks topologically. -static llvm::SetVector topologicalSort(Function f) { +static llvm::SetVector topologicalSort(FuncOp f) { // For each blocks that has not been visited yet (i.e. that has no // predecessors), add it to the list and traverse its successors in DFS // preorder. @@ -320,7 +320,7 @@ static llvm::SetVector topologicalSort(Function f) { return blocks; } -bool ModuleTranslation::convertOneFunction(Function func) { +bool ModuleTranslation::convertOneFunction(FuncOp func) { // Clear the block and value mappings, they are only relevant within one // function. blockMapping.clear(); @@ -375,7 +375,7 @@ bool ModuleTranslation::convertOneFunction(Function func) { bool ModuleTranslation::convertFunctions() { // Declare all functions first because there may be function calls that form a // call graph with cycles. - for (Function function : mlirModule.getOps()) { + for (FuncOp function : mlirModule.getOps()) { mlir::BoolAttr isVarArgsAttr = function.getAttrOfType("std.varargs"); bool isVarArgs = isVarArgsAttr && isVarArgsAttr.getValue(); @@ -392,7 +392,7 @@ bool ModuleTranslation::convertFunctions() { } // Convert functions. - for (Function function : mlirModule.getOps()) { + for (FuncOp function : mlirModule.getOps()) { // Ignore external functions. if (function.isExternal()) continue; diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index 42683afc468..23c12a0e7c4 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -849,7 +849,7 @@ struct FunctionConverter { /// error, success otherwise. If 'signatureConversion' is provided, the /// arguments of the entry block are updated accordingly. LogicalResult - convertFunction(Function f, + convertFunction(FuncOp f, TypeConverter::SignatureConversion *signatureConversion); /// Converts the given region starting from the entry block and following the @@ -957,7 +957,7 @@ FunctionConverter::convertRegion(DialectConversionRewriter &rewriter, } LogicalResult FunctionConverter::convertFunction( - Function f, TypeConverter::SignatureConversion *signatureConversion) { + FuncOp f, TypeConverter::SignatureConversion *signatureConversion) { // If this is an external function, there is nothing else to do. if (f.isExternal()) return success(); @@ -1131,14 +1131,14 @@ LogicalResult mlir::applyConversionPatterns(Module module, ConversionTarget &target, TypeConverter &converter, OwningRewritePatternList &&patterns) { - SmallVector allFunctions(module.getOps()); + SmallVector allFunctions(module.getOps()); return applyConversionPatterns(allFunctions, target, converter, std::move(patterns)); } /// Convert the given functions with the provided conversion patterns. LogicalResult mlir::applyConversionPatterns( - MutableArrayRef fns, ConversionTarget &target, + MutableArrayRef fns, ConversionTarget &target, TypeConverter &converter, OwningRewritePatternList &&patterns) { if (fns.empty()) return success(); @@ -1174,7 +1174,7 @@ LogicalResult mlir::applyConversionPatterns( /// convert as many of the operations within 'fn' as possible given the set of /// patterns. LogicalResult -mlir::applyConversionPatterns(Function fn, ConversionTarget &target, +mlir::applyConversionPatterns(FuncOp fn, ConversionTarget &target, OwningRewritePatternList &&patterns) { // Convert the body of this function. FunctionConverter converter(fn.getContext(), target, patterns); diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 830546db497..f78c941f923 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -771,7 +771,7 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { } void DmaGeneration::runOnFunction() { - Function f = getFunction(); + FuncOp f = getFunction(); OpBuilder topBuilder(f.getBody()); zeroIndex = topBuilder.create(f.getLoc(), 0); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index b2557a6c6fd..ea1a03f09a3 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -150,7 +150,7 @@ static bool isMemRefDereferencingOp(Operation &op) { } // MemRefDependenceGraph is a graph data structure where graph nodes are -// top-level operations in a Function which contain load/store ops, and edges +// top-level operations in a FuncOp which contain load/store ops, and edges // are memref dependences between the nodes. // TODO(andydavis) Add a more flexible dependece graph representation. // TODO(andydavis) Add a depth parameter to dependence graph construction. @@ -257,7 +257,7 @@ public: // Initializes the dependence graph based on operations in 'f'. // Returns true on success, false otherwise. - bool init(Function f); + bool init(FuncOp f); // Returns the graph node for 'id'. Node *getNode(unsigned id) { @@ -637,7 +637,7 @@ public: // Assigns each node in the graph a node id based on program order in 'f'. // TODO(andydavis) Add support for taking a Block arg to construct the // dependence graph at a different depth. -bool MemRefDependenceGraph::init(Function f) { +bool MemRefDependenceGraph::init(FuncOp f) { DenseMap> memrefAccesses; // TODO: support multi-block functions. diff --git a/mlir/lib/Transforms/LoopParametricTiling.cpp b/mlir/lib/Transforms/LoopParametricTiling.cpp index c2b23943794..77626f54a3c 100644 --- a/mlir/lib/Transforms/LoopParametricTiling.cpp +++ b/mlir/lib/Transforms/LoopParametricTiling.cpp @@ -43,7 +43,7 @@ public: : sizes(outerLoopSizes.begin(), outerLoopSizes.end()) {} void runOnFunction() override { - Function func = getFunction(); + FuncOp func = getFunction(); func.walk([this](ForOp op) { // Ignore nested loops. diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 2744e5ca05c..0a331cae100 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -261,7 +261,7 @@ LogicalResult mlir::tileCodeGen(MutableArrayRef band, // Identify valid and profitable bands of loops to tile. This is currently just // a temporary placeholder to test the mechanics of tiled code generation. // Returns all maximal outermost perfect loop nests to tile. -static void getTileableBands(Function f, +static void getTileableBands(FuncOp f, std::vector> *bands) { // Get maximal perfect nest of 'affine.for' insts starting from root // (inclusive). diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 6f13f623fe8..1c7f3393ada 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -92,7 +92,7 @@ void LoopUnroll::runOnFunction() { // Store innermost loops as we walk. std::vector loops; - void walkPostOrder(Function f) { + void walkPostOrder(FuncOp f) { for (auto &b : f) walkPostOrder(b.begin(), b.end()); } @@ -142,7 +142,7 @@ void LoopUnroll::runOnFunction() { ? clUnrollNumRepetitions : 1; // If the call back is provided, we will recurse until no loops are found. - Function func = getFunction(); + FuncOp func = getFunction(); for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) { InnermostLoopGatherer ilg; ilg.walkPostOrder(func); diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index df30e270fe6..2edf2a29386 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -726,7 +726,7 @@ public: } // end namespace -LogicalResult mlir::lowerAffineConstructs(Function function) { +LogicalResult mlir::lowerAffineConstructs(FuncOp function) { OwningRewritePatternList patterns; RewriteListBuilder &terminators, +static bool materialize(FuncOp f, const SetVector &terminators, MaterializationState *state) { DenseSet seen; DominanceInfo domInfo(f); @@ -731,7 +731,7 @@ void MaterializeVectorsPass::runOnFunction() { NestedPatternContext mlContext; // TODO(ntv): Check to see if this supports arbitrary top-level code. - Function f = getFunction(); + FuncOp f = getFunction(); if (f.getBlocks().size() != 1) return; diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index 13a53e3a944..93f7331f7a3 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -213,7 +213,7 @@ void MemRefDataFlowOpt::forwardStoreToLoad(AffineLoadOp loadOp) { void MemRefDataFlowOpt::runOnFunction() { // Only supports single block functions at the moment. - Function f = getFunction(); + FuncOp f = getFunction(); if (f.getBlocks().size() != 1) { markAllAnalysesPreserved(); return; diff --git a/mlir/lib/Transforms/StripDebugInfo.cpp b/mlir/lib/Transforms/StripDebugInfo.cpp index c7c3621781a..c82354ed49e 100644 --- a/mlir/lib/Transforms/StripDebugInfo.cpp +++ b/mlir/lib/Transforms/StripDebugInfo.cpp @@ -29,7 +29,7 @@ struct StripDebugInfo : public FunctionPass { } // end anonymous namespace void StripDebugInfo::runOnFunction() { - Function func = getFunction(); + FuncOp func = getFunction(); auto unknownLoc = UnknownLoc::get(&getContext()); // Strip the debug info from the function and its operations. diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index e185f702d27..c65370233da 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -44,7 +44,7 @@ namespace { /// applies the locally optimal patterns in a roughly "bottom up" way. class GreedyPatternRewriteDriver : public PatternRewriter { public: - explicit GreedyPatternRewriteDriver(Function fn, + explicit GreedyPatternRewriteDriver(FuncOp fn, OwningRewritePatternList &&patterns) : PatternRewriter(fn.getBody()), matcher(std::move(patterns)) { worklist.reserve(64); @@ -213,7 +213,7 @@ bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) { /// patterns in a greedy work-list driven manner. Return true if no more /// patterns can be matched in the result function. /// -bool mlir::applyPatternsGreedily(Function fn, +bool mlir::applyPatternsGreedily(FuncOp fn, OwningRewritePatternList &&patterns) { GreedyPatternRewriteDriver driver(fn, std::move(patterns)); bool converged = driver.simplifyFunction(maxPatternMatchIterations); diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 5a0fb1f49fb..1f823391c3a 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -153,11 +153,11 @@ LogicalResult mlir::promoteIfSingleIteration(AffineForOp forOp) { return success(); } -/// Promotes all single iteration for op's in the Function, i.e., moves +/// Promotes all single iteration for op's in the FuncOp, i.e., moves /// their body into the containing Block. -void mlir::promoteSingleIterationLoops(Function *f) { +void mlir::promoteSingleIterationLoops(FuncOp f) { // Gathers all innermost loops through a post order pruned walk. - f->walk( + f.walk( [](AffineForOp forOp) { promoteIfSingleIteration(forOp); }); } diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 4aff2ac4d13..43a6a2f7a82 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -1229,7 +1229,7 @@ static LogicalResult vectorizeRootMatch(NestedMatch m, /// Applies vectorization to the current Function by searching over a bunch of /// predetermined patterns. void Vectorize::runOnFunction() { - Function f = getFunction(); + FuncOp f = getFunction(); if (!fastestVaryingPattern.empty() && fastestVaryingPattern.size() != vectorSizes.size()) { f.emitRemark("Fastest varying pattern specified with different size than " diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp index b58170a3557..e9894b6a2f1 100644 --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -43,11 +43,11 @@ static MLIRContext &globalContext() { return context; } -static Function makeFunction(StringRef name, ArrayRef results = {}, - ArrayRef args = {}) { +static FuncOp makeFunction(StringRef name, ArrayRef results = {}, + ArrayRef args = {}) { auto &ctx = globalContext(); - auto function = Function::create(UnknownLoc::get(&ctx), name, - FunctionType::get(args, results, &ctx)); + auto function = FuncOp::create(UnknownLoc::get(&ctx), name, + FunctionType::get(args, results, &ctx)); function.addEntryBlock(); return function; } @@ -556,7 +556,7 @@ TEST_FUNC(vectorize_2d) { auto owningF = makeFunction("vectorize_2d", {}, {memrefType, memrefType, memrefType}); - mlir::Function f = owningF; + mlir::FuncOp f = owningF; mlir::OwningModuleRef module = Module::create(&globalContext()); module->push_back(f); diff --git a/mlir/test/lib/Transforms/TestVectorizationUtils.cpp b/mlir/test/lib/Transforms/TestVectorizationUtils.cpp index 7bfb5564064..4fd77c88ae0 100644 --- a/mlir/test/lib/Transforms/TestVectorizationUtils.cpp +++ b/mlir/test/lib/Transforms/TestVectorizationUtils.cpp @@ -256,7 +256,7 @@ void VectorizerTestPass::runOnFunction() { NestedPatternContext mlContext; // Only support single block functions at this point. - Function f = getFunction(); + FuncOp f = getFunction(); if (f.getBlocks().size() != 1) return; diff --git a/mlir/tools/mlir-cpu-runner/mlir-cpu-runner-lib.cpp b/mlir/tools/mlir-cpu-runner/mlir-cpu-runner-lib.cpp index 86e673b1362..9194b4e0d63 100644 --- a/mlir/tools/mlir-cpu-runner/mlir-cpu-runner-lib.cpp +++ b/mlir/tools/mlir-cpu-runner/mlir-cpu-runner-lib.cpp @@ -164,7 +164,7 @@ static LogicalResult convertAffineStandardToLLVMIR(Module module) { static Error compileAndExecuteFunctionWithMemRefs( Module module, StringRef entryPoint, std::function transformer) { - Function mainFunction = module.getNamedFunction(entryPoint); + FuncOp mainFunction = module.getNamedFunction(entryPoint); if (!mainFunction || mainFunction.getBlocks().empty()) { return make_string_error("entry point not found"); } @@ -207,7 +207,7 @@ static Error compileAndExecuteFunctionWithMemRefs( static Error compileAndExecuteSingleFloatReturnFunction( Module module, StringRef entryPoint, std::function transformer) { - Function mainFunction = module.getNamedFunction(entryPoint); + FuncOp mainFunction = module.getNamedFunction(entryPoint); if (!mainFunction || mainFunction.isExternal()) { return make_string_error("entry point not found"); } diff --git a/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp b/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp index 62e7cfc307e..32238ba34a2 100644 --- a/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp +++ b/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp @@ -45,7 +45,7 @@ extern int run(int argc, char **argv, llvm::function_ref); inline void emit_cuda_error(const llvm::Twine &message, const char *buffer, - CUresult error, Function &function) { + CUresult error, FuncOp &function) { function.emitError(message.concat(" failed with error code ") .concat(llvm::Twine{error}) .concat("[") @@ -62,7 +62,7 @@ inline void emit_cuda_error(const llvm::Twine &message, const char *buffer, } \ } -OwnedCubin compilePtxToCubin(const std::string ptx, Function &function) { +OwnedCubin compilePtxToCubin(const std::string ptx, FuncOp &function) { char jitErrorBuffer[4096] = {0}; RETURN_ON_CUDA_ERROR(cuInit(0), "cuInit"); diff --git a/mlir/tools/mlir-tblgen/ReferenceImplGen.cpp b/mlir/tools/mlir-tblgen/ReferenceImplGen.cpp index ec92b0e3e9b..3e6893a23d3 100644 --- a/mlir/tools/mlir-tblgen/ReferenceImplGen.cpp +++ b/mlir/tools/mlir-tblgen/ReferenceImplGen.cpp @@ -39,7 +39,7 @@ static void emitReferenceImplementations(const RecordKeeper &recordKeeper, emitSourceFileHeader("Reference implementation file", os); const auto &defs = recordKeeper.getAllDerivedDefinitions("Op"); - os << "void printRefImplementation(StringRef opName, mlir::Function *f) {\n" + os << "void printRefImplementation(StringRef opName, mlir::FuncOp *f) {\n" << " using namespace ::mlir::edsc;\n" << "if (false) {}"; for (auto *def : defs) { diff --git a/mlir/unittests/Pass/AnalysisManagerTest.cpp b/mlir/unittests/Pass/AnalysisManagerTest.cpp index 0464498b361..b6de02e7e80 100644 --- a/mlir/unittests/Pass/AnalysisManagerTest.cpp +++ b/mlir/unittests/Pass/AnalysisManagerTest.cpp @@ -25,11 +25,11 @@ using namespace mlir::detail; namespace { /// Minimal class definitions for two analyses. struct MyAnalysis { - MyAnalysis(Function) {} + MyAnalysis(FuncOp) {} MyAnalysis(Module) {} }; struct OtherAnalysis { - OtherAnalysis(Function) {} + OtherAnalysis(FuncOp) {} OtherAnalysis(Module) {} }; @@ -59,9 +59,9 @@ TEST(AnalysisManagerTest, FineGrainFunctionAnalysisPreservation) { // Create a function and a module. OwningModuleRef module(Module::create(&context)); - Function func1 = - Function::create(builder.getUnknownLoc(), "foo", - builder.getFunctionType(llvm::None, llvm::None)); + FuncOp func1 = + FuncOp::create(builder.getUnknownLoc(), "foo", + builder.getFunctionType(llvm::None, llvm::None)); module->push_back(func1); // Test fine grain invalidation of the function analysis manager. @@ -87,9 +87,9 @@ TEST(AnalysisManagerTest, FineGrainChildFunctionAnalysisPreservation) { // Create a function and a module. OwningModuleRef module(Module::create(&context)); - Function func1 = - Function::create(builder.getUnknownLoc(), "foo", - builder.getFunctionType(llvm::None, llvm::None)); + FuncOp func1 = + FuncOp::create(builder.getUnknownLoc(), "foo", + builder.getFunctionType(llvm::None, llvm::None)); module->push_back(func1); // Test fine grain invalidation of a function analysis from within a module -- cgit v1.2.3 From 926fb685deadfed2042163145ac52311914bf5c2 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Mon, 12 Aug 2019 19:12:42 -0700 Subject: Express ownership transfer in PassManager API through std::unique_ptr (NFC) Since raw pointers are always passed around for IR construct without implying any ownership transfer, it can be error prone to have implicit ownership transferred the same way. For example this code can seem harmless: Pass *pass = .... pm.addPass(pass); pm.addPass(pass); pm.run(module); PiperOrigin-RevId: 263053082 --- .../Linalg/Linalg3/include/linalg3/Transforms.h | 2 +- mlir/examples/Linalg/Linalg3/lib/Transforms.cpp | 4 +- mlir/examples/toy/Ch4/include/toy/Passes.h | 4 +- mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp | 4 +- mlir/examples/toy/Ch4/toyc.cpp | 1 + mlir/examples/toy/Ch5/include/toy/Lowering.h | 4 +- mlir/examples/toy/Ch5/include/toy/Passes.h | 4 +- mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp | 4 +- mlir/examples/toy/Ch5/mlir/LateLowering.cpp | 4 +- mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp | 4 +- mlir/examples/toy/Ch5/toyc.cpp | 1 + .../mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h | 6 +-- .../mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h | 4 +- .../mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h | 6 ++- .../StandardToLLVM/ConvertStandardToLLVMPass.h | 6 +-- mlir/include/mlir/Dialect/GPU/Passes.h | 4 +- mlir/include/mlir/Dialect/QuantOps/Passes.h | 6 ++- mlir/include/mlir/Dialect/SPIRV/Passes.h | 2 +- mlir/include/mlir/Linalg/Passes.h | 12 +++--- mlir/include/mlir/Pass/Pass.h | 6 +-- mlir/include/mlir/Pass/PassManager.h | 6 +-- mlir/include/mlir/Pass/PassRegistry.h | 7 +++- mlir/include/mlir/Quantizer/Transforms/Passes.h | 6 +-- mlir/include/mlir/Transforms/Passes.h | 48 +++++++++++----------- .../GPUToCUDA/ConvertKernelFuncToCubin.cpp | 4 +- .../GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp | 5 ++- .../GPUToCUDA/GenerateCubinAccessors.cpp | 4 +- .../Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 4 +- mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp | 11 ++--- .../StandardToLLVM/ConvertStandardToLLVM.cpp | 9 ++-- .../StandardToSPIRV/ConvertStandardToSPIRVPass.cpp | 5 ++- .../lib/Dialect/GPU/Transforms/KernelOutlining.cpp | 4 +- .../Dialect/QuantOps/Transforms/ConvertConst.cpp | 4 +- .../QuantOps/Transforms/ConvertSimQuant.cpp | 5 ++- mlir/lib/Linalg/Transforms/Fusion.cpp | 6 +-- mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp | 4 +- mlir/lib/Linalg/Transforms/LowerToLoops.cpp | 4 +- mlir/lib/Linalg/Transforms/Tiling.cpp | 6 +-- mlir/lib/Pass/Pass.cpp | 26 ++++++------ mlir/lib/Pass/PassDetail.h | 8 +++- .../Transforms/AddDefaultStatsTestPass.cpp | 4 +- .../Transforms/InferQuantizedTypesPass.cpp | 4 +- .../Transforms/RemoveInstrumentationPass.cpp | 5 ++- mlir/lib/Transforms/AffineDataCopyGeneration.cpp | 8 ++-- mlir/lib/Transforms/CSE.cpp | 4 +- mlir/lib/Transforms/Canonicalizer.cpp | 4 +- mlir/lib/Transforms/LoopCoalescing.cpp | 4 +- mlir/lib/Transforms/LoopFusion.cpp | 9 ++-- mlir/lib/Transforms/LoopInvariantCodeMotion.cpp | 4 +- mlir/lib/Transforms/LoopTiling.cpp | 5 ++- mlir/lib/Transforms/LoopUnroll.cpp | 4 +- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 5 ++- mlir/lib/Transforms/LowerAffine.cpp | 4 +- mlir/lib/Transforms/LowerVectorTransfers.cpp | 4 +- mlir/lib/Transforms/MaterializeVectors.cpp | 4 +- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 4 +- mlir/lib/Transforms/PipelineDataTransfer.cpp | 4 +- mlir/lib/Transforms/SimplifyAffineStructures.cpp | 4 +- mlir/lib/Transforms/StripDebugInfo.cpp | 4 +- mlir/lib/Transforms/Vectorize.cpp | 4 +- mlir/test/lib/TestDialect/TestPatterns.cpp | 9 ++-- mlir/test/lib/Transforms/TestConstantFold.cpp | 4 +- mlir/test/lib/Transforms/TestLoopFusion.cpp | 4 +- mlir/test/lib/Transforms/TestLoopMapping.cpp | 2 +- .../lib/Transforms/TestLoopParametricTiling.cpp | 7 ++-- .../test/lib/Transforms/TestVectorizationUtils.cpp | 4 +- 66 files changed, 217 insertions(+), 169 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h b/mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h index 4346b47cf49..123d6afba08 100644 --- a/mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h +++ b/mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h @@ -73,7 +73,7 @@ void lowerToLoops(mlir::FuncOp f); /// Creates a pass that rewrites linalg.load and linalg.store to affine.load and /// affine.store operations. -mlir::FunctionPassBase *createLowerLinalgLoadStorePass(); +std::unique_ptr createLowerLinalgLoadStorePass(); } // namespace linalg diff --git a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp index 7fc4bb5c897..79fa4ca34f2 100644 --- a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp @@ -300,6 +300,6 @@ Rewriter::matchAndRewrite(linalg::StoreOp store, } } // namespace -FunctionPassBase *linalg::createLowerLinalgLoadStorePass() { - return new LowerLinalgLoadStorePass(); +std::unique_ptr linalg::createLowerLinalgLoadStorePass() { + return llvm::make_unique(); } diff --git a/mlir/examples/toy/Ch4/include/toy/Passes.h b/mlir/examples/toy/Ch4/include/toy/Passes.h index dd73b95f9c2..93cf0d5ba15 100644 --- a/mlir/examples/toy/Ch4/include/toy/Passes.h +++ b/mlir/examples/toy/Ch4/include/toy/Passes.h @@ -22,12 +22,14 @@ #ifndef MLIR_TUTORIAL_TOY_PASSES_H #define MLIR_TUTORIAL_TOY_PASSES_H +#include + namespace mlir { class Pass; } // namespace mlir namespace toy { -mlir::Pass *createShapeInferencePass(); +std::unique_ptr createShapeInferencePass(); } // namespace toy #endif // MLIR_TUTORIAL_TOY_PASSES_H diff --git a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp index 5c258f1ef5b..4a6bf8790e0 100644 --- a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp @@ -375,5 +375,7 @@ public: } // end anonymous namespace namespace toy { -mlir::Pass *createShapeInferencePass() { return new ShapeInferencePass(); } +std::unique_ptr createShapeInferencePass() { + return llvm::make_unique(); +} } // namespace toy diff --git a/mlir/examples/toy/Ch4/toyc.cpp b/mlir/examples/toy/Ch4/toyc.cpp index a273c9301ce..9e7a8a39e0a 100644 --- a/mlir/examples/toy/Ch4/toyc.cpp +++ b/mlir/examples/toy/Ch4/toyc.cpp @@ -28,6 +28,7 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/Parser.h" +#include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" diff --git a/mlir/examples/toy/Ch5/include/toy/Lowering.h b/mlir/examples/toy/Ch5/include/toy/Lowering.h index 362a3428346..4788ea3fbeb 100644 --- a/mlir/examples/toy/Ch5/include/toy/Lowering.h +++ b/mlir/examples/toy/Ch5/include/toy/Lowering.h @@ -35,10 +35,10 @@ class DialectConversion; namespace toy { /// Create a pass for lowering operations in the `Linalg` dialects, for a subset /// of the Toy IR (matmul). -mlir::Pass *createEarlyLoweringPass(); +std::unique_ptr createEarlyLoweringPass(); /// Create a pass for the late lowering toward LLVM dialect. -mlir::Pass *createLateLoweringPass(); +std::unique_ptr createLateLoweringPass(); } // namespace toy diff --git a/mlir/examples/toy/Ch5/include/toy/Passes.h b/mlir/examples/toy/Ch5/include/toy/Passes.h index dd73b95f9c2..93cf0d5ba15 100644 --- a/mlir/examples/toy/Ch5/include/toy/Passes.h +++ b/mlir/examples/toy/Ch5/include/toy/Passes.h @@ -22,12 +22,14 @@ #ifndef MLIR_TUTORIAL_TOY_PASSES_H #define MLIR_TUTORIAL_TOY_PASSES_H +#include + namespace mlir { class Pass; } // namespace mlir namespace toy { -mlir::Pass *createShapeInferencePass(); +std::unique_ptr createShapeInferencePass(); } // namespace toy #endif // MLIR_TUTORIAL_TOY_PASSES_H diff --git a/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp b/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp index 015a3fd64c2..96230fdfbea 100644 --- a/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp @@ -142,5 +142,7 @@ struct EarlyLoweringPass : public FunctionPass { } // end anonymous namespace namespace toy { -Pass *createEarlyLoweringPass() { return new EarlyLoweringPass(); } +std::unique_ptr createEarlyLoweringPass() { + return llvm::make_unique(); +} } // namespace toy diff --git a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp index 3b6bfc9df5d..6135e275a75 100644 --- a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp @@ -458,5 +458,7 @@ struct LateLoweringPass : public ModulePass { } // end anonymous namespace namespace toy { -Pass *createLateLoweringPass() { return new LateLoweringPass(); } +std::unique_ptr createLateLoweringPass() { + return llvm::make_unique(); +} } // namespace toy diff --git a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp index cef2939788c..6437c0b3f73 100644 --- a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp @@ -375,5 +375,7 @@ public: } // end anonymous namespace namespace toy { -mlir::Pass *createShapeInferencePass() { return new ShapeInferencePass(); } +std::unique_ptr createShapeInferencePass() { + return llvm::make_unique(); +} } // namespace toy diff --git a/mlir/examples/toy/Ch5/toyc.cpp b/mlir/examples/toy/Ch5/toyc.cpp index 1d80c3c018d..a21eda74d82 100644 --- a/mlir/examples/toy/Ch5/toyc.cpp +++ b/mlir/examples/toy/Ch5/toyc.cpp @@ -32,6 +32,7 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/Parser.h" +#include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Target/LLVMIR.h" #include "mlir/Transforms/Passes.h" diff --git a/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h b/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h index b19fb53e3e2..bd1a3fea0ff 100644 --- a/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h +++ b/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h @@ -39,7 +39,7 @@ using CubinGenerator = std::function; /// attached as a string attribute named 'nvvm.cubin' to the kernel function. /// After the transformation, the body of the kernel function is removed (i.e., /// it is turned into a declaration). -ModulePassBase * +std::unique_ptr createConvertGPUKernelToCubinPass(CubinGenerator cubinGenerator); /// Creates a pass to convert a gpu.launch_func operation into a sequence of @@ -48,11 +48,11 @@ createConvertGPUKernelToCubinPass(CubinGenerator cubinGenerator); /// This pass does not generate code to call CUDA directly but instead uses a /// small wrapper library that exports a stable and conveniently typed ABI /// ontop of CUDA. -ModulePassBase *createConvertGpuLaunchFuncToCudaCallsPass(); +std::unique_ptr createConvertGpuLaunchFuncToCudaCallsPass(); /// Creates a pass to augment a module with getter functions for all contained /// cubins as encoded via the 'nvvm.cubin' attribute. -ModulePassBase *createGenerateCubinAccessorPass(); +std::unique_ptr createGenerateCubinAccessorPass(); } // namespace mlir #endif // MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_ diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h index b53549fb275..f1c8601795c 100644 --- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h +++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h @@ -17,11 +17,13 @@ #ifndef MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_ #define MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_ +#include + namespace mlir { struct FunctionPassBase; /// Creates a pass that lowers GPU dialect operations to NVVM counterparts. -FunctionPassBase *createLowerGpuOpsToNVVMOpsPass(); +std::unique_ptr createLowerGpuOpsToNVVMOpsPass(); } // namespace mlir diff --git a/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h b/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h index 52f0dd4babb..3d32c36c43c 100644 --- a/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h +++ b/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h @@ -17,6 +17,8 @@ #ifndef MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPUPASS_H_ #define MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPUPASS_H_ +#include + namespace mlir { class FunctionPassBase; @@ -28,8 +30,8 @@ class FunctionPassBase; /// parallelization is performed, it is under the responsibility of the caller /// to strip-mine the loops and to perform the dependence analysis before /// calling the conversion. -FunctionPassBase *createSimpleLoopsToGPUPass(unsigned numBlockDims, - unsigned numThreadDims); +std::unique_ptr +createSimpleLoopsToGPUPass(unsigned numBlockDims, unsigned numThreadDims); } // namespace mlir #endif // MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPUPASS_H_ diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h index 941e382905f..a08b2fb45d6 100644 --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h @@ -57,12 +57,12 @@ void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns); /// Creates a pass to convert the Standard dialect into the LLVMIR dialect. -ModulePassBase *createConvertToLLVMIRPass(); +std::unique_ptr createConvertToLLVMIRPass(); /// Creates a pass to convert operations to the LLVMIR dialect. The conversion /// is defined by a list of patterns and a type converter that will be obtained /// during the pass using the provided callbacks. -ModulePassBase * +std::unique_ptr createConvertToLLVMIRPass(LLVMPatternListFiller patternListFiller, LLVMTypeConverterMaker typeConverterMaker); @@ -71,7 +71,7 @@ createConvertToLLVMIRPass(LLVMPatternListFiller patternListFiller, /// callback and an optional type conversion class, an instance is created /// during the pass. template -ModulePassBase * +std::unique_ptr createConvertToLLVMIRPass(LLVMPatternListFiller patternListFiller) { return createConvertToLLVMIRPass(patternListFiller, [](MLIRContext *context) { return llvm::make_unique(context); diff --git a/mlir/include/mlir/Dialect/GPU/Passes.h b/mlir/include/mlir/Dialect/GPU/Passes.h index f9b569d50af..d562b5835c7 100644 --- a/mlir/include/mlir/Dialect/GPU/Passes.h +++ b/mlir/include/mlir/Dialect/GPU/Passes.h @@ -22,11 +22,13 @@ #ifndef MLIR_DIALECT_GPU_PASSES_H_ #define MLIR_DIALECT_GPU_PASSES_H_ +#include + namespace mlir { class ModulePassBase; -ModulePassBase *createGpuKernelOutliningPass(); +std::unique_ptr createGpuKernelOutliningPass(); } // namespace mlir diff --git a/mlir/include/mlir/Dialect/QuantOps/Passes.h b/mlir/include/mlir/Dialect/QuantOps/Passes.h index 6b647a87f4a..1d43f7087db 100644 --- a/mlir/include/mlir/Dialect/QuantOps/Passes.h +++ b/mlir/include/mlir/Dialect/QuantOps/Passes.h @@ -25,6 +25,8 @@ #ifndef MLIR_DIALECT_QUANTOPS_PASSES_H #define MLIR_DIALECT_QUANTOPS_PASSES_H +#include + namespace mlir { class FunctionPassBase; @@ -32,14 +34,14 @@ namespace quant { /// Creates a pass that converts quantization simulation operations (i.e. /// FakeQuant and those like it) to casts into/out of supported QuantizedTypes. -FunctionPassBase *createConvertSimulatedQuantPass(); +std::unique_ptr createConvertSimulatedQuantPass(); /// Creates a pass that converts constants followed by a qbarrier to a /// constant whose value is quantized. This is typically one of the last /// passes done when lowering to express actual quantized arithmetic in a /// low level representation. Because it modifies the constant, it is /// destructive and cannot be undone. -FunctionPassBase *createConvertConstPass(); +std::unique_ptr createConvertConstPass(); } // namespace quant } // namespace mlir diff --git a/mlir/include/mlir/Dialect/SPIRV/Passes.h b/mlir/include/mlir/Dialect/SPIRV/Passes.h index e896da7ae8a..85f4f79ed59 100644 --- a/mlir/include/mlir/Dialect/SPIRV/Passes.h +++ b/mlir/include/mlir/Dialect/SPIRV/Passes.h @@ -27,7 +27,7 @@ namespace mlir { namespace spirv { -ModulePassBase *createConvertStandardToSPIRVPass(); +std::unique_ptr createConvertStandardToSPIRVPass(); } // namespace spirv } // namespace mlir diff --git a/mlir/include/mlir/Linalg/Passes.h b/mlir/include/mlir/Linalg/Passes.h index 02941492059..57dd09cfc63 100644 --- a/mlir/include/mlir/Linalg/Passes.h +++ b/mlir/include/mlir/Linalg/Passes.h @@ -30,14 +30,16 @@ class FunctionPassBase; class ModulePassBase; namespace linalg { -FunctionPassBase *createLinalgFusionPass(ArrayRef tileSizes = {}); +std::unique_ptr +createLinalgFusionPass(ArrayRef tileSizes = {}); -FunctionPassBase *createLinalgTilingPass(ArrayRef tileSizes = {}, - bool promoteViews = false); +std::unique_ptr +createLinalgTilingPass(ArrayRef tileSizes = {}, + bool promoteViews = false); -FunctionPassBase *createLowerLinalgToLoopsPass(); +std::unique_ptr createLowerLinalgToLoopsPass(); -ModulePassBase *createLowerLinalgToLLVMPass(); +std::unique_ptr createLowerLinalgToLLVMPass(); } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h index b1531a357e5..f5c8d8bd1a6 100644 --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -104,7 +104,7 @@ protected: virtual void runOnFunction() = 0; /// A clone method to create a copy of this pass. - virtual FunctionPassBase *clone() const = 0; + virtual std::unique_ptr clone() const = 0; /// Return the current function being transformed. FuncOp getFunction() { return getPassState().irAndPassFailed.getPointer(); } @@ -259,8 +259,8 @@ struct FunctionPass : public detail::PassModel { } /// A clone method to create a copy of this pass. - FunctionPassBase *clone() const override { - return new T(*static_cast(this)); + std::unique_ptr clone() const override { + return llvm::make_unique(*static_cast(this)); } }; diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h index 68dfeb099bc..b01445eae4c 100644 --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -71,16 +71,16 @@ public: /// Add an opaque pass pointer to the current manager. This takes ownership /// over the provided pass pointer. - void addPass(Pass *pass); + void addPass(std::unique_ptr pass); /// Add a module pass to the current manager. This takes ownership over the /// provided pass pointer. - void addPass(ModulePassBase *pass); + void addPass(std::unique_ptr pass); /// Add a function pass to the current manager. This takes ownership over the /// provided pass pointer. This will automatically create a function pass /// executor if necessary. - void addPass(FunctionPassBase *pass); + void addPass(std::unique_ptr pass); //===--------------------------------------------------------------------===// // Instrumentations diff --git a/mlir/include/mlir/Pass/PassRegistry.h b/mlir/include/mlir/Pass/PassRegistry.h index ea0fbbe39db..bd108f3e77f 100644 --- a/mlir/include/mlir/Pass/PassRegistry.h +++ b/mlir/include/mlir/Pass/PassRegistry.h @@ -29,6 +29,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Compiler.h" #include +#include namespace mlir { class Pass; @@ -37,7 +38,7 @@ class PassManager; /// A registry function that adds passes to the given pass manager. using PassRegistryFunction = std::function; -using PassAllocatorFunction = std::function; +using PassAllocatorFunction = std::function()>; /// A special type used by transformation passes to provide an address that can /// act as a unique identifier during pass registration. @@ -120,7 +121,9 @@ template struct PassRegistration { } PassRegistration(StringRef arg, StringRef description) { - PassAllocatorFunction constructor = [] { return new ConcretePass(); }; + PassAllocatorFunction constructor = [] { + return llvm::make_unique(); + }; registerPass(arg, description, PassID::getID(), constructor); } }; diff --git a/mlir/include/mlir/Quantizer/Transforms/Passes.h b/mlir/include/mlir/Quantizer/Transforms/Passes.h index 0d7b4cb55b3..f894ea801e0 100644 --- a/mlir/include/mlir/Quantizer/Transforms/Passes.h +++ b/mlir/include/mlir/Quantizer/Transforms/Passes.h @@ -33,17 +33,17 @@ class TargetConfiguration; /// Creates a pass that infers quantized types based on metadata discovered /// in the computation. -ModulePassBase * +std::unique_ptr createInferQuantizedTypesPass(SolverContext &solverContext, const TargetConfiguration &config); /// Creates a pass which removes any instrumentation and hint ops which have /// no effect on final runtime. -FunctionPassBase *createRemoveInstrumentationPass(); +std::unique_ptr createRemoveInstrumentationPass(); /// Adds default (dummy) statistics to ops that can benefit from runtime stats. /// Meant for testing. -FunctionPassBase *createAddDefaultStatsPass(); +std::unique_ptr createAddDefaultStatsPass(); } // namespace quantizer } // namespace mlir diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index ee36517cea7..693c7b0ae00 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -37,25 +37,25 @@ class ModulePassBase; /// top-down constant folding functionality; it is intended to be used for /// testing purpose. Use Canonicalizer pass, which exploits more simplification /// opportunties exposed by constant folding, for the general cases. -FunctionPassBase *createTestConstantFoldPass(); +std::unique_ptr createTestConstantFoldPass(); /// Creates an instance of the Canonicalizer pass. -FunctionPassBase *createCanonicalizerPass(); +std::unique_ptr createCanonicalizerPass(); /// Creates a pass to perform common sub expression elimination. -FunctionPassBase *createCSEPass(); +std::unique_ptr createCSEPass(); /// Creates a pass to vectorize loops, operations and data types using a /// target-independent, n-D super-vector abstraction. -FunctionPassBase * +std::unique_ptr createVectorizePass(llvm::ArrayRef virtualVectorSize); /// Creates a pass to allow independent testing of vectorizer functionality with /// FileCheck. -FunctionPassBase *createVectorizerTestPass(); +std::unique_ptr createVectorizerTestPass(); /// Creates a pass to lower super-vectors to target-dependent HW vectors. -FunctionPassBase * +std::unique_ptr createMaterializeVectorsPass(llvm::ArrayRef vectorSize); /// Creates a loop unrolling pass with the provided parameters. @@ -64,71 +64,73 @@ createMaterializeVectorsPass(llvm::ArrayRef vectorSize); /// factors supplied through other means. If -1 is passed as the unrollFactor /// and no callback is provided, anything passed from the command-line (if at /// all) or the default unroll factor is used (LoopUnroll:kDefaultUnrollFactor). -FunctionPassBase *createLoopUnrollPass( +std::unique_ptr createLoopUnrollPass( int unrollFactor = -1, int unrollFull = -1, const std::function &getUnrollFactor = nullptr); /// Creates a loop unroll jam pass to unroll jam by the specified factor. A /// factor of -1 lets the pass use the default factor or the one on the command /// line if provided. -FunctionPassBase *createLoopUnrollAndJamPass(int unrollJamFactor = -1); +std::unique_ptr +createLoopUnrollAndJamPass(int unrollJamFactor = -1); /// Creates an simplification pass for affine structures. -FunctionPassBase *createSimplifyAffineStructuresPass(); +std::unique_ptr createSimplifyAffineStructuresPass(); /// Creates a loop fusion pass which fuses loops. Buffers of size less than or /// equal to `localBufSizeThreshold` are promoted to memory space /// `fastMemorySpace'. -FunctionPassBase *createLoopFusionPass(unsigned fastMemorySpace = 0, - uint64_t localBufSizeThreshold = 0, - bool maximalFusion = false); +std::unique_ptr +createLoopFusionPass(unsigned fastMemorySpace = 0, + uint64_t localBufSizeThreshold = 0, + bool maximalFusion = false); /// Creates a loop invariant code motion pass that hoists loop invariant /// instructions out of the loop. -FunctionPassBase *createLoopInvariantCodeMotionPass(); +std::unique_ptr createLoopInvariantCodeMotionPass(); /// Creates a pass to pipeline explicit movement of data across levels of the /// memory hierarchy. -FunctionPassBase *createPipelineDataTransferPass(); +std::unique_ptr createPipelineDataTransferPass(); /// Lowers affine control flow operations (ForStmt, IfStmt and AffineApplyOp) /// to equivalent lower-level constructs (flow of basic blocks and arithmetic /// primitives). -FunctionPassBase *createLowerAffinePass(); +std::unique_ptr createLowerAffinePass(); /// Creates a pass to perform tiling on loop nests. -FunctionPassBase *createLoopTilingPass(uint64_t cacheSizeBytes); +std::unique_ptr createLoopTilingPass(uint64_t cacheSizeBytes); /// Creates a pass that performs parametric tiling so that the outermost loops /// have the given fixed number of iterations. Assumes outermost loop nests /// are permutable. -FunctionPassBase * +std::unique_ptr createSimpleParametricTilingPass(ArrayRef outerLoopSizes); /// Creates a pass that transforms perfectly nested loops with independent /// bounds into a single loop. -FunctionPassBase *createLoopCoalescingPass(); +std::unique_ptr createLoopCoalescingPass(); /// Performs packing (or explicit copying) of accessed memref regions into /// buffers in the specified faster memory space through either pointwise copies /// or DMA operations. -FunctionPassBase *createAffineDataCopyGenerationPass( +std::unique_ptr createAffineDataCopyGenerationPass( unsigned slowMemorySpace, unsigned fastMemorySpace, unsigned tagMemorySpace = 0, int minDmaTransferSize = 1024, uint64_t fastMemCapacityBytes = std::numeric_limits::max()); /// Creates a pass to lower VectorTransferReadOp and VectorTransferWriteOp. -FunctionPassBase *createLowerVectorTransfersPass(); +std::unique_ptr createLowerVectorTransfersPass(); /// Creates a pass to perform optimizations relying on memref dataflow such as /// store to load forwarding, elimination of dead stores, and dead allocs. -FunctionPassBase *createMemRefDataFlowOptPass(); +std::unique_ptr createMemRefDataFlowOptPass(); /// Creates a pass to strip debug information from a function. -FunctionPassBase *createStripDebugInfoPass(); +std::unique_ptr createStripDebugInfoPass(); /// Creates a pass which tests loop fusion utilities. -FunctionPassBase *createTestLoopFusionPass(); +std::unique_ptr createTestLoopFusionPass(); } // end namespace mlir diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp index 766377528a1..0223dee9ede 100644 --- a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp @@ -163,9 +163,9 @@ GpuKernelToCubinPass::translateGpuKernelToCubinAnnotation(FuncOp &function) { return success(); } -ModulePassBase * +std::unique_ptr mlir::createConvertGPUKernelToCubinPass(CubinGenerator cubinGenerator) { - return new GpuKernelToCubinPass(cubinGenerator); + return llvm::make_unique(cubinGenerator); } static PassRegistration diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp index bf7577856db..bf0816c8b71 100644 --- a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp @@ -382,8 +382,9 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls( launchOp.erase(); } -mlir::ModulePassBase *mlir::createConvertGpuLaunchFuncToCudaCallsPass() { - return new GpuLaunchFuncToCudaCallsPass(); +std::unique_ptr +mlir::createConvertGpuLaunchFuncToCudaCallsPass() { + return llvm::make_unique(); } static PassRegistration diff --git a/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp b/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp index 813a3bee0ad..fa481632e29 100644 --- a/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp @@ -141,8 +141,8 @@ private: } // anonymous namespace -ModulePassBase *createGenerateCubinAccessorPass() { - return new GpuGenerateCubinAccessorsPass(); +std::unique_ptr createGenerateCubinAccessorPass() { + return llvm::make_unique(); } static PassRegistration diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index e4a6f964f50..91671489f2d 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -128,8 +128,8 @@ public: } // anonymous namespace -FunctionPassBase *createLowerGpuOpsToNVVMOpsPass() { - return new LowerGpuOpsToNVVMOpsPass(); +std::unique_ptr createLowerGpuOpsToNVVMOpsPass() { + return llvm::make_unique(); } static PassRegistration diff --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp index 7c785b5c995..36869b87f1a 100644 --- a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp +++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp @@ -66,13 +66,14 @@ struct ForLoopMapper : public FunctionPass { }; } // namespace -FunctionPassBase *mlir::createSimpleLoopsToGPUPass(unsigned numBlockDims, - unsigned numThreadDims) { - return new ForLoopMapper(numBlockDims, numThreadDims); +std::unique_ptr +mlir::createSimpleLoopsToGPUPass(unsigned numBlockDims, + unsigned numThreadDims) { + return llvm::make_unique(numBlockDims, numThreadDims); } static PassRegistration registration(PASS_NAME, "Convert top-level loops to GPU kernels", [] { - return new ForLoopMapper(clNumBlockDims.getValue(), - clNumThreadDims.getValue()); + return llvm::make_unique(clNumBlockDims.getValue(), + clNumThreadDims.getValue()); }); diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index c62a5d8719d..731c07e22c3 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -1132,14 +1132,15 @@ struct LLVMLoweringPass : public ModulePass { }; } // end namespace -ModulePassBase *mlir::createConvertToLLVMIRPass() { - return new LLVMLoweringPass; +std::unique_ptr mlir::createConvertToLLVMIRPass() { + return llvm::make_unique(); } -ModulePassBase * +std::unique_ptr mlir::createConvertToLLVMIRPass(LLVMPatternListFiller patternListFiller, LLVMTypeConverterMaker typeConverterMaker) { - return new LLVMLoweringPass(patternListFiller, typeConverterMaker); + return llvm::make_unique(patternListFiller, + typeConverterMaker); } static PassRegistration diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp index ad2c4b57fb4..3d4ef639cfa 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp @@ -48,8 +48,9 @@ void ConvertStandardToSPIRVPass::runOnModule() { } } -ModulePassBase *mlir::spirv::createConvertStandardToSPIRVPass() { - return new ConvertStandardToSPIRVPass(); +std::unique_ptr +mlir::spirv::createConvertStandardToSPIRVPass() { + return llvm::make_unique(); } static PassRegistration diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp index 01decce28ac..b7be427be1b 100644 --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -109,8 +109,8 @@ public: } // namespace -ModulePassBase *mlir::createGpuKernelOutliningPass() { - return new GpuKernelOutliningPass(); +std::unique_ptr mlir::createGpuKernelOutliningPass() { + return llvm::make_unique(); } static PassRegistration diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp index 120d0cf0e56..9c48c672300 100644 --- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp +++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp @@ -112,8 +112,8 @@ void ConvertConstPass::runOnFunction() { applyPatternsGreedily(func, patterns); } -FunctionPassBase *mlir::quant::createConvertConstPass() { - return new ConvertConstPass(); +std::unique_ptr mlir::quant::createConvertConstPass() { + return llvm::make_unique(); } static PassRegistration diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp index dfdce8964ba..924e6390d88 100644 --- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp +++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp @@ -103,8 +103,9 @@ void ConvertSimulatedQuantPass::runOnFunction() { signalPassFailure(); } -FunctionPassBase *mlir::quant::createConvertSimulatedQuantPass() { - return new ConvertSimulatedQuantPass(); +std::unique_ptr +mlir::quant::createConvertSimulatedQuantPass() { + return llvm::make_unique(); } static PassRegistration diff --git a/mlir/lib/Linalg/Transforms/Fusion.cpp b/mlir/lib/Linalg/Transforms/Fusion.cpp index 4864f394c88..992c4664b10 100644 --- a/mlir/lib/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Linalg/Transforms/Fusion.cpp @@ -350,14 +350,14 @@ LinalgFusionPass::LinalgFusionPass(ArrayRef sizes) this->tileSizes.assign(sizes.begin(), sizes.end()); } -FunctionPassBase * +std::unique_ptr mlir::linalg::createLinalgFusionPass(ArrayRef tileSizes) { - return new LinalgFusionPass(tileSizes); + return llvm::make_unique(tileSizes); } static PassRegistration pass("linalg-fusion", "Fuse operations in the linalg dialect", [] { - auto *pass = new LinalgFusionPass(); + auto pass = llvm::make_unique(); pass->tileSizes.assign(clTileSizes.begin(), clTileSizes.end()); return pass; }); diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index 84452a2ec2c..49af61e33eb 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -741,8 +741,8 @@ void LowerLinalgToLLVMPass::runOnModule() { } } -ModulePassBase *mlir::linalg::createLowerLinalgToLLVMPass() { - return new LowerLinalgToLLVMPass(); +std::unique_ptr mlir::linalg::createLowerLinalgToLLVMPass() { + return llvm::make_unique(); } static PassRegistration diff --git a/mlir/lib/Linalg/Transforms/LowerToLoops.cpp b/mlir/lib/Linalg/Transforms/LowerToLoops.cpp index afeb5c43f91..24e56b11063 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLoops.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLoops.cpp @@ -390,8 +390,8 @@ void LowerLinalgToLoopsPass::runOnFunction() { } } -FunctionPassBase *mlir::linalg::createLowerLinalgToLoopsPass() { - return new LowerLinalgToLoopsPass(); +std::unique_ptr mlir::linalg::createLowerLinalgToLoopsPass() { + return llvm::make_unique(); } static PassRegistration diff --git a/mlir/lib/Linalg/Transforms/Tiling.cpp b/mlir/lib/Linalg/Transforms/Tiling.cpp index 8090a587d42..48c0da8f88f 100644 --- a/mlir/lib/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Linalg/Transforms/Tiling.cpp @@ -527,15 +527,15 @@ LinalgTilingPass::LinalgTilingPass(ArrayRef sizes, bool promoteViews) { this->promoteViews = promoteViews; } -FunctionPassBase * +std::unique_ptr mlir::linalg::createLinalgTilingPass(ArrayRef tileSizes, bool promoteViews) { - return new LinalgTilingPass(tileSizes, promoteViews); + return llvm::make_unique(tileSizes, promoteViews); } static PassRegistration pass("linalg-tile", "Tile operations in the linalg dialect", [] { - auto *pass = new LinalgTilingPass(); + auto pass = llvm::make_unique(); pass->tileSizes.assign(clTileSizes.begin(), clTileSizes.end()); pass->promoteViews = clPromoteFullTileViews; return pass; diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp index 3ed7b248042..35d96634cf1 100644 --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -264,44 +264,44 @@ void PassManager::disableMultithreading(bool disable) { /// Add an opaque pass pointer to the current manager. This takes ownership /// over the provided pass pointer. -void PassManager::addPass(Pass *pass) { +void PassManager::addPass(std::unique_ptr pass) { switch (pass->getKind()) { case Pass::Kind::FunctionPass: - addPass(cast(pass)); + addPass(cast(std::move(pass))); break; case Pass::Kind::ModulePass: - addPass(cast(pass)); + addPass(cast(std::move(pass))); break; } } /// Add a module pass to the current manager. This takes ownership over the /// provided pass pointer. -void PassManager::addPass(ModulePassBase *pass) { +void PassManager::addPass(std::unique_ptr pass) { nestedExecutorStack.clear(); - mpe->addPass(pass); + mpe->addPass(std::move(pass)); // Add a verifier run if requested. if (verifyPasses) - mpe->addPass(new ModuleVerifierPass()); + mpe->addPass(llvm::make_unique()); } /// Add a function pass to the current manager. This takes ownership over the /// provided pass pointer. This will automatically create a function pass /// executor if necessary. -void PassManager::addPass(FunctionPassBase *pass) { +void PassManager::addPass(std::unique_ptr pass) { detail::FunctionPassExecutor *fpe; if (nestedExecutorStack.empty()) { /// Create an executor adaptor for this pass. if (disableThreads || !llvm::llvm_is_multithreaded()) { // If multi-threading is disabled, then create a synchronous adaptor. - auto *adaptor = new ModuleToFunctionPassAdaptor(); - addPass(adaptor); + auto adaptor = llvm::make_unique(); fpe = &adaptor->getFunctionExecutor(); + addPass(std::unique_ptr{adaptor.release()}); } else { - auto *adaptor = new ModuleToFunctionPassAdaptorParallel(); - addPass(adaptor); + auto adaptor = llvm::make_unique(); fpe = &adaptor->getFunctionExecutor(); + addPass(std::unique_ptr{adaptor.release()}); } /// Add the executor to the stack. @@ -309,11 +309,11 @@ void PassManager::addPass(FunctionPassBase *pass) { } else { fpe = cast(nestedExecutorStack.back()); } - fpe->addPass(pass); + fpe->addPass(std::move(pass)); // Add a verifier run if requested. if (verifyPasses) - fpe->addPass(new FunctionVerifierPass()); + fpe->addPass(llvm::make_unique()); } /// Add the provided instrumentation to the pass manager. This takes ownership diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h index 0b41c44ef14..bb482a2bc65 100644 --- a/mlir/lib/Pass/PassDetail.h +++ b/mlir/lib/Pass/PassDetail.h @@ -66,7 +66,9 @@ public: /// Add a pass to the current executor. This takes ownership over the provided /// pass pointer. - void addPass(FunctionPassBase *pass) { passes.emplace_back(pass); } + void addPass(std::unique_ptr pass) { + passes.push_back(std::move(pass)); + } /// Returns the number of passes held by this executor. size_t size() const { return passes.size(); } @@ -94,7 +96,9 @@ public: /// Add a pass to the current executor. This takes ownership over the provided /// pass pointer. - void addPass(ModulePassBase *pass) { passes.emplace_back(pass); } + void addPass(std::unique_ptr pass) { + passes.push_back(std::move(pass)); + } static bool classof(const PassExecutor *pe) { return pe->getKind() == Kind::ModuleExecutor; diff --git a/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp b/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp index 3f26bf075af..4868d3be291 100644 --- a/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp +++ b/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp @@ -118,8 +118,8 @@ void AddDefaultStatsPass::runWithConfig(SolverContext &solverContext, }); } -FunctionPassBase *mlir::quantizer::createAddDefaultStatsPass() { - return new AddDefaultStatsPass(); +std::unique_ptr mlir::quantizer::createAddDefaultStatsPass() { + return llvm::make_unique(); } static PassRegistration pass( diff --git a/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp b/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp index 765a36e791a..e1365e769b3 100644 --- a/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp +++ b/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp @@ -286,9 +286,9 @@ void InferQuantizedTypesPass::transformResultType(CAGResultAnchor *anchor, } } -ModulePassBase *mlir::quantizer::createInferQuantizedTypesPass( +std::unique_ptr mlir::quantizer::createInferQuantizedTypesPass( SolverContext &solverContext, const TargetConfiguration &config) { - return new InferQuantizedTypesPass(solverContext, config); + return llvm::make_unique(solverContext, config); } static PassRegistration diff --git a/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp b/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp index d5fb28463d6..104a3b60404 100644 --- a/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp +++ b/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp @@ -66,8 +66,9 @@ void RemoveInstrumentationPass::runOnFunction() { applyPatternsGreedily(func, patterns); } -FunctionPassBase *mlir::quantizer::createRemoveInstrumentationPass() { - return new RemoveInstrumentationPass(); +std::unique_ptr +mlir::quantizer::createRemoveInstrumentationPass() { + return llvm::make_unique(); } static PassRegistration diff --git a/mlir/lib/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Transforms/AffineDataCopyGeneration.cpp index 522ed4a4c09..e422bd24425 100644 --- a/mlir/lib/Transforms/AffineDataCopyGeneration.cpp +++ b/mlir/lib/Transforms/AffineDataCopyGeneration.cpp @@ -162,12 +162,12 @@ struct AffineDataCopyGeneration /// buffers in 'fastMemorySpace', and replaces memory operations to the former /// by the latter. Only load op's handled for now. /// TODO(bondhugula): extend this to store op's. -FunctionPassBase *mlir::createAffineDataCopyGenerationPass( +std::unique_ptr mlir::createAffineDataCopyGenerationPass( unsigned slowMemorySpace, unsigned fastMemorySpace, unsigned tagMemorySpace, int minDmaTransferSize, uint64_t fastMemCapacityBytes) { - return new AffineDataCopyGeneration(slowMemorySpace, fastMemorySpace, - tagMemorySpace, minDmaTransferSize, - fastMemCapacityBytes); + return llvm::make_unique( + slowMemorySpace, fastMemorySpace, tagMemorySpace, minDmaTransferSize, + fastMemCapacityBytes); } // Info comprising stride and number of elements transferred every stride. diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index eeb63e7f9eb..59658526c25 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -258,7 +258,9 @@ void CSE::runOnFunction() { markAnalysesPreserved(); } -FunctionPassBase *mlir::createCSEPass() { return new CSE(); } +std::unique_ptr mlir::createCSEPass() { + return llvm::make_unique(); +} static PassRegistration pass("cse", "Eliminate common sub-expressions in functions"); diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp index 80d8ea92b03..6f4a40f86f3 100644 --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -53,8 +53,8 @@ void Canonicalizer::runOnFunction() { } /// Create a Canonicalizer pass. -FunctionPassBase *mlir::createCanonicalizerPass() { - return new Canonicalizer(); +std::unique_ptr mlir::createCanonicalizerPass() { + return llvm::make_unique(); } static PassRegistration pass("canonicalize", diff --git a/mlir/lib/Transforms/LoopCoalescing.cpp b/mlir/lib/Transforms/LoopCoalescing.cpp index f47433c52c0..eb52e8d5802 100644 --- a/mlir/lib/Transforms/LoopCoalescing.cpp +++ b/mlir/lib/Transforms/LoopCoalescing.cpp @@ -96,8 +96,8 @@ public: } // namespace -FunctionPassBase *mlir::createLoopCoalescingPass() { - return new LoopCoalescingPass; +std::unique_ptr mlir::createLoopCoalescingPass() { + return llvm::make_unique(); } static PassRegistration diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index ea1a03f09a3..2736ebc0f55 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -111,10 +111,11 @@ struct LoopFusion : public FunctionPass { } // end anonymous namespace -FunctionPassBase *mlir::createLoopFusionPass(unsigned fastMemorySpace, - uint64_t localBufSizeThreshold, - bool maximalFusion) { - return new LoopFusion(fastMemorySpace, localBufSizeThreshold, maximalFusion); +std::unique_ptr +mlir::createLoopFusionPass(unsigned fastMemorySpace, + uint64_t localBufSizeThreshold, bool maximalFusion) { + return llvm::make_unique(fastMemorySpace, localBufSizeThreshold, + maximalFusion); } namespace { diff --git a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp index d8b5b2d8b2c..09fe9afe808 100644 --- a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp +++ b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp @@ -76,8 +76,8 @@ static bool isMemRefDereferencingOp(Operation &op) { return false; } -FunctionPassBase *mlir::createLoopInvariantCodeMotionPass() { - return new LoopInvariantCodeMotion(); +std::unique_ptr mlir::createLoopInvariantCodeMotionPass() { + return llvm::make_unique(); } // Returns true if the individual op is loop invariant. diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 0a331cae100..d6ff9a94234 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -81,8 +81,9 @@ struct LoopTiling : public FunctionPass { /// Creates a pass to perform loop tiling on all suitable loop nests of a /// Function. -FunctionPassBase *mlir::createLoopTilingPass(uint64_t cacheSizeBytes) { - return new LoopTiling(cacheSizeBytes); +std::unique_ptr +mlir::createLoopTilingPass(uint64_t cacheSizeBytes) { + return llvm::make_unique(cacheSizeBytes); } // Move the loop body of AffineForOp 'src' from 'src' into the specified diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 1c7f3393ada..c3db90e4b3a 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -180,10 +180,10 @@ LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) { return loopUnrollByFactor(forOp, kDefaultUnrollFactor); } -FunctionPassBase *mlir::createLoopUnrollPass( +std::unique_ptr mlir::createLoopUnrollPass( int unrollFactor, int unrollFull, const std::function &getUnrollFactor) { - return new LoopUnroll( + return llvm::make_unique( unrollFactor == -1 ? None : Optional(unrollFactor), unrollFull == -1 ? None : Optional(unrollFull), getUnrollFactor); } diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 7650db1ce27..362aa8683cc 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -82,8 +82,9 @@ struct LoopUnrollAndJam : public FunctionPass { }; } // end anonymous namespace -FunctionPassBase *mlir::createLoopUnrollAndJamPass(int unrollJamFactor) { - return new LoopUnrollAndJam( +std::unique_ptr +mlir::createLoopUnrollAndJamPass(int unrollJamFactor) { + return llvm::make_unique( unrollJamFactor == -1 ? None : Optional(unrollJamFactor)); } diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index 062134dea9c..f24bc6d88da 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -529,8 +529,8 @@ class LowerAffinePass : public FunctionPass { /// Lowers If and For operations within a function into their lower level CFG /// equivalent blocks. -FunctionPassBase *mlir::createLowerAffinePass() { - return new LowerAffinePass(); +std::unique_ptr mlir::createLowerAffinePass() { + return llvm::make_unique(); } static PassRegistration diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index e2d5920f1dd..e941850b5b1 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -373,8 +373,8 @@ struct LowerVectorTransfersPass } // end anonymous namespace -FunctionPassBase *mlir::createLowerVectorTransfersPass() { - return new LowerVectorTransfersPass(); +std::unique_ptr mlir::createLowerVectorTransfersPass() { + return llvm::make_unique(); } static PassRegistration diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 17acc92f49a..24b1f77c939 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -766,9 +766,9 @@ void MaterializeVectorsPass::runOnFunction() { signalPassFailure(); } -FunctionPassBase * +std::unique_ptr mlir::createMaterializeVectorsPass(llvm::ArrayRef vectorSize) { - return new MaterializeVectorsPass(vectorSize); + return llvm::make_unique(vectorSize); } static PassRegistration diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index 4f8b1c61cbf..b16dff93ee3 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -88,8 +88,8 @@ struct MemRefDataFlowOpt : public FunctionPass { /// Creates a pass to perform optimizations relying on memref dataflow such as /// store to load forwarding, elimination of dead stores, and dead allocs. -FunctionPassBase *mlir::createMemRefDataFlowOptPass() { - return new MemRefDataFlowOpt(); +std::unique_ptr mlir::createMemRefDataFlowOptPass() { + return llvm::make_unique(); } // This is a straightforward implementation not optimized for speed. Optimize diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index af456c31408..d4d91c9b0e2 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -49,8 +49,8 @@ struct PipelineDataTransfer : public FunctionPass { /// Creates a pass to pipeline explicit movement of data across levels of the /// memory hierarchy. -FunctionPassBase *mlir::createPipelineDataTransferPass() { - return new PipelineDataTransfer(); +std::unique_ptr mlir::createPipelineDataTransferPass() { + return llvm::make_unique(); } // Returns the position of the tag memref operand given a DMA operation. diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp index 3b6c231d054..3cc9309a5d5 100644 --- a/mlir/lib/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp @@ -88,8 +88,8 @@ struct SimplifyAffineStructures } // end anonymous namespace -FunctionPassBase *mlir::createSimplifyAffineStructuresPass() { - return new SimplifyAffineStructures(); +std::unique_ptr mlir::createSimplifyAffineStructuresPass() { + return llvm::make_unique(); } void SimplifyAffineStructures::runOnFunction() { diff --git a/mlir/lib/Transforms/StripDebugInfo.cpp b/mlir/lib/Transforms/StripDebugInfo.cpp index c82354ed49e..21d8ef15219 100644 --- a/mlir/lib/Transforms/StripDebugInfo.cpp +++ b/mlir/lib/Transforms/StripDebugInfo.cpp @@ -38,8 +38,8 @@ void StripDebugInfo::runOnFunction() { } /// Creates a pass to strip debug information from a function. -FunctionPassBase *mlir::createStripDebugInfoPass() { - return new StripDebugInfo(); +std::unique_ptr mlir::createStripDebugInfoPass() { + return llvm::make_unique(); } static PassRegistration diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index ce254065332..932f00bfcbe 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -1276,9 +1276,9 @@ void Vectorize::runOnFunction() { LLVM_DEBUG(dbgs() << "\n"); } -FunctionPassBase * +std::unique_ptr mlir::createVectorizePass(llvm::ArrayRef virtualVectorSize) { - return new Vectorize(virtualVectorSize); + return llvm::make_unique(virtualVectorSize); } static PassRegistration diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp index 584ff996fca..9b7fe8e94bf 100644 --- a/mlir/test/lib/TestDialect/TestPatterns.cpp +++ b/mlir/test/lib/TestDialect/TestPatterns.cpp @@ -247,6 +247,9 @@ static llvm::cl::opt clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, "partial", "Perform a partial conversion"))); -static mlir::PassRegistration legalizer_pass( - "test-legalize-patterns", "Run test dialect legalization patterns", - [] { return new TestLegalizePatternDriver(legalizerConversionMode); }); +static mlir::PassRegistration + legalizer_pass("test-legalize-patterns", + "Run test dialect legalization patterns", [] { + return llvm::make_unique( + legalizerConversionMode); + }); diff --git a/mlir/test/lib/Transforms/TestConstantFold.cpp b/mlir/test/lib/Transforms/TestConstantFold.cpp index 7d17f60c719..02c66ef86ac 100644 --- a/mlir/test/lib/Transforms/TestConstantFold.cpp +++ b/mlir/test/lib/Transforms/TestConstantFold.cpp @@ -74,8 +74,8 @@ void TestConstantFold::runOnFunction() { } /// Creates a constant folding pass. -FunctionPassBase *mlir::createTestConstantFoldPass() { - return new TestConstantFold(); +std::unique_ptr mlir::createTestConstantFoldPass() { + return llvm::make_unique(); } static PassRegistration diff --git a/mlir/test/lib/Transforms/TestLoopFusion.cpp b/mlir/test/lib/Transforms/TestLoopFusion.cpp index 39990968a34..bcb050769a1 100644 --- a/mlir/test/lib/Transforms/TestLoopFusion.cpp +++ b/mlir/test/lib/Transforms/TestLoopFusion.cpp @@ -58,8 +58,8 @@ struct TestLoopFusion : public FunctionPass { } // end anonymous namespace -FunctionPassBase *mlir::createTestLoopFusionPass() { - return new TestLoopFusion; +std::unique_ptr mlir::createTestLoopFusionPass() { + return llvm::make_unique(); } // Gathers all AffineForOps in 'block' at 'currLoopDepth' in 'depthToLoops'. diff --git a/mlir/test/lib/Transforms/TestLoopMapping.cpp b/mlir/test/lib/Transforms/TestLoopMapping.cpp index bf354670f92..a9da70a6d5e 100644 --- a/mlir/test/lib/Transforms/TestLoopMapping.cpp +++ b/mlir/test/lib/Transforms/TestLoopMapping.cpp @@ -62,4 +62,4 @@ public: static PassRegistration reg("test-mapping-to-processing-elements", "test mapping a single loop on a virtual processor grid", - [] { return new TestLoopMappingPass(); }); + [] { return llvm::make_unique(); }); diff --git a/mlir/test/lib/Transforms/TestLoopParametricTiling.cpp b/mlir/test/lib/Transforms/TestLoopParametricTiling.cpp index d30eacc044d..e01ff66d825 100644 --- a/mlir/test/lib/Transforms/TestLoopParametricTiling.cpp +++ b/mlir/test/lib/Transforms/TestLoopParametricTiling.cpp @@ -55,9 +55,9 @@ public: }; } // end namespace -FunctionPassBase * +std::unique_ptr mlir::createSimpleParametricTilingPass(ArrayRef outerLoopSizes) { - return new SimpleParametricLoopTilingPass(outerLoopSizes); + return llvm::make_unique(outerLoopSizes); } static PassRegistration @@ -65,7 +65,8 @@ static PassRegistration "test application of parametric tiling to the outer loops so that the " "ranges of outer loops become static", [] { - auto *pass = new SimpleParametricLoopTilingPass({}); + auto pass = llvm::make_unique( + ArrayRef{}); pass->sizes.assign(clOuterLoopSizes.begin(), clOuterLoopSizes.end()); return pass; }); diff --git a/mlir/test/lib/Transforms/TestVectorizationUtils.cpp b/mlir/test/lib/Transforms/TestVectorizationUtils.cpp index b51de412306..3bfe6b6fce3 100644 --- a/mlir/test/lib/Transforms/TestVectorizationUtils.cpp +++ b/mlir/test/lib/Transforms/TestVectorizationUtils.cpp @@ -290,8 +290,8 @@ void VectorizerTestPass::runOnFunction() { } } -FunctionPassBase *mlir::createVectorizerTestPass() { - return new VectorizerTestPass(); +std::unique_ptr mlir::createVectorizerTestPass() { + return llvm::make_unique(); } static PassRegistration -- cgit v1.2.3 From 79f53b0cf1fd204af0a09c8e085dd09a1ce0b6d9 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Sat, 17 Aug 2019 11:05:35 -0700 Subject: Change from llvm::make_unique to std::make_unique Switch to C++14 standard method as llvm::make_unique has been removed ( https://reviews.llvm.org/D66259). Also mark some targets as c++14 to ease next integrates. PiperOrigin-RevId: 263953918 --- mlir/examples/Linalg/Linalg3/lib/Transforms.cpp | 2 +- mlir/examples/toy/Ch1/include/toy/Parser.h | 40 ++++++++++------------ mlir/examples/toy/Ch2/include/toy/Parser.h | 40 ++++++++++------------ mlir/examples/toy/Ch2/mlir/MLIRGen.cpp | 4 +-- mlir/examples/toy/Ch3/include/toy/Parser.h | 40 ++++++++++------------ mlir/examples/toy/Ch3/mlir/MLIRGen.cpp | 4 +-- mlir/examples/toy/Ch4/include/toy/Parser.h | 40 ++++++++++------------ mlir/examples/toy/Ch4/mlir/MLIRGen.cpp | 4 +-- mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp | 2 +- mlir/examples/toy/Ch5/include/toy/Parser.h | 40 ++++++++++------------ mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp | 2 +- mlir/examples/toy/Ch5/mlir/LateLowering.cpp | 2 +- mlir/examples/toy/Ch5/mlir/MLIRGen.cpp | 4 +-- mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp | 2 +- mlir/g3doc/Tutorials/Toy/Ch-4.md | 2 +- .../StandardToLLVM/ConvertStandardToLLVMPass.h | 2 +- mlir/include/mlir/IR/Dialect.h | 2 +- mlir/include/mlir/IR/PatternMatch.h | 2 +- mlir/include/mlir/Pass/AnalysisManager.h | 2 +- mlir/include/mlir/Pass/Pass.h | 2 +- mlir/include/mlir/Pass/PassRegistry.h | 2 +- .../Quantizer/Support/ConstraintAnalysisGraph.h | 6 ++-- mlir/lib/Analysis/AffineStructures.cpp | 2 +- mlir/lib/Analysis/Dominance.cpp | 2 +- mlir/lib/Analysis/Utils.cpp | 2 +- .../GPUToCUDA/ConvertKernelFuncToCubin.cpp | 4 +-- .../GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp | 2 +- .../GPUToCUDA/GenerateCubinAccessors.cpp | 2 +- .../Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 2 +- mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp | 6 ++-- .../StandardToLLVM/ConvertStandardToLLVM.cpp | 8 ++--- .../StandardToSPIRV/ConvertStandardToSPIRVPass.cpp | 2 +- .../lib/Dialect/GPU/Transforms/KernelOutlining.cpp | 2 +- .../Dialect/QuantOps/Transforms/ConvertConst.cpp | 2 +- .../QuantOps/Transforms/ConvertSimQuant.cpp | 2 +- mlir/lib/ExecutionEngine/ExecutionEngine.cpp | 12 +++---- mlir/lib/IR/Diagnostics.cpp | 2 +- mlir/lib/Linalg/Transforms/Fusion.cpp | 4 +-- mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp | 2 +- mlir/lib/Linalg/Transforms/LowerToLoops.cpp | 2 +- mlir/lib/Linalg/Transforms/Tiling.cpp | 4 +-- mlir/lib/Pass/Pass.cpp | 8 ++--- .../lib/Quantizer/Configurations/FxpMathConfig.cpp | 2 +- .../Quantizer/Support/ConstraintAnalysisGraph.cpp | 4 +-- .../Transforms/AddDefaultStatsTestPass.cpp | 2 +- .../Transforms/InferQuantizedTypesPass.cpp | 2 +- .../Transforms/RemoveInstrumentationPass.cpp | 2 +- mlir/lib/Support/FileUtilities.cpp | 4 +-- mlir/lib/TableGen/Pattern.cpp | 2 +- mlir/lib/Transforms/AffineDataCopyGeneration.cpp | 4 +-- mlir/lib/Transforms/CSE.cpp | 6 ++-- mlir/lib/Transforms/Canonicalizer.cpp | 2 +- mlir/lib/Transforms/LoopCoalescing.cpp | 2 +- mlir/lib/Transforms/LoopFusion.cpp | 4 +-- mlir/lib/Transforms/LoopInvariantCodeMotion.cpp | 2 +- mlir/lib/Transforms/LoopTiling.cpp | 2 +- mlir/lib/Transforms/LoopUnroll.cpp | 2 +- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 2 +- mlir/lib/Transforms/LowerAffine.cpp | 2 +- mlir/lib/Transforms/LowerVectorTransfers.cpp | 2 +- mlir/lib/Transforms/MaterializeVectors.cpp | 2 +- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 2 +- mlir/lib/Transforms/PipelineDataTransfer.cpp | 2 +- mlir/lib/Transforms/SimplifyAffineStructures.cpp | 2 +- mlir/lib/Transforms/StripDebugInfo.cpp | 2 +- mlir/lib/Transforms/Utils/Utils.cpp | 4 +-- mlir/lib/Transforms/Vectorize.cpp | 2 +- mlir/test/lib/TestDialect/TestPatterns.cpp | 2 +- mlir/test/lib/Transforms/TestConstantFold.cpp | 2 +- mlir/test/lib/Transforms/TestLoopFusion.cpp | 2 +- mlir/test/lib/Transforms/TestLoopMapping.cpp | 2 +- .../lib/Transforms/TestLoopParametricTiling.cpp | 4 +-- .../test/lib/Transforms/TestVectorizationUtils.cpp | 2 +- mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp | 4 +-- 74 files changed, 195 insertions(+), 205 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp index 79fa4ca34f2..8731138bba5 100644 --- a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp @@ -301,5 +301,5 @@ Rewriter::matchAndRewrite(linalg::StoreOp store, } // namespace std::unique_ptr linalg::createLowerLinalgLoadStorePass() { - return llvm::make_unique(); + return std::make_unique(); } diff --git a/mlir/examples/toy/Ch1/include/toy/Parser.h b/mlir/examples/toy/Ch1/include/toy/Parser.h index bc7aa520624..75c660b7c78 100644 --- a/mlir/examples/toy/Ch1/include/toy/Parser.h +++ b/mlir/examples/toy/Ch1/include/toy/Parser.h @@ -62,7 +62,7 @@ public: if (lexer.getCurToken() != tok_eof) return parseError("nothing", "at end of module"); - return llvm::make_unique(std::move(functions)); + return std::make_unique(std::move(functions)); } private: @@ -81,7 +81,7 @@ private: if (!expr) return nullptr; } - return llvm::make_unique(std::move(loc), std::move(expr)); + return std::make_unique(std::move(loc), std::move(expr)); } /// Parse a literal number. @@ -89,7 +89,7 @@ private: std::unique_ptr ParseNumberExpr() { auto loc = lexer.getLastLocation(); auto Result = - llvm::make_unique(std::move(loc), lexer.getValue()); + std::make_unique(std::move(loc), lexer.getValue()); lexer.consume(tok_number); return std::move(Result); } @@ -157,8 +157,8 @@ private: "inside literal expession"); } } - return llvm::make_unique(std::move(loc), std::move(values), - std::move(dims)); + return std::make_unique(std::move(loc), std::move(values), + std::move(dims)); } /// parenexpr ::= '(' expression ')' @@ -184,7 +184,7 @@ private: lexer.getNextToken(); // eat identifier. if (lexer.getCurToken() != '(') // Simple variable ref. - return llvm::make_unique(std::move(loc), name); + return std::make_unique(std::move(loc), name); // This is a function call. lexer.consume(Token('(')); @@ -211,13 +211,11 @@ private: if (Args.size() != 1) return parseError("", "as argument to print()"); - return llvm::make_unique(std::move(loc), - std::move(Args[0])); + return std::make_unique(std::move(loc), std::move(Args[0])); } // Call to a user-defined function - return llvm::make_unique(std::move(loc), name, - std::move(Args)); + return std::make_unique(std::move(loc), name, std::move(Args)); } /// primary @@ -281,8 +279,8 @@ private: } // Merge LHS/RHS. - LHS = llvm::make_unique(std::move(loc), BinOp, - std::move(LHS), std::move(RHS)); + LHS = std::make_unique(std::move(loc), BinOp, + std::move(LHS), std::move(RHS)); } } @@ -302,7 +300,7 @@ private: return parseError("<", "to begin type"); lexer.getNextToken(); // eat < - auto type = llvm::make_unique(); + auto type = std::make_unique(); while (lexer.getCurToken() == tok_number) { type->shape.push_back(lexer.getValue()); @@ -341,11 +339,11 @@ private: } if (!type) - type = llvm::make_unique(); + type = std::make_unique(); lexer.consume(Token('=')); auto expr = ParseExpression(); - return llvm::make_unique(std::move(loc), std::move(id), - std::move(*type), std::move(expr)); + return std::make_unique(std::move(loc), std::move(id), + std::move(*type), std::move(expr)); } /// Parse a block: a list of expression separated by semicolons and wrapped in @@ -359,7 +357,7 @@ private: return parseError("{", "to begin block"); lexer.consume(Token('{')); - auto exprList = llvm::make_unique(); + auto exprList = std::make_unique(); // Ignore empty expressions: swallow sequences of semicolons. while (lexer.getCurToken() == ';') @@ -422,7 +420,7 @@ private: std::string name = lexer.getId(); auto loc = lexer.getLastLocation(); lexer.consume(tok_identifier); - auto decl = llvm::make_unique(std::move(loc), name); + auto decl = std::make_unique(std::move(loc), name); args.push_back(std::move(decl)); if (lexer.getCurToken() != ',') break; @@ -437,8 +435,8 @@ private: // success. lexer.consume(Token(')')); - return llvm::make_unique(std::move(loc), FnName, - std::move(args)); + return std::make_unique(std::move(loc), FnName, + std::move(args)); } /// Parse a function definition, we expect a prototype initiated with the @@ -451,7 +449,7 @@ private: return nullptr; if (auto block = ParseBlock()) - return llvm::make_unique(std::move(Proto), std::move(block)); + return std::make_unique(std::move(Proto), std::move(block)); return nullptr; } diff --git a/mlir/examples/toy/Ch2/include/toy/Parser.h b/mlir/examples/toy/Ch2/include/toy/Parser.h index bc7aa520624..75c660b7c78 100644 --- a/mlir/examples/toy/Ch2/include/toy/Parser.h +++ b/mlir/examples/toy/Ch2/include/toy/Parser.h @@ -62,7 +62,7 @@ public: if (lexer.getCurToken() != tok_eof) return parseError("nothing", "at end of module"); - return llvm::make_unique(std::move(functions)); + return std::make_unique(std::move(functions)); } private: @@ -81,7 +81,7 @@ private: if (!expr) return nullptr; } - return llvm::make_unique(std::move(loc), std::move(expr)); + return std::make_unique(std::move(loc), std::move(expr)); } /// Parse a literal number. @@ -89,7 +89,7 @@ private: std::unique_ptr ParseNumberExpr() { auto loc = lexer.getLastLocation(); auto Result = - llvm::make_unique(std::move(loc), lexer.getValue()); + std::make_unique(std::move(loc), lexer.getValue()); lexer.consume(tok_number); return std::move(Result); } @@ -157,8 +157,8 @@ private: "inside literal expession"); } } - return llvm::make_unique(std::move(loc), std::move(values), - std::move(dims)); + return std::make_unique(std::move(loc), std::move(values), + std::move(dims)); } /// parenexpr ::= '(' expression ')' @@ -184,7 +184,7 @@ private: lexer.getNextToken(); // eat identifier. if (lexer.getCurToken() != '(') // Simple variable ref. - return llvm::make_unique(std::move(loc), name); + return std::make_unique(std::move(loc), name); // This is a function call. lexer.consume(Token('(')); @@ -211,13 +211,11 @@ private: if (Args.size() != 1) return parseError("", "as argument to print()"); - return llvm::make_unique(std::move(loc), - std::move(Args[0])); + return std::make_unique(std::move(loc), std::move(Args[0])); } // Call to a user-defined function - return llvm::make_unique(std::move(loc), name, - std::move(Args)); + return std::make_unique(std::move(loc), name, std::move(Args)); } /// primary @@ -281,8 +279,8 @@ private: } // Merge LHS/RHS. - LHS = llvm::make_unique(std::move(loc), BinOp, - std::move(LHS), std::move(RHS)); + LHS = std::make_unique(std::move(loc), BinOp, + std::move(LHS), std::move(RHS)); } } @@ -302,7 +300,7 @@ private: return parseError("<", "to begin type"); lexer.getNextToken(); // eat < - auto type = llvm::make_unique(); + auto type = std::make_unique(); while (lexer.getCurToken() == tok_number) { type->shape.push_back(lexer.getValue()); @@ -341,11 +339,11 @@ private: } if (!type) - type = llvm::make_unique(); + type = std::make_unique(); lexer.consume(Token('=')); auto expr = ParseExpression(); - return llvm::make_unique(std::move(loc), std::move(id), - std::move(*type), std::move(expr)); + return std::make_unique(std::move(loc), std::move(id), + std::move(*type), std::move(expr)); } /// Parse a block: a list of expression separated by semicolons and wrapped in @@ -359,7 +357,7 @@ private: return parseError("{", "to begin block"); lexer.consume(Token('{')); - auto exprList = llvm::make_unique(); + auto exprList = std::make_unique(); // Ignore empty expressions: swallow sequences of semicolons. while (lexer.getCurToken() == ';') @@ -422,7 +420,7 @@ private: std::string name = lexer.getId(); auto loc = lexer.getLastLocation(); lexer.consume(tok_identifier); - auto decl = llvm::make_unique(std::move(loc), name); + auto decl = std::make_unique(std::move(loc), name); args.push_back(std::move(decl)); if (lexer.getCurToken() != ',') break; @@ -437,8 +435,8 @@ private: // success. lexer.consume(Token(')')); - return llvm::make_unique(std::move(loc), FnName, - std::move(args)); + return std::make_unique(std::move(loc), FnName, + std::move(args)); } /// Parse a function definition, we expect a prototype initiated with the @@ -451,7 +449,7 @@ private: return nullptr; if (auto block = ParseBlock()) - return llvm::make_unique(std::move(Proto), std::move(block)); + return std::make_unique(std::move(Proto), std::move(block)); return nullptr; } diff --git a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp index 7b874b92cc4..c09c4ad679c 100644 --- a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp @@ -43,11 +43,11 @@ using namespace toy; using llvm::cast; using llvm::dyn_cast; using llvm::isa; -using llvm::make_unique; using llvm::ScopedHashTableScope; using llvm::SmallVector; using llvm::StringRef; using llvm::Twine; +using std::make_unique; namespace { @@ -172,7 +172,7 @@ private: // Create a builder for the function, it will be used throughout the codegen // to create operations in this function. - builder = llvm::make_unique(function.getBody()); + builder = std::make_unique(function.getBody()); // Emit the body of the function. if (!mlirGen(*funcAST.getBody())) { diff --git a/mlir/examples/toy/Ch3/include/toy/Parser.h b/mlir/examples/toy/Ch3/include/toy/Parser.h index bc7aa520624..75c660b7c78 100644 --- a/mlir/examples/toy/Ch3/include/toy/Parser.h +++ b/mlir/examples/toy/Ch3/include/toy/Parser.h @@ -62,7 +62,7 @@ public: if (lexer.getCurToken() != tok_eof) return parseError("nothing", "at end of module"); - return llvm::make_unique(std::move(functions)); + return std::make_unique(std::move(functions)); } private: @@ -81,7 +81,7 @@ private: if (!expr) return nullptr; } - return llvm::make_unique(std::move(loc), std::move(expr)); + return std::make_unique(std::move(loc), std::move(expr)); } /// Parse a literal number. @@ -89,7 +89,7 @@ private: std::unique_ptr ParseNumberExpr() { auto loc = lexer.getLastLocation(); auto Result = - llvm::make_unique(std::move(loc), lexer.getValue()); + std::make_unique(std::move(loc), lexer.getValue()); lexer.consume(tok_number); return std::move(Result); } @@ -157,8 +157,8 @@ private: "inside literal expession"); } } - return llvm::make_unique(std::move(loc), std::move(values), - std::move(dims)); + return std::make_unique(std::move(loc), std::move(values), + std::move(dims)); } /// parenexpr ::= '(' expression ')' @@ -184,7 +184,7 @@ private: lexer.getNextToken(); // eat identifier. if (lexer.getCurToken() != '(') // Simple variable ref. - return llvm::make_unique(std::move(loc), name); + return std::make_unique(std::move(loc), name); // This is a function call. lexer.consume(Token('(')); @@ -211,13 +211,11 @@ private: if (Args.size() != 1) return parseError("", "as argument to print()"); - return llvm::make_unique(std::move(loc), - std::move(Args[0])); + return std::make_unique(std::move(loc), std::move(Args[0])); } // Call to a user-defined function - return llvm::make_unique(std::move(loc), name, - std::move(Args)); + return std::make_unique(std::move(loc), name, std::move(Args)); } /// primary @@ -281,8 +279,8 @@ private: } // Merge LHS/RHS. - LHS = llvm::make_unique(std::move(loc), BinOp, - std::move(LHS), std::move(RHS)); + LHS = std::make_unique(std::move(loc), BinOp, + std::move(LHS), std::move(RHS)); } } @@ -302,7 +300,7 @@ private: return parseError("<", "to begin type"); lexer.getNextToken(); // eat < - auto type = llvm::make_unique(); + auto type = std::make_unique(); while (lexer.getCurToken() == tok_number) { type->shape.push_back(lexer.getValue()); @@ -341,11 +339,11 @@ private: } if (!type) - type = llvm::make_unique(); + type = std::make_unique(); lexer.consume(Token('=')); auto expr = ParseExpression(); - return llvm::make_unique(std::move(loc), std::move(id), - std::move(*type), std::move(expr)); + return std::make_unique(std::move(loc), std::move(id), + std::move(*type), std::move(expr)); } /// Parse a block: a list of expression separated by semicolons and wrapped in @@ -359,7 +357,7 @@ private: return parseError("{", "to begin block"); lexer.consume(Token('{')); - auto exprList = llvm::make_unique(); + auto exprList = std::make_unique(); // Ignore empty expressions: swallow sequences of semicolons. while (lexer.getCurToken() == ';') @@ -422,7 +420,7 @@ private: std::string name = lexer.getId(); auto loc = lexer.getLastLocation(); lexer.consume(tok_identifier); - auto decl = llvm::make_unique(std::move(loc), name); + auto decl = std::make_unique(std::move(loc), name); args.push_back(std::move(decl)); if (lexer.getCurToken() != ',') break; @@ -437,8 +435,8 @@ private: // success. lexer.consume(Token(')')); - return llvm::make_unique(std::move(loc), FnName, - std::move(args)); + return std::make_unique(std::move(loc), FnName, + std::move(args)); } /// Parse a function definition, we expect a prototype initiated with the @@ -451,7 +449,7 @@ private: return nullptr; if (auto block = ParseBlock()) - return llvm::make_unique(std::move(Proto), std::move(block)); + return std::make_unique(std::move(Proto), std::move(block)); return nullptr; } diff --git a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp index e3b06a7f7df..b3ba2f9281c 100644 --- a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp @@ -44,11 +44,11 @@ using namespace toy; using llvm::cast; using llvm::dyn_cast; using llvm::isa; -using llvm::make_unique; using llvm::ScopedHashTableScope; using llvm::SmallVector; using llvm::StringRef; using llvm::Twine; +using std::make_unique; namespace { @@ -173,7 +173,7 @@ private: // Create a builder for the function, it will be used throughout the codegen // to create operations in this function. - builder = llvm::make_unique(function.getBody()); + builder = std::make_unique(function.getBody()); // Emit the body of the function. if (!mlirGen(*funcAST.getBody())) { diff --git a/mlir/examples/toy/Ch4/include/toy/Parser.h b/mlir/examples/toy/Ch4/include/toy/Parser.h index bc7aa520624..75c660b7c78 100644 --- a/mlir/examples/toy/Ch4/include/toy/Parser.h +++ b/mlir/examples/toy/Ch4/include/toy/Parser.h @@ -62,7 +62,7 @@ public: if (lexer.getCurToken() != tok_eof) return parseError("nothing", "at end of module"); - return llvm::make_unique(std::move(functions)); + return std::make_unique(std::move(functions)); } private: @@ -81,7 +81,7 @@ private: if (!expr) return nullptr; } - return llvm::make_unique(std::move(loc), std::move(expr)); + return std::make_unique(std::move(loc), std::move(expr)); } /// Parse a literal number. @@ -89,7 +89,7 @@ private: std::unique_ptr ParseNumberExpr() { auto loc = lexer.getLastLocation(); auto Result = - llvm::make_unique(std::move(loc), lexer.getValue()); + std::make_unique(std::move(loc), lexer.getValue()); lexer.consume(tok_number); return std::move(Result); } @@ -157,8 +157,8 @@ private: "inside literal expession"); } } - return llvm::make_unique(std::move(loc), std::move(values), - std::move(dims)); + return std::make_unique(std::move(loc), std::move(values), + std::move(dims)); } /// parenexpr ::= '(' expression ')' @@ -184,7 +184,7 @@ private: lexer.getNextToken(); // eat identifier. if (lexer.getCurToken() != '(') // Simple variable ref. - return llvm::make_unique(std::move(loc), name); + return std::make_unique(std::move(loc), name); // This is a function call. lexer.consume(Token('(')); @@ -211,13 +211,11 @@ private: if (Args.size() != 1) return parseError("", "as argument to print()"); - return llvm::make_unique(std::move(loc), - std::move(Args[0])); + return std::make_unique(std::move(loc), std::move(Args[0])); } // Call to a user-defined function - return llvm::make_unique(std::move(loc), name, - std::move(Args)); + return std::make_unique(std::move(loc), name, std::move(Args)); } /// primary @@ -281,8 +279,8 @@ private: } // Merge LHS/RHS. - LHS = llvm::make_unique(std::move(loc), BinOp, - std::move(LHS), std::move(RHS)); + LHS = std::make_unique(std::move(loc), BinOp, + std::move(LHS), std::move(RHS)); } } @@ -302,7 +300,7 @@ private: return parseError("<", "to begin type"); lexer.getNextToken(); // eat < - auto type = llvm::make_unique(); + auto type = std::make_unique(); while (lexer.getCurToken() == tok_number) { type->shape.push_back(lexer.getValue()); @@ -341,11 +339,11 @@ private: } if (!type) - type = llvm::make_unique(); + type = std::make_unique(); lexer.consume(Token('=')); auto expr = ParseExpression(); - return llvm::make_unique(std::move(loc), std::move(id), - std::move(*type), std::move(expr)); + return std::make_unique(std::move(loc), std::move(id), + std::move(*type), std::move(expr)); } /// Parse a block: a list of expression separated by semicolons and wrapped in @@ -359,7 +357,7 @@ private: return parseError("{", "to begin block"); lexer.consume(Token('{')); - auto exprList = llvm::make_unique(); + auto exprList = std::make_unique(); // Ignore empty expressions: swallow sequences of semicolons. while (lexer.getCurToken() == ';') @@ -422,7 +420,7 @@ private: std::string name = lexer.getId(); auto loc = lexer.getLastLocation(); lexer.consume(tok_identifier); - auto decl = llvm::make_unique(std::move(loc), name); + auto decl = std::make_unique(std::move(loc), name); args.push_back(std::move(decl)); if (lexer.getCurToken() != ',') break; @@ -437,8 +435,8 @@ private: // success. lexer.consume(Token(')')); - return llvm::make_unique(std::move(loc), FnName, - std::move(args)); + return std::make_unique(std::move(loc), FnName, + std::move(args)); } /// Parse a function definition, we expect a prototype initiated with the @@ -451,7 +449,7 @@ private: return nullptr; if (auto block = ParseBlock()) - return llvm::make_unique(std::move(Proto), std::move(block)); + return std::make_unique(std::move(Proto), std::move(block)); return nullptr; } diff --git a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp index e61d1aaa99d..fd385a47004 100644 --- a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp @@ -44,11 +44,11 @@ using namespace toy; using llvm::cast; using llvm::dyn_cast; using llvm::isa; -using llvm::make_unique; using llvm::ScopedHashTableScope; using llvm::SmallVector; using llvm::StringRef; using llvm::Twine; +using std::make_unique; namespace { @@ -173,7 +173,7 @@ private: // Create a builder for the function, it will be used throughout the codegen // to create operations in this function. - builder = llvm::make_unique(function.getBody()); + builder = std::make_unique(function.getBody()); // Emit the body of the function. if (!mlirGen(*funcAST.getBody())) { diff --git a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp index 4a6bf8790e0..793f153291e 100644 --- a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp @@ -376,6 +376,6 @@ public: namespace toy { std::unique_ptr createShapeInferencePass() { - return llvm::make_unique(); + return std::make_unique(); } } // namespace toy diff --git a/mlir/examples/toy/Ch5/include/toy/Parser.h b/mlir/examples/toy/Ch5/include/toy/Parser.h index bc7aa520624..75c660b7c78 100644 --- a/mlir/examples/toy/Ch5/include/toy/Parser.h +++ b/mlir/examples/toy/Ch5/include/toy/Parser.h @@ -62,7 +62,7 @@ public: if (lexer.getCurToken() != tok_eof) return parseError("nothing", "at end of module"); - return llvm::make_unique(std::move(functions)); + return std::make_unique(std::move(functions)); } private: @@ -81,7 +81,7 @@ private: if (!expr) return nullptr; } - return llvm::make_unique(std::move(loc), std::move(expr)); + return std::make_unique(std::move(loc), std::move(expr)); } /// Parse a literal number. @@ -89,7 +89,7 @@ private: std::unique_ptr ParseNumberExpr() { auto loc = lexer.getLastLocation(); auto Result = - llvm::make_unique(std::move(loc), lexer.getValue()); + std::make_unique(std::move(loc), lexer.getValue()); lexer.consume(tok_number); return std::move(Result); } @@ -157,8 +157,8 @@ private: "inside literal expession"); } } - return llvm::make_unique(std::move(loc), std::move(values), - std::move(dims)); + return std::make_unique(std::move(loc), std::move(values), + std::move(dims)); } /// parenexpr ::= '(' expression ')' @@ -184,7 +184,7 @@ private: lexer.getNextToken(); // eat identifier. if (lexer.getCurToken() != '(') // Simple variable ref. - return llvm::make_unique(std::move(loc), name); + return std::make_unique(std::move(loc), name); // This is a function call. lexer.consume(Token('(')); @@ -211,13 +211,11 @@ private: if (Args.size() != 1) return parseError("", "as argument to print()"); - return llvm::make_unique(std::move(loc), - std::move(Args[0])); + return std::make_unique(std::move(loc), std::move(Args[0])); } // Call to a user-defined function - return llvm::make_unique(std::move(loc), name, - std::move(Args)); + return std::make_unique(std::move(loc), name, std::move(Args)); } /// primary @@ -281,8 +279,8 @@ private: } // Merge LHS/RHS. - LHS = llvm::make_unique(std::move(loc), BinOp, - std::move(LHS), std::move(RHS)); + LHS = std::make_unique(std::move(loc), BinOp, + std::move(LHS), std::move(RHS)); } } @@ -302,7 +300,7 @@ private: return parseError("<", "to begin type"); lexer.getNextToken(); // eat < - auto type = llvm::make_unique(); + auto type = std::make_unique(); while (lexer.getCurToken() == tok_number) { type->shape.push_back(lexer.getValue()); @@ -341,11 +339,11 @@ private: } if (!type) - type = llvm::make_unique(); + type = std::make_unique(); lexer.consume(Token('=')); auto expr = ParseExpression(); - return llvm::make_unique(std::move(loc), std::move(id), - std::move(*type), std::move(expr)); + return std::make_unique(std::move(loc), std::move(id), + std::move(*type), std::move(expr)); } /// Parse a block: a list of expression separated by semicolons and wrapped in @@ -359,7 +357,7 @@ private: return parseError("{", "to begin block"); lexer.consume(Token('{')); - auto exprList = llvm::make_unique(); + auto exprList = std::make_unique(); // Ignore empty expressions: swallow sequences of semicolons. while (lexer.getCurToken() == ';') @@ -422,7 +420,7 @@ private: std::string name = lexer.getId(); auto loc = lexer.getLastLocation(); lexer.consume(tok_identifier); - auto decl = llvm::make_unique(std::move(loc), name); + auto decl = std::make_unique(std::move(loc), name); args.push_back(std::move(decl)); if (lexer.getCurToken() != ',') break; @@ -437,8 +435,8 @@ private: // success. lexer.consume(Token(')')); - return llvm::make_unique(std::move(loc), FnName, - std::move(args)); + return std::make_unique(std::move(loc), FnName, + std::move(args)); } /// Parse a function definition, we expect a prototype initiated with the @@ -451,7 +449,7 @@ private: return nullptr; if (auto block = ParseBlock()) - return llvm::make_unique(std::move(Proto), std::move(block)); + return std::make_unique(std::move(Proto), std::move(block)); return nullptr; } diff --git a/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp b/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp index 96230fdfbea..c55a0dbd949 100644 --- a/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp @@ -143,6 +143,6 @@ struct EarlyLoweringPass : public FunctionPass { namespace toy { std::unique_ptr createEarlyLoweringPass() { - return llvm::make_unique(); + return std::make_unique(); } } // namespace toy diff --git a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp index 29d83aeb663..8146e763303 100644 --- a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp @@ -455,6 +455,6 @@ struct LateLoweringPass : public ModulePass { namespace toy { std::unique_ptr createLateLoweringPass() { - return llvm::make_unique(); + return std::make_unique(); } } // namespace toy diff --git a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp index 8d7d169c7d2..88fb95048da 100644 --- a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp @@ -44,11 +44,11 @@ using namespace toy; using llvm::cast; using llvm::dyn_cast; using llvm::isa; -using llvm::make_unique; using llvm::ScopedHashTableScope; using llvm::SmallVector; using llvm::StringRef; using llvm::Twine; +using std::make_unique; namespace { @@ -173,7 +173,7 @@ private: // Create a builder for the function, it will be used throughout the codegen // to create operations in this function. - builder = llvm::make_unique(function.getBody()); + builder = std::make_unique(function.getBody()); // Emit the body of the function. if (!mlirGen(*funcAST.getBody())) { diff --git a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp index 6437c0b3f73..b6808d713eb 100644 --- a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp @@ -376,6 +376,6 @@ public: namespace toy { std::unique_ptr createShapeInferencePass() { - return llvm::make_unique(); + return std::make_unique(); } } // namespace toy diff --git a/mlir/g3doc/Tutorials/Toy/Ch-4.md b/mlir/g3doc/Tutorials/Toy/Ch-4.md index 343d8f9e00b..1551e129379 100644 --- a/mlir/g3doc/Tutorials/Toy/Ch-4.md +++ b/mlir/g3doc/Tutorials/Toy/Ch-4.md @@ -112,7 +112,7 @@ method: /// supports, for use by the canonicalization pass. static void getCanonicalizationPatterns(mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { - results.push_back(llvm::make_unique(context)); + results.push_back(std::make_unique(context)); } ``` diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h index a08b2fb45d6..d2f416b35fe 100644 --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h @@ -74,7 +74,7 @@ template std::unique_ptr createConvertToLLVMIRPass(LLVMPatternListFiller patternListFiller) { return createConvertToLLVMIRPass(patternListFiller, [](MLIRContext *context) { - return llvm::make_unique(context); + return std::make_unique(context); }); } diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h index 683701f3bc4..7ed647b61f9 100644 --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -262,7 +262,7 @@ protected: addInterfaces(); } template void addInterfaces() { - addInterface(llvm::make_unique(this)); + addInterface(std::make_unique(this)); } private: diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index d47b924d888..5e4fe60a7bd 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -422,7 +422,7 @@ public: // FIXME: In c++17 this can be simplified by using 'fold expressions'. using dummy = int[]; (void)dummy{ - 0, (patterns.emplace_back(llvm::make_unique(arg, args...)), 0)...}; + 0, (patterns.emplace_back(std::make_unique(arg, args...)), 0)...}; } private: diff --git a/mlir/include/mlir/Pass/AnalysisManager.h b/mlir/include/mlir/Pass/AnalysisManager.h index 1f44515ceb1..ae98831f2b1 100644 --- a/mlir/include/mlir/Pass/AnalysisManager.h +++ b/mlir/include/mlir/Pass/AnalysisManager.h @@ -123,7 +123,7 @@ public: if (pi) pi->runBeforeAnalysis(getAnalysisName(), id, ir); - it->second = llvm::make_unique>(ir); + it->second = std::make_unique>(ir); if (pi) pi->runAfterAnalysis(getAnalysisName(), id, ir); diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h index f5c8d8bd1a6..3a3444af532 100644 --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -260,7 +260,7 @@ struct FunctionPass : public detail::PassModel { /// A clone method to create a copy of this pass. std::unique_ptr clone() const override { - return llvm::make_unique(*static_cast(this)); + return std::make_unique(*static_cast(this)); } }; diff --git a/mlir/include/mlir/Pass/PassRegistry.h b/mlir/include/mlir/Pass/PassRegistry.h index bd108f3e77f..eea3778d8b1 100644 --- a/mlir/include/mlir/Pass/PassRegistry.h +++ b/mlir/include/mlir/Pass/PassRegistry.h @@ -122,7 +122,7 @@ template struct PassRegistration { PassRegistration(StringRef arg, StringRef description) { PassAllocatorFunction constructor = [] { - return llvm::make_unique(); + return std::make_unique(); }; registerPass(arg, description, PassID::getID(), constructor); } diff --git a/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraph.h b/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraph.h index 8f2a0e52b30..63f62dbeeeb 100644 --- a/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraph.h +++ b/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraph.h @@ -279,7 +279,7 @@ public: Args... args) { static_assert(std::is_convertible(), "T must be a CAGConstraingNode"); - T *constraintNode = addNode(llvm::make_unique(args...)); + T *constraintNode = addNode(std::make_unique(args...)); for (auto *anchor : anchors) anchor->addOutgoing(constraintNode); return constraintNode; @@ -292,7 +292,7 @@ public: Args... args) { static_assert(std::is_convertible(), "T must be a CAGConstraingNode"); - T *constraintNode = addNode(llvm::make_unique(args...)); + T *constraintNode = addNode(std::make_unique(args...)); fromAnchor->addOutgoing(constraintNode); for (auto *toAnchor : toAnchors) { constraintNode->addOutgoing(toAnchor); @@ -312,7 +312,7 @@ public: T *constraintNode; if (cluster.empty()) { // Create new. - constraintNode = addNode(llvm::make_unique()); + constraintNode = addNode(std::make_unique()); } else { // Merge existing. constraintNode = cluster[0]; diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 46e45351d54..b2b2c6970b9 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -303,7 +303,7 @@ FlatAffineConstraints::FlatAffineConstraints( // Clones this object. std::unique_ptr FlatAffineConstraints::clone() const { - return llvm::make_unique(*this); + return std::make_unique(*this); } // Construct from an IntegerSet. diff --git a/mlir/lib/Analysis/Dominance.cpp b/mlir/lib/Analysis/Dominance.cpp index e384a56a71d..ead8d7e070c 100644 --- a/mlir/lib/Analysis/Dominance.cpp +++ b/mlir/lib/Analysis/Dominance.cpp @@ -45,7 +45,7 @@ void DominanceInfoBase::recalculate(Operation *op) { // Don't compute dominance if the region is empty. if (region.empty()) continue; - auto opDominance = llvm::make_unique(); + auto opDominance = std::make_unique(); opDominance->recalculate(region); dominanceInfos.try_emplace(®ion, std::move(opDominance)); } diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index fc36cc58f8e..85e39e37f65 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -913,7 +913,7 @@ static Optional getMemoryFootprintBytes(Block &block, } // Compute the memref region symbolic in any IVs enclosing this block. - auto region = llvm::make_unique(opInst->getLoc()); + auto region = std::make_unique(opInst->getLoc()); if (failed( region->compute(opInst, /*loopDepth=*/getNestingDepth(*block.begin())))) { diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp index 0223dee9ede..29771fe7ea5 100644 --- a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp @@ -106,7 +106,7 @@ OwnedCubin GpuKernelToCubinPass::compilePtxToCubinForTesting(const std::string &ptx, FuncOp &function) { const char data[] = "CUBIN"; - return llvm::make_unique>(data, data + sizeof(data) - 1); + return std::make_unique>(data, data + sizeof(data) - 1); } OwnedCubin GpuKernelToCubinPass::convertModuleToCubin(llvm::Module &llvmModule, @@ -165,7 +165,7 @@ GpuKernelToCubinPass::translateGpuKernelToCubinAnnotation(FuncOp &function) { std::unique_ptr mlir::createConvertGPUKernelToCubinPass(CubinGenerator cubinGenerator) { - return llvm::make_unique(cubinGenerator); + return std::make_unique(cubinGenerator); } static PassRegistration diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp index bf0816c8b71..b3864a39560 100644 --- a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp @@ -384,7 +384,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls( std::unique_ptr mlir::createConvertGpuLaunchFuncToCudaCallsPass() { - return llvm::make_unique(); + return std::make_unique(); } static PassRegistration diff --git a/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp b/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp index 332a1324865..b819de2471e 100644 --- a/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp @@ -120,7 +120,7 @@ private: } // anonymous namespace std::unique_ptr createGenerateCubinAccessorPass() { - return llvm::make_unique(); + return std::make_unique(); } static PassRegistration diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index 91671489f2d..32b0caf180a 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -129,7 +129,7 @@ public: } // anonymous namespace std::unique_ptr createLowerGpuOpsToNVVMOpsPass() { - return llvm::make_unique(); + return std::make_unique(); } static PassRegistration diff --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp index 36869b87f1a..4b241e497c6 100644 --- a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp +++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp @@ -69,11 +69,11 @@ struct ForLoopMapper : public FunctionPass { std::unique_ptr mlir::createSimpleLoopsToGPUPass(unsigned numBlockDims, unsigned numThreadDims) { - return llvm::make_unique(numBlockDims, numThreadDims); + return std::make_unique(numBlockDims, numThreadDims); } static PassRegistration registration(PASS_NAME, "Convert top-level loops to GPU kernels", [] { - return llvm::make_unique(clNumBlockDims.getValue(), - clNumThreadDims.getValue()); + return std::make_unique(clNumBlockDims.getValue(), + clNumThreadDims.getValue()); }); diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 731c07e22c3..9ba06db7aba 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -1082,7 +1082,7 @@ Type LLVMTypeConverter::packFunctionResults(ArrayRef types) { /// Create an instance of LLVMTypeConverter in the given context. static std::unique_ptr makeStandardToLLVMTypeConverter(MLIRContext *context) { - return llvm::make_unique(context); + return std::make_unique(context); } namespace { @@ -1133,14 +1133,14 @@ struct LLVMLoweringPass : public ModulePass { } // end namespace std::unique_ptr mlir::createConvertToLLVMIRPass() { - return llvm::make_unique(); + return std::make_unique(); } std::unique_ptr mlir::createConvertToLLVMIRPass(LLVMPatternListFiller patternListFiller, LLVMTypeConverterMaker typeConverterMaker) { - return llvm::make_unique(patternListFiller, - typeConverterMaker); + return std::make_unique(patternListFiller, + typeConverterMaker); } static PassRegistration diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp index 3d4ef639cfa..174a4477560 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp @@ -50,7 +50,7 @@ void ConvertStandardToSPIRVPass::runOnModule() { std::unique_ptr mlir::spirv::createConvertStandardToSPIRVPass() { - return llvm::make_unique(); + return std::make_unique(); } static PassRegistration diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp index b7be427be1b..ea64ea8058b 100644 --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -110,7 +110,7 @@ public: } // namespace std::unique_ptr mlir::createGpuKernelOutliningPass() { - return llvm::make_unique(); + return std::make_unique(); } static PassRegistration diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp index 9c48c672300..efb202b7491 100644 --- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp +++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp @@ -113,7 +113,7 @@ void ConvertConstPass::runOnFunction() { } std::unique_ptr mlir::quant::createConvertConstPass() { - return llvm::make_unique(); + return std::make_unique(); } static PassRegistration diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp index 924e6390d88..129671979ca 100644 --- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp +++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp @@ -105,7 +105,7 @@ void ConvertSimulatedQuantPass::runOnFunction() { std::unique_ptr mlir::quant::createConvertSimulatedQuantPass() { - return llvm::make_unique(); + return std::make_unique(); } static PassRegistration diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp index 99bf43de8c1..4450bf4d403 100644 --- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp @@ -132,13 +132,13 @@ public: : irTransformer(transform), objectLayer( session, - [this]() { return llvm::make_unique(session); }), + [this]() { return std::make_unique(session); }), compileLayer( session, objectLayer, llvm::orc::ConcurrentIRCompiler(std::move(machineBuilder))), transformLayer(session, compileLayer, makeIRTransformFunction()), dataLayout(layout), mangler(session, this->dataLayout), - threadSafeCtx(llvm::make_unique()) { + threadSafeCtx(std::make_unique()) { session.getMainJITDylib().addGenerator( cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess( layout.getGlobalPrefix()))); @@ -156,9 +156,9 @@ public: if (!dataLayout) return dataLayout.takeError(); - return llvm::make_unique(std::move(*machineBuilder), - std::move(*dataLayout), transformer, - sharedLibPaths); + return std::make_unique(std::move(*machineBuilder), + std::move(*dataLayout), transformer, + sharedLibPaths); } // Add an LLVM module to the main library managed by the JIT engine. @@ -328,7 +328,7 @@ Expected> ExecutionEngine::create(ModuleOp m, std::function transformer, ArrayRef sharedLibPaths) { - auto engine = llvm::make_unique(); + auto engine = std::make_unique(); auto expectedJIT = impl::OrcJIT::createDefault(transformer, sharedLibPaths); if (!expectedJIT) return expectedJIT.takeError(); diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp index 28894066023..e9963ece379 100644 --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -160,7 +160,7 @@ Diagnostic &Diagnostic::attachNote(llvm::Optional noteLoc) { /// Append and return a new note. notes.push_back( - llvm::make_unique(*noteLoc, DiagnosticSeverity::Note)); + std::make_unique(*noteLoc, DiagnosticSeverity::Note)); return *notes.back(); } diff --git a/mlir/lib/Linalg/Transforms/Fusion.cpp b/mlir/lib/Linalg/Transforms/Fusion.cpp index 992c4664b10..a2a63d5bedf 100644 --- a/mlir/lib/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Linalg/Transforms/Fusion.cpp @@ -352,12 +352,12 @@ LinalgFusionPass::LinalgFusionPass(ArrayRef sizes) std::unique_ptr mlir::linalg::createLinalgFusionPass(ArrayRef tileSizes) { - return llvm::make_unique(tileSizes); + return std::make_unique(tileSizes); } static PassRegistration pass("linalg-fusion", "Fuse operations in the linalg dialect", [] { - auto pass = llvm::make_unique(); + auto pass = std::make_unique(); pass->tileSizes.assign(clTileSizes.begin(), clTileSizes.end()); return pass; }); diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index 908191ccd66..de183f8f76e 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -735,7 +735,7 @@ void LowerLinalgToLLVMPass::runOnModule() { } std::unique_ptr mlir::linalg::createLowerLinalgToLLVMPass() { - return llvm::make_unique(); + return std::make_unique(); } static PassRegistration diff --git a/mlir/lib/Linalg/Transforms/LowerToLoops.cpp b/mlir/lib/Linalg/Transforms/LowerToLoops.cpp index 24e56b11063..faef51f5c8c 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLoops.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLoops.cpp @@ -391,7 +391,7 @@ void LowerLinalgToLoopsPass::runOnFunction() { } std::unique_ptr mlir::linalg::createLowerLinalgToLoopsPass() { - return llvm::make_unique(); + return std::make_unique(); } static PassRegistration diff --git a/mlir/lib/Linalg/Transforms/Tiling.cpp b/mlir/lib/Linalg/Transforms/Tiling.cpp index 48c0da8f88f..051278e12f4 100644 --- a/mlir/lib/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Linalg/Transforms/Tiling.cpp @@ -530,12 +530,12 @@ LinalgTilingPass::LinalgTilingPass(ArrayRef sizes, bool promoteViews) { std::unique_ptr mlir::linalg::createLinalgTilingPass(ArrayRef tileSizes, bool promoteViews) { - return llvm::make_unique(tileSizes, promoteViews); + return std::make_unique(tileSizes, promoteViews); } static PassRegistration pass("linalg-tile", "Tile operations in the linalg dialect", [] { - auto pass = llvm::make_unique(); + auto pass = std::make_unique(); pass->tileSizes.assign(clTileSizes.begin(), clTileSizes.end()); pass->promoteViews = clPromoteFullTileViews; return pass; diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp index ba3b4742cc7..13f2738b002 100644 --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -283,7 +283,7 @@ void PassManager::addPass(std::unique_ptr pass) { // Add a verifier run if requested. if (verifyPasses) - mpe->addPass(llvm::make_unique()); + mpe->addPass(std::make_unique()); } /// Add a function pass to the current manager. This takes ownership over the @@ -295,11 +295,11 @@ void PassManager::addPass(std::unique_ptr pass) { /// Create an executor adaptor for this pass. if (disableThreads || !llvm::llvm_is_multithreaded()) { // If multi-threading is disabled, then create a synchronous adaptor. - auto adaptor = llvm::make_unique(); + auto adaptor = std::make_unique(); fpe = &adaptor->getFunctionExecutor(); addPass(std::unique_ptr{adaptor.release()}); } else { - auto adaptor = llvm::make_unique(); + auto adaptor = std::make_unique(); fpe = &adaptor->getFunctionExecutor(); addPass(std::unique_ptr{adaptor.release()}); } @@ -313,7 +313,7 @@ void PassManager::addPass(std::unique_ptr pass) { // Add a verifier run if requested. if (verifyPasses) - fpe->addPass(llvm::make_unique()); + fpe->addPass(std::make_unique()); } /// Add the provided instrumentation to the pass manager. This takes ownership diff --git a/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp b/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp index 6a0cff83ced..4119bde5ac1 100644 --- a/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp +++ b/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp @@ -283,5 +283,5 @@ struct FxpMathTargetConfigImpl : public FxpMathTargetConfig { std::unique_ptr FxpMathTargetConfig::create(SolverContext &context) { - return llvm::make_unique(context); + return std::make_unique(context); } diff --git a/mlir/lib/Quantizer/Support/ConstraintAnalysisGraph.cpp b/mlir/lib/Quantizer/Support/ConstraintAnalysisGraph.cpp index b4d48b78025..cfed2a2647c 100644 --- a/mlir/lib/Quantizer/Support/ConstraintAnalysisGraph.cpp +++ b/mlir/lib/Quantizer/Support/ConstraintAnalysisGraph.cpp @@ -68,7 +68,7 @@ CAGOperandAnchor *CAGSlice::getOperandAnchor(Operation *op, } // Create. - auto anchor = llvm::make_unique(op, operandIdx); + auto anchor = std::make_unique(op, operandIdx); auto *unowned = anchor.release(); unowned->nodeId = allNodes.size(); allNodes.push_back(unowned); @@ -87,7 +87,7 @@ CAGResultAnchor *CAGSlice::getResultAnchor(Operation *op, unsigned resultIdx) { } // Create. - auto anchor = llvm::make_unique(op, resultIdx); + auto anchor = std::make_unique(op, resultIdx); auto *unowned = anchor.release(); unowned->nodeId = allNodes.size(); allNodes.push_back(unowned); diff --git a/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp b/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp index 4868d3be291..a2d38ce211d 100644 --- a/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp +++ b/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp @@ -119,7 +119,7 @@ void AddDefaultStatsPass::runWithConfig(SolverContext &solverContext, } std::unique_ptr mlir::quantizer::createAddDefaultStatsPass() { - return llvm::make_unique(); + return std::make_unique(); } static PassRegistration pass( diff --git a/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp b/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp index e1365e769b3..ff293fc93aa 100644 --- a/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp +++ b/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp @@ -288,7 +288,7 @@ void InferQuantizedTypesPass::transformResultType(CAGResultAnchor *anchor, std::unique_ptr mlir::quantizer::createInferQuantizedTypesPass( SolverContext &solverContext, const TargetConfiguration &config) { - return llvm::make_unique(solverContext, config); + return std::make_unique(solverContext, config); } static PassRegistration diff --git a/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp b/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp index 104a3b60404..b9fbf27d24f 100644 --- a/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp +++ b/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp @@ -68,7 +68,7 @@ void RemoveInstrumentationPass::runOnFunction() { std::unique_ptr mlir::quantizer::createRemoveInstrumentationPass() { - return llvm::make_unique(); + return std::make_unique(); } static PassRegistration diff --git a/mlir/lib/Support/FileUtilities.cpp b/mlir/lib/Support/FileUtilities.cpp index fb9f5cf86da..6f0dc93b235 100644 --- a/mlir/lib/Support/FileUtilities.cpp +++ b/mlir/lib/Support/FileUtilities.cpp @@ -43,8 +43,8 @@ mlir::openInputFile(StringRef inputFilename, std::string *errorMessage) { std::unique_ptr mlir::openOutputFile(StringRef outputFilename, std::string *errorMessage) { std::error_code error; - auto result = llvm::make_unique(outputFilename, error, - llvm::sys::fs::F_None); + auto result = std::make_unique(outputFilename, error, + llvm::sys::fs::F_None); if (error) { if (errorMessage) *errorMessage = "cannot open output file '" + outputFilename.str() + diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp index 344bcaa94b8..7fe3f6272d9 100644 --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -122,7 +122,7 @@ Operator &tblgen::DagNode::getDialectOp(RecordOperatorMap *mapper) const { auto it = mapper->find(opDef); if (it != mapper->end()) return *it->second; - return *mapper->try_emplace(opDef, llvm::make_unique(opDef)) + return *mapper->try_emplace(opDef, std::make_unique(opDef)) .first->second; } diff --git a/mlir/lib/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Transforms/AffineDataCopyGeneration.cpp index e422bd24425..5030f722519 100644 --- a/mlir/lib/Transforms/AffineDataCopyGeneration.cpp +++ b/mlir/lib/Transforms/AffineDataCopyGeneration.cpp @@ -165,7 +165,7 @@ struct AffineDataCopyGeneration std::unique_ptr mlir::createAffineDataCopyGenerationPass( unsigned slowMemorySpace, unsigned fastMemorySpace, unsigned tagMemorySpace, int minDmaTransferSize, uint64_t fastMemCapacityBytes) { - return llvm::make_unique( + return std::make_unique( slowMemorySpace, fastMemorySpace, tagMemorySpace, minDmaTransferSize, fastMemCapacityBytes); } @@ -743,7 +743,7 @@ uint64_t AffineDataCopyGeneration::runOnBlock(Block::iterator begin, } // Compute the MemRefRegion accessed. - auto region = llvm::make_unique(opInst->getLoc()); + auto region = std::make_unique(opInst->getLoc()); if (failed(region->compute(opInst, copyDepth))) { LLVM_DEBUG(llvm::dbgs() << "Error obtaining memory region: semi-affine maps?\n"); diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index 59658526c25..bb89aef7fef 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -213,7 +213,7 @@ void CSE::simplifyRegion(ScopedMapTy &knownValues, DominanceInfo &domInfo, std::deque> stack; // Process the nodes of the dom tree for this region. - stack.emplace_back(llvm::make_unique( + stack.emplace_back(std::make_unique( knownValues, domInfo.getRootNode(®ion))); while (!stack.empty()) { @@ -229,7 +229,7 @@ void CSE::simplifyRegion(ScopedMapTy &knownValues, DominanceInfo &domInfo, if (currentNode->childIterator != currentNode->node->end()) { auto *childNode = *(currentNode->childIterator++); stack.emplace_back( - llvm::make_unique(knownValues, childNode)); + std::make_unique(knownValues, childNode)); } else { // Finally, if the node and all of its children have been processed // then we delete the node. @@ -259,7 +259,7 @@ void CSE::runOnFunction() { } std::unique_ptr mlir::createCSEPass() { - return llvm::make_unique(); + return std::make_unique(); } static PassRegistration diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp index 6f4a40f86f3..db6c8ee26e6 100644 --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -54,7 +54,7 @@ void Canonicalizer::runOnFunction() { /// Create a Canonicalizer pass. std::unique_ptr mlir::createCanonicalizerPass() { - return llvm::make_unique(); + return std::make_unique(); } static PassRegistration pass("canonicalize", diff --git a/mlir/lib/Transforms/LoopCoalescing.cpp b/mlir/lib/Transforms/LoopCoalescing.cpp index eb52e8d5802..2ce0fbd011b 100644 --- a/mlir/lib/Transforms/LoopCoalescing.cpp +++ b/mlir/lib/Transforms/LoopCoalescing.cpp @@ -97,7 +97,7 @@ public: } // namespace std::unique_ptr mlir::createLoopCoalescingPass() { - return llvm::make_unique(); + return std::make_unique(); } static PassRegistration diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 2736ebc0f55..98d01b24be0 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -114,8 +114,8 @@ struct LoopFusion : public FunctionPass { std::unique_ptr mlir::createLoopFusionPass(unsigned fastMemorySpace, uint64_t localBufSizeThreshold, bool maximalFusion) { - return llvm::make_unique(fastMemorySpace, localBufSizeThreshold, - maximalFusion); + return std::make_unique(fastMemorySpace, localBufSizeThreshold, + maximalFusion); } namespace { diff --git a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp index 09fe9afe808..fddc890edcf 100644 --- a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp +++ b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp @@ -77,7 +77,7 @@ static bool isMemRefDereferencingOp(Operation &op) { } std::unique_ptr mlir::createLoopInvariantCodeMotionPass() { - return llvm::make_unique(); + return std::make_unique(); } // Returns true if the individual op is loop invariant. diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index d6ff9a94234..c521a8f6f5d 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -83,7 +83,7 @@ struct LoopTiling : public FunctionPass { /// Function. std::unique_ptr mlir::createLoopTilingPass(uint64_t cacheSizeBytes) { - return llvm::make_unique(cacheSizeBytes); + return std::make_unique(cacheSizeBytes); } // Move the loop body of AffineForOp 'src' from 'src' into the specified diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index c3db90e4b3a..fbe1dcc09f9 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -183,7 +183,7 @@ LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) { std::unique_ptr mlir::createLoopUnrollPass( int unrollFactor, int unrollFull, const std::function &getUnrollFactor) { - return llvm::make_unique( + return std::make_unique( unrollFactor == -1 ? None : Optional(unrollFactor), unrollFull == -1 ? None : Optional(unrollFull), getUnrollFactor); } diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 362aa8683cc..ef92861adf9 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -84,7 +84,7 @@ struct LoopUnrollAndJam : public FunctionPass { std::unique_ptr mlir::createLoopUnrollAndJamPass(int unrollJamFactor) { - return llvm::make_unique( + return std::make_unique( unrollJamFactor == -1 ? None : Optional(unrollJamFactor)); } diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index f24bc6d88da..1879ff63af2 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -530,7 +530,7 @@ class LowerAffinePass : public FunctionPass { /// Lowers If and For operations within a function into their lower level CFG /// equivalent blocks. std::unique_ptr mlir::createLowerAffinePass() { - return llvm::make_unique(); + return std::make_unique(); } static PassRegistration diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index e941850b5b1..8cb50e805f8 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -374,7 +374,7 @@ struct LowerVectorTransfersPass } // end anonymous namespace std::unique_ptr mlir::createLowerVectorTransfersPass() { - return llvm::make_unique(); + return std::make_unique(); } static PassRegistration diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 24b1f77c939..811c6fc7ad5 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -768,7 +768,7 @@ void MaterializeVectorsPass::runOnFunction() { std::unique_ptr mlir::createMaterializeVectorsPass(llvm::ArrayRef vectorSize) { - return llvm::make_unique(vectorSize); + return std::make_unique(vectorSize); } static PassRegistration diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index b16dff93ee3..59a4fbe93ab 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -89,7 +89,7 @@ struct MemRefDataFlowOpt : public FunctionPass { /// Creates a pass to perform optimizations relying on memref dataflow such as /// store to load forwarding, elimination of dead stores, and dead allocs. std::unique_ptr mlir::createMemRefDataFlowOptPass() { - return llvm::make_unique(); + return std::make_unique(); } // This is a straightforward implementation not optimized for speed. Optimize diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index d4d91c9b0e2..db78f500867 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -50,7 +50,7 @@ struct PipelineDataTransfer : public FunctionPass { /// Creates a pass to pipeline explicit movement of data across levels of the /// memory hierarchy. std::unique_ptr mlir::createPipelineDataTransferPass() { - return llvm::make_unique(); + return std::make_unique(); } // Returns the position of the tag memref operand given a DMA operation. diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp index 3cc9309a5d5..97193b49a74 100644 --- a/mlir/lib/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp @@ -89,7 +89,7 @@ struct SimplifyAffineStructures } // end anonymous namespace std::unique_ptr mlir::createSimplifyAffineStructuresPass() { - return llvm::make_unique(); + return std::make_unique(); } void SimplifyAffineStructures::runOnFunction() { diff --git a/mlir/lib/Transforms/StripDebugInfo.cpp b/mlir/lib/Transforms/StripDebugInfo.cpp index 21d8ef15219..15db8b58e88 100644 --- a/mlir/lib/Transforms/StripDebugInfo.cpp +++ b/mlir/lib/Transforms/StripDebugInfo.cpp @@ -39,7 +39,7 @@ void StripDebugInfo::runOnFunction() { /// Creates a pass to strip debug information from a function. std::unique_ptr mlir::createStripDebugInfoPass() { - return llvm::make_unique(); + return std::make_unique(); } static PassRegistration diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 250c76913c2..ffc19d1a1d3 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -82,11 +82,11 @@ bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, std::unique_ptr domInfo; std::unique_ptr postDomInfo; if (domInstFilter) - domInfo = llvm::make_unique( + domInfo = std::make_unique( domInstFilter->getParentOfType()); if (postDomInstFilter) - postDomInfo = llvm::make_unique( + postDomInfo = std::make_unique( postDomInstFilter->getParentOfType()); // The ops where memref replacement succeeds are replaced with new ones. diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 932f00bfcbe..d00174ba2fa 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -1278,7 +1278,7 @@ void Vectorize::runOnFunction() { std::unique_ptr mlir::createVectorizePass(llvm::ArrayRef virtualVectorSize) { - return llvm::make_unique(virtualVectorSize); + return std::make_unique(virtualVectorSize); } static PassRegistration diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp index 9b7fe8e94bf..bde640b2691 100644 --- a/mlir/test/lib/TestDialect/TestPatterns.cpp +++ b/mlir/test/lib/TestDialect/TestPatterns.cpp @@ -250,6 +250,6 @@ static llvm::cl::opt static mlir::PassRegistration legalizer_pass("test-legalize-patterns", "Run test dialect legalization patterns", [] { - return llvm::make_unique( + return std::make_unique( legalizerConversionMode); }); diff --git a/mlir/test/lib/Transforms/TestConstantFold.cpp b/mlir/test/lib/Transforms/TestConstantFold.cpp index 02c66ef86ac..34480f09f57 100644 --- a/mlir/test/lib/Transforms/TestConstantFold.cpp +++ b/mlir/test/lib/Transforms/TestConstantFold.cpp @@ -75,7 +75,7 @@ void TestConstantFold::runOnFunction() { /// Creates a constant folding pass. std::unique_ptr mlir::createTestConstantFoldPass() { - return llvm::make_unique(); + return std::make_unique(); } static PassRegistration diff --git a/mlir/test/lib/Transforms/TestLoopFusion.cpp b/mlir/test/lib/Transforms/TestLoopFusion.cpp index bcb050769a1..8b55d351bdc 100644 --- a/mlir/test/lib/Transforms/TestLoopFusion.cpp +++ b/mlir/test/lib/Transforms/TestLoopFusion.cpp @@ -59,7 +59,7 @@ struct TestLoopFusion : public FunctionPass { } // end anonymous namespace std::unique_ptr mlir::createTestLoopFusionPass() { - return llvm::make_unique(); + return std::make_unique(); } // Gathers all AffineForOps in 'block' at 'currLoopDepth' in 'depthToLoops'. diff --git a/mlir/test/lib/Transforms/TestLoopMapping.cpp b/mlir/test/lib/Transforms/TestLoopMapping.cpp index a9da70a6d5e..f4aa6469a99 100644 --- a/mlir/test/lib/Transforms/TestLoopMapping.cpp +++ b/mlir/test/lib/Transforms/TestLoopMapping.cpp @@ -62,4 +62,4 @@ public: static PassRegistration reg("test-mapping-to-processing-elements", "test mapping a single loop on a virtual processor grid", - [] { return llvm::make_unique(); }); + [] { return std::make_unique(); }); diff --git a/mlir/test/lib/Transforms/TestLoopParametricTiling.cpp b/mlir/test/lib/Transforms/TestLoopParametricTiling.cpp index e01ff66d825..cf68ec1b9a7 100644 --- a/mlir/test/lib/Transforms/TestLoopParametricTiling.cpp +++ b/mlir/test/lib/Transforms/TestLoopParametricTiling.cpp @@ -57,7 +57,7 @@ public: std::unique_ptr mlir::createSimpleParametricTilingPass(ArrayRef outerLoopSizes) { - return llvm::make_unique(outerLoopSizes); + return std::make_unique(outerLoopSizes); } static PassRegistration @@ -65,7 +65,7 @@ static PassRegistration "test application of parametric tiling to the outer loops so that the " "ranges of outer loops become static", [] { - auto pass = llvm::make_unique( + auto pass = std::make_unique( ArrayRef{}); pass->sizes.assign(clOuterLoopSizes.begin(), clOuterLoopSizes.end()); return pass; diff --git a/mlir/test/lib/Transforms/TestVectorizationUtils.cpp b/mlir/test/lib/Transforms/TestVectorizationUtils.cpp index 3bfe6b6fce3..6fe277dcfcb 100644 --- a/mlir/test/lib/Transforms/TestVectorizationUtils.cpp +++ b/mlir/test/lib/Transforms/TestVectorizationUtils.cpp @@ -291,7 +291,7 @@ void VectorizerTestPass::runOnFunction() { } std::unique_ptr mlir::createVectorizerTestPass() { - return llvm::make_unique(); + return std::make_unique(); } static PassRegistration diff --git a/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp b/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp index f75413fdaed..1d174eb8395 100644 --- a/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp +++ b/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp @@ -98,8 +98,8 @@ OwnedCubin compilePtxToCubin(const std::string ptx, FuncOp &function) { "cuLinkComplete"); char *cubinAsChar = static_cast(cubinData); - OwnedCubin result = llvm::make_unique>( - cubinAsChar, cubinAsChar + cubinSize); + OwnedCubin result = + std::make_unique>(cubinAsChar, cubinAsChar + cubinSize); // This will also destroy the cubin data. RETURN_ON_CUDA_ERROR(cuLinkDestroy(linkState), "cuLinkDestroy"); -- cgit v1.2.3 From ba0fa92524ce0aea2385858016bdb08bd941a10d Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 19 Aug 2019 11:00:47 -0700 Subject: NFC: Move LLVMIR, SDBM, and StandardOps to the Dialect/ directory. PiperOrigin-RevId: 264193915 --- .../Linalg/Linalg1/include/linalg1/Common.h | 2 +- .../Linalg1/include/linalg1/LLVMIntrinsics.h | 2 +- mlir/examples/Linalg/Linalg1/lib/Common.cpp | 2 +- .../Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp | 2 +- .../Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp | 2 +- mlir/examples/toy/Ch2/mlir/MLIRGen.cpp | 2 +- mlir/examples/toy/Ch3/mlir/MLIRGen.cpp | 2 +- mlir/examples/toy/Ch4/mlir/MLIRGen.cpp | 2 +- mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp | 2 +- mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp | 2 +- mlir/examples/toy/Ch5/mlir/LateLowering.cpp | 2 +- mlir/examples/toy/Ch5/mlir/MLIRGen.cpp | 2 +- mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp | 2 +- mlir/include/mlir/CMakeLists.txt | 2 - mlir/include/mlir/Dialect/CMakeLists.txt | 2 + mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt | 16 + mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h | 180 ++ mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td | 59 + mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 553 +++++ mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h | 43 + mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 60 + mlir/include/mlir/Dialect/SDBM/SDBM.h | 206 ++ mlir/include/mlir/Dialect/SDBM/SDBMDialect.h | 41 + mlir/include/mlir/Dialect/SDBM/SDBMExpr.h | 530 +++++ .../mlir/Dialect/StandardOps/CMakeLists.txt | 4 + mlir/include/mlir/Dialect/StandardOps/Ops.h | 363 ++++ mlir/include/mlir/Dialect/StandardOps/Ops.td | 905 +++++++++ mlir/include/mlir/EDSC/Builders.h | 2 +- mlir/include/mlir/LLVMIR/CMakeLists.txt | 16 - mlir/include/mlir/LLVMIR/LLVMDialect.h | 180 -- mlir/include/mlir/LLVMIR/LLVMOpBase.td | 59 - mlir/include/mlir/LLVMIR/LLVMOps.td | 553 ----- mlir/include/mlir/LLVMIR/NVVMDialect.h | 43 - mlir/include/mlir/LLVMIR/NVVMOps.td | 60 - mlir/include/mlir/SDBM/SDBM.h | 206 -- mlir/include/mlir/SDBM/SDBMDialect.h | 41 - mlir/include/mlir/SDBM/SDBMExpr.h | 530 ----- mlir/include/mlir/StandardOps/CMakeLists.txt | 4 - mlir/include/mlir/StandardOps/Ops.h | 363 ---- mlir/include/mlir/StandardOps/Ops.td | 905 --------- mlir/include/mlir/Transforms/Utils.h | 2 +- mlir/lib/AffineOps/AffineOps.cpp | 2 +- mlir/lib/Analysis/AffineAnalysis.cpp | 2 +- mlir/lib/Analysis/AffineStructures.cpp | 2 +- mlir/lib/Analysis/LoopAnalysis.cpp | 2 +- mlir/lib/Analysis/MemRefBoundCheck.cpp | 2 +- mlir/lib/Analysis/NestedMatcher.cpp | 2 +- mlir/lib/Analysis/TestMemRefDependenceCheck.cpp | 2 +- mlir/lib/Analysis/Utils.cpp | 2 +- mlir/lib/Analysis/VectorAnalysis.cpp | 2 +- mlir/lib/CMakeLists.txt | 3 - .../ControlFlowToCFG/ConvertControlFlowToCFG.cpp | 2 +- .../GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp | 2 +- .../GPUToCUDA/GenerateCubinAccessors.cpp | 2 +- .../Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 4 +- mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp | 2 +- .../StandardToLLVM/ConvertStandardToLLVM.cpp | 4 +- .../StandardToSPIRV/ConvertStandardToSPIRV.cpp | 2 +- .../Conversion/StandardToSPIRV/StandardToSPIRV.td | 2 +- mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp | 2 +- mlir/lib/Dialect/CMakeLists.txt | 3 + .../FxpMathOps/Transforms/LowerUniformRealMath.cpp | 2 +- mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 2 +- .../lib/Dialect/GPU/Transforms/KernelOutlining.cpp | 2 +- mlir/lib/Dialect/LLVMIR/CMakeLists.txt | 17 + mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 1394 +++++++++++++ mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 88 + mlir/lib/Dialect/LoopOps/LoopOps.cpp | 2 +- .../Dialect/QuantOps/Transforms/ConvertConst.cpp | 2 +- mlir/lib/Dialect/SDBM/CMakeLists.txt | 10 + mlir/lib/Dialect/SDBM/SDBM.cpp | 561 ++++++ mlir/lib/Dialect/SDBM/SDBMDialect.cpp | 20 + mlir/lib/Dialect/SDBM/SDBMExpr.cpp | 647 ++++++ mlir/lib/Dialect/SDBM/SDBMExprDetail.h | 138 ++ .../SPIRV/Serialization/ConvertFromBinary.cpp | 2 +- mlir/lib/Dialect/StandardOps/CMakeLists.txt | 9 + .../Dialect/StandardOps/DialectRegistration.cpp | 22 + mlir/lib/Dialect/StandardOps/Ops.cpp | 2102 ++++++++++++++++++++ mlir/lib/EDSC/Builders.cpp | 2 +- mlir/lib/EDSC/Helpers.cpp | 2 +- mlir/lib/LLVMIR/CMakeLists.txt | 17 - mlir/lib/LLVMIR/IR/LLVMDialect.cpp | 1394 ------------- mlir/lib/LLVMIR/IR/NVVMDialect.cpp | 88 - mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp | 2 +- mlir/lib/Linalg/Transforms/LowerToLoops.cpp | 2 +- mlir/lib/Linalg/Utils/Utils.cpp | 2 +- .../lib/Quantizer/Configurations/FxpMathConfig.cpp | 2 +- mlir/lib/SDBM/CMakeLists.txt | 10 - mlir/lib/SDBM/SDBM.cpp | 561 ------ mlir/lib/SDBM/SDBMDialect.cpp | 20 - mlir/lib/SDBM/SDBMExpr.cpp | 647 ------ mlir/lib/SDBM/SDBMExprDetail.h | 138 -- mlir/lib/StandardOps/CMakeLists.txt | 9 - mlir/lib/StandardOps/DialectRegistration.cpp | 22 - mlir/lib/StandardOps/Ops.cpp | 2102 -------------------- mlir/lib/Support/JitRunner.cpp | 2 +- mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp | 4 +- mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 4 +- mlir/lib/Transforms/AffineDataCopyGeneration.cpp | 2 +- mlir/lib/Transforms/LoopCoalescing.cpp | 2 +- mlir/lib/Transforms/LoopFusion.cpp | 2 +- mlir/lib/Transforms/LoopInvariantCodeMotion.cpp | 2 +- mlir/lib/Transforms/LowerAffine.cpp | 2 +- mlir/lib/Transforms/LowerVectorTransfers.cpp | 2 +- mlir/lib/Transforms/MaterializeVectors.cpp | 2 +- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 2 +- mlir/lib/Transforms/PipelineDataTransfer.cpp | 2 +- mlir/lib/Transforms/Utils/FoldUtils.cpp | 2 +- .../Utils/GreedyPatternRewriteDriver.cpp | 2 +- mlir/lib/Transforms/Utils/LoopFusionUtils.cpp | 2 +- mlir/lib/Transforms/Utils/LoopUtils.cpp | 2 +- mlir/lib/Transforms/Utils/Utils.cpp | 2 +- mlir/lib/Transforms/Vectorize.cpp | 2 +- mlir/test/EDSC/builder-api-test.cpp | 2 +- mlir/test/SDBM/sdbm-api-test.cpp | 6 +- mlir/test/lib/Transforms/TestConstantFold.cpp | 2 +- mlir/test/lib/Transforms/TestLoopFusion.cpp | 2 +- mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp | 2 +- mlir/unittests/SDBM/SDBMTest.cpp | 6 +- 119 files changed, 8050 insertions(+), 8050 deletions(-) create mode 100644 mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt create mode 100644 mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h create mode 100644 mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td create mode 100644 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td create mode 100644 mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h create mode 100644 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td create mode 100644 mlir/include/mlir/Dialect/SDBM/SDBM.h create mode 100644 mlir/include/mlir/Dialect/SDBM/SDBMDialect.h create mode 100644 mlir/include/mlir/Dialect/SDBM/SDBMExpr.h create mode 100644 mlir/include/mlir/Dialect/StandardOps/CMakeLists.txt create mode 100644 mlir/include/mlir/Dialect/StandardOps/Ops.h create mode 100644 mlir/include/mlir/Dialect/StandardOps/Ops.td delete mode 100644 mlir/include/mlir/LLVMIR/CMakeLists.txt delete mode 100644 mlir/include/mlir/LLVMIR/LLVMDialect.h delete mode 100644 mlir/include/mlir/LLVMIR/LLVMOpBase.td delete mode 100644 mlir/include/mlir/LLVMIR/LLVMOps.td delete mode 100644 mlir/include/mlir/LLVMIR/NVVMDialect.h delete mode 100644 mlir/include/mlir/LLVMIR/NVVMOps.td delete mode 100644 mlir/include/mlir/SDBM/SDBM.h delete mode 100644 mlir/include/mlir/SDBM/SDBMDialect.h delete mode 100644 mlir/include/mlir/SDBM/SDBMExpr.h delete mode 100644 mlir/include/mlir/StandardOps/CMakeLists.txt delete mode 100644 mlir/include/mlir/StandardOps/Ops.h delete mode 100644 mlir/include/mlir/StandardOps/Ops.td create mode 100644 mlir/lib/Dialect/LLVMIR/CMakeLists.txt create mode 100644 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp create mode 100644 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp create mode 100644 mlir/lib/Dialect/SDBM/CMakeLists.txt create mode 100644 mlir/lib/Dialect/SDBM/SDBM.cpp create mode 100644 mlir/lib/Dialect/SDBM/SDBMDialect.cpp create mode 100644 mlir/lib/Dialect/SDBM/SDBMExpr.cpp create mode 100644 mlir/lib/Dialect/SDBM/SDBMExprDetail.h create mode 100644 mlir/lib/Dialect/StandardOps/CMakeLists.txt create mode 100644 mlir/lib/Dialect/StandardOps/DialectRegistration.cpp create mode 100644 mlir/lib/Dialect/StandardOps/Ops.cpp delete mode 100644 mlir/lib/LLVMIR/CMakeLists.txt delete mode 100644 mlir/lib/LLVMIR/IR/LLVMDialect.cpp delete mode 100644 mlir/lib/LLVMIR/IR/NVVMDialect.cpp delete mode 100644 mlir/lib/SDBM/CMakeLists.txt delete mode 100644 mlir/lib/SDBM/SDBM.cpp delete mode 100644 mlir/lib/SDBM/SDBMDialect.cpp delete mode 100644 mlir/lib/SDBM/SDBMExpr.cpp delete mode 100644 mlir/lib/SDBM/SDBMExprDetail.h delete mode 100644 mlir/lib/StandardOps/CMakeLists.txt delete mode 100644 mlir/lib/StandardOps/DialectRegistration.cpp delete mode 100644 mlir/lib/StandardOps/Ops.cpp (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h b/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h index 5501158eaab..29ff9bd2d3e 100644 --- a/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h +++ b/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h @@ -21,6 +21,7 @@ #include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/Verifier.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/EDSC/Builders.h" #include "mlir/EDSC/Helpers.h" #include "mlir/EDSC/Intrinsics.h" @@ -35,7 +36,6 @@ #include "mlir/IR/Types.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" -#include "mlir/StandardOps/Ops.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/LoopUtils.h" #include "mlir/Transforms/Passes.h" diff --git a/mlir/examples/Linalg/Linalg1/include/linalg1/LLVMIntrinsics.h b/mlir/examples/Linalg/Linalg1/include/linalg1/LLVMIntrinsics.h index 577981b85ed..fbab091f51b 100644 --- a/mlir/examples/Linalg/Linalg1/include/linalg1/LLVMIntrinsics.h +++ b/mlir/examples/Linalg/Linalg1/include/linalg1/LLVMIntrinsics.h @@ -18,9 +18,9 @@ #ifndef LINALG1_LLVMINTRINSICS_H_ #define LINALG1_LLVMINTRINSICS_H_ +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/EDSC/Builders.h" #include "mlir/EDSC/Intrinsics.h" -#include "mlir/LLVMIR/LLVMDialect.h" // Expose some LLVM IR instructions to declarative builders. namespace intrinsics { diff --git a/mlir/examples/Linalg/Linalg1/lib/Common.cpp b/mlir/examples/Linalg/Linalg1/lib/Common.cpp index 9ce661364d3..da96a57063b 100644 --- a/mlir/examples/Linalg/Linalg1/lib/Common.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/Common.cpp @@ -23,8 +23,8 @@ #include "linalg1/Common.h" #include "linalg1/Ops.h" #include "linalg1/Types.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/EDSC/Intrinsics.h" -#include "mlir/StandardOps/Ops.h" using llvm::ArrayRef; using mlir::ConstantIndexOp; diff --git a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp index ec7e0fe99a7..9073169b260 100644 --- a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp @@ -18,6 +18,7 @@ #include "mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/EDSC/Builders.h" #include "mlir/EDSC/Intrinsics.h" #include "mlir/IR/Attributes.h" @@ -28,7 +29,6 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" #include "mlir/IR/Types.h" -#include "mlir/LLVMIR/LLVMDialect.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" diff --git a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp index 8de1ad6386c..ea3f700e733 100644 --- a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp @@ -18,6 +18,7 @@ #include "mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/EDSC/Builders.h" #include "mlir/EDSC/Intrinsics.h" #include "mlir/IR/Attributes.h" @@ -28,7 +29,6 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" #include "mlir/IR/Types.h" -#include "mlir/LLVMIR/LLVMDialect.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/LowerAffine.h" diff --git a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp index c09c4ad679c..82e2bfa4c68 100644 --- a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp @@ -24,6 +24,7 @@ #include "toy/AST.h" #include "mlir/Analysis/Verifier.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" @@ -32,7 +33,6 @@ #include "mlir/IR/Module.h" #include "mlir/IR/StandardTypes.h" #include "mlir/IR/Types.h" -#include "mlir/StandardOps/Ops.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopedHashTable.h" diff --git a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp index b3ba2f9281c..572f0eeb2ae 100644 --- a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp @@ -25,6 +25,7 @@ #include "toy/Dialect.h" #include "mlir/Analysis/Verifier.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" @@ -33,7 +34,6 @@ #include "mlir/IR/Module.h" #include "mlir/IR/StandardTypes.h" #include "mlir/IR/Types.h" -#include "mlir/StandardOps/Ops.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopedHashTable.h" diff --git a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp index fd385a47004..bd0ad01e7e3 100644 --- a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp @@ -25,6 +25,7 @@ #include "toy/Dialect.h" #include "mlir/Analysis/Verifier.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" @@ -33,7 +34,6 @@ #include "mlir/IR/Module.h" #include "mlir/IR/StandardTypes.h" #include "mlir/IR/Types.h" -#include "mlir/StandardOps/Ops.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopedHashTable.h" diff --git a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp index 793f153291e..6c0d8dfb1df 100644 --- a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp @@ -23,11 +23,11 @@ #include "toy/Dialect.h" #include "mlir/Analysis/Verifier.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/Ops.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Support/STLExtras.h" #include "llvm/ADT/DenseSet.h" diff --git a/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp b/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp index c55a0dbd949..13832f0dae0 100644 --- a/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp @@ -31,13 +31,13 @@ #include "linalg1/Intrinsics.h" #include "linalg1/ViewOp.h" #include "linalg3/TensorOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/EDSC/Builders.h" #include "mlir/EDSC/Helpers.h" #include "mlir/EDSC/Intrinsics.h" #include "mlir/IR/Builders.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/StandardTypes.h" -#include "mlir/LLVMIR/LLVMDialect.h" #include "mlir/Parser.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" diff --git a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp index 8146e763303..cbcf1491c19 100644 --- a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp @@ -30,13 +30,13 @@ #include "linalg3/ConvertToLLVMDialect.h" #include "linalg3/TensorOps.h" #include "linalg3/Transforms.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/EDSC/Builders.h" #include "mlir/EDSC/Helpers.h" #include "mlir/EDSC/Intrinsics.h" #include "mlir/IR/Builders.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/StandardTypes.h" -#include "mlir/LLVMIR/LLVMDialect.h" #include "mlir/Parser.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" diff --git a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp index 88fb95048da..eba2933bd0b 100644 --- a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp @@ -25,6 +25,7 @@ #include "toy/Dialect.h" #include "mlir/Analysis/Verifier.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" @@ -33,7 +34,6 @@ #include "mlir/IR/Module.h" #include "mlir/IR/StandardTypes.h" #include "mlir/IR/Types.h" -#include "mlir/StandardOps/Ops.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopedHashTable.h" diff --git a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp index b6808d713eb..057c7ec1a27 100644 --- a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp @@ -23,11 +23,11 @@ #include "toy/Dialect.h" #include "mlir/Analysis/Verifier.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/Ops.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" diff --git a/mlir/include/mlir/CMakeLists.txt b/mlir/include/mlir/CMakeLists.txt index 202b40b7b2c..fc690a05910 100644 --- a/mlir/include/mlir/CMakeLists.txt +++ b/mlir/include/mlir/CMakeLists.txt @@ -2,6 +2,4 @@ add_subdirectory(AffineOps) add_subdirectory(Dialect) add_subdirectory(EDSC) add_subdirectory(Linalg) -add_subdirectory(LLVMIR) -add_subdirectory(StandardOps) add_subdirectory(VectorOps) diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt index 5ae314a9984..128c04d867a 100644 --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -1,5 +1,7 @@ add_subdirectory(FxpMathOps) add_subdirectory(GPU) +add_subdirectory(LLVMIR) add_subdirectory(LoopOps) add_subdirectory(QuantOps) add_subdirectory(SPIRV) +add_subdirectory(StandardOps) diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt new file mode 100644 index 00000000000..1d7d06bc25c --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt @@ -0,0 +1,16 @@ +set(LLVM_TARGET_DEFINITIONS LLVMOps.td) +mlir_tablegen(LLVMOps.h.inc -gen-op-decls) +mlir_tablegen(LLVMOps.cpp.inc -gen-op-defs) +mlir_tablegen(LLVMOpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(LLVMOpsEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(MLIRLLVMOpsIncGen) +set(LLVM_TARGET_DEFINITIONS NVVMOps.td) +mlir_tablegen(NVVMOps.h.inc -gen-op-decls) +mlir_tablegen(NVVMOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRNVVMOpsIncGen) +set(LLVM_TARGET_DEFINITIONS LLVMOps.td) +mlir_tablegen(LLVMConversions.inc -gen-llvmir-conversions) +add_public_tablegen_target(MLIRLLVMConversionsIncGen) +set(LLVM_TARGET_DEFINITIONS NVVMOps.td) +mlir_tablegen(NVVMConversions.inc -gen-llvmir-conversions) +add_public_tablegen_target(MLIRNVVMConversionsIncGen) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h new file mode 100644 index 00000000000..7318c006692 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -0,0 +1,180 @@ +//===- LLVMDialect.h - MLIR LLVM IR dialect ---------------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines the LLVM IR dialect in MLIR, containing LLVM operations and +// LLVM type system. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LLVMIR_LLVMDIALECT_H_ +#define MLIR_DIALECT_LLVMIR_LLVMDIALECT_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" + +#include "mlir/Dialect/LLVMIR/LLVMOpsEnums.h.inc" + +namespace llvm { +class Type; +class LLVMContext; +} // end namespace llvm + +namespace mlir { +namespace LLVM { +class LLVMDialect; + +namespace detail { +struct LLVMTypeStorage; +struct LLVMDialectImpl; +} // namespace detail + +class LLVMType : public mlir::Type::TypeBase { +public: + enum Kind { + LLVM_TYPE = FIRST_LLVM_TYPE, + }; + + using Base::Base; + + static bool kindof(unsigned kind) { return kind == LLVM_TYPE; } + + LLVMDialect &getDialect(); + llvm::Type *getUnderlyingType() const; + + /// Array type utilities. + LLVMType getArrayElementType(); + unsigned getArrayNumElements(); + + /// Vector type utilities. + LLVMType getVectorElementType(); + + /// Function type utilities. + LLVMType getFunctionParamType(unsigned argIdx); + unsigned getFunctionNumParams(); + LLVMType getFunctionResultType(); + + /// Pointer type utilities. + LLVMType getPointerTo(unsigned addrSpace = 0); + LLVMType getPointerElementTy(); + + /// Struct type utilities. + LLVMType getStructElementType(unsigned i); + + /// Utilities used to generate floating point types. + static LLVMType getDoubleTy(LLVMDialect *dialect); + static LLVMType getFloatTy(LLVMDialect *dialect); + static LLVMType getHalfTy(LLVMDialect *dialect); + + /// Utilities used to generate integer types. + static LLVMType getIntNTy(LLVMDialect *dialect, unsigned numBits); + static LLVMType getInt1Ty(LLVMDialect *dialect) { + return getIntNTy(dialect, /*numBits=*/1); + } + static LLVMType getInt8Ty(LLVMDialect *dialect) { + return getIntNTy(dialect, /*numBits=*/8); + } + static LLVMType getInt8PtrTy(LLVMDialect *dialect) { + return getInt8Ty(dialect).getPointerTo(); + } + static LLVMType getInt16Ty(LLVMDialect *dialect) { + return getIntNTy(dialect, /*numBits=*/16); + } + static LLVMType getInt32Ty(LLVMDialect *dialect) { + return getIntNTy(dialect, /*numBits=*/32); + } + static LLVMType getInt64Ty(LLVMDialect *dialect) { + return getIntNTy(dialect, /*numBits=*/64); + } + + /// Utilities used to generate other miscellaneous types. + static LLVMType getArrayTy(LLVMType elementType, uint64_t numElements); + static LLVMType getFunctionTy(LLVMType result, ArrayRef params, + bool isVarArg); + static LLVMType getFunctionTy(LLVMType result, bool isVarArg) { + return getFunctionTy(result, llvm::None, isVarArg); + } + static LLVMType getStructTy(LLVMDialect *dialect, ArrayRef elements, + bool isPacked = false); + static LLVMType getStructTy(LLVMDialect *dialect, bool isPacked = false) { + return getStructTy(dialect, llvm::None, isPacked); + } + template + static typename std::enable_if::value, + LLVMType>::type + getStructTy(LLVMType elt1, Args... elts) { + SmallVector fields({elt1, elts...}); + return getStructTy(&elt1.getDialect(), fields); + } + static LLVMType getVectorTy(LLVMType elementType, unsigned numElements); + static LLVMType getVoidTy(LLVMDialect *dialect); + +private: + friend LLVMDialect; + + /// Get an LLVMType with a pre-existing llvm type. + static LLVMType get(MLIRContext *context, llvm::Type *llvmType); + + /// Get an LLVMType with an llvm type that may cause changes to the underlying + /// llvm context when constructed. + static LLVMType getLocked(LLVMDialect *dialect, + llvm::function_ref typeBuilder); +}; + +///// Ops ///// +#define GET_OP_CLASSES +#include "mlir/Dialect/LLVMIR/LLVMOps.h.inc" + +class LLVMDialect : public Dialect { +public: + explicit LLVMDialect(MLIRContext *context); + ~LLVMDialect(); + static StringRef getDialectNamespace() { return "llvm"; } + + llvm::LLVMContext &getLLVMContext(); + llvm::Module &getLLVMModule(); + + /// Parse a type registered to this dialect. + Type parseType(StringRef tyData, Location loc) const override; + + /// Print a type registered to this dialect. + void printType(Type type, raw_ostream &os) const override; + + /// Verify a region argument attribute registered to this dialect. + /// Returns failure if the verification failed, success otherwise. + LogicalResult verifyRegionArgAttribute(Operation *op, unsigned regionIdx, + unsigned argIdx, + NamedAttribute argAttr) override; + +private: + friend LLVMType; + + std::unique_ptr impl; +}; + +} // end namespace LLVM +} // end namespace mlir + +#endif // MLIR_DIALECT_LLVMIR_LLVMDIALECT_H_ diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td new file mode 100644 index 00000000000..a68cdbf3da0 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -0,0 +1,59 @@ +//===-- LLVMOpBase.td - LLVM IR dialect shared definitions -*- tablegen -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file contains shared definitions for the LLVM IR dialect and its +// subdialects. +// +//===----------------------------------------------------------------------===// + +#ifdef LLVMIR_OP_BASE +#else +#define LLVMIR_OP_BASE + +#ifdef OP_BASE +#else +include "mlir/IR/OpBase.td" +#endif // OP_BASE + +def LLVM_Dialect : Dialect { + let name = "llvm"; + let cppNamespace = "LLVM"; +} + +// LLVM IR type wrapped in MLIR. +def LLVM_Type : Type()">, + "LLVM dialect type">; + +// Base class for LLVM operations. Defines the interface to the llvm::IRBuilder +// used to translate to LLVM IR proper. +class LLVM_OpBase traits = []> : + Op { + // A pattern for constructing the LLVM IR Instruction (or other Value) that + // corresponds to this op. This pattern can use `builder` to refer to an + // `llvm::IRBuilder<>` instance, $-names of arguments and results and the + // following special variable names: + // - $_resultType - substituted with the LLVM IR type of the result; + // - $_numOperands - substituted with the number of operands (including + // the variadic ones); + // - $_hasResult - substituted with a check that a variadic-result op does + // have a result (LLVM ops can have 0 or 1 result); + // - $_location - mlir::Location object of the instruction. + // Additionally, `$$` can be used to produce the dollar character. + string llvmBuilder = ""; +} + +#endif // LLVMIR_OP_BASE diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td new file mode 100644 index 00000000000..be96d855174 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -0,0 +1,553 @@ +//===-- LLVMOps.td - LLVM IR dialect op definition file ----*- tablegen -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This is the LLVM IR operation definition file. +// +//===----------------------------------------------------------------------===// + +#ifdef LLVMIR_OPS +#else +#define LLVMIR_OPS + +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" + +// Base class for LLVM operations. All operations get an "llvm." prefix in +// their name automatically. LLVM operations have either zero or one result, +// this class is specialized below for both cases and should not be used +// directly. +class LLVM_Op traits = []> : + LLVM_OpBase { +} + +class LLVM_Builder { + string llvmBuilder = builder; +} + +def LLVM_OneResultOpBuilder : OpBuilder< + "Builder *, OperationState *result, Type resultType, " + "ArrayRef operands, ArrayRef attributes = {}", + [{ + if (resultType) result->addTypes(resultType); + result->addOperands(operands); + for (auto namedAttr : attributes) { + result->addAttribute(namedAttr.first, namedAttr.second); + } + }]>; + +def LLVM_ZeroResultOpBuilder : OpBuilder< + "Builder *, OperationState *result, ArrayRef operands, " + "ArrayRef attributes = {}", + [{ + result->addOperands(operands); + for (auto namedAttr : attributes) { + result->addAttribute(namedAttr.first, namedAttr.second); + } + }]>; + +class LLVM_TwoBuilders { + list builders = [b1, b2]; +} + +// Base class for LLVM operations with one result. +class LLVM_OneResultOp traits = []> : + LLVM_Op, Results<(outs LLVM_Type:$res)> { + let builders = [LLVM_OneResultOpBuilder]; +} + +// Compatibility builder that takes an instance of wrapped llvm::VoidType +// to indicate no result. +def LLVM_VoidResultTypeOpBuilder : OpBuilder< + "Builder *builder, OperationState *result, Type resultType, " + "ArrayRef operands, ArrayRef attributes = {}", + [{ + auto llvmType = resultType.dyn_cast(); (void)llvmType; + assert(llvmType && "result must be an LLVM type"); + assert(llvmType.getUnderlyingType() && + llvmType.getUnderlyingType()->isVoidTy() && + "for zero-result operands, only 'void' is accepted as result type"); + build(builder, result, operands, attributes); + }]>; + +// Base class for LLVM operations with zero results. +class LLVM_ZeroResultOp traits = []> : + LLVM_Op, Results<(outs)>, + LLVM_TwoBuilders; + +// Base class for LLVM terminator operations. All terminator operations have +// zero results and an optional list of successors. +class LLVM_TerminatorOp traits = []> : + LLVM_Op, + Arguments<(ins Variadic:$args)>, Results<(outs)> { + let builders = [OpBuilder< + "Builder *, OperationState *result, " + "ArrayRef properOperands, " + "ArrayRef destinations, " + "ArrayRef> operands = {}, " + "ArrayRef attributes = {}", + [{ + result->addOperands(properOperands); + for (auto kvp : llvm::zip(destinations, operands)) { + result->addSuccessor(std::get<0>(kvp), std::get<1>(kvp)); + } + for (auto namedAttr : attributes) { + result->addAttribute(namedAttr.first, namedAttr.second); + } + }] + >]; +} + +// Class for arithmetic binary operations. +class LLVM_ArithmeticOp traits = []> : + LLVM_OneResultOp, + Arguments<(ins LLVM_Type:$lhs, LLVM_Type:$rhs)>, + LLVM_Builder<"$res = builder." # builderFunc # "($lhs, $rhs);"> { + let parser = [{ return impl::parseBinaryOp(parser, result); }]; + let printer = [{ mlir::impl::printBinaryOp(this->getOperation(), p); }]; +} + +// Integer binary operations. +def LLVM_AddOp : LLVM_ArithmeticOp<"add", "CreateAdd", [Commutative]>; +def LLVM_SubOp : LLVM_ArithmeticOp<"sub", "CreateSub">; +def LLVM_MulOp : LLVM_ArithmeticOp<"mul", "CreateMul", [Commutative]>; +def LLVM_UDivOp : LLVM_ArithmeticOp<"udiv", "CreateUDiv">; +def LLVM_SDivOp : LLVM_ArithmeticOp<"sdiv", "CreateSDiv">; +def LLVM_URemOp : LLVM_ArithmeticOp<"urem", "CreateURem">; +def LLVM_SRemOp : LLVM_ArithmeticOp<"srem", "CreateSRem">; +def LLVM_AndOp : LLVM_ArithmeticOp<"and", "CreateAnd">; +def LLVM_OrOp : LLVM_ArithmeticOp<"or", "CreateOr">; +def LLVM_XOrOp : LLVM_ArithmeticOp<"xor", "CreateXor">; + +// Predicate for integer comparisons. +def ICmpPredicateEQ : I64EnumAttrCase<"eq", 0>; +def ICmpPredicateNE : I64EnumAttrCase<"ne", 1>; +def ICmpPredicateSLT : I64EnumAttrCase<"slt", 2>; +def ICmpPredicateSLE : I64EnumAttrCase<"sle", 3>; +def ICmpPredicateSGT : I64EnumAttrCase<"sgt", 4>; +def ICmpPredicateSGE : I64EnumAttrCase<"sge", 5>; +def ICmpPredicateULT : I64EnumAttrCase<"ult", 6>; +def ICmpPredicateULE : I64EnumAttrCase<"ule", 7>; +def ICmpPredicateUGT : I64EnumAttrCase<"ugt", 8>; +def ICmpPredicateUGE : I64EnumAttrCase<"uge", 9>; +def ICmpPredicate : I64EnumAttr< + "ICmpPredicate", + "llvm.icmp comparison predicate", + [ICmpPredicateEQ, ICmpPredicateNE, ICmpPredicateSLT, ICmpPredicateSLE, + ICmpPredicateSGT, ICmpPredicateSGE, ICmpPredicateULT, ICmpPredicateULE, + ICmpPredicateUGT, ICmpPredicateUGE]> { + let cppNamespace = "mlir::LLVM"; + + let returnType = "ICmpPredicate"; + let convertFromStorage = + "static_cast<" # returnType # ">($_self.getValue().getZExtValue())"; +} + +// Other integer operations. +def LLVM_ICmpOp : LLVM_OneResultOp<"icmp", [NoSideEffect]>, + Arguments<(ins ICmpPredicate:$predicate, LLVM_Type:$lhs, + LLVM_Type:$rhs)> { + let llvmBuilder = [{ + $res = builder.CreateICmp(getLLVMCmpPredicate($predicate), $lhs, $rhs); + }]; + let parser = [{ return parseCmpOp(parser, result); }]; + let printer = [{ printICmpOp(p, *this); }]; +} + +// Predicate for float comparisons +def FCmpPredicateFALSE : I64EnumAttrCase<"_false", 0>; +def FCmpPredicateOEQ : I64EnumAttrCase<"oeq", 1>; +def FCmpPredicateOGT : I64EnumAttrCase<"ogt", 2>; +def FCmpPredicateOGE : I64EnumAttrCase<"oge", 3>; +def FCmpPredicateOLT : I64EnumAttrCase<"olt", 4>; +def FCmpPredicateOLE : I64EnumAttrCase<"ole", 5>; +def FCmpPredicateONE : I64EnumAttrCase<"one", 6>; +def FCmpPredicateORD : I64EnumAttrCase<"ord", 7>; +def FCmpPredicateUEQ : I64EnumAttrCase<"ueq", 8>; +def FCmpPredicateUGT : I64EnumAttrCase<"ugt", 9>; +def FCmpPredicateUGE : I64EnumAttrCase<"uge", 10>; +def FCmpPredicateULT : I64EnumAttrCase<"ult", 11>; +def FCmpPredicateULE : I64EnumAttrCase<"ule", 12>; +def FCmpPredicateUNE : I64EnumAttrCase<"une", 13>; +def FCmpPredicateUNO : I64EnumAttrCase<"uno", 14>; +def FCmpPredicateTRUE : I64EnumAttrCase<"_true", 15>; + +def FCmpPredicate : I64EnumAttr< + "FCmpPredicate", + "llvm.fcmp comparison predicate", + [FCmpPredicateFALSE, FCmpPredicateOEQ, FCmpPredicateOGT, FCmpPredicateOGE, + FCmpPredicateOLT, FCmpPredicateOLE, FCmpPredicateONE, FCmpPredicateORD, + FCmpPredicateUEQ, FCmpPredicateUGT, FCmpPredicateUGE, FCmpPredicateULT, + FCmpPredicateULE, FCmpPredicateUNE, FCmpPredicateUNO, FCmpPredicateTRUE + ]> { + let cppNamespace = "mlir::LLVM"; + + let returnType = "FCmpPredicate"; + let convertFromStorage = + "static_cast<" # returnType # ">($_self.getValue().getZExtValue())"; +} + +// Other integer operations. +def LLVM_FCmpOp : LLVM_OneResultOp<"fcmp", [NoSideEffect]>, + Arguments<(ins FCmpPredicate:$predicate, LLVM_Type:$lhs, + LLVM_Type:$rhs)> { + let llvmBuilder = [{ + $res = builder.CreateFCmp(getLLVMCmpPredicate($predicate), $lhs, $rhs); + }]; + let parser = [{ return parseCmpOp(parser, result); }]; + let printer = [{ printFCmpOp(p, *this); }]; +} + +// Floating point binary operations. +def LLVM_FAddOp : LLVM_ArithmeticOp<"fadd", "CreateFAdd">; +def LLVM_FSubOp : LLVM_ArithmeticOp<"fsub", "CreateFSub">; +def LLVM_FMulOp : LLVM_ArithmeticOp<"fmul", "CreateFMul">; +def LLVM_FDivOp : LLVM_ArithmeticOp<"fdiv", "CreateFDiv">; +def LLVM_FRemOp : LLVM_ArithmeticOp<"frem", "CreateFRem">; + +// Memory-related operations. +def LLVM_AllocaOp : + LLVM_OneResultOp<"alloca">, + Arguments<(ins LLVM_Type:$arraySize, OptionalAttr:$alignment)> { + string llvmBuilder = [{ + auto *alloca = builder.CreateAlloca( + $_resultType->getPointerElementType(), $arraySize); + if ($alignment.hasValue()) { + auto align = $alignment.getValue().getZExtValue(); + if (align != 0) + alloca->setAlignment(align); + } + $res = alloca; + }]; + let builders = [OpBuilder< + "Builder *b, OperationState *result, Type resultType, Value *arraySize, " + "unsigned alignment", + [{ + if (alignment == 0) + return build(b, result, resultType, arraySize, IntegerAttr()); + build(b, result, resultType, arraySize, b->getI64IntegerAttr(alignment)); + }]>]; + let parser = [{ return parseAllocaOp(parser, result); }]; + let printer = [{ printAllocaOp(p, *this); }]; + let verifier = [{ + if (alignment().hasValue()) { + auto align = alignment().getValue().getSExtValue(); + if (align < 0) + return emitOpError("expected positive alignment"); + } + return success(); + }]; +} +def LLVM_GEPOp : LLVM_OneResultOp<"getelementptr", [NoSideEffect]>, + Arguments<(ins LLVM_Type:$base, Variadic:$indices)>, + LLVM_Builder<"$res = builder.CreateGEP($base, $indices);"> { + let parser = [{ return parseGEPOp(parser, result); }]; + let printer = [{ printGEPOp(p, *this); }]; +} +def LLVM_LoadOp : LLVM_OneResultOp<"load">, Arguments<(ins LLVM_Type:$addr)>, + LLVM_Builder<"$res = builder.CreateLoad($addr);"> { + let builders = [OpBuilder< + "Builder *b, OperationState *result, Value *addr", + [{ + auto type = addr->getType().cast().getPointerElementTy(); + build(b, result, type, addr); + }]>]; + let parser = [{ return parseLoadOp(parser, result); }]; + let printer = [{ printLoadOp(p, *this); }]; +} +def LLVM_StoreOp : LLVM_ZeroResultOp<"store">, + Arguments<(ins LLVM_Type:$value, LLVM_Type:$addr)>, + LLVM_Builder<"builder.CreateStore($value, $addr);"> { + let parser = [{ return parseStoreOp(parser, result); }]; + let printer = [{ printStoreOp(p, *this); }]; +} + +// Casts. +class LLVM_CastOp traits = []> : + LLVM_OneResultOp, + Arguments<(ins LLVM_Type:$arg)>, + LLVM_Builder<"$res = builder." # builderFunc # "($arg, $_resultType);"> { + let parser = [{ return mlir::impl::parseCastOp(parser, result); }]; + let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }]; +} +def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "CreateBitCast">; +def LLVM_IntToPtrOp : LLVM_CastOp<"inttoptr", "CreateIntToPtr">; +def LLVM_PtrToIntOp : LLVM_CastOp<"ptrtoint", "CreatePtrToInt">; +def LLVM_SExtOp : LLVM_CastOp<"sext", "CreateSExt">; +def LLVM_ZExtOp : LLVM_CastOp<"zext", "CreateZExt">; +def LLVM_TruncOp : LLVM_CastOp<"trunc", "CreateTrunc">; +def LLVM_SIToFPOp : LLVM_CastOp<"sitofp", "CreateSIToFP">; + +// Call-related operations. +def LLVM_CallOp : LLVM_Op<"call">, + Arguments<(ins OptionalAttr:$callee, + // TODO(b/133216756): fix test failure and + // change to LLVM_Type + Variadic)>, + Results<(outs Variadic)>, + LLVM_TwoBuilders { + let verifier = [{ + if (getNumResults() > 1) + return emitOpError("must have 0 or 1 result"); + return success(); + }]; + let parser = [{ return parseCallOp(parser, result); }]; + let printer = [{ printCallOp(p, *this); }]; +} +def LLVM_ExtractElementOp : LLVM_OneResultOp<"extractelement", [NoSideEffect]>, + Arguments<(ins LLVM_Type:$vector, + LLVM_Type:$position)> { + string llvmBuilder = [{ + $res = builder.CreateExtractElement($vector, $position); + }]; + let builders = [OpBuilder< + "Builder *b, OperationState *result, Value *vector, Value *position," + "ArrayRef attrs = {}">]; + let parser = [{ return parseExtractElementOp(parser, result); }]; + let printer = [{ printExtractElementOp(p, *this); }]; +} +def LLVM_ExtractValueOp : LLVM_OneResultOp<"extractvalue", [NoSideEffect]>, + Arguments<(ins LLVM_Type:$container, + ArrayAttr:$position)> { + string llvmBuilder = [{ + $res = builder.CreateExtractValue($container, extractPosition($position)); + }]; + let parser = [{ return parseExtractValueOp(parser, result); }]; + let printer = [{ printExtractValueOp(p, *this); }]; +} +def LLVM_InsertElementOp : LLVM_OneResultOp<"insertelement", [NoSideEffect]>, + Arguments<(ins LLVM_Type:$vector, LLVM_Type:$value, + LLVM_Type:$position)> { + string llvmBuilder = [{ + $res = builder.CreateInsertElement($vector, $value, $position); + }]; + let parser = [{ return parseInsertElementOp(parser, result); }]; + let printer = [{ printInsertElementOp(p, *this); }]; +} +def LLVM_InsertValueOp : LLVM_OneResultOp<"insertvalue", [NoSideEffect]>, + Arguments<(ins LLVM_Type:$container, LLVM_Type:$value, + ArrayAttr:$position)> { + string llvmBuilder = [{ + $res = builder.CreateInsertValue($container, $value, + extractPosition($position)); + }]; + let builders = [OpBuilder< + "Builder *b, OperationState *result, Value *container, Value *value, " + "ArrayAttr position", + [{ + build(b, result, container->getType(), container, value, position); + }]>]; + let parser = [{ return parseInsertValueOp(parser, result); }]; + let printer = [{ printInsertValueOp(p, *this); }]; +} +def LLVM_ShuffleVectorOp + : LLVM_OneResultOp<"shufflevector", [NoSideEffect]>, + Arguments<(ins LLVM_Type:$v1, LLVM_Type:$v2, I32ArrayAttr:$mask)>, + LLVM_Builder< + "$res = builder.CreateShuffleVector($v1, $v2, extractPosition($mask));"> { + let builders = [OpBuilder< + "Builder *b, OperationState *result, Value *v1, Value *v2, " + "ArrayAttr mask, ArrayRef attrs = {}">]; + let verifier = [{ + auto wrappedVectorType1 = v1()->getType().cast(); + auto wrappedVectorType2 = v2()->getType().cast(); + if (!wrappedVectorType2.getUnderlyingType()->isVectorTy()) + return emitOpError("expected LLVM IR Dialect vector type for operand #2"); + if (wrappedVectorType1.getVectorElementType() != + wrappedVectorType2.getVectorElementType()) + return emitOpError("expected matching LLVM IR Dialect element types"); + return success(); + }]; + let parser = [{ return parseShuffleVectorOp(parser, result); }]; + let printer = [{ printShuffleVectorOp(p, *this); }]; +} + +// Misc operations. +def LLVM_SelectOp + : LLVM_OneResultOp<"select", [NoSideEffect]>, + Arguments<(ins LLVM_Type:$condition, LLVM_Type:$trueValue, + LLVM_Type:$falseValue)>, + LLVM_Builder< + "$res = builder.CreateSelect($condition, $trueValue, $falseValue);"> { + let parser = [{ return parseSelectOp(parser, result); }]; + let printer = [{ printSelectOp(p, *this); }]; +} + +// Terminators. +def LLVM_BrOp : LLVM_TerminatorOp<"br", []> { + let parser = [{ return parseBrOp(parser, result); }]; + let printer = [{ printBrOp(p, *this); }]; +} +def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br", []> { + let verifier = [{ + if (getNumSuccessors() != 2) + return emitOpError("expected exactly two successors"); + return success(); + }]; + let parser = [{ return parseCondBrOp(parser, result); }]; + let printer = [{ printCondBrOp(p, *this); }]; +} +def LLVM_ReturnOp : LLVM_TerminatorOp<"return", []> { + string llvmBuilder = [{ + if ($_numOperands != 0) + builder.CreateRet($args[0]); + else + builder.CreateRetVoid(); + }]; + + let verifier = [{ + if (getNumOperands() > 1) + return emitOpError("expects at most 1 operand"); + return success(); + }]; + + let parser = [{ return parseReturnOp(parser, result); }]; + let printer = [{ printReturnOp(p, *this); }]; +} +def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable", []> { + string llvmBuilder = [{ builder.CreateUnreachable(); }]; + let parser = [{ return success(); }]; + let printer = [{ *p << getOperationName(); }]; +} + +// Pseudo-operations (do not appear in LLVM IR but necessary for the dialect to +// work correctly). +def LLVM_AddressOfOp + : LLVM_OneResultOp<"addressof">, + Arguments<(ins SymbolRefAttr:$global_name)> { + let builders = [ + OpBuilder<"Builder *builder, OperationState *result, LLVMType resType, " + "StringRef name, ArrayRef attrs = {}", [{ + result->addAttribute("global_name", builder->getSymbolRefAttr(name)); + result->addAttributes(attrs); + result->addTypes(resType);}]>, + + OpBuilder<"Builder *builder, OperationState *result, GlobalOp global, " + "ArrayRef attrs = {}", [{ + build(builder, result, global.getType().getPointerTo(), global.sym_name(), + attrs);}]> + ]; + + let extraClassDeclaration = [{ + /// Return the llvm.global operation that defined the value referenced here. + GlobalOp getGlobal(); + }]; + + let printer = "printAddressOfOp(p, *this);"; + let parser = "return parseAddressOfOp(parser, result);"; + let verifier = "return ::verify(*this);"; +} + +def LLVM_GlobalOp + : LLVM_ZeroResultOp<"global">, + Arguments<(ins TypeAttr:$type, UnitAttr:$constant, StrAttr:$sym_name, + AnyAttr:$value)> { + + let builders = [ + OpBuilder<"Builder *builder, OperationState *result, LLVMType type, " + "bool isConstant, StringRef name, Attribute value, " + "ArrayRef attrs = {}"> + ]; + + let extraClassDeclaration = [{ + /// Return the LLVM type of the global. + LLVMType getType() { + return type().cast(); + } + }]; + + let printer = "printGlobalOp(p, *this);"; + let parser = "return parseGlobalOp(parser, result);"; + let verifier = "return ::verify(*this);"; +} + +def LLVM_LLVMFuncOp : LLVM_ZeroResultOp<"func", + [NativeOpTrait<"IsIsolatedFromAbove">, NativeOpTrait<"FunctionLike">]> { + let summary = "LLVM dialect function, has wrapped LLVM IR function type"; + + let regions = (region AnyRegion:$body); + + let skipDefaultBuilders = 1; + + let builders = [ + OpBuilder<"Builder *builder, OperationState *result, StringRef name, " + "LLVMType type, ArrayRef attrs, " + "ArrayRef argAttrs = {}"> + ]; + + let extraClassDeclaration = [{ + LLVMType getType() { + return getAttrOfType(getTypeAttrName()) + .getValue().cast(); + } + bool isVarArg() { + return getType().getUnderlyingType()->isFunctionVarArg(); + } + + // Hook for OpTrait::FunctionLike, returns the number of function arguments. + // Depends on the type attribute being correct as checked by verifyType. + unsigned getNumFuncArguments(); + + // Hook for OpTrait::FunctionLike, called after verifying that the 'type' + // attribute is present. This can check for preconditions of the + // getNumArguments hook not failing. + LogicalResult verifyType(); + }]; + + let verifier = [{ return ::verify(*this); }]; + let printer = [{ printLLVMFuncOp(p, *this); }]; + let parser = [{ + return impl::parseFunctionLikeOp(parser, result, /*allowVariadic=*/true, + buildLLVMFunctionType); + }]; +} + +def LLVM_UndefOp : LLVM_OneResultOp<"undef", [NoSideEffect]>, + LLVM_Builder<"$res = llvm::UndefValue::get($_resultType);"> { + let parser = [{ return parseUndefOp(parser, result); }]; + let printer = [{ printUndefOp(p, *this); }]; +} +def LLVM_ConstantOp + : LLVM_OneResultOp<"constant", [NoSideEffect]>, + Arguments<(ins AnyAttr:$value)>, + LLVM_Builder<"$res = getLLVMConstant($_resultType, $value, $_location);"> +{ + let parser = [{ return parseConstantOp(parser, result); }]; + let printer = [{ printConstantOp(p, *this); }]; +} + +// Operations that correspond to LLVM intrinsics. With MLIR operation set being +// extendable, there is no reason to introduce a hard boundary between "core" +// operations and intrinsics. + +def LLVM_fmuladd : LLVM_Op<"fmuladd", [NoSideEffect]>, + Arguments<(ins LLVM_Type:$a, LLVM_Type:$b, LLVM_Type:$c)>, + Results<(outs LLVM_Type:$res)> { + let llvmBuilder = [{ + llvm::Module *module = builder.GetInsertBlock()->getModule(); + llvm::Function *fn = llvm::Intrinsic::getDeclaration( + module, llvm::Intrinsic::fmuladd, + {$a->getType(), $b->getType(), $c->getType()}); + $res = builder.CreateCall(fn, {$a, $b, $c}); + }]; +} + + +#endif // LLVMIR_OPS diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h new file mode 100644 index 00000000000..4c39794557b --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h @@ -0,0 +1,43 @@ +//===- NVVMDialect.h - MLIR NVVM IR dialect ---------------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines the NVVM IR dialect in MLIR, containing NVVM operations and +// NVVM specific extensions to the LLVM type system. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LLVMIR_NVVMDIALECT_H_ +#define MLIR_DIALECT_LLVMIR_NVVMDIALECT_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +namespace mlir { +namespace NVVM { + +///// Ops ///// +#define GET_OP_CLASSES +#include "mlir/Dialect/LLVMIR/NVVMOps.h.inc" + +class NVVMDialect : public Dialect { +public: + explicit NVVMDialect(MLIRContext *context); +}; + +} // namespace NVVM +} // namespace mlir + +#endif /* MLIR_DIALECT_LLVMIR_NVVMDIALECT_H_ */ diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td new file mode 100644 index 00000000000..72bbb13570a --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -0,0 +1,60 @@ +//===-- NVVMOps.td - NVVM IR dialect op definition file ----*- tablegen -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This is the NVVM IR operation definition file. +// +//===----------------------------------------------------------------------===// + +#ifdef NVVMIR_OPS +#else +#define NVVMIR_OPS + +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" + +def NVVM_Dialect : Dialect { + let name = "nvvm"; + let cppNamespace = "NVVM"; +} + +class NVVM_Op traits = []> : + LLVM_OpBase { +} + +class NVVM_SpecialRegisterOp traits = []> : + NVVM_Op, + Results<(outs LLVM_Type:$res)>, Arguments<(ins)> { + string llvmBuilder = "$res = createIntrinsicCall(builder," + # "llvm::Intrinsic::nvvm_" # !subst(".","_", mnemonic) # ");"; + let parser = [{ return parseNVVMSpecialRegisterOp(parser, result); }]; + let printer = [{ printNVVMSpecialRegisterOp(p, this->getOperation()); }]; +} + +def NVVM_ThreadIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.x">; +def NVVM_ThreadIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.y">; +def NVVM_ThreadIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.z">; +def NVVM_BlockDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ntid.x">; +def NVVM_BlockDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ntid.y">; +def NVVM_BlockDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ntid.z">; +def NVVM_BlockIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ctaid.x">; +def NVVM_BlockIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ctaid.y">; +def NVVM_BlockIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ctaid.z">; +def NVVM_GridDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.x">; +def NVVM_GridDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.y">; +def NVVM_GridDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.z">; + +#endif // NVVMIR_OPS diff --git a/mlir/include/mlir/Dialect/SDBM/SDBM.h b/mlir/include/mlir/Dialect/SDBM/SDBM.h new file mode 100644 index 00000000000..3115805bb5f --- /dev/null +++ b/mlir/include/mlir/Dialect/SDBM/SDBM.h @@ -0,0 +1,206 @@ +//===- SDBM.h - MLIR SDBM declaration ---------------------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// A striped difference-bound matrix (SDBM) is a set in Z^N (or R^N) defined +// as {(x_1, ... x_n) | f(x_1, ... x_n) >= 0} where f is an SDBM expression. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SDBM_SDBM_H +#define MLIR_DIALECT_SDBM_SDBM_H + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/DenseMap.h" + +namespace mlir { + +class MLIRContext; +class SDBMDialect; +class SDBMExpr; +class SDBMPositiveExpr; + +/// A utility class for SDBM to represent an integer with potentially infinite +/// positive value. This uses the largest value of int64_t to represent infinity +/// and redefines the arithmetic operators so that the infinity "saturates": +/// inf + x = inf, +/// inf - x = inf. +/// If a sum of two finite values reaches the largest value of int64_t, the +/// behavior of IntInfty is undefined (in practice, it asserts), similarly to +/// regular signed integer overflow. +class IntInfty { +public: + constexpr static int64_t infty = std::numeric_limits::max(); + + /*implicit*/ IntInfty(int64_t v) : value(v) {} + + IntInfty &operator=(int64_t v) { + value = v; + return *this; + } + + static IntInfty infinity() { return IntInfty(infty); } + + int64_t getValue() const { return value; } + explicit operator int64_t() const { return value; } + + bool isFinite() { return value != infty; } + +private: + int64_t value; +}; + +inline IntInfty operator+(IntInfty lhs, IntInfty rhs) { + if (!lhs.isFinite() || !rhs.isFinite()) + return IntInfty::infty; + + // Check for overflows, treating the sum of two values adding up to INT_MAX as + // overflow. Convert values to unsigned to get an extra bit and avoid the + // undefined behavior of signed integer overflows. + assert((lhs.getValue() <= 0 || rhs.getValue() <= 0 || + static_cast(lhs.getValue()) + + static_cast(rhs.getValue()) < + static_cast(std::numeric_limits::max())) && + "IntInfty overflow"); + // Check for underflows by converting values to unsigned to avoid undefined + // behavior of signed integers perform the addition (bitwise result is same + // because numbers are required to be two's complement in C++) and check if + // the sign bit remains negative. + assert((lhs.getValue() >= 0 || rhs.getValue() >= 0 || + ((static_cast(lhs.getValue()) + + static_cast(rhs.getValue())) >> + 63) == 1) && + "IntInfty underflow"); + + return lhs.getValue() + rhs.getValue(); +} + +inline bool operator<(IntInfty lhs, IntInfty rhs) { + return lhs.getValue() < rhs.getValue(); +} + +inline bool operator<=(IntInfty lhs, IntInfty rhs) { + return lhs.getValue() <= rhs.getValue(); +} + +inline bool operator==(IntInfty lhs, IntInfty rhs) { + return lhs.getValue() == rhs.getValue(); +} + +inline bool operator!=(IntInfty lhs, IntInfty rhs) { return !(lhs == rhs); } + +/// Striped difference-bound matrix is a representation of an integer set bound +/// by a system of SDBMExprs interpreted as inequalities "expr <= 0". +class SDBM { +public: + /// Obtain an SDBM from a list of SDBM expressions treated as inequalities and + /// equalities with zero. + static SDBM get(ArrayRef inequalities, + ArrayRef equalities); + + void getSDBMExpressions(SDBMDialect *dialect, + SmallVectorImpl &inequalities, + SmallVectorImpl &equalities); + + void print(llvm::raw_ostream &os); + void dump(); + + IntInfty operator()(int i, int j) { return at(i, j); } + +private: + /// Get the given element of the difference bounds matrix. First index + /// corresponds to the negative term of the difference, second index + /// corresponds to the positive term of the difference. + IntInfty &at(int i, int j) { return matrix[i * getNumVariables() + j]; } + + /// Populate `inequalities` and `equalities` based on the values at(row,col) + /// and at(col,row) of the DBM. Depending on the values being finite and + /// being subsumed by stripe expressions, this may or may not add elements to + /// the lists of equalities and inequalities. + void convertDBMElement(unsigned row, unsigned col, SDBMPositiveExpr rowExpr, + SDBMPositiveExpr colExpr, + SmallVectorImpl &inequalities, + SmallVectorImpl &equalities); + + /// Populate `inequalities` based on the value at(pos,pos) of the DBM. Only + /// adds new inequalities if the inequality is not trivially true. + void convertDBMDiagonalElement(unsigned pos, SDBMPositiveExpr expr, + SmallVectorImpl &inequalities); + + /// Get the total number of elements in the matrix. + unsigned getNumVariables() const { + return 1 + numDims + numSymbols + numTemporaries; + } + + /// Get the position in the matrix that corresponds to the given dimension. + unsigned getDimPosition(unsigned position) const { return 1 + position; } + + /// Get the position in the matrix that corresponds to the given symbol. + unsigned getSymbolPosition(unsigned position) const { + return 1 + numDims + position; + } + + /// Get the position in the matrix that corresponds to the given temporary. + unsigned getTemporaryPosition(unsigned position) const { + return 1 + numDims + numSymbols + position; + } + + /// Number of dimensions in the system, + unsigned numDims; + /// Number of symbols in the system. + unsigned numSymbols; + /// Number of temporary variables in the system. + unsigned numTemporaries; + + /// Difference bounds matrix, stored as a linearized row-major vector. + /// Each value in this matrix corresponds to an inequality + /// + /// v@col - v@row <= at(row, col) + /// + /// where v@col and v@row are the variables that correspond to the linearized + /// position in the matrix. The positions correspond to + /// + /// - constant 0 (producing constraints v@col <= X and -v@row <= Y); + /// - SDBM expression dimensions (d0, d1, ...); + /// - SDBM expression symbols (s0, s1, ...); + /// - temporary variables (t0, t1, ...). + /// + /// Temporary variables are introduced to represent expressions that are not + /// trivially a difference between two variables. For example, if one side of + /// a difference expression is itself a stripe expression, it will be replaced + /// with a temporary variable assigned equal to this expression. + /// + /// Infinite entries in the matrix correspond correspond to an absence of a + /// constraint: + /// + /// v@col - v@row <= infinity + /// + /// is trivially true. Negated values at symmetric positions in the matrix + /// allow one to couple two inequalities into a single equality. + std::vector matrix; + + /// The mapping between the indices of variables in the DBM and the stripe + /// expressions they are equal to. These expressions are stored as they + /// appeared when constructing an SDBM from a SDBMExprs, in particular no + /// temporaries can appear in these expressions. This removes the need to + /// iteratively substitute definitions of the temporaries in the reverse + /// conversion. + llvm::DenseMap stripeToPoint; +}; + +} // namespace mlir + +#endif // MLIR_DIALECT_SDBM_SDBM_H diff --git a/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h b/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h new file mode 100644 index 00000000000..e3573ba604d --- /dev/null +++ b/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h @@ -0,0 +1,41 @@ +//===- SDBMDialect.h - Dialect for striped DBMs -----------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef MLIR_DIALECT_SDBM_SDBMDIALECT_H +#define MLIR_DIALECT_SDBM_SDBMDIALECT_H + +#include "mlir/IR/Dialect.h" +#include "mlir/Support/StorageUniquer.h" + +namespace mlir { +class MLIRContext; + +class SDBMDialect : public Dialect { +public: + SDBMDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) {} + + static StringRef getDialectNamespace() { return "sdbm"; } + + /// Get the uniquer for SDBM expressions. This should not be used directly. + StorageUniquer &getUniquer() { return uniquer; } + +private: + StorageUniquer uniquer; +}; +} // namespace mlir + +#endif // MLIR_DIALECT_SDBM_SDBMDIALECT_H diff --git a/mlir/include/mlir/Dialect/SDBM/SDBMExpr.h b/mlir/include/mlir/Dialect/SDBM/SDBMExpr.h new file mode 100644 index 00000000000..1e695b68f97 --- /dev/null +++ b/mlir/include/mlir/Dialect/SDBM/SDBMExpr.h @@ -0,0 +1,530 @@ +//===- SDBMExpr.h - MLIR SDBM Expression ------------------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// A striped difference-bound matrix (SDBM) expression is a constant expression, +// an identifier, a binary expression with constant RHS and +, stripe operators +// or a difference expression between two identifiers. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SDBM_SDBMEXPR_H +#define MLIR_DIALECT_SDBM_SDBMEXPR_H + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/DenseMapInfo.h" + +namespace mlir { + +class AffineExpr; +class MLIRContext; + +enum class SDBMExprKind { Add, Stripe, Diff, Constant, DimId, SymbolId, Neg }; + +namespace detail { +struct SDBMExprStorage; +struct SDBMBinaryExprStorage; +struct SDBMDiffExprStorage; +struct SDBMPositiveExprStorage; +struct SDBMConstantExprStorage; +struct SDBMNegExprStorage; +} // namespace detail + +class SDBMConstantExpr; +class SDBMDialect; +class SDBMDimExpr; +class SDBMSymbolExpr; + +/// Striped Difference-Bounded Matrix (SDBM) expression is a base left-hand side +/// expression for the SDBM framework. SDBM expressions are a subset of affine +/// expressions supporting low-complexity algorithms for the operations used in +/// loop transformations. In particular, are supported: +/// - constant expressions; +/// - single variables (dimensions and symbols) with +1 or -1 coefficient; +/// - stripe expressions: "x # C", where "x" is a single variable or another +/// stripe expression, "#" is the stripe operator, and "C" is a constant +/// expression; "#" is defined as x - x mod C. +/// - sum expressions between single variable/stripe expressions and constant +/// expressions; +/// - difference expressions between single variable/stripe expressions. +/// `SDBMExpr` class hierarchy provides a type-safe interface to constructing +/// and operating on SDBM expressions. For example, it requires the LHS of a +/// sum expression to be a single variable or a stripe expression. These +/// restrictions are intended to force the caller to perform the necessary +/// simplifications to stay within the SDBM domain, because SDBM expressions do +/// not combine in more cases than they do. This choice may be reconsidered in +/// the future. +/// +/// `SDBMExpr` and derived classes are thin wrappers around a pointer owned by +/// an MLIRContext, and should be used by-value. They are uniqued in the +/// MLIRContext and immortal. +class SDBMExpr { +public: + using ImplType = detail::SDBMExprStorage; + SDBMExpr() : impl(nullptr) {} + /* implicit */ SDBMExpr(ImplType *expr) : impl(expr) {} + + /// SDBM expressions are thin wrappers around a unique'ed immutable pointer, + /// which makes them trivially assignable and trivially copyable. + SDBMExpr(const SDBMExpr &) = default; + SDBMExpr &operator=(const SDBMExpr &) = default; + + /// SDBM expressions can be compared straight-forwardly. + bool operator==(const SDBMExpr &other) const { return impl == other.impl; } + bool operator!=(const SDBMExpr &other) const { return !(*this == other); } + + /// SDBM expressions are convertible to `bool`: null expressions are converted + /// to false, non-null expressions are converted to true. + explicit operator bool() const { return impl != nullptr; } + bool operator!() const { return !static_cast(*this); } + + /// Negate the given SDBM expression. + SDBMExpr operator-(); + + /// Prints the SDBM expression. + void print(raw_ostream &os) const; + void dump() const; + + /// LLVM-style casts. + template bool isa() const { return U::isClassFor(*this); } + template U dyn_cast() const { + if (!isa()) + return {}; + return U(const_cast(this)->impl); + } + template U cast() const { + assert(isa() && "cast to incorrect subtype"); + return U(const_cast(this)->impl); + } + + /// Support for LLVM hashing. + ::llvm::hash_code hash_value() const { return ::llvm::hash_value(impl); } + + /// Returns the kind of the SDBM expression. + SDBMExprKind getKind() const; + + /// Returns the MLIR context in which this expression lives. + MLIRContext *getContext() const; + + /// Returns the SDBM dialect instance. + SDBMDialect *getDialect() const; + + /// Convert the SDBM expression into an Affine expression. This always + /// succeeds because SDBM are a subset of affine. + AffineExpr getAsAffineExpr() const; + + /// Try constructing an SDBM expression from the given affine expression. + /// This may fail if the affine expression is not representable as SDBM, in + /// which case llvm::None is returned. The conversion procedure recognizes + /// (nested) multiplicative ((x floordiv B) * B) and additive (x - x mod B) + /// patterns for the stripe expression. + static Optional tryConvertAffineExpr(AffineExpr affine); + +protected: + ImplType *impl; +}; + +/// SDBM constant expression, wraps a 64-bit integer. +class SDBMConstantExpr : public SDBMExpr { +public: + using ImplType = detail::SDBMConstantExprStorage; + + using SDBMExpr::SDBMExpr; + + /// Obtain or create a constant expression unique'ed in the given dialect + /// (which belongs to a context). + static SDBMConstantExpr get(SDBMDialect *dialect, int64_t value); + + static bool isClassFor(const SDBMExpr &expr) { + return expr.getKind() == SDBMExprKind::Constant; + } + + int64_t getValue() const; +}; + +/// SDBM varying expression can be one of: +/// - input variable expression; +/// - stripe expression; +/// - negation (product with -1) of either of the above. +/// - sum of a varying and a constant expression +/// - difference between varying expressions +class SDBMVaryingExpr : public SDBMExpr { +public: + using ImplType = detail::SDBMExprStorage; + using SDBMExpr::SDBMExpr; + + static bool isClassFor(const SDBMExpr &expr) { + return expr.getKind() == SDBMExprKind::DimId || + expr.getKind() == SDBMExprKind::SymbolId || + expr.getKind() == SDBMExprKind::Neg || + expr.getKind() == SDBMExprKind::Stripe || + expr.getKind() == SDBMExprKind::Add || + expr.getKind() == SDBMExprKind::Diff; + } +}; + +/// SDBM positive variable expression can be one of: +/// - single variable expression; +/// - stripe expression. +class SDBMPositiveExpr : public SDBMVaryingExpr { +public: + using SDBMVaryingExpr::SDBMVaryingExpr; + + static bool isClassFor(const SDBMExpr &expr) { + return expr.getKind() == SDBMExprKind::DimId || + expr.getKind() == SDBMExprKind::SymbolId || + expr.getKind() == SDBMExprKind::Stripe; + } +}; + +/// SDBM sum expression. LHS is a varying expression and RHS is always a +/// constant expression. +class SDBMSumExpr : public SDBMVaryingExpr { +public: + using ImplType = detail::SDBMBinaryExprStorage; + using SDBMVaryingExpr::SDBMVaryingExpr; + + /// Obtain or create a sum expression unique'ed in the given context. + static SDBMSumExpr get(SDBMVaryingExpr lhs, SDBMConstantExpr rhs); + + static bool isClassFor(const SDBMExpr &expr) { + SDBMExprKind kind = expr.getKind(); + return kind == SDBMExprKind::Add; + } + + SDBMVaryingExpr getLHS() const; + SDBMConstantExpr getRHS() const; +}; + +/// SDBM difference expression. Both LHS and RHS are positive variable +/// expressions. +class SDBMDiffExpr : public SDBMVaryingExpr { +public: + using ImplType = detail::SDBMDiffExprStorage; + using SDBMVaryingExpr::SDBMVaryingExpr; + + /// Obtain or create a difference expression unique'ed in the given context. + static SDBMDiffExpr get(SDBMPositiveExpr lhs, SDBMPositiveExpr rhs); + + static bool isClassFor(const SDBMExpr &expr) { + return expr.getKind() == SDBMExprKind::Diff; + } + + SDBMPositiveExpr getLHS() const; + SDBMPositiveExpr getRHS() const; +}; + +/// SDBM stripe expression "x # C" where "x" is a positive variable expression, +/// "C" is a constant expression and "#" is the stripe operator defined as: +/// x # C = x - x mod C. +class SDBMStripeExpr : public SDBMPositiveExpr { +public: + using ImplType = detail::SDBMBinaryExprStorage; + using SDBMPositiveExpr::SDBMPositiveExpr; + + static bool isClassFor(const SDBMExpr &expr) { + return expr.getKind() == SDBMExprKind::Stripe; + } + + static SDBMStripeExpr get(SDBMPositiveExpr var, + SDBMConstantExpr stripeFactor); + + SDBMPositiveExpr getVar() const; + SDBMConstantExpr getStripeFactor() const; +}; + +/// SDBM "input" variable expression can be either a dimension identifier or +/// a symbol identifier. When used to define SDBM functions, dimensions are +/// interpreted as function arguments while symbols are treated as unknown but +/// constant values, hence the name. +class SDBMInputExpr : public SDBMPositiveExpr { +public: + using ImplType = detail::SDBMPositiveExprStorage; + using SDBMPositiveExpr::SDBMPositiveExpr; + + static bool isClassFor(const SDBMExpr &expr) { + return expr.getKind() == SDBMExprKind::DimId || + expr.getKind() == SDBMExprKind::SymbolId; + } + + unsigned getPosition() const; +}; + +/// SDBM dimension expression. Dimensions correspond to function arguments +/// when defining functions using SDBM expressions. +class SDBMDimExpr : public SDBMInputExpr { +public: + using ImplType = detail::SDBMPositiveExprStorage; + using SDBMInputExpr::SDBMInputExpr; + + /// Obtain or create a dimension expression unique'ed in the given dialect + /// (which belongs to a context). + static SDBMDimExpr get(SDBMDialect *dialect, unsigned position); + + static bool isClassFor(const SDBMExpr &expr) { + return expr.getKind() == SDBMExprKind::DimId; + } +}; + +/// SDBM symbol expression. Symbols correspond to symbolic constants when +/// defining functions using SDBM expressions. +class SDBMSymbolExpr : public SDBMInputExpr { +public: + using ImplType = detail::SDBMPositiveExprStorage; + using SDBMInputExpr::SDBMInputExpr; + + /// Obtain or create a symbol expression unique'ed in the given dialect (which + /// belongs to a context). + static SDBMSymbolExpr get(SDBMDialect *dialect, unsigned position); + + static bool isClassFor(const SDBMExpr &expr) { + return expr.getKind() == SDBMExprKind::SymbolId; + } +}; + +/// Negation of an SDBM variable expression. Equivalent to multiplying the +/// expression with -1 (SDBM does not support other coefficients that 1 and -1). +class SDBMNegExpr : public SDBMVaryingExpr { +public: + using ImplType = detail::SDBMNegExprStorage; + using SDBMVaryingExpr::SDBMVaryingExpr; + + /// Obtain or create a negation expression unique'ed in the given context. + static SDBMNegExpr get(SDBMPositiveExpr var); + + static bool isClassFor(const SDBMExpr &expr) { + return expr.getKind() == SDBMExprKind::Neg; + } + + SDBMPositiveExpr getVar() const; +}; + +/// A visitor class for SDBM expressions. Calls the kind-specific function +/// depending on the kind of expression it visits. +template class SDBMVisitor { +public: + /// Visit the given SDBM expression, dispatching to kind-specific functions. + Result visit(SDBMExpr expr) { + auto *derived = static_cast(this); + switch (expr.getKind()) { + case SDBMExprKind::Add: + case SDBMExprKind::Diff: + case SDBMExprKind::DimId: + case SDBMExprKind::SymbolId: + case SDBMExprKind::Neg: + case SDBMExprKind::Stripe: + return derived->visitVarying(expr.cast()); + case SDBMExprKind::Constant: + return derived->visitConstant(expr.cast()); + } + + llvm_unreachable("unsupported SDBM expression kind"); + } + + /// Traverse the SDBM expression tree calling `visit` on each node + /// in depth-first preorder. + void walkPreorder(SDBMExpr expr) { return walk(expr); } + + /// Traverse the SDBM expression tree calling `visit` on each node in + /// depth-first postorder. + void walkPostorder(SDBMExpr expr) { return walk(expr); } + +protected: + /// Default visitors do nothing. + void visitSum(SDBMSumExpr) {} + void visitDiff(SDBMDiffExpr) {} + void visitStripe(SDBMStripeExpr) {} + void visitDim(SDBMDimExpr) {} + void visitSymbol(SDBMSymbolExpr) {} + void visitNeg(SDBMNegExpr) {} + void visitConstant(SDBMConstantExpr) {} + + /// Default implementation of visitPositive dispatches to the special + /// functions for stripes and other variables. Concrete visitors can override + /// it. + Result visitPositive(SDBMPositiveExpr expr) { + auto *derived = static_cast(this); + if (expr.getKind() == SDBMExprKind::Stripe) + return derived->visitStripe(expr.cast()); + else + return derived->visitInput(expr.cast()); + } + + /// Default implementation of visitInput dispatches to the special + /// functions for dimensions or symbols. Concrete visitors can override it to + /// visit all variables instead. + Result visitInput(SDBMInputExpr expr) { + auto *derived = static_cast(this); + if (expr.getKind() == SDBMExprKind::DimId) + return derived->visitDim(expr.cast()); + else + return derived->visitSymbol(expr.cast()); + } + + /// Default implementation of visitVarying dispatches to the special + /// functions for variables and negations thereof. Concerete visitors can + /// override it to visit all variables and negations instead. + Result visitVarying(SDBMVaryingExpr expr) { + auto *derived = static_cast(this); + if (auto var = expr.dyn_cast()) + return derived->visitPositive(var); + else if (auto neg = expr.dyn_cast()) + return derived->visitNeg(neg); + else if (auto sum = expr.dyn_cast()) + return derived->visitSum(sum); + else if (auto diff = expr.dyn_cast()) + return derived->visitDiff(diff); + + llvm_unreachable("unhandled subtype of varying SDBM expression"); + } + + template void walk(SDBMExpr expr) { + if (isPreorder) + visit(expr); + if (auto sumExpr = expr.dyn_cast()) { + walk(sumExpr.getLHS()); + walk(sumExpr.getRHS()); + } else if (auto diffExpr = expr.dyn_cast()) { + walk(diffExpr.getLHS()); + walk(diffExpr.getRHS()); + } else if (auto stripeExpr = expr.dyn_cast()) { + walk(stripeExpr.getVar()); + walk(stripeExpr.getStripeFactor()); + } else if (auto negExpr = expr.dyn_cast()) { + walk(negExpr.getVar()); + } + if (!isPreorder) + visit(expr); + } +}; + +/// Overloaded arithmetic operators for SDBM expressions asserting that their +/// arguments have the proper SDBM expression subtype. Perform canonicalization +/// and constant folding on these expressions. +namespace ops_assertions { + +/// Add two SDBM expressions. At least one of the expressions must be a +/// constant or a negation, but both expressions cannot be negations +/// simultaneously. +SDBMExpr operator+(SDBMExpr lhs, SDBMExpr rhs); +inline SDBMExpr operator+(SDBMExpr lhs, int64_t rhs) { + return lhs + SDBMConstantExpr::get(lhs.getDialect(), rhs); +} +inline SDBMExpr operator+(int64_t lhs, SDBMExpr rhs) { + return SDBMConstantExpr::get(rhs.getDialect(), lhs) + rhs; +} + +/// Subtract an SDBM expression from another SDBM expression. Both expressions +/// must not be difference expressions. +SDBMExpr operator-(SDBMExpr lhs, SDBMExpr rhs); +inline SDBMExpr operator-(SDBMExpr lhs, int64_t rhs) { + return lhs - SDBMConstantExpr::get(lhs.getDialect(), rhs); +} +inline SDBMExpr operator-(int64_t lhs, SDBMExpr rhs) { + return SDBMConstantExpr::get(rhs.getDialect(), lhs) - rhs; +} + +/// Construct a stripe expression from a positive expression and a positive +/// constant stripe factor. +SDBMExpr stripe(SDBMExpr expr, SDBMExpr factor); +inline SDBMExpr stripe(SDBMExpr expr, int64_t factor) { + return stripe(expr, SDBMConstantExpr::get(expr.getDialect(), factor)); +} +} // namespace ops_assertions + +} // end namespace mlir + +namespace llvm { +// SDBMExpr hash just like pointers. +template <> struct DenseMapInfo { + static mlir::SDBMExpr getEmptyKey() { + auto *pointer = llvm::DenseMapInfo::getEmptyKey(); + return mlir::SDBMExpr(static_cast(pointer)); + } + static mlir::SDBMExpr getTombstoneKey() { + auto *pointer = llvm::DenseMapInfo::getTombstoneKey(); + return mlir::SDBMExpr(static_cast(pointer)); + } + static unsigned getHashValue(mlir::SDBMExpr expr) { + return expr.hash_value(); + } + static bool isEqual(mlir::SDBMExpr lhs, mlir::SDBMExpr rhs) { + return lhs == rhs; + } +}; + +// SDBMVaryingExpr hash just like pointers. +template <> struct DenseMapInfo { + static mlir::SDBMVaryingExpr getEmptyKey() { + auto *pointer = llvm::DenseMapInfo::getEmptyKey(); + return mlir::SDBMVaryingExpr( + static_cast(pointer)); + } + static mlir::SDBMVaryingExpr getTombstoneKey() { + auto *pointer = llvm::DenseMapInfo::getTombstoneKey(); + return mlir::SDBMVaryingExpr( + static_cast(pointer)); + } + static unsigned getHashValue(mlir::SDBMVaryingExpr expr) { + return expr.hash_value(); + } + static bool isEqual(mlir::SDBMVaryingExpr lhs, mlir::SDBMVaryingExpr rhs) { + return lhs == rhs; + } +}; + +// SDBMPositiveExpr hash just like pointers. +template <> struct DenseMapInfo { + static mlir::SDBMPositiveExpr getEmptyKey() { + auto *pointer = llvm::DenseMapInfo::getEmptyKey(); + return mlir::SDBMPositiveExpr( + static_cast(pointer)); + } + static mlir::SDBMPositiveExpr getTombstoneKey() { + auto *pointer = llvm::DenseMapInfo::getTombstoneKey(); + return mlir::SDBMPositiveExpr( + static_cast(pointer)); + } + static unsigned getHashValue(mlir::SDBMPositiveExpr expr) { + return expr.hash_value(); + } + static bool isEqual(mlir::SDBMPositiveExpr lhs, mlir::SDBMPositiveExpr rhs) { + return lhs == rhs; + } +}; + +// SDBMConstantExpr hash just like pointers. +template <> struct DenseMapInfo { + static mlir::SDBMConstantExpr getEmptyKey() { + auto *pointer = llvm::DenseMapInfo::getEmptyKey(); + return mlir::SDBMConstantExpr( + static_cast(pointer)); + } + static mlir::SDBMConstantExpr getTombstoneKey() { + auto *pointer = llvm::DenseMapInfo::getTombstoneKey(); + return mlir::SDBMConstantExpr( + static_cast(pointer)); + } + static unsigned getHashValue(mlir::SDBMConstantExpr expr) { + return expr.hash_value(); + } + static bool isEqual(mlir::SDBMConstantExpr lhs, mlir::SDBMConstantExpr rhs) { + return lhs == rhs; + } +}; +} // namespace llvm + +#endif // MLIR_DIALECT_SDBM_SDBMEXPR_H diff --git a/mlir/include/mlir/Dialect/StandardOps/CMakeLists.txt b/mlir/include/mlir/Dialect/StandardOps/CMakeLists.txt new file mode 100644 index 00000000000..670676f24db --- /dev/null +++ b/mlir/include/mlir/Dialect/StandardOps/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS Ops.td) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRStandardOpsIncGen) diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.h b/mlir/include/mlir/Dialect/StandardOps/Ops.h new file mode 100644 index 00000000000..3d2f34c40da --- /dev/null +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.h @@ -0,0 +1,363 @@ +//===- Ops.h - Standard MLIR Operations -------------------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines convenience types for working with standard operations +// in the MLIR operation set. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_STANDARDOPS_OPS_H +#define MLIR_DIALECT_STANDARDOPS_OPS_H + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" + +namespace mlir { +class AffineMap; +class Builder; +class FuncOp; +class OpBuilder; + +class StandardOpsDialect : public Dialect { +public: + StandardOpsDialect(MLIRContext *context); + static StringRef getDialectNamespace() { return "std"; } +}; + +/// The predicate indicates the type of the comparison to perform: +/// (in)equality; (un)signed less/greater than (or equal to). +enum class CmpIPredicate { + FirstValidValue, + // (In)equality comparisons. + EQ = FirstValidValue, + NE, + // Signed comparisons. + SLT, + SLE, + SGT, + SGE, + // Unsigned comparisons. + ULT, + ULE, + UGT, + UGE, + // Number of predicates. + NumPredicates +}; + +/// The predicate indicates the type of the comparison to perform: +/// (un)orderedness, (in)equality and less/greater than (or equal to) as +/// well as predicates that are always true or false. +enum class CmpFPredicate { + FirstValidValue, + // Always false + AlwaysFalse = FirstValidValue, + // Ordered comparisons + OEQ, + OGT, + OGE, + OLT, + OLE, + ONE, + // Both ordered + ORD, + // Unordered comparisons + UEQ, + UGT, + UGE, + ULT, + ULE, + UNE, + // Any unordered + UNO, + // Always true + AlwaysTrue, + // Number of predicates. + NumPredicates +}; + +#define GET_OP_CLASSES +#include "mlir/Dialect/StandardOps/Ops.h.inc" + +/// This is a refinement of the "constant" op for the case where it is +/// returning a float value of FloatType. +/// +/// %1 = "std.constant"(){value: 42.0} : bf16 +/// +class ConstantFloatOp : public ConstantOp { +public: + using ConstantOp::ConstantOp; + + /// Builds a constant float op producing a float of the specified type. + static void build(Builder *builder, OperationState *result, + const APFloat &value, FloatType type); + + APFloat getValue() { return getAttrOfType("value").getValue(); } + + static bool classof(Operation *op); +}; + +/// This is a refinement of the "constant" op for the case where it is +/// returning an integer value of IntegerType. +/// +/// %1 = "std.constant"(){value: 42} : i32 +/// +class ConstantIntOp : public ConstantOp { +public: + using ConstantOp::ConstantOp; + /// Build a constant int op producing an integer of the specified width. + static void build(Builder *builder, OperationState *result, int64_t value, + unsigned width); + + /// Build a constant int op producing an integer with the specified type, + /// which must be an integer type. + static void build(Builder *builder, OperationState *result, int64_t value, + Type type); + + int64_t getValue() { return getAttrOfType("value").getInt(); } + + static bool classof(Operation *op); +}; + +/// This is a refinement of the "constant" op for the case where it is +/// returning an integer value of Index type. +/// +/// %1 = "std.constant"(){value: 99} : () -> index +/// +class ConstantIndexOp : public ConstantOp { +public: + using ConstantOp::ConstantOp; + + /// Build a constant int op producing an index. + static void build(Builder *builder, OperationState *result, int64_t value); + + int64_t getValue() { return getAttrOfType("value").getInt(); } + + static bool classof(Operation *op); +}; + +// DmaStartOp starts a non-blocking DMA operation that transfers data from a +// source memref to a destination memref. The source and destination memref need +// not be of the same dimensionality, but need to have the same elemental type. +// The operands include the source and destination memref's each followed by its +// indices, size of the data transfer in terms of the number of elements (of the +// elemental type of the memref), a tag memref with its indices, and optionally +// at the end, a stride and a number_of_elements_per_stride arguments. The tag +// location is used by a DmaWaitOp to check for completion. The indices of the +// source memref, destination memref, and the tag memref have the same +// restrictions as any load/store. The optional stride arguments should be of +// 'index' type, and specify a stride for the slower memory space (memory space +// with a lower memory space id), tranferring chunks of +// number_of_elements_per_stride every stride until %num_elements are +// transferred. Either both or no stride arguments should be specified. +// +// For example, a DmaStartOp operation that transfers 256 elements of a memref +// '%src' in memory space 0 at indices [%i, %j] to memref '%dst' in memory space +// 1 at indices [%k, %l], would be specified as follows: +// +// %num_elements = constant 256 +// %idx = constant 0 : index +// %tag = alloc() : memref<1 x i32, (d0) -> (d0), 4> +// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx] : +// memref<40 x 128 x f32>, (d0) -> (d0), 0>, +// memref<2 x 1024 x f32>, (d0) -> (d0), 1>, +// memref<1 x i32>, (d0) -> (d0), 2> +// +// If %stride and %num_elt_per_stride are specified, the DMA is expected to +// transfer %num_elt_per_stride elements every %stride elements apart from +// memory space 0 until %num_elements are transferred. +// +// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx], %stride, +// %num_elt_per_stride : +// +// TODO(mlir-team): add additional operands to allow source and destination +// striding, and multiple stride levels. +// TODO(andydavis) Consider replacing src/dst memref indices with view memrefs. +class DmaStartOp + : public Op { +public: + using Op::Op; + + static void build(Builder *builder, OperationState *result, Value *srcMemRef, + ArrayRef srcIndices, Value *destMemRef, + ArrayRef destIndices, Value *numElements, + Value *tagMemRef, ArrayRef tagIndices, + Value *stride = nullptr, + Value *elementsPerStride = nullptr); + + // Returns the source MemRefType for this DMA operation. + Value *getSrcMemRef() { return getOperand(0); } + // Returns the rank (number of indices) of the source MemRefType. + unsigned getSrcMemRefRank() { + return getSrcMemRef()->getType().cast().getRank(); + } + // Returns the source memerf indices for this DMA operation. + operand_range getSrcIndices() { + return {getOperation()->operand_begin() + 1, + getOperation()->operand_begin() + 1 + getSrcMemRefRank()}; + } + + // Returns the destination MemRefType for this DMA operations. + Value *getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); } + // Returns the rank (number of indices) of the destination MemRefType. + unsigned getDstMemRefRank() { + return getDstMemRef()->getType().cast().getRank(); + } + unsigned getSrcMemorySpace() { + return getSrcMemRef()->getType().cast().getMemorySpace(); + } + unsigned getDstMemorySpace() { + return getDstMemRef()->getType().cast().getMemorySpace(); + } + + // Returns the destination memref indices for this DMA operation. + operand_range getDstIndices() { + return {getOperation()->operand_begin() + 1 + getSrcMemRefRank() + 1, + getOperation()->operand_begin() + 1 + getSrcMemRefRank() + 1 + + getDstMemRefRank()}; + } + + // Returns the number of elements being transferred by this DMA operation. + Value *getNumElements() { + return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank()); + } + + // Returns the Tag MemRef for this DMA operation. + Value *getTagMemRef() { + return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1); + } + // Returns the rank (number of indices) of the tag MemRefType. + unsigned getTagMemRefRank() { + return getTagMemRef()->getType().cast().getRank(); + } + + // Returns the tag memref index for this DMA operation. + operand_range getTagIndices() { + unsigned tagIndexStartPos = + 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1; + return {getOperation()->operand_begin() + tagIndexStartPos, + getOperation()->operand_begin() + tagIndexStartPos + + getTagMemRefRank()}; + } + + /// Returns true if this is a DMA from a faster memory space to a slower one. + bool isDestMemorySpaceFaster() { + return (getSrcMemorySpace() < getDstMemorySpace()); + } + + /// Returns true if this is a DMA from a slower memory space to a faster one. + bool isSrcMemorySpaceFaster() { + // Assumes that a lower number is for a slower memory space. + return (getDstMemorySpace() < getSrcMemorySpace()); + } + + /// Given a DMA start operation, returns the operand position of either the + /// source or destination memref depending on the one that is at the higher + /// level of the memory hierarchy. Asserts failure if neither is true. + unsigned getFasterMemPos() { + assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster()); + return isSrcMemorySpaceFaster() ? 0 : getSrcMemRefRank() + 1; + } + + static StringRef getOperationName() { return "std.dma_start"; } + static ParseResult parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p); + LogicalResult verify(); + + static void getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context); + + bool isStrided() { + return getNumOperands() != 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + + 1 + 1 + getTagMemRefRank(); + } + + Value *getStride() { + if (!isStrided()) + return nullptr; + return getOperand(getNumOperands() - 1 - 1); + } + + Value *getNumElementsPerStride() { + if (!isStrided()) + return nullptr; + return getOperand(getNumOperands() - 1); + } +}; + +// DmaWaitOp blocks until the completion of a DMA operation associated with the +// tag element '%tag[%index]'. %tag is a memref, and %index has to be an index +// with the same restrictions as any load/store index. %num_elements is the +// number of elements associated with the DMA operation. For example: +// +// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%index] : +// memref<2048 x f32>, (d0) -> (d0), 0>, +// memref<256 x f32>, (d0) -> (d0), 1> +// memref<1 x i32>, (d0) -> (d0), 2> +// ... +// ... +// dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 2> +// +class DmaWaitOp + : public Op { +public: + using Op::Op; + + static void build(Builder *builder, OperationState *result, Value *tagMemRef, + ArrayRef tagIndices, Value *numElements); + + static StringRef getOperationName() { return "std.dma_wait"; } + + // Returns the Tag MemRef associated with the DMA operation being waited on. + Value *getTagMemRef() { return getOperand(0); } + + // Returns the tag memref index for this DMA operation. + operand_range getTagIndices() { + return {getOperation()->operand_begin() + 1, + getOperation()->operand_begin() + 1 + getTagMemRefRank()}; + } + + // Returns the rank (number of indices) of the tag memref. + unsigned getTagMemRefRank() { + return getTagMemRef()->getType().cast().getRank(); + } + + // Returns the number of elements transferred in the associated DMA operation. + Value *getNumElements() { return getOperand(1 + getTagMemRefRank()); } + + static ParseResult parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p); + static void getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context); +}; + +/// Prints dimension and symbol list. +void printDimAndSymbolList(Operation::operand_iterator begin, + Operation::operand_iterator end, unsigned numDims, + OpAsmPrinter *p); + +/// Parses dimension and symbol list and returns true if parsing failed. +ParseResult parseDimAndSymbolList(OpAsmParser *parser, + SmallVector &operands, + unsigned &numDims); + +} // end namespace mlir + +#endif // MLIR_DIALECT_STANDARDOPS_OPS_H diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td new file mode 100644 index 00000000000..b6bf2cfb40b --- /dev/null +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -0,0 +1,905 @@ +//===- Ops.td - Standard operation definitions -------------*- tablegen -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// Defines some MLIR standard operations. +// +//===----------------------------------------------------------------------===// + +#ifdef STANDARD_OPS +#else +#define STANDARD_OPS + +#ifdef OP_BASE +#else +include "mlir/IR/OpBase.td" +#endif // OP_BASE + +def Std_Dialect : Dialect { + let name = "std"; + let cppNamespace = ""; +} + +// Base class for Standard dialect ops. +class Std_Op traits = []> : + Op { + // For every standard op, there needs to be a: + // * void print(OpAsmPrinter *p, ${C++ class of Op} op) + // * LogicalResult verify(${C++ class of Op} op) + // * ParseResult parse${C++ class of Op}(OpAsmParser *parser, + // OperationState *result) + // functions. + let printer = [{ return ::print(p, *this); }]; + let verifier = [{ return ::verify(*this); }]; + let parser = [{ return ::parse$cppClass(parser, result); }]; +} + +// Base class for standard cast operations. Requires single operand and result, +// but does not constrain them to specific types. +class CastOp traits = []> : + Std_Op { + + let results = (outs AnyType); + + let builders = [OpBuilder< + "Builder *builder, OperationState *result, Value *source, Type destType", [{ + impl::buildCastOp(builder, result, source, destType); + }]>]; + + let parser = [{ + return impl::parseCastOp(parser, result); + }]; + let printer = [{ + return printStandardCastOp(this->getOperation(), p); + }]; + let verifier = [{ return ::verifyCastOp(*this); }]; + + let hasFolder = 1; +} + +// Base class for standard arithmetic operations. Requires operands and +// results to be of the same type, but does not constrain them to specific +// types. Individual classes will have `lhs` and `rhs` accessor to operands. +class ArithmeticOp traits = []> : + Op { + + let results = (outs AnyType); + + let parser = [{ + return impl::parseBinaryOp(parser, result); + }]; + + let printer = [{ + return printStandardBinaryOp(this->getOperation(), p); + }]; +} + +// Base class for standard arithmetic operations on integers, vectors and +// tensors thereof. This operation takes two operands and returns one result, +// each of these is required to be of the same type. This type may be an +// integer scalar type, a vector whose element type is an integer type, or an +// integer tensor. The custom assembly form of the operaton is as follows +// +// i %0, %1 : i32 +class IntArithmeticOp traits = []> : + ArithmeticOp, + Arguments<(ins IntegerLike:$lhs, IntegerLike:$rhs)>; + +// Base class for standard arithmetic binary operations on floats, vectors and +// tensors thereof. This operation has two operands and returns one result, +// each of these is required to be of the same type. This type may be a +// floating point scalar type, a vector whose element type is a floating point +// type, or a floating point tensor. The custom assembly form of the operation +// is as follows +// +// f %0, %1 : f32 +class FloatArithmeticOp traits = []> : + ArithmeticOp, + Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>; + +def AddFOp : FloatArithmeticOp<"addf"> { + let summary = "floating point addition operation"; + let hasFolder = 1; +} + +def AddIOp : IntArithmeticOp<"addi", [Commutative]> { + let summary = "integer addition operation"; + let hasFolder = 1; +} + +def AllocOp : Std_Op<"alloc"> { + let summary = "memory allocation operation"; + let description = [{ + The "alloc" operation allocates a region of memory, as specified by its + memref type. For example: + + %0 = alloc() : memref<8x64xf32, (d0, d1) -> (d0, d1), 1> + + The optional list of dimension operands are bound to the dynamic dimensions + specified in its memref type. In the example below, the ssa value '%d' is + bound to the second dimension of the memref (which is dynamic). + + %0 = alloc(%d) : memref<8x?xf32, (d0, d1) -> (d0, d1), 1> + + The optional list of symbol operands are bound to the symbols of the + memrefs affine map. In the example below, the ssa value '%s' is bound to + the symbol 's0' in the affine map specified in the allocs memref type. + + %0 = alloc()[%s] : memref<8x64xf32, (d0, d1)[s0] -> ((d0 + s0), d1), 1> + + This operation returns a single ssa value of memref type, which can be used + by subsequent load and store operations. + }]; + + let arguments = (ins Variadic:$value); + let results = (outs AnyMemRef); + + let builders = [OpBuilder< + "Builder *builder, OperationState *result, MemRefType memrefType", [{ + result->types.push_back(memrefType); + }] + >]; + + let extraClassDeclaration = [{ + MemRefType getType() { return getResult()->getType().cast(); } + }]; + + let hasCanonicalizer = 1; +} + +def AndOp : IntArithmeticOp<"and", [Commutative]> { + let summary = "integer binary and"; + let hasFolder = 1; +} + +def BranchOp : Std_Op<"br", [Terminator]> { + let summary = "branch operation"; + let description = [{ + The "br" operation represents a branch operation in a function. + The operation takes variable number of operands and produces no results. + The operand number and types for each successor must match the arguments of + the block successor. For example: + + ^bb2: + %2 = call @someFn() + br ^bb3(%2 : tensor<*xf32>) + ^bb3(%3: tensor<*xf32>): + }]; + + let arguments = (ins Variadic:$operands); + + let builders = [OpBuilder< + "Builder *, OperationState *result, Block *dest," + "ArrayRef operands = {}", [{ + result->addSuccessor(dest, operands); + }]>]; + + // BranchOp is fully verified by traits. + let verifier = ?; + + let extraClassDeclaration = [{ + Block *getDest(); + void setDest(Block *block); + + /// Erase the operand at 'index' from the operand list. + void eraseOperand(unsigned index); + }]; +} + +def CallOp : Std_Op<"call"> { + let summary = "call operation"; + let description = [{ + The "call" operation represents a direct call to a function. The operands + and result types of the call must match the specified function type. The + callee is encoded as a function attribute named "callee". + + %2 = call @my_add(%0, %1) : (f32, f32) -> f32 + }]; + + let arguments = (ins SymbolRefAttr:$callee, Variadic:$operands); + let results = (outs Variadic); + + let builders = [OpBuilder< + "Builder *builder, OperationState *result, FuncOp callee," + "ArrayRef operands = {}", [{ + result->addOperands(operands); + result->addAttribute("callee", builder->getSymbolRefAttr(callee)); + result->addTypes(callee.getType().getResults()); + }]>, OpBuilder< + "Builder *builder, OperationState *result, StringRef callee," + "ArrayRef results, ArrayRef operands = {}", [{ + result->addOperands(operands); + result->addAttribute("callee", builder->getSymbolRefAttr(callee)); + result->addTypes(results); + }]>]; + + let extraClassDeclaration = [{ + StringRef getCallee() { return callee(); } + FunctionType getCalleeType(); + + /// Get the argument operands to the called function. + operand_range getArgOperands() { + return {arg_operand_begin(), arg_operand_end()}; + } + + operand_iterator arg_operand_begin() { return operand_begin(); } + operand_iterator arg_operand_end() { return operand_end(); } + }]; +} + +def CallIndirectOp : Std_Op<"call_indirect"> { + let summary = "indirect call operation"; + let description = [{ + The "call_indirect" operation represents an indirect call to a value of + function type. Functions are first class types in MLIR, and may be passed + as arguments and merged together with block arguments. The operands + and result types of the call must match the specified function type. + + %3 = call_indirect %2(%0, %1) : (f32, f32) -> f32 + }]; + + let arguments = (ins FunctionType:$callee, Variadic:$operands); + let results = (outs Variadic); + + let builders = [OpBuilder< + "Builder *, OperationState *result, Value *callee," + "ArrayRef operands = {}", [{ + result->operands.push_back(callee); + result->addOperands(operands); + result->addTypes(callee->getType().cast().getResults()); + }]>]; + + let extraClassDeclaration = [{ + Value *getCallee() { return getOperand(0); } + + /// Get the argument operands to the called function. + operand_range getArgOperands() { + return {arg_operand_begin(), arg_operand_end()}; + } + + operand_iterator arg_operand_begin() { return ++operand_begin(); } + operand_iterator arg_operand_end() { return operand_end(); } + }]; + + let hasCanonicalizer = 1; +} + +def CmpIOp : Std_Op<"cmpi", [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape]> { + let summary = "integer comparison operation"; + let description = [{ + The "cmpi" operation compares its two operands according to the integer + comparison rules and the predicate specified by the respective attribute. + The predicate defines the type of comparison: (in)equality, (un)signed + less/greater than (or equal to). The operands must have the same type, and + this type must be an integer type, a vector or a tensor thereof. The result + is an i1, or a vector/tensor thereof having the same shape as the inputs. + Since integers are signless, the predicate also explicitly indicates + whether to interpret the operands as signed or unsigned integers for + less/greater than comparisons. For the sake of readability by humans, + custom assembly form for the operation uses a string-typed attribute for + the predicate. The value of this attribute corresponds to lower-cased name + of the predicate constant, e.g., "slt" means "signed less than". The string + representation of the attribute is merely a syntactic sugar and is converted + to an integer attribute by the parser. + + %r1 = cmpi "eq" %0, %1 : i32 + %r2 = cmpi "slt" %0, %1 : tensor<42x42xi64> + %r3 = "std.cmpi"(%0, %1){predicate: 0} : (i8, i8) -> i1 + }]; + + let arguments = (ins IntegerLike:$lhs, IntegerLike:$rhs); + let results = (outs BoolLike); + + let builders = [OpBuilder< + "Builder *builder, OperationState *result, CmpIPredicate predicate," + "Value *lhs, Value *rhs", [{ + ::buildCmpIOp(builder, result, predicate, lhs, rhs); + }]>]; + + let extraClassDeclaration = [{ + static StringRef getPredicateAttrName() { return "predicate"; } + static CmpIPredicate getPredicateByName(StringRef name); + + CmpIPredicate getPredicate() { + return (CmpIPredicate)getAttrOfType(getPredicateAttrName()) + .getInt(); + } + }]; + + let hasFolder = 1; +} + +def CmpFOp : Std_Op<"cmpf", [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape]> { + let summary = "floating-point comparison operation"; + let description = [{ + The "cmpf" operation compares its two operands according to the float + comparison rules and the predicate specified by the respective attribute. + The predicate defines the type of comparison: (un)orderedness, (in)equality + and signed less/greater than (or equal to) as well as predicates that are + always true or false. The operands must have the same type, and this type + must be a float type, or a vector or tensor thereof. The result is an i1, + or a vector/tensor thereof having the same shape as the inputs. Unlike cmpi, + the operands are always treated as signed. The u prefix indicates + *unordered* comparison, not unsigned comparison, so "une" means unordered or + not equal. For the sake of readability by humans, custom assembly form for + the operation uses a string-typed attribute for the predicate. The value of + this attribute corresponds to lower-cased name of the predicate constant, + e.g., "one" means "ordered not equal". The string representation of the + attribute is merely a syntactic sugar and is converted to an integer + attribute by the parser. + + %r1 = cmpf "oeq" %0, %1 : f32 + %r2 = cmpf "ult" %0, %1 : tensor<42x42xf64> + %r3 = "std.cmpf"(%0, %1) {predicate: 0} : (f8, f8) -> i1 + }]; + + let arguments = (ins FloatLike:$lhs, FloatLike:$rhs); + let results = (outs BoolLike); + + let builders = [OpBuilder< + "Builder *builder, OperationState *result, CmpFPredicate predicate," + "Value *lhs, Value *rhs", [{ + ::buildCmpFOp(builder, result, predicate, lhs, rhs); + }]>]; + + let extraClassDeclaration = [{ + static StringRef getPredicateAttrName() { return "predicate"; } + static CmpFPredicate getPredicateByName(StringRef name); + + CmpFPredicate getPredicate() { + return (CmpFPredicate)getAttrOfType(getPredicateAttrName()) + .getInt(); + } + }]; + + let hasFolder = 1; +} + +def CondBranchOp : Std_Op<"cond_br", [Terminator]> { + let summary = "conditional branch operation"; + let description = [{ + The "cond_br" operation represents a conditional branch operation in a + function. The operation takes variable number of operands and produces + no results. The operand number and types for each successor must match the + arguments of the block successor. For example: + + ^bb0: + %0 = extract_element %arg0[] : tensor + cond_br %0, ^bb1, ^bb2 + ^bb1: + ... + ^bb2: + ... + }]; + + let arguments = (ins I1:$condition, Variadic:$branchOperands); + + let builders = [OpBuilder< + "Builder *, OperationState *result, Value *condition," + "Block *trueDest, ArrayRef trueOperands," + "Block *falseDest, ArrayRef falseOperands", [{ + result->addOperands(condition); + result->addSuccessor(trueDest, trueOperands); + result->addSuccessor(falseDest, falseOperands); + }]>]; + + // CondBranchOp is fully verified by traits. + let verifier = ?; + + let extraClassDeclaration = [{ + // These are the indices into the dests list. + enum { trueIndex = 0, falseIndex = 1 }; + + // The condition operand is the first operand in the list. + Value *getCondition() { return getOperand(0); } + + /// Return the destination if the condition is true. + Block *getTrueDest() { + return getOperation()->getSuccessor(trueIndex); + } + + /// Return the destination if the condition is false. + Block *getFalseDest() { + return getOperation()->getSuccessor(falseIndex); + } + + // Accessors for operands to the 'true' destination. + Value *getTrueOperand(unsigned idx) { + assert(idx < getNumTrueOperands()); + return getOperand(getTrueDestOperandIndex() + idx); + } + + void setTrueOperand(unsigned idx, Value *value) { + assert(idx < getNumTrueOperands()); + setOperand(getTrueDestOperandIndex() + idx, value); + } + + operand_iterator true_operand_begin() { + return operand_begin() + getTrueDestOperandIndex(); + } + operand_iterator true_operand_end() { + return true_operand_begin() + getNumTrueOperands(); + } + operand_range getTrueOperands() { + return {true_operand_begin(), true_operand_end()}; + } + + unsigned getNumTrueOperands() { + return getOperation()->getNumSuccessorOperands(trueIndex); + } + + /// Erase the operand at 'index' from the true operand list. + void eraseTrueOperand(unsigned index) { + getOperation()->eraseSuccessorOperand(trueIndex, index); + } + + // Accessors for operands to the 'false' destination. + Value *getFalseOperand(unsigned idx) { + assert(idx < getNumFalseOperands()); + return getOperand(getFalseDestOperandIndex() + idx); + } + void setFalseOperand(unsigned idx, Value *value) { + assert(idx < getNumFalseOperands()); + setOperand(getFalseDestOperandIndex() + idx, value); + } + + operand_iterator false_operand_begin() { return true_operand_end(); } + operand_iterator false_operand_end() { + return false_operand_begin() + getNumFalseOperands(); + } + operand_range getFalseOperands() { + return {false_operand_begin(), false_operand_end()}; + } + + unsigned getNumFalseOperands() { + return getOperation()->getNumSuccessorOperands(falseIndex); + } + + /// Erase the operand at 'index' from the false operand list. + void eraseFalseOperand(unsigned index) { + getOperation()->eraseSuccessorOperand(falseIndex, index); + } + + private: + /// Get the index of the first true destination operand. + unsigned getTrueDestOperandIndex() { return 1; } + + /// Get the index of the first false destination operand. + unsigned getFalseDestOperandIndex() { + return getTrueDestOperandIndex() + getNumTrueOperands(); + } + }]; + + let hasCanonicalizer = 1; +} + +def ConstantOp : Std_Op<"constant", [NoSideEffect]> { + let summary = "constant"; + + let arguments = (ins AnyAttr:$value); + let results = (outs AnyType); + + let builders = [OpBuilder< + "Builder *builder, OperationState *result, Attribute value", + [{ build(builder, result, value.getType(), value); }]>]; + + let extraClassDeclaration = [{ + Attribute getValue() { return getAttr("value"); } + + /// Returns true if a constant operation can be built with the given value + /// and result type. + static bool isBuildableWith(Attribute value, Type type); + }]; + + let hasFolder = 1; +} + +def DeallocOp : Std_Op<"dealloc"> { + let summary = "memory deallocation operation"; + let description = [{ + The "dealloc" operation frees the region of memory referenced by a memref + which was originally created by the "alloc" operation. + The "dealloc" operation should not be called on memrefs which alias an + alloc'd memref (i.e. memrefs returned by the "view" and "reshape" + operations). + + %0 = alloc() : memref<8x64xf32, (d0, d1) -> (d0, d1), 1> + dealloc %0 : memref<8x64xf32, (d0, d1) -> (d0, d1), 1> + }]; + + let arguments = (ins AnyMemRef:$memref); + + let hasCanonicalizer = 1; +} + +def DimOp : Std_Op<"dim", [NoSideEffect]> { + let summary = "dimension index operation"; + let description = [{ + The "dim" operation takes a memref or tensor operand and returns an "index". + It requires a single integer attribute named "index". It returns the size + of the specified dimension. For example: + + %1 = dim %0, 2 : tensor + }]; + + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor], + "any tensor or memref type">:$memrefOrTensor, + APIntAttr:$index); + let results = (outs Index); + + let builders = [OpBuilder< + "Builder *builder, OperationState *result, Value *memrefOrTensor," + "unsigned index", [{ + auto indexType = builder->getIndexType(); + auto indexAttr = builder->getIntegerAttr(indexType, index); + build(builder, result, indexType, memrefOrTensor, indexAttr); + }]>]; + + let extraClassDeclaration = [{ + unsigned getIndex() { + return getAttrOfType("index").getValue().getZExtValue(); + } + }]; + + let hasFolder = 1; +} + +def DivFOp : FloatArithmeticOp<"divf"> { + let summary = "floating point division operation"; +} + +def DivISOp : IntArithmeticOp<"divis"> { + let summary = "signed integer division operation"; + let hasFolder = 1; +} + +def DivIUOp : IntArithmeticOp<"diviu"> { + let summary = "unsigned integer division operation"; + let hasFolder = 1; +} + +def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> { + let summary = "element extract operation"; + let description = [{ + The "extract_element" op reads a tensor or vector and returns one element + from it specified by an index list. The output of extract is a new value + with the same type as the elements of the tensor or vector. The arity of + indices matches the rank of the accessed value (i.e., if a tensor is of rank + 3, then 3 indices are required for the extract). The indices should all be + of affine_int type. For example: + + %0 = extract_element %0[%1, %2] : vector<4x4xi32> + }]; + + let arguments = (ins AnyTypeOf<[AnyVector, AnyTensor]>:$aggregate, + Variadic:$indices); + let results = (outs AnyType); + + let builders = [OpBuilder< + "Builder *builder, OperationState *result, Value *aggregate," + "ArrayRef indices = {}", [{ + auto resType = aggregate->getType().cast() + .getElementType(); + build(builder, result, resType, aggregate, indices); + }]>]; + + let extraClassDeclaration = [{ + Value *getAggregate() { return getOperand(0); } + + operand_range getIndices() { + return {getOperation()->operand_begin() + 1, + getOperation()->operand_end()}; + } + }]; + + let hasFolder = 1; +} + +def IndexCastOp : CastOp<"index_cast">, Arguments<(ins AnyType:$in)> { + let summary = "cast between index and integer types"; + let description = [{ + Casts between integer scalars and 'index' scalars. Index is an integer of + platform-specific bit width. If casting to a wider integer, the value is + sign-extended. If casting to a narrower integer, the value is truncated. + }]; + + let extraClassDeclaration = [{ + /// Return true if `a` and `b` are valid operand and result pairs for + /// the operation. + static bool areCastCompatible(Type a, Type b); + }]; + + let hasFolder = 0; +} + +def SIToFPOp : CastOp<"sitofp">, Arguments<(ins AnyType:$in)> { + let summary = "cast from integer type to floating-point"; + let description = [{ + Cast from a value interpreted as signed integer to the corresponding + floating-point value. If the value cannot be exactly represented, it is + rounded using the default rounding mode. Only scalars are currently + supported. + }]; + + let extraClassDeclaration = [{ + /// Return true if `a` and `b` are valid operand and result pairs for + /// the operation. + static bool areCastCompatible(Type a, Type b); + }]; + + let hasFolder = 0; +} + +def LoadOp : Std_Op<"load"> { + let summary = "load operation"; + let description = [{ + The "load" op reads an element from a memref specified by an index list. The + output of load is a new value with the same type as the elements of the + memref. The arity of indices is the rank of the memref (i.e., if the memref + loaded from is of rank 3, then 3 indices are required for the load following + the memref identifier). For example: + + %3 = load %0[%1, %1] : memref<4x4xi32> + }]; + + let arguments = (ins AnyMemRef:$memref, Variadic:$indices); + let results = (outs AnyType); + + let builders = [OpBuilder< + "Builder *, OperationState *result, Value *memref," + "ArrayRef indices = {}", [{ + auto memrefType = memref->getType().cast(); + result->addOperands(memref); + result->addOperands(indices); + result->types.push_back(memrefType.getElementType()); + }]>]; + + let extraClassDeclaration = [{ + Value *getMemRef() { return getOperand(0); } + void setMemRef(Value *value) { setOperand(0, value); } + MemRefType getMemRefType() { + return getMemRef()->getType().cast(); + } + + operand_range getIndices() { + return {getOperation()->operand_begin() + 1, getOperation()->operand_end()}; + } + }]; + + let hasCanonicalizer = 1; +} + +def MemRefCastOp : CastOp<"memref_cast"> { + let summary = "memref cast operation"; + let description = [{ + The "memref_cast" operation converts a memref from one type to an equivalent + type with a compatible shape. The source and destination types are + when both are memref types with the same element type, affine mappings, + address space, and rank but where the individual dimensions may add or + remove constant dimensions from the memref type. + + If the cast converts any dimensions from an unknown to a known size, then it + acts as an assertion that fails at runtime of the dynamic dimensions + disagree with resultant destination size. + + Assert that the input dynamic shape matches the destination static shape. + %2 = memref_cast %1 : memref to memref<4x4xf32> + Erase static shape information, replacing it with dynamic information. + %3 = memref_cast %1 : memref<4xf32> to memref + }]; + + let arguments = (ins AnyMemRef:$source); + let results = (outs AnyMemRef); + + let extraClassDeclaration = [{ + /// Return true if `a` and `b` are valid operand and result pairs for + /// the operation. + static bool areCastCompatible(Type a, Type b); + + /// The result of a memref_cast is always a memref. + MemRefType getType() { return getResult()->getType().cast(); } + }]; +} + +def MulFOp : FloatArithmeticOp<"mulf"> { + let summary = "foating point multiplication operation"; + let hasFolder = 1; +} + +def MulIOp : IntArithmeticOp<"muli", [Commutative]> { + let summary = "integer multiplication operation"; + let hasFolder = 1; +} + +def OrOp : IntArithmeticOp<"or", [Commutative]> { + let summary = "integer binary or"; + let hasFolder = 1; +} + +def RankOp : Std_Op<"rank", [NoSideEffect]> { + let summary = "rank operation"; + let description = [{ + The "rank" operation takes a tensor operand and returns its rank. + + %1 = rank %0 : index + }]; + + let arguments = (ins AnyTensor); + let results = (outs Index); + let verifier = ?; + + let builders = [OpBuilder< + "Builder *builder, OperationState *result, Value *tensor", [{ + auto indexType = builder->getIndexType(); + build(builder, result, indexType, tensor); + }]>]; + + let hasFolder = 1; +} + +def RemFOp : FloatArithmeticOp<"remf"> { + let summary = "floating point division remainder operation"; +} + +def RemISOp : IntArithmeticOp<"remis"> { + let summary = "signed integer division remainder operation"; + let hasFolder = 1; +} + +def RemIUOp : IntArithmeticOp<"remiu"> { + let summary = "unsigned integer division remainder operation"; + let hasFolder = 1; +} + +def ReturnOp : Std_Op<"return", [Terminator, HasParent<"FuncOp">]> { + let summary = "return operation"; + let description = [{ + The "return" operation represents a return operation within a function. + The operation takes variable number of operands and produces no results. + The operand number and types must match the signature of the function + that contains the operation. For example: + + func @foo() : (i32, f8) { + ... + return %0, %1 : i32, f8 + }]; + + let arguments = (ins Variadic:$operands); + + let builders = [OpBuilder< + "Builder *b, OperationState *result", [{ build(b, result, llvm::None); }] + >]; +} + +def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape]> { + let summary = "select operation"; + let description = [{ + The "select" operation chooses one value based on a binary condition + supplied as its first operand. If the value of the first operand is 1, the + second operand is chosen, otherwise the third operand is chosen. The second + and the third operand must have the same type. The operation applies + elementwise to vectors and tensors. The shape of all arguments must be + identical. For example, the maximum operation is obtained by combining + "select" with "cmpi" as follows. + + %2 = cmpi "gt" %0, %1 : i32 // %2 is i1 + %3 = select %2, %0, %1 : i32 + }]; + + let arguments = (ins BoolLike:$condition, AnyType:$true_value, + AnyType:$false_value); + let results = (outs AnyType); + + let builders = [OpBuilder< + "Builder *builder, OperationState *result, Value *condition," + "Value *trueValue, Value *falseValue", [{ + result->addOperands({condition, trueValue, falseValue}); + result->addTypes(trueValue->getType()); + }]>]; + + let extraClassDeclaration = [{ + Value *getCondition() { return condition(); } + Value *getTrueValue() { return true_value(); } + Value *getFalseValue() { return false_value(); } + }]; + + let hasFolder = 1; +} +def ShlISOp : IntArithmeticOp<"shlis"> { + let summary = "signed integer shift left"; +} + +def SubFOp : FloatArithmeticOp<"subf"> { + let summary = "floating point subtraction operation"; + let hasFolder = 1; +} + +def SubIOp : IntArithmeticOp<"subi"> { + let summary = "integer subtraction operation"; + let hasFolder = 1; +} + +def StoreOp : Std_Op<"store"> { + let summary = "store operation"; + let description = [{ + The "store" op writes an element to a memref specified by an index list. + The arity of indices is the rank of the memref (i.e. if the memref being + stored to is of rank 3, then 3 indices are required for the store following + the memref identifier). The store operation does not produce a result. + + In the following example, the ssa value '%v' is stored in memref '%A' at + indices [%i, %j]: + store %v, %A[%i, %j] : memref<4x128xf32, (d0, d1) -> (d0, d1), 0> + }]; + + let arguments = (ins AnyType:$value, AnyMemRef:$memref, Variadic:$indices); + + let builders = [OpBuilder< + "Builder *, OperationState *result, Value *valueToStore, Value *memref", [{ + result->addOperands(valueToStore); + result->addOperands(memref); + }]>]; + + let extraClassDeclaration = [{ + Value *getValueToStore() { return getOperand(0); } + + Value *getMemRef() { return getOperand(1); } + void setMemRef(Value *value) { setOperand(1, value); } + MemRefType getMemRefType() { + return getMemRef()->getType().cast(); + } + + operand_range getIndices() { + return {getOperation()->operand_begin() + 2, getOperation()->operand_end()}; + } + }]; + + let hasCanonicalizer = 1; +} + +def TensorCastOp : CastOp<"tensor_cast"> { + let summary = "tensor cast operation"; + let description = [{ + The "tensor_cast" operation converts a tensor from one type to an equivalent + type without changing any data elements. The source and destination types + must both be tensor types with the same element type. If both are ranked + then the rank should be the same and static dimensions should match. The + operation is invalid if converting to a mismatching constant dimension. + + Convert from unknown rank to rank 2 with unknown dimension sizes. + %2 = tensor_cast %1 : tensor to tensor + }]; + + let arguments = (ins AnyTensor); + let results = (outs AnyTensor); + + let extraClassDeclaration = [{ + /// Return true if `a` and `b` are valid operand and result pairs for + /// the operation. + static bool areCastCompatible(Type a, Type b); + + /// The result of a tensor_cast is always a tensor. + TensorType getType() { return getResult()->getType().cast(); } + }]; +} + +def XOrOp : IntArithmeticOp<"xor", [Commutative]> { + let summary = "integer binary xor"; + let hasFolder = 1; +} + +#endif // STANDARD_OPS diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h index c1df3cfa42e..c4728743f31 100644 --- a/mlir/include/mlir/EDSC/Builders.h +++ b/mlir/include/mlir/EDSC/Builders.h @@ -24,8 +24,8 @@ #define MLIR_EDSC_BUILDERS_H_ #include "mlir/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Builders.h" -#include "mlir/StandardOps/Ops.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/VectorOps/VectorOps.h" diff --git a/mlir/include/mlir/LLVMIR/CMakeLists.txt b/mlir/include/mlir/LLVMIR/CMakeLists.txt deleted file mode 100644 index 1d7d06bc25c..00000000000 --- a/mlir/include/mlir/LLVMIR/CMakeLists.txt +++ /dev/null @@ -1,16 +0,0 @@ -set(LLVM_TARGET_DEFINITIONS LLVMOps.td) -mlir_tablegen(LLVMOps.h.inc -gen-op-decls) -mlir_tablegen(LLVMOps.cpp.inc -gen-op-defs) -mlir_tablegen(LLVMOpsEnums.h.inc -gen-enum-decls) -mlir_tablegen(LLVMOpsEnums.cpp.inc -gen-enum-defs) -add_public_tablegen_target(MLIRLLVMOpsIncGen) -set(LLVM_TARGET_DEFINITIONS NVVMOps.td) -mlir_tablegen(NVVMOps.h.inc -gen-op-decls) -mlir_tablegen(NVVMOps.cpp.inc -gen-op-defs) -add_public_tablegen_target(MLIRNVVMOpsIncGen) -set(LLVM_TARGET_DEFINITIONS LLVMOps.td) -mlir_tablegen(LLVMConversions.inc -gen-llvmir-conversions) -add_public_tablegen_target(MLIRLLVMConversionsIncGen) -set(LLVM_TARGET_DEFINITIONS NVVMOps.td) -mlir_tablegen(NVVMConversions.inc -gen-llvmir-conversions) -add_public_tablegen_target(MLIRNVVMConversionsIncGen) diff --git a/mlir/include/mlir/LLVMIR/LLVMDialect.h b/mlir/include/mlir/LLVMIR/LLVMDialect.h deleted file mode 100644 index 00f5be4d8d6..00000000000 --- a/mlir/include/mlir/LLVMIR/LLVMDialect.h +++ /dev/null @@ -1,180 +0,0 @@ -//===- LLVMDialect.h - MLIR LLVM IR dialect ---------------------*- C++ -*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file defines the LLVM IR dialect in MLIR, containing LLVM operations and -// LLVM type system. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_TARGET_LLVMDIALECT_H_ -#define MLIR_TARGET_LLVMDIALECT_H_ - -#include "mlir/IR/Dialect.h" -#include "mlir/IR/Function.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/TypeSupport.h" -#include "mlir/IR/Types.h" -#include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/LLVMContext.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/Type.h" - -#include "mlir/LLVMIR/LLVMOpsEnums.h.inc" - -namespace llvm { -class Type; -class LLVMContext; -} // end namespace llvm - -namespace mlir { -namespace LLVM { -class LLVMDialect; - -namespace detail { -struct LLVMTypeStorage; -struct LLVMDialectImpl; -} // namespace detail - -class LLVMType : public mlir::Type::TypeBase { -public: - enum Kind { - LLVM_TYPE = FIRST_LLVM_TYPE, - }; - - using Base::Base; - - static bool kindof(unsigned kind) { return kind == LLVM_TYPE; } - - LLVMDialect &getDialect(); - llvm::Type *getUnderlyingType() const; - - /// Array type utilities. - LLVMType getArrayElementType(); - unsigned getArrayNumElements(); - - /// Vector type utilities. - LLVMType getVectorElementType(); - - /// Function type utilities. - LLVMType getFunctionParamType(unsigned argIdx); - unsigned getFunctionNumParams(); - LLVMType getFunctionResultType(); - - /// Pointer type utilities. - LLVMType getPointerTo(unsigned addrSpace = 0); - LLVMType getPointerElementTy(); - - /// Struct type utilities. - LLVMType getStructElementType(unsigned i); - - /// Utilities used to generate floating point types. - static LLVMType getDoubleTy(LLVMDialect *dialect); - static LLVMType getFloatTy(LLVMDialect *dialect); - static LLVMType getHalfTy(LLVMDialect *dialect); - - /// Utilities used to generate integer types. - static LLVMType getIntNTy(LLVMDialect *dialect, unsigned numBits); - static LLVMType getInt1Ty(LLVMDialect *dialect) { - return getIntNTy(dialect, /*numBits=*/1); - } - static LLVMType getInt8Ty(LLVMDialect *dialect) { - return getIntNTy(dialect, /*numBits=*/8); - } - static LLVMType getInt8PtrTy(LLVMDialect *dialect) { - return getInt8Ty(dialect).getPointerTo(); - } - static LLVMType getInt16Ty(LLVMDialect *dialect) { - return getIntNTy(dialect, /*numBits=*/16); - } - static LLVMType getInt32Ty(LLVMDialect *dialect) { - return getIntNTy(dialect, /*numBits=*/32); - } - static LLVMType getInt64Ty(LLVMDialect *dialect) { - return getIntNTy(dialect, /*numBits=*/64); - } - - /// Utilities used to generate other miscellaneous types. - static LLVMType getArrayTy(LLVMType elementType, uint64_t numElements); - static LLVMType getFunctionTy(LLVMType result, ArrayRef params, - bool isVarArg); - static LLVMType getFunctionTy(LLVMType result, bool isVarArg) { - return getFunctionTy(result, llvm::None, isVarArg); - } - static LLVMType getStructTy(LLVMDialect *dialect, ArrayRef elements, - bool isPacked = false); - static LLVMType getStructTy(LLVMDialect *dialect, bool isPacked = false) { - return getStructTy(dialect, llvm::None, isPacked); - } - template - static typename std::enable_if::value, - LLVMType>::type - getStructTy(LLVMType elt1, Args... elts) { - SmallVector fields({elt1, elts...}); - return getStructTy(&elt1.getDialect(), fields); - } - static LLVMType getVectorTy(LLVMType elementType, unsigned numElements); - static LLVMType getVoidTy(LLVMDialect *dialect); - -private: - friend LLVMDialect; - - /// Get an LLVMType with a pre-existing llvm type. - static LLVMType get(MLIRContext *context, llvm::Type *llvmType); - - /// Get an LLVMType with an llvm type that may cause changes to the underlying - /// llvm context when constructed. - static LLVMType getLocked(LLVMDialect *dialect, - llvm::function_ref typeBuilder); -}; - -///// Ops ///// -#define GET_OP_CLASSES -#include "mlir/LLVMIR/LLVMOps.h.inc" - -class LLVMDialect : public Dialect { -public: - explicit LLVMDialect(MLIRContext *context); - ~LLVMDialect(); - static StringRef getDialectNamespace() { return "llvm"; } - - llvm::LLVMContext &getLLVMContext(); - llvm::Module &getLLVMModule(); - - /// Parse a type registered to this dialect. - Type parseType(StringRef tyData, Location loc) const override; - - /// Print a type registered to this dialect. - void printType(Type type, raw_ostream &os) const override; - - /// Verify a region argument attribute registered to this dialect. - /// Returns failure if the verification failed, success otherwise. - LogicalResult verifyRegionArgAttribute(Operation *op, unsigned regionIdx, - unsigned argIdx, - NamedAttribute argAttr) override; - -private: - friend LLVMType; - - std::unique_ptr impl; -}; - -} // end namespace LLVM -} // end namespace mlir - -#endif // MLIR_TARGET_LLVMDIALECT_H_ diff --git a/mlir/include/mlir/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/LLVMIR/LLVMOpBase.td deleted file mode 100644 index a68cdbf3da0..00000000000 --- a/mlir/include/mlir/LLVMIR/LLVMOpBase.td +++ /dev/null @@ -1,59 +0,0 @@ -//===-- LLVMOpBase.td - LLVM IR dialect shared definitions -*- tablegen -*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file contains shared definitions for the LLVM IR dialect and its -// subdialects. -// -//===----------------------------------------------------------------------===// - -#ifdef LLVMIR_OP_BASE -#else -#define LLVMIR_OP_BASE - -#ifdef OP_BASE -#else -include "mlir/IR/OpBase.td" -#endif // OP_BASE - -def LLVM_Dialect : Dialect { - let name = "llvm"; - let cppNamespace = "LLVM"; -} - -// LLVM IR type wrapped in MLIR. -def LLVM_Type : Type()">, - "LLVM dialect type">; - -// Base class for LLVM operations. Defines the interface to the llvm::IRBuilder -// used to translate to LLVM IR proper. -class LLVM_OpBase traits = []> : - Op { - // A pattern for constructing the LLVM IR Instruction (or other Value) that - // corresponds to this op. This pattern can use `builder` to refer to an - // `llvm::IRBuilder<>` instance, $-names of arguments and results and the - // following special variable names: - // - $_resultType - substituted with the LLVM IR type of the result; - // - $_numOperands - substituted with the number of operands (including - // the variadic ones); - // - $_hasResult - substituted with a check that a variadic-result op does - // have a result (LLVM ops can have 0 or 1 result); - // - $_location - mlir::Location object of the instruction. - // Additionally, `$$` can be used to produce the dollar character. - string llvmBuilder = ""; -} - -#endif // LLVMIR_OP_BASE diff --git a/mlir/include/mlir/LLVMIR/LLVMOps.td b/mlir/include/mlir/LLVMIR/LLVMOps.td deleted file mode 100644 index cf456614442..00000000000 --- a/mlir/include/mlir/LLVMIR/LLVMOps.td +++ /dev/null @@ -1,553 +0,0 @@ -//===-- LLVMOps.td - LLVM IR dialect op definition file ----*- tablegen -*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This is the LLVM IR operation definition file. -// -//===----------------------------------------------------------------------===// - -#ifdef LLVMIR_OPS -#else -#define LLVMIR_OPS - -include "mlir/LLVMIR/LLVMOpBase.td" - -// Base class for LLVM operations. All operations get an "llvm." prefix in -// their name automatically. LLVM operations have either zero or one result, -// this class is specialized below for both cases and should not be used -// directly. -class LLVM_Op traits = []> : - LLVM_OpBase { -} - -class LLVM_Builder { - string llvmBuilder = builder; -} - -def LLVM_OneResultOpBuilder : OpBuilder< - "Builder *, OperationState *result, Type resultType, " - "ArrayRef operands, ArrayRef attributes = {}", - [{ - if (resultType) result->addTypes(resultType); - result->addOperands(operands); - for (auto namedAttr : attributes) { - result->addAttribute(namedAttr.first, namedAttr.second); - } - }]>; - -def LLVM_ZeroResultOpBuilder : OpBuilder< - "Builder *, OperationState *result, ArrayRef operands, " - "ArrayRef attributes = {}", - [{ - result->addOperands(operands); - for (auto namedAttr : attributes) { - result->addAttribute(namedAttr.first, namedAttr.second); - } - }]>; - -class LLVM_TwoBuilders { - list builders = [b1, b2]; -} - -// Base class for LLVM operations with one result. -class LLVM_OneResultOp traits = []> : - LLVM_Op, Results<(outs LLVM_Type:$res)> { - let builders = [LLVM_OneResultOpBuilder]; -} - -// Compatibility builder that takes an instance of wrapped llvm::VoidType -// to indicate no result. -def LLVM_VoidResultTypeOpBuilder : OpBuilder< - "Builder *builder, OperationState *result, Type resultType, " - "ArrayRef operands, ArrayRef attributes = {}", - [{ - auto llvmType = resultType.dyn_cast(); (void)llvmType; - assert(llvmType && "result must be an LLVM type"); - assert(llvmType.getUnderlyingType() && - llvmType.getUnderlyingType()->isVoidTy() && - "for zero-result operands, only 'void' is accepted as result type"); - build(builder, result, operands, attributes); - }]>; - -// Base class for LLVM operations with zero results. -class LLVM_ZeroResultOp traits = []> : - LLVM_Op, Results<(outs)>, - LLVM_TwoBuilders; - -// Base class for LLVM terminator operations. All terminator operations have -// zero results and an optional list of successors. -class LLVM_TerminatorOp traits = []> : - LLVM_Op, - Arguments<(ins Variadic:$args)>, Results<(outs)> { - let builders = [OpBuilder< - "Builder *, OperationState *result, " - "ArrayRef properOperands, " - "ArrayRef destinations, " - "ArrayRef> operands = {}, " - "ArrayRef attributes = {}", - [{ - result->addOperands(properOperands); - for (auto kvp : llvm::zip(destinations, operands)) { - result->addSuccessor(std::get<0>(kvp), std::get<1>(kvp)); - } - for (auto namedAttr : attributes) { - result->addAttribute(namedAttr.first, namedAttr.second); - } - }] - >]; -} - -// Class for arithmetic binary operations. -class LLVM_ArithmeticOp traits = []> : - LLVM_OneResultOp, - Arguments<(ins LLVM_Type:$lhs, LLVM_Type:$rhs)>, - LLVM_Builder<"$res = builder." # builderFunc # "($lhs, $rhs);"> { - let parser = [{ return impl::parseBinaryOp(parser, result); }]; - let printer = [{ mlir::impl::printBinaryOp(this->getOperation(), p); }]; -} - -// Integer binary operations. -def LLVM_AddOp : LLVM_ArithmeticOp<"add", "CreateAdd", [Commutative]>; -def LLVM_SubOp : LLVM_ArithmeticOp<"sub", "CreateSub">; -def LLVM_MulOp : LLVM_ArithmeticOp<"mul", "CreateMul", [Commutative]>; -def LLVM_UDivOp : LLVM_ArithmeticOp<"udiv", "CreateUDiv">; -def LLVM_SDivOp : LLVM_ArithmeticOp<"sdiv", "CreateSDiv">; -def LLVM_URemOp : LLVM_ArithmeticOp<"urem", "CreateURem">; -def LLVM_SRemOp : LLVM_ArithmeticOp<"srem", "CreateSRem">; -def LLVM_AndOp : LLVM_ArithmeticOp<"and", "CreateAnd">; -def LLVM_OrOp : LLVM_ArithmeticOp<"or", "CreateOr">; -def LLVM_XOrOp : LLVM_ArithmeticOp<"xor", "CreateXor">; - -// Predicate for integer comparisons. -def ICmpPredicateEQ : I64EnumAttrCase<"eq", 0>; -def ICmpPredicateNE : I64EnumAttrCase<"ne", 1>; -def ICmpPredicateSLT : I64EnumAttrCase<"slt", 2>; -def ICmpPredicateSLE : I64EnumAttrCase<"sle", 3>; -def ICmpPredicateSGT : I64EnumAttrCase<"sgt", 4>; -def ICmpPredicateSGE : I64EnumAttrCase<"sge", 5>; -def ICmpPredicateULT : I64EnumAttrCase<"ult", 6>; -def ICmpPredicateULE : I64EnumAttrCase<"ule", 7>; -def ICmpPredicateUGT : I64EnumAttrCase<"ugt", 8>; -def ICmpPredicateUGE : I64EnumAttrCase<"uge", 9>; -def ICmpPredicate : I64EnumAttr< - "ICmpPredicate", - "llvm.icmp comparison predicate", - [ICmpPredicateEQ, ICmpPredicateNE, ICmpPredicateSLT, ICmpPredicateSLE, - ICmpPredicateSGT, ICmpPredicateSGE, ICmpPredicateULT, ICmpPredicateULE, - ICmpPredicateUGT, ICmpPredicateUGE]> { - let cppNamespace = "mlir::LLVM"; - - let returnType = "ICmpPredicate"; - let convertFromStorage = - "static_cast<" # returnType # ">($_self.getValue().getZExtValue())"; -} - -// Other integer operations. -def LLVM_ICmpOp : LLVM_OneResultOp<"icmp", [NoSideEffect]>, - Arguments<(ins ICmpPredicate:$predicate, LLVM_Type:$lhs, - LLVM_Type:$rhs)> { - let llvmBuilder = [{ - $res = builder.CreateICmp(getLLVMCmpPredicate($predicate), $lhs, $rhs); - }]; - let parser = [{ return parseCmpOp(parser, result); }]; - let printer = [{ printICmpOp(p, *this); }]; -} - -// Predicate for float comparisons -def FCmpPredicateFALSE : I64EnumAttrCase<"_false", 0>; -def FCmpPredicateOEQ : I64EnumAttrCase<"oeq", 1>; -def FCmpPredicateOGT : I64EnumAttrCase<"ogt", 2>; -def FCmpPredicateOGE : I64EnumAttrCase<"oge", 3>; -def FCmpPredicateOLT : I64EnumAttrCase<"olt", 4>; -def FCmpPredicateOLE : I64EnumAttrCase<"ole", 5>; -def FCmpPredicateONE : I64EnumAttrCase<"one", 6>; -def FCmpPredicateORD : I64EnumAttrCase<"ord", 7>; -def FCmpPredicateUEQ : I64EnumAttrCase<"ueq", 8>; -def FCmpPredicateUGT : I64EnumAttrCase<"ugt", 9>; -def FCmpPredicateUGE : I64EnumAttrCase<"uge", 10>; -def FCmpPredicateULT : I64EnumAttrCase<"ult", 11>; -def FCmpPredicateULE : I64EnumAttrCase<"ule", 12>; -def FCmpPredicateUNE : I64EnumAttrCase<"une", 13>; -def FCmpPredicateUNO : I64EnumAttrCase<"uno", 14>; -def FCmpPredicateTRUE : I64EnumAttrCase<"_true", 15>; - -def FCmpPredicate : I64EnumAttr< - "FCmpPredicate", - "llvm.fcmp comparison predicate", - [FCmpPredicateFALSE, FCmpPredicateOEQ, FCmpPredicateOGT, FCmpPredicateOGE, - FCmpPredicateOLT, FCmpPredicateOLE, FCmpPredicateONE, FCmpPredicateORD, - FCmpPredicateUEQ, FCmpPredicateUGT, FCmpPredicateUGE, FCmpPredicateULT, - FCmpPredicateULE, FCmpPredicateUNE, FCmpPredicateUNO, FCmpPredicateTRUE - ]> { - let cppNamespace = "mlir::LLVM"; - - let returnType = "FCmpPredicate"; - let convertFromStorage = - "static_cast<" # returnType # ">($_self.getValue().getZExtValue())"; -} - -// Other integer operations. -def LLVM_FCmpOp : LLVM_OneResultOp<"fcmp", [NoSideEffect]>, - Arguments<(ins FCmpPredicate:$predicate, LLVM_Type:$lhs, - LLVM_Type:$rhs)> { - let llvmBuilder = [{ - $res = builder.CreateFCmp(getLLVMCmpPredicate($predicate), $lhs, $rhs); - }]; - let parser = [{ return parseCmpOp(parser, result); }]; - let printer = [{ printFCmpOp(p, *this); }]; -} - -// Floating point binary operations. -def LLVM_FAddOp : LLVM_ArithmeticOp<"fadd", "CreateFAdd">; -def LLVM_FSubOp : LLVM_ArithmeticOp<"fsub", "CreateFSub">; -def LLVM_FMulOp : LLVM_ArithmeticOp<"fmul", "CreateFMul">; -def LLVM_FDivOp : LLVM_ArithmeticOp<"fdiv", "CreateFDiv">; -def LLVM_FRemOp : LLVM_ArithmeticOp<"frem", "CreateFRem">; - -// Memory-related operations. -def LLVM_AllocaOp : - LLVM_OneResultOp<"alloca">, - Arguments<(ins LLVM_Type:$arraySize, OptionalAttr:$alignment)> { - string llvmBuilder = [{ - auto *alloca = builder.CreateAlloca( - $_resultType->getPointerElementType(), $arraySize); - if ($alignment.hasValue()) { - auto align = $alignment.getValue().getZExtValue(); - if (align != 0) - alloca->setAlignment(align); - } - $res = alloca; - }]; - let builders = [OpBuilder< - "Builder *b, OperationState *result, Type resultType, Value *arraySize, " - "unsigned alignment", - [{ - if (alignment == 0) - return build(b, result, resultType, arraySize, IntegerAttr()); - build(b, result, resultType, arraySize, b->getI64IntegerAttr(alignment)); - }]>]; - let parser = [{ return parseAllocaOp(parser, result); }]; - let printer = [{ printAllocaOp(p, *this); }]; - let verifier = [{ - if (alignment().hasValue()) { - auto align = alignment().getValue().getSExtValue(); - if (align < 0) - return emitOpError("expected positive alignment"); - } - return success(); - }]; -} -def LLVM_GEPOp : LLVM_OneResultOp<"getelementptr", [NoSideEffect]>, - Arguments<(ins LLVM_Type:$base, Variadic:$indices)>, - LLVM_Builder<"$res = builder.CreateGEP($base, $indices);"> { - let parser = [{ return parseGEPOp(parser, result); }]; - let printer = [{ printGEPOp(p, *this); }]; -} -def LLVM_LoadOp : LLVM_OneResultOp<"load">, Arguments<(ins LLVM_Type:$addr)>, - LLVM_Builder<"$res = builder.CreateLoad($addr);"> { - let builders = [OpBuilder< - "Builder *b, OperationState *result, Value *addr", - [{ - auto type = addr->getType().cast().getPointerElementTy(); - build(b, result, type, addr); - }]>]; - let parser = [{ return parseLoadOp(parser, result); }]; - let printer = [{ printLoadOp(p, *this); }]; -} -def LLVM_StoreOp : LLVM_ZeroResultOp<"store">, - Arguments<(ins LLVM_Type:$value, LLVM_Type:$addr)>, - LLVM_Builder<"builder.CreateStore($value, $addr);"> { - let parser = [{ return parseStoreOp(parser, result); }]; - let printer = [{ printStoreOp(p, *this); }]; -} - -// Casts. -class LLVM_CastOp traits = []> : - LLVM_OneResultOp, - Arguments<(ins LLVM_Type:$arg)>, - LLVM_Builder<"$res = builder." # builderFunc # "($arg, $_resultType);"> { - let parser = [{ return mlir::impl::parseCastOp(parser, result); }]; - let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }]; -} -def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "CreateBitCast">; -def LLVM_IntToPtrOp : LLVM_CastOp<"inttoptr", "CreateIntToPtr">; -def LLVM_PtrToIntOp : LLVM_CastOp<"ptrtoint", "CreatePtrToInt">; -def LLVM_SExtOp : LLVM_CastOp<"sext", "CreateSExt">; -def LLVM_ZExtOp : LLVM_CastOp<"zext", "CreateZExt">; -def LLVM_TruncOp : LLVM_CastOp<"trunc", "CreateTrunc">; -def LLVM_SIToFPOp : LLVM_CastOp<"sitofp", "CreateSIToFP">; - -// Call-related operations. -def LLVM_CallOp : LLVM_Op<"call">, - Arguments<(ins OptionalAttr:$callee, - // TODO(b/133216756): fix test failure and - // change to LLVM_Type - Variadic)>, - Results<(outs Variadic)>, - LLVM_TwoBuilders { - let verifier = [{ - if (getNumResults() > 1) - return emitOpError("must have 0 or 1 result"); - return success(); - }]; - let parser = [{ return parseCallOp(parser, result); }]; - let printer = [{ printCallOp(p, *this); }]; -} -def LLVM_ExtractElementOp : LLVM_OneResultOp<"extractelement", [NoSideEffect]>, - Arguments<(ins LLVM_Type:$vector, - LLVM_Type:$position)> { - string llvmBuilder = [{ - $res = builder.CreateExtractElement($vector, $position); - }]; - let builders = [OpBuilder< - "Builder *b, OperationState *result, Value *vector, Value *position," - "ArrayRef attrs = {}">]; - let parser = [{ return parseExtractElementOp(parser, result); }]; - let printer = [{ printExtractElementOp(p, *this); }]; -} -def LLVM_ExtractValueOp : LLVM_OneResultOp<"extractvalue", [NoSideEffect]>, - Arguments<(ins LLVM_Type:$container, - ArrayAttr:$position)> { - string llvmBuilder = [{ - $res = builder.CreateExtractValue($container, extractPosition($position)); - }]; - let parser = [{ return parseExtractValueOp(parser, result); }]; - let printer = [{ printExtractValueOp(p, *this); }]; -} -def LLVM_InsertElementOp : LLVM_OneResultOp<"insertelement", [NoSideEffect]>, - Arguments<(ins LLVM_Type:$vector, LLVM_Type:$value, - LLVM_Type:$position)> { - string llvmBuilder = [{ - $res = builder.CreateInsertElement($vector, $value, $position); - }]; - let parser = [{ return parseInsertElementOp(parser, result); }]; - let printer = [{ printInsertElementOp(p, *this); }]; -} -def LLVM_InsertValueOp : LLVM_OneResultOp<"insertvalue", [NoSideEffect]>, - Arguments<(ins LLVM_Type:$container, LLVM_Type:$value, - ArrayAttr:$position)> { - string llvmBuilder = [{ - $res = builder.CreateInsertValue($container, $value, - extractPosition($position)); - }]; - let builders = [OpBuilder< - "Builder *b, OperationState *result, Value *container, Value *value, " - "ArrayAttr position", - [{ - build(b, result, container->getType(), container, value, position); - }]>]; - let parser = [{ return parseInsertValueOp(parser, result); }]; - let printer = [{ printInsertValueOp(p, *this); }]; -} -def LLVM_ShuffleVectorOp - : LLVM_OneResultOp<"shufflevector", [NoSideEffect]>, - Arguments<(ins LLVM_Type:$v1, LLVM_Type:$v2, I32ArrayAttr:$mask)>, - LLVM_Builder< - "$res = builder.CreateShuffleVector($v1, $v2, extractPosition($mask));"> { - let builders = [OpBuilder< - "Builder *b, OperationState *result, Value *v1, Value *v2, " - "ArrayAttr mask, ArrayRef attrs = {}">]; - let verifier = [{ - auto wrappedVectorType1 = v1()->getType().cast(); - auto wrappedVectorType2 = v2()->getType().cast(); - if (!wrappedVectorType2.getUnderlyingType()->isVectorTy()) - return emitOpError("expected LLVM IR Dialect vector type for operand #2"); - if (wrappedVectorType1.getVectorElementType() != - wrappedVectorType2.getVectorElementType()) - return emitOpError("expected matching LLVM IR Dialect element types"); - return success(); - }]; - let parser = [{ return parseShuffleVectorOp(parser, result); }]; - let printer = [{ printShuffleVectorOp(p, *this); }]; -} - -// Misc operations. -def LLVM_SelectOp - : LLVM_OneResultOp<"select", [NoSideEffect]>, - Arguments<(ins LLVM_Type:$condition, LLVM_Type:$trueValue, - LLVM_Type:$falseValue)>, - LLVM_Builder< - "$res = builder.CreateSelect($condition, $trueValue, $falseValue);"> { - let parser = [{ return parseSelectOp(parser, result); }]; - let printer = [{ printSelectOp(p, *this); }]; -} - -// Terminators. -def LLVM_BrOp : LLVM_TerminatorOp<"br", []> { - let parser = [{ return parseBrOp(parser, result); }]; - let printer = [{ printBrOp(p, *this); }]; -} -def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br", []> { - let verifier = [{ - if (getNumSuccessors() != 2) - return emitOpError("expected exactly two successors"); - return success(); - }]; - let parser = [{ return parseCondBrOp(parser, result); }]; - let printer = [{ printCondBrOp(p, *this); }]; -} -def LLVM_ReturnOp : LLVM_TerminatorOp<"return", []> { - string llvmBuilder = [{ - if ($_numOperands != 0) - builder.CreateRet($args[0]); - else - builder.CreateRetVoid(); - }]; - - let verifier = [{ - if (getNumOperands() > 1) - return emitOpError("expects at most 1 operand"); - return success(); - }]; - - let parser = [{ return parseReturnOp(parser, result); }]; - let printer = [{ printReturnOp(p, *this); }]; -} -def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable", []> { - string llvmBuilder = [{ builder.CreateUnreachable(); }]; - let parser = [{ return success(); }]; - let printer = [{ *p << getOperationName(); }]; -} - -// Pseudo-operations (do not appear in LLVM IR but necessary for the dialect to -// work correctly). -def LLVM_AddressOfOp - : LLVM_OneResultOp<"addressof">, - Arguments<(ins SymbolRefAttr:$global_name)> { - let builders = [ - OpBuilder<"Builder *builder, OperationState *result, LLVMType resType, " - "StringRef name, ArrayRef attrs = {}", [{ - result->addAttribute("global_name", builder->getSymbolRefAttr(name)); - result->addAttributes(attrs); - result->addTypes(resType);}]>, - - OpBuilder<"Builder *builder, OperationState *result, GlobalOp global, " - "ArrayRef attrs = {}", [{ - build(builder, result, global.getType().getPointerTo(), global.sym_name(), - attrs);}]> - ]; - - let extraClassDeclaration = [{ - /// Return the llvm.global operation that defined the value referenced here. - GlobalOp getGlobal(); - }]; - - let printer = "printAddressOfOp(p, *this);"; - let parser = "return parseAddressOfOp(parser, result);"; - let verifier = "return ::verify(*this);"; -} - -def LLVM_GlobalOp - : LLVM_ZeroResultOp<"global">, - Arguments<(ins TypeAttr:$type, UnitAttr:$constant, StrAttr:$sym_name, - AnyAttr:$value)> { - - let builders = [ - OpBuilder<"Builder *builder, OperationState *result, LLVMType type, " - "bool isConstant, StringRef name, Attribute value, " - "ArrayRef attrs = {}"> - ]; - - let extraClassDeclaration = [{ - /// Return the LLVM type of the global. - LLVMType getType() { - return type().cast(); - } - }]; - - let printer = "printGlobalOp(p, *this);"; - let parser = "return parseGlobalOp(parser, result);"; - let verifier = "return ::verify(*this);"; -} - -def LLVM_LLVMFuncOp : LLVM_ZeroResultOp<"func", - [NativeOpTrait<"IsIsolatedFromAbove">, NativeOpTrait<"FunctionLike">]> { - let summary = "LLVM dialect function, has wrapped LLVM IR function type"; - - let regions = (region AnyRegion:$body); - - let skipDefaultBuilders = 1; - - let builders = [ - OpBuilder<"Builder *builder, OperationState *result, StringRef name, " - "LLVMType type, ArrayRef attrs, " - "ArrayRef argAttrs = {}"> - ]; - - let extraClassDeclaration = [{ - LLVMType getType() { - return getAttrOfType(getTypeAttrName()) - .getValue().cast(); - } - bool isVarArg() { - return getType().getUnderlyingType()->isFunctionVarArg(); - } - - // Hook for OpTrait::FunctionLike, returns the number of function arguments. - // Depends on the type attribute being correct as checked by verifyType. - unsigned getNumFuncArguments(); - - // Hook for OpTrait::FunctionLike, called after verifying that the 'type' - // attribute is present. This can check for preconditions of the - // getNumArguments hook not failing. - LogicalResult verifyType(); - }]; - - let verifier = [{ return ::verify(*this); }]; - let printer = [{ printLLVMFuncOp(p, *this); }]; - let parser = [{ - return impl::parseFunctionLikeOp(parser, result, /*allowVariadic=*/true, - buildLLVMFunctionType); - }]; -} - -def LLVM_UndefOp : LLVM_OneResultOp<"undef", [NoSideEffect]>, - LLVM_Builder<"$res = llvm::UndefValue::get($_resultType);"> { - let parser = [{ return parseUndefOp(parser, result); }]; - let printer = [{ printUndefOp(p, *this); }]; -} -def LLVM_ConstantOp - : LLVM_OneResultOp<"constant", [NoSideEffect]>, - Arguments<(ins AnyAttr:$value)>, - LLVM_Builder<"$res = getLLVMConstant($_resultType, $value, $_location);"> -{ - let parser = [{ return parseConstantOp(parser, result); }]; - let printer = [{ printConstantOp(p, *this); }]; -} - -// Operations that correspond to LLVM intrinsics. With MLIR operation set being -// extendable, there is no reason to introduce a hard boundary between "core" -// operations and intrinsics. - -def LLVM_fmuladd : LLVM_Op<"fmuladd", [NoSideEffect]>, - Arguments<(ins LLVM_Type:$a, LLVM_Type:$b, LLVM_Type:$c)>, - Results<(outs LLVM_Type:$res)> { - let llvmBuilder = [{ - llvm::Module *module = builder.GetInsertBlock()->getModule(); - llvm::Function *fn = llvm::Intrinsic::getDeclaration( - module, llvm::Intrinsic::fmuladd, - {$a->getType(), $b->getType(), $c->getType()}); - $res = builder.CreateCall(fn, {$a, $b, $c}); - }]; -} - - -#endif // LLVMIR_OPS diff --git a/mlir/include/mlir/LLVMIR/NVVMDialect.h b/mlir/include/mlir/LLVMIR/NVVMDialect.h deleted file mode 100644 index 206f86871c7..00000000000 --- a/mlir/include/mlir/LLVMIR/NVVMDialect.h +++ /dev/null @@ -1,43 +0,0 @@ -//===- NVVMDialect.h - MLIR NVVM IR dialect ---------------------*- C++ -*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file defines the NVVM IR dialect in MLIR, containing NVVM operations and -// NVVM specific extensions to the LLVM type system. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_LLVMIR_NVVMDIALECT_H_ -#define MLIR_LLVMIR_NVVMDIALECT_H_ - -#include "mlir/IR/Dialect.h" -#include "mlir/IR/OpDefinition.h" -namespace mlir { -namespace NVVM { - -///// Ops ///// -#define GET_OP_CLASSES -#include "mlir/LLVMIR/NVVMOps.h.inc" - -class NVVMDialect : public Dialect { -public: - explicit NVVMDialect(MLIRContext *context); -}; - -} // namespace NVVM -} // namespace mlir - -#endif /* MLIR_LLVMIR_NVVMDIALECT_H_ */ diff --git a/mlir/include/mlir/LLVMIR/NVVMOps.td b/mlir/include/mlir/LLVMIR/NVVMOps.td deleted file mode 100644 index 18be59988da..00000000000 --- a/mlir/include/mlir/LLVMIR/NVVMOps.td +++ /dev/null @@ -1,60 +0,0 @@ -//===-- NVVMOps.td - NVVM IR dialect op definition file ----*- tablegen -*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This is the NVVM IR operation definition file. -// -//===----------------------------------------------------------------------===// - -#ifdef NVVMIR_OPS -#else -#define NVVMIR_OPS - -include "mlir/LLVMIR/LLVMOpBase.td" - -def NVVM_Dialect : Dialect { - let name = "nvvm"; - let cppNamespace = "NVVM"; -} - -class NVVM_Op traits = []> : - LLVM_OpBase { -} - -class NVVM_SpecialRegisterOp traits = []> : - NVVM_Op, - Results<(outs LLVM_Type:$res)>, Arguments<(ins)> { - string llvmBuilder = "$res = createIntrinsicCall(builder," - # "llvm::Intrinsic::nvvm_" # !subst(".","_", mnemonic) # ");"; - let parser = [{ return parseNVVMSpecialRegisterOp(parser, result); }]; - let printer = [{ printNVVMSpecialRegisterOp(p, this->getOperation()); }]; -} - -def NVVM_ThreadIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.x">; -def NVVM_ThreadIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.y">; -def NVVM_ThreadIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.z">; -def NVVM_BlockDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ntid.x">; -def NVVM_BlockDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ntid.y">; -def NVVM_BlockDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ntid.z">; -def NVVM_BlockIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ctaid.x">; -def NVVM_BlockIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ctaid.y">; -def NVVM_BlockIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ctaid.z">; -def NVVM_GridDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.x">; -def NVVM_GridDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.y">; -def NVVM_GridDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.z">; - -#endif // NVVMIR_OPS diff --git a/mlir/include/mlir/SDBM/SDBM.h b/mlir/include/mlir/SDBM/SDBM.h deleted file mode 100644 index b1c272372b3..00000000000 --- a/mlir/include/mlir/SDBM/SDBM.h +++ /dev/null @@ -1,206 +0,0 @@ -//===- SDBM.h - MLIR SDBM declaration ---------------------------*- C++ -*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// A striped difference-bound matrix (SDBM) is a set in Z^N (or R^N) defined -// as {(x_1, ... x_n) | f(x_1, ... x_n) >= 0} where f is an SDBM expression. -// -//===----------------------------------------------------------------------===// - -#ifndef INCLUDE_MLIR_IR_SDBM_H -#define INCLUDE_MLIR_IR_SDBM_H - -#include "mlir/Support/LLVM.h" -#include "llvm/ADT/DenseMap.h" - -namespace mlir { - -class MLIRContext; -class SDBMDialect; -class SDBMExpr; -class SDBMPositiveExpr; - -/// A utility class for SDBM to represent an integer with potentially infinite -/// positive value. This uses the largest value of int64_t to represent infinity -/// and redefines the arithmetic operators so that the infinity "saturates": -/// inf + x = inf, -/// inf - x = inf. -/// If a sum of two finite values reaches the largest value of int64_t, the -/// behavior of IntInfty is undefined (in practice, it asserts), similarly to -/// regular signed integer overflow. -class IntInfty { -public: - constexpr static int64_t infty = std::numeric_limits::max(); - - /*implicit*/ IntInfty(int64_t v) : value(v) {} - - IntInfty &operator=(int64_t v) { - value = v; - return *this; - } - - static IntInfty infinity() { return IntInfty(infty); } - - int64_t getValue() const { return value; } - explicit operator int64_t() const { return value; } - - bool isFinite() { return value != infty; } - -private: - int64_t value; -}; - -inline IntInfty operator+(IntInfty lhs, IntInfty rhs) { - if (!lhs.isFinite() || !rhs.isFinite()) - return IntInfty::infty; - - // Check for overflows, treating the sum of two values adding up to INT_MAX as - // overflow. Convert values to unsigned to get an extra bit and avoid the - // undefined behavior of signed integer overflows. - assert((lhs.getValue() <= 0 || rhs.getValue() <= 0 || - static_cast(lhs.getValue()) + - static_cast(rhs.getValue()) < - static_cast(std::numeric_limits::max())) && - "IntInfty overflow"); - // Check for underflows by converting values to unsigned to avoid undefined - // behavior of signed integers perform the addition (bitwise result is same - // because numbers are required to be two's complement in C++) and check if - // the sign bit remains negative. - assert((lhs.getValue() >= 0 || rhs.getValue() >= 0 || - ((static_cast(lhs.getValue()) + - static_cast(rhs.getValue())) >> - 63) == 1) && - "IntInfty underflow"); - - return lhs.getValue() + rhs.getValue(); -} - -inline bool operator<(IntInfty lhs, IntInfty rhs) { - return lhs.getValue() < rhs.getValue(); -} - -inline bool operator<=(IntInfty lhs, IntInfty rhs) { - return lhs.getValue() <= rhs.getValue(); -} - -inline bool operator==(IntInfty lhs, IntInfty rhs) { - return lhs.getValue() == rhs.getValue(); -} - -inline bool operator!=(IntInfty lhs, IntInfty rhs) { return !(lhs == rhs); } - -/// Striped difference-bound matrix is a representation of an integer set bound -/// by a system of SDBMExprs interpreted as inequalities "expr <= 0". -class SDBM { -public: - /// Obtain an SDBM from a list of SDBM expressions treated as inequalities and - /// equalities with zero. - static SDBM get(ArrayRef inequalities, - ArrayRef equalities); - - void getSDBMExpressions(SDBMDialect *dialect, - SmallVectorImpl &inequalities, - SmallVectorImpl &equalities); - - void print(llvm::raw_ostream &os); - void dump(); - - IntInfty operator()(int i, int j) { return at(i, j); } - -private: - /// Get the given element of the difference bounds matrix. First index - /// corresponds to the negative term of the difference, second index - /// corresponds to the positive term of the difference. - IntInfty &at(int i, int j) { return matrix[i * getNumVariables() + j]; } - - /// Populate `inequalities` and `equalities` based on the values at(row,col) - /// and at(col,row) of the DBM. Depending on the values being finite and - /// being subsumed by stripe expressions, this may or may not add elements to - /// the lists of equalities and inequalities. - void convertDBMElement(unsigned row, unsigned col, SDBMPositiveExpr rowExpr, - SDBMPositiveExpr colExpr, - SmallVectorImpl &inequalities, - SmallVectorImpl &equalities); - - /// Populate `inequalities` based on the value at(pos,pos) of the DBM. Only - /// adds new inequalities if the inequality is not trivially true. - void convertDBMDiagonalElement(unsigned pos, SDBMPositiveExpr expr, - SmallVectorImpl &inequalities); - - /// Get the total number of elements in the matrix. - unsigned getNumVariables() const { - return 1 + numDims + numSymbols + numTemporaries; - } - - /// Get the position in the matrix that corresponds to the given dimension. - unsigned getDimPosition(unsigned position) const { return 1 + position; } - - /// Get the position in the matrix that corresponds to the given symbol. - unsigned getSymbolPosition(unsigned position) const { - return 1 + numDims + position; - } - - /// Get the position in the matrix that corresponds to the given temporary. - unsigned getTemporaryPosition(unsigned position) const { - return 1 + numDims + numSymbols + position; - } - - /// Number of dimensions in the system, - unsigned numDims; - /// Number of symbols in the system. - unsigned numSymbols; - /// Number of temporary variables in the system. - unsigned numTemporaries; - - /// Difference bounds matrix, stored as a linearized row-major vector. - /// Each value in this matrix corresponds to an inequality - /// - /// v@col - v@row <= at(row, col) - /// - /// where v@col and v@row are the variables that correspond to the linearized - /// position in the matrix. The positions correspond to - /// - /// - constant 0 (producing constraints v@col <= X and -v@row <= Y); - /// - SDBM expression dimensions (d0, d1, ...); - /// - SDBM expression symbols (s0, s1, ...); - /// - temporary variables (t0, t1, ...). - /// - /// Temporary variables are introduced to represent expressions that are not - /// trivially a difference between two variables. For example, if one side of - /// a difference expression is itself a stripe expression, it will be replaced - /// with a temporary variable assigned equal to this expression. - /// - /// Infinite entries in the matrix correspond correspond to an absence of a - /// constraint: - /// - /// v@col - v@row <= infinity - /// - /// is trivially true. Negated values at symmetric positions in the matrix - /// allow one to couple two inequalities into a single equality. - std::vector matrix; - - /// The mapping between the indices of variables in the DBM and the stripe - /// expressions they are equal to. These expressions are stored as they - /// appeared when constructing an SDBM from a SDBMExprs, in particular no - /// temporaries can appear in these expressions. This removes the need to - /// iteratively substitute definitions of the temporaries in the reverse - /// conversion. - llvm::DenseMap stripeToPoint; -}; - -} // namespace mlir - -#endif // INCLUDE_MLIR_IR_SDBM_H diff --git a/mlir/include/mlir/SDBM/SDBMDialect.h b/mlir/include/mlir/SDBM/SDBMDialect.h deleted file mode 100644 index 12086dcd3b4..00000000000 --- a/mlir/include/mlir/SDBM/SDBMDialect.h +++ /dev/null @@ -1,41 +0,0 @@ -//===- SDBMDialect.h - Dialect for striped DBMs -----------------*- C++ -*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#ifndef MLIR_SDBM_SDBMDIALECT_H -#define MLIR_SDBM_SDBMDIALECT_H - -#include "mlir/IR/Dialect.h" -#include "mlir/Support/StorageUniquer.h" - -namespace mlir { -class MLIRContext; - -class SDBMDialect : public Dialect { -public: - SDBMDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) {} - - static StringRef getDialectNamespace() { return "sdbm"; } - - /// Get the uniquer for SDBM expressions. This should not be used directly. - StorageUniquer &getUniquer() { return uniquer; } - -private: - StorageUniquer uniquer; -}; -} // namespace mlir - -#endif // MLIR_SDBM_SDBMDIALECT_H diff --git a/mlir/include/mlir/SDBM/SDBMExpr.h b/mlir/include/mlir/SDBM/SDBMExpr.h deleted file mode 100644 index afbeda15fe6..00000000000 --- a/mlir/include/mlir/SDBM/SDBMExpr.h +++ /dev/null @@ -1,530 +0,0 @@ -//===- SDBMExpr.h - MLIR SDBM Expression ------------------------*- C++ -*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// A striped difference-bound matrix (SDBM) expression is a constant expression, -// an identifier, a binary expression with constant RHS and +, stripe operators -// or a difference expression between two identifiers. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_IR_SDBMEXPR_H -#define MLIR_IR_SDBMEXPR_H - -#include "mlir/Support/LLVM.h" -#include "llvm/ADT/DenseMapInfo.h" - -namespace mlir { - -class AffineExpr; -class MLIRContext; - -enum class SDBMExprKind { Add, Stripe, Diff, Constant, DimId, SymbolId, Neg }; - -namespace detail { -struct SDBMExprStorage; -struct SDBMBinaryExprStorage; -struct SDBMDiffExprStorage; -struct SDBMPositiveExprStorage; -struct SDBMConstantExprStorage; -struct SDBMNegExprStorage; -} // namespace detail - -class SDBMConstantExpr; -class SDBMDialect; -class SDBMDimExpr; -class SDBMSymbolExpr; - -/// Striped Difference-Bounded Matrix (SDBM) expression is a base left-hand side -/// expression for the SDBM framework. SDBM expressions are a subset of affine -/// expressions supporting low-complexity algorithms for the operations used in -/// loop transformations. In particular, are supported: -/// - constant expressions; -/// - single variables (dimensions and symbols) with +1 or -1 coefficient; -/// - stripe expressions: "x # C", where "x" is a single variable or another -/// stripe expression, "#" is the stripe operator, and "C" is a constant -/// expression; "#" is defined as x - x mod C. -/// - sum expressions between single variable/stripe expressions and constant -/// expressions; -/// - difference expressions between single variable/stripe expressions. -/// `SDBMExpr` class hierarchy provides a type-safe interface to constructing -/// and operating on SDBM expressions. For example, it requires the LHS of a -/// sum expression to be a single variable or a stripe expression. These -/// restrictions are intended to force the caller to perform the necessary -/// simplifications to stay within the SDBM domain, because SDBM expressions do -/// not combine in more cases than they do. This choice may be reconsidered in -/// the future. -/// -/// `SDBMExpr` and derived classes are thin wrappers around a pointer owned by -/// an MLIRContext, and should be used by-value. They are uniqued in the -/// MLIRContext and immortal. -class SDBMExpr { -public: - using ImplType = detail::SDBMExprStorage; - SDBMExpr() : impl(nullptr) {} - /* implicit */ SDBMExpr(ImplType *expr) : impl(expr) {} - - /// SDBM expressions are thin wrappers around a unique'ed immutable pointer, - /// which makes them trivially assignable and trivially copyable. - SDBMExpr(const SDBMExpr &) = default; - SDBMExpr &operator=(const SDBMExpr &) = default; - - /// SDBM expressions can be compared straight-forwardly. - bool operator==(const SDBMExpr &other) const { return impl == other.impl; } - bool operator!=(const SDBMExpr &other) const { return !(*this == other); } - - /// SDBM expressions are convertible to `bool`: null expressions are converted - /// to false, non-null expressions are converted to true. - explicit operator bool() const { return impl != nullptr; } - bool operator!() const { return !static_cast(*this); } - - /// Negate the given SDBM expression. - SDBMExpr operator-(); - - /// Prints the SDBM expression. - void print(raw_ostream &os) const; - void dump() const; - - /// LLVM-style casts. - template bool isa() const { return U::isClassFor(*this); } - template U dyn_cast() const { - if (!isa()) - return {}; - return U(const_cast(this)->impl); - } - template U cast() const { - assert(isa() && "cast to incorrect subtype"); - return U(const_cast(this)->impl); - } - - /// Support for LLVM hashing. - ::llvm::hash_code hash_value() const { return ::llvm::hash_value(impl); } - - /// Returns the kind of the SDBM expression. - SDBMExprKind getKind() const; - - /// Returns the MLIR context in which this expression lives. - MLIRContext *getContext() const; - - /// Returns the SDBM dialect instance. - SDBMDialect *getDialect() const; - - /// Convert the SDBM expression into an Affine expression. This always - /// succeeds because SDBM are a subset of affine. - AffineExpr getAsAffineExpr() const; - - /// Try constructing an SDBM expression from the given affine expression. - /// This may fail if the affine expression is not representable as SDBM, in - /// which case llvm::None is returned. The conversion procedure recognizes - /// (nested) multiplicative ((x floordiv B) * B) and additive (x - x mod B) - /// patterns for the stripe expression. - static Optional tryConvertAffineExpr(AffineExpr affine); - -protected: - ImplType *impl; -}; - -/// SDBM constant expression, wraps a 64-bit integer. -class SDBMConstantExpr : public SDBMExpr { -public: - using ImplType = detail::SDBMConstantExprStorage; - - using SDBMExpr::SDBMExpr; - - /// Obtain or create a constant expression unique'ed in the given dialect - /// (which belongs to a context). - static SDBMConstantExpr get(SDBMDialect *dialect, int64_t value); - - static bool isClassFor(const SDBMExpr &expr) { - return expr.getKind() == SDBMExprKind::Constant; - } - - int64_t getValue() const; -}; - -/// SDBM varying expression can be one of: -/// - input variable expression; -/// - stripe expression; -/// - negation (product with -1) of either of the above. -/// - sum of a varying and a constant expression -/// - difference between varying expressions -class SDBMVaryingExpr : public SDBMExpr { -public: - using ImplType = detail::SDBMExprStorage; - using SDBMExpr::SDBMExpr; - - static bool isClassFor(const SDBMExpr &expr) { - return expr.getKind() == SDBMExprKind::DimId || - expr.getKind() == SDBMExprKind::SymbolId || - expr.getKind() == SDBMExprKind::Neg || - expr.getKind() == SDBMExprKind::Stripe || - expr.getKind() == SDBMExprKind::Add || - expr.getKind() == SDBMExprKind::Diff; - } -}; - -/// SDBM positive variable expression can be one of: -/// - single variable expression; -/// - stripe expression. -class SDBMPositiveExpr : public SDBMVaryingExpr { -public: - using SDBMVaryingExpr::SDBMVaryingExpr; - - static bool isClassFor(const SDBMExpr &expr) { - return expr.getKind() == SDBMExprKind::DimId || - expr.getKind() == SDBMExprKind::SymbolId || - expr.getKind() == SDBMExprKind::Stripe; - } -}; - -/// SDBM sum expression. LHS is a varying expression and RHS is always a -/// constant expression. -class SDBMSumExpr : public SDBMVaryingExpr { -public: - using ImplType = detail::SDBMBinaryExprStorage; - using SDBMVaryingExpr::SDBMVaryingExpr; - - /// Obtain or create a sum expression unique'ed in the given context. - static SDBMSumExpr get(SDBMVaryingExpr lhs, SDBMConstantExpr rhs); - - static bool isClassFor(const SDBMExpr &expr) { - SDBMExprKind kind = expr.getKind(); - return kind == SDBMExprKind::Add; - } - - SDBMVaryingExpr getLHS() const; - SDBMConstantExpr getRHS() const; -}; - -/// SDBM difference expression. Both LHS and RHS are positive variable -/// expressions. -class SDBMDiffExpr : public SDBMVaryingExpr { -public: - using ImplType = detail::SDBMDiffExprStorage; - using SDBMVaryingExpr::SDBMVaryingExpr; - - /// Obtain or create a difference expression unique'ed in the given context. - static SDBMDiffExpr get(SDBMPositiveExpr lhs, SDBMPositiveExpr rhs); - - static bool isClassFor(const SDBMExpr &expr) { - return expr.getKind() == SDBMExprKind::Diff; - } - - SDBMPositiveExpr getLHS() const; - SDBMPositiveExpr getRHS() const; -}; - -/// SDBM stripe expression "x # C" where "x" is a positive variable expression, -/// "C" is a constant expression and "#" is the stripe operator defined as: -/// x # C = x - x mod C. -class SDBMStripeExpr : public SDBMPositiveExpr { -public: - using ImplType = detail::SDBMBinaryExprStorage; - using SDBMPositiveExpr::SDBMPositiveExpr; - - static bool isClassFor(const SDBMExpr &expr) { - return expr.getKind() == SDBMExprKind::Stripe; - } - - static SDBMStripeExpr get(SDBMPositiveExpr var, - SDBMConstantExpr stripeFactor); - - SDBMPositiveExpr getVar() const; - SDBMConstantExpr getStripeFactor() const; -}; - -/// SDBM "input" variable expression can be either a dimension identifier or -/// a symbol identifier. When used to define SDBM functions, dimensions are -/// interpreted as function arguments while symbols are treated as unknown but -/// constant values, hence the name. -class SDBMInputExpr : public SDBMPositiveExpr { -public: - using ImplType = detail::SDBMPositiveExprStorage; - using SDBMPositiveExpr::SDBMPositiveExpr; - - static bool isClassFor(const SDBMExpr &expr) { - return expr.getKind() == SDBMExprKind::DimId || - expr.getKind() == SDBMExprKind::SymbolId; - } - - unsigned getPosition() const; -}; - -/// SDBM dimension expression. Dimensions correspond to function arguments -/// when defining functions using SDBM expressions. -class SDBMDimExpr : public SDBMInputExpr { -public: - using ImplType = detail::SDBMPositiveExprStorage; - using SDBMInputExpr::SDBMInputExpr; - - /// Obtain or create a dimension expression unique'ed in the given dialect - /// (which belongs to a context). - static SDBMDimExpr get(SDBMDialect *dialect, unsigned position); - - static bool isClassFor(const SDBMExpr &expr) { - return expr.getKind() == SDBMExprKind::DimId; - } -}; - -/// SDBM symbol expression. Symbols correspond to symbolic constants when -/// defining functions using SDBM expressions. -class SDBMSymbolExpr : public SDBMInputExpr { -public: - using ImplType = detail::SDBMPositiveExprStorage; - using SDBMInputExpr::SDBMInputExpr; - - /// Obtain or create a symbol expression unique'ed in the given dialect (which - /// belongs to a context). - static SDBMSymbolExpr get(SDBMDialect *dialect, unsigned position); - - static bool isClassFor(const SDBMExpr &expr) { - return expr.getKind() == SDBMExprKind::SymbolId; - } -}; - -/// Negation of an SDBM variable expression. Equivalent to multiplying the -/// expression with -1 (SDBM does not support other coefficients that 1 and -1). -class SDBMNegExpr : public SDBMVaryingExpr { -public: - using ImplType = detail::SDBMNegExprStorage; - using SDBMVaryingExpr::SDBMVaryingExpr; - - /// Obtain or create a negation expression unique'ed in the given context. - static SDBMNegExpr get(SDBMPositiveExpr var); - - static bool isClassFor(const SDBMExpr &expr) { - return expr.getKind() == SDBMExprKind::Neg; - } - - SDBMPositiveExpr getVar() const; -}; - -/// A visitor class for SDBM expressions. Calls the kind-specific function -/// depending on the kind of expression it visits. -template class SDBMVisitor { -public: - /// Visit the given SDBM expression, dispatching to kind-specific functions. - Result visit(SDBMExpr expr) { - auto *derived = static_cast(this); - switch (expr.getKind()) { - case SDBMExprKind::Add: - case SDBMExprKind::Diff: - case SDBMExprKind::DimId: - case SDBMExprKind::SymbolId: - case SDBMExprKind::Neg: - case SDBMExprKind::Stripe: - return derived->visitVarying(expr.cast()); - case SDBMExprKind::Constant: - return derived->visitConstant(expr.cast()); - } - - llvm_unreachable("unsupported SDBM expression kind"); - } - - /// Traverse the SDBM expression tree calling `visit` on each node - /// in depth-first preorder. - void walkPreorder(SDBMExpr expr) { return walk(expr); } - - /// Traverse the SDBM expression tree calling `visit` on each node in - /// depth-first postorder. - void walkPostorder(SDBMExpr expr) { return walk(expr); } - -protected: - /// Default visitors do nothing. - void visitSum(SDBMSumExpr) {} - void visitDiff(SDBMDiffExpr) {} - void visitStripe(SDBMStripeExpr) {} - void visitDim(SDBMDimExpr) {} - void visitSymbol(SDBMSymbolExpr) {} - void visitNeg(SDBMNegExpr) {} - void visitConstant(SDBMConstantExpr) {} - - /// Default implementation of visitPositive dispatches to the special - /// functions for stripes and other variables. Concrete visitors can override - /// it. - Result visitPositive(SDBMPositiveExpr expr) { - auto *derived = static_cast(this); - if (expr.getKind() == SDBMExprKind::Stripe) - return derived->visitStripe(expr.cast()); - else - return derived->visitInput(expr.cast()); - } - - /// Default implementation of visitInput dispatches to the special - /// functions for dimensions or symbols. Concrete visitors can override it to - /// visit all variables instead. - Result visitInput(SDBMInputExpr expr) { - auto *derived = static_cast(this); - if (expr.getKind() == SDBMExprKind::DimId) - return derived->visitDim(expr.cast()); - else - return derived->visitSymbol(expr.cast()); - } - - /// Default implementation of visitVarying dispatches to the special - /// functions for variables and negations thereof. Concerete visitors can - /// override it to visit all variables and negations instead. - Result visitVarying(SDBMVaryingExpr expr) { - auto *derived = static_cast(this); - if (auto var = expr.dyn_cast()) - return derived->visitPositive(var); - else if (auto neg = expr.dyn_cast()) - return derived->visitNeg(neg); - else if (auto sum = expr.dyn_cast()) - return derived->visitSum(sum); - else if (auto diff = expr.dyn_cast()) - return derived->visitDiff(diff); - - llvm_unreachable("unhandled subtype of varying SDBM expression"); - } - - template void walk(SDBMExpr expr) { - if (isPreorder) - visit(expr); - if (auto sumExpr = expr.dyn_cast()) { - walk(sumExpr.getLHS()); - walk(sumExpr.getRHS()); - } else if (auto diffExpr = expr.dyn_cast()) { - walk(diffExpr.getLHS()); - walk(diffExpr.getRHS()); - } else if (auto stripeExpr = expr.dyn_cast()) { - walk(stripeExpr.getVar()); - walk(stripeExpr.getStripeFactor()); - } else if (auto negExpr = expr.dyn_cast()) { - walk(negExpr.getVar()); - } - if (!isPreorder) - visit(expr); - } -}; - -/// Overloaded arithmetic operators for SDBM expressions asserting that their -/// arguments have the proper SDBM expression subtype. Perform canonicalization -/// and constant folding on these expressions. -namespace ops_assertions { - -/// Add two SDBM expressions. At least one of the expressions must be a -/// constant or a negation, but both expressions cannot be negations -/// simultaneously. -SDBMExpr operator+(SDBMExpr lhs, SDBMExpr rhs); -inline SDBMExpr operator+(SDBMExpr lhs, int64_t rhs) { - return lhs + SDBMConstantExpr::get(lhs.getDialect(), rhs); -} -inline SDBMExpr operator+(int64_t lhs, SDBMExpr rhs) { - return SDBMConstantExpr::get(rhs.getDialect(), lhs) + rhs; -} - -/// Subtract an SDBM expression from another SDBM expression. Both expressions -/// must not be difference expressions. -SDBMExpr operator-(SDBMExpr lhs, SDBMExpr rhs); -inline SDBMExpr operator-(SDBMExpr lhs, int64_t rhs) { - return lhs - SDBMConstantExpr::get(lhs.getDialect(), rhs); -} -inline SDBMExpr operator-(int64_t lhs, SDBMExpr rhs) { - return SDBMConstantExpr::get(rhs.getDialect(), lhs) - rhs; -} - -/// Construct a stripe expression from a positive expression and a positive -/// constant stripe factor. -SDBMExpr stripe(SDBMExpr expr, SDBMExpr factor); -inline SDBMExpr stripe(SDBMExpr expr, int64_t factor) { - return stripe(expr, SDBMConstantExpr::get(expr.getDialect(), factor)); -} -} // namespace ops_assertions - -} // end namespace mlir - -namespace llvm { -// SDBMExpr hash just like pointers. -template <> struct DenseMapInfo { - static mlir::SDBMExpr getEmptyKey() { - auto *pointer = llvm::DenseMapInfo::getEmptyKey(); - return mlir::SDBMExpr(static_cast(pointer)); - } - static mlir::SDBMExpr getTombstoneKey() { - auto *pointer = llvm::DenseMapInfo::getTombstoneKey(); - return mlir::SDBMExpr(static_cast(pointer)); - } - static unsigned getHashValue(mlir::SDBMExpr expr) { - return expr.hash_value(); - } - static bool isEqual(mlir::SDBMExpr lhs, mlir::SDBMExpr rhs) { - return lhs == rhs; - } -}; - -// SDBMVaryingExpr hash just like pointers. -template <> struct DenseMapInfo { - static mlir::SDBMVaryingExpr getEmptyKey() { - auto *pointer = llvm::DenseMapInfo::getEmptyKey(); - return mlir::SDBMVaryingExpr( - static_cast(pointer)); - } - static mlir::SDBMVaryingExpr getTombstoneKey() { - auto *pointer = llvm::DenseMapInfo::getTombstoneKey(); - return mlir::SDBMVaryingExpr( - static_cast(pointer)); - } - static unsigned getHashValue(mlir::SDBMVaryingExpr expr) { - return expr.hash_value(); - } - static bool isEqual(mlir::SDBMVaryingExpr lhs, mlir::SDBMVaryingExpr rhs) { - return lhs == rhs; - } -}; - -// SDBMPositiveExpr hash just like pointers. -template <> struct DenseMapInfo { - static mlir::SDBMPositiveExpr getEmptyKey() { - auto *pointer = llvm::DenseMapInfo::getEmptyKey(); - return mlir::SDBMPositiveExpr( - static_cast(pointer)); - } - static mlir::SDBMPositiveExpr getTombstoneKey() { - auto *pointer = llvm::DenseMapInfo::getTombstoneKey(); - return mlir::SDBMPositiveExpr( - static_cast(pointer)); - } - static unsigned getHashValue(mlir::SDBMPositiveExpr expr) { - return expr.hash_value(); - } - static bool isEqual(mlir::SDBMPositiveExpr lhs, mlir::SDBMPositiveExpr rhs) { - return lhs == rhs; - } -}; - -// SDBMConstantExpr hash just like pointers. -template <> struct DenseMapInfo { - static mlir::SDBMConstantExpr getEmptyKey() { - auto *pointer = llvm::DenseMapInfo::getEmptyKey(); - return mlir::SDBMConstantExpr( - static_cast(pointer)); - } - static mlir::SDBMConstantExpr getTombstoneKey() { - auto *pointer = llvm::DenseMapInfo::getTombstoneKey(); - return mlir::SDBMConstantExpr( - static_cast(pointer)); - } - static unsigned getHashValue(mlir::SDBMConstantExpr expr) { - return expr.hash_value(); - } - static bool isEqual(mlir::SDBMConstantExpr lhs, mlir::SDBMConstantExpr rhs) { - return lhs == rhs; - } -}; -} // namespace llvm - -#endif // MLIR_IR_SDBMEXPR_H diff --git a/mlir/include/mlir/StandardOps/CMakeLists.txt b/mlir/include/mlir/StandardOps/CMakeLists.txt deleted file mode 100644 index 670676f24db..00000000000 --- a/mlir/include/mlir/StandardOps/CMakeLists.txt +++ /dev/null @@ -1,4 +0,0 @@ -set(LLVM_TARGET_DEFINITIONS Ops.td) -mlir_tablegen(Ops.h.inc -gen-op-decls) -mlir_tablegen(Ops.cpp.inc -gen-op-defs) -add_public_tablegen_target(MLIRStandardOpsIncGen) diff --git a/mlir/include/mlir/StandardOps/Ops.h b/mlir/include/mlir/StandardOps/Ops.h deleted file mode 100644 index fbd6462938b..00000000000 --- a/mlir/include/mlir/StandardOps/Ops.h +++ /dev/null @@ -1,363 +0,0 @@ -//===- Ops.h - Standard MLIR Operations -------------------------*- C++ -*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file defines convenience types for working with standard operations -// in the MLIR operation set. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_STANDARDOPS_OPS_H -#define MLIR_STANDARDOPS_OPS_H - -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/StandardTypes.h" - -namespace mlir { -class AffineMap; -class Builder; -class FuncOp; -class OpBuilder; - -class StandardOpsDialect : public Dialect { -public: - StandardOpsDialect(MLIRContext *context); - static StringRef getDialectNamespace() { return "std"; } -}; - -/// The predicate indicates the type of the comparison to perform: -/// (in)equality; (un)signed less/greater than (or equal to). -enum class CmpIPredicate { - FirstValidValue, - // (In)equality comparisons. - EQ = FirstValidValue, - NE, - // Signed comparisons. - SLT, - SLE, - SGT, - SGE, - // Unsigned comparisons. - ULT, - ULE, - UGT, - UGE, - // Number of predicates. - NumPredicates -}; - -/// The predicate indicates the type of the comparison to perform: -/// (un)orderedness, (in)equality and less/greater than (or equal to) as -/// well as predicates that are always true or false. -enum class CmpFPredicate { - FirstValidValue, - // Always false - AlwaysFalse = FirstValidValue, - // Ordered comparisons - OEQ, - OGT, - OGE, - OLT, - OLE, - ONE, - // Both ordered - ORD, - // Unordered comparisons - UEQ, - UGT, - UGE, - ULT, - ULE, - UNE, - // Any unordered - UNO, - // Always true - AlwaysTrue, - // Number of predicates. - NumPredicates -}; - -#define GET_OP_CLASSES -#include "mlir/StandardOps/Ops.h.inc" - -/// This is a refinement of the "constant" op for the case where it is -/// returning a float value of FloatType. -/// -/// %1 = "std.constant"(){value: 42.0} : bf16 -/// -class ConstantFloatOp : public ConstantOp { -public: - using ConstantOp::ConstantOp; - - /// Builds a constant float op producing a float of the specified type. - static void build(Builder *builder, OperationState *result, - const APFloat &value, FloatType type); - - APFloat getValue() { return getAttrOfType("value").getValue(); } - - static bool classof(Operation *op); -}; - -/// This is a refinement of the "constant" op for the case where it is -/// returning an integer value of IntegerType. -/// -/// %1 = "std.constant"(){value: 42} : i32 -/// -class ConstantIntOp : public ConstantOp { -public: - using ConstantOp::ConstantOp; - /// Build a constant int op producing an integer of the specified width. - static void build(Builder *builder, OperationState *result, int64_t value, - unsigned width); - - /// Build a constant int op producing an integer with the specified type, - /// which must be an integer type. - static void build(Builder *builder, OperationState *result, int64_t value, - Type type); - - int64_t getValue() { return getAttrOfType("value").getInt(); } - - static bool classof(Operation *op); -}; - -/// This is a refinement of the "constant" op for the case where it is -/// returning an integer value of Index type. -/// -/// %1 = "std.constant"(){value: 99} : () -> index -/// -class ConstantIndexOp : public ConstantOp { -public: - using ConstantOp::ConstantOp; - - /// Build a constant int op producing an index. - static void build(Builder *builder, OperationState *result, int64_t value); - - int64_t getValue() { return getAttrOfType("value").getInt(); } - - static bool classof(Operation *op); -}; - -// DmaStartOp starts a non-blocking DMA operation that transfers data from a -// source memref to a destination memref. The source and destination memref need -// not be of the same dimensionality, but need to have the same elemental type. -// The operands include the source and destination memref's each followed by its -// indices, size of the data transfer in terms of the number of elements (of the -// elemental type of the memref), a tag memref with its indices, and optionally -// at the end, a stride and a number_of_elements_per_stride arguments. The tag -// location is used by a DmaWaitOp to check for completion. The indices of the -// source memref, destination memref, and the tag memref have the same -// restrictions as any load/store. The optional stride arguments should be of -// 'index' type, and specify a stride for the slower memory space (memory space -// with a lower memory space id), tranferring chunks of -// number_of_elements_per_stride every stride until %num_elements are -// transferred. Either both or no stride arguments should be specified. -// -// For example, a DmaStartOp operation that transfers 256 elements of a memref -// '%src' in memory space 0 at indices [%i, %j] to memref '%dst' in memory space -// 1 at indices [%k, %l], would be specified as follows: -// -// %num_elements = constant 256 -// %idx = constant 0 : index -// %tag = alloc() : memref<1 x i32, (d0) -> (d0), 4> -// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx] : -// memref<40 x 128 x f32>, (d0) -> (d0), 0>, -// memref<2 x 1024 x f32>, (d0) -> (d0), 1>, -// memref<1 x i32>, (d0) -> (d0), 2> -// -// If %stride and %num_elt_per_stride are specified, the DMA is expected to -// transfer %num_elt_per_stride elements every %stride elements apart from -// memory space 0 until %num_elements are transferred. -// -// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx], %stride, -// %num_elt_per_stride : -// -// TODO(mlir-team): add additional operands to allow source and destination -// striding, and multiple stride levels. -// TODO(andydavis) Consider replacing src/dst memref indices with view memrefs. -class DmaStartOp - : public Op { -public: - using Op::Op; - - static void build(Builder *builder, OperationState *result, Value *srcMemRef, - ArrayRef srcIndices, Value *destMemRef, - ArrayRef destIndices, Value *numElements, - Value *tagMemRef, ArrayRef tagIndices, - Value *stride = nullptr, - Value *elementsPerStride = nullptr); - - // Returns the source MemRefType for this DMA operation. - Value *getSrcMemRef() { return getOperand(0); } - // Returns the rank (number of indices) of the source MemRefType. - unsigned getSrcMemRefRank() { - return getSrcMemRef()->getType().cast().getRank(); - } - // Returns the source memerf indices for this DMA operation. - operand_range getSrcIndices() { - return {getOperation()->operand_begin() + 1, - getOperation()->operand_begin() + 1 + getSrcMemRefRank()}; - } - - // Returns the destination MemRefType for this DMA operations. - Value *getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); } - // Returns the rank (number of indices) of the destination MemRefType. - unsigned getDstMemRefRank() { - return getDstMemRef()->getType().cast().getRank(); - } - unsigned getSrcMemorySpace() { - return getSrcMemRef()->getType().cast().getMemorySpace(); - } - unsigned getDstMemorySpace() { - return getDstMemRef()->getType().cast().getMemorySpace(); - } - - // Returns the destination memref indices for this DMA operation. - operand_range getDstIndices() { - return {getOperation()->operand_begin() + 1 + getSrcMemRefRank() + 1, - getOperation()->operand_begin() + 1 + getSrcMemRefRank() + 1 + - getDstMemRefRank()}; - } - - // Returns the number of elements being transferred by this DMA operation. - Value *getNumElements() { - return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank()); - } - - // Returns the Tag MemRef for this DMA operation. - Value *getTagMemRef() { - return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1); - } - // Returns the rank (number of indices) of the tag MemRefType. - unsigned getTagMemRefRank() { - return getTagMemRef()->getType().cast().getRank(); - } - - // Returns the tag memref index for this DMA operation. - operand_range getTagIndices() { - unsigned tagIndexStartPos = - 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1; - return {getOperation()->operand_begin() + tagIndexStartPos, - getOperation()->operand_begin() + tagIndexStartPos + - getTagMemRefRank()}; - } - - /// Returns true if this is a DMA from a faster memory space to a slower one. - bool isDestMemorySpaceFaster() { - return (getSrcMemorySpace() < getDstMemorySpace()); - } - - /// Returns true if this is a DMA from a slower memory space to a faster one. - bool isSrcMemorySpaceFaster() { - // Assumes that a lower number is for a slower memory space. - return (getDstMemorySpace() < getSrcMemorySpace()); - } - - /// Given a DMA start operation, returns the operand position of either the - /// source or destination memref depending on the one that is at the higher - /// level of the memory hierarchy. Asserts failure if neither is true. - unsigned getFasterMemPos() { - assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster()); - return isSrcMemorySpaceFaster() ? 0 : getSrcMemRefRank() + 1; - } - - static StringRef getOperationName() { return "std.dma_start"; } - static ParseResult parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p); - LogicalResult verify(); - - static void getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context); - - bool isStrided() { - return getNumOperands() != 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + - 1 + 1 + getTagMemRefRank(); - } - - Value *getStride() { - if (!isStrided()) - return nullptr; - return getOperand(getNumOperands() - 1 - 1); - } - - Value *getNumElementsPerStride() { - if (!isStrided()) - return nullptr; - return getOperand(getNumOperands() - 1); - } -}; - -// DmaWaitOp blocks until the completion of a DMA operation associated with the -// tag element '%tag[%index]'. %tag is a memref, and %index has to be an index -// with the same restrictions as any load/store index. %num_elements is the -// number of elements associated with the DMA operation. For example: -// -// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%index] : -// memref<2048 x f32>, (d0) -> (d0), 0>, -// memref<256 x f32>, (d0) -> (d0), 1> -// memref<1 x i32>, (d0) -> (d0), 2> -// ... -// ... -// dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 2> -// -class DmaWaitOp - : public Op { -public: - using Op::Op; - - static void build(Builder *builder, OperationState *result, Value *tagMemRef, - ArrayRef tagIndices, Value *numElements); - - static StringRef getOperationName() { return "std.dma_wait"; } - - // Returns the Tag MemRef associated with the DMA operation being waited on. - Value *getTagMemRef() { return getOperand(0); } - - // Returns the tag memref index for this DMA operation. - operand_range getTagIndices() { - return {getOperation()->operand_begin() + 1, - getOperation()->operand_begin() + 1 + getTagMemRefRank()}; - } - - // Returns the rank (number of indices) of the tag memref. - unsigned getTagMemRefRank() { - return getTagMemRef()->getType().cast().getRank(); - } - - // Returns the number of elements transferred in the associated DMA operation. - Value *getNumElements() { return getOperand(1 + getTagMemRefRank()); } - - static ParseResult parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p); - static void getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context); -}; - -/// Prints dimension and symbol list. -void printDimAndSymbolList(Operation::operand_iterator begin, - Operation::operand_iterator end, unsigned numDims, - OpAsmPrinter *p); - -/// Parses dimension and symbol list and returns true if parsing failed. -ParseResult parseDimAndSymbolList(OpAsmParser *parser, - SmallVector &operands, - unsigned &numDims); - -} // end namespace mlir - -#endif // MLIR_STANDARDOPS_OPS_H diff --git a/mlir/include/mlir/StandardOps/Ops.td b/mlir/include/mlir/StandardOps/Ops.td deleted file mode 100644 index b6bf2cfb40b..00000000000 --- a/mlir/include/mlir/StandardOps/Ops.td +++ /dev/null @@ -1,905 +0,0 @@ -//===- Ops.td - Standard operation definitions -------------*- tablegen -*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// Defines some MLIR standard operations. -// -//===----------------------------------------------------------------------===// - -#ifdef STANDARD_OPS -#else -#define STANDARD_OPS - -#ifdef OP_BASE -#else -include "mlir/IR/OpBase.td" -#endif // OP_BASE - -def Std_Dialect : Dialect { - let name = "std"; - let cppNamespace = ""; -} - -// Base class for Standard dialect ops. -class Std_Op traits = []> : - Op { - // For every standard op, there needs to be a: - // * void print(OpAsmPrinter *p, ${C++ class of Op} op) - // * LogicalResult verify(${C++ class of Op} op) - // * ParseResult parse${C++ class of Op}(OpAsmParser *parser, - // OperationState *result) - // functions. - let printer = [{ return ::print(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; - let parser = [{ return ::parse$cppClass(parser, result); }]; -} - -// Base class for standard cast operations. Requires single operand and result, -// but does not constrain them to specific types. -class CastOp traits = []> : - Std_Op { - - let results = (outs AnyType); - - let builders = [OpBuilder< - "Builder *builder, OperationState *result, Value *source, Type destType", [{ - impl::buildCastOp(builder, result, source, destType); - }]>]; - - let parser = [{ - return impl::parseCastOp(parser, result); - }]; - let printer = [{ - return printStandardCastOp(this->getOperation(), p); - }]; - let verifier = [{ return ::verifyCastOp(*this); }]; - - let hasFolder = 1; -} - -// Base class for standard arithmetic operations. Requires operands and -// results to be of the same type, but does not constrain them to specific -// types. Individual classes will have `lhs` and `rhs` accessor to operands. -class ArithmeticOp traits = []> : - Op { - - let results = (outs AnyType); - - let parser = [{ - return impl::parseBinaryOp(parser, result); - }]; - - let printer = [{ - return printStandardBinaryOp(this->getOperation(), p); - }]; -} - -// Base class for standard arithmetic operations on integers, vectors and -// tensors thereof. This operation takes two operands and returns one result, -// each of these is required to be of the same type. This type may be an -// integer scalar type, a vector whose element type is an integer type, or an -// integer tensor. The custom assembly form of the operaton is as follows -// -// i %0, %1 : i32 -class IntArithmeticOp traits = []> : - ArithmeticOp, - Arguments<(ins IntegerLike:$lhs, IntegerLike:$rhs)>; - -// Base class for standard arithmetic binary operations on floats, vectors and -// tensors thereof. This operation has two operands and returns one result, -// each of these is required to be of the same type. This type may be a -// floating point scalar type, a vector whose element type is a floating point -// type, or a floating point tensor. The custom assembly form of the operation -// is as follows -// -// f %0, %1 : f32 -class FloatArithmeticOp traits = []> : - ArithmeticOp, - Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>; - -def AddFOp : FloatArithmeticOp<"addf"> { - let summary = "floating point addition operation"; - let hasFolder = 1; -} - -def AddIOp : IntArithmeticOp<"addi", [Commutative]> { - let summary = "integer addition operation"; - let hasFolder = 1; -} - -def AllocOp : Std_Op<"alloc"> { - let summary = "memory allocation operation"; - let description = [{ - The "alloc" operation allocates a region of memory, as specified by its - memref type. For example: - - %0 = alloc() : memref<8x64xf32, (d0, d1) -> (d0, d1), 1> - - The optional list of dimension operands are bound to the dynamic dimensions - specified in its memref type. In the example below, the ssa value '%d' is - bound to the second dimension of the memref (which is dynamic). - - %0 = alloc(%d) : memref<8x?xf32, (d0, d1) -> (d0, d1), 1> - - The optional list of symbol operands are bound to the symbols of the - memrefs affine map. In the example below, the ssa value '%s' is bound to - the symbol 's0' in the affine map specified in the allocs memref type. - - %0 = alloc()[%s] : memref<8x64xf32, (d0, d1)[s0] -> ((d0 + s0), d1), 1> - - This operation returns a single ssa value of memref type, which can be used - by subsequent load and store operations. - }]; - - let arguments = (ins Variadic:$value); - let results = (outs AnyMemRef); - - let builders = [OpBuilder< - "Builder *builder, OperationState *result, MemRefType memrefType", [{ - result->types.push_back(memrefType); - }] - >]; - - let extraClassDeclaration = [{ - MemRefType getType() { return getResult()->getType().cast(); } - }]; - - let hasCanonicalizer = 1; -} - -def AndOp : IntArithmeticOp<"and", [Commutative]> { - let summary = "integer binary and"; - let hasFolder = 1; -} - -def BranchOp : Std_Op<"br", [Terminator]> { - let summary = "branch operation"; - let description = [{ - The "br" operation represents a branch operation in a function. - The operation takes variable number of operands and produces no results. - The operand number and types for each successor must match the arguments of - the block successor. For example: - - ^bb2: - %2 = call @someFn() - br ^bb3(%2 : tensor<*xf32>) - ^bb3(%3: tensor<*xf32>): - }]; - - let arguments = (ins Variadic:$operands); - - let builders = [OpBuilder< - "Builder *, OperationState *result, Block *dest," - "ArrayRef operands = {}", [{ - result->addSuccessor(dest, operands); - }]>]; - - // BranchOp is fully verified by traits. - let verifier = ?; - - let extraClassDeclaration = [{ - Block *getDest(); - void setDest(Block *block); - - /// Erase the operand at 'index' from the operand list. - void eraseOperand(unsigned index); - }]; -} - -def CallOp : Std_Op<"call"> { - let summary = "call operation"; - let description = [{ - The "call" operation represents a direct call to a function. The operands - and result types of the call must match the specified function type. The - callee is encoded as a function attribute named "callee". - - %2 = call @my_add(%0, %1) : (f32, f32) -> f32 - }]; - - let arguments = (ins SymbolRefAttr:$callee, Variadic:$operands); - let results = (outs Variadic); - - let builders = [OpBuilder< - "Builder *builder, OperationState *result, FuncOp callee," - "ArrayRef operands = {}", [{ - result->addOperands(operands); - result->addAttribute("callee", builder->getSymbolRefAttr(callee)); - result->addTypes(callee.getType().getResults()); - }]>, OpBuilder< - "Builder *builder, OperationState *result, StringRef callee," - "ArrayRef results, ArrayRef operands = {}", [{ - result->addOperands(operands); - result->addAttribute("callee", builder->getSymbolRefAttr(callee)); - result->addTypes(results); - }]>]; - - let extraClassDeclaration = [{ - StringRef getCallee() { return callee(); } - FunctionType getCalleeType(); - - /// Get the argument operands to the called function. - operand_range getArgOperands() { - return {arg_operand_begin(), arg_operand_end()}; - } - - operand_iterator arg_operand_begin() { return operand_begin(); } - operand_iterator arg_operand_end() { return operand_end(); } - }]; -} - -def CallIndirectOp : Std_Op<"call_indirect"> { - let summary = "indirect call operation"; - let description = [{ - The "call_indirect" operation represents an indirect call to a value of - function type. Functions are first class types in MLIR, and may be passed - as arguments and merged together with block arguments. The operands - and result types of the call must match the specified function type. - - %3 = call_indirect %2(%0, %1) : (f32, f32) -> f32 - }]; - - let arguments = (ins FunctionType:$callee, Variadic:$operands); - let results = (outs Variadic); - - let builders = [OpBuilder< - "Builder *, OperationState *result, Value *callee," - "ArrayRef operands = {}", [{ - result->operands.push_back(callee); - result->addOperands(operands); - result->addTypes(callee->getType().cast().getResults()); - }]>]; - - let extraClassDeclaration = [{ - Value *getCallee() { return getOperand(0); } - - /// Get the argument operands to the called function. - operand_range getArgOperands() { - return {arg_operand_begin(), arg_operand_end()}; - } - - operand_iterator arg_operand_begin() { return ++operand_begin(); } - operand_iterator arg_operand_end() { return operand_end(); } - }]; - - let hasCanonicalizer = 1; -} - -def CmpIOp : Std_Op<"cmpi", [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape]> { - let summary = "integer comparison operation"; - let description = [{ - The "cmpi" operation compares its two operands according to the integer - comparison rules and the predicate specified by the respective attribute. - The predicate defines the type of comparison: (in)equality, (un)signed - less/greater than (or equal to). The operands must have the same type, and - this type must be an integer type, a vector or a tensor thereof. The result - is an i1, or a vector/tensor thereof having the same shape as the inputs. - Since integers are signless, the predicate also explicitly indicates - whether to interpret the operands as signed or unsigned integers for - less/greater than comparisons. For the sake of readability by humans, - custom assembly form for the operation uses a string-typed attribute for - the predicate. The value of this attribute corresponds to lower-cased name - of the predicate constant, e.g., "slt" means "signed less than". The string - representation of the attribute is merely a syntactic sugar and is converted - to an integer attribute by the parser. - - %r1 = cmpi "eq" %0, %1 : i32 - %r2 = cmpi "slt" %0, %1 : tensor<42x42xi64> - %r3 = "std.cmpi"(%0, %1){predicate: 0} : (i8, i8) -> i1 - }]; - - let arguments = (ins IntegerLike:$lhs, IntegerLike:$rhs); - let results = (outs BoolLike); - - let builders = [OpBuilder< - "Builder *builder, OperationState *result, CmpIPredicate predicate," - "Value *lhs, Value *rhs", [{ - ::buildCmpIOp(builder, result, predicate, lhs, rhs); - }]>]; - - let extraClassDeclaration = [{ - static StringRef getPredicateAttrName() { return "predicate"; } - static CmpIPredicate getPredicateByName(StringRef name); - - CmpIPredicate getPredicate() { - return (CmpIPredicate)getAttrOfType(getPredicateAttrName()) - .getInt(); - } - }]; - - let hasFolder = 1; -} - -def CmpFOp : Std_Op<"cmpf", [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape]> { - let summary = "floating-point comparison operation"; - let description = [{ - The "cmpf" operation compares its two operands according to the float - comparison rules and the predicate specified by the respective attribute. - The predicate defines the type of comparison: (un)orderedness, (in)equality - and signed less/greater than (or equal to) as well as predicates that are - always true or false. The operands must have the same type, and this type - must be a float type, or a vector or tensor thereof. The result is an i1, - or a vector/tensor thereof having the same shape as the inputs. Unlike cmpi, - the operands are always treated as signed. The u prefix indicates - *unordered* comparison, not unsigned comparison, so "une" means unordered or - not equal. For the sake of readability by humans, custom assembly form for - the operation uses a string-typed attribute for the predicate. The value of - this attribute corresponds to lower-cased name of the predicate constant, - e.g., "one" means "ordered not equal". The string representation of the - attribute is merely a syntactic sugar and is converted to an integer - attribute by the parser. - - %r1 = cmpf "oeq" %0, %1 : f32 - %r2 = cmpf "ult" %0, %1 : tensor<42x42xf64> - %r3 = "std.cmpf"(%0, %1) {predicate: 0} : (f8, f8) -> i1 - }]; - - let arguments = (ins FloatLike:$lhs, FloatLike:$rhs); - let results = (outs BoolLike); - - let builders = [OpBuilder< - "Builder *builder, OperationState *result, CmpFPredicate predicate," - "Value *lhs, Value *rhs", [{ - ::buildCmpFOp(builder, result, predicate, lhs, rhs); - }]>]; - - let extraClassDeclaration = [{ - static StringRef getPredicateAttrName() { return "predicate"; } - static CmpFPredicate getPredicateByName(StringRef name); - - CmpFPredicate getPredicate() { - return (CmpFPredicate)getAttrOfType(getPredicateAttrName()) - .getInt(); - } - }]; - - let hasFolder = 1; -} - -def CondBranchOp : Std_Op<"cond_br", [Terminator]> { - let summary = "conditional branch operation"; - let description = [{ - The "cond_br" operation represents a conditional branch operation in a - function. The operation takes variable number of operands and produces - no results. The operand number and types for each successor must match the - arguments of the block successor. For example: - - ^bb0: - %0 = extract_element %arg0[] : tensor - cond_br %0, ^bb1, ^bb2 - ^bb1: - ... - ^bb2: - ... - }]; - - let arguments = (ins I1:$condition, Variadic:$branchOperands); - - let builders = [OpBuilder< - "Builder *, OperationState *result, Value *condition," - "Block *trueDest, ArrayRef trueOperands," - "Block *falseDest, ArrayRef falseOperands", [{ - result->addOperands(condition); - result->addSuccessor(trueDest, trueOperands); - result->addSuccessor(falseDest, falseOperands); - }]>]; - - // CondBranchOp is fully verified by traits. - let verifier = ?; - - let extraClassDeclaration = [{ - // These are the indices into the dests list. - enum { trueIndex = 0, falseIndex = 1 }; - - // The condition operand is the first operand in the list. - Value *getCondition() { return getOperand(0); } - - /// Return the destination if the condition is true. - Block *getTrueDest() { - return getOperation()->getSuccessor(trueIndex); - } - - /// Return the destination if the condition is false. - Block *getFalseDest() { - return getOperation()->getSuccessor(falseIndex); - } - - // Accessors for operands to the 'true' destination. - Value *getTrueOperand(unsigned idx) { - assert(idx < getNumTrueOperands()); - return getOperand(getTrueDestOperandIndex() + idx); - } - - void setTrueOperand(unsigned idx, Value *value) { - assert(idx < getNumTrueOperands()); - setOperand(getTrueDestOperandIndex() + idx, value); - } - - operand_iterator true_operand_begin() { - return operand_begin() + getTrueDestOperandIndex(); - } - operand_iterator true_operand_end() { - return true_operand_begin() + getNumTrueOperands(); - } - operand_range getTrueOperands() { - return {true_operand_begin(), true_operand_end()}; - } - - unsigned getNumTrueOperands() { - return getOperation()->getNumSuccessorOperands(trueIndex); - } - - /// Erase the operand at 'index' from the true operand list. - void eraseTrueOperand(unsigned index) { - getOperation()->eraseSuccessorOperand(trueIndex, index); - } - - // Accessors for operands to the 'false' destination. - Value *getFalseOperand(unsigned idx) { - assert(idx < getNumFalseOperands()); - return getOperand(getFalseDestOperandIndex() + idx); - } - void setFalseOperand(unsigned idx, Value *value) { - assert(idx < getNumFalseOperands()); - setOperand(getFalseDestOperandIndex() + idx, value); - } - - operand_iterator false_operand_begin() { return true_operand_end(); } - operand_iterator false_operand_end() { - return false_operand_begin() + getNumFalseOperands(); - } - operand_range getFalseOperands() { - return {false_operand_begin(), false_operand_end()}; - } - - unsigned getNumFalseOperands() { - return getOperation()->getNumSuccessorOperands(falseIndex); - } - - /// Erase the operand at 'index' from the false operand list. - void eraseFalseOperand(unsigned index) { - getOperation()->eraseSuccessorOperand(falseIndex, index); - } - - private: - /// Get the index of the first true destination operand. - unsigned getTrueDestOperandIndex() { return 1; } - - /// Get the index of the first false destination operand. - unsigned getFalseDestOperandIndex() { - return getTrueDestOperandIndex() + getNumTrueOperands(); - } - }]; - - let hasCanonicalizer = 1; -} - -def ConstantOp : Std_Op<"constant", [NoSideEffect]> { - let summary = "constant"; - - let arguments = (ins AnyAttr:$value); - let results = (outs AnyType); - - let builders = [OpBuilder< - "Builder *builder, OperationState *result, Attribute value", - [{ build(builder, result, value.getType(), value); }]>]; - - let extraClassDeclaration = [{ - Attribute getValue() { return getAttr("value"); } - - /// Returns true if a constant operation can be built with the given value - /// and result type. - static bool isBuildableWith(Attribute value, Type type); - }]; - - let hasFolder = 1; -} - -def DeallocOp : Std_Op<"dealloc"> { - let summary = "memory deallocation operation"; - let description = [{ - The "dealloc" operation frees the region of memory referenced by a memref - which was originally created by the "alloc" operation. - The "dealloc" operation should not be called on memrefs which alias an - alloc'd memref (i.e. memrefs returned by the "view" and "reshape" - operations). - - %0 = alloc() : memref<8x64xf32, (d0, d1) -> (d0, d1), 1> - dealloc %0 : memref<8x64xf32, (d0, d1) -> (d0, d1), 1> - }]; - - let arguments = (ins AnyMemRef:$memref); - - let hasCanonicalizer = 1; -} - -def DimOp : Std_Op<"dim", [NoSideEffect]> { - let summary = "dimension index operation"; - let description = [{ - The "dim" operation takes a memref or tensor operand and returns an "index". - It requires a single integer attribute named "index". It returns the size - of the specified dimension. For example: - - %1 = dim %0, 2 : tensor - }]; - - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor], - "any tensor or memref type">:$memrefOrTensor, - APIntAttr:$index); - let results = (outs Index); - - let builders = [OpBuilder< - "Builder *builder, OperationState *result, Value *memrefOrTensor," - "unsigned index", [{ - auto indexType = builder->getIndexType(); - auto indexAttr = builder->getIntegerAttr(indexType, index); - build(builder, result, indexType, memrefOrTensor, indexAttr); - }]>]; - - let extraClassDeclaration = [{ - unsigned getIndex() { - return getAttrOfType("index").getValue().getZExtValue(); - } - }]; - - let hasFolder = 1; -} - -def DivFOp : FloatArithmeticOp<"divf"> { - let summary = "floating point division operation"; -} - -def DivISOp : IntArithmeticOp<"divis"> { - let summary = "signed integer division operation"; - let hasFolder = 1; -} - -def DivIUOp : IntArithmeticOp<"diviu"> { - let summary = "unsigned integer division operation"; - let hasFolder = 1; -} - -def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> { - let summary = "element extract operation"; - let description = [{ - The "extract_element" op reads a tensor or vector and returns one element - from it specified by an index list. The output of extract is a new value - with the same type as the elements of the tensor or vector. The arity of - indices matches the rank of the accessed value (i.e., if a tensor is of rank - 3, then 3 indices are required for the extract). The indices should all be - of affine_int type. For example: - - %0 = extract_element %0[%1, %2] : vector<4x4xi32> - }]; - - let arguments = (ins AnyTypeOf<[AnyVector, AnyTensor]>:$aggregate, - Variadic:$indices); - let results = (outs AnyType); - - let builders = [OpBuilder< - "Builder *builder, OperationState *result, Value *aggregate," - "ArrayRef indices = {}", [{ - auto resType = aggregate->getType().cast() - .getElementType(); - build(builder, result, resType, aggregate, indices); - }]>]; - - let extraClassDeclaration = [{ - Value *getAggregate() { return getOperand(0); } - - operand_range getIndices() { - return {getOperation()->operand_begin() + 1, - getOperation()->operand_end()}; - } - }]; - - let hasFolder = 1; -} - -def IndexCastOp : CastOp<"index_cast">, Arguments<(ins AnyType:$in)> { - let summary = "cast between index and integer types"; - let description = [{ - Casts between integer scalars and 'index' scalars. Index is an integer of - platform-specific bit width. If casting to a wider integer, the value is - sign-extended. If casting to a narrower integer, the value is truncated. - }]; - - let extraClassDeclaration = [{ - /// Return true if `a` and `b` are valid operand and result pairs for - /// the operation. - static bool areCastCompatible(Type a, Type b); - }]; - - let hasFolder = 0; -} - -def SIToFPOp : CastOp<"sitofp">, Arguments<(ins AnyType:$in)> { - let summary = "cast from integer type to floating-point"; - let description = [{ - Cast from a value interpreted as signed integer to the corresponding - floating-point value. If the value cannot be exactly represented, it is - rounded using the default rounding mode. Only scalars are currently - supported. - }]; - - let extraClassDeclaration = [{ - /// Return true if `a` and `b` are valid operand and result pairs for - /// the operation. - static bool areCastCompatible(Type a, Type b); - }]; - - let hasFolder = 0; -} - -def LoadOp : Std_Op<"load"> { - let summary = "load operation"; - let description = [{ - The "load" op reads an element from a memref specified by an index list. The - output of load is a new value with the same type as the elements of the - memref. The arity of indices is the rank of the memref (i.e., if the memref - loaded from is of rank 3, then 3 indices are required for the load following - the memref identifier). For example: - - %3 = load %0[%1, %1] : memref<4x4xi32> - }]; - - let arguments = (ins AnyMemRef:$memref, Variadic:$indices); - let results = (outs AnyType); - - let builders = [OpBuilder< - "Builder *, OperationState *result, Value *memref," - "ArrayRef indices = {}", [{ - auto memrefType = memref->getType().cast(); - result->addOperands(memref); - result->addOperands(indices); - result->types.push_back(memrefType.getElementType()); - }]>]; - - let extraClassDeclaration = [{ - Value *getMemRef() { return getOperand(0); } - void setMemRef(Value *value) { setOperand(0, value); } - MemRefType getMemRefType() { - return getMemRef()->getType().cast(); - } - - operand_range getIndices() { - return {getOperation()->operand_begin() + 1, getOperation()->operand_end()}; - } - }]; - - let hasCanonicalizer = 1; -} - -def MemRefCastOp : CastOp<"memref_cast"> { - let summary = "memref cast operation"; - let description = [{ - The "memref_cast" operation converts a memref from one type to an equivalent - type with a compatible shape. The source and destination types are - when both are memref types with the same element type, affine mappings, - address space, and rank but where the individual dimensions may add or - remove constant dimensions from the memref type. - - If the cast converts any dimensions from an unknown to a known size, then it - acts as an assertion that fails at runtime of the dynamic dimensions - disagree with resultant destination size. - - Assert that the input dynamic shape matches the destination static shape. - %2 = memref_cast %1 : memref to memref<4x4xf32> - Erase static shape information, replacing it with dynamic information. - %3 = memref_cast %1 : memref<4xf32> to memref - }]; - - let arguments = (ins AnyMemRef:$source); - let results = (outs AnyMemRef); - - let extraClassDeclaration = [{ - /// Return true if `a` and `b` are valid operand and result pairs for - /// the operation. - static bool areCastCompatible(Type a, Type b); - - /// The result of a memref_cast is always a memref. - MemRefType getType() { return getResult()->getType().cast(); } - }]; -} - -def MulFOp : FloatArithmeticOp<"mulf"> { - let summary = "foating point multiplication operation"; - let hasFolder = 1; -} - -def MulIOp : IntArithmeticOp<"muli", [Commutative]> { - let summary = "integer multiplication operation"; - let hasFolder = 1; -} - -def OrOp : IntArithmeticOp<"or", [Commutative]> { - let summary = "integer binary or"; - let hasFolder = 1; -} - -def RankOp : Std_Op<"rank", [NoSideEffect]> { - let summary = "rank operation"; - let description = [{ - The "rank" operation takes a tensor operand and returns its rank. - - %1 = rank %0 : index - }]; - - let arguments = (ins AnyTensor); - let results = (outs Index); - let verifier = ?; - - let builders = [OpBuilder< - "Builder *builder, OperationState *result, Value *tensor", [{ - auto indexType = builder->getIndexType(); - build(builder, result, indexType, tensor); - }]>]; - - let hasFolder = 1; -} - -def RemFOp : FloatArithmeticOp<"remf"> { - let summary = "floating point division remainder operation"; -} - -def RemISOp : IntArithmeticOp<"remis"> { - let summary = "signed integer division remainder operation"; - let hasFolder = 1; -} - -def RemIUOp : IntArithmeticOp<"remiu"> { - let summary = "unsigned integer division remainder operation"; - let hasFolder = 1; -} - -def ReturnOp : Std_Op<"return", [Terminator, HasParent<"FuncOp">]> { - let summary = "return operation"; - let description = [{ - The "return" operation represents a return operation within a function. - The operation takes variable number of operands and produces no results. - The operand number and types must match the signature of the function - that contains the operation. For example: - - func @foo() : (i32, f8) { - ... - return %0, %1 : i32, f8 - }]; - - let arguments = (ins Variadic:$operands); - - let builders = [OpBuilder< - "Builder *b, OperationState *result", [{ build(b, result, llvm::None); }] - >]; -} - -def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape]> { - let summary = "select operation"; - let description = [{ - The "select" operation chooses one value based on a binary condition - supplied as its first operand. If the value of the first operand is 1, the - second operand is chosen, otherwise the third operand is chosen. The second - and the third operand must have the same type. The operation applies - elementwise to vectors and tensors. The shape of all arguments must be - identical. For example, the maximum operation is obtained by combining - "select" with "cmpi" as follows. - - %2 = cmpi "gt" %0, %1 : i32 // %2 is i1 - %3 = select %2, %0, %1 : i32 - }]; - - let arguments = (ins BoolLike:$condition, AnyType:$true_value, - AnyType:$false_value); - let results = (outs AnyType); - - let builders = [OpBuilder< - "Builder *builder, OperationState *result, Value *condition," - "Value *trueValue, Value *falseValue", [{ - result->addOperands({condition, trueValue, falseValue}); - result->addTypes(trueValue->getType()); - }]>]; - - let extraClassDeclaration = [{ - Value *getCondition() { return condition(); } - Value *getTrueValue() { return true_value(); } - Value *getFalseValue() { return false_value(); } - }]; - - let hasFolder = 1; -} -def ShlISOp : IntArithmeticOp<"shlis"> { - let summary = "signed integer shift left"; -} - -def SubFOp : FloatArithmeticOp<"subf"> { - let summary = "floating point subtraction operation"; - let hasFolder = 1; -} - -def SubIOp : IntArithmeticOp<"subi"> { - let summary = "integer subtraction operation"; - let hasFolder = 1; -} - -def StoreOp : Std_Op<"store"> { - let summary = "store operation"; - let description = [{ - The "store" op writes an element to a memref specified by an index list. - The arity of indices is the rank of the memref (i.e. if the memref being - stored to is of rank 3, then 3 indices are required for the store following - the memref identifier). The store operation does not produce a result. - - In the following example, the ssa value '%v' is stored in memref '%A' at - indices [%i, %j]: - store %v, %A[%i, %j] : memref<4x128xf32, (d0, d1) -> (d0, d1), 0> - }]; - - let arguments = (ins AnyType:$value, AnyMemRef:$memref, Variadic:$indices); - - let builders = [OpBuilder< - "Builder *, OperationState *result, Value *valueToStore, Value *memref", [{ - result->addOperands(valueToStore); - result->addOperands(memref); - }]>]; - - let extraClassDeclaration = [{ - Value *getValueToStore() { return getOperand(0); } - - Value *getMemRef() { return getOperand(1); } - void setMemRef(Value *value) { setOperand(1, value); } - MemRefType getMemRefType() { - return getMemRef()->getType().cast(); - } - - operand_range getIndices() { - return {getOperation()->operand_begin() + 2, getOperation()->operand_end()}; - } - }]; - - let hasCanonicalizer = 1; -} - -def TensorCastOp : CastOp<"tensor_cast"> { - let summary = "tensor cast operation"; - let description = [{ - The "tensor_cast" operation converts a tensor from one type to an equivalent - type without changing any data elements. The source and destination types - must both be tensor types with the same element type. If both are ranked - then the rank should be the same and static dimensions should match. The - operation is invalid if converting to a mismatching constant dimension. - - Convert from unknown rank to rank 2 with unknown dimension sizes. - %2 = tensor_cast %1 : tensor to tensor - }]; - - let arguments = (ins AnyTensor); - let results = (outs AnyTensor); - - let extraClassDeclaration = [{ - /// Return true if `a` and `b` are valid operand and result pairs for - /// the operation. - static bool areCastCompatible(Type a, Type b); - - /// The result of a tensor_cast is always a tensor. - TensorType getType() { return getResult()->getType().cast(); } - }]; -} - -def XOrOp : IntArithmeticOp<"xor", [Commutative]> { - let summary = "integer binary xor"; - let hasFolder = 1; -} - -#endif // STANDARD_OPS diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h index ff48a902134..c59d76ae047 100644 --- a/mlir/include/mlir/Transforms/Utils.h +++ b/mlir/include/mlir/Transforms/Utils.h @@ -25,8 +25,8 @@ #ifndef MLIR_TRANSFORMS_UTILS_H #define MLIR_TRANSFORMS_UTILS_H +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/AffineMap.h" -#include "mlir/StandardOps/Ops.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 51a6ec2aecf..f3af9599b59 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -16,6 +16,7 @@ // ============================================================================= #include "mlir/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" @@ -23,7 +24,6 @@ #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/StandardOps/Ops.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/Support/Debug.h" diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index 28c4eae941e..e074e5d4405 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -24,12 +24,12 @@ #include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Operation.h" -#include "mlir/StandardOps/Ops.h" #include "mlir/Support/MathExtras.h" #include "mlir/Support/STLExtras.h" #include "llvm/ADT/DenseMap.h" diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index b2b2c6970b9..b1e818ac02c 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -21,11 +21,11 @@ #include "mlir/Analysis/AffineStructures.h" #include "mlir/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Operation.h" -#include "mlir/StandardOps/Ops.h" #include "mlir/Support/MathExtras.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SmallPtrSet.h" diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 743907ba39c..79620f95373 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -26,10 +26,10 @@ #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/NestedMatcher.h" #include "mlir/Analysis/VectorAnalysis.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Operation.h" -#include "mlir/StandardOps/Ops.h" #include "mlir/Support/Functional.h" #include "mlir/Support/MathExtras.h" #include "mlir/VectorOps/VectorOps.h" diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index b043d4734fd..85fe3109f6a 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -25,9 +25,9 @@ #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Passes.h" #include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/Ops.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "memref-bound-check" diff --git a/mlir/lib/Analysis/NestedMatcher.cpp b/mlir/lib/Analysis/NestedMatcher.cpp index 18be6cf3bc9..c7c0db90a7b 100644 --- a/mlir/lib/Analysis/NestedMatcher.cpp +++ b/mlir/lib/Analysis/NestedMatcher.cpp @@ -17,7 +17,7 @@ #include "mlir/Analysis/NestedMatcher.h" #include "mlir/AffineOps/AffineOps.h" -#include "mlir/StandardOps/Ops.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" diff --git a/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp b/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp index 1802b736fad..9ecdcf7c2fe 100644 --- a/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp @@ -24,9 +24,9 @@ #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Passes.h" #include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/Ops.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "test-memref-dependence-check" diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 85e39e37f65..d4fc42ceff7 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -25,8 +25,8 @@ #include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Builders.h" -#include "mlir/StandardOps/Ops.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Debug.h" diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index 23061561dfb..2e85b168a37 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -19,10 +19,10 @@ #include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/LoopAnalysis.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Builders.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Operation.h" -#include "mlir/StandardOps/Ops.h" #include "mlir/Support/Functional.h" #include "mlir/Support/STLExtras.h" #include "mlir/VectorOps/VectorOps.h" diff --git a/mlir/lib/CMakeLists.txt b/mlir/lib/CMakeLists.txt index fece5cbb063..a0a1bdad2f3 100644 --- a/mlir/lib/CMakeLists.txt +++ b/mlir/lib/CMakeLists.txt @@ -5,13 +5,10 @@ add_subdirectory(Dialect) add_subdirectory(EDSC) add_subdirectory(ExecutionEngine) add_subdirectory(IR) -add_subdirectory(LLVMIR) add_subdirectory(Linalg) add_subdirectory(Parser) add_subdirectory(Pass) add_subdirectory(Quantizer) -add_subdirectory(SDBM) -add_subdirectory(StandardOps) add_subdirectory(Support) add_subdirectory(TableGen) add_subdirectory(Target) diff --git a/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp b/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp index 9535dc7d903..d68c2658f6e 100644 --- a/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp +++ b/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp @@ -22,12 +22,12 @@ #include "mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h" #include "mlir/Dialect/LoopOps/LoopOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Builders.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/Ops.h" #include "mlir/Support/Functional.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp index a3b80b1e9e0..7073e5e46ee 100644 --- a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp @@ -25,12 +25,12 @@ #include "mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h" #include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" #include "mlir/IR/Module.h" #include "mlir/IR/StandardTypes.h" -#include "mlir/LLVMIR/LLVMDialect.h" #include "mlir/Pass/Pass.h" #include "llvm/ADT/STLExtras.h" diff --git a/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp b/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp index b819de2471e..12f65c76ad5 100644 --- a/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp @@ -21,13 +21,13 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" #include "mlir/IR/Identifier.h" #include "mlir/IR/Module.h" #include "mlir/IR/StandardTypes.h" -#include "mlir/LLVMIR/LLVMDialect.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassRegistry.h" diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index 32b0caf180a..3ba3e430853 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -21,10 +21,10 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/IR/Builders.h" #include "mlir/IR/StandardTypes.h" -#include "mlir/LLVMIR/LLVMDialect.h" -#include "mlir/LLVMIR/NVVMDialect.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassRegistry.h" diff --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp index 6ca4cb39f83..13ba898dc44 100644 --- a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp +++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp @@ -25,9 +25,9 @@ #include "mlir/AffineOps/AffineOps.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/LoopOps/LoopOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Builders.h" -#include "mlir/StandardOps/Ops.h" #include "mlir/Transforms/LowerAffine.h" #include "mlir/Transforms/RegionUtils.h" diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 4240e3e7ae7..e33da63f6b7 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -23,13 +23,13 @@ #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Builders.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/LLVMIR/LLVMDialect.h" #include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/Ops.h" #include "mlir/Support/Functional.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp index 035de4f815d..b7dfff4cef3 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -22,7 +22,7 @@ #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" -#include "mlir/StandardOps/Ops.h" +#include "mlir/Dialect/StandardOps/Ops.h" using namespace mlir; diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td index 9198e8538a1..4cfd5596db3 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td @@ -16,7 +16,7 @@ #ifdef STANDARD_OPS #else -include "mlir/StandardOps/Ops.td" +include "mlir/Dialect/StandardOps/Ops.td" #endif // STANDARD_OPS #ifdef SPIRV_OPS diff --git a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp index 1e4b8ca6419..238bd920341 100644 --- a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp @@ -18,6 +18,7 @@ #include "mlir/Conversion/VectorToLLVM/VectorToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/MLIRContext.h" @@ -26,7 +27,6 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" #include "mlir/IR/Types.h" -#include "mlir/LLVMIR/LLVMDialect.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/DialectConversion.h" diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt index 8898c43fc1d..7c6a4fafc4d 100644 --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -1,8 +1,11 @@ add_subdirectory(FxpMathOps) add_subdirectory(GPU) +add_subdirectory(LLVMIR) add_subdirectory(LoopOps) add_subdirectory(QuantOps) +add_subdirectory(SDBM) add_subdirectory(SPIRV) +add_subdirectory(StandardOps) add_llvm_library(MLIRDialect Traits.cpp diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp index e6c351bd105..83307da957b 100644 --- a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp +++ b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp @@ -19,10 +19,10 @@ #include "mlir/Dialect/FxpMathOps/FxpMathOps.h" #include "mlir/Dialect/FxpMathOps/Passes.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/Ops.h" using namespace mlir; using namespace mlir::fxpmath; diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 2fbaa49f56e..22d433a74fc 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -20,13 +20,13 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" #include "mlir/IR/Module.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" -#include "mlir/StandardOps/Ops.h" using namespace mlir; using namespace mlir::gpu; diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp index ea64ea8058b..481ed247e81 100644 --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -21,10 +21,10 @@ #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/Passes.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/Ops.h" using namespace mlir; diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt new file mode 100644 index 00000000000..4469e7606d3 --- /dev/null +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -0,0 +1,17 @@ +add_llvm_library(MLIRLLVMIR + IR/LLVMDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR + ) +add_dependencies(MLIRLLVMIR MLIRLLVMOpsIncGen MLIRLLVMConversionsIncGen LLVMAsmParser LLVMCore LLVMSupport) +target_link_libraries(MLIRLLVMIR LLVMAsmParser LLVMCore LLVMSupport) + +add_llvm_library(MLIRNVVMIR + IR/NVVMDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR + ) +add_dependencies(MLIRNVVMIR MLIRNVVMOpsIncGen MLIRNVVMConversionsIncGen LLVMAsmParser LLVMCore LLVMSupport) +target_link_libraries(MLIRNVVMIR LLVMAsmParser LLVMCore LLVMSupport) diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp new file mode 100644 index 00000000000..906cf344347 --- /dev/null +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -0,0 +1,1394 @@ +//===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines the types and operation details for the LLVM IR dialect in +// MLIR, and the LLVM IR dialect. It also registers the dialect. +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/StandardTypes.h" + +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Type.h" +#include "llvm/Support/Mutex.h" +#include "llvm/Support/SourceMgr.h" + +using namespace mlir; +using namespace mlir::LLVM; + +#include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc" + +//===----------------------------------------------------------------------===// +// Printing/parsing for LLVM::CmpOp. +//===----------------------------------------------------------------------===// +static void printICmpOp(OpAsmPrinter *p, ICmpOp &op) { + *p << op.getOperationName() << " \"" << stringifyICmpPredicate(op.predicate()) + << "\" " << *op.getOperand(0) << ", " << *op.getOperand(1); + p->printOptionalAttrDict(op.getAttrs(), {"predicate"}); + *p << " : " << op.lhs()->getType(); +} + +static void printFCmpOp(OpAsmPrinter *p, FCmpOp &op) { + *p << op.getOperationName() << " \"" << stringifyFCmpPredicate(op.predicate()) + << "\" " << *op.getOperand(0) << ", " << *op.getOperand(1); + p->printOptionalAttrDict(op.getAttrs(), {"predicate"}); + *p << " : " << op.lhs()->getType(); +} + +// ::= `llvm.icmp` string-literal ssa-use `,` ssa-use +// attribute-dict? `:` type +// ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use +// attribute-dict? `:` type +template +static ParseResult parseCmpOp(OpAsmParser *parser, OperationState *result) { + Builder &builder = parser->getBuilder(); + + Attribute predicate; + SmallVector attrs; + OpAsmParser::OperandType lhs, rhs; + Type type; + llvm::SMLoc predicateLoc, trailingTypeLoc; + if (parser->getCurrentLocation(&predicateLoc) || + parser->parseAttribute(predicate, "predicate", attrs) || + parser->parseOperand(lhs) || parser->parseComma() || + parser->parseOperand(rhs) || parser->parseOptionalAttributeDict(attrs) || + parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) || + parser->parseType(type) || + parser->resolveOperand(lhs, type, result->operands) || + parser->resolveOperand(rhs, type, result->operands)) + return failure(); + + // Replace the string attribute `predicate` with an integer attribute. + auto predicateStr = predicate.dyn_cast(); + if (!predicateStr) + return parser->emitError(predicateLoc, + "expected 'predicate' attribute of string type"); + + int64_t predicateValue = 0; + if (std::is_same()) { + Optional predicate = + symbolizeICmpPredicate(predicateStr.getValue()); + if (!predicate) + return parser->emitError(predicateLoc) + << "'" << predicateStr.getValue() + << "' is an incorrect value of the 'predicate' attribute"; + predicateValue = static_cast(predicate.getValue()); + } else { + Optional predicate = + symbolizeFCmpPredicate(predicateStr.getValue()); + if (!predicate) + return parser->emitError(predicateLoc) + << "'" << predicateStr.getValue() + << "' is an incorrect value of the 'predicate' attribute"; + predicateValue = static_cast(predicate.getValue()); + } + + attrs[0].second = parser->getBuilder().getI64IntegerAttr(predicateValue); + + // The result type is either i1 or a vector type if the inputs are + // vectors. + auto *dialect = builder.getContext()->getRegisteredDialect(); + auto resultType = LLVMType::getInt1Ty(dialect); + auto argType = type.dyn_cast(); + if (!argType) + return parser->emitError(trailingTypeLoc, "expected LLVM IR dialect type"); + if (argType.getUnderlyingType()->isVectorTy()) + resultType = LLVMType::getVectorTy( + resultType, argType.getUnderlyingType()->getVectorNumElements()); + + result->attributes = attrs; + result->addTypes({resultType}); + return success(); +} + +//===----------------------------------------------------------------------===// +// Printing/parsing for LLVM::AllocaOp. +//===----------------------------------------------------------------------===// + +static void printAllocaOp(OpAsmPrinter *p, AllocaOp &op) { + auto elemTy = op.getType().cast().getPointerElementTy(); + + auto funcTy = FunctionType::get({op.arraySize()->getType()}, {op.getType()}, + op.getContext()); + + *p << op.getOperationName() << ' ' << *op.arraySize() << " x " << elemTy; + if (op.alignment().hasValue() && op.alignment()->getSExtValue() != 0) + p->printOptionalAttrDict(op.getAttrs()); + else + p->printOptionalAttrDict(op.getAttrs(), {"alignment"}); + *p << " : " << funcTy; +} + +// ::= `llvm.alloca` ssa-use `x` type attribute-dict? +// `:` type `,` type +static ParseResult parseAllocaOp(OpAsmParser *parser, OperationState *result) { + SmallVector attrs; + OpAsmParser::OperandType arraySize; + Type type, elemType; + llvm::SMLoc trailingTypeLoc; + if (parser->parseOperand(arraySize) || parser->parseKeyword("x") || + parser->parseType(elemType) || + parser->parseOptionalAttributeDict(attrs) || parser->parseColon() || + parser->getCurrentLocation(&trailingTypeLoc) || parser->parseType(type)) + return failure(); + + // Extract the result type from the trailing function type. + auto funcType = type.dyn_cast(); + if (!funcType || funcType.getNumInputs() != 1 || + funcType.getNumResults() != 1) + return parser->emitError( + trailingTypeLoc, + "expected trailing function type with one argument and one result"); + + if (parser->resolveOperand(arraySize, funcType.getInput(0), result->operands)) + return failure(); + + result->attributes = attrs; + result->addTypes({funcType.getResult(0)}); + return success(); +} + +//===----------------------------------------------------------------------===// +// Printing/parsing for LLVM::GEPOp. +//===----------------------------------------------------------------------===// + +static void printGEPOp(OpAsmPrinter *p, GEPOp &op) { + SmallVector types(op.getOperandTypes()); + auto funcTy = FunctionType::get(types, op.getType(), op.getContext()); + + *p << op.getOperationName() << ' ' << *op.base() << '['; + p->printOperands(std::next(op.operand_begin()), op.operand_end()); + *p << ']'; + p->printOptionalAttrDict(op.getAttrs()); + *p << " : " << funcTy; +} + +// ::= `llvm.getelementptr` ssa-use `[` ssa-use-list `]` +// attribute-dict? `:` type +static ParseResult parseGEPOp(OpAsmParser *parser, OperationState *result) { + SmallVector attrs; + OpAsmParser::OperandType base; + SmallVector indices; + Type type; + llvm::SMLoc trailingTypeLoc; + if (parser->parseOperand(base) || + parser->parseOperandList(indices, OpAsmParser::Delimiter::Square) || + parser->parseOptionalAttributeDict(attrs) || parser->parseColon() || + parser->getCurrentLocation(&trailingTypeLoc) || parser->parseType(type)) + return failure(); + + // Deconstruct the trailing function type to extract the types of the base + // pointer and result (same type) and the types of the indices. + auto funcType = type.dyn_cast(); + if (!funcType || funcType.getNumResults() != 1 || + funcType.getNumInputs() == 0) + return parser->emitError(trailingTypeLoc, + "expected trailing function type with at least " + "one argument and one result"); + + if (parser->resolveOperand(base, funcType.getInput(0), result->operands) || + parser->resolveOperands(indices, funcType.getInputs().drop_front(), + parser->getNameLoc(), result->operands)) + return failure(); + + result->attributes = attrs; + result->addTypes(funcType.getResults()); + return success(); +} + +//===----------------------------------------------------------------------===// +// Printing/parsing for LLVM::LoadOp. +//===----------------------------------------------------------------------===// + +static void printLoadOp(OpAsmPrinter *p, LoadOp &op) { + *p << op.getOperationName() << ' ' << *op.addr(); + p->printOptionalAttrDict(op.getAttrs()); + *p << " : " << op.addr()->getType(); +} + +// Extract the pointee type from the LLVM pointer type wrapped in MLIR. Return +// the resulting type wrapped in MLIR, or nullptr on error. +static Type getLoadStoreElementType(OpAsmParser *parser, Type type, + llvm::SMLoc trailingTypeLoc) { + auto llvmTy = type.dyn_cast(); + if (!llvmTy) + return parser->emitError(trailingTypeLoc, "expected LLVM IR dialect type"), + nullptr; + if (!llvmTy.getUnderlyingType()->isPointerTy()) + return parser->emitError(trailingTypeLoc, "expected LLVM pointer type"), + nullptr; + return llvmTy.getPointerElementTy(); +} + +// ::= `llvm.load` ssa-use attribute-dict? `:` type +static ParseResult parseLoadOp(OpAsmParser *parser, OperationState *result) { + SmallVector attrs; + OpAsmParser::OperandType addr; + Type type; + llvm::SMLoc trailingTypeLoc; + + if (parser->parseOperand(addr) || parser->parseOptionalAttributeDict(attrs) || + parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) || + parser->parseType(type) || + parser->resolveOperand(addr, type, result->operands)) + return failure(); + + Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc); + + result->attributes = attrs; + result->addTypes(elemTy); + return success(); +} + +//===----------------------------------------------------------------------===// +// Printing/parsing for LLVM::StoreOp. +//===----------------------------------------------------------------------===// + +static void printStoreOp(OpAsmPrinter *p, StoreOp &op) { + *p << op.getOperationName() << ' ' << *op.value() << ", " << *op.addr(); + p->printOptionalAttrDict(op.getAttrs()); + *p << " : " << op.addr()->getType(); +} + +// ::= `llvm.store` ssa-use `,` ssa-use attribute-dict? `:` type +static ParseResult parseStoreOp(OpAsmParser *parser, OperationState *result) { + SmallVector attrs; + OpAsmParser::OperandType addr, value; + Type type; + llvm::SMLoc trailingTypeLoc; + + if (parser->parseOperand(value) || parser->parseComma() || + parser->parseOperand(addr) || parser->parseOptionalAttributeDict(attrs) || + parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) || + parser->parseType(type)) + return failure(); + + Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc); + if (!elemTy) + return failure(); + + if (parser->resolveOperand(value, elemTy, result->operands) || + parser->resolveOperand(addr, type, result->operands)) + return failure(); + + result->attributes = attrs; + return success(); +} + +//===----------------------------------------------------------------------===// +// Printing/parsing for LLVM::CallOp. +//===----------------------------------------------------------------------===// + +static void printCallOp(OpAsmPrinter *p, CallOp &op) { + auto callee = op.callee(); + bool isDirect = callee.hasValue(); + + // Print the direct callee if present as a function attribute, or an indirect + // callee (first operand) otherwise. + *p << op.getOperationName() << ' '; + if (isDirect) + *p << '@' << callee.getValue(); + else + *p << *op.getOperand(0); + + *p << '('; + p->printOperands(llvm::drop_begin(op.getOperands(), isDirect ? 0 : 1)); + *p << ')'; + + p->printOptionalAttrDict(op.getAttrs(), {"callee"}); + + // Reconstruct the function MLIR function type from operand and result types. + SmallVector resultTypes(op.getResultTypes()); + SmallVector argTypes( + llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1)); + + *p << " : " << FunctionType::get(argTypes, resultTypes, op.getContext()); +} + +// ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)` +// attribute-dict? `:` function-type +static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) { + SmallVector attrs; + SmallVector operands; + Type type; + SymbolRefAttr funcAttr; + llvm::SMLoc trailingTypeLoc; + + // Parse an operand list that will, in practice, contain 0 or 1 operand. In + // case of an indirect call, there will be 1 operand before `(`. In case of a + // direct call, there will be no operands and the parser will stop at the + // function identifier without complaining. + if (parser->parseOperandList(operands)) + return failure(); + bool isDirect = operands.empty(); + + // Optionally parse a function identifier. + if (isDirect) + if (parser->parseAttribute(funcAttr, "callee", attrs)) + return failure(); + + if (parser->parseOperandList(operands, OpAsmParser::Delimiter::Paren) || + parser->parseOptionalAttributeDict(attrs) || parser->parseColon() || + parser->getCurrentLocation(&trailingTypeLoc) || parser->parseType(type)) + return failure(); + + auto funcType = type.dyn_cast(); + if (!funcType) + return parser->emitError(trailingTypeLoc, "expected function type"); + if (isDirect) { + // Make sure types match. + if (parser->resolveOperands(operands, funcType.getInputs(), + parser->getNameLoc(), result->operands)) + return failure(); + result->addTypes(funcType.getResults()); + } else { + // Construct the LLVM IR Dialect function type that the first operand + // should match. + if (funcType.getNumResults() > 1) + return parser->emitError(trailingTypeLoc, + "expected function with 0 or 1 result"); + + Builder &builder = parser->getBuilder(); + auto *llvmDialect = + builder.getContext()->getRegisteredDialect(); + LLVM::LLVMType llvmResultType; + if (funcType.getNumResults() == 0) { + llvmResultType = LLVM::LLVMType::getVoidTy(llvmDialect); + } else { + llvmResultType = funcType.getResult(0).dyn_cast(); + if (!llvmResultType) + return parser->emitError(trailingTypeLoc, + "expected result to have LLVM type"); + } + + SmallVector argTypes; + argTypes.reserve(funcType.getNumInputs()); + for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) { + auto argType = funcType.getInput(i).dyn_cast(); + if (!argType) + return parser->emitError(trailingTypeLoc, + "expected LLVM types as inputs"); + argTypes.push_back(argType); + } + auto llvmFuncType = LLVM::LLVMType::getFunctionTy(llvmResultType, argTypes, + /*isVarArg=*/false); + auto wrappedFuncType = llvmFuncType.getPointerTo(); + + auto funcArguments = + ArrayRef(operands).drop_front(); + + // Make sure that the first operand (indirect callee) matches the wrapped + // LLVM IR function type, and that the types of the other call operands + // match the types of the function arguments. + if (parser->resolveOperand(operands[0], wrappedFuncType, + result->operands) || + parser->resolveOperands(funcArguments, funcType.getInputs(), + parser->getNameLoc(), result->operands)) + return failure(); + + result->addTypes(llvmResultType); + } + + result->attributes = attrs; + return success(); +} + +//===----------------------------------------------------------------------===// +// Printing/parsing for LLVM::ExtractElementOp. +//===----------------------------------------------------------------------===// +// Expects vector to be of wrapped LLVM vector type and position to be of +// wrapped LLVM i32 type. +void LLVM::ExtractElementOp::build(Builder *b, OperationState *result, + Value *vector, Value *position, + ArrayRef attrs) { + auto wrappedVectorType = vector->getType().cast(); + auto llvmType = wrappedVectorType.getVectorElementType(); + build(b, result, llvmType, vector, position); + result->addAttributes(attrs); +} + +static void printExtractElementOp(OpAsmPrinter *p, ExtractElementOp &op) { + *p << op.getOperationName() << ' ' << *op.vector() << ", " << *op.position(); + p->printOptionalAttrDict(op.getAttrs()); + *p << " : " << op.vector()->getType(); +} + +// ::= `llvm.extractelement` ssa-use `, ` ssa-use +// attribute-dict? `:` type +static ParseResult parseExtractElementOp(OpAsmParser *parser, + OperationState *result) { + llvm::SMLoc loc; + OpAsmParser::OperandType vector, position; + auto *llvmDialect = parser->getBuilder() + .getContext() + ->getRegisteredDialect(); + Type type, i32Type = LLVMType::getInt32Ty(llvmDialect); + if (parser->getCurrentLocation(&loc) || parser->parseOperand(vector) || + parser->parseComma() || parser->parseOperand(position) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(type) || + parser->resolveOperand(vector, type, result->operands) || + parser->resolveOperand(position, i32Type, result->operands)) + return failure(); + auto wrappedVectorType = type.dyn_cast(); + if (!wrappedVectorType || + !wrappedVectorType.getUnderlyingType()->isVectorTy()) + return parser->emitError( + loc, "expected LLVM IR dialect vector type for operand #1"); + result->addTypes(wrappedVectorType.getVectorElementType()); + return success(); +} + +//===----------------------------------------------------------------------===// +// Printing/parsing for LLVM::ExtractValueOp. +//===----------------------------------------------------------------------===// + +static void printExtractValueOp(OpAsmPrinter *p, ExtractValueOp &op) { + *p << op.getOperationName() << ' ' << *op.container() << op.position(); + p->printOptionalAttrDict(op.getAttrs(), {"position"}); + *p << " : " << op.container()->getType(); +} + +// Extract the type at `position` in the wrapped LLVM IR aggregate type +// `containerType`. Position is an integer array attribute where each value +// is a zero-based position of the element in the aggregate type. Return the +// resulting type wrapped in MLIR, or nullptr on error. +static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser *parser, + Type containerType, + Attribute positionAttr, + llvm::SMLoc attributeLoc, + llvm::SMLoc typeLoc) { + auto wrappedContainerType = containerType.dyn_cast(); + if (!wrappedContainerType) + return parser->emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr; + + auto positionArrayAttr = positionAttr.dyn_cast(); + if (!positionArrayAttr) + return parser->emitError(attributeLoc, "expected an array attribute"), + nullptr; + + // Infer the element type from the structure type: iteratively step inside the + // type by taking the element type, indexed by the position attribute for + // stuctures. Check the position index before accessing, it is supposed to be + // in bounds. + for (Attribute subAttr : positionArrayAttr) { + auto positionElementAttr = subAttr.dyn_cast(); + if (!positionElementAttr) + return parser->emitError(attributeLoc, + "expected an array of integer literals"), + nullptr; + int position = positionElementAttr.getInt(); + auto *llvmContainerType = wrappedContainerType.getUnderlyingType(); + if (llvmContainerType->isArrayTy()) { + if (position < 0 || static_cast(position) >= + llvmContainerType->getArrayNumElements()) + return parser->emitError(attributeLoc, "position out of bounds"), + nullptr; + wrappedContainerType = wrappedContainerType.getArrayElementType(); + } else if (llvmContainerType->isStructTy()) { + if (position < 0 || static_cast(position) >= + llvmContainerType->getStructNumElements()) + return parser->emitError(attributeLoc, "position out of bounds"), + nullptr; + wrappedContainerType = + wrappedContainerType.getStructElementType(position); + } else { + return parser->emitError(typeLoc, + "expected wrapped LLVM IR structure/array type"), + nullptr; + } + } + return wrappedContainerType; +} + +// ::= `llvm.extractvalue` ssa-use +// `[` integer-literal (`,` integer-literal)* `]` +// attribute-dict? `:` type +static ParseResult parseExtractValueOp(OpAsmParser *parser, + OperationState *result) { + SmallVector attrs; + OpAsmParser::OperandType container; + Type containerType; + Attribute positionAttr; + llvm::SMLoc attributeLoc, trailingTypeLoc; + + if (parser->parseOperand(container) || + parser->getCurrentLocation(&attributeLoc) || + parser->parseAttribute(positionAttr, "position", attrs) || + parser->parseOptionalAttributeDict(attrs) || parser->parseColon() || + parser->getCurrentLocation(&trailingTypeLoc) || + parser->parseType(containerType) || + parser->resolveOperand(container, containerType, result->operands)) + return failure(); + + auto elementType = getInsertExtractValueElementType( + parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); + if (!elementType) + return failure(); + + result->attributes = attrs; + result->addTypes(elementType); + return success(); +} + +//===----------------------------------------------------------------------===// +// Printing/parsing for LLVM::InsertElementOp. +//===----------------------------------------------------------------------===// + +static void printInsertElementOp(OpAsmPrinter *p, InsertElementOp &op) { + *p << op.getOperationName() << ' ' << *op.vector() << ", " << *op.value() + << ", " << *op.position(); + p->printOptionalAttrDict(op.getAttrs()); + *p << " : " << op.vector()->getType(); +} + +// ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use +// attribute-dict? `:` type +static ParseResult parseInsertElementOp(OpAsmParser *parser, + OperationState *result) { + llvm::SMLoc loc; + OpAsmParser::OperandType vector, value, position; + auto *llvmDialect = parser->getBuilder() + .getContext() + ->getRegisteredDialect(); + Type vectorType, i32Type = LLVMType::getInt32Ty(llvmDialect); + if (parser->getCurrentLocation(&loc) || parser->parseOperand(vector) || + parser->parseComma() || parser->parseOperand(value) || + parser->parseComma() || parser->parseOperand(position) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(vectorType)) + return failure(); + + auto wrappedVectorType = vectorType.dyn_cast(); + if (!wrappedVectorType || + !wrappedVectorType.getUnderlyingType()->isVectorTy()) + return parser->emitError( + loc, "expected LLVM IR dialect vector type for operand #1"); + auto valueType = wrappedVectorType.getVectorElementType(); + if (!valueType) + return failure(); + + if (parser->resolveOperand(vector, vectorType, result->operands) || + parser->resolveOperand(value, valueType, result->operands) || + parser->resolveOperand(position, i32Type, result->operands)) + return failure(); + + result->addTypes(vectorType); + return success(); +} + +//===----------------------------------------------------------------------===// +// Printing/parsing for LLVM::InsertValueOp. +//===----------------------------------------------------------------------===// + +static void printInsertValueOp(OpAsmPrinter *p, InsertValueOp &op) { + *p << op.getOperationName() << ' ' << *op.value() << ", " << *op.container() + << op.position(); + p->printOptionalAttrDict(op.getAttrs(), {"position"}); + *p << " : " << op.container()->getType(); +} + +// ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use +// `[` integer-literal (`,` integer-literal)* `]` +// attribute-dict? `:` type +static ParseResult parseInsertValueOp(OpAsmParser *parser, + OperationState *result) { + OpAsmParser::OperandType container, value; + Type containerType; + Attribute positionAttr; + llvm::SMLoc attributeLoc, trailingTypeLoc; + + if (parser->parseOperand(value) || parser->parseComma() || + parser->parseOperand(container) || + parser->getCurrentLocation(&attributeLoc) || + parser->parseAttribute(positionAttr, "position", result->attributes) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) || + parser->parseType(containerType)) + return failure(); + + auto valueType = getInsertExtractValueElementType( + parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); + if (!valueType) + return failure(); + + if (parser->resolveOperand(container, containerType, result->operands) || + parser->resolveOperand(value, valueType, result->operands)) + return failure(); + + result->addTypes(containerType); + return success(); +} + +//===----------------------------------------------------------------------===// +// Printing/parsing for LLVM::SelectOp. +//===----------------------------------------------------------------------===// + +static void printSelectOp(OpAsmPrinter *p, SelectOp &op) { + *p << op.getOperationName() << ' ' << *op.condition() << ", " + << *op.trueValue() << ", " << *op.falseValue(); + p->printOptionalAttrDict(op.getAttrs()); + *p << " : " << op.condition()->getType() << ", " << op.trueValue()->getType(); +} + +// ::= `llvm.select` ssa-use `,` ssa-use `,` ssa-use +// attribute-dict? `:` type, type +static ParseResult parseSelectOp(OpAsmParser *parser, OperationState *result) { + OpAsmParser::OperandType condition, trueValue, falseValue; + Type conditionType, argType; + + if (parser->parseOperand(condition) || parser->parseComma() || + parser->parseOperand(trueValue) || parser->parseComma() || + parser->parseOperand(falseValue) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(conditionType) || parser->parseComma() || + parser->parseType(argType)) + return failure(); + + if (parser->resolveOperand(condition, conditionType, result->operands) || + parser->resolveOperand(trueValue, argType, result->operands) || + parser->resolveOperand(falseValue, argType, result->operands)) + return failure(); + + result->addTypes(argType); + return success(); +} + +//===----------------------------------------------------------------------===// +// Printing/parsing for LLVM::BrOp. +//===----------------------------------------------------------------------===// + +static void printBrOp(OpAsmPrinter *p, BrOp &op) { + *p << op.getOperationName() << ' '; + p->printSuccessorAndUseList(op.getOperation(), 0); + p->printOptionalAttrDict(op.getAttrs()); +} + +// ::= `llvm.br` bb-id (`[` ssa-use-and-type-list `]`)? +// attribute-dict? +static ParseResult parseBrOp(OpAsmParser *parser, OperationState *result) { + Block *dest; + SmallVector operands; + if (parser->parseSuccessorAndUseList(dest, operands) || + parser->parseOptionalAttributeDict(result->attributes)) + return failure(); + + result->addSuccessor(dest, operands); + return success(); +} + +//===----------------------------------------------------------------------===// +// Printing/parsing for LLVM::CondBrOp. +//===----------------------------------------------------------------------===// + +static void printCondBrOp(OpAsmPrinter *p, CondBrOp &op) { + *p << op.getOperationName() << ' ' << *op.getOperand(0) << ", "; + p->printSuccessorAndUseList(op.getOperation(), 0); + *p << ", "; + p->printSuccessorAndUseList(op.getOperation(), 1); + p->printOptionalAttrDict(op.getAttrs()); +} + +// ::= `llvm.cond_br` ssa-use `,` +// bb-id (`[` ssa-use-and-type-list `]`)? `,` +// bb-id (`[` ssa-use-and-type-list `]`)? attribute-dict? +static ParseResult parseCondBrOp(OpAsmParser *parser, OperationState *result) { + Block *trueDest; + Block *falseDest; + SmallVector trueOperands; + SmallVector falseOperands; + OpAsmParser::OperandType condition; + + Builder &builder = parser->getBuilder(); + auto *llvmDialect = + builder.getContext()->getRegisteredDialect(); + auto i1Type = LLVM::LLVMType::getInt1Ty(llvmDialect); + + if (parser->parseOperand(condition) || parser->parseComma() || + parser->parseSuccessorAndUseList(trueDest, trueOperands) || + parser->parseComma() || + parser->parseSuccessorAndUseList(falseDest, falseOperands) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->resolveOperand(condition, i1Type, result->operands)) + return failure(); + + result->addSuccessor(trueDest, trueOperands); + result->addSuccessor(falseDest, falseOperands); + return success(); +} + +//===----------------------------------------------------------------------===// +// Printing/parsing for LLVM::ReturnOp. +//===----------------------------------------------------------------------===// + +static void printReturnOp(OpAsmPrinter *p, ReturnOp &op) { + *p << op.getOperationName(); + p->printOptionalAttrDict(op.getAttrs()); + assert(op.getNumOperands() <= 1); + + if (op.getNumOperands() == 0) + return; + + *p << ' ' << *op.getOperand(0) << " : " << op.getOperand(0)->getType(); +} + +// ::= `llvm.return` ssa-use-list attribute-dict? `:` +// type-list-no-parens +static ParseResult parseReturnOp(OpAsmParser *parser, OperationState *result) { + SmallVector operands; + Type type; + + if (parser->parseOperandList(operands) || + parser->parseOptionalAttributeDict(result->attributes)) + return failure(); + if (operands.empty()) + return success(); + + if (parser->parseColonType(type) || + parser->resolveOperand(operands[0], type, result->operands)) + return failure(); + return success(); +} + +//===----------------------------------------------------------------------===// +// Printing/parsing for LLVM::UndefOp. +//===----------------------------------------------------------------------===// + +static void printUndefOp(OpAsmPrinter *p, UndefOp &op) { + *p << op.getOperationName(); + p->printOptionalAttrDict(op.getAttrs()); + *p << " : " << op.res()->getType(); +} + +// ::= `llvm.undef` attribute-dict? : type +static ParseResult parseUndefOp(OpAsmParser *parser, OperationState *result) { + Type type; + + if (parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(type)) + return failure(); + + result->addTypes(type); + return success(); +} + +//===----------------------------------------------------------------------===// +// Printer, parser and verifier for LLVM::AddressOfOp. +//===----------------------------------------------------------------------===// + +GlobalOp AddressOfOp::getGlobal() { + auto module = getParentOfType(); + assert(module && "unexpected operation outside of a module"); + return module.lookupSymbol(global_name()); +} + +static void printAddressOfOp(OpAsmPrinter *p, AddressOfOp op) { + *p << op.getOperationName() << " @" << op.global_name(); + p->printOptionalAttrDict(op.getAttrs(), {"global_name"}); + *p << " : " << op.getResult()->getType(); +} + +static ParseResult parseAddressOfOp(OpAsmParser *parser, + OperationState *result) { + Attribute symRef; + Type type; + if (parser->parseAttribute(symRef, "global_name", result->attributes) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(type) || + parser->addTypeToList(type, result->types)) + return failure(); + + if (!symRef.isa()) + return parser->emitError(parser->getNameLoc(), "expected symbol reference"); + return success(); +} + +static LogicalResult verify(AddressOfOp op) { + auto global = op.getGlobal(); + if (!global) + return op.emitOpError("must reference a global defined by 'llvm.global'"); + + if (global.getType().getPointerTo() != op.getResult()->getType()) + return op.emitOpError( + "the type must be a pointer to the type of the referred global"); + + return success(); +} + +//===----------------------------------------------------------------------===// +// Printing/parsing for LLVM::ConstantOp. +//===----------------------------------------------------------------------===// + +static void printConstantOp(OpAsmPrinter *p, ConstantOp &op) { + *p << op.getOperationName() << '(' << op.value() << ')'; + p->printOptionalAttrDict(op.getAttrs(), {"value"}); + *p << " : " << op.res()->getType(); +} + +// ::= `llvm.constant` `(` attribute `)` attribute-list? : type +static ParseResult parseConstantOp(OpAsmParser *parser, + OperationState *result) { + Attribute valueAttr; + Type type; + + if (parser->parseLParen() || + parser->parseAttribute(valueAttr, "value", result->attributes) || + parser->parseRParen() || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(type)) + return failure(); + + result->addTypes(type); + return success(); +} + +//===----------------------------------------------------------------------===// +// Builder, printer and verifier for LLVM::GlobalOp. +//===----------------------------------------------------------------------===// + +void GlobalOp::build(Builder *builder, OperationState *result, LLVMType type, + bool isConstant, StringRef name, Attribute value, + ArrayRef attrs) { + result->addAttribute(SymbolTable::getSymbolAttrName(), + builder->getStringAttr(name)); + result->addAttribute("type", builder->getTypeAttr(type)); + if (isConstant) + result->addAttribute("constant", builder->getUnitAttr()); + result->addAttribute("value", value); + result->attributes.append(attrs.begin(), attrs.end()); +} + +static void printGlobalOp(OpAsmPrinter *p, GlobalOp op) { + *p << op.getOperationName() << ' '; + if (op.constant()) + *p << "constant "; + *p << '@' << op.sym_name() << '('; + p->printAttribute(op.value()); + *p << ')'; + p->printOptionalAttrDict(op.getAttrs(), {SymbolTable::getSymbolAttrName(), + "type", "constant", "value"}); + + // Print the trailing type unless it's a string global. + if (op.value().isa()) + return; + *p << " : "; + p->printType(op.type()); +} + +// ::= `llvm.global` `constant`? `@` identifier `(` attribute `)` +// attribute-list? (`:` type)? +// +// The type can be omitted for string attributes, in which case it will be +// inferred from the value of the string as [strlen(value) x i8]. +static ParseResult parseGlobalOp(OpAsmParser *parser, OperationState *result) { + if (succeeded(parser->parseOptionalKeyword("constant"))) + result->addAttribute("constant", parser->getBuilder().getUnitAttr()); + + Attribute value; + StringAttr name; + SmallVector types; + if (parser->parseSymbolName(name, SymbolTable::getSymbolAttrName(), + result->attributes) || + parser->parseLParen() || + parser->parseAttribute(value, "value", result->attributes) || + parser->parseRParen() || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseOptionalColonTypeList(types)) + return failure(); + + if (types.size() > 1) + return parser->emitError(parser->getNameLoc(), "expected zero or one type"); + + if (types.empty()) { + if (auto strAttr = value.dyn_cast()) { + MLIRContext *context = parser->getBuilder().getContext(); + auto *dialect = context->getRegisteredDialect(); + auto arrayType = LLVM::LLVMType::getArrayTy( + LLVM::LLVMType::getInt8Ty(dialect), strAttr.getValue().size()); + types.push_back(arrayType); + } else { + return parser->emitError(parser->getNameLoc(), + "type can only be omitted for string globals"); + } + } + + result->addAttribute("type", parser->getBuilder().getTypeAttr(types[0])); + return success(); +} + +static LogicalResult verify(GlobalOp op) { + if (!llvm::PointerType::isValidElementType(op.getType().getUnderlyingType())) + return op.emitOpError( + "expects type to be a valid element type for an LLVM pointer"); + if (op.getParentOp() && !isa(op.getParentOp())) + return op.emitOpError("must appear at the module level"); + if (auto strAttr = op.value().dyn_cast()) { + auto type = op.getType(); + if (!type.getUnderlyingType()->isArrayTy() || + !type.getArrayElementType().getUnderlyingType()->isIntegerTy(8) || + type.getArrayNumElements() != strAttr.getValue().size()) + return op.emitOpError( + "requires an i8 array type of the length equal to that of the string " + "attribute"); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// Printing/parsing for LLVM::ShuffleVectorOp. +//===----------------------------------------------------------------------===// +// Expects vector to be of wrapped LLVM vector type and position to be of +// wrapped LLVM i32 type. +void LLVM::ShuffleVectorOp::build(Builder *b, OperationState *result, Value *v1, + Value *v2, ArrayAttr mask, + ArrayRef attrs) { + auto wrappedContainerType1 = v1->getType().cast(); + auto vType = LLVMType::getVectorTy( + wrappedContainerType1.getVectorElementType(), mask.size()); + build(b, result, vType, v1, v2, mask); + result->addAttributes(attrs); +} + +static void printShuffleVectorOp(OpAsmPrinter *p, ShuffleVectorOp &op) { + *p << op.getOperationName() << ' ' << *op.v1() << ", " << *op.v2() << " " + << op.mask(); + p->printOptionalAttrDict(op.getAttrs(), {"mask"}); + *p << " : " << op.v1()->getType() << ", " << op.v2()->getType(); +} + +// ::= `llvm.shufflevector` ssa-use `, ` ssa-use +// `[` integer-literal (`,` integer-literal)* `]` +// attribute-dict? `:` type +static ParseResult parseShuffleVectorOp(OpAsmParser *parser, + OperationState *result) { + llvm::SMLoc loc; + SmallVector attrs; + OpAsmParser::OperandType v1, v2; + Attribute maskAttr; + Type typeV1, typeV2; + if (parser->getCurrentLocation(&loc) || parser->parseOperand(v1) || + parser->parseComma() || parser->parseOperand(v2) || + parser->parseAttribute(maskAttr, "mask", attrs) || + parser->parseOptionalAttributeDict(attrs) || + parser->parseColonType(typeV1) || parser->parseComma() || + parser->parseType(typeV2) || + parser->resolveOperand(v1, typeV1, result->operands) || + parser->resolveOperand(v2, typeV2, result->operands)) + return failure(); + auto wrappedContainerType1 = typeV1.dyn_cast(); + if (!wrappedContainerType1 || + !wrappedContainerType1.getUnderlyingType()->isVectorTy()) + return parser->emitError( + loc, "expected LLVM IR dialect vector type for operand #1"); + auto vType = + LLVMType::getVectorTy(wrappedContainerType1.getVectorElementType(), + maskAttr.cast().size()); + result->attributes = attrs; + result->addTypes(vType); + return success(); +} + +//===----------------------------------------------------------------------===// +// Builder, printer and verifier for LLVM::LLVMFuncOp. +//===----------------------------------------------------------------------===// + +void LLVMFuncOp::build(Builder *builder, OperationState *result, StringRef name, + LLVMType type, ArrayRef attrs, + ArrayRef argAttrs) { + result->addRegion(); + result->addAttribute(SymbolTable::getSymbolAttrName(), + builder->getStringAttr(name)); + result->addAttribute("type", builder->getTypeAttr(type)); + result->attributes.append(attrs.begin(), attrs.end()); + if (argAttrs.empty()) + return; + + unsigned numInputs = type.getUnderlyingType()->getFunctionNumParams(); + assert(numInputs == argAttrs.size() && + "expected as many argument attribute lists as arguments"); + SmallString<8> argAttrName; + for (unsigned i = 0; i < numInputs; ++i) + if (auto argDict = argAttrs[i].getDictionary()) + result->addAttribute(getArgAttrName(i, argAttrName), argDict); +} + +// Build an LLVM function type from the given lists of input and output types. +// Returns a null type if any of the types provided are non-LLVM types, or if +// there is more than one output type. +static Type buildLLVMFunctionType(Builder &b, ArrayRef inputs, + ArrayRef outputs, + impl::VariadicFlag variadicFlag, + std::string &errorMessage) { + if (outputs.size() > 1) { + errorMessage = "expected zero or one function result"; + return {}; + } + + // Convert inputs to LLVM types, exit early on error. + SmallVector llvmInputs; + for (auto t : inputs) { + auto llvmTy = t.dyn_cast(); + if (!llvmTy) { + errorMessage = "expected LLVM type for function arguments"; + return {}; + } + llvmInputs.push_back(llvmTy); + } + + // Get the dialect from the input type, if any exist. Look it up in the + // context otherwise. + LLVMDialect *dialect = + llvmInputs.empty() ? b.getContext()->getRegisteredDialect() + : &llvmInputs.front().getDialect(); + + // No output is denoted as "void" in LLVM type system. + LLVMType llvmOutput = outputs.empty() ? LLVMType::getVoidTy(dialect) + : outputs.front().dyn_cast(); + if (!llvmOutput) { + errorMessage = "expected LLVM type for function results"; + return {}; + } + return LLVMType::getFunctionTy(llvmOutput, llvmInputs, + variadicFlag.isVariadic()); +} + +// Print the LLVMFuncOp. Collects argument and result types and passes them +// to the trait printer. Drops "void" result since it cannot be parsed back. +static void printLLVMFuncOp(OpAsmPrinter *p, LLVMFuncOp op) { + LLVMType fnType = op.getType(); + SmallVector argTypes; + SmallVector resTypes; + argTypes.reserve(fnType.getFunctionNumParams()); + for (unsigned i = 0, e = fnType.getFunctionNumParams(); i < e; ++i) + argTypes.push_back(fnType.getFunctionParamType(i)); + + LLVMType returnType = fnType.getFunctionResultType(); + if (!returnType.getUnderlyingType()->isVoidTy()) + resTypes.push_back(returnType); + + impl::printFunctionLikeOp(p, op, argTypes, op.isVarArg(), resTypes); +} + +// Hook for OpTrait::FunctionLike, called after verifying that the 'type' +// attribute is present. This can check for preconditions of the +// getNumArguments hook not failing. +LogicalResult LLVMFuncOp::verifyType() { + auto llvmType = getTypeAttr().getValue().dyn_cast_or_null(); + if (!llvmType || !llvmType.getUnderlyingType()->isFunctionTy()) + return emitOpError("requires '" + getTypeAttrName() + + "' attribute of wrapped LLVM function type"); + + return success(); +} + +// Hook for OpTrait::FunctionLike, returns the number of function arguments. +// Depends on the type attribute being correct as checked by verifyType +unsigned LLVMFuncOp::getNumFuncArguments() { + return getType().getUnderlyingType()->getFunctionNumParams(); +} + +static LogicalResult verify(LLVMFuncOp op) { + if (op.isExternal()) + return success(); + + if (op.isVarArg()) + return op.emitOpError("only external functions can be variadic"); + + auto *funcType = cast(op.getType().getUnderlyingType()); + unsigned numArguments = funcType->getNumParams(); + Block &entryBlock = op.front(); + for (unsigned i = 0; i < numArguments; ++i) { + Type argType = entryBlock.getArgument(i)->getType(); + auto argLLVMType = argType.dyn_cast(); + if (!argLLVMType) + return op.emitOpError("entry block argument #") + << i << " is not of LLVM type"; + if (funcType->getParamType(i) != argLLVMType.getUnderlyingType()) + return op.emitOpError("the type of entry block argument #") + << i << " does not match the function signature"; + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// LLVMDialect initialization, type parsing, and registration. +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace LLVM { +namespace detail { +struct LLVMDialectImpl { + LLVMDialectImpl() : module("LLVMDialectModule", llvmContext) {} + + llvm::LLVMContext llvmContext; + llvm::Module module; + + /// A set of LLVMTypes that are cached on construction to avoid any lookups or + /// locking. + LLVMType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty; + LLVMType doubleTy, floatTy, halfTy; + LLVMType voidTy; + + /// A smart mutex to lock access to the llvm context. Unlike MLIR, LLVM is not + /// multi-threaded and requires locked access to prevent race conditions. + llvm::sys::SmartMutex mutex; +}; +} // end namespace detail +} // end namespace LLVM +} // end namespace mlir + +LLVMDialect::LLVMDialect(MLIRContext *context) + : Dialect(getDialectNamespace(), context), + impl(new detail::LLVMDialectImpl()) { + addTypes(); + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" + >(); + + // Support unknown operations because not all LLVM operations are registered. + allowUnknownOperations(); + + // Cache some of the common LLVM types to avoid the need for lookups/locking. + auto &llvmContext = impl->llvmContext; + /// Integer Types. + impl->int1Ty = LLVMType::get(context, llvm::Type::getInt1Ty(llvmContext)); + impl->int8Ty = LLVMType::get(context, llvm::Type::getInt8Ty(llvmContext)); + impl->int16Ty = LLVMType::get(context, llvm::Type::getInt16Ty(llvmContext)); + impl->int32Ty = LLVMType::get(context, llvm::Type::getInt32Ty(llvmContext)); + impl->int64Ty = LLVMType::get(context, llvm::Type::getInt64Ty(llvmContext)); + impl->int128Ty = LLVMType::get(context, llvm::Type::getInt128Ty(llvmContext)); + /// Float Types. + impl->doubleTy = LLVMType::get(context, llvm::Type::getDoubleTy(llvmContext)); + impl->floatTy = LLVMType::get(context, llvm::Type::getFloatTy(llvmContext)); + impl->halfTy = LLVMType::get(context, llvm::Type::getHalfTy(llvmContext)); + /// Other Types. + impl->voidTy = LLVMType::get(context, llvm::Type::getVoidTy(llvmContext)); +} + +LLVMDialect::~LLVMDialect() {} + +#define GET_OP_CLASSES +#include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" + +llvm::LLVMContext &LLVMDialect::getLLVMContext() { return impl->llvmContext; } +llvm::Module &LLVMDialect::getLLVMModule() { return impl->module; } + +/// Parse a type registered to this dialect. +Type LLVMDialect::parseType(StringRef tyData, Location loc) const { + // LLVM is not thread-safe, so lock access to it. + llvm::sys::SmartScopedLock lock(impl->mutex); + + llvm::SMDiagnostic errorMessage; + llvm::Type *type = llvm::parseType(tyData, errorMessage, impl->module); + if (!type) + return (emitError(loc, errorMessage.getMessage()), nullptr); + return LLVMType::get(getContext(), type); +} + +/// Print a type registered to this dialect. +void LLVMDialect::printType(Type type, raw_ostream &os) const { + auto llvmType = type.dyn_cast(); + assert(llvmType && "printing wrong type"); + assert(llvmType.getUnderlyingType() && "no underlying LLVM type"); + llvmType.getUnderlyingType()->print(os); +} + +/// Verify LLVMIR function argument attributes. +LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op, + unsigned regionIdx, + unsigned argIdx, + NamedAttribute argAttr) { + // Check that llvm.noalias is a boolean attribute. + if (argAttr.first == "llvm.noalias" && !argAttr.second.isa()) + return op->emitError() + << "llvm.noalias argument attribute of non boolean type"; + return success(); +} + +static DialectRegistration llvmDialect; + +//===----------------------------------------------------------------------===// +// LLVMType. +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace LLVM { +namespace detail { +struct LLVMTypeStorage : public ::mlir::TypeStorage { + LLVMTypeStorage(llvm::Type *ty) : underlyingType(ty) {} + + // LLVM types are pointer-unique. + using KeyTy = llvm::Type *; + bool operator==(const KeyTy &key) const { return key == underlyingType; } + + static LLVMTypeStorage *construct(TypeStorageAllocator &allocator, + llvm::Type *ty) { + return new (allocator.allocate()) LLVMTypeStorage(ty); + } + + llvm::Type *underlyingType; +}; +} // end namespace detail +} // end namespace LLVM +} // end namespace mlir + +LLVMType LLVMType::get(MLIRContext *context, llvm::Type *llvmType) { + return Base::get(context, FIRST_LLVM_TYPE, llvmType); +} + +/// Get an LLVMType with an llvm type that may cause changes to the underlying +/// llvm context when constructed. +LLVMType LLVMType::getLocked(LLVMDialect *dialect, + llvm::function_ref typeBuilder) { + // Lock access to the llvm context and build the type. + llvm::sys::SmartScopedLock lock(dialect->impl->mutex); + return get(dialect->getContext(), typeBuilder()); +} + +LLVMDialect &LLVMType::getDialect() { + return static_cast(Type::getDialect()); +} + +llvm::Type *LLVMType::getUnderlyingType() const { + return getImpl()->underlyingType; +} + +/// Array type utilities. +LLVMType LLVMType::getArrayElementType() { + return get(getContext(), getUnderlyingType()->getArrayElementType()); +} +unsigned LLVMType::getArrayNumElements() { + return getUnderlyingType()->getArrayNumElements(); +} + +/// Vector type utilities. +LLVMType LLVMType::getVectorElementType() { + return get(getContext(), getUnderlyingType()->getVectorElementType()); +} + +/// Function type utilities. +LLVMType LLVMType::getFunctionParamType(unsigned argIdx) { + return get(getContext(), getUnderlyingType()->getFunctionParamType(argIdx)); +} +unsigned LLVMType::getFunctionNumParams() { + return getUnderlyingType()->getFunctionNumParams(); +} +LLVMType LLVMType::getFunctionResultType() { + return get( + getContext(), + llvm::cast(getUnderlyingType())->getReturnType()); +} + +/// Pointer type utilities. +LLVMType LLVMType::getPointerTo(unsigned addrSpace) { + // Lock access to the dialect as this may modify the LLVM context. + return getLocked(&getDialect(), [=] { + return getUnderlyingType()->getPointerTo(addrSpace); + }); +} +LLVMType LLVMType::getPointerElementTy() { + return get(getContext(), getUnderlyingType()->getPointerElementType()); +} + +/// Struct type utilities. +LLVMType LLVMType::getStructElementType(unsigned i) { + return get(getContext(), getUnderlyingType()->getStructElementType(i)); +} + +/// Utilities used to generate floating point types. +LLVMType LLVMType::getDoubleTy(LLVMDialect *dialect) { + return dialect->impl->doubleTy; +} +LLVMType LLVMType::getFloatTy(LLVMDialect *dialect) { + return dialect->impl->floatTy; +} +LLVMType LLVMType::getHalfTy(LLVMDialect *dialect) { + return dialect->impl->halfTy; +} + +/// Utilities used to generate integer types. +LLVMType LLVMType::getIntNTy(LLVMDialect *dialect, unsigned numBits) { + switch (numBits) { + case 1: + return dialect->impl->int1Ty; + case 8: + return dialect->impl->int8Ty; + case 16: + return dialect->impl->int16Ty; + case 32: + return dialect->impl->int32Ty; + case 64: + return dialect->impl->int64Ty; + case 128: + return dialect->impl->int128Ty; + default: + break; + } + + // Lock access to the dialect as this may modify the LLVM context. + return getLocked(dialect, [=] { + return llvm::Type::getIntNTy(dialect->getLLVMContext(), numBits); + }); +} + +/// Utilities used to generate other miscellaneous types. +LLVMType LLVMType::getArrayTy(LLVMType elementType, uint64_t numElements) { + // Lock access to the dialect as this may modify the LLVM context. + return getLocked(&elementType.getDialect(), [=] { + return llvm::ArrayType::get(elementType.getUnderlyingType(), numElements); + }); +} +LLVMType LLVMType::getFunctionTy(LLVMType result, ArrayRef params, + bool isVarArg) { + SmallVector llvmParams; + for (auto param : params) + llvmParams.push_back(param.getUnderlyingType()); + + // Lock access to the dialect as this may modify the LLVM context. + return getLocked(&result.getDialect(), [=] { + return llvm::FunctionType::get(result.getUnderlyingType(), llvmParams, + isVarArg); + }); +} +LLVMType LLVMType::getStructTy(LLVMDialect *dialect, + ArrayRef elements, bool isPacked) { + SmallVector llvmElements; + for (auto elt : elements) + llvmElements.push_back(elt.getUnderlyingType()); + + // Lock access to the dialect as this may modify the LLVM context. + return getLocked(dialect, [=] { + return llvm::StructType::get(dialect->getLLVMContext(), llvmElements, + isPacked); + }); +} +LLVMType LLVMType::getVectorTy(LLVMType elementType, unsigned numElements) { + // Lock access to the dialect as this may modify the LLVM context. + return getLocked(&elementType.getDialect(), [=] { + return llvm::VectorType::get(elementType.getUnderlyingType(), numElements); + }); +} +LLVMType LLVMType::getVoidTy(LLVMDialect *dialect) { + return dialect->impl->voidTy; +} diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp new file mode 100644 index 00000000000..8d6f308e5b3 --- /dev/null +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -0,0 +1,88 @@ +//===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines the types and operation details for the NVVM IR dialect in +// MLIR, and the LLVM IR dialect. It also registers the dialect. +// +// The NVVM dialect only contains GPU specific additions on top of the general +// LLVM dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/StandardTypes.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Type.h" +#include "llvm/Support/SourceMgr.h" + +namespace mlir { +namespace NVVM { + +//===----------------------------------------------------------------------===// +// Printing/parsing for NVVM ops +//===----------------------------------------------------------------------===// + +static void printNVVMSpecialRegisterOp(OpAsmPrinter *p, Operation *op) { + *p << op->getName() << " : "; + if (op->getNumResults() == 1) { + *p << op->getResult(0)->getType(); + } else { + *p << "###invalid type###"; + } +} + +// ::= `llvm.nvvm.XYZ` : type +static ParseResult parseNVVMSpecialRegisterOp(OpAsmParser *parser, + OperationState *result) { + Type type; + if (parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(type)) + return failure(); + + result->addTypes(type); + return success(); +} + +//===----------------------------------------------------------------------===// +// NVVMDialect initialization, type parsing, and registration. +//===----------------------------------------------------------------------===// + +// TODO(herhut): This should be the llvm.nvvm dialect once this is supported. +NVVMDialect::NVVMDialect(MLIRContext *context) : Dialect("nvvm", context) { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc" + >(); + + // Support unknown operations because not all NVVM operations are registered. + allowUnknownOperations(); +} + +#define GET_OP_CLASSES +#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc" + +static DialectRegistration nvvmDialect; + +} // namespace NVVM +} // namespace mlir diff --git a/mlir/lib/Dialect/LoopOps/LoopOps.cpp b/mlir/lib/Dialect/LoopOps/LoopOps.cpp index 13dc35ec7ce..4d99cac3a04 100644 --- a/mlir/lib/Dialect/LoopOps/LoopOps.cpp +++ b/mlir/lib/Dialect/LoopOps/LoopOps.cpp @@ -16,6 +16,7 @@ // ============================================================================= #include "mlir/Dialect/LoopOps/LoopOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" @@ -26,7 +27,6 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" #include "mlir/IR/Value.h" -#include "mlir/StandardOps/Ops.h" #include "mlir/Support/MathExtras.h" #include "mlir/Support/STLExtras.h" diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp index efb202b7491..e3a17b057d4 100644 --- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp +++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp @@ -19,12 +19,12 @@ #include "mlir/Dialect/QuantOps/QuantOps.h" #include "mlir/Dialect/QuantOps/QuantizeUtils.h" #include "mlir/Dialect/QuantOps/UniformSupport.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/Ops.h" using namespace mlir; using namespace mlir::quant; diff --git a/mlir/lib/Dialect/SDBM/CMakeLists.txt b/mlir/lib/Dialect/SDBM/CMakeLists.txt new file mode 100644 index 00000000000..e36308e0eda --- /dev/null +++ b/mlir/lib/Dialect/SDBM/CMakeLists.txt @@ -0,0 +1,10 @@ +add_llvm_library(MLIRSDBM + SDBM.cpp + SDBMExpr.cpp + SDBMDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SDBM +) +add_dependencies(MLIRSDBM MLIRIR) +target_link_libraries(MLIRSDBM MLIRIR) diff --git a/mlir/lib/Dialect/SDBM/SDBM.cpp b/mlir/lib/Dialect/SDBM/SDBM.cpp new file mode 100644 index 00000000000..5450a61b17b --- /dev/null +++ b/mlir/lib/Dialect/SDBM/SDBM.cpp @@ -0,0 +1,561 @@ +//===- SDBM.cpp - MLIR SDBM implementation --------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// A striped difference-bound matrix (SDBM) is a set in Z^N (or R^N) defined +// as {(x_1, ... x_n) | f(x_1, ... x_n) >= 0} where f is an SDBM expression. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SDBM/SDBM.h" +#include "mlir/Dialect/SDBM/SDBMExpr.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; + +// Helper function for SDBM construction that collects information necessary to +// start building an SDBM in one sweep. In particular, it records the largest +// position of a dimension in `dim`, that of a symbol in `symbol` as well as +// collects all unique stripe expressions in `stripes`. Uses SetVector to +// ensure these expressions always have the same order. +static void collectSDBMBuildInfo(SDBMExpr expr, int &dim, int &symbol, + llvm::SmallSetVector &stripes) { + struct Visitor : public SDBMVisitor { + void visitDim(SDBMDimExpr dimExpr) { + int p = dimExpr.getPosition(); + if (p > maxDimPosition) + maxDimPosition = p; + } + void visitSymbol(SDBMSymbolExpr symbExpr) { + int p = symbExpr.getPosition(); + if (p > maxSymbPosition) + maxSymbPosition = p; + } + void visitStripe(SDBMStripeExpr stripeExpr) { stripes.insert(stripeExpr); } + + Visitor(llvm::SmallSetVector &stripes) : stripes(stripes) {} + + int maxDimPosition = -1; + int maxSymbPosition = -1; + llvm::SmallSetVector &stripes; + }; + + Visitor visitor(stripes); + visitor.walkPostorder(expr); + dim = std::max(dim, visitor.maxDimPosition); + symbol = std::max(symbol, visitor.maxSymbPosition); +} + +namespace { +// Utility class for SDBMBuilder. Represents a value that can be inserted in +// the SDB matrix that corresponds to "v0 - v1 + C <= 0", where v0 and v1 is +// any combination of the positive and negative positions. Since multiple +// variables can be declared equal to the same stripe expression, the +// constraints on this expression must be reflected to all these variables. For +// example, if +// d0 = s0 # 42 +// d1 = s0 # 42 +// d2 = s1 # 2 +// d3 = s1 # 2 +// the constraint +// s0 # 42 - s1 # 2 <= C +// should be reflected in the DB matrix as +// d0 - d2 <= C +// d1 - d2 <= C +// d0 - d3 <= C +// d1 - d3 <= C +// since the DB matrix has no knowledge of the transitive equality between d0, +// d1 and s0 # 42 as well as between d2, d3 and s1 # 2. This knowledge can be +// obtained by computing a transitive closure, which is impossible until the +// DBM is actually built. +struct SDBMBuilderResult { + // Positions in the matrix of the variables taken with the "+" sign in the + // difference expression, 0 if it is a constant rather than a variable. + llvm::SmallVector positivePos; + + // Positions in the matrix of the variables taken with the "-" sign in the + // difference expression, 0 if it is a constant rather than a variable. + llvm::SmallVector negativePos; + + // Constant value in the difference expression. + int64_t value = 0; +}; + +// Visitor for building an SDBM from SDBM expressions. After traversing an SDBM +// expression, produces an update to the SDB matrix specifying the positions in +// the matrix and the negated value that should be stored. Both the positive +// and the negative positions may be lists of indices in cases where multiple +// variables are equal to the same stripe expression. In such cases, the update +// applies to the cross product of positions because elements involved in the +// update are (transitively) equal and should have the same constraints, but we +// may not have an explicit equality for them. +struct SDBMBuilder : public SDBMVisitor { +public: + // A difference expression produces both the positive and the negative + // coordinate in the matrix, recursively traversing the LHS and the RHS. The + // value is the difference between values obtained from LHS and RHS. + SDBMBuilderResult visitDiff(SDBMDiffExpr diffExpr) { + auto lhs = visit(diffExpr.getLHS()); + auto rhs = visit(diffExpr.getRHS()); + assert(lhs.negativePos.size() == 1 && lhs.negativePos[0] == 0 && + "unexpected negative expression in a difference expression"); + assert(rhs.negativePos.size() == 1 && lhs.negativePos[0] == 0 && + "unexpected negative expression in a difference expression"); + + SDBMBuilderResult result; + result.positivePos = lhs.positivePos; + result.negativePos = rhs.positivePos; + result.value = lhs.value - rhs.value; + return result; + } + + // An input expression is always taken with the "+" sign and therefore + // produces a positive coordinate keeping the negative coordinate zero for an + // eventual constant. + SDBMBuilderResult visitInput(SDBMInputExpr expr) { + SDBMBuilderResult r; + r.positivePos.push_back(linearPosition(expr)); + r.negativePos.push_back(0); + return r; + } + + // A stripe expression is always equal to one or more variables, which may be + // temporaries, and appears with a "+" sign in the SDBM expression tree. Take + // the positions of the corresponding variables as positive coordinates. + SDBMBuilderResult visitStripe(SDBMStripeExpr expr) { + SDBMBuilderResult r; + assert(pointExprToStripe.count(expr)); + r.positivePos = pointExprToStripe[expr]; + r.negativePos.push_back(0); + return r; + } + + // A constant expression has both coordinates at zero. + SDBMBuilderResult visitConstant(SDBMConstantExpr expr) { + SDBMBuilderResult r; + r.positivePos.push_back(0); + r.negativePos.push_back(0); + r.value = expr.getValue(); + return r; + } + + // A negation expression swaps the positive and the negative coordinates + // and also negates the constant value. + SDBMBuilderResult visitNeg(SDBMNegExpr expr) { + SDBMBuilderResult result = visit(expr.getVar()); + std::swap(result.positivePos, result.negativePos); + result.value = -result.value; + return result; + } + + // The RHS of a sum expression must be a constant and therefore must have both + // positive and negative coordinates at zero. Take the sum of the values + // between LHS and RHS and keep LHS coordinates. + SDBMBuilderResult visitSum(SDBMSumExpr expr) { + auto lhs = visit(expr.getLHS()); + auto rhs = visit(expr.getRHS()); + for (auto pos : rhs.negativePos) { + (void)pos; + assert(pos == 0 && "unexpected variable on the RHS of SDBM sum"); + } + for (auto pos : rhs.positivePos) { + (void)pos; + assert(pos == 0 && "unexpected variable on the RHS of SDBM sum"); + } + + lhs.value += rhs.value; + return lhs; + } + + SDBMBuilder(llvm::DenseMap> + &pointExprToStripe, + llvm::function_ref callback) + : pointExprToStripe(pointExprToStripe), linearPosition(callback) {} + + llvm::DenseMap> &pointExprToStripe; + llvm::function_ref linearPosition; +}; +} // namespace + +SDBM SDBM::get(ArrayRef inequalities, ArrayRef equalities) { + SDBM result; + + // TODO(zinenko): consider detecting equalities in the list of inequalities. + // This is potentially expensive and requires to + // - create a list of negated inequalities (may allocate under lock); + // - perform a pairwise comparison of direct and negated inequalities; + // - copy the lists of equalities and inequalities, and move entries between + // them; + // only for the purpose of sparing a temporary variable in cases where an + // implicit equality between a variable and a stripe expression is present in + // the input. + + // Do the first sweep over (in)equalities to collect the information necessary + // to allocate the SDB matrix (number of dimensions, symbol and temporary + // variables required for stripe expressions). + llvm::SmallSetVector stripes; + int maxDim = -1; + int maxSymbol = -1; + for (auto expr : inequalities) + collectSDBMBuildInfo(expr, maxDim, maxSymbol, stripes); + for (auto expr : equalities) + collectSDBMBuildInfo(expr, maxDim, maxSymbol, stripes); + // Indexing of dimensions starts with 0, obtain the number of dimensions by + // incrementing the maximal position of the dimension seen in expressions. + result.numDims = maxDim + 1; + result.numSymbols = maxSymbol + 1; + result.numTemporaries = 0; + + // Helper function that returns the position of the variable represented by + // an SDBM input expression. + auto linearPosition = [result](SDBMInputExpr expr) { + if (expr.isa()) + return result.getDimPosition(expr.getPosition()); + return result.getSymbolPosition(expr.getPosition()); + }; + + // Check if some stripe expressions are equal to another variable. In + // particular, look for the equalities of the form + // d0 - stripe-expression = 0, or + // stripe-expression - d0 = 0. + // There may be multiple variables that are equal to the same stripe + // expression. Keep track of those in pointExprToStripe. + // There may also be multiple stripe expressions equal to the same variable. + // Introduce a temporary variable for each of those. + llvm::DenseMap> pointExprToStripe; + unsigned numTemporaries = 0; + + auto updateStripePointMaps = [&numTemporaries, &result, &pointExprToStripe, + linearPosition](SDBMInputExpr input, + SDBMExpr expr) { + unsigned position = linearPosition(input); + if (result.stripeToPoint.count(position) && + result.stripeToPoint[position] != expr) { + position = result.getNumVariables() + numTemporaries++; + } + pointExprToStripe[expr].push_back(position); + result.stripeToPoint.insert(std::make_pair(position, expr)); + }; + + for (auto eq : equalities) { + auto diffExpr = eq.dyn_cast(); + if (!diffExpr) + continue; + + auto lhs = diffExpr.getLHS(); + auto rhs = diffExpr.getRHS(); + auto lhsInput = lhs.dyn_cast(); + auto rhsInput = rhs.dyn_cast(); + + if (lhsInput && stripes.count(rhs)) + updateStripePointMaps(lhsInput, rhs); + if (rhsInput && stripes.count(lhs)) + updateStripePointMaps(rhsInput, lhs); + } + + // Assign the remaining stripe expressions to temporary variables. These + // expressions are the ones that could not be associated with an existing + // variable in the previous step. + for (auto expr : stripes) { + if (pointExprToStripe.count(expr)) + continue; + unsigned position = result.getNumVariables() + numTemporaries++; + pointExprToStripe[expr].push_back(position); + result.stripeToPoint.insert(std::make_pair(position, expr)); + } + + // Create the DBM matrix, initialized to infinity values for the least tight + // possible bound (x - y <= infinity is always true). + result.numTemporaries = numTemporaries; + result.matrix.resize(result.getNumVariables() * result.getNumVariables(), + IntInfty::infinity()); + + SDBMBuilder builder(pointExprToStripe, linearPosition); + + // Only keep the tightest constraint. Since we transform everything into + // less-than-or-equals-to inequalities, keep the smallest constant. For + // example, if we have d0 - d1 <= 42 and d0 - d1 <= 2, we keep the latter. + // Note that the input expressions are in the shape of d0 - d1 + -42 <= 0 + // so we negate the value before storing it. + // In case where the positive and the negative positions are equal, the + // corresponding expression has the form d0 - d0 + -42 <= 0. If the constant + // value is positive, the set defined by SDBM is trivially empty. We store + // this value anyway and continue processing to maintain the correspondence + // between the matrix form and the list-of-SDBMExpr form. + // TODO(zinenko): we may want to reconsider this once we have canonicalization + // or simplification in place + auto updateMatrix = [](SDBM &sdbm, const SDBMBuilderResult &r) { + for (auto positivePos : r.positivePos) { + for (auto negativePos : r.negativePos) { + auto &m = sdbm.at(negativePos, positivePos); + m = m < -r.value ? m : -r.value; + } + } + }; + + // Do the second sweep on (in)equalities, updating the SDB matrix to reflect + // the constraints. + for (auto ineq : inequalities) + updateMatrix(result, builder.visit(ineq)); + + // An equality f(x) = 0 is represented as a pair of inequalities {f(x) >= 0; + // f(x) <= 0} or, alternatively, {-f(x) <= 0 and f(x) <= 0}. + for (auto eq : equalities) { + updateMatrix(result, builder.visit(eq)); + updateMatrix(result, builder.visit(-eq)); + } + + // Add the inequalities induced by stripe equalities. + // t = x # C => t <= x <= t + C - 1 + // which is equivalent to + // {t - x <= 0; + // x - t - (C - 1) <= 0}. + for (const auto &pair : result.stripeToPoint) { + auto stripe = pair.second.cast(); + SDBMBuilderResult update = builder.visit(stripe.getVar()); + assert(update.negativePos.size() == 1 && update.negativePos[0] == 0 && + "unexpected negated variable in stripe expression"); + assert(update.value == 0 && + "unexpected non-zero value in stripe expression"); + update.negativePos.clear(); + update.negativePos.push_back(pair.first); + update.value = -(stripe.getStripeFactor().getValue() - 1); + updateMatrix(result, update); + + std::swap(update.negativePos, update.positivePos); + update.value = 0; + updateMatrix(result, update); + } + + return result; +} + +// Given a row and a column position in the square DBM, insert one equality +// or up to two inequalities that correspond the entries (col, row) and (row, +// col) in the DBM. `rowExpr` and `colExpr` contain the expressions such that +// colExpr - rowExpr <= V where V is the value at (row, col) in the DBM. +// If one of the expressions is derived from another using a stripe operation, +// check if the inequalities induced by the stripe operation subsume the +// inequalities defined in the DBM and if so, elide these inequalities. +void SDBM::convertDBMElement(unsigned row, unsigned col, + SDBMPositiveExpr rowExpr, SDBMPositiveExpr colExpr, + SmallVectorImpl &inequalities, + SmallVectorImpl &equalities) { + using ops_assertions::operator+; + using ops_assertions::operator-; + + auto diffIJValue = at(col, row); + auto diffJIValue = at(row, col); + + // If symmetric entries are opposite, the corresponding expressions are equal. + if (diffIJValue.isFinite() && + diffIJValue.getValue() == -diffJIValue.getValue()) { + equalities.push_back(rowExpr - colExpr - diffIJValue.getValue()); + return; + } + + // Given an inequality x0 - x1 <= A, check if x0 is a stripe variable derived + // from x1: x0 = x1 # B. If so, it would imply the constraints + // x0 <= x1 <= x0 + (B - 1) <=> x0 - x1 <= 0 and x1 - x0 <= (B - 1). + // Therefore, if A >= 0, this inequality is subsumed by that implied + // by the stripe equality and thus can be elided. + // Similarly, check if x1 is a stripe variable derived from x0: x1 = x0 # C. + // If so, it would imply the constraints x1 <= x0 <= x1 + (C - 1) <=> + // <=> x1 - x0 <= 0 and x0 - x1 <= (C - 1). Therefore, if A >= (C - 1), this + // inequality can be elided. + // + // Note: x0 and x1 may be a stripe expressions themselves, we rely on stripe + // expressions being stored without temporaries on the RHS and being passed + // into this function as is. + auto canElide = [this](unsigned x0, unsigned x1, SDBMExpr x0Expr, + SDBMExpr x1Expr, int64_t value) { + if (stripeToPoint.count(x0)) { + auto stripe = stripeToPoint[x0].cast(); + SDBMPositiveExpr var = stripe.getVar(); + if (x1Expr == var && value >= 0) + return true; + } + if (stripeToPoint.count(x1)) { + auto stripe = stripeToPoint[x1].cast(); + SDBMPositiveExpr var = stripe.getVar(); + if (x0Expr == var && value >= stripe.getStripeFactor().getValue() - 1) + return true; + } + return false; + }; + + // Check row - col. + if (diffIJValue.isFinite() && + !canElide(row, col, rowExpr, colExpr, diffIJValue.getValue())) { + inequalities.push_back(rowExpr - colExpr - diffIJValue.getValue()); + } + // Check col - row. + if (diffJIValue.isFinite() && + !canElide(col, row, colExpr, rowExpr, diffJIValue.getValue())) { + inequalities.push_back(colExpr - rowExpr - diffJIValue.getValue()); + } +} + +// The values on the main diagonal correspond to the upper bound on the +// difference between a variable and itself: d0 - d0 <= C, or alternatively +// to -C <= 0. Only construct the inequalities when C is negative, which +// are trivially false but necessary for the returned system of inequalities +// to indicate that the set it defines is empty. +void SDBM::convertDBMDiagonalElement(unsigned pos, SDBMPositiveExpr expr, + SmallVectorImpl &inequalities) { + auto selfDifference = at(pos, pos); + if (selfDifference.isFinite() && selfDifference < 0) { + auto selfDifferenceValueExpr = + SDBMConstantExpr::get(expr.getDialect(), -selfDifference.getValue()); + inequalities.push_back(selfDifferenceValueExpr); + } +} + +void SDBM::getSDBMExpressions(SDBMDialect *dialect, + SmallVectorImpl &inequalities, + SmallVectorImpl &equalities) { + using ops_assertions::operator-; + using ops_assertions::operator+; + + // Helper function that creates an SDBMInputExpr given the linearized position + // of variable in the DBM. + auto getInput = [dialect, this](unsigned matrixPos) -> SDBMInputExpr { + if (matrixPos < numDims) + return SDBMDimExpr::get(dialect, matrixPos); + return SDBMSymbolExpr::get(dialect, matrixPos - numDims); + }; + + // The top-left value corresponds to inequality 0 <= C. If C is negative, the + // set defined by SDBM is trivially empty and we add the constraint -C <= 0 to + // the list of inequalities. Otherwise, the constraint is trivially true and + // we ignore it. + auto difference = at(0, 0); + if (difference.isFinite() && difference < 0) { + inequalities.push_back( + SDBMConstantExpr::get(dialect, -difference.getValue())); + } + + // Traverse the segment of the matrix that involves non-temporary variables. + unsigned numTrueVariables = numDims + numSymbols; + for (unsigned i = 0; i < numTrueVariables; ++i) { + // The first row and column represent numerical upper and lower bound on + // each variable. Transform them into inequalities if they are finite. + auto upperBound = at(0, 1 + i); + auto lowerBound = at(1 + i, 0); + auto inputExpr = getInput(i); + if (upperBound.isFinite() && + upperBound.getValue() == -lowerBound.getValue()) { + equalities.push_back(inputExpr - upperBound.getValue()); + } else if (upperBound.isFinite()) { + inequalities.push_back(inputExpr - upperBound.getValue()); + } else if (lowerBound.isFinite()) { + inequalities.push_back(-inputExpr - lowerBound.getValue()); + } + + // Introduce trivially false inequalities if required by diagonal elements. + convertDBMDiagonalElement(1 + i, inputExpr, inequalities); + + // Introduce equalities or inequalities between non-temporary variables. + for (unsigned j = 0; j < i; ++j) { + convertDBMElement(1 + i, 1 + j, getInput(i), getInput(j), inequalities, + equalities); + } + } + + // Add equalities for stripe expressions that define non-temporary + // variables. Temporary variables will be substituted into their uses and + // should not appear in the resulting equalities. + for (const auto &stripePair : stripeToPoint) { + unsigned position = stripePair.first; + if (position < 1 + numTrueVariables) { + equalities.push_back(getInput(position - 1) - stripePair.second); + } + } + + // Add equalities / inequalities involving temporaries by replacing the + // temporaries with stripe expressions that define them. + for (unsigned i = 1 + numTrueVariables, e = getNumVariables(); i < e; ++i) { + // Mixed constraints involving one temporary (j) and one non-temporary (i) + // variable. + for (unsigned j = 0; j < numTrueVariables; ++j) { + convertDBMElement(i, 1 + j, stripeToPoint[i].cast(), + getInput(j), inequalities, equalities); + } + + // Constraints involving only temporary variables. + for (unsigned j = 1 + numTrueVariables; j < i; ++j) { + convertDBMElement(i, j, stripeToPoint[i].cast(), + stripeToPoint[j].cast(), inequalities, + equalities); + } + + // Introduce trivially false inequalities if required by diagonal elements. + convertDBMDiagonalElement(i, stripeToPoint[i].cast(), + inequalities); + } +} + +void SDBM::print(llvm::raw_ostream &os) { + unsigned numVariables = getNumVariables(); + + // Helper function that prints the name of the variable given its linearized + // position in the DBM. + auto getVarName = [this](unsigned matrixPos) -> std::string { + if (matrixPos == 0) + return "cst"; + matrixPos -= 1; + if (matrixPos < numDims) + return llvm::formatv("d{0}", matrixPos); + matrixPos -= numDims; + if (matrixPos < numSymbols) + return llvm::formatv("s{0}", matrixPos); + matrixPos -= numSymbols; + return llvm::formatv("t{0}", matrixPos); + }; + + // Header row. + os << " cst"; + for (unsigned i = 1; i < numVariables; ++i) { + os << llvm::formatv(" {0,4}", getVarName(i)); + } + os << '\n'; + + // Data rows. + for (unsigned i = 0; i < numVariables; ++i) { + os << llvm::formatv("{0,-4}", getVarName(i)); + for (unsigned j = 0; j < numVariables; ++j) { + IntInfty value = operator()(i, j); + if (!value.isFinite()) + os << " inf"; + else + os << llvm::formatv(" {0,4}", value.getValue()); + } + os << '\n'; + } + + // Explanation of temporaries. + for (const auto &pair : stripeToPoint) { + os << getVarName(pair.first) << " = "; + pair.second.print(os); + os << '\n'; + } +} + +void SDBM::dump() { print(llvm::errs()); } diff --git a/mlir/lib/Dialect/SDBM/SDBMDialect.cpp b/mlir/lib/Dialect/SDBM/SDBMDialect.cpp new file mode 100644 index 00000000000..d3d895fec88 --- /dev/null +++ b/mlir/lib/Dialect/SDBM/SDBMDialect.cpp @@ -0,0 +1,20 @@ +//===- SDBMDialect.cpp - Dialect for striped difference-bound matrices ----===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/Dialect/SDBM/SDBMDialect.h" + +static mlir::DialectRegistration SDBMDialect; diff --git a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp new file mode 100644 index 00000000000..a174c8c84f2 --- /dev/null +++ b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp @@ -0,0 +1,647 @@ +//===- SDBMExpr.cpp - MLIR SDBM Expression implementation -----------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// A striped difference-bound matrix (SDBM) expression is a constant expression, +// an identifier, a binary expression with constant RHS and +, stripe operators +// or a difference expression between two identifiers. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SDBM/SDBMExpr.h" +#include "SDBMExprDetail.h" +#include "mlir/Dialect/SDBM/SDBMDialect.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineExprVisitor.h" + +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; + +namespace { +/// A simple compositional matcher for AffineExpr +/// +/// Example usage: +/// +/// ```c++ +/// AffineExprMatcher x, C, m; +/// AffineExprMatcher pattern1 = ((x % C) * m) + x; +/// AffineExprMatcher pattern2 = x + ((x % C) * m); +/// if (pattern1.match(expr) || pattern2.match(expr)) { +/// ... +/// } +/// ``` +class AffineExprMatcherStorage; +class AffineExprMatcher { +public: + AffineExprMatcher(); + AffineExprMatcher(const AffineExprMatcher &other); + + AffineExprMatcher operator+(AffineExprMatcher other) { + return AffineExprMatcher(AffineExprKind::Add, *this, other); + } + AffineExprMatcher operator*(AffineExprMatcher other) { + return AffineExprMatcher(AffineExprKind::Mul, *this, other); + } + AffineExprMatcher floorDiv(AffineExprMatcher other) { + return AffineExprMatcher(AffineExprKind::FloorDiv, *this, other); + } + AffineExprMatcher ceilDiv(AffineExprMatcher other) { + return AffineExprMatcher(AffineExprKind::CeilDiv, *this, other); + } + AffineExprMatcher operator%(AffineExprMatcher other) { + return AffineExprMatcher(AffineExprKind::Mod, *this, other); + } + + AffineExpr match(AffineExpr expr); + AffineExpr matched(); + Optional getMatchedConstantValue(); + +private: + AffineExprMatcher(AffineExprKind k, AffineExprMatcher a, AffineExprMatcher b); + AffineExprKind kind; // only used to match in binary op cases. + // A shared_ptr allows multiple references to same matcher storage without + // worrying about ownership or dealing with an arena. To be cleaned up if we + // go with this. + std::shared_ptr storage; +}; + +class AffineExprMatcherStorage { +public: + AffineExprMatcherStorage() {} + AffineExprMatcherStorage(const AffineExprMatcherStorage &other) + : subExprs(other.subExprs.begin(), other.subExprs.end()), + matched(other.matched) {} + AffineExprMatcherStorage(ArrayRef exprs) + : subExprs(exprs.begin(), exprs.end()) {} + AffineExprMatcherStorage(AffineExprMatcher &a, AffineExprMatcher &b) + : subExprs({a, b}) {} + llvm::SmallVector subExprs; + AffineExpr matched; +}; +} // namespace + +AffineExprMatcher::AffineExprMatcher() + : kind(AffineExprKind::Constant), storage(new AffineExprMatcherStorage()) {} + +AffineExprMatcher::AffineExprMatcher(const AffineExprMatcher &other) + : kind(other.kind), storage(other.storage) {} + +Optional AffineExprMatcher::getMatchedConstantValue() { + if (auto cst = storage->matched.dyn_cast()) + return cst.getValue(); + return None; +} + +AffineExpr AffineExprMatcher::match(AffineExpr expr) { + if (kind > AffineExprKind::LAST_AFFINE_BINARY_OP) { + if (storage->matched) + if (storage->matched != expr) + return AffineExpr(); + storage->matched = expr; + return storage->matched; + } + if (kind != expr.getKind()) { + return AffineExpr(); + } + if (auto bin = expr.dyn_cast()) { + if (!storage->subExprs.empty() && + !storage->subExprs[0].match(bin.getLHS())) { + return AffineExpr(); + } + if (!storage->subExprs.empty() && + !storage->subExprs[1].match(bin.getRHS())) { + return AffineExpr(); + } + if (storage->matched) + if (storage->matched != expr) + return AffineExpr(); + storage->matched = expr; + return storage->matched; + } + llvm_unreachable("binary expected"); +} + +AffineExpr AffineExprMatcher::matched() { return storage->matched; } + +AffineExprMatcher::AffineExprMatcher(AffineExprKind k, AffineExprMatcher a, + AffineExprMatcher b) + : kind(k), storage(new AffineExprMatcherStorage(a, b)) { + storage->subExprs.push_back(a); + storage->subExprs.push_back(b); +} + +//===----------------------------------------------------------------------===// +// SDBMExpr +//===----------------------------------------------------------------------===// + +SDBMExprKind SDBMExpr::getKind() const { return impl->getKind(); } + +MLIRContext *SDBMExpr::getContext() const { + return impl->dialect->getContext(); +} + +SDBMDialect *SDBMExpr::getDialect() const { return impl->dialect; } + +void SDBMExpr::print(raw_ostream &os) const { + struct Printer : public SDBMVisitor { + Printer(raw_ostream &ostream) : prn(ostream) {} + + void visitSum(SDBMSumExpr expr) { + visitVarying(expr.getLHS()); + prn << " + "; + visitConstant(expr.getRHS()); + } + void visitDiff(SDBMDiffExpr expr) { + visitPositive(expr.getLHS()); + prn << " - "; + visitPositive(expr.getRHS()); + } + void visitDim(SDBMDimExpr expr) { prn << 'd' << expr.getPosition(); } + void visitSymbol(SDBMSymbolExpr expr) { prn << 's' << expr.getPosition(); } + void visitStripe(SDBMStripeExpr expr) { + visitPositive(expr.getVar()); + prn << " # "; + visitConstant(expr.getStripeFactor()); + } + void visitNeg(SDBMNegExpr expr) { + prn << '-'; + visitPositive(expr.getVar()); + } + void visitConstant(SDBMConstantExpr expr) { prn << expr.getValue(); } + + raw_ostream &prn; + }; + Printer printer(os); + printer.visit(*this); +} + +void SDBMExpr::dump() const { + print(llvm::errs()); + llvm::errs() << '\n'; +} + +namespace { +// Helper class to perform negation of an SDBM expression. +struct SDBMNegator : public SDBMVisitor { + // Any positive expression is wrapped into a negation expression. + // -(x) = -x + SDBMExpr visitPositive(SDBMPositiveExpr expr) { + return SDBMNegExpr::get(expr); + } + // A negation expression is unwrapped. + // -(-x) = x + SDBMExpr visitNeg(SDBMNegExpr expr) { return expr.getVar(); } + // The value of the constant is negated. + SDBMExpr visitConstant(SDBMConstantExpr expr) { + return SDBMConstantExpr::get(expr.getDialect(), -expr.getValue()); + } + // Both terms of the sum are negated recursively. + SDBMExpr visitSum(SDBMSumExpr expr) { + return SDBMSumExpr::get(visit(expr.getLHS()).cast(), + visit(expr.getRHS()).cast()); + } + // Terms of a difference are interchanged. + // -(x - y) = y - x + SDBMExpr visitDiff(SDBMDiffExpr expr) { + return SDBMDiffExpr::get(expr.getRHS(), expr.getLHS()); + } +}; +} // namespace + +SDBMExpr SDBMExpr::operator-() { return SDBMNegator().visit(*this); } + +//===----------------------------------------------------------------------===// +// SDBMSumExpr +//===----------------------------------------------------------------------===// + +SDBMSumExpr SDBMSumExpr::get(SDBMVaryingExpr lhs, SDBMConstantExpr rhs) { + assert(lhs && "expected SDBM variable expression"); + assert(rhs && "expected SDBM constant"); + + // If LHS of a sum is another sum, fold the constant RHS parts. + if (auto lhsSum = lhs.dyn_cast()) { + lhs = lhsSum.getLHS(); + rhs = SDBMConstantExpr::get(rhs.getDialect(), + rhs.getValue() + lhsSum.getRHS().getValue()); + } + + StorageUniquer &uniquer = lhs.getDialect()->getUniquer(); + return uniquer.get( + /*initFn=*/{}, static_cast(SDBMExprKind::Add), lhs, rhs); +} + +SDBMVaryingExpr SDBMSumExpr::getLHS() const { + return static_cast(impl)->lhs; +} + +SDBMConstantExpr SDBMSumExpr::getRHS() const { + return static_cast(impl)->rhs; +} + +AffineExpr SDBMExpr::getAsAffineExpr() const { + struct Converter : public SDBMVisitor { + AffineExpr visitSum(SDBMSumExpr expr) { + AffineExpr lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS()); + return lhs + rhs; + } + + AffineExpr visitStripe(SDBMStripeExpr expr) { + AffineExpr lhs = visit(expr.getVar()), + rhs = visit(expr.getStripeFactor()); + return lhs - (lhs % rhs); + } + + AffineExpr visitDiff(SDBMDiffExpr expr) { + AffineExpr lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS()); + return lhs - rhs; + } + + AffineExpr visitDim(SDBMDimExpr expr) { + return getAffineDimExpr(expr.getPosition(), expr.getContext()); + } + + AffineExpr visitSymbol(SDBMSymbolExpr expr) { + return getAffineSymbolExpr(expr.getPosition(), expr.getContext()); + } + + AffineExpr visitNeg(SDBMNegExpr expr) { + return getAffineBinaryOpExpr(AffineExprKind::Mul, + getAffineConstantExpr(-1, expr.getContext()), + visit(expr.getVar())); + } + + AffineExpr visitConstant(SDBMConstantExpr expr) { + return getAffineConstantExpr(expr.getValue(), expr.getContext()); + } + } converter; + return converter.visit(*this); +} + +Optional SDBMExpr::tryConvertAffineExpr(AffineExpr affine) { + struct Converter : public AffineExprVisitor { + SDBMExpr visitAddExpr(AffineBinaryOpExpr expr) { + // Attempt to recover a stripe expression. Because AffineExprs don't have + // a first-class difference kind, we check for both x + -1 * (x mod C) and + // -1 * (x mod C) + x cases. + AffineExprMatcher x, C, m; + AffineExprMatcher pattern1 = ((x % C) * m) + x; + AffineExprMatcher pattern2 = x + ((x % C) * m); + if ((pattern1.match(expr) && m.getMatchedConstantValue() == -1) || + (pattern2.match(expr) && m.getMatchedConstantValue() == -1)) { + if (auto convertedLHS = visit(x.matched())) { + // TODO(ntv): return convertedLHS.stripe(C); + return SDBMStripeExpr::get( + convertedLHS.cast(), + visit(C.matched()).cast()); + } + } + auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS()); + if (!lhs || !rhs) + return {}; + + // In a "add" AffineExpr, the constant always appears on the right. If + // there were two constants, they would have been folded away. + assert(!lhs.isa() && "non-canonical affine expression"); + auto rhsConstant = rhs.dyn_cast(); + + // SDBM accepts LHS variables and RHS constants in a sum. + auto lhsVar = lhs.dyn_cast(); + auto rhsVar = rhs.dyn_cast(); + if (rhsConstant && lhsVar) + return SDBMSumExpr::get(lhsVar, rhsConstant); + + // The sum of a negated variable and a non-negated variable is a + // difference, supported as a special kind in SDBM. Because AffineExprs + // don't have first-class difference kind, check both LHS and RHS for + // negation. + auto lhsPos = lhs.dyn_cast(); + auto rhsPos = rhs.dyn_cast(); + auto lhsNeg = lhs.dyn_cast(); + auto rhsNeg = rhs.dyn_cast(); + if (lhsNeg && rhsVar) + return SDBMDiffExpr::get(rhsPos, lhsNeg.getVar()); + if (rhsNeg && lhsVar) + return SDBMDiffExpr::get(lhsPos, rhsNeg.getVar()); + + // Other cases don't fit into SDBM. + return {}; + } + + SDBMExpr visitMulExpr(AffineBinaryOpExpr expr) { + // Attempt to recover a stripe expression "x # C = (x floordiv C) * C". + AffineExprMatcher x, C; + AffineExprMatcher pattern = (x.floorDiv(C)) * C; + if (pattern.match(expr)) { + if (SDBMExpr converted = visit(x.matched())) { + if (auto varConverted = converted.dyn_cast()) + // TODO(ntv): return varConverted.stripe(C.getConstantValue()); + return SDBMStripeExpr::get( + varConverted, + SDBMConstantExpr::get(dialect, + C.getMatchedConstantValue().getValue())); + } + } + + auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS()); + if (!lhs || !rhs) + return {}; + + // In a "mul" AffineExpr, the constant always appears on the right. If + // there were two constants, they would have been folded away. + assert(!lhs.isa() && "non-canonical affine expression"); + auto rhsConstant = rhs.dyn_cast(); + if (!rhsConstant) + return {}; + + // The only supported "multiplication" expression is an SDBM is dimension + // negation, that is a product of dimension and constant -1. + auto lhsVar = lhs.dyn_cast(); + if (lhsVar && rhsConstant.getValue() == -1) + return SDBMNegExpr::get(lhsVar); + + // Other multiplications are not allowed in SDBM. + return {}; + } + + SDBMExpr visitModExpr(AffineBinaryOpExpr expr) { + auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS()); + if (!lhs || !rhs) + return {}; + + // 'mod' can only be converted to SDBM if its LHS is a variable + // and its RHS is a constant. Then it `x mod c = x - x stripe c`. + auto rhsConstant = rhs.dyn_cast(); + auto lhsVar = rhs.dyn_cast(); + if (!lhsVar || !rhsConstant) + return {}; + return SDBMDiffExpr::get(lhsVar, + SDBMStripeExpr::get(lhsVar, rhsConstant)); + } + + // `a floordiv b = (a stripe b) / b`, but we have no division in SDBM + SDBMExpr visitFloorDivExpr(AffineBinaryOpExpr expr) { return {}; } + SDBMExpr visitCeilDivExpr(AffineBinaryOpExpr expr) { return {}; } + + // Dimensions, symbols and constants are converted trivially. + SDBMExpr visitConstantExpr(AffineConstantExpr expr) { + return SDBMConstantExpr::get(dialect, expr.getValue()); + } + SDBMExpr visitDimExpr(AffineDimExpr expr) { + return SDBMDimExpr::get(dialect, expr.getPosition()); + } + SDBMExpr visitSymbolExpr(AffineSymbolExpr expr) { + return SDBMSymbolExpr::get(dialect, expr.getPosition()); + } + + SDBMDialect *dialect; + } converter; + converter.dialect = affine.getContext()->getRegisteredDialect(); + + if (auto result = converter.visit(affine)) + return result; + return None; +} + +//===----------------------------------------------------------------------===// +// SDBMDiffExpr +//===----------------------------------------------------------------------===// + +SDBMDiffExpr SDBMDiffExpr::get(SDBMPositiveExpr lhs, SDBMPositiveExpr rhs) { + assert(lhs && "expected SDBM dimension"); + assert(rhs && "expected SDBM dimension"); + + StorageUniquer &uniquer = lhs.getDialect()->getUniquer(); + return uniquer.get( + /*initFn=*/{}, static_cast(SDBMExprKind::Diff), lhs, rhs); +} + +SDBMPositiveExpr SDBMDiffExpr::getLHS() const { + return static_cast(impl)->lhs; +} + +SDBMPositiveExpr SDBMDiffExpr::getRHS() const { + return static_cast(impl)->rhs; +} + +//===----------------------------------------------------------------------===// +// SDBMStripeExpr +//===----------------------------------------------------------------------===// + +SDBMStripeExpr SDBMStripeExpr::get(SDBMPositiveExpr var, + SDBMConstantExpr stripeFactor) { + assert(var && "expected SDBM variable expression"); + assert(stripeFactor && "expected non-null stripe factor"); + if (stripeFactor.getValue() <= 0) + llvm::report_fatal_error("non-positive stripe factor"); + + StorageUniquer &uniquer = var.getDialect()->getUniquer(); + return uniquer.get( + /*initFn=*/{}, static_cast(SDBMExprKind::Stripe), var, + stripeFactor); +} + +SDBMPositiveExpr SDBMStripeExpr::getVar() const { + if (SDBMVaryingExpr lhs = static_cast(impl)->lhs) + return lhs.cast(); + return {}; +} + +SDBMConstantExpr SDBMStripeExpr::getStripeFactor() const { + return static_cast(impl)->rhs; +} + +//===----------------------------------------------------------------------===// +// SDBMInputExpr +//===----------------------------------------------------------------------===// + +unsigned SDBMInputExpr::getPosition() const { + return static_cast(impl)->position; +} + +//===----------------------------------------------------------------------===// +// SDBMDimExpr +//===----------------------------------------------------------------------===// + +SDBMDimExpr SDBMDimExpr::get(SDBMDialect *dialect, unsigned position) { + assert(dialect && "expected non-null dialect"); + + auto assignDialect = [dialect](detail::SDBMPositiveExprStorage *storage) { + storage->dialect = dialect; + }; + + StorageUniquer &uniquer = dialect->getUniquer(); + return uniquer.get( + assignDialect, static_cast(SDBMExprKind::DimId), position); +} + +//===----------------------------------------------------------------------===// +// SDBMSymbolExpr +//===----------------------------------------------------------------------===// + +SDBMSymbolExpr SDBMSymbolExpr::get(SDBMDialect *dialect, unsigned position) { + assert(dialect && "expected non-null dialect"); + + auto assignDialect = [dialect](detail::SDBMPositiveExprStorage *storage) { + storage->dialect = dialect; + }; + + StorageUniquer &uniquer = dialect->getUniquer(); + return uniquer.get( + assignDialect, static_cast(SDBMExprKind::SymbolId), position); +} + +//===----------------------------------------------------------------------===// +// SDBMConstantExpr +//===----------------------------------------------------------------------===// + +SDBMConstantExpr SDBMConstantExpr::get(SDBMDialect *dialect, int64_t value) { + assert(dialect && "expected non-null dialect"); + + auto assignCtx = [dialect](detail::SDBMConstantExprStorage *storage) { + storage->dialect = dialect; + }; + + StorageUniquer &uniquer = dialect->getUniquer(); + return uniquer.get( + assignCtx, static_cast(SDBMExprKind::Constant), value); +} + +int64_t SDBMConstantExpr::getValue() const { + return static_cast(impl)->constant; +} + +//===----------------------------------------------------------------------===// +// SDBMNegExpr +//===----------------------------------------------------------------------===// + +SDBMNegExpr SDBMNegExpr::get(SDBMPositiveExpr var) { + assert(var && "expected non-null SDBM variable expression"); + + StorageUniquer &uniquer = var.getDialect()->getUniquer(); + return uniquer.get( + /*initFn=*/{}, static_cast(SDBMExprKind::Neg), var); +} + +SDBMPositiveExpr SDBMNegExpr::getVar() const { + return static_cast(impl)->dim; +} + +namespace mlir { +namespace ops_assertions { + +SDBMExpr operator+(SDBMExpr lhs, SDBMExpr rhs) { + // If one of the operands is a negation, take a difference rather than a sum. + auto lhsNeg = lhs.dyn_cast(); + auto rhsNeg = rhs.dyn_cast(); + assert(!(lhsNeg && rhsNeg) && "a sum of negated expressions is a negation of " + "a sum of variables and not a correct SDBM"); + if (lhsNeg) + return rhs - lhsNeg.getVar(); + if (rhsNeg) + return lhs - rhsNeg.getVar(); + + // If LHS is a constant and RHS is not, swap the order to get into a supported + // sum case. From now on, RHS must be a constant. + auto lhsConstant = lhs.dyn_cast(); + auto rhsConstant = rhs.dyn_cast(); + if (!rhsConstant && lhsConstant) { + std::swap(lhs, rhs); + std::swap(lhsConstant, rhsConstant); + } + assert(rhsConstant && "at least one operand must be a constant"); + + // If LHS is another sum, first compute the sum of its variable + // part with the other argument and then add the constant part to enable + // constant folding (the variable part may, e.g., be a negation that requires + // to enter this function again). + auto lhsSum = lhs.dyn_cast(); + if (lhsSum) + return lhsSum.getLHS() + + (lhsSum.getRHS().getValue() + rhsConstant.getValue()); + + // Constant-fold if LHS is a constant. + if (lhsConstant) + return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant.getValue() + + rhsConstant.getValue()); + + // Fold x + 0 == x. + if (rhsConstant.getValue() == 0) + return lhs; + + return SDBMSumExpr::get(lhs.cast(), + rhs.cast()); +} + +SDBMExpr operator-(SDBMExpr lhs, SDBMExpr rhs) { + // Fold x - x == 0. + if (lhs == rhs) + return SDBMConstantExpr::get(lhs.getDialect(), 0); + + // LHS and RHS may be constants. + auto lhsConstant = lhs.dyn_cast(); + auto rhsConstant = rhs.dyn_cast(); + + // Constant fold if both LHS and RHS are constants. + if (lhsConstant && rhsConstant) + return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant.getValue() - + rhsConstant.getValue()); + + // Replace a difference with a sum with a negated value if one of LHS and RHS + // is a constant: + // x - C == x + (-C); + // C - x == -x + C. + // This calls into operator+ for further simplification. + if (rhsConstant) + return lhs + (-rhsConstant); + if (lhsConstant) + return -rhs + lhsConstant; + + // Hoist constant factors outside the difference if any of sides is a sum: + // (x + A) - (y - B) == x - y + (A - B). + // If either LHS or RHS is a sum, collect the constant values separately and + // update LHS and RHS to point to the variable part of the sum. + auto lhsSum = lhs.dyn_cast(); + auto rhsSum = rhs.dyn_cast(); + int64_t value = 0; + if (lhsSum) { + value += lhsSum.getRHS().getValue(); + lhs = lhsSum.getLHS(); + } + if (rhsSum) { + value -= rhsSum.getRHS().getValue(); + rhs = rhsSum.getLHS(); + } + + // This calls into operator+ for futher simplification in case value == 0. + return SDBMDiffExpr::get(lhs.cast(), + rhs.cast()) + + value; +} + +SDBMExpr stripe(SDBMExpr expr, SDBMExpr factor) { + auto constantFactor = factor.cast(); + assert(constantFactor.getValue() > 0 && "non-positive stripe"); + + // Fold x # 1 = x. + if (constantFactor.getValue() == 1) + return expr; + + return SDBMStripeExpr::get(expr.cast(), constantFactor); +} + +} // namespace ops_assertions +} // namespace mlir diff --git a/mlir/lib/Dialect/SDBM/SDBMExprDetail.h b/mlir/lib/Dialect/SDBM/SDBMExprDetail.h new file mode 100644 index 00000000000..1721b02dae7 --- /dev/null +++ b/mlir/lib/Dialect/SDBM/SDBMExprDetail.h @@ -0,0 +1,138 @@ +//===- SDBMExprDetail.h - MLIR SDBM Expression storage details --*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This holds implementation details of SDBMExpr, in particular underlying +// storage types. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_SDBMEXPRDETAIL_H +#define MLIR_IR_SDBMEXPRDETAIL_H + +#include "mlir/Dialect/SDBM/SDBMExpr.h" +#include "mlir/Support/StorageUniquer.h" + +namespace mlir { + +class SDBMDialect; + +namespace detail { + +// Base storage class for SDBMExpr. +struct SDBMExprStorage : public StorageUniquer::BaseStorage { + SDBMExprKind getKind() { + return static_cast(BaseStorage::getKind()); + } + + SDBMDialect *dialect; +}; + +// Storage class for SDBM sum and stripe expressions. +struct SDBMBinaryExprStorage : public SDBMExprStorage { + using KeyTy = std::pair; + + bool operator==(const KeyTy &key) const { + return std::get<0>(key) == lhs && std::get<1>(key) == rhs; + } + + static SDBMBinaryExprStorage * + construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) { + auto *result = allocator.allocate(); + result->lhs = std::get<0>(key); + result->rhs = std::get<1>(key); + result->dialect = result->lhs.getDialect(); + return result; + } + + SDBMVaryingExpr lhs; + SDBMConstantExpr rhs; +}; + +// Storage class for SDBM difference expressions. +struct SDBMDiffExprStorage : public SDBMExprStorage { + using KeyTy = std::pair; + + bool operator==(const KeyTy &key) const { + return std::get<0>(key) == lhs && std::get<1>(key) == rhs; + } + + static SDBMDiffExprStorage * + construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) { + auto *result = allocator.allocate(); + result->lhs = std::get<0>(key); + result->rhs = std::get<1>(key); + result->dialect = result->lhs.getDialect(); + return result; + } + + SDBMPositiveExpr lhs; + SDBMPositiveExpr rhs; +}; + +// Storage class for SDBM constant expressions. +struct SDBMConstantExprStorage : public SDBMExprStorage { + using KeyTy = int64_t; + + bool operator==(const KeyTy &key) const { return constant == key; } + + static SDBMConstantExprStorage * + construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) { + auto *result = allocator.allocate(); + result->constant = key; + return result; + } + + int64_t constant; +}; + +// Storage class for SDBM dimension and symbol expressions. +struct SDBMPositiveExprStorage : public SDBMExprStorage { + using KeyTy = unsigned; + + bool operator==(const KeyTy &key) const { return position == key; } + + static SDBMPositiveExprStorage * + construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) { + auto *result = allocator.allocate(); + result->position = key; + return result; + } + + unsigned position; +}; + +// Storage class for SDBM negation expressions. +struct SDBMNegExprStorage : public SDBMExprStorage { + using KeyTy = SDBMPositiveExpr; + + bool operator==(const KeyTy &key) const { return key == dim; } + + static SDBMNegExprStorage * + construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) { + auto *result = allocator.allocate(); + result->dim = key; + result->dialect = key.getDialect(); + return result; + } + + SDBMPositiveExpr dim; +}; + +} // end namespace detail +} // end namespace mlir + +#endif // MLIR_IR_SDBMEXPRDETAIL_H diff --git a/mlir/lib/Dialect/SPIRV/Serialization/ConvertFromBinary.cpp b/mlir/lib/Dialect/SPIRV/Serialization/ConvertFromBinary.cpp index 38e8d93752e..cda56e27b1a 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/ConvertFromBinary.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/ConvertFromBinary.cpp @@ -22,10 +22,10 @@ #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Serialization.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" #include "mlir/IR/Module.h" -#include "mlir/StandardOps/Ops.h" #include "mlir/Support/FileUtilities.h" #include "mlir/Translation.h" #include "llvm/ADT/StringRef.h" diff --git a/mlir/lib/Dialect/StandardOps/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/CMakeLists.txt new file mode 100644 index 00000000000..f10c173af8a --- /dev/null +++ b/mlir/lib/Dialect/StandardOps/CMakeLists.txt @@ -0,0 +1,9 @@ +file(GLOB globbed *.c *.cpp) +add_llvm_library(MLIRStandardOps + ${globbed} + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/StandardOps + ) +add_dependencies(MLIRStandardOps MLIRStandardOpsIncGen LLVMSupport) +target_link_libraries(MLIRStandardOps LLVMSupport) diff --git a/mlir/lib/Dialect/StandardOps/DialectRegistration.cpp b/mlir/lib/Dialect/StandardOps/DialectRegistration.cpp new file mode 100644 index 00000000000..6b5578f93cf --- /dev/null +++ b/mlir/lib/Dialect/StandardOps/DialectRegistration.cpp @@ -0,0 +1,22 @@ +//===- DialectRegistration.cpp - Register standard Op dialect -------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/Dialect/StandardOps/Ops.h" +using namespace mlir; + +// Static initialization for standard op dialect registration. +static DialectRegistration StandardOps; diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp new file mode 100644 index 00000000000..4e484e6b50b --- /dev/null +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -0,0 +1,2102 @@ +//===- Ops.cpp - Standard MLIR Operations ---------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/Dialect/StandardOps/Ops.h" + +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/MathExtras.h" +#include "mlir/Support/STLExtras.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/raw_ostream.h" +using namespace mlir; + +//===----------------------------------------------------------------------===// +// StandardOpsDialect +//===----------------------------------------------------------------------===// + +/// A custom binary operation printer that omits the "std." prefix from the +/// operation names. +static void printStandardBinaryOp(Operation *op, OpAsmPrinter *p) { + assert(op->getNumOperands() == 2 && "binary op should have two operands"); + assert(op->getNumResults() == 1 && "binary op should have one result"); + + // If not all the operand and result types are the same, just use the + // generic assembly form to avoid omitting information in printing. + auto resultType = op->getResult(0)->getType(); + if (op->getOperand(0)->getType() != resultType || + op->getOperand(1)->getType() != resultType) { + p->printGenericOp(op); + return; + } + + *p << op->getName().getStringRef().drop_front(strlen("std.")) << ' ' + << *op->getOperand(0) << ", " << *op->getOperand(1); + p->printOptionalAttrDict(op->getAttrs()); + + // Now we can output only one type for all operands and the result. + *p << " : " << op->getResult(0)->getType(); +} + +/// A custom cast operation printer that omits the "std." prefix from the +/// operation names. +static void printStandardCastOp(Operation *op, OpAsmPrinter *p) { + *p << op->getName().getStringRef().drop_front(strlen("std.")) << ' ' + << *op->getOperand(0) << " : " << op->getOperand(0)->getType() << " to " + << op->getResult(0)->getType(); +} + +/// A custom cast operation verifier. +template static LogicalResult verifyCastOp(T op) { + auto opType = op.getOperand()->getType(); + auto resType = op.getType(); + if (!T::areCastCompatible(opType, resType)) + return op.emitError("operand type ") << opType << " and result type " + << resType << " are cast incompatible"; + + return success(); +} + +StandardOpsDialect::StandardOpsDialect(MLIRContext *context) + : Dialect(getDialectNamespace(), context) { + addOperations(); +} + +void mlir::printDimAndSymbolList(Operation::operand_iterator begin, + Operation::operand_iterator end, + unsigned numDims, OpAsmPrinter *p) { + *p << '('; + p->printOperands(begin, begin + numDims); + *p << ')'; + + if (begin + numDims != end) { + *p << '['; + p->printOperands(begin + numDims, end); + *p << ']'; + } +} + +// Parses dimension and symbol list, and sets 'numDims' to the number of +// dimension operands parsed. +// Returns 'false' on success and 'true' on error. +ParseResult mlir::parseDimAndSymbolList(OpAsmParser *parser, + SmallVector &operands, + unsigned &numDims) { + SmallVector opInfos; + if (parser->parseOperandList(opInfos, OpAsmParser::Delimiter::Paren)) + return failure(); + // Store number of dimensions for validation by caller. + numDims = opInfos.size(); + + // Parse the optional symbol operands. + auto affineIntTy = parser->getBuilder().getIndexType(); + if (parser->parseOperandList(opInfos, + OpAsmParser::Delimiter::OptionalSquare) || + parser->resolveOperands(opInfos, affineIntTy, operands)) + return failure(); + return success(); +} + +/// Matches a ConstantIndexOp. +/// TODO: This should probably just be a general matcher that uses m_Constant +/// and checks the operation for an index type. +static detail::op_matcher m_ConstantIndex() { + return detail::op_matcher(); +} + +//===----------------------------------------------------------------------===// +// Common canonicalization pattern support logic +//===----------------------------------------------------------------------===// + +namespace { +/// This is a common class used for patterns of the form +/// "someop(memrefcast) -> someop". It folds the source of any memref_cast +/// into the root operation directly. +struct MemRefCastFolder : public RewritePattern { + /// The rootOpName is the name of the root operation to match against. + MemRefCastFolder(StringRef rootOpName, MLIRContext *context) + : RewritePattern(rootOpName, 1, context) {} + + PatternMatchResult match(Operation *op) const override { + for (auto *operand : op->getOperands()) + if (matchPattern(operand, m_Op())) + return matchSuccess(); + + return matchFailure(); + } + + void rewrite(Operation *op, PatternRewriter &rewriter) const override { + for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) + if (auto *memref = op->getOperand(i)->getDefiningOp()) + if (auto cast = dyn_cast(memref)) + op->setOperand(i, cast.getOperand()); + rewriter.updatedRootInPlace(op); + } +}; + +/// Performs const folding `calculate` with element-wise behavior on the two +/// attributes in `operands` and returns the result if possible. +template > +Attribute constFoldBinaryOp(ArrayRef operands, + const CalculationT &calculate) { + assert(operands.size() == 2 && "binary op takes two operands"); + + if (auto lhs = operands[0].dyn_cast_or_null()) { + auto rhs = operands[1].dyn_cast_or_null(); + if (!rhs || lhs.getType() != rhs.getType()) + return {}; + + return AttrElementT::get(lhs.getType(), + calculate(lhs.getValue(), rhs.getValue())); + } else if (auto lhs = operands[0].dyn_cast_or_null()) { + auto rhs = operands[1].dyn_cast_or_null(); + if (!rhs || lhs.getType() != rhs.getType()) + return {}; + + auto elementResult = constFoldBinaryOp( + {lhs.getSplatValue(), rhs.getSplatValue()}, calculate); + if (!elementResult) + return {}; + + return DenseElementsAttr::get(lhs.getType(), elementResult); + } + return {}; +} +} // end anonymous namespace. + +//===----------------------------------------------------------------------===// +// AddFOp +//===----------------------------------------------------------------------===// + +OpFoldResult AddFOp::fold(ArrayRef operands) { + return constFoldBinaryOp( + operands, [](APFloat a, APFloat b) { return a + b; }); +} + +//===----------------------------------------------------------------------===// +// AddIOp +//===----------------------------------------------------------------------===// + +OpFoldResult AddIOp::fold(ArrayRef operands) { + /// addi(x, 0) -> x + if (matchPattern(rhs(), m_Zero())) + return lhs(); + + return constFoldBinaryOp(operands, + [](APInt a, APInt b) { return a + b; }); +} + +//===----------------------------------------------------------------------===// +// AllocOp +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter *p, AllocOp op) { + *p << "alloc"; + + // Print dynamic dimension operands. + MemRefType type = op.getType(); + printDimAndSymbolList(op.operand_begin(), op.operand_end(), + type.getNumDynamicDims(), p); + p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"}); + *p << " : " << type; +} + +static ParseResult parseAllocOp(OpAsmParser *parser, OperationState *result) { + MemRefType type; + + // Parse the dimension operands and optional symbol operands, followed by a + // memref type. + unsigned numDimOperands; + if (parseDimAndSymbolList(parser, result->operands, numDimOperands) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(type)) + return failure(); + + // Check numDynamicDims against number of question marks in memref type. + // Note: this check remains here (instead of in verify()), because the + // partition between dim operands and symbol operands is lost after parsing. + // Verification still checks that the total number of operands matches + // the number of symbols in the affine map, plus the number of dynamic + // dimensions in the memref. + if (numDimOperands != type.getNumDynamicDims()) + return parser->emitError(parser->getNameLoc()) + << "dimension operand count does not equal memref dynamic dimension " + "count"; + result->types.push_back(type); + return success(); +} + +static LogicalResult verify(AllocOp op) { + auto memRefType = op.getResult()->getType().dyn_cast(); + if (!memRefType) + return op.emitOpError("result must be a memref"); + + unsigned numSymbols = 0; + if (!memRefType.getAffineMaps().empty()) { + // Store number of symbols used in affine map (used in subsequent check). + AffineMap affineMap = memRefType.getAffineMaps()[0]; + numSymbols = affineMap.getNumSymbols(); + } + + // Check that the total number of operands matches the number of symbols in + // the affine map, plus the number of dynamic dimensions specified in the + // memref type. + unsigned numDynamicDims = memRefType.getNumDynamicDims(); + if (op.getOperation()->getNumOperands() != numDynamicDims + numSymbols) + return op.emitOpError( + "operand count does not equal dimension plus symbol operand count"); + + // Verify that all operands are of type Index. + for (auto operandType : op.getOperandTypes()) + if (!operandType.isIndex()) + return op.emitOpError("requires operands to be of type Index"); + return success(); +} + +namespace { +/// Fold constant dimensions into an alloc operation. +struct SimplifyAllocConst : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(AllocOp alloc, + PatternRewriter &rewriter) const override { + // Check to see if any dimensions operands are constants. If so, we can + // substitute and drop them. + if (llvm::none_of(alloc.getOperands(), [](Value *operand) { + return matchPattern(operand, m_ConstantIndex()); + })) + return matchFailure(); + + auto memrefType = alloc.getType(); + + // Ok, we have one or more constant operands. Collect the non-constant ones + // and keep track of the resultant memref type to build. + SmallVector newShapeConstants; + newShapeConstants.reserve(memrefType.getRank()); + SmallVector newOperands; + SmallVector droppedOperands; + + unsigned dynamicDimPos = 0; + for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { + int64_t dimSize = memrefType.getDimSize(dim); + // If this is already static dimension, keep it. + if (dimSize != -1) { + newShapeConstants.push_back(dimSize); + continue; + } + auto *defOp = alloc.getOperand(dynamicDimPos)->getDefiningOp(); + if (auto constantIndexOp = dyn_cast_or_null(defOp)) { + // Dynamic shape dimension will be folded. + newShapeConstants.push_back(constantIndexOp.getValue()); + // Record to check for zero uses later below. + droppedOperands.push_back(constantIndexOp); + } else { + // Dynamic shape dimension not folded; copy operand from old memref. + newShapeConstants.push_back(-1); + newOperands.push_back(alloc.getOperand(dynamicDimPos)); + } + dynamicDimPos++; + } + + // Create new memref type (which will have fewer dynamic dimensions). + auto newMemRefType = MemRefType::get( + newShapeConstants, memrefType.getElementType(), + memrefType.getAffineMaps(), memrefType.getMemorySpace()); + assert(static_cast(newOperands.size()) == + newMemRefType.getNumDynamicDims()); + + // Create and insert the alloc op for the new memref. + auto newAlloc = + rewriter.create(alloc.getLoc(), newMemRefType, newOperands); + // Insert a cast so we have the same type as the old alloc. + auto resultCast = rewriter.create(alloc.getLoc(), newAlloc, + alloc.getType()); + + rewriter.replaceOp(alloc, {resultCast}, droppedOperands); + return matchSuccess(); + } +}; + +/// Fold alloc operations with no uses. Alloc has side effects on the heap, +/// but can still be deleted if it has zero uses. +struct SimplifyDeadAlloc : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(AllocOp alloc, + PatternRewriter &rewriter) const override { + // Check if the alloc'ed value has any uses. + if (!alloc.use_empty()) + return matchFailure(); + + // If it doesn't, we can eliminate it. + alloc.erase(); + return matchSuccess(); + } +}; +} // end anonymous namespace. + +void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// BranchOp +//===----------------------------------------------------------------------===// + +static ParseResult parseBranchOp(OpAsmParser *parser, OperationState *result) { + Block *dest; + SmallVector destOperands; + if (parser->parseSuccessorAndUseList(dest, destOperands)) + return failure(); + result->addSuccessor(dest, destOperands); + return success(); +} + +static void print(OpAsmPrinter *p, BranchOp op) { + *p << "br "; + p->printSuccessorAndUseList(op.getOperation(), 0); +} + +Block *BranchOp::getDest() { return getOperation()->getSuccessor(0); } + +void BranchOp::setDest(Block *block) { + return getOperation()->setSuccessor(block, 0); +} + +void BranchOp::eraseOperand(unsigned index) { + getOperation()->eraseSuccessorOperand(0, index); +} + +//===----------------------------------------------------------------------===// +// CallOp +//===----------------------------------------------------------------------===// + +static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) { + SymbolRefAttr calleeAttr; + FunctionType calleeType; + SmallVector operands; + auto calleeLoc = parser->getNameLoc(); + if (parser->parseAttribute(calleeAttr, "callee", result->attributes) || + parser->parseOperandList(operands, OpAsmParser::Delimiter::Paren) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(calleeType) || + parser->addTypesToList(calleeType.getResults(), result->types) || + parser->resolveOperands(operands, calleeType.getInputs(), calleeLoc, + result->operands)) + return failure(); + + return success(); +} + +static void print(OpAsmPrinter *p, CallOp op) { + *p << "call " << op.getAttr("callee") << '('; + p->printOperands(op.getOperands()); + *p << ')'; + p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); + *p << " : "; + p->printType(op.getCalleeType()); +} + +static LogicalResult verify(CallOp op) { + // Check that the callee attribute was specified. + auto fnAttr = op.getAttrOfType("callee"); + if (!fnAttr) + return op.emitOpError("requires a 'callee' symbol reference attribute"); + auto fn = + op.getParentOfType().lookupSymbol(fnAttr.getValue()); + if (!fn) + return op.emitOpError() << "'" << fnAttr.getValue() + << "' does not reference a valid function"; + + // Verify that the operand and result types match the callee. + auto fnType = fn.getType(); + if (fnType.getNumInputs() != op.getNumOperands()) + return op.emitOpError("incorrect number of operands for callee"); + + for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) + if (op.getOperand(i)->getType() != fnType.getInput(i)) + return op.emitOpError("operand type mismatch"); + + if (fnType.getNumResults() != op.getNumResults()) + return op.emitOpError("incorrect number of results for callee"); + + for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) + if (op.getResult(i)->getType() != fnType.getResult(i)) + return op.emitOpError("result type mismatch"); + + return success(); +} + +FunctionType CallOp::getCalleeType() { + SmallVector resultTypes(getResultTypes()); + SmallVector argTypes(getOperandTypes()); + return FunctionType::get(argTypes, resultTypes, getContext()); +} + +//===----------------------------------------------------------------------===// +// CallIndirectOp +//===----------------------------------------------------------------------===// +namespace { +/// Fold indirect calls that have a constant function as the callee operand. +struct SimplifyIndirectCallWithKnownCallee + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(CallIndirectOp indirectCall, + PatternRewriter &rewriter) const override { + // Check that the callee is a constant callee. + SymbolRefAttr calledFn; + if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn))) + return matchFailure(); + + // Replace with a direct call. + SmallVector callResults(indirectCall.getResultTypes()); + SmallVector callOperands(indirectCall.getArgOperands()); + rewriter.replaceOpWithNewOp(indirectCall, calledFn.getValue(), + callResults, callOperands); + return matchSuccess(); + } +}; +} // end anonymous namespace. + +static ParseResult parseCallIndirectOp(OpAsmParser *parser, + OperationState *result) { + FunctionType calleeType; + OpAsmParser::OperandType callee; + llvm::SMLoc operandsLoc; + SmallVector operands; + return failure( + parser->parseOperand(callee) || + parser->getCurrentLocation(&operandsLoc) || + parser->parseOperandList(operands, OpAsmParser::Delimiter::Paren) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(calleeType) || + parser->resolveOperand(callee, calleeType, result->operands) || + parser->resolveOperands(operands, calleeType.getInputs(), operandsLoc, + result->operands) || + parser->addTypesToList(calleeType.getResults(), result->types)); +} + +static void print(OpAsmPrinter *p, CallIndirectOp op) { + *p << "call_indirect "; + p->printOperand(op.getCallee()); + *p << '('; + p->printOperands(op.getArgOperands()); + *p << ')'; + p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); + *p << " : " << op.getCallee()->getType(); +} + +static LogicalResult verify(CallIndirectOp op) { + // The callee must be a function. + auto fnType = op.getCallee()->getType().dyn_cast(); + if (!fnType) + return op.emitOpError("callee must have function type"); + + // Verify that the operand and result types match the callee. + if (fnType.getNumInputs() != op.getNumOperands() - 1) + return op.emitOpError("incorrect number of operands for callee"); + + for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) + if (op.getOperand(i + 1)->getType() != fnType.getInput(i)) + return op.emitOpError("operand type mismatch"); + + if (fnType.getNumResults() != op.getNumResults()) + return op.emitOpError("incorrect number of results for callee"); + + for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) + if (op.getResult(i)->getType() != fnType.getResult(i)) + return op.emitOpError("result type mismatch"); + + return success(); +} + +void CallIndirectOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// General helpers for comparison ops +//===----------------------------------------------------------------------===// + +// Return the type of the same shape (scalar, vector or tensor) containing i1. +static Type getCheckedI1SameShape(Builder *build, Type type) { + auto i1Type = build->getI1Type(); + if (type.isIntOrIndexOrFloat()) + return i1Type; + if (auto tensorType = type.dyn_cast()) + return build->getTensorType(tensorType.getShape(), i1Type); + if (type.isa()) + return build->getTensorType(i1Type); + if (auto vectorType = type.dyn_cast()) + return build->getVectorType(vectorType.getShape(), i1Type); + return Type(); +} + +static Type getI1SameShape(Builder *build, Type type) { + Type res = getCheckedI1SameShape(build, type); + assert(res && "expected type with valid i1 shape"); + return res; +} + +//===----------------------------------------------------------------------===// +// CmpIOp +//===----------------------------------------------------------------------===// + +// Returns an array of mnemonics for CmpIPredicates indexed by values thereof. +static inline const char *const *getCmpIPredicateNames() { + static const char *predicateNames[]{ + /*EQ*/ "eq", + /*NE*/ "ne", + /*SLT*/ "slt", + /*SLE*/ "sle", + /*SGT*/ "sgt", + /*SGE*/ "sge", + /*ULT*/ "ult", + /*ULE*/ "ule", + /*UGT*/ "ugt", + /*UGE*/ "uge", + }; + static_assert(std::extent::value == + (size_t)CmpIPredicate::NumPredicates, + "wrong number of predicate names"); + return predicateNames; +} + +// Returns a value of the predicate corresponding to the given mnemonic. +// Returns NumPredicates (one-past-end) if there is no such mnemonic. +CmpIPredicate CmpIOp::getPredicateByName(StringRef name) { + return llvm::StringSwitch(name) + .Case("eq", CmpIPredicate::EQ) + .Case("ne", CmpIPredicate::NE) + .Case("slt", CmpIPredicate::SLT) + .Case("sle", CmpIPredicate::SLE) + .Case("sgt", CmpIPredicate::SGT) + .Case("sge", CmpIPredicate::SGE) + .Case("ult", CmpIPredicate::ULT) + .Case("ule", CmpIPredicate::ULE) + .Case("ugt", CmpIPredicate::UGT) + .Case("uge", CmpIPredicate::UGE) + .Default(CmpIPredicate::NumPredicates); +} + +static void buildCmpIOp(Builder *build, OperationState *result, + CmpIPredicate predicate, Value *lhs, Value *rhs) { + result->addOperands({lhs, rhs}); + result->types.push_back(getI1SameShape(build, lhs->getType())); + result->addAttribute( + CmpIOp::getPredicateAttrName(), + build->getI64IntegerAttr(static_cast(predicate))); +} + +static ParseResult parseCmpIOp(OpAsmParser *parser, OperationState *result) { + SmallVector ops; + SmallVector attrs; + Attribute predicateNameAttr; + Type type; + if (parser->parseAttribute(predicateNameAttr, CmpIOp::getPredicateAttrName(), + attrs) || + parser->parseComma() || parser->parseOperandList(ops, 2) || + parser->parseOptionalAttributeDict(attrs) || + parser->parseColonType(type) || + parser->resolveOperands(ops, type, result->operands)) + return failure(); + + if (!predicateNameAttr.isa()) + return parser->emitError(parser->getNameLoc(), + "expected string comparison predicate attribute"); + + // Rewrite string attribute to an enum value. + StringRef predicateName = predicateNameAttr.cast().getValue(); + auto predicate = CmpIOp::getPredicateByName(predicateName); + if (predicate == CmpIPredicate::NumPredicates) + return parser->emitError(parser->getNameLoc()) + << "unknown comparison predicate \"" << predicateName << "\""; + + auto builder = parser->getBuilder(); + Type i1Type = getCheckedI1SameShape(&builder, type); + if (!i1Type) + return parser->emitError(parser->getNameLoc(), + "expected type with valid i1 shape"); + + attrs[0].second = builder.getI64IntegerAttr(static_cast(predicate)); + result->attributes = attrs; + + result->addTypes({i1Type}); + return success(); +} + +static void print(OpAsmPrinter *p, CmpIOp op) { + *p << "cmpi "; + + auto predicateValue = + op.getAttrOfType(CmpIOp::getPredicateAttrName()).getInt(); + assert(predicateValue >= static_cast(CmpIPredicate::FirstValidValue) && + predicateValue < static_cast(CmpIPredicate::NumPredicates) && + "unknown predicate index"); + Builder b(op.getContext()); + auto predicateStringAttr = + b.getStringAttr(getCmpIPredicateNames()[predicateValue]); + p->printAttribute(predicateStringAttr); + + *p << ", "; + p->printOperand(op.lhs()); + *p << ", "; + p->printOperand(op.rhs()); + p->printOptionalAttrDict(op.getAttrs(), + /*elidedAttrs=*/{CmpIOp::getPredicateAttrName()}); + *p << " : " << op.lhs()->getType(); +} + +static LogicalResult verify(CmpIOp op) { + auto predicateAttr = + op.getAttrOfType(CmpIOp::getPredicateAttrName()); + if (!predicateAttr) + return op.emitOpError("requires an integer attribute named 'predicate'"); + auto predicate = predicateAttr.getInt(); + if (predicate < (int64_t)CmpIPredicate::FirstValidValue || + predicate >= (int64_t)CmpIPredicate::NumPredicates) + return op.emitOpError("'predicate' attribute value out of range"); + + return success(); +} + +// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer +// comparison predicates. +static bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs, + const APInt &rhs) { + switch (predicate) { + case CmpIPredicate::EQ: + return lhs.eq(rhs); + case CmpIPredicate::NE: + return lhs.ne(rhs); + case CmpIPredicate::SLT: + return lhs.slt(rhs); + case CmpIPredicate::SLE: + return lhs.sle(rhs); + case CmpIPredicate::SGT: + return lhs.sgt(rhs); + case CmpIPredicate::SGE: + return lhs.sge(rhs); + case CmpIPredicate::ULT: + return lhs.ult(rhs); + case CmpIPredicate::ULE: + return lhs.ule(rhs); + case CmpIPredicate::UGT: + return lhs.ugt(rhs); + case CmpIPredicate::UGE: + return lhs.uge(rhs); + default: + llvm_unreachable("unknown comparison predicate"); + } +} + +// Constant folding hook for comparisons. +OpFoldResult CmpIOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "cmpi takes two arguments"); + + auto lhs = operands.front().dyn_cast_or_null(); + auto rhs = operands.back().dyn_cast_or_null(); + if (!lhs || !rhs) + return {}; + + auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); + return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val)); +} + +//===----------------------------------------------------------------------===// +// CmpFOp +//===----------------------------------------------------------------------===// + +// Returns an array of mnemonics for CmpFPredicates indexed by values thereof. +static inline const char *const *getCmpFPredicateNames() { + static const char *predicateNames[] = { + /*AlwaysFalse*/ "false", + /*OEQ*/ "oeq", + /*OGT*/ "ogt", + /*OGE*/ "oge", + /*OLT*/ "olt", + /*OLE*/ "ole", + /*ONE*/ "one", + /*ORD*/ "ord", + /*UEQ*/ "ueq", + /*UGT*/ "ugt", + /*UGE*/ "uge", + /*ULT*/ "ult", + /*ULE*/ "ule", + /*UNE*/ "une", + /*UNO*/ "uno", + /*AlwaysTrue*/ "true", + }; + static_assert(std::extent::value == + (size_t)CmpFPredicate::NumPredicates, + "wrong number of predicate names"); + return predicateNames; +} + +// Returns a value of the predicate corresponding to the given mnemonic. +// Returns NumPredicates (one-past-end) if there is no such mnemonic. +CmpFPredicate CmpFOp::getPredicateByName(StringRef name) { + return llvm::StringSwitch(name) + .Case("false", CmpFPredicate::AlwaysFalse) + .Case("oeq", CmpFPredicate::OEQ) + .Case("ogt", CmpFPredicate::OGT) + .Case("oge", CmpFPredicate::OGE) + .Case("olt", CmpFPredicate::OLT) + .Case("ole", CmpFPredicate::OLE) + .Case("one", CmpFPredicate::ONE) + .Case("ord", CmpFPredicate::ORD) + .Case("ueq", CmpFPredicate::UEQ) + .Case("ugt", CmpFPredicate::UGT) + .Case("uge", CmpFPredicate::UGE) + .Case("ult", CmpFPredicate::ULT) + .Case("ule", CmpFPredicate::ULE) + .Case("une", CmpFPredicate::UNE) + .Case("uno", CmpFPredicate::UNO) + .Case("true", CmpFPredicate::AlwaysTrue) + .Default(CmpFPredicate::NumPredicates); +} + +static void buildCmpFOp(Builder *build, OperationState *result, + CmpFPredicate predicate, Value *lhs, Value *rhs) { + result->addOperands({lhs, rhs}); + result->types.push_back(getI1SameShape(build, lhs->getType())); + result->addAttribute( + CmpFOp::getPredicateAttrName(), + build->getI64IntegerAttr(static_cast(predicate))); +} + +static ParseResult parseCmpFOp(OpAsmParser *parser, OperationState *result) { + SmallVector ops; + SmallVector attrs; + Attribute predicateNameAttr; + Type type; + if (parser->parseAttribute(predicateNameAttr, CmpFOp::getPredicateAttrName(), + attrs) || + parser->parseComma() || parser->parseOperandList(ops, 2) || + parser->parseOptionalAttributeDict(attrs) || + parser->parseColonType(type) || + parser->resolveOperands(ops, type, result->operands)) + return failure(); + + if (!predicateNameAttr.isa()) + return parser->emitError(parser->getNameLoc(), + "expected string comparison predicate attribute"); + + // Rewrite string attribute to an enum value. + StringRef predicateName = predicateNameAttr.cast().getValue(); + auto predicate = CmpFOp::getPredicateByName(predicateName); + if (predicate == CmpFPredicate::NumPredicates) + return parser->emitError(parser->getNameLoc(), + "unknown comparison predicate \"" + predicateName + + "\""); + + auto builder = parser->getBuilder(); + Type i1Type = getCheckedI1SameShape(&builder, type); + if (!i1Type) + return parser->emitError(parser->getNameLoc(), + "expected type with valid i1 shape"); + + attrs[0].second = builder.getI64IntegerAttr(static_cast(predicate)); + result->attributes = attrs; + + result->addTypes({i1Type}); + return success(); +} + +static void print(OpAsmPrinter *p, CmpFOp op) { + *p << "cmpf "; + + auto predicateValue = + op.getAttrOfType(CmpFOp::getPredicateAttrName()).getInt(); + assert(predicateValue >= static_cast(CmpFPredicate::FirstValidValue) && + predicateValue < static_cast(CmpFPredicate::NumPredicates) && + "unknown predicate index"); + Builder b(op.getContext()); + auto predicateStringAttr = + b.getStringAttr(getCmpFPredicateNames()[predicateValue]); + p->printAttribute(predicateStringAttr); + + *p << ", "; + p->printOperand(op.lhs()); + *p << ", "; + p->printOperand(op.rhs()); + p->printOptionalAttrDict(op.getAttrs(), + /*elidedAttrs=*/{CmpFOp::getPredicateAttrName()}); + *p << " : " << op.lhs()->getType(); +} + +static LogicalResult verify(CmpFOp op) { + auto predicateAttr = + op.getAttrOfType(CmpFOp::getPredicateAttrName()); + if (!predicateAttr) + return op.emitOpError("requires an integer attribute named 'predicate'"); + auto predicate = predicateAttr.getInt(); + if (predicate < (int64_t)CmpFPredicate::FirstValidValue || + predicate >= (int64_t)CmpFPredicate::NumPredicates) + return op.emitOpError("'predicate' attribute value out of range"); + + return success(); +} + +// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point +// comparison predicates. +static bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs, + const APFloat &rhs) { + auto cmpResult = lhs.compare(rhs); + switch (predicate) { + case CmpFPredicate::AlwaysFalse: + return false; + case CmpFPredicate::OEQ: + return cmpResult == APFloat::cmpEqual; + case CmpFPredicate::OGT: + return cmpResult == APFloat::cmpGreaterThan; + case CmpFPredicate::OGE: + return cmpResult == APFloat::cmpGreaterThan || + cmpResult == APFloat::cmpEqual; + case CmpFPredicate::OLT: + return cmpResult == APFloat::cmpLessThan; + case CmpFPredicate::OLE: + return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; + case CmpFPredicate::ONE: + return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual; + case CmpFPredicate::ORD: + return cmpResult != APFloat::cmpUnordered; + case CmpFPredicate::UEQ: + return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual; + case CmpFPredicate::UGT: + return cmpResult == APFloat::cmpUnordered || + cmpResult == APFloat::cmpGreaterThan; + case CmpFPredicate::UGE: + return cmpResult == APFloat::cmpUnordered || + cmpResult == APFloat::cmpGreaterThan || + cmpResult == APFloat::cmpEqual; + case CmpFPredicate::ULT: + return cmpResult == APFloat::cmpUnordered || + cmpResult == APFloat::cmpLessThan; + case CmpFPredicate::ULE: + return cmpResult == APFloat::cmpUnordered || + cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; + case CmpFPredicate::UNE: + return cmpResult != APFloat::cmpEqual; + case CmpFPredicate::UNO: + return cmpResult == APFloat::cmpUnordered; + case CmpFPredicate::AlwaysTrue: + return true; + default: + llvm_unreachable("unknown comparison predicate"); + } +} + +// Constant folding hook for comparisons. +OpFoldResult CmpFOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "cmpf takes two arguments"); + + auto lhs = operands.front().dyn_cast_or_null(); + auto rhs = operands.back().dyn_cast_or_null(); + if (!lhs || !rhs || + // TODO(b/122019992) Implement and test constant folding for nan/inf when + // it is possible to have constant nan/inf + !lhs.getValue().isFinite() || !rhs.getValue().isFinite()) + return {}; + + auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); + return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val)); +} + +//===----------------------------------------------------------------------===// +// CondBranchOp +//===----------------------------------------------------------------------===// + +namespace { +/// cond_br true, ^bb1, ^bb2 -> br ^bb1 +/// cond_br false, ^bb1, ^bb2 -> br ^bb2 +/// +struct SimplifyConstCondBranchPred : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(CondBranchOp condbr, + PatternRewriter &rewriter) const override { + // Check that the condition is a constant. + if (!matchPattern(condbr.getCondition(), m_Op())) + return matchFailure(); + + Block *foldedDest; + SmallVector branchArgs; + + // If the condition is known to evaluate to false we fold to a branch to the + // false destination. Otherwise, we fold to a branch to the true + // destination. + if (matchPattern(condbr.getCondition(), m_Zero())) { + foldedDest = condbr.getFalseDest(); + branchArgs.assign(condbr.false_operand_begin(), + condbr.false_operand_end()); + } else { + foldedDest = condbr.getTrueDest(); + branchArgs.assign(condbr.true_operand_begin(), condbr.true_operand_end()); + } + + rewriter.replaceOpWithNewOp(condbr, foldedDest, branchArgs); + return matchSuccess(); + } +}; +} // end anonymous namespace. + +static ParseResult parseCondBranchOp(OpAsmParser *parser, + OperationState *result) { + SmallVector destOperands; + Block *dest; + OpAsmParser::OperandType condInfo; + + // Parse the condition. + Type int1Ty = parser->getBuilder().getI1Type(); + if (parser->parseOperand(condInfo) || parser->parseComma() || + parser->resolveOperand(condInfo, int1Ty, result->operands)) { + return parser->emitError(parser->getNameLoc(), + "expected condition type was boolean (i1)"); + } + + // Parse the true successor. + if (parser->parseSuccessorAndUseList(dest, destOperands)) + return failure(); + result->addSuccessor(dest, destOperands); + + // Parse the false successor. + destOperands.clear(); + if (parser->parseComma() || + parser->parseSuccessorAndUseList(dest, destOperands)) + return failure(); + result->addSuccessor(dest, destOperands); + + return success(); +} + +static void print(OpAsmPrinter *p, CondBranchOp op) { + *p << "cond_br "; + p->printOperand(op.getCondition()); + *p << ", "; + p->printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex); + *p << ", "; + p->printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex); +} + +void CondBranchOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// Constant*Op +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter *p, ConstantOp &op) { + *p << "constant "; + p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"}); + + if (op.getAttrs().size() > 1) + *p << ' '; + p->printAttribute(op.getValue()); + + // If the value is a symbol reference, print a trailing type. + if (op.getValue().isa()) + *p << " : " << op.getType(); +} + +static ParseResult parseConstantOp(OpAsmParser *parser, + OperationState *result) { + Attribute valueAttr; + if (parser->parseOptionalAttributeDict(result->attributes) || + parser->parseAttribute(valueAttr, "value", result->attributes)) + return failure(); + + // If the attribute is a symbol reference, then we expect a trailing type. + Type type; + if (!valueAttr.isa()) + type = valueAttr.getType(); + else if (parser->parseColonType(type)) + return failure(); + + // Add the attribute type to the list. + return parser->addTypeToList(type, result->types); +} + +/// The constant op requires an attribute, and furthermore requires that it +/// matches the return type. +static LogicalResult verify(ConstantOp &op) { + auto value = op.getValue(); + if (!value) + return op.emitOpError("requires a 'value' attribute"); + + auto type = op.getType(); + if (!value.getType().isa() && type != value.getType()) + return op.emitOpError() << "requires attribute's type (" << value.getType() + << ") to match op's return type (" << type << ")"; + + if (type.isa() || value.isa()) + return success(); + + if (auto intAttr = value.dyn_cast()) { + // If the type has a known bitwidth we verify that the value can be + // represented with the given bitwidth. + auto bitwidth = type.cast().getWidth(); + auto intVal = intAttr.getValue(); + if (!intVal.isSignedIntN(bitwidth) && !intVal.isIntN(bitwidth)) + return op.emitOpError("requires 'value' to be an integer within the " + "range of the integer result type"); + return success(); + } + + if (type.isa()) { + if (!value.isa()) + return op.emitOpError("requires 'value' to be a floating point constant"); + return success(); + } + + if (type.isa()) { + if (!value.isa()) + return op.emitOpError("requires 'value' to be a shaped constant"); + return success(); + } + + if (type.isa()) { + auto fnAttr = value.dyn_cast(); + if (!fnAttr) + return op.emitOpError("requires 'value' to be a function reference"); + + // Try to find the referenced function. + auto fn = + op.getParentOfType().lookupSymbol(fnAttr.getValue()); + if (!fn) + return op.emitOpError("reference to undefined function 'bar'"); + + // Check that the referenced function has the correct type. + if (fn.getType() != type) + return op.emitOpError("reference to function with mismatched type"); + + return success(); + } + + if (type.isa() && value.isa()) + return success(); + + return op.emitOpError("unsupported 'value' attribute: ") << value; +} + +OpFoldResult ConstantOp::fold(ArrayRef operands) { + assert(operands.empty() && "constant has no operands"); + return getValue(); +} + +/// Returns true if a constant operation can be built with the given value and +/// result type. +bool ConstantOp::isBuildableWith(Attribute value, Type type) { + // SymbolRefAttr can only be used with a function type. + if (value.isa()) + return type.isa(); + // Otherwise, the attribute must have the same type as 'type'. + if (value.getType() != type) + return false; + // Finally, check that the attribute kind is handled. + return value.isa() || value.isa() || + value.isa() || value.isa() || + value.isa(); +} + +void ConstantFloatOp::build(Builder *builder, OperationState *result, + const APFloat &value, FloatType type) { + ConstantOp::build(builder, result, type, builder->getFloatAttr(type, value)); +} + +bool ConstantFloatOp::classof(Operation *op) { + return ConstantOp::classof(op) && + op->getResult(0)->getType().isa(); +} + +/// ConstantIntOp only matches values whose result type is an IntegerType. +bool ConstantIntOp::classof(Operation *op) { + return ConstantOp::classof(op) && + op->getResult(0)->getType().isa(); +} + +void ConstantIntOp::build(Builder *builder, OperationState *result, + int64_t value, unsigned width) { + Type type = builder->getIntegerType(width); + ConstantOp::build(builder, result, type, + builder->getIntegerAttr(type, value)); +} + +/// Build a constant int op producing an integer with the specified type, +/// which must be an integer type. +void ConstantIntOp::build(Builder *builder, OperationState *result, + int64_t value, Type type) { + assert(type.isa() && "ConstantIntOp can only have integer type"); + ConstantOp::build(builder, result, type, + builder->getIntegerAttr(type, value)); +} + +/// ConstantIndexOp only matches values whose result type is Index. +bool ConstantIndexOp::classof(Operation *op) { + return ConstantOp::classof(op) && op->getResult(0)->getType().isIndex(); +} + +void ConstantIndexOp::build(Builder *builder, OperationState *result, + int64_t value) { + Type type = builder->getIndexType(); + ConstantOp::build(builder, result, type, + builder->getIntegerAttr(type, value)); +} + +//===----------------------------------------------------------------------===// +// DeallocOp +//===----------------------------------------------------------------------===// +namespace { +/// Fold Dealloc operations that are deallocating an AllocOp that is only used +/// by other Dealloc operations. +struct SimplifyDeadDealloc : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(DeallocOp dealloc, + PatternRewriter &rewriter) const override { + // Check that the memref operand's defining operation is an AllocOp. + Value *memref = dealloc.memref(); + if (!isa_and_nonnull(memref->getDefiningOp())) + return matchFailure(); + + // Check that all of the uses of the AllocOp are other DeallocOps. + for (auto *user : memref->getUsers()) + if (!isa(user)) + return matchFailure(); + + // Erase the dealloc operation. + rewriter.replaceOp(dealloc, llvm::None); + return matchSuccess(); + } +}; +} // end anonymous namespace. + +static void print(OpAsmPrinter *p, DeallocOp op) { + *p << "dealloc " << *op.memref() << " : " << op.memref()->getType(); +} + +static ParseResult parseDeallocOp(OpAsmParser *parser, OperationState *result) { + OpAsmParser::OperandType memrefInfo; + MemRefType type; + + return failure(parser->parseOperand(memrefInfo) || + parser->parseColonType(type) || + parser->resolveOperand(memrefInfo, type, result->operands)); +} + +static LogicalResult verify(DeallocOp op) { + if (!op.memref()->getType().isa()) + return op.emitOpError("operand must be a memref"); + return success(); +} + +void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + /// dealloc(memrefcast) -> dealloc + results.insert(getOperationName(), context); + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// DimOp +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter *p, DimOp op) { + *p << "dim " << *op.getOperand() << ", " << op.getIndex(); + p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"index"}); + *p << " : " << op.getOperand()->getType(); +} + +static ParseResult parseDimOp(OpAsmParser *parser, OperationState *result) { + OpAsmParser::OperandType operandInfo; + IntegerAttr indexAttr; + Type type; + Type indexType = parser->getBuilder().getIndexType(); + + return failure(parser->parseOperand(operandInfo) || parser->parseComma() || + parser->parseAttribute(indexAttr, indexType, "index", + result->attributes) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(type) || + parser->resolveOperand(operandInfo, type, result->operands) || + parser->addTypeToList(indexType, result->types)); +} + +static LogicalResult verify(DimOp op) { + // Check that we have an integer index operand. + auto indexAttr = op.getAttrOfType("index"); + if (!indexAttr) + return op.emitOpError("requires an integer attribute named 'index'"); + int64_t index = indexAttr.getValue().getSExtValue(); + + auto type = op.getOperand()->getType(); + if (auto tensorType = type.dyn_cast()) { + if (index >= tensorType.getRank()) + return op.emitOpError("index is out of range"); + } else if (auto memrefType = type.dyn_cast()) { + if (index >= memrefType.getRank()) + return op.emitOpError("index is out of range"); + + } else if (type.isa()) { + // ok, assumed to be in-range. + } else { + return op.emitOpError("requires an operand with tensor or memref type"); + } + + return success(); +} + +OpFoldResult DimOp::fold(ArrayRef operands) { + // Constant fold dim when the size along the index referred to is a constant. + auto opType = getOperand()->getType(); + int64_t indexSize = -1; + if (auto tensorType = opType.dyn_cast()) + indexSize = tensorType.getShape()[getIndex()]; + else if (auto memrefType = opType.dyn_cast()) + indexSize = memrefType.getShape()[getIndex()]; + + if (indexSize >= 0) + return IntegerAttr::get(IndexType::get(getContext()), indexSize); + + return {}; +} + +//===----------------------------------------------------------------------===// +// DivISOp +//===----------------------------------------------------------------------===// + +OpFoldResult DivISOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "binary operation takes two operands"); + + auto lhs = operands.front().dyn_cast_or_null(); + auto rhs = operands.back().dyn_cast_or_null(); + if (!lhs || !rhs) + return {}; + + // Don't fold if it requires division by zero. + if (rhs.getValue().isNullValue()) + return {}; + + // Don't fold if it would overflow. + bool overflow; + auto result = lhs.getValue().sdiv_ov(rhs.getValue(), overflow); + return overflow ? IntegerAttr() : IntegerAttr::get(lhs.getType(), result); +} + +//===----------------------------------------------------------------------===// +// DivIUOp +//===----------------------------------------------------------------------===// + +OpFoldResult DivIUOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "binary operation takes two operands"); + + auto lhs = operands.front().dyn_cast_or_null(); + auto rhs = operands.back().dyn_cast_or_null(); + if (!lhs || !rhs) + return {}; + + // Don't fold if it requires division by zero. + auto rhsValue = rhs.getValue(); + if (rhsValue.isNullValue()) + return {}; + + return IntegerAttr::get(lhs.getType(), lhs.getValue().udiv(rhsValue)); +} + +// --------------------------------------------------------------------------- +// DmaStartOp +// --------------------------------------------------------------------------- + +void DmaStartOp::build(Builder *builder, OperationState *result, + Value *srcMemRef, ArrayRef srcIndices, + Value *destMemRef, ArrayRef destIndices, + Value *numElements, Value *tagMemRef, + ArrayRef tagIndices, Value *stride, + Value *elementsPerStride) { + result->addOperands(srcMemRef); + result->addOperands(srcIndices); + result->addOperands(destMemRef); + result->addOperands(destIndices); + result->addOperands({numElements, tagMemRef}); + result->addOperands(tagIndices); + if (stride) + result->addOperands({stride, elementsPerStride}); +} + +void DmaStartOp::print(OpAsmPrinter *p) { + *p << "dma_start " << *getSrcMemRef() << '['; + p->printOperands(getSrcIndices()); + *p << "], " << *getDstMemRef() << '['; + p->printOperands(getDstIndices()); + *p << "], " << *getNumElements(); + *p << ", " << *getTagMemRef() << '['; + p->printOperands(getTagIndices()); + *p << ']'; + if (isStrided()) { + *p << ", " << *getStride(); + *p << ", " << *getNumElementsPerStride(); + } + p->printOptionalAttrDict(getAttrs()); + *p << " : " << getSrcMemRef()->getType(); + *p << ", " << getDstMemRef()->getType(); + *p << ", " << getTagMemRef()->getType(); +} + +// Parse DmaStartOp. +// Ex: +// %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, +// %tag[%index], %stride, %num_elt_per_stride : +// : memref<3076 x f32, 0>, +// memref<1024 x f32, 2>, +// memref<1 x i32> +// +ParseResult DmaStartOp::parse(OpAsmParser *parser, OperationState *result) { + OpAsmParser::OperandType srcMemRefInfo; + SmallVector srcIndexInfos; + OpAsmParser::OperandType dstMemRefInfo; + SmallVector dstIndexInfos; + OpAsmParser::OperandType numElementsInfo; + OpAsmParser::OperandType tagMemrefInfo; + SmallVector tagIndexInfos; + SmallVector strideInfo; + + SmallVector types; + auto indexType = parser->getBuilder().getIndexType(); + + // Parse and resolve the following list of operands: + // *) source memref followed by its indices (in square brackets). + // *) destination memref followed by its indices (in square brackets). + // *) dma size in KiB. + if (parser->parseOperand(srcMemRefInfo) || + parser->parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || + parser->parseComma() || parser->parseOperand(dstMemRefInfo) || + parser->parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || + parser->parseComma() || parser->parseOperand(numElementsInfo) || + parser->parseComma() || parser->parseOperand(tagMemrefInfo) || + parser->parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) + return failure(); + + // Parse optional stride and elements per stride. + if (parser->parseTrailingOperandList(strideInfo)) + return failure(); + + bool isStrided = strideInfo.size() == 2; + if (!strideInfo.empty() && !isStrided) { + return parser->emitError(parser->getNameLoc(), + "expected two stride related operands"); + } + + if (parser->parseColonTypeList(types)) + return failure(); + if (types.size() != 3) + return parser->emitError(parser->getNameLoc(), "fewer/more types expected"); + + if (parser->resolveOperand(srcMemRefInfo, types[0], result->operands) || + parser->resolveOperands(srcIndexInfos, indexType, result->operands) || + parser->resolveOperand(dstMemRefInfo, types[1], result->operands) || + parser->resolveOperands(dstIndexInfos, indexType, result->operands) || + // size should be an index. + parser->resolveOperand(numElementsInfo, indexType, result->operands) || + parser->resolveOperand(tagMemrefInfo, types[2], result->operands) || + // tag indices should be index. + parser->resolveOperands(tagIndexInfos, indexType, result->operands)) + return failure(); + + auto memrefType0 = types[0].dyn_cast(); + if (!memrefType0) + return parser->emitError(parser->getNameLoc(), + "expected source to be of memref type"); + + auto memrefType1 = types[1].dyn_cast(); + if (!memrefType1) + return parser->emitError(parser->getNameLoc(), + "expected destination to be of memref type"); + + auto memrefType2 = types[2].dyn_cast(); + if (!memrefType2) + return parser->emitError(parser->getNameLoc(), + "expected tag to be of memref type"); + + if (isStrided) { + if (parser->resolveOperands(strideInfo, indexType, result->operands)) + return failure(); + } + + // Check that source/destination index list size matches associated rank. + if (static_cast(srcIndexInfos.size()) != memrefType0.getRank() || + static_cast(dstIndexInfos.size()) != memrefType1.getRank()) + return parser->emitError(parser->getNameLoc(), + "memref rank not equal to indices count"); + if (static_cast(tagIndexInfos.size()) != memrefType2.getRank()) + return parser->emitError(parser->getNameLoc(), + "tag memref rank not equal to indices count"); + + return success(); +} + +LogicalResult DmaStartOp::verify() { + // DMAs from different memory spaces supported. + if (getSrcMemorySpace() == getDstMemorySpace()) + return emitOpError("DMA should be between different memory spaces"); + + if (getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() + + getDstMemRefRank() + 3 + 1 && + getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() + + getDstMemRefRank() + 3 + 1 + 2) { + return emitOpError("incorrect number of operands"); + } + return success(); +} + +void DmaStartOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + /// dma_start(memrefcast) -> dma_start + results.insert(getOperationName(), context); +} + +// --------------------------------------------------------------------------- +// DmaWaitOp +// --------------------------------------------------------------------------- + +void DmaWaitOp::build(Builder *builder, OperationState *result, + Value *tagMemRef, ArrayRef tagIndices, + Value *numElements) { + result->addOperands(tagMemRef); + result->addOperands(tagIndices); + result->addOperands(numElements); +} + +void DmaWaitOp::print(OpAsmPrinter *p) { + *p << "dma_wait "; + p->printOperand(getTagMemRef()); + *p << '['; + p->printOperands(getTagIndices()); + *p << "], "; + p->printOperand(getNumElements()); + p->printOptionalAttrDict(getAttrs()); + *p << " : " << getTagMemRef()->getType(); +} + +// Parse DmaWaitOp. +// Eg: +// dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4> +// +ParseResult DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) { + OpAsmParser::OperandType tagMemrefInfo; + SmallVector tagIndexInfos; + Type type; + auto indexType = parser->getBuilder().getIndexType(); + OpAsmParser::OperandType numElementsInfo; + + // Parse tag memref, its indices, and dma size. + if (parser->parseOperand(tagMemrefInfo) || + parser->parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) || + parser->parseComma() || parser->parseOperand(numElementsInfo) || + parser->parseColonType(type) || + parser->resolveOperand(tagMemrefInfo, type, result->operands) || + parser->resolveOperands(tagIndexInfos, indexType, result->operands) || + parser->resolveOperand(numElementsInfo, indexType, result->operands)) + return failure(); + + auto memrefType = type.dyn_cast(); + if (!memrefType) + return parser->emitError(parser->getNameLoc(), + "expected tag to be of memref type"); + + if (static_cast(tagIndexInfos.size()) != memrefType.getRank()) + return parser->emitError(parser->getNameLoc(), + "tag memref rank not equal to indices count"); + + return success(); +} + +void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + /// dma_wait(memrefcast) -> dma_wait + results.insert(getOperationName(), context); +} + +//===----------------------------------------------------------------------===// +// ExtractElementOp +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter *p, ExtractElementOp op) { + *p << "extract_element " << *op.getAggregate() << '['; + p->printOperands(op.getIndices()); + *p << ']'; + p->printOptionalAttrDict(op.getAttrs()); + *p << " : " << op.getAggregate()->getType(); +} + +static ParseResult parseExtractElementOp(OpAsmParser *parser, + OperationState *result) { + OpAsmParser::OperandType aggregateInfo; + SmallVector indexInfo; + ShapedType type; + + auto affineIntTy = parser->getBuilder().getIndexType(); + return failure( + parser->parseOperand(aggregateInfo) || + parser->parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(type) || + parser->resolveOperand(aggregateInfo, type, result->operands) || + parser->resolveOperands(indexInfo, affineIntTy, result->operands) || + parser->addTypeToList(type.getElementType(), result->types)); +} + +static LogicalResult verify(ExtractElementOp op) { + auto aggregateType = op.getAggregate()->getType().cast(); + + // This should be possible with tablegen type constraints + if (op.getType() != aggregateType.getElementType()) + return op.emitOpError("result type must match element type of aggregate"); + + // Verify the # indices match if we have a ranked type. + if (aggregateType.hasRank() && + aggregateType.getRank() != op.getNumOperands() - 1) + return op.emitOpError("incorrect number of indices for extract_element"); + + return success(); +} + +OpFoldResult ExtractElementOp::fold(ArrayRef operands) { + assert(!operands.empty() && "extract_element takes atleast one operand"); + + // The aggregate operand must be a known constant. + Attribute aggregate = operands.front(); + if (!aggregate) + return {}; + + // If this is a splat elements attribute, simply return the value. All of the + // elements of a splat attribute are the same. + if (auto splatAggregate = aggregate.dyn_cast()) + return splatAggregate.getSplatValue(); + + // Otherwise, collect the constant indices into the aggregate. + SmallVector indices; + for (Attribute indice : llvm::drop_begin(operands, 1)) { + if (!indice || !indice.isa()) + return {}; + indices.push_back(indice.cast().getInt()); + } + + // If this is an elements attribute, query the value at the given indices. + auto elementsAttr = aggregate.dyn_cast(); + if (elementsAttr && elementsAttr.isValidIndex(indices)) + return elementsAttr.getValue(indices); + return {}; +} + +//===----------------------------------------------------------------------===// +// IndexCastOp +//===----------------------------------------------------------------------===// + +// Index cast is applicable from index to integer and backwards. +bool IndexCastOp::areCastCompatible(Type a, Type b) { + return (a.isIndex() && b.isa()) || + (a.isa() && b.isIndex()); +} + +//===----------------------------------------------------------------------===// +// LoadOp +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter *p, LoadOp op) { + *p << "load " << *op.getMemRef() << '['; + p->printOperands(op.getIndices()); + *p << ']'; + p->printOptionalAttrDict(op.getAttrs()); + *p << " : " << op.getMemRefType(); +} + +static ParseResult parseLoadOp(OpAsmParser *parser, OperationState *result) { + OpAsmParser::OperandType memrefInfo; + SmallVector indexInfo; + MemRefType type; + + auto affineIntTy = parser->getBuilder().getIndexType(); + return failure( + parser->parseOperand(memrefInfo) || + parser->parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(type) || + parser->resolveOperand(memrefInfo, type, result->operands) || + parser->resolveOperands(indexInfo, affineIntTy, result->operands) || + parser->addTypeToList(type.getElementType(), result->types)); +} + +static LogicalResult verify(LoadOp op) { + if (op.getType() != op.getMemRefType().getElementType()) + return op.emitOpError("result type must match element type of memref"); + + if (op.getMemRefType().getRank() != op.getNumOperands() - 1) + return op.emitOpError("incorrect number of indices for load"); + + for (auto *idx : op.getIndices()) + if (!idx->getType().isIndex()) + return op.emitOpError("index to load must have 'index' type"); + + // TODO: Verify we have the right number of indices. + + // TODO: in Function verify that the indices are parameters, IV's, or the + // result of an affine.apply. + return success(); +} + +void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + /// load(memrefcast) -> load + results.insert(getOperationName(), context); +} + +//===----------------------------------------------------------------------===// +// MemRefCastOp +//===----------------------------------------------------------------------===// + +bool MemRefCastOp::areCastCompatible(Type a, Type b) { + auto aT = a.dyn_cast(); + auto bT = b.dyn_cast(); + + if (!aT || !bT) + return false; + if (aT.getElementType() != bT.getElementType()) + return false; + if (aT.getAffineMaps() != bT.getAffineMaps()) + return false; + if (aT.getMemorySpace() != bT.getMemorySpace()) + return false; + + // They must have the same rank, and any specified dimensions must match. + if (aT.getRank() != bT.getRank()) + return false; + + for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { + int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); + if (aDim != -1 && bDim != -1 && aDim != bDim) + return false; + } + + return true; +} + +OpFoldResult MemRefCastOp::fold(ArrayRef operands) { + return impl::foldCastOp(*this); +} + +//===----------------------------------------------------------------------===// +// MulFOp +//===----------------------------------------------------------------------===// + +OpFoldResult MulFOp::fold(ArrayRef operands) { + return constFoldBinaryOp( + operands, [](APFloat a, APFloat b) { return a * b; }); +} + +//===----------------------------------------------------------------------===// +// MulIOp +//===----------------------------------------------------------------------===// + +OpFoldResult MulIOp::fold(ArrayRef operands) { + /// muli(x, 0) -> 0 + if (matchPattern(rhs(), m_Zero())) + return rhs(); + /// muli(x, 1) -> x + if (matchPattern(rhs(), m_One())) + return getOperand(0); + + // TODO: Handle the overflow case. + return constFoldBinaryOp(operands, + [](APInt a, APInt b) { return a * b; }); +} + +//===----------------------------------------------------------------------===// +// RankOp +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter *p, RankOp op) { + *p << "rank " << *op.getOperand() << " : " << op.getOperand()->getType(); +} + +static ParseResult parseRankOp(OpAsmParser *parser, OperationState *result) { + OpAsmParser::OperandType operandInfo; + Type type; + Type indexType = parser->getBuilder().getIndexType(); + return failure(parser->parseOperand(operandInfo) || + parser->parseColonType(type) || + parser->resolveOperand(operandInfo, type, result->operands) || + parser->addTypeToList(indexType, result->types)); +} + +OpFoldResult RankOp::fold(ArrayRef operands) { + // Constant fold rank when the rank of the tensor is known. + auto type = getOperand()->getType(); + if (auto tensorType = type.dyn_cast()) + return IntegerAttr::get(IndexType::get(getContext()), tensorType.getRank()); + return IntegerAttr(); +} + +//===----------------------------------------------------------------------===// +// RemISOp +//===----------------------------------------------------------------------===// + +OpFoldResult RemISOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "remis takes two operands"); + + auto rhs = operands.back().dyn_cast_or_null(); + if (!rhs) + return {}; + auto rhsValue = rhs.getValue(); + + // x % 1 = 0 + if (rhsValue.isOneValue()) + return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); + + // Don't fold if it requires division by zero. + if (rhsValue.isNullValue()) + return {}; + + auto lhs = operands.front().dyn_cast_or_null(); + if (!lhs) + return {}; + return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue)); +} + +//===----------------------------------------------------------------------===// +// RemIUOp +//===----------------------------------------------------------------------===// + +OpFoldResult RemIUOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "remiu takes two operands"); + + auto rhs = operands.back().dyn_cast_or_null(); + if (!rhs) + return {}; + auto rhsValue = rhs.getValue(); + + // x % 1 = 0 + if (rhsValue.isOneValue()) + return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); + + // Don't fold if it requires division by zero. + if (rhsValue.isNullValue()) + return {}; + + auto lhs = operands.front().dyn_cast_or_null(); + if (!lhs) + return {}; + return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue)); +} + +//===----------------------------------------------------------------------===// +// ReturnOp +//===----------------------------------------------------------------------===// + +static ParseResult parseReturnOp(OpAsmParser *parser, OperationState *result) { + SmallVector opInfo; + SmallVector types; + llvm::SMLoc loc = parser->getCurrentLocation(); + return failure(parser->parseOperandList(opInfo) || + (!opInfo.empty() && parser->parseColonTypeList(types)) || + parser->resolveOperands(opInfo, types, loc, result->operands)); +} + +static void print(OpAsmPrinter *p, ReturnOp op) { + *p << "return"; + if (op.getNumOperands() != 0) { + *p << ' '; + p->printOperands(op.getOperands()); + *p << " : "; + interleaveComma(op.getOperandTypes(), *p); + } +} + +static LogicalResult verify(ReturnOp op) { + auto function = cast(op.getParentOp()); + + // The operand number and types must match the function signature. + const auto &results = function.getType().getResults(); + if (op.getNumOperands() != results.size()) + return op.emitOpError("has ") + << op.getNumOperands() + << " operands, but enclosing function returns " << results.size(); + + for (unsigned i = 0, e = results.size(); i != e; ++i) + if (op.getOperand(i)->getType() != results[i]) + return op.emitError() + << "type of return operand " << i << " (" + << op.getOperand(i)->getType() + << ") doesn't match function result type (" << results[i] << ")"; + + return success(); +} + +//===----------------------------------------------------------------------===// +// SIToFPOp +//===----------------------------------------------------------------------===// + +// sitofp is applicable from integer types to float types. +bool SIToFPOp::areCastCompatible(Type a, Type b) { + return a.isa() && b.isa(); +} + +//===----------------------------------------------------------------------===// +// SelectOp +//===----------------------------------------------------------------------===// + +static ParseResult parseSelectOp(OpAsmParser *parser, OperationState *result) { + SmallVector ops; + SmallVector attrs; + Type type; + if (parser->parseOperandList(ops, 3) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(type)) + return failure(); + + auto i1Type = getCheckedI1SameShape(&parser->getBuilder(), type); + if (!i1Type) + return parser->emitError(parser->getNameLoc(), + "expected type with valid i1 shape"); + + SmallVector types = {i1Type, type, type}; + return failure(parser->resolveOperands(ops, types, parser->getNameLoc(), + result->operands) || + parser->addTypeToList(type, result->types)); +} + +static void print(OpAsmPrinter *p, SelectOp op) { + *p << "select "; + p->printOperands(op.getOperands()); + *p << " : " << op.getTrueValue()->getType(); + p->printOptionalAttrDict(op.getAttrs()); +} + +static LogicalResult verify(SelectOp op) { + auto trueType = op.getTrueValue()->getType(); + auto falseType = op.getFalseValue()->getType(); + + if (trueType != falseType) + return op.emitOpError( + "requires 'true' and 'false' arguments to be of the same type"); + + return success(); +} + +OpFoldResult SelectOp::fold(ArrayRef operands) { + auto *condition = getCondition(); + + // select true, %0, %1 => %0 + if (matchPattern(condition, m_One())) + return getTrueValue(); + + // select false, %0, %1 => %1 + if (matchPattern(condition, m_Zero())) + return getFalseValue(); + return nullptr; +} + +//===----------------------------------------------------------------------===// +// StoreOp +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter *p, StoreOp op) { + *p << "store " << *op.getValueToStore(); + *p << ", " << *op.getMemRef() << '['; + p->printOperands(op.getIndices()); + *p << ']'; + p->printOptionalAttrDict(op.getAttrs()); + *p << " : " << op.getMemRefType(); +} + +static ParseResult parseStoreOp(OpAsmParser *parser, OperationState *result) { + OpAsmParser::OperandType storeValueInfo; + OpAsmParser::OperandType memrefInfo; + SmallVector indexInfo; + MemRefType memrefType; + + auto affineIntTy = parser->getBuilder().getIndexType(); + return failure( + parser->parseOperand(storeValueInfo) || parser->parseComma() || + parser->parseOperand(memrefInfo) || + parser->parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(memrefType) || + parser->resolveOperand(storeValueInfo, memrefType.getElementType(), + result->operands) || + parser->resolveOperand(memrefInfo, memrefType, result->operands) || + parser->resolveOperands(indexInfo, affineIntTy, result->operands)); +} + +static LogicalResult verify(StoreOp op) { + // First operand must have same type as memref element type. + if (op.getValueToStore()->getType() != op.getMemRefType().getElementType()) + return op.emitOpError( + "first operand must have same type memref element type"); + + if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) + return op.emitOpError("store index operand count not equal to memref rank"); + + for (auto *idx : op.getIndices()) + if (!idx->getType().isIndex()) + return op.emitOpError("index to load must have 'index' type"); + + // TODO: Verify we have the right number of indices. + + // TODO: in Function verify that the indices are parameters, IV's, or the + // result of an affine.apply. + return success(); +} + +void StoreOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + /// store(memrefcast) -> store + results.insert(getOperationName(), context); +} + +//===----------------------------------------------------------------------===// +// SubFOp +//===----------------------------------------------------------------------===// + +OpFoldResult SubFOp::fold(ArrayRef operands) { + return constFoldBinaryOp( + operands, [](APFloat a, APFloat b) { return a - b; }); +} + +//===----------------------------------------------------------------------===// +// SubIOp +//===----------------------------------------------------------------------===// + +OpFoldResult SubIOp::fold(ArrayRef operands) { + // subi(x,x) -> 0 + if (getOperand(0) == getOperand(1)) + return Builder(getContext()).getZeroAttr(getType()); + + return constFoldBinaryOp(operands, + [](APInt a, APInt b) { return a - b; }); +} + +//===----------------------------------------------------------------------===// +// AndOp +//===----------------------------------------------------------------------===// + +OpFoldResult AndOp::fold(ArrayRef operands) { + /// and(x, 0) -> 0 + if (matchPattern(rhs(), m_Zero())) + return rhs(); + /// and(x,x) -> x + if (lhs() == rhs()) + return rhs(); + + return constFoldBinaryOp(operands, + [](APInt a, APInt b) { return a & b; }); +} + +//===----------------------------------------------------------------------===// +// OrOp +//===----------------------------------------------------------------------===// + +OpFoldResult OrOp::fold(ArrayRef operands) { + /// or(x, 0) -> x + if (matchPattern(rhs(), m_Zero())) + return lhs(); + /// or(x,x) -> x + if (lhs() == rhs()) + return rhs(); + + return constFoldBinaryOp(operands, + [](APInt a, APInt b) { return a | b; }); +} + +//===----------------------------------------------------------------------===// +// XOrOp +//===----------------------------------------------------------------------===// + +OpFoldResult XOrOp::fold(ArrayRef operands) { + /// xor(x, 0) -> x + if (matchPattern(rhs(), m_Zero())) + return lhs(); + /// xor(x,x) -> 0 + if (lhs() == rhs()) + return Builder(getContext()).getZeroAttr(getType()); + + return constFoldBinaryOp(operands, + [](APInt a, APInt b) { return a ^ b; }); +} + +//===----------------------------------------------------------------------===// +// TensorCastOp +//===----------------------------------------------------------------------===// + +bool TensorCastOp::areCastCompatible(Type a, Type b) { + auto aT = a.dyn_cast(); + auto bT = b.dyn_cast(); + if (!aT || !bT) + return false; + + if (aT.getElementType() != bT.getElementType()) + return false; + + // If the either are unranked, then the cast is valid. + auto aRType = aT.dyn_cast(); + auto bRType = bT.dyn_cast(); + if (!aRType || !bRType) + return true; + + // If they are both ranked, they have to have the same rank, and any specified + // dimensions must match. + if (aRType.getRank() != bRType.getRank()) + return false; + + for (unsigned i = 0, e = aRType.getRank(); i != e; ++i) { + int64_t aDim = aRType.getDimSize(i), bDim = bRType.getDimSize(i); + if (aDim != -1 && bDim != -1 && aDim != bDim) + return false; + } + + return true; +} + +OpFoldResult TensorCastOp::fold(ArrayRef operands) { + return impl::foldCastOp(*this); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/StandardOps/Ops.cpp.inc" diff --git a/mlir/lib/EDSC/Builders.cpp b/mlir/lib/EDSC/Builders.cpp index d52490055e4..c620ac555f5 100644 --- a/mlir/lib/EDSC/Builders.cpp +++ b/mlir/lib/EDSC/Builders.cpp @@ -16,8 +16,8 @@ // ============================================================================= #include "mlir/EDSC/Builders.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/AffineExpr.h" -#include "mlir/StandardOps/Ops.h" #include "llvm/ADT/Optional.h" diff --git a/mlir/lib/EDSC/Helpers.cpp b/mlir/lib/EDSC/Helpers.cpp index e6266d373e6..b4455c43c1e 100644 --- a/mlir/lib/EDSC/Helpers.cpp +++ b/mlir/lib/EDSC/Helpers.cpp @@ -16,8 +16,8 @@ // ============================================================================= #include "mlir/EDSC/Helpers.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/AffineExpr.h" -#include "mlir/StandardOps/Ops.h" using namespace mlir; using namespace mlir::edsc; diff --git a/mlir/lib/LLVMIR/CMakeLists.txt b/mlir/lib/LLVMIR/CMakeLists.txt deleted file mode 100644 index 5e21850dbac..00000000000 --- a/mlir/lib/LLVMIR/CMakeLists.txt +++ /dev/null @@ -1,17 +0,0 @@ -add_llvm_library(MLIRLLVMIR - IR/LLVMDialect.cpp - - ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/LLVMIR - ) -add_dependencies(MLIRLLVMIR MLIRLLVMOpsIncGen MLIRLLVMConversionsIncGen LLVMAsmParser LLVMCore LLVMSupport) -target_link_libraries(MLIRLLVMIR LLVMAsmParser LLVMCore LLVMSupport) - -add_llvm_library(MLIRNVVMIR - IR/NVVMDialect.cpp - - ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/LLVMIR - ) -add_dependencies(MLIRNVVMIR MLIRNVVMOpsIncGen MLIRNVVMConversionsIncGen LLVMAsmParser LLVMCore LLVMSupport) -target_link_libraries(MLIRNVVMIR LLVMAsmParser LLVMCore LLVMSupport) diff --git a/mlir/lib/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/LLVMIR/IR/LLVMDialect.cpp deleted file mode 100644 index d051cc94d5b..00000000000 --- a/mlir/lib/LLVMIR/IR/LLVMDialect.cpp +++ /dev/null @@ -1,1394 +0,0 @@ -//===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file defines the types and operation details for the LLVM IR dialect in -// MLIR, and the LLVM IR dialect. It also registers the dialect. -// -//===----------------------------------------------------------------------===// -#include "mlir/LLVMIR/LLVMDialect.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Module.h" -#include "mlir/IR/StandardTypes.h" - -#include "llvm/AsmParser/Parser.h" -#include "llvm/IR/Attributes.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/Type.h" -#include "llvm/Support/Mutex.h" -#include "llvm/Support/SourceMgr.h" - -using namespace mlir; -using namespace mlir::LLVM; - -#include "mlir/LLVMIR/LLVMOpsEnums.cpp.inc" - -//===----------------------------------------------------------------------===// -// Printing/parsing for LLVM::CmpOp. -//===----------------------------------------------------------------------===// -static void printICmpOp(OpAsmPrinter *p, ICmpOp &op) { - *p << op.getOperationName() << " \"" << stringifyICmpPredicate(op.predicate()) - << "\" " << *op.getOperand(0) << ", " << *op.getOperand(1); - p->printOptionalAttrDict(op.getAttrs(), {"predicate"}); - *p << " : " << op.lhs()->getType(); -} - -static void printFCmpOp(OpAsmPrinter *p, FCmpOp &op) { - *p << op.getOperationName() << " \"" << stringifyFCmpPredicate(op.predicate()) - << "\" " << *op.getOperand(0) << ", " << *op.getOperand(1); - p->printOptionalAttrDict(op.getAttrs(), {"predicate"}); - *p << " : " << op.lhs()->getType(); -} - -// ::= `llvm.icmp` string-literal ssa-use `,` ssa-use -// attribute-dict? `:` type -// ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use -// attribute-dict? `:` type -template -static ParseResult parseCmpOp(OpAsmParser *parser, OperationState *result) { - Builder &builder = parser->getBuilder(); - - Attribute predicate; - SmallVector attrs; - OpAsmParser::OperandType lhs, rhs; - Type type; - llvm::SMLoc predicateLoc, trailingTypeLoc; - if (parser->getCurrentLocation(&predicateLoc) || - parser->parseAttribute(predicate, "predicate", attrs) || - parser->parseOperand(lhs) || parser->parseComma() || - parser->parseOperand(rhs) || parser->parseOptionalAttributeDict(attrs) || - parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) || - parser->parseType(type) || - parser->resolveOperand(lhs, type, result->operands) || - parser->resolveOperand(rhs, type, result->operands)) - return failure(); - - // Replace the string attribute `predicate` with an integer attribute. - auto predicateStr = predicate.dyn_cast(); - if (!predicateStr) - return parser->emitError(predicateLoc, - "expected 'predicate' attribute of string type"); - - int64_t predicateValue = 0; - if (std::is_same()) { - Optional predicate = - symbolizeICmpPredicate(predicateStr.getValue()); - if (!predicate) - return parser->emitError(predicateLoc) - << "'" << predicateStr.getValue() - << "' is an incorrect value of the 'predicate' attribute"; - predicateValue = static_cast(predicate.getValue()); - } else { - Optional predicate = - symbolizeFCmpPredicate(predicateStr.getValue()); - if (!predicate) - return parser->emitError(predicateLoc) - << "'" << predicateStr.getValue() - << "' is an incorrect value of the 'predicate' attribute"; - predicateValue = static_cast(predicate.getValue()); - } - - attrs[0].second = parser->getBuilder().getI64IntegerAttr(predicateValue); - - // The result type is either i1 or a vector type if the inputs are - // vectors. - auto *dialect = builder.getContext()->getRegisteredDialect(); - auto resultType = LLVMType::getInt1Ty(dialect); - auto argType = type.dyn_cast(); - if (!argType) - return parser->emitError(trailingTypeLoc, "expected LLVM IR dialect type"); - if (argType.getUnderlyingType()->isVectorTy()) - resultType = LLVMType::getVectorTy( - resultType, argType.getUnderlyingType()->getVectorNumElements()); - - result->attributes = attrs; - result->addTypes({resultType}); - return success(); -} - -//===----------------------------------------------------------------------===// -// Printing/parsing for LLVM::AllocaOp. -//===----------------------------------------------------------------------===// - -static void printAllocaOp(OpAsmPrinter *p, AllocaOp &op) { - auto elemTy = op.getType().cast().getPointerElementTy(); - - auto funcTy = FunctionType::get({op.arraySize()->getType()}, {op.getType()}, - op.getContext()); - - *p << op.getOperationName() << ' ' << *op.arraySize() << " x " << elemTy; - if (op.alignment().hasValue() && op.alignment()->getSExtValue() != 0) - p->printOptionalAttrDict(op.getAttrs()); - else - p->printOptionalAttrDict(op.getAttrs(), {"alignment"}); - *p << " : " << funcTy; -} - -// ::= `llvm.alloca` ssa-use `x` type attribute-dict? -// `:` type `,` type -static ParseResult parseAllocaOp(OpAsmParser *parser, OperationState *result) { - SmallVector attrs; - OpAsmParser::OperandType arraySize; - Type type, elemType; - llvm::SMLoc trailingTypeLoc; - if (parser->parseOperand(arraySize) || parser->parseKeyword("x") || - parser->parseType(elemType) || - parser->parseOptionalAttributeDict(attrs) || parser->parseColon() || - parser->getCurrentLocation(&trailingTypeLoc) || parser->parseType(type)) - return failure(); - - // Extract the result type from the trailing function type. - auto funcType = type.dyn_cast(); - if (!funcType || funcType.getNumInputs() != 1 || - funcType.getNumResults() != 1) - return parser->emitError( - trailingTypeLoc, - "expected trailing function type with one argument and one result"); - - if (parser->resolveOperand(arraySize, funcType.getInput(0), result->operands)) - return failure(); - - result->attributes = attrs; - result->addTypes({funcType.getResult(0)}); - return success(); -} - -//===----------------------------------------------------------------------===// -// Printing/parsing for LLVM::GEPOp. -//===----------------------------------------------------------------------===// - -static void printGEPOp(OpAsmPrinter *p, GEPOp &op) { - SmallVector types(op.getOperandTypes()); - auto funcTy = FunctionType::get(types, op.getType(), op.getContext()); - - *p << op.getOperationName() << ' ' << *op.base() << '['; - p->printOperands(std::next(op.operand_begin()), op.operand_end()); - *p << ']'; - p->printOptionalAttrDict(op.getAttrs()); - *p << " : " << funcTy; -} - -// ::= `llvm.getelementptr` ssa-use `[` ssa-use-list `]` -// attribute-dict? `:` type -static ParseResult parseGEPOp(OpAsmParser *parser, OperationState *result) { - SmallVector attrs; - OpAsmParser::OperandType base; - SmallVector indices; - Type type; - llvm::SMLoc trailingTypeLoc; - if (parser->parseOperand(base) || - parser->parseOperandList(indices, OpAsmParser::Delimiter::Square) || - parser->parseOptionalAttributeDict(attrs) || parser->parseColon() || - parser->getCurrentLocation(&trailingTypeLoc) || parser->parseType(type)) - return failure(); - - // Deconstruct the trailing function type to extract the types of the base - // pointer and result (same type) and the types of the indices. - auto funcType = type.dyn_cast(); - if (!funcType || funcType.getNumResults() != 1 || - funcType.getNumInputs() == 0) - return parser->emitError(trailingTypeLoc, - "expected trailing function type with at least " - "one argument and one result"); - - if (parser->resolveOperand(base, funcType.getInput(0), result->operands) || - parser->resolveOperands(indices, funcType.getInputs().drop_front(), - parser->getNameLoc(), result->operands)) - return failure(); - - result->attributes = attrs; - result->addTypes(funcType.getResults()); - return success(); -} - -//===----------------------------------------------------------------------===// -// Printing/parsing for LLVM::LoadOp. -//===----------------------------------------------------------------------===// - -static void printLoadOp(OpAsmPrinter *p, LoadOp &op) { - *p << op.getOperationName() << ' ' << *op.addr(); - p->printOptionalAttrDict(op.getAttrs()); - *p << " : " << op.addr()->getType(); -} - -// Extract the pointee type from the LLVM pointer type wrapped in MLIR. Return -// the resulting type wrapped in MLIR, or nullptr on error. -static Type getLoadStoreElementType(OpAsmParser *parser, Type type, - llvm::SMLoc trailingTypeLoc) { - auto llvmTy = type.dyn_cast(); - if (!llvmTy) - return parser->emitError(trailingTypeLoc, "expected LLVM IR dialect type"), - nullptr; - if (!llvmTy.getUnderlyingType()->isPointerTy()) - return parser->emitError(trailingTypeLoc, "expected LLVM pointer type"), - nullptr; - return llvmTy.getPointerElementTy(); -} - -// ::= `llvm.load` ssa-use attribute-dict? `:` type -static ParseResult parseLoadOp(OpAsmParser *parser, OperationState *result) { - SmallVector attrs; - OpAsmParser::OperandType addr; - Type type; - llvm::SMLoc trailingTypeLoc; - - if (parser->parseOperand(addr) || parser->parseOptionalAttributeDict(attrs) || - parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) || - parser->parseType(type) || - parser->resolveOperand(addr, type, result->operands)) - return failure(); - - Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc); - - result->attributes = attrs; - result->addTypes(elemTy); - return success(); -} - -//===----------------------------------------------------------------------===// -// Printing/parsing for LLVM::StoreOp. -//===----------------------------------------------------------------------===// - -static void printStoreOp(OpAsmPrinter *p, StoreOp &op) { - *p << op.getOperationName() << ' ' << *op.value() << ", " << *op.addr(); - p->printOptionalAttrDict(op.getAttrs()); - *p << " : " << op.addr()->getType(); -} - -// ::= `llvm.store` ssa-use `,` ssa-use attribute-dict? `:` type -static ParseResult parseStoreOp(OpAsmParser *parser, OperationState *result) { - SmallVector attrs; - OpAsmParser::OperandType addr, value; - Type type; - llvm::SMLoc trailingTypeLoc; - - if (parser->parseOperand(value) || parser->parseComma() || - parser->parseOperand(addr) || parser->parseOptionalAttributeDict(attrs) || - parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) || - parser->parseType(type)) - return failure(); - - Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc); - if (!elemTy) - return failure(); - - if (parser->resolveOperand(value, elemTy, result->operands) || - parser->resolveOperand(addr, type, result->operands)) - return failure(); - - result->attributes = attrs; - return success(); -} - -//===----------------------------------------------------------------------===// -// Printing/parsing for LLVM::CallOp. -//===----------------------------------------------------------------------===// - -static void printCallOp(OpAsmPrinter *p, CallOp &op) { - auto callee = op.callee(); - bool isDirect = callee.hasValue(); - - // Print the direct callee if present as a function attribute, or an indirect - // callee (first operand) otherwise. - *p << op.getOperationName() << ' '; - if (isDirect) - *p << '@' << callee.getValue(); - else - *p << *op.getOperand(0); - - *p << '('; - p->printOperands(llvm::drop_begin(op.getOperands(), isDirect ? 0 : 1)); - *p << ')'; - - p->printOptionalAttrDict(op.getAttrs(), {"callee"}); - - // Reconstruct the function MLIR function type from operand and result types. - SmallVector resultTypes(op.getResultTypes()); - SmallVector argTypes( - llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1)); - - *p << " : " << FunctionType::get(argTypes, resultTypes, op.getContext()); -} - -// ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)` -// attribute-dict? `:` function-type -static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) { - SmallVector attrs; - SmallVector operands; - Type type; - SymbolRefAttr funcAttr; - llvm::SMLoc trailingTypeLoc; - - // Parse an operand list that will, in practice, contain 0 or 1 operand. In - // case of an indirect call, there will be 1 operand before `(`. In case of a - // direct call, there will be no operands and the parser will stop at the - // function identifier without complaining. - if (parser->parseOperandList(operands)) - return failure(); - bool isDirect = operands.empty(); - - // Optionally parse a function identifier. - if (isDirect) - if (parser->parseAttribute(funcAttr, "callee", attrs)) - return failure(); - - if (parser->parseOperandList(operands, OpAsmParser::Delimiter::Paren) || - parser->parseOptionalAttributeDict(attrs) || parser->parseColon() || - parser->getCurrentLocation(&trailingTypeLoc) || parser->parseType(type)) - return failure(); - - auto funcType = type.dyn_cast(); - if (!funcType) - return parser->emitError(trailingTypeLoc, "expected function type"); - if (isDirect) { - // Make sure types match. - if (parser->resolveOperands(operands, funcType.getInputs(), - parser->getNameLoc(), result->operands)) - return failure(); - result->addTypes(funcType.getResults()); - } else { - // Construct the LLVM IR Dialect function type that the first operand - // should match. - if (funcType.getNumResults() > 1) - return parser->emitError(trailingTypeLoc, - "expected function with 0 or 1 result"); - - Builder &builder = parser->getBuilder(); - auto *llvmDialect = - builder.getContext()->getRegisteredDialect(); - LLVM::LLVMType llvmResultType; - if (funcType.getNumResults() == 0) { - llvmResultType = LLVM::LLVMType::getVoidTy(llvmDialect); - } else { - llvmResultType = funcType.getResult(0).dyn_cast(); - if (!llvmResultType) - return parser->emitError(trailingTypeLoc, - "expected result to have LLVM type"); - } - - SmallVector argTypes; - argTypes.reserve(funcType.getNumInputs()); - for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) { - auto argType = funcType.getInput(i).dyn_cast(); - if (!argType) - return parser->emitError(trailingTypeLoc, - "expected LLVM types as inputs"); - argTypes.push_back(argType); - } - auto llvmFuncType = LLVM::LLVMType::getFunctionTy(llvmResultType, argTypes, - /*isVarArg=*/false); - auto wrappedFuncType = llvmFuncType.getPointerTo(); - - auto funcArguments = - ArrayRef(operands).drop_front(); - - // Make sure that the first operand (indirect callee) matches the wrapped - // LLVM IR function type, and that the types of the other call operands - // match the types of the function arguments. - if (parser->resolveOperand(operands[0], wrappedFuncType, - result->operands) || - parser->resolveOperands(funcArguments, funcType.getInputs(), - parser->getNameLoc(), result->operands)) - return failure(); - - result->addTypes(llvmResultType); - } - - result->attributes = attrs; - return success(); -} - -//===----------------------------------------------------------------------===// -// Printing/parsing for LLVM::ExtractElementOp. -//===----------------------------------------------------------------------===// -// Expects vector to be of wrapped LLVM vector type and position to be of -// wrapped LLVM i32 type. -void LLVM::ExtractElementOp::build(Builder *b, OperationState *result, - Value *vector, Value *position, - ArrayRef attrs) { - auto wrappedVectorType = vector->getType().cast(); - auto llvmType = wrappedVectorType.getVectorElementType(); - build(b, result, llvmType, vector, position); - result->addAttributes(attrs); -} - -static void printExtractElementOp(OpAsmPrinter *p, ExtractElementOp &op) { - *p << op.getOperationName() << ' ' << *op.vector() << ", " << *op.position(); - p->printOptionalAttrDict(op.getAttrs()); - *p << " : " << op.vector()->getType(); -} - -// ::= `llvm.extractelement` ssa-use `, ` ssa-use -// attribute-dict? `:` type -static ParseResult parseExtractElementOp(OpAsmParser *parser, - OperationState *result) { - llvm::SMLoc loc; - OpAsmParser::OperandType vector, position; - auto *llvmDialect = parser->getBuilder() - .getContext() - ->getRegisteredDialect(); - Type type, i32Type = LLVMType::getInt32Ty(llvmDialect); - if (parser->getCurrentLocation(&loc) || parser->parseOperand(vector) || - parser->parseComma() || parser->parseOperand(position) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(type) || - parser->resolveOperand(vector, type, result->operands) || - parser->resolveOperand(position, i32Type, result->operands)) - return failure(); - auto wrappedVectorType = type.dyn_cast(); - if (!wrappedVectorType || - !wrappedVectorType.getUnderlyingType()->isVectorTy()) - return parser->emitError( - loc, "expected LLVM IR dialect vector type for operand #1"); - result->addTypes(wrappedVectorType.getVectorElementType()); - return success(); -} - -//===----------------------------------------------------------------------===// -// Printing/parsing for LLVM::ExtractValueOp. -//===----------------------------------------------------------------------===// - -static void printExtractValueOp(OpAsmPrinter *p, ExtractValueOp &op) { - *p << op.getOperationName() << ' ' << *op.container() << op.position(); - p->printOptionalAttrDict(op.getAttrs(), {"position"}); - *p << " : " << op.container()->getType(); -} - -// Extract the type at `position` in the wrapped LLVM IR aggregate type -// `containerType`. Position is an integer array attribute where each value -// is a zero-based position of the element in the aggregate type. Return the -// resulting type wrapped in MLIR, or nullptr on error. -static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser *parser, - Type containerType, - Attribute positionAttr, - llvm::SMLoc attributeLoc, - llvm::SMLoc typeLoc) { - auto wrappedContainerType = containerType.dyn_cast(); - if (!wrappedContainerType) - return parser->emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr; - - auto positionArrayAttr = positionAttr.dyn_cast(); - if (!positionArrayAttr) - return parser->emitError(attributeLoc, "expected an array attribute"), - nullptr; - - // Infer the element type from the structure type: iteratively step inside the - // type by taking the element type, indexed by the position attribute for - // stuctures. Check the position index before accessing, it is supposed to be - // in bounds. - for (Attribute subAttr : positionArrayAttr) { - auto positionElementAttr = subAttr.dyn_cast(); - if (!positionElementAttr) - return parser->emitError(attributeLoc, - "expected an array of integer literals"), - nullptr; - int position = positionElementAttr.getInt(); - auto *llvmContainerType = wrappedContainerType.getUnderlyingType(); - if (llvmContainerType->isArrayTy()) { - if (position < 0 || static_cast(position) >= - llvmContainerType->getArrayNumElements()) - return parser->emitError(attributeLoc, "position out of bounds"), - nullptr; - wrappedContainerType = wrappedContainerType.getArrayElementType(); - } else if (llvmContainerType->isStructTy()) { - if (position < 0 || static_cast(position) >= - llvmContainerType->getStructNumElements()) - return parser->emitError(attributeLoc, "position out of bounds"), - nullptr; - wrappedContainerType = - wrappedContainerType.getStructElementType(position); - } else { - return parser->emitError(typeLoc, - "expected wrapped LLVM IR structure/array type"), - nullptr; - } - } - return wrappedContainerType; -} - -// ::= `llvm.extractvalue` ssa-use -// `[` integer-literal (`,` integer-literal)* `]` -// attribute-dict? `:` type -static ParseResult parseExtractValueOp(OpAsmParser *parser, - OperationState *result) { - SmallVector attrs; - OpAsmParser::OperandType container; - Type containerType; - Attribute positionAttr; - llvm::SMLoc attributeLoc, trailingTypeLoc; - - if (parser->parseOperand(container) || - parser->getCurrentLocation(&attributeLoc) || - parser->parseAttribute(positionAttr, "position", attrs) || - parser->parseOptionalAttributeDict(attrs) || parser->parseColon() || - parser->getCurrentLocation(&trailingTypeLoc) || - parser->parseType(containerType) || - parser->resolveOperand(container, containerType, result->operands)) - return failure(); - - auto elementType = getInsertExtractValueElementType( - parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); - if (!elementType) - return failure(); - - result->attributes = attrs; - result->addTypes(elementType); - return success(); -} - -//===----------------------------------------------------------------------===// -// Printing/parsing for LLVM::InsertElementOp. -//===----------------------------------------------------------------------===// - -static void printInsertElementOp(OpAsmPrinter *p, InsertElementOp &op) { - *p << op.getOperationName() << ' ' << *op.vector() << ", " << *op.value() - << ", " << *op.position(); - p->printOptionalAttrDict(op.getAttrs()); - *p << " : " << op.vector()->getType(); -} - -// ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use -// attribute-dict? `:` type -static ParseResult parseInsertElementOp(OpAsmParser *parser, - OperationState *result) { - llvm::SMLoc loc; - OpAsmParser::OperandType vector, value, position; - auto *llvmDialect = parser->getBuilder() - .getContext() - ->getRegisteredDialect(); - Type vectorType, i32Type = LLVMType::getInt32Ty(llvmDialect); - if (parser->getCurrentLocation(&loc) || parser->parseOperand(vector) || - parser->parseComma() || parser->parseOperand(value) || - parser->parseComma() || parser->parseOperand(position) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(vectorType)) - return failure(); - - auto wrappedVectorType = vectorType.dyn_cast(); - if (!wrappedVectorType || - !wrappedVectorType.getUnderlyingType()->isVectorTy()) - return parser->emitError( - loc, "expected LLVM IR dialect vector type for operand #1"); - auto valueType = wrappedVectorType.getVectorElementType(); - if (!valueType) - return failure(); - - if (parser->resolveOperand(vector, vectorType, result->operands) || - parser->resolveOperand(value, valueType, result->operands) || - parser->resolveOperand(position, i32Type, result->operands)) - return failure(); - - result->addTypes(vectorType); - return success(); -} - -//===----------------------------------------------------------------------===// -// Printing/parsing for LLVM::InsertValueOp. -//===----------------------------------------------------------------------===// - -static void printInsertValueOp(OpAsmPrinter *p, InsertValueOp &op) { - *p << op.getOperationName() << ' ' << *op.value() << ", " << *op.container() - << op.position(); - p->printOptionalAttrDict(op.getAttrs(), {"position"}); - *p << " : " << op.container()->getType(); -} - -// ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use -// `[` integer-literal (`,` integer-literal)* `]` -// attribute-dict? `:` type -static ParseResult parseInsertValueOp(OpAsmParser *parser, - OperationState *result) { - OpAsmParser::OperandType container, value; - Type containerType; - Attribute positionAttr; - llvm::SMLoc attributeLoc, trailingTypeLoc; - - if (parser->parseOperand(value) || parser->parseComma() || - parser->parseOperand(container) || - parser->getCurrentLocation(&attributeLoc) || - parser->parseAttribute(positionAttr, "position", result->attributes) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) || - parser->parseType(containerType)) - return failure(); - - auto valueType = getInsertExtractValueElementType( - parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); - if (!valueType) - return failure(); - - if (parser->resolveOperand(container, containerType, result->operands) || - parser->resolveOperand(value, valueType, result->operands)) - return failure(); - - result->addTypes(containerType); - return success(); -} - -//===----------------------------------------------------------------------===// -// Printing/parsing for LLVM::SelectOp. -//===----------------------------------------------------------------------===// - -static void printSelectOp(OpAsmPrinter *p, SelectOp &op) { - *p << op.getOperationName() << ' ' << *op.condition() << ", " - << *op.trueValue() << ", " << *op.falseValue(); - p->printOptionalAttrDict(op.getAttrs()); - *p << " : " << op.condition()->getType() << ", " << op.trueValue()->getType(); -} - -// ::= `llvm.select` ssa-use `,` ssa-use `,` ssa-use -// attribute-dict? `:` type, type -static ParseResult parseSelectOp(OpAsmParser *parser, OperationState *result) { - OpAsmParser::OperandType condition, trueValue, falseValue; - Type conditionType, argType; - - if (parser->parseOperand(condition) || parser->parseComma() || - parser->parseOperand(trueValue) || parser->parseComma() || - parser->parseOperand(falseValue) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(conditionType) || parser->parseComma() || - parser->parseType(argType)) - return failure(); - - if (parser->resolveOperand(condition, conditionType, result->operands) || - parser->resolveOperand(trueValue, argType, result->operands) || - parser->resolveOperand(falseValue, argType, result->operands)) - return failure(); - - result->addTypes(argType); - return success(); -} - -//===----------------------------------------------------------------------===// -// Printing/parsing for LLVM::BrOp. -//===----------------------------------------------------------------------===// - -static void printBrOp(OpAsmPrinter *p, BrOp &op) { - *p << op.getOperationName() << ' '; - p->printSuccessorAndUseList(op.getOperation(), 0); - p->printOptionalAttrDict(op.getAttrs()); -} - -// ::= `llvm.br` bb-id (`[` ssa-use-and-type-list `]`)? -// attribute-dict? -static ParseResult parseBrOp(OpAsmParser *parser, OperationState *result) { - Block *dest; - SmallVector operands; - if (parser->parseSuccessorAndUseList(dest, operands) || - parser->parseOptionalAttributeDict(result->attributes)) - return failure(); - - result->addSuccessor(dest, operands); - return success(); -} - -//===----------------------------------------------------------------------===// -// Printing/parsing for LLVM::CondBrOp. -//===----------------------------------------------------------------------===// - -static void printCondBrOp(OpAsmPrinter *p, CondBrOp &op) { - *p << op.getOperationName() << ' ' << *op.getOperand(0) << ", "; - p->printSuccessorAndUseList(op.getOperation(), 0); - *p << ", "; - p->printSuccessorAndUseList(op.getOperation(), 1); - p->printOptionalAttrDict(op.getAttrs()); -} - -// ::= `llvm.cond_br` ssa-use `,` -// bb-id (`[` ssa-use-and-type-list `]`)? `,` -// bb-id (`[` ssa-use-and-type-list `]`)? attribute-dict? -static ParseResult parseCondBrOp(OpAsmParser *parser, OperationState *result) { - Block *trueDest; - Block *falseDest; - SmallVector trueOperands; - SmallVector falseOperands; - OpAsmParser::OperandType condition; - - Builder &builder = parser->getBuilder(); - auto *llvmDialect = - builder.getContext()->getRegisteredDialect(); - auto i1Type = LLVM::LLVMType::getInt1Ty(llvmDialect); - - if (parser->parseOperand(condition) || parser->parseComma() || - parser->parseSuccessorAndUseList(trueDest, trueOperands) || - parser->parseComma() || - parser->parseSuccessorAndUseList(falseDest, falseOperands) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->resolveOperand(condition, i1Type, result->operands)) - return failure(); - - result->addSuccessor(trueDest, trueOperands); - result->addSuccessor(falseDest, falseOperands); - return success(); -} - -//===----------------------------------------------------------------------===// -// Printing/parsing for LLVM::ReturnOp. -//===----------------------------------------------------------------------===// - -static void printReturnOp(OpAsmPrinter *p, ReturnOp &op) { - *p << op.getOperationName(); - p->printOptionalAttrDict(op.getAttrs()); - assert(op.getNumOperands() <= 1); - - if (op.getNumOperands() == 0) - return; - - *p << ' ' << *op.getOperand(0) << " : " << op.getOperand(0)->getType(); -} - -// ::= `llvm.return` ssa-use-list attribute-dict? `:` -// type-list-no-parens -static ParseResult parseReturnOp(OpAsmParser *parser, OperationState *result) { - SmallVector operands; - Type type; - - if (parser->parseOperandList(operands) || - parser->parseOptionalAttributeDict(result->attributes)) - return failure(); - if (operands.empty()) - return success(); - - if (parser->parseColonType(type) || - parser->resolveOperand(operands[0], type, result->operands)) - return failure(); - return success(); -} - -//===----------------------------------------------------------------------===// -// Printing/parsing for LLVM::UndefOp. -//===----------------------------------------------------------------------===// - -static void printUndefOp(OpAsmPrinter *p, UndefOp &op) { - *p << op.getOperationName(); - p->printOptionalAttrDict(op.getAttrs()); - *p << " : " << op.res()->getType(); -} - -// ::= `llvm.undef` attribute-dict? : type -static ParseResult parseUndefOp(OpAsmParser *parser, OperationState *result) { - Type type; - - if (parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(type)) - return failure(); - - result->addTypes(type); - return success(); -} - -//===----------------------------------------------------------------------===// -// Printer, parser and verifier for LLVM::AddressOfOp. -//===----------------------------------------------------------------------===// - -GlobalOp AddressOfOp::getGlobal() { - auto module = getParentOfType(); - assert(module && "unexpected operation outside of a module"); - return module.lookupSymbol(global_name()); -} - -static void printAddressOfOp(OpAsmPrinter *p, AddressOfOp op) { - *p << op.getOperationName() << " @" << op.global_name(); - p->printOptionalAttrDict(op.getAttrs(), {"global_name"}); - *p << " : " << op.getResult()->getType(); -} - -static ParseResult parseAddressOfOp(OpAsmParser *parser, - OperationState *result) { - Attribute symRef; - Type type; - if (parser->parseAttribute(symRef, "global_name", result->attributes) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(type) || - parser->addTypeToList(type, result->types)) - return failure(); - - if (!symRef.isa()) - return parser->emitError(parser->getNameLoc(), "expected symbol reference"); - return success(); -} - -static LogicalResult verify(AddressOfOp op) { - auto global = op.getGlobal(); - if (!global) - return op.emitOpError("must reference a global defined by 'llvm.global'"); - - if (global.getType().getPointerTo() != op.getResult()->getType()) - return op.emitOpError( - "the type must be a pointer to the type of the referred global"); - - return success(); -} - -//===----------------------------------------------------------------------===// -// Printing/parsing for LLVM::ConstantOp. -//===----------------------------------------------------------------------===// - -static void printConstantOp(OpAsmPrinter *p, ConstantOp &op) { - *p << op.getOperationName() << '(' << op.value() << ')'; - p->printOptionalAttrDict(op.getAttrs(), {"value"}); - *p << " : " << op.res()->getType(); -} - -// ::= `llvm.constant` `(` attribute `)` attribute-list? : type -static ParseResult parseConstantOp(OpAsmParser *parser, - OperationState *result) { - Attribute valueAttr; - Type type; - - if (parser->parseLParen() || - parser->parseAttribute(valueAttr, "value", result->attributes) || - parser->parseRParen() || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(type)) - return failure(); - - result->addTypes(type); - return success(); -} - -//===----------------------------------------------------------------------===// -// Builder, printer and verifier for LLVM::GlobalOp. -//===----------------------------------------------------------------------===// - -void GlobalOp::build(Builder *builder, OperationState *result, LLVMType type, - bool isConstant, StringRef name, Attribute value, - ArrayRef attrs) { - result->addAttribute(SymbolTable::getSymbolAttrName(), - builder->getStringAttr(name)); - result->addAttribute("type", builder->getTypeAttr(type)); - if (isConstant) - result->addAttribute("constant", builder->getUnitAttr()); - result->addAttribute("value", value); - result->attributes.append(attrs.begin(), attrs.end()); -} - -static void printGlobalOp(OpAsmPrinter *p, GlobalOp op) { - *p << op.getOperationName() << ' '; - if (op.constant()) - *p << "constant "; - *p << '@' << op.sym_name() << '('; - p->printAttribute(op.value()); - *p << ')'; - p->printOptionalAttrDict(op.getAttrs(), {SymbolTable::getSymbolAttrName(), - "type", "constant", "value"}); - - // Print the trailing type unless it's a string global. - if (op.value().isa()) - return; - *p << " : "; - p->printType(op.type()); -} - -// ::= `llvm.global` `constant`? `@` identifier `(` attribute `)` -// attribute-list? (`:` type)? -// -// The type can be omitted for string attributes, in which case it will be -// inferred from the value of the string as [strlen(value) x i8]. -static ParseResult parseGlobalOp(OpAsmParser *parser, OperationState *result) { - if (succeeded(parser->parseOptionalKeyword("constant"))) - result->addAttribute("constant", parser->getBuilder().getUnitAttr()); - - Attribute value; - StringAttr name; - SmallVector types; - if (parser->parseSymbolName(name, SymbolTable::getSymbolAttrName(), - result->attributes) || - parser->parseLParen() || - parser->parseAttribute(value, "value", result->attributes) || - parser->parseRParen() || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseOptionalColonTypeList(types)) - return failure(); - - if (types.size() > 1) - return parser->emitError(parser->getNameLoc(), "expected zero or one type"); - - if (types.empty()) { - if (auto strAttr = value.dyn_cast()) { - MLIRContext *context = parser->getBuilder().getContext(); - auto *dialect = context->getRegisteredDialect(); - auto arrayType = LLVM::LLVMType::getArrayTy( - LLVM::LLVMType::getInt8Ty(dialect), strAttr.getValue().size()); - types.push_back(arrayType); - } else { - return parser->emitError(parser->getNameLoc(), - "type can only be omitted for string globals"); - } - } - - result->addAttribute("type", parser->getBuilder().getTypeAttr(types[0])); - return success(); -} - -static LogicalResult verify(GlobalOp op) { - if (!llvm::PointerType::isValidElementType(op.getType().getUnderlyingType())) - return op.emitOpError( - "expects type to be a valid element type for an LLVM pointer"); - if (op.getParentOp() && !isa(op.getParentOp())) - return op.emitOpError("must appear at the module level"); - if (auto strAttr = op.value().dyn_cast()) { - auto type = op.getType(); - if (!type.getUnderlyingType()->isArrayTy() || - !type.getArrayElementType().getUnderlyingType()->isIntegerTy(8) || - type.getArrayNumElements() != strAttr.getValue().size()) - return op.emitOpError( - "requires an i8 array type of the length equal to that of the string " - "attribute"); - } - return success(); -} - -//===----------------------------------------------------------------------===// -// Printing/parsing for LLVM::ShuffleVectorOp. -//===----------------------------------------------------------------------===// -// Expects vector to be of wrapped LLVM vector type and position to be of -// wrapped LLVM i32 type. -void LLVM::ShuffleVectorOp::build(Builder *b, OperationState *result, Value *v1, - Value *v2, ArrayAttr mask, - ArrayRef attrs) { - auto wrappedContainerType1 = v1->getType().cast(); - auto vType = LLVMType::getVectorTy( - wrappedContainerType1.getVectorElementType(), mask.size()); - build(b, result, vType, v1, v2, mask); - result->addAttributes(attrs); -} - -static void printShuffleVectorOp(OpAsmPrinter *p, ShuffleVectorOp &op) { - *p << op.getOperationName() << ' ' << *op.v1() << ", " << *op.v2() << " " - << op.mask(); - p->printOptionalAttrDict(op.getAttrs(), {"mask"}); - *p << " : " << op.v1()->getType() << ", " << op.v2()->getType(); -} - -// ::= `llvm.shufflevector` ssa-use `, ` ssa-use -// `[` integer-literal (`,` integer-literal)* `]` -// attribute-dict? `:` type -static ParseResult parseShuffleVectorOp(OpAsmParser *parser, - OperationState *result) { - llvm::SMLoc loc; - SmallVector attrs; - OpAsmParser::OperandType v1, v2; - Attribute maskAttr; - Type typeV1, typeV2; - if (parser->getCurrentLocation(&loc) || parser->parseOperand(v1) || - parser->parseComma() || parser->parseOperand(v2) || - parser->parseAttribute(maskAttr, "mask", attrs) || - parser->parseOptionalAttributeDict(attrs) || - parser->parseColonType(typeV1) || parser->parseComma() || - parser->parseType(typeV2) || - parser->resolveOperand(v1, typeV1, result->operands) || - parser->resolveOperand(v2, typeV2, result->operands)) - return failure(); - auto wrappedContainerType1 = typeV1.dyn_cast(); - if (!wrappedContainerType1 || - !wrappedContainerType1.getUnderlyingType()->isVectorTy()) - return parser->emitError( - loc, "expected LLVM IR dialect vector type for operand #1"); - auto vType = - LLVMType::getVectorTy(wrappedContainerType1.getVectorElementType(), - maskAttr.cast().size()); - result->attributes = attrs; - result->addTypes(vType); - return success(); -} - -//===----------------------------------------------------------------------===// -// Builder, printer and verifier for LLVM::LLVMFuncOp. -//===----------------------------------------------------------------------===// - -void LLVMFuncOp::build(Builder *builder, OperationState *result, StringRef name, - LLVMType type, ArrayRef attrs, - ArrayRef argAttrs) { - result->addRegion(); - result->addAttribute(SymbolTable::getSymbolAttrName(), - builder->getStringAttr(name)); - result->addAttribute("type", builder->getTypeAttr(type)); - result->attributes.append(attrs.begin(), attrs.end()); - if (argAttrs.empty()) - return; - - unsigned numInputs = type.getUnderlyingType()->getFunctionNumParams(); - assert(numInputs == argAttrs.size() && - "expected as many argument attribute lists as arguments"); - SmallString<8> argAttrName; - for (unsigned i = 0; i < numInputs; ++i) - if (auto argDict = argAttrs[i].getDictionary()) - result->addAttribute(getArgAttrName(i, argAttrName), argDict); -} - -// Build an LLVM function type from the given lists of input and output types. -// Returns a null type if any of the types provided are non-LLVM types, or if -// there is more than one output type. -static Type buildLLVMFunctionType(Builder &b, ArrayRef inputs, - ArrayRef outputs, - impl::VariadicFlag variadicFlag, - std::string &errorMessage) { - if (outputs.size() > 1) { - errorMessage = "expected zero or one function result"; - return {}; - } - - // Convert inputs to LLVM types, exit early on error. - SmallVector llvmInputs; - for (auto t : inputs) { - auto llvmTy = t.dyn_cast(); - if (!llvmTy) { - errorMessage = "expected LLVM type for function arguments"; - return {}; - } - llvmInputs.push_back(llvmTy); - } - - // Get the dialect from the input type, if any exist. Look it up in the - // context otherwise. - LLVMDialect *dialect = - llvmInputs.empty() ? b.getContext()->getRegisteredDialect() - : &llvmInputs.front().getDialect(); - - // No output is denoted as "void" in LLVM type system. - LLVMType llvmOutput = outputs.empty() ? LLVMType::getVoidTy(dialect) - : outputs.front().dyn_cast(); - if (!llvmOutput) { - errorMessage = "expected LLVM type for function results"; - return {}; - } - return LLVMType::getFunctionTy(llvmOutput, llvmInputs, - variadicFlag.isVariadic()); -} - -// Print the LLVMFuncOp. Collects argument and result types and passes them -// to the trait printer. Drops "void" result since it cannot be parsed back. -static void printLLVMFuncOp(OpAsmPrinter *p, LLVMFuncOp op) { - LLVMType fnType = op.getType(); - SmallVector argTypes; - SmallVector resTypes; - argTypes.reserve(fnType.getFunctionNumParams()); - for (unsigned i = 0, e = fnType.getFunctionNumParams(); i < e; ++i) - argTypes.push_back(fnType.getFunctionParamType(i)); - - LLVMType returnType = fnType.getFunctionResultType(); - if (!returnType.getUnderlyingType()->isVoidTy()) - resTypes.push_back(returnType); - - impl::printFunctionLikeOp(p, op, argTypes, op.isVarArg(), resTypes); -} - -// Hook for OpTrait::FunctionLike, called after verifying that the 'type' -// attribute is present. This can check for preconditions of the -// getNumArguments hook not failing. -LogicalResult LLVMFuncOp::verifyType() { - auto llvmType = getTypeAttr().getValue().dyn_cast_or_null(); - if (!llvmType || !llvmType.getUnderlyingType()->isFunctionTy()) - return emitOpError("requires '" + getTypeAttrName() + - "' attribute of wrapped LLVM function type"); - - return success(); -} - -// Hook for OpTrait::FunctionLike, returns the number of function arguments. -// Depends on the type attribute being correct as checked by verifyType -unsigned LLVMFuncOp::getNumFuncArguments() { - return getType().getUnderlyingType()->getFunctionNumParams(); -} - -static LogicalResult verify(LLVMFuncOp op) { - if (op.isExternal()) - return success(); - - if (op.isVarArg()) - return op.emitOpError("only external functions can be variadic"); - - auto *funcType = cast(op.getType().getUnderlyingType()); - unsigned numArguments = funcType->getNumParams(); - Block &entryBlock = op.front(); - for (unsigned i = 0; i < numArguments; ++i) { - Type argType = entryBlock.getArgument(i)->getType(); - auto argLLVMType = argType.dyn_cast(); - if (!argLLVMType) - return op.emitOpError("entry block argument #") - << i << " is not of LLVM type"; - if (funcType->getParamType(i) != argLLVMType.getUnderlyingType()) - return op.emitOpError("the type of entry block argument #") - << i << " does not match the function signature"; - } - - return success(); -} - -//===----------------------------------------------------------------------===// -// LLVMDialect initialization, type parsing, and registration. -//===----------------------------------------------------------------------===// - -namespace mlir { -namespace LLVM { -namespace detail { -struct LLVMDialectImpl { - LLVMDialectImpl() : module("LLVMDialectModule", llvmContext) {} - - llvm::LLVMContext llvmContext; - llvm::Module module; - - /// A set of LLVMTypes that are cached on construction to avoid any lookups or - /// locking. - LLVMType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty; - LLVMType doubleTy, floatTy, halfTy; - LLVMType voidTy; - - /// A smart mutex to lock access to the llvm context. Unlike MLIR, LLVM is not - /// multi-threaded and requires locked access to prevent race conditions. - llvm::sys::SmartMutex mutex; -}; -} // end namespace detail -} // end namespace LLVM -} // end namespace mlir - -LLVMDialect::LLVMDialect(MLIRContext *context) - : Dialect(getDialectNamespace(), context), - impl(new detail::LLVMDialectImpl()) { - addTypes(); - addOperations< -#define GET_OP_LIST -#include "mlir/LLVMIR/LLVMOps.cpp.inc" - >(); - - // Support unknown operations because not all LLVM operations are registered. - allowUnknownOperations(); - - // Cache some of the common LLVM types to avoid the need for lookups/locking. - auto &llvmContext = impl->llvmContext; - /// Integer Types. - impl->int1Ty = LLVMType::get(context, llvm::Type::getInt1Ty(llvmContext)); - impl->int8Ty = LLVMType::get(context, llvm::Type::getInt8Ty(llvmContext)); - impl->int16Ty = LLVMType::get(context, llvm::Type::getInt16Ty(llvmContext)); - impl->int32Ty = LLVMType::get(context, llvm::Type::getInt32Ty(llvmContext)); - impl->int64Ty = LLVMType::get(context, llvm::Type::getInt64Ty(llvmContext)); - impl->int128Ty = LLVMType::get(context, llvm::Type::getInt128Ty(llvmContext)); - /// Float Types. - impl->doubleTy = LLVMType::get(context, llvm::Type::getDoubleTy(llvmContext)); - impl->floatTy = LLVMType::get(context, llvm::Type::getFloatTy(llvmContext)); - impl->halfTy = LLVMType::get(context, llvm::Type::getHalfTy(llvmContext)); - /// Other Types. - impl->voidTy = LLVMType::get(context, llvm::Type::getVoidTy(llvmContext)); -} - -LLVMDialect::~LLVMDialect() {} - -#define GET_OP_CLASSES -#include "mlir/LLVMIR/LLVMOps.cpp.inc" - -llvm::LLVMContext &LLVMDialect::getLLVMContext() { return impl->llvmContext; } -llvm::Module &LLVMDialect::getLLVMModule() { return impl->module; } - -/// Parse a type registered to this dialect. -Type LLVMDialect::parseType(StringRef tyData, Location loc) const { - // LLVM is not thread-safe, so lock access to it. - llvm::sys::SmartScopedLock lock(impl->mutex); - - llvm::SMDiagnostic errorMessage; - llvm::Type *type = llvm::parseType(tyData, errorMessage, impl->module); - if (!type) - return (emitError(loc, errorMessage.getMessage()), nullptr); - return LLVMType::get(getContext(), type); -} - -/// Print a type registered to this dialect. -void LLVMDialect::printType(Type type, raw_ostream &os) const { - auto llvmType = type.dyn_cast(); - assert(llvmType && "printing wrong type"); - assert(llvmType.getUnderlyingType() && "no underlying LLVM type"); - llvmType.getUnderlyingType()->print(os); -} - -/// Verify LLVMIR function argument attributes. -LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op, - unsigned regionIdx, - unsigned argIdx, - NamedAttribute argAttr) { - // Check that llvm.noalias is a boolean attribute. - if (argAttr.first == "llvm.noalias" && !argAttr.second.isa()) - return op->emitError() - << "llvm.noalias argument attribute of non boolean type"; - return success(); -} - -static DialectRegistration llvmDialect; - -//===----------------------------------------------------------------------===// -// LLVMType. -//===----------------------------------------------------------------------===// - -namespace mlir { -namespace LLVM { -namespace detail { -struct LLVMTypeStorage : public ::mlir::TypeStorage { - LLVMTypeStorage(llvm::Type *ty) : underlyingType(ty) {} - - // LLVM types are pointer-unique. - using KeyTy = llvm::Type *; - bool operator==(const KeyTy &key) const { return key == underlyingType; } - - static LLVMTypeStorage *construct(TypeStorageAllocator &allocator, - llvm::Type *ty) { - return new (allocator.allocate()) LLVMTypeStorage(ty); - } - - llvm::Type *underlyingType; -}; -} // end namespace detail -} // end namespace LLVM -} // end namespace mlir - -LLVMType LLVMType::get(MLIRContext *context, llvm::Type *llvmType) { - return Base::get(context, FIRST_LLVM_TYPE, llvmType); -} - -/// Get an LLVMType with an llvm type that may cause changes to the underlying -/// llvm context when constructed. -LLVMType LLVMType::getLocked(LLVMDialect *dialect, - llvm::function_ref typeBuilder) { - // Lock access to the llvm context and build the type. - llvm::sys::SmartScopedLock lock(dialect->impl->mutex); - return get(dialect->getContext(), typeBuilder()); -} - -LLVMDialect &LLVMType::getDialect() { - return static_cast(Type::getDialect()); -} - -llvm::Type *LLVMType::getUnderlyingType() const { - return getImpl()->underlyingType; -} - -/// Array type utilities. -LLVMType LLVMType::getArrayElementType() { - return get(getContext(), getUnderlyingType()->getArrayElementType()); -} -unsigned LLVMType::getArrayNumElements() { - return getUnderlyingType()->getArrayNumElements(); -} - -/// Vector type utilities. -LLVMType LLVMType::getVectorElementType() { - return get(getContext(), getUnderlyingType()->getVectorElementType()); -} - -/// Function type utilities. -LLVMType LLVMType::getFunctionParamType(unsigned argIdx) { - return get(getContext(), getUnderlyingType()->getFunctionParamType(argIdx)); -} -unsigned LLVMType::getFunctionNumParams() { - return getUnderlyingType()->getFunctionNumParams(); -} -LLVMType LLVMType::getFunctionResultType() { - return get( - getContext(), - llvm::cast(getUnderlyingType())->getReturnType()); -} - -/// Pointer type utilities. -LLVMType LLVMType::getPointerTo(unsigned addrSpace) { - // Lock access to the dialect as this may modify the LLVM context. - return getLocked(&getDialect(), [=] { - return getUnderlyingType()->getPointerTo(addrSpace); - }); -} -LLVMType LLVMType::getPointerElementTy() { - return get(getContext(), getUnderlyingType()->getPointerElementType()); -} - -/// Struct type utilities. -LLVMType LLVMType::getStructElementType(unsigned i) { - return get(getContext(), getUnderlyingType()->getStructElementType(i)); -} - -/// Utilities used to generate floating point types. -LLVMType LLVMType::getDoubleTy(LLVMDialect *dialect) { - return dialect->impl->doubleTy; -} -LLVMType LLVMType::getFloatTy(LLVMDialect *dialect) { - return dialect->impl->floatTy; -} -LLVMType LLVMType::getHalfTy(LLVMDialect *dialect) { - return dialect->impl->halfTy; -} - -/// Utilities used to generate integer types. -LLVMType LLVMType::getIntNTy(LLVMDialect *dialect, unsigned numBits) { - switch (numBits) { - case 1: - return dialect->impl->int1Ty; - case 8: - return dialect->impl->int8Ty; - case 16: - return dialect->impl->int16Ty; - case 32: - return dialect->impl->int32Ty; - case 64: - return dialect->impl->int64Ty; - case 128: - return dialect->impl->int128Ty; - default: - break; - } - - // Lock access to the dialect as this may modify the LLVM context. - return getLocked(dialect, [=] { - return llvm::Type::getIntNTy(dialect->getLLVMContext(), numBits); - }); -} - -/// Utilities used to generate other miscellaneous types. -LLVMType LLVMType::getArrayTy(LLVMType elementType, uint64_t numElements) { - // Lock access to the dialect as this may modify the LLVM context. - return getLocked(&elementType.getDialect(), [=] { - return llvm::ArrayType::get(elementType.getUnderlyingType(), numElements); - }); -} -LLVMType LLVMType::getFunctionTy(LLVMType result, ArrayRef params, - bool isVarArg) { - SmallVector llvmParams; - for (auto param : params) - llvmParams.push_back(param.getUnderlyingType()); - - // Lock access to the dialect as this may modify the LLVM context. - return getLocked(&result.getDialect(), [=] { - return llvm::FunctionType::get(result.getUnderlyingType(), llvmParams, - isVarArg); - }); -} -LLVMType LLVMType::getStructTy(LLVMDialect *dialect, - ArrayRef elements, bool isPacked) { - SmallVector llvmElements; - for (auto elt : elements) - llvmElements.push_back(elt.getUnderlyingType()); - - // Lock access to the dialect as this may modify the LLVM context. - return getLocked(dialect, [=] { - return llvm::StructType::get(dialect->getLLVMContext(), llvmElements, - isPacked); - }); -} -LLVMType LLVMType::getVectorTy(LLVMType elementType, unsigned numElements) { - // Lock access to the dialect as this may modify the LLVM context. - return getLocked(&elementType.getDialect(), [=] { - return llvm::VectorType::get(elementType.getUnderlyingType(), numElements); - }); -} -LLVMType LLVMType::getVoidTy(LLVMDialect *dialect) { - return dialect->impl->voidTy; -} diff --git a/mlir/lib/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/LLVMIR/IR/NVVMDialect.cpp deleted file mode 100644 index f586f0e5c7c..00000000000 --- a/mlir/lib/LLVMIR/IR/NVVMDialect.cpp +++ /dev/null @@ -1,88 +0,0 @@ -//===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file defines the types and operation details for the NVVM IR dialect in -// MLIR, and the LLVM IR dialect. It also registers the dialect. -// -// The NVVM dialect only contains GPU specific additions on top of the general -// LLVM dialect. -// -//===----------------------------------------------------------------------===// - -#include "mlir/LLVMIR/NVVMDialect.h" - -#include "mlir/IR/Builders.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/LLVMIR/LLVMDialect.h" -#include "llvm/AsmParser/Parser.h" -#include "llvm/IR/Attributes.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/Type.h" -#include "llvm/Support/SourceMgr.h" - -namespace mlir { -namespace NVVM { - -//===----------------------------------------------------------------------===// -// Printing/parsing for NVVM ops -//===----------------------------------------------------------------------===// - -static void printNVVMSpecialRegisterOp(OpAsmPrinter *p, Operation *op) { - *p << op->getName() << " : "; - if (op->getNumResults() == 1) { - *p << op->getResult(0)->getType(); - } else { - *p << "###invalid type###"; - } -} - -// ::= `llvm.nvvm.XYZ` : type -static ParseResult parseNVVMSpecialRegisterOp(OpAsmParser *parser, - OperationState *result) { - Type type; - if (parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(type)) - return failure(); - - result->addTypes(type); - return success(); -} - -//===----------------------------------------------------------------------===// -// NVVMDialect initialization, type parsing, and registration. -//===----------------------------------------------------------------------===// - -// TODO(herhut): This should be the llvm.nvvm dialect once this is supported. -NVVMDialect::NVVMDialect(MLIRContext *context) : Dialect("nvvm", context) { - addOperations< -#define GET_OP_LIST -#include "mlir/LLVMIR/NVVMOps.cpp.inc" - >(); - - // Support unknown operations because not all NVVM operations are registered. - allowUnknownOperations(); -} - -#define GET_OP_CLASSES -#include "mlir/LLVMIR/NVVMOps.cpp.inc" - -static DialectRegistration nvvmDialect; - -} // namespace NVVM -} // namespace mlir diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index 05bdf24a975..6fa075fa6d9 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -18,6 +18,7 @@ #include "mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/EDSC/Builders.h" #include "mlir/EDSC/Intrinsics.h" #include "mlir/IR/Attributes.h" @@ -28,7 +29,6 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" #include "mlir/IR/Types.h" -#include "mlir/LLVMIR/LLVMDialect.h" #include "mlir/Linalg/IR/LinalgOps.h" #include "mlir/Linalg/IR/LinalgTypes.h" #include "mlir/Linalg/Passes.h" diff --git a/mlir/lib/Linalg/Transforms/LowerToLoops.cpp b/mlir/lib/Linalg/Transforms/LowerToLoops.cpp index faef51f5c8c..1fd50666f00 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLoops.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLoops.cpp @@ -17,6 +17,7 @@ #include "mlir/AffineOps/AffineOps.h" #include "mlir/Dialect/LoopOps/LoopOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/EDSC/Helpers.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -28,7 +29,6 @@ #include "mlir/Linalg/Utils/Intrinsics.h" #include "mlir/Linalg/Utils/Utils.h" #include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/Ops.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/STLExtras.h" #include "mlir/Transforms/DialectConversion.h" diff --git a/mlir/lib/Linalg/Utils/Utils.cpp b/mlir/lib/Linalg/Utils/Utils.cpp index d31fe0d3006..9472b80f58e 100644 --- a/mlir/lib/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Linalg/Utils/Utils.cpp @@ -21,6 +21,7 @@ #include "mlir/Linalg/Utils/Utils.h" #include "mlir/Dialect/LoopOps/LoopOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/EDSC/Helpers.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -30,7 +31,6 @@ #include "mlir/Linalg/Passes.h" #include "mlir/Linalg/Utils/Intrinsics.h" #include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/Ops.h" #include "mlir/Support/STLExtras.h" #include "mlir/Transforms/FoldUtils.h" diff --git a/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp b/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp index 4119bde5ac1..94e364238c5 100644 --- a/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp +++ b/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp @@ -26,13 +26,13 @@ #include "mlir/Dialect/FxpMathOps/FxpMathOps.h" #include "mlir/Dialect/QuantOps/QuantOps.h" #include "mlir/Dialect/QuantOps/QuantTypes.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h" #include "mlir/Quantizer/Support/Metadata.h" #include "mlir/Quantizer/Support/Statistics.h" #include "mlir/Quantizer/Support/UniformConstraints.h" -#include "mlir/StandardOps/Ops.h" using namespace mlir; using namespace mlir::quantizer; diff --git a/mlir/lib/SDBM/CMakeLists.txt b/mlir/lib/SDBM/CMakeLists.txt deleted file mode 100644 index 30b2f641a7b..00000000000 --- a/mlir/lib/SDBM/CMakeLists.txt +++ /dev/null @@ -1,10 +0,0 @@ -add_llvm_library(MLIRSDBM - SDBM.cpp - SDBMExpr.cpp - SDBMDialect.cpp - - ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/SDBM -) -add_dependencies(MLIRSDBM MLIRIR) -target_link_libraries(MLIRSDBM MLIRIR) diff --git a/mlir/lib/SDBM/SDBM.cpp b/mlir/lib/SDBM/SDBM.cpp deleted file mode 100644 index 13932c649b0..00000000000 --- a/mlir/lib/SDBM/SDBM.cpp +++ /dev/null @@ -1,561 +0,0 @@ -//===- SDBM.cpp - MLIR SDBM implementation --------------------------------===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// A striped difference-bound matrix (SDBM) is a set in Z^N (or R^N) defined -// as {(x_1, ... x_n) | f(x_1, ... x_n) >= 0} where f is an SDBM expression. -// -//===----------------------------------------------------------------------===// - -#include "mlir/SDBM/SDBM.h" -#include "mlir/SDBM/SDBMExpr.h" - -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/SetVector.h" -#include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/raw_ostream.h" - -using namespace mlir; - -// Helper function for SDBM construction that collects information necessary to -// start building an SDBM in one sweep. In particular, it records the largest -// position of a dimension in `dim`, that of a symbol in `symbol` as well as -// collects all unique stripe expressions in `stripes`. Uses SetVector to -// ensure these expressions always have the same order. -static void collectSDBMBuildInfo(SDBMExpr expr, int &dim, int &symbol, - llvm::SmallSetVector &stripes) { - struct Visitor : public SDBMVisitor { - void visitDim(SDBMDimExpr dimExpr) { - int p = dimExpr.getPosition(); - if (p > maxDimPosition) - maxDimPosition = p; - } - void visitSymbol(SDBMSymbolExpr symbExpr) { - int p = symbExpr.getPosition(); - if (p > maxSymbPosition) - maxSymbPosition = p; - } - void visitStripe(SDBMStripeExpr stripeExpr) { stripes.insert(stripeExpr); } - - Visitor(llvm::SmallSetVector &stripes) : stripes(stripes) {} - - int maxDimPosition = -1; - int maxSymbPosition = -1; - llvm::SmallSetVector &stripes; - }; - - Visitor visitor(stripes); - visitor.walkPostorder(expr); - dim = std::max(dim, visitor.maxDimPosition); - symbol = std::max(symbol, visitor.maxSymbPosition); -} - -namespace { -// Utility class for SDBMBuilder. Represents a value that can be inserted in -// the SDB matrix that corresponds to "v0 - v1 + C <= 0", where v0 and v1 is -// any combination of the positive and negative positions. Since multiple -// variables can be declared equal to the same stripe expression, the -// constraints on this expression must be reflected to all these variables. For -// example, if -// d0 = s0 # 42 -// d1 = s0 # 42 -// d2 = s1 # 2 -// d3 = s1 # 2 -// the constraint -// s0 # 42 - s1 # 2 <= C -// should be reflected in the DB matrix as -// d0 - d2 <= C -// d1 - d2 <= C -// d0 - d3 <= C -// d1 - d3 <= C -// since the DB matrix has no knowledge of the transitive equality between d0, -// d1 and s0 # 42 as well as between d2, d3 and s1 # 2. This knowledge can be -// obtained by computing a transitive closure, which is impossible until the -// DBM is actually built. -struct SDBMBuilderResult { - // Positions in the matrix of the variables taken with the "+" sign in the - // difference expression, 0 if it is a constant rather than a variable. - llvm::SmallVector positivePos; - - // Positions in the matrix of the variables taken with the "-" sign in the - // difference expression, 0 if it is a constant rather than a variable. - llvm::SmallVector negativePos; - - // Constant value in the difference expression. - int64_t value = 0; -}; - -// Visitor for building an SDBM from SDBM expressions. After traversing an SDBM -// expression, produces an update to the SDB matrix specifying the positions in -// the matrix and the negated value that should be stored. Both the positive -// and the negative positions may be lists of indices in cases where multiple -// variables are equal to the same stripe expression. In such cases, the update -// applies to the cross product of positions because elements involved in the -// update are (transitively) equal and should have the same constraints, but we -// may not have an explicit equality for them. -struct SDBMBuilder : public SDBMVisitor { -public: - // A difference expression produces both the positive and the negative - // coordinate in the matrix, recursively traversing the LHS and the RHS. The - // value is the difference between values obtained from LHS and RHS. - SDBMBuilderResult visitDiff(SDBMDiffExpr diffExpr) { - auto lhs = visit(diffExpr.getLHS()); - auto rhs = visit(diffExpr.getRHS()); - assert(lhs.negativePos.size() == 1 && lhs.negativePos[0] == 0 && - "unexpected negative expression in a difference expression"); - assert(rhs.negativePos.size() == 1 && lhs.negativePos[0] == 0 && - "unexpected negative expression in a difference expression"); - - SDBMBuilderResult result; - result.positivePos = lhs.positivePos; - result.negativePos = rhs.positivePos; - result.value = lhs.value - rhs.value; - return result; - } - - // An input expression is always taken with the "+" sign and therefore - // produces a positive coordinate keeping the negative coordinate zero for an - // eventual constant. - SDBMBuilderResult visitInput(SDBMInputExpr expr) { - SDBMBuilderResult r; - r.positivePos.push_back(linearPosition(expr)); - r.negativePos.push_back(0); - return r; - } - - // A stripe expression is always equal to one or more variables, which may be - // temporaries, and appears with a "+" sign in the SDBM expression tree. Take - // the positions of the corresponding variables as positive coordinates. - SDBMBuilderResult visitStripe(SDBMStripeExpr expr) { - SDBMBuilderResult r; - assert(pointExprToStripe.count(expr)); - r.positivePos = pointExprToStripe[expr]; - r.negativePos.push_back(0); - return r; - } - - // A constant expression has both coordinates at zero. - SDBMBuilderResult visitConstant(SDBMConstantExpr expr) { - SDBMBuilderResult r; - r.positivePos.push_back(0); - r.negativePos.push_back(0); - r.value = expr.getValue(); - return r; - } - - // A negation expression swaps the positive and the negative coordinates - // and also negates the constant value. - SDBMBuilderResult visitNeg(SDBMNegExpr expr) { - SDBMBuilderResult result = visit(expr.getVar()); - std::swap(result.positivePos, result.negativePos); - result.value = -result.value; - return result; - } - - // The RHS of a sum expression must be a constant and therefore must have both - // positive and negative coordinates at zero. Take the sum of the values - // between LHS and RHS and keep LHS coordinates. - SDBMBuilderResult visitSum(SDBMSumExpr expr) { - auto lhs = visit(expr.getLHS()); - auto rhs = visit(expr.getRHS()); - for (auto pos : rhs.negativePos) { - (void)pos; - assert(pos == 0 && "unexpected variable on the RHS of SDBM sum"); - } - for (auto pos : rhs.positivePos) { - (void)pos; - assert(pos == 0 && "unexpected variable on the RHS of SDBM sum"); - } - - lhs.value += rhs.value; - return lhs; - } - - SDBMBuilder(llvm::DenseMap> - &pointExprToStripe, - llvm::function_ref callback) - : pointExprToStripe(pointExprToStripe), linearPosition(callback) {} - - llvm::DenseMap> &pointExprToStripe; - llvm::function_ref linearPosition; -}; -} // namespace - -SDBM SDBM::get(ArrayRef inequalities, ArrayRef equalities) { - SDBM result; - - // TODO(zinenko): consider detecting equalities in the list of inequalities. - // This is potentially expensive and requires to - // - create a list of negated inequalities (may allocate under lock); - // - perform a pairwise comparison of direct and negated inequalities; - // - copy the lists of equalities and inequalities, and move entries between - // them; - // only for the purpose of sparing a temporary variable in cases where an - // implicit equality between a variable and a stripe expression is present in - // the input. - - // Do the first sweep over (in)equalities to collect the information necessary - // to allocate the SDB matrix (number of dimensions, symbol and temporary - // variables required for stripe expressions). - llvm::SmallSetVector stripes; - int maxDim = -1; - int maxSymbol = -1; - for (auto expr : inequalities) - collectSDBMBuildInfo(expr, maxDim, maxSymbol, stripes); - for (auto expr : equalities) - collectSDBMBuildInfo(expr, maxDim, maxSymbol, stripes); - // Indexing of dimensions starts with 0, obtain the number of dimensions by - // incrementing the maximal position of the dimension seen in expressions. - result.numDims = maxDim + 1; - result.numSymbols = maxSymbol + 1; - result.numTemporaries = 0; - - // Helper function that returns the position of the variable represented by - // an SDBM input expression. - auto linearPosition = [result](SDBMInputExpr expr) { - if (expr.isa()) - return result.getDimPosition(expr.getPosition()); - return result.getSymbolPosition(expr.getPosition()); - }; - - // Check if some stripe expressions are equal to another variable. In - // particular, look for the equalities of the form - // d0 - stripe-expression = 0, or - // stripe-expression - d0 = 0. - // There may be multiple variables that are equal to the same stripe - // expression. Keep track of those in pointExprToStripe. - // There may also be multiple stripe expressions equal to the same variable. - // Introduce a temporary variable for each of those. - llvm::DenseMap> pointExprToStripe; - unsigned numTemporaries = 0; - - auto updateStripePointMaps = [&numTemporaries, &result, &pointExprToStripe, - linearPosition](SDBMInputExpr input, - SDBMExpr expr) { - unsigned position = linearPosition(input); - if (result.stripeToPoint.count(position) && - result.stripeToPoint[position] != expr) { - position = result.getNumVariables() + numTemporaries++; - } - pointExprToStripe[expr].push_back(position); - result.stripeToPoint.insert(std::make_pair(position, expr)); - }; - - for (auto eq : equalities) { - auto diffExpr = eq.dyn_cast(); - if (!diffExpr) - continue; - - auto lhs = diffExpr.getLHS(); - auto rhs = diffExpr.getRHS(); - auto lhsInput = lhs.dyn_cast(); - auto rhsInput = rhs.dyn_cast(); - - if (lhsInput && stripes.count(rhs)) - updateStripePointMaps(lhsInput, rhs); - if (rhsInput && stripes.count(lhs)) - updateStripePointMaps(rhsInput, lhs); - } - - // Assign the remaining stripe expressions to temporary variables. These - // expressions are the ones that could not be associated with an existing - // variable in the previous step. - for (auto expr : stripes) { - if (pointExprToStripe.count(expr)) - continue; - unsigned position = result.getNumVariables() + numTemporaries++; - pointExprToStripe[expr].push_back(position); - result.stripeToPoint.insert(std::make_pair(position, expr)); - } - - // Create the DBM matrix, initialized to infinity values for the least tight - // possible bound (x - y <= infinity is always true). - result.numTemporaries = numTemporaries; - result.matrix.resize(result.getNumVariables() * result.getNumVariables(), - IntInfty::infinity()); - - SDBMBuilder builder(pointExprToStripe, linearPosition); - - // Only keep the tightest constraint. Since we transform everything into - // less-than-or-equals-to inequalities, keep the smallest constant. For - // example, if we have d0 - d1 <= 42 and d0 - d1 <= 2, we keep the latter. - // Note that the input expressions are in the shape of d0 - d1 + -42 <= 0 - // so we negate the value before storing it. - // In case where the positive and the negative positions are equal, the - // corresponding expression has the form d0 - d0 + -42 <= 0. If the constant - // value is positive, the set defined by SDBM is trivially empty. We store - // this value anyway and continue processing to maintain the correspondence - // between the matrix form and the list-of-SDBMExpr form. - // TODO(zinenko): we may want to reconsider this once we have canonicalization - // or simplification in place - auto updateMatrix = [](SDBM &sdbm, const SDBMBuilderResult &r) { - for (auto positivePos : r.positivePos) { - for (auto negativePos : r.negativePos) { - auto &m = sdbm.at(negativePos, positivePos); - m = m < -r.value ? m : -r.value; - } - } - }; - - // Do the second sweep on (in)equalities, updating the SDB matrix to reflect - // the constraints. - for (auto ineq : inequalities) - updateMatrix(result, builder.visit(ineq)); - - // An equality f(x) = 0 is represented as a pair of inequalities {f(x) >= 0; - // f(x) <= 0} or, alternatively, {-f(x) <= 0 and f(x) <= 0}. - for (auto eq : equalities) { - updateMatrix(result, builder.visit(eq)); - updateMatrix(result, builder.visit(-eq)); - } - - // Add the inequalities induced by stripe equalities. - // t = x # C => t <= x <= t + C - 1 - // which is equivalent to - // {t - x <= 0; - // x - t - (C - 1) <= 0}. - for (const auto &pair : result.stripeToPoint) { - auto stripe = pair.second.cast(); - SDBMBuilderResult update = builder.visit(stripe.getVar()); - assert(update.negativePos.size() == 1 && update.negativePos[0] == 0 && - "unexpected negated variable in stripe expression"); - assert(update.value == 0 && - "unexpected non-zero value in stripe expression"); - update.negativePos.clear(); - update.negativePos.push_back(pair.first); - update.value = -(stripe.getStripeFactor().getValue() - 1); - updateMatrix(result, update); - - std::swap(update.negativePos, update.positivePos); - update.value = 0; - updateMatrix(result, update); - } - - return result; -} - -// Given a row and a column position in the square DBM, insert one equality -// or up to two inequalities that correspond the entries (col, row) and (row, -// col) in the DBM. `rowExpr` and `colExpr` contain the expressions such that -// colExpr - rowExpr <= V where V is the value at (row, col) in the DBM. -// If one of the expressions is derived from another using a stripe operation, -// check if the inequalities induced by the stripe operation subsume the -// inequalities defined in the DBM and if so, elide these inequalities. -void SDBM::convertDBMElement(unsigned row, unsigned col, - SDBMPositiveExpr rowExpr, SDBMPositiveExpr colExpr, - SmallVectorImpl &inequalities, - SmallVectorImpl &equalities) { - using ops_assertions::operator+; - using ops_assertions::operator-; - - auto diffIJValue = at(col, row); - auto diffJIValue = at(row, col); - - // If symmetric entries are opposite, the corresponding expressions are equal. - if (diffIJValue.isFinite() && - diffIJValue.getValue() == -diffJIValue.getValue()) { - equalities.push_back(rowExpr - colExpr - diffIJValue.getValue()); - return; - } - - // Given an inequality x0 - x1 <= A, check if x0 is a stripe variable derived - // from x1: x0 = x1 # B. If so, it would imply the constraints - // x0 <= x1 <= x0 + (B - 1) <=> x0 - x1 <= 0 and x1 - x0 <= (B - 1). - // Therefore, if A >= 0, this inequality is subsumed by that implied - // by the stripe equality and thus can be elided. - // Similarly, check if x1 is a stripe variable derived from x0: x1 = x0 # C. - // If so, it would imply the constraints x1 <= x0 <= x1 + (C - 1) <=> - // <=> x1 - x0 <= 0 and x0 - x1 <= (C - 1). Therefore, if A >= (C - 1), this - // inequality can be elided. - // - // Note: x0 and x1 may be a stripe expressions themselves, we rely on stripe - // expressions being stored without temporaries on the RHS and being passed - // into this function as is. - auto canElide = [this](unsigned x0, unsigned x1, SDBMExpr x0Expr, - SDBMExpr x1Expr, int64_t value) { - if (stripeToPoint.count(x0)) { - auto stripe = stripeToPoint[x0].cast(); - SDBMPositiveExpr var = stripe.getVar(); - if (x1Expr == var && value >= 0) - return true; - } - if (stripeToPoint.count(x1)) { - auto stripe = stripeToPoint[x1].cast(); - SDBMPositiveExpr var = stripe.getVar(); - if (x0Expr == var && value >= stripe.getStripeFactor().getValue() - 1) - return true; - } - return false; - }; - - // Check row - col. - if (diffIJValue.isFinite() && - !canElide(row, col, rowExpr, colExpr, diffIJValue.getValue())) { - inequalities.push_back(rowExpr - colExpr - diffIJValue.getValue()); - } - // Check col - row. - if (diffJIValue.isFinite() && - !canElide(col, row, colExpr, rowExpr, diffJIValue.getValue())) { - inequalities.push_back(colExpr - rowExpr - diffJIValue.getValue()); - } -} - -// The values on the main diagonal correspond to the upper bound on the -// difference between a variable and itself: d0 - d0 <= C, or alternatively -// to -C <= 0. Only construct the inequalities when C is negative, which -// are trivially false but necessary for the returned system of inequalities -// to indicate that the set it defines is empty. -void SDBM::convertDBMDiagonalElement(unsigned pos, SDBMPositiveExpr expr, - SmallVectorImpl &inequalities) { - auto selfDifference = at(pos, pos); - if (selfDifference.isFinite() && selfDifference < 0) { - auto selfDifferenceValueExpr = - SDBMConstantExpr::get(expr.getDialect(), -selfDifference.getValue()); - inequalities.push_back(selfDifferenceValueExpr); - } -} - -void SDBM::getSDBMExpressions(SDBMDialect *dialect, - SmallVectorImpl &inequalities, - SmallVectorImpl &equalities) { - using ops_assertions::operator-; - using ops_assertions::operator+; - - // Helper function that creates an SDBMInputExpr given the linearized position - // of variable in the DBM. - auto getInput = [dialect, this](unsigned matrixPos) -> SDBMInputExpr { - if (matrixPos < numDims) - return SDBMDimExpr::get(dialect, matrixPos); - return SDBMSymbolExpr::get(dialect, matrixPos - numDims); - }; - - // The top-left value corresponds to inequality 0 <= C. If C is negative, the - // set defined by SDBM is trivially empty and we add the constraint -C <= 0 to - // the list of inequalities. Otherwise, the constraint is trivially true and - // we ignore it. - auto difference = at(0, 0); - if (difference.isFinite() && difference < 0) { - inequalities.push_back( - SDBMConstantExpr::get(dialect, -difference.getValue())); - } - - // Traverse the segment of the matrix that involves non-temporary variables. - unsigned numTrueVariables = numDims + numSymbols; - for (unsigned i = 0; i < numTrueVariables; ++i) { - // The first row and column represent numerical upper and lower bound on - // each variable. Transform them into inequalities if they are finite. - auto upperBound = at(0, 1 + i); - auto lowerBound = at(1 + i, 0); - auto inputExpr = getInput(i); - if (upperBound.isFinite() && - upperBound.getValue() == -lowerBound.getValue()) { - equalities.push_back(inputExpr - upperBound.getValue()); - } else if (upperBound.isFinite()) { - inequalities.push_back(inputExpr - upperBound.getValue()); - } else if (lowerBound.isFinite()) { - inequalities.push_back(-inputExpr - lowerBound.getValue()); - } - - // Introduce trivially false inequalities if required by diagonal elements. - convertDBMDiagonalElement(1 + i, inputExpr, inequalities); - - // Introduce equalities or inequalities between non-temporary variables. - for (unsigned j = 0; j < i; ++j) { - convertDBMElement(1 + i, 1 + j, getInput(i), getInput(j), inequalities, - equalities); - } - } - - // Add equalities for stripe expressions that define non-temporary - // variables. Temporary variables will be substituted into their uses and - // should not appear in the resulting equalities. - for (const auto &stripePair : stripeToPoint) { - unsigned position = stripePair.first; - if (position < 1 + numTrueVariables) { - equalities.push_back(getInput(position - 1) - stripePair.second); - } - } - - // Add equalities / inequalities involving temporaries by replacing the - // temporaries with stripe expressions that define them. - for (unsigned i = 1 + numTrueVariables, e = getNumVariables(); i < e; ++i) { - // Mixed constraints involving one temporary (j) and one non-temporary (i) - // variable. - for (unsigned j = 0; j < numTrueVariables; ++j) { - convertDBMElement(i, 1 + j, stripeToPoint[i].cast(), - getInput(j), inequalities, equalities); - } - - // Constraints involving only temporary variables. - for (unsigned j = 1 + numTrueVariables; j < i; ++j) { - convertDBMElement(i, j, stripeToPoint[i].cast(), - stripeToPoint[j].cast(), inequalities, - equalities); - } - - // Introduce trivially false inequalities if required by diagonal elements. - convertDBMDiagonalElement(i, stripeToPoint[i].cast(), - inequalities); - } -} - -void SDBM::print(llvm::raw_ostream &os) { - unsigned numVariables = getNumVariables(); - - // Helper function that prints the name of the variable given its linearized - // position in the DBM. - auto getVarName = [this](unsigned matrixPos) -> std::string { - if (matrixPos == 0) - return "cst"; - matrixPos -= 1; - if (matrixPos < numDims) - return llvm::formatv("d{0}", matrixPos); - matrixPos -= numDims; - if (matrixPos < numSymbols) - return llvm::formatv("s{0}", matrixPos); - matrixPos -= numSymbols; - return llvm::formatv("t{0}", matrixPos); - }; - - // Header row. - os << " cst"; - for (unsigned i = 1; i < numVariables; ++i) { - os << llvm::formatv(" {0,4}", getVarName(i)); - } - os << '\n'; - - // Data rows. - for (unsigned i = 0; i < numVariables; ++i) { - os << llvm::formatv("{0,-4}", getVarName(i)); - for (unsigned j = 0; j < numVariables; ++j) { - IntInfty value = operator()(i, j); - if (!value.isFinite()) - os << " inf"; - else - os << llvm::formatv(" {0,4}", value.getValue()); - } - os << '\n'; - } - - // Explanation of temporaries. - for (const auto &pair : stripeToPoint) { - os << getVarName(pair.first) << " = "; - pair.second.print(os); - os << '\n'; - } -} - -void SDBM::dump() { print(llvm::errs()); } diff --git a/mlir/lib/SDBM/SDBMDialect.cpp b/mlir/lib/SDBM/SDBMDialect.cpp deleted file mode 100644 index e000209e165..00000000000 --- a/mlir/lib/SDBM/SDBMDialect.cpp +++ /dev/null @@ -1,20 +0,0 @@ -//===- SDBMDialect.cpp - Dialect for striped difference-bound matrices ----===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#include "mlir/SDBM/SDBMDialect.h" - -static mlir::DialectRegistration SDBMDialect; diff --git a/mlir/lib/SDBM/SDBMExpr.cpp b/mlir/lib/SDBM/SDBMExpr.cpp deleted file mode 100644 index 5757ebefe52..00000000000 --- a/mlir/lib/SDBM/SDBMExpr.cpp +++ /dev/null @@ -1,647 +0,0 @@ -//===- SDBMExpr.cpp - MLIR SDBM Expression implementation -----------------===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// A striped difference-bound matrix (SDBM) expression is a constant expression, -// an identifier, a binary expression with constant RHS and +, stripe operators -// or a difference expression between two identifiers. -// -//===----------------------------------------------------------------------===// - -#include "mlir/SDBM/SDBMExpr.h" -#include "SDBMExprDetail.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineExprVisitor.h" -#include "mlir/SDBM/SDBMDialect.h" - -#include "llvm/Support/raw_ostream.h" - -using namespace mlir; - -namespace { -/// A simple compositional matcher for AffineExpr -/// -/// Example usage: -/// -/// ```c++ -/// AffineExprMatcher x, C, m; -/// AffineExprMatcher pattern1 = ((x % C) * m) + x; -/// AffineExprMatcher pattern2 = x + ((x % C) * m); -/// if (pattern1.match(expr) || pattern2.match(expr)) { -/// ... -/// } -/// ``` -class AffineExprMatcherStorage; -class AffineExprMatcher { -public: - AffineExprMatcher(); - AffineExprMatcher(const AffineExprMatcher &other); - - AffineExprMatcher operator+(AffineExprMatcher other) { - return AffineExprMatcher(AffineExprKind::Add, *this, other); - } - AffineExprMatcher operator*(AffineExprMatcher other) { - return AffineExprMatcher(AffineExprKind::Mul, *this, other); - } - AffineExprMatcher floorDiv(AffineExprMatcher other) { - return AffineExprMatcher(AffineExprKind::FloorDiv, *this, other); - } - AffineExprMatcher ceilDiv(AffineExprMatcher other) { - return AffineExprMatcher(AffineExprKind::CeilDiv, *this, other); - } - AffineExprMatcher operator%(AffineExprMatcher other) { - return AffineExprMatcher(AffineExprKind::Mod, *this, other); - } - - AffineExpr match(AffineExpr expr); - AffineExpr matched(); - Optional getMatchedConstantValue(); - -private: - AffineExprMatcher(AffineExprKind k, AffineExprMatcher a, AffineExprMatcher b); - AffineExprKind kind; // only used to match in binary op cases. - // A shared_ptr allows multiple references to same matcher storage without - // worrying about ownership or dealing with an arena. To be cleaned up if we - // go with this. - std::shared_ptr storage; -}; - -class AffineExprMatcherStorage { -public: - AffineExprMatcherStorage() {} - AffineExprMatcherStorage(const AffineExprMatcherStorage &other) - : subExprs(other.subExprs.begin(), other.subExprs.end()), - matched(other.matched) {} - AffineExprMatcherStorage(ArrayRef exprs) - : subExprs(exprs.begin(), exprs.end()) {} - AffineExprMatcherStorage(AffineExprMatcher &a, AffineExprMatcher &b) - : subExprs({a, b}) {} - llvm::SmallVector subExprs; - AffineExpr matched; -}; -} // namespace - -AffineExprMatcher::AffineExprMatcher() - : kind(AffineExprKind::Constant), storage(new AffineExprMatcherStorage()) {} - -AffineExprMatcher::AffineExprMatcher(const AffineExprMatcher &other) - : kind(other.kind), storage(other.storage) {} - -Optional AffineExprMatcher::getMatchedConstantValue() { - if (auto cst = storage->matched.dyn_cast()) - return cst.getValue(); - return None; -} - -AffineExpr AffineExprMatcher::match(AffineExpr expr) { - if (kind > AffineExprKind::LAST_AFFINE_BINARY_OP) { - if (storage->matched) - if (storage->matched != expr) - return AffineExpr(); - storage->matched = expr; - return storage->matched; - } - if (kind != expr.getKind()) { - return AffineExpr(); - } - if (auto bin = expr.dyn_cast()) { - if (!storage->subExprs.empty() && - !storage->subExprs[0].match(bin.getLHS())) { - return AffineExpr(); - } - if (!storage->subExprs.empty() && - !storage->subExprs[1].match(bin.getRHS())) { - return AffineExpr(); - } - if (storage->matched) - if (storage->matched != expr) - return AffineExpr(); - storage->matched = expr; - return storage->matched; - } - llvm_unreachable("binary expected"); -} - -AffineExpr AffineExprMatcher::matched() { return storage->matched; } - -AffineExprMatcher::AffineExprMatcher(AffineExprKind k, AffineExprMatcher a, - AffineExprMatcher b) - : kind(k), storage(new AffineExprMatcherStorage(a, b)) { - storage->subExprs.push_back(a); - storage->subExprs.push_back(b); -} - -//===----------------------------------------------------------------------===// -// SDBMExpr -//===----------------------------------------------------------------------===// - -SDBMExprKind SDBMExpr::getKind() const { return impl->getKind(); } - -MLIRContext *SDBMExpr::getContext() const { - return impl->dialect->getContext(); -} - -SDBMDialect *SDBMExpr::getDialect() const { return impl->dialect; } - -void SDBMExpr::print(raw_ostream &os) const { - struct Printer : public SDBMVisitor { - Printer(raw_ostream &ostream) : prn(ostream) {} - - void visitSum(SDBMSumExpr expr) { - visitVarying(expr.getLHS()); - prn << " + "; - visitConstant(expr.getRHS()); - } - void visitDiff(SDBMDiffExpr expr) { - visitPositive(expr.getLHS()); - prn << " - "; - visitPositive(expr.getRHS()); - } - void visitDim(SDBMDimExpr expr) { prn << 'd' << expr.getPosition(); } - void visitSymbol(SDBMSymbolExpr expr) { prn << 's' << expr.getPosition(); } - void visitStripe(SDBMStripeExpr expr) { - visitPositive(expr.getVar()); - prn << " # "; - visitConstant(expr.getStripeFactor()); - } - void visitNeg(SDBMNegExpr expr) { - prn << '-'; - visitPositive(expr.getVar()); - } - void visitConstant(SDBMConstantExpr expr) { prn << expr.getValue(); } - - raw_ostream &prn; - }; - Printer printer(os); - printer.visit(*this); -} - -void SDBMExpr::dump() const { - print(llvm::errs()); - llvm::errs() << '\n'; -} - -namespace { -// Helper class to perform negation of an SDBM expression. -struct SDBMNegator : public SDBMVisitor { - // Any positive expression is wrapped into a negation expression. - // -(x) = -x - SDBMExpr visitPositive(SDBMPositiveExpr expr) { - return SDBMNegExpr::get(expr); - } - // A negation expression is unwrapped. - // -(-x) = x - SDBMExpr visitNeg(SDBMNegExpr expr) { return expr.getVar(); } - // The value of the constant is negated. - SDBMExpr visitConstant(SDBMConstantExpr expr) { - return SDBMConstantExpr::get(expr.getDialect(), -expr.getValue()); - } - // Both terms of the sum are negated recursively. - SDBMExpr visitSum(SDBMSumExpr expr) { - return SDBMSumExpr::get(visit(expr.getLHS()).cast(), - visit(expr.getRHS()).cast()); - } - // Terms of a difference are interchanged. - // -(x - y) = y - x - SDBMExpr visitDiff(SDBMDiffExpr expr) { - return SDBMDiffExpr::get(expr.getRHS(), expr.getLHS()); - } -}; -} // namespace - -SDBMExpr SDBMExpr::operator-() { return SDBMNegator().visit(*this); } - -//===----------------------------------------------------------------------===// -// SDBMSumExpr -//===----------------------------------------------------------------------===// - -SDBMSumExpr SDBMSumExpr::get(SDBMVaryingExpr lhs, SDBMConstantExpr rhs) { - assert(lhs && "expected SDBM variable expression"); - assert(rhs && "expected SDBM constant"); - - // If LHS of a sum is another sum, fold the constant RHS parts. - if (auto lhsSum = lhs.dyn_cast()) { - lhs = lhsSum.getLHS(); - rhs = SDBMConstantExpr::get(rhs.getDialect(), - rhs.getValue() + lhsSum.getRHS().getValue()); - } - - StorageUniquer &uniquer = lhs.getDialect()->getUniquer(); - return uniquer.get( - /*initFn=*/{}, static_cast(SDBMExprKind::Add), lhs, rhs); -} - -SDBMVaryingExpr SDBMSumExpr::getLHS() const { - return static_cast(impl)->lhs; -} - -SDBMConstantExpr SDBMSumExpr::getRHS() const { - return static_cast(impl)->rhs; -} - -AffineExpr SDBMExpr::getAsAffineExpr() const { - struct Converter : public SDBMVisitor { - AffineExpr visitSum(SDBMSumExpr expr) { - AffineExpr lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS()); - return lhs + rhs; - } - - AffineExpr visitStripe(SDBMStripeExpr expr) { - AffineExpr lhs = visit(expr.getVar()), - rhs = visit(expr.getStripeFactor()); - return lhs - (lhs % rhs); - } - - AffineExpr visitDiff(SDBMDiffExpr expr) { - AffineExpr lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS()); - return lhs - rhs; - } - - AffineExpr visitDim(SDBMDimExpr expr) { - return getAffineDimExpr(expr.getPosition(), expr.getContext()); - } - - AffineExpr visitSymbol(SDBMSymbolExpr expr) { - return getAffineSymbolExpr(expr.getPosition(), expr.getContext()); - } - - AffineExpr visitNeg(SDBMNegExpr expr) { - return getAffineBinaryOpExpr(AffineExprKind::Mul, - getAffineConstantExpr(-1, expr.getContext()), - visit(expr.getVar())); - } - - AffineExpr visitConstant(SDBMConstantExpr expr) { - return getAffineConstantExpr(expr.getValue(), expr.getContext()); - } - } converter; - return converter.visit(*this); -} - -Optional SDBMExpr::tryConvertAffineExpr(AffineExpr affine) { - struct Converter : public AffineExprVisitor { - SDBMExpr visitAddExpr(AffineBinaryOpExpr expr) { - // Attempt to recover a stripe expression. Because AffineExprs don't have - // a first-class difference kind, we check for both x + -1 * (x mod C) and - // -1 * (x mod C) + x cases. - AffineExprMatcher x, C, m; - AffineExprMatcher pattern1 = ((x % C) * m) + x; - AffineExprMatcher pattern2 = x + ((x % C) * m); - if ((pattern1.match(expr) && m.getMatchedConstantValue() == -1) || - (pattern2.match(expr) && m.getMatchedConstantValue() == -1)) { - if (auto convertedLHS = visit(x.matched())) { - // TODO(ntv): return convertedLHS.stripe(C); - return SDBMStripeExpr::get( - convertedLHS.cast(), - visit(C.matched()).cast()); - } - } - auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS()); - if (!lhs || !rhs) - return {}; - - // In a "add" AffineExpr, the constant always appears on the right. If - // there were two constants, they would have been folded away. - assert(!lhs.isa() && "non-canonical affine expression"); - auto rhsConstant = rhs.dyn_cast(); - - // SDBM accepts LHS variables and RHS constants in a sum. - auto lhsVar = lhs.dyn_cast(); - auto rhsVar = rhs.dyn_cast(); - if (rhsConstant && lhsVar) - return SDBMSumExpr::get(lhsVar, rhsConstant); - - // The sum of a negated variable and a non-negated variable is a - // difference, supported as a special kind in SDBM. Because AffineExprs - // don't have first-class difference kind, check both LHS and RHS for - // negation. - auto lhsPos = lhs.dyn_cast(); - auto rhsPos = rhs.dyn_cast(); - auto lhsNeg = lhs.dyn_cast(); - auto rhsNeg = rhs.dyn_cast(); - if (lhsNeg && rhsVar) - return SDBMDiffExpr::get(rhsPos, lhsNeg.getVar()); - if (rhsNeg && lhsVar) - return SDBMDiffExpr::get(lhsPos, rhsNeg.getVar()); - - // Other cases don't fit into SDBM. - return {}; - } - - SDBMExpr visitMulExpr(AffineBinaryOpExpr expr) { - // Attempt to recover a stripe expression "x # C = (x floordiv C) * C". - AffineExprMatcher x, C; - AffineExprMatcher pattern = (x.floorDiv(C)) * C; - if (pattern.match(expr)) { - if (SDBMExpr converted = visit(x.matched())) { - if (auto varConverted = converted.dyn_cast()) - // TODO(ntv): return varConverted.stripe(C.getConstantValue()); - return SDBMStripeExpr::get( - varConverted, - SDBMConstantExpr::get(dialect, - C.getMatchedConstantValue().getValue())); - } - } - - auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS()); - if (!lhs || !rhs) - return {}; - - // In a "mul" AffineExpr, the constant always appears on the right. If - // there were two constants, they would have been folded away. - assert(!lhs.isa() && "non-canonical affine expression"); - auto rhsConstant = rhs.dyn_cast(); - if (!rhsConstant) - return {}; - - // The only supported "multiplication" expression is an SDBM is dimension - // negation, that is a product of dimension and constant -1. - auto lhsVar = lhs.dyn_cast(); - if (lhsVar && rhsConstant.getValue() == -1) - return SDBMNegExpr::get(lhsVar); - - // Other multiplications are not allowed in SDBM. - return {}; - } - - SDBMExpr visitModExpr(AffineBinaryOpExpr expr) { - auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS()); - if (!lhs || !rhs) - return {}; - - // 'mod' can only be converted to SDBM if its LHS is a variable - // and its RHS is a constant. Then it `x mod c = x - x stripe c`. - auto rhsConstant = rhs.dyn_cast(); - auto lhsVar = rhs.dyn_cast(); - if (!lhsVar || !rhsConstant) - return {}; - return SDBMDiffExpr::get(lhsVar, - SDBMStripeExpr::get(lhsVar, rhsConstant)); - } - - // `a floordiv b = (a stripe b) / b`, but we have no division in SDBM - SDBMExpr visitFloorDivExpr(AffineBinaryOpExpr expr) { return {}; } - SDBMExpr visitCeilDivExpr(AffineBinaryOpExpr expr) { return {}; } - - // Dimensions, symbols and constants are converted trivially. - SDBMExpr visitConstantExpr(AffineConstantExpr expr) { - return SDBMConstantExpr::get(dialect, expr.getValue()); - } - SDBMExpr visitDimExpr(AffineDimExpr expr) { - return SDBMDimExpr::get(dialect, expr.getPosition()); - } - SDBMExpr visitSymbolExpr(AffineSymbolExpr expr) { - return SDBMSymbolExpr::get(dialect, expr.getPosition()); - } - - SDBMDialect *dialect; - } converter; - converter.dialect = affine.getContext()->getRegisteredDialect(); - - if (auto result = converter.visit(affine)) - return result; - return None; -} - -//===----------------------------------------------------------------------===// -// SDBMDiffExpr -//===----------------------------------------------------------------------===// - -SDBMDiffExpr SDBMDiffExpr::get(SDBMPositiveExpr lhs, SDBMPositiveExpr rhs) { - assert(lhs && "expected SDBM dimension"); - assert(rhs && "expected SDBM dimension"); - - StorageUniquer &uniquer = lhs.getDialect()->getUniquer(); - return uniquer.get( - /*initFn=*/{}, static_cast(SDBMExprKind::Diff), lhs, rhs); -} - -SDBMPositiveExpr SDBMDiffExpr::getLHS() const { - return static_cast(impl)->lhs; -} - -SDBMPositiveExpr SDBMDiffExpr::getRHS() const { - return static_cast(impl)->rhs; -} - -//===----------------------------------------------------------------------===// -// SDBMStripeExpr -//===----------------------------------------------------------------------===// - -SDBMStripeExpr SDBMStripeExpr::get(SDBMPositiveExpr var, - SDBMConstantExpr stripeFactor) { - assert(var && "expected SDBM variable expression"); - assert(stripeFactor && "expected non-null stripe factor"); - if (stripeFactor.getValue() <= 0) - llvm::report_fatal_error("non-positive stripe factor"); - - StorageUniquer &uniquer = var.getDialect()->getUniquer(); - return uniquer.get( - /*initFn=*/{}, static_cast(SDBMExprKind::Stripe), var, - stripeFactor); -} - -SDBMPositiveExpr SDBMStripeExpr::getVar() const { - if (SDBMVaryingExpr lhs = static_cast(impl)->lhs) - return lhs.cast(); - return {}; -} - -SDBMConstantExpr SDBMStripeExpr::getStripeFactor() const { - return static_cast(impl)->rhs; -} - -//===----------------------------------------------------------------------===// -// SDBMInputExpr -//===----------------------------------------------------------------------===// - -unsigned SDBMInputExpr::getPosition() const { - return static_cast(impl)->position; -} - -//===----------------------------------------------------------------------===// -// SDBMDimExpr -//===----------------------------------------------------------------------===// - -SDBMDimExpr SDBMDimExpr::get(SDBMDialect *dialect, unsigned position) { - assert(dialect && "expected non-null dialect"); - - auto assignDialect = [dialect](detail::SDBMPositiveExprStorage *storage) { - storage->dialect = dialect; - }; - - StorageUniquer &uniquer = dialect->getUniquer(); - return uniquer.get( - assignDialect, static_cast(SDBMExprKind::DimId), position); -} - -//===----------------------------------------------------------------------===// -// SDBMSymbolExpr -//===----------------------------------------------------------------------===// - -SDBMSymbolExpr SDBMSymbolExpr::get(SDBMDialect *dialect, unsigned position) { - assert(dialect && "expected non-null dialect"); - - auto assignDialect = [dialect](detail::SDBMPositiveExprStorage *storage) { - storage->dialect = dialect; - }; - - StorageUniquer &uniquer = dialect->getUniquer(); - return uniquer.get( - assignDialect, static_cast(SDBMExprKind::SymbolId), position); -} - -//===----------------------------------------------------------------------===// -// SDBMConstantExpr -//===----------------------------------------------------------------------===// - -SDBMConstantExpr SDBMConstantExpr::get(SDBMDialect *dialect, int64_t value) { - assert(dialect && "expected non-null dialect"); - - auto assignCtx = [dialect](detail::SDBMConstantExprStorage *storage) { - storage->dialect = dialect; - }; - - StorageUniquer &uniquer = dialect->getUniquer(); - return uniquer.get( - assignCtx, static_cast(SDBMExprKind::Constant), value); -} - -int64_t SDBMConstantExpr::getValue() const { - return static_cast(impl)->constant; -} - -//===----------------------------------------------------------------------===// -// SDBMNegExpr -//===----------------------------------------------------------------------===// - -SDBMNegExpr SDBMNegExpr::get(SDBMPositiveExpr var) { - assert(var && "expected non-null SDBM variable expression"); - - StorageUniquer &uniquer = var.getDialect()->getUniquer(); - return uniquer.get( - /*initFn=*/{}, static_cast(SDBMExprKind::Neg), var); -} - -SDBMPositiveExpr SDBMNegExpr::getVar() const { - return static_cast(impl)->dim; -} - -namespace mlir { -namespace ops_assertions { - -SDBMExpr operator+(SDBMExpr lhs, SDBMExpr rhs) { - // If one of the operands is a negation, take a difference rather than a sum. - auto lhsNeg = lhs.dyn_cast(); - auto rhsNeg = rhs.dyn_cast(); - assert(!(lhsNeg && rhsNeg) && "a sum of negated expressions is a negation of " - "a sum of variables and not a correct SDBM"); - if (lhsNeg) - return rhs - lhsNeg.getVar(); - if (rhsNeg) - return lhs - rhsNeg.getVar(); - - // If LHS is a constant and RHS is not, swap the order to get into a supported - // sum case. From now on, RHS must be a constant. - auto lhsConstant = lhs.dyn_cast(); - auto rhsConstant = rhs.dyn_cast(); - if (!rhsConstant && lhsConstant) { - std::swap(lhs, rhs); - std::swap(lhsConstant, rhsConstant); - } - assert(rhsConstant && "at least one operand must be a constant"); - - // If LHS is another sum, first compute the sum of its variable - // part with the other argument and then add the constant part to enable - // constant folding (the variable part may, e.g., be a negation that requires - // to enter this function again). - auto lhsSum = lhs.dyn_cast(); - if (lhsSum) - return lhsSum.getLHS() + - (lhsSum.getRHS().getValue() + rhsConstant.getValue()); - - // Constant-fold if LHS is a constant. - if (lhsConstant) - return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant.getValue() + - rhsConstant.getValue()); - - // Fold x + 0 == x. - if (rhsConstant.getValue() == 0) - return lhs; - - return SDBMSumExpr::get(lhs.cast(), - rhs.cast()); -} - -SDBMExpr operator-(SDBMExpr lhs, SDBMExpr rhs) { - // Fold x - x == 0. - if (lhs == rhs) - return SDBMConstantExpr::get(lhs.getDialect(), 0); - - // LHS and RHS may be constants. - auto lhsConstant = lhs.dyn_cast(); - auto rhsConstant = rhs.dyn_cast(); - - // Constant fold if both LHS and RHS are constants. - if (lhsConstant && rhsConstant) - return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant.getValue() - - rhsConstant.getValue()); - - // Replace a difference with a sum with a negated value if one of LHS and RHS - // is a constant: - // x - C == x + (-C); - // C - x == -x + C. - // This calls into operator+ for further simplification. - if (rhsConstant) - return lhs + (-rhsConstant); - if (lhsConstant) - return -rhs + lhsConstant; - - // Hoist constant factors outside the difference if any of sides is a sum: - // (x + A) - (y - B) == x - y + (A - B). - // If either LHS or RHS is a sum, collect the constant values separately and - // update LHS and RHS to point to the variable part of the sum. - auto lhsSum = lhs.dyn_cast(); - auto rhsSum = rhs.dyn_cast(); - int64_t value = 0; - if (lhsSum) { - value += lhsSum.getRHS().getValue(); - lhs = lhsSum.getLHS(); - } - if (rhsSum) { - value -= rhsSum.getRHS().getValue(); - rhs = rhsSum.getLHS(); - } - - // This calls into operator+ for futher simplification in case value == 0. - return SDBMDiffExpr::get(lhs.cast(), - rhs.cast()) + - value; -} - -SDBMExpr stripe(SDBMExpr expr, SDBMExpr factor) { - auto constantFactor = factor.cast(); - assert(constantFactor.getValue() > 0 && "non-positive stripe"); - - // Fold x # 1 = x. - if (constantFactor.getValue() == 1) - return expr; - - return SDBMStripeExpr::get(expr.cast(), constantFactor); -} - -} // namespace ops_assertions -} // namespace mlir diff --git a/mlir/lib/SDBM/SDBMExprDetail.h b/mlir/lib/SDBM/SDBMExprDetail.h deleted file mode 100644 index d2c241e744b..00000000000 --- a/mlir/lib/SDBM/SDBMExprDetail.h +++ /dev/null @@ -1,138 +0,0 @@ -//===- SDBMExprDetail.h - MLIR SDBM Expression storage details --*- C++ -*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This holds implementation details of SDBMExpr, in particular underlying -// storage types. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_IR_SDBMEXPRDETAIL_H -#define MLIR_IR_SDBMEXPRDETAIL_H - -#include "mlir/SDBM/SDBMExpr.h" -#include "mlir/Support/StorageUniquer.h" - -namespace mlir { - -class SDBMDialect; - -namespace detail { - -// Base storage class for SDBMExpr. -struct SDBMExprStorage : public StorageUniquer::BaseStorage { - SDBMExprKind getKind() { - return static_cast(BaseStorage::getKind()); - } - - SDBMDialect *dialect; -}; - -// Storage class for SDBM sum and stripe expressions. -struct SDBMBinaryExprStorage : public SDBMExprStorage { - using KeyTy = std::pair; - - bool operator==(const KeyTy &key) const { - return std::get<0>(key) == lhs && std::get<1>(key) == rhs; - } - - static SDBMBinaryExprStorage * - construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) { - auto *result = allocator.allocate(); - result->lhs = std::get<0>(key); - result->rhs = std::get<1>(key); - result->dialect = result->lhs.getDialect(); - return result; - } - - SDBMVaryingExpr lhs; - SDBMConstantExpr rhs; -}; - -// Storage class for SDBM difference expressions. -struct SDBMDiffExprStorage : public SDBMExprStorage { - using KeyTy = std::pair; - - bool operator==(const KeyTy &key) const { - return std::get<0>(key) == lhs && std::get<1>(key) == rhs; - } - - static SDBMDiffExprStorage * - construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) { - auto *result = allocator.allocate(); - result->lhs = std::get<0>(key); - result->rhs = std::get<1>(key); - result->dialect = result->lhs.getDialect(); - return result; - } - - SDBMPositiveExpr lhs; - SDBMPositiveExpr rhs; -}; - -// Storage class for SDBM constant expressions. -struct SDBMConstantExprStorage : public SDBMExprStorage { - using KeyTy = int64_t; - - bool operator==(const KeyTy &key) const { return constant == key; } - - static SDBMConstantExprStorage * - construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) { - auto *result = allocator.allocate(); - result->constant = key; - return result; - } - - int64_t constant; -}; - -// Storage class for SDBM dimension and symbol expressions. -struct SDBMPositiveExprStorage : public SDBMExprStorage { - using KeyTy = unsigned; - - bool operator==(const KeyTy &key) const { return position == key; } - - static SDBMPositiveExprStorage * - construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) { - auto *result = allocator.allocate(); - result->position = key; - return result; - } - - unsigned position; -}; - -// Storage class for SDBM negation expressions. -struct SDBMNegExprStorage : public SDBMExprStorage { - using KeyTy = SDBMPositiveExpr; - - bool operator==(const KeyTy &key) const { return key == dim; } - - static SDBMNegExprStorage * - construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) { - auto *result = allocator.allocate(); - result->dim = key; - result->dialect = key.getDialect(); - return result; - } - - SDBMPositiveExpr dim; -}; - -} // end namespace detail -} // end namespace mlir - -#endif // MLIR_IR_SDBMEXPRDETAIL_H diff --git a/mlir/lib/StandardOps/CMakeLists.txt b/mlir/lib/StandardOps/CMakeLists.txt deleted file mode 100644 index e9fce2b0baf..00000000000 --- a/mlir/lib/StandardOps/CMakeLists.txt +++ /dev/null @@ -1,9 +0,0 @@ -file(GLOB globbed *.c *.cpp) -add_llvm_library(MLIRStandardOps - ${globbed} - - ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/StandardOps - ) -add_dependencies(MLIRStandardOps MLIRStandardOpsIncGen LLVMSupport) -target_link_libraries(MLIRStandardOps LLVMSupport) diff --git a/mlir/lib/StandardOps/DialectRegistration.cpp b/mlir/lib/StandardOps/DialectRegistration.cpp deleted file mode 100644 index 1f71a3d014e..00000000000 --- a/mlir/lib/StandardOps/DialectRegistration.cpp +++ /dev/null @@ -1,22 +0,0 @@ -//===- DialectRegistration.cpp - Register standard Op dialect -------------===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#include "mlir/StandardOps/Ops.h" -using namespace mlir; - -// Static initialization for standard op dialect registration. -static DialectRegistration StandardOps; diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp deleted file mode 100644 index 22148eeadc3..00000000000 --- a/mlir/lib/StandardOps/Ops.cpp +++ /dev/null @@ -1,2102 +0,0 @@ -//===- Ops.cpp - Standard MLIR Operations ---------------------------------===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#include "mlir/StandardOps/Ops.h" - -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Function.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/Module.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/IR/Value.h" -#include "mlir/Support/MathExtras.h" -#include "mlir/Support/STLExtras.h" -#include "llvm/ADT/StringSwitch.h" -#include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/raw_ostream.h" -using namespace mlir; - -//===----------------------------------------------------------------------===// -// StandardOpsDialect -//===----------------------------------------------------------------------===// - -/// A custom binary operation printer that omits the "std." prefix from the -/// operation names. -static void printStandardBinaryOp(Operation *op, OpAsmPrinter *p) { - assert(op->getNumOperands() == 2 && "binary op should have two operands"); - assert(op->getNumResults() == 1 && "binary op should have one result"); - - // If not all the operand and result types are the same, just use the - // generic assembly form to avoid omitting information in printing. - auto resultType = op->getResult(0)->getType(); - if (op->getOperand(0)->getType() != resultType || - op->getOperand(1)->getType() != resultType) { - p->printGenericOp(op); - return; - } - - *p << op->getName().getStringRef().drop_front(strlen("std.")) << ' ' - << *op->getOperand(0) << ", " << *op->getOperand(1); - p->printOptionalAttrDict(op->getAttrs()); - - // Now we can output only one type for all operands and the result. - *p << " : " << op->getResult(0)->getType(); -} - -/// A custom cast operation printer that omits the "std." prefix from the -/// operation names. -static void printStandardCastOp(Operation *op, OpAsmPrinter *p) { - *p << op->getName().getStringRef().drop_front(strlen("std.")) << ' ' - << *op->getOperand(0) << " : " << op->getOperand(0)->getType() << " to " - << op->getResult(0)->getType(); -} - -/// A custom cast operation verifier. -template static LogicalResult verifyCastOp(T op) { - auto opType = op.getOperand()->getType(); - auto resType = op.getType(); - if (!T::areCastCompatible(opType, resType)) - return op.emitError("operand type ") << opType << " and result type " - << resType << " are cast incompatible"; - - return success(); -} - -StandardOpsDialect::StandardOpsDialect(MLIRContext *context) - : Dialect(getDialectNamespace(), context) { - addOperations(); -} - -void mlir::printDimAndSymbolList(Operation::operand_iterator begin, - Operation::operand_iterator end, - unsigned numDims, OpAsmPrinter *p) { - *p << '('; - p->printOperands(begin, begin + numDims); - *p << ')'; - - if (begin + numDims != end) { - *p << '['; - p->printOperands(begin + numDims, end); - *p << ']'; - } -} - -// Parses dimension and symbol list, and sets 'numDims' to the number of -// dimension operands parsed. -// Returns 'false' on success and 'true' on error. -ParseResult mlir::parseDimAndSymbolList(OpAsmParser *parser, - SmallVector &operands, - unsigned &numDims) { - SmallVector opInfos; - if (parser->parseOperandList(opInfos, OpAsmParser::Delimiter::Paren)) - return failure(); - // Store number of dimensions for validation by caller. - numDims = opInfos.size(); - - // Parse the optional symbol operands. - auto affineIntTy = parser->getBuilder().getIndexType(); - if (parser->parseOperandList(opInfos, - OpAsmParser::Delimiter::OptionalSquare) || - parser->resolveOperands(opInfos, affineIntTy, operands)) - return failure(); - return success(); -} - -/// Matches a ConstantIndexOp. -/// TODO: This should probably just be a general matcher that uses m_Constant -/// and checks the operation for an index type. -static detail::op_matcher m_ConstantIndex() { - return detail::op_matcher(); -} - -//===----------------------------------------------------------------------===// -// Common canonicalization pattern support logic -//===----------------------------------------------------------------------===// - -namespace { -/// This is a common class used for patterns of the form -/// "someop(memrefcast) -> someop". It folds the source of any memref_cast -/// into the root operation directly. -struct MemRefCastFolder : public RewritePattern { - /// The rootOpName is the name of the root operation to match against. - MemRefCastFolder(StringRef rootOpName, MLIRContext *context) - : RewritePattern(rootOpName, 1, context) {} - - PatternMatchResult match(Operation *op) const override { - for (auto *operand : op->getOperands()) - if (matchPattern(operand, m_Op())) - return matchSuccess(); - - return matchFailure(); - } - - void rewrite(Operation *op, PatternRewriter &rewriter) const override { - for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) - if (auto *memref = op->getOperand(i)->getDefiningOp()) - if (auto cast = dyn_cast(memref)) - op->setOperand(i, cast.getOperand()); - rewriter.updatedRootInPlace(op); - } -}; - -/// Performs const folding `calculate` with element-wise behavior on the two -/// attributes in `operands` and returns the result if possible. -template > -Attribute constFoldBinaryOp(ArrayRef operands, - const CalculationT &calculate) { - assert(operands.size() == 2 && "binary op takes two operands"); - - if (auto lhs = operands[0].dyn_cast_or_null()) { - auto rhs = operands[1].dyn_cast_or_null(); - if (!rhs || lhs.getType() != rhs.getType()) - return {}; - - return AttrElementT::get(lhs.getType(), - calculate(lhs.getValue(), rhs.getValue())); - } else if (auto lhs = operands[0].dyn_cast_or_null()) { - auto rhs = operands[1].dyn_cast_or_null(); - if (!rhs || lhs.getType() != rhs.getType()) - return {}; - - auto elementResult = constFoldBinaryOp( - {lhs.getSplatValue(), rhs.getSplatValue()}, calculate); - if (!elementResult) - return {}; - - return DenseElementsAttr::get(lhs.getType(), elementResult); - } - return {}; -} -} // end anonymous namespace. - -//===----------------------------------------------------------------------===// -// AddFOp -//===----------------------------------------------------------------------===// - -OpFoldResult AddFOp::fold(ArrayRef operands) { - return constFoldBinaryOp( - operands, [](APFloat a, APFloat b) { return a + b; }); -} - -//===----------------------------------------------------------------------===// -// AddIOp -//===----------------------------------------------------------------------===// - -OpFoldResult AddIOp::fold(ArrayRef operands) { - /// addi(x, 0) -> x - if (matchPattern(rhs(), m_Zero())) - return lhs(); - - return constFoldBinaryOp(operands, - [](APInt a, APInt b) { return a + b; }); -} - -//===----------------------------------------------------------------------===// -// AllocOp -//===----------------------------------------------------------------------===// - -static void print(OpAsmPrinter *p, AllocOp op) { - *p << "alloc"; - - // Print dynamic dimension operands. - MemRefType type = op.getType(); - printDimAndSymbolList(op.operand_begin(), op.operand_end(), - type.getNumDynamicDims(), p); - p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"}); - *p << " : " << type; -} - -static ParseResult parseAllocOp(OpAsmParser *parser, OperationState *result) { - MemRefType type; - - // Parse the dimension operands and optional symbol operands, followed by a - // memref type. - unsigned numDimOperands; - if (parseDimAndSymbolList(parser, result->operands, numDimOperands) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(type)) - return failure(); - - // Check numDynamicDims against number of question marks in memref type. - // Note: this check remains here (instead of in verify()), because the - // partition between dim operands and symbol operands is lost after parsing. - // Verification still checks that the total number of operands matches - // the number of symbols in the affine map, plus the number of dynamic - // dimensions in the memref. - if (numDimOperands != type.getNumDynamicDims()) - return parser->emitError(parser->getNameLoc()) - << "dimension operand count does not equal memref dynamic dimension " - "count"; - result->types.push_back(type); - return success(); -} - -static LogicalResult verify(AllocOp op) { - auto memRefType = op.getResult()->getType().dyn_cast(); - if (!memRefType) - return op.emitOpError("result must be a memref"); - - unsigned numSymbols = 0; - if (!memRefType.getAffineMaps().empty()) { - // Store number of symbols used in affine map (used in subsequent check). - AffineMap affineMap = memRefType.getAffineMaps()[0]; - numSymbols = affineMap.getNumSymbols(); - } - - // Check that the total number of operands matches the number of symbols in - // the affine map, plus the number of dynamic dimensions specified in the - // memref type. - unsigned numDynamicDims = memRefType.getNumDynamicDims(); - if (op.getOperation()->getNumOperands() != numDynamicDims + numSymbols) - return op.emitOpError( - "operand count does not equal dimension plus symbol operand count"); - - // Verify that all operands are of type Index. - for (auto operandType : op.getOperandTypes()) - if (!operandType.isIndex()) - return op.emitOpError("requires operands to be of type Index"); - return success(); -} - -namespace { -/// Fold constant dimensions into an alloc operation. -struct SimplifyAllocConst : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - PatternMatchResult matchAndRewrite(AllocOp alloc, - PatternRewriter &rewriter) const override { - // Check to see if any dimensions operands are constants. If so, we can - // substitute and drop them. - if (llvm::none_of(alloc.getOperands(), [](Value *operand) { - return matchPattern(operand, m_ConstantIndex()); - })) - return matchFailure(); - - auto memrefType = alloc.getType(); - - // Ok, we have one or more constant operands. Collect the non-constant ones - // and keep track of the resultant memref type to build. - SmallVector newShapeConstants; - newShapeConstants.reserve(memrefType.getRank()); - SmallVector newOperands; - SmallVector droppedOperands; - - unsigned dynamicDimPos = 0; - for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { - int64_t dimSize = memrefType.getDimSize(dim); - // If this is already static dimension, keep it. - if (dimSize != -1) { - newShapeConstants.push_back(dimSize); - continue; - } - auto *defOp = alloc.getOperand(dynamicDimPos)->getDefiningOp(); - if (auto constantIndexOp = dyn_cast_or_null(defOp)) { - // Dynamic shape dimension will be folded. - newShapeConstants.push_back(constantIndexOp.getValue()); - // Record to check for zero uses later below. - droppedOperands.push_back(constantIndexOp); - } else { - // Dynamic shape dimension not folded; copy operand from old memref. - newShapeConstants.push_back(-1); - newOperands.push_back(alloc.getOperand(dynamicDimPos)); - } - dynamicDimPos++; - } - - // Create new memref type (which will have fewer dynamic dimensions). - auto newMemRefType = MemRefType::get( - newShapeConstants, memrefType.getElementType(), - memrefType.getAffineMaps(), memrefType.getMemorySpace()); - assert(static_cast(newOperands.size()) == - newMemRefType.getNumDynamicDims()); - - // Create and insert the alloc op for the new memref. - auto newAlloc = - rewriter.create(alloc.getLoc(), newMemRefType, newOperands); - // Insert a cast so we have the same type as the old alloc. - auto resultCast = rewriter.create(alloc.getLoc(), newAlloc, - alloc.getType()); - - rewriter.replaceOp(alloc, {resultCast}, droppedOperands); - return matchSuccess(); - } -}; - -/// Fold alloc operations with no uses. Alloc has side effects on the heap, -/// but can still be deleted if it has zero uses. -struct SimplifyDeadAlloc : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - PatternMatchResult matchAndRewrite(AllocOp alloc, - PatternRewriter &rewriter) const override { - // Check if the alloc'ed value has any uses. - if (!alloc.use_empty()) - return matchFailure(); - - // If it doesn't, we can eliminate it. - alloc.erase(); - return matchSuccess(); - } -}; -} // end anonymous namespace. - -void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// BranchOp -//===----------------------------------------------------------------------===// - -static ParseResult parseBranchOp(OpAsmParser *parser, OperationState *result) { - Block *dest; - SmallVector destOperands; - if (parser->parseSuccessorAndUseList(dest, destOperands)) - return failure(); - result->addSuccessor(dest, destOperands); - return success(); -} - -static void print(OpAsmPrinter *p, BranchOp op) { - *p << "br "; - p->printSuccessorAndUseList(op.getOperation(), 0); -} - -Block *BranchOp::getDest() { return getOperation()->getSuccessor(0); } - -void BranchOp::setDest(Block *block) { - return getOperation()->setSuccessor(block, 0); -} - -void BranchOp::eraseOperand(unsigned index) { - getOperation()->eraseSuccessorOperand(0, index); -} - -//===----------------------------------------------------------------------===// -// CallOp -//===----------------------------------------------------------------------===// - -static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) { - SymbolRefAttr calleeAttr; - FunctionType calleeType; - SmallVector operands; - auto calleeLoc = parser->getNameLoc(); - if (parser->parseAttribute(calleeAttr, "callee", result->attributes) || - parser->parseOperandList(operands, OpAsmParser::Delimiter::Paren) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(calleeType) || - parser->addTypesToList(calleeType.getResults(), result->types) || - parser->resolveOperands(operands, calleeType.getInputs(), calleeLoc, - result->operands)) - return failure(); - - return success(); -} - -static void print(OpAsmPrinter *p, CallOp op) { - *p << "call " << op.getAttr("callee") << '('; - p->printOperands(op.getOperands()); - *p << ')'; - p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); - *p << " : "; - p->printType(op.getCalleeType()); -} - -static LogicalResult verify(CallOp op) { - // Check that the callee attribute was specified. - auto fnAttr = op.getAttrOfType("callee"); - if (!fnAttr) - return op.emitOpError("requires a 'callee' symbol reference attribute"); - auto fn = - op.getParentOfType().lookupSymbol(fnAttr.getValue()); - if (!fn) - return op.emitOpError() << "'" << fnAttr.getValue() - << "' does not reference a valid function"; - - // Verify that the operand and result types match the callee. - auto fnType = fn.getType(); - if (fnType.getNumInputs() != op.getNumOperands()) - return op.emitOpError("incorrect number of operands for callee"); - - for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) - if (op.getOperand(i)->getType() != fnType.getInput(i)) - return op.emitOpError("operand type mismatch"); - - if (fnType.getNumResults() != op.getNumResults()) - return op.emitOpError("incorrect number of results for callee"); - - for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) - if (op.getResult(i)->getType() != fnType.getResult(i)) - return op.emitOpError("result type mismatch"); - - return success(); -} - -FunctionType CallOp::getCalleeType() { - SmallVector resultTypes(getResultTypes()); - SmallVector argTypes(getOperandTypes()); - return FunctionType::get(argTypes, resultTypes, getContext()); -} - -//===----------------------------------------------------------------------===// -// CallIndirectOp -//===----------------------------------------------------------------------===// -namespace { -/// Fold indirect calls that have a constant function as the callee operand. -struct SimplifyIndirectCallWithKnownCallee - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - PatternMatchResult matchAndRewrite(CallIndirectOp indirectCall, - PatternRewriter &rewriter) const override { - // Check that the callee is a constant callee. - SymbolRefAttr calledFn; - if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn))) - return matchFailure(); - - // Replace with a direct call. - SmallVector callResults(indirectCall.getResultTypes()); - SmallVector callOperands(indirectCall.getArgOperands()); - rewriter.replaceOpWithNewOp(indirectCall, calledFn.getValue(), - callResults, callOperands); - return matchSuccess(); - } -}; -} // end anonymous namespace. - -static ParseResult parseCallIndirectOp(OpAsmParser *parser, - OperationState *result) { - FunctionType calleeType; - OpAsmParser::OperandType callee; - llvm::SMLoc operandsLoc; - SmallVector operands; - return failure( - parser->parseOperand(callee) || - parser->getCurrentLocation(&operandsLoc) || - parser->parseOperandList(operands, OpAsmParser::Delimiter::Paren) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(calleeType) || - parser->resolveOperand(callee, calleeType, result->operands) || - parser->resolveOperands(operands, calleeType.getInputs(), operandsLoc, - result->operands) || - parser->addTypesToList(calleeType.getResults(), result->types)); -} - -static void print(OpAsmPrinter *p, CallIndirectOp op) { - *p << "call_indirect "; - p->printOperand(op.getCallee()); - *p << '('; - p->printOperands(op.getArgOperands()); - *p << ')'; - p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); - *p << " : " << op.getCallee()->getType(); -} - -static LogicalResult verify(CallIndirectOp op) { - // The callee must be a function. - auto fnType = op.getCallee()->getType().dyn_cast(); - if (!fnType) - return op.emitOpError("callee must have function type"); - - // Verify that the operand and result types match the callee. - if (fnType.getNumInputs() != op.getNumOperands() - 1) - return op.emitOpError("incorrect number of operands for callee"); - - for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) - if (op.getOperand(i + 1)->getType() != fnType.getInput(i)) - return op.emitOpError("operand type mismatch"); - - if (fnType.getNumResults() != op.getNumResults()) - return op.emitOpError("incorrect number of results for callee"); - - for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) - if (op.getResult(i)->getType() != fnType.getResult(i)) - return op.emitOpError("result type mismatch"); - - return success(); -} - -void CallIndirectOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// General helpers for comparison ops -//===----------------------------------------------------------------------===// - -// Return the type of the same shape (scalar, vector or tensor) containing i1. -static Type getCheckedI1SameShape(Builder *build, Type type) { - auto i1Type = build->getI1Type(); - if (type.isIntOrIndexOrFloat()) - return i1Type; - if (auto tensorType = type.dyn_cast()) - return build->getTensorType(tensorType.getShape(), i1Type); - if (type.isa()) - return build->getTensorType(i1Type); - if (auto vectorType = type.dyn_cast()) - return build->getVectorType(vectorType.getShape(), i1Type); - return Type(); -} - -static Type getI1SameShape(Builder *build, Type type) { - Type res = getCheckedI1SameShape(build, type); - assert(res && "expected type with valid i1 shape"); - return res; -} - -//===----------------------------------------------------------------------===// -// CmpIOp -//===----------------------------------------------------------------------===// - -// Returns an array of mnemonics for CmpIPredicates indexed by values thereof. -static inline const char *const *getCmpIPredicateNames() { - static const char *predicateNames[]{ - /*EQ*/ "eq", - /*NE*/ "ne", - /*SLT*/ "slt", - /*SLE*/ "sle", - /*SGT*/ "sgt", - /*SGE*/ "sge", - /*ULT*/ "ult", - /*ULE*/ "ule", - /*UGT*/ "ugt", - /*UGE*/ "uge", - }; - static_assert(std::extent::value == - (size_t)CmpIPredicate::NumPredicates, - "wrong number of predicate names"); - return predicateNames; -} - -// Returns a value of the predicate corresponding to the given mnemonic. -// Returns NumPredicates (one-past-end) if there is no such mnemonic. -CmpIPredicate CmpIOp::getPredicateByName(StringRef name) { - return llvm::StringSwitch(name) - .Case("eq", CmpIPredicate::EQ) - .Case("ne", CmpIPredicate::NE) - .Case("slt", CmpIPredicate::SLT) - .Case("sle", CmpIPredicate::SLE) - .Case("sgt", CmpIPredicate::SGT) - .Case("sge", CmpIPredicate::SGE) - .Case("ult", CmpIPredicate::ULT) - .Case("ule", CmpIPredicate::ULE) - .Case("ugt", CmpIPredicate::UGT) - .Case("uge", CmpIPredicate::UGE) - .Default(CmpIPredicate::NumPredicates); -} - -static void buildCmpIOp(Builder *build, OperationState *result, - CmpIPredicate predicate, Value *lhs, Value *rhs) { - result->addOperands({lhs, rhs}); - result->types.push_back(getI1SameShape(build, lhs->getType())); - result->addAttribute( - CmpIOp::getPredicateAttrName(), - build->getI64IntegerAttr(static_cast(predicate))); -} - -static ParseResult parseCmpIOp(OpAsmParser *parser, OperationState *result) { - SmallVector ops; - SmallVector attrs; - Attribute predicateNameAttr; - Type type; - if (parser->parseAttribute(predicateNameAttr, CmpIOp::getPredicateAttrName(), - attrs) || - parser->parseComma() || parser->parseOperandList(ops, 2) || - parser->parseOptionalAttributeDict(attrs) || - parser->parseColonType(type) || - parser->resolveOperands(ops, type, result->operands)) - return failure(); - - if (!predicateNameAttr.isa()) - return parser->emitError(parser->getNameLoc(), - "expected string comparison predicate attribute"); - - // Rewrite string attribute to an enum value. - StringRef predicateName = predicateNameAttr.cast().getValue(); - auto predicate = CmpIOp::getPredicateByName(predicateName); - if (predicate == CmpIPredicate::NumPredicates) - return parser->emitError(parser->getNameLoc()) - << "unknown comparison predicate \"" << predicateName << "\""; - - auto builder = parser->getBuilder(); - Type i1Type = getCheckedI1SameShape(&builder, type); - if (!i1Type) - return parser->emitError(parser->getNameLoc(), - "expected type with valid i1 shape"); - - attrs[0].second = builder.getI64IntegerAttr(static_cast(predicate)); - result->attributes = attrs; - - result->addTypes({i1Type}); - return success(); -} - -static void print(OpAsmPrinter *p, CmpIOp op) { - *p << "cmpi "; - - auto predicateValue = - op.getAttrOfType(CmpIOp::getPredicateAttrName()).getInt(); - assert(predicateValue >= static_cast(CmpIPredicate::FirstValidValue) && - predicateValue < static_cast(CmpIPredicate::NumPredicates) && - "unknown predicate index"); - Builder b(op.getContext()); - auto predicateStringAttr = - b.getStringAttr(getCmpIPredicateNames()[predicateValue]); - p->printAttribute(predicateStringAttr); - - *p << ", "; - p->printOperand(op.lhs()); - *p << ", "; - p->printOperand(op.rhs()); - p->printOptionalAttrDict(op.getAttrs(), - /*elidedAttrs=*/{CmpIOp::getPredicateAttrName()}); - *p << " : " << op.lhs()->getType(); -} - -static LogicalResult verify(CmpIOp op) { - auto predicateAttr = - op.getAttrOfType(CmpIOp::getPredicateAttrName()); - if (!predicateAttr) - return op.emitOpError("requires an integer attribute named 'predicate'"); - auto predicate = predicateAttr.getInt(); - if (predicate < (int64_t)CmpIPredicate::FirstValidValue || - predicate >= (int64_t)CmpIPredicate::NumPredicates) - return op.emitOpError("'predicate' attribute value out of range"); - - return success(); -} - -// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer -// comparison predicates. -static bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs, - const APInt &rhs) { - switch (predicate) { - case CmpIPredicate::EQ: - return lhs.eq(rhs); - case CmpIPredicate::NE: - return lhs.ne(rhs); - case CmpIPredicate::SLT: - return lhs.slt(rhs); - case CmpIPredicate::SLE: - return lhs.sle(rhs); - case CmpIPredicate::SGT: - return lhs.sgt(rhs); - case CmpIPredicate::SGE: - return lhs.sge(rhs); - case CmpIPredicate::ULT: - return lhs.ult(rhs); - case CmpIPredicate::ULE: - return lhs.ule(rhs); - case CmpIPredicate::UGT: - return lhs.ugt(rhs); - case CmpIPredicate::UGE: - return lhs.uge(rhs); - default: - llvm_unreachable("unknown comparison predicate"); - } -} - -// Constant folding hook for comparisons. -OpFoldResult CmpIOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "cmpi takes two arguments"); - - auto lhs = operands.front().dyn_cast_or_null(); - auto rhs = operands.back().dyn_cast_or_null(); - if (!lhs || !rhs) - return {}; - - auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); - return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val)); -} - -//===----------------------------------------------------------------------===// -// CmpFOp -//===----------------------------------------------------------------------===// - -// Returns an array of mnemonics for CmpFPredicates indexed by values thereof. -static inline const char *const *getCmpFPredicateNames() { - static const char *predicateNames[] = { - /*AlwaysFalse*/ "false", - /*OEQ*/ "oeq", - /*OGT*/ "ogt", - /*OGE*/ "oge", - /*OLT*/ "olt", - /*OLE*/ "ole", - /*ONE*/ "one", - /*ORD*/ "ord", - /*UEQ*/ "ueq", - /*UGT*/ "ugt", - /*UGE*/ "uge", - /*ULT*/ "ult", - /*ULE*/ "ule", - /*UNE*/ "une", - /*UNO*/ "uno", - /*AlwaysTrue*/ "true", - }; - static_assert(std::extent::value == - (size_t)CmpFPredicate::NumPredicates, - "wrong number of predicate names"); - return predicateNames; -} - -// Returns a value of the predicate corresponding to the given mnemonic. -// Returns NumPredicates (one-past-end) if there is no such mnemonic. -CmpFPredicate CmpFOp::getPredicateByName(StringRef name) { - return llvm::StringSwitch(name) - .Case("false", CmpFPredicate::AlwaysFalse) - .Case("oeq", CmpFPredicate::OEQ) - .Case("ogt", CmpFPredicate::OGT) - .Case("oge", CmpFPredicate::OGE) - .Case("olt", CmpFPredicate::OLT) - .Case("ole", CmpFPredicate::OLE) - .Case("one", CmpFPredicate::ONE) - .Case("ord", CmpFPredicate::ORD) - .Case("ueq", CmpFPredicate::UEQ) - .Case("ugt", CmpFPredicate::UGT) - .Case("uge", CmpFPredicate::UGE) - .Case("ult", CmpFPredicate::ULT) - .Case("ule", CmpFPredicate::ULE) - .Case("une", CmpFPredicate::UNE) - .Case("uno", CmpFPredicate::UNO) - .Case("true", CmpFPredicate::AlwaysTrue) - .Default(CmpFPredicate::NumPredicates); -} - -static void buildCmpFOp(Builder *build, OperationState *result, - CmpFPredicate predicate, Value *lhs, Value *rhs) { - result->addOperands({lhs, rhs}); - result->types.push_back(getI1SameShape(build, lhs->getType())); - result->addAttribute( - CmpFOp::getPredicateAttrName(), - build->getI64IntegerAttr(static_cast(predicate))); -} - -static ParseResult parseCmpFOp(OpAsmParser *parser, OperationState *result) { - SmallVector ops; - SmallVector attrs; - Attribute predicateNameAttr; - Type type; - if (parser->parseAttribute(predicateNameAttr, CmpFOp::getPredicateAttrName(), - attrs) || - parser->parseComma() || parser->parseOperandList(ops, 2) || - parser->parseOptionalAttributeDict(attrs) || - parser->parseColonType(type) || - parser->resolveOperands(ops, type, result->operands)) - return failure(); - - if (!predicateNameAttr.isa()) - return parser->emitError(parser->getNameLoc(), - "expected string comparison predicate attribute"); - - // Rewrite string attribute to an enum value. - StringRef predicateName = predicateNameAttr.cast().getValue(); - auto predicate = CmpFOp::getPredicateByName(predicateName); - if (predicate == CmpFPredicate::NumPredicates) - return parser->emitError(parser->getNameLoc(), - "unknown comparison predicate \"" + predicateName + - "\""); - - auto builder = parser->getBuilder(); - Type i1Type = getCheckedI1SameShape(&builder, type); - if (!i1Type) - return parser->emitError(parser->getNameLoc(), - "expected type with valid i1 shape"); - - attrs[0].second = builder.getI64IntegerAttr(static_cast(predicate)); - result->attributes = attrs; - - result->addTypes({i1Type}); - return success(); -} - -static void print(OpAsmPrinter *p, CmpFOp op) { - *p << "cmpf "; - - auto predicateValue = - op.getAttrOfType(CmpFOp::getPredicateAttrName()).getInt(); - assert(predicateValue >= static_cast(CmpFPredicate::FirstValidValue) && - predicateValue < static_cast(CmpFPredicate::NumPredicates) && - "unknown predicate index"); - Builder b(op.getContext()); - auto predicateStringAttr = - b.getStringAttr(getCmpFPredicateNames()[predicateValue]); - p->printAttribute(predicateStringAttr); - - *p << ", "; - p->printOperand(op.lhs()); - *p << ", "; - p->printOperand(op.rhs()); - p->printOptionalAttrDict(op.getAttrs(), - /*elidedAttrs=*/{CmpFOp::getPredicateAttrName()}); - *p << " : " << op.lhs()->getType(); -} - -static LogicalResult verify(CmpFOp op) { - auto predicateAttr = - op.getAttrOfType(CmpFOp::getPredicateAttrName()); - if (!predicateAttr) - return op.emitOpError("requires an integer attribute named 'predicate'"); - auto predicate = predicateAttr.getInt(); - if (predicate < (int64_t)CmpFPredicate::FirstValidValue || - predicate >= (int64_t)CmpFPredicate::NumPredicates) - return op.emitOpError("'predicate' attribute value out of range"); - - return success(); -} - -// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point -// comparison predicates. -static bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs, - const APFloat &rhs) { - auto cmpResult = lhs.compare(rhs); - switch (predicate) { - case CmpFPredicate::AlwaysFalse: - return false; - case CmpFPredicate::OEQ: - return cmpResult == APFloat::cmpEqual; - case CmpFPredicate::OGT: - return cmpResult == APFloat::cmpGreaterThan; - case CmpFPredicate::OGE: - return cmpResult == APFloat::cmpGreaterThan || - cmpResult == APFloat::cmpEqual; - case CmpFPredicate::OLT: - return cmpResult == APFloat::cmpLessThan; - case CmpFPredicate::OLE: - return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; - case CmpFPredicate::ONE: - return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual; - case CmpFPredicate::ORD: - return cmpResult != APFloat::cmpUnordered; - case CmpFPredicate::UEQ: - return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual; - case CmpFPredicate::UGT: - return cmpResult == APFloat::cmpUnordered || - cmpResult == APFloat::cmpGreaterThan; - case CmpFPredicate::UGE: - return cmpResult == APFloat::cmpUnordered || - cmpResult == APFloat::cmpGreaterThan || - cmpResult == APFloat::cmpEqual; - case CmpFPredicate::ULT: - return cmpResult == APFloat::cmpUnordered || - cmpResult == APFloat::cmpLessThan; - case CmpFPredicate::ULE: - return cmpResult == APFloat::cmpUnordered || - cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; - case CmpFPredicate::UNE: - return cmpResult != APFloat::cmpEqual; - case CmpFPredicate::UNO: - return cmpResult == APFloat::cmpUnordered; - case CmpFPredicate::AlwaysTrue: - return true; - default: - llvm_unreachable("unknown comparison predicate"); - } -} - -// Constant folding hook for comparisons. -OpFoldResult CmpFOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "cmpf takes two arguments"); - - auto lhs = operands.front().dyn_cast_or_null(); - auto rhs = operands.back().dyn_cast_or_null(); - if (!lhs || !rhs || - // TODO(b/122019992) Implement and test constant folding for nan/inf when - // it is possible to have constant nan/inf - !lhs.getValue().isFinite() || !rhs.getValue().isFinite()) - return {}; - - auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); - return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val)); -} - -//===----------------------------------------------------------------------===// -// CondBranchOp -//===----------------------------------------------------------------------===// - -namespace { -/// cond_br true, ^bb1, ^bb2 -> br ^bb1 -/// cond_br false, ^bb1, ^bb2 -> br ^bb2 -/// -struct SimplifyConstCondBranchPred : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - PatternMatchResult matchAndRewrite(CondBranchOp condbr, - PatternRewriter &rewriter) const override { - // Check that the condition is a constant. - if (!matchPattern(condbr.getCondition(), m_Op())) - return matchFailure(); - - Block *foldedDest; - SmallVector branchArgs; - - // If the condition is known to evaluate to false we fold to a branch to the - // false destination. Otherwise, we fold to a branch to the true - // destination. - if (matchPattern(condbr.getCondition(), m_Zero())) { - foldedDest = condbr.getFalseDest(); - branchArgs.assign(condbr.false_operand_begin(), - condbr.false_operand_end()); - } else { - foldedDest = condbr.getTrueDest(); - branchArgs.assign(condbr.true_operand_begin(), condbr.true_operand_end()); - } - - rewriter.replaceOpWithNewOp(condbr, foldedDest, branchArgs); - return matchSuccess(); - } -}; -} // end anonymous namespace. - -static ParseResult parseCondBranchOp(OpAsmParser *parser, - OperationState *result) { - SmallVector destOperands; - Block *dest; - OpAsmParser::OperandType condInfo; - - // Parse the condition. - Type int1Ty = parser->getBuilder().getI1Type(); - if (parser->parseOperand(condInfo) || parser->parseComma() || - parser->resolveOperand(condInfo, int1Ty, result->operands)) { - return parser->emitError(parser->getNameLoc(), - "expected condition type was boolean (i1)"); - } - - // Parse the true successor. - if (parser->parseSuccessorAndUseList(dest, destOperands)) - return failure(); - result->addSuccessor(dest, destOperands); - - // Parse the false successor. - destOperands.clear(); - if (parser->parseComma() || - parser->parseSuccessorAndUseList(dest, destOperands)) - return failure(); - result->addSuccessor(dest, destOperands); - - return success(); -} - -static void print(OpAsmPrinter *p, CondBranchOp op) { - *p << "cond_br "; - p->printOperand(op.getCondition()); - *p << ", "; - p->printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex); - *p << ", "; - p->printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex); -} - -void CondBranchOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// Constant*Op -//===----------------------------------------------------------------------===// - -static void print(OpAsmPrinter *p, ConstantOp &op) { - *p << "constant "; - p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"}); - - if (op.getAttrs().size() > 1) - *p << ' '; - p->printAttribute(op.getValue()); - - // If the value is a symbol reference, print a trailing type. - if (op.getValue().isa()) - *p << " : " << op.getType(); -} - -static ParseResult parseConstantOp(OpAsmParser *parser, - OperationState *result) { - Attribute valueAttr; - if (parser->parseOptionalAttributeDict(result->attributes) || - parser->parseAttribute(valueAttr, "value", result->attributes)) - return failure(); - - // If the attribute is a symbol reference, then we expect a trailing type. - Type type; - if (!valueAttr.isa()) - type = valueAttr.getType(); - else if (parser->parseColonType(type)) - return failure(); - - // Add the attribute type to the list. - return parser->addTypeToList(type, result->types); -} - -/// The constant op requires an attribute, and furthermore requires that it -/// matches the return type. -static LogicalResult verify(ConstantOp &op) { - auto value = op.getValue(); - if (!value) - return op.emitOpError("requires a 'value' attribute"); - - auto type = op.getType(); - if (!value.getType().isa() && type != value.getType()) - return op.emitOpError() << "requires attribute's type (" << value.getType() - << ") to match op's return type (" << type << ")"; - - if (type.isa() || value.isa()) - return success(); - - if (auto intAttr = value.dyn_cast()) { - // If the type has a known bitwidth we verify that the value can be - // represented with the given bitwidth. - auto bitwidth = type.cast().getWidth(); - auto intVal = intAttr.getValue(); - if (!intVal.isSignedIntN(bitwidth) && !intVal.isIntN(bitwidth)) - return op.emitOpError("requires 'value' to be an integer within the " - "range of the integer result type"); - return success(); - } - - if (type.isa()) { - if (!value.isa()) - return op.emitOpError("requires 'value' to be a floating point constant"); - return success(); - } - - if (type.isa()) { - if (!value.isa()) - return op.emitOpError("requires 'value' to be a shaped constant"); - return success(); - } - - if (type.isa()) { - auto fnAttr = value.dyn_cast(); - if (!fnAttr) - return op.emitOpError("requires 'value' to be a function reference"); - - // Try to find the referenced function. - auto fn = - op.getParentOfType().lookupSymbol(fnAttr.getValue()); - if (!fn) - return op.emitOpError("reference to undefined function 'bar'"); - - // Check that the referenced function has the correct type. - if (fn.getType() != type) - return op.emitOpError("reference to function with mismatched type"); - - return success(); - } - - if (type.isa() && value.isa()) - return success(); - - return op.emitOpError("unsupported 'value' attribute: ") << value; -} - -OpFoldResult ConstantOp::fold(ArrayRef operands) { - assert(operands.empty() && "constant has no operands"); - return getValue(); -} - -/// Returns true if a constant operation can be built with the given value and -/// result type. -bool ConstantOp::isBuildableWith(Attribute value, Type type) { - // SymbolRefAttr can only be used with a function type. - if (value.isa()) - return type.isa(); - // Otherwise, the attribute must have the same type as 'type'. - if (value.getType() != type) - return false; - // Finally, check that the attribute kind is handled. - return value.isa() || value.isa() || - value.isa() || value.isa() || - value.isa(); -} - -void ConstantFloatOp::build(Builder *builder, OperationState *result, - const APFloat &value, FloatType type) { - ConstantOp::build(builder, result, type, builder->getFloatAttr(type, value)); -} - -bool ConstantFloatOp::classof(Operation *op) { - return ConstantOp::classof(op) && - op->getResult(0)->getType().isa(); -} - -/// ConstantIntOp only matches values whose result type is an IntegerType. -bool ConstantIntOp::classof(Operation *op) { - return ConstantOp::classof(op) && - op->getResult(0)->getType().isa(); -} - -void ConstantIntOp::build(Builder *builder, OperationState *result, - int64_t value, unsigned width) { - Type type = builder->getIntegerType(width); - ConstantOp::build(builder, result, type, - builder->getIntegerAttr(type, value)); -} - -/// Build a constant int op producing an integer with the specified type, -/// which must be an integer type. -void ConstantIntOp::build(Builder *builder, OperationState *result, - int64_t value, Type type) { - assert(type.isa() && "ConstantIntOp can only have integer type"); - ConstantOp::build(builder, result, type, - builder->getIntegerAttr(type, value)); -} - -/// ConstantIndexOp only matches values whose result type is Index. -bool ConstantIndexOp::classof(Operation *op) { - return ConstantOp::classof(op) && op->getResult(0)->getType().isIndex(); -} - -void ConstantIndexOp::build(Builder *builder, OperationState *result, - int64_t value) { - Type type = builder->getIndexType(); - ConstantOp::build(builder, result, type, - builder->getIntegerAttr(type, value)); -} - -//===----------------------------------------------------------------------===// -// DeallocOp -//===----------------------------------------------------------------------===// -namespace { -/// Fold Dealloc operations that are deallocating an AllocOp that is only used -/// by other Dealloc operations. -struct SimplifyDeadDealloc : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - PatternMatchResult matchAndRewrite(DeallocOp dealloc, - PatternRewriter &rewriter) const override { - // Check that the memref operand's defining operation is an AllocOp. - Value *memref = dealloc.memref(); - if (!isa_and_nonnull(memref->getDefiningOp())) - return matchFailure(); - - // Check that all of the uses of the AllocOp are other DeallocOps. - for (auto *user : memref->getUsers()) - if (!isa(user)) - return matchFailure(); - - // Erase the dealloc operation. - rewriter.replaceOp(dealloc, llvm::None); - return matchSuccess(); - } -}; -} // end anonymous namespace. - -static void print(OpAsmPrinter *p, DeallocOp op) { - *p << "dealloc " << *op.memref() << " : " << op.memref()->getType(); -} - -static ParseResult parseDeallocOp(OpAsmParser *parser, OperationState *result) { - OpAsmParser::OperandType memrefInfo; - MemRefType type; - - return failure(parser->parseOperand(memrefInfo) || - parser->parseColonType(type) || - parser->resolveOperand(memrefInfo, type, result->operands)); -} - -static LogicalResult verify(DeallocOp op) { - if (!op.memref()->getType().isa()) - return op.emitOpError("operand must be a memref"); - return success(); -} - -void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - /// dealloc(memrefcast) -> dealloc - results.insert(getOperationName(), context); - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// DimOp -//===----------------------------------------------------------------------===// - -static void print(OpAsmPrinter *p, DimOp op) { - *p << "dim " << *op.getOperand() << ", " << op.getIndex(); - p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"index"}); - *p << " : " << op.getOperand()->getType(); -} - -static ParseResult parseDimOp(OpAsmParser *parser, OperationState *result) { - OpAsmParser::OperandType operandInfo; - IntegerAttr indexAttr; - Type type; - Type indexType = parser->getBuilder().getIndexType(); - - return failure(parser->parseOperand(operandInfo) || parser->parseComma() || - parser->parseAttribute(indexAttr, indexType, "index", - result->attributes) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(type) || - parser->resolveOperand(operandInfo, type, result->operands) || - parser->addTypeToList(indexType, result->types)); -} - -static LogicalResult verify(DimOp op) { - // Check that we have an integer index operand. - auto indexAttr = op.getAttrOfType("index"); - if (!indexAttr) - return op.emitOpError("requires an integer attribute named 'index'"); - int64_t index = indexAttr.getValue().getSExtValue(); - - auto type = op.getOperand()->getType(); - if (auto tensorType = type.dyn_cast()) { - if (index >= tensorType.getRank()) - return op.emitOpError("index is out of range"); - } else if (auto memrefType = type.dyn_cast()) { - if (index >= memrefType.getRank()) - return op.emitOpError("index is out of range"); - - } else if (type.isa()) { - // ok, assumed to be in-range. - } else { - return op.emitOpError("requires an operand with tensor or memref type"); - } - - return success(); -} - -OpFoldResult DimOp::fold(ArrayRef operands) { - // Constant fold dim when the size along the index referred to is a constant. - auto opType = getOperand()->getType(); - int64_t indexSize = -1; - if (auto tensorType = opType.dyn_cast()) - indexSize = tensorType.getShape()[getIndex()]; - else if (auto memrefType = opType.dyn_cast()) - indexSize = memrefType.getShape()[getIndex()]; - - if (indexSize >= 0) - return IntegerAttr::get(IndexType::get(getContext()), indexSize); - - return {}; -} - -//===----------------------------------------------------------------------===// -// DivISOp -//===----------------------------------------------------------------------===// - -OpFoldResult DivISOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "binary operation takes two operands"); - - auto lhs = operands.front().dyn_cast_or_null(); - auto rhs = operands.back().dyn_cast_or_null(); - if (!lhs || !rhs) - return {}; - - // Don't fold if it requires division by zero. - if (rhs.getValue().isNullValue()) - return {}; - - // Don't fold if it would overflow. - bool overflow; - auto result = lhs.getValue().sdiv_ov(rhs.getValue(), overflow); - return overflow ? IntegerAttr() : IntegerAttr::get(lhs.getType(), result); -} - -//===----------------------------------------------------------------------===// -// DivIUOp -//===----------------------------------------------------------------------===// - -OpFoldResult DivIUOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "binary operation takes two operands"); - - auto lhs = operands.front().dyn_cast_or_null(); - auto rhs = operands.back().dyn_cast_or_null(); - if (!lhs || !rhs) - return {}; - - // Don't fold if it requires division by zero. - auto rhsValue = rhs.getValue(); - if (rhsValue.isNullValue()) - return {}; - - return IntegerAttr::get(lhs.getType(), lhs.getValue().udiv(rhsValue)); -} - -// --------------------------------------------------------------------------- -// DmaStartOp -// --------------------------------------------------------------------------- - -void DmaStartOp::build(Builder *builder, OperationState *result, - Value *srcMemRef, ArrayRef srcIndices, - Value *destMemRef, ArrayRef destIndices, - Value *numElements, Value *tagMemRef, - ArrayRef tagIndices, Value *stride, - Value *elementsPerStride) { - result->addOperands(srcMemRef); - result->addOperands(srcIndices); - result->addOperands(destMemRef); - result->addOperands(destIndices); - result->addOperands({numElements, tagMemRef}); - result->addOperands(tagIndices); - if (stride) - result->addOperands({stride, elementsPerStride}); -} - -void DmaStartOp::print(OpAsmPrinter *p) { - *p << "dma_start " << *getSrcMemRef() << '['; - p->printOperands(getSrcIndices()); - *p << "], " << *getDstMemRef() << '['; - p->printOperands(getDstIndices()); - *p << "], " << *getNumElements(); - *p << ", " << *getTagMemRef() << '['; - p->printOperands(getTagIndices()); - *p << ']'; - if (isStrided()) { - *p << ", " << *getStride(); - *p << ", " << *getNumElementsPerStride(); - } - p->printOptionalAttrDict(getAttrs()); - *p << " : " << getSrcMemRef()->getType(); - *p << ", " << getDstMemRef()->getType(); - *p << ", " << getTagMemRef()->getType(); -} - -// Parse DmaStartOp. -// Ex: -// %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, -// %tag[%index], %stride, %num_elt_per_stride : -// : memref<3076 x f32, 0>, -// memref<1024 x f32, 2>, -// memref<1 x i32> -// -ParseResult DmaStartOp::parse(OpAsmParser *parser, OperationState *result) { - OpAsmParser::OperandType srcMemRefInfo; - SmallVector srcIndexInfos; - OpAsmParser::OperandType dstMemRefInfo; - SmallVector dstIndexInfos; - OpAsmParser::OperandType numElementsInfo; - OpAsmParser::OperandType tagMemrefInfo; - SmallVector tagIndexInfos; - SmallVector strideInfo; - - SmallVector types; - auto indexType = parser->getBuilder().getIndexType(); - - // Parse and resolve the following list of operands: - // *) source memref followed by its indices (in square brackets). - // *) destination memref followed by its indices (in square brackets). - // *) dma size in KiB. - if (parser->parseOperand(srcMemRefInfo) || - parser->parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || - parser->parseComma() || parser->parseOperand(dstMemRefInfo) || - parser->parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || - parser->parseComma() || parser->parseOperand(numElementsInfo) || - parser->parseComma() || parser->parseOperand(tagMemrefInfo) || - parser->parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) - return failure(); - - // Parse optional stride and elements per stride. - if (parser->parseTrailingOperandList(strideInfo)) - return failure(); - - bool isStrided = strideInfo.size() == 2; - if (!strideInfo.empty() && !isStrided) { - return parser->emitError(parser->getNameLoc(), - "expected two stride related operands"); - } - - if (parser->parseColonTypeList(types)) - return failure(); - if (types.size() != 3) - return parser->emitError(parser->getNameLoc(), "fewer/more types expected"); - - if (parser->resolveOperand(srcMemRefInfo, types[0], result->operands) || - parser->resolveOperands(srcIndexInfos, indexType, result->operands) || - parser->resolveOperand(dstMemRefInfo, types[1], result->operands) || - parser->resolveOperands(dstIndexInfos, indexType, result->operands) || - // size should be an index. - parser->resolveOperand(numElementsInfo, indexType, result->operands) || - parser->resolveOperand(tagMemrefInfo, types[2], result->operands) || - // tag indices should be index. - parser->resolveOperands(tagIndexInfos, indexType, result->operands)) - return failure(); - - auto memrefType0 = types[0].dyn_cast(); - if (!memrefType0) - return parser->emitError(parser->getNameLoc(), - "expected source to be of memref type"); - - auto memrefType1 = types[1].dyn_cast(); - if (!memrefType1) - return parser->emitError(parser->getNameLoc(), - "expected destination to be of memref type"); - - auto memrefType2 = types[2].dyn_cast(); - if (!memrefType2) - return parser->emitError(parser->getNameLoc(), - "expected tag to be of memref type"); - - if (isStrided) { - if (parser->resolveOperands(strideInfo, indexType, result->operands)) - return failure(); - } - - // Check that source/destination index list size matches associated rank. - if (static_cast(srcIndexInfos.size()) != memrefType0.getRank() || - static_cast(dstIndexInfos.size()) != memrefType1.getRank()) - return parser->emitError(parser->getNameLoc(), - "memref rank not equal to indices count"); - if (static_cast(tagIndexInfos.size()) != memrefType2.getRank()) - return parser->emitError(parser->getNameLoc(), - "tag memref rank not equal to indices count"); - - return success(); -} - -LogicalResult DmaStartOp::verify() { - // DMAs from different memory spaces supported. - if (getSrcMemorySpace() == getDstMemorySpace()) - return emitOpError("DMA should be between different memory spaces"); - - if (getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() + - getDstMemRefRank() + 3 + 1 && - getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() + - getDstMemRefRank() + 3 + 1 + 2) { - return emitOpError("incorrect number of operands"); - } - return success(); -} - -void DmaStartOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - /// dma_start(memrefcast) -> dma_start - results.insert(getOperationName(), context); -} - -// --------------------------------------------------------------------------- -// DmaWaitOp -// --------------------------------------------------------------------------- - -void DmaWaitOp::build(Builder *builder, OperationState *result, - Value *tagMemRef, ArrayRef tagIndices, - Value *numElements) { - result->addOperands(tagMemRef); - result->addOperands(tagIndices); - result->addOperands(numElements); -} - -void DmaWaitOp::print(OpAsmPrinter *p) { - *p << "dma_wait "; - p->printOperand(getTagMemRef()); - *p << '['; - p->printOperands(getTagIndices()); - *p << "], "; - p->printOperand(getNumElements()); - p->printOptionalAttrDict(getAttrs()); - *p << " : " << getTagMemRef()->getType(); -} - -// Parse DmaWaitOp. -// Eg: -// dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4> -// -ParseResult DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) { - OpAsmParser::OperandType tagMemrefInfo; - SmallVector tagIndexInfos; - Type type; - auto indexType = parser->getBuilder().getIndexType(); - OpAsmParser::OperandType numElementsInfo; - - // Parse tag memref, its indices, and dma size. - if (parser->parseOperand(tagMemrefInfo) || - parser->parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) || - parser->parseComma() || parser->parseOperand(numElementsInfo) || - parser->parseColonType(type) || - parser->resolveOperand(tagMemrefInfo, type, result->operands) || - parser->resolveOperands(tagIndexInfos, indexType, result->operands) || - parser->resolveOperand(numElementsInfo, indexType, result->operands)) - return failure(); - - auto memrefType = type.dyn_cast(); - if (!memrefType) - return parser->emitError(parser->getNameLoc(), - "expected tag to be of memref type"); - - if (static_cast(tagIndexInfos.size()) != memrefType.getRank()) - return parser->emitError(parser->getNameLoc(), - "tag memref rank not equal to indices count"); - - return success(); -} - -void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - /// dma_wait(memrefcast) -> dma_wait - results.insert(getOperationName(), context); -} - -//===----------------------------------------------------------------------===// -// ExtractElementOp -//===----------------------------------------------------------------------===// - -static void print(OpAsmPrinter *p, ExtractElementOp op) { - *p << "extract_element " << *op.getAggregate() << '['; - p->printOperands(op.getIndices()); - *p << ']'; - p->printOptionalAttrDict(op.getAttrs()); - *p << " : " << op.getAggregate()->getType(); -} - -static ParseResult parseExtractElementOp(OpAsmParser *parser, - OperationState *result) { - OpAsmParser::OperandType aggregateInfo; - SmallVector indexInfo; - ShapedType type; - - auto affineIntTy = parser->getBuilder().getIndexType(); - return failure( - parser->parseOperand(aggregateInfo) || - parser->parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(type) || - parser->resolveOperand(aggregateInfo, type, result->operands) || - parser->resolveOperands(indexInfo, affineIntTy, result->operands) || - parser->addTypeToList(type.getElementType(), result->types)); -} - -static LogicalResult verify(ExtractElementOp op) { - auto aggregateType = op.getAggregate()->getType().cast(); - - // This should be possible with tablegen type constraints - if (op.getType() != aggregateType.getElementType()) - return op.emitOpError("result type must match element type of aggregate"); - - // Verify the # indices match if we have a ranked type. - if (aggregateType.hasRank() && - aggregateType.getRank() != op.getNumOperands() - 1) - return op.emitOpError("incorrect number of indices for extract_element"); - - return success(); -} - -OpFoldResult ExtractElementOp::fold(ArrayRef operands) { - assert(!operands.empty() && "extract_element takes atleast one operand"); - - // The aggregate operand must be a known constant. - Attribute aggregate = operands.front(); - if (!aggregate) - return {}; - - // If this is a splat elements attribute, simply return the value. All of the - // elements of a splat attribute are the same. - if (auto splatAggregate = aggregate.dyn_cast()) - return splatAggregate.getSplatValue(); - - // Otherwise, collect the constant indices into the aggregate. - SmallVector indices; - for (Attribute indice : llvm::drop_begin(operands, 1)) { - if (!indice || !indice.isa()) - return {}; - indices.push_back(indice.cast().getInt()); - } - - // If this is an elements attribute, query the value at the given indices. - auto elementsAttr = aggregate.dyn_cast(); - if (elementsAttr && elementsAttr.isValidIndex(indices)) - return elementsAttr.getValue(indices); - return {}; -} - -//===----------------------------------------------------------------------===// -// IndexCastOp -//===----------------------------------------------------------------------===// - -// Index cast is applicable from index to integer and backwards. -bool IndexCastOp::areCastCompatible(Type a, Type b) { - return (a.isIndex() && b.isa()) || - (a.isa() && b.isIndex()); -} - -//===----------------------------------------------------------------------===// -// LoadOp -//===----------------------------------------------------------------------===// - -static void print(OpAsmPrinter *p, LoadOp op) { - *p << "load " << *op.getMemRef() << '['; - p->printOperands(op.getIndices()); - *p << ']'; - p->printOptionalAttrDict(op.getAttrs()); - *p << " : " << op.getMemRefType(); -} - -static ParseResult parseLoadOp(OpAsmParser *parser, OperationState *result) { - OpAsmParser::OperandType memrefInfo; - SmallVector indexInfo; - MemRefType type; - - auto affineIntTy = parser->getBuilder().getIndexType(); - return failure( - parser->parseOperand(memrefInfo) || - parser->parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(type) || - parser->resolveOperand(memrefInfo, type, result->operands) || - parser->resolveOperands(indexInfo, affineIntTy, result->operands) || - parser->addTypeToList(type.getElementType(), result->types)); -} - -static LogicalResult verify(LoadOp op) { - if (op.getType() != op.getMemRefType().getElementType()) - return op.emitOpError("result type must match element type of memref"); - - if (op.getMemRefType().getRank() != op.getNumOperands() - 1) - return op.emitOpError("incorrect number of indices for load"); - - for (auto *idx : op.getIndices()) - if (!idx->getType().isIndex()) - return op.emitOpError("index to load must have 'index' type"); - - // TODO: Verify we have the right number of indices. - - // TODO: in Function verify that the indices are parameters, IV's, or the - // result of an affine.apply. - return success(); -} - -void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - /// load(memrefcast) -> load - results.insert(getOperationName(), context); -} - -//===----------------------------------------------------------------------===// -// MemRefCastOp -//===----------------------------------------------------------------------===// - -bool MemRefCastOp::areCastCompatible(Type a, Type b) { - auto aT = a.dyn_cast(); - auto bT = b.dyn_cast(); - - if (!aT || !bT) - return false; - if (aT.getElementType() != bT.getElementType()) - return false; - if (aT.getAffineMaps() != bT.getAffineMaps()) - return false; - if (aT.getMemorySpace() != bT.getMemorySpace()) - return false; - - // They must have the same rank, and any specified dimensions must match. - if (aT.getRank() != bT.getRank()) - return false; - - for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { - int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); - if (aDim != -1 && bDim != -1 && aDim != bDim) - return false; - } - - return true; -} - -OpFoldResult MemRefCastOp::fold(ArrayRef operands) { - return impl::foldCastOp(*this); -} - -//===----------------------------------------------------------------------===// -// MulFOp -//===----------------------------------------------------------------------===// - -OpFoldResult MulFOp::fold(ArrayRef operands) { - return constFoldBinaryOp( - operands, [](APFloat a, APFloat b) { return a * b; }); -} - -//===----------------------------------------------------------------------===// -// MulIOp -//===----------------------------------------------------------------------===// - -OpFoldResult MulIOp::fold(ArrayRef operands) { - /// muli(x, 0) -> 0 - if (matchPattern(rhs(), m_Zero())) - return rhs(); - /// muli(x, 1) -> x - if (matchPattern(rhs(), m_One())) - return getOperand(0); - - // TODO: Handle the overflow case. - return constFoldBinaryOp(operands, - [](APInt a, APInt b) { return a * b; }); -} - -//===----------------------------------------------------------------------===// -// RankOp -//===----------------------------------------------------------------------===// - -static void print(OpAsmPrinter *p, RankOp op) { - *p << "rank " << *op.getOperand() << " : " << op.getOperand()->getType(); -} - -static ParseResult parseRankOp(OpAsmParser *parser, OperationState *result) { - OpAsmParser::OperandType operandInfo; - Type type; - Type indexType = parser->getBuilder().getIndexType(); - return failure(parser->parseOperand(operandInfo) || - parser->parseColonType(type) || - parser->resolveOperand(operandInfo, type, result->operands) || - parser->addTypeToList(indexType, result->types)); -} - -OpFoldResult RankOp::fold(ArrayRef operands) { - // Constant fold rank when the rank of the tensor is known. - auto type = getOperand()->getType(); - if (auto tensorType = type.dyn_cast()) - return IntegerAttr::get(IndexType::get(getContext()), tensorType.getRank()); - return IntegerAttr(); -} - -//===----------------------------------------------------------------------===// -// RemISOp -//===----------------------------------------------------------------------===// - -OpFoldResult RemISOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "remis takes two operands"); - - auto rhs = operands.back().dyn_cast_or_null(); - if (!rhs) - return {}; - auto rhsValue = rhs.getValue(); - - // x % 1 = 0 - if (rhsValue.isOneValue()) - return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); - - // Don't fold if it requires division by zero. - if (rhsValue.isNullValue()) - return {}; - - auto lhs = operands.front().dyn_cast_or_null(); - if (!lhs) - return {}; - return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue)); -} - -//===----------------------------------------------------------------------===// -// RemIUOp -//===----------------------------------------------------------------------===// - -OpFoldResult RemIUOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "remiu takes two operands"); - - auto rhs = operands.back().dyn_cast_or_null(); - if (!rhs) - return {}; - auto rhsValue = rhs.getValue(); - - // x % 1 = 0 - if (rhsValue.isOneValue()) - return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); - - // Don't fold if it requires division by zero. - if (rhsValue.isNullValue()) - return {}; - - auto lhs = operands.front().dyn_cast_or_null(); - if (!lhs) - return {}; - return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue)); -} - -//===----------------------------------------------------------------------===// -// ReturnOp -//===----------------------------------------------------------------------===// - -static ParseResult parseReturnOp(OpAsmParser *parser, OperationState *result) { - SmallVector opInfo; - SmallVector types; - llvm::SMLoc loc = parser->getCurrentLocation(); - return failure(parser->parseOperandList(opInfo) || - (!opInfo.empty() && parser->parseColonTypeList(types)) || - parser->resolveOperands(opInfo, types, loc, result->operands)); -} - -static void print(OpAsmPrinter *p, ReturnOp op) { - *p << "return"; - if (op.getNumOperands() != 0) { - *p << ' '; - p->printOperands(op.getOperands()); - *p << " : "; - interleaveComma(op.getOperandTypes(), *p); - } -} - -static LogicalResult verify(ReturnOp op) { - auto function = cast(op.getParentOp()); - - // The operand number and types must match the function signature. - const auto &results = function.getType().getResults(); - if (op.getNumOperands() != results.size()) - return op.emitOpError("has ") - << op.getNumOperands() - << " operands, but enclosing function returns " << results.size(); - - for (unsigned i = 0, e = results.size(); i != e; ++i) - if (op.getOperand(i)->getType() != results[i]) - return op.emitError() - << "type of return operand " << i << " (" - << op.getOperand(i)->getType() - << ") doesn't match function result type (" << results[i] << ")"; - - return success(); -} - -//===----------------------------------------------------------------------===// -// SIToFPOp -//===----------------------------------------------------------------------===// - -// sitofp is applicable from integer types to float types. -bool SIToFPOp::areCastCompatible(Type a, Type b) { - return a.isa() && b.isa(); -} - -//===----------------------------------------------------------------------===// -// SelectOp -//===----------------------------------------------------------------------===// - -static ParseResult parseSelectOp(OpAsmParser *parser, OperationState *result) { - SmallVector ops; - SmallVector attrs; - Type type; - if (parser->parseOperandList(ops, 3) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(type)) - return failure(); - - auto i1Type = getCheckedI1SameShape(&parser->getBuilder(), type); - if (!i1Type) - return parser->emitError(parser->getNameLoc(), - "expected type with valid i1 shape"); - - SmallVector types = {i1Type, type, type}; - return failure(parser->resolveOperands(ops, types, parser->getNameLoc(), - result->operands) || - parser->addTypeToList(type, result->types)); -} - -static void print(OpAsmPrinter *p, SelectOp op) { - *p << "select "; - p->printOperands(op.getOperands()); - *p << " : " << op.getTrueValue()->getType(); - p->printOptionalAttrDict(op.getAttrs()); -} - -static LogicalResult verify(SelectOp op) { - auto trueType = op.getTrueValue()->getType(); - auto falseType = op.getFalseValue()->getType(); - - if (trueType != falseType) - return op.emitOpError( - "requires 'true' and 'false' arguments to be of the same type"); - - return success(); -} - -OpFoldResult SelectOp::fold(ArrayRef operands) { - auto *condition = getCondition(); - - // select true, %0, %1 => %0 - if (matchPattern(condition, m_One())) - return getTrueValue(); - - // select false, %0, %1 => %1 - if (matchPattern(condition, m_Zero())) - return getFalseValue(); - return nullptr; -} - -//===----------------------------------------------------------------------===// -// StoreOp -//===----------------------------------------------------------------------===// - -static void print(OpAsmPrinter *p, StoreOp op) { - *p << "store " << *op.getValueToStore(); - *p << ", " << *op.getMemRef() << '['; - p->printOperands(op.getIndices()); - *p << ']'; - p->printOptionalAttrDict(op.getAttrs()); - *p << " : " << op.getMemRefType(); -} - -static ParseResult parseStoreOp(OpAsmParser *parser, OperationState *result) { - OpAsmParser::OperandType storeValueInfo; - OpAsmParser::OperandType memrefInfo; - SmallVector indexInfo; - MemRefType memrefType; - - auto affineIntTy = parser->getBuilder().getIndexType(); - return failure( - parser->parseOperand(storeValueInfo) || parser->parseComma() || - parser->parseOperand(memrefInfo) || - parser->parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(memrefType) || - parser->resolveOperand(storeValueInfo, memrefType.getElementType(), - result->operands) || - parser->resolveOperand(memrefInfo, memrefType, result->operands) || - parser->resolveOperands(indexInfo, affineIntTy, result->operands)); -} - -static LogicalResult verify(StoreOp op) { - // First operand must have same type as memref element type. - if (op.getValueToStore()->getType() != op.getMemRefType().getElementType()) - return op.emitOpError( - "first operand must have same type memref element type"); - - if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) - return op.emitOpError("store index operand count not equal to memref rank"); - - for (auto *idx : op.getIndices()) - if (!idx->getType().isIndex()) - return op.emitOpError("index to load must have 'index' type"); - - // TODO: Verify we have the right number of indices. - - // TODO: in Function verify that the indices are parameters, IV's, or the - // result of an affine.apply. - return success(); -} - -void StoreOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - /// store(memrefcast) -> store - results.insert(getOperationName(), context); -} - -//===----------------------------------------------------------------------===// -// SubFOp -//===----------------------------------------------------------------------===// - -OpFoldResult SubFOp::fold(ArrayRef operands) { - return constFoldBinaryOp( - operands, [](APFloat a, APFloat b) { return a - b; }); -} - -//===----------------------------------------------------------------------===// -// SubIOp -//===----------------------------------------------------------------------===// - -OpFoldResult SubIOp::fold(ArrayRef operands) { - // subi(x,x) -> 0 - if (getOperand(0) == getOperand(1)) - return Builder(getContext()).getZeroAttr(getType()); - - return constFoldBinaryOp(operands, - [](APInt a, APInt b) { return a - b; }); -} - -//===----------------------------------------------------------------------===// -// AndOp -//===----------------------------------------------------------------------===// - -OpFoldResult AndOp::fold(ArrayRef operands) { - /// and(x, 0) -> 0 - if (matchPattern(rhs(), m_Zero())) - return rhs(); - /// and(x,x) -> x - if (lhs() == rhs()) - return rhs(); - - return constFoldBinaryOp(operands, - [](APInt a, APInt b) { return a & b; }); -} - -//===----------------------------------------------------------------------===// -// OrOp -//===----------------------------------------------------------------------===// - -OpFoldResult OrOp::fold(ArrayRef operands) { - /// or(x, 0) -> x - if (matchPattern(rhs(), m_Zero())) - return lhs(); - /// or(x,x) -> x - if (lhs() == rhs()) - return rhs(); - - return constFoldBinaryOp(operands, - [](APInt a, APInt b) { return a | b; }); -} - -//===----------------------------------------------------------------------===// -// XOrOp -//===----------------------------------------------------------------------===// - -OpFoldResult XOrOp::fold(ArrayRef operands) { - /// xor(x, 0) -> x - if (matchPattern(rhs(), m_Zero())) - return lhs(); - /// xor(x,x) -> 0 - if (lhs() == rhs()) - return Builder(getContext()).getZeroAttr(getType()); - - return constFoldBinaryOp(operands, - [](APInt a, APInt b) { return a ^ b; }); -} - -//===----------------------------------------------------------------------===// -// TensorCastOp -//===----------------------------------------------------------------------===// - -bool TensorCastOp::areCastCompatible(Type a, Type b) { - auto aT = a.dyn_cast(); - auto bT = b.dyn_cast(); - if (!aT || !bT) - return false; - - if (aT.getElementType() != bT.getElementType()) - return false; - - // If the either are unranked, then the cast is valid. - auto aRType = aT.dyn_cast(); - auto bRType = bT.dyn_cast(); - if (!aRType || !bRType) - return true; - - // If they are both ranked, they have to have the same rank, and any specified - // dimensions must match. - if (aRType.getRank() != bRType.getRank()) - return false; - - for (unsigned i = 0, e = aRType.getRank(); i != e; ++i) { - int64_t aDim = aRType.getDimSize(i), bDim = bRType.getDimSize(i); - if (aDim != -1 && bDim != -1 && aDim != bDim) - return false; - } - - return true; -} - -OpFoldResult TensorCastOp::fold(ArrayRef operands) { - return impl::foldCastOp(*this); -} - -//===----------------------------------------------------------------------===// -// TableGen'd op method definitions -//===----------------------------------------------------------------------===// - -#define GET_OP_CLASSES -#include "mlir/StandardOps/Ops.cpp.inc" diff --git a/mlir/lib/Support/JitRunner.cpp b/mlir/lib/Support/JitRunner.cpp index 26a5fc12cce..afa356ea69f 100644 --- a/mlir/lib/Support/JitRunner.cpp +++ b/mlir/lib/Support/JitRunner.cpp @@ -26,13 +26,13 @@ #include "mlir/Support/JitRunner.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/ExecutionEngine/ExecutionEngine.h" #include "mlir/ExecutionEngine/MemRefUtils.h" #include "mlir/ExecutionEngine/OptUtils.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/StandardTypes.h" -#include "mlir/LLVMIR/LLVMDialect.h" #include "mlir/Parser.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" diff --git a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp index a1e09fda84d..98dc43c7105 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp @@ -23,9 +23,9 @@ #include "mlir/Target/NVVMIR.h" #include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/IR/Function.h" #include "mlir/IR/Module.h" -#include "mlir/LLVMIR/NVVMDialect.h" #include "mlir/Support/FileUtilities.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "mlir/Translation.h" @@ -55,7 +55,7 @@ protected: LogicalResult convertOperation(Operation &opInst, llvm::IRBuilder<> &builder) override { -#include "mlir/LLVMIR/NVVMConversions.inc" +#include "mlir/Dialect/LLVMIR/NVVMConversions.inc" return LLVM::ModuleTranslation::convertOperation(opInst, builder); } diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 19ff0961497..bea22c9753c 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -22,9 +22,9 @@ #include "mlir/Target/LLVMIR/ModuleTranslation.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Module.h" -#include "mlir/LLVMIR/LLVMDialect.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/SetVector.h" @@ -202,7 +202,7 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst, return position; }; -#include "mlir/LLVMIR/LLVMConversions.inc" +#include "mlir/Dialect/LLVMIR/LLVMConversions.inc" // Emit function calls. If the "callee" attribute is present, this is a // direct function call and we also need to look up the remapped function diff --git a/mlir/lib/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Transforms/AffineDataCopyGeneration.cpp index 5030f722519..33b73336ff7 100644 --- a/mlir/lib/Transforms/AffineDataCopyGeneration.cpp +++ b/mlir/lib/Transforms/AffineDataCopyGeneration.cpp @@ -31,9 +31,9 @@ #include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/Ops.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Utils.h" #include "llvm/ADT/MapVector.h" diff --git a/mlir/lib/Transforms/LoopCoalescing.cpp b/mlir/lib/Transforms/LoopCoalescing.cpp index 2ce0fbd011b..c4024fe303f 100644 --- a/mlir/lib/Transforms/LoopCoalescing.cpp +++ b/mlir/lib/Transforms/LoopCoalescing.cpp @@ -16,8 +16,8 @@ // ============================================================================= #include "mlir/Dialect/LoopOps/LoopOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/Ops.h" #include "mlir/Transforms/LoopUtils.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/RegionUtils.h" diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 98d01b24be0..98798938077 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -24,11 +24,11 @@ #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/Ops.h" #include "mlir/Transforms/LoopFusionUtils.h" #include "mlir/Transforms/LoopUtils.h" #include "mlir/Transforms/Passes.h" diff --git a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp index fddc890edcf..094f8fc421d 100644 --- a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp +++ b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp @@ -25,11 +25,11 @@ #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/Ops.h" #include "mlir/Transforms/LoopUtils.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Utils.h" diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index 1879ff63af2..5a7d926d4f9 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -23,13 +23,13 @@ #include "mlir/Transforms/LowerAffine.h" #include "mlir/AffineOps/AffineOps.h" #include "mlir/Dialect/LoopOps/LoopOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/Ops.h" #include "mlir/Support/Functional.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index 8cb50e805f8..ab98340f0af 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -25,6 +25,7 @@ #include "mlir/Analysis/NestedMatcher.h" #include "mlir/Analysis/Utils.h" #include "mlir/Analysis/VectorAnalysis.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/EDSC/Builders.h" #include "mlir/EDSC/Helpers.h" #include "mlir/IR/AffineExpr.h" @@ -37,7 +38,6 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Types.h" #include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/Ops.h" #include "mlir/Support/Functional.h" #include "mlir/Transforms/Passes.h" #include "mlir/VectorOps/VectorOps.h" diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 811c6fc7ad5..eaa4d002969 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -28,6 +28,7 @@ #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/Utils.h" #include "mlir/Analysis/VectorAnalysis.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" @@ -36,7 +37,6 @@ #include "mlir/IR/OperationSupport.h" #include "mlir/IR/Types.h" #include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/Ops.h" #include "mlir/Support/Functional.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/Passes.h" diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index 59a4fbe93ab..33433e50d0f 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -26,8 +26,8 @@ #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/Dominance.h" #include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/Ops.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/SmallPtrSet.h" #include diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index db78f500867..b58b6debc05 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -25,9 +25,9 @@ #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/Ops.h" #include "mlir/Transforms/LoopUtils.h" #include "mlir/Transforms/Utils.h" #include "llvm/ADT/DenseMap.h" diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp index 435ea85ea98..6c313e20932 100644 --- a/mlir/lib/Transforms/Utils/FoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp @@ -22,10 +22,10 @@ #include "mlir/Transforms/FoldUtils.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/Operation.h" -#include "mlir/StandardOps/Ops.h" using namespace mlir; diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index fe15fb49865..361580811e6 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -19,9 +19,9 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Builders.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/StandardOps/Ops.h" #include "mlir/Transforms/FoldUtils.h" #include "llvm/ADT/DenseMap.h" #include "llvm/Support/CommandLine.h" diff --git a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp index 4c079bd88aa..63150c14742 100644 --- a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp @@ -26,13 +26,13 @@ #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" #include "mlir/IR/Operation.h" -#include "mlir/StandardOps/Ops.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Debug.h" diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index a4717ad507b..8b62d007f47 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -27,6 +27,7 @@ #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/LoopOps/LoopOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BlockAndValueMapping.h" @@ -34,7 +35,6 @@ #include "mlir/IR/Function.h" #include "mlir/IR/Module.h" #include "mlir/IR/Operation.h" -#include "mlir/StandardOps/Ops.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SetVector.h" diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index ffc19d1a1d3..e2253c77f67 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -27,10 +27,10 @@ #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Dominance.h" #include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" #include "mlir/IR/Module.h" -#include "mlir/StandardOps/Ops.h" #include "mlir/Support/MathExtras.h" #include "llvm/ADT/DenseMap.h" using namespace mlir; diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index d00174ba2fa..6b3c4449667 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -26,12 +26,12 @@ #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/Utils.h" #include "mlir/Analysis/VectorAnalysis.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Location.h" #include "mlir/IR/Types.h" #include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/Ops.h" #include "mlir/Support/Functional.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/Passes.h" diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp index a39388e2fd6..386d744ef32 100644 --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -18,6 +18,7 @@ // RUN: mlir-edsc-builder-api-test | FileCheck %s #include "mlir/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/EDSC/Builders.h" #include "mlir/EDSC/Helpers.h" #include "mlir/EDSC/Intrinsics.h" @@ -28,7 +29,6 @@ #include "mlir/IR/Types.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" -#include "mlir/StandardOps/Ops.h" #include "mlir/Transforms/LoopUtils.h" #include "mlir/Transforms/Passes.h" diff --git a/mlir/test/SDBM/sdbm-api-test.cpp b/mlir/test/SDBM/sdbm-api-test.cpp index 39a4a0d3189..b8cbaef5c35 100644 --- a/mlir/test/SDBM/sdbm-api-test.cpp +++ b/mlir/test/SDBM/sdbm-api-test.cpp @@ -17,10 +17,10 @@ // RUN: mlir-sdbm-api-test | FileCheck %s +#include "mlir/Dialect/SDBM/SDBM.h" +#include "mlir/Dialect/SDBM/SDBMDialect.h" +#include "mlir/Dialect/SDBM/SDBMExpr.h" #include "mlir/IR/MLIRContext.h" -#include "mlir/SDBM/SDBM.h" -#include "mlir/SDBM/SDBMDialect.h" -#include "mlir/SDBM/SDBMExpr.h" #include "llvm/Support/raw_ostream.h" diff --git a/mlir/test/lib/Transforms/TestConstantFold.cpp b/mlir/test/lib/Transforms/TestConstantFold.cpp index 34480f09f57..35a7eba5478 100644 --- a/mlir/test/lib/Transforms/TestConstantFold.cpp +++ b/mlir/test/lib/Transforms/TestConstantFold.cpp @@ -16,10 +16,10 @@ // ============================================================================= #include "mlir/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" #include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/Ops.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Utils.h" diff --git a/mlir/test/lib/Transforms/TestLoopFusion.cpp b/mlir/test/lib/Transforms/TestLoopFusion.cpp index 8b55d351bdc..4dd06a58904 100644 --- a/mlir/test/lib/Transforms/TestLoopFusion.cpp +++ b/mlir/test/lib/Transforms/TestLoopFusion.cpp @@ -24,9 +24,9 @@ #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Passes.h" #include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" -#include "mlir/StandardOps/Ops.h" #include "mlir/Transforms/LoopFusionUtils.h" #include "mlir/Transforms/Passes.h" diff --git a/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp b/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp index 1d174eb8395..9bd9222bbef 100644 --- a/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp +++ b/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp @@ -29,9 +29,9 @@ #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/Passes.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Function.h" #include "mlir/IR/Module.h" -#include "mlir/LLVMIR/LLVMDialect.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/JitRunner.h" diff --git a/mlir/unittests/SDBM/SDBMTest.cpp b/mlir/unittests/SDBM/SDBMTest.cpp index c4c1c56f51f..850d1480320 100644 --- a/mlir/unittests/SDBM/SDBMTest.cpp +++ b/mlir/unittests/SDBM/SDBMTest.cpp @@ -15,11 +15,11 @@ // limitations under the License. // ============================================================================= -#include "mlir/SDBM/SDBM.h" +#include "mlir/Dialect/SDBM/SDBM.h" +#include "mlir/Dialect/SDBM/SDBMDialect.h" +#include "mlir/Dialect/SDBM/SDBMExpr.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/MLIRContext.h" -#include "mlir/SDBM/SDBMDialect.h" -#include "mlir/SDBM/SDBMExpr.h" #include "gtest/gtest.h" #include "llvm/ADT/DenseSet.h" -- cgit v1.2.3 From ffde975e215e8ccba2b96a05f66a5756bebc8b64 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Tue, 20 Aug 2019 15:36:08 -0700 Subject: NFC: Move AffineOps dialect to the Dialect sub-directory. PiperOrigin-RevId: 264482571 --- .../Linalg/Linalg1/include/linalg1/Common.h | 2 +- mlir/examples/Linalg/Linalg4/lib/Transforms.cpp | 2 +- mlir/include/mlir/AffineOps/AffineOps.h | 598 ------- mlir/include/mlir/AffineOps/AffineOps.td | 259 --- mlir/include/mlir/AffineOps/AffineOpsBase.td | 44 - mlir/include/mlir/AffineOps/CMakeLists.txt | 4 - mlir/include/mlir/CMakeLists.txt | 1 - mlir/include/mlir/Dialect/AffineOps/AffineOps.h | 598 +++++++ mlir/include/mlir/Dialect/AffineOps/AffineOps.td | 259 +++ .../mlir/Dialect/AffineOps/AffineOpsBase.td | 44 + mlir/include/mlir/Dialect/AffineOps/CMakeLists.txt | 4 + mlir/include/mlir/Dialect/CMakeLists.txt | 1 + .../mlir/Dialect/Linalg/IR/LinalgLibraryOps.td | 2 +- mlir/include/mlir/EDSC/Builders.h | 2 +- mlir/lib/AffineOps/AffineOps.cpp | 1764 -------------------- mlir/lib/AffineOps/CMakeLists.txt | 10 - mlir/lib/AffineOps/DialectRegistration.cpp | 22 - mlir/lib/Analysis/AffineAnalysis.cpp | 2 +- mlir/lib/Analysis/AffineStructures.cpp | 2 +- mlir/lib/Analysis/LoopAnalysis.cpp | 2 +- mlir/lib/Analysis/MemRefBoundCheck.cpp | 2 +- mlir/lib/Analysis/NestedMatcher.cpp | 2 +- mlir/lib/Analysis/SliceAnalysis.cpp | 2 +- mlir/lib/Analysis/TestMemRefDependenceCheck.cpp | 2 +- mlir/lib/Analysis/TestParallelismDetection.cpp | 2 +- mlir/lib/Analysis/Utils.cpp | 2 +- mlir/lib/Analysis/VectorAnalysis.cpp | 2 +- mlir/lib/CMakeLists.txt | 1 - mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp | 2 +- mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp | 2 +- mlir/lib/Dialect/AffineOps/AffineOps.cpp | 1764 ++++++++++++++++++++ mlir/lib/Dialect/AffineOps/CMakeLists.txt | 10 + mlir/lib/Dialect/AffineOps/DialectRegistration.cpp | 22 + mlir/lib/Dialect/CMakeLists.txt | 1 + .../lib/Dialect/Linalg/Transforms/LowerToLoops.cpp | 12 +- mlir/lib/Transforms/AffineDataCopyGeneration.cpp | 2 +- mlir/lib/Transforms/LoopFusion.cpp | 2 +- mlir/lib/Transforms/LoopInvariantCodeMotion.cpp | 2 +- mlir/lib/Transforms/LoopTiling.cpp | 2 +- mlir/lib/Transforms/LoopUnroll.cpp | 2 +- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 2 +- mlir/lib/Transforms/LowerAffine.cpp | 2 +- mlir/lib/Transforms/MaterializeVectors.cpp | 2 +- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 2 +- mlir/lib/Transforms/PipelineDataTransfer.cpp | 2 +- mlir/lib/Transforms/Utils/LoopFusionUtils.cpp | 2 +- mlir/lib/Transforms/Utils/LoopUtils.cpp | 2 +- mlir/lib/Transforms/Utils/Utils.cpp | 2 +- mlir/lib/Transforms/Vectorize.cpp | 2 +- mlir/test/EDSC/builder-api-test.cpp | 2 +- mlir/test/lib/Transforms/TestConstantFold.cpp | 2 +- mlir/test/lib/Transforms/TestLoopFusion.cpp | 2 +- .../test/lib/Transforms/TestVectorizationUtils.cpp | 2 +- 53 files changed, 2743 insertions(+), 2743 deletions(-) delete mode 100644 mlir/include/mlir/AffineOps/AffineOps.h delete mode 100644 mlir/include/mlir/AffineOps/AffineOps.td delete mode 100644 mlir/include/mlir/AffineOps/AffineOpsBase.td delete mode 100644 mlir/include/mlir/AffineOps/CMakeLists.txt create mode 100644 mlir/include/mlir/Dialect/AffineOps/AffineOps.h create mode 100644 mlir/include/mlir/Dialect/AffineOps/AffineOps.td create mode 100644 mlir/include/mlir/Dialect/AffineOps/AffineOpsBase.td create mode 100644 mlir/include/mlir/Dialect/AffineOps/CMakeLists.txt delete mode 100644 mlir/lib/AffineOps/AffineOps.cpp delete mode 100644 mlir/lib/AffineOps/CMakeLists.txt delete mode 100644 mlir/lib/AffineOps/DialectRegistration.cpp create mode 100644 mlir/lib/Dialect/AffineOps/AffineOps.cpp create mode 100644 mlir/lib/Dialect/AffineOps/CMakeLists.txt create mode 100644 mlir/lib/Dialect/AffineOps/DialectRegistration.cpp (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h b/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h index 29ff9bd2d3e..8bedf513907 100644 --- a/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h +++ b/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h @@ -18,9 +18,9 @@ #ifndef LINALG1_COMMON_H_ #define LINALG1_COMMON_H_ -#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/Verifier.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/EDSC/Builders.h" #include "mlir/EDSC/Helpers.h" diff --git a/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp index 15e544a773c..68143dcc146 100644 --- a/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp @@ -23,7 +23,7 @@ #include "linalg3/Intrinsics.h" #include "linalg3/TensorOps.h" -#include "mlir/AffineOps/AffineOps.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/EDSC/Helpers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Transforms/LoopUtils.h" diff --git a/mlir/include/mlir/AffineOps/AffineOps.h b/mlir/include/mlir/AffineOps/AffineOps.h deleted file mode 100644 index 59f7fc782e6..00000000000 --- a/mlir/include/mlir/AffineOps/AffineOps.h +++ /dev/null @@ -1,598 +0,0 @@ -//===- AffineOps.h - MLIR Affine Operations -------------------------------===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file defines convenience types for working with Affine operations -// in the MLIR operation set. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_AFFINEOPS_AFFINEOPS_H -#define MLIR_AFFINEOPS_AFFINEOPS_H - -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/StandardTypes.h" - -namespace mlir { -class AffineBound; -class AffineValueMap; -class AffineTerminatorOp; -class FlatAffineConstraints; -class OpBuilder; - -/// A utility function to check if a value is defined at the top level of a -/// function. A value defined at the top level is always a valid symbol. -bool isTopLevelSymbol(Value *value); - -class AffineOpsDialect : public Dialect { -public: - AffineOpsDialect(MLIRContext *context); - static StringRef getDialectNamespace() { return "affine"; } -}; - -/// The "affine.apply" operation applies an affine map to a list of operands, -/// yielding a single result. The operand list must be the same size as the -/// number of arguments to the affine mapping. All operands and the result are -/// of type 'Index'. This operation requires a single affine map attribute named -/// "map". For example: -/// -/// %y = "affine.apply" (%x) { map: (d0) -> (d0 + 1) } : -/// (index) -> (index) -/// -/// equivalently: -/// -/// #map42 = (d0)->(d0+1) -/// %y = affine.apply #map42(%x) -/// -class AffineApplyOp : public Op { -public: - using Op::Op; - - /// Builds an affine apply op with the specified map and operands. - static void build(Builder *builder, OperationState *result, AffineMap map, - ArrayRef operands); - - /// Returns the affine map to be applied by this operation. - AffineMap getAffineMap() { - return getAttrOfType("map").getValue(); - } - - /// Returns true if the result of this operation can be used as dimension id. - bool isValidDim(); - - /// Returns true if the result of this operation is a symbol. - bool isValidSymbol(); - - static StringRef getOperationName() { return "affine.apply"; } - - // Hooks to customize behavior of this op. - static ParseResult parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p); - LogicalResult verify(); - OpFoldResult fold(ArrayRef operands); - - static void getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context); -}; - -/// AffineDmaStartOp starts a non-blocking DMA operation that transfers data -/// from a source memref to a destination memref. The source and destination -/// memref need not be of the same dimensionality, but need to have the same -/// elemental type. The operands include the source and destination memref's -/// each followed by its indices, size of the data transfer in terms of the -/// number of elements (of the elemental type of the memref), a tag memref with -/// its indices, and optionally at the end, a stride and a -/// number_of_elements_per_stride arguments. The tag location is used by an -/// AffineDmaWaitOp to check for completion. The indices of the source memref, -/// destination memref, and the tag memref have the same restrictions as any -/// affine.load/store. In particular, index for each memref dimension must be an -/// affine expression of loop induction variables and symbols. -/// The optional stride arguments should be of 'index' type, and specify a -/// stride for the slower memory space (memory space with a lower memory space -/// id), tranferring chunks of number_of_elements_per_stride every stride until -/// %num_elements are transferred. Either both or no stride arguments should be -/// specified. The value of 'num_elements' must be a multiple of -/// 'number_of_elements_per_stride'. -// -// For example, a DmaStartOp operation that transfers 256 elements of a memref -// '%src' in memory space 0 at indices [%i + 3, %j] to memref '%dst' in memory -// space 1 at indices [%k + 7, %l], would be specified as follows: -// -// %num_elements = constant 256 -// %idx = constant 0 : index -// %tag = alloc() : memref<1xi32, 4> -// affine.dma_start %src[%i + 3, %j], %dst[%k + 7, %l], %tag[%idx], -// %num_elements : -// memref<40x128xf32, 0>, memref<2x1024xf32, 1>, memref<1xi32, 2> -// -// If %stride and %num_elt_per_stride are specified, the DMA is expected to -// transfer %num_elt_per_stride elements every %stride elements apart from -// memory space 0 until %num_elements are transferred. -// -// affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%idx], %num_elements, -// %stride, %num_elt_per_stride : ... -// -// TODO(mlir-team): add additional operands to allow source and destination -// striding, and multiple stride levels (possibly using AffineMaps to specify -// multiple levels of striding). -// TODO(andydavis) Consider replacing src/dst memref indices with view memrefs. -class AffineDmaStartOp : public Op { -public: - using Op::Op; - - static void build(Builder *builder, OperationState *result, Value *srcMemRef, - AffineMap srcMap, ArrayRef srcIndices, - Value *destMemRef, AffineMap dstMap, - ArrayRef destIndices, Value *tagMemRef, - AffineMap tagMap, ArrayRef tagIndices, - Value *numElements, Value *stride = nullptr, - Value *elementsPerStride = nullptr); - - /// Returns the operand index of the src memref. - unsigned getSrcMemRefOperandIndex() { return 0; } - - /// Returns the source MemRefType for this DMA operation. - Value *getSrcMemRef() { return getOperand(getSrcMemRefOperandIndex()); } - MemRefType getSrcMemRefType() { - return getSrcMemRef()->getType().cast(); - } - - /// Returns the rank (number of indices) of the source MemRefType. - unsigned getSrcMemRefRank() { return getSrcMemRefType().getRank(); } - - /// Returns the affine map used to access the src memref. - AffineMap getSrcMap() { return getSrcMapAttr().getValue(); } - AffineMapAttr getSrcMapAttr() { - return getAttr(getSrcMapAttrName()).cast(); - } - - /// Returns the source memref affine map indices for this DMA operation. - operand_range getSrcIndices() { - return {operand_begin() + getSrcMemRefOperandIndex() + 1, - operand_begin() + getSrcMemRefOperandIndex() + 1 + - getSrcMap().getNumInputs()}; - } - - /// Returns the memory space of the src memref. - unsigned getSrcMemorySpace() { - return getSrcMemRef()->getType().cast().getMemorySpace(); - } - - /// Returns the operand index of the dst memref. - unsigned getDstMemRefOperandIndex() { - return getSrcMemRefOperandIndex() + 1 + getSrcMap().getNumInputs(); - } - - /// Returns the destination MemRefType for this DMA operations. - Value *getDstMemRef() { return getOperand(getDstMemRefOperandIndex()); } - MemRefType getDstMemRefType() { - return getDstMemRef()->getType().cast(); - } - - /// Returns the rank (number of indices) of the destination MemRefType. - unsigned getDstMemRefRank() { - return getDstMemRef()->getType().cast().getRank(); - } - - /// Returns the memory space of the src memref. - unsigned getDstMemorySpace() { - return getDstMemRef()->getType().cast().getMemorySpace(); - } - - /// Returns the affine map used to access the dst memref. - AffineMap getDstMap() { return getDstMapAttr().getValue(); } - AffineMapAttr getDstMapAttr() { - return getAttr(getDstMapAttrName()).cast(); - } - - /// Returns the destination memref indices for this DMA operation. - operand_range getDstIndices() { - return {operand_begin() + getDstMemRefOperandIndex() + 1, - operand_begin() + getDstMemRefOperandIndex() + 1 + - getDstMap().getNumInputs()}; - } - - /// Returns the operand index of the tag memref. - unsigned getTagMemRefOperandIndex() { - return getDstMemRefOperandIndex() + 1 + getDstMap().getNumInputs(); - } - - /// Returns the Tag MemRef for this DMA operation. - Value *getTagMemRef() { return getOperand(getTagMemRefOperandIndex()); } - MemRefType getTagMemRefType() { - return getTagMemRef()->getType().cast(); - } - - /// Returns the rank (number of indices) of the tag MemRefType. - unsigned getTagMemRefRank() { - return getTagMemRef()->getType().cast().getRank(); - } - - /// Returns the affine map used to access the tag memref. - AffineMap getTagMap() { return getTagMapAttr().getValue(); } - AffineMapAttr getTagMapAttr() { - return getAttr(getTagMapAttrName()).cast(); - } - - /// Returns the tag memref indices for this DMA operation. - operand_range getTagIndices() { - return {operand_begin() + getTagMemRefOperandIndex() + 1, - operand_begin() + getTagMemRefOperandIndex() + 1 + - getTagMap().getNumInputs()}; - } - - /// Returns the number of elements being transferred by this DMA operation. - Value *getNumElements() { - return getOperand(getTagMemRefOperandIndex() + 1 + - getTagMap().getNumInputs()); - } - - /// Returns the AffineMapAttr associated with 'memref'. - NamedAttribute getAffineMapAttrForMemRef(Value *memref) { - if (memref == getSrcMemRef()) - return {Identifier::get(getSrcMapAttrName(), getContext()), - getSrcMapAttr()}; - else if (memref == getDstMemRef()) - return {Identifier::get(getDstMapAttrName(), getContext()), - getDstMapAttr()}; - assert(memref == getTagMemRef() && - "DmaStartOp expected source, destination or tag memref"); - return {Identifier::get(getTagMapAttrName(), getContext()), - getTagMapAttr()}; - } - - /// Returns true if this is a DMA from a faster memory space to a slower one. - bool isDestMemorySpaceFaster() { - return (getSrcMemorySpace() < getDstMemorySpace()); - } - - /// Returns true if this is a DMA from a slower memory space to a faster one. - bool isSrcMemorySpaceFaster() { - // Assumes that a lower number is for a slower memory space. - return (getDstMemorySpace() < getSrcMemorySpace()); - } - - /// Given a DMA start operation, returns the operand position of either the - /// source or destination memref depending on the one that is at the higher - /// level of the memory hierarchy. Asserts failure if neither is true. - unsigned getFasterMemPos() { - assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster()); - return isSrcMemorySpaceFaster() ? 0 : getDstMemRefOperandIndex(); - } - - static StringRef getSrcMapAttrName() { return "src_map"; } - static StringRef getDstMapAttrName() { return "dst_map"; } - static StringRef getTagMapAttrName() { return "tag_map"; } - - static StringRef getOperationName() { return "affine.dma_start"; } - static ParseResult parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p); - LogicalResult verify(); - static void getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context); - - /// Returns true if this DMA operation is strided, returns false otherwise. - bool isStrided() { - return getNumOperands() != - getTagMemRefOperandIndex() + 1 + getTagMap().getNumInputs() + 1; - } - - /// Returns the stride value for this DMA operation. - Value *getStride() { - if (!isStrided()) - return nullptr; - return getOperand(getNumOperands() - 1 - 1); - } - - /// Returns the number of elements to transfer per stride for this DMA op. - Value *getNumElementsPerStride() { - if (!isStrided()) - return nullptr; - return getOperand(getNumOperands() - 1); - } -}; - -/// AffineDmaWaitOp blocks until the completion of a DMA operation associated -/// with the tag element '%tag[%index]'. %tag is a memref, and %index has to be -/// an index with the same restrictions as any load/store index. In particular, -/// index for each memref dimension must be an affine expression of loop -/// induction variables and symbols. %num_elements is the number of elements -/// associated with the DMA operation. For example: -// -// affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %num_elements : -// memref<2048xf32, 0>, memref<256xf32, 1>, memref<1xi32, 2> -// ... -// ... -// affine.dma_wait %tag[%index], %num_elements : memref<1xi32, 2> -// -class AffineDmaWaitOp : public Op { -public: - using Op::Op; - - static void build(Builder *builder, OperationState *result, Value *tagMemRef, - AffineMap tagMap, ArrayRef tagIndices, - Value *numElements); - - static StringRef getOperationName() { return "affine.dma_wait"; } - - // Returns the Tag MemRef associated with the DMA operation being waited on. - Value *getTagMemRef() { return getOperand(0); } - MemRefType getTagMemRefType() { - return getTagMemRef()->getType().cast(); - } - - /// Returns the affine map used to access the tag memref. - AffineMap getTagMap() { return getTagMapAttr().getValue(); } - AffineMapAttr getTagMapAttr() { - return getAttr(getTagMapAttrName()).cast(); - } - - // Returns the tag memref index for this DMA operation. - operand_range getTagIndices() { - return {operand_begin() + 1, - operand_begin() + 1 + getTagMap().getNumInputs()}; - } - - // Returns the rank (number of indices) of the tag memref. - unsigned getTagMemRefRank() { - return getTagMemRef()->getType().cast().getRank(); - } - - /// Returns the AffineMapAttr associated with 'memref'. - NamedAttribute getAffineMapAttrForMemRef(Value *memref) { - assert(memref == getTagMemRef()); - return {Identifier::get(getTagMapAttrName(), getContext()), - getTagMapAttr()}; - } - - /// Returns the number of elements transferred in the associated DMA op. - Value *getNumElements() { return getOperand(1 + getTagMap().getNumInputs()); } - - static StringRef getTagMapAttrName() { return "tag_map"; } - static ParseResult parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p); - LogicalResult verify(); - static void getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context); -}; - -/// The "affine.load" op reads an element from a memref, where the index -/// for each memref dimension is an affine expression of loop induction -/// variables and symbols. The output of 'affine.load' is a new value with the -/// same type as the elements of the memref. An affine expression of loop IVs -/// and symbols must be specified for each dimension of the memref. The keyword -/// 'symbol' can be used to indicate SSA identifiers which are symbolic. -// -// Example 1: -// -// %1 = affine.load %0[%i0 + 3, %i1 + 7] : memref<100x100xf32> -// -// Example 2: Uses 'symbol' keyword for symbols '%n' and '%m'. -// -// %1 = affine.load %0[%i0 + symbol(%n), %i1 + symbol(%m)] -// : memref<100x100xf32> -// -class AffineLoadOp : public Op::Impl> { -public: - using Op::Op; - - /// Builds an affine load op with the specified map and operands. - static void build(Builder *builder, OperationState *result, AffineMap map, - ArrayRef operands); - /// Builds an affine load op an identify map and operands. - static void build(Builder *builder, OperationState *result, Value *memref, - ArrayRef indices = {}); - - /// Returns the operand index of the memref. - unsigned getMemRefOperandIndex() { return 0; } - - /// Get memref operand. - Value *getMemRef() { return getOperand(getMemRefOperandIndex()); } - void setMemRef(Value *value) { setOperand(getMemRefOperandIndex(), value); } - MemRefType getMemRefType() { - return getMemRef()->getType().cast(); - } - - /// Get affine map operands. - operand_range getIndices() { return llvm::drop_begin(getOperands(), 1); } - - /// Returns the affine map used to index the memref for this operation. - AffineMap getAffineMap() { return getAffineMapAttr().getValue(); } - AffineMapAttr getAffineMapAttr() { - return getAttr(getMapAttrName()).cast(); - } - - /// Returns the AffineMapAttr associated with 'memref'. - NamedAttribute getAffineMapAttrForMemRef(Value *memref) { - assert(memref == getMemRef()); - return {Identifier::get(getMapAttrName(), getContext()), - getAffineMapAttr()}; - } - - static StringRef getMapAttrName() { return "map"; } - static StringRef getOperationName() { return "affine.load"; } - - // Hooks to customize behavior of this op. - static ParseResult parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p); - LogicalResult verify(); - static void getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context); -}; - -/// The "affine.store" op writes an element to a memref, where the index -/// for each memref dimension is an affine expression of loop induction -/// variables and symbols. The 'affine.store' op stores a new value which is the -/// same type as the elements of the memref. An affine expression of loop IVs -/// and symbols must be specified for each dimension of the memref. The keyword -/// 'symbol' can be used to indicate SSA identifiers which are symbolic. -// -// Example 1: -// -// affine.store %v0, %0[%i0 + 3, %i1 + 7] : memref<100x100xf32> -// -// Example 2: Uses 'symbol' keyword for symbols '%n' and '%m'. -// -// affine.store %v0, %0[%i0 + symbol(%n), %i1 + symbol(%m)] -// : memref<100x100xf32> -// -class AffineStoreOp : public Op::Impl> { -public: - using Op::Op; - - /// Builds an affine store operation with the specified map and operands. - static void build(Builder *builder, OperationState *result, - Value *valueToStore, AffineMap map, - ArrayRef operands); - /// Builds an affine store operation with an identity map and operands. - static void build(Builder *builder, OperationState *result, - Value *valueToStore, Value *memref, - ArrayRef operands); - - /// Get value to be stored by store operation. - Value *getValueToStore() { return getOperand(0); } - - /// Returns the operand index of the memref. - unsigned getMemRefOperandIndex() { return 1; } - - /// Get memref operand. - Value *getMemRef() { return getOperand(getMemRefOperandIndex()); } - void setMemRef(Value *value) { setOperand(getMemRefOperandIndex(), value); } - - MemRefType getMemRefType() { - return getMemRef()->getType().cast(); - } - - /// Get affine map operands. - operand_range getIndices() { return llvm::drop_begin(getOperands(), 2); } - - /// Returns the affine map used to index the memref for this operation. - AffineMap getAffineMap() { return getAffineMapAttr().getValue(); } - AffineMapAttr getAffineMapAttr() { - return getAttr(getMapAttrName()).cast(); - } - - /// Returns the AffineMapAttr associated with 'memref'. - NamedAttribute getAffineMapAttrForMemRef(Value *memref) { - assert(memref == getMemRef()); - return {Identifier::get(getMapAttrName(), getContext()), - getAffineMapAttr()}; - } - - static StringRef getMapAttrName() { return "map"; } - static StringRef getOperationName() { return "affine.store"; } - - // Hooks to customize behavior of this op. - static ParseResult parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p); - LogicalResult verify(); - static void getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context); -}; - -/// Returns true if the given Value can be used as a dimension id. -bool isValidDim(Value *value); - -/// Returns true if the given Value can be used as a symbol. -bool isValidSymbol(Value *value); - -/// Modifies both `map` and `operands` in-place so as to: -/// 1. drop duplicate operands -/// 2. drop unused dims and symbols from map -void canonicalizeMapAndOperands(AffineMap *map, - llvm::SmallVectorImpl *operands); - -/// Returns a composed AffineApplyOp by composing `map` and `operands` with -/// other AffineApplyOps supplying those operands. The operands of the resulting -/// AffineApplyOp do not change the length of AffineApplyOp chains. -AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, - llvm::ArrayRef operands); - -/// Given an affine map `map` and its input `operands`, this method composes -/// into `map`, maps of AffineApplyOps whose results are the values in -/// `operands`, iteratively until no more of `operands` are the result of an -/// AffineApplyOp. When this function returns, `map` becomes the composed affine -/// map, and each Value in `operands` is guaranteed to be either a loop IV or a -/// terminal symbol, i.e., a symbol defined at the top level or a block/function -/// argument. -void fullyComposeAffineMapAndOperands(AffineMap *map, - llvm::SmallVectorImpl *operands); - -#define GET_OP_CLASSES -#include "mlir/AffineOps/AffineOps.h.inc" - -/// Returns if the provided value is the induction variable of a AffineForOp. -bool isForInductionVar(Value *val); - -/// Returns the loop parent of an induction variable. If the provided value is -/// not an induction variable, then return nullptr. -AffineForOp getForInductionVarOwner(Value *val); - -/// Extracts the induction variables from a list of AffineForOps and places them -/// in the output argument `ivs`. -void extractForInductionVars(ArrayRef forInsts, - SmallVectorImpl *ivs); - -/// AffineBound represents a lower or upper bound in the for operation. -/// This class does not own the underlying operands. Instead, it refers -/// to the operands stored in the AffineForOp. Its life span should not exceed -/// that of the for operation it refers to. -class AffineBound { -public: - AffineForOp getAffineForOp() { return op; } - AffineMap getMap() { return map; } - - /// Returns an AffineValueMap representing this bound. - AffineValueMap getAsAffineValueMap(); - - unsigned getNumOperands() { return opEnd - opStart; } - Value *getOperand(unsigned idx) { - return op.getOperation()->getOperand(opStart + idx); - } - - using operand_iterator = AffineForOp::operand_iterator; - using operand_range = AffineForOp::operand_range; - - operand_iterator operand_begin() { return op.operand_begin() + opStart; } - operand_iterator operand_end() { return op.operand_begin() + opEnd; } - operand_range getOperands() { return {operand_begin(), operand_end()}; } - -private: - // 'affine.for' operation that contains this bound. - AffineForOp op; - // Start and end positions of this affine bound operands in the list of - // the containing 'affine.for' operation operands. - unsigned opStart, opEnd; - // Affine map for this bound. - AffineMap map; - - AffineBound(AffineForOp op, unsigned opStart, unsigned opEnd, AffineMap map) - : op(op), opStart(opStart), opEnd(opEnd), map(map) {} - - friend class AffineForOp; -}; - -} // end namespace mlir - -#endif diff --git a/mlir/include/mlir/AffineOps/AffineOps.td b/mlir/include/mlir/AffineOps/AffineOps.td deleted file mode 100644 index c517ed0244d..00000000000 --- a/mlir/include/mlir/AffineOps/AffineOps.td +++ /dev/null @@ -1,259 +0,0 @@ -//===- AffineOps.td - Affine operation definitions ---------*- tablegen -*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// Defines MLIR affine operations. -// -//===----------------------------------------------------------------------===// - -#ifdef AFFINE_OPS -#else -#define AFFINE_OPS - -#ifdef OP_BASE -#else -include "mlir/IR/OpBase.td" -#endif // OP_BASE - -include "mlir/AffineOps/AffineOpsBase.td" - -def Affine_Dialect : Dialect { - let name = "affine"; - let cppNamespace = ""; -} - -// Base class for Affine dialect ops. -class Affine_Op traits = []> : - Op { - // For every affine op, there needs to be a: - // * void print(OpAsmPrinter *p, ${C++ class of Op} op) - // * LogicalResult verify(${C++ class of Op} op) - // * ParseResult parse${C++ class of Op}(OpAsmParser *parser, - // OperationState *result) - // functions. - let printer = [{ return ::print(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; - let parser = [{ return ::parse$cppClass(parser, result); }]; -} - -// Require regions to have affine terminator. -def ImplicitAffineTerminator - : SingleBlockImplicitTerminator<"AffineTerminatorOp">; - -def AffineForOp : Affine_Op<"for", [ImplicitAffineTerminator]> { - let summary = "for operation"; - let description = [{ - The "affine.for" operation represents an affine loop nest, defining an SSA - value for its induction variable. It has one region capturing the loop body. - The induction variable is represented as a argument of this region. This SSA - value always has type index, which is the size of the machine word. The - stride, represented by step, is a positive constant integer which defaults - to "1" if not present. The lower and upper bounds specify a half-open range: - the range includes the lower bound but does not include the upper bound. - - The body region must contain exactly one block that terminates with - "affine.terminator". Calling AffineForOp::build will create such region - and insert the terminator, so will the parsing even in cases if it is absent - from the custom format. - - The lower and upper bounds of a for operation are represented as an - application of an affine mapping to a list of SSA values passed to the map. - The same restrictions hold for these SSA values as for all bindings of SSA - values to dimensions and symbols. The affine mappings for the bounds may - return multiple results, in which case the max/min keywords are required - (for the lower/upper bound respectively), and the bound is the - maximum/minimum of the returned values. - - Example: - - affine.for %i = 1 to 10 { - ... - } - - }]; - let arguments = (ins Variadic); - let regions = (region SizedRegion<1>:$region); - - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder<"Builder *builder, OperationState *result, " - "int64_t lowerBound, int64_t upperBound, int64_t step = 1">, - OpBuilder<"Builder *builder, OperationState *result, " - "ArrayRef lbOperands, AffineMap lbMap, " - "ArrayRef ubOperands, AffineMap ubMap, " - "int64_t step = 1"> - ]; - - let extraClassDeclaration = [{ - static StringRef getStepAttrName() { return "step"; } - static StringRef getLowerBoundAttrName() { return "lower_bound"; } - static StringRef getUpperBoundAttrName() { return "upper_bound"; } - - Block *getBody() { return ®ion().front(); } - Value *getInductionVar() { return getBody()->getArgument(0); } - OpBuilder getBodyBuilder() { - return OpBuilder(getBody(), std::prev(getBody()->end())); - } - - // TODO: provide iterators for the lower and upper bound operands - // if the current access via getLowerBound(), getUpperBound() is too slow. - - /// Returns operands for the lower bound map. - operand_range getLowerBoundOperands(); - - /// Returns operands for the upper bound map. - operand_range getUpperBoundOperands(); - - /// Returns information about the lower bound as a single object. - AffineBound getLowerBound(); - - /// Returns information about the upper bound as a single object. - AffineBound getUpperBound(); - - /// Returns loop step. - int64_t getStep() { - return getAttr(getStepAttrName()).cast().getInt(); - } - - /// Returns affine map for the lower bound. - AffineMap getLowerBoundMap() { return getLowerBoundMapAttr().getValue(); } - AffineMapAttr getLowerBoundMapAttr() { - return getAttr(getLowerBoundAttrName()).cast(); - } - /// Returns affine map for the upper bound. The upper bound is exclusive. - AffineMap getUpperBoundMap() { return getUpperBoundMapAttr().getValue(); } - AffineMapAttr getUpperBoundMapAttr() { - return getAttr(getUpperBoundAttrName()).cast(); - } - - /// Set lower bound. The new bound must have the same number of operands as - /// the current bound map. Otherwise, 'replaceForLowerBound' should be used. - void setLowerBound(ArrayRef operands, AffineMap map); - /// Set upper bound. The new bound must not have more operands than the - /// current bound map. Otherwise, 'replaceForUpperBound' should be used. - void setUpperBound(ArrayRef operands, AffineMap map); - - /// Set the lower bound map without changing operands. - void setLowerBoundMap(AffineMap map); - - /// Set the upper bound map without changing operands. - void setUpperBoundMap(AffineMap map); - - /// Set loop step. - void setStep(int64_t step) { - assert(step > 0 && "step has to be a positive integer constant"); - auto *context = getLowerBoundMap().getContext(); - setAttr(Identifier::get(getStepAttrName(), context), - IntegerAttr::get(IndexType::get(context), step)); - } - - /// Returns true if the lower bound is constant. - bool hasConstantLowerBound(); - /// Returns true if the upper bound is constant. - bool hasConstantUpperBound(); - /// Returns true if both bounds are constant. - bool hasConstantBounds() { - return hasConstantLowerBound() && hasConstantUpperBound(); - } - /// Returns the value of the constant lower bound. - /// Fails assertion if the bound is non-constant. - int64_t getConstantLowerBound(); - /// Returns the value of the constant upper bound. The upper bound is - /// exclusive. Fails assertion if the bound is non-constant. - int64_t getConstantUpperBound(); - /// Sets the lower bound to the given constant value. - void setConstantLowerBound(int64_t value); - /// Sets the upper bound to the given constant value. - void setConstantUpperBound(int64_t value); - - /// Returns true if both the lower and upper bound have the same operand - /// lists (same operands in the same order). - bool matchingBoundOperandList(); - }]; - - let hasCanonicalizer = 1; -} - -def AffineIfOp : Affine_Op<"if", [ImplicitAffineTerminator]> { - let summary = "if-then-else operation"; - let description = [{ - The "if" operation represents an if-then-else construct for conditionally - executing two regions of code. The operands to an if operation are an - IntegerSet condition and a set of symbol/dimension operands to the - condition set. The operation produces no results. For example: - - affine.if #set(%i) { - ... - } else { - ... - } - - The 'else' blocks to the if operation are optional, and may be omitted. For - example: - - affine.if #set(%i) { - ... - } - }]; - let arguments = (ins Variadic); - let regions = (region SizedRegion<1>:$thenRegion, AnyRegion:$elseRegion); - - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder<"Builder *builder, OperationState *result, " - "Value *cond, bool withElseRegion"> - ]; - - let extraClassDeclaration = [{ - static StringRef getConditionAttrName() { return "condition"; } - - IntegerSet getIntegerSet(); - void setIntegerSet(IntegerSet newSet); - - OpBuilder getThenBodyBuilder() { - assert(!thenRegion().empty() && "Unexpected empty 'then' region."); - Block &body = thenRegion().front(); - return OpBuilder(&body, std::prev(body.end())); - } - OpBuilder getElseBodyBuilder() { - assert(!elseRegion().empty() && "Unexpected empty 'else' region."); - Block &body = elseRegion().front(); - return OpBuilder(&body, std::prev(body.end())); - } - }]; -} - -def AffineTerminatorOp : - Affine_Op<"terminator", [Terminator]> { - let summary = "affine terminator operation"; - let description = [{ - Affine terminator is a special terminator operation for blocks inside affine - loops and branches. It unconditionally transmits the control flow to the - successor of the operation enclosing the region. - - This operation does _not_ have a custom syntax. However, affine control - operations omit the terminator in their custom syntax for brevity. - }]; - - // No custom parsing/printing form. - let parser = ?; - let printer = ?; - - // Fully specified by traits. - let verifier = ?; -} - -#endif // AFFINE_OPS diff --git a/mlir/include/mlir/AffineOps/AffineOpsBase.td b/mlir/include/mlir/AffineOps/AffineOpsBase.td deleted file mode 100644 index 2ac1d379c12..00000000000 --- a/mlir/include/mlir/AffineOps/AffineOpsBase.td +++ /dev/null @@ -1,44 +0,0 @@ -//===- AffineOpsBase.td - Affine operation definitions -----*- tablegen -*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// Defines base support for MLIR affine operations. -// -//===----------------------------------------------------------------------===// - -#ifdef AFFINE_OPS_BASE -#else -#define AFFINE_OPS_BASE - -#ifdef OP_BASE -#else -include "mlir/IR/OpBase.td" -#endif // OP_BASE - -// Attributes containing affine maps. -def AffineMapAttr : Attr< - CPred<"$_self.isa()">, "AffineMap attribute"> { - let storageType = [{ AffineMapAttr }]; - let returnType = [{ AffineMap }]; - let constBuilderCall = "$_builder.getAffineMapAttr($0)"; -} - -def AffineMapArrayAttr : TypedArrayAttrBase { - let constBuilderCall = "$_builder.getAffineMapArrayAttr($0)"; -} - -#endif // AFFINE_OPS_BASE diff --git a/mlir/include/mlir/AffineOps/CMakeLists.txt b/mlir/include/mlir/AffineOps/CMakeLists.txt deleted file mode 100644 index 6c5a58c957b..00000000000 --- a/mlir/include/mlir/AffineOps/CMakeLists.txt +++ /dev/null @@ -1,4 +0,0 @@ -set(LLVM_TARGET_DEFINITIONS AffineOps.td) -mlir_tablegen(AffineOps.h.inc -gen-op-decls) -mlir_tablegen(AffineOps.cpp.inc -gen-op-defs) -add_public_tablegen_target(MLIRAffineOpsIncGen) diff --git a/mlir/include/mlir/CMakeLists.txt b/mlir/include/mlir/CMakeLists.txt index 043db03641f..b393ea2c0e8 100644 --- a/mlir/include/mlir/CMakeLists.txt +++ b/mlir/include/mlir/CMakeLists.txt @@ -1,3 +1,2 @@ -add_subdirectory(AffineOps) add_subdirectory(Dialect) add_subdirectory(EDSC) diff --git a/mlir/include/mlir/Dialect/AffineOps/AffineOps.h b/mlir/include/mlir/Dialect/AffineOps/AffineOps.h new file mode 100644 index 00000000000..a6af20eca0b --- /dev/null +++ b/mlir/include/mlir/Dialect/AffineOps/AffineOps.h @@ -0,0 +1,598 @@ +//===- AffineOps.h - MLIR Affine Operations -------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines convenience types for working with Affine operations +// in the MLIR operation set. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_AFFINEOPS_AFFINEOPS_H +#define MLIR_DIALECT_AFFINEOPS_AFFINEOPS_H + +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" + +namespace mlir { +class AffineBound; +class AffineValueMap; +class AffineTerminatorOp; +class FlatAffineConstraints; +class OpBuilder; + +/// A utility function to check if a value is defined at the top level of a +/// function. A value defined at the top level is always a valid symbol. +bool isTopLevelSymbol(Value *value); + +class AffineOpsDialect : public Dialect { +public: + AffineOpsDialect(MLIRContext *context); + static StringRef getDialectNamespace() { return "affine"; } +}; + +/// The "affine.apply" operation applies an affine map to a list of operands, +/// yielding a single result. The operand list must be the same size as the +/// number of arguments to the affine mapping. All operands and the result are +/// of type 'Index'. This operation requires a single affine map attribute named +/// "map". For example: +/// +/// %y = "affine.apply" (%x) { map: (d0) -> (d0 + 1) } : +/// (index) -> (index) +/// +/// equivalently: +/// +/// #map42 = (d0)->(d0+1) +/// %y = affine.apply #map42(%x) +/// +class AffineApplyOp : public Op { +public: + using Op::Op; + + /// Builds an affine apply op with the specified map and operands. + static void build(Builder *builder, OperationState *result, AffineMap map, + ArrayRef operands); + + /// Returns the affine map to be applied by this operation. + AffineMap getAffineMap() { + return getAttrOfType("map").getValue(); + } + + /// Returns true if the result of this operation can be used as dimension id. + bool isValidDim(); + + /// Returns true if the result of this operation is a symbol. + bool isValidSymbol(); + + static StringRef getOperationName() { return "affine.apply"; } + + // Hooks to customize behavior of this op. + static ParseResult parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p); + LogicalResult verify(); + OpFoldResult fold(ArrayRef operands); + + static void getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context); +}; + +/// AffineDmaStartOp starts a non-blocking DMA operation that transfers data +/// from a source memref to a destination memref. The source and destination +/// memref need not be of the same dimensionality, but need to have the same +/// elemental type. The operands include the source and destination memref's +/// each followed by its indices, size of the data transfer in terms of the +/// number of elements (of the elemental type of the memref), a tag memref with +/// its indices, and optionally at the end, a stride and a +/// number_of_elements_per_stride arguments. The tag location is used by an +/// AffineDmaWaitOp to check for completion. The indices of the source memref, +/// destination memref, and the tag memref have the same restrictions as any +/// affine.load/store. In particular, index for each memref dimension must be an +/// affine expression of loop induction variables and symbols. +/// The optional stride arguments should be of 'index' type, and specify a +/// stride for the slower memory space (memory space with a lower memory space +/// id), tranferring chunks of number_of_elements_per_stride every stride until +/// %num_elements are transferred. Either both or no stride arguments should be +/// specified. The value of 'num_elements' must be a multiple of +/// 'number_of_elements_per_stride'. +// +// For example, a DmaStartOp operation that transfers 256 elements of a memref +// '%src' in memory space 0 at indices [%i + 3, %j] to memref '%dst' in memory +// space 1 at indices [%k + 7, %l], would be specified as follows: +// +// %num_elements = constant 256 +// %idx = constant 0 : index +// %tag = alloc() : memref<1xi32, 4> +// affine.dma_start %src[%i + 3, %j], %dst[%k + 7, %l], %tag[%idx], +// %num_elements : +// memref<40x128xf32, 0>, memref<2x1024xf32, 1>, memref<1xi32, 2> +// +// If %stride and %num_elt_per_stride are specified, the DMA is expected to +// transfer %num_elt_per_stride elements every %stride elements apart from +// memory space 0 until %num_elements are transferred. +// +// affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%idx], %num_elements, +// %stride, %num_elt_per_stride : ... +// +// TODO(mlir-team): add additional operands to allow source and destination +// striding, and multiple stride levels (possibly using AffineMaps to specify +// multiple levels of striding). +// TODO(andydavis) Consider replacing src/dst memref indices with view memrefs. +class AffineDmaStartOp : public Op { +public: + using Op::Op; + + static void build(Builder *builder, OperationState *result, Value *srcMemRef, + AffineMap srcMap, ArrayRef srcIndices, + Value *destMemRef, AffineMap dstMap, + ArrayRef destIndices, Value *tagMemRef, + AffineMap tagMap, ArrayRef tagIndices, + Value *numElements, Value *stride = nullptr, + Value *elementsPerStride = nullptr); + + /// Returns the operand index of the src memref. + unsigned getSrcMemRefOperandIndex() { return 0; } + + /// Returns the source MemRefType for this DMA operation. + Value *getSrcMemRef() { return getOperand(getSrcMemRefOperandIndex()); } + MemRefType getSrcMemRefType() { + return getSrcMemRef()->getType().cast(); + } + + /// Returns the rank (number of indices) of the source MemRefType. + unsigned getSrcMemRefRank() { return getSrcMemRefType().getRank(); } + + /// Returns the affine map used to access the src memref. + AffineMap getSrcMap() { return getSrcMapAttr().getValue(); } + AffineMapAttr getSrcMapAttr() { + return getAttr(getSrcMapAttrName()).cast(); + } + + /// Returns the source memref affine map indices for this DMA operation. + operand_range getSrcIndices() { + return {operand_begin() + getSrcMemRefOperandIndex() + 1, + operand_begin() + getSrcMemRefOperandIndex() + 1 + + getSrcMap().getNumInputs()}; + } + + /// Returns the memory space of the src memref. + unsigned getSrcMemorySpace() { + return getSrcMemRef()->getType().cast().getMemorySpace(); + } + + /// Returns the operand index of the dst memref. + unsigned getDstMemRefOperandIndex() { + return getSrcMemRefOperandIndex() + 1 + getSrcMap().getNumInputs(); + } + + /// Returns the destination MemRefType for this DMA operations. + Value *getDstMemRef() { return getOperand(getDstMemRefOperandIndex()); } + MemRefType getDstMemRefType() { + return getDstMemRef()->getType().cast(); + } + + /// Returns the rank (number of indices) of the destination MemRefType. + unsigned getDstMemRefRank() { + return getDstMemRef()->getType().cast().getRank(); + } + + /// Returns the memory space of the src memref. + unsigned getDstMemorySpace() { + return getDstMemRef()->getType().cast().getMemorySpace(); + } + + /// Returns the affine map used to access the dst memref. + AffineMap getDstMap() { return getDstMapAttr().getValue(); } + AffineMapAttr getDstMapAttr() { + return getAttr(getDstMapAttrName()).cast(); + } + + /// Returns the destination memref indices for this DMA operation. + operand_range getDstIndices() { + return {operand_begin() + getDstMemRefOperandIndex() + 1, + operand_begin() + getDstMemRefOperandIndex() + 1 + + getDstMap().getNumInputs()}; + } + + /// Returns the operand index of the tag memref. + unsigned getTagMemRefOperandIndex() { + return getDstMemRefOperandIndex() + 1 + getDstMap().getNumInputs(); + } + + /// Returns the Tag MemRef for this DMA operation. + Value *getTagMemRef() { return getOperand(getTagMemRefOperandIndex()); } + MemRefType getTagMemRefType() { + return getTagMemRef()->getType().cast(); + } + + /// Returns the rank (number of indices) of the tag MemRefType. + unsigned getTagMemRefRank() { + return getTagMemRef()->getType().cast().getRank(); + } + + /// Returns the affine map used to access the tag memref. + AffineMap getTagMap() { return getTagMapAttr().getValue(); } + AffineMapAttr getTagMapAttr() { + return getAttr(getTagMapAttrName()).cast(); + } + + /// Returns the tag memref indices for this DMA operation. + operand_range getTagIndices() { + return {operand_begin() + getTagMemRefOperandIndex() + 1, + operand_begin() + getTagMemRefOperandIndex() + 1 + + getTagMap().getNumInputs()}; + } + + /// Returns the number of elements being transferred by this DMA operation. + Value *getNumElements() { + return getOperand(getTagMemRefOperandIndex() + 1 + + getTagMap().getNumInputs()); + } + + /// Returns the AffineMapAttr associated with 'memref'. + NamedAttribute getAffineMapAttrForMemRef(Value *memref) { + if (memref == getSrcMemRef()) + return {Identifier::get(getSrcMapAttrName(), getContext()), + getSrcMapAttr()}; + else if (memref == getDstMemRef()) + return {Identifier::get(getDstMapAttrName(), getContext()), + getDstMapAttr()}; + assert(memref == getTagMemRef() && + "DmaStartOp expected source, destination or tag memref"); + return {Identifier::get(getTagMapAttrName(), getContext()), + getTagMapAttr()}; + } + + /// Returns true if this is a DMA from a faster memory space to a slower one. + bool isDestMemorySpaceFaster() { + return (getSrcMemorySpace() < getDstMemorySpace()); + } + + /// Returns true if this is a DMA from a slower memory space to a faster one. + bool isSrcMemorySpaceFaster() { + // Assumes that a lower number is for a slower memory space. + return (getDstMemorySpace() < getSrcMemorySpace()); + } + + /// Given a DMA start operation, returns the operand position of either the + /// source or destination memref depending on the one that is at the higher + /// level of the memory hierarchy. Asserts failure if neither is true. + unsigned getFasterMemPos() { + assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster()); + return isSrcMemorySpaceFaster() ? 0 : getDstMemRefOperandIndex(); + } + + static StringRef getSrcMapAttrName() { return "src_map"; } + static StringRef getDstMapAttrName() { return "dst_map"; } + static StringRef getTagMapAttrName() { return "tag_map"; } + + static StringRef getOperationName() { return "affine.dma_start"; } + static ParseResult parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p); + LogicalResult verify(); + static void getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context); + + /// Returns true if this DMA operation is strided, returns false otherwise. + bool isStrided() { + return getNumOperands() != + getTagMemRefOperandIndex() + 1 + getTagMap().getNumInputs() + 1; + } + + /// Returns the stride value for this DMA operation. + Value *getStride() { + if (!isStrided()) + return nullptr; + return getOperand(getNumOperands() - 1 - 1); + } + + /// Returns the number of elements to transfer per stride for this DMA op. + Value *getNumElementsPerStride() { + if (!isStrided()) + return nullptr; + return getOperand(getNumOperands() - 1); + } +}; + +/// AffineDmaWaitOp blocks until the completion of a DMA operation associated +/// with the tag element '%tag[%index]'. %tag is a memref, and %index has to be +/// an index with the same restrictions as any load/store index. In particular, +/// index for each memref dimension must be an affine expression of loop +/// induction variables and symbols. %num_elements is the number of elements +/// associated with the DMA operation. For example: +// +// affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %num_elements : +// memref<2048xf32, 0>, memref<256xf32, 1>, memref<1xi32, 2> +// ... +// ... +// affine.dma_wait %tag[%index], %num_elements : memref<1xi32, 2> +// +class AffineDmaWaitOp : public Op { +public: + using Op::Op; + + static void build(Builder *builder, OperationState *result, Value *tagMemRef, + AffineMap tagMap, ArrayRef tagIndices, + Value *numElements); + + static StringRef getOperationName() { return "affine.dma_wait"; } + + // Returns the Tag MemRef associated with the DMA operation being waited on. + Value *getTagMemRef() { return getOperand(0); } + MemRefType getTagMemRefType() { + return getTagMemRef()->getType().cast(); + } + + /// Returns the affine map used to access the tag memref. + AffineMap getTagMap() { return getTagMapAttr().getValue(); } + AffineMapAttr getTagMapAttr() { + return getAttr(getTagMapAttrName()).cast(); + } + + // Returns the tag memref index for this DMA operation. + operand_range getTagIndices() { + return {operand_begin() + 1, + operand_begin() + 1 + getTagMap().getNumInputs()}; + } + + // Returns the rank (number of indices) of the tag memref. + unsigned getTagMemRefRank() { + return getTagMemRef()->getType().cast().getRank(); + } + + /// Returns the AffineMapAttr associated with 'memref'. + NamedAttribute getAffineMapAttrForMemRef(Value *memref) { + assert(memref == getTagMemRef()); + return {Identifier::get(getTagMapAttrName(), getContext()), + getTagMapAttr()}; + } + + /// Returns the number of elements transferred in the associated DMA op. + Value *getNumElements() { return getOperand(1 + getTagMap().getNumInputs()); } + + static StringRef getTagMapAttrName() { return "tag_map"; } + static ParseResult parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p); + LogicalResult verify(); + static void getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context); +}; + +/// The "affine.load" op reads an element from a memref, where the index +/// for each memref dimension is an affine expression of loop induction +/// variables and symbols. The output of 'affine.load' is a new value with the +/// same type as the elements of the memref. An affine expression of loop IVs +/// and symbols must be specified for each dimension of the memref. The keyword +/// 'symbol' can be used to indicate SSA identifiers which are symbolic. +// +// Example 1: +// +// %1 = affine.load %0[%i0 + 3, %i1 + 7] : memref<100x100xf32> +// +// Example 2: Uses 'symbol' keyword for symbols '%n' and '%m'. +// +// %1 = affine.load %0[%i0 + symbol(%n), %i1 + symbol(%m)] +// : memref<100x100xf32> +// +class AffineLoadOp : public Op::Impl> { +public: + using Op::Op; + + /// Builds an affine load op with the specified map and operands. + static void build(Builder *builder, OperationState *result, AffineMap map, + ArrayRef operands); + /// Builds an affine load op an identify map and operands. + static void build(Builder *builder, OperationState *result, Value *memref, + ArrayRef indices = {}); + + /// Returns the operand index of the memref. + unsigned getMemRefOperandIndex() { return 0; } + + /// Get memref operand. + Value *getMemRef() { return getOperand(getMemRefOperandIndex()); } + void setMemRef(Value *value) { setOperand(getMemRefOperandIndex(), value); } + MemRefType getMemRefType() { + return getMemRef()->getType().cast(); + } + + /// Get affine map operands. + operand_range getIndices() { return llvm::drop_begin(getOperands(), 1); } + + /// Returns the affine map used to index the memref for this operation. + AffineMap getAffineMap() { return getAffineMapAttr().getValue(); } + AffineMapAttr getAffineMapAttr() { + return getAttr(getMapAttrName()).cast(); + } + + /// Returns the AffineMapAttr associated with 'memref'. + NamedAttribute getAffineMapAttrForMemRef(Value *memref) { + assert(memref == getMemRef()); + return {Identifier::get(getMapAttrName(), getContext()), + getAffineMapAttr()}; + } + + static StringRef getMapAttrName() { return "map"; } + static StringRef getOperationName() { return "affine.load"; } + + // Hooks to customize behavior of this op. + static ParseResult parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p); + LogicalResult verify(); + static void getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context); +}; + +/// The "affine.store" op writes an element to a memref, where the index +/// for each memref dimension is an affine expression of loop induction +/// variables and symbols. The 'affine.store' op stores a new value which is the +/// same type as the elements of the memref. An affine expression of loop IVs +/// and symbols must be specified for each dimension of the memref. The keyword +/// 'symbol' can be used to indicate SSA identifiers which are symbolic. +// +// Example 1: +// +// affine.store %v0, %0[%i0 + 3, %i1 + 7] : memref<100x100xf32> +// +// Example 2: Uses 'symbol' keyword for symbols '%n' and '%m'. +// +// affine.store %v0, %0[%i0 + symbol(%n), %i1 + symbol(%m)] +// : memref<100x100xf32> +// +class AffineStoreOp : public Op::Impl> { +public: + using Op::Op; + + /// Builds an affine store operation with the specified map and operands. + static void build(Builder *builder, OperationState *result, + Value *valueToStore, AffineMap map, + ArrayRef operands); + /// Builds an affine store operation with an identity map and operands. + static void build(Builder *builder, OperationState *result, + Value *valueToStore, Value *memref, + ArrayRef operands); + + /// Get value to be stored by store operation. + Value *getValueToStore() { return getOperand(0); } + + /// Returns the operand index of the memref. + unsigned getMemRefOperandIndex() { return 1; } + + /// Get memref operand. + Value *getMemRef() { return getOperand(getMemRefOperandIndex()); } + void setMemRef(Value *value) { setOperand(getMemRefOperandIndex(), value); } + + MemRefType getMemRefType() { + return getMemRef()->getType().cast(); + } + + /// Get affine map operands. + operand_range getIndices() { return llvm::drop_begin(getOperands(), 2); } + + /// Returns the affine map used to index the memref for this operation. + AffineMap getAffineMap() { return getAffineMapAttr().getValue(); } + AffineMapAttr getAffineMapAttr() { + return getAttr(getMapAttrName()).cast(); + } + + /// Returns the AffineMapAttr associated with 'memref'. + NamedAttribute getAffineMapAttrForMemRef(Value *memref) { + assert(memref == getMemRef()); + return {Identifier::get(getMapAttrName(), getContext()), + getAffineMapAttr()}; + } + + static StringRef getMapAttrName() { return "map"; } + static StringRef getOperationName() { return "affine.store"; } + + // Hooks to customize behavior of this op. + static ParseResult parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p); + LogicalResult verify(); + static void getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context); +}; + +/// Returns true if the given Value can be used as a dimension id. +bool isValidDim(Value *value); + +/// Returns true if the given Value can be used as a symbol. +bool isValidSymbol(Value *value); + +/// Modifies both `map` and `operands` in-place so as to: +/// 1. drop duplicate operands +/// 2. drop unused dims and symbols from map +void canonicalizeMapAndOperands(AffineMap *map, + llvm::SmallVectorImpl *operands); + +/// Returns a composed AffineApplyOp by composing `map` and `operands` with +/// other AffineApplyOps supplying those operands. The operands of the resulting +/// AffineApplyOp do not change the length of AffineApplyOp chains. +AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, + llvm::ArrayRef operands); + +/// Given an affine map `map` and its input `operands`, this method composes +/// into `map`, maps of AffineApplyOps whose results are the values in +/// `operands`, iteratively until no more of `operands` are the result of an +/// AffineApplyOp. When this function returns, `map` becomes the composed affine +/// map, and each Value in `operands` is guaranteed to be either a loop IV or a +/// terminal symbol, i.e., a symbol defined at the top level or a block/function +/// argument. +void fullyComposeAffineMapAndOperands(AffineMap *map, + llvm::SmallVectorImpl *operands); + +#define GET_OP_CLASSES +#include "mlir/Dialect/AffineOps/AffineOps.h.inc" + +/// Returns if the provided value is the induction variable of a AffineForOp. +bool isForInductionVar(Value *val); + +/// Returns the loop parent of an induction variable. If the provided value is +/// not an induction variable, then return nullptr. +AffineForOp getForInductionVarOwner(Value *val); + +/// Extracts the induction variables from a list of AffineForOps and places them +/// in the output argument `ivs`. +void extractForInductionVars(ArrayRef forInsts, + SmallVectorImpl *ivs); + +/// AffineBound represents a lower or upper bound in the for operation. +/// This class does not own the underlying operands. Instead, it refers +/// to the operands stored in the AffineForOp. Its life span should not exceed +/// that of the for operation it refers to. +class AffineBound { +public: + AffineForOp getAffineForOp() { return op; } + AffineMap getMap() { return map; } + + /// Returns an AffineValueMap representing this bound. + AffineValueMap getAsAffineValueMap(); + + unsigned getNumOperands() { return opEnd - opStart; } + Value *getOperand(unsigned idx) { + return op.getOperation()->getOperand(opStart + idx); + } + + using operand_iterator = AffineForOp::operand_iterator; + using operand_range = AffineForOp::operand_range; + + operand_iterator operand_begin() { return op.operand_begin() + opStart; } + operand_iterator operand_end() { return op.operand_begin() + opEnd; } + operand_range getOperands() { return {operand_begin(), operand_end()}; } + +private: + // 'affine.for' operation that contains this bound. + AffineForOp op; + // Start and end positions of this affine bound operands in the list of + // the containing 'affine.for' operation operands. + unsigned opStart, opEnd; + // Affine map for this bound. + AffineMap map; + + AffineBound(AffineForOp op, unsigned opStart, unsigned opEnd, AffineMap map) + : op(op), opStart(opStart), opEnd(opEnd), map(map) {} + + friend class AffineForOp; +}; + +} // end namespace mlir + +#endif diff --git a/mlir/include/mlir/Dialect/AffineOps/AffineOps.td b/mlir/include/mlir/Dialect/AffineOps/AffineOps.td new file mode 100644 index 00000000000..237692c04a7 --- /dev/null +++ b/mlir/include/mlir/Dialect/AffineOps/AffineOps.td @@ -0,0 +1,259 @@ +//===- AffineOps.td - Affine operation definitions ---------*- tablegen -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// Defines MLIR affine operations. +// +//===----------------------------------------------------------------------===// + +#ifdef AFFINE_OPS +#else +#define AFFINE_OPS + +#ifdef OP_BASE +#else +include "mlir/IR/OpBase.td" +#endif // OP_BASE + +include "mlir/Dialect/AffineOps/AffineOpsBase.td" + +def Affine_Dialect : Dialect { + let name = "affine"; + let cppNamespace = ""; +} + +// Base class for Affine dialect ops. +class Affine_Op traits = []> : + Op { + // For every affine op, there needs to be a: + // * void print(OpAsmPrinter *p, ${C++ class of Op} op) + // * LogicalResult verify(${C++ class of Op} op) + // * ParseResult parse${C++ class of Op}(OpAsmParser *parser, + // OperationState *result) + // functions. + let printer = [{ return ::print(p, *this); }]; + let verifier = [{ return ::verify(*this); }]; + let parser = [{ return ::parse$cppClass(parser, result); }]; +} + +// Require regions to have affine terminator. +def ImplicitAffineTerminator + : SingleBlockImplicitTerminator<"AffineTerminatorOp">; + +def AffineForOp : Affine_Op<"for", [ImplicitAffineTerminator]> { + let summary = "for operation"; + let description = [{ + The "affine.for" operation represents an affine loop nest, defining an SSA + value for its induction variable. It has one region capturing the loop body. + The induction variable is represented as a argument of this region. This SSA + value always has type index, which is the size of the machine word. The + stride, represented by step, is a positive constant integer which defaults + to "1" if not present. The lower and upper bounds specify a half-open range: + the range includes the lower bound but does not include the upper bound. + + The body region must contain exactly one block that terminates with + "affine.terminator". Calling AffineForOp::build will create such region + and insert the terminator, so will the parsing even in cases if it is absent + from the custom format. + + The lower and upper bounds of a for operation are represented as an + application of an affine mapping to a list of SSA values passed to the map. + The same restrictions hold for these SSA values as for all bindings of SSA + values to dimensions and symbols. The affine mappings for the bounds may + return multiple results, in which case the max/min keywords are required + (for the lower/upper bound respectively), and the bound is the + maximum/minimum of the returned values. + + Example: + + affine.for %i = 1 to 10 { + ... + } + + }]; + let arguments = (ins Variadic); + let regions = (region SizedRegion<1>:$region); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<"Builder *builder, OperationState *result, " + "int64_t lowerBound, int64_t upperBound, int64_t step = 1">, + OpBuilder<"Builder *builder, OperationState *result, " + "ArrayRef lbOperands, AffineMap lbMap, " + "ArrayRef ubOperands, AffineMap ubMap, " + "int64_t step = 1"> + ]; + + let extraClassDeclaration = [{ + static StringRef getStepAttrName() { return "step"; } + static StringRef getLowerBoundAttrName() { return "lower_bound"; } + static StringRef getUpperBoundAttrName() { return "upper_bound"; } + + Block *getBody() { return ®ion().front(); } + Value *getInductionVar() { return getBody()->getArgument(0); } + OpBuilder getBodyBuilder() { + return OpBuilder(getBody(), std::prev(getBody()->end())); + } + + // TODO: provide iterators for the lower and upper bound operands + // if the current access via getLowerBound(), getUpperBound() is too slow. + + /// Returns operands for the lower bound map. + operand_range getLowerBoundOperands(); + + /// Returns operands for the upper bound map. + operand_range getUpperBoundOperands(); + + /// Returns information about the lower bound as a single object. + AffineBound getLowerBound(); + + /// Returns information about the upper bound as a single object. + AffineBound getUpperBound(); + + /// Returns loop step. + int64_t getStep() { + return getAttr(getStepAttrName()).cast().getInt(); + } + + /// Returns affine map for the lower bound. + AffineMap getLowerBoundMap() { return getLowerBoundMapAttr().getValue(); } + AffineMapAttr getLowerBoundMapAttr() { + return getAttr(getLowerBoundAttrName()).cast(); + } + /// Returns affine map for the upper bound. The upper bound is exclusive. + AffineMap getUpperBoundMap() { return getUpperBoundMapAttr().getValue(); } + AffineMapAttr getUpperBoundMapAttr() { + return getAttr(getUpperBoundAttrName()).cast(); + } + + /// Set lower bound. The new bound must have the same number of operands as + /// the current bound map. Otherwise, 'replaceForLowerBound' should be used. + void setLowerBound(ArrayRef operands, AffineMap map); + /// Set upper bound. The new bound must not have more operands than the + /// current bound map. Otherwise, 'replaceForUpperBound' should be used. + void setUpperBound(ArrayRef operands, AffineMap map); + + /// Set the lower bound map without changing operands. + void setLowerBoundMap(AffineMap map); + + /// Set the upper bound map without changing operands. + void setUpperBoundMap(AffineMap map); + + /// Set loop step. + void setStep(int64_t step) { + assert(step > 0 && "step has to be a positive integer constant"); + auto *context = getLowerBoundMap().getContext(); + setAttr(Identifier::get(getStepAttrName(), context), + IntegerAttr::get(IndexType::get(context), step)); + } + + /// Returns true if the lower bound is constant. + bool hasConstantLowerBound(); + /// Returns true if the upper bound is constant. + bool hasConstantUpperBound(); + /// Returns true if both bounds are constant. + bool hasConstantBounds() { + return hasConstantLowerBound() && hasConstantUpperBound(); + } + /// Returns the value of the constant lower bound. + /// Fails assertion if the bound is non-constant. + int64_t getConstantLowerBound(); + /// Returns the value of the constant upper bound. The upper bound is + /// exclusive. Fails assertion if the bound is non-constant. + int64_t getConstantUpperBound(); + /// Sets the lower bound to the given constant value. + void setConstantLowerBound(int64_t value); + /// Sets the upper bound to the given constant value. + void setConstantUpperBound(int64_t value); + + /// Returns true if both the lower and upper bound have the same operand + /// lists (same operands in the same order). + bool matchingBoundOperandList(); + }]; + + let hasCanonicalizer = 1; +} + +def AffineIfOp : Affine_Op<"if", [ImplicitAffineTerminator]> { + let summary = "if-then-else operation"; + let description = [{ + The "if" operation represents an if-then-else construct for conditionally + executing two regions of code. The operands to an if operation are an + IntegerSet condition and a set of symbol/dimension operands to the + condition set. The operation produces no results. For example: + + affine.if #set(%i) { + ... + } else { + ... + } + + The 'else' blocks to the if operation are optional, and may be omitted. For + example: + + affine.if #set(%i) { + ... + } + }]; + let arguments = (ins Variadic); + let regions = (region SizedRegion<1>:$thenRegion, AnyRegion:$elseRegion); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<"Builder *builder, OperationState *result, " + "Value *cond, bool withElseRegion"> + ]; + + let extraClassDeclaration = [{ + static StringRef getConditionAttrName() { return "condition"; } + + IntegerSet getIntegerSet(); + void setIntegerSet(IntegerSet newSet); + + OpBuilder getThenBodyBuilder() { + assert(!thenRegion().empty() && "Unexpected empty 'then' region."); + Block &body = thenRegion().front(); + return OpBuilder(&body, std::prev(body.end())); + } + OpBuilder getElseBodyBuilder() { + assert(!elseRegion().empty() && "Unexpected empty 'else' region."); + Block &body = elseRegion().front(); + return OpBuilder(&body, std::prev(body.end())); + } + }]; +} + +def AffineTerminatorOp : + Affine_Op<"terminator", [Terminator]> { + let summary = "affine terminator operation"; + let description = [{ + Affine terminator is a special terminator operation for blocks inside affine + loops and branches. It unconditionally transmits the control flow to the + successor of the operation enclosing the region. + + This operation does _not_ have a custom syntax. However, affine control + operations omit the terminator in their custom syntax for brevity. + }]; + + // No custom parsing/printing form. + let parser = ?; + let printer = ?; + + // Fully specified by traits. + let verifier = ?; +} + +#endif // AFFINE_OPS diff --git a/mlir/include/mlir/Dialect/AffineOps/AffineOpsBase.td b/mlir/include/mlir/Dialect/AffineOps/AffineOpsBase.td new file mode 100644 index 00000000000..2ac1d379c12 --- /dev/null +++ b/mlir/include/mlir/Dialect/AffineOps/AffineOpsBase.td @@ -0,0 +1,44 @@ +//===- AffineOpsBase.td - Affine operation definitions -----*- tablegen -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// Defines base support for MLIR affine operations. +// +//===----------------------------------------------------------------------===// + +#ifdef AFFINE_OPS_BASE +#else +#define AFFINE_OPS_BASE + +#ifdef OP_BASE +#else +include "mlir/IR/OpBase.td" +#endif // OP_BASE + +// Attributes containing affine maps. +def AffineMapAttr : Attr< + CPred<"$_self.isa()">, "AffineMap attribute"> { + let storageType = [{ AffineMapAttr }]; + let returnType = [{ AffineMap }]; + let constBuilderCall = "$_builder.getAffineMapAttr($0)"; +} + +def AffineMapArrayAttr : TypedArrayAttrBase { + let constBuilderCall = "$_builder.getAffineMapArrayAttr($0)"; +} + +#endif // AFFINE_OPS_BASE diff --git a/mlir/include/mlir/Dialect/AffineOps/CMakeLists.txt b/mlir/include/mlir/Dialect/AffineOps/CMakeLists.txt new file mode 100644 index 00000000000..6c5a58c957b --- /dev/null +++ b/mlir/include/mlir/Dialect/AffineOps/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS AffineOps.td) +mlir_tablegen(AffineOps.h.inc -gen-op-decls) +mlir_tablegen(AffineOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRAffineOpsIncGen) diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt index ce53bfc9a57..9235436995a 100644 --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(AffineOps) add_subdirectory(FxpMathOps) add_subdirectory(GPU) add_subdirectory(Linalg) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td index 29977c1c637..47d30cf2836 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td @@ -24,7 +24,7 @@ #else #define LINALG_LIBRARY_OPS -include "mlir/AffineOps/AffineOpsBase.td" +include "mlir/Dialect/AffineOps/AffineOpsBase.td" include "mlir/Dialect/Linalg/IR/LinalgBase.td" class LinalgParametricNativeOpTrait : diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h index 29e2e9e1ea7..51c5c331fe9 100644 --- a/mlir/include/mlir/EDSC/Builders.h +++ b/mlir/include/mlir/EDSC/Builders.h @@ -23,7 +23,7 @@ #ifndef MLIR_EDSC_BUILDERS_H_ #define MLIR_EDSC_BUILDERS_H_ -#include "mlir/AffineOps/AffineOps.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Dialect/VectorOps/VectorOps.h" #include "mlir/IR/Builders.h" diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp deleted file mode 100644 index b00a11083ec..00000000000 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ /dev/null @@ -1,1764 +0,0 @@ -//===- AffineOps.cpp - MLIR Affine Operations -----------------------------===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#include "mlir/AffineOps/AffineOps.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/IR/Block.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Function.h" -#include "mlir/IR/IntegerSet.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/PatternMatch.h" -#include "llvm/ADT/SetVector.h" -#include "llvm/ADT/SmallBitVector.h" -#include "llvm/Support/Debug.h" -using namespace mlir; -using llvm::dbgs; - -#define DEBUG_TYPE "affine-analysis" - -//===----------------------------------------------------------------------===// -// AffineOpsDialect -//===----------------------------------------------------------------------===// - -AffineOpsDialect::AffineOpsDialect(MLIRContext *context) - : Dialect(getDialectNamespace(), context) { - addOperations(); -} - -/// A utility function to check if a given region is attached to a function. -static bool isFunctionRegion(Region *region) { - return llvm::isa(region->getParentOp()); -} - -/// A utility function to check if a value is defined at the top level of a -/// function. A value defined at the top level is always a valid symbol. -bool mlir::isTopLevelSymbol(Value *value) { - if (auto *arg = dyn_cast(value)) - return isFunctionRegion(arg->getOwner()->getParent()); - return isFunctionRegion(value->getDefiningOp()->getParentRegion()); -} - -// Value can be used as a dimension id if it is valid as a symbol, or -// it is an induction variable, or it is a result of affine apply operation -// with dimension id arguments. -bool mlir::isValidDim(Value *value) { - // The value must be an index type. - if (!value->getType().isIndex()) - return false; - - if (auto *op = value->getDefiningOp()) { - // Top level operation or constant operation is ok. - if (isFunctionRegion(op->getParentRegion()) || isa(op)) - return true; - // Affine apply operation is ok if all of its operands are ok. - if (auto applyOp = dyn_cast(op)) - return applyOp.isValidDim(); - // The dim op is okay if its operand memref/tensor is defined at the top - // level. - if (auto dimOp = dyn_cast(op)) - return isTopLevelSymbol(dimOp.getOperand()); - return false; - } - // This value is a block argument (which also includes 'affine.for' loop IVs). - return true; -} - -// Value can be used as a symbol if it is a constant, or it is defined at -// the top level, or it is a result of affine apply operation with symbol -// arguments. -bool mlir::isValidSymbol(Value *value) { - // The value must be an index type. - if (!value->getType().isIndex()) - return false; - - if (auto *op = value->getDefiningOp()) { - // Top level operation or constant operation is ok. - if (isFunctionRegion(op->getParentRegion()) || isa(op)) - return true; - // Affine apply operation is ok if all of its operands are ok. - if (auto applyOp = dyn_cast(op)) - return applyOp.isValidSymbol(); - // The dim op is okay if its operand memref/tensor is defined at the top - // level. - if (auto dimOp = dyn_cast(op)) - return isTopLevelSymbol(dimOp.getOperand()); - return false; - } - // Otherwise, check that the value is a top level symbol. - return isTopLevelSymbol(value); -} - -// Returns true if 'value' is a valid index to an affine operation (e.g. -// affine.load, affine.store, affine.dma_start, affine.dma_wait). -// Returns false otherwise. -static bool isValidAffineIndexOperand(Value *value) { - return isValidDim(value) || isValidSymbol(value); -} - -/// Utility function to verify that a set of operands are valid dimension and -/// symbol identifiers. The operands should be layed out such that the dimension -/// operands are before the symbol operands. This function returns failure if -/// there was an invalid operand. An operation is provided to emit any necessary -/// errors. -template -static LogicalResult -verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, - unsigned numDims) { - unsigned opIt = 0; - for (auto *operand : operands) { - if (opIt++ < numDims) { - if (!isValidDim(operand)) - return op.emitOpError("operand cannot be used as a dimension id"); - } else if (!isValidSymbol(operand)) { - return op.emitOpError("operand cannot be used as a symbol"); - } - } - return success(); -} - -//===----------------------------------------------------------------------===// -// AffineApplyOp -//===----------------------------------------------------------------------===// - -void AffineApplyOp::build(Builder *builder, OperationState *result, - AffineMap map, ArrayRef operands) { - result->addOperands(operands); - result->types.append(map.getNumResults(), builder->getIndexType()); - result->addAttribute("map", builder->getAffineMapAttr(map)); -} - -ParseResult AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) { - auto &builder = parser->getBuilder(); - auto affineIntTy = builder.getIndexType(); - - AffineMapAttr mapAttr; - unsigned numDims; - if (parser->parseAttribute(mapAttr, "map", result->attributes) || - parseDimAndSymbolList(parser, result->operands, numDims) || - parser->parseOptionalAttributeDict(result->attributes)) - return failure(); - auto map = mapAttr.getValue(); - - if (map.getNumDims() != numDims || - numDims + map.getNumSymbols() != result->operands.size()) { - return parser->emitError(parser->getNameLoc(), - "dimension or symbol index mismatch"); - } - - result->types.append(map.getNumResults(), affineIntTy); - return success(); -} - -void AffineApplyOp::print(OpAsmPrinter *p) { - *p << "affine.apply " << getAttr("map"); - printDimAndSymbolList(operand_begin(), operand_end(), - getAffineMap().getNumDims(), p); - p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{"map"}); -} - -LogicalResult AffineApplyOp::verify() { - // Check that affine map attribute was specified. - auto affineMapAttr = getAttrOfType("map"); - if (!affineMapAttr) - return emitOpError("requires an affine map"); - - // Check input and output dimensions match. - auto map = affineMapAttr.getValue(); - - // Verify that operand count matches affine map dimension and symbol count. - if (getNumOperands() != map.getNumDims() + map.getNumSymbols()) - return emitOpError( - "operand count and affine map dimension and symbol count must match"); - - // Verify that all operands are of `index` type. - for (Type t : getOperandTypes()) { - if (!t.isIndex()) - return emitOpError("operands must be of type 'index'"); - } - - if (!getResult()->getType().isIndex()) - return emitOpError("result must be of type 'index'"); - - // Verify that the operands are valid dimension and symbol identifiers. - if (failed(verifyDimAndSymbolIdentifiers(*this, getOperands(), - map.getNumDims()))) - return failure(); - - // Verify that the map only produces one result. - if (map.getNumResults() != 1) - return emitOpError("mapping must produce one value"); - - return success(); -} - -// The result of the affine apply operation can be used as a dimension id if it -// is a CFG value or if it is an Value, and all the operands are valid -// dimension ids. -bool AffineApplyOp::isValidDim() { - return llvm::all_of(getOperands(), - [](Value *op) { return mlir::isValidDim(op); }); -} - -// The result of the affine apply operation can be used as a symbol if it is -// a CFG value or if it is an Value, and all the operands are symbols. -bool AffineApplyOp::isValidSymbol() { - return llvm::all_of(getOperands(), - [](Value *op) { return mlir::isValidSymbol(op); }); -} - -OpFoldResult AffineApplyOp::fold(ArrayRef operands) { - auto map = getAffineMap(); - - // Fold dims and symbols to existing values. - auto expr = map.getResult(0); - if (auto dim = expr.dyn_cast()) - return getOperand(dim.getPosition()); - if (auto sym = expr.dyn_cast()) - return getOperand(map.getNumDims() + sym.getPosition()); - - // Otherwise, default to folding the map. - SmallVector result; - if (failed(map.constantFold(operands, result))) - return {}; - return result[0]; -} - -namespace { -/// An `AffineApplyNormalizer` is a helper class that is not visible to the user -/// and supports renumbering operands of AffineApplyOp. This acts as a -/// reindexing map of Value* to positional dims or symbols and allows -/// simplifications such as: -/// -/// ```mlir -/// %1 = affine.apply (d0, d1) -> (d0 - d1) (%0, %0) -/// ``` -/// -/// into: -/// -/// ```mlir -/// %1 = affine.apply () -> (0) -/// ``` -struct AffineApplyNormalizer { - AffineApplyNormalizer(AffineMap map, ArrayRef operands); - - /// Returns the AffineMap resulting from normalization. - AffineMap getAffineMap() { return affineMap; } - - SmallVector getOperands() { - SmallVector res(reorderedDims); - res.append(concatenatedSymbols.begin(), concatenatedSymbols.end()); - return res; - } - -private: - /// Helper function to insert `v` into the coordinate system of the current - /// AffineApplyNormalizer. Returns the AffineDimExpr with the corresponding - /// renumbered position. - AffineDimExpr renumberOneDim(Value *v); - - /// Given an `other` normalizer, this rewrites `other.affineMap` in the - /// coordinate system of the current AffineApplyNormalizer. - /// Returns the rewritten AffineMap and updates the dims and symbols of - /// `this`. - AffineMap renumber(const AffineApplyNormalizer &other); - - /// Maps of Value* to position in `affineMap`. - DenseMap dimValueToPosition; - - /// Ordered dims and symbols matching positional dims and symbols in - /// `affineMap`. - SmallVector reorderedDims; - SmallVector concatenatedSymbols; - - AffineMap affineMap; - - /// Used with RAII to control the depth at which AffineApply are composed - /// recursively. Only accepts depth 1 for now to allow a behavior where a - /// newly composed AffineApplyOp does not increase the length of the chain of - /// AffineApplyOps. Full composition is implemented iteratively on top of - /// this behavior. - static unsigned &affineApplyDepth() { - static thread_local unsigned depth = 0; - return depth; - } - static constexpr unsigned kMaxAffineApplyDepth = 1; - - AffineApplyNormalizer() { affineApplyDepth()++; } - -public: - ~AffineApplyNormalizer() { affineApplyDepth()--; } -}; -} // end anonymous namespace. - -AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value *v) { - DenseMap::iterator iterPos; - bool inserted = false; - std::tie(iterPos, inserted) = - dimValueToPosition.insert(std::make_pair(v, dimValueToPosition.size())); - if (inserted) { - reorderedDims.push_back(v); - } - return getAffineDimExpr(iterPos->second, v->getContext()) - .cast(); -} - -AffineMap AffineApplyNormalizer::renumber(const AffineApplyNormalizer &other) { - SmallVector dimRemapping; - for (auto *v : other.reorderedDims) { - auto kvp = other.dimValueToPosition.find(v); - if (dimRemapping.size() <= kvp->second) - dimRemapping.resize(kvp->second + 1); - dimRemapping[kvp->second] = renumberOneDim(kvp->first); - } - unsigned numSymbols = concatenatedSymbols.size(); - unsigned numOtherSymbols = other.concatenatedSymbols.size(); - SmallVector symRemapping(numOtherSymbols); - for (unsigned idx = 0; idx < numOtherSymbols; ++idx) { - symRemapping[idx] = - getAffineSymbolExpr(idx + numSymbols, other.affineMap.getContext()); - } - concatenatedSymbols.insert(concatenatedSymbols.end(), - other.concatenatedSymbols.begin(), - other.concatenatedSymbols.end()); - auto map = other.affineMap; - return map.replaceDimsAndSymbols(dimRemapping, symRemapping, - dimRemapping.size(), symRemapping.size()); -} - -// Gather the positions of the operands that are produced by an AffineApplyOp. -static llvm::SetVector -indicesFromAffineApplyOp(ArrayRef operands) { - llvm::SetVector res; - for (auto en : llvm::enumerate(operands)) - if (isa_and_nonnull(en.value()->getDefiningOp())) - res.insert(en.index()); - return res; -} - -// Support the special case of a symbol coming from an AffineApplyOp that needs -// to be composed into the current AffineApplyOp. -// This case is handled by rewriting all such symbols into dims for the purpose -// of allowing mathematical AffineMap composition. -// Returns an AffineMap where symbols that come from an AffineApplyOp have been -// rewritten as dims and are ordered after the original dims. -// TODO(andydavis,ntv): This promotion makes AffineMap lose track of which -// symbols are represented as dims. This loss is static but can still be -// recovered dynamically (with `isValidSymbol`). Still this is annoying for the -// semi-affine map case. A dynamic canonicalization of all dims that are valid -// symbols (a.k.a `canonicalizePromotedSymbols`) into symbols helps and even -// results in better simplifications and foldings. But we should evaluate -// whether this behavior is what we really want after using more. -static AffineMap promoteComposedSymbolsAsDims(AffineMap map, - ArrayRef symbols) { - if (symbols.empty()) { - return map; - } - - // Sanity check on symbols. - for (auto *sym : symbols) { - assert(isValidSymbol(sym) && "Expected only valid symbols"); - (void)sym; - } - - // Extract the symbol positions that come from an AffineApplyOp and - // needs to be rewritten as dims. - auto symPositions = indicesFromAffineApplyOp(symbols); - if (symPositions.empty()) { - return map; - } - - // Create the new map by replacing each symbol at pos by the next new dim. - unsigned numDims = map.getNumDims(); - unsigned numSymbols = map.getNumSymbols(); - unsigned numNewDims = 0; - unsigned numNewSymbols = 0; - SmallVector symReplacements(numSymbols); - for (unsigned i = 0; i < numSymbols; ++i) { - symReplacements[i] = - symPositions.count(i) > 0 - ? getAffineDimExpr(numDims + numNewDims++, map.getContext()) - : getAffineSymbolExpr(numNewSymbols++, map.getContext()); - } - assert(numSymbols >= numNewDims); - AffineMap newMap = map.replaceDimsAndSymbols( - {}, symReplacements, numDims + numNewDims, numNewSymbols); - - return newMap; -} - -/// The AffineNormalizer composes AffineApplyOp recursively. Its purpose is to -/// keep a correspondence between the mathematical `map` and the `operands` of -/// a given AffineApplyOp. This correspondence is maintained by iterating over -/// the operands and forming an `auxiliaryMap` that can be composed -/// mathematically with `map`. To keep this correspondence in cases where -/// symbols are produced by affine.apply operations, we perform a local rewrite -/// of symbols as dims. -/// -/// Rationale for locally rewriting symbols as dims: -/// ================================================ -/// The mathematical composition of AffineMap must always concatenate symbols -/// because it does not have enough information to do otherwise. For example, -/// composing `(d0)[s0] -> (d0 + s0)` with itself must produce -/// `(d0)[s0, s1] -> (d0 + s0 + s1)`. -/// -/// The result is only equivalent to `(d0)[s0] -> (d0 + 2 * s0)` when -/// applied to the same mlir::Value* for both s0 and s1. -/// As a consequence mathematical composition of AffineMap always concatenates -/// symbols. -/// -/// When AffineMaps are used in AffineApplyOp however, they may specify -/// composition via symbols, which is ambiguous mathematically. This corner case -/// is handled by locally rewriting such symbols that come from AffineApplyOp -/// into dims and composing through dims. -/// TODO(andydavis, ntv): Composition via symbols comes at a significant code -/// complexity. Alternatively we should investigate whether we want to -/// explicitly disallow symbols coming from affine.apply and instead force the -/// user to compose symbols beforehand. The annoyances may be small (i.e. 1 or 2 -/// extra API calls for such uses, which haven't popped up until now) and the -/// benefit potentially big: simpler and more maintainable code for a -/// non-trivial, recursive, procedure. -AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map, - ArrayRef operands) - : AffineApplyNormalizer() { - static_assert(kMaxAffineApplyDepth > 0, "kMaxAffineApplyDepth must be > 0"); - assert(map.getNumInputs() == operands.size() && - "number of operands does not match the number of map inputs"); - - LLVM_DEBUG(map.print(dbgs() << "\nInput map: ")); - - // Promote symbols that come from an AffineApplyOp to dims by rewriting the - // map to always refer to: - // (dims, symbols coming from AffineApplyOp, other symbols). - // The order of operands can remain unchanged. - // This is a simplification that relies on 2 ordering properties: - // 1. rewritten symbols always appear after the original dims in the map; - // 2. operands are traversed in order and either dispatched to: - // a. auxiliaryExprs (dims and symbols rewritten as dims); - // b. concatenatedSymbols (all other symbols) - // This allows operand order to remain unchanged. - unsigned numDimsBeforeRewrite = map.getNumDims(); - map = promoteComposedSymbolsAsDims(map, - operands.take_back(map.getNumSymbols())); - - LLVM_DEBUG(map.print(dbgs() << "\nRewritten map: ")); - - SmallVector auxiliaryExprs; - bool furtherCompose = (affineApplyDepth() <= kMaxAffineApplyDepth); - // We fully spell out the 2 cases below. In this particular instance a little - // code duplication greatly improves readability. - // Note that the first branch would disappear if we only supported full - // composition (i.e. infinite kMaxAffineApplyDepth). - if (!furtherCompose) { - // 1. Only dispatch dims or symbols. - for (auto en : llvm::enumerate(operands)) { - auto *t = en.value(); - assert(t->getType().isIndex()); - bool isDim = (en.index() < map.getNumDims()); - if (isDim) { - // a. The mathematical composition of AffineMap composes dims. - auxiliaryExprs.push_back(renumberOneDim(t)); - } else { - // b. The mathematical composition of AffineMap concatenates symbols. - // We do the same for symbol operands. - concatenatedSymbols.push_back(t); - } - } - } else { - assert(numDimsBeforeRewrite <= operands.size()); - // 2. Compose AffineApplyOps and dispatch dims or symbols. - for (unsigned i = 0, e = operands.size(); i < e; ++i) { - auto *t = operands[i]; - auto affineApply = dyn_cast_or_null(t->getDefiningOp()); - if (affineApply) { - // a. Compose affine.apply operations. - LLVM_DEBUG(affineApply.getOperation()->print( - dbgs() << "\nCompose AffineApplyOp recursively: ")); - AffineMap affineApplyMap = affineApply.getAffineMap(); - SmallVector affineApplyOperands( - affineApply.getOperands().begin(), affineApply.getOperands().end()); - AffineApplyNormalizer normalizer(affineApplyMap, affineApplyOperands); - - LLVM_DEBUG(normalizer.affineMap.print( - dbgs() << "\nRenumber into current normalizer: ")); - - auto renumberedMap = renumber(normalizer); - - LLVM_DEBUG( - renumberedMap.print(dbgs() << "\nRecursive composition yields: ")); - - auxiliaryExprs.push_back(renumberedMap.getResult(0)); - } else { - if (i < numDimsBeforeRewrite) { - // b. The mathematical composition of AffineMap composes dims. - auxiliaryExprs.push_back(renumberOneDim(t)); - } else { - // c. The mathematical composition of AffineMap concatenates symbols. - // We do the same for symbol operands. - concatenatedSymbols.push_back(t); - } - } - } - } - - // Early exit if `map` is already composed. - if (auxiliaryExprs.empty()) { - affineMap = map; - return; - } - - assert(concatenatedSymbols.size() >= map.getNumSymbols() && - "Unexpected number of concatenated symbols"); - auto numDims = dimValueToPosition.size(); - auto numSymbols = concatenatedSymbols.size() - map.getNumSymbols(); - auto auxiliaryMap = AffineMap::get(numDims, numSymbols, auxiliaryExprs); - - LLVM_DEBUG(map.print(dbgs() << "\nCompose map: ")); - LLVM_DEBUG(auxiliaryMap.print(dbgs() << "\nWith map: ")); - LLVM_DEBUG(map.compose(auxiliaryMap).print(dbgs() << "\nResult: ")); - - // TODO(andydavis,ntv): Disabling simplification results in major speed gains. - // Another option is to cache the results as it is expected a lot of redundant - // work is performed in practice. - affineMap = simplifyAffineMap(map.compose(auxiliaryMap)); - - LLVM_DEBUG(affineMap.print(dbgs() << "\nSimplified result: ")); - LLVM_DEBUG(dbgs() << "\n"); -} - -/// Implements `map` and `operands` composition and simplification to support -/// `makeComposedAffineApply`. This can be called to achieve the same effects -/// on `map` and `operands` without creating an AffineApplyOp that needs to be -/// immediately deleted. -static void composeAffineMapAndOperands(AffineMap *map, - SmallVectorImpl *operands) { - AffineApplyNormalizer normalizer(*map, *operands); - auto normalizedMap = normalizer.getAffineMap(); - auto normalizedOperands = normalizer.getOperands(); - canonicalizeMapAndOperands(&normalizedMap, &normalizedOperands); - *map = normalizedMap; - *operands = normalizedOperands; - assert(*map); -} - -void mlir::fullyComposeAffineMapAndOperands( - AffineMap *map, SmallVectorImpl *operands) { - while (llvm::any_of(*operands, [](Value *v) { - return isa_and_nonnull(v->getDefiningOp()); - })) { - composeAffineMapAndOperands(map, operands); - } -} - -AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc, - AffineMap map, - ArrayRef operands) { - AffineMap normalizedMap = map; - SmallVector normalizedOperands(operands.begin(), operands.end()); - composeAffineMapAndOperands(&normalizedMap, &normalizedOperands); - assert(normalizedMap); - return b.create(loc, normalizedMap, normalizedOperands); -} - -// A symbol may appear as a dim in affine.apply operations. This function -// canonicalizes dims that are valid symbols into actual symbols. -static void -canonicalizePromotedSymbols(AffineMap *map, - llvm::SmallVectorImpl *operands) { - if (!map || operands->empty()) - return; - - assert(map->getNumInputs() == operands->size() && - "map inputs must match number of operands"); - - auto *context = map->getContext(); - SmallVector resultOperands; - resultOperands.reserve(operands->size()); - SmallVector remappedSymbols; - remappedSymbols.reserve(operands->size()); - unsigned nextDim = 0; - unsigned nextSym = 0; - unsigned oldNumSyms = map->getNumSymbols(); - SmallVector dimRemapping(map->getNumDims()); - for (unsigned i = 0, e = map->getNumInputs(); i != e; ++i) { - if (i < map->getNumDims()) { - if (isValidSymbol((*operands)[i])) { - // This is a valid symbols that appears as a dim, canonicalize it. - dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context); - remappedSymbols.push_back((*operands)[i]); - } else { - dimRemapping[i] = getAffineDimExpr(nextDim++, context); - resultOperands.push_back((*operands)[i]); - } - } else { - resultOperands.push_back((*operands)[i]); - } - } - - resultOperands.append(remappedSymbols.begin(), remappedSymbols.end()); - *operands = resultOperands; - *map = map->replaceDimsAndSymbols(dimRemapping, {}, nextDim, - oldNumSyms + nextSym); - - assert(map->getNumInputs() == operands->size() && - "map inputs must match number of operands"); -} - -void mlir::canonicalizeMapAndOperands( - AffineMap *map, llvm::SmallVectorImpl *operands) { - if (!map || operands->empty()) - return; - - assert(map->getNumInputs() == operands->size() && - "map inputs must match number of operands"); - - canonicalizePromotedSymbols(map, operands); - - // Check to see what dims are used. - llvm::SmallBitVector usedDims(map->getNumDims()); - llvm::SmallBitVector usedSyms(map->getNumSymbols()); - map->walkExprs([&](AffineExpr expr) { - if (auto dimExpr = expr.dyn_cast()) - usedDims[dimExpr.getPosition()] = true; - else if (auto symExpr = expr.dyn_cast()) - usedSyms[symExpr.getPosition()] = true; - }); - - auto *context = map->getContext(); - - SmallVector resultOperands; - resultOperands.reserve(operands->size()); - - llvm::SmallDenseMap seenDims; - SmallVector dimRemapping(map->getNumDims()); - unsigned nextDim = 0; - for (unsigned i = 0, e = map->getNumDims(); i != e; ++i) { - if (usedDims[i]) { - auto it = seenDims.find((*operands)[i]); - if (it == seenDims.end()) { - dimRemapping[i] = getAffineDimExpr(nextDim++, context); - resultOperands.push_back((*operands)[i]); - seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i])); - } else { - dimRemapping[i] = it->second; - } - } - } - llvm::SmallDenseMap seenSymbols; - SmallVector symRemapping(map->getNumSymbols()); - unsigned nextSym = 0; - for (unsigned i = 0, e = map->getNumSymbols(); i != e; ++i) { - if (usedSyms[i]) { - auto it = seenSymbols.find((*operands)[i + map->getNumDims()]); - if (it == seenSymbols.end()) { - symRemapping[i] = getAffineSymbolExpr(nextSym++, context); - resultOperands.push_back((*operands)[i + map->getNumDims()]); - seenSymbols.insert(std::make_pair((*operands)[i + map->getNumDims()], - symRemapping[i])); - } else { - symRemapping[i] = it->second; - } - } - } - *map = - map->replaceDimsAndSymbols(dimRemapping, symRemapping, nextDim, nextSym); - *operands = resultOperands; -} - -namespace { -/// Simplify AffineApply operations. -/// -struct SimplifyAffineApply : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - PatternMatchResult matchAndRewrite(AffineApplyOp apply, - PatternRewriter &rewriter) const override { - auto map = apply.getAffineMap(); - - AffineMap oldMap = map; - SmallVector resultOperands(apply.getOperands()); - composeAffineMapAndOperands(&map, &resultOperands); - if (map == oldMap) - return matchFailure(); - - rewriter.replaceOpWithNewOp(apply, map, resultOperands); - return matchSuccess(); - } -}; -} // end anonymous namespace. - -void AffineApplyOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// Common canonicalization pattern support logic -//===----------------------------------------------------------------------===// - -namespace { -/// This is a common class used for patterns of the form -/// "someop(memrefcast) -> someop". It folds the source of any memref_cast -/// into the root operation directly. -struct MemRefCastFolder : public RewritePattern { - /// The rootOpName is the name of the root operation to match against. - MemRefCastFolder(StringRef rootOpName, MLIRContext *context) - : RewritePattern(rootOpName, 1, context) {} - - PatternMatchResult match(Operation *op) const override { - for (auto *operand : op->getOperands()) - if (matchPattern(operand, m_Op())) - return matchSuccess(); - - return matchFailure(); - } - - void rewrite(Operation *op, PatternRewriter &rewriter) const override { - for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) - if (auto *memref = op->getOperand(i)->getDefiningOp()) - if (auto cast = dyn_cast(memref)) - op->setOperand(i, cast.getOperand()); - rewriter.updatedRootInPlace(op); - } -}; - -} // end anonymous namespace. - -//===----------------------------------------------------------------------===// -// AffineDmaStartOp -//===----------------------------------------------------------------------===// - -// TODO(b/133776335) Check that map operands are loop IVs or symbols. -void AffineDmaStartOp::build(Builder *builder, OperationState *result, - Value *srcMemRef, AffineMap srcMap, - ArrayRef srcIndices, Value *destMemRef, - AffineMap dstMap, ArrayRef destIndices, - Value *tagMemRef, AffineMap tagMap, - ArrayRef tagIndices, Value *numElements, - Value *stride, Value *elementsPerStride) { - result->addOperands(srcMemRef); - result->addAttribute(getSrcMapAttrName(), builder->getAffineMapAttr(srcMap)); - result->addOperands(srcIndices); - result->addOperands(destMemRef); - result->addAttribute(getDstMapAttrName(), builder->getAffineMapAttr(dstMap)); - result->addOperands(destIndices); - result->addOperands(tagMemRef); - result->addAttribute(getTagMapAttrName(), builder->getAffineMapAttr(tagMap)); - result->addOperands(tagIndices); - result->addOperands(numElements); - if (stride) { - result->addOperands({stride, elementsPerStride}); - } -} - -void AffineDmaStartOp::print(OpAsmPrinter *p) { - *p << "affine.dma_start " << *getSrcMemRef() << '['; - SmallVector operands(getSrcIndices()); - p->printAffineMapOfSSAIds(getSrcMapAttr(), operands); - *p << "], " << *getDstMemRef() << '['; - operands.assign(getDstIndices().begin(), getDstIndices().end()); - p->printAffineMapOfSSAIds(getDstMapAttr(), operands); - *p << "], " << *getTagMemRef() << '['; - operands.assign(getTagIndices().begin(), getTagIndices().end()); - p->printAffineMapOfSSAIds(getTagMapAttr(), operands); - *p << "], " << *getNumElements(); - if (isStrided()) { - *p << ", " << *getStride(); - *p << ", " << *getNumElementsPerStride(); - } - *p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", " - << getTagMemRefType(); -} - -// Parse AffineDmaStartOp. -// Ex: -// affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size, -// %stride, %num_elt_per_stride -// : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32> -// -ParseResult AffineDmaStartOp::parse(OpAsmParser *parser, - OperationState *result) { - OpAsmParser::OperandType srcMemRefInfo; - AffineMapAttr srcMapAttr; - SmallVector srcMapOperands; - OpAsmParser::OperandType dstMemRefInfo; - AffineMapAttr dstMapAttr; - SmallVector dstMapOperands; - OpAsmParser::OperandType tagMemRefInfo; - AffineMapAttr tagMapAttr; - SmallVector tagMapOperands; - OpAsmParser::OperandType numElementsInfo; - SmallVector strideInfo; - - SmallVector types; - auto indexType = parser->getBuilder().getIndexType(); - - // Parse and resolve the following list of operands: - // *) dst memref followed by its affine maps operands (in square brackets). - // *) src memref followed by its affine map operands (in square brackets). - // *) tag memref followed by its affine map operands (in square brackets). - // *) number of elements transferred by DMA operation. - if (parser->parseOperand(srcMemRefInfo) || - parser->parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr, - getSrcMapAttrName(), result->attributes) || - parser->parseComma() || parser->parseOperand(dstMemRefInfo) || - parser->parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr, - getDstMapAttrName(), result->attributes) || - parser->parseComma() || parser->parseOperand(tagMemRefInfo) || - parser->parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr, - getTagMapAttrName(), result->attributes) || - parser->parseComma() || parser->parseOperand(numElementsInfo)) - return failure(); - - // Parse optional stride and elements per stride. - if (parser->parseTrailingOperandList(strideInfo)) { - return failure(); - } - if (!strideInfo.empty() && strideInfo.size() != 2) { - return parser->emitError(parser->getNameLoc(), - "expected two stride related operands"); - } - bool isStrided = strideInfo.size() == 2; - - if (parser->parseColonTypeList(types)) - return failure(); - - if (types.size() != 3) - return parser->emitError(parser->getNameLoc(), "expected three types"); - - if (parser->resolveOperand(srcMemRefInfo, types[0], result->operands) || - parser->resolveOperands(srcMapOperands, indexType, result->operands) || - parser->resolveOperand(dstMemRefInfo, types[1], result->operands) || - parser->resolveOperands(dstMapOperands, indexType, result->operands) || - parser->resolveOperand(tagMemRefInfo, types[2], result->operands) || - parser->resolveOperands(tagMapOperands, indexType, result->operands) || - parser->resolveOperand(numElementsInfo, indexType, result->operands)) - return failure(); - - if (isStrided) { - if (parser->resolveOperands(strideInfo, indexType, result->operands)) - return failure(); - } - - // Check that src/dst/tag operand counts match their map.numInputs. - if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() || - dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() || - tagMapOperands.size() != tagMapAttr.getValue().getNumInputs()) - return parser->emitError(parser->getNameLoc(), - "memref operand count not equal to map.numInputs"); - return success(); -} - -LogicalResult AffineDmaStartOp::verify() { - if (!getOperand(getSrcMemRefOperandIndex())->getType().isa()) - return emitOpError("expected DMA source to be of memref type"); - if (!getOperand(getDstMemRefOperandIndex())->getType().isa()) - return emitOpError("expected DMA destination to be of memref type"); - if (!getOperand(getTagMemRefOperandIndex())->getType().isa()) - return emitOpError("expected DMA tag to be of memref type"); - - // DMAs from different memory spaces supported. - if (getSrcMemorySpace() == getDstMemorySpace()) { - return emitOpError("DMA should be between different memory spaces"); - } - unsigned numInputsAllMaps = getSrcMap().getNumInputs() + - getDstMap().getNumInputs() + - getTagMap().getNumInputs(); - if (getNumOperands() != numInputsAllMaps + 3 + 1 && - getNumOperands() != numInputsAllMaps + 3 + 1 + 2) { - return emitOpError("incorrect number of operands"); - } - - for (auto *idx : getSrcIndices()) { - if (!idx->getType().isIndex()) - return emitOpError("src index to dma_start must have 'index' type"); - if (!isValidAffineIndexOperand(idx)) - return emitOpError("src index must be a dimension or symbol identifier"); - } - for (auto *idx : getDstIndices()) { - if (!idx->getType().isIndex()) - return emitOpError("dst index to dma_start must have 'index' type"); - if (!isValidAffineIndexOperand(idx)) - return emitOpError("dst index must be a dimension or symbol identifier"); - } - for (auto *idx : getTagIndices()) { - if (!idx->getType().isIndex()) - return emitOpError("tag index to dma_start must have 'index' type"); - if (!isValidAffineIndexOperand(idx)) - return emitOpError("tag index must be a dimension or symbol identifier"); - } - return success(); -} - -void AffineDmaStartOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - /// dma_start(memrefcast) -> dma_start - results.insert(getOperationName(), context); -} - -//===----------------------------------------------------------------------===// -// AffineDmaWaitOp -//===----------------------------------------------------------------------===// - -// TODO(b/133776335) Check that map operands are loop IVs or symbols. -void AffineDmaWaitOp::build(Builder *builder, OperationState *result, - Value *tagMemRef, AffineMap tagMap, - ArrayRef tagIndices, Value *numElements) { - result->addOperands(tagMemRef); - result->addAttribute(getTagMapAttrName(), builder->getAffineMapAttr(tagMap)); - result->addOperands(tagIndices); - result->addOperands(numElements); -} - -void AffineDmaWaitOp::print(OpAsmPrinter *p) { - *p << "affine.dma_wait " << *getTagMemRef() << '['; - SmallVector operands(getTagIndices()); - p->printAffineMapOfSSAIds(getTagMapAttr(), operands); - *p << "], "; - p->printOperand(getNumElements()); - *p << " : " << getTagMemRef()->getType(); -} - -// Parse AffineDmaWaitOp. -// Eg: -// affine.dma_wait %tag[%index], %num_elements -// : memref<1 x i32, (d0) -> (d0), 4> -// -ParseResult AffineDmaWaitOp::parse(OpAsmParser *parser, - OperationState *result) { - OpAsmParser::OperandType tagMemRefInfo; - AffineMapAttr tagMapAttr; - SmallVector tagMapOperands; - Type type; - auto indexType = parser->getBuilder().getIndexType(); - OpAsmParser::OperandType numElementsInfo; - - // Parse tag memref, its map operands, and dma size. - if (parser->parseOperand(tagMemRefInfo) || - parser->parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr, - getTagMapAttrName(), result->attributes) || - parser->parseComma() || parser->parseOperand(numElementsInfo) || - parser->parseColonType(type) || - parser->resolveOperand(tagMemRefInfo, type, result->operands) || - parser->resolveOperands(tagMapOperands, indexType, result->operands) || - parser->resolveOperand(numElementsInfo, indexType, result->operands)) - return failure(); - - if (!type.isa()) - return parser->emitError(parser->getNameLoc(), - "expected tag to be of memref type"); - - if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs()) - return parser->emitError(parser->getNameLoc(), - "tag memref operand count != to map.numInputs"); - return success(); -} - -LogicalResult AffineDmaWaitOp::verify() { - if (!getOperand(0)->getType().isa()) - return emitOpError("expected DMA tag to be of memref type"); - for (auto *idx : getTagIndices()) { - if (!idx->getType().isIndex()) - return emitOpError("index to dma_wait must have 'index' type"); - if (!isValidAffineIndexOperand(idx)) - return emitOpError("index must be a dimension or symbol identifier"); - } - return success(); -} - -void AffineDmaWaitOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - /// dma_wait(memrefcast) -> dma_wait - results.insert(getOperationName(), context); -} - -//===----------------------------------------------------------------------===// -// AffineForOp -//===----------------------------------------------------------------------===// - -void AffineForOp::build(Builder *builder, OperationState *result, - ArrayRef lbOperands, AffineMap lbMap, - ArrayRef ubOperands, AffineMap ubMap, - int64_t step) { - assert(((!lbMap && lbOperands.empty()) || - lbOperands.size() == lbMap.getNumInputs()) && - "lower bound operand count does not match the affine map"); - assert(((!ubMap && ubOperands.empty()) || - ubOperands.size() == ubMap.getNumInputs()) && - "upper bound operand count does not match the affine map"); - assert(step > 0 && "step has to be a positive integer constant"); - - // Add an attribute for the step. - result->addAttribute(getStepAttrName(), - builder->getIntegerAttr(builder->getIndexType(), step)); - - // Add the lower bound. - result->addAttribute(getLowerBoundAttrName(), - builder->getAffineMapAttr(lbMap)); - result->addOperands(lbOperands); - - // Add the upper bound. - result->addAttribute(getUpperBoundAttrName(), - builder->getAffineMapAttr(ubMap)); - result->addOperands(ubOperands); - - // Create a region and a block for the body. The argument of the region is - // the loop induction variable. - Region *bodyRegion = result->addRegion(); - Block *body = new Block(); - body->addArgument(IndexType::get(builder->getContext())); - bodyRegion->push_back(body); - ensureTerminator(*bodyRegion, *builder, result->location); - - // Set the operands list as resizable so that we can freely modify the bounds. - result->setOperandListToResizable(); -} - -void AffineForOp::build(Builder *builder, OperationState *result, int64_t lb, - int64_t ub, int64_t step) { - auto lbMap = AffineMap::getConstantMap(lb, builder->getContext()); - auto ubMap = AffineMap::getConstantMap(ub, builder->getContext()); - return build(builder, result, {}, lbMap, {}, ubMap, step); -} - -static LogicalResult verify(AffineForOp op) { - // Check that the body defines as single block argument for the induction - // variable. - auto *body = op.getBody(); - if (body->getNumArguments() != 1 || - !body->getArgument(0)->getType().isIndex()) - return op.emitOpError( - "expected body to have a single index argument for the " - "induction variable"); - - // Verify that there are enough operands for the bounds. - AffineMap lowerBoundMap = op.getLowerBoundMap(), - upperBoundMap = op.getUpperBoundMap(); - if (op.getNumOperands() != - (lowerBoundMap.getNumInputs() + upperBoundMap.getNumInputs())) - return op.emitOpError( - "operand count must match with affine map dimension and symbol count"); - - // Verify that the bound operands are valid dimension/symbols. - /// Lower bound. - if (failed(verifyDimAndSymbolIdentifiers(op, op.getLowerBoundOperands(), - op.getLowerBoundMap().getNumDims()))) - return failure(); - /// Upper bound. - if (failed(verifyDimAndSymbolIdentifiers(op, op.getUpperBoundOperands(), - op.getUpperBoundMap().getNumDims()))) - return failure(); - return success(); -} - -/// Parse a for operation loop bounds. -static ParseResult parseBound(bool isLower, OperationState *result, - OpAsmParser *p) { - // 'min' / 'max' prefixes are generally syntactic sugar, but are required if - // the map has multiple results. - bool failedToParsedMinMax = - failed(p->parseOptionalKeyword(isLower ? "max" : "min")); - - auto &builder = p->getBuilder(); - auto boundAttrName = isLower ? AffineForOp::getLowerBoundAttrName() - : AffineForOp::getUpperBoundAttrName(); - - // Parse ssa-id as identity map. - SmallVector boundOpInfos; - if (p->parseOperandList(boundOpInfos)) - return failure(); - - if (!boundOpInfos.empty()) { - // Check that only one operand was parsed. - if (boundOpInfos.size() > 1) - return p->emitError(p->getNameLoc(), - "expected only one loop bound operand"); - - // TODO: improve error message when SSA value is not an affine integer. - // Currently it is 'use of value ... expects different type than prior uses' - if (p->resolveOperand(boundOpInfos.front(), builder.getIndexType(), - result->operands)) - return failure(); - - // Create an identity map using symbol id. This representation is optimized - // for storage. Analysis passes may expand it into a multi-dimensional map - // if desired. - AffineMap map = builder.getSymbolIdentityMap(); - result->addAttribute(boundAttrName, builder.getAffineMapAttr(map)); - return success(); - } - - // Get the attribute location. - llvm::SMLoc attrLoc = p->getCurrentLocation(); - - Attribute boundAttr; - if (p->parseAttribute(boundAttr, builder.getIndexType(), boundAttrName, - result->attributes)) - return failure(); - - // Parse full form - affine map followed by dim and symbol list. - if (auto affineMapAttr = boundAttr.dyn_cast()) { - unsigned currentNumOperands = result->operands.size(); - unsigned numDims; - if (parseDimAndSymbolList(p, result->operands, numDims)) - return failure(); - - auto map = affineMapAttr.getValue(); - if (map.getNumDims() != numDims) - return p->emitError( - p->getNameLoc(), - "dim operand count and integer set dim count must match"); - - unsigned numDimAndSymbolOperands = - result->operands.size() - currentNumOperands; - if (numDims + map.getNumSymbols() != numDimAndSymbolOperands) - return p->emitError( - p->getNameLoc(), - "symbol operand count and integer set symbol count must match"); - - // If the map has multiple results, make sure that we parsed the min/max - // prefix. - if (map.getNumResults() > 1 && failedToParsedMinMax) { - if (isLower) { - return p->emitError(attrLoc, "lower loop bound affine map with " - "multiple results requires 'max' prefix"); - } - return p->emitError(attrLoc, "upper loop bound affine map with multiple " - "results requires 'min' prefix"); - } - return success(); - } - - // Parse custom assembly form. - if (auto integerAttr = boundAttr.dyn_cast()) { - result->attributes.pop_back(); - result->addAttribute( - boundAttrName, builder.getAffineMapAttr( - builder.getConstantAffineMap(integerAttr.getInt()))); - return success(); - } - - return p->emitError( - p->getNameLoc(), - "expected valid affine map representation for loop bounds"); -} - -ParseResult parseAffineForOp(OpAsmParser *parser, OperationState *result) { - auto &builder = parser->getBuilder(); - OpAsmParser::OperandType inductionVariable; - // Parse the induction variable followed by '='. - if (parser->parseRegionArgument(inductionVariable) || parser->parseEqual()) - return failure(); - - // Parse loop bounds. - if (parseBound(/*isLower=*/true, result, parser) || - parser->parseKeyword("to", " between bounds") || - parseBound(/*isLower=*/false, result, parser)) - return failure(); - - // Parse the optional loop step, we default to 1 if one is not present. - if (parser->parseOptionalKeyword("step")) { - result->addAttribute( - AffineForOp::getStepAttrName(), - builder.getIntegerAttr(builder.getIndexType(), /*value=*/1)); - } else { - llvm::SMLoc stepLoc = parser->getCurrentLocation(); - IntegerAttr stepAttr; - if (parser->parseAttribute(stepAttr, builder.getIndexType(), - AffineForOp::getStepAttrName().data(), - result->attributes)) - return failure(); - - if (stepAttr.getValue().getSExtValue() < 0) - return parser->emitError( - stepLoc, - "expected step to be representable as a positive signed integer"); - } - - // Parse the body region. - Region *body = result->addRegion(); - if (parser->parseRegion(*body, inductionVariable, builder.getIndexType())) - return failure(); - - AffineForOp::ensureTerminator(*body, builder, result->location); - - // Parse the optional attribute list. - if (parser->parseOptionalAttributeDict(result->attributes)) - return failure(); - - // Set the operands list as resizable so that we can freely modify the bounds. - result->setOperandListToResizable(); - return success(); -} - -static void printBound(AffineMapAttr boundMap, - Operation::operand_range boundOperands, - const char *prefix, OpAsmPrinter *p) { - AffineMap map = boundMap.getValue(); - - // Check if this bound should be printed using custom assembly form. - // The decision to restrict printing custom assembly form to trivial cases - // comes from the will to roundtrip MLIR binary -> text -> binary in a - // lossless way. - // Therefore, custom assembly form parsing and printing is only supported for - // zero-operand constant maps and single symbol operand identity maps. - if (map.getNumResults() == 1) { - AffineExpr expr = map.getResult(0); - - // Print constant bound. - if (map.getNumDims() == 0 && map.getNumSymbols() == 0) { - if (auto constExpr = expr.dyn_cast()) { - *p << constExpr.getValue(); - return; - } - } - - // Print bound that consists of a single SSA symbol if the map is over a - // single symbol. - if (map.getNumDims() == 0 && map.getNumSymbols() == 1) { - if (auto symExpr = expr.dyn_cast()) { - p->printOperand(*boundOperands.begin()); - return; - } - } - } else { - // Map has multiple results. Print 'min' or 'max' prefix. - *p << prefix << ' '; - } - - // Print the map and its operands. - *p << boundMap; - printDimAndSymbolList(boundOperands.begin(), boundOperands.end(), - map.getNumDims(), p); -} - -void print(OpAsmPrinter *p, AffineForOp op) { - *p << "affine.for "; - p->printOperand(op.getBody()->getArgument(0)); - *p << " = "; - printBound(op.getLowerBoundMapAttr(), op.getLowerBoundOperands(), "max", p); - *p << " to "; - printBound(op.getUpperBoundMapAttr(), op.getUpperBoundOperands(), "min", p); - - if (op.getStep() != 1) - *p << " step " << op.getStep(); - p->printRegion(op.region(), - /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/false); - p->printOptionalAttrDict(op.getAttrs(), - /*elidedAttrs=*/{op.getLowerBoundAttrName(), - op.getUpperBoundAttrName(), - op.getStepAttrName()}); -} - -namespace { -/// This is a pattern to fold constant loop bounds. -struct AffineForLoopBoundFolder : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - PatternMatchResult matchAndRewrite(AffineForOp forOp, - PatternRewriter &rewriter) const override { - auto foldLowerOrUpperBound = [&forOp](bool lower) { - // Check to see if each of the operands is the result of a constant. If - // so, get the value. If not, ignore it. - SmallVector operandConstants; - auto boundOperands = - lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands(); - for (auto *operand : boundOperands) { - Attribute operandCst; - matchPattern(operand, m_Constant(&operandCst)); - operandConstants.push_back(operandCst); - } - - AffineMap boundMap = - lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap(); - assert(boundMap.getNumResults() >= 1 && - "bound maps should have at least one result"); - SmallVector foldedResults; - if (failed(boundMap.constantFold(operandConstants, foldedResults))) - return failure(); - - // Compute the max or min as applicable over the results. - assert(!foldedResults.empty() && - "bounds should have at least one result"); - auto maxOrMin = foldedResults[0].cast().getValue(); - for (unsigned i = 1, e = foldedResults.size(); i < e; i++) { - auto foldedResult = foldedResults[i].cast().getValue(); - maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult) - : llvm::APIntOps::smin(maxOrMin, foldedResult); - } - lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue()) - : forOp.setConstantUpperBound(maxOrMin.getSExtValue()); - return success(); - }; - - // Try to fold the lower bound. - bool folded = false; - if (!forOp.hasConstantLowerBound()) - folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true)); - - // Try to fold the upper bound. - if (!forOp.hasConstantUpperBound()) - folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false)); - - // If any of the bounds were folded we return success. - if (!folded) - return matchFailure(); - rewriter.updatedRootInPlace(forOp); - return matchSuccess(); - } -}; -} // end anonymous namespace - -void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -AffineBound AffineForOp::getLowerBound() { - auto lbMap = getLowerBoundMap(); - return AffineBound(AffineForOp(*this), 0, lbMap.getNumInputs(), lbMap); -} - -AffineBound AffineForOp::getUpperBound() { - auto lbMap = getLowerBoundMap(); - auto ubMap = getUpperBoundMap(); - return AffineBound(AffineForOp(*this), lbMap.getNumInputs(), getNumOperands(), - ubMap); -} - -void AffineForOp::setLowerBound(ArrayRef lbOperands, AffineMap map) { - assert(lbOperands.size() == map.getNumInputs()); - assert(map.getNumResults() >= 1 && "bound map has at least one result"); - - SmallVector newOperands(lbOperands.begin(), lbOperands.end()); - - auto ubOperands = getUpperBoundOperands(); - newOperands.append(ubOperands.begin(), ubOperands.end()); - getOperation()->setOperands(newOperands); - - setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map)); -} - -void AffineForOp::setUpperBound(ArrayRef ubOperands, AffineMap map) { - assert(ubOperands.size() == map.getNumInputs()); - assert(map.getNumResults() >= 1 && "bound map has at least one result"); - - SmallVector newOperands(getLowerBoundOperands()); - newOperands.append(ubOperands.begin(), ubOperands.end()); - getOperation()->setOperands(newOperands); - - setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map)); -} - -void AffineForOp::setLowerBoundMap(AffineMap map) { - auto lbMap = getLowerBoundMap(); - assert(lbMap.getNumDims() == map.getNumDims() && - lbMap.getNumSymbols() == map.getNumSymbols()); - assert(map.getNumResults() >= 1 && "bound map has at least one result"); - (void)lbMap; - setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map)); -} - -void AffineForOp::setUpperBoundMap(AffineMap map) { - auto ubMap = getUpperBoundMap(); - assert(ubMap.getNumDims() == map.getNumDims() && - ubMap.getNumSymbols() == map.getNumSymbols()); - assert(map.getNumResults() >= 1 && "bound map has at least one result"); - (void)ubMap; - setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map)); -} - -bool AffineForOp::hasConstantLowerBound() { - return getLowerBoundMap().isSingleConstant(); -} - -bool AffineForOp::hasConstantUpperBound() { - return getUpperBoundMap().isSingleConstant(); -} - -int64_t AffineForOp::getConstantLowerBound() { - return getLowerBoundMap().getSingleConstantResult(); -} - -int64_t AffineForOp::getConstantUpperBound() { - return getUpperBoundMap().getSingleConstantResult(); -} - -void AffineForOp::setConstantLowerBound(int64_t value) { - setLowerBound({}, AffineMap::getConstantMap(value, getContext())); -} - -void AffineForOp::setConstantUpperBound(int64_t value) { - setUpperBound({}, AffineMap::getConstantMap(value, getContext())); -} - -AffineForOp::operand_range AffineForOp::getLowerBoundOperands() { - return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()}; -} - -AffineForOp::operand_range AffineForOp::getUpperBoundOperands() { - return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()}; -} - -bool AffineForOp::matchingBoundOperandList() { - auto lbMap = getLowerBoundMap(); - auto ubMap = getUpperBoundMap(); - if (lbMap.getNumDims() != ubMap.getNumDims() || - lbMap.getNumSymbols() != ubMap.getNumSymbols()) - return false; - - unsigned numOperands = lbMap.getNumInputs(); - for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) { - // Compare Value *'s. - if (getOperand(i) != getOperand(numOperands + i)) - return false; - } - return true; -} - -/// Returns if the provided value is the induction variable of a AffineForOp. -bool mlir::isForInductionVar(Value *val) { - return getForInductionVarOwner(val) != AffineForOp(); -} - -/// Returns the loop parent of an induction variable. If the provided value is -/// not an induction variable, then return nullptr. -AffineForOp mlir::getForInductionVarOwner(Value *val) { - auto *ivArg = dyn_cast(val); - if (!ivArg || !ivArg->getOwner()) - return AffineForOp(); - auto *containingInst = ivArg->getOwner()->getParent()->getParentOp(); - return dyn_cast(containingInst); -} - -/// Extracts the induction variables from a list of AffineForOps and returns -/// them. -void mlir::extractForInductionVars(ArrayRef forInsts, - SmallVectorImpl *ivs) { - ivs->reserve(forInsts.size()); - for (auto forInst : forInsts) - ivs->push_back(forInst.getInductionVar()); -} - -//===----------------------------------------------------------------------===// -// AffineIfOp -//===----------------------------------------------------------------------===// - -static LogicalResult verify(AffineIfOp op) { - // Verify that we have a condition attribute. - auto conditionAttr = - op.getAttrOfType(op.getConditionAttrName()); - if (!conditionAttr) - return op.emitOpError( - "requires an integer set attribute named 'condition'"); - - // Verify that there are enough operands for the condition. - IntegerSet condition = conditionAttr.getValue(); - if (op.getNumOperands() != condition.getNumOperands()) - return op.emitOpError( - "operand count and condition integer set dimension and " - "symbol count must match"); - - // Verify that the operands are valid dimension/symbols. - if (failed(verifyDimAndSymbolIdentifiers( - op, op.getOperation()->getNonSuccessorOperands(), - condition.getNumDims()))) - return failure(); - - // Verify that the entry of each child region does not have arguments. - for (auto ®ion : op.getOperation()->getRegions()) { - for (auto &b : region) - if (b.getNumArguments() != 0) - return op.emitOpError( - "requires that child entry blocks have no arguments"); - } - return success(); -} - -ParseResult parseAffineIfOp(OpAsmParser *parser, OperationState *result) { - // Parse the condition attribute set. - IntegerSetAttr conditionAttr; - unsigned numDims; - if (parser->parseAttribute(conditionAttr, AffineIfOp::getConditionAttrName(), - result->attributes) || - parseDimAndSymbolList(parser, result->operands, numDims)) - return failure(); - - // Verify the condition operands. - auto set = conditionAttr.getValue(); - if (set.getNumDims() != numDims) - return parser->emitError( - parser->getNameLoc(), - "dim operand count and integer set dim count must match"); - if (numDims + set.getNumSymbols() != result->operands.size()) - return parser->emitError( - parser->getNameLoc(), - "symbol operand count and integer set symbol count must match"); - - // Create the regions for 'then' and 'else'. The latter must be created even - // if it remains empty for the validity of the operation. - result->regions.reserve(2); - Region *thenRegion = result->addRegion(); - Region *elseRegion = result->addRegion(); - - // Parse the 'then' region. - if (parser->parseRegion(*thenRegion, {}, {})) - return failure(); - AffineIfOp::ensureTerminator(*thenRegion, parser->getBuilder(), - result->location); - - // If we find an 'else' keyword then parse the 'else' region. - if (!parser->parseOptionalKeyword("else")) { - if (parser->parseRegion(*elseRegion, {}, {})) - return failure(); - AffineIfOp::ensureTerminator(*elseRegion, parser->getBuilder(), - result->location); - } - - // Parse the optional attribute list. - if (parser->parseOptionalAttributeDict(result->attributes)) - return failure(); - - return success(); -} - -void print(OpAsmPrinter *p, AffineIfOp op) { - auto conditionAttr = - op.getAttrOfType(op.getConditionAttrName()); - *p << "affine.if " << conditionAttr; - printDimAndSymbolList(op.operand_begin(), op.operand_end(), - conditionAttr.getValue().getNumDims(), p); - p->printRegion(op.thenRegion(), - /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/false); - - // Print the 'else' regions if it has any blocks. - auto &elseRegion = op.elseRegion(); - if (!elseRegion.empty()) { - *p << " else"; - p->printRegion(elseRegion, - /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/false); - } - - // Print the attribute list. - p->printOptionalAttrDict(op.getAttrs(), - /*elidedAttrs=*/op.getConditionAttrName()); -} - -IntegerSet AffineIfOp::getIntegerSet() { - return getAttrOfType(getConditionAttrName()).getValue(); -} -void AffineIfOp::setIntegerSet(IntegerSet newSet) { - setAttr(getConditionAttrName(), IntegerSetAttr::get(newSet)); -} - -//===----------------------------------------------------------------------===// -// AffineLoadOp -//===----------------------------------------------------------------------===// - -void AffineLoadOp::build(Builder *builder, OperationState *result, - AffineMap map, ArrayRef operands) { - result->addOperands(operands); - if (map) - result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map)); - auto memrefType = operands[0]->getType().cast(); - result->types.push_back(memrefType.getElementType()); -} - -void AffineLoadOp::build(Builder *builder, OperationState *result, - Value *memref, ArrayRef indices) { - result->addOperands(memref); - result->addOperands(indices); - auto memrefType = memref->getType().cast(); - auto rank = memrefType.getRank(); - // Create identity map for memrefs with at least one dimension or () -> () - // for zero-dimensional memrefs. - auto map = rank ? builder->getMultiDimIdentityMap(rank) - : builder->getEmptyAffineMap(); - result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map)); - result->types.push_back(memrefType.getElementType()); -} - -ParseResult AffineLoadOp::parse(OpAsmParser *parser, OperationState *result) { - auto &builder = parser->getBuilder(); - auto affineIntTy = builder.getIndexType(); - - MemRefType type; - OpAsmParser::OperandType memrefInfo; - AffineMapAttr mapAttr; - SmallVector mapOperands; - return failure( - parser->parseOperand(memrefInfo) || - parser->parseAffineMapOfSSAIds(mapOperands, mapAttr, getMapAttrName(), - result->attributes) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(type) || - parser->resolveOperand(memrefInfo, type, result->operands) || - parser->resolveOperands(mapOperands, affineIntTy, result->operands) || - parser->addTypeToList(type.getElementType(), result->types)); -} - -void AffineLoadOp::print(OpAsmPrinter *p) { - *p << "affine.load " << *getMemRef() << '['; - AffineMapAttr mapAttr = getAttrOfType(getMapAttrName()); - if (mapAttr) { - SmallVector operands(getIndices()); - p->printAffineMapOfSSAIds(mapAttr, operands); - } - *p << ']'; - p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{getMapAttrName()}); - *p << " : " << getMemRefType(); -} - -LogicalResult AffineLoadOp::verify() { - if (getType() != getMemRefType().getElementType()) - return emitOpError("result type must match element type of memref"); - - auto mapAttr = getAttrOfType(getMapAttrName()); - if (mapAttr) { - AffineMap map = getAttrOfType(getMapAttrName()).getValue(); - if (map.getNumResults() != getMemRefType().getRank()) - return emitOpError("affine.load affine map num results must equal" - " memref rank"); - if (map.getNumInputs() != getNumOperands() - 1) - return emitOpError("expects as many subscripts as affine map inputs"); - } else { - if (getMemRefType().getRank() != getNumOperands() - 1) - return emitOpError( - "expects the number of subscripts to be equal to memref rank"); - } - - for (auto *idx : getIndices()) { - if (!idx->getType().isIndex()) - return emitOpError("index to load must have 'index' type"); - if (!isValidAffineIndexOperand(idx)) - return emitOpError("index must be a dimension or symbol identifier"); - } - return success(); -} - -void AffineLoadOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - /// load(memrefcast) -> load - results.insert(getOperationName(), context); -} - -//===----------------------------------------------------------------------===// -// AffineStoreOp -//===----------------------------------------------------------------------===// - -void AffineStoreOp::build(Builder *builder, OperationState *result, - Value *valueToStore, AffineMap map, - ArrayRef operands) { - result->addOperands(valueToStore); - result->addOperands(operands); - if (map) - result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map)); -} - -void AffineStoreOp::build(Builder *builder, OperationState *result, - Value *valueToStore, Value *memref, - ArrayRef operands) { - result->addOperands(valueToStore); - result->addOperands(memref); - result->addOperands(operands); - auto memrefType = memref->getType().cast(); - auto rank = memrefType.getRank(); - // Create identity map for memrefs with at least one dimension or () -> () - // for zero-dimensional memrefs. - auto map = rank ? builder->getMultiDimIdentityMap(rank) - : builder->getEmptyAffineMap(); - result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map)); -} - -ParseResult AffineStoreOp::parse(OpAsmParser *parser, OperationState *result) { - auto affineIntTy = parser->getBuilder().getIndexType(); - - MemRefType type; - OpAsmParser::OperandType storeValueInfo; - OpAsmParser::OperandType memrefInfo; - AffineMapAttr mapAttr; - SmallVector mapOperands; - return failure( - parser->parseOperand(storeValueInfo) || parser->parseComma() || - parser->parseOperand(memrefInfo) || - parser->parseAffineMapOfSSAIds(mapOperands, mapAttr, getMapAttrName(), - result->attributes) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(type) || - parser->resolveOperand(storeValueInfo, type.getElementType(), - result->operands) || - parser->resolveOperand(memrefInfo, type, result->operands) || - parser->resolveOperands(mapOperands, affineIntTy, result->operands)); -} - -void AffineStoreOp::print(OpAsmPrinter *p) { - *p << "affine.store " << *getValueToStore(); - *p << ", " << *getMemRef() << '['; - AffineMapAttr mapAttr = getAttrOfType(getMapAttrName()); - if (mapAttr) { - SmallVector operands(getIndices()); - p->printAffineMapOfSSAIds(mapAttr, operands); - } - *p << ']'; - p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{getMapAttrName()}); - *p << " : " << getMemRefType(); -} - -LogicalResult AffineStoreOp::verify() { - // First operand must have same type as memref element type. - if (getValueToStore()->getType() != getMemRefType().getElementType()) - return emitOpError("first operand must have same type memref element type"); - - auto mapAttr = getAttrOfType(getMapAttrName()); - if (mapAttr) { - AffineMap map = mapAttr.getValue(); - if (map.getNumResults() != getMemRefType().getRank()) - return emitOpError("affine.store affine map num results must equal" - " memref rank"); - if (map.getNumInputs() != getNumOperands() - 2) - return emitOpError("expects as many subscripts as affine map inputs"); - } else { - if (getMemRefType().getRank() != getNumOperands() - 2) - return emitOpError( - "expects the number of subscripts to be equal to memref rank"); - } - - for (auto *idx : getIndices()) { - if (!idx->getType().isIndex()) - return emitOpError("index to store must have 'index' type"); - if (!isValidAffineIndexOperand(idx)) - return emitOpError("index must be a dimension or symbol identifier"); - } - return success(); -} - -void AffineStoreOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - /// load(memrefcast) -> load - results.insert(getOperationName(), context); -} - -#define GET_OP_CLASSES -#include "mlir/AffineOps/AffineOps.cpp.inc" diff --git a/mlir/lib/AffineOps/CMakeLists.txt b/mlir/lib/AffineOps/CMakeLists.txt deleted file mode 100644 index a8cf24e6c2b..00000000000 --- a/mlir/lib/AffineOps/CMakeLists.txt +++ /dev/null @@ -1,10 +0,0 @@ -add_llvm_library(MLIRAffineOps - AffineOps.cpp - DialectRegistration.cpp - - ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/AffineOps - ) -add_dependencies(MLIRAffineOps MLIRAffineOpsIncGen MLIRIR MLIRStandardOps) -target_link_libraries(MLIRAffineOps MLIRIR MLIRStandardOps) - diff --git a/mlir/lib/AffineOps/DialectRegistration.cpp b/mlir/lib/AffineOps/DialectRegistration.cpp deleted file mode 100644 index 0afb32c1bd6..00000000000 --- a/mlir/lib/AffineOps/DialectRegistration.cpp +++ /dev/null @@ -1,22 +0,0 @@ -//===- DialectRegistration.cpp - Register Affine Op dialect ---------------===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#include "mlir/AffineOps/AffineOps.h" -using namespace mlir; - -// Static initialization for Affine op dialect registration. -static DialectRegistration StandardOps; diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index e074e5d4405..92997ad27a7 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -21,9 +21,9 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/AffineAnalysis.h" -#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/Builders.h" diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index b1e818ac02c..70daca9754f 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -20,7 +20,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/AffineStructures.h" -#include "mlir/AffineOps/AffineOps.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 1e1095743c9..21d47c3c1ea 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -21,11 +21,11 @@ #include "mlir/Analysis/LoopAnalysis.h" -#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/NestedMatcher.h" #include "mlir/Analysis/VectorAnalysis.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Dialect/VectorOps/VectorOps.h" #include "mlir/IR/AffineMap.h" diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index 85fe3109f6a..849407520da 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -20,11 +20,11 @@ // //===----------------------------------------------------------------------===// -#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Passes.h" #include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/lib/Analysis/NestedMatcher.cpp b/mlir/lib/Analysis/NestedMatcher.cpp index c7c0db90a7b..9d7d17f836c 100644 --- a/mlir/lib/Analysis/NestedMatcher.cpp +++ b/mlir/lib/Analysis/NestedMatcher.cpp @@ -16,7 +16,7 @@ // ============================================================================= #include "mlir/Analysis/NestedMatcher.h" -#include "mlir/AffineOps/AffineOps.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "llvm/ADT/ArrayRef.h" diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index c240d779c44..2f7eddf5ab3 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -20,8 +20,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/SliceAnalysis.h" -#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/VectorAnalysis.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/IR/Function.h" #include "mlir/IR/Operation.h" diff --git a/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp b/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp index 9ecdcf7c2fe..477121fcc24 100644 --- a/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp @@ -19,11 +19,11 @@ // //===----------------------------------------------------------------------===// -#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Passes.h" #include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/lib/Analysis/TestParallelismDetection.cpp b/mlir/lib/Analysis/TestParallelismDetection.cpp index 246cfbe9720..351a6a7a191 100644 --- a/mlir/lib/Analysis/TestParallelismDetection.cpp +++ b/mlir/lib/Analysis/TestParallelismDetection.cpp @@ -19,9 +19,9 @@ // //===----------------------------------------------------------------------===// -#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/Passes.h" #include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index d4fc42ceff7..aaefd98d1bd 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -22,9 +22,9 @@ #include "mlir/Analysis/Utils.h" -#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Builders.h" #include "llvm/ADT/DenseMap.h" diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index f34515f73a0..9846abb7be2 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -16,9 +16,9 @@ // ============================================================================= #include "mlir/Analysis/VectorAnalysis.h" -#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/LoopAnalysis.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Dialect/VectorOps/VectorOps.h" #include "mlir/IR/Builders.h" diff --git a/mlir/lib/CMakeLists.txt b/mlir/lib/CMakeLists.txt index bcb2d21d2da..f34b1e8bead 100644 --- a/mlir/lib/CMakeLists.txt +++ b/mlir/lib/CMakeLists.txt @@ -1,4 +1,3 @@ -add_subdirectory(AffineOps) add_subdirectory(Analysis) add_subdirectory(Conversion) add_subdirectory(Dialect) diff --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp index 13ba898dc44..154a8660bee 100644 --- a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp +++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp @@ -22,7 +22,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/LoopsToGPU/LoopsToGPU.h" -#include "mlir/AffineOps/AffineOps.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/Dialect/StandardOps/Ops.h" diff --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp index 4b241e497c6..9dd9fdbbb87 100644 --- a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp +++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp @@ -16,8 +16,8 @@ // ============================================================================= #include "mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h" -#include "mlir/AffineOps/AffineOps.h" #include "mlir/Conversion/LoopsToGPU/LoopsToGPU.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/mlir/lib/Dialect/AffineOps/AffineOps.cpp new file mode 100644 index 00000000000..7db3fa07c52 --- /dev/null +++ b/mlir/lib/Dialect/AffineOps/AffineOps.cpp @@ -0,0 +1,1764 @@ +//===- AffineOps.cpp - MLIR Affine Operations -----------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/IntegerSet.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallBitVector.h" +#include "llvm/Support/Debug.h" +using namespace mlir; +using llvm::dbgs; + +#define DEBUG_TYPE "affine-analysis" + +//===----------------------------------------------------------------------===// +// AffineOpsDialect +//===----------------------------------------------------------------------===// + +AffineOpsDialect::AffineOpsDialect(MLIRContext *context) + : Dialect(getDialectNamespace(), context) { + addOperations(); +} + +/// A utility function to check if a given region is attached to a function. +static bool isFunctionRegion(Region *region) { + return llvm::isa(region->getParentOp()); +} + +/// A utility function to check if a value is defined at the top level of a +/// function. A value defined at the top level is always a valid symbol. +bool mlir::isTopLevelSymbol(Value *value) { + if (auto *arg = dyn_cast(value)) + return isFunctionRegion(arg->getOwner()->getParent()); + return isFunctionRegion(value->getDefiningOp()->getParentRegion()); +} + +// Value can be used as a dimension id if it is valid as a symbol, or +// it is an induction variable, or it is a result of affine apply operation +// with dimension id arguments. +bool mlir::isValidDim(Value *value) { + // The value must be an index type. + if (!value->getType().isIndex()) + return false; + + if (auto *op = value->getDefiningOp()) { + // Top level operation or constant operation is ok. + if (isFunctionRegion(op->getParentRegion()) || isa(op)) + return true; + // Affine apply operation is ok if all of its operands are ok. + if (auto applyOp = dyn_cast(op)) + return applyOp.isValidDim(); + // The dim op is okay if its operand memref/tensor is defined at the top + // level. + if (auto dimOp = dyn_cast(op)) + return isTopLevelSymbol(dimOp.getOperand()); + return false; + } + // This value is a block argument (which also includes 'affine.for' loop IVs). + return true; +} + +// Value can be used as a symbol if it is a constant, or it is defined at +// the top level, or it is a result of affine apply operation with symbol +// arguments. +bool mlir::isValidSymbol(Value *value) { + // The value must be an index type. + if (!value->getType().isIndex()) + return false; + + if (auto *op = value->getDefiningOp()) { + // Top level operation or constant operation is ok. + if (isFunctionRegion(op->getParentRegion()) || isa(op)) + return true; + // Affine apply operation is ok if all of its operands are ok. + if (auto applyOp = dyn_cast(op)) + return applyOp.isValidSymbol(); + // The dim op is okay if its operand memref/tensor is defined at the top + // level. + if (auto dimOp = dyn_cast(op)) + return isTopLevelSymbol(dimOp.getOperand()); + return false; + } + // Otherwise, check that the value is a top level symbol. + return isTopLevelSymbol(value); +} + +// Returns true if 'value' is a valid index to an affine operation (e.g. +// affine.load, affine.store, affine.dma_start, affine.dma_wait). +// Returns false otherwise. +static bool isValidAffineIndexOperand(Value *value) { + return isValidDim(value) || isValidSymbol(value); +} + +/// Utility function to verify that a set of operands are valid dimension and +/// symbol identifiers. The operands should be layed out such that the dimension +/// operands are before the symbol operands. This function returns failure if +/// there was an invalid operand. An operation is provided to emit any necessary +/// errors. +template +static LogicalResult +verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, + unsigned numDims) { + unsigned opIt = 0; + for (auto *operand : operands) { + if (opIt++ < numDims) { + if (!isValidDim(operand)) + return op.emitOpError("operand cannot be used as a dimension id"); + } else if (!isValidSymbol(operand)) { + return op.emitOpError("operand cannot be used as a symbol"); + } + } + return success(); +} + +//===----------------------------------------------------------------------===// +// AffineApplyOp +//===----------------------------------------------------------------------===// + +void AffineApplyOp::build(Builder *builder, OperationState *result, + AffineMap map, ArrayRef operands) { + result->addOperands(operands); + result->types.append(map.getNumResults(), builder->getIndexType()); + result->addAttribute("map", builder->getAffineMapAttr(map)); +} + +ParseResult AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) { + auto &builder = parser->getBuilder(); + auto affineIntTy = builder.getIndexType(); + + AffineMapAttr mapAttr; + unsigned numDims; + if (parser->parseAttribute(mapAttr, "map", result->attributes) || + parseDimAndSymbolList(parser, result->operands, numDims) || + parser->parseOptionalAttributeDict(result->attributes)) + return failure(); + auto map = mapAttr.getValue(); + + if (map.getNumDims() != numDims || + numDims + map.getNumSymbols() != result->operands.size()) { + return parser->emitError(parser->getNameLoc(), + "dimension or symbol index mismatch"); + } + + result->types.append(map.getNumResults(), affineIntTy); + return success(); +} + +void AffineApplyOp::print(OpAsmPrinter *p) { + *p << "affine.apply " << getAttr("map"); + printDimAndSymbolList(operand_begin(), operand_end(), + getAffineMap().getNumDims(), p); + p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{"map"}); +} + +LogicalResult AffineApplyOp::verify() { + // Check that affine map attribute was specified. + auto affineMapAttr = getAttrOfType("map"); + if (!affineMapAttr) + return emitOpError("requires an affine map"); + + // Check input and output dimensions match. + auto map = affineMapAttr.getValue(); + + // Verify that operand count matches affine map dimension and symbol count. + if (getNumOperands() != map.getNumDims() + map.getNumSymbols()) + return emitOpError( + "operand count and affine map dimension and symbol count must match"); + + // Verify that all operands are of `index` type. + for (Type t : getOperandTypes()) { + if (!t.isIndex()) + return emitOpError("operands must be of type 'index'"); + } + + if (!getResult()->getType().isIndex()) + return emitOpError("result must be of type 'index'"); + + // Verify that the operands are valid dimension and symbol identifiers. + if (failed(verifyDimAndSymbolIdentifiers(*this, getOperands(), + map.getNumDims()))) + return failure(); + + // Verify that the map only produces one result. + if (map.getNumResults() != 1) + return emitOpError("mapping must produce one value"); + + return success(); +} + +// The result of the affine apply operation can be used as a dimension id if it +// is a CFG value or if it is an Value, and all the operands are valid +// dimension ids. +bool AffineApplyOp::isValidDim() { + return llvm::all_of(getOperands(), + [](Value *op) { return mlir::isValidDim(op); }); +} + +// The result of the affine apply operation can be used as a symbol if it is +// a CFG value or if it is an Value, and all the operands are symbols. +bool AffineApplyOp::isValidSymbol() { + return llvm::all_of(getOperands(), + [](Value *op) { return mlir::isValidSymbol(op); }); +} + +OpFoldResult AffineApplyOp::fold(ArrayRef operands) { + auto map = getAffineMap(); + + // Fold dims and symbols to existing values. + auto expr = map.getResult(0); + if (auto dim = expr.dyn_cast()) + return getOperand(dim.getPosition()); + if (auto sym = expr.dyn_cast()) + return getOperand(map.getNumDims() + sym.getPosition()); + + // Otherwise, default to folding the map. + SmallVector result; + if (failed(map.constantFold(operands, result))) + return {}; + return result[0]; +} + +namespace { +/// An `AffineApplyNormalizer` is a helper class that is not visible to the user +/// and supports renumbering operands of AffineApplyOp. This acts as a +/// reindexing map of Value* to positional dims or symbols and allows +/// simplifications such as: +/// +/// ```mlir +/// %1 = affine.apply (d0, d1) -> (d0 - d1) (%0, %0) +/// ``` +/// +/// into: +/// +/// ```mlir +/// %1 = affine.apply () -> (0) +/// ``` +struct AffineApplyNormalizer { + AffineApplyNormalizer(AffineMap map, ArrayRef operands); + + /// Returns the AffineMap resulting from normalization. + AffineMap getAffineMap() { return affineMap; } + + SmallVector getOperands() { + SmallVector res(reorderedDims); + res.append(concatenatedSymbols.begin(), concatenatedSymbols.end()); + return res; + } + +private: + /// Helper function to insert `v` into the coordinate system of the current + /// AffineApplyNormalizer. Returns the AffineDimExpr with the corresponding + /// renumbered position. + AffineDimExpr renumberOneDim(Value *v); + + /// Given an `other` normalizer, this rewrites `other.affineMap` in the + /// coordinate system of the current AffineApplyNormalizer. + /// Returns the rewritten AffineMap and updates the dims and symbols of + /// `this`. + AffineMap renumber(const AffineApplyNormalizer &other); + + /// Maps of Value* to position in `affineMap`. + DenseMap dimValueToPosition; + + /// Ordered dims and symbols matching positional dims and symbols in + /// `affineMap`. + SmallVector reorderedDims; + SmallVector concatenatedSymbols; + + AffineMap affineMap; + + /// Used with RAII to control the depth at which AffineApply are composed + /// recursively. Only accepts depth 1 for now to allow a behavior where a + /// newly composed AffineApplyOp does not increase the length of the chain of + /// AffineApplyOps. Full composition is implemented iteratively on top of + /// this behavior. + static unsigned &affineApplyDepth() { + static thread_local unsigned depth = 0; + return depth; + } + static constexpr unsigned kMaxAffineApplyDepth = 1; + + AffineApplyNormalizer() { affineApplyDepth()++; } + +public: + ~AffineApplyNormalizer() { affineApplyDepth()--; } +}; +} // end anonymous namespace. + +AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value *v) { + DenseMap::iterator iterPos; + bool inserted = false; + std::tie(iterPos, inserted) = + dimValueToPosition.insert(std::make_pair(v, dimValueToPosition.size())); + if (inserted) { + reorderedDims.push_back(v); + } + return getAffineDimExpr(iterPos->second, v->getContext()) + .cast(); +} + +AffineMap AffineApplyNormalizer::renumber(const AffineApplyNormalizer &other) { + SmallVector dimRemapping; + for (auto *v : other.reorderedDims) { + auto kvp = other.dimValueToPosition.find(v); + if (dimRemapping.size() <= kvp->second) + dimRemapping.resize(kvp->second + 1); + dimRemapping[kvp->second] = renumberOneDim(kvp->first); + } + unsigned numSymbols = concatenatedSymbols.size(); + unsigned numOtherSymbols = other.concatenatedSymbols.size(); + SmallVector symRemapping(numOtherSymbols); + for (unsigned idx = 0; idx < numOtherSymbols; ++idx) { + symRemapping[idx] = + getAffineSymbolExpr(idx + numSymbols, other.affineMap.getContext()); + } + concatenatedSymbols.insert(concatenatedSymbols.end(), + other.concatenatedSymbols.begin(), + other.concatenatedSymbols.end()); + auto map = other.affineMap; + return map.replaceDimsAndSymbols(dimRemapping, symRemapping, + dimRemapping.size(), symRemapping.size()); +} + +// Gather the positions of the operands that are produced by an AffineApplyOp. +static llvm::SetVector +indicesFromAffineApplyOp(ArrayRef operands) { + llvm::SetVector res; + for (auto en : llvm::enumerate(operands)) + if (isa_and_nonnull(en.value()->getDefiningOp())) + res.insert(en.index()); + return res; +} + +// Support the special case of a symbol coming from an AffineApplyOp that needs +// to be composed into the current AffineApplyOp. +// This case is handled by rewriting all such symbols into dims for the purpose +// of allowing mathematical AffineMap composition. +// Returns an AffineMap where symbols that come from an AffineApplyOp have been +// rewritten as dims and are ordered after the original dims. +// TODO(andydavis,ntv): This promotion makes AffineMap lose track of which +// symbols are represented as dims. This loss is static but can still be +// recovered dynamically (with `isValidSymbol`). Still this is annoying for the +// semi-affine map case. A dynamic canonicalization of all dims that are valid +// symbols (a.k.a `canonicalizePromotedSymbols`) into symbols helps and even +// results in better simplifications and foldings. But we should evaluate +// whether this behavior is what we really want after using more. +static AffineMap promoteComposedSymbolsAsDims(AffineMap map, + ArrayRef symbols) { + if (symbols.empty()) { + return map; + } + + // Sanity check on symbols. + for (auto *sym : symbols) { + assert(isValidSymbol(sym) && "Expected only valid symbols"); + (void)sym; + } + + // Extract the symbol positions that come from an AffineApplyOp and + // needs to be rewritten as dims. + auto symPositions = indicesFromAffineApplyOp(symbols); + if (symPositions.empty()) { + return map; + } + + // Create the new map by replacing each symbol at pos by the next new dim. + unsigned numDims = map.getNumDims(); + unsigned numSymbols = map.getNumSymbols(); + unsigned numNewDims = 0; + unsigned numNewSymbols = 0; + SmallVector symReplacements(numSymbols); + for (unsigned i = 0; i < numSymbols; ++i) { + symReplacements[i] = + symPositions.count(i) > 0 + ? getAffineDimExpr(numDims + numNewDims++, map.getContext()) + : getAffineSymbolExpr(numNewSymbols++, map.getContext()); + } + assert(numSymbols >= numNewDims); + AffineMap newMap = map.replaceDimsAndSymbols( + {}, symReplacements, numDims + numNewDims, numNewSymbols); + + return newMap; +} + +/// The AffineNormalizer composes AffineApplyOp recursively. Its purpose is to +/// keep a correspondence between the mathematical `map` and the `operands` of +/// a given AffineApplyOp. This correspondence is maintained by iterating over +/// the operands and forming an `auxiliaryMap` that can be composed +/// mathematically with `map`. To keep this correspondence in cases where +/// symbols are produced by affine.apply operations, we perform a local rewrite +/// of symbols as dims. +/// +/// Rationale for locally rewriting symbols as dims: +/// ================================================ +/// The mathematical composition of AffineMap must always concatenate symbols +/// because it does not have enough information to do otherwise. For example, +/// composing `(d0)[s0] -> (d0 + s0)` with itself must produce +/// `(d0)[s0, s1] -> (d0 + s0 + s1)`. +/// +/// The result is only equivalent to `(d0)[s0] -> (d0 + 2 * s0)` when +/// applied to the same mlir::Value* for both s0 and s1. +/// As a consequence mathematical composition of AffineMap always concatenates +/// symbols. +/// +/// When AffineMaps are used in AffineApplyOp however, they may specify +/// composition via symbols, which is ambiguous mathematically. This corner case +/// is handled by locally rewriting such symbols that come from AffineApplyOp +/// into dims and composing through dims. +/// TODO(andydavis, ntv): Composition via symbols comes at a significant code +/// complexity. Alternatively we should investigate whether we want to +/// explicitly disallow symbols coming from affine.apply and instead force the +/// user to compose symbols beforehand. The annoyances may be small (i.e. 1 or 2 +/// extra API calls for such uses, which haven't popped up until now) and the +/// benefit potentially big: simpler and more maintainable code for a +/// non-trivial, recursive, procedure. +AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map, + ArrayRef operands) + : AffineApplyNormalizer() { + static_assert(kMaxAffineApplyDepth > 0, "kMaxAffineApplyDepth must be > 0"); + assert(map.getNumInputs() == operands.size() && + "number of operands does not match the number of map inputs"); + + LLVM_DEBUG(map.print(dbgs() << "\nInput map: ")); + + // Promote symbols that come from an AffineApplyOp to dims by rewriting the + // map to always refer to: + // (dims, symbols coming from AffineApplyOp, other symbols). + // The order of operands can remain unchanged. + // This is a simplification that relies on 2 ordering properties: + // 1. rewritten symbols always appear after the original dims in the map; + // 2. operands are traversed in order and either dispatched to: + // a. auxiliaryExprs (dims and symbols rewritten as dims); + // b. concatenatedSymbols (all other symbols) + // This allows operand order to remain unchanged. + unsigned numDimsBeforeRewrite = map.getNumDims(); + map = promoteComposedSymbolsAsDims(map, + operands.take_back(map.getNumSymbols())); + + LLVM_DEBUG(map.print(dbgs() << "\nRewritten map: ")); + + SmallVector auxiliaryExprs; + bool furtherCompose = (affineApplyDepth() <= kMaxAffineApplyDepth); + // We fully spell out the 2 cases below. In this particular instance a little + // code duplication greatly improves readability. + // Note that the first branch would disappear if we only supported full + // composition (i.e. infinite kMaxAffineApplyDepth). + if (!furtherCompose) { + // 1. Only dispatch dims or symbols. + for (auto en : llvm::enumerate(operands)) { + auto *t = en.value(); + assert(t->getType().isIndex()); + bool isDim = (en.index() < map.getNumDims()); + if (isDim) { + // a. The mathematical composition of AffineMap composes dims. + auxiliaryExprs.push_back(renumberOneDim(t)); + } else { + // b. The mathematical composition of AffineMap concatenates symbols. + // We do the same for symbol operands. + concatenatedSymbols.push_back(t); + } + } + } else { + assert(numDimsBeforeRewrite <= operands.size()); + // 2. Compose AffineApplyOps and dispatch dims or symbols. + for (unsigned i = 0, e = operands.size(); i < e; ++i) { + auto *t = operands[i]; + auto affineApply = dyn_cast_or_null(t->getDefiningOp()); + if (affineApply) { + // a. Compose affine.apply operations. + LLVM_DEBUG(affineApply.getOperation()->print( + dbgs() << "\nCompose AffineApplyOp recursively: ")); + AffineMap affineApplyMap = affineApply.getAffineMap(); + SmallVector affineApplyOperands( + affineApply.getOperands().begin(), affineApply.getOperands().end()); + AffineApplyNormalizer normalizer(affineApplyMap, affineApplyOperands); + + LLVM_DEBUG(normalizer.affineMap.print( + dbgs() << "\nRenumber into current normalizer: ")); + + auto renumberedMap = renumber(normalizer); + + LLVM_DEBUG( + renumberedMap.print(dbgs() << "\nRecursive composition yields: ")); + + auxiliaryExprs.push_back(renumberedMap.getResult(0)); + } else { + if (i < numDimsBeforeRewrite) { + // b. The mathematical composition of AffineMap composes dims. + auxiliaryExprs.push_back(renumberOneDim(t)); + } else { + // c. The mathematical composition of AffineMap concatenates symbols. + // We do the same for symbol operands. + concatenatedSymbols.push_back(t); + } + } + } + } + + // Early exit if `map` is already composed. + if (auxiliaryExprs.empty()) { + affineMap = map; + return; + } + + assert(concatenatedSymbols.size() >= map.getNumSymbols() && + "Unexpected number of concatenated symbols"); + auto numDims = dimValueToPosition.size(); + auto numSymbols = concatenatedSymbols.size() - map.getNumSymbols(); + auto auxiliaryMap = AffineMap::get(numDims, numSymbols, auxiliaryExprs); + + LLVM_DEBUG(map.print(dbgs() << "\nCompose map: ")); + LLVM_DEBUG(auxiliaryMap.print(dbgs() << "\nWith map: ")); + LLVM_DEBUG(map.compose(auxiliaryMap).print(dbgs() << "\nResult: ")); + + // TODO(andydavis,ntv): Disabling simplification results in major speed gains. + // Another option is to cache the results as it is expected a lot of redundant + // work is performed in practice. + affineMap = simplifyAffineMap(map.compose(auxiliaryMap)); + + LLVM_DEBUG(affineMap.print(dbgs() << "\nSimplified result: ")); + LLVM_DEBUG(dbgs() << "\n"); +} + +/// Implements `map` and `operands` composition and simplification to support +/// `makeComposedAffineApply`. This can be called to achieve the same effects +/// on `map` and `operands` without creating an AffineApplyOp that needs to be +/// immediately deleted. +static void composeAffineMapAndOperands(AffineMap *map, + SmallVectorImpl *operands) { + AffineApplyNormalizer normalizer(*map, *operands); + auto normalizedMap = normalizer.getAffineMap(); + auto normalizedOperands = normalizer.getOperands(); + canonicalizeMapAndOperands(&normalizedMap, &normalizedOperands); + *map = normalizedMap; + *operands = normalizedOperands; + assert(*map); +} + +void mlir::fullyComposeAffineMapAndOperands( + AffineMap *map, SmallVectorImpl *operands) { + while (llvm::any_of(*operands, [](Value *v) { + return isa_and_nonnull(v->getDefiningOp()); + })) { + composeAffineMapAndOperands(map, operands); + } +} + +AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc, + AffineMap map, + ArrayRef operands) { + AffineMap normalizedMap = map; + SmallVector normalizedOperands(operands.begin(), operands.end()); + composeAffineMapAndOperands(&normalizedMap, &normalizedOperands); + assert(normalizedMap); + return b.create(loc, normalizedMap, normalizedOperands); +} + +// A symbol may appear as a dim in affine.apply operations. This function +// canonicalizes dims that are valid symbols into actual symbols. +static void +canonicalizePromotedSymbols(AffineMap *map, + llvm::SmallVectorImpl *operands) { + if (!map || operands->empty()) + return; + + assert(map->getNumInputs() == operands->size() && + "map inputs must match number of operands"); + + auto *context = map->getContext(); + SmallVector resultOperands; + resultOperands.reserve(operands->size()); + SmallVector remappedSymbols; + remappedSymbols.reserve(operands->size()); + unsigned nextDim = 0; + unsigned nextSym = 0; + unsigned oldNumSyms = map->getNumSymbols(); + SmallVector dimRemapping(map->getNumDims()); + for (unsigned i = 0, e = map->getNumInputs(); i != e; ++i) { + if (i < map->getNumDims()) { + if (isValidSymbol((*operands)[i])) { + // This is a valid symbols that appears as a dim, canonicalize it. + dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context); + remappedSymbols.push_back((*operands)[i]); + } else { + dimRemapping[i] = getAffineDimExpr(nextDim++, context); + resultOperands.push_back((*operands)[i]); + } + } else { + resultOperands.push_back((*operands)[i]); + } + } + + resultOperands.append(remappedSymbols.begin(), remappedSymbols.end()); + *operands = resultOperands; + *map = map->replaceDimsAndSymbols(dimRemapping, {}, nextDim, + oldNumSyms + nextSym); + + assert(map->getNumInputs() == operands->size() && + "map inputs must match number of operands"); +} + +void mlir::canonicalizeMapAndOperands( + AffineMap *map, llvm::SmallVectorImpl *operands) { + if (!map || operands->empty()) + return; + + assert(map->getNumInputs() == operands->size() && + "map inputs must match number of operands"); + + canonicalizePromotedSymbols(map, operands); + + // Check to see what dims are used. + llvm::SmallBitVector usedDims(map->getNumDims()); + llvm::SmallBitVector usedSyms(map->getNumSymbols()); + map->walkExprs([&](AffineExpr expr) { + if (auto dimExpr = expr.dyn_cast()) + usedDims[dimExpr.getPosition()] = true; + else if (auto symExpr = expr.dyn_cast()) + usedSyms[symExpr.getPosition()] = true; + }); + + auto *context = map->getContext(); + + SmallVector resultOperands; + resultOperands.reserve(operands->size()); + + llvm::SmallDenseMap seenDims; + SmallVector dimRemapping(map->getNumDims()); + unsigned nextDim = 0; + for (unsigned i = 0, e = map->getNumDims(); i != e; ++i) { + if (usedDims[i]) { + auto it = seenDims.find((*operands)[i]); + if (it == seenDims.end()) { + dimRemapping[i] = getAffineDimExpr(nextDim++, context); + resultOperands.push_back((*operands)[i]); + seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i])); + } else { + dimRemapping[i] = it->second; + } + } + } + llvm::SmallDenseMap seenSymbols; + SmallVector symRemapping(map->getNumSymbols()); + unsigned nextSym = 0; + for (unsigned i = 0, e = map->getNumSymbols(); i != e; ++i) { + if (usedSyms[i]) { + auto it = seenSymbols.find((*operands)[i + map->getNumDims()]); + if (it == seenSymbols.end()) { + symRemapping[i] = getAffineSymbolExpr(nextSym++, context); + resultOperands.push_back((*operands)[i + map->getNumDims()]); + seenSymbols.insert(std::make_pair((*operands)[i + map->getNumDims()], + symRemapping[i])); + } else { + symRemapping[i] = it->second; + } + } + } + *map = + map->replaceDimsAndSymbols(dimRemapping, symRemapping, nextDim, nextSym); + *operands = resultOperands; +} + +namespace { +/// Simplify AffineApply operations. +/// +struct SimplifyAffineApply : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(AffineApplyOp apply, + PatternRewriter &rewriter) const override { + auto map = apply.getAffineMap(); + + AffineMap oldMap = map; + SmallVector resultOperands(apply.getOperands()); + composeAffineMapAndOperands(&map, &resultOperands); + if (map == oldMap) + return matchFailure(); + + rewriter.replaceOpWithNewOp(apply, map, resultOperands); + return matchSuccess(); + } +}; +} // end anonymous namespace. + +void AffineApplyOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// Common canonicalization pattern support logic +//===----------------------------------------------------------------------===// + +namespace { +/// This is a common class used for patterns of the form +/// "someop(memrefcast) -> someop". It folds the source of any memref_cast +/// into the root operation directly. +struct MemRefCastFolder : public RewritePattern { + /// The rootOpName is the name of the root operation to match against. + MemRefCastFolder(StringRef rootOpName, MLIRContext *context) + : RewritePattern(rootOpName, 1, context) {} + + PatternMatchResult match(Operation *op) const override { + for (auto *operand : op->getOperands()) + if (matchPattern(operand, m_Op())) + return matchSuccess(); + + return matchFailure(); + } + + void rewrite(Operation *op, PatternRewriter &rewriter) const override { + for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) + if (auto *memref = op->getOperand(i)->getDefiningOp()) + if (auto cast = dyn_cast(memref)) + op->setOperand(i, cast.getOperand()); + rewriter.updatedRootInPlace(op); + } +}; + +} // end anonymous namespace. + +//===----------------------------------------------------------------------===// +// AffineDmaStartOp +//===----------------------------------------------------------------------===// + +// TODO(b/133776335) Check that map operands are loop IVs or symbols. +void AffineDmaStartOp::build(Builder *builder, OperationState *result, + Value *srcMemRef, AffineMap srcMap, + ArrayRef srcIndices, Value *destMemRef, + AffineMap dstMap, ArrayRef destIndices, + Value *tagMemRef, AffineMap tagMap, + ArrayRef tagIndices, Value *numElements, + Value *stride, Value *elementsPerStride) { + result->addOperands(srcMemRef); + result->addAttribute(getSrcMapAttrName(), builder->getAffineMapAttr(srcMap)); + result->addOperands(srcIndices); + result->addOperands(destMemRef); + result->addAttribute(getDstMapAttrName(), builder->getAffineMapAttr(dstMap)); + result->addOperands(destIndices); + result->addOperands(tagMemRef); + result->addAttribute(getTagMapAttrName(), builder->getAffineMapAttr(tagMap)); + result->addOperands(tagIndices); + result->addOperands(numElements); + if (stride) { + result->addOperands({stride, elementsPerStride}); + } +} + +void AffineDmaStartOp::print(OpAsmPrinter *p) { + *p << "affine.dma_start " << *getSrcMemRef() << '['; + SmallVector operands(getSrcIndices()); + p->printAffineMapOfSSAIds(getSrcMapAttr(), operands); + *p << "], " << *getDstMemRef() << '['; + operands.assign(getDstIndices().begin(), getDstIndices().end()); + p->printAffineMapOfSSAIds(getDstMapAttr(), operands); + *p << "], " << *getTagMemRef() << '['; + operands.assign(getTagIndices().begin(), getTagIndices().end()); + p->printAffineMapOfSSAIds(getTagMapAttr(), operands); + *p << "], " << *getNumElements(); + if (isStrided()) { + *p << ", " << *getStride(); + *p << ", " << *getNumElementsPerStride(); + } + *p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", " + << getTagMemRefType(); +} + +// Parse AffineDmaStartOp. +// Ex: +// affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size, +// %stride, %num_elt_per_stride +// : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32> +// +ParseResult AffineDmaStartOp::parse(OpAsmParser *parser, + OperationState *result) { + OpAsmParser::OperandType srcMemRefInfo; + AffineMapAttr srcMapAttr; + SmallVector srcMapOperands; + OpAsmParser::OperandType dstMemRefInfo; + AffineMapAttr dstMapAttr; + SmallVector dstMapOperands; + OpAsmParser::OperandType tagMemRefInfo; + AffineMapAttr tagMapAttr; + SmallVector tagMapOperands; + OpAsmParser::OperandType numElementsInfo; + SmallVector strideInfo; + + SmallVector types; + auto indexType = parser->getBuilder().getIndexType(); + + // Parse and resolve the following list of operands: + // *) dst memref followed by its affine maps operands (in square brackets). + // *) src memref followed by its affine map operands (in square brackets). + // *) tag memref followed by its affine map operands (in square brackets). + // *) number of elements transferred by DMA operation. + if (parser->parseOperand(srcMemRefInfo) || + parser->parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr, + getSrcMapAttrName(), result->attributes) || + parser->parseComma() || parser->parseOperand(dstMemRefInfo) || + parser->parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr, + getDstMapAttrName(), result->attributes) || + parser->parseComma() || parser->parseOperand(tagMemRefInfo) || + parser->parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr, + getTagMapAttrName(), result->attributes) || + parser->parseComma() || parser->parseOperand(numElementsInfo)) + return failure(); + + // Parse optional stride and elements per stride. + if (parser->parseTrailingOperandList(strideInfo)) { + return failure(); + } + if (!strideInfo.empty() && strideInfo.size() != 2) { + return parser->emitError(parser->getNameLoc(), + "expected two stride related operands"); + } + bool isStrided = strideInfo.size() == 2; + + if (parser->parseColonTypeList(types)) + return failure(); + + if (types.size() != 3) + return parser->emitError(parser->getNameLoc(), "expected three types"); + + if (parser->resolveOperand(srcMemRefInfo, types[0], result->operands) || + parser->resolveOperands(srcMapOperands, indexType, result->operands) || + parser->resolveOperand(dstMemRefInfo, types[1], result->operands) || + parser->resolveOperands(dstMapOperands, indexType, result->operands) || + parser->resolveOperand(tagMemRefInfo, types[2], result->operands) || + parser->resolveOperands(tagMapOperands, indexType, result->operands) || + parser->resolveOperand(numElementsInfo, indexType, result->operands)) + return failure(); + + if (isStrided) { + if (parser->resolveOperands(strideInfo, indexType, result->operands)) + return failure(); + } + + // Check that src/dst/tag operand counts match their map.numInputs. + if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() || + dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() || + tagMapOperands.size() != tagMapAttr.getValue().getNumInputs()) + return parser->emitError(parser->getNameLoc(), + "memref operand count not equal to map.numInputs"); + return success(); +} + +LogicalResult AffineDmaStartOp::verify() { + if (!getOperand(getSrcMemRefOperandIndex())->getType().isa()) + return emitOpError("expected DMA source to be of memref type"); + if (!getOperand(getDstMemRefOperandIndex())->getType().isa()) + return emitOpError("expected DMA destination to be of memref type"); + if (!getOperand(getTagMemRefOperandIndex())->getType().isa()) + return emitOpError("expected DMA tag to be of memref type"); + + // DMAs from different memory spaces supported. + if (getSrcMemorySpace() == getDstMemorySpace()) { + return emitOpError("DMA should be between different memory spaces"); + } + unsigned numInputsAllMaps = getSrcMap().getNumInputs() + + getDstMap().getNumInputs() + + getTagMap().getNumInputs(); + if (getNumOperands() != numInputsAllMaps + 3 + 1 && + getNumOperands() != numInputsAllMaps + 3 + 1 + 2) { + return emitOpError("incorrect number of operands"); + } + + for (auto *idx : getSrcIndices()) { + if (!idx->getType().isIndex()) + return emitOpError("src index to dma_start must have 'index' type"); + if (!isValidAffineIndexOperand(idx)) + return emitOpError("src index must be a dimension or symbol identifier"); + } + for (auto *idx : getDstIndices()) { + if (!idx->getType().isIndex()) + return emitOpError("dst index to dma_start must have 'index' type"); + if (!isValidAffineIndexOperand(idx)) + return emitOpError("dst index must be a dimension or symbol identifier"); + } + for (auto *idx : getTagIndices()) { + if (!idx->getType().isIndex()) + return emitOpError("tag index to dma_start must have 'index' type"); + if (!isValidAffineIndexOperand(idx)) + return emitOpError("tag index must be a dimension or symbol identifier"); + } + return success(); +} + +void AffineDmaStartOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + /// dma_start(memrefcast) -> dma_start + results.insert(getOperationName(), context); +} + +//===----------------------------------------------------------------------===// +// AffineDmaWaitOp +//===----------------------------------------------------------------------===// + +// TODO(b/133776335) Check that map operands are loop IVs or symbols. +void AffineDmaWaitOp::build(Builder *builder, OperationState *result, + Value *tagMemRef, AffineMap tagMap, + ArrayRef tagIndices, Value *numElements) { + result->addOperands(tagMemRef); + result->addAttribute(getTagMapAttrName(), builder->getAffineMapAttr(tagMap)); + result->addOperands(tagIndices); + result->addOperands(numElements); +} + +void AffineDmaWaitOp::print(OpAsmPrinter *p) { + *p << "affine.dma_wait " << *getTagMemRef() << '['; + SmallVector operands(getTagIndices()); + p->printAffineMapOfSSAIds(getTagMapAttr(), operands); + *p << "], "; + p->printOperand(getNumElements()); + *p << " : " << getTagMemRef()->getType(); +} + +// Parse AffineDmaWaitOp. +// Eg: +// affine.dma_wait %tag[%index], %num_elements +// : memref<1 x i32, (d0) -> (d0), 4> +// +ParseResult AffineDmaWaitOp::parse(OpAsmParser *parser, + OperationState *result) { + OpAsmParser::OperandType tagMemRefInfo; + AffineMapAttr tagMapAttr; + SmallVector tagMapOperands; + Type type; + auto indexType = parser->getBuilder().getIndexType(); + OpAsmParser::OperandType numElementsInfo; + + // Parse tag memref, its map operands, and dma size. + if (parser->parseOperand(tagMemRefInfo) || + parser->parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr, + getTagMapAttrName(), result->attributes) || + parser->parseComma() || parser->parseOperand(numElementsInfo) || + parser->parseColonType(type) || + parser->resolveOperand(tagMemRefInfo, type, result->operands) || + parser->resolveOperands(tagMapOperands, indexType, result->operands) || + parser->resolveOperand(numElementsInfo, indexType, result->operands)) + return failure(); + + if (!type.isa()) + return parser->emitError(parser->getNameLoc(), + "expected tag to be of memref type"); + + if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs()) + return parser->emitError(parser->getNameLoc(), + "tag memref operand count != to map.numInputs"); + return success(); +} + +LogicalResult AffineDmaWaitOp::verify() { + if (!getOperand(0)->getType().isa()) + return emitOpError("expected DMA tag to be of memref type"); + for (auto *idx : getTagIndices()) { + if (!idx->getType().isIndex()) + return emitOpError("index to dma_wait must have 'index' type"); + if (!isValidAffineIndexOperand(idx)) + return emitOpError("index must be a dimension or symbol identifier"); + } + return success(); +} + +void AffineDmaWaitOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + /// dma_wait(memrefcast) -> dma_wait + results.insert(getOperationName(), context); +} + +//===----------------------------------------------------------------------===// +// AffineForOp +//===----------------------------------------------------------------------===// + +void AffineForOp::build(Builder *builder, OperationState *result, + ArrayRef lbOperands, AffineMap lbMap, + ArrayRef ubOperands, AffineMap ubMap, + int64_t step) { + assert(((!lbMap && lbOperands.empty()) || + lbOperands.size() == lbMap.getNumInputs()) && + "lower bound operand count does not match the affine map"); + assert(((!ubMap && ubOperands.empty()) || + ubOperands.size() == ubMap.getNumInputs()) && + "upper bound operand count does not match the affine map"); + assert(step > 0 && "step has to be a positive integer constant"); + + // Add an attribute for the step. + result->addAttribute(getStepAttrName(), + builder->getIntegerAttr(builder->getIndexType(), step)); + + // Add the lower bound. + result->addAttribute(getLowerBoundAttrName(), + builder->getAffineMapAttr(lbMap)); + result->addOperands(lbOperands); + + // Add the upper bound. + result->addAttribute(getUpperBoundAttrName(), + builder->getAffineMapAttr(ubMap)); + result->addOperands(ubOperands); + + // Create a region and a block for the body. The argument of the region is + // the loop induction variable. + Region *bodyRegion = result->addRegion(); + Block *body = new Block(); + body->addArgument(IndexType::get(builder->getContext())); + bodyRegion->push_back(body); + ensureTerminator(*bodyRegion, *builder, result->location); + + // Set the operands list as resizable so that we can freely modify the bounds. + result->setOperandListToResizable(); +} + +void AffineForOp::build(Builder *builder, OperationState *result, int64_t lb, + int64_t ub, int64_t step) { + auto lbMap = AffineMap::getConstantMap(lb, builder->getContext()); + auto ubMap = AffineMap::getConstantMap(ub, builder->getContext()); + return build(builder, result, {}, lbMap, {}, ubMap, step); +} + +static LogicalResult verify(AffineForOp op) { + // Check that the body defines as single block argument for the induction + // variable. + auto *body = op.getBody(); + if (body->getNumArguments() != 1 || + !body->getArgument(0)->getType().isIndex()) + return op.emitOpError( + "expected body to have a single index argument for the " + "induction variable"); + + // Verify that there are enough operands for the bounds. + AffineMap lowerBoundMap = op.getLowerBoundMap(), + upperBoundMap = op.getUpperBoundMap(); + if (op.getNumOperands() != + (lowerBoundMap.getNumInputs() + upperBoundMap.getNumInputs())) + return op.emitOpError( + "operand count must match with affine map dimension and symbol count"); + + // Verify that the bound operands are valid dimension/symbols. + /// Lower bound. + if (failed(verifyDimAndSymbolIdentifiers(op, op.getLowerBoundOperands(), + op.getLowerBoundMap().getNumDims()))) + return failure(); + /// Upper bound. + if (failed(verifyDimAndSymbolIdentifiers(op, op.getUpperBoundOperands(), + op.getUpperBoundMap().getNumDims()))) + return failure(); + return success(); +} + +/// Parse a for operation loop bounds. +static ParseResult parseBound(bool isLower, OperationState *result, + OpAsmParser *p) { + // 'min' / 'max' prefixes are generally syntactic sugar, but are required if + // the map has multiple results. + bool failedToParsedMinMax = + failed(p->parseOptionalKeyword(isLower ? "max" : "min")); + + auto &builder = p->getBuilder(); + auto boundAttrName = isLower ? AffineForOp::getLowerBoundAttrName() + : AffineForOp::getUpperBoundAttrName(); + + // Parse ssa-id as identity map. + SmallVector boundOpInfos; + if (p->parseOperandList(boundOpInfos)) + return failure(); + + if (!boundOpInfos.empty()) { + // Check that only one operand was parsed. + if (boundOpInfos.size() > 1) + return p->emitError(p->getNameLoc(), + "expected only one loop bound operand"); + + // TODO: improve error message when SSA value is not an affine integer. + // Currently it is 'use of value ... expects different type than prior uses' + if (p->resolveOperand(boundOpInfos.front(), builder.getIndexType(), + result->operands)) + return failure(); + + // Create an identity map using symbol id. This representation is optimized + // for storage. Analysis passes may expand it into a multi-dimensional map + // if desired. + AffineMap map = builder.getSymbolIdentityMap(); + result->addAttribute(boundAttrName, builder.getAffineMapAttr(map)); + return success(); + } + + // Get the attribute location. + llvm::SMLoc attrLoc = p->getCurrentLocation(); + + Attribute boundAttr; + if (p->parseAttribute(boundAttr, builder.getIndexType(), boundAttrName, + result->attributes)) + return failure(); + + // Parse full form - affine map followed by dim and symbol list. + if (auto affineMapAttr = boundAttr.dyn_cast()) { + unsigned currentNumOperands = result->operands.size(); + unsigned numDims; + if (parseDimAndSymbolList(p, result->operands, numDims)) + return failure(); + + auto map = affineMapAttr.getValue(); + if (map.getNumDims() != numDims) + return p->emitError( + p->getNameLoc(), + "dim operand count and integer set dim count must match"); + + unsigned numDimAndSymbolOperands = + result->operands.size() - currentNumOperands; + if (numDims + map.getNumSymbols() != numDimAndSymbolOperands) + return p->emitError( + p->getNameLoc(), + "symbol operand count and integer set symbol count must match"); + + // If the map has multiple results, make sure that we parsed the min/max + // prefix. + if (map.getNumResults() > 1 && failedToParsedMinMax) { + if (isLower) { + return p->emitError(attrLoc, "lower loop bound affine map with " + "multiple results requires 'max' prefix"); + } + return p->emitError(attrLoc, "upper loop bound affine map with multiple " + "results requires 'min' prefix"); + } + return success(); + } + + // Parse custom assembly form. + if (auto integerAttr = boundAttr.dyn_cast()) { + result->attributes.pop_back(); + result->addAttribute( + boundAttrName, builder.getAffineMapAttr( + builder.getConstantAffineMap(integerAttr.getInt()))); + return success(); + } + + return p->emitError( + p->getNameLoc(), + "expected valid affine map representation for loop bounds"); +} + +ParseResult parseAffineForOp(OpAsmParser *parser, OperationState *result) { + auto &builder = parser->getBuilder(); + OpAsmParser::OperandType inductionVariable; + // Parse the induction variable followed by '='. + if (parser->parseRegionArgument(inductionVariable) || parser->parseEqual()) + return failure(); + + // Parse loop bounds. + if (parseBound(/*isLower=*/true, result, parser) || + parser->parseKeyword("to", " between bounds") || + parseBound(/*isLower=*/false, result, parser)) + return failure(); + + // Parse the optional loop step, we default to 1 if one is not present. + if (parser->parseOptionalKeyword("step")) { + result->addAttribute( + AffineForOp::getStepAttrName(), + builder.getIntegerAttr(builder.getIndexType(), /*value=*/1)); + } else { + llvm::SMLoc stepLoc = parser->getCurrentLocation(); + IntegerAttr stepAttr; + if (parser->parseAttribute(stepAttr, builder.getIndexType(), + AffineForOp::getStepAttrName().data(), + result->attributes)) + return failure(); + + if (stepAttr.getValue().getSExtValue() < 0) + return parser->emitError( + stepLoc, + "expected step to be representable as a positive signed integer"); + } + + // Parse the body region. + Region *body = result->addRegion(); + if (parser->parseRegion(*body, inductionVariable, builder.getIndexType())) + return failure(); + + AffineForOp::ensureTerminator(*body, builder, result->location); + + // Parse the optional attribute list. + if (parser->parseOptionalAttributeDict(result->attributes)) + return failure(); + + // Set the operands list as resizable so that we can freely modify the bounds. + result->setOperandListToResizable(); + return success(); +} + +static void printBound(AffineMapAttr boundMap, + Operation::operand_range boundOperands, + const char *prefix, OpAsmPrinter *p) { + AffineMap map = boundMap.getValue(); + + // Check if this bound should be printed using custom assembly form. + // The decision to restrict printing custom assembly form to trivial cases + // comes from the will to roundtrip MLIR binary -> text -> binary in a + // lossless way. + // Therefore, custom assembly form parsing and printing is only supported for + // zero-operand constant maps and single symbol operand identity maps. + if (map.getNumResults() == 1) { + AffineExpr expr = map.getResult(0); + + // Print constant bound. + if (map.getNumDims() == 0 && map.getNumSymbols() == 0) { + if (auto constExpr = expr.dyn_cast()) { + *p << constExpr.getValue(); + return; + } + } + + // Print bound that consists of a single SSA symbol if the map is over a + // single symbol. + if (map.getNumDims() == 0 && map.getNumSymbols() == 1) { + if (auto symExpr = expr.dyn_cast()) { + p->printOperand(*boundOperands.begin()); + return; + } + } + } else { + // Map has multiple results. Print 'min' or 'max' prefix. + *p << prefix << ' '; + } + + // Print the map and its operands. + *p << boundMap; + printDimAndSymbolList(boundOperands.begin(), boundOperands.end(), + map.getNumDims(), p); +} + +void print(OpAsmPrinter *p, AffineForOp op) { + *p << "affine.for "; + p->printOperand(op.getBody()->getArgument(0)); + *p << " = "; + printBound(op.getLowerBoundMapAttr(), op.getLowerBoundOperands(), "max", p); + *p << " to "; + printBound(op.getUpperBoundMapAttr(), op.getUpperBoundOperands(), "min", p); + + if (op.getStep() != 1) + *p << " step " << op.getStep(); + p->printRegion(op.region(), + /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/false); + p->printOptionalAttrDict(op.getAttrs(), + /*elidedAttrs=*/{op.getLowerBoundAttrName(), + op.getUpperBoundAttrName(), + op.getStepAttrName()}); +} + +namespace { +/// This is a pattern to fold constant loop bounds. +struct AffineForLoopBoundFolder : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(AffineForOp forOp, + PatternRewriter &rewriter) const override { + auto foldLowerOrUpperBound = [&forOp](bool lower) { + // Check to see if each of the operands is the result of a constant. If + // so, get the value. If not, ignore it. + SmallVector operandConstants; + auto boundOperands = + lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands(); + for (auto *operand : boundOperands) { + Attribute operandCst; + matchPattern(operand, m_Constant(&operandCst)); + operandConstants.push_back(operandCst); + } + + AffineMap boundMap = + lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap(); + assert(boundMap.getNumResults() >= 1 && + "bound maps should have at least one result"); + SmallVector foldedResults; + if (failed(boundMap.constantFold(operandConstants, foldedResults))) + return failure(); + + // Compute the max or min as applicable over the results. + assert(!foldedResults.empty() && + "bounds should have at least one result"); + auto maxOrMin = foldedResults[0].cast().getValue(); + for (unsigned i = 1, e = foldedResults.size(); i < e; i++) { + auto foldedResult = foldedResults[i].cast().getValue(); + maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult) + : llvm::APIntOps::smin(maxOrMin, foldedResult); + } + lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue()) + : forOp.setConstantUpperBound(maxOrMin.getSExtValue()); + return success(); + }; + + // Try to fold the lower bound. + bool folded = false; + if (!forOp.hasConstantLowerBound()) + folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true)); + + // Try to fold the upper bound. + if (!forOp.hasConstantUpperBound()) + folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false)); + + // If any of the bounds were folded we return success. + if (!folded) + return matchFailure(); + rewriter.updatedRootInPlace(forOp); + return matchSuccess(); + } +}; +} // end anonymous namespace + +void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +AffineBound AffineForOp::getLowerBound() { + auto lbMap = getLowerBoundMap(); + return AffineBound(AffineForOp(*this), 0, lbMap.getNumInputs(), lbMap); +} + +AffineBound AffineForOp::getUpperBound() { + auto lbMap = getLowerBoundMap(); + auto ubMap = getUpperBoundMap(); + return AffineBound(AffineForOp(*this), lbMap.getNumInputs(), getNumOperands(), + ubMap); +} + +void AffineForOp::setLowerBound(ArrayRef lbOperands, AffineMap map) { + assert(lbOperands.size() == map.getNumInputs()); + assert(map.getNumResults() >= 1 && "bound map has at least one result"); + + SmallVector newOperands(lbOperands.begin(), lbOperands.end()); + + auto ubOperands = getUpperBoundOperands(); + newOperands.append(ubOperands.begin(), ubOperands.end()); + getOperation()->setOperands(newOperands); + + setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map)); +} + +void AffineForOp::setUpperBound(ArrayRef ubOperands, AffineMap map) { + assert(ubOperands.size() == map.getNumInputs()); + assert(map.getNumResults() >= 1 && "bound map has at least one result"); + + SmallVector newOperands(getLowerBoundOperands()); + newOperands.append(ubOperands.begin(), ubOperands.end()); + getOperation()->setOperands(newOperands); + + setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map)); +} + +void AffineForOp::setLowerBoundMap(AffineMap map) { + auto lbMap = getLowerBoundMap(); + assert(lbMap.getNumDims() == map.getNumDims() && + lbMap.getNumSymbols() == map.getNumSymbols()); + assert(map.getNumResults() >= 1 && "bound map has at least one result"); + (void)lbMap; + setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map)); +} + +void AffineForOp::setUpperBoundMap(AffineMap map) { + auto ubMap = getUpperBoundMap(); + assert(ubMap.getNumDims() == map.getNumDims() && + ubMap.getNumSymbols() == map.getNumSymbols()); + assert(map.getNumResults() >= 1 && "bound map has at least one result"); + (void)ubMap; + setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map)); +} + +bool AffineForOp::hasConstantLowerBound() { + return getLowerBoundMap().isSingleConstant(); +} + +bool AffineForOp::hasConstantUpperBound() { + return getUpperBoundMap().isSingleConstant(); +} + +int64_t AffineForOp::getConstantLowerBound() { + return getLowerBoundMap().getSingleConstantResult(); +} + +int64_t AffineForOp::getConstantUpperBound() { + return getUpperBoundMap().getSingleConstantResult(); +} + +void AffineForOp::setConstantLowerBound(int64_t value) { + setLowerBound({}, AffineMap::getConstantMap(value, getContext())); +} + +void AffineForOp::setConstantUpperBound(int64_t value) { + setUpperBound({}, AffineMap::getConstantMap(value, getContext())); +} + +AffineForOp::operand_range AffineForOp::getLowerBoundOperands() { + return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()}; +} + +AffineForOp::operand_range AffineForOp::getUpperBoundOperands() { + return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()}; +} + +bool AffineForOp::matchingBoundOperandList() { + auto lbMap = getLowerBoundMap(); + auto ubMap = getUpperBoundMap(); + if (lbMap.getNumDims() != ubMap.getNumDims() || + lbMap.getNumSymbols() != ubMap.getNumSymbols()) + return false; + + unsigned numOperands = lbMap.getNumInputs(); + for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) { + // Compare Value *'s. + if (getOperand(i) != getOperand(numOperands + i)) + return false; + } + return true; +} + +/// Returns if the provided value is the induction variable of a AffineForOp. +bool mlir::isForInductionVar(Value *val) { + return getForInductionVarOwner(val) != AffineForOp(); +} + +/// Returns the loop parent of an induction variable. If the provided value is +/// not an induction variable, then return nullptr. +AffineForOp mlir::getForInductionVarOwner(Value *val) { + auto *ivArg = dyn_cast(val); + if (!ivArg || !ivArg->getOwner()) + return AffineForOp(); + auto *containingInst = ivArg->getOwner()->getParent()->getParentOp(); + return dyn_cast(containingInst); +} + +/// Extracts the induction variables from a list of AffineForOps and returns +/// them. +void mlir::extractForInductionVars(ArrayRef forInsts, + SmallVectorImpl *ivs) { + ivs->reserve(forInsts.size()); + for (auto forInst : forInsts) + ivs->push_back(forInst.getInductionVar()); +} + +//===----------------------------------------------------------------------===// +// AffineIfOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(AffineIfOp op) { + // Verify that we have a condition attribute. + auto conditionAttr = + op.getAttrOfType(op.getConditionAttrName()); + if (!conditionAttr) + return op.emitOpError( + "requires an integer set attribute named 'condition'"); + + // Verify that there are enough operands for the condition. + IntegerSet condition = conditionAttr.getValue(); + if (op.getNumOperands() != condition.getNumOperands()) + return op.emitOpError( + "operand count and condition integer set dimension and " + "symbol count must match"); + + // Verify that the operands are valid dimension/symbols. + if (failed(verifyDimAndSymbolIdentifiers( + op, op.getOperation()->getNonSuccessorOperands(), + condition.getNumDims()))) + return failure(); + + // Verify that the entry of each child region does not have arguments. + for (auto ®ion : op.getOperation()->getRegions()) { + for (auto &b : region) + if (b.getNumArguments() != 0) + return op.emitOpError( + "requires that child entry blocks have no arguments"); + } + return success(); +} + +ParseResult parseAffineIfOp(OpAsmParser *parser, OperationState *result) { + // Parse the condition attribute set. + IntegerSetAttr conditionAttr; + unsigned numDims; + if (parser->parseAttribute(conditionAttr, AffineIfOp::getConditionAttrName(), + result->attributes) || + parseDimAndSymbolList(parser, result->operands, numDims)) + return failure(); + + // Verify the condition operands. + auto set = conditionAttr.getValue(); + if (set.getNumDims() != numDims) + return parser->emitError( + parser->getNameLoc(), + "dim operand count and integer set dim count must match"); + if (numDims + set.getNumSymbols() != result->operands.size()) + return parser->emitError( + parser->getNameLoc(), + "symbol operand count and integer set symbol count must match"); + + // Create the regions for 'then' and 'else'. The latter must be created even + // if it remains empty for the validity of the operation. + result->regions.reserve(2); + Region *thenRegion = result->addRegion(); + Region *elseRegion = result->addRegion(); + + // Parse the 'then' region. + if (parser->parseRegion(*thenRegion, {}, {})) + return failure(); + AffineIfOp::ensureTerminator(*thenRegion, parser->getBuilder(), + result->location); + + // If we find an 'else' keyword then parse the 'else' region. + if (!parser->parseOptionalKeyword("else")) { + if (parser->parseRegion(*elseRegion, {}, {})) + return failure(); + AffineIfOp::ensureTerminator(*elseRegion, parser->getBuilder(), + result->location); + } + + // Parse the optional attribute list. + if (parser->parseOptionalAttributeDict(result->attributes)) + return failure(); + + return success(); +} + +void print(OpAsmPrinter *p, AffineIfOp op) { + auto conditionAttr = + op.getAttrOfType(op.getConditionAttrName()); + *p << "affine.if " << conditionAttr; + printDimAndSymbolList(op.operand_begin(), op.operand_end(), + conditionAttr.getValue().getNumDims(), p); + p->printRegion(op.thenRegion(), + /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/false); + + // Print the 'else' regions if it has any blocks. + auto &elseRegion = op.elseRegion(); + if (!elseRegion.empty()) { + *p << " else"; + p->printRegion(elseRegion, + /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/false); + } + + // Print the attribute list. + p->printOptionalAttrDict(op.getAttrs(), + /*elidedAttrs=*/op.getConditionAttrName()); +} + +IntegerSet AffineIfOp::getIntegerSet() { + return getAttrOfType(getConditionAttrName()).getValue(); +} +void AffineIfOp::setIntegerSet(IntegerSet newSet) { + setAttr(getConditionAttrName(), IntegerSetAttr::get(newSet)); +} + +//===----------------------------------------------------------------------===// +// AffineLoadOp +//===----------------------------------------------------------------------===// + +void AffineLoadOp::build(Builder *builder, OperationState *result, + AffineMap map, ArrayRef operands) { + result->addOperands(operands); + if (map) + result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map)); + auto memrefType = operands[0]->getType().cast(); + result->types.push_back(memrefType.getElementType()); +} + +void AffineLoadOp::build(Builder *builder, OperationState *result, + Value *memref, ArrayRef indices) { + result->addOperands(memref); + result->addOperands(indices); + auto memrefType = memref->getType().cast(); + auto rank = memrefType.getRank(); + // Create identity map for memrefs with at least one dimension or () -> () + // for zero-dimensional memrefs. + auto map = rank ? builder->getMultiDimIdentityMap(rank) + : builder->getEmptyAffineMap(); + result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map)); + result->types.push_back(memrefType.getElementType()); +} + +ParseResult AffineLoadOp::parse(OpAsmParser *parser, OperationState *result) { + auto &builder = parser->getBuilder(); + auto affineIntTy = builder.getIndexType(); + + MemRefType type; + OpAsmParser::OperandType memrefInfo; + AffineMapAttr mapAttr; + SmallVector mapOperands; + return failure( + parser->parseOperand(memrefInfo) || + parser->parseAffineMapOfSSAIds(mapOperands, mapAttr, getMapAttrName(), + result->attributes) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(type) || + parser->resolveOperand(memrefInfo, type, result->operands) || + parser->resolveOperands(mapOperands, affineIntTy, result->operands) || + parser->addTypeToList(type.getElementType(), result->types)); +} + +void AffineLoadOp::print(OpAsmPrinter *p) { + *p << "affine.load " << *getMemRef() << '['; + AffineMapAttr mapAttr = getAttrOfType(getMapAttrName()); + if (mapAttr) { + SmallVector operands(getIndices()); + p->printAffineMapOfSSAIds(mapAttr, operands); + } + *p << ']'; + p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{getMapAttrName()}); + *p << " : " << getMemRefType(); +} + +LogicalResult AffineLoadOp::verify() { + if (getType() != getMemRefType().getElementType()) + return emitOpError("result type must match element type of memref"); + + auto mapAttr = getAttrOfType(getMapAttrName()); + if (mapAttr) { + AffineMap map = getAttrOfType(getMapAttrName()).getValue(); + if (map.getNumResults() != getMemRefType().getRank()) + return emitOpError("affine.load affine map num results must equal" + " memref rank"); + if (map.getNumInputs() != getNumOperands() - 1) + return emitOpError("expects as many subscripts as affine map inputs"); + } else { + if (getMemRefType().getRank() != getNumOperands() - 1) + return emitOpError( + "expects the number of subscripts to be equal to memref rank"); + } + + for (auto *idx : getIndices()) { + if (!idx->getType().isIndex()) + return emitOpError("index to load must have 'index' type"); + if (!isValidAffineIndexOperand(idx)) + return emitOpError("index must be a dimension or symbol identifier"); + } + return success(); +} + +void AffineLoadOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + /// load(memrefcast) -> load + results.insert(getOperationName(), context); +} + +//===----------------------------------------------------------------------===// +// AffineStoreOp +//===----------------------------------------------------------------------===// + +void AffineStoreOp::build(Builder *builder, OperationState *result, + Value *valueToStore, AffineMap map, + ArrayRef operands) { + result->addOperands(valueToStore); + result->addOperands(operands); + if (map) + result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map)); +} + +void AffineStoreOp::build(Builder *builder, OperationState *result, + Value *valueToStore, Value *memref, + ArrayRef operands) { + result->addOperands(valueToStore); + result->addOperands(memref); + result->addOperands(operands); + auto memrefType = memref->getType().cast(); + auto rank = memrefType.getRank(); + // Create identity map for memrefs with at least one dimension or () -> () + // for zero-dimensional memrefs. + auto map = rank ? builder->getMultiDimIdentityMap(rank) + : builder->getEmptyAffineMap(); + result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map)); +} + +ParseResult AffineStoreOp::parse(OpAsmParser *parser, OperationState *result) { + auto affineIntTy = parser->getBuilder().getIndexType(); + + MemRefType type; + OpAsmParser::OperandType storeValueInfo; + OpAsmParser::OperandType memrefInfo; + AffineMapAttr mapAttr; + SmallVector mapOperands; + return failure( + parser->parseOperand(storeValueInfo) || parser->parseComma() || + parser->parseOperand(memrefInfo) || + parser->parseAffineMapOfSSAIds(mapOperands, mapAttr, getMapAttrName(), + result->attributes) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(type) || + parser->resolveOperand(storeValueInfo, type.getElementType(), + result->operands) || + parser->resolveOperand(memrefInfo, type, result->operands) || + parser->resolveOperands(mapOperands, affineIntTy, result->operands)); +} + +void AffineStoreOp::print(OpAsmPrinter *p) { + *p << "affine.store " << *getValueToStore(); + *p << ", " << *getMemRef() << '['; + AffineMapAttr mapAttr = getAttrOfType(getMapAttrName()); + if (mapAttr) { + SmallVector operands(getIndices()); + p->printAffineMapOfSSAIds(mapAttr, operands); + } + *p << ']'; + p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{getMapAttrName()}); + *p << " : " << getMemRefType(); +} + +LogicalResult AffineStoreOp::verify() { + // First operand must have same type as memref element type. + if (getValueToStore()->getType() != getMemRefType().getElementType()) + return emitOpError("first operand must have same type memref element type"); + + auto mapAttr = getAttrOfType(getMapAttrName()); + if (mapAttr) { + AffineMap map = mapAttr.getValue(); + if (map.getNumResults() != getMemRefType().getRank()) + return emitOpError("affine.store affine map num results must equal" + " memref rank"); + if (map.getNumInputs() != getNumOperands() - 2) + return emitOpError("expects as many subscripts as affine map inputs"); + } else { + if (getMemRefType().getRank() != getNumOperands() - 2) + return emitOpError( + "expects the number of subscripts to be equal to memref rank"); + } + + for (auto *idx : getIndices()) { + if (!idx->getType().isIndex()) + return emitOpError("index to store must have 'index' type"); + if (!isValidAffineIndexOperand(idx)) + return emitOpError("index must be a dimension or symbol identifier"); + } + return success(); +} + +void AffineStoreOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + /// load(memrefcast) -> load + results.insert(getOperationName(), context); +} + +#define GET_OP_CLASSES +#include "mlir/Dialect/AffineOps/AffineOps.cpp.inc" diff --git a/mlir/lib/Dialect/AffineOps/CMakeLists.txt b/mlir/lib/Dialect/AffineOps/CMakeLists.txt new file mode 100644 index 00000000000..dbe469369a3 --- /dev/null +++ b/mlir/lib/Dialect/AffineOps/CMakeLists.txt @@ -0,0 +1,10 @@ +add_llvm_library(MLIRAffineOps + AffineOps.cpp + DialectRegistration.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AffineOps + ) +add_dependencies(MLIRAffineOps MLIRAffineOpsIncGen MLIRIR MLIRStandardOps) +target_link_libraries(MLIRAffineOps MLIRIR MLIRStandardOps) + diff --git a/mlir/lib/Dialect/AffineOps/DialectRegistration.cpp b/mlir/lib/Dialect/AffineOps/DialectRegistration.cpp new file mode 100644 index 00000000000..9197e3c619f --- /dev/null +++ b/mlir/lib/Dialect/AffineOps/DialectRegistration.cpp @@ -0,0 +1,22 @@ +//===- DialectRegistration.cpp - Register Affine Op dialect ---------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/Dialect/AffineOps/AffineOps.h" +using namespace mlir; + +// Static initialization for Affine op dialect registration. +static DialectRegistration StandardOps; diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt index 294041df4a5..b0641a9611f 100644 --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(AffineOps) add_subdirectory(FxpMathOps) add_subdirectory(GPU) add_subdirectory(Linalg) diff --git a/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp index 1c5bb6e70c8..c48437f60db 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp @@ -15,7 +15,12 @@ // limitations under the License. // ============================================================================= -#include "mlir/AffineOps/AffineOps.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Utils/Intrinsics.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/EDSC/Helpers.h" @@ -23,11 +28,6 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/OpImplementation.h" -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" -#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" -#include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Dialect/Linalg/Utils/Intrinsics.h" -#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/STLExtras.h" diff --git a/mlir/lib/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Transforms/AffineDataCopyGeneration.cpp index 33b73336ff7..71f6c78462d 100644 --- a/mlir/lib/Transforms/AffineDataCopyGeneration.cpp +++ b/mlir/lib/Transforms/AffineDataCopyGeneration.cpp @@ -28,9 +28,9 @@ // //===----------------------------------------------------------------------===// -#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 98798938077..46713dcff49 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -19,11 +19,11 @@ // //===----------------------------------------------------------------------===// -#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" diff --git a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp index 094f8fc421d..293e565cda7 100644 --- a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp +++ b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp @@ -19,12 +19,12 @@ // //===----------------------------------------------------------------------===// -#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index c521a8f6f5d..02787b12e3d 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -19,11 +19,11 @@ // //===----------------------------------------------------------------------===// -#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/LoopUtils.h" diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index fbe1dcc09f9..2acc5a90f5f 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -21,8 +21,8 @@ #include "mlir/Transforms/Passes.h" -#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/LoopAnalysis.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index ef92861adf9..3e92ad739e8 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -43,8 +43,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Transforms/Passes.h" -#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/LoopAnalysis.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BlockAndValueMapping.h" diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index 5a7d926d4f9..e8a8284d392 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -21,7 +21,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Transforms/LowerAffine.h" -#include "mlir/AffineOps/AffineOps.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/AffineExprVisitor.h" diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 0c6a3567ef3..bfdd5bf05f2 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -20,7 +20,6 @@ // //===----------------------------------------------------------------------===// -#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/Dominance.h" #include "mlir/Analysis/LoopAnalysis.h" @@ -28,6 +27,7 @@ #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/Utils.h" #include "mlir/Analysis/VectorAnalysis.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Dialect/VectorOps/VectorOps.h" #include "mlir/IR/AffineExpr.h" diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index 33433e50d0f..9b71ada100c 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -22,10 +22,10 @@ // SSA scalars live out of 'affine.for'/'affine.if' statements is available. //===----------------------------------------------------------------------===// -#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/Dominance.h" #include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/Passes.h" diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index b58b6debc05..0cd979a1c82 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -21,10 +21,10 @@ #include "mlir/Transforms/Passes.h" -#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp index 63150c14742..8b314780c9f 100644 --- a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp @@ -21,11 +21,11 @@ #include "mlir/Transforms/LoopFusionUtils.h" -#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 8b62d007f47..d6a31f92aed 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -21,11 +21,11 @@ #include "mlir/Transforms/LoopUtils.h" -#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/AffineExpr.h" diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index e2253c77f67..8d7b7a8b3a1 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -22,11 +22,11 @@ #include "mlir/Transforms/Utils.h" -#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Dominance.h" #include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 08ee944dc45..cbf616eae10 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -20,12 +20,12 @@ // //===----------------------------------------------------------------------===// -#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Analysis/NestedMatcher.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/Utils.h" #include "mlir/Analysis/VectorAnalysis.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Dialect/VectorOps/VectorOps.h" #include "mlir/IR/AffineExpr.h" diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp index f4c04eeb7fb..529aaa052f6 100644 --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -17,7 +17,7 @@ // RUN: mlir-edsc-builder-api-test | FileCheck %s -#include "mlir/AffineOps/AffineOps.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/EDSC/Builders.h" #include "mlir/EDSC/Helpers.h" diff --git a/mlir/test/lib/Transforms/TestConstantFold.cpp b/mlir/test/lib/Transforms/TestConstantFold.cpp index 35a7eba5478..9c541699e99 100644 --- a/mlir/test/lib/Transforms/TestConstantFold.cpp +++ b/mlir/test/lib/Transforms/TestConstantFold.cpp @@ -15,7 +15,7 @@ // limitations under the License. // ============================================================================= -#include "mlir/AffineOps/AffineOps.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" diff --git a/mlir/test/lib/Transforms/TestLoopFusion.cpp b/mlir/test/lib/Transforms/TestLoopFusion.cpp index 4dd06a58904..604b42817e2 100644 --- a/mlir/test/lib/Transforms/TestLoopFusion.cpp +++ b/mlir/test/lib/Transforms/TestLoopFusion.cpp @@ -19,11 +19,11 @@ // //===----------------------------------------------------------------------===// -#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Passes.h" #include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/test/lib/Transforms/TestVectorizationUtils.cpp b/mlir/test/lib/Transforms/TestVectorizationUtils.cpp index 6fe277dcfcb..3f00eb01e11 100644 --- a/mlir/test/lib/Transforms/TestVectorizationUtils.cpp +++ b/mlir/test/lib/Transforms/TestVectorizationUtils.cpp @@ -19,11 +19,11 @@ // //===----------------------------------------------------------------------===// -#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/NestedMatcher.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/VectorAnalysis.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/StandardTypes.h" -- cgit v1.2.3 From aa2cee9cf53678dde2950087548f502009fec814 Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Tue, 27 Aug 2019 17:56:25 -0700 Subject: Refactor / improve replaceAllMemRefUsesWith Refactor replaceAllMemRefUsesWith to split it into two methods: the new method does the replacement on a single op, and is used by the existing one. - make the methods return LogicalResult instead of bool - Earlier, when replacement failed (due to non-deferencing uses of the memref), the set of ops that had already been processed would have been replaced leaving the IR in an inconsistent state. Now, a pass is made over all ops to first check for non-deferencing uses, and then replacement is performed. No test cases were affected because all clients of this method were first checking for non-deferencing uses before calling this method (for other reasons). This isn't true for a use case in another upcoming PR (scalar replacement); clients can now bail out with consistent IR on failure of replaceAllMemRefUsesWith. Add test case. - multiple deferencing uses of the same memref in a single op is possible (we have no such use cases/scenarios), and this has always remained unsupported. Add an assertion for this. - minor fix to another test pipeline-data-transfer case. Signed-off-by: Uday Bondhugula Closes tensorflow/mlir#87 PiperOrigin-RevId: 265808183 --- mlir/include/mlir/Transforms/Utils.h | 38 ++- mlir/lib/Transforms/LoopFusion.cpp | 7 +- mlir/lib/Transforms/PipelineDataTransfer.cpp | 21 +- mlir/lib/Transforms/Utils/Utils.cpp | 359 +++++++++++++---------- mlir/test/Transforms/pipeline-data-transfer.mlir | 37 ++- 5 files changed, 270 insertions(+), 192 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h index c59d76ae047..23286af8a49 100644 --- a/mlir/include/mlir/Transforms/Utils.h +++ b/mlir/include/mlir/Transforms/Utils.h @@ -37,26 +37,26 @@ class AffineForOp; class Location; class OpBuilder; -/// Replaces all "deferencing" uses of oldMemRef with newMemRef while optionally -/// remapping the old memref's indices using the supplied affine map, -/// 'indexRemap'. The new memref could be of a different shape or rank. -/// 'extraIndices' provides additional access indices to be added to the start. +/// Replaces all "dereferencing" uses of `oldMemRef` with `newMemRef` while +/// optionally remapping the old memref's indices using the supplied affine map, +/// `indexRemap`. The new memref could be of a different shape or rank. +/// `extraIndices` provides additional access indices to be added to the start. /// -/// 'indexRemap' remaps indices of the old memref access to a new set of indices +/// `indexRemap` remaps indices of the old memref access to a new set of indices /// that are used to index the memref. Additional input operands to indexRemap /// can be optionally provided, and they are added at the start of its input -/// list. 'indexRemap' is expected to have only dimensional inputs, and the +/// list. `indexRemap` is expected to have only dimensional inputs, and the /// number of its inputs equal to extraOperands.size() plus rank of the memref. /// 'extraOperands' is an optional argument that corresponds to additional /// operands (inputs) for indexRemap at the beginning of its input list. /// -/// 'domInstFilter', if non-null, restricts the replacement to only those +/// `domInstFilter`, if non-null, restricts the replacement to only those /// operations that are dominated by the former; similarly, `postDomInstFilter` /// restricts replacement to only those operations that are postdominated by it. /// /// Returns true on success and false if the replacement is not possible, -/// whenever a memref is used as an operand in a non-deferencing context, except -/// for dealloc's on the memref which are left untouched. See comments at +/// whenever a memref is used as an operand in a non-dereferencing context, +/// except for dealloc's on the memref which are left untouched. See comments at /// function definition for an example. // // Ex: to replace load %A[%i, %j] with load %Abuf[%t mod 2, %ii - %i, %j]: @@ -66,12 +66,20 @@ class OpBuilder; // extra operands, note that 'indexRemap' would just be applied to existing // indices (%i, %j). // TODO(bondhugula): allow extraIndices to be added at any position. -bool replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, - ArrayRef extraIndices = {}, - AffineMap indexRemap = AffineMap(), - ArrayRef extraOperands = {}, - Operation *domInstFilter = nullptr, - Operation *postDomInstFilter = nullptr); +LogicalResult replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, + ArrayRef extraIndices = {}, + AffineMap indexRemap = AffineMap(), + ArrayRef extraOperands = {}, + Operation *domInstFilter = nullptr, + Operation *postDomInstFilter = nullptr); + +/// Performs the same replacement as the other version above but only for the +/// dereferencing uses of `oldMemRef` in `op`. +LogicalResult replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, + Operation *op, + ArrayRef extraIndices = {}, + AffineMap indexRemap = AffineMap(), + ArrayRef extraOperands = {}); /// Creates and inserts into 'builder' a new AffineApplyOp, with the number of /// its results equal to the number of operands, as a composition diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 46713dcff49..a17481f89c9 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -952,12 +952,13 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, ? AffineMap() : b.getAffineMap(outerIVs.size() + rank, 0, remapExprs); // Replace all users of 'oldMemRef' with 'newMemRef'. - bool ret = + LogicalResult res = replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap, /*extraOperands=*/outerIVs, /*domInstFilter=*/&*forOp.getBody()->begin()); - assert(ret && "replaceAllMemrefUsesWith should always succeed here"); - (void)ret; + assert(succeeded(res) && + "replaceAllMemrefUsesWith should always succeed here"); + (void)res; return newMemRef; } diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 0cd979a1c82..a814af92a5f 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -115,13 +115,14 @@ static bool doubleBuffer(Value *oldMemRef, AffineForOp forOp) { auto ivModTwoOp = bInner.create(forOp.getLoc(), modTwoMap, forOp.getInductionVar()); - // replaceAllMemRefUsesWith will always succeed unless the forOp body has - // non-deferencing uses of the memref (dealloc's are fine though). - if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef, - /*extraIndices=*/{ivModTwoOp}, - /*indexRemap=*/AffineMap(), - /*extraOperands=*/{}, - /*domInstFilter=*/&*forOp.getBody()->begin())) { + // replaceAllMemRefUsesWith will succeed unless the forOp body has + // non-dereferencing uses of the memref (dealloc's are fine though). + if (failed(replaceAllMemRefUsesWith( + oldMemRef, newMemRef, + /*extraIndices=*/{ivModTwoOp}, + /*indexRemap=*/AffineMap(), + /*extraOperands=*/{}, + /*domInstFilter=*/&*forOp.getBody()->begin()))) { LLVM_DEBUG( forOp.emitError("memref replacement for double buffering failed")); ivModTwoOp.erase(); @@ -276,9 +277,9 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { if (!doubleBuffer(oldMemRef, forOp)) { // Normally, double buffering should not fail because we already checked // that there are no uses outside. - LLVM_DEBUG(llvm::dbgs() << "double buffering failed for: \n";); - LLVM_DEBUG(dmaStartInst->dump()); - // IR still in a valid state. + LLVM_DEBUG(llvm::dbgs() + << "double buffering failed for" << dmaStartInst << "\n";); + // IR still valid and semantically correct. return; } // If the old memref has no more uses, remove its 'dead' alloc if it was diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 8d7b7a8b3a1..b0c9b942352 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -57,16 +57,181 @@ static NamedAttribute getAffineMapAttrForMemRef(Operation *op, Value *memref) { return cast(op).getAffineMapAttrForMemRef(memref); } -bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, - ArrayRef extraIndices, - AffineMap indexRemap, - ArrayRef extraOperands, - Operation *domInstFilter, - Operation *postDomInstFilter) { +// Perform the replacement in `op`. +LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, + Operation *op, + ArrayRef extraIndices, + AffineMap indexRemap, + ArrayRef extraOperands) { unsigned newMemRefRank = newMemRef->getType().cast().getRank(); (void)newMemRefRank; // unused in opt mode unsigned oldMemRefRank = oldMemRef->getType().cast().getRank(); - (void)newMemRefRank; + (void)oldMemRefRank; + if (indexRemap) { + assert(indexRemap.getNumSymbols() == 0 && "pure dimensional map expected"); + assert(indexRemap.getNumInputs() == extraOperands.size() + oldMemRefRank); + assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank); + } else { + assert(oldMemRefRank + extraIndices.size() == newMemRefRank); + } + + // Assert same elemental type. + assert(oldMemRef->getType().cast().getElementType() == + newMemRef->getType().cast().getElementType()); + + if (!isMemRefDereferencingOp(*op)) + // Failure: memref used in a non-dereferencing context (potentially + // escapes); no replacement in these cases. + return failure(); + + SmallVector usePositions; + for (const auto &opEntry : llvm::enumerate(op->getOperands())) { + if (opEntry.value() == oldMemRef) + usePositions.push_back(opEntry.index()); + } + + // If memref doesn't appear, nothing to do. + if (usePositions.empty()) + return success(); + + if (usePositions.size() > 1) { + // TODO(mlir-team): extend it for this case when needed (rare). + assert(false && "multiple dereferencing uses in a single op not supported"); + return failure(); + } + + unsigned memRefOperandPos = usePositions.front(); + + OpBuilder builder(op); + NamedAttribute oldMapAttrPair = getAffineMapAttrForMemRef(op, oldMemRef); + AffineMap oldMap = oldMapAttrPair.second.cast().getValue(); + unsigned oldMapNumInputs = oldMap.getNumInputs(); + SmallVector oldMapOperands( + op->operand_begin() + memRefOperandPos + 1, + op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs); + + // Apply 'oldMemRefOperands = oldMap(oldMapOperands)'. + SmallVector oldMemRefOperands; + SmallVector affineApplyOps; + oldMemRefOperands.reserve(oldMemRefRank); + if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) { + for (auto resultExpr : oldMap.getResults()) { + auto singleResMap = builder.getAffineMap( + oldMap.getNumDims(), oldMap.getNumSymbols(), resultExpr); + auto afOp = builder.create(op->getLoc(), singleResMap, + oldMapOperands); + oldMemRefOperands.push_back(afOp); + affineApplyOps.push_back(afOp); + } + } else { + oldMemRefOperands.append(oldMapOperands.begin(), oldMapOperands.end()); + } + + // Construct new indices as a remap of the old ones if a remapping has been + // provided. The indices of a memref come right after it, i.e., + // at position memRefOperandPos + 1. + SmallVector remapOperands; + remapOperands.reserve(extraOperands.size() + oldMemRefRank); + remapOperands.append(extraOperands.begin(), extraOperands.end()); + remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end()); + + SmallVector remapOutputs; + remapOutputs.reserve(oldMemRefRank); + + if (indexRemap && + indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) { + // Remapped indices. + for (auto resultExpr : indexRemap.getResults()) { + auto singleResMap = builder.getAffineMap( + indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr); + auto afOp = builder.create(op->getLoc(), singleResMap, + remapOperands); + remapOutputs.push_back(afOp); + affineApplyOps.push_back(afOp); + } + } else { + // No remapping specified. + remapOutputs.append(remapOperands.begin(), remapOperands.end()); + } + + SmallVector newMapOperands; + newMapOperands.reserve(newMemRefRank); + + // Prepend 'extraIndices' in 'newMapOperands'. + for (auto *extraIndex : extraIndices) { + assert(extraIndex->getDefiningOp()->getNumResults() == 1 && + "single result op's expected to generate these indices"); + assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) && + "invalid memory op index"); + newMapOperands.push_back(extraIndex); + } + + // Append 'remapOutputs' to 'newMapOperands'. + newMapOperands.append(remapOutputs.begin(), remapOutputs.end()); + + // Create new fully composed AffineMap for new op to be created. + assert(newMapOperands.size() == newMemRefRank); + auto newMap = builder.getMultiDimIdentityMap(newMemRefRank); + // TODO(b/136262594) Avoid creating/deleting temporary AffineApplyOps here. + fullyComposeAffineMapAndOperands(&newMap, &newMapOperands); + newMap = simplifyAffineMap(newMap); + canonicalizeMapAndOperands(&newMap, &newMapOperands); + // Remove any affine.apply's that became dead as a result of composition. + for (auto *value : affineApplyOps) + if (value->use_empty()) + value->getDefiningOp()->erase(); + + // Construct the new operation using this memref. + OperationState state(op->getLoc(), op->getName()); + state.setOperandListToResizable(op->hasResizableOperandsList()); + state.operands.reserve(op->getNumOperands() + extraIndices.size()); + // Insert the non-memref operands. + state.operands.append(op->operand_begin(), + op->operand_begin() + memRefOperandPos); + // Insert the new memref value. + state.operands.push_back(newMemRef); + + // Insert the new memref map operands. + state.operands.append(newMapOperands.begin(), newMapOperands.end()); + + // Insert the remaining operands unmodified. + state.operands.append(op->operand_begin() + memRefOperandPos + 1 + + oldMapNumInputs, + op->operand_end()); + + // Result types don't change. Both memref's are of the same elemental type. + state.types.reserve(op->getNumResults()); + for (auto *result : op->getResults()) + state.types.push_back(result->getType()); + + // Add attribute for 'newMap', other Attributes do not change. + auto newMapAttr = builder.getAffineMapAttr(newMap); + for (auto namedAttr : op->getAttrs()) { + if (namedAttr.first == oldMapAttrPair.first) { + state.attributes.push_back({namedAttr.first, newMapAttr}); + } else { + state.attributes.push_back(namedAttr); + } + } + + // Create the new operation. + auto *repOp = builder.createOperation(state); + op->replaceAllUsesWith(repOp); + op->erase(); + + return success(); +} + +LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, + ArrayRef extraIndices, + AffineMap indexRemap, + ArrayRef extraOperands, + Operation *domInstFilter, + Operation *postDomInstFilter) { + unsigned newMemRefRank = newMemRef->getType().cast().getRank(); + (void)newMemRefRank; // unused in opt mode + unsigned oldMemRefRank = oldMemRef->getType().cast().getRank(); + (void)oldMemRefRank; if (indexRemap) { assert(indexRemap.getNumSymbols() == 0 && "pure dimensional map expected"); assert(indexRemap.getNumInputs() == extraOperands.size() + oldMemRefRank); @@ -89,170 +254,44 @@ bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, postDomInfo = std::make_unique( postDomInstFilter->getParentOfType()); - // The ops where memref replacement succeeds are replaced with new ones. - SmallVector opsToErase; - - // Walk all uses of old memref. Operation using the memref gets replaced. - for (auto *opInst : llvm::make_early_inc_range(oldMemRef->getUsers())) { + // Walk all uses of old memref; collect ops to perform replacement. We use a + // DenseSet since an operation could potentially have multiple uses of a + // memref (although rare), and the replacement later is going to erase ops. + DenseSet opsToReplace; + for (auto *op : oldMemRef->getUsers()) { // Skip this use if it's not dominated by domInstFilter. - if (domInstFilter && !domInfo->dominates(domInstFilter, opInst)) + if (domInstFilter && !domInfo->dominates(domInstFilter, op)) continue; // Skip this use if it's not post-dominated by postDomInstFilter. - if (postDomInstFilter && - !postDomInfo->postDominates(postDomInstFilter, opInst)) + if (postDomInstFilter && !postDomInfo->postDominates(postDomInstFilter, op)) continue; - // Skip dealloc's - no replacement is necessary, and a replacement doesn't - // hurt dealloc's. - if (isa(opInst)) + // Skip dealloc's - no replacement is necessary, and a memref replacement + // at other uses doesn't hurt these dealloc's. + if (isa(op)) continue; - // Check if the memref was used in a non-deferencing context. It is fine for - // the memref to be used in a non-deferencing way outside of the region - // where this replacement is happening. - if (!isMemRefDereferencingOp(*opInst)) - // Failure: memref used in a non-deferencing op (potentially escapes); no - // replacement in these cases. - return false; - - auto getMemRefOperandPos = [&]() -> unsigned { - unsigned i, e; - for (i = 0, e = opInst->getNumOperands(); i < e; i++) { - if (opInst->getOperand(i) == oldMemRef) - break; - } - assert(i < opInst->getNumOperands() && "operand guaranteed to be found"); - return i; - }; - - OpBuilder builder(opInst); - unsigned memRefOperandPos = getMemRefOperandPos(); - NamedAttribute oldMapAttrPair = - getAffineMapAttrForMemRef(opInst, oldMemRef); - AffineMap oldMap = oldMapAttrPair.second.cast().getValue(); - unsigned oldMapNumInputs = oldMap.getNumInputs(); - SmallVector oldMapOperands( - opInst->operand_begin() + memRefOperandPos + 1, - opInst->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs); - SmallVector affineApplyOps; - - // Apply 'oldMemRefOperands = oldMap(oldMapOperands)'. - SmallVector oldMemRefOperands; - oldMemRefOperands.reserve(oldMemRefRank); - if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) { - for (auto resultExpr : oldMap.getResults()) { - auto singleResMap = builder.getAffineMap( - oldMap.getNumDims(), oldMap.getNumSymbols(), resultExpr); - auto afOp = builder.create(opInst->getLoc(), - singleResMap, oldMapOperands); - oldMemRefOperands.push_back(afOp); - affineApplyOps.push_back(afOp); - } - } else { - oldMemRefOperands.append(oldMapOperands.begin(), oldMapOperands.end()); - } - - // Construct new indices as a remap of the old ones if a remapping has been - // provided. The indices of a memref come right after it, i.e., - // at position memRefOperandPos + 1. - SmallVector remapOperands; - remapOperands.reserve(extraOperands.size() + oldMemRefRank); - remapOperands.append(extraOperands.begin(), extraOperands.end()); - remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end()); - - SmallVector remapOutputs; - remapOutputs.reserve(oldMemRefRank); - - if (indexRemap && - indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) { - // Remapped indices. - for (auto resultExpr : indexRemap.getResults()) { - auto singleResMap = builder.getAffineMap( - indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr); - auto afOp = builder.create(opInst->getLoc(), - singleResMap, remapOperands); - remapOutputs.push_back(afOp); - affineApplyOps.push_back(afOp); - } - } else { - // No remapping specified. - remapOutputs.append(remapOperands.begin(), remapOperands.end()); - } - - SmallVector newMapOperands; - newMapOperands.reserve(newMemRefRank); - - // Prepend 'extraIndices' in 'newMapOperands'. - for (auto *extraIndex : extraIndices) { - assert(extraIndex->getDefiningOp()->getNumResults() == 1 && - "single result op's expected to generate these indices"); - assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) && - "invalid memory op index"); - newMapOperands.push_back(extraIndex); - } - - // Append 'remapOutputs' to 'newMapOperands'. - newMapOperands.append(remapOutputs.begin(), remapOutputs.end()); - - // Create new fully composed AffineMap for new op to be created. - assert(newMapOperands.size() == newMemRefRank); - auto newMap = builder.getMultiDimIdentityMap(newMemRefRank); - // TODO(b/136262594) Avoid creating/deleting temporary AffineApplyOps here. - fullyComposeAffineMapAndOperands(&newMap, &newMapOperands); - newMap = simplifyAffineMap(newMap); - canonicalizeMapAndOperands(&newMap, &newMapOperands); - // Remove any affine.apply's that became dead as a result of composition. - for (auto *value : affineApplyOps) - if (value->use_empty()) - value->getDefiningOp()->erase(); - - // Construct the new operation using this memref. - OperationState state(opInst->getLoc(), opInst->getName()); - state.setOperandListToResizable(opInst->hasResizableOperandsList()); - state.operands.reserve(opInst->getNumOperands() + extraIndices.size()); - // Insert the non-memref operands. - state.operands.append(opInst->operand_begin(), - opInst->operand_begin() + memRefOperandPos); - // Insert the new memref value. - state.operands.push_back(newMemRef); - - // Insert the new memref map operands. - state.operands.append(newMapOperands.begin(), newMapOperands.end()); - - // Insert the remaining operands unmodified. - state.operands.append(opInst->operand_begin() + memRefOperandPos + 1 + - oldMapNumInputs, - opInst->operand_end()); - - // Result types don't change. Both memref's are of the same elemental type. - state.types.reserve(opInst->getNumResults()); - for (auto *result : opInst->getResults()) - state.types.push_back(result->getType()); - - // Add attribute for 'newMap', other Attributes do not change. - auto newMapAttr = builder.getAffineMapAttr(newMap); - for (auto namedAttr : opInst->getAttrs()) { - if (namedAttr.first == oldMapAttrPair.first) { - state.attributes.push_back({namedAttr.first, newMapAttr}); - } else { - state.attributes.push_back(namedAttr); - } - } - - // Create the new operation. - auto *repOp = builder.createOperation(state); - opInst->replaceAllUsesWith(repOp); - - // Collect and erase at the end since one of these op's could be - // domInstFilter or postDomInstFilter as well! - opsToErase.push_back(opInst); + // Check if the memref was used in a non-dereferencing context. It is fine + // for the memref to be used in a non-dereferencing way outside of the + // region where this replacement is happening. + if (!isMemRefDereferencingOp(*op)) + // Failure: memref used in a non-dereferencing op (potentially escapes); + // no replacement in these cases. + return failure(); + + // We'll first collect and then replace --- since replacement erases the op + // that has the use, and that op could be postDomFilter or domFilter itself! + opsToReplace.insert(op); } - for (auto *opInst : opsToErase) - opInst->erase(); + for (auto *op : opsToReplace) { + if (failed(replaceAllMemRefUsesWith(oldMemRef, newMemRef, op, extraIndices, + indexRemap, extraOperands))) + assert(false && "memref replacement guaranteed to succeed here"); + } - return true; + return success(); } /// Given an operation, inserts one or more single result affine diff --git a/mlir/test/Transforms/pipeline-data-transfer.mlir b/mlir/test/Transforms/pipeline-data-transfer.mlir index 2b4d1438d5a..ce266d5da60 100644 --- a/mlir/test/Transforms/pipeline-data-transfer.mlir +++ b/mlir/test/Transforms/pipeline-data-transfer.mlir @@ -14,7 +14,7 @@ func @loop_nest_dma() { %tag = alloc() : memref<1 x f32> %zero = constant 0 : index - %num_elts = constant 128 : index + %num_elts = constant 32 : index affine.for %i = 0 to 8 { affine.dma_start %A[%i], %Ah[%i], %tag[%zero], %num_elts : memref<256 x f32>, memref<32 x f32, 1>, memref<1 x f32> @@ -22,7 +22,7 @@ func @loop_nest_dma() { %v = affine.load %Ah[%i] : memref<32 x f32, (d0) -> (d0), 1> %r = "compute"(%v) : (f32) -> (f32) affine.store %r, %Ah[%i] : memref<32 x f32, (d0) -> (d0), 1> - affine.for %j = 0 to 128 { + affine.for %j = 0 to 32 { "do_more_compute"(%i, %j) : (index, index) -> () } } @@ -41,7 +41,7 @@ func @loop_nest_dma() { // CHECK-NEXT: %{{.*}} = affine.load %{{.*}}[%{{.*}} mod 2, %{{.*}}] : memref<2x32xf32, 1> // CHECK-NEXT: %{{.*}} = "compute"(%{{.*}}) : (f32) -> f32 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}} mod 2, %{{.*}}] : memref<2x32xf32, 1> -// CHECK-NEXT: affine.for %{{.*}} = 0 to 128 { +// CHECK-NEXT: affine.for %{{.*}} = 0 to 32 { // CHECK-NEXT: "do_more_compute"(%{{.*}}, %{{.*}}) : (index, index) -> () // CHECK-NEXT: } // CHECK-NEXT: } @@ -52,7 +52,7 @@ func @loop_nest_dma() { // CHECK-NEXT: %{{.*}} = affine.load %{{.*}}[%{{.*}} mod 2, %{{.*}}] : memref<2x32xf32, 1> // CHECK-NEXT: %{{.*}} = "compute"(%{{.*}}) : (f32) -> f32 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}} mod 2, %{{.*}}] : memref<2x32xf32, 1> -// CHECK-NEXT: affine.for %{{.*}} = 0 to 128 { +// CHECK-NEXT: affine.for %{{.*}} = 0 to 32 { // CHECK-NEXT: "do_more_compute"(%{{.*}}, %{{.*}}) : (index, index) -> () // CHECK-NEXT: } // CHECK-NEXT: dealloc %{{.*}} : memref<2x1xf32> @@ -297,3 +297,32 @@ func @dynamic_shape_dma_buffer(%arg0: memref<512 x 32 x f32>) { // CHECK: affine.dma_wait %{{.*}}[%{{.*}} mod 2, symbol(%{{.*}})], %{{.*}} : memref<2x1xi32> // CHECK: return } + +// Memref replacement will fail here due to a non-dereferencing use. However, +// no incorrect transformation is performed since replaceAllMemRefUsesWith +// checks for escaping uses before performing any replacement. +// CHECK-LABEL: func @escaping_use +func @escaping_use() { + %A = alloc() : memref<256 x f32, (d0) -> (d0), 0> + %Ah = alloc() : memref<32 x f32, (d0) -> (d0), 1> + %tag = alloc() : memref<1 x f32> + %zero = constant 0 : index + %num_elts = constant 32 : index + + // alloc for the buffer is created but no replacement should happen. + affine.for %i = 0 to 8 { + affine.dma_start %A[%i], %Ah[%i], %tag[%zero], %num_elts : memref<256 x f32>, memref<32 x f32, 1>, memref<1 x f32> + affine.dma_wait %tag[%zero], %num_elts : memref<1 x f32> + "compute"(%Ah) : (memref<32 x f32, 1>) -> () + %v = affine.load %Ah[%i] : memref<32 x f32, (d0) -> (d0), 1> + "foo"(%v) : (f32) -> () + } + return +} +// No replacement +// CHECK: affine.for %{{.*}} = 0 to 8 { +// CHECK-NEXT: affine.dma_start %{{.*}}[%{{.*}}], %{{.*}}[%{{.*}}], %{{.*}}[%{{.*}}], %{{.*}} +// CHECK-NEXT: affine.dma_wait %{{.*}}[%{{.*}}], %{{.*}} : memref<1xf32> +// CHECK-NEXT: "compute"(%{{.*}}) : (memref<32xf32, 1>) -> () +// CHECK-NEXT: [[VAL:%[0-9]+]] = affine.load %{{.*}}[%{{.*}}] : memref<32xf32, 1> +// CHECK-NEXT: "foo"([[VAL]]) : (f32) -> () -- cgit v1.2.3 From f1b100c77ba005899c60f3dea74607d5daad3f52 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Fri, 13 Sep 2019 13:33:46 -0700 Subject: NFC: Finish replacing FunctionPassBase/ModulePassBase with OpPassBase. These directives were temporary during the generalization of FunctionPass/ModulePass to OpPass. PiperOrigin-RevId: 268970259 --- .../Linalg/Linalg1/include/linalg1/Passes.h | 3 +- .../Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp | 2 +- .../Linalg/Linalg3/include/linalg3/Transforms.h | 4 +- mlir/examples/Linalg/Linalg3/lib/Transforms.cpp | 2 +- mlir/include/mlir/Analysis/Passes.h | 8 ++-- .../ControlFlowToCFG/ConvertControlFlowToCFG.h | 3 +- .../mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h | 8 ++-- .../mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h | 3 +- .../mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h | 3 +- .../StandardToLLVM/ConvertStandardToLLVMPass.h | 7 ++-- .../mlir/Conversion/VectorToLLVM/VectorToLLVM.h | 3 +- mlir/include/mlir/Dialect/FxpMathOps/Passes.h | 5 +-- mlir/include/mlir/Dialect/GPU/Passes.h | 3 +- mlir/include/mlir/Dialect/Linalg/Passes.h | 10 ++--- mlir/include/mlir/Dialect/QuantOps/Passes.h | 5 +-- mlir/include/mlir/Dialect/SPIRV/Passes.h | 2 +- mlir/include/mlir/Pass/Pass.h | 5 --- mlir/include/mlir/Quantizer/Transforms/Passes.h | 6 +-- mlir/include/mlir/Transforms/Passes.h | 45 +++++++++++----------- mlir/include/mlir/Transforms/ViewOpGraph.h | 3 +- mlir/include/mlir/Transforms/ViewRegionGraph.h | 7 ++-- mlir/lib/Analysis/MemRefBoundCheck.cpp | 4 +- mlir/lib/Analysis/TestMemRefDependenceCheck.cpp | 5 ++- mlir/lib/Analysis/TestParallelismDetection.cpp | 4 +- .../ControlFlowToCFG/ConvertControlFlowToCFG.cpp | 2 +- .../GPUToCUDA/ConvertKernelFuncToCubin.cpp | 2 +- .../GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp | 2 +- .../GPUToCUDA/GenerateCubinAccessors.cpp | 2 +- .../Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 2 +- mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 2 +- mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp | 2 +- .../StandardToLLVM/ConvertStandardToLLVM.cpp | 4 +- .../StandardToSPIRV/ConvertStandardToSPIRVPass.cpp | 2 +- mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp | 2 +- .../FxpMathOps/Transforms/LowerUniformRealMath.cpp | 4 +- .../lib/Dialect/GPU/Transforms/KernelOutlining.cpp | 2 +- mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 2 +- .../Linalg/Transforms/LowerToLLVMDialect.cpp | 3 +- .../lib/Dialect/Linalg/Transforms/LowerToLoops.cpp | 3 +- mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 2 +- .../Dialect/QuantOps/Transforms/ConvertConst.cpp | 2 +- .../QuantOps/Transforms/ConvertSimQuant.cpp | 2 +- .../Transforms/AddDefaultStatsTestPass.cpp | 3 +- .../Transforms/InferQuantizedTypesPass.cpp | 3 +- .../Transforms/RemoveInstrumentationPass.cpp | 2 +- mlir/lib/Transforms/AffineDataCopyGeneration.cpp | 2 +- mlir/lib/Transforms/CSE.cpp | 2 +- mlir/lib/Transforms/Canonicalizer.cpp | 2 +- mlir/lib/Transforms/LoopCoalescing.cpp | 2 +- mlir/lib/Transforms/LoopFusion.cpp | 2 +- mlir/lib/Transforms/LoopInvariantCodeMotion.cpp | 2 +- mlir/lib/Transforms/LoopTiling.cpp | 2 +- mlir/lib/Transforms/LoopUnroll.cpp | 2 +- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 2 +- mlir/lib/Transforms/LowerAffine.cpp | 2 +- mlir/lib/Transforms/LowerVectorTransfers.cpp | 2 +- mlir/lib/Transforms/MaterializeVectors.cpp | 2 +- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 2 +- mlir/lib/Transforms/PipelineDataTransfer.cpp | 2 +- mlir/lib/Transforms/SimplifyAffineStructures.cpp | 2 +- mlir/lib/Transforms/StripDebugInfo.cpp | 2 +- mlir/lib/Transforms/Vectorize.cpp | 2 +- mlir/lib/Transforms/ViewOpGraph.cpp | 2 +- mlir/lib/Transforms/ViewRegionGraph.cpp | 6 +-- mlir/test/lib/Transforms/TestConstantFold.cpp | 2 +- mlir/test/lib/Transforms/TestLoopFusion.cpp | 2 +- .../lib/Transforms/TestLoopParametricTiling.cpp | 2 +- .../test/lib/Transforms/TestVectorizationUtils.cpp | 2 +- mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp | 2 +- 69 files changed, 119 insertions(+), 133 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/examples/Linalg/Linalg1/include/linalg1/Passes.h b/mlir/examples/Linalg/Linalg1/include/linalg1/Passes.h index 0347e182a50..8be517db917 100644 --- a/mlir/examples/Linalg/Linalg1/include/linalg1/Passes.h +++ b/mlir/examples/Linalg/Linalg1/include/linalg1/Passes.h @@ -29,12 +29,11 @@ namespace mlir { class ModuleOp; template class OpPassBase; -using ModulePassBase = OpPassBase; } // namespace mlir namespace linalg { -mlir::ModulePassBase *createLowerLinalgToLLVMPass(); +mlir::OpPassBase *createLowerLinalgToLLVMPass(); } // namespace linalg diff --git a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp index 9073169b260..abbd9c95ac9 100644 --- a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp @@ -440,7 +440,7 @@ struct LowerLinalgToLLVMPass : public ModulePass { }; } // namespace -ModulePassBase *linalg::createLowerLinalgToLLVMPass() { +OpPassBase *linalg::createLowerLinalgToLLVMPass() { return new LowerLinalgToLLVMPass(); } diff --git a/mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h b/mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h index 849d65a6b6f..5381734721c 100644 --- a/mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h +++ b/mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h @@ -31,7 +31,6 @@ class Operation; class Value; template class OpPassBase; -using FunctionPassBase = OpPassBase; } // namespace mlir namespace linalg { @@ -75,7 +74,8 @@ void lowerToLoops(mlir::FuncOp f); /// Creates a pass that rewrites linalg.load and linalg.store to affine.load and /// affine.store operations. -std::unique_ptr createLowerLinalgLoadStorePass(); +std::unique_ptr> +createLowerLinalgLoadStorePass(); } // namespace linalg diff --git a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp index ce2656520fa..184f1da528f 100644 --- a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp @@ -300,6 +300,6 @@ Rewriter::matchAndRewrite(linalg::StoreOp store, } } // namespace -std::unique_ptr linalg::createLowerLinalgLoadStorePass() { +std::unique_ptr> linalg::createLowerLinalgLoadStorePass() { return std::make_unique(); } diff --git a/mlir/include/mlir/Analysis/Passes.h b/mlir/include/mlir/Analysis/Passes.h index 8c947e6c222..b233ab5f209 100644 --- a/mlir/include/mlir/Analysis/Passes.h +++ b/mlir/include/mlir/Analysis/Passes.h @@ -24,21 +24,21 @@ #define MLIR_ANALYSIS_PASSES_H #include "mlir/Support/LLVM.h" +#include namespace mlir { class FuncOp; template class OpPassBase; -using FunctionPassBase = OpPassBase; /// Creates a pass to check memref accesses in a Function. -FunctionPassBase *createMemRefBoundCheckPass(); +std::unique_ptr> createMemRefBoundCheckPass(); /// Creates a pass to check memref access dependences in a Function. -FunctionPassBase *createTestMemRefDependenceCheckPass(); +std::unique_ptr> createTestMemRefDependenceCheckPass(); /// Creates a pass to test parallelism detection; emits note for parallel loops. -FunctionPassBase *createParallelismDetectionTestPass(); +std::unique_ptr> createParallelismDetectionTestPass(); } // end namespace mlir diff --git a/mlir/include/mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h b/mlir/include/mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h index 56b0ed1d290..b6a29da3900 100644 --- a/mlir/include/mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h +++ b/mlir/include/mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h @@ -26,7 +26,6 @@ class FuncOp; struct LogicalResult; class MLIRContext; template class OpPassBase; -using FunctionPassBase = OpPassBase; class RewritePattern; // Owning list of rewriting patterns. @@ -39,7 +38,7 @@ void populateLoopToStdConversionPatterns(OwningRewritePatternList &patterns, MLIRContext *ctx); /// Creates a pass to convert loop.for, loop.if and loop.terminator ops to CFG. -std::unique_ptr createLowerToCFGPass(); +std::unique_ptr> createLowerToCFGPass(); } // namespace mlir diff --git a/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h b/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h index 8d5c5013599..161f68701d6 100644 --- a/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h +++ b/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h @@ -36,7 +36,6 @@ class LLVMDialect; } template class OpPassBase; -using ModulePassBase = OpPassBase; using OwnedCubin = std::unique_ptr>; using CubinGenerator = std::function; @@ -50,7 +49,7 @@ using CubinGenerator = std::function; /// attached as a string attribute named 'nvvm.cubin' to the kernel function. /// After the transformation, the body of the kernel function is removed (i.e., /// it is turned into a declaration). -std::unique_ptr +std::unique_ptr> createConvertGPUKernelToCubinPass(CubinGenerator cubinGenerator); /// Creates a pass to convert a gpu.launch_func operation into a sequence of @@ -59,11 +58,12 @@ createConvertGPUKernelToCubinPass(CubinGenerator cubinGenerator); /// This pass does not generate code to call CUDA directly but instead uses a /// small wrapper library that exports a stable and conveniently typed ABI /// ontop of CUDA. -std::unique_ptr createConvertGpuLaunchFuncToCudaCallsPass(); +std::unique_ptr> +createConvertGpuLaunchFuncToCudaCallsPass(); /// Creates a pass to augment a module with getter functions for all contained /// cubins as encoded via the 'nvvm.cubin' attribute. -std::unique_ptr createGenerateCubinAccessorPass(); +std::unique_ptr> createGenerateCubinAccessorPass(); } // namespace mlir diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h index 01e50baa592..9a15b41f7de 100644 --- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h +++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h @@ -25,14 +25,13 @@ class OwningRewritePatternList; class ModuleOp; template class OpPassBase; -using ModulePassBase = OpPassBase; /// Collect a set of patterns to convert from the GPU dialect to NVVM. void populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns); /// Creates a pass that lowers GPU dialect operations to NVVM counterparts. -std::unique_ptr createLowerGpuOpsToNVVMOpsPass(); +std::unique_ptr> createLowerGpuOpsToNVVMOpsPass(); } // namespace mlir diff --git a/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h b/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h index 9ef21ea97b6..960a93dd566 100644 --- a/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h +++ b/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h @@ -22,7 +22,6 @@ namespace mlir { class FuncOp; template class OpPassBase; -using FunctionPassBase = OpPassBase; /// Create a pass that converts loop nests into GPU kernels. It considers /// top-level affine.for and linalg.for operations as roots of loop nests and @@ -32,7 +31,7 @@ using FunctionPassBase = OpPassBase; /// parallelization is performed, it is under the responsibility of the caller /// to strip-mine the loops and to perform the dependence analysis before /// calling the conversion. -std::unique_ptr +std::unique_ptr> createSimpleLoopsToGPUPass(unsigned numBlockDims, unsigned numThreadDims); } // namespace mlir diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h index 589571d0a46..98e105aa2b5 100644 --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h @@ -34,7 +34,6 @@ struct LogicalResult; class MLIRContext; class ModuleOp; template class OpPassBase; -using ModulePassBase = OpPassBase; class RewritePattern; class Type; @@ -58,12 +57,12 @@ void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns); /// Creates a pass to convert the Standard dialect into the LLVMIR dialect. -std::unique_ptr createLowerToLLVMPass(); +std::unique_ptr> createLowerToLLVMPass(); /// Creates a pass to convert operations to the LLVMIR dialect. The conversion /// is defined by a list of patterns and a type converter that will be obtained /// during the pass using the provided callbacks. -std::unique_ptr +std::unique_ptr> createLowerToLLVMPass(LLVMPatternListFiller patternListFiller, LLVMTypeConverterMaker typeConverterMaker); @@ -72,7 +71,7 @@ createLowerToLLVMPass(LLVMPatternListFiller patternListFiller, /// callback and an optional type conversion class, an instance is created /// during the pass. template -std::unique_ptr +std::unique_ptr> createLowerToLLVMPass(LLVMPatternListFiller patternListFiller) { return createLowerToLLVMPass(patternListFiller, [](MLIRContext *context) { return std::make_unique(context); diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/VectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/VectorToLLVM.h index c781858a672..34d783ae131 100644 --- a/mlir/include/mlir/Conversion/VectorToLLVM/VectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/VectorToLLVM.h @@ -23,14 +23,13 @@ class ModuleOp; class OwningRewritePatternList; template class OpPassBase; -using ModulePassBase = OpPassBase; /// Collect a set of patterns to convert from the Vector dialect to LLVM. void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns); /// Create a pass to convert vector operations to the LLVMIR dialect. -ModulePassBase *createLowerVectorToLLVMPass(); +OpPassBase *createLowerVectorToLLVMPass(); } // namespace mlir #endif // MLIR_CONVERSION_VECTORTOLLVM_VECTORTOLLVM_H_ diff --git a/mlir/include/mlir/Dialect/FxpMathOps/Passes.h b/mlir/include/mlir/Dialect/FxpMathOps/Passes.h index f4099ab7754..415b1c0b253 100644 --- a/mlir/include/mlir/Dialect/FxpMathOps/Passes.h +++ b/mlir/include/mlir/Dialect/FxpMathOps/Passes.h @@ -25,7 +25,6 @@ namespace mlir { class FuncOp; template class OpPassBase; -using FunctionPassBase = OpPassBase; namespace fxpmath { @@ -33,11 +32,11 @@ namespace fxpmath { /// arithmetic. This will leave unrecognized real math ops as-is and is /// typically followed by a pass that lowers any unrecognized ops to a pure /// floating point form. -FunctionPassBase *createLowerUniformRealMathPass(); +OpPassBase *createLowerUniformRealMathPass(); /// Creates a pass that lowers uniform-quantized qcast/dcast ops to equivalent /// operations that perform quantize/dequantize. -FunctionPassBase *createLowerUniformCastsPass(); +OpPassBase *createLowerUniformCastsPass(); } // namespace fxpmath } // namespace mlir diff --git a/mlir/include/mlir/Dialect/GPU/Passes.h b/mlir/include/mlir/Dialect/GPU/Passes.h index 14a9f013c99..7c8ce02db90 100644 --- a/mlir/include/mlir/Dialect/GPU/Passes.h +++ b/mlir/include/mlir/Dialect/GPU/Passes.h @@ -28,9 +28,8 @@ namespace mlir { class ModuleOp; template class OpPassBase; -using ModulePassBase = OpPassBase; -std::unique_ptr createGpuKernelOutliningPass(); +std::unique_ptr> createGpuKernelOutliningPass(); } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h index 118e278ef60..2b58df71a48 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -29,20 +29,18 @@ namespace mlir { class FuncOp; class ModuleOp; template class OpPassBase; -using FunctionPassBase = OpPassBase; -using ModulePassBase = OpPassBase; namespace linalg { -std::unique_ptr +std::unique_ptr> createLinalgFusionPass(ArrayRef tileSizes = {}); -std::unique_ptr +std::unique_ptr> createLinalgTilingPass(ArrayRef tileSizes = {}, bool promoteViews = false); -std::unique_ptr createLowerLinalgToLoopsPass(); +std::unique_ptr> createLowerLinalgToLoopsPass(); -std::unique_ptr createLowerLinalgToLLVMPass(); +std::unique_ptr> createLowerLinalgToLLVMPass(); } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Dialect/QuantOps/Passes.h b/mlir/include/mlir/Dialect/QuantOps/Passes.h index 5e5fd700f92..c57d7bf41fe 100644 --- a/mlir/include/mlir/Dialect/QuantOps/Passes.h +++ b/mlir/include/mlir/Dialect/QuantOps/Passes.h @@ -30,20 +30,19 @@ namespace mlir { class FuncOp; template class OpPassBase; -using FunctionPassBase = OpPassBase; namespace quant { /// Creates a pass that converts quantization simulation operations (i.e. /// FakeQuant and those like it) to casts into/out of supported QuantizedTypes. -std::unique_ptr createConvertSimulatedQuantPass(); +std::unique_ptr> createConvertSimulatedQuantPass(); /// Creates a pass that converts constants followed by a qbarrier to a /// constant whose value is quantized. This is typically one of the last /// passes done when lowering to express actual quantized arithmetic in a /// low level representation. Because it modifies the constant, it is /// destructive and cannot be undone. -std::unique_ptr createConvertConstPass(); +std::unique_ptr> createConvertConstPass(); } // namespace quant } // namespace mlir diff --git a/mlir/include/mlir/Dialect/SPIRV/Passes.h b/mlir/include/mlir/Dialect/SPIRV/Passes.h index 85f4f79ed59..ce4c19bf059 100644 --- a/mlir/include/mlir/Dialect/SPIRV/Passes.h +++ b/mlir/include/mlir/Dialect/SPIRV/Passes.h @@ -27,7 +27,7 @@ namespace mlir { namespace spirv { -std::unique_ptr createConvertStandardToSPIRVPass(); +std::unique_ptr> createConvertStandardToSPIRVPass(); } // namespace spirv } // namespace mlir diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h index fc440d5cb2e..441fd29bdd1 100644 --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -298,11 +298,6 @@ template struct ModulePass : public OpPass { /// Return the current module being transformed. ModuleOp getModule() { return this->getOperation(); } }; - -/// Using directives defining legacy base classes. -// TODO(riverriddle) These should be removed in favor of OpPassBase. -using FunctionPassBase = OpPassBase; -using ModulePassBase = OpPassBase; } // end namespace mlir #endif // MLIR_PASS_PASS_H diff --git a/mlir/include/mlir/Quantizer/Transforms/Passes.h b/mlir/include/mlir/Quantizer/Transforms/Passes.h index f894ea801e0..4fdea58daf4 100644 --- a/mlir/include/mlir/Quantizer/Transforms/Passes.h +++ b/mlir/include/mlir/Quantizer/Transforms/Passes.h @@ -33,17 +33,17 @@ class TargetConfiguration; /// Creates a pass that infers quantized types based on metadata discovered /// in the computation. -std::unique_ptr +std::unique_ptr> createInferQuantizedTypesPass(SolverContext &solverContext, const TargetConfiguration &config); /// Creates a pass which removes any instrumentation and hint ops which have /// no effect on final runtime. -std::unique_ptr createRemoveInstrumentationPass(); +std::unique_ptr> createRemoveInstrumentationPass(); /// Adds default (dummy) statistics to ops that can benefit from runtime stats. /// Meant for testing. -std::unique_ptr createAddDefaultStatsPass(); +std::unique_ptr> createAddDefaultStatsPass(); } // namespace quantizer } // namespace mlir diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index 0c777ec6035..2656a777d23 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -33,32 +33,30 @@ class AffineForOp; class FuncOp; class ModuleOp; template class OpPassBase; -using FunctionPassBase = OpPassBase; -using ModulePassBase = OpPassBase; /// Creates a constant folding pass. Note that this pass solely provides simple /// top-down constant folding functionality; it is intended to be used for /// testing purpose. Use Canonicalizer pass, which exploits more simplification /// opportunties exposed by constant folding, for the general cases. -std::unique_ptr createTestConstantFoldPass(); +std::unique_ptr> createTestConstantFoldPass(); /// Creates an instance of the Canonicalizer pass. -std::unique_ptr createCanonicalizerPass(); +std::unique_ptr> createCanonicalizerPass(); /// Creates a pass to perform common sub expression elimination. -std::unique_ptr createCSEPass(); +std::unique_ptr> createCSEPass(); /// Creates a pass to vectorize loops, operations and data types using a /// target-independent, n-D super-vector abstraction. -std::unique_ptr +std::unique_ptr> createVectorizePass(llvm::ArrayRef virtualVectorSize); /// Creates a pass to allow independent testing of vectorizer functionality with /// FileCheck. -std::unique_ptr createVectorizerTestPass(); +std::unique_ptr> createVectorizerTestPass(); /// Creates a pass to lower super-vectors to target-dependent HW vectors. -std::unique_ptr +std::unique_ptr> createMaterializeVectorsPass(llvm::ArrayRef vectorSize); /// Creates a loop unrolling pass with the provided parameters. @@ -67,75 +65,76 @@ createMaterializeVectorsPass(llvm::ArrayRef vectorSize); /// factors supplied through other means. If -1 is passed as the unrollFactor /// and no callback is provided, anything passed from the command-line (if at /// all) or the default unroll factor is used (LoopUnroll:kDefaultUnrollFactor). -std::unique_ptr createLoopUnrollPass( +std::unique_ptr> createLoopUnrollPass( int unrollFactor = -1, int unrollFull = -1, const std::function &getUnrollFactor = nullptr); /// Creates a loop unroll jam pass to unroll jam by the specified factor. A /// factor of -1 lets the pass use the default factor or the one on the command /// line if provided. -std::unique_ptr +std::unique_ptr> createLoopUnrollAndJamPass(int unrollJamFactor = -1); /// Creates a simplification pass for affine structures (maps and sets). In /// addition, this pass also normalizes memrefs to have the trivial (identity) /// layout map. -std::unique_ptr createSimplifyAffineStructuresPass(); +std::unique_ptr> createSimplifyAffineStructuresPass(); /// Creates a loop fusion pass which fuses loops. Buffers of size less than or /// equal to `localBufSizeThreshold` are promoted to memory space /// `fastMemorySpace'. -std::unique_ptr +std::unique_ptr> createLoopFusionPass(unsigned fastMemorySpace = 0, uint64_t localBufSizeThreshold = 0, bool maximalFusion = false); /// Creates a loop invariant code motion pass that hoists loop invariant /// instructions out of the loop. -std::unique_ptr createLoopInvariantCodeMotionPass(); +std::unique_ptr> createLoopInvariantCodeMotionPass(); /// Creates a pass to pipeline explicit movement of data across levels of the /// memory hierarchy. -std::unique_ptr createPipelineDataTransferPass(); +std::unique_ptr> createPipelineDataTransferPass(); /// Lowers affine control flow operations (ForStmt, IfStmt and AffineApplyOp) /// to equivalent lower-level constructs (flow of basic blocks and arithmetic /// primitives). -std::unique_ptr createLowerAffinePass(); +std::unique_ptr> createLowerAffinePass(); /// Creates a pass to perform tiling on loop nests. -std::unique_ptr createLoopTilingPass(uint64_t cacheSizeBytes); +std::unique_ptr> +createLoopTilingPass(uint64_t cacheSizeBytes); /// Creates a pass that performs parametric tiling so that the outermost loops /// have the given fixed number of iterations. Assumes outermost loop nests /// are permutable. -std::unique_ptr +std::unique_ptr> createSimpleParametricTilingPass(ArrayRef outerLoopSizes); /// Creates a pass that transforms perfectly nested loops with independent /// bounds into a single loop. -std::unique_ptr createLoopCoalescingPass(); +std::unique_ptr> createLoopCoalescingPass(); /// Performs packing (or explicit copying) of accessed memref regions into /// buffers in the specified faster memory space through either pointwise copies /// or DMA operations. -std::unique_ptr createAffineDataCopyGenerationPass( +std::unique_ptr> createAffineDataCopyGenerationPass( unsigned slowMemorySpace, unsigned fastMemorySpace, unsigned tagMemorySpace = 0, int minDmaTransferSize = 1024, uint64_t fastMemCapacityBytes = std::numeric_limits::max()); /// Creates a pass to lower VectorTransferReadOp and VectorTransferWriteOp. -std::unique_ptr createLowerVectorTransfersPass(); +std::unique_ptr> createLowerVectorTransfersPass(); /// Creates a pass to perform optimizations relying on memref dataflow such as /// store to load forwarding, elimination of dead stores, and dead allocs. -std::unique_ptr createMemRefDataFlowOptPass(); +std::unique_ptr> createMemRefDataFlowOptPass(); /// Creates a pass to strip debug information from a function. -std::unique_ptr createStripDebugInfoPass(); +std::unique_ptr> createStripDebugInfoPass(); /// Creates a pass which tests loop fusion utilities. -std::unique_ptr createTestLoopFusionPass(); +std::unique_ptr> createTestLoopFusionPass(); } // end namespace mlir diff --git a/mlir/include/mlir/Transforms/ViewOpGraph.h b/mlir/include/mlir/Transforms/ViewOpGraph.h index 9ba85c242ea..4f9856e9f93 100644 --- a/mlir/include/mlir/Transforms/ViewOpGraph.h +++ b/mlir/include/mlir/Transforms/ViewOpGraph.h @@ -30,7 +30,6 @@ namespace mlir { class Block; class ModuleOp; template class OpPassBase; -using ModulePassBase = OpPassBase; /// Displays the graph in a window. This is for use from the debugger and /// depends on Graphviz to generate the graph. @@ -42,7 +41,7 @@ llvm::raw_ostream &writeGraph(llvm::raw_ostream &os, Block &block, bool shortNames = false, const Twine &title = ""); /// Creates a pass to print op graphs. -std::unique_ptr +std::unique_ptr> createPrintOpGraphPass(llvm::raw_ostream &os = llvm::errs(), bool shortNames = false, const llvm::Twine &title = ""); diff --git a/mlir/include/mlir/Transforms/ViewRegionGraph.h b/mlir/include/mlir/Transforms/ViewRegionGraph.h index f54d35643eb..626afc31284 100644 --- a/mlir/include/mlir/Transforms/ViewRegionGraph.h +++ b/mlir/include/mlir/Transforms/ViewRegionGraph.h @@ -29,7 +29,6 @@ namespace mlir { class FuncOp; template class OpPassBase; -using FunctionPassBase = OpPassBase; class Region; /// Displays the CFG in a window. This is for use from the debugger and @@ -42,9 +41,9 @@ llvm::raw_ostream &writeGraph(llvm::raw_ostream &os, Region ®ion, bool shortNames = false, const Twine &title = ""); /// Creates a pass to print CFG graphs. -FunctionPassBase *createPrintCFGGraphPass(llvm::raw_ostream &os = llvm::errs(), - bool shortNames = false, - const llvm::Twine &title = ""); +OpPassBase * +createPrintCFGGraphPass(llvm::raw_ostream &os = llvm::errs(), + bool shortNames = false, const llvm::Twine &title = ""); } // end namespace mlir diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index 849407520da..1d115b13082 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -43,8 +43,8 @@ struct MemRefBoundCheck : public FunctionPass { } // end anonymous namespace -FunctionPassBase *mlir::createMemRefBoundCheckPass() { - return new MemRefBoundCheck(); +std::unique_ptr> mlir::createMemRefBoundCheckPass() { + return std::make_unique(); } void MemRefBoundCheck::runOnFunction() { diff --git a/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp b/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp index 477121fcc24..c73bf72f127 100644 --- a/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp @@ -45,8 +45,9 @@ struct TestMemRefDependenceCheck } // end anonymous namespace -FunctionPassBase *mlir::createTestMemRefDependenceCheckPass() { - return new TestMemRefDependenceCheck(); +std::unique_ptr> +mlir::createTestMemRefDependenceCheckPass() { + return std::make_unique(); } // Returns a result string which represents the direction vector (if there was diff --git a/mlir/lib/Analysis/TestParallelismDetection.cpp b/mlir/lib/Analysis/TestParallelismDetection.cpp index 75982a8e0c5..a9f9ea94a45 100644 --- a/mlir/lib/Analysis/TestParallelismDetection.cpp +++ b/mlir/lib/Analysis/TestParallelismDetection.cpp @@ -36,8 +36,8 @@ struct TestParallelismDetection } // end anonymous namespace -FunctionPassBase *mlir::createParallelismDetectionTestPass() { - return new TestParallelismDetection(); +std::unique_ptr> mlir::createParallelismDetectionTestPass() { + return std::make_unique(); } // Walks the function and emits a note for all 'affine.for' ops detected as diff --git a/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp b/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp index 81426aaa243..cbff101e15d 100644 --- a/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp +++ b/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp @@ -270,7 +270,7 @@ void ControlFlowToCFGPass::runOnFunction() { signalPassFailure(); } -std::unique_ptr mlir::createLowerToCFGPass() { +std::unique_ptr> mlir::createLowerToCFGPass() { return std::make_unique(); } diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp index 29771fe7ea5..2cefa787ae8 100644 --- a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp @@ -163,7 +163,7 @@ GpuKernelToCubinPass::translateGpuKernelToCubinAnnotation(FuncOp &function) { return success(); } -std::unique_ptr +std::unique_ptr> mlir::createConvertGPUKernelToCubinPass(CubinGenerator cubinGenerator) { return std::make_unique(cubinGenerator); } diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp index ba0bc475168..5a435a5cc88 100644 --- a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp @@ -369,7 +369,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls( launchOp.erase(); } -std::unique_ptr +std::unique_ptr> mlir::createConvertGpuLaunchFuncToCudaCallsPass() { return std::make_unique(); } diff --git a/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp b/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp index c4daf8af956..f8c6f5d15ff 100644 --- a/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp @@ -110,7 +110,7 @@ private: } // anonymous namespace -std::unique_ptr createGenerateCubinAccessorPass() { +std::unique_ptr> createGenerateCubinAccessorPass() { return std::make_unique(); } diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index ed7ebfbced1..1ae83ae9ae2 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -162,7 +162,7 @@ void mlir::populateGpuToNVVMConversionPatterns( converter); } -std::unique_ptr mlir::createLowerGpuOpsToNVVMOpsPass() { +std::unique_ptr> mlir::createLowerGpuOpsToNVVMOpsPass() { return std::make_unique(); } diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index 6746594ce87..544232e9860 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -166,7 +166,7 @@ void GPUToSPIRVPass::runOnModule() { } } -ModulePassBase *createGPUToSPIRVPass() { return new GPUToSPIRVPass(); } +OpPassBase *createGPUToSPIRVPass() { return new GPUToSPIRVPass(); } static PassRegistration pass("convert-gpu-to-spirv", "Convert GPU dialect to SPIR-V dialect"); diff --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp index 9dd9fdbbb87..6d4cb9d8256 100644 --- a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp +++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp @@ -66,7 +66,7 @@ struct ForLoopMapper : public FunctionPass { }; } // namespace -std::unique_ptr +std::unique_ptr> mlir::createSimpleLoopsToGPUPass(unsigned numBlockDims, unsigned numThreadDims) { return std::make_unique(numBlockDims, numThreadDims); diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 8d0dc6bb6b2..ce844e9dfc8 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -1222,11 +1222,11 @@ struct LLVMLoweringPass : public ModulePass { }; } // end namespace -std::unique_ptr mlir::createLowerToLLVMPass() { +std::unique_ptr> mlir::createLowerToLLVMPass() { return std::make_unique(); } -std::unique_ptr +std::unique_ptr> mlir::createLowerToLLVMPass(LLVMPatternListFiller patternListFiller, LLVMTypeConverterMaker typeConverterMaker) { return std::make_unique(patternListFiller, diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp index 174a4477560..dcecb84453f 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp @@ -48,7 +48,7 @@ void ConvertStandardToSPIRVPass::runOnModule() { } } -std::unique_ptr +std::unique_ptr> mlir::spirv::createConvertStandardToSPIRVPass() { return std::make_unique(); } diff --git a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp index 174e3d6910c..2b15637ae14 100644 --- a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp @@ -194,7 +194,7 @@ void LowerVectorToLLVMPass::runOnModule() { } } -ModulePassBase *mlir::createLowerVectorToLLVMPass() { +OpPassBase *mlir::createLowerVectorToLLVMPass() { return new LowerVectorToLLVMPass(); } diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp index 83307da957b..a4fd98bb89e 100644 --- a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp +++ b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp @@ -372,7 +372,7 @@ void LowerUniformRealMathPass::runOnFunction() { applyPatternsGreedily(fn, patterns); } -FunctionPassBase *mlir::fxpmath::createLowerUniformRealMathPass() { +OpPassBase *mlir::fxpmath::createLowerUniformRealMathPass() { return new LowerUniformRealMathPass(); } @@ -392,7 +392,7 @@ void LowerUniformCastsPass::runOnFunction() { applyPatternsGreedily(fn, patterns); } -FunctionPassBase *mlir::fxpmath::createLowerUniformCastsPass() { +OpPassBase *mlir::fxpmath::createLowerUniformCastsPass() { return new LowerUniformCastsPass(); } diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp index 26449f6e6f1..4328fb39c29 100644 --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -147,7 +147,7 @@ public: } // namespace -std::unique_ptr mlir::createGpuKernelOutliningPass() { +std::unique_ptr> mlir::createGpuKernelOutliningPass() { return std::make_unique(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index 0ce6c82679b..bfad37dffaf 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -349,7 +349,7 @@ LinalgFusionPass::LinalgFusionPass(ArrayRef sizes) this->tileSizes.assign(sizes.begin(), sizes.end()); } -std::unique_ptr +std::unique_ptr> mlir::linalg::createLinalgFusionPass(ArrayRef tileSizes) { return std::make_unique(tileSizes); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp index 59d78d2e870..48b4eda8697 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -907,7 +907,8 @@ void LowerLinalgToLLVMPass::runOnModule() { } } -std::unique_ptr mlir::linalg::createLowerLinalgToLLVMPass() { +std::unique_ptr> +mlir::linalg::createLowerLinalgToLLVMPass() { return std::make_unique(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp index 54c0350504e..64773903f87 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp @@ -390,7 +390,8 @@ void LowerLinalgToLoopsPass::runOnFunction() { } } -std::unique_ptr mlir::linalg::createLowerLinalgToLoopsPass() { +std::unique_ptr> +mlir::linalg::createLowerLinalgToLoopsPass() { return std::make_unique(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index cacec86dc35..f13ce6485bd 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -527,7 +527,7 @@ LinalgTilingPass::LinalgTilingPass(ArrayRef sizes, bool promoteViews) { this->promoteViews = promoteViews; } -std::unique_ptr +std::unique_ptr> mlir::linalg::createLinalgTilingPass(ArrayRef tileSizes, bool promoteViews) { return std::make_unique(tileSizes, promoteViews); diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp index e3a17b057d4..61636dcdd8b 100644 --- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp +++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp @@ -112,7 +112,7 @@ void ConvertConstPass::runOnFunction() { applyPatternsGreedily(func, patterns); } -std::unique_ptr mlir::quant::createConvertConstPass() { +std::unique_ptr> mlir::quant::createConvertConstPass() { return std::make_unique(); } diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp index 1000b1fabbf..e65f30d035b 100644 --- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp +++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp @@ -147,7 +147,7 @@ void ConvertSimulatedQuantPass::runOnFunction() { signalPassFailure(); } -std::unique_ptr +std::unique_ptr> mlir::quant::createConvertSimulatedQuantPass() { return std::make_unique(); } diff --git a/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp b/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp index a2d38ce211d..696c1e2db3a 100644 --- a/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp +++ b/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp @@ -118,7 +118,8 @@ void AddDefaultStatsPass::runWithConfig(SolverContext &solverContext, }); } -std::unique_ptr mlir::quantizer::createAddDefaultStatsPass() { +std::unique_ptr> +mlir::quantizer::createAddDefaultStatsPass() { return std::make_unique(); } diff --git a/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp b/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp index ff293fc93aa..7c449e32c4c 100644 --- a/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp +++ b/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp @@ -286,7 +286,8 @@ void InferQuantizedTypesPass::transformResultType(CAGResultAnchor *anchor, } } -std::unique_ptr mlir::quantizer::createInferQuantizedTypesPass( +std::unique_ptr> +mlir::quantizer::createInferQuantizedTypesPass( SolverContext &solverContext, const TargetConfiguration &config) { return std::make_unique(solverContext, config); } diff --git a/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp b/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp index b9fbf27d24f..0266520bec3 100644 --- a/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp +++ b/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp @@ -66,7 +66,7 @@ void RemoveInstrumentationPass::runOnFunction() { applyPatternsGreedily(func, patterns); } -std::unique_ptr +std::unique_ptr> mlir::quantizer::createRemoveInstrumentationPass() { return std::make_unique(); } diff --git a/mlir/lib/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Transforms/AffineDataCopyGeneration.cpp index fa483008d15..5b2a3185469 100644 --- a/mlir/lib/Transforms/AffineDataCopyGeneration.cpp +++ b/mlir/lib/Transforms/AffineDataCopyGeneration.cpp @@ -165,7 +165,7 @@ struct AffineDataCopyGeneration /// buffers in 'fastMemorySpace', and replaces memory operations to the former /// by the latter. Only load op's handled for now. /// TODO(bondhugula): extend this to store op's. -std::unique_ptr mlir::createAffineDataCopyGenerationPass( +std::unique_ptr> mlir::createAffineDataCopyGenerationPass( unsigned slowMemorySpace, unsigned fastMemorySpace, unsigned tagMemorySpace, int minDmaTransferSize, uint64_t fastMemCapacityBytes) { return std::make_unique( diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index bb89aef7fef..0e6dae6c549 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -258,7 +258,7 @@ void CSE::runOnFunction() { markAnalysesPreserved(); } -std::unique_ptr mlir::createCSEPass() { +std::unique_ptr> mlir::createCSEPass() { return std::make_unique(); } diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp index db6c8ee26e6..7e08d363648 100644 --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -53,7 +53,7 @@ void Canonicalizer::runOnFunction() { } /// Create a Canonicalizer pass. -std::unique_ptr mlir::createCanonicalizerPass() { +std::unique_ptr> mlir::createCanonicalizerPass() { return std::make_unique(); } diff --git a/mlir/lib/Transforms/LoopCoalescing.cpp b/mlir/lib/Transforms/LoopCoalescing.cpp index 8e220607f06..c1eec56526e 100644 --- a/mlir/lib/Transforms/LoopCoalescing.cpp +++ b/mlir/lib/Transforms/LoopCoalescing.cpp @@ -96,7 +96,7 @@ public: } // namespace -std::unique_ptr mlir::createLoopCoalescingPass() { +std::unique_ptr> mlir::createLoopCoalescingPass() { return std::make_unique(); } diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index a17481f89c9..8257bf05f5d 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -111,7 +111,7 @@ struct LoopFusion : public FunctionPass { } // end anonymous namespace -std::unique_ptr +std::unique_ptr> mlir::createLoopFusionPass(unsigned fastMemorySpace, uint64_t localBufSizeThreshold, bool maximalFusion) { return std::make_unique(fastMemorySpace, localBufSizeThreshold, diff --git a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp index 6150996a3d4..ed0adbf21a0 100644 --- a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp +++ b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp @@ -76,7 +76,7 @@ static bool isMemRefDereferencingOp(Operation &op) { return false; } -std::unique_ptr mlir::createLoopInvariantCodeMotionPass() { +std::unique_ptr> mlir::createLoopInvariantCodeMotionPass() { return std::make_unique(); } diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 02787b12e3d..d90e727b0ac 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -81,7 +81,7 @@ struct LoopTiling : public FunctionPass { /// Creates a pass to perform loop tiling on all suitable loop nests of a /// Function. -std::unique_ptr +std::unique_ptr> mlir::createLoopTilingPass(uint64_t cacheSizeBytes) { return std::make_unique(cacheSizeBytes); } diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 5e132794149..40f48ada4d7 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -180,7 +180,7 @@ LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) { return loopUnrollByFactor(forOp, kDefaultUnrollFactor); } -std::unique_ptr mlir::createLoopUnrollPass( +std::unique_ptr> mlir::createLoopUnrollPass( int unrollFactor, int unrollFull, const std::function &getUnrollFactor) { return std::make_unique( diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index b6b2f3d4ad7..559f94bedf0 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -82,7 +82,7 @@ struct LoopUnrollAndJam : public FunctionPass { }; } // end anonymous namespace -std::unique_ptr +std::unique_ptr> mlir::createLoopUnrollAndJamPass(int unrollJamFactor) { return std::make_unique( unrollJamFactor == -1 ? None : Optional(unrollJamFactor)); diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index e8a8284d392..2ed01a7cc32 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -529,7 +529,7 @@ class LowerAffinePass : public FunctionPass { /// Lowers If and For operations within a function into their lower level CFG /// equivalent blocks. -std::unique_ptr mlir::createLowerAffinePass() { +std::unique_ptr> mlir::createLowerAffinePass() { return std::make_unique(); } diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index 86ab2484e2a..126a29edffb 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -373,7 +373,7 @@ struct LowerVectorTransfersPass } // end anonymous namespace -std::unique_ptr mlir::createLowerVectorTransfersPass() { +std::unique_ptr> mlir::createLowerVectorTransfersPass() { return std::make_unique(); } diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index bfdd5bf05f2..737af704992 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -766,7 +766,7 @@ void MaterializeVectorsPass::runOnFunction() { signalPassFailure(); } -std::unique_ptr +std::unique_ptr> mlir::createMaterializeVectorsPass(llvm::ArrayRef vectorSize) { return std::make_unique(vectorSize); } diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index f922d508c69..58703394479 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -88,7 +88,7 @@ struct MemRefDataFlowOpt : public FunctionPass { /// Creates a pass to perform optimizations relying on memref dataflow such as /// store to load forwarding, elimination of dead stores, and dead allocs. -std::unique_ptr mlir::createMemRefDataFlowOptPass() { +std::unique_ptr> mlir::createMemRefDataFlowOptPass() { return std::make_unique(); } diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index fe201572ca3..d8d8dba9620 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -49,7 +49,7 @@ struct PipelineDataTransfer : public FunctionPass { /// Creates a pass to pipeline explicit movement of data across levels of the /// memory hierarchy. -std::unique_ptr mlir::createPipelineDataTransferPass() { +std::unique_ptr> mlir::createPipelineDataTransferPass() { return std::make_unique(); } diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp index 5eaf8f3460a..e243c1bec54 100644 --- a/mlir/lib/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp @@ -86,7 +86,7 @@ struct SimplifyAffineStructures } // end anonymous namespace -std::unique_ptr mlir::createSimplifyAffineStructuresPass() { +std::unique_ptr> mlir::createSimplifyAffineStructuresPass() { return std::make_unique(); } diff --git a/mlir/lib/Transforms/StripDebugInfo.cpp b/mlir/lib/Transforms/StripDebugInfo.cpp index 15db8b58e88..772df3da3c7 100644 --- a/mlir/lib/Transforms/StripDebugInfo.cpp +++ b/mlir/lib/Transforms/StripDebugInfo.cpp @@ -38,7 +38,7 @@ void StripDebugInfo::runOnFunction() { } /// Creates a pass to strip debug information from a function. -std::unique_ptr mlir::createStripDebugInfoPass() { +std::unique_ptr> mlir::createStripDebugInfoPass() { return std::make_unique(); } diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 89e3da7477d..606cdb77a42 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -1276,7 +1276,7 @@ void Vectorize::runOnFunction() { LLVM_DEBUG(dbgs() << "\n"); } -std::unique_ptr +std::unique_ptr> mlir::createVectorizePass(llvm::ArrayRef virtualVectorSize) { return std::make_unique(virtualVectorSize); } diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp index afb65c7d148..7f65a143a96 100644 --- a/mlir/lib/Transforms/ViewOpGraph.cpp +++ b/mlir/lib/Transforms/ViewOpGraph.cpp @@ -153,7 +153,7 @@ llvm::raw_ostream &mlir::writeGraph(llvm::raw_ostream &os, mlir::Block &block, return llvm::WriteGraph(os, &block, shortNames, title); } -std::unique_ptr +std::unique_ptr> mlir::createPrintOpGraphPass(llvm::raw_ostream &os, bool shortNames, const llvm::Twine &title) { return std::make_unique(os, shortNames, title); diff --git a/mlir/lib/Transforms/ViewRegionGraph.cpp b/mlir/lib/Transforms/ViewRegionGraph.cpp index 5a0e8e5ea99..91ac397200a 100644 --- a/mlir/lib/Transforms/ViewRegionGraph.cpp +++ b/mlir/lib/Transforms/ViewRegionGraph.cpp @@ -85,9 +85,9 @@ private: }; } // namespace -FunctionPassBase *mlir::createPrintCFGGraphPass(llvm::raw_ostream &os, - bool shortNames, - const llvm::Twine &title) { +OpPassBase *mlir::createPrintCFGGraphPass(llvm::raw_ostream &os, + bool shortNames, + const llvm::Twine &title) { return new PrintCFGPass(os, shortNames, title); } diff --git a/mlir/test/lib/Transforms/TestConstantFold.cpp b/mlir/test/lib/Transforms/TestConstantFold.cpp index b1c895257c3..15ecaabb149 100644 --- a/mlir/test/lib/Transforms/TestConstantFold.cpp +++ b/mlir/test/lib/Transforms/TestConstantFold.cpp @@ -74,7 +74,7 @@ void TestConstantFold::runOnFunction() { } /// Creates a constant folding pass. -std::unique_ptr mlir::createTestConstantFoldPass() { +std::unique_ptr> mlir::createTestConstantFoldPass() { return std::make_unique(); } diff --git a/mlir/test/lib/Transforms/TestLoopFusion.cpp b/mlir/test/lib/Transforms/TestLoopFusion.cpp index 604b42817e2..026a897fa8d 100644 --- a/mlir/test/lib/Transforms/TestLoopFusion.cpp +++ b/mlir/test/lib/Transforms/TestLoopFusion.cpp @@ -58,7 +58,7 @@ struct TestLoopFusion : public FunctionPass { } // end anonymous namespace -std::unique_ptr mlir::createTestLoopFusionPass() { +std::unique_ptr> mlir::createTestLoopFusionPass() { return std::make_unique(); } diff --git a/mlir/test/lib/Transforms/TestLoopParametricTiling.cpp b/mlir/test/lib/Transforms/TestLoopParametricTiling.cpp index 6dc0bfde371..bce1e08402d 100644 --- a/mlir/test/lib/Transforms/TestLoopParametricTiling.cpp +++ b/mlir/test/lib/Transforms/TestLoopParametricTiling.cpp @@ -55,7 +55,7 @@ public: }; } // end namespace -std::unique_ptr +std::unique_ptr> mlir::createSimpleParametricTilingPass(ArrayRef outerLoopSizes) { return std::make_unique(outerLoopSizes); } diff --git a/mlir/test/lib/Transforms/TestVectorizationUtils.cpp b/mlir/test/lib/Transforms/TestVectorizationUtils.cpp index 3f00eb01e11..4fdb66071bf 100644 --- a/mlir/test/lib/Transforms/TestVectorizationUtils.cpp +++ b/mlir/test/lib/Transforms/TestVectorizationUtils.cpp @@ -290,7 +290,7 @@ void VectorizerTestPass::runOnFunction() { } } -std::unique_ptr mlir::createVectorizerTestPass() { +std::unique_ptr> mlir::createVectorizerTestPass() { return std::make_unique(); } diff --git a/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp b/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp index df69407fa9e..deddc63eb10 100644 --- a/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp +++ b/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp @@ -140,7 +140,7 @@ static LogicalResult runMLIRPasses(ModuleOp m) { PassManager pm(m.getContext()); pm.addPass(createGpuKernelOutliningPass()); - pm.addPass(static_cast>( + pm.addPass(static_cast>>( std::make_unique())); pm.addPass(createConvertGPUKernelToCubinPass(&compilePtxToCubin)); pm.addPass(createGenerateCubinAccessorPass()); -- cgit v1.2.3 From 727a50ae2db4492a8c3168647996abacd75d0622 Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Wed, 18 Sep 2019 11:25:33 -0700 Subject: Support symbolic operands for memref replacement; fix memrefNormalize - allow symbols in index remapping provided for memref replacement - fix memref normalize crash on cases with layout maps with symbols Signed-off-by: Uday Bondhugula Reported by: Alex Zinenko Closes tensorflow/mlir#139 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/139 from bondhugula:memref-rep-symbols 2f48c1fdb5d4c58915bbddbd9f07b18541819233 PiperOrigin-RevId: 269851182 --- mlir/include/mlir/Dialect/StandardOps/Ops.td | 6 ++++++ mlir/include/mlir/Transforms/Utils.h | 16 +++++++------- mlir/lib/Transforms/LoopFusion.cpp | 1 + mlir/lib/Transforms/PipelineDataTransfer.cpp | 1 + mlir/lib/Transforms/Utils/LoopUtils.cpp | 1 + mlir/lib/Transforms/Utils/Utils.cpp | 31 ++++++++++++++++++++-------- mlir/test/Transforms/memref-normalize.mlir | 15 ++++++++++++++ 7 files changed, 55 insertions(+), 16 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td index 426ec656b0e..7de48c07e44 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -162,6 +162,12 @@ def AllocOp : Std_Op<"alloc"> { unsigned getNumSymbolicOperands() { return getNumOperands() - getType().getNumDynamicDims(); } + + /// Returns the symbolic operands (the ones in square brackets), which bind + /// to the symbols of the memref's layout map. + operand_range getSymbolicOperands() { + return {operand_begin() + getType().getNumDynamicDims(), operand_end()}; + } }]; let hasCanonicalizer = 1; diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h index 0644bc8064f..c682b48f331 100644 --- a/mlir/include/mlir/Transforms/Utils.h +++ b/mlir/include/mlir/Transforms/Utils.h @@ -40,15 +40,15 @@ class OpBuilder; /// Replaces all "dereferencing" uses of `oldMemRef` with `newMemRef` while /// optionally remapping the old memref's indices using the supplied affine map, /// `indexRemap`. The new memref could be of a different shape or rank. -/// `extraIndices` provides additional access indices to be added to the start. +/// `extraIndices` provides any additional access indices to be added to the +/// start. /// /// `indexRemap` remaps indices of the old memref access to a new set of indices /// that are used to index the memref. Additional input operands to indexRemap -/// can be optionally provided, and they are added at the start of its input -/// list. `indexRemap` is expected to have only dimensional inputs, and the -/// number of its inputs equal to extraOperands.size() plus rank of the memref. -/// 'extraOperands' is an optional argument that corresponds to additional -/// operands (inputs) for indexRemap at the beginning of its input list. +/// can be optionally provided in `extraOperands`, and they occupy the start +/// of its input list. `indexRemap`'s dimensional inputs are expected to +/// correspond to memref's indices, and its symbolic inputs if any should be +/// provided in `symbolOperands`. /// /// `domInstFilter`, if non-null, restricts the replacement to only those /// operations that are dominated by the former; similarly, `postDomInstFilter` @@ -70,6 +70,7 @@ LogicalResult replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, ArrayRef extraIndices = {}, AffineMap indexRemap = AffineMap(), ArrayRef extraOperands = {}, + ArrayRef symbolOperands = {}, Operation *domInstFilter = nullptr, Operation *postDomInstFilter = nullptr); @@ -79,7 +80,8 @@ LogicalResult replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, Operation *op, ArrayRef extraIndices = {}, AffineMap indexRemap = AffineMap(), - ArrayRef extraOperands = {}); + ArrayRef extraOperands = {}, + ArrayRef symbolOperands = {}); /// Rewrites the memref defined by this alloc op to have an identity layout map /// and updates all its indexing uses. Returns failure if any of its uses diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 8257bf05f5d..188165b94e1 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -955,6 +955,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, LogicalResult res = replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap, /*extraOperands=*/outerIVs, + /*symbolOperands=*/{}, /*domInstFilter=*/&*forOp.getBody()->begin()); assert(succeeded(res) && "replaceAllMemrefUsesWith should always succeed here"); diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index d8d8dba9620..b4d67262c17 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -122,6 +122,7 @@ static bool doubleBuffer(Value *oldMemRef, AffineForOp forOp) { /*extraIndices=*/{ivModTwoOp}, /*indexRemap=*/AffineMap(), /*extraOperands=*/{}, + /*symbolOperands=*/{}, /*domInstFilter=*/&*forOp.getBody()->begin()))) { LLVM_DEBUG( forOp.emitError("memref replacement for double buffering failed")); diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index e038512c0c0..0c9a666a6ec 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -1548,6 +1548,7 @@ static LogicalResult generateCopy( replaceAllMemRefUsesWith(memref, fastMemRef, /*extraIndices=*/{}, indexRemap, /*extraOperands=*/regionSymbols, + /*symbolOperands=*/{}, /*domInstFilter=*/&*begin, /*postDomInstFilter=*/&*postDomFilter); diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index e57d40e5a1c..d6400ac50ed 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -62,14 +62,17 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, Operation *op, ArrayRef extraIndices, AffineMap indexRemap, - ArrayRef extraOperands) { + ArrayRef extraOperands, + ArrayRef symbolOperands) { unsigned newMemRefRank = newMemRef->getType().cast().getRank(); (void)newMemRefRank; // unused in opt mode unsigned oldMemRefRank = oldMemRef->getType().cast().getRank(); - (void)oldMemRefRank; + (void)oldMemRefRank; // unused in opt mode if (indexRemap) { - assert(indexRemap.getNumSymbols() == 0 && "pure dimensional map expected"); - assert(indexRemap.getNumInputs() == extraOperands.size() + oldMemRefRank); + assert(indexRemap.getNumSymbols() == symbolOperands.size() && + "symbolic operand count mistmatch"); + assert(indexRemap.getNumInputs() == + extraOperands.size() + oldMemRefRank + symbolOperands.size()); assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank); } else { assert(oldMemRefRank + extraIndices.size() == newMemRefRank); @@ -131,9 +134,11 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, // provided. The indices of a memref come right after it, i.e., // at position memRefOperandPos + 1. SmallVector remapOperands; - remapOperands.reserve(extraOperands.size() + oldMemRefRank); + remapOperands.reserve(extraOperands.size() + oldMemRefRank + + symbolOperands.size()); remapOperands.append(extraOperands.begin(), extraOperands.end()); remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end()); + remapOperands.append(symbolOperands.begin(), symbolOperands.end()); SmallVector remapOutputs; remapOutputs.reserve(oldMemRefRank); @@ -226,6 +231,7 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, ArrayRef extraIndices, AffineMap indexRemap, ArrayRef extraOperands, + ArrayRef symbolOperands, Operation *domInstFilter, Operation *postDomInstFilter) { unsigned newMemRefRank = newMemRef->getType().cast().getRank(); @@ -233,8 +239,10 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, unsigned oldMemRefRank = oldMemRef->getType().cast().getRank(); (void)oldMemRefRank; if (indexRemap) { - assert(indexRemap.getNumSymbols() == 0 && "pure dimensional map expected"); - assert(indexRemap.getNumInputs() == extraOperands.size() + oldMemRefRank); + assert(indexRemap.getNumSymbols() == symbolOperands.size() && + "symbol operand count mismatch"); + assert(indexRemap.getNumInputs() == + extraOperands.size() + oldMemRefRank + symbolOperands.size()); assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank); } else { assert(oldMemRefRank + extraIndices.size() == newMemRefRank); @@ -287,7 +295,8 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, for (auto *op : opsToReplace) { if (failed(replaceAllMemRefUsesWith(oldMemRef, newMemRef, op, extraIndices, - indexRemap, extraOperands))) + indexRemap, extraOperands, + symbolOperands))) llvm_unreachable("memref replacement guaranteed to succeed here"); } @@ -446,6 +455,8 @@ LogicalResult mlir::normalizeMemRef(AllocOp allocOp) { } auto *oldMemRef = allocOp.getResult(); + SmallVector symbolOperands(allocOp.getSymbolicOperands()); + auto newMemRefType = b.getMemRefType(newShape, memrefType.getElementType(), b.getMultiDimIdentityMap(newRank)); auto newAlloc = b.create(allocOp.getLoc(), newMemRefType); @@ -453,7 +464,9 @@ LogicalResult mlir::normalizeMemRef(AllocOp allocOp) { // Replace all uses of the old memref. if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc, /*extraIndices=*/{}, - /*indexRemap=*/layoutMap))) { + /*indexRemap=*/layoutMap, + /*extraOperands=*/{}, + /*symbolOperands=*/symbolOperands))) { // If it failed (due to escapes for example), bail out. newAlloc.erase(); return failure(); diff --git a/mlir/test/Transforms/memref-normalize.mlir b/mlir/test/Transforms/memref-normalize.mlir index c4973e8ecee..e9b63624120 100644 --- a/mlir/test/Transforms/memref-normalize.mlir +++ b/mlir/test/Transforms/memref-normalize.mlir @@ -96,6 +96,21 @@ func @strided_cumulative() { return } +// Symbolic operand for alloc, although unused. Tests replaceAllMemRefUsesWith +// when the index remap has symbols. +// CHECK-LABEL: func @symbolic_operands +func @symbolic_operands(%s : index) { + // CHECK: alloc() : memref<100xf32> + %A = alloc()[%s] : memref<10x10xf32, (d0,d1)[s0] -> (10*d0 + d1)> + affine.for %i = 0 to 10 { + affine.for %j = 0 to 10 { + // CHECK: affine.load %{{.*}}[%{{.*}} * 10 + %{{.*}}] : memref<100xf32> + affine.load %A[%i, %j] : memref<10x10xf32, (d0,d1)[s0] -> (10*d0 + d1)> + } + } + return +} + // Memref escapes; no normalization. // CHECK-LABEL: func @escaping() -> memref<64xf32, #map{{[0-9]+}}> func @escaping() -> memref<64xf32, (d0) -> (d0 + 2)> { -- cgit v1.2.3 From 3451055614a26f353438430d32e7920ce57ab4b9 Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Wed, 9 Oct 2019 10:36:54 -0700 Subject: Add support for some multi-store cases in affine fusion This PR is a stepping stone towards supporting generic multi-store source loop nests in affine loop fusion. It extends the algorithm to support fusion of multi-store loop nests that: 1. have only one store that writes to a function-local live out, and 2. the remaining stores are involved in loop nest self dependences or no dependences within the function. Closes tensorflow/mlir#162 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/162 from dcaballe:dcaballe/multi-output-fusion 7fb7dec6fe8b45f5ce176f018bfe37b256420c45 PiperOrigin-RevId: 273773907 --- mlir/lib/Transforms/LoopFusion.cpp | 100 ++++++++++++++++++++++------------ mlir/test/Transforms/loop-fusion.mlir | 75 +++++++++++++++++++++++++ 2 files changed, 141 insertions(+), 34 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 188165b94e1..15dc36c9c13 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -322,6 +322,44 @@ public: return false; } + // Returns the unique AffineStoreOp in `node` that meets all the following: + // *) store is the only one that writes to a function-local memref live out + // of `node`, + // *) store is not the source of a self-dependence on `node`. + // Otherwise, returns a null AffineStoreOp. + AffineStoreOp getUniqueOutgoingStore(Node *node) { + AffineStoreOp uniqueStore; + + // Return null if `node` doesn't have any outgoing edges. + auto outEdgeIt = outEdges.find(node->id); + if (outEdgeIt == outEdges.end()) + return nullptr; + + const auto &nodeOutEdges = outEdgeIt->second; + for (auto *op : node->stores) { + auto storeOp = cast(op); + auto *memref = storeOp.getMemRef(); + // Skip this store if there are no dependences on its memref. This means + // that store either: + // *) writes to a memref that is only read within the same loop nest + // (self-dependence edges are not represented in graph at the moment), + // *) writes to a function live out memref (function parameter), or + // *) is dead. + if (llvm::all_of(nodeOutEdges, [=](const Edge &edge) { + return (edge.value != memref); + })) + continue; + + if (uniqueStore) + // Found multiple stores to function-local live-out memrefs. + return nullptr; + // Found first store to function-local live-out memref. + uniqueStore = storeOp; + } + + return uniqueStore; + } + // Returns true if node 'id' can be removed from the graph. Returns false // otherwise. A node can be removed from the graph iff the following // conditions are met: @@ -963,42 +1001,30 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, return newMemRef; } -// Checks if node 'srcId' (which writes to a live out memref), can be safely -// fused into node 'dstId'. Returns true if the following conditions are met: -// *) 'srcNode' only writes to live out 'memref'. -// *) 'srcNode' has exactly one output edge on 'memref' (which is to 'dstId'). -// *) 'dstNode's read/write region to 'memref' is a super set of 'srcNode's -// write region to 'memref'. +// Checks if node 'srcId' can be safely fused into node 'dstId'. Node 'srcId' +// may write to multiple memrefs but it is required that only one of them, +// 'srcLiveOutStoreOp', have an output edge. +// Returns true if 'dstNode's read/write region to 'memref' is a super set of +// 'srcNode's write region to 'memref'. // TODO(andydavis) Generalize this to handle more live in/out cases. static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, - Value *memref, + AffineStoreOp srcLiveOutStoreOp, MemRefDependenceGraph *mdg) { - auto *srcNode = mdg->getNode(srcId); + assert(srcLiveOutStoreOp && "Expected a valid store op"); + assert(mdg->getOutEdgeCount(srcId) == 1 && "Expected only one output edge"); auto *dstNode = mdg->getNode(dstId); + Value *memref = srcLiveOutStoreOp.getMemRef(); - // Gather all memrefs from 'srcNode' store ops. - DenseSet storeMemrefs; - for (auto *storeOpInst : srcNode->stores) { - storeMemrefs.insert(cast(storeOpInst).getMemRef()); - } - // Return false if any of the following are true: - // *) 'srcNode' writes to a live in/out memref other than 'memref'. - // *) 'srcNode' has more than one output edge on 'memref'. - // Check that all stores are to the same memref. - if (storeMemrefs.size() != 1 || - mdg->getOutEdgeCount(srcNode->id, memref) != 1) - return false; - // Compute MemRefRegion 'srcWriteRegion' for 'srcStoreOpInst' on 'memref'. - auto *srcStoreOpInst = srcNode->stores.front(); - MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc()); - if (failed(srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0))) { + // Compute MemRefRegion 'srcWriteRegion' for 'srcStoreOp' on 'memref'. + MemRefRegion srcWriteRegion(srcLiveOutStoreOp.getLoc()); + if (failed(srcWriteRegion.compute(srcLiveOutStoreOp, /*loopDepth=*/0))) { LLVM_DEBUG(llvm::dbgs() << "Unable to compute MemRefRegion for source operation\n."); return false; } SmallVector srcShape; // Query 'srcWriteRegion' for 'srcShape' and 'srcNumElements'. - // by 'srcStoreOpInst' at depth 'dstLoopDepth'. + // by 'srcStoreOp' at depth 'dstLoopDepth'. Optional srcNumElements = srcWriteRegion.getConstantBoundingSizeAndShape(&srcShape); if (!srcNumElements.hasValue()) @@ -1491,17 +1517,25 @@ public: // Skip if 'srcNode' is not a loop nest. if (!isa(srcNode->op)) continue; - // Skip if 'srcNode' has more than one store to any memref. - // TODO(andydavis) Support fusing multi-output src loop nests. - if (srcNode->stores.size() != 1) + // Skip if 'srcNode' has more than one live-out store to a + // function-local memref. + // TODO(andydavis) Support more generic multi-output src loop nests + // fusion. + auto srcStoreOp = mdg->getUniqueOutgoingStore(srcNode); + if (!srcStoreOp) continue; + // Unique outgoing store found must write to 'memref' since 'memref' + // is the one that established the producer-consumer relationship + // between 'srcNode' and 'dstNode'. + assert(srcStoreOp.getMemRef() == memref && + "Found store to unexpected memref"); // Skip if 'srcNode' writes to any live in or escaping memrefs, // and cannot be fused. bool writesToLiveInOrOut = mdg->writesToLiveInOrEscapingMemrefs(srcNode->id); if (writesToLiveInOrOut && - !canFuseSrcWhichWritesToLiveOut(srcId, dstId, memref, mdg)) + !canFuseSrcWhichWritesToLiveOut(srcId, dstId, srcStoreOp, mdg)) continue; // Skip if 'srcNode' out edge count on 'memref' > 'maxSrcUserCount'. @@ -1515,8 +1549,6 @@ public: if (insertPointInst == nullptr) continue; - // Get unique 'srcNode' store op. - auto *srcStoreOpInst = srcNode->stores.front(); // Gather 'dstNode' store ops to 'memref'. SmallVector dstStoreOpInsts; for (auto *storeOpInst : dstNode->stores) @@ -1526,8 +1558,8 @@ public: unsigned bestDstLoopDepth; mlir::ComputationSliceState sliceState; // Check if fusion would be profitable. - if (!isFusionProfitable(srcStoreOpInst, srcStoreOpInst, - dstLoadOpInsts, dstStoreOpInsts, &sliceState, + if (!isFusionProfitable(srcStoreOp, srcStoreOp, dstLoadOpInsts, + dstStoreOpInsts, &sliceState, &bestDstLoopDepth, maximalFusion)) continue; // TODO(andydavis) Remove the following test code when canFuseLoops @@ -1542,7 +1574,7 @@ public: } // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. auto sliceLoopNest = mlir::insertBackwardComputationSlice( - srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState); + srcStoreOp, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState); if (sliceLoopNest) { LLVM_DEBUG(llvm::dbgs() << "\tslice loop nest:\n" << *sliceLoopNest.getOperation() << "\n"); diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index c97e3df715e..6ff31de7318 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -1251,6 +1251,7 @@ func @should_not_fuse_multi_output_producer() { } affine.for %i1 = 0 to 10 { %v0 = affine.load %a[%i1] : memref<10xf32> + %v1 = affine.load %b[%i1] : memref<10xf32> } // CHECK: affine.for %{{.*}} = 0 to 10 { @@ -1259,6 +1260,7 @@ func @should_not_fuse_multi_output_producer() { // CHECK-NEXT: } // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 { // CHECK-NEXT: %{{.*}} = affine.load %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: %{{.*}} = affine.load %{{.*}}[%{{.*}}] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -2266,3 +2268,76 @@ func @affine_2_dependent_mm_fused(%arg0: memref<1024x1024xf32>, %arg1: memref<10 // CHECK-NEXT: } return } + +// ----- + +// CHECK-LABEL: func @should_fuse_self_dependence_multi_store_producer() { +func @should_fuse_self_dependence_multi_store_producer() { + %m = alloc() : memref<10xf32> + %local_m = alloc() : memref<10xf32> + %cf7 = constant 7.0 : f32 + + affine.for %i0 = 0 to 10 { + affine.store %cf7, %local_m[%i0] : memref<10xf32> + %v0 = affine.load %local_m[%i0] : memref<10xf32> + affine.store %v0, %m[%i0] : memref<10xf32> + } + affine.for %i1 = 0 to 10 { + %v1 = affine.load %m[%i1] : memref<10xf32> + } + // CHECK: affine.for %[[i0:.*]] = 0 to 10 { + // CHECK-NEXT: affine.store %{{.*}}, [[LOCAL_M:%.*]][%[[i0]]] : memref<10xf32> + // CHECK-NEXT: [[v0:%.*]] = affine.load [[LOCAL_M]][%[[i0]]] : memref<10xf32> + // CHECK-NEXT: affine.store [[v0]], %{{.*}}[0] : memref<1xf32> + // CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32> + // CHECK-NEXT: } + // CHECK-NEXT: return + return +} + +// ----- + +// CHECK-LABEL: func @should_fuse_dead_multi_store_producer() { +func @should_fuse_dead_multi_store_producer() { + %m = alloc() : memref<10xf32> + %dead_m = alloc() : memref<10xf32> + %cf7 = constant 7.0 : f32 + + affine.for %i0 = 0 to 10 { + affine.store %cf7, %dead_m[%i0] : memref<10xf32> + affine.store %cf7, %m[%i0] : memref<10xf32> + } + affine.for %i1 = 0 to 10 { + %v0 = affine.load %m[%i1] : memref<10xf32> + } + // CHECK: affine.for %[[i0:.*]] = 0 to 10 { + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%[[i0]]] : memref<10xf32> + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32> + // CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32> + // CHECK-NEXT: } + // CHECK-NEXT: return + return +} + +// ----- + +// CHECK-LABEL: func @should_fuse_function_live_out_multi_store_producer +func @should_fuse_function_live_out_multi_store_producer(%live_in_out_m : memref<10xf32>) { + %m = alloc() : memref<10xf32> + %cf7 = constant 7.0 : f32 + + affine.for %i0 = 0 to 10 { + affine.store %cf7, %live_in_out_m[%i0] : memref<10xf32> + affine.store %cf7, %m[%i0] : memref<10xf32> + } + affine.for %i1 = 0 to 10 { + %v0 = affine.load %m[%i1] : memref<10xf32> + } + // CHECK: affine.for %[[i0:.*]] = 0 to 10 { + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%[[i0]]] : memref<10xf32> + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%[[i0]]] : memref<10xf32> + // CHECK-NEXT: affine.load %{{.*}}[%[[i0]]] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: return + return +} -- cgit v1.2.3 From 2acc220f17bacbf933d024a68385a909b44352fd Mon Sep 17 00:00:00 2001 From: River Riddle Date: Thu, 17 Oct 2019 20:08:01 -0700 Subject: NFC: Remove trivial builder get methods. These don't add any value, and some are even more restrictive than the respective static 'get' method. PiperOrigin-RevId: 275391240 --- mlir/examples/toy/Ch2/mlir/Dialect.cpp | 10 +-- mlir/examples/toy/Ch2/mlir/MLIRGen.cpp | 6 +- mlir/examples/toy/Ch3/mlir/Dialect.cpp | 10 +-- mlir/examples/toy/Ch3/mlir/MLIRGen.cpp | 6 +- mlir/examples/toy/Ch4/mlir/Dialect.cpp | 10 +-- mlir/examples/toy/Ch4/mlir/MLIRGen.cpp | 6 +- mlir/examples/toy/Ch5/mlir/Dialect.cpp | 10 +-- mlir/examples/toy/Ch5/mlir/MLIRGen.cpp | 6 +- mlir/examples/toy/Ch6/mlir/Dialect.cpp | 10 +-- mlir/examples/toy/Ch6/mlir/MLIRGen.cpp | 6 +- .../StandardToSPIRV/ConvertStandardToSPIRV.h | 5 +- .../mlir/Dialect/AffineOps/AffineOpsBase.td | 2 +- mlir/include/mlir/Dialect/StandardOps/Ops.td | 4 +- mlir/include/mlir/IR/Builders.h | 28 +-------- mlir/include/mlir/IR/OpBase.td | 6 +- mlir/lib/Analysis/LoopAnalysis.cpp | 2 +- .../StandardToSPIRV/ConvertStandardToSPIRV.cpp | 2 +- mlir/lib/Dialect/AffineOps/AffineOps.cpp | 28 ++++----- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 6 +- mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 2 +- .../Dialect/SPIRV/Serialization/Deserializer.cpp | 8 +-- .../DecorateSPIRVCompositeTypeLayoutPass.cpp | 2 +- mlir/lib/Dialect/StandardOps/Ops.cpp | 8 +-- mlir/lib/Dialect/VectorOps/VectorOps.cpp | 4 +- mlir/lib/IR/Builders.cpp | 72 +--------------------- mlir/lib/IR/Function.cpp | 2 +- mlir/lib/IR/FunctionSupport.cpp | 2 +- mlir/lib/Parser/Parser.cpp | 20 +++--- .../Transforms/AddDefaultStatsTestPass.cpp | 4 +- mlir/lib/Transforms/LoopFusion.cpp | 6 +- mlir/lib/Transforms/LoopTiling.cpp | 4 +- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 2 +- mlir/lib/Transforms/PipelineDataTransfer.cpp | 8 +-- mlir/lib/Transforms/Utils/LoopUtils.cpp | 24 ++++---- mlir/lib/Transforms/Utils/Utils.cpp | 14 ++--- mlir/lib/Transforms/Vectorize.cpp | 2 +- mlir/test/EDSC/builder-api-test.cpp | 2 +- mlir/unittests/Dialect/SPIRV/SerializationTest.cpp | 2 +- 38 files changed, 128 insertions(+), 223 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/examples/toy/Ch2/mlir/Dialect.cpp b/mlir/examples/toy/Ch2/mlir/Dialect.cpp index 0603943270b..d0746bb3c79 100644 --- a/mlir/examples/toy/Ch2/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch2/mlir/Dialect.cpp @@ -50,7 +50,7 @@ ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) { /// expected to fill in order to build the operation. static void buildConstantOp(mlir::Builder *builder, mlir::OperationState &result, double value) { - auto dataType = builder->getTensorType({}, builder->getF64Type()); + auto dataType = RankedTensorType::get({}, builder->getF64Type()); auto dataAttribute = DenseElementsAttr::get(dataType, value); ConstantOp::build(builder, result, dataType, dataAttribute); } @@ -88,7 +88,7 @@ static mlir::LogicalResult verify(ConstantOp op) { static void buildAddOp(mlir::Builder *builder, mlir::OperationState &result, mlir::Value *lhs, mlir::Value *rhs) { - result.addTypes(builder->getTensorType(builder->getF64Type())); + result.addTypes(UnrankedTensorType::get(builder->getF64Type())); result.addOperands({lhs, rhs}); } @@ -96,14 +96,14 @@ static void buildGenericCallOp(mlir::Builder *builder, mlir::OperationState &result, StringRef callee, ArrayRef arguments) { // Generic call always returns an unranked Tensor initially. - result.addTypes(builder->getTensorType(builder->getF64Type())); + result.addTypes(UnrankedTensorType::get(builder->getF64Type())); result.addOperands(arguments); result.addAttribute("callee", builder->getSymbolRefAttr(callee)); } static void buildMulOp(mlir::Builder *builder, mlir::OperationState &result, mlir::Value *lhs, mlir::Value *rhs) { - result.addTypes(builder->getTensorType(builder->getF64Type())); + result.addTypes(UnrankedTensorType::get(builder->getF64Type())); result.addOperands({lhs, rhs}); } @@ -144,7 +144,7 @@ static mlir::LogicalResult verify(ReturnOp op) { static void buildTransposeOp(mlir::Builder *builder, mlir::OperationState &result, mlir::Value *value) { - result.addTypes(builder->getTensorType(builder->getF64Type())); + result.addTypes(UnrankedTensorType::get(builder->getF64Type())); result.addOperands(value); } diff --git a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp index 5f12d0a8798..55391d72245 100644 --- a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp @@ -282,7 +282,7 @@ private: // The type of this attribute is tensor of 64-bit floating-point with the // shape of the literal. mlir::Type elementType = builder.getF64Type(); - auto dataType = builder.getTensorType(lit.getDims(), elementType); + auto dataType = mlir::RankedTensorType::get(lit.getDims(), elementType); // This is the actual attribute that holds the list of values for this // tensor literal. @@ -443,10 +443,10 @@ private: mlir::Type getType(ArrayRef shape) { // If the shape is empty, then this type is unranked. if (shape.empty()) - return builder.getTensorType(builder.getF64Type()); + return mlir::UnrankedTensorType::get(builder.getF64Type()); // Otherwise, we use the given shape. - return builder.getTensorType(shape, builder.getF64Type()); + return mlir::RankedTensorType::get(shape, builder.getF64Type()); } /// Build an MLIR type from a Toy AST variable type (forward to the generic diff --git a/mlir/examples/toy/Ch3/mlir/Dialect.cpp b/mlir/examples/toy/Ch3/mlir/Dialect.cpp index 5ca50d961bf..37292a2d98a 100644 --- a/mlir/examples/toy/Ch3/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch3/mlir/Dialect.cpp @@ -50,7 +50,7 @@ ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) { /// expected to fill in order to build the operation. static void buildConstantOp(mlir::Builder *builder, mlir::OperationState &state, double value) { - auto dataType = builder->getTensorType({}, builder->getF64Type()); + auto dataType = RankedTensorType::get({}, builder->getF64Type()); auto dataAttribute = DenseElementsAttr::get(dataType, value); ConstantOp::build(builder, state, dataType, dataAttribute); } @@ -88,7 +88,7 @@ static mlir::LogicalResult verify(ConstantOp op) { static void buildAddOp(mlir::Builder *builder, mlir::OperationState &state, mlir::Value *lhs, mlir::Value *rhs) { - state.addTypes(builder->getTensorType(builder->getF64Type())); + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands({lhs, rhs}); } @@ -96,14 +96,14 @@ static void buildGenericCallOp(mlir::Builder *builder, mlir::OperationState &state, StringRef callee, ArrayRef arguments) { // Generic call always returns an unranked Tensor initially. - state.addTypes(builder->getTensorType(builder->getF64Type())); + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands(arguments); state.addAttribute("callee", builder->getSymbolRefAttr(callee)); } static void buildMulOp(mlir::Builder *builder, mlir::OperationState &state, mlir::Value *lhs, mlir::Value *rhs) { - state.addTypes(builder->getTensorType(builder->getF64Type())); + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands({lhs, rhs}); } @@ -144,7 +144,7 @@ static mlir::LogicalResult verify(ReturnOp op) { static void buildTransposeOp(mlir::Builder *builder, mlir::OperationState &state, mlir::Value *value) { - state.addTypes(builder->getTensorType(builder->getF64Type())); + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands(value); } diff --git a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp index 5f12d0a8798..55391d72245 100644 --- a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp @@ -282,7 +282,7 @@ private: // The type of this attribute is tensor of 64-bit floating-point with the // shape of the literal. mlir::Type elementType = builder.getF64Type(); - auto dataType = builder.getTensorType(lit.getDims(), elementType); + auto dataType = mlir::RankedTensorType::get(lit.getDims(), elementType); // This is the actual attribute that holds the list of values for this // tensor literal. @@ -443,10 +443,10 @@ private: mlir::Type getType(ArrayRef shape) { // If the shape is empty, then this type is unranked. if (shape.empty()) - return builder.getTensorType(builder.getF64Type()); + return mlir::UnrankedTensorType::get(builder.getF64Type()); // Otherwise, we use the given shape. - return builder.getTensorType(shape, builder.getF64Type()); + return mlir::RankedTensorType::get(shape, builder.getF64Type()); } /// Build an MLIR type from a Toy AST variable type (forward to the generic diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp index e31cb917d89..254f92f08fd 100644 --- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp @@ -100,7 +100,7 @@ ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) { /// expected to fill in order to build the operation. static void buildConstantOp(mlir::Builder *builder, mlir::OperationState &state, double value) { - auto dataType = builder->getTensorType({}, builder->getF64Type()); + auto dataType = RankedTensorType::get({}, builder->getF64Type()); auto dataAttribute = DenseElementsAttr::get(dataType, value); ConstantOp::build(builder, state, dataType, dataAttribute); } @@ -142,7 +142,7 @@ static mlir::LogicalResult verify(ConstantOp op) { static void buildAddOp(mlir::Builder *builder, mlir::OperationState &state, mlir::Value *lhs, mlir::Value *rhs) { - state.addTypes(builder->getTensorType(builder->getF64Type())); + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands({lhs, rhs}); } @@ -154,7 +154,7 @@ static void buildGenericCallOp(mlir::Builder *builder, mlir::OperationState &state, StringRef callee, ArrayRef arguments) { // Generic call always returns an unranked Tensor initially. - state.addTypes(builder->getTensorType(builder->getF64Type())); + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands(arguments); state.addAttribute("callee", builder->getSymbolRefAttr(callee)); } @@ -171,7 +171,7 @@ Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); } static void buildMulOp(mlir::Builder *builder, mlir::OperationState &state, mlir::Value *lhs, mlir::Value *rhs) { - state.addTypes(builder->getTensorType(builder->getF64Type())); + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands({lhs, rhs}); } @@ -235,7 +235,7 @@ static mlir::LogicalResult verify(ReturnOp op) { static void buildTransposeOp(mlir::Builder *builder, mlir::OperationState &state, mlir::Value *value) { - state.addTypes(builder->getTensorType(builder->getF64Type())); + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands(value); } diff --git a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp index 5f12d0a8798..55391d72245 100644 --- a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp @@ -282,7 +282,7 @@ private: // The type of this attribute is tensor of 64-bit floating-point with the // shape of the literal. mlir::Type elementType = builder.getF64Type(); - auto dataType = builder.getTensorType(lit.getDims(), elementType); + auto dataType = mlir::RankedTensorType::get(lit.getDims(), elementType); // This is the actual attribute that holds the list of values for this // tensor literal. @@ -443,10 +443,10 @@ private: mlir::Type getType(ArrayRef shape) { // If the shape is empty, then this type is unranked. if (shape.empty()) - return builder.getTensorType(builder.getF64Type()); + return mlir::UnrankedTensorType::get(builder.getF64Type()); // Otherwise, we use the given shape. - return builder.getTensorType(shape, builder.getF64Type()); + return mlir::RankedTensorType::get(shape, builder.getF64Type()); } /// Build an MLIR type from a Toy AST variable type (forward to the generic diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp index e31cb917d89..254f92f08fd 100644 --- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp @@ -100,7 +100,7 @@ ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) { /// expected to fill in order to build the operation. static void buildConstantOp(mlir::Builder *builder, mlir::OperationState &state, double value) { - auto dataType = builder->getTensorType({}, builder->getF64Type()); + auto dataType = RankedTensorType::get({}, builder->getF64Type()); auto dataAttribute = DenseElementsAttr::get(dataType, value); ConstantOp::build(builder, state, dataType, dataAttribute); } @@ -142,7 +142,7 @@ static mlir::LogicalResult verify(ConstantOp op) { static void buildAddOp(mlir::Builder *builder, mlir::OperationState &state, mlir::Value *lhs, mlir::Value *rhs) { - state.addTypes(builder->getTensorType(builder->getF64Type())); + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands({lhs, rhs}); } @@ -154,7 +154,7 @@ static void buildGenericCallOp(mlir::Builder *builder, mlir::OperationState &state, StringRef callee, ArrayRef arguments) { // Generic call always returns an unranked Tensor initially. - state.addTypes(builder->getTensorType(builder->getF64Type())); + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands(arguments); state.addAttribute("callee", builder->getSymbolRefAttr(callee)); } @@ -171,7 +171,7 @@ Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); } static void buildMulOp(mlir::Builder *builder, mlir::OperationState &state, mlir::Value *lhs, mlir::Value *rhs) { - state.addTypes(builder->getTensorType(builder->getF64Type())); + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands({lhs, rhs}); } @@ -235,7 +235,7 @@ static mlir::LogicalResult verify(ReturnOp op) { static void buildTransposeOp(mlir::Builder *builder, mlir::OperationState &state, mlir::Value *value) { - state.addTypes(builder->getTensorType(builder->getF64Type())); + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands(value); } diff --git a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp index 5f12d0a8798..55391d72245 100644 --- a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp @@ -282,7 +282,7 @@ private: // The type of this attribute is tensor of 64-bit floating-point with the // shape of the literal. mlir::Type elementType = builder.getF64Type(); - auto dataType = builder.getTensorType(lit.getDims(), elementType); + auto dataType = mlir::RankedTensorType::get(lit.getDims(), elementType); // This is the actual attribute that holds the list of values for this // tensor literal. @@ -443,10 +443,10 @@ private: mlir::Type getType(ArrayRef shape) { // If the shape is empty, then this type is unranked. if (shape.empty()) - return builder.getTensorType(builder.getF64Type()); + return mlir::UnrankedTensorType::get(builder.getF64Type()); // Otherwise, we use the given shape. - return builder.getTensorType(shape, builder.getF64Type()); + return mlir::RankedTensorType::get(shape, builder.getF64Type()); } /// Build an MLIR type from a Toy AST variable type (forward to the generic diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp index e31cb917d89..254f92f08fd 100644 --- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp @@ -100,7 +100,7 @@ ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) { /// expected to fill in order to build the operation. static void buildConstantOp(mlir::Builder *builder, mlir::OperationState &state, double value) { - auto dataType = builder->getTensorType({}, builder->getF64Type()); + auto dataType = RankedTensorType::get({}, builder->getF64Type()); auto dataAttribute = DenseElementsAttr::get(dataType, value); ConstantOp::build(builder, state, dataType, dataAttribute); } @@ -142,7 +142,7 @@ static mlir::LogicalResult verify(ConstantOp op) { static void buildAddOp(mlir::Builder *builder, mlir::OperationState &state, mlir::Value *lhs, mlir::Value *rhs) { - state.addTypes(builder->getTensorType(builder->getF64Type())); + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands({lhs, rhs}); } @@ -154,7 +154,7 @@ static void buildGenericCallOp(mlir::Builder *builder, mlir::OperationState &state, StringRef callee, ArrayRef arguments) { // Generic call always returns an unranked Tensor initially. - state.addTypes(builder->getTensorType(builder->getF64Type())); + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands(arguments); state.addAttribute("callee", builder->getSymbolRefAttr(callee)); } @@ -171,7 +171,7 @@ Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); } static void buildMulOp(mlir::Builder *builder, mlir::OperationState &state, mlir::Value *lhs, mlir::Value *rhs) { - state.addTypes(builder->getTensorType(builder->getF64Type())); + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands({lhs, rhs}); } @@ -235,7 +235,7 @@ static mlir::LogicalResult verify(ReturnOp op) { static void buildTransposeOp(mlir::Builder *builder, mlir::OperationState &state, mlir::Value *value) { - state.addTypes(builder->getTensorType(builder->getF64Type())); + state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands(value); } diff --git a/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp index 5f12d0a8798..55391d72245 100644 --- a/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp @@ -282,7 +282,7 @@ private: // The type of this attribute is tensor of 64-bit floating-point with the // shape of the literal. mlir::Type elementType = builder.getF64Type(); - auto dataType = builder.getTensorType(lit.getDims(), elementType); + auto dataType = mlir::RankedTensorType::get(lit.getDims(), elementType); // This is the actual attribute that holds the list of values for this // tensor literal. @@ -443,10 +443,10 @@ private: mlir::Type getType(ArrayRef shape) { // If the shape is empty, then this type is unranked. if (shape.empty()) - return builder.getTensorType(builder.getF64Type()); + return mlir::UnrankedTensorType::get(builder.getF64Type()); // Otherwise, we use the given shape. - return builder.getTensorType(shape, builder.getF64Type()); + return mlir::RankedTensorType::get(shape, builder.getF64Type()); } /// Build an MLIR type from a Toy AST variable type (forward to the generic diff --git a/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h b/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h index e92ad03d776..63e63cfebb9 100644 --- a/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h +++ b/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h @@ -135,11 +135,10 @@ private: case spirv::BuiltIn::LocalInvocationId: case spirv::BuiltIn::GlobalInvocationId: { auto ptrType = spirv::PointerType::get( - builder.getVectorType({3}, builder.getIntegerType(32)), + VectorType::get({3}, builder.getIntegerType(32)), spirv::StorageClass::Input); newVarOp = builder.create( - loc, builder.getTypeAttr(ptrType), builder.getStringAttr(name), - nullptr); + loc, TypeAttr::get(ptrType), builder.getStringAttr(name), nullptr); newVarOp.setAttr( convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn)), builder.getStringAttr(stringifyBuiltIn(builtin))); diff --git a/mlir/include/mlir/Dialect/AffineOps/AffineOpsBase.td b/mlir/include/mlir/Dialect/AffineOps/AffineOpsBase.td index 2ac1d379c12..fb4439a43ac 100644 --- a/mlir/include/mlir/Dialect/AffineOps/AffineOpsBase.td +++ b/mlir/include/mlir/Dialect/AffineOps/AffineOpsBase.td @@ -33,7 +33,7 @@ def AffineMapAttr : Attr< CPred<"$_self.isa()">, "AffineMap attribute"> { let storageType = [{ AffineMapAttr }]; let returnType = [{ AffineMap }]; - let constBuilderCall = "$_builder.getAffineMapAttr($0)"; + let constBuilderCall = "AffineMapAttr::get($0)"; } def AffineMapArrayAttr : TypedArrayAttrBasegetType().cast(); - auto resultType = builder->getTensorType(memrefType.getShape(), - memrefType.getElementType()); + auto resultType = RankedTensorType::get(memrefType.getShape(), + memrefType.getElementType()); result.addOperands(memref); result.addTypes(resultType); }]>]; diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index dcc5280f49e..0005a395e70 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -80,12 +80,6 @@ public: IntegerType getI1Type(); IntegerType getIntegerType(unsigned width); FunctionType getFunctionType(ArrayRef inputs, ArrayRef results); - MemRefType getMemRefType(ArrayRef shape, Type elementType, - ArrayRef affineMapComposition = {}, - unsigned memorySpace = 0); - VectorType getVectorType(ArrayRef shape, Type elementType); - RankedTensorType getTensorType(ArrayRef shape, Type elementType); - UnrankedTensorType getTensorType(Type elementType); TupleType getTupleType(ArrayRef elementTypes); NoneType getNoneType(); @@ -105,22 +99,10 @@ public: FloatAttr getFloatAttr(Type type, double value); FloatAttr getFloatAttr(Type type, const APFloat &value); StringAttr getStringAttr(StringRef bytes); - StringAttr getStringAttr(StringRef bytes, Type type); ArrayAttr getArrayAttr(ArrayRef value); - AffineMapAttr getAffineMapAttr(AffineMap map); - IntegerSetAttr getIntegerSetAttr(IntegerSet set); - TypeAttr getTypeAttr(Type type); SymbolRefAttr getSymbolRefAttr(Operation *value); SymbolRefAttr getSymbolRefAttr(StringRef value); - ElementsAttr getDenseElementsAttr(ShapedType type, - ArrayRef values); - ElementsAttr getDenseIntElementsAttr(ShapedType type, - ArrayRef values); - ElementsAttr getSparseElementsAttr(ShapedType type, - DenseIntElementsAttr indices, - DenseElementsAttr values); - ElementsAttr getOpaqueElementsAttr(Dialect *dialect, ShapedType type, - StringRef bytes); + // Returns a 0-valued attribute of the given `type`. This function only // supports boolean, integer, and 16-/32-/64-bit float types, and vector or // ranked tensor of them. Returns null attribute otherwise. @@ -149,9 +131,6 @@ public: AffineExpr getAffineSymbolExpr(unsigned position); AffineExpr getAffineConstantExpr(int64_t constant); - AffineMap getAffineMap(unsigned dimCount, unsigned symbolCount, - ArrayRef results); - // Special cases of affine maps and integer sets /// Returns a zero result affine map with no dimensions or symbols: () -> (). AffineMap getEmptyAffineMap(); @@ -175,11 +154,6 @@ public: /// returns: (d0, d1)[s0] -> (d0 + 2, d1 + s0 + 2) AffineMap getShiftedAffineMap(AffineMap map, int64_t shift); - // Integer set. - IntegerSet getIntegerSet(unsigned dimCount, unsigned symbolCount, - ArrayRef constraints, - ArrayRef isEq); - // TODO: Helpers for affine map/exprs, etc. protected: MLIRContext *context; }; diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index dbb1e7f0a73..c1bd04fea4e 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -970,7 +970,7 @@ class IntElementsAttr : ElementsAttrBase< // Note that this is only constructing scalar elements attribute. let constBuilderCall = "DenseElementsAttr::get(" - "$_builder.getTensorType({}, $_builder.getIntegerType(" # width # ")), " + "RankedTensorType::get({}, $_builder.getIntegerType(" # width # ")), " "llvm::makeArrayRef($0)).cast()"; let convertFromStorage = "$_self"; } @@ -989,7 +989,7 @@ class FloatElementsAttr : ElementsAttrBase< // Note that this is only constructing scalar elements attribute. let constBuilderCall = "DenseElementsAttr::get(" - "$_builder.getTensorType({}, $_builder.getF" # width # "Type())," + "RankedTensorType::get({}, $_builder.getF" # width # "Type())," "llvm::makeArrayRef($0))"; let convertFromStorage = "$_self"; } @@ -1013,7 +1013,7 @@ class RankedFloatElementsAttr dims> : ElementsAttrBase< let returnType = [{ DenseFPElementsAttr }]; let constBuilderCall = "DenseElementsAttr::get(" - "$_builder.getTensorType({" # StrJoinInt.result # + "RankedTensorType::get({" # StrJoinInt.result # "}, $_builder.getF" # width # "Type()), " "llvm::makeArrayRef($0)).cast()"; let convertFromStorage = "$_self"; diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index b1895d308ad..55a0ac6df84 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -77,7 +77,7 @@ void mlir::buildTripCountMapAndOperands( SmallVector lbSplatExpr(ubValueMap.getNumResults(), lbMap.getResult(0)); auto lbMapSplat = - b.getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(), lbSplatExpr); + AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(), lbSplatExpr); AffineValueMap lbSplatValueMap(lbMapSplat, lbOperands); AffineValueMap tripCountValueMap; diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp index a0906b75950..718f8077981 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -132,7 +132,7 @@ static Value *createAndLoadGlobalVarForEntryFnArg(PatternRewriter &rewriter, funcOp.getName().str() + "_arg_" + std::to_string(origArgNum); var = rewriter.create( funcOp.getLoc(), - rewriter.getTypeAttr(getGlobalVarTypeForEntryFnArg(origArg->getType())), + TypeAttr::get(getGlobalVarTypeForEntryFnArg(origArg->getType())), rewriter.getStringAttr(varName), nullptr); var.setAttr( spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet), diff --git a/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/mlir/lib/Dialect/AffineOps/AffineOps.cpp index 38d4edb60f4..9980497640a 100644 --- a/mlir/lib/Dialect/AffineOps/AffineOps.cpp +++ b/mlir/lib/Dialect/AffineOps/AffineOps.cpp @@ -198,7 +198,7 @@ void AffineApplyOp::build(Builder *builder, OperationState &result, AffineMap map, ArrayRef operands) { result.addOperands(operands); result.types.append(map.getNumResults(), builder->getIndexType()); - result.addAttribute("map", builder->getAffineMapAttr(map)); + result.addAttribute("map", AffineMapAttr::get(map)); } ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) { @@ -817,13 +817,13 @@ void AffineDmaStartOp::build(Builder *builder, OperationState &result, ArrayRef tagIndices, Value *numElements, Value *stride, Value *elementsPerStride) { result.addOperands(srcMemRef); - result.addAttribute(getSrcMapAttrName(), builder->getAffineMapAttr(srcMap)); + result.addAttribute(getSrcMapAttrName(), AffineMapAttr::get(srcMap)); result.addOperands(srcIndices); result.addOperands(destMemRef); - result.addAttribute(getDstMapAttrName(), builder->getAffineMapAttr(dstMap)); + result.addAttribute(getDstMapAttrName(), AffineMapAttr::get(dstMap)); result.addOperands(destIndices); result.addOperands(tagMemRef); - result.addAttribute(getTagMapAttrName(), builder->getAffineMapAttr(tagMap)); + result.addAttribute(getTagMapAttrName(), AffineMapAttr::get(tagMap)); result.addOperands(tagIndices); result.addOperands(numElements); if (stride) { @@ -985,7 +985,7 @@ void AffineDmaWaitOp::build(Builder *builder, OperationState &result, Value *tagMemRef, AffineMap tagMap, ArrayRef tagIndices, Value *numElements) { result.addOperands(tagMemRef); - result.addAttribute(getTagMapAttrName(), builder->getAffineMapAttr(tagMap)); + result.addAttribute(getTagMapAttrName(), AffineMapAttr::get(tagMap)); result.addOperands(tagIndices); result.addOperands(numElements); } @@ -1073,13 +1073,11 @@ void AffineForOp::build(Builder *builder, OperationState &result, builder->getIntegerAttr(builder->getIndexType(), step)); // Add the lower bound. - result.addAttribute(getLowerBoundAttrName(), - builder->getAffineMapAttr(lbMap)); + result.addAttribute(getLowerBoundAttrName(), AffineMapAttr::get(lbMap)); result.addOperands(lbOperands); // Add the upper bound. - result.addAttribute(getUpperBoundAttrName(), - builder->getAffineMapAttr(ubMap)); + result.addAttribute(getUpperBoundAttrName(), AffineMapAttr::get(ubMap)); result.addOperands(ubOperands); // Create a region and a block for the body. The argument of the region is @@ -1164,7 +1162,7 @@ static ParseResult parseBound(bool isLower, OperationState &result, // for storage. Analysis passes may expand it into a multi-dimensional map // if desired. AffineMap map = builder.getSymbolIdentityMap(); - result.addAttribute(boundAttrName, builder.getAffineMapAttr(map)); + result.addAttribute(boundAttrName, AffineMapAttr::get(map)); return success(); } @@ -1213,8 +1211,8 @@ static ParseResult parseBound(bool isLower, OperationState &result, if (auto integerAttr = boundAttr.dyn_cast()) { result.attributes.pop_back(); result.addAttribute( - boundAttrName, builder.getAffineMapAttr( - builder.getConstantAffineMap(integerAttr.getInt()))); + boundAttrName, + AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt()))); return success(); } @@ -1752,7 +1750,7 @@ void AffineLoadOp::build(Builder *builder, OperationState &result, assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands"); result.addOperands(operands); if (map) - result.addAttribute(getMapAttrName(), builder->getAffineMapAttr(map)); + result.addAttribute(getMapAttrName(), AffineMapAttr::get(map)); auto memrefType = operands[0]->getType().cast(); result.types.push_back(memrefType.getElementType()); } @@ -1764,7 +1762,7 @@ void AffineLoadOp::build(Builder *builder, OperationState &result, result.addOperands(memref); result.addOperands(mapOperands); auto memrefType = memref->getType().cast(); - result.addAttribute(getMapAttrName(), builder->getAffineMapAttr(map)); + result.addAttribute(getMapAttrName(), AffineMapAttr::get(map)); result.types.push_back(memrefType.getElementType()); } @@ -1855,7 +1853,7 @@ void AffineStoreOp::build(Builder *builder, OperationState &result, result.addOperands(valueToStore); result.addOperands(memref); result.addOperands(mapOperands); - result.addAttribute(getMapAttrName(), builder->getAffineMapAttr(map)); + result.addAttribute(getMapAttrName(), AffineMapAttr::get(map)); } // Use identity map. diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 5862efe71b5..23e3889c049 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -869,7 +869,7 @@ void GlobalOp::build(Builder *builder, OperationState &result, LLVMType type, ArrayRef attrs) { result.addAttribute(SymbolTable::getSymbolAttrName(), builder->getStringAttr(name)); - result.addAttribute("type", builder->getTypeAttr(type)); + result.addAttribute("type", TypeAttr::get(type)); if (isConstant) result.addAttribute("constant", builder->getUnitAttr()); if (value) @@ -939,7 +939,7 @@ static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) { } } - result.addAttribute("type", parser.getBuilder().getTypeAttr(types[0])); + result.addAttribute("type", TypeAttr::get(types[0])); return success(); } @@ -1026,7 +1026,7 @@ void LLVMFuncOp::build(Builder *builder, OperationState &result, StringRef name, result.addRegion(); result.addAttribute(SymbolTable::getSymbolAttrName(), builder->getStringAttr(name)); - result.addAttribute("type", builder->getTypeAttr(type)); + result.addAttribute("type", TypeAttr::get(type)); result.attributes.append(attrs.begin(), attrs.end()); if (argAttrs.empty()) return; diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 4b1c4e4089d..85d106ed33e 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -1244,7 +1244,7 @@ static ParseResult parseGlobalVariableOp(OpAsmParser &parser, if (!type.isa()) { return parser.emitError(loc, "expected spv.ptr type"); } - state.addAttribute(kTypeAttrName, parser.getBuilder().getTypeAttr(type)); + state.addAttribute(kTypeAttrName, TypeAttr::get(type)); return success(); } diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index accee7c2214..6ba18d1f1d0 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -954,8 +954,8 @@ LogicalResult Deserializer::processGlobalVariable(ArrayRef operands) { << wordIndex << " of " << operands.size() << " processed"; } auto varOp = opBuilder.create( - unknownLoc, opBuilder.getTypeAttr(type), - opBuilder.getStringAttr(variableName), initializer); + unknownLoc, TypeAttr::get(type), opBuilder.getStringAttr(variableName), + initializer); // Decorations. if (decorations.count(variableID)) { @@ -1065,7 +1065,7 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode, return emitError(unknownLoc, "OpTypeVector references undefined ") << operands[1]; } - typeMap[operands[0]] = opBuilder.getVectorType({operands[2]}, elementTy); + typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy); } break; case spirv::Opcode::OpTypePointer: { if (operands.size() != 3) { @@ -1391,7 +1391,7 @@ Deserializer::processConstantComposite(ArrayRef operands) { auto resultID = operands[1]; if (auto vectorType = resultType.dyn_cast()) { - auto attr = opBuilder.getDenseElementsAttr(vectorType, elements); + auto attr = DenseElementsAttr::get(vectorType, elements); // For normal constants, we just record the attribute (and its type) for // later materialization at use sites. constantMap.try_emplace(resultID, attr, resultType); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp index a854a1d511c..1fd6274b16e 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp @@ -58,7 +58,7 @@ public: } rewriter.replaceOpWithNewOp( - op, rewriter.getTypeAttr(decoratedType), globalVarAttrs); + op, TypeAttr::get(decoratedType), globalVarAttrs); return matchSuccess(); } }; diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index 7177cfe7dff..739def41f8e 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -653,11 +653,11 @@ static Type getCheckedI1SameShape(Builder *build, Type type) { if (type.isIntOrIndexOrFloat()) return i1Type; if (auto tensorType = type.dyn_cast()) - return build->getTensorType(tensorType.getShape(), i1Type); + return RankedTensorType::get(tensorType.getShape(), i1Type); if (type.isa()) - return build->getTensorType(i1Type); + return UnrankedTensorType::get(i1Type); if (auto vectorType = type.dyn_cast()) - return build->getVectorType(vectorType.getShape(), i1Type); + return VectorType::get(vectorType.getShape(), i1Type); return Type(); } @@ -2241,7 +2241,7 @@ OpFoldResult TensorCastOp::fold(ArrayRef operands) { static Type getTensorTypeFromMemRefType(Builder &b, Type type) { if (auto memref = type.dyn_cast()) - return b.getTensorType(memref.getShape(), memref.getElementType()); + return RankedTensorType::get(memref.getShape(), memref.getElementType()); return b.getNoneType(); } diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index 22f25683087..a7006f0f1a9 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -206,7 +206,7 @@ void VectorTransferReadOp::build(Builder *builder, OperationState &result, result.addOperands({*paddingValue}); } result.addAttribute(getPermutationMapAttrName(), - builder->getAffineMapAttr(permutationMap)); + AffineMapAttr::get(permutationMap)); result.addTypes(vectorType); } @@ -383,7 +383,7 @@ void VectorTransferWriteOp::build(Builder *builder, OperationState &result, result.addOperands({srcVector, dstMemRef}); result.addOperands(dstIndices); result.addAttribute(getPermutationMapAttrName(), - builder->getAffineMapAttr(permutationMap)); + AffineMapAttr::get(permutationMap)); } auto VectorTransferWriteOp::getIndices() -> operand_range { diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 0a9389768b9..7ec5c3b6a26 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -72,25 +72,6 @@ FunctionType Builder::getFunctionType(ArrayRef inputs, return FunctionType::get(inputs, results, context); } -MemRefType Builder::getMemRefType(ArrayRef shape, Type elementType, - ArrayRef affineMapComposition, - unsigned memorySpace) { - return MemRefType::get(shape, elementType, affineMapComposition, memorySpace); -} - -VectorType Builder::getVectorType(ArrayRef shape, Type elementType) { - return VectorType::get(shape, elementType); -} - -RankedTensorType Builder::getTensorType(ArrayRef shape, - Type elementType) { - return RankedTensorType::get(shape, elementType); -} - -UnrankedTensorType Builder::getTensorType(Type elementType) { - return UnrankedTensorType::get(elementType); -} - TupleType Builder::getTupleType(ArrayRef elementTypes) { return TupleType::get(elementTypes, context); } @@ -165,24 +146,10 @@ StringAttr Builder::getStringAttr(StringRef bytes) { return StringAttr::get(bytes, context); } -StringAttr Builder::getStringAttr(StringRef bytes, Type type) { - return StringAttr::get(bytes, type); -} - ArrayAttr Builder::getArrayAttr(ArrayRef value) { return ArrayAttr::get(value, context); } -AffineMapAttr Builder::getAffineMapAttr(AffineMap map) { - return AffineMapAttr::get(map); -} - -IntegerSetAttr Builder::getIntegerSetAttr(IntegerSet set) { - return IntegerSetAttr::get(set); -} - -TypeAttr Builder::getTypeAttr(Type type) { return TypeAttr::get(type); } - SymbolRefAttr Builder::getSymbolRefAttr(Operation *value) { auto symName = value->getAttrOfType(SymbolTable::getSymbolAttrName()); @@ -193,27 +160,6 @@ SymbolRefAttr Builder::getSymbolRefAttr(StringRef value) { return SymbolRefAttr::get(value, getContext()); } -ElementsAttr Builder::getDenseElementsAttr(ShapedType type, - ArrayRef values) { - return DenseElementsAttr::get(type, values); -} - -ElementsAttr Builder::getDenseIntElementsAttr(ShapedType type, - ArrayRef values) { - return DenseIntElementsAttr::get(type, values); -} - -ElementsAttr Builder::getSparseElementsAttr(ShapedType type, - DenseIntElementsAttr indices, - DenseElementsAttr values) { - return SparseElementsAttr::get(type, indices, values); -} - -ElementsAttr Builder::getOpaqueElementsAttr(Dialect *dialect, ShapedType type, - StringRef bytes) { - return OpaqueElementsAttr::get(dialect, type, bytes); -} - ArrayAttr Builder::getI32ArrayAttr(ArrayRef values) { auto attrs = functional::map( [this](int32_t v) -> Attribute { return getI32IntegerAttr(v); }, values); @@ -255,7 +201,7 @@ ArrayAttr Builder::getStrArrayAttr(ArrayRef values) { ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef values) { auto attrs = functional::map( - [this](AffineMap v) -> Attribute { return getAffineMapAttr(v); }, values); + [](AffineMap v) -> Attribute { return AffineMapAttr::get(v); }, values); return getArrayAttr(attrs); } @@ -278,7 +224,7 @@ Attribute Builder::getZeroAttr(Type type) { auto element = getZeroAttr(vtType.getElementType()); if (!element) return {}; - return getDenseElementsAttr(vtType, element); + return DenseElementsAttr::get(vtType, element); } default: break; @@ -290,11 +236,6 @@ Attribute Builder::getZeroAttr(Type type) { // Affine Expressions, Affine Maps, and Integet Sets. //===----------------------------------------------------------------------===// -AffineMap Builder::getAffineMap(unsigned dimCount, unsigned symbolCount, - ArrayRef results) { - return AffineMap::get(dimCount, symbolCount, results); -} - AffineExpr Builder::getAffineDimExpr(unsigned position) { return mlir::getAffineDimExpr(position, context); } @@ -307,12 +248,6 @@ AffineExpr Builder::getAffineConstantExpr(int64_t constant) { return mlir::getAffineConstantExpr(constant, context); } -IntegerSet Builder::getIntegerSet(unsigned dimCount, unsigned symbolCount, - ArrayRef constraints, - ArrayRef isEq) { - return IntegerSet::get(dimCount, symbolCount, constraints, isEq); -} - AffineMap Builder::getEmptyAffineMap() { return AffineMap::get(context); } AffineMap Builder::getConstantAffineMap(int64_t val) { @@ -347,9 +282,8 @@ AffineMap Builder::getSingleDimShiftAffineMap(int64_t shift) { AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) { SmallVector shiftedResults; shiftedResults.reserve(map.getNumResults()); - for (auto resultExpr : map.getResults()) { + for (auto resultExpr : map.getResults()) shiftedResults.push_back(resultExpr + shift); - } return AffineMap::get(map.getNumDims(), map.getNumSymbols(), shiftedResults); } diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index 474dd1a5934..4f5a4737698 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -57,7 +57,7 @@ void FuncOp::build(Builder *builder, OperationState &result, StringRef name, FunctionType type, ArrayRef attrs) { result.addAttribute(SymbolTable::getSymbolAttrName(), builder->getStringAttr(name)); - result.addAttribute(getTypeAttrName(), builder->getTypeAttr(type)); + result.addAttribute(getTypeAttrName(), TypeAttr::get(type)); result.attributes.append(attrs.begin(), attrs.end()); result.addRegion(); } diff --git a/mlir/lib/IR/FunctionSupport.cpp b/mlir/lib/IR/FunctionSupport.cpp index b40eebb04b2..468301e9431 100644 --- a/mlir/lib/IR/FunctionSupport.cpp +++ b/mlir/lib/IR/FunctionSupport.cpp @@ -133,7 +133,7 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result, std::string errorMessage; if (auto type = funcTypeBuilder(builder, argTypes, results, impl::VariadicFlag(isVariadic), errorMessage)) - result.addAttribute(getTypeAttrName(), builder.getTypeAttr(type)); + result.addAttribute(getTypeAttrName(), TypeAttr::get(type)); else return parser.emitError(signatureLocation) << "failed to construct function type" diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 80d34ab76e5..873476ebffb 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -1063,9 +1063,9 @@ Attribute Parser::parseAttribute(Type type) { if (parseAffineMapOrIntegerSetReference(map, set)) return nullptr; if (map) - return builder.getAffineMapAttr(map); + return AffineMapAttr::get(map); assert(set); - return builder.getIntegerSetAttr(set); + return IntegerSetAttr::get(set); } // Parse an array attribute. @@ -1164,7 +1164,7 @@ Attribute Parser::parseAttribute(Type type) { default: // Parse a type attribute. if (Type type = parseType()) - return builder.getTypeAttr(type); + return TypeAttr::get(type); return nullptr; } } @@ -1381,7 +1381,7 @@ Attribute Parser::parseOpaqueElementsAttr() { if (!type) return nullptr; - return builder.getOpaqueElementsAttr(dialect, type, llvm::fromHex(val)); + return OpaqueElementsAttr::get(dialect, type, llvm::fromHex(val)); } namespace { @@ -2496,8 +2496,8 @@ ParseResult AffineParser::parseAffineMapOfSSAIds(AffineMap &map) { if (exprs.empty()) map = AffineMap(); else - map = builder.getAffineMap(numDimOperands, - dimsAndSymbols.size() - numDimOperands, exprs); + map = AffineMap::get(numDimOperands, dimsAndSymbols.size() - numDimOperands, + exprs); return success(); } @@ -2525,7 +2525,7 @@ AffineMap AffineParser::parseAffineMapRange(unsigned numDims, return AffineMap(); // Parsed a valid affine map. - return builder.getAffineMap(numDims, numSymbols, exprs); + return AffineMap::get(numDims, numSymbols, exprs); } /// Parse an affine constraint. @@ -2600,11 +2600,11 @@ IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims, if (constraints.empty()) { /* 0 == 0 */ auto zero = getAffineConstantExpr(0, getContext()); - return builder.getIntegerSet(numDims, numSymbols, zero, true); + return IntegerSet::get(numDims, numSymbols, zero, true); } // Parsed a valid integer set. - return builder.getIntegerSet(numDims, numSymbols, constraints, isEqs); + return IntegerSet::get(numDims, numSymbols, constraints, isEqs); } /// Parse an ambiguous reference to either and affine map or an integer set. @@ -3715,7 +3715,7 @@ public: return failure(); // Add AffineMap attribute. if (map) { - mapAttr = parser.builder.getAffineMapAttr(map); + mapAttr = AffineMapAttr::get(map); attrs.push_back(parser.builder.getNamedAttr(attrName, mapAttr)); } diff --git a/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp b/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp index a82a288caf3..a32bb2c9b3c 100644 --- a/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp +++ b/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp @@ -81,7 +81,7 @@ void AddDefaultStatsPass::runWithConfig(SolverContext &solverContext, APFloat minValue(-1.0f); APFloat maxValue(1.0f); ElementsAttr layerStats = DenseFPElementsAttr::get( - b.getTensorType({2}, b.getF32Type()), {minValue, maxValue}); + RankedTensorType::get({2}, b.getF32Type()), {minValue, maxValue}); auto statsOp = b.create(func.getLoc(), arg, layerStats, nullptr, nullptr); arg->replaceAllUsesWith(statsOp); @@ -107,7 +107,7 @@ void AddDefaultStatsPass::runWithConfig(SolverContext &solverContext, APFloat minValue(-1.0f); APFloat maxValue(1.0f); ElementsAttr layerStats = DenseFPElementsAttr::get( - b.getTensorType({2}, b.getF32Type()), {minValue, maxValue}); + RankedTensorType::get({2}, b.getF32Type()), {minValue, maxValue}); auto statsOp = b.create(op->getLoc(), op->getResult(0), layerStats, nullptr, nullptr); originalResult->replaceAllUsesWith(statsOp); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 15dc36c9c13..7e08c6bbc57 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -953,8 +953,8 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, } else { newMemSpace = oldMemRefType.getMemorySpace(); } - auto newMemRefType = top.getMemRefType( - newShape, oldMemRefType.getElementType(), {}, newMemSpace); + auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(), + {}, newMemSpace); // Gather alloc operands for the dynamic dimensions of the memref. SmallVector allocOperands; unsigned dynamicDimCount = 0; @@ -988,7 +988,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, } auto indexRemap = zeroOffsetCount == rank ? AffineMap() - : b.getAffineMap(outerIVs.size() + rank, 0, remapExprs); + : AffineMap::get(outerIVs.size() + rank, 0, remapExprs); // Replace all users of 'oldMemRef' with 'newMemRef'. LogicalResult res = replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap, diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index af1ecd06ee6..4ee7197f2df 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -168,13 +168,13 @@ constructTiledIndexSetHyperRect(MutableArrayRef origLoops, boundExprs.push_back(dim + tileSizes[i]); boundExprs.append(origUbMap.getResults().begin(), origUbMap.getResults().end()); - auto ubMap = b.getAffineMap(origUbMap.getNumDims() + 1, + auto ubMap = AffineMap::get(origUbMap.getNumDims() + 1, origUbMap.getNumSymbols(), boundExprs); newLoops[width + i].setUpperBound(/*operands=*/ubOperands, ubMap); } else { // No need of the min expression. auto dim = b.getAffineDimExpr(0); - auto ubMap = b.getAffineMap(1, 0, dim + tileSizes[i]); + auto ubMap = AffineMap::get(1, 0, dim + tileSizes[i]); newLoops[width + i].setUpperBound(newLoops[i].getInductionVar(), ubMap); } } diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 2b7330f4175..230869abcd5 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -223,7 +223,7 @@ LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp, if (!forOpIV->use_empty()) { // iv' = iv + i, i = 1 to unrollJamFactor-1. auto d0 = builder.getAffineDimExpr(0); - auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}); + auto bumpMap = AffineMap::get(1, 0, {d0 + i * step}); auto ivUnroll = builder.create(forInst->getLoc(), bumpMap, forOpIV); operandMapping.map(forOpIV, ivUnroll); diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index bb9c39baf55..7e175fb22d2 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -82,8 +82,8 @@ static bool doubleBuffer(Value *oldMemRef, AffineForOp forOp) { newShape[0] = 2; std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1); auto newMemRefType = - bInner.getMemRefType(newShape, oldMemRefType.getElementType(), {}, - oldMemRefType.getMemorySpace()); + MemRefType::get(newShape, oldMemRefType.getElementType(), {}, + oldMemRefType.getMemorySpace()); return newMemRefType; }; @@ -109,8 +109,8 @@ static bool doubleBuffer(Value *oldMemRef, AffineForOp forOp) { // Create 'iv mod 2' value to index the leading dimension. auto d0 = bInner.getAffineDimExpr(0); int64_t step = forOp.getStep(); - auto modTwoMap = bInner.getAffineMap(/*dimCount=*/1, /*symbolCount=*/0, - {d0.floorDiv(step) % 2}); + auto modTwoMap = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, + {d0.floorDiv(step) % 2}); auto ivModTwoOp = bInner.create(forOp.getLoc(), modTwoMap, forOp.getInductionVar()); diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 1872044b0fb..fb96772f053 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -87,7 +87,7 @@ void mlir::getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor, for (unsigned i = 0, e = tripCountMap.getNumResults(); i < e; i++) { auto tripCountExpr = tripCountMap.getResult(i); bumpExprs[i] = (tripCountExpr - tripCountExpr % unrollFactor) * step; - auto bumpMap = b.getAffineMap(tripCountMap.getNumDims(), + auto bumpMap = AffineMap::get(tripCountMap.getNumDims(), tripCountMap.getNumSymbols(), bumpExprs[i]); bumpValues[i] = b.create(forOp.getLoc(), bumpMap, tripCountOperands); @@ -100,7 +100,7 @@ void mlir::getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor, operands->clear(); operands->push_back(lb); operands->append(bumpValues.begin(), bumpValues.end()); - *map = b.getAffineMap(1 + tripCountMap.getNumResults(), 0, newUbExprs); + *map = AffineMap::get(1 + tripCountMap.getNumResults(), 0, newUbExprs); // Simplify the map + operands. fullyComposeAffineMapAndOperands(map, operands); *map = simplifyAffineMap(*map); @@ -487,7 +487,7 @@ LogicalResult mlir::loopUnrollByFactor(AffineForOp forOp, if (!forOpIV->use_empty()) { // iv' = iv + 1/2/3...unrollFactor-1; auto d0 = builder.getAffineDimExpr(0); - auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}); + auto bumpMap = AffineMap::get(1, 0, {d0 + i * step}); auto ivUnroll = builder.create(forOp.getLoc(), bumpMap, forOpIV); operandMap.map(forOpIV, ivUnroll); @@ -676,7 +676,7 @@ static void augmentMapAndBounds(OpBuilder &b, Value *iv, AffineMap *map, auto bounds = llvm::to_vector<4>(map->getResults()); bounds.push_back(b.getAffineDimExpr(map->getNumDims()) + offset); operands->insert(operands->begin() + map->getNumDims(), iv); - *map = b.getAffineMap(map->getNumDims() + 1, map->getNumSymbols(), bounds); + *map = AffineMap::get(map->getNumDims() + 1, map->getNumSymbols(), bounds); canonicalizeMapAndOperands(map, operands); } @@ -1229,7 +1229,7 @@ static AffineForOp generatePointWiseCopy(Location loc, Value *memref, ? memIndicesStart[d] : b.create( loc, - b.getAffineMap(memAffineMap.getNumDims(), + AffineMap::get(memAffineMap.getNumDims(), memAffineMap.getNumSymbols(), memAffineMap.getResult(d)), memIndicesStart); @@ -1238,7 +1238,7 @@ static AffineForOp generatePointWiseCopy(Location loc, Value *memref, SmallVector operands = {memBase, forOp.getInductionVar()}; auto memIndex = b.create( loc, - b.getAffineMap(2, 0, b.getAffineDimExpr(0) + b.getAffineDimExpr(1)), + AffineMap::get(2, 0, b.getAffineDimExpr(0) + b.getAffineDimExpr(1)), operands); memIndices.push_back(memIndex); } @@ -1381,7 +1381,7 @@ static LogicalResult generateCopy( } else { // The coordinate for the start location is just the lower bound along the // corresponding dimension on the memory region (stored in 'offset'). - auto map = top.getAffineMap( + auto map = AffineMap::get( cst->getNumDimIds() + cst->getNumSymbolIds() - rank, 0, offset); memIndices.push_back(b.create(loc, map, regionSymbols)); } @@ -1401,8 +1401,8 @@ static LogicalResult generateCopy( if (!existingBuf) { AffineMap fastBufferLayout = b.getMultiDimIdentityMap(rank); auto fastMemRefType = - top.getMemRefType(fastBufferShape, memRefType.getElementType(), - fastBufferLayout, copyOptions.fastMemorySpace); + MemRefType::get(fastBufferShape, memRefType.getElementType(), + fastBufferLayout, copyOptions.fastMemorySpace); // Create the fast memory space buffer just before the 'affine.for' // operation. @@ -1470,8 +1470,8 @@ static LogicalResult generateCopy( } else { // DMA generation. // Create a tag (single element 1-d memref) for the DMA. - auto tagMemRefType = top.getMemRefType({1}, top.getIntegerType(32), {}, - copyOptions.tagMemorySpace); + auto tagMemRefType = MemRefType::get({1}, top.getIntegerType(32), {}, + copyOptions.tagMemorySpace); auto tagMemRef = prologue.create(loc, tagMemRefType); SmallVector tagIndices({zeroIndex}); @@ -1532,7 +1532,7 @@ static LogicalResult generateCopy( auto dimExpr = b.getAffineDimExpr(regionSymbols.size() + i); remapExprs.push_back(dimExpr - offsets[i]); } - auto indexRemap = b.getAffineMap(regionSymbols.size() + rank, 0, remapExprs); + auto indexRemap = AffineMap::get(regionSymbols.size() + rank, 0, remapExprs); // Record the begin since it may be invalidated by memref replacement. Block::iterator prevOfBegin; diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index d6400ac50ed..35a5273a28f 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -119,8 +119,8 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, oldMemRefOperands.reserve(oldMemRefRank); if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) { for (auto resultExpr : oldMap.getResults()) { - auto singleResMap = builder.getAffineMap( - oldMap.getNumDims(), oldMap.getNumSymbols(), resultExpr); + auto singleResMap = AffineMap::get(oldMap.getNumDims(), + oldMap.getNumSymbols(), resultExpr); auto afOp = builder.create(op->getLoc(), singleResMap, oldMapOperands); oldMemRefOperands.push_back(afOp); @@ -147,7 +147,7 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) { // Remapped indices. for (auto resultExpr : indexRemap.getResults()) { - auto singleResMap = builder.getAffineMap( + auto singleResMap = AffineMap::get( indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr); auto afOp = builder.create(op->getLoc(), singleResMap, remapOperands); @@ -210,7 +210,7 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, state.types.push_back(result->getType()); // Add attribute for 'newMap', other Attributes do not change. - auto newMapAttr = builder.getAffineMapAttr(newMap); + auto newMapAttr = AffineMapAttr::get(newMap); for (auto namedAttr : op->getAttrs()) { if (namedAttr.first == oldMapAttrPair.first) { state.attributes.push_back({namedAttr.first, newMapAttr}); @@ -371,8 +371,8 @@ void mlir::createAffineComputationSlice( // Create an affine.apply for each of the map results. sliceOps->reserve(composedMap.getNumResults()); for (auto resultExpr : composedMap.getResults()) { - auto singleResMap = builder.getAffineMap( - composedMap.getNumDims(), composedMap.getNumSymbols(), resultExpr); + auto singleResMap = AffineMap::get(composedMap.getNumDims(), + composedMap.getNumSymbols(), resultExpr); sliceOps->push_back(builder.create( opInst->getLoc(), singleResMap, composedOpOperands)); } @@ -457,7 +457,7 @@ LogicalResult mlir::normalizeMemRef(AllocOp allocOp) { auto *oldMemRef = allocOp.getResult(); SmallVector symbolOperands(allocOp.getSymbolicOperands()); - auto newMemRefType = b.getMemRefType(newShape, memrefType.getElementType(), + auto newMemRefType = MemRefType::get(newShape, memrefType.getElementType(), b.getMultiDimIdentityMap(newRank)); auto newAlloc = b.create(allocOp.getLoc(), newMemRefType); diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index a54b05e980a..1e10f372b5f 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -772,7 +772,7 @@ static void computeMemoryOpIndices(Operation *op, AffineMap map, OpBuilder builder(op); for (auto resultExpr : map.getResults()) { auto singleResMap = - builder.getAffineMap(map.getNumDims(), map.getNumSymbols(), resultExpr); + AffineMap::get(map.getNumDims(), map.getNumSymbols(), resultExpr); auto afOp = builder.create(op->getLoc(), singleResMap, mapOperands); results.push_back(afOp); diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp index 73b2141970a..73a7366b5e8 100644 --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -772,7 +772,7 @@ TEST_FUNC(affine_if_op) { builder.getAffineSymbolExpr(0), // s0 >= 0 builder.getAffineSymbolExpr(1) // s1 >= 0 }; - auto intSet = builder.getIntegerSet(2, 2, affineExprs, isEq); + auto intSet = IntegerSet::get(2, 2, affineExprs, isEq); SmallVector affineIfArgs = {zero, zero, ten, ten}; intrinsics::affine_if(intSet, affineIfArgs, /*withElseRegion=*/false); diff --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp index ad9d6b21c95..ecf37139ac3 100644 --- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp +++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp @@ -71,7 +71,7 @@ protected: OpBuilder opBuilder(module.body()); auto ptrType = spirv::PointerType::get(type, spirv::StorageClass::Uniform); opBuilder.create( - UnknownLoc::get(&context), opBuilder.getTypeAttr(ptrType), + UnknownLoc::get(&context), TypeAttr::get(ptrType), opBuilder.getStringAttr(name), nullptr); } -- cgit v1.2.3 From 8bfedb3ca599ccf3c507e721f4bf5e6a3b026f8c Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 20 Oct 2019 00:11:03 -0700 Subject: Fix minor spelling tweaks (NFC) Closes tensorflow/mlir#177 PiperOrigin-RevId: 275692653 --- mlir/lib/Analysis/AffineAnalysis.cpp | 16 +++++++-------- mlir/lib/Analysis/AffineStructures.cpp | 14 ++++++------- mlir/lib/Analysis/MemRefBoundCheck.cpp | 2 +- mlir/lib/Analysis/NestedMatcher.cpp | 2 +- mlir/lib/Analysis/TestMemRefDependenceCheck.cpp | 2 +- mlir/lib/Analysis/Utils.cpp | 4 ++-- mlir/lib/Analysis/VectorAnalysis.cpp | 2 +- .../GPUToCUDA/ConvertKernelFuncToCubin.cpp | 6 +++--- .../StandardToSPIRV/ConvertStandardToSPIRV.cpp | 4 ++-- .../FxpMathOps/Transforms/LowerUniformRealMath.cpp | 3 ++- .../lib/Dialect/GPU/Transforms/KernelOutlining.cpp | 2 +- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 6 +++--- .../Linalg/Transforms/LowerToLLVMDialect.cpp | 8 ++++---- mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 2 +- mlir/lib/Dialect/QuantOps/IR/TypeDetail.h | 2 +- .../Dialect/QuantOps/Utils/FakeQuantSupport.cpp | 2 +- mlir/lib/Dialect/SDBM/SDBMExpr.cpp | 4 ++-- mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 6 +++--- .../Dialect/SPIRV/Serialization/Deserializer.cpp | 24 +++++++++++----------- .../lib/Dialect/SPIRV/Serialization/Serializer.cpp | 14 ++++++------- mlir/lib/ExecutionEngine/OptUtils.cpp | 2 +- mlir/lib/IR/Builders.cpp | 2 +- mlir/lib/IR/Diagnostics.cpp | 2 +- mlir/lib/IR/MLIRContext.cpp | 2 +- mlir/lib/IR/SymbolTable.cpp | 2 +- mlir/lib/Parser/Lexer.cpp | 4 ++-- mlir/lib/Parser/Parser.cpp | 10 ++++----- mlir/lib/Pass/Pass.cpp | 2 +- mlir/lib/Pass/PassTiming.cpp | 6 +++--- mlir/lib/Quantizer/Support/UniformSolvers.cpp | 2 +- mlir/lib/TableGen/Pattern.cpp | 2 +- mlir/lib/TableGen/Predicate.cpp | 4 ++-- mlir/lib/Transforms/LoopFusion.cpp | 20 +++++++++--------- mlir/lib/Transforms/LowerAffine.cpp | 2 +- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 2 +- mlir/lib/Transforms/Utils/LoopUtils.cpp | 10 ++++----- mlir/lib/Transforms/Vectorize.cpp | 12 +++++------ .../FxpMathOps/lower-uniform-real-math-mulew.mlir | 2 +- 38 files changed, 107 insertions(+), 106 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index 4d2fac96913..c4b88453487 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -114,7 +114,7 @@ LogicalResult mlir::getIndexSet(MutableArrayRef forOps, // Computes the iteration domain for 'opInst' and populates 'indexSet', which // encapsulates the constraints involving loops surrounding 'opInst' and // potentially involving any Function symbols. The dimensional identifiers in -// 'indexSet' correspond to the loops surounding 'op' from outermost to +// 'indexSet' correspond to the loops surrounding 'op' from outermost to // innermost. // TODO(andydavis) Add support to handle IfInsts surrounding 'op'. static LogicalResult getInstIndexSet(Operation *op, @@ -133,11 +133,11 @@ static LogicalResult getInstIndexSet(Operation *op, // Position lookups return the absolute position in the new space which // has the following format: // -// [src-dim-identifiers] [dst-dim-identifiers] [symbol-identifers] +// [src-dim-identifiers] [dst-dim-identifiers] [symbol-identifiers] // // Note: access function non-IV dimension identifiers (that have 'dimension' // positions in the access function position space) are assigned as symbols -// in the output position space. Convienience access functions which lookup +// in the output position space. Convenience access functions which lookup // an Value in multiple maps are provided (i.e. getSrcDimOrSymPos) to handle // the common case of resolving positions for all access function operands. // @@ -634,7 +634,7 @@ static void computeDirectionVector( dependenceDomain->addDimId(j); } - // Add equality contraints for each common loop, setting newly introduced + // Add equality constraints for each common loop, setting newly introduced // variable at column 'j' to the 'dst' IV minus the 'src IV. SmallVector eq; eq.resize(dependenceDomain->getNumCols()); @@ -698,7 +698,7 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const { // composed with AffineApplyOps reachable from operands of that access, // until operands of the AffineValueMap are loop IVs or symbols. // *) Build iteration domain constraints for each access. Iteration domain -// constraints are pairs of inequality contraints representing the +// constraints are pairs of inequality constraints representing the // upper/lower loop bounds for each AffineForOp in the loop nest associated // with each access. // *) Build dimension and symbol position maps for each access, which map @@ -709,8 +709,8 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const { // // [src-dim-identifiers, dst-dim-identifiers, symbols, constant] // -// For example, given the following MLIR code with with "source" and -// "destination" accesses to the same memref labled, and symbols %M, %N, %K: +// For example, given the following MLIR code with "source" and "destination" +// accesses to the same memref label, and symbols %M, %N, %K: // // affine.for %i0 = 0 to 100 { // affine.for %i1 = 0 to 50 { @@ -819,7 +819,7 @@ DependenceResult mlir::checkMemrefAccessDependence( return DependenceResult::NoDependence; } // Build dim and symbol position maps for each access from access operand - // Value to position in merged contstraint system. + // Value to position in merged constraint system. ValuePositionMap valuePosMap; buildDimAndSymbolPositionMaps(srcDomain, dstDomain, srcAccessMap, dstAccessMap, &valuePosMap, diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index b568d6eacf6..4b171f0bede 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -916,7 +916,7 @@ findConstraintWithNonZeroAt(const FlatAffineConstraints &constraints, } // Normalizes the coefficient values across all columns in 'rowIDx' by their -// GCD in equality or inequality contraints as specified by 'isEq'. +// GCD in equality or inequality constraints as specified by 'isEq'. template static void normalizeConstraintByGCD(FlatAffineConstraints *constraints, unsigned rowIdx) { @@ -1161,7 +1161,7 @@ bool FlatAffineConstraints::isEmpty() const { getBestIdToEliminate(tmpCst, 0, tmpCst.getNumIds())); // Check for a constraint explosion. This rarely happens in practice, but // this check exists as a safeguard against improperly constructed - // constraint systems or artifically created arbitrarily complex systems + // constraint systems or artificially created arbitrarily complex systems // that aren't the intended use case for FlatAffineConstraints. This is // needed since FM has a worst case exponential complexity in theory. if (tmpCst.getNumConstraints() >= kExplosionFactor * getNumIds()) { @@ -1233,7 +1233,7 @@ void FlatAffineConstraints::GCDTightenInequalities() { } } -// Eliminates all identifer variables in column range [posStart, posLimit). +// Eliminates all identifier variables in column range [posStart, posLimit). // Returns the number of variables eliminated. unsigned FlatAffineConstraints::gaussianEliminateIds(unsigned posStart, unsigned posLimit) { @@ -1712,7 +1712,7 @@ void FlatAffineConstraints::getSliceBounds(unsigned offset, unsigned num, // Work on a copy so that we don't update this constraint system. if (!tmpClone) { tmpClone.emplace(FlatAffineConstraints(*this)); - // Removing redudnant inequalities is necessary so that we don't get + // Removing redundant inequalities is necessary so that we don't get // redundant loop bounds. tmpClone->removeRedundantInequalities(); } @@ -1766,7 +1766,7 @@ FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap, if (eq) lower = true; - // Fully commpose map and operands; canonicalize and simplify so that we + // Fully compose map and operands; canonicalize and simplify so that we // transitively get to terminal symbols or loop IVs. auto map = boundMap; SmallVector operands(boundOperands.begin(), boundOperands.end()); @@ -1996,7 +1996,7 @@ void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) { numSymbols = newSymbolCount; } -/// Sets the specified identifer to a constant value. +/// Sets the specified identifier to a constant value. void FlatAffineConstraints::setIdToConstant(unsigned pos, int64_t val) { unsigned offset = equalities.size(); equalities.resize(equalities.size() + numReservedCols); @@ -2006,7 +2006,7 @@ void FlatAffineConstraints::setIdToConstant(unsigned pos, int64_t val) { equalities[offset + getNumCols() - 1] = -val; } -/// Sets the specified identifer to a constant value; asserts if the id is not +/// Sets the specified identifier to a constant value; asserts if the id is not /// found. void FlatAffineConstraints::setIdToConstant(Value &id, int64_t val) { unsigned pos; diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index 1d115b13082..52379c0a1d0 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -15,7 +15,7 @@ // limitations under the License. // ============================================================================= // -// This file implements a pass to check memref accessses for out of bound +// This file implements a pass to check memref accesses for out of bound // accesses. // //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Analysis/NestedMatcher.cpp b/mlir/lib/Analysis/NestedMatcher.cpp index 9d7d17f836c..5f2be48b327 100644 --- a/mlir/lib/Analysis/NestedMatcher.cpp +++ b/mlir/lib/Analysis/NestedMatcher.cpp @@ -74,7 +74,7 @@ unsigned NestedPattern::getDepth() const { /// there is no match; /// 2. calls the customizable filter function to refine the single operation /// match with extra semantic constraints; -/// 3. if all is good, recursivey matches the nested patterns; +/// 3. if all is good, recursively matches the nested patterns; /// 4. if all nested match then the single operation matches too and is /// appended to the list of matches; /// 5. TODO(ntv) Optionally applies actions (lambda), in which case we will diff --git a/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp b/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp index c73bf72f127..d0351e9bcf9 100644 --- a/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp @@ -78,7 +78,7 @@ getDirectionVectorStr(bool ret, unsigned numCommonLoops, unsigned loopNestDepth, return result; } -// For each access in 'loadsAndStores', runs a depence check between this +// For each access in 'loadsAndStores', runs a dependence check between this // "source" access and all subsequent "destination" accesses in // 'loadsAndStores'. Emits the result of the dependence check as a note with // the source access. diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 354d03423a7..042c744c74f 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -40,7 +40,7 @@ using llvm::SmallDenseMap; void mlir::getLoopIVs(Operation &op, SmallVectorImpl *loops) { auto *currOp = op.getParentOp(); AffineForOp currAffineForOp; - // Traverse up the hierarchy collecing all 'affine.for' operation while + // Traverse up the hierarchy collecting all 'affine.for' operation while // skipping over 'affine.if' operations. while (currOp && ((currAffineForOp = dyn_cast(currOp)) || isa(currOp))) { @@ -222,7 +222,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth, cst.reset(numDims, numSymbols, 0, operands); // Add equality constraints. - // Add inequalties for loop lower/upper bounds. + // Add inequalities for loop lower/upper bounds. for (unsigned i = 0; i < numDims + numSymbols; ++i) { auto *operand = operands[i]; if (auto loop = getForInductionVarOwner(operand)) { diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index 9846abb7be2..e765ce35e74 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -182,7 +182,7 @@ AffineMap mlir::makePermutationMap( bool mlir::matcher::operatesOnSuperVectorsOf(Operation &op, VectorType subVectorType) { - // First, extract the vector type and ditinguish between: + // First, extract the vector type and distinguish between: // a. ops that *must* lower a super-vector (i.e. vector.transfer_read, // vector.transfer_write); and // b. ops that *may* lower a super-vector (all other ops). diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp index c76381f9d0a..8c6fdf9aba0 100644 --- a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp @@ -94,9 +94,9 @@ private: std::string translateModuleToPtx(llvm::Module &module, llvm::TargetMachine &target_machine); - /// Converts llvmModule to cubin using the user-provded generator. Location is - /// used for error reporting and name is forwarded to the CUBIN generator to - /// use in its logging mechanisms. + /// Converts llvmModule to cubin using the user-provided generator. Location + /// is used for error reporting and name is forwarded to the CUBIN generator + /// to use in its logging mechanisms. OwnedCubin convertModuleToCubin(llvm::Module &llvmModule, Location loc, StringRef name); diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp index 718f8077981..56b243c2971 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -353,7 +353,7 @@ public: /// Convert load -> spv.LoadOp. The operands of the replaced operation are of /// IndexType while that of the replacement operation are of type i32. This is -/// not suppored in tablegen based pattern specification. +/// not supported in tablegen based pattern specification. // TODO(ravishankarm) : These could potentially be templated on the operation // being converted, since the same logic should work for linalg.load. class LoadOpConversion final : public ConversionPattern { @@ -398,7 +398,7 @@ public: /// Convert store -> spv.StoreOp. The operands of the replaced operation are of /// IndexType while that of the replacement operation are of type i32. This is -/// not suppored in tablegen based pattern specification. +/// not supported in tablegen based pattern specification. // TODO(ravishankarm) : These could potentially be templated on the operation // being converted, since the same logic should work for linalg.store. class StoreOpConversion final : public ConversionPattern { diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp index a0b075ace6a..3982a6a4713 100644 --- a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp +++ b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp @@ -232,7 +232,8 @@ tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info, info.rhsType.getScale() / info.resultType.getScale(); if (outputMultiplierReal > 1.0) { - info.op->emitWarning("unimplemented: cannot multiply with multipler > 1.0"); + info.op->emitWarning( + "unimplemented: cannot multiply with multiplier > 1.0"); return failure(); } diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp index e2b0e463de0..8377ec64e7e 100644 --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -1,4 +1,4 @@ -//===- KernelOutlining.cpp - Implementation of GPU kernel outling ---------===// +//===- KernelOutlining.cpp - Implementation of GPU kernel outlining -------===// // // Copyright 2019 The MLIR Authors. // diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index ad1fa7705eb..d7d33072d7d 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -487,8 +487,8 @@ static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser &parser, // Infer the element type from the structure type: iteratively step inside the // type by taking the element type, indexed by the position attribute for - // stuctures. Check the position index before accessing, it is supposed to be - // in bounds. + // structures. Check the position index before accessing, it is supposed to + // be in bounds. for (Attribute subAttr : positionArrayAttr) { auto positionElementAttr = subAttr.dyn_cast(); if (!positionElementAttr) @@ -1470,7 +1470,7 @@ Value *mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder, LLVM::LLVMDialect *llvmDialect) { assert(builder.getInsertionBlock() && builder.getInsertionBlock()->getParentOp() && - "expected builder to point to a block constained in an op"); + "expected builder to point to a block constrained in an op"); auto module = builder.getInsertionBlock()->getParentOp()->getParentOfType(); assert(module && "builder points to an op outside of a module"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp index 90a76dedd9b..7adf589d888 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -443,12 +443,12 @@ public: TransposeOpOperandAdaptor adaptor(operands); Value *baseDesc = adaptor.view(); - auto tranposeOp = cast(op); + auto transposeOp = cast(op); // No permutation, early exit. - if (tranposeOp.permutation().isIdentity()) + if (transposeOp.permutation().isIdentity()) return rewriter.replaceOp(op, baseDesc), matchSuccess(); - BaseViewConversionHelper helper(op->getLoc(), tranposeOp.getViewType(), + BaseViewConversionHelper helper(op->getLoc(), transposeOp.getViewType(), rewriter, lowering); LLVMType elementTy = helper.elementTy, int64Ty = helper.int64Ty; Value *desc = helper.desc; @@ -463,7 +463,7 @@ public: desc = insertvalue(desc, extractvalue(int64Ty, baseDesc, offPos), offPos); // Iterate over the dimensions and apply size/stride permutation. - for (auto en : llvm::enumerate(tranposeOp.permutation().getResults())) { + for (auto en : llvm::enumerate(transposeOp.permutation().getResults())) { int sourcePos = en.index(); int targetPos = en.value().cast().getPosition(); Value *size = extractvalue(int64Ty, baseDesc, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 31cdd4e2c06..a499f342c95 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -112,7 +112,7 @@ struct TileCheck : public AffineExprVisitor { visit(expr.getRHS()); if (expr.getKind() == mlir::AffineExprKind::Mul) assert(expr.getRHS().cast().getValue() > 0 && - "nonpositive multipliying coefficient"); + "nonpositive multiplying coefficient"); } bool isTiled; ArrayRef tileSizes; diff --git a/mlir/lib/Dialect/QuantOps/IR/TypeDetail.h b/mlir/lib/Dialect/QuantOps/IR/TypeDetail.h index 4949b128481..13a88da3043 100644 --- a/mlir/lib/Dialect/QuantOps/IR/TypeDetail.h +++ b/mlir/lib/Dialect/QuantOps/IR/TypeDetail.h @@ -224,7 +224,7 @@ struct UniformQuantizedPerAxisTypeStorage : public QuantizedTypeStorage { }; // We pass scales and zeroPoints in directly rather than relying on KeyTy - // because we have to create new reallocated versions in `constrcut` below. + // because we have to create new reallocated versions in `construct` below. UniformQuantizedPerAxisTypeStorage(const KeyTy &key, ArrayRef scales, ArrayRef zeroPoints) : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType, diff --git a/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp b/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp index 2e1bd958b79..10668f87ed4 100644 --- a/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp +++ b/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp @@ -58,7 +58,7 @@ bool getDefaultStorageParams(unsigned numBits, bool narrowRange, bool isSigned, // If 0.0 < rmin < rmax or rmin < rmax < 0.0, the range will be shifted // to include 0.0, but the range width size (rmax-rmin) isn't changed. The zero // point is derived from the shifted range, and the scale isn't changed. As -// a consequence some values, which are supposeed in the original [rmin, rmax] +// a consequence some values, which are supposed in the original [rmin, rmax] // range will be outside the shifted range and be clamped during quantization. // TODO(fengliuai): we should nudge the scale as well, but that requires the // fake quant op used in the training to use the nudged scale as well. diff --git a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp index 04e6eb3d67b..8f6b59d8e45 100644 --- a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp +++ b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp @@ -336,7 +336,7 @@ Result addConstantAndSink(SDBMDirectExpr expr, int64_t constant, bool negated, // Construct an expression lhs + constant while maintaining the canonical form // of the SDBM expressions, in particular sink the constant expression to the -// nearest sum expression in the left subtree of the expresison tree. +// nearest sum expression in the left subtree of the expression tree. static SDBMExpr addConstant(SDBMVaryingExpr lhs, int64_t constant) { if (auto lhsDiff = lhs.dyn_cast()) return addConstantAndSink( @@ -438,7 +438,7 @@ Optional SDBMExpr::tryConvertAffineExpr(AffineExpr affine) { assert(!lhs.isa() && "non-canonical affine expression"); // If RHS is a constant, we can always extend the SDBM expression to - // include it by sinking the constant into the nearest sum expresion. + // include it by sinking the constant into the nearest sum expression. if (auto rhsConstant = rhs.dyn_cast()) { int64_t constant = rhsConstant.getValue(); auto varying = lhs.dyn_cast(); diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index b2ae902997a..44fecf32b9e 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -804,7 +804,7 @@ static ParseResult parseCompositeExtractOp(OpAsmParser &parser, } else { return parser.emitError( attrLocation, - "expexted an 32-bit integer for index, but found '") + "expected an 32-bit integer for index, but found '") << indexAttr << "'"; } @@ -838,7 +838,7 @@ static LogicalResult verify(spirv::CompositeExtractOp compExOp) { if (!indicesArrayAttr.size()) { return compExOp.emitOpError( - "expexted at least one index for spv.CompositeExtractOp"); + "expected at least one index for spv.CompositeExtractOp"); } int32_t index; @@ -953,7 +953,7 @@ bool spirv::ConstantOp::isBuildableWith(Type type) { if (type.getKind() >= Type::FIRST_SPIRV_TYPE && type.getKind() <= spirv::TypeKind::LAST_SPIRV_TYPE) { - // TODO(antiagainst): support contant struct + // TODO(antiagainst): support constant struct return type.isa(); } diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index 6ba18d1f1d0..8e7673b026b 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -15,7 +15,7 @@ // limitations under the License. // ============================================================================= // -// This file defines the SPIR-V binary to MLIR SPIR-V module deseralization. +// This file defines the SPIR-V binary to MLIR SPIR-V module deserialization. // //===----------------------------------------------------------------------===// @@ -92,7 +92,7 @@ private: /// in the deserializer. LogicalResult processCapability(ArrayRef operands); - /// Attaches all collected capabilites to `module` as an attribute. + /// Attaches all collected capabilities to `module` as an attribute. void attachCapabilities(); /// Processes the SPIR-V OpExtension with `operands` and updates bookkeeping @@ -135,7 +135,7 @@ private: /// Gets the constant's attribute and type associated with the given . Optional> getConstant(uint32_t id); - /// Gets the constants's integer attribute with the given . Returns a null + /// Gets the constant's integer attribute with the given . Returns a null /// IntegerAttr if the given is not registered or does not correspond to an /// integer constant. IntegerAttr getConstantInt(uint32_t id); @@ -306,7 +306,7 @@ private: /// This method is the main entrance for handling SPIR-V instruction; it /// checks the instruction opcode and dispatches to the corresponding handler. /// Processing of Some instructions (like OpEntryPoint and OpExecutionMode) - /// might need to be defered, since they contain forward references to s + /// might need to be deferred, since they contain forward references to s /// in the deserialized binary, but module in SPIR-V dialect expects these to /// be ssa-uses. LogicalResult processInstruction(spirv::Opcode opcode, @@ -436,7 +436,7 @@ private: // Result to extended instruction set name. DenseMap extendedInstSets; - // List of instructions that are processed in a defered fashion (after an + // List of instructions that are processed in a deferred fashion (after an // initial processing of the entire binary). Some operations like // OpEntryPoint, and OpExecutionMode use forward references to function // s. In SPIR-V dialect the corresponding operations (spv.EntryPoint and @@ -444,7 +444,7 @@ private: // are deserialized and stored for processing once the entire binary is // processed. SmallVector>, 4> - deferedInstructions; + deferredInstructions; }; } // namespace @@ -462,7 +462,7 @@ LogicalResult Deserializer::deserialize() { auto binarySize = binary.size(); while (curOffset < binarySize) { // Slice the next instruction out and populate `opcode` and `operands`. - // Interally this also updates `curOffset`. + // Internally this also updates `curOffset`. if (failed(sliceInstruction(opcode, operands))) return failure(); @@ -473,8 +473,8 @@ LogicalResult Deserializer::deserialize() { assert(curOffset == binarySize && "deserializer should never index beyond the binary end"); - for (auto &defered : deferedInstructions) { - if (failed(processInstruction(defered.first, defered.second, false))) { + for (auto &deferred : deferredInstructions) { + if (failed(processInstruction(deferred.first, deferred.second, false))) { return failure(); } } @@ -564,7 +564,7 @@ LogicalResult Deserializer::processExtInstImport(ArrayRef words) { if (words.size() < 2) { return emitError(unknownLoc, "OpExtInstImport must have a result and a literal " - "string for the extensed instruction set name"); + "string for the extended instruction set name"); } unsigned wordIndex = 1; @@ -1049,7 +1049,7 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode, floatTy = opBuilder.getF64Type(); break; default: - return emitError(unknownLoc, "unsupported OpTypeFloat bitwdith: ") + return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ") << operands[1]; } typeMap[operands[0]] = floatTy; @@ -1885,7 +1885,7 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode, case spirv::Opcode::OpEntryPoint: case spirv::Opcode::OpExecutionMode: if (deferInstructions) { - deferedInstructions.emplace_back(opcode, operands); + deferredInstructions.emplace_back(opcode, operands); return success(); } break; diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index 58a33356f6e..241be2a4297 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -15,7 +15,7 @@ // limitations under the License. // ============================================================================= // -// This file defines the MLIR SPIR-V module to SPIR-V binary seralization. +// This file defines the MLIR SPIR-V module to SPIR-V binary serialization. // //===----------------------------------------------------------------------===// @@ -149,7 +149,7 @@ private: template LogicalResult processTypeDecoration(Location loc, DType type, uint32_t resultId) { - return emitError(loc, "unhandled decoraion for type:") << type; + return emitError(loc, "unhandled decoration for type:") << type; } /// Process member decoration @@ -371,7 +371,7 @@ LogicalResult Serializer::serialize() { processExtension(); processMemoryModel(); - // Iterate over the module body to serialze it. Assumptions are that there is + // Iterate over the module body to serialize it. Assumptions are that there is // only one basic block in the moduleOp for (auto &op : module.getBlock()) { if (failed(processOperation(&op))) { @@ -1073,7 +1073,7 @@ uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr, uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr, bool isSpec) { if (!isSpec) { - // We can de-duplicate nomral contants, but not specialization constants. + // We can de-duplicate normal constants, but not specialization constants. if (auto id = getConstantID(boolAttr)) { return id; } @@ -1102,7 +1102,7 @@ uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr, uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr, bool isSpec) { if (!isSpec) { - // We can de-duplicate nomral contants, but not specialization constants. + // We can de-duplicate normal constants, but not specialization constants. if (auto id = getConstantID(intAttr)) { return id; } @@ -1168,7 +1168,7 @@ uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr, uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr, bool isSpec) { if (!isSpec) { - // We can de-duplicate nomral contants, but not specialization constants. + // We can de-duplicate normal constants, but not specialization constants. if (auto id = getConstantID(floatAttr)) { return id; } @@ -1549,7 +1549,7 @@ template <> LogicalResult Serializer::processOp(spirv::EntryPointOp op) { SmallVector operands; - // Add the ExectionModel. + // Add the ExecutionModel. operands.push_back(static_cast(op.execution_model())); // Add the function . auto funcID = getFunctionID(op.fn()); diff --git a/mlir/lib/ExecutionEngine/OptUtils.cpp b/mlir/lib/ExecutionEngine/OptUtils.cpp index e8c6652f446..dc3bd20794e 100644 --- a/mlir/lib/ExecutionEngine/OptUtils.cpp +++ b/mlir/lib/ExecutionEngine/OptUtils.cpp @@ -84,7 +84,7 @@ static void populatePassManagers(llvm::legacy::PassManager &modulePM, if (targetMachine) { // Add pass to initialize TTI for this specific target. Otherwise, TTI will - // be initialized to NoTTIImpl by defaul. + // be initialized to NoTTIImpl by default. modulePM.add(createTargetTransformInfoWrapperPass( targetMachine->getTargetIRAnalysis())); funcPM.add(createTargetTransformInfoWrapperPass( diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 7ec5c3b6a26..24ae2072f77 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -233,7 +233,7 @@ Attribute Builder::getZeroAttr(Type type) { } //===----------------------------------------------------------------------===// -// Affine Expressions, Affine Maps, and Integet Sets. +// Affine Expressions, Affine Maps, and Integer Sets. //===----------------------------------------------------------------------===// AffineExpr Builder::getAffineDimExpr(unsigned position) { diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp index edd67d75866..cdf08f6adcf 100644 --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -654,7 +654,7 @@ SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler( : SourceMgrDiagnosticVerifierHandler(srcMgr, ctx, llvm::errs()) {} SourceMgrDiagnosticVerifierHandler::~SourceMgrDiagnosticVerifierHandler() { - // Ensure that all expected diagnosics were handled. + // Ensure that all expected diagnostics were handled. (void)verify(); } diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index c624a06d18d..be904f8da44 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -636,7 +636,7 @@ IntegerSet IntegerSet::get(unsigned dimCount, unsigned symbolCount, }; // If this instance is uniqued, then we handle it separately so that multiple - // threads may simulatenously access existing instances. + // threads may simultaneously access existing instances. if (constraints.size() < IntegerSet::kUniquingThreshold) { auto key = std::make_tuple(dimCount, symbolCount, constraints, eqFlags); return safeGetOrCreate(impl.integerSets, key, impl.affineMutex, diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp index 144be24c4d2..39ba7ca5733 100644 --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -137,7 +137,7 @@ LogicalResult OpTrait::impl::verifySymbolTable(Operation *op) { return op->emitOpError() << "Operations with a 'SymbolTable' must have exactly one region"; - // Check that all symboles are uniquely named within child regions. + // Check that all symbols are uniquely named within child regions. llvm::StringMap nameToOrigLoc; for (auto &block : op->getRegion(0)) { for (auto &op : block) { diff --git a/mlir/lib/Parser/Lexer.cpp b/mlir/lib/Parser/Lexer.cpp index 991c4c92567..917cf913f14 100644 --- a/mlir/lib/Parser/Lexer.cpp +++ b/mlir/lib/Parser/Lexer.cpp @@ -30,7 +30,7 @@ using namespace mlir; using llvm::SMLoc; using llvm::SourceMgr; -// Returns true if 'c' is an allowable puncuation character: [$._-] +// Returns true if 'c' is an allowable punctuation character: [$._-] // Returns false otherwise. static bool isPunct(char c) { return c == '$' || c == '.' || c == '_' || c == '-'; @@ -284,7 +284,7 @@ Token Lexer::lexNumber(const char *tokStart) { // Handle the hexadecimal case. if (curPtr[-1] == '0' && *curPtr == 'x') { - // If we see stuff like 0xi32, this is a literal `0` follwed by an + // If we see stuff like 0xi32, this is a literal `0` followed by an // identifier `xi32`, stop after `0`. if (!isxdigit(curPtr[1])) return formToken(Token::integer, tokStart); diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 873476ebffb..60a0a9015fc 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -350,7 +350,7 @@ ParseResult Parser::parseCommaSeparatedListUntil( } /// Parse the body of a pretty dialect symbol, which starts and ends with <>'s, -/// and may be recursive. Return with the 'prettyName' StringRef encompasing +/// and may be recursive. Return with the 'prettyName' StringRef encompassing /// the entire pretty name. /// /// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>' @@ -2815,7 +2815,7 @@ private: /// This keeps track of the block names as well as the location of the first /// reference for each nested name scope. This is used to diagnose invalid - /// block references and memoize them. + /// block references and memorize them. SmallVector>, 2> blocksByName; SmallVector, 2> forwardRef; @@ -3250,7 +3250,7 @@ ParseResult OperationParser::parseSuccessors( namespace { // RAII-style guard for cleaning up the regions in the operation state before // deleting them. Within the parser, regions may get deleted if parsing failed, -// and other errors may be present, in praticular undominated uses. This makes +// and other errors may be present, in particular undominated uses. This makes // sure such uses are deleted. struct CleanupOpStateRegions { ~CleanupOpStateRegions() { @@ -3352,7 +3352,7 @@ Operation *OperationParser::parseGenericOperation() { return nullptr; } - // Add the sucessors, and their operands after the proper operands. + // Add the successors, and their operands after the proper operands. for (const auto &succ : llvm::zip(successors, successorOperands)) { Block *successor = std::get<0>(succ); const SmallVector &operands = std::get<1>(succ); @@ -3730,7 +3730,7 @@ public: //===--------------------------------------------------------------------===// /// Parse a region that takes `arguments` of `argTypes` types. This - /// effectively defines the SSA values of `arguments` and assignes their type. + /// effectively defines the SSA values of `arguments` and assigns their type. ParseResult parseRegion(Region ®ion, ArrayRef arguments, ArrayRef argTypes, bool enableNameShadowing) override { diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp index ed62574d7cb..a195bb0c0c8 100644 --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -301,7 +301,7 @@ static LogicalResult runPipeline(OpPassManager &pm, Operation *op, // Clear out any computed operation analyses. These analyses won't be used // any more in this pipeline, and this helps reduce the current working set // of memory. If preserving these analyses becomes important in the future - // we can re-evalutate this. + // we can re-evaluate this. am.clear(); return result; } diff --git a/mlir/lib/Pass/PassTiming.cpp b/mlir/lib/Pass/PassTiming.cpp index 7ba5b46fc68..69a2cb723e5 100644 --- a/mlir/lib/Pass/PassTiming.cpp +++ b/mlir/lib/Pass/PassTiming.cpp @@ -102,7 +102,7 @@ struct Timer { .count()); } - // Otheriwse, accumulate the timing from each of the children. + // Otherwise, accumulate the timing from each of the children. TimeRecord totalTime; for (auto &child : children) totalTime += child.second->getTotalTime(); @@ -120,7 +120,7 @@ struct Timer { mergeChildren(std::move(other.children)); } - /// Merge the timer chilren in 'otherChildren' with the children of this + /// Merge the timer children in 'otherChildren' with the children of this /// timer. void mergeChildren(ChildrenMap &&otherChildren) { // Check for an empty children list. @@ -130,7 +130,7 @@ struct Timer { } // Pipeline merges are handled separately as the children are merged - // lexographically. + // lexicographically. if (kind == TimerKind::Pipeline) { assert(children.size() == otherChildren.size() && "pipeline merge requires the same number of children"); diff --git a/mlir/lib/Quantizer/Support/UniformSolvers.cpp b/mlir/lib/Quantizer/Support/UniformSolvers.cpp index db5eaa0a05c..341df5bf888 100644 --- a/mlir/lib/Quantizer/Support/UniformSolvers.cpp +++ b/mlir/lib/Quantizer/Support/UniformSolvers.cpp @@ -59,7 +59,7 @@ bool UniformParamsFromMinMaxSolver::compute() { double minOvershoot = boundingMin - adjMinMax.first; // If undershooting on the min or max end, return that because it is // to be unconditionally avoided. Otherwise return the end with the - // greateast magnitude of overshoot. + // greatest magnitude of overshoot. if (maxOvershoot < 0) return maxOvershoot; if (minOvershoot < 0) diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp index 9db085a310b..2986f30dee6 100644 --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -475,7 +475,7 @@ std::vector tblgen::Pattern::getConstraints() const { for (auto it : *listInit) { auto *dagInit = dyn_cast(it); if (!dagInit) - PrintFatalError(def.getLoc(), "all elemements in Pattern multi-entity " + PrintFatalError(def.getLoc(), "all elements in Pattern multi-entity " "constraints should be DAG nodes"); std::vector entities; diff --git a/mlir/lib/TableGen/Predicate.cpp b/mlir/lib/TableGen/Predicate.cpp index bc2b424ab00..f8f23e04c3f 100644 --- a/mlir/lib/TableGen/Predicate.cpp +++ b/mlir/lib/TableGen/Predicate.cpp @@ -170,7 +170,7 @@ static PredNode *buildPredicateTree(const tblgen::Pred &root, } // If the current combined predicate is a leaf substitution, append it to the - // list before contiuing. + // list before continuing. auto allSubstitutions = llvm::to_vector<4>(substitutions); if (rootNode->kind == PredCombinerKind::SubstLeaves) { const auto &substPred = static_cast(root); @@ -223,7 +223,7 @@ static PredNode *propagateGroundTruth( // TODO(zinenko,jpienaar): we can support ground truth for rewritten // predicates by either (a) having our own unique'ing of the predicates // instead of relying on TableGen record pointers or (b) taking ground truth - // values optinally prefixed with a list of substitutions to apply, e.g. + // values optionally prefixed with a list of substitutions to apply, e.g. // "predX is true by itself as well as predSubY leaf substitution had been // applied to it". if (node->kind == PredCombinerKind::SubstLeaves) { diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 7e08c6bbc57..24d91c2fe63 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -153,7 +153,7 @@ static bool isMemRefDereferencingOp(Operation &op) { // MemRefDependenceGraph is a graph data structure where graph nodes are // top-level operations in a FuncOp which contain load/store ops, and edges // are memref dependences between the nodes. -// TODO(andydavis) Add a more flexible dependece graph representation. +// TODO(andydavis) Add a more flexible dependence graph representation. // TODO(andydavis) Add a depth parameter to dependence graph construction. struct MemRefDependenceGraph { public: @@ -224,7 +224,7 @@ public: } }; - // Edge represents a data dependece between nodes in the graph. + // Edge represents a data dependence between nodes in the graph. struct Edge { // The id of the node at the other end of the edge. // If this edge is stored in Edge = Node.inEdges[i], then @@ -672,7 +672,7 @@ public: void dump() const { print(llvm::errs()); } }; -// Intializes the data dependence graph by walking operations in 'f'. +// Initializes the data dependence graph by walking operations in 'f'. // Assigns each node in the graph a node id based on program order in 'f'. // TODO(andydavis) Add support for taking a Block arg to construct the // dependence graph at a different depth. @@ -921,7 +921,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, "non-constant number of elts in local buffer"); const FlatAffineConstraints *cst = region.getConstraints(); - // 'outerIVs' holds the values that this memory region is symbolic/paramteric + // 'outerIVs' holds the values that this memory region is symbolic/parametric // on; this would correspond to loop IVs surrounding the level at which the // slice is being materialized. SmallVector outerIVs; @@ -1065,7 +1065,7 @@ static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, // surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'. // The argument 'srcStoreOpInst' is used to calculate the storage reduction on // the memref being produced and consumed, which is an input to the cost model. -// For producer-constumer fusion, 'srcStoreOpInst' will be the same as +// For producer-consumer fusion, 'srcStoreOpInst' will be the same as // 'srcOpInst', as we are slicing w.r.t to that producer. // For input-reuse fusion, 'srcOpInst' will be the src loop nest LoadOp which // reads from the same memref as dst loop nest load ops, and 'srcStoreOpInst' @@ -1084,8 +1084,8 @@ static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, // nest). // *) Computes the cost of fusing a slice of the src loop nest into the dst // loop nest at various values of dst loop depth, attempting to fuse -// the largest compution slice at the maximal dst loop depth (closest to the -// load) to minimize reuse distance and potentially enable subsequent +// the largest computation slice at the maximal dst loop depth (closest to +// the load) to minimize reuse distance and potentially enable subsequent // load/store forwarding. // NOTE: If the dst loop nest includes multiple loads in 'dstLoadOpInsts' for // the same memref as is written by 'srcOpInst', then the union of slice @@ -1095,7 +1095,7 @@ static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, // NOTE: We attempt to maximize the dst loop depth, but there are cases // where a particular setting for 'dstLoopNest' might fuse an unsliced // loop (within the src computation slice) at a depth which results in -// execessive recomputation (see unit tests for examples). +// excessive recomputation (see unit tests for examples). // *) Compares the total cost of the unfused loop nests to the min cost fused // loop nest computed in the previous step, and returns true if the latter // is lower. @@ -1612,7 +1612,7 @@ public: mdg->addEdge(newMemRefNodeId, dstId, newMemRef); } - // Collect dst loop stats after memref privatizaton transformation. + // Collect dst loop stats after memref privatization transformation. LoopNestStateCollector dstLoopCollector; dstLoopCollector.collect(dstAffineForOp.getOperation()); @@ -1876,7 +1876,7 @@ public: promoteIfSingleIteration(forOp); } - // Collect dst loop stats after memref privatizaton transformation. + // Collect dst loop stats after memref privatization transformation. auto dstForInst = cast(dstNode->op); LoopNestStateCollector dstLoopCollector; dstLoopCollector.collect(dstForInst.getOperation()); diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index b3e811b7123..d50c5e0e8c7 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -431,7 +431,7 @@ public: if (!maybeExpandedMap) return matchFailure(); - // Build std.store valutToStore, memref[expandedMap.results]. + // Build std.store valueToStore, memref[expandedMap.results]. rewriter.replaceOpWithNewOp(op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap); return matchSuccess(); diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index 870804d989b..c531ca551b4 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -51,7 +51,7 @@ namespace { // all store op's that have a dependence into the load, is provably the last // writer to the particular memref location being loaded at the load op, and its // store value can be forwarded to the load. Note that the only dependences -// that are to be considered are those that are satisifed at the block* of the +// that are to be considered are those that are satisfied at the block* of the // innermost common surrounding loop of the being considered. // // (* A dependence being satisfied at a block: a dependence that is satisfied by diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index fb96772f053..e09d8c89b37 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -235,7 +235,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, // This method uses an algorithm// in time linear in the number of operations // in the body of the for loop - (using the 'sweep line' paradigm). This method // asserts preservation of SSA dominance. A check for that as well as that for -// memory-based depedence preservation check rests with the users of this +// memory-based dependence preservation check rests with the users of this // method. LogicalResult mlir::instBodySkew(AffineForOp forOp, ArrayRef shifts, bool unrollPrologueEpilogue) { @@ -531,7 +531,7 @@ void mlir::interchangeLoops(AffineForOp forOpA, AffineForOp forOpB) { // Checks each dependence component against the permutation to see if the // desired loop interchange would violate dependences by making the -// dependence componenent lexicographically negative. +// dependence component lexicographically negative. static bool checkLoopInterchangeDependences( const std::vector> &depCompsVec, ArrayRef loops, ArrayRef loopPermMap) { @@ -829,7 +829,7 @@ Loops mlir::tile(ArrayRef forOps, ArrayRef sizes, Loops mlir::tilePerfectlyNested(loop::ForOp rootForOp, ArrayRef sizes) { - // Collect prefectly nested loops. If more size values provided than nested + // Collect perfectly nested loops. If more size values provided than nested // loops available, truncate `sizes`. SmallVector forOps; forOps.reserve(sizes.size()); @@ -842,7 +842,7 @@ Loops mlir::tilePerfectlyNested(loop::ForOp rootForOp, // Build the IR that performs ceil division of a positive value by a constant: // ceildiv(a, B) = divis(a + (B-1), B) -// where divis is roundning-to-zero division. +// where divis is rounding-to-zero division. static Value *ceilDivPositive(OpBuilder &builder, Location loc, Value *dividend, int64_t divisor) { assert(divisor > 0 && "expected positive divisor"); @@ -1343,7 +1343,7 @@ static LogicalResult generateCopy( } const FlatAffineConstraints *cst = region.getConstraints(); - // 'regionSymbols' hold values that this memory region is symbolic/paramteric + // 'regionSymbols' hold values that this memory region is symbolic/parametric // on; these typically include loop IVs surrounding the level at which the // copy generation is being done or other valid symbols in MLIR. SmallVector regionSymbols; diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 1e10f372b5f..a1e87568745 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -191,7 +191,7 @@ using namespace mlir; /// programmer/library: we derive information from scalar code + annotations. /// 2. After dependence analysis and before polyhedral scheduling: the /// information that supports vectorization does not need to be supplied by a -/// higher level of abstraction. Traditional dependence anaysis is available +/// higher level of abstraction. Traditional dependence analysis is available /// in MLIR and will be used to drive vectorization and cost models. /// /// Let's pause here and remark that applying super-vectorization as described @@ -211,7 +211,7 @@ using namespace mlir; /// operating on elemental vector types. For this reason, the pattern /// profitability analysis should include a component that also captures the /// maximal amount of fusion available under a particular pattern. This is -/// still at the stage of rought ideas but in this context, search is our +/// still at the stage of rough ideas but in this context, search is our /// friend as the Tensor Comprehensions and auto-TVM contributions /// demonstrated previously. /// Bottom-line is we do not yet have good answers for the above but aim at @@ -253,8 +253,8 @@ using namespace mlir; /// 1. defining super-vectorization patterns and matching them on the tree of /// AffineForOp. A super-vectorization pattern is defined as a recursive /// data structures that matches and captures nested, imperfectly-nested -/// loops that have a. comformable loop annotations attached (e.g. parallel, -/// reduction, vectoriable, ...) as well as b. all contiguous load/store +/// loops that have a. conformable loop annotations attached (e.g. parallel, +/// reduction, vectorizable, ...) as well as b. all contiguous load/store /// operations along a specified minor dimension (not necessarily the /// fastest varying) ; /// 2. analyzing those patterns for profitability (TODO(ntv): and @@ -482,7 +482,7 @@ using namespace mlir; /// --test-fastest-varying=1 --test-fastest-varying=0 /// ``` /// -/// produces this more insteresting mixed outer-innermost-loop vectorized code: +/// produces this more interesting mixed outer-innermost-loop vectorized code: /// ```mlir /// mlfunc @vector_add_2d(%arg0 : index, %arg1 : index) -> f32 { /// %0 = alloc(%arg0, %arg1) : memref @@ -1099,7 +1099,7 @@ static Operation *vectorizeOneOperation(Operation *opInst, /// Iterates over the forward slice from the loads in the vectorization pattern /// and rewrites them using their vectorized counterpart by: -/// 1. Create the forward slice starting from the laods in the vectorization +/// 1. Create the forward slice starting from the loads in the vectorization /// pattern. /// 2. Topologically sorts the forward slice. /// 3. For each operation in the slice, create the vector form of this diff --git a/mlir/test/Dialect/FxpMathOps/lower-uniform-real-math-mulew.mlir b/mlir/test/Dialect/FxpMathOps/lower-uniform-real-math-mulew.mlir index 7fea6cf40ce..9fc120fb936 100644 --- a/mlir/test/Dialect/FxpMathOps/lower-uniform-real-math-mulew.mlir +++ b/mlir/test/Dialect/FxpMathOps/lower-uniform-real-math-mulew.mlir @@ -88,7 +88,7 @@ func @real_mulew_unquantized_result(%arg0 : !type_lhs, %arg1: !type_rhs) -> !typ !type_rhs = type tensor<4x!quant.uniform> !type_result = type tensor<4x!quant.uniform> func @real_mulew_multiplier_gt_1(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result { - // expected-warning@+1 {{unimplemented: cannot multiply with multipler > 1.0}} + // expected-warning@+1 {{unimplemented: cannot multiply with multiplier > 1.0}} %0 = "fxpmath.real_mul_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result) return %0 : !type_result } -- cgit v1.2.3 From 68a8da4a938e5489ba915d615352af0b069ae56a Mon Sep 17 00:00:00 2001 From: Andy Davis Date: Mon, 18 Nov 2019 11:20:03 -0800 Subject: Fix Affine Loop Fusion test case reported on github. This CL utilizies the more robust fusion feasibility analysis being built out in LoopFusionUtils, which will eventually be used to replace the current affine loop fusion pass. PiperOrigin-RevId: 281112340 --- mlir/lib/Analysis/Utils.cpp | 4 +- mlir/lib/Transforms/LoopFusion.cpp | 82 +++++++++++++++++++++++++++-------- mlir/test/Transforms/loop-fusion.mlir | 60 ++++++++++++++++++++++--- 3 files changed, 122 insertions(+), 24 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 042c744c74f..23361e38745 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -616,7 +616,9 @@ LogicalResult mlir::computeSliceUnion(ArrayRef opsA, return failure(); } // Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'. - if (failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) { + if (sliceUnionCst.getNumLocalIds() > 0 || + tmpSliceCst.getNumLocalIds() > 0 || + failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) { LLVM_DEBUG(llvm::dbgs() << "Unable to compute union bounding box of slice bounds." "\n."); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 24d91c2fe63..7985ca1c5ef 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -546,8 +546,10 @@ public: } // Updates edge mappings from node 'srcId' to node 'dstId' after 'oldMemRef' - // has been replaced in node at 'dstId' by a private memref. - void updateEdges(unsigned srcId, unsigned dstId, Value *oldMemRef) { + // has been replaced in node at 'dstId' by a private memref depending + // on the value of 'createPrivateMemRef'. + void updateEdges(unsigned srcId, unsigned dstId, Value *oldMemRef, + bool createPrivateMemRef) { // For each edge in 'inEdges[srcId]': add new edge remaping to 'dstId'. if (inEdges.count(srcId) > 0) { SmallVector oldInEdges = inEdges[srcId]; @@ -569,7 +571,7 @@ public: // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being // replaced by a private memref). These edges could come from nodes // other than 'srcId' which were removed in the previous step. - if (inEdges.count(dstId) > 0) { + if (inEdges.count(dstId) > 0 && createPrivateMemRef) { SmallVector oldInEdges = inEdges[dstId]; for (auto &inEdge : oldInEdges) if (inEdge.value == oldMemRef) @@ -1522,8 +1524,27 @@ public: // TODO(andydavis) Support more generic multi-output src loop nests // fusion. auto srcStoreOp = mdg->getUniqueOutgoingStore(srcNode); - if (!srcStoreOp) - continue; + if (!srcStoreOp) { + // Get the src store op at the deepest loop depth. + // We will use 'LoopFusionUtils::canFuseLoops' to check fusion + // feasibility for loops with multiple stores. + unsigned maxLoopDepth = 0; + for (auto *op : srcNode->stores) { + auto storeOp = cast(op); + if (storeOp.getMemRef() != memref) { + srcStoreOp = nullptr; + break; + } + unsigned loopDepth = getNestingDepth(*storeOp); + if (loopDepth > maxLoopDepth) { + maxLoopDepth = loopDepth; + srcStoreOp = storeOp; + } + } + if (!srcStoreOp) + continue; + } + // Unique outgoing store found must write to 'memref' since 'memref' // is the one that established the producer-consumer relationship // between 'srcNode' and 'dstNode'. @@ -1538,6 +1559,15 @@ public: !canFuseSrcWhichWritesToLiveOut(srcId, dstId, srcStoreOp, mdg)) continue; + // Dont create a private memref if 'writesToLiveInOrOut'. + bool createPrivateMemref = !writesToLiveInOrOut; + // Dont create a private memref if 'srcNode' has in edges on 'memref', + // or if 'dstNode' has out edges on 'memref'. + if (mdg->getIncomingMemRefAccesses(srcNode->id, memref) > 0 || + mdg->getOutEdgeCount(dstNode->id, memref) > 0) { + createPrivateMemref = false; + } + // Skip if 'srcNode' out edge count on 'memref' > 'maxSrcUserCount'. if (mdg->getOutEdgeCount(srcNode->id, memref) > maxSrcUserCount) continue; @@ -1549,6 +1579,29 @@ public: if (insertPointInst == nullptr) continue; + // Compute the innermost common loop depth for dstNode loads/stores. + SmallVector dstOps(dstNode->loads.begin(), + dstNode->loads.end()); + dstOps.append(dstNode->stores.begin(), dstNode->stores.end()); + unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstOps); + // Check the feasibility of fusing src loop nest into dst loop nest + // at loop depths in range [1, dstLoopDepthTest]. + // TODO(andydavis) Use slice union computation and union of memref + // read/write regions to cost model and fusion. + bool canFuse = false; + for (unsigned i = 1; i <= dstLoopDepthTest; ++i) { + ComputationSliceState sliceUnion; + FusionResult result = mlir::canFuseLoops( + cast(srcNode->op), cast(dstNode->op), + /*dstLoopDepth=*/i, &sliceUnion); + if (result.value == FusionResult::Success) + canFuse = true; + } + + // Skip if fusion is not feasible at all loop depths. + if (!canFuse) + continue; + // Gather 'dstNode' store ops to 'memref'. SmallVector dstStoreOpInsts; for (auto *storeOpInst : dstNode->stores) @@ -1562,16 +1615,7 @@ public: dstStoreOpInsts, &sliceState, &bestDstLoopDepth, maximalFusion)) continue; - // TODO(andydavis) Remove the following test code when canFuseLoops - // is fully functional. - mlir::ComputationSliceState sliceUnion; - if (!maximalFusion) { - FusionResult result = mlir::canFuseLoops( - cast(srcNode->op), cast(dstNode->op), - bestDstLoopDepth, &sliceUnion); - assert(result.value == FusionResult::Success); - (void)result; - } + // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. auto sliceLoopNest = mlir::insertBackwardComputationSlice( srcStoreOp, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState); @@ -1584,7 +1628,8 @@ public: dstAffineForOp.getOperation()->moveBefore(insertPointInst); } // Update edges between 'srcNode' and 'dstNode'. - mdg->updateEdges(srcNode->id, dstNode->id, memref); + mdg->updateEdges(srcNode->id, dstNode->id, memref, + createPrivateMemref); // Collect slice loop stats. LoopNestStateCollector sliceCollector; @@ -1593,14 +1638,15 @@ public: for (auto forOp : sliceCollector.forOps) { promoteIfSingleIteration(forOp); } - if (!writesToLiveInOrOut) { + if (createPrivateMemref) { // Create private memref for 'memref' in 'dstAffineForOp'. SmallVector storesForMemref; for (auto *storeOpInst : sliceCollector.storeOpInsts) { if (cast(storeOpInst).getMemRef() == memref) storesForMemref.push_back(storeOpInst); } - assert(storesForMemref.size() == 1); + // TODO(andydavis) Use union of memref write regions to compute + // private memref footprint. auto *newMemRef = createPrivateMemRef( dstAffineForOp, storesForMemref[0], bestDstLoopDepth, fastMemorySpace, localBufSizeThreshold); diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index 36bcd0e5466..592b45d4ab1 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -321,11 +321,8 @@ func @should_fuse_producer_consumer() { // TODO(andydavis) When the fusion pass is run to a fixed-point, it should // fuse all three of these loop nests. // CHECK: %{{.*}} = alloc() : memref<1xf32> - // CHECK: %{{.*}} = alloc() : memref<10xf32> // CHECK: affine.for %{{.*}} = 0 to 10 { - // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> - // CHECK-NEXT: } - // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 { + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32> // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32> // CHECK-NEXT: %{{.*}} = affine.load %{{.*}}[0] : memref<1xf32> // CHECK-NEXT: } @@ -1238,7 +1235,6 @@ func @R3_to_R2_reshape() { // ----- -// CHECK-LABEL: func @should_not_fuse_multi_output_producer() { func @should_not_fuse_multi_output_producer() { %a = alloc() : memref<10xf32> %b = alloc() : memref<10xf32> @@ -2341,3 +2337,57 @@ func @should_fuse_function_live_out_multi_store_producer(%live_in_out_m : memref // CHECK-NEXT: return return } + +// ----- + +// Test case from github bug 777. +// CHECK-LABEL: func @mul_add_0 +func @mul_add_0(%arg0: memref<3x4xf32>, %arg1: memref<4x3xf32>, %arg2: memref<3x3xf32>, %arg3: memref<3x3xf32>) { + %cst = constant 0.000000e+00 : f32 + %0 = alloc() : memref<3x3xf32> + affine.for %arg4 = 0 to 3 { + affine.for %arg5 = 0 to 3 { + affine.store %cst, %0[%arg4, %arg5] : memref<3x3xf32> + } + } + affine.for %arg4 = 0 to 3 { + affine.for %arg5 = 0 to 3 { + affine.for %arg6 = 0 to 4 { + %1 = affine.load %arg1[%arg6, %arg5] : memref<4x3xf32> + %2 = affine.load %arg0[%arg4, %arg6] : memref<3x4xf32> + %3 = mulf %2, %1 : f32 + %4 = affine.load %0[%arg4, %arg5] : memref<3x3xf32> + %5 = addf %4, %3 : f32 + affine.store %5, %0[%arg4, %arg5] : memref<3x3xf32> + } + } + } + affine.for %arg4 = 0 to 3 { + affine.for %arg5 = 0 to 3 { + %6 = affine.load %arg2[%arg4, %arg5] : memref<3x3xf32> + %7 = affine.load %0[%arg4, %arg5] : memref<3x3xf32> + %8 = addf %7, %6 : f32 + affine.store %8, %arg3[%arg4, %arg5] : memref<3x3xf32> + } + } + // CHECK: affine.for %[[i0:.*]] = 0 to 3 { + // CHECK-NEXT: affine.for %[[i1:.*]] = 0 to 3 { + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0, 0] : memref<1x1xf32> + // CHECK-NEXT: affine.for %[[i2:.*]] = 0 to 4 { + // CHECK-NEXT: affine.load %{{.*}}[%[[i2]], %[[i1]]] : memref<4x3xf32> + // CHECK-NEXT: affine.load %{{.*}}[%[[i0]], %[[i2]]] : memref<3x4xf32> + // CHECK-NEXT: %{{.*}} = mulf %{{.*}}, %{{.*}} : f32 + // CHECK-NEXT: affine.load %{{.*}}[0, 0] : memref<1x1xf32> + // CHECK-NEXT: %{{.*}} = addf %{{.*}}, %{{.*}} : f32 + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0, 0] : memref<1x1xf32> + // CHECK-NEXT: } + // CHECK-NEXT: affine.load %{{.*}}[%[[i0]], %[[i1]]] : memref<3x3xf32> + // CHECK-NEXT: affine.load %{{.*}}[0, 0] : memref<1x1xf32> + // CHECK-NEXT: %{{.*}} = addf %{{.*}}, %{{.*}} : f32 + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%[[i0]], %[[i1]]] : memref<3x3xf32> + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: return + + return +} -- cgit v1.2.3 From 330d1ff00ea85363125ca9b7e42dca50f6ea4ebe Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Tue, 3 Dec 2019 06:09:21 -0800 Subject: AffineLoopFusion: Prevent fusion of multi-out-edge producer loops tensorflow/mlir#162 introduced a bug that incorrectly allowed fusion of producer loops with multiple outgoing edges. This commit fixes that problem. It also introduces a new flag to disable sibling loop fusion so that we can test producer-consumer fusion in isolation. Closes tensorflow/mlir#259 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/259 from dcaballe:dcaballe/fix_multi_out_edge_producer_fusion 578d5661705fd5c56c555832d5e0528df88c5282 PiperOrigin-RevId: 283531105 --- mlir/lib/Transforms/LoopFusion.cpp | 8 ++++-- mlir/test/Transforms/loop-fusion.mlir | 53 +++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 3 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 7985ca1c5ef..cda35297abc 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -1005,17 +1005,19 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, // Checks if node 'srcId' can be safely fused into node 'dstId'. Node 'srcId' // may write to multiple memrefs but it is required that only one of them, -// 'srcLiveOutStoreOp', have an output edge. +// 'srcLiveOutStoreOp', has output edges. // Returns true if 'dstNode's read/write region to 'memref' is a super set of -// 'srcNode's write region to 'memref'. +// 'srcNode's write region to 'memref' and 'srcId' has only one output edge. // TODO(andydavis) Generalize this to handle more live in/out cases. static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, AffineStoreOp srcLiveOutStoreOp, MemRefDependenceGraph *mdg) { assert(srcLiveOutStoreOp && "Expected a valid store op"); - assert(mdg->getOutEdgeCount(srcId) == 1 && "Expected only one output edge"); auto *dstNode = mdg->getNode(dstId); Value *memref = srcLiveOutStoreOp.getMemRef(); + // Return false if 'srcNode' has more than one output edge on 'memref'. + if (mdg->getOutEdgeCount(srcId, memref) > 1) + return false; // Compute MemRefRegion 'srcWriteRegion' for 'srcStoreOp' on 'memref'. MemRefRegion srcWriteRegion(srcLiveOutStoreOp.getLoc()); diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index 7431eade896..339cc31f549 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -2388,6 +2388,59 @@ func @mul_add_0(%arg0: memref<3x4xf32>, %arg1: memref<4x3xf32>, %arg2: memref<3x // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return + return +} + +// ----- + +// Verify that 'fuseProducerConsumerNodes' doesn't fuse a producer loop with +// a store that has multiple outgoing edges. Sibling loop fusion should not fuse +// any of these loops due to dependencies on external memref '%a'. +// CHECK-LABEL: func @should_not_fuse_multi_outgoing_edge_store_producer1 +func @should_not_fuse_multi_outgoing_edge_store_producer1(%a : memref<1xf32>) { + %cst = constant 0.000000e+00 : f32 + affine.for %arg0 = 0 to 1 { + affine.store %cst, %a[%arg0] : memref<1xf32> + } + + affine.for %arg0 = 0 to 1 { + %0 = affine.load %a[%arg0] : memref<1xf32> + } + + affine.for %arg0 = 0 to 1 { + %0 = affine.load %a[%arg0] : memref<1xf32> + } + // CHECK: affine.for %{{.*}} = 0 to 1 + // CHECK: affine.for %{{.*}} = 0 to 1 + // CHECK: affine.for %{{.*}} = 0 to 1 + return +} + +// ----- + +// Verify that 'fuseProducerConsumerNodes' fuses a producer loop that: 1) has +// multiple outgoing edges, 2) producer store has a single outgoing edge. +// Sibling loop fusion should not fuse any of these loops due to +// dependencies on external memrefs '%a' and '%b'. + +// CHECK-LABEL: func @should_fuse_producer_with_multi_outgoing_edges +func @should_fuse_producer_with_multi_outgoing_edges(%a : memref<1xf32>, %b : memref<1xf32>) { + %cst = constant 0.000000e+00 : f32 + affine.for %arg0 = 0 to 1 { + %0 = affine.load %a[%arg0] : memref<1xf32> + affine.store %cst, %b[%arg0] : memref<1xf32> + } + + affine.for %arg0 = 0 to 1 { + affine.store %cst, %a[%arg0] : memref<1xf32> + %1 = affine.load %b[%arg0] : memref<1xf32> + } + // CHECK: affine.for %{{.*}} = 0 to 1 + // CHECK-NEXT: affine.load %[[A:.*]][{{.*}}] + // CHECK-NEXT: affine.store %{{.*}}, %[[B:.*]][{{.*}}] + // CHECK-NEXT: affine.store %{{.*}}, %[[A]] + // CHECK-NEXT: affine.load %[[B]] + // CHECK-NOT: affine.for %{{.*}} return } -- cgit v1.2.3 From 84a6182ddd62a2ca8eee2d8470e3be1ef6147fce Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 6 Dec 2019 05:58:59 -0800 Subject: minor spelling tweaks Closes tensorflow/mlir#290 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/290 from kiszk:spelling_tweaks_201912 9d9afd16a723dd65754a04698b3976f150a6054a PiperOrigin-RevId: 284169681 --- mlir/examples/toy/Ch7/mlir/MLIRGen.cpp | 2 +- mlir/g3doc/DeclarativeRewrites.md | 13 ++- mlir/g3doc/Dialects/GPU.md | 2 +- mlir/g3doc/OpDefinitions.md | 130 +++++++++++---------- mlir/g3doc/WritingAPass.md | 2 +- .../StandardToLLVM/ConvertStandardToLLVM.cpp | 2 +- mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp | 6 +- .../Dialect/SPIRV/Serialization/Deserializer.cpp | 3 +- mlir/lib/IR/AsmPrinter.cpp | 2 +- mlir/lib/IR/Diagnostics.cpp | 2 +- mlir/lib/IR/Operation.cpp | 2 +- mlir/lib/IR/SymbolTable.cpp | 2 +- mlir/lib/Pass/PassTiming.cpp | 2 +- mlir/lib/Transforms/LoopFusion.cpp | 6 +- mlir/test/IR/traits.mlir | 4 +- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 2 +- mlir/tools/mlir-tblgen/RewriterGen.cpp | 6 +- mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp | 2 +- mlir/unittests/TableGen/StructsGenTest.cpp | 2 +- 19 files changed, 97 insertions(+), 95 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp index 227ebcd758b..b33137a1066 100644 --- a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp @@ -375,7 +375,7 @@ private: return mlir::success(); } - /// Emit a coinstant for a literal/constant array. It will be emitted as a + /// Emit a constant for a literal/constant array. It will be emitted as a /// flattened array of data in an Attribute attached to a `toy.constant` /// operation. See documentation on [Attributes](LangRef.md#attributes) for /// more details. Here is an excerpt: diff --git a/mlir/g3doc/DeclarativeRewrites.md b/mlir/g3doc/DeclarativeRewrites.md index e319b7d7a83..5adcb320983 100644 --- a/mlir/g3doc/DeclarativeRewrites.md +++ b/mlir/g3doc/DeclarativeRewrites.md @@ -259,9 +259,9 @@ def : Pat<(AOp $input, $attr), (COp (AOp $input, $attr) $attr)>; `AOp` is generated via a nested result pattern; DRR won't be able to deduce the result type for it. A custom builder for `AOp` should be defined and it should -deduce the result type by itself. The builder should have the a separate -parameter for each operand and attribute and deduce the result type internally -by itself. For example, for the above `AOp`, a possible builder is: +deduce the result type by itself. The builder should have the separate parameter +for each operand and attribute and deduce the result type internally by itself. +For example, for the above `AOp`, a possible builder is: ```c++ @@ -311,9 +311,10 @@ def DOp : Op<"d_op"> { def : Pat<(AOp $input, $ignored_attr), (DOp (BOp:$b_result) $b_result)>; ``` -In this pattern, a `AOp` is matched and replaced with a `DOp` whose two operands -are from the result of a single `BOp`. This is only possible by binding the -result of the `BOp` to a name and reuse it for the second operand of the `DOp` +In this pattern, an `AOp` is matched and replaced with a `DOp` whose two +operands are from the result of a single `BOp`. This is only possible by binding +the result of the `BOp` to a name and reuse it for the second operand of the +`DOp` #### `NativeCodeCall`: transforming the generated op diff --git a/mlir/g3doc/Dialects/GPU.md b/mlir/g3doc/Dialects/GPU.md index b1cc30e510f..faa07219e03 100644 --- a/mlir/g3doc/Dialects/GPU.md +++ b/mlir/g3doc/Dialects/GPU.md @@ -87,7 +87,7 @@ memory buffers at the module level, we chose to do it at the function level to provide some structuring for the lifetime of those buffers; this avoids the incentive to use the buffers for communicating between different kernels or launches of the same kernel, which should be done through function arguments -intead; we chose not to use `alloca`-style approach that would require more +instead; we chose not to use `alloca`-style approach that would require more complex lifetime analysis following the principles of MLIR that promote structure and representing analysis results in the IR. diff --git a/mlir/g3doc/OpDefinitions.md b/mlir/g3doc/OpDefinitions.md index b72b9937ebb..7fb0e53ea17 100644 --- a/mlir/g3doc/OpDefinitions.md +++ b/mlir/g3doc/OpDefinitions.md @@ -60,16 +60,17 @@ allowed in a TableGen file (typically with filename suffix `.td`) can be found [here][TableGenIntro]. The formal language specification can be found [here][TableGenRef]. _Roughly_ speaking, -* TableGen `class` is similar to C++ class; it can be templated and subclassed. -* TableGen `def` is similar to C++ object; it can be declared by specializing - a TableGen `class` (e.g., `def MyDef : MyClass<...>;`) or completely - independently (e.g., `def MyDef;`). It cannot be further templated or - subclassed. -* TableGen `dag` is a dedicated type for directed graph of elements. A `dag` - has one operator and zero or more arguments. Its syntax is `(operator arg0, - arg1, argN)`. The operator can be any TableGen `def`; an argument can be - anything, including `dag` itself. We can have names attached to both the - operator and the arguments like `(MyOp:$op_name MyArg:$arg_name)`. +* TableGen `class` is similar to C++ class; it can be templated and + subclassed. +* TableGen `def` is similar to C++ object; it can be declared by specializing + a TableGen `class` (e.g., `def MyDef : MyClass<...>;`) or completely + independently (e.g., `def MyDef;`). It cannot be further templated or + subclassed. +* TableGen `dag` is a dedicated type for directed acyclic graph of elements. A + `dag` has one operator and zero or more arguments. Its syntax is `(operator + arg0, arg1, argN)`. The operator can be any TableGen `def`; an argument can + be anything, including `dag` itself. We can have names attached to both the + operator and the arguments like `(MyOp:$op_name MyArg:$arg_name)`. Please see the [language introduction][TableGenIntro] to learn about all the types and expressions supported by TableGen. @@ -214,13 +215,13 @@ places like constraints. To declare a variadic operand, wrap the `TypeConstraint` for the operand with `Variadic<...>`. -Normally operations have no variadic operands or just one variadic operand. -For the latter case, it is easily deduce which dynamic operands are for the -static variadic operand definition. But if an operation has more than one -variadic operands, it would be impossible to attribute dynamic operands to the +Normally operations have no variadic operands or just one variadic operand. For +the latter case, it is easy to deduce which dynamic operands are for the static +variadic operand definition. But if an operation has more than one variadic +operands, it would be impossible to attribute dynamic operands to the corresponding static variadic operand definitions without further information -from the operation. Therefore, the `SameVariadicOperandSize` trait is needed -to indicate that all variadic operands have the same number of dynamic values. +from the operation. Therefore, the `SameVariadicOperandSize` trait is needed to +indicate that all variadic operands have the same number of dynamic values. #### Optional attributes @@ -776,7 +777,7 @@ duplication, which is being worked on right now. ### Enum attributes Some attributes can only take values from an predefined enum, e.g., the -comparsion kind of a comparsion op. To define such attributes, ODS provides +comparison kind of a comparison op. To define such attributes, ODS provides several mechanisms: `StrEnumAttr`, `IntEnumAttr`, and `BitEnumAttr`. * `StrEnumAttr`: each enum case is a string, the attribute is stored as a @@ -1042,53 +1043,54 @@ possible). We considered the approaches of several contemporary systems and focused on requirements that were desirable: -* Ops registered using a registry separate from C++ code. - * Unknown ops are allowed in MLIR, so ops need not be registered. The - ability of the compiler to optimize those ops or graphs containing those - ops is constrained but correct. - * The current proposal does not include a runtime op description, but it - does not preclude such description, it can be added later. - * The op registry is essential for generating C++ classes that make - manipulating ops, verifying correct construction etc. in C++ easier by - providing a typed representation and accessors. -* The op registry will be defined in - [TableGen](https://llvm.org/docs/TableGen/index.html) and be used to - generate C++ classes and utility functions - (builder/verifier/parser/printer). - * TableGen is a modelling specification language used by LLVM's backends - and fits in well with trait based modelling. This is an implementation - decision and there are alternative ways of doing this. But the - specification language is good for the requirements of modelling the - traits (as seen from usage in LLVM processor backend modelling) and easy - to extend, so a practical choice. If another good option comes up, we - will consider it. -* MLIR allows both defined and undefined ops. - * Defined ops should have fixed semantics and could have a corresponding - reference implementation defined using, for example, EDSC. - * Dialects are under full control of the dialect owner and normally live - with the framework of the dialect. -* The op's traits (e.g., commutative) are modelled along with the op in - the registry. -* The op's operand/return type constraints are modelled along with the op in - the registry (see [Shape inference](#shape-inference) discussion below), - this allows (e.g.) optimized concise syntax in textual dumps. -* Behavior of the op is documented along with the op with a summary and a - description. The description is written in markdown and extracted for - inclusion in the generated LangRef section of the dialect. -* The generic assembly form of printing and parsing is available as normal, - but a custom parser and printer can either be specified or automatically - generated from an optional string representation showing the mapping of the - "assembly" string to operands/type. - * Parser-level remappings (e.g., `eq` to enum) will be supported as part - of the parser generation. -* Matching patterns are specified separately from the op description. - * Contrasted with LLVM there is no "base" set of ops that every backend - needs to be aware of. Instead there are many different dialects and the - transformations/legalizations between these dialects form a graph of - transformations. -* Reference implementation may be provided along with the op definition. - * The reference implementation may be in terms of either standard ops or - other reference implementations. +* Ops registered using a registry separate from C++ code. + * Unknown ops are allowed in MLIR, so ops need not be registered. The + ability of the compiler to optimize those ops or graphs containing those + ops is constrained but correct. + * The current proposal does not include a runtime op description, but it + does not preclude such description, it can be added later. + * The op registry is essential for generating C++ classes that make + manipulating ops, verifying correct construction etc. in C++ easier by + providing a typed representation and accessors. +* The op registry will be defined in + [TableGen](https://llvm.org/docs/TableGen/index.html) and be used to + generate C++ classes and utility functions + (builder/verifier/parser/printer). + * TableGen is a modelling specification language used by LLVM's backends + and fits in well with trait-based modelling. This is an implementation + decision and there are alternative ways of doing this. But the + specification language is good for the requirements of modelling the + traits (as seen from usage in LLVM processor backend modelling) and easy + to extend, so a practical choice. If another good option comes up, we + will consider it. +* MLIR allows both defined and undefined ops. + * Defined ops should have fixed semantics and could have a corresponding + reference implementation defined using, for example, EDSC. + * Dialects are under full control of the dialect owner and normally live + with the framework of the dialect. +* The op's traits (e.g., commutative) are modelled along with the op in the + registry. +* The op's operand/return type constraints are modelled along with the op in + the registry (see [Shape inference](#shape-inference) discussion below), + this allows (e.g.) optimized concise syntax in textual dumps. +* Behavior of the op is documented along with the op with a summary and a + description. The description is written in markdown and extracted for + inclusion in the generated LangRef section of the dialect. +* The generic assembly form of printing and parsing is available as normal, + but a custom parser and printer can either be specified or automatically + generated from an optional string representation showing the mapping of the + "assembly" string to operands/type. + * Parser-level remappings (e.g., `eq` to enum) will be supported as part + of the parser generation. +* Matching patterns are specified separately from the op description. + * Contrasted with LLVM there is no "base" set of ops that every backend + needs to be aware of. Instead there are many different dialects and the + transformations/legalizations between these dialects form a graph of + transformations. +* Reference implementation may be provided along with the op definition. + + * The reference implementation may be in terms of either standard ops or + other reference implementations. TODO: document expectation if the dependent op's definition changes. diff --git a/mlir/g3doc/WritingAPass.md b/mlir/g3doc/WritingAPass.md index fc73b7e9ef3..f72d41bea40 100644 --- a/mlir/g3doc/WritingAPass.md +++ b/mlir/g3doc/WritingAPass.md @@ -122,7 +122,7 @@ An analysis may provide additional hooks to control various behavior: Given a preserved analysis set, the analysis returns true if it should truly be invalidated. This allows for more fine-tuned invalidation in cases where an -analysis wasn't explicitly marked preserved, but may be preserved(or +analysis wasn't explicitly marked preserved, but may be preserved (or invalidated) based upon other properties such as analyses sets. ### Querying Analyses diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 5a6282e8d4d..7b15b758968 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -510,7 +510,7 @@ struct FuncOpConversion : public LLVMLegalizationPattern { attributes.push_back(attr); } - // Create an LLVM funcion, use external linkage by default until MLIR + // Create an LLVM function, use external linkage by default until MLIR // functions have linkage. auto newFuncOp = rewriter.create( op->getLoc(), funcOp.getName(), llvmType, LLVM::Linkage::External, diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp index e3b550223e5..694a98fd075 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -71,7 +71,7 @@ mlir::spirv::getEntryPointABIAttr(ArrayRef localSize, Type SPIRVTypeConverter::getIndexType(MLIRContext *context) { // Convert to 32-bit integers for now. Might need a way to control this in // future. - // TODO(ravishankarm): It is porbably better to make it 64-bit integers. To + // TODO(ravishankarm): It is probably better to make it 64-bit integers. To // this some support is needed in SPIR-V dialect for Conversion // instructions. The Vulkan spec requires the builtins like // GlobalInvocationID, etc. to be 32-bit (unsigned) integers which should be @@ -189,7 +189,7 @@ static spirv::GlobalVariableOp getBuiltinVariable(spirv::ModuleOp &moduleOp, return nullptr; } -/// Gets name of global variable for a buitlin. +/// Gets name of global variable for a builtin. static std::string getBuiltinVarName(spirv::BuiltIn builtin) { return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__"; } @@ -230,7 +230,7 @@ getOrInsertBuiltinVariable(spirv::ModuleOp &moduleOp, Location loc, } /// Gets the global variable associated with a builtin and add -/// it if it doesnt exist. +/// it if it doesn't exist. Value *mlir::spirv::getBuiltinVariableValue(Operation *op, spirv::BuiltIn builtin, OpBuilder &builder) { diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index 2011c750d83..72d11a19380 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -270,7 +270,6 @@ private: // block and redirect all branches to the old header block to the old // merge block (which contains the spv.selection/spv.loop op now). - /// For OpPhi instructions, we use block arguments to represent them. OpPhi /// encodes a list of (value, predecessor) pairs. At the time of handling the /// block containing an OpPhi instruction, the predecessor block might not be @@ -278,7 +277,7 @@ private: /// the block argument from the predecessors. We use the following approach: /// /// 1. For each OpPhi instruction, add a block argument to the current block - /// in construction. Record the block argment in `valueMap` so its uses + /// in construction. Record the block argument in `valueMap` so its uses /// can be resolved. For the list of (value, predecessor) pairs, update /// `blockPhiInfo` for bookkeeping. /// 2. After processing all blocks, loop over `blockPhiInfo` to fix up each diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index a3a15dac533..ed97b8b5940 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1116,7 +1116,7 @@ void ModulePrinter::printType(Type type) { //===----------------------------------------------------------------------===// namespace { -/// This class provides the main specialication of the DialectAsmPrinter that is +/// This class provides the main specialization of the DialectAsmPrinter that is /// used to provide support for print attributes and types. This hooks allows /// for dialects to hook into the main ModulePrinter. struct CustomDialectAsmPrinter : public DialectAsmPrinter { diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp index f2f2f83b3a8..70a802cd856 100644 --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -689,7 +689,7 @@ SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler( for (unsigned i = 0, e = mgr.getNumBuffers(); i != e; ++i) (void)impl->computeExpectedDiags(mgr.getMemoryBuffer(i + 1)); - // Register a handler to verfy the diagnostics. + // Register a handler to verify the diagnostics. setHandler([&](Diagnostic &diag) { // Process the main diagnostics. process(diag); diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 69b8d056cd5..1d213f45dd5 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -286,7 +286,7 @@ void Operation::destroy() { /// Return the context this operation is associated with. MLIRContext *Operation::getContext() { return location->getContext(); } -/// Return the dialact this operation is associated with, or nullptr if the +/// Return the dialect this operation is associated with, or nullptr if the /// associated dialect is not registered. Dialect *Operation::getDialect() { if (auto *abstractOp = getAbstractOperation()) diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp index b61308b74af..bd8cb59cea7 100644 --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -283,7 +283,7 @@ static Optional walkSymbolUses( if (walkSymbolRefs(&op, callback).wasInterrupted()) return WalkResult::interrupt(); - // If this operation has regions, and it as well as its dialect arent't + // If this operation has regions, and it as well as its dialect aren't // registered then conservatively fail. The operation may define a // symbol table, so we can't opaquely know if we should traverse to find // nested uses. diff --git a/mlir/lib/Pass/PassTiming.cpp b/mlir/lib/Pass/PassTiming.cpp index 4747249690f..dd193a4d9a9 100644 --- a/mlir/lib/Pass/PassTiming.cpp +++ b/mlir/lib/Pass/PassTiming.cpp @@ -323,7 +323,7 @@ void PassTiming::runAfterPass(Pass *pass, Operation *) { return; } - // Adapator passes aren't timed directly, so we don't need to stop their + // Adaptor passes aren't timed directly, so we don't need to stop their // timers. if (!isAdaptorPass(pass)) timer->stop(); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index cda35297abc..6627e73056a 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -1561,10 +1561,10 @@ public: !canFuseSrcWhichWritesToLiveOut(srcId, dstId, srcStoreOp, mdg)) continue; - // Dont create a private memref if 'writesToLiveInOrOut'. + // Don't create a private memref if 'writesToLiveInOrOut'. bool createPrivateMemref = !writesToLiveInOrOut; - // Dont create a private memref if 'srcNode' has in edges on 'memref', - // or if 'dstNode' has out edges on 'memref'. + // Don't create a private memref if 'srcNode' has in edges on + // 'memref', or if 'dstNode' has out edges on 'memref'. if (mdg->getIncomingMemRefAccesses(srcNode->id, memref) > 0 || mdg->getOutEdgeCount(dstNode->id, memref) > 0) { createPrivateMemref = false; diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir index b78dbf24bc8..794ed4cd4f7 100644 --- a/mlir/test/IR/traits.mlir +++ b/mlir/test/IR/traits.mlir @@ -265,7 +265,7 @@ func @failedOperandSizeAttrWrongTotalSize(%arg: i32) { // ----- func @failedOperandSizeAttrWrongCount(%arg: i32) { - // expected-error @+1 {{'operand_segment_sizes' attribute for specifiying operand segments must have 4 elements}} + // expected-error @+1 {{'operand_segment_sizes' attribute for specifying operand segments must have 4 elements}} "test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = dense<[2, 1, 1]>: vector<3xi32>} : (i32, i32, i32, i32) -> () } @@ -315,7 +315,7 @@ func @failedResultSizeAttrWrongTotalSize() { // ----- func @failedResultSizeAttrWrongCount() { - // expected-error @+1 {{'result_segment_sizes' attribute for specifiying result segments must have 4 elements}} + // expected-error @+1 {{'result_segment_sizes' attribute for specifying result segments must have 4 elements}} %0:4 = "test.attr_sized_results"() {result_segment_sizes = dense<[2, 1, 1]>: vector<3xi32>} : () -> (i32, i32, i32, i32) } diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 16894ad4cb3..b5fd0862b45 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1439,7 +1439,7 @@ void OpEmitter::genVerifier() { auto sizeAttr = getAttrOfType("{0}"); auto numElements = sizeAttr.getType().cast().getNumElements(); if (numElements != {1}) {{ - return emitOpError("'{0}' attribute for specifiying {2} segments " + return emitOpError("'{0}' attribute for specifying {2} segments " "must have {1} elements"); } )"; diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index d321b204f4e..f229a349d27 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -685,7 +685,7 @@ std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) { } for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)); - LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argment #" << i + LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i << " replacement: " << attrs[i] << "\n"); } return tgfmt(fmt, &fmtCtx, attrs[0], attrs[1], attrs[2], attrs[3], attrs[4], @@ -769,7 +769,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, if (isSameOperandsAndResultType || useFirstAttr) { // We know how to deduce the result type for ops with these traits and we've - // generated builders taking aggregrate parameters. Use those builders to + // generated builders taking aggregate parameters. Use those builders to // create the ops. // First prepare local variables for op arguments used in builder call. @@ -891,7 +891,7 @@ void PatternEmitter::supplyValuesForOpArgs( Operator &resultOp = node.getDialectOp(opMap); for (int argIndex = 0, numOpArgs = resultOp.getNumArgs(); argIndex != numOpArgs; ++argIndex) { - // Start each argment on its own line. + // Start each argument on its own line. (os << ",\n").indent(8); Argument opArg = resultOp.getArg(argIndex); diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp index f39295a22c8..422183ed948 100644 --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -687,7 +687,7 @@ static void emitEnumGetSymbolizeFnDefn(const EnumAttr &enumAttr, } static bool emitOpUtils(const RecordKeeper &recordKeeper, raw_ostream &os) { - llvm::emitSourceFileHeader("SPIR-V Op Utilites", os); + llvm::emitSourceFileHeader("SPIR-V Op Utilities", os); auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo"); os << "#ifndef SPIRV_OP_UTILS_H_\n"; diff --git a/mlir/unittests/TableGen/StructsGenTest.cpp b/mlir/unittests/TableGen/StructsGenTest.cpp index c8b811db935..b446ca9558a 100644 --- a/mlir/unittests/TableGen/StructsGenTest.cpp +++ b/mlir/unittests/TableGen/StructsGenTest.cpp @@ -109,7 +109,7 @@ TEST(StructsGenTest, ClassofMissingFalse) { llvm::SmallVector newValues( expectedValues.begin() + 1, expectedValues.end()); - // Make a new DictionaryAttr and validate it is not a validte TestStruct. + // Make a new DictionaryAttr and validate it is not a validate TestStruct. auto badDictionary = mlir::DictionaryAttr::get(newValues, &context); ASSERT_FALSE(test::TestStruct::classof(badDictionary)); } -- cgit v1.2.3 From 2666b97314ad1b50f88fcc4376ae941f601f67ea Mon Sep 17 00:00:00 2001 From: River Riddle Date: Wed, 18 Dec 2019 10:46:16 -0800 Subject: NFC: Cleanup non-conforming usages of namespaces. * Fixes use of anonymous namespace for static methods. * Uses explicit qualifiers(mlir::) instead of wrapping the definition with the namespace. PiperOrigin-RevId: 286222654 --- .../Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp | 11 +++--- .../VectorToLoops/ConvertVectorToLoops.cpp | 15 ++++---- .../Dialect/QuantOps/Utils/FakeQuantSupport.cpp | 40 +++++++++----------- mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp | 20 +++++----- mlir/lib/Dialect/SDBM/SDBMExpr.cpp | 12 ++---- .../DecorateSPIRVCompositeTypeLayoutPass.cpp | 2 +- mlir/lib/ExecutionEngine/ExecutionEngine.cpp | 7 +--- mlir/lib/Quantizer/Support/Statistics.cpp | 9 +---- mlir/lib/Quantizer/Support/UniformSolvers.cpp | 14 +++---- mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp | 2 +- mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp | 2 +- mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 43 ++++++++++------------ mlir/lib/Transforms/LoopFusion.cpp | 22 ++++++----- 13 files changed, 88 insertions(+), 111 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp index 92cc02660a2..42483a6e5df 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp @@ -18,6 +18,7 @@ // This file implements the conversion patterns from GPU ops to SPIR-V dialect. // //===----------------------------------------------------------------------===// +#include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" @@ -350,11 +351,10 @@ PatternMatchResult GPUReturnOpConversion::matchAndRewrite( // GPU To SPIRV Patterns. //===----------------------------------------------------------------------===// -namespace mlir { -void populateGPUToSPIRVPatterns(MLIRContext *context, - SPIRVTypeConverter &typeConverter, - OwningRewritePatternList &patterns, - ArrayRef workGroupSize) { +void mlir::populateGPUToSPIRVPatterns(MLIRContext *context, + SPIRVTypeConverter &typeConverter, + OwningRewritePatternList &patterns, + ArrayRef workGroupSize) { patterns.insert(context, typeConverter, workGroupSize); patterns.insert< GPUReturnOpConversion, ForOpConversion, KernelModuleConversion, @@ -366,4 +366,3 @@ void populateGPUToSPIRVPatterns(MLIRContext *context, spirv::BuiltIn::LocalInvocationId>>(context, typeConverter); } -} // namespace mlir diff --git a/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp b/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp index 721e7092cfc..0b39f604b41 100644 --- a/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp +++ b/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp @@ -117,14 +117,16 @@ struct VectorTransferRewriter : public RewritePattern { PatternRewriter &rewriter) const override; }; +} // namespace + /// Analyzes the `transfer` to find an access dimension along the fastest remote /// MemRef dimension. If such a dimension with coalescing properties is found, /// `pivs` and `vectorView` are swapped so that the invocation of /// LoopNestBuilder captures it in the innermost loop. template -void coalesceCopy(TransferOpTy transfer, - SmallVectorImpl *pivs, - edsc::VectorView *vectorView) { +static void coalesceCopy(TransferOpTy transfer, + SmallVectorImpl *pivs, + edsc::VectorView *vectorView) { // rank of the remote memory access, coalescing behavior occurs on the // innermost memory dimension. auto remoteRank = transfer.getMemRefType().getRank(); @@ -155,9 +157,9 @@ void coalesceCopy(TransferOpTy transfer, /// Emits remote memory accesses that are clipped to the boundaries of the /// MemRef. template -SmallVector clip(TransferOpTy transfer, - edsc::MemRefView &view, - ArrayRef ivs) { +static SmallVector clip(TransferOpTy transfer, + edsc::MemRefView &view, + ArrayRef ivs) { using namespace mlir::edsc; using namespace edsc::op; using edsc::intrinsics::select; @@ -357,7 +359,6 @@ PatternMatchResult VectorTransferRewriter::matchAndRewrite( rewriter.eraseOp(op); return matchSuccess(); } -} // namespace void mlir::populateVectorToAffineLoopsConversionPatterns( MLIRContext *context, OwningRewritePatternList &patterns) { diff --git a/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp b/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp index 10668f87ed4..f4256cf25c8 100644 --- a/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp +++ b/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp @@ -18,12 +18,13 @@ #include "mlir/Dialect/QuantOps/FakeQuantSupport.h" #include "mlir/Dialect/QuantOps/QuantTypes.h" -namespace mlir { -namespace quant { -namespace { -bool getDefaultStorageParams(unsigned numBits, bool narrowRange, bool isSigned, - MLIRContext *ctx, Type &storageType, int64_t &qmin, - int64_t &qmax) { +using namespace mlir; +using namespace mlir::quant; + +static bool getDefaultStorageParams(unsigned numBits, bool narrowRange, + bool isSigned, MLIRContext *ctx, + Type &storageType, int64_t &qmin, + int64_t &qmax) { // Hard-coded type mapping from TFLite. if (numBits <= 8) { storageType = IntegerType::get(8, ctx); @@ -62,9 +63,9 @@ bool getDefaultStorageParams(unsigned numBits, bool narrowRange, bool isSigned, // range will be outside the shifted range and be clamped during quantization. // TODO(fengliuai): we should nudge the scale as well, but that requires the // fake quant op used in the training to use the nudged scale as well. -void getNudgedScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin, - double rmax, double &scale, - int64_t &nudgedZeroPoint) { +static void getNudgedScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin, + double rmax, double &scale, + int64_t &nudgedZeroPoint) { // Determine the scale. const double qminDouble = qmin; const double qmaxDouble = qmax; @@ -103,12 +104,10 @@ void getNudgedScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin, assert(nudgedZeroPoint <= qmax); } -} // end namespace - -UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits, - double rmin, double rmax, - bool narrowRange, Type expressedType, - bool isSigned) { +UniformQuantizedType +mlir::quant::fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin, + double rmax, bool narrowRange, + Type expressedType, bool isSigned) { MLIRContext *ctx = expressedType.getContext(); unsigned flags = isSigned ? QuantizationFlags::Signed : 0; Type storageType; @@ -137,10 +136,10 @@ UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits, loc); } -UniformQuantizedPerAxisType -fakeQuantAttrsToType(Location loc, unsigned numBits, int32_t quantizedDimension, - ArrayRef rmins, ArrayRef rmaxs, - bool narrowRange, Type expressedType, bool isSigned) { +UniformQuantizedPerAxisType mlir::quant::fakeQuantAttrsToType( + Location loc, unsigned numBits, int32_t quantizedDimension, + ArrayRef rmins, ArrayRef rmaxs, bool narrowRange, + Type expressedType, bool isSigned) { size_t axis_size = rmins.size(); if (axis_size != rmaxs.size()) { return (emitError(loc, "mismatched per-axis min and max size: ") @@ -183,6 +182,3 @@ fakeQuantAttrsToType(Location loc, unsigned numBits, int32_t quantizedDimension, flags, storageType, expressedType, scales, zeroPoints, quantizedDimension, qmin, qmax, loc); } - -} // namespace quant -} // namespace mlir diff --git a/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp b/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp index e7a1df97599..56e2cbae4f0 100644 --- a/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp +++ b/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp @@ -20,8 +20,9 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/StandardTypes.h" -namespace mlir { -namespace quant { +using namespace mlir; +using namespace mlir::quant; + /// Converts a possible primitive, real expressed value attribute to a /// corresponding storage attribute (typically FloatAttr -> IntegerAttr). /// quantizedElementType is the QuantizedType that describes the expressed @@ -104,10 +105,9 @@ convertSparseElementsAttr(SparseElementsAttr realSparseAttr, /// Converts a real expressed Attribute to a corresponding Attribute containing /// quantized storage values assuming the given uniform quantizedElementType and /// converter. -Attribute quantizeAttrUniform(Attribute realValue, - UniformQuantizedType quantizedElementType, - const UniformQuantizedValueConverter &converter, - Type &outConvertedType) { +Attribute mlir::quant::quantizeAttrUniform( + Attribute realValue, UniformQuantizedType quantizedElementType, + const UniformQuantizedValueConverter &converter, Type &outConvertedType) { // Fork to handle different variants of constants supported. if (realValue.isa()) { // Dense tensor or vector constant. @@ -133,8 +133,9 @@ Attribute quantizeAttrUniform(Attribute realValue, /// quantizedElementType.getStorageType(). /// Returns nullptr if the conversion is not supported. /// On success, stores the converted type in outConvertedType. -Attribute quantizeAttr(Attribute realValue, QuantizedType quantizedElementType, - Type &outConvertedType) { +Attribute mlir::quant::quantizeAttr(Attribute realValue, + QuantizedType quantizedElementType, + Type &outConvertedType) { if (auto uniformQuantized = quantizedElementType.dyn_cast()) { UniformQuantizedValueConverter converter(uniformQuantized); @@ -154,6 +155,3 @@ Attribute quantizeAttr(Attribute realValue, QuantizedType quantizedElementType, return nullptr; } } - -} // namespace quant -} // namespace mlir diff --git a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp index 8cdd9c8566e..44cdd18cf98 100644 --- a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp +++ b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp @@ -671,10 +671,7 @@ SDBMDirectExpr SDBMNegExpr::getVar() const { return static_cast(impl)->expr; } -namespace mlir { -namespace ops_assertions { - -SDBMExpr operator+(SDBMExpr lhs, SDBMExpr rhs) { +SDBMExpr mlir::ops_assertions::operator+(SDBMExpr lhs, SDBMExpr rhs) { if (auto folded = foldSumDiff(lhs, rhs)) return folded; assert(!(lhs.isa() && rhs.isa()) && @@ -707,7 +704,7 @@ SDBMExpr operator+(SDBMExpr lhs, SDBMExpr rhs) { return addConstant(lhs.cast(), rhsConstant.getValue()); } -SDBMExpr operator-(SDBMExpr lhs, SDBMExpr rhs) { +SDBMExpr mlir::ops_assertions::operator-(SDBMExpr lhs, SDBMExpr rhs) { // Fold x - x == 0. if (lhs == rhs) return SDBMConstantExpr::get(lhs.getDialect(), 0); @@ -734,7 +731,7 @@ SDBMExpr operator-(SDBMExpr lhs, SDBMExpr rhs) { return buildDiffExpr(lhs.cast(), (-rhs).cast()); } -SDBMExpr stripe(SDBMExpr expr, SDBMExpr factor) { +SDBMExpr mlir::ops_assertions::stripe(SDBMExpr expr, SDBMExpr factor) { auto constantFactor = factor.cast(); assert(constantFactor.getValue() > 0 && "non-positive stripe"); @@ -744,6 +741,3 @@ SDBMExpr stripe(SDBMExpr expr, SDBMExpr factor) { return SDBMStripeExpr::get(expr.cast(), constantFactor); } - -} // namespace ops_assertions -} // namespace mlir diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp index 1fd6274b16e..be486f858fe 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp @@ -93,6 +93,7 @@ class DecorateSPIRVCompositeTypeLayoutPass private: void runOnModule() override; }; +} // namespace void DecorateSPIRVCompositeTypeLayoutPass::runOnModule() { auto module = getModule(); @@ -120,7 +121,6 @@ void DecorateSPIRVCompositeTypeLayoutPass::runOnModule() { } } } -} // namespace std::unique_ptr> mlir::spirv::createDecorateSPIRVCompositeTypeLayoutPass() { diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp index bbee80ac4e9..5098ba81762 100644 --- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp @@ -63,14 +63,12 @@ using llvm::orc::RTDyldObjectLinkingLayer; using llvm::orc::ThreadSafeModule; using llvm::orc::TMOwningSimpleCompiler; -// Wrap a string into an llvm::StringError. -static inline Error make_string_error(const Twine &message) { +/// Wrap a string into an llvm::StringError. +static Error make_string_error(const Twine &message) { return llvm::make_error(message.str(), llvm::inconvertibleErrorCode()); } -namespace mlir { - void SimpleObjectCache::notifyObjectCompiled(const Module *M, MemoryBufferRef ObjBuffer) { cachedObjects[M->getModuleIdentifier()] = MemoryBuffer::getMemBufferCopy( @@ -316,4 +314,3 @@ Error ExecutionEngine::invoke(StringRef name, MutableArrayRef args) { return Error::success(); } -} // end namespace mlir diff --git a/mlir/lib/Quantizer/Support/Statistics.cpp b/mlir/lib/Quantizer/Support/Statistics.cpp index d155875cfe3..6753898dbdc 100644 --- a/mlir/lib/Quantizer/Support/Statistics.cpp +++ b/mlir/lib/Quantizer/Support/Statistics.cpp @@ -95,15 +95,10 @@ bool AttributeTensorStatistics::get(TensorAxisStatistics &stats) const { return false; } -namespace mlir { -namespace quantizer { - -raw_ostream &operator<<(raw_ostream &os, const TensorAxisStatistics &stats) { +raw_ostream &mlir::quantizer::operator<<(raw_ostream &os, + const TensorAxisStatistics &stats) { os << "STATS[sampleSize=" << stats.sampleSize << ", min=" << stats.minValue << ", maxValue=" << stats.maxValue << ", mean=" << stats.mean << ", variance=" << stats.variance << "]"; return os; } - -} // end namespace quantizer -} // end namespace mlir diff --git a/mlir/lib/Quantizer/Support/UniformSolvers.cpp b/mlir/lib/Quantizer/Support/UniformSolvers.cpp index bd2fe686ee1..77d69be8382 100644 --- a/mlir/lib/Quantizer/Support/UniformSolvers.cpp +++ b/mlir/lib/Quantizer/Support/UniformSolvers.cpp @@ -127,16 +127,15 @@ double UniformParamsFromMinMaxSolver::dequantize(int64_t xq) const { return (xq - zp) * delta; } -namespace mlir { -namespace quantizer { - -raw_ostream &operator<<(raw_ostream &os, const UniformStorageParams &p) { +raw_ostream &mlir::quantizer::operator<<(raw_ostream &os, + const UniformStorageParams &p) { os << "UniformStorageParams{" << p.numLevels << ", " << p.minValue << "}"; return os; } -raw_ostream &operator<<(raw_ostream &os, - const UniformParamsFromMinMaxSolver &s) { +raw_ostream & +mlir::quantizer::operator<<(raw_ostream &os, + const UniformParamsFromMinMaxSolver &s) { os << "UniformParamsFromMinMaxSolver(" << s.getStepCount() << "){"; os << "(" << s.getBoundingMin() << ":" << s.getBoundingMax() << ") -> "; if (!s.isSatisfied()) { @@ -151,6 +150,3 @@ raw_ostream &operator<<(raw_ostream &os, return os; } - -} // end namespace quantizer -} // end namespace mlir diff --git a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp index 83c486979d6..8baed9854f1 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp @@ -36,7 +36,6 @@ using namespace mlir; -namespace { static llvm::Value *createIntrinsicCall(llvm::IRBuilder<> &builder, llvm::Intrinsic::ID intrinsic, ArrayRef args = {}) { @@ -56,6 +55,7 @@ static llvm::Intrinsic::ID getShflBflyIntrinsicId(llvm::Type *resultType, : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32; } +namespace { class ModuleTranslation : public LLVM::ModuleTranslation { public: diff --git a/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp index c06e1cadbc4..34786fb1868 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp @@ -39,7 +39,6 @@ using namespace mlir; -namespace { // Create a call to llvm intrinsic static llvm::Value *createIntrinsicCall(llvm::IRBuilder<> &builder, llvm::Intrinsic::ID intrinsic, @@ -67,6 +66,7 @@ static llvm::Value *createDeviceFunctionCall(llvm::IRBuilder<> &builder, return builder.CreateCall(fn, ArrayRef(fn_op0)); } +namespace { class ModuleTranslation : public LLVM::ModuleTranslation { public: diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 086c3a831fc..6206a88e870 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -36,13 +36,13 @@ #include "llvm/IR/Module.h" #include "llvm/Transforms/Utils/Cloning.h" -namespace mlir { -namespace LLVM { +using namespace mlir; +using namespace mlir::LLVM; -// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`. -// This currently supports integer, floating point, splat and dense element -// attributes and combinations thereof. In case of error, report it to `loc` -// and return nullptr. +/// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`. +/// This currently supports integer, floating point, splat and dense element +/// attributes and combinations thereof. In case of error, report it to `loc` +/// and return nullptr. llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType, Attribute attr, Location loc) { @@ -94,7 +94,7 @@ llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType, return nullptr; } -// Convert MLIR integer comparison predicate to LLVM IR comparison predicate. +/// Convert MLIR integer comparison predicate to LLVM IR comparison predicate. static llvm::CmpInst::Predicate getLLVMCmpPredicate(ICmpPredicate p) { switch (p) { case LLVM::ICmpPredicate::eq: @@ -159,10 +159,10 @@ static llvm::CmpInst::Predicate getLLVMCmpPredicate(FCmpPredicate p) { llvm_unreachable("incorrect comparison predicate"); } -// Given a single MLIR operation, create the corresponding LLVM IR operation -// using the `builder`. LLVM IR Builder does not have a generic interface so -// this has to be a long chain of `if`s calling different functions with a -// different number of arguments. +/// Given a single MLIR operation, create the corresponding LLVM IR operation +/// using the `builder`. LLVM IR Builder does not have a generic interface so +/// this has to be a long chain of `if`s calling different functions with a +/// different number of arguments. LogicalResult ModuleTranslation::convertOperation(Operation &opInst, llvm::IRBuilder<> &builder) { auto extractPosition = [](ArrayAttr attr) { @@ -232,9 +232,9 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst, << opInst.getName(); } -// Convert block to LLVM IR. Unless `ignoreArguments` is set, emit PHI nodes -// to define values corresponding to the MLIR block arguments. These nodes -// are not connected to the source basic blocks, which may not exist yet. +/// Convert block to LLVM IR. Unless `ignoreArguments` is set, emit PHI nodes +/// to define values corresponding to the MLIR block arguments. These nodes +/// are not connected to the source basic blocks, which may not exist yet. LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments) { llvm::IRBuilder<> builder(blockMapping[&bb]); @@ -268,7 +268,7 @@ LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments) { return success(); } -// Convert the LLVM dialect linkage type to LLVM IR linkage type. +/// Convert the LLVM dialect linkage type to LLVM IR linkage type. llvm::GlobalVariable::LinkageTypes convertLinkageType(LLVM::Linkage linkage) { switch (linkage) { case LLVM::Linkage::Private: @@ -297,8 +297,8 @@ llvm::GlobalVariable::LinkageTypes convertLinkageType(LLVM::Linkage linkage) { llvm_unreachable("unknown linkage type"); } -// Create named global variables that correspond to llvm.mlir.global -// definitions. +/// Create named global variables that correspond to llvm.mlir.global +/// definitions. void ModuleTranslation::convertGlobals() { for (auto op : getModuleBody(mlirModule).getOps()) { llvm::Type *type = op.getType().getUnderlyingType(); @@ -340,8 +340,8 @@ void ModuleTranslation::convertGlobals() { } } -// Get the SSA value passed to the current block from the terminator operation -// of its predecessor. +/// Get the SSA value passed to the current block from the terminator operation +/// of its predecessor. static Value *getPHISourceValue(Block *current, Block *pred, unsigned numArguments, unsigned index) { auto &terminator = *pred->getTerminator(); @@ -394,7 +394,7 @@ static void topologicalSortImpl(llvm::SetVector &blocks, Block *b) { } } -// Sort function blocks topologically. +/// Sort function blocks topologically. static llvm::SetVector topologicalSort(LLVMFuncOp f) { // For each blocks that has not been visited yet (i.e. that has no // predecessors), add it to the list and traverse its successors in DFS @@ -513,6 +513,3 @@ ModuleTranslation::prepareLLVMModule(Operation *m) { return llvmModule; } - -} // namespace LLVM -} // namespace mlir diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 6627e73056a..5694c990b9b 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -118,6 +118,14 @@ mlir::createLoopFusionPass(unsigned fastMemorySpace, maximalFusion); } +// TODO(b/117228571) Replace when this is modeled through side-effects/op traits +static bool isMemRefDereferencingOp(Operation &op) { + if (isa(op) || isa(op) || + isa(op) || isa(op)) + return true; + return false; +} + namespace { // LoopNestStateCollector walks loop nests and collects load and store @@ -142,14 +150,6 @@ struct LoopNestStateCollector { } }; -// TODO(b/117228571) Replace when this is modeled through side-effects/op traits -static bool isMemRefDereferencingOp(Operation &op) { - if (isa(op) || isa(op) || - isa(op) || isa(op)) - return true; - return false; -} - // MemRefDependenceGraph is a graph data structure where graph nodes are // top-level operations in a FuncOp which contain load/store ops, and edges // are memref dependences between the nodes. @@ -674,6 +674,8 @@ public: void dump() const { print(llvm::errs()); } }; +} // end anonymous namespace + // Initializes the data dependence graph by walking operations in 'f'. // Assigns each node in the graph a node id based on program order in 'f'. // TODO(andydavis) Add support for taking a Block arg to construct the @@ -872,7 +874,7 @@ static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) { } // TODO(mlir-team): improve/complete this when we have target data. -unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { +static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { auto elementType = memRefType.getElementType(); unsigned sizeInBits; @@ -1373,6 +1375,8 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, return true; } +namespace { + // GreedyFusion greedily fuses loop nests which have a producer/consumer or // input-reuse relationship on a memref, with the goal of improving locality. // -- cgit v1.2.3 From 35807bc4c5c9d8abc31ba0b2f955a82abf276e12 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Sun, 22 Dec 2019 21:59:55 -0800 Subject: NFC: Introduce new ValuePtr/ValueRef typedefs to simplify the transition to Value being value-typed. This is an initial step to refactoring the representation of OpResult as proposed in: https://groups.google.com/a/tensorflow.org/g/mlir/c/XXzzKhqqF_0/m/v6bKb08WCgAJ This change will make it much simpler to incrementally transition all of the existing code to use value-typed semantics. PiperOrigin-RevId: 286844725 --- mlir/bindings/python/pybind.cpp | 2 +- mlir/examples/toy/Ch2/include/toy/Ops.td | 8 +- mlir/examples/toy/Ch2/mlir/Dialect.cpp | 9 +- mlir/examples/toy/Ch2/mlir/MLIRGen.cpp | 41 +-- mlir/examples/toy/Ch3/include/toy/Ops.td | 8 +- mlir/examples/toy/Ch3/mlir/Dialect.cpp | 9 +- mlir/examples/toy/Ch3/mlir/MLIRGen.cpp | 41 +-- mlir/examples/toy/Ch3/mlir/ToyCombine.cpp | 2 +- mlir/examples/toy/Ch4/include/toy/Ops.td | 8 +- mlir/examples/toy/Ch4/mlir/Dialect.cpp | 13 +- mlir/examples/toy/Ch4/mlir/MLIRGen.cpp | 41 +-- mlir/examples/toy/Ch4/mlir/ToyCombine.cpp | 2 +- mlir/examples/toy/Ch5/include/toy/Ops.td | 8 +- mlir/examples/toy/Ch5/mlir/Dialect.cpp | 13 +- mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp | 36 +-- mlir/examples/toy/Ch5/mlir/MLIRGen.cpp | 41 +-- mlir/examples/toy/Ch5/mlir/ToyCombine.cpp | 2 +- mlir/examples/toy/Ch6/include/toy/Ops.td | 8 +- mlir/examples/toy/Ch6/mlir/Dialect.cpp | 13 +- mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp | 36 +-- mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp | 24 +- mlir/examples/toy/Ch6/mlir/MLIRGen.cpp | 41 +-- mlir/examples/toy/Ch6/mlir/ToyCombine.cpp | 2 +- mlir/examples/toy/Ch7/include/toy/Ops.td | 10 +- mlir/examples/toy/Ch7/mlir/Dialect.cpp | 15 +- mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp | 36 +-- mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp | 24 +- mlir/examples/toy/Ch7/mlir/MLIRGen.cpp | 40 +-- mlir/examples/toy/Ch7/mlir/ToyCombine.cpp | 2 +- mlir/g3doc/DeclarativeRewrites.md | 6 +- mlir/g3doc/DialectConversion.md | 6 +- mlir/g3doc/EDSC.md | 8 +- mlir/g3doc/GenericDAGRewriter.md | 2 +- mlir/g3doc/OpDefinitions.md | 14 +- mlir/g3doc/QuickstartRewrites.md | 4 +- mlir/g3doc/Rationale.md | 2 +- mlir/g3doc/Tutorials/Toy/Ch-3.md | 2 +- mlir/g3doc/Tutorials/Toy/Ch-4.md | 4 +- mlir/g3doc/Tutorials/Toy/Ch-5.md | 10 +- mlir/g3doc/UsageOfConst.md | 8 +- mlir/include/mlir/Analysis/AffineAnalysis.h | 9 +- mlir/include/mlir/Analysis/AffineStructures.h | 72 ++--- mlir/include/mlir/Analysis/CallInterfaces.h | 4 +- mlir/include/mlir/Analysis/Dominance.h | 4 +- mlir/include/mlir/Analysis/Liveness.h | 17 +- mlir/include/mlir/Analysis/LoopAnalysis.h | 9 +- mlir/include/mlir/Analysis/Utils.h | 10 +- .../Conversion/AffineToStandard/AffineToStandard.h | 13 +- .../mlir/Conversion/LoopsToGPU/LoopsToGPU.h | 7 +- .../StandardToLLVM/ConvertStandardToLLVM.h | 57 ++-- mlir/include/mlir/Dialect/AffineOps/AffineOps.h | 105 +++---- mlir/include/mlir/Dialect/AffineOps/AffineOps.td | 8 +- mlir/include/mlir/Dialect/GPU/GPUDialect.h | 6 +- mlir/include/mlir/Dialect/GPU/GPUOps.td | 16 +- mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h | 6 +- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 22 +- .../Dialect/Linalg/Analysis/DependenceAnalysis.h | 16 +- mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h | 20 +- .../mlir/Dialect/Linalg/IR/LinalgLibraryOps.td | 10 +- mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td | 16 +- .../mlir/Dialect/Linalg/IR/LinalgStructuredOps.td | 10 +- mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h | 8 +- .../Linalg/Transforms/LinalgTransformPatterns.td | 2 +- .../Dialect/Linalg/Transforms/LinalgTransforms.h | 4 +- mlir/include/mlir/Dialect/Linalg/Utils/Utils.h | 36 +-- mlir/include/mlir/Dialect/LoopOps/LoopOps.h | 2 +- mlir/include/mlir/Dialect/LoopOps/LoopOps.td | 12 +- .../mlir/Dialect/SPIRV/SPIRVCompositeOps.td | 2 +- .../mlir/Dialect/SPIRV/SPIRVControlFlowOps.td | 2 +- mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td | 4 +- mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h | 4 +- mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td | 6 +- mlir/include/mlir/Dialect/StandardOps/Ops.h | 35 +-- mlir/include/mlir/Dialect/StandardOps/Ops.td | 78 +++--- mlir/include/mlir/Dialect/VectorOps/Utils.h | 5 +- mlir/include/mlir/Dialect/VectorOps/VectorOps.td | 22 +- .../mlir/Dialect/VectorOps/VectorTransforms.h | 5 +- mlir/include/mlir/EDSC/Builders.h | 32 +-- mlir/include/mlir/EDSC/Helpers.h | 10 +- mlir/include/mlir/EDSC/Intrinsics.h | 26 +- mlir/include/mlir/IR/Block.h | 8 +- mlir/include/mlir/IR/BlockAndValueMapping.h | 8 +- mlir/include/mlir/IR/Builders.h | 10 +- mlir/include/mlir/IR/FunctionSupport.h | 2 +- mlir/include/mlir/IR/Matchers.h | 14 +- mlir/include/mlir/IR/OpDefinition.h | 40 +-- mlir/include/mlir/IR/OpImplementation.h | 30 +- mlir/include/mlir/IR/Operation.h | 22 +- mlir/include/mlir/IR/OperationSupport.h | 45 +-- mlir/include/mlir/IR/TypeUtilities.h | 12 +- mlir/include/mlir/IR/Value.h | 22 +- .../Quantizer/Support/ConstraintAnalysisGraph.h | 10 +- .../include/mlir/Target/LLVMIR/ModuleTranslation.h | 2 +- mlir/include/mlir/Transforms/DialectConversion.h | 46 ++-- mlir/include/mlir/Transforms/FoldUtils.h | 10 +- mlir/include/mlir/Transforms/InliningUtils.h | 14 +- mlir/include/mlir/Transforms/LoopLikeInterface.td | 2 +- mlir/include/mlir/Transforms/LoopUtils.h | 12 +- mlir/include/mlir/Transforms/RegionUtils.h | 8 +- mlir/include/mlir/Transforms/Utils.h | 20 +- mlir/lib/Analysis/AffineAnalysis.cpp | 60 ++-- mlir/lib/Analysis/AffineStructures.cpp | 94 +++---- mlir/lib/Analysis/CallGraph.cpp | 2 +- mlir/lib/Analysis/Dominance.cpp | 2 +- mlir/lib/Analysis/Liveness.cpp | 34 +-- mlir/lib/Analysis/LoopAnalysis.cpp | 30 +- mlir/lib/Analysis/SliceAnalysis.cpp | 4 +- mlir/lib/Analysis/Utils.cpp | 42 +-- mlir/lib/Analysis/VectorAnalysis.cpp | 4 +- mlir/lib/Analysis/Verifier.cpp | 6 +- .../AffineToStandard/AffineToStandard.cpp | 139 +++++----- .../GPUCommon/IndexIntrinsicsOpLowering.h | 4 +- .../Conversion/GPUCommon/OpToFuncCallLowering.h | 6 +- .../GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp | 46 ++-- .../Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 217 +++++++-------- .../Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp | 30 +- mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp | 68 ++--- .../LoopToStandard/ConvertLoopToStandard.cpp | 18 +- mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp | 111 ++++---- mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp | 2 +- .../StandardToLLVM/ConvertStandardToLLVM.cpp | 301 +++++++++++---------- .../StandardToSPIRV/ConvertStandardToSPIRV.cpp | 37 +-- .../StandardToSPIRV/ConvertStandardToSPIRVPass.cpp | 4 +- .../StandardToSPIRV/LegalizeStandardForSPIRV.cpp | 8 +- .../VectorToLLVM/ConvertVectorToLLVM.cpp | 108 ++++---- mlir/lib/Dialect/AffineOps/AffineOps.cpp | 147 +++++----- .../FxpMathOps/Transforms/LowerUniformRealMath.cpp | 64 ++--- .../FxpMathOps/Transforms/UniformKernelUtils.h | 6 +- mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 47 ++-- .../lib/Dialect/GPU/Transforms/KernelOutlining.cpp | 12 +- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 26 +- .../Dialect/Linalg/Analysis/DependenceAnalysis.cpp | 20 +- mlir/lib/Dialect/Linalg/EDSC/Builders.cpp | 18 +- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 4 +- mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 32 +-- .../Dialect/Linalg/Transforms/LinalgToLoops.cpp | 44 +-- .../Dialect/Linalg/Transforms/LinalgTransforms.cpp | 6 +- mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp | 35 +-- mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 53 ++-- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 24 +- mlir/lib/Dialect/LoopOps/LoopOps.cpp | 12 +- mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp | 2 +- mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp | 8 +- mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 39 +-- .../Dialect/SPIRV/Serialization/Deserializer.cpp | 38 +-- .../lib/Dialect/SPIRV/Serialization/Serializer.cpp | 18 +- .../SPIRV/Transforms/LowerABIAttributesPass.cpp | 6 +- mlir/lib/Dialect/StandardOps/Ops.cpp | 66 ++--- mlir/lib/Dialect/VectorOps/VectorOps.cpp | 30 +- mlir/lib/Dialect/VectorOps/VectorTransforms.cpp | 76 +++--- mlir/lib/EDSC/Builders.cpp | 23 +- mlir/lib/EDSC/Helpers.cpp | 6 +- mlir/lib/EDSC/Intrinsics.cpp | 12 +- mlir/lib/IR/AsmPrinter.cpp | 50 ++-- mlir/lib/IR/Block.cpp | 4 +- mlir/lib/IR/Builders.cpp | 4 +- mlir/lib/IR/Operation.cpp | 26 +- mlir/lib/IR/OperationSupport.cpp | 13 +- mlir/lib/IR/Region.cpp | 6 +- mlir/lib/IR/TypeUtilities.cpp | 12 +- mlir/lib/IR/Value.cpp | 4 +- mlir/lib/Parser/Parser.cpp | 65 ++--- mlir/lib/Pass/IRPrinting.cpp | 4 +- .../Quantizer/Support/ConstraintAnalysisGraph.cpp | 2 +- .../Transforms/AddDefaultStatsTestPass.cpp | 2 +- .../Transforms/InferQuantizedTypesPass.cpp | 14 +- mlir/lib/TableGen/Pattern.cpp | 2 +- mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp | 38 +-- mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 10 +- mlir/lib/Transforms/AffineDataCopyGeneration.cpp | 2 +- .../Transforms/AffineLoopInvariantCodeMotion.cpp | 21 +- mlir/lib/Transforms/DialectConversion.cpp | 58 ++-- mlir/lib/Transforms/LoopFusion.cpp | 93 +++---- mlir/lib/Transforms/LoopInvariantCodeMotion.cpp | 4 +- mlir/lib/Transforms/LoopTiling.cpp | 11 +- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 4 +- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 6 +- mlir/lib/Transforms/PipelineDataTransfer.cpp | 14 +- mlir/lib/Transforms/Utils/FoldUtils.cpp | 8 +- .../Utils/GreedyPatternRewriteDriver.cpp | 8 +- mlir/lib/Transforms/Utils/InliningUtils.cpp | 36 +-- mlir/lib/Transforms/Utils/LoopFusionUtils.cpp | 16 +- mlir/lib/Transforms/Utils/LoopUtils.cpp | 169 ++++++------ mlir/lib/Transforms/Utils/RegionUtils.cpp | 24 +- mlir/lib/Transforms/Utils/Utils.cpp | 57 ++-- mlir/lib/Transforms/Vectorize.cpp | 40 +-- mlir/test/EDSC/builder-api-test.cpp | 4 +- mlir/test/lib/TestDialect/TestDialect.cpp | 8 +- mlir/test/lib/TestDialect/TestOps.td | 2 +- mlir/test/lib/TestDialect/TestPatterns.cpp | 33 +-- mlir/test/lib/Transforms/TestLoopMapping.cpp | 2 +- .../test/lib/Transforms/TestVectorizationUtils.cpp | 2 +- mlir/test/mlir-tblgen/op-attribute.td | 6 +- mlir/test/mlir-tblgen/op-decl.td | 24 +- mlir/test/mlir-tblgen/op-operand.td | 10 +- mlir/test/mlir-tblgen/op-result.td | 6 +- mlir/test/mlir-tblgen/predicate.td | 4 +- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 27 +- mlir/tools/mlir-tblgen/RewriterGen.cpp | 20 +- mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp | 2 +- mlir/unittests/IR/OperationSupportTest.cpp | 8 +- 201 files changed, 2493 insertions(+), 2413 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/bindings/python/pybind.cpp b/mlir/bindings/python/pybind.cpp index 825f800c0bd..54646cbe800 100644 --- a/mlir/bindings/python/pybind.cpp +++ b/mlir/bindings/python/pybind.cpp @@ -103,7 +103,7 @@ struct PythonValueHandle { assert(value.hasType() && value.getType().isa() && "can only call function-typed values"); - std::vector argValues; + std::vector argValues; argValues.reserve(args.size()); for (auto arg : args) argValues.push_back(arg.value.getValue()); diff --git a/mlir/examples/toy/Ch2/include/toy/Ops.td b/mlir/examples/toy/Ch2/include/toy/Ops.td index f7c011915ff..dd88b097ab1 100644 --- a/mlir/examples/toy/Ch2/include/toy/Ops.td +++ b/mlir/examples/toy/Ch2/include/toy/Ops.td @@ -98,7 +98,7 @@ def AddOp : Toy_Op<"add"> { // Allow building an AddOp with from the two input operands. let builders = [ - OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs"> + OpBuilder<"Builder *b, OperationState &state, ValuePtr lhs, ValuePtr rhs"> ]; } @@ -129,7 +129,7 @@ def GenericCallOp : Toy_Op<"generic_call"> { // Add custom build methods for the generic call operation. let builders = [ OpBuilder<"Builder *builder, OperationState &state, " - "StringRef callee, ArrayRef arguments"> + "StringRef callee, ArrayRef arguments"> ]; } @@ -145,7 +145,7 @@ def MulOp : Toy_Op<"mul"> { // Allow building a MulOp with from the two input operands. let builders = [ - OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs"> + OpBuilder<"Builder *b, OperationState &state, ValuePtr lhs, ValuePtr rhs"> ]; } @@ -219,7 +219,7 @@ def TransposeOp : Toy_Op<"transpose"> { // Allow building a TransposeOp with from the input operand. let builders = [ - OpBuilder<"Builder *b, OperationState &state, Value *input"> + OpBuilder<"Builder *b, OperationState &state, ValuePtr input"> ]; // Invoke a static verify method to verify this transpose operation. diff --git a/mlir/examples/toy/Ch2/mlir/Dialect.cpp b/mlir/examples/toy/Ch2/mlir/Dialect.cpp index 86f648dbe0e..4a3232dabe3 100644 --- a/mlir/examples/toy/Ch2/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch2/mlir/Dialect.cpp @@ -94,7 +94,7 @@ static mlir::LogicalResult verify(ConstantOp op) { // AddOp void AddOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *lhs, mlir::Value *rhs) { + mlir::ValuePtr lhs, mlir::ValuePtr rhs) { state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands({lhs, rhs}); } @@ -103,7 +103,8 @@ void AddOp::build(mlir::Builder *builder, mlir::OperationState &state, // GenericCallOp void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state, - StringRef callee, ArrayRef arguments) { + StringRef callee, + ArrayRef arguments) { // Generic call always returns an unranked Tensor initially. state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands(arguments); @@ -114,7 +115,7 @@ void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state, // MulOp void MulOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *lhs, mlir::Value *rhs) { + mlir::ValuePtr lhs, mlir::ValuePtr rhs) { state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands({lhs, rhs}); } @@ -161,7 +162,7 @@ static mlir::LogicalResult verify(ReturnOp op) { // TransposeOp void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *value) { + mlir::ValuePtr value) { state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands(value); } diff --git a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp index da474e809b3..902c634a954 100644 --- a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp @@ -99,7 +99,7 @@ private: /// Entering a function creates a new scope, and the function arguments are /// added to the mapping. When the processing of a function is terminated, the /// scope is destroyed and the mappings created in this scope are dropped. - llvm::ScopedHashTable symbolTable; + llvm::ScopedHashTable symbolTable; /// Helper conversion for a Toy AST location to an MLIR location. mlir::Location loc(Location loc) { @@ -109,7 +109,7 @@ private: /// Declare a variable in the current scope, return success if the variable /// wasn't declared yet. - mlir::LogicalResult declare(llvm::StringRef var, mlir::Value *value) { + mlir::LogicalResult declare(llvm::StringRef var, mlir::ValuePtr value) { if (symbolTable.count(var)) return mlir::failure(); symbolTable.insert(var, value); @@ -132,7 +132,8 @@ private: /// Emit a new function and add it to the MLIR module. mlir::FuncOp mlirGen(FunctionAST &funcAST) { // Create a scope in the symbol table to hold variable declarations. - ScopedHashTableScope var_scope(symbolTable); + ScopedHashTableScope var_scope( + symbolTable); // Create an MLIR function for the given prototype. mlir::FuncOp function(mlirGen(*funcAST.getProto())); @@ -183,7 +184,7 @@ private: } /// Emit a binary operation - mlir::Value *mlirGen(BinaryExprAST &binop) { + mlir::ValuePtr mlirGen(BinaryExprAST &binop) { // First emit the operations for each side of the operation before emitting // the operation itself. For example if the expression is `a + foo(a)` // 1) First it will visiting the LHS, which will return a reference to the @@ -195,10 +196,10 @@ private: // and the result value is returned. If an error occurs we get a nullptr // and propagate. // - mlir::Value *lhs = mlirGen(*binop.getLHS()); + mlir::ValuePtr lhs = mlirGen(*binop.getLHS()); if (!lhs) return nullptr; - mlir::Value *rhs = mlirGen(*binop.getRHS()); + mlir::ValuePtr rhs = mlirGen(*binop.getRHS()); if (!rhs) return nullptr; auto location = loc(binop.loc()); @@ -219,8 +220,8 @@ private: /// This is a reference to a variable in an expression. The variable is /// expected to have been declared and so should have a value in the symbol /// table, otherwise emit an error and return nullptr. - mlir::Value *mlirGen(VariableExprAST &expr) { - if (auto *variable = symbolTable.lookup(expr.getName())) + mlir::ValuePtr mlirGen(VariableExprAST &expr) { + if (auto variable = symbolTable.lookup(expr.getName())) return variable; emitError(loc(expr.loc()), "error: unknown variable '") @@ -233,7 +234,7 @@ private: auto location = loc(ret.loc()); // 'return' takes an optional expression, handle that case here. - mlir::Value *expr = nullptr; + mlir::ValuePtr expr = nullptr; if (ret.getExpr().hasValue()) { if (!(expr = mlirGen(*ret.getExpr().getValue()))) return mlir::failure(); @@ -241,7 +242,7 @@ private: // Otherwise, this return operation has zero operands. builder.create(location, expr ? makeArrayRef(expr) - : ArrayRef()); + : ArrayRef()); return mlir::success(); } @@ -263,7 +264,7 @@ private: /// [[1.000000e+00, 2.000000e+00, 3.000000e+00], /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64> /// - mlir::Value *mlirGen(LiteralExprAST &lit) { + mlir::ValuePtr mlirGen(LiteralExprAST &lit) { auto type = getType(lit.getDims()); // The attribute is a vector with a floating point value per element @@ -309,14 +310,14 @@ private: /// Emit a call expression. It emits specific operations for the `transpose` /// builtin. Other identifiers are assumed to be user-defined functions. - mlir::Value *mlirGen(CallExprAST &call) { + mlir::ValuePtr mlirGen(CallExprAST &call) { llvm::StringRef callee = call.getCallee(); auto location = loc(call.loc()); // Codegen the operands first. - SmallVector operands; + SmallVector operands; for (auto &expr : call.getArgs()) { - auto *arg = mlirGen(*expr); + auto arg = mlirGen(*expr); if (!arg) return nullptr; operands.push_back(arg); @@ -342,7 +343,7 @@ private: /// Emit a print expression. It emits specific operations for two builtins: /// transpose(x) and print(x). mlir::LogicalResult mlirGen(PrintExprAST &call) { - auto *arg = mlirGen(*call.getArg()); + auto arg = mlirGen(*call.getArg()); if (!arg) return mlir::failure(); @@ -351,12 +352,12 @@ private: } /// Emit a constant for a single number (FIXME: semantic? broadcast?) - mlir::Value *mlirGen(NumberExprAST &num) { + mlir::ValuePtr mlirGen(NumberExprAST &num) { return builder.create(loc(num.loc()), num.getValue()); } /// Dispatch codegen for the right expression subclass using RTTI. - mlir::Value *mlirGen(ExprAST &expr) { + mlir::ValuePtr mlirGen(ExprAST &expr) { switch (expr.getKind()) { case toy::ExprAST::Expr_BinOp: return mlirGen(cast(expr)); @@ -380,7 +381,7 @@ private: /// initializer and record the value in the symbol table before returning it. /// Future expressions will be able to reference this variable through symbol /// table lookup. - mlir::Value *mlirGen(VarDeclExprAST &vardecl) { + mlir::ValuePtr mlirGen(VarDeclExprAST &vardecl) { auto init = vardecl.getInitVal(); if (!init) { emitError(loc(vardecl.loc()), @@ -388,7 +389,7 @@ private: return nullptr; } - mlir::Value *value = mlirGen(*init); + mlir::ValuePtr value = mlirGen(*init); if (!value) return nullptr; @@ -408,7 +409,7 @@ private: /// Codegen a list of expression, return failure if one of them hit an error. mlir::LogicalResult mlirGen(ExprASTList &blockAST) { - ScopedHashTableScope var_scope(symbolTable); + ScopedHashTableScope var_scope(symbolTable); for (auto &expr : blockAST) { // Specific handling for variable declarations, return statement, and // print. These can only appear in block list and not in nested diff --git a/mlir/examples/toy/Ch3/include/toy/Ops.td b/mlir/examples/toy/Ch3/include/toy/Ops.td index 921e503e416..6c400169da2 100644 --- a/mlir/examples/toy/Ch3/include/toy/Ops.td +++ b/mlir/examples/toy/Ch3/include/toy/Ops.td @@ -98,7 +98,7 @@ def AddOp : Toy_Op<"add", [NoSideEffect]> { // Allow building an AddOp with from the two input operands. let builders = [ - OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs"> + OpBuilder<"Builder *b, OperationState &state, ValuePtr lhs, ValuePtr rhs"> ]; } @@ -129,7 +129,7 @@ def GenericCallOp : Toy_Op<"generic_call"> { // Add custom build methods for the generic call operation. let builders = [ OpBuilder<"Builder *builder, OperationState &state, " - "StringRef callee, ArrayRef arguments"> + "StringRef callee, ArrayRef arguments"> ]; } @@ -145,7 +145,7 @@ def MulOp : Toy_Op<"mul", [NoSideEffect]> { // Allow building a MulOp with from the two input operands. let builders = [ - OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs"> + OpBuilder<"Builder *b, OperationState &state, ValuePtr lhs, ValuePtr rhs"> ]; } @@ -225,7 +225,7 @@ def TransposeOp : Toy_Op<"transpose", [NoSideEffect]> { // Allow building a TransposeOp with from the input operand. let builders = [ - OpBuilder<"Builder *b, OperationState &state, Value *input"> + OpBuilder<"Builder *b, OperationState &state, ValuePtr input"> ]; // Invoke a static verify method to verify this transpose operation. diff --git a/mlir/examples/toy/Ch3/mlir/Dialect.cpp b/mlir/examples/toy/Ch3/mlir/Dialect.cpp index 86f648dbe0e..4a3232dabe3 100644 --- a/mlir/examples/toy/Ch3/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch3/mlir/Dialect.cpp @@ -94,7 +94,7 @@ static mlir::LogicalResult verify(ConstantOp op) { // AddOp void AddOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *lhs, mlir::Value *rhs) { + mlir::ValuePtr lhs, mlir::ValuePtr rhs) { state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands({lhs, rhs}); } @@ -103,7 +103,8 @@ void AddOp::build(mlir::Builder *builder, mlir::OperationState &state, // GenericCallOp void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state, - StringRef callee, ArrayRef arguments) { + StringRef callee, + ArrayRef arguments) { // Generic call always returns an unranked Tensor initially. state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands(arguments); @@ -114,7 +115,7 @@ void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state, // MulOp void MulOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *lhs, mlir::Value *rhs) { + mlir::ValuePtr lhs, mlir::ValuePtr rhs) { state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands({lhs, rhs}); } @@ -161,7 +162,7 @@ static mlir::LogicalResult verify(ReturnOp op) { // TransposeOp void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *value) { + mlir::ValuePtr value) { state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands(value); } diff --git a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp index da474e809b3..902c634a954 100644 --- a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp @@ -99,7 +99,7 @@ private: /// Entering a function creates a new scope, and the function arguments are /// added to the mapping. When the processing of a function is terminated, the /// scope is destroyed and the mappings created in this scope are dropped. - llvm::ScopedHashTable symbolTable; + llvm::ScopedHashTable symbolTable; /// Helper conversion for a Toy AST location to an MLIR location. mlir::Location loc(Location loc) { @@ -109,7 +109,7 @@ private: /// Declare a variable in the current scope, return success if the variable /// wasn't declared yet. - mlir::LogicalResult declare(llvm::StringRef var, mlir::Value *value) { + mlir::LogicalResult declare(llvm::StringRef var, mlir::ValuePtr value) { if (symbolTable.count(var)) return mlir::failure(); symbolTable.insert(var, value); @@ -132,7 +132,8 @@ private: /// Emit a new function and add it to the MLIR module. mlir::FuncOp mlirGen(FunctionAST &funcAST) { // Create a scope in the symbol table to hold variable declarations. - ScopedHashTableScope var_scope(symbolTable); + ScopedHashTableScope var_scope( + symbolTable); // Create an MLIR function for the given prototype. mlir::FuncOp function(mlirGen(*funcAST.getProto())); @@ -183,7 +184,7 @@ private: } /// Emit a binary operation - mlir::Value *mlirGen(BinaryExprAST &binop) { + mlir::ValuePtr mlirGen(BinaryExprAST &binop) { // First emit the operations for each side of the operation before emitting // the operation itself. For example if the expression is `a + foo(a)` // 1) First it will visiting the LHS, which will return a reference to the @@ -195,10 +196,10 @@ private: // and the result value is returned. If an error occurs we get a nullptr // and propagate. // - mlir::Value *lhs = mlirGen(*binop.getLHS()); + mlir::ValuePtr lhs = mlirGen(*binop.getLHS()); if (!lhs) return nullptr; - mlir::Value *rhs = mlirGen(*binop.getRHS()); + mlir::ValuePtr rhs = mlirGen(*binop.getRHS()); if (!rhs) return nullptr; auto location = loc(binop.loc()); @@ -219,8 +220,8 @@ private: /// This is a reference to a variable in an expression. The variable is /// expected to have been declared and so should have a value in the symbol /// table, otherwise emit an error and return nullptr. - mlir::Value *mlirGen(VariableExprAST &expr) { - if (auto *variable = symbolTable.lookup(expr.getName())) + mlir::ValuePtr mlirGen(VariableExprAST &expr) { + if (auto variable = symbolTable.lookup(expr.getName())) return variable; emitError(loc(expr.loc()), "error: unknown variable '") @@ -233,7 +234,7 @@ private: auto location = loc(ret.loc()); // 'return' takes an optional expression, handle that case here. - mlir::Value *expr = nullptr; + mlir::ValuePtr expr = nullptr; if (ret.getExpr().hasValue()) { if (!(expr = mlirGen(*ret.getExpr().getValue()))) return mlir::failure(); @@ -241,7 +242,7 @@ private: // Otherwise, this return operation has zero operands. builder.create(location, expr ? makeArrayRef(expr) - : ArrayRef()); + : ArrayRef()); return mlir::success(); } @@ -263,7 +264,7 @@ private: /// [[1.000000e+00, 2.000000e+00, 3.000000e+00], /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64> /// - mlir::Value *mlirGen(LiteralExprAST &lit) { + mlir::ValuePtr mlirGen(LiteralExprAST &lit) { auto type = getType(lit.getDims()); // The attribute is a vector with a floating point value per element @@ -309,14 +310,14 @@ private: /// Emit a call expression. It emits specific operations for the `transpose` /// builtin. Other identifiers are assumed to be user-defined functions. - mlir::Value *mlirGen(CallExprAST &call) { + mlir::ValuePtr mlirGen(CallExprAST &call) { llvm::StringRef callee = call.getCallee(); auto location = loc(call.loc()); // Codegen the operands first. - SmallVector operands; + SmallVector operands; for (auto &expr : call.getArgs()) { - auto *arg = mlirGen(*expr); + auto arg = mlirGen(*expr); if (!arg) return nullptr; operands.push_back(arg); @@ -342,7 +343,7 @@ private: /// Emit a print expression. It emits specific operations for two builtins: /// transpose(x) and print(x). mlir::LogicalResult mlirGen(PrintExprAST &call) { - auto *arg = mlirGen(*call.getArg()); + auto arg = mlirGen(*call.getArg()); if (!arg) return mlir::failure(); @@ -351,12 +352,12 @@ private: } /// Emit a constant for a single number (FIXME: semantic? broadcast?) - mlir::Value *mlirGen(NumberExprAST &num) { + mlir::ValuePtr mlirGen(NumberExprAST &num) { return builder.create(loc(num.loc()), num.getValue()); } /// Dispatch codegen for the right expression subclass using RTTI. - mlir::Value *mlirGen(ExprAST &expr) { + mlir::ValuePtr mlirGen(ExprAST &expr) { switch (expr.getKind()) { case toy::ExprAST::Expr_BinOp: return mlirGen(cast(expr)); @@ -380,7 +381,7 @@ private: /// initializer and record the value in the symbol table before returning it. /// Future expressions will be able to reference this variable through symbol /// table lookup. - mlir::Value *mlirGen(VarDeclExprAST &vardecl) { + mlir::ValuePtr mlirGen(VarDeclExprAST &vardecl) { auto init = vardecl.getInitVal(); if (!init) { emitError(loc(vardecl.loc()), @@ -388,7 +389,7 @@ private: return nullptr; } - mlir::Value *value = mlirGen(*init); + mlir::ValuePtr value = mlirGen(*init); if (!value) return nullptr; @@ -408,7 +409,7 @@ private: /// Codegen a list of expression, return failure if one of them hit an error. mlir::LogicalResult mlirGen(ExprASTList &blockAST) { - ScopedHashTableScope var_scope(symbolTable); + ScopedHashTableScope var_scope(symbolTable); for (auto &expr : blockAST) { // Specific handling for variable declarations, return statement, and // print. These can only appear in block list and not in nested diff --git a/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp index 1b9dcd20291..42a10397513 100644 --- a/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp @@ -48,7 +48,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { matchAndRewrite(TransposeOp op, mlir::PatternRewriter &rewriter) const override { // Look through the input of the current transpose. - mlir::Value *transposeInput = op.getOperand(); + mlir::ValuePtr transposeInput = op.getOperand(); TransposeOp transposeInputOp = llvm::dyn_cast_or_null(transposeInput->getDefiningOp()); diff --git a/mlir/examples/toy/Ch4/include/toy/Ops.td b/mlir/examples/toy/Ch4/include/toy/Ops.td index aec1cc3cfc9..ef5b30a862b 100644 --- a/mlir/examples/toy/Ch4/include/toy/Ops.td +++ b/mlir/examples/toy/Ch4/include/toy/Ops.td @@ -100,7 +100,7 @@ def AddOp : Toy_Op<"add", // Allow building an AddOp with from the two input operands. let builders = [ - OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs"> + OpBuilder<"Builder *b, OperationState &state, ValuePtr lhs, ValuePtr rhs"> ]; } @@ -151,7 +151,7 @@ def GenericCallOp : Toy_Op<"generic_call", // Add custom build methods for the generic call operation. let builders = [ OpBuilder<"Builder *builder, OperationState &state, " - "StringRef callee, ArrayRef arguments"> + "StringRef callee, ArrayRef arguments"> ]; } @@ -168,7 +168,7 @@ def MulOp : Toy_Op<"mul", // Allow building a MulOp with from the two input operands. let builders = [ - OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs"> + OpBuilder<"Builder *b, OperationState &state, ValuePtr lhs, ValuePtr rhs"> ]; } @@ -245,7 +245,7 @@ def TransposeOp : Toy_Op<"transpose", // Allow building a TransposeOp with from the input operand. let builders = [ - OpBuilder<"Builder *b, OperationState &state, Value *input"> + OpBuilder<"Builder *b, OperationState &state, ValuePtr input"> ]; // Invoke a static verify method to verify this transpose operation. diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp index 7003cbdcc81..8be1094cf15 100644 --- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp @@ -55,7 +55,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface { /// Handle the given inlined terminator(toy.return) by replacing it with a new /// operation as necessary. void handleTerminator(Operation *op, - ArrayRef valuesToRepl) const final { + ArrayRef valuesToRepl) const final { // Only "toy.return" needs to be handled here. auto returnOp = cast(op); @@ -70,7 +70,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface { /// operation that takes 'input' as the only operand, and produces a single /// result of 'resultType'. If a conversion can not be generated, nullptr /// should be returned. - Operation *materializeCallConversion(OpBuilder &builder, Value *input, + Operation *materializeCallConversion(OpBuilder &builder, ValuePtr input, Type resultType, Location conversionLoc) const final { return builder.create(conversionLoc, resultType, input); @@ -144,7 +144,7 @@ static mlir::LogicalResult verify(ConstantOp op) { // AddOp void AddOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *lhs, mlir::Value *rhs) { + mlir::ValuePtr lhs, mlir::ValuePtr rhs) { state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands({lhs, rhs}); } @@ -164,7 +164,8 @@ void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); } // GenericCallOp void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state, - StringRef callee, ArrayRef arguments) { + StringRef callee, + ArrayRef arguments) { // Generic call always returns an unranked Tensor initially. state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands(arguments); @@ -185,7 +186,7 @@ Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); } // MulOp void MulOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *lhs, mlir::Value *rhs) { + mlir::ValuePtr lhs, mlir::ValuePtr rhs) { state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands({lhs, rhs}); } @@ -236,7 +237,7 @@ static mlir::LogicalResult verify(ReturnOp op) { // TransposeOp void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *value) { + mlir::ValuePtr value) { state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands(value); } diff --git a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp index da474e809b3..902c634a954 100644 --- a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp @@ -99,7 +99,7 @@ private: /// Entering a function creates a new scope, and the function arguments are /// added to the mapping. When the processing of a function is terminated, the /// scope is destroyed and the mappings created in this scope are dropped. - llvm::ScopedHashTable symbolTable; + llvm::ScopedHashTable symbolTable; /// Helper conversion for a Toy AST location to an MLIR location. mlir::Location loc(Location loc) { @@ -109,7 +109,7 @@ private: /// Declare a variable in the current scope, return success if the variable /// wasn't declared yet. - mlir::LogicalResult declare(llvm::StringRef var, mlir::Value *value) { + mlir::LogicalResult declare(llvm::StringRef var, mlir::ValuePtr value) { if (symbolTable.count(var)) return mlir::failure(); symbolTable.insert(var, value); @@ -132,7 +132,8 @@ private: /// Emit a new function and add it to the MLIR module. mlir::FuncOp mlirGen(FunctionAST &funcAST) { // Create a scope in the symbol table to hold variable declarations. - ScopedHashTableScope var_scope(symbolTable); + ScopedHashTableScope var_scope( + symbolTable); // Create an MLIR function for the given prototype. mlir::FuncOp function(mlirGen(*funcAST.getProto())); @@ -183,7 +184,7 @@ private: } /// Emit a binary operation - mlir::Value *mlirGen(BinaryExprAST &binop) { + mlir::ValuePtr mlirGen(BinaryExprAST &binop) { // First emit the operations for each side of the operation before emitting // the operation itself. For example if the expression is `a + foo(a)` // 1) First it will visiting the LHS, which will return a reference to the @@ -195,10 +196,10 @@ private: // and the result value is returned. If an error occurs we get a nullptr // and propagate. // - mlir::Value *lhs = mlirGen(*binop.getLHS()); + mlir::ValuePtr lhs = mlirGen(*binop.getLHS()); if (!lhs) return nullptr; - mlir::Value *rhs = mlirGen(*binop.getRHS()); + mlir::ValuePtr rhs = mlirGen(*binop.getRHS()); if (!rhs) return nullptr; auto location = loc(binop.loc()); @@ -219,8 +220,8 @@ private: /// This is a reference to a variable in an expression. The variable is /// expected to have been declared and so should have a value in the symbol /// table, otherwise emit an error and return nullptr. - mlir::Value *mlirGen(VariableExprAST &expr) { - if (auto *variable = symbolTable.lookup(expr.getName())) + mlir::ValuePtr mlirGen(VariableExprAST &expr) { + if (auto variable = symbolTable.lookup(expr.getName())) return variable; emitError(loc(expr.loc()), "error: unknown variable '") @@ -233,7 +234,7 @@ private: auto location = loc(ret.loc()); // 'return' takes an optional expression, handle that case here. - mlir::Value *expr = nullptr; + mlir::ValuePtr expr = nullptr; if (ret.getExpr().hasValue()) { if (!(expr = mlirGen(*ret.getExpr().getValue()))) return mlir::failure(); @@ -241,7 +242,7 @@ private: // Otherwise, this return operation has zero operands. builder.create(location, expr ? makeArrayRef(expr) - : ArrayRef()); + : ArrayRef()); return mlir::success(); } @@ -263,7 +264,7 @@ private: /// [[1.000000e+00, 2.000000e+00, 3.000000e+00], /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64> /// - mlir::Value *mlirGen(LiteralExprAST &lit) { + mlir::ValuePtr mlirGen(LiteralExprAST &lit) { auto type = getType(lit.getDims()); // The attribute is a vector with a floating point value per element @@ -309,14 +310,14 @@ private: /// Emit a call expression. It emits specific operations for the `transpose` /// builtin. Other identifiers are assumed to be user-defined functions. - mlir::Value *mlirGen(CallExprAST &call) { + mlir::ValuePtr mlirGen(CallExprAST &call) { llvm::StringRef callee = call.getCallee(); auto location = loc(call.loc()); // Codegen the operands first. - SmallVector operands; + SmallVector operands; for (auto &expr : call.getArgs()) { - auto *arg = mlirGen(*expr); + auto arg = mlirGen(*expr); if (!arg) return nullptr; operands.push_back(arg); @@ -342,7 +343,7 @@ private: /// Emit a print expression. It emits specific operations for two builtins: /// transpose(x) and print(x). mlir::LogicalResult mlirGen(PrintExprAST &call) { - auto *arg = mlirGen(*call.getArg()); + auto arg = mlirGen(*call.getArg()); if (!arg) return mlir::failure(); @@ -351,12 +352,12 @@ private: } /// Emit a constant for a single number (FIXME: semantic? broadcast?) - mlir::Value *mlirGen(NumberExprAST &num) { + mlir::ValuePtr mlirGen(NumberExprAST &num) { return builder.create(loc(num.loc()), num.getValue()); } /// Dispatch codegen for the right expression subclass using RTTI. - mlir::Value *mlirGen(ExprAST &expr) { + mlir::ValuePtr mlirGen(ExprAST &expr) { switch (expr.getKind()) { case toy::ExprAST::Expr_BinOp: return mlirGen(cast(expr)); @@ -380,7 +381,7 @@ private: /// initializer and record the value in the symbol table before returning it. /// Future expressions will be able to reference this variable through symbol /// table lookup. - mlir::Value *mlirGen(VarDeclExprAST &vardecl) { + mlir::ValuePtr mlirGen(VarDeclExprAST &vardecl) { auto init = vardecl.getInitVal(); if (!init) { emitError(loc(vardecl.loc()), @@ -388,7 +389,7 @@ private: return nullptr; } - mlir::Value *value = mlirGen(*init); + mlir::ValuePtr value = mlirGen(*init); if (!value) return nullptr; @@ -408,7 +409,7 @@ private: /// Codegen a list of expression, return failure if one of them hit an error. mlir::LogicalResult mlirGen(ExprASTList &blockAST) { - ScopedHashTableScope var_scope(symbolTable); + ScopedHashTableScope var_scope(symbolTable); for (auto &expr : blockAST) { // Specific handling for variable declarations, return statement, and // print. These can only appear in block list and not in nested diff --git a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp index 47e1abc6c74..604e9fa6c83 100644 --- a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp @@ -53,7 +53,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { matchAndRewrite(TransposeOp op, mlir::PatternRewriter &rewriter) const override { // Look through the input of the current transpose. - mlir::Value *transposeInput = op.getOperand(); + mlir::ValuePtr transposeInput = op.getOperand(); TransposeOp transposeInputOp = llvm::dyn_cast_or_null(transposeInput->getDefiningOp()); diff --git a/mlir/examples/toy/Ch5/include/toy/Ops.td b/mlir/examples/toy/Ch5/include/toy/Ops.td index e40b661fd34..b3bda1d647b 100644 --- a/mlir/examples/toy/Ch5/include/toy/Ops.td +++ b/mlir/examples/toy/Ch5/include/toy/Ops.td @@ -100,7 +100,7 @@ def AddOp : Toy_Op<"add", // Allow building an AddOp with from the two input operands. let builders = [ - OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs"> + OpBuilder<"Builder *b, OperationState &state, ValuePtr lhs, ValuePtr rhs"> ]; } @@ -151,7 +151,7 @@ def GenericCallOp : Toy_Op<"generic_call", // Add custom build methods for the generic call operation. let builders = [ OpBuilder<"Builder *builder, OperationState &state, " - "StringRef callee, ArrayRef arguments"> + "StringRef callee, ArrayRef arguments"> ]; } @@ -168,7 +168,7 @@ def MulOp : Toy_Op<"mul", // Allow building a MulOp with from the two input operands. let builders = [ - OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs"> + OpBuilder<"Builder *b, OperationState &state, ValuePtr lhs, ValuePtr rhs"> ]; } @@ -246,7 +246,7 @@ def TransposeOp : Toy_Op<"transpose", // Allow building a TransposeOp with from the input operand. let builders = [ - OpBuilder<"Builder *b, OperationState &state, Value *input"> + OpBuilder<"Builder *b, OperationState &state, ValuePtr input"> ]; // Invoke a static verify method to verify this transpose operation. diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp index 7003cbdcc81..8be1094cf15 100644 --- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp @@ -55,7 +55,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface { /// Handle the given inlined terminator(toy.return) by replacing it with a new /// operation as necessary. void handleTerminator(Operation *op, - ArrayRef valuesToRepl) const final { + ArrayRef valuesToRepl) const final { // Only "toy.return" needs to be handled here. auto returnOp = cast(op); @@ -70,7 +70,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface { /// operation that takes 'input' as the only operand, and produces a single /// result of 'resultType'. If a conversion can not be generated, nullptr /// should be returned. - Operation *materializeCallConversion(OpBuilder &builder, Value *input, + Operation *materializeCallConversion(OpBuilder &builder, ValuePtr input, Type resultType, Location conversionLoc) const final { return builder.create(conversionLoc, resultType, input); @@ -144,7 +144,7 @@ static mlir::LogicalResult verify(ConstantOp op) { // AddOp void AddOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *lhs, mlir::Value *rhs) { + mlir::ValuePtr lhs, mlir::ValuePtr rhs) { state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands({lhs, rhs}); } @@ -164,7 +164,8 @@ void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); } // GenericCallOp void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state, - StringRef callee, ArrayRef arguments) { + StringRef callee, + ArrayRef arguments) { // Generic call always returns an unranked Tensor initially. state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands(arguments); @@ -185,7 +186,7 @@ Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); } // MulOp void MulOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *lhs, mlir::Value *rhs) { + mlir::ValuePtr lhs, mlir::ValuePtr rhs) { state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands({lhs, rhs}); } @@ -236,7 +237,7 @@ static mlir::LogicalResult verify(ReturnOp op) { // TransposeOp void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *value) { + mlir::ValuePtr value) { state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands(value); } diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp index 4ab8c5b501c..3fa761c7404 100644 --- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp @@ -43,8 +43,8 @@ static MemRefType convertTensorToMemRef(TensorType type) { } /// Insert an allocation and deallocation for the given MemRefType. -static Value *insertAllocAndDealloc(MemRefType type, Location loc, - PatternRewriter &rewriter) { +static ValuePtr insertAllocAndDealloc(MemRefType type, Location loc, + PatternRewriter &rewriter) { auto alloc = rewriter.create(loc, type); // Make sure to allocate at the beginning of the block. @@ -63,11 +63,11 @@ static Value *insertAllocAndDealloc(MemRefType type, Location loc, /// to the operands of the input operation, and the set of loop induction /// variables for the iteration. It returns a value to store at the current /// index of the iteration. -using LoopIterationFn = function_ref memRefOperands, - ArrayRef loopIvs)>; +using LoopIterationFn = function_ref memRefOperands, + ArrayRef loopIvs)>; -static void lowerOpToLoops(Operation *op, ArrayRef operands, +static void lowerOpToLoops(Operation *op, ArrayRef operands, PatternRewriter &rewriter, LoopIterationFn processIteration) { auto tensorType = (*op->result_type_begin()).cast(); @@ -78,7 +78,7 @@ static void lowerOpToLoops(Operation *op, ArrayRef operands, auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); // Create an empty affine loop for each of the dimensions within the shape. - SmallVector loopIvs; + SmallVector loopIvs; for (auto dim : tensorType.getShape()) { auto loop = rewriter.create(loc, /*lb=*/0, dim, /*step=*/1); loop.getBody()->clear(); @@ -94,7 +94,7 @@ static void lowerOpToLoops(Operation *op, ArrayRef operands, // Generate a call to the processing function with the rewriter, the memref // operands, and the loop induction variables. This function will return the // value to store at the current index. - Value *valueToStore = processIteration(rewriter, operands, loopIvs); + ValuePtr valueToStore = processIteration(rewriter, operands, loopIvs); rewriter.create(loc, valueToStore, alloc, llvm::makeArrayRef(loopIvs)); @@ -113,13 +113,13 @@ struct BinaryOpLowering : public ConversionPattern { : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); lowerOpToLoops( op, operands, rewriter, - [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, - ArrayRef loopIvs) { + [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, + ArrayRef loopIvs) { // Generate an adaptor for the remapped operands of the BinaryOp. This // allows for using the nice named accessors that are generated by the // ODS. @@ -163,7 +163,7 @@ struct ConstantOpLowering : public OpRewritePattern { // Create these constants up-front to avoid large amounts of redundant // operations. auto valueShape = memRefType.getShape(); - SmallVector constantIndices; + SmallVector constantIndices; for (auto i : llvm::seq( 0, *std::max_element(valueShape.begin(), valueShape.end()))) constantIndices.push_back(rewriter.create(loc, i)); @@ -172,7 +172,7 @@ struct ConstantOpLowering : public OpRewritePattern { // will need to generate a store for each of the elements. The following // functor recursively walks the dimensions of the constant shape, // generating a store when the recursion hits the base case. - SmallVector indices; + SmallVector indices; auto valueIt = constantValue.getValues().begin(); std::function storeElements = [&](uint64_t dimension) { // The last dimension is the base case of the recursion, at this point @@ -231,22 +231,22 @@ struct TransposeOpLowering : public ConversionPattern { : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); lowerOpToLoops( op, operands, rewriter, - [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, - ArrayRef loopIvs) { + [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, + ArrayRef loopIvs) { // Generate an adaptor for the remapped operands of the TransposeOp. // This allows for using the nice named accessors that are generated // by the ODS. toy::TransposeOpOperandAdaptor transposeAdaptor(memRefOperands); - Value *input = transposeAdaptor.input(); + ValuePtr input = transposeAdaptor.input(); // Transpose the elements by generating a load from the reverse // indices. - SmallVector reverseIvs(llvm::reverse(loopIvs)); + SmallVector reverseIvs(llvm::reverse(loopIvs)); return rewriter.create(loc, input, reverseIvs); }); return matchSuccess(); diff --git a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp index da474e809b3..902c634a954 100644 --- a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp @@ -99,7 +99,7 @@ private: /// Entering a function creates a new scope, and the function arguments are /// added to the mapping. When the processing of a function is terminated, the /// scope is destroyed and the mappings created in this scope are dropped. - llvm::ScopedHashTable symbolTable; + llvm::ScopedHashTable symbolTable; /// Helper conversion for a Toy AST location to an MLIR location. mlir::Location loc(Location loc) { @@ -109,7 +109,7 @@ private: /// Declare a variable in the current scope, return success if the variable /// wasn't declared yet. - mlir::LogicalResult declare(llvm::StringRef var, mlir::Value *value) { + mlir::LogicalResult declare(llvm::StringRef var, mlir::ValuePtr value) { if (symbolTable.count(var)) return mlir::failure(); symbolTable.insert(var, value); @@ -132,7 +132,8 @@ private: /// Emit a new function and add it to the MLIR module. mlir::FuncOp mlirGen(FunctionAST &funcAST) { // Create a scope in the symbol table to hold variable declarations. - ScopedHashTableScope var_scope(symbolTable); + ScopedHashTableScope var_scope( + symbolTable); // Create an MLIR function for the given prototype. mlir::FuncOp function(mlirGen(*funcAST.getProto())); @@ -183,7 +184,7 @@ private: } /// Emit a binary operation - mlir::Value *mlirGen(BinaryExprAST &binop) { + mlir::ValuePtr mlirGen(BinaryExprAST &binop) { // First emit the operations for each side of the operation before emitting // the operation itself. For example if the expression is `a + foo(a)` // 1) First it will visiting the LHS, which will return a reference to the @@ -195,10 +196,10 @@ private: // and the result value is returned. If an error occurs we get a nullptr // and propagate. // - mlir::Value *lhs = mlirGen(*binop.getLHS()); + mlir::ValuePtr lhs = mlirGen(*binop.getLHS()); if (!lhs) return nullptr; - mlir::Value *rhs = mlirGen(*binop.getRHS()); + mlir::ValuePtr rhs = mlirGen(*binop.getRHS()); if (!rhs) return nullptr; auto location = loc(binop.loc()); @@ -219,8 +220,8 @@ private: /// This is a reference to a variable in an expression. The variable is /// expected to have been declared and so should have a value in the symbol /// table, otherwise emit an error and return nullptr. - mlir::Value *mlirGen(VariableExprAST &expr) { - if (auto *variable = symbolTable.lookup(expr.getName())) + mlir::ValuePtr mlirGen(VariableExprAST &expr) { + if (auto variable = symbolTable.lookup(expr.getName())) return variable; emitError(loc(expr.loc()), "error: unknown variable '") @@ -233,7 +234,7 @@ private: auto location = loc(ret.loc()); // 'return' takes an optional expression, handle that case here. - mlir::Value *expr = nullptr; + mlir::ValuePtr expr = nullptr; if (ret.getExpr().hasValue()) { if (!(expr = mlirGen(*ret.getExpr().getValue()))) return mlir::failure(); @@ -241,7 +242,7 @@ private: // Otherwise, this return operation has zero operands. builder.create(location, expr ? makeArrayRef(expr) - : ArrayRef()); + : ArrayRef()); return mlir::success(); } @@ -263,7 +264,7 @@ private: /// [[1.000000e+00, 2.000000e+00, 3.000000e+00], /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64> /// - mlir::Value *mlirGen(LiteralExprAST &lit) { + mlir::ValuePtr mlirGen(LiteralExprAST &lit) { auto type = getType(lit.getDims()); // The attribute is a vector with a floating point value per element @@ -309,14 +310,14 @@ private: /// Emit a call expression. It emits specific operations for the `transpose` /// builtin. Other identifiers are assumed to be user-defined functions. - mlir::Value *mlirGen(CallExprAST &call) { + mlir::ValuePtr mlirGen(CallExprAST &call) { llvm::StringRef callee = call.getCallee(); auto location = loc(call.loc()); // Codegen the operands first. - SmallVector operands; + SmallVector operands; for (auto &expr : call.getArgs()) { - auto *arg = mlirGen(*expr); + auto arg = mlirGen(*expr); if (!arg) return nullptr; operands.push_back(arg); @@ -342,7 +343,7 @@ private: /// Emit a print expression. It emits specific operations for two builtins: /// transpose(x) and print(x). mlir::LogicalResult mlirGen(PrintExprAST &call) { - auto *arg = mlirGen(*call.getArg()); + auto arg = mlirGen(*call.getArg()); if (!arg) return mlir::failure(); @@ -351,12 +352,12 @@ private: } /// Emit a constant for a single number (FIXME: semantic? broadcast?) - mlir::Value *mlirGen(NumberExprAST &num) { + mlir::ValuePtr mlirGen(NumberExprAST &num) { return builder.create(loc(num.loc()), num.getValue()); } /// Dispatch codegen for the right expression subclass using RTTI. - mlir::Value *mlirGen(ExprAST &expr) { + mlir::ValuePtr mlirGen(ExprAST &expr) { switch (expr.getKind()) { case toy::ExprAST::Expr_BinOp: return mlirGen(cast(expr)); @@ -380,7 +381,7 @@ private: /// initializer and record the value in the symbol table before returning it. /// Future expressions will be able to reference this variable through symbol /// table lookup. - mlir::Value *mlirGen(VarDeclExprAST &vardecl) { + mlir::ValuePtr mlirGen(VarDeclExprAST &vardecl) { auto init = vardecl.getInitVal(); if (!init) { emitError(loc(vardecl.loc()), @@ -388,7 +389,7 @@ private: return nullptr; } - mlir::Value *value = mlirGen(*init); + mlir::ValuePtr value = mlirGen(*init); if (!value) return nullptr; @@ -408,7 +409,7 @@ private: /// Codegen a list of expression, return failure if one of them hit an error. mlir::LogicalResult mlirGen(ExprASTList &blockAST) { - ScopedHashTableScope var_scope(symbolTable); + ScopedHashTableScope var_scope(symbolTable); for (auto &expr : blockAST) { // Specific handling for variable declarations, return statement, and // print. These can only appear in block list and not in nested diff --git a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp index 47e1abc6c74..604e9fa6c83 100644 --- a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp @@ -53,7 +53,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { matchAndRewrite(TransposeOp op, mlir::PatternRewriter &rewriter) const override { // Look through the input of the current transpose. - mlir::Value *transposeInput = op.getOperand(); + mlir::ValuePtr transposeInput = op.getOperand(); TransposeOp transposeInputOp = llvm::dyn_cast_or_null(transposeInput->getDefiningOp()); diff --git a/mlir/examples/toy/Ch6/include/toy/Ops.td b/mlir/examples/toy/Ch6/include/toy/Ops.td index e40b661fd34..b3bda1d647b 100644 --- a/mlir/examples/toy/Ch6/include/toy/Ops.td +++ b/mlir/examples/toy/Ch6/include/toy/Ops.td @@ -100,7 +100,7 @@ def AddOp : Toy_Op<"add", // Allow building an AddOp with from the two input operands. let builders = [ - OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs"> + OpBuilder<"Builder *b, OperationState &state, ValuePtr lhs, ValuePtr rhs"> ]; } @@ -151,7 +151,7 @@ def GenericCallOp : Toy_Op<"generic_call", // Add custom build methods for the generic call operation. let builders = [ OpBuilder<"Builder *builder, OperationState &state, " - "StringRef callee, ArrayRef arguments"> + "StringRef callee, ArrayRef arguments"> ]; } @@ -168,7 +168,7 @@ def MulOp : Toy_Op<"mul", // Allow building a MulOp with from the two input operands. let builders = [ - OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs"> + OpBuilder<"Builder *b, OperationState &state, ValuePtr lhs, ValuePtr rhs"> ]; } @@ -246,7 +246,7 @@ def TransposeOp : Toy_Op<"transpose", // Allow building a TransposeOp with from the input operand. let builders = [ - OpBuilder<"Builder *b, OperationState &state, Value *input"> + OpBuilder<"Builder *b, OperationState &state, ValuePtr input"> ]; // Invoke a static verify method to verify this transpose operation. diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp index 7003cbdcc81..8be1094cf15 100644 --- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp @@ -55,7 +55,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface { /// Handle the given inlined terminator(toy.return) by replacing it with a new /// operation as necessary. void handleTerminator(Operation *op, - ArrayRef valuesToRepl) const final { + ArrayRef valuesToRepl) const final { // Only "toy.return" needs to be handled here. auto returnOp = cast(op); @@ -70,7 +70,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface { /// operation that takes 'input' as the only operand, and produces a single /// result of 'resultType'. If a conversion can not be generated, nullptr /// should be returned. - Operation *materializeCallConversion(OpBuilder &builder, Value *input, + Operation *materializeCallConversion(OpBuilder &builder, ValuePtr input, Type resultType, Location conversionLoc) const final { return builder.create(conversionLoc, resultType, input); @@ -144,7 +144,7 @@ static mlir::LogicalResult verify(ConstantOp op) { // AddOp void AddOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *lhs, mlir::Value *rhs) { + mlir::ValuePtr lhs, mlir::ValuePtr rhs) { state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands({lhs, rhs}); } @@ -164,7 +164,8 @@ void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); } // GenericCallOp void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state, - StringRef callee, ArrayRef arguments) { + StringRef callee, + ArrayRef arguments) { // Generic call always returns an unranked Tensor initially. state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands(arguments); @@ -185,7 +186,7 @@ Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); } // MulOp void MulOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *lhs, mlir::Value *rhs) { + mlir::ValuePtr lhs, mlir::ValuePtr rhs) { state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands({lhs, rhs}); } @@ -236,7 +237,7 @@ static mlir::LogicalResult verify(ReturnOp op) { // TransposeOp void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *value) { + mlir::ValuePtr value) { state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands(value); } diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp index 4ab8c5b501c..3fa761c7404 100644 --- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp @@ -43,8 +43,8 @@ static MemRefType convertTensorToMemRef(TensorType type) { } /// Insert an allocation and deallocation for the given MemRefType. -static Value *insertAllocAndDealloc(MemRefType type, Location loc, - PatternRewriter &rewriter) { +static ValuePtr insertAllocAndDealloc(MemRefType type, Location loc, + PatternRewriter &rewriter) { auto alloc = rewriter.create(loc, type); // Make sure to allocate at the beginning of the block. @@ -63,11 +63,11 @@ static Value *insertAllocAndDealloc(MemRefType type, Location loc, /// to the operands of the input operation, and the set of loop induction /// variables for the iteration. It returns a value to store at the current /// index of the iteration. -using LoopIterationFn = function_ref memRefOperands, - ArrayRef loopIvs)>; +using LoopIterationFn = function_ref memRefOperands, + ArrayRef loopIvs)>; -static void lowerOpToLoops(Operation *op, ArrayRef operands, +static void lowerOpToLoops(Operation *op, ArrayRef operands, PatternRewriter &rewriter, LoopIterationFn processIteration) { auto tensorType = (*op->result_type_begin()).cast(); @@ -78,7 +78,7 @@ static void lowerOpToLoops(Operation *op, ArrayRef operands, auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); // Create an empty affine loop for each of the dimensions within the shape. - SmallVector loopIvs; + SmallVector loopIvs; for (auto dim : tensorType.getShape()) { auto loop = rewriter.create(loc, /*lb=*/0, dim, /*step=*/1); loop.getBody()->clear(); @@ -94,7 +94,7 @@ static void lowerOpToLoops(Operation *op, ArrayRef operands, // Generate a call to the processing function with the rewriter, the memref // operands, and the loop induction variables. This function will return the // value to store at the current index. - Value *valueToStore = processIteration(rewriter, operands, loopIvs); + ValuePtr valueToStore = processIteration(rewriter, operands, loopIvs); rewriter.create(loc, valueToStore, alloc, llvm::makeArrayRef(loopIvs)); @@ -113,13 +113,13 @@ struct BinaryOpLowering : public ConversionPattern { : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); lowerOpToLoops( op, operands, rewriter, - [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, - ArrayRef loopIvs) { + [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, + ArrayRef loopIvs) { // Generate an adaptor for the remapped operands of the BinaryOp. This // allows for using the nice named accessors that are generated by the // ODS. @@ -163,7 +163,7 @@ struct ConstantOpLowering : public OpRewritePattern { // Create these constants up-front to avoid large amounts of redundant // operations. auto valueShape = memRefType.getShape(); - SmallVector constantIndices; + SmallVector constantIndices; for (auto i : llvm::seq( 0, *std::max_element(valueShape.begin(), valueShape.end()))) constantIndices.push_back(rewriter.create(loc, i)); @@ -172,7 +172,7 @@ struct ConstantOpLowering : public OpRewritePattern { // will need to generate a store for each of the elements. The following // functor recursively walks the dimensions of the constant shape, // generating a store when the recursion hits the base case. - SmallVector indices; + SmallVector indices; auto valueIt = constantValue.getValues().begin(); std::function storeElements = [&](uint64_t dimension) { // The last dimension is the base case of the recursion, at this point @@ -231,22 +231,22 @@ struct TransposeOpLowering : public ConversionPattern { : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); lowerOpToLoops( op, operands, rewriter, - [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, - ArrayRef loopIvs) { + [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, + ArrayRef loopIvs) { // Generate an adaptor for the remapped operands of the TransposeOp. // This allows for using the nice named accessors that are generated // by the ODS. toy::TransposeOpOperandAdaptor transposeAdaptor(memRefOperands); - Value *input = transposeAdaptor.input(); + ValuePtr input = transposeAdaptor.input(); // Transpose the elements by generating a load from the reverse // indices. - SmallVector reverseIvs(llvm::reverse(loopIvs)); + SmallVector reverseIvs(llvm::reverse(loopIvs)); return rewriter.create(loc, input, reverseIvs); }); return matchSuccess(); diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp index d35cc5c576a..c3180b4a92d 100644 --- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp @@ -51,7 +51,7 @@ public: : ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto memRefType = (*op->operand_type_begin()).cast(); auto memRefShape = memRefType.getShape(); @@ -64,14 +64,14 @@ public: // Get a symbol reference to the printf function, inserting it if necessary. auto printfRef = getOrInsertPrintf(rewriter, parentModule, llvmDialect); - Value *formatSpecifierCst = getOrCreateGlobalString( + ValuePtr formatSpecifierCst = getOrCreateGlobalString( loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule, llvmDialect); - Value *newLineCst = getOrCreateGlobalString( + ValuePtr newLineCst = getOrCreateGlobalString( loc, rewriter, "nl", StringRef("\n\0", 2), parentModule, llvmDialect); // Create a loop for each of the dimensions within the shape. - SmallVector loopIvs; + SmallVector loopIvs; for (unsigned i = 0, e = memRefShape.size(); i != e; ++i) { auto lowerBound = rewriter.create(loc, 0); auto upperBound = rewriter.create(loc, memRefShape[i]); @@ -97,7 +97,7 @@ public: auto elementLoad = rewriter.create(loc, printOp.input(), loopIvs); rewriter.create( loc, printfRef, rewriter.getIntegerType(32), - ArrayRef({formatSpecifierCst, elementLoad})); + ArrayRef({formatSpecifierCst, elementLoad})); // Notify the rewriter that this operation has been removed. rewriter.eraseOp(op); @@ -130,10 +130,10 @@ private: /// Return a value representing an access into a global string with the given /// name, creating the string if necessary. - static Value *getOrCreateGlobalString(Location loc, OpBuilder &builder, - StringRef name, StringRef value, - ModuleOp module, - LLVM::LLVMDialect *llvmDialect) { + static ValuePtr getOrCreateGlobalString(Location loc, OpBuilder &builder, + StringRef name, StringRef value, + ModuleOp module, + LLVM::LLVMDialect *llvmDialect) { // Create the global at the entry of the module. LLVM::GlobalOp global; if (!(global = module.lookupSymbol(name))) { @@ -147,13 +147,13 @@ private: } // Get the pointer to the first character in the global string. - Value *globalPtr = builder.create(loc, global); - Value *cst0 = builder.create( + ValuePtr globalPtr = builder.create(loc, global); + ValuePtr cst0 = builder.create( loc, LLVM::LLVMType::getInt64Ty(llvmDialect), builder.getIntegerAttr(builder.getIndexType(), 0)); return builder.create( loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), globalPtr, - ArrayRef({cst0, cst0})); + ArrayRef({cst0, cst0})); } }; } // end anonymous namespace diff --git a/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp index da474e809b3..902c634a954 100644 --- a/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp @@ -99,7 +99,7 @@ private: /// Entering a function creates a new scope, and the function arguments are /// added to the mapping. When the processing of a function is terminated, the /// scope is destroyed and the mappings created in this scope are dropped. - llvm::ScopedHashTable symbolTable; + llvm::ScopedHashTable symbolTable; /// Helper conversion for a Toy AST location to an MLIR location. mlir::Location loc(Location loc) { @@ -109,7 +109,7 @@ private: /// Declare a variable in the current scope, return success if the variable /// wasn't declared yet. - mlir::LogicalResult declare(llvm::StringRef var, mlir::Value *value) { + mlir::LogicalResult declare(llvm::StringRef var, mlir::ValuePtr value) { if (symbolTable.count(var)) return mlir::failure(); symbolTable.insert(var, value); @@ -132,7 +132,8 @@ private: /// Emit a new function and add it to the MLIR module. mlir::FuncOp mlirGen(FunctionAST &funcAST) { // Create a scope in the symbol table to hold variable declarations. - ScopedHashTableScope var_scope(symbolTable); + ScopedHashTableScope var_scope( + symbolTable); // Create an MLIR function for the given prototype. mlir::FuncOp function(mlirGen(*funcAST.getProto())); @@ -183,7 +184,7 @@ private: } /// Emit a binary operation - mlir::Value *mlirGen(BinaryExprAST &binop) { + mlir::ValuePtr mlirGen(BinaryExprAST &binop) { // First emit the operations for each side of the operation before emitting // the operation itself. For example if the expression is `a + foo(a)` // 1) First it will visiting the LHS, which will return a reference to the @@ -195,10 +196,10 @@ private: // and the result value is returned. If an error occurs we get a nullptr // and propagate. // - mlir::Value *lhs = mlirGen(*binop.getLHS()); + mlir::ValuePtr lhs = mlirGen(*binop.getLHS()); if (!lhs) return nullptr; - mlir::Value *rhs = mlirGen(*binop.getRHS()); + mlir::ValuePtr rhs = mlirGen(*binop.getRHS()); if (!rhs) return nullptr; auto location = loc(binop.loc()); @@ -219,8 +220,8 @@ private: /// This is a reference to a variable in an expression. The variable is /// expected to have been declared and so should have a value in the symbol /// table, otherwise emit an error and return nullptr. - mlir::Value *mlirGen(VariableExprAST &expr) { - if (auto *variable = symbolTable.lookup(expr.getName())) + mlir::ValuePtr mlirGen(VariableExprAST &expr) { + if (auto variable = symbolTable.lookup(expr.getName())) return variable; emitError(loc(expr.loc()), "error: unknown variable '") @@ -233,7 +234,7 @@ private: auto location = loc(ret.loc()); // 'return' takes an optional expression, handle that case here. - mlir::Value *expr = nullptr; + mlir::ValuePtr expr = nullptr; if (ret.getExpr().hasValue()) { if (!(expr = mlirGen(*ret.getExpr().getValue()))) return mlir::failure(); @@ -241,7 +242,7 @@ private: // Otherwise, this return operation has zero operands. builder.create(location, expr ? makeArrayRef(expr) - : ArrayRef()); + : ArrayRef()); return mlir::success(); } @@ -263,7 +264,7 @@ private: /// [[1.000000e+00, 2.000000e+00, 3.000000e+00], /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64> /// - mlir::Value *mlirGen(LiteralExprAST &lit) { + mlir::ValuePtr mlirGen(LiteralExprAST &lit) { auto type = getType(lit.getDims()); // The attribute is a vector with a floating point value per element @@ -309,14 +310,14 @@ private: /// Emit a call expression. It emits specific operations for the `transpose` /// builtin. Other identifiers are assumed to be user-defined functions. - mlir::Value *mlirGen(CallExprAST &call) { + mlir::ValuePtr mlirGen(CallExprAST &call) { llvm::StringRef callee = call.getCallee(); auto location = loc(call.loc()); // Codegen the operands first. - SmallVector operands; + SmallVector operands; for (auto &expr : call.getArgs()) { - auto *arg = mlirGen(*expr); + auto arg = mlirGen(*expr); if (!arg) return nullptr; operands.push_back(arg); @@ -342,7 +343,7 @@ private: /// Emit a print expression. It emits specific operations for two builtins: /// transpose(x) and print(x). mlir::LogicalResult mlirGen(PrintExprAST &call) { - auto *arg = mlirGen(*call.getArg()); + auto arg = mlirGen(*call.getArg()); if (!arg) return mlir::failure(); @@ -351,12 +352,12 @@ private: } /// Emit a constant for a single number (FIXME: semantic? broadcast?) - mlir::Value *mlirGen(NumberExprAST &num) { + mlir::ValuePtr mlirGen(NumberExprAST &num) { return builder.create(loc(num.loc()), num.getValue()); } /// Dispatch codegen for the right expression subclass using RTTI. - mlir::Value *mlirGen(ExprAST &expr) { + mlir::ValuePtr mlirGen(ExprAST &expr) { switch (expr.getKind()) { case toy::ExprAST::Expr_BinOp: return mlirGen(cast(expr)); @@ -380,7 +381,7 @@ private: /// initializer and record the value in the symbol table before returning it. /// Future expressions will be able to reference this variable through symbol /// table lookup. - mlir::Value *mlirGen(VarDeclExprAST &vardecl) { + mlir::ValuePtr mlirGen(VarDeclExprAST &vardecl) { auto init = vardecl.getInitVal(); if (!init) { emitError(loc(vardecl.loc()), @@ -388,7 +389,7 @@ private: return nullptr; } - mlir::Value *value = mlirGen(*init); + mlir::ValuePtr value = mlirGen(*init); if (!value) return nullptr; @@ -408,7 +409,7 @@ private: /// Codegen a list of expression, return failure if one of them hit an error. mlir::LogicalResult mlirGen(ExprASTList &blockAST) { - ScopedHashTableScope var_scope(symbolTable); + ScopedHashTableScope var_scope(symbolTable); for (auto &expr : blockAST) { // Specific handling for variable declarations, return statement, and // print. These can only appear in block list and not in nested diff --git a/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp index 47e1abc6c74..604e9fa6c83 100644 --- a/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp @@ -53,7 +53,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { matchAndRewrite(TransposeOp op, mlir::PatternRewriter &rewriter) const override { // Look through the input of the current transpose. - mlir::Value *transposeInput = op.getOperand(); + mlir::ValuePtr transposeInput = op.getOperand(); TransposeOp transposeInputOp = llvm::dyn_cast_or_null(transposeInput->getDefiningOp()); diff --git a/mlir/examples/toy/Ch7/include/toy/Ops.td b/mlir/examples/toy/Ch7/include/toy/Ops.td index 0d48f74e9fe..94f1bcf3e82 100644 --- a/mlir/examples/toy/Ch7/include/toy/Ops.td +++ b/mlir/examples/toy/Ch7/include/toy/Ops.td @@ -112,7 +112,7 @@ def AddOp : Toy_Op<"add", // Allow building an AddOp with from the two input operands. let builders = [ - OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs"> + OpBuilder<"Builder *b, OperationState &state, ValuePtr lhs, ValuePtr rhs"> ]; } @@ -164,7 +164,7 @@ def GenericCallOp : Toy_Op<"generic_call", // Add custom build methods for the generic call operation. let builders = [ OpBuilder<"Builder *builder, OperationState &state, " - "StringRef callee, ArrayRef arguments"> + "StringRef callee, ArrayRef arguments"> ]; } @@ -181,7 +181,7 @@ def MulOp : Toy_Op<"mul", // Allow building a MulOp with from the two input operands. let builders = [ - OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs"> + OpBuilder<"Builder *b, OperationState &state, ValuePtr lhs, ValuePtr rhs"> ]; } @@ -260,7 +260,7 @@ def StructAccessOp : Toy_Op<"struct_access", [NoSideEffect]> { // Allow building a StructAccessOp with just a struct value and an index. let builders = [ - OpBuilder<"Builder *b, OperationState &state, Value *input, size_t index"> + OpBuilder<"Builder *b, OperationState &state, ValuePtr input, size_t index"> ]; let verifier = [{ return ::verify(*this); }]; @@ -299,7 +299,7 @@ def TransposeOp : Toy_Op<"transpose", // Allow building a TransposeOp with from the input operand. let builders = [ - OpBuilder<"Builder *b, OperationState &state, Value *input"> + OpBuilder<"Builder *b, OperationState &state, ValuePtr input"> ]; // Invoke a static verify method to verify this transpose operation. diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp index 2beaa870a89..0ce896db5de 100644 --- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -56,7 +56,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface { /// Handle the given inlined terminator(toy.return) by replacing it with a new /// operation as necessary. void handleTerminator(Operation *op, - ArrayRef valuesToRepl) const final { + ArrayRef valuesToRepl) const final { // Only "toy.return" needs to be handled here. auto returnOp = cast(op); @@ -71,7 +71,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface { /// operation that takes 'input' as the only operand, and produces a single /// result of 'resultType'. If a conversion can not be generated, nullptr /// should be returned. - Operation *materializeCallConversion(OpBuilder &builder, Value *input, + Operation *materializeCallConversion(OpBuilder &builder, ValuePtr input, Type resultType, Location conversionLoc) const final { return builder.create(conversionLoc, resultType, input); @@ -195,7 +195,7 @@ void ConstantOp::inferShapes() { getResult()->setType(value().getType()); } // AddOp void AddOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *lhs, mlir::Value *rhs) { + mlir::ValuePtr lhs, mlir::ValuePtr rhs) { state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands({lhs, rhs}); } @@ -215,7 +215,8 @@ void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); } // GenericCallOp void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state, - StringRef callee, ArrayRef arguments) { + StringRef callee, + ArrayRef arguments) { // Generic call always returns an unranked Tensor initially. state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands(arguments); @@ -236,7 +237,7 @@ Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); } // MulOp void MulOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *lhs, mlir::Value *rhs) { + mlir::ValuePtr lhs, mlir::ValuePtr rhs) { state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands({lhs, rhs}); } @@ -287,7 +288,7 @@ static mlir::LogicalResult verify(ReturnOp op) { // StructAccessOp void StructAccessOp::build(mlir::Builder *b, mlir::OperationState &state, - mlir::Value *input, size_t index) { + mlir::ValuePtr input, size_t index) { // Extract the result type from the input type. StructType structTy = input->getType().cast(); assert(index < structTy.getNumElementTypes()); @@ -314,7 +315,7 @@ static mlir::LogicalResult verify(StructAccessOp op) { // TransposeOp void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *value) { + mlir::ValuePtr value) { state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands(value); } diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp index 4ab8c5b501c..3fa761c7404 100644 --- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp @@ -43,8 +43,8 @@ static MemRefType convertTensorToMemRef(TensorType type) { } /// Insert an allocation and deallocation for the given MemRefType. -static Value *insertAllocAndDealloc(MemRefType type, Location loc, - PatternRewriter &rewriter) { +static ValuePtr insertAllocAndDealloc(MemRefType type, Location loc, + PatternRewriter &rewriter) { auto alloc = rewriter.create(loc, type); // Make sure to allocate at the beginning of the block. @@ -63,11 +63,11 @@ static Value *insertAllocAndDealloc(MemRefType type, Location loc, /// to the operands of the input operation, and the set of loop induction /// variables for the iteration. It returns a value to store at the current /// index of the iteration. -using LoopIterationFn = function_ref memRefOperands, - ArrayRef loopIvs)>; +using LoopIterationFn = function_ref memRefOperands, + ArrayRef loopIvs)>; -static void lowerOpToLoops(Operation *op, ArrayRef operands, +static void lowerOpToLoops(Operation *op, ArrayRef operands, PatternRewriter &rewriter, LoopIterationFn processIteration) { auto tensorType = (*op->result_type_begin()).cast(); @@ -78,7 +78,7 @@ static void lowerOpToLoops(Operation *op, ArrayRef operands, auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); // Create an empty affine loop for each of the dimensions within the shape. - SmallVector loopIvs; + SmallVector loopIvs; for (auto dim : tensorType.getShape()) { auto loop = rewriter.create(loc, /*lb=*/0, dim, /*step=*/1); loop.getBody()->clear(); @@ -94,7 +94,7 @@ static void lowerOpToLoops(Operation *op, ArrayRef operands, // Generate a call to the processing function with the rewriter, the memref // operands, and the loop induction variables. This function will return the // value to store at the current index. - Value *valueToStore = processIteration(rewriter, operands, loopIvs); + ValuePtr valueToStore = processIteration(rewriter, operands, loopIvs); rewriter.create(loc, valueToStore, alloc, llvm::makeArrayRef(loopIvs)); @@ -113,13 +113,13 @@ struct BinaryOpLowering : public ConversionPattern { : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); lowerOpToLoops( op, operands, rewriter, - [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, - ArrayRef loopIvs) { + [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, + ArrayRef loopIvs) { // Generate an adaptor for the remapped operands of the BinaryOp. This // allows for using the nice named accessors that are generated by the // ODS. @@ -163,7 +163,7 @@ struct ConstantOpLowering : public OpRewritePattern { // Create these constants up-front to avoid large amounts of redundant // operations. auto valueShape = memRefType.getShape(); - SmallVector constantIndices; + SmallVector constantIndices; for (auto i : llvm::seq( 0, *std::max_element(valueShape.begin(), valueShape.end()))) constantIndices.push_back(rewriter.create(loc, i)); @@ -172,7 +172,7 @@ struct ConstantOpLowering : public OpRewritePattern { // will need to generate a store for each of the elements. The following // functor recursively walks the dimensions of the constant shape, // generating a store when the recursion hits the base case. - SmallVector indices; + SmallVector indices; auto valueIt = constantValue.getValues().begin(); std::function storeElements = [&](uint64_t dimension) { // The last dimension is the base case of the recursion, at this point @@ -231,22 +231,22 @@ struct TransposeOpLowering : public ConversionPattern { : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); lowerOpToLoops( op, operands, rewriter, - [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, - ArrayRef loopIvs) { + [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, + ArrayRef loopIvs) { // Generate an adaptor for the remapped operands of the TransposeOp. // This allows for using the nice named accessors that are generated // by the ODS. toy::TransposeOpOperandAdaptor transposeAdaptor(memRefOperands); - Value *input = transposeAdaptor.input(); + ValuePtr input = transposeAdaptor.input(); // Transpose the elements by generating a load from the reverse // indices. - SmallVector reverseIvs(llvm::reverse(loopIvs)); + SmallVector reverseIvs(llvm::reverse(loopIvs)); return rewriter.create(loc, input, reverseIvs); }); return matchSuccess(); diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp index d35cc5c576a..c3180b4a92d 100644 --- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp @@ -51,7 +51,7 @@ public: : ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto memRefType = (*op->operand_type_begin()).cast(); auto memRefShape = memRefType.getShape(); @@ -64,14 +64,14 @@ public: // Get a symbol reference to the printf function, inserting it if necessary. auto printfRef = getOrInsertPrintf(rewriter, parentModule, llvmDialect); - Value *formatSpecifierCst = getOrCreateGlobalString( + ValuePtr formatSpecifierCst = getOrCreateGlobalString( loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule, llvmDialect); - Value *newLineCst = getOrCreateGlobalString( + ValuePtr newLineCst = getOrCreateGlobalString( loc, rewriter, "nl", StringRef("\n\0", 2), parentModule, llvmDialect); // Create a loop for each of the dimensions within the shape. - SmallVector loopIvs; + SmallVector loopIvs; for (unsigned i = 0, e = memRefShape.size(); i != e; ++i) { auto lowerBound = rewriter.create(loc, 0); auto upperBound = rewriter.create(loc, memRefShape[i]); @@ -97,7 +97,7 @@ public: auto elementLoad = rewriter.create(loc, printOp.input(), loopIvs); rewriter.create( loc, printfRef, rewriter.getIntegerType(32), - ArrayRef({formatSpecifierCst, elementLoad})); + ArrayRef({formatSpecifierCst, elementLoad})); // Notify the rewriter that this operation has been removed. rewriter.eraseOp(op); @@ -130,10 +130,10 @@ private: /// Return a value representing an access into a global string with the given /// name, creating the string if necessary. - static Value *getOrCreateGlobalString(Location loc, OpBuilder &builder, - StringRef name, StringRef value, - ModuleOp module, - LLVM::LLVMDialect *llvmDialect) { + static ValuePtr getOrCreateGlobalString(Location loc, OpBuilder &builder, + StringRef name, StringRef value, + ModuleOp module, + LLVM::LLVMDialect *llvmDialect) { // Create the global at the entry of the module. LLVM::GlobalOp global; if (!(global = module.lookupSymbol(name))) { @@ -147,13 +147,13 @@ private: } // Get the pointer to the first character in the global string. - Value *globalPtr = builder.create(loc, global); - Value *cst0 = builder.create( + ValuePtr globalPtr = builder.create(loc, global); + ValuePtr cst0 = builder.create( loc, LLVM::LLVMType::getInt64Ty(llvmDialect), builder.getIntegerAttr(builder.getIndexType(), 0)); return builder.create( loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), globalPtr, - ArrayRef({cst0, cst0})); + ArrayRef({cst0, cst0})); } }; } // end anonymous namespace diff --git a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp index b33137a1066..590b21e53a1 100644 --- a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp @@ -108,11 +108,11 @@ private: /// Entering a function creates a new scope, and the function arguments are /// added to the mapping. When the processing of a function is terminated, the /// scope is destroyed and the mappings created in this scope are dropped. - llvm::ScopedHashTable> + llvm::ScopedHashTable> symbolTable; using SymbolTableScopeT = llvm::ScopedHashTableScope>; + std::pair>; /// A mapping for the functions that have been code generated to MLIR. llvm::StringMap functionMap; @@ -129,7 +129,7 @@ private: /// Declare a variable in the current scope, return success if the variable /// wasn't declared yet. - mlir::LogicalResult declare(VarDeclExprAST &var, mlir::Value *value) { + mlir::LogicalResult declare(VarDeclExprAST &var, mlir::ValuePtr value) { if (symbolTable.count(var.getName())) return mlir::failure(); symbolTable.insert(var.getName(), {value, &var}); @@ -301,7 +301,7 @@ private: } /// Emit a binary operation - mlir::Value *mlirGen(BinaryExprAST &binop) { + mlir::ValuePtr mlirGen(BinaryExprAST &binop) { // First emit the operations for each side of the operation before emitting // the operation itself. For example if the expression is `a + foo(a)` // 1) First it will visiting the LHS, which will return a reference to the @@ -313,7 +313,7 @@ private: // and the result value is returned. If an error occurs we get a nullptr // and propagate. // - mlir::Value *lhs = mlirGen(*binop.getLHS()); + mlir::ValuePtr lhs = mlirGen(*binop.getLHS()); if (!lhs) return nullptr; auto location = loc(binop.loc()); @@ -329,7 +329,7 @@ private: } // Otherwise, this is a normal binary op. - mlir::Value *rhs = mlirGen(*binop.getRHS()); + mlir::ValuePtr rhs = mlirGen(*binop.getRHS()); if (!rhs) return nullptr; @@ -349,8 +349,8 @@ private: /// This is a reference to a variable in an expression. The variable is /// expected to have been declared and so should have a value in the symbol /// table, otherwise emit an error and return nullptr. - mlir::Value *mlirGen(VariableExprAST &expr) { - if (auto *variable = symbolTable.lookup(expr.getName()).first) + mlir::ValuePtr mlirGen(VariableExprAST &expr) { + if (auto variable = symbolTable.lookup(expr.getName()).first) return variable; emitError(loc(expr.loc()), "error: unknown variable '") @@ -363,7 +363,7 @@ private: auto location = loc(ret.loc()); // 'return' takes an optional expression, handle that case here. - mlir::Value *expr = nullptr; + mlir::ValuePtr expr = nullptr; if (ret.getExpr().hasValue()) { if (!(expr = mlirGen(*ret.getExpr().getValue()))) return mlir::failure(); @@ -371,7 +371,7 @@ private: // Otherwise, this return operation has zero operands. builder.create(location, expr ? makeArrayRef(expr) - : ArrayRef()); + : ArrayRef()); return mlir::success(); } @@ -450,7 +450,7 @@ private: } /// Emit an array literal. - mlir::Value *mlirGen(LiteralExprAST &lit) { + mlir::ValuePtr mlirGen(LiteralExprAST &lit) { mlir::Type type = getType(lit.getDims()); mlir::DenseElementsAttr dataAttribute = getConstantAttr(lit); @@ -462,7 +462,7 @@ private: /// Emit a struct literal. It will be emitted as an array of /// other literals in an Attribute attached to a `toy.struct_constant` /// operation. - mlir::Value *mlirGen(StructLiteralExprAST &lit) { + mlir::ValuePtr mlirGen(StructLiteralExprAST &lit) { mlir::ArrayAttr dataAttr; mlir::Type dataType; std::tie(dataAttr, dataType) = getConstantAttr(lit); @@ -493,14 +493,14 @@ private: /// Emit a call expression. It emits specific operations for the `transpose` /// builtin. Other identifiers are assumed to be user-defined functions. - mlir::Value *mlirGen(CallExprAST &call) { + mlir::ValuePtr mlirGen(CallExprAST &call) { llvm::StringRef callee = call.getCallee(); auto location = loc(call.loc()); // Codegen the operands first. - SmallVector operands; + SmallVector operands; for (auto &expr : call.getArgs()) { - auto *arg = mlirGen(*expr); + auto arg = mlirGen(*expr); if (!arg) return nullptr; operands.push_back(arg); @@ -534,7 +534,7 @@ private: /// Emit a print expression. It emits specific operations for two builtins: /// transpose(x) and print(x). mlir::LogicalResult mlirGen(PrintExprAST &call) { - auto *arg = mlirGen(*call.getArg()); + auto arg = mlirGen(*call.getArg()); if (!arg) return mlir::failure(); @@ -543,12 +543,12 @@ private: } /// Emit a constant for a single number (FIXME: semantic? broadcast?) - mlir::Value *mlirGen(NumberExprAST &num) { + mlir::ValuePtr mlirGen(NumberExprAST &num) { return builder.create(loc(num.loc()), num.getValue()); } /// Dispatch codegen for the right expression subclass using RTTI. - mlir::Value *mlirGen(ExprAST &expr) { + mlir::ValuePtr mlirGen(ExprAST &expr) { switch (expr.getKind()) { case toy::ExprAST::Expr_BinOp: return mlirGen(cast(expr)); @@ -574,7 +574,7 @@ private: /// initializer and record the value in the symbol table before returning it. /// Future expressions will be able to reference this variable through symbol /// table lookup. - mlir::Value *mlirGen(VarDeclExprAST &vardecl) { + mlir::ValuePtr mlirGen(VarDeclExprAST &vardecl) { auto init = vardecl.getInitVal(); if (!init) { emitError(loc(vardecl.loc()), @@ -582,7 +582,7 @@ private: return nullptr; } - mlir::Value *value = mlirGen(*init); + mlir::ValuePtr value = mlirGen(*init); if (!value) return nullptr; diff --git a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp index ebd4f5d1103..d18396c63bb 100644 --- a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp @@ -71,7 +71,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { matchAndRewrite(TransposeOp op, mlir::PatternRewriter &rewriter) const override { // Look through the input of the current transpose. - mlir::Value *transposeInput = op.getOperand(); + mlir::ValuePtr transposeInput = op.getOperand(); TransposeOp transposeInputOp = llvm::dyn_cast_or_null(transposeInput->getDefiningOp()); diff --git a/mlir/g3doc/DeclarativeRewrites.md b/mlir/g3doc/DeclarativeRewrites.md index 5adcb320983..9fcd4341611 100644 --- a/mlir/g3doc/DeclarativeRewrites.md +++ b/mlir/g3doc/DeclarativeRewrites.md @@ -233,7 +233,7 @@ In the above, we are using `BOp`'s result for building `COp`. Given that `COp` was specified with table-driven op definition, there will be several `build()` methods generated for it. One of them has aggregated parameters for result types, operands, and attributes in the signature: `void -COp::build(..., ArrayRef resultTypes, Array operands, +COp::build(..., ArrayRef resultTypes, Array operands, ArrayRef attr)`. The pattern in the above calls this `build()` method for constructing the `COp`. @@ -266,7 +266,7 @@ For example, for the above `AOp`, a possible builder is: ```c++ void AOp::build(Builder *builder, OperationState &state, - Value *input, Attribute attr) { + ValuePtr input, Attribute attr) { state.addOperands({input}); state.addAttribute("a_attr", attr); Type type = ...; // Deduce result type here @@ -422,7 +422,7 @@ op; it can be also used to specify how to build an op entirely. An example: If we have a C++ function for building an op: ```c++ -Operation *createMyOp(OpBuilder builder, Value *input, Attribute attr); +Operation *createMyOp(OpBuilder builder, ValuePtr input, Attribute attr); ``` We can wrap it up and invoke it like: diff --git a/mlir/g3doc/DialectConversion.md b/mlir/g3doc/DialectConversion.md index b4e309daf1f..6771860366c 100644 --- a/mlir/g3doc/DialectConversion.md +++ b/mlir/g3doc/DialectConversion.md @@ -209,7 +209,7 @@ class TypeConverter { /// the conversion has finished. virtual Operation *materializeConversion(PatternRewriter &rewriter, Type resultType, - ArrayRef inputs, + ArrayRef inputs, Location loc); }; ``` @@ -232,7 +232,7 @@ struct MyConversionPattern : public ConversionPattern { /// `operands` parameter, containing the remapped operands of the original /// operation. virtual PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const; }; ``` @@ -269,7 +269,7 @@ public: /// Remap an input of the original signature to another `replacement` /// value. This drops the original argument. - void remapInput(unsigned origInputNo, Value *replacement); + void remapInput(unsigned origInputNo, ValuePtr replacement); }; ``` diff --git a/mlir/g3doc/EDSC.md b/mlir/g3doc/EDSC.md index afceac2dfc1..eaaeb6c7009 100644 --- a/mlir/g3doc/EDSC.md +++ b/mlir/g3doc/EDSC.md @@ -15,10 +15,10 @@ declarative builders are available within the lifetime of a `ScopedContext`. ## ValueHandle and IndexHandle `mlir::edsc::ValueHandle` and `mlir::edsc::IndexHandle` provide typed -abstractions around an `mlir::Value*`. These abstractions are "delayed", in the -sense that they allow separating declaration from definition. They may -capture IR snippets, as they are built, for programmatic manipulation. -Intuitive operators are provided to allow concise and idiomatic expressions. +abstractions around an `mlir::Value`. These abstractions are "delayed", in the +sense that they allow separating declaration from definition. They may capture +IR snippets, as they are built, for programmatic manipulation. Intuitive +operators are provided to allow concise and idiomatic expressions. ```c++ ValueHandle zero = constant_index(0); diff --git a/mlir/g3doc/GenericDAGRewriter.md b/mlir/g3doc/GenericDAGRewriter.md index 3b26c22eb37..64b8f4f7ade 100644 --- a/mlir/g3doc/GenericDAGRewriter.md +++ b/mlir/g3doc/GenericDAGRewriter.md @@ -128,7 +128,7 @@ complicated :) if (match(LHS, m_Xor(m_Value(Y), m_APInt(C1)))) if (C1->countTrailingZeros() == 0) if (match(Y, m_And(m_Value(Z), m_APInt(C2))) && *C1 == (*C2 + 1)) { - Value *NewOr = Builder.CreateOr(Z, ~(*C2)); + ValuePtr NewOr = Builder.CreateOr(Z, ~(*C2)); return Builder.CreateSub(RHS, NewOr, "sub"); } ``` diff --git a/mlir/g3doc/OpDefinitions.md b/mlir/g3doc/OpDefinitions.md index 1f98671d59a..1db18266ee0 100644 --- a/mlir/g3doc/OpDefinitions.md +++ b/mlir/g3doc/OpDefinitions.md @@ -360,7 +360,7 @@ def MyInterface : OpInterface<"MyInterface"> { // A new non-static method accepting an input argument. InterfaceMethod<"/*insert doc here*/", - "Value *", "bar", (ins "unsigned":$i) + "ValuePtr ", "bar", (ins "unsigned":$i) >, // Query a static property of the derived operation. @@ -438,7 +438,7 @@ static void build(Builder *tblgen_builder, OperationState &tblgen_state, // for attributes are of mlir::Attribute types. static void build(Builder *tblgen_builder, OperationState &tblgen_state, Type i32_result, Type f32_result, ..., - Value *i32_operand, Value *f32_operand, ..., + ValuePtr i32_operand, ValuePtr f32_operand, ..., IntegerAttr i32_attr, FloatAttr f32_attr, ...); // Each result-type/operand/attribute has a separate parameter. The parameters @@ -447,13 +447,13 @@ static void build(Builder *tblgen_builder, OperationState &tblgen_state, // explanation for more details.) static void build(Builder *tblgen_builder, OperationState &tblgen_state, Type i32_result, Type f32_result, ..., - Value *i32_operand, Value *f32_operand, ..., + ValuePtr i32_operand, ValuePtr f32_operand, ..., APInt i32_attr, StringRef f32_attr, ...); // Each operand/attribute has a separate parameter but result type is aggregate. static void build(Builder *tblgen_builder, OperationState &tblgen_state, ArrayRef resultTypes, - Value *i32_operand, Value *f32_operand, ..., + ValuePtr i32_operand, ValuePtr f32_operand, ..., IntegerAttr i32_attr, FloatAttr f32_attr, ...); // All operands/attributes have aggregate parameters. @@ -615,7 +615,7 @@ coding style requirements. For each operation, we automatically generate an _operand adaptor_. This class solves the problem of accessing operands provided as a list of `Value`s without using "magic" constants. The operand adaptor takes a reference to an array of -`Value *` and provides methods with the same names as those in the operation +`ValuePtr` and provides methods with the same names as those in the operation class to access them. For example, for a binary arithmetic operation, it may provide `.lhs()` to access the first operand and `.rhs()` to access the second operand. @@ -629,11 +629,11 @@ Operand adaptors can be used in function templates that also process operations: ```c++ template -std::pair zip(BinaryOpTy &&op) { +std::pair zip(BinaryOpTy &&op) { return std::make_pair(op.lhs(), op.rhs());; } -void process(AddOp op, ArrayRef newOperands) { +void process(AddOp op, ArrayRef newOperands) { zip(op); zip(OperandAdaptor(newOperands)); /*...*/ diff --git a/mlir/g3doc/QuickstartRewrites.md b/mlir/g3doc/QuickstartRewrites.md index d7bf9a54370..6a4a7cca8b8 100644 --- a/mlir/g3doc/QuickstartRewrites.md +++ b/mlir/g3doc/QuickstartRewrites.md @@ -128,8 +128,8 @@ def : Pat<(TF_LeakyReluOp:$old_value, $arg, F32Attr:$a), ``` ```c++ -static Value* createTFLLeakyRelu(PatternRewriter &rewriter, Operation *op, - Value* operand, Attribute attr) { +static Value createTFLLeakyRelu(PatternRewriter &rewriter, Operation *op, + Value operand, Attribute attr) { return rewriter.create( op->getLoc(), operands[0]->getType(), /*arg=*/operands[0], /*alpha=*/attrs[0].cast()); diff --git a/mlir/g3doc/Rationale.md b/mlir/g3doc/Rationale.md index 66cf800621d..763442dce06 100644 --- a/mlir/g3doc/Rationale.md +++ b/mlir/g3doc/Rationale.md @@ -1099,7 +1099,7 @@ those chunks independently. The problem is that LLVM has several objects in its IR that are globally uniqued and also mutable: notably constants like `i32 0`. In LLVM, these constants are -`Value*r`'s, which allow them to be used as operands to instructions, and that +`Value`'s, which allow them to be used as operands to instructions, and that they also have SSA use lists. Because these things are uniqued, every `i32 0` in any function shares a use list. This means that optimizing multiple functions in parallel won't work (at least without some sort of synchronization on the use diff --git a/mlir/g3doc/Tutorials/Toy/Ch-3.md b/mlir/g3doc/Tutorials/Toy/Ch-3.md index 07ead64d455..fb470434d6f 100644 --- a/mlir/g3doc/Tutorials/Toy/Ch-3.md +++ b/mlir/g3doc/Tutorials/Toy/Ch-3.md @@ -90,7 +90,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { matchAndRewrite(TransposeOp op, mlir::PatternRewriter &rewriter) const override { // Look through the input of the current transpose. - mlir::Value *transposeInput = op.getOperand(); + mlir::ValuePtr transposeInput = op.getOperand(); TransposeOp transposeInputOp = llvm::dyn_cast_or_null(transposeInput->getDefiningOp()); // If the input is defined by another Transpose, bingo! diff --git a/mlir/g3doc/Tutorials/Toy/Ch-4.md b/mlir/g3doc/Tutorials/Toy/Ch-4.md index ac124699c2f..921e5cdc52a 100644 --- a/mlir/g3doc/Tutorials/Toy/Ch-4.md +++ b/mlir/g3doc/Tutorials/Toy/Ch-4.md @@ -75,7 +75,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface { /// previously returned by the call operation with the operands of the /// return. void handleTerminator(Operation *op, - ArrayRef valuesToRepl) const final { + ArrayRef valuesToRepl) const final { // Only "toy.return" needs to be handled here. auto returnOp = cast(op); @@ -207,7 +207,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface { /// operation that takes 'input' as the only operand, and produces a single /// result of 'resultType'. If a conversion can not be generated, nullptr /// should be returned. - Operation *materializeCallConversion(OpBuilder &builder, Value *input, + Operation *materializeCallConversion(OpBuilder &builder, ValuePtr input, Type resultType, Location conversionLoc) const final { return builder.create(conversionLoc, resultType, input); diff --git a/mlir/g3doc/Tutorials/Toy/Ch-5.md b/mlir/g3doc/Tutorials/Toy/Ch-5.md index 1124cf14a43..ed62f8954b7 100644 --- a/mlir/g3doc/Tutorials/Toy/Ch-5.md +++ b/mlir/g3doc/Tutorials/Toy/Ch-5.md @@ -101,7 +101,7 @@ struct TransposeOpLowering : public mlir::ConversionPattern { /// Match and rewrite the given `toy.transpose` operation, with the given /// operands that have been remapped from `tensor<...>` to `memref<...>`. mlir::PatternMatchResult - matchAndRewrite(mlir::Operation *op, ArrayRef operands, + matchAndRewrite(mlir::Operation *op, ArrayRef operands, mlir::ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); @@ -112,18 +112,18 @@ struct TransposeOpLowering : public mlir::ConversionPattern { lowerOpToLoops( op, operands, rewriter, [loc](mlir::PatternRewriter &rewriter, - ArrayRef memRefOperands, - ArrayRef loopIvs) { + ArrayRef memRefOperands, + ArrayRef loopIvs) { // Generate an adaptor for the remapped operands of the TransposeOp. // This allows for using the nice named accessors that are generated // by the ODS. This adaptor is automatically provided by the ODS // framework. TransposeOpOperandAdaptor transposeAdaptor(memRefOperands); - mlir::Value *input = transposeAdaptor.input(); + mlir::ValuePtr input = transposeAdaptor.input(); // Transpose the elements by generating a load from the reverse // indices. - SmallVector reverseIvs(llvm::reverse(loopIvs)); + SmallVector reverseIvs(llvm::reverse(loopIvs)); return rewriter.create(loc, input, reverseIvs); }); return matchSuccess(); diff --git a/mlir/g3doc/UsageOfConst.md b/mlir/g3doc/UsageOfConst.md index 052f14ddf01..5f6d3793164 100644 --- a/mlir/g3doc/UsageOfConst.md +++ b/mlir/g3doc/UsageOfConst.md @@ -10,7 +10,7 @@ understood (even though the LLVM implementation is flawed in many ways). The design team since decided to change to a different module, which eschews `const` entirely for the core IR types: you should never see a `const` method on -`Operation`, should never see the type `const Value *`, and you shouldn't feel +`Operation`, should never see the type `const ValuePtr`, and you shouldn't feel bad about this. That said, you *should* use `const` for non-IR types, like `SmallVector`'s and many other things. @@ -39,7 +39,7 @@ into the MLIR codebase, argues that the cost/benefit tradeoff of this design is a poor tradeoff, and proposes switching to a much simpler approach - eliminating the use of const of these IR types entirely. -**Note:** **This document is only discussing things like `const Value*` and +**Note:** **This document is only discussing things like `const Value` and `const Operation*`. There is no proposed change for other types, e.g. `SmallVector` references, the immutable types like `Attribute`, etc.** @@ -130,7 +130,7 @@ const. operand_iterator operand_begin(); operand_iterator operand_end(); - /// Returns an iterator on the underlying Value's (Value *). + /// Returns an iterator on the underlying Value's (ValuePtr ). operand_range getOperands(); // Support const operand iteration. @@ -141,7 +141,7 @@ const. const_operand_iterator operand_begin() const; const_operand_iterator operand_end() const; - /// Returns a const iterator on the underlying Value's (Value *). + /// Returns a const iterator on the underlying Value's (ValuePtr ). llvm::iterator_range getOperands() const; ArrayRef getOpOperands() const { diff --git a/mlir/include/mlir/Analysis/AffineAnalysis.h b/mlir/include/mlir/Analysis/AffineAnalysis.h index 8243d1f6f63..f506470f36a 100644 --- a/mlir/include/mlir/Analysis/AffineAnalysis.h +++ b/mlir/include/mlir/Analysis/AffineAnalysis.h @@ -39,10 +39,13 @@ class FlatAffineConstraints; class Operation; class Value; +// TODO(riverriddle) Remove this after Value is value-typed. +using ValuePtr = Value *; + /// Returns in `affineApplyOps`, the sequence of those AffineApplyOp /// Operations that are reachable via a search starting from `operands` and /// ending at those operands that are not the result of an AffineApplyOp. -void getReachableAffineApplyOps(ArrayRef operands, +void getReachableAffineApplyOps(ArrayRef operands, SmallVectorImpl &affineApplyOps); /// Builds a system of constraints with dimensional identifiers corresponding to @@ -56,9 +59,9 @@ LogicalResult getIndexSet(MutableArrayRef forOps, /// Encapsulates a memref load or store access information. struct MemRefAccess { - Value *memref; + ValuePtr memref; Operation *opInst; - SmallVector indices; + SmallVector indices; /// Constructs a MemRefAccess from a load or store operation. // TODO(b/119949820): add accessors to standard op's load, store, DMA op's to diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h index e53af5024da..65cf13a0ce6 100644 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -123,8 +123,8 @@ public: // Creates an empty AffineValueMap (users should call 'reset' to reset map // and operands). AffineValueMap() {} - AffineValueMap(AffineMap map, ArrayRef operands, - ArrayRef results = llvm::None); + AffineValueMap(AffineMap map, ArrayRef operands, + ArrayRef results = llvm::None); explicit AffineValueMap(AffineApplyOp applyOp); explicit AffineValueMap(AffineBound bound); @@ -132,8 +132,8 @@ public: ~AffineValueMap(); // Resets this AffineValueMap with 'map', 'operands', and 'results'. - void reset(AffineMap map, ArrayRef operands, - ArrayRef results = llvm::None); + void reset(AffineMap map, ArrayRef operands, + ArrayRef results = llvm::None); /// Return the value map that is the difference of value maps 'a' and 'b', /// represented as an affine map and its operands. The output map + operands @@ -146,7 +146,7 @@ public: inline bool isMultipleOf(unsigned idx, int64_t factor) const; /// Return true if the idx^th result depends on 'value', false otherwise. - bool isFunctionOf(unsigned idx, Value *value) const; + bool isFunctionOf(unsigned idx, ValuePtr value) const; /// Return true if the result at 'idx' is a constant, false /// otherwise. @@ -162,8 +162,8 @@ public: inline unsigned getNumSymbols() const { return map.getNumSymbols(); } inline unsigned getNumResults() const { return map.getNumResults(); } - Value *getOperand(unsigned i) const; - ArrayRef getOperands() const; + ValuePtr getOperand(unsigned i) const; + ArrayRef getOperands() const; AffineMap getAffineMap() const; private: @@ -172,9 +172,9 @@ private: // TODO: make these trailing objects? /// The SSA operands binding to the dim's and symbols of 'map'. - SmallVector operands; + SmallVector operands; /// The SSA results binding to the results of 'map'. - SmallVector results; + SmallVector results; }; /// An IntegerValueSet is an integer set plus its operands. @@ -207,7 +207,7 @@ private: // 'AffineCondition'. MutableIntegerSet set; /// The SSA operands binding to the dim's and symbols of 'set'. - SmallVector operands; + SmallVector operands; }; /// A flat list of affine equalities and inequalities in the form. @@ -245,7 +245,7 @@ public: unsigned numReservedEqualities, unsigned numReservedCols, unsigned numDims = 0, unsigned numSymbols = 0, unsigned numLocals = 0, - ArrayRef> idArgs = {}) + ArrayRef> idArgs = {}) : numReservedCols(numReservedCols), numDims(numDims), numSymbols(numSymbols) { assert(numReservedCols >= numDims + numSymbols + 1); @@ -264,7 +264,7 @@ public: /// dimensions and symbols. FlatAffineConstraints(unsigned numDims = 0, unsigned numSymbols = 0, unsigned numLocals = 0, - ArrayRef> idArgs = {}) + ArrayRef> idArgs = {}) : numReservedCols(numDims + numSymbols + numLocals + 1), numDims(numDims), numSymbols(numSymbols) { assert(numReservedCols >= numDims + numSymbols + 1); @@ -304,10 +304,10 @@ public: // Clears any existing data and reserves memory for the specified constraints. void reset(unsigned numReservedInequalities, unsigned numReservedEqualities, unsigned numReservedCols, unsigned numDims, unsigned numSymbols, - unsigned numLocals = 0, ArrayRef idArgs = {}); + unsigned numLocals = 0, ArrayRef idArgs = {}); void reset(unsigned numDims = 0, unsigned numSymbols = 0, - unsigned numLocals = 0, ArrayRef idArgs = {}); + unsigned numLocals = 0, ArrayRef idArgs = {}); /// Appends constraints from 'other' into this. This is equivalent to an /// intersection with no simplification of any sort attempted. @@ -396,7 +396,7 @@ public: /// operands. If `eq` is true, add a single equality equal to the bound map's /// first result expr. LogicalResult addLowerOrUpperBound(unsigned pos, AffineMap boundMap, - ArrayRef operands, bool eq, + ArrayRef operands, bool eq, bool lower = true); /// Computes the lower and upper bounds of the first 'num' dimensional @@ -415,10 +415,10 @@ public: /// operand list 'operands'. /// This function assumes 'values.size' == 'lbMaps.size' == 'ubMaps.size'. /// Note that both lower/upper bounds use operands from 'operands'. - LogicalResult addSliceBounds(ArrayRef values, + LogicalResult addSliceBounds(ArrayRef values, ArrayRef lbMaps, ArrayRef ubMaps, - ArrayRef operands); + ArrayRef operands); // Adds an inequality (>= 0) from the coefficients specified in inEq. void addInequality(ArrayRef inEq); @@ -447,25 +447,25 @@ public: /// Sets the identifier corresponding to the specified Value id to a /// constant. Asserts if the 'id' is not found. - void setIdToConstant(Value &id, int64_t val); + void setIdToConstant(ValueRef id, int64_t val); /// Looks up the position of the identifier with the specified Value. Returns /// true if found (false otherwise). `pos' is set to the (column) position of /// the identifier. - bool findId(Value &id, unsigned *pos) const; + bool findId(ValueRef id, unsigned *pos) const; /// Returns true if an identifier with the specified Value exists, false /// otherwise. - bool containsId(Value &id) const; + bool containsId(ValueRef id) const; // Add identifiers of the specified kind - specified positions are relative to // the kind of identifier. The coefficient column corresponding to the added // identifier is initialized to zero. 'id' is the Value corresponding to the // identifier that can optionally be provided. - void addDimId(unsigned pos, Value *id = nullptr); - void addSymbolId(unsigned pos, Value *id = nullptr); + void addDimId(unsigned pos, ValuePtr id = nullptr); + void addSymbolId(unsigned pos, ValuePtr id = nullptr); void addLocalId(unsigned pos); - void addId(IdKind kind, unsigned pos, Value *id = nullptr); + void addId(IdKind kind, unsigned pos, ValuePtr id = nullptr); /// Add the specified values as a dim or symbol id depending on its nature, if /// it already doesn't exist in the system. `id' has to be either a terminal @@ -473,7 +473,7 @@ public: /// symbols or loop IVs. The identifier is added to the end of the existing /// dims or symbols. Additional information on the identifier is extracted /// from the IR and added to the constraint system. - void addInductionVarOrTerminalSymbol(Value *id); + void addInductionVarOrTerminalSymbol(ValuePtr id); /// Composes the affine value map with this FlatAffineConstrains, adding the /// results of the map as dimensions at the front [0, vMap->getNumResults()) @@ -500,8 +500,8 @@ public: void projectOut(unsigned pos, unsigned num); inline void projectOut(unsigned pos) { return projectOut(pos, 1); } - /// Projects out the identifier that is associate with Value *. - void projectOut(Value *id); + /// Projects out the identifier that is associate with ValuePtr . + void projectOut(ValuePtr id); void removeId(IdKind idKind, unsigned pos); void removeId(unsigned pos); @@ -577,20 +577,20 @@ public: return numIds - numDims - numSymbols; } - inline ArrayRef> getIds() const { + inline ArrayRef> getIds() const { return {ids.data(), ids.size()}; } - inline MutableArrayRef> getIds() { + inline MutableArrayRef> getIds() { return {ids.data(), ids.size()}; } /// Returns the optional Value corresponding to the pos^th identifier. - inline Optional getId(unsigned pos) const { return ids[pos]; } - inline Optional &getId(unsigned pos) { return ids[pos]; } + inline Optional getId(unsigned pos) const { return ids[pos]; } + inline Optional &getId(unsigned pos) { return ids[pos]; } /// Returns the Value associated with the pos^th identifier. Asserts if /// no Value identifier was associated. - inline Value *getIdValue(unsigned pos) const { + inline ValuePtr getIdValue(unsigned pos) const { assert(ids[pos].hasValue() && "identifier's Value not set"); return ids[pos].getValue(); } @@ -598,7 +598,7 @@ public: /// Returns the Values associated with identifiers in range [start, end). /// Asserts if no Value was associated with one of these identifiers. void getIdValues(unsigned start, unsigned end, - SmallVectorImpl *values) const { + SmallVectorImpl *values) const { assert((start < numIds || start == end) && "invalid start position"); assert(end <= numIds && "invalid end position"); values->clear(); @@ -607,17 +607,17 @@ public: values->push_back(getIdValue(i)); } } - inline void getAllIdValues(SmallVectorImpl *values) const { + inline void getAllIdValues(SmallVectorImpl *values) const { getIdValues(0, numIds, values); } /// Sets Value associated with the pos^th identifier. - inline void setIdValue(unsigned pos, Value *val) { + inline void setIdValue(unsigned pos, ValuePtr val) { assert(pos < numIds && "invalid id position"); ids[pos] = val; } /// Sets Values associated with identifiers in the range [start, end). - void setIdValues(unsigned start, unsigned end, ArrayRef values) { + void setIdValues(unsigned start, unsigned end, ArrayRef values) { assert((start < numIds || end == start) && "invalid start position"); assert(end <= numIds && "invalid end position"); assert(values.size() == end - start); @@ -766,7 +766,7 @@ private: /// system appearing in the order the identifiers correspond to columns. /// Temporary ones or those that aren't associated to any Value are set to /// None. - SmallVector, 8> ids; + SmallVector, 8> ids; /// A parameter that controls detection of an unrealistic number of /// constraints. If the number of constraints is this many times the number of diff --git a/mlir/include/mlir/Analysis/CallInterfaces.h b/mlir/include/mlir/Analysis/CallInterfaces.h index dd23d77889f..a18cfa7aba4 100644 --- a/mlir/include/mlir/Analysis/CallInterfaces.h +++ b/mlir/include/mlir/Analysis/CallInterfaces.h @@ -30,8 +30,8 @@ namespace mlir { /// A callable is either a symbol, or an SSA value, that is referenced by a /// call-like operation. This represents the destination of the call. -struct CallInterfaceCallable : public PointerUnion { - using PointerUnion::PointerUnion; +struct CallInterfaceCallable : public PointerUnion { + using PointerUnion::PointerUnion; }; #include "mlir/Analysis/CallInterfaces.h.inc" diff --git a/mlir/include/mlir/Analysis/Dominance.h b/mlir/include/mlir/Analysis/Dominance.h index 09114eafbb1..f46241e2af0 100644 --- a/mlir/include/mlir/Analysis/Dominance.h +++ b/mlir/include/mlir/Analysis/Dominance.h @@ -74,10 +74,10 @@ public: } /// Return true if value A properly dominates operation B. - bool properlyDominates(Value *a, Operation *b); + bool properlyDominates(ValuePtr a, Operation *b); /// Return true if operation A dominates operation B. - bool dominates(Value *a, Operation *b) { + bool dominates(ValuePtr a, Operation *b) { return (Operation *)a->getDefiningOp() == b || properlyDominates(a, b); } diff --git a/mlir/include/mlir/Analysis/Liveness.h b/mlir/include/mlir/Analysis/Liveness.h index 0bdb474fd92..0aa9d9693e4 100644 --- a/mlir/include/mlir/Analysis/Liveness.h +++ b/mlir/include/mlir/Analysis/Liveness.h @@ -41,6 +41,9 @@ class Operation; class Region; class Value; +// TODO(riverriddle) Remove this after Value is value-typed. +using ValuePtr = Value *; + /// Represents an analysis for computing liveness information from a /// given top-level operation. The analysis iterates over all associated /// regions that are attached to the given top-level operation. It @@ -57,7 +60,7 @@ class Liveness { public: using OperationListT = std::vector; using BlockMapT = DenseMap; - using ValueSetT = SmallPtrSet; + using ValueSetT = SmallPtrSet; public: /// Creates a new Liveness analysis that computes liveness @@ -72,7 +75,7 @@ public: /// Note that the operations in this list are not ordered and the current /// implementation is computationally expensive (as it iterates over all /// blocks in which the given value is live). - OperationListT resolveLiveness(Value *value) const; + OperationListT resolveLiveness(ValuePtr value) const; /// Gets liveness info (if any) for the block. const LivenessBlockInfo *getLiveness(Block *block) const; @@ -85,7 +88,7 @@ public: /// Returns true if the given operation represent the last use of the /// given value. - bool isLastUse(Value *value, Operation *operation) const; + bool isLastUse(ValuePtr value, Operation *operation) const; /// Dumps the liveness information in a human readable format. void dump() const; @@ -124,20 +127,20 @@ public: const ValueSetT &out() const { return outValues; } /// Returns true if the given value is in the live-in set. - bool isLiveIn(Value *value) const; + bool isLiveIn(ValuePtr value) const; /// Returns true if the given value is in the live-out set. - bool isLiveOut(Value *value) const; + bool isLiveOut(ValuePtr value) const; /// Gets the start operation for the given value. This is the first operation /// the given value is considered to be live. This could either be the start /// operation of the current block (in case the value is live-in) or the /// operation that defines the given value (must be referenced in this block). - Operation *getStartOperation(Value *value) const; + Operation *getStartOperation(ValuePtr value) const; /// Gets the end operation for the given value using the start operation /// provided (must be referenced in this block). - Operation *getEndOperation(Value *value, Operation *startOperation) const; + Operation *getEndOperation(ValuePtr value, Operation *startOperation) const; private: /// The underlying block. diff --git a/mlir/include/mlir/Analysis/LoopAnalysis.h b/mlir/include/mlir/Analysis/LoopAnalysis.h index 47cc22a4923..ad7dc6d6092 100644 --- a/mlir/include/mlir/Analysis/LoopAnalysis.h +++ b/mlir/include/mlir/Analysis/LoopAnalysis.h @@ -36,6 +36,9 @@ class NestedPattern; class Operation; class Value; +// TODO(riverriddle) Remove this after Value is value-typed. +using ValuePtr = Value *; + /// Returns the trip count of the loop as an affine map with its corresponding /// operands if the latter is expressible as an affine expression, and nullptr /// otherwise. This method always succeeds as long as the lower bound is not a @@ -45,7 +48,7 @@ class Value; // TODO(mlir-team): this should be moved into 'Transforms/' and be replaced by a // pure analysis method relying on FlatAffineConstraints void buildTripCountMapAndOperands(AffineForOp forOp, AffineMap *map, - SmallVectorImpl *operands); + SmallVectorImpl *operands); /// Returns the trip count of the loop if it's a constant, None otherwise. This /// uses affine expression analysis and is able to determine constant trip count @@ -66,8 +69,8 @@ uint64_t getLargestDivisorOfTripCount(AffineForOp forOp); /// /// Emits a note if it encounters a chain of affine.apply and conservatively /// those cases. -DenseSet> -getInvariantAccesses(Value *iv, ArrayRef indices); +DenseSet> +getInvariantAccesses(ValuePtr iv, ArrayRef indices); using VectorizableLoopFun = std::function; diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index cffa222154f..ea0987df3fe 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -55,7 +55,7 @@ unsigned getNestingDepth(Operation &op); /// Returns in 'sequentialLoops' all sequential loops in loop nest rooted /// at 'forOp'. void getSequentialLoops(AffineForOp forOp, - llvm::SmallDenseSet *sequentialLoops); + llvm::SmallDenseSet *sequentialLoops); /// ComputationSliceState aggregates loop IVs, loop bound AffineMaps and their /// associated operands for a set of loops within a loop nest (typically the @@ -64,15 +64,15 @@ void getSequentialLoops(AffineForOp forOp, struct ComputationSliceState { // List of sliced loop IVs (ordered from outermost to innermost). // EX: 'ivs[i]' has lower bound 'lbs[i]' and upper bound 'ubs[i]'. - SmallVector ivs; + SmallVector ivs; // List of lower bound AffineMaps. SmallVector lbs; // List of upper bound AffineMaps. SmallVector ubs; // List of lower bound operands (lbOperands[i] are used by 'lbs[i]'). - std::vector> lbOperands; + std::vector> lbOperands; // List of upper bound operands (ubOperands[i] are used by 'ubs[i]'). - std::vector> ubOperands; + std::vector> ubOperands; // Slice loop nest insertion point in target loop nest. Block::iterator insertPoint; // Adds to 'cst' with constraints which represent the slice bounds on 'ivs' @@ -257,7 +257,7 @@ struct MemRefRegion { unsigned getRank() const; /// Memref that this region corresponds to. - Value *memref; + ValuePtr memref; /// Read or write. bool write; diff --git a/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h b/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h index b5c51ad4b4c..4bbe6610e31 100644 --- a/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h +++ b/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h @@ -30,14 +30,17 @@ class OpBuilder; class RewritePattern; class Value; +// TODO(riverriddle) Remove this after Value is value-typed. +using ValuePtr = Value *; + // Owning list of rewriting patterns. class OwningRewritePatternList; /// Emit code that computes the given affine expression using standard /// arithmetic operations applied to the provided dimension and symbol values. -Value *expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, - ArrayRef dimValues, - ArrayRef symbolValues); +ValuePtr expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, + ArrayRef dimValues, + ArrayRef symbolValues); /// Collect a set of patterns to convert from the Affine dialect to the Standard /// dialect, in particular convert structured affine control flow into CFG @@ -47,11 +50,11 @@ void populateAffineToStdConversionPatterns(OwningRewritePatternList &patterns, /// Emit code that computes the lower bound of the given affine loop using /// standard arithmetic operations. -Value *lowerAffineLowerBound(AffineForOp op, OpBuilder &builder); +ValuePtr lowerAffineLowerBound(AffineForOp op, OpBuilder &builder); /// Emit code that computes the upper bound of the given affine loop using /// standard arithmetic operations. -Value *lowerAffineUpperBound(AffineForOp op, OpBuilder &builder); +ValuePtr lowerAffineUpperBound(AffineForOp op, OpBuilder &builder); } // namespace mlir #endif // MLIR_CONVERSION_AFFINETOSTANDARD_AFFINETOSTANDARD_H diff --git a/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h b/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h index 0aab8723eab..58d49a13391 100644 --- a/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h +++ b/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h @@ -24,6 +24,9 @@ class AffineForOp; struct LogicalResult; class Value; +// TODO(riverriddle) Remove this after Value is value-typed. +using ValuePtr = Value *; + namespace loop { class ForOp; } // end namespace loop @@ -78,8 +81,8 @@ LogicalResult convertLoopNestToGPULaunch(loop::ForOp forOp, /// The above conditions are assumed to be satisfied by the computation rooted /// at `forOp`. LogicalResult convertLoopToGPULaunch(loop::ForOp forOp, - ArrayRef numWorkGroups, - ArrayRef workGroupSizes); + ArrayRef numWorkGroups, + ArrayRef workGroupSizes); } // namespace mlir diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h index e8d16f064a8..6f41fb68633 100644 --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -74,16 +74,16 @@ public: /// Promote the LLVM struct representation of all MemRef descriptors to stack /// and use pointers to struct to avoid the complexity of the /// platform-specific C/C++ ABI lowering related to struct argument passing. - SmallVector promoteMemRefDescriptors(Location loc, - ValueRange opOperands, - ValueRange operands, - OpBuilder &builder); + SmallVector promoteMemRefDescriptors(Location loc, + ValueRange opOperands, + ValueRange operands, + OpBuilder &builder); /// Promote the LLVM struct representation of one MemRef descriptor to stack /// and use pointer to struct to avoid the complexity of the platform-specific /// C/C++ ABI lowering related to struct argument passing. - Value *promoteOneMemRefDescriptor(Location loc, Value *operand, - OpBuilder &builder); + ValuePtr promoteOneMemRefDescriptor(Location loc, ValuePtr operand, + OpBuilder &builder); protected: /// LLVM IR module used to parse/create types. @@ -139,24 +139,24 @@ private: class StructBuilder { public: /// Construct a helper for the given value. - explicit StructBuilder(Value *v); + explicit StructBuilder(ValuePtr v); /// Builds IR creating an `undef` value of the descriptor type. static StructBuilder undef(OpBuilder &builder, Location loc, Type descriptorType); - /*implicit*/ operator Value *() { return value; } + /*implicit*/ operator ValuePtr() { return value; } protected: // LLVM value - Value *value; + ValuePtr value; // Cached struct type. Type structType; protected: /// Builds IR to extract a value from the struct at position pos - Value *extractPtr(OpBuilder &builder, Location loc, unsigned pos); + ValuePtr extractPtr(OpBuilder &builder, Location loc, unsigned pos); /// Builds IR to set a value in the struct at position pos - void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value *ptr); + void setPtr(OpBuilder &builder, Location loc, unsigned pos, ValuePtr ptr); }; /// Helper class to produce LLVM dialect operations extracting or inserting /// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor. @@ -164,7 +164,7 @@ protected: class MemRefDescriptor : public StructBuilder { public: /// Construct a helper for the given descriptor value. - explicit MemRefDescriptor(Value *descriptor); + explicit MemRefDescriptor(ValuePtr descriptor); /// Builds IR creating an `undef` value of the descriptor type. static MemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType); @@ -173,39 +173,40 @@ public: /// type. static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, - MemRefType type, Value *memory); + MemRefType type, ValuePtr memory); /// Builds IR extracting the allocated pointer from the descriptor. - Value *allocatedPtr(OpBuilder &builder, Location loc); + ValuePtr allocatedPtr(OpBuilder &builder, Location loc); /// Builds IR inserting the allocated pointer into the descriptor. - void setAllocatedPtr(OpBuilder &builder, Location loc, Value *ptr); + void setAllocatedPtr(OpBuilder &builder, Location loc, ValuePtr ptr); /// Builds IR extracting the aligned pointer from the descriptor. - Value *alignedPtr(OpBuilder &builder, Location loc); + ValuePtr alignedPtr(OpBuilder &builder, Location loc); /// Builds IR inserting the aligned pointer into the descriptor. - void setAlignedPtr(OpBuilder &builder, Location loc, Value *ptr); + void setAlignedPtr(OpBuilder &builder, Location loc, ValuePtr ptr); /// Builds IR extracting the offset from the descriptor. - Value *offset(OpBuilder &builder, Location loc); + ValuePtr offset(OpBuilder &builder, Location loc); /// Builds IR inserting the offset into the descriptor. - void setOffset(OpBuilder &builder, Location loc, Value *offset); + void setOffset(OpBuilder &builder, Location loc, ValuePtr offset); void setConstantOffset(OpBuilder &builder, Location loc, uint64_t offset); /// Builds IR extracting the pos-th size from the descriptor. - Value *size(OpBuilder &builder, Location loc, unsigned pos); + ValuePtr size(OpBuilder &builder, Location loc, unsigned pos); /// Builds IR inserting the pos-th size into the descriptor - void setSize(OpBuilder &builder, Location loc, unsigned pos, Value *size); + void setSize(OpBuilder &builder, Location loc, unsigned pos, ValuePtr size); void setConstantSize(OpBuilder &builder, Location loc, unsigned pos, uint64_t size); /// Builds IR extracting the pos-th size from the descriptor. - Value *stride(OpBuilder &builder, Location loc, unsigned pos); + ValuePtr stride(OpBuilder &builder, Location loc, unsigned pos); /// Builds IR inserting the pos-th stride into the descriptor - void setStride(OpBuilder &builder, Location loc, unsigned pos, Value *stride); + void setStride(OpBuilder &builder, Location loc, unsigned pos, + ValuePtr stride); void setConstantStride(OpBuilder &builder, Location loc, unsigned pos, uint64_t stride); @@ -220,19 +221,19 @@ private: class UnrankedMemRefDescriptor : public StructBuilder { public: /// Construct a helper for the given descriptor value. - explicit UnrankedMemRefDescriptor(Value *descriptor); + explicit UnrankedMemRefDescriptor(ValuePtr descriptor); /// Builds IR creating an `undef` value of the descriptor type. static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType); /// Builds IR extracting the rank from the descriptor - Value *rank(OpBuilder &builder, Location loc); + ValuePtr rank(OpBuilder &builder, Location loc); /// Builds IR setting the rank in the descriptor - void setRank(OpBuilder &builder, Location loc, Value *value); + void setRank(OpBuilder &builder, Location loc, ValuePtr value); /// Builds IR extracting ranked memref descriptor ptr - Value *memRefDescPtr(OpBuilder &builder, Location loc); + ValuePtr memRefDescPtr(OpBuilder &builder, Location loc); /// Builds IR setting ranked memref descriptor ptr - void setMemRefDescPtr(OpBuilder &builder, Location loc, Value *value); + void setMemRefDescPtr(OpBuilder &builder, Location loc, ValuePtr value); }; /// Base class for operation conversions targeting the LLVM IR dialect. Provides /// conversion patterns with an access to the containing LLVMLowering for the diff --git a/mlir/include/mlir/Dialect/AffineOps/AffineOps.h b/mlir/include/mlir/Dialect/AffineOps/AffineOps.h index 36b4e55e77c..764f439e020 100644 --- a/mlir/include/mlir/Dialect/AffineOps/AffineOps.h +++ b/mlir/include/mlir/Dialect/AffineOps/AffineOps.h @@ -41,7 +41,7 @@ class OpBuilder; /// A utility function to check if a value is defined at the top level of a /// function. A value of index type defined at the top level is always a valid /// symbol. -bool isTopLevelValue(Value *value); +bool isTopLevelValue(ValuePtr value); class AffineOpsDialect : public Dialect { public: @@ -148,18 +148,19 @@ class AffineDmaStartOp : public OpgetType().cast(); } @@ -191,7 +192,7 @@ public: } /// Returns the destination MemRefType for this DMA operations. - Value *getDstMemRef() { return getOperand(getDstMemRefOperandIndex()); } + ValuePtr getDstMemRef() { return getOperand(getDstMemRefOperandIndex()); } MemRefType getDstMemRefType() { return getDstMemRef()->getType().cast(); } @@ -225,7 +226,7 @@ public: } /// Returns the Tag MemRef for this DMA operation. - Value *getTagMemRef() { return getOperand(getTagMemRefOperandIndex()); } + ValuePtr getTagMemRef() { return getOperand(getTagMemRefOperandIndex()); } MemRefType getTagMemRefType() { return getTagMemRef()->getType().cast(); } @@ -249,13 +250,13 @@ public: } /// Returns the number of elements being transferred by this DMA operation. - Value *getNumElements() { + ValuePtr getNumElements() { return getOperand(getTagMemRefOperandIndex() + 1 + getTagMap().getNumInputs()); } /// Returns the AffineMapAttr associated with 'memref'. - NamedAttribute getAffineMapAttrForMemRef(Value *memref) { + NamedAttribute getAffineMapAttrForMemRef(ValuePtr memref) { if (memref == getSrcMemRef()) return {Identifier::get(getSrcMapAttrName(), getContext()), getSrcMapAttr()}; @@ -305,14 +306,14 @@ public: } /// Returns the stride value for this DMA operation. - Value *getStride() { + ValuePtr getStride() { if (!isStrided()) return nullptr; return getOperand(getNumOperands() - 1 - 1); } /// Returns the number of elements to transfer per stride for this DMA op. - Value *getNumElementsPerStride() { + ValuePtr getNumElementsPerStride() { if (!isStrided()) return nullptr; return getOperand(getNumOperands() - 1); @@ -337,14 +338,14 @@ class AffineDmaWaitOp : public OpgetType().cast(); } @@ -367,14 +368,16 @@ public: } /// Returns the AffineMapAttr associated with 'memref'. - NamedAttribute getAffineMapAttrForMemRef(Value *memref) { + NamedAttribute getAffineMapAttrForMemRef(ValuePtr memref) { assert(memref == getTagMemRef()); return {Identifier::get(getTagMapAttrName(), getContext()), getTagMapAttr()}; } /// Returns the number of elements transferred in the associated DMA op. - Value *getNumElements() { return getOperand(1 + getTagMap().getNumInputs()); } + ValuePtr getNumElements() { + return getOperand(1 + getTagMap().getNumInputs()); + } static StringRef getTagMapAttrName() { return "tag_map"; } static ParseResult parse(OpAsmParser &parser, OperationState &result); @@ -409,18 +412,18 @@ public: static void build(Builder *builder, OperationState &result, AffineMap map, ValueRange operands); /// Builds an affine load op with an identity map and operands. - static void build(Builder *builder, OperationState &result, Value *memref, + static void build(Builder *builder, OperationState &result, ValuePtr memref, ValueRange indices = {}); /// Builds an affine load op with the specified map and its operands. - static void build(Builder *builder, OperationState &result, Value *memref, + static void build(Builder *builder, OperationState &result, ValuePtr memref, AffineMap map, ValueRange mapOperands); /// Returns the operand index of the memref. unsigned getMemRefOperandIndex() { return 0; } /// Get memref operand. - Value *getMemRef() { return getOperand(getMemRefOperandIndex()); } - void setMemRef(Value *value) { setOperand(getMemRefOperandIndex(), value); } + ValuePtr getMemRef() { return getOperand(getMemRefOperandIndex()); } + void setMemRef(ValuePtr value) { setOperand(getMemRefOperandIndex(), value); } MemRefType getMemRefType() { return getMemRef()->getType().cast(); } @@ -435,7 +438,7 @@ public: } /// Returns the AffineMapAttr associated with 'memref'. - NamedAttribute getAffineMapAttrForMemRef(Value *memref) { + NamedAttribute getAffineMapAttrForMemRef(ValuePtr memref) { assert(memref == getMemRef()); return {Identifier::get(getMapAttrName(), getContext()), getAffineMapAttr()}; @@ -476,21 +479,21 @@ public: /// Builds an affine store operation with the provided indices (identity map). static void build(Builder *builder, OperationState &result, - Value *valueToStore, Value *memref, ValueRange indices); + ValuePtr valueToStore, ValuePtr memref, ValueRange indices); /// Builds an affine store operation with the specified map and its operands. static void build(Builder *builder, OperationState &result, - Value *valueToStore, Value *memref, AffineMap map, + ValuePtr valueToStore, ValuePtr memref, AffineMap map, ValueRange mapOperands); /// Get value to be stored by store operation. - Value *getValueToStore() { return getOperand(0); } + ValuePtr getValueToStore() { return getOperand(0); } /// Returns the operand index of the memref. unsigned getMemRefOperandIndex() { return 1; } /// Get memref operand. - Value *getMemRef() { return getOperand(getMemRefOperandIndex()); } - void setMemRef(Value *value) { setOperand(getMemRefOperandIndex(), value); } + ValuePtr getMemRef() { return getOperand(getMemRefOperandIndex()); } + void setMemRef(ValuePtr value) { setOperand(getMemRefOperandIndex(), value); } MemRefType getMemRefType() { return getMemRef()->getType().cast(); @@ -506,7 +509,7 @@ public: } /// Returns the AffineMapAttr associated with 'memref'. - NamedAttribute getAffineMapAttrForMemRef(Value *memref) { + NamedAttribute getAffineMapAttrForMemRef(ValuePtr memref) { assert(memref == getMemRef()); return {Identifier::get(getMapAttrName(), getContext()), getAffineMapAttr()}; @@ -526,10 +529,10 @@ public: }; /// Returns true if the given Value can be used as a dimension id. -bool isValidDim(Value *value); +bool isValidDim(ValuePtr value); /// Returns true if the given Value can be used as a symbol. -bool isValidSymbol(Value *value); +bool isValidSymbol(ValuePtr value); /// Modifies both `map` and `operands` in-place so as to: /// 1. drop duplicate operands @@ -538,17 +541,17 @@ bool isValidSymbol(Value *value); /// dimensional operands /// 4. propagate constant operands and drop them void canonicalizeMapAndOperands(AffineMap *map, - SmallVectorImpl *operands); + SmallVectorImpl *operands); /// Canonicalizes an integer set the same way canonicalizeMapAndOperands does /// for affine maps. void canonicalizeSetAndOperands(IntegerSet *set, - SmallVectorImpl *operands); + SmallVectorImpl *operands); /// Returns a composed AffineApplyOp by composing `map` and `operands` with /// other AffineApplyOps supplying those operands. The operands of the resulting /// AffineApplyOp do not change the length of AffineApplyOp chains. AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, - ArrayRef operands); + ArrayRef operands); /// Given an affine map `map` and its input `operands`, this method composes /// into `map`, maps of AffineApplyOps whose results are the values in @@ -558,22 +561,22 @@ AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, /// terminal symbol, i.e., a symbol defined at the top level or a block/function /// argument. void fullyComposeAffineMapAndOperands(AffineMap *map, - SmallVectorImpl *operands); + SmallVectorImpl *operands); #define GET_OP_CLASSES #include "mlir/Dialect/AffineOps/AffineOps.h.inc" /// Returns if the provided value is the induction variable of a AffineForOp. -bool isForInductionVar(Value *val); +bool isForInductionVar(ValuePtr val); /// Returns the loop parent of an induction variable. If the provided value is /// not an induction variable, then return nullptr. -AffineForOp getForInductionVarOwner(Value *val); +AffineForOp getForInductionVarOwner(ValuePtr val); /// Extracts the induction variables from a list of AffineForOps and places them /// in the output argument `ivs`. void extractForInductionVars(ArrayRef forInsts, - SmallVectorImpl *ivs); + SmallVectorImpl *ivs); /// AffineBound represents a lower or upper bound in the for operation. /// This class does not own the underlying operands. Instead, it refers @@ -588,7 +591,7 @@ public: AffineValueMap getAsAffineValueMap(); unsigned getNumOperands() { return opEnd - opStart; } - Value *getOperand(unsigned idx) { return op.getOperand(opStart + idx); } + ValuePtr getOperand(unsigned idx) { return op.getOperand(opStart + idx); } using operand_iterator = AffineForOp::operand_iterator; using operand_range = AffineForOp::operand_range; @@ -613,7 +616,7 @@ private: }; /// An `AffineApplyNormalizer` is a helper class that supports renumbering -/// operands of AffineApplyOp. This acts as a reindexing map of Value* to +/// operands of AffineApplyOp. This acts as a reindexing map of Value to /// positional dims or symbols and allows simplifications such as: /// /// ```mlir @@ -626,13 +629,13 @@ private: /// %1 = affine.apply () -> (0) /// ``` struct AffineApplyNormalizer { - AffineApplyNormalizer(AffineMap map, ArrayRef operands); + AffineApplyNormalizer(AffineMap map, ArrayRef operands); /// Returns the AffineMap resulting from normalization. AffineMap getAffineMap() { return affineMap; } - SmallVector getOperands() { - SmallVector res(reorderedDims); + SmallVector getOperands() { + SmallVector res(reorderedDims); res.append(concatenatedSymbols.begin(), concatenatedSymbols.end()); return res; } @@ -642,13 +645,13 @@ struct AffineApplyNormalizer { /// Normalizes 'otherMap' and its operands 'otherOperands' to map to this /// normalizer's coordinate space. - void normalize(AffineMap *otherMap, SmallVectorImpl *otherOperands); + void normalize(AffineMap *otherMap, SmallVectorImpl *otherOperands); private: /// Helper function to insert `v` into the coordinate system of the current /// AffineApplyNormalizer. Returns the AffineDimExpr with the corresponding /// renumbered position. - AffineDimExpr renumberOneDim(Value *v); + AffineDimExpr renumberOneDim(ValuePtr v); /// Given an `other` normalizer, this rewrites `other.affineMap` in the /// coordinate system of the current AffineApplyNormalizer. @@ -656,13 +659,13 @@ private: /// `this`. AffineMap renumber(const AffineApplyNormalizer &other); - /// Maps of Value* to position in `affineMap`. - DenseMap dimValueToPosition; + /// Maps of Value to position in `affineMap`. + DenseMap dimValueToPosition; /// Ordered dims and symbols matching positional dims and symbols in /// `affineMap`. - SmallVector reorderedDims; - SmallVector concatenatedSymbols; + SmallVector reorderedDims; + SmallVector concatenatedSymbols; AffineMap affineMap; diff --git a/mlir/include/mlir/Dialect/AffineOps/AffineOps.td b/mlir/include/mlir/Dialect/AffineOps/AffineOps.td index b40990ecb5d..befdc2f6237 100644 --- a/mlir/include/mlir/Dialect/AffineOps/AffineOps.td +++ b/mlir/include/mlir/Dialect/AffineOps/AffineOps.td @@ -101,7 +101,7 @@ def AffineForOp : Affine_Op<"for", static StringRef getUpperBoundAttrName() { return "upper_bound"; } Block *getBody() { return ®ion().front(); } - Value *getInductionVar() { return getBody()->getArgument(0); } + ValuePtr getInductionVar() { return getBody()->getArgument(0); } OpBuilder getBodyBuilder() { return OpBuilder(getBody(), std::prev(getBody()->end())); } @@ -286,8 +286,8 @@ def AffinePrefetchOp : Affine_Op<"prefetch"> { BoolAttr:$isDataCache); let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value *memref," - "AffineMap map, ArrayRef mapOperands, bool isWrite," + "Builder *builder, OperationState &result, ValuePtr memref," + "AffineMap map, ArrayRef mapOperands, bool isWrite," "unsigned localityHint, bool isDataCache", [{ assert(map.getNumInputs() == mapOperands.size() @@ -315,7 +315,7 @@ def AffinePrefetchOp : Affine_Op<"prefetch"> { } /// Returns the AffineMapAttr associated with 'memref'. - NamedAttribute getAffineMapAttrForMemRef(Value *mref) { + NamedAttribute getAffineMapAttrForMemRef(ValuePtr mref) { assert(mref == memref()); return {Identifier::get(getMapAttrName(), getContext()), getAffineMapAttr()}; diff --git a/mlir/include/mlir/Dialect/GPU/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/GPUDialect.h index 93c0b13ee3e..12c2aa1bbd1 100644 --- a/mlir/include/mlir/Dialect/GPU/GPUDialect.h +++ b/mlir/include/mlir/Dialect/GPU/GPUDialect.h @@ -77,9 +77,9 @@ public: /// Utility class for the GPU dialect to represent triples of `Value`s /// accessible through `.x`, `.y`, and `.z` similarly to CUDA notation. struct KernelDim3 { - Value *x; - Value *y; - Value *z; + ValuePtr x; + ValuePtr y; + ValuePtr z; }; #define GET_OP_CLASSES diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td index 6751f0a3f70..def1ff2b8a1 100644 --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -157,7 +157,7 @@ def GPU_GPUFuncOp : GPU_Op<"func", [FunctionLike, IsolatedFromAbove, Symbol]> { /// Returns a list of block arguments that correspond to buffers located in /// the workgroup memory - ArrayRef getWorkgroupAttributions() { + ArrayRef getWorkgroupAttributions() { auto begin = std::next(getBody().front().args_begin(), getType().getNumInputs()); auto end = std::next(begin, getNumWorkgroupAttributions()); @@ -166,7 +166,7 @@ def GPU_GPUFuncOp : GPU_Op<"func", [FunctionLike, IsolatedFromAbove, Symbol]> { /// Returns a list of block arguments that correspond to buffers located in /// the private memory. - ArrayRef getPrivateAttributions() { + ArrayRef getPrivateAttributions() { auto begin = std::next(getBody().front().args_begin(), getType().getNumInputs() + getNumWorkgroupAttributions()); @@ -282,8 +282,8 @@ def GPU_LaunchFuncOp : GPU_Op<"launch_func">, let builders = [ OpBuilder<"Builder *builder, OperationState &result, GPUFuncOp kernelFunc, " - "Value *gridSizeX, Value *gridSizeY, Value *gridSizeZ, " - "Value *blockSizeX, Value *blockSizeY, Value *blockSizeZ, " + "ValuePtr gridSizeX, ValuePtr gridSizeY, ValuePtr gridSizeZ, " + "ValuePtr blockSizeX, ValuePtr blockSizeY, ValuePtr blockSizeZ, " "ValueRange kernelOperands">, OpBuilder<"Builder *builder, OperationState &result, GPUFuncOp kernelFunc, " "KernelDim3 gridSize, KernelDim3 blockSize, " @@ -302,7 +302,7 @@ def GPU_LaunchFuncOp : GPU_Op<"launch_func">, StringRef getKernelModuleName(); /// The i-th operand passed to the kernel function. - Value *getKernelOperand(unsigned i); + ValuePtr getKernelOperand(unsigned i); /// Get the SSA values passed as operands to specify the grid size. KernelDim3 getGridSizeOperandValues(); @@ -415,9 +415,9 @@ def GPU_LaunchOp : GPU_Op<"launch", [IsolatedFromAbove]>, let skipDefaultBuilders = 1; let builders = [ - OpBuilder<"Builder *builder, OperationState &result, Value *gridSizeX," - "Value *gridSizeY, Value *gridSizeZ, Value *blockSizeX," - "Value *blockSizeY, Value *blockSizeZ," + OpBuilder<"Builder *builder, OperationState &result, ValuePtr gridSizeX," + "ValuePtr gridSizeY, ValuePtr gridSizeZ, ValuePtr blockSizeX," + "ValuePtr blockSizeY, ValuePtr blockSizeZ," "ValueRange operands"> ]; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h index dae27d00e5a..a599d51b31f 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -194,9 +194,9 @@ private: /// surrounding the insertion point of builder. Obtain the address of that /// global and use it to compute the address of the first character in the /// string (operations inserted at the builder insertion point). -Value *createGlobalString(Location loc, OpBuilder &builder, StringRef name, - StringRef value, LLVM::Linkage linkage, - LLVM::LLVMDialect *llvmDialect); +ValuePtr createGlobalString(Location loc, OpBuilder &builder, StringRef name, + StringRef value, LLVM::Linkage linkage, + LLVM::LLVMDialect *llvmDialect); /// LLVM requires some operations to be inside of a Module operation. This /// function confirms that the Operation has the desired properties. diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 00acc539dab..cfbbf7da65d 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -185,8 +185,8 @@ def LLVM_ICmpOp : LLVM_OneResultOp<"icmp", [NoSideEffect]>, $res = builder.CreateICmp(getLLVMCmpPredicate($predicate), $lhs, $rhs); }]; let builders = [OpBuilder< - "Builder *b, OperationState &result, ICmpPredicate predicate, Value *lhs, " - "Value *rhs", [{ + "Builder *b, OperationState &result, ICmpPredicate predicate, ValuePtr lhs, " + "ValuePtr rhs", [{ LLVMDialect *dialect = &lhs->getType().cast().getDialect(); build(b, result, LLVMType::getInt1Ty(dialect), b->getI64IntegerAttr(static_cast(predicate)), lhs, rhs); @@ -232,8 +232,8 @@ def LLVM_FCmpOp : LLVM_OneResultOp<"fcmp", [NoSideEffect]>, $res = builder.CreateFCmp(getLLVMCmpPredicate($predicate), $lhs, $rhs); }]; let builders = [OpBuilder< - "Builder *b, OperationState &result, FCmpPredicate predicate, Value *lhs, " - "Value *rhs", [{ + "Builder *b, OperationState &result, FCmpPredicate predicate, ValuePtr lhs, " + "ValuePtr rhs", [{ LLVMDialect *dialect = &lhs->getType().cast().getDialect(); build(b, result, LLVMType::getInt1Ty(dialect), b->getI64IntegerAttr(static_cast(predicate)), lhs, rhs); @@ -265,7 +265,7 @@ def LLVM_AllocaOp : $res = alloca; }]; let builders = [OpBuilder< - "Builder *b, OperationState &result, Type resultType, Value *arraySize, " + "Builder *b, OperationState &result, Type resultType, ValuePtr arraySize, " "unsigned alignment", [{ if (alignment == 0) @@ -292,7 +292,7 @@ def LLVM_GEPOp : LLVM_OneResultOp<"getelementptr", [NoSideEffect]>, def LLVM_LoadOp : LLVM_OneResultOp<"load">, Arguments<(ins LLVM_Type:$addr)>, LLVM_Builder<"$res = builder.CreateLoad($addr);"> { let builders = [OpBuilder< - "Builder *b, OperationState &result, Value *addr", + "Builder *b, OperationState &result, ValuePtr addr", [{ auto type = addr->getType().cast().getPointerElementTy(); build(b, result, type, addr); @@ -353,7 +353,7 @@ def LLVM_ExtractElementOp : LLVM_OneResultOp<"extractelement", [NoSideEffect]>, $res = builder.CreateExtractElement($vector, $position); }]; let builders = [OpBuilder< - "Builder *b, OperationState &result, Value *vector, Value *position," + "Builder *b, OperationState &result, ValuePtr vector, ValuePtr position," "ArrayRef attrs = {}">]; let parser = [{ return parseExtractElementOp(parser, result); }]; let printer = [{ printExtractElementOp(p, *this); }]; @@ -384,7 +384,7 @@ def LLVM_InsertValueOp : LLVM_OneResultOp<"insertvalue", [NoSideEffect]>, extractPosition($position)); }]; let builders = [OpBuilder< - "Builder *b, OperationState &result, Value *container, Value *value, " + "Builder *b, OperationState &result, ValuePtr container, ValuePtr value, " "ArrayAttr position", [{ build(b, result, container->getType(), container, value, position); @@ -398,7 +398,7 @@ def LLVM_ShuffleVectorOp LLVM_Builder< "$res = builder.CreateShuffleVector($v1, $v2, extractPosition($mask));"> { let builders = [OpBuilder< - "Builder *b, OperationState &result, Value *v1, Value *v2, " + "Builder *b, OperationState &result, ValuePtr v1, ValuePtr v2, " "ArrayAttr mask, ArrayRef attrs = {}">]; let verifier = [{ auto wrappedVectorType1 = v1()->getType().cast(); @@ -422,8 +422,8 @@ def LLVM_SelectOp LLVM_Builder< "$res = builder.CreateSelect($condition, $trueValue, $falseValue);"> { let builders = [OpBuilder< - "Builder *b, OperationState &result, Value *condition, Value *lhs, " - "Value *rhs", [{ + "Builder *b, OperationState &result, ValuePtr condition, ValuePtr lhs, " + "ValuePtr rhs", [{ build(b, result, lhs->getType(), condition, lhs, rhs); }]>]; let parser = [{ return parseSelectOp(parser, result); }]; diff --git a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h index 01d3e4b239c..426708b14a8 100644 --- a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h +++ b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h @@ -37,15 +37,15 @@ class LinalgOp; class Aliases { public: /// Returns true if v1 and v2 alias. - bool alias(Value *v1, Value *v2) { return find(v1) == find(v2); } + bool alias(ValuePtr v1, ValuePtr v2) { return find(v1) == find(v2); } private: /// Returns the base buffer or block argument into which the view `v` aliases. /// This lazily records the new aliases discovered while walking back the /// use-def chain. - Value *find(Value *v); + ValuePtr find(ValuePtr v); - DenseMap aliases; + DenseMap aliases; }; /// Data structure for holding a dependence graph that operates on LinalgOp and @@ -54,7 +54,7 @@ class LinalgDependenceGraph { public: struct LinalgOpView { Operation *op; - Value *view; + ValuePtr view; }; struct LinalgDependenceGraphElem { // dependentOpView may be either: @@ -64,7 +64,7 @@ public: // View in the op that is used to index in the graph: // 1. src in the case of dependencesFromDstGraphs. // 2. dst in the case of dependencesIntoGraphs. - Value *indexingView; + ValuePtr indexingView; }; using LinalgDependences = SmallVector; using DependenceGraph = DenseMap; @@ -97,14 +97,14 @@ public: /// Dependences are restricted to views aliasing `view`. SmallVector findCoveringReads(LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, - Value *view) const; + ValuePtr view) const; /// Returns the operations that are interleaved between `srcLinalgOp` and /// `dstLinalgOp` and that are involved in a WAR or WAW with `srcLinalgOp`. /// Dependences are restricted to views aliasing `view`. SmallVector findCoveringWrites(LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, - Value *view) const; + ValuePtr view) const; private: // Keep dependences in both directions, this is not just a performance gain @@ -130,7 +130,7 @@ private: /// Implementation detail for findCoveringxxx. SmallVector findOperationsWithCoveringDependences(LinalgOp srcLinalgOp, - LinalgOp dstLinalgOp, Value *view, + LinalgOp dstLinalgOp, ValuePtr view, ArrayRef types) const; Aliases &aliases; diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h index cf6335278b7..8375e750a5c 100644 --- a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h @@ -55,34 +55,34 @@ inline StringRef toString(IterType t) { /// makeLinalgGenericOp({A({m, n}), B({k, n})}, {C({m, n})}, ... ); /// ``` struct StructuredIndexed { - StructuredIndexed(Value *v) : value(v) {} + StructuredIndexed(ValuePtr v) : value(v) {} StructuredIndexed operator()(ArrayRef indexings) { return StructuredIndexed(value, indexings); } - operator Value *() const /* implicit */ { return value; } + operator ValuePtr() const /* implicit */ { return value; } ArrayRef getExprs() { return exprs; } private: - StructuredIndexed(Value *v, ArrayRef indexings) + StructuredIndexed(ValuePtr v, ArrayRef indexings) : value(v), exprs(indexings.begin(), indexings.end()) { assert(v->getType().isa() && "MemRefType expected"); } StructuredIndexed(ValueHandle v, ArrayRef indexings) : StructuredIndexed(v.getValue(), indexings) {} - Value *value; + ValuePtr value; SmallVector exprs; }; -inline void defaultRegionBuilder(ArrayRef args) {} +inline void defaultRegionBuilder(ArrayRef args) {} Operation *makeLinalgGenericOp(ArrayRef iteratorTypes, ArrayRef inputs, ArrayRef outputs, - function_ref)> + function_ref)> regionBuilder = defaultRegionBuilder, - ArrayRef otherValues = {}, + ArrayRef otherValues = {}, ArrayRef otherAttributes = {}); namespace ops { @@ -96,7 +96,7 @@ using edsc::intrinsics::linalg_yield; /// Build the body of a region to compute a multiply-accumulate, under the /// current ScopedContext, at the current insert point. -void macRegionBuilder(ArrayRef args); +void macRegionBuilder(ArrayRef args); /// TODO(ntv): In the future we should tie these implementations to something in /// Tablegen that generates the proper interfaces and the proper sugared named @@ -120,7 +120,7 @@ void macRegionBuilder(ArrayRef args); /// with in-place semantics and parallelism. /// Unary pointwise operation (with broadcast) entry point. -using UnaryPointwiseOpBuilder = function_ref; +using UnaryPointwiseOpBuilder = function_ref; Operation *linalg_pointwise(UnaryPointwiseOpBuilder unaryOp, StructuredIndexed I, StructuredIndexed O); @@ -131,7 +131,7 @@ Operation *linalg_pointwise_tanh(StructuredIndexed I, StructuredIndexed O); /// Binary pointwise operation (with broadcast) entry point. using BinaryPointwiseOpBuilder = - function_ref; + function_ref; Operation *linalg_pointwise(BinaryPointwiseOpBuilder binaryOp, StructuredIndexed I1, StructuredIndexed I2, StructuredIndexed O); diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td index 12318a244df..18ca31cc376 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td @@ -92,22 +92,22 @@ def LinalgLibraryInterface : OpInterface<"LinalgOp"> { "Query the number of loops within the current operation.", "unsigned", "getNumLoops">, InterfaceMethod<"Query the input view at the given index.", - "Value *", "getInput", (ins "unsigned":$i) + "ValuePtr ", "getInput", (ins "unsigned":$i) >, InterfaceMethod<"Query the output view at the given index.", - "Value *", "getOutput", (ins "unsigned":$i) + "ValuePtr ", "getOutput", (ins "unsigned":$i) >, InterfaceMethod<[{ Query the index of the given input value, or `None` if the value is not an input. }], - "Optional", "getIndexOfInput", (ins "Value *":$view) + "Optional", "getIndexOfInput", (ins "ValuePtr ":$view) >, InterfaceMethod<[{ Query the index of the given view value, or `None` if the value is not an view. }], - "Optional", "getIndexOfOutput", (ins "Value *":$view) + "Optional", "getIndexOfOutput", (ins "ValuePtr ":$view) >, InterfaceMethod<[{ Query the type of the input view at the given index. @@ -228,7 +228,7 @@ def CopyOp : LinalgLibrary_Op<"copy", [NInputs<1>, NOutputs<1>]> { // TODO(ntv) this should go away once the usage of OptionalAttr triggers // emission of builders with default arguments left unspecified. let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value *input, Value *output", [{ + "Builder *builder, OperationState &result, ValuePtr input, ValuePtr output", [{ return build( builder, result, input, output, AffineMapAttr(), AffineMapAttr()); }]>]; diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td index b806d7548fb..5d402a9ded9 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -56,8 +56,8 @@ def Linalg_RangeOp : ```` }]; let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value *min, Value *max, " - "Value *step", + "Builder *builder, OperationState &result, ValuePtr min, ValuePtr max, " + "ValuePtr step", [{ auto rangeType = RangeType::get(builder->getContext()); build(builder, result, rangeType, min, max, step); @@ -112,7 +112,7 @@ def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>, }]; let builders = [OpBuilder< - "Builder *b, OperationState &result, Value *base, " + "Builder *b, OperationState &result, ValuePtr base, " "ValueRange indexings">]; let extraClassDeclaration = [{ @@ -124,12 +124,12 @@ def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>, MemRefType getBaseViewType() { return view()->getType().cast(); } // Get the underlying indexing at a given rank. - Value *indexing(unsigned rank) { return *(indexings().begin() + rank); } + ValuePtr indexing(unsigned rank) { return *(indexings().begin() + rank); } // Get the subset of indexings that are of RangeType. - SmallVector getRanges() { - SmallVector res; - for (auto *operand : indexings()) + SmallVector getRanges() { + SmallVector res; + for (auto operand : indexings()) if (!operand->getType().isa()) res.push_back(operand); return res; @@ -154,7 +154,7 @@ def Linalg_TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>, }]; let builders = [OpBuilder< - "Builder *b, OperationState &result, Value *view, " + "Builder *b, OperationState &result, ValuePtr view, " "AffineMapAttr permutation, ArrayRef attrs = {}">]; let verifier = [{ diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index 75b63c93cd8..774be6616cd 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -92,22 +92,22 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { "Query the number of loops within the current operation.", "unsigned", "getNumLoops">, InterfaceMethod<"Query the input view at the given index.", - "Value *", "getInput", (ins "unsigned":$i) + "ValuePtr ", "getInput", (ins "unsigned":$i) >, InterfaceMethod<"Query the output view at the given index.", - "Value *", "getOutput", (ins "unsigned":$i) + "ValuePtr ", "getOutput", (ins "unsigned":$i) >, InterfaceMethod<[{ Query the index of the given input value, or `None` if the value is not an input. }], - "llvm::Optional", "getIndexOfInput", (ins "Value *":$view) + "llvm::Optional", "getIndexOfInput", (ins "ValuePtr ":$view) >, InterfaceMethod<[{ Query the index of the given view value, or `None` if the value is not an view. }], - "llvm::Optional", "getIndexOfOutput", (ins "Value *":$view) + "llvm::Optional", "getIndexOfOutput", (ins "ValuePtr ":$view) >, InterfaceMethod<[{ Query the type of the input view at the given index. @@ -228,7 +228,7 @@ def CopyOp : LinalgStructured_Op<"copy", [NInputs<1>, NOutputs<1>]> { // TODO(ntv) this should go away once the usage of OptionalAttr triggers // emission of builders with default arguments left unspecified. let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value *input, Value *output", [{ + "Builder *builder, OperationState &result, ValuePtr input, ValuePtr output", [{ return build( builder, result, input, output, AffineMapAttr(), AffineMapAttr()); }]>]; diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h index a24c1ca63c4..d196e6ccf94 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h @@ -77,13 +77,13 @@ private: public: /// Return the `i`-th input view. - Value *getInput(unsigned i) { + ValuePtr getInput(unsigned i) { assert(i < nInputs()); return this->getOperation()->getOperand(i); } /// Return the index of `view` in the list of input views if found, llvm::None /// otherwise. - Optional getIndexOfInput(Value *view) { + Optional getIndexOfInput(ValuePtr view) { auto it = llvm::find(getInputs(), view); if (it != getInputs().end()) return it - getInputs().begin(); @@ -99,12 +99,12 @@ public: return {range.begin(), range.begin() + nInputs()}; } /// Return the `i`-th output view. - Value *getOutput(unsigned i) { + ValuePtr getOutput(unsigned i) { return this->getOperation()->getOperand(nInputs() + i); } /// Return the index of `view` in the list of output views if found, /// llvm::None otherwise. - Optional getIndexOfOutput(Value *view) { + Optional getIndexOfOutput(ValuePtr view) { auto it = llvm::find(getOutputs(), view); if (it != getOutputs().end()) return it - getOutputs().begin(); diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td index 415dd918f74..dbc162f4132 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td @@ -45,7 +45,7 @@ class AffineMapDomainHasDim : CPred<[{ class HasOperandsOfType: CPred<[{ llvm::any_of($0.getOperands(), - [](Value* v) { + [](ValuePtr v) { return dyn_cast_or_null<}] # type # [{>(v->getDefiningOp()); }) }]>; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h index dfbac5ac193..a1a7458ae7f 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h @@ -38,7 +38,7 @@ struct LinalgTransforms { namespace detail { // Implementation detail of isProducedByOpOfType avoids the need for explicit // template instantiations. -bool isProducedByOpOfTypeImpl(Operation *consumerOp, Value *consumedView, +bool isProducedByOpOfTypeImpl(Operation *consumerOp, ValuePtr consumedView, function_ref isaOpType); } // namespace detail @@ -46,7 +46,7 @@ bool isProducedByOpOfTypeImpl(Operation *consumerOp, Value *consumedView, // an op of type `OpTy`. This is used to implement use-def type information on // buffers. template -bool isProducedByOpOfType(Operation *consumerOp, Value *consumedView) { +bool isProducedByOpOfType(Operation *consumerOp, ValuePtr consumedView) { return detail::isProducedByOpOfTypeImpl( consumerOp, consumedView, [](Operation *op) { return isa(op); }); } diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index f8d10ecfa57..50039dd9336 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -34,7 +34,7 @@ namespace edsc { /// A LoopRangeBuilder is a generic NestedBuilder for loop.for operations. /// More specifically it is meant to be used as a temporary object for -/// representing any nested MLIR construct that is "related to" an mlir::Value* +/// representing any nested MLIR construct that is "related to" an mlir::Value /// (for now an induction variable). class LoopRangeBuilder : public NestedBuilder { public: @@ -42,7 +42,7 @@ public: /// variable. A ValueHandle pointer is passed as the first argument and is the /// *only* way to capture the loop induction variable. LoopRangeBuilder(ValueHandle *iv, ValueHandle range); - LoopRangeBuilder(ValueHandle *iv, Value *range); + LoopRangeBuilder(ValueHandle *iv, ValuePtr range); LoopRangeBuilder(ValueHandle *iv, SubViewOp::Range range); LoopRangeBuilder(const LoopRangeBuilder &) = delete; @@ -65,7 +65,7 @@ public: LoopNestRangeBuilder(ArrayRef ivs, ArrayRef ranges); LoopNestRangeBuilder(ArrayRef ivs, - ArrayRef ranges); + ArrayRef ranges); LoopNestRangeBuilder(ArrayRef ivs, ArrayRef ranges); edsc::ValueHandle operator()(std::function fun = nullptr); @@ -88,14 +88,14 @@ struct FusionInfo { /// whole `consumedView`. This checks structural dominance, that the dependence /// is a RAW without any interleaved write to any piece of `consumedView`. bool isProducerLastWriteOfView(const LinalgDependenceGraph &graph, - LinalgOp consumer, Value *consumedView, + LinalgOp consumer, ValuePtr consumedView, LinalgOp producer); /// Checks whether fusing the specific `producer` of the `consumedView` is /// feasible. This checks `producer` is the last write of `consumedView` and /// that no interleaved dependence would be violated (RAW, WAR or WAW). bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer, - Value *consumedView, LinalgOp producer); + ValuePtr consumedView, LinalgOp producer); /// Fuses producer into consumer if the producer is structurally feasible and /// the fusion would not violate dependencies. @@ -111,8 +111,8 @@ Optional fuseProducerOf(OpBuilder &b, LinalgOp consumer, /// the inverse, concatenated loopToOperandRangeMaps to this list allows the /// derivation of loop ranges for any linalgOp. template -SmallVector getViewSizes(ConcreteOp linalgOp) { - SmallVector res; +SmallVector getViewSizes(ConcreteOp linalgOp) { + SmallVector res; for (auto v : linalgOp.getInputsAndOutputs()) { MemRefType t = v->getType().template cast(); for (unsigned i = 0; i < t.getRank(); ++i) @@ -125,10 +125,10 @@ SmallVector getViewSizes(ConcreteOp linalgOp) { /// When non-null, the optional pointer `folder` is used to call into the /// `createAndFold` builder method. If `folder` is null, the regular `create` /// method is called. -SmallVector applyMapToValues(OpBuilder &b, Location loc, - AffineMap map, - ArrayRef values, - OperationFolder *folder = nullptr); +SmallVector applyMapToValues(OpBuilder &b, Location loc, + AffineMap map, + ArrayRef values, + OperationFolder *folder = nullptr); struct TiledLinalgOp { LinalgOp op; @@ -151,7 +151,7 @@ struct TiledLinalgOp { /// `createAndFold` builder method. If `folder` is null, the regular `create` /// method is called. Optional tileLinalgOp(OpBuilder &b, LinalgOp op, - ArrayRef tileSizes, + ArrayRef tileSizes, ArrayRef permutation = {}, OperationFolder *folder = nullptr); @@ -182,9 +182,9 @@ Optional tileLinalgOperation(OpBuilder &b, Operation *op, } struct PromotionInfo { - Value *buffer; - Value *fullLocalView; - Value *partialLocalView; + ValuePtr buffer; + ValuePtr fullLocalView; + ValuePtr partialLocalView; }; /// Promotes the `subViews` into a new buffer allocated at the insertion point @@ -199,13 +199,13 @@ struct PromotionInfo { /// Returns a list of PromotionInfo which hold the promoted buffer and the /// full and partial views indexing into the buffer. SmallVector -promoteSubViews(OpBuilder &b, Location loc, ArrayRef subViews, +promoteSubViews(OpBuilder &b, Location loc, ArrayRef subViews, bool dynamicBuffers = false, OperationFolder *folder = nullptr); /// Returns all the operands of `linalgOp` that are not views. /// Asserts that these operands are value types to allow transformations like /// tiling to just use the values when cloning `linalgOp`. -SmallVector getAssumedNonViewOperands(LinalgOp linalgOp); +SmallVector getAssumedNonViewOperands(LinalgOp linalgOp); /// Apply the permutation defined by `permutation` to `inVec`. /// Element `i` in `inVec` is mapped to location `j = permutation[i]`. @@ -226,7 +226,7 @@ void applyPermutationToVector(SmallVector &inVec, /// It is the entry point for declarative transformation /// Returns the cloned `LinalgOp` with the new operands LinalgOp promoteSubViewOperands(OpBuilder &b, LinalgOp op, - llvm::SetVector subViews, + llvm::SetVector subViews, bool dynamicBuffers = false, OperationFolder *folder = nullptr); diff --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.h b/mlir/include/mlir/Dialect/LoopOps/LoopOps.h index fdadf4a40dd..e7ff6f84977 100644 --- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.h +++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.h @@ -50,7 +50,7 @@ void ensureLoopTerminator(Region ®ion, Builder &builder, Location loc); /// Returns the loop parent of an induction variable. If the provided value is /// not an induction variable, then return nullptr. -ForOp getForInductionVarOwner(Value *val); +ForOp getForInductionVarOwner(ValuePtr val); } // end namespace loop } // end namespace mlir diff --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td index 5e0b8098411..e0f5b896309 100644 --- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td +++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td @@ -74,18 +74,18 @@ def ForOp : Loop_Op<"for", let skipDefaultBuilders = 1; let builders = [ OpBuilder<"Builder *builder, OperationState &result, " - "Value *lowerBound, Value *upperBound, Value *step"> + "ValuePtr lowerBound, ValuePtr upperBound, ValuePtr step"> ]; let extraClassDeclaration = [{ Block *getBody() { return ®ion().front(); } - Value *getInductionVar() { return getBody()->getArgument(0); } + ValuePtr getInductionVar() { return getBody()->getArgument(0); } OpBuilder getBodyBuilder() { return OpBuilder(getBody(), std::prev(getBody()->end())); } - void setLowerBound(Value *bound) { getOperation()->setOperand(0, bound); } - void setUpperBound(Value *bound) { getOperation()->setOperand(1, bound); } - void setStep(Value *step) { getOperation()->setOperand(2, step); } + void setLowerBound(ValuePtr bound) { getOperation()->setOperand(0, bound); } + void setUpperBound(ValuePtr bound) { getOperation()->setOperand(1, bound); } + void setStep(ValuePtr step) { getOperation()->setOperand(2, step); } }]; } @@ -116,7 +116,7 @@ def IfOp : Loop_Op<"if", let skipDefaultBuilders = 1; let builders = [ OpBuilder<"Builder *builder, OperationState &result, " - "Value *cond, bool withElseRegion"> + "ValuePtr cond, bool withElseRegion"> ]; let extraClassDeclaration = [{ diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td index d6e2e1c6fda..d19fd974684 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td @@ -120,7 +120,7 @@ def SPV_CompositeExtractOp : SPV_Op<"CompositeExtract", [NoSideEffect]> { let builders = [ OpBuilder<[{Builder *builder, OperationState &state, - Value *composite, ArrayRef indices}]> + ValuePtr composite, ArrayRef indices}]> ]; let hasFolder = 1; diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td index 464b670dae9..32a78024560 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td @@ -132,7 +132,7 @@ def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", let builders = [ OpBuilder< - "Builder *builder, OperationState &state, Value *condition, " + "Builder *builder, OperationState &state, ValuePtr condition, " "Block *trueBlock, ValueRange trueArguments, " "Block *falseBlock, ValueRange falseArguments, " "Optional> weights = {}", diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td index 0c4b2902a12..e1e94bcd861 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td @@ -858,8 +858,8 @@ def SPV_SelectOp : SPV_Op<"Select", [NoSideEffect]> { ); let builders = [OpBuilder<[{Builder *builder, OperationState &state, - Value *cond, Value *trueValue, - Value *falseValue}]>]; + ValuePtr cond, ValuePtr trueValue, + ValuePtr falseValue}]>]; } // ----- diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h index f48a1d0b129..37b4ee24237 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h @@ -64,8 +64,8 @@ protected: namespace spirv { /// Returns a value that represents a builtin variable value within the SPIR-V /// module. -Value *getBuiltinVariableValue(Operation *op, spirv::BuiltIn builtin, - OpBuilder &builder); +ValuePtr getBuiltinVariableValue(Operation *op, spirv::BuiltIn builtin, + OpBuilder &builder); /// Attribute name for specifying argument ABI information. StringRef getInterfaceVarABIAttrName(); diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td index 91ea8d7d676..777e5750486 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td @@ -102,7 +102,7 @@ def SPV_AccessChainOp : SPV_Op<"AccessChain", [NoSideEffect]> { ); let builders = [OpBuilder<[{Builder *builder, OperationState &state, - Value *basePtr, ValueRange indices}]>]; + ValuePtr basePtr, ValueRange indices}]>]; let hasCanonicalizer = 1; } @@ -272,7 +272,7 @@ def SPV_LoadOp : SPV_Op<"Load", []> { ); let builders = [OpBuilder<[{Builder *builder, OperationState &state, - Value *basePtr, /*optional*/IntegerAttr memory_access, + ValuePtr basePtr, /*optional*/IntegerAttr memory_access, /*optional*/IntegerAttr alignment}]>]; } @@ -367,7 +367,7 @@ def SPV_StoreOp : SPV_Op<"Store", []> { let builders = [ OpBuilder<"Builder *builder, OperationState &state, " - "Value *ptr, Value *value, ArrayRef namedAttrs", [{ + "ValuePtr ptr, ValuePtr value, ArrayRef namedAttrs", [{ state.addOperands(ptr); state.addOperands(value); state.addAttributes(namedAttrs); diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.h b/mlir/include/mlir/Dialect/StandardOps/Ops.h index 1b1cf02d204..563116823d9 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.h @@ -182,15 +182,15 @@ class DmaStartOp public: using Op::Op; - static void build(Builder *builder, OperationState &result, Value *srcMemRef, - ValueRange srcIndices, Value *destMemRef, - ValueRange destIndices, Value *numElements, - Value *tagMemRef, ValueRange tagIndices, - Value *stride = nullptr, - Value *elementsPerStride = nullptr); + static void build(Builder *builder, OperationState &result, + ValuePtr srcMemRef, ValueRange srcIndices, + ValuePtr destMemRef, ValueRange destIndices, + ValuePtr numElements, ValuePtr tagMemRef, + ValueRange tagIndices, ValuePtr stride = nullptr, + ValuePtr elementsPerStride = nullptr); // Returns the source MemRefType for this DMA operation. - Value *getSrcMemRef() { return getOperand(0); } + ValuePtr getSrcMemRef() { return getOperand(0); } // Returns the rank (number of indices) of the source MemRefType. unsigned getSrcMemRefRank() { return getSrcMemRef()->getType().cast().getRank(); @@ -202,7 +202,7 @@ public: } // Returns the destination MemRefType for this DMA operations. - Value *getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); } + ValuePtr getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); } // Returns the rank (number of indices) of the destination MemRefType. unsigned getDstMemRefRank() { return getDstMemRef()->getType().cast().getRank(); @@ -222,12 +222,12 @@ public: } // Returns the number of elements being transferred by this DMA operation. - Value *getNumElements() { + ValuePtr getNumElements() { return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank()); } // Returns the Tag MemRef for this DMA operation. - Value *getTagMemRef() { + ValuePtr getTagMemRef() { return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1); } // Returns the rank (number of indices) of the tag MemRefType. @@ -276,13 +276,13 @@ public: 1 + 1 + getTagMemRefRank(); } - Value *getStride() { + ValuePtr getStride() { if (!isStrided()) return nullptr; return getOperand(getNumOperands() - 1 - 1); } - Value *getNumElementsPerStride() { + ValuePtr getNumElementsPerStride() { if (!isStrided()) return nullptr; return getOperand(getNumOperands() - 1); @@ -307,13 +307,14 @@ class DmaWaitOp public: using Op::Op; - static void build(Builder *builder, OperationState &result, Value *tagMemRef, - ValueRange tagIndices, Value *numElements); + static void build(Builder *builder, OperationState &result, + ValuePtr tagMemRef, ValueRange tagIndices, + ValuePtr numElements); static StringRef getOperationName() { return "std.dma_wait"; } // Returns the Tag MemRef associated with the DMA operation being waited on. - Value *getTagMemRef() { return getOperand(0); } + ValuePtr getTagMemRef() { return getOperand(0); } // Returns the tag memref index for this DMA operation. operand_range getTagIndices() { @@ -327,7 +328,7 @@ public: } // Returns the number of elements transferred in the associated DMA operation. - Value *getNumElements() { return getOperand(1 + getTagMemRefRank()); } + ValuePtr getNumElements() { return getOperand(1 + getTagMemRefRank()); } static ParseResult parse(OpAsmParser &parser, OperationState &result); void print(OpAsmPrinter &p); @@ -342,7 +343,7 @@ void printDimAndSymbolList(Operation::operand_iterator begin, /// Parses dimension and symbol list and returns true if parsing failed. ParseResult parseDimAndSymbolList(OpAsmParser &parser, - SmallVectorImpl &operands, + SmallVectorImpl &operands, unsigned &numDims); raw_ostream &operator<<(raw_ostream &os, SubViewOp::Range &range); diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td index c26baf6a76e..e00674708f6 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -52,7 +52,7 @@ class CastOp traits = []> : let results = (outs AnyType); let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value *source, Type destType", [{ + "Builder *builder, OperationState &result, ValuePtr source, Type destType", [{ impl::buildCastOp(builder, result, source, destType); }]>]; @@ -191,7 +191,7 @@ def AllocOp : Std_Op<"alloc"> { }]>, OpBuilder< "Builder *builder, OperationState &result, MemRefType memrefType, " # - "ArrayRef operands, IntegerAttr alignment = IntegerAttr()", [{ + "ArrayRef operands, IntegerAttr alignment = IntegerAttr()", [{ result.addOperands(operands); result.types.push_back(memrefType); if (alignment) @@ -330,7 +330,7 @@ def CallIndirectOp : Std_Op<"call_indirect", [CallOpInterface]> { let results = (outs Variadic); let builders = [OpBuilder< - "Builder *, OperationState &result, Value *callee," + "Builder *, OperationState &result, ValuePtr callee," "ValueRange operands = {}", [{ result.operands.push_back(callee); result.addOperands(operands); @@ -338,7 +338,7 @@ def CallIndirectOp : Std_Op<"call_indirect", [CallOpInterface]> { }]>]; let extraClassDeclaration = [{ - Value *getCallee() { return getOperand(0); } + ValuePtr getCallee() { return getOperand(0); } /// Get the argument operands to the called function. operand_range getArgOperands() { @@ -395,7 +395,7 @@ def CmpFOp : Std_Op<"cmpf", let builders = [OpBuilder< "Builder *builder, OperationState &result, CmpFPredicate predicate," - "Value *lhs, Value *rhs", [{ + "ValuePtr lhs, ValuePtr rhs", [{ ::buildCmpFOp(builder, result, predicate, lhs, rhs); }]>]; @@ -463,7 +463,7 @@ def CmpIOp : Std_Op<"cmpi", let builders = [OpBuilder< "Builder *builder, OperationState &result, CmpIPredicate predicate," - "Value *lhs, Value *rhs", [{ + "ValuePtr lhs, ValuePtr rhs", [{ ::buildCmpIOp(builder, result, predicate, lhs, rhs); }]>]; @@ -502,7 +502,7 @@ def CondBranchOp : Std_Op<"cond_br", [Terminator]> { let arguments = (ins I1:$condition, Variadic:$branchOperands); let builders = [OpBuilder< - "Builder *, OperationState &result, Value *condition," + "Builder *, OperationState &result, ValuePtr condition," "Block *trueDest, ValueRange trueOperands," "Block *falseDest, ValueRange falseOperands", [{ result.addOperands(condition); @@ -518,7 +518,7 @@ def CondBranchOp : Std_Op<"cond_br", [Terminator]> { enum { trueIndex = 0, falseIndex = 1 }; // The condition operand is the first operand in the list. - Value *getCondition() { return getOperand(0); } + ValuePtr getCondition() { return getOperand(0); } /// Return the destination if the condition is true. Block *getTrueDest() { @@ -531,12 +531,12 @@ def CondBranchOp : Std_Op<"cond_br", [Terminator]> { } // Accessors for operands to the 'true' destination. - Value *getTrueOperand(unsigned idx) { + ValuePtr getTrueOperand(unsigned idx) { assert(idx < getNumTrueOperands()); return getOperand(getTrueDestOperandIndex() + idx); } - void setTrueOperand(unsigned idx, Value *value) { + void setTrueOperand(unsigned idx, ValuePtr value) { assert(idx < getNumTrueOperands()); setOperand(getTrueDestOperandIndex() + idx, value); } @@ -561,11 +561,11 @@ def CondBranchOp : Std_Op<"cond_br", [Terminator]> { } // Accessors for operands to the 'false' destination. - Value *getFalseOperand(unsigned idx) { + ValuePtr getFalseOperand(unsigned idx) { assert(idx < getNumFalseOperands()); return getOperand(getFalseDestOperandIndex() + idx); } - void setFalseOperand(unsigned idx, Value *value) { + void setFalseOperand(unsigned idx, ValuePtr value) { assert(idx < getNumFalseOperands()); setOperand(getFalseDestOperandIndex() + idx, value); } @@ -678,7 +678,7 @@ def DimOp : Std_Op<"dim", [NoSideEffect]> { let results = (outs Index); let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value *memrefOrTensor," + "Builder *builder, OperationState &result, ValuePtr memrefOrTensor," "unsigned index", [{ auto indexType = builder->getIndexType(); auto indexAttr = builder->getIntegerAttr(indexType, index); @@ -730,7 +730,7 @@ def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> { let results = (outs AnyType); let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value *aggregate," + "Builder *builder, OperationState &result, ValuePtr aggregate," "ValueRange indices = {}", [{ auto resType = aggregate->getType().cast() .getElementType(); @@ -738,7 +738,7 @@ def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> { }]>]; let extraClassDeclaration = [{ - Value *getAggregate() { return getOperand(0); } + ValuePtr getAggregate() { return getOperand(0); } operand_range getIndices() { return {operand_begin() + 1, operand_end()}; @@ -816,7 +816,7 @@ def LoadOp : Std_Op<"load"> { let results = (outs AnyType); let builders = [OpBuilder< - "Builder *, OperationState &result, Value *memref," + "Builder *, OperationState &result, ValuePtr memref," "ValueRange indices = {}", [{ auto memrefType = memref->getType().cast(); result.addOperands(memref); @@ -825,8 +825,8 @@ def LoadOp : Std_Op<"load"> { }]>]; let extraClassDeclaration = [{ - Value *getMemRef() { return getOperand(0); } - void setMemRef(Value *value) { setOperand(0, value); } + ValuePtr getMemRef() { return getOperand(0); } + void setMemRef(ValuePtr value) { setOperand(0, value); } MemRefType getMemRefType() { return getMemRef()->getType().cast(); } @@ -952,8 +952,8 @@ def PrefetchOp : Std_Op<"prefetch"> { BoolAttr:$isDataCache); let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value *memref," - "ArrayRef indices, bool isWrite, unsigned hint, bool isData", + "Builder *builder, OperationState &result, ValuePtr memref," + "ArrayRef indices, bool isWrite, unsigned hint, bool isData", [{ auto hintAttr = builder->getI32IntegerAttr(hint); auto isWriteAttr = builder->getBoolAttr(isWrite); @@ -990,7 +990,7 @@ def RankOp : Std_Op<"rank", [NoSideEffect]> { let verifier = ?; let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value *tensor", [{ + "Builder *builder, OperationState &result, ValuePtr tensor", [{ auto indexType = builder->getIndexType(); build(builder, result, indexType, tensor); }]>]; @@ -1052,16 +1052,16 @@ def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape]> { let results = (outs AnyType); let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value *condition," - "Value *trueValue, Value *falseValue", [{ + "Builder *builder, OperationState &result, ValuePtr condition," + "ValuePtr trueValue, ValuePtr falseValue", [{ result.addOperands({condition, trueValue, falseValue}); result.addTypes(trueValue->getType()); }]>]; let extraClassDeclaration = [{ - Value *getCondition() { return condition(); } - Value *getTrueValue() { return true_value(); } - Value *getFalseValue() { return false_value(); } + ValuePtr getCondition() { return condition(); } + ValuePtr getTrueValue() { return true_value(); } + ValuePtr getFalseValue() { return false_value(); } }]; let hasFolder = 1; @@ -1089,7 +1089,7 @@ def SignExtendIOp : Std_Op<"sexti", let results = (outs IntegerLike); let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value *value, Type destType", [{ + "Builder *builder, OperationState &result, ValuePtr value, Type destType", [{ result.addOperands(value); result.addTypes(destType); }]>]; @@ -1189,7 +1189,7 @@ def SplatOp : Std_Op<"splat", [NoSideEffect]> { let results = (outs AnyTypeOf<[AnyVector, AnyStaticShapeTensor]>:$aggregate); let builders = - [OpBuilder<"Builder *builder, OperationState &result, Value *element, " + [OpBuilder<"Builder *builder, OperationState &result, ValuePtr element, " "Type aggregateType", [{ build(builder, result, aggregateType, element); }]>]; @@ -1213,16 +1213,16 @@ def StoreOp : Std_Op<"store"> { Variadic:$indices); let builders = [OpBuilder< - "Builder *, OperationState &result, Value *valueToStore, Value *memref", [{ + "Builder *, OperationState &result, ValuePtr valueToStore, ValuePtr memref", [{ result.addOperands(valueToStore); result.addOperands(memref); }]>]; let extraClassDeclaration = [{ - Value *getValueToStore() { return getOperand(0); } + ValuePtr getValueToStore() { return getOperand(0); } - Value *getMemRef() { return getOperand(1); } - void setMemRef(Value *value) { setOperand(1, value); } + ValuePtr getMemRef() { return getOperand(1); } + void setMemRef(ValuePtr value) { setOperand(1, value); } MemRefType getMemRefType() { return getMemRef()->getType().cast(); } @@ -1364,13 +1364,13 @@ def SubViewOp : Std_Op<"subview", [AttrSizedOperandSegments, NoSideEffect]> { let builders = [ OpBuilder< - "Builder *b, OperationState &result, Value *source, " + "Builder *b, OperationState &result, ValuePtr source, " "ValueRange offsets, ValueRange sizes, " "ValueRange strides, Type resultType = Type(), " "ArrayRef attrs = {}">, OpBuilder< "Builder *builder, OperationState &result, " - "Type resultType, Value *source"> + "Type resultType, ValuePtr source"> ]; let extraClassDeclaration = [{ @@ -1403,7 +1403,7 @@ def SubViewOp : Std_Op<"subview", [AttrSizedOperandSegments, NoSideEffect]> { // offset, size and stride operands of the SubViewOp into a list of triples. // Such a list of triple is sometimes more convenient to manipulate. struct Range { - Value *offset, *size, *stride; + ValuePtr offset, size, stride; }; SmallVector getRanges(); }]; @@ -1465,7 +1465,7 @@ def TensorLoadOp : Std_Op<"tensor_load", let verifier = ?; let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value *memref", [{ + "Builder *builder, OperationState &result, ValuePtr memref", [{ auto memrefType = memref->getType().cast(); auto resultType = RankedTensorType::get(memrefType.getShape(), memrefType.getElementType()); @@ -1519,7 +1519,7 @@ def TruncateIOp : Std_Op<"trunci", [NoSideEffect, SameOperandsAndResultShape]> { let results = (outs IntegerLike); let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value *value, Type destType", [{ + "Builder *builder, OperationState &result, ValuePtr value, Type destType", [{ result.addOperands(value); result.addTypes(destType); }]>]; @@ -1578,7 +1578,7 @@ def ViewOp : Std_Op<"view", [NoSideEffect]> { /// Returns the dynamic offset for this view operation if specified. /// Returns nullptr if no dynamic offset was specified. - Value *getDynamicOffset(); + ValuePtr getDynamicOffset(); /// Returns the starting operand list position of the dynamic size operands. unsigned getDynamicSizesOperandStart() { @@ -1619,7 +1619,7 @@ def ZeroExtendIOp : Std_Op<"zexti", [NoSideEffect, SameOperandsAndResultShape]> let results = (outs IntegerLike); let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value *value, Type destType", [{ + "Builder *builder, OperationState &result, ValuePtr value, Type destType", [{ result.addOperands(value); result.addTypes(destType); }]>]; diff --git a/mlir/include/mlir/Dialect/VectorOps/Utils.h b/mlir/include/mlir/Dialect/VectorOps/Utils.h index f61a813855d..68c62cc7ec7 100644 --- a/mlir/include/mlir/Dialect/VectorOps/Utils.h +++ b/mlir/include/mlir/Dialect/VectorOps/Utils.h @@ -34,6 +34,9 @@ class Operation; class Value; class VectorType; +// TODO(riverriddle) Remove this after Value is value-typed. +using ValuePtr = Value *; + /// Computes and returns the multi-dimensional ratio of `superShape` to /// `subShape`. This is calculated by performing a traversal from minor to major /// dimensions (i.e. in reverse shape order). If integral division is not @@ -122,7 +125,7 @@ Optional> shapeRatio(VectorType superVectorType, /// `%arg0[%c0, %c0]` into vector<128xf32> which needs a 1-D vector broadcast. /// AffineMap -makePermutationMap(Operation *op, ArrayRef indices, +makePermutationMap(Operation *op, ArrayRef indices, const DenseMap &loopToVectorDim); namespace matcher { diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td index 5fd19498350..94262e6f1ff 100644 --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -128,8 +128,8 @@ def Vector_ContractionOp : : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32> }]; let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value *lhs, Value *rhs, " - "Value *acc, ArrayAttr indexingMaps, ArrayAttr iteratorTypes">]; + "Builder *builder, OperationState &result, ValuePtr lhs, ValuePtr rhs, " + "ValuePtr acc, ArrayAttr indexingMaps, ArrayAttr iteratorTypes">]; let extraClassDeclaration = [{ VectorType getLhsType() { return lhs()->getType().cast(); @@ -252,7 +252,8 @@ def Vector_ShuffleOp : ``` }]; - let builders = [OpBuilder<"Builder *builder, OperationState &result, Value *v1, Value *v2, ArrayRef">]; + let builders = [OpBuilder<"Builder *builder, OperationState &result," + "ValuePtr v1, ValuePtr v2, ArrayRef">]; let extraClassDeclaration = [{ static StringRef getMaskAttrName() { return "mask"; } VectorType getV1VectorType() { @@ -312,7 +313,8 @@ def Vector_ExtractOp : ``` }]; let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value *source, ArrayRef">]; + "Builder *builder, OperationState &result, ValuePtr source," + "ArrayRef">]; let extraClassDeclaration = [{ static StringRef getPositionAttrName() { return "position"; } VectorType getVectorType() { @@ -357,7 +359,7 @@ def Vector_ExtractSlicesOp : }]; let builders = [OpBuilder< "Builder *builder, OperationState &result, TupleType tupleType, " # - "Value *vector, ArrayRef sizes, " # + "ValuePtr vector, ArrayRef sizes, " # "ArrayRef strides">]; let extraClassDeclaration = [{ VectorType getSourceVectorType() { @@ -428,8 +430,8 @@ def Vector_InsertOp : ``` }]; let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value *source, " # - "Value *dest, ArrayRef">]; + "Builder *builder, OperationState &result, ValuePtr source, " # + "ValuePtr dest, ArrayRef">]; let extraClassDeclaration = [{ static StringRef getPositionAttrName() { return "position"; } Type getSourceType() { return source()->getType(); } @@ -521,7 +523,7 @@ def Vector_InsertStridedSliceOp : ``` }]; let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value *source, Value *dest, " # + "Builder *builder, OperationState &result, ValuePtr source, ValuePtr dest, " # "ArrayRef offsets, ArrayRef strides">]; let extraClassDeclaration = [{ static StringRef getOffsetsAttrName() { return "offsets"; } @@ -723,7 +725,7 @@ def Vector_StridedSliceOp : vector<4x8x16xf32> to vector<2x4x16xf32> }]; let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value *source, " # + "Builder *builder, OperationState &result, ValuePtr source, " # "ArrayRef offsets, ArrayRef sizes, " # "ArrayRef strides">]; let extraClassDeclaration = [{ @@ -975,7 +977,7 @@ def Vector_TypeCastOp : }]; let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value *source">]; + "Builder *builder, OperationState &result, ValuePtr source">]; let parser = [{ return impl::parseCastOp(parser, result); diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h b/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h index 2c2e4e7c4fa..b48cb51533f 100644 --- a/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h @@ -73,8 +73,9 @@ namespace vector { // // This will be extended in the future to support more advanced use cases than // simple pointwise ops. -Value *unrollSingleResultOpMatchingType(PatternRewriter &builder, Operation *op, - ArrayRef targetShape); +ValuePtr unrollSingleResultOpMatchingType(PatternRewriter &builder, + Operation *op, + ArrayRef targetShape); } // namespace vector } // namespace mlir diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h index 69c72a50870..11ee0bff342 100644 --- a/mlir/include/mlir/EDSC/Builders.h +++ b/mlir/include/mlir/EDSC/Builders.h @@ -152,7 +152,7 @@ private: /// A LoopBuilder is a generic NestedBuilder for loop-like MLIR operations. /// More specifically it is meant to be used as a temporary object for -/// representing any nested MLIR construct that is "related to" an mlir::Value* +/// representing any nested MLIR construct that is "related to" an mlir::Value /// (for now an induction variable). /// This is extensible and will evolve in the future as MLIR evolves, hence /// the name LoopBuilder (as opposed to say ForBuilder or AffineForBuilder). @@ -242,7 +242,7 @@ class Append {}; /// A BlockBuilder is a NestedBuilder for mlir::Block*. /// This exists by opposition to LoopBuilder which is not related to an -/// mlir::Block* but to a mlir::Value*. +/// mlir::Block* but to a mlir::Value. /// It is meant to be used as a temporary object for representing any nested /// MLIR construct that is "related to" an mlir::Block*. class BlockBuilder : public NestedBuilder { @@ -257,7 +257,7 @@ public: /// /// Prerequisites: /// The ValueHandle `args` are typed delayed ValueHandles; i.e. they are - /// not yet bound to mlir::Value*. + /// not yet bound to mlir::Value. BlockBuilder(BlockHandle *bh, ArrayRef args); /// The only purpose of this operator is to serve as a sequence point so that @@ -291,10 +291,10 @@ protected: /// typed "delayed" value that can be hold a Value in the future; /// 3. constructed state,in which case it holds a Value. /// -/// A ValueHandle is meant to capture a single Value* and should be used for +/// A ValueHandle is meant to capture a single Value and should be used for /// operations that have a single result. For convenience of use, we also /// include AffineForOp in this category although it does not return a value. -/// In the case of AffineForOp, the captured Value* is the loop induction +/// In the case of AffineForOp, the captured Value is the loop induction /// variable. class ValueHandle : public CapturableHandle { public: @@ -304,15 +304,15 @@ public: /// A ValueHandle that is constructed from a Type represents a typed "delayed" /// Value. A delayed Value can only capture Values of the specified type. /// Such a delayed value represents the declaration (in the PL sense) of a - /// placeholder for an mlir::Value* that will be constructed and captured at + /// placeholder for an mlir::Value that will be constructed and captured at /// some later point in the program. explicit ValueHandle(Type t) : t(t), v(nullptr) {} - /// A ValueHandle that is constructed from an mlir::Value* is an "eager" + /// A ValueHandle that is constructed from an mlir::Value is an "eager" /// Value. An eager Value represents both the declaration and the definition - /// (in the PL sense) of a placeholder for an mlir::Value* that has already + /// (in the PL sense) of a placeholder for an mlir::Value that has already /// been constructed in the past and that is captured "now" in the program. - explicit ValueHandle(Value *v) : t(v->getType()), v(v) {} + explicit ValueHandle(ValuePtr v) : t(v->getType()), v(v) {} /// Builds a ConstantIndexOp of value `cst`. The constant is created at the /// current insertion point. @@ -336,8 +336,8 @@ public: std::swap(v, other.v); } - /// Implicit conversion useful for automatic conversion to Container. - operator Value *() const { return getValue(); } + /// Implicit conversion useful for automatic conversion to Container. + operator ValuePtr() const { return getValue(); } /// Generic mlir::Op create. This is the key to being extensible to the whole /// of MLIR without duplicating the type system or the op definitions. @@ -355,7 +355,7 @@ public: /// Special case to build composed AffineApply operations. // TODO: createOrFold when available and move inside of the `create` method. static ValueHandle createComposedAffineApply(AffineMap map, - ArrayRef operands); + ArrayRef operands); /// Generic create for a named operation producing a single value. static ValueHandle create(StringRef name, ArrayRef operands, @@ -363,7 +363,7 @@ public: ArrayRef attributes = {}); bool hasValue() const { return v != nullptr; } - Value *getValue() const { + ValuePtr getValue() const { assert(hasValue() && "Unexpected null value;"); return v; } @@ -380,12 +380,12 @@ protected: ValueHandle() : t(), v(nullptr) {} Type t; - Value *v; + ValuePtr v; }; /// An OperationHandle can be used in lieu of ValueHandle to capture the /// operation in cases when one does not care about, or cannot extract, a -/// unique Value* from the operation. +/// unique Value from the operation. /// This can be used for capturing zero result operations as well as /// multi-result operations that are not supported by ValueHandle. /// We do not distinguish further between zero and multi-result operations at @@ -529,7 +529,7 @@ ValueHandle operator>=(ValueHandle lhs, ValueHandle rhs); } // namespace op -/// Entry point to build multiple ValueHandle from a `Container` of Value* or +/// Entry point to build multiple ValueHandle from a `Container` of Value or /// Type. template inline SmallVector makeValueHandles(Container values) { diff --git a/mlir/include/mlir/EDSC/Helpers.h b/mlir/include/mlir/EDSC/Helpers.h index 423c92b2d06..c18307e7121 100644 --- a/mlir/include/mlir/EDSC/Helpers.h +++ b/mlir/include/mlir/EDSC/Helpers.h @@ -75,7 +75,7 @@ protected: // TODO(ntv): Support MemRefs with layoutMaps. class MemRefView : public View { public: - explicit MemRefView(Value *v); + explicit MemRefView(ValuePtr v); MemRefView(const MemRefView &) = default; MemRefView &operator=(const MemRefView &) = default; @@ -91,7 +91,7 @@ private: /// a MemRefView but for vectors. This exists purely for boilerplate avoidance. class VectorView : public View { public: - explicit VectorView(Value *v); + explicit VectorView(ValuePtr v); VectorView(const VectorView &) = default; VectorView &operator=(const VectorView &) = default; @@ -120,7 +120,7 @@ private: template class TemplatedIndexedValue { public: explicit TemplatedIndexedValue(Type t) : base(t) {} - explicit TemplatedIndexedValue(Value *v) + explicit TemplatedIndexedValue(ValuePtr v) : TemplatedIndexedValue(ValueHandle(v)) {} explicit TemplatedIndexedValue(ValueHandle v) : base(v) {} @@ -161,8 +161,8 @@ public: return Load(getBase(), {indices.begin(), indices.end()}); } - /// Emits a `load` when converting to a Value*. - Value *operator*(void)const { + /// Emits a `load` when converting to a Value. + ValuePtr operator*(void) const { return Load(getBase(), {indices.begin(), indices.end()}).getValue(); } diff --git a/mlir/include/mlir/EDSC/Intrinsics.h b/mlir/include/mlir/EDSC/Intrinsics.h index 06c75505cb7..dc0c1186c7a 100644 --- a/mlir/include/mlir/EDSC/Intrinsics.h +++ b/mlir/include/mlir/EDSC/Intrinsics.h @@ -44,7 +44,7 @@ struct IndexHandle : public ValueHandle { explicit IndexHandle() : ValueHandle(ScopedContext::getBuilder().getIndexType()) {} explicit IndexHandle(index_t v) : ValueHandle(v) {} - explicit IndexHandle(Value *v) : ValueHandle(v) { + explicit IndexHandle(ValuePtr v) : ValueHandle(v) { assert(v->getType() == ScopedContext::getBuilder().getIndexType() && "Expected index type"); } @@ -79,9 +79,9 @@ makeHandlePointers(MutableArrayRef ivs) { return pivs; } -/// Returns a vector of the underlying Value* from `ivs`. -inline SmallVector extractValues(ArrayRef ivs) { - SmallVector vals; +/// Returns a vector of the underlying Value from `ivs`. +inline SmallVector extractValues(ArrayRef ivs) { + SmallVector vals; vals.reserve(ivs.size()); for (auto &iv : ivs) { vals.push_back(iv.getValue()); @@ -96,7 +96,7 @@ namespace intrinsics { namespace detail { /// Helper structure to be used with ValueBuilder / OperationBuilder. /// It serves the purpose of removing boilerplate specialization for the sole -/// purpose of implicitly converting ArrayRef -> ArrayRef. +/// purpose of implicitly converting ArrayRef -> ArrayRef. class ValueHandleArray { public: ValueHandleArray(ArrayRef vals) { @@ -109,11 +109,11 @@ public: SmallVector tmp(vals.begin(), vals.end()); values.append(tmp.begin(), tmp.end()); } - operator ArrayRef() { return values; } + operator ArrayRef() { return values; } private: ValueHandleArray() = default; - SmallVector values; + SmallVector values; }; template inline T unpack(T value) { return value; } @@ -128,8 +128,8 @@ inline detail::ValueHandleArray unpack(ArrayRef values) { /// boilerplate or Tablegen. /// Arguably a builder is not a ValueHandle but in practice it is only used as /// an alias to a notional ValueHandle. -/// Implementing it as a subclass allows it to compose all the way to Value*. -/// Without subclassing, implicit conversion to Value* would fail when composing +/// Implementing it as a subclass allows it to compose all the way to Value. +/// Without subclassing, implicit conversion to Value would fail when composing /// in patterns such as: `select(a, b, select(c, d, e))`. template struct ValueBuilder : public ValueHandle { // Builder-based @@ -238,8 +238,8 @@ OperationHandle br(BlockHandle bh, ArrayRef operands); /// /// Prerequisites: /// `b` has not yet captured an mlir::Block*. -/// No `captures` have captured any mlir::Value*. -/// All `operands` have already captured an mlir::Value* +/// No `captures` have captured any mlir::Value. +/// All `operands` have already captured an mlir::Value /// captures.size() == operands.size() /// captures and operands are pairwise of the same type. OperationHandle br(BlockHandle *bh, ArrayRef captures, @@ -266,8 +266,8 @@ OperationHandle cond_br(ValueHandle cond, BlockHandle trueBranch, /// /// Prerequisites: /// `trueBranch`/`falseBranch` has not yet captured an mlir::Block*. -/// No `trueCaptures`/`falseCaptures` have captured any mlir::Value*. -/// All `trueOperands`/`trueOperands` have already captured an mlir::Value* +/// No `trueCaptures`/`falseCaptures` have captured any mlir::Value. +/// All `trueOperands`/`trueOperands` have already captured an mlir::Value /// `trueCaptures`.size() == `trueOperands`.size() /// `falseCaptures`.size() == `falseOperands`.size() /// `trueCaptures` and `trueOperands` are pairwise of the same type diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h index 6c5099b06da..87c77160e1d 100644 --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -72,7 +72,7 @@ public: //===--------------------------------------------------------------------===// // This is the list of arguments to the block. - using BlockArgListType = ArrayRef; + using BlockArgListType = ArrayRef; BlockArgListType getArguments() { return arguments; } @@ -86,7 +86,7 @@ public: bool args_empty() { return arguments.empty(); } /// Add one value to the argument list. - BlockArgument *addArgument(Type type); + BlockArgumentPtr addArgument(Type type); /// Add one argument to the argument list for each type specified in the list. iterator_range addArguments(ArrayRef types); @@ -97,7 +97,7 @@ public: void eraseArgument(unsigned index, bool updatePredTerms = true); unsigned getNumArguments() { return arguments.size(); } - BlockArgument *getArgument(unsigned i) { return arguments[i]; } + BlockArgumentPtr getArgument(unsigned i) { return arguments[i]; } //===--------------------------------------------------------------------===// // Operation list management @@ -332,7 +332,7 @@ private: OpListType operations; /// This is the list of arguments to the block. - std::vector arguments; + std::vector arguments; Block(Block &) = delete; void operator=(Block &) = delete; diff --git a/mlir/include/mlir/IR/BlockAndValueMapping.h b/mlir/include/mlir/IR/BlockAndValueMapping.h index cd15d457a77..287dd508fa6 100644 --- a/mlir/include/mlir/IR/BlockAndValueMapping.h +++ b/mlir/include/mlir/IR/BlockAndValueMapping.h @@ -37,7 +37,7 @@ public: /// Inserts a new mapping for 'from' to 'to'. If there is an existing mapping, /// it is overwritten. void map(Block *from, Block *to) { valueMap[from] = to; } - void map(Value *from, Value *to) { valueMap[from] = to; } + void map(ValuePtr from, ValuePtr to) { valueMap[from] = to; } /// Erases a mapping for 'from'. void erase(IRObjectWithUseList *from) { valueMap.erase(from); } @@ -52,8 +52,8 @@ public: Block *lookupOrNull(Block *from) const { return lookupOrValue(from, (Block *)nullptr); } - Value *lookupOrNull(Value *from) const { - return lookupOrValue(from, (Value *)nullptr); + ValuePtr lookupOrNull(ValuePtr from) const { + return lookupOrValue(from, (ValuePtr) nullptr); } /// Lookup a mapped value within the map. If a mapping for the provided value @@ -61,7 +61,7 @@ public: Block *lookupOrDefault(Block *from) const { return lookupOrValue(from, from); } - Value *lookupOrDefault(Value *from) const { + ValuePtr lookupOrDefault(ValuePtr from) const { return lookupOrValue(from, from); } diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 766902fabfa..c199c09feb5 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -313,7 +313,7 @@ public: /// and immediately try to fold it. This functions populates 'results' with /// the results after folding the operation. template - void createOrFold(SmallVectorImpl &results, Location location, + void createOrFold(SmallVectorImpl &results, Location location, Args &&... args) { // Create the operation without using 'createOperation' as we don't want to // insert it yet. @@ -331,9 +331,9 @@ public: /// Overload to create or fold a single result operation. template typename std::enable_if(), - Value *>::type + ValuePtr>::type createOrFold(Location location, Args &&... args) { - SmallVector results; + SmallVector results; createOrFold(results, location, std::forward(args)...); return results.front(); } @@ -344,7 +344,7 @@ public: OpTy>::type createOrFold(Location location, Args &&... args) { auto op = create(location, std::forward(args)...); - SmallVector unused; + SmallVector unused; tryFold(op.getOperation(), unused); // Folding cannot remove a zero-result operation, so for convenience we @@ -355,7 +355,7 @@ public: /// Attempts to fold the given operation and places new results within /// 'results'. Returns success if the operation was folded, failure otherwise. /// Note: This function does not erase the operation on a successful fold. - LogicalResult tryFold(Operation *op, SmallVectorImpl &results); + LogicalResult tryFold(Operation *op, SmallVectorImpl &results); /// Creates a deep copy of the specified operation, remapping any operands /// that use values outside of the operation using the map that is provided diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h index b15b056a3ec..1ba85d73df9 100644 --- a/mlir/include/mlir/IR/FunctionSupport.h +++ b/mlir/include/mlir/IR/FunctionSupport.h @@ -183,7 +183,7 @@ public: } /// Gets argument. - BlockArgument *getArgument(unsigned idx) { + BlockArgumentPtr getArgument(unsigned idx) { return getBlocks().front().getArgument(idx); } diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h index 1261916dae2..3b36f2fb5eb 100644 --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -142,7 +142,7 @@ using has_operation_or_value_matcher_t = /// Statically switch to a Value matcher. template typename std::enable_if_t::value, + MatcherClass, ValuePtr>::value, bool> matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) { return matcher.match(op->getOperand(idx)); @@ -161,14 +161,14 @@ matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) { /// Terminal matcher, always returns true. struct AnyValueMatcher { - bool match(Value *op) const { return true; } + bool match(ValuePtr op) const { return true; } }; /// Binds to a specific value and matches it. struct PatternMatcherValue { - PatternMatcherValue(Value *val) : value(val) {} - bool match(Value *val) const { return val == value; } - Value *value; + PatternMatcherValue(ValuePtr val) : value(val) {} + bool match(ValuePtr val) const { return val == value; } + ValuePtr value; }; template @@ -235,7 +235,7 @@ inline detail::constant_int_not_value_matcher<0> m_NonZero() { /// Entry point for matching a pattern over a Value. template -inline bool matchPattern(Value *value, const Pattern &pattern) { +inline bool matchPattern(ValuePtr value, const Pattern &pattern) { // TODO: handle other cases if (auto *op = value->getDefiningOp()) return const_cast(pattern).match(op); @@ -262,7 +262,7 @@ auto m_Op(Matchers... matchers) { namespace matchers { inline auto m_Any() { return detail::AnyValueMatcher(); } -inline auto m_Val(Value *v) { return detail::PatternMatcherValue(v); } +inline auto m_Val(ValuePtr v) { return detail::PatternMatcherValue(v); } } // namespace matchers } // end namespace mlir diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index c220120b337..437540117c4 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -257,8 +257,8 @@ inline bool operator!=(OpState lhs, OpState rhs) { } /// This class represents a single result from folding an operation. -class OpFoldResult : public PointerUnion { - using PointerUnion::PointerUnion; +class OpFoldResult : public PointerUnion { + using PointerUnion::PointerUnion; }; /// This template defines the foldHook as used by AbstractOperation. @@ -311,8 +311,8 @@ class FoldingHook::type> { public: /// If the operation returns a single value, then the Op can be implicitly - /// converted to an Value*. This yields the value of the only result. - operator Value *() { + /// converted to an Value. This yields the value of the only result. + operator ValuePtr() { return static_cast(this)->getOperation()->getResult(0); } @@ -326,7 +326,7 @@ public: // Check if the operation was folded in place. In this case, the operation // returns itself. - if (result.template dyn_cast() != op->getResult(0)) + if (result.template dyn_cast() != op->getResult(0)) results.push_back(result); return success(); } @@ -428,10 +428,12 @@ struct MultiOperandTraitBase : public TraitBase { unsigned getNumOperands() { return this->getOperation()->getNumOperands(); } /// Return the operand at index 'i'. - Value *getOperand(unsigned i) { return this->getOperation()->getOperand(i); } + ValuePtr getOperand(unsigned i) { + return this->getOperation()->getOperand(i); + } /// Set the operand at index 'i' to 'value'. - void setOperand(unsigned i, Value *value) { + void setOperand(unsigned i, ValuePtr value) { this->getOperation()->setOperand(i, value); } @@ -475,9 +477,11 @@ private: template class OneOperand : public TraitBase { public: - Value *getOperand() { return this->getOperation()->getOperand(0); } + ValuePtr getOperand() { return this->getOperation()->getOperand(0); } - void setOperand(Value *value) { this->getOperation()->setOperand(0, value); } + void setOperand(ValuePtr value) { + this->getOperation()->setOperand(0, value); + } static LogicalResult verifyTrait(Operation *op) { return impl::verifyOneOperand(op); @@ -550,7 +554,7 @@ struct MultiResultTraitBase : public TraitBase { unsigned getNumResults() { return this->getOperation()->getNumResults(); } /// Return the result at index 'i'. - Value *getResult(unsigned i) { return this->getOperation()->getResult(i); } + ValuePtr getResult(unsigned i) { return this->getOperation()->getResult(i); } /// Replace all uses of results of this operation with the provided 'values'. /// 'values' may correspond to an existing operation, or a range of 'Value'. @@ -586,13 +590,13 @@ struct MultiResultTraitBase : public TraitBase { template class OneResult : public TraitBase { public: - Value *getResult() { return this->getOperation()->getResult(0); } + ValuePtr getResult() { return this->getOperation()->getResult(0); } Type getType() { return getResult()->getType(); } /// Replace all uses of 'this' value with the new value, updating anything in /// the IR that uses 'this' to use the other value instead. When this returns /// there are zero uses of 'this'. - void replaceAllUsesWith(Value *newValue) { + void replaceAllUsesWith(ValuePtr newValue) { getResult()->replaceAllUsesWith(newValue); } @@ -820,10 +824,10 @@ public: return this->getOperation()->setSuccessor(block, index); } - void addSuccessorOperand(unsigned index, Value *value) { + void addSuccessorOperand(unsigned index, ValuePtr value) { return this->getOperation()->addSuccessorOperand(index, value); } - void addSuccessorOperands(unsigned index, ArrayRef values) { + void addSuccessorOperands(unsigned index, ArrayRef values) { return this->getOperation()->addSuccessorOperand(index, values); } }; @@ -1209,8 +1213,8 @@ namespace impl { ParseResult parseOneResultOneOperandTypeOp(OpAsmParser &parser, OperationState &result); -void buildBinaryOp(Builder *builder, OperationState &result, Value *lhs, - Value *rhs); +void buildBinaryOp(Builder *builder, OperationState &result, ValuePtr lhs, + ValuePtr rhs); ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result); @@ -1223,11 +1227,11 @@ void printOneResultOp(Operation *op, OpAsmPrinter &p); // These functions are out-of-line implementations of the methods in CastOp, // which avoids them being template instantiated/duplicated. namespace impl { -void buildCastOp(Builder *builder, OperationState &result, Value *source, +void buildCastOp(Builder *builder, OperationState &result, ValuePtr source, Type destType); ParseResult parseCastOp(OpAsmParser &parser, OperationState &result); void printCastOp(Operation *op, OpAsmPrinter &p); -Value *foldCastOp(Operation *op); +ValuePtr foldCastOp(Operation *op); } // namespace impl } // end namespace mlir diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index 7dd11d089c2..fcadce9ab16 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -45,7 +45,7 @@ public: virtual raw_ostream &getStream() const = 0; /// Print implementations for various things an operation contains. - virtual void printOperand(Value *value) = 0; + virtual void printOperand(ValuePtr value) = 0; /// Print a comma separated list of operands. template @@ -121,7 +121,7 @@ public: void printFunctionalType(Operation *op) { auto &os = getStream(); os << "("; - interleaveComma(op->getNonSuccessorOperands(), os, [&](Value *operand) { + interleaveComma(op->getNonSuccessorOperands(), os, [&](ValuePtr operand) { if (operand) printType(operand->getType()); else @@ -150,18 +150,18 @@ private: }; // Make the implementations convenient to use. -inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value &value) { +inline OpAsmPrinter &operator<<(OpAsmPrinter &p, ValueRef value) { p.printOperand(&value); return p; } -inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value *value) { +inline OpAsmPrinter &operator<<(OpAsmPrinter &p, ValuePtr value) { return p << *value; } -template ::value && - !std::is_convertible::value, - T>::type * = nullptr> +template ::value && + !std::is_convertible::value, + T>::type * = nullptr> inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &values) { p.printOperands(values); return p; @@ -181,8 +181,8 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Attribute attr) { // even if it isn't exactly one of them. For example, we want to print // FunctionType with the Type version above, not have it match this. template ::value && - !std::is_convertible::value && + !std::is_convertible::value && + !std::is_convertible::value && !std::is_convertible::value && !std::is_convertible::value && !std::is_convertible::value && @@ -467,13 +467,13 @@ public: /// Resolve an operand to an SSA value, emitting an error on failure. virtual ParseResult resolveOperand(const OperandType &operand, Type type, - SmallVectorImpl &result) = 0; + SmallVectorImpl &result) = 0; /// Resolve a list of operands to SSA values, emitting an error on failure, or /// appending the results to the list on success. This method should be used /// when all operands have the same type. ParseResult resolveOperands(ArrayRef operands, Type type, - SmallVectorImpl &result) { + SmallVectorImpl &result) { for (auto elt : operands) if (resolveOperand(elt, type, result)) return failure(); @@ -485,7 +485,7 @@ public: /// to the list on success. ParseResult resolveOperands(ArrayRef operands, ArrayRef types, llvm::SMLoc loc, - SmallVectorImpl &result) { + SmallVectorImpl &result) { if (operands.size() != types.size()) return emitError(loc) << operands.size() << " operands present, but expected " @@ -556,7 +556,7 @@ public: /// Parse a single operation successor and its operand list. virtual ParseResult parseSuccessorAndUseList(Block *&dest, - SmallVectorImpl &operands) = 0; + SmallVectorImpl &operands) = 0; //===--------------------------------------------------------------------===// // Type Parsing @@ -634,7 +634,7 @@ private: /// A functor used to set the name of the start of a result group of an /// operation. See 'getAsmResultNames' below for more details. -using OpAsmSetValueNameFn = function_ref; +using OpAsmSetValueNameFn = function_ref; class OpAsmDialectInterface : public DialectInterface::Base { diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index 2159d10fd2a..ad0dc600f8f 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -44,7 +44,7 @@ public: /// Create a new Operation with the specific fields. static Operation *create(Location location, OperationName name, ArrayRef resultTypes, - ArrayRef operands, + ArrayRef operands, ArrayRef attributes, ArrayRef successors, unsigned numRegions, bool resizableOperandList); @@ -53,7 +53,7 @@ public: /// unnecessarily uniquing a list of attributes. static Operation *create(Location location, OperationName name, ArrayRef resultTypes, - ArrayRef operands, + ArrayRef operands, NamedAttributeList attributes, ArrayRef successors, unsigned numRegions, bool resizableOperandList); @@ -64,7 +64,7 @@ public: /// Create a new Operation with the specific fields. static Operation * create(Location location, OperationName name, ArrayRef resultTypes, - ArrayRef operands, NamedAttributeList attributes, + ArrayRef operands, NamedAttributeList attributes, ArrayRef successors = {}, RegionRange regions = {}, bool resizableOperandList = false); @@ -149,7 +149,7 @@ public: } /// Replace any uses of 'from' with 'to' within this operation. - void replaceUsesOfWith(Value *from, Value *to); + void replaceUsesOfWith(ValuePtr from, ValuePtr to); /// Replace all uses of results of this operation with the provided 'values'. template > decomposeSuccessorOperandIndex(unsigned operandIndex); - /// Returns the `BlockArgument*` corresponding to operand `operandIndex` in + /// Returns the `BlockArgument` corresponding to operand `operandIndex` in /// some successor, or None if `operandIndex` isn't a successor operand index. - Optional getSuccessorBlockArgument(unsigned operandIndex) { + Optional getSuccessorBlockArgument(unsigned operandIndex) { auto decomposed = decomposeSuccessorOperandIndex(operandIndex); if (!decomposed.hasValue()) return None; diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index 23ef0ce5937..b7f63218ba5 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -270,7 +270,7 @@ inline llvm::hash_code hash_value(OperationName arg) { struct OperationState { Location location; OperationName name; - SmallVector operands; + SmallVector operands; /// Types of the results of this operation. SmallVector types; SmallVector attributes; @@ -534,8 +534,8 @@ private: /// This class implements iteration on the types of a given range of values. template class ValueTypeIterator final - : public llvm::mapped_iterator { - static Type unwrap(Value *value) { return value->getType(); } + : public llvm::mapped_iterator { + static Type unwrap(ValuePtr value) { return value->getType(); } public: using reference = Type; @@ -545,7 +545,8 @@ public: /// Initializes the type iterator to the specified value iterator. ValueTypeIterator(ValueIteratorT it) - : llvm::mapped_iterator(it, &unwrap) {} + : llvm::mapped_iterator(it, &unwrap) { + } }; //===----------------------------------------------------------------------===// @@ -554,7 +555,7 @@ public: /// This class implements the operand iterators for the Operation class. class OperandRange final : public detail::indexed_accessor_range_base { + ValuePtr, ValuePtr, ValuePtr> { public: using RangeBaseT::RangeBaseT; OperandRange(Operation *op); @@ -569,7 +570,7 @@ private: return object + index; } /// See `detail::indexed_accessor_range_base` for details. - static Value *dereference_iterator(OpOperand *object, ptrdiff_t index) { + static ValuePtr dereference_iterator(OpOperand *object, ptrdiff_t index) { return object[index].get(); } @@ -582,8 +583,8 @@ private: /// This class implements the result iterators for the Operation class. class ResultRange final - : public detail::indexed_accessor_range_base { + : public detail::indexed_accessor_range_base { public: using RangeBaseT::RangeBaseT; ResultRange(Operation *op); @@ -594,11 +595,11 @@ public: private: /// See `detail::indexed_accessor_range_base` for details. - static OpResult *offset_base(OpResult *object, ptrdiff_t index) { + static OpResultPtr offset_base(OpResultPtr object, ptrdiff_t index) { return object + index; } /// See `detail::indexed_accessor_range_base` for details. - static Value *dereference_iterator(OpResult *object, ptrdiff_t index) { + static ValuePtr dereference_iterator(OpResultPtr object, ptrdiff_t index) { return &object[index]; } @@ -610,31 +611,31 @@ private: // ValueRange /// This class provides an abstraction over the different types of ranges over -/// Value*s. In many cases, this prevents the need to explicitly materialize a +/// Values. In many cases, this prevents the need to explicitly materialize a /// SmallVector/std::vector. This class should be used in places that are not /// suitable for a more derived type (e.g. ArrayRef) or a template range /// parameter. class ValueRange final : public detail::indexed_accessor_range_base< - ValueRange, PointerUnion, - Value *, Value *, Value *> { + ValueRange, PointerUnion, + ValuePtr, ValuePtr, ValuePtr> { public: using RangeBaseT::RangeBaseT; template , Arg>::value && - !std::is_convertible::value>> + std::is_constructible, Arg>::value && + !std::is_convertible::value>> ValueRange(Arg &&arg) - : ValueRange(ArrayRef(std::forward(arg))) {} - ValueRange(Value *const &value) : ValueRange(&value, /*count=*/1) {} - ValueRange(const std::initializer_list &values) - : ValueRange(ArrayRef(values)) {} + : ValueRange(ArrayRef(std::forward(arg))) {} + ValueRange(ValuePtr const &value) : ValueRange(&value, /*count=*/1) {} + ValueRange(const std::initializer_list &values) + : ValueRange(ArrayRef(values)) {} ValueRange(iterator_range values) : ValueRange(OperandRange(values)) {} ValueRange(iterator_range values) : ValueRange(ResultRange(values)) {} - ValueRange(ArrayRef values = llvm::None); + ValueRange(ArrayRef values = llvm::None); ValueRange(OperandRange values); ValueRange(ResultRange values); @@ -645,12 +646,12 @@ public: private: /// The type representing the owner of this range. This is either a list of /// values, operands, or results. - using OwnerT = PointerUnion; + using OwnerT = PointerUnion; /// See `detail::indexed_accessor_range_base` for details. static OwnerT offset_base(const OwnerT &owner, ptrdiff_t index); /// See `detail::indexed_accessor_range_base` for details. - static Value *dereference_iterator(const OwnerT &owner, ptrdiff_t index); + static ValuePtr dereference_iterator(const OwnerT &owner, ptrdiff_t index); /// Allow access to `offset_base` and `dereference_iterator`. friend RangeBaseT; diff --git a/mlir/include/mlir/IR/TypeUtilities.h b/mlir/include/mlir/IR/TypeUtilities.h index 2cce4dbb6cf..af22f9c4a9f 100644 --- a/mlir/include/mlir/IR/TypeUtilities.h +++ b/mlir/include/mlir/IR/TypeUtilities.h @@ -41,8 +41,8 @@ Type getElementTypeOrSelf(Type type); /// Return the element type or return the type itself. Type getElementTypeOrSelf(Attribute attr); -Type getElementTypeOrSelf(Value *val); -Type getElementTypeOrSelf(Value &val); +Type getElementTypeOrSelf(ValuePtr val); +Type getElementTypeOrSelf(ValueRef val); /// Get the types within a nested Tuple. A helper for the class method that /// handles storage concerns, which is tricky to do in tablegen. @@ -72,7 +72,7 @@ LogicalResult verifyCompatibleShape(Type type1, Type type2); // An iterator for the element types of an op's operands of shaped types. class OperandElementTypeIterator final : public llvm::mapped_iterator { + Type (*)(ValuePtr)> { public: using reference = Type; @@ -81,7 +81,7 @@ public: explicit OperandElementTypeIterator(Operation::operand_iterator it); private: - static Type unwrap(Value *value); + static Type unwrap(ValuePtr value); }; using OperandElementTypeRange = iterator_range; @@ -89,7 +89,7 @@ using OperandElementTypeRange = iterator_range; // An iterator for the tensor element types of an op's results of shaped types. class ResultElementTypeIterator final : public llvm::mapped_iterator { + Type (*)(ValuePtr)> { public: using reference = Type; @@ -98,7 +98,7 @@ public: explicit ResultElementTypeIterator(Operation::result_iterator it); private: - static Type unwrap(Value *value); + static Type unwrap(ValuePtr value); }; using ResultElementTypeRange = iterator_range; diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h index 34c74c888cb..11cb8cdcbc7 100644 --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -28,10 +28,18 @@ namespace mlir { class Block; +class BlockArgument; class Operation; +class OpResult; class Region; class Value; +/// Using directives that simplify the transition of Value to being value typed. +using BlockArgumentPtr = BlockArgument *; +using OpResultPtr = OpResult *; +using ValueRef = Value &; +using ValuePtr = Value *; + /// Operands contain a Value. using OpOperand = IROperandImpl; @@ -48,6 +56,15 @@ public: ~Value() {} + template bool isa() const { return U::classof(this); } + template U *dyn_cast() const { + return isa() ? (U *)this : nullptr; + } + template U *cast() const { + assert(isa()); + return (U *)this; + } + Kind getKind() const { return typeAndKind.getInt(); } Type getType() const { return typeAndKind.getPointer(); } @@ -66,7 +83,7 @@ public: /// Replace all uses of 'this' value with the new value, updating anything in /// the IR that uses 'this' to use the other value instead. When this returns /// there are zero uses of 'this'. - void replaceAllUsesWith(Value *newValue) { + void replaceAllUsesWith(ValuePtr newValue) { IRObjectWithUseList::replaceAllUsesWith(newValue); } @@ -100,7 +117,7 @@ private: llvm::PointerIntPair typeAndKind; }; -inline raw_ostream &operator<<(raw_ostream &os, Value &value) { +inline raw_ostream &operator<<(raw_ostream &os, ValueRef value) { value.print(os); return os; } @@ -160,7 +177,6 @@ private: /// through bitpacking shenanigans. Operation *const owner; }; - } // namespace mlir #endif diff --git a/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraph.h b/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraph.h index 070b3c36e8c..202e86566fc 100644 --- a/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraph.h +++ b/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraph.h @@ -163,7 +163,7 @@ public: } virtual Operation *getOp() const = 0; - virtual Value *getValue() const = 0; + virtual ValuePtr getValue() const = 0; static bool classof(const CAGNode *n) { return n->getKind() >= Kind::Anchor && n->getKind() <= Kind::LastAnchor; @@ -210,7 +210,7 @@ public: return n->getKind() == Kind::Anchor || n->getKind() == Kind::OperandAnchor; } - Value *getValue() const final { return op->getOperand(operandIdx); } + ValuePtr getValue() const final { return op->getOperand(operandIdx); } void printLabel(raw_ostream &os) const override; @@ -221,7 +221,7 @@ private: /// An anchor tied to a specific result. /// Since a result is already anchored to its defining op, result anchors refer -/// directly to the underlying Value*. +/// directly to the underlying Value. class CAGResultAnchor : public CAGAnchorNode { public: CAGResultAnchor(Operation *op, unsigned resultIdx); @@ -231,12 +231,12 @@ public: } Operation *getOp() const final { return resultValue->getDefiningOp(); } - Value *getValue() const final { return resultValue; } + ValuePtr getValue() const final { return resultValue; } void printLabel(raw_ostream &os) const override; private: - Value *resultValue; + ValuePtr resultValue; }; /// Base class for constraint nodes. diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h index 7adb4aac2e2..7464e2a347d 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -113,7 +113,7 @@ private: protected: // Mappings between original and translated values, used for lookups. llvm::StringMap functionMapping; - DenseMap valueMapping; + DenseMap valueMapping; DenseMap blockMapping; }; diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 814f2202f01..f9f1207c0a0 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -60,7 +60,7 @@ public: /// remaps an existing signature input. struct InputMapping { size_t inputNo, size; - Value *replacementValue; + ValuePtr replacementValue; }; /// Return the argument types for the new signature. @@ -90,7 +90,7 @@ public: /// Remap an input of the original signature to another `replacement` /// value. This drops the original argument. - void remapInput(unsigned origInputNo, Value *replacement); + void remapInput(unsigned origInputNo, ValuePtr replacement); private: /// The remapping information for each of the original arguments. @@ -143,7 +143,7 @@ public: /// the conversion has finished. virtual Operation *materializeConversion(PatternRewriter &rewriter, Type resultType, - ArrayRef inputs, + ArrayRef inputs, Location loc) { llvm_unreachable("expected 'materializeConversion' to be overridden"); } @@ -172,7 +172,7 @@ public: /// ConversionPattern ever needs to replace an operation that does not /// have successors. This function should not fail. If some specific cases of /// the operation are not supported, these cases should not be matched. - virtual void rewrite(Operation *op, ArrayRef operands, + virtual void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { llvm_unreachable("unimplemented rewrite"); } @@ -187,18 +187,18 @@ public: /// terminator operation that has successors. This function should not fail /// the pass. If some specific cases of the operation are not supported, /// these cases should not be matched. - virtual void rewrite(Operation *op, ArrayRef properOperands, + virtual void rewrite(Operation *op, ArrayRef properOperands, ArrayRef destinations, - ArrayRef> operands, + ArrayRef> operands, ConversionPatternRewriter &rewriter) const { llvm_unreachable("unimplemented rewrite for terminators"); } /// Hook for derived classes to implement combined matching and rewriting. virtual PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef properOperands, + matchAndRewrite(Operation *op, ArrayRef properOperands, ArrayRef destinations, - ArrayRef> operands, + ArrayRef> operands, ConversionPatternRewriter &rewriter) const { if (!match(op)) return matchFailure(); @@ -208,7 +208,7 @@ public: /// Hook for derived classes to implement combined matching and rewriting. virtual PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (!match(op)) return matchFailure(); @@ -234,27 +234,27 @@ struct OpConversionPattern : public ConversionPattern { /// Wrappers around the ConversionPattern methods that pass the derived op /// type. - void rewrite(Operation *op, ArrayRef operands, + void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { rewrite(cast(op), operands, rewriter); } - void rewrite(Operation *op, ArrayRef properOperands, + void rewrite(Operation *op, ArrayRef properOperands, ArrayRef destinations, - ArrayRef> operands, + ArrayRef> operands, ConversionPatternRewriter &rewriter) const final { rewrite(cast(op), properOperands, destinations, operands, rewriter); } PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef properOperands, + matchAndRewrite(Operation *op, ArrayRef properOperands, ArrayRef destinations, - ArrayRef> operands, + ArrayRef> operands, ConversionPatternRewriter &rewriter) const final { return matchAndRewrite(cast(op), properOperands, destinations, operands, rewriter); } PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { return matchAndRewrite(cast(op), operands, rewriter); } @@ -264,22 +264,22 @@ struct OpConversionPattern : public ConversionPattern { /// Rewrite and Match methods that operate on the SourceOp type. These must be /// overridden by the derived pattern class. - virtual void rewrite(SourceOp op, ArrayRef operands, + virtual void rewrite(SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { llvm_unreachable("must override matchAndRewrite or a rewrite method"); } - virtual void rewrite(SourceOp op, ArrayRef properOperands, + virtual void rewrite(SourceOp op, ArrayRef properOperands, ArrayRef destinations, - ArrayRef> operands, + ArrayRef> operands, ConversionPatternRewriter &rewriter) const { llvm_unreachable("unimplemented rewrite for terminators"); } virtual PatternMatchResult - matchAndRewrite(SourceOp op, ArrayRef properOperands, + matchAndRewrite(SourceOp op, ArrayRef properOperands, ArrayRef destinations, - ArrayRef> operands, + ArrayRef> operands, ConversionPatternRewriter &rewriter) const { if (!match(op)) return matchFailure(); @@ -288,7 +288,7 @@ struct OpConversionPattern : public ConversionPattern { } virtual PatternMatchResult - matchAndRewrite(SourceOp op, ArrayRef operands, + matchAndRewrite(SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (!match(op)) return matchFailure(); @@ -330,11 +330,11 @@ public: TypeConverter::SignatureConversion &conversion); /// Replace all the uses of the block argument `from` with value `to`. - void replaceUsesOfBlockArgument(BlockArgument *from, Value *to); + void replaceUsesOfBlockArgument(BlockArgumentPtr from, ValuePtr to); /// Return the converted value that replaces 'key'. Return 'key' if there is /// no such a converted value. - Value *getRemappedValue(Value *key); + ValuePtr getRemappedValue(ValuePtr key); //===--------------------------------------------------------------------===// // PatternRewriter Hooks diff --git a/mlir/include/mlir/Transforms/FoldUtils.h b/mlir/include/mlir/Transforms/FoldUtils.h index bdf88d3bfb2..65dd1b6df16 100644 --- a/mlir/include/mlir/Transforms/FoldUtils.h +++ b/mlir/include/mlir/Transforms/FoldUtils.h @@ -82,7 +82,7 @@ public: /// and immediately try to fold it. This function populates 'results' with /// the results after folding the operation. template - void create(OpBuilder &builder, SmallVectorImpl &results, + void create(OpBuilder &builder, SmallVectorImpl &results, Location location, Args &&... args) { Operation *op = builder.create(location, std::forward(args)...); if (failed(tryToFold(op, results))) @@ -94,9 +94,9 @@ public: /// Overload to create or fold a single result operation. template typename std::enable_if(), - Value *>::type + ValuePtr>::type create(OpBuilder &builder, Location location, Args &&... args) { - SmallVector results; + SmallVector results; create(builder, results, location, std::forward(args)...); return results.front(); } @@ -107,7 +107,7 @@ public: OpTy>::type create(OpBuilder &builder, Location location, Args &&... args) { auto op = builder.create(location, std::forward(args)...); - SmallVector unused; + SmallVector unused; (void)tryToFold(op.getOperation(), unused); // Folding cannot remove a zero-result operation, so for convenience we @@ -126,7 +126,7 @@ private: /// Tries to perform folding on the given `op`. If successful, populates /// `results` with the results of the folding. LogicalResult tryToFold( - Operation *op, SmallVectorImpl &results, + Operation *op, SmallVectorImpl &results, function_ref processGeneratedConstants = nullptr); /// Try to get or create a new constant entry. On success this returns the diff --git a/mlir/include/mlir/Transforms/InliningUtils.h b/mlir/include/mlir/Transforms/InliningUtils.h index 590b46a5d12..47c4f48f468 100644 --- a/mlir/include/mlir/Transforms/InliningUtils.h +++ b/mlir/include/mlir/Transforms/InliningUtils.h @@ -105,7 +105,7 @@ public: /// operation). The given 'op' will be removed by the caller, after this /// function has been called. virtual void handleTerminator(Operation *op, - ArrayRef valuesToReplace) const { + ArrayRef valuesToReplace) const { llvm_unreachable( "must implement handleTerminator in the case of one inlined block"); } @@ -125,8 +125,8 @@ public: /// ... = foo.call @foo(%input : i32) -> i16 /// /// NOTE: This hook may be invoked before the 'isLegal' checks above. - virtual Operation *materializeCallConversion(OpBuilder &builder, Value *input, - Type resultType, + virtual Operation *materializeCallConversion(OpBuilder &builder, + ValuePtr input, Type resultType, Location conversionLoc) const { return nullptr; } @@ -165,7 +165,7 @@ public: virtual void handleTerminator(Operation *op, Block *newDest) const; virtual void handleTerminator(Operation *op, - ArrayRef valuesToRepl) const; + ArrayRef valuesToRepl) const; }; //===----------------------------------------------------------------------===// @@ -187,7 +187,7 @@ public: /// be cloned into the 'inlinePoint' or spliced directly. LogicalResult inlineRegion(InlinerInterface &interface, Region *src, Operation *inlinePoint, BlockAndValueMapping &mapper, - ArrayRef resultsToReplace, + ArrayRef resultsToReplace, Optional inlineLoc = llvm::None, bool shouldCloneInlinedRegion = true); @@ -196,8 +196,8 @@ LogicalResult inlineRegion(InlinerInterface &interface, Region *src, /// in-favor of the region arguments when inlining. LogicalResult inlineRegion(InlinerInterface &interface, Region *src, Operation *inlinePoint, - ArrayRef inlinedOperands, - ArrayRef resultsToReplace, + ArrayRef inlinedOperands, + ArrayRef resultsToReplace, Optional inlineLoc = llvm::None, bool shouldCloneInlinedRegion = true); diff --git a/mlir/include/mlir/Transforms/LoopLikeInterface.td b/mlir/include/mlir/Transforms/LoopLikeInterface.td index 5c324b79f67..583cfe26d87 100644 --- a/mlir/include/mlir/Transforms/LoopLikeInterface.td +++ b/mlir/include/mlir/Transforms/LoopLikeInterface.td @@ -38,7 +38,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> { explicit capture of dependencies, an implementation could check whether the value corresponds to a captured dependency. }], - "bool", "isDefinedOutsideOfLoop", (ins "Value *":$value) + "bool", "isDefinedOutsideOfLoop", (ins "ValuePtr ":$value) >, InterfaceMethod<[{ Returns the region that makes up the body of the loop and should be diff --git a/mlir/include/mlir/Transforms/LoopUtils.h b/mlir/include/mlir/Transforms/LoopUtils.h index 5ca3f7f6510..37434ea2ea8 100644 --- a/mlir/include/mlir/Transforms/LoopUtils.h +++ b/mlir/include/mlir/Transforms/LoopUtils.h @@ -85,7 +85,7 @@ void promoteSingleIterationLoops(FuncOp f); /// expression. void getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor, AffineMap *map, - SmallVectorImpl *operands, + SmallVectorImpl *operands, OpBuilder &builder); /// Skew the operations in the body of a 'affine.for' operation with the @@ -140,7 +140,7 @@ SmallVector, 8> tile(ArrayRef forOps, ArrayRef sizes, ArrayRef targets); SmallVector tile(ArrayRef forOps, - ArrayRef sizes, + ArrayRef sizes, ArrayRef targets); /// Performs tiling (with interchange) by strip-mining the `forOps` by `sizes` @@ -149,7 +149,7 @@ SmallVector tile(ArrayRef forOps, /// `target`. SmallVector tile(ArrayRef forOps, ArrayRef sizes, AffineForOp target); -Loops tile(ArrayRef forOps, ArrayRef sizes, +Loops tile(ArrayRef forOps, ArrayRef sizes, loop::ForOp target); /// Tile a nest of loop::ForOp loops rooted at `rootForOp` with the given @@ -157,7 +157,7 @@ Loops tile(ArrayRef forOps, ArrayRef sizes, /// runtime. If more sizes than loops are provided, discard the trailing values /// in sizes. Assumes the loop nest is permutable. /// Returns the newly created intra-tile loops. -Loops tilePerfectlyNested(loop::ForOp rootForOp, ArrayRef sizes); +Loops tilePerfectlyNested(loop::ForOp rootForOp, ArrayRef sizes); /// Explicit copy / DMA generation options for mlir::affineDataCopyGenerate. struct AffineCopyOptions { @@ -229,8 +229,8 @@ void coalesceLoops(MutableArrayRef loops); /// ... /// } /// ``` -void mapLoopToProcessorIds(loop::ForOp forOp, ArrayRef processorId, - ArrayRef numProcessors); +void mapLoopToProcessorIds(loop::ForOp forOp, ArrayRef processorId, + ArrayRef numProcessors); } // end namespace mlir #endif // MLIR_TRANSFORMS_LOOP_UTILS_H diff --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h index 48080b26c2c..63236d6a5a0 100644 --- a/mlir/include/mlir/Transforms/RegionUtils.h +++ b/mlir/include/mlir/Transforms/RegionUtils.h @@ -30,14 +30,14 @@ namespace mlir { /// of `limit`. template bool areValuesDefinedAbove(Range values, Region &limit) { - for (Value *v : values) + for (ValuePtr v : values) if (!v->getParentRegion()->isProperAncestor(&limit)) return false; return true; } /// Replace all uses of `orig` within the given region with `replacement`. -void replaceAllUsesInRegionWith(Value *orig, Value *replacement, +void replaceAllUsesInRegionWith(ValuePtr orig, ValuePtr replacement, Region ®ion); /// Calls `callback` for each use of a value within `region` or its descendants @@ -53,12 +53,12 @@ void visitUsedValuesDefinedAbove(MutableArrayRef regions, /// Fill `values` with a list of values defined at the ancestors of the `limit` /// region and used within `region` or its descendants. void getUsedValuesDefinedAbove(Region ®ion, Region &limit, - llvm::SetVector &values); + llvm::SetVector &values); /// Fill `values` with a list of values used within any of the regions provided /// but defined in one of the ancestors. void getUsedValuesDefinedAbove(MutableArrayRef regions, - llvm::SetVector &values); + llvm::SetVector &values); /// Run a set of structural simplifications over the given regions. This /// includes transformations like unreachable block elimination, dead argument diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h index c682b48f331..02c368ec496 100644 --- a/mlir/include/mlir/Transforms/Utils.h +++ b/mlir/include/mlir/Transforms/Utils.h @@ -66,22 +66,22 @@ class OpBuilder; // extra operands, note that 'indexRemap' would just be applied to existing // indices (%i, %j). // TODO(bondhugula): allow extraIndices to be added at any position. -LogicalResult replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, - ArrayRef extraIndices = {}, +LogicalResult replaceAllMemRefUsesWith(ValuePtr oldMemRef, ValuePtr newMemRef, + ArrayRef extraIndices = {}, AffineMap indexRemap = AffineMap(), - ArrayRef extraOperands = {}, - ArrayRef symbolOperands = {}, + ArrayRef extraOperands = {}, + ArrayRef symbolOperands = {}, Operation *domInstFilter = nullptr, Operation *postDomInstFilter = nullptr); /// Performs the same replacement as the other version above but only for the /// dereferencing uses of `oldMemRef` in `op`. -LogicalResult replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, +LogicalResult replaceAllMemRefUsesWith(ValuePtr oldMemRef, ValuePtr newMemRef, Operation *op, - ArrayRef extraIndices = {}, + ArrayRef extraIndices = {}, AffineMap indexRemap = AffineMap(), - ArrayRef extraOperands = {}, - ArrayRef symbolOperands = {}); + ArrayRef extraOperands = {}, + ArrayRef symbolOperands = {}); /// Rewrites the memref defined by this alloc op to have an identity layout map /// and updates all its indexing uses. Returns failure if any of its uses @@ -96,9 +96,9 @@ LogicalResult normalizeMemRef(AllocOp op); /// The final results of the composed AffineApplyOp are returned in output /// parameter 'results'. Returns the affine apply op created. Operation *createComposedAffineApplyOp(OpBuilder &builder, Location loc, - ArrayRef operands, + ArrayRef operands, ArrayRef affineApplyOps, - SmallVectorImpl *results); + SmallVectorImpl *results); /// Given an operation, inserts one or more single result affine apply /// operations, results of which are exclusively used by this operation. diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index 97868a56524..60b2f17292b 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -48,15 +48,15 @@ using llvm::dbgs; // TODO(andydavis) Add a method to AffineApplyOp which forward substitutes // the AffineApplyOp into any user AffineApplyOps. void mlir::getReachableAffineApplyOps( - ArrayRef operands, SmallVectorImpl &affineApplyOps) { + ArrayRef operands, SmallVectorImpl &affineApplyOps) { struct State { // The ssa value for this node in the DFS traversal. - Value *value; + ValuePtr value; // The operand index of 'value' to explore next during DFS traversal. unsigned operandIndex; }; SmallVector worklist; - for (auto *operand : operands) { + for (auto operand : operands) { worklist.push_back({operand, 0}); } @@ -77,7 +77,7 @@ void mlir::getReachableAffineApplyOps( if (state.operandIndex < opInst->getNumOperands()) { // Visit: Add next 'affineApplyOp' operand to worklist. // Get next operand to visit at 'operandIndex'. - auto *nextOperand = opInst->getOperand(state.operandIndex); + auto nextOperand = opInst->getOperand(state.operandIndex); // Increment 'operandIndex' in 'state'. ++state.operandIndex; // Add 'nextOperand' to worklist. @@ -99,7 +99,7 @@ void mlir::getReachableAffineApplyOps( // setExprStride(ArrayRef expr, int64_t stride) LogicalResult mlir::getIndexSet(MutableArrayRef forOps, FlatAffineConstraints *domain) { - SmallVector indices; + SmallVector indices; extractForInductionVars(forOps, &indices); // Reset while associated Values in 'indices' to the domain. domain->reset(forOps.size(), /*numSymbols=*/0, /*numLocals=*/0, indices); @@ -146,25 +146,25 @@ static LogicalResult getInstIndexSet(Operation *op, // of maps to check. So getSrcDimOrSymPos would be "getPos(value, {0, 2})". class ValuePositionMap { public: - void addSrcValue(Value *value) { + void addSrcValue(ValuePtr value) { if (addValueAt(value, &srcDimPosMap, numSrcDims)) ++numSrcDims; } - void addDstValue(Value *value) { + void addDstValue(ValuePtr value) { if (addValueAt(value, &dstDimPosMap, numDstDims)) ++numDstDims; } - void addSymbolValue(Value *value) { + void addSymbolValue(ValuePtr value) { if (addValueAt(value, &symbolPosMap, numSymbols)) ++numSymbols; } - unsigned getSrcDimOrSymPos(Value *value) const { + unsigned getSrcDimOrSymPos(ValuePtr value) const { return getDimOrSymPos(value, srcDimPosMap, 0); } - unsigned getDstDimOrSymPos(Value *value) const { + unsigned getDstDimOrSymPos(ValuePtr value) const { return getDimOrSymPos(value, dstDimPosMap, numSrcDims); } - unsigned getSymPos(Value *value) const { + unsigned getSymPos(ValuePtr value) const { auto it = symbolPosMap.find(value); assert(it != symbolPosMap.end()); return numSrcDims + numDstDims + it->second; @@ -176,7 +176,7 @@ public: unsigned getNumSymbols() const { return numSymbols; } private: - bool addValueAt(Value *value, DenseMap *posMap, + bool addValueAt(ValuePtr value, DenseMap *posMap, unsigned position) { auto it = posMap->find(value); if (it == posMap->end()) { @@ -185,8 +185,8 @@ private: } return false; } - unsigned getDimOrSymPos(Value *value, - const DenseMap &dimPosMap, + unsigned getDimOrSymPos(ValuePtr value, + const DenseMap &dimPosMap, unsigned dimPosOffset) const { auto it = dimPosMap.find(value); if (it != dimPosMap.end()) { @@ -200,9 +200,9 @@ private: unsigned numSrcDims = 0; unsigned numDstDims = 0; unsigned numSymbols = 0; - DenseMap srcDimPosMap; - DenseMap dstDimPosMap; - DenseMap symbolPosMap; + DenseMap srcDimPosMap; + DenseMap dstDimPosMap; + DenseMap symbolPosMap; }; // Builds a map from Value to identifier position in a new merged identifier @@ -219,9 +219,9 @@ static void buildDimAndSymbolPositionMaps( const FlatAffineConstraints &dstDomain, const AffineValueMap &srcAccessMap, const AffineValueMap &dstAccessMap, ValuePositionMap *valuePosMap, FlatAffineConstraints *dependenceConstraints) { - auto updateValuePosMap = [&](ArrayRef values, bool isSrc) { + auto updateValuePosMap = [&](ArrayRef values, bool isSrc) { for (unsigned i = 0, e = values.size(); i < e; ++i) { - auto *value = values[i]; + auto value = values[i]; if (!isForInductionVar(values[i])) { assert(isValidSymbol(values[i]) && "access operand has to be either a loop IV or a symbol"); @@ -234,7 +234,7 @@ static void buildDimAndSymbolPositionMaps( } }; - SmallVector srcValues, destValues; + SmallVector srcValues, destValues; srcDomain.getIdValues(0, srcDomain.getNumDimAndSymbolIds(), &srcValues); dstDomain.getIdValues(0, dstDomain.getNumDimAndSymbolIds(), &destValues); // Update value position map with identifiers from src iteration domain. @@ -273,7 +273,7 @@ void initDependenceConstraints(const FlatAffineConstraints &srcDomain, numLocals); // Set values corresponding to dependence constraint identifiers. - SmallVector srcLoopIVs, dstLoopIVs; + SmallVector srcLoopIVs, dstLoopIVs; srcDomain.getIdValues(0, srcDomain.getNumDimIds(), &srcLoopIVs); dstDomain.getIdValues(0, dstDomain.getNumDimIds(), &dstLoopIVs); @@ -282,8 +282,8 @@ void initDependenceConstraints(const FlatAffineConstraints &srcDomain, srcLoopIVs.size(), srcLoopIVs.size() + dstLoopIVs.size(), dstLoopIVs); // Set values for the symbolic identifier dimensions. - auto setSymbolIds = [&](ArrayRef values) { - for (auto *value : values) { + auto setSymbolIds = [&](ArrayRef values) { + for (auto value : values) { if (!isForInductionVar(value)) { assert(isValidSymbol(value) && "expected symbol"); dependenceConstraints->setIdValue(valuePosMap.getSymPos(value), value); @@ -294,7 +294,7 @@ void initDependenceConstraints(const FlatAffineConstraints &srcDomain, setSymbolIds(srcAccessMap.getOperands()); setSymbolIds(dstAccessMap.getOperands()); - SmallVector srcSymbolValues, dstSymbolValues; + SmallVector srcSymbolValues, dstSymbolValues; srcDomain.getIdValues(srcDomain.getNumDimIds(), srcDomain.getNumDimAndSymbolIds(), &srcSymbolValues); dstDomain.getIdValues(dstDomain.getNumDimIds(), @@ -398,10 +398,10 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap, unsigned numResults = srcMap.getNumResults(); unsigned srcNumIds = srcMap.getNumDims() + srcMap.getNumSymbols(); - ArrayRef srcOperands = srcAccessMap.getOperands(); + ArrayRef srcOperands = srcAccessMap.getOperands(); unsigned dstNumIds = dstMap.getNumDims() + dstMap.getNumSymbols(); - ArrayRef dstOperands = dstAccessMap.getOperands(); + ArrayRef dstOperands = dstAccessMap.getOperands(); std::vector> srcFlatExprs; std::vector> destFlatExprs; @@ -457,11 +457,11 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap, } // Add equality constraints for any operands that are defined by constant ops. - auto addEqForConstOperands = [&](ArrayRef operands) { + auto addEqForConstOperands = [&](ArrayRef operands) { for (unsigned i = 0, e = operands.size(); i < e; ++i) { if (isForInductionVar(operands[i])) continue; - auto *symbol = operands[i]; + auto symbol = operands[i]; assert(isValidSymbol(symbol)); // Check if the symbol is a constant. if (auto cOp = dyn_cast_or_null(symbol->getDefiningOp())) @@ -553,7 +553,7 @@ static Block *getCommonBlock(const MemRefAccess &srcAccess, } return block; } - auto *commonForValue = srcDomain.getIdValue(numCommonLoops - 1); + auto commonForValue = srcDomain.getIdValue(numCommonLoops - 1); auto forOp = getForInductionVarOwner(commonForValue); assert(forOp && "commonForValue was not an induction variable"); return forOp.getBody(); @@ -675,7 +675,7 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const { map = loadOp.getAffineMap(); else if (auto storeOp = dyn_cast(opInst)) map = storeOp.getAffineMap(); - SmallVector operands(indices.begin(), indices.end()); + SmallVector operands(indices.begin(), indices.end()); fullyComposeAffineMapAndOperands(&map, &operands); map = simplifyAffineMap(map); canonicalizeMapAndOperands(&map, &operands); diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index d678355880e..21c2830c016 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -204,8 +204,8 @@ MutableIntegerSet::MutableIntegerSet(unsigned numDims, unsigned numSymbols, // AffineValueMap. //===----------------------------------------------------------------------===// -AffineValueMap::AffineValueMap(AffineMap map, ArrayRef operands, - ArrayRef results) +AffineValueMap::AffineValueMap(AffineMap map, ArrayRef operands, + ArrayRef results) : map(map), operands(operands.begin(), operands.end()), results(results.begin(), results.end()) {} @@ -219,8 +219,8 @@ AffineValueMap::AffineValueMap(AffineBound bound) : map(bound.getMap()), operands(bound.operand_begin(), bound.operand_end()) {} -void AffineValueMap::reset(AffineMap map, ArrayRef operands, - ArrayRef results) { +void AffineValueMap::reset(AffineMap map, ArrayRef operands, + ArrayRef results) { this->map.reset(map); this->operands.assign(operands.begin(), operands.end()); this->results.assign(results.begin(), results.end()); @@ -232,14 +232,14 @@ void AffineValueMap::difference(const AffineValueMap &a, // Fully compose A's map + operands. auto aMap = a.getAffineMap(); - SmallVector aOperands(a.getOperands().begin(), - a.getOperands().end()); + SmallVector aOperands(a.getOperands().begin(), + a.getOperands().end()); fullyComposeAffineMapAndOperands(&aMap, &aOperands); // Use the affine apply normalizer to get B's map into A's coordinate space. AffineApplyNormalizer normalizer(aMap, aOperands); - SmallVector bOperands(b.getOperands().begin(), - b.getOperands().end()); + SmallVector bOperands(b.getOperands().begin(), + b.getOperands().end()); auto bMap = b.getAffineMap(); normalizer.normalize(&bMap, &bOperands); @@ -263,7 +263,7 @@ void AffineValueMap::difference(const AffineValueMap &a, // Returns true and sets 'indexOfMatch' if 'valueToMatch' is found in // 'valuesToSearch' beginning at 'indexStart'. Returns false otherwise. -static bool findIndex(Value *valueToMatch, ArrayRef valuesToSearch, +static bool findIndex(ValuePtr valueToMatch, ArrayRef valuesToSearch, unsigned indexStart, unsigned *indexOfMatch) { unsigned size = valuesToSearch.size(); for (unsigned i = indexStart; i < size; ++i) { @@ -281,7 +281,7 @@ inline bool AffineValueMap::isMultipleOf(unsigned idx, int64_t factor) const { /// This method uses the invariant that operands are always positionally aligned /// with the AffineDimExpr in the underlying AffineMap. -bool AffineValueMap::isFunctionOf(unsigned idx, Value *value) const { +bool AffineValueMap::isFunctionOf(unsigned idx, ValuePtr value) const { unsigned index; if (!findIndex(value, operands, /*indexStart=*/0, &index)) { return false; @@ -292,12 +292,12 @@ bool AffineValueMap::isFunctionOf(unsigned idx, Value *value) const { return expr.isFunctionOfDim(index); } -Value *AffineValueMap::getOperand(unsigned i) const { - return static_cast(operands[i]); +ValuePtr AffineValueMap::getOperand(unsigned i) const { + return static_cast(operands[i]); } -ArrayRef AffineValueMap::getOperands() const { - return ArrayRef(operands); +ArrayRef AffineValueMap::getOperands() const { + return ArrayRef(operands); } AffineMap AffineValueMap::getAffineMap() const { return map.getAffineMap(); } @@ -378,7 +378,7 @@ void FlatAffineConstraints::reset(unsigned numReservedInequalities, unsigned newNumReservedCols, unsigned newNumDims, unsigned newNumSymbols, unsigned newNumLocals, - ArrayRef idArgs) { + ArrayRef idArgs) { assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 && "minimum 1 column"); numReservedCols = newNumReservedCols; @@ -401,7 +401,7 @@ void FlatAffineConstraints::reset(unsigned numReservedInequalities, void FlatAffineConstraints::reset(unsigned newNumDims, unsigned newNumSymbols, unsigned newNumLocals, - ArrayRef idArgs) { + ArrayRef idArgs) { reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims, newNumSymbols, newNumLocals, idArgs); } @@ -428,17 +428,17 @@ void FlatAffineConstraints::addLocalId(unsigned pos) { addId(IdKind::Local, pos); } -void FlatAffineConstraints::addDimId(unsigned pos, Value *id) { +void FlatAffineConstraints::addDimId(unsigned pos, ValuePtr id) { addId(IdKind::Dimension, pos, id); } -void FlatAffineConstraints::addSymbolId(unsigned pos, Value *id) { +void FlatAffineConstraints::addSymbolId(unsigned pos, ValuePtr id) { addId(IdKind::Symbol, pos, id); } /// Adds a dimensional identifier. The added column is initialized to /// zero. -void FlatAffineConstraints::addId(IdKind kind, unsigned pos, Value *id) { +void FlatAffineConstraints::addId(IdKind kind, unsigned pos, ValuePtr id) { if (kind == IdKind::Dimension) { assert(pos <= getNumDimIds()); } else if (kind == IdKind::Symbol) { @@ -527,7 +527,7 @@ bool FlatAffineConstraints::areIdsAlignedWithOther( /// Checks if the SSA values associated with `cst''s identifiers are unique. static bool LLVM_ATTRIBUTE_UNUSED areIdsUnique(const FlatAffineConstraints &cst) { - SmallPtrSet uniqueIds; + SmallPtrSet uniqueIds; for (auto id : cst.getIds()) { if (id.hasValue() && !uniqueIds.insert(id.getValue()).second) return false; @@ -571,11 +571,11 @@ static void mergeAndAlignIds(unsigned offset, FlatAffineConstraints *A, assert(std::all_of(A->getIds().begin() + offset, A->getIds().begin() + A->getNumDimAndSymbolIds(), - [](Optional id) { return id.hasValue(); })); + [](Optional id) { return id.hasValue(); })); assert(std::all_of(B->getIds().begin() + offset, B->getIds().begin() + B->getNumDimAndSymbolIds(), - [](Optional id) { return id.hasValue(); })); + [](Optional id) { return id.hasValue(); })); // Place local id's of A after local id's of B. for (unsigned l = 0, e = A->getNumLocalIds(); l < e; l++) { @@ -586,13 +586,13 @@ static void mergeAndAlignIds(unsigned offset, FlatAffineConstraints *A, A->addLocalId(A->getNumLocalIds()); } - SmallVector aDimValues, aSymValues; + SmallVector aDimValues, aSymValues; A->getIdValues(offset, A->getNumDimIds(), &aDimValues); A->getIdValues(A->getNumDimIds(), A->getNumDimAndSymbolIds(), &aSymValues); { // Merge dims from A into B. unsigned d = offset; - for (auto *aDimValue : aDimValues) { + for (auto aDimValue : aDimValues) { unsigned loc; if (B->findId(*aDimValue, &loc)) { assert(loc >= offset && "A's dim appears in B's aligned range"); @@ -615,7 +615,7 @@ static void mergeAndAlignIds(unsigned offset, FlatAffineConstraints *A, { // Merge symbols: merge A's symbols into B first. unsigned s = B->getNumDimIds(); - for (auto *aSymValue : aSymValues) { + for (auto aSymValue : aSymValues) { unsigned loc; if (B->findId(*aSymValue, &loc)) { assert(loc >= B->getNumDimIds() && loc < B->getNumDimAndSymbolIds() && @@ -785,7 +785,7 @@ LogicalResult FlatAffineConstraints::composeMatchingMap(AffineMap other) { } // Turn a dimension into a symbol. -static void turnDimIntoSymbol(FlatAffineConstraints *cst, Value &id) { +static void turnDimIntoSymbol(FlatAffineConstraints *cst, ValueRef id) { unsigned pos; if (cst->findId(id, &pos) && pos < cst->getNumDimIds()) { swapId(cst, pos, cst->getNumDimIds() - 1); @@ -794,7 +794,7 @@ static void turnDimIntoSymbol(FlatAffineConstraints *cst, Value &id) { } // Turn a symbol into a dimension. -static void turnSymbolIntoDim(FlatAffineConstraints *cst, Value &id) { +static void turnSymbolIntoDim(FlatAffineConstraints *cst, ValueRef id) { unsigned pos; if (cst->findId(id, &pos) && pos >= cst->getNumDimIds() && pos < cst->getNumDimAndSymbolIds()) { @@ -806,18 +806,18 @@ static void turnSymbolIntoDim(FlatAffineConstraints *cst, Value &id) { // Changes all symbol identifiers which are loop IVs to dim identifiers. void FlatAffineConstraints::convertLoopIVSymbolsToDims() { // Gather all symbols which are loop IVs. - SmallVector loopIVs; + SmallVector loopIVs; for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++) { if (ids[i].hasValue() && getForInductionVarOwner(ids[i].getValue())) loopIVs.push_back(ids[i].getValue()); } // Turn each symbol in 'loopIVs' into a dim identifier. - for (auto *iv : loopIVs) { + for (auto iv : loopIVs) { turnSymbolIntoDim(this, *iv); } } -void FlatAffineConstraints::addInductionVarOrTerminalSymbol(Value *id) { +void FlatAffineConstraints::addInductionVarOrTerminalSymbol(ValuePtr id) { if (containsId(*id)) return; @@ -876,8 +876,8 @@ LogicalResult FlatAffineConstraints::addAffineForOpDomain(AffineForOp forOp) { addConstantLowerBound(pos, forOp.getConstantLowerBound()); } else { // Non-constant lower bound case. - SmallVector lbOperands(forOp.getLowerBoundOperands().begin(), - forOp.getLowerBoundOperands().end()); + SmallVector lbOperands(forOp.getLowerBoundOperands().begin(), + forOp.getLowerBoundOperands().end()); if (failed(addLowerOrUpperBound(pos, forOp.getLowerBoundMap(), lbOperands, /*eq=*/false, /*lower=*/true))) return failure(); @@ -888,8 +888,8 @@ LogicalResult FlatAffineConstraints::addAffineForOpDomain(AffineForOp forOp) { return success(); } // Non-constant upper bound case. - SmallVector ubOperands(forOp.getUpperBoundOperands().begin(), - forOp.getUpperBoundOperands().end()); + SmallVector ubOperands(forOp.getUpperBoundOperands().begin(), + forOp.getUpperBoundOperands().end()); return addLowerOrUpperBound(pos, forOp.getUpperBoundMap(), ubOperands, /*eq=*/false, /*lower=*/false); } @@ -1757,7 +1757,7 @@ void FlatAffineConstraints::getSliceBounds(unsigned offset, unsigned num, LogicalResult FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap, - ArrayRef boundOperands, + ArrayRef boundOperands, bool eq, bool lower) { assert(pos < getNumDimAndSymbolIds() && "invalid position"); // Equality follows the logic of lower bound except that we add an equality @@ -1769,11 +1769,11 @@ FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap, // Fully compose map and operands; canonicalize and simplify so that we // transitively get to terminal symbols or loop IVs. auto map = boundMap; - SmallVector operands(boundOperands.begin(), boundOperands.end()); + SmallVector operands(boundOperands.begin(), boundOperands.end()); fullyComposeAffineMapAndOperands(&map, &operands); map = simplifyAffineMap(map); canonicalizeMapAndOperands(&map, &operands); - for (auto *operand : operands) + for (auto operand : operands) addInductionVarOrTerminalSymbol(operand); FlatAffineConstraints localVarCst; @@ -1787,7 +1787,7 @@ FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap, if (localVarCst.getNumLocalIds() > 0) { // Set values for localVarCst. localVarCst.setIdValues(0, localVarCst.getNumDimAndSymbolIds(), operands); - for (auto *operand : operands) { + for (auto operand : operands) { unsigned pos; if (findId(*operand, &pos)) { if (pos >= getNumDimIds() && pos < getNumDimAndSymbolIds()) { @@ -1807,7 +1807,7 @@ FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap, // this here since the constraint system changes after a bound is added. SmallVector positions; unsigned numOperands = operands.size(); - for (auto *operand : operands) { + for (auto operand : operands) { unsigned pos; if (!findId(*operand, &pos)) assert(0 && "expected to be found"); @@ -1848,8 +1848,8 @@ FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap, // Returns failure for unimplemented cases such as semi-affine expressions or // expressions with mod/floordiv. LogicalResult FlatAffineConstraints::addSliceBounds( - ArrayRef values, ArrayRef lbMaps, - ArrayRef ubMaps, ArrayRef operands) { + ArrayRef values, ArrayRef lbMaps, + ArrayRef ubMaps, ArrayRef operands) { assert(values.size() == lbMaps.size()); assert(lbMaps.size() == ubMaps.size()); @@ -1971,7 +1971,7 @@ void FlatAffineConstraints::addLocalFloorDiv(ArrayRef dividend, addInequality(bound); } -bool FlatAffineConstraints::findId(Value &id, unsigned *pos) const { +bool FlatAffineConstraints::findId(ValueRef id, unsigned *pos) const { unsigned i = 0; for (const auto &mayBeId : ids) { if (mayBeId.hasValue() && mayBeId.getValue() == &id) { @@ -1983,8 +1983,8 @@ bool FlatAffineConstraints::findId(Value &id, unsigned *pos) const { return false; } -bool FlatAffineConstraints::containsId(Value &id) const { - return llvm::any_of(ids, [&](const Optional &mayBeId) { +bool FlatAffineConstraints::containsId(ValueRef id) const { + return llvm::any_of(ids, [&](const Optional &mayBeId) { return mayBeId.hasValue() && mayBeId.getValue() == &id; }); } @@ -2008,7 +2008,7 @@ void FlatAffineConstraints::setIdToConstant(unsigned pos, int64_t val) { /// Sets the specified identifier to a constant value; asserts if the id is not /// found. -void FlatAffineConstraints::setIdToConstant(Value &id, int64_t val) { +void FlatAffineConstraints::setIdToConstant(ValueRef id, int64_t val) { unsigned pos; if (!findId(id, &pos)) // This is a pre-condition for this method. @@ -2573,7 +2573,7 @@ void FlatAffineConstraints::FourierMotzkinEliminate( unsigned newNumDims = dimsSymbols.first; unsigned newNumSymbols = dimsSymbols.second; - SmallVector, 8> newIds; + SmallVector, 8> newIds; newIds.reserve(numIds - 1); newIds.append(ids.begin(), ids.begin() + pos); newIds.append(ids.begin() + pos + 1, ids.end()); @@ -2709,7 +2709,7 @@ void FlatAffineConstraints::projectOut(unsigned pos, unsigned num) { normalizeConstraintsByGCD(); } -void FlatAffineConstraints::projectOut(Value *id) { +void FlatAffineConstraints::projectOut(ValuePtr id) { unsigned pos; bool ret = findId(*id, &pos); assert(ret); diff --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp index 93017ca3b57..6ec7c059526 100644 --- a/mlir/lib/Analysis/CallGraph.cpp +++ b/mlir/lib/Analysis/CallGraph.cpp @@ -188,7 +188,7 @@ CallGraphNode *CallGraph::resolveCallable(CallInterfaceCallable callable, callee = SymbolTable::lookupNearestSymbolFrom(from, symbolRef.getRootReference()); else - callee = callable.get()->getDefiningOp(); + callee = callable.get()->getDefiningOp(); // If the callee is non-null and is a valid callable object, try to get the // called region from it. diff --git a/mlir/lib/Analysis/Dominance.cpp b/mlir/lib/Analysis/Dominance.cpp index c422578320f..532972b771b 100644 --- a/mlir/lib/Analysis/Dominance.cpp +++ b/mlir/lib/Analysis/Dominance.cpp @@ -127,7 +127,7 @@ bool DominanceInfo::properlyDominates(Operation *a, Operation *b) { } /// Return true if value A properly dominates operation B. -bool DominanceInfo::properlyDominates(Value *a, Operation *b) { +bool DominanceInfo::properlyDominates(ValuePtr a, Operation *b) { if (auto *aOp = a->getDefiningOp()) { // The values defined by an operation do *not* dominate any nested // operations. diff --git a/mlir/lib/Analysis/Liveness.cpp b/mlir/lib/Analysis/Liveness.cpp index 6aaec4cc719..edb18e5645d 100644 --- a/mlir/lib/Analysis/Liveness.cpp +++ b/mlir/lib/Analysis/Liveness.cpp @@ -40,13 +40,13 @@ struct BlockInfoBuilder { /// Fills the block builder with initial liveness information. BlockInfoBuilder(Block *block) : block(block) { // Mark all block arguments (phis) as defined. - for (BlockArgument *argument : block->getArguments()) + for (BlockArgumentPtr argument : block->getArguments()) defValues.insert(argument); // Check all result values and whether their uses // are inside this block or not (see outValues). for (Operation &operation : *block) - for (Value *result : operation.getResults()) { + for (ValuePtr result : operation.getResults()) { defValues.insert(result); // Check whether this value will be in the outValues @@ -63,7 +63,7 @@ struct BlockInfoBuilder { // Check all operations for used operands. for (Operation &operation : block->getOperations()) - for (Value *operand : operation.getOperands()) { + for (ValuePtr operand : operation.getOperands()) { // If the operand is already defined in the scope of this // block, we can skip the value in the use set. if (!defValues.count(operand)) @@ -173,7 +173,7 @@ void Liveness::build(MutableArrayRef regions) { } /// Gets liveness info (if any) for the given value. -Liveness::OperationListT Liveness::resolveLiveness(Value *value) const { +Liveness::OperationListT Liveness::resolveLiveness(ValuePtr value) const { OperationListT result; SmallPtrSet visited; SmallVector toProcess; @@ -238,7 +238,7 @@ const Liveness::ValueSetT &Liveness::getLiveOut(Block *block) const { /// Returns true if the given operation represent the last use of the /// given value. -bool Liveness::isLastUse(Value *value, Operation *operation) const { +bool Liveness::isLastUse(ValuePtr value, Operation *operation) const { Block *block = operation->getBlock(); const LivenessBlockInfo *blockInfo = getLiveness(block); @@ -263,21 +263,21 @@ void Liveness::print(raw_ostream &os) const { // Builds unique block/value mappings for testing purposes. DenseMap blockIds; DenseMap operationIds; - DenseMap valueIds; + DenseMap valueIds; for (Region ®ion : operation->getRegions()) for (Block &block : region) { blockIds.insert({&block, blockIds.size()}); - for (BlockArgument *argument : block.getArguments()) + for (BlockArgumentPtr argument : block.getArguments()) valueIds.insert({argument, valueIds.size()}); for (Operation &operation : block) { operationIds.insert({&operation, operationIds.size()}); - for (Value *result : operation.getResults()) + for (ValuePtr result : operation.getResults()) valueIds.insert({result, valueIds.size()}); } } // Local printing helpers - auto printValueRef = [&](Value *value) { + auto printValueRef = [&](ValuePtr value) { if (Operation *defOp = value->getDefiningOp()) os << "val_" << defOp->getName(); else { @@ -289,12 +289,12 @@ void Liveness::print(raw_ostream &os) const { }; auto printValueRefs = [&](const ValueSetT &values) { - std::vector orderedValues(values.begin(), values.end()); + std::vector orderedValues(values.begin(), values.end()); std::sort(orderedValues.begin(), orderedValues.end(), - [&](Value *left, Value *right) { + [&](ValuePtr left, ValuePtr right) { return valueIds[left] < valueIds[right]; }); - for (Value *value : orderedValues) + for (ValuePtr value : orderedValues) printValueRef(value); }; @@ -315,7 +315,7 @@ void Liveness::print(raw_ostream &os) const { if (op.getNumResults() < 1) continue; os << "\n"; - for (Value *result : op.getResults()) { + for (ValuePtr result : op.getResults()) { os << "// "; printValueRef(result); os << ":"; @@ -340,18 +340,18 @@ void Liveness::print(raw_ostream &os) const { //===----------------------------------------------------------------------===// /// Returns true if the given value is in the live-in set. -bool LivenessBlockInfo::isLiveIn(Value *value) const { +bool LivenessBlockInfo::isLiveIn(ValuePtr value) const { return inValues.count(value); } /// Returns true if the given value is in the live-out set. -bool LivenessBlockInfo::isLiveOut(Value *value) const { +bool LivenessBlockInfo::isLiveOut(ValuePtr value) const { return outValues.count(value); } /// Gets the start operation for the given value /// (must be referenced in this block). -Operation *LivenessBlockInfo::getStartOperation(Value *value) const { +Operation *LivenessBlockInfo::getStartOperation(ValuePtr value) const { Operation *definingOp = value->getDefiningOp(); // The given value is either live-in or is defined // in the scope of this block. @@ -362,7 +362,7 @@ Operation *LivenessBlockInfo::getStartOperation(Value *value) const { /// Gets the end operation for the given value using the start operation /// provided (must be referenced in this block). -Operation *LivenessBlockInfo::getEndOperation(Value *value, +Operation *LivenessBlockInfo::getEndOperation(ValuePtr value, Operation *startOperation) const { // The given value is either dying in this block or live-out. if (isLiveOut(value)) diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index a81116579ce..9dfbfe0c542 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -43,7 +43,7 @@ using namespace mlir; // be more powerful (since both inequalities and equalities will be considered). void mlir::buildTripCountMapAndOperands( AffineForOp forOp, AffineMap *tripCountMap, - SmallVectorImpl *tripCountOperands) { + SmallVectorImpl *tripCountOperands) { int64_t loopSpan; int64_t step = forOp.getStep(); @@ -65,8 +65,8 @@ void mlir::buildTripCountMapAndOperands( *tripCountMap = AffineMap(); return; } - SmallVector lbOperands(forOp.getLowerBoundOperands()); - SmallVector ubOperands(forOp.getUpperBoundOperands()); + SmallVector lbOperands(forOp.getLowerBoundOperands()); + SmallVector ubOperands(forOp.getUpperBoundOperands()); // Difference of each upper bound expression from the single lower bound // expression (divided by the step) provides the expressions for the trip @@ -98,7 +98,7 @@ void mlir::buildTripCountMapAndOperands( // works with analysis structures (FlatAffineConstraints) and thus doesn't // update the IR. Optional mlir::getConstantTripCount(AffineForOp forOp) { - SmallVector operands; + SmallVector operands; AffineMap map; buildTripCountMapAndOperands(forOp, &map, &operands); @@ -124,7 +124,7 @@ Optional mlir::getConstantTripCount(AffineForOp forOp) { /// expression analysis is used (indirectly through getTripCount), and /// this method is thus able to determine non-trivial divisors. uint64_t mlir::getLargestDivisorOfTripCount(AffineForOp forOp) { - SmallVector operands; + SmallVector operands; AffineMap map; buildTripCountMapAndOperands(forOp, &map, &operands); @@ -173,7 +173,7 @@ uint64_t mlir::getLargestDivisorOfTripCount(AffineForOp forOp) { /// /// Returns false in cases with more than one AffineApplyOp, this is /// conservative. -static bool isAccessIndexInvariant(Value *iv, Value *index) { +static bool isAccessIndexInvariant(ValuePtr iv, ValuePtr index) { assert(isForInductionVar(iv) && "iv must be a AffineForOp"); assert(index->getType().isa() && "index must be of IndexType"); SmallVector affineApplyOps; @@ -197,11 +197,11 @@ static bool isAccessIndexInvariant(Value *iv, Value *index) { return !(AffineValueMap(composeOp).isFunctionOf(0, iv)); } -DenseSet mlir::getInvariantAccesses(Value *iv, - ArrayRef indices) { - DenseSet res; +DenseSet mlir::getInvariantAccesses(ValuePtr iv, + ArrayRef indices) { + DenseSet res; for (unsigned idx = 0, n = indices.size(); idx < n; ++idx) { - auto *val = indices[idx]; + auto val = indices[idx]; if (isAccessIndexInvariant(iv, val)) { res.insert(val); } @@ -229,7 +229,7 @@ DenseSet mlir::getInvariantAccesses(Value *iv, /// // TODO(ntv): check strides. template -static bool isContiguousAccess(Value *iv, LoadOrStoreOp memoryOp, +static bool isContiguousAccess(ValuePtr iv, LoadOrStoreOp memoryOp, int *memRefDim) { static_assert(std::is_same::value || std::is_same::value, @@ -250,11 +250,11 @@ static bool isContiguousAccess(Value *iv, LoadOrStoreOp memoryOp, int uniqueVaryingIndexAlongIv = -1; auto accessMap = memoryOp.getAffineMap(); - SmallVector mapOperands(memoryOp.getMapOperands()); + SmallVector mapOperands(memoryOp.getMapOperands()); unsigned numDims = accessMap.getNumDims(); for (unsigned i = 0, e = memRefType.getRank(); i < e; ++i) { // Gather map operands used result expr 'i' in 'exprOperands'. - SmallVector exprOperands; + SmallVector exprOperands; auto resultExpr = accessMap.getResult(i); resultExpr.walk([&](AffineExpr expr) { if (auto dimExpr = expr.dyn_cast()) @@ -263,7 +263,7 @@ static bool isContiguousAccess(Value *iv, LoadOrStoreOp memoryOp, exprOperands.push_back(mapOperands[numDims + symExpr.getPosition()]); }); // Check access invariance of each operand in 'exprOperands'. - for (auto *exprOperand : exprOperands) { + for (auto exprOperand : exprOperands) { if (!isAccessIndexInvariant(iv, exprOperand)) { if (uniqueVaryingIndexAlongIv != -1) { // 2+ varying indices -> do not vectorize along iv. @@ -382,7 +382,7 @@ bool mlir::isInstwiseShiftValid(AffineForOp forOp, ArrayRef shifts) { // Validate the results of this operation if it were to be shifted. for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) { - Value *result = op.getResult(i); + ValuePtr result = op.getResult(i); for (auto *user : result->getUsers()) { // If an ancestor operation doesn't lie in the block of forOp, // there is no shift to check. diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index 700321ebb40..b09bddddd66 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -104,8 +104,8 @@ static void getBackwardSliceImpl(Operation *op, } for (auto en : llvm::enumerate(op->getOperands())) { - auto *operand = en.value(); - if (auto *blockArg = dyn_cast(operand)) { + auto operand = en.value(); + if (auto blockArg = dyn_cast(operand)) { if (auto affIv = getForInductionVarOwner(operand)) { auto *affOp = affIv.getOperation(); if (backwardSlice->count(affOp) == 0) diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 3ba27bbb299..73aa07e7d7b 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -60,7 +60,7 @@ ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) { // Adds operands (dst ivs and symbols) as symbols in 'cst'. unsigned numSymbols = lbOperands[0].size(); - SmallVector values(ivs); + SmallVector values(ivs); // Append 'ivs' then 'operands' to 'values'. values.append(lbOperands[0].begin(), lbOperands[0].end()); cst->reset(numDims, numSymbols, 0, values); @@ -185,7 +185,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth, if (rank == 0) { SmallVector ivs; getLoopIVs(*op, &ivs); - SmallVector regionSymbols; + SmallVector regionSymbols; extractForInductionVars(ivs, ®ionSymbols); // A rank 0 memref has a 0-d region. cst.reset(rank, loopDepth, 0, regionSymbols); @@ -201,7 +201,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth, unsigned numSymbols = accessMap.getNumSymbols(); unsigned numOperands = accessValueMap.getNumOperands(); // Merge operands with slice operands. - SmallVector operands; + SmallVector operands; operands.resize(numOperands); for (unsigned i = 0; i < numOperands; ++i) operands[i] = accessValueMap.getOperand(i); @@ -224,7 +224,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth, // Add equality constraints. // Add inequalities for loop lower/upper bounds. for (unsigned i = 0; i < numDims + numSymbols; ++i) { - auto *operand = operands[i]; + auto operand = operands[i]; if (auto loop = getForInductionVarOwner(operand)) { // Note that cst can now have more dimensions than accessMap if the // bounds expressions involve outer loops or other symbols. @@ -234,7 +234,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth, return failure(); } else { // Has to be a valid symbol. - auto *symbol = operand; + auto symbol = operand; assert(isValidSymbol(symbol)); // Check if the symbol is a constant. if (auto *op = symbol->getDefiningOp()) { @@ -278,9 +278,9 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth, getLoopIVs(*op, &enclosingIVs); assert(loopDepth <= enclosingIVs.size() && "invalid loop depth"); enclosingIVs.resize(loopDepth); - SmallVector ids; + SmallVector ids; cst.getIdValues(cst.getNumDimIds(), cst.getNumDimAndSymbolIds(), &ids); - for (auto *id : ids) { + for (auto id : ids) { AffineForOp iv; if ((iv = getForInductionVarOwner(id)) && llvm::is_contained(enclosingIVs, iv) == false) { @@ -345,9 +345,9 @@ Optional MemRefRegion::getRegionSize() { // Indices to use for the DmaStart op. // Indices for the original memref being DMAed from/to. - SmallVector memIndices; + SmallVector memIndices; // Indices for the faster buffer being DMAed into/from. - SmallVector bufIndices; + SmallVector bufIndices; // Compute the extents of the buffer. Optional numElements = getConstantBoundingSizeAndShape(); @@ -480,10 +480,10 @@ static Operation *getInstAtPosition(ArrayRef positions, } // Adds loop IV bounds to 'cst' for loop IVs not found in 'ivs'. -LogicalResult addMissingLoopIVBounds(SmallPtrSet &ivs, +LogicalResult addMissingLoopIVBounds(SmallPtrSet &ivs, FlatAffineConstraints *cst) { for (unsigned i = 0, e = cst->getNumDimIds(); i < e; ++i) { - auto *value = cst->getIdValue(i); + auto value = cst->getIdValue(i); if (ivs.count(value) == 0) { assert(isForInductionVar(value)); auto loop = getForInductionVarOwner(value); @@ -596,10 +596,10 @@ LogicalResult mlir::computeSliceUnion(ArrayRef opsA, // Pre-constraint id alignment: record loop IVs used in each constraint // system. - SmallPtrSet sliceUnionIVs; + SmallPtrSet sliceUnionIVs; for (unsigned k = 0, l = sliceUnionCst.getNumDimIds(); k < l; ++k) sliceUnionIVs.insert(sliceUnionCst.getIdValue(k)); - SmallPtrSet tmpSliceIVs; + SmallPtrSet tmpSliceIVs; for (unsigned k = 0, l = tmpSliceCst.getNumDimIds(); k < l; ++k) tmpSliceIVs.insert(tmpSliceCst.getIdValue(k)); @@ -659,7 +659,7 @@ LogicalResult mlir::computeSliceUnion(ArrayRef opsA, &sliceUnion->ubs); // Add slice bound operands of union. - SmallVector sliceBoundOperands; + SmallVector sliceBoundOperands; sliceUnionCst.getIdValues(numSliceLoopIVs, sliceUnionCst.getNumDimAndSymbolIds(), &sliceBoundOperands); @@ -725,7 +725,7 @@ void mlir::getComputationSliceState( &sliceState->lbs, &sliceState->ubs); // Set up bound operands for the slice's lower and upper bounds. - SmallVector sliceBoundOperands; + SmallVector sliceBoundOperands; unsigned numDimsAndSymbols = dependenceConstraints->getNumDimAndSymbolIds(); for (unsigned i = 0; i < numDimsAndSymbols; ++i) { if (i < offset || i >= offset + numSliceLoopIVs) { @@ -743,7 +743,7 @@ void mlir::getComputationSliceState( isBackwardSlice ? dstLoopIVs[loopDepth - 1].getBody()->begin() : std::prev(srcLoopIVs[loopDepth - 1].getBody()->end()); - llvm::SmallDenseSet sequentialLoops; + llvm::SmallDenseSet sequentialLoops; if (isa(depSourceOp) && isa(depSinkOp)) { // For read-read access pairs, clear any slice bounds on sequential loops. // Get sequential loops in loop nest rooted at 'srcLoopIVs[0]'. @@ -758,7 +758,7 @@ void mlir::getComputationSliceState( return isBackwardSlice ? srcLoopIVs[i] : dstLoopIVs[i]; }; for (unsigned i = 0; i < numSliceLoopIVs; ++i) { - Value *iv = getSliceLoop(i).getInductionVar(); + ValuePtr iv = getSliceLoop(i).getInductionVar(); if (sequentialLoops.count(iv) == 0 && getSliceLoop(i).getAttr(kSliceFusionBarrierAttrName) == nullptr) continue; @@ -846,7 +846,7 @@ MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) { opInst = loadOrStoreOpInst; auto loadMemrefType = loadOp.getMemRefType(); indices.reserve(loadMemrefType.getRank()); - for (auto *index : loadOp.getMapOperands()) { + for (auto index : loadOp.getMapOperands()) { indices.push_back(index); } } else { @@ -856,7 +856,7 @@ MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) { memref = storeOp.getMemRef(); auto storeMemrefType = storeOp.getMemRefType(); indices.reserve(storeMemrefType.getRank()); - for (auto *index : storeOp.getMapOperands()) { + for (auto index : storeOp.getMapOperands()) { indices.push_back(index); } } @@ -919,7 +919,7 @@ static Optional getMemoryFootprintBytes(Block &block, Block::iterator start, Block::iterator end, int memorySpace) { - SmallDenseMap, 4> regions; + SmallDenseMap, 4> regions; // Walk this 'affine.for' operation to gather all memory regions. auto result = block.walk(start, end, [&](Operation *opInst) -> WalkResult { @@ -970,7 +970,7 @@ Optional mlir::getMemoryFootprintBytes(AffineForOp forOp, /// Returns in 'sequentialLoops' all sequential loops in loop nest rooted /// at 'forOp'. void mlir::getSequentialLoops( - AffineForOp forOp, llvm::SmallDenseSet *sequentialLoops) { + AffineForOp forOp, llvm::SmallDenseSet *sequentialLoops) { forOp.getOperation()->walk([&](Operation *op) { if (auto innerFor = dyn_cast(op)) if (!isLoopParallel(innerFor)) diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index 42d3f10b14c..a7917eba503 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -109,7 +109,7 @@ Optional> mlir::shapeRatio(VectorType superVectorType, /// Examples can be found in the documentation of `makePermutationMap`, in the /// header file. static AffineMap makePermutationMap( - ArrayRef indices, + ArrayRef indices, const DenseMap &enclosingLoopToVectorDim) { if (enclosingLoopToVectorDim.empty()) return AffineMap(); @@ -167,7 +167,7 @@ static SetVector getEnclosingforOps(Operation *op) { } AffineMap mlir::makePermutationMap( - Operation *op, ArrayRef indices, + Operation *op, ArrayRef indices, const DenseMap &loopToVectorDim) { DenseMap enclosingLoopToVectorDim; auto enclosingLoops = getEnclosingforOps(op); diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index 82f5aa5e01c..be499a93898 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -138,7 +138,7 @@ LogicalResult OperationVerifier::verifyRegion(Region ®ion) { } LogicalResult OperationVerifier::verifyBlock(Block &block) { - for (auto *arg : block.getArguments()) + for (auto arg : block.getArguments()) if (arg->getOwner() != &block) return emitError(block, "block argument not owned by block"); @@ -175,7 +175,7 @@ LogicalResult OperationVerifier::verifyBlock(Block &block) { LogicalResult OperationVerifier::verifyOperation(Operation &op) { // Check that operands are non-nil and structurally ok. - for (auto *operand : op.getOperands()) + for (auto operand : op.getOperands()) if (!operand) return op.emitError("null operand found"); @@ -244,7 +244,7 @@ LogicalResult OperationVerifier::verifyDominance(Operation &op) { // Check that operands properly dominate this use. for (unsigned operandNo = 0, e = op.getNumOperands(); operandNo != e; ++operandNo) { - auto *operand = op.getOperand(operandNo); + auto operand = op.getOperand(operandNo); if (domInfo->properlyDominates(operand, &op)) continue; diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp index 3f613c6bfb5..144b4a97e87 100644 --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -42,16 +42,16 @@ namespace { // that correspond to it. Visitation functions return an Value of the // expression subtree they visited or `nullptr` on error. class AffineApplyExpander - : public AffineExprVisitor { + : public AffineExprVisitor { public: // This internal class expects arguments to be non-null, checks must be // performed at the call site. - AffineApplyExpander(OpBuilder &builder, ArrayRef dimValues, - ArrayRef symbolValues, Location loc) + AffineApplyExpander(OpBuilder &builder, ArrayRef dimValues, + ArrayRef symbolValues, Location loc) : builder(builder), dimValues(dimValues), symbolValues(symbolValues), loc(loc) {} - template Value *buildBinaryExpr(AffineBinaryOpExpr expr) { + template ValuePtr buildBinaryExpr(AffineBinaryOpExpr expr) { auto lhs = visit(expr.getLHS()); auto rhs = visit(expr.getRHS()); if (!lhs || !rhs) @@ -60,11 +60,11 @@ public: return op.getResult(); } - Value *visitAddExpr(AffineBinaryOpExpr expr) { + ValuePtr visitAddExpr(AffineBinaryOpExpr expr) { return buildBinaryExpr(expr); } - Value *visitMulExpr(AffineBinaryOpExpr expr) { + ValuePtr visitMulExpr(AffineBinaryOpExpr expr) { return buildBinaryExpr(expr); } @@ -77,7 +77,7 @@ public: // let remainder = srem a, b; // negative = a < 0 in // select negative, remainder + b, remainder. - Value *visitModExpr(AffineBinaryOpExpr expr) { + ValuePtr visitModExpr(AffineBinaryOpExpr expr) { auto rhsConst = expr.getRHS().dyn_cast(); if (!rhsConst) { emitError( @@ -94,13 +94,13 @@ public: auto rhs = visit(expr.getRHS()); assert(lhs && rhs && "unexpected affine expr lowering failure"); - Value *remainder = builder.create(loc, lhs, rhs); - Value *zeroCst = builder.create(loc, 0); - Value *isRemainderNegative = + ValuePtr remainder = builder.create(loc, lhs, rhs); + ValuePtr zeroCst = builder.create(loc, 0); + ValuePtr isRemainderNegative = builder.create(loc, CmpIPredicate::slt, remainder, zeroCst); - Value *correctedRemainder = builder.create(loc, remainder, rhs); - Value *result = builder.create(loc, isRemainderNegative, - correctedRemainder, remainder); + ValuePtr correctedRemainder = builder.create(loc, remainder, rhs); + ValuePtr result = builder.create(loc, isRemainderNegative, + correctedRemainder, remainder); return result; } @@ -114,7 +114,7 @@ public: // let absolute = negative ? -a - 1 : a in // let quotient = absolute / b in // negative ? -quotient - 1 : quotient - Value *visitFloorDivExpr(AffineBinaryOpExpr expr) { + ValuePtr visitFloorDivExpr(AffineBinaryOpExpr expr) { auto rhsConst = expr.getRHS().dyn_cast(); if (!rhsConst) { emitError( @@ -131,16 +131,16 @@ public: auto rhs = visit(expr.getRHS()); assert(lhs && rhs && "unexpected affine expr lowering failure"); - Value *zeroCst = builder.create(loc, 0); - Value *noneCst = builder.create(loc, -1); - Value *negative = + ValuePtr zeroCst = builder.create(loc, 0); + ValuePtr noneCst = builder.create(loc, -1); + ValuePtr negative = builder.create(loc, CmpIPredicate::slt, lhs, zeroCst); - Value *negatedDecremented = builder.create(loc, noneCst, lhs); - Value *dividend = + ValuePtr negatedDecremented = builder.create(loc, noneCst, lhs); + ValuePtr dividend = builder.create(loc, negative, negatedDecremented, lhs); - Value *quotient = builder.create(loc, dividend, rhs); - Value *correctedQuotient = builder.create(loc, noneCst, quotient); - Value *result = + ValuePtr quotient = builder.create(loc, dividend, rhs); + ValuePtr correctedQuotient = builder.create(loc, noneCst, quotient); + ValuePtr result = builder.create(loc, negative, correctedQuotient, quotient); return result; } @@ -155,7 +155,7 @@ public: // let absolute = negative ? -a : a - 1 in // let quotient = absolute / b in // negative ? -quotient : quotient + 1 - Value *visitCeilDivExpr(AffineBinaryOpExpr expr) { + ValuePtr visitCeilDivExpr(AffineBinaryOpExpr expr) { auto rhsConst = expr.getRHS().dyn_cast(); if (!rhsConst) { emitError(loc) << "semi-affine expressions (division by non-const) are " @@ -170,23 +170,24 @@ public: auto rhs = visit(expr.getRHS()); assert(lhs && rhs && "unexpected affine expr lowering failure"); - Value *zeroCst = builder.create(loc, 0); - Value *oneCst = builder.create(loc, 1); - Value *nonPositive = + ValuePtr zeroCst = builder.create(loc, 0); + ValuePtr oneCst = builder.create(loc, 1); + ValuePtr nonPositive = builder.create(loc, CmpIPredicate::sle, lhs, zeroCst); - Value *negated = builder.create(loc, zeroCst, lhs); - Value *decremented = builder.create(loc, lhs, oneCst); - Value *dividend = + ValuePtr negated = builder.create(loc, zeroCst, lhs); + ValuePtr decremented = builder.create(loc, lhs, oneCst); + ValuePtr dividend = builder.create(loc, nonPositive, negated, decremented); - Value *quotient = builder.create(loc, dividend, rhs); - Value *negatedQuotient = builder.create(loc, zeroCst, quotient); - Value *incrementedQuotient = builder.create(loc, quotient, oneCst); - Value *result = builder.create(loc, nonPositive, negatedQuotient, - incrementedQuotient); + ValuePtr quotient = builder.create(loc, dividend, rhs); + ValuePtr negatedQuotient = builder.create(loc, zeroCst, quotient); + ValuePtr incrementedQuotient = + builder.create(loc, quotient, oneCst); + ValuePtr result = builder.create( + loc, nonPositive, negatedQuotient, incrementedQuotient); return result; } - Value *visitConstantExpr(AffineConstantExpr expr) { + ValuePtr visitConstantExpr(AffineConstantExpr expr) { auto valueAttr = builder.getIntegerAttr(builder.getIndexType(), expr.getValue()); auto op = @@ -194,13 +195,13 @@ public: return op.getResult(); } - Value *visitDimExpr(AffineDimExpr expr) { + ValuePtr visitDimExpr(AffineDimExpr expr) { assert(expr.getPosition() < dimValues.size() && "affine dim position out of range"); return dimValues[expr.getPosition()]; } - Value *visitSymbolExpr(AffineSymbolExpr expr) { + ValuePtr visitSymbolExpr(AffineSymbolExpr expr) { assert(expr.getPosition() < symbolValues.size() && "symbol dim position out of range"); return symbolValues[expr.getPosition()]; @@ -208,8 +209,8 @@ public: private: OpBuilder &builder; - ArrayRef dimValues; - ArrayRef symbolValues; + ArrayRef dimValues; + ArrayRef symbolValues; Location loc; }; @@ -217,18 +218,18 @@ private: // Create a sequence of operations that implement the `expr` applied to the // given dimension and symbol values. -mlir::Value *mlir::expandAffineExpr(OpBuilder &builder, Location loc, - AffineExpr expr, - ArrayRef dimValues, - ArrayRef symbolValues) { +mlir::ValuePtr mlir::expandAffineExpr(OpBuilder &builder, Location loc, + AffineExpr expr, + ArrayRef dimValues, + ArrayRef symbolValues) { return AffineApplyExpander(builder, dimValues, symbolValues, loc).visit(expr); } // Create a sequence of operations that implement the `affineMap` applied to // the given `operands` (as it it were an AffineApplyOp). -Optional> static expandAffineMap( +Optional> static expandAffineMap( OpBuilder &builder, Location loc, AffineMap affineMap, - ArrayRef operands) { + ArrayRef operands) { auto numDims = affineMap.getNumDims(); auto expanded = functional::map( [numDims, &builder, loc, operands](AffineExpr expr) { @@ -237,7 +238,7 @@ Optional> static expandAffineMap( operands.drop_front(numDims)); }, affineMap.getResults()); - if (llvm::all_of(expanded, [](Value *v) { return v; })) + if (llvm::all_of(expanded, [](ValuePtr v) { return v; })) return expanded; return None; } @@ -253,13 +254,13 @@ Optional> static expandAffineMap( // Multiple values are scanned in a linear sequence. This creates a data // dependences that wouldn't exist in a tree reduction, but is easier to // recognize as a reduction by the subsequent passes. -static Value *buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate, - ArrayRef values, - OpBuilder &builder) { +static ValuePtr buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate, + ArrayRef values, + OpBuilder &builder) { assert(!llvm::empty(values) && "empty min/max chain"); auto valueIt = values.begin(); - Value *value = *valueIt++; + ValuePtr value = *valueIt++; for (; valueIt != values.end(); ++valueIt) { auto cmpOp = builder.create(loc, predicate, value, *valueIt); value = builder.create(loc, cmpOp.getResult(), value, *valueIt); @@ -271,8 +272,8 @@ static Value *buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate, // Emit instructions that correspond to the affine map in the lower bound // applied to the respective operands, and compute the maximum value across // the results. -Value *mlir::lowerAffineLowerBound(AffineForOp op, OpBuilder &builder) { - SmallVector boundOperands(op.getLowerBoundOperands()); +ValuePtr mlir::lowerAffineLowerBound(AffineForOp op, OpBuilder &builder) { + SmallVector boundOperands(op.getLowerBoundOperands()); auto lbValues = expandAffineMap(builder, op.getLoc(), op.getLowerBoundMap(), boundOperands); if (!lbValues) @@ -284,8 +285,8 @@ Value *mlir::lowerAffineLowerBound(AffineForOp op, OpBuilder &builder) { // Emit instructions that correspond to the affine map in the upper bound // applied to the respective operands, and compute the minimum value across // the results. -Value *mlir::lowerAffineUpperBound(AffineForOp op, OpBuilder &builder) { - SmallVector boundOperands(op.getUpperBoundOperands()); +ValuePtr mlir::lowerAffineUpperBound(AffineForOp op, OpBuilder &builder) { + SmallVector boundOperands(op.getUpperBoundOperands()); auto ubValues = expandAffineMap(builder, op.getLoc(), op.getUpperBoundMap(), boundOperands); if (!ubValues) @@ -314,9 +315,9 @@ public: PatternMatchResult matchAndRewrite(AffineForOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Value *lowerBound = lowerAffineLowerBound(op, rewriter); - Value *upperBound = lowerAffineUpperBound(op, rewriter); - Value *step = rewriter.create(loc, op.getStep()); + ValuePtr lowerBound = lowerAffineLowerBound(op, rewriter); + ValuePtr upperBound = lowerAffineUpperBound(op, rewriter); + ValuePtr step = rewriter.create(loc, op.getStep()); auto f = rewriter.create(loc, lowerBound, upperBound, step); f.region().getBlocks().clear(); rewriter.inlineRegionBefore(op.region(), f.region(), f.region().end()); @@ -335,25 +336,25 @@ public: // Now we just have to handle the condition logic. auto integerSet = op.getIntegerSet(); - Value *zeroConstant = rewriter.create(loc, 0); - SmallVector operands(op.getOperands()); + ValuePtr zeroConstant = rewriter.create(loc, 0); + SmallVector operands(op.getOperands()); auto operandsRef = llvm::makeArrayRef(operands); // Calculate cond as a conjunction without short-circuiting. - Value *cond = nullptr; + ValuePtr cond = nullptr; for (unsigned i = 0, e = integerSet.getNumConstraints(); i < e; ++i) { AffineExpr constraintExpr = integerSet.getConstraint(i); bool isEquality = integerSet.isEq(i); // Build and apply an affine expression auto numDims = integerSet.getNumDims(); - Value *affResult = expandAffineExpr(rewriter, loc, constraintExpr, - operandsRef.take_front(numDims), - operandsRef.drop_front(numDims)); + ValuePtr affResult = expandAffineExpr(rewriter, loc, constraintExpr, + operandsRef.take_front(numDims), + operandsRef.drop_front(numDims)); if (!affResult) return matchFailure(); auto pred = isEquality ? CmpIPredicate::eq : CmpIPredicate::sge; - Value *cmpVal = + ValuePtr cmpVal = rewriter.create(loc, pred, affResult, zeroConstant); cond = cond ? rewriter.create(loc, cond, cmpVal).getResult() : cmpVal; @@ -404,7 +405,7 @@ public: PatternMatchResult matchAndRewrite(AffineLoadOp op, PatternRewriter &rewriter) const override { // Expand affine map from 'affineLoadOp'. - SmallVector indices(op.getMapOperands()); + SmallVector indices(op.getMapOperands()); auto resultOperands = expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); if (!resultOperands) @@ -426,7 +427,7 @@ public: PatternMatchResult matchAndRewrite(AffinePrefetchOp op, PatternRewriter &rewriter) const override { // Expand affine map from 'affinePrefetchOp'. - SmallVector indices(op.getMapOperands()); + SmallVector indices(op.getMapOperands()); auto resultOperands = expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); if (!resultOperands) @@ -450,7 +451,7 @@ public: PatternMatchResult matchAndRewrite(AffineStoreOp op, PatternRewriter &rewriter) const override { // Expand affine map from 'affineStoreOp'. - SmallVector indices(op.getMapOperands()); + SmallVector indices(op.getMapOperands()); auto maybeExpandedMap = expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); if (!maybeExpandedMap) @@ -472,7 +473,7 @@ public: PatternMatchResult matchAndRewrite(AffineDmaStartOp op, PatternRewriter &rewriter) const override { - SmallVector operands(op.getOperands()); + SmallVector operands(op.getOperands()); auto operandsRef = llvm::makeArrayRef(operands); // Expand affine map for DMA source memref. @@ -513,7 +514,7 @@ public: PatternMatchResult matchAndRewrite(AffineDmaWaitOp op, PatternRewriter &rewriter) const override { // Expand affine map for DMA tag memref. - SmallVector indices(op.getTagIndices()); + SmallVector indices(op.getTagIndices()); auto maybeExpandedTagMap = expandAffineMap(rewriter, op.getLoc(), op.getTagMap(), indices); if (!maybeExpandedTagMap) diff --git a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h index 6a1a580e369..a408ab5b5d9 100644 --- a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h +++ b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h @@ -57,11 +57,11 @@ public: // Convert the kernel arguments to an LLVM type, preserve the rest. PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto dialect = lowering.getDialect(); - Value *newOp; + ValuePtr newOp; switch (dimensionToIndex(cast(op))) { case X: newOp = rewriter.create(loc, LLVM::LLVMType::getInt32Ty(dialect)); diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h index 23bfa303708..3ab8e75633e 100644 --- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -44,7 +44,7 @@ public: f32Func(f32Func), f64Func(f64Func) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { using LLVM::LLVMFuncOp; using LLVM::LLVMType; @@ -69,10 +69,10 @@ public: private: LLVM::LLVMType getFunctionType(LLVM::LLVMType resultType, - ArrayRef operands) const { + ArrayRef operands) const { using LLVM::LLVMType; SmallVector operandTypes; - for (Value *operand : operands) { + for (ValuePtr operand : operands) { operandTypes.push_back(operand->getType().cast()); } return LLVMType::getFunctionTy(resultType, operandTypes, diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp index f342083bee7..840ad6ba701 100644 --- a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp @@ -114,7 +114,7 @@ private: } // Allocate a void pointer on the stack. - Value *allocatePointer(OpBuilder &builder, Location loc) { + ValuePtr allocatePointer(OpBuilder &builder, Location loc) { auto one = builder.create(loc, getInt32Type(), builder.getI32IntegerAttr(1)); return builder.create(loc, getPointerPointerType(), one, @@ -122,9 +122,9 @@ private: } void declareCudaFunctions(Location loc); - Value *setupParamsArray(gpu::LaunchFuncOp launchOp, OpBuilder &builder); - Value *generateKernelNameConstant(StringRef name, Location loc, - OpBuilder &builder); + ValuePtr setupParamsArray(gpu::LaunchFuncOp launchOp, OpBuilder &builder); + ValuePtr generateKernelNameConstant(StringRef name, Location loc, + OpBuilder &builder); void translateGpuLaunchCalls(mlir::gpu::LaunchFuncOp launchOp); public: @@ -248,7 +248,7 @@ void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) { // for (i : [0, NumKernelOperands)) // %array[i] = cast(KernelOperand[i]) // return %array -Value * +ValuePtr GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp, OpBuilder &builder) { auto numKernelOperands = launchOp.getNumKernelOperands(); @@ -264,7 +264,7 @@ GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp, for (unsigned idx = 0; idx < numKernelOperands; ++idx) { auto operand = launchOp.getKernelOperand(idx); auto llvmType = operand->getType().cast(); - Value *memLocation = builder.create( + ValuePtr memLocation = builder.create( loc, llvmType.getPointerTo(), one, /*alignment=*/1); builder.create(loc, operand, memLocation); auto casted = @@ -280,12 +280,12 @@ GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp, getModule().lookupSymbol(kMcuMemHostRegister); auto nullPtr = builder.create(loc, llvmType.getPointerTo()); auto gep = builder.create(loc, llvmType.getPointerTo(), - ArrayRef{nullPtr, one}); + ArrayRef{nullPtr, one}); auto size = builder.create(loc, getInt64Type(), gep); builder.create(loc, ArrayRef{}, builder.getSymbolRefAttr(registerFunc), - ArrayRef{casted, size}); - Value *memLocation = builder.create( + ArrayRef{casted, size}); + ValuePtr memLocation = builder.create( loc, getPointerPointerType(), one, /*alignment=*/1); builder.create(loc, casted, memLocation); casted = @@ -295,7 +295,7 @@ GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp, auto index = builder.create( loc, getInt32Type(), builder.getI32IntegerAttr(idx)); auto gep = builder.create(loc, getPointerPointerType(), array, - ArrayRef{index}); + ArrayRef{index}); builder.create(loc, casted, gep); } return array; @@ -311,7 +311,7 @@ GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp, // %1 = llvm.constant (0 : index) // %2 = llvm.getelementptr %0[%1, %1] : !llvm<"i8*"> // } -Value *GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant( +ValuePtr GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant( StringRef name, Location loc, OpBuilder &builder) { // Make sure the trailing zero is included in the constant. std::vector kernelName(name.begin(), name.end()); @@ -367,7 +367,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls( assert(kernelModule.getName() && "expected a named module"); SmallString<128> nameBuffer(*kernelModule.getName()); nameBuffer.append(kCubinStorageSuffix); - Value *data = LLVM::createGlobalString( + ValuePtr data = LLVM::createGlobalString( loc, builder, nameBuffer.str(), cubinAttr.getValue(), LLVM::Linkage::Internal, getLLVMDialect()); @@ -378,7 +378,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls( getModule().lookupSymbol(cuModuleLoadName); builder.create(loc, ArrayRef{getCUResultType()}, builder.getSymbolRefAttr(cuModuleLoad), - ArrayRef{cuModule, data}); + ArrayRef{cuModule, data}); // Get the function from the module. The name corresponds to the name of // the kernel function. auto cuOwningModuleRef = @@ -390,13 +390,13 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls( builder.create( loc, ArrayRef{getCUResultType()}, builder.getSymbolRefAttr(cuModuleGetFunction), - ArrayRef{cuFunction, cuOwningModuleRef, kernelName}); + ArrayRef{cuFunction, cuOwningModuleRef, kernelName}); // Grab the global stream needed for execution. auto cuGetStreamHelper = getModule().lookupSymbol(cuGetStreamHelperName); auto cuStream = builder.create( loc, ArrayRef{getPointerType()}, - builder.getSymbolRefAttr(cuGetStreamHelper), ArrayRef{}); + builder.getSymbolRefAttr(cuGetStreamHelper), ArrayRef{}); // Invoke the function with required arguments. auto cuLaunchKernel = getModule().lookupSymbol(cuLaunchKernelName); @@ -408,19 +408,19 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls( builder.create( loc, ArrayRef{getCUResultType()}, builder.getSymbolRefAttr(cuLaunchKernel), - ArrayRef{cuFunctionRef, launchOp.getOperand(0), - launchOp.getOperand(1), launchOp.getOperand(2), - launchOp.getOperand(3), launchOp.getOperand(4), - launchOp.getOperand(5), zero, /* sharedMemBytes */ - cuStream.getResult(0), /* stream */ - paramsArray, /* kernel params */ - nullpointer /* extra */}); + ArrayRef{cuFunctionRef, launchOp.getOperand(0), + launchOp.getOperand(1), launchOp.getOperand(2), + launchOp.getOperand(3), launchOp.getOperand(4), + launchOp.getOperand(5), zero, /* sharedMemBytes */ + cuStream.getResult(0), /* stream */ + paramsArray, /* kernel params */ + nullpointer /* extra */}); // Sync on the stream to make it synchronous. auto cuStreamSync = getModule().lookupSymbol(cuStreamSynchronizeName); builder.create(loc, ArrayRef{getCUResultType()}, builder.getSymbolRefAttr(cuStreamSync), - ArrayRef(cuStream.getResult(0))); + ArrayRef(cuStream.getResult(0))); launchOp.erase(); } diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index 220df53b977..bf18ea03dab 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -60,8 +60,8 @@ public: /// Converts all_reduce op to LLVM/NVVM ops. struct GPUAllReduceOpLowering : public LLVMOpLowering { - using AccumulatorFactory = std::function; + using AccumulatorFactory = std::function; explicit GPUAllReduceOpLowering(LLVMTypeConverter &lowering_) : LLVMOpLowering(gpu::AllReduceOp::getOperationName(), @@ -69,10 +69,10 @@ struct GPUAllReduceOpLowering : public LLVMOpLowering { int32Type(LLVM::LLVMType::getInt32Ty(lowering_.getDialect())) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - Value *operand = operands.front(); + ValuePtr operand = operands.front(); // TODO(csigg): Generalize to other types of accumulation. assert(op->getOperand(0)->getType().isIntOrFloat()); @@ -81,7 +81,7 @@ struct GPUAllReduceOpLowering : public LLVMOpLowering { AccumulatorFactory factory = getFactory(cast(op), operand); assert(factory && "failed to create accumulator factory"); - Value *result = createBlockReduce(loc, operand, factory, rewriter); + ValuePtr result = createBlockReduce(loc, operand, factory, rewriter); rewriter.replaceOp(op, {result}); return matchSuccess(); @@ -91,7 +91,7 @@ private: /// Returns an accumulator factory using either the op attribute or the body /// region. AccumulatorFactory getFactory(gpu::AllReduceOp allReduce, - Value *operand) const { + ValuePtr operand) const { if (!allReduce.body().empty()) { return getFactory(allReduce.body()); } @@ -106,7 +106,7 @@ private: /// block is expected to have 2 arguments. The gpu.yield return the /// accumulated value of the same type. AccumulatorFactory getFactory(Region &body) const { - return AccumulatorFactory([&](Location loc, Value *lhs, Value *rhs, + return AccumulatorFactory([&](Location loc, ValuePtr lhs, ValuePtr rhs, ConversionPatternRewriter &rewriter) { Block *block = rewriter.getInsertionBlock(); Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint()); @@ -120,7 +120,7 @@ private: // Add branch before inserted body, into body. block = block->getNextNode(); - rewriter.create(loc, ArrayRef{}, + rewriter.create(loc, ArrayRef{}, llvm::makeArrayRef(block), ValueRange()); // Replace all gpu.yield ops with branch out of body. @@ -130,7 +130,7 @@ private: continue; rewriter.setInsertionPointToEnd(block); rewriter.replaceOpWithNewOp( - terminator, ArrayRef{}, llvm::makeArrayRef(split), + terminator, ArrayRef{}, llvm::makeArrayRef(split), ValueRange(terminator->getOperand(0))); } @@ -161,7 +161,7 @@ private: /// Returns an accumulator factory that creates an op of type T. template AccumulatorFactory getFactory() const { - return [](Location loc, Value *lhs, Value *rhs, + return [](Location loc, ValuePtr lhs, ValuePtr rhs, ConversionPatternRewriter &rewriter) { return rewriter.create(loc, lhs->getType(), lhs, rhs); }; @@ -203,60 +203,60 @@ private: /// %result = llvm.load %result_ptr /// return %result /// - Value *createBlockReduce(Location loc, Value *operand, - AccumulatorFactory &accumFactory, - ConversionPatternRewriter &rewriter) const { + ValuePtr createBlockReduce(Location loc, ValuePtr operand, + AccumulatorFactory &accumFactory, + ConversionPatternRewriter &rewriter) const { auto type = operand->getType().cast(); // Create shared memory array to store the warp reduction. auto module = operand->getDefiningOp()->getParentOfType(); assert(module && "op must belong to a module"); - Value *sharedMemPtr = + ValuePtr sharedMemPtr = createSharedMemoryArray(loc, module, type, kWarpSize, rewriter); - Value *zero = rewriter.create( + ValuePtr zero = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(0u)); - Value *laneId = rewriter.create(loc, int32Type); - Value *isFirstLane = rewriter.create( + ValuePtr laneId = rewriter.create(loc, int32Type); + ValuePtr isFirstLane = rewriter.create( loc, LLVM::ICmpPredicate::eq, laneId, zero); - Value *threadIdx = getLinearThreadIndex(loc, rewriter); - Value *blockSize = getBlockSize(loc, rewriter); - Value *activeWidth = getActiveWidth(loc, threadIdx, blockSize, rewriter); + ValuePtr threadIdx = getLinearThreadIndex(loc, rewriter); + ValuePtr blockSize = getBlockSize(loc, rewriter); + ValuePtr activeWidth = getActiveWidth(loc, threadIdx, blockSize, rewriter); // Reduce elements within each warp to produce the intermediate results. - Value *warpReduce = createWarpReduce(loc, activeWidth, laneId, operand, - accumFactory, rewriter); + ValuePtr warpReduce = createWarpReduce(loc, activeWidth, laneId, operand, + accumFactory, rewriter); // Write the intermediate results to shared memory, using the first lane of // each warp. createPredicatedBlock(loc, rewriter, isFirstLane, [&] { - Value *warpId = getDivideByWarpSize(threadIdx, rewriter); - Value *storeDst = rewriter.create( - loc, type, sharedMemPtr, ArrayRef({zero, warpId})); + ValuePtr warpId = getDivideByWarpSize(threadIdx, rewriter); + ValuePtr storeDst = rewriter.create( + loc, type, sharedMemPtr, ArrayRef({zero, warpId})); rewriter.create(loc, warpReduce, storeDst); }); rewriter.create(loc); - Value *numWarps = getNumWarps(loc, blockSize, rewriter); - Value *isValidWarp = rewriter.create( + ValuePtr numWarps = getNumWarps(loc, blockSize, rewriter); + ValuePtr isValidWarp = rewriter.create( loc, LLVM::ICmpPredicate::slt, threadIdx, numWarps); - Value *resultPtr = rewriter.create( - loc, type, sharedMemPtr, ArrayRef({zero, zero})); + ValuePtr resultPtr = rewriter.create( + loc, type, sharedMemPtr, ArrayRef({zero, zero})); // Use the first numWarps threads to reduce the intermediate results from // shared memory. The final result is written to shared memory again. createPredicatedBlock(loc, rewriter, isValidWarp, [&] { - Value *loadSrc = rewriter.create( - loc, type, sharedMemPtr, ArrayRef({zero, threadIdx})); - Value *value = rewriter.create(loc, type, loadSrc); - Value *result = createWarpReduce(loc, numWarps, laneId, value, - accumFactory, rewriter); + ValuePtr loadSrc = rewriter.create( + loc, type, sharedMemPtr, ArrayRef({zero, threadIdx})); + ValuePtr value = rewriter.create(loc, type, loadSrc); + ValuePtr result = createWarpReduce(loc, numWarps, laneId, value, + accumFactory, rewriter); rewriter.create(loc, result, resultPtr); }); rewriter.create(loc); // Load and return result from shared memory. - Value *result = rewriter.create(loc, type, resultPtr); + ValuePtr result = rewriter.create(loc, type, resultPtr); return result; } @@ -274,7 +274,7 @@ private: /// template void createIf(Location loc, ConversionPatternRewriter &rewriter, - Value *condition, ThenOpsFactory &&thenOpsFactory, + ValuePtr condition, ThenOpsFactory &&thenOpsFactory, ElseOpsFactory &&elseOpsFactory) const { Block *currentBlock = rewriter.getInsertionBlock(); auto currentPoint = rewriter.getInsertionPoint(); @@ -288,7 +288,7 @@ private: ArrayRef{thenBlock, elseBlock}); auto addBranch = [&](ValueRange operands) { - rewriter.create(loc, ArrayRef{}, + rewriter.create(loc, ArrayRef{}, llvm::makeArrayRef(continueBlock), llvm::makeArrayRef(operands)); }; @@ -303,32 +303,32 @@ private: assert(thenOperands.size() == elseOperands.size()); rewriter.setInsertionPointToStart(continueBlock); - for (auto *operand : thenOperands) + for (auto operand : thenOperands) continueBlock->addArgument(operand->getType()); } /// Shortcut for createIf with empty else block and no block operands. template void createPredicatedBlock(Location loc, ConversionPatternRewriter &rewriter, - Value *condition, + ValuePtr condition, Factory &&predicatedOpsFactory) const { createIf( loc, rewriter, condition, [&] { predicatedOpsFactory(); - return ArrayRef(); + return ArrayRef(); }, - [&] { return ArrayRef(); }); + [&] { return ArrayRef(); }); } /// Creates a reduction across the first activeWidth lanes of a warp. /// The first lane returns the result, all others return values are undefined. - Value *createWarpReduce(Location loc, Value *activeWidth, Value *laneId, - Value *operand, AccumulatorFactory accumFactory, - ConversionPatternRewriter &rewriter) const { - Value *warpSize = rewriter.create( + ValuePtr createWarpReduce(Location loc, ValuePtr activeWidth, ValuePtr laneId, + ValuePtr operand, AccumulatorFactory accumFactory, + ConversionPatternRewriter &rewriter) const { + ValuePtr warpSize = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize)); - Value *isPartialWarp = rewriter.create( + ValuePtr isPartialWarp = rewriter.create( loc, LLVM::ICmpPredicate::slt, activeWidth, warpSize); auto type = operand->getType().cast(); @@ -336,16 +336,16 @@ private: loc, rewriter, isPartialWarp, // Generate reduction over a (potentially) partial warp. [&] { - Value *value = operand; - Value *one = rewriter.create( + ValuePtr value = operand; + ValuePtr one = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(1)); // Bit mask of active lanes: `(1 << activeWidth) - 1`. - Value *activeMask = rewriter.create( + ValuePtr activeMask = rewriter.create( loc, int32Type, rewriter.create(loc, int32Type, one, activeWidth), one); // Clamp lane: `activeWidth - 1` - Value *maskAndClamp = + ValuePtr maskAndClamp = rewriter.create(loc, int32Type, activeWidth, one); auto dialect = lowering.getDialect(); auto predTy = LLVM::LLVMType::getInt1Ty(dialect); @@ -356,53 +356,53 @@ private: // lane is within the active range. All lanes contain the final // result, but only the first lane's result is used. for (int i = 1; i < kWarpSize; i <<= 1) { - Value *offset = rewriter.create( + ValuePtr offset = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(i)); - Value *shfl = rewriter.create( + ValuePtr shfl = rewriter.create( loc, shflTy, activeMask, value, offset, maskAndClamp, returnValueAndIsValidAttr); - Value *isActiveSrcLane = rewriter.create( + ValuePtr isActiveSrcLane = rewriter.create( loc, predTy, shfl, rewriter.getIndexArrayAttr(1)); // Skip the accumulation if the shuffle op read from a lane outside // of the active range. createIf( loc, rewriter, isActiveSrcLane, [&] { - Value *shflValue = rewriter.create( + ValuePtr shflValue = rewriter.create( loc, type, shfl, rewriter.getIndexArrayAttr(0)); - return SmallVector{ + return SmallVector{ accumFactory(loc, value, shflValue, rewriter)}; }, [&] { return llvm::makeArrayRef(value); }); value = rewriter.getInsertionBlock()->getArgument(0); } - return SmallVector{value}; + return SmallVector{value}; }, // Generate a reduction over the entire warp. This is a specialization // of the above reduction with unconditional accumulation. [&] { - Value *value = operand; - Value *activeMask = rewriter.create( + ValuePtr value = operand; + ValuePtr activeMask = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(~0u)); - Value *maskAndClamp = rewriter.create( + ValuePtr maskAndClamp = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1)); for (int i = 1; i < kWarpSize; i <<= 1) { - Value *offset = rewriter.create( + ValuePtr offset = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(i)); - Value *shflValue = rewriter.create( + ValuePtr shflValue = rewriter.create( loc, type, activeMask, value, offset, maskAndClamp, /*return_value_and_is_valid=*/UnitAttr()); value = accumFactory(loc, value, shflValue, rewriter); } - return SmallVector{value}; + return SmallVector{value}; }); return rewriter.getInsertionBlock()->getArgument(0); } /// Creates a global array stored in shared memory. - Value *createSharedMemoryArray(Location loc, ModuleOp module, - LLVM::LLVMType elementType, int numElements, - ConversionPatternRewriter &rewriter) const { + ValuePtr createSharedMemoryArray(Location loc, ModuleOp module, + LLVM::LLVMType elementType, int numElements, + ConversionPatternRewriter &rewriter) const { OpBuilder builder(module.getBodyRegion()); auto arrayType = LLVM::LLVMType::getArrayTy(elementType, numElements); @@ -416,31 +416,32 @@ private: } /// Returns the index of the thread within the block. - Value *getLinearThreadIndex(Location loc, - ConversionPatternRewriter &rewriter) const { - Value *dimX = rewriter.create(loc, int32Type); - Value *dimY = rewriter.create(loc, int32Type); - Value *idX = rewriter.create(loc, int32Type); - Value *idY = rewriter.create(loc, int32Type); - Value *idZ = rewriter.create(loc, int32Type); - Value *tmp1 = rewriter.create(loc, int32Type, idZ, dimY); - Value *tmp2 = rewriter.create(loc, int32Type, tmp1, idY); - Value *tmp3 = rewriter.create(loc, int32Type, tmp2, dimX); + ValuePtr getLinearThreadIndex(Location loc, + ConversionPatternRewriter &rewriter) const { + ValuePtr dimX = rewriter.create(loc, int32Type); + ValuePtr dimY = rewriter.create(loc, int32Type); + ValuePtr idX = rewriter.create(loc, int32Type); + ValuePtr idY = rewriter.create(loc, int32Type); + ValuePtr idZ = rewriter.create(loc, int32Type); + ValuePtr tmp1 = rewriter.create(loc, int32Type, idZ, dimY); + ValuePtr tmp2 = rewriter.create(loc, int32Type, tmp1, idY); + ValuePtr tmp3 = rewriter.create(loc, int32Type, tmp2, dimX); return rewriter.create(loc, int32Type, tmp3, idX); } /// Returns the number of threads in the block. - Value *getBlockSize(Location loc, ConversionPatternRewriter &rewriter) const { - Value *dimX = rewriter.create(loc, int32Type); - Value *dimY = rewriter.create(loc, int32Type); - Value *dimZ = rewriter.create(loc, int32Type); - Value *dimXY = rewriter.create(loc, int32Type, dimX, dimY); + ValuePtr getBlockSize(Location loc, + ConversionPatternRewriter &rewriter) const { + ValuePtr dimX = rewriter.create(loc, int32Type); + ValuePtr dimY = rewriter.create(loc, int32Type); + ValuePtr dimZ = rewriter.create(loc, int32Type); + ValuePtr dimXY = rewriter.create(loc, int32Type, dimX, dimY); return rewriter.create(loc, int32Type, dimXY, dimZ); } /// Returns the number of warps in the block. - Value *getNumWarps(Location loc, Value *blockSize, - ConversionPatternRewriter &rewriter) const { + ValuePtr getNumWarps(Location loc, ValuePtr blockSize, + ConversionPatternRewriter &rewriter) const { auto warpSizeMinusOne = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1)); auto biasedBlockSize = rewriter.create( @@ -449,19 +450,19 @@ private: } /// Returns the number of active threads in the warp, not clamped to 32. - Value *getActiveWidth(Location loc, Value *threadIdx, Value *blockSize, - ConversionPatternRewriter &rewriter) const { - Value *threadIdxMask = rewriter.create( + ValuePtr getActiveWidth(Location loc, ValuePtr threadIdx, ValuePtr blockSize, + ConversionPatternRewriter &rewriter) const { + ValuePtr threadIdxMask = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(~(kWarpSize - 1))); - Value *numThreadsWithSmallerWarpId = + ValuePtr numThreadsWithSmallerWarpId = rewriter.create(loc, threadIdx, threadIdxMask); return rewriter.create(loc, blockSize, numThreadsWithSmallerWarpId); } /// Returns value divided by the warp size (i.e. 32). - Value *getDivideByWarpSize(Value *value, - ConversionPatternRewriter &rewriter) const { + ValuePtr getDivideByWarpSize(ValuePtr value, + ConversionPatternRewriter &rewriter) const { auto loc = value->getLoc(); auto warpSize = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize)); @@ -495,7 +496,7 @@ struct GPUShuffleOpLowering : public LLVMOpLowering { /// %shfl_pred = llvm.extractvalue %shfl[1 : index] : /// !llvm<"{ float, i1 }"> PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); gpu::ShuffleOpOperandAdaptor adaptor(operands); @@ -506,24 +507,24 @@ struct GPUShuffleOpLowering : public LLVMOpLowering { auto predTy = LLVM::LLVMType::getInt1Ty(dialect); auto resultTy = LLVM::LLVMType::getStructTy(dialect, {valueTy, predTy}); - Value *one = rewriter.create( + ValuePtr one = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(1)); // Bit mask of active lanes: `(1 << activeWidth) - 1`. - Value *activeMask = rewriter.create( + ValuePtr activeMask = rewriter.create( loc, int32Type, rewriter.create(loc, int32Type, one, adaptor.width()), one); // Clamp lane: `activeWidth - 1` - Value *maskAndClamp = + ValuePtr maskAndClamp = rewriter.create(loc, int32Type, adaptor.width(), one); auto returnValueAndIsValidAttr = rewriter.getUnitAttr(); - Value *shfl = rewriter.create( + ValuePtr shfl = rewriter.create( loc, resultTy, activeMask, adaptor.value(), adaptor.offset(), maskAndClamp, returnValueAndIsValidAttr); - Value *shflValue = rewriter.create( + ValuePtr shflValue = rewriter.create( loc, valueTy, shfl, rewriter.getIndexArrayAttr(0)); - Value *isActiveSrcLane = rewriter.create( + ValuePtr isActiveSrcLane = rewriter.create( loc, predTy, shfl, rewriter.getIndexArrayAttr(1)); rewriter.replaceOp(op, {shflValue, isActiveSrcLane}); @@ -538,7 +539,7 @@ struct GPUFuncOpLowering : LLVMOpLowering { typeConverter) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { assert(operands.empty() && "func op is not expected to have operands"); auto gpuFuncOp = cast(op); @@ -547,7 +548,7 @@ struct GPUFuncOpLowering : LLVMOpLowering { SmallVector workgroupBuffers; workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions()); for (auto en : llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) { - Value *attribution = en.value(); + ValuePtr attribution = en.value(); auto type = attribution->getType().dyn_cast(); assert(type && type.hasStaticShape() && "unexpected type in attribution"); @@ -604,23 +605,23 @@ struct GPUFuncOpLowering : LLVMOpLowering { unsigned numProperArguments = gpuFuncOp.getNumArguments(); auto i32Type = LLVM::LLVMType::getInt32Ty(lowering.getDialect()); - Value *zero = nullptr; + ValuePtr zero = nullptr; if (!workgroupBuffers.empty()) zero = rewriter.create(loc, i32Type, rewriter.getI32IntegerAttr(0)); for (auto en : llvm::enumerate(workgroupBuffers)) { LLVM::GlobalOp global = en.value(); - Value *address = rewriter.create(loc, global); + ValuePtr address = rewriter.create(loc, global); auto elementType = global.getType().getArrayElementType(); - Value *memory = rewriter.create( + ValuePtr memory = rewriter.create( loc, elementType.getPointerTo(global.addr_space().getZExtValue()), - address, ArrayRef{zero, zero}); + address, ArrayRef{zero, zero}); // Build a memref descriptor pointing to the buffer to plug with the // existing memref infrastructure. This may use more registers than // otherwise necessary given that memref sizes are fixed, but we can try // and canonicalize that away later. - Value *attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()]; + ValuePtr attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()]; auto type = attribution->getType().cast(); auto descr = MemRefDescriptor::fromStaticShape(rewriter, loc, lowering, type, memory); @@ -632,7 +633,7 @@ struct GPUFuncOpLowering : LLVMOpLowering { gpuFuncOp.getNumWorkgroupAttributions(); auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect()); for (auto en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) { - Value *attribution = en.value(); + ValuePtr attribution = en.value(); auto type = attribution->getType().cast(); assert(type && type.hasStaticShape() && "unexpected type in attribution"); @@ -643,10 +644,10 @@ struct GPUFuncOpLowering : LLVMOpLowering { auto ptrType = lowering.convertType(type.getElementType()) .cast() .getPointerTo(); - Value *numElements = rewriter.create( + ValuePtr numElements = rewriter.create( gpuFuncOp.getLoc(), int64Ty, rewriter.getI64IntegerAttr(type.getNumElements())); - Value *allocated = rewriter.create( + ValuePtr allocated = rewriter.create( gpuFuncOp.getLoc(), ptrType, numElements, /*alignment=*/0); auto descr = MemRefDescriptor::fromStaticShape(rewriter, loc, lowering, type, allocated); @@ -674,8 +675,8 @@ struct GPUFuncOpLowering : LLVMOpLowering { !en.value().isa()) continue; - BlockArgument *arg = block.getArgument(en.index()); - Value *loaded = rewriter.create(loc, arg); + BlockArgumentPtr arg = block.getArgument(en.index()); + ValuePtr loaded = rewriter.create(loc, arg); rewriter.replaceUsesOfBlockArgument(arg, loaded); } } @@ -692,7 +693,7 @@ struct GPUReturnOpLowering : public LLVMOpLowering { typeConverter) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, operands, ArrayRef()); diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp index 42483a6e5df..0c34fc2b8e1 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp @@ -36,7 +36,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult - matchAndRewrite(loop::ForOp forOp, ArrayRef operands, + matchAndRewrite(loop::ForOp forOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -48,7 +48,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult - matchAndRewrite(SourceOp op, ArrayRef operands, + matchAndRewrite(SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -65,7 +65,7 @@ public: } PatternMatchResult - matchAndRewrite(gpu::GPUFuncOp funcOp, ArrayRef operands, + matchAndRewrite(gpu::GPUFuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; private: @@ -79,7 +79,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult - matchAndRewrite(ModuleOp moduleOp, ArrayRef operands, + matchAndRewrite(ModuleOp moduleOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -92,7 +92,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult - matchAndRewrite(ModuleTerminatorOp terminatorOp, ArrayRef operands, + matchAndRewrite(ModuleTerminatorOp terminatorOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -103,7 +103,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult - matchAndRewrite(gpu::ReturnOp returnOp, ArrayRef operands, + matchAndRewrite(gpu::ReturnOp returnOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -114,7 +114,7 @@ public: //===----------------------------------------------------------------------===// PatternMatchResult -ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef operands, +ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { // loop::ForOp can be lowered to the structured control flow represented by // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop @@ -135,7 +135,7 @@ ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef operands, loopOp.body().getBlocks().insert(std::next(loopOp.body().begin(), 1), header); // Create the new induction variable to use. - BlockArgument *newIndVar = + BlockArgumentPtr newIndVar = header->addArgument(forOperands.lowerBound()->getType()); Block *body = forOp.getBody(); @@ -166,7 +166,7 @@ ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef operands, auto cmpOp = rewriter.create( loc, rewriter.getI1Type(), newIndVar, forOperands.upperBound()); rewriter.create( - loc, cmpOp, body, ArrayRef(), mergeBlock, ArrayRef()); + loc, cmpOp, body, ArrayRef(), mergeBlock, ArrayRef()); // Generate instructions to increment the step of the induction variable and // branch to the header. @@ -174,7 +174,7 @@ ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef operands, rewriter.setInsertionPointToEnd(continueBlock); // Add the step to the induction variable and branch to the header. - Value *updatedIndVar = rewriter.create( + ValuePtr updatedIndVar = rewriter.create( loc, newIndVar->getType(), newIndVar, forOperands.step()); rewriter.create(loc, header, updatedIndVar); @@ -188,7 +188,7 @@ ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef operands, template PatternMatchResult LaunchConfigConversion::matchAndRewrite( - SourceOp op, ArrayRef operands, + SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { auto dimAttr = op.getOperation()->template getAttrOfType("dimension"); @@ -267,7 +267,7 @@ lowerAsEntryFunction(gpu::GPUFuncOp funcOp, SPIRVTypeConverter &typeConverter, PatternMatchResult KernelFnConversion::matchAndRewrite(gpu::GPUFuncOp funcOp, - ArrayRef operands, + ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (!gpu::GPUDialect::isKernel(funcOp)) { return matchFailure(); @@ -297,7 +297,7 @@ KernelFnConversion::matchAndRewrite(gpu::GPUFuncOp funcOp, //===----------------------------------------------------------------------===// PatternMatchResult KernelModuleConversion::matchAndRewrite( - ModuleOp moduleOp, ArrayRef operands, + ModuleOp moduleOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (!moduleOp.getAttrOfType( gpu::GPUDialect::getKernelModuleAttrName())) { @@ -327,7 +327,7 @@ PatternMatchResult KernelModuleConversion::matchAndRewrite( //===----------------------------------------------------------------------===// PatternMatchResult KernelModuleTerminatorConversion::matchAndRewrite( - ModuleTerminatorOp terminatorOp, ArrayRef operands, + ModuleTerminatorOp terminatorOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { rewriter.replaceOpWithNewOp(terminatorOp); return matchSuccess(); @@ -338,7 +338,7 @@ PatternMatchResult KernelModuleTerminatorConversion::matchAndRewrite( //===----------------------------------------------------------------------===// PatternMatchResult GPUReturnOpConversion::matchAndRewrite( - gpu::ReturnOp returnOp, ArrayRef operands, + gpu::ReturnOp returnOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (!operands.empty()) return matchFailure(); diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp index 3eb23c19dc7..8b6b9fb7930 100644 --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -120,21 +120,23 @@ public: BaseViewConversionHelper(Type type) : d(MemRefDescriptor::undef(rewriter(), loc(), type)) {} - BaseViewConversionHelper(Value *v) : d(v) {} + BaseViewConversionHelper(ValuePtr v) : d(v) {} /// Wrappers around MemRefDescriptor that use EDSC builder and location. - Value *allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); } - void setAllocatedPtr(Value *v) { d.setAllocatedPtr(rewriter(), loc(), v); } - Value *alignedPtr() { return d.alignedPtr(rewriter(), loc()); } - void setAlignedPtr(Value *v) { d.setAlignedPtr(rewriter(), loc(), v); } - Value *offset() { return d.offset(rewriter(), loc()); } - void setOffset(Value *v) { d.setOffset(rewriter(), loc(), v); } - Value *size(unsigned i) { return d.size(rewriter(), loc(), i); } - void setSize(unsigned i, Value *v) { d.setSize(rewriter(), loc(), i, v); } - Value *stride(unsigned i) { return d.stride(rewriter(), loc(), i); } - void setStride(unsigned i, Value *v) { d.setStride(rewriter(), loc(), i, v); } - - operator Value *() { return d; } + ValuePtr allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); } + void setAllocatedPtr(ValuePtr v) { d.setAllocatedPtr(rewriter(), loc(), v); } + ValuePtr alignedPtr() { return d.alignedPtr(rewriter(), loc()); } + void setAlignedPtr(ValuePtr v) { d.setAlignedPtr(rewriter(), loc(), v); } + ValuePtr offset() { return d.offset(rewriter(), loc()); } + void setOffset(ValuePtr v) { d.setOffset(rewriter(), loc(), v); } + ValuePtr size(unsigned i) { return d.size(rewriter(), loc(), i); } + void setSize(unsigned i, ValuePtr v) { d.setSize(rewriter(), loc(), i, v); } + ValuePtr stride(unsigned i) { return d.stride(rewriter(), loc(), i); } + void setStride(unsigned i, ValuePtr v) { + d.setStride(rewriter(), loc(), i, v); + } + + operator ValuePtr() { return d; } private: OpBuilder &rewriter() { return ScopedContext::getBuilder(); } @@ -151,7 +153,7 @@ public: : LLVMOpLowering(RangeOp::getOperationName(), context, lowering_) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto rangeOp = cast(op); auto rangeDescriptorTy = @@ -161,7 +163,7 @@ public: // Fill in an aggregate value of the descriptor. RangeOpOperandAdaptor adaptor(operands); - Value *desc = llvm_undef(rangeDescriptorTy); + ValuePtr desc = llvm_undef(rangeDescriptorTy); desc = insertvalue(desc, adaptor.min(), rewriter.getI64ArrayAttr(0)); desc = insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1)); desc = insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2)); @@ -184,7 +186,7 @@ public: : LLVMOpLowering(SliceOp::getOperationName(), context, lowering_) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { edsc::ScopedContext context(rewriter, op->getLoc()); SliceOpOperandAdaptor adaptor(operands); @@ -198,7 +200,7 @@ public: BaseViewConversionHelper desc(lowering.convertType(sliceOp.getViewType())); // TODO(ntv): extract sizes and emit asserts. - SmallVector strides(memRefType.getRank()); + SmallVector strides(memRefType.getRank()); for (int i = 0, e = memRefType.getRank(); i < e; ++i) strides[i] = baseDesc.stride(i); @@ -207,10 +209,10 @@ public: }; // Compute base offset. - Value *baseOffset = baseDesc.offset(); + ValuePtr baseOffset = baseDesc.offset(); for (int i = 0, e = memRefType.getRank(); i < e; ++i) { - Value *indexing = adaptor.indexings()[i]; - Value *min = indexing; + ValuePtr indexing = adaptor.indexings()[i]; + ValuePtr min = indexing; if (sliceOp.indexing(i)->getType().isa()) min = extractvalue(int64Ty, indexing, pos(0)); baseOffset = add(baseOffset, mul(min, strides[i])); @@ -227,29 +229,29 @@ public: if (sliceOp.getViewType().getRank() == 0) return rewriter.replaceOp(op, {desc}), matchSuccess(); - Value *zero = + ValuePtr zero = constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); // Compute and insert view sizes (max - min along the range) and strides. // Skip the non-range operands as they will be projected away from the view. int numNewDims = 0; for (auto en : llvm::enumerate(sliceOp.indexings())) { - Value *indexing = en.value(); + ValuePtr indexing = en.value(); if (indexing->getType().isa()) { int rank = en.index(); - Value *rangeDescriptor = adaptor.indexings()[rank]; - Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0)); - Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1)); - Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2)); - Value *baseSize = baseDesc.size(rank); + ValuePtr rangeDescriptor = adaptor.indexings()[rank]; + ValuePtr min = extractvalue(int64Ty, rangeDescriptor, pos(0)); + ValuePtr max = extractvalue(int64Ty, rangeDescriptor, pos(1)); + ValuePtr step = extractvalue(int64Ty, rangeDescriptor, pos(2)); + ValuePtr baseSize = baseDesc.size(rank); // Bound upper by base view upper bound. max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max, baseSize); - Value *size = sub(max, min); + ValuePtr size = sub(max, min); // Bound lower by zero. size = llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size); - Value *stride = mul(strides[rank], step); + ValuePtr stride = mul(strides[rank], step); desc.setSize(numNewDims, size); desc.setStride(numNewDims, stride); ++numNewDims; @@ -275,7 +277,7 @@ public: : LLVMOpLowering(TransposeOp::getOperationName(), context, lowering_) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // Initialize the common boilerplate and alloca at the top of the FuncOp. edsc::ScopedContext context(rewriter, op->getLoc()); @@ -318,7 +320,7 @@ public: : LLVMOpLowering(YieldOp::getOperationName(), context, lowering_) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, operands); return matchSuccess(); @@ -453,7 +455,7 @@ public: op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); auto indexedGenericOp = cast(op); auto numLoops = indexedGenericOp.getNumLoops(); - SmallVector operands; + SmallVector operands; operands.reserve(numLoops + op.getNumOperands()); for (unsigned i = 0; i < numLoops; ++i) { operands.push_back(zero); @@ -477,7 +479,7 @@ public: PatternMatchResult matchAndRewrite(CopyOp op, PatternRewriter &rewriter) const override { - Value *in = op.input(), *out = op.output(); + ValuePtr in = op.input(), out = op.output(); // If either inputPerm or outputPerm are non-identities, insert transposes. auto inputPerm = op.inputPermutation(); diff --git a/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp b/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp index ff93ce58fd4..d8df7487e71 100644 --- a/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp +++ b/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp @@ -182,22 +182,22 @@ ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const { rewriter.splitBlock(conditionBlock, conditionBlock->begin()); auto *lastBodyBlock = &forOp.region().back(); rewriter.inlineRegionBefore(forOp.region(), endBlock); - auto *iv = conditionBlock->getArgument(0); + auto iv = conditionBlock->getArgument(0); // Append the induction variable stepping logic to the last body block and // branch back to the condition block. Construct an expression f : // (x -> x+step) and apply this expression to the induction variable. rewriter.setInsertionPointToEnd(lastBodyBlock); - auto *step = forOp.step(); - auto *stepped = rewriter.create(loc, iv, step).getResult(); + auto step = forOp.step(); + auto stepped = rewriter.create(loc, iv, step).getResult(); if (!stepped) return matchFailure(); rewriter.create(loc, conditionBlock, stepped); // Compute loop bounds before branching to the condition. rewriter.setInsertionPointToEnd(initBlock); - Value *lowerBound = forOp.lowerBound(); - Value *upperBound = forOp.upperBound(); + ValuePtr lowerBound = forOp.lowerBound(); + ValuePtr upperBound = forOp.upperBound(); if (!lowerBound || !upperBound) return matchFailure(); rewriter.create(loc, conditionBlock, lowerBound); @@ -208,8 +208,8 @@ ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const { rewriter.create(loc, CmpIPredicate::slt, iv, upperBound); rewriter.create(loc, comparison, firstBodyBlock, - ArrayRef(), endBlock, - ArrayRef()); + ArrayRef(), endBlock, + ArrayRef()); // Ok, we're done! rewriter.eraseOp(forOp); return matchSuccess(); @@ -248,8 +248,8 @@ IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const { rewriter.setInsertionPointToEnd(condBlock); rewriter.create(loc, ifOp.condition(), thenBlock, - /*trueArgs=*/ArrayRef(), elseBlock, - /*falseArgs=*/ArrayRef()); + /*trueArgs=*/ArrayRef(), elseBlock, + /*falseArgs=*/ArrayRef()); // Ok, we're done! rewriter.eraseOp(ifOp); diff --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp index d663ae105f2..3cbce7caa76 100644 --- a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp +++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp @@ -43,7 +43,7 @@ using namespace mlir::loop; using llvm::seq; // Extract an indexed value from KernelDim3. -static Value *getDim3Value(const gpu::KernelDim3 &dim3, unsigned pos) { +static ValuePtr getDim3Value(const gpu::KernelDim3 &dim3, unsigned pos) { switch (pos) { case 0: return dim3.x; @@ -61,8 +61,8 @@ static Value *getDim3Value(const gpu::KernelDim3 &dim3, unsigned pos) { static Operation::operand_range getLowerBoundOperands(AffineForOp forOp) { return forOp.getLowerBoundOperands(); } -static SmallVector getLowerBoundOperands(ForOp forOp) { - SmallVector bounds(1, forOp.lowerBound()); +static SmallVector getLowerBoundOperands(ForOp forOp) { + SmallVector bounds(1, forOp.lowerBound()); return bounds; } @@ -70,33 +70,35 @@ static SmallVector getLowerBoundOperands(ForOp forOp) { static Operation::operand_range getUpperBoundOperands(AffineForOp forOp) { return forOp.getUpperBoundOperands(); } -static SmallVector getUpperBoundOperands(ForOp forOp) { - SmallVector bounds(1, forOp.upperBound()); +static SmallVector getUpperBoundOperands(ForOp forOp) { + SmallVector bounds(1, forOp.upperBound()); return bounds; } // Get a Value that corresponds to the loop step. If the step is an attribute, // materialize a corresponding constant using builder. -static Value *getOrCreateStep(AffineForOp forOp, OpBuilder &builder) { +static ValuePtr getOrCreateStep(AffineForOp forOp, OpBuilder &builder) { return builder.create(forOp.getLoc(), forOp.getStep()); } -static Value *getOrCreateStep(ForOp forOp, OpBuilder &) { return forOp.step(); } +static ValuePtr getOrCreateStep(ForOp forOp, OpBuilder &) { + return forOp.step(); +} // Get a Value for the loop lower bound. If the value requires computation, // materialize the instructions using builder. -static Value *getOrEmitLowerBound(AffineForOp forOp, OpBuilder &builder) { +static ValuePtr getOrEmitLowerBound(AffineForOp forOp, OpBuilder &builder) { return lowerAffineLowerBound(forOp, builder); } -static Value *getOrEmitLowerBound(ForOp forOp, OpBuilder &) { +static ValuePtr getOrEmitLowerBound(ForOp forOp, OpBuilder &) { return forOp.lowerBound(); } // Get a Value for the loop upper bound. If the value requires computation, // materialize the instructions using builder. -static Value *getOrEmitUpperBound(AffineForOp forOp, OpBuilder &builder) { +static ValuePtr getOrEmitUpperBound(AffineForOp forOp, OpBuilder &builder) { return lowerAffineUpperBound(forOp, builder); } -static Value *getOrEmitUpperBound(ForOp forOp, OpBuilder &) { +static ValuePtr getOrEmitUpperBound(ForOp forOp, OpBuilder &) { return forOp.upperBound(); } @@ -212,18 +214,18 @@ struct LoopToGpuConverter { unsigned numThreadDims); // Ranges of the loops mapped to blocks or threads. - SmallVector dims; + SmallVector dims; // Lower bounds of the loops mapped to blocks or threads. - SmallVector lbs; + SmallVector lbs; // Induction variables of the loops mapped to blocks or threads. - SmallVector ivs; + SmallVector ivs; // Steps of the loops mapped to blocks or threads. - SmallVector steps; + SmallVector steps; }; } // namespace // Return true if the value is obviously a constant "one". -static bool isConstantOne(Value *value) { +static bool isConstantOne(ValuePtr value) { if (auto def = dyn_cast_or_null(value->getDefiningOp())) return def.getValue() == 1; return false; @@ -244,15 +246,15 @@ Optional LoopToGpuConverter::collectBounds(OpTy forOp, steps.reserve(numLoops); OpTy currentLoop = forOp; for (unsigned i = 0; i < numLoops; ++i) { - Value *lowerBound = getOrEmitLowerBound(currentLoop, builder); - Value *upperBound = getOrEmitUpperBound(currentLoop, builder); + ValuePtr lowerBound = getOrEmitLowerBound(currentLoop, builder); + ValuePtr upperBound = getOrEmitUpperBound(currentLoop, builder); if (!lowerBound || !upperBound) { return llvm::None; } - Value *range = + ValuePtr range = builder.create(currentLoop.getLoc(), upperBound, lowerBound); - Value *step = getOrCreateStep(currentLoop, builder); + ValuePtr step = getOrCreateStep(currentLoop, builder); if (!isConstantOne(step)) range = builder.create(currentLoop.getLoc(), range, step); dims.push_back(range); @@ -274,8 +276,8 @@ Optional LoopToGpuConverter::collectBounds(OpTy forOp, /// `nids`. The innermost loop is mapped to the x-dimension, followed by the /// next innermost loop to y-dimension, followed by z-dimension. template -OpTy createGPULaunchLoops(OpTy rootForOp, ArrayRef ids, - ArrayRef nids) { +OpTy createGPULaunchLoops(OpTy rootForOp, ArrayRef ids, + ArrayRef nids) { auto nDims = ids.size(); assert(nDims == nids.size()); for (auto dim : llvm::seq(0, nDims)) { @@ -295,11 +297,11 @@ OpTy createGPULaunchLoops(OpTy rootForOp, ArrayRef ids, /// each workgroup/workitem and number of workgroup/workitems along a dimension /// of the launch into a container. void packIdAndNumId(gpu::KernelDim3 kernelIds, gpu::KernelDim3 kernelNids, - unsigned nDims, SmallVectorImpl &ids, - SmallVectorImpl &nids) { + unsigned nDims, SmallVectorImpl &ids, + SmallVectorImpl &nids) { assert(nDims <= 3 && "invalid number of launch dimensions"); - SmallVector allIds = {kernelIds.z, kernelIds.y, kernelIds.x}; - SmallVector allNids = {kernelNids.z, kernelNids.y, kernelNids.x}; + SmallVector allIds = {kernelIds.z, kernelIds.y, kernelIds.x}; + SmallVector allNids = {kernelNids.z, kernelNids.y, kernelNids.x}; ids.clear(); ids.append(std::next(allIds.begin(), allIds.size() - nDims), allIds.end()); nids.clear(); @@ -317,7 +319,7 @@ LogicalResult createLaunchBody(OpBuilder &builder, OpTy rootForOp, auto returnOp = builder.create(launchOp.getLoc()); rootForOp.getOperation()->moveBefore(returnOp); - SmallVector workgroupID, numWorkGroups; + SmallVector workgroupID, numWorkGroups; packIdAndNumId(launchOp.getBlockIds(), launchOp.getGridSize(), numBlockDims, workgroupID, numWorkGroups); @@ -333,7 +335,7 @@ LogicalResult createLaunchBody(OpBuilder &builder, OpTy rootForOp, } } - SmallVector workItemID, workGroupSize; + SmallVector workItemID, workGroupSize; packIdAndNumId(launchOp.getThreadIds(), launchOp.getBlockSize(), numThreadDims, workItemID, workGroupSize); for (auto &loopOp : threadRootForOps) { @@ -347,17 +349,17 @@ LogicalResult createLaunchBody(OpBuilder &builder, OpTy rootForOp, // given workgroup size and number of workgroups. template LogicalResult createLaunchFromOp(OpTy rootForOp, - ArrayRef numWorkGroups, - ArrayRef workGroupSizes) { + ArrayRef numWorkGroups, + ArrayRef workGroupSizes) { OpBuilder builder(rootForOp.getOperation()); if (numWorkGroups.size() > 3) { return rootForOp.emitError("invalid ") << numWorkGroups.size() << "-D workgroup specification"; } auto loc = rootForOp.getLoc(); - Value *one = builder.create( + ValuePtr one = builder.create( loc, builder.getIntegerAttr(builder.getIndexType(), 1)); - SmallVector numWorkGroups3D(3, one), workGroupSize3D(3, one); + SmallVector numWorkGroups3D(3, one), workGroupSize3D(3, one); for (auto numWorkGroup : enumerate(numWorkGroups)) { numWorkGroups3D[numWorkGroup.index()] = numWorkGroup.value(); } @@ -367,7 +369,7 @@ LogicalResult createLaunchFromOp(OpTy rootForOp, // Get the values used within the region of the rootForOp but defined above // it. - llvm::SetVector valuesToForwardSet; + llvm::SetVector valuesToForwardSet; getUsedValuesDefinedAbove(rootForOp.region(), rootForOp.region(), valuesToForwardSet); // Also add the values used for the lb, ub, and step of the rootForOp. @@ -387,8 +389,8 @@ LogicalResult createLaunchFromOp(OpTy rootForOp, // defined outside. They all are replaced with kernel arguments. for (const auto &pair : llvm::zip_first(valuesToForward, launchOp.getKernelArguments())) { - Value *from = std::get<0>(pair); - Value *to = std::get<1>(pair); + ValuePtr from = std::get<0>(pair); + ValuePtr to = std::get<1>(pair); replaceAllUsesInRegionWith(from, to, launchOp.body()); } return success(); @@ -408,22 +410,23 @@ void LoopToGpuConverter::createLaunch(OpTy rootForOp, OpTy innermostForOp, OpBuilder builder(rootForOp.getOperation()); // Prepare the grid and block sizes for the launch operation. If there is // no loop mapped to a specific dimension, use constant "1" as its size. - Value *constOne = (numBlockDims < 3 || numThreadDims < 3) - ? builder.create(rootForOp.getLoc(), 1) - : nullptr; - Value *gridSizeX = dims[0]; - Value *gridSizeY = numBlockDims > 1 ? dims[1] : constOne; - Value *gridSizeZ = numBlockDims > 2 ? dims[2] : constOne; - Value *blockSizeX = dims[numBlockDims]; - Value *blockSizeY = numThreadDims > 1 ? dims[numBlockDims + 1] : constOne; - Value *blockSizeZ = numThreadDims > 2 ? dims[numBlockDims + 2] : constOne; + ValuePtr constOne = + (numBlockDims < 3 || numThreadDims < 3) + ? builder.create(rootForOp.getLoc(), 1) + : nullptr; + ValuePtr gridSizeX = dims[0]; + ValuePtr gridSizeY = numBlockDims > 1 ? dims[1] : constOne; + ValuePtr gridSizeZ = numBlockDims > 2 ? dims[2] : constOne; + ValuePtr blockSizeX = dims[numBlockDims]; + ValuePtr blockSizeY = numThreadDims > 1 ? dims[numBlockDims + 1] : constOne; + ValuePtr blockSizeZ = numThreadDims > 2 ? dims[numBlockDims + 2] : constOne; // Create a launch op and move the body region of the innermost loop to the // launch op. Pass the values defined outside the outermost loop and used // inside the innermost loop and loop lower bounds as kernel data arguments. // Still assuming perfect nesting so there are no values other than induction // variables that are defined in one loop and used in deeper loops. - llvm::SetVector valuesToForwardSet; + llvm::SetVector valuesToForwardSet; getUsedValuesDefinedAbove(innermostForOp.region(), rootForOp.region(), valuesToForwardSet); auto valuesToForward = valuesToForwardSet.takeVector(); @@ -457,15 +460,15 @@ void LoopToGpuConverter::createLaunch(OpTy rootForOp, OpTy innermostForOp, originallyForwardedValues); auto stepArgumentIt = std::next(lbArgumentIt, lbs.size()); for (auto en : llvm::enumerate(ivs)) { - Value *id = + ValuePtr id = en.index() < numBlockDims ? getDim3Value(launchOp.getBlockIds(), en.index()) : getDim3Value(launchOp.getThreadIds(), en.index() - numBlockDims); - Value *step = steps[en.index()]; + ValuePtr step = steps[en.index()]; if (!isConstantOne(step)) id = builder.create(rootForOp.getLoc(), step, id); - Value *ivReplacement = + ValuePtr ivReplacement = builder.create(rootForOp.getLoc(), *lbArgumentIt, id); en.value()->replaceAllUsesWith(ivReplacement); replaceAllUsesInRegionWith(steps[en.index()], *stepArgumentIt, @@ -479,8 +482,8 @@ void LoopToGpuConverter::createLaunch(OpTy rootForOp, OpTy innermostForOp, // trailing positions, make sure we don't touch those. for (const auto &pair : llvm::zip_first(valuesToForward, launchOp.getKernelArguments())) { - Value *from = std::get<0>(pair); - Value *to = std::get<1>(pair); + ValuePtr from = std::get<0>(pair); + ValuePtr to = std::get<1>(pair); replaceAllUsesInRegionWith(from, to, launchOp.body()); } @@ -510,8 +513,8 @@ static LogicalResult convertLoopNestToGPULaunch(OpTy forOp, // nested. The workgroup size and num workgroups is provided as input template static LogicalResult convertLoopToGPULaunch(OpTy forOp, - ArrayRef numWorkGroups, - ArrayRef workGroupSize) { + ArrayRef numWorkGroups, + ArrayRef workGroupSize) { if (failed(checkLoopOpMappable(forOp, numWorkGroups.size(), workGroupSize.size()))) { return failure(); @@ -532,7 +535,7 @@ LogicalResult mlir::convertLoopNestToGPULaunch(ForOp forOp, } LogicalResult mlir::convertLoopToGPULaunch(loop::ForOp forOp, - ArrayRef numWorkGroups, - ArrayRef workGroupSizes) { + ArrayRef numWorkGroups, + ArrayRef workGroupSizes) { return ::convertLoopToGPULaunch(forOp, numWorkGroups, workGroupSizes); } diff --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp index 21abc3cf99b..63836883512 100644 --- a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp +++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp @@ -98,7 +98,7 @@ struct ImperfectlyNestedForLoopMapper // pass is only used for testing. FuncOp funcOp = getFunction(); OpBuilder builder(funcOp.getOperation()->getRegion(0)); - SmallVector numWorkGroupsVal, workGroupSizeVal; + SmallVector numWorkGroupsVal, workGroupSizeVal; for (auto val : numWorkGroups) { auto constOp = builder.create( funcOp.getLoc(), builder.getIntegerAttr(builder.getIndexType(), val)); diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index fdc90851b64..67b545c4ec8 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -256,20 +256,20 @@ LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context, /*============================================================================*/ /* StructBuilder implementation */ /*============================================================================*/ -StructBuilder::StructBuilder(Value *v) : value(v) { +StructBuilder::StructBuilder(ValuePtr v) : value(v) { assert(value != nullptr && "value cannot be null"); structType = value->getType().cast(); } -Value *StructBuilder::extractPtr(OpBuilder &builder, Location loc, - unsigned pos) { +ValuePtr StructBuilder::extractPtr(OpBuilder &builder, Location loc, + unsigned pos) { Type type = structType.cast().getStructElementType(pos); return builder.create(loc, type, value, builder.getI64ArrayAttr(pos)); } void StructBuilder::setPtr(OpBuilder &builder, Location loc, unsigned pos, - Value *ptr) { + ValuePtr ptr) { value = builder.create(loc, structType, value, ptr, builder.getI64ArrayAttr(pos)); } @@ -278,7 +278,7 @@ void StructBuilder::setPtr(OpBuilder &builder, Location loc, unsigned pos, /*============================================================================*/ /// Construct a helper for the given descriptor value. -MemRefDescriptor::MemRefDescriptor(Value *descriptor) +MemRefDescriptor::MemRefDescriptor(ValuePtr descriptor) : StructBuilder(descriptor) { assert(value != nullptr && "value cannot be null"); indexType = value->getType().cast().getStructElementType( @@ -289,7 +289,7 @@ MemRefDescriptor::MemRefDescriptor(Value *descriptor) MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc, Type descriptorType) { - Value *descriptor = + ValuePtr descriptor = builder.create(loc, descriptorType.cast()); return MemRefDescriptor(descriptor); } @@ -300,7 +300,7 @@ MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc, MemRefDescriptor MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, - MemRefType type, Value *memory) { + MemRefType type, ValuePtr memory) { assert(type.hasStaticShape() && "unexpected dynamic shape"); assert(type.getAffineMaps().empty() && "unexpected layout map"); @@ -325,37 +325,37 @@ MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc, } /// Builds IR extracting the allocated pointer from the descriptor. -Value *MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) { +ValuePtr MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) { return extractPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor); } /// Builds IR inserting the allocated pointer into the descriptor. void MemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc, - Value *ptr) { + ValuePtr ptr) { setPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor, ptr); } /// Builds IR extracting the aligned pointer from the descriptor. -Value *MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) { +ValuePtr MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) { return extractPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor); } /// Builds IR inserting the aligned pointer into the descriptor. void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc, - Value *ptr) { + ValuePtr ptr) { setPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor, ptr); } // Creates a constant Op producing a value of `resultType` from an index-typed // integer attribute. -static Value *createIndexAttrConstant(OpBuilder &builder, Location loc, - Type resultType, int64_t value) { +static ValuePtr createIndexAttrConstant(OpBuilder &builder, Location loc, + Type resultType, int64_t value) { return builder.create( loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value)); } /// Builds IR extracting the offset from the descriptor. -Value *MemRefDescriptor::offset(OpBuilder &builder, Location loc) { +ValuePtr MemRefDescriptor::offset(OpBuilder &builder, Location loc) { return builder.create( loc, indexType, value, builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor)); @@ -363,7 +363,7 @@ Value *MemRefDescriptor::offset(OpBuilder &builder, Location loc) { /// Builds IR inserting the offset into the descriptor. void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc, - Value *offset) { + ValuePtr offset) { value = builder.create( loc, structType, value, offset, builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor)); @@ -377,7 +377,8 @@ void MemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc, } /// Builds IR extracting the pos-th size from the descriptor. -Value *MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) { +ValuePtr MemRefDescriptor::size(OpBuilder &builder, Location loc, + unsigned pos) { return builder.create( loc, indexType, value, builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos})); @@ -385,7 +386,7 @@ Value *MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) { /// Builds IR inserting the pos-th size into the descriptor void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos, - Value *size) { + ValuePtr size) { value = builder.create( loc, structType, value, size, builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos})); @@ -399,8 +400,8 @@ void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc, } /// Builds IR extracting the pos-th size from the descriptor. -Value *MemRefDescriptor::stride(OpBuilder &builder, Location loc, - unsigned pos) { +ValuePtr MemRefDescriptor::stride(OpBuilder &builder, Location loc, + unsigned pos) { return builder.create( loc, indexType, value, builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos})); @@ -408,7 +409,7 @@ Value *MemRefDescriptor::stride(OpBuilder &builder, Location loc, /// Builds IR inserting the pos-th stride into the descriptor void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos, - Value *stride) { + ValuePtr stride) { value = builder.create( loc, structType, value, stride, builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos})); @@ -431,30 +432,30 @@ LLVM::LLVMType MemRefDescriptor::getElementType() { /*============================================================================*/ /// Construct a helper for the given descriptor value. -UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value *descriptor) +UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(ValuePtr descriptor) : StructBuilder(descriptor) {} /// Builds IR creating an `undef` value of the descriptor type. UnrankedMemRefDescriptor UnrankedMemRefDescriptor::undef(OpBuilder &builder, Location loc, Type descriptorType) { - Value *descriptor = + ValuePtr descriptor = builder.create(loc, descriptorType.cast()); return UnrankedMemRefDescriptor(descriptor); } -Value *UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) { +ValuePtr UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) { return extractPtr(builder, loc, kRankInUnrankedMemRefDescriptor); } void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc, - Value *v) { + ValuePtr v) { setPtr(builder, loc, kRankInUnrankedMemRefDescriptor, v); } -Value *UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder, - Location loc) { +ValuePtr UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder, + Location loc) { return extractPtr(builder, loc, kPtrInUnrankedMemRefDescriptor); } void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder, - Location loc, Value *v) { + Location loc, ValuePtr v) { setPtr(builder, loc, kPtrInUnrankedMemRefDescriptor, v); } namespace { @@ -495,8 +496,8 @@ public: } // Create an LLVM IR pseudo-operation defining the given index constant. - Value *createIndexConstant(ConversionPatternRewriter &builder, Location loc, - uint64_t value) const { + ValuePtr createIndexConstant(ConversionPatternRewriter &builder, Location loc, + uint64_t value) const { return createIndexAttrConstant(builder, loc, getIndexType(), value); } @@ -508,7 +509,7 @@ struct FuncOpConversion : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto funcOp = cast(op); FunctionType type = funcOp.getType(); @@ -556,8 +557,8 @@ struct FuncOpConversion : public LLVMLegalizationPattern { Block *firstBlock = &newFuncOp.getBody().front(); rewriter.setInsertionPoint(firstBlock, firstBlock->begin()); for (unsigned idx : promotedArgIndices) { - BlockArgument *arg = firstBlock->getArgument(idx); - Value *loaded = rewriter.create(funcOp.getLoc(), arg); + BlockArgumentPtr arg = firstBlock->getArgument(idx); + ValuePtr loaded = rewriter.create(funcOp.getLoc(), arg); rewriter.replaceUsesOfBlockArgument(arg, loaded); } } @@ -656,7 +657,7 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern { // Convert the type of the result to an LLVM type, pass operands as is, // preserve attributes. PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { unsigned numResults = op->getNumResults(); @@ -680,7 +681,7 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern { // Otherwise, it had been converted to an operation producing a structure. // Extract individual results from the structure and return them as list. - SmallVector results; + SmallVector results; results.reserve(numResults); for (unsigned i = 0; i < numResults; ++i) { auto type = this->lowering.convertType(op->getResult(i)->getType()); @@ -721,7 +722,7 @@ struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern { // Convert the type of the result to an LLVM type, pass operands as is, // preserve attributes. PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { ValidateOpCount(); static_assert( @@ -732,7 +733,7 @@ struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern { "expected same operands and result type"); // Cannot convert ops if their operands are not of LLVM type. - for (Value *operand : operands) { + for (ValuePtr operand : operands) { if (!operand || !operand->getType().isa()) return this->matchFailure(); } @@ -755,16 +756,16 @@ struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern { if (!llvmVectorTy || llvmArrayTy != vectorTypeInfo.llvmArrayTy) return this->matchFailure(); - Value *desc = rewriter.create(loc, llvmArrayTy); + ValuePtr desc = rewriter.create(loc, llvmArrayTy); nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) { // For this unrolled `position` corresponding to the `linearIndex`^th // element, extract operand vectors - SmallVector extractedOperands; + SmallVector extractedOperands; for (unsigned i = 0; i < OpCount; ++i) { extractedOperands.push_back(rewriter.create( loc, llvmVectorTy, operands[i], position)); } - Value *newVal = rewriter.create( + ValuePtr newVal = rewriter.create( loc, llvmVectorTy, extractedOperands, op->getAttrs()); desc = rewriter.create(loc, llvmArrayTy, desc, newVal, position); @@ -927,7 +928,7 @@ struct AllocOpLowering : public LLVMLegalizationPattern { return matchSuccess(); } - void rewrite(Operation *op, ArrayRef operands, + void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto allocOp = cast(op); @@ -936,7 +937,7 @@ struct AllocOpLowering : public LLVMLegalizationPattern { // Get actual sizes of the memref as values: static sizes are constant // values and dynamic sizes are passed to 'alloc' as operands. In case of // zero-dimensional memref, assume a scalar (size 1). - SmallVector sizes; + SmallVector sizes; sizes.reserve(type.getRank()); unsigned i = 0; for (int64_t s : type.getShape()) @@ -946,10 +947,10 @@ struct AllocOpLowering : public LLVMLegalizationPattern { sizes.push_back(createIndexConstant(rewriter, loc, 1)); // Compute the total number of memref elements. - Value *cumulativeSize = sizes.front(); + ValuePtr cumulativeSize = sizes.front(); for (unsigned i = 1, e = sizes.size(); i < e; ++i) cumulativeSize = rewriter.create( - loc, getIndexType(), ArrayRef{cumulativeSize, sizes[i]}); + loc, getIndexType(), ArrayRef{cumulativeSize, sizes[i]}); // Compute the size of an individual element. This emits the MLIR equivalent // of the following sizeof(...) implementation in LLVM IR: @@ -962,17 +963,17 @@ struct AllocOpLowering : public LLVMLegalizationPattern { auto nullPtr = rewriter.create(loc, convertedPtrType); auto one = createIndexConstant(rewriter, loc, 1); auto gep = rewriter.create(loc, convertedPtrType, - ArrayRef{nullPtr, one}); + ArrayRef{nullPtr, one}); auto elementSize = rewriter.create(loc, getIndexType(), gep); cumulativeSize = rewriter.create( - loc, getIndexType(), ArrayRef{cumulativeSize, elementSize}); + loc, getIndexType(), ArrayRef{cumulativeSize, elementSize}); // Allocate the underlying buffer and store a pointer to it in the MemRef // descriptor. - Value *allocated = nullptr; + ValuePtr allocated = nullptr; int alignment = 0; - Value *alignmentValue = nullptr; + ValuePtr alignmentValue = nullptr; if (auto alignAttr = allocOp.alignment()) alignment = alignAttr.getValue().getSExtValue(); @@ -1008,8 +1009,8 @@ struct AllocOpLowering : public LLVMLegalizationPattern { auto structElementType = lowering.convertType(elementType); auto elementPtrType = structElementType.cast().getPointerTo( type.getMemorySpace()); - Value *bitcastAllocated = rewriter.create( - loc, elementPtrType, ArrayRef(allocated)); + ValuePtr bitcastAllocated = rewriter.create( + loc, elementPtrType, ArrayRef(allocated)); int64_t offset; SmallVector strides; @@ -1031,22 +1032,22 @@ struct AllocOpLowering : public LLVMLegalizationPattern { memRefDescriptor.setAllocatedPtr(rewriter, loc, bitcastAllocated); // Field 2: Actual aligned pointer to payload. - Value *bitcastAligned = bitcastAllocated; + ValuePtr bitcastAligned = bitcastAllocated; if (!useAlloca && alignment != 0) { assert(alignmentValue); // offset = (align - (ptr % align))% align - Value *intVal = rewriter.create( + ValuePtr intVal = rewriter.create( loc, this->getIndexType(), allocated); - Value *ptrModAlign = + ValuePtr ptrModAlign = rewriter.create(loc, intVal, alignmentValue); - Value *subbed = + ValuePtr subbed = rewriter.create(loc, alignmentValue, ptrModAlign); - Value *offset = + ValuePtr offset = rewriter.create(loc, subbed, alignmentValue); - Value *aligned = rewriter.create(loc, allocated->getType(), - allocated, offset); + ValuePtr aligned = rewriter.create(loc, allocated->getType(), + allocated, offset); bitcastAligned = rewriter.create( - loc, elementPtrType, ArrayRef(aligned)); + loc, elementPtrType, ArrayRef(aligned)); } memRefDescriptor.setAlignedPtr(rewriter, loc, bitcastAligned); @@ -1061,10 +1062,10 @@ struct AllocOpLowering : public LLVMLegalizationPattern { // Fields 4 and 5: Sizes and strides of the strided MemRef. // Store all sizes in the descriptor. Only dynamic sizes are passed in as // operands to AllocOp. - Value *runningStride = nullptr; + ValuePtr runningStride = nullptr; // Iterate strides in reverse order, compute runningStride and strideValues. auto nStrides = strides.size(); - SmallVector strideValues(nStrides, nullptr); + SmallVector strideValues(nStrides, nullptr); for (auto indexedStride : llvm::enumerate(llvm::reverse(strides))) { int64_t index = nStrides - 1 - indexedStride.index(); if (strides[index] == MemRefType::getDynamicStrideOrOffset()) @@ -1101,7 +1102,7 @@ struct CallOpInterfaceLowering : public LLVMLegalizationPattern { using Base = LLVMLegalizationPattern; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { OperandAdaptor transformed(operands); auto callOp = cast(op); @@ -1139,7 +1140,7 @@ struct CallOpInterfaceLowering : public LLVMLegalizationPattern { // TODO(aminim, ntv, riverriddle, zinenko): this seems like patching around // a particular interaction between MemRefType and CallOp lowering. Find a // way to avoid special casing. - SmallVector results; + SmallVector results; results.reserve(numResults); for (unsigned i = 0; i < numResults; ++i) { auto type = this->lowering.convertType(op->getResult(i)->getType()); @@ -1173,7 +1174,7 @@ struct DeallocOpLowering : public LLVMLegalizationPattern { useAlloca(useAlloca) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { if (useAlloca) return rewriter.eraseOp(op), matchSuccess(); @@ -1193,7 +1194,7 @@ struct DeallocOpLowering : public LLVMLegalizationPattern { } MemRefDescriptor memref(transformed.memref()); - Value *casted = rewriter.create( + ValuePtr casted = rewriter.create( op->getLoc(), getVoidPtrType(), memref.allocatedPtr(rewriter, op->getLoc())); rewriter.replaceOpWithNewOp( @@ -1209,7 +1210,7 @@ struct TanhOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { using LLVMFuncOpT = LLVM::LLVMFuncOp; @@ -1283,7 +1284,7 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern { : matchFailure(); } - void rewrite(Operation *op, ArrayRef operands, + void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto memRefCastOp = cast(op); OperandAdaptor transformed(operands); @@ -1324,7 +1325,7 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern { memRefDesc.setRank(rewriter, loc, rankVal); // d2 = InsertValueOp d1, voidptr, 1 memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr); - rewriter.replaceOp(op, (Value *)memRefDesc); + rewriter.replaceOp(op, (ValuePtr)memRefDesc); } else if (srcType.isa() && dstType.isa()) { // Casting from unranked type to ranked. @@ -1355,7 +1356,7 @@ struct DimOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dimOp = cast(op); OperandAdaptor transformed(operands); @@ -1397,43 +1398,45 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern { // by accumulating the running linearized value. // Note that `indices` and `allocSizes` are passed in the same order as they // appear in load/store operations and memref type declarations. - Value *linearizeSubscripts(ConversionPatternRewriter &builder, Location loc, - ArrayRef indices, - ArrayRef allocSizes) const { + ValuePtr linearizeSubscripts(ConversionPatternRewriter &builder, Location loc, + ArrayRef indices, + ArrayRef allocSizes) const { assert(indices.size() == allocSizes.size() && "mismatching number of indices and allocation sizes"); assert(!indices.empty() && "cannot linearize a 0-dimensional access"); - Value *linearized = indices.front(); + ValuePtr linearized = indices.front(); for (int i = 1, nSizes = allocSizes.size(); i < nSizes; ++i) { linearized = builder.create( loc, this->getIndexType(), - ArrayRef{linearized, allocSizes[i]}); + ArrayRef{linearized, allocSizes[i]}); linearized = builder.create( - loc, this->getIndexType(), ArrayRef{linearized, indices[i]}); + loc, this->getIndexType(), + ArrayRef{linearized, indices[i]}); } return linearized; } // This is a strided getElementPtr variant that linearizes subscripts as: // `base_offset + index_0 * stride_0 + ... + index_n * stride_n`. - Value *getStridedElementPtr(Location loc, Type elementTypePtr, - Value *descriptor, ArrayRef indices, - ArrayRef strides, int64_t offset, - ConversionPatternRewriter &rewriter) const { + ValuePtr getStridedElementPtr(Location loc, Type elementTypePtr, + ValuePtr descriptor, ArrayRef indices, + ArrayRef strides, int64_t offset, + ConversionPatternRewriter &rewriter) const { MemRefDescriptor memRefDescriptor(descriptor); - Value *base = memRefDescriptor.alignedPtr(rewriter, loc); - Value *offsetValue = offset == MemRefType::getDynamicStrideOrOffset() - ? memRefDescriptor.offset(rewriter, loc) - : this->createIndexConstant(rewriter, loc, offset); + ValuePtr base = memRefDescriptor.alignedPtr(rewriter, loc); + ValuePtr offsetValue = + offset == MemRefType::getDynamicStrideOrOffset() + ? memRefDescriptor.offset(rewriter, loc) + : this->createIndexConstant(rewriter, loc, offset); for (int i = 0, e = indices.size(); i < e; ++i) { - Value *stride = + ValuePtr stride = strides[i] == MemRefType::getDynamicStrideOrOffset() ? memRefDescriptor.stride(rewriter, loc, i) : this->createIndexConstant(rewriter, loc, strides[i]); - Value *additionalOffset = + ValuePtr additionalOffset = rewriter.create(loc, indices[i], stride); offsetValue = rewriter.create(loc, offsetValue, additionalOffset); @@ -1441,10 +1444,10 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern { return rewriter.create(loc, elementTypePtr, base, offsetValue); } - Value *getDataPtr(Location loc, MemRefType type, Value *memRefDesc, - ArrayRef indices, - ConversionPatternRewriter &rewriter, - llvm::Module &module) const { + ValuePtr getDataPtr(Location loc, MemRefType type, ValuePtr memRefDesc, + ArrayRef indices, + ConversionPatternRewriter &rewriter, + llvm::Module &module) const { LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType(); int64_t offset; SmallVector strides; @@ -1462,14 +1465,14 @@ struct LoadOpLowering : public LoadStoreOpLowering { using Base::Base; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loadOp = cast(op); OperandAdaptor transformed(operands); auto type = loadOp.getMemRefType(); - Value *dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), - transformed.indices(), rewriter, getModule()); + ValuePtr dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), + transformed.indices(), rewriter, getModule()); rewriter.replaceOpWithNewOp(op, dataPtr); return matchSuccess(); } @@ -1481,13 +1484,13 @@ struct StoreOpLowering : public LoadStoreOpLowering { using Base::Base; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto type = cast(op).getMemRefType(); OperandAdaptor transformed(operands); - Value *dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), - transformed.indices(), rewriter, getModule()); + ValuePtr dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), + transformed.indices(), rewriter, getModule()); rewriter.replaceOpWithNewOp(op, transformed.value(), dataPtr); return matchSuccess(); @@ -1500,14 +1503,14 @@ struct PrefetchOpLowering : public LoadStoreOpLowering { using Base::Base; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto prefetchOp = cast(op); OperandAdaptor transformed(operands); auto type = prefetchOp.getMemRefType(); - Value *dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), - transformed.indices(), rewriter, getModule()); + ValuePtr dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), + transformed.indices(), rewriter, getModule()); // Replace with llvm.prefetch. auto llvmI32Type = lowering.convertType(rewriter.getIntegerType(32)); @@ -1535,7 +1538,7 @@ struct IndexCastOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { IndexCastOpOperandAdaptor transformed(operands); auto indexCastOp = cast(op); @@ -1570,7 +1573,7 @@ struct CmpIOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto cmpiOp = cast(op); CmpIOpOperandAdaptor transformed(operands); @@ -1589,7 +1592,7 @@ struct CmpFOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto cmpfOp = cast(op); CmpFOpOperandAdaptor transformed(operands); @@ -1641,9 +1644,9 @@ struct OneToOneLLVMTerminatorLowering using Super = OneToOneLLVMTerminatorLowering; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef properOperands, + matchAndRewrite(Operation *op, ArrayRef properOperands, ArrayRef destinations, - ArrayRef> operands, + ArrayRef> operands, ConversionPatternRewriter &rewriter) const override { SmallVector operandRanges(operands.begin(), operands.end()); rewriter.replaceOpWithNewOp(op, properOperands, destinations, @@ -1662,19 +1665,19 @@ struct ReturnOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { unsigned numArguments = op->getNumOperands(); // If ReturnOp has 0 or 1 operand, create it and return immediately. if (numArguments == 0) { rewriter.replaceOpWithNewOp( - op, ArrayRef(), ArrayRef(), op->getAttrs()); + op, ArrayRef(), ArrayRef(), op->getAttrs()); return matchSuccess(); } if (numArguments == 1) { rewriter.replaceOpWithNewOp( - op, ArrayRef(operands.front()), ArrayRef(), + op, ArrayRef(operands.front()), ArrayRef(), op->getAttrs()); return matchSuccess(); } @@ -1684,7 +1687,7 @@ struct ReturnOpLowering : public LLVMLegalizationPattern { auto packedType = lowering.packFunctionResults(llvm::to_vector<4>(op->getOperandTypes())); - Value *packed = rewriter.create(op->getLoc(), packedType); + ValuePtr packed = rewriter.create(op->getLoc(), packedType); for (unsigned i = 0; i < numArguments; ++i) { packed = rewriter.create( op->getLoc(), packedType, packed, operands[i], @@ -1712,7 +1715,7 @@ struct SplatOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto splatOp = cast(op); VectorType resultType = splatOp.getType().dyn_cast(); @@ -1721,7 +1724,7 @@ struct SplatOpLowering : public LLVMLegalizationPattern { // First insert it into an undef vector so we can shuffle it. auto vectorType = lowering.convertType(splatOp.getType()); - Value *undef = rewriter.create(op->getLoc(), vectorType); + ValuePtr undef = rewriter.create(op->getLoc(), vectorType); auto zero = rewriter.create( op->getLoc(), lowering.convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); @@ -1746,7 +1749,7 @@ struct SplatNdOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto splatOp = cast(op); OperandAdaptor adaptor(operands); @@ -1763,16 +1766,16 @@ struct SplatNdOpLowering : public LLVMLegalizationPattern { return matchFailure(); // Construct returned value. - Value *desc = rewriter.create(loc, llvmArrayTy); + ValuePtr desc = rewriter.create(loc, llvmArrayTy); // Construct a 1-D vector with the splatted value that we insert in all the // places within the returned descriptor. - Value *vdesc = rewriter.create(loc, llvmVectorTy); + ValuePtr vdesc = rewriter.create(loc, llvmVectorTy); auto zero = rewriter.create( loc, lowering.convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); - Value *v = rewriter.create(loc, llvmVectorTy, vdesc, - adaptor.input(), zero); + ValuePtr v = rewriter.create( + loc, llvmVectorTy, vdesc, adaptor.input(), zero); // Shuffle the value across the desired number of elements. int64_t width = resultType.getDimSize(resultType.getRank() - 1); @@ -1800,21 +1803,21 @@ struct SubViewOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto viewOp = cast(op); // TODO(b/144779634, ravishankarm) : After Tblgen is adapted to support // having multiple variadic operands where each operand can have different // number of entries, clean all of this up. - SmallVector dynamicOffsets( + SmallVector dynamicOffsets( std::next(operands.begin()), std::next(operands.begin(), 1 + viewOp.getNumOffsets())); - SmallVector dynamicSizes( + SmallVector dynamicSizes( std::next(operands.begin(), 1 + viewOp.getNumOffsets()), std::next(operands.begin(), 1 + viewOp.getNumOffsets() + viewOp.getNumSizes())); - SmallVector dynamicStrides( + SmallVector dynamicStrides( std::next(operands.begin(), 1 + viewOp.getNumOffsets() + viewOp.getNumSizes()), operands.end()); @@ -1851,8 +1854,8 @@ struct SubViewOpLowering : public LLVMLegalizationPattern { auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); // Copy the buffer pointer from the old descriptor to the new one. - Value *extracted = sourceMemRef.allocatedPtr(rewriter, loc); - Value *bitcastPtr = rewriter.create( + ValuePtr extracted = sourceMemRef.allocatedPtr(rewriter, loc); + ValuePtr bitcastPtr = rewriter.create( loc, targetElementTy.getPointerTo(), extracted); targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); @@ -1862,7 +1865,7 @@ struct SubViewOpLowering : public LLVMLegalizationPattern { targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); // Extract strides needed to compute offset. - SmallVector strideValues; + SmallVector strideValues; strideValues.reserve(viewMemRefType.getRank()); for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); @@ -1879,9 +1882,9 @@ struct SubViewOpLowering : public LLVMLegalizationPattern { } // Offset. - Value *baseOffset = sourceMemRef.offset(rewriter, loc); + ValuePtr baseOffset = sourceMemRef.offset(rewriter, loc); for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) { - Value *min = dynamicOffsets[i]; + ValuePtr min = dynamicOffsets[i]; baseOffset = rewriter.create( loc, baseOffset, rewriter.create(loc, min, strideValues[i])); @@ -1891,7 +1894,7 @@ struct SubViewOpLowering : public LLVMLegalizationPattern { // Update sizes and strides. for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { targetMemRef.setSize(rewriter, loc, i, dynamicSizes[i]); - Value *newStride; + ValuePtr newStride; if (dynamicStrides.empty()) newStride = rewriter.create( loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i])); @@ -1916,9 +1919,9 @@ struct ViewOpLowering : public LLVMLegalizationPattern { // Build and return the value for the idx^th shape dimension, either by // returning the constant shape dimension or counting the proper dynamic size. - Value *getSize(ConversionPatternRewriter &rewriter, Location loc, - ArrayRef shape, ArrayRef dynamicSizes, - unsigned idx) const { + ValuePtr getSize(ConversionPatternRewriter &rewriter, Location loc, + ArrayRef shape, ArrayRef dynamicSizes, + unsigned idx) const { assert(idx < shape.size()); if (!ShapedType::isDynamic(shape[idx])) return createIndexConstant(rewriter, loc, shape[idx]); @@ -1933,9 +1936,9 @@ struct ViewOpLowering : public LLVMLegalizationPattern { // or by computing the dynamic stride from the current `runningStride` and // `nextSize`. The caller should keep a running stride and update it with the // result returned by this function. - Value *getStride(ConversionPatternRewriter &rewriter, Location loc, - ArrayRef strides, Value *nextSize, - Value *runningStride, unsigned idx) const { + ValuePtr getStride(ConversionPatternRewriter &rewriter, Location loc, + ArrayRef strides, ValuePtr nextSize, + ValuePtr runningStride, unsigned idx) const { assert(idx < strides.size()); if (strides[idx] != MemRefType::getDynamicStrideOrOffset()) return createIndexConstant(rewriter, loc, strides[idx]); @@ -1948,7 +1951,7 @@ struct ViewOpLowering : public LLVMLegalizationPattern { } PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto viewOp = cast(op); @@ -1975,8 +1978,8 @@ struct ViewOpLowering : public LLVMLegalizationPattern { auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); // Field 1: Copy the allocated pointer, used for malloc/free. - Value *extracted = sourceMemRef.allocatedPtr(rewriter, loc); - Value *bitcastPtr = rewriter.create( + ValuePtr extracted = sourceMemRef.allocatedPtr(rewriter, loc); + ValuePtr bitcastPtr = rewriter.create( loc, targetElementTy.getPointerTo(), extracted); targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); @@ -1993,10 +1996,10 @@ struct ViewOpLowering : public LLVMLegalizationPattern { auto sizeAndOffsetOperands = adaptor.operands(); assert(llvm::size(sizeAndOffsetOperands) == numDynamicSizes + (hasDynamicOffset ? 1 : 0)); - Value *baseOffset = !hasDynamicOffset - ? createIndexConstant(rewriter, loc, offset) - // TODO(ntv): better adaptor. - : sizeAndOffsetOperands.front(); + ValuePtr baseOffset = !hasDynamicOffset + ? createIndexConstant(rewriter, loc, offset) + // TODO(ntv): better adaptor. + : sizeAndOffsetOperands.front(); targetMemRef.setOffset(rewriter, loc, baseOffset); // Early exit for 0-D corner case. @@ -2007,14 +2010,14 @@ struct ViewOpLowering : public LLVMLegalizationPattern { if (strides.back() != 1) return op->emitWarning("cannot cast to non-contiguous shape"), matchFailure(); - Value *stride = nullptr, *nextSize = nullptr; + ValuePtr stride = nullptr, nextSize = nullptr; // Drop the dynamic stride from the operand list, if present. - ArrayRef sizeOperands(sizeAndOffsetOperands); + ArrayRef sizeOperands(sizeAndOffsetOperands); if (hasDynamicOffset) sizeOperands = sizeOperands.drop_front(); for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { // Update size. - Value *size = + ValuePtr size = getSize(rewriter, loc, viewMemRefType.getShape(), sizeOperands, i); targetMemRef.setSize(rewriter, loc, i, size); // Update stride. @@ -2058,7 +2061,7 @@ static void ensureDistinctSuccessors(Block &bb) { auto *dummyBlock = new Block(); bb.getParent()->push_back(dummyBlock); auto builder = OpBuilder(dummyBlock); - SmallVector operands( + SmallVector operands( terminator->getSuccessorOperands(*position)); builder.create(terminator->getLoc(), successor.first, operands); terminator->setSuccessor(dummyBlock, *position); @@ -2179,33 +2182,33 @@ Type LLVMTypeConverter::packFunctionResults(ArrayRef types) { return LLVM::LLVMType::getStructTy(llvmDialect, resultTypes); } -Value *LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, - Value *operand, - OpBuilder &builder) { +ValuePtr LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, + ValuePtr operand, + OpBuilder &builder) { auto *context = builder.getContext(); auto int64Ty = LLVM::LLVMType::getInt64Ty(getDialect()); auto indexType = IndexType::get(context); // Alloca with proper alignment. We do not expect optimizations of this // alloca op and so we omit allocating at the entry block. auto ptrType = operand->getType().cast().getPointerTo(); - Value *one = builder.create(loc, int64Ty, - IntegerAttr::get(indexType, 1)); - Value *allocated = + ValuePtr one = builder.create( + loc, int64Ty, IntegerAttr::get(indexType, 1)); + ValuePtr allocated = builder.create(loc, ptrType, one, /*alignment=*/0); // Store into the alloca'ed descriptor. builder.create(loc, operand, allocated); return allocated; } -SmallVector +SmallVector LLVMTypeConverter::promoteMemRefDescriptors(Location loc, ValueRange opOperands, ValueRange operands, OpBuilder &builder) { - SmallVector promotedOperands; + SmallVector promotedOperands; promotedOperands.reserve(operands.size()); for (auto it : llvm::zip(opOperands, operands)) { - auto *operand = std::get<0>(it); - auto *llvmOperand = std::get<1>(it); + auto operand = std::get<0>(it); + auto llvmOperand = std::get<1>(it); if (!operand->getType().isa() && !operand->getType().isa()) { promotedOperands.push_back(operand); diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp index a14271efbb6..f7b0c9cb9bc 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -44,7 +44,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult - matchAndRewrite(ConstantOp constIndexOp, ArrayRef operands, + matchAndRewrite(ConstantOp constIndexOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -54,7 +54,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult - matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, + matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -70,7 +70,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult - matchAndRewrite(StdOp operation, ArrayRef operands, + matchAndRewrite(StdOp operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto resultType = this->typeConverter.convertType(operation.getResult()->getType()); @@ -89,7 +89,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult - matchAndRewrite(LoadOp loadOp, ArrayRef operands, + matchAndRewrite(LoadOp loadOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -100,7 +100,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult - matchAndRewrite(ReturnOp returnOp, ArrayRef operands, + matchAndRewrite(ReturnOp returnOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -110,7 +110,7 @@ class SelectOpConversion final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult - matchAndRewrite(SelectOp op, ArrayRef operands, + matchAndRewrite(SelectOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -123,7 +123,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult - matchAndRewrite(StoreOp storeOp, ArrayRef operands, + matchAndRewrite(StoreOp storeOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -141,7 +141,8 @@ public: spirv::AccessChainOp getElementPtr(OpBuilder &builder, SPIRVTypeConverter &typeConverter, Location loc, MemRefType origBaseType, - Value *basePtr, ArrayRef indices) { + ValuePtr basePtr, + ArrayRef indices) { // Get base and offset of the MemRefType and verify they are static. int64_t offset; SmallVector strides; @@ -152,18 +153,18 @@ spirv::AccessChainOp getElementPtr(OpBuilder &builder, auto indexType = typeConverter.getIndexType(builder.getContext()); - Value *ptrLoc = nullptr; + ValuePtr ptrLoc = nullptr; assert(indices.size() == strides.size()); for (auto index : enumerate(indices)) { - Value *strideVal = builder.create( + ValuePtr strideVal = builder.create( loc, indexType, IntegerAttr::get(indexType, strides[index.index()])); - Value *update = + ValuePtr update = builder.create(loc, strideVal, index.value()); ptrLoc = (ptrLoc ? builder.create(loc, ptrLoc, update).getResult() : update); } - SmallVector linearizedIndices; + SmallVector linearizedIndices; // Add a '0' at the start to index into the struct. linearizedIndices.push_back(builder.create( loc, indexType, IntegerAttr::get(indexType, 0))); @@ -176,7 +177,7 @@ spirv::AccessChainOp getElementPtr(OpBuilder &builder, //===----------------------------------------------------------------------===// PatternMatchResult ConstantIndexOpConversion::matchAndRewrite( - ConstantOp constIndexOp, ArrayRef operands, + ConstantOp constIndexOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (!constIndexOp.getResult()->getType().isa()) { return matchFailure(); @@ -210,7 +211,7 @@ PatternMatchResult ConstantIndexOpConversion::matchAndRewrite( //===----------------------------------------------------------------------===// PatternMatchResult -CmpIOpConversion::matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, +CmpIOpConversion::matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { CmpIOpOperandAdaptor cmpIOpOperands(operands); @@ -242,7 +243,7 @@ CmpIOpConversion::matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, //===----------------------------------------------------------------------===// PatternMatchResult -LoadOpConversion::matchAndRewrite(LoadOp loadOp, ArrayRef operands, +LoadOpConversion::matchAndRewrite(LoadOp loadOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { LoadOpOperandAdaptor loadOperands(operands); auto loadPtr = getElementPtr(rewriter, typeConverter, loadOp.getLoc(), @@ -260,7 +261,7 @@ LoadOpConversion::matchAndRewrite(LoadOp loadOp, ArrayRef operands, PatternMatchResult ReturnOpConversion::matchAndRewrite(ReturnOp returnOp, - ArrayRef operands, + ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (returnOp.getNumOperands()) { return matchFailure(); @@ -274,7 +275,7 @@ ReturnOpConversion::matchAndRewrite(ReturnOp returnOp, //===----------------------------------------------------------------------===// PatternMatchResult -SelectOpConversion::matchAndRewrite(SelectOp op, ArrayRef operands, +SelectOpConversion::matchAndRewrite(SelectOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { SelectOpOperandAdaptor selectOperands(operands); rewriter.replaceOpWithNewOp(op, selectOperands.condition(), @@ -288,7 +289,7 @@ SelectOpConversion::matchAndRewrite(SelectOp op, ArrayRef operands, //===----------------------------------------------------------------------===// PatternMatchResult -StoreOpConversion::matchAndRewrite(StoreOp storeOp, ArrayRef operands, +StoreOpConversion::matchAndRewrite(StoreOp storeOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { StoreOpOperandAdaptor storeOperands(operands); auto storePtr = diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp index c0c56a3b0b2..113789abe8a 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp @@ -37,7 +37,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult - matchAndRewrite(FuncOp funcOp, ArrayRef operands, + matchAndRewrite(FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -49,7 +49,7 @@ class ConvertStandardToSPIRVPass } // namespace PatternMatchResult -FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef operands, +FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { auto fnType = funcOp.getType(); if (fnType.getNumResults()) { diff --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp index 4469c2802a8..2e1a7f09ff8 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp @@ -69,7 +69,7 @@ public: static LogicalResult resolveSourceIndices(Location loc, PatternRewriter &rewriter, SubViewOp subViewOp, ValueRange indices, - SmallVectorImpl &sourceIndices) { + SmallVectorImpl &sourceIndices) { // TODO: Aborting when the offsets are static. There might be a way to fold // the subview op with load even if the offsets have been canonicalized // away. @@ -77,7 +77,7 @@ resolveSourceIndices(Location loc, PatternRewriter &rewriter, return failure(); ValueRange opOffsets = subViewOp.offsets(); - SmallVector opStrides; + SmallVector opStrides; if (subViewOp.getNumStrides()) { // If the strides are dynamic, get the stride operands. opStrides = llvm::to_vector<2>(subViewOp.strides()); @@ -124,7 +124,7 @@ LoadOpOfSubViewFolder::matchAndRewrite(LoadOp loadOp, if (!subViewOp) { return matchFailure(); } - SmallVector sourceIndices; + SmallVector sourceIndices; if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp, loadOp.indices(), sourceIndices))) return matchFailure(); @@ -146,7 +146,7 @@ StoreOpOfSubViewFolder::matchAndRewrite(StoreOp storeOp, if (!subViewOp) { return matchFailure(); } - SmallVector sourceIndices; + SmallVector sourceIndices; if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp, storeOp.indices(), sourceIndices))) return matchFailure(); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 9ec8ec6f88d..5099cb01bbc 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -62,9 +62,10 @@ static VectorType reducedVectorTypeBack(VectorType tp) { } // Helper that picks the proper sequence for inserting. -static Value *insertOne(ConversionPatternRewriter &rewriter, - LLVMTypeConverter &lowering, Location loc, Value *val1, - Value *val2, Type llvmType, int64_t rank, int64_t pos) { +static ValuePtr insertOne(ConversionPatternRewriter &rewriter, + LLVMTypeConverter &lowering, Location loc, + ValuePtr val1, ValuePtr val2, Type llvmType, + int64_t rank, int64_t pos) { if (rank == 1) { auto idxType = rewriter.getIndexType(); auto constant = rewriter.create( @@ -78,9 +79,10 @@ static Value *insertOne(ConversionPatternRewriter &rewriter, } // Helper that picks the proper sequence for extracting. -static Value *extractOne(ConversionPatternRewriter &rewriter, - LLVMTypeConverter &lowering, Location loc, Value *val, - Type llvmType, int64_t rank, int64_t pos) { +static ValuePtr extractOne(ConversionPatternRewriter &rewriter, + LLVMTypeConverter &lowering, Location loc, + ValuePtr val, Type llvmType, int64_t rank, + int64_t pos) { if (rank == 1) { auto idxType = rewriter.getIndexType(); auto constant = rewriter.create( @@ -101,7 +103,7 @@ public: typeConverter) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto broadcastOp = cast(op); VectorType dstVectorType = broadcastOp.getVectorType(); @@ -129,9 +131,9 @@ private: // ops once all insert/extract/shuffle operations // are available with lowering implemention. // - Value *expandRanks(Value *value, Location loc, VectorType srcVectorType, - VectorType dstVectorType, - ConversionPatternRewriter &rewriter) const { + ValuePtr expandRanks(ValuePtr value, Location loc, VectorType srcVectorType, + VectorType dstVectorType, + ConversionPatternRewriter &rewriter) const { assert((dstVectorType != nullptr) && "invalid result type in broadcast"); // Determine rank of source and destination. int64_t srcRank = srcVectorType ? srcVectorType.getRank() : 0; @@ -168,23 +170,24 @@ private: // becomes: // x = [s,s] // v = [x,x,x,x] - Value *duplicateOneRank(Value *value, Location loc, VectorType srcVectorType, - VectorType dstVectorType, int64_t rank, int64_t dim, - ConversionPatternRewriter &rewriter) const { + ValuePtr duplicateOneRank(ValuePtr value, Location loc, + VectorType srcVectorType, VectorType dstVectorType, + int64_t rank, int64_t dim, + ConversionPatternRewriter &rewriter) const { Type llvmType = lowering.convertType(dstVectorType); assert((llvmType != nullptr) && "unlowerable vector type"); if (rank == 1) { - Value *undef = rewriter.create(loc, llvmType); - Value *expand = + ValuePtr undef = rewriter.create(loc, llvmType); + ValuePtr expand = insertOne(rewriter, lowering, loc, undef, value, llvmType, rank, 0); SmallVector zeroValues(dim, 0); return rewriter.create( loc, expand, undef, rewriter.getI32ArrayAttr(zeroValues)); } - Value *expand = + ValuePtr expand = expandRanks(value, loc, srcVectorType, reducedVectorTypeFront(dstVectorType), rewriter); - Value *result = rewriter.create(loc, llvmType); + ValuePtr result = rewriter.create(loc, llvmType); for (int64_t d = 0; d < dim; ++d) { result = insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d); @@ -209,19 +212,20 @@ private: // y = broadcast w[1][0] : vector<2xf32> to vector <2x2xf32> // a = [x, y] // etc. - Value *stretchOneRank(Value *value, Location loc, VectorType srcVectorType, - VectorType dstVectorType, int64_t rank, int64_t dim, - ConversionPatternRewriter &rewriter) const { + ValuePtr stretchOneRank(ValuePtr value, Location loc, + VectorType srcVectorType, VectorType dstVectorType, + int64_t rank, int64_t dim, + ConversionPatternRewriter &rewriter) const { Type llvmType = lowering.convertType(dstVectorType); assert((llvmType != nullptr) && "unlowerable vector type"); - Value *result = rewriter.create(loc, llvmType); + ValuePtr result = rewriter.create(loc, llvmType); bool atStretch = dim != srcVectorType.getDimSize(0); if (rank == 1) { assert(atStretch); Type redLlvmType = lowering.convertType(dstVectorType.getElementType()); - Value *one = + ValuePtr one = extractOne(rewriter, lowering, loc, value, redLlvmType, rank, 0); - Value *expand = + ValuePtr expand = insertOne(rewriter, lowering, loc, result, one, llvmType, rank, 0); SmallVector zeroValues(dim, 0); return rewriter.create( @@ -232,9 +236,9 @@ private: Type redLlvmType = lowering.convertType(redSrcType); for (int64_t d = 0; d < dim; ++d) { int64_t pos = atStretch ? 0 : d; - Value *one = + ValuePtr one = extractOne(rewriter, lowering, loc, value, redLlvmType, rank, pos); - Value *expand = expandRanks(one, loc, redSrcType, redDstType, rewriter); + ValuePtr expand = expandRanks(one, loc, redSrcType, redDstType, rewriter); result = insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d); } @@ -250,7 +254,7 @@ public: typeConverter) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto adaptor = vector::ShuffleOpOperandAdaptor(operands); @@ -274,23 +278,23 @@ public: // For rank 1, where both operands have *exactly* the same vector type, // there is direct shuffle support in LLVM. Use it! if (rank == 1 && v1Type == v2Type) { - Value *shuffle = rewriter.create( + ValuePtr shuffle = rewriter.create( loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); rewriter.replaceOp(op, shuffle); return matchSuccess(); } // For all other cases, insert the individual values individually. - Value *insert = rewriter.create(loc, llvmType); + ValuePtr insert = rewriter.create(loc, llvmType); int64_t insPos = 0; for (auto en : llvm::enumerate(maskArrayAttr)) { int64_t extPos = en.value().cast().getInt(); - Value *value = adaptor.v1(); + ValuePtr value = adaptor.v1(); if (extPos >= v1Dim) { extPos -= v1Dim; value = adaptor.v2(); } - Value *extract = + ValuePtr extract = extractOne(rewriter, lowering, loc, value, llvmType, rank, extPos); insert = insertOne(rewriter, lowering, loc, insert, extract, llvmType, rank, insPos++); @@ -308,7 +312,7 @@ public: typeConverter) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto adaptor = vector::ExtractElementOpOperandAdaptor(operands); auto extractEltOp = cast(op); @@ -333,7 +337,7 @@ public: typeConverter) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto adaptor = vector::ExtractOpOperandAdaptor(operands); @@ -349,7 +353,7 @@ public: // One-shot extraction of vector from array (only requires extractvalue). if (resultType.isa()) { - Value *extracted = rewriter.create( + ValuePtr extracted = rewriter.create( loc, llvmResultType, adaptor.vector(), positionArrayAttr); rewriter.replaceOp(op, extracted); return matchSuccess(); @@ -357,7 +361,7 @@ public: // Potential extraction of 1-D vector from array. auto *context = op->getContext(); - Value *extracted = adaptor.vector(); + ValuePtr extracted = adaptor.vector(); auto positionAttrs = positionArrayAttr.getValue(); if (positionAttrs.size() > 1) { auto oneDVectorType = reducedVectorTypeBack(vectorType); @@ -388,7 +392,7 @@ public: typeConverter) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto adaptor = vector::InsertElementOpOperandAdaptor(operands); auto insertEltOp = cast(op); @@ -413,7 +417,7 @@ public: typeConverter) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto adaptor = vector::InsertOpOperandAdaptor(operands); @@ -429,7 +433,7 @@ public: // One-shot insertion of a vector into an array (only requires insertvalue). if (sourceType.isa()) { - Value *inserted = rewriter.create( + ValuePtr inserted = rewriter.create( loc, llvmResultType, adaptor.dest(), adaptor.source(), positionArrayAttr); rewriter.replaceOp(op, inserted); @@ -438,7 +442,7 @@ public: // Potential extraction of 1-D vector from array. auto *context = op->getContext(); - Value *extracted = adaptor.dest(); + ValuePtr extracted = adaptor.dest(); auto positionAttrs = positionArrayAttr.getValue(); auto position = positionAttrs.back().cast(); auto oneDVectorType = destVectorType; @@ -454,7 +458,7 @@ public: // Insertion of an element into a 1-D LLVM vector. auto i64Type = LLVM::LLVMType::getInt64Ty(lowering.getDialect()); auto constant = rewriter.create(loc, i64Type, position); - Value *inserted = rewriter.create( + ValuePtr inserted = rewriter.create( loc, lowering.convertType(oneDVectorType), extracted, adaptor.source(), constant); @@ -480,7 +484,7 @@ public: typeConverter) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto adaptor = vector::OuterProductOpOperandAdaptor(operands); @@ -491,10 +495,10 @@ public: auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements(); auto llvmArrayOfVectType = lowering.convertType( cast(op).getResult()->getType()); - Value *desc = rewriter.create(loc, llvmArrayOfVectType); - Value *a = adaptor.lhs(), *b = adaptor.rhs(); - Value *acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front(); - SmallVector lhs, accs; + ValuePtr desc = rewriter.create(loc, llvmArrayOfVectType); + ValuePtr a = adaptor.lhs(), b = adaptor.rhs(); + ValuePtr acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front(); + SmallVector lhs, accs; lhs.reserve(rankLHS); accs.reserve(rankLHS); for (unsigned d = 0, e = rankLHS; d < e; ++d) { @@ -502,7 +506,7 @@ public: auto attr = rewriter.getI32IntegerAttr(d); SmallVector bcastAttr(rankRHS, attr); auto bcastArrayAttr = ArrayAttr::get(bcastAttr, ctx); - Value *aD = nullptr, *accD = nullptr; + ValuePtr aD = nullptr, accD = nullptr; // 1. Broadcast the element a[d] into vector aD. aD = rewriter.create(loc, a, a, bcastArrayAttr); // 2. If acc is present, extract 1-d vector acc[d] into accD. @@ -510,7 +514,7 @@ public: accD = rewriter.create( loc, vRHS, acc, rewriter.getI64ArrayAttr(d)); // 3. Compute aD outer b (plus accD, if relevant). - Value *aOuterbD = + ValuePtr aOuterbD = accD ? rewriter.create(loc, vRHS, aD, b, accD) .getResult() : rewriter.create(loc, aD, b).getResult(); @@ -532,7 +536,7 @@ public: typeConverter) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); vector::TypeCastOp castOp = cast(op); @@ -581,12 +585,12 @@ public: auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); Type llvmTargetElementTy = desc.getElementType(); // Set allocated ptr. - Value *allocated = sourceMemRef.allocatedPtr(rewriter, loc); + ValuePtr allocated = sourceMemRef.allocatedPtr(rewriter, loc); allocated = rewriter.create(loc, llvmTargetElementTy, allocated); desc.setAllocatedPtr(rewriter, loc, allocated); // Set aligned ptr. - Value *ptr = sourceMemRef.alignedPtr(rewriter, loc); + ValuePtr ptr = sourceMemRef.alignedPtr(rewriter, loc); ptr = rewriter.create(loc, llvmTargetElementTy, ptr); desc.setAlignedPtr(rewriter, loc, ptr); // Fill offset 0. @@ -632,7 +636,7 @@ public: // TODO(ajcbik): rely solely on libc in future? something else? // PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto printOp = cast(op); auto adaptor = vector::PrintOpOperandAdaptor(operands); @@ -662,7 +666,7 @@ public: private: void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, - Value *value, VectorType vectorType, Operation *printer, + ValuePtr value, VectorType vectorType, Operation *printer, int64_t rank) const { Location loc = op->getLoc(); if (rank == 0) { @@ -678,7 +682,7 @@ private: rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; auto llvmType = lowering.convertType( rank > 1 ? reducedType : vectorType.getElementType()); - Value *nestedVal = + ValuePtr nestedVal = extractOne(rewriter, lowering, loc, value, llvmType, rank, d); emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1); if (d != dim - 1) diff --git a/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/mlir/lib/Dialect/AffineOps/AffineOps.cpp index ef4060d4302..3a21de389c7 100644 --- a/mlir/lib/Dialect/AffineOps/AffineOps.cpp +++ b/mlir/lib/Dialect/AffineOps/AffineOps.cpp @@ -115,8 +115,8 @@ static bool isFunctionRegion(Region *region) { /// A utility function to check if a value is defined at the top level of a /// function. A value of index type defined at the top level is always a valid /// symbol. -bool mlir::isTopLevelValue(Value *value) { - if (auto *arg = dyn_cast(value)) +bool mlir::isTopLevelValue(ValuePtr value) { + if (auto arg = dyn_cast(value)) return isFunctionRegion(arg->getOwner()->getParent()); return isFunctionRegion(value->getDefiningOp()->getParentRegion()); } @@ -124,7 +124,7 @@ bool mlir::isTopLevelValue(Value *value) { // Value can be used as a dimension id if it is valid as a symbol, or // it is an induction variable, or it is a result of affine apply operation // with dimension id arguments. -bool mlir::isValidDim(Value *value) { +bool mlir::isValidDim(ValuePtr value) { // The value must be an index type. if (!value->getType().isIndex()) return false; @@ -184,7 +184,7 @@ static bool isDimOpValidSymbol(DimOp dimOp) { // the top level, or it is a result of affine apply operation with symbol // arguments, or a result of the dim op on a memref satisfying certain // constraints. -bool mlir::isValidSymbol(Value *value) { +bool mlir::isValidSymbol(ValuePtr value) { // The value must be an index type. if (!value->getType().isIndex()) return false; @@ -207,7 +207,7 @@ bool mlir::isValidSymbol(Value *value) { // Returns true if 'value' is a valid index to an affine operation (e.g. // affine.load, affine.store, affine.dma_start, affine.dma_wait). // Returns false otherwise. -static bool isValidAffineIndexOperand(Value *value) { +static bool isValidAffineIndexOperand(ValuePtr value) { return isValidDim(value) || isValidSymbol(value); } @@ -221,7 +221,7 @@ static LogicalResult verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, unsigned numDims) { unsigned opIt = 0; - for (auto *operand : operands) { + for (auto operand : operands) { if (opIt++ < numDims) { if (!isValidDim(operand)) return op.emitOpError("operand cannot be used as a dimension id"); @@ -306,14 +306,14 @@ LogicalResult AffineApplyOp::verify() { // its operands are valid dimension ids. bool AffineApplyOp::isValidDim() { return llvm::all_of(getOperands(), - [](Value *op) { return mlir::isValidDim(op); }); + [](ValuePtr op) { return mlir::isValidDim(op); }); } // The result of the affine apply operation can be used as a symbol if all its // operands are symbols. bool AffineApplyOp::isValidSymbol() { return llvm::all_of(getOperands(), - [](Value *op) { return mlir::isValidSymbol(op); }); + [](ValuePtr op) { return mlir::isValidSymbol(op); }); } OpFoldResult AffineApplyOp::fold(ArrayRef operands) { @@ -333,8 +333,8 @@ OpFoldResult AffineApplyOp::fold(ArrayRef operands) { return result[0]; } -AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value *v) { - DenseMap::iterator iterPos; +AffineDimExpr AffineApplyNormalizer::renumberOneDim(ValuePtr v) { + DenseMap::iterator iterPos; bool inserted = false; std::tie(iterPos, inserted) = dimValueToPosition.insert(std::make_pair(v, dimValueToPosition.size())); @@ -347,7 +347,7 @@ AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value *v) { AffineMap AffineApplyNormalizer::renumber(const AffineApplyNormalizer &other) { SmallVector dimRemapping; - for (auto *v : other.reorderedDims) { + for (auto v : other.reorderedDims) { auto kvp = other.dimValueToPosition.find(v); if (dimRemapping.size() <= kvp->second) dimRemapping.resize(kvp->second + 1); @@ -371,7 +371,7 @@ AffineMap AffineApplyNormalizer::renumber(const AffineApplyNormalizer &other) { // Gather the positions of the operands that are produced by an AffineApplyOp. static llvm::SetVector -indicesFromAffineApplyOp(ArrayRef operands) { +indicesFromAffineApplyOp(ArrayRef operands) { llvm::SetVector res; for (auto en : llvm::enumerate(operands)) if (isa_and_nonnull(en.value()->getDefiningOp())) @@ -393,13 +393,13 @@ indicesFromAffineApplyOp(ArrayRef operands) { // results in better simplifications and foldings. But we should evaluate // whether this behavior is what we really want after using more. static AffineMap promoteComposedSymbolsAsDims(AffineMap map, - ArrayRef symbols) { + ArrayRef symbols) { if (symbols.empty()) { return map; } // Sanity check on symbols. - for (auto *sym : symbols) { + for (auto sym : symbols) { assert(isValidSymbol(sym) && "Expected only valid symbols"); (void)sym; } @@ -446,7 +446,7 @@ static AffineMap promoteComposedSymbolsAsDims(AffineMap map, /// `(d0)[s0, s1] -> (d0 + s0 + s1)`. /// /// The result is only equivalent to `(d0)[s0] -> (d0 + 2 * s0)` when -/// applied to the same mlir::Value* for both s0 and s1. +/// applied to the same mlir::Value for both s0 and s1. /// As a consequence mathematical composition of AffineMap always concatenates /// symbols. /// @@ -462,7 +462,7 @@ static AffineMap promoteComposedSymbolsAsDims(AffineMap map, /// benefit potentially big: simpler and more maintainable code for a /// non-trivial, recursive, procedure. AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map, - ArrayRef operands) + ArrayRef operands) : AffineApplyNormalizer() { static_assert(kMaxAffineApplyDepth > 0, "kMaxAffineApplyDepth must be > 0"); assert(map.getNumInputs() == operands.size() && @@ -495,7 +495,7 @@ AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map, if (!furtherCompose) { // 1. Only dispatch dims or symbols. for (auto en : llvm::enumerate(operands)) { - auto *t = en.value(); + auto t = en.value(); assert(t->getType().isIndex()); bool isDim = (en.index() < map.getNumDims()); if (isDim) { @@ -511,14 +511,14 @@ AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map, assert(numDimsBeforeRewrite <= operands.size()); // 2. Compose AffineApplyOps and dispatch dims or symbols. for (unsigned i = 0, e = operands.size(); i < e; ++i) { - auto *t = operands[i]; + auto t = operands[i]; auto affineApply = dyn_cast_or_null(t->getDefiningOp()); if (affineApply) { // a. Compose affine.apply operations. LLVM_DEBUG(affineApply.getOperation()->print( dbgs() << "\nCompose AffineApplyOp recursively: ")); AffineMap affineApplyMap = affineApply.getAffineMap(); - SmallVector affineApplyOperands( + SmallVector affineApplyOperands( affineApply.getOperands().begin(), affineApply.getOperands().end()); AffineApplyNormalizer normalizer(affineApplyMap, affineApplyOperands); @@ -569,8 +569,8 @@ AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map, LLVM_DEBUG(dbgs() << "\n"); } -void AffineApplyNormalizer::normalize(AffineMap *otherMap, - SmallVectorImpl *otherOperands) { +void AffineApplyNormalizer::normalize( + AffineMap *otherMap, SmallVectorImpl *otherOperands) { AffineApplyNormalizer other(*otherMap, *otherOperands); *otherMap = renumber(other); @@ -584,7 +584,7 @@ void AffineApplyNormalizer::normalize(AffineMap *otherMap, /// on `map` and `operands` without creating an AffineApplyOp that needs to be /// immediately deleted. static void composeAffineMapAndOperands(AffineMap *map, - SmallVectorImpl *operands) { + SmallVectorImpl *operands) { AffineApplyNormalizer normalizer(*map, *operands); auto normalizedMap = normalizer.getAffineMap(); auto normalizedOperands = normalizer.getOperands(); @@ -595,8 +595,8 @@ static void composeAffineMapAndOperands(AffineMap *map, } void mlir::fullyComposeAffineMapAndOperands( - AffineMap *map, SmallVectorImpl *operands) { - while (llvm::any_of(*operands, [](Value *v) { + AffineMap *map, SmallVectorImpl *operands) { + while (llvm::any_of(*operands, [](ValuePtr v) { return isa_and_nonnull(v->getDefiningOp()); })) { composeAffineMapAndOperands(map, operands); @@ -605,9 +605,9 @@ void mlir::fullyComposeAffineMapAndOperands( AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, - ArrayRef operands) { + ArrayRef operands) { AffineMap normalizedMap = map; - SmallVector normalizedOperands(operands.begin(), operands.end()); + SmallVector normalizedOperands(operands.begin(), operands.end()); composeAffineMapAndOperands(&normalizedMap, &normalizedOperands); assert(normalizedMap); return b.create(loc, normalizedMap, normalizedOperands); @@ -617,7 +617,7 @@ AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc, // canonicalizes dims that are valid symbols into actual symbols. template static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, - SmallVectorImpl *operands) { + SmallVectorImpl *operands) { if (!mapOrSet || operands->empty()) return; @@ -625,9 +625,9 @@ static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, "map/set inputs must match number of operands"); auto *context = mapOrSet->getContext(); - SmallVector resultOperands; + SmallVector resultOperands; resultOperands.reserve(operands->size()); - SmallVector remappedSymbols; + SmallVector remappedSymbols; remappedSymbols.reserve(operands->size()); unsigned nextDim = 0; unsigned nextSym = 0; @@ -661,7 +661,7 @@ static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, template static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, - SmallVectorImpl *operands) { + SmallVectorImpl *operands) { static_assert(std::is_same::value || std::is_same::value, "Argument must be either of AffineMap or IntegerSet type"); @@ -686,10 +686,10 @@ canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, auto *context = mapOrSet->getContext(); - SmallVector resultOperands; + SmallVector resultOperands; resultOperands.reserve(operands->size()); - llvm::SmallDenseMap seenDims; + llvm::SmallDenseMap seenDims; SmallVector dimRemapping(mapOrSet->getNumDims()); unsigned nextDim = 0; for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) { @@ -705,7 +705,7 @@ canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, } } } - llvm::SmallDenseMap seenSymbols; + llvm::SmallDenseMap seenSymbols; SmallVector symRemapping(mapOrSet->getNumSymbols()); unsigned nextSym = 0; for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) { @@ -738,12 +738,12 @@ canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, } void mlir::canonicalizeMapAndOperands(AffineMap *map, - SmallVectorImpl *operands) { + SmallVectorImpl *operands) { canonicalizeMapOrSetAndOperands(map, operands); } void mlir::canonicalizeSetAndOperands(IntegerSet *set, - SmallVectorImpl *operands) { + SmallVectorImpl *operands) { canonicalizeMapOrSetAndOperands(set, operands); } @@ -758,7 +758,7 @@ struct SimplifyAffineOp : public OpRewritePattern { /// Replace the affine op with another instance of it with the supplied /// map and mapOperands. void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp, - AffineMap map, ArrayRef mapOperands) const; + AffineMap map, ArrayRef mapOperands) const; PatternMatchResult matchAndRewrite(AffineOpTy affineOp, PatternRewriter &rewriter) const override { @@ -770,7 +770,7 @@ struct SimplifyAffineOp : public OpRewritePattern { auto map = affineOp.getAffineMap(); AffineMap oldMap = map; auto oldOperands = affineOp.getMapOperands(); - SmallVector resultOperands(oldOperands); + SmallVector resultOperands(oldOperands); composeAffineMapAndOperands(&map, &resultOperands); if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(), resultOperands.begin())) @@ -786,14 +786,14 @@ struct SimplifyAffineOp : public OpRewritePattern { template <> void SimplifyAffineOp::replaceAffineOp( PatternRewriter &rewriter, AffineLoadOp load, AffineMap map, - ArrayRef mapOperands) const { + ArrayRef mapOperands) const { rewriter.replaceOpWithNewOp(load, load.getMemRef(), map, mapOperands); } template <> void SimplifyAffineOp::replaceAffineOp( PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map, - ArrayRef mapOperands) const { + ArrayRef mapOperands) const { rewriter.replaceOpWithNewOp( prefetch, prefetch.memref(), map, mapOperands, prefetch.localityHint().getZExtValue(), prefetch.isWrite(), @@ -802,14 +802,14 @@ void SimplifyAffineOp::replaceAffineOp( template <> void SimplifyAffineOp::replaceAffineOp( PatternRewriter &rewriter, AffineStoreOp store, AffineMap map, - ArrayRef mapOperands) const { + ArrayRef mapOperands) const { rewriter.replaceOpWithNewOp( store, store.getValueToStore(), store.getMemRef(), map, mapOperands); } template <> void SimplifyAffineOp::replaceAffineOp( PatternRewriter &rewriter, AffineApplyOp apply, AffineMap map, - ArrayRef mapOperands) const { + ArrayRef mapOperands) const { rewriter.replaceOpWithNewOp(apply, map, mapOperands); } } // end anonymous namespace. @@ -844,12 +844,12 @@ static LogicalResult foldMemRefCast(Operation *op) { // TODO(b/133776335) Check that map operands are loop IVs or symbols. void AffineDmaStartOp::build(Builder *builder, OperationState &result, - Value *srcMemRef, AffineMap srcMap, - ValueRange srcIndices, Value *destMemRef, + ValuePtr srcMemRef, AffineMap srcMap, + ValueRange srcIndices, ValuePtr destMemRef, AffineMap dstMap, ValueRange destIndices, - Value *tagMemRef, AffineMap tagMap, - ValueRange tagIndices, Value *numElements, - Value *stride, Value *elementsPerStride) { + ValuePtr tagMemRef, AffineMap tagMap, + ValueRange tagIndices, ValuePtr numElements, + ValuePtr stride, ValuePtr elementsPerStride) { result.addOperands(srcMemRef); result.addAttribute(getSrcMapAttrName(), AffineMapAttr::get(srcMap)); result.addOperands(srcIndices); @@ -980,19 +980,19 @@ LogicalResult AffineDmaStartOp::verify() { return emitOpError("incorrect number of operands"); } - for (auto *idx : getSrcIndices()) { + for (auto idx : getSrcIndices()) { if (!idx->getType().isIndex()) return emitOpError("src index to dma_start must have 'index' type"); if (!isValidAffineIndexOperand(idx)) return emitOpError("src index must be a dimension or symbol identifier"); } - for (auto *idx : getDstIndices()) { + for (auto idx : getDstIndices()) { if (!idx->getType().isIndex()) return emitOpError("dst index to dma_start must have 'index' type"); if (!isValidAffineIndexOperand(idx)) return emitOpError("dst index must be a dimension or symbol identifier"); } - for (auto *idx : getTagIndices()) { + for (auto idx : getTagIndices()) { if (!idx->getType().isIndex()) return emitOpError("tag index to dma_start must have 'index' type"); if (!isValidAffineIndexOperand(idx)) @@ -1013,8 +1013,8 @@ LogicalResult AffineDmaStartOp::fold(ArrayRef cstOperands, // TODO(b/133776335) Check that map operands are loop IVs or symbols. void AffineDmaWaitOp::build(Builder *builder, OperationState &result, - Value *tagMemRef, AffineMap tagMap, - ValueRange tagIndices, Value *numElements) { + ValuePtr tagMemRef, AffineMap tagMap, + ValueRange tagIndices, ValuePtr numElements) { result.addOperands(tagMemRef); result.addAttribute(getTagMapAttrName(), AffineMapAttr::get(tagMap)); result.addOperands(tagIndices); @@ -1023,7 +1023,7 @@ void AffineDmaWaitOp::build(Builder *builder, OperationState &result, void AffineDmaWaitOp::print(OpAsmPrinter &p) { p << "affine.dma_wait " << *getTagMemRef() << '['; - SmallVector operands(getTagIndices()); + SmallVector operands(getTagIndices()); p.printAffineMapOfSSAIds(getTagMapAttr(), operands); p << "], "; p.printOperand(getNumElements()); @@ -1068,7 +1068,7 @@ ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser, LogicalResult AffineDmaWaitOp::verify() { if (!getOperand(0)->getType().isa()) return emitOpError("expected DMA tag to be of memref type"); - for (auto *idx : getTagIndices()) { + for (auto idx : getTagIndices()) { if (!idx->getType().isIndex()) return emitOpError("index to dma_wait must have 'index' type"); if (!isValidAffineIndexOperand(idx)) @@ -1368,7 +1368,7 @@ static LogicalResult foldLoopBounds(AffineForOp forOp) { SmallVector operandConstants; auto boundOperands = lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands(); - for (auto *operand : boundOperands) { + for (auto operand : boundOperands) { Attribute operandCst; matchPattern(operand, m_Constant(&operandCst)); operandConstants.push_back(operandCst); @@ -1408,8 +1408,8 @@ static LogicalResult foldLoopBounds(AffineForOp forOp) { /// Canonicalize the bounds of the given loop. static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) { - SmallVector lbOperands(forOp.getLowerBoundOperands()); - SmallVector ubOperands(forOp.getUpperBoundOperands()); + SmallVector lbOperands(forOp.getLowerBoundOperands()); + SmallVector ubOperands(forOp.getUpperBoundOperands()); auto lbMap = forOp.getLowerBoundMap(); auto ubMap = forOp.getUpperBoundMap(); @@ -1474,7 +1474,7 @@ void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) { assert(lbOperands.size() == map.getNumInputs()); assert(map.getNumResults() >= 1 && "bound map has at least one result"); - SmallVector newOperands(lbOperands.begin(), lbOperands.end()); + SmallVector newOperands(lbOperands.begin(), lbOperands.end()); auto ubOperands = getUpperBoundOperands(); newOperands.append(ubOperands.begin(), ubOperands.end()); @@ -1487,7 +1487,7 @@ void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) { assert(ubOperands.size() == map.getNumInputs()); assert(map.getNumResults() >= 1 && "bound map has at least one result"); - SmallVector newOperands(getLowerBoundOperands()); + SmallVector newOperands(getLowerBoundOperands()); newOperands.append(ubOperands.begin(), ubOperands.end()); getOperation()->setOperands(newOperands); @@ -1553,7 +1553,7 @@ bool AffineForOp::matchingBoundOperandList() { unsigned numOperands = lbMap.getNumInputs(); for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) { - // Compare Value *'s. + // Compare ValuePtr 's. if (getOperand(i) != getOperand(numOperands + i)) return false; } @@ -1562,7 +1562,7 @@ bool AffineForOp::matchingBoundOperandList() { Region &AffineForOp::getLoopBody() { return region(); } -bool AffineForOp::isDefinedOutsideOfLoop(Value *value) { +bool AffineForOp::isDefinedOutsideOfLoop(ValuePtr value) { return !region().isAncestor(value->getParentRegion()); } @@ -1573,14 +1573,14 @@ LogicalResult AffineForOp::moveOutOfLoop(ArrayRef ops) { } /// Returns if the provided value is the induction variable of a AffineForOp. -bool mlir::isForInductionVar(Value *val) { +bool mlir::isForInductionVar(ValuePtr val) { return getForInductionVarOwner(val) != AffineForOp(); } /// Returns the loop parent of an induction variable. If the provided value is /// not an induction variable, then return nullptr. -AffineForOp mlir::getForInductionVarOwner(Value *val) { - auto *ivArg = dyn_cast(val); +AffineForOp mlir::getForInductionVarOwner(ValuePtr val) { + auto ivArg = dyn_cast(val); if (!ivArg || !ivArg->getOwner()) return AffineForOp(); auto *containingInst = ivArg->getOwner()->getParent()->getParentOp(); @@ -1590,7 +1590,7 @@ AffineForOp mlir::getForInductionVarOwner(Value *val) { /// Extracts the induction variables from a list of AffineForOps and returns /// them. void mlir::extractForInductionVars(ArrayRef forInsts, - SmallVectorImpl *ivs) { + SmallVectorImpl *ivs) { ivs->reserve(forInsts.size()); for (auto forInst : forInsts) ivs->push_back(forInst.getInductionVar()); @@ -1729,7 +1729,7 @@ void AffineIfOp::build(Builder *builder, OperationState &result, IntegerSet set, LogicalResult AffineIfOp::fold(ArrayRef, SmallVectorImpl &) { auto set = getIntegerSet(); - SmallVector operands(getOperands()); + SmallVector operands(getOperands()); canonicalizeSetAndOperands(&set, &operands); // Any canonicalization change always leads to either a reduction in the @@ -1759,7 +1759,8 @@ void AffineLoadOp::build(Builder *builder, OperationState &result, } void AffineLoadOp::build(Builder *builder, OperationState &result, - Value *memref, AffineMap map, ValueRange mapOperands) { + ValuePtr memref, AffineMap map, + ValueRange mapOperands) { assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info"); result.addOperands(memref); result.addOperands(mapOperands); @@ -1769,7 +1770,7 @@ void AffineLoadOp::build(Builder *builder, OperationState &result, } void AffineLoadOp::build(Builder *builder, OperationState &result, - Value *memref, ValueRange indices) { + ValuePtr memref, ValueRange indices) { auto memrefType = memref->getType().cast(); auto rank = memrefType.getRank(); // Create identity map for memrefs with at least one dimension or () -> () @@ -1825,7 +1826,7 @@ LogicalResult AffineLoadOp::verify() { "expects the number of subscripts to be equal to memref rank"); } - for (auto *idx : getMapOperands()) { + for (auto idx : getMapOperands()) { if (!idx->getType().isIndex()) return emitOpError("index to load must have 'index' type"); if (!isValidAffineIndexOperand(idx)) @@ -1851,7 +1852,7 @@ OpFoldResult AffineLoadOp::fold(ArrayRef cstOperands) { //===----------------------------------------------------------------------===// void AffineStoreOp::build(Builder *builder, OperationState &result, - Value *valueToStore, Value *memref, AffineMap map, + ValuePtr valueToStore, ValuePtr memref, AffineMap map, ValueRange mapOperands) { assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info"); result.addOperands(valueToStore); @@ -1862,7 +1863,7 @@ void AffineStoreOp::build(Builder *builder, OperationState &result, // Use identity map. void AffineStoreOp::build(Builder *builder, OperationState &result, - Value *valueToStore, Value *memref, + ValuePtr valueToStore, ValuePtr memref, ValueRange indices) { auto memrefType = memref->getType().cast(); auto rank = memrefType.getRank(); @@ -1923,7 +1924,7 @@ LogicalResult AffineStoreOp::verify() { "expects the number of subscripts to be equal to memref rank"); } - for (auto *idx : getMapOperands()) { + for (auto idx : getMapOperands()) { if (!idx->getType().isIndex()) return emitOpError("index to store must have 'index' type"); if (!isValidAffineIndexOperand(idx)) @@ -2072,7 +2073,7 @@ void print(OpAsmPrinter &p, AffinePrefetchOp op) { p << AffinePrefetchOp::getOperationName() << " " << *op.memref() << '['; AffineMapAttr mapAttr = op.getAttrOfType(op.getMapAttrName()); if (mapAttr) { - SmallVector operands(op.getMapOperands()); + SmallVector operands(op.getMapOperands()); p.printAffineMapOfSSAIds(mapAttr, operands); } p << ']' << ", " << (op.isWrite() ? "write" : "read") << ", " @@ -2099,7 +2100,7 @@ LogicalResult verify(AffinePrefetchOp op) { return op.emitOpError("too few operands"); } - for (auto *idx : op.getMapOperands()) { + for (auto idx : op.getMapOperands()) { if (!isValidAffineIndexOperand(idx)) return op.emitOpError("index must be a dimension or symbol identifier"); } diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp index 3982a6a4713..e1951ff900b 100644 --- a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp +++ b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp @@ -46,9 +46,9 @@ struct LowerUniformCastsPass : public FunctionPass { // Dequantize //===----------------------------------------------------------------------===// -static Value *emitUniformPerLayerDequantize(Location loc, Value *input, - UniformQuantizedType elementType, - PatternRewriter &rewriter) { +static ValuePtr emitUniformPerLayerDequantize(Location loc, ValuePtr input, + UniformQuantizedType elementType, + PatternRewriter &rewriter) { // Pre-conditions. if (!elementType.isSigned()) { // TODO: Support unsigned storage type. @@ -71,7 +71,7 @@ static Value *emitUniformPerLayerDequantize(Location loc, Value *input, // Apply zero-point offset. if (elementType.getZeroPoint() != 0) { - Value *negZeroPointConst = rewriter.create( + ValuePtr negZeroPointConst = rewriter.create( loc, broadcastScalarConstIntValue(intermediateType, -elementType.getZeroPoint())); input = rewriter.create(loc, input, negZeroPointConst); @@ -81,14 +81,14 @@ static Value *emitUniformPerLayerDequantize(Location loc, Value *input, input = rewriter.create(loc, realType, input); // Mul by scale. - Value *scaleConst = rewriter.create( + ValuePtr scaleConst = rewriter.create( loc, broadcastScalarConstFloatValue(realType, APFloat(elementType.getScale()))); return rewriter.create(loc, input, scaleConst); } -static Value * -emitUniformPerAxisDequantize(Location loc, Value *input, +static ValuePtr +emitUniformPerAxisDequantize(Location loc, ValuePtr input, UniformQuantizedPerAxisType elementType, PatternRewriter &rewriter) { // TODO: Support per-axis dequantize. @@ -97,8 +97,8 @@ emitUniformPerAxisDequantize(Location loc, Value *input, return nullptr; } -static Value *emitDequantize(Location loc, Value *input, - PatternRewriter &rewriter) { +static ValuePtr emitDequantize(Location loc, ValuePtr input, + PatternRewriter &rewriter) { Type inputType = input->getType(); QuantizedType qElementType = QuantizedType::getQuantizedElementType(inputType); @@ -133,7 +133,7 @@ struct UniformDequantizePattern : public OpRewritePattern { return matchFailure(); } - Value *dequantizedValue = emitDequantize(op.getLoc(), op.arg(), rewriter); + ValuePtr dequantizedValue = emitDequantize(op.getLoc(), op.arg(), rewriter); if (!dequantizedValue) { return matchFailure(); } @@ -170,14 +170,14 @@ tryRewriteAffineAddEwIsomorphicSigned(const UniformBinaryOpInfo &info, castElementType(info.resultStorageType, intermediateElementType); // Cast operands to storage type. - Value *lhsValue = rewriter - .create(info.op->getLoc(), - info.lhsStorageType, info.lhs) - .getResult(); - Value *rhsValue = rewriter - .create(info.op->getLoc(), - info.rhsStorageType, info.rhs) - .getResult(); + ValuePtr lhsValue = rewriter + .create(info.op->getLoc(), + info.lhsStorageType, info.lhs) + .getResult(); + ValuePtr rhsValue = rewriter + .create(info.op->getLoc(), + info.rhsStorageType, info.rhs) + .getResult(); // Cast to the intermediate sized type. lhsValue = rewriter.create(info.op->getLoc(), intermediateType, @@ -186,7 +186,7 @@ tryRewriteAffineAddEwIsomorphicSigned(const UniformBinaryOpInfo &info, rhsValue); // Add. - Value *resultValue = + ValuePtr resultValue = rewriter.create(info.op->getLoc(), lhsValue, rhsValue); // Zero point offset adjustment. @@ -194,7 +194,7 @@ tryRewriteAffineAddEwIsomorphicSigned(const UniformBinaryOpInfo &info, // zpOffset = -zp int zpOffset = -1 * info.resultType.getZeroPoint(); if (zpOffset != 0) { - Value *zpOffsetConst = rewriter.create( + ValuePtr zpOffsetConst = rewriter.create( info.op->getLoc(), broadcastScalarConstIntValue(intermediateType, zpOffset)); resultValue = @@ -246,14 +246,14 @@ tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info, castElementType(info.resultStorageType, intermediateElementType); // Cast operands to storage type. - Value *lhsValue = rewriter - .create(info.op->getLoc(), - info.lhsStorageType, info.lhs) - .getResult(); - Value *rhsValue = rewriter - .create(info.op->getLoc(), - info.rhsStorageType, info.rhs) - .getResult(); + ValuePtr lhsValue = rewriter + .create(info.op->getLoc(), + info.lhsStorageType, info.lhs) + .getResult(); + ValuePtr rhsValue = rewriter + .create(info.op->getLoc(), + info.rhsStorageType, info.rhs) + .getResult(); // Cast to the intermediate sized type. lhsValue = rewriter.create(info.op->getLoc(), intermediateType, @@ -263,7 +263,7 @@ tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info, // Apply argument zeroPoints. if (info.lhsType.getZeroPoint() != 0) { - Value *zpOffsetConst = rewriter.create( + ValuePtr zpOffsetConst = rewriter.create( info.op->getLoc(), broadcastScalarConstIntValue( intermediateType, -info.lhsType.getZeroPoint())); lhsValue = @@ -271,7 +271,7 @@ tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info, } if (info.rhsType.getZeroPoint() != 0) { - Value *zpOffsetConst = rewriter.create( + ValuePtr zpOffsetConst = rewriter.create( info.op->getLoc(), broadcastScalarConstIntValue( intermediateType, -info.rhsType.getZeroPoint())); rhsValue = @@ -279,7 +279,7 @@ tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info, } // Mul. - Value *resultValue = + ValuePtr resultValue = rewriter.create(info.op->getLoc(), lhsValue, rhsValue); // Scale output. @@ -293,7 +293,7 @@ tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info, // Zero point offset adjustment. if (info.resultType.getZeroPoint() != 0) { - Value *zpOffsetConst = rewriter.create( + ValuePtr zpOffsetConst = rewriter.create( info.op->getLoc(), broadcastScalarConstIntValue(intermediateType, info.resultType.getZeroPoint())); diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h b/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h index 955e2ecc88c..57a8422b362 100644 --- a/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h +++ b/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h @@ -59,7 +59,7 @@ template bool integralLog2(F x, int &log2Result) { /// Helper class for operating on binary operations where all operands /// and the result are a UniformQuantizedType. struct UniformBinaryOpInfo { - UniformBinaryOpInfo(Operation *op, Value *lhs, Value *rhs, + UniformBinaryOpInfo(Operation *op, ValuePtr lhs, ValuePtr rhs, Optional clampMin, Optional clampMax) : op(op), lhs(lhs), rhs(rhs), clampMin(clampMin), clampMax(clampMax), lhsType(getUniformElementType(lhs->getType())), @@ -128,8 +128,8 @@ struct UniformBinaryOpInfo { } Operation *op; - Value *lhs; - Value *rhs; + ValuePtr lhs; + ValuePtr rhs; Optional clampMin; Optional clampMax; diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 9c0183eb90f..349c1fa4644 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -145,7 +145,7 @@ static LogicalResult verifyAllReduce(gpu::AllReduceOp allReduce) { if (!allReduce.body().empty()) { if (allReduce.body().front().getNumArguments() != 2) return allReduce.emitError("expected two region arguments"); - for (auto *argument : allReduce.body().front().getArguments()) { + for (auto argument : allReduce.body().front().getArguments()) { if (argument->getType() != allReduce.getType()) return allReduce.emitError("incorrect region argument type"); } @@ -213,15 +213,15 @@ static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &state) { static SmallVector getValueTypes(ValueRange values) { SmallVector types; types.reserve(values.size()); - for (Value *v : values) + for (ValuePtr v : values) types.push_back(v->getType()); return types; } -void LaunchOp::build(Builder *builder, OperationState &result, Value *gridSizeX, - Value *gridSizeY, Value *gridSizeZ, Value *blockSizeX, - Value *blockSizeY, Value *blockSizeZ, - ValueRange operands) { +void LaunchOp::build(Builder *builder, OperationState &result, + ValuePtr gridSizeX, ValuePtr gridSizeY, ValuePtr gridSizeZ, + ValuePtr blockSizeX, ValuePtr blockSizeY, + ValuePtr blockSizeZ, ValueRange operands) { // Add grid and block sizes as op operands, followed by the data operands. result.addOperands( {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ}); @@ -489,22 +489,22 @@ class PropagateConstantBounds : public OpRewritePattern { // and use it instead of passing the value from the parent region. Perform // the traversal in the inverse order to simplify index arithmetics when // dropping arguments. - SmallVector operands(launchOp.getKernelOperandValues().begin(), - launchOp.getKernelOperandValues().end()); - SmallVector kernelArgs(launchOp.getKernelArguments().begin(), - launchOp.getKernelArguments().end()); + SmallVector operands(launchOp.getKernelOperandValues().begin(), + launchOp.getKernelOperandValues().end()); + SmallVector kernelArgs(launchOp.getKernelArguments().begin(), + launchOp.getKernelArguments().end()); bool found = false; for (unsigned i = operands.size(); i > 0; --i) { unsigned index = i - 1; - Value *operand = operands[index]; + ValuePtr operand = operands[index]; if (!isa_and_nonnull(operand->getDefiningOp())) { continue; } found = true; - Value *internalConstant = + ValuePtr internalConstant = rewriter.clone(*operand->getDefiningOp())->getResult(0); - Value *kernelArg = kernelArgs[index]; + ValuePtr kernelArg = kernelArgs[index]; kernelArg->replaceAllUsesWith(internalConstant); launchOp.eraseKernelArgument(index); } @@ -529,10 +529,10 @@ void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList &results, //===----------------------------------------------------------------------===// void LaunchFuncOp::build(Builder *builder, OperationState &result, - GPUFuncOp kernelFunc, Value *gridSizeX, - Value *gridSizeY, Value *gridSizeZ, Value *blockSizeX, - Value *blockSizeY, Value *blockSizeZ, - ValueRange kernelOperands) { + GPUFuncOp kernelFunc, ValuePtr gridSizeX, + ValuePtr gridSizeY, ValuePtr gridSizeZ, + ValuePtr blockSizeX, ValuePtr blockSizeY, + ValuePtr blockSizeZ, ValueRange kernelOperands) { // Add grid and block sizes as op operands, followed by the data operands. result.addOperands( {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ}); @@ -565,7 +565,7 @@ StringRef LaunchFuncOp::getKernelModuleName() { .getRootReference(); } -Value *LaunchFuncOp::getKernelOperand(unsigned i) { +ValuePtr LaunchFuncOp::getKernelOperand(unsigned i) { return getOperation()->getOperand(i + kNumConfigOperands); } @@ -728,13 +728,14 @@ static ParseResult parseGPUFuncOp(OpAsmParser &parser, OperationState &result) { } static void printAttributions(OpAsmPrinter &p, StringRef keyword, - ArrayRef values) { + ArrayRef values) { if (values.empty()) return; p << ' ' << keyword << '('; - interleaveComma(values, p, - [&p](BlockArgument *v) { p << *v << " : " << v->getType(); }); + interleaveComma(values, p, [&p](BlockArgumentPtr v) { + p << *v << " : " << v->getType(); + }); p << ')'; } @@ -781,9 +782,9 @@ LogicalResult GPUFuncOp::verifyType() { } static LogicalResult verifyAttributions(Operation *op, - ArrayRef attributions, + ArrayRef attributions, unsigned memorySpace) { - for (Value *v : attributions) { + for (ValuePtr v : attributions) { auto type = v->getType().dyn_cast(); if (!type) return op->emitOpError() << "expected memref type in attribution"; diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp index 0a6a5915633..8f5f50e4909 100644 --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -31,10 +31,10 @@ using namespace mlir; template static void createForAllDimensions(OpBuilder &builder, Location loc, - SmallVectorImpl &values) { + SmallVectorImpl &values) { for (StringRef dim : {"x", "y", "z"}) { - Value *v = builder.create(loc, builder.getIndexType(), - builder.getStringAttr(dim)); + ValuePtr v = builder.create(loc, builder.getIndexType(), + builder.getStringAttr(dim)); values.push_back(v); } } @@ -46,7 +46,7 @@ static void injectGpuIndexOperations(Location loc, Region &body) { OpBuilder builder(loc->getContext()); Block &firstBlock = body.front(); builder.setInsertionPointToStart(&firstBlock); - SmallVector indexOps; + SmallVector indexOps; createForAllDimensions(builder, loc, indexOps); createForAllDimensions(builder, loc, indexOps); createForAllDimensions(builder, loc, indexOps); @@ -69,7 +69,7 @@ static gpu::LaunchFuncOp inlineBeneficiaryOps(gpu::GPUFuncOp kernelFunc, gpu::LaunchFuncOp launch) { OpBuilder kernelBuilder(kernelFunc.getBody()); auto &firstBlock = kernelFunc.getBody().front(); - SmallVector newLaunchArgs; + SmallVector newLaunchArgs; BlockAndValueMapping map; for (int i = 0, e = launch.getNumKernelOperands(); i < e; ++i) { map.map(launch.getKernelOperand(i), kernelFunc.getArgument(i)); @@ -82,7 +82,7 @@ static gpu::LaunchFuncOp inlineBeneficiaryOps(gpu::GPUFuncOp kernelFunc, } // Only inline operations that do not create new arguments. if (!llvm::all_of(operandOp->getOperands(), - [map](Value *value) { return map.contains(value); })) { + [map](ValuePtr value) { return map.contains(value); })) { continue; } auto clone = kernelBuilder.clone(*operandOp, map); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 1813b30165f..b94ee335bd2 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -415,7 +415,7 @@ static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) { // Expects vector to be of wrapped LLVM vector type and position to be of // wrapped LLVM i32 type. void LLVM::ExtractElementOp::build(Builder *b, OperationState &result, - Value *vector, Value *position, + ValuePtr vector, ValuePtr position, ArrayRef attrs) { auto wrappedVectorType = vector->getType().cast(); auto llvmType = wrappedVectorType.getVectorElementType(); @@ -681,7 +681,7 @@ static void printBrOp(OpAsmPrinter &p, BrOp &op) { // attribute-dict? static ParseResult parseBrOp(OpAsmParser &parser, OperationState &result) { Block *dest; - SmallVector operands; + SmallVector operands; if (parser.parseSuccessorAndUseList(dest, operands) || parser.parseOptionalAttrDict(result.attributes)) return failure(); @@ -708,8 +708,8 @@ static void printCondBrOp(OpAsmPrinter &p, CondBrOp &op) { static ParseResult parseCondBrOp(OpAsmParser &parser, OperationState &result) { Block *trueDest; Block *falseDest; - SmallVector trueOperands; - SmallVector falseOperands; + SmallVector trueOperands; + SmallVector falseOperands; OpAsmParser::OperandType condition; Builder &builder = parser.getBuilder(); @@ -1066,8 +1066,8 @@ static LogicalResult verify(GlobalOp op) { //===----------------------------------------------------------------------===// // Expects vector to be of wrapped LLVM vector type and position to be of // wrapped LLVM i32 type. -void LLVM::ShuffleVectorOp::build(Builder *b, OperationState &result, Value *v1, - Value *v2, ArrayAttr mask, +void LLVM::ShuffleVectorOp::build(Builder *b, OperationState &result, + ValuePtr v1, ValuePtr v2, ArrayAttr mask, ArrayRef attrs) { auto wrappedContainerType1 = v1->getType().cast(); auto vType = LLVMType::getVectorTy( @@ -1664,10 +1664,10 @@ LLVMType LLVMType::getVoidTy(LLVMDialect *dialect) { // Utility functions. //===----------------------------------------------------------------------===// -Value *mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder, - StringRef name, StringRef value, - LLVM::Linkage linkage, - LLVM::LLVMDialect *llvmDialect) { +ValuePtr mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder, + StringRef name, StringRef value, + LLVM::Linkage linkage, + LLVM::LLVMDialect *llvmDialect) { assert(builder.getInsertionBlock() && builder.getInsertionBlock()->getParentOp() && "expected builder to point to a block constrained in an op"); @@ -1684,13 +1684,13 @@ Value *mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder, builder.getStringAttr(value)); // Get the pointer to the first character in the global string. - Value *globalPtr = builder.create(loc, global); - Value *cst0 = builder.create( + ValuePtr globalPtr = builder.create(loc, global); + ValuePtr cst0 = builder.create( loc, LLVM::LLVMType::getInt64Ty(llvmDialect), builder.getIntegerAttr(builder.getIndexType(), 0)); return builder.create( loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), globalPtr, - ArrayRef({cst0, cst0})); + ArrayRef({cst0, cst0})); } bool mlir::LLVM::satisfiesLLVMModule(Operation *op) { diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp index d7e4d08527d..ee122e16037 100644 --- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp +++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp @@ -49,7 +49,7 @@ static StringRef toStringRef(LinalgDependenceGraph::DependenceType dt) { llvm_unreachable("Unexpected DependenceType"); } -Value *Aliases::find(Value *v) { +ValuePtr Aliases::find(ValuePtr v) { if (isa(v)) return v; @@ -147,9 +147,9 @@ LinalgDependenceGraph::getDependencesInto( } void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) { - for (auto *srcView : src.getOutputs()) { // W + for (auto srcView : src.getOutputs()) { // W // RAW graph - for (auto *dstView : dst.getInputs()) { // R + for (auto dstView : dst.getInputs()) { // R if (aliases.alias(srcView, dstView)) { // if alias, fill RAW addDependenceElem(DependenceType::RAW, LinalgOpView{src.getOperation(), srcView}, @@ -157,7 +157,7 @@ void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) { } } // WAW graph - for (auto *dstView : dst.getOutputs()) { // W + for (auto dstView : dst.getOutputs()) { // W if (aliases.alias(srcView, dstView)) { // if alias, fill WAW addDependenceElem(DependenceType::WAW, LinalgOpView{src.getOperation(), srcView}, @@ -165,9 +165,9 @@ void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) { } } } - for (auto *srcView : src.getInputs()) { // R + for (auto srcView : src.getInputs()) { // R // RAR graph - for (auto *dstView : dst.getInputs()) { // R + for (auto dstView : dst.getInputs()) { // R if (aliases.alias(srcView, dstView)) { // if alias, fill RAR addDependenceElem(DependenceType::RAR, LinalgOpView{src.getOperation(), srcView}, @@ -175,7 +175,7 @@ void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) { } } // WAR graph - for (auto *dstView : dst.getOutputs()) { // W + for (auto dstView : dst.getOutputs()) { // W if (aliases.alias(srcView, dstView)) { // if alias, fill WAR addDependenceElem(DependenceType::WAR, LinalgOpView{src.getOperation(), srcView}, @@ -194,14 +194,14 @@ LinalgDependenceGraph::findCoveringDependences(LinalgOp srcLinalgOp, } SmallVector LinalgDependenceGraph::findCoveringWrites( - LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value *view) const { + LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, ValuePtr view) const { return findOperationsWithCoveringDependences( srcLinalgOp, dstLinalgOp, view, {DependenceType::WAW, DependenceType::WAR}); } SmallVector LinalgDependenceGraph::findCoveringReads( - LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value *view) const { + LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, ValuePtr view) const { return findOperationsWithCoveringDependences( srcLinalgOp, dstLinalgOp, view, {DependenceType::RAR, DependenceType::RAW}); @@ -209,7 +209,7 @@ SmallVector LinalgDependenceGraph::findCoveringReads( SmallVector LinalgDependenceGraph::findOperationsWithCoveringDependences( - LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value *view, + LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, ValuePtr view, ArrayRef types) const { auto *src = srcLinalgOp.getOperation(); auto *dst = dstLinalgOp.getOperation(); diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp index ba96186da38..7b530d7f0df 100644 --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -44,8 +44,8 @@ static void getMaxDimIndex(ArrayRef structuredIndices, Operation *mlir::edsc::makeLinalgGenericOp( ArrayRef iteratorTypes, ArrayRef inputs, ArrayRef outputs, - function_ref)> regionBuilder, - ArrayRef otherValues, ArrayRef otherAttributes) { + function_ref)> regionBuilder, + ArrayRef otherValues, ArrayRef otherAttributes) { auto &builder = edsc::ScopedContext::getBuilder(); auto *ctx = builder.getContext(); unsigned nInputs = inputs.size(); @@ -66,7 +66,7 @@ Operation *mlir::edsc::makeLinalgGenericOp( AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, out.getExprs())); unsigned nViews = nInputs + nOutputs; - SmallVector values; + SmallVector values; values.reserve(nViews); values.append(inputs.begin(), inputs.end()); values.append(outputs.begin(), outputs.end()); @@ -109,7 +109,7 @@ Operation *mlir::edsc::makeLinalgGenericOp( return op; } -void mlir::edsc::ops::macRegionBuilder(ArrayRef args) { +void mlir::edsc::ops::macRegionBuilder(ArrayRef args) { using edsc::op::operator+; using edsc::op::operator*; assert(args.size() == 3 && "expected 3 block arguments"); @@ -122,7 +122,7 @@ Operation *mlir::edsc::ops::linalg_pointwise(UnaryPointwiseOpBuilder unaryOp, StructuredIndexed O) { SmallVector iterTypes(O.getExprs().size(), edsc::IterType::Parallel); - auto fun = [&unaryOp](ArrayRef args) { + auto fun = [&unaryOp](ArrayRef args) { assert(args.size() == 2 && "expected 2 block arguments"); ValueHandle a(args[0]); linalg_yield(unaryOp(a)); @@ -135,7 +135,7 @@ Operation *mlir::edsc::ops::linalg_pointwise_tanh(StructuredIndexed I, ; using edsc::intrinsics::tanh; UnaryPointwiseOpBuilder unOp( - [](ValueHandle a) -> Value * { return tanh(a); }); + [](ValueHandle a) -> ValuePtr { return tanh(a); }); return linalg_pointwise(unOp, I, O); } @@ -146,7 +146,7 @@ Operation *mlir::edsc::ops::linalg_pointwise(BinaryPointwiseOpBuilder binaryOp, StructuredIndexed O) { SmallVector iterTypes(O.getExprs().size(), edsc::IterType::Parallel); - auto fun = [&binaryOp](ArrayRef args) { + auto fun = [&binaryOp](ArrayRef args) { assert(args.size() == 3 && "expected 3 block arguments"); ValueHandle a(args[0]), b(args[1]); linalg_yield(binaryOp(a, b)); @@ -159,14 +159,14 @@ Operation *mlir::edsc::ops::linalg_pointwise_add(StructuredIndexed I1, StructuredIndexed O) { using edsc::op::operator+; BinaryPointwiseOpBuilder binOp( - [](ValueHandle a, ValueHandle b) -> Value * { return a + b; }); + [](ValueHandle a, ValueHandle b) -> ValuePtr { return a + b; }); return linalg_pointwise(binOp, I1, I2, O); } Operation *mlir::edsc::ops::linalg_pointwise_max(StructuredIndexed I1, StructuredIndexed I2, StructuredIndexed O) { - BinaryPointwiseOpBuilder binOp([](ValueHandle a, ValueHandle b) -> Value * { + BinaryPointwiseOpBuilder binOp([](ValueHandle a, ValueHandle b) -> ValuePtr { using edsc::intrinsics::select; using edsc::op::operator>; return select(a > b, a, b).getValue(); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 6eca181e9b4..c5f30b7e10b 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -318,7 +318,7 @@ static ParseResult parseRangeOp(OpAsmParser &parser, OperationState &result) { // SliceOp //===----------------------------------------------------------------------===// void mlir::linalg::SliceOp::build(Builder *b, OperationState &result, - Value *base, ValueRange indexings) { + ValuePtr base, ValueRange indexings) { result.addOperands(base); result.addOperands(indexings); @@ -394,7 +394,7 @@ static LogicalResult verify(SliceOp op) { // TransposeOp //===----------------------------------------------------------------------===// void mlir::linalg::TransposeOp::build(Builder *b, OperationState &result, - Value *view, AffineMapAttr permutation, + ValuePtr view, AffineMapAttr permutation, ArrayRef attrs) { auto permutationMap = permutation.getValue(); assert(permutationMap); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index 453daba204c..49cea7e4170 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -77,16 +77,16 @@ static llvm::cl::list clTileSizes( static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, ArrayRef loopRanges) { auto maps = loopToOperandRangesMaps(op); - SmallVector clonedViews; + SmallVector clonedViews; clonedViews.reserve(op.getNumInputsAndOutputs()); // Iterate over the inputs and outputs in order. // Extract the subranges from the linearized ranges. - SmallVector ios(op.getInputsAndOutputs()); + SmallVector ios(op.getInputsAndOutputs()); for (auto en : llvm::enumerate(ios)) { unsigned idx = en.index(); auto map = maps[idx]; LLVM_DEBUG(dbgs() << "map: " << map << "\n"); - Value *view = en.value(); + ValuePtr view = en.value(); SmallVector viewRanges(map.getNumResults()); for (auto en2 : llvm::enumerate(map.getResults())) { unsigned d = en2.index(); @@ -99,7 +99,7 @@ static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, } // Construct a new subview for the tile. unsigned rank = viewRanges.size(); - SmallVector offsets, sizes, strides; + SmallVector offsets, sizes, strides; offsets.reserve(rank); sizes.reserve(rank); strides.reserve(rank); @@ -117,7 +117,7 @@ static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, } struct ViewDimension { - Value *view; + ValuePtr view; unsigned dimension; }; @@ -130,14 +130,14 @@ static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) { auto maps = loopToOperandRangesMaps(op); // Iterate over the inputs and outputs in order. // Extract the subranges from the linearized ranges. - SmallVector ios(op.getInputsAndOutputs()); + SmallVector ios(op.getInputsAndOutputs()); for (auto en : llvm::enumerate(ios)) { unsigned idx = en.index(); auto map = maps[idx]; LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n"); LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n"); - Value *view = en.value(); - SmallVector viewRanges(map.getNumResults(), nullptr); + ValuePtr view = en.value(); + SmallVector viewRanges(map.getNumResults(), nullptr); for (auto en2 : llvm::enumerate(map.getResults())) { if (loopDepth == en2.value().cast().getPosition()) { LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth @@ -151,9 +151,9 @@ static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) { llvm_unreachable("Expect to be able to extract a view defining loop range"); } -static LinalgOp fuse(Value *producedView, LinalgOp producer, LinalgOp consumer, - unsigned consumerIdx, unsigned producerIdx, - OperationFolder *folder) { +static LinalgOp fuse(ValuePtr producedView, LinalgOp producer, + LinalgOp consumer, unsigned consumerIdx, + unsigned producerIdx, OperationFolder *folder) { auto subView = dyn_cast_or_null( consumer.getInput(consumerIdx)->getDefiningOp()); auto slice = dyn_cast_or_null( @@ -206,7 +206,7 @@ static LinalgOp fuse(Value *producedView, LinalgOp producer, LinalgOp consumer, // Encode structural fusion safety preconditions. // Some of these will be lifted in the future with better analysis. static bool isStructurallyFusableProducer(LinalgOp producer, - Value *consumedView, + ValuePtr consumedView, LinalgOp consumer) { if (producer.getNumOutputs() != 1) { LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)"); @@ -226,7 +226,7 @@ static bool isStructurallyFusableProducer(LinalgOp producer, bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph, LinalgOp consumer, - Value *consumedView, + ValuePtr consumedView, LinalgOp producer) { // Make some simple structural checks that alleviate the need for more // complex analyses. @@ -245,7 +245,7 @@ bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph, } bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph, - LinalgOp consumer, Value *consumedView, + LinalgOp consumer, ValuePtr consumedView, LinalgOp producer) { if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer)) return false; @@ -272,13 +272,13 @@ Optional mlir::linalg::fuseProducerOf( auto producer = cast(dependence.dependentOpView.op); // Check that the dependence is indeed on the input `consumerIdx` view. - auto *consumedView = dependence.indexingView; + auto consumedView = dependence.indexingView; if (consumer.getInput(consumerIdx) != consumedView) continue; // Consumer consumes this view, `isStructurallyFusableProducer` also checks // whether it is a strict subview of the producer view. - auto *producedView = dependence.dependentOpView.view; + auto producedView = dependence.dependentOpView.view; auto producerIdx = producer.getIndexOfOutput(producedView).getValue(); // `consumerIdx` and `producerIdx` exist by construction. LLVM_DEBUG(dbgs() << "\nRAW producer: " << *producer.getOperation() diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp index c50c495750f..e468c19a0b4 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -49,7 +49,7 @@ using edsc::op::operator==; static SmallVector makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map, - ArrayRef vals) { + ArrayRef vals) { assert(map.getNumSymbols() == 0); assert(map.getNumInputs() == vals.size()); SmallVector res; @@ -57,35 +57,35 @@ makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map, auto dims = map.getNumDims(); for (auto e : map.getResults()) { auto exprMap = AffineMap::get(dims, 0, e); - SmallVector operands(vals.begin(), vals.end()); + SmallVector operands(vals.begin(), vals.end()); canonicalizeMapAndOperands(&exprMap, &operands); res.push_back(affine_apply(exprMap, operands)); } return res; } -static SmallVector permuteIvs(ArrayRef ivs, - Optional permutation) { +static SmallVector permuteIvs(ArrayRef ivs, + Optional permutation) { return permutation ? applyMapToValues(ScopedContext::getBuilder(), ScopedContext::getLocation(), permutation.getValue(), ivs) - : SmallVector(ivs.begin(), ivs.end()); + : SmallVector(ivs.begin(), ivs.end()); } // Creates a number of ranges equal to the number of results in `map`. // The returned ranges correspond to the loop ranges, in the proper order, for // which new loops will be created. -static SmallVector emitLoopRanges(OpBuilder &b, Location loc, - AffineMap map, - ArrayRef allViewSizes); -SmallVector emitLoopRanges(OpBuilder &b, Location loc, - AffineMap map, - ArrayRef allViewSizes) { +static SmallVector emitLoopRanges(OpBuilder &b, Location loc, + AffineMap map, + ArrayRef allViewSizes); +SmallVector emitLoopRanges(OpBuilder &b, Location loc, + AffineMap map, + ArrayRef allViewSizes) { // Apply `map` to get view sizes in loop order. auto sizes = applyMapToValues(b, loc, map, allViewSizes); // Create a new range with the applied tile sizes. ScopedContext scope(b, loc); - SmallVector res; + SmallVector res; for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx) { res.push_back(range(constant_index(0), sizes[idx], constant_index(1))); } @@ -98,7 +98,7 @@ class LinalgScopedEmitter {}; template class LinalgScopedEmitter { public: - static void emitScalarImplementation(ArrayRef allIvs, + static void emitScalarImplementation(ArrayRef allIvs, CopyOp copyOp) { auto nPar = copyOp.getNumParallelLoops(); assert(nPar == allIvs.size()); @@ -121,7 +121,7 @@ public: template class LinalgScopedEmitter { public: - static void emitScalarImplementation(ArrayRef allIvs, + static void emitScalarImplementation(ArrayRef allIvs, FillOp fillOp) { auto nPar = fillOp.getNumParallelLoops(); assert(nPar == allIvs.size()); @@ -138,7 +138,7 @@ public: template class LinalgScopedEmitter { public: - static void emitScalarImplementation(ArrayRef allIvs, DotOp dotOp) { + static void emitScalarImplementation(ArrayRef allIvs, DotOp dotOp) { assert(allIvs.size() == 1); IndexHandle r_i(allIvs[0]); IndexedValueType A(dotOp.getInput(0)), B(dotOp.getInput(1)), @@ -151,7 +151,7 @@ public: template class LinalgScopedEmitter { public: - static void emitScalarImplementation(ArrayRef allIvs, + static void emitScalarImplementation(ArrayRef allIvs, MatvecOp matvecOp) { assert(allIvs.size() == 2); IndexHandle i(allIvs[0]), r_j(allIvs[1]); @@ -165,7 +165,7 @@ public: template class LinalgScopedEmitter { public: - static void emitScalarImplementation(ArrayRef allIvs, + static void emitScalarImplementation(ArrayRef allIvs, MatmulOp matmulOp) { assert(allIvs.size() == 3); IndexHandle i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]); @@ -179,7 +179,7 @@ public: template class LinalgScopedEmitter { public: - static void emitScalarImplementation(ArrayRef allIvs, + static void emitScalarImplementation(ArrayRef allIvs, ConvOp convOp) { auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); @@ -229,14 +229,14 @@ public: template class LinalgScopedEmitter { public: - static void emitScalarImplementation(ArrayRef allIvs, + static void emitScalarImplementation(ArrayRef allIvs, GenericOp genericOp) { auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); using edsc::intrinsics::detail::ValueHandleArray; unsigned nInputs = genericOp.getNumInputs(); unsigned nOutputs = genericOp.getNumOutputs(); - SmallVector indexedValues(nInputs + nOutputs); + SmallVector indexedValues(nInputs + nOutputs); // 1.a. Emit std_load from input views. for (unsigned i = 0; i < nInputs; ++i) { @@ -324,7 +324,7 @@ public: template class LinalgScopedEmitter { public: - static void emitScalarImplementation(ArrayRef allIvs, + static void emitScalarImplementation(ArrayRef allIvs, IndexedGenericOp indexedGenericOp) { auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); @@ -332,7 +332,7 @@ public: unsigned nInputs = indexedGenericOp.getNumInputs(); unsigned nOutputs = indexedGenericOp.getNumOutputs(); unsigned nLoops = allIvs.size(); - SmallVector indexedValues(nLoops + nInputs + nOutputs); + SmallVector indexedValues(nLoops + nInputs + nOutputs); for (unsigned i = 0; i < nLoops; ++i) { indexedValues[i] = allIvs[i]; diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp index f4364928af8..999406e05cf 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -99,7 +99,7 @@ LogicalResult mlir::linalg::tileAndFuseLinalgOpAndSetMarker( } bool mlir::linalg::detail::isProducedByOpOfTypeImpl( - Operation *consumerOp, Value *consumedView, + Operation *consumerOp, ValuePtr consumedView, function_ref isaOpType) { LinalgOp consumer = dyn_cast(consumerOp); if (!consumer) @@ -175,7 +175,7 @@ LogicalResult mlir::linalg::vectorizeGenericOp(PatternRewriter &rewriter, return failure(); // TODO(ntv): non-identity layout. - auto isStaticMemRefWithIdentityLayout = [](Value *v) { + auto isStaticMemRefWithIdentityLayout = [](ValuePtr v) { auto m = v->getType().dyn_cast(); if (!m || !m.hasStaticShape() || !m.getAffineMaps().empty()) return false; @@ -235,7 +235,7 @@ mlir::linalg::permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op, LogicalResult mlir::linalg::linalgOpPromoteSubviews(PatternRewriter &rewriter, Operation *op) { LinalgOp linOp = dyn_cast(op); - SetVector subViews; + SetVector subViews; for (auto it : linOp.getInputsAndOutputs()) if (auto sv = dyn_cast_or_null(it->getDefiningOp())) subViews.insert(sv); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp index c7fbebce383..b1dae455194 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -55,14 +55,15 @@ static llvm::cl::opt clPromoteDynamic( llvm::cl::desc("Test generation of dynamic promoted buffers"), llvm::cl::cat(clOptionsCategory), llvm::cl::init(false)); -static Value *allocBuffer(Type elementType, Value *size, bool dynamicBuffers) { +static ValuePtr allocBuffer(Type elementType, ValuePtr size, + bool dynamicBuffers) { auto *ctx = size->getContext(); auto width = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8); if (!dynamicBuffers) if (auto cst = dyn_cast_or_null(size->getDefiningOp())) return alloc( MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx))); - Value *mul = muli(constant_index(width), size); + ValuePtr mul = muli(constant_index(width), size); return alloc(MemRefType::get(-1, IntegerType::get(8, ctx)), mul); } @@ -92,20 +93,20 @@ static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc, auto viewType = subView.getType(); auto rank = viewType.getRank(); - Value *allocSize = one; - SmallVector fullRanges, partialRanges; + ValuePtr allocSize = one; + SmallVector fullRanges, partialRanges; fullRanges.reserve(rank); partialRanges.reserve(rank); for (auto en : llvm::enumerate(subView.getRanges())) { auto rank = en.index(); auto rangeValue = en.value(); - Value *d = rangeValue.size; + ValuePtr d = rangeValue.size; allocSize = muli(folder, allocSize, d).getValue(); fullRanges.push_back(d); partialRanges.push_back(range(folder, zero, dim(subView, rank), one)); } SmallVector dynSizes(fullRanges.size(), -1); - auto *buffer = + auto buffer = allocBuffer(viewType.getElementType(), allocSize, dynamicBuffers); auto fullLocalView = view( MemRefType::get(dynSizes, viewType.getElementType()), buffer, fullRanges); @@ -115,7 +116,7 @@ static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc, SmallVector mlir::linalg::promoteSubViews(OpBuilder &b, Location loc, - ArrayRef subViews, bool dynamicBuffers, + ArrayRef subViews, bool dynamicBuffers, OperationFolder *folder) { if (subViews.empty()) return {}; @@ -123,8 +124,8 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc, ScopedContext scope(b, loc); SmallVector res; res.reserve(subViews.size()); - DenseMap promotionInfoMap; - for (auto *v : subViews) { + DenseMap promotionInfoMap; + for (auto v : subViews) { SubViewOp subView = cast(v->getDefiningOp()); auto viewType = subView.getType(); // TODO(ntv): support more cases than just float. @@ -136,7 +137,7 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc, res.push_back(promotionInfo); } - for (auto *v : subViews) { + for (auto v : subViews) { SubViewOp subView = cast(v->getDefiningOp()); auto info = promotionInfoMap.find(v); if (info == promotionInfoMap.end()) @@ -144,14 +145,14 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc, // TODO(ntv): value to fill with should be related to the operation. // For now, just use APFloat(0.0f). auto t = subView.getType().getElementType().cast(); - Value *fillVal = constant_float(folder, APFloat(0.0f), t); + ValuePtr fillVal = constant_float(folder, APFloat(0.0f), t); // TODO(ntv): fill is only necessary if `promotionInfo` has a full local // view that is different from the partial local view and we are on the // boundary. fill(info->second.fullLocalView, fillVal); } - for (auto *v : subViews) { + for (auto v : subViews) { auto info = promotionInfoMap.find(v); if (info == promotionInfoMap.end()) continue; @@ -161,19 +162,19 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc, } LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op, - SetVector subViews, + SetVector subViews, bool dynamicBuffers, OperationFolder *folder) { // 1. Promote the specified views and use them in the new op. ScopedContext scope(b, op.getLoc()); auto promotedBufferAndViews = promoteSubViews( b, op.getLoc(), subViews.getArrayRef(), dynamicBuffers, folder); - SmallVector opViews; + SmallVector opViews; opViews.reserve(op.getNumInputsAndOutputs()); - SmallVector, 8> writebackViews; + SmallVector, 8> writebackViews; writebackViews.reserve(subViews.size()); unsigned promotedIdx = 0; - for (auto *view : op.getInputsAndOutputs()) { + for (auto view : op.getInputsAndOutputs()) { if (subViews.count(view) != 0) { opViews.push_back(promotedBufferAndViews[promotedIdx].fullLocalView); writebackViews.emplace_back(std::make_pair( @@ -214,7 +215,7 @@ static void promoteSubViews(FuncOp f, bool dynamicBuffers) { f.walk([dynamicBuffers, &folder, &toErase](LinalgOp op) { // TODO(ntv) some heuristic here to decide what to promote. Atm it is all or // nothing. - SetVector subViews; + SetVector subViews; OpBuilder b(op); for (auto it : op.getInputsAndOutputs()) if (auto sv = dyn_cast_or_null(it->getDefiningOp())) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 4d8a24cb6cb..07d559918cf 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -53,7 +53,7 @@ static llvm::cl::list llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated, llvm::cl::cat(clOptionsCategory)); -static bool isZero(Value *v) { +static bool isZero(ValuePtr v) { return isa_and_nonnull(v->getDefiningOp()) && cast(v->getDefiningOp()).getValue() == 0; } @@ -71,12 +71,12 @@ using LoopIndexToRangeIndexMap = DenseMap; // indices of newly created loops. static std::tuple, LoopIndexToRangeIndexMap> makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map, - ArrayRef allViewSizes, - ArrayRef allTileSizes, OperationFolder *folder) { + ArrayRef allViewSizes, + ArrayRef allTileSizes, OperationFolder *folder) { assert(allTileSizes.size() == map.getNumResults()); // Apply `map` to get view sizes in loop order. auto viewSizes = applyMapToValues(b, loc, map, allViewSizes, folder); - SmallVector tileSizes(allTileSizes.begin(), allTileSizes.end()); + SmallVector tileSizes(allTileSizes.begin(), allTileSizes.end()); // Traverse the tile sizes, which are in loop order, erase zeros everywhere. LoopIndexToRangeIndexMap loopIndexToRangeIndex; @@ -110,7 +110,7 @@ namespace { // `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0] // struct TileCheck : public AffineExprVisitor { - TileCheck(ArrayRef tileSizes) + TileCheck(ArrayRef tileSizes) : isTiled(false), tileSizes(tileSizes) {} void visitDimExpr(AffineDimExpr expr) { @@ -124,7 +124,7 @@ struct TileCheck : public AffineExprVisitor { "nonpositive multiplying coefficient"); } bool isTiled; - ArrayRef tileSizes; + ArrayRef tileSizes; }; } // namespace @@ -206,11 +206,11 @@ void transformIndexedGenericOpIndices( auto rangeIndex = loopIndexToRangeIndex.find(i); if (rangeIndex == loopIndexToRangeIndex.end()) continue; - Value *oldIndex = block.getArgument(i); + ValuePtr oldIndex = block.getArgument(i); // Offset the index argument `i` by the value of the corresponding induction // variable and replace all uses of the previous value. - Value *newIndex = b.create(indexedGenericOp.getLoc(), oldIndex, - pivs[rangeIndex->second]->getValue()); + ValuePtr newIndex = b.create(indexedGenericOp.getLoc(), oldIndex, + pivs[rangeIndex->second]->getValue()); for (auto &use : oldIndex->getUses()) { if (use.getOwner() == newIndex->getDefiningOp()) continue; @@ -219,7 +219,7 @@ void transformIndexedGenericOpIndices( } } -static bool isTiled(AffineExpr expr, ArrayRef tileSizes) { +static bool isTiled(AffineExpr expr, ArrayRef tileSizes) { if (!expr) return false; TileCheck t(tileSizes); @@ -229,7 +229,7 @@ static bool isTiled(AffineExpr expr, ArrayRef tileSizes) { // Checks whether the view with index `viewIndex` within `linalgOp` varies with // respect to a non-zero `tileSize`. -static bool isTiled(AffineMap map, ArrayRef tileSizes) { +static bool isTiled(AffineMap map, ArrayRef tileSizes) { if (!map) return false; for (unsigned r = 0; r < map.getNumResults(); ++r) @@ -238,13 +238,13 @@ static bool isTiled(AffineMap map, ArrayRef tileSizes) { return false; } -static SmallVector +static SmallVector makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp, - ArrayRef ivs, ArrayRef tileSizes, - ArrayRef viewSizes, OperationFolder *folder) { + ArrayRef ivs, ArrayRef tileSizes, + ArrayRef viewSizes, OperationFolder *folder) { assert(ivs.size() == static_cast(llvm::count_if( llvm::make_range(tileSizes.begin(), tileSizes.end()), - [](Value *v) { return !isZero(v); })) && + [](ValuePtr v) { return !isZero(v); })) && "expected as many ivs as non-zero sizes"); using edsc::intrinsics::select; @@ -253,21 +253,22 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp, // Construct (potentially temporary) mins and maxes on which to apply maps // that define tile subviews. - SmallVector lbs, subViewSizes; + SmallVector lbs, subViewSizes; for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) { bool isTiled = !isZero(tileSizes[idx]); - lbs.push_back(isTiled ? ivs[idxIvs++] : (Value *)constant_index(folder, 0)); + lbs.push_back(isTiled ? ivs[idxIvs++] + : (ValuePtr)constant_index(folder, 0)); subViewSizes.push_back(isTiled ? tileSizes[idx] : viewSizes[idx]); } auto *op = linalgOp.getOperation(); - SmallVector res; + SmallVector res; res.reserve(op->getNumOperands()); auto viewIteratorBegin = linalgOp.getInputsAndOutputs().begin(); for (unsigned viewIndex = 0; viewIndex < linalgOp.getNumInputsAndOutputs(); ++viewIndex) { - Value *view = *(viewIteratorBegin + viewIndex); + ValuePtr view = *(viewIteratorBegin + viewIndex); unsigned rank = view->getType().cast().getRank(); auto map = loopToOperandRangesMaps(linalgOp)[viewIndex]; // If the view is not tiled, we can use it as is. @@ -277,7 +278,7 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp, } // Construct a new subview for the tile. - SmallVector offsets, sizes, strides; + SmallVector offsets, sizes, strides; offsets.reserve(rank); sizes.reserve(rank); strides.reserve(rank); @@ -292,9 +293,9 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp, // Tiling creates a new slice at the proper index, the slice step is 1 // (i.e. the slice view does not subsample, stepping occurs in the loop). auto m = map.getSubMap({r}); - auto *offset = applyMapToValues(b, loc, m, lbs, folder).front(); + auto offset = applyMapToValues(b, loc, m, lbs, folder).front(); offsets.push_back(offset); - auto *size = applyMapToValues(b, loc, m, subViewSizes, folder).front(); + auto size = applyMapToValues(b, loc, m, subViewSizes, folder).front(); sizes.push_back(size); strides.push_back(constant_index(folder, 1)); } @@ -308,7 +309,7 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp, // This is a special type of folding that we only apply when `folder` is // defined. if (folder) - for (auto *v : llvm::concat(lbs, subViewSizes)) + for (auto v : llvm::concat(lbs, subViewSizes)) if (v->use_empty()) v->getDefiningOp()->erase(); @@ -316,7 +317,7 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp, } Optional mlir::linalg::tileLinalgOp( - OpBuilder &b, LinalgOp op, ArrayRef tileSizes, + OpBuilder &b, LinalgOp op, ArrayRef tileSizes, ArrayRef permutation, OperationFolder *folder) { // 1. Enforce the convention that "tiling by zero" skips tiling a particular // dimension. This convention is significantly simpler to handle instead of @@ -360,7 +361,7 @@ Optional mlir::linalg::tileLinalgOp( LoopNestRangeBuilder(pivs, loopRanges)([&] { auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); - SmallVector ivValues(ivs.begin(), ivs.end()); + SmallVector ivValues(ivs.begin(), ivs.end()); // If we have to apply a permutation to the tiled loop nest, we have to // reorder the induction variables This permutation is the right one @@ -411,7 +412,7 @@ Optional mlir::linalg::tileLinalgOp( ScopedContext scope(b, op.getLoc()); // Materialize concrete tile size values to pass the generic tiling function. - SmallVector tileSizeValues; + SmallVector tileSizeValues; tileSizeValues.reserve(tileSizes.size()); for (auto ts : tileSizes) tileSizeValues.push_back(constant_index(folder, ts)); diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index eb501f9b5b5..125937807f4 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -92,7 +92,7 @@ mlir::edsc::LoopNestRangeBuilder::LoopNestRangeBuilder( } mlir::edsc::LoopNestRangeBuilder::LoopNestRangeBuilder( - ArrayRef ivs, ArrayRef ranges) + ArrayRef ivs, ArrayRef ranges) : LoopNestRangeBuilder( ivs, SmallVector(ranges.begin(), ranges.end())) {} @@ -106,26 +106,26 @@ ValueHandle LoopNestRangeBuilder::LoopNestRangeBuilder::operator()( return ValueHandle::null(); } -static Value *emitOrFoldComposedAffineApply(OpBuilder &b, Location loc, - AffineMap map, - ArrayRef operandsRef, - OperationFolder *folder) { - SmallVector operands(operandsRef.begin(), operandsRef.end()); +static ValuePtr emitOrFoldComposedAffineApply(OpBuilder &b, Location loc, + AffineMap map, + ArrayRef operandsRef, + OperationFolder *folder) { + SmallVector operands(operandsRef.begin(), operandsRef.end()); fullyComposeAffineMapAndOperands(&map, &operands); canonicalizeMapAndOperands(&map, &operands); return folder ? folder->create(b, loc, map, operands) : b.create(loc, map, operands); } -SmallVector +SmallVector mlir::linalg::applyMapToValues(OpBuilder &b, Location loc, AffineMap map, - ArrayRef values, + ArrayRef values, OperationFolder *folder) { - SmallVector res; + SmallVector res; res.reserve(map.getNumResults()); unsigned numDims = map.getNumDims(); // For each `expr` in `map`, applies the `expr` to the values extracted from - // ranges. If the resulting application can be folded into a Value*, the + // ranges. If the resulting application can be folded into a Value, the // folding occurs eagerly. Otherwise, an affine.apply operation is emitted. for (auto expr : map.getResults()) { AffineMap map = AffineMap::get(numDims, 0, expr); @@ -137,12 +137,12 @@ mlir::linalg::applyMapToValues(OpBuilder &b, Location loc, AffineMap map, /// Returns all the operands of `linalgOp` that are not views. /// Asserts that these operands are value types to allow transformations like /// tiling to just use the values when cloning `linalgOp`. -SmallVector +SmallVector mlir::linalg::getAssumedNonViewOperands(LinalgOp linalgOp) { auto *op = linalgOp.getOperation(); unsigned numViews = linalgOp.getNumInputsAndOutputs(); unsigned nOperands = op->getNumOperands() - numViews; - SmallVector res; + SmallVector res; res.reserve(nOperands); for (unsigned i = 0; i < nOperands; ++i) { res.push_back(op->getOperand(numViews + i)); diff --git a/mlir/lib/Dialect/LoopOps/LoopOps.cpp b/mlir/lib/Dialect/LoopOps/LoopOps.cpp index fc8832e9a46..9610a1ac270 100644 --- a/mlir/lib/Dialect/LoopOps/LoopOps.cpp +++ b/mlir/lib/Dialect/LoopOps/LoopOps.cpp @@ -69,8 +69,8 @@ LoopOpsDialect::LoopOpsDialect(MLIRContext *context) // ForOp //===----------------------------------------------------------------------===// -void ForOp::build(Builder *builder, OperationState &result, Value *lb, - Value *ub, Value *step) { +void ForOp::build(Builder *builder, OperationState &result, ValuePtr lb, + ValuePtr ub, ValuePtr step) { result.addOperands({lb, ub, step}); Region *bodyRegion = result.addRegion(); ForOp::ensureTerminator(*bodyRegion, *builder, result.location); @@ -134,7 +134,7 @@ static ParseResult parseForOp(OpAsmParser &parser, OperationState &result) { Region &ForOp::getLoopBody() { return region(); } -bool ForOp::isDefinedOutsideOfLoop(Value *value) { +bool ForOp::isDefinedOutsideOfLoop(ValuePtr value) { return !region().isAncestor(value->getParentRegion()); } @@ -144,8 +144,8 @@ LogicalResult ForOp::moveOutOfLoop(ArrayRef ops) { return success(); } -ForOp mlir::loop::getForInductionVarOwner(Value *val) { - auto *ivArg = dyn_cast(val); +ForOp mlir::loop::getForInductionVarOwner(ValuePtr val) { + auto ivArg = dyn_cast(val); if (!ivArg) return ForOp(); assert(ivArg->getOwner() && "unlinked block argument"); @@ -157,7 +157,7 @@ ForOp mlir::loop::getForInductionVarOwner(Value *val) { // IfOp //===----------------------------------------------------------------------===// -void IfOp::build(Builder *builder, OperationState &result, Value *cond, +void IfOp::build(Builder *builder, OperationState &result, ValuePtr cond, bool withElseRegion) { result.addOperands(cond); Region *thenRegion = result.addRegion(); diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp index def8ee810fe..4416e1e6b04 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -94,7 +94,7 @@ struct SPIRVInlinerInterface : public DialectInlinerInterface { /// Handle the given inlined terminator by replacing it with a new operation /// as necessary. void handleTerminator(Operation *op, - ArrayRef valuesToRepl) const final { + ArrayRef valuesToRepl) const final { // Only spv.ReturnValue needs to be handled here. auto retValOp = dyn_cast(op); if (!retValOp) diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp index 284fe915029..ca9b883a703 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -229,9 +229,9 @@ getOrInsertBuiltinVariable(spirv::ModuleOp &moduleOp, Location loc, /// Gets the global variable associated with a builtin and add /// it if it doesn't exist. -Value *mlir::spirv::getBuiltinVariableValue(Operation *op, - spirv::BuiltIn builtin, - OpBuilder &builder) { +ValuePtr mlir::spirv::getBuiltinVariableValue(Operation *op, + spirv::BuiltIn builtin, + OpBuilder &builder) { auto moduleOp = op->getParentOfType(); if (!moduleOp) { op->emitError("expected operation to be within a SPIR-V module"); @@ -239,7 +239,7 @@ Value *mlir::spirv::getBuiltinVariableValue(Operation *op, } spirv::GlobalVariableOp varOp = getOrInsertBuiltinVariable(moduleOp, op->getLoc(), builtin, builder); - Value *ptr = builder.create(op->getLoc(), varOp); + ValuePtr ptr = builder.create(op->getLoc(), varOp); return builder.create(op->getLoc(), ptr, /*memory_access =*/nullptr, /*alignment =*/nullptr); diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 0df4525bac6..a20c18056e1 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -273,8 +273,8 @@ static LogicalResult verifyMemorySemantics(BarrierOp op) { } template -static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value *ptr, - Value *val) { +static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, + ValuePtr ptr, ValuePtr val) { // ODS already checks ptr is spirv::PointerType. Just check that the pointee // type of the pointer and the type of the value are the same // @@ -664,8 +664,8 @@ static ParseResult parseShiftOp(OpAsmParser &parser, OperationState &state) { } static void printShiftOp(Operation *op, OpAsmPrinter &printer) { - Value *base = op->getOperand(0); - Value *shift = op->getOperand(1); + ValuePtr base = op->getOperand(0); + ValuePtr shift = op->getOperand(1); printer << op->getName() << ' ' << *base << ", " << *shift << " : " << base->getType() << ", " << shift->getType(); } @@ -742,7 +742,7 @@ static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) { } void spirv::AccessChainOp::build(Builder *builder, OperationState &state, - Value *basePtr, ValueRange indices) { + ValuePtr basePtr, ValueRange indices) { auto type = getElementPtrType(basePtr->getType(), indices, state.location); assert(type && "Unable to deduce return type based on basePtr and indices"); build(builder, state, type, basePtr, indices); @@ -782,8 +782,8 @@ static void print(spirv::AccessChainOp op, OpAsmPrinter &printer) { } static LogicalResult verify(spirv::AccessChainOp accessChainOp) { - SmallVector indices(accessChainOp.indices().begin(), - accessChainOp.indices().end()); + SmallVector indices(accessChainOp.indices().begin(), + accessChainOp.indices().end()); auto resultType = getElementPtrType(accessChainOp.base_ptr()->getType(), indices, accessChainOp.getLoc()); if (!resultType) { @@ -824,7 +824,7 @@ struct CombineChainedAccessChain } // Combine indices. - SmallVector indices(parentAccessChainOp.indices()); + SmallVector indices(parentAccessChainOp.indices()); indices.append(accessChainOp.indices().begin(), accessChainOp.indices().end()); @@ -1060,7 +1060,7 @@ static LogicalResult verify(spirv::BitFieldInsertOp bitFieldOp) { static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &state) { Block *dest; - SmallVector destOperands; + SmallVector destOperands; if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure(); state.addSuccessor(dest, destOperands); @@ -1089,7 +1089,7 @@ static ParseResult parseBranchConditionalOp(OpAsmParser &parser, auto &builder = parser.getBuilder(); OpAsmParser::OperandType condInfo; Block *dest; - SmallVector destOperands; + SmallVector destOperands; // Parse the condition. Type boolTy = builder.getI1Type(); @@ -1214,7 +1214,7 @@ static void print(spirv::CompositeConstructOp compositeConstructOp, static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) { auto cType = compositeConstructOp.getType().cast(); - SmallVector constituents(compositeConstructOp.constituents()); + SmallVector constituents(compositeConstructOp.constituents()); if (constituents.size() != cType.getNumElements()) { return compositeConstructOp.emitError( "has incorrect number of operands: expected ") @@ -1239,7 +1239,7 @@ static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) { //===----------------------------------------------------------------------===// void spirv::CompositeExtractOp::build(Builder *builder, OperationState &state, - Value *composite, + ValuePtr composite, ArrayRef indices) { auto indexAttr = builder->getI32ArrayAttr(indices); auto elementType = @@ -1963,7 +1963,7 @@ OpFoldResult spirv::ISubOp::fold(ArrayRef operands) { //===----------------------------------------------------------------------===// void spirv::LoadOp::build(Builder *builder, OperationState &state, - Value *basePtr, IntegerAttr memory_access, + ValuePtr basePtr, IntegerAttr memory_access, IntegerAttr alignment) { auto ptrType = basePtr->getType().cast(); build(builder, state, ptrType.getPointeeType(), basePtr, memory_access, @@ -2497,7 +2497,8 @@ static LogicalResult verify(spirv::ReturnValueOp retValOp) { //===----------------------------------------------------------------------===// void spirv::SelectOp::build(Builder *builder, OperationState &state, - Value *cond, Value *trueValue, Value *falseValue) { + ValuePtr cond, ValuePtr trueValue, + ValuePtr falseValue) { build(builder, state, trueValue->getType(), cond, trueValue, falseValue); } @@ -2698,9 +2699,9 @@ struct ConvertSelectionOpToSelect return matchFailure(); } - auto *trueValue = getSrcValue(trueBlock); - auto *falseValue = getSrcValue(falseBlock); - auto *ptrValue = getDstPtr(trueBlock); + auto trueValue = getSrcValue(trueBlock); + auto falseValue = getSrcValue(falseBlock); + auto ptrValue = getDstPtr(trueBlock); auto storeOpAttributes = cast(trueBlock->front()).getOperation()->getAttrs(); @@ -2747,13 +2748,13 @@ private: } // Returns a soruce value for the given block. - Value *getSrcValue(Block *block) const { + ValuePtr getSrcValue(Block *block) const { auto storeOp = cast(block->front()); return storeOp.value(); } // Returns a destination value for the given block. - Value *getDstPtr(Block *block) const { + ValuePtr getDstPtr(Block *block) const { auto storeOp = cast(block->front()); return storeOp.ptr(); } diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index df9cb47a562..799828cb629 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -327,7 +327,7 @@ private: /// This method materializes normal constants and inserts "casting" ops /// (`spv._address_of` and `spv._reference_of`) to turn an symbol into a SSA /// value for handling uses of module scope constants/variables in functions. - Value *getValue(uint32_t id); + ValuePtr getValue(uint32_t id); /// Slices the first instruction out of `binary` and returns its opcode and /// operands via `opcode` and `operands` respectively. Returns failure if @@ -446,7 +446,7 @@ private: DenseMap blockPhiInfo; // Result to value mapping. - DenseMap valueMap; + DenseMap valueMap; // Mapping from result to undef value of a type. DenseMap undefMap; @@ -1520,7 +1520,7 @@ Deserializer::processBranchConditional(ArrayRef operands) { "false label, and optionally two branch weights"); } - auto *condition = getValue(operands[0]); + auto condition = getValue(operands[0]); auto *trueBlock = getOrCreateBlock(operands[1]); auto *falseBlock = getOrCreateBlock(operands[2]); @@ -1531,8 +1531,8 @@ Deserializer::processBranchConditional(ArrayRef operands) { opBuilder.create( unknownLoc, condition, trueBlock, - /*trueArguments=*/ArrayRef(), falseBlock, - /*falseArguments=*/ArrayRef(), weights); + /*trueArguments=*/ArrayRef(), falseBlock, + /*falseArguments=*/ArrayRef(), weights); return success(); } @@ -1626,7 +1626,7 @@ LogicalResult Deserializer::processPhi(ArrayRef operands) { // Create a block argument for this OpPhi instruction. Type blockArgType = getType(operands[0]); - BlockArgument *blockArg = curBlock->addArgument(blockArgType); + BlockArgumentPtr blockArg = curBlock->addArgument(blockArgType); valueMap[operands[1]] = blockArg; LLVM_DEBUG(llvm::dbgs() << "[phi] created block argument " << blockArg << " id = " << operands[1] << " of type " @@ -1783,8 +1783,8 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() { LLVM_DEBUG(llvm::dbgs() << "[cf] cloned block " << newBlock << " from block " << block << "\n"); if (!isFnEntryBlock(block)) { - for (BlockArgument *blockArg : block->getArguments()) { - auto *newArg = newBlock->addArgument(blockArg->getType()); + for (BlockArgumentPtr blockArg : block->getArguments()) { + auto newArg = newBlock->addArgument(blockArg->getType()); mapper.map(blockArg, newArg); LLVM_DEBUG(llvm::dbgs() << "[cf] remapped block argument " << blockArg << " to " << newArg << '\n'); @@ -1801,10 +1801,10 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() { // Go through all ops and remap the operands. auto remapOperands = [&](Operation *op) { for (auto &operand : op->getOpOperands()) - if (auto *mappedOp = mapper.lookupOrNull(operand.get())) + if (auto mappedOp = mapper.lookupOrNull(operand.get())) operand.set(mappedOp); for (auto &succOp : op->getBlockOperands()) - if (auto *mappedOp = mapper.lookupOrNull(succOp.get())) + if (auto mappedOp = mapper.lookupOrNull(succOp.get())) succOp.set(mappedOp); }; for (auto &block : body) { @@ -1824,13 +1824,13 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() { // we place the selection/loop op inside the old merge block, we need to // make sure the old merge block has the same block argument list. assert(mergeBlock->args_empty() && "OpPhi in loop merge block unsupported"); - for (BlockArgument *blockArg : headerBlock->getArguments()) { + for (BlockArgumentPtr blockArg : headerBlock->getArguments()) { mergeBlock->addArgument(blockArg->getType()); } // If the loop header block has block arguments, make sure the spv.branch op // matches. - SmallVector blockArgs; + SmallVector blockArgs; if (!headerBlock->args_empty()) blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()}; @@ -1838,7 +1838,7 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() { // loop header block. builder.setInsertionPointToEnd(&body.front()); builder.create(location, mapper.lookupOrNull(headerBlock), - ArrayRef(blockArgs)); + ArrayRef(blockArgs)); } // All the blocks cloned into the SelectionOp/LoopOp's region can now be @@ -1924,10 +1924,10 @@ LogicalResult Deserializer::wireUpBlockArgument() { auto *op = block->getTerminator(); opBuilder.setInsertionPoint(op); - SmallVector blockArgs; + SmallVector blockArgs; blockArgs.reserve(phiInfo.size()); for (uint32_t valueId : phiInfo) { - if (Value *value = getValue(valueId)) { + if (ValuePtr value = getValue(valueId)) { blockArgs.push_back(value); LLVM_DEBUG(llvm::dbgs() << "[phi] block argument " << value << " id = " << valueId << '\n'); @@ -1996,7 +1996,7 @@ LogicalResult Deserializer::structurizeControlFlow() { // Instruction //===----------------------------------------------------------------------===// -Value *Deserializer::getValue(uint32_t id) { +ValuePtr Deserializer::getValue(uint32_t id) { if (auto constInfo = getConstant(id)) { // Materialize a `spv.constant` op at every use site. return opBuilder.create(unknownLoc, constInfo->second, @@ -2192,7 +2192,7 @@ LogicalResult Deserializer::processBitcast(ArrayRef words) { } } valueID = words[wordIndex++]; - SmallVector operands; + SmallVector operands; SmallVector attributes; if (wordIndex < words.size()) { auto arg = getValue(words[wordIndex]); @@ -2366,9 +2366,9 @@ Deserializer::processOp(ArrayRef operands) { auto functionName = getFunctionSymbol(functionID); - SmallVector arguments; + SmallVector arguments; for (auto operand : llvm::drop_begin(operands, 3)) { - auto *value = getValue(operand); + auto value = getValue(operand); if (!value) { return emitError(unknownLoc, "unknown ") << operand << " used by OpFunctionCall"; diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index 4baac53b89f..9b47045ea61 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -323,7 +323,7 @@ private: uint32_t opcode, ArrayRef operands); - uint32_t getValueID(Value *val) const { return valueIDMap.lookup(val); } + uint32_t getValueID(ValuePtr val) const { return valueIDMap.lookup(val); } LogicalResult processAddressOfOp(spirv::AddressOfOp addressOfOp); @@ -414,7 +414,7 @@ private: DenseMap undefValIDMap; /// Map from results of normal operations to their s. - DenseMap valueIDMap; + DenseMap valueIDMap; /// Map from extended instruction set name to s. llvm::StringMap extendedInstSetIDMap; @@ -457,7 +457,7 @@ private: /// placed inside `functions`) here. And then after emitting all blocks, we /// replace the dummy 0 with the real result by overwriting /// `functions[offset]`. - DenseMap> deferredPhiValues; + DenseMap> deferredPhiValues; }; } // namespace @@ -513,12 +513,12 @@ void Serializer::collect(SmallVectorImpl &binary) { void Serializer::printValueIDMap(raw_ostream &os) { os << "\n= Value Map =\n\n"; for (auto valueIDPair : valueIDMap) { - Value *val = valueIDPair.first; + ValuePtr val = valueIDPair.first; os << " " << val << " " << "id = " << valueIDPair.second << ' '; if (auto *op = val->getDefiningOp()) { os << "from op '" << op->getName() << "'"; - } else if (auto *arg = dyn_cast(val)) { + } else if (auto arg = dyn_cast(val)) { Block *block = arg->getOwner(); os << "from argument of block " << block << ' '; os << " in op '" << block->getParentOp()->getName() << "'"; @@ -752,7 +752,7 @@ LogicalResult Serializer::processFuncOp(FuncOp op) { // There might be OpPhi instructions who have value references needing to fix. for (auto deferredValue : deferredPhiValues) { - Value *value = deferredValue.first; + ValuePtr value = deferredValue.first; uint32_t id = getValueID(value); LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value << " to id = " << id << '\n'); @@ -1402,7 +1402,7 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) { // Then create OpPhi instruction for each of the block argument. for (auto argIndex : llvm::seq(0, block->getNumArguments())) { - BlockArgument *arg = block->getArgument(argIndex); + BlockArgumentPtr arg = block->getArgument(argIndex); // Get the type and result for this OpPhi instruction. uint32_t phiTypeID = 0; @@ -1418,7 +1418,7 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) { phiArgs.push_back(phiID); for (auto predIndex : llvm::seq(0, predecessors.size())) { - Value *value = *(predecessors[predIndex].second + argIndex); + ValuePtr value = *(predecessors[predIndex].second + argIndex); uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first); LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId << ") value " << value << ' '); @@ -1784,7 +1784,7 @@ Serializer::processOp(spirv::FunctionCallOp op) { auto funcCallID = getNextID(); SmallVector operands{resTypeID, funcCallID, funcID}; - for (auto *value : op.arguments()) { + for (auto value : op.arguments()) { auto valueID = getValueID(value); assert(valueID && "cannot find a value for spv.FunctionCall"); operands.push_back(valueID); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp index d48b31fe491..93ce2c0a0d5 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -140,7 +140,7 @@ class FuncOpLowering final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult - matchAndRewrite(FuncOp funcOp, ArrayRef operands, + matchAndRewrite(FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -153,7 +153,7 @@ private: } // namespace PatternMatchResult -FuncOpLowering::matchAndRewrite(FuncOp funcOp, ArrayRef operands, +FuncOpLowering::matchAndRewrite(FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (!funcOp.getAttrOfType( spirv::getEntryPointABIAttrName())) { @@ -183,7 +183,7 @@ FuncOpLowering::matchAndRewrite(FuncOp funcOp, ArrayRef operands, OpBuilder::InsertionGuard funcInsertionGuard(rewriter); rewriter.setInsertionPointToStart(&funcOp.front()); // Insert spirv::AddressOf and spirv::AccessChain operations. - Value *replacement = + ValuePtr replacement = rewriter.create(funcOp.getLoc(), var); // Check if the arg is a scalar or vector type. In that case, the value // needs to be loaded into registers. diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index 4116f6f14ae..94166b5a7dd 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -81,7 +81,7 @@ struct StdInlinerInterface : public DialectInlinerInterface { /// Handle the given inlined terminator by replacing it with a new operation /// as necessary. void handleTerminator(Operation *op, - ArrayRef valuesToRepl) const final { + ArrayRef valuesToRepl) const final { // Only "std.return" needs to be handled here. auto returnOp = cast(op); @@ -184,7 +184,7 @@ void mlir::printDimAndSymbolList(Operation::operand_iterator begin, // dimension operands parsed. // Returns 'false' on success and 'true' on error. ParseResult mlir::parseDimAndSymbolList(OpAsmParser &parser, - SmallVectorImpl &operands, + SmallVectorImpl &operands, unsigned &numDims) { SmallVector opInfos; if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren)) @@ -325,7 +325,7 @@ struct SimplifyAllocConst : public OpRewritePattern { PatternRewriter &rewriter) const override { // Check to see if any dimensions operands are constants. If so, we can // substitute and drop them. - if (llvm::none_of(alloc.getOperands(), [](Value *operand) { + if (llvm::none_of(alloc.getOperands(), [](ValuePtr operand) { return matchPattern(operand, m_ConstantIndex()); })) return matchFailure(); @@ -336,8 +336,8 @@ struct SimplifyAllocConst : public OpRewritePattern { // and keep track of the resultant memref type to build. SmallVector newShapeConstants; newShapeConstants.reserve(memrefType.getRank()); - SmallVector newOperands; - SmallVector droppedOperands; + SmallVector newOperands; + SmallVector droppedOperands; unsigned dynamicDimPos = 0; for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { @@ -429,7 +429,7 @@ struct SimplifyBrToBlockWithSinglePred : public OpRewritePattern { static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &result) { Block *dest; - SmallVector destOperands; + SmallVector destOperands; if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure(); result.addSuccessor(dest, destOperands); @@ -623,7 +623,7 @@ static Type getI1SameShape(Builder *build, Type type) { //===----------------------------------------------------------------------===// static void buildCmpIOp(Builder *build, OperationState &result, - CmpIPredicate predicate, Value *lhs, Value *rhs) { + CmpIPredicate predicate, ValuePtr lhs, ValuePtr rhs) { result.addOperands({lhs, rhs}); result.types.push_back(getI1SameShape(build, lhs->getType())); result.addAttribute( @@ -777,7 +777,7 @@ CmpFPredicate CmpFOp::getPredicateByName(StringRef name) { } static void buildCmpFOp(Builder *build, OperationState &result, - CmpFPredicate predicate, Value *lhs, Value *rhs) { + CmpFPredicate predicate, ValuePtr lhs, ValuePtr rhs) { result.addOperands({lhs, rhs}); result.types.push_back(getI1SameShape(build, lhs->getType())); result.addAttribute( @@ -946,7 +946,7 @@ struct SimplifyConstCondBranchPred : public OpRewritePattern { static ParseResult parseCondBranchOp(OpAsmParser &parser, OperationState &result) { - SmallVector destOperands; + SmallVector destOperands; Block *dest; OpAsmParser::OperandType condInfo; @@ -1088,7 +1088,7 @@ OpFoldResult ConstantOp::fold(ArrayRef operands) { } void ConstantOp::getAsmResultNames( - function_ref setNameFn) { + function_ref setNameFn) { Type type = getType(); if (auto intCst = getValue().dyn_cast()) { IntegerType intTy = type.dyn_cast(); @@ -1183,7 +1183,7 @@ struct SimplifyDeadDealloc : public OpRewritePattern { PatternMatchResult matchAndRewrite(DeallocOp dealloc, PatternRewriter &rewriter) const override { // Check that the memref operand's defining operation is an AllocOp. - Value *memref = dealloc.memref(); + ValuePtr memref = dealloc.memref(); if (!isa_and_nonnull(memref->getDefiningOp())) return matchFailure(); @@ -1362,11 +1362,11 @@ OpFoldResult UnsignedDivIOp::fold(ArrayRef operands) { // --------------------------------------------------------------------------- void DmaStartOp::build(Builder *builder, OperationState &result, - Value *srcMemRef, ValueRange srcIndices, - Value *destMemRef, ValueRange destIndices, - Value *numElements, Value *tagMemRef, - ValueRange tagIndices, Value *stride, - Value *elementsPerStride) { + ValuePtr srcMemRef, ValueRange srcIndices, + ValuePtr destMemRef, ValueRange destIndices, + ValuePtr numElements, ValuePtr tagMemRef, + ValueRange tagIndices, ValuePtr stride, + ValuePtr elementsPerStride) { result.addOperands(srcMemRef); result.addOperands(srcIndices); result.addOperands(destMemRef); @@ -1507,8 +1507,8 @@ LogicalResult DmaStartOp::fold(ArrayRef cstOperands, // --------------------------------------------------------------------------- void DmaWaitOp::build(Builder *builder, OperationState &result, - Value *tagMemRef, ValueRange tagIndices, - Value *numElements) { + ValuePtr tagMemRef, ValueRange tagIndices, + ValuePtr numElements) { result.addOperands(tagMemRef); result.addOperands(tagIndices); result.addOperands(numElements); @@ -2025,7 +2025,7 @@ static LogicalResult verify(SelectOp op) { } OpFoldResult SelectOp::fold(ArrayRef operands) { - auto *condition = getCondition(); + auto condition = getCondition(); // select true, %0, %1 => %0 if (matchPattern(condition, m_One())) @@ -2357,7 +2357,7 @@ static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { static void print(OpAsmPrinter &p, ViewOp op) { p << op.getOperationName() << ' ' << *op.getOperand(0) << '['; - auto *dynamicOffset = op.getDynamicOffset(); + auto dynamicOffset = op.getDynamicOffset(); if (dynamicOffset != nullptr) p.printOperand(dynamicOffset); p << "][" << op.getDynamicSizes() << ']'; @@ -2365,7 +2365,7 @@ static void print(OpAsmPrinter &p, ViewOp op) { p << " : " << op.getOperand(0)->getType() << " to " << op.getType(); } -Value *ViewOp::getDynamicOffset() { +ValuePtr ViewOp::getDynamicOffset() { int64_t offset; SmallVector strides; auto result = @@ -2440,7 +2440,7 @@ struct ViewOpShapeFolder : public OpRewritePattern { PatternMatchResult matchAndRewrite(ViewOp viewOp, PatternRewriter &rewriter) const override { // Return if none of the operands are constants. - if (llvm::none_of(viewOp.getOperands(), [](Value *operand) { + if (llvm::none_of(viewOp.getOperands(), [](ValuePtr operand) { return matchPattern(operand, m_ConstantIndex()); })) return matchFailure(); @@ -2457,11 +2457,11 @@ struct ViewOpShapeFolder : public OpRewritePattern { if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) return matchFailure(); - SmallVector newOperands; - SmallVector droppedOperands; + SmallVector newOperands; + SmallVector droppedOperands; // Fold dynamic offset operand if it is produced by a constant. - auto *dynamicOffset = viewOp.getDynamicOffset(); + auto dynamicOffset = viewOp.getDynamicOffset(); int64_t newOffset = oldOffset; unsigned dynamicOffsetOperandCount = 0; if (dynamicOffset != nullptr) { @@ -2576,7 +2576,7 @@ static Type inferSubViewResultType(MemRefType memRefType) { memRefType.getMemorySpace()); } -void mlir::SubViewOp::build(Builder *b, OperationState &result, Value *source, +void mlir::SubViewOp::build(Builder *b, OperationState &result, ValuePtr source, ValueRange offsets, ValueRange sizes, ValueRange strides, Type resultType, ArrayRef attrs) { @@ -2590,7 +2590,7 @@ void mlir::SubViewOp::build(Builder *b, OperationState &result, Value *source, } void mlir::SubViewOp::build(Builder *b, OperationState &result, Type resultType, - Value *source) { + ValuePtr source) { build(b, result, source, /*offsets=*/{}, /*sizes=*/{}, /*strides=*/{}, resultType); } @@ -2826,7 +2826,7 @@ public: // Follow all or nothing approach for shapes for now. If all the operands // for sizes are constants then fold it into the type of the result memref. if (subViewType.hasStaticShape() || - llvm::any_of(subViewOp.sizes(), [](Value *operand) { + llvm::any_of(subViewOp.sizes(), [](ValuePtr operand) { return !matchPattern(operand, m_ConstantIndex()); })) { return matchFailure(); @@ -2842,7 +2842,7 @@ public: subViewType.getMemorySpace()); auto newSubViewOp = rewriter.create( subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(), - ArrayRef(), subViewOp.strides(), newMemRefType); + ArrayRef(), subViewOp.strides(), newMemRefType); // Insert a memref_cast for compatibility of the uses of the op. rewriter.replaceOpWithNewOp( subViewOp.sizes(), subViewOp, newSubViewOp, subViewOp.getType()); @@ -2871,7 +2871,7 @@ public: failed(getStridesAndOffset(subViewType, resultStrides, resultOffset)) || llvm::is_contained(baseStrides, MemRefType::getDynamicStrideOrOffset()) || - llvm::any_of(subViewOp.strides(), [](Value *stride) { + llvm::any_of(subViewOp.strides(), [](ValuePtr stride) { return !matchPattern(stride, m_ConstantIndex()); })) { return matchFailure(); @@ -2892,7 +2892,7 @@ public: layoutMap, subViewType.getMemorySpace()); auto newSubViewOp = rewriter.create( subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(), - subViewOp.sizes(), ArrayRef(), newMemRefType); + subViewOp.sizes(), ArrayRef(), newMemRefType); // Insert a memref_cast for compatibility of the uses of the op. rewriter.replaceOpWithNewOp( subViewOp.strides(), subViewOp, newSubViewOp, subViewOp.getType()); @@ -2922,7 +2922,7 @@ public: llvm::is_contained(baseStrides, MemRefType::getDynamicStrideOrOffset()) || baseOffset == MemRefType::getDynamicStrideOrOffset() || - llvm::any_of(subViewOp.offsets(), [](Value *stride) { + llvm::any_of(subViewOp.offsets(), [](ValuePtr stride) { return !matchPattern(stride, m_ConstantIndex()); })) { return matchFailure(); @@ -2943,7 +2943,7 @@ public: MemRefType::get(subViewType.getShape(), subViewType.getElementType(), layoutMap, subViewType.getMemorySpace()); auto newSubViewOp = rewriter.create( - subViewOp.getLoc(), subViewOp.source(), ArrayRef(), + subViewOp.getLoc(), subViewOp.source(), ArrayRef(), subViewOp.sizes(), subViewOp.strides(), newMemRefType); // Insert a memref_cast for compatibility of the uses of the op. rewriter.replaceOpWithNewOp( diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index 6a3ff74afcd..18c1714f403 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -72,7 +72,7 @@ ArrayAttr vector::getVectorSubscriptAttr(Builder &builder, //===----------------------------------------------------------------------===// void vector::ContractionOp::build(Builder *builder, OperationState &result, - Value *lhs, Value *rhs, Value *acc, + ValuePtr lhs, ValuePtr rhs, ValuePtr acc, ArrayAttr indexingMaps, ArrayAttr iteratorTypes) { result.addOperands({lhs, rhs, acc}); @@ -404,7 +404,7 @@ static Type inferExtractOpResultType(VectorType vectorType, } void vector::ExtractOp::build(Builder *builder, OperationState &result, - Value *source, ArrayRef position) { + ValuePtr source, ArrayRef position) { result.addOperands(source); auto positionAttr = getVectorSubscriptAttr(*builder, position); result.addTypes(inferExtractOpResultType(source->getType().cast(), @@ -471,7 +471,7 @@ static LogicalResult verify(vector::ExtractOp op) { //===----------------------------------------------------------------------===// void ExtractSlicesOp::build(Builder *builder, OperationState &result, - TupleType tupleType, Value *vector, + TupleType tupleType, ValuePtr vector, ArrayRef sizes, ArrayRef strides) { result.addOperands(vector); @@ -647,8 +647,8 @@ static ParseResult parseBroadcastOp(OpAsmParser &parser, // ShuffleOp //===----------------------------------------------------------------------===// -void ShuffleOp::build(Builder *builder, OperationState &result, Value *v1, - Value *v2, ArrayRef mask) { +void ShuffleOp::build(Builder *builder, OperationState &result, ValuePtr v1, + ValuePtr v2, ArrayRef mask) { result.addOperands({v1, v2}); auto maskAttr = getVectorSubscriptAttr(*builder, mask); result.addTypes(v1->getType()); @@ -771,8 +771,8 @@ static LogicalResult verify(InsertElementOp op) { // InsertOp //===----------------------------------------------------------------------===// -void InsertOp::build(Builder *builder, OperationState &result, Value *source, - Value *dest, ArrayRef position) { +void InsertOp::build(Builder *builder, OperationState &result, ValuePtr source, + ValuePtr dest, ArrayRef position) { result.addOperands({source, dest}); auto positionAttr = getVectorSubscriptAttr(*builder, position); result.addTypes(dest->getType()); @@ -893,7 +893,7 @@ void InsertSlicesOp::getStrides(SmallVectorImpl &results) { //===----------------------------------------------------------------------===// void InsertStridedSliceOp::build(Builder *builder, OperationState &result, - Value *source, Value *dest, + ValuePtr source, ValuePtr dest, ArrayRef offsets, ArrayRef strides) { result.addOperands({source, dest}); @@ -1201,17 +1201,17 @@ static LogicalResult verify(ReshapeOp op) { // If all shape operands are produced by constant ops, verify that product // of dimensions for input/output shape match. - auto isDefByConstant = [](Value *operand) { + auto isDefByConstant = [](ValuePtr operand) { return isa_and_nonnull(operand->getDefiningOp()); }; if (llvm::all_of(op.input_shape(), isDefByConstant) && llvm::all_of(op.output_shape(), isDefByConstant)) { int64_t numInputElements = 1; - for (auto *operand : op.input_shape()) + for (auto operand : op.input_shape()) numInputElements *= cast(operand->getDefiningOp()).getValue(); int64_t numOutputElements = 1; - for (auto *operand : op.output_shape()) + for (auto operand : op.output_shape()) numOutputElements *= cast(operand->getDefiningOp()).getValue(); if (numInputElements != numOutputElements) @@ -1247,7 +1247,7 @@ static Type inferStridedSliceOpResultType(VectorType vectorType, } void StridedSliceOp::build(Builder *builder, OperationState &result, - Value *source, ArrayRef offsets, + ValuePtr source, ArrayRef offsets, ArrayRef sizes, ArrayRef strides) { result.addOperands(source); auto offsetsAttr = getVectorSubscriptAttr(*builder, offsets); @@ -1603,7 +1603,7 @@ static MemRefType inferVectorTypeCastResultType(MemRefType t) { } void TypeCastOp::build(Builder *builder, OperationState &result, - Value *source) { + ValuePtr source) { result.addOperands(source); result.addTypes( inferVectorTypeCastResultType(source->getType().cast())); @@ -1793,14 +1793,14 @@ public: PatternMatchResult matchAndRewrite(CreateMaskOp createMaskOp, PatternRewriter &rewriter) const override { // Return if any of 'createMaskOp' operands are not defined by a constant. - auto is_not_def_by_constant = [](Value *operand) { + auto is_not_def_by_constant = [](ValuePtr operand) { return !isa_and_nonnull(operand->getDefiningOp()); }; if (llvm::any_of(createMaskOp.operands(), is_not_def_by_constant)) return matchFailure(); // Gather constant mask dimension sizes. SmallVector maskDimSizes; - for (auto *operand : createMaskOp.operands()) { + for (auto operand : createMaskOp.operands()) { auto defOp = operand->getDefiningOp(); maskDimSizes.push_back(cast(defOp).getValue()); } diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp index 64cacb28720..e5c281cbf64 100644 --- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp @@ -106,17 +106,17 @@ static SmallVector delinearize(int64_t linearIndex, // `resultTypes`. static Operation *cloneOpWithOperandsAndTypes(PatternRewriter &builder, Location loc, Operation *op, - ArrayRef operands, + ArrayRef operands, ArrayRef resultTypes) { OperationState res(loc, op->getName().getStringRef(), operands, resultTypes, op->getAttrs()); return builder.createOperation(res); } -static Value *makeSplatZero(Location loc, PatternRewriter &rewriter, - VectorType vt) { +static ValuePtr makeSplatZero(Location loc, PatternRewriter &rewriter, + VectorType vt) { auto t = vt.getElementType(); - Value *f = nullptr; + ValuePtr f = nullptr; if (t.isBF16() || t.isF16()) f = rewriter.create(loc, t, rewriter.getF64FloatAttr(0.0f)); else if (t.isF32()) @@ -190,12 +190,12 @@ struct UnrolledVectorState { SmallVector unrollFactors; SmallVector basis; int64_t numInstances; - Value *slicesTuple; + ValuePtr slicesTuple; }; // Populates 'state' with unrolled shape, unroll factors, basis and // num unrolled instances for 'vectorType'. -static void initUnrolledVectorState(VectorType vectorType, Value *initValue, +static void initUnrolledVectorState(VectorType vectorType, ValuePtr initValue, const DenseMap &indexMap, ArrayRef targetShape, UnrolledVectorState &state, @@ -239,10 +239,10 @@ getUnrolledVectorLinearIndex(UnrolledVectorState &state, // Returns an unrolled vector at 'vectorOffsets' within the vector // represented by 'state'. The vector is created from a slice of 'initValue' // if not present in 'cache'. -static Value *getOrCreateUnrolledVectorSlice( +static ValuePtr getOrCreateUnrolledVectorSlice( Location loc, UnrolledVectorState &state, ArrayRef vectorOffsets, ArrayRef offsets, DenseMap &indexMap, - Value *initValue, SmallVectorImpl &cache, + ValuePtr initValue, SmallVectorImpl &cache, PatternRewriter &builder) { // Compute slice offsets. SmallVector sliceOffsets(state.unrolledShape.size()); @@ -253,7 +253,7 @@ static Value *getOrCreateUnrolledVectorSlice( int64_t sliceLinearIndex = getUnrolledVectorLinearIndex(state, vectorOffsets, indexMap); assert(sliceLinearIndex < static_cast(cache.size())); - auto *valueSlice = cache[sliceLinearIndex]; + auto valueSlice = cache[sliceLinearIndex]; if (valueSlice == nullptr) { // Return tuple element at 'sliceLinearIndex'. auto tupleIndex = builder.getI64IntegerAttr(sliceLinearIndex); @@ -330,12 +330,10 @@ struct VectorState { // TODO(andydavis) Generalize this to support structured ops beyond // vector ContractionOp, and merge it with 'unrollSingleResultOpMatchingType' -static Value *unrollSingleResultStructuredOp(Operation *op, - ArrayRef iterationBounds, - std::vector &vectors, - unsigned resultIndex, - ArrayRef targetShape, - PatternRewriter &builder) { +static ValuePtr unrollSingleResultStructuredOp( + Operation *op, ArrayRef iterationBounds, + std::vector &vectors, unsigned resultIndex, + ArrayRef targetShape, PatternRewriter &builder) { auto shapedType = op->getResult(0)->getType().dyn_cast_or_null(); if (!shapedType || !shapedType.hasStaticShape()) assert(false && "Expected a statically shaped result type"); @@ -351,7 +349,7 @@ static Value *unrollSingleResultStructuredOp(Operation *op, SmallVector unrolledVectorState(numVectors); for (unsigned i = 0; i < numVectors; ++i) { int64_t operandIndex = vectors[i].operandIndex; - auto *operand = operandIndex >= 0 ? op->getOperand(operandIndex) : nullptr; + auto operand = operandIndex >= 0 ? op->getOperand(operandIndex) : nullptr; initUnrolledVectorState(vectors[i].type, operand, vectors[i].indexMap, targetShape, unrolledVectorState[i], builder); } @@ -364,7 +362,7 @@ static Value *unrollSingleResultStructuredOp(Operation *op, shapedType.getElementType()); // Initialize caches for intermediate vector results. - std::vector> caches(numVectors); + std::vector> caches(numVectors); for (unsigned i = 0; i < numVectors; ++i) caches[i].resize(unrolledVectorState[i].numInstances); @@ -376,13 +374,13 @@ static Value *unrollSingleResultStructuredOp(Operation *op, auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; }, vectorOffsets, targetShape); // Get cached slice (or create slice) for each operand at 'offsets'. - SmallVector operands; + SmallVector operands; operands.resize(op->getNumOperands()); for (unsigned i = 0; i < numVectors; ++i) { int64_t operandIndex = vectors[i].operandIndex; if (operandIndex < 0) continue; // Output - auto *operand = op->getOperand(operandIndex); + auto operand = op->getOperand(operandIndex); operands[operandIndex] = getOrCreateUnrolledVectorSlice( op->getLoc(), unrolledVectorState[i], vectorOffsets, offsets, vectors[i].indexMap, operand, caches[i], builder); @@ -402,21 +400,21 @@ static Value *unrollSingleResultStructuredOp(Operation *op, // Create TupleOp of unrolled result vectors. SmallVector vectorTupleTypes(resultValueState.numInstances); - SmallVector vectorTupleValues(resultValueState.numInstances); + SmallVector vectorTupleValues(resultValueState.numInstances); for (unsigned i = 0; i < resultValueState.numInstances; ++i) { vectorTupleTypes[i] = caches[resultIndex][i]->getType().cast(); vectorTupleValues[i] = caches[resultIndex][i]; } TupleType tupleType = builder.getTupleType(vectorTupleTypes); - Value *tupleOp = builder.create(op->getLoc(), tupleType, - vectorTupleValues); + ValuePtr tupleOp = builder.create(op->getLoc(), tupleType, + vectorTupleValues); // Create InsertSlicesOp(Tuple(result_vectors)). auto resultVectorType = op->getResult(0)->getType().cast(); SmallVector sizes(resultValueState.unrolledShape); SmallVector strides(resultValueState.unrollFactors.size(), 1); - Value *insertSlicesOp = builder.create( + ValuePtr insertSlicesOp = builder.create( op->getLoc(), resultVectorType, tupleOp, builder.getI64ArrayAttr(sizes), builder.getI64ArrayAttr(strides)); return insertSlicesOp; @@ -487,7 +485,7 @@ getVectorElementwiseOpUnrollState(Operation *op, ArrayRef targetShape, } // Entry point for unrolling declarative pattern rewrites. -Value *mlir::vector::unrollSingleResultOpMatchingType( +ValuePtr mlir::vector::unrollSingleResultOpMatchingType( PatternRewriter &builder, Operation *op, ArrayRef targetShape) { assert(op->getNumResults() == 1 && "Expected single result operation"); @@ -516,8 +514,8 @@ Value *mlir::vector::unrollSingleResultOpMatchingType( static void generateTransferOpSlices(VectorType vectorType, TupleType tupleType, ArrayRef sizes, ArrayRef strides, - ArrayRef indices, PatternRewriter &rewriter, - function_ref)> fn) { + ArrayRef indices, PatternRewriter &rewriter, + function_ref)> fn) { // Compute strides w.r.t. to slice counts in each dimension. auto maybeDimSliceCounts = shapeRatio(vectorType.getShape(), sizes); assert(maybeDimSliceCounts.hasValue()); @@ -534,13 +532,13 @@ generateTransferOpSlices(VectorType vectorType, TupleType tupleType, auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; }, vectorOffsets, sizes); // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'. - SmallVector sliceIndices(numSliceIndices); + SmallVector sliceIndices(numSliceIndices); for (auto it : llvm::enumerate(indices)) { auto expr = getAffineDimExpr(0, ctx) + getAffineConstantExpr(offsets[it.index()], ctx); auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); sliceIndices[it.index()] = rewriter.create( - it.value()->getLoc(), map, ArrayRef(it.value())); + it.value()->getLoc(), map, ArrayRef(it.value())); } // Call 'fn' to generate slice 'i' at 'sliceIndices'. fn(i, sliceIndices); @@ -559,7 +557,7 @@ struct SplitTransferReadOp : public OpRewritePattern { if (!xferReadOp.permutation_map().isIdentity()) return matchFailure(); // Return unless the unique 'xferReadOp' user is an ExtractSlicesOp. - Value *xferReadResult = xferReadOp.getResult(); + ValuePtr xferReadResult = xferReadOp.getResult(); auto extractSlicesOp = dyn_cast(*xferReadResult->getUsers().begin()); if (!xferReadResult->hasOneUse() || !extractSlicesOp) @@ -576,10 +574,10 @@ struct SplitTransferReadOp : public OpRewritePattern { Location loc = xferReadOp.getLoc(); int64_t numSlices = resultTupleType.size(); - SmallVector vectorTupleValues(numSlices); - SmallVector indices(xferReadOp.indices().begin(), - xferReadOp.indices().end()); - auto createSlice = [&](unsigned index, ArrayRef sliceIndices) { + SmallVector vectorTupleValues(numSlices); + SmallVector indices(xferReadOp.indices().begin(), + xferReadOp.indices().end()); + auto createSlice = [&](unsigned index, ArrayRef sliceIndices) { // Get VectorType for slice 'i'. auto sliceVectorType = resultTupleType.getType(index); // Create split TransferReadOp for 'sliceUser'. @@ -591,8 +589,8 @@ struct SplitTransferReadOp : public OpRewritePattern { indices, rewriter, createSlice); // Create tuple of splice xfer read operations. - Value *tupleOp = rewriter.create(loc, resultTupleType, - vectorTupleValues); + ValuePtr tupleOp = rewriter.create(loc, resultTupleType, + vectorTupleValues); // Replace 'xferReadOp' with result 'insertSlicesResult'. rewriter.replaceOpWithNewOp( xferReadOp, sourceVectorType, tupleOp, extractSlicesOp.sizes(), @@ -632,9 +630,9 @@ struct SplitTransferWriteOp : public OpRewritePattern { insertSlicesOp.getStrides(strides); Location loc = xferWriteOp.getLoc(); - SmallVector indices(xferWriteOp.indices().begin(), - xferWriteOp.indices().end()); - auto createSlice = [&](unsigned index, ArrayRef sliceIndices) { + SmallVector indices(xferWriteOp.indices().begin(), + xferWriteOp.indices().end()); + auto createSlice = [&](unsigned index, ArrayRef sliceIndices) { // Create split TransferWriteOp for source vector 'tupleOp.operand[i]'. rewriter.create( loc, tupleOp.getOperand(index), xferWriteOp.memref(), sliceIndices, @@ -676,7 +674,7 @@ struct TupleGetFolderOp : public OpRewritePattern { return matchFailure(); // Forward Value from 'tupleOp' at 'tupleGetOp.index'. - Value *tupleValue = tupleOp.getOperand(tupleGetOp.getIndex()); + ValuePtr tupleValue = tupleOp.getOperand(tupleGetOp.getIndex()); rewriter.replaceOp(tupleGetOp, tupleValue); return matchSuccess(); } diff --git a/mlir/lib/EDSC/Builders.cpp b/mlir/lib/EDSC/Builders.cpp index 47e2dfed55e..35108ed5666 100644 --- a/mlir/lib/EDSC/Builders.cpp +++ b/mlir/lib/EDSC/Builders.cpp @@ -88,9 +88,8 @@ ValueHandle &mlir::edsc::ValueHandle::operator=(const ValueHandle &other) { return *this; } -ValueHandle -mlir::edsc::ValueHandle::createComposedAffineApply(AffineMap map, - ArrayRef operands) { +ValueHandle mlir::edsc::ValueHandle::createComposedAffineApply( + AffineMap map, ArrayRef operands) { Operation *op = makeComposedAffineApply(ScopedContext::getBuilder(), ScopedContext::getLocation(), map, operands) @@ -118,7 +117,7 @@ OperationHandle OperationHandle::create(StringRef name, ArrayRef resultTypes, ArrayRef attributes) { OperationState state(ScopedContext::getLocation(), name); - SmallVector ops(operands.begin(), operands.end()); + SmallVector ops(operands.begin(), operands.end()); state.addOperands(ops); state.addTypes(resultTypes); for (const auto &attr : attributes) { @@ -169,8 +168,8 @@ mlir::edsc::LoopBuilder mlir::edsc::LoopBuilder::makeAffine( if (auto staticFor = emitStaticFor(lbHandles, ubHandles, step)) { *iv = staticFor.getValue(); } else { - SmallVector lbs(lbHandles.begin(), lbHandles.end()); - SmallVector ubs(ubHandles.begin(), ubHandles.end()); + SmallVector lbs(lbHandles.begin(), lbHandles.end()); + SmallVector ubs(ubHandles.begin(), ubHandles.end()); *iv = ValueHandle::create( lbs, ScopedContext::getBuilder().getMultiDimIdentityMap(lbs.size()), ubs, ScopedContext::getBuilder().getMultiDimIdentityMap(ubs.size()), @@ -309,11 +308,11 @@ static ValueHandle createBinaryHandle(ValueHandle lhs, ValueHandle rhs) { return ValueHandle::create(lhs.getValue(), rhs.getValue()); } -static std::pair -categorizeValueByAffineType(MLIRContext *context, Value *val, unsigned &numDims, - unsigned &numSymbols) { +static std::pair +categorizeValueByAffineType(MLIRContext *context, ValuePtr val, + unsigned &numDims, unsigned &numSymbols) { AffineExpr d; - Value *resultVal = nullptr; + ValuePtr resultVal = nullptr; if (auto constant = dyn_cast_or_null(val->getDefiningOp())) { d = getAffineConstantExpr(constant.getValue(), context); } else if (isValidSymbol(val) && !isValidDim(val)) { @@ -332,12 +331,12 @@ static ValueHandle createBinaryIndexHandle( MLIRContext *context = ScopedContext::getContext(); unsigned numDims = 0, numSymbols = 0; AffineExpr d0, d1; - Value *v0, *v1; + ValuePtr v0, v1; std::tie(d0, v0) = categorizeValueByAffineType(context, lhs.getValue(), numDims, numSymbols); std::tie(d1, v1) = categorizeValueByAffineType(context, rhs.getValue(), numDims, numSymbols); - SmallVector operands; + SmallVector operands; if (v0) { operands.push_back(v0); } diff --git a/mlir/lib/EDSC/Helpers.cpp b/mlir/lib/EDSC/Helpers.cpp index eeb28668a34..1771eb0a427 100644 --- a/mlir/lib/EDSC/Helpers.cpp +++ b/mlir/lib/EDSC/Helpers.cpp @@ -22,7 +22,7 @@ using namespace mlir; using namespace mlir::edsc; -static SmallVector getMemRefSizes(Value *memRef) { +static SmallVector getMemRefSizes(ValuePtr memRef) { MemRefType memRefType = memRef->getType().cast(); assert(isStrided(memRefType) && "Expected strided MemRef type"); @@ -39,7 +39,7 @@ static SmallVector getMemRefSizes(Value *memRef) { return res; } -mlir::edsc::MemRefView::MemRefView(Value *v) : base(v) { +mlir::edsc::MemRefView::MemRefView(ValuePtr v) : base(v) { assert(v->getType().isa() && "MemRefType expected"); auto memrefSizeValues = getMemRefSizes(v); @@ -50,7 +50,7 @@ mlir::edsc::MemRefView::MemRefView(Value *v) : base(v) { } } -mlir::edsc::VectorView::VectorView(Value *v) : base(v) { +mlir::edsc::VectorView::VectorView(ValuePtr v) : base(v) { auto vectorType = v->getType().cast(); for (auto s : vectorType.getShape()) { diff --git a/mlir/lib/EDSC/Intrinsics.cpp b/mlir/lib/EDSC/Intrinsics.cpp index 1b19f9aa0bf..c6738c42993 100644 --- a/mlir/lib/EDSC/Intrinsics.cpp +++ b/mlir/lib/EDSC/Intrinsics.cpp @@ -29,7 +29,7 @@ OperationHandle mlir::edsc::intrinsics::br(BlockHandle bh, (void)o; assert(o && "Expected already captured ValueHandle"); } - SmallVector ops(operands.begin(), operands.end()); + SmallVector ops(operands.begin(), operands.end()); return OperationHandle::create(bh.getBlock(), ops); } static void enforceEmptyCapturesMatchOperands(ArrayRef captures, @@ -52,7 +52,7 @@ OperationHandle mlir::edsc::intrinsics::br(BlockHandle *bh, assert(!*bh && "Unexpected already captured BlockHandle"); enforceEmptyCapturesMatchOperands(captures, operands); BlockBuilder(bh, captures)(/* no body */); - SmallVector ops(operands.begin(), operands.end()); + SmallVector ops(operands.begin(), operands.end()); return OperationHandle::create(bh->getBlock(), ops); } @@ -61,8 +61,8 @@ mlir::edsc::intrinsics::cond_br(ValueHandle cond, BlockHandle trueBranch, ArrayRef trueOperands, BlockHandle falseBranch, ArrayRef falseOperands) { - SmallVector trueOps(trueOperands.begin(), trueOperands.end()); - SmallVector falseOps(falseOperands.begin(), falseOperands.end()); + SmallVector trueOps(trueOperands.begin(), trueOperands.end()); + SmallVector falseOps(falseOperands.begin(), falseOperands.end()); return OperationHandle::create( cond, trueBranch.getBlock(), trueOps, falseBranch.getBlock(), falseOps); } @@ -78,8 +78,8 @@ OperationHandle mlir::edsc::intrinsics::cond_br( enforceEmptyCapturesMatchOperands(falseCaptures, falseOperands); BlockBuilder(trueBranch, trueCaptures)(/* no body */); BlockBuilder(falseBranch, falseCaptures)(/* no body */); - SmallVector trueOps(trueOperands.begin(), trueOperands.end()); - SmallVector falseOps(falseOperands.begin(), falseOperands.end()); + SmallVector trueOps(trueOperands.begin(), trueOperands.end()); + SmallVector falseOps(falseOperands.begin(), falseOperands.end()); return OperationHandle::create( cond, trueBranch->getBlock(), trueOps, falseBranch->getBlock(), falseOps); } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index f3c92ada0a0..177d8a5ef05 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -319,7 +319,7 @@ void ModuleState::visitOperation(Operation *op) { visitType(type); for (auto ®ion : op->getRegions()) for (auto &block : region) - for (auto *arg : block.getArguments()) + for (auto arg : block.getArguments()) visitType(arg->getType()); // Visit each of the attributes. @@ -1437,7 +1437,7 @@ public: void printAttribute(Attribute attr) override { ModulePrinter::printAttribute(attr); } - void printOperand(Value *value) override { printValueID(value); } + void printOperand(ValuePtr value) override { printValueID(value); } void printOptionalAttrDict(ArrayRef attrs, ArrayRef elidedAttrs = {}) override { @@ -1519,7 +1519,7 @@ protected: void numberValuesInRegion(Region ®ion); void numberValuesInBlock(Block &block); void numberValuesInOp(Operation &op); - void printValueID(Value *value, bool printResultNo = true) const { + void printValueID(ValuePtr value, bool printResultNo = true) const { printValueIDImpl(value, printResultNo, os); } @@ -1528,13 +1528,13 @@ private: /// 'lookupValue' and the result of 'result' within that group in /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group /// has more than 1 result. - void getResultIDAndNumber(OpResult *result, Value *&lookupValue, + void getResultIDAndNumber(OpResultPtr result, ValuePtr &lookupValue, int &lookupResultNo) const; - void printValueIDImpl(Value *value, bool printResultNo, + void printValueIDImpl(ValuePtr value, bool printResultNo, raw_ostream &stream) const; /// Set a special value name for the given value. - void setValueName(Value *value, StringRef name); + void setValueName(ValuePtr value, StringRef name); /// Uniques the given value name within the printer. If the given name /// conflicts, it is automatically renamed. @@ -1542,8 +1542,8 @@ private: /// This is the value ID for each SSA value. If this returns ~0, then the /// valueID has an entry in valueNames. - DenseMap valueIDs; - DenseMap valueNames; + DenseMap valueIDs; + DenseMap valueNames; /// This is a map of operations that contain multiple named result groups, /// i.e. there may be multiple names for the results of the operation. The key @@ -1619,7 +1619,7 @@ void OperationPrinter::numberValuesInRegion(Region ®ion) { } void OperationPrinter::numberValuesInBlock(Block &block) { - auto setArgNameFn = [&](Value *arg, StringRef name) { + auto setArgNameFn = [&](ValuePtr arg, StringRef name) { assert(!valueIDs.count(arg) && "arg numbered multiple times"); assert(cast(arg)->getOwner() == &block && "arg not defined in 'block'"); @@ -1638,7 +1638,7 @@ void OperationPrinter::numberValuesInBlock(Block &block) { // 'arg'. SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : ""); llvm::raw_svector_ostream specialName(specialNameBuffer); - for (auto *arg : block.getArguments()) { + for (auto arg : block.getArguments()) { if (valueIDs.count(arg)) continue; if (isEntryBlock) { @@ -1657,11 +1657,11 @@ void OperationPrinter::numberValuesInOp(Operation &op) { unsigned numResults = op.getNumResults(); if (numResults == 0) return; - Value *resultBegin = op.getResult(0); + ValuePtr resultBegin = op.getResult(0); // Function used to set the special result names for the operation. SmallVector resultGroups(/*Size=*/1, /*Value=*/0); - auto setResultNameFn = [&](Value *result, StringRef name) { + auto setResultNameFn = [&](ValuePtr result, StringRef name) { assert(!valueIDs.count(result) && "result numbered multiple times"); assert(result->getDefiningOp() == &op && "result not defined by 'op'"); setValueName(result, name); @@ -1690,7 +1690,7 @@ void OperationPrinter::numberValuesInOp(Operation &op) { } /// Set a special value name for the given value. -void OperationPrinter::setValueName(Value *value, StringRef name) { +void OperationPrinter::setValueName(ValuePtr value, StringRef name) { // If the name is empty, the value uses the default numbering. if (name.empty()) { valueIDs[value] = nextValueID++; @@ -1737,7 +1737,7 @@ void OperationPrinter::print(Block *block, bool printBlockArgs, // Print the argument list if non-empty. if (!block->args_empty()) { os << '('; - interleaveComma(block->getArguments(), [&](BlockArgument *arg) { + interleaveComma(block->getArguments(), [&](BlockArgumentPtr arg) { printValueID(arg); os << ": "; printType(arg->getType()); @@ -1788,8 +1788,8 @@ void OperationPrinter::print(Operation *op) { printTrailingLocation(op->getLoc()); } -void OperationPrinter::getResultIDAndNumber(OpResult *result, - Value *&lookupValue, +void OperationPrinter::getResultIDAndNumber(OpResultPtr result, + ValuePtr &lookupValue, int &lookupResultNo) const { Operation *owner = result->getOwner(); if (owner->getNumResults() == 1) @@ -1827,7 +1827,7 @@ void OperationPrinter::getResultIDAndNumber(OpResult *result, lookupValue = owner->getResult(groupResultNo); } -void OperationPrinter::printValueIDImpl(Value *value, bool printResultNo, +void OperationPrinter::printValueIDImpl(ValuePtr value, bool printResultNo, raw_ostream &stream) const { if (!value) { stream << "<>"; @@ -1840,7 +1840,7 @@ void OperationPrinter::printValueIDImpl(Value *value, bool printResultNo, // If this is a reference to the result of a multi-result operation or // operation, print out the # identifier and make sure to map our lookup // to the first result of the operation. - if (OpResult *result = dyn_cast(value)) + if (OpResultPtr result = dyn_cast(value)) getResultIDAndNumber(result, lookupValue, resultNo); auto it = valueIDs.find(lookupValue); @@ -1875,11 +1875,11 @@ void OperationPrinter::shadowRegionArgs(Region ®ion, ValueRange namesToUse) { SmallVector nameStr; for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) { - auto *nameToUse = namesToUse[i]; + auto nameToUse = namesToUse[i]; if (nameToUse == nullptr) continue; - auto *nameToReplace = region.front().getArgument(i); + auto nameToReplace = region.front().getArgument(i); nameStr.clear(); llvm::raw_svector_ostream nameStream(nameStr); @@ -1951,10 +1951,10 @@ void OperationPrinter::printGenericOp(Operation *op) { for (unsigned i = 0; i < numSuccessors; ++i) totalNumSuccessorOperands += op->getNumSuccessorOperands(i); unsigned numProperOperands = op->getNumOperands() - totalNumSuccessorOperands; - SmallVector properOperands( + SmallVector properOperands( op->operand_begin(), std::next(op->operand_begin(), numProperOperands)); - interleaveComma(properOperands, [&](Value *value) { printValueID(value); }); + interleaveComma(properOperands, [&](ValuePtr value) { printValueID(value); }); os << ')'; @@ -1997,10 +1997,10 @@ void OperationPrinter::printSuccessorAndUseList(Operation *term, os << '('; interleaveComma(succOperands, - [this](Value *operand) { printValueID(operand); }); + [this](ValuePtr operand) { printValueID(operand); }); os << " : "; interleaveComma(succOperands, - [this](Value *operand) { printType(operand->getType()); }); + [this](ValuePtr operand) { printType(operand->getType()); }); os << ')'; } @@ -2072,7 +2072,7 @@ void Value::print(raw_ostream &os) { if (auto *op = getDefiningOp()) return op->print(os); // TODO: Improve this. - assert(isa(*this)); + assert(isa()); os << "\n"; } diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp index 4dac32ae0c0..894f9ba38d0 100644 --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -98,7 +98,7 @@ void Block::dropAllReferences() { } void Block::dropAllDefinedValueUses() { - for (auto *arg : getArguments()) + for (auto arg : getArguments()) arg->dropAllUses(); for (auto &op : *this) op.dropAllDefinedValueUses(); @@ -151,7 +151,7 @@ void Block::recomputeOpOrder() { // Argument list management. //===----------------------------------------------------------------------===// -BlockArgument *Block::addArgument(Type type) { +BlockArgumentPtr Block::addArgument(Type type) { auto *arg = new BlockArgument(type, this); arguments.push_back(arg); return arg; diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 691b2ad99c4..733fcd13994 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -343,7 +343,7 @@ Operation *OpBuilder::createOperation(const OperationState &state) { /// 'results'. Returns success if the operation was folded, failure otherwise. /// Note: This function does not erase the operation on a successful fold. LogicalResult OpBuilder::tryFold(Operation *op, - SmallVectorImpl &results) { + SmallVectorImpl &results) { results.reserve(op->getNumResults()); auto cleanupFailure = [&] { results.assign(op->result_begin(), op->result_end()); @@ -374,7 +374,7 @@ LogicalResult OpBuilder::tryFold(Operation *op, Dialect *dialect = op->getDialect(); for (auto &it : llvm::enumerate(foldResults)) { // Normal values get pushed back directly. - if (auto *value = it.value().dyn_cast()) { + if (auto value = it.value().dyn_cast()) { results.push_back(value); continue; } diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 9df10791046..53399ce00a3 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -114,7 +114,7 @@ template <> unsigned BlockOperand::getOperandNumber() { /// Create a new Operation with the specific fields. Operation *Operation::create(Location location, OperationName name, ArrayRef resultTypes, - ArrayRef operands, + ArrayRef operands, ArrayRef attributes, ArrayRef successors, unsigned numRegions, bool resizableOperandList) { @@ -134,7 +134,7 @@ Operation *Operation::create(const OperationState &state) { /// Create a new Operation with the specific fields. Operation *Operation::create(Location location, OperationName name, ArrayRef resultTypes, - ArrayRef operands, + ArrayRef operands, NamedAttributeList attributes, ArrayRef successors, RegionRange regions, bool resizableOperandList) { @@ -151,7 +151,7 @@ Operation *Operation::create(Location location, OperationName name, /// unnecessarily uniquing a list of attributes. Operation *Operation::create(Location location, OperationName name, ArrayRef resultTypes, - ArrayRef operands, + ArrayRef operands, NamedAttributeList attributes, ArrayRef successors, unsigned numRegions, bool resizableOperandList) { @@ -314,7 +314,7 @@ bool Operation::isProperAncestor(Operation *other) { } /// Replace any uses of 'from' with 'to' within this operation. -void Operation::replaceUsesOfWith(Value *from, Value *to) { +void Operation::replaceUsesOfWith(ValuePtr from, ValuePtr to) { if (from == to) return; for (auto &operand : getOpOperands()) @@ -585,7 +585,7 @@ void Operation::dropAllDefinedValueUses() { /// Return true if there are no users of any results of this operation. bool Operation::use_empty() { - for (auto *result : getResults()) + for (auto result : getResults()) if (!result->use_empty()) return false; return true; @@ -672,14 +672,14 @@ InFlightDiagnostic Operation::emitOpError(const Twine &message) { /// Operands are remapped using `mapper` (if present), and `mapper` is updated /// to contain the results. Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper) { - SmallVector operands; + SmallVector operands; SmallVector successors; operands.reserve(getNumOperands() + getNumSuccessors()); if (getNumSuccessors() == 0) { // Non-branching operations can just add all the operands. - for (auto *opValue : getOperands()) + for (auto opValue : getOperands()) operands.push_back(mapper.lookupOrDefault(opValue)); } else { // We add the operands separated by nullptr's for each successor. @@ -699,7 +699,7 @@ Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper) { operands.push_back(nullptr); // Remap the successors operands. - for (auto *operand : getSuccessorOperands(succ)) + for (auto operand : getSuccessorOperands(succ)) operands.push_back(mapper.lookupOrDefault(operand)); } } @@ -1092,8 +1092,8 @@ LogicalResult OpTrait::impl::verifyResultSizeAttr(Operation *op, // These functions are out-of-line implementations of the methods in BinaryOp, // which avoids them being template instantiated/duplicated. -void impl::buildBinaryOp(Builder *builder, OperationState &result, Value *lhs, - Value *rhs) { +void impl::buildBinaryOp(Builder *builder, OperationState &result, ValuePtr lhs, + ValuePtr rhs) { assert(lhs->getType() == rhs->getType()); result.addOperands({lhs, rhs}); result.types.push_back(lhs->getType()); @@ -1133,8 +1133,8 @@ void impl::printOneResultOp(Operation *op, OpAsmPrinter &p) { // CastOp implementation //===----------------------------------------------------------------------===// -void impl::buildCastOp(Builder *builder, OperationState &result, Value *source, - Type destType) { +void impl::buildCastOp(Builder *builder, OperationState &result, + ValuePtr source, Type destType) { result.addOperands(source); result.addTypes(destType); } @@ -1157,7 +1157,7 @@ void impl::printCastOp(Operation *op, OpAsmPrinter &p) { << op->getResult(0)->getType(); } -Value *impl::foldCastOp(Operation *op) { +ValuePtr impl::foldCastOp(Operation *op) { // Identity cast if (op->getOperand(0)->getType() == op->getResult(0)->getType()) return op->getOperand(0); diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp index 256a261acd8..333685a16fd 100644 --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -164,7 +164,7 @@ ResultRange::ResultRange(Operation *op) //===----------------------------------------------------------------------===// // ValueRange -ValueRange::ValueRange(ArrayRef values) +ValueRange::ValueRange(ArrayRef values) : ValueRange(values.data(), values.size()) {} ValueRange::ValueRange(OperandRange values) : ValueRange(values.begin().getBase(), values.size()) {} @@ -176,18 +176,19 @@ ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner, ptrdiff_t index) { if (OpOperand *operand = owner.dyn_cast()) return operand + index; - if (OpResult *result = owner.dyn_cast()) + if (OpResultPtr result = owner.dyn_cast()) return result + index; - return owner.get() + index; + return owner.get() + index; } /// See `detail::indexed_accessor_range_base` for details. -Value *ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) { +ValuePtr ValueRange::dereference_iterator(const OwnerT &owner, + ptrdiff_t index) { // Operands access the held value via 'get'. if (OpOperand *operand = owner.dyn_cast()) return operand[index].get(); // An OpResult is a value, so we can return it directly. - if (OpResult *result = owner.dyn_cast()) + if (OpResultPtr result = owner.dyn_cast()) return &result[index]; // Otherwise, this is a raw value array so just index directly. - return owner.get()[index]; + return owner.get()[index]; } diff --git a/mlir/lib/IR/Region.cpp b/mlir/lib/IR/Region.cpp index 6cec021b6a1..26f14c43424 100644 --- a/mlir/lib/IR/Region.cpp +++ b/mlir/lib/IR/Region.cpp @@ -91,7 +91,7 @@ void Region::cloneInto(Region *dest, Region::iterator destPos, // Clone the block arguments. The user might be deleting arguments to the // block by specifying them in the mapper. If so, we don't add the // argument to the cloned block. - for (auto *arg : block.getArguments()) + for (auto arg : block.getArguments()) if (!mapper.contains(arg)) mapper.map(arg, newBlock->addArgument(arg->getType())); @@ -106,7 +106,7 @@ void Region::cloneInto(Region *dest, Region::iterator destPos, // operands of each of the operations. auto remapOperands = [&](Operation *op) { for (auto &operand : op->getOpOperands()) - if (auto *mappedOp = mapper.lookupOrNull(operand.get())) + if (auto mappedOp = mapper.lookupOrNull(operand.get())) operand.set(mappedOp); for (auto &succOp : op->getBlockOperands()) if (auto *mappedOp = mapper.lookupOrNull(succOp.get())) @@ -143,7 +143,7 @@ static bool isIsolatedAbove(Region ®ion, Region &limit, while (!pendingRegions.empty()) { for (Block &block : *pendingRegions.pop_back_val()) { for (Operation &op : block) { - for (Value *operand : op.getOperands()) { + for (ValuePtr operand : op.getOperands()) { // operand should be non-null here if the IR is well-formed. But // we don't assert here as this function is called from the verifier // and so could be called on invalid IR. diff --git a/mlir/lib/IR/TypeUtilities.cpp b/mlir/lib/IR/TypeUtilities.cpp index 54b1bf6329b..8200e3a3bc6 100644 --- a/mlir/lib/IR/TypeUtilities.cpp +++ b/mlir/lib/IR/TypeUtilities.cpp @@ -33,11 +33,11 @@ Type mlir::getElementTypeOrSelf(Type type) { return type; } -Type mlir::getElementTypeOrSelf(Value *val) { +Type mlir::getElementTypeOrSelf(ValuePtr val) { return getElementTypeOrSelf(val->getType()); } -Type mlir::getElementTypeOrSelf(Value &val) { +Type mlir::getElementTypeOrSelf(ValueRef val) { return getElementTypeOrSelf(val.getType()); } @@ -101,18 +101,18 @@ LogicalResult mlir::verifyCompatibleShape(Type type1, Type type2) { OperandElementTypeIterator::OperandElementTypeIterator( Operation::operand_iterator it) - : llvm::mapped_iterator( + : llvm::mapped_iterator( it, &unwrap) {} -Type OperandElementTypeIterator::unwrap(Value *value) { +Type OperandElementTypeIterator::unwrap(ValuePtr value) { return value->getType().cast().getElementType(); } ResultElementTypeIterator::ResultElementTypeIterator( Operation::result_iterator it) - : llvm::mapped_iterator( + : llvm::mapped_iterator( it, &unwrap) {} -Type ResultElementTypeIterator::unwrap(Value *value) { +Type ResultElementTypeIterator::unwrap(ValuePtr value) { return value->getType().cast().getElementType(); } diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp index 4c2ea5ac69c..660d8ae3248 100644 --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -23,7 +23,7 @@ using namespace mlir; /// If this value is the result of an Operation, return the operation that /// defines it. Operation *Value::getDefiningOp() { - if (auto *result = dyn_cast(this)) + if (auto *result = dyn_cast()) return result->getOwner(); return nullptr; } @@ -38,7 +38,7 @@ Location Value::getLoc() { Region *Value::getParentRegion() { if (auto *op = getDefiningOp()) return op->getParentRegion(); - return cast(this)->getOwner()->getParent(); + return cast()->getOwner()->getParent(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 498a64d70c2..f78704842fe 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -3093,7 +3093,7 @@ public: ParseResult popSSANameScope(); /// Register a definition of a value with the symbol table. - ParseResult addDefinition(SSAUseInfo useInfo, Value *value); + ParseResult addDefinition(SSAUseInfo useInfo, ValuePtr value); /// Parse an optional list of SSA uses into 'results'. ParseResult parseOptionalSSAUseList(SmallVectorImpl &results); @@ -3103,12 +3103,13 @@ public: /// Given a reference to an SSA value and its type, return a reference. This /// returns null on failure. - Value *resolveSSAUse(SSAUseInfo useInfo, Type type); + ValuePtr resolveSSAUse(SSAUseInfo useInfo, Type type); ParseResult parseSSADefOrUseAndType( const std::function &action); - ParseResult parseOptionalSSAUseAndTypeList(SmallVectorImpl &results); + ParseResult + parseOptionalSSAUseAndTypeList(SmallVectorImpl &results); /// Return the location of the value identified by its name and number if it /// has been already reference. @@ -3130,12 +3131,12 @@ public: /// Parse a single operation successor and its operand list. ParseResult parseSuccessorAndUseList(Block *&dest, - SmallVectorImpl &operands); + SmallVectorImpl &operands); /// Parse a comma-separated list of operation successors in brackets. ParseResult parseSuccessors(SmallVectorImpl &destinations, - SmallVectorImpl> &operands); + SmallVectorImpl> &operands); /// Parse an operation instance that is in the generic form. Operation *parseGenericOperation(); @@ -3174,7 +3175,7 @@ public: /// Parse a (possibly empty) list of block arguments. ParseResult - parseOptionalBlockArgList(SmallVectorImpl &results, + parseOptionalBlockArgList(SmallVectorImpl &results, Block *owner); /// Get the block with the specified name, creating it if it doesn't @@ -3204,14 +3205,14 @@ private: void recordDefinition(StringRef def); /// Get the value entry for the given SSA name. - SmallVectorImpl> &getSSAValueEntry(StringRef name); + SmallVectorImpl> &getSSAValueEntry(StringRef name); /// Create a forward reference placeholder value with the given location and /// result type. - Value *createForwardRefPlaceholder(SMLoc loc, Type type); + ValuePtr createForwardRefPlaceholder(SMLoc loc, Type type); /// Return true if this is a forward reference. - bool isForwardRefPlaceholder(Value *value) { + bool isForwardRefPlaceholder(ValuePtr value) { return forwardRefPlaceholders.count(value); } @@ -3236,7 +3237,7 @@ private: /// This keeps track of all of the SSA values we are tracking for each name /// scope, indexed by their name. This has one entry per result number. - llvm::StringMap, 1>> values; + llvm::StringMap, 1>> values; /// This keeps track of all of the values defined by a specific name scope. SmallVector, 2> definitionsPerScope; @@ -3253,7 +3254,7 @@ private: /// These are all of the placeholders we've made along with the location of /// their first reference, to allow checking for use of undefined values. - DenseMap forwardRefPlaceholders; + DenseMap forwardRefPlaceholders; /// The builder used when creating parsed operation instances. OpBuilder opBuilder; @@ -3278,7 +3279,7 @@ ParseResult OperationParser::finalize() { // Check for any forward references that are left. If we find any, error // out. if (!forwardRefPlaceholders.empty()) { - SmallVector, 4> errors; + SmallVector, 4> errors; // Iteration over the map isn't deterministic, so sort by source location. for (auto entry : forwardRefPlaceholders) errors.push_back({entry.second.getPointer(), entry.first}); @@ -3342,7 +3343,7 @@ ParseResult OperationParser::popSSANameScope() { } /// Register a definition of a value with the symbol table. -ParseResult OperationParser::addDefinition(SSAUseInfo useInfo, Value *value) { +ParseResult OperationParser::addDefinition(SSAUseInfo useInfo, ValuePtr value) { auto &entries = getSSAValueEntry(useInfo.name); // Make sure there is a slot for this value. @@ -3351,7 +3352,7 @@ ParseResult OperationParser::addDefinition(SSAUseInfo useInfo, Value *value) { // If we already have an entry for this, check to see if it was a definition // or a forward reference. - if (auto *existing = entries[useInfo.number].first) { + if (auto existing = entries[useInfo.number].first) { if (!isForwardRefPlaceholder(existing)) { return emitError(useInfo.loc) .append("redefinition of SSA value '", useInfo.name, "'") @@ -3416,12 +3417,12 @@ ParseResult OperationParser::parseSSAUse(SSAUseInfo &result) { /// Given an unbound reference to an SSA value and its type, return the value /// it specifies. This returns null on failure. -Value *OperationParser::resolveSSAUse(SSAUseInfo useInfo, Type type) { +ValuePtr OperationParser::resolveSSAUse(SSAUseInfo useInfo, Type type) { auto &entries = getSSAValueEntry(useInfo.name); // If we have already seen a value of this name, return it. if (useInfo.number < entries.size() && entries[useInfo.number].first) { - auto *result = entries[useInfo.number].first; + auto result = entries[useInfo.number].first; // Check that the type matches the other uses. if (result->getType() == type) return result; @@ -3447,7 +3448,7 @@ Value *OperationParser::resolveSSAUse(SSAUseInfo useInfo, Type type) { // Otherwise, this is a forward reference. Create a placeholder and remember // that we did so. - auto *result = createForwardRefPlaceholder(useInfo.loc, type); + auto result = createForwardRefPlaceholder(useInfo.loc, type); entries[useInfo.number].first = result; entries[useInfo.number].second = useInfo.loc; return result; @@ -3477,7 +3478,7 @@ ParseResult OperationParser::parseSSADefOrUseAndType( /// ::= ssa-use-list ':' type-list-no-parens /// ParseResult OperationParser::parseOptionalSSAUseAndTypeList( - SmallVectorImpl &results) { + SmallVectorImpl &results) { SmallVector valueIDs; if (parseOptionalSSAUseList(valueIDs)) return failure(); @@ -3497,7 +3498,7 @@ ParseResult OperationParser::parseOptionalSSAUseAndTypeList( results.reserve(valueIDs.size()); for (unsigned i = 0, e = valueIDs.size(); i != e; ++i) { - if (auto *value = resolveSSAUse(valueIDs[i], types[i])) + if (auto value = resolveSSAUse(valueIDs[i], types[i])) results.push_back(value); else return failure(); @@ -3512,13 +3513,13 @@ void OperationParser::recordDefinition(StringRef def) { } /// Get the value entry for the given SSA name. -SmallVectorImpl> & +SmallVectorImpl> & OperationParser::getSSAValueEntry(StringRef name) { return isolatedNameScopes.back().values[name]; } /// Create and remember a new placeholder for a forward reference. -Value *OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) { +ValuePtr OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) { // Forward references are always created as operations, because we just need // something with a def/use chain. // @@ -3632,7 +3633,7 @@ ParseResult OperationParser::parseOperation() { /// ParseResult OperationParser::parseSuccessorAndUseList(Block *&dest, - SmallVectorImpl &operands) { + SmallVectorImpl &operands) { // Verify branch is identifier and get the matching block. if (!getToken().is(Token::caret_identifier)) return emitError("expected block name"); @@ -3655,13 +3656,13 @@ OperationParser::parseSuccessorAndUseList(Block *&dest, /// ParseResult OperationParser::parseSuccessors( SmallVectorImpl &destinations, - SmallVectorImpl> &operands) { + SmallVectorImpl> &operands) { if (parseToken(Token::l_square, "expected '['")) return failure(); auto parseElt = [this, &destinations, &operands]() { Block *dest; - SmallVector destOperands; + SmallVector destOperands; auto res = parseSuccessorAndUseList(dest, destOperands); destinations.push_back(dest); operands.push_back(destOperands); @@ -3718,7 +3719,7 @@ Operation *OperationParser::parseGenericOperation() { // Parse the successor list but don't add successors to the result yet to // avoid messing up with the argument order. SmallVector successors; - SmallVector, 2> successorOperands; + SmallVector, 2> successorOperands; if (getToken().is(Token::l_square)) { // Check if the operation is a known terminator. const AbstractOperation *abstractOp = result.name.getAbstractOperation(); @@ -3779,7 +3780,7 @@ Operation *OperationParser::parseGenericOperation() { // Add the successors, and their operands after the proper operands. for (const auto &succ : llvm::zip(successors, successorOperands)) { Block *successor = std::get<0>(succ); - const SmallVector &operands = std::get<1>(succ); + const SmallVector &operands = std::get<1>(succ); result.addSuccessor(successor, operands); } @@ -4129,10 +4130,10 @@ public: /// Resolve an operand to an SSA value, emitting an error on failure. ParseResult resolveOperand(const OperandType &operand, Type type, - SmallVectorImpl &result) override { + SmallVectorImpl &result) override { OperationParser::SSAUseInfo operandInfo = {operand.name, operand.number, operand.location}; - if (auto *value = parser.resolveSSAUse(operandInfo, type)) { + if (auto value = parser.resolveSSAUse(operandInfo, type)) { result.push_back(value); return success(); } @@ -4242,7 +4243,7 @@ public: /// Parse a single operation successor and its operand list. ParseResult parseSuccessorAndUseList(Block *&dest, - SmallVectorImpl &operands) override { + SmallVectorImpl &operands) override { return parser.parseSuccessorAndUseList(dest, operands); } @@ -4470,7 +4471,7 @@ ParseResult OperationParser::parseBlock(Block *&block) { // If an argument list is present, parse it. if (consumeIf(Token::l_paren)) { - SmallVector bbArgs; + SmallVector bbArgs; if (parseOptionalBlockArgList(bbArgs, block) || parseToken(Token::r_paren, "expected ')' to end argument list")) return failure(); @@ -4534,7 +4535,7 @@ Block *OperationParser::defineBlockNamed(StringRef name, SMLoc loc, /// ssa-id-and-type-list ::= ssa-id-and-type (`,` ssa-id-and-type)* /// ParseResult OperationParser::parseOptionalBlockArgList( - SmallVectorImpl &results, Block *owner) { + SmallVectorImpl &results, Block *owner) { if (getToken().is(Token::r_brace)) return success(); @@ -4555,7 +4556,7 @@ ParseResult OperationParser::parseOptionalBlockArgList( return emitError("too many arguments specified in argument list"); // Finally, make sure the existing argument has the correct type. - auto *arg = owner->getArgument(nextArgument++); + auto arg = owner->getArgument(nextArgument++); if (arg->getType() != type) return emitError("argument and block argument type mismatch"); return addDefinition(useInfo, arg); diff --git a/mlir/lib/Pass/IRPrinting.cpp b/mlir/lib/Pass/IRPrinting.cpp index 8e172156f05..9d1c1f0d391 100644 --- a/mlir/lib/Pass/IRPrinting.cpp +++ b/mlir/lib/Pass/IRPrinting.cpp @@ -48,14 +48,14 @@ public: for (Region ®ion : op->getRegions()) { for (Block &block : region) { addDataToHash(hasher, &block); - for (BlockArgument *arg : block.getArguments()) + for (BlockArgumentPtr arg : block.getArguments()) addDataToHash(hasher, arg); } } // - Location addDataToHash(hasher, op->getLoc().getAsOpaquePointer()); // - Operands - for (Value *operand : op->getOperands()) + for (ValuePtr operand : op->getOperands()) addDataToHash(hasher, operand); // - Successors for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) diff --git a/mlir/lib/Quantizer/Support/ConstraintAnalysisGraph.cpp b/mlir/lib/Quantizer/Support/ConstraintAnalysisGraph.cpp index d38c76255f0..13fed0f9b1c 100644 --- a/mlir/lib/Quantizer/Support/ConstraintAnalysisGraph.cpp +++ b/mlir/lib/Quantizer/Support/ConstraintAnalysisGraph.cpp @@ -102,7 +102,7 @@ void CAGSlice::enumerateImpliedConnections( std::vector> impliedPairs; for (auto &resultAnchorPair : resultAnchors) { CAGResultAnchor *resultAnchor = resultAnchorPair.second; - Value *resultValue = resultAnchor->getValue(); + ValuePtr resultValue = resultAnchor->getValue(); for (auto &use : resultValue->getUses()) { Operation *operandOp = use.getOwner(); unsigned operandIdx = use.getOperandNumber(); diff --git a/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp b/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp index a32bb2c9b3c..a3cbe214040 100644 --- a/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp +++ b/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp @@ -74,7 +74,7 @@ void AddDefaultStatsPass::runWithConfig(SolverContext &solverContext, auto func = getFunction(); // Insert stats for each argument. - for (auto *arg : func.getArguments()) { + for (auto arg : func.getArguments()) { if (!config.isHandledType(arg->getType())) continue; OpBuilder b(func.getBody()); diff --git a/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp b/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp index 511df0a463f..68c263bc423 100644 --- a/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp +++ b/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp @@ -181,17 +181,17 @@ void InferQuantizedTypesPass::runWithConfig(SolverContext &solverContext, void InferQuantizedTypesPass::transformOperandType(CAGOperandAnchor *anchor, Type newType) { - Value *inputValue = anchor->getValue(); + ValuePtr inputValue = anchor->getValue(); Operation *op = anchor->getOp(); OpBuilder b(op->getBlock(), Block::iterator(op)); - SmallVector removeValuesIfDead; + SmallVector removeValuesIfDead; // Because we've already run the result transforms at this phase, it is // very likely that inputValue points to a dcast op whose input matches // our type. We detect that situation and route around just to save some // bulk in the IR. - Value *newTypedInputValue = inputValue; + ValuePtr newTypedInputValue = inputValue; auto inputDcastOp = dyn_cast_or_null(inputValue->getDefiningOp()); if (inputDcastOp && inputDcastOp.arg()->getType() == newType) { @@ -228,7 +228,7 @@ void InferQuantizedTypesPass::transformOperandType(CAGOperandAnchor *anchor, break; } - for (Value *removeValueIfDead : removeValuesIfDead) { + for (ValuePtr removeValueIfDead : removeValuesIfDead) { if (removeValueIfDead->use_empty()) { removeValueIfDead->getDefiningOp()->erase(); } @@ -237,12 +237,12 @@ void InferQuantizedTypesPass::transformOperandType(CAGOperandAnchor *anchor, void InferQuantizedTypesPass::transformResultType(CAGResultAnchor *anchor, Type newType) { - Value *origResultValue = anchor->getValue(); + ValuePtr origResultValue = anchor->getValue(); Operation *op = origResultValue->getDefiningOp(); OpBuilder b(op->getBlock(), ++Block::iterator(op)); - Value *replacedResultValue = nullptr; - Value *newResultValue = nullptr; + ValuePtr replacedResultValue = nullptr; + ValuePtr newResultValue = nullptr; switch (anchor->getTypeTransformRule()) { case CAGAnchorNode::TypeTransformRule::Direct: origResultValue->setType(newType); diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp index 098dba3ae6e..e8f44087b85 100644 --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -224,7 +224,7 @@ tblgen::SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const { return formatv("Operation::operand_range {0}(op0->getOperands());\n", name); } case Kind::Value: { - return formatv("ArrayRef {0};\n", name); + return formatv("ArrayRef {0};\n", name); } case Kind::Result: { // Use the op itself for captured results. diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp index 6cf975bcce2..7273d3dfd7b 100644 --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -76,7 +76,7 @@ private: /// `value` is an SSA-use. Return the remapped version of `value` or a /// placeholder that will be remapped later if this is an instruction that /// has not yet been visited. - Value *processValue(llvm::Value *value); + ValuePtr processValue(llvm::Value *value); /// Create the most accurate Location possible using a llvm::DebugLoc and /// possibly an llvm::Instruction to narrow the Location if debug information /// is unavailable. @@ -85,14 +85,14 @@ private: /// `br` branches to `target`. Return the block arguments to attach to the /// generated branch op. These should be in the same order as the PHIs in /// `target`. - SmallVector processBranchArgs(llvm::BranchInst *br, - llvm::BasicBlock *target); + SmallVector processBranchArgs(llvm::BranchInst *br, + llvm::BasicBlock *target); /// Return `value` as an attribute to attach to a GlobalOp. Attribute getConstantAsAttr(llvm::Constant *value); /// Return `c` as an MLIR Value. This could either be a ConstantOp, or /// an expanded sequence of ops in the current function's entry block (for /// ConstantExprs or ConstantGEPs). - Value *processConstant(llvm::Constant *c); + ValuePtr processConstant(llvm::Constant *c); /// The current builder, pointing at where the next Instruction should be /// generated. @@ -120,7 +120,7 @@ private: /// Remapped blocks, for the current function. DenseMap blocks; /// Remapped values. These are function-local. - DenseMap instMap; + DenseMap instMap; /// Instructions that had not been defined when first encountered as a use. /// Maps to the dummy Operation that was created in processValue(). DenseMap unknownInstMap; @@ -263,13 +263,13 @@ GlobalOp Importer::processGlobal(llvm::GlobalVariable *GV) { Region &r = op.getInitializerRegion(); currentEntryBlock = b.createBlock(&r); b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin()); - Value *v = processConstant(GV->getInitializer()); - b.create(op.getLoc(), ArrayRef({v})); + ValuePtr v = processConstant(GV->getInitializer()); + b.create(op.getLoc(), ArrayRef({v})); } return globals[GV] = op; } -Value *Importer::processConstant(llvm::Constant *c) { +ValuePtr Importer::processConstant(llvm::Constant *c) { if (Attribute attr = getConstantAsAttr(c)) { // These constants can be represented as attributes. OpBuilder b(currentEntryBlock, currentEntryBlock->begin()); @@ -298,7 +298,7 @@ Value *Importer::processConstant(llvm::Constant *c) { return nullptr; } -Value *Importer::processValue(llvm::Value *value) { +ValuePtr Importer::processValue(llvm::Value *value) { auto it = instMap.find(value); if (it != instMap.end()) return it->second; @@ -407,9 +407,9 @@ static ICmpPredicate getICmpPredicate(llvm::CmpInst::Predicate p) { // `br` branches to `target`. Return the branch arguments to `br`, in the // same order of the PHIs in `target`. -SmallVector Importer::processBranchArgs(llvm::BranchInst *br, - llvm::BasicBlock *target) { - SmallVector v; +SmallVector Importer::processBranchArgs(llvm::BranchInst *br, + llvm::BasicBlock *target) { + SmallVector v; for (auto inst = target->begin(); isa(inst); ++inst) { auto *PN = cast(&*inst); v.push_back(processValue(PN->getIncomingValueForBlock(br->getParent()))); @@ -421,7 +421,7 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) { // FIXME: Support uses of SubtargetData. Currently inbounds GEPs, fast-math // flags and call / operand attributes are not supported. Location loc = processDebugLoc(inst->getDebugLoc(), inst); - Value *&v = instMap[inst]; + ValuePtr &v = instMap[inst]; assert(!v && "processInstruction must be called only once per instruction!"); switch (inst->getOpcode()) { default: @@ -462,7 +462,7 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) { case llvm::Instruction::AddrSpaceCast: case llvm::Instruction::BitCast: { OperationState state(loc, opcMap.lookup(inst->getOpcode())); - SmallVector ops; + SmallVector ops; ops.reserve(inst->getNumOperands()); for (auto *op : inst->operand_values()) ops.push_back(processValue(op)); @@ -484,7 +484,7 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) { auto *brInst = cast(inst); OperationState state(loc, brInst->isConditional() ? "llvm.cond_br" : "llvm.br"); - SmallVector ops; + SmallVector ops; if (brInst->isConditional()) ops.push_back(processValue(brInst->getCondition())); state.addOperands(ops); @@ -500,7 +500,7 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) { } case llvm::Instruction::Call: { llvm::CallInst *ci = cast(inst); - SmallVector ops; + SmallVector ops; ops.reserve(inst->getNumOperands()); for (auto &op : ci->arg_operands()) ops.push_back(processValue(op.get())); @@ -523,7 +523,7 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) { case llvm::Instruction::GetElementPtr: { // FIXME: Support inbounds GEPs. llvm::GetElementPtrInst *gep = cast(inst); - SmallVector ops; + SmallVector ops; for (auto *op : gep->operand_values()) ops.push_back(processValue(op)); v = b.create(loc, processType(inst->getType()), ops, @@ -565,8 +565,8 @@ LogicalResult Importer::processFunction(llvm::Function *f) { // any unknown uses we encountered are remapped. for (auto &llvmAndUnknown : unknownInstMap) { assert(instMap.count(llvmAndUnknown.first)); - Value *newValue = instMap[llvmAndUnknown.first]; - Value *oldValue = llvmAndUnknown.second->getResult(0); + ValuePtr newValue = instMap[llvmAndUnknown.first]; + ValuePtr oldValue = llvmAndUnknown.second->getResult(0); oldValue->replaceAllUsesWith(newValue); llvmAndUnknown.second->erase(); } diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index e59c69aa25b..ec28434b823 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -248,7 +248,7 @@ LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments) { auto predecessors = bb.getPredecessors(); unsigned numPredecessors = std::distance(predecessors.begin(), predecessors.end()); - for (auto *arg : bb.getArguments()) { + for (auto arg : bb.getArguments()) { auto wrappedType = arg->getType().dyn_cast(); if (!wrappedType) return emitError(bb.front().getLoc(), @@ -342,8 +342,8 @@ void ModuleTranslation::convertGlobals() { /// Get the SSA value passed to the current block from the terminator operation /// of its predecessor. -static Value *getPHISourceValue(Block *current, Block *pred, - unsigned numArguments, unsigned index) { +static ValuePtr getPHISourceValue(Block *current, Block *pred, + unsigned numArguments, unsigned index) { auto &terminator = *pred->getTerminator(); if (isa(terminator)) { return terminator.getOperand(index); @@ -420,7 +420,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) { unsigned int argIdx = 0; for (const auto &kvp : llvm::zip(func.getArguments(), llvmFunc->args())) { llvm::Argument &llvmArg = std::get<1>(kvp); - BlockArgument *mlirArg = std::get<0>(kvp); + BlockArgumentPtr mlirArg = std::get<0>(kvp); if (auto attr = func.getArgAttrOfType(argIdx, "llvm.noalias")) { // NB: Attribute already verified to be boolean, so check if we can indeed @@ -497,7 +497,7 @@ SmallVector ModuleTranslation::lookupValues(ValueRange values) { SmallVector remapped; remapped.reserve(values.size()); - for (Value *v : values) + for (ValuePtr v : values) remapped.push_back(valueMapping.lookup(v)); return remapped; } diff --git a/mlir/lib/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Transforms/AffineDataCopyGeneration.cpp index 7fb356f3ad2..5bc33943e50 100644 --- a/mlir/lib/Transforms/AffineDataCopyGeneration.cpp +++ b/mlir/lib/Transforms/AffineDataCopyGeneration.cpp @@ -130,7 +130,7 @@ struct AffineDataCopyGeneration bool skipNonUnitStrideLoops; // Constant zero index to avoid too many duplicates. - Value *zeroIndex = nullptr; + ValuePtr zeroIndex = nullptr; }; } // end anonymous namespace diff --git a/mlir/lib/Transforms/AffineLoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/AffineLoopInvariantCodeMotion.cpp index f384f6d3fb1..23199dd8a39 100644 --- a/mlir/lib/Transforms/AffineLoopInvariantCodeMotion.cpp +++ b/mlir/lib/Transforms/AffineLoopInvariantCodeMotion.cpp @@ -58,15 +58,15 @@ struct LoopInvariantCodeMotion : public FunctionPass { } // end anonymous namespace static bool -checkInvarianceOfNestedIfOps(Operation *op, Value *indVar, +checkInvarianceOfNestedIfOps(Operation *op, ValuePtr indVar, SmallPtrSetImpl &definedOps, SmallPtrSetImpl &opsToHoist); -static bool isOpLoopInvariant(Operation &op, Value *indVar, +static bool isOpLoopInvariant(Operation &op, ValuePtr indVar, SmallPtrSetImpl &definedOps, SmallPtrSetImpl &opsToHoist); static bool -areAllOpsInTheBlockListInvariant(Region &blockList, Value *indVar, +areAllOpsInTheBlockListInvariant(Region &blockList, ValuePtr indVar, SmallPtrSetImpl &definedOps, SmallPtrSetImpl &opsToHoist); @@ -79,7 +79,7 @@ static bool isMemRefDereferencingOp(Operation &op) { } // Returns true if the individual op is loop invariant. -bool isOpLoopInvariant(Operation &op, Value *indVar, +bool isOpLoopInvariant(Operation &op, ValuePtr indVar, SmallPtrSetImpl &definedOps, SmallPtrSetImpl &opsToHoist) { LLVM_DEBUG(llvm::dbgs() << "iterating on op: " << op;); @@ -97,9 +97,9 @@ bool isOpLoopInvariant(Operation &op, Value *indVar, return false; } else if (!isa(op)) { if (isMemRefDereferencingOp(op)) { - Value *memref = isa(op) - ? cast(op).getMemRef() - : cast(op).getMemRef(); + ValuePtr memref = isa(op) + ? cast(op).getMemRef() + : cast(op).getMemRef(); for (auto *user : memref->getUsers()) { // If this memref has a user that is a DMA, give up because these // operations write to this memref. @@ -163,7 +163,8 @@ bool isOpLoopInvariant(Operation &op, Value *indVar, // Checks if all ops in a region (i.e. list of blocks) are loop invariant. bool areAllOpsInTheBlockListInvariant( - Region &blockList, Value *indVar, SmallPtrSetImpl &definedOps, + Region &blockList, ValuePtr indVar, + SmallPtrSetImpl &definedOps, SmallPtrSetImpl &opsToHoist) { for (auto &b : blockList) { @@ -178,7 +179,7 @@ bool areAllOpsInTheBlockListInvariant( } // Returns true if the affine.if op can be hoisted. -bool checkInvarianceOfNestedIfOps(Operation *op, Value *indVar, +bool checkInvarianceOfNestedIfOps(Operation *op, ValuePtr indVar, SmallPtrSetImpl &definedOps, SmallPtrSetImpl &opsToHoist) { assert(isa(op)); @@ -199,7 +200,7 @@ bool checkInvarianceOfNestedIfOps(Operation *op, Value *indVar, void LoopInvariantCodeMotion::runOnAffineForOp(AffineForOp forOp) { auto *loopBody = forOp.getBody(); - auto *indVar = forOp.getInductionVar(); + auto indVar = forOp.getInductionVar(); SmallPtrSet definedOps; // This is the place where hoisted instructions would reside. diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index 37c918fe9be..05066ef599c 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -86,13 +86,13 @@ namespace { struct ConversionValueMapping { /// Lookup a mapped value within the map. If a mapping for the provided value /// does not exist then return the provided value. - Value *lookupOrDefault(Value *from) const; + ValuePtr lookupOrDefault(ValuePtr from) const; /// Map a value to the one provided. - void map(Value *oldVal, Value *newVal) { mapping.map(oldVal, newVal); } + void map(ValuePtr oldVal, ValuePtr newVal) { mapping.map(oldVal, newVal); } /// Drop the last mapping for the given value. - void erase(Value *value) { mapping.erase(value); } + void erase(ValuePtr value) { mapping.erase(value); } private: /// Current value mappings. @@ -102,10 +102,10 @@ private: /// Lookup a mapped value within the map. If a mapping for the provided value /// does not exist then return the provided value. -Value *ConversionValueMapping::lookupOrDefault(Value *from) const { +ValuePtr ConversionValueMapping::lookupOrDefault(ValuePtr from) const { // If this value had a valid mapping, unmap that value as well in the case // that it was also replaced. - while (auto *mappedValue = mapping.lookupOrNull(from)) + while (auto mappedValue = mapping.lookupOrNull(from)) from = mappedValue; return from; } @@ -127,7 +127,7 @@ struct ArgConverter { /// been converted. struct ConvertedArgInfo { ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize, - Value *castValue = nullptr) + ValuePtr castValue = nullptr) : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {} /// The start index of in the new argument list that contains arguments that @@ -139,7 +139,7 @@ struct ArgConverter { /// The cast value that was created to cast from the new arguments to the /// old. This only used if 'newArgSize' > 1. - Value *castValue; + ValuePtr castValue; }; /// This structure contains information pertaining to a block that has had its @@ -235,7 +235,7 @@ void ArgConverter::notifyOpRemoved(Operation *op) { // Drop all uses of the original arguments and delete the original block. Block *origBlock = it->second.origBlock; - for (BlockArgument *arg : origBlock->getArguments()) + for (BlockArgumentPtr arg : origBlock->getArguments()) arg->dropAllUses(); conversionInfo.erase(it); } @@ -270,7 +270,7 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) { // Process the remapping for each of the original arguments. for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) { Optional &argInfo = blockInfo.argInfo[i]; - BlockArgument *origArg = origBlock->getArgument(i); + BlockArgumentPtr origArg = origBlock->getArgument(i); // Handle the case of a 1->0 value mapping. if (!argInfo) { @@ -305,7 +305,7 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) { } // Otherwise this is a 1->N value mapping. - Value *castValue = argInfo->castValue; + ValuePtr castValue = argInfo->castValue; assert(argInfo->newArgSize > 1 && castValue && "expected 1->N mapping"); // If the argument is still used, replace it with the generated cast. @@ -344,8 +344,8 @@ Block *ArgConverter::applySignatureConversion( Block *newBlock = block->splitBlock(block->begin()); block->replaceAllUsesWith(newBlock); - SmallVector newArgRange(newBlock->addArguments(convertedTypes)); - ArrayRef newArgs(newArgRange); + SmallVector newArgRange(newBlock->addArguments(convertedTypes)); + ArrayRef newArgs(newArgRange); // Remap each of the original arguments as determined by the signature // conversion. @@ -358,7 +358,7 @@ Block *ArgConverter::applySignatureConversion( auto inputMap = signatureConversion.getInputMapping(i); if (!inputMap) continue; - BlockArgument *origArg = block->getArgument(i); + BlockArgumentPtr origArg = block->getArgument(i); // If inputMap->replacementValue is not nullptr, then the argument is // dropped and a replacement value is provided to be the remappedValue. @@ -445,7 +445,7 @@ struct ConversionPatternRewriterImpl { : op(op), newValues(newValues.begin(), newValues.end()) {} Operation *op; - SmallVector newValues; + SmallVector newValues; }; /// The kind of the block action performed during the rewrite. Actions can be @@ -542,7 +542,7 @@ struct ConversionPatternRewriterImpl { /// Remap the given operands to those with potentially different types. void remapValues(Operation::operand_range operands, - SmallVectorImpl &remapped); + SmallVectorImpl &remapped); /// Returns true if the given operation is ignored, and does not need to be /// converted. @@ -591,7 +591,7 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) { // Reset any replaced operations and undo any saved mappings. for (auto &repl : llvm::drop_begin(replacements, state.numReplacements)) - for (auto *result : repl.op->getResults()) + for (auto result : repl.op->getResults()) mapping.erase(result); replacements.resize(state.numReplacements); @@ -660,7 +660,7 @@ void ConversionPatternRewriterImpl::applyRewrites() { // Apply all of the rewrites replacements requested during conversion. for (auto &repl : replacements) { for (unsigned i = 0, e = repl.newValues.size(); i != e; ++i) { - if (auto *newValue = repl.newValues[i]) + if (auto newValue = repl.newValues[i]) repl.op->getResult(i)->replaceAllUsesWith( mapping.lookupOrDefault(newValue)); } @@ -715,7 +715,7 @@ void ConversionPatternRewriterImpl::replaceOp(Operation *op, // Create mappings for each of the new result values. for (unsigned i = 0, e = newValues.size(); i < e; ++i) - if (auto *repl = newValues[i]) + if (auto repl = newValues[i]) mapping.map(op->getResult(i), repl); // Record the requested operation replacement. @@ -755,9 +755,9 @@ void ConversionPatternRewriterImpl::notifyRegionWasClonedBefore( } void ConversionPatternRewriterImpl::remapValues( - Operation::operand_range operands, SmallVectorImpl &remapped) { + Operation::operand_range operands, SmallVectorImpl &remapped) { remapped.reserve(llvm::size(operands)); - for (Value *operand : operands) + for (ValuePtr operand : operands) remapped.push_back(mapping.lookupOrDefault(operand)); } @@ -803,7 +803,7 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues, void ConversionPatternRewriter::eraseOp(Operation *op) { LLVM_DEBUG(llvm::dbgs() << "** Erasing operation : " << op->getName() << "\n"); - SmallVector nullRepls(op->getNumResults(), nullptr); + SmallVector nullRepls(op->getNumResults(), nullptr); impl->replaceOp(op, nullRepls, /*valuesToRemoveIfDead=*/llvm::None); } @@ -813,8 +813,8 @@ Block *ConversionPatternRewriter::applySignatureConversion( return impl->applySignatureConversion(region, conversion); } -void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument *from, - Value *to) { +void ConversionPatternRewriter::replaceUsesOfBlockArgument( + BlockArgumentPtr from, ValuePtr to) { for (auto &u : from->getUses()) { if (u.getOwner() == to->getDefiningOp()) continue; @@ -825,7 +825,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument *from, /// Return the converted value that replaces 'key'. Return 'key' if there is /// no such a converted value. -Value *ConversionPatternRewriter::getRemappedValue(Value *key) { +ValuePtr ConversionPatternRewriter::getRemappedValue(ValuePtr key) { return impl->mapping.lookupOrDefault(key); } @@ -896,7 +896,7 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { PatternMatchResult ConversionPattern::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { - SmallVector operands; + SmallVector operands; auto &dialectRewriter = static_cast(rewriter); dialectRewriter.getImpl().remapValues(op->getOperands(), operands); @@ -908,7 +908,7 @@ ConversionPattern::matchAndRewrite(Operation *op, SmallVector destinations; destinations.reserve(op->getNumSuccessors()); - SmallVector, 2> operandsPerDestination; + SmallVector, 2> operandsPerDestination; unsigned firstSuccessorOperand = op->getSuccessorOperandIndex(0); for (unsigned i = 0, seen = 0, e = op->getNumSuccessors(); i < e; ++i) { destinations.push_back(op->getSuccessor(i)); @@ -1059,7 +1059,7 @@ OperationLegalizer::legalizeWithFold(Operation *op, RewriterState curState = rewriterImpl.getCurrentState(); // Try to fold the operation. - SmallVector replacementValues; + SmallVector replacementValues; rewriter.setInsertionPoint(op); if (failed(rewriter.tryFold(op, replacementValues))) return failure(); @@ -1459,7 +1459,7 @@ void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo, /// Remap an input of the original signature to another `replacementValue` /// value. This would make the signature converter drop this argument. void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo, - Value *replacementValue) { + ValuePtr replacementValue) { assert(!remappedInputs[origInputNo] && "input has already been remapped"); remappedInputs[origInputNo] = InputMapping{origInputNo, /*size=*/0, replacementValue}; @@ -1528,7 +1528,7 @@ struct FuncOpSignatureConversion : public OpConversionPattern { /// Hook for derived classes to implement combined matching and rewriting. PatternMatchResult - matchAndRewrite(FuncOp funcOp, ArrayRef operands, + matchAndRewrite(FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { FunctionType type = funcOp.getType(); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 5694c990b9b..60f0264eb35 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -172,7 +172,7 @@ public: Node(unsigned id, Operation *op) : id(id), op(op) {} // Returns the load op count for 'memref'. - unsigned getLoadOpCount(Value *memref) { + unsigned getLoadOpCount(ValuePtr memref) { unsigned loadOpCount = 0; for (auto *loadOpInst : loads) { if (memref == cast(loadOpInst).getMemRef()) @@ -182,7 +182,7 @@ public: } // Returns the store op count for 'memref'. - unsigned getStoreOpCount(Value *memref) { + unsigned getStoreOpCount(ValuePtr memref) { unsigned storeOpCount = 0; for (auto *storeOpInst : stores) { if (memref == cast(storeOpInst).getMemRef()) @@ -192,7 +192,7 @@ public: } // Returns all store ops in 'storeOps' which access 'memref'. - void getStoreOpsForMemref(Value *memref, + void getStoreOpsForMemref(ValuePtr memref, SmallVectorImpl *storeOps) { for (auto *storeOpInst : stores) { if (memref == cast(storeOpInst).getMemRef()) @@ -201,7 +201,7 @@ public: } // Returns all load ops in 'loadOps' which access 'memref'. - void getLoadOpsForMemref(Value *memref, + void getLoadOpsForMemref(ValuePtr memref, SmallVectorImpl *loadOps) { for (auto *loadOpInst : loads) { if (memref == cast(loadOpInst).getMemRef()) @@ -211,13 +211,13 @@ public: // Returns all memrefs in 'loadAndStoreMemrefSet' for which this node // has at least one load and store operation. - void getLoadAndStoreMemrefSet(DenseSet *loadAndStoreMemrefSet) { - llvm::SmallDenseSet loadMemrefs; + void getLoadAndStoreMemrefSet(DenseSet *loadAndStoreMemrefSet) { + llvm::SmallDenseSet loadMemrefs; for (auto *loadOpInst : loads) { loadMemrefs.insert(cast(loadOpInst).getMemRef()); } for (auto *storeOpInst : stores) { - auto *memref = cast(storeOpInst).getMemRef(); + auto memref = cast(storeOpInst).getMemRef(); if (loadMemrefs.count(memref) > 0) loadAndStoreMemrefSet->insert(memref); } @@ -239,7 +239,7 @@ public: // defines an SSA value and another graph node which uses the SSA value // (e.g. a constant operation defining a value which is used inside a loop // nest). - Value *value; + ValuePtr value; }; // Map from node id to Node. @@ -250,7 +250,7 @@ public: DenseMap> outEdges; // Map from memref to a count on the dependence edges associated with that // memref. - DenseMap memrefEdgeCount; + DenseMap memrefEdgeCount; // The next unique identifier to use for newly created graph nodes. unsigned nextNodeId = 0; @@ -309,7 +309,7 @@ public: bool writesToLiveInOrEscapingMemrefs(unsigned id) { Node *node = getNode(id); for (auto *storeOpInst : node->stores) { - auto *memref = cast(storeOpInst).getMemRef(); + auto memref = cast(storeOpInst).getMemRef(); auto *op = memref->getDefiningOp(); // Return true if 'memref' is a block argument. if (!op) @@ -338,7 +338,7 @@ public: const auto &nodeOutEdges = outEdgeIt->second; for (auto *op : node->stores) { auto storeOp = cast(op); - auto *memref = storeOp.getMemRef(); + auto memref = storeOp.getMemRef(); // Skip this store if there are no dependences on its memref. This means // that store either: // *) writes to a memref that is only read within the same loop nest @@ -381,7 +381,7 @@ public: // Returns true iff there is an edge from node 'srcId' to node 'dstId' which // is for 'value' if non-null, or for any value otherwise. Returns false // otherwise. - bool hasEdge(unsigned srcId, unsigned dstId, Value *value = nullptr) { + bool hasEdge(unsigned srcId, unsigned dstId, ValuePtr value = nullptr) { if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) { return false; } @@ -395,7 +395,7 @@ public: } // Adds an edge from node 'srcId' to node 'dstId' for 'value'. - void addEdge(unsigned srcId, unsigned dstId, Value *value) { + void addEdge(unsigned srcId, unsigned dstId, ValuePtr value) { if (!hasEdge(srcId, dstId, value)) { outEdges[srcId].push_back({dstId, value}); inEdges[dstId].push_back({srcId, value}); @@ -405,7 +405,7 @@ public: } // Removes an edge from node 'srcId' to node 'dstId' for 'value'. - void removeEdge(unsigned srcId, unsigned dstId, Value *value) { + void removeEdge(unsigned srcId, unsigned dstId, ValuePtr value) { assert(inEdges.count(dstId) > 0); assert(outEdges.count(srcId) > 0); if (value->getType().isa()) { @@ -459,7 +459,7 @@ public: // Returns the input edge count for node 'id' and 'memref' from src nodes // which access 'memref' with a store operation. - unsigned getIncomingMemRefAccesses(unsigned id, Value *memref) { + unsigned getIncomingMemRefAccesses(unsigned id, ValuePtr memref) { unsigned inEdgeCount = 0; if (inEdges.count(id) > 0) for (auto &inEdge : inEdges[id]) @@ -474,7 +474,7 @@ public: // Returns the output edge count for node 'id' and 'memref' (if non-null), // otherwise returns the total output edge count from node 'id'. - unsigned getOutEdgeCount(unsigned id, Value *memref = nullptr) { + unsigned getOutEdgeCount(unsigned id, ValuePtr memref = nullptr) { unsigned outEdgeCount = 0; if (outEdges.count(id) > 0) for (auto &outEdge : outEdges[id]) @@ -548,7 +548,7 @@ public: // Updates edge mappings from node 'srcId' to node 'dstId' after 'oldMemRef' // has been replaced in node at 'dstId' by a private memref depending // on the value of 'createPrivateMemRef'. - void updateEdges(unsigned srcId, unsigned dstId, Value *oldMemRef, + void updateEdges(unsigned srcId, unsigned dstId, ValuePtr oldMemRef, bool createPrivateMemRef) { // For each edge in 'inEdges[srcId]': add new edge remaping to 'dstId'. if (inEdges.count(srcId) > 0) { @@ -681,7 +681,7 @@ public: // TODO(andydavis) Add support for taking a Block arg to construct the // dependence graph at a different depth. bool MemRefDependenceGraph::init(FuncOp f) { - DenseMap> memrefAccesses; + DenseMap> memrefAccesses; // TODO: support multi-block functions. if (f.getBlocks().size() != 1) @@ -701,12 +701,12 @@ bool MemRefDependenceGraph::init(FuncOp f) { Node node(nextNodeId++, &op); for (auto *opInst : collector.loadOpInsts) { node.loads.push_back(opInst); - auto *memref = cast(opInst).getMemRef(); + auto memref = cast(opInst).getMemRef(); memrefAccesses[memref].insert(node.id); } for (auto *opInst : collector.storeOpInsts) { node.stores.push_back(opInst); - auto *memref = cast(opInst).getMemRef(); + auto memref = cast(opInst).getMemRef(); memrefAccesses[memref].insert(node.id); } forToNodeMap[&op] = node.id; @@ -715,14 +715,14 @@ bool MemRefDependenceGraph::init(FuncOp f) { // Create graph node for top-level load op. Node node(nextNodeId++, &op); node.loads.push_back(&op); - auto *memref = cast(op).getMemRef(); + auto memref = cast(op).getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); } else if (auto storeOp = dyn_cast(op)) { // Create graph node for top-level store op. Node node(nextNodeId++, &op); node.stores.push_back(&op); - auto *memref = cast(op).getMemRef(); + auto memref = cast(op).getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); } else if (op.getNumRegions() != 0) { @@ -743,7 +743,7 @@ bool MemRefDependenceGraph::init(FuncOp f) { if (!node.loads.empty() || !node.stores.empty()) continue; auto *opInst = node.op; - for (auto *value : opInst->getResults()) { + for (auto value : opInst->getResults()) { for (auto *user : value->getUsers()) { SmallVector loops; getLoopIVs(*user, &loops); @@ -777,7 +777,7 @@ bool MemRefDependenceGraph::init(FuncOp f) { // Removes load operations from 'srcLoads' which operate on 'memref', and // adds them to 'dstLoads'. -static void moveLoadsAccessingMemrefTo(Value *memref, +static void moveLoadsAccessingMemrefTo(ValuePtr memref, SmallVectorImpl *srcLoads, SmallVectorImpl *dstLoads) { dstLoads->clear(); @@ -893,10 +893,11 @@ static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { // MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'. // TODO(bondhugula): consider refactoring the common code from generateDma and // this one. -static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, - unsigned dstLoopDepth, - Optional fastMemorySpace, - uint64_t localBufSizeThreshold) { +static ValuePtr createPrivateMemRef(AffineForOp forOp, + Operation *srcStoreOpInst, + unsigned dstLoopDepth, + Optional fastMemorySpace, + uint64_t localBufSizeThreshold) { auto *forInst = forOp.getOperation(); // Create builder to insert alloc op just before 'forOp'. @@ -904,7 +905,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, // Builder to create constants at the top level. OpBuilder top(forInst->getParentOfType().getBody()); // Create new memref type based on slice bounds. - auto *oldMemRef = cast(srcStoreOpInst).getMemRef(); + auto oldMemRef = cast(srcStoreOpInst).getMemRef(); auto oldMemRefType = oldMemRef->getType().cast(); unsigned rank = oldMemRefType.getRank(); @@ -928,7 +929,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, // 'outerIVs' holds the values that this memory region is symbolic/parametric // on; this would correspond to loop IVs surrounding the level at which the // slice is being materialized. - SmallVector outerIVs; + SmallVector outerIVs; cst->getIdValues(rank, cst->getNumIds(), &outerIVs); // Build 'rank' AffineExprs from MemRefRegion 'lbs' @@ -960,7 +961,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(), {}, newMemSpace); // Gather alloc operands for the dynamic dimensions of the memref. - SmallVector allocOperands; + SmallVector allocOperands; unsigned dynamicDimCount = 0; for (auto dimSize : oldMemRefType.getShape()) { if (dimSize == -1) @@ -973,7 +974,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, // consumer loop nests to reduce their live range. Currently they are added // at the beginning of the function, because loop nests can be reordered // during the fusion pass. - Value *newMemRef = + ValuePtr newMemRef = top.create(forOp.getLoc(), newMemRefType, allocOperands); // Build an AffineMap to remap access functions based on lower bound offsets. @@ -1016,7 +1017,7 @@ static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, MemRefDependenceGraph *mdg) { assert(srcLiveOutStoreOp && "Expected a valid store op"); auto *dstNode = mdg->getNode(dstId); - Value *memref = srcLiveOutStoreOp.getMemRef(); + ValuePtr memref = srcLiveOutStoreOp.getMemRef(); // Return false if 'srcNode' has more than one output edge on 'memref'. if (mdg->getOutEdgeCount(srcId, memref) > 1) return false; @@ -1495,10 +1496,10 @@ public: SmallVector loads = dstNode->loads; SmallVector dstLoadOpInsts; - DenseSet visitedMemrefs; + DenseSet visitedMemrefs; while (!loads.empty()) { // Get memref of load on top of the stack. - auto *memref = cast(loads.back()).getMemRef(); + auto memref = cast(loads.back()).getMemRef(); if (visitedMemrefs.count(memref) > 0) continue; visitedMemrefs.insert(memref); @@ -1653,7 +1654,7 @@ public: } // TODO(andydavis) Use union of memref write regions to compute // private memref footprint. - auto *newMemRef = createPrivateMemRef( + auto newMemRef = createPrivateMemRef( dstAffineForOp, storesForMemref[0], bestDstLoopDepth, fastMemorySpace, localBufSizeThreshold); visitedMemrefs.insert(newMemRef); @@ -1671,7 +1672,7 @@ public: // Add new load ops to current Node load op list 'loads' to // continue fusing based on new operands. for (auto *loadOpInst : dstLoopCollector.loadOpInsts) { - auto *loadMemRef = cast(loadOpInst).getMemRef(); + auto loadMemRef = cast(loadOpInst).getMemRef(); if (visitedMemrefs.count(loadMemRef) == 0) loads.push_back(loadOpInst); } @@ -1737,10 +1738,10 @@ public: // Attempt to fuse 'dstNode' with sibling nodes in the graph. void fuseWithSiblingNodes(Node *dstNode) { DenseSet visitedSibNodeIds; - std::pair idAndMemref; + std::pair idAndMemref; while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) { unsigned sibId = idAndMemref.first; - Value *memref = idAndMemref.second; + ValuePtr memref = idAndMemref.second; // TODO(andydavis) Check that 'sibStoreOpInst' post-dominates all other // stores to the same memref in 'sibNode' loop nest. auto *sibNode = mdg->getNode(sibId); @@ -1804,10 +1805,10 @@ public: // 'idAndMemrefToFuse' on success. Returns false otherwise. bool findSiblingNodeToFuse(Node *dstNode, DenseSet *visitedSibNodeIds, - std::pair *idAndMemrefToFuse) { + std::pair *idAndMemrefToFuse) { // Returns true if 'sibNode' can be fused with 'dstNode' for input reuse // on 'memref'. - auto canFuseWithSibNode = [&](Node *sibNode, Value *memref) { + auto canFuseWithSibNode = [&](Node *sibNode, ValuePtr memref) { // Skip if 'outEdge' is not a read-after-write dependence. // TODO(andydavis) Remove restrict to single load op restriction. if (sibNode->getLoadOpCount(memref) != 1) @@ -1819,15 +1820,15 @@ public: return false; // Skip sib node if it loads to (and stores from) the same memref on // which it also has an input dependence edge. - DenseSet loadAndStoreMemrefSet; + DenseSet loadAndStoreMemrefSet; sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet); - if (llvm::any_of(loadAndStoreMemrefSet, [=](Value *memref) { + if (llvm::any_of(loadAndStoreMemrefSet, [=](ValuePtr memref) { return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > 0; })) return false; // Check that all stores are to the same memref. - DenseSet storeMemrefs; + DenseSet storeMemrefs; for (auto *storeOpInst : sibNode->stores) { storeMemrefs.insert(cast(storeOpInst).getMemRef()); } @@ -1856,7 +1857,7 @@ public: if (visitedSibNodeIds->count(sibNode->id) > 0) continue; // Skip 'use' if it does not load from the same memref as 'dstNode'. - auto *memref = loadOp.getMemRef(); + auto memref = loadOp.getMemRef(); if (dstNode->getLoadOpCount(memref) == 0) continue; // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'. @@ -1950,7 +1951,7 @@ public: for (auto &pair : mdg->memrefEdgeCount) { if (pair.second > 0) continue; - auto *memref = pair.first; + auto memref = pair.first; // Skip if there exist other uses (return operation or function calls). if (!memref->use_empty()) continue; diff --git a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp index 4932494a04b..bd58827d001 100644 --- a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp +++ b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp @@ -50,7 +50,7 @@ public: // - the op has no side-effects. If sideEffecting is Never, sideeffects of this // op and its nested ops are ignored. static bool canBeHoisted(Operation *op, - function_ref definedOutside, + function_ref definedOutside, SideEffecting sideEffecting, SideEffectsInterface &interface) { // Check that dependencies are defined outside of loop. @@ -92,7 +92,7 @@ static LogicalResult moveLoopInvariantCode(LoopLikeOpInterface looplike, SmallVector opsToMove; // Helper to check whether an operation is loop invariant wrt. SSA properties. - auto isDefinedOutsideOfBody = [&](Value *value) { + auto isDefinedOutsideOfBody = [&](ValuePtr value) { auto definingOp = value->getDefiningOp(); return (definingOp && !!willBeMovedSet.count(definingOp)) || looplike.isDefinedOutsideOfLoop(value); diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 10654783aa9..361a4d8ecb9 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -120,8 +120,8 @@ constructTiledIndexSetHyperRect(MutableArrayRef origLoops, for (unsigned i = 0; i < width; i++) { auto lbOperands = origLoops[i].getLowerBoundOperands(); auto ubOperands = origLoops[i].getUpperBoundOperands(); - SmallVector newLbOperands(lbOperands); - SmallVector newUbOperands(ubOperands); + SmallVector newLbOperands(lbOperands); + SmallVector newUbOperands(ubOperands); newLoops[i].setLowerBound(newLbOperands, origLoops[i].getLowerBoundMap()); newLoops[i].setUpperBound(newUbOperands, origLoops[i].getUpperBoundMap()); newLoops[i].setStep(tileSizes[i]); @@ -147,7 +147,7 @@ constructTiledIndexSetHyperRect(MutableArrayRef origLoops, // with 'i' (tile-space loop) appended to it. The new upper bound map is // the original one with an additional expression i + tileSize appended. auto ub = origLoops[i].getUpperBound(); - SmallVector ubOperands; + SmallVector ubOperands; ubOperands.reserve(ub.getNumOperands() + 1); auto origUbMap = ub.getMap(); // Add dim operands from original upper bound. @@ -235,9 +235,10 @@ LogicalResult mlir::tileCodeGen(MutableArrayRef band, // Move the loop body of the original nest to the new one. moveLoopBody(origLoops[origLoops.size() - 1], innermostPointLoop); - SmallVector origLoopIVs; + SmallVector origLoopIVs; extractForInductionVars(band, &origLoopIVs); - SmallVector, 6> ids(origLoopIVs.begin(), origLoopIVs.end()); + SmallVector, 6> ids(origLoopIVs.begin(), + origLoopIVs.end()); FlatAffineConstraints cst; getIndexSet(band, &cst); diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 230869abcd5..a857b8ec95a 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -191,7 +191,7 @@ LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp, // Adjust the lower bound of the cleanup loop; its upper bound is the same // as the original loop's upper bound. AffineMap cleanupMap; - SmallVector cleanupOperands; + SmallVector cleanupOperands; getCleanupLoopLowerBound(forOp, unrollJamFactor, &cleanupMap, &cleanupOperands, builder); cleanupAffineForOp.setLowerBound(cleanupOperands, cleanupMap); @@ -208,7 +208,7 @@ LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp, int64_t step = forOp.getStep(); forOp.setStep(step * unrollJamFactor); - auto *forOpIV = forOp.getInductionVar(); + auto forOpIV = forOp.getInductionVar(); // Unroll and jam (appends unrollJamFactor - 1 additional copies). for (unsigned i = unrollJamFactor - 1; i >= 1; --i) { // Operand map persists across all sub-blocks. diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index c531ca551b4..0695aafe171 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -76,7 +76,7 @@ struct MemRefDataFlowOpt : public FunctionPass { void forwardStoreToLoad(AffineLoadOp loadOp); // A list of memref's that are potentially dead / could be eliminated. - SmallPtrSet memrefsToErase; + SmallPtrSet memrefsToErase; // Load op's whose results were replaced by those forwarded from stores. SmallVector loadOpsToErase; @@ -180,7 +180,7 @@ void MemRefDataFlowOpt::forwardStoreToLoad(AffineLoadOp loadOp) { return; // Perform the actual store to load forwarding. - Value *storeVal = cast(lastWriteStoreOp).getValueToStore(); + ValuePtr storeVal = cast(lastWriteStoreOp).getValueToStore(); loadOp.replaceAllUsesWith(storeVal); // Record the memref for a later sweep to optimize away. memrefsToErase.insert(loadOp.getMemRef()); @@ -213,7 +213,7 @@ void MemRefDataFlowOpt::runOnFunction() { // Check if the store fwd'ed memrefs are now left with only stores and can // thus be completely deleted. Note: the canonicalize pass should be able // to do this as well, but we'll do it here since we collected these anyway. - for (auto *memref : memrefsToErase) { + for (auto memref : memrefsToErase) { // If the memref hasn't been alloc'ed in this function, skip. Operation *defInst = memref->getDefiningOp(); if (!defInst || !isa(defInst)) diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index fdf01351549..4162936ea2d 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -70,7 +70,7 @@ static unsigned getTagMemRefPos(Operation &dmaInst) { /// Replaces all uses of the old memref by the new one while indexing the newly /// added dimension by the loop IV of the specified 'affine.for' operation /// modulo 2. Returns false if such a replacement cannot be performed. -static bool doubleBuffer(Value *oldMemRef, AffineForOp forOp) { +static bool doubleBuffer(ValuePtr oldMemRef, AffineForOp forOp) { auto *forBody = forOp.getBody(); OpBuilder bInner(forBody, forBody->begin()); @@ -94,7 +94,7 @@ static bool doubleBuffer(Value *oldMemRef, AffineForOp forOp) { auto *forInst = forOp.getOperation(); OpBuilder bOuter(forInst); // Put together alloc operands for any dynamic dimensions of the memref. - SmallVector allocOperands; + SmallVector allocOperands; unsigned dynamicDimCount = 0; for (auto dimSize : oldMemRefType.getShape()) { if (dimSize == -1) @@ -103,7 +103,7 @@ static bool doubleBuffer(Value *oldMemRef, AffineForOp forOp) { } // Create and place the alloc right before the 'affine.for' operation. - Value *newMemRef = + ValuePtr newMemRef = bOuter.create(forInst->getLoc(), newMemRefType, allocOperands); // Create 'iv mod 2' value to index the leading dimension. @@ -212,7 +212,7 @@ static void findMatchingStartFinishInsts( continue; // We only double buffer if the buffer is not live out of loop. - auto *memref = dmaStartOp.getOperand(dmaStartOp.getFasterMemPos()); + auto memref = dmaStartOp.getOperand(dmaStartOp.getFasterMemPos()); bool escapingUses = false; for (auto *user : memref->getUsers()) { // We can double buffer regardless of dealloc's outside the loop. @@ -270,7 +270,7 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { // dimension. for (auto &pair : startWaitPairs) { auto *dmaStartInst = pair.first; - Value *oldMemRef = dmaStartInst->getOperand( + ValuePtr oldMemRef = dmaStartInst->getOperand( cast(dmaStartInst).getFasterMemPos()); if (!doubleBuffer(oldMemRef, forOp)) { // Normally, double buffering should not fail because we already checked @@ -301,7 +301,7 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { // Double the buffers for tag memrefs. for (auto &pair : startWaitPairs) { auto *dmaFinishInst = pair.second; - Value *oldTagMemRef = + ValuePtr oldTagMemRef = dmaFinishInst->getOperand(getTagMemRefPos(*dmaFinishInst)); if (!doubleBuffer(oldTagMemRef, forOp)) { LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";); @@ -342,7 +342,7 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { // If a slice wasn't created, the reachable affine.apply op's from its // operands are the ones that go with it. SmallVector affineApplyInsts; - SmallVector operands(dmaStartInst->getOperands()); + SmallVector operands(dmaStartInst->getOperands()); getReachableAffineApplyOps(operands, affineApplyInsts); for (auto *op : affineApplyInsts) { instShiftMap[op] = 0; diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp index d4b7caae527..85d1f21305e 100644 --- a/mlir/lib/Transforms/Utils/FoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp @@ -90,7 +90,7 @@ LogicalResult OperationFolder::tryToFold( return failure(); // Try to fold the operation. - SmallVector results; + SmallVector results; if (failed(tryToFold(op, results, processGeneratedConstants))) return failure(); @@ -138,7 +138,7 @@ void OperationFolder::notifyRemoval(Operation *op) { /// Tries to perform folding on the given `op`. If successful, populates /// `results` with the results of the folding. LogicalResult OperationFolder::tryToFold( - Operation *op, SmallVectorImpl &results, + Operation *op, SmallVectorImpl &results, function_ref processGeneratedConstants) { SmallVector operandConstants; SmallVector foldResults; @@ -181,13 +181,13 @@ LogicalResult OperationFolder::tryToFold( assert(!foldResults[i].isNull() && "expected valid OpFoldResult"); // Check if the result was an SSA value. - if (auto *repl = foldResults[i].dyn_cast()) { + if (auto repl = foldResults[i].dyn_cast()) { results.emplace_back(repl); continue; } // Check to see if there is a canonicalized version of this constant. - auto *res = op->getResult(i); + auto res = op->getResult(i); Attribute attrRepl = foldResults[i].get(); if (auto *constOp = tryGetOrCreateConstant(uniquedConstants, dialect, builder, attrRepl, diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index e2ca3f8fc5e..fe4a6f9f9e0 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -107,7 +107,7 @@ protected: // simplifications to its users - make sure to add them to the worklist // before the root is changed. void notifyRootReplaced(Operation *op) override { - for (auto *result : op->getResults()) + for (auto result : op->getResults()) for (auto *user : result->getUsers()) addToWorklist(user); } @@ -118,7 +118,7 @@ private: // operation is modified or removed, as it may trigger further // simplifications. template void addToWorklist(Operands &&operands) { - for (Value *operand : operands) { + for (ValuePtr operand : operands) { // If the use count of this operand is now < 2, we re-add the defining // operation to the worklist. // TODO(riverriddle) This is based on the fact that zero use operations @@ -160,7 +160,7 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef regions, region.walk(collectOps); // These are scratch vectors used in the folding loop below. - SmallVector originalOperands, resultValues; + SmallVector originalOperands, resultValues; changed = false; while (!worklist.empty()) { @@ -189,7 +189,7 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef regions, // Add all the users of the result to the worklist so we make sure // to revisit them. - for (auto *result : op->getResults()) + for (auto result : op->getResults()) for (auto *operand : result->getUsers()) addToWorklist(operand); diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp index e8466aa3fd6..048130c0d3a 100644 --- a/mlir/lib/Transforms/Utils/InliningUtils.cpp +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -55,7 +55,7 @@ static void remapInlinedOperands(iterator_range inlinedBlocks, BlockAndValueMapping &mapper) { auto remapOperands = [&](Operation *op) { for (auto &operand : op->getOpOperands()) - if (auto *mappedOp = mapper.lookupOrNull(operand.get())) + if (auto mappedOp = mapper.lookupOrNull(operand.get())) operand.set(mappedOp); }; for (auto &block : inlinedBlocks) @@ -98,7 +98,7 @@ void InlinerInterface::handleTerminator(Operation *op, Block *newDest) const { /// Handle the given inlined terminator by replacing it with a new operation /// as necessary. void InlinerInterface::handleTerminator(Operation *op, - ArrayRef valuesToRepl) const { + ArrayRef valuesToRepl) const { auto *handler = getInterfaceFor(op); assert(handler && "expected valid dialect handler"); handler->handleTerminator(op, valuesToRepl); @@ -137,7 +137,7 @@ static bool isLegalToInline(InlinerInterface &interface, Region *src, LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, Operation *inlinePoint, BlockAndValueMapping &mapper, - ArrayRef resultsToReplace, + ArrayRef resultsToReplace, Optional inlineLoc, bool shouldCloneInlinedRegion) { // We expect the region to have at least one block. @@ -147,7 +147,7 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, // Check that all of the region arguments have been mapped. auto *srcEntryBlock = &src->front(); if (llvm::any_of(srcEntryBlock->getArguments(), - [&](BlockArgument *arg) { return !mapper.contains(arg); })) + [&](BlockArgumentPtr arg) { return !mapper.contains(arg); })) return failure(); // The insertion point must be within a block. @@ -207,7 +207,7 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, } else { // Otherwise, there were multiple blocks inlined. Add arguments to the post // insertion block to represent the results to replace. - for (Value *resultToRepl : resultsToReplace) { + for (ValuePtr resultToRepl : resultsToReplace) { resultToRepl->replaceAllUsesWith( postInsertBlock->addArgument(resultToRepl->getType())); } @@ -229,8 +229,8 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, /// in-favor of the region arguments when inlining. LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, Operation *inlinePoint, - ArrayRef inlinedOperands, - ArrayRef resultsToReplace, + ArrayRef inlinedOperands, + ArrayRef resultsToReplace, Optional inlineLoc, bool shouldCloneInlinedRegion) { // We expect the region to have at least one block. @@ -246,7 +246,7 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) { // Verify that the types of the provided values match the function argument // types. - BlockArgument *regionArg = entryBlock->getArgument(i); + BlockArgumentPtr regionArg = entryBlock->getArgument(i); if (inlinedOperands[i]->getType() != regionArg->getType()) return failure(); mapper.map(regionArg, inlinedOperands[i]); @@ -259,10 +259,10 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, /// Utility function used to generate a cast operation from the given interface, /// or return nullptr if a cast could not be generated. -static Value *materializeConversion(const DialectInlinerInterface *interface, - SmallVectorImpl &castOps, - OpBuilder &castBuilder, Value *arg, - Type type, Location conversionLoc) { +static ValuePtr materializeConversion(const DialectInlinerInterface *interface, + SmallVectorImpl &castOps, + OpBuilder &castBuilder, ValuePtr arg, + Type type, Location conversionLoc) { if (!interface) return nullptr; @@ -297,8 +297,8 @@ LogicalResult mlir::inlineCall(InlinerInterface &interface, // Make sure that the number of arguments and results matchup between the call // and the region. - SmallVector callOperands(call.getArgOperands()); - SmallVector callResults(call.getOperation()->getResults()); + SmallVector callOperands(call.getArgOperands()); + SmallVector callResults(call.getOperation()->getResults()); if (callOperands.size() != entryBlock->getNumArguments() || callResults.size() != callableResultTypes.size()) return failure(); @@ -325,8 +325,8 @@ LogicalResult mlir::inlineCall(InlinerInterface &interface, // Map the provided call operands to the arguments of the region. BlockAndValueMapping mapper; for (unsigned i = 0, e = callOperands.size(); i != e; ++i) { - BlockArgument *regionArg = entryBlock->getArgument(i); - Value *operand = callOperands[i]; + BlockArgumentPtr regionArg = entryBlock->getArgument(i); + ValuePtr operand = callOperands[i]; // If the call operand doesn't match the expected region argument, try to // generate a cast. @@ -342,13 +342,13 @@ LogicalResult mlir::inlineCall(InlinerInterface &interface, // Ensure that the resultant values of the call, match the callable. castBuilder.setInsertionPointAfter(call); for (unsigned i = 0, e = callResults.size(); i != e; ++i) { - Value *callResult = callResults[i]; + ValuePtr callResult = callResults[i]; if (callResult->getType() == callableResultTypes[i]) continue; // Generate a conversion that will produce the original type, so that the IR // is still valid after the original call gets replaced. - Value *castResult = + ValuePtr castResult = materializeConversion(callInterface, castOps, castBuilder, callResult, callResult->getType(), castLoc); if (!castResult) diff --git a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp index fd803390ce7..d5cda3265de 100644 --- a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp @@ -45,7 +45,7 @@ using namespace mlir; // Gathers all load and store memref accesses in 'opA' into 'values', where // 'values[memref] == true' for each store operation. static void getLoadAndStoreMemRefAccesses(Operation *opA, - DenseMap &values) { + DenseMap &values) { opA->walk([&](Operation *op) { if (auto loadOp = dyn_cast(op)) { if (values.count(loadOp.getMemRef()) == 0) @@ -60,7 +60,7 @@ static void getLoadAndStoreMemRefAccesses(Operation *opA, // accessed 'values' and at least one of the access is a store operation. // Returns false otherwise. static bool isDependentLoadOrStoreOp(Operation *op, - DenseMap &values) { + DenseMap &values) { if (auto loadOp = dyn_cast(op)) { return values.count(loadOp.getMemRef()) > 0 && values[loadOp.getMemRef()] == true; @@ -75,7 +75,7 @@ static bool isDependentLoadOrStoreOp(Operation *op, static Operation *getFirstDependentOpInRange(Operation *opA, Operation *opB) { // Record memref values from all loads/store in loop nest rooted at 'opA'. // Map from memref value to bool which is true if store, false otherwise. - DenseMap values; + DenseMap values; getLoadAndStoreMemRefAccesses(opA, values); // For each 'opX' in block in range ('opA', 'opB'), check if there is a data @@ -101,7 +101,7 @@ static Operation *getFirstDependentOpInRange(Operation *opA, Operation *opB) { static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) { // Record memref values from all loads/store in loop nest rooted at 'opB'. // Map from memref value to bool which is true if store, false otherwise. - DenseMap values; + DenseMap values; getLoadAndStoreMemRefAccesses(opB, values); // For each 'opX' in block in range ('opA', 'opB') in reverse order, @@ -121,8 +121,8 @@ static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) { } return WalkResult::advance(); } - for (auto *value : op->getResults()) { - for (auto *user : value->getUsers()) { + for (auto value : op->getResults()) { + for (auto user : value->getUsers()) { SmallVector loops; // Check if any loop in loop nest surrounding 'user' is 'opB'. getLoopIVs(*user, &loops); @@ -443,7 +443,7 @@ bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, // Subtract from operation count the loads/store we expect load/store // forwarding to remove. unsigned storeCount = 0; - llvm::SmallDenseSet storeMemrefs; + llvm::SmallDenseSet storeMemrefs; srcForOp.walk([&](Operation *op) { if (auto storeOp = dyn_cast(op)) { storeMemrefs.insert(storeOp.getMemRef()); @@ -455,7 +455,7 @@ bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, computeCostMap[insertPointParent] = -storeCount; // Subtract out any load users of 'storeMemrefs' nested below // 'insertPointParent'. - for (auto *value : storeMemrefs) { + for (auto value : storeMemrefs) { for (auto *user : value->getUsers()) { if (auto loadOp = dyn_cast(user)) { SmallVector loops; diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 3691aee4870..bc1ced408a9 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -52,7 +52,7 @@ using llvm::SmallMapVector; /// expression. void mlir::getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor, AffineMap *map, - SmallVectorImpl *operands, + SmallVectorImpl *operands, OpBuilder &b) { auto lbMap = forOp.getLowerBoundMap(); @@ -63,7 +63,7 @@ void mlir::getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor, } AffineMap tripCountMap; - SmallVector tripCountOperands; + SmallVector tripCountOperands; buildTripCountMapAndOperands(forOp, &tripCountMap, &tripCountOperands); // Sometimes the trip count cannot be expressed as an affine expression. @@ -82,7 +82,7 @@ void mlir::getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor, // lb + tr1 - tr1 % ufactor, lb + tr2 - tr2 % ufactor; the results of all // these affine.apply's make up the cleanup loop lower bound. SmallVector bumpExprs(tripCountMap.getNumResults()); - SmallVector bumpValues(tripCountMap.getNumResults()); + SmallVector bumpValues(tripCountMap.getNumResults()); for (unsigned i = 0, e = tripCountMap.getNumResults(); i < e; i++) { auto tripCountExpr = tripCountMap.getResult(i); bumpExprs[i] = (tripCountExpr - tripCountExpr % unrollFactor) * step; @@ -105,7 +105,7 @@ void mlir::getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor, *map = simplifyAffineMap(*map); canonicalizeMapAndOperands(map, operands); // Remove any affine.apply's that became dead from the simplification above. - for (auto *v : bumpValues) { + for (auto v : bumpValues) { if (v->use_empty()) { v->getDefiningOp()->erase(); } @@ -127,7 +127,7 @@ LogicalResult mlir::promoteIfSingleIteration(AffineForOp forOp) { return failure(); // Replaces all IV uses to its single iteration value. - auto *iv = forOp.getInductionVar(); + auto iv = forOp.getInductionVar(); Operation *op = forOp.getOperation(); if (!iv->use_empty()) { if (forOp.hasConstantLowerBound()) { @@ -137,7 +137,7 @@ LogicalResult mlir::promoteIfSingleIteration(AffineForOp forOp) { iv->replaceAllUsesWith(constOp); } else { AffineBound lb = forOp.getLowerBound(); - SmallVector lbOperands(lb.operand_begin(), lb.operand_end()); + SmallVector lbOperands(lb.operand_begin(), lb.operand_end()); OpBuilder builder(op->getBlock(), Block::iterator(op)); if (lb.getMap() == builder.getDimIdentityMap()) { // No need of generating an affine.apply. @@ -178,8 +178,8 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, const std::vector>> &instGroupQueue, unsigned offset, AffineForOp srcForInst, OpBuilder b) { - SmallVector lbOperands(srcForInst.getLowerBoundOperands()); - SmallVector ubOperands(srcForInst.getUpperBoundOperands()); + SmallVector lbOperands(srcForInst.getLowerBoundOperands()); + SmallVector ubOperands(srcForInst.getUpperBoundOperands()); assert(lbMap.getNumInputs() == lbOperands.size()); assert(ubMap.getNumInputs() == ubOperands.size()); @@ -187,8 +187,8 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, auto loopChunk = b.create(srcForInst.getLoc(), lbOperands, lbMap, ubOperands, ubMap, srcForInst.getStep()); - auto *loopChunkIV = loopChunk.getInductionVar(); - auto *srcIV = srcForInst.getInductionVar(); + auto loopChunkIV = loopChunk.getInductionVar(); + auto srcIV = srcForInst.getInductionVar(); BlockAndValueMapping operandMap; @@ -449,7 +449,7 @@ LogicalResult mlir::loopUnrollByFactor(AffineForOp forOp, OpBuilder builder(op->getBlock(), ++Block::iterator(op)); auto cleanupForInst = cast(builder.clone(*op)); AffineMap cleanupMap; - SmallVector cleanupOperands; + SmallVector cleanupOperands; getCleanupLoopLowerBound(forOp, unrollFactor, &cleanupMap, &cleanupOperands, builder); assert(cleanupMap && @@ -477,7 +477,7 @@ LogicalResult mlir::loopUnrollByFactor(AffineForOp forOp, Block::iterator srcBlockEnd = std::prev(forOp.getBody()->end(), 2); // Unroll the contents of 'forOp' (append unrollFactor-1 additional copies). - auto *forOpIV = forOp.getInductionVar(); + auto forOpIV = forOp.getInductionVar(); for (unsigned i = 1; i < unrollFactor; i++) { BlockAndValueMapping operandMap; @@ -669,8 +669,8 @@ void mlir::sinkLoop(AffineForOp forOp, unsigned loopDepth) { // ... // } // ``` -static void augmentMapAndBounds(OpBuilder &b, Value *iv, AffineMap *map, - SmallVector *operands, +static void augmentMapAndBounds(OpBuilder &b, ValuePtr iv, AffineMap *map, + SmallVector *operands, int64_t offset = 0) { auto bounds = llvm::to_vector<4>(map->getResults()); bounds.push_back(b.getAffineDimExpr(map->getNumDims()) + offset); @@ -699,16 +699,16 @@ stripmineSink(AffineForOp forOp, uint64_t factor, // Lower-bound map creation. auto lbMap = forOp.getLowerBoundMap(); - SmallVector lbOperands(forOp.getLowerBoundOperands()); + SmallVector lbOperands(forOp.getLowerBoundOperands()); augmentMapAndBounds(b, forOp.getInductionVar(), &lbMap, &lbOperands); // Upper-bound map creation. auto ubMap = forOp.getUpperBoundMap(); - SmallVector ubOperands(forOp.getUpperBoundOperands()); + SmallVector ubOperands(forOp.getUpperBoundOperands()); augmentMapAndBounds(b, forOp.getInductionVar(), &ubMap, &ubOperands, /*offset=*/scaledStep); - auto *iv = forOp.getInductionVar(); + auto iv = forOp.getInductionVar(); SmallVector innerLoops; for (auto t : targets) { // Insert newForOp before the terminator of `t`. @@ -729,10 +729,10 @@ stripmineSink(AffineForOp forOp, uint64_t factor, return innerLoops; } -static Loops stripmineSink(loop::ForOp forOp, Value *factor, +static Loops stripmineSink(loop::ForOp forOp, ValuePtr factor, ArrayRef targets) { - auto *originalStep = forOp.step(); - auto *iv = forOp.getInductionVar(); + auto originalStep = forOp.step(); + auto iv = forOp.getInductionVar(); OpBuilder b(forOp); forOp.setStep(b.create(forOp.getLoc(), originalStep, factor)); @@ -745,10 +745,10 @@ static Loops stripmineSink(loop::ForOp forOp, Value *factor, // Insert newForOp before the terminator of `t`. OpBuilder b(t.getBodyBuilder()); - Value *stepped = b.create(t.getLoc(), iv, forOp.step()); - Value *less = b.create(t.getLoc(), CmpIPredicate::slt, - forOp.upperBound(), stepped); - Value *ub = + ValuePtr stepped = b.create(t.getLoc(), iv, forOp.step()); + ValuePtr less = b.create(t.getLoc(), CmpIPredicate::slt, + forOp.upperBound(), stepped); + ValuePtr ub = b.create(t.getLoc(), less, forOp.upperBound(), stepped); // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses. @@ -799,7 +799,7 @@ mlir::tile(ArrayRef forOps, ArrayRef sizes, } SmallVector mlir::tile(ArrayRef forOps, - ArrayRef sizes, + ArrayRef sizes, ArrayRef targets) { return tileImpl(forOps, sizes, targets); } @@ -821,13 +821,13 @@ SmallVector mlir::tile(ArrayRef forOps, return tileImpl(forOps, sizes, target); } -Loops mlir::tile(ArrayRef forOps, ArrayRef sizes, +Loops mlir::tile(ArrayRef forOps, ArrayRef sizes, loop::ForOp target) { return tileImpl(forOps, sizes, target); } Loops mlir::tilePerfectlyNested(loop::ForOp rootForOp, - ArrayRef sizes) { + ArrayRef sizes) { // Collect perfectly nested loops. If more size values provided than nested // loops available, truncate `sizes`. SmallVector forOps; @@ -842,14 +842,15 @@ Loops mlir::tilePerfectlyNested(loop::ForOp rootForOp, // Build the IR that performs ceil division of a positive value by a constant: // ceildiv(a, B) = divis(a + (B-1), B) // where divis is rounding-to-zero division. -static Value *ceilDivPositive(OpBuilder &builder, Location loc, Value *dividend, - int64_t divisor) { +static ValuePtr ceilDivPositive(OpBuilder &builder, Location loc, + ValuePtr dividend, int64_t divisor) { assert(divisor > 0 && "expected positive divisor"); assert(dividend->getType().isIndex() && "expected index-typed value"); - Value *divisorMinusOneCst = builder.create(loc, divisor - 1); - Value *divisorCst = builder.create(loc, divisor); - Value *sum = builder.create(loc, dividend, divisorMinusOneCst); + ValuePtr divisorMinusOneCst = + builder.create(loc, divisor - 1); + ValuePtr divisorCst = builder.create(loc, divisor); + ValuePtr sum = builder.create(loc, dividend, divisorMinusOneCst); return builder.create(loc, sum, divisorCst); } @@ -857,13 +858,13 @@ static Value *ceilDivPositive(OpBuilder &builder, Location loc, Value *dividend, // positive value: // ceildiv(a, b) = divis(a + (b - 1), b) // where divis is rounding-to-zero division. -static Value *ceilDivPositive(OpBuilder &builder, Location loc, Value *dividend, - Value *divisor) { +static ValuePtr ceilDivPositive(OpBuilder &builder, Location loc, + ValuePtr dividend, ValuePtr divisor) { assert(dividend->getType().isIndex() && "expected index-typed value"); - Value *cstOne = builder.create(loc, 1); - Value *divisorMinusOne = builder.create(loc, divisor, cstOne); - Value *sum = builder.create(loc, dividend, divisorMinusOne); + ValuePtr cstOne = builder.create(loc, 1); + ValuePtr divisorMinusOne = builder.create(loc, divisor, cstOne); + ValuePtr sum = builder.create(loc, dividend, divisorMinusOne); return builder.create(loc, sum, divisor); } @@ -945,7 +946,7 @@ TileLoops mlir::extractFixedOuterLoops(loop::ForOp rootForOp, // iterations. Given that the loop current executes // numIterations = ceildiv((upperBound - lowerBound), step) // iterations, we need to tile with size ceildiv(numIterations, size[i]). - SmallVector tileSizes; + SmallVector tileSizes; tileSizes.reserve(sizes.size()); for (unsigned i = 0, e = sizes.size(); i < e; ++i) { assert(sizes[i] > 0 && "expected strictly positive size for strip-mining"); @@ -953,10 +954,10 @@ TileLoops mlir::extractFixedOuterLoops(loop::ForOp rootForOp, auto forOp = forOps[i]; OpBuilder builder(forOp); auto loc = forOp.getLoc(); - Value *diff = + ValuePtr diff = builder.create(loc, forOp.upperBound(), forOp.lowerBound()); - Value *numIterations = ceilDivPositive(builder, loc, diff, forOp.step()); - Value *iterationsPerBlock = + ValuePtr numIterations = ceilDivPositive(builder, loc, diff, forOp.step()); + ValuePtr iterationsPerBlock = ceilDivPositive(builder, loc, numIterations, sizes[i]); tileSizes.push_back(iterationsPerBlock); } @@ -976,7 +977,7 @@ TileLoops mlir::extractFixedOuterLoops(loop::ForOp rootForOp, // Replaces all uses of `orig` with `replacement` except if the user is listed // in `exceptions`. static void -replaceAllUsesExcept(Value *orig, Value *replacement, +replaceAllUsesExcept(ValuePtr orig, ValuePtr replacement, const SmallPtrSetImpl &exceptions) { for (auto &use : llvm::make_early_inc_range(orig->getUses())) { if (exceptions.count(use.getOwner()) == 0) @@ -1018,30 +1019,30 @@ static void normalizeLoop(loop::ForOp loop, loop::ForOp outer, // of the loop to go from 0 to the number of iterations, if necessary. // TODO(zinenko): introduce support for negative steps or emit dynamic asserts // on step positivity, whatever gets implemented first. - Value *diff = + ValuePtr diff = builder.create(loc, loop.upperBound(), loop.lowerBound()); - Value *numIterations = ceilDivPositive(builder, loc, diff, loop.step()); + ValuePtr numIterations = ceilDivPositive(builder, loc, diff, loop.step()); loop.setUpperBound(numIterations); - Value *lb = loop.lowerBound(); + ValuePtr lb = loop.lowerBound(); if (!isZeroBased) { - Value *cst0 = builder.create(loc, 0); + ValuePtr cst0 = builder.create(loc, 0); loop.setLowerBound(cst0); } - Value *step = loop.step(); + ValuePtr step = loop.step(); if (!isStepOne) { - Value *cst1 = builder.create(loc, 1); + ValuePtr cst1 = builder.create(loc, 1); loop.setStep(cst1); } // Insert code computing the value of the original loop induction variable // from the "normalized" one. builder.setInsertionPointToStart(inner.getBody()); - Value *scaled = + ValuePtr scaled = isStepOne ? loop.getInductionVar() : builder.create(loc, loop.getInductionVar(), step); - Value *shifted = + ValuePtr shifted = isZeroBased ? scaled : builder.create(loc, scaled, lb); SmallPtrSet preserve{scaled->getDefiningOp(), @@ -1065,7 +1066,7 @@ void mlir::coalesceLoops(MutableArrayRef loops) { // of the number of iterations of all loops. OpBuilder builder(outermost); Location loc = outermost.getLoc(); - Value *upperBound = outermost.upperBound(); + ValuePtr upperBound = outermost.upperBound(); for (auto loop : loops.drop_front()) upperBound = builder.create(loc, upperBound, loop.upperBound()); outermost.setUpperBound(upperBound); @@ -1080,16 +1081,16 @@ void mlir::coalesceLoops(MutableArrayRef loops) { // iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i. // Compute these iteratively from the innermost loop by creating a "running // quotient" of division by the range. - Value *previous = outermost.getInductionVar(); + ValuePtr previous = outermost.getInductionVar(); for (unsigned i = 0, e = loops.size(); i < e; ++i) { unsigned idx = loops.size() - i - 1; if (i != 0) previous = builder.create(loc, previous, loops[idx + 1].upperBound()); - Value *iv = (i == e - 1) ? previous - : builder.create( - loc, previous, loops[idx].upperBound()); + ValuePtr iv = (i == e - 1) ? previous + : builder.create( + loc, previous, loops[idx].upperBound()); replaceAllUsesInRegionWith(loops[idx].getInductionVar(), iv, loops.back().region()); } @@ -1105,24 +1106,24 @@ void mlir::coalesceLoops(MutableArrayRef loops) { } void mlir::mapLoopToProcessorIds(loop::ForOp forOp, - ArrayRef processorId, - ArrayRef numProcessors) { + ArrayRef processorId, + ArrayRef numProcessors) { assert(processorId.size() == numProcessors.size()); if (processorId.empty()) return; OpBuilder b(forOp); Location loc(forOp.getLoc()); - Value *mul = processorId.front(); + ValuePtr mul = processorId.front(); for (unsigned i = 1, e = processorId.size(); i < e; ++i) mul = b.create(loc, b.create(loc, mul, numProcessors[i]), processorId[i]); - Value *lb = b.create(loc, forOp.lowerBound(), - b.create(loc, forOp.step(), mul)); + ValuePtr lb = b.create(loc, forOp.lowerBound(), + b.create(loc, forOp.step(), mul)); forOp.setLowerBound(lb); - Value *step = forOp.step(); - for (auto *numProcs : numProcessors) + ValuePtr step = forOp.step(); + for (auto numProcs : numProcessors) step = b.create(loc, step, numProcs); forOp.setStep(step); } @@ -1139,7 +1140,7 @@ findHighestBlockForPlacement(const MemRefRegion ®ion, Block &block, Block::iterator *copyInPlacementStart, Block::iterator *copyOutPlacementStart) { const auto *cst = region.getConstraints(); - SmallVector symbols; + SmallVector symbols; cst->getIdValues(cst->getNumDimIds(), cst->getNumDimAndSymbolIds(), &symbols); SmallVector enclosingFors; @@ -1202,10 +1203,10 @@ static void getMultiLevelStrides(const MemRefRegion ®ion, /// returns the outermost AffineForOp of the copy loop nest. `memIndicesStart' /// holds the lower coordinates of the region in the original memref to copy /// in/out. If `copyOut' is true, generates a copy-out; otherwise a copy-in. -static AffineForOp generatePointWiseCopy(Location loc, Value *memref, - Value *fastMemRef, +static AffineForOp generatePointWiseCopy(Location loc, ValuePtr memref, + ValuePtr fastMemRef, AffineMap memAffineMap, - ArrayRef memIndicesStart, + ArrayRef memIndicesStart, ArrayRef fastBufferShape, bool isCopyOut, OpBuilder b) { assert(!memIndicesStart.empty() && "only 1-d or more memrefs"); @@ -1215,7 +1216,7 @@ static AffineForOp generatePointWiseCopy(Location loc, Value *memref, // for y = ... // fast_buf[x][y] = buf[mem_x + x][mem_y + y] - SmallVector fastBufIndices, memIndices; + SmallVector fastBufIndices, memIndices; AffineForOp copyNestRoot; for (unsigned d = 0, e = fastBufferShape.size(); d < e; ++d) { auto forOp = b.create(loc, 0, fastBufferShape[d]); @@ -1224,7 +1225,7 @@ static AffineForOp generatePointWiseCopy(Location loc, Value *memref, b = forOp.getBodyBuilder(); fastBufIndices.push_back(forOp.getInductionVar()); - Value *memBase = + ValuePtr memBase = (memAffineMap == b.getMultiDimIdentityMap(memAffineMap.getNumDims())) ? memIndicesStart[d] : b.create( @@ -1277,7 +1278,7 @@ static LogicalResult generateCopy( const MemRefRegion ®ion, Block *block, Block::iterator begin, Block::iterator end, Block *copyPlacementBlock, Block::iterator copyInPlacementStart, Block::iterator copyOutPlacementStart, - AffineCopyOptions copyOptions, DenseMap &fastBufferMap, + AffineCopyOptions copyOptions, DenseMap &fastBufferMap, DenseSet ©Nests, uint64_t *sizeInBytes, Block::iterator *nBegin, Block::iterator *nEnd) { *nBegin = begin; @@ -1285,7 +1286,7 @@ static LogicalResult generateCopy( FuncOp f = begin->getParentOfType(); OpBuilder topBuilder(f.getBody()); - Value *zeroIndex = topBuilder.create(f.getLoc(), 0); + ValuePtr zeroIndex = topBuilder.create(f.getLoc(), 0); if (begin == end) return success(); @@ -1305,7 +1306,7 @@ static LogicalResult generateCopy( OpBuilder top(func.getBody()); auto loc = region.loc; - auto *memref = region.memref; + auto memref = region.memref; auto memRefType = memref->getType().cast(); auto layoutMaps = memRefType.getAffineMaps(); @@ -1317,9 +1318,9 @@ static LogicalResult generateCopy( // Indices to use for the copying. // Indices for the original memref being copied from/to. - SmallVector memIndices; + SmallVector memIndices; // Indices for the faster buffer being copied into/from. - SmallVector bufIndices; + SmallVector bufIndices; unsigned rank = memRefType.getRank(); SmallVector fastBufferShape; @@ -1345,7 +1346,7 @@ static LogicalResult generateCopy( // 'regionSymbols' hold values that this memory region is symbolic/parametric // on; these typically include loop IVs surrounding the level at which the // copy generation is being done or other valid symbols in MLIR. - SmallVector regionSymbols; + SmallVector regionSymbols; cst->getIdValues(rank, cst->getNumIds(), ®ionSymbols); // Construct the index expressions for the fast memory buffer. The index @@ -1393,7 +1394,7 @@ static LogicalResult generateCopy( } // The faster memory space buffer. - Value *fastMemRef; + ValuePtr fastMemRef; // Check if a buffer was already created. bool existingBuf = fastBufferMap.count(memref) > 0; @@ -1433,8 +1434,8 @@ static LogicalResult generateCopy( return failure(); } - Value *stride = nullptr; - Value *numEltPerStride = nullptr; + ValuePtr stride = nullptr; + ValuePtr numEltPerStride = nullptr; if (!strideInfos.empty()) { stride = top.create(loc, strideInfos[0].stride); numEltPerStride = @@ -1473,7 +1474,7 @@ static LogicalResult generateCopy( copyOptions.tagMemorySpace); auto tagMemRef = prologue.create(loc, tagMemRefType); - SmallVector tagIndices({zeroIndex}); + SmallVector tagIndices({zeroIndex}); auto tagAffineMap = b.getMultiDimIdentityMap(tagIndices.size()); fullyComposeAffineMapAndOperands(&tagAffineMap, &tagIndices); if (!region.isWrite()) { @@ -1582,7 +1583,7 @@ static bool getFullMemRefAsRegion(Operation *opInst, unsigned numParamLoopIVs, SmallVector ivs; getLoopIVs(*opInst, &ivs); ivs.resize(numParamLoopIVs); - SmallVector symbols; + SmallVector symbols; extractForInductionVars(ivs, &symbols); regionCst->reset(rank, numParamLoopIVs, 0); regionCst->setIdValues(rank, rank + numParamLoopIVs, symbols); @@ -1629,12 +1630,12 @@ uint64_t mlir::affineDataCopyGenerate(Block::iterator begin, // List of memory regions to copy for. We need a map vector to have a // guaranteed iteration order to write test cases. CHECK-DAG doesn't help here // since the alloc's for example are identical except for the SSA id. - SmallMapVector, 4> readRegions; - SmallMapVector, 4> writeRegions; + SmallMapVector, 4> readRegions; + SmallMapVector, 4> writeRegions; // Map from original memref's to the fast buffers that their accesses are // replaced with. - DenseMap fastBufferMap; + DenseMap fastBufferMap; // To check for errors when walking the block. bool error = false; @@ -1684,7 +1685,7 @@ uint64_t mlir::affineDataCopyGenerate(Block::iterator begin, // Attempts to update; returns true if 'region' exists in targetRegions. auto updateRegion = - [&](const SmallMapVector, 4> + [&](const SmallMapVector, 4> &targetRegions) { auto it = targetRegions.find(region->memref); if (it == targetRegions.end()) @@ -1736,7 +1737,7 @@ uint64_t mlir::affineDataCopyGenerate(Block::iterator begin, uint64_t totalCopyBuffersSizeInBytes = 0; bool ret = true; auto processRegions = - [&](const SmallMapVector, 4> + [&](const SmallMapVector, 4> ®ions) { for (const auto ®ionEntry : regions) { // For each region, hoist copy in/out past all hoistable diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp index b91b189b381..749d5bf1dd0 100644 --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -27,9 +27,9 @@ using namespace mlir; -void mlir::replaceAllUsesInRegionWith(Value *orig, Value *replacement, +void mlir::replaceAllUsesInRegionWith(ValuePtr orig, ValuePtr replacement, Region ®ion) { - for (IROperand &use : llvm::make_early_inc_range(orig->getUses())) { + for (auto &use : llvm::make_early_inc_range(orig->getUses())) { if (region.isAncestor(use.getOwner()->getParentRegion())) use.set(replacement); } @@ -63,14 +63,14 @@ void mlir::visitUsedValuesDefinedAbove( } void mlir::getUsedValuesDefinedAbove(Region ®ion, Region &limit, - llvm::SetVector &values) { + llvm::SetVector &values) { visitUsedValuesDefinedAbove(region, limit, [&](OpOperand *operand) { values.insert(operand->get()); }); } void mlir::getUsedValuesDefinedAbove(MutableArrayRef regions, - llvm::SetVector &values) { + llvm::SetVector &values) { for (Region ®ion : regions) getUsedValuesDefinedAbove(region, region, values); } @@ -146,8 +146,8 @@ namespace { class LiveMap { public: /// Value methods. - bool wasProvenLive(Value *value) { return liveValues.count(value); } - void setProvedLive(Value *value) { + bool wasProvenLive(ValuePtr value) { return liveValues.count(value); } + void setProvedLive(ValuePtr value) { changed |= liveValues.insert(value).second; } @@ -161,7 +161,7 @@ public: private: bool changed = false; - DenseSet liveValues; + DenseSet liveValues; DenseSet liveOps; }; } // namespace @@ -188,7 +188,7 @@ static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap) { return false; } -static void processValue(Value *value, LiveMap &liveMap) { +static void processValue(ValuePtr value, LiveMap &liveMap) { bool provedLive = llvm::any_of(value->getUses(), [&](OpOperand &use) { if (isUseSpeciallyKnownDead(use, liveMap)) return false; @@ -222,9 +222,9 @@ static void propagateLiveness(Operation *op, LiveMap &liveMap) { liveMap.setProvedLive(op); return; } - for (Value *value : op->getResults()) + for (ValuePtr value : op->getResults()) processValue(value, liveMap); - bool provedLive = llvm::any_of(op->getResults(), [&](Value *value) { + bool provedLive = llvm::any_of(op->getResults(), [&](ValuePtr value) { return liveMap.wasProvenLive(value); }); if (provedLive) @@ -240,7 +240,7 @@ static void propagateLiveness(Region ®ion, LiveMap &liveMap) { // faster convergence to a fixed point (we try to visit uses before defs). for (Operation &op : llvm::reverse(block->getOperations())) propagateLiveness(&op, liveMap); - for (Value *value : block->getArguments()) + for (ValuePtr value : block->getArguments()) processValue(value, liveMap); } } @@ -259,7 +259,7 @@ static void eraseTerminatorSuccessorOperands(Operation *terminator, // Iterating args in reverse is needed for correctness, to avoid // shifting later args when earlier args are erased. unsigned arg = argE - argI - 1; - Value *value = terminator->getSuccessor(succ)->getArgument(arg); + ValuePtr value = terminator->getSuccessor(succ)->getArgument(arg); if (!liveMap.wasProvenLive(value)) { terminator->eraseSuccessorOperand(succ, arg); } diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 57a92531163..96a6cdc544f 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -47,7 +47,8 @@ static bool isMemRefDereferencingOp(Operation &op) { } /// Return the AffineMapAttr associated with memory 'op' on 'memref'. -static NamedAttribute getAffineMapAttrForMemRef(Operation *op, Value *memref) { +static NamedAttribute getAffineMapAttrForMemRef(Operation *op, + ValuePtr memref) { return TypeSwitch(op) .Case( @@ -55,12 +56,10 @@ static NamedAttribute getAffineMapAttrForMemRef(Operation *op, Value *memref) { } // Perform the replacement in `op`. -LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, - Operation *op, - ArrayRef extraIndices, - AffineMap indexRemap, - ArrayRef extraOperands, - ArrayRef symbolOperands) { +LogicalResult mlir::replaceAllMemRefUsesWith( + ValuePtr oldMemRef, ValuePtr newMemRef, Operation *op, + ArrayRef extraIndices, AffineMap indexRemap, + ArrayRef extraOperands, ArrayRef symbolOperands) { unsigned newMemRefRank = newMemRef->getType().cast().getRank(); (void)newMemRefRank; // unused in opt mode unsigned oldMemRefRank = oldMemRef->getType().cast().getRank(); @@ -106,13 +105,13 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, NamedAttribute oldMapAttrPair = getAffineMapAttrForMemRef(op, oldMemRef); AffineMap oldMap = oldMapAttrPair.second.cast().getValue(); unsigned oldMapNumInputs = oldMap.getNumInputs(); - SmallVector oldMapOperands( + SmallVector oldMapOperands( op->operand_begin() + memRefOperandPos + 1, op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs); // Apply 'oldMemRefOperands = oldMap(oldMapOperands)'. - SmallVector oldMemRefOperands; - SmallVector affineApplyOps; + SmallVector oldMemRefOperands; + SmallVector affineApplyOps; oldMemRefOperands.reserve(oldMemRefRank); if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) { for (auto resultExpr : oldMap.getResults()) { @@ -130,14 +129,14 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, // Construct new indices as a remap of the old ones if a remapping has been // provided. The indices of a memref come right after it, i.e., // at position memRefOperandPos + 1. - SmallVector remapOperands; + SmallVector remapOperands; remapOperands.reserve(extraOperands.size() + oldMemRefRank + symbolOperands.size()); remapOperands.append(extraOperands.begin(), extraOperands.end()); remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end()); remapOperands.append(symbolOperands.begin(), symbolOperands.end()); - SmallVector remapOutputs; + SmallVector remapOutputs; remapOutputs.reserve(oldMemRefRank); if (indexRemap && @@ -156,11 +155,11 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, remapOutputs.append(remapOperands.begin(), remapOperands.end()); } - SmallVector newMapOperands; + SmallVector newMapOperands; newMapOperands.reserve(newMemRefRank); // Prepend 'extraIndices' in 'newMapOperands'. - for (auto *extraIndex : extraIndices) { + for (auto extraIndex : extraIndices) { assert(extraIndex->getDefiningOp()->getNumResults() == 1 && "single result op's expected to generate these indices"); assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) && @@ -179,7 +178,7 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, newMap = simplifyAffineMap(newMap); canonicalizeMapAndOperands(&newMap, &newMapOperands); // Remove any affine.apply's that became dead as a result of composition. - for (auto *value : affineApplyOps) + for (auto value : affineApplyOps) if (value->use_empty()) value->getDefiningOp()->erase(); @@ -203,7 +202,7 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, // Result types don't change. Both memref's are of the same elemental type. state.types.reserve(op->getNumResults()); - for (auto *result : op->getResults()) + for (auto result : op->getResults()) state.types.push_back(result->getType()); // Add attribute for 'newMap', other Attributes do not change. @@ -224,13 +223,11 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, return success(); } -LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, - ArrayRef extraIndices, - AffineMap indexRemap, - ArrayRef extraOperands, - ArrayRef symbolOperands, - Operation *domInstFilter, - Operation *postDomInstFilter) { +LogicalResult mlir::replaceAllMemRefUsesWith( + ValuePtr oldMemRef, ValuePtr newMemRef, ArrayRef extraIndices, + AffineMap indexRemap, ArrayRef extraOperands, + ArrayRef symbolOperands, Operation *domInstFilter, + Operation *postDomInstFilter) { unsigned newMemRefRank = newMemRef->getType().cast().getRank(); (void)newMemRefRank; // unused in opt mode unsigned oldMemRefRank = oldMemRef->getType().cast().getRank(); @@ -331,9 +328,9 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, void mlir::createAffineComputationSlice( Operation *opInst, SmallVectorImpl *sliceOps) { // Collect all operands that are results of affine apply ops. - SmallVector subOperands; + SmallVector subOperands; subOperands.reserve(opInst->getNumOperands()); - for (auto *operand : opInst->getOperands()) + for (auto operand : opInst->getOperands()) if (isa_and_nonnull(operand->getDefiningOp())) subOperands.push_back(operand); @@ -348,7 +345,7 @@ void mlir::createAffineComputationSlice( // which case there would be nothing to do. bool localized = true; for (auto *op : affineApplyOps) { - for (auto *result : op->getResults()) { + for (auto result : op->getResults()) { for (auto *user : result->getUsers()) { if (user != opInst) { localized = false; @@ -361,7 +358,7 @@ void mlir::createAffineComputationSlice( return; OpBuilder builder(opInst); - SmallVector composedOpOperands(subOperands); + SmallVector composedOpOperands(subOperands); auto composedMap = builder.getMultiDimIdentityMap(composedOpOperands.size()); fullyComposeAffineMapAndOperands(&composedMap, &composedOpOperands); @@ -378,7 +375,7 @@ void mlir::createAffineComputationSlice( // affine apply op above instead of existing ones (subOperands). So, they // differ from opInst's operands only for those operands in 'subOperands', for // which they will be replaced by the corresponding one from 'sliceOps'. - SmallVector newOperands(opInst->getOperands()); + SmallVector newOperands(opInst->getOperands()); for (unsigned i = 0, e = newOperands.size(); i < e; i++) { // Replace the subOperands from among the new operands. unsigned j, f; @@ -451,8 +448,8 @@ LogicalResult mlir::normalizeMemRef(AllocOp allocOp) { newShape[d] = ubConst.getValue() + 1; } - auto *oldMemRef = allocOp.getResult(); - SmallVector symbolOperands(allocOp.getSymbolicOperands()); + auto oldMemRef = allocOp.getResult(); + SmallVector symbolOperands(allocOp.getSymbolicOperands()); auto newMemRefType = MemRefType::get(newShape, memrefType.getElementType(), b.getMultiDimIdentityMap(newRank)); diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index e3212d54e42..d8f5b1dc0e4 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -705,7 +705,7 @@ struct VectorizationState { // Map of old scalar Operation to new vectorized Operation. DenseMap vectorizationMap; // Map of old scalar Value to new vectorized Value. - DenseMap replacementMap; + DenseMap replacementMap; // The strategy drives which loop to vectorize by which amount. const VectorizationStrategy *strategy; // Use-def roots. These represent the starting points for the worklist in the @@ -728,7 +728,7 @@ struct VectorizationState { OperationFolder *folder; private: - void registerReplacement(Value *key, Value *value); + void registerReplacement(ValuePtr key, ValuePtr value); }; } // end namespace @@ -768,7 +768,7 @@ void VectorizationState::finishVectorizationPattern() { } } -void VectorizationState::registerReplacement(Value *key, Value *value) { +void VectorizationState::registerReplacement(ValuePtr key, ValuePtr value) { assert(replacementMap.count(key) == 0 && "replacement already registered"); replacementMap.insert(std::make_pair(key, value)); } @@ -776,7 +776,7 @@ void VectorizationState::registerReplacement(Value *key, Value *value) { // Apply 'map' with 'mapOperands' returning resulting values in 'results'. static void computeMemoryOpIndices(Operation *op, AffineMap map, ValueRange mapOperands, - SmallVectorImpl &results) { + SmallVectorImpl &results) { OpBuilder builder(op); for (auto resultExpr : map.getResults()) { auto singleResMap = @@ -803,7 +803,7 @@ static void computeMemoryOpIndices(Operation *op, AffineMap map, /// Such special cases force us to delay the vectorization of the stores until /// the last step. Here we merely register the store operation. template -static LogicalResult vectorizeRootOrTerminal(Value *iv, +static LogicalResult vectorizeRootOrTerminal(ValuePtr iv, LoadOrStoreOpPointer memoryOp, VectorizationState *state) { auto memRefType = memoryOp.getMemRef()->getType().template cast(); @@ -823,7 +823,7 @@ static LogicalResult vectorizeRootOrTerminal(Value *iv, if (auto load = dyn_cast(opInst)) { OpBuilder b(opInst); ValueRange mapOperands = load.getMapOperands(); - SmallVector indices; + SmallVector indices; indices.reserve(load.getMemRefType().getRank()); if (load.getAffineMap() != b.getMultiDimIdentityMap(load.getMemRefType().getRank())) { @@ -838,8 +838,7 @@ static LogicalResult vectorizeRootOrTerminal(Value *iv, LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: "); LLVM_DEBUG(permutationMap.print(dbgs())); auto transfer = b.create( - opInst->getLoc(), vectorType, memoryOp.getMemRef(), - map(makePtrDynCaster(), indices), + opInst->getLoc(), vectorType, memoryOp.getMemRef(), indices, AffineMapAttr::get(permutationMap), // TODO(b/144455320) add a proper padding value, not just 0.0 : f32 state->folder->create(b, opInst->getLoc(), @@ -951,7 +950,8 @@ vectorizeLoopsAndLoadsRecursively(NestedMatch oneMatch, /// element type. /// If `type` is not a valid vector type or if the scalar constant is not a /// valid vector element type, returns nullptr. -static Value *vectorizeConstant(Operation *op, ConstantOp constant, Type type) { +static ValuePtr vectorizeConstant(Operation *op, ConstantOp constant, + Type type) { if (!type || !type.isa() || !VectorType::isValidElementType(constant.getType())) { return nullptr; @@ -989,8 +989,8 @@ static Value *vectorizeConstant(Operation *op, ConstantOp constant, Type type) { /// vectorization is possible with the above logic. Returns nullptr otherwise. /// /// TODO(ntv): handle more complex cases. -static Value *vectorizeOperand(Value *operand, Operation *op, - VectorizationState *state) { +static ValuePtr vectorizeOperand(ValuePtr operand, Operation *op, + VectorizationState *state) { LLVM_DEBUG(dbgs() << "\n[early-vect]vectorize operand: "); LLVM_DEBUG(operand->print(dbgs())); // 1. If this value has already been vectorized this round, we are done. @@ -1004,7 +1004,7 @@ static Value *vectorizeOperand(Value *operand, Operation *op, // been vectorized. This would be invalid IR. auto it = state->replacementMap.find(operand); if (it != state->replacementMap.end()) { - auto *res = it->second; + auto res = it->second; LLVM_DEBUG(dbgs() << "-> delayed replacement by: "); LLVM_DEBUG(res->print(dbgs())); return res; @@ -1047,12 +1047,12 @@ static Operation *vectorizeOneOperation(Operation *opInst, if (auto store = dyn_cast(opInst)) { OpBuilder b(opInst); - auto *memRef = store.getMemRef(); - auto *value = store.getValueToStore(); - auto *vectorValue = vectorizeOperand(value, opInst, state); + auto memRef = store.getMemRef(); + auto value = store.getValueToStore(); + auto vectorValue = vectorizeOperand(value, opInst, state); ValueRange mapOperands = store.getMapOperands(); - SmallVector indices; + SmallVector indices; indices.reserve(store.getMemRefType().getRank()); if (store.getAffineMap() != b.getMultiDimIdentityMap(store.getMemRefType().getRank())) { @@ -1081,16 +1081,16 @@ static Operation *vectorizeOneOperation(Operation *opInst, return nullptr; SmallVector vectorTypes; - for (auto *v : opInst->getResults()) { + for (auto v : opInst->getResults()) { vectorTypes.push_back( VectorType::get(state->strategy->vectorSizes, v->getType())); } - SmallVector vectorOperands; - for (auto *v : opInst->getOperands()) { + SmallVector vectorOperands; + for (auto v : opInst->getOperands()) { vectorOperands.push_back(vectorizeOperand(v, opInst, state)); } // Check whether a single operand is null. If so, vectorization failed. - bool success = llvm::all_of(vectorOperands, [](Value *op) { return op; }); + bool success = llvm::all_of(vectorOperands, [](ValuePtr op) { return op; }); if (!success) { LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ an operand failed vectorize"); return nullptr; diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp index 0b105eadf5a..376fc249a18 100644 --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -484,7 +484,7 @@ TEST_FUNC(select_op_i32) { IndexedValue A(f.getArgument(0)); IndexHandle i, j; AffineLoopNestBuilder({&i, &j}, {zero, zero}, {one, one}, {1, 1})([&]{ - // This test exercises IndexedValue::operator Value*. + // This test exercises IndexedValue::operator Value. // Without it, one must force conversion to ValueHandle as such: // edsc::intrinsics::select( // i == zero, ValueHandle(A(zero, zero)), ValueHandle(ValueA(i, j))) @@ -802,7 +802,7 @@ TEST_FUNC(affine_if_op) { }; auto intSet = IntegerSet::get(2, 2, affineExprs, isEq); - SmallVector affineIfArgs = {zero, zero, ten, ten}; + SmallVector affineIfArgs = {zero, zero, ten, ten}; intrinsics::affine_if(intSet, affineIfArgs, /*withElseRegion=*/false); intrinsics::affine_if(intSet, affineIfArgs, /*withElseRegion=*/true); diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp index 7462db4544f..12d024f6593 100644 --- a/mlir/test/lib/TestDialect/TestDialect.cpp +++ b/mlir/test/lib/TestDialect/TestDialect.cpp @@ -100,7 +100,7 @@ struct TestInlinerInterface : public DialectInlinerInterface { /// Handle the given inlined terminator by replacing it with a new operation /// as necessary. void handleTerminator(Operation *op, - ArrayRef valuesToRepl) const final { + ArrayRef valuesToRepl) const final { // Only handle "test.return" here. auto returnOp = dyn_cast(op); if (!returnOp) @@ -117,7 +117,7 @@ struct TestInlinerInterface : public DialectInlinerInterface { /// operation that takes 'input' as the only operand, and produces a single /// result of 'resultType'. If a conversion can not be generated, nullptr /// should be returned. - Operation *materializeCallConversion(OpBuilder &builder, Value *input, + Operation *materializeCallConversion(OpBuilder &builder, ValuePtr input, Type resultType, Location conversionLoc) const final { // Only allow conversion for i16/i32 types. @@ -231,7 +231,7 @@ static ParseResult parseWrappingRegionOp(OpAsmParser &parser, // Create a return terminator in the inner region, pass as operand to the // terminator the returned values from the wrapped operation. - SmallVector return_operands(wrapped_op->getResults()); + SmallVector return_operands(wrapped_op->getResults()); OpBuilder builder(parser.getBuilder().getContext()); builder.setInsertionPointToEnd(&block); builder.create(wrapped_op->getLoc(), return_operands); @@ -297,7 +297,7 @@ OpFoldResult TestOpWithRegionFold::fold(ArrayRef operands) { LogicalResult TestOpWithVariadicResultsAndFolder::fold( ArrayRef operands, SmallVectorImpl &results) { - for (Value *input : this->operands()) { + for (ValuePtr input : this->operands()) { results.push_back(input); } return success(); diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td index e33d9c26c7f..ea071f0ddf4 100644 --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -644,7 +644,7 @@ def OpSymbolBindingB : TEST_Op<"symbol_binding_b", []> { let builders = [ OpBuilder< - "Builder *builder, OperationState &state, Value *operand", + "Builder *builder, OperationState &state, ValuePtr operand", [{ state.types.assign({builder->getIntegerType(32)}); state.addOperands({operand}); diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp index 94eb792cc66..1f6224dba3a 100644 --- a/mlir/test/lib/TestDialect/TestPatterns.cpp +++ b/mlir/test/lib/TestDialect/TestPatterns.cpp @@ -22,11 +22,12 @@ using namespace mlir; // Native function for testing NativeCodeCall -static Value *chooseOperand(Value *input1, Value *input2, BoolAttr choice) { +static ValuePtr chooseOperand(ValuePtr input1, ValuePtr input2, + BoolAttr choice) { return choice.getValue() ? input1 : input2; } -static void createOpI(PatternRewriter &rewriter, Value *input) { +static void createOpI(PatternRewriter &rewriter, ValuePtr input) { rewriter.create(rewriter.getUnknownLoc(), input); } @@ -73,7 +74,7 @@ struct ReturnTypeOpMatch : public RewritePattern { PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final { if (auto retTypeFn = dyn_cast(op)) { - SmallVector values(op->getOperands()); + SmallVector values(op->getOperands()); SmallVector inferedReturnTypes; if (failed(retTypeFn.inferReturnTypes(op->getLoc(), values, op->getAttrs(), op->getRegions(), @@ -132,7 +133,7 @@ struct TestRegionRewriteBlockMovement : public ConversionPattern { : ConversionPattern("test.region", 1, ctx) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // Inline this region into the parent region. auto &parentRegion = *op->getParentRegion(); @@ -165,7 +166,7 @@ struct TestRegionRewriteUndo : public RewritePattern { // Add an explicitly illegal operation to ensure the conversion fails. rewriter.create(op->getLoc(), rewriter.getIntegerType(32)); - rewriter.create(op->getLoc(), ArrayRef()); + rewriter.create(op->getLoc(), ArrayRef()); // Drop this operation. rewriter.eraseOp(op); @@ -182,7 +183,7 @@ struct TestDropOpSignatureConversion : public ConversionPattern { : ConversionPattern("test.drop_region_op", 1, ctx), converter(converter) { } PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Region ®ion = op->getRegion(0); Block *entry = ®ion.front(); @@ -208,7 +209,7 @@ struct TestPassthroughInvalidOp : public ConversionPattern { TestPassthroughInvalidOp(MLIRContext *ctx) : ConversionPattern("test.invalid", 1, ctx) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { rewriter.replaceOpWithNewOp(op, llvm::None, operands, llvm::None); @@ -220,7 +221,7 @@ struct TestSplitReturnType : public ConversionPattern { TestSplitReturnType(MLIRContext *ctx) : ConversionPattern("test.return", 1, ctx) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // Check for a return of F32. if (op->getNumOperands() != 1 || !op->getOperand(0)->getType().isF32()) @@ -245,7 +246,7 @@ struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { TestChangeProducerTypeI32ToF32(MLIRContext *ctx) : ConversionPattern("test.type_producer", 1, ctx) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // If the type is I32, change the type to F32. if (!(*op->result_type_begin()).isInteger(32)) @@ -258,7 +259,7 @@ struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { TestChangeProducerTypeF32ToF64(MLIRContext *ctx) : ConversionPattern("test.type_producer", 1, ctx) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // If the type is F32, change the type to F64. if (!(*op->result_type_begin()).isF32()) @@ -271,7 +272,7 @@ struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) : ConversionPattern("test.type_producer", 10, ctx) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // Always convert to B16, even though it is not a legal type. This tests // that values are unmapped correctly. @@ -283,7 +284,7 @@ struct TestUpdateConsumerType : public ConversionPattern { TestUpdateConsumerType(MLIRContext *ctx) : ConversionPattern("test.type_consumer", 1, ctx) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // Verify that the incoming operand has been successfully remapped to F64. if (!operands[0]->getType().isF64()) @@ -344,7 +345,7 @@ struct TestTypeConverter : public TypeConverter { /// Override the hook to materialize a conversion. This is necessary because /// we generate 1->N type mappings. Operation *materializeConversion(PatternRewriter &rewriter, Type resultType, - ArrayRef inputs, + ArrayRef inputs, Location loc) override { return rewriter.create(loc, resultType, inputs); } @@ -467,13 +468,13 @@ struct OneVResOneVOperandOp1Converter using OpConversionPattern::OpConversionPattern; PatternMatchResult - matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef operands, + matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto origOps = op.getOperands(); assert(std::distance(origOps.begin(), origOps.end()) == 1 && "One operand expected"); - Value *origOp = *origOps.begin(); - SmallVector remappedOperands; + ValuePtr origOp = *origOps.begin(); + SmallVector remappedOperands; // Replicate the remapped original operand twice. Note that we don't used // the remapped 'operand' since the goal is testing 'getRemappedValue'. remappedOperands.push_back(rewriter.getRemappedValue(origOp)); diff --git a/mlir/test/lib/Transforms/TestLoopMapping.cpp b/mlir/test/lib/Transforms/TestLoopMapping.cpp index c25fea9aa13..7f587fc3170 100644 --- a/mlir/test/lib/Transforms/TestLoopMapping.cpp +++ b/mlir/test/lib/Transforms/TestLoopMapping.cpp @@ -41,7 +41,7 @@ public: // SSA values for the transformation are created out of thin air by // unregistered "new_processor_id_and_range" operations. This is enough to // emulate mapping conditions. - SmallVector processorIds, numProcessors; + SmallVector processorIds, numProcessors; func.walk([&processorIds, &numProcessors](Operation *op) { if (op->getName().getStringRef() != "new_processor_id_and_range") return; diff --git a/mlir/test/lib/Transforms/TestVectorizationUtils.cpp b/mlir/test/lib/Transforms/TestVectorizationUtils.cpp index 7efc74f2304..35df0631ca7 100644 --- a/mlir/test/lib/Transforms/TestVectorizationUtils.cpp +++ b/mlir/test/lib/Transforms/TestVectorizationUtils.cpp @@ -245,7 +245,7 @@ void VectorizerTestPass::testNormalizeMaps() { for (auto m : matches) { auto app = cast(m.getMatchedOperation()); OpBuilder b(m.getMatchedOperation()); - SmallVector operands(app.getOperands()); + SmallVector operands(app.getOperands()); makeComposedAffineApply(b, app.getLoc(), app.getAffineMap(), operands); } } diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td index fa73697dba8..004e7662299 100644 --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -216,9 +216,9 @@ def MixOperandsAndAttrs : NS_Op<"mix_operands_and_attrs", []> { } // DEF-LABEL: MixOperandsAndAttrs definitions -// DEF-DAG: Value *MixOperandsAndAttrs::operand() -// DEF-DAG: Value *MixOperandsAndAttrs::otherArg() -// DEF-DAG: void MixOperandsAndAttrs::build(Builder *tblgen_builder, OperationState &tblgen_state, FloatAttr attr, Value *operand, FloatAttr otherAttr, Value *otherArg) +// DEF-DAG: ValuePtr MixOperandsAndAttrs::operand() +// DEF-DAG: ValuePtr MixOperandsAndAttrs::otherArg() +// DEF-DAG: void MixOperandsAndAttrs::build(Builder *tblgen_builder, OperationState &tblgen_state, FloatAttr attr, ValuePtr operand, FloatAttr otherAttr, ValuePtr otherArg) // DEF-DAG: APFloat MixOperandsAndAttrs::attr() // DEF-DAG: APFloat MixOperandsAndAttrs::otherAttr() diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td index a217a139848..55952236429 100644 --- a/mlir/test/mlir-tblgen/op-decl.td +++ b/mlir/test/mlir-tblgen/op-decl.td @@ -26,7 +26,7 @@ def NS_AOp : NS_Op<"a_op", [NoSideEffect, NoSideEffect]> { ); let regions = (region AnyRegion:$someRegion); - let builders = [OpBuilder<"Value *val">]; + let builders = [OpBuilder<"ValuePtr val">]; let parser = [{ foo }]; let printer = [{ bar }]; let verifier = [{ baz }]; @@ -46,12 +46,12 @@ def NS_AOp : NS_Op<"a_op", [NoSideEffect, NoSideEffect]> { // CHECK: class AOpOperandAdaptor { // CHECK: public: -// CHECK: AOpOperandAdaptor(ArrayRef values); -// CHECK: ArrayRef getODSOperands(unsigned index); -// CHECK: Value *a(); -// CHECK: ArrayRef b(); +// CHECK: AOpOperandAdaptor(ArrayRef values); +// CHECK: ArrayRef getODSOperands(unsigned index); +// CHECK: ValuePtr a(); +// CHECK: ArrayRef b(); // CHECK: private: -// CHECK: ArrayRef tblgen_operands; +// CHECK: ArrayRef tblgen_operands; // CHECK: }; // CHECK: class AOp : public Op::Impl, OpTrait::HasNoSideEffect, OpTrait::AtLeastNOperands<1>::Impl @@ -60,18 +60,18 @@ def NS_AOp : NS_Op<"a_op", [NoSideEffect, NoSideEffect]> { // CHECK: using OperandAdaptor = AOpOperandAdaptor; // CHECK: static StringRef getOperationName(); // CHECK: Operation::operand_range getODSOperands(unsigned index); -// CHECK: Value *a(); +// CHECK: ValuePtr a(); // CHECK: Operation::operand_range b(); // CHECK: Operation::result_range getODSResults(unsigned index); -// CHECK: Value *r(); +// CHECK: ValuePtr r(); // CHECK: Region &someRegion(); // CHECK: IntegerAttr attr1Attr() // CHECK: APInt attr1(); // CHECK: FloatAttr attr2Attr() // CHECK: Optional< APFloat > attr2(); -// CHECK: static void build(Value *val); -// CHECK: static void build(Builder *tblgen_builder, OperationState &tblgen_state, Type r, ArrayRef s, Value *a, ValueRange b, IntegerAttr attr1, /*optional*/FloatAttr attr2) -// CHECK: static void build(Builder *tblgen_builder, OperationState &tblgen_state, Type r, ArrayRef s, Value *a, ValueRange b, APInt attr1, /*optional*/FloatAttr attr2) +// CHECK: static void build(ValuePtr val); +// CHECK: static void build(Builder *tblgen_builder, OperationState &tblgen_state, Type r, ArrayRef s, ValuePtr a, ValueRange b, IntegerAttr attr1, /*optional*/FloatAttr attr2) +// CHECK: static void build(Builder *tblgen_builder, OperationState &tblgen_state, Type r, ArrayRef s, ValuePtr a, ValueRange b, APInt attr1, /*optional*/FloatAttr attr2) // CHECK: static void build(Builder *, OperationState &tblgen_state, ArrayRef resultTypes, ValueRange operands, ArrayRef attributes) // CHECK: static ParseResult parse(OpAsmParser &parser, OperationState &result); // CHECK: void print(OpAsmPrinter &p); @@ -111,7 +111,7 @@ def NS_DOp : NS_Op<"op_with_two_operands", []> { def NS_SkipDefaultBuildersOp : NS_Op<"skip_default_builders", []> { let skipDefaultBuilders = 1; - let builders = [OpBuilder<"Value *val">]; + let builders = [OpBuilder<"ValuePtr val">]; } // CHECK-LABEL: NS::SkipDefaultBuildersOp declarations diff --git a/mlir/test/mlir-tblgen/op-operand.td b/mlir/test/mlir-tblgen/op-operand.td index 872cc474a06..c592686ebd3 100644 --- a/mlir/test/mlir-tblgen/op-operand.td +++ b/mlir/test/mlir-tblgen/op-operand.td @@ -18,7 +18,7 @@ def OpA : NS_Op<"one_normal_operand_op", []> { // CHECK-NEXT: tblgen_operands = values // CHECK: void OpA::build -// CHECK: Value *input +// CHECK: ValuePtr input // CHECK: tblgen_state.addOperands(input); // CHECK: void OpA::build @@ -39,19 +39,19 @@ def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]> let arguments = (ins Variadic:$input1, AnyTensor:$input2, Variadic:$input3); } -// CHECK-LABEL: ArrayRef OpDOperandAdaptor::input1 +// CHECK-LABEL: ArrayRef OpDOperandAdaptor::input1 // CHECK-NEXT: return getODSOperands(0); -// CHECK-LABEL: Value *OpDOperandAdaptor::input2 +// CHECK-LABEL: ValuePtr OpDOperandAdaptor::input2 // CHECK-NEXT: return *getODSOperands(1).begin(); -// CHECK-LABEL: ArrayRef OpDOperandAdaptor::input3 +// CHECK-LABEL: ArrayRef OpDOperandAdaptor::input3 // CHECK-NEXT: return getODSOperands(2); // CHECK-LABEL: Operation::operand_range OpD::input1 // CHECK-NEXT: return getODSOperands(0); -// CHECK-LABEL: Value *OpD::input2 +// CHECK-LABEL: ValuePtr OpD::input2 // CHECK-NEXT: return *getODSOperands(1).begin(); // CHECK-LABEL: OpD::build diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td index 4ee631986cc..f9a77ea492e 100644 --- a/mlir/test/mlir-tblgen/op-result.td +++ b/mlir/test/mlir-tblgen/op-result.td @@ -23,9 +23,9 @@ def OpB : NS_Op<"same_input_output_type_op", [SameOperandsAndResultType]> { } // CHECK-LABEL: OpB definitions -// CHECK: void OpB::build(Builder *tblgen_builder, OperationState &tblgen_state, Type y, Value *x) +// CHECK: void OpB::build(Builder *tblgen_builder, OperationState &tblgen_state, Type y, ValuePtr x) // CHECK: tblgen_state.addTypes(y); -// CHECK: void OpB::build(Builder *tblgen_builder, OperationState &tblgen_state, Value *x) +// CHECK: void OpB::build(Builder *tblgen_builder, OperationState &tblgen_state, ValuePtr x) // CHECK: tblgen_state.addTypes({x->getType()}); def OpC : NS_Op<"three_normal_result_op", []> { @@ -89,7 +89,7 @@ def OpI : NS_Op<"mix_variadic_and_normal_results_op", [SameVariadicResultSize]> // CHECK-LABEL: Operation::result_range OpI::output1 // CHECK-NEXT: return getODSResults(0); -// CHECK-LABEL: Value *OpI::output2 +// CHECK-LABEL: ValuePtr OpI::output2 // CHECK-NEXT: return *getODSResults(1).begin(); // CHECK-LABEL: OpI::build diff --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td index 26a5b746fb4..fef1b139dc9 100644 --- a/mlir/test/mlir-tblgen/predicate.td +++ b/mlir/test/mlir-tblgen/predicate.td @@ -16,7 +16,7 @@ def OpA : NS_Op<"op_for_CPred_containing_multiple_same_placeholder", []> { } // CHECK-LABEL: OpA::verify -// CHECK: for (Value *v : getODSOperands(0)) { +// CHECK: for (ValuePtr v : getODSOperands(0)) { // CHECK: if (!((v->getType().isInteger(32) || v->getType().isF32()))) def OpB : NS_Op<"op_for_And_PredOpTrait", [ @@ -90,5 +90,5 @@ def OpK : NS_Op<"op_for_AnyTensorOf", []> { } // CHECK-LABEL: OpK::verify -// CHECK: for (Value *v : getODSOperands(0)) { +// CHECK: for (ValuePtr v : getODSOperands(0)) { // CHECK: if (!(((v->getType().isa())) && (((v->getType().cast().getElementType().isF32())) || ((v->getType().cast().getElementType().isInteger(32)))))) diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index dd56458ccb3..df8feb855c5 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -713,11 +713,12 @@ void OpEmitter::genAttrGetters() { // Generates the named operand getter methods for the given Operator `op` and // puts them in `opClass`. Uses `rangeType` as the return type of getters that -// return a range of operands (individual operands are `Value *` and each -// element in the range must also be `Value *`); use `rangeBeginCall` to get an -// iterator to the beginning of the operand range; use `rangeSizeCall` to obtain -// the number of operands. `getOperandCallPattern` contains the code necessary -// to obtain a single operand whose position will be substituted instead of +// return a range of operands (individual operands are `ValuePtr ` and each +// element in the range must also be `ValuePtr `); use `rangeBeginCall` to get +// an iterator to the beginning of the operand range; use `rangeSizeCall` to +// obtain the number of operands. `getOperandCallPattern` contains the code +// necessary to obtain a single operand whose position will be substituted +// instead of // "{0}" marker in the pattern. Note that the pattern should work for any kind // of ops, in particular for one-operand ops that may not have the // `getOperand(unsigned)` method. @@ -790,7 +791,7 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass, auto &m = opClass.newMethod(rangeType, operand.name); m.body() << " return getODSOperands(" << i << ");"; } else { - auto &m = opClass.newMethod("Value *", operand.name); + auto &m = opClass.newMethod("ValuePtr ", operand.name); m.body() << " return *getODSOperands(" << i << ").begin();"; } } @@ -868,7 +869,7 @@ void OpEmitter::genNamedResultGetters() { auto &m = opClass.newMethod("Operation::result_range", result.name); m.body() << " return getODSResults(" << i << ");"; } else { - auto &m = opClass.newMethod("Value *", result.name); + auto &m = opClass.newMethod("ValuePtr ", result.name); m.body() << " return *getODSResults(" << i << ").begin();"; } } @@ -1246,7 +1247,7 @@ void OpEmitter::buildParamList(std::string ¶mList, auto argument = op.getArg(i); if (argument.is()) { const auto &operand = op.getOperand(numOperands); - paramList.append(operand.isVariadic() ? ", ValueRange " : ", Value *"); + paramList.append(operand.isVariadic() ? ", ValueRange " : ", ValuePtr "); paramList.append(getArgumentName(op, numOperands)); ++numOperands; } else { @@ -1535,7 +1536,7 @@ void OpEmitter::genOperandResultVerifier(OpMethodBody &body, continue; // Emit a loop to check all the dynamic values in the pack. - body << formatv(" for (Value *v : getODS{0}{1}s({2})) {{\n", + body << formatv(" for (ValuePtr v : getODS{0}{1}s({2})) {{\n", // Capitalize the first letter to match the function name valueKind.substr(0, 1).upper(), valueKind.substr(1), staticValue.index()); @@ -1690,7 +1691,7 @@ void OpEmitter::genOpAsmInterface() { namespace { // Helper class to emit Op operand adaptors to an output stream. Operand -// adaptors are wrappers around ArrayRef that provide named operand +// adaptors are wrappers around ArrayRef that provide named operand // getters identical to those defined in the Op. class OpOperandAdaptorEmitter { public: @@ -1706,12 +1707,12 @@ private: OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op) : adapterClass(op.getCppClassName().str() + "OperandAdaptor") { - adapterClass.newField("ArrayRef", "tblgen_operands"); - auto &constructor = adapterClass.newConstructor("ArrayRef values"); + adapterClass.newField("ArrayRef", "tblgen_operands"); + auto &constructor = adapterClass.newConstructor("ArrayRef values"); constructor.body() << " tblgen_operands = values;\n"; generateNamedOperandGetters(op, adapterClass, - /*rangeType=*/"ArrayRef", + /*rangeType=*/"ArrayRef", /*rangeBeginCall=*/"tblgen_operands.begin()", /*rangeSizeCall=*/"tblgen_operands.size()", /*getOperandCallPattern=*/"tblgen_operands[{0}]"); diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index b2376e8739c..a74bc23a95a 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -576,14 +576,14 @@ void PatternEmitter::emitRewriteLogic() { os.indent(4) << "rewriter.eraseOp(op0);\n"; } else { // Process replacement result patterns. - os.indent(4) << "SmallVector tblgen_repl_values;\n"; + os.indent(4) << "SmallVector tblgen_repl_values;\n"; for (int i = replStartIndex; i < numResultPatterns; ++i) { DagNode resultTree = pattern.getResultPattern(i); auto val = handleResultPattern(resultTree, offsets[i], 0); os.indent(4) << "\n"; // Resolve each symbol for all range use so that we can loop over them. os << symbolInfoMap.getAllRangeUse( - val, " for (auto *v : {0}) {{ tblgen_repl_values.push_back(v); }", + val, " for (auto v : {0}) {{ tblgen_repl_values.push_back(v); }", "\n"); } os.indent(4) << "\n"; @@ -819,7 +819,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, int numResults = resultOp.getNumResults(); if (numResults != 0) { for (int i = 0; i < numResults; ++i) - os.indent(6) << formatv("for (auto *v : castedOp0.getODSResults({0})) {{" + os.indent(6) << formatv("for (auto v : castedOp0.getODSResults({0})) {{" "tblgen_types.push_back(v->getType()); }\n", resultIndex + i); } @@ -835,8 +835,8 @@ void PatternEmitter::createSeparateLocalVarsForOpArgs( Operator &resultOp = node.getDialectOp(opMap); // Now prepare operands used for building this op: - // * If the operand is non-variadic, we create a `Value*` local variable. - // * If the operand is variadic, we create a `SmallVector` local + // * If the operand is non-variadic, we create a `Value` local variable. + // * If the operand is variadic, we create a `SmallVector` local // variable. int valueIndex = 0; // An index for uniquing local variable names. @@ -851,7 +851,7 @@ void PatternEmitter::createSeparateLocalVarsForOpArgs( std::string varName; if (operand->isVariadic()) { varName = formatv("tblgen_values_{0}", valueIndex++); - os.indent(6) << formatv("SmallVector {0};\n", varName); + os.indent(6) << formatv("SmallVector {0};\n", varName); std::string range; if (node.isNestedDagArg(argIndex)) { range = childNodeNames[argIndex]; @@ -861,11 +861,11 @@ void PatternEmitter::createSeparateLocalVarsForOpArgs( // Resolve the symbol for all range use so that we have a uniform way of // capturing the values. range = symbolInfoMap.getValueAndRangeUse(range); - os.indent(6) << formatv("for (auto *v : {0}) {1}.push_back(v);\n", range, + os.indent(6) << formatv("for (auto v : {0}) {1}.push_back(v);\n", range, varName); } else { varName = formatv("tblgen_value_{0}", valueIndex++); - os.indent(6) << formatv("Value *{0} = ", varName); + os.indent(6) << formatv("ValuePtr {0} = ", varName); if (node.isNestedDagArg(argIndex)) { os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]); } else { @@ -934,7 +934,7 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs( Operator &resultOp = node.getDialectOp(opMap); os.indent(6) << formatv( - "SmallVector tblgen_values; (void)tblgen_values;\n"); + "SmallVector tblgen_values; (void)tblgen_values;\n"); os.indent(6) << formatv( "SmallVector tblgen_attrs; (void)tblgen_attrs;\n"); @@ -975,7 +975,7 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs( // capturing the values. range = symbolInfoMap.getValueAndRangeUse(range); os.indent(6) << formatv( - "for (auto *v : {0}) tblgen_values.push_back(v);\n", range); + "for (auto v : {0}) tblgen_values.push_back(v);\n", range); } else { os.indent(6) << formatv("tblgen_values.push_back(", varName); if (node.isNestedDagArg(argIndex)) { diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp index f1712efb319..6d5bcc116ad 100644 --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -470,7 +470,7 @@ static void emitDeserializationFunction(const Record *attrClass, emitResultDeserialization(op, record->getLoc(), " ", words, wordIndex, resultTypes, valueID, os); - os << formatv(" SmallVector {0};\n", operands); + os << formatv(" SmallVector {0};\n", operands); os << formatv(" SmallVector {0};\n", attributes); // Operand deserialization emitOperandDeserialization(op, record->getLoc(), " ", words, wordIndex, diff --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp index 80f82ac3e5d..d7dae4648fe 100644 --- a/mlir/unittests/IR/OperationSupportTest.cpp +++ b/mlir/unittests/IR/OperationSupportTest.cpp @@ -25,7 +25,7 @@ using namespace mlir::detail; namespace { Operation *createOp(MLIRContext *context, bool resizableOperands, - ArrayRef operands = llvm::None, + ArrayRef operands = llvm::None, ArrayRef resultTypes = llvm::None) { return Operation::create( UnknownLoc::get(context), OperationName("foo.bar", context), resultTypes, @@ -39,7 +39,7 @@ TEST(OperandStorageTest, NonResizable) { Operation *useOp = createOp(&context, /*resizableOperands=*/false, /*operands=*/llvm::None, builder.getIntegerType(16)); - Value *operand = useOp->getResult(0); + ValuePtr operand = useOp->getResult(0); // Create a non-resizable operation with one operand. Operation *user = createOp(&context, /*resizableOperands=*/false, operand, @@ -68,7 +68,7 @@ TEST(OperandStorageDeathTest, AddToNonResizable) { Operation *useOp = createOp(&context, /*resizableOperands=*/false, /*operands=*/llvm::None, builder.getIntegerType(16)); - Value *operand = useOp->getResult(0); + ValuePtr operand = useOp->getResult(0); // Create a non-resizable operation with one operand. Operation *user = createOp(&context, /*resizableOperands=*/false, operand, @@ -88,7 +88,7 @@ TEST(OperandStorageTest, Resizable) { Operation *useOp = createOp(&context, /*resizableOperands=*/false, /*operands=*/llvm::None, builder.getIntegerType(16)); - Value *operand = useOp->getResult(0); + ValuePtr operand = useOp->getResult(0); // Create a resizable operation with one operand. Operation *user = createOp(&context, /*resizableOperands=*/true, operand, -- cgit v1.2.3 From 56222a0694e4caf35e892d70591417c39fef1185 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Mon, 23 Dec 2019 09:35:36 -0800 Subject: Adjust License.txt file to use the LLVM license PiperOrigin-RevId: 286906740 --- mlir/LICENSE.TXT | 128 ++++++++++++++++----- mlir/bindings/python/pybind.cpp | 17 +-- mlir/examples/toy/Ch1/include/toy/AST.h | 17 +-- mlir/examples/toy/Ch1/include/toy/Lexer.h | 17 +-- mlir/examples/toy/Ch1/include/toy/Parser.h | 17 +-- mlir/examples/toy/Ch1/parser/AST.cpp | 17 +-- mlir/examples/toy/Ch1/toyc.cpp | 17 +-- mlir/examples/toy/Ch2/include/toy/AST.h | 17 +-- mlir/examples/toy/Ch2/include/toy/Dialect.h | 17 +-- mlir/examples/toy/Ch2/include/toy/Lexer.h | 17 +-- mlir/examples/toy/Ch2/include/toy/MLIRGen.h | 17 +-- mlir/examples/toy/Ch2/include/toy/Ops.td | 17 +-- mlir/examples/toy/Ch2/include/toy/Parser.h | 17 +-- mlir/examples/toy/Ch2/mlir/Dialect.cpp | 17 +-- mlir/examples/toy/Ch2/mlir/MLIRGen.cpp | 17 +-- mlir/examples/toy/Ch2/parser/AST.cpp | 17 +-- mlir/examples/toy/Ch2/toyc.cpp | 17 +-- mlir/examples/toy/Ch3/include/toy/AST.h | 17 +-- mlir/examples/toy/Ch3/include/toy/Dialect.h | 17 +-- mlir/examples/toy/Ch3/include/toy/Lexer.h | 17 +-- mlir/examples/toy/Ch3/include/toy/MLIRGen.h | 17 +-- mlir/examples/toy/Ch3/include/toy/Ops.td | 17 +-- mlir/examples/toy/Ch3/include/toy/Parser.h | 17 +-- mlir/examples/toy/Ch3/mlir/Dialect.cpp | 17 +-- mlir/examples/toy/Ch3/mlir/MLIRGen.cpp | 17 +-- mlir/examples/toy/Ch3/mlir/ToyCombine.cpp | 17 +-- mlir/examples/toy/Ch3/mlir/ToyCombine.td | 17 +-- mlir/examples/toy/Ch3/parser/AST.cpp | 17 +-- mlir/examples/toy/Ch3/toyc.cpp | 17 +-- mlir/examples/toy/Ch4/include/toy/AST.h | 17 +-- mlir/examples/toy/Ch4/include/toy/Dialect.h | 17 +-- mlir/examples/toy/Ch4/include/toy/Lexer.h | 17 +-- mlir/examples/toy/Ch4/include/toy/MLIRGen.h | 17 +-- mlir/examples/toy/Ch4/include/toy/Ops.td | 17 +-- mlir/examples/toy/Ch4/include/toy/Parser.h | 17 +-- mlir/examples/toy/Ch4/include/toy/Passes.h | 17 +-- .../toy/Ch4/include/toy/ShapeInferenceInterface.h | 17 +-- .../toy/Ch4/include/toy/ShapeInferenceInterface.td | 17 +-- .../toy/Ch4/mlir/DeadFunctionEliminationPass.cpp | 17 +-- mlir/examples/toy/Ch4/mlir/Dialect.cpp | 17 +-- mlir/examples/toy/Ch4/mlir/MLIRGen.cpp | 17 +-- mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp | 17 +-- mlir/examples/toy/Ch4/mlir/ToyCombine.cpp | 17 +-- mlir/examples/toy/Ch4/mlir/ToyCombine.td | 17 +-- mlir/examples/toy/Ch4/parser/AST.cpp | 17 +-- mlir/examples/toy/Ch4/toyc.cpp | 17 +-- mlir/examples/toy/Ch5/include/toy/AST.h | 17 +-- mlir/examples/toy/Ch5/include/toy/Dialect.h | 17 +-- mlir/examples/toy/Ch5/include/toy/Lexer.h | 17 +-- mlir/examples/toy/Ch5/include/toy/MLIRGen.h | 17 +-- mlir/examples/toy/Ch5/include/toy/Ops.td | 17 +-- mlir/examples/toy/Ch5/include/toy/Parser.h | 17 +-- mlir/examples/toy/Ch5/include/toy/Passes.h | 17 +-- .../toy/Ch5/include/toy/ShapeInferenceInterface.h | 17 +-- .../toy/Ch5/include/toy/ShapeInferenceInterface.td | 17 +-- .../toy/Ch5/mlir/DeadFunctionEliminationPass.cpp | 17 +-- mlir/examples/toy/Ch5/mlir/Dialect.cpp | 17 +-- mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp | 17 +-- mlir/examples/toy/Ch5/mlir/MLIRGen.cpp | 17 +-- mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp | 17 +-- mlir/examples/toy/Ch5/mlir/ToyCombine.cpp | 17 +-- mlir/examples/toy/Ch5/mlir/ToyCombine.td | 17 +-- mlir/examples/toy/Ch5/parser/AST.cpp | 17 +-- mlir/examples/toy/Ch5/toyc.cpp | 17 +-- mlir/examples/toy/Ch6/include/toy/AST.h | 17 +-- mlir/examples/toy/Ch6/include/toy/Dialect.h | 17 +-- mlir/examples/toy/Ch6/include/toy/Lexer.h | 17 +-- mlir/examples/toy/Ch6/include/toy/MLIRGen.h | 17 +-- mlir/examples/toy/Ch6/include/toy/Ops.td | 17 +-- mlir/examples/toy/Ch6/include/toy/Parser.h | 17 +-- mlir/examples/toy/Ch6/include/toy/Passes.h | 17 +-- .../toy/Ch6/include/toy/ShapeInferenceInterface.h | 17 +-- .../toy/Ch6/include/toy/ShapeInferenceInterface.td | 17 +-- .../toy/Ch6/mlir/DeadFunctionEliminationPass.cpp | 17 +-- mlir/examples/toy/Ch6/mlir/Dialect.cpp | 17 +-- mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp | 17 +-- mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp | 17 +-- mlir/examples/toy/Ch6/mlir/MLIRGen.cpp | 17 +-- mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp | 17 +-- mlir/examples/toy/Ch6/mlir/ToyCombine.cpp | 17 +-- mlir/examples/toy/Ch6/mlir/ToyCombine.td | 17 +-- mlir/examples/toy/Ch6/parser/AST.cpp | 17 +-- mlir/examples/toy/Ch6/toyc.cpp | 17 +-- mlir/examples/toy/Ch7/include/toy/AST.h | 17 +-- mlir/examples/toy/Ch7/include/toy/Dialect.h | 17 +-- mlir/examples/toy/Ch7/include/toy/Lexer.h | 17 +-- mlir/examples/toy/Ch7/include/toy/MLIRGen.h | 17 +-- mlir/examples/toy/Ch7/include/toy/Ops.td | 17 +-- mlir/examples/toy/Ch7/include/toy/Parser.h | 17 +-- mlir/examples/toy/Ch7/include/toy/Passes.h | 17 +-- .../toy/Ch7/include/toy/ShapeInferenceInterface.h | 17 +-- .../toy/Ch7/include/toy/ShapeInferenceInterface.td | 17 +-- .../toy/Ch7/mlir/DeadFunctionEliminationPass.cpp | 17 +-- mlir/examples/toy/Ch7/mlir/Dialect.cpp | 17 +-- mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp | 17 +-- mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp | 17 +-- mlir/examples/toy/Ch7/mlir/MLIRGen.cpp | 17 +-- mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp | 17 +-- mlir/examples/toy/Ch7/mlir/ToyCombine.cpp | 17 +-- mlir/examples/toy/Ch7/mlir/ToyCombine.td | 17 +-- mlir/examples/toy/Ch7/parser/AST.cpp | 17 +-- mlir/examples/toy/Ch7/toyc.cpp | 17 +-- mlir/include/mlir-c/Core.h | 17 +-- mlir/include/mlir/ADT/TypeSwitch.h | 17 +-- mlir/include/mlir/Analysis/AffineAnalysis.h | 17 +-- mlir/include/mlir/Analysis/AffineStructures.h | 17 +-- mlir/include/mlir/Analysis/CallGraph.h | 17 +-- mlir/include/mlir/Analysis/CallInterfaces.h | 17 +-- mlir/include/mlir/Analysis/CallInterfaces.td | 17 +-- mlir/include/mlir/Analysis/Dominance.h | 17 +-- mlir/include/mlir/Analysis/InferTypeOpInterface.h | 17 +-- mlir/include/mlir/Analysis/InferTypeOpInterface.td | 17 +-- mlir/include/mlir/Analysis/Liveness.h | 17 +-- mlir/include/mlir/Analysis/LoopAnalysis.h | 17 +-- mlir/include/mlir/Analysis/NestedMatcher.h | 17 +-- mlir/include/mlir/Analysis/Passes.h | 17 +-- mlir/include/mlir/Analysis/SliceAnalysis.h | 17 +-- mlir/include/mlir/Analysis/Utils.h | 17 +-- mlir/include/mlir/Analysis/Verifier.h | 17 +-- .../Conversion/AffineToStandard/AffineToStandard.h | 17 +-- .../mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h | 17 +-- .../mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h | 17 +-- .../mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h | 17 +-- .../mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h | 17 +-- .../Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h | 17 +-- .../mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h | 17 +-- .../LoopToStandard/ConvertLoopToStandard.h | 17 +-- .../mlir/Conversion/LoopsToGPU/LoopsToGPU.h | 17 +-- .../mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h | 17 +-- .../StandardToLLVM/ConvertStandardToLLVM.h | 17 +-- .../StandardToLLVM/ConvertStandardToLLVMPass.h | 17 +-- .../StandardToSPIRV/ConvertStandardToSPIRV.h | 17 +-- .../StandardToSPIRV/ConvertStandardToSPIRVPass.h | 17 +-- .../Conversion/VectorToLLVM/ConvertVectorToLLVM.h | 17 +-- .../VectorToLoops/ConvertVectorToLoops.h | 17 +-- mlir/include/mlir/Dialect/AffineOps/AffineOps.h | 17 +-- mlir/include/mlir/Dialect/AffineOps/AffineOps.td | 17 +-- .../mlir/Dialect/AffineOps/AffineOpsBase.td | 17 +-- mlir/include/mlir/Dialect/CommonFolders.h | 17 +-- mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.h | 17 +-- mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.td | 17 +-- mlir/include/mlir/Dialect/FxpMathOps/Passes.h | 17 +-- mlir/include/mlir/Dialect/GPU/GPUDialect.h | 17 +-- mlir/include/mlir/Dialect/GPU/GPUOps.td | 17 +-- mlir/include/mlir/Dialect/GPU/Passes.h | 17 +-- mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h | 17 +-- mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td | 17 +-- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 17 +-- mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h | 17 +-- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 17 +-- mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h | 17 +-- mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 17 +-- .../Dialect/Linalg/Analysis/DependenceAnalysis.h | 17 +-- mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h | 17 +-- mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h | 17 +-- mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td | 17 +-- mlir/include/mlir/Dialect/Linalg/IR/LinalgDoc.td | 17 +-- .../mlir/Dialect/Linalg/IR/LinalgLibraryOps.td | 17 +-- mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h | 17 +-- mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td | 17 +-- .../mlir/Dialect/Linalg/IR/LinalgStructuredOps.td | 17 +-- mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h | 17 +-- mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h | 17 +-- mlir/include/mlir/Dialect/Linalg/Passes.h | 17 +-- .../Linalg/Transforms/LinalgTransformPatterns.td | 17 +-- .../Dialect/Linalg/Transforms/LinalgTransforms.h | 17 +-- .../include/mlir/Dialect/Linalg/Utils/Intrinsics.h | 17 +-- mlir/include/mlir/Dialect/Linalg/Utils/Utils.h | 17 +-- mlir/include/mlir/Dialect/LoopOps/LoopOps.h | 17 +-- mlir/include/mlir/Dialect/LoopOps/LoopOps.td | 17 +-- .../mlir/Dialect/QuantOps/FakeQuantSupport.h | 17 +-- mlir/include/mlir/Dialect/QuantOps/Passes.h | 17 +-- mlir/include/mlir/Dialect/QuantOps/QuantOps.h | 17 +-- mlir/include/mlir/Dialect/QuantOps/QuantOps.td | 17 +-- .../mlir/Dialect/QuantOps/QuantPredicates.td | 17 +-- mlir/include/mlir/Dialect/QuantOps/QuantTypes.h | 17 +-- mlir/include/mlir/Dialect/QuantOps/QuantizeUtils.h | 17 +-- .../include/mlir/Dialect/QuantOps/UniformSupport.h | 17 +-- mlir/include/mlir/Dialect/SDBM/SDBM.h | 17 +-- mlir/include/mlir/Dialect/SDBM/SDBMDialect.h | 17 +-- mlir/include/mlir/Dialect/SDBM/SDBMExpr.h | 17 +-- mlir/include/mlir/Dialect/SPIRV/LayoutUtils.h | 17 +-- mlir/include/mlir/Dialect/SPIRV/Passes.h | 17 +-- .../mlir/Dialect/SPIRV/SPIRVArithmeticOps.td | 17 +-- mlir/include/mlir/Dialect/SPIRV/SPIRVAtomicOps.td | 17 +-- mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td | 17 +-- mlir/include/mlir/Dialect/SPIRV/SPIRVBinaryUtils.h | 17 +-- mlir/include/mlir/Dialect/SPIRV/SPIRVBitOps.td | 17 +-- mlir/include/mlir/Dialect/SPIRV/SPIRVCastOps.td | 17 +-- .../mlir/Dialect/SPIRV/SPIRVCompositeOps.td | 17 +-- .../mlir/Dialect/SPIRV/SPIRVControlFlowOps.td | 17 +-- mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h | 17 +-- mlir/include/mlir/Dialect/SPIRV/SPIRVGLSLOps.td | 17 +-- mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td | 17 +-- mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td | 17 +-- mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h | 17 +-- mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.td | 17 +-- .../mlir/Dialect/SPIRV/SPIRVNonUniformOps.td | 17 +-- mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h | 17 +-- mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td | 17 +-- .../mlir/Dialect/SPIRV/SPIRVStructureOps.td | 17 +-- mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h | 17 +-- mlir/include/mlir/Dialect/SPIRV/Serialization.h | 17 +-- mlir/include/mlir/Dialect/StandardOps/Ops.h | 17 +-- mlir/include/mlir/Dialect/StandardOps/Ops.td | 17 +-- mlir/include/mlir/Dialect/Traits.h | 17 +-- .../mlir/Dialect/Utils/StructuredOpsUtils.h | 17 +-- mlir/include/mlir/Dialect/VectorOps/Utils.h | 17 +-- mlir/include/mlir/Dialect/VectorOps/VectorOps.h | 17 +-- mlir/include/mlir/Dialect/VectorOps/VectorOps.td | 17 +-- .../Dialect/VectorOps/VectorTransformPatterns.td | 17 +-- .../mlir/Dialect/VectorOps/VectorTransforms.h | 17 +-- mlir/include/mlir/EDSC/Builders.h | 17 +-- mlir/include/mlir/EDSC/Helpers.h | 17 +-- mlir/include/mlir/EDSC/Intrinsics.h | 17 +-- .../include/mlir/ExecutionEngine/ExecutionEngine.h | 17 +-- mlir/include/mlir/ExecutionEngine/OptUtils.h | 17 +-- mlir/include/mlir/IR/AffineExpr.h | 17 +-- mlir/include/mlir/IR/AffineExprVisitor.h | 17 +-- mlir/include/mlir/IR/AffineMap.h | 17 +-- mlir/include/mlir/IR/AttributeSupport.h | 17 +-- mlir/include/mlir/IR/Attributes.h | 17 +-- mlir/include/mlir/IR/Block.h | 17 +-- mlir/include/mlir/IR/BlockAndValueMapping.h | 17 +-- mlir/include/mlir/IR/BlockSupport.h | 17 +-- mlir/include/mlir/IR/Builders.h | 17 +-- mlir/include/mlir/IR/Diagnostics.h | 17 +-- mlir/include/mlir/IR/Dialect.h | 17 +-- mlir/include/mlir/IR/DialectHooks.h | 17 +-- mlir/include/mlir/IR/DialectImplementation.h | 17 +-- mlir/include/mlir/IR/DialectInterface.h | 17 +-- mlir/include/mlir/IR/DialectSymbolRegistry.def | 17 +-- mlir/include/mlir/IR/Function.h | 17 +-- mlir/include/mlir/IR/FunctionImplementation.h | 17 +-- mlir/include/mlir/IR/FunctionSupport.h | 17 +-- mlir/include/mlir/IR/Identifier.h | 17 +-- mlir/include/mlir/IR/IntegerSet.h | 17 +-- mlir/include/mlir/IR/Location.h | 17 +-- mlir/include/mlir/IR/MLIRContext.h | 17 +-- mlir/include/mlir/IR/Matchers.h | 17 +-- mlir/include/mlir/IR/Module.h | 17 +-- mlir/include/mlir/IR/OpAsmInterface.td | 17 +-- mlir/include/mlir/IR/OpBase.td | 17 +-- mlir/include/mlir/IR/OpDefinition.h | 17 +-- mlir/include/mlir/IR/OpImplementation.h | 17 +-- mlir/include/mlir/IR/Operation.h | 17 +-- mlir/include/mlir/IR/OperationSupport.h | 17 +-- mlir/include/mlir/IR/PatternMatch.h | 17 +-- mlir/include/mlir/IR/Region.h | 17 +-- mlir/include/mlir/IR/RegionGraphTraits.h | 17 +-- mlir/include/mlir/IR/StandardTypes.h | 17 +-- mlir/include/mlir/IR/StorageUniquerSupport.h | 17 +-- mlir/include/mlir/IR/SymbolTable.h | 17 +-- mlir/include/mlir/IR/TypeSupport.h | 17 +-- mlir/include/mlir/IR/TypeUtilities.h | 17 +-- mlir/include/mlir/IR/Types.h | 17 +-- mlir/include/mlir/IR/UseDefLists.h | 17 +-- mlir/include/mlir/IR/Value.h | 17 +-- mlir/include/mlir/IR/Visitors.h | 17 +-- mlir/include/mlir/Parser.h | 17 +-- mlir/include/mlir/Pass/AnalysisManager.h | 17 +-- mlir/include/mlir/Pass/Pass.h | 17 +-- mlir/include/mlir/Pass/PassInstrumentation.h | 17 +-- mlir/include/mlir/Pass/PassManager.h | 17 +-- mlir/include/mlir/Pass/PassOptions.h | 17 +-- mlir/include/mlir/Pass/PassRegistry.h | 17 +-- .../mlir/Quantizer/Configurations/FxpMathConfig.h | 17 +-- .../include/mlir/Quantizer/Support/Configuration.h | 17 +-- .../Quantizer/Support/ConstraintAnalysisGraph.h | 17 +-- .../Support/ConstraintAnalysisGraphTraits.h | 17 +-- mlir/include/mlir/Quantizer/Support/Metadata.h | 17 +-- mlir/include/mlir/Quantizer/Support/Rules.h | 17 +-- mlir/include/mlir/Quantizer/Support/Statistics.h | 17 +-- mlir/include/mlir/Quantizer/Support/TypeUtils.h | 17 +-- .../mlir/Quantizer/Support/UniformConstraints.h | 17 +-- .../mlir/Quantizer/Support/UniformSolvers.h | 17 +-- mlir/include/mlir/Quantizer/Transforms/Passes.h | 17 +-- mlir/include/mlir/Support/DebugStringHelper.h | 17 +-- mlir/include/mlir/Support/FileUtilities.h | 17 +-- mlir/include/mlir/Support/Functional.h | 17 +-- mlir/include/mlir/Support/JitRunner.h | 17 +-- mlir/include/mlir/Support/LLVM.h | 17 +-- mlir/include/mlir/Support/LogicalResult.h | 17 +-- mlir/include/mlir/Support/MathExtras.h | 17 +-- mlir/include/mlir/Support/MlirOptMain.h | 17 +-- mlir/include/mlir/Support/STLExtras.h | 17 +-- mlir/include/mlir/Support/StorageUniquer.h | 17 +-- mlir/include/mlir/Support/StringExtras.h | 17 +-- mlir/include/mlir/Support/ToolUtilities.h | 17 +-- mlir/include/mlir/Support/TranslateClParser.h | 17 +-- mlir/include/mlir/TableGen/Argument.h | 17 +-- mlir/include/mlir/TableGen/Attribute.h | 17 +-- mlir/include/mlir/TableGen/Constraint.h | 17 +-- mlir/include/mlir/TableGen/Dialect.h | 17 +-- mlir/include/mlir/TableGen/Format.h | 17 +-- mlir/include/mlir/TableGen/GenInfo.h | 17 +-- mlir/include/mlir/TableGen/GenNameParser.h | 17 +-- mlir/include/mlir/TableGen/OpInterfaces.h | 17 +-- mlir/include/mlir/TableGen/OpTrait.h | 17 +-- mlir/include/mlir/TableGen/Operator.h | 17 +-- mlir/include/mlir/TableGen/Pattern.h | 17 +-- mlir/include/mlir/TableGen/Predicate.h | 17 +-- mlir/include/mlir/TableGen/Region.h | 17 +-- mlir/include/mlir/TableGen/Type.h | 17 +-- mlir/include/mlir/Target/LLVMIR.h | 17 +-- .../include/mlir/Target/LLVMIR/ModuleTranslation.h | 17 +-- mlir/include/mlir/Target/NVVMIR.h | 17 +-- mlir/include/mlir/Target/ROCDLIR.h | 17 +-- mlir/include/mlir/Transforms/DialectConversion.h | 17 +-- mlir/include/mlir/Transforms/FoldUtils.h | 17 +-- mlir/include/mlir/Transforms/InliningUtils.h | 17 +-- mlir/include/mlir/Transforms/LoopFusionUtils.h | 17 +-- mlir/include/mlir/Transforms/LoopLikeInterface.h | 17 +-- mlir/include/mlir/Transforms/LoopLikeInterface.td | 17 +-- mlir/include/mlir/Transforms/LoopUtils.h | 17 +-- mlir/include/mlir/Transforms/Passes.h | 17 +-- mlir/include/mlir/Transforms/RegionUtils.h | 17 +-- .../include/mlir/Transforms/SideEffectsInterface.h | 17 +-- mlir/include/mlir/Transforms/Utils.h | 17 +-- mlir/include/mlir/Transforms/ViewOpGraph.h | 17 +-- mlir/include/mlir/Transforms/ViewRegionGraph.h | 17 +-- mlir/include/mlir/Translation.h | 17 +-- mlir/lib/Analysis/AffineAnalysis.cpp | 17 +-- mlir/lib/Analysis/AffineStructures.cpp | 17 +-- mlir/lib/Analysis/CallGraph.cpp | 17 +-- mlir/lib/Analysis/Dominance.cpp | 17 +-- mlir/lib/Analysis/InferTypeOpInterface.cpp | 17 +-- mlir/lib/Analysis/Liveness.cpp | 17 +-- mlir/lib/Analysis/LoopAnalysis.cpp | 17 +-- mlir/lib/Analysis/MemRefBoundCheck.cpp | 17 +-- mlir/lib/Analysis/NestedMatcher.cpp | 17 +-- mlir/lib/Analysis/OpStats.cpp | 17 +-- mlir/lib/Analysis/SliceAnalysis.cpp | 17 +-- mlir/lib/Analysis/TestMemRefDependenceCheck.cpp | 17 +-- mlir/lib/Analysis/TestParallelismDetection.cpp | 17 +-- mlir/lib/Analysis/Utils.cpp | 17 +-- mlir/lib/Analysis/VectorAnalysis.cpp | 17 +-- mlir/lib/Analysis/Verifier.cpp | 17 +-- .../AffineToStandard/AffineToStandard.cpp | 17 +-- .../GPUCommon/IndexIntrinsicsOpLowering.h | 17 +-- .../Conversion/GPUCommon/OpToFuncCallLowering.h | 17 +-- .../GPUToCUDA/ConvertKernelFuncToCubin.cpp | 17 +-- .../GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp | 17 +-- mlir/lib/Conversion/GPUToNVVM/GPUToNVVM.td | 17 +-- .../Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 17 +-- .../GPUToROCDL/LowerGpuOpsToROCDLOps.cpp | 17 +-- .../Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp | 17 +-- .../GPUToSPIRV/ConvertGPUToSPIRVPass.cpp | 17 +-- mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp | 17 +-- .../LoopToStandard/ConvertLoopToStandard.cpp | 17 +-- mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp | 17 +-- mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp | 17 +-- .../StandardToLLVM/ConvertStandardToLLVM.cpp | 17 +-- .../StandardToSPIRV/ConvertStandardToSPIRV.cpp | 17 +-- .../StandardToSPIRV/ConvertStandardToSPIRVPass.cpp | 17 +-- .../StandardToSPIRV/LegalizeStandardForSPIRV.cpp | 17 +-- .../VectorToLLVM/ConvertVectorToLLVM.cpp | 17 +-- .../VectorToLoops/ConvertVectorToLoops.cpp | 17 +-- mlir/lib/Dialect/AffineOps/AffineOps.cpp | 17 +-- mlir/lib/Dialect/AffineOps/DialectRegistration.cpp | 17 +-- .../Dialect/FxpMathOps/IR/DialectRegistration.cpp | 17 +-- mlir/lib/Dialect/FxpMathOps/IR/FxpMathOps.cpp | 17 +-- .../FxpMathOps/Transforms/LowerUniformRealMath.cpp | 17 +-- .../FxpMathOps/Transforms/UniformKernelUtils.h | 17 +-- mlir/lib/Dialect/GPU/IR/DialectRegistration.cpp | 17 +-- mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 17 +-- .../lib/Dialect/GPU/Transforms/KernelOutlining.cpp | 17 +-- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 17 +-- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 17 +-- mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp | 17 +-- .../Dialect/Linalg/Analysis/DependenceAnalysis.cpp | 17 +-- mlir/lib/Dialect/Linalg/EDSC/Builders.cpp | 17 +-- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 17 +-- mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp | 17 +-- mlir/lib/Dialect/Linalg/LinalgRegistration.cpp | 17 +-- mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 17 +-- .../Dialect/Linalg/Transforms/LinalgToLoops.cpp | 17 +-- .../Dialect/Linalg/Transforms/LinalgTransforms.cpp | 17 +-- mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp | 17 +-- mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 17 +-- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 17 +-- mlir/lib/Dialect/LoopOps/DialectRegistration.cpp | 17 +-- mlir/lib/Dialect/LoopOps/LoopOps.cpp | 17 +-- .../Dialect/QuantOps/IR/DialectRegistration.cpp | 17 +-- mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp | 17 +-- mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp | 17 +-- mlir/lib/Dialect/QuantOps/IR/TypeDetail.h | 17 +-- mlir/lib/Dialect/QuantOps/IR/TypeParser.cpp | 17 +-- .../Dialect/QuantOps/Transforms/ConvertConst.cpp | 17 +-- .../QuantOps/Transforms/ConvertSimQuant.cpp | 17 +-- .../Dialect/QuantOps/Utils/FakeQuantSupport.cpp | 17 +-- mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp | 17 +-- mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp | 17 +-- mlir/lib/Dialect/SDBM/SDBM.cpp | 17 +-- mlir/lib/Dialect/SDBM/SDBMDialect.cpp | 17 +-- mlir/lib/Dialect/SDBM/SDBMExpr.cpp | 17 +-- mlir/lib/Dialect/SDBM/SDBMExprDetail.h | 17 +-- mlir/lib/Dialect/SPIRV/DialectRegistration.cpp | 17 +-- mlir/lib/Dialect/SPIRV/LayoutUtils.cpp | 17 +-- mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp | 17 +-- mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 17 +-- mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp | 17 +-- .../Dialect/SPIRV/Serialization/Deserializer.cpp | 17 +-- .../SPIRV/Serialization/SPIRVBinaryUtils.cpp | 17 +-- .../lib/Dialect/SPIRV/Serialization/Serializer.cpp | 17 +-- .../SPIRV/Serialization/TranslateRegistration.cpp | 17 +-- .../DecorateSPIRVCompositeTypeLayoutPass.cpp | 17 +-- .../SPIRV/Transforms/LowerABIAttributesPass.cpp | 17 +-- .../Dialect/StandardOps/DialectRegistration.cpp | 17 +-- mlir/lib/Dialect/StandardOps/Ops.cpp | 17 +-- mlir/lib/Dialect/Traits.cpp | 17 +-- mlir/lib/Dialect/VectorOps/DialectRegistration.cpp | 17 +-- mlir/lib/Dialect/VectorOps/VectorOps.cpp | 17 +-- mlir/lib/Dialect/VectorOps/VectorTransforms.cpp | 17 +-- mlir/lib/EDSC/Builders.cpp | 17 +-- mlir/lib/EDSC/CoreAPIs.cpp | 17 +-- mlir/lib/EDSC/Helpers.cpp | 17 +-- mlir/lib/EDSC/Intrinsics.cpp | 17 +-- mlir/lib/ExecutionEngine/ExecutionEngine.cpp | 17 +-- mlir/lib/ExecutionEngine/OptUtils.cpp | 17 +-- mlir/lib/IR/AffineExpr.cpp | 17 +-- mlir/lib/IR/AffineExprDetail.h | 17 +-- mlir/lib/IR/AffineMap.cpp | 17 +-- mlir/lib/IR/AffineMapDetail.h | 17 +-- mlir/lib/IR/AsmPrinter.cpp | 17 +-- mlir/lib/IR/AttributeDetail.h | 17 +-- mlir/lib/IR/Attributes.cpp | 17 +-- mlir/lib/IR/Block.cpp | 17 +-- mlir/lib/IR/Builders.cpp | 17 +-- mlir/lib/IR/Diagnostics.cpp | 17 +-- mlir/lib/IR/Dialect.cpp | 17 +-- mlir/lib/IR/Function.cpp | 17 +-- mlir/lib/IR/FunctionImplementation.cpp | 17 +-- mlir/lib/IR/IntegerSet.cpp | 17 +-- mlir/lib/IR/IntegerSetDetail.h | 17 +-- mlir/lib/IR/Location.cpp | 17 +-- mlir/lib/IR/LocationDetail.h | 17 +-- mlir/lib/IR/MLIRContext.cpp | 17 +-- mlir/lib/IR/Module.cpp | 17 +-- mlir/lib/IR/Operation.cpp | 17 +-- mlir/lib/IR/OperationSupport.cpp | 17 +-- mlir/lib/IR/PatternMatch.cpp | 17 +-- mlir/lib/IR/Region.cpp | 17 +-- mlir/lib/IR/StandardTypes.cpp | 17 +-- mlir/lib/IR/SymbolTable.cpp | 17 +-- mlir/lib/IR/TypeDetail.h | 17 +-- mlir/lib/IR/TypeUtilities.cpp | 17 +-- mlir/lib/IR/Types.cpp | 17 +-- mlir/lib/IR/Value.cpp | 17 +-- mlir/lib/IR/Visitors.cpp | 17 +-- mlir/lib/Parser/Lexer.cpp | 17 +-- mlir/lib/Parser/Lexer.h | 17 +-- mlir/lib/Parser/Parser.cpp | 17 +-- mlir/lib/Parser/Token.cpp | 17 +-- mlir/lib/Parser/Token.h | 17 +-- mlir/lib/Parser/TokenKinds.def | 17 +-- mlir/lib/Pass/IRPrinting.cpp | 17 +-- mlir/lib/Pass/Pass.cpp | 17 +-- mlir/lib/Pass/PassDetail.h | 17 +-- mlir/lib/Pass/PassManagerOptions.cpp | 17 +-- mlir/lib/Pass/PassRegistry.cpp | 17 +-- mlir/lib/Pass/PassStatistics.cpp | 17 +-- mlir/lib/Pass/PassTiming.cpp | 17 +-- .../lib/Quantizer/Configurations/FxpMathConfig.cpp | 17 +-- mlir/lib/Quantizer/Support/Configuration.cpp | 17 +-- .../Quantizer/Support/ConstraintAnalysisGraph.cpp | 17 +-- mlir/lib/Quantizer/Support/Metadata.cpp | 17 +-- mlir/lib/Quantizer/Support/Statistics.cpp | 17 +-- mlir/lib/Quantizer/Support/TypeUtils.cpp | 17 +-- mlir/lib/Quantizer/Support/UniformConstraints.cpp | 17 +-- mlir/lib/Quantizer/Support/UniformSolvers.cpp | 17 +-- .../Transforms/AddDefaultStatsTestPass.cpp | 17 +-- .../Transforms/InferQuantizedTypesPass.cpp | 17 +-- .../Transforms/RemoveInstrumentationPass.cpp | 17 +-- mlir/lib/Support/FileUtilities.cpp | 17 +-- mlir/lib/Support/JitRunner.cpp | 17 +-- mlir/lib/Support/MlirOptMain.cpp | 17 +-- mlir/lib/Support/StorageUniquer.cpp | 17 +-- mlir/lib/Support/ToolUtilities.cpp | 17 +-- mlir/lib/Support/TranslateClParser.cpp | 17 +-- mlir/lib/TableGen/Argument.cpp | 17 +-- mlir/lib/TableGen/Attribute.cpp | 17 +-- mlir/lib/TableGen/Constraint.cpp | 17 +-- mlir/lib/TableGen/Dialect.cpp | 17 +-- mlir/lib/TableGen/Format.cpp | 17 +-- mlir/lib/TableGen/OpInterfaces.cpp | 17 +-- mlir/lib/TableGen/OpTrait.cpp | 17 +-- mlir/lib/TableGen/Operator.cpp | 17 +-- mlir/lib/TableGen/Pattern.cpp | 17 +-- mlir/lib/TableGen/Predicate.cpp | 17 +-- mlir/lib/TableGen/Type.cpp | 17 +-- mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp | 17 +-- mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp | 17 +-- mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp | 17 +-- mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp | 17 +-- mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 17 +-- mlir/lib/Transforms/AffineDataCopyGeneration.cpp | 17 +-- .../Transforms/AffineLoopInvariantCodeMotion.cpp | 17 +-- mlir/lib/Transforms/CSE.cpp | 17 +-- mlir/lib/Transforms/Canonicalizer.cpp | 17 +-- mlir/lib/Transforms/DialectConversion.cpp | 17 +-- mlir/lib/Transforms/Inliner.cpp | 17 +-- mlir/lib/Transforms/LoopCoalescing.cpp | 17 +-- mlir/lib/Transforms/LoopFusion.cpp | 17 +-- mlir/lib/Transforms/LoopInvariantCodeMotion.cpp | 17 +-- mlir/lib/Transforms/LoopTiling.cpp | 17 +-- mlir/lib/Transforms/LoopUnroll.cpp | 17 +-- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 17 +-- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 17 +-- mlir/lib/Transforms/PipelineDataTransfer.cpp | 17 +-- mlir/lib/Transforms/SimplifyAffineStructures.cpp | 17 +-- mlir/lib/Transforms/StripDebugInfo.cpp | 17 +-- mlir/lib/Transforms/Utils/FoldUtils.cpp | 17 +-- .../Utils/GreedyPatternRewriteDriver.cpp | 17 +-- mlir/lib/Transforms/Utils/InliningUtils.cpp | 17 +-- mlir/lib/Transforms/Utils/LoopFusionUtils.cpp | 17 +-- mlir/lib/Transforms/Utils/LoopUtils.cpp | 17 +-- mlir/lib/Transforms/Utils/RegionUtils.cpp | 17 +-- mlir/lib/Transforms/Utils/Utils.cpp | 17 +-- mlir/lib/Transforms/Vectorize.cpp | 17 +-- mlir/lib/Transforms/ViewOpGraph.cpp | 17 +-- mlir/lib/Transforms/ViewRegionGraph.cpp | 17 +-- mlir/lib/Translation/Translation.cpp | 17 +-- mlir/test/APITest.h | 17 +-- mlir/test/EDSC/builder-api-test.cpp | 17 +-- mlir/test/SDBM/sdbm-api-test.cpp | 17 +-- .../TestLinalgTransformPatterns.td | 17 +-- .../TestVectorTransformPatterns.td | 17 +-- mlir/test/lib/IR/TestFunc.cpp | 17 +-- mlir/test/lib/IR/TestMatchers.cpp | 17 +-- mlir/test/lib/IR/TestSymbolUses.cpp | 17 +-- mlir/test/lib/Pass/TestPassManager.cpp | 17 +-- mlir/test/lib/TestDialect/TestDialect.cpp | 17 +-- mlir/test/lib/TestDialect/TestDialect.h | 17 +-- mlir/test/lib/TestDialect/TestOps.td | 17 +-- mlir/test/lib/TestDialect/TestPatterns.cpp | 17 +-- mlir/test/lib/Transforms/TestCallGraph.cpp | 17 +-- mlir/test/lib/Transforms/TestConstantFold.cpp | 17 +-- mlir/test/lib/Transforms/TestInlining.cpp | 17 +-- mlir/test/lib/Transforms/TestLinalgTransforms.cpp | 17 +-- mlir/test/lib/Transforms/TestLiveness.cpp | 17 +-- mlir/test/lib/Transforms/TestLoopFusion.cpp | 17 +-- mlir/test/lib/Transforms/TestLoopMapping.cpp | 17 +-- .../lib/Transforms/TestLoopParametricTiling.cpp | 17 +-- .../lib/Transforms/TestMemRefStrideCalculation.cpp | 17 +-- mlir/test/lib/Transforms/TestOpaqueLoc.cpp | 17 +-- .../lib/Transforms/TestVectorToLoopsConversion.cpp | 17 +-- mlir/test/lib/Transforms/TestVectorTransforms.cpp | 17 +-- .../test/lib/Transforms/TestVectorizationUtils.cpp | 17 +-- mlir/test/mlir-cpu-runner/cblas.cpp | 17 +-- mlir/test/mlir-cpu-runner/cblas_interface.cpp | 17 +-- mlir/test/mlir-cpu-runner/include/cblas.h | 17 +-- .../mlir-cpu-runner/include/mlir_runner_utils.h | 17 +-- mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp | 17 +-- mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp | 17 +-- .../mlir-cuda-runner/cuda-runtime-wrappers.cpp | 17 +-- mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp | 17 +-- mlir/tools/mlir-opt/mlir-opt.cpp | 17 +-- mlir/tools/mlir-tblgen/DocGenUtilities.h | 17 +-- mlir/tools/mlir-tblgen/EnumsGen.cpp | 17 +-- mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp | 17 +-- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 17 +-- mlir/tools/mlir-tblgen/OpDocGen.cpp | 17 +-- mlir/tools/mlir-tblgen/OpInterfacesGen.cpp | 17 +-- mlir/tools/mlir-tblgen/ReferenceImplGen.cpp | 17 +-- mlir/tools/mlir-tblgen/RewriterGen.cpp | 17 +-- mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp | 17 +-- mlir/tools/mlir-tblgen/StructsGen.cpp | 17 +-- mlir/tools/mlir-tblgen/mlir-tblgen.cpp | 17 +-- mlir/tools/mlir-translate/mlir-translate.cpp | 17 +-- mlir/unittests/ADT/TypeSwitchTest.cpp | 17 +-- mlir/unittests/Dialect/BroadcastShapeTest.cpp | 17 +-- .../Dialect/QuantOps/QuantizationUtilsTest.cpp | 17 +-- .../Dialect/SPIRV/DeserializationTest.cpp | 17 +-- mlir/unittests/Dialect/SPIRV/SerializationTest.cpp | 17 +-- mlir/unittests/IR/AttributeTest.cpp | 17 +-- mlir/unittests/IR/DialectTest.cpp | 17 +-- mlir/unittests/IR/OperationSupportTest.cpp | 17 +-- mlir/unittests/IR/StringExtrasTest.cpp | 17 +-- mlir/unittests/Pass/AnalysisManagerTest.cpp | 17 +-- mlir/unittests/Quantizer/Support/RulesTest.cpp | 17 +-- .../Quantizer/Support/UniformSolversTest.cpp | 17 +-- mlir/unittests/SDBM/SDBMTest.cpp | 17 +-- mlir/unittests/TableGen/EnumsGenTest.cpp | 17 +-- mlir/unittests/TableGen/FormatTest.cpp | 17 +-- mlir/unittests/TableGen/StructsGenTest.cpp | 17 +-- mlir/unittests/TableGen/enums.td | 17 +-- mlir/unittests/TableGen/structs.td | 17 +-- mlir/utils/generate-test-checks.py | 16 +-- mlir/utils/spirv/define_enum.sh | 16 +-- mlir/utils/spirv/define_inst.sh | 16 +-- mlir/utils/spirv/define_opcodes.sh | 16 +-- mlir/utils/spirv/gen_spirv_dialect.py | 16 +-- 593 files changed, 2464 insertions(+), 7723 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/LICENSE.TXT b/mlir/LICENSE.TXT index a4b160b6e33..fa6ac540007 100644 --- a/mlir/LICENSE.TXT +++ b/mlir/LICENSE.TXT @@ -1,12 +1,14 @@ -Copyright 2019 The MLIR Authors. +============================================================================== +The LLVM Project is under the Apache License v2.0 with LLVM Exceptions: +============================================================================== Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - 1. Definitions. + 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. @@ -65,14 +67,14 @@ Copyright 2019 The MLIR Authors. on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. - 2. Grant of Copyright License. Subject to the terms and conditions of + 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. - 3. Grant of Patent License. Subject to the terms and conditions of + 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, @@ -88,7 +90,7 @@ Copyright 2019 The MLIR Authors. granted to You under this License for that Work shall terminate as of the date such litigation is filed. - 4. Redistribution. You may reproduce and distribute copies of the + 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: @@ -129,7 +131,7 @@ Copyright 2019 The MLIR Authors. reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. - 5. Submission of Contributions. Unless You explicitly state otherwise, + 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. @@ -137,12 +139,12 @@ Copyright 2019 The MLIR Authors. the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. - 6. Trademarks. This License does not grant permission to use the trade + 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. - 7. Disclaimer of Warranty. Unless required by applicable law or + 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or @@ -152,7 +154,7 @@ Copyright 2019 The MLIR Authors. appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. - 8. Limitation of Liability. In no event and under no legal theory, + 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be @@ -164,7 +166,7 @@ Copyright 2019 The MLIR Authors. other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. - 9. Accepting Warranty or Additional Liability. While redistributing + 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this @@ -175,9 +177,9 @@ Copyright 2019 The MLIR Authors. incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. - END OF TERMS AND CONDITIONS + END OF TERMS AND CONDITIONS - APPENDIX: How to apply the Apache License to your work. + APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" @@ -188,18 +190,90 @@ Copyright 2019 The MLIR Authors. same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + +---- LLVM Exceptions to the Apache 2.0 License ---- + +As an exception, if, as a result of your compiling your source code, portions +of this Software are embedded into an Object form of such source code, you +may redistribute such embedded portions in such Object form without complying +with the conditions of Sections 4(a), 4(b) and 4(d) of the License. + +In addition, if you combine or link compiled forms of this Software with +software that is licensed under the GPLv2 ("Combined Software") and if a +court of competent jurisdiction determines that the patent provision (Section +3), the indemnity provision (Section 9) or other Section of the License +conflicts with the conditions of the GPLv2, you may retroactively and +prospectively choose to deem waived or otherwise exclude such Section(s) of +the License, but only in their entirety and only with respect to the Combined +Software. + +============================================================================== +Software from third parties included in the LLVM Project: +============================================================================== +The LLVM Project contains third party software which is under different license +terms. All such code will be identified clearly using at least one of two +mechanisms: +1) It will be in a separate directory tree with its own `LICENSE.txt` or + `LICENSE` file at the top containing the specific license and restrictions + which apply to that software, or +2) It will contain specific license and restriction terms at the top of every + file. + +============================================================================== +Legacy LLVM License (https://llvm.org/docs/DeveloperPolicy.html#legacy): +============================================================================== +University of Illinois/NCSA +Open Source License + +Copyright (c) 2003-2019 University of Illinois at Urbana-Champaign. +All rights reserved. + +Developed by: + + LLVM Team + + University of Illinois at Urbana-Champaign + + http://llvm.org + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal with +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimers. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimers in the + documentation and/or other materials provided with the distribution. + + * Neither the names of the LLVM Team, University of Illinois at + Urbana-Champaign, nor the names of its contributors may be used to + endorse or promote products derived from this Software without specific + prior written permission. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH THE +SOFTWARE. diff --git a/mlir/bindings/python/pybind.cpp b/mlir/bindings/python/pybind.cpp index 54646cbe800..10445edaf12 100644 --- a/mlir/bindings/python/pybind.cpp +++ b/mlir/bindings/python/pybind.cpp @@ -1,19 +1,10 @@ //===- pybind.cpp - MLIR Python bindings ----------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" diff --git a/mlir/examples/toy/Ch1/include/toy/AST.h b/mlir/examples/toy/Ch1/include/toy/AST.h index 901164b0f39..820600b5b1c 100644 --- a/mlir/examples/toy/Ch1/include/toy/AST.h +++ b/mlir/examples/toy/Ch1/include/toy/AST.h @@ -1,19 +1,10 @@ //===- AST.h - Node definition for the Toy AST ----------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the AST for the Toy language. It is optimized for // simplicity, not efficiency. The AST forms a tree structure where each node diff --git a/mlir/examples/toy/Ch1/include/toy/Lexer.h b/mlir/examples/toy/Ch1/include/toy/Lexer.h index 2e19cd09b20..a77a91bb564 100644 --- a/mlir/examples/toy/Ch1/include/toy/Lexer.h +++ b/mlir/examples/toy/Ch1/include/toy/Lexer.h @@ -1,19 +1,10 @@ //===- Lexer.h - Lexer for the Toy language -------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a simple Lexer for the Toy language. // diff --git a/mlir/examples/toy/Ch1/include/toy/Parser.h b/mlir/examples/toy/Ch1/include/toy/Parser.h index 9e219e56551..4557ea26859 100644 --- a/mlir/examples/toy/Ch1/include/toy/Parser.h +++ b/mlir/examples/toy/Ch1/include/toy/Parser.h @@ -1,19 +1,10 @@ //===- Parser.h - Toy Language Parser -------------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the parser for the Toy language. It processes the Token // provided by the Lexer and returns an AST. diff --git a/mlir/examples/toy/Ch1/parser/AST.cpp b/mlir/examples/toy/Ch1/parser/AST.cpp index 3ec91a4300d..0d6d9359529 100644 --- a/mlir/examples/toy/Ch1/parser/AST.cpp +++ b/mlir/examples/toy/Ch1/parser/AST.cpp @@ -1,19 +1,10 @@ //===- AST.cpp - Helper for printing out the Toy AST ----------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the AST dump for the Toy language. // diff --git a/mlir/examples/toy/Ch1/toyc.cpp b/mlir/examples/toy/Ch1/toyc.cpp index 37794d5c4d9..48863fa931c 100644 --- a/mlir/examples/toy/Ch1/toyc.cpp +++ b/mlir/examples/toy/Ch1/toyc.cpp @@ -1,19 +1,10 @@ //===- toyc.cpp - The Toy Compiler ----------------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the entry point for the Toy compiler. // diff --git a/mlir/examples/toy/Ch2/include/toy/AST.h b/mlir/examples/toy/Ch2/include/toy/AST.h index 901164b0f39..820600b5b1c 100644 --- a/mlir/examples/toy/Ch2/include/toy/AST.h +++ b/mlir/examples/toy/Ch2/include/toy/AST.h @@ -1,19 +1,10 @@ //===- AST.h - Node definition for the Toy AST ----------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the AST for the Toy language. It is optimized for // simplicity, not efficiency. The AST forms a tree structure where each node diff --git a/mlir/examples/toy/Ch2/include/toy/Dialect.h b/mlir/examples/toy/Ch2/include/toy/Dialect.h index 91dd631d2ff..385d6ddb95a 100644 --- a/mlir/examples/toy/Ch2/include/toy/Dialect.h +++ b/mlir/examples/toy/Ch2/include/toy/Dialect.h @@ -1,19 +1,10 @@ //===- Dialect.h - Dialect definition for the Toy IR ----------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the IR Dialect for the Toy language. // See g3doc/Tutorials/Toy/Ch-2.md for more information. diff --git a/mlir/examples/toy/Ch2/include/toy/Lexer.h b/mlir/examples/toy/Ch2/include/toy/Lexer.h index 144388c460c..6eff64ee5f0 100644 --- a/mlir/examples/toy/Ch2/include/toy/Lexer.h +++ b/mlir/examples/toy/Ch2/include/toy/Lexer.h @@ -1,19 +1,10 @@ //===- Lexer.h - Lexer for the Toy language -------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a simple Lexer for the Toy language. // diff --git a/mlir/examples/toy/Ch2/include/toy/MLIRGen.h b/mlir/examples/toy/Ch2/include/toy/MLIRGen.h index 287f432c847..e1c8ca1201d 100644 --- a/mlir/examples/toy/Ch2/include/toy/MLIRGen.h +++ b/mlir/examples/toy/Ch2/include/toy/MLIRGen.h @@ -1,19 +1,10 @@ //===- MLIRGen.h - MLIR Generation from a Toy AST -------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file declares a simple interface to perform IR generation targeting MLIR // from a Module AST for the Toy language. diff --git a/mlir/examples/toy/Ch2/include/toy/Ops.td b/mlir/examples/toy/Ch2/include/toy/Ops.td index dd88b097ab1..20c4a7463d9 100644 --- a/mlir/examples/toy/Ch2/include/toy/Ops.td +++ b/mlir/examples/toy/Ch2/include/toy/Ops.td @@ -1,19 +1,10 @@ //===- Ops.td - Toy dialect operation definitions ----------*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Defines the operations of the Toy dialect. // diff --git a/mlir/examples/toy/Ch2/include/toy/Parser.h b/mlir/examples/toy/Ch2/include/toy/Parser.h index 9e219e56551..4557ea26859 100644 --- a/mlir/examples/toy/Ch2/include/toy/Parser.h +++ b/mlir/examples/toy/Ch2/include/toy/Parser.h @@ -1,19 +1,10 @@ //===- Parser.h - Toy Language Parser -------------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the parser for the Toy language. It processes the Token // provided by the Lexer and returns an AST. diff --git a/mlir/examples/toy/Ch2/mlir/Dialect.cpp b/mlir/examples/toy/Ch2/mlir/Dialect.cpp index 4a3232dabe3..b33cb5cbfe9 100644 --- a/mlir/examples/toy/Ch2/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch2/mlir/Dialect.cpp @@ -1,19 +1,10 @@ //===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the dialect for the Toy IR: custom type parsing and // operation verification. diff --git a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp index 902c634a954..e9987ff2c77 100644 --- a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp @@ -1,19 +1,10 @@ //===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a simple IR generation targeting MLIR from a Module AST // for the Toy language. diff --git a/mlir/examples/toy/Ch2/parser/AST.cpp b/mlir/examples/toy/Ch2/parser/AST.cpp index 3ec91a4300d..0d6d9359529 100644 --- a/mlir/examples/toy/Ch2/parser/AST.cpp +++ b/mlir/examples/toy/Ch2/parser/AST.cpp @@ -1,19 +1,10 @@ //===- AST.cpp - Helper for printing out the Toy AST ----------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the AST dump for the Toy language. // diff --git a/mlir/examples/toy/Ch2/toyc.cpp b/mlir/examples/toy/Ch2/toyc.cpp index 19def702589..3e3db97b4ae 100644 --- a/mlir/examples/toy/Ch2/toyc.cpp +++ b/mlir/examples/toy/Ch2/toyc.cpp @@ -1,19 +1,10 @@ //===- toyc.cpp - The Toy Compiler ----------------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the entry point for the Toy compiler. // diff --git a/mlir/examples/toy/Ch3/include/toy/AST.h b/mlir/examples/toy/Ch3/include/toy/AST.h index 901164b0f39..820600b5b1c 100644 --- a/mlir/examples/toy/Ch3/include/toy/AST.h +++ b/mlir/examples/toy/Ch3/include/toy/AST.h @@ -1,19 +1,10 @@ //===- AST.h - Node definition for the Toy AST ----------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the AST for the Toy language. It is optimized for // simplicity, not efficiency. The AST forms a tree structure where each node diff --git a/mlir/examples/toy/Ch3/include/toy/Dialect.h b/mlir/examples/toy/Ch3/include/toy/Dialect.h index 91dd631d2ff..385d6ddb95a 100644 --- a/mlir/examples/toy/Ch3/include/toy/Dialect.h +++ b/mlir/examples/toy/Ch3/include/toy/Dialect.h @@ -1,19 +1,10 @@ //===- Dialect.h - Dialect definition for the Toy IR ----------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the IR Dialect for the Toy language. // See g3doc/Tutorials/Toy/Ch-2.md for more information. diff --git a/mlir/examples/toy/Ch3/include/toy/Lexer.h b/mlir/examples/toy/Ch3/include/toy/Lexer.h index 144388c460c..6eff64ee5f0 100644 --- a/mlir/examples/toy/Ch3/include/toy/Lexer.h +++ b/mlir/examples/toy/Ch3/include/toy/Lexer.h @@ -1,19 +1,10 @@ //===- Lexer.h - Lexer for the Toy language -------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a simple Lexer for the Toy language. // diff --git a/mlir/examples/toy/Ch3/include/toy/MLIRGen.h b/mlir/examples/toy/Ch3/include/toy/MLIRGen.h index 287f432c847..e1c8ca1201d 100644 --- a/mlir/examples/toy/Ch3/include/toy/MLIRGen.h +++ b/mlir/examples/toy/Ch3/include/toy/MLIRGen.h @@ -1,19 +1,10 @@ //===- MLIRGen.h - MLIR Generation from a Toy AST -------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file declares a simple interface to perform IR generation targeting MLIR // from a Module AST for the Toy language. diff --git a/mlir/examples/toy/Ch3/include/toy/Ops.td b/mlir/examples/toy/Ch3/include/toy/Ops.td index 6c400169da2..a6c93ccba10 100644 --- a/mlir/examples/toy/Ch3/include/toy/Ops.td +++ b/mlir/examples/toy/Ch3/include/toy/Ops.td @@ -1,19 +1,10 @@ //===- Ops.td - Toy dialect operation definitions ----------*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Defines the operations of the Toy dialect. // diff --git a/mlir/examples/toy/Ch3/include/toy/Parser.h b/mlir/examples/toy/Ch3/include/toy/Parser.h index 9e219e56551..4557ea26859 100644 --- a/mlir/examples/toy/Ch3/include/toy/Parser.h +++ b/mlir/examples/toy/Ch3/include/toy/Parser.h @@ -1,19 +1,10 @@ //===- Parser.h - Toy Language Parser -------------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the parser for the Toy language. It processes the Token // provided by the Lexer and returns an AST. diff --git a/mlir/examples/toy/Ch3/mlir/Dialect.cpp b/mlir/examples/toy/Ch3/mlir/Dialect.cpp index 4a3232dabe3..b33cb5cbfe9 100644 --- a/mlir/examples/toy/Ch3/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch3/mlir/Dialect.cpp @@ -1,19 +1,10 @@ //===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the dialect for the Toy IR: custom type parsing and // operation verification. diff --git a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp index 902c634a954..e9987ff2c77 100644 --- a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp @@ -1,19 +1,10 @@ //===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a simple IR generation targeting MLIR from a Module AST // for the Toy language. diff --git a/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp index 42a10397513..d52a2c173c1 100644 --- a/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp @@ -1,19 +1,10 @@ //===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a set of simple combiners for optimizing operations in // the Toy dialect. diff --git a/mlir/examples/toy/Ch3/mlir/ToyCombine.td b/mlir/examples/toy/Ch3/mlir/ToyCombine.td index 1ca143a913c..e6e33e84d7e 100644 --- a/mlir/examples/toy/Ch3/mlir/ToyCombine.td +++ b/mlir/examples/toy/Ch3/mlir/ToyCombine.td @@ -1,19 +1,10 @@ //===- ToyCombine.td - Pattern Match Optimizations for Toy -*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Defines language-specific pattern match optimizations for Toy using // Declarative Rewrite Rules (DRR) specified using TableGen records. diff --git a/mlir/examples/toy/Ch3/parser/AST.cpp b/mlir/examples/toy/Ch3/parser/AST.cpp index 3ec91a4300d..0d6d9359529 100644 --- a/mlir/examples/toy/Ch3/parser/AST.cpp +++ b/mlir/examples/toy/Ch3/parser/AST.cpp @@ -1,19 +1,10 @@ //===- AST.cpp - Helper for printing out the Toy AST ----------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the AST dump for the Toy language. // diff --git a/mlir/examples/toy/Ch3/toyc.cpp b/mlir/examples/toy/Ch3/toyc.cpp index 410a9d677e8..e8b6e94786b 100644 --- a/mlir/examples/toy/Ch3/toyc.cpp +++ b/mlir/examples/toy/Ch3/toyc.cpp @@ -1,19 +1,10 @@ //===- toyc.cpp - The Toy Compiler ----------------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the entry point for the Toy compiler. // diff --git a/mlir/examples/toy/Ch4/include/toy/AST.h b/mlir/examples/toy/Ch4/include/toy/AST.h index 901164b0f39..820600b5b1c 100644 --- a/mlir/examples/toy/Ch4/include/toy/AST.h +++ b/mlir/examples/toy/Ch4/include/toy/AST.h @@ -1,19 +1,10 @@ //===- AST.h - Node definition for the Toy AST ----------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the AST for the Toy language. It is optimized for // simplicity, not efficiency. The AST forms a tree structure where each node diff --git a/mlir/examples/toy/Ch4/include/toy/Dialect.h b/mlir/examples/toy/Ch4/include/toy/Dialect.h index 556ae972b84..5e8b91dcf48 100644 --- a/mlir/examples/toy/Ch4/include/toy/Dialect.h +++ b/mlir/examples/toy/Ch4/include/toy/Dialect.h @@ -1,19 +1,10 @@ //===- Dialect.h - Dialect definition for the Toy IR ----------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the IR Dialect for the Toy language. // See g3doc/Tutorials/Toy/Ch-2.md for more information. diff --git a/mlir/examples/toy/Ch4/include/toy/Lexer.h b/mlir/examples/toy/Ch4/include/toy/Lexer.h index 144388c460c..6eff64ee5f0 100644 --- a/mlir/examples/toy/Ch4/include/toy/Lexer.h +++ b/mlir/examples/toy/Ch4/include/toy/Lexer.h @@ -1,19 +1,10 @@ //===- Lexer.h - Lexer for the Toy language -------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a simple Lexer for the Toy language. // diff --git a/mlir/examples/toy/Ch4/include/toy/MLIRGen.h b/mlir/examples/toy/Ch4/include/toy/MLIRGen.h index 287f432c847..e1c8ca1201d 100644 --- a/mlir/examples/toy/Ch4/include/toy/MLIRGen.h +++ b/mlir/examples/toy/Ch4/include/toy/MLIRGen.h @@ -1,19 +1,10 @@ //===- MLIRGen.h - MLIR Generation from a Toy AST -------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file declares a simple interface to perform IR generation targeting MLIR // from a Module AST for the Toy language. diff --git a/mlir/examples/toy/Ch4/include/toy/Ops.td b/mlir/examples/toy/Ch4/include/toy/Ops.td index ef5b30a862b..71167664bbc 100644 --- a/mlir/examples/toy/Ch4/include/toy/Ops.td +++ b/mlir/examples/toy/Ch4/include/toy/Ops.td @@ -1,19 +1,10 @@ //===- Ops.td - Toy dialect operation definitions ----------*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Defines the operations of the Toy dialect. // diff --git a/mlir/examples/toy/Ch4/include/toy/Parser.h b/mlir/examples/toy/Ch4/include/toy/Parser.h index 9e219e56551..4557ea26859 100644 --- a/mlir/examples/toy/Ch4/include/toy/Parser.h +++ b/mlir/examples/toy/Ch4/include/toy/Parser.h @@ -1,19 +1,10 @@ //===- Parser.h - Toy Language Parser -------------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the parser for the Toy language. It processes the Token // provided by the Lexer and returns an AST. diff --git a/mlir/examples/toy/Ch4/include/toy/Passes.h b/mlir/examples/toy/Ch4/include/toy/Passes.h index 8c8365d6882..93c51309008 100644 --- a/mlir/examples/toy/Ch4/include/toy/Passes.h +++ b/mlir/examples/toy/Ch4/include/toy/Passes.h @@ -1,19 +1,10 @@ //===- Passes.h - Toy Passes Definition -----------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file exposes the entry points to create compiler passes for Toy. // diff --git a/mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.h b/mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.h index fc36b5b100d..da0fb66018e 100644 --- a/mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.h +++ b/mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.h @@ -1,19 +1,10 @@ //===- ShapeInferenceInterface.h - Interface definitions for ShapeInference -=// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains the declarations of the shape inference interfaces defined // in ShapeInferenceInterface.td. diff --git a/mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.td b/mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.td index 6974575a63c..1b38ada1622 100644 --- a/mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.td +++ b/mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.td @@ -1,19 +1,10 @@ //===- ShapeInferenceInterface.td - Shape Inference Interface -*- tablegen -==// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Defines the operations of the Shape Inference Op Interface. // diff --git a/mlir/examples/toy/Ch4/mlir/DeadFunctionEliminationPass.cpp b/mlir/examples/toy/Ch4/mlir/DeadFunctionEliminationPass.cpp index b58adb5d52f..1ee34547860 100644 --- a/mlir/examples/toy/Ch4/mlir/DeadFunctionEliminationPass.cpp +++ b/mlir/examples/toy/Ch4/mlir/DeadFunctionEliminationPass.cpp @@ -1,19 +1,10 @@ //===- DeadFunctionEliminationPass.cpp - Eliminate inlined functions ------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a Module level pass performing dead function // elimination. This is required as a post-processing step after function diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp index 8be1094cf15..50116b14bea 100644 --- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp @@ -1,19 +1,10 @@ //===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the dialect for the Toy IR: custom type parsing and // operation verification. diff --git a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp index 902c634a954..e9987ff2c77 100644 --- a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp @@ -1,19 +1,10 @@ //===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a simple IR generation targeting MLIR from a Module AST // for the Toy language. diff --git a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp index 1f572015c39..517a1f07530 100644 --- a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp @@ -1,19 +1,10 @@ //===- ShapeInferencePass.cpp - Shape Inference ---------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a Function level pass performing interprocedural // propagation of array shapes through function specialization. diff --git a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp index 604e9fa6c83..2cbf8bdac9b 100644 --- a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp @@ -1,19 +1,10 @@ //===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a set of simple combiners for optimizing operations in // the Toy dialect. diff --git a/mlir/examples/toy/Ch4/mlir/ToyCombine.td b/mlir/examples/toy/Ch4/mlir/ToyCombine.td index 1ca143a913c..e6e33e84d7e 100644 --- a/mlir/examples/toy/Ch4/mlir/ToyCombine.td +++ b/mlir/examples/toy/Ch4/mlir/ToyCombine.td @@ -1,19 +1,10 @@ //===- ToyCombine.td - Pattern Match Optimizations for Toy -*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Defines language-specific pattern match optimizations for Toy using // Declarative Rewrite Rules (DRR) specified using TableGen records. diff --git a/mlir/examples/toy/Ch4/parser/AST.cpp b/mlir/examples/toy/Ch4/parser/AST.cpp index 3ec91a4300d..0d6d9359529 100644 --- a/mlir/examples/toy/Ch4/parser/AST.cpp +++ b/mlir/examples/toy/Ch4/parser/AST.cpp @@ -1,19 +1,10 @@ //===- AST.cpp - Helper for printing out the Toy AST ----------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the AST dump for the Toy language. // diff --git a/mlir/examples/toy/Ch4/toyc.cpp b/mlir/examples/toy/Ch4/toyc.cpp index 5ec514ac5b9..e7b584407f6 100644 --- a/mlir/examples/toy/Ch4/toyc.cpp +++ b/mlir/examples/toy/Ch4/toyc.cpp @@ -1,19 +1,10 @@ //===- toyc.cpp - The Toy Compiler ----------------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the entry point for the Toy compiler. // diff --git a/mlir/examples/toy/Ch5/include/toy/AST.h b/mlir/examples/toy/Ch5/include/toy/AST.h index 901164b0f39..820600b5b1c 100644 --- a/mlir/examples/toy/Ch5/include/toy/AST.h +++ b/mlir/examples/toy/Ch5/include/toy/AST.h @@ -1,19 +1,10 @@ //===- AST.h - Node definition for the Toy AST ----------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the AST for the Toy language. It is optimized for // simplicity, not efficiency. The AST forms a tree structure where each node diff --git a/mlir/examples/toy/Ch5/include/toy/Dialect.h b/mlir/examples/toy/Ch5/include/toy/Dialect.h index 556ae972b84..5e8b91dcf48 100644 --- a/mlir/examples/toy/Ch5/include/toy/Dialect.h +++ b/mlir/examples/toy/Ch5/include/toy/Dialect.h @@ -1,19 +1,10 @@ //===- Dialect.h - Dialect definition for the Toy IR ----------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the IR Dialect for the Toy language. // See g3doc/Tutorials/Toy/Ch-2.md for more information. diff --git a/mlir/examples/toy/Ch5/include/toy/Lexer.h b/mlir/examples/toy/Ch5/include/toy/Lexer.h index 144388c460c..6eff64ee5f0 100644 --- a/mlir/examples/toy/Ch5/include/toy/Lexer.h +++ b/mlir/examples/toy/Ch5/include/toy/Lexer.h @@ -1,19 +1,10 @@ //===- Lexer.h - Lexer for the Toy language -------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a simple Lexer for the Toy language. // diff --git a/mlir/examples/toy/Ch5/include/toy/MLIRGen.h b/mlir/examples/toy/Ch5/include/toy/MLIRGen.h index 287f432c847..e1c8ca1201d 100644 --- a/mlir/examples/toy/Ch5/include/toy/MLIRGen.h +++ b/mlir/examples/toy/Ch5/include/toy/MLIRGen.h @@ -1,19 +1,10 @@ //===- MLIRGen.h - MLIR Generation from a Toy AST -------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file declares a simple interface to perform IR generation targeting MLIR // from a Module AST for the Toy language. diff --git a/mlir/examples/toy/Ch5/include/toy/Ops.td b/mlir/examples/toy/Ch5/include/toy/Ops.td index b3bda1d647b..bb98ae19a09 100644 --- a/mlir/examples/toy/Ch5/include/toy/Ops.td +++ b/mlir/examples/toy/Ch5/include/toy/Ops.td @@ -1,19 +1,10 @@ //===- Ops.td - Toy dialect operation definitions ----------*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Defines the operations of the Toy dialect. // diff --git a/mlir/examples/toy/Ch5/include/toy/Parser.h b/mlir/examples/toy/Ch5/include/toy/Parser.h index 9e219e56551..4557ea26859 100644 --- a/mlir/examples/toy/Ch5/include/toy/Parser.h +++ b/mlir/examples/toy/Ch5/include/toy/Parser.h @@ -1,19 +1,10 @@ //===- Parser.h - Toy Language Parser -------------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the parser for the Toy language. It processes the Token // provided by the Lexer and returns an AST. diff --git a/mlir/examples/toy/Ch5/include/toy/Passes.h b/mlir/examples/toy/Ch5/include/toy/Passes.h index b6a79eda176..97a5d0db46c 100644 --- a/mlir/examples/toy/Ch5/include/toy/Passes.h +++ b/mlir/examples/toy/Ch5/include/toy/Passes.h @@ -1,19 +1,10 @@ //===- Passes.h - Toy Passes Definition -----------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file exposes the entry points to create compiler passes for Toy. // diff --git a/mlir/examples/toy/Ch5/include/toy/ShapeInferenceInterface.h b/mlir/examples/toy/Ch5/include/toy/ShapeInferenceInterface.h index fc36b5b100d..da0fb66018e 100644 --- a/mlir/examples/toy/Ch5/include/toy/ShapeInferenceInterface.h +++ b/mlir/examples/toy/Ch5/include/toy/ShapeInferenceInterface.h @@ -1,19 +1,10 @@ //===- ShapeInferenceInterface.h - Interface definitions for ShapeInference -=// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains the declarations of the shape inference interfaces defined // in ShapeInferenceInterface.td. diff --git a/mlir/examples/toy/Ch5/include/toy/ShapeInferenceInterface.td b/mlir/examples/toy/Ch5/include/toy/ShapeInferenceInterface.td index 6974575a63c..1b38ada1622 100644 --- a/mlir/examples/toy/Ch5/include/toy/ShapeInferenceInterface.td +++ b/mlir/examples/toy/Ch5/include/toy/ShapeInferenceInterface.td @@ -1,19 +1,10 @@ //===- ShapeInferenceInterface.td - Shape Inference Interface -*- tablegen -==// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Defines the operations of the Shape Inference Op Interface. // diff --git a/mlir/examples/toy/Ch5/mlir/DeadFunctionEliminationPass.cpp b/mlir/examples/toy/Ch5/mlir/DeadFunctionEliminationPass.cpp index b58adb5d52f..1ee34547860 100644 --- a/mlir/examples/toy/Ch5/mlir/DeadFunctionEliminationPass.cpp +++ b/mlir/examples/toy/Ch5/mlir/DeadFunctionEliminationPass.cpp @@ -1,19 +1,10 @@ //===- DeadFunctionEliminationPass.cpp - Eliminate inlined functions ------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a Module level pass performing dead function // elimination. This is required as a post-processing step after function diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp index 8be1094cf15..50116b14bea 100644 --- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp @@ -1,19 +1,10 @@ //===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the dialect for the Toy IR: custom type parsing and // operation verification. diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp index 3fa761c7404..cba838a2928 100644 --- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp @@ -1,19 +1,10 @@ //====- LowerToAffineLoops.cpp - Partial lowering from Toy to Affine+Std --===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a partial lowering of Toy operations to a combination of // affine loops and standard operations. This lowering expects that all calls diff --git a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp index 902c634a954..e9987ff2c77 100644 --- a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp @@ -1,19 +1,10 @@ //===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a simple IR generation targeting MLIR from a Module AST // for the Toy language. diff --git a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp index 1f572015c39..517a1f07530 100644 --- a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp @@ -1,19 +1,10 @@ //===- ShapeInferencePass.cpp - Shape Inference ---------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a Function level pass performing interprocedural // propagation of array shapes through function specialization. diff --git a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp index 604e9fa6c83..2cbf8bdac9b 100644 --- a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp @@ -1,19 +1,10 @@ //===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a set of simple combiners for optimizing operations in // the Toy dialect. diff --git a/mlir/examples/toy/Ch5/mlir/ToyCombine.td b/mlir/examples/toy/Ch5/mlir/ToyCombine.td index 1ca143a913c..e6e33e84d7e 100644 --- a/mlir/examples/toy/Ch5/mlir/ToyCombine.td +++ b/mlir/examples/toy/Ch5/mlir/ToyCombine.td @@ -1,19 +1,10 @@ //===- ToyCombine.td - Pattern Match Optimizations for Toy -*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Defines language-specific pattern match optimizations for Toy using // Declarative Rewrite Rules (DRR) specified using TableGen records. diff --git a/mlir/examples/toy/Ch5/parser/AST.cpp b/mlir/examples/toy/Ch5/parser/AST.cpp index 3ec91a4300d..0d6d9359529 100644 --- a/mlir/examples/toy/Ch5/parser/AST.cpp +++ b/mlir/examples/toy/Ch5/parser/AST.cpp @@ -1,19 +1,10 @@ //===- AST.cpp - Helper for printing out the Toy AST ----------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the AST dump for the Toy language. // diff --git a/mlir/examples/toy/Ch5/toyc.cpp b/mlir/examples/toy/Ch5/toyc.cpp index e1ab8c0ce55..836968e2188 100644 --- a/mlir/examples/toy/Ch5/toyc.cpp +++ b/mlir/examples/toy/Ch5/toyc.cpp @@ -1,19 +1,10 @@ //===- toyc.cpp - The Toy Compiler ----------------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the entry point for the Toy compiler. // diff --git a/mlir/examples/toy/Ch6/include/toy/AST.h b/mlir/examples/toy/Ch6/include/toy/AST.h index 901164b0f39..820600b5b1c 100644 --- a/mlir/examples/toy/Ch6/include/toy/AST.h +++ b/mlir/examples/toy/Ch6/include/toy/AST.h @@ -1,19 +1,10 @@ //===- AST.h - Node definition for the Toy AST ----------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the AST for the Toy language. It is optimized for // simplicity, not efficiency. The AST forms a tree structure where each node diff --git a/mlir/examples/toy/Ch6/include/toy/Dialect.h b/mlir/examples/toy/Ch6/include/toy/Dialect.h index 556ae972b84..5e8b91dcf48 100644 --- a/mlir/examples/toy/Ch6/include/toy/Dialect.h +++ b/mlir/examples/toy/Ch6/include/toy/Dialect.h @@ -1,19 +1,10 @@ //===- Dialect.h - Dialect definition for the Toy IR ----------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the IR Dialect for the Toy language. // See g3doc/Tutorials/Toy/Ch-2.md for more information. diff --git a/mlir/examples/toy/Ch6/include/toy/Lexer.h b/mlir/examples/toy/Ch6/include/toy/Lexer.h index 144388c460c..6eff64ee5f0 100644 --- a/mlir/examples/toy/Ch6/include/toy/Lexer.h +++ b/mlir/examples/toy/Ch6/include/toy/Lexer.h @@ -1,19 +1,10 @@ //===- Lexer.h - Lexer for the Toy language -------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a simple Lexer for the Toy language. // diff --git a/mlir/examples/toy/Ch6/include/toy/MLIRGen.h b/mlir/examples/toy/Ch6/include/toy/MLIRGen.h index 287f432c847..e1c8ca1201d 100644 --- a/mlir/examples/toy/Ch6/include/toy/MLIRGen.h +++ b/mlir/examples/toy/Ch6/include/toy/MLIRGen.h @@ -1,19 +1,10 @@ //===- MLIRGen.h - MLIR Generation from a Toy AST -------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file declares a simple interface to perform IR generation targeting MLIR // from a Module AST for the Toy language. diff --git a/mlir/examples/toy/Ch6/include/toy/Ops.td b/mlir/examples/toy/Ch6/include/toy/Ops.td index b3bda1d647b..bb98ae19a09 100644 --- a/mlir/examples/toy/Ch6/include/toy/Ops.td +++ b/mlir/examples/toy/Ch6/include/toy/Ops.td @@ -1,19 +1,10 @@ //===- Ops.td - Toy dialect operation definitions ----------*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Defines the operations of the Toy dialect. // diff --git a/mlir/examples/toy/Ch6/include/toy/Parser.h b/mlir/examples/toy/Ch6/include/toy/Parser.h index 9e219e56551..4557ea26859 100644 --- a/mlir/examples/toy/Ch6/include/toy/Parser.h +++ b/mlir/examples/toy/Ch6/include/toy/Parser.h @@ -1,19 +1,10 @@ //===- Parser.h - Toy Language Parser -------------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the parser for the Toy language. It processes the Token // provided by the Lexer and returns an AST. diff --git a/mlir/examples/toy/Ch6/include/toy/Passes.h b/mlir/examples/toy/Ch6/include/toy/Passes.h index 00fe4ffe49b..33c2021c8db 100644 --- a/mlir/examples/toy/Ch6/include/toy/Passes.h +++ b/mlir/examples/toy/Ch6/include/toy/Passes.h @@ -1,19 +1,10 @@ //===- Passes.h - Toy Passes Definition -----------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file exposes the entry points to create compiler passes for Toy. // diff --git a/mlir/examples/toy/Ch6/include/toy/ShapeInferenceInterface.h b/mlir/examples/toy/Ch6/include/toy/ShapeInferenceInterface.h index fc36b5b100d..da0fb66018e 100644 --- a/mlir/examples/toy/Ch6/include/toy/ShapeInferenceInterface.h +++ b/mlir/examples/toy/Ch6/include/toy/ShapeInferenceInterface.h @@ -1,19 +1,10 @@ //===- ShapeInferenceInterface.h - Interface definitions for ShapeInference -=// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains the declarations of the shape inference interfaces defined // in ShapeInferenceInterface.td. diff --git a/mlir/examples/toy/Ch6/include/toy/ShapeInferenceInterface.td b/mlir/examples/toy/Ch6/include/toy/ShapeInferenceInterface.td index 6974575a63c..1b38ada1622 100644 --- a/mlir/examples/toy/Ch6/include/toy/ShapeInferenceInterface.td +++ b/mlir/examples/toy/Ch6/include/toy/ShapeInferenceInterface.td @@ -1,19 +1,10 @@ //===- ShapeInferenceInterface.td - Shape Inference Interface -*- tablegen -==// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Defines the operations of the Shape Inference Op Interface. // diff --git a/mlir/examples/toy/Ch6/mlir/DeadFunctionEliminationPass.cpp b/mlir/examples/toy/Ch6/mlir/DeadFunctionEliminationPass.cpp index b58adb5d52f..1ee34547860 100644 --- a/mlir/examples/toy/Ch6/mlir/DeadFunctionEliminationPass.cpp +++ b/mlir/examples/toy/Ch6/mlir/DeadFunctionEliminationPass.cpp @@ -1,19 +1,10 @@ //===- DeadFunctionEliminationPass.cpp - Eliminate inlined functions ------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a Module level pass performing dead function // elimination. This is required as a post-processing step after function diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp index 8be1094cf15..50116b14bea 100644 --- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp @@ -1,19 +1,10 @@ //===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the dialect for the Toy IR: custom type parsing and // operation verification. diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp index 3fa761c7404..cba838a2928 100644 --- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp @@ -1,19 +1,10 @@ //====- LowerToAffineLoops.cpp - Partial lowering from Toy to Affine+Std --===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a partial lowering of Toy operations to a combination of // affine loops and standard operations. This lowering expects that all calls diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp index c3180b4a92d..377bc11dd27 100644 --- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp @@ -1,19 +1,10 @@ //====- LowerToLLVM.cpp - Lowering from Toy+Affine+Std to LLVM ------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a partial lowering of Toy operations to a combination of // affine loops and standard operations. This lowering expects that all calls diff --git a/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp index 902c634a954..e9987ff2c77 100644 --- a/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp @@ -1,19 +1,10 @@ //===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a simple IR generation targeting MLIR from a Module AST // for the Toy language. diff --git a/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp index 1f572015c39..517a1f07530 100644 --- a/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp @@ -1,19 +1,10 @@ //===- ShapeInferencePass.cpp - Shape Inference ---------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a Function level pass performing interprocedural // propagation of array shapes through function specialization. diff --git a/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp index 604e9fa6c83..2cbf8bdac9b 100644 --- a/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp @@ -1,19 +1,10 @@ //===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a set of simple combiners for optimizing operations in // the Toy dialect. diff --git a/mlir/examples/toy/Ch6/mlir/ToyCombine.td b/mlir/examples/toy/Ch6/mlir/ToyCombine.td index 1ca143a913c..e6e33e84d7e 100644 --- a/mlir/examples/toy/Ch6/mlir/ToyCombine.td +++ b/mlir/examples/toy/Ch6/mlir/ToyCombine.td @@ -1,19 +1,10 @@ //===- ToyCombine.td - Pattern Match Optimizations for Toy -*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Defines language-specific pattern match optimizations for Toy using // Declarative Rewrite Rules (DRR) specified using TableGen records. diff --git a/mlir/examples/toy/Ch6/parser/AST.cpp b/mlir/examples/toy/Ch6/parser/AST.cpp index 3ec91a4300d..0d6d9359529 100644 --- a/mlir/examples/toy/Ch6/parser/AST.cpp +++ b/mlir/examples/toy/Ch6/parser/AST.cpp @@ -1,19 +1,10 @@ //===- AST.cpp - Helper for printing out the Toy AST ----------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the AST dump for the Toy language. // diff --git a/mlir/examples/toy/Ch6/toyc.cpp b/mlir/examples/toy/Ch6/toyc.cpp index 60e3d0f9791..4e5b2afb7c6 100644 --- a/mlir/examples/toy/Ch6/toyc.cpp +++ b/mlir/examples/toy/Ch6/toyc.cpp @@ -1,19 +1,10 @@ //===- toyc.cpp - The Toy Compiler ----------------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the entry point for the Toy compiler. // diff --git a/mlir/examples/toy/Ch7/include/toy/AST.h b/mlir/examples/toy/Ch7/include/toy/AST.h index 558d9deab8e..3d3ae89dbeb 100644 --- a/mlir/examples/toy/Ch7/include/toy/AST.h +++ b/mlir/examples/toy/Ch7/include/toy/AST.h @@ -1,19 +1,10 @@ //===- AST.h - Node definition for the Toy AST ----------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the AST for the Toy language. It is optimized for // simplicity, not efficiency. The AST forms a tree structure where each node diff --git a/mlir/examples/toy/Ch7/include/toy/Dialect.h b/mlir/examples/toy/Ch7/include/toy/Dialect.h index b96ff99a5b6..77481b1884f 100644 --- a/mlir/examples/toy/Ch7/include/toy/Dialect.h +++ b/mlir/examples/toy/Ch7/include/toy/Dialect.h @@ -1,19 +1,10 @@ //===- Dialect.h - Dialect definition for the Toy IR ----------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the IR Dialect for the Toy language. // See g3doc/Tutorials/Toy/Ch-2.md for more information. diff --git a/mlir/examples/toy/Ch7/include/toy/Lexer.h b/mlir/examples/toy/Ch7/include/toy/Lexer.h index 89dc6cba9ff..b41b82f2a0a 100644 --- a/mlir/examples/toy/Ch7/include/toy/Lexer.h +++ b/mlir/examples/toy/Ch7/include/toy/Lexer.h @@ -1,19 +1,10 @@ //===- Lexer.h - Lexer for the Toy language -------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a simple Lexer for the Toy language. // diff --git a/mlir/examples/toy/Ch7/include/toy/MLIRGen.h b/mlir/examples/toy/Ch7/include/toy/MLIRGen.h index 287f432c847..e1c8ca1201d 100644 --- a/mlir/examples/toy/Ch7/include/toy/MLIRGen.h +++ b/mlir/examples/toy/Ch7/include/toy/MLIRGen.h @@ -1,19 +1,10 @@ //===- MLIRGen.h - MLIR Generation from a Toy AST -------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file declares a simple interface to perform IR generation targeting MLIR // from a Module AST for the Toy language. diff --git a/mlir/examples/toy/Ch7/include/toy/Ops.td b/mlir/examples/toy/Ch7/include/toy/Ops.td index 94f1bcf3e82..801aef06934 100644 --- a/mlir/examples/toy/Ch7/include/toy/Ops.td +++ b/mlir/examples/toy/Ch7/include/toy/Ops.td @@ -1,19 +1,10 @@ //===- Ops.td - Toy dialect operation definitions ----------*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Defines the operations of the Toy dialect. // diff --git a/mlir/examples/toy/Ch7/include/toy/Parser.h b/mlir/examples/toy/Ch7/include/toy/Parser.h index df6c4fb2f60..d2659e04dac 100644 --- a/mlir/examples/toy/Ch7/include/toy/Parser.h +++ b/mlir/examples/toy/Ch7/include/toy/Parser.h @@ -1,19 +1,10 @@ //===- Parser.h - Toy Language Parser -------------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the parser for the Toy language. It processes the Token // provided by the Lexer and returns an AST. diff --git a/mlir/examples/toy/Ch7/include/toy/Passes.h b/mlir/examples/toy/Ch7/include/toy/Passes.h index 00fe4ffe49b..33c2021c8db 100644 --- a/mlir/examples/toy/Ch7/include/toy/Passes.h +++ b/mlir/examples/toy/Ch7/include/toy/Passes.h @@ -1,19 +1,10 @@ //===- Passes.h - Toy Passes Definition -----------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file exposes the entry points to create compiler passes for Toy. // diff --git a/mlir/examples/toy/Ch7/include/toy/ShapeInferenceInterface.h b/mlir/examples/toy/Ch7/include/toy/ShapeInferenceInterface.h index fc36b5b100d..da0fb66018e 100644 --- a/mlir/examples/toy/Ch7/include/toy/ShapeInferenceInterface.h +++ b/mlir/examples/toy/Ch7/include/toy/ShapeInferenceInterface.h @@ -1,19 +1,10 @@ //===- ShapeInferenceInterface.h - Interface definitions for ShapeInference -=// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains the declarations of the shape inference interfaces defined // in ShapeInferenceInterface.td. diff --git a/mlir/examples/toy/Ch7/include/toy/ShapeInferenceInterface.td b/mlir/examples/toy/Ch7/include/toy/ShapeInferenceInterface.td index 6974575a63c..1b38ada1622 100644 --- a/mlir/examples/toy/Ch7/include/toy/ShapeInferenceInterface.td +++ b/mlir/examples/toy/Ch7/include/toy/ShapeInferenceInterface.td @@ -1,19 +1,10 @@ //===- ShapeInferenceInterface.td - Shape Inference Interface -*- tablegen -==// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Defines the operations of the Shape Inference Op Interface. // diff --git a/mlir/examples/toy/Ch7/mlir/DeadFunctionEliminationPass.cpp b/mlir/examples/toy/Ch7/mlir/DeadFunctionEliminationPass.cpp index b58adb5d52f..1ee34547860 100644 --- a/mlir/examples/toy/Ch7/mlir/DeadFunctionEliminationPass.cpp +++ b/mlir/examples/toy/Ch7/mlir/DeadFunctionEliminationPass.cpp @@ -1,19 +1,10 @@ //===- DeadFunctionEliminationPass.cpp - Eliminate inlined functions ------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a Module level pass performing dead function // elimination. This is required as a post-processing step after function diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp index 0ce896db5de..4f4cbdf2f0f 100644 --- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -1,19 +1,10 @@ //===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the dialect for the Toy IR: custom type parsing and // operation verification. diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp index 3fa761c7404..cba838a2928 100644 --- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp @@ -1,19 +1,10 @@ //====- LowerToAffineLoops.cpp - Partial lowering from Toy to Affine+Std --===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a partial lowering of Toy operations to a combination of // affine loops and standard operations. This lowering expects that all calls diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp index c3180b4a92d..377bc11dd27 100644 --- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp @@ -1,19 +1,10 @@ //====- LowerToLLVM.cpp - Lowering from Toy+Affine+Std to LLVM ------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a partial lowering of Toy operations to a combination of // affine loops and standard operations. This lowering expects that all calls diff --git a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp index 590b21e53a1..62e8c553709 100644 --- a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp @@ -1,19 +1,10 @@ //===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a simple IR generation targeting MLIR from a Module AST // for the Toy language. diff --git a/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp index 1f572015c39..517a1f07530 100644 --- a/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp @@ -1,19 +1,10 @@ //===- ShapeInferencePass.cpp - Shape Inference ---------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a Function level pass performing interprocedural // propagation of array shapes through function specialization. diff --git a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp index d18396c63bb..2fb0a1c5b69 100644 --- a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp @@ -1,19 +1,10 @@ //===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a set of simple combiners for optimizing operations in // the Toy dialect. diff --git a/mlir/examples/toy/Ch7/mlir/ToyCombine.td b/mlir/examples/toy/Ch7/mlir/ToyCombine.td index 1ca143a913c..e6e33e84d7e 100644 --- a/mlir/examples/toy/Ch7/mlir/ToyCombine.td +++ b/mlir/examples/toy/Ch7/mlir/ToyCombine.td @@ -1,19 +1,10 @@ //===- ToyCombine.td - Pattern Match Optimizations for Toy -*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Defines language-specific pattern match optimizations for Toy using // Declarative Rewrite Rules (DRR) specified using TableGen records. diff --git a/mlir/examples/toy/Ch7/parser/AST.cpp b/mlir/examples/toy/Ch7/parser/AST.cpp index 391757f711f..669bc9dbec2 100644 --- a/mlir/examples/toy/Ch7/parser/AST.cpp +++ b/mlir/examples/toy/Ch7/parser/AST.cpp @@ -1,19 +1,10 @@ //===- AST.cpp - Helper for printing out the Toy AST ----------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the AST dump for the Toy language. // diff --git a/mlir/examples/toy/Ch7/toyc.cpp b/mlir/examples/toy/Ch7/toyc.cpp index ec5a4f8056b..c6afab594e1 100644 --- a/mlir/examples/toy/Ch7/toyc.cpp +++ b/mlir/examples/toy/Ch7/toyc.cpp @@ -1,19 +1,10 @@ //===- toyc.cpp - The Toy Compiler ----------------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the entry point for the Toy compiler. // diff --git a/mlir/include/mlir-c/Core.h b/mlir/include/mlir-c/Core.h index c205e898901..5e3e2087f8b 100644 --- a/mlir/include/mlir-c/Core.h +++ b/mlir/include/mlir-c/Core.h @@ -1,18 +1,9 @@ /*===-- mlir-c/Core.h - Core Library C Interface ------------------*- C -*-===*\ |* *| -|* Copyright 2019 The MLIR Authors. *| -|* *| -|* Licensed under the Apache License, Version 2.0 (the "License"); *| -|* you may not use this file except in compliance with the License. *| -|* You may obtain a copy of the License at *| -|* *| -|* http://www.apache.org/licenses/LICENSE-2.0 *| -|* *| -|* Unless required by applicable law or agreed to in writing, software *| -|* distributed under the License is distributed on an "AS IS" BASIS, *| -|* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *| -|* See the License for the specific language governing permissions and *| -|* limitations under the License. *| +|* Part of the MLIR Project, under the Apache License v2.0 with LLVM *| +|* Exceptions. *| +|* See https://llvm.org/LICENSE.txt for license information. *| +|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception *| |* *| |*===----------------------------------------------------------------------===*| |* *| diff --git a/mlir/include/mlir/ADT/TypeSwitch.h b/mlir/include/mlir/ADT/TypeSwitch.h index 75051b6a539..2dbc611f557 100644 --- a/mlir/include/mlir/ADT/TypeSwitch.h +++ b/mlir/include/mlir/ADT/TypeSwitch.h @@ -1,19 +1,10 @@ //===- TypeSwitch.h - Switch functionality for RTTI casting -*- C++ -*-----===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the TypeSwitch template, which mimics a switch() // statement whose cases are type names. diff --git a/mlir/include/mlir/Analysis/AffineAnalysis.h b/mlir/include/mlir/Analysis/AffineAnalysis.h index f506470f36a..5d9422883c1 100644 --- a/mlir/include/mlir/Analysis/AffineAnalysis.h +++ b/mlir/include/mlir/Analysis/AffineAnalysis.h @@ -1,19 +1,10 @@ //===- AffineAnalysis.h - analyses for affine structures --------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This header file defines prototypes for methods that perform analysis // involving affine structures (AffineExprStorage, AffineMap, IntegerSet, etc.) diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h index 65cf13a0ce6..770bf686f50 100644 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -1,19 +1,10 @@ //===- AffineStructures.h - MLIR Affine Structures Class --------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Structures for affine/polyhedral analysis of ML functions. // diff --git a/mlir/include/mlir/Analysis/CallGraph.h b/mlir/include/mlir/Analysis/CallGraph.h index 700a016e836..8f954161921 100644 --- a/mlir/include/mlir/Analysis/CallGraph.h +++ b/mlir/include/mlir/Analysis/CallGraph.h @@ -1,19 +1,10 @@ //===- CallGraph.h - CallGraph analysis for MLIR ----------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains an analysis for computing the multi-level callgraph from a // given top-level operation. This nodes within this callgraph are defined by diff --git a/mlir/include/mlir/Analysis/CallInterfaces.h b/mlir/include/mlir/Analysis/CallInterfaces.h index a18cfa7aba4..a9806bfb8c6 100644 --- a/mlir/include/mlir/Analysis/CallInterfaces.h +++ b/mlir/include/mlir/Analysis/CallInterfaces.h @@ -1,19 +1,10 @@ //===- CallInterfaces.h - Call Interfaces for MLIR --------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains the definitions of the call interfaces defined in // `CallInterfaces.td`. diff --git a/mlir/include/mlir/Analysis/CallInterfaces.td b/mlir/include/mlir/Analysis/CallInterfaces.td index 043f009a8e2..3e5b599baf8 100644 --- a/mlir/include/mlir/Analysis/CallInterfaces.td +++ b/mlir/include/mlir/Analysis/CallInterfaces.td @@ -1,19 +1,10 @@ //===- CallInterfaces.td - Call Interfaces for ops -*- tablegen ---------*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains a set of interfaces that can be used to define information // related to call-like and callable operations. Each of which are defined along diff --git a/mlir/include/mlir/Analysis/Dominance.h b/mlir/include/mlir/Analysis/Dominance.h index f46241e2af0..5c42dbe12c2 100644 --- a/mlir/include/mlir/Analysis/Dominance.h +++ b/mlir/include/mlir/Analysis/Dominance.h @@ -1,19 +1,10 @@ //===- Dominance.h - Dominator analysis for CFGs ----------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_ANALYSIS_DOMINANCE_H #define MLIR_ANALYSIS_DOMINANCE_H diff --git a/mlir/include/mlir/Analysis/InferTypeOpInterface.h b/mlir/include/mlir/Analysis/InferTypeOpInterface.h index 2d68ada0d13..baf16162a0b 100644 --- a/mlir/include/mlir/Analysis/InferTypeOpInterface.h +++ b/mlir/include/mlir/Analysis/InferTypeOpInterface.h @@ -1,19 +1,10 @@ //===- InferTypeOpInterface.h - Infer Type Interfaces -----------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains the definitions of the infer op interfaces defined in // `InferTypeOpInterface.td`. diff --git a/mlir/include/mlir/Analysis/InferTypeOpInterface.td b/mlir/include/mlir/Analysis/InferTypeOpInterface.td index 14d580962e1..bbcea6be7eb 100644 --- a/mlir/include/mlir/Analysis/InferTypeOpInterface.td +++ b/mlir/include/mlir/Analysis/InferTypeOpInterface.td @@ -1,19 +1,10 @@ //===- InferTypeOpInterface.td - Infer Type interfaces -----*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains a set of interfaces that can be used to define information // related to type inference. diff --git a/mlir/include/mlir/Analysis/Liveness.h b/mlir/include/mlir/Analysis/Liveness.h index 0aa9d9693e4..791c164c7d2 100644 --- a/mlir/include/mlir/Analysis/Liveness.h +++ b/mlir/include/mlir/Analysis/Liveness.h @@ -1,19 +1,10 @@ //===- Liveness.h - Liveness analysis for MLIR ------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains an analysis for computing liveness information from a // given top-level operation. The current version of the analysis uses a diff --git a/mlir/include/mlir/Analysis/LoopAnalysis.h b/mlir/include/mlir/Analysis/LoopAnalysis.h index ad7dc6d6092..66f0033bf2f 100644 --- a/mlir/include/mlir/Analysis/LoopAnalysis.h +++ b/mlir/include/mlir/Analysis/LoopAnalysis.h @@ -1,19 +1,10 @@ //===- LoopAnalysis.h - loop analysis methods -------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This header file defines prototypes for methods to analyze loops. // diff --git a/mlir/include/mlir/Analysis/NestedMatcher.h b/mlir/include/mlir/Analysis/NestedMatcher.h index 9af26e8842a..2da64e88e14 100644 --- a/mlir/include/mlir/Analysis/NestedMatcher.h +++ b/mlir/include/mlir/Analysis/NestedMatcher.h @@ -1,19 +1,10 @@ //===- NestedMacher.h - Nested matcher for Function -------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_ANALYSIS_MLFUNCTIONMATCHER_H_ #define MLIR_ANALYSIS_MLFUNCTIONMATCHER_H_ diff --git a/mlir/include/mlir/Analysis/Passes.h b/mlir/include/mlir/Analysis/Passes.h index b233ab5f209..0bbc850e6c9 100644 --- a/mlir/include/mlir/Analysis/Passes.h +++ b/mlir/include/mlir/Analysis/Passes.h @@ -1,19 +1,10 @@ //===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This header file defines prototypes that expose pass constructors in the // analysis library. diff --git a/mlir/include/mlir/Analysis/SliceAnalysis.h b/mlir/include/mlir/Analysis/SliceAnalysis.h index ad6b65387be..d7b6e957014 100644 --- a/mlir/include/mlir/Analysis/SliceAnalysis.h +++ b/mlir/include/mlir/Analysis/SliceAnalysis.h @@ -1,19 +1,10 @@ //===- SliceAnalysis.h - Analysis for Transitive UseDef chains --*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_ANALYSIS_SLICEANALYSIS_H_ #define MLIR_ANALYSIS_SLICEANALYSIS_H_ diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index ea0987df3fe..d06e003faae 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -1,19 +1,10 @@ //===- Utils.h - General analysis utilities ---------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This header file defines prototypes for various transformation utilities for // memref's and non-loop IR structures. These are not passes by themselves but diff --git a/mlir/include/mlir/Analysis/Verifier.h b/mlir/include/mlir/Analysis/Verifier.h index daaff57683e..b7075b4f157 100644 --- a/mlir/include/mlir/Analysis/Verifier.h +++ b/mlir/include/mlir/Analysis/Verifier.h @@ -1,19 +1,10 @@ //===- Verifier.h - Verifier analysis for MLIR structures -------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_ANALYSIS_VERIFIER_H #define MLIR_ANALYSIS_VERIFIER_H diff --git a/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h b/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h index 4bbe6610e31..8e873bfb1c3 100644 --- a/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h +++ b/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h @@ -1,19 +1,10 @@ //===- AffineToStandard.h - Convert Affine to Standard dialect --*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_CONVERSION_AFFINETOSTANDARD_AFFINETOSTANDARD_H #define MLIR_CONVERSION_AFFINETOSTANDARD_AFFINETOSTANDARD_H diff --git a/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h b/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h index 6b9b08ed7d5..4eb6379adf6 100644 --- a/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h +++ b/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h @@ -1,19 +1,10 @@ //===- GPUToCUDAPass.h - MLIR CUDA runtime support --------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_ #define MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_ diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h index 635d4366e83..75e4f7e374c 100644 --- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h +++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h @@ -1,19 +1,10 @@ //===- GPUToNVVMPass.h - Convert GPU kernel to NVVM dialect -----*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_ #define MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_ diff --git a/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h b/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h index 54cda41afa1..e913c2e1131 100644 --- a/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h +++ b/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h @@ -1,19 +1,10 @@ //===- GPUToROCDLPass.h - Convert GPU kernel to ROCDL dialect ---*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_CONVERSION_GPUTOROCDL_GPUTOROCDLPASS_H_ #define MLIR_CONVERSION_GPUTOROCDL_GPUTOROCDLPASS_H_ diff --git a/mlir/include/mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h b/mlir/include/mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h index 134dbf40b4d..762a6e502d4 100644 --- a/mlir/include/mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h +++ b/mlir/include/mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h @@ -1,19 +1,10 @@ //===- ConvertGPUToSPIRV.h - GPU Ops to SPIR-V dialect patterns ----C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Provides patterns for lowering GPU Ops to SPIR-V dialect. // diff --git a/mlir/include/mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h b/mlir/include/mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h index 8f0a910c74d..37230f4c0e1 100644 --- a/mlir/include/mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h +++ b/mlir/include/mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h @@ -1,19 +1,10 @@ //===- ConvertGPUToSPIRVPass.h - GPU to SPIR-V conversion pass --*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Provides a pass to convert GPU ops to SPIRV ops. // diff --git a/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h b/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h index 6bae08e13be..27950177c1d 100644 --- a/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h +++ b/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h @@ -1,19 +1,10 @@ //===- LinalgToLLVM.h - Utils to convert from the linalg dialect ----------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_CONVERSION_LINALGTOLLVM_LINALGTOLLVM_H_ #define MLIR_CONVERSION_LINALGTOLLVM_LINALGTOLLVM_H_ diff --git a/mlir/include/mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h b/mlir/include/mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h index 095c9f470b3..5cb8f59e6f7 100644 --- a/mlir/include/mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h +++ b/mlir/include/mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h @@ -1,19 +1,10 @@ //===- ConvertLoopToStandard.h - Pass entrypoint ----------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_CONVERSION_LOOPTOSTANDARD_CONVERTLOOPTOSTANDARD_H_ #define MLIR_CONVERSION_LOOPTOSTANDARD_CONVERTLOOPTOSTANDARD_H_ diff --git a/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h b/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h index 58d49a13391..5f3ea87f3cc 100644 --- a/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h +++ b/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h @@ -1,19 +1,10 @@ //===- LoopsToGPU.h - Convert loop nests to GPU kernels ---------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPU_H_ #define MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPU_H_ diff --git a/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h b/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h index a42320c9bdf..a3d663ae3d7 100644 --- a/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h +++ b/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h @@ -1,19 +1,10 @@ //===- LoopsToGPUPass.h - Pass converting loops to GPU kernels --*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPUPASS_H_ #define MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPUPASS_H_ diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h index 6f41fb68633..5c8a8e6e494 100644 --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -1,19 +1,10 @@ //===- ConvertStandardToLLVM.h - Convert to the LLVM dialect ----*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Provides a dialect conversion targeting the LLVM IR dialect. By default, it // converts Standard ops and types and provides hooks for dialect-specific diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h index d49c1c22530..a4d95da6a75 100644 --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h @@ -1,19 +1,10 @@ //===- ConvertStandardToLLVMPass.h - Pass entrypoint ------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVMPASS_H_ #define MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVMPASS_H_ diff --git a/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h b/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h index 4caa6d9de77..e0e874027bf 100644 --- a/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h +++ b/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h @@ -1,19 +1,10 @@ //===- ConvertStandardToSPIRV.h - Convert to SPIR-V dialect -----*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Provides patterns to lower StandardOps to SPIR-V dialect. // diff --git a/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h b/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h index e8a71feb8b2..7dbaf1c0418 100644 --- a/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h +++ b/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h @@ -1,19 +1,10 @@ //===- ConvertStandardToSPIRVPass.h - StdOps to SPIR-V pass -----*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Provides a pass to lower from StandardOps to SPIR-V dialect. // diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h index a87e1c658a6..b8b97c21a3e 100644 --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -1,19 +1,10 @@ //===- ConvertVectorToLLVM.h - Utils to convert from the vector dialect ---===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLLVM_H_ #define MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLLVM_H_ diff --git a/mlir/include/mlir/Conversion/VectorToLoops/ConvertVectorToLoops.h b/mlir/include/mlir/Conversion/VectorToLoops/ConvertVectorToLoops.h index 198eaceda41..4f7d0843b73 100644 --- a/mlir/include/mlir/Conversion/VectorToLoops/ConvertVectorToLoops.h +++ b/mlir/include/mlir/Conversion/VectorToLoops/ConvertVectorToLoops.h @@ -1,19 +1,10 @@ //===- ConvertVectorToLoops.h - Utils to convert from the vector dialect --===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLOOPS_H_ #define MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLOOPS_H_ diff --git a/mlir/include/mlir/Dialect/AffineOps/AffineOps.h b/mlir/include/mlir/Dialect/AffineOps/AffineOps.h index 764f439e020..09408d2efc8 100644 --- a/mlir/include/mlir/Dialect/AffineOps/AffineOps.h +++ b/mlir/include/mlir/Dialect/AffineOps/AffineOps.h @@ -1,19 +1,10 @@ //===- AffineOps.h - MLIR Affine Operations -------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines convenience types for working with Affine operations // in the MLIR operation set. diff --git a/mlir/include/mlir/Dialect/AffineOps/AffineOps.td b/mlir/include/mlir/Dialect/AffineOps/AffineOps.td index befdc2f6237..715e3807a95 100644 --- a/mlir/include/mlir/Dialect/AffineOps/AffineOps.td +++ b/mlir/include/mlir/Dialect/AffineOps/AffineOps.td @@ -1,19 +1,10 @@ //===- AffineOps.td - Affine operation definitions ---------*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Defines MLIR affine operations. // diff --git a/mlir/include/mlir/Dialect/AffineOps/AffineOpsBase.td b/mlir/include/mlir/Dialect/AffineOps/AffineOpsBase.td index 755f65c338e..6aee5f3cd4a 100644 --- a/mlir/include/mlir/Dialect/AffineOps/AffineOpsBase.td +++ b/mlir/include/mlir/Dialect/AffineOps/AffineOpsBase.td @@ -1,19 +1,10 @@ //===- AffineOpsBase.td - Affine operation definitions -----*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Defines base support for MLIR affine operations. // diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h index 45552945f0d..d667de73d41 100644 --- a/mlir/include/mlir/Dialect/CommonFolders.h +++ b/mlir/include/mlir/Dialect/CommonFolders.h @@ -1,19 +1,10 @@ //===- CommonFolders.h - Common Operation Folders----------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This header file declares various common operation folders. These folders // are intended to be used by dialects to support common folding behavior diff --git a/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.h b/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.h index 88a42344c3b..8c0e7aa1aad 100644 --- a/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.h +++ b/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.h @@ -1,19 +1,10 @@ //===- FxpMathOps.h - Fixed point ops ---------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_DIALECT_FXPMATHOPS_FXPMATHOPS_H_ #define MLIR_DIALECT_FXPMATHOPS_FXPMATHOPS_H_ diff --git a/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.td b/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.td index b1bfb2706cf..d527b759a10 100644 --- a/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.td +++ b/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.td @@ -1,19 +1,10 @@ //===- FxpMathOps.td - Fixed point ops --------------------*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This is the operation definition file for fixed point ops (and real // equivalents). diff --git a/mlir/include/mlir/Dialect/FxpMathOps/Passes.h b/mlir/include/mlir/Dialect/FxpMathOps/Passes.h index 415b1c0b253..aec21c4c186 100644 --- a/mlir/include/mlir/Dialect/FxpMathOps/Passes.h +++ b/mlir/include/mlir/Dialect/FxpMathOps/Passes.h @@ -1,19 +1,10 @@ //===- Passes.h - Fixed point math passes -----------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines all of the passes owned by the FxpMathOps dialect. // diff --git a/mlir/include/mlir/Dialect/GPU/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/GPUDialect.h index 12c2aa1bbd1..c3ab6ec5729 100644 --- a/mlir/include/mlir/Dialect/GPU/GPUDialect.h +++ b/mlir/include/mlir/Dialect/GPU/GPUDialect.h @@ -1,19 +1,10 @@ //===- GPUDialect.h - MLIR Dialect for GPU Kernels --------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines the GPU kernel-related operations and puts them in the // corresponding dialect. diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td index def1ff2b8a1..037664d0d9b 100644 --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -1,19 +1,10 @@ //===-- GPUOps.td - GPU dialect operation definitions ------*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Defines some operations of the GPU dialect. // diff --git a/mlir/include/mlir/Dialect/GPU/Passes.h b/mlir/include/mlir/Dialect/GPU/Passes.h index 7c8ce02db90..daf6d28d452 100644 --- a/mlir/include/mlir/Dialect/GPU/Passes.h +++ b/mlir/include/mlir/Dialect/GPU/Passes.h @@ -1,19 +1,10 @@ //===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This header file defines prototypes that expose pass constructors. // diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h index a599d51b31f..bef1f2dbf20 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -1,19 +1,10 @@ //===- LLVMDialect.h - MLIR LLVM IR dialect ---------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines the LLVM IR dialect in MLIR, containing LLVM operations and // LLVM type system. diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td index 6257b4a51d9..ed935d5b7f7 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -1,19 +1,10 @@ //===-- LLVMOpBase.td - LLVM IR dialect shared definitions -*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains shared definitions for the LLVM IR dialect and its // subdialects. diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index cfbbf7da65d..46f63206ef5 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -1,19 +1,10 @@ //===-- LLVMOps.td - LLVM IR dialect op definition file ----*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This is the LLVM IR operation definition file. // diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h index 0328cf4ba94..afb6d4ab627 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h @@ -1,19 +1,10 @@ //===- NVVMDialect.h - MLIR NVVM IR dialect ---------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines the NVVM IR dialect in MLIR, containing NVVM operations and // NVVM specific extensions to the LLVM type system. diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index bc6887da8e4..f35b7798149 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1,19 +1,10 @@ //===-- NVVMOps.td - NVVM IR dialect op definition file ----*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This is the NVVM IR operation definition file. // diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h b/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h index a34c11223f3..dab32d30e8f 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h @@ -1,19 +1,10 @@ //===- ROCDLDialect.h - MLIR ROCDL IR dialect -------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines the ROCDL dialect in MLIR, containing ROCDL operations // and ROCDL specific extensions to the LLVM type system. diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 79d4136d6f5..697ff9740a8 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -1,19 +1,10 @@ //===-- ROCDLOps.td - ROCDL IR dialect op definition file --*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This is the ROCDL IR operation definition file. // diff --git a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h index 426708b14a8..1a2d6b9b3ba 100644 --- a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h +++ b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h @@ -1,19 +1,10 @@ //===- DependenceAnalysis.h - Dependence analysis on SSA views --*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_ #define MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h index 8375e750a5c..d0f6c942b95 100644 --- a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h @@ -1,19 +1,10 @@ //===- Builders.h - MLIR Declarative Linalg Builders ------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Provides intuitive composable interfaces for building structured MLIR // snippets in a declarative fashion. diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h index f1acab69a4d..b04c11f22bb 100644 --- a/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h @@ -1,19 +1,10 @@ //===- Intrinsics.h - MLIR EDSC Intrinsics for Linalg -----------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_DIALECT_LINALG_EDSC_INTRINSICS_H_ #define MLIR_DIALECT_LINALG_EDSC_INTRINSICS_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td index 4e77b0ac0a8..c1adc8b4d05 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -1,19 +1,10 @@ //===- LinalgBase.td - Linalg dialect base support ---------*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This is the definition file for base linear algebra support. // diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgDoc.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgDoc.td index a3163f50476..819d02d396d 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgDoc.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgDoc.td @@ -1,19 +1,10 @@ //===- LinalgDoc.td - Linalg documentation -----------------*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This documentation files exists to circumvent limitations on mixing different // .td files in cases one does not want to have all ops belong to the same diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td index 18ca31cc376..e52019d7992 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td @@ -1,19 +1,10 @@ //===- LinalgLibraryOps.td - Linalg dialect library ops -*- tablegen ----*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This is the operation definition file for linear algebra operations that // correspond to underlying library calls (e.g. BLAS). diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h index c5f1f01d0c7..3249edb48e0 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h @@ -1,19 +1,10 @@ //===- LinalgOps.h - Linalg Operations --------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_DIALECT_LINALG_LINALGOPS_H_ #define MLIR_DIALECT_LINALG_LINALGOPS_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td index 5d402a9ded9..728fa619dbe 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -1,19 +1,10 @@ //===- LinalgOps.td - Linalg dialect ops -------------------*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This is the operation definition file for linear algebra operations. // diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index 774be6616cd..8674c277e4a 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -1,19 +1,10 @@ //===- LinalgStructuredOps.td - Linalg dialect library ops -*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This is the operation definition file for structured operations on buffers // that correspond to underlying library calls (e.g. BLAS). diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h index d196e6ccf94..7399aad6663 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h @@ -1,19 +1,10 @@ //===- LinalgTraits.h - Linalg Traits ---------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_DIALECT_LINALG_LINALGTRAITS_H_ #define MLIR_DIALECT_LINALG_LINALGTRAITS_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h index f779c3de6ae..abeda3e0552 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h @@ -1,19 +1,10 @@ //===- LinalgTypes.h - Linalg Types ---------------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_DIALECT_LINALG_LINALGTYPES_H_ #define MLIR_DIALECT_LINALG_LINALGTYPES_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h index 7ae3877f01e..86cf6fdd027 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -1,19 +1,10 @@ //===- Passes.h - Linalg pass entry points ----------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This header file defines prototypes that expose pass constructors. // diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td index dbc162f4132..448ffdf7d4b 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td @@ -1,19 +1,10 @@ //===- LinalgPatterns.td - Linalg transformation patterns --*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This is the pattern definition file for declarative Linalg transformation. // diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h index a1a7458ae7f..a88dc4105e2 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h @@ -1,19 +1,10 @@ //===- LinalgTransforms.h - Linalg transformations as patterns --*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef DIALECT_LINALG_TRANSFORMS_LINALGTRANSFORMS_H_ #define DIALECT_LINALG_TRANSFORMS_LINALGTRANSFORMS_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Intrinsics.h b/mlir/include/mlir/Dialect/Linalg/Utils/Intrinsics.h index 5a815ba158e..778d853aeef 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Intrinsics.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Intrinsics.h @@ -1,19 +1,10 @@ //===- Intrinsics.h - Linalg intrinsics definitions -----------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_DIALECT_LINALG_INTRINSICS_H_ #define MLIR_DIALECT_LINALG_INTRINSICS_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 50039dd9336..1b45179bc9e 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -1,19 +1,10 @@ //===- Utils.h - Utilities to support the Linalg dialect --------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_DIALECT_LINALG_UTILS_H_ #define MLIR_DIALECT_LINALG_UTILS_H_ diff --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.h b/mlir/include/mlir/Dialect/LoopOps/LoopOps.h index e7ff6f84977..dba5e819986 100644 --- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.h +++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.h @@ -1,19 +1,10 @@ //===- Ops.h - Loop MLIR Operations -----------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines convenience types for working with loop operations. // diff --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td index e0f5b896309..3b0f120441a 100644 --- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td +++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td @@ -1,19 +1,10 @@ //===- Ops.td - Loop operation definitions ---------------*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Defines MLIR loop operations. // diff --git a/mlir/include/mlir/Dialect/QuantOps/FakeQuantSupport.h b/mlir/include/mlir/Dialect/QuantOps/FakeQuantSupport.h index 23e2967bd77..1a141e3b1b3 100644 --- a/mlir/include/mlir/Dialect/QuantOps/FakeQuantSupport.h +++ b/mlir/include/mlir/Dialect/QuantOps/FakeQuantSupport.h @@ -1,19 +1,10 @@ //===- FakeQuantSupport.h - Support utilities for FakeQuant ops -*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines support utilities for interoperating with FakeQuant* based // QAT (Quantized Aware Training) computations, as implemented by TFLite. Note diff --git a/mlir/include/mlir/Dialect/QuantOps/Passes.h b/mlir/include/mlir/Dialect/QuantOps/Passes.h index c57d7bf41fe..d3109775db2 100644 --- a/mlir/include/mlir/Dialect/QuantOps/Passes.h +++ b/mlir/include/mlir/Dialect/QuantOps/Passes.h @@ -1,19 +1,10 @@ //===- Passes.h - Quantization Passes ------ --------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines all of the passes owned by the quantization dialect. As // things mature, it is expected that passes specific to certain frontend or diff --git a/mlir/include/mlir/Dialect/QuantOps/QuantOps.h b/mlir/include/mlir/Dialect/QuantOps/QuantOps.h index 020d34918d4..9a4eec67c74 100644 --- a/mlir/include/mlir/Dialect/QuantOps/QuantOps.h +++ b/mlir/include/mlir/Dialect/QuantOps/QuantOps.h @@ -1,19 +1,10 @@ //===- QuantOps.h - Quantization Ops and Types ------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_DIALECT_QUANTOPS_QUANTOPS_H_ #define MLIR_DIALECT_QUANTOPS_QUANTOPS_H_ diff --git a/mlir/include/mlir/Dialect/QuantOps/QuantOps.td b/mlir/include/mlir/Dialect/QuantOps/QuantOps.td index 072715d65aa..bbeb9419cc4 100644 --- a/mlir/include/mlir/Dialect/QuantOps/QuantOps.td +++ b/mlir/include/mlir/Dialect/QuantOps/QuantOps.td @@ -1,19 +1,10 @@ //===- QuantOps.td - Quantization operation definition -----*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This is the operation definition file for Quantization. // diff --git a/mlir/include/mlir/Dialect/QuantOps/QuantPredicates.td b/mlir/include/mlir/Dialect/QuantOps/QuantPredicates.td index 2fbb7995dd4..7225dcc72db 100644 --- a/mlir/include/mlir/Dialect/QuantOps/QuantPredicates.td +++ b/mlir/include/mlir/Dialect/QuantOps/QuantPredicates.td @@ -1,19 +1,10 @@ //===- QuantPredicates.td - Predicates for dialect types ---*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Predicates for types in the Quantization dialect. // diff --git a/mlir/include/mlir/Dialect/QuantOps/QuantTypes.h b/mlir/include/mlir/Dialect/QuantOps/QuantTypes.h index 55e921ff8fb..daeb0374460 100644 --- a/mlir/include/mlir/Dialect/QuantOps/QuantTypes.h +++ b/mlir/include/mlir/Dialect/QuantOps/QuantTypes.h @@ -1,19 +1,10 @@ //===- QuantTypes.h - Quantization Ops and Types ----------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_DIALECT_QUANTOPS_QUANT_TYPES_H_ #define MLIR_DIALECT_QUANTOPS_QUANT_TYPES_H_ diff --git a/mlir/include/mlir/Dialect/QuantOps/QuantizeUtils.h b/mlir/include/mlir/Dialect/QuantOps/QuantizeUtils.h index de87ca1e67c..c40b9e6f026 100644 --- a/mlir/include/mlir/Dialect/QuantOps/QuantizeUtils.h +++ b/mlir/include/mlir/Dialect/QuantOps/QuantizeUtils.h @@ -1,19 +1,10 @@ //===- QuantizeUtils.h - Support utilities for quantization -----*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_DIALECT_QUANTOPS_QUANTIZEUTILS_H_ #define MLIR_DIALECT_QUANTOPS_QUANTIZEUTILS_H_ diff --git a/mlir/include/mlir/Dialect/QuantOps/UniformSupport.h b/mlir/include/mlir/Dialect/QuantOps/UniformSupport.h index 0416db34e17..7c74fc56b8f 100644 --- a/mlir/include/mlir/Dialect/QuantOps/UniformSupport.h +++ b/mlir/include/mlir/Dialect/QuantOps/UniformSupport.h @@ -1,19 +1,10 @@ //===- UniformSupport.h - Support utilities for uniform quant ---*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_DIALECT_QUANTOPS_UNIFORMSUPPORT_H_ #define MLIR_DIALECT_QUANTOPS_UNIFORMSUPPORT_H_ diff --git a/mlir/include/mlir/Dialect/SDBM/SDBM.h b/mlir/include/mlir/Dialect/SDBM/SDBM.h index f95a51e407a..c8a0eec8ca8 100644 --- a/mlir/include/mlir/Dialect/SDBM/SDBM.h +++ b/mlir/include/mlir/Dialect/SDBM/SDBM.h @@ -1,19 +1,10 @@ //===- SDBM.h - MLIR SDBM declaration ---------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // A striped difference-bound matrix (SDBM) is a set in Z^N (or R^N) defined // as {(x_1, ... x_n) | f(x_1, ... x_n) >= 0} where f is an SDBM expression. diff --git a/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h b/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h index e3573ba604d..501c66140f0 100644 --- a/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h +++ b/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h @@ -1,19 +1,10 @@ //===- SDBMDialect.h - Dialect for striped DBMs -----------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_DIALECT_SDBM_SDBMDIALECT_H #define MLIR_DIALECT_SDBM_SDBMDIALECT_H diff --git a/mlir/include/mlir/Dialect/SDBM/SDBMExpr.h b/mlir/include/mlir/Dialect/SDBM/SDBMExpr.h index 8cb5ef0be10..84a9a8405a8 100644 --- a/mlir/include/mlir/Dialect/SDBM/SDBMExpr.h +++ b/mlir/include/mlir/Dialect/SDBM/SDBMExpr.h @@ -1,19 +1,10 @@ //===- SDBMExpr.h - MLIR SDBM Expression ------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // A striped difference-bound matrix (SDBM) expression is a constant expression, // an identifier, a binary expression with constant RHS and +, stripe operators diff --git a/mlir/include/mlir/Dialect/SPIRV/LayoutUtils.h b/mlir/include/mlir/Dialect/SPIRV/LayoutUtils.h index 7537e5f654b..329caa2d3aa 100644 --- a/mlir/include/mlir/Dialect/SPIRV/LayoutUtils.h +++ b/mlir/include/mlir/Dialect/SPIRV/LayoutUtils.h @@ -1,19 +1,10 @@ //===-- LayoutUtils.h - Decorate composite type with layout information ---===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines utilities used to get alignment and layout information for // types in SPIR-V dialect. diff --git a/mlir/include/mlir/Dialect/SPIRV/Passes.h b/mlir/include/mlir/Dialect/SPIRV/Passes.h index fe029ff27ea..68f149b54d5 100644 --- a/mlir/include/mlir/Dialect/SPIRV/Passes.h +++ b/mlir/include/mlir/Dialect/SPIRV/Passes.h @@ -1,19 +1,10 @@ //===- Passes.h - SPIR-V pass entry points ----------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This header file defines prototypes that expose pass constructors. // diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td index f15d274922a..39858f357ff 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td @@ -1,19 +1,10 @@ //===-- SPIRVArithmeticOps.td - MLIR SPIR-V Arithmetic Ops -*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains arithmetic ops for the SPIR-V dialect. It corresponds // to "3.32.13. Arithmetic Instructions" of the SPIR-V specification. diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVAtomicOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVAtomicOps.td index 15b6ab0105c..c2ea100c121 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVAtomicOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVAtomicOps.td @@ -1,19 +1,10 @@ //===-- SPIRVAtomicOps.td - MLIR SPIR-V Atomic Ops ---------*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains atomic ops for the SPIR-V dialect. It corresponds to // "3.32.18. Atomic Instructions" of the SPIR-V specification. diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index 838398823ad..5751a32e169 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -1,19 +1,10 @@ //===- SPIRVBase.td - MLIR SPIR-V Op Definitions Base file -*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This is the base file for SPIR-V operation definition specification. // This file defines the SPIR-V dialect, common SPIR-V types, and utilities diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBinaryUtils.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVBinaryUtils.h index 3229e28ef1a..6a426488423 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBinaryUtils.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBinaryUtils.h @@ -1,19 +1,10 @@ //===- SPIRVBinaryUtils.cpp - SPIR-V Binary Module Utils --------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file declares common utilities for SPIR-V binary module. // diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBitOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBitOps.td index d76a1e3854b..360edeec52d 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBitOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBitOps.td @@ -1,19 +1,10 @@ //===-- SPIRVBitOps.td - MLIR SPIR-V Bit Ops -*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains bit ops for the SPIR-V dialect. It corresponds // to "3.32.13. Bit Instructions" of the SPIR-V specification. diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVCastOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVCastOps.td index e4fe526e420..99fe0bbbf5f 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVCastOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVCastOps.td @@ -1,19 +1,10 @@ //===-- SPIRVCastOps.td - MLIR SPIR-V Cast Ops -------*- tablegen -*-------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains cast ops for the SPIR-V dialect. It corresponds // to "3.32.11. Convertion Instructions" of the SPIR-V specification. diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td index d19fd974684..7bd88ab66e0 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td @@ -1,19 +1,10 @@ //===-- SPIRVCompositeOps.td - MLIR SPIR-V Composite Ops ---*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains composite ops for SPIR-V dialect. It corresponds // to "3.32.12. Composite Instructions" of the SPIR-V spec. diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td index 32a78024560..bc06c0289db 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td @@ -1,19 +1,10 @@ //===-- SPIRVControlFlowOps.td - SPIR-V Control Flow Ops ---*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains control flow ops for the SPIR-V dialect. It corresponds // to "3.32.17. Control-Flow Instructions" of the SPIR-V specification. diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h index 2571e5d8928..0c0eebd34d1 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h @@ -1,19 +1,10 @@ //===- SPIRVDialect.h - MLIR SPIR-V dialect ---------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file declares the SPIR-V dialect in MLIR. // diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVGLSLOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVGLSLOps.td index a031facdf5a..b2eacbf306a 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVGLSLOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVGLSLOps.td @@ -1,19 +1,10 @@ //===- SPIRVGLSLOps.td - GLSL extended insts spec file -----*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This is the op definition spec of GLSL extension ops. // diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td index c0388fe4e23..827636afbaf 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td @@ -1,19 +1,10 @@ //===-- SPIRVGroupOps.td - MLIR SPIR-V (Sub)Group Ops ------*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains group and subgroup ops for the SPIR-V dialect. It // corresponds to "3.32.21. Group and Subgroup Instructions" of the SPIR-V diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td index e1e94bcd861..4057f47931c 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td @@ -1,19 +1,10 @@ //===-- SPIRVLogicalOps.td - MLIR SPIR-V Logical Ops -------*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains arithmetic ops for the SPIR-V dialect. It corresponds // to "3.32.15. Relational and Logical Instructions" of the SPIR-V spec. diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h index 37b4ee24237..e7cf250cc3a 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h @@ -1,19 +1,10 @@ //===- SPIRVLowering.h - SPIR-V lowering utilities -------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Defines utilities to use while targeting SPIR-V dialect. // diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.td index d9cf0a752b8..91a8ff68bbf 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.td @@ -1,19 +1,10 @@ //===- SPIRVBase.td - MLIR SPIR-V Op Definitions Base file -*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This is the base file for supporting lowering to SPIR-V dialect. This // file defines SPIR-V attributes used for specifying the shader diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td index 1b3174c9e9f..f3a9a61a9e9 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td @@ -1,19 +1,10 @@ //===-- SPIRVNonUniformOps.td - MLIR SPIR-V NonUniform Ops -*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains non-uniform ops for the SPIR-V dialect. It corresponds to // "3.32.24. Non-Uniform Instructions" of the SPIR-V specification. diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h index cb33146286a..2fa417bfe25 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h @@ -1,19 +1,10 @@ //===- SPIRVOps.h - MLIR SPIR-V operations ----------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file declares the operations in the SPIR-V dialect. // diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td index 777e5750486..f657d5847d0 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td @@ -1,19 +1,10 @@ //===-- SPIRVOps.td - MLIR SPIR-V Op Definitions Spec ------*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This is the main operation definition specification file for SPIR-V // operations. diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td index d1dacf3d63d..c37796b9f60 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td @@ -1,19 +1,10 @@ //===-- SPIRVStructureOps.td - MLIR SPIR-V Structure Ops ---*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains ops for defining the SPIR-V structure: module, function, // and module-level operations. The representational form of these ops deviate diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h index bc3083e8d7c..001d3130778 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h @@ -1,19 +1,10 @@ //===- SPIRVTypes.h - MLIR SPIR-V Types -------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file declares the types in the SPIR-V dialect. // diff --git a/mlir/include/mlir/Dialect/SPIRV/Serialization.h b/mlir/include/mlir/Dialect/SPIRV/Serialization.h index bad7355791f..e8240b0072e 100644 --- a/mlir/include/mlir/Dialect/SPIRV/Serialization.h +++ b/mlir/include/mlir/Dialect/SPIRV/Serialization.h @@ -1,19 +1,10 @@ //===- Serialization.h - MLIR SPIR-V (De)serialization ----------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file declares the entry points for serialize and deserialize SPIR-V // binary modules. diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.h b/mlir/include/mlir/Dialect/StandardOps/Ops.h index 563116823d9..e3ec6f1f7d6 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.h @@ -1,19 +1,10 @@ //===- Ops.h - Standard MLIR Operations -------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines convenience types for working with standard operations // in the MLIR operation set. diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td index e00674708f6..c31b3dc9395 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -1,19 +1,10 @@ //===- Ops.td - Standard operation definitions -------------*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Defines some MLIR standard operations. // diff --git a/mlir/include/mlir/Dialect/Traits.h b/mlir/include/mlir/Dialect/Traits.h index e04eb829e88..87c8e662a65 100644 --- a/mlir/include/mlir/Dialect/Traits.h +++ b/mlir/include/mlir/Dialect/Traits.h @@ -1,19 +1,10 @@ //===- Traits.h - Common op traits shared by dialects -----------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file declares common op traits that are not core to MLIR but can be // shared by multiple dialects. diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h index b7e3990a333..9e7cbba0f43 100644 --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -1,19 +1,10 @@ //===- StructuredOpsUtils.h - Utilities used by structured ops --*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This header file define utilities that operate on standard types and are // useful across multiple dialects that use structured ops abstractions. These diff --git a/mlir/include/mlir/Dialect/VectorOps/Utils.h b/mlir/include/mlir/Dialect/VectorOps/Utils.h index 68c62cc7ec7..b4d8ad65e60 100644 --- a/mlir/include/mlir/Dialect/VectorOps/Utils.h +++ b/mlir/include/mlir/Dialect/VectorOps/Utils.h @@ -1,19 +1,10 @@ //===- Utils.h - VectorOps Utils ----------------------------*- C++ -*-=======// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_DIALECT_VECTOROPS_UTILS_H_ #define MLIR_DIALECT_VECTOROPS_UTILS_H_ diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.h b/mlir/include/mlir/Dialect/VectorOps/VectorOps.h index 29ad6eecaf9..7234d46b765 100644 --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.h +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.h @@ -1,19 +1,10 @@ //===- VectorOps.h - MLIR Super Vectorizer Operations -----------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines the Vector dialect. // diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td index 94262e6f1ff..87ed28caf80 100644 --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -1,19 +1,10 @@ //===- VectorOps.td - Vector op definitions ---------------*- tablegen -*-====// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Defines MLIR vector operations. // diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorTransformPatterns.td b/mlir/include/mlir/Dialect/VectorOps/VectorTransformPatterns.td index 86ff9b505d5..5d0244f6989 100644 --- a/mlir/include/mlir/Dialect/VectorOps/VectorTransformPatterns.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorTransformPatterns.td @@ -1,19 +1,10 @@ //===- VectorTransformPatterns.td - Vector-Vector patterns -*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This is the pattern definition file for declarative Vector transformations. // diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h b/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h index b48cb51533f..a73444d2023 100644 --- a/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h @@ -1,19 +1,10 @@ //===- VectorTransforms.h - Vector transformations as patterns --*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef DIALECT_VECTOROPS_VECTORTRANSFORMS_H_ #define DIALECT_VECTOROPS_VECTORTRANSFORMS_H_ diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h index 11ee0bff342..6607f267057 100644 --- a/mlir/include/mlir/EDSC/Builders.h +++ b/mlir/include/mlir/EDSC/Builders.h @@ -1,19 +1,10 @@ //===- Builders.h - MLIR Declarative Builder Classes ------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Provides intuitive composable interfaces for building structured MLIR // snippets in a declarative fashion. diff --git a/mlir/include/mlir/EDSC/Helpers.h b/mlir/include/mlir/EDSC/Helpers.h index c18307e7121..0be8a6045f7 100644 --- a/mlir/include/mlir/EDSC/Helpers.h +++ b/mlir/include/mlir/EDSC/Helpers.h @@ -1,19 +1,10 @@ //===- Helpers.h - MLIR Declarative Helper Functionality --------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Provides helper classes and syntactic sugar for declarative builders. // diff --git a/mlir/include/mlir/EDSC/Intrinsics.h b/mlir/include/mlir/EDSC/Intrinsics.h index dc0c1186c7a..5edbf9600fb 100644 --- a/mlir/include/mlir/EDSC/Intrinsics.h +++ b/mlir/include/mlir/EDSC/Intrinsics.h @@ -1,19 +1,10 @@ //===- Intrinsics.h - MLIR Operations for Declarative Builders ---*- C++-*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Provides intuitive composable intrinsics for building snippets of MLIR // declaratively diff --git a/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h index 4e70a21f6ec..4f218bd0d9b 100644 --- a/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h +++ b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h @@ -1,19 +1,10 @@ //===- ExecutionEngine.h - MLIR Execution engine and utils -----*- C++ -*--===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file provides a JIT-backed execution engine for MLIR modules. // diff --git a/mlir/include/mlir/ExecutionEngine/OptUtils.h b/mlir/include/mlir/ExecutionEngine/OptUtils.h index 8c0249d5c09..7b7b2598db5 100644 --- a/mlir/include/mlir/ExecutionEngine/OptUtils.h +++ b/mlir/include/mlir/ExecutionEngine/OptUtils.h @@ -1,19 +1,10 @@ //===- OptUtils.h - MLIR Execution Engine opt pass utilities ----*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file declares the utility functions to trigger LLVM optimizations from // MLIR Execution Engine. diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h index b66933df408..7059489ed4c 100644 --- a/mlir/include/mlir/IR/AffineExpr.h +++ b/mlir/include/mlir/IR/AffineExpr.h @@ -1,19 +1,10 @@ //===- AffineExpr.h - MLIR Affine Expr Class --------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // An affine expression is an affine combination of dimension identifiers and // symbols, including ceildiv/floordiv/mod by a constant integer. diff --git a/mlir/include/mlir/IR/AffineExprVisitor.h b/mlir/include/mlir/IR/AffineExprVisitor.h index 9fa40218b5f..7866d6bb996 100644 --- a/mlir/include/mlir/IR/AffineExprVisitor.h +++ b/mlir/include/mlir/IR/AffineExprVisitor.h @@ -1,19 +1,10 @@ //===- AffineExprVisitor.h - MLIR AffineExpr Visitor Class ------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines the AffineExpr visitor class. // diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h index abd3712b0e1..3f9116cb168 100644 --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -1,19 +1,10 @@ //===- AffineMap.h - MLIR Affine Map Class ----------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Affine maps are mathematical functions which map a list of dimension // identifiers and symbols, to multidimensional affine expressions. diff --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h index 78b3a2779d3..9804d6866f8 100644 --- a/mlir/include/mlir/IR/AttributeSupport.h +++ b/mlir/include/mlir/IR/AttributeSupport.h @@ -1,19 +1,10 @@ //===- AttributeSupport.h ---------------------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines support types for registering dialect extended attributes. // diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index b5f4b1a7d7c..b8398580f61 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -1,19 +1,10 @@ //===- Attributes.h - MLIR Attribute Classes --------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_IR_ATTRIBUTES_H #define MLIR_IR_ATTRIBUTES_H diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h index 87c77160e1d..b5189b48a85 100644 --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -1,19 +1,10 @@ //===- Block.h - MLIR Block Class -------------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines the Block class. // diff --git a/mlir/include/mlir/IR/BlockAndValueMapping.h b/mlir/include/mlir/IR/BlockAndValueMapping.h index 287dd508fa6..82173c34368 100644 --- a/mlir/include/mlir/IR/BlockAndValueMapping.h +++ b/mlir/include/mlir/IR/BlockAndValueMapping.h @@ -1,19 +1,10 @@ //===- BlockAndValueMapping.h -----------------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines a utility class for maintaining a mapping for multiple // value types. diff --git a/mlir/include/mlir/IR/BlockSupport.h b/mlir/include/mlir/IR/BlockSupport.h index fd30c36aaa3..7cefe870c22 100644 --- a/mlir/include/mlir/IR/BlockSupport.h +++ b/mlir/include/mlir/IR/BlockSupport.h @@ -1,19 +1,10 @@ //===- BlockSupport.h -------------------------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines a number of support types for the Block class. // diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index c199c09feb5..038664f0186 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -1,19 +1,10 @@ //===- Builders.h - Helpers for constructing MLIR Classes -------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_IR_BUILDERS_H #define MLIR_IR_BUILDERS_H diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h index 9385de9ac4f..e3d0f838208 100644 --- a/mlir/include/mlir/IR/Diagnostics.h +++ b/mlir/include/mlir/IR/Diagnostics.h @@ -1,19 +1,10 @@ //===- Diagnostics.h - MLIR Diagnostics -------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines utilities for emitting diagnostics. // diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h index a1855e797e8..d3b4b055bc0 100644 --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -1,19 +1,10 @@ //===- Dialect.h - IR Dialect Description -----------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines the 'dialect' abstraction. // diff --git a/mlir/include/mlir/IR/DialectHooks.h b/mlir/include/mlir/IR/DialectHooks.h index c51fafb6180..7e4e1d8335b 100644 --- a/mlir/include/mlir/IR/DialectHooks.h +++ b/mlir/include/mlir/IR/DialectHooks.h @@ -1,19 +1,10 @@ //===- DialectHooks.h - MLIR DialectHooks mechanism -------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines abstraction and registration mechanism for dialect hooks. // diff --git a/mlir/include/mlir/IR/DialectImplementation.h b/mlir/include/mlir/IR/DialectImplementation.h index c645a2427b2..1eada8f264b 100644 --- a/mlir/include/mlir/IR/DialectImplementation.h +++ b/mlir/include/mlir/IR/DialectImplementation.h @@ -1,19 +1,10 @@ //===- DialectImplementation.h ----------------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains utilities classes for implementing dialect attributes and // types. diff --git a/mlir/include/mlir/IR/DialectInterface.h b/mlir/include/mlir/IR/DialectInterface.h index 4eb41105032..ff1f8fb015a 100644 --- a/mlir/include/mlir/IR/DialectInterface.h +++ b/mlir/include/mlir/IR/DialectInterface.h @@ -1,19 +1,10 @@ //===- DialectInterface.h - IR Dialect Interfaces ---------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_IR_DIALECTINTERFACE_H #define MLIR_IR_DIALECTINTERFACE_H diff --git a/mlir/include/mlir/IR/DialectSymbolRegistry.def b/mlir/include/mlir/IR/DialectSymbolRegistry.def index c1056bd4da0..14b876a2ce9 100644 --- a/mlir/include/mlir/IR/DialectSymbolRegistry.def +++ b/mlir/include/mlir/IR/DialectSymbolRegistry.def @@ -1,19 +1,10 @@ //===- DialectSymbolRegistry.def - MLIR Dialect Symbol Registry -*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file enumerates the different dialects that define custom classes // within the attribute or type system. diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h index 6731f5430fa..3f788bbeeba 100644 --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -1,19 +1,10 @@ //===- Function.h - MLIR Function Class -------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Functions are the basic unit of composition in MLIR. // diff --git a/mlir/include/mlir/IR/FunctionImplementation.h b/mlir/include/mlir/IR/FunctionImplementation.h index c557d58429c..9d3e438f67e 100644 --- a/mlir/include/mlir/IR/FunctionImplementation.h +++ b/mlir/include/mlir/IR/FunctionImplementation.h @@ -1,19 +1,10 @@ //===- FunctionImplementation.h - Function-like Op utilities ----*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file provides utility functions for implementing function-like // operations, in particular, parsing, printing and verification components diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h index 1ba85d73df9..49175ba5e75 100644 --- a/mlir/include/mlir/IR/FunctionSupport.h +++ b/mlir/include/mlir/IR/FunctionSupport.h @@ -1,19 +1,10 @@ //===- FunctionSupport.h - Utility types for function-like ops --*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines support types for Operations that represent function-like // constructs to use. diff --git a/mlir/include/mlir/IR/Identifier.h b/mlir/include/mlir/IR/Identifier.h index bc84c200545..604eebf341e 100644 --- a/mlir/include/mlir/IR/Identifier.h +++ b/mlir/include/mlir/IR/Identifier.h @@ -1,19 +1,10 @@ //===- Identifier.h - MLIR Identifier Class ---------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_IR_IDENTIFIER_H #define MLIR_IR_IDENTIFIER_H diff --git a/mlir/include/mlir/IR/IntegerSet.h b/mlir/include/mlir/IR/IntegerSet.h index 6ffe830883b..1238511df34 100644 --- a/mlir/include/mlir/IR/IntegerSet.h +++ b/mlir/include/mlir/IR/IntegerSet.h @@ -1,19 +1,10 @@ //===- IntegerSet.h - MLIR Integer Set Class --------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Integer sets are sets of points from the integer lattice constrained by // affine equality/inequality constraints. This class is meant to represent diff --git a/mlir/include/mlir/IR/Location.h b/mlir/include/mlir/IR/Location.h index bb55ad69057..c36bcb30735 100644 --- a/mlir/include/mlir/IR/Location.h +++ b/mlir/include/mlir/IR/Location.h @@ -1,19 +1,10 @@ //===- Location.h - MLIR Location Classes -----------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // These classes provide the ability to relate MLIR objects back to source // location position information. diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h index a93cb8b3353..e0761bcaaf1 100644 --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -1,19 +1,10 @@ //===- MLIRContext.h - MLIR Global Context Class ----------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_IR_MLIRCONTEXT_H #define MLIR_IR_MLIRCONTEXT_H diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h index 3b36f2fb5eb..5ce2cc7a8a8 100644 --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -1,19 +1,10 @@ //===- Matchers.h - Various common matchers ---------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file provides a simple and efficient mechanism for performing general // tree-based pattern matching over MLIR. This mechanism is inspired by LLVM's diff --git a/mlir/include/mlir/IR/Module.h b/mlir/include/mlir/IR/Module.h index 52d2455c7ae..babc51aad0d 100644 --- a/mlir/include/mlir/IR/Module.h +++ b/mlir/include/mlir/IR/Module.h @@ -1,19 +1,10 @@ //===- Module.h - MLIR Module Class -----------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Module is the top-level container for code in an MLIR program. // diff --git a/mlir/include/mlir/IR/OpAsmInterface.td b/mlir/include/mlir/IR/OpAsmInterface.td index 85726a8c64d..7e31c07575e 100644 --- a/mlir/include/mlir/IR/OpAsmInterface.td +++ b/mlir/include/mlir/IR/OpAsmInterface.td @@ -1,19 +1,10 @@ //===- OpAsmInterface.td - Asm Interfaces for opse ---------*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains Interfaces for interacting with the AsmParser and // AsmPrinter. diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 24e48b329d5..c457d25fc51 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1,19 +1,10 @@ //===-- OpBase.td - Base op definition file ----------------*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This is the base operation definition file. // diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 437540117c4..84f3cf2f444 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1,19 +1,10 @@ //===- OpDefinition.h - Classes for defining concrete Op types --*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements helper classes for implementing the "Op" types. This // includes the Op type, which is the base class for Op class definitions, diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index fcadce9ab16..e58a5b07038 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -1,19 +1,10 @@ //===- OpImplementation.h - Classes for implementing Op types ---*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This classes used by the implementation details of Op types. // diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index ad0dc600f8f..9ab900c8761 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -1,19 +1,10 @@ //===- Operation.h - MLIR Operation Class -----------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines the Operation class. // diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index b7f63218ba5..14681663372 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -1,19 +1,10 @@ //===- OperationSupport.h ---------------------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines a number of support types that Operation and related // classes build on top of. diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 707bb7c139f..e6b5e7a5eb7 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -1,19 +1,10 @@ //===- PatternMatch.h - PatternMatcher classes -------==---------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_PATTERNMATCHER_H #define MLIR_PATTERNMATCHER_H diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h index c1390adb40b..00f3ca7fba1 100644 --- a/mlir/include/mlir/IR/Region.h +++ b/mlir/include/mlir/IR/Region.h @@ -1,19 +1,10 @@ //===- Region.h - MLIR Region Class -----------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines the Region class. // diff --git a/mlir/include/mlir/IR/RegionGraphTraits.h b/mlir/include/mlir/IR/RegionGraphTraits.h index f45dcc41a4a..b11c87dbd0c 100644 --- a/mlir/include/mlir/IR/RegionGraphTraits.h +++ b/mlir/include/mlir/IR/RegionGraphTraits.h @@ -1,19 +1,10 @@ //===- RegionGraphTraits.h - llvm::GraphTraits for CFGs ---------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements specializations of llvm::GraphTraits for various MLIR // CFG data types. This allows the generic LLVM graph algorithms to be applied diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h index b6b4b6ea52c..89ffc45e547 100644 --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -1,19 +1,10 @@ //===- StandardTypes.h - MLIR Standard Type Classes -------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_IR_STANDARDTYPES_H #define MLIR_IR_STANDARDTYPES_H diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h index 1a730731f32..f9288197072 100644 --- a/mlir/include/mlir/IR/StorageUniquerSupport.h +++ b/mlir/include/mlir/IR/StorageUniquerSupport.h @@ -1,19 +1,10 @@ //===- StorageUniquerSupport.h - MLIR Storage Uniquer Utilities -*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines utility classes for interfacing with StorageUniquer. // diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h index e04beac6bc6..07829186cbf 100644 --- a/mlir/include/mlir/IR/SymbolTable.h +++ b/mlir/include/mlir/IR/SymbolTable.h @@ -1,19 +1,10 @@ //===- SymbolTable.h - MLIR Symbol Table Class ------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_IR_SYMBOLTABLE_H #define MLIR_IR_SYMBOLTABLE_H diff --git a/mlir/include/mlir/IR/TypeSupport.h b/mlir/include/mlir/IR/TypeSupport.h index 86620da0b5c..8cc811cb916 100644 --- a/mlir/include/mlir/IR/TypeSupport.h +++ b/mlir/include/mlir/IR/TypeSupport.h @@ -1,19 +1,10 @@ //===- TypeSupport.h --------------------------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines support types for registering dialect extended types. // diff --git a/mlir/include/mlir/IR/TypeUtilities.h b/mlir/include/mlir/IR/TypeUtilities.h index af22f9c4a9f..b4713226559 100644 --- a/mlir/include/mlir/IR/TypeUtilities.h +++ b/mlir/include/mlir/IR/TypeUtilities.h @@ -1,19 +1,10 @@ //===- TypeUtilities.h - Helper function for type queries -------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines generic type utilities. // diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h index 2ab36353dc4..6246e9bedd0 100644 --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -1,19 +1,10 @@ //===- Types.h - MLIR Type Classes ------------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_IR_TYPES_H #define MLIR_IR_TYPES_H diff --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h index 96e4ace2529..898d0da2b28 100644 --- a/mlir/include/mlir/IR/UseDefLists.h +++ b/mlir/include/mlir/IR/UseDefLists.h @@ -1,19 +1,10 @@ //===- UseDefLists.h --------------------------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines generic use/def list machinery and manipulation utilities. // diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h index 11cb8cdcbc7..030e6fa58b1 100644 --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -1,19 +1,10 @@ //===- Value.h - Base of the SSA Value hierarchy ----------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines generic Value type and manipulation utilities. // diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h index 50d65627f1a..aaab933d239 100644 --- a/mlir/include/mlir/IR/Visitors.h +++ b/mlir/include/mlir/IR/Visitors.h @@ -1,19 +1,10 @@ //===- Visitors.h - Utilities for visiting operations -----------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines utilities for walking and visiting operations. // diff --git a/mlir/include/mlir/Parser.h b/mlir/include/mlir/Parser.h index 3a818ffa9d8..cae1e8b9ab1 100644 --- a/mlir/include/mlir/Parser.h +++ b/mlir/include/mlir/Parser.h @@ -1,19 +1,10 @@ //===- Parser.h - MLIR Parser Library Interface -----------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file is contains the interface to the MLIR parser library. // diff --git a/mlir/include/mlir/Pass/AnalysisManager.h b/mlir/include/mlir/Pass/AnalysisManager.h index e233a4a5676..471cd011c40 100644 --- a/mlir/include/mlir/Pass/AnalysisManager.h +++ b/mlir/include/mlir/Pass/AnalysisManager.h @@ -1,19 +1,10 @@ //===- AnalysisManager.h - Analysis Management Infrastructure ---*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_PASS_ANALYSISMANAGER_H #define MLIR_PASS_ANALYSISMANAGER_H diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h index 380b097c78c..b4e8db86ff0 100644 --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -1,19 +1,10 @@ //===- Pass.h - Base classes for compiler passes ----------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_PASS_PASS_H #define MLIR_PASS_PASS_H diff --git a/mlir/include/mlir/Pass/PassInstrumentation.h b/mlir/include/mlir/Pass/PassInstrumentation.h index 4b61850c661..ef75e56ae62 100644 --- a/mlir/include/mlir/Pass/PassInstrumentation.h +++ b/mlir/include/mlir/Pass/PassInstrumentation.h @@ -1,19 +1,10 @@ //===- PassInstrumentation.h ------------------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_PASS_PASSINSTRUMENTATION_H_ #define MLIR_PASS_PASSINSTRUMENTATION_H_ diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h index 9de8ace435c..d4f3683f031 100644 --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -1,19 +1,10 @@ //===- PassManager.h - Pass Management Interface ----------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_PASS_PASSMANAGER_H #define MLIR_PASS_PASSMANAGER_H diff --git a/mlir/include/mlir/Pass/PassOptions.h b/mlir/include/mlir/Pass/PassOptions.h index eabfa73a1b6..8ebeead90c8 100644 --- a/mlir/include/mlir/Pass/PassOptions.h +++ b/mlir/include/mlir/Pass/PassOptions.h @@ -1,19 +1,10 @@ //===- PassOptions.h - Pass Option Utilities --------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains utilities for registering options with compiler passes and // pipelines. diff --git a/mlir/include/mlir/Pass/PassRegistry.h b/mlir/include/mlir/Pass/PassRegistry.h index deb80ef765e..e07b9855c8d 100644 --- a/mlir/include/mlir/Pass/PassRegistry.h +++ b/mlir/include/mlir/Pass/PassRegistry.h @@ -1,19 +1,10 @@ //===- PassRegistry.h - Pass Registration Utilities -------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains utilities for registering information about compiler // passes. diff --git a/mlir/include/mlir/Quantizer/Configurations/FxpMathConfig.h b/mlir/include/mlir/Quantizer/Configurations/FxpMathConfig.h index 467512f2b77..f27d12d7f52 100644 --- a/mlir/include/mlir/Quantizer/Configurations/FxpMathConfig.h +++ b/mlir/include/mlir/Quantizer/Configurations/FxpMathConfig.h @@ -1,19 +1,10 @@ //===- FxpMathConfig.h - Reference fixed point config -----------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines a TargetConfiguration for reference fixed-point math // quantization scheme based on the FxpMathOps (plus a small category of diff --git a/mlir/include/mlir/Quantizer/Support/Configuration.h b/mlir/include/mlir/Quantizer/Support/Configuration.h index 17a472de30a..3732fbad3a2 100644 --- a/mlir/include/mlir/Quantizer/Support/Configuration.h +++ b/mlir/include/mlir/Quantizer/Support/Configuration.h @@ -1,19 +1,10 @@ //===- Configuration.h - Configuration object base classes ------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // The quantizer is relatively agnostic to source and target dialects, with // the specific represented by configuration policy objects derived from diff --git a/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraph.h b/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraph.h index 202e86566fc..fe66848b906 100644 --- a/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraph.h +++ b/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraph.h @@ -1,19 +1,10 @@ //===- ConstraintAnalysisGraph.h - Graphs type for constraints --*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file provides graph-based data structures for representing anchors // and constraints between them. diff --git a/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h b/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h index 7e2b61d0496..35ec85f13b2 100644 --- a/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h +++ b/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h @@ -1,19 +1,10 @@ //===- ConstraintAnalysisGraphTraits.h - Traits for CAGs --------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Provides graph traits for constraint analysis graphs. // diff --git a/mlir/include/mlir/Quantizer/Support/Metadata.h b/mlir/include/mlir/Quantizer/Support/Metadata.h index 6c327d9df7a..0545e78f917 100644 --- a/mlir/include/mlir/Quantizer/Support/Metadata.h +++ b/mlir/include/mlir/Quantizer/Support/Metadata.h @@ -1,19 +1,10 @@ //===- Metadata.h - Top level types and metadata ----------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains top level types needed to construct constraint graphs, // including context/allocator support and concrete metadata structs for diff --git a/mlir/include/mlir/Quantizer/Support/Rules.h b/mlir/include/mlir/Quantizer/Support/Rules.h index 9d1e53df5c0..536dd7ea07e 100644 --- a/mlir/include/mlir/Quantizer/Support/Rules.h +++ b/mlir/include/mlir/Quantizer/Support/Rules.h @@ -1,19 +1,10 @@ //===- Rules.h - Helpers for declaring facts and rules ----------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines helper classes and functions for managing state (facts), // merging and tracking modification for various data types important for diff --git a/mlir/include/mlir/Quantizer/Support/Statistics.h b/mlir/include/mlir/Quantizer/Support/Statistics.h index 744c5b640ec..a24eecd3427 100644 --- a/mlir/include/mlir/Quantizer/Support/Statistics.h +++ b/mlir/include/mlir/Quantizer/Support/Statistics.h @@ -1,19 +1,10 @@ //===- Statistics.h - Collects statistics over tensors ----------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines adapters for extracting various (per layer and per axis) // statistics over tensors. diff --git a/mlir/include/mlir/Quantizer/Support/TypeUtils.h b/mlir/include/mlir/Quantizer/Support/TypeUtils.h index 074f8b9e854..64ae5d65b57 100644 --- a/mlir/include/mlir/Quantizer/Support/TypeUtils.h +++ b/mlir/include/mlir/Quantizer/Support/TypeUtils.h @@ -1,19 +1,10 @@ //===- TypeUtils.h - Helper function for manipulating types -----*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines various helper functions for manipulating types. The // process of quantizing typically involves a number of type manipulations diff --git a/mlir/include/mlir/Quantizer/Support/UniformConstraints.h b/mlir/include/mlir/Quantizer/Support/UniformConstraints.h index 90b5fe12153..70c022c96a1 100644 --- a/mlir/include/mlir/Quantizer/Support/UniformConstraints.h +++ b/mlir/include/mlir/Quantizer/Support/UniformConstraints.h @@ -1,19 +1,10 @@ //===- UniformConstraints.h - Constraints for uniform quant -----*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines a builder that lets you attach constraints necessary to // perform a variety of uniform quantization conversions to CAG anchors. diff --git a/mlir/include/mlir/Quantizer/Support/UniformSolvers.h b/mlir/include/mlir/Quantizer/Support/UniformSolvers.h index 98df671f81d..d6bd1a25ec3 100644 --- a/mlir/include/mlir/Quantizer/Support/UniformSolvers.h +++ b/mlir/include/mlir/Quantizer/Support/UniformSolvers.h @@ -1,19 +1,10 @@ //===- UniformSolvers.h - Uniform type solver algorithms --------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines algorithms for solving uniform type parameters for various // conditions (i.e. fixed-point, affine, scale matching, etc). diff --git a/mlir/include/mlir/Quantizer/Transforms/Passes.h b/mlir/include/mlir/Quantizer/Transforms/Passes.h index 4fdea58daf4..3490f2953a4 100644 --- a/mlir/include/mlir/Quantizer/Transforms/Passes.h +++ b/mlir/include/mlir/Quantizer/Transforms/Passes.h @@ -1,19 +1,10 @@ //===- Passes.h - Quantizer passes -----------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines entry points to create passes to perform various kinds // of quantization related transforms. diff --git a/mlir/include/mlir/Support/DebugStringHelper.h b/mlir/include/mlir/Support/DebugStringHelper.h index 230ed231458..0fa342686ba 100644 --- a/mlir/include/mlir/Support/DebugStringHelper.h +++ b/mlir/include/mlir/Support/DebugStringHelper.h @@ -1,19 +1,10 @@ //===- DebugStringHelper.h - helpers to generate debug strings --*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Convenience functions to make it easier to get a string representation for // ops that have a print method. For use in debugging output and errors diff --git a/mlir/include/mlir/Support/FileUtilities.h b/mlir/include/mlir/Support/FileUtilities.h index 5ce97223176..c13b39efc4f 100644 --- a/mlir/include/mlir/Support/FileUtilities.h +++ b/mlir/include/mlir/Support/FileUtilities.h @@ -1,19 +1,10 @@ //===- FileUtilities.h - utilities for working with files -------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Common utilities for working with files. // diff --git a/mlir/include/mlir/Support/Functional.h b/mlir/include/mlir/Support/Functional.h index e8bf394b110..f18677f806b 100644 --- a/mlir/include/mlir/Support/Functional.h +++ b/mlir/include/mlir/Support/Functional.h @@ -1,19 +1,10 @@ //===- Functional.h - Helpers for functional-style Combinators --*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_SUPPORT_FUNCTIONAL_H_ #define MLIR_SUPPORT_FUNCTIONAL_H_ diff --git a/mlir/include/mlir/Support/JitRunner.h b/mlir/include/mlir/Support/JitRunner.h index 14b66a8cebd..71c1d7d5105 100644 --- a/mlir/include/mlir/Support/JitRunner.h +++ b/mlir/include/mlir/Support/JitRunner.h @@ -1,19 +1,10 @@ //===- JitRunner.h - MLIR CPU Execution Driver Library ----------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This is a library that provides a shared implementation for command line // utilities that execute an MLIR file on the CPU by translating MLIR to LLVM diff --git a/mlir/include/mlir/Support/LLVM.h b/mlir/include/mlir/Support/LLVM.h index 91d145dd3ca..1885ebe609b 100644 --- a/mlir/include/mlir/Support/LLVM.h +++ b/mlir/include/mlir/Support/LLVM.h @@ -1,19 +1,10 @@ //===- LLVM.h - Import and forward declare core LLVM types ------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file forward declares and imports various common LLVM datatypes that // MLIR wants to use unqualified. diff --git a/mlir/include/mlir/Support/LogicalResult.h b/mlir/include/mlir/Support/LogicalResult.h index a9fc77ceef8..418293c0f80 100644 --- a/mlir/include/mlir/Support/LogicalResult.h +++ b/mlir/include/mlir/Support/LogicalResult.h @@ -1,19 +1,10 @@ //===- LogicalResult.h - Utilities for handling success/failure -*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_SUPPORT_LOGICAL_RESULT_H #define MLIR_SUPPORT_LOGICAL_RESULT_H diff --git a/mlir/include/mlir/Support/MathExtras.h b/mlir/include/mlir/Support/MathExtras.h index 767677fbc5d..1fd0634e9e8 100644 --- a/mlir/include/mlir/Support/MathExtras.h +++ b/mlir/include/mlir/Support/MathExtras.h @@ -1,19 +1,10 @@ //===- MathExtras.h - Math functions relevant to MLIR -----------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains math functions relevant to MLIR. // diff --git a/mlir/include/mlir/Support/MlirOptMain.h b/mlir/include/mlir/Support/MlirOptMain.h index be8e4328fb1..eac5ee765c2 100644 --- a/mlir/include/mlir/Support/MlirOptMain.h +++ b/mlir/include/mlir/Support/MlirOptMain.h @@ -1,19 +1,10 @@ //===- MlirOptMain.h - MLIR Optimizer Driver main ---------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Main entry function for mlir-opt for when built as standalone binary. // diff --git a/mlir/include/mlir/Support/STLExtras.h b/mlir/include/mlir/Support/STLExtras.h index 9bae7acadd6..9a128611c6e 100644 --- a/mlir/include/mlir/Support/STLExtras.h +++ b/mlir/include/mlir/Support/STLExtras.h @@ -1,19 +1,10 @@ //===- STLExtras.h - STL-like extensions that are used by MLIR --*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains stuff that should be arguably sunk down to the LLVM // Support/STLExtras.h file over time. diff --git a/mlir/include/mlir/Support/StorageUniquer.h b/mlir/include/mlir/Support/StorageUniquer.h index fe1f898957a..f505731a649 100644 --- a/mlir/include/mlir/Support/StorageUniquer.h +++ b/mlir/include/mlir/Support/StorageUniquer.h @@ -1,19 +1,10 @@ //===- StorageUniquer.h - Common Storage Class Uniquer ----------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_SUPPORT_STORAGEUNIQUER_H #define MLIR_SUPPORT_STORAGEUNIQUER_H diff --git a/mlir/include/mlir/Support/StringExtras.h b/mlir/include/mlir/Support/StringExtras.h index 2f75c8e5d20..5fc6769c124 100644 --- a/mlir/include/mlir/Support/StringExtras.h +++ b/mlir/include/mlir/Support/StringExtras.h @@ -1,19 +1,10 @@ //===- StringExtras.h - String utilities used by MLIR -----------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains string utility functions used within MLIR. // diff --git a/mlir/include/mlir/Support/ToolUtilities.h b/mlir/include/mlir/Support/ToolUtilities.h index 13a3742f849..3175ebbdba5 100644 --- a/mlir/include/mlir/Support/ToolUtilities.h +++ b/mlir/include/mlir/Support/ToolUtilities.h @@ -1,19 +1,10 @@ //===- ToolUtilities.h - MLIR Tool Utilities --------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file declares common utilities for implementing MLIR tools. // diff --git a/mlir/include/mlir/Support/TranslateClParser.h b/mlir/include/mlir/Support/TranslateClParser.h index ccd4fb97676..822d4b1a0a4 100644 --- a/mlir/include/mlir/Support/TranslateClParser.h +++ b/mlir/include/mlir/Support/TranslateClParser.h @@ -1,19 +1,10 @@ //===- TranslateClParser.h - Translations command line parser ---*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains custom command line parser for translations. // diff --git a/mlir/include/mlir/TableGen/Argument.h b/mlir/include/mlir/TableGen/Argument.h index 83909392a43..6a0787e1b6c 100644 --- a/mlir/include/mlir/TableGen/Argument.h +++ b/mlir/include/mlir/TableGen/Argument.h @@ -1,19 +1,10 @@ //===- Argument.h - Argument definitions ------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This header file contains definitions for TableGen operation's arguments. // Operation arguments fall into two categories: diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h index 242376e24ff..747df945cea 100644 --- a/mlir/include/mlir/TableGen/Attribute.h +++ b/mlir/include/mlir/TableGen/Attribute.h @@ -1,19 +1,10 @@ //===- Attribute.h - Attribute wrapper class --------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Attribute wrapper to simplify using TableGen Record defining a MLIR // Attribute. diff --git a/mlir/include/mlir/TableGen/Constraint.h b/mlir/include/mlir/TableGen/Constraint.h index 17b60da6027..fb7c1d74b64 100644 --- a/mlir/include/mlir/TableGen/Constraint.h +++ b/mlir/include/mlir/TableGen/Constraint.h @@ -1,19 +1,10 @@ //===- Constraint.h - Constraint class --------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Constraint wrapper to simplify using TableGen Record for constraints. // diff --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h index 6861da46e88..56d17f41b56 100644 --- a/mlir/include/mlir/TableGen/Dialect.h +++ b/mlir/include/mlir/TableGen/Dialect.h @@ -1,18 +1,9 @@ // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Dialect wrapper to simplify using TableGen Record defining a MLIR dialect. // diff --git a/mlir/include/mlir/TableGen/Format.h b/mlir/include/mlir/TableGen/Format.h index 6f02c283cad..160ba5f036a 100644 --- a/mlir/include/mlir/TableGen/Format.h +++ b/mlir/include/mlir/TableGen/Format.h @@ -1,19 +1,10 @@ //===- Format.h - Utilities for String Format -------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file declares utilities for formatting strings. They are specially // tailored to the needs of TableGen'ing op definitions and rewrite rules, diff --git a/mlir/include/mlir/TableGen/GenInfo.h b/mlir/include/mlir/TableGen/GenInfo.h index 0b0bd192ae5..3c732c2ff49 100644 --- a/mlir/include/mlir/TableGen/GenInfo.h +++ b/mlir/include/mlir/TableGen/GenInfo.h @@ -1,19 +1,10 @@ //===- GenInfo.h - Generator info -------------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_TABLEGEN_GENINFO_H_ #define MLIR_TABLEGEN_GENINFO_H_ diff --git a/mlir/include/mlir/TableGen/GenNameParser.h b/mlir/include/mlir/TableGen/GenNameParser.h index 7b1e8a36d03..65f4a8ceace 100644 --- a/mlir/include/mlir/TableGen/GenNameParser.h +++ b/mlir/include/mlir/TableGen/GenNameParser.h @@ -1,19 +1,10 @@ //===- GenNameParser.h - Command line parser for generators -----*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // The GenNameParser class adds all passes linked in to the system that are // creatable to the tool. diff --git a/mlir/include/mlir/TableGen/OpInterfaces.h b/mlir/include/mlir/TableGen/OpInterfaces.h index 0959f6be9bb..9bf18161564 100644 --- a/mlir/include/mlir/TableGen/OpInterfaces.h +++ b/mlir/include/mlir/TableGen/OpInterfaces.h @@ -1,19 +1,10 @@ //===- OpInterfaces.h - OpInterfaces wrapper class --------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // OpInterfaces wrapper to simplify using TableGen OpInterfaces. // diff --git a/mlir/include/mlir/TableGen/OpTrait.h b/mlir/include/mlir/TableGen/OpTrait.h index c3ea9a7bda0..59fc7acbfd7 100644 --- a/mlir/include/mlir/TableGen/OpTrait.h +++ b/mlir/include/mlir/TableGen/OpTrait.h @@ -1,19 +1,10 @@ //===- OpTrait.h - OpTrait wrapper class ------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // OpTrait wrapper to simplify using TableGen Record defining an MLIR OpTrait. // diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h index 89fd4ed8d2e..dd5ff353bf9 100644 --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -1,19 +1,10 @@ //===- Operator.h - Operator class ------------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Operator wrapper to simplify using TableGen Record defining a MLIR Op. // diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h index 8bd1c918e31..bf89f6e7c82 100644 --- a/mlir/include/mlir/TableGen/Pattern.h +++ b/mlir/include/mlir/TableGen/Pattern.h @@ -1,19 +1,10 @@ //===- Pattern.h - Pattern wrapper class ------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Pattern wrapper class to simplify using TableGen Record defining a MLIR // Pattern. diff --git a/mlir/include/mlir/TableGen/Predicate.h b/mlir/include/mlir/TableGen/Predicate.h index 49f7ebcfe52..045b7fece2e 100644 --- a/mlir/include/mlir/TableGen/Predicate.h +++ b/mlir/include/mlir/TableGen/Predicate.h @@ -1,19 +1,10 @@ //===- Predicate.h - Predicate class ----------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Wrapper around predicates defined in TableGen. // diff --git a/mlir/include/mlir/TableGen/Region.h b/mlir/include/mlir/TableGen/Region.h index 21dffe687f4..778f68622bf 100644 --- a/mlir/include/mlir/TableGen/Region.h +++ b/mlir/include/mlir/TableGen/Region.h @@ -1,19 +1,10 @@ //===- TGRegion.h - TableGen region definitions -----------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_TABLEGEN_REGION_H_ #define MLIR_TABLEGEN_REGION_H_ diff --git a/mlir/include/mlir/TableGen/Type.h b/mlir/include/mlir/TableGen/Type.h index 03cbd104dc1..35de70f52fd 100644 --- a/mlir/include/mlir/TableGen/Type.h +++ b/mlir/include/mlir/TableGen/Type.h @@ -1,19 +1,10 @@ //===- Type.h - Type class --------------------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Type wrapper to simplify using TableGen Record defining a MLIR Type. // diff --git a/mlir/include/mlir/Target/LLVMIR.h b/mlir/include/mlir/Target/LLVMIR.h index 7ed7b39c4db..1cdc26ccee6 100644 --- a/mlir/include/mlir/Target/LLVMIR.h +++ b/mlir/include/mlir/Target/LLVMIR.h @@ -1,19 +1,10 @@ //===- LLVMIR.h - MLIR to LLVM IR conversion --------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file declares the entry point for the MLIR to LLVM IR conversion. // diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h index 7464e2a347d..4a5010ea09a 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -1,19 +1,10 @@ //===- ModuleTranslation.h - MLIR to LLVM conversion ------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the translation between an MLIR LLVM dialect module and // the corresponding LLVMIR module. It only handles core LLVM IR operations. diff --git a/mlir/include/mlir/Target/NVVMIR.h b/mlir/include/mlir/Target/NVVMIR.h index ec9858e0fd7..377ee16d4e4 100644 --- a/mlir/include/mlir/Target/NVVMIR.h +++ b/mlir/include/mlir/Target/NVVMIR.h @@ -1,19 +1,10 @@ //===- NVVMIR.h - MLIR to LLVM + NVVM IR conversion -------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file declares the entry point for the MLIR to LLVM + NVVM IR conversion. // diff --git a/mlir/include/mlir/Target/ROCDLIR.h b/mlir/include/mlir/Target/ROCDLIR.h index fd00e9458ef..25937eedd5a 100644 --- a/mlir/include/mlir/Target/ROCDLIR.h +++ b/mlir/include/mlir/Target/ROCDLIR.h @@ -1,19 +1,10 @@ //===- ROCDLIR.h - MLIR to LLVM + ROCDL IR conversion -----------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file declares the entry point for the MLIR to LLVM + ROCDL IR // conversion. diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index f9f1207c0a0..dca26348689 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -1,19 +1,10 @@ //===- DialectConversion.h - MLIR dialect conversion pass -------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file declares a generic pass for converting between MLIR dialects. // diff --git a/mlir/include/mlir/Transforms/FoldUtils.h b/mlir/include/mlir/Transforms/FoldUtils.h index 65dd1b6df16..ed18619c44a 100644 --- a/mlir/include/mlir/Transforms/FoldUtils.h +++ b/mlir/include/mlir/Transforms/FoldUtils.h @@ -1,19 +1,10 @@ //===- FoldUtils.h - Operation Fold Utilities -------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This header file declares various operation folding utilities. These // utilities are intended to be used by passes to unify and simply their logic. diff --git a/mlir/include/mlir/Transforms/InliningUtils.h b/mlir/include/mlir/Transforms/InliningUtils.h index 47c4f48f468..e4739bba66b 100644 --- a/mlir/include/mlir/Transforms/InliningUtils.h +++ b/mlir/include/mlir/Transforms/InliningUtils.h @@ -1,19 +1,10 @@ //===- InliningUtils.h - Inliner utilities ----------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This header file defines interfaces for various inlining utility methods. // diff --git a/mlir/include/mlir/Transforms/LoopFusionUtils.h b/mlir/include/mlir/Transforms/LoopFusionUtils.h index af84b8911eb..4c307ffeda3 100644 --- a/mlir/include/mlir/Transforms/LoopFusionUtils.h +++ b/mlir/include/mlir/Transforms/LoopFusionUtils.h @@ -1,19 +1,10 @@ //===- LoopFusionUtils.h - Loop fusion utilities ----------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This header file defines prototypes for various loop fusion utility // methods: these are not passes by themselves but are used either by passes, diff --git a/mlir/include/mlir/Transforms/LoopLikeInterface.h b/mlir/include/mlir/Transforms/LoopLikeInterface.h index a8bc0d11378..cba9ae78122 100644 --- a/mlir/include/mlir/Transforms/LoopLikeInterface.h +++ b/mlir/include/mlir/Transforms/LoopLikeInterface.h @@ -1,19 +1,10 @@ //===- LoopLikeInterface.h - Loop-like operations interface ---------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the operation interface for loop like operations. // diff --git a/mlir/include/mlir/Transforms/LoopLikeInterface.td b/mlir/include/mlir/Transforms/LoopLikeInterface.td index 583cfe26d87..089a3e19c35 100644 --- a/mlir/include/mlir/Transforms/LoopLikeInterface.td +++ b/mlir/include/mlir/Transforms/LoopLikeInterface.td @@ -1,19 +1,10 @@ //===- LoopLikeInterface.td - LoopLike interface -----------*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Defines the interface for loop-like operations as used by LICM. // diff --git a/mlir/include/mlir/Transforms/LoopUtils.h b/mlir/include/mlir/Transforms/LoopUtils.h index 37434ea2ea8..a08a3fc8307 100644 --- a/mlir/include/mlir/Transforms/LoopUtils.h +++ b/mlir/include/mlir/Transforms/LoopUtils.h @@ -1,19 +1,10 @@ //===- LoopUtils.h - Loop transformation utilities --------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This header file defines prototypes for various loop transformation utility // methods: these are not passes by themselves but are used either by passes, diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index 5480a9a4fe1..1ea8f060e39 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -1,19 +1,10 @@ //===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This header file defines prototypes that expose pass constructors in the loop // transformation library. diff --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h index 63236d6a5a0..9639dfad857 100644 --- a/mlir/include/mlir/Transforms/RegionUtils.h +++ b/mlir/include/mlir/Transforms/RegionUtils.h @@ -1,19 +1,10 @@ //===- RegionUtils.h - Region-related transformation utilities --*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_TRANSFORMS_REGIONUTILS_H_ #define MLIR_TRANSFORMS_REGIONUTILS_H_ diff --git a/mlir/include/mlir/Transforms/SideEffectsInterface.h b/mlir/include/mlir/Transforms/SideEffectsInterface.h index 443596b60c1..69c2a272c70 100644 --- a/mlir/include/mlir/Transforms/SideEffectsInterface.h +++ b/mlir/include/mlir/Transforms/SideEffectsInterface.h @@ -1,19 +1,10 @@ //===- SideEffectsInterface.h - dialect interface modeling side effects ---===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file specifies a dialect interface to model side-effects. // diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h index 02c368ec496..a8268c1daa2 100644 --- a/mlir/include/mlir/Transforms/Utils.h +++ b/mlir/include/mlir/Transforms/Utils.h @@ -1,19 +1,10 @@ //===- Utils.h - General transformation utilities ---------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This header file defines prototypes for various transformation utilities for // memref's and non-loop IR structures. These are not passes by themselves but diff --git a/mlir/include/mlir/Transforms/ViewOpGraph.h b/mlir/include/mlir/Transforms/ViewOpGraph.h index 41f5eb5838d..c1782081adc 100644 --- a/mlir/include/mlir/Transforms/ViewOpGraph.h +++ b/mlir/include/mlir/Transforms/ViewOpGraph.h @@ -1,19 +1,10 @@ //===- ViewOpGraph.h - View/write op graphviz graphs ------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Defines interface to produce Graphviz outputs of MLIR op within block. // diff --git a/mlir/include/mlir/Transforms/ViewRegionGraph.h b/mlir/include/mlir/Transforms/ViewRegionGraph.h index 4378d38fae1..e8c47500c74 100644 --- a/mlir/include/mlir/Transforms/ViewRegionGraph.h +++ b/mlir/include/mlir/Transforms/ViewRegionGraph.h @@ -1,19 +1,10 @@ //===- ViewRegionGraph.h - View/write graphviz graphs -----------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Defines interface to produce Graphviz outputs of MLIR Regions. // diff --git a/mlir/include/mlir/Translation.h b/mlir/include/mlir/Translation.h index 0bf8178146a..9244b971753 100644 --- a/mlir/include/mlir/Translation.h +++ b/mlir/include/mlir/Translation.h @@ -1,19 +1,10 @@ //===- Translation.h - Translation registry ---------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Registry for user-provided translations. // diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index 60b2f17292b..27aa0748711 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -1,19 +1,10 @@ //===- AffineAnalysis.cpp - Affine structures analysis routines -----------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements miscellaneous analysis routines for affine structures // (expressions, maps, sets), and other utilities relying on such analysis. diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 21c2830c016..7ab547483cd 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1,19 +1,10 @@ //===- AffineStructures.cpp - MLIR Affine Structures Class-----------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Structures for affine/polyhedral analysis of MLIR functions. // diff --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp index 6ec7c059526..65f6e83bcdf 100644 --- a/mlir/lib/Analysis/CallGraph.cpp +++ b/mlir/lib/Analysis/CallGraph.cpp @@ -1,19 +1,10 @@ //===- CallGraph.cpp - CallGraph analysis for MLIR ------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains interfaces and analyses for defining a nested callgraph. // diff --git a/mlir/lib/Analysis/Dominance.cpp b/mlir/lib/Analysis/Dominance.cpp index 532972b771b..060a505593a 100644 --- a/mlir/lib/Analysis/Dominance.cpp +++ b/mlir/lib/Analysis/Dominance.cpp @@ -1,19 +1,10 @@ //===- Dominance.cpp - Dominator analysis for CFGs ------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Implementation of dominance related classes and instantiations of extern // templates. diff --git a/mlir/lib/Analysis/InferTypeOpInterface.cpp b/mlir/lib/Analysis/InferTypeOpInterface.cpp index cbbd44681ba..2e52de2b3fa 100644 --- a/mlir/lib/Analysis/InferTypeOpInterface.cpp +++ b/mlir/lib/Analysis/InferTypeOpInterface.cpp @@ -1,19 +1,10 @@ //===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains the definitions of the infer op interfaces defined in // `InferTypeOpInterface.td`. diff --git a/mlir/lib/Analysis/Liveness.cpp b/mlir/lib/Analysis/Liveness.cpp index edb18e5645d..bef0b9fa385 100644 --- a/mlir/lib/Analysis/Liveness.cpp +++ b/mlir/lib/Analysis/Liveness.cpp @@ -1,19 +1,10 @@ //===- Liveness.cpp - Liveness analysis for MLIR --------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Implementation of the liveness analysis. // diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 9dfbfe0c542..5499f887c1e 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -1,19 +1,10 @@ //===- LoopAnalysis.cpp - Misc loop analysis routines //-------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements miscellaneous loop analysis routines. // diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index 4696ce64c22..1f7c1a1ae31 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -1,19 +1,10 @@ //===- MemRefBoundCheck.cpp - MLIR Affine Structures Class ----------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a pass to check memref accesses for out of bound // accesses. diff --git a/mlir/lib/Analysis/NestedMatcher.cpp b/mlir/lib/Analysis/NestedMatcher.cpp index 5f2be48b327..97eaafd37ce 100644 --- a/mlir/lib/Analysis/NestedMatcher.cpp +++ b/mlir/lib/Analysis/NestedMatcher.cpp @@ -1,19 +1,10 @@ //===- NestedMatcher.cpp - NestedMatcher Impl ----------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Analysis/NestedMatcher.h" #include "mlir/Dialect/AffineOps/AffineOps.h" diff --git a/mlir/lib/Analysis/OpStats.cpp b/mlir/lib/Analysis/OpStats.cpp index 1c9f6211a84..dbd938710ef 100644 --- a/mlir/lib/Analysis/OpStats.cpp +++ b/mlir/lib/Analysis/OpStats.cpp @@ -1,19 +1,10 @@ //===- OpStats.cpp - Prints stats of operations in module -----------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/IR/Module.h" #include "mlir/IR/Operation.h" diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index b09bddddd66..befe3d39759 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -1,19 +1,10 @@ //===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements Analysis functions specific to slicing in Function. // diff --git a/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp b/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp index 80a579d163f..c6d7519740e 100644 --- a/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp @@ -1,19 +1,10 @@ //===- TestMemRefDependenceCheck.cpp - Test dep analysis ------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a pass to run pair-wise memref access dependence checks. // diff --git a/mlir/lib/Analysis/TestParallelismDetection.cpp b/mlir/lib/Analysis/TestParallelismDetection.cpp index a9f9ea94a45..6cfc5431df3 100644 --- a/mlir/lib/Analysis/TestParallelismDetection.cpp +++ b/mlir/lib/Analysis/TestParallelismDetection.cpp @@ -1,19 +1,10 @@ //===- ParallelismDetection.cpp - Parallelism Detection pass ------------*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a pass to detect parallel affine 'affine.for' ops. // diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 73aa07e7d7b..0e7d10e78cf 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -1,19 +1,10 @@ //===- Utils.cpp ---- Misc utilities for analysis -------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements miscellaneous analysis routines for non-loop IR // structures. diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index a7917eba503..cd77eff9e40 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -1,19 +1,10 @@ //===- VectorAnalysis.cpp - Analysis for Vectorization --------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/LoopAnalysis.h" diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index be499a93898..d4861b1a2e7 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -1,19 +1,10 @@ //===- Verifier.cpp - MLIR Verifier Implementation ------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the verify() methods on the various IR types, performing // (potentially expensive) checks on the holistic structure of the code. This diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp index 144b4a97e87..ce1e5c4a2af 100644 --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -1,19 +1,10 @@ //===- AffineToStandard.cpp - Lower affine constructs to primitives -------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file lowers affine constructs (If and For statements, AffineApply // operations) within a function into their standard If and For equivalent ops. diff --git a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h index a408ab5b5d9..2ca9717ad86 100644 --- a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h +++ b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h @@ -1,19 +1,10 @@ //===- IndexIntrinsicsOpLowering.h - GPU IndexOps Lowering class *- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_CONVERSION_GPUCOMMON_INDEXINTRINSICSOPLOWERING_H_ #define MLIR_CONVERSION_GPUCOMMON_INDEXINTRINSICSOPLOWERING_H_ diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h index 3ab8e75633e..97881d359f6 100644 --- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -1,19 +1,10 @@ //===- OpToFuncCallLowering.h - GPU ops lowering to custom calls *- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_ #define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_ diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp index a91c43e1e92..66a2e66f99a 100644 --- a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp @@ -1,19 +1,10 @@ //===- ConvertKernelFuncToCubin.cpp - MLIR GPU lowering passes ------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a pass to convert gpu kernel functions into a // corresponding binary blob that can be executed on a CUDA GPU. Currently diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp index 840ad6ba701..3383cf13d36 100644 --- a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp @@ -1,19 +1,10 @@ //===- ConvertLaunchFuncToCudaCalls.cpp - MLIR CUDA lowering passes -------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a pass to convert gpu.launch_func op into a sequence of // CUDA runtime calls. As the CUDA runtime does not have a stable published ABI, diff --git a/mlir/lib/Conversion/GPUToNVVM/GPUToNVVM.td b/mlir/lib/Conversion/GPUToNVVM/GPUToNVVM.td index 8c27ba49686..0a6aec07041 100644 --- a/mlir/lib/Conversion/GPUToNVVM/GPUToNVVM.td +++ b/mlir/lib/Conversion/GPUToNVVM/GPUToNVVM.td @@ -1,19 +1,10 @@ //==-- GPUToNVVM.td - GPU Ops to NVVM Patterns ---------------*- tablegen -*==// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Defines Patterns to lower GPU ops to NVVM. // diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index bf18ea03dab..e15ad823a2b 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -1,19 +1,10 @@ //===- LowerGpuOpsToNVVMOps.cpp - MLIR GPU to NVVM lowering passes --------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a pass to generate NVVMIR operations for higher-level // GPU operations. diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp index 59892dbcee8..83770641bd4 100644 --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -1,19 +1,10 @@ //===- LowerGpuOpsToROCDLOps.cpp - MLIR GPU to ROCDL lowering passes ------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a pass to generate ROCDLIR operations for higher-level // GPU operations. diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp index 0c34fc2b8e1..95c46853b1f 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp @@ -1,19 +1,10 @@ //===- ConvertGPUToSPIRV.cpp - Convert GPU ops to SPIR-V dialect ----------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the conversion patterns from GPU ops to SPIR-V dialect. // diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp index b8fe27e92a2..115096003e1 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp @@ -1,19 +1,10 @@ //===- ConvertGPUToSPIRVPass.cpp - GPU to SPIR-V dialect lowering passes --===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a pass to convert a kernel function in the GPU Dialect // into a spv.module operation diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp index 8b6b9fb7930..1b70df6f8bd 100644 --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -1,19 +1,10 @@ //===- LinalgToLLVM.cpp - conversion from Linalg to LLVM dialect ----------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" diff --git a/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp b/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp index d8df7487e71..59dac73de9c 100644 --- a/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp +++ b/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp @@ -1,19 +1,10 @@ //===- ConvertLoopToStandard.cpp - ControlFlow to CFG conversion ----------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a pass to convert loop.for, loop.if and loop.terminator // ops into standard CFG ops. diff --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp index 3cbce7caa76..24bb8ffc462 100644 --- a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp +++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp @@ -1,19 +1,10 @@ //===- LoopsToGPU.cpp - Convert an affine loop nest to a GPU kernel -------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This implements a straightforward conversion of an loop nest into a GPU // kernel. The caller is expected to guarantee that the conversion is correct diff --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp index 63836883512..4dfd26a4392 100644 --- a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp +++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp @@ -1,19 +1,10 @@ //===- LoopsToGPUPass.cpp - Convert a loop nest to a GPU kernel -----------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h" #include "mlir/Conversion/LoopsToGPU/LoopsToGPU.h" diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 67b545c4ec8..160678efe9f 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -1,19 +1,10 @@ //===- ConvertStandardToLLVM.cpp - Standard to LLVM dialect conversion-----===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a pass to convert MLIR standard and builtin dialects // into the LLVM IR dialect. diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp index f7b0c9cb9bc..af1c92ef11d 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -1,19 +1,10 @@ //===- ConvertStandardToSPIRV.cpp - Standard to SPIR-V dialect conversion--===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements patterns to convert Standard Ops to the SPIR-V dialect. // diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp index 113789abe8a..41deec1f6ab 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp @@ -1,19 +1,10 @@ //===- ConvertStandardToSPIRVPass.cpp - Convert Std Ops to SPIR-V Ops -----===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a pass to convert MLIR standard ops into the SPIR-V // ops. diff --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp index 2e1a7f09ff8..5d693336c3f 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp @@ -1,19 +1,10 @@ //===- LegalizeStandardForSPIRV.cpp - Legalize ops for SPIR-V lowering ----===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This transformation pass legalizes operations before the conversion to SPIR-V // dialect to handle ops that cannot be lowered directly. diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 5099cb01bbc..56005220d3f 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1,19 +1,10 @@ //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" diff --git a/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp b/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp index 33778e42329..3ed031b985a 100644 --- a/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp +++ b/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp @@ -1,19 +1,10 @@ //===- VectorToLoops.cpp - Conversion from Vector to mix of Loops and Std -===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements target-dependent lowering of vector transfer operations. // diff --git a/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/mlir/lib/Dialect/AffineOps/AffineOps.cpp index 3a21de389c7..bfe72101e85 100644 --- a/mlir/lib/Dialect/AffineOps/AffineOps.cpp +++ b/mlir/lib/Dialect/AffineOps/AffineOps.cpp @@ -1,19 +1,10 @@ //===- AffineOps.cpp - MLIR Affine Operations -----------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" diff --git a/mlir/lib/Dialect/AffineOps/DialectRegistration.cpp b/mlir/lib/Dialect/AffineOps/DialectRegistration.cpp index 9197e3c619f..775e25ec8ea 100644 --- a/mlir/lib/Dialect/AffineOps/DialectRegistration.cpp +++ b/mlir/lib/Dialect/AffineOps/DialectRegistration.cpp @@ -1,19 +1,10 @@ //===- DialectRegistration.cpp - Register Affine Op dialect ---------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Dialect/AffineOps/AffineOps.h" using namespace mlir; diff --git a/mlir/lib/Dialect/FxpMathOps/IR/DialectRegistration.cpp b/mlir/lib/Dialect/FxpMathOps/IR/DialectRegistration.cpp index aa6782e1464..57d5ae8e789 100644 --- a/mlir/lib/Dialect/FxpMathOps/IR/DialectRegistration.cpp +++ b/mlir/lib/Dialect/FxpMathOps/IR/DialectRegistration.cpp @@ -1,19 +1,10 @@ //===- DialectRegistration.cpp - Register FxpMathOps dialect --------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Dialect/FxpMathOps/FxpMathOps.h" diff --git a/mlir/lib/Dialect/FxpMathOps/IR/FxpMathOps.cpp b/mlir/lib/Dialect/FxpMathOps/IR/FxpMathOps.cpp index 18c07b07117..30e7dc04104 100644 --- a/mlir/lib/Dialect/FxpMathOps/IR/FxpMathOps.cpp +++ b/mlir/lib/Dialect/FxpMathOps/IR/FxpMathOps.cpp @@ -1,19 +1,10 @@ //===- FxpMathOps.cpp - Op implementation for FxpMathOps ------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Dialect/FxpMathOps/FxpMathOps.h" #include "mlir/Dialect/QuantOps/QuantTypes.h" diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp index e1951ff900b..725751eb6c1 100644 --- a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp +++ b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp @@ -1,19 +1,10 @@ //===- LowerUniformRealMath.cpp ------------------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "UniformKernelUtils.h" diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h b/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h index 57a8422b362..bce5285a8b0 100644 --- a/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h +++ b/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h @@ -1,19 +1,10 @@ //===- UniformKernelUtils.h - Utilities for lowering uniform math - C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_ #define MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_ diff --git a/mlir/lib/Dialect/GPU/IR/DialectRegistration.cpp b/mlir/lib/Dialect/GPU/IR/DialectRegistration.cpp index af50d0270cf..511c69e0695 100644 --- a/mlir/lib/Dialect/GPU/IR/DialectRegistration.cpp +++ b/mlir/lib/Dialect/GPU/IR/DialectRegistration.cpp @@ -1,19 +1,10 @@ //===- DialectRegistration.cpp - MLIR GPU dialect registration ------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Dialect/GPU/GPUDialect.h" diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 349c1fa4644..62d6a4b7ea4 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -1,19 +1,10 @@ //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the GPU kernel-related dialect and its operations. // diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp index 8f5f50e4909..6a7cd290dd2 100644 --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -1,19 +1,10 @@ //===- KernelOutlining.cpp - Implementation of GPU kernel outlining -------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the GPU dialect kernel outlining pass. // diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index b94ee335bd2..b8d2d242657 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1,19 +1,10 @@ //===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines the types and operation details for the LLVM IR dialect in // MLIR, and the LLVM IR dialect. It also registers the dialect. diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index e4708fbe535..3a8e84ea918 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -1,19 +1,10 @@ //===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines the types and operation details for the NVVM IR dialect in // MLIR, and the LLVM IR dialect. It also registers the dialect. diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp index 30c55b52e59..c11572cf5a2 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp @@ -1,19 +1,10 @@ //===- ROCDLDialect.cpp - ROCDL IR Ops and Dialect registration -----------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines the types and operation details for the ROCDL IR dialect in // MLIR, and the LLVM IR dialect. It also registers the dialect. diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp index ee122e16037..5fbbdea60c2 100644 --- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp +++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp @@ -1,19 +1,10 @@ //===- DependenceAnalysis.cpp - Dependence analysis on SSA views ----------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements view-based alias and dependence analyses. // diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp index 7b530d7f0df..af5e576b290 100644 --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -1,19 +1,10 @@ //===- Builders.cpp - MLIR Declarative Linalg Builders --------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/EDSC/Builders.h" #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index c5f30b7e10b..10c37c0ec43 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1,19 +1,10 @@ //===- LinalgOps.cpp - Implementation of the linalg operations ------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a the Linalg operations. // diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp index 263a64c5cdc..32b1620f67c 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp @@ -1,19 +1,10 @@ //===- Dialect.cpp - Implementation of the linalg dialect and types -------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the Linalg dialect types and dialect. // diff --git a/mlir/lib/Dialect/Linalg/LinalgRegistration.cpp b/mlir/lib/Dialect/Linalg/LinalgRegistration.cpp index df21ffa88ac..768b18b57f0 100644 --- a/mlir/lib/Dialect/Linalg/LinalgRegistration.cpp +++ b/mlir/lib/Dialect/Linalg/LinalgRegistration.cpp @@ -1,19 +1,10 @@ //===- LinalgRegistration.cpp - Register the linalg dialect statically ----===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index 49cea7e4170..27dcf663d23 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -1,19 +1,10 @@ //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the linalg dialect Fusion pass. // diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp index e468c19a0b4..0f333791dd7 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -1,19 +1,10 @@ //===- LowerToLoops.cpp - conversion from Linalg library ops to loops------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp index 999406e05cf..451803797f4 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -1,19 +1,10 @@ //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements logic for transforming Linalg operations. // diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp index b1dae455194..08bc1518a19 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -1,19 +1,10 @@ //===- Promotion.cpp - Implementation of linalg Promotion -----------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the linalg dialect Promotion pass. // diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 07d559918cf..99645a23100 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -1,19 +1,10 @@ //===- Tiling.cpp - Implementation of linalg Tiling -----------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the linalg dialect Tiling pass. // diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 125937807f4..ae02af0ecc8 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -1,19 +1,10 @@ //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements utilities for the Linalg dialect. // diff --git a/mlir/lib/Dialect/LoopOps/DialectRegistration.cpp b/mlir/lib/Dialect/LoopOps/DialectRegistration.cpp index 5724402e690..6564e78855c 100644 --- a/mlir/lib/Dialect/LoopOps/DialectRegistration.cpp +++ b/mlir/lib/Dialect/LoopOps/DialectRegistration.cpp @@ -1,19 +1,10 @@ //===- DialectRegistration.cpp - Register loop dialect --------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Dialect/LoopOps/LoopOps.h" using namespace mlir; diff --git a/mlir/lib/Dialect/LoopOps/LoopOps.cpp b/mlir/lib/Dialect/LoopOps/LoopOps.cpp index 9610a1ac270..d3040c1bbb2 100644 --- a/mlir/lib/Dialect/LoopOps/LoopOps.cpp +++ b/mlir/lib/Dialect/LoopOps/LoopOps.cpp @@ -1,19 +1,10 @@ //===- Ops.cpp - Loop MLIR Operations -------------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/Dialect/StandardOps/Ops.h" diff --git a/mlir/lib/Dialect/QuantOps/IR/DialectRegistration.cpp b/mlir/lib/Dialect/QuantOps/IR/DialectRegistration.cpp index b071248f4bb..1738d6d7277 100644 --- a/mlir/lib/Dialect/QuantOps/IR/DialectRegistration.cpp +++ b/mlir/lib/Dialect/QuantOps/IR/DialectRegistration.cpp @@ -1,19 +1,10 @@ //===- DialectRegistration.cpp - Register Quantization dialect ------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Dialect/QuantOps/QuantOps.h" diff --git a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp index 51f19940dcb..faeff246bd2 100644 --- a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp @@ -1,19 +1,10 @@ //===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Dialect/QuantOps/QuantOps.h" #include "TypeDetail.h" diff --git a/mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp b/mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp index bc8290cda16..2e33963602c 100644 --- a/mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp +++ b/mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp @@ -1,19 +1,10 @@ //===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Dialect/QuantOps/QuantTypes.h" #include "TypeDetail.h" diff --git a/mlir/lib/Dialect/QuantOps/IR/TypeDetail.h b/mlir/lib/Dialect/QuantOps/IR/TypeDetail.h index 13a88da3043..801a0de32b4 100644 --- a/mlir/lib/Dialect/QuantOps/IR/TypeDetail.h +++ b/mlir/lib/Dialect/QuantOps/IR/TypeDetail.h @@ -1,19 +1,10 @@ //===- TypeDetail.h - QuantOps Type detail ----------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef TYPE_DETAIL_H_ #define TYPE_DETAIL_H_ diff --git a/mlir/lib/Dialect/QuantOps/IR/TypeParser.cpp b/mlir/lib/Dialect/QuantOps/IR/TypeParser.cpp index 2bdde1f94f8..2689a2dff89 100644 --- a/mlir/lib/Dialect/QuantOps/IR/TypeParser.cpp +++ b/mlir/lib/Dialect/QuantOps/IR/TypeParser.cpp @@ -1,19 +1,10 @@ //===- TypeParser.h - Quantization Type Parser ------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Dialect/QuantOps/QuantOps.h" #include "mlir/Dialect/QuantOps/QuantTypes.h" diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp index 61636dcdd8b..08a5ec59e8d 100644 --- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp +++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp @@ -1,19 +1,10 @@ //===- ConvertConst.cpp - Quantizes constant ops --------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Dialect/QuantOps/Passes.h" #include "mlir/Dialect/QuantOps/QuantOps.h" diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp index 83fa9237dee..2a4c14f2231 100644 --- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp +++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp @@ -1,19 +1,10 @@ //===- ConvertSimQuant.cpp - Converts simulated quant ops------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Dialect/QuantOps/FakeQuantSupport.h" #include "mlir/Dialect/QuantOps/Passes.h" diff --git a/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp b/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp index f4256cf25c8..cbd4315f832 100644 --- a/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp +++ b/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp @@ -1,19 +1,10 @@ //===- FakeQuantSupport.cpp - Support utilities for FakeQuant ops ---------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Dialect/QuantOps/FakeQuantSupport.h" #include "mlir/Dialect/QuantOps/QuantTypes.h" diff --git a/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp b/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp index 56e2cbae4f0..094fefee486 100644 --- a/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp +++ b/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp @@ -1,19 +1,10 @@ //===- QuantizeUtils.cpp - Support utilities for quantization -------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Dialect/QuantOps/QuantizeUtils.h" #include "mlir/Dialect/QuantOps/UniformSupport.h" diff --git a/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp b/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp index 34e767dfee3..df002336c16 100644 --- a/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp +++ b/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp @@ -1,19 +1,10 @@ //===- UniformSupport.cpp - Support utilities for uniform quant -----------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Dialect/QuantOps/UniformSupport.h" #include "mlir/IR/StandardTypes.h" diff --git a/mlir/lib/Dialect/SDBM/SDBM.cpp b/mlir/lib/Dialect/SDBM/SDBM.cpp index 510e13e8028..03ffe3ffbb9 100644 --- a/mlir/lib/Dialect/SDBM/SDBM.cpp +++ b/mlir/lib/Dialect/SDBM/SDBM.cpp @@ -1,19 +1,10 @@ //===- SDBM.cpp - MLIR SDBM implementation --------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // A striped difference-bound matrix (SDBM) is a set in Z^N (or R^N) defined // as {(x_1, ... x_n) | f(x_1, ... x_n) >= 0} where f is an SDBM expression. diff --git a/mlir/lib/Dialect/SDBM/SDBMDialect.cpp b/mlir/lib/Dialect/SDBM/SDBMDialect.cpp index d3d895fec88..fab9463a866 100644 --- a/mlir/lib/Dialect/SDBM/SDBMDialect.cpp +++ b/mlir/lib/Dialect/SDBM/SDBMDialect.cpp @@ -1,19 +1,10 @@ //===- SDBMDialect.cpp - Dialect for striped difference-bound matrices ----===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Dialect/SDBM/SDBMDialect.h" diff --git a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp index 44cdd18cf98..68e3e1c278e 100644 --- a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp +++ b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp @@ -1,19 +1,10 @@ //===- SDBMExpr.cpp - MLIR SDBM Expression implementation -----------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // A striped difference-bound matrix (SDBM) expression is a constant expression, // an identifier, a binary expression with constant RHS and +, stripe operators diff --git a/mlir/lib/Dialect/SDBM/SDBMExprDetail.h b/mlir/lib/Dialect/SDBM/SDBMExprDetail.h index 0441200754c..fb80b45902e 100644 --- a/mlir/lib/Dialect/SDBM/SDBMExprDetail.h +++ b/mlir/lib/Dialect/SDBM/SDBMExprDetail.h @@ -1,19 +1,10 @@ //===- SDBMExprDetail.h - MLIR SDBM Expression storage details --*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This holds implementation details of SDBMExpr, in particular underlying // storage types. diff --git a/mlir/lib/Dialect/SPIRV/DialectRegistration.cpp b/mlir/lib/Dialect/SPIRV/DialectRegistration.cpp index 63e9e812c39..431b40ef022 100644 --- a/mlir/lib/Dialect/SPIRV/DialectRegistration.cpp +++ b/mlir/lib/Dialect/SPIRV/DialectRegistration.cpp @@ -1,19 +1,10 @@ //===- DialectRegistration.cpp - MLIR SPIR-V dialect registration ---------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Dialect/SPIRV/SPIRVDialect.h" diff --git a/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp b/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp index 5db478d388b..a12d04edd68 100644 --- a/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp +++ b/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp @@ -1,19 +1,10 @@ //===-- LayoutUtils.cpp - Decorate composite type with layout information -===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements Utilities used to get alignment and layout information // for types in SPIR-V dialect. diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp index ca9b883a703..7b6c013f9ed 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -1,19 +1,10 @@ //===- SPIRVLowering.cpp - Standard to SPIR-V dialect conversion--===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements utilities used to lower to SPIR-V dialect. // diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index a20c18056e1..e42dc10f55d 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -1,19 +1,10 @@ //===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines the operations in the SPIR-V dialect. // diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp index 15621aa5fde..18e027afb4c 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp @@ -1,19 +1,10 @@ //===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines the types in the SPIR-V dialect. // diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index 799828cb629..9e820c6f42b 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -1,19 +1,10 @@ //===- Deserializer.cpp - MLIR SPIR-V Deserialization ---------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines the SPIR-V binary to MLIR SPIR-V module deserialization. // diff --git a/mlir/lib/Dialect/SPIRV/Serialization/SPIRVBinaryUtils.cpp b/mlir/lib/Dialect/SPIRV/Serialization/SPIRVBinaryUtils.cpp index ba383b2cc6c..13405c9883d 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/SPIRVBinaryUtils.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/SPIRVBinaryUtils.cpp @@ -1,19 +1,10 @@ //===- SPIRVBinaryUtils.cpp - MLIR SPIR-V Binary Module Utilities ---------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines common utilities for SPIR-V binary module. // diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index 9b47045ea61..7ff471dfda5 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -1,19 +1,10 @@ //===- Serializer.cpp - MLIR SPIR-V Serialization -------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines the MLIR SPIR-V module to SPIR-V binary serialization. // diff --git a/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp b/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp index e9b4f23cca4..750710fa3d9 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp @@ -1,19 +1,10 @@ //===- TranslateRegistration.cpp - hooks to mlir-translate ----------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a translation from SPIR-V binary module to MLIR SPIR-V // ModuleOp. diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp index be486f858fe..07621d6fa80 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp @@ -1,19 +1,10 @@ //===- DecorateSPIRVCompositeTypeLayoutPass.cpp - Decorate composite type -===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a pass to decorate the composite types used by // composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp index 93ce2c0a0d5..76e1b9b716e 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -1,19 +1,10 @@ //===- LowerABIAttributesPass.cpp - Decorate composite type ---------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a pass to lower attributes that specify the shader ABI // for the functions in the generated SPIR-V module. diff --git a/mlir/lib/Dialect/StandardOps/DialectRegistration.cpp b/mlir/lib/Dialect/StandardOps/DialectRegistration.cpp index 6b5578f93cf..684806009e5 100644 --- a/mlir/lib/Dialect/StandardOps/DialectRegistration.cpp +++ b/mlir/lib/Dialect/StandardOps/DialectRegistration.cpp @@ -1,19 +1,10 @@ //===- DialectRegistration.cpp - Register standard Op dialect -------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Dialect/StandardOps/Ops.h" using namespace mlir; diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index 94166b5a7dd..55da59a0c74 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -1,19 +1,10 @@ //===- Ops.cpp - Standard MLIR Operations ---------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Dialect/StandardOps/Ops.h" diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp index 0ac07c2c4f5..3aea206c07e 100644 --- a/mlir/lib/Dialect/Traits.cpp +++ b/mlir/lib/Dialect/Traits.cpp @@ -1,19 +1,10 @@ //===- Traits.cpp - Common op traits shared by dialects -------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Dialect/Traits.h" #include "mlir/IR/StandardTypes.h" diff --git a/mlir/lib/Dialect/VectorOps/DialectRegistration.cpp b/mlir/lib/Dialect/VectorOps/DialectRegistration.cpp index 0caa1cf629e..edd6abb4e2e 100644 --- a/mlir/lib/Dialect/VectorOps/DialectRegistration.cpp +++ b/mlir/lib/Dialect/VectorOps/DialectRegistration.cpp @@ -1,19 +1,10 @@ //===- DialectRegistration.cpp - Register super vectorization dialect -----===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Dialect/VectorOps/VectorOps.h" using namespace mlir; diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index 18c1714f403..8ceff014029 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -1,19 +1,10 @@ //===- VectorOps.cpp - MLIR Super Vectorizer Operations -------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements convenience types for working with super-vectorization // operations, in particular super-vector loads and stores. diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp index e5c281cbf64..927aeda4ecd 100644 --- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp @@ -1,19 +1,10 @@ //===- VectorToLoops.cpp - Conversion within the Vector dialect -----------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements target-independent rewrites as 1->N patterns. // diff --git a/mlir/lib/EDSC/Builders.cpp b/mlir/lib/EDSC/Builders.cpp index 35108ed5666..b25eb987a9e 100644 --- a/mlir/lib/EDSC/Builders.cpp +++ b/mlir/lib/EDSC/Builders.cpp @@ -1,19 +1,10 @@ //===- Builders.cpp - MLIR Declarative Builder Classes --------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/EDSC/Builders.h" #include "mlir/Dialect/StandardOps/Ops.h" diff --git a/mlir/lib/EDSC/CoreAPIs.cpp b/mlir/lib/EDSC/CoreAPIs.cpp index 46199c29c14..6f7c1728bb0 100644 --- a/mlir/lib/EDSC/CoreAPIs.cpp +++ b/mlir/lib/EDSC/CoreAPIs.cpp @@ -1,19 +1,10 @@ //===- Types.cpp - Implementations of MLIR Core C APIs --------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir-c/Core.h" diff --git a/mlir/lib/EDSC/Helpers.cpp b/mlir/lib/EDSC/Helpers.cpp index 1771eb0a427..79888334cd9 100644 --- a/mlir/lib/EDSC/Helpers.cpp +++ b/mlir/lib/EDSC/Helpers.cpp @@ -1,19 +1,10 @@ //===- Helpers.cpp - MLIR Declarative Helper Functionality ----------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/EDSC/Helpers.h" #include "mlir/Dialect/StandardOps/Ops.h" diff --git a/mlir/lib/EDSC/Intrinsics.cpp b/mlir/lib/EDSC/Intrinsics.cpp index c6738c42993..1bb32b97867 100644 --- a/mlir/lib/EDSC/Intrinsics.cpp +++ b/mlir/lib/EDSC/Intrinsics.cpp @@ -1,19 +1,10 @@ //===- Intrinsics.cpp - MLIR Operations for Declarative Builders ----------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/EDSC/Intrinsics.h" #include "mlir/EDSC/Builders.h" diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp index 5098ba81762..1537018076a 100644 --- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp @@ -1,19 +1,10 @@ //===- ExecutionEngine.cpp - MLIR Execution engine and utils --------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the execution engine for MLIR modules based on LLVM Orc // JIT engine. diff --git a/mlir/lib/ExecutionEngine/OptUtils.cpp b/mlir/lib/ExecutionEngine/OptUtils.cpp index dc3bd20794e..ec2ae5f2dcc 100644 --- a/mlir/lib/ExecutionEngine/OptUtils.cpp +++ b/mlir/lib/ExecutionEngine/OptUtils.cpp @@ -1,19 +1,10 @@ //===- OptUtils.cpp - MLIR Execution Engine optimization pass utilities ---===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the utility functions to trigger LLVM optimizations from // MLIR Execution Engine. diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp index 009c1a1485c..dd8ce00c82a 100644 --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -1,19 +1,10 @@ //===- AffineExpr.cpp - MLIR Affine Expr Classes --------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/IR/AffineExpr.h" #include "AffineExprDetail.h" diff --git a/mlir/lib/IR/AffineExprDetail.h b/mlir/lib/IR/AffineExprDetail.h index 214fee65056..8824ddd8682 100644 --- a/mlir/lib/IR/AffineExprDetail.h +++ b/mlir/lib/IR/AffineExprDetail.h @@ -1,19 +1,10 @@ //===- AffineExprDetail.h - MLIR Affine Expr storage details ----*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This holds implementation details of AffineExpr. Ideally it would not be // exposed and would be kept local to AffineExpr.cpp however, MLIRContext.cpp diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp index 6cfef363985..50624afa3eb 100644 --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -1,19 +1,10 @@ //===- AffineMap.cpp - MLIR Affine Map Classes ----------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/IR/AffineMap.h" #include "AffineMapDetail.h" diff --git a/mlir/lib/IR/AffineMapDetail.h b/mlir/lib/IR/AffineMapDetail.h index a247783540c..f00c4ba216e 100644 --- a/mlir/lib/IR/AffineMapDetail.h +++ b/mlir/lib/IR/AffineMapDetail.h @@ -1,19 +1,10 @@ //===- AffineMapDetail.h - MLIR Affine Map details Class --------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This holds implementation details of AffineMap. // diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 177d8a5ef05..a574f87c530 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1,19 +1,10 @@ //===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the MLIR AsmPrinter class, which is used to implement // the various print() methods on the core IR objects. diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h index da4aa69dda4..c78d49c0f87 100644 --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -1,19 +1,10 @@ //===- AttributeDetail.h - MLIR Affine Map details Class --------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This holds implementation details of Attribute. // diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index bb35a63bf5d..3a9c91f6f77 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -1,19 +1,10 @@ //===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/IR/Attributes.h" #include "AttributeDetail.h" diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp index 894f9ba38d0..b168a8facd2 100644 --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -1,19 +1,10 @@ //===- Block.cpp - MLIR Block Class ---------------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 733fcd13994..2ef10b6e669 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -1,19 +1,10 @@ //===- Builders.cpp - Helpers for constructing MLIR Classes ---------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/IR/Builders.h" #include "mlir/IR/AffineExpr.h" diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp index 59e16a48865..6ec92f05370 100644 --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -1,19 +1,10 @@ //===- Diagnostics.cpp - MLIR Diagnostics ---------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Attributes.h" diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp index c6266b09668..b2485a368fd 100644 --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -1,19 +1,10 @@ //===- Dialect.cpp - Dialect implementation -------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/IR/Dialect.h" #include "mlir/IR/Diagnostics.h" diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index b51c77f34c2..72b5ac46a8f 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -1,19 +1,10 @@ //===- Function.cpp - MLIR Function Classes -------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/IR/Function.h" #include "mlir/IR/BlockAndValueMapping.h" diff --git a/mlir/lib/IR/FunctionImplementation.cpp b/mlir/lib/IR/FunctionImplementation.cpp index 9cec216468d..79863bc74f4 100644 --- a/mlir/lib/IR/FunctionImplementation.cpp +++ b/mlir/lib/IR/FunctionImplementation.cpp @@ -1,19 +1,10 @@ //===- FunctionImplementation.cpp - Utilities for function-like ops -------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/Builders.h" diff --git a/mlir/lib/IR/IntegerSet.cpp b/mlir/lib/IR/IntegerSet.cpp index ce50fa7cc5b..835b4c3a7e2 100644 --- a/mlir/lib/IR/IntegerSet.cpp +++ b/mlir/lib/IR/IntegerSet.cpp @@ -1,19 +1,10 @@ //===- IntegerSet.cpp - MLIR Integer Set class ----------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/IR/IntegerSet.h" #include "IntegerSetDetail.h" diff --git a/mlir/lib/IR/IntegerSetDetail.h b/mlir/lib/IR/IntegerSetDetail.h index b3eda5205fb..54ffd47bd47 100644 --- a/mlir/lib/IR/IntegerSetDetail.h +++ b/mlir/lib/IR/IntegerSetDetail.h @@ -1,19 +1,10 @@ //===- IntegerSetDetail.h - MLIR IntegerSet storage details -----*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This holds implementation details of IntegerSet. // diff --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp index 1ea75d5e30e..e23a73647a4 100644 --- a/mlir/lib/IR/Location.cpp +++ b/mlir/lib/IR/Location.cpp @@ -1,19 +1,10 @@ //===- Location.cpp - MLIR Location Classes -------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/IR/Location.h" #include "LocationDetail.h" diff --git a/mlir/lib/IR/LocationDetail.h b/mlir/lib/IR/LocationDetail.h index 6ccaa17018c..a47a2111c4f 100644 --- a/mlir/lib/IR/LocationDetail.h +++ b/mlir/lib/IR/LocationDetail.h @@ -1,19 +1,10 @@ //===- LocationDetail.h - MLIR Location storage details ---------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This holds implementation details of the location attributes. // diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index d3feca14477..42d77ae2a3d 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -1,19 +1,10 @@ //===- MLIRContext.cpp - MLIR Type Classes --------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/IR/MLIRContext.h" #include "AffineExprDetail.h" diff --git a/mlir/lib/IR/Module.cpp b/mlir/lib/IR/Module.cpp index c52a55b20fe..c5af227459c 100644 --- a/mlir/lib/IR/Module.cpp +++ b/mlir/lib/IR/Module.cpp @@ -1,19 +1,10 @@ //===- Module.cpp - MLIR Module Operation ---------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/IR/Module.h" #include "mlir/IR/Builders.h" diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 53399ce00a3..1dc7cb4bafd 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -1,19 +1,10 @@ //===- Operation.cpp - Operation support code -----------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/IR/Operation.h" #include "mlir/IR/BlockAndValueMapping.h" diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp index 333685a16fd..1c68686a0cb 100644 --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -1,19 +1,10 @@ //===- OperationSupport.cpp -----------------------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains out-of-line implementations of the support types that // Operation and related classes build on top of. diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index 3887a0308b0..d5749fabc07 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -1,19 +1,10 @@ //===- PatternMatch.cpp - Base classes for pattern match ------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/IR/PatternMatch.h" #include "mlir/IR/BlockAndValueMapping.h" diff --git a/mlir/lib/IR/Region.cpp b/mlir/lib/IR/Region.cpp index 26f14c43424..935854a5365 100644 --- a/mlir/lib/IR/Region.cpp +++ b/mlir/lib/IR/Region.cpp @@ -1,19 +1,10 @@ //===- Region.cpp - MLIR Region Class -------------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/IR/Region.h" #include "mlir/IR/BlockAndValueMapping.h" diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp index 7c494e219e8..441b59ed9cd 100644 --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -1,19 +1,10 @@ //===- StandardTypes.cpp - MLIR Standard Type Classes ---------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/IR/StandardTypes.h" #include "TypeDetail.h" diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp index bd8cb59cea7..83e5802093c 100644 --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -1,19 +1,10 @@ //===- SymbolTable.cpp - MLIR Symbol Table Class --------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/IR/SymbolTable.h" #include "llvm/ADT/SmallString.h" diff --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h index 5bcb0b61aa5..b3e0edd3a57 100644 --- a/mlir/lib/IR/TypeDetail.h +++ b/mlir/lib/IR/TypeDetail.h @@ -1,19 +1,10 @@ //===- TypeDetail.h - MLIR Type storage details -----------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This holds implementation details of Type. // diff --git a/mlir/lib/IR/TypeUtilities.cpp b/mlir/lib/IR/TypeUtilities.cpp index 8200e3a3bc6..8bc67e46fdc 100644 --- a/mlir/lib/IR/TypeUtilities.cpp +++ b/mlir/lib/IR/TypeUtilities.cpp @@ -1,19 +1,10 @@ //===- TypeUtilities.cpp - Helper function for type queries ---------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines generic type utilities. // diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp index 23c80c96aad..923d6e16f57 100644 --- a/mlir/lib/IR/Types.cpp +++ b/mlir/lib/IR/Types.cpp @@ -1,19 +1,10 @@ //===- Types.cpp - MLIR Type Classes --------------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/IR/Types.h" #include "TypeDetail.h" diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp index 660d8ae3248..d723eec8b29 100644 --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -1,19 +1,10 @@ //===- Value.cpp - MLIR Value Classes -------------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/IR/Value.h" #include "mlir/IR/Block.h" diff --git a/mlir/lib/IR/Visitors.cpp b/mlir/lib/IR/Visitors.cpp index ea2a6d69418..404e74a82c9 100644 --- a/mlir/lib/IR/Visitors.cpp +++ b/mlir/lib/IR/Visitors.cpp @@ -1,19 +1,10 @@ //===- Visitors.cpp - MLIR Visitor Utilties -------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/IR/Visitors.h" #include "mlir/IR/Operation.h" diff --git a/mlir/lib/Parser/Lexer.cpp b/mlir/lib/Parser/Lexer.cpp index 29104c82e23..7d8337a9cb3 100644 --- a/mlir/lib/Parser/Lexer.cpp +++ b/mlir/lib/Parser/Lexer.cpp @@ -1,19 +1,10 @@ //===- Lexer.cpp - MLIR Lexer Implementation ------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the lexer for the MLIR textual form. // diff --git a/mlir/lib/Parser/Lexer.h b/mlir/lib/Parser/Lexer.h index a7a2ac4214c..a760dca9396 100644 --- a/mlir/lib/Parser/Lexer.h +++ b/mlir/lib/Parser/Lexer.h @@ -1,19 +1,10 @@ //===- Lexer.h - MLIR Lexer Interface ---------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file declares the MLIR Lexer class. // diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index f78704842fe..e25f4d19654 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -1,19 +1,10 @@ //===- Parser.cpp - MLIR Parser Implementation ----------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the parser for the MLIR textual form. // diff --git a/mlir/lib/Parser/Token.cpp b/mlir/lib/Parser/Token.cpp index c01d6032cbd..84de4c396f4 100644 --- a/mlir/lib/Parser/Token.cpp +++ b/mlir/lib/Parser/Token.cpp @@ -1,19 +1,10 @@ //===- Token.cpp - MLIR Token Implementation ------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the Token class for the MLIR textual form. // diff --git a/mlir/lib/Parser/Token.h b/mlir/lib/Parser/Token.h index 333c4d29aad..7487736fac7 100644 --- a/mlir/lib/Parser/Token.h +++ b/mlir/lib/Parser/Token.h @@ -1,19 +1,10 @@ //===- Token.h - MLIR Token Interface ---------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_LIB_PARSER_TOKEN_H #define MLIR_LIB_PARSER_TOKEN_H diff --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def index 19cd343274d..fc9f7821f1a 100644 --- a/mlir/lib/Parser/TokenKinds.def +++ b/mlir/lib/Parser/TokenKinds.def @@ -1,19 +1,10 @@ //===- TokenKinds.def - MLIR Token Description ------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file is intended to be #include'd multiple times to extract information // about tokens for various clients in the lexer. diff --git a/mlir/lib/Pass/IRPrinting.cpp b/mlir/lib/Pass/IRPrinting.cpp index 9d1c1f0d391..132a0bec4b7 100644 --- a/mlir/lib/Pass/IRPrinting.cpp +++ b/mlir/lib/Pass/IRPrinting.cpp @@ -1,19 +1,10 @@ //===- IRPrinting.cpp -----------------------------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "PassDetail.h" #include "mlir/IR/Module.h" diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp index f893c7babf9..22e58cc5b63 100644 --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -1,19 +1,10 @@ //===- Pass.cpp - Pass infrastructure implementation ----------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements common pass infrastructure. // diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h index d0a2ea63e7d..9a52535bedf 100644 --- a/mlir/lib/Pass/PassDetail.h +++ b/mlir/lib/Pass/PassDetail.h @@ -1,19 +1,10 @@ //===- PassDetail.h - MLIR Pass details -------------------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_PASS_PASSDETAIL_H_ #define MLIR_PASS_PASSDETAIL_H_ diff --git a/mlir/lib/Pass/PassManagerOptions.cpp b/mlir/lib/Pass/PassManagerOptions.cpp index c29e0d08869..87487069d97 100644 --- a/mlir/lib/Pass/PassManagerOptions.cpp +++ b/mlir/lib/Pass/PassManagerOptions.cpp @@ -1,19 +1,10 @@ //===- PassManagerOptions.cpp - PassManager Command Line Options ----------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp index 1a321d666c4..93753d363db 100644 --- a/mlir/lib/Pass/PassRegistry.cpp +++ b/mlir/lib/Pass/PassRegistry.cpp @@ -1,19 +1,10 @@ //===- PassRegistry.cpp - Pass Registration Utilities ---------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Pass/PassRegistry.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/lib/Pass/PassStatistics.cpp b/mlir/lib/Pass/PassStatistics.cpp index 530697421ef..0ab656c2054 100644 --- a/mlir/lib/Pass/PassStatistics.cpp +++ b/mlir/lib/Pass/PassStatistics.cpp @@ -1,19 +1,10 @@ //===- PassStatistics.cpp -------------------------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "PassDetail.h" #include "mlir/Pass/PassManager.h" diff --git a/mlir/lib/Pass/PassTiming.cpp b/mlir/lib/Pass/PassTiming.cpp index 113b65a09b5..93e640e7890 100644 --- a/mlir/lib/Pass/PassTiming.cpp +++ b/mlir/lib/Pass/PassTiming.cpp @@ -1,19 +1,10 @@ //===- PassTiming.cpp -----------------------------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "PassDetail.h" #include "mlir/Pass/PassManager.h" diff --git a/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp b/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp index 94e364238c5..ba9c078a765 100644 --- a/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp +++ b/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp @@ -1,19 +1,10 @@ //===- FxpMathConfig.cpp - Reference fixed point config -------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines a TargetConfiguration for reference fixed-point math // quantization scheme based on the FxpMathOps (plus a small category of diff --git a/mlir/lib/Quantizer/Support/Configuration.cpp b/mlir/lib/Quantizer/Support/Configuration.cpp index 78a74514f8b..f64cc85f0f7 100644 --- a/mlir/lib/Quantizer/Support/Configuration.cpp +++ b/mlir/lib/Quantizer/Support/Configuration.cpp @@ -1,19 +1,10 @@ //===- Configuration.cpp - Configuration object base classes --------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Quantizer/Support/Configuration.h" diff --git a/mlir/lib/Quantizer/Support/ConstraintAnalysisGraph.cpp b/mlir/lib/Quantizer/Support/ConstraintAnalysisGraph.cpp index 13fed0f9b1c..38aa5dc811b 100644 --- a/mlir/lib/Quantizer/Support/ConstraintAnalysisGraph.cpp +++ b/mlir/lib/Quantizer/Support/ConstraintAnalysisGraph.cpp @@ -1,19 +1,10 @@ //===- ConstraintAnalysisGraph.cpp - Graphs type for constraints ----------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h" diff --git a/mlir/lib/Quantizer/Support/Metadata.cpp b/mlir/lib/Quantizer/Support/Metadata.cpp index 89478c4209d..b7badfd5f87 100644 --- a/mlir/lib/Quantizer/Support/Metadata.cpp +++ b/mlir/lib/Quantizer/Support/Metadata.cpp @@ -1,19 +1,10 @@ //===- Metadata.cpp - Top level types and metadata ------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Quantizer/Support/Metadata.h" diff --git a/mlir/lib/Quantizer/Support/Statistics.cpp b/mlir/lib/Quantizer/Support/Statistics.cpp index 6753898dbdc..3c8b041e244 100644 --- a/mlir/lib/Quantizer/Support/Statistics.cpp +++ b/mlir/lib/Quantizer/Support/Statistics.cpp @@ -1,19 +1,10 @@ //===- Statistics.cpp - Collects statistics over tensors ------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Quantizer/Support/Statistics.h" diff --git a/mlir/lib/Quantizer/Support/TypeUtils.cpp b/mlir/lib/Quantizer/Support/TypeUtils.cpp index fab4e565308..a1f52c585a1 100644 --- a/mlir/lib/Quantizer/Support/TypeUtils.cpp +++ b/mlir/lib/Quantizer/Support/TypeUtils.cpp @@ -1,19 +1,10 @@ //===- TypeUtils.cpp - Helper function for manipulating types -------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Quantizer/Support/TypeUtils.h" diff --git a/mlir/lib/Quantizer/Support/UniformConstraints.cpp b/mlir/lib/Quantizer/Support/UniformConstraints.cpp index 1a800dad4ac..b20213568a1 100644 --- a/mlir/lib/Quantizer/Support/UniformConstraints.cpp +++ b/mlir/lib/Quantizer/Support/UniformConstraints.cpp @@ -1,19 +1,10 @@ //===- UniformConstraints.cpp - Constraints for uniform quant -------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Quantizer/Support/UniformConstraints.h" diff --git a/mlir/lib/Quantizer/Support/UniformSolvers.cpp b/mlir/lib/Quantizer/Support/UniformSolvers.cpp index 77d69be8382..2f6bb20792f 100644 --- a/mlir/lib/Quantizer/Support/UniformSolvers.cpp +++ b/mlir/lib/Quantizer/Support/UniformSolvers.cpp @@ -1,19 +1,10 @@ //===- UniformSolvers.cpp - Uniform type solver algorithms ----------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Quantizer/Support/UniformSolvers.h" #include "mlir/Support/LLVM.h" diff --git a/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp b/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp index a3cbe214040..a27f09bf942 100644 --- a/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp +++ b/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp @@ -1,19 +1,10 @@ //===- AddDefaultStatsTestPass.cpp - Testing pass to add default stats ----===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines a testing pass to add default statistics nodes to every // quantization eligible op. Useful for unit testing. diff --git a/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp b/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp index 68c263bc423..c8569c2fe19 100644 --- a/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp +++ b/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp @@ -1,19 +1,10 @@ //===- InferQuantizedTypesPass.cpp - Infers quantized types ---------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines the primary pass for instantiating a CAG, running it to // convergence on a module to determine eligible quantized type transforms, and diff --git a/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp b/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp index 0266520bec3..da5bd12ea1c 100644 --- a/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp +++ b/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp @@ -1,19 +1,10 @@ //===- RemoveInstrumentationPass.cpp - Removes instrumentation ------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines a pass to remove any instrumentation ops. It is often one // of the final steps when performing quantization and is run after any diff --git a/mlir/lib/Support/FileUtilities.cpp b/mlir/lib/Support/FileUtilities.cpp index 6f0dc93b235..a56ae57ba25 100644 --- a/mlir/lib/Support/FileUtilities.cpp +++ b/mlir/lib/Support/FileUtilities.cpp @@ -1,19 +1,10 @@ //===- FileUtilities.cpp - utilities for working with files ---------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Definitions of common utilities for working with files. // diff --git a/mlir/lib/Support/JitRunner.cpp b/mlir/lib/Support/JitRunner.cpp index dcd23437401..b327d3d4756 100644 --- a/mlir/lib/Support/JitRunner.cpp +++ b/mlir/lib/Support/JitRunner.cpp @@ -1,19 +1,10 @@ //===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This is a library that provides a shared implementation for command line // utilities that execute an MLIR file on the CPU by translating MLIR to LLVM diff --git a/mlir/lib/Support/MlirOptMain.cpp b/mlir/lib/Support/MlirOptMain.cpp index c256e970c95..4a76801211c 100644 --- a/mlir/lib/Support/MlirOptMain.cpp +++ b/mlir/lib/Support/MlirOptMain.cpp @@ -1,19 +1,10 @@ //===- MlirOptMain.cpp - MLIR Optimizer Driver ----------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This is a utility that runs an optimization pass and prints the result back // out. It is designed to support unit testing. diff --git a/mlir/lib/Support/StorageUniquer.cpp b/mlir/lib/Support/StorageUniquer.cpp index cae4dce143f..d6f6bac4236 100644 --- a/mlir/lib/Support/StorageUniquer.cpp +++ b/mlir/lib/Support/StorageUniquer.cpp @@ -1,19 +1,10 @@ //===- StorageUniquer.cpp - Common Storage Class Uniquer ------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Support/StorageUniquer.h" diff --git a/mlir/lib/Support/ToolUtilities.cpp b/mlir/lib/Support/ToolUtilities.cpp index 60d0eee6b8a..cd2df7809b7 100644 --- a/mlir/lib/Support/ToolUtilities.cpp +++ b/mlir/lib/Support/ToolUtilities.cpp @@ -1,19 +1,10 @@ //===- ToolUtilities.cpp - MLIR Tool Utilities ----------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines common utilities for implementing MLIR tools. // diff --git a/mlir/lib/Support/TranslateClParser.cpp b/mlir/lib/Support/TranslateClParser.cpp index 115c0c03f50..1f538cb531d 100644 --- a/mlir/lib/Support/TranslateClParser.cpp +++ b/mlir/lib/Support/TranslateClParser.cpp @@ -1,19 +1,10 @@ //===- TranslateClParser.h - Translations command line parser -------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains custom command line parser for translations. // diff --git a/mlir/lib/TableGen/Argument.cpp b/mlir/lib/TableGen/Argument.cpp index 17dba054e4f..080e717092e 100644 --- a/mlir/lib/TableGen/Argument.cpp +++ b/mlir/lib/TableGen/Argument.cpp @@ -1,19 +1,10 @@ //===- Argument.cpp - Argument definitions --------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/TableGen/Argument.h" #include "llvm/TableGen/Record.h" diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp index ec946a855fc..92f5b1f7d9f 100644 --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -1,19 +1,10 @@ //===- Attribute.cpp - Attribute wrapper class ----------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Attribute wrapper to simplify using TableGen Record defining a MLIR // Attribute. diff --git a/mlir/lib/TableGen/Constraint.cpp b/mlir/lib/TableGen/Constraint.cpp index ef3fa5271fa..022c5ad04df 100644 --- a/mlir/lib/TableGen/Constraint.cpp +++ b/mlir/lib/TableGen/Constraint.cpp @@ -1,19 +1,10 @@ //===- Constraint.cpp - Constraint class ----------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Constraint wrapper to simplify using TableGen Record for constraints. // diff --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp index ace4ce3d0f6..d9e8e2f7154 100644 --- a/mlir/lib/TableGen/Dialect.cpp +++ b/mlir/lib/TableGen/Dialect.cpp @@ -1,19 +1,10 @@ //===- Dialect.cpp - Dialect wrapper class --------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Dialect wrapper to simplify using TableGen Record defining a MLIR dialect. // diff --git a/mlir/lib/TableGen/Format.cpp b/mlir/lib/TableGen/Format.cpp index 967d51a61f7..07742ab6a40 100644 --- a/mlir/lib/TableGen/Format.cpp +++ b/mlir/lib/TableGen/Format.cpp @@ -1,19 +1,10 @@ //===- Format.cpp - Utilities for String Format ---------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines utilities for formatting strings. They are specially // tailored to the needs of TableGen'ing op definitions and rewrite rules, diff --git a/mlir/lib/TableGen/OpInterfaces.cpp b/mlir/lib/TableGen/OpInterfaces.cpp index 1687f3ac795..b1e56efc029 100644 --- a/mlir/lib/TableGen/OpInterfaces.cpp +++ b/mlir/lib/TableGen/OpInterfaces.cpp @@ -1,19 +1,10 @@ //===- OpInterfaces.cpp - OpInterfaces class ------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // OpInterfaces wrapper to simplify using TableGen OpInterfaces. // diff --git a/mlir/lib/TableGen/OpTrait.cpp b/mlir/lib/TableGen/OpTrait.cpp index 0e436a87497..86e34cd46b5 100644 --- a/mlir/lib/TableGen/OpTrait.cpp +++ b/mlir/lib/TableGen/OpTrait.cpp @@ -1,19 +1,10 @@ //===- OpTrait.cpp - OpTrait class ----------------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // OpTrait wrapper to simplify using TableGen Record defining a MLIR OpTrait. // diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp index 3825363bec0..d61eec4ad44 100644 --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -1,19 +1,10 @@ //===- Operator.cpp - Operator class --------------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Operator wrapper to simplify using TableGen Record defining a MLIR Op. // diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp index e8f44087b85..1045b784ae2 100644 --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -1,19 +1,10 @@ //===- Pattern.cpp - Pattern wrapper class --------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Pattern wrapper class to simplify using TableGen Record defining a MLIR // Pattern. diff --git a/mlir/lib/TableGen/Predicate.cpp b/mlir/lib/TableGen/Predicate.cpp index f8f23e04c3f..c52e15dbdea 100644 --- a/mlir/lib/TableGen/Predicate.cpp +++ b/mlir/lib/TableGen/Predicate.cpp @@ -1,19 +1,10 @@ //===- Predicate.cpp - Predicate class ------------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Wrapper around predicates defined in TableGen. // diff --git a/mlir/lib/TableGen/Type.cpp b/mlir/lib/TableGen/Type.cpp index a558be4c89d..9a309bdde46 100644 --- a/mlir/lib/TableGen/Type.cpp +++ b/mlir/lib/TableGen/Type.cpp @@ -1,19 +1,10 @@ //===- Type.cpp - Type class ----------------------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Type wrapper to simplify using TableGen Record defining a MLIR Type. // diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp index 7273d3dfd7b..6f3e2ef21aa 100644 --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -1,19 +1,10 @@ //===- ConvertFromLLVMIR.cpp - MLIR to LLVM IR conversion -----------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a translation between LLVM IR and the MLIR LLVM dialect. // diff --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp index e69dce7b59b..4cc59974960 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp @@ -1,19 +1,10 @@ //===- ConvertToLLVMIR.cpp - MLIR to LLVM IR conversion -------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a translation between the MLIR LLVM dialect and LLVM IR. // diff --git a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp index 8baed9854f1..a5992174df3 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp @@ -1,19 +1,10 @@ //===- ConvertToNVVMIR.cpp - MLIR to LLVM IR conversion -------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a translation between the MLIR LLVM + NVVM dialects and // LLVM IR with NVVM intrinsics and metadata. diff --git a/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp index f119b138e13..881d165e0c8 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp @@ -1,19 +1,10 @@ //===- ConvertToROCDLIR.cpp - MLIR to LLVM IR conversion ------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a translation between the MLIR LLVM + ROCDL dialects and // LLVM IR with ROCDL intrinsics and metadata. diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index ec28434b823..e8376364c41 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -1,19 +1,10 @@ //===- ModuleTranslation.cpp - MLIR to LLVM conversion --------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements the translation between an MLIR LLVM dialect module and // the corresponding LLVMIR module. It only handles core LLVM IR operations. diff --git a/mlir/lib/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Transforms/AffineDataCopyGeneration.cpp index 5bc33943e50..1e1b8775d32 100644 --- a/mlir/lib/Transforms/AffineDataCopyGeneration.cpp +++ b/mlir/lib/Transforms/AffineDataCopyGeneration.cpp @@ -1,19 +1,10 @@ //===- AffineDataCopyGeneration.cpp - Explicit memref copying pass ------*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a pass to automatically promote accessed memref regions // to buffers in a faster memory space that is explicitly managed, with the diff --git a/mlir/lib/Transforms/AffineLoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/AffineLoopInvariantCodeMotion.cpp index 23199dd8a39..1f33c0f5dca 100644 --- a/mlir/lib/Transforms/AffineLoopInvariantCodeMotion.cpp +++ b/mlir/lib/Transforms/AffineLoopInvariantCodeMotion.cpp @@ -1,19 +1,10 @@ //===- AffineLoopInvariantCodeMotion.cpp - Code to perform loop fusion-----===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements loop invariant code motion. // diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index 18f9fce5e46..714fb1d0109 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -1,19 +1,10 @@ //===- CSE.cpp - Common Sub-expression Elimination ------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This transformation pass performs a simple common sub-expression elimination // algorithm on operations within a function. diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp index 7dcdeb67cdc..5b3a1eb1cf3 100644 --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -1,19 +1,10 @@ //===- Canonicalizer.cpp - Canonicalize MLIR operations -------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This transformation pass converts operations into their canonical forms by // folding constants, applying operation identity transformations etc. diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index 05066ef599c..a19274acd1b 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -1,19 +1,10 @@ //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Transforms/DialectConversion.h" #include "mlir/IR/Block.h" diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp index b158948069e..b2cee7da083 100644 --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -1,19 +1,10 @@ //===- Inliner.cpp - Pass to inline function calls ------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a basic inlining algorithm that operates bottom up over // the Strongly Connect Components(SCCs) of the CallGraph. This enables a more diff --git a/mlir/lib/Transforms/LoopCoalescing.cpp b/mlir/lib/Transforms/LoopCoalescing.cpp index c1eec56526e..2aee688c6c1 100644 --- a/mlir/lib/Transforms/LoopCoalescing.cpp +++ b/mlir/lib/Transforms/LoopCoalescing.cpp @@ -1,19 +1,10 @@ //===- LoopCoalescing.cpp - Pass transforming loop nests into single loops-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/Dialect/StandardOps/Ops.h" diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 60f0264eb35..51e30ba7163 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -1,19 +1,10 @@ //===- LoopFusion.cpp - Code to perform loop fusion -----------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements loop fusion. // diff --git a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp index bd58827d001..93c80822fb3 100644 --- a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp +++ b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp @@ -1,19 +1,10 @@ //===- LoopInvariantCodeMotion.cpp - Code to perform loop fusion-----------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements loop invariant code motion. // diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 361a4d8ecb9..5389c7e4429 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -1,19 +1,10 @@ //===- LoopTiling.cpp --- Loop tiling pass ------------------------------*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a pass to tile loop nests. // diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 40f48ada4d7..e94c6c8b0bb 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -1,19 +1,10 @@ //===- LoopUnroll.cpp - Code to perform loop unrolling --------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements loop unrolling. // diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index a857b8ec95a..3cefcaacadc 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -1,19 +1,10 @@ //===- LoopUnrollAndJam.cpp - Code to perform loop unroll and jam ---------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements loop unroll and jam. Unroll and jam is a transformation // that improves locality, in particular, register reuse, while also improving diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index 0695aafe171..957f41a9d3e 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -1,19 +1,10 @@ //===- MemRefDataFlowOpt.cpp - MemRef DataFlow Optimization pass ------ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a pass to forward memref stores to loads, thereby // potentially getting rid of intermediate memref's entirely. diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 4162936ea2d..12ce6c66abd 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -1,19 +1,10 @@ //===- PipelineDataTransfer.cpp --- Pass for pipelining data movement ---*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a pass to pipeline data transfers. // diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp index 9512ff738aa..217e06bc877 100644 --- a/mlir/lib/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp @@ -1,19 +1,10 @@ //===- SimplifyAffineStructures.cpp ---------------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a pass to simplify affine structures. // diff --git a/mlir/lib/Transforms/StripDebugInfo.cpp b/mlir/lib/Transforms/StripDebugInfo.cpp index 772df3da3c7..cdfc7fd7e41 100644 --- a/mlir/lib/Transforms/StripDebugInfo.cpp +++ b/mlir/lib/Transforms/StripDebugInfo.cpp @@ -1,19 +1,10 @@ //===- StripDebugInfo.cpp - Pass to strip debug information ---------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/IR/Function.h" #include "mlir/IR/Operation.h" diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp index 85d1f21305e..ce39625831a 100644 --- a/mlir/lib/Transforms/Utils/FoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp @@ -1,19 +1,10 @@ //===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines various operation fold utilities. These utilities are // intended to be used by passes to unify and simply their logic. diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index fe4a6f9f9e0..3ab4e287bb2 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -1,19 +1,10 @@ //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements mlir::applyPatternsGreedily. // diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp index 048130c0d3a..e7b34bb3956 100644 --- a/mlir/lib/Transforms/Utils/InliningUtils.cpp +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -1,19 +1,10 @@ //===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements miscellaneous inlining utilities. // diff --git a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp index d5cda3265de..4745a26e168 100644 --- a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp @@ -1,19 +1,10 @@ //===- LoopFusionUtils.cpp ---- Utilities for loop fusion ----------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements loop fusion transformation utility functions. // diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index bc1ced408a9..3d4db22c866 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -1,19 +1,10 @@ //===- LoopUtils.cpp ---- Misc utilities for loop transformation ----------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements miscellaneous loop transformation routines. // diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp index 749d5bf1dd0..569c5416edd 100644 --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -1,19 +1,10 @@ //===- RegionUtils.cpp - Region-related transformation utilities ----------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Transforms/RegionUtils.h" #include "mlir/IR/Block.h" diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 96a6cdc544f..409729a5f20 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -1,19 +1,10 @@ //===- Utils.cpp ---- Misc utilities for code and data transformation -----===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements miscellaneous transformation routines for non-loop IR // structures. diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index d8f5b1dc0e4..2dbac868cc0 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -1,19 +1,10 @@ //===- Vectorize.cpp - Vectorize Pass Impl --------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements vectorization of loops, operations and data types to // a target-independent, n-D super-vector abstraction. diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp index 591562d0245..508c547a52b 100644 --- a/mlir/lib/Transforms/ViewOpGraph.cpp +++ b/mlir/lib/Transforms/ViewOpGraph.cpp @@ -1,19 +1,10 @@ //===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Transforms/ViewOpGraph.h" #include "mlir/IR/Block.h" diff --git a/mlir/lib/Transforms/ViewRegionGraph.cpp b/mlir/lib/Transforms/ViewRegionGraph.cpp index db55415d62e..77111087d07 100644 --- a/mlir/lib/Transforms/ViewRegionGraph.cpp +++ b/mlir/lib/Transforms/ViewRegionGraph.cpp @@ -1,19 +1,10 @@ //===- ViewRegionGraph.cpp - View/write graphviz graphs -------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Transforms/ViewRegionGraph.h" #include "mlir/IR/RegionGraphTraits.h" diff --git a/mlir/lib/Translation/Translation.cpp b/mlir/lib/Translation/Translation.cpp index 8b5f98714e4..80c1e483731 100644 --- a/mlir/lib/Translation/Translation.cpp +++ b/mlir/lib/Translation/Translation.cpp @@ -1,19 +1,10 @@ //===- Translation.cpp - Translation registry -----------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Definitions of the translation registry. // diff --git a/mlir/test/APITest.h b/mlir/test/APITest.h index 9475bae2b58..08d64a0e48d 100644 --- a/mlir/test/APITest.h +++ b/mlir/test/APITest.h @@ -1,19 +1,10 @@ //===- Test.h - Simple macros for API unit tests ----------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file define simple macros for declaring test functions and running them. // The actual checking must be performed on the outputs with FileCheck. diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp index 376fc249a18..cfe703281e2 100644 --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -1,19 +1,10 @@ //===- builder-api-test.cpp - Tests for Declarative Builder APIs ----------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // RUN: mlir-edsc-builder-api-test | FileCheck %s diff --git a/mlir/test/SDBM/sdbm-api-test.cpp b/mlir/test/SDBM/sdbm-api-test.cpp index 12aff301d88..a672290d01d 100644 --- a/mlir/test/SDBM/sdbm-api-test.cpp +++ b/mlir/test/SDBM/sdbm-api-test.cpp @@ -1,19 +1,10 @@ //===- sdbm-api-test.cpp - Tests for SDBM expression APIs -----------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // RUN: mlir-sdbm-api-test | FileCheck %s diff --git a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td index d2313927398..d07f6060c3b 100644 --- a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td +++ b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td @@ -1,19 +1,10 @@ //===- TestLinalgTransformPatterns.td - Test patterns --*- tablegen ----*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This is the pattern definition file for declarative Linalg transformations // tests. diff --git a/mlir/test/lib/DeclarativeTransforms/TestVectorTransformPatterns.td b/mlir/test/lib/DeclarativeTransforms/TestVectorTransformPatterns.td index 228a8a018d6..29875ccd543 100644 --- a/mlir/test/lib/DeclarativeTransforms/TestVectorTransformPatterns.td +++ b/mlir/test/lib/DeclarativeTransforms/TestVectorTransformPatterns.td @@ -1,19 +1,10 @@ //===- TestVectorTransformPatterns.td - Test patterns ---*- tablegen ----*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This is the pattern definition file for declarative Vector transformations // tests. diff --git a/mlir/test/lib/IR/TestFunc.cpp b/mlir/test/lib/IR/TestFunc.cpp index 880d0785bb5..3e131590fae 100644 --- a/mlir/test/lib/IR/TestFunc.cpp +++ b/mlir/test/lib/IR/TestFunc.cpp @@ -1,19 +1,10 @@ //===- TestFunctionLike.cpp - Pass to test helpers on FunctionLike --------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/IR/Function.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/test/lib/IR/TestMatchers.cpp b/mlir/test/lib/IR/TestMatchers.cpp index 5985a88ffa6..b62daa8437c 100644 --- a/mlir/test/lib/IR/TestMatchers.cpp +++ b/mlir/test/lib/IR/TestMatchers.cpp @@ -1,19 +1,10 @@ //===- TestMatchers.cpp - Pass to test matchers ---------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Function.h" diff --git a/mlir/test/lib/IR/TestSymbolUses.cpp b/mlir/test/lib/IR/TestSymbolUses.cpp index 8ef4bb48a1c..c8fb1d8eecf 100644 --- a/mlir/test/lib/IR/TestSymbolUses.cpp +++ b/mlir/test/lib/IR/TestSymbolUses.cpp @@ -1,19 +1,10 @@ //===- TestSymbolUses.cpp - Pass to test symbol uselists ------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "TestDialect.h" #include "mlir/IR/Function.h" diff --git a/mlir/test/lib/Pass/TestPassManager.cpp b/mlir/test/lib/Pass/TestPassManager.cpp index d1e1a6d13ee..2e811634880 100644 --- a/mlir/test/lib/Pass/TestPassManager.cpp +++ b/mlir/test/lib/Pass/TestPassManager.cpp @@ -1,19 +1,10 @@ //===- TestPassManager.cpp - Test pass manager functionality --------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/IR/Function.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp index 12d024f6593..976a1976f01 100644 --- a/mlir/test/lib/TestDialect/TestDialect.cpp +++ b/mlir/test/lib/TestDialect/TestDialect.cpp @@ -1,19 +1,10 @@ //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "TestDialect.h" #include "mlir/IR/Function.h" diff --git a/mlir/test/lib/TestDialect/TestDialect.h b/mlir/test/lib/TestDialect/TestDialect.h index 783b8a1bcdd..20db0f39b81 100644 --- a/mlir/test/lib/TestDialect/TestDialect.h +++ b/mlir/test/lib/TestDialect/TestDialect.h @@ -1,19 +1,10 @@ //===- TestDialect.h - MLIR Dialect for testing -----------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines a fake 'test' dialect that can be used for testing things // that do not have a respective counterpart in the main source directories. diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td index ea071f0ddf4..a8709ddca27 100644 --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -1,19 +1,10 @@ //===-- TestOps.td - Test dialect operation definitions ----*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef TEST_OPS #define TEST_OPS diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp index 1f6224dba3a..b886097202d 100644 --- a/mlir/test/lib/TestDialect/TestPatterns.cpp +++ b/mlir/test/lib/TestDialect/TestPatterns.cpp @@ -1,19 +1,10 @@ //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "TestDialect.h" #include "mlir/IR/PatternMatch.h" diff --git a/mlir/test/lib/Transforms/TestCallGraph.cpp b/mlir/test/lib/Transforms/TestCallGraph.cpp index debf5e77645..6378d953648 100644 --- a/mlir/test/lib/Transforms/TestCallGraph.cpp +++ b/mlir/test/lib/Transforms/TestCallGraph.cpp @@ -1,19 +1,10 @@ //===- TestCallGraph.cpp - Test callgraph construction and iteration ------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains test passes for constructing and iterating over a // callgraph. diff --git a/mlir/test/lib/Transforms/TestConstantFold.cpp b/mlir/test/lib/Transforms/TestConstantFold.cpp index 5a0e9ed3f3c..f660bccca1d 100644 --- a/mlir/test/lib/Transforms/TestConstantFold.cpp +++ b/mlir/test/lib/Transforms/TestConstantFold.cpp @@ -1,19 +1,10 @@ //===- TestConstantFold.cpp - Pass to test constant folding ---------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" diff --git a/mlir/test/lib/Transforms/TestInlining.cpp b/mlir/test/lib/Transforms/TestInlining.cpp index 0571dc62b73..36378283f8e 100644 --- a/mlir/test/lib/Transforms/TestInlining.cpp +++ b/mlir/test/lib/Transforms/TestInlining.cpp @@ -1,19 +1,10 @@ //===- TestInlining.cpp - Pass to inline calls in the test dialect --------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // TODO(riverriddle) This pass is only necessary because the main inlining pass // has no abstracted away the call+callee relationship. When the inlining diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp index 37030ca2059..6ea995d3dfe 100644 --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -1,19 +1,10 @@ //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements logic for testing Linalg transformations. // diff --git a/mlir/test/lib/Transforms/TestLiveness.cpp b/mlir/test/lib/Transforms/TestLiveness.cpp index d97060247f4..23725740df4 100644 --- a/mlir/test/lib/Transforms/TestLiveness.cpp +++ b/mlir/test/lib/Transforms/TestLiveness.cpp @@ -1,20 +1,11 @@ //===- TestLiveness.cpp - Test liveness construction and information //-------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains test passes for constructing and resolving liveness // information. diff --git a/mlir/test/lib/Transforms/TestLoopFusion.cpp b/mlir/test/lib/Transforms/TestLoopFusion.cpp index 7dc722f21f6..23e5035153e 100644 --- a/mlir/test/lib/Transforms/TestLoopFusion.cpp +++ b/mlir/test/lib/Transforms/TestLoopFusion.cpp @@ -1,19 +1,10 @@ //===- TestLoopFusion.cpp - Test loop fusion ------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a pass to test various loop fusion utility functions. // diff --git a/mlir/test/lib/Transforms/TestLoopMapping.cpp b/mlir/test/lib/Transforms/TestLoopMapping.cpp index 7f587fc3170..5b1394d5996 100644 --- a/mlir/test/lib/Transforms/TestLoopMapping.cpp +++ b/mlir/test/lib/Transforms/TestLoopMapping.cpp @@ -1,19 +1,10 @@ //===- TestLoopMapping.cpp --- Parametric loop mapping pass ---------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a pass to parametrically map loop.for loops to virtual // processing element dimensions. diff --git a/mlir/test/lib/Transforms/TestLoopParametricTiling.cpp b/mlir/test/lib/Transforms/TestLoopParametricTiling.cpp index 9a8e1917e1f..7b0cdcade4d 100644 --- a/mlir/test/lib/Transforms/TestLoopParametricTiling.cpp +++ b/mlir/test/lib/Transforms/TestLoopParametricTiling.cpp @@ -1,19 +1,10 @@ //===- TestLoopParametricTiling.cpp --- Parametric loop tiling pass -------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a pass to parametrically tile nests of standard loops. // diff --git a/mlir/test/lib/Transforms/TestMemRefStrideCalculation.cpp b/mlir/test/lib/Transforms/TestMemRefStrideCalculation.cpp index 40788b259c5..d5e0b7df02b 100644 --- a/mlir/test/lib/Transforms/TestMemRefStrideCalculation.cpp +++ b/mlir/test/lib/Transforms/TestMemRefStrideCalculation.cpp @@ -1,19 +1,10 @@ //===- TestMemRefStrideCalculation.cpp - Pass to test strides computation--===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/StandardTypes.h" diff --git a/mlir/test/lib/Transforms/TestOpaqueLoc.cpp b/mlir/test/lib/Transforms/TestOpaqueLoc.cpp index 0db53322fb8..9a261c0bb3b 100644 --- a/mlir/test/lib/Transforms/TestOpaqueLoc.cpp +++ b/mlir/test/lib/Transforms/TestOpaqueLoc.cpp @@ -1,19 +1,10 @@ //===- TestOpaqueLoc.cpp - Pass to test opaque locations ------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Builders.h" diff --git a/mlir/test/lib/Transforms/TestVectorToLoopsConversion.cpp b/mlir/test/lib/Transforms/TestVectorToLoopsConversion.cpp index e5f5f749bd0..a31f8e474b4 100644 --- a/mlir/test/lib/Transforms/TestVectorToLoopsConversion.cpp +++ b/mlir/test/lib/Transforms/TestVectorToLoopsConversion.cpp @@ -1,19 +1,10 @@ //===- TestVectorToLoopsConversion.cpp - Test VectorTransfers lowering ----===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp index 1d513065330..664d49ab4e5 100644 --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -1,19 +1,10 @@ //===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include diff --git a/mlir/test/lib/Transforms/TestVectorizationUtils.cpp b/mlir/test/lib/Transforms/TestVectorizationUtils.cpp index 35df0631ca7..e131f4803ef 100644 --- a/mlir/test/lib/Transforms/TestVectorizationUtils.cpp +++ b/mlir/test/lib/Transforms/TestVectorizationUtils.cpp @@ -1,19 +1,10 @@ //===- VectorizerTestPass.cpp - VectorizerTestPass Pass Impl --------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file implements a simple testing pass for vectorization functionality. // diff --git a/mlir/test/mlir-cpu-runner/cblas.cpp b/mlir/test/mlir-cpu-runner/cblas.cpp index d219b7b1256..aebb8f212b4 100644 --- a/mlir/test/mlir-cpu-runner/cblas.cpp +++ b/mlir/test/mlir-cpu-runner/cblas.cpp @@ -1,19 +1,10 @@ //===- cblas.cpp - Simple Blas subset implementation ----------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Simple Blas subset implementation. // diff --git a/mlir/test/mlir-cpu-runner/cblas_interface.cpp b/mlir/test/mlir-cpu-runner/cblas_interface.cpp index 831702f594e..5e3a00e7fd1 100644 --- a/mlir/test/mlir-cpu-runner/cblas_interface.cpp +++ b/mlir/test/mlir-cpu-runner/cblas_interface.cpp @@ -1,19 +1,10 @@ //===- cblas_interface.cpp - Simple Blas subset interface -----------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Simple Blas subset interface implementation. // diff --git a/mlir/test/mlir-cpu-runner/include/cblas.h b/mlir/test/mlir-cpu-runner/include/cblas.h index 522ac8380b9..ccd316ff52e 100644 --- a/mlir/test/mlir-cpu-runner/include/cblas.h +++ b/mlir/test/mlir-cpu-runner/include/cblas.h @@ -1,19 +1,10 @@ //===- cblas.h - Simple Blas subset ---------------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_CPU_RUNNER_CBLAS_H_ #define MLIR_CPU_RUNNER_CBLAS_H_ diff --git a/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h b/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h index d4b6e1fedb0..1f4e638c33e 100644 --- a/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h +++ b/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h @@ -1,19 +1,10 @@ //===- mlir_runner_utils.h - Utils for debugging MLIR CPU execution -------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #ifndef MLIR_CPU_RUNNER_MLIRUTILS_H_ #define MLIR_CPU_RUNNER_MLIRUTILS_H_ diff --git a/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp b/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp index 56829c629bb..fc3a782c080 100644 --- a/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp +++ b/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp @@ -1,19 +1,10 @@ //===- mlir_runner_utils.cpp - Utils for MLIR CPU execution ---------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Utilities for interfacing MLIR types with C code as well as printing, // debugging etc. diff --git a/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp b/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp index f7023c4cf61..144f73d9c97 100644 --- a/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp +++ b/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp @@ -1,19 +1,10 @@ //===- mlir-cpu-runner.cpp - MLIR CPU Execution Driver---------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Main entry point to a command line utility that executes an MLIR file on the // CPU by translating MLIR to LLVM IR before JIT-compiling and executing the diff --git a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp index 0698095afcf..9f1591b5a8c 100644 --- a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp +++ b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp @@ -1,19 +1,10 @@ //===- cuda-runtime-wrappers.cpp - MLIR CUDA runner wrapper library -------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Implements C wrappers around the CUDA library for easy linking in ORC jit. // Also adds some debugging helpers that are helpful when writing MLIR code to diff --git a/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp b/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp index c1ca4ebd8e1..d6160d6d6e0 100644 --- a/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp +++ b/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp @@ -1,19 +1,10 @@ //===- mlir-cpu-runner.cpp - MLIR CPU Execution Driver---------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This is a command line utility that executes an MLIR file on the GPU by // translating MLIR to NVVM/LVVM IR before JIT-compiling and executing the diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index d01f66d4e0b..b0dd1b59ce7 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -1,19 +1,10 @@ //===- mlir-opt.cpp - MLIR Optimizer Driver -------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // Main entry function for mlir-opt for when built as standalone binary. // diff --git a/mlir/tools/mlir-tblgen/DocGenUtilities.h b/mlir/tools/mlir-tblgen/DocGenUtilities.h index b7617742727..1b3c8541aee 100644 --- a/mlir/tools/mlir-tblgen/DocGenUtilities.h +++ b/mlir/tools/mlir-tblgen/DocGenUtilities.h @@ -1,19 +1,10 @@ //===- DocGenUtilities.h - MLIR doc gen utilities ---------------*- C++ -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file defines common utilities for generating documents from tablegen // structures. diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp index e278fdd80e8..610a380dab3 100644 --- a/mlir/tools/mlir-tblgen/EnumsGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp @@ -1,19 +1,10 @@ //===- EnumsGen.cpp - MLIR enum utility generator -------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // EnumsGen generates common utility functions for enums. // diff --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp index f4b1279f11e..30f720e8d73 100644 --- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp +++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp @@ -1,19 +1,10 @@ //===- LLVMIRConversionGen.cpp - MLIR LLVM IR builder generator -----------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file uses tablegen definitions of the LLVM IR Dialect operations to // generate the code building the LLVM IR from it. diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index df8feb855c5..52cbc08c429 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1,19 +1,10 @@ //===- OpDefinitionsGen.cpp - MLIR op definitions generator ---------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // OpDefinitionsGen uses the description of operations to generate C++ // definitions for ops. diff --git a/mlir/tools/mlir-tblgen/OpDocGen.cpp b/mlir/tools/mlir-tblgen/OpDocGen.cpp index 8b048d9ea94..87a27238ce3 100644 --- a/mlir/tools/mlir-tblgen/OpDocGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDocGen.cpp @@ -1,19 +1,10 @@ //===- OpDocGen.cpp - MLIR operation documentation generator --------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // OpDocGen uses the description of operations to generate documentation for the // operations. diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp index a48bd2509bc..a96736cd2c5 100644 --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -1,19 +1,10 @@ //===- OpInterfacesGen.cpp - MLIR op interface utility generator ----------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // OpInterfacesGen generates definitions for operation interfaces. // diff --git a/mlir/tools/mlir-tblgen/ReferenceImplGen.cpp b/mlir/tools/mlir-tblgen/ReferenceImplGen.cpp index 9181d0e90ed..90b60e5efed 100644 --- a/mlir/tools/mlir-tblgen/ReferenceImplGen.cpp +++ b/mlir/tools/mlir-tblgen/ReferenceImplGen.cpp @@ -1,19 +1,10 @@ //===- ReferenceImplGen.cpp - MLIR reference implementation generator -----===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // ReferenceImplGen uses the description of operations to generate reference // implementations for the ops. diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index a74bc23a95a..8cfd454d629 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -1,19 +1,10 @@ //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. // diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp index 6d5bcc116ad..1aa7d5968d2 100644 --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -1,19 +1,10 @@ //===- SPIRVSerializationGen.cpp - SPIR-V serialization utility generator -===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // SPIRVSerializationGen generates common utility functions for SPIR-V // serialization. diff --git a/mlir/tools/mlir-tblgen/StructsGen.cpp b/mlir/tools/mlir-tblgen/StructsGen.cpp index d8844957ece..576085e41eb 100644 --- a/mlir/tools/mlir-tblgen/StructsGen.cpp +++ b/mlir/tools/mlir-tblgen/StructsGen.cpp @@ -1,19 +1,10 @@ //===- StructsGen.cpp - MLIR struct utility generator ---------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // StructsGen generates common utility functions for grouping attributes into a // set of structured data. diff --git a/mlir/tools/mlir-tblgen/mlir-tblgen.cpp b/mlir/tools/mlir-tblgen/mlir-tblgen.cpp index 993a05d7095..3c9778b3ec7 100644 --- a/mlir/tools/mlir-tblgen/mlir-tblgen.cpp +++ b/mlir/tools/mlir-tblgen/mlir-tblgen.cpp @@ -1,19 +1,10 @@ //===- mlir-tblgen.cpp - Top-Level TableGen implementation for MLIR -------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains the main function for MLIR's TableGen. // diff --git a/mlir/tools/mlir-translate/mlir-translate.cpp b/mlir/tools/mlir-translate/mlir-translate.cpp index b5622e3ecf8..3b15c5f3875 100644 --- a/mlir/tools/mlir-translate/mlir-translate.cpp +++ b/mlir/tools/mlir-translate/mlir-translate.cpp @@ -1,19 +1,10 @@ //===- mlir-translate.cpp - MLIR Translate Driver -------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This is a command line utility that translates a file from/to MLIR using one // of the registered translations. diff --git a/mlir/unittests/ADT/TypeSwitchTest.cpp b/mlir/unittests/ADT/TypeSwitchTest.cpp index b6a78de892e..549fb9b221e 100644 --- a/mlir/unittests/ADT/TypeSwitchTest.cpp +++ b/mlir/unittests/ADT/TypeSwitchTest.cpp @@ -1,19 +1,10 @@ //===- TypeSwitchTest.cpp - TypeSwitch unit tests -------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/ADT/TypeSwitch.h" #include "gtest/gtest.h" diff --git a/mlir/unittests/Dialect/BroadcastShapeTest.cpp b/mlir/unittests/Dialect/BroadcastShapeTest.cpp index c475fa79476..594e98741e1 100644 --- a/mlir/unittests/Dialect/BroadcastShapeTest.cpp +++ b/mlir/unittests/Dialect/BroadcastShapeTest.cpp @@ -1,19 +1,10 @@ //===- BroadcastShapeTest.cpp - broadcasting shape unit tests -------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Dialect/Traits.h" #include "llvm/ADT/SmallVector.h" diff --git a/mlir/unittests/Dialect/QuantOps/QuantizationUtilsTest.cpp b/mlir/unittests/Dialect/QuantOps/QuantizationUtilsTest.cpp index d10623e3d1d..4f6ad302c7c 100644 --- a/mlir/unittests/Dialect/QuantOps/QuantizationUtilsTest.cpp +++ b/mlir/unittests/Dialect/QuantOps/QuantizationUtilsTest.cpp @@ -1,19 +1,10 @@ //===- QuantizationUtilsTest.cpp - unit tests for quantization utils ------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Dialect/QuantOps/QuantizeUtils.h" #include "mlir/Dialect/QuantOps/UniformSupport.h" diff --git a/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp b/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp index ed5bd9ebecc..72fee15ac90 100644 --- a/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp +++ b/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp @@ -1,19 +1,10 @@ //===- DeserializationTest.cpp - SPIR-V Deserialization Tests -------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // The purpose of this file is to provide negative deserialization tests. // For positive deserialization tests, please use serialization and diff --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp index 2728ab820b6..61f5fcbfb7b 100644 --- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp +++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp @@ -1,19 +1,10 @@ //===- SerializationTest.cpp - SPIR-V Serialization Tests -----------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// // // This file contains corner case tests for the SPIR-V serializer that are not // covered by normal serialization and deserialization roundtripping. diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp index 3db87b29c2d..5a1750e1123 100644 --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -1,19 +1,10 @@ //===- AttributeTest.cpp - Attribute unit tests ---------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/IR/Attributes.h" #include "mlir/IR/StandardTypes.h" diff --git a/mlir/unittests/IR/DialectTest.cpp b/mlir/unittests/IR/DialectTest.cpp index e48d2d7d710..1438d322652 100644 --- a/mlir/unittests/IR/DialectTest.cpp +++ b/mlir/unittests/IR/DialectTest.cpp @@ -1,19 +1,10 @@ //===- DialectTest.cpp - Dialect unit tests -------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/IR/Dialect.h" #include "gtest/gtest.h" diff --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp index d7dae4648fe..004a940ca6c 100644 --- a/mlir/unittests/IR/OperationSupportTest.cpp +++ b/mlir/unittests/IR/OperationSupportTest.cpp @@ -1,19 +1,10 @@ //===- OperationSupportTest.cpp - Operation support unit tests ------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/IR/OperationSupport.h" #include "mlir/IR/Builders.h" diff --git a/mlir/unittests/IR/StringExtrasTest.cpp b/mlir/unittests/IR/StringExtrasTest.cpp index def65950365..3773006faee 100644 --- a/mlir/unittests/IR/StringExtrasTest.cpp +++ b/mlir/unittests/IR/StringExtrasTest.cpp @@ -1,19 +1,10 @@ //===- StringExtrasTest.cpp - Tests for utility methods in StringExtras.h -===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Support/StringExtras.h" #include "gtest/gtest.h" diff --git a/mlir/unittests/Pass/AnalysisManagerTest.cpp b/mlir/unittests/Pass/AnalysisManagerTest.cpp index 790ad9c2589..0ea2c3f66e5 100644 --- a/mlir/unittests/Pass/AnalysisManagerTest.cpp +++ b/mlir/unittests/Pass/AnalysisManagerTest.cpp @@ -1,19 +1,10 @@ //===- AnalysisManagerTest.cpp - AnalysisManager unit tests ---------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Pass/AnalysisManager.h" #include "mlir/IR/Builders.h" diff --git a/mlir/unittests/Quantizer/Support/RulesTest.cpp b/mlir/unittests/Quantizer/Support/RulesTest.cpp index 7ddfb715751..e2593848cbf 100644 --- a/mlir/unittests/Quantizer/Support/RulesTest.cpp +++ b/mlir/unittests/Quantizer/Support/RulesTest.cpp @@ -1,19 +1,10 @@ //===- RulesTest.cpp - Rules unit tests -----------------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Quantizer/Support/Rules.h" #include "llvm/Support/raw_ostream.h" diff --git a/mlir/unittests/Quantizer/Support/UniformSolversTest.cpp b/mlir/unittests/Quantizer/Support/UniformSolversTest.cpp index 4a53f923d3d..4e27cdc1d66 100644 --- a/mlir/unittests/Quantizer/Support/UniformSolversTest.cpp +++ b/mlir/unittests/Quantizer/Support/UniformSolversTest.cpp @@ -1,19 +1,10 @@ //===- UniformSolversTest.cpp - Tests for uniform solvers -----------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Quantizer/Support/UniformSolvers.h" #include "gtest/gtest.h" diff --git a/mlir/unittests/SDBM/SDBMTest.cpp b/mlir/unittests/SDBM/SDBMTest.cpp index 99756dcaa96..aa55ce58477 100644 --- a/mlir/unittests/SDBM/SDBMTest.cpp +++ b/mlir/unittests/SDBM/SDBMTest.cpp @@ -1,19 +1,10 @@ //===- SDBMTest.cpp - SDBM expression unit tests --------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/Dialect/SDBM/SDBM.h" #include "mlir/Dialect/SDBM/SDBMDialect.h" diff --git a/mlir/unittests/TableGen/EnumsGenTest.cpp b/mlir/unittests/TableGen/EnumsGenTest.cpp index 3934e8b0ed2..e4fe68482ef 100644 --- a/mlir/unittests/TableGen/EnumsGenTest.cpp +++ b/mlir/unittests/TableGen/EnumsGenTest.cpp @@ -1,19 +1,10 @@ //===- EnumsGenTest.cpp - TableGen EnumsGen Tests -------------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/Optional.h" diff --git a/mlir/unittests/TableGen/FormatTest.cpp b/mlir/unittests/TableGen/FormatTest.cpp index 7338a8f7554..0566c8a5a7b 100644 --- a/mlir/unittests/TableGen/FormatTest.cpp +++ b/mlir/unittests/TableGen/FormatTest.cpp @@ -1,19 +1,10 @@ //===- FormatTest.cpp - TableGen Format Utility Tests ---------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/TableGen/Format.h" #include "gmock/gmock.h" diff --git a/mlir/unittests/TableGen/StructsGenTest.cpp b/mlir/unittests/TableGen/StructsGenTest.cpp index b446ca9558a..45455d7e2aa 100644 --- a/mlir/unittests/TableGen/StructsGenTest.cpp +++ b/mlir/unittests/TableGen/StructsGenTest.cpp @@ -1,19 +1,10 @@ //===- StructsGenTest.cpp - TableGen StructsGen Tests ---------------------===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "mlir/IR/Attributes.h" #include "mlir/IR/Identifier.h" diff --git a/mlir/unittests/TableGen/enums.td b/mlir/unittests/TableGen/enums.td index 806d4a9d7b9..7d44856f9dc 100644 --- a/mlir/unittests/TableGen/enums.td +++ b/mlir/unittests/TableGen/enums.td @@ -1,19 +1,10 @@ //===-- enums.td - EnumsGen test definition file -----------*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// include "mlir/IR/OpBase.td" diff --git a/mlir/unittests/TableGen/structs.td b/mlir/unittests/TableGen/structs.td index efa0a6024c5..88551751182 100644 --- a/mlir/unittests/TableGen/structs.td +++ b/mlir/unittests/TableGen/structs.td @@ -1,19 +1,10 @@ //===-- structs.td - StructsGen test definition file -------*- tablegen -*-===// // -// Copyright 2019 The MLIR Authors. +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= +//===----------------------------------------------------------------------===// include "mlir/IR/OpBase.td" diff --git a/mlir/utils/generate-test-checks.py b/mlir/utils/generate-test-checks.py index 3bb4ffe4a4f..6dc40c797e2 100755 --- a/mlir/utils/generate-test-checks.py +++ b/mlir/utils/generate-test-checks.py @@ -17,19 +17,9 @@ adding checks to a test case fast, it is *not* designed to be authoritative about what constitutes a good test! """ -# Copyright 2019 The MLIR Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import argparse import os # Used to advertise this file's name ("autogenerated_note"). diff --git a/mlir/utils/spirv/define_enum.sh b/mlir/utils/spirv/define_enum.sh index 9da898f7d4c..87b88c93133 100755 --- a/mlir/utils/spirv/define_enum.sh +++ b/mlir/utils/spirv/define_enum.sh @@ -1,18 +1,8 @@ #!/bin/bash -# Copyright 2019 The MLIR Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Script for defining a new enum attr using SPIR-V spec from the Internet. # diff --git a/mlir/utils/spirv/define_inst.sh b/mlir/utils/spirv/define_inst.sh index f11078a8e76..322c67e8da8 100755 --- a/mlir/utils/spirv/define_inst.sh +++ b/mlir/utils/spirv/define_inst.sh @@ -1,17 +1,7 @@ #!/bin/bash -# Copyright 2019 The MLIR Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Script for defining a new op using SPIR-V spec from the Internet. # diff --git a/mlir/utils/spirv/define_opcodes.sh b/mlir/utils/spirv/define_opcodes.sh index 05c36571115..7b9aeab9c08 100755 --- a/mlir/utils/spirv/define_opcodes.sh +++ b/mlir/utils/spirv/define_opcodes.sh @@ -1,18 +1,8 @@ #!/bin/bash -# Copyright 2019 The MLIR Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Script for defining map for opname to opcode using SPIR-V spec from the # Internet diff --git a/mlir/utils/spirv/gen_spirv_dialect.py b/mlir/utils/spirv/gen_spirv_dialect.py index be7116c211f..2433cf4e6da 100755 --- a/mlir/utils/spirv/gen_spirv_dialect.py +++ b/mlir/utils/spirv/gen_spirv_dialect.py @@ -1,19 +1,9 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright 2019 The MLIR Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Script for updating SPIR-V dialect by scraping information from SPIR-V # HTML and JSON specs from the Internet. -- cgit v1.2.3 From e62a69561fb9d7b1013d2853da68d79a7907fead Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 23 Dec 2019 14:45:01 -0800 Subject: NFC: Replace ValuePtr with Value and remove it now that Value is value-typed. ValuePtr was a temporary typedef during the transition to a value-typed Value. PiperOrigin-RevId: 286945714 --- mlir/bindings/python/pybind.cpp | 2 +- mlir/examples/toy/Ch2/include/toy/Ops.td | 8 +- mlir/examples/toy/Ch2/mlir/Dialect.cpp | 9 +- mlir/examples/toy/Ch2/mlir/MLIRGen.cpp | 35 ++- mlir/examples/toy/Ch3/include/toy/Ops.td | 8 +- mlir/examples/toy/Ch3/mlir/Dialect.cpp | 9 +- mlir/examples/toy/Ch3/mlir/MLIRGen.cpp | 35 ++- mlir/examples/toy/Ch3/mlir/ToyCombine.cpp | 2 +- mlir/examples/toy/Ch4/include/toy/Ops.td | 8 +- mlir/examples/toy/Ch4/mlir/Dialect.cpp | 13 +- mlir/examples/toy/Ch4/mlir/MLIRGen.cpp | 35 ++- mlir/examples/toy/Ch4/mlir/ToyCombine.cpp | 2 +- mlir/examples/toy/Ch5/include/toy/Ops.td | 8 +- mlir/examples/toy/Ch5/mlir/Dialect.cpp | 13 +- mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp | 36 +-- mlir/examples/toy/Ch5/mlir/MLIRGen.cpp | 35 ++- mlir/examples/toy/Ch5/mlir/ToyCombine.cpp | 2 +- mlir/examples/toy/Ch6/include/toy/Ops.td | 8 +- mlir/examples/toy/Ch6/mlir/Dialect.cpp | 13 +- mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp | 36 +-- mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp | 27 +- mlir/examples/toy/Ch6/mlir/MLIRGen.cpp | 35 ++- mlir/examples/toy/Ch6/mlir/ToyCombine.cpp | 2 +- mlir/examples/toy/Ch7/include/toy/Ops.td | 10 +- mlir/examples/toy/Ch7/mlir/Dialect.cpp | 15 +- mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp | 36 +-- mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp | 27 +- mlir/examples/toy/Ch7/mlir/MLIRGen.cpp | 34 +-- mlir/examples/toy/Ch7/mlir/ToyCombine.cpp | 2 +- mlir/g3doc/DeclarativeRewrites.md | 6 +- mlir/g3doc/DialectConversion.md | 6 +- mlir/g3doc/GenericDAGRewriter.md | 2 +- mlir/g3doc/OpDefinitions.md | 19 +- mlir/g3doc/Tutorials/Toy/Ch-3.md | 2 +- mlir/g3doc/Tutorials/Toy/Ch-4.md | 4 +- mlir/g3doc/Tutorials/Toy/Ch-5.md | 10 +- mlir/g3doc/UsageOfConst.md | 8 +- mlir/include/mlir/Analysis/AffineAnalysis.h | 9 +- mlir/include/mlir/Analysis/AffineStructures.h | 72 ++--- mlir/include/mlir/Analysis/CallInterfaces.h | 4 +- mlir/include/mlir/Analysis/Dominance.h | 4 +- mlir/include/mlir/Analysis/Liveness.h | 17 +- mlir/include/mlir/Analysis/LoopAnalysis.h | 9 +- mlir/include/mlir/Analysis/Utils.h | 10 +- .../Conversion/AffineToStandard/AffineToStandard.h | 12 +- .../mlir/Conversion/LoopsToGPU/LoopsToGPU.h | 7 +- .../StandardToLLVM/ConvertStandardToLLVM.h | 57 ++-- mlir/include/mlir/Dialect/AffineOps/AffineOps.h | 99 ++++--- mlir/include/mlir/Dialect/AffineOps/AffineOps.td | 8 +- mlir/include/mlir/Dialect/GPU/GPUDialect.h | 6 +- mlir/include/mlir/Dialect/GPU/GPUOps.td | 16 +- mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h | 6 +- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 22 +- .../Dialect/Linalg/Analysis/DependenceAnalysis.h | 16 +- mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h | 30 +-- .../mlir/Dialect/Linalg/IR/LinalgLibraryOps.td | 10 +- mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td | 14 +- .../mlir/Dialect/Linalg/IR/LinalgStructuredOps.td | 10 +- mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h | 8 +- .../Linalg/Transforms/LinalgTransformPatterns.td | 2 +- .../Dialect/Linalg/Transforms/LinalgTransforms.h | 4 +- mlir/include/mlir/Dialect/Linalg/Utils/Utils.h | 33 ++- mlir/include/mlir/Dialect/LoopOps/LoopOps.h | 2 +- mlir/include/mlir/Dialect/LoopOps/LoopOps.td | 12 +- .../mlir/Dialect/SPIRV/SPIRVCompositeOps.td | 2 +- .../mlir/Dialect/SPIRV/SPIRVControlFlowOps.td | 2 +- mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td | 4 +- mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h | 4 +- mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td | 6 +- mlir/include/mlir/Dialect/StandardOps/Ops.h | 34 ++- mlir/include/mlir/Dialect/StandardOps/Ops.td | 78 +++--- mlir/include/mlir/Dialect/VectorOps/Utils.h | 5 +- mlir/include/mlir/Dialect/VectorOps/VectorOps.td | 20 +- .../mlir/Dialect/VectorOps/VectorTransforms.h | 5 +- mlir/include/mlir/EDSC/Builders.h | 10 +- mlir/include/mlir/EDSC/Helpers.h | 8 +- mlir/include/mlir/EDSC/Intrinsics.h | 10 +- mlir/include/mlir/IR/Block.h | 8 +- mlir/include/mlir/IR/Builders.h | 10 +- mlir/include/mlir/IR/FunctionSupport.h | 2 +- mlir/include/mlir/IR/Matchers.h | 14 +- mlir/include/mlir/IR/OpDefinition.h | 38 ++- mlir/include/mlir/IR/OpImplementation.h | 17 +- mlir/include/mlir/IR/Operation.h | 31 ++- mlir/include/mlir/IR/OperationSupport.h | 2 +- mlir/include/mlir/IR/TypeUtilities.h | 10 +- mlir/include/mlir/IR/Value.h | 6 - .../Quantizer/Support/ConstraintAnalysisGraph.h | 8 +- .../include/mlir/Target/LLVMIR/ModuleTranslation.h | 2 +- mlir/include/mlir/Transforms/DialectConversion.h | 46 ++-- mlir/include/mlir/Transforms/FoldUtils.h | 10 +- mlir/include/mlir/Transforms/InliningUtils.h | 14 +- mlir/include/mlir/Transforms/LoopLikeInterface.td | 2 +- mlir/include/mlir/Transforms/LoopUtils.h | 14 +- mlir/include/mlir/Transforms/RegionUtils.h | 9 +- mlir/include/mlir/Transforms/Utils.h | 20 +- mlir/lib/Analysis/AffineAnalysis.cpp | 48 ++-- mlir/lib/Analysis/AffineStructures.cpp | 85 +++--- mlir/lib/Analysis/CallGraph.cpp | 2 +- mlir/lib/Analysis/Dominance.cpp | 2 +- mlir/lib/Analysis/Liveness.cpp | 34 +-- mlir/lib/Analysis/LoopAnalysis.cpp | 25 +- mlir/lib/Analysis/Utils.cpp | 32 +-- mlir/lib/Analysis/VectorAnalysis.cpp | 4 +- .../AffineToStandard/AffineToStandard.cpp | 138 +++++----- .../GPUCommon/IndexIntrinsicsOpLowering.h | 4 +- .../Conversion/GPUCommon/OpToFuncCallLowering.h | 6 +- .../GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp | 49 ++-- .../Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 215 ++++++++------- .../Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp | 30 +-- mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp | 68 +++-- .../LoopToStandard/ConvertLoopToStandard.cpp | 11 +- mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp | 112 ++++---- mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp | 2 +- .../StandardToLLVM/ConvertStandardToLLVM.cpp | 300 ++++++++++----------- .../StandardToSPIRV/ConvertStandardToSPIRV.cpp | 39 ++- .../StandardToSPIRV/ConvertStandardToSPIRVPass.cpp | 4 +- .../StandardToSPIRV/LegalizeStandardForSPIRV.cpp | 8 +- .../VectorToLLVM/ConvertVectorToLLVM.cpp | 111 ++++---- mlir/lib/Dialect/AffineOps/AffineOps.cpp | 124 +++++---- .../FxpMathOps/Transforms/LowerUniformRealMath.cpp | 64 ++--- .../FxpMathOps/Transforms/UniformKernelUtils.h | 6 +- mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 29 +- .../lib/Dialect/GPU/Transforms/KernelOutlining.cpp | 12 +- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 30 +-- .../Dialect/Linalg/Analysis/DependenceAnalysis.cpp | 8 +- mlir/lib/Dialect/Linalg/EDSC/Builders.cpp | 19 +- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 4 +- mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 29 +- .../Dialect/Linalg/Transforms/LinalgToLoops.cpp | 46 ++-- .../Dialect/Linalg/Transforms/LinalgTransforms.cpp | 6 +- mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp | 25 +- mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 55 ++-- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 26 +- mlir/lib/Dialect/LoopOps/LoopOps.cpp | 10 +- mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp | 2 +- mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp | 8 +- mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 35 ++- .../Dialect/SPIRV/Serialization/Deserializer.cpp | 28 +- .../lib/Dialect/SPIRV/Serialization/Serializer.cpp | 14 +- .../SPIRV/Transforms/LowerABIAttributesPass.cpp | 6 +- mlir/lib/Dialect/StandardOps/Ops.cpp | 60 ++--- mlir/lib/Dialect/VectorOps/VectorOps.cpp | 25 +- mlir/lib/Dialect/VectorOps/VectorTransforms.cpp | 71 ++--- mlir/lib/EDSC/Builders.cpp | 23 +- mlir/lib/EDSC/Helpers.cpp | 6 +- mlir/lib/EDSC/Intrinsics.cpp | 12 +- mlir/lib/IR/AsmPrinter.cpp | 39 ++- mlir/lib/IR/Block.cpp | 2 +- mlir/lib/IR/Builders.cpp | 4 +- mlir/lib/IR/Operation.cpp | 20 +- mlir/lib/IR/Region.cpp | 2 +- mlir/lib/IR/TypeUtilities.cpp | 10 +- mlir/lib/Parser/Parser.cpp | 59 ++-- mlir/lib/Pass/IRPrinting.cpp | 4 +- .../Quantizer/Support/ConstraintAnalysisGraph.cpp | 2 +- .../Transforms/InferQuantizedTypesPass.cpp | 14 +- mlir/lib/TableGen/Pattern.cpp | 2 +- mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp | 38 +-- mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 8 +- mlir/lib/Transforms/AffineDataCopyGeneration.cpp | 2 +- .../Transforms/AffineLoopInvariantCodeMotion.cpp | 19 +- mlir/lib/Transforms/DialectConversion.cpp | 50 ++-- mlir/lib/Transforms/LoopFusion.cpp | 65 +++-- mlir/lib/Transforms/LoopInvariantCodeMotion.cpp | 4 +- mlir/lib/Transforms/LoopTiling.cpp | 11 +- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 2 +- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 4 +- mlir/lib/Transforms/PipelineDataTransfer.cpp | 12 +- mlir/lib/Transforms/Utils/FoldUtils.cpp | 6 +- .../Utils/GreedyPatternRewriteDriver.cpp | 4 +- mlir/lib/Transforms/Utils/InliningUtils.cpp | 34 +-- mlir/lib/Transforms/Utils/LoopFusionUtils.cpp | 10 +- mlir/lib/Transforms/Utils/LoopUtils.cpp | 151 +++++------ mlir/lib/Transforms/Utils/RegionUtils.cpp | 22 +- mlir/lib/Transforms/Utils/Utils.cpp | 45 ++-- mlir/lib/Transforms/Vectorize.cpp | 25 +- mlir/test/EDSC/builder-api-test.cpp | 2 +- mlir/test/lib/TestDialect/TestDialect.cpp | 8 +- mlir/test/lib/TestDialect/TestOps.td | 2 +- mlir/test/lib/TestDialect/TestPatterns.cpp | 33 ++- mlir/test/lib/Transforms/TestLoopMapping.cpp | 2 +- .../test/lib/Transforms/TestVectorizationUtils.cpp | 2 +- mlir/test/mlir-tblgen/op-attribute.td | 6 +- mlir/test/mlir-tblgen/op-decl.td | 24 +- mlir/test/mlir-tblgen/op-operand.td | 10 +- mlir/test/mlir-tblgen/op-result.td | 6 +- mlir/test/mlir-tblgen/predicate.td | 4 +- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 20 +- mlir/tools/mlir-tblgen/RewriterGen.cpp | 8 +- mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp | 2 +- mlir/unittests/IR/OperationSupportTest.cpp | 8 +- 192 files changed, 2160 insertions(+), 2271 deletions(-) (limited to 'mlir/lib/Transforms/LoopFusion.cpp') diff --git a/mlir/bindings/python/pybind.cpp b/mlir/bindings/python/pybind.cpp index caff9af59ac..7a3864704ba 100644 --- a/mlir/bindings/python/pybind.cpp +++ b/mlir/bindings/python/pybind.cpp @@ -95,7 +95,7 @@ struct PythonValueHandle { assert(value.hasType() && value.getType().isa() && "can only call function-typed values"); - std::vector argValues; + std::vector argValues; argValues.reserve(args.size()); for (auto arg : args) argValues.push_back(arg.value.getValue()); diff --git a/mlir/examples/toy/Ch2/include/toy/Ops.td b/mlir/examples/toy/Ch2/include/toy/Ops.td index 20c4a7463d9..aa7e94fcae7 100644 --- a/mlir/examples/toy/Ch2/include/toy/Ops.td +++ b/mlir/examples/toy/Ch2/include/toy/Ops.td @@ -89,7 +89,7 @@ def AddOp : Toy_Op<"add"> { // Allow building an AddOp with from the two input operands. let builders = [ - OpBuilder<"Builder *b, OperationState &state, ValuePtr lhs, ValuePtr rhs"> + OpBuilder<"Builder *b, OperationState &state, Value lhs, Value rhs"> ]; } @@ -120,7 +120,7 @@ def GenericCallOp : Toy_Op<"generic_call"> { // Add custom build methods for the generic call operation. let builders = [ OpBuilder<"Builder *builder, OperationState &state, " - "StringRef callee, ArrayRef arguments"> + "StringRef callee, ArrayRef arguments"> ]; } @@ -136,7 +136,7 @@ def MulOp : Toy_Op<"mul"> { // Allow building a MulOp with from the two input operands. let builders = [ - OpBuilder<"Builder *b, OperationState &state, ValuePtr lhs, ValuePtr rhs"> + OpBuilder<"Builder *b, OperationState &state, Value lhs, Value rhs"> ]; } @@ -210,7 +210,7 @@ def TransposeOp : Toy_Op<"transpose"> { // Allow building a TransposeOp with from the input operand. let builders = [ - OpBuilder<"Builder *b, OperationState &state, ValuePtr input"> + OpBuilder<"Builder *b, OperationState &state, Value input"> ]; // Invoke a static verify method to verify this transpose operation. diff --git a/mlir/examples/toy/Ch2/mlir/Dialect.cpp b/mlir/examples/toy/Ch2/mlir/Dialect.cpp index b33cb5cbfe9..6b4d669d18e 100644 --- a/mlir/examples/toy/Ch2/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch2/mlir/Dialect.cpp @@ -85,7 +85,7 @@ static mlir::LogicalResult verify(ConstantOp op) { // AddOp void AddOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::ValuePtr lhs, mlir::ValuePtr rhs) { + mlir::Value lhs, mlir::Value rhs) { state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands({lhs, rhs}); } @@ -94,8 +94,7 @@ void AddOp::build(mlir::Builder *builder, mlir::OperationState &state, // GenericCallOp void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state, - StringRef callee, - ArrayRef arguments) { + StringRef callee, ArrayRef arguments) { // Generic call always returns an unranked Tensor initially. state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands(arguments); @@ -106,7 +105,7 @@ void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state, // MulOp void MulOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::ValuePtr lhs, mlir::ValuePtr rhs) { + mlir::Value lhs, mlir::Value rhs) { state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands({lhs, rhs}); } @@ -153,7 +152,7 @@ static mlir::LogicalResult verify(ReturnOp op) { // TransposeOp void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::ValuePtr value) { + mlir::Value value) { state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands(value); } diff --git a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp index e9987ff2c77..d9c960c79f4 100644 --- a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp @@ -90,7 +90,7 @@ private: /// Entering a function creates a new scope, and the function arguments are /// added to the mapping. When the processing of a function is terminated, the /// scope is destroyed and the mappings created in this scope are dropped. - llvm::ScopedHashTable symbolTable; + llvm::ScopedHashTable symbolTable; /// Helper conversion for a Toy AST location to an MLIR location. mlir::Location loc(Location loc) { @@ -100,7 +100,7 @@ private: /// Declare a variable in the current scope, return success if the variable /// wasn't declared yet. - mlir::LogicalResult declare(llvm::StringRef var, mlir::ValuePtr value) { + mlir::LogicalResult declare(llvm::StringRef var, mlir::Value value) { if (symbolTable.count(var)) return mlir::failure(); symbolTable.insert(var, value); @@ -123,8 +123,7 @@ private: /// Emit a new function and add it to the MLIR module. mlir::FuncOp mlirGen(FunctionAST &funcAST) { // Create a scope in the symbol table to hold variable declarations. - ScopedHashTableScope var_scope( - symbolTable); + ScopedHashTableScope var_scope(symbolTable); // Create an MLIR function for the given prototype. mlir::FuncOp function(mlirGen(*funcAST.getProto())); @@ -175,7 +174,7 @@ private: } /// Emit a binary operation - mlir::ValuePtr mlirGen(BinaryExprAST &binop) { + mlir::Value mlirGen(BinaryExprAST &binop) { // First emit the operations for each side of the operation before emitting // the operation itself. For example if the expression is `a + foo(a)` // 1) First it will visiting the LHS, which will return a reference to the @@ -187,10 +186,10 @@ private: // and the result value is returned. If an error occurs we get a nullptr // and propagate. // - mlir::ValuePtr lhs = mlirGen(*binop.getLHS()); + mlir::Value lhs = mlirGen(*binop.getLHS()); if (!lhs) return nullptr; - mlir::ValuePtr rhs = mlirGen(*binop.getRHS()); + mlir::Value rhs = mlirGen(*binop.getRHS()); if (!rhs) return nullptr; auto location = loc(binop.loc()); @@ -211,7 +210,7 @@ private: /// This is a reference to a variable in an expression. The variable is /// expected to have been declared and so should have a value in the symbol /// table, otherwise emit an error and return nullptr. - mlir::ValuePtr mlirGen(VariableExprAST &expr) { + mlir::Value mlirGen(VariableExprAST &expr) { if (auto variable = symbolTable.lookup(expr.getName())) return variable; @@ -225,7 +224,7 @@ private: auto location = loc(ret.loc()); // 'return' takes an optional expression, handle that case here. - mlir::ValuePtr expr = nullptr; + mlir::Value expr = nullptr; if (ret.getExpr().hasValue()) { if (!(expr = mlirGen(*ret.getExpr().getValue()))) return mlir::failure(); @@ -233,7 +232,7 @@ private: // Otherwise, this return operation has zero operands. builder.create(location, expr ? makeArrayRef(expr) - : ArrayRef()); + : ArrayRef()); return mlir::success(); } @@ -255,7 +254,7 @@ private: /// [[1.000000e+00, 2.000000e+00, 3.000000e+00], /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64> /// - mlir::ValuePtr mlirGen(LiteralExprAST &lit) { + mlir::Value mlirGen(LiteralExprAST &lit) { auto type = getType(lit.getDims()); // The attribute is a vector with a floating point value per element @@ -301,12 +300,12 @@ private: /// Emit a call expression. It emits specific operations for the `transpose` /// builtin. Other identifiers are assumed to be user-defined functions. - mlir::ValuePtr mlirGen(CallExprAST &call) { + mlir::Value mlirGen(CallExprAST &call) { llvm::StringRef callee = call.getCallee(); auto location = loc(call.loc()); // Codegen the operands first. - SmallVector operands; + SmallVector operands; for (auto &expr : call.getArgs()) { auto arg = mlirGen(*expr); if (!arg) @@ -343,12 +342,12 @@ private: } /// Emit a constant for a single number (FIXME: semantic? broadcast?) - mlir::ValuePtr mlirGen(NumberExprAST &num) { + mlir::Value mlirGen(NumberExprAST &num) { return builder.create(loc(num.loc()), num.getValue()); } /// Dispatch codegen for the right expression subclass using RTTI. - mlir::ValuePtr mlirGen(ExprAST &expr) { + mlir::Value mlirGen(ExprAST &expr) { switch (expr.getKind()) { case toy::ExprAST::Expr_BinOp: return mlirGen(cast(expr)); @@ -372,7 +371,7 @@ private: /// initializer and record the value in the symbol table before returning it. /// Future expressions will be able to reference this variable through symbol /// table lookup. - mlir::ValuePtr mlirGen(VarDeclExprAST &vardecl) { + mlir::Value mlirGen(VarDeclExprAST &vardecl) { auto init = vardecl.getInitVal(); if (!init) { emitError(loc(vardecl.loc()), @@ -380,7 +379,7 @@ private: return nullptr; } - mlir::ValuePtr value = mlirGen(*init); + mlir::Value value = mlirGen(*init); if (!value) return nullptr; @@ -400,7 +399,7 @@ private: /// Codegen a list of expression, return failure if one of them hit an error. mlir::LogicalResult mlirGen(ExprASTList &blockAST) { - ScopedHashTableScope var_scope(symbolTable); + ScopedHashTableScope var_scope(symbolTable); for (auto &expr : blockAST) { // Specific handling for variable declarations, return statement, and // print. These can only appear in block list and not in nested diff --git a/mlir/examples/toy/Ch3/include/toy/Ops.td b/mlir/examples/toy/Ch3/include/toy/Ops.td index a6c93ccba10..80717119b2f 100644 --- a/mlir/examples/toy/Ch3/include/toy/Ops.td +++ b/mlir/examples/toy/Ch3/include/toy/Ops.td @@ -89,7 +89,7 @@ def AddOp : Toy_Op<"add", [NoSideEffect]> { // Allow building an AddOp with from the two input operands. let builders = [ - OpBuilder<"Builder *b, OperationState &state, ValuePtr lhs, ValuePtr rhs"> + OpBuilder<"Builder *b, OperationState &state, Value lhs, Value rhs"> ]; } @@ -120,7 +120,7 @@ def GenericCallOp : Toy_Op<"generic_call"> { // Add custom build methods for the generic call operation. let builders = [ OpBuilder<"Builder *builder, OperationState &state, " - "StringRef callee, ArrayRef arguments"> + "StringRef callee, ArrayRef arguments"> ]; } @@ -136,7 +136,7 @@ def MulOp : Toy_Op<"mul", [NoSideEffect]> { // Allow building a MulOp with from the two input operands. let builders = [ - OpBuilder<"Builder *b, OperationState &state, ValuePtr lhs, ValuePtr rhs"> + OpBuilder<"Builder *b, OperationState &state, Value lhs, Value rhs"> ]; } @@ -216,7 +216,7 @@ def TransposeOp : Toy_Op<"transpose", [NoSideEffect]> { // Allow building a TransposeOp with from the input operand. let builders = [ - OpBuilder<"Builder *b, OperationState &state, ValuePtr input"> + OpBuilder<"Builder *b, OperationState &state, Value input"> ]; // Invoke a static verify method to verify this transpose operation. diff --git a/mlir/examples/toy/Ch3/mlir/Dialect.cpp b/mlir/examples/toy/Ch3/mlir/Dialect.cpp index b33cb5cbfe9..6b4d669d18e 100644 --- a/mlir/examples/toy/Ch3/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch3/mlir/Dialect.cpp @@ -85,7 +85,7 @@ static mlir::LogicalResult verify(ConstantOp op) { // AddOp void AddOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::ValuePtr lhs, mlir::ValuePtr rhs) { + mlir::Value lhs, mlir::Value rhs) { state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands({lhs, rhs}); } @@ -94,8 +94,7 @@ void AddOp::build(mlir::Builder *builder, mlir::OperationState &state, // GenericCallOp void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state, - StringRef callee, - ArrayRef arguments) { + StringRef callee, ArrayRef arguments) { // Generic call always returns an unranked Tensor initially. state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands(arguments); @@ -106,7 +105,7 @@ void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state, // MulOp void MulOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::ValuePtr lhs, mlir::ValuePtr rhs) { + mlir::Value lhs, mlir::Value rhs) { state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands({lhs, rhs}); } @@ -153,7 +152,7 @@ static mlir::LogicalResult verify(ReturnOp op) { // TransposeOp void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::ValuePtr value) { + mlir::Value value) { state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands(value); } diff --git a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp index e9987ff2c77..d9c960c79f4 100644 --- a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp @@ -90,7 +90,7 @@ private: /// Entering a function creates a new scope, and the function arguments are /// added to the mapping. When the processing of a function is terminated, the /// scope is destroyed and the mappings created in this scope are dropped. - llvm::ScopedHashTable symbolTable; + llvm::ScopedHashTable symbolTable; /// Helper conversion for a Toy AST location to an MLIR location. mlir::Location loc(Location loc) { @@ -100,7 +100,7 @@ private: /// Declare a variable in the current scope, return success if the variable /// wasn't declared yet. - mlir::LogicalResult declare(llvm::StringRef var, mlir::ValuePtr value) { + mlir::LogicalResult declare(llvm::StringRef var, mlir::Value value) { if (symbolTable.count(var)) return mlir::failure(); symbolTable.insert(var, value); @@ -123,8 +123,7 @@ private: /// Emit a new function and add it to the MLIR module. mlir::FuncOp mlirGen(FunctionAST &funcAST) { // Create a scope in the symbol table to hold variable declarations. - ScopedHashTableScope var_scope( - symbolTable); + ScopedHashTableScope var_scope(symbolTable); // Create an MLIR function for the given prototype. mlir::FuncOp function(mlirGen(*funcAST.getProto())); @@ -175,7 +174,7 @@ private: } /// Emit a binary operation - mlir::ValuePtr mlirGen(BinaryExprAST &binop) { + mlir::Value mlirGen(BinaryExprAST &binop) { // First emit the operations for each side of the operation before emitting // the operation itself. For example if the expression is `a + foo(a)` // 1) First it will visiting the LHS, which will return a reference to the @@ -187,10 +186,10 @@ private: // and the result value is returned. If an error occurs we get a nullptr // and propagate. // - mlir::ValuePtr lhs = mlirGen(*binop.getLHS()); + mlir::Value lhs = mlirGen(*binop.getLHS()); if (!lhs) return nullptr; - mlir::ValuePtr rhs = mlirGen(*binop.getRHS()); + mlir::Value rhs = mlirGen(*binop.getRHS()); if (!rhs) return nullptr; auto location = loc(binop.loc()); @@ -211,7 +210,7 @@ private: /// This is a reference to a variable in an expression. The variable is /// expected to have been declared and so should have a value in the symbol /// table, otherwise emit an error and return nullptr. - mlir::ValuePtr mlirGen(VariableExprAST &expr) { + mlir::Value mlirGen(VariableExprAST &expr) { if (auto variable = symbolTable.lookup(expr.getName())) return variable; @@ -225,7 +224,7 @@ private: auto location = loc(ret.loc()); // 'return' takes an optional expression, handle that case here. - mlir::ValuePtr expr = nullptr; + mlir::Value expr = nullptr; if (ret.getExpr().hasValue()) { if (!(expr = mlirGen(*ret.getExpr().getValue()))) return mlir::failure(); @@ -233,7 +232,7 @@ private: // Otherwise, this return operation has zero operands. builder.create(location, expr ? makeArrayRef(expr) - : ArrayRef()); + : ArrayRef()); return mlir::success(); } @@ -255,7 +254,7 @@ private: /// [[1.000000e+00, 2.000000e+00, 3.000000e+00], /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64> /// - mlir::ValuePtr mlirGen(LiteralExprAST &lit) { + mlir::Value mlirGen(LiteralExprAST &lit) { auto type = getType(lit.getDims()); // The attribute is a vector with a floating point value per element @@ -301,12 +300,12 @@ private: /// Emit a call expression. It emits specific operations for the `transpose` /// builtin. Other identifiers are assumed to be user-defined functions. - mlir::ValuePtr mlirGen(CallExprAST &call) { + mlir::Value mlirGen(CallExprAST &call) { llvm::StringRef callee = call.getCallee(); auto location = loc(call.loc()); // Codegen the operands first. - SmallVector operands; + SmallVector operands; for (auto &expr : call.getArgs()) { auto arg = mlirGen(*expr); if (!arg) @@ -343,12 +342,12 @@ private: } /// Emit a constant for a single number (FIXME: semantic? broadcast?) - mlir::ValuePtr mlirGen(NumberExprAST &num) { + mlir::Value mlirGen(NumberExprAST &num) { return builder.create(loc(num.loc()), num.getValue()); } /// Dispatch codegen for the right expression subclass using RTTI. - mlir::ValuePtr mlirGen(ExprAST &expr) { + mlir::Value mlirGen(ExprAST &expr) { switch (expr.getKind()) { case toy::ExprAST::Expr_BinOp: return mlirGen(cast(expr)); @@ -372,7 +371,7 @@ private: /// initializer and record the value in the symbol table before returning it. /// Future expressions will be able to reference this variable through symbol /// table lookup. - mlir::ValuePtr mlirGen(VarDeclExprAST &vardecl) { + mlir::Value mlirGen(VarDeclExprAST &vardecl) { auto init = vardecl.getInitVal(); if (!init) { emitError(loc(vardecl.loc()), @@ -380,7 +379,7 @@ private: return nullptr; } - mlir::ValuePtr value = mlirGen(*init); + mlir::Value value = mlirGen(*init); if (!value) return nullptr; @@ -400,7 +399,7 @@ private: /// Codegen a list of expression, return failure if one of them hit an error. mlir::LogicalResult mlirGen(ExprASTList &blockAST) { - ScopedHashTableScope var_scope(symbolTable); + ScopedHashTableScope var_scope(symbolTable); for (auto &expr : blockAST) { // Specific handling for variable declarations, return statement, and // print. These can only appear in block list and not in nested diff --git a/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp index d52a2c173c1..e3205402179 100644 --- a/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp @@ -39,7 +39,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { matchAndRewrite(TransposeOp op, mlir::PatternRewriter &rewriter) const override { // Look through the input of the current transpose. - mlir::ValuePtr transposeInput = op.getOperand(); + mlir::Value transposeInput = op.getOperand(); TransposeOp transposeInputOp = llvm::dyn_cast_or_null(transposeInput->getDefiningOp()); diff --git a/mlir/examples/toy/Ch4/include/toy/Ops.td b/mlir/examples/toy/Ch4/include/toy/Ops.td index 71167664bbc..dfb11cf23b9 100644 --- a/mlir/examples/toy/Ch4/include/toy/Ops.td +++ b/mlir/examples/toy/Ch4/include/toy/Ops.td @@ -91,7 +91,7 @@ def AddOp : Toy_Op<"add", // Allow building an AddOp with from the two input operands. let builders = [ - OpBuilder<"Builder *b, OperationState &state, ValuePtr lhs, ValuePtr rhs"> + OpBuilder<"Builder *b, OperationState &state, Value lhs, Value rhs"> ]; } @@ -142,7 +142,7 @@ def GenericCallOp : Toy_Op<"generic_call", // Add custom build methods for the generic call operation. let builders = [ OpBuilder<"Builder *builder, OperationState &state, " - "StringRef callee, ArrayRef arguments"> + "StringRef callee, ArrayRef arguments"> ]; } @@ -159,7 +159,7 @@ def MulOp : Toy_Op<"mul", // Allow building a MulOp with from the two input operands. let builders = [ - OpBuilder<"Builder *b, OperationState &state, ValuePtr lhs, ValuePtr rhs"> + OpBuilder<"Builder *b, OperationState &state, Value lhs, Value rhs"> ]; } @@ -236,7 +236,7 @@ def TransposeOp : Toy_Op<"transpose", // Allow building a TransposeOp with from the input operand. let builders = [ - OpBuilder<"Builder *b, OperationState &state, ValuePtr input"> + OpBuilder<"Builder *b, OperationState &state, Value input"> ]; // Invoke a static verify method to verify this transpose operation. diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp index 50116b14bea..0a9ded0c3d3 100644 --- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp @@ -46,7 +46,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface { /// Handle the given inlined terminator(toy.return) by replacing it with a new /// operation as necessary. void handleTerminator(Operation *op, - ArrayRef valuesToRepl) const final { + ArrayRef valuesToRepl) const final { // Only "toy.return" needs to be handled here. auto returnOp = cast(op); @@ -61,7 +61,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface { /// operation that takes 'input' as the only operand, and produces a single /// result of 'resultType'. If a conversion can not be generated, nullptr /// should be returned. - Operation *materializeCallConversion(OpBuilder &builder, ValuePtr input, + Operation *materializeCallConversion(OpBuilder &builder, Value input, Type resultType, Location conversionLoc) const final { return builder.create(conversionLoc, resultType, input); @@ -135,7 +135,7 @@ static mlir::LogicalResult verify(ConstantOp op) { // AddOp void AddOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::ValuePtr lhs, mlir::ValuePtr rhs) { + mlir::Value lhs, mlir::Value rhs) { state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands({lhs, rhs}); } @@ -155,8 +155,7 @@ void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); } // GenericCallOp void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state, - StringRef callee, - ArrayRef arguments) { + StringRef callee, ArrayRef arguments) { // Generic call always returns an unranked Tensor initially. state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands(arguments); @@ -177,7 +176,7 @@ Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); } // MulOp void MulOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::ValuePtr lhs, mlir::ValuePtr rhs) { + mlir::Value lhs, mlir::Value rhs) { state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands({lhs, rhs}); } @@ -228,7 +227,7 @@ static mlir::LogicalResult verify(ReturnOp op) { // TransposeOp void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::ValuePtr value) { + mlir::Value value) { state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands(value); } diff --git a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp index e9987ff2c77..d9c960c79f4 100644 --- a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp @@ -90,7 +90,7 @@ private: /// Entering a function creates a new scope, and the function arguments are /// added to the mapping. When the processing of a function is terminated, the /// scope is destroyed and the mappings created in this scope are dropped. - llvm::ScopedHashTable symbolTable; + llvm::ScopedHashTable symbolTable; /// Helper conversion for a Toy AST location to an MLIR location. mlir::Location loc(Location loc) { @@ -100,7 +100,7 @@ private: /// Declare a variable in the current scope, return success if the variable /// wasn't declared yet. - mlir::LogicalResult declare(llvm::StringRef var, mlir::ValuePtr value) { + mlir::LogicalResult declare(llvm::StringRef var, mlir::Value value) { if (symbolTable.count(var)) return mlir::failure(); symbolTable.insert(var, value); @@ -123,8 +123,7 @@ private: /// Emit a new function and add it to the MLIR module. mlir::FuncOp mlirGen(FunctionAST &funcAST) { // Create a scope in the symbol table to hold variable declarations. - ScopedHashTableScope var_scope( - symbolTable); + ScopedHashTableScope var_scope(symbolTable); // Create an MLIR function for the given prototype. mlir::FuncOp function(mlirGen(*funcAST.getProto())); @@ -175,7 +174,7 @@ private: } /// Emit a binary operation - mlir::ValuePtr mlirGen(BinaryExprAST &binop) { + mlir::Value mlirGen(BinaryExprAST &binop) { // First emit the operations for each side of the operation before emitting // the operation itself. For example if the expression is `a + foo(a)` // 1) First it will visiting the LHS, which will return a reference to the @@ -187,10 +186,10 @@ private: // and the result value is returned. If an error occurs we get a nullptr // and propagate. // - mlir::ValuePtr lhs = mlirGen(*binop.getLHS()); + mlir::Value lhs = mlirGen(*binop.getLHS()); if (!lhs) return nullptr; - mlir::ValuePtr rhs = mlirGen(*binop.getRHS()); + mlir::Value rhs = mlirGen(*binop.getRHS()); if (!rhs) return nullptr; auto location = loc(binop.loc()); @@ -211,7 +210,7 @@ private: /// This is a reference to a variable in an expression. The variable is /// expected to have been declared and so should have a value in the symbol /// table, otherwise emit an error and return nullptr. - mlir::ValuePtr mlirGen(VariableExprAST &expr) { + mlir::Value mlirGen(VariableExprAST &expr) { if (auto variable = symbolTable.lookup(expr.getName())) return variable; @@ -225,7 +224,7 @@ private: auto location = loc(ret.loc()); // 'return' takes an optional expression, handle that case here. - mlir::ValuePtr expr = nullptr; + mlir::Value expr = nullptr; if (ret.getExpr().hasValue()) { if (!(expr = mlirGen(*ret.getExpr().getValue()))) return mlir::failure(); @@ -233,7 +232,7 @@ private: // Otherwise, this return operation has zero operands. builder.create(location, expr ? makeArrayRef(expr) - : ArrayRef()); + : ArrayRef()); return mlir::success(); } @@ -255,7 +254,7 @@ private: /// [[1.000000e+00, 2.000000e+00, 3.000000e+00], /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64> /// - mlir::ValuePtr mlirGen(LiteralExprAST &lit) { + mlir::Value mlirGen(LiteralExprAST &lit) { auto type = getType(lit.getDims()); // The attribute is a vector with a floating point value per element @@ -301,12 +300,12 @@ private: /// Emit a call expression. It emits specific operations for the `transpose` /// builtin. Other identifiers are assumed to be user-defined functions. - mlir::ValuePtr mlirGen(CallExprAST &call) { + mlir::Value mlirGen(CallExprAST &call) { llvm::StringRef callee = call.getCallee(); auto location = loc(call.loc()); // Codegen the operands first. - SmallVector operands; + SmallVector operands; for (auto &expr : call.getArgs()) { auto arg = mlirGen(*expr); if (!arg) @@ -343,12 +342,12 @@ private: } /// Emit a constant for a single number (FIXME: semantic? broadcast?) - mlir::ValuePtr mlirGen(NumberExprAST &num) { + mlir::Value mlirGen(NumberExprAST &num) { return builder.create(loc(num.loc()), num.getValue()); } /// Dispatch codegen for the right expression subclass using RTTI. - mlir::ValuePtr mlirGen(ExprAST &expr) { + mlir::Value mlirGen(ExprAST &expr) { switch (expr.getKind()) { case toy::ExprAST::Expr_BinOp: return mlirGen(cast(expr)); @@ -372,7 +371,7 @@ private: /// initializer and record the value in the symbol table before returning it. /// Future expressions will be able to reference this variable through symbol /// table lookup. - mlir::ValuePtr mlirGen(VarDeclExprAST &vardecl) { + mlir::Value mlirGen(VarDeclExprAST &vardecl) { auto init = vardecl.getInitVal(); if (!init) { emitError(loc(vardecl.loc()), @@ -380,7 +379,7 @@ private: return nullptr; } - mlir::ValuePtr value = mlirGen(*init); + mlir::Value value = mlirGen(*init); if (!value) return nullptr; @@ -400,7 +399,7 @@ private: /// Codegen a list of expression, return failure if one of them hit an error. mlir::LogicalResult mlirGen(ExprASTList &blockAST) { - ScopedHashTableScope var_scope(symbolTable); + ScopedHashTableScope var_scope(symbolTable); for (auto &expr : blockAST) { // Specific handling for variable declarations, return statement, and // print. These can only appear in block list and not in nested diff --git a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp index 2cbf8bdac9b..82c247c1be2 100644 --- a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp @@ -44,7 +44,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { matchAndRewrite(TransposeOp op, mlir::PatternRewriter &rewriter) const override { // Look through the input of the current transpose. - mlir::ValuePtr transposeInput = op.getOperand(); + mlir::Value transposeInput = op.getOperand(); TransposeOp transposeInputOp = llvm::dyn_cast_or_null(transposeInput->getDefiningOp()); diff --git a/mlir/examples/toy/Ch5/include/toy/Ops.td b/mlir/examples/toy/Ch5/include/toy/Ops.td index bb98ae19a09..410c5df2461 100644 --- a/mlir/examples/toy/Ch5/include/toy/Ops.td +++ b/mlir/examples/toy/Ch5/include/toy/Ops.td @@ -91,7 +91,7 @@ def AddOp : Toy_Op<"add", // Allow building an AddOp with from the two input operands. let builders = [ - OpBuilder<"Builder *b, OperationState &state, ValuePtr lhs, ValuePtr rhs"> + OpBuilder<"Builder *b, OperationState &state, Value lhs, Value rhs"> ]; } @@ -142,7 +142,7 @@ def GenericCallOp : Toy_Op<"generic_call", // Add custom build methods for the generic call operation. let builders = [ OpBuilder<"Builder *builder, OperationState &state, " - "StringRef callee, ArrayRef arguments"> + "StringRef callee, ArrayRef arguments"> ]; } @@ -159,7 +159,7 @@ def MulOp : Toy_Op<"mul", // Allow building a MulOp with from the two input operands. let builders = [ - OpBuilder<"Builder *b, OperationState &state, ValuePtr lhs, ValuePtr rhs"> + OpBuilder<"Builder *b, OperationState &state, Value lhs, Value rhs"> ]; } @@ -237,7 +237,7 @@ def TransposeOp : Toy_Op<"transpose", // Allow building a TransposeOp with from the input operand. let builders = [ - OpBuilder<"Builder *b, OperationState &state, ValuePtr input"> + OpBuilder<"Builder *b, OperationState &state, Value input"> ]; // Invoke a static verify method to verify this transpose operation. diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp index 50116b14bea..0a9ded0c3d3 100644 --- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp @@ -46,7 +46,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface { /// Handle the given inlined terminator(toy.return) by replacing it with a new /// operation as necessary. void handleTerminator(Operation *op, - ArrayRef valuesToRepl) const final { + ArrayRef valuesToRepl) const final { // Only "toy.return" needs to be handled here. auto returnOp = cast(op); @@ -61,7 +61,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface { /// operation that takes 'input' as the only operand, and produces a single /// result of 'resultType'. If a conversion can not be generated, nullptr /// should be returned. - Operation *materializeCallConversion(OpBuilder &builder, ValuePtr input, + Operation *materializeCallConversion(OpBuilder &builder, Value input, Type resultType, Location conversionLoc) const final { return builder.create(conversionLoc, resultType, input); @@ -135,7 +135,7 @@ static mlir::LogicalResult verify(ConstantOp op) { // AddOp void AddOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::ValuePtr lhs, mlir::ValuePtr rhs) { + mlir::Value lhs, mlir::Value rhs) { state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands({lhs, rhs}); } @@ -155,8 +155,7 @@ void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); } // GenericCallOp void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state, - StringRef callee, - ArrayRef arguments) { + StringRef callee, ArrayRef arguments) { // Generic call always returns an unranked Tensor initially. state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands(arguments); @@ -177,7 +176,7 @@ Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); } // MulOp void MulOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::ValuePtr lhs, mlir::ValuePtr rhs) { + mlir::Value lhs, mlir::Value rhs) { state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands({lhs, rhs}); } @@ -228,7 +227,7 @@ static mlir::LogicalResult verify(ReturnOp op) { // TransposeOp void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::ValuePtr value) { + mlir::Value value) { state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands(value); } diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp index cba838a2928..2d6e76de069 100644 --- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp @@ -34,8 +34,8 @@ static MemRefType convertTensorToMemRef(TensorType type) { } /// Insert an allocation and deallocation for the given MemRefType. -static ValuePtr insertAllocAndDealloc(MemRefType type, Location loc, - PatternRewriter &rewriter) { +static Value insertAllocAndDealloc(MemRefType type, Location loc, + PatternRewriter &rewriter) { auto alloc = rewriter.create(loc, type); // Make sure to allocate at the beginning of the block. @@ -54,11 +54,11 @@ static ValuePtr insertAllocAndDealloc(MemRefType type, Location loc, /// to the operands of the input operation, and the set of loop induction /// variables for the iteration. It returns a value to store at the current /// index of the iteration. -using LoopIterationFn = function_ref memRefOperands, - ArrayRef loopIvs)>; +using LoopIterationFn = function_ref memRefOperands, + ArrayRef loopIvs)>; -static void lowerOpToLoops(Operation *op, ArrayRef operands, +static void lowerOpToLoops(Operation *op, ArrayRef operands, PatternRewriter &rewriter, LoopIterationFn processIteration) { auto tensorType = (*op->result_type_begin()).cast(); @@ -69,7 +69,7 @@ static void lowerOpToLoops(Operation *op, ArrayRef operands, auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); // Create an empty affine loop for each of the dimensions within the shape. - SmallVector loopIvs; + SmallVector loopIvs; for (auto dim : tensorType.getShape()) { auto loop = rewriter.create(loc, /*lb=*/0, dim, /*step=*/1); loop.getBody()->clear(); @@ -85,7 +85,7 @@ static void lowerOpToLoops(Operation *op, ArrayRef operands, // Generate a call to the processing function with the rewriter, the memref // operands, and the loop induction variables. This function will return the // value to store at the current index. - ValuePtr valueToStore = processIteration(rewriter, operands, loopIvs); + Value valueToStore = processIteration(rewriter, operands, loopIvs); rewriter.create(loc, valueToStore, alloc, llvm::makeArrayRef(loopIvs)); @@ -104,13 +104,13 @@ struct BinaryOpLowering : public ConversionPattern { : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); lowerOpToLoops( op, operands, rewriter, - [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, - ArrayRef loopIvs) { + [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, + ArrayRef loopIvs) { // Generate an adaptor for the remapped operands of the BinaryOp. This // allows for using the nice named accessors that are generated by the // ODS. @@ -154,7 +154,7 @@ struct ConstantOpLowering : public OpRewritePattern { // Create these constants up-front to avoid large amounts of redundant // operations. auto valueShape = memRefType.getShape(); - SmallVector constantIndices; + SmallVector constantIndices; for (auto i : llvm::seq( 0, *std::max_element(valueShape.begin(), valueShape.end()))) constantIndices.push_back(rewriter.create(loc, i)); @@ -163,7 +163,7 @@ struct ConstantOpLowering : public OpRewritePattern { // will need to generate a store for each of the elements. The following // functor recursively walks the dimensions of the constant shape, // generating a store when the recursion hits the base case. - SmallVector indices; + SmallVector indices; auto valueIt = constantValue.getValues().begin(); std::function storeElements = [&](uint64_t dimension) { // The last dimension is the base case of the recursion, at this point @@ -222,22 +222,22 @@ struct TransposeOpLowering : public ConversionPattern { : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); lowerOpToLoops( op, operands, rewriter, - [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, - ArrayRef loopIvs) { + [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, + ArrayRef loopIvs) { // Generate an adaptor for the remapped operands of the TransposeOp. // This allows for using the nice named accessors that are generated // by the ODS. toy::TransposeOpOperandAdaptor transposeAdaptor(memRefOperands); - ValuePtr input = transposeAdaptor.input(); + Value input = transposeAdaptor.input(); // Transpose the elements by generating a load from the reverse // indices. - SmallVector reverseIvs(llvm::reverse(loopIvs)); + SmallVector reverseIvs(llvm::reverse(loopIvs)); return rewriter.create(loc, input, reverseIvs); }); return matchSuccess(); diff --git a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp index e9987ff2c77..d9c960c79f4 100644 --- a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp @@ -90,7 +90,7 @@ private: /// Entering a function creates a new scope, and the function arguments are /// added to the mapping. When the processing of a function is terminated, the /// scope is destroyed and the mappings created in this scope are dropped. - llvm::ScopedHashTable symbolTable; + llvm::ScopedHashTable symbolTable; /// Helper conversion for a Toy AST location to an MLIR location. mlir::Location loc(Location loc) { @@ -100,7 +100,7 @@ private: /// Declare a variable in the current scope, return success if the variable /// wasn't declared yet. - mlir::LogicalResult declare(llvm::StringRef var, mlir::ValuePtr value) { + mlir::LogicalResult declare(llvm::StringRef var, mlir::Value value) { if (symbolTable.count(var)) return mlir::failure(); symbolTable.insert(var, value); @@ -123,8 +123,7 @@ private: /// Emit a new function and add it to the MLIR module. mlir::FuncOp mlirGen(FunctionAST &funcAST) { // Create a scope in the symbol table to hold variable declarations. - ScopedHashTableScope var_scope( - symbolTable); + ScopedHashTableScope var_scope(symbolTable); // Create an MLIR function for the given prototype. mlir::FuncOp function(mlirGen(*funcAST.getProto())); @@ -175,7 +174,7 @@ private: } /// Emit a binary operation - mlir::ValuePtr mlirGen(BinaryExprAST &binop) { + mlir::Value mlirGen(BinaryExprAST &binop) { // First emit the operations for each side of the operation before emitting // the operation itself. For example if the expression is `a + foo(a)` // 1) First it will visiting the LHS, which will return a reference to the @@ -187,10 +186,10 @@ private: // and the result value is returned. If an error occurs we get a nullptr // and propagate. // - mlir::ValuePtr lhs = mlirGen(*binop.getLHS()); + mlir::Value lhs = mlirGen(*binop.getLHS()); if (!lhs) return nullptr; - mlir::ValuePtr rhs = mlirGen(*binop.getRHS()); + mlir::Value rhs = mlirGen(*binop.getRHS()); if (!rhs) return nullptr; auto location = loc(binop.loc()); @@ -211,7 +210,7 @@ private: /// This is a reference to a variable in an expression. The variable is /// expected to have been declared and so should have a value in the symbol /// table, otherwise emit an error and return nullptr. - mlir::ValuePtr mlirGen(VariableExprAST &expr) { + mlir::Value mlirGen(VariableExprAST &expr) { if (auto variable = symbolTable.lookup(expr.getName())) return variable; @@ -225,7 +224,7 @@ private: auto location = loc(ret.loc()); // 'return' takes an optional expression, handle that case here. - mlir::ValuePtr expr = nullptr; + mlir::Value expr = nullptr; if (ret.getExpr().hasValue()) { if (!(expr = mlirGen(*ret.getExpr().getValue()))) return mlir::failure(); @@ -233,7 +232,7 @@ private: // Otherwise, this return operation has zero operands. builder.create(location, expr ? makeArrayRef(expr) - : ArrayRef()); + : ArrayRef()); return mlir::success(); } @@ -255,7 +254,7 @@ private: /// [[1.000000e+00, 2.000000e+00, 3.000000e+00], /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64> /// - mlir::ValuePtr mlirGen(LiteralExprAST &lit) { + mlir::Value mlirGen(LiteralExprAST &lit) { auto type = getType(lit.getDims()); // The attribute is a vector with a floating point value per element @@ -301,12 +300,12 @@ private: /// Emit a call expression. It emits specific operations for the `transpose` /// builtin. Other identifiers are assumed to be user-defined functions. - mlir::ValuePtr mlirGen(CallExprAST &call) { + mlir::Value mlirGen(CallExprAST &call) { llvm::StringRef callee = call.getCallee(); auto location = loc(call.loc()); // Codegen the operands first. - SmallVector operands; + SmallVector operands; for (auto &expr : call.getArgs()) { auto arg = mlirGen(*expr); if (!arg) @@ -343,12 +342,12 @@ private: } /// Emit a constant for a single number (FIXME: semantic? broadcast?) - mlir::ValuePtr mlirGen(NumberExprAST &num) { + mlir::Value mlirGen(NumberExprAST &num) { return builder.create(loc(num.loc()), num.getValue()); } /// Dispatch codegen for the right expression subclass using RTTI. - mlir::ValuePtr mlirGen(ExprAST &expr) { + mlir::Value mlirGen(ExprAST &expr) { switch (expr.getKind()) { case toy::ExprAST::Expr_BinOp: return mlirGen(cast(expr)); @@ -372,7 +371,7 @@ private: /// initializer and record the value in the symbol table before returning it. /// Future expressions will be able to reference this variable through symbol /// table lookup. - mlir::ValuePtr mlirGen(VarDeclExprAST &vardecl) { + mlir::Value mlirGen(VarDeclExprAST &vardecl) { auto init = vardecl.getInitVal(); if (!init) { emitError(loc(vardecl.loc()), @@ -380,7 +379,7 @@ private: return nullptr; } - mlir::ValuePtr value = mlirGen(*init); + mlir::Value value = mlirGen(*init); if (!value) return nullptr; @@ -400,7 +399,7 @@ private: /// Codegen a list of expression, return failure if one of them hit an error. mlir::LogicalResult mlirGen(ExprASTList &blockAST) { - ScopedHashTableScope var_scope(symbolTable); + ScopedHashTableScope var_scope(symbolTable); for (auto &expr : blockAST) { // Specific handling for variable declarations, return statement, and // print. These can only appear in block list and not in nested diff --git a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp index 2cbf8bdac9b..82c247c1be2 100644 --- a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp @@ -44,7 +44,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { matchAndRewrite(TransposeOp op, mlir::PatternRewriter &rewriter) const override { // Look through the input of the current transpose. - mlir::ValuePtr transposeInput = op.getOperand(); + mlir::Value transposeInput = op.getOperand(); TransposeOp transposeInputOp = llvm::dyn_cast_or_null(transposeInput->getDefiningOp()); diff --git a/mlir/examples/toy/Ch6/include/toy/Ops.td b/mlir/examples/toy/Ch6/include/toy/Ops.td index bb98ae19a09..410c5df2461 100644 --- a/mlir/examples/toy/Ch6/include/toy/Ops.td +++ b/mlir/examples/toy/Ch6/include/toy/Ops.td @@ -91,7 +91,7 @@ def AddOp : Toy_Op<"add", // Allow building an AddOp with from the two input operands. let builders = [ - OpBuilder<"Builder *b, OperationState &state, ValuePtr lhs, ValuePtr rhs"> + OpBuilder<"Builder *b, OperationState &state, Value lhs, Value rhs"> ]; } @@ -142,7 +142,7 @@ def GenericCallOp : Toy_Op<"generic_call", // Add custom build methods for the generic call operation. let builders = [ OpBuilder<"Builder *builder, OperationState &state, " - "StringRef callee, ArrayRef arguments"> + "StringRef callee, ArrayRef arguments"> ]; } @@ -159,7 +159,7 @@ def MulOp : Toy_Op<"mul", // Allow building a MulOp with from the two input operands. let builders = [ - OpBuilder<"Builder *b, OperationState &state, ValuePtr lhs, ValuePtr rhs"> + OpBuilder<"Builder *b, OperationState &state, Value lhs, Value rhs"> ]; } @@ -237,7 +237,7 @@ def TransposeOp : Toy_Op<"transpose", // Allow building a TransposeOp with from the input operand. let builders = [ - OpBuilder<"Builder *b, OperationState &state, ValuePtr input"> + OpBuilder<"Builder *b, OperationState &state, Value input"> ]; // Invoke a static verify method to verify this transpose operation. diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp index 50116b14bea..0a9ded0c3d3 100644 --- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp @@ -46,7 +46,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface { /// Handle the given inlined terminator(toy.return) by replacing it with a new /// operation as necessary. void handleTerminator(Operation *op, - ArrayRef valuesToRepl) const final { + ArrayRef valuesToRepl) const final { // Only "toy.return" needs to be handled here. auto returnOp = cast(op); @@ -61,7 +61,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface { /// operation that takes 'input' as the only operand, and produces a single /// result of 'resultType'. If a conversion can not be generated, nullptr /// should be returned. - Operation *materializeCallConversion(OpBuilder &builder, ValuePtr input, + Operation *materializeCallConversion(OpBuilder &builder, Value input, Type resultType, Location conversionLoc) const final { return builder.create(conversionLoc, resultType, input); @@ -135,7 +135,7 @@ static mlir::LogicalResult verify(ConstantOp op) { // AddOp void AddOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::ValuePtr lhs, mlir::ValuePtr rhs) { + mlir::Value lhs, mlir::Value rhs) { state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands({lhs, rhs}); } @@ -155,8 +155,7 @@ void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); } // GenericCallOp void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state, - StringRef callee, - ArrayRef arguments) { + StringRef callee, ArrayRef arguments) { // Generic call always returns an unranked Tensor initially. state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands(arguments); @@ -177,7 +176,7 @@ Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); } // MulOp void MulOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::ValuePtr lhs, mlir::ValuePtr rhs) { + mlir::Value lhs, mlir::Value rhs) { state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands({lhs, rhs}); } @@ -228,7 +227,7 @@ static mlir::LogicalResult verify(ReturnOp op) { // TransposeOp void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::ValuePtr value) { + mlir::Value value) { state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands(value); } diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp index cba838a2928..2d6e76de069 100644 --- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp @@ -34,8 +34,8 @@ static MemRefType convertTensorToMemRef(TensorType type) { } /// Insert an allocation and deallocation for the given MemRefType. -static ValuePtr insertAllocAndDealloc(MemRefType type, Location loc, - PatternRewriter &rewriter) { +static Value insertAllocAndDealloc(MemRefType type, Location loc, + PatternRewriter &rewriter) { auto alloc = rewriter.create(loc, type); // Make sure to allocate at the beginning of the block. @@ -54,11 +54,11 @@ static ValuePtr insertAllocAndDealloc(MemRefType type, Location loc, /// to the operands of the input operation, and the set of loop induction /// variables for the iteration. It returns a value to store at the current /// index of the iteration. -using LoopIterationFn = function_ref memRefOperands, - ArrayRef loopIvs)>; +using LoopIterationFn = function_ref memRefOperands, + ArrayRef loopIvs)>; -static void lowerOpToLoops(Operation *op, ArrayRef operands, +static void lowerOpToLoops(Operation *op, ArrayRef operands, PatternRewriter &rewriter, LoopIterationFn processIteration) { auto tensorType = (*op->result_type_begin()).cast(); @@ -69,7 +69,7 @@ static void lowerOpToLoops(Operation *op, ArrayRef operands, auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); // Create an empty affine loop for each of the dimensions within the shape. - SmallVector loopIvs; + SmallVector loopIvs; for (auto dim : tensorType.getShape()) { auto loop = rewriter.create(loc, /*lb=*/0, dim, /*step=*/1); loop.getBody()->clear(); @@ -85,7 +85,7 @@ static void lowerOpToLoops(Operation *op, ArrayRef operands, // Generate a call to the processing function with the rewriter, the memref // operands, and the loop induction variables. This function will return the // value to store at the current index. - ValuePtr valueToStore = processIteration(rewriter, operands, loopIvs); + Value valueToStore = processIteration(rewriter, operands, loopIvs); rewriter.create(loc, valueToStore, alloc, llvm::makeArrayRef(loopIvs)); @@ -104,13 +104,13 @@ struct BinaryOpLowering : public ConversionPattern { : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); lowerOpToLoops( op, operands, rewriter, - [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, - ArrayRef loopIvs) { + [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, + ArrayRef loopIvs) { // Generate an adaptor for the remapped operands of the BinaryOp. This // allows for using the nice named accessors that are generated by the // ODS. @@ -154,7 +154,7 @@ struct ConstantOpLowering : public OpRewritePattern { // Create these constants up-front to avoid large amounts of redundant // operations. auto valueShape = memRefType.getShape(); - SmallVector constantIndices; + SmallVector constantIndices; for (auto i : llvm::seq( 0, *std::max_element(valueShape.begin(), valueShape.end()))) constantIndices.push_back(rewriter.create(loc, i)); @@ -163,7 +163,7 @@ struct ConstantOpLowering : public OpRewritePattern { // will need to generate a store for each of the elements. The following // functor recursively walks the dimensions of the constant shape, // generating a store when the recursion hits the base case. - SmallVector indices; + SmallVector indices; auto valueIt = constantValue.getValues().begin(); std::function storeElements = [&](uint64_t dimension) { // The last dimension is the base case of the recursion, at this point @@ -222,22 +222,22 @@ struct TransposeOpLowering : public ConversionPattern { : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); lowerOpToLoops( op, operands, rewriter, - [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, - ArrayRef loopIvs) { + [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, + ArrayRef loopIvs) { // Generate an adaptor for the remapped operands of the TransposeOp. // This allows for using the nice named accessors that are generated // by the ODS. toy::TransposeOpOperandAdaptor transposeAdaptor(memRefOperands); - ValuePtr input = transposeAdaptor.input(); + Value input = transposeAdaptor.input(); // Transpose the elements by generating a load from the reverse // indices. - SmallVector reverseIvs(llvm::reverse(loopIvs)); + SmallVector reverseIvs(llvm::reverse(loopIvs)); return rewriter.create(loc, input, reverseIvs); }); return matchSuccess(); diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp index 377bc11dd27..2f1a6ae8bbe 100644 --- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp @@ -42,7 +42,7 @@ public: : ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto memRefType = (*op->operand_type_begin()).cast(); auto memRefShape = memRefType.getShape(); @@ -55,14 +55,14 @@ public: // Get a symbol reference to the printf function, inserting it if necessary. auto printfRef = getOrInsertPrintf(rewriter, parentModule, llvmDialect); - ValuePtr formatSpecifierCst = getOrCreateGlobalString( + Value formatSpecifierCst = getOrCreateGlobalString( loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule, llvmDialect); - ValuePtr newLineCst = getOrCreateGlobalString( + Value newLineCst = getOrCreateGlobalString( loc, rewriter, "nl", StringRef("\n\0", 2), parentModule, llvmDialect); // Create a loop for each of the dimensions within the shape. - SmallVector loopIvs; + SmallVector loopIvs; for (unsigned i = 0, e = memRefShape.size(); i != e; ++i) { auto lowerBound = rewriter.create(loc, 0); auto upperBound = rewriter.create(loc, memRefShape[i]); @@ -86,9 +86,8 @@ public: // Generate a call to printf for the current element of the loop. auto printOp = cast(op); auto elementLoad = rewriter.create(loc, printOp.input(), loopIvs); - rewriter.create( - loc, printfRef, rewriter.getIntegerType(32), - ArrayRef({formatSpecifierCst, elementLoad})); + rewriter.create(loc, printfRef, rewriter.getIntegerType(32), + ArrayRef({formatSpecifierCst, elementLoad})); // Notify the rewriter that this operation has been removed. rewriter.eraseOp(op); @@ -121,10 +120,10 @@ private: /// Return a value representing an access into a global string with the given /// name, creating the string if necessary. - static ValuePtr getOrCreateGlobalString(Location loc, OpBuilder &builder, - StringRef name, StringRef value, - ModuleOp module, - LLVM::LLVMDialect *llvmDialect) { + static Value getOrCreateGlobalString(Location loc, OpBuilder &builder, + StringRef name, StringRef value, + ModuleOp module, + LLVM::LLVMDialect *llvmDialect) { // Create the global at the entry of the module. LLVM::GlobalOp global; if (!(global = module.lookupSymbol(name))) { @@ -138,13 +137,13 @@ private: } // Get the pointer to the first character in the global string. - ValuePtr globalPtr = builder.create(loc, global); - ValuePtr cst0 = builder.create( + Value globalPtr = builder.create(loc, global); + Value cst0 = builder.create( loc, LLVM::LLVMType::getInt64Ty(llvmDialect), builder.getIntegerAttr(builder.getIndexType(), 0)); return builder.create( loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), globalPtr, - ArrayRef({cst0, cst0})); + ArrayRef({cst0, cst0})); } }; } // end anonymous namespace diff --git a/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp index e9987ff2c77..d9c960c79f4 100644 --- a/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp @@ -90,7 +90,7 @@ private: /// Entering a function creates a new scope, and the function arguments are /// added to the mapping. When the processing of a function is terminated, the /// scope is destroyed and the mappings created in this scope are dropped. - llvm::ScopedHashTable symbolTable; + llvm::ScopedHashTable symbolTable; /// Helper conversion for a Toy AST location to an MLIR location. mlir::Location loc(Location loc) { @@ -100,7 +100,7 @@ private: /// Declare a variable in the current scope, return success if the variable /// wasn't declared yet. - mlir::LogicalResult declare(llvm::StringRef var, mlir::ValuePtr value) { + mlir::LogicalResult declare(llvm::StringRef var, mlir::Value value) { if (symbolTable.count(var)) return mlir::failure(); symbolTable.insert(var, value); @@ -123,8 +123,7 @@ private: /// Emit a new function and add it to the MLIR module. mlir::FuncOp mlirGen(FunctionAST &funcAST) { // Create a scope in the symbol table to hold variable declarations. - ScopedHashTableScope var_scope( - symbolTable); + ScopedHashTableScope var_scope(symbolTable); // Create an MLIR function for the given prototype. mlir::FuncOp function(mlirGen(*funcAST.getProto())); @@ -175,7 +174,7 @@ private: } /// Emit a binary operation - mlir::ValuePtr mlirGen(BinaryExprAST &binop) { + mlir::Value mlirGen(BinaryExprAST &binop) { // First emit the operations for each side of the operation before emitting // the operation itself. For example if the expression is `a + foo(a)` // 1) First it will visiting the LHS, which will return a reference to the @@ -187,10 +186,10 @@ private: // and the result value is returned. If an error occurs we get a nullptr // and propagate. // - mlir::ValuePtr lhs = mlirGen(*binop.getLHS()); + mlir::Value lhs = mlirGen(*binop.getLHS()); if (!lhs) return nullptr; - mlir::ValuePtr rhs = mlirGen(*binop.getRHS()); + mlir::Value rhs = mlirGen(*binop.getRHS()); if (!rhs) return nullptr; auto location = loc(binop.loc()); @@ -211,7 +210,7 @@ private: /// This is a reference to a variable in an expression. The variable is /// expected to have been declared and so should have a value in the symbol /// table, otherwise emit an error and return nullptr. - mlir::ValuePtr mlirGen(VariableExprAST &expr) { + mlir::Value mlirGen(VariableExprAST &expr) { if (auto variable = symbolTable.lookup(expr.getName())) return variable; @@ -225,7 +224,7 @@ private: auto location = loc(ret.loc()); // 'return' takes an optional expression, handle that case here. - mlir::ValuePtr expr = nullptr; + mlir::Value expr = nullptr; if (ret.getExpr().hasValue()) { if (!(expr = mlirGen(*ret.getExpr().getValue()))) return mlir::failure(); @@ -233,7 +232,7 @@ private: // Otherwise, this return operation has zero operands. builder.create(location, expr ? makeArrayRef(expr) - : ArrayRef()); + : ArrayRef()); return mlir::success(); } @@ -255,7 +254,7 @@ private: /// [[1.000000e+00, 2.000000e+00, 3.000000e+00], /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64> /// - mlir::ValuePtr mlirGen(LiteralExprAST &lit) { + mlir::Value mlirGen(LiteralExprAST &lit) { auto type = getType(lit.getDims()); // The attribute is a vector with a floating point value per element @@ -301,12 +300,12 @@ private: /// Emit a call expression. It emits specific operations for the `transpose` /// builtin. Other identifiers are assumed to be user-defined functions. - mlir::ValuePtr mlirGen(CallExprAST &call) { + mlir::Value mlirGen(CallExprAST &call) { llvm::StringRef callee = call.getCallee(); auto location = loc(call.loc()); // Codegen the operands first. - SmallVector operands; + SmallVector operands; for (auto &expr : call.getArgs()) { auto arg = mlirGen(*expr); if (!arg) @@ -343,12 +342,12 @@ private: } /// Emit a constant for a single number (FIXME: semantic? broadcast?) - mlir::ValuePtr mlirGen(NumberExprAST &num) { + mlir::Value mlirGen(NumberExprAST &num) { return builder.create(loc(num.loc()), num.getValue()); } /// Dispatch codegen for the right expression subclass using RTTI. - mlir::ValuePtr mlirGen(ExprAST &expr) { + mlir::Value mlirGen(ExprAST &expr) { switch (expr.getKind()) { case toy::ExprAST::Expr_BinOp: return mlirGen(cast(expr)); @@ -372,7 +371,7 @@ private: /// initializer and record the value in the symbol table before returning it. /// Future expressions will be able to reference this variable through symbol /// table lookup. - mlir::ValuePtr mlirGen(VarDeclExprAST &vardecl) { + mlir::Value mlirGen(VarDeclExprAST &vardecl) { auto init = vardecl.getInitVal(); if (!init) { emitError(loc(vardecl.loc()), @@ -380,7 +379,7 @@ private: return nullptr; } - mlir::ValuePtr value = mlirGen(*init); + mlir::Value value = mlirGen(*init); if (!value) return nullptr; @@ -400,7 +399,7 @@ private: /// Codegen a list of expression, return failure if one of them hit an error. mlir::LogicalResult mlirGen(ExprASTList &blockAST) { - ScopedHashTableScope var_scope(symbolTable); + ScopedHashTableScope var_scope(symbolTable); for (auto &expr : blockAST) { // Specific handling for variable declarations, return statement, and // print. These can only appear in block list and not in nested diff --git a/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp index 2cbf8bdac9b..82c247c1be2 100644 --- a/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp @@ -44,7 +44,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { matchAndRewrite(TransposeOp op, mlir::PatternRewriter &rewriter) const override { // Look through the input of the current transpose. - mlir::ValuePtr transposeInput = op.getOperand(); + mlir::Value transposeInput = op.getOperand(); TransposeOp transposeInputOp = llvm::dyn_cast_or_null(transposeInput->getDefiningOp()); diff --git a/mlir/examples/toy/Ch7/include/toy/Ops.td b/mlir/examples/toy/Ch7/include/toy/Ops.td index 801aef06934..15395c6da4e 100644 --- a/mlir/examples/toy/Ch7/include/toy/Ops.td +++ b/mlir/examples/toy/Ch7/include/toy/Ops.td @@ -103,7 +103,7 @@ def AddOp : Toy_Op<"add", // Allow building an AddOp with from the two input operands. let builders = [ - OpBuilder<"Builder *b, OperationState &state, ValuePtr lhs, ValuePtr rhs"> + OpBuilder<"Builder *b, OperationState &state, Value lhs, Value rhs"> ]; } @@ -155,7 +155,7 @@ def GenericCallOp : Toy_Op<"generic_call", // Add custom build methods for the generic call operation. let builders = [ OpBuilder<"Builder *builder, OperationState &state, " - "StringRef callee, ArrayRef arguments"> + "StringRef callee, ArrayRef arguments"> ]; } @@ -172,7 +172,7 @@ def MulOp : Toy_Op<"mul", // Allow building a MulOp with from the two input operands. let builders = [ - OpBuilder<"Builder *b, OperationState &state, ValuePtr lhs, ValuePtr rhs"> + OpBuilder<"Builder *b, OperationState &state, Value lhs, Value rhs"> ]; } @@ -251,7 +251,7 @@ def StructAccessOp : Toy_Op<"struct_access", [NoSideEffect]> { // Allow building a StructAccessOp with just a struct value and an index. let builders = [ - OpBuilder<"Builder *b, OperationState &state, ValuePtr input, size_t index"> + OpBuilder<"Builder *b, OperationState &state, Value input, size_t index"> ]; let verifier = [{ return ::verify(*this); }]; @@ -290,7 +290,7 @@ def TransposeOp : Toy_Op<"transpose", // Allow building a TransposeOp with from the input operand. let builders = [ - OpBuilder<"Builder *b, OperationState &state, ValuePtr input"> + OpBuilder<"Builder *b, OperationState &state, Value input"> ]; // Invoke a static verify method to verify this transpose operation. diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp index 4f4cbdf2f0f..7e37f61a473 100644 --- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -47,7 +47,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface { /// Handle the given inlined terminator(toy.return) by replacing it with a new /// operation as necessary. void handleTerminator(Operation *op, - ArrayRef valuesToRepl) const final { + ArrayRef valuesToRepl) const final { // Only "toy.return" needs to be handled here. auto returnOp = cast(op); @@ -62,7 +62,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface { /// operation that takes 'input' as the only operand, and produces a single /// result of 'resultType'. If a conversion can not be generated, nullptr /// should be returned. - Operation *materializeCallConversion(OpBuilder &builder, ValuePtr input, + Operation *materializeCallConversion(OpBuilder &builder, Value input, Type resultType, Location conversionLoc) const final { return builder.create(conversionLoc, resultType, input); @@ -186,7 +186,7 @@ void ConstantOp::inferShapes() { getResult()->setType(value().getType()); } // AddOp void AddOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::ValuePtr lhs, mlir::ValuePtr rhs) { + mlir::Value lhs, mlir::Value rhs) { state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands({lhs, rhs}); } @@ -206,8 +206,7 @@ void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); } // GenericCallOp void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state, - StringRef callee, - ArrayRef arguments) { + StringRef callee, ArrayRef arguments) { // Generic call always returns an unranked Tensor initially. state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands(arguments); @@ -228,7 +227,7 @@ Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); } // MulOp void MulOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::ValuePtr lhs, mlir::ValuePtr rhs) { + mlir::Value lhs, mlir::Value rhs) { state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands({lhs, rhs}); } @@ -279,7 +278,7 @@ static mlir::LogicalResult verify(ReturnOp op) { // StructAccessOp void StructAccessOp::build(mlir::Builder *b, mlir::OperationState &state, - mlir::ValuePtr input, size_t index) { + mlir::Value input, size_t index) { // Extract the result type from the input type. StructType structTy = input->getType().cast(); assert(index < structTy.getNumElementTypes()); @@ -306,7 +305,7 @@ static mlir::LogicalResult verify(StructAccessOp op) { // TransposeOp void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::ValuePtr value) { + mlir::Value value) { state.addTypes(UnrankedTensorType::get(builder->getF64Type())); state.addOperands(value); } diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp index cba838a2928..2d6e76de069 100644 --- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp @@ -34,8 +34,8 @@ static MemRefType convertTensorToMemRef(TensorType type) { } /// Insert an allocation and deallocation for the given MemRefType. -static ValuePtr insertAllocAndDealloc(MemRefType type, Location loc, - PatternRewriter &rewriter) { +static Value insertAllocAndDealloc(MemRefType type, Location loc, + PatternRewriter &rewriter) { auto alloc = rewriter.create(loc, type); // Make sure to allocate at the beginning of the block. @@ -54,11 +54,11 @@ static ValuePtr insertAllocAndDealloc(MemRefType type, Location loc, /// to the operands of the input operation, and the set of loop induction /// variables for the iteration. It returns a value to store at the current /// index of the iteration. -using LoopIterationFn = function_ref memRefOperands, - ArrayRef loopIvs)>; +using LoopIterationFn = function_ref memRefOperands, + ArrayRef loopIvs)>; -static void lowerOpToLoops(Operation *op, ArrayRef operands, +static void lowerOpToLoops(Operation *op, ArrayRef operands, PatternRewriter &rewriter, LoopIterationFn processIteration) { auto tensorType = (*op->result_type_begin()).cast(); @@ -69,7 +69,7 @@ static void lowerOpToLoops(Operation *op, ArrayRef operands, auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); // Create an empty affine loop for each of the dimensions within the shape. - SmallVector loopIvs; + SmallVector loopIvs; for (auto dim : tensorType.getShape()) { auto loop = rewriter.create(loc, /*lb=*/0, dim, /*step=*/1); loop.getBody()->clear(); @@ -85,7 +85,7 @@ static void lowerOpToLoops(Operation *op, ArrayRef operands, // Generate a call to the processing function with the rewriter, the memref // operands, and the loop induction variables. This function will return the // value to store at the current index. - ValuePtr valueToStore = processIteration(rewriter, operands, loopIvs); + Value valueToStore = processIteration(rewriter, operands, loopIvs); rewriter.create(loc, valueToStore, alloc, llvm::makeArrayRef(loopIvs)); @@ -104,13 +104,13 @@ struct BinaryOpLowering : public ConversionPattern { : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); lowerOpToLoops( op, operands, rewriter, - [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, - ArrayRef loopIvs) { + [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, + ArrayRef loopIvs) { // Generate an adaptor for the remapped operands of the BinaryOp. This // allows for using the nice named accessors that are generated by the // ODS. @@ -154,7 +154,7 @@ struct ConstantOpLowering : public OpRewritePattern { // Create these constants up-front to avoid large amounts of redundant // operations. auto valueShape = memRefType.getShape(); - SmallVector constantIndices; + SmallVector constantIndices; for (auto i : llvm::seq( 0, *std::max_element(valueShape.begin(), valueShape.end()))) constantIndices.push_back(rewriter.create(loc, i)); @@ -163,7 +163,7 @@ struct ConstantOpLowering : public OpRewritePattern { // will need to generate a store for each of the elements. The following // functor recursively walks the dimensions of the constant shape, // generating a store when the recursion hits the base case. - SmallVector indices; + SmallVector indices; auto valueIt = constantValue.getValues().begin(); std::function storeElements = [&](uint64_t dimension) { // The last dimension is the base case of the recursion, at this point @@ -222,22 +222,22 @@ struct TransposeOpLowering : public ConversionPattern { : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); lowerOpToLoops( op, operands, rewriter, - [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, - ArrayRef loopIvs) { + [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, + ArrayRef loopIvs) { // Generate an adaptor for the remapped operands of the TransposeOp. // This allows for using the nice named accessors that are generated // by the ODS. toy::TransposeOpOperandAdaptor transposeAdaptor(memRefOperands); - ValuePtr input = transposeAdaptor.input(); + Value input = transposeAdaptor.input(); // Transpose the elements by generating a load from the reverse // indices. - SmallVector reverseIvs(llvm::reverse(loopIvs)); + SmallVector reverseIvs(llvm::reverse(loopIvs)); return rewriter.create(loc, input, reverseIvs); }); return matchSuccess(); diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp index 377bc11dd27..2f1a6ae8bbe 100644 --- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp @@ -42,7 +42,7 @@ public: : ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto memRefType = (*op->operand_type_begin()).cast(); auto memRefShape = memRefType.getShape(); @@ -55,14 +55,14 @@ public: // Get a symbol reference to the printf function, inserting it if necessary. auto printfRef = getOrInsertPrintf(rewriter, parentModule, llvmDialect); - ValuePtr formatSpecifierCst = getOrCreateGlobalString( + Value formatSpecifierCst = getOrCreateGlobalString( loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule, llvmDialect); - ValuePtr newLineCst = getOrCreateGlobalString( + Value newLineCst = getOrCreateGlobalString( loc, rewriter, "nl", StringRef("\n\0", 2), parentModule, llvmDialect); // Create a loop for each of the dimensions within the shape. - SmallVector loopIvs; + SmallVector loopIvs; for (unsigned i = 0, e = memRefShape.size(); i != e; ++i) { auto lowerBound = rewriter.create(loc, 0); auto upperBound = rewriter.create(loc, memRefShape[i]); @@ -86,9 +86,8 @@ public: // Generate a call to printf for the current element of the loop. auto printOp = cast(op); auto elementLoad = rewriter.create(loc, printOp.input(), loopIvs); - rewriter.create( - loc, printfRef, rewriter.getIntegerType(32), - ArrayRef({formatSpecifierCst, elementLoad})); + rewriter.create(loc, printfRef, rewriter.getIntegerType(32), + ArrayRef({formatSpecifierCst, elementLoad})); // Notify the rewriter that this operation has been removed. rewriter.eraseOp(op); @@ -121,10 +120,10 @@ private: /// Return a value representing an access into a global string with the given /// name, creating the string if necessary. - static ValuePtr getOrCreateGlobalString(Location loc, OpBuilder &builder, - StringRef name, StringRef value, - ModuleOp module, - LLVM::LLVMDialect *llvmDialect) { + static Value getOrCreateGlobalString(Location loc, OpBuilder &builder, + StringRef name, StringRef value, + ModuleOp module, + LLVM::LLVMDialect *llvmDialect) { // Create the global at the entry of the module. LLVM::GlobalOp global; if (!(global = module.lookupSymbol(name))) { @@ -138,13 +137,13 @@ private: } // Get the pointer to the first character in the global string. - ValuePtr globalPtr = builder.create(loc, global); - ValuePtr cst0 = builder.create( + Value globalPtr = builder.create(loc, global); + Value cst0 = builder.create( loc, LLVM::LLVMType::getInt64Ty(llvmDialect), builder.getIntegerAttr(builder.getIndexType(), 0)); return builder.create( loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), globalPtr, - ArrayRef({cst0, cst0})); + ArrayRef({cst0, cst0})); } }; } // end anonymous namespace diff --git a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp index 62e8c553709..3d543f69bdc 100644 --- a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp @@ -99,11 +99,11 @@ private: /// Entering a function creates a new scope, and the function arguments are /// added to the mapping. When the processing of a function is terminated, the /// scope is destroyed and the mappings created in this scope are dropped. - llvm::ScopedHashTable> + llvm::ScopedHashTable> symbolTable; using SymbolTableScopeT = llvm::ScopedHashTableScope>; + std::pair>; /// A mapping for the functions that have been code generated to MLIR. llvm::StringMap functionMap; @@ -120,7 +120,7 @@ private: /// Declare a variable in the current scope, return success if the variable /// wasn't declared yet. - mlir::LogicalResult declare(VarDeclExprAST &var, mlir::ValuePtr value) { + mlir::LogicalResult declare(VarDeclExprAST &var, mlir::Value value) { if (symbolTable.count(var.getName())) return mlir::failure(); symbolTable.insert(var.getName(), {value, &var}); @@ -292,7 +292,7 @@ private: } /// Emit a binary operation - mlir::ValuePtr mlirGen(BinaryExprAST &binop) { + mlir::Value mlirGen(BinaryExprAST &binop) { // First emit the operations for each side of the operation before emitting // the operation itself. For example if the expression is `a + foo(a)` // 1) First it will visiting the LHS, which will return a reference to the @@ -304,7 +304,7 @@ private: // and the result value is returned. If an error occurs we get a nullptr // and propagate. // - mlir::ValuePtr lhs = mlirGen(*binop.getLHS()); + mlir::Value lhs = mlirGen(*binop.getLHS()); if (!lhs) return nullptr; auto location = loc(binop.loc()); @@ -320,7 +320,7 @@ private: } // Otherwise, this is a normal binary op. - mlir::ValuePtr rhs = mlirGen(*binop.getRHS()); + mlir::Value rhs = mlirGen(*binop.getRHS()); if (!rhs) return nullptr; @@ -340,7 +340,7 @@ private: /// This is a reference to a variable in an expression. The variable is /// expected to have been declared and so should have a value in the symbol /// table, otherwise emit an error and return nullptr. - mlir::ValuePtr mlirGen(VariableExprAST &expr) { + mlir::Value mlirGen(VariableExprAST &expr) { if (auto variable = symbolTable.lookup(expr.getName()).first) return variable; @@ -354,7 +354,7 @@ private: auto location = loc(ret.loc()); // 'return' takes an optional expression, handle that case here. - mlir::ValuePtr expr = nullptr; + mlir::Value expr = nullptr; if (ret.getExpr().hasValue()) { if (!(expr = mlirGen(*ret.getExpr().getValue()))) return mlir::failure(); @@ -362,7 +362,7 @@ private: // Otherwise, this return operation has zero operands. builder.create(location, expr ? makeArrayRef(expr) - : ArrayRef()); + : ArrayRef()); return mlir::success(); } @@ -441,7 +441,7 @@ private: } /// Emit an array literal. - mlir::ValuePtr mlirGen(LiteralExprAST &lit) { + mlir::Value mlirGen(LiteralExprAST &lit) { mlir::Type type = getType(lit.getDims()); mlir::DenseElementsAttr dataAttribute = getConstantAttr(lit); @@ -453,7 +453,7 @@ private: /// Emit a struct literal. It will be emitted as an array of /// other literals in an Attribute attached to a `toy.struct_constant` /// operation. - mlir::ValuePtr mlirGen(StructLiteralExprAST &lit) { + mlir::Value mlirGen(StructLiteralExprAST &lit) { mlir::ArrayAttr dataAttr; mlir::Type dataType; std::tie(dataAttr, dataType) = getConstantAttr(lit); @@ -484,12 +484,12 @@ private: /// Emit a call expression. It emits specific operations for the `transpose` /// builtin. Other identifiers are assumed to be user-defined functions. - mlir::ValuePtr mlirGen(CallExprAST &call) { + mlir::Value mlirGen(CallExprAST &call) { llvm::StringRef callee = call.getCallee(); auto location = loc(call.loc()); // Codegen the operands first. - SmallVector operands; + SmallVector operands; for (auto &expr : call.getArgs()) { auto arg = mlirGen(*expr); if (!arg) @@ -534,12 +534,12 @@ private: } /// Emit a constant for a single number (FIXME: semantic? broadcast?) - mlir::ValuePtr mlirGen(NumberExprAST &num) { + mlir::Value mlirGen(NumberExprAST &num) { return builder.create(loc(num.loc()), num.getValue()); } /// Dispatch codegen for the right expression subclass using RTTI. - mlir::ValuePtr mlirGen(ExprAST &expr) { + mlir::Value mlirGen(ExprAST &expr) { switch (expr.getKind()) { case toy::ExprAST::Expr_BinOp: return mlirGen(cast(expr)); @@ -565,7 +565,7 @@ private: /// initializer and record the value in the symbol table before returning it. /// Future expressions will be able to reference this variable through symbol /// table lookup. - mlir::ValuePtr mlirGen(VarDeclExprAST &vardecl) { + mlir::Value mlirGen(VarDeclExprAST &vardecl) { auto init = vardecl.getInitVal(); if (!init) { emitError(loc(vardecl.loc()), @@ -573,7 +573,7 @@ private: return nullptr; } - mlir::ValuePtr value = mlirGen(*init); + mlir::Value value = mlirGen(*init); if (!value) return nullptr; diff --git a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp index 2fb0a1c5b69..c688a53d86f 100644 --- a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp @@ -62,7 +62,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { matchAndRewrite(TransposeOp op, mlir::PatternRewriter &rewriter) const override { // Look through the input of the current transpose. - mlir::ValuePtr transposeInput = op.getOperand(); + mlir::Value transposeInput = op.getOperand(); TransposeOp transposeInputOp = llvm::dyn_cast_or_null(transposeInput->getDefiningOp()); diff --git a/mlir/g3doc/DeclarativeRewrites.md b/mlir/g3doc/DeclarativeRewrites.md index 9fcd4341611..67ff102fef9 100644 --- a/mlir/g3doc/DeclarativeRewrites.md +++ b/mlir/g3doc/DeclarativeRewrites.md @@ -233,7 +233,7 @@ In the above, we are using `BOp`'s result for building `COp`. Given that `COp` was specified with table-driven op definition, there will be several `build()` methods generated for it. One of them has aggregated parameters for result types, operands, and attributes in the signature: `void -COp::build(..., ArrayRef resultTypes, Array operands, +COp::build(..., ArrayRef resultTypes, Array operands, ArrayRef attr)`. The pattern in the above calls this `build()` method for constructing the `COp`. @@ -266,7 +266,7 @@ For example, for the above `AOp`, a possible builder is: ```c++ void AOp::build(Builder *builder, OperationState &state, - ValuePtr input, Attribute attr) { + Value input, Attribute attr) { state.addOperands({input}); state.addAttribute("a_attr", attr); Type type = ...; // Deduce result type here @@ -422,7 +422,7 @@ op; it can be also used to specify how to build an op entirely. An example: If we have a C++ function for building an op: ```c++ -Operation *createMyOp(OpBuilder builder, ValuePtr input, Attribute attr); +Operation *createMyOp(OpBuilder builder, Value input, Attribute attr); ``` We can wrap it up and invoke it like: diff --git a/mlir/g3doc/DialectConversion.md b/mlir/g3doc/DialectConversion.md index 6771860366c..e6b652f2191 100644 --- a/mlir/g3doc/DialectConversion.md +++ b/mlir/g3doc/DialectConversion.md @@ -209,7 +209,7 @@ class TypeConverter { /// the conversion has finished. virtual Operation *materializeConversion(PatternRewriter &rewriter, Type resultType, - ArrayRef inputs, + ArrayRef inputs, Location loc); }; ``` @@ -232,7 +232,7 @@ struct MyConversionPattern : public ConversionPattern { /// `operands` parameter, containing the remapped operands of the original /// operation. virtual PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const; }; ``` @@ -269,7 +269,7 @@ public: /// Remap an input of the original signature to another `replacement` /// value. This drops the original argument. - void remapInput(unsigned origInputNo, ValuePtr replacement); + void remapInput(unsigned origInputNo, Value replacement); }; ``` diff --git a/mlir/g3doc/GenericDAGRewriter.md b/mlir/g3doc/GenericDAGRewriter.md index 64b8f4f7ade..8cc09f7d17f 100644 --- a/mlir/g3doc/GenericDAGRewriter.md +++ b/mlir/g3doc/GenericDAGRewriter.md @@ -128,7 +128,7 @@ complicated :) if (match(LHS, m_Xor(m_Value(Y), m_APInt(C1)))) if (C1->countTrailingZeros() == 0) if (match(Y, m_And(m_Value(Z), m_APInt(C2))) && *C1 == (*C2 + 1)) { - ValuePtr NewOr = Builder.CreateOr(Z, ~(*C2)); + Value NewOr = Builder.CreateOr(Z, ~(*C2)); return Builder.CreateSub(RHS, NewOr, "sub"); } ``` diff --git a/mlir/g3doc/OpDefinitions.md b/mlir/g3doc/OpDefinitions.md index 1db18266ee0..ff3a21fa1bb 100644 --- a/mlir/g3doc/OpDefinitions.md +++ b/mlir/g3doc/OpDefinitions.md @@ -360,7 +360,7 @@ def MyInterface : OpInterface<"MyInterface"> { // A new non-static method accepting an input argument. InterfaceMethod<"/*insert doc here*/", - "ValuePtr ", "bar", (ins "unsigned":$i) + "Value ", "bar", (ins "unsigned":$i) >, // Query a static property of the derived operation. @@ -438,7 +438,7 @@ static void build(Builder *tblgen_builder, OperationState &tblgen_state, // for attributes are of mlir::Attribute types. static void build(Builder *tblgen_builder, OperationState &tblgen_state, Type i32_result, Type f32_result, ..., - ValuePtr i32_operand, ValuePtr f32_operand, ..., + Value i32_operand, Value f32_operand, ..., IntegerAttr i32_attr, FloatAttr f32_attr, ...); // Each result-type/operand/attribute has a separate parameter. The parameters @@ -447,13 +447,13 @@ static void build(Builder *tblgen_builder, OperationState &tblgen_state, // explanation for more details.) static void build(Builder *tblgen_builder, OperationState &tblgen_state, Type i32_result, Type f32_result, ..., - ValuePtr i32_operand, ValuePtr f32_operand, ..., + Value i32_operand, Value f32_operand, ..., APInt i32_attr, StringRef f32_attr, ...); // Each operand/attribute has a separate parameter but result type is aggregate. static void build(Builder *tblgen_builder, OperationState &tblgen_state, ArrayRef resultTypes, - ValuePtr i32_operand, ValuePtr f32_operand, ..., + Value i32_operand, Value f32_operand, ..., IntegerAttr i32_attr, FloatAttr f32_attr, ...); // All operands/attributes have aggregate parameters. @@ -615,10 +615,9 @@ coding style requirements. For each operation, we automatically generate an _operand adaptor_. This class solves the problem of accessing operands provided as a list of `Value`s without using "magic" constants. The operand adaptor takes a reference to an array of -`ValuePtr` and provides methods with the same names as those in the operation -class to access them. For example, for a binary arithmetic operation, it may -provide `.lhs()` to access the first operand and `.rhs()` to access the second -operand. +`Value` and provides methods with the same names as those in the operation class +to access them. For example, for a binary arithmetic operation, it may provide +`.lhs()` to access the first operand and `.rhs()` to access the second operand. The operand adaptor class lives in the same namespace as the operation class, and has the name of the operation followed by `OperandAdaptor`. A template @@ -629,11 +628,11 @@ Operand adaptors can be used in function templates that also process operations: ```c++ template -std::pair zip(BinaryOpTy &&op) { +std::pair zip(BinaryOpTy &&op) { return std::make_pair(op.lhs(), op.rhs());; } -void process(AddOp op, ArrayRef newOperands) { +void process(AddOp op, ArrayRef newOperands) { zip(op); zip(OperandAdaptor(newOperands)); /*...*/ diff --git a/mlir/g3doc/Tutorials/Toy/Ch-3.md b/mlir/g3doc/Tutorials/Toy/Ch-3.md index fb470434d6f..615c2c1bbec 100644 --- a/mlir/g3doc/Tutorials/Toy/Ch-3.md +++ b/mlir/g3doc/Tutorials/Toy/Ch-3.md @@ -90,7 +90,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { matchAndRewrite(TransposeOp op, mlir::PatternRewriter &rewriter) const override { // Look through the input of the current transpose. - mlir::ValuePtr transposeInput = op.getOperand(); + mlir::Value transposeInput = op.getOperand(); TransposeOp transposeInputOp = llvm::dyn_cast_or_null(transposeInput->getDefiningOp()); // If the input is defined by another Transpose, bingo! diff --git a/mlir/g3doc/Tutorials/Toy/Ch-4.md b/mlir/g3doc/Tutorials/Toy/Ch-4.md index 921e5cdc52a..4a4e11c68e6 100644 --- a/mlir/g3doc/Tutorials/Toy/Ch-4.md +++ b/mlir/g3doc/Tutorials/Toy/Ch-4.md @@ -75,7 +75,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface { /// previously returned by the call operation with the operands of the /// return. void handleTerminator(Operation *op, - ArrayRef valuesToRepl) const final { + ArrayRef valuesToRepl) const final { // Only "toy.return" needs to be handled here. auto returnOp = cast(op); @@ -207,7 +207,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface { /// operation that takes 'input' as the only operand, and produces a single /// result of 'resultType'. If a conversion can not be generated, nullptr /// should be returned. - Operation *materializeCallConversion(OpBuilder &builder, ValuePtr input, + Operation *materializeCallConversion(OpBuilder &builder, Value input, Type resultType, Location conversionLoc) const final { return builder.create(conversionLoc, resultType, input); diff --git a/mlir/g3doc/Tutorials/Toy/Ch-5.md b/mlir/g3doc/Tutorials/Toy/Ch-5.md index ed62f8954b7..8a4268b498f 100644 --- a/mlir/g3doc/Tutorials/Toy/Ch-5.md +++ b/mlir/g3doc/Tutorials/Toy/Ch-5.md @@ -101,7 +101,7 @@ struct TransposeOpLowering : public mlir::ConversionPattern { /// Match and rewrite the given `toy.transpose` operation, with the given /// operands that have been remapped from `tensor<...>` to `memref<...>`. mlir::PatternMatchResult - matchAndRewrite(mlir::Operation *op, ArrayRef operands, + matchAndRewrite(mlir::Operation *op, ArrayRef operands, mlir::ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); @@ -112,18 +112,18 @@ struct TransposeOpLowering : public mlir::ConversionPattern { lowerOpToLoops( op, operands, rewriter, [loc](mlir::PatternRewriter &rewriter, - ArrayRef memRefOperands, - ArrayRef loopIvs) { + ArrayRef memRefOperands, + ArrayRef loopIvs) { // Generate an adaptor for the remapped operands of the TransposeOp. // This allows for using the nice named accessors that are generated // by the ODS. This adaptor is automatically provided by the ODS // framework. TransposeOpOperandAdaptor transposeAdaptor(memRefOperands); - mlir::ValuePtr input = transposeAdaptor.input(); + mlir::Value input = transposeAdaptor.input(); // Transpose the elements by generating a load from the reverse // indices. - SmallVector reverseIvs(llvm::reverse(loopIvs)); + SmallVector reverseIvs(llvm::reverse(loopIvs)); return rewriter.create(loc, input, reverseIvs); }); return matchSuccess(); diff --git a/mlir/g3doc/UsageOfConst.md b/mlir/g3doc/UsageOfConst.md index 5f6d3793164..6e8ce78e960 100644 --- a/mlir/g3doc/UsageOfConst.md +++ b/mlir/g3doc/UsageOfConst.md @@ -10,8 +10,8 @@ understood (even though the LLVM implementation is flawed in many ways). The design team since decided to change to a different module, which eschews `const` entirely for the core IR types: you should never see a `const` method on -`Operation`, should never see the type `const ValuePtr`, and you shouldn't feel -bad about this. That said, you *should* use `const` for non-IR types, like +`Operation`, should never see the type `const Value`, and you shouldn't feel bad +about this. That said, you *should* use `const` for non-IR types, like `SmallVector`'s and many other things. The document below explains this design point from the viewpoint of "why make a @@ -130,7 +130,7 @@ const. operand_iterator operand_begin(); operand_iterator operand_end(); - /// Returns an iterator on the underlying Value's (ValuePtr ). + /// Returns an iterator on the underlying Value's (Value ). operand_range getOperands(); // Support const operand iteration. @@ -141,7 +141,7 @@ const. const_operand_iterator operand_begin() const; const_operand_iterator operand_end() const; - /// Returns a const iterator on the underlying Value's (ValuePtr ). + /// Returns a const iterator on the underlying Value's (Value ). llvm::iterator_range getOperands() const; ArrayRef getOpOperands() const { diff --git a/mlir/include/mlir/Analysis/AffineAnalysis.h b/mlir/include/mlir/Analysis/AffineAnalysis.h index 6029a9ccdaa..d0bcb932c04 100644 --- a/mlir/include/mlir/Analysis/AffineAnalysis.h +++ b/mlir/include/mlir/Analysis/AffineAnalysis.h @@ -27,13 +27,10 @@ class AffineValueMap; class FlatAffineConstraints; class Operation; -// TODO(riverriddle) Remove this after Value is value-typed. -using ValuePtr = Value; - /// Returns in `affineApplyOps`, the sequence of those AffineApplyOp /// Operations that are reachable via a search starting from `operands` and /// ending at those operands that are not the result of an AffineApplyOp. -void getReachableAffineApplyOps(ArrayRef operands, +void getReachableAffineApplyOps(ArrayRef operands, SmallVectorImpl &affineApplyOps); /// Builds a system of constraints with dimensional identifiers corresponding to @@ -47,9 +44,9 @@ LogicalResult getIndexSet(MutableArrayRef forOps, /// Encapsulates a memref load or store access information. struct MemRefAccess { - ValuePtr memref; + Value memref; Operation *opInst; - SmallVector indices; + SmallVector indices; /// Constructs a MemRefAccess from a load or store operation. // TODO(b/119949820): add accessors to standard op's load, store, DMA op's to diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h index 770bf686f50..47e0ddab547 100644 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -114,8 +114,8 @@ public: // Creates an empty AffineValueMap (users should call 'reset' to reset map // and operands). AffineValueMap() {} - AffineValueMap(AffineMap map, ArrayRef operands, - ArrayRef results = llvm::None); + AffineValueMap(AffineMap map, ArrayRef operands, + ArrayRef results = llvm::None); explicit AffineValueMap(AffineApplyOp applyOp); explicit AffineValueMap(AffineBound bound); @@ -123,8 +123,8 @@ public: ~AffineValueMap(); // Resets this AffineValueMap with 'map', 'operands', and 'results'. - void reset(AffineMap map, ArrayRef operands, - ArrayRef results = llvm::None); + void reset(AffineMap map, ArrayRef operands, + ArrayRef results = llvm::None); /// Return the value map that is the difference of value maps 'a' and 'b', /// represented as an affine map and its operands. The output map + operands @@ -137,7 +137,7 @@ public: inline bool isMultipleOf(unsigned idx, int64_t factor) const; /// Return true if the idx^th result depends on 'value', false otherwise. - bool isFunctionOf(unsigned idx, ValuePtr value) const; + bool isFunctionOf(unsigned idx, Value value) const; /// Return true if the result at 'idx' is a constant, false /// otherwise. @@ -153,8 +153,8 @@ public: inline unsigned getNumSymbols() const { return map.getNumSymbols(); } inline unsigned getNumResults() const { return map.getNumResults(); } - ValuePtr getOperand(unsigned i) const; - ArrayRef getOperands() const; + Value getOperand(unsigned i) const; + ArrayRef getOperands() const; AffineMap getAffineMap() const; private: @@ -163,9 +163,9 @@ private: // TODO: make these trailing objects? /// The SSA operands binding to the dim's and symbols of 'map'. - SmallVector operands; + SmallVector operands; /// The SSA results binding to the results of 'map'. - SmallVector results; + SmallVector results; }; /// An IntegerValueSet is an integer set plus its operands. @@ -198,7 +198,7 @@ private: // 'AffineCondition'. MutableIntegerSet set; /// The SSA operands binding to the dim's and symbols of 'set'. - SmallVector operands; + SmallVector operands; }; /// A flat list of affine equalities and inequalities in the form. @@ -236,7 +236,7 @@ public: unsigned numReservedEqualities, unsigned numReservedCols, unsigned numDims = 0, unsigned numSymbols = 0, unsigned numLocals = 0, - ArrayRef> idArgs = {}) + ArrayRef> idArgs = {}) : numReservedCols(numReservedCols), numDims(numDims), numSymbols(numSymbols) { assert(numReservedCols >= numDims + numSymbols + 1); @@ -255,7 +255,7 @@ public: /// dimensions and symbols. FlatAffineConstraints(unsigned numDims = 0, unsigned numSymbols = 0, unsigned numLocals = 0, - ArrayRef> idArgs = {}) + ArrayRef> idArgs = {}) : numReservedCols(numDims + numSymbols + numLocals + 1), numDims(numDims), numSymbols(numSymbols) { assert(numReservedCols >= numDims + numSymbols + 1); @@ -295,10 +295,10 @@ public: // Clears any existing data and reserves memory for the specified constraints. void reset(unsigned numReservedInequalities, unsigned numReservedEqualities, unsigned numReservedCols, unsigned numDims, unsigned numSymbols, - unsigned numLocals = 0, ArrayRef idArgs = {}); + unsigned numLocals = 0, ArrayRef idArgs = {}); void reset(unsigned numDims = 0, unsigned numSymbols = 0, - unsigned numLocals = 0, ArrayRef idArgs = {}); + unsigned numLocals = 0, ArrayRef idArgs = {}); /// Appends constraints from 'other' into this. This is equivalent to an /// intersection with no simplification of any sort attempted. @@ -387,7 +387,7 @@ public: /// operands. If `eq` is true, add a single equality equal to the bound map's /// first result expr. LogicalResult addLowerOrUpperBound(unsigned pos, AffineMap boundMap, - ArrayRef operands, bool eq, + ArrayRef operands, bool eq, bool lower = true); /// Computes the lower and upper bounds of the first 'num' dimensional @@ -406,10 +406,10 @@ public: /// operand list 'operands'. /// This function assumes 'values.size' == 'lbMaps.size' == 'ubMaps.size'. /// Note that both lower/upper bounds use operands from 'operands'. - LogicalResult addSliceBounds(ArrayRef values, + LogicalResult addSliceBounds(ArrayRef values, ArrayRef lbMaps, ArrayRef ubMaps, - ArrayRef operands); + ArrayRef operands); // Adds an inequality (>= 0) from the coefficients specified in inEq. void addInequality(ArrayRef inEq); @@ -438,25 +438,25 @@ public: /// Sets the identifier corresponding to the specified Value id to a /// constant. Asserts if the 'id' is not found. - void setIdToConstant(ValueRef id, int64_t val); + void setIdToConstant(Value id, int64_t val); /// Looks up the position of the identifier with the specified Value. Returns /// true if found (false otherwise). `pos' is set to the (column) position of /// the identifier. - bool findId(ValueRef id, unsigned *pos) const; + bool findId(Value id, unsigned *pos) const; /// Returns true if an identifier with the specified Value exists, false /// otherwise. - bool containsId(ValueRef id) const; + bool containsId(Value id) const; // Add identifiers of the specified kind - specified positions are relative to // the kind of identifier. The coefficient column corresponding to the added // identifier is initialized to zero. 'id' is the Value corresponding to the // identifier that can optionally be provided. - void addDimId(unsigned pos, ValuePtr id = nullptr); - void addSymbolId(unsigned pos, ValuePtr id = nullptr); + void addDimId(unsigned pos, Value id = nullptr); + void addSymbolId(unsigned pos, Value id = nullptr); void addLocalId(unsigned pos); - void addId(IdKind kind, unsigned pos, ValuePtr id = nullptr); + void addId(IdKind kind, unsigned pos, Value id = nullptr); /// Add the specified values as a dim or symbol id depending on its nature, if /// it already doesn't exist in the system. `id' has to be either a terminal @@ -464,7 +464,7 @@ public: /// symbols or loop IVs. The identifier is added to the end of the existing /// dims or symbols. Additional information on the identifier is extracted /// from the IR and added to the constraint system. - void addInductionVarOrTerminalSymbol(ValuePtr id); + void addInductionVarOrTerminalSymbol(Value id); /// Composes the affine value map with this FlatAffineConstrains, adding the /// results of the map as dimensions at the front [0, vMap->getNumResults()) @@ -491,8 +491,8 @@ public: void projectOut(unsigned pos, unsigned num); inline void projectOut(unsigned pos) { return projectOut(pos, 1); } - /// Projects out the identifier that is associate with ValuePtr . - void projectOut(ValuePtr id); + /// Projects out the identifier that is associate with Value . + void projectOut(Value id); void removeId(IdKind idKind, unsigned pos); void removeId(unsigned pos); @@ -568,20 +568,20 @@ public: return numIds - numDims - numSymbols; } - inline ArrayRef> getIds() const { + inline ArrayRef> getIds() const { return {ids.data(), ids.size()}; } - inline MutableArrayRef> getIds() { + inline MutableArrayRef> getIds() { return {ids.data(), ids.size()}; } /// Returns the optional Value corresponding to the pos^th identifier. - inline Optional getId(unsigned pos) const { return ids[pos]; } - inline Optional &getId(unsigned pos) { return ids[pos]; } + inline Optional getId(unsigned pos) const { return ids[pos]; } + inline Optional &getId(unsigned pos) { return ids[pos]; } /// Returns the Value associated with the pos^th identifier. Asserts if /// no Value identifier was associated. - inline ValuePtr getIdValue(unsigned pos) const { + inline Value getIdValue(unsigned pos) const { assert(ids[pos].hasValue() && "identifier's Value not set"); return ids[pos].getValue(); } @@ -589,7 +589,7 @@ public: /// Returns the Values associated with identifiers in range [start, end). /// Asserts if no Value was associated with one of these identifiers. void getIdValues(unsigned start, unsigned end, - SmallVectorImpl *values) const { + SmallVectorImpl *values) const { assert((start < numIds || start == end) && "invalid start position"); assert(end <= numIds && "invalid end position"); values->clear(); @@ -598,17 +598,17 @@ public: values->push_back(getIdValue(i)); } } - inline void getAllIdValues(SmallVectorImpl *values) const { + inline void getAllIdValues(SmallVectorImpl *values) const { getIdValues(0, numIds, values); } /// Sets Value associated with the pos^th identifier. - inline void setIdValue(unsigned pos, ValuePtr val) { + inline void setIdValue(unsigned pos, Value val) { assert(pos < numIds && "invalid id position"); ids[pos] = val; } /// Sets Values associated with identifiers in the range [start, end). - void setIdValues(unsigned start, unsigned end, ArrayRef values) { + void setIdValues(unsigned start, unsigned end, ArrayRef values) { assert((start < numIds || end == start) && "invalid start position"); assert(end <= numIds && "invalid end position"); assert(values.size() == end - start); @@ -757,7 +757,7 @@ private: /// system appearing in the order the identifiers correspond to columns. /// Temporary ones or those that aren't associated to any Value are set to /// None. - SmallVector, 8> ids; + SmallVector, 8> ids; /// A parameter that controls detection of an unrealistic number of /// constraints. If the number of constraints is this many times the number of diff --git a/mlir/include/mlir/Analysis/CallInterfaces.h b/mlir/include/mlir/Analysis/CallInterfaces.h index a9806bfb8c6..b5870bac142 100644 --- a/mlir/include/mlir/Analysis/CallInterfaces.h +++ b/mlir/include/mlir/Analysis/CallInterfaces.h @@ -21,8 +21,8 @@ namespace mlir { /// A callable is either a symbol, or an SSA value, that is referenced by a /// call-like operation. This represents the destination of the call. -struct CallInterfaceCallable : public PointerUnion { - using PointerUnion::PointerUnion; +struct CallInterfaceCallable : public PointerUnion { + using PointerUnion::PointerUnion; }; #include "mlir/Analysis/CallInterfaces.h.inc" diff --git a/mlir/include/mlir/Analysis/Dominance.h b/mlir/include/mlir/Analysis/Dominance.h index 5c42dbe12c2..ead54b93e80 100644 --- a/mlir/include/mlir/Analysis/Dominance.h +++ b/mlir/include/mlir/Analysis/Dominance.h @@ -65,10 +65,10 @@ public: } /// Return true if value A properly dominates operation B. - bool properlyDominates(ValuePtr a, Operation *b); + bool properlyDominates(Value a, Operation *b); /// Return true if operation A dominates operation B. - bool dominates(ValuePtr a, Operation *b) { + bool dominates(Value a, Operation *b) { return (Operation *)a->getDefiningOp() == b || properlyDominates(a, b); } diff --git a/mlir/include/mlir/Analysis/Liveness.h b/mlir/include/mlir/Analysis/Liveness.h index cbd2e63fd3e..7e1dc2903ae 100644 --- a/mlir/include/mlir/Analysis/Liveness.h +++ b/mlir/include/mlir/Analysis/Liveness.h @@ -32,9 +32,6 @@ class Operation; class Region; class Value; -// TODO(riverriddle) Remove this after Value is value-typed. -using ValuePtr = Value; - /// Represents an analysis for computing liveness information from a /// given top-level operation. The analysis iterates over all associated /// regions that are attached to the given top-level operation. It @@ -51,7 +48,7 @@ class Liveness { public: using OperationListT = std::vector; using BlockMapT = DenseMap; - using ValueSetT = SmallPtrSet; + using ValueSetT = SmallPtrSet; public: /// Creates a new Liveness analysis that computes liveness @@ -66,7 +63,7 @@ public: /// Note that the operations in this list are not ordered and the current /// implementation is computationally expensive (as it iterates over all /// blocks in which the given value is live). - OperationListT resolveLiveness(ValuePtr value) const; + OperationListT resolveLiveness(Value value) const; /// Gets liveness info (if any) for the block. const LivenessBlockInfo *getLiveness(Block *block) const; @@ -79,7 +76,7 @@ public: /// Returns true if the given operation represent the last use of the /// given value. - bool isLastUse(ValuePtr value, Operation *operation) const; + bool isLastUse(Value value, Operation *operation) const; /// Dumps the liveness information in a human readable format. void dump() const; @@ -118,20 +115,20 @@ public: const ValueSetT &out() const { return outValues; } /// Returns true if the given value is in the live-in set. - bool isLiveIn(ValuePtr value) const; + bool isLiveIn(Value value) const; /// Returns true if the given value is in the live-out set. - bool isLiveOut(ValuePtr value) const; + bool isLiveOut(Value value) const; /// Gets the start operation for the given value. This is the first operation /// the given value is considered to be live. This could either be the start /// operation of the current block (in case the value is live-in) or the /// operation that defines the given value (must be referenced in this block). - Operation *getStartOperation(ValuePtr value) const; + Operation *getStartOperation(Value value) const; /// Gets the end operation for the given value using the start operation /// provided (must be referenced in this block). - Operation *getEndOperation(ValuePtr value, Operation *startOperation) const; + Operation *getEndOperation(Value value, Operation *startOperation) const; private: /// The underlying block. diff --git a/mlir/include/mlir/Analysis/LoopAnalysis.h b/mlir/include/mlir/Analysis/LoopAnalysis.h index 75d7b98e20f..0dd89e454a8 100644 --- a/mlir/include/mlir/Analysis/LoopAnalysis.h +++ b/mlir/include/mlir/Analysis/LoopAnalysis.h @@ -27,9 +27,6 @@ class NestedPattern; class Operation; class Value; -// TODO(riverriddle) Remove this after Value is value-typed. -using ValuePtr = Value; - /// Returns the trip count of the loop as an affine map with its corresponding /// operands if the latter is expressible as an affine expression, and nullptr /// otherwise. This method always succeeds as long as the lower bound is not a @@ -39,7 +36,7 @@ using ValuePtr = Value; // TODO(mlir-team): this should be moved into 'Transforms/' and be replaced by a // pure analysis method relying on FlatAffineConstraints void buildTripCountMapAndOperands(AffineForOp forOp, AffineMap *map, - SmallVectorImpl *operands); + SmallVectorImpl *operands); /// Returns the trip count of the loop if it's a constant, None otherwise. This /// uses affine expression analysis and is able to determine constant trip count @@ -60,8 +57,8 @@ uint64_t getLargestDivisorOfTripCount(AffineForOp forOp); /// /// Emits a note if it encounters a chain of affine.apply and conservatively /// those cases. -DenseSet> -getInvariantAccesses(ValuePtr iv, ArrayRef indices); +DenseSet> +getInvariantAccesses(Value iv, ArrayRef indices); using VectorizableLoopFun = std::function; diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index d06e003faae..7cf1e5c971a 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -46,7 +46,7 @@ unsigned getNestingDepth(Operation &op); /// Returns in 'sequentialLoops' all sequential loops in loop nest rooted /// at 'forOp'. void getSequentialLoops(AffineForOp forOp, - llvm::SmallDenseSet *sequentialLoops); + llvm::SmallDenseSet *sequentialLoops); /// ComputationSliceState aggregates loop IVs, loop bound AffineMaps and their /// associated operands for a set of loops within a loop nest (typically the @@ -55,15 +55,15 @@ void getSequentialLoops(AffineForOp forOp, struct ComputationSliceState { // List of sliced loop IVs (ordered from outermost to innermost). // EX: 'ivs[i]' has lower bound 'lbs[i]' and upper bound 'ubs[i]'. - SmallVector ivs; + SmallVector ivs; // List of lower bound AffineMaps. SmallVector lbs; // List of upper bound AffineMaps. SmallVector ubs; // List of lower bound operands (lbOperands[i] are used by 'lbs[i]'). - std::vector> lbOperands; + std::vector> lbOperands; // List of upper bound operands (ubOperands[i] are used by 'ubs[i]'). - std::vector> ubOperands; + std::vector> ubOperands; // Slice loop nest insertion point in target loop nest. Block::iterator insertPoint; // Adds to 'cst' with constraints which represent the slice bounds on 'ivs' @@ -248,7 +248,7 @@ struct MemRefRegion { unsigned getRank() const; /// Memref that this region corresponds to. - ValuePtr memref; + Value memref; /// Read or write. bool write; diff --git a/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h b/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h index c8298760bad..c6a2fac6ec9 100644 --- a/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h +++ b/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h @@ -21,17 +21,13 @@ class OpBuilder; class RewritePattern; class Value; -// TODO(riverriddle) Remove this after Value is value-typed. -using ValuePtr = Value; - // Owning list of rewriting patterns. class OwningRewritePatternList; /// Emit code that computes the given affine expression using standard /// arithmetic operations applied to the provided dimension and symbol values. -ValuePtr expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, - ArrayRef dimValues, - ArrayRef symbolValues); +Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, + ArrayRef dimValues, ArrayRef symbolValues); /// Collect a set of patterns to convert from the Affine dialect to the Standard /// dialect, in particular convert structured affine control flow into CFG @@ -41,11 +37,11 @@ void populateAffineToStdConversionPatterns(OwningRewritePatternList &patterns, /// Emit code that computes the lower bound of the given affine loop using /// standard arithmetic operations. -ValuePtr lowerAffineLowerBound(AffineForOp op, OpBuilder &builder); +Value lowerAffineLowerBound(AffineForOp op, OpBuilder &builder); /// Emit code that computes the upper bound of the given affine loop using /// standard arithmetic operations. -ValuePtr lowerAffineUpperBound(AffineForOp op, OpBuilder &builder); +Value lowerAffineUpperBound(AffineForOp op, OpBuilder &builder); } // namespace mlir #endif // MLIR_CONVERSION_AFFINETOSTANDARD_AFFINETOSTANDARD_H diff --git a/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h b/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h index b7423a58f2a..80faa03f313 100644 --- a/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h +++ b/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h @@ -15,9 +15,6 @@ class AffineForOp; struct LogicalResult; class Value; -// TODO(riverriddle) Remove this after Value is value-typed. -using ValuePtr = Value; - namespace loop { class ForOp; } // end namespace loop @@ -72,8 +69,8 @@ LogicalResult convertLoopNestToGPULaunch(loop::ForOp forOp, /// The above conditions are assumed to be satisfied by the computation rooted /// at `forOp`. LogicalResult convertLoopToGPULaunch(loop::ForOp forOp, - ArrayRef numWorkGroups, - ArrayRef workGroupSizes); + ArrayRef numWorkGroups, + ArrayRef workGroupSizes); } // namespace mlir diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h index 5c8a8e6e494..e78859f992b 100644 --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -65,16 +65,16 @@ public: /// Promote the LLVM struct representation of all MemRef descriptors to stack /// and use pointers to struct to avoid the complexity of the /// platform-specific C/C++ ABI lowering related to struct argument passing. - SmallVector promoteMemRefDescriptors(Location loc, - ValueRange opOperands, - ValueRange operands, - OpBuilder &builder); + SmallVector promoteMemRefDescriptors(Location loc, + ValueRange opOperands, + ValueRange operands, + OpBuilder &builder); /// Promote the LLVM struct representation of one MemRef descriptor to stack /// and use pointer to struct to avoid the complexity of the platform-specific /// C/C++ ABI lowering related to struct argument passing. - ValuePtr promoteOneMemRefDescriptor(Location loc, ValuePtr operand, - OpBuilder &builder); + Value promoteOneMemRefDescriptor(Location loc, Value operand, + OpBuilder &builder); protected: /// LLVM IR module used to parse/create types. @@ -130,24 +130,24 @@ private: class StructBuilder { public: /// Construct a helper for the given value. - explicit StructBuilder(ValuePtr v); + explicit StructBuilder(Value v); /// Builds IR creating an `undef` value of the descriptor type. static StructBuilder undef(OpBuilder &builder, Location loc, Type descriptorType); - /*implicit*/ operator ValuePtr() { return value; } + /*implicit*/ operator Value() { return value; } protected: // LLVM value - ValuePtr value; + Value value; // Cached struct type. Type structType; protected: /// Builds IR to extract a value from the struct at position pos - ValuePtr extractPtr(OpBuilder &builder, Location loc, unsigned pos); + Value extractPtr(OpBuilder &builder, Location loc, unsigned pos); /// Builds IR to set a value in the struct at position pos - void setPtr(OpBuilder &builder, Location loc, unsigned pos, ValuePtr ptr); + void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr); }; /// Helper class to produce LLVM dialect operations extracting or inserting /// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor. @@ -155,7 +155,7 @@ protected: class MemRefDescriptor : public StructBuilder { public: /// Construct a helper for the given descriptor value. - explicit MemRefDescriptor(ValuePtr descriptor); + explicit MemRefDescriptor(Value descriptor); /// Builds IR creating an `undef` value of the descriptor type. static MemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType); @@ -164,40 +164,39 @@ public: /// type. static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, - MemRefType type, ValuePtr memory); + MemRefType type, Value memory); /// Builds IR extracting the allocated pointer from the descriptor. - ValuePtr allocatedPtr(OpBuilder &builder, Location loc); + Value allocatedPtr(OpBuilder &builder, Location loc); /// Builds IR inserting the allocated pointer into the descriptor. - void setAllocatedPtr(OpBuilder &builder, Location loc, ValuePtr ptr); + void setAllocatedPtr(OpBuilder &builder, Location loc, Value ptr); /// Builds IR extracting the aligned pointer from the descriptor. - ValuePtr alignedPtr(OpBuilder &builder, Location loc); + Value alignedPtr(OpBuilder &builder, Location loc); /// Builds IR inserting the aligned pointer into the descriptor. - void setAlignedPtr(OpBuilder &builder, Location loc, ValuePtr ptr); + void setAlignedPtr(OpBuilder &builder, Location loc, Value ptr); /// Builds IR extracting the offset from the descriptor. - ValuePtr offset(OpBuilder &builder, Location loc); + Value offset(OpBuilder &builder, Location loc); /// Builds IR inserting the offset into the descriptor. - void setOffset(OpBuilder &builder, Location loc, ValuePtr offset); + void setOffset(OpBuilder &builder, Location loc, Value offset); void setConstantOffset(OpBuilder &builder, Location loc, uint64_t offset); /// Builds IR extracting the pos-th size from the descriptor. - ValuePtr size(OpBuilder &builder, Location loc, unsigned pos); + Value size(OpBuilder &builder, Location loc, unsigned pos); /// Builds IR inserting the pos-th size into the descriptor - void setSize(OpBuilder &builder, Location loc, unsigned pos, ValuePtr size); + void setSize(OpBuilder &builder, Location loc, unsigned pos, Value size); void setConstantSize(OpBuilder &builder, Location loc, unsigned pos, uint64_t size); /// Builds IR extracting the pos-th size from the descriptor. - ValuePtr stride(OpBuilder &builder, Location loc, unsigned pos); + Value stride(OpBuilder &builder, Location loc, unsigned pos); /// Builds IR inserting the pos-th stride into the descriptor - void setStride(OpBuilder &builder, Location loc, unsigned pos, - ValuePtr stride); + void setStride(OpBuilder &builder, Location loc, unsigned pos, Value stride); void setConstantStride(OpBuilder &builder, Location loc, unsigned pos, uint64_t stride); @@ -212,19 +211,19 @@ private: class UnrankedMemRefDescriptor : public StructBuilder { public: /// Construct a helper for the given descriptor value. - explicit UnrankedMemRefDescriptor(ValuePtr descriptor); + explicit UnrankedMemRefDescriptor(Value descriptor); /// Builds IR creating an `undef` value of the descriptor type. static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType); /// Builds IR extracting the rank from the descriptor - ValuePtr rank(OpBuilder &builder, Location loc); + Value rank(OpBuilder &builder, Location loc); /// Builds IR setting the rank in the descriptor - void setRank(OpBuilder &builder, Location loc, ValuePtr value); + void setRank(OpBuilder &builder, Location loc, Value value); /// Builds IR extracting ranked memref descriptor ptr - ValuePtr memRefDescPtr(OpBuilder &builder, Location loc); + Value memRefDescPtr(OpBuilder &builder, Location loc); /// Builds IR setting ranked memref descriptor ptr - void setMemRefDescPtr(OpBuilder &builder, Location loc, ValuePtr value); + void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value); }; /// Base class for operation conversions targeting the LLVM IR dialect. Provides /// conversion patterns with an access to the containing LLVMLowering for the diff --git a/mlir/include/mlir/Dialect/AffineOps/AffineOps.h b/mlir/include/mlir/Dialect/AffineOps/AffineOps.h index 09408d2efc8..b884ac5c2ce 100644 --- a/mlir/include/mlir/Dialect/AffineOps/AffineOps.h +++ b/mlir/include/mlir/Dialect/AffineOps/AffineOps.h @@ -32,7 +32,7 @@ class OpBuilder; /// A utility function to check if a value is defined at the top level of a /// function. A value of index type defined at the top level is always a valid /// symbol. -bool isTopLevelValue(ValuePtr value); +bool isTopLevelValue(Value value); class AffineOpsDialect : public Dialect { public: @@ -139,19 +139,17 @@ class AffineDmaStartOp : public OpgetType().cast(); } @@ -183,7 +181,7 @@ public: } /// Returns the destination MemRefType for this DMA operations. - ValuePtr getDstMemRef() { return getOperand(getDstMemRefOperandIndex()); } + Value getDstMemRef() { return getOperand(getDstMemRefOperandIndex()); } MemRefType getDstMemRefType() { return getDstMemRef()->getType().cast(); } @@ -217,7 +215,7 @@ public: } /// Returns the Tag MemRef for this DMA operation. - ValuePtr getTagMemRef() { return getOperand(getTagMemRefOperandIndex()); } + Value getTagMemRef() { return getOperand(getTagMemRefOperandIndex()); } MemRefType getTagMemRefType() { return getTagMemRef()->getType().cast(); } @@ -241,13 +239,13 @@ public: } /// Returns the number of elements being transferred by this DMA operation. - ValuePtr getNumElements() { + Value getNumElements() { return getOperand(getTagMemRefOperandIndex() + 1 + getTagMap().getNumInputs()); } /// Returns the AffineMapAttr associated with 'memref'. - NamedAttribute getAffineMapAttrForMemRef(ValuePtr memref) { + NamedAttribute getAffineMapAttrForMemRef(Value memref) { if (memref == getSrcMemRef()) return {Identifier::get(getSrcMapAttrName(), getContext()), getSrcMapAttr()}; @@ -297,14 +295,14 @@ public: } /// Returns the stride value for this DMA operation. - ValuePtr getStride() { + Value getStride() { if (!isStrided()) return nullptr; return getOperand(getNumOperands() - 1 - 1); } /// Returns the number of elements to transfer per stride for this DMA op. - ValuePtr getNumElementsPerStride() { + Value getNumElementsPerStride() { if (!isStrided()) return nullptr; return getOperand(getNumOperands() - 1); @@ -329,14 +327,13 @@ class AffineDmaWaitOp : public OpgetType().cast(); } @@ -359,16 +356,14 @@ public: } /// Returns the AffineMapAttr associated with 'memref'. - NamedAttribute getAffineMapAttrForMemRef(ValuePtr memref) { + NamedAttribute getAffineMapAttrForMemRef(Value memref) { assert(memref == getTagMemRef()); return {Identifier::get(getTagMapAttrName(), getContext()), getTagMapAttr()}; } /// Returns the number of elements transferred in the associated DMA op. - ValuePtr getNumElements() { - return getOperand(1 + getTagMap().getNumInputs()); - } + Value getNumElements() { return getOperand(1 + getTagMap().getNumInputs()); } static StringRef getTagMapAttrName() { return "tag_map"; } static ParseResult parse(OpAsmParser &parser, OperationState &result); @@ -403,18 +398,18 @@ public: static void build(Builder *builder, OperationState &result, AffineMap map, ValueRange operands); /// Builds an affine load op with an identity map and operands. - static void build(Builder *builder, OperationState &result, ValuePtr memref, + static void build(Builder *builder, OperationState &result, Value memref, ValueRange indices = {}); /// Builds an affine load op with the specified map and its operands. - static void build(Builder *builder, OperationState &result, ValuePtr memref, + static void build(Builder *builder, OperationState &result, Value memref, AffineMap map, ValueRange mapOperands); /// Returns the operand index of the memref. unsigned getMemRefOperandIndex() { return 0; } /// Get memref operand. - ValuePtr getMemRef() { return getOperand(getMemRefOperandIndex()); } - void setMemRef(ValuePtr value) { setOperand(getMemRefOperandIndex(), value); } + Value getMemRef() { return getOperand(getMemRefOperandIndex()); } + void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); } MemRefType getMemRefType() { return getMemRef()->getType().cast(); } @@ -429,7 +424,7 @@ public: } /// Returns the AffineMapAttr associated with 'memref'. - NamedAttribute getAffineMapAttrForMemRef(ValuePtr memref) { + NamedAttribute getAffineMapAttrForMemRef(Value memref) { assert(memref == getMemRef()); return {Identifier::get(getMapAttrName(), getContext()), getAffineMapAttr()}; @@ -470,21 +465,21 @@ public: /// Builds an affine store operation with the provided indices (identity map). static void build(Builder *builder, OperationState &result, - ValuePtr valueToStore, ValuePtr memref, ValueRange indices); + Value valueToStore, Value memref, ValueRange indices); /// Builds an affine store operation with the specified map and its operands. static void build(Builder *builder, OperationState &result, - ValuePtr valueToStore, ValuePtr memref, AffineMap map, + Value valueToStore, Value memref, AffineMap map, ValueRange mapOperands); /// Get value to be stored by store operation. - ValuePtr getValueToStore() { return getOperand(0); } + Value getValueToStore() { return getOperand(0); } /// Returns the operand index of the memref. unsigned getMemRefOperandIndex() { return 1; } /// Get memref operand. - ValuePtr getMemRef() { return getOperand(getMemRefOperandIndex()); } - void setMemRef(ValuePtr value) { setOperand(getMemRefOperandIndex(), value); } + Value getMemRef() { return getOperand(getMemRefOperandIndex()); } + void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); } MemRefType getMemRefType() { return getMemRef()->getType().cast(); @@ -500,7 +495,7 @@ public: } /// Returns the AffineMapAttr associated with 'memref'. - NamedAttribute getAffineMapAttrForMemRef(ValuePtr memref) { + NamedAttribute getAffineMapAttrForMemRef(Value memref) { assert(memref == getMemRef()); return {Identifier::get(getMapAttrName(), getContext()), getAffineMapAttr()}; @@ -520,10 +515,10 @@ public: }; /// Returns true if the given Value can be used as a dimension id. -bool isValidDim(ValuePtr value); +bool isValidDim(Value value); /// Returns true if the given Value can be used as a symbol. -bool isValidSymbol(ValuePtr value); +bool isValidSymbol(Value value); /// Modifies both `map` and `operands` in-place so as to: /// 1. drop duplicate operands @@ -532,17 +527,17 @@ bool isValidSymbol(ValuePtr value); /// dimensional operands /// 4. propagate constant operands and drop them void canonicalizeMapAndOperands(AffineMap *map, - SmallVectorImpl *operands); + SmallVectorImpl *operands); /// Canonicalizes an integer set the same way canonicalizeMapAndOperands does /// for affine maps. void canonicalizeSetAndOperands(IntegerSet *set, - SmallVectorImpl *operands); + SmallVectorImpl *operands); /// Returns a composed AffineApplyOp by composing `map` and `operands` with /// other AffineApplyOps supplying those operands. The operands of the resulting /// AffineApplyOp do not change the length of AffineApplyOp chains. AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, - ArrayRef operands); + ArrayRef operands); /// Given an affine map `map` and its input `operands`, this method composes /// into `map`, maps of AffineApplyOps whose results are the values in @@ -552,22 +547,22 @@ AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, /// terminal symbol, i.e., a symbol defined at the top level or a block/function /// argument. void fullyComposeAffineMapAndOperands(AffineMap *map, - SmallVectorImpl *operands); + SmallVectorImpl *operands); #define GET_OP_CLASSES #include "mlir/Dialect/AffineOps/AffineOps.h.inc" /// Returns if the provided value is the induction variable of a AffineForOp. -bool isForInductionVar(ValuePtr val); +bool isForInductionVar(Value val); /// Returns the loop parent of an induction variable. If the provided value is /// not an induction variable, then return nullptr. -AffineForOp getForInductionVarOwner(ValuePtr val); +AffineForOp getForInductionVarOwner(Value val); /// Extracts the induction variables from a list of AffineForOps and places them /// in the output argument `ivs`. void extractForInductionVars(ArrayRef forInsts, - SmallVectorImpl *ivs); + SmallVectorImpl *ivs); /// AffineBound represents a lower or upper bound in the for operation. /// This class does not own the underlying operands. Instead, it refers @@ -582,7 +577,7 @@ public: AffineValueMap getAsAffineValueMap(); unsigned getNumOperands() { return opEnd - opStart; } - ValuePtr getOperand(unsigned idx) { return op.getOperand(opStart + idx); } + Value getOperand(unsigned idx) { return op.getOperand(opStart + idx); } using operand_iterator = AffineForOp::operand_iterator; using operand_range = AffineForOp::operand_range; @@ -620,13 +615,13 @@ private: /// %1 = affine.apply () -> (0) /// ``` struct AffineApplyNormalizer { - AffineApplyNormalizer(AffineMap map, ArrayRef operands); + AffineApplyNormalizer(AffineMap map, ArrayRef operands); /// Returns the AffineMap resulting from normalization. AffineMap getAffineMap() { return affineMap; } - SmallVector getOperands() { - SmallVector res(reorderedDims); + SmallVector getOperands() { + SmallVector res(reorderedDims); res.append(concatenatedSymbols.begin(), concatenatedSymbols.end()); return res; } @@ -636,13 +631,13 @@ struct AffineApplyNormalizer { /// Normalizes 'otherMap' and its operands 'otherOperands' to map to this /// normalizer's coordinate space. - void normalize(AffineMap *otherMap, SmallVectorImpl *otherOperands); + void normalize(AffineMap *otherMap, SmallVectorImpl *otherOperands); private: /// Helper function to insert `v` into the coordinate system of the current /// AffineApplyNormalizer. Returns the AffineDimExpr with the corresponding /// renumbered position. - AffineDimExpr renumberOneDim(ValuePtr v); + AffineDimExpr renumberOneDim(Value v); /// Given an `other` normalizer, this rewrites `other.affineMap` in the /// coordinate system of the current AffineApplyNormalizer. @@ -651,12 +646,12 @@ private: AffineMap renumber(const AffineApplyNormalizer &other); /// Maps of Value to position in `affineMap`. - DenseMap dimValueToPosition; + DenseMap dimValueToPosition; /// Ordered dims and symbols matching positional dims and symbols in /// `affineMap`. - SmallVector reorderedDims; - SmallVector concatenatedSymbols; + SmallVector reorderedDims; + SmallVector concatenatedSymbols; AffineMap affineMap; diff --git a/mlir/include/mlir/Dialect/AffineOps/AffineOps.td b/mlir/include/mlir/Dialect/AffineOps/AffineOps.td index 715e3807a95..114e20513b2 100644 --- a/mlir/include/mlir/Dialect/AffineOps/AffineOps.td +++ b/mlir/include/mlir/Dialect/AffineOps/AffineOps.td @@ -92,7 +92,7 @@ def AffineForOp : Affine_Op<"for", static StringRef getUpperBoundAttrName() { return "upper_bound"; } Block *getBody() { return ®ion().front(); } - ValuePtr getInductionVar() { return getBody()->getArgument(0); } + Value getInductionVar() { return getBody()->getArgument(0); } OpBuilder getBodyBuilder() { return OpBuilder(getBody(), std::prev(getBody()->end())); } @@ -277,8 +277,8 @@ def AffinePrefetchOp : Affine_Op<"prefetch"> { BoolAttr:$isDataCache); let builders = [OpBuilder< - "Builder *builder, OperationState &result, ValuePtr memref," - "AffineMap map, ArrayRef mapOperands, bool isWrite," + "Builder *builder, OperationState &result, Value memref," + "AffineMap map, ArrayRef mapOperands, bool isWrite," "unsigned localityHint, bool isDataCache", [{ assert(map.getNumInputs() == mapOperands.size() @@ -306,7 +306,7 @@ def AffinePrefetchOp : Affine_Op<"prefetch"> { } /// Returns the AffineMapAttr associated with 'memref'. - NamedAttribute getAffineMapAttrForMemRef(ValuePtr mref) { + NamedAttribute getAffineMapAttrForMemRef(Value mref) { assert(mref == memref()); return {Identifier::get(getMapAttrName(), getContext()), getAffineMapAttr()}; diff --git a/mlir/include/mlir/Dialect/GPU/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/GPUDialect.h index c3ab6ec5729..1776ff71980 100644 --- a/mlir/include/mlir/Dialect/GPU/GPUDialect.h +++ b/mlir/include/mlir/Dialect/GPU/GPUDialect.h @@ -68,9 +68,9 @@ public: /// Utility class for the GPU dialect to represent triples of `Value`s /// accessible through `.x`, `.y`, and `.z` similarly to CUDA notation. struct KernelDim3 { - ValuePtr x; - ValuePtr y; - ValuePtr z; + Value x; + Value y; + Value z; }; #define GET_OP_CLASSES diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td index 037664d0d9b..b5b93e9b553 100644 --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -148,7 +148,7 @@ def GPU_GPUFuncOp : GPU_Op<"func", [FunctionLike, IsolatedFromAbove, Symbol]> { /// Returns a list of block arguments that correspond to buffers located in /// the workgroup memory - ArrayRef getWorkgroupAttributions() { + ArrayRef getWorkgroupAttributions() { auto begin = std::next(getBody().front().args_begin(), getType().getNumInputs()); auto end = std::next(begin, getNumWorkgroupAttributions()); @@ -157,7 +157,7 @@ def GPU_GPUFuncOp : GPU_Op<"func", [FunctionLike, IsolatedFromAbove, Symbol]> { /// Returns a list of block arguments that correspond to buffers located in /// the private memory. - ArrayRef getPrivateAttributions() { + ArrayRef getPrivateAttributions() { auto begin = std::next(getBody().front().args_begin(), getType().getNumInputs() + getNumWorkgroupAttributions()); @@ -273,8 +273,8 @@ def GPU_LaunchFuncOp : GPU_Op<"launch_func">, let builders = [ OpBuilder<"Builder *builder, OperationState &result, GPUFuncOp kernelFunc, " - "ValuePtr gridSizeX, ValuePtr gridSizeY, ValuePtr gridSizeZ, " - "ValuePtr blockSizeX, ValuePtr blockSizeY, ValuePtr blockSizeZ, " + "Value gridSizeX, Value gridSizeY, Value gridSizeZ, " + "Value blockSizeX, Value blockSizeY, Value blockSizeZ, " "ValueRange kernelOperands">, OpBuilder<"Builder *builder, OperationState &result, GPUFuncOp kernelFunc, " "KernelDim3 gridSize, KernelDim3 blockSize, " @@ -293,7 +293,7 @@ def GPU_LaunchFuncOp : GPU_Op<"launch_func">, StringRef getKernelModuleName(); /// The i-th operand passed to the kernel function. - ValuePtr getKernelOperand(unsigned i); + Value getKernelOperand(unsigned i); /// Get the SSA values passed as operands to specify the grid size. KernelDim3 getGridSizeOperandValues(); @@ -406,9 +406,9 @@ def GPU_LaunchOp : GPU_Op<"launch", [IsolatedFromAbove]>, let skipDefaultBuilders = 1; let builders = [ - OpBuilder<"Builder *builder, OperationState &result, ValuePtr gridSizeX," - "ValuePtr gridSizeY, ValuePtr gridSizeZ, ValuePtr blockSizeX," - "ValuePtr blockSizeY, ValuePtr blockSizeZ," + OpBuilder<"Builder *builder, OperationState &result, Value gridSizeX," + "Value gridSizeY, Value gridSizeZ, Value blockSizeX," + "Value blockSizeY, Value blockSizeZ," "ValueRange operands"> ]; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h index bef1f2dbf20..d36619bb9a9 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -185,9 +185,9 @@ private: /// surrounding the insertion point of builder. Obtain the address of that /// global and use it to compute the address of the first character in the /// string (operations inserted at the builder insertion point). -ValuePtr createGlobalString(Location loc, OpBuilder &builder, StringRef name, - StringRef value, LLVM::Linkage linkage, - LLVM::LLVMDialect *llvmDialect); +Value createGlobalString(Location loc, OpBuilder &builder, StringRef name, + StringRef value, LLVM::Linkage linkage, + LLVM::LLVMDialect *llvmDialect); /// LLVM requires some operations to be inside of a Module operation. This /// function confirms that the Operation has the desired properties. diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 46f63206ef5..2e47eb03474 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -176,8 +176,8 @@ def LLVM_ICmpOp : LLVM_OneResultOp<"icmp", [NoSideEffect]>, $res = builder.CreateICmp(getLLVMCmpPredicate($predicate), $lhs, $rhs); }]; let builders = [OpBuilder< - "Builder *b, OperationState &result, ICmpPredicate predicate, ValuePtr lhs, " - "ValuePtr rhs", [{ + "Builder *b, OperationState &result, ICmpPredicate predicate, Value lhs, " + "Value rhs", [{ LLVMDialect *dialect = &lhs->getType().cast().getDialect(); build(b, result, LLVMType::getInt1Ty(dialect), b->getI64IntegerAttr(static_cast(predicate)), lhs, rhs); @@ -223,8 +223,8 @@ def LLVM_FCmpOp : LLVM_OneResultOp<"fcmp", [NoSideEffect]>, $res = builder.CreateFCmp(getLLVMCmpPredicate($predicate), $lhs, $rhs); }]; let builders = [OpBuilder< - "Builder *b, OperationState &result, FCmpPredicate predicate, ValuePtr lhs, " - "ValuePtr rhs", [{ + "Builder *b, OperationState &result, FCmpPredicate predicate, Value lhs, " + "Value rhs", [{ LLVMDialect *dialect = &lhs->getType().cast().getDialect(); build(b, result, LLVMType::getInt1Ty(dialect), b->getI64IntegerAttr(static_cast(predicate)), lhs, rhs); @@ -256,7 +256,7 @@ def LLVM_AllocaOp : $res = alloca; }]; let builders = [OpBuilder< - "Builder *b, OperationState &result, Type resultType, ValuePtr arraySize, " + "Builder *b, OperationState &result, Type resultType, Value arraySize, " "unsigned alignment", [{ if (alignment == 0) @@ -283,7 +283,7 @@ def LLVM_GEPOp : LLVM_OneResultOp<"getelementptr", [NoSideEffect]>, def LLVM_LoadOp : LLVM_OneResultOp<"load">, Arguments<(ins LLVM_Type:$addr)>, LLVM_Builder<"$res = builder.CreateLoad($addr);"> { let builders = [OpBuilder< - "Builder *b, OperationState &result, ValuePtr addr", + "Builder *b, OperationState &result, Value addr", [{ auto type = addr->getType().cast().getPointerElementTy(); build(b, result, type, addr); @@ -344,7 +344,7 @@ def LLVM_ExtractElementOp : LLVM_OneResultOp<"extractelement", [NoSideEffect]>, $res = builder.CreateExtractElement($vector, $position); }]; let builders = [OpBuilder< - "Builder *b, OperationState &result, ValuePtr vector, ValuePtr position," + "Builder *b, OperationState &result, Value vector, Value position," "ArrayRef attrs = {}">]; let parser = [{ return parseExtractElementOp(parser, result); }]; let printer = [{ printExtractElementOp(p, *this); }]; @@ -375,7 +375,7 @@ def LLVM_InsertValueOp : LLVM_OneResultOp<"insertvalue", [NoSideEffect]>, extractPosition($position)); }]; let builders = [OpBuilder< - "Builder *b, OperationState &result, ValuePtr container, ValuePtr value, " + "Builder *b, OperationState &result, Value container, Value value, " "ArrayAttr position", [{ build(b, result, container->getType(), container, value, position); @@ -389,7 +389,7 @@ def LLVM_ShuffleVectorOp LLVM_Builder< "$res = builder.CreateShuffleVector($v1, $v2, extractPosition($mask));"> { let builders = [OpBuilder< - "Builder *b, OperationState &result, ValuePtr v1, ValuePtr v2, " + "Builder *b, OperationState &result, Value v1, Value v2, " "ArrayAttr mask, ArrayRef attrs = {}">]; let verifier = [{ auto wrappedVectorType1 = v1()->getType().cast(); @@ -413,8 +413,8 @@ def LLVM_SelectOp LLVM_Builder< "$res = builder.CreateSelect($condition, $trueValue, $falseValue);"> { let builders = [OpBuilder< - "Builder *b, OperationState &result, ValuePtr condition, ValuePtr lhs, " - "ValuePtr rhs", [{ + "Builder *b, OperationState &result, Value condition, Value lhs, " + "Value rhs", [{ build(b, result, lhs->getType(), condition, lhs, rhs); }]>]; let parser = [{ return parseSelectOp(parser, result); }]; diff --git a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h index 1a2d6b9b3ba..dd5034e823c 100644 --- a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h +++ b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h @@ -28,15 +28,15 @@ class LinalgOp; class Aliases { public: /// Returns true if v1 and v2 alias. - bool alias(ValuePtr v1, ValuePtr v2) { return find(v1) == find(v2); } + bool alias(Value v1, Value v2) { return find(v1) == find(v2); } private: /// Returns the base buffer or block argument into which the view `v` aliases. /// This lazily records the new aliases discovered while walking back the /// use-def chain. - ValuePtr find(ValuePtr v); + Value find(Value v); - DenseMap aliases; + DenseMap aliases; }; /// Data structure for holding a dependence graph that operates on LinalgOp and @@ -45,7 +45,7 @@ class LinalgDependenceGraph { public: struct LinalgOpView { Operation *op; - ValuePtr view; + Value view; }; struct LinalgDependenceGraphElem { // dependentOpView may be either: @@ -55,7 +55,7 @@ public: // View in the op that is used to index in the graph: // 1. src in the case of dependencesFromDstGraphs. // 2. dst in the case of dependencesIntoGraphs. - ValuePtr indexingView; + Value indexingView; }; using LinalgDependences = SmallVector; using DependenceGraph = DenseMap; @@ -88,14 +88,14 @@ public: /// Dependences are restricted to views aliasing `view`. SmallVector findCoveringReads(LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, - ValuePtr view) const; + Value view) const; /// Returns the operations that are interleaved between `srcLinalgOp` and /// `dstLinalgOp` and that are involved in a WAR or WAW with `srcLinalgOp`. /// Dependences are restricted to views aliasing `view`. SmallVector findCoveringWrites(LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, - ValuePtr view) const; + Value view) const; private: // Keep dependences in both directions, this is not just a performance gain @@ -121,7 +121,7 @@ private: /// Implementation detail for findCoveringxxx. SmallVector findOperationsWithCoveringDependences(LinalgOp srcLinalgOp, - LinalgOp dstLinalgOp, ValuePtr view, + LinalgOp dstLinalgOp, Value view, ArrayRef types) const; Aliases &aliases; diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h index d0f6c942b95..97fbede1cc7 100644 --- a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h @@ -46,35 +46,34 @@ inline StringRef toString(IterType t) { /// makeLinalgGenericOp({A({m, n}), B({k, n})}, {C({m, n})}, ... ); /// ``` struct StructuredIndexed { - StructuredIndexed(ValuePtr v) : value(v) {} + StructuredIndexed(Value v) : value(v) {} StructuredIndexed operator()(ArrayRef indexings) { return StructuredIndexed(value, indexings); } - operator ValuePtr() const /* implicit */ { return value; } + operator Value() const /* implicit */ { return value; } ArrayRef getExprs() { return exprs; } private: - StructuredIndexed(ValuePtr v, ArrayRef indexings) + StructuredIndexed(Value v, ArrayRef indexings) : value(v), exprs(indexings.begin(), indexings.end()) { assert(v->getType().isa() && "MemRefType expected"); } StructuredIndexed(ValueHandle v, ArrayRef indexings) : StructuredIndexed(v.getValue(), indexings) {} - ValuePtr value; + Value value; SmallVector exprs; }; -inline void defaultRegionBuilder(ArrayRef args) {} +inline void defaultRegionBuilder(ArrayRef args) {} -Operation *makeLinalgGenericOp(ArrayRef iteratorTypes, - ArrayRef inputs, - ArrayRef outputs, - function_ref)> - regionBuilder = defaultRegionBuilder, - ArrayRef otherValues = {}, - ArrayRef otherAttributes = {}); +Operation *makeLinalgGenericOp( + ArrayRef iteratorTypes, ArrayRef inputs, + ArrayRef outputs, + function_ref)> regionBuilder = + defaultRegionBuilder, + ArrayRef otherValues = {}, ArrayRef otherAttributes = {}); namespace ops { using edsc::StructuredIndexed; @@ -87,7 +86,7 @@ using edsc::intrinsics::linalg_yield; /// Build the body of a region to compute a multiply-accumulate, under the /// current ScopedContext, at the current insert point. -void macRegionBuilder(ArrayRef args); +void macRegionBuilder(ArrayRef args); /// TODO(ntv): In the future we should tie these implementations to something in /// Tablegen that generates the proper interfaces and the proper sugared named @@ -111,7 +110,7 @@ void macRegionBuilder(ArrayRef args); /// with in-place semantics and parallelism. /// Unary pointwise operation (with broadcast) entry point. -using UnaryPointwiseOpBuilder = function_ref; +using UnaryPointwiseOpBuilder = function_ref; Operation *linalg_pointwise(UnaryPointwiseOpBuilder unaryOp, StructuredIndexed I, StructuredIndexed O); @@ -121,8 +120,7 @@ Operation *linalg_pointwise(UnaryPointwiseOpBuilder unaryOp, Operation *linalg_pointwise_tanh(StructuredIndexed I, StructuredIndexed O); /// Binary pointwise operation (with broadcast) entry point. -using BinaryPointwiseOpBuilder = - function_ref; +using BinaryPointwiseOpBuilder = function_ref; Operation *linalg_pointwise(BinaryPointwiseOpBuilder binaryOp, StructuredIndexed I1, StructuredIndexed I2, StructuredIndexed O); diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td index e52019d7992..6fdb8a644af 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td @@ -83,22 +83,22 @@ def LinalgLibraryInterface : OpInterface<"LinalgOp"> { "Query the number of loops within the current operation.", "unsigned", "getNumLoops">, InterfaceMethod<"Query the input view at the given index.", - "ValuePtr ", "getInput", (ins "unsigned":$i) + "Value ", "getInput", (ins "unsigned":$i) >, InterfaceMethod<"Query the output view at the given index.", - "ValuePtr ", "getOutput", (ins "unsigned":$i) + "Value ", "getOutput", (ins "unsigned":$i) >, InterfaceMethod<[{ Query the index of the given input value, or `None` if the value is not an input. }], - "Optional", "getIndexOfInput", (ins "ValuePtr ":$view) + "Optional", "getIndexOfInput", (ins "Value ":$view) >, InterfaceMethod<[{ Query the index of the given view value, or `None` if the value is not an view. }], - "Optional", "getIndexOfOutput", (ins "ValuePtr ":$view) + "Optional", "getIndexOfOutput", (ins "Value ":$view) >, InterfaceMethod<[{ Query the type of the input view at the given index. @@ -219,7 +219,7 @@ def CopyOp : LinalgLibrary_Op<"copy", [NInputs<1>, NOutputs<1>]> { // TODO(ntv) this should go away once the usage of OptionalAttr triggers // emission of builders with default arguments left unspecified. let builders = [OpBuilder< - "Builder *builder, OperationState &result, ValuePtr input, ValuePtr output", [{ + "Builder *builder, OperationState &result, Value input, Value output", [{ return build( builder, result, input, output, AffineMapAttr(), AffineMapAttr()); }]>]; diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td index 728fa619dbe..0445968ee80 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -47,8 +47,8 @@ def Linalg_RangeOp : ```` }]; let builders = [OpBuilder< - "Builder *builder, OperationState &result, ValuePtr min, ValuePtr max, " - "ValuePtr step", + "Builder *builder, OperationState &result, Value min, Value max, " + "Value step", [{ auto rangeType = RangeType::get(builder->getContext()); build(builder, result, rangeType, min, max, step); @@ -103,7 +103,7 @@ def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>, }]; let builders = [OpBuilder< - "Builder *b, OperationState &result, ValuePtr base, " + "Builder *b, OperationState &result, Value base, " "ValueRange indexings">]; let extraClassDeclaration = [{ @@ -115,11 +115,11 @@ def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>, MemRefType getBaseViewType() { return view()->getType().cast(); } // Get the underlying indexing at a given rank. - ValuePtr indexing(unsigned rank) { return *(indexings().begin() + rank); } + Value indexing(unsigned rank) { return *(indexings().begin() + rank); } // Get the subset of indexings that are of RangeType. - SmallVector getRanges() { - SmallVector res; + SmallVector getRanges() { + SmallVector res; for (auto operand : indexings()) if (!operand->getType().isa()) res.push_back(operand); @@ -145,7 +145,7 @@ def Linalg_TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>, }]; let builders = [OpBuilder< - "Builder *b, OperationState &result, ValuePtr view, " + "Builder *b, OperationState &result, Value view, " "AffineMapAttr permutation, ArrayRef attrs = {}">]; let verifier = [{ diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index 8674c277e4a..dd9e09b8eae 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -83,22 +83,22 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { "Query the number of loops within the current operation.", "unsigned", "getNumLoops">, InterfaceMethod<"Query the input view at the given index.", - "ValuePtr ", "getInput", (ins "unsigned":$i) + "Value ", "getInput", (ins "unsigned":$i) >, InterfaceMethod<"Query the output view at the given index.", - "ValuePtr ", "getOutput", (ins "unsigned":$i) + "Value ", "getOutput", (ins "unsigned":$i) >, InterfaceMethod<[{ Query the index of the given input value, or `None` if the value is not an input. }], - "llvm::Optional", "getIndexOfInput", (ins "ValuePtr ":$view) + "llvm::Optional", "getIndexOfInput", (ins "Value ":$view) >, InterfaceMethod<[{ Query the index of the given view value, or `None` if the value is not an view. }], - "llvm::Optional", "getIndexOfOutput", (ins "ValuePtr ":$view) + "llvm::Optional", "getIndexOfOutput", (ins "Value ":$view) >, InterfaceMethod<[{ Query the type of the input view at the given index. @@ -219,7 +219,7 @@ def CopyOp : LinalgStructured_Op<"copy", [NInputs<1>, NOutputs<1>]> { // TODO(ntv) this should go away once the usage of OptionalAttr triggers // emission of builders with default arguments left unspecified. let builders = [OpBuilder< - "Builder *builder, OperationState &result, ValuePtr input, ValuePtr output", [{ + "Builder *builder, OperationState &result, Value input, Value output", [{ return build( builder, result, input, output, AffineMapAttr(), AffineMapAttr()); }]>]; diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h index 7399aad6663..e0d651806d3 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h @@ -68,13 +68,13 @@ private: public: /// Return the `i`-th input view. - ValuePtr getInput(unsigned i) { + Value getInput(unsigned i) { assert(i < nInputs()); return this->getOperation()->getOperand(i); } /// Return the index of `view` in the list of input views if found, llvm::None /// otherwise. - Optional getIndexOfInput(ValuePtr view) { + Optional getIndexOfInput(Value view) { auto it = llvm::find(getInputs(), view); if (it != getInputs().end()) return it - getInputs().begin(); @@ -90,12 +90,12 @@ public: return {range.begin(), range.begin() + nInputs()}; } /// Return the `i`-th output view. - ValuePtr getOutput(unsigned i) { + Value getOutput(unsigned i) { return this->getOperation()->getOperand(nInputs() + i); } /// Return the index of `view` in the list of output views if found, /// llvm::None otherwise. - Optional getIndexOfOutput(ValuePtr view) { + Optional getIndexOfOutput(Value view) { auto it = llvm::find(getOutputs(), view); if (it != getOutputs().end()) return it - getOutputs().begin(); diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td index 448ffdf7d4b..8f6762f0048 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td @@ -36,7 +36,7 @@ class AffineMapDomainHasDim : CPred<[{ class HasOperandsOfType: CPred<[{ llvm::any_of($0.getOperands(), - [](ValuePtr v) { + [](Value v) { return dyn_cast_or_null<}] # type # [{>(v->getDefiningOp()); }) }]>; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h index a88dc4105e2..757ee3ad1a7 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h @@ -29,7 +29,7 @@ struct LinalgTransforms { namespace detail { // Implementation detail of isProducedByOpOfType avoids the need for explicit // template instantiations. -bool isProducedByOpOfTypeImpl(Operation *consumerOp, ValuePtr consumedView, +bool isProducedByOpOfTypeImpl(Operation *consumerOp, Value consumedView, function_ref isaOpType); } // namespace detail @@ -37,7 +37,7 @@ bool isProducedByOpOfTypeImpl(Operation *consumerOp, ValuePtr consumedView, // an op of type `OpTy`. This is used to implement use-def type information on // buffers. template -bool isProducedByOpOfType(Operation *consumerOp, ValuePtr consumedView) { +bool isProducedByOpOfType(Operation *consumerOp, Value consumedView) { return detail::isProducedByOpOfTypeImpl( consumerOp, consumedView, [](Operation *op) { return isa(op); }); } diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 1b45179bc9e..996658b4c5c 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -33,7 +33,7 @@ public: /// variable. A ValueHandle pointer is passed as the first argument and is the /// *only* way to capture the loop induction variable. LoopRangeBuilder(ValueHandle *iv, ValueHandle range); - LoopRangeBuilder(ValueHandle *iv, ValuePtr range); + LoopRangeBuilder(ValueHandle *iv, Value range); LoopRangeBuilder(ValueHandle *iv, SubViewOp::Range range); LoopRangeBuilder(const LoopRangeBuilder &) = delete; @@ -56,7 +56,7 @@ public: LoopNestRangeBuilder(ArrayRef ivs, ArrayRef ranges); LoopNestRangeBuilder(ArrayRef ivs, - ArrayRef ranges); + ArrayRef ranges); LoopNestRangeBuilder(ArrayRef ivs, ArrayRef ranges); edsc::ValueHandle operator()(std::function fun = nullptr); @@ -79,14 +79,14 @@ struct FusionInfo { /// whole `consumedView`. This checks structural dominance, that the dependence /// is a RAW without any interleaved write to any piece of `consumedView`. bool isProducerLastWriteOfView(const LinalgDependenceGraph &graph, - LinalgOp consumer, ValuePtr consumedView, + LinalgOp consumer, Value consumedView, LinalgOp producer); /// Checks whether fusing the specific `producer` of the `consumedView` is /// feasible. This checks `producer` is the last write of `consumedView` and /// that no interleaved dependence would be violated (RAW, WAR or WAW). bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer, - ValuePtr consumedView, LinalgOp producer); + Value consumedView, LinalgOp producer); /// Fuses producer into consumer if the producer is structurally feasible and /// the fusion would not violate dependencies. @@ -102,8 +102,8 @@ Optional fuseProducerOf(OpBuilder &b, LinalgOp consumer, /// the inverse, concatenated loopToOperandRangeMaps to this list allows the /// derivation of loop ranges for any linalgOp. template -SmallVector getViewSizes(ConcreteOp linalgOp) { - SmallVector res; +SmallVector getViewSizes(ConcreteOp linalgOp) { + SmallVector res; for (auto v : linalgOp.getInputsAndOutputs()) { MemRefType t = v->getType().template cast(); for (unsigned i = 0; i < t.getRank(); ++i) @@ -116,10 +116,9 @@ SmallVector getViewSizes(ConcreteOp linalgOp) { /// When non-null, the optional pointer `folder` is used to call into the /// `createAndFold` builder method. If `folder` is null, the regular `create` /// method is called. -SmallVector applyMapToValues(OpBuilder &b, Location loc, - AffineMap map, - ArrayRef values, - OperationFolder *folder = nullptr); +SmallVector applyMapToValues(OpBuilder &b, Location loc, + AffineMap map, ArrayRef values, + OperationFolder *folder = nullptr); struct TiledLinalgOp { LinalgOp op; @@ -142,7 +141,7 @@ struct TiledLinalgOp { /// `createAndFold` builder method. If `folder` is null, the regular `create` /// method is called. Optional tileLinalgOp(OpBuilder &b, LinalgOp op, - ArrayRef tileSizes, + ArrayRef tileSizes, ArrayRef permutation = {}, OperationFolder *folder = nullptr); @@ -173,9 +172,9 @@ Optional tileLinalgOperation(OpBuilder &b, Operation *op, } struct PromotionInfo { - ValuePtr buffer; - ValuePtr fullLocalView; - ValuePtr partialLocalView; + Value buffer; + Value fullLocalView; + Value partialLocalView; }; /// Promotes the `subViews` into a new buffer allocated at the insertion point @@ -190,13 +189,13 @@ struct PromotionInfo { /// Returns a list of PromotionInfo which hold the promoted buffer and the /// full and partial views indexing into the buffer. SmallVector -promoteSubViews(OpBuilder &b, Location loc, ArrayRef subViews, +promoteSubViews(OpBuilder &b, Location loc, ArrayRef subViews, bool dynamicBuffers = false, OperationFolder *folder = nullptr); /// Returns all the operands of `linalgOp` that are not views. /// Asserts that these operands are value types to allow transformations like /// tiling to just use the values when cloning `linalgOp`. -SmallVector getAssumedNonViewOperands(LinalgOp linalgOp); +SmallVector getAssumedNonViewOperands(LinalgOp linalgOp); /// Apply the permutation defined by `permutation` to `inVec`. /// Element `i` in `inVec` is mapped to location `j = permutation[i]`. @@ -217,7 +216,7 @@ void applyPermutationToVector(SmallVector &inVec, /// It is the entry point for declarative transformation /// Returns the cloned `LinalgOp` with the new operands LinalgOp promoteSubViewOperands(OpBuilder &b, LinalgOp op, - llvm::SetVector subViews, + llvm::SetVector subViews, bool dynamicBuffers = false, OperationFolder *folder = nullptr); diff --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.h b/mlir/include/mlir/Dialect/LoopOps/LoopOps.h index dba5e819986..2617d7fd783 100644 --- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.h +++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.h @@ -41,7 +41,7 @@ void ensureLoopTerminator(Region ®ion, Builder &builder, Location loc); /// Returns the loop parent of an induction variable. If the provided value is /// not an induction variable, then return nullptr. -ForOp getForInductionVarOwner(ValuePtr val); +ForOp getForInductionVarOwner(Value val); } // end namespace loop } // end namespace mlir diff --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td index 3b0f120441a..707b788aaa8 100644 --- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td +++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td @@ -65,18 +65,18 @@ def ForOp : Loop_Op<"for", let skipDefaultBuilders = 1; let builders = [ OpBuilder<"Builder *builder, OperationState &result, " - "ValuePtr lowerBound, ValuePtr upperBound, ValuePtr step"> + "Value lowerBound, Value upperBound, Value step"> ]; let extraClassDeclaration = [{ Block *getBody() { return ®ion().front(); } - ValuePtr getInductionVar() { return getBody()->getArgument(0); } + Value getInductionVar() { return getBody()->getArgument(0); } OpBuilder getBodyBuilder() { return OpBuilder(getBody(), std::prev(getBody()->end())); } - void setLowerBound(ValuePtr bound) { getOperation()->setOperand(0, bound); } - void setUpperBound(ValuePtr bound) { getOperation()->setOperand(1, bound); } - void setStep(ValuePtr step) { getOperation()->setOperand(2, step); } + void setLowerBound(Value bound) { getOperation()->setOperand(0, bound); } + void setUpperBound(Value bound) { getOperation()->setOperand(1, bound); } + void setStep(Value step) { getOperation()->setOperand(2, step); } }]; } @@ -107,7 +107,7 @@ def IfOp : Loop_Op<"if", let skipDefaultBuilders = 1; let builders = [ OpBuilder<"Builder *builder, OperationState &result, " - "ValuePtr cond, bool withElseRegion"> + "Value cond, bool withElseRegion"> ]; let extraClassDeclaration = [{ diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td index 7bd88ab66e0..5a8235fff1a 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td @@ -111,7 +111,7 @@ def SPV_CompositeExtractOp : SPV_Op<"CompositeExtract", [NoSideEffect]> { let builders = [ OpBuilder<[{Builder *builder, OperationState &state, - ValuePtr composite, ArrayRef indices}]> + Value composite, ArrayRef indices}]> ]; let hasFolder = 1; diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td index bc06c0289db..be095579451 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td @@ -123,7 +123,7 @@ def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", let builders = [ OpBuilder< - "Builder *builder, OperationState &state, ValuePtr condition, " + "Builder *builder, OperationState &state, Value condition, " "Block *trueBlock, ValueRange trueArguments, " "Block *falseBlock, ValueRange falseArguments, " "Optional> weights = {}", diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td index 4057f47931c..ac377d5e866 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td @@ -849,8 +849,8 @@ def SPV_SelectOp : SPV_Op<"Select", [NoSideEffect]> { ); let builders = [OpBuilder<[{Builder *builder, OperationState &state, - ValuePtr cond, ValuePtr trueValue, - ValuePtr falseValue}]>]; + Value cond, Value trueValue, + Value falseValue}]>]; } // ----- diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h index e7cf250cc3a..0f481f5956d 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h @@ -55,8 +55,8 @@ protected: namespace spirv { /// Returns a value that represents a builtin variable value within the SPIR-V /// module. -ValuePtr getBuiltinVariableValue(Operation *op, spirv::BuiltIn builtin, - OpBuilder &builder); +Value getBuiltinVariableValue(Operation *op, spirv::BuiltIn builtin, + OpBuilder &builder); /// Attribute name for specifying argument ABI information. StringRef getInterfaceVarABIAttrName(); diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td index f657d5847d0..1ce28928c41 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td @@ -93,7 +93,7 @@ def SPV_AccessChainOp : SPV_Op<"AccessChain", [NoSideEffect]> { ); let builders = [OpBuilder<[{Builder *builder, OperationState &state, - ValuePtr basePtr, ValueRange indices}]>]; + Value basePtr, ValueRange indices}]>]; let hasCanonicalizer = 1; } @@ -263,7 +263,7 @@ def SPV_LoadOp : SPV_Op<"Load", []> { ); let builders = [OpBuilder<[{Builder *builder, OperationState &state, - ValuePtr basePtr, /*optional*/IntegerAttr memory_access, + Value basePtr, /*optional*/IntegerAttr memory_access, /*optional*/IntegerAttr alignment}]>]; } @@ -358,7 +358,7 @@ def SPV_StoreOp : SPV_Op<"Store", []> { let builders = [ OpBuilder<"Builder *builder, OperationState &state, " - "ValuePtr ptr, ValuePtr value, ArrayRef namedAttrs", [{ + "Value ptr, Value value, ArrayRef namedAttrs", [{ state.addOperands(ptr); state.addOperands(value); state.addAttributes(namedAttrs); diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.h b/mlir/include/mlir/Dialect/StandardOps/Ops.h index e3ec6f1f7d6..0ba16c56f8e 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.h @@ -173,15 +173,14 @@ class DmaStartOp public: using Op::Op; - static void build(Builder *builder, OperationState &result, - ValuePtr srcMemRef, ValueRange srcIndices, - ValuePtr destMemRef, ValueRange destIndices, - ValuePtr numElements, ValuePtr tagMemRef, - ValueRange tagIndices, ValuePtr stride = nullptr, - ValuePtr elementsPerStride = nullptr); + static void build(Builder *builder, OperationState &result, Value srcMemRef, + ValueRange srcIndices, Value destMemRef, + ValueRange destIndices, Value numElements, Value tagMemRef, + ValueRange tagIndices, Value stride = nullptr, + Value elementsPerStride = nullptr); // Returns the source MemRefType for this DMA operation. - ValuePtr getSrcMemRef() { return getOperand(0); } + Value getSrcMemRef() { return getOperand(0); } // Returns the rank (number of indices) of the source MemRefType. unsigned getSrcMemRefRank() { return getSrcMemRef()->getType().cast().getRank(); @@ -193,7 +192,7 @@ public: } // Returns the destination MemRefType for this DMA operations. - ValuePtr getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); } + Value getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); } // Returns the rank (number of indices) of the destination MemRefType. unsigned getDstMemRefRank() { return getDstMemRef()->getType().cast().getRank(); @@ -213,12 +212,12 @@ public: } // Returns the number of elements being transferred by this DMA operation. - ValuePtr getNumElements() { + Value getNumElements() { return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank()); } // Returns the Tag MemRef for this DMA operation. - ValuePtr getTagMemRef() { + Value getTagMemRef() { return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1); } // Returns the rank (number of indices) of the tag MemRefType. @@ -267,13 +266,13 @@ public: 1 + 1 + getTagMemRefRank(); } - ValuePtr getStride() { + Value getStride() { if (!isStrided()) return nullptr; return getOperand(getNumOperands() - 1 - 1); } - ValuePtr getNumElementsPerStride() { + Value getNumElementsPerStride() { if (!isStrided()) return nullptr; return getOperand(getNumOperands() - 1); @@ -298,14 +297,13 @@ class DmaWaitOp public: using Op::Op; - static void build(Builder *builder, OperationState &result, - ValuePtr tagMemRef, ValueRange tagIndices, - ValuePtr numElements); + static void build(Builder *builder, OperationState &result, Value tagMemRef, + ValueRange tagIndices, Value numElements); static StringRef getOperationName() { return "std.dma_wait"; } // Returns the Tag MemRef associated with the DMA operation being waited on. - ValuePtr getTagMemRef() { return getOperand(0); } + Value getTagMemRef() { return getOperand(0); } // Returns the tag memref index for this DMA operation. operand_range getTagIndices() { @@ -319,7 +317,7 @@ public: } // Returns the number of elements transferred in the associated DMA operation. - ValuePtr getNumElements() { return getOperand(1 + getTagMemRefRank()); } + Value getNumElements() { return getOperand(1 + getTagMemRefRank()); } static ParseResult parse(OpAsmParser &parser, OperationState &result); void print(OpAsmPrinter &p); @@ -334,7 +332,7 @@ void printDimAndSymbolList(Operation::operand_iterator begin, /// Parses dimension and symbol list and returns true if parsing failed. ParseResult parseDimAndSymbolList(OpAsmParser &parser, - SmallVectorImpl &operands, + SmallVectorImpl &operands, unsigned &numDims); raw_ostream &operator<<(raw_ostream &os, SubViewOp::Range &range); diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td index c31b3dc9395..1c8bb251c02 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -43,7 +43,7 @@ class CastOp traits = []> : let results = (outs AnyType); let builders = [OpBuilder< - "Builder *builder, OperationState &result, ValuePtr source, Type destType", [{ + "Builder *builder, OperationState &result, Value source, Type destType", [{ impl::buildCastOp(builder, result, source, destType); }]>]; @@ -182,7 +182,7 @@ def AllocOp : Std_Op<"alloc"> { }]>, OpBuilder< "Builder *builder, OperationState &result, MemRefType memrefType, " # - "ArrayRef operands, IntegerAttr alignment = IntegerAttr()", [{ + "ArrayRef operands, IntegerAttr alignment = IntegerAttr()", [{ result.addOperands(operands); result.types.push_back(memrefType); if (alignment) @@ -321,7 +321,7 @@ def CallIndirectOp : Std_Op<"call_indirect", [CallOpInterface]> { let results = (outs Variadic); let builders = [OpBuilder< - "Builder *, OperationState &result, ValuePtr callee," + "Builder *, OperationState &result, Value callee," "ValueRange operands = {}", [{ result.operands.push_back(callee); result.addOperands(operands); @@ -329,7 +329,7 @@ def CallIndirectOp : Std_Op<"call_indirect", [CallOpInterface]> { }]>]; let extraClassDeclaration = [{ - ValuePtr getCallee() { return getOperand(0); } + Value getCallee() { return getOperand(0); } /// Get the argument operands to the called function. operand_range getArgOperands() { @@ -386,7 +386,7 @@ def CmpFOp : Std_Op<"cmpf", let builders = [OpBuilder< "Builder *builder, OperationState &result, CmpFPredicate predicate," - "ValuePtr lhs, ValuePtr rhs", [{ + "Value lhs, Value rhs", [{ ::buildCmpFOp(builder, result, predicate, lhs, rhs); }]>]; @@ -454,7 +454,7 @@ def CmpIOp : Std_Op<"cmpi", let builders = [OpBuilder< "Builder *builder, OperationState &result, CmpIPredicate predicate," - "ValuePtr lhs, ValuePtr rhs", [{ + "Value lhs, Value rhs", [{ ::buildCmpIOp(builder, result, predicate, lhs, rhs); }]>]; @@ -493,7 +493,7 @@ def CondBranchOp : Std_Op<"cond_br", [Terminator]> { let arguments = (ins I1:$condition, Variadic:$branchOperands); let builders = [OpBuilder< - "Builder *, OperationState &result, ValuePtr condition," + "Builder *, OperationState &result, Value condition," "Block *trueDest, ValueRange trueOperands," "Block *falseDest, ValueRange falseOperands", [{ result.addOperands(condition); @@ -509,7 +509,7 @@ def CondBranchOp : Std_Op<"cond_br", [Terminator]> { enum { trueIndex = 0, falseIndex = 1 }; // The condition operand is the first operand in the list. - ValuePtr getCondition() { return getOperand(0); } + Value getCondition() { return getOperand(0); } /// Return the destination if the condition is true. Block *getTrueDest() { @@ -522,12 +522,12 @@ def CondBranchOp : Std_Op<"cond_br", [Terminator]> { } // Accessors for operands to the 'true' destination. - ValuePtr getTrueOperand(unsigned idx) { + Value getTrueOperand(unsigned idx) { assert(idx < getNumTrueOperands()); return getOperand(getTrueDestOperandIndex() + idx); } - void setTrueOperand(unsigned idx, ValuePtr value) { + void setTrueOperand(unsigned idx, Value value) { assert(idx < getNumTrueOperands()); setOperand(getTrueDestOperandIndex() + idx, value); } @@ -552,11 +552,11 @@ def CondBranchOp : Std_Op<"cond_br", [Terminator]> { } // Accessors for operands to the 'false' destination. - ValuePtr getFalseOperand(unsigned idx) { + Value getFalseOperand(unsigned idx) { assert(idx < getNumFalseOperands()); return getOperand(getFalseDestOperandIndex() + idx); } - void setFalseOperand(unsigned idx, ValuePtr value) { + void setFalseOperand(unsigned idx, Value value) { assert(idx < getNumFalseOperands()); setOperand(getFalseDestOperandIndex() + idx, value); } @@ -669,7 +669,7 @@ def DimOp : Std_Op<"dim", [NoSideEffect]> { let results = (outs Index); let builders = [OpBuilder< - "Builder *builder, OperationState &result, ValuePtr memrefOrTensor," + "Builder *builder, OperationState &result, Value memrefOrTensor," "unsigned index", [{ auto indexType = builder->getIndexType(); auto indexAttr = builder->getIntegerAttr(indexType, index); @@ -721,7 +721,7 @@ def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> { let results = (outs AnyType); let builders = [OpBuilder< - "Builder *builder, OperationState &result, ValuePtr aggregate," + "Builder *builder, OperationState &result, Value aggregate," "ValueRange indices = {}", [{ auto resType = aggregate->getType().cast() .getElementType(); @@ -729,7 +729,7 @@ def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> { }]>]; let extraClassDeclaration = [{ - ValuePtr getAggregate() { return getOperand(0); } + Value getAggregate() { return getOperand(0); } operand_range getIndices() { return {operand_begin() + 1, operand_end()}; @@ -807,7 +807,7 @@ def LoadOp : Std_Op<"load"> { let results = (outs AnyType); let builders = [OpBuilder< - "Builder *, OperationState &result, ValuePtr memref," + "Builder *, OperationState &result, Value memref," "ValueRange indices = {}", [{ auto memrefType = memref->getType().cast(); result.addOperands(memref); @@ -816,8 +816,8 @@ def LoadOp : Std_Op<"load"> { }]>]; let extraClassDeclaration = [{ - ValuePtr getMemRef() { return getOperand(0); } - void setMemRef(ValuePtr value) { setOperand(0, value); } + Value getMemRef() { return getOperand(0); } + void setMemRef(Value value) { setOperand(0, value); } MemRefType getMemRefType() { return getMemRef()->getType().cast(); } @@ -943,8 +943,8 @@ def PrefetchOp : Std_Op<"prefetch"> { BoolAttr:$isDataCache); let builders = [OpBuilder< - "Builder *builder, OperationState &result, ValuePtr memref," - "ArrayRef indices, bool isWrite, unsigned hint, bool isData", + "Builder *builder, OperationState &result, Value memref," + "ArrayRef indices, bool isWrite, unsigned hint, bool isData", [{ auto hintAttr = builder->getI32IntegerAttr(hint); auto isWriteAttr = builder->getBoolAttr(isWrite); @@ -981,7 +981,7 @@ def RankOp : Std_Op<"rank", [NoSideEffect]> { let verifier = ?; let builders = [OpBuilder< - "Builder *builder, OperationState &result, ValuePtr tensor", [{ + "Builder *builder, OperationState &result, Value tensor", [{ auto indexType = builder->getIndexType(); build(builder, result, indexType, tensor); }]>]; @@ -1043,16 +1043,16 @@ def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape]> { let results = (outs AnyType); let builders = [OpBuilder< - "Builder *builder, OperationState &result, ValuePtr condition," - "ValuePtr trueValue, ValuePtr falseValue", [{ + "Builder *builder, OperationState &result, Value condition," + "Value trueValue, Value falseValue", [{ result.addOperands({condition, trueValue, falseValue}); result.addTypes(trueValue->getType()); }]>]; let extraClassDeclaration = [{ - ValuePtr getCondition() { return condition(); } - ValuePtr getTrueValue() { return true_value(); } - ValuePtr getFalseValue() { return false_value(); } + Value getCondition() { return condition(); } + Value getTrueValue() { return true_value(); } + Value getFalseValue() { return false_value(); } }]; let hasFolder = 1; @@ -1080,7 +1080,7 @@ def SignExtendIOp : Std_Op<"sexti", let results = (outs IntegerLike); let builders = [OpBuilder< - "Builder *builder, OperationState &result, ValuePtr value, Type destType", [{ + "Builder *builder, OperationState &result, Value value, Type destType", [{ result.addOperands(value); result.addTypes(destType); }]>]; @@ -1180,7 +1180,7 @@ def SplatOp : Std_Op<"splat", [NoSideEffect]> { let results = (outs AnyTypeOf<[AnyVector, AnyStaticShapeTensor]>:$aggregate); let builders = - [OpBuilder<"Builder *builder, OperationState &result, ValuePtr element, " + [OpBuilder<"Builder *builder, OperationState &result, Value element, " "Type aggregateType", [{ build(builder, result, aggregateType, element); }]>]; @@ -1204,16 +1204,16 @@ def StoreOp : Std_Op<"store"> { Variadic:$indices); let builders = [OpBuilder< - "Builder *, OperationState &result, ValuePtr valueToStore, ValuePtr memref", [{ + "Builder *, OperationState &result, Value valueToStore, Value memref", [{ result.addOperands(valueToStore); result.addOperands(memref); }]>]; let extraClassDeclaration = [{ - ValuePtr getValueToStore() { return getOperand(0); } + Value getValueToStore() { return getOperand(0); } - ValuePtr getMemRef() { return getOperand(1); } - void setMemRef(ValuePtr value) { setOperand(1, value); } + Value getMemRef() { return getOperand(1); } + void setMemRef(Value value) { setOperand(1, value); } MemRefType getMemRefType() { return getMemRef()->getType().cast(); } @@ -1355,13 +1355,13 @@ def SubViewOp : Std_Op<"subview", [AttrSizedOperandSegments, NoSideEffect]> { let builders = [ OpBuilder< - "Builder *b, OperationState &result, ValuePtr source, " + "Builder *b, OperationState &result, Value source, " "ValueRange offsets, ValueRange sizes, " "ValueRange strides, Type resultType = Type(), " "ArrayRef attrs = {}">, OpBuilder< "Builder *builder, OperationState &result, " - "Type resultType, ValuePtr source"> + "Type resultType, Value source"> ]; let extraClassDeclaration = [{ @@ -1394,7 +1394,7 @@ def SubViewOp : Std_Op<"subview", [AttrSizedOperandSegments, NoSideEffect]> { // offset, size and stride operands of the SubViewOp into a list of triples. // Such a list of triple is sometimes more convenient to manipulate. struct Range { - ValuePtr offset, size, stride; + Value offset, size, stride; }; SmallVector getRanges(); }]; @@ -1456,7 +1456,7 @@ def TensorLoadOp : Std_Op<"tensor_load", let verifier = ?; let builders = [OpBuilder< - "Builder *builder, OperationState &result, ValuePtr memref", [{ + "Builder *builder, OperationState &result, Value memref", [{ auto memrefType = memref->getType().cast(); auto resultType = RankedTensorType::get(memrefType.getShape(), memrefType.getElementType()); @@ -1510,7 +1510,7 @@ def TruncateIOp : Std_Op<"trunci", [NoSideEffect, SameOperandsAndResultShape]> { let results = (outs IntegerLike); let builders = [OpBuilder< - "Builder *builder, OperationState &result, ValuePtr value, Type destType", [{ + "Builder *builder, OperationState &result, Value value, Type destType", [{ result.addOperands(value); result.addTypes(destType); }]>]; @@ -1569,7 +1569,7 @@ def ViewOp : Std_Op<"view", [NoSideEffect]> { /// Returns the dynamic offset for this view operation if specified. /// Returns nullptr if no dynamic offset was specified. - ValuePtr getDynamicOffset(); + Value getDynamicOffset(); /// Returns the starting operand list position of the dynamic size operands. unsigned getDynamicSizesOperandStart() { @@ -1610,7 +1610,7 @@ def ZeroExtendIOp : Std_Op<"zexti", [NoSideEffect, SameOperandsAndResultShape]> let results = (outs IntegerLike); let builders = [OpBuilder< - "Builder *builder, OperationState &result, ValuePtr value, Type destType", [{ + "Builder *builder, OperationState &result, Value value, Type destType", [{ result.addOperands(value); result.addTypes(destType); }]>]; diff --git a/mlir/include/mlir/Dialect/VectorOps/Utils.h b/mlir/include/mlir/Dialect/VectorOps/Utils.h index 04bd8b50fb6..5f19f849e3f 100644 --- a/mlir/include/mlir/Dialect/VectorOps/Utils.h +++ b/mlir/include/mlir/Dialect/VectorOps/Utils.h @@ -25,9 +25,6 @@ class Operation; class Value; class VectorType; -// TODO(riverriddle) Remove this after Value is value-typed. -using ValuePtr = Value; - /// Computes and returns the multi-dimensional ratio of `superShape` to /// `subShape`. This is calculated by performing a traversal from minor to major /// dimensions (i.e. in reverse shape order). If integral division is not @@ -116,7 +113,7 @@ Optional> shapeRatio(VectorType superVectorType, /// `%arg0[%c0, %c0]` into vector<128xf32> which needs a 1-D vector broadcast. /// AffineMap -makePermutationMap(Operation *op, ArrayRef indices, +makePermutationMap(Operation *op, ArrayRef indices, const DenseMap &loopToVectorDim); namespace matcher { diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td index 87ed28caf80..8726b162fd6 100644 --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -119,8 +119,8 @@ def Vector_ContractionOp : : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32> }]; let builders = [OpBuilder< - "Builder *builder, OperationState &result, ValuePtr lhs, ValuePtr rhs, " - "ValuePtr acc, ArrayAttr indexingMaps, ArrayAttr iteratorTypes">]; + "Builder *builder, OperationState &result, Value lhs, Value rhs, " + "Value acc, ArrayAttr indexingMaps, ArrayAttr iteratorTypes">]; let extraClassDeclaration = [{ VectorType getLhsType() { return lhs()->getType().cast(); @@ -244,7 +244,7 @@ def Vector_ShuffleOp : ``` }]; let builders = [OpBuilder<"Builder *builder, OperationState &result," - "ValuePtr v1, ValuePtr v2, ArrayRef">]; + "Value v1, Value v2, ArrayRef">]; let extraClassDeclaration = [{ static StringRef getMaskAttrName() { return "mask"; } VectorType getV1VectorType() { @@ -304,7 +304,7 @@ def Vector_ExtractOp : ``` }]; let builders = [OpBuilder< - "Builder *builder, OperationState &result, ValuePtr source," + "Builder *builder, OperationState &result, Value source," "ArrayRef">]; let extraClassDeclaration = [{ static StringRef getPositionAttrName() { return "position"; } @@ -350,7 +350,7 @@ def Vector_ExtractSlicesOp : }]; let builders = [OpBuilder< "Builder *builder, OperationState &result, TupleType tupleType, " # - "ValuePtr vector, ArrayRef sizes, " # + "Value vector, ArrayRef sizes, " # "ArrayRef strides">]; let extraClassDeclaration = [{ VectorType getSourceVectorType() { @@ -421,8 +421,8 @@ def Vector_InsertOp : ``` }]; let builders = [OpBuilder< - "Builder *builder, OperationState &result, ValuePtr source, " # - "ValuePtr dest, ArrayRef">]; + "Builder *builder, OperationState &result, Value source, " # + "Value dest, ArrayRef">]; let extraClassDeclaration = [{ static StringRef getPositionAttrName() { return "position"; } Type getSourceType() { return source()->getType(); } @@ -514,7 +514,7 @@ def Vector_InsertStridedSliceOp : ``` }]; let builders = [OpBuilder< - "Builder *builder, OperationState &result, ValuePtr source, ValuePtr dest, " # + "Builder *builder, OperationState &result, Value source, Value dest, " # "ArrayRef offsets, ArrayRef strides">]; let extraClassDeclaration = [{ static StringRef getOffsetsAttrName() { return "offsets"; } @@ -716,7 +716,7 @@ def Vector_StridedSliceOp : vector<4x8x16xf32> to vector<2x4x16xf32> }]; let builders = [OpBuilder< - "Builder *builder, OperationState &result, ValuePtr source, " # + "Builder *builder, OperationState &result, Value source, " # "ArrayRef offsets, ArrayRef sizes, " # "ArrayRef strides">]; let extraClassDeclaration = [{ @@ -968,7 +968,7 @@ def Vector_TypeCastOp : }]; let builders = [OpBuilder< - "Builder *builder, OperationState &result, ValuePtr source">]; + "Builder *builder, OperationState &result, Value source">]; let parser = [{ return impl::parseCastOp(parser, result); diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h b/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h index a73444d2023..feb8bd60445 100644 --- a/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h @@ -64,9 +64,8 @@ namespace vector { // // This will be extended in the future to support more advanced use cases than // simple pointwise ops. -ValuePtr unrollSingleResultOpMatchingType(PatternRewriter &builder, - Operation *op, - ArrayRef targetShape); +Value unrollSingleResultOpMatchingType(PatternRewriter &builder, Operation *op, + ArrayRef targetShape); } // namespace vector } // namespace mlir diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h index f9629a8d99e..d598c1cfb23 100644 --- a/mlir/include/mlir/EDSC/Builders.h +++ b/mlir/include/mlir/EDSC/Builders.h @@ -303,7 +303,7 @@ public: /// Value. An eager Value represents both the declaration and the definition /// (in the PL sense) of a placeholder for an mlir::Value that has already /// been constructed in the past and that is captured "now" in the program. - explicit ValueHandle(ValuePtr v) : t(v->getType()), v(v) {} + explicit ValueHandle(Value v) : t(v->getType()), v(v) {} /// Builds a ConstantIndexOp of value `cst`. The constant is created at the /// current insertion point. @@ -328,7 +328,7 @@ public: } /// Implicit conversion useful for automatic conversion to Container. - operator ValuePtr() const { return getValue(); } + operator Value() const { return getValue(); } operator bool() const { return hasValue(); } /// Generic mlir::Op create. This is the key to being extensible to the whole @@ -347,7 +347,7 @@ public: /// Special case to build composed AffineApply operations. // TODO: createOrFold when available and move inside of the `create` method. static ValueHandle createComposedAffineApply(AffineMap map, - ArrayRef operands); + ArrayRef operands); /// Generic create for a named operation producing a single value. static ValueHandle create(StringRef name, ArrayRef operands, @@ -355,7 +355,7 @@ public: ArrayRef attributes = {}); bool hasValue() const { return v != nullptr; } - ValuePtr getValue() const { + Value getValue() const { assert(hasValue() && "Unexpected null value;"); return v; } @@ -372,7 +372,7 @@ protected: ValueHandle() : t(), v(nullptr) {} Type t; - ValuePtr v; + Value v; }; /// An OperationHandle can be used in lieu of ValueHandle to capture the diff --git a/mlir/include/mlir/EDSC/Helpers.h b/mlir/include/mlir/EDSC/Helpers.h index 0be8a6045f7..a7c0365225a 100644 --- a/mlir/include/mlir/EDSC/Helpers.h +++ b/mlir/include/mlir/EDSC/Helpers.h @@ -66,7 +66,7 @@ protected: // TODO(ntv): Support MemRefs with layoutMaps. class MemRefView : public View { public: - explicit MemRefView(ValuePtr v); + explicit MemRefView(Value v); MemRefView(const MemRefView &) = default; MemRefView &operator=(const MemRefView &) = default; @@ -82,7 +82,7 @@ private: /// a MemRefView but for vectors. This exists purely for boilerplate avoidance. class VectorView : public View { public: - explicit VectorView(ValuePtr v); + explicit VectorView(Value v); VectorView(const VectorView &) = default; VectorView &operator=(const VectorView &) = default; @@ -111,7 +111,7 @@ private: template class TemplatedIndexedValue { public: explicit TemplatedIndexedValue(Type t) : base(t) {} - explicit TemplatedIndexedValue(ValuePtr v) + explicit TemplatedIndexedValue(Value v) : TemplatedIndexedValue(ValueHandle(v)) {} explicit TemplatedIndexedValue(ValueHandle v) : base(v) {} @@ -153,7 +153,7 @@ public: } /// Emits a `load` when converting to a Value. - ValuePtr operator*(void) const { + Value operator*(void) const { return Load(getBase(), {indices.begin(), indices.end()}).getValue(); } diff --git a/mlir/include/mlir/EDSC/Intrinsics.h b/mlir/include/mlir/EDSC/Intrinsics.h index 5edbf9600fb..30cce6bb8d6 100644 --- a/mlir/include/mlir/EDSC/Intrinsics.h +++ b/mlir/include/mlir/EDSC/Intrinsics.h @@ -35,7 +35,7 @@ struct IndexHandle : public ValueHandle { explicit IndexHandle() : ValueHandle(ScopedContext::getBuilder().getIndexType()) {} explicit IndexHandle(index_t v) : ValueHandle(v) {} - explicit IndexHandle(ValuePtr v) : ValueHandle(v) { + explicit IndexHandle(Value v) : ValueHandle(v) { assert(v->getType() == ScopedContext::getBuilder().getIndexType() && "Expected index type"); } @@ -71,8 +71,8 @@ makeHandlePointers(MutableArrayRef ivs) { } /// Returns a vector of the underlying Value from `ivs`. -inline SmallVector extractValues(ArrayRef ivs) { - SmallVector vals; +inline SmallVector extractValues(ArrayRef ivs) { + SmallVector vals; vals.reserve(ivs.size()); for (auto &iv : ivs) { vals.push_back(iv.getValue()); @@ -100,11 +100,11 @@ public: SmallVector tmp(vals.begin(), vals.end()); values.append(tmp.begin(), tmp.end()); } - operator ArrayRef() { return values; } + operator ArrayRef() { return values; } private: ValueHandleArray() = default; - SmallVector values; + SmallVector values; }; template inline T unpack(T value) { return value; } diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h index 33feea7bcbb..934eed93c3b 100644 --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -63,7 +63,7 @@ public: //===--------------------------------------------------------------------===// // This is the list of arguments to the block. - using BlockArgListType = MutableArrayRef; + using BlockArgListType = MutableArrayRef; BlockArgListType getArguments() { return arguments; } @@ -77,7 +77,7 @@ public: bool args_empty() { return arguments.empty(); } /// Add one value to the argument list. - BlockArgumentPtr addArgument(Type type); + BlockArgument addArgument(Type type); /// Add one argument to the argument list for each type specified in the list. iterator_range addArguments(ArrayRef types); @@ -88,7 +88,7 @@ public: void eraseArgument(unsigned index, bool updatePredTerms = true); unsigned getNumArguments() { return arguments.size(); } - BlockArgumentPtr getArgument(unsigned i) { return arguments[i]; } + BlockArgument getArgument(unsigned i) { return arguments[i]; } //===--------------------------------------------------------------------===// // Operation list management @@ -323,7 +323,7 @@ private: OpListType operations; /// This is the list of arguments to the block. - std::vector arguments; + std::vector arguments; Block(Block &) = delete; void operator=(Block &) = delete; diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 038664f0186..2db44cbfa2e 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -304,7 +304,7 @@ public: /// and immediately try to fold it. This functions populates 'results' with /// the results after folding the operation. template - void createOrFold(SmallVectorImpl &results, Location location, + void createOrFold(SmallVectorImpl &results, Location location, Args &&... args) { // Create the operation without using 'createOperation' as we don't want to // insert it yet. @@ -322,9 +322,9 @@ public: /// Overload to create or fold a single result operation. template typename std::enable_if(), - ValuePtr>::type + Value>::type createOrFold(Location location, Args &&... args) { - SmallVector results; + SmallVector results; createOrFold(results, location, std::forward(args)...); return results.front(); } @@ -335,7 +335,7 @@ public: OpTy>::type createOrFold(Location location, Args &&... args) { auto op = create(location, std::forward(args)...); - SmallVector unused; + SmallVector unused; tryFold(op.getOperation(), unused); // Folding cannot remove a zero-result operation, so for convenience we @@ -346,7 +346,7 @@ public: /// Attempts to fold the given operation and places new results within /// 'results'. Returns success if the operation was folded, failure otherwise. /// Note: This function does not erase the operation on a successful fold. - LogicalResult tryFold(Operation *op, SmallVectorImpl &results); + LogicalResult tryFold(Operation *op, SmallVectorImpl &results); /// Creates a deep copy of the specified operation, remapping any operands /// that use values outside of the operation using the map that is provided diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h index 49175ba5e75..e6cba2c7404 100644 --- a/mlir/include/mlir/IR/FunctionSupport.h +++ b/mlir/include/mlir/IR/FunctionSupport.h @@ -174,7 +174,7 @@ public: } /// Gets argument. - BlockArgumentPtr getArgument(unsigned idx) { + BlockArgument getArgument(unsigned idx) { return getBlocks().front().getArgument(idx); } diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h index 5ce2cc7a8a8..2cfa2428bd5 100644 --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -133,7 +133,7 @@ using has_operation_or_value_matcher_t = /// Statically switch to a Value matcher. template typename std::enable_if_t::value, + MatcherClass, Value>::value, bool> matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) { return matcher.match(op->getOperand(idx)); @@ -152,14 +152,14 @@ matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) { /// Terminal matcher, always returns true. struct AnyValueMatcher { - bool match(ValuePtr op) const { return true; } + bool match(Value op) const { return true; } }; /// Binds to a specific value and matches it. struct PatternMatcherValue { - PatternMatcherValue(ValuePtr val) : value(val) {} - bool match(ValuePtr val) const { return val == value; } - ValuePtr value; + PatternMatcherValue(Value val) : value(val) {} + bool match(Value val) const { return val == value; } + Value value; }; template @@ -226,7 +226,7 @@ inline detail::constant_int_not_value_matcher<0> m_NonZero() { /// Entry point for matching a pattern over a Value. template -inline bool matchPattern(ValuePtr value, const Pattern &pattern) { +inline bool matchPattern(Value value, const Pattern &pattern) { // TODO: handle other cases if (auto *op = value->getDefiningOp()) return const_cast(pattern).match(op); @@ -253,7 +253,7 @@ auto m_Op(Matchers... matchers) { namespace matchers { inline auto m_Any() { return detail::AnyValueMatcher(); } -inline auto m_Val(ValuePtr v) { return detail::PatternMatcherValue(v); } +inline auto m_Val(Value v) { return detail::PatternMatcherValue(v); } } // namespace matchers } // end namespace mlir diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 84f3cf2f444..1abf82f37ee 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -248,8 +248,8 @@ inline bool operator!=(OpState lhs, OpState rhs) { } /// This class represents a single result from folding an operation. -class OpFoldResult : public PointerUnion { - using PointerUnion::PointerUnion; +class OpFoldResult : public PointerUnion { + using PointerUnion::PointerUnion; }; /// This template defines the foldHook as used by AbstractOperation. @@ -303,7 +303,7 @@ class FoldingHook(this)->getOperation()->getResult(0); } @@ -317,7 +317,7 @@ public: // Check if the operation was folded in place. In this case, the operation // returns itself. - if (result.template dyn_cast() != op->getResult(0)) + if (result.template dyn_cast() != op->getResult(0)) results.push_back(result); return success(); } @@ -419,12 +419,10 @@ struct MultiOperandTraitBase : public TraitBase { unsigned getNumOperands() { return this->getOperation()->getNumOperands(); } /// Return the operand at index 'i'. - ValuePtr getOperand(unsigned i) { - return this->getOperation()->getOperand(i); - } + Value getOperand(unsigned i) { return this->getOperation()->getOperand(i); } /// Set the operand at index 'i' to 'value'. - void setOperand(unsigned i, ValuePtr value) { + void setOperand(unsigned i, Value value) { this->getOperation()->setOperand(i, value); } @@ -468,11 +466,9 @@ private: template class OneOperand : public TraitBase { public: - ValuePtr getOperand() { return this->getOperation()->getOperand(0); } + Value getOperand() { return this->getOperation()->getOperand(0); } - void setOperand(ValuePtr value) { - this->getOperation()->setOperand(0, value); - } + void setOperand(Value value) { this->getOperation()->setOperand(0, value); } static LogicalResult verifyTrait(Operation *op) { return impl::verifyOneOperand(op); @@ -545,7 +541,7 @@ struct MultiResultTraitBase : public TraitBase { unsigned getNumResults() { return this->getOperation()->getNumResults(); } /// Return the result at index 'i'. - ValuePtr getResult(unsigned i) { return this->getOperation()->getResult(i); } + Value getResult(unsigned i) { return this->getOperation()->getResult(i); } /// Replace all uses of results of this operation with the provided 'values'. /// 'values' may correspond to an existing operation, or a range of 'Value'. @@ -581,13 +577,13 @@ struct MultiResultTraitBase : public TraitBase { template class OneResult : public TraitBase { public: - ValuePtr getResult() { return this->getOperation()->getResult(0); } + Value getResult() { return this->getOperation()->getResult(0); } Type getType() { return getResult()->getType(); } /// Replace all uses of 'this' value with the new value, updating anything in /// the IR that uses 'this' to use the other value instead. When this returns /// there are zero uses of 'this'. - void replaceAllUsesWith(ValuePtr newValue) { + void replaceAllUsesWith(Value newValue) { getResult()->replaceAllUsesWith(newValue); } @@ -815,10 +811,10 @@ public: return this->getOperation()->setSuccessor(block, index); } - void addSuccessorOperand(unsigned index, ValuePtr value) { + void addSuccessorOperand(unsigned index, Value value) { return this->getOperation()->addSuccessorOperand(index, value); } - void addSuccessorOperands(unsigned index, ArrayRef values) { + void addSuccessorOperands(unsigned index, ArrayRef values) { return this->getOperation()->addSuccessorOperand(index, values); } }; @@ -1204,8 +1200,8 @@ namespace impl { ParseResult parseOneResultOneOperandTypeOp(OpAsmParser &parser, OperationState &result); -void buildBinaryOp(Builder *builder, OperationState &result, ValuePtr lhs, - ValuePtr rhs); +void buildBinaryOp(Builder *builder, OperationState &result, Value lhs, + Value rhs); ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result); @@ -1218,11 +1214,11 @@ void printOneResultOp(Operation *op, OpAsmPrinter &p); // These functions are out-of-line implementations of the methods in CastOp, // which avoids them being template instantiated/duplicated. namespace impl { -void buildCastOp(Builder *builder, OperationState &result, ValuePtr source, +void buildCastOp(Builder *builder, OperationState &result, Value source, Type destType); ParseResult parseCastOp(OpAsmParser &parser, OperationState &result); void printCastOp(Operation *op, OpAsmPrinter &p); -ValuePtr foldCastOp(Operation *op); +Value foldCastOp(Operation *op); } // namespace impl } // end namespace mlir diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index 8e2aed29500..41acdba1a05 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -36,7 +36,7 @@ public: virtual raw_ostream &getStream() const = 0; /// Print implementations for various things an operation contains. - virtual void printOperand(ValuePtr value) = 0; + virtual void printOperand(Value value) = 0; /// Print a comma separated list of operands. template @@ -112,7 +112,7 @@ public: void printFunctionalType(Operation *op) { auto &os = getStream(); os << "("; - interleaveComma(op->getNonSuccessorOperands(), os, [&](ValuePtr operand) { + interleaveComma(op->getNonSuccessorOperands(), os, [&](Value operand) { if (operand) printType(operand->getType()); else @@ -141,7 +141,7 @@ private: }; // Make the implementations convenient to use. -inline OpAsmPrinter &operator<<(OpAsmPrinter &p, ValueRef value) { +inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value value) { p.printOperand(value); return p; } @@ -454,13 +454,13 @@ public: /// Resolve an operand to an SSA value, emitting an error on failure. virtual ParseResult resolveOperand(const OperandType &operand, Type type, - SmallVectorImpl &result) = 0; + SmallVectorImpl &result) = 0; /// Resolve a list of operands to SSA values, emitting an error on failure, or /// appending the results to the list on success. This method should be used /// when all operands have the same type. ParseResult resolveOperands(ArrayRef operands, Type type, - SmallVectorImpl &result) { + SmallVectorImpl &result) { for (auto elt : operands) if (resolveOperand(elt, type, result)) return failure(); @@ -472,7 +472,7 @@ public: /// to the list on success. ParseResult resolveOperands(ArrayRef operands, ArrayRef types, llvm::SMLoc loc, - SmallVectorImpl &result) { + SmallVectorImpl &result) { if (operands.size() != types.size()) return emitError(loc) << operands.size() << " operands present, but expected " @@ -542,8 +542,7 @@ public: /// Parse a single operation successor and its operand list. virtual ParseResult - parseSuccessorAndUseList(Block *&dest, - SmallVectorImpl &operands) = 0; + parseSuccessorAndUseList(Block *&dest, SmallVectorImpl &operands) = 0; //===--------------------------------------------------------------------===// // Type Parsing @@ -621,7 +620,7 @@ private: /// A functor used to set the name of the start of a result group of an /// operation. See 'getAsmResultNames' below for more details. -using OpAsmSetValueNameFn = function_ref; +using OpAsmSetValueNameFn = function_ref; class OpAsmDialectInterface : public DialectInterface::Base { diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index 47085f361ca..9ef1636d3d0 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -34,8 +34,7 @@ class Operation final public: /// Create a new Operation with the specific fields. static Operation *create(Location location, OperationName name, - ArrayRef resultTypes, - ArrayRef operands, + ArrayRef resultTypes, ArrayRef operands, ArrayRef attributes, ArrayRef successors, unsigned numRegions, bool resizableOperandList); @@ -43,8 +42,7 @@ public: /// Overload of create that takes an existing NamedAttributeList to avoid /// unnecessarily uniquing a list of attributes. static Operation *create(Location location, OperationName name, - ArrayRef resultTypes, - ArrayRef operands, + ArrayRef resultTypes, ArrayRef operands, NamedAttributeList attributes, ArrayRef successors, unsigned numRegions, bool resizableOperandList); @@ -53,11 +51,12 @@ public: static Operation *create(const OperationState &state); /// Create a new Operation with the specific fields. - static Operation * - create(Location location, OperationName name, ArrayRef resultTypes, - ArrayRef operands, NamedAttributeList attributes, - ArrayRef successors = {}, RegionRange regions = {}, - bool resizableOperandList = false); + static Operation *create(Location location, OperationName name, + ArrayRef resultTypes, ArrayRef operands, + NamedAttributeList attributes, + ArrayRef successors = {}, + RegionRange regions = {}, + bool resizableOperandList = false); /// The name of an operation is the key identifier for it. OperationName getName() { return name; } @@ -140,7 +139,7 @@ public: } /// Replace any uses of 'from' with 'to' within this operation. - void replaceUsesOfWith(ValuePtr from, ValuePtr to); + void replaceUsesOfWith(Value from, Value to); /// Replace all uses of results of this operation with the provided 'values'. template getSuccessorBlockArgument(unsigned operandIndex) { + Optional getSuccessorBlockArgument(unsigned operandIndex) { auto decomposed = decomposeSuccessorOperandIndex(operandIndex); if (!decomposed.hasValue()) return None; diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index ef2ff44ef6e..30376b8b599 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -261,7 +261,7 @@ inline llvm::hash_code hash_value(OperationName arg) { struct OperationState { Location location; OperationName name; - SmallVector operands; + SmallVector operands; /// Types of the results of this operation. SmallVector types; SmallVector attributes; diff --git a/mlir/include/mlir/IR/TypeUtilities.h b/mlir/include/mlir/IR/TypeUtilities.h index fd9d317ed35..b095683ae5b 100644 --- a/mlir/include/mlir/IR/TypeUtilities.h +++ b/mlir/include/mlir/IR/TypeUtilities.h @@ -32,7 +32,7 @@ Type getElementTypeOrSelf(Type type); /// Return the element type or return the type itself. Type getElementTypeOrSelf(Attribute attr); -Type getElementTypeOrSelf(ValuePtr val); +Type getElementTypeOrSelf(Value val); /// Get the types within a nested Tuple. A helper for the class method that /// handles storage concerns, which is tricky to do in tablegen. @@ -62,7 +62,7 @@ LogicalResult verifyCompatibleShape(Type type1, Type type2); // An iterator for the element types of an op's operands of shaped types. class OperandElementTypeIterator final : public llvm::mapped_iterator { + Type (*)(Value)> { public: using reference = Type; @@ -71,7 +71,7 @@ public: explicit OperandElementTypeIterator(Operation::operand_iterator it); private: - static Type unwrap(ValuePtr value); + static Type unwrap(Value value); }; using OperandElementTypeRange = iterator_range; @@ -79,7 +79,7 @@ using OperandElementTypeRange = iterator_range; // An iterator for the tensor element types of an op's results of shaped types. class ResultElementTypeIterator final : public llvm::mapped_iterator { + Type (*)(Value)> { public: using reference = Type; @@ -88,7 +88,7 @@ public: explicit ResultElementTypeIterator(Operation::result_iterator it); private: - static Type unwrap(ValuePtr value); + static Type unwrap(Value value); }; using ResultElementTypeRange = iterator_range; diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h index 26703a25306..c4356b16840 100644 --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -292,12 +292,6 @@ inline ::llvm::hash_code hash_value(Value arg) { return ::llvm::hash_value(arg.impl); } -/// Using directives that simplify the transition of Value to being value typed. -using BlockArgumentPtr = BlockArgument; -using OpResultPtr = OpResult; -using ValueRef = Value; -using ValuePtr = Value; - } // namespace mlir namespace llvm { diff --git a/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraph.h b/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraph.h index fe66848b906..d99db65b015 100644 --- a/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraph.h +++ b/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraph.h @@ -154,7 +154,7 @@ public: } virtual Operation *getOp() const = 0; - virtual ValuePtr getValue() const = 0; + virtual Value getValue() const = 0; static bool classof(const CAGNode *n) { return n->getKind() >= Kind::Anchor && n->getKind() <= Kind::LastAnchor; @@ -201,7 +201,7 @@ public: return n->getKind() == Kind::Anchor || n->getKind() == Kind::OperandAnchor; } - ValuePtr getValue() const final { return op->getOperand(operandIdx); } + Value getValue() const final { return op->getOperand(operandIdx); } void printLabel(raw_ostream &os) const override; @@ -222,12 +222,12 @@ public: } Operation *getOp() const final { return resultValue->getDefiningOp(); } - ValuePtr getValue() const final { return resultValue; } + Value getValue() const final { return resultValue; } void printLabel(raw_ostream &os) const override; private: - ValuePtr resultValue; + Value resultValue; }; /// Base class for constraint nodes. diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h index 4a5010ea09a..d0b13a669fa 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -104,7 +104,7 @@ private: protected: // Mappings between original and translated values, used for lookups. llvm::StringMap functionMapping; - DenseMap valueMapping; + DenseMap valueMapping; DenseMap blockMapping; }; diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index becb95f1f4e..5cbbcae4543 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -51,7 +51,7 @@ public: /// remaps an existing signature input. struct InputMapping { size_t inputNo, size; - ValuePtr replacementValue; + Value replacementValue; }; /// Return the argument types for the new signature. @@ -81,7 +81,7 @@ public: /// Remap an input of the original signature to another `replacement` /// value. This drops the original argument. - void remapInput(unsigned origInputNo, ValuePtr replacement); + void remapInput(unsigned origInputNo, Value replacement); private: /// The remapping information for each of the original arguments. @@ -134,7 +134,7 @@ public: /// the conversion has finished. virtual Operation *materializeConversion(PatternRewriter &rewriter, Type resultType, - ArrayRef inputs, + ArrayRef inputs, Location loc) { llvm_unreachable("expected 'materializeConversion' to be overridden"); } @@ -163,7 +163,7 @@ public: /// ConversionPattern ever needs to replace an operation that does not /// have successors. This function should not fail. If some specific cases of /// the operation are not supported, these cases should not be matched. - virtual void rewrite(Operation *op, ArrayRef operands, + virtual void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { llvm_unreachable("unimplemented rewrite"); } @@ -178,18 +178,18 @@ public: /// terminator operation that has successors. This function should not fail /// the pass. If some specific cases of the operation are not supported, /// these cases should not be matched. - virtual void rewrite(Operation *op, ArrayRef properOperands, + virtual void rewrite(Operation *op, ArrayRef properOperands, ArrayRef destinations, - ArrayRef> operands, + ArrayRef> operands, ConversionPatternRewriter &rewriter) const { llvm_unreachable("unimplemented rewrite for terminators"); } /// Hook for derived classes to implement combined matching and rewriting. virtual PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef properOperands, + matchAndRewrite(Operation *op, ArrayRef properOperands, ArrayRef destinations, - ArrayRef> operands, + ArrayRef> operands, ConversionPatternRewriter &rewriter) const { if (!match(op)) return matchFailure(); @@ -199,7 +199,7 @@ public: /// Hook for derived classes to implement combined matching and rewriting. virtual PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (!match(op)) return matchFailure(); @@ -225,27 +225,27 @@ struct OpConversionPattern : public ConversionPattern { /// Wrappers around the ConversionPattern methods that pass the derived op /// type. - void rewrite(Operation *op, ArrayRef operands, + void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { rewrite(cast(op), operands, rewriter); } - void rewrite(Operation *op, ArrayRef properOperands, + void rewrite(Operation *op, ArrayRef properOperands, ArrayRef destinations, - ArrayRef> operands, + ArrayRef> operands, ConversionPatternRewriter &rewriter) const final { rewrite(cast(op), properOperands, destinations, operands, rewriter); } PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef properOperands, + matchAndRewrite(Operation *op, ArrayRef properOperands, ArrayRef destinations, - ArrayRef> operands, + ArrayRef> operands, ConversionPatternRewriter &rewriter) const final { return matchAndRewrite(cast(op), properOperands, destinations, operands, rewriter); } PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { return matchAndRewrite(cast(op), operands, rewriter); } @@ -255,22 +255,22 @@ struct OpConversionPattern : public ConversionPattern { /// Rewrite and Match methods that operate on the SourceOp type. These must be /// overridden by the derived pattern class. - virtual void rewrite(SourceOp op, ArrayRef operands, + virtual void rewrite(SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { llvm_unreachable("must override matchAndRewrite or a rewrite method"); } - virtual void rewrite(SourceOp op, ArrayRef properOperands, + virtual void rewrite(SourceOp op, ArrayRef properOperands, ArrayRef destinations, - ArrayRef> operands, + ArrayRef> operands, ConversionPatternRewriter &rewriter) const { llvm_unreachable("unimplemented rewrite for terminators"); } virtual PatternMatchResult - matchAndRewrite(SourceOp op, ArrayRef properOperands, + matchAndRewrite(SourceOp op, ArrayRef properOperands, ArrayRef destinations, - ArrayRef> operands, + ArrayRef> operands, ConversionPatternRewriter &rewriter) const { if (!match(op)) return matchFailure(); @@ -279,7 +279,7 @@ struct OpConversionPattern : public ConversionPattern { } virtual PatternMatchResult - matchAndRewrite(SourceOp op, ArrayRef operands, + matchAndRewrite(SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (!match(op)) return matchFailure(); @@ -321,11 +321,11 @@ public: TypeConverter::SignatureConversion &conversion); /// Replace all the uses of the block argument `from` with value `to`. - void replaceUsesOfBlockArgument(BlockArgumentPtr from, ValuePtr to); + void replaceUsesOfBlockArgument(BlockArgument from, Value to); /// Return the converted value that replaces 'key'. Return 'key' if there is /// no such a converted value. - ValuePtr getRemappedValue(ValuePtr key); + Value getRemappedValue(Value key); //===--------------------------------------------------------------------===// // PatternRewriter Hooks diff --git a/mlir/include/mlir/Transforms/FoldUtils.h b/mlir/include/mlir/Transforms/FoldUtils.h index ed18619c44a..6b0e82794f5 100644 --- a/mlir/include/mlir/Transforms/FoldUtils.h +++ b/mlir/include/mlir/Transforms/FoldUtils.h @@ -73,7 +73,7 @@ public: /// and immediately try to fold it. This function populates 'results' with /// the results after folding the operation. template - void create(OpBuilder &builder, SmallVectorImpl &results, + void create(OpBuilder &builder, SmallVectorImpl &results, Location location, Args &&... args) { Operation *op = builder.create(location, std::forward(args)...); if (failed(tryToFold(op, results))) @@ -85,9 +85,9 @@ public: /// Overload to create or fold a single result operation. template typename std::enable_if(), - ValuePtr>::type + Value>::type create(OpBuilder &builder, Location location, Args &&... args) { - SmallVector results; + SmallVector results; create(builder, results, location, std::forward(args)...); return results.front(); } @@ -98,7 +98,7 @@ public: OpTy>::type create(OpBuilder &builder, Location location, Args &&... args) { auto op = builder.create(location, std::forward(args)...); - SmallVector unused; + SmallVector unused; (void)tryToFold(op.getOperation(), unused); // Folding cannot remove a zero-result operation, so for convenience we @@ -117,7 +117,7 @@ private: /// Tries to perform folding on the given `op`. If successful, populates /// `results` with the results of the folding. LogicalResult tryToFold( - Operation *op, SmallVectorImpl &results, + Operation *op, SmallVectorImpl &results, function_ref processGeneratedConstants = nullptr); /// Try to get or create a new constant entry. On success this returns the diff --git a/mlir/include/mlir/Transforms/InliningUtils.h b/mlir/include/mlir/Transforms/InliningUtils.h index e4739bba66b..e3631c21c30 100644 --- a/mlir/include/mlir/Transforms/InliningUtils.h +++ b/mlir/include/mlir/Transforms/InliningUtils.h @@ -96,7 +96,7 @@ public: /// operation). The given 'op' will be removed by the caller, after this /// function has been called. virtual void handleTerminator(Operation *op, - ArrayRef valuesToReplace) const { + ArrayRef valuesToReplace) const { llvm_unreachable( "must implement handleTerminator in the case of one inlined block"); } @@ -116,8 +116,8 @@ public: /// ... = foo.call @foo(%input : i32) -> i16 /// /// NOTE: This hook may be invoked before the 'isLegal' checks above. - virtual Operation *materializeCallConversion(OpBuilder &builder, - ValuePtr input, Type resultType, + virtual Operation *materializeCallConversion(OpBuilder &builder, Value input, + Type resultType, Location conversionLoc) const { return nullptr; } @@ -156,7 +156,7 @@ public: virtual void handleTerminator(Operation *op, Block *newDest) const; virtual void handleTerminator(Operation *op, - ArrayRef valuesToRepl) const; + ArrayRef valuesToRepl) const; }; //===----------------------------------------------------------------------===// @@ -178,7 +178,7 @@ public: /// be cloned into the 'inlinePoint' or spliced directly. LogicalResult inlineRegion(InlinerInterface &interface, Region *src, Operation *inlinePoint, BlockAndValueMapping &mapper, - ArrayRef resultsToReplace, + ArrayRef resultsToReplace, Optional inlineLoc = llvm::None, bool shouldCloneInlinedRegion = true); @@ -187,8 +187,8 @@ LogicalResult inlineRegion(InlinerInterface &interface, Region *src, /// in-favor of the region arguments when inlining. LogicalResult inlineRegion(InlinerInterface &interface, Region *src, Operation *inlinePoint, - ArrayRef inlinedOperands, - ArrayRef resultsToReplace, + ArrayRef inlinedOperands, + ArrayRef resultsToReplace, Optional inlineLoc = llvm::None, bool shouldCloneInlinedRegion = true); diff --git a/mlir/include/mlir/Transforms/LoopLikeInterface.td b/mlir/include/mlir/Transforms/LoopLikeInterface.td index 089a3e19c35..c110b192987 100644 --- a/mlir/include/mlir/Transforms/LoopLikeInterface.td +++ b/mlir/include/mlir/Transforms/LoopLikeInterface.td @@ -29,7 +29,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> { explicit capture of dependencies, an implementation could check whether the value corresponds to a captured dependency. }], - "bool", "isDefinedOutsideOfLoop", (ins "ValuePtr ":$value) + "bool", "isDefinedOutsideOfLoop", (ins "Value ":$value) >, InterfaceMethod<[{ Returns the region that makes up the body of the loop and should be diff --git a/mlir/include/mlir/Transforms/LoopUtils.h b/mlir/include/mlir/Transforms/LoopUtils.h index a08a3fc8307..402a336cf1c 100644 --- a/mlir/include/mlir/Transforms/LoopUtils.h +++ b/mlir/include/mlir/Transforms/LoopUtils.h @@ -75,8 +75,7 @@ void promoteSingleIterationLoops(FuncOp f); /// operands or a null map when the trip count can't be expressed as an affine /// expression. void getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor, - AffineMap *map, - SmallVectorImpl *operands, + AffineMap *map, SmallVectorImpl *operands, OpBuilder &builder); /// Skew the operations in the body of a 'affine.for' operation with the @@ -130,8 +129,7 @@ using TileLoops = std::pair; SmallVector, 8> tile(ArrayRef forOps, ArrayRef sizes, ArrayRef targets); -SmallVector tile(ArrayRef forOps, - ArrayRef sizes, +SmallVector tile(ArrayRef forOps, ArrayRef sizes, ArrayRef targets); /// Performs tiling (with interchange) by strip-mining the `forOps` by `sizes` @@ -140,7 +138,7 @@ SmallVector tile(ArrayRef forOps, /// `target`. SmallVector tile(ArrayRef forOps, ArrayRef sizes, AffineForOp target); -Loops tile(ArrayRef forOps, ArrayRef sizes, +Loops tile(ArrayRef forOps, ArrayRef sizes, loop::ForOp target); /// Tile a nest of loop::ForOp loops rooted at `rootForOp` with the given @@ -148,7 +146,7 @@ Loops tile(ArrayRef forOps, ArrayRef sizes, /// runtime. If more sizes than loops are provided, discard the trailing values /// in sizes. Assumes the loop nest is permutable. /// Returns the newly created intra-tile loops. -Loops tilePerfectlyNested(loop::ForOp rootForOp, ArrayRef sizes); +Loops tilePerfectlyNested(loop::ForOp rootForOp, ArrayRef sizes); /// Explicit copy / DMA generation options for mlir::affineDataCopyGenerate. struct AffineCopyOptions { @@ -220,8 +218,8 @@ void coalesceLoops(MutableArrayRef loops); /// ... /// } /// ``` -void mapLoopToProcessorIds(loop::ForOp forOp, ArrayRef processorId, - ArrayRef numProcessors); +void mapLoopToProcessorIds(loop::ForOp forOp, ArrayRef processorId, + ArrayRef numProcessors); } // end namespace mlir #endif // MLIR_TRANSFORMS_LOOP_UTILS_H diff --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h index 9639dfad857..bd71553e96b 100644 --- a/mlir/include/mlir/Transforms/RegionUtils.h +++ b/mlir/include/mlir/Transforms/RegionUtils.h @@ -21,15 +21,14 @@ namespace mlir { /// of `limit`. template bool areValuesDefinedAbove(Range values, Region &limit) { - for (ValuePtr v : values) + for (Value v : values) if (!v->getParentRegion()->isProperAncestor(&limit)) return false; return true; } /// Replace all uses of `orig` within the given region with `replacement`. -void replaceAllUsesInRegionWith(ValuePtr orig, ValuePtr replacement, - Region ®ion); +void replaceAllUsesInRegionWith(Value orig, Value replacement, Region ®ion); /// Calls `callback` for each use of a value within `region` or its descendants /// that was defined at the ancestors of the `limit`. @@ -44,12 +43,12 @@ void visitUsedValuesDefinedAbove(MutableArrayRef regions, /// Fill `values` with a list of values defined at the ancestors of the `limit` /// region and used within `region` or its descendants. void getUsedValuesDefinedAbove(Region ®ion, Region &limit, - llvm::SetVector &values); + llvm::SetVector &values); /// Fill `values` with a list of values used within any of the regions provided /// but defined in one of the ancestors. void getUsedValuesDefinedAbove(MutableArrayRef regions, - llvm::SetVector &values); + llvm::SetVector &values); /// Run a set of structural simplifications over the given regions. This /// includes transformations like unreachable block elimination, dead argument diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h index a8268c1daa2..3b7f6cd3909 100644 --- a/mlir/include/mlir/Transforms/Utils.h +++ b/mlir/include/mlir/Transforms/Utils.h @@ -57,22 +57,22 @@ class OpBuilder; // extra operands, note that 'indexRemap' would just be applied to existing // indices (%i, %j). // TODO(bondhugula): allow extraIndices to be added at any position. -LogicalResult replaceAllMemRefUsesWith(ValuePtr oldMemRef, ValuePtr newMemRef, - ArrayRef extraIndices = {}, +LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, + ArrayRef extraIndices = {}, AffineMap indexRemap = AffineMap(), - ArrayRef extraOperands = {}, - ArrayRef symbolOperands = {}, + ArrayRef extraOperands = {}, + ArrayRef symbolOperands = {}, Operation *domInstFilter = nullptr, Operation *postDomInstFilter = nullptr); /// Performs the same replacement as the other version above but only for the /// dereferencing uses of `oldMemRef` in `op`. -LogicalResult replaceAllMemRefUsesWith(ValuePtr oldMemRef, ValuePtr newMemRef, +LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, Operation *op, - ArrayRef extraIndices = {}, + ArrayRef extraIndices = {}, AffineMap indexRemap = AffineMap(), - ArrayRef extraOperands = {}, - ArrayRef symbolOperands = {}); + ArrayRef extraOperands = {}, + ArrayRef symbolOperands = {}); /// Rewrites the memref defined by this alloc op to have an identity layout map /// and updates all its indexing uses. Returns failure if any of its uses @@ -87,9 +87,9 @@ LogicalResult normalizeMemRef(AllocOp op); /// The final results of the composed AffineApplyOp are returned in output /// parameter 'results'. Returns the affine apply op created. Operation *createComposedAffineApplyOp(OpBuilder &builder, Location loc, - ArrayRef operands, + ArrayRef operands, ArrayRef affineApplyOps, - SmallVectorImpl *results); + SmallVectorImpl *results); /// Given an operation, inserts one or more single result affine apply /// operations, results of which are exclusively used by this operation. diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index 27aa0748711..3358bb437ff 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -39,10 +39,10 @@ using llvm::dbgs; // TODO(andydavis) Add a method to AffineApplyOp which forward substitutes // the AffineApplyOp into any user AffineApplyOps. void mlir::getReachableAffineApplyOps( - ArrayRef operands, SmallVectorImpl &affineApplyOps) { + ArrayRef operands, SmallVectorImpl &affineApplyOps) { struct State { // The ssa value for this node in the DFS traversal. - ValuePtr value; + Value value; // The operand index of 'value' to explore next during DFS traversal. unsigned operandIndex; }; @@ -90,7 +90,7 @@ void mlir::getReachableAffineApplyOps( // setExprStride(ArrayRef expr, int64_t stride) LogicalResult mlir::getIndexSet(MutableArrayRef forOps, FlatAffineConstraints *domain) { - SmallVector indices; + SmallVector indices; extractForInductionVars(forOps, &indices); // Reset while associated Values in 'indices' to the domain. domain->reset(forOps.size(), /*numSymbols=*/0, /*numLocals=*/0, indices); @@ -137,25 +137,25 @@ static LogicalResult getInstIndexSet(Operation *op, // of maps to check. So getSrcDimOrSymPos would be "getPos(value, {0, 2})". class ValuePositionMap { public: - void addSrcValue(ValuePtr value) { + void addSrcValue(Value value) { if (addValueAt(value, &srcDimPosMap, numSrcDims)) ++numSrcDims; } - void addDstValue(ValuePtr value) { + void addDstValue(Value value) { if (addValueAt(value, &dstDimPosMap, numDstDims)) ++numDstDims; } - void addSymbolValue(ValuePtr value) { + void addSymbolValue(Value value) { if (addValueAt(value, &symbolPosMap, numSymbols)) ++numSymbols; } - unsigned getSrcDimOrSymPos(ValuePtr value) const { + unsigned getSrcDimOrSymPos(Value value) const { return getDimOrSymPos(value, srcDimPosMap, 0); } - unsigned getDstDimOrSymPos(ValuePtr value) const { + unsigned getDstDimOrSymPos(Value value) const { return getDimOrSymPos(value, dstDimPosMap, numSrcDims); } - unsigned getSymPos(ValuePtr value) const { + unsigned getSymPos(Value value) const { auto it = symbolPosMap.find(value); assert(it != symbolPosMap.end()); return numSrcDims + numDstDims + it->second; @@ -167,7 +167,7 @@ public: unsigned getNumSymbols() const { return numSymbols; } private: - bool addValueAt(ValuePtr value, DenseMap *posMap, + bool addValueAt(Value value, DenseMap *posMap, unsigned position) { auto it = posMap->find(value); if (it == posMap->end()) { @@ -176,8 +176,8 @@ private: } return false; } - unsigned getDimOrSymPos(ValuePtr value, - const DenseMap &dimPosMap, + unsigned getDimOrSymPos(Value value, + const DenseMap &dimPosMap, unsigned dimPosOffset) const { auto it = dimPosMap.find(value); if (it != dimPosMap.end()) { @@ -191,9 +191,9 @@ private: unsigned numSrcDims = 0; unsigned numDstDims = 0; unsigned numSymbols = 0; - DenseMap srcDimPosMap; - DenseMap dstDimPosMap; - DenseMap symbolPosMap; + DenseMap srcDimPosMap; + DenseMap dstDimPosMap; + DenseMap symbolPosMap; }; // Builds a map from Value to identifier position in a new merged identifier @@ -210,7 +210,7 @@ static void buildDimAndSymbolPositionMaps( const FlatAffineConstraints &dstDomain, const AffineValueMap &srcAccessMap, const AffineValueMap &dstAccessMap, ValuePositionMap *valuePosMap, FlatAffineConstraints *dependenceConstraints) { - auto updateValuePosMap = [&](ArrayRef values, bool isSrc) { + auto updateValuePosMap = [&](ArrayRef values, bool isSrc) { for (unsigned i = 0, e = values.size(); i < e; ++i) { auto value = values[i]; if (!isForInductionVar(values[i])) { @@ -225,7 +225,7 @@ static void buildDimAndSymbolPositionMaps( } }; - SmallVector srcValues, destValues; + SmallVector srcValues, destValues; srcDomain.getIdValues(0, srcDomain.getNumDimAndSymbolIds(), &srcValues); dstDomain.getIdValues(0, dstDomain.getNumDimAndSymbolIds(), &destValues); // Update value position map with identifiers from src iteration domain. @@ -264,7 +264,7 @@ void initDependenceConstraints(const FlatAffineConstraints &srcDomain, numLocals); // Set values corresponding to dependence constraint identifiers. - SmallVector srcLoopIVs, dstLoopIVs; + SmallVector srcLoopIVs, dstLoopIVs; srcDomain.getIdValues(0, srcDomain.getNumDimIds(), &srcLoopIVs); dstDomain.getIdValues(0, dstDomain.getNumDimIds(), &dstLoopIVs); @@ -273,7 +273,7 @@ void initDependenceConstraints(const FlatAffineConstraints &srcDomain, srcLoopIVs.size(), srcLoopIVs.size() + dstLoopIVs.size(), dstLoopIVs); // Set values for the symbolic identifier dimensions. - auto setSymbolIds = [&](ArrayRef values) { + auto setSymbolIds = [&](ArrayRef values) { for (auto value : values) { if (!isForInductionVar(value)) { assert(isValidSymbol(value) && "expected symbol"); @@ -285,7 +285,7 @@ void initDependenceConstraints(const FlatAffineConstraints &srcDomain, setSymbolIds(srcAccessMap.getOperands()); setSymbolIds(dstAccessMap.getOperands()); - SmallVector srcSymbolValues, dstSymbolValues; + SmallVector srcSymbolValues, dstSymbolValues; srcDomain.getIdValues(srcDomain.getNumDimIds(), srcDomain.getNumDimAndSymbolIds(), &srcSymbolValues); dstDomain.getIdValues(dstDomain.getNumDimIds(), @@ -389,10 +389,10 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap, unsigned numResults = srcMap.getNumResults(); unsigned srcNumIds = srcMap.getNumDims() + srcMap.getNumSymbols(); - ArrayRef srcOperands = srcAccessMap.getOperands(); + ArrayRef srcOperands = srcAccessMap.getOperands(); unsigned dstNumIds = dstMap.getNumDims() + dstMap.getNumSymbols(); - ArrayRef dstOperands = dstAccessMap.getOperands(); + ArrayRef dstOperands = dstAccessMap.getOperands(); std::vector> srcFlatExprs; std::vector> destFlatExprs; @@ -448,7 +448,7 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap, } // Add equality constraints for any operands that are defined by constant ops. - auto addEqForConstOperands = [&](ArrayRef operands) { + auto addEqForConstOperands = [&](ArrayRef operands) { for (unsigned i = 0, e = operands.size(); i < e; ++i) { if (isForInductionVar(operands[i])) continue; @@ -666,7 +666,7 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const { map = loadOp.getAffineMap(); else if (auto storeOp = dyn_cast(opInst)) map = storeOp.getAffineMap(); - SmallVector operands(indices.begin(), indices.end()); + SmallVector operands(indices.begin(), indices.end()); fullyComposeAffineMapAndOperands(&map, &operands); map = simplifyAffineMap(map); canonicalizeMapAndOperands(&map, &operands); diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index ce96a19751f..78a869884ee 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -195,8 +195,8 @@ MutableIntegerSet::MutableIntegerSet(unsigned numDims, unsigned numSymbols, // AffineValueMap. //===----------------------------------------------------------------------===// -AffineValueMap::AffineValueMap(AffineMap map, ArrayRef operands, - ArrayRef results) +AffineValueMap::AffineValueMap(AffineMap map, ArrayRef operands, + ArrayRef results) : map(map), operands(operands.begin(), operands.end()), results(results.begin(), results.end()) {} @@ -210,8 +210,8 @@ AffineValueMap::AffineValueMap(AffineBound bound) : map(bound.getMap()), operands(bound.operand_begin(), bound.operand_end()) {} -void AffineValueMap::reset(AffineMap map, ArrayRef operands, - ArrayRef results) { +void AffineValueMap::reset(AffineMap map, ArrayRef operands, + ArrayRef results) { this->map.reset(map); this->operands.assign(operands.begin(), operands.end()); this->results.assign(results.begin(), results.end()); @@ -223,14 +223,14 @@ void AffineValueMap::difference(const AffineValueMap &a, // Fully compose A's map + operands. auto aMap = a.getAffineMap(); - SmallVector aOperands(a.getOperands().begin(), - a.getOperands().end()); + SmallVector aOperands(a.getOperands().begin(), + a.getOperands().end()); fullyComposeAffineMapAndOperands(&aMap, &aOperands); // Use the affine apply normalizer to get B's map into A's coordinate space. AffineApplyNormalizer normalizer(aMap, aOperands); - SmallVector bOperands(b.getOperands().begin(), - b.getOperands().end()); + SmallVector bOperands(b.getOperands().begin(), + b.getOperands().end()); auto bMap = b.getAffineMap(); normalizer.normalize(&bMap, &bOperands); @@ -254,7 +254,7 @@ void AffineValueMap::difference(const AffineValueMap &a, // Returns true and sets 'indexOfMatch' if 'valueToMatch' is found in // 'valuesToSearch' beginning at 'indexStart'. Returns false otherwise. -static bool findIndex(ValuePtr valueToMatch, ArrayRef valuesToSearch, +static bool findIndex(Value valueToMatch, ArrayRef valuesToSearch, unsigned indexStart, unsigned *indexOfMatch) { unsigned size = valuesToSearch.size(); for (unsigned i = indexStart; i < size; ++i) { @@ -272,7 +272,7 @@ inline bool AffineValueMap::isMultipleOf(unsigned idx, int64_t factor) const { /// This method uses the invariant that operands are always positionally aligned /// with the AffineDimExpr in the underlying AffineMap. -bool AffineValueMap::isFunctionOf(unsigned idx, ValuePtr value) const { +bool AffineValueMap::isFunctionOf(unsigned idx, Value value) const { unsigned index; if (!findIndex(value, operands, /*indexStart=*/0, &index)) { return false; @@ -283,12 +283,12 @@ bool AffineValueMap::isFunctionOf(unsigned idx, ValuePtr value) const { return expr.isFunctionOfDim(index); } -ValuePtr AffineValueMap::getOperand(unsigned i) const { - return static_cast(operands[i]); +Value AffineValueMap::getOperand(unsigned i) const { + return static_cast(operands[i]); } -ArrayRef AffineValueMap::getOperands() const { - return ArrayRef(operands); +ArrayRef AffineValueMap::getOperands() const { + return ArrayRef(operands); } AffineMap AffineValueMap::getAffineMap() const { return map.getAffineMap(); } @@ -369,7 +369,7 @@ void FlatAffineConstraints::reset(unsigned numReservedInequalities, unsigned newNumReservedCols, unsigned newNumDims, unsigned newNumSymbols, unsigned newNumLocals, - ArrayRef idArgs) { + ArrayRef idArgs) { assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 && "minimum 1 column"); numReservedCols = newNumReservedCols; @@ -392,7 +392,7 @@ void FlatAffineConstraints::reset(unsigned numReservedInequalities, void FlatAffineConstraints::reset(unsigned newNumDims, unsigned newNumSymbols, unsigned newNumLocals, - ArrayRef idArgs) { + ArrayRef idArgs) { reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims, newNumSymbols, newNumLocals, idArgs); } @@ -419,17 +419,17 @@ void FlatAffineConstraints::addLocalId(unsigned pos) { addId(IdKind::Local, pos); } -void FlatAffineConstraints::addDimId(unsigned pos, ValuePtr id) { +void FlatAffineConstraints::addDimId(unsigned pos, Value id) { addId(IdKind::Dimension, pos, id); } -void FlatAffineConstraints::addSymbolId(unsigned pos, ValuePtr id) { +void FlatAffineConstraints::addSymbolId(unsigned pos, Value id) { addId(IdKind::Symbol, pos, id); } /// Adds a dimensional identifier. The added column is initialized to /// zero. -void FlatAffineConstraints::addId(IdKind kind, unsigned pos, ValuePtr id) { +void FlatAffineConstraints::addId(IdKind kind, unsigned pos, Value id) { if (kind == IdKind::Dimension) { assert(pos <= getNumDimIds()); } else if (kind == IdKind::Symbol) { @@ -518,7 +518,7 @@ bool FlatAffineConstraints::areIdsAlignedWithOther( /// Checks if the SSA values associated with `cst''s identifiers are unique. static bool LLVM_ATTRIBUTE_UNUSED areIdsUnique(const FlatAffineConstraints &cst) { - SmallPtrSet uniqueIds; + SmallPtrSet uniqueIds; for (auto id : cst.getIds()) { if (id.hasValue() && !uniqueIds.insert(id.getValue()).second) return false; @@ -562,11 +562,11 @@ static void mergeAndAlignIds(unsigned offset, FlatAffineConstraints *A, assert(std::all_of(A->getIds().begin() + offset, A->getIds().begin() + A->getNumDimAndSymbolIds(), - [](Optional id) { return id.hasValue(); })); + [](Optional id) { return id.hasValue(); })); assert(std::all_of(B->getIds().begin() + offset, B->getIds().begin() + B->getNumDimAndSymbolIds(), - [](Optional id) { return id.hasValue(); })); + [](Optional id) { return id.hasValue(); })); // Place local id's of A after local id's of B. for (unsigned l = 0, e = A->getNumLocalIds(); l < e; l++) { @@ -577,7 +577,7 @@ static void mergeAndAlignIds(unsigned offset, FlatAffineConstraints *A, A->addLocalId(A->getNumLocalIds()); } - SmallVector aDimValues, aSymValues; + SmallVector aDimValues, aSymValues; A->getIdValues(offset, A->getNumDimIds(), &aDimValues); A->getIdValues(A->getNumDimIds(), A->getNumDimAndSymbolIds(), &aSymValues); { @@ -776,7 +776,7 @@ LogicalResult FlatAffineConstraints::composeMatchingMap(AffineMap other) { } // Turn a dimension into a symbol. -static void turnDimIntoSymbol(FlatAffineConstraints *cst, ValueRef id) { +static void turnDimIntoSymbol(FlatAffineConstraints *cst, Value id) { unsigned pos; if (cst->findId(id, &pos) && pos < cst->getNumDimIds()) { swapId(cst, pos, cst->getNumDimIds() - 1); @@ -785,7 +785,7 @@ static void turnDimIntoSymbol(FlatAffineConstraints *cst, ValueRef id) { } // Turn a symbol into a dimension. -static void turnSymbolIntoDim(FlatAffineConstraints *cst, ValueRef id) { +static void turnSymbolIntoDim(FlatAffineConstraints *cst, Value id) { unsigned pos; if (cst->findId(id, &pos) && pos >= cst->getNumDimIds() && pos < cst->getNumDimAndSymbolIds()) { @@ -797,7 +797,7 @@ static void turnSymbolIntoDim(FlatAffineConstraints *cst, ValueRef id) { // Changes all symbol identifiers which are loop IVs to dim identifiers. void FlatAffineConstraints::convertLoopIVSymbolsToDims() { // Gather all symbols which are loop IVs. - SmallVector loopIVs; + SmallVector loopIVs; for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++) { if (ids[i].hasValue() && getForInductionVarOwner(ids[i].getValue())) loopIVs.push_back(ids[i].getValue()); @@ -808,7 +808,7 @@ void FlatAffineConstraints::convertLoopIVSymbolsToDims() { } } -void FlatAffineConstraints::addInductionVarOrTerminalSymbol(ValuePtr id) { +void FlatAffineConstraints::addInductionVarOrTerminalSymbol(Value id) { if (containsId(*id)) return; @@ -867,8 +867,8 @@ LogicalResult FlatAffineConstraints::addAffineForOpDomain(AffineForOp forOp) { addConstantLowerBound(pos, forOp.getConstantLowerBound()); } else { // Non-constant lower bound case. - SmallVector lbOperands(forOp.getLowerBoundOperands().begin(), - forOp.getLowerBoundOperands().end()); + SmallVector lbOperands(forOp.getLowerBoundOperands().begin(), + forOp.getLowerBoundOperands().end()); if (failed(addLowerOrUpperBound(pos, forOp.getLowerBoundMap(), lbOperands, /*eq=*/false, /*lower=*/true))) return failure(); @@ -879,8 +879,8 @@ LogicalResult FlatAffineConstraints::addAffineForOpDomain(AffineForOp forOp) { return success(); } // Non-constant upper bound case. - SmallVector ubOperands(forOp.getUpperBoundOperands().begin(), - forOp.getUpperBoundOperands().end()); + SmallVector ubOperands(forOp.getUpperBoundOperands().begin(), + forOp.getUpperBoundOperands().end()); return addLowerOrUpperBound(pos, forOp.getUpperBoundMap(), ubOperands, /*eq=*/false, /*lower=*/false); } @@ -1748,7 +1748,7 @@ void FlatAffineConstraints::getSliceBounds(unsigned offset, unsigned num, LogicalResult FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap, - ArrayRef boundOperands, + ArrayRef boundOperands, bool eq, bool lower) { assert(pos < getNumDimAndSymbolIds() && "invalid position"); // Equality follows the logic of lower bound except that we add an equality @@ -1760,7 +1760,7 @@ FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap, // Fully compose map and operands; canonicalize and simplify so that we // transitively get to terminal symbols or loop IVs. auto map = boundMap; - SmallVector operands(boundOperands.begin(), boundOperands.end()); + SmallVector operands(boundOperands.begin(), boundOperands.end()); fullyComposeAffineMapAndOperands(&map, &operands); map = simplifyAffineMap(map); canonicalizeMapAndOperands(&map, &operands); @@ -1838,9 +1838,10 @@ FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap, // Note that both lower/upper bounds use operands from 'operands'. // Returns failure for unimplemented cases such as semi-affine expressions or // expressions with mod/floordiv. -LogicalResult FlatAffineConstraints::addSliceBounds( - ArrayRef values, ArrayRef lbMaps, - ArrayRef ubMaps, ArrayRef operands) { +LogicalResult FlatAffineConstraints::addSliceBounds(ArrayRef values, + ArrayRef lbMaps, + ArrayRef ubMaps, + ArrayRef operands) { assert(values.size() == lbMaps.size()); assert(lbMaps.size() == ubMaps.size()); @@ -1962,7 +1963,7 @@ void FlatAffineConstraints::addLocalFloorDiv(ArrayRef dividend, addInequality(bound); } -bool FlatAffineConstraints::findId(ValueRef id, unsigned *pos) const { +bool FlatAffineConstraints::findId(Value id, unsigned *pos) const { unsigned i = 0; for (const auto &mayBeId : ids) { if (mayBeId.hasValue() && mayBeId.getValue() == id) { @@ -1974,8 +1975,8 @@ bool FlatAffineConstraints::findId(ValueRef id, unsigned *pos) const { return false; } -bool FlatAffineConstraints::containsId(ValueRef id) const { - return llvm::any_of(ids, [&](const Optional &mayBeId) { +bool FlatAffineConstraints::containsId(Value id) const { + return llvm::any_of(ids, [&](const Optional &mayBeId) { return mayBeId.hasValue() && mayBeId.getValue() == id; }); } @@ -1999,7 +2000,7 @@ void FlatAffineConstraints::setIdToConstant(unsigned pos, int64_t val) { /// Sets the specified identifier to a constant value; asserts if the id is not /// found. -void FlatAffineConstraints::setIdToConstant(ValueRef id, int64_t val) { +void FlatAffineConstraints::setIdToConstant(Value id, int64_t val) { unsigned pos; if (!findId(id, &pos)) // This is a pre-condition for this method. @@ -2564,7 +2565,7 @@ void FlatAffineConstraints::FourierMotzkinEliminate( unsigned newNumDims = dimsSymbols.first; unsigned newNumSymbols = dimsSymbols.second; - SmallVector, 8> newIds; + SmallVector, 8> newIds; newIds.reserve(numIds - 1); newIds.append(ids.begin(), ids.begin() + pos); newIds.append(ids.begin() + pos + 1, ids.end()); @@ -2700,7 +2701,7 @@ void FlatAffineConstraints::projectOut(unsigned pos, unsigned num) { normalizeConstraintsByGCD(); } -void FlatAffineConstraints::projectOut(ValuePtr id) { +void FlatAffineConstraints::projectOut(Value id) { unsigned pos; bool ret = findId(*id, &pos); assert(ret); diff --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp index 65f6e83bcdf..c35421d55eb 100644 --- a/mlir/lib/Analysis/CallGraph.cpp +++ b/mlir/lib/Analysis/CallGraph.cpp @@ -179,7 +179,7 @@ CallGraphNode *CallGraph::resolveCallable(CallInterfaceCallable callable, callee = SymbolTable::lookupNearestSymbolFrom(from, symbolRef.getRootReference()); else - callee = callable.get()->getDefiningOp(); + callee = callable.get()->getDefiningOp(); // If the callee is non-null and is a valid callable object, try to get the // called region from it. diff --git a/mlir/lib/Analysis/Dominance.cpp b/mlir/lib/Analysis/Dominance.cpp index ea1501e8998..e4af4c0d69b 100644 --- a/mlir/lib/Analysis/Dominance.cpp +++ b/mlir/lib/Analysis/Dominance.cpp @@ -118,7 +118,7 @@ bool DominanceInfo::properlyDominates(Operation *a, Operation *b) { } /// Return true if value A properly dominates operation B. -bool DominanceInfo::properlyDominates(ValuePtr a, Operation *b) { +bool DominanceInfo::properlyDominates(Value a, Operation *b) { if (auto *aOp = a->getDefiningOp()) { // The values defined by an operation do *not* dominate any nested // operations. diff --git a/mlir/lib/Analysis/Liveness.cpp b/mlir/lib/Analysis/Liveness.cpp index 9b7b806c558..7ba31365f1a 100644 --- a/mlir/lib/Analysis/Liveness.cpp +++ b/mlir/lib/Analysis/Liveness.cpp @@ -31,13 +31,13 @@ struct BlockInfoBuilder { /// Fills the block builder with initial liveness information. BlockInfoBuilder(Block *block) : block(block) { // Mark all block arguments (phis) as defined. - for (BlockArgumentPtr argument : block->getArguments()) + for (BlockArgument argument : block->getArguments()) defValues.insert(argument); // Check all result values and whether their uses // are inside this block or not (see outValues). for (Operation &operation : *block) - for (ValuePtr result : operation.getResults()) { + for (Value result : operation.getResults()) { defValues.insert(result); // Check whether this value will be in the outValues @@ -54,7 +54,7 @@ struct BlockInfoBuilder { // Check all operations for used operands. for (Operation &operation : block->getOperations()) - for (ValuePtr operand : operation.getOperands()) { + for (Value operand : operation.getOperands()) { // If the operand is already defined in the scope of this // block, we can skip the value in the use set. if (!defValues.count(operand)) @@ -164,7 +164,7 @@ void Liveness::build(MutableArrayRef regions) { } /// Gets liveness info (if any) for the given value. -Liveness::OperationListT Liveness::resolveLiveness(ValuePtr value) const { +Liveness::OperationListT Liveness::resolveLiveness(Value value) const { OperationListT result; SmallPtrSet visited; SmallVector toProcess; @@ -229,7 +229,7 @@ const Liveness::ValueSetT &Liveness::getLiveOut(Block *block) const { /// Returns true if the given operation represent the last use of the /// given value. -bool Liveness::isLastUse(ValuePtr value, Operation *operation) const { +bool Liveness::isLastUse(Value value, Operation *operation) const { Block *block = operation->getBlock(); const LivenessBlockInfo *blockInfo = getLiveness(block); @@ -254,21 +254,21 @@ void Liveness::print(raw_ostream &os) const { // Builds unique block/value mappings for testing purposes. DenseMap blockIds; DenseMap operationIds; - DenseMap valueIds; + DenseMap valueIds; for (Region ®ion : operation->getRegions()) for (Block &block : region) { blockIds.insert({&block, blockIds.size()}); - for (BlockArgumentPtr argument : block.getArguments()) + for (BlockArgument argument : block.getArguments()) valueIds.insert({argument, valueIds.size()}); for (Operation &operation : block) { operationIds.insert({&operation, operationIds.size()}); - for (ValuePtr result : operation.getResults()) + for (Value result : operation.getResults()) valueIds.insert({result, valueIds.size()}); } } // Local printing helpers - auto printValueRef = [&](ValuePtr value) { + auto printValueRef = [&](Value value) { if (Operation *defOp = value->getDefiningOp()) os << "val_" << defOp->getName(); else { @@ -280,12 +280,12 @@ void Liveness::print(raw_ostream &os) const { }; auto printValueRefs = [&](const ValueSetT &values) { - std::vector orderedValues(values.begin(), values.end()); + std::vector orderedValues(values.begin(), values.end()); std::sort(orderedValues.begin(), orderedValues.end(), - [&](ValuePtr left, ValuePtr right) { + [&](Value left, Value right) { return valueIds[left] < valueIds[right]; }); - for (ValuePtr value : orderedValues) + for (Value value : orderedValues) printValueRef(value); }; @@ -306,7 +306,7 @@ void Liveness::print(raw_ostream &os) const { if (op.getNumResults() < 1) continue; os << "\n"; - for (ValuePtr result : op.getResults()) { + for (Value result : op.getResults()) { os << "// "; printValueRef(result); os << ":"; @@ -331,18 +331,18 @@ void Liveness::print(raw_ostream &os) const { //===----------------------------------------------------------------------===// /// Returns true if the given value is in the live-in set. -bool LivenessBlockInfo::isLiveIn(ValuePtr value) const { +bool LivenessBlockInfo::isLiveIn(Value value) const { return inValues.count(value); } /// Returns true if the given value is in the live-out set. -bool LivenessBlockInfo::isLiveOut(ValuePtr value) const { +bool LivenessBlockInfo::isLiveOut(Value value) const { return outValues.count(value); } /// Gets the start operation for the given value /// (must be referenced in this block). -Operation *LivenessBlockInfo::getStartOperation(ValuePtr value) const { +Operation *LivenessBlockInfo::getStartOperation(Value value) const { Operation *definingOp = value->getDefiningOp(); // The given value is either live-in or is defined // in the scope of this block. @@ -353,7 +353,7 @@ Operation *LivenessBlockInfo::getStartOperation(ValuePtr value) const { /// Gets the end operation for the given value using the start operation /// provided (must be referenced in this block). -Operation *LivenessBlockInfo::getEndOperation(ValuePtr value, +Operation *LivenessBlockInfo::getEndOperation(Value value, Operation *startOperation) const { // The given value is either dying in this block or live-out. if (isLiveOut(value)) diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 5499f887c1e..18c86dc63b4 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -34,7 +34,7 @@ using namespace mlir; // be more powerful (since both inequalities and equalities will be considered). void mlir::buildTripCountMapAndOperands( AffineForOp forOp, AffineMap *tripCountMap, - SmallVectorImpl *tripCountOperands) { + SmallVectorImpl *tripCountOperands) { int64_t loopSpan; int64_t step = forOp.getStep(); @@ -56,8 +56,8 @@ void mlir::buildTripCountMapAndOperands( *tripCountMap = AffineMap(); return; } - SmallVector lbOperands(forOp.getLowerBoundOperands()); - SmallVector ubOperands(forOp.getUpperBoundOperands()); + SmallVector lbOperands(forOp.getLowerBoundOperands()); + SmallVector ubOperands(forOp.getUpperBoundOperands()); // Difference of each upper bound expression from the single lower bound // expression (divided by the step) provides the expressions for the trip @@ -89,7 +89,7 @@ void mlir::buildTripCountMapAndOperands( // works with analysis structures (FlatAffineConstraints) and thus doesn't // update the IR. Optional mlir::getConstantTripCount(AffineForOp forOp) { - SmallVector operands; + SmallVector operands; AffineMap map; buildTripCountMapAndOperands(forOp, &map, &operands); @@ -115,7 +115,7 @@ Optional mlir::getConstantTripCount(AffineForOp forOp) { /// expression analysis is used (indirectly through getTripCount), and /// this method is thus able to determine non-trivial divisors. uint64_t mlir::getLargestDivisorOfTripCount(AffineForOp forOp) { - SmallVector operands; + SmallVector operands; AffineMap map; buildTripCountMapAndOperands(forOp, &map, &operands); @@ -164,7 +164,7 @@ uint64_t mlir::getLargestDivisorOfTripCount(AffineForOp forOp) { /// /// Returns false in cases with more than one AffineApplyOp, this is /// conservative. -static bool isAccessIndexInvariant(ValuePtr iv, ValuePtr index) { +static bool isAccessIndexInvariant(Value iv, Value index) { assert(isForInductionVar(iv) && "iv must be a AffineForOp"); assert(index->getType().isa() && "index must be of IndexType"); SmallVector affineApplyOps; @@ -188,9 +188,8 @@ static bool isAccessIndexInvariant(ValuePtr iv, ValuePtr index) { return !(AffineValueMap(composeOp).isFunctionOf(0, iv)); } -DenseSet mlir::getInvariantAccesses(ValuePtr iv, - ArrayRef indices) { - DenseSet res; +DenseSet mlir::getInvariantAccesses(Value iv, ArrayRef indices) { + DenseSet res; for (unsigned idx = 0, n = indices.size(); idx < n; ++idx) { auto val = indices[idx]; if (isAccessIndexInvariant(iv, val)) { @@ -220,7 +219,7 @@ DenseSet mlir::getInvariantAccesses(ValuePtr iv, /// // TODO(ntv): check strides. template -static bool isContiguousAccess(ValuePtr iv, LoadOrStoreOp memoryOp, +static bool isContiguousAccess(Value iv, LoadOrStoreOp memoryOp, int *memRefDim) { static_assert(std::is_same::value || std::is_same::value, @@ -241,11 +240,11 @@ static bool isContiguousAccess(ValuePtr iv, LoadOrStoreOp memoryOp, int uniqueVaryingIndexAlongIv = -1; auto accessMap = memoryOp.getAffineMap(); - SmallVector mapOperands(memoryOp.getMapOperands()); + SmallVector mapOperands(memoryOp.getMapOperands()); unsigned numDims = accessMap.getNumDims(); for (unsigned i = 0, e = memRefType.getRank(); i < e; ++i) { // Gather map operands used result expr 'i' in 'exprOperands'. - SmallVector exprOperands; + SmallVector exprOperands; auto resultExpr = accessMap.getResult(i); resultExpr.walk([&](AffineExpr expr) { if (auto dimExpr = expr.dyn_cast()) @@ -373,7 +372,7 @@ bool mlir::isInstwiseShiftValid(AffineForOp forOp, ArrayRef shifts) { // Validate the results of this operation if it were to be shifted. for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) { - ValuePtr result = op.getResult(i); + Value result = op.getResult(i); for (auto *user : result->getUsers()) { // If an ancestor operation doesn't lie in the block of forOp, // there is no shift to check. diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 0e7d10e78cf..8ddf2e274eb 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -51,7 +51,7 @@ ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) { // Adds operands (dst ivs and symbols) as symbols in 'cst'. unsigned numSymbols = lbOperands[0].size(); - SmallVector values(ivs); + SmallVector values(ivs); // Append 'ivs' then 'operands' to 'values'. values.append(lbOperands[0].begin(), lbOperands[0].end()); cst->reset(numDims, numSymbols, 0, values); @@ -176,7 +176,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth, if (rank == 0) { SmallVector ivs; getLoopIVs(*op, &ivs); - SmallVector regionSymbols; + SmallVector regionSymbols; extractForInductionVars(ivs, ®ionSymbols); // A rank 0 memref has a 0-d region. cst.reset(rank, loopDepth, 0, regionSymbols); @@ -192,7 +192,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth, unsigned numSymbols = accessMap.getNumSymbols(); unsigned numOperands = accessValueMap.getNumOperands(); // Merge operands with slice operands. - SmallVector operands; + SmallVector operands; operands.resize(numOperands); for (unsigned i = 0; i < numOperands; ++i) operands[i] = accessValueMap.getOperand(i); @@ -269,7 +269,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth, getLoopIVs(*op, &enclosingIVs); assert(loopDepth <= enclosingIVs.size() && "invalid loop depth"); enclosingIVs.resize(loopDepth); - SmallVector ids; + SmallVector ids; cst.getIdValues(cst.getNumDimIds(), cst.getNumDimAndSymbolIds(), &ids); for (auto id : ids) { AffineForOp iv; @@ -336,9 +336,9 @@ Optional MemRefRegion::getRegionSize() { // Indices to use for the DmaStart op. // Indices for the original memref being DMAed from/to. - SmallVector memIndices; + SmallVector memIndices; // Indices for the faster buffer being DMAed into/from. - SmallVector bufIndices; + SmallVector bufIndices; // Compute the extents of the buffer. Optional numElements = getConstantBoundingSizeAndShape(); @@ -471,7 +471,7 @@ static Operation *getInstAtPosition(ArrayRef positions, } // Adds loop IV bounds to 'cst' for loop IVs not found in 'ivs'. -LogicalResult addMissingLoopIVBounds(SmallPtrSet &ivs, +LogicalResult addMissingLoopIVBounds(SmallPtrSet &ivs, FlatAffineConstraints *cst) { for (unsigned i = 0, e = cst->getNumDimIds(); i < e; ++i) { auto value = cst->getIdValue(i); @@ -587,10 +587,10 @@ LogicalResult mlir::computeSliceUnion(ArrayRef opsA, // Pre-constraint id alignment: record loop IVs used in each constraint // system. - SmallPtrSet sliceUnionIVs; + SmallPtrSet sliceUnionIVs; for (unsigned k = 0, l = sliceUnionCst.getNumDimIds(); k < l; ++k) sliceUnionIVs.insert(sliceUnionCst.getIdValue(k)); - SmallPtrSet tmpSliceIVs; + SmallPtrSet tmpSliceIVs; for (unsigned k = 0, l = tmpSliceCst.getNumDimIds(); k < l; ++k) tmpSliceIVs.insert(tmpSliceCst.getIdValue(k)); @@ -650,7 +650,7 @@ LogicalResult mlir::computeSliceUnion(ArrayRef opsA, &sliceUnion->ubs); // Add slice bound operands of union. - SmallVector sliceBoundOperands; + SmallVector sliceBoundOperands; sliceUnionCst.getIdValues(numSliceLoopIVs, sliceUnionCst.getNumDimAndSymbolIds(), &sliceBoundOperands); @@ -716,7 +716,7 @@ void mlir::getComputationSliceState( &sliceState->lbs, &sliceState->ubs); // Set up bound operands for the slice's lower and upper bounds. - SmallVector sliceBoundOperands; + SmallVector sliceBoundOperands; unsigned numDimsAndSymbols = dependenceConstraints->getNumDimAndSymbolIds(); for (unsigned i = 0; i < numDimsAndSymbols; ++i) { if (i < offset || i >= offset + numSliceLoopIVs) { @@ -734,7 +734,7 @@ void mlir::getComputationSliceState( isBackwardSlice ? dstLoopIVs[loopDepth - 1].getBody()->begin() : std::prev(srcLoopIVs[loopDepth - 1].getBody()->end()); - llvm::SmallDenseSet sequentialLoops; + llvm::SmallDenseSet sequentialLoops; if (isa(depSourceOp) && isa(depSinkOp)) { // For read-read access pairs, clear any slice bounds on sequential loops. // Get sequential loops in loop nest rooted at 'srcLoopIVs[0]'. @@ -749,7 +749,7 @@ void mlir::getComputationSliceState( return isBackwardSlice ? srcLoopIVs[i] : dstLoopIVs[i]; }; for (unsigned i = 0; i < numSliceLoopIVs; ++i) { - ValuePtr iv = getSliceLoop(i).getInductionVar(); + Value iv = getSliceLoop(i).getInductionVar(); if (sequentialLoops.count(iv) == 0 && getSliceLoop(i).getAttr(kSliceFusionBarrierAttrName) == nullptr) continue; @@ -910,7 +910,7 @@ static Optional getMemoryFootprintBytes(Block &block, Block::iterator start, Block::iterator end, int memorySpace) { - SmallDenseMap, 4> regions; + SmallDenseMap, 4> regions; // Walk this 'affine.for' operation to gather all memory regions. auto result = block.walk(start, end, [&](Operation *opInst) -> WalkResult { @@ -960,8 +960,8 @@ Optional mlir::getMemoryFootprintBytes(AffineForOp forOp, /// Returns in 'sequentialLoops' all sequential loops in loop nest rooted /// at 'forOp'. -void mlir::getSequentialLoops( - AffineForOp forOp, llvm::SmallDenseSet *sequentialLoops) { +void mlir::getSequentialLoops(AffineForOp forOp, + llvm::SmallDenseSet *sequentialLoops) { forOp.getOperation()->walk([&](Operation *op) { if (auto innerFor = dyn_cast(op)) if (!isLoopParallel(innerFor)) diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index cd77eff9e40..1c7dbed5fac 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -100,7 +100,7 @@ Optional> mlir::shapeRatio(VectorType superVectorType, /// Examples can be found in the documentation of `makePermutationMap`, in the /// header file. static AffineMap makePermutationMap( - ArrayRef indices, + ArrayRef indices, const DenseMap &enclosingLoopToVectorDim) { if (enclosingLoopToVectorDim.empty()) return AffineMap(); @@ -158,7 +158,7 @@ static SetVector getEnclosingforOps(Operation *op) { } AffineMap mlir::makePermutationMap( - Operation *op, ArrayRef indices, + Operation *op, ArrayRef indices, const DenseMap &loopToVectorDim) { DenseMap enclosingLoopToVectorDim; auto enclosingLoops = getEnclosingforOps(op); diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp index ce1e5c4a2af..e9a9ca82f51 100644 --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -33,16 +33,16 @@ namespace { // that correspond to it. Visitation functions return an Value of the // expression subtree they visited or `nullptr` on error. class AffineApplyExpander - : public AffineExprVisitor { + : public AffineExprVisitor { public: // This internal class expects arguments to be non-null, checks must be // performed at the call site. - AffineApplyExpander(OpBuilder &builder, ArrayRef dimValues, - ArrayRef symbolValues, Location loc) + AffineApplyExpander(OpBuilder &builder, ArrayRef dimValues, + ArrayRef symbolValues, Location loc) : builder(builder), dimValues(dimValues), symbolValues(symbolValues), loc(loc) {} - template ValuePtr buildBinaryExpr(AffineBinaryOpExpr expr) { + template Value buildBinaryExpr(AffineBinaryOpExpr expr) { auto lhs = visit(expr.getLHS()); auto rhs = visit(expr.getRHS()); if (!lhs || !rhs) @@ -51,11 +51,11 @@ public: return op.getResult(); } - ValuePtr visitAddExpr(AffineBinaryOpExpr expr) { + Value visitAddExpr(AffineBinaryOpExpr expr) { return buildBinaryExpr(expr); } - ValuePtr visitMulExpr(AffineBinaryOpExpr expr) { + Value visitMulExpr(AffineBinaryOpExpr expr) { return buildBinaryExpr(expr); } @@ -68,7 +68,7 @@ public: // let remainder = srem a, b; // negative = a < 0 in // select negative, remainder + b, remainder. - ValuePtr visitModExpr(AffineBinaryOpExpr expr) { + Value visitModExpr(AffineBinaryOpExpr expr) { auto rhsConst = expr.getRHS().dyn_cast(); if (!rhsConst) { emitError( @@ -85,13 +85,13 @@ public: auto rhs = visit(expr.getRHS()); assert(lhs && rhs && "unexpected affine expr lowering failure"); - ValuePtr remainder = builder.create(loc, lhs, rhs); - ValuePtr zeroCst = builder.create(loc, 0); - ValuePtr isRemainderNegative = + Value remainder = builder.create(loc, lhs, rhs); + Value zeroCst = builder.create(loc, 0); + Value isRemainderNegative = builder.create(loc, CmpIPredicate::slt, remainder, zeroCst); - ValuePtr correctedRemainder = builder.create(loc, remainder, rhs); - ValuePtr result = builder.create(loc, isRemainderNegative, - correctedRemainder, remainder); + Value correctedRemainder = builder.create(loc, remainder, rhs); + Value result = builder.create(loc, isRemainderNegative, + correctedRemainder, remainder); return result; } @@ -105,7 +105,7 @@ public: // let absolute = negative ? -a - 1 : a in // let quotient = absolute / b in // negative ? -quotient - 1 : quotient - ValuePtr visitFloorDivExpr(AffineBinaryOpExpr expr) { + Value visitFloorDivExpr(AffineBinaryOpExpr expr) { auto rhsConst = expr.getRHS().dyn_cast(); if (!rhsConst) { emitError( @@ -122,16 +122,16 @@ public: auto rhs = visit(expr.getRHS()); assert(lhs && rhs && "unexpected affine expr lowering failure"); - ValuePtr zeroCst = builder.create(loc, 0); - ValuePtr noneCst = builder.create(loc, -1); - ValuePtr negative = + Value zeroCst = builder.create(loc, 0); + Value noneCst = builder.create(loc, -1); + Value negative = builder.create(loc, CmpIPredicate::slt, lhs, zeroCst); - ValuePtr negatedDecremented = builder.create(loc, noneCst, lhs); - ValuePtr dividend = + Value negatedDecremented = builder.create(loc, noneCst, lhs); + Value dividend = builder.create(loc, negative, negatedDecremented, lhs); - ValuePtr quotient = builder.create(loc, dividend, rhs); - ValuePtr correctedQuotient = builder.create(loc, noneCst, quotient); - ValuePtr result = + Value quotient = builder.create(loc, dividend, rhs); + Value correctedQuotient = builder.create(loc, noneCst, quotient); + Value result = builder.create(loc, negative, correctedQuotient, quotient); return result; } @@ -146,7 +146,7 @@ public: // let absolute = negative ? -a : a - 1 in // let quotient = absolute / b in // negative ? -quotient : quotient + 1 - ValuePtr visitCeilDivExpr(AffineBinaryOpExpr expr) { + Value visitCeilDivExpr(AffineBinaryOpExpr expr) { auto rhsConst = expr.getRHS().dyn_cast(); if (!rhsConst) { emitError(loc) << "semi-affine expressions (division by non-const) are " @@ -161,24 +161,23 @@ public: auto rhs = visit(expr.getRHS()); assert(lhs && rhs && "unexpected affine expr lowering failure"); - ValuePtr zeroCst = builder.create(loc, 0); - ValuePtr oneCst = builder.create(loc, 1); - ValuePtr nonPositive = + Value zeroCst = builder.create(loc, 0); + Value oneCst = builder.create(loc, 1); + Value nonPositive = builder.create(loc, CmpIPredicate::sle, lhs, zeroCst); - ValuePtr negated = builder.create(loc, zeroCst, lhs); - ValuePtr decremented = builder.create(loc, lhs, oneCst); - ValuePtr dividend = + Value negated = builder.create(loc, zeroCst, lhs); + Value decremented = builder.create(loc, lhs, oneCst); + Value dividend = builder.create(loc, nonPositive, negated, decremented); - ValuePtr quotient = builder.create(loc, dividend, rhs); - ValuePtr negatedQuotient = builder.create(loc, zeroCst, quotient); - ValuePtr incrementedQuotient = - builder.create(loc, quotient, oneCst); - ValuePtr result = builder.create( - loc, nonPositive, negatedQuotient, incrementedQuotient); + Value quotient = builder.create(loc, dividend, rhs); + Value negatedQuotient = builder.create(loc, zeroCst, quotient); + Value incrementedQuotient = builder.create(loc, quotient, oneCst); + Value result = builder.create(loc, nonPositive, negatedQuotient, + incrementedQuotient); return result; } - ValuePtr visitConstantExpr(AffineConstantExpr expr) { + Value visitConstantExpr(AffineConstantExpr expr) { auto valueAttr = builder.getIntegerAttr(builder.getIndexType(), expr.getValue()); auto op = @@ -186,13 +185,13 @@ public: return op.getResult(); } - ValuePtr visitDimExpr(AffineDimExpr expr) { + Value visitDimExpr(AffineDimExpr expr) { assert(expr.getPosition() < dimValues.size() && "affine dim position out of range"); return dimValues[expr.getPosition()]; } - ValuePtr visitSymbolExpr(AffineSymbolExpr expr) { + Value visitSymbolExpr(AffineSymbolExpr expr) { assert(expr.getPosition() < symbolValues.size() && "symbol dim position out of range"); return symbolValues[expr.getPosition()]; @@ -200,8 +199,8 @@ public: private: OpBuilder &builder; - ArrayRef dimValues; - ArrayRef symbolValues; + ArrayRef dimValues; + ArrayRef symbolValues; Location loc; }; @@ -209,18 +208,17 @@ private: // Create a sequence of operations that implement the `expr` applied to the // given dimension and symbol values. -mlir::ValuePtr mlir::expandAffineExpr(OpBuilder &builder, Location loc, - AffineExpr expr, - ArrayRef dimValues, - ArrayRef symbolValues) { +mlir::Value mlir::expandAffineExpr(OpBuilder &builder, Location loc, + AffineExpr expr, ArrayRef dimValues, + ArrayRef symbolValues) { return AffineApplyExpander(builder, dimValues, symbolValues, loc).visit(expr); } // Create a sequence of operations that implement the `affineMap` applied to // the given `operands` (as it it were an AffineApplyOp). -Optional> static expandAffineMap( +Optional> static expandAffineMap( OpBuilder &builder, Location loc, AffineMap affineMap, - ArrayRef operands) { + ArrayRef operands) { auto numDims = affineMap.getNumDims(); auto expanded = functional::map( [numDims, &builder, loc, operands](AffineExpr expr) { @@ -229,7 +227,7 @@ Optional> static expandAffineMap( operands.drop_front(numDims)); }, affineMap.getResults()); - if (llvm::all_of(expanded, [](ValuePtr v) { return v; })) + if (llvm::all_of(expanded, [](Value v) { return v; })) return expanded; return None; } @@ -245,13 +243,13 @@ Optional> static expandAffineMap( // Multiple values are scanned in a linear sequence. This creates a data // dependences that wouldn't exist in a tree reduction, but is easier to // recognize as a reduction by the subsequent passes. -static ValuePtr buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate, - ArrayRef values, - OpBuilder &builder) { +static Value buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate, + ArrayRef values, + OpBuilder &builder) { assert(!llvm::empty(values) && "empty min/max chain"); auto valueIt = values.begin(); - ValuePtr value = *valueIt++; + Value value = *valueIt++; for (; valueIt != values.end(); ++valueIt) { auto cmpOp = builder.create(loc, predicate, value, *valueIt); value = builder.create(loc, cmpOp.getResult(), value, *valueIt); @@ -263,8 +261,8 @@ static ValuePtr buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate, // Emit instructions that correspond to the affine map in the lower bound // applied to the respective operands, and compute the maximum value across // the results. -ValuePtr mlir::lowerAffineLowerBound(AffineForOp op, OpBuilder &builder) { - SmallVector boundOperands(op.getLowerBoundOperands()); +Value mlir::lowerAffineLowerBound(AffineForOp op, OpBuilder &builder) { + SmallVector boundOperands(op.getLowerBoundOperands()); auto lbValues = expandAffineMap(builder, op.getLoc(), op.getLowerBoundMap(), boundOperands); if (!lbValues) @@ -276,8 +274,8 @@ ValuePtr mlir::lowerAffineLowerBound(AffineForOp op, OpBuilder &builder) { // Emit instructions that correspond to the affine map in the upper bound // applied to the respective operands, and compute the minimum value across // the results. -ValuePtr mlir::lowerAffineUpperBound(AffineForOp op, OpBuilder &builder) { - SmallVector boundOperands(op.getUpperBoundOperands()); +Value mlir::lowerAffineUpperBound(AffineForOp op, OpBuilder &builder) { + SmallVector boundOperands(op.getUpperBoundOperands()); auto ubValues = expandAffineMap(builder, op.getLoc(), op.getUpperBoundMap(), boundOperands); if (!ubValues) @@ -306,9 +304,9 @@ public: PatternMatchResult matchAndRewrite(AffineForOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - ValuePtr lowerBound = lowerAffineLowerBound(op, rewriter); - ValuePtr upperBound = lowerAffineUpperBound(op, rewriter); - ValuePtr step = rewriter.create(loc, op.getStep()); + Value lowerBound = lowerAffineLowerBound(op, rewriter); + Value upperBound = lowerAffineUpperBound(op, rewriter); + Value step = rewriter.create(loc, op.getStep()); auto f = rewriter.create(loc, lowerBound, upperBound, step); f.region().getBlocks().clear(); rewriter.inlineRegionBefore(op.region(), f.region(), f.region().end()); @@ -327,25 +325,25 @@ public: // Now we just have to handle the condition logic. auto integerSet = op.getIntegerSet(); - ValuePtr zeroConstant = rewriter.create(loc, 0); - SmallVector operands(op.getOperands()); + Value zeroConstant = rewriter.create(loc, 0); + SmallVector operands(op.getOperands()); auto operandsRef = llvm::makeArrayRef(operands); // Calculate cond as a conjunction without short-circuiting. - ValuePtr cond = nullptr; + Value cond = nullptr; for (unsigned i = 0, e = integerSet.getNumConstraints(); i < e; ++i) { AffineExpr constraintExpr = integerSet.getConstraint(i); bool isEquality = integerSet.isEq(i); // Build and apply an affine expression auto numDims = integerSet.getNumDims(); - ValuePtr affResult = expandAffineExpr(rewriter, loc, constraintExpr, - operandsRef.take_front(numDims), - operandsRef.drop_front(numDims)); + Value affResult = expandAffineExpr(rewriter, loc, constraintExpr, + operandsRef.take_front(numDims), + operandsRef.drop_front(numDims)); if (!affResult) return matchFailure(); auto pred = isEquality ? CmpIPredicate::eq : CmpIPredicate::sge; - ValuePtr cmpVal = + Value cmpVal = rewriter.create(loc, pred, affResult, zeroConstant); cond = cond ? rewriter.create(loc, cond, cmpVal).getResult() : cmpVal; @@ -396,7 +394,7 @@ public: PatternMatchResult matchAndRewrite(AffineLoadOp op, PatternRewriter &rewriter) const override { // Expand affine map from 'affineLoadOp'. - SmallVector indices(op.getMapOperands()); + SmallVector indices(op.getMapOperands()); auto resultOperands = expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); if (!resultOperands) @@ -418,7 +416,7 @@ public: PatternMatchResult matchAndRewrite(AffinePrefetchOp op, PatternRewriter &rewriter) const override { // Expand affine map from 'affinePrefetchOp'. - SmallVector indices(op.getMapOperands()); + SmallVector indices(op.getMapOperands()); auto resultOperands = expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); if (!resultOperands) @@ -442,7 +440,7 @@ public: PatternMatchResult matchAndRewrite(AffineStoreOp op, PatternRewriter &rewriter) const override { // Expand affine map from 'affineStoreOp'. - SmallVector indices(op.getMapOperands()); + SmallVector indices(op.getMapOperands()); auto maybeExpandedMap = expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); if (!maybeExpandedMap) @@ -464,7 +462,7 @@ public: PatternMatchResult matchAndRewrite(AffineDmaStartOp op, PatternRewriter &rewriter) const override { - SmallVector operands(op.getOperands()); + SmallVector operands(op.getOperands()); auto operandsRef = llvm::makeArrayRef(operands); // Expand affine map for DMA source memref. @@ -505,7 +503,7 @@ public: PatternMatchResult matchAndRewrite(AffineDmaWaitOp op, PatternRewriter &rewriter) const override { // Expand affine map for DMA tag memref. - SmallVector indices(op.getTagIndices()); + SmallVector indices(op.getTagIndices()); auto maybeExpandedTagMap = expandAffineMap(rewriter, op.getLoc(), op.getTagMap(), indices); if (!maybeExpandedTagMap) diff --git a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h index 2ca9717ad86..63bc15173be 100644 --- a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h +++ b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h @@ -48,11 +48,11 @@ public: // Convert the kernel arguments to an LLVM type, preserve the rest. PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto dialect = lowering.getDialect(); - ValuePtr newOp; + Value newOp; switch (dimensionToIndex(cast(op))) { case X: newOp = rewriter.create(loc, LLVM::LLVMType::getInt32Ty(dialect)); diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h index 97881d359f6..b75c1bf2d7b 100644 --- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -35,7 +35,7 @@ public: f32Func(f32Func), f64Func(f64Func) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { using LLVM::LLVMFuncOp; using LLVM::LLVMType; @@ -60,10 +60,10 @@ public: private: LLVM::LLVMType getFunctionType(LLVM::LLVMType resultType, - ArrayRef operands) const { + ArrayRef operands) const { using LLVM::LLVMType; SmallVector operandTypes; - for (ValuePtr operand : operands) { + for (Value operand : operands) { operandTypes.push_back(operand->getType().cast()); } return LLVMType::getFunctionTy(resultType, operandTypes, diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp index 3383cf13d36..19dabcdafee 100644 --- a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp @@ -105,7 +105,7 @@ private: } // Allocate a void pointer on the stack. - ValuePtr allocatePointer(OpBuilder &builder, Location loc) { + Value allocatePointer(OpBuilder &builder, Location loc) { auto one = builder.create(loc, getInt32Type(), builder.getI32IntegerAttr(1)); return builder.create(loc, getPointerPointerType(), one, @@ -113,9 +113,9 @@ private: } void declareCudaFunctions(Location loc); - ValuePtr setupParamsArray(gpu::LaunchFuncOp launchOp, OpBuilder &builder); - ValuePtr generateKernelNameConstant(StringRef name, Location loc, - OpBuilder &builder); + Value setupParamsArray(gpu::LaunchFuncOp launchOp, OpBuilder &builder); + Value generateKernelNameConstant(StringRef name, Location loc, + OpBuilder &builder); void translateGpuLaunchCalls(mlir::gpu::LaunchFuncOp launchOp); public: @@ -239,9 +239,8 @@ void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) { // for (i : [0, NumKernelOperands)) // %array[i] = cast(KernelOperand[i]) // return %array -ValuePtr -GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp, - OpBuilder &builder) { +Value GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp, + OpBuilder &builder) { auto numKernelOperands = launchOp.getNumKernelOperands(); Location loc = launchOp.getLoc(); auto one = builder.create(loc, getInt32Type(), @@ -255,7 +254,7 @@ GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp, for (unsigned idx = 0; idx < numKernelOperands; ++idx) { auto operand = launchOp.getKernelOperand(idx); auto llvmType = operand->getType().cast(); - ValuePtr memLocation = builder.create( + Value memLocation = builder.create( loc, llvmType.getPointerTo(), one, /*alignment=*/1); builder.create(loc, operand, memLocation); auto casted = @@ -271,12 +270,12 @@ GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp, getModule().lookupSymbol(kMcuMemHostRegister); auto nullPtr = builder.create(loc, llvmType.getPointerTo()); auto gep = builder.create(loc, llvmType.getPointerTo(), - ArrayRef{nullPtr, one}); + ArrayRef{nullPtr, one}); auto size = builder.create(loc, getInt64Type(), gep); builder.create(loc, ArrayRef{}, builder.getSymbolRefAttr(registerFunc), - ArrayRef{casted, size}); - ValuePtr memLocation = builder.create( + ArrayRef{casted, size}); + Value memLocation = builder.create( loc, getPointerPointerType(), one, /*alignment=*/1); builder.create(loc, casted, memLocation); casted = @@ -286,7 +285,7 @@ GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp, auto index = builder.create( loc, getInt32Type(), builder.getI32IntegerAttr(idx)); auto gep = builder.create(loc, getPointerPointerType(), array, - ArrayRef{index}); + ArrayRef{index}); builder.create(loc, casted, gep); } return array; @@ -302,7 +301,7 @@ GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp, // %1 = llvm.constant (0 : index) // %2 = llvm.getelementptr %0[%1, %1] : !llvm<"i8*"> // } -ValuePtr GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant( +Value GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant( StringRef name, Location loc, OpBuilder &builder) { // Make sure the trailing zero is included in the constant. std::vector kernelName(name.begin(), name.end()); @@ -358,7 +357,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls( assert(kernelModule.getName() && "expected a named module"); SmallString<128> nameBuffer(*kernelModule.getName()); nameBuffer.append(kCubinStorageSuffix); - ValuePtr data = LLVM::createGlobalString( + Value data = LLVM::createGlobalString( loc, builder, nameBuffer.str(), cubinAttr.getValue(), LLVM::Linkage::Internal, getLLVMDialect()); @@ -369,7 +368,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls( getModule().lookupSymbol(cuModuleLoadName); builder.create(loc, ArrayRef{getCUResultType()}, builder.getSymbolRefAttr(cuModuleLoad), - ArrayRef{cuModule, data}); + ArrayRef{cuModule, data}); // Get the function from the module. The name corresponds to the name of // the kernel function. auto cuOwningModuleRef = @@ -381,13 +380,13 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls( builder.create( loc, ArrayRef{getCUResultType()}, builder.getSymbolRefAttr(cuModuleGetFunction), - ArrayRef{cuFunction, cuOwningModuleRef, kernelName}); + ArrayRef{cuFunction, cuOwningModuleRef, kernelName}); // Grab the global stream needed for execution. auto cuGetStreamHelper = getModule().lookupSymbol(cuGetStreamHelperName); auto cuStream = builder.create( loc, ArrayRef{getPointerType()}, - builder.getSymbolRefAttr(cuGetStreamHelper), ArrayRef{}); + builder.getSymbolRefAttr(cuGetStreamHelper), ArrayRef{}); // Invoke the function with required arguments. auto cuLaunchKernel = getModule().lookupSymbol(cuLaunchKernelName); @@ -399,19 +398,19 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls( builder.create( loc, ArrayRef{getCUResultType()}, builder.getSymbolRefAttr(cuLaunchKernel), - ArrayRef{cuFunctionRef, launchOp.getOperand(0), - launchOp.getOperand(1), launchOp.getOperand(2), - launchOp.getOperand(3), launchOp.getOperand(4), - launchOp.getOperand(5), zero, /* sharedMemBytes */ - cuStream.getResult(0), /* stream */ - paramsArray, /* kernel params */ - nullpointer /* extra */}); + ArrayRef{cuFunctionRef, launchOp.getOperand(0), + launchOp.getOperand(1), launchOp.getOperand(2), + launchOp.getOperand(3), launchOp.getOperand(4), + launchOp.getOperand(5), zero, /* sharedMemBytes */ + cuStream.getResult(0), /* stream */ + paramsArray, /* kernel params */ + nullpointer /* extra */}); // Sync on the stream to make it synchronous. auto cuStreamSync = getModule().lookupSymbol(cuStreamSynchronizeName); builder.create(loc, ArrayRef{getCUResultType()}, builder.getSymbolRefAttr(cuStreamSync), - ArrayRef(cuStream.getResult(0))); + ArrayRef(cuStream.getResult(0))); launchOp.erase(); } diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index e15ad823a2b..08c18c1ec83 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -51,8 +51,8 @@ public: /// Converts all_reduce op to LLVM/NVVM ops. struct GPUAllReduceOpLowering : public LLVMOpLowering { - using AccumulatorFactory = std::function; + using AccumulatorFactory = + std::function; explicit GPUAllReduceOpLowering(LLVMTypeConverter &lowering_) : LLVMOpLowering(gpu::AllReduceOp::getOperationName(), @@ -60,10 +60,10 @@ struct GPUAllReduceOpLowering : public LLVMOpLowering { int32Type(LLVM::LLVMType::getInt32Ty(lowering_.getDialect())) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - ValuePtr operand = operands.front(); + Value operand = operands.front(); // TODO(csigg): Generalize to other types of accumulation. assert(op->getOperand(0)->getType().isIntOrFloat()); @@ -72,7 +72,7 @@ struct GPUAllReduceOpLowering : public LLVMOpLowering { AccumulatorFactory factory = getFactory(cast(op), operand); assert(factory && "failed to create accumulator factory"); - ValuePtr result = createBlockReduce(loc, operand, factory, rewriter); + Value result = createBlockReduce(loc, operand, factory, rewriter); rewriter.replaceOp(op, {result}); return matchSuccess(); @@ -82,7 +82,7 @@ private: /// Returns an accumulator factory using either the op attribute or the body /// region. AccumulatorFactory getFactory(gpu::AllReduceOp allReduce, - ValuePtr operand) const { + Value operand) const { if (!allReduce.body().empty()) { return getFactory(allReduce.body()); } @@ -97,7 +97,7 @@ private: /// block is expected to have 2 arguments. The gpu.yield return the /// accumulated value of the same type. AccumulatorFactory getFactory(Region &body) const { - return AccumulatorFactory([&](Location loc, ValuePtr lhs, ValuePtr rhs, + return AccumulatorFactory([&](Location loc, Value lhs, Value rhs, ConversionPatternRewriter &rewriter) { Block *block = rewriter.getInsertionBlock(); Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint()); @@ -111,7 +111,7 @@ private: // Add branch before inserted body, into body. block = block->getNextNode(); - rewriter.create(loc, ArrayRef{}, + rewriter.create(loc, ArrayRef{}, llvm::makeArrayRef(block), ValueRange()); // Replace all gpu.yield ops with branch out of body. @@ -121,7 +121,7 @@ private: continue; rewriter.setInsertionPointToEnd(block); rewriter.replaceOpWithNewOp( - terminator, ArrayRef{}, llvm::makeArrayRef(split), + terminator, ArrayRef{}, llvm::makeArrayRef(split), ValueRange(terminator->getOperand(0))); } @@ -152,7 +152,7 @@ private: /// Returns an accumulator factory that creates an op of type T. template AccumulatorFactory getFactory() const { - return [](Location loc, ValuePtr lhs, ValuePtr rhs, + return [](Location loc, Value lhs, Value rhs, ConversionPatternRewriter &rewriter) { return rewriter.create(loc, lhs->getType(), lhs, rhs); }; @@ -194,60 +194,60 @@ private: /// %result = llvm.load %result_ptr /// return %result /// - ValuePtr createBlockReduce(Location loc, ValuePtr operand, - AccumulatorFactory &accumFactory, - ConversionPatternRewriter &rewriter) const { + Value createBlockReduce(Location loc, Value operand, + AccumulatorFactory &accumFactory, + ConversionPatternRewriter &rewriter) const { auto type = operand->getType().cast(); // Create shared memory array to store the warp reduction. auto module = operand->getDefiningOp()->getParentOfType(); assert(module && "op must belong to a module"); - ValuePtr sharedMemPtr = + Value sharedMemPtr = createSharedMemoryArray(loc, module, type, kWarpSize, rewriter); - ValuePtr zero = rewriter.create( + Value zero = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(0u)); - ValuePtr laneId = rewriter.create(loc, int32Type); - ValuePtr isFirstLane = rewriter.create( + Value laneId = rewriter.create(loc, int32Type); + Value isFirstLane = rewriter.create( loc, LLVM::ICmpPredicate::eq, laneId, zero); - ValuePtr threadIdx = getLinearThreadIndex(loc, rewriter); - ValuePtr blockSize = getBlockSize(loc, rewriter); - ValuePtr activeWidth = getActiveWidth(loc, threadIdx, blockSize, rewriter); + Value threadIdx = getLinearThreadIndex(loc, rewriter); + Value blockSize = getBlockSize(loc, rewriter); + Value activeWidth = getActiveWidth(loc, threadIdx, blockSize, rewriter); // Reduce elements within each warp to produce the intermediate results. - ValuePtr warpReduce = createWarpReduce(loc, activeWidth, laneId, operand, - accumFactory, rewriter); + Value warpReduce = createWarpReduce(loc, activeWidth, laneId, operand, + accumFactory, rewriter); // Write the intermediate results to shared memory, using the first lane of // each warp. createPredicatedBlock(loc, rewriter, isFirstLane, [&] { - ValuePtr warpId = getDivideByWarpSize(threadIdx, rewriter); - ValuePtr storeDst = rewriter.create( - loc, type, sharedMemPtr, ArrayRef({zero, warpId})); + Value warpId = getDivideByWarpSize(threadIdx, rewriter); + Value storeDst = rewriter.create( + loc, type, sharedMemPtr, ArrayRef({zero, warpId})); rewriter.create(loc, warpReduce, storeDst); }); rewriter.create(loc); - ValuePtr numWarps = getNumWarps(loc, blockSize, rewriter); - ValuePtr isValidWarp = rewriter.create( + Value numWarps = getNumWarps(loc, blockSize, rewriter); + Value isValidWarp = rewriter.create( loc, LLVM::ICmpPredicate::slt, threadIdx, numWarps); - ValuePtr resultPtr = rewriter.create( - loc, type, sharedMemPtr, ArrayRef({zero, zero})); + Value resultPtr = rewriter.create( + loc, type, sharedMemPtr, ArrayRef({zero, zero})); // Use the first numWarps threads to reduce the intermediate results from // shared memory. The final result is written to shared memory again. createPredicatedBlock(loc, rewriter, isValidWarp, [&] { - ValuePtr loadSrc = rewriter.create( - loc, type, sharedMemPtr, ArrayRef({zero, threadIdx})); - ValuePtr value = rewriter.create(loc, type, loadSrc); - ValuePtr result = createWarpReduce(loc, numWarps, laneId, value, - accumFactory, rewriter); + Value loadSrc = rewriter.create( + loc, type, sharedMemPtr, ArrayRef({zero, threadIdx})); + Value value = rewriter.create(loc, type, loadSrc); + Value result = createWarpReduce(loc, numWarps, laneId, value, + accumFactory, rewriter); rewriter.create(loc, result, resultPtr); }); rewriter.create(loc); // Load and return result from shared memory. - ValuePtr result = rewriter.create(loc, type, resultPtr); + Value result = rewriter.create(loc, type, resultPtr); return result; } @@ -265,7 +265,7 @@ private: /// template void createIf(Location loc, ConversionPatternRewriter &rewriter, - ValuePtr condition, ThenOpsFactory &&thenOpsFactory, + Value condition, ThenOpsFactory &&thenOpsFactory, ElseOpsFactory &&elseOpsFactory) const { Block *currentBlock = rewriter.getInsertionBlock(); auto currentPoint = rewriter.getInsertionPoint(); @@ -279,7 +279,7 @@ private: ArrayRef{thenBlock, elseBlock}); auto addBranch = [&](ValueRange operands) { - rewriter.create(loc, ArrayRef{}, + rewriter.create(loc, ArrayRef{}, llvm::makeArrayRef(continueBlock), llvm::makeArrayRef(operands)); }; @@ -301,25 +301,25 @@ private: /// Shortcut for createIf with empty else block and no block operands. template void createPredicatedBlock(Location loc, ConversionPatternRewriter &rewriter, - ValuePtr condition, + Value condition, Factory &&predicatedOpsFactory) const { createIf( loc, rewriter, condition, [&] { predicatedOpsFactory(); - return ArrayRef(); + return ArrayRef(); }, - [&] { return ArrayRef(); }); + [&] { return ArrayRef(); }); } /// Creates a reduction across the first activeWidth lanes of a warp. /// The first lane returns the result, all others return values are undefined. - ValuePtr createWarpReduce(Location loc, ValuePtr activeWidth, ValuePtr laneId, - ValuePtr operand, AccumulatorFactory accumFactory, - ConversionPatternRewriter &rewriter) const { - ValuePtr warpSize = rewriter.create( + Value createWarpReduce(Location loc, Value activeWidth, Value laneId, + Value operand, AccumulatorFactory accumFactory, + ConversionPatternRewriter &rewriter) const { + Value warpSize = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize)); - ValuePtr isPartialWarp = rewriter.create( + Value isPartialWarp = rewriter.create( loc, LLVM::ICmpPredicate::slt, activeWidth, warpSize); auto type = operand->getType().cast(); @@ -327,16 +327,16 @@ private: loc, rewriter, isPartialWarp, // Generate reduction over a (potentially) partial warp. [&] { - ValuePtr value = operand; - ValuePtr one = rewriter.create( + Value value = operand; + Value one = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(1)); // Bit mask of active lanes: `(1 << activeWidth) - 1`. - ValuePtr activeMask = rewriter.create( + Value activeMask = rewriter.create( loc, int32Type, rewriter.create(loc, int32Type, one, activeWidth), one); // Clamp lane: `activeWidth - 1` - ValuePtr maskAndClamp = + Value maskAndClamp = rewriter.create(loc, int32Type, activeWidth, one); auto dialect = lowering.getDialect(); auto predTy = LLVM::LLVMType::getInt1Ty(dialect); @@ -347,53 +347,53 @@ private: // lane is within the active range. All lanes contain the final // result, but only the first lane's result is used. for (int i = 1; i < kWarpSize; i <<= 1) { - ValuePtr offset = rewriter.create( + Value offset = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(i)); - ValuePtr shfl = rewriter.create( + Value shfl = rewriter.create( loc, shflTy, activeMask, value, offset, maskAndClamp, returnValueAndIsValidAttr); - ValuePtr isActiveSrcLane = rewriter.create( + Value isActiveSrcLane = rewriter.create( loc, predTy, shfl, rewriter.getIndexArrayAttr(1)); // Skip the accumulation if the shuffle op read from a lane outside // of the active range. createIf( loc, rewriter, isActiveSrcLane, [&] { - ValuePtr shflValue = rewriter.create( + Value shflValue = rewriter.create( loc, type, shfl, rewriter.getIndexArrayAttr(0)); - return SmallVector{ + return SmallVector{ accumFactory(loc, value, shflValue, rewriter)}; }, [&] { return llvm::makeArrayRef(value); }); value = rewriter.getInsertionBlock()->getArgument(0); } - return SmallVector{value}; + return SmallVector{value}; }, // Generate a reduction over the entire warp. This is a specialization // of the above reduction with unconditional accumulation. [&] { - ValuePtr value = operand; - ValuePtr activeMask = rewriter.create( + Value value = operand; + Value activeMask = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(~0u)); - ValuePtr maskAndClamp = rewriter.create( + Value maskAndClamp = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1)); for (int i = 1; i < kWarpSize; i <<= 1) { - ValuePtr offset = rewriter.create( + Value offset = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(i)); - ValuePtr shflValue = rewriter.create( + Value shflValue = rewriter.create( loc, type, activeMask, value, offset, maskAndClamp, /*return_value_and_is_valid=*/UnitAttr()); value = accumFactory(loc, value, shflValue, rewriter); } - return SmallVector{value}; + return SmallVector{value}; }); return rewriter.getInsertionBlock()->getArgument(0); } /// Creates a global array stored in shared memory. - ValuePtr createSharedMemoryArray(Location loc, ModuleOp module, - LLVM::LLVMType elementType, int numElements, - ConversionPatternRewriter &rewriter) const { + Value createSharedMemoryArray(Location loc, ModuleOp module, + LLVM::LLVMType elementType, int numElements, + ConversionPatternRewriter &rewriter) const { OpBuilder builder(module.getBodyRegion()); auto arrayType = LLVM::LLVMType::getArrayTy(elementType, numElements); @@ -407,32 +407,31 @@ private: } /// Returns the index of the thread within the block. - ValuePtr getLinearThreadIndex(Location loc, - ConversionPatternRewriter &rewriter) const { - ValuePtr dimX = rewriter.create(loc, int32Type); - ValuePtr dimY = rewriter.create(loc, int32Type); - ValuePtr idX = rewriter.create(loc, int32Type); - ValuePtr idY = rewriter.create(loc, int32Type); - ValuePtr idZ = rewriter.create(loc, int32Type); - ValuePtr tmp1 = rewriter.create(loc, int32Type, idZ, dimY); - ValuePtr tmp2 = rewriter.create(loc, int32Type, tmp1, idY); - ValuePtr tmp3 = rewriter.create(loc, int32Type, tmp2, dimX); + Value getLinearThreadIndex(Location loc, + ConversionPatternRewriter &rewriter) const { + Value dimX = rewriter.create(loc, int32Type); + Value dimY = rewriter.create(loc, int32Type); + Value idX = rewriter.create(loc, int32Type); + Value idY = rewriter.create(loc, int32Type); + Value idZ = rewriter.create(loc, int32Type); + Value tmp1 = rewriter.create(loc, int32Type, idZ, dimY); + Value tmp2 = rewriter.create(loc, int32Type, tmp1, idY); + Value tmp3 = rewriter.create(loc, int32Type, tmp2, dimX); return rewriter.create(loc, int32Type, tmp3, idX); } /// Returns the number of threads in the block. - ValuePtr getBlockSize(Location loc, - ConversionPatternRewriter &rewriter) const { - ValuePtr dimX = rewriter.create(loc, int32Type); - ValuePtr dimY = rewriter.create(loc, int32Type); - ValuePtr dimZ = rewriter.create(loc, int32Type); - ValuePtr dimXY = rewriter.create(loc, int32Type, dimX, dimY); + Value getBlockSize(Location loc, ConversionPatternRewriter &rewriter) const { + Value dimX = rewriter.create(loc, int32Type); + Value dimY = rewriter.create(loc, int32Type); + Value dimZ = rewriter.create(loc, int32Type); + Value dimXY = rewriter.create(loc, int32Type, dimX, dimY); return rewriter.create(loc, int32Type, dimXY, dimZ); } /// Returns the number of warps in the block. - ValuePtr getNumWarps(Location loc, ValuePtr blockSize, - ConversionPatternRewriter &rewriter) const { + Value getNumWarps(Location loc, Value blockSize, + ConversionPatternRewriter &rewriter) const { auto warpSizeMinusOne = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1)); auto biasedBlockSize = rewriter.create( @@ -441,19 +440,19 @@ private: } /// Returns the number of active threads in the warp, not clamped to 32. - ValuePtr getActiveWidth(Location loc, ValuePtr threadIdx, ValuePtr blockSize, - ConversionPatternRewriter &rewriter) const { - ValuePtr threadIdxMask = rewriter.create( + Value getActiveWidth(Location loc, Value threadIdx, Value blockSize, + ConversionPatternRewriter &rewriter) const { + Value threadIdxMask = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(~(kWarpSize - 1))); - ValuePtr numThreadsWithSmallerWarpId = + Value numThreadsWithSmallerWarpId = rewriter.create(loc, threadIdx, threadIdxMask); return rewriter.create(loc, blockSize, numThreadsWithSmallerWarpId); } /// Returns value divided by the warp size (i.e. 32). - ValuePtr getDivideByWarpSize(ValuePtr value, - ConversionPatternRewriter &rewriter) const { + Value getDivideByWarpSize(Value value, + ConversionPatternRewriter &rewriter) const { auto loc = value->getLoc(); auto warpSize = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize)); @@ -487,7 +486,7 @@ struct GPUShuffleOpLowering : public LLVMOpLowering { /// %shfl_pred = llvm.extractvalue %shfl[1 : index] : /// !llvm<"{ float, i1 }"> PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); gpu::ShuffleOpOperandAdaptor adaptor(operands); @@ -498,24 +497,24 @@ struct GPUShuffleOpLowering : public LLVMOpLowering { auto predTy = LLVM::LLVMType::getInt1Ty(dialect); auto resultTy = LLVM::LLVMType::getStructTy(dialect, {valueTy, predTy}); - ValuePtr one = rewriter.create( + Value one = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(1)); // Bit mask of active lanes: `(1 << activeWidth) - 1`. - ValuePtr activeMask = rewriter.create( + Value activeMask = rewriter.create( loc, int32Type, rewriter.create(loc, int32Type, one, adaptor.width()), one); // Clamp lane: `activeWidth - 1` - ValuePtr maskAndClamp = + Value maskAndClamp = rewriter.create(loc, int32Type, adaptor.width(), one); auto returnValueAndIsValidAttr = rewriter.getUnitAttr(); - ValuePtr shfl = rewriter.create( + Value shfl = rewriter.create( loc, resultTy, activeMask, adaptor.value(), adaptor.offset(), maskAndClamp, returnValueAndIsValidAttr); - ValuePtr shflValue = rewriter.create( + Value shflValue = rewriter.create( loc, valueTy, shfl, rewriter.getIndexArrayAttr(0)); - ValuePtr isActiveSrcLane = rewriter.create( + Value isActiveSrcLane = rewriter.create( loc, predTy, shfl, rewriter.getIndexArrayAttr(1)); rewriter.replaceOp(op, {shflValue, isActiveSrcLane}); @@ -530,7 +529,7 @@ struct GPUFuncOpLowering : LLVMOpLowering { typeConverter) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { assert(operands.empty() && "func op is not expected to have operands"); auto gpuFuncOp = cast(op); @@ -539,7 +538,7 @@ struct GPUFuncOpLowering : LLVMOpLowering { SmallVector workgroupBuffers; workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions()); for (auto en : llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) { - ValuePtr attribution = en.value(); + Value attribution = en.value(); auto type = attribution->getType().dyn_cast(); assert(type && type.hasStaticShape() && "unexpected type in attribution"); @@ -596,23 +595,23 @@ struct GPUFuncOpLowering : LLVMOpLowering { unsigned numProperArguments = gpuFuncOp.getNumArguments(); auto i32Type = LLVM::LLVMType::getInt32Ty(lowering.getDialect()); - ValuePtr zero = nullptr; + Value zero = nullptr; if (!workgroupBuffers.empty()) zero = rewriter.create(loc, i32Type, rewriter.getI32IntegerAttr(0)); for (auto en : llvm::enumerate(workgroupBuffers)) { LLVM::GlobalOp global = en.value(); - ValuePtr address = rewriter.create(loc, global); + Value address = rewriter.create(loc, global); auto elementType = global.getType().getArrayElementType(); - ValuePtr memory = rewriter.create( + Value memory = rewriter.create( loc, elementType.getPointerTo(global.addr_space().getZExtValue()), - address, ArrayRef{zero, zero}); + address, ArrayRef{zero, zero}); // Build a memref descriptor pointing to the buffer to plug with the // existing memref infrastructure. This may use more registers than // otherwise necessary given that memref sizes are fixed, but we can try // and canonicalize that away later. - ValuePtr attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()]; + Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()]; auto type = attribution->getType().cast(); auto descr = MemRefDescriptor::fromStaticShape(rewriter, loc, lowering, type, memory); @@ -624,7 +623,7 @@ struct GPUFuncOpLowering : LLVMOpLowering { gpuFuncOp.getNumWorkgroupAttributions(); auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect()); for (auto en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) { - ValuePtr attribution = en.value(); + Value attribution = en.value(); auto type = attribution->getType().cast(); assert(type && type.hasStaticShape() && "unexpected type in attribution"); @@ -635,10 +634,10 @@ struct GPUFuncOpLowering : LLVMOpLowering { auto ptrType = lowering.convertType(type.getElementType()) .cast() .getPointerTo(); - ValuePtr numElements = rewriter.create( + Value numElements = rewriter.create( gpuFuncOp.getLoc(), int64Ty, rewriter.getI64IntegerAttr(type.getNumElements())); - ValuePtr allocated = rewriter.create( + Value allocated = rewriter.create( gpuFuncOp.getLoc(), ptrType, numElements, /*alignment=*/0); auto descr = MemRefDescriptor::fromStaticShape(rewriter, loc, lowering, type, allocated); @@ -666,8 +665,8 @@ struct GPUFuncOpLowering : LLVMOpLowering { !en.value().isa()) continue; - BlockArgumentPtr arg = block.getArgument(en.index()); - ValuePtr loaded = rewriter.create(loc, arg); + BlockArgument arg = block.getArgument(en.index()); + Value loaded = rewriter.create(loc, arg); rewriter.replaceUsesOfBlockArgument(arg, loaded); } } @@ -684,7 +683,7 @@ struct GPUReturnOpLowering : public LLVMOpLowering { typeConverter) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, operands, ArrayRef()); diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp index 95c46853b1f..509457d076a 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp @@ -27,7 +27,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult - matchAndRewrite(loop::ForOp forOp, ArrayRef operands, + matchAndRewrite(loop::ForOp forOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -39,7 +39,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult - matchAndRewrite(SourceOp op, ArrayRef operands, + matchAndRewrite(SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -56,7 +56,7 @@ public: } PatternMatchResult - matchAndRewrite(gpu::GPUFuncOp funcOp, ArrayRef operands, + matchAndRewrite(gpu::GPUFuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; private: @@ -70,7 +70,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult - matchAndRewrite(ModuleOp moduleOp, ArrayRef operands, + matchAndRewrite(ModuleOp moduleOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -83,7 +83,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult - matchAndRewrite(ModuleTerminatorOp terminatorOp, ArrayRef operands, + matchAndRewrite(ModuleTerminatorOp terminatorOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -94,7 +94,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult - matchAndRewrite(gpu::ReturnOp returnOp, ArrayRef operands, + matchAndRewrite(gpu::ReturnOp returnOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -105,7 +105,7 @@ public: //===----------------------------------------------------------------------===// PatternMatchResult -ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef operands, +ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { // loop::ForOp can be lowered to the structured control flow represented by // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop @@ -126,7 +126,7 @@ ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef operands, loopOp.body().getBlocks().insert(std::next(loopOp.body().begin(), 1), header); // Create the new induction variable to use. - BlockArgumentPtr newIndVar = + BlockArgument newIndVar = header->addArgument(forOperands.lowerBound()->getType()); Block *body = forOp.getBody(); @@ -157,7 +157,7 @@ ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef operands, auto cmpOp = rewriter.create( loc, rewriter.getI1Type(), newIndVar, forOperands.upperBound()); rewriter.create( - loc, cmpOp, body, ArrayRef(), mergeBlock, ArrayRef()); + loc, cmpOp, body, ArrayRef(), mergeBlock, ArrayRef()); // Generate instructions to increment the step of the induction variable and // branch to the header. @@ -165,7 +165,7 @@ ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef operands, rewriter.setInsertionPointToEnd(continueBlock); // Add the step to the induction variable and branch to the header. - ValuePtr updatedIndVar = rewriter.create( + Value updatedIndVar = rewriter.create( loc, newIndVar->getType(), newIndVar, forOperands.step()); rewriter.create(loc, header, updatedIndVar); @@ -179,7 +179,7 @@ ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef operands, template PatternMatchResult LaunchConfigConversion::matchAndRewrite( - SourceOp op, ArrayRef operands, + SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { auto dimAttr = op.getOperation()->template getAttrOfType("dimension"); @@ -258,7 +258,7 @@ lowerAsEntryFunction(gpu::GPUFuncOp funcOp, SPIRVTypeConverter &typeConverter, PatternMatchResult KernelFnConversion::matchAndRewrite(gpu::GPUFuncOp funcOp, - ArrayRef operands, + ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (!gpu::GPUDialect::isKernel(funcOp)) { return matchFailure(); @@ -288,7 +288,7 @@ KernelFnConversion::matchAndRewrite(gpu::GPUFuncOp funcOp, //===----------------------------------------------------------------------===// PatternMatchResult KernelModuleConversion::matchAndRewrite( - ModuleOp moduleOp, ArrayRef operands, + ModuleOp moduleOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (!moduleOp.getAttrOfType( gpu::GPUDialect::getKernelModuleAttrName())) { @@ -318,7 +318,7 @@ PatternMatchResult KernelModuleConversion::matchAndRewrite( //===----------------------------------------------------------------------===// PatternMatchResult KernelModuleTerminatorConversion::matchAndRewrite( - ModuleTerminatorOp terminatorOp, ArrayRef operands, + ModuleTerminatorOp terminatorOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { rewriter.replaceOpWithNewOp(terminatorOp); return matchSuccess(); @@ -329,7 +329,7 @@ PatternMatchResult KernelModuleTerminatorConversion::matchAndRewrite( //===----------------------------------------------------------------------===// PatternMatchResult GPUReturnOpConversion::matchAndRewrite( - gpu::ReturnOp returnOp, ArrayRef operands, + gpu::ReturnOp returnOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (!operands.empty()) return matchFailure(); diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp index 1b70df6f8bd..2a034fd15c5 100644 --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -111,23 +111,21 @@ public: BaseViewConversionHelper(Type type) : d(MemRefDescriptor::undef(rewriter(), loc(), type)) {} - BaseViewConversionHelper(ValuePtr v) : d(v) {} + BaseViewConversionHelper(Value v) : d(v) {} /// Wrappers around MemRefDescriptor that use EDSC builder and location. - ValuePtr allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); } - void setAllocatedPtr(ValuePtr v) { d.setAllocatedPtr(rewriter(), loc(), v); } - ValuePtr alignedPtr() { return d.alignedPtr(rewriter(), loc()); } - void setAlignedPtr(ValuePtr v) { d.setAlignedPtr(rewriter(), loc(), v); } - ValuePtr offset() { return d.offset(rewriter(), loc()); } - void setOffset(ValuePtr v) { d.setOffset(rewriter(), loc(), v); } - ValuePtr size(unsigned i) { return d.size(rewriter(), loc(), i); } - void setSize(unsigned i, ValuePtr v) { d.setSize(rewriter(), loc(), i, v); } - ValuePtr stride(unsigned i) { return d.stride(rewriter(), loc(), i); } - void setStride(unsigned i, ValuePtr v) { - d.setStride(rewriter(), loc(), i, v); - } - - operator ValuePtr() { return d; } + Value allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); } + void setAllocatedPtr(Value v) { d.setAllocatedPtr(rewriter(), loc(), v); } + Value alignedPtr() { return d.alignedPtr(rewriter(), loc()); } + void setAlignedPtr(Value v) { d.setAlignedPtr(rewriter(), loc(), v); } + Value offset() { return d.offset(rewriter(), loc()); } + void setOffset(Value v) { d.setOffset(rewriter(), loc(), v); } + Value size(unsigned i) { return d.size(rewriter(), loc(), i); } + void setSize(unsigned i, Value v) { d.setSize(rewriter(), loc(), i, v); } + Value stride(unsigned i) { return d.stride(rewriter(), loc(), i); } + void setStride(unsigned i, Value v) { d.setStride(rewriter(), loc(), i, v); } + + operator Value() { return d; } private: OpBuilder &rewriter() { return ScopedContext::getBuilder(); } @@ -144,7 +142,7 @@ public: : LLVMOpLowering(RangeOp::getOperationName(), context, lowering_) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto rangeOp = cast(op); auto rangeDescriptorTy = @@ -154,7 +152,7 @@ public: // Fill in an aggregate value of the descriptor. RangeOpOperandAdaptor adaptor(operands); - ValuePtr desc = llvm_undef(rangeDescriptorTy); + Value desc = llvm_undef(rangeDescriptorTy); desc = insertvalue(desc, adaptor.min(), rewriter.getI64ArrayAttr(0)); desc = insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1)); desc = insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2)); @@ -177,7 +175,7 @@ public: : LLVMOpLowering(SliceOp::getOperationName(), context, lowering_) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { edsc::ScopedContext context(rewriter, op->getLoc()); SliceOpOperandAdaptor adaptor(operands); @@ -191,7 +189,7 @@ public: BaseViewConversionHelper desc(lowering.convertType(sliceOp.getViewType())); // TODO(ntv): extract sizes and emit asserts. - SmallVector strides(memRefType.getRank()); + SmallVector strides(memRefType.getRank()); for (int i = 0, e = memRefType.getRank(); i < e; ++i) strides[i] = baseDesc.stride(i); @@ -200,10 +198,10 @@ public: }; // Compute base offset. - ValuePtr baseOffset = baseDesc.offset(); + Value baseOffset = baseDesc.offset(); for (int i = 0, e = memRefType.getRank(); i < e; ++i) { - ValuePtr indexing = adaptor.indexings()[i]; - ValuePtr min = indexing; + Value indexing = adaptor.indexings()[i]; + Value min = indexing; if (sliceOp.indexing(i)->getType().isa()) min = extractvalue(int64Ty, indexing, pos(0)); baseOffset = add(baseOffset, mul(min, strides[i])); @@ -220,29 +218,29 @@ public: if (sliceOp.getViewType().getRank() == 0) return rewriter.replaceOp(op, {desc}), matchSuccess(); - ValuePtr zero = + Value zero = constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); // Compute and insert view sizes (max - min along the range) and strides. // Skip the non-range operands as they will be projected away from the view. int numNewDims = 0; for (auto en : llvm::enumerate(sliceOp.indexings())) { - ValuePtr indexing = en.value(); + Value indexing = en.value(); if (indexing->getType().isa()) { int rank = en.index(); - ValuePtr rangeDescriptor = adaptor.indexings()[rank]; - ValuePtr min = extractvalue(int64Ty, rangeDescriptor, pos(0)); - ValuePtr max = extractvalue(int64Ty, rangeDescriptor, pos(1)); - ValuePtr step = extractvalue(int64Ty, rangeDescriptor, pos(2)); - ValuePtr baseSize = baseDesc.size(rank); + Value rangeDescriptor = adaptor.indexings()[rank]; + Value min = extractvalue(int64Ty, rangeDescriptor, pos(0)); + Value max = extractvalue(int64Ty, rangeDescriptor, pos(1)); + Value step = extractvalue(int64Ty, rangeDescriptor, pos(2)); + Value baseSize = baseDesc.size(rank); // Bound upper by base view upper bound. max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max, baseSize); - ValuePtr size = sub(max, min); + Value size = sub(max, min); // Bound lower by zero. size = llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size); - ValuePtr stride = mul(strides[rank], step); + Value stride = mul(strides[rank], step); desc.setSize(numNewDims, size); desc.setStride(numNewDims, stride); ++numNewDims; @@ -268,7 +266,7 @@ public: : LLVMOpLowering(TransposeOp::getOperationName(), context, lowering_) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // Initialize the common boilerplate and alloca at the top of the FuncOp. edsc::ScopedContext context(rewriter, op->getLoc()); @@ -311,7 +309,7 @@ public: : LLVMOpLowering(YieldOp::getOperationName(), context, lowering_) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, operands); return matchSuccess(); @@ -446,7 +444,7 @@ public: op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); auto indexedGenericOp = cast(op); auto numLoops = indexedGenericOp.getNumLoops(); - SmallVector operands; + SmallVector operands; operands.reserve(numLoops + op.getNumOperands()); for (unsigned i = 0; i < numLoops; ++i) { operands.push_back(zero); @@ -470,7 +468,7 @@ public: PatternMatchResult matchAndRewrite(CopyOp op, PatternRewriter &rewriter) const override { - ValuePtr in = op.input(), out = op.output(); + Value in = op.input(), out = op.output(); // If either inputPerm or outputPerm are non-identities, insert transposes. auto inputPerm = op.inputPermutation(); diff --git a/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp b/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp index 59dac73de9c..b257e9b482b 100644 --- a/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp +++ b/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp @@ -187,8 +187,8 @@ ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const { // Compute loop bounds before branching to the condition. rewriter.setInsertionPointToEnd(initBlock); - ValuePtr lowerBound = forOp.lowerBound(); - ValuePtr upperBound = forOp.upperBound(); + Value lowerBound = forOp.lowerBound(); + Value upperBound = forOp.upperBound(); if (!lowerBound || !upperBound) return matchFailure(); rewriter.create(loc, conditionBlock, lowerBound); @@ -199,8 +199,7 @@ ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const { rewriter.create(loc, CmpIPredicate::slt, iv, upperBound); rewriter.create(loc, comparison, firstBodyBlock, - ArrayRef(), endBlock, - ArrayRef()); + ArrayRef(), endBlock, ArrayRef()); // Ok, we're done! rewriter.eraseOp(forOp); return matchSuccess(); @@ -239,8 +238,8 @@ IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const { rewriter.setInsertionPointToEnd(condBlock); rewriter.create(loc, ifOp.condition(), thenBlock, - /*trueArgs=*/ArrayRef(), elseBlock, - /*falseArgs=*/ArrayRef()); + /*trueArgs=*/ArrayRef(), elseBlock, + /*falseArgs=*/ArrayRef()); // Ok, we're done! rewriter.eraseOp(ifOp); diff --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp index 24bb8ffc462..e500d10983c 100644 --- a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp +++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp @@ -34,7 +34,7 @@ using namespace mlir::loop; using llvm::seq; // Extract an indexed value from KernelDim3. -static ValuePtr getDim3Value(const gpu::KernelDim3 &dim3, unsigned pos) { +static Value getDim3Value(const gpu::KernelDim3 &dim3, unsigned pos) { switch (pos) { case 0: return dim3.x; @@ -52,8 +52,8 @@ static ValuePtr getDim3Value(const gpu::KernelDim3 &dim3, unsigned pos) { static Operation::operand_range getLowerBoundOperands(AffineForOp forOp) { return forOp.getLowerBoundOperands(); } -static SmallVector getLowerBoundOperands(ForOp forOp) { - SmallVector bounds(1, forOp.lowerBound()); +static SmallVector getLowerBoundOperands(ForOp forOp) { + SmallVector bounds(1, forOp.lowerBound()); return bounds; } @@ -61,35 +61,33 @@ static SmallVector getLowerBoundOperands(ForOp forOp) { static Operation::operand_range getUpperBoundOperands(AffineForOp forOp) { return forOp.getUpperBoundOperands(); } -static SmallVector getUpperBoundOperands(ForOp forOp) { - SmallVector bounds(1, forOp.upperBound()); +static SmallVector getUpperBoundOperands(ForOp forOp) { + SmallVector bounds(1, forOp.upperBound()); return bounds; } // Get a Value that corresponds to the loop step. If the step is an attribute, // materialize a corresponding constant using builder. -static ValuePtr getOrCreateStep(AffineForOp forOp, OpBuilder &builder) { +static Value getOrCreateStep(AffineForOp forOp, OpBuilder &builder) { return builder.create(forOp.getLoc(), forOp.getStep()); } -static ValuePtr getOrCreateStep(ForOp forOp, OpBuilder &) { - return forOp.step(); -} +static Value getOrCreateStep(ForOp forOp, OpBuilder &) { return forOp.step(); } // Get a Value for the loop lower bound. If the value requires computation, // materialize the instructions using builder. -static ValuePtr getOrEmitLowerBound(AffineForOp forOp, OpBuilder &builder) { +static Value getOrEmitLowerBound(AffineForOp forOp, OpBuilder &builder) { return lowerAffineLowerBound(forOp, builder); } -static ValuePtr getOrEmitLowerBound(ForOp forOp, OpBuilder &) { +static Value getOrEmitLowerBound(ForOp forOp, OpBuilder &) { return forOp.lowerBound(); } // Get a Value for the loop upper bound. If the value requires computation, // materialize the instructions using builder. -static ValuePtr getOrEmitUpperBound(AffineForOp forOp, OpBuilder &builder) { +static Value getOrEmitUpperBound(AffineForOp forOp, OpBuilder &builder) { return lowerAffineUpperBound(forOp, builder); } -static ValuePtr getOrEmitUpperBound(ForOp forOp, OpBuilder &) { +static Value getOrEmitUpperBound(ForOp forOp, OpBuilder &) { return forOp.upperBound(); } @@ -205,18 +203,18 @@ struct LoopToGpuConverter { unsigned numThreadDims); // Ranges of the loops mapped to blocks or threads. - SmallVector dims; + SmallVector dims; // Lower bounds of the loops mapped to blocks or threads. - SmallVector lbs; + SmallVector lbs; // Induction variables of the loops mapped to blocks or threads. - SmallVector ivs; + SmallVector ivs; // Steps of the loops mapped to blocks or threads. - SmallVector steps; + SmallVector steps; }; } // namespace // Return true if the value is obviously a constant "one". -static bool isConstantOne(ValuePtr value) { +static bool isConstantOne(Value value) { if (auto def = dyn_cast_or_null(value->getDefiningOp())) return def.getValue() == 1; return false; @@ -237,15 +235,15 @@ Optional LoopToGpuConverter::collectBounds(OpTy forOp, steps.reserve(numLoops); OpTy currentLoop = forOp; for (unsigned i = 0; i < numLoops; ++i) { - ValuePtr lowerBound = getOrEmitLowerBound(currentLoop, builder); - ValuePtr upperBound = getOrEmitUpperBound(currentLoop, builder); + Value lowerBound = getOrEmitLowerBound(currentLoop, builder); + Value upperBound = getOrEmitUpperBound(currentLoop, builder); if (!lowerBound || !upperBound) { return llvm::None; } - ValuePtr range = + Value range = builder.create(currentLoop.getLoc(), upperBound, lowerBound); - ValuePtr step = getOrCreateStep(currentLoop, builder); + Value step = getOrCreateStep(currentLoop, builder); if (!isConstantOne(step)) range = builder.create(currentLoop.getLoc(), range, step); dims.push_back(range); @@ -267,8 +265,8 @@ Optional LoopToGpuConverter::collectBounds(OpTy forOp, /// `nids`. The innermost loop is mapped to the x-dimension, followed by the /// next innermost loop to y-dimension, followed by z-dimension. template -OpTy createGPULaunchLoops(OpTy rootForOp, ArrayRef ids, - ArrayRef nids) { +OpTy createGPULaunchLoops(OpTy rootForOp, ArrayRef ids, + ArrayRef nids) { auto nDims = ids.size(); assert(nDims == nids.size()); for (auto dim : llvm::seq(0, nDims)) { @@ -288,11 +286,11 @@ OpTy createGPULaunchLoops(OpTy rootForOp, ArrayRef ids, /// each workgroup/workitem and number of workgroup/workitems along a dimension /// of the launch into a container. void packIdAndNumId(gpu::KernelDim3 kernelIds, gpu::KernelDim3 kernelNids, - unsigned nDims, SmallVectorImpl &ids, - SmallVectorImpl &nids) { + unsigned nDims, SmallVectorImpl &ids, + SmallVectorImpl &nids) { assert(nDims <= 3 && "invalid number of launch dimensions"); - SmallVector allIds = {kernelIds.z, kernelIds.y, kernelIds.x}; - SmallVector allNids = {kernelNids.z, kernelNids.y, kernelNids.x}; + SmallVector allIds = {kernelIds.z, kernelIds.y, kernelIds.x}; + SmallVector allNids = {kernelNids.z, kernelNids.y, kernelNids.x}; ids.clear(); ids.append(std::next(allIds.begin(), allIds.size() - nDims), allIds.end()); nids.clear(); @@ -310,7 +308,7 @@ LogicalResult createLaunchBody(OpBuilder &builder, OpTy rootForOp, auto returnOp = builder.create(launchOp.getLoc()); rootForOp.getOperation()->moveBefore(returnOp); - SmallVector workgroupID, numWorkGroups; + SmallVector workgroupID, numWorkGroups; packIdAndNumId(launchOp.getBlockIds(), launchOp.getGridSize(), numBlockDims, workgroupID, numWorkGroups); @@ -326,7 +324,7 @@ LogicalResult createLaunchBody(OpBuilder &builder, OpTy rootForOp, } } - SmallVector workItemID, workGroupSize; + SmallVector workItemID, workGroupSize; packIdAndNumId(launchOp.getThreadIds(), launchOp.getBlockSize(), numThreadDims, workItemID, workGroupSize); for (auto &loopOp : threadRootForOps) { @@ -339,18 +337,17 @@ LogicalResult createLaunchBody(OpBuilder &builder, OpTy rootForOp, // Convert the computation rooted at the `rootForOp`, into a GPU kernel with the // given workgroup size and number of workgroups. template -LogicalResult createLaunchFromOp(OpTy rootForOp, - ArrayRef numWorkGroups, - ArrayRef workGroupSizes) { +LogicalResult createLaunchFromOp(OpTy rootForOp, ArrayRef numWorkGroups, + ArrayRef workGroupSizes) { OpBuilder builder(rootForOp.getOperation()); if (numWorkGroups.size() > 3) { return rootForOp.emitError("invalid ") << numWorkGroups.size() << "-D workgroup specification"; } auto loc = rootForOp.getLoc(); - ValuePtr one = builder.create( + Value one = builder.create( loc, builder.getIntegerAttr(builder.getIndexType(), 1)); - SmallVector numWorkGroups3D(3, one), workGroupSize3D(3, one); + SmallVector numWorkGroups3D(3, one), workGroupSize3D(3, one); for (auto numWorkGroup : enumerate(numWorkGroups)) { numWorkGroups3D[numWorkGroup.index()] = numWorkGroup.value(); } @@ -360,7 +357,7 @@ LogicalResult createLaunchFromOp(OpTy rootForOp, // Get the values used within the region of the rootForOp but defined above // it. - llvm::SetVector valuesToForwardSet; + llvm::SetVector valuesToForwardSet; getUsedValuesDefinedAbove(rootForOp.region(), rootForOp.region(), valuesToForwardSet); // Also add the values used for the lb, ub, and step of the rootForOp. @@ -380,8 +377,8 @@ LogicalResult createLaunchFromOp(OpTy rootForOp, // defined outside. They all are replaced with kernel arguments. for (const auto &pair : llvm::zip_first(valuesToForward, launchOp.getKernelArguments())) { - ValuePtr from = std::get<0>(pair); - ValuePtr to = std::get<1>(pair); + Value from = std::get<0>(pair); + Value to = std::get<1>(pair); replaceAllUsesInRegionWith(from, to, launchOp.body()); } return success(); @@ -401,23 +398,22 @@ void LoopToGpuConverter::createLaunch(OpTy rootForOp, OpTy innermostForOp, OpBuilder builder(rootForOp.getOperation()); // Prepare the grid and block sizes for the launch operation. If there is // no loop mapped to a specific dimension, use constant "1" as its size. - ValuePtr constOne = - (numBlockDims < 3 || numThreadDims < 3) - ? builder.create(rootForOp.getLoc(), 1) - : nullptr; - ValuePtr gridSizeX = dims[0]; - ValuePtr gridSizeY = numBlockDims > 1 ? dims[1] : constOne; - ValuePtr gridSizeZ = numBlockDims > 2 ? dims[2] : constOne; - ValuePtr blockSizeX = dims[numBlockDims]; - ValuePtr blockSizeY = numThreadDims > 1 ? dims[numBlockDims + 1] : constOne; - ValuePtr blockSizeZ = numThreadDims > 2 ? dims[numBlockDims + 2] : constOne; + Value constOne = (numBlockDims < 3 || numThreadDims < 3) + ? builder.create(rootForOp.getLoc(), 1) + : nullptr; + Value gridSizeX = dims[0]; + Value gridSizeY = numBlockDims > 1 ? dims[1] : constOne; + Value gridSizeZ = numBlockDims > 2 ? dims[2] : constOne; + Value blockSizeX = dims[numBlockDims]; + Value blockSizeY = numThreadDims > 1 ? dims[numBlockDims + 1] : constOne; + Value blockSizeZ = numThreadDims > 2 ? dims[numBlockDims + 2] : constOne; // Create a launch op and move the body region of the innermost loop to the // launch op. Pass the values defined outside the outermost loop and used // inside the innermost loop and loop lower bounds as kernel data arguments. // Still assuming perfect nesting so there are no values other than induction // variables that are defined in one loop and used in deeper loops. - llvm::SetVector valuesToForwardSet; + llvm::SetVector valuesToForwardSet; getUsedValuesDefinedAbove(innermostForOp.region(), rootForOp.region(), valuesToForwardSet); auto valuesToForward = valuesToForwardSet.takeVector(); @@ -451,15 +447,15 @@ void LoopToGpuConverter::createLaunch(OpTy rootForOp, OpTy innermostForOp, originallyForwardedValues); auto stepArgumentIt = std::next(lbArgumentIt, lbs.size()); for (auto en : llvm::enumerate(ivs)) { - ValuePtr id = + Value id = en.index() < numBlockDims ? getDim3Value(launchOp.getBlockIds(), en.index()) : getDim3Value(launchOp.getThreadIds(), en.index() - numBlockDims); - ValuePtr step = steps[en.index()]; + Value step = steps[en.index()]; if (!isConstantOne(step)) id = builder.create(rootForOp.getLoc(), step, id); - ValuePtr ivReplacement = + Value ivReplacement = builder.create(rootForOp.getLoc(), *lbArgumentIt, id); en.value()->replaceAllUsesWith(ivReplacement); replaceAllUsesInRegionWith(steps[en.index()], *stepArgumentIt, @@ -473,8 +469,8 @@ void LoopToGpuConverter::createLaunch(OpTy rootForOp, OpTy innermostForOp, // trailing positions, make sure we don't touch those. for (const auto &pair : llvm::zip_first(valuesToForward, launchOp.getKernelArguments())) { - ValuePtr from = std::get<0>(pair); - ValuePtr to = std::get<1>(pair); + Value from = std::get<0>(pair); + Value to = std::get<1>(pair); replaceAllUsesInRegionWith(from, to, launchOp.body()); } @@ -504,8 +500,8 @@ static LogicalResult convertLoopNestToGPULaunch(OpTy forOp, // nested. The workgroup size and num workgroups is provided as input template static LogicalResult convertLoopToGPULaunch(OpTy forOp, - ArrayRef numWorkGroups, - ArrayRef workGroupSize) { + ArrayRef numWorkGroups, + ArrayRef workGroupSize) { if (failed(checkLoopOpMappable(forOp, numWorkGroups.size(), workGroupSize.size()))) { return failure(); @@ -526,7 +522,7 @@ LogicalResult mlir::convertLoopNestToGPULaunch(ForOp forOp, } LogicalResult mlir::convertLoopToGPULaunch(loop::ForOp forOp, - ArrayRef numWorkGroups, - ArrayRef workGroupSizes) { + ArrayRef numWorkGroups, + ArrayRef workGroupSizes) { return ::convertLoopToGPULaunch(forOp, numWorkGroups, workGroupSizes); } diff --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp index 4dfd26a4392..c3bbf274818 100644 --- a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp +++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp @@ -89,7 +89,7 @@ struct ImperfectlyNestedForLoopMapper // pass is only used for testing. FuncOp funcOp = getFunction(); OpBuilder builder(funcOp.getOperation()->getRegion(0)); - SmallVector numWorkGroupsVal, workGroupSizeVal; + SmallVector numWorkGroupsVal, workGroupSizeVal; for (auto val : numWorkGroups) { auto constOp = builder.create( funcOp.getLoc(), builder.getIntegerAttr(builder.getIndexType(), val)); diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 160678efe9f..0c96cc5e9c7 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -247,20 +247,20 @@ LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context, /*============================================================================*/ /* StructBuilder implementation */ /*============================================================================*/ -StructBuilder::StructBuilder(ValuePtr v) : value(v) { +StructBuilder::StructBuilder(Value v) : value(v) { assert(value != nullptr && "value cannot be null"); structType = value->getType().cast(); } -ValuePtr StructBuilder::extractPtr(OpBuilder &builder, Location loc, - unsigned pos) { +Value StructBuilder::extractPtr(OpBuilder &builder, Location loc, + unsigned pos) { Type type = structType.cast().getStructElementType(pos); return builder.create(loc, type, value, builder.getI64ArrayAttr(pos)); } void StructBuilder::setPtr(OpBuilder &builder, Location loc, unsigned pos, - ValuePtr ptr) { + Value ptr) { value = builder.create(loc, structType, value, ptr, builder.getI64ArrayAttr(pos)); } @@ -269,7 +269,7 @@ void StructBuilder::setPtr(OpBuilder &builder, Location loc, unsigned pos, /*============================================================================*/ /// Construct a helper for the given descriptor value. -MemRefDescriptor::MemRefDescriptor(ValuePtr descriptor) +MemRefDescriptor::MemRefDescriptor(Value descriptor) : StructBuilder(descriptor) { assert(value != nullptr && "value cannot be null"); indexType = value->getType().cast().getStructElementType( @@ -280,7 +280,7 @@ MemRefDescriptor::MemRefDescriptor(ValuePtr descriptor) MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc, Type descriptorType) { - ValuePtr descriptor = + Value descriptor = builder.create(loc, descriptorType.cast()); return MemRefDescriptor(descriptor); } @@ -291,7 +291,7 @@ MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc, MemRefDescriptor MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, - MemRefType type, ValuePtr memory) { + MemRefType type, Value memory) { assert(type.hasStaticShape() && "unexpected dynamic shape"); assert(type.getAffineMaps().empty() && "unexpected layout map"); @@ -316,37 +316,37 @@ MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc, } /// Builds IR extracting the allocated pointer from the descriptor. -ValuePtr MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) { +Value MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) { return extractPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor); } /// Builds IR inserting the allocated pointer into the descriptor. void MemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc, - ValuePtr ptr) { + Value ptr) { setPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor, ptr); } /// Builds IR extracting the aligned pointer from the descriptor. -ValuePtr MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) { +Value MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) { return extractPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor); } /// Builds IR inserting the aligned pointer into the descriptor. void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc, - ValuePtr ptr) { + Value ptr) { setPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor, ptr); } // Creates a constant Op producing a value of `resultType` from an index-typed // integer attribute. -static ValuePtr createIndexAttrConstant(OpBuilder &builder, Location loc, - Type resultType, int64_t value) { +static Value createIndexAttrConstant(OpBuilder &builder, Location loc, + Type resultType, int64_t value) { return builder.create( loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value)); } /// Builds IR extracting the offset from the descriptor. -ValuePtr MemRefDescriptor::offset(OpBuilder &builder, Location loc) { +Value MemRefDescriptor::offset(OpBuilder &builder, Location loc) { return builder.create( loc, indexType, value, builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor)); @@ -354,7 +354,7 @@ ValuePtr MemRefDescriptor::offset(OpBuilder &builder, Location loc) { /// Builds IR inserting the offset into the descriptor. void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc, - ValuePtr offset) { + Value offset) { value = builder.create( loc, structType, value, offset, builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor)); @@ -368,8 +368,7 @@ void MemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc, } /// Builds IR extracting the pos-th size from the descriptor. -ValuePtr MemRefDescriptor::size(OpBuilder &builder, Location loc, - unsigned pos) { +Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) { return builder.create( loc, indexType, value, builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos})); @@ -377,7 +376,7 @@ ValuePtr MemRefDescriptor::size(OpBuilder &builder, Location loc, /// Builds IR inserting the pos-th size into the descriptor void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos, - ValuePtr size) { + Value size) { value = builder.create( loc, structType, value, size, builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos})); @@ -391,8 +390,7 @@ void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc, } /// Builds IR extracting the pos-th size from the descriptor. -ValuePtr MemRefDescriptor::stride(OpBuilder &builder, Location loc, - unsigned pos) { +Value MemRefDescriptor::stride(OpBuilder &builder, Location loc, unsigned pos) { return builder.create( loc, indexType, value, builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos})); @@ -400,7 +398,7 @@ ValuePtr MemRefDescriptor::stride(OpBuilder &builder, Location loc, /// Builds IR inserting the pos-th stride into the descriptor void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos, - ValuePtr stride) { + Value stride) { value = builder.create( loc, structType, value, stride, builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos})); @@ -423,30 +421,30 @@ LLVM::LLVMType MemRefDescriptor::getElementType() { /*============================================================================*/ /// Construct a helper for the given descriptor value. -UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(ValuePtr descriptor) +UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value descriptor) : StructBuilder(descriptor) {} /// Builds IR creating an `undef` value of the descriptor type. UnrankedMemRefDescriptor UnrankedMemRefDescriptor::undef(OpBuilder &builder, Location loc, Type descriptorType) { - ValuePtr descriptor = + Value descriptor = builder.create(loc, descriptorType.cast()); return UnrankedMemRefDescriptor(descriptor); } -ValuePtr UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) { +Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) { return extractPtr(builder, loc, kRankInUnrankedMemRefDescriptor); } void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc, - ValuePtr v) { + Value v) { setPtr(builder, loc, kRankInUnrankedMemRefDescriptor, v); } -ValuePtr UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder, - Location loc) { +Value UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder, + Location loc) { return extractPtr(builder, loc, kPtrInUnrankedMemRefDescriptor); } void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder, - Location loc, ValuePtr v) { + Location loc, Value v) { setPtr(builder, loc, kPtrInUnrankedMemRefDescriptor, v); } namespace { @@ -487,8 +485,8 @@ public: } // Create an LLVM IR pseudo-operation defining the given index constant. - ValuePtr createIndexConstant(ConversionPatternRewriter &builder, Location loc, - uint64_t value) const { + Value createIndexConstant(ConversionPatternRewriter &builder, Location loc, + uint64_t value) const { return createIndexAttrConstant(builder, loc, getIndexType(), value); } @@ -500,7 +498,7 @@ struct FuncOpConversion : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto funcOp = cast(op); FunctionType type = funcOp.getType(); @@ -548,8 +546,8 @@ struct FuncOpConversion : public LLVMLegalizationPattern { Block *firstBlock = &newFuncOp.getBody().front(); rewriter.setInsertionPoint(firstBlock, firstBlock->begin()); for (unsigned idx : promotedArgIndices) { - BlockArgumentPtr arg = firstBlock->getArgument(idx); - ValuePtr loaded = rewriter.create(funcOp.getLoc(), arg); + BlockArgument arg = firstBlock->getArgument(idx); + Value loaded = rewriter.create(funcOp.getLoc(), arg); rewriter.replaceUsesOfBlockArgument(arg, loaded); } } @@ -648,7 +646,7 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern { // Convert the type of the result to an LLVM type, pass operands as is, // preserve attributes. PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { unsigned numResults = op->getNumResults(); @@ -672,7 +670,7 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern { // Otherwise, it had been converted to an operation producing a structure. // Extract individual results from the structure and return them as list. - SmallVector results; + SmallVector results; results.reserve(numResults); for (unsigned i = 0; i < numResults; ++i) { auto type = this->lowering.convertType(op->getResult(i)->getType()); @@ -713,7 +711,7 @@ struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern { // Convert the type of the result to an LLVM type, pass operands as is, // preserve attributes. PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { ValidateOpCount(); static_assert( @@ -724,7 +722,7 @@ struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern { "expected same operands and result type"); // Cannot convert ops if their operands are not of LLVM type. - for (ValuePtr operand : operands) { + for (Value operand : operands) { if (!operand || !operand->getType().isa()) return this->matchFailure(); } @@ -747,16 +745,16 @@ struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern { if (!llvmVectorTy || llvmArrayTy != vectorTypeInfo.llvmArrayTy) return this->matchFailure(); - ValuePtr desc = rewriter.create(loc, llvmArrayTy); + Value desc = rewriter.create(loc, llvmArrayTy); nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) { // For this unrolled `position` corresponding to the `linearIndex`^th // element, extract operand vectors - SmallVector extractedOperands; + SmallVector extractedOperands; for (unsigned i = 0; i < OpCount; ++i) { extractedOperands.push_back(rewriter.create( loc, llvmVectorTy, operands[i], position)); } - ValuePtr newVal = rewriter.create( + Value newVal = rewriter.create( loc, llvmVectorTy, extractedOperands, op->getAttrs()); desc = rewriter.create(loc, llvmArrayTy, desc, newVal, position); @@ -919,7 +917,7 @@ struct AllocOpLowering : public LLVMLegalizationPattern { return matchSuccess(); } - void rewrite(Operation *op, ArrayRef operands, + void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto allocOp = cast(op); @@ -928,7 +926,7 @@ struct AllocOpLowering : public LLVMLegalizationPattern { // Get actual sizes of the memref as values: static sizes are constant // values and dynamic sizes are passed to 'alloc' as operands. In case of // zero-dimensional memref, assume a scalar (size 1). - SmallVector sizes; + SmallVector sizes; sizes.reserve(type.getRank()); unsigned i = 0; for (int64_t s : type.getShape()) @@ -938,10 +936,10 @@ struct AllocOpLowering : public LLVMLegalizationPattern { sizes.push_back(createIndexConstant(rewriter, loc, 1)); // Compute the total number of memref elements. - ValuePtr cumulativeSize = sizes.front(); + Value cumulativeSize = sizes.front(); for (unsigned i = 1, e = sizes.size(); i < e; ++i) cumulativeSize = rewriter.create( - loc, getIndexType(), ArrayRef{cumulativeSize, sizes[i]}); + loc, getIndexType(), ArrayRef{cumulativeSize, sizes[i]}); // Compute the size of an individual element. This emits the MLIR equivalent // of the following sizeof(...) implementation in LLVM IR: @@ -954,17 +952,17 @@ struct AllocOpLowering : public LLVMLegalizationPattern { auto nullPtr = rewriter.create(loc, convertedPtrType); auto one = createIndexConstant(rewriter, loc, 1); auto gep = rewriter.create(loc, convertedPtrType, - ArrayRef{nullPtr, one}); + ArrayRef{nullPtr, one}); auto elementSize = rewriter.create(loc, getIndexType(), gep); cumulativeSize = rewriter.create( - loc, getIndexType(), ArrayRef{cumulativeSize, elementSize}); + loc, getIndexType(), ArrayRef{cumulativeSize, elementSize}); // Allocate the underlying buffer and store a pointer to it in the MemRef // descriptor. - ValuePtr allocated = nullptr; + Value allocated = nullptr; int alignment = 0; - ValuePtr alignmentValue = nullptr; + Value alignmentValue = nullptr; if (auto alignAttr = allocOp.alignment()) alignment = alignAttr.getValue().getSExtValue(); @@ -1000,8 +998,8 @@ struct AllocOpLowering : public LLVMLegalizationPattern { auto structElementType = lowering.convertType(elementType); auto elementPtrType = structElementType.cast().getPointerTo( type.getMemorySpace()); - ValuePtr bitcastAllocated = rewriter.create( - loc, elementPtrType, ArrayRef(allocated)); + Value bitcastAllocated = rewriter.create( + loc, elementPtrType, ArrayRef(allocated)); int64_t offset; SmallVector strides; @@ -1023,22 +1021,21 @@ struct AllocOpLowering : public LLVMLegalizationPattern { memRefDescriptor.setAllocatedPtr(rewriter, loc, bitcastAllocated); // Field 2: Actual aligned pointer to payload. - ValuePtr bitcastAligned = bitcastAllocated; + Value bitcastAligned = bitcastAllocated; if (!useAlloca && alignment != 0) { assert(alignmentValue); // offset = (align - (ptr % align))% align - ValuePtr intVal = rewriter.create( + Value intVal = rewriter.create( loc, this->getIndexType(), allocated); - ValuePtr ptrModAlign = + Value ptrModAlign = rewriter.create(loc, intVal, alignmentValue); - ValuePtr subbed = + Value subbed = rewriter.create(loc, alignmentValue, ptrModAlign); - ValuePtr offset = - rewriter.create(loc, subbed, alignmentValue); - ValuePtr aligned = rewriter.create(loc, allocated->getType(), - allocated, offset); + Value offset = rewriter.create(loc, subbed, alignmentValue); + Value aligned = rewriter.create(loc, allocated->getType(), + allocated, offset); bitcastAligned = rewriter.create( - loc, elementPtrType, ArrayRef(aligned)); + loc, elementPtrType, ArrayRef(aligned)); } memRefDescriptor.setAlignedPtr(rewriter, loc, bitcastAligned); @@ -1053,10 +1050,10 @@ struct AllocOpLowering : public LLVMLegalizationPattern { // Fields 4 and 5: Sizes and strides of the strided MemRef. // Store all sizes in the descriptor. Only dynamic sizes are passed in as // operands to AllocOp. - ValuePtr runningStride = nullptr; + Value runningStride = nullptr; // Iterate strides in reverse order, compute runningStride and strideValues. auto nStrides = strides.size(); - SmallVector strideValues(nStrides, nullptr); + SmallVector strideValues(nStrides, nullptr); for (auto indexedStride : llvm::enumerate(llvm::reverse(strides))) { int64_t index = nStrides - 1 - indexedStride.index(); if (strides[index] == MemRefType::getDynamicStrideOrOffset()) @@ -1093,7 +1090,7 @@ struct CallOpInterfaceLowering : public LLVMLegalizationPattern { using Base = LLVMLegalizationPattern; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { OperandAdaptor transformed(operands); auto callOp = cast(op); @@ -1131,7 +1128,7 @@ struct CallOpInterfaceLowering : public LLVMLegalizationPattern { // TODO(aminim, ntv, riverriddle, zinenko): this seems like patching around // a particular interaction between MemRefType and CallOp lowering. Find a // way to avoid special casing. - SmallVector results; + SmallVector results; results.reserve(numResults); for (unsigned i = 0; i < numResults; ++i) { auto type = this->lowering.convertType(op->getResult(i)->getType()); @@ -1165,7 +1162,7 @@ struct DeallocOpLowering : public LLVMLegalizationPattern { useAlloca(useAlloca) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { if (useAlloca) return rewriter.eraseOp(op), matchSuccess(); @@ -1185,7 +1182,7 @@ struct DeallocOpLowering : public LLVMLegalizationPattern { } MemRefDescriptor memref(transformed.memref()); - ValuePtr casted = rewriter.create( + Value casted = rewriter.create( op->getLoc(), getVoidPtrType(), memref.allocatedPtr(rewriter, op->getLoc())); rewriter.replaceOpWithNewOp( @@ -1201,7 +1198,7 @@ struct TanhOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { using LLVMFuncOpT = LLVM::LLVMFuncOp; @@ -1275,7 +1272,7 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern { : matchFailure(); } - void rewrite(Operation *op, ArrayRef operands, + void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto memRefCastOp = cast(op); OperandAdaptor transformed(operands); @@ -1316,7 +1313,7 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern { memRefDesc.setRank(rewriter, loc, rankVal); // d2 = InsertValueOp d1, voidptr, 1 memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr); - rewriter.replaceOp(op, (ValuePtr)memRefDesc); + rewriter.replaceOp(op, (Value)memRefDesc); } else if (srcType.isa() && dstType.isa()) { // Casting from unranked type to ranked. @@ -1347,7 +1344,7 @@ struct DimOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dimOp = cast(op); OperandAdaptor transformed(operands); @@ -1389,45 +1386,42 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern { // by accumulating the running linearized value. // Note that `indices` and `allocSizes` are passed in the same order as they // appear in load/store operations and memref type declarations. - ValuePtr linearizeSubscripts(ConversionPatternRewriter &builder, Location loc, - ArrayRef indices, - ArrayRef allocSizes) const { + Value linearizeSubscripts(ConversionPatternRewriter &builder, Location loc, + ArrayRef indices, + ArrayRef allocSizes) const { assert(indices.size() == allocSizes.size() && "mismatching number of indices and allocation sizes"); assert(!indices.empty() && "cannot linearize a 0-dimensional access"); - ValuePtr linearized = indices.front(); + Value linearized = indices.front(); for (int i = 1, nSizes = allocSizes.size(); i < nSizes; ++i) { linearized = builder.create( loc, this->getIndexType(), - ArrayRef{linearized, allocSizes[i]}); + ArrayRef{linearized, allocSizes[i]}); linearized = builder.create( - loc, this->getIndexType(), - ArrayRef{linearized, indices[i]}); + loc, this->getIndexType(), ArrayRef{linearized, indices[i]}); } return linearized; } // This is a strided getElementPtr variant that linearizes subscripts as: // `base_offset + index_0 * stride_0 + ... + index_n * stride_n`. - ValuePtr getStridedElementPtr(Location loc, Type elementTypePtr, - ValuePtr descriptor, ArrayRef indices, - ArrayRef strides, int64_t offset, - ConversionPatternRewriter &rewriter) const { + Value getStridedElementPtr(Location loc, Type elementTypePtr, + Value descriptor, ArrayRef indices, + ArrayRef strides, int64_t offset, + ConversionPatternRewriter &rewriter) const { MemRefDescriptor memRefDescriptor(descriptor); - ValuePtr base = memRefDescriptor.alignedPtr(rewriter, loc); - ValuePtr offsetValue = - offset == MemRefType::getDynamicStrideOrOffset() - ? memRefDescriptor.offset(rewriter, loc) - : this->createIndexConstant(rewriter, loc, offset); + Value base = memRefDescriptor.alignedPtr(rewriter, loc); + Value offsetValue = offset == MemRefType::getDynamicStrideOrOffset() + ? memRefDescriptor.offset(rewriter, loc) + : this->createIndexConstant(rewriter, loc, offset); for (int i = 0, e = indices.size(); i < e; ++i) { - ValuePtr stride = - strides[i] == MemRefType::getDynamicStrideOrOffset() - ? memRefDescriptor.stride(rewriter, loc, i) - : this->createIndexConstant(rewriter, loc, strides[i]); - ValuePtr additionalOffset = + Value stride = strides[i] == MemRefType::getDynamicStrideOrOffset() + ? memRefDescriptor.stride(rewriter, loc, i) + : this->createIndexConstant(rewriter, loc, strides[i]); + Value additionalOffset = rewriter.create(loc, indices[i], stride); offsetValue = rewriter.create(loc, offsetValue, additionalOffset); @@ -1435,10 +1429,9 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern { return rewriter.create(loc, elementTypePtr, base, offsetValue); } - ValuePtr getDataPtr(Location loc, MemRefType type, ValuePtr memRefDesc, - ArrayRef indices, - ConversionPatternRewriter &rewriter, - llvm::Module &module) const { + Value getDataPtr(Location loc, MemRefType type, Value memRefDesc, + ArrayRef indices, ConversionPatternRewriter &rewriter, + llvm::Module &module) const { LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType(); int64_t offset; SmallVector strides; @@ -1456,14 +1449,14 @@ struct LoadOpLowering : public LoadStoreOpLowering { using Base::Base; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loadOp = cast(op); OperandAdaptor transformed(operands); auto type = loadOp.getMemRefType(); - ValuePtr dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), - transformed.indices(), rewriter, getModule()); + Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), + transformed.indices(), rewriter, getModule()); rewriter.replaceOpWithNewOp(op, dataPtr); return matchSuccess(); } @@ -1475,13 +1468,13 @@ struct StoreOpLowering : public LoadStoreOpLowering { using Base::Base; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto type = cast(op).getMemRefType(); OperandAdaptor transformed(operands); - ValuePtr dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), - transformed.indices(), rewriter, getModule()); + Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), + transformed.indices(), rewriter, getModule()); rewriter.replaceOpWithNewOp(op, transformed.value(), dataPtr); return matchSuccess(); @@ -1494,14 +1487,14 @@ struct PrefetchOpLowering : public LoadStoreOpLowering { using Base::Base; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto prefetchOp = cast(op); OperandAdaptor transformed(operands); auto type = prefetchOp.getMemRefType(); - ValuePtr dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), - transformed.indices(), rewriter, getModule()); + Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), + transformed.indices(), rewriter, getModule()); // Replace with llvm.prefetch. auto llvmI32Type = lowering.convertType(rewriter.getIntegerType(32)); @@ -1529,7 +1522,7 @@ struct IndexCastOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { IndexCastOpOperandAdaptor transformed(operands); auto indexCastOp = cast(op); @@ -1564,7 +1557,7 @@ struct CmpIOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto cmpiOp = cast(op); CmpIOpOperandAdaptor transformed(operands); @@ -1583,7 +1576,7 @@ struct CmpFOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto cmpfOp = cast(op); CmpFOpOperandAdaptor transformed(operands); @@ -1635,9 +1628,9 @@ struct OneToOneLLVMTerminatorLowering using Super = OneToOneLLVMTerminatorLowering; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef properOperands, + matchAndRewrite(Operation *op, ArrayRef properOperands, ArrayRef destinations, - ArrayRef> operands, + ArrayRef> operands, ConversionPatternRewriter &rewriter) const override { SmallVector operandRanges(operands.begin(), operands.end()); rewriter.replaceOpWithNewOp(op, properOperands, destinations, @@ -1656,19 +1649,19 @@ struct ReturnOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { unsigned numArguments = op->getNumOperands(); // If ReturnOp has 0 or 1 operand, create it and return immediately. if (numArguments == 0) { rewriter.replaceOpWithNewOp( - op, ArrayRef(), ArrayRef(), op->getAttrs()); + op, ArrayRef(), ArrayRef(), op->getAttrs()); return matchSuccess(); } if (numArguments == 1) { rewriter.replaceOpWithNewOp( - op, ArrayRef(operands.front()), ArrayRef(), + op, ArrayRef(operands.front()), ArrayRef(), op->getAttrs()); return matchSuccess(); } @@ -1678,7 +1671,7 @@ struct ReturnOpLowering : public LLVMLegalizationPattern { auto packedType = lowering.packFunctionResults(llvm::to_vector<4>(op->getOperandTypes())); - ValuePtr packed = rewriter.create(op->getLoc(), packedType); + Value packed = rewriter.create(op->getLoc(), packedType); for (unsigned i = 0; i < numArguments; ++i) { packed = rewriter.create( op->getLoc(), packedType, packed, operands[i], @@ -1706,7 +1699,7 @@ struct SplatOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto splatOp = cast(op); VectorType resultType = splatOp.getType().dyn_cast(); @@ -1715,7 +1708,7 @@ struct SplatOpLowering : public LLVMLegalizationPattern { // First insert it into an undef vector so we can shuffle it. auto vectorType = lowering.convertType(splatOp.getType()); - ValuePtr undef = rewriter.create(op->getLoc(), vectorType); + Value undef = rewriter.create(op->getLoc(), vectorType); auto zero = rewriter.create( op->getLoc(), lowering.convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); @@ -1740,7 +1733,7 @@ struct SplatNdOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto splatOp = cast(op); OperandAdaptor adaptor(operands); @@ -1757,16 +1750,16 @@ struct SplatNdOpLowering : public LLVMLegalizationPattern { return matchFailure(); // Construct returned value. - ValuePtr desc = rewriter.create(loc, llvmArrayTy); + Value desc = rewriter.create(loc, llvmArrayTy); // Construct a 1-D vector with the splatted value that we insert in all the // places within the returned descriptor. - ValuePtr vdesc = rewriter.create(loc, llvmVectorTy); + Value vdesc = rewriter.create(loc, llvmVectorTy); auto zero = rewriter.create( loc, lowering.convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); - ValuePtr v = rewriter.create( - loc, llvmVectorTy, vdesc, adaptor.input(), zero); + Value v = rewriter.create(loc, llvmVectorTy, vdesc, + adaptor.input(), zero); // Shuffle the value across the desired number of elements. int64_t width = resultType.getDimSize(resultType.getRank() - 1); @@ -1794,21 +1787,21 @@ struct SubViewOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto viewOp = cast(op); // TODO(b/144779634, ravishankarm) : After Tblgen is adapted to support // having multiple variadic operands where each operand can have different // number of entries, clean all of this up. - SmallVector dynamicOffsets( + SmallVector dynamicOffsets( std::next(operands.begin()), std::next(operands.begin(), 1 + viewOp.getNumOffsets())); - SmallVector dynamicSizes( + SmallVector dynamicSizes( std::next(operands.begin(), 1 + viewOp.getNumOffsets()), std::next(operands.begin(), 1 + viewOp.getNumOffsets() + viewOp.getNumSizes())); - SmallVector dynamicStrides( + SmallVector dynamicStrides( std::next(operands.begin(), 1 + viewOp.getNumOffsets() + viewOp.getNumSizes()), operands.end()); @@ -1845,8 +1838,8 @@ struct SubViewOpLowering : public LLVMLegalizationPattern { auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); // Copy the buffer pointer from the old descriptor to the new one. - ValuePtr extracted = sourceMemRef.allocatedPtr(rewriter, loc); - ValuePtr bitcastPtr = rewriter.create( + Value extracted = sourceMemRef.allocatedPtr(rewriter, loc); + Value bitcastPtr = rewriter.create( loc, targetElementTy.getPointerTo(), extracted); targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); @@ -1856,7 +1849,7 @@ struct SubViewOpLowering : public LLVMLegalizationPattern { targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); // Extract strides needed to compute offset. - SmallVector strideValues; + SmallVector strideValues; strideValues.reserve(viewMemRefType.getRank()); for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); @@ -1873,9 +1866,9 @@ struct SubViewOpLowering : public LLVMLegalizationPattern { } // Offset. - ValuePtr baseOffset = sourceMemRef.offset(rewriter, loc); + Value baseOffset = sourceMemRef.offset(rewriter, loc); for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) { - ValuePtr min = dynamicOffsets[i]; + Value min = dynamicOffsets[i]; baseOffset = rewriter.create( loc, baseOffset, rewriter.create(loc, min, strideValues[i])); @@ -1885,7 +1878,7 @@ struct SubViewOpLowering : public LLVMLegalizationPattern { // Update sizes and strides. for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { targetMemRef.setSize(rewriter, loc, i, dynamicSizes[i]); - ValuePtr newStride; + Value newStride; if (dynamicStrides.empty()) newStride = rewriter.create( loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i])); @@ -1910,9 +1903,9 @@ struct ViewOpLowering : public LLVMLegalizationPattern { // Build and return the value for the idx^th shape dimension, either by // returning the constant shape dimension or counting the proper dynamic size. - ValuePtr getSize(ConversionPatternRewriter &rewriter, Location loc, - ArrayRef shape, ArrayRef dynamicSizes, - unsigned idx) const { + Value getSize(ConversionPatternRewriter &rewriter, Location loc, + ArrayRef shape, ArrayRef dynamicSizes, + unsigned idx) const { assert(idx < shape.size()); if (!ShapedType::isDynamic(shape[idx])) return createIndexConstant(rewriter, loc, shape[idx]); @@ -1927,9 +1920,9 @@ struct ViewOpLowering : public LLVMLegalizationPattern { // or by computing the dynamic stride from the current `runningStride` and // `nextSize`. The caller should keep a running stride and update it with the // result returned by this function. - ValuePtr getStride(ConversionPatternRewriter &rewriter, Location loc, - ArrayRef strides, ValuePtr nextSize, - ValuePtr runningStride, unsigned idx) const { + Value getStride(ConversionPatternRewriter &rewriter, Location loc, + ArrayRef strides, Value nextSize, + Value runningStride, unsigned idx) const { assert(idx < strides.size()); if (strides[idx] != MemRefType::getDynamicStrideOrOffset()) return createIndexConstant(rewriter, loc, strides[idx]); @@ -1942,7 +1935,7 @@ struct ViewOpLowering : public LLVMLegalizationPattern { } PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto viewOp = cast(op); @@ -1969,8 +1962,8 @@ struct ViewOpLowering : public LLVMLegalizationPattern { auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); // Field 1: Copy the allocated pointer, used for malloc/free. - ValuePtr extracted = sourceMemRef.allocatedPtr(rewriter, loc); - ValuePtr bitcastPtr = rewriter.create( + Value extracted = sourceMemRef.allocatedPtr(rewriter, loc); + Value bitcastPtr = rewriter.create( loc, targetElementTy.getPointerTo(), extracted); targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); @@ -1987,10 +1980,10 @@ struct ViewOpLowering : public LLVMLegalizationPattern { auto sizeAndOffsetOperands = adaptor.operands(); assert(llvm::size(sizeAndOffsetOperands) == numDynamicSizes + (hasDynamicOffset ? 1 : 0)); - ValuePtr baseOffset = !hasDynamicOffset - ? createIndexConstant(rewriter, loc, offset) - // TODO(ntv): better adaptor. - : sizeAndOffsetOperands.front(); + Value baseOffset = !hasDynamicOffset + ? createIndexConstant(rewriter, loc, offset) + // TODO(ntv): better adaptor. + : sizeAndOffsetOperands.front(); targetMemRef.setOffset(rewriter, loc, baseOffset); // Early exit for 0-D corner case. @@ -2001,14 +1994,14 @@ struct ViewOpLowering : public LLVMLegalizationPattern { if (strides.back() != 1) return op->emitWarning("cannot cast to non-contiguous shape"), matchFailure(); - ValuePtr stride = nullptr, nextSize = nullptr; + Value stride = nullptr, nextSize = nullptr; // Drop the dynamic stride from the operand list, if present. - ArrayRef sizeOperands(sizeAndOffsetOperands); + ArrayRef sizeOperands(sizeAndOffsetOperands); if (hasDynamicOffset) sizeOperands = sizeOperands.drop_front(); for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { // Update size. - ValuePtr size = + Value size = getSize(rewriter, loc, viewMemRefType.getShape(), sizeOperands, i); targetMemRef.setSize(rewriter, loc, i, size); // Update stride. @@ -2052,7 +2045,7 @@ static void ensureDistinctSuccessors(Block &bb) { auto *dummyBlock = new Block(); bb.getParent()->push_back(dummyBlock); auto builder = OpBuilder(dummyBlock); - SmallVector operands( + SmallVector operands( terminator->getSuccessorOperands(*position)); builder.create(terminator->getLoc(), successor.first, operands); terminator->setSuccessor(dummyBlock, *position); @@ -2173,29 +2166,28 @@ Type LLVMTypeConverter::packFunctionResults(ArrayRef types) { return LLVM::LLVMType::getStructTy(llvmDialect, resultTypes); } -ValuePtr LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, - ValuePtr operand, - OpBuilder &builder) { +Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand, + OpBuilder &builder) { auto *context = builder.getContext(); auto int64Ty = LLVM::LLVMType::getInt64Ty(getDialect()); auto indexType = IndexType::get(context); // Alloca with proper alignment. We do not expect optimizations of this // alloca op and so we omit allocating at the entry block. auto ptrType = operand->getType().cast().getPointerTo(); - ValuePtr one = builder.create( - loc, int64Ty, IntegerAttr::get(indexType, 1)); - ValuePtr allocated = + Value one = builder.create(loc, int64Ty, + IntegerAttr::get(indexType, 1)); + Value allocated = builder.create(loc, ptrType, one, /*alignment=*/0); // Store into the alloca'ed descriptor. builder.create(loc, operand, allocated); return allocated; } -SmallVector +SmallVector LLVMTypeConverter::promoteMemRefDescriptors(Location loc, ValueRange opOperands, ValueRange operands, OpBuilder &builder) { - SmallVector promotedOperands; + SmallVector promotedOperands; promotedOperands.reserve(operands.size()); for (auto it : llvm::zip(opOperands, operands)) { auto operand = std::get<0>(it); diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp index af1c92ef11d..a02dee4419a 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -35,7 +35,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult - matchAndRewrite(ConstantOp constIndexOp, ArrayRef operands, + matchAndRewrite(ConstantOp constIndexOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -45,7 +45,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult - matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, + matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -61,7 +61,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult - matchAndRewrite(StdOp operation, ArrayRef operands, + matchAndRewrite(StdOp operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto resultType = this->typeConverter.convertType(operation.getResult()->getType()); @@ -80,7 +80,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult - matchAndRewrite(LoadOp loadOp, ArrayRef operands, + matchAndRewrite(LoadOp loadOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -91,7 +91,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult - matchAndRewrite(ReturnOp returnOp, ArrayRef operands, + matchAndRewrite(ReturnOp returnOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -101,7 +101,7 @@ class SelectOpConversion final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult - matchAndRewrite(SelectOp op, ArrayRef operands, + matchAndRewrite(SelectOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -114,7 +114,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult - matchAndRewrite(StoreOp storeOp, ArrayRef operands, + matchAndRewrite(StoreOp storeOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -132,8 +132,7 @@ public: spirv::AccessChainOp getElementPtr(OpBuilder &builder, SPIRVTypeConverter &typeConverter, Location loc, MemRefType origBaseType, - ValuePtr basePtr, - ArrayRef indices) { + Value basePtr, ArrayRef indices) { // Get base and offset of the MemRefType and verify they are static. int64_t offset; SmallVector strides; @@ -144,18 +143,17 @@ spirv::AccessChainOp getElementPtr(OpBuilder &builder, auto indexType = typeConverter.getIndexType(builder.getContext()); - ValuePtr ptrLoc = nullptr; + Value ptrLoc = nullptr; assert(indices.size() == strides.size()); for (auto index : enumerate(indices)) { - ValuePtr strideVal = builder.create( + Value strideVal = builder.create( loc, indexType, IntegerAttr::get(indexType, strides[index.index()])); - ValuePtr update = - builder.create(loc, strideVal, index.value()); + Value update = builder.create(loc, strideVal, index.value()); ptrLoc = (ptrLoc ? builder.create(loc, ptrLoc, update).getResult() : update); } - SmallVector linearizedIndices; + SmallVector linearizedIndices; // Add a '0' at the start to index into the struct. linearizedIndices.push_back(builder.create( loc, indexType, IntegerAttr::get(indexType, 0))); @@ -168,7 +166,7 @@ spirv::AccessChainOp getElementPtr(OpBuilder &builder, //===----------------------------------------------------------------------===// PatternMatchResult ConstantIndexOpConversion::matchAndRewrite( - ConstantOp constIndexOp, ArrayRef operands, + ConstantOp constIndexOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (!constIndexOp.getResult()->getType().isa()) { return matchFailure(); @@ -202,7 +200,7 @@ PatternMatchResult ConstantIndexOpConversion::matchAndRewrite( //===----------------------------------------------------------------------===// PatternMatchResult -CmpIOpConversion::matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, +CmpIOpConversion::matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { CmpIOpOperandAdaptor cmpIOpOperands(operands); @@ -234,7 +232,7 @@ CmpIOpConversion::matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, //===----------------------------------------------------------------------===// PatternMatchResult -LoadOpConversion::matchAndRewrite(LoadOp loadOp, ArrayRef operands, +LoadOpConversion::matchAndRewrite(LoadOp loadOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { LoadOpOperandAdaptor loadOperands(operands); auto loadPtr = getElementPtr(rewriter, typeConverter, loadOp.getLoc(), @@ -251,8 +249,7 @@ LoadOpConversion::matchAndRewrite(LoadOp loadOp, ArrayRef operands, //===----------------------------------------------------------------------===// PatternMatchResult -ReturnOpConversion::matchAndRewrite(ReturnOp returnOp, - ArrayRef operands, +ReturnOpConversion::matchAndRewrite(ReturnOp returnOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (returnOp.getNumOperands()) { return matchFailure(); @@ -266,7 +263,7 @@ ReturnOpConversion::matchAndRewrite(ReturnOp returnOp, //===----------------------------------------------------------------------===// PatternMatchResult -SelectOpConversion::matchAndRewrite(SelectOp op, ArrayRef operands, +SelectOpConversion::matchAndRewrite(SelectOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { SelectOpOperandAdaptor selectOperands(operands); rewriter.replaceOpWithNewOp(op, selectOperands.condition(), @@ -280,7 +277,7 @@ SelectOpConversion::matchAndRewrite(SelectOp op, ArrayRef operands, //===----------------------------------------------------------------------===// PatternMatchResult -StoreOpConversion::matchAndRewrite(StoreOp storeOp, ArrayRef operands, +StoreOpConversion::matchAndRewrite(StoreOp storeOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { StoreOpOperandAdaptor storeOperands(operands); auto storePtr = diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp index c3937358c47..52456b6e46d 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp @@ -28,7 +28,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult - matchAndRewrite(FuncOp funcOp, ArrayRef operands, + matchAndRewrite(FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -40,7 +40,7 @@ class ConvertStandardToSPIRVPass } // namespace PatternMatchResult -FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef operands, +FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { auto fnType = funcOp.getType(); if (fnType.getNumResults()) { diff --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp index 5d693336c3f..a658356f76c 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp @@ -60,7 +60,7 @@ public: static LogicalResult resolveSourceIndices(Location loc, PatternRewriter &rewriter, SubViewOp subViewOp, ValueRange indices, - SmallVectorImpl &sourceIndices) { + SmallVectorImpl &sourceIndices) { // TODO: Aborting when the offsets are static. There might be a way to fold // the subview op with load even if the offsets have been canonicalized // away. @@ -68,7 +68,7 @@ resolveSourceIndices(Location loc, PatternRewriter &rewriter, return failure(); ValueRange opOffsets = subViewOp.offsets(); - SmallVector opStrides; + SmallVector opStrides; if (subViewOp.getNumStrides()) { // If the strides are dynamic, get the stride operands. opStrides = llvm::to_vector<2>(subViewOp.strides()); @@ -115,7 +115,7 @@ LoadOpOfSubViewFolder::matchAndRewrite(LoadOp loadOp, if (!subViewOp) { return matchFailure(); } - SmallVector sourceIndices; + SmallVector sourceIndices; if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp, loadOp.indices(), sourceIndices))) return matchFailure(); @@ -137,7 +137,7 @@ StoreOpOfSubViewFolder::matchAndRewrite(StoreOp storeOp, if (!subViewOp) { return matchFailure(); } - SmallVector sourceIndices; + SmallVector sourceIndices; if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp, storeOp.indices(), sourceIndices))) return matchFailure(); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 56005220d3f..b48930c4dda 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -53,10 +53,9 @@ static VectorType reducedVectorTypeBack(VectorType tp) { } // Helper that picks the proper sequence for inserting. -static ValuePtr insertOne(ConversionPatternRewriter &rewriter, - LLVMTypeConverter &lowering, Location loc, - ValuePtr val1, ValuePtr val2, Type llvmType, - int64_t rank, int64_t pos) { +static Value insertOne(ConversionPatternRewriter &rewriter, + LLVMTypeConverter &lowering, Location loc, Value val1, + Value val2, Type llvmType, int64_t rank, int64_t pos) { if (rank == 1) { auto idxType = rewriter.getIndexType(); auto constant = rewriter.create( @@ -70,10 +69,9 @@ static ValuePtr insertOne(ConversionPatternRewriter &rewriter, } // Helper that picks the proper sequence for extracting. -static ValuePtr extractOne(ConversionPatternRewriter &rewriter, - LLVMTypeConverter &lowering, Location loc, - ValuePtr val, Type llvmType, int64_t rank, - int64_t pos) { +static Value extractOne(ConversionPatternRewriter &rewriter, + LLVMTypeConverter &lowering, Location loc, Value val, + Type llvmType, int64_t rank, int64_t pos) { if (rank == 1) { auto idxType = rewriter.getIndexType(); auto constant = rewriter.create( @@ -94,7 +92,7 @@ public: typeConverter) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto broadcastOp = cast(op); VectorType dstVectorType = broadcastOp.getVectorType(); @@ -122,9 +120,9 @@ private: // ops once all insert/extract/shuffle operations // are available with lowering implemention. // - ValuePtr expandRanks(ValuePtr value, Location loc, VectorType srcVectorType, - VectorType dstVectorType, - ConversionPatternRewriter &rewriter) const { + Value expandRanks(Value value, Location loc, VectorType srcVectorType, + VectorType dstVectorType, + ConversionPatternRewriter &rewriter) const { assert((dstVectorType != nullptr) && "invalid result type in broadcast"); // Determine rank of source and destination. int64_t srcRank = srcVectorType ? srcVectorType.getRank() : 0; @@ -161,24 +159,22 @@ private: // becomes: // x = [s,s] // v = [x,x,x,x] - ValuePtr duplicateOneRank(ValuePtr value, Location loc, - VectorType srcVectorType, VectorType dstVectorType, - int64_t rank, int64_t dim, - ConversionPatternRewriter &rewriter) const { + Value duplicateOneRank(Value value, Location loc, VectorType srcVectorType, + VectorType dstVectorType, int64_t rank, int64_t dim, + ConversionPatternRewriter &rewriter) const { Type llvmType = lowering.convertType(dstVectorType); assert((llvmType != nullptr) && "unlowerable vector type"); if (rank == 1) { - ValuePtr undef = rewriter.create(loc, llvmType); - ValuePtr expand = + Value undef = rewriter.create(loc, llvmType); + Value expand = insertOne(rewriter, lowering, loc, undef, value, llvmType, rank, 0); SmallVector zeroValues(dim, 0); return rewriter.create( loc, expand, undef, rewriter.getI32ArrayAttr(zeroValues)); } - ValuePtr expand = - expandRanks(value, loc, srcVectorType, - reducedVectorTypeFront(dstVectorType), rewriter); - ValuePtr result = rewriter.create(loc, llvmType); + Value expand = expandRanks(value, loc, srcVectorType, + reducedVectorTypeFront(dstVectorType), rewriter); + Value result = rewriter.create(loc, llvmType); for (int64_t d = 0; d < dim; ++d) { result = insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d); @@ -203,20 +199,19 @@ private: // y = broadcast w[1][0] : vector<2xf32> to vector <2x2xf32> // a = [x, y] // etc. - ValuePtr stretchOneRank(ValuePtr value, Location loc, - VectorType srcVectorType, VectorType dstVectorType, - int64_t rank, int64_t dim, - ConversionPatternRewriter &rewriter) const { + Value stretchOneRank(Value value, Location loc, VectorType srcVectorType, + VectorType dstVectorType, int64_t rank, int64_t dim, + ConversionPatternRewriter &rewriter) const { Type llvmType = lowering.convertType(dstVectorType); assert((llvmType != nullptr) && "unlowerable vector type"); - ValuePtr result = rewriter.create(loc, llvmType); + Value result = rewriter.create(loc, llvmType); bool atStretch = dim != srcVectorType.getDimSize(0); if (rank == 1) { assert(atStretch); Type redLlvmType = lowering.convertType(dstVectorType.getElementType()); - ValuePtr one = + Value one = extractOne(rewriter, lowering, loc, value, redLlvmType, rank, 0); - ValuePtr expand = + Value expand = insertOne(rewriter, lowering, loc, result, one, llvmType, rank, 0); SmallVector zeroValues(dim, 0); return rewriter.create( @@ -227,9 +222,9 @@ private: Type redLlvmType = lowering.convertType(redSrcType); for (int64_t d = 0; d < dim; ++d) { int64_t pos = atStretch ? 0 : d; - ValuePtr one = + Value one = extractOne(rewriter, lowering, loc, value, redLlvmType, rank, pos); - ValuePtr expand = expandRanks(one, loc, redSrcType, redDstType, rewriter); + Value expand = expandRanks(one, loc, redSrcType, redDstType, rewriter); result = insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d); } @@ -245,7 +240,7 @@ public: typeConverter) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto adaptor = vector::ShuffleOpOperandAdaptor(operands); @@ -269,23 +264,23 @@ public: // For rank 1, where both operands have *exactly* the same vector type, // there is direct shuffle support in LLVM. Use it! if (rank == 1 && v1Type == v2Type) { - ValuePtr shuffle = rewriter.create( + Value shuffle = rewriter.create( loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); rewriter.replaceOp(op, shuffle); return matchSuccess(); } // For all other cases, insert the individual values individually. - ValuePtr insert = rewriter.create(loc, llvmType); + Value insert = rewriter.create(loc, llvmType); int64_t insPos = 0; for (auto en : llvm::enumerate(maskArrayAttr)) { int64_t extPos = en.value().cast().getInt(); - ValuePtr value = adaptor.v1(); + Value value = adaptor.v1(); if (extPos >= v1Dim) { extPos -= v1Dim; value = adaptor.v2(); } - ValuePtr extract = + Value extract = extractOne(rewriter, lowering, loc, value, llvmType, rank, extPos); insert = insertOne(rewriter, lowering, loc, insert, extract, llvmType, rank, insPos++); @@ -303,7 +298,7 @@ public: typeConverter) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto adaptor = vector::ExtractElementOpOperandAdaptor(operands); auto extractEltOp = cast(op); @@ -328,7 +323,7 @@ public: typeConverter) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto adaptor = vector::ExtractOpOperandAdaptor(operands); @@ -344,7 +339,7 @@ public: // One-shot extraction of vector from array (only requires extractvalue). if (resultType.isa()) { - ValuePtr extracted = rewriter.create( + Value extracted = rewriter.create( loc, llvmResultType, adaptor.vector(), positionArrayAttr); rewriter.replaceOp(op, extracted); return matchSuccess(); @@ -352,7 +347,7 @@ public: // Potential extraction of 1-D vector from array. auto *context = op->getContext(); - ValuePtr extracted = adaptor.vector(); + Value extracted = adaptor.vector(); auto positionAttrs = positionArrayAttr.getValue(); if (positionAttrs.size() > 1) { auto oneDVectorType = reducedVectorTypeBack(vectorType); @@ -383,7 +378,7 @@ public: typeConverter) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto adaptor = vector::InsertElementOpOperandAdaptor(operands); auto insertEltOp = cast(op); @@ -408,7 +403,7 @@ public: typeConverter) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto adaptor = vector::InsertOpOperandAdaptor(operands); @@ -424,7 +419,7 @@ public: // One-shot insertion of a vector into an array (only requires insertvalue). if (sourceType.isa()) { - ValuePtr inserted = rewriter.create( + Value inserted = rewriter.create( loc, llvmResultType, adaptor.dest(), adaptor.source(), positionArrayAttr); rewriter.replaceOp(op, inserted); @@ -433,7 +428,7 @@ public: // Potential extraction of 1-D vector from array. auto *context = op->getContext(); - ValuePtr extracted = adaptor.dest(); + Value extracted = adaptor.dest(); auto positionAttrs = positionArrayAttr.getValue(); auto position = positionAttrs.back().cast(); auto oneDVectorType = destVectorType; @@ -449,7 +444,7 @@ public: // Insertion of an element into a 1-D LLVM vector. auto i64Type = LLVM::LLVMType::getInt64Ty(lowering.getDialect()); auto constant = rewriter.create(loc, i64Type, position); - ValuePtr inserted = rewriter.create( + Value inserted = rewriter.create( loc, lowering.convertType(oneDVectorType), extracted, adaptor.source(), constant); @@ -475,7 +470,7 @@ public: typeConverter) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto adaptor = vector::OuterProductOpOperandAdaptor(operands); @@ -486,10 +481,10 @@ public: auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements(); auto llvmArrayOfVectType = lowering.convertType( cast(op).getResult()->getType()); - ValuePtr desc = rewriter.create(loc, llvmArrayOfVectType); - ValuePtr a = adaptor.lhs(), b = adaptor.rhs(); - ValuePtr acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front(); - SmallVector lhs, accs; + Value desc = rewriter.create(loc, llvmArrayOfVectType); + Value a = adaptor.lhs(), b = adaptor.rhs(); + Value acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front(); + SmallVector lhs, accs; lhs.reserve(rankLHS); accs.reserve(rankLHS); for (unsigned d = 0, e = rankLHS; d < e; ++d) { @@ -497,7 +492,7 @@ public: auto attr = rewriter.getI32IntegerAttr(d); SmallVector bcastAttr(rankRHS, attr); auto bcastArrayAttr = ArrayAttr::get(bcastAttr, ctx); - ValuePtr aD = nullptr, accD = nullptr; + Value aD = nullptr, accD = nullptr; // 1. Broadcast the element a[d] into vector aD. aD = rewriter.create(loc, a, a, bcastArrayAttr); // 2. If acc is present, extract 1-d vector acc[d] into accD. @@ -505,7 +500,7 @@ public: accD = rewriter.create( loc, vRHS, acc, rewriter.getI64ArrayAttr(d)); // 3. Compute aD outer b (plus accD, if relevant). - ValuePtr aOuterbD = + Value aOuterbD = accD ? rewriter.create(loc, vRHS, aD, b, accD) .getResult() : rewriter.create(loc, aD, b).getResult(); @@ -527,7 +522,7 @@ public: typeConverter) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); vector::TypeCastOp castOp = cast(op); @@ -576,12 +571,12 @@ public: auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); Type llvmTargetElementTy = desc.getElementType(); // Set allocated ptr. - ValuePtr allocated = sourceMemRef.allocatedPtr(rewriter, loc); + Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); allocated = rewriter.create(loc, llvmTargetElementTy, allocated); desc.setAllocatedPtr(rewriter, loc, allocated); // Set aligned ptr. - ValuePtr ptr = sourceMemRef.alignedPtr(rewriter, loc); + Value ptr = sourceMemRef.alignedPtr(rewriter, loc); ptr = rewriter.create(loc, llvmTargetElementTy, ptr); desc.setAlignedPtr(rewriter, loc, ptr); // Fill offset 0. @@ -627,7 +622,7 @@ public: // TODO(ajcbik): rely solely on libc in future? something else? // PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto printOp = cast(op); auto adaptor = vector::PrintOpOperandAdaptor(operands); @@ -657,7 +652,7 @@ public: private: void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, - ValuePtr value, VectorType vectorType, Operation *printer, + Value value, VectorType vectorType, Operation *printer, int64_t rank) const { Location loc = op->getLoc(); if (rank == 0) { @@ -673,7 +668,7 @@ private: rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; auto llvmType = lowering.convertType( rank > 1 ? reducedType : vectorType.getElementType()); - ValuePtr nestedVal = + Value nestedVal = extractOne(rewriter, lowering, loc, value, llvmType, rank, d); emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1); if (d != dim - 1) diff --git a/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/mlir/lib/Dialect/AffineOps/AffineOps.cpp index d80f9865ccb..5f4cc2e1060 100644 --- a/mlir/lib/Dialect/AffineOps/AffineOps.cpp +++ b/mlir/lib/Dialect/AffineOps/AffineOps.cpp @@ -106,7 +106,7 @@ static bool isFunctionRegion(Region *region) { /// A utility function to check if a value is defined at the top level of a /// function. A value of index type defined at the top level is always a valid /// symbol. -bool mlir::isTopLevelValue(ValuePtr value) { +bool mlir::isTopLevelValue(Value value) { if (auto arg = value.dyn_cast()) return isFunctionRegion(arg->getOwner()->getParent()); return isFunctionRegion(value->getDefiningOp()->getParentRegion()); @@ -115,7 +115,7 @@ bool mlir::isTopLevelValue(ValuePtr value) { // Value can be used as a dimension id if it is valid as a symbol, or // it is an induction variable, or it is a result of affine apply operation // with dimension id arguments. -bool mlir::isValidDim(ValuePtr value) { +bool mlir::isValidDim(Value value) { // The value must be an index type. if (!value->getType().isIndex()) return false; @@ -175,7 +175,7 @@ static bool isDimOpValidSymbol(DimOp dimOp) { // the top level, or it is a result of affine apply operation with symbol // arguments, or a result of the dim op on a memref satisfying certain // constraints. -bool mlir::isValidSymbol(ValuePtr value) { +bool mlir::isValidSymbol(Value value) { // The value must be an index type. if (!value->getType().isIndex()) return false; @@ -198,7 +198,7 @@ bool mlir::isValidSymbol(ValuePtr value) { // Returns true if 'value' is a valid index to an affine operation (e.g. // affine.load, affine.store, affine.dma_start, affine.dma_wait). // Returns false otherwise. -static bool isValidAffineIndexOperand(ValuePtr value) { +static bool isValidAffineIndexOperand(Value value) { return isValidDim(value) || isValidSymbol(value); } @@ -297,14 +297,14 @@ LogicalResult AffineApplyOp::verify() { // its operands are valid dimension ids. bool AffineApplyOp::isValidDim() { return llvm::all_of(getOperands(), - [](ValuePtr op) { return mlir::isValidDim(op); }); + [](Value op) { return mlir::isValidDim(op); }); } // The result of the affine apply operation can be used as a symbol if all its // operands are symbols. bool AffineApplyOp::isValidSymbol() { return llvm::all_of(getOperands(), - [](ValuePtr op) { return mlir::isValidSymbol(op); }); + [](Value op) { return mlir::isValidSymbol(op); }); } OpFoldResult AffineApplyOp::fold(ArrayRef operands) { @@ -324,8 +324,8 @@ OpFoldResult AffineApplyOp::fold(ArrayRef operands) { return result[0]; } -AffineDimExpr AffineApplyNormalizer::renumberOneDim(ValuePtr v) { - DenseMap::iterator iterPos; +AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value v) { + DenseMap::iterator iterPos; bool inserted = false; std::tie(iterPos, inserted) = dimValueToPosition.insert(std::make_pair(v, dimValueToPosition.size())); @@ -362,7 +362,7 @@ AffineMap AffineApplyNormalizer::renumber(const AffineApplyNormalizer &other) { // Gather the positions of the operands that are produced by an AffineApplyOp. static llvm::SetVector -indicesFromAffineApplyOp(ArrayRef operands) { +indicesFromAffineApplyOp(ArrayRef operands) { llvm::SetVector res; for (auto en : llvm::enumerate(operands)) if (isa_and_nonnull(en.value()->getDefiningOp())) @@ -384,7 +384,7 @@ indicesFromAffineApplyOp(ArrayRef operands) { // results in better simplifications and foldings. But we should evaluate // whether this behavior is what we really want after using more. static AffineMap promoteComposedSymbolsAsDims(AffineMap map, - ArrayRef symbols) { + ArrayRef symbols) { if (symbols.empty()) { return map; } @@ -453,7 +453,7 @@ static AffineMap promoteComposedSymbolsAsDims(AffineMap map, /// benefit potentially big: simpler and more maintainable code for a /// non-trivial, recursive, procedure. AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map, - ArrayRef operands) + ArrayRef operands) : AffineApplyNormalizer() { static_assert(kMaxAffineApplyDepth > 0, "kMaxAffineApplyDepth must be > 0"); assert(map.getNumInputs() == operands.size() && @@ -509,7 +509,7 @@ AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map, LLVM_DEBUG(affineApply.getOperation()->print( dbgs() << "\nCompose AffineApplyOp recursively: ")); AffineMap affineApplyMap = affineApply.getAffineMap(); - SmallVector affineApplyOperands( + SmallVector affineApplyOperands( affineApply.getOperands().begin(), affineApply.getOperands().end()); AffineApplyNormalizer normalizer(affineApplyMap, affineApplyOperands); @@ -560,8 +560,8 @@ AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map, LLVM_DEBUG(dbgs() << "\n"); } -void AffineApplyNormalizer::normalize( - AffineMap *otherMap, SmallVectorImpl *otherOperands) { +void AffineApplyNormalizer::normalize(AffineMap *otherMap, + SmallVectorImpl *otherOperands) { AffineApplyNormalizer other(*otherMap, *otherOperands); *otherMap = renumber(other); @@ -575,7 +575,7 @@ void AffineApplyNormalizer::normalize( /// on `map` and `operands` without creating an AffineApplyOp that needs to be /// immediately deleted. static void composeAffineMapAndOperands(AffineMap *map, - SmallVectorImpl *operands) { + SmallVectorImpl *operands) { AffineApplyNormalizer normalizer(*map, *operands); auto normalizedMap = normalizer.getAffineMap(); auto normalizedOperands = normalizer.getOperands(); @@ -585,9 +585,9 @@ static void composeAffineMapAndOperands(AffineMap *map, assert(*map); } -void mlir::fullyComposeAffineMapAndOperands( - AffineMap *map, SmallVectorImpl *operands) { - while (llvm::any_of(*operands, [](ValuePtr v) { +void mlir::fullyComposeAffineMapAndOperands(AffineMap *map, + SmallVectorImpl *operands) { + while (llvm::any_of(*operands, [](Value v) { return isa_and_nonnull(v->getDefiningOp()); })) { composeAffineMapAndOperands(map, operands); @@ -596,9 +596,9 @@ void mlir::fullyComposeAffineMapAndOperands( AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, - ArrayRef operands) { + ArrayRef operands) { AffineMap normalizedMap = map; - SmallVector normalizedOperands(operands.begin(), operands.end()); + SmallVector normalizedOperands(operands.begin(), operands.end()); composeAffineMapAndOperands(&normalizedMap, &normalizedOperands); assert(normalizedMap); return b.create(loc, normalizedMap, normalizedOperands); @@ -608,7 +608,7 @@ AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc, // canonicalizes dims that are valid symbols into actual symbols. template static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, - SmallVectorImpl *operands) { + SmallVectorImpl *operands) { if (!mapOrSet || operands->empty()) return; @@ -616,9 +616,9 @@ static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, "map/set inputs must match number of operands"); auto *context = mapOrSet->getContext(); - SmallVector resultOperands; + SmallVector resultOperands; resultOperands.reserve(operands->size()); - SmallVector remappedSymbols; + SmallVector remappedSymbols; remappedSymbols.reserve(operands->size()); unsigned nextDim = 0; unsigned nextSym = 0; @@ -650,9 +650,8 @@ static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, // Works for either an affine map or an integer set. template -static void -canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, - SmallVectorImpl *operands) { +static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, + SmallVectorImpl *operands) { static_assert(std::is_same::value || std::is_same::value, "Argument must be either of AffineMap or IntegerSet type"); @@ -677,10 +676,10 @@ canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, auto *context = mapOrSet->getContext(); - SmallVector resultOperands; + SmallVector resultOperands; resultOperands.reserve(operands->size()); - llvm::SmallDenseMap seenDims; + llvm::SmallDenseMap seenDims; SmallVector dimRemapping(mapOrSet->getNumDims()); unsigned nextDim = 0; for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) { @@ -696,7 +695,7 @@ canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, } } } - llvm::SmallDenseMap seenSymbols; + llvm::SmallDenseMap seenSymbols; SmallVector symRemapping(mapOrSet->getNumSymbols()); unsigned nextSym = 0; for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) { @@ -729,12 +728,12 @@ canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, } void mlir::canonicalizeMapAndOperands(AffineMap *map, - SmallVectorImpl *operands) { + SmallVectorImpl *operands) { canonicalizeMapOrSetAndOperands(map, operands); } void mlir::canonicalizeSetAndOperands(IntegerSet *set, - SmallVectorImpl *operands) { + SmallVectorImpl *operands) { canonicalizeMapOrSetAndOperands(set, operands); } @@ -749,7 +748,7 @@ struct SimplifyAffineOp : public OpRewritePattern { /// Replace the affine op with another instance of it with the supplied /// map and mapOperands. void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp, - AffineMap map, ArrayRef mapOperands) const; + AffineMap map, ArrayRef mapOperands) const; PatternMatchResult matchAndRewrite(AffineOpTy affineOp, PatternRewriter &rewriter) const override { @@ -761,7 +760,7 @@ struct SimplifyAffineOp : public OpRewritePattern { auto map = affineOp.getAffineMap(); AffineMap oldMap = map; auto oldOperands = affineOp.getMapOperands(); - SmallVector resultOperands(oldOperands); + SmallVector resultOperands(oldOperands); composeAffineMapAndOperands(&map, &resultOperands); if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(), resultOperands.begin())) @@ -777,14 +776,14 @@ struct SimplifyAffineOp : public OpRewritePattern { template <> void SimplifyAffineOp::replaceAffineOp( PatternRewriter &rewriter, AffineLoadOp load, AffineMap map, - ArrayRef mapOperands) const { + ArrayRef mapOperands) const { rewriter.replaceOpWithNewOp(load, load.getMemRef(), map, mapOperands); } template <> void SimplifyAffineOp::replaceAffineOp( PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map, - ArrayRef mapOperands) const { + ArrayRef mapOperands) const { rewriter.replaceOpWithNewOp( prefetch, prefetch.memref(), map, mapOperands, prefetch.localityHint().getZExtValue(), prefetch.isWrite(), @@ -793,14 +792,14 @@ void SimplifyAffineOp::replaceAffineOp( template <> void SimplifyAffineOp::replaceAffineOp( PatternRewriter &rewriter, AffineStoreOp store, AffineMap map, - ArrayRef mapOperands) const { + ArrayRef mapOperands) const { rewriter.replaceOpWithNewOp( store, store.getValueToStore(), store.getMemRef(), map, mapOperands); } template <> void SimplifyAffineOp::replaceAffineOp( PatternRewriter &rewriter, AffineApplyOp apply, AffineMap map, - ArrayRef mapOperands) const { + ArrayRef mapOperands) const { rewriter.replaceOpWithNewOp(apply, map, mapOperands); } } // end anonymous namespace. @@ -835,12 +834,12 @@ static LogicalResult foldMemRefCast(Operation *op) { // TODO(b/133776335) Check that map operands are loop IVs or symbols. void AffineDmaStartOp::build(Builder *builder, OperationState &result, - ValuePtr srcMemRef, AffineMap srcMap, - ValueRange srcIndices, ValuePtr destMemRef, + Value srcMemRef, AffineMap srcMap, + ValueRange srcIndices, Value destMemRef, AffineMap dstMap, ValueRange destIndices, - ValuePtr tagMemRef, AffineMap tagMap, - ValueRange tagIndices, ValuePtr numElements, - ValuePtr stride, ValuePtr elementsPerStride) { + Value tagMemRef, AffineMap tagMap, + ValueRange tagIndices, Value numElements, + Value stride, Value elementsPerStride) { result.addOperands(srcMemRef); result.addAttribute(getSrcMapAttrName(), AffineMapAttr::get(srcMap)); result.addOperands(srcIndices); @@ -1004,8 +1003,8 @@ LogicalResult AffineDmaStartOp::fold(ArrayRef cstOperands, // TODO(b/133776335) Check that map operands are loop IVs or symbols. void AffineDmaWaitOp::build(Builder *builder, OperationState &result, - ValuePtr tagMemRef, AffineMap tagMap, - ValueRange tagIndices, ValuePtr numElements) { + Value tagMemRef, AffineMap tagMap, + ValueRange tagIndices, Value numElements) { result.addOperands(tagMemRef); result.addAttribute(getTagMapAttrName(), AffineMapAttr::get(tagMap)); result.addOperands(tagIndices); @@ -1014,7 +1013,7 @@ void AffineDmaWaitOp::build(Builder *builder, OperationState &result, void AffineDmaWaitOp::print(OpAsmPrinter &p) { p << "affine.dma_wait " << *getTagMemRef() << '['; - SmallVector operands(getTagIndices()); + SmallVector operands(getTagIndices()); p.printAffineMapOfSSAIds(getTagMapAttr(), operands); p << "], "; p.printOperand(getNumElements()); @@ -1399,8 +1398,8 @@ static LogicalResult foldLoopBounds(AffineForOp forOp) { /// Canonicalize the bounds of the given loop. static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) { - SmallVector lbOperands(forOp.getLowerBoundOperands()); - SmallVector ubOperands(forOp.getUpperBoundOperands()); + SmallVector lbOperands(forOp.getLowerBoundOperands()); + SmallVector ubOperands(forOp.getUpperBoundOperands()); auto lbMap = forOp.getLowerBoundMap(); auto ubMap = forOp.getUpperBoundMap(); @@ -1465,7 +1464,7 @@ void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) { assert(lbOperands.size() == map.getNumInputs()); assert(map.getNumResults() >= 1 && "bound map has at least one result"); - SmallVector newOperands(lbOperands.begin(), lbOperands.end()); + SmallVector newOperands(lbOperands.begin(), lbOperands.end()); auto ubOperands = getUpperBoundOperands(); newOperands.append(ubOperands.begin(), ubOperands.end()); @@ -1478,7 +1477,7 @@ void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) { assert(ubOperands.size() == map.getNumInputs()); assert(map.getNumResults() >= 1 && "bound map has at least one result"); - SmallVector newOperands(getLowerBoundOperands()); + SmallVector newOperands(getLowerBoundOperands()); newOperands.append(ubOperands.begin(), ubOperands.end()); getOperation()->setOperands(newOperands); @@ -1544,7 +1543,7 @@ bool AffineForOp::matchingBoundOperandList() { unsigned numOperands = lbMap.getNumInputs(); for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) { - // Compare ValuePtr 's. + // Compare Value 's. if (getOperand(i) != getOperand(numOperands + i)) return false; } @@ -1553,7 +1552,7 @@ bool AffineForOp::matchingBoundOperandList() { Region &AffineForOp::getLoopBody() { return region(); } -bool AffineForOp::isDefinedOutsideOfLoop(ValuePtr value) { +bool AffineForOp::isDefinedOutsideOfLoop(Value value) { return !region().isAncestor(value->getParentRegion()); } @@ -1564,13 +1563,13 @@ LogicalResult AffineForOp::moveOutOfLoop(ArrayRef ops) { } /// Returns if the provided value is the induction variable of a AffineForOp. -bool mlir::isForInductionVar(ValuePtr val) { +bool mlir::isForInductionVar(Value val) { return getForInductionVarOwner(val) != AffineForOp(); } /// Returns the loop parent of an induction variable. If the provided value is /// not an induction variable, then return nullptr. -AffineForOp mlir::getForInductionVarOwner(ValuePtr val) { +AffineForOp mlir::getForInductionVarOwner(Value val) { auto ivArg = val.dyn_cast(); if (!ivArg || !ivArg->getOwner()) return AffineForOp(); @@ -1581,7 +1580,7 @@ AffineForOp mlir::getForInductionVarOwner(ValuePtr val) { /// Extracts the induction variables from a list of AffineForOps and returns /// them. void mlir::extractForInductionVars(ArrayRef forInsts, - SmallVectorImpl *ivs) { + SmallVectorImpl *ivs) { ivs->reserve(forInsts.size()); for (auto forInst : forInsts) ivs->push_back(forInst.getInductionVar()); @@ -1720,7 +1719,7 @@ void AffineIfOp::build(Builder *builder, OperationState &result, IntegerSet set, LogicalResult AffineIfOp::fold(ArrayRef, SmallVectorImpl &) { auto set = getIntegerSet(); - SmallVector operands(getOperands()); + SmallVector operands(getOperands()); canonicalizeSetAndOperands(&set, &operands); // Any canonicalization change always leads to either a reduction in the @@ -1749,9 +1748,8 @@ void AffineLoadOp::build(Builder *builder, OperationState &result, result.types.push_back(memrefType.getElementType()); } -void AffineLoadOp::build(Builder *builder, OperationState &result, - ValuePtr memref, AffineMap map, - ValueRange mapOperands) { +void AffineLoadOp::build(Builder *builder, OperationState &result, Value memref, + AffineMap map, ValueRange mapOperands) { assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info"); result.addOperands(memref); result.addOperands(mapOperands); @@ -1760,8 +1758,8 @@ void AffineLoadOp::build(Builder *builder, OperationState &result, result.types.push_back(memrefType.getElementType()); } -void AffineLoadOp::build(Builder *builder, OperationState &result, - ValuePtr memref, ValueRange indices) { +void AffineLoadOp::build(Builder *builder, OperationState &result, Value memref, + ValueRange indices) { auto memrefType = memref->getType().cast(); auto rank = memrefType.getRank(); // Create identity map for memrefs with at least one dimension or () -> () @@ -1843,7 +1841,7 @@ OpFoldResult AffineLoadOp::fold(ArrayRef cstOperands) { //===----------------------------------------------------------------------===// void AffineStoreOp::build(Builder *builder, OperationState &result, - ValuePtr valueToStore, ValuePtr memref, AffineMap map, + Value valueToStore, Value memref, AffineMap map, ValueRange mapOperands) { assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info"); result.addOperands(valueToStore); @@ -1854,7 +1852,7 @@ void AffineStoreOp::build(Builder *builder, OperationState &result, // Use identity map. void AffineStoreOp::build(Builder *builder, OperationState &result, - ValuePtr valueToStore, ValuePtr memref, + Value valueToStore, Value memref, ValueRange indices) { auto memrefType = memref->getType().cast(); auto rank = memrefType.getRank(); @@ -2064,7 +2062,7 @@ void print(OpAsmPrinter &p, AffinePrefetchOp op) { p << AffinePrefetchOp::getOperationName() << " " << *op.memref() << '['; AffineMapAttr mapAttr = op.getAttrOfType(op.getMapAttrName()); if (mapAttr) { - SmallVector operands(op.getMapOperands()); + SmallVector operands(op.getMapOperands()); p.printAffineMapOfSSAIds(mapAttr, operands); } p << ']' << ", " << (op.isWrite() ? "write" : "read") << ", " diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp index 725751eb6c1..df6015de1b9 100644 --- a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp +++ b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp @@ -37,9 +37,9 @@ struct LowerUniformCastsPass : public FunctionPass { // Dequantize //===----------------------------------------------------------------------===// -static ValuePtr emitUniformPerLayerDequantize(Location loc, ValuePtr input, - UniformQuantizedType elementType, - PatternRewriter &rewriter) { +static Value emitUniformPerLayerDequantize(Location loc, Value input, + UniformQuantizedType elementType, + PatternRewriter &rewriter) { // Pre-conditions. if (!elementType.isSigned()) { // TODO: Support unsigned storage type. @@ -62,7 +62,7 @@ static ValuePtr emitUniformPerLayerDequantize(Location loc, ValuePtr input, // Apply zero-point offset. if (elementType.getZeroPoint() != 0) { - ValuePtr negZeroPointConst = rewriter.create( + Value negZeroPointConst = rewriter.create( loc, broadcastScalarConstIntValue(intermediateType, -elementType.getZeroPoint())); input = rewriter.create(loc, input, negZeroPointConst); @@ -72,14 +72,14 @@ static ValuePtr emitUniformPerLayerDequantize(Location loc, ValuePtr input, input = rewriter.create(loc, realType, input); // Mul by scale. - ValuePtr scaleConst = rewriter.create( + Value scaleConst = rewriter.create( loc, broadcastScalarConstFloatValue(realType, APFloat(elementType.getScale()))); return rewriter.create(loc, input, scaleConst); } -static ValuePtr -emitUniformPerAxisDequantize(Location loc, ValuePtr input, +static Value +emitUniformPerAxisDequantize(Location loc, Value input, UniformQuantizedPerAxisType elementType, PatternRewriter &rewriter) { // TODO: Support per-axis dequantize. @@ -88,8 +88,8 @@ emitUniformPerAxisDequantize(Location loc, ValuePtr input, return nullptr; } -static ValuePtr emitDequantize(Location loc, ValuePtr input, - PatternRewriter &rewriter) { +static Value emitDequantize(Location loc, Value input, + PatternRewriter &rewriter) { Type inputType = input->getType(); QuantizedType qElementType = QuantizedType::getQuantizedElementType(inputType); @@ -124,7 +124,7 @@ struct UniformDequantizePattern : public OpRewritePattern { return matchFailure(); } - ValuePtr dequantizedValue = emitDequantize(op.getLoc(), op.arg(), rewriter); + Value dequantizedValue = emitDequantize(op.getLoc(), op.arg(), rewriter); if (!dequantizedValue) { return matchFailure(); } @@ -161,14 +161,14 @@ tryRewriteAffineAddEwIsomorphicSigned(const UniformBinaryOpInfo &info, castElementType(info.resultStorageType, intermediateElementType); // Cast operands to storage type. - ValuePtr lhsValue = rewriter - .create(info.op->getLoc(), - info.lhsStorageType, info.lhs) - .getResult(); - ValuePtr rhsValue = rewriter - .create(info.op->getLoc(), - info.rhsStorageType, info.rhs) - .getResult(); + Value lhsValue = rewriter + .create(info.op->getLoc(), + info.lhsStorageType, info.lhs) + .getResult(); + Value rhsValue = rewriter + .create(info.op->getLoc(), + info.rhsStorageType, info.rhs) + .getResult(); // Cast to the intermediate sized type. lhsValue = rewriter.create(info.op->getLoc(), intermediateType, @@ -177,7 +177,7 @@ tryRewriteAffineAddEwIsomorphicSigned(const UniformBinaryOpInfo &info, rhsValue); // Add. - ValuePtr resultValue = + Value resultValue = rewriter.create(info.op->getLoc(), lhsValue, rhsValue); // Zero point offset adjustment. @@ -185,7 +185,7 @@ tryRewriteAffineAddEwIsomorphicSigned(const UniformBinaryOpInfo &info, // zpOffset = -zp int zpOffset = -1 * info.resultType.getZeroPoint(); if (zpOffset != 0) { - ValuePtr zpOffsetConst = rewriter.create( + Value zpOffsetConst = rewriter.create( info.op->getLoc(), broadcastScalarConstIntValue(intermediateType, zpOffset)); resultValue = @@ -237,14 +237,14 @@ tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info, castElementType(info.resultStorageType, intermediateElementType); // Cast operands to storage type. - ValuePtr lhsValue = rewriter - .create(info.op->getLoc(), - info.lhsStorageType, info.lhs) - .getResult(); - ValuePtr rhsValue = rewriter - .create(info.op->getLoc(), - info.rhsStorageType, info.rhs) - .getResult(); + Value lhsValue = rewriter + .create(info.op->getLoc(), + info.lhsStorageType, info.lhs) + .getResult(); + Value rhsValue = rewriter + .create(info.op->getLoc(), + info.rhsStorageType, info.rhs) + .getResult(); // Cast to the intermediate sized type. lhsValue = rewriter.create(info.op->getLoc(), intermediateType, @@ -254,7 +254,7 @@ tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info, // Apply argument zeroPoints. if (info.lhsType.getZeroPoint() != 0) { - ValuePtr zpOffsetConst = rewriter.create( + Value zpOffsetConst = rewriter.create( info.op->getLoc(), broadcastScalarConstIntValue( intermediateType, -info.lhsType.getZeroPoint())); lhsValue = @@ -262,7 +262,7 @@ tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info, } if (info.rhsType.getZeroPoint() != 0) { - ValuePtr zpOffsetConst = rewriter.create( + Value zpOffsetConst = rewriter.create( info.op->getLoc(), broadcastScalarConstIntValue( intermediateType, -info.rhsType.getZeroPoint())); rhsValue = @@ -270,7 +270,7 @@ tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info, } // Mul. - ValuePtr resultValue = + Value resultValue = rewriter.create(info.op->getLoc(), lhsValue, rhsValue); // Scale output. @@ -284,7 +284,7 @@ tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info, // Zero point offset adjustment. if (info.resultType.getZeroPoint() != 0) { - ValuePtr zpOffsetConst = rewriter.create( + Value zpOffsetConst = rewriter.create( info.op->getLoc(), broadcastScalarConstIntValue(intermediateType, info.resultType.getZeroPoint())); diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h b/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h index bce5285a8b0..8cea97c693c 100644 --- a/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h +++ b/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h @@ -50,7 +50,7 @@ template bool integralLog2(F x, int &log2Result) { /// Helper class for operating on binary operations where all operands /// and the result are a UniformQuantizedType. struct UniformBinaryOpInfo { - UniformBinaryOpInfo(Operation *op, ValuePtr lhs, ValuePtr rhs, + UniformBinaryOpInfo(Operation *op, Value lhs, Value rhs, Optional clampMin, Optional clampMax) : op(op), lhs(lhs), rhs(rhs), clampMin(clampMin), clampMax(clampMax), lhsType(getUniformElementType(lhs->getType())), @@ -119,8 +119,8 @@ struct UniformBinaryOpInfo { } Operation *op; - ValuePtr lhs; - ValuePtr rhs; + Value lhs; + Value rhs; Optional clampMin; Optional clampMax; diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 422597fe90d..bda8032fc21 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -204,15 +204,14 @@ static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &state) { static SmallVector getValueTypes(ValueRange values) { SmallVector types; types.reserve(values.size()); - for (ValuePtr v : values) + for (Value v : values) types.push_back(v->getType()); return types; } -void LaunchOp::build(Builder *builder, OperationState &result, - ValuePtr gridSizeX, ValuePtr gridSizeY, ValuePtr gridSizeZ, - ValuePtr blockSizeX, ValuePtr blockSizeY, - ValuePtr blockSizeZ, ValueRange operands) { +void LaunchOp::build(Builder *builder, OperationState &result, Value gridSizeX, + Value gridSizeY, Value gridSizeZ, Value blockSizeX, + Value blockSizeY, Value blockSizeZ, ValueRange operands) { // Add grid and block sizes as op operands, followed by the data operands. result.addOperands( {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ}); @@ -519,10 +518,9 @@ void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList &results, //===----------------------------------------------------------------------===// void LaunchFuncOp::build(Builder *builder, OperationState &result, - GPUFuncOp kernelFunc, ValuePtr gridSizeX, - ValuePtr gridSizeY, ValuePtr gridSizeZ, - ValuePtr blockSizeX, ValuePtr blockSizeY, - ValuePtr blockSizeZ, ValueRange kernelOperands) { + GPUFuncOp kernelFunc, Value gridSizeX, Value gridSizeY, + Value gridSizeZ, Value blockSizeX, Value blockSizeY, + Value blockSizeZ, ValueRange kernelOperands) { // Add grid and block sizes as op operands, followed by the data operands. result.addOperands( {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ}); @@ -555,7 +553,7 @@ StringRef LaunchFuncOp::getKernelModuleName() { .getRootReference(); } -ValuePtr LaunchFuncOp::getKernelOperand(unsigned i) { +Value LaunchFuncOp::getKernelOperand(unsigned i) { return getOperation()->getOperand(i + kNumConfigOperands); } @@ -718,14 +716,13 @@ static ParseResult parseGPUFuncOp(OpAsmParser &parser, OperationState &result) { } static void printAttributions(OpAsmPrinter &p, StringRef keyword, - ArrayRef values) { + ArrayRef values) { if (values.empty()) return; p << ' ' << keyword << '('; - interleaveComma(values, p, [&p](BlockArgumentPtr v) { - p << *v << " : " << v->getType(); - }); + interleaveComma(values, p, + [&p](BlockArgument v) { p << *v << " : " << v->getType(); }); p << ')'; } @@ -772,9 +769,9 @@ LogicalResult GPUFuncOp::verifyType() { } static LogicalResult verifyAttributions(Operation *op, - ArrayRef attributions, + ArrayRef attributions, unsigned memorySpace) { - for (ValuePtr v : attributions) { + for (Value v : attributions) { auto type = v->getType().dyn_cast(); if (!type) return op->emitOpError() << "expected memref type in attribution"; diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp index 6a7cd290dd2..2d00ac03d33 100644 --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -22,10 +22,10 @@ using namespace mlir; template static void createForAllDimensions(OpBuilder &builder, Location loc, - SmallVectorImpl &values) { + SmallVectorImpl &values) { for (StringRef dim : {"x", "y", "z"}) { - ValuePtr v = builder.create(loc, builder.getIndexType(), - builder.getStringAttr(dim)); + Value v = builder.create(loc, builder.getIndexType(), + builder.getStringAttr(dim)); values.push_back(v); } } @@ -37,7 +37,7 @@ static void injectGpuIndexOperations(Location loc, Region &body) { OpBuilder builder(loc->getContext()); Block &firstBlock = body.front(); builder.setInsertionPointToStart(&firstBlock); - SmallVector indexOps; + SmallVector indexOps; createForAllDimensions(builder, loc, indexOps); createForAllDimensions(builder, loc, indexOps); createForAllDimensions(builder, loc, indexOps); @@ -60,7 +60,7 @@ static gpu::LaunchFuncOp inlineBeneficiaryOps(gpu::GPUFuncOp kernelFunc, gpu::LaunchFuncOp launch) { OpBuilder kernelBuilder(kernelFunc.getBody()); auto &firstBlock = kernelFunc.getBody().front(); - SmallVector newLaunchArgs; + SmallVector newLaunchArgs; BlockAndValueMapping map; for (int i = 0, e = launch.getNumKernelOperands(); i < e; ++i) { map.map(launch.getKernelOperand(i), kernelFunc.getArgument(i)); @@ -73,7 +73,7 @@ static gpu::LaunchFuncOp inlineBeneficiaryOps(gpu::GPUFuncOp kernelFunc, } // Only inline operations that do not create new arguments. if (!llvm::all_of(operandOp->getOperands(), - [map](ValuePtr value) { return map.contains(value); })) { + [map](Value value) { return map.contains(value); })) { continue; } auto clone = kernelBuilder.clone(*operandOp, map); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index b8d2d242657..71b7064ac63 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -406,7 +406,7 @@ static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) { // Expects vector to be of wrapped LLVM vector type and position to be of // wrapped LLVM i32 type. void LLVM::ExtractElementOp::build(Builder *b, OperationState &result, - ValuePtr vector, ValuePtr position, + Value vector, Value position, ArrayRef attrs) { auto wrappedVectorType = vector->getType().cast(); auto llvmType = wrappedVectorType.getVectorElementType(); @@ -672,7 +672,7 @@ static void printBrOp(OpAsmPrinter &p, BrOp &op) { // attribute-dict? static ParseResult parseBrOp(OpAsmParser &parser, OperationState &result) { Block *dest; - SmallVector operands; + SmallVector operands; if (parser.parseSuccessorAndUseList(dest, operands) || parser.parseOptionalAttrDict(result.attributes)) return failure(); @@ -699,8 +699,8 @@ static void printCondBrOp(OpAsmPrinter &p, CondBrOp &op) { static ParseResult parseCondBrOp(OpAsmParser &parser, OperationState &result) { Block *trueDest; Block *falseDest; - SmallVector trueOperands; - SmallVector falseOperands; + SmallVector trueOperands; + SmallVector falseOperands; OpAsmParser::OperandType condition; Builder &builder = parser.getBuilder(); @@ -1057,8 +1057,8 @@ static LogicalResult verify(GlobalOp op) { //===----------------------------------------------------------------------===// // Expects vector to be of wrapped LLVM vector type and position to be of // wrapped LLVM i32 type. -void LLVM::ShuffleVectorOp::build(Builder *b, OperationState &result, - ValuePtr v1, ValuePtr v2, ArrayAttr mask, +void LLVM::ShuffleVectorOp::build(Builder *b, OperationState &result, Value v1, + Value v2, ArrayAttr mask, ArrayRef attrs) { auto wrappedContainerType1 = v1->getType().cast(); auto vType = LLVMType::getVectorTy( @@ -1655,10 +1655,10 @@ LLVMType LLVMType::getVoidTy(LLVMDialect *dialect) { // Utility functions. //===----------------------------------------------------------------------===// -ValuePtr mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder, - StringRef name, StringRef value, - LLVM::Linkage linkage, - LLVM::LLVMDialect *llvmDialect) { +Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder, + StringRef name, StringRef value, + LLVM::Linkage linkage, + LLVM::LLVMDialect *llvmDialect) { assert(builder.getInsertionBlock() && builder.getInsertionBlock()->getParentOp() && "expected builder to point to a block constrained in an op"); @@ -1675,13 +1675,13 @@ ValuePtr mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder, builder.getStringAttr(value)); // Get the pointer to the first character in the global string. - ValuePtr globalPtr = builder.create(loc, global); - ValuePtr cst0 = builder.create( + Value globalPtr = builder.create(loc, global); + Value cst0 = builder.create( loc, LLVM::LLVMType::getInt64Ty(llvmDialect), builder.getIntegerAttr(builder.getIndexType(), 0)); - return builder.create( - loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), globalPtr, - ArrayRef({cst0, cst0})); + return builder.create(loc, + LLVM::LLVMType::getInt8PtrTy(llvmDialect), + globalPtr, ArrayRef({cst0, cst0})); } bool mlir::LLVM::satisfiesLLVMModule(Operation *op) { diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp index be90b1ce5a6..e8667f07822 100644 --- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp +++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp @@ -40,7 +40,7 @@ static StringRef toStringRef(LinalgDependenceGraph::DependenceType dt) { llvm_unreachable("Unexpected DependenceType"); } -ValuePtr Aliases::find(ValuePtr v) { +Value Aliases::find(Value v) { if (v.isa()) return v; @@ -185,14 +185,14 @@ LinalgDependenceGraph::findCoveringDependences(LinalgOp srcLinalgOp, } SmallVector LinalgDependenceGraph::findCoveringWrites( - LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, ValuePtr view) const { + LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view) const { return findOperationsWithCoveringDependences( srcLinalgOp, dstLinalgOp, view, {DependenceType::WAW, DependenceType::WAR}); } SmallVector LinalgDependenceGraph::findCoveringReads( - LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, ValuePtr view) const { + LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view) const { return findOperationsWithCoveringDependences( srcLinalgOp, dstLinalgOp, view, {DependenceType::RAR, DependenceType::RAW}); @@ -200,7 +200,7 @@ SmallVector LinalgDependenceGraph::findCoveringReads( SmallVector LinalgDependenceGraph::findOperationsWithCoveringDependences( - LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, ValuePtr view, + LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view, ArrayRef types) const { auto *src = srcLinalgOp.getOperation(); auto *dst = dstLinalgOp.getOperation(); diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp index af5e576b290..37c63b74f14 100644 --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -35,8 +35,8 @@ static void getMaxDimIndex(ArrayRef structuredIndices, Operation *mlir::edsc::makeLinalgGenericOp( ArrayRef iteratorTypes, ArrayRef inputs, ArrayRef outputs, - function_ref)> regionBuilder, - ArrayRef otherValues, ArrayRef otherAttributes) { + function_ref)> regionBuilder, + ArrayRef otherValues, ArrayRef otherAttributes) { auto &builder = edsc::ScopedContext::getBuilder(); auto *ctx = builder.getContext(); unsigned nInputs = inputs.size(); @@ -57,7 +57,7 @@ Operation *mlir::edsc::makeLinalgGenericOp( AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, out.getExprs())); unsigned nViews = nInputs + nOutputs; - SmallVector values; + SmallVector values; values.reserve(nViews); values.append(inputs.begin(), inputs.end()); values.append(outputs.begin(), outputs.end()); @@ -100,7 +100,7 @@ Operation *mlir::edsc::makeLinalgGenericOp( return op; } -void mlir::edsc::ops::macRegionBuilder(ArrayRef args) { +void mlir::edsc::ops::macRegionBuilder(ArrayRef args) { using edsc::op::operator+; using edsc::op::operator*; assert(args.size() == 3 && "expected 3 block arguments"); @@ -113,7 +113,7 @@ Operation *mlir::edsc::ops::linalg_pointwise(UnaryPointwiseOpBuilder unaryOp, StructuredIndexed O) { SmallVector iterTypes(O.getExprs().size(), edsc::IterType::Parallel); - auto fun = [&unaryOp](ArrayRef args) { + auto fun = [&unaryOp](ArrayRef args) { assert(args.size() == 2 && "expected 2 block arguments"); ValueHandle a(args[0]); linalg_yield(unaryOp(a)); @@ -125,8 +125,7 @@ Operation *mlir::edsc::ops::linalg_pointwise_tanh(StructuredIndexed I, StructuredIndexed O) { ; using edsc::intrinsics::tanh; - UnaryPointwiseOpBuilder unOp( - [](ValueHandle a) -> ValuePtr { return tanh(a); }); + UnaryPointwiseOpBuilder unOp([](ValueHandle a) -> Value { return tanh(a); }); return linalg_pointwise(unOp, I, O); } @@ -137,7 +136,7 @@ Operation *mlir::edsc::ops::linalg_pointwise(BinaryPointwiseOpBuilder binaryOp, StructuredIndexed O) { SmallVector iterTypes(O.getExprs().size(), edsc::IterType::Parallel); - auto fun = [&binaryOp](ArrayRef args) { + auto fun = [&binaryOp](ArrayRef args) { assert(args.size() == 3 && "expected 3 block arguments"); ValueHandle a(args[0]), b(args[1]); linalg_yield(binaryOp(a, b)); @@ -150,14 +149,14 @@ Operation *mlir::edsc::ops::linalg_pointwise_add(StructuredIndexed I1, StructuredIndexed O) { using edsc::op::operator+; BinaryPointwiseOpBuilder binOp( - [](ValueHandle a, ValueHandle b) -> ValuePtr { return a + b; }); + [](ValueHandle a, ValueHandle b) -> Value { return a + b; }); return linalg_pointwise(binOp, I1, I2, O); } Operation *mlir::edsc::ops::linalg_pointwise_max(StructuredIndexed I1, StructuredIndexed I2, StructuredIndexed O) { - BinaryPointwiseOpBuilder binOp([](ValueHandle a, ValueHandle b) -> ValuePtr { + BinaryPointwiseOpBuilder binOp([](ValueHandle a, ValueHandle b) -> Value { using edsc::intrinsics::select; using edsc::op::operator>; return select(a > b, a, b).getValue(); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 10c37c0ec43..0f9f8f8d51f 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -309,7 +309,7 @@ static ParseResult parseRangeOp(OpAsmParser &parser, OperationState &result) { // SliceOp //===----------------------------------------------------------------------===// void mlir::linalg::SliceOp::build(Builder *b, OperationState &result, - ValuePtr base, ValueRange indexings) { + Value base, ValueRange indexings) { result.addOperands(base); result.addOperands(indexings); @@ -385,7 +385,7 @@ static LogicalResult verify(SliceOp op) { // TransposeOp //===----------------------------------------------------------------------===// void mlir::linalg::TransposeOp::build(Builder *b, OperationState &result, - ValuePtr view, AffineMapAttr permutation, + Value view, AffineMapAttr permutation, ArrayRef attrs) { auto permutationMap = permutation.getValue(); assert(permutationMap); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index 27dcf663d23..9df7bce0879 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -68,16 +68,16 @@ static llvm::cl::list clTileSizes( static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, ArrayRef loopRanges) { auto maps = loopToOperandRangesMaps(op); - SmallVector clonedViews; + SmallVector clonedViews; clonedViews.reserve(op.getNumInputsAndOutputs()); // Iterate over the inputs and outputs in order. // Extract the subranges from the linearized ranges. - SmallVector ios(op.getInputsAndOutputs()); + SmallVector ios(op.getInputsAndOutputs()); for (auto en : llvm::enumerate(ios)) { unsigned idx = en.index(); auto map = maps[idx]; LLVM_DEBUG(dbgs() << "map: " << map << "\n"); - ValuePtr view = en.value(); + Value view = en.value(); SmallVector viewRanges(map.getNumResults()); for (auto en2 : llvm::enumerate(map.getResults())) { unsigned d = en2.index(); @@ -90,7 +90,7 @@ static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, } // Construct a new subview for the tile. unsigned rank = viewRanges.size(); - SmallVector offsets, sizes, strides; + SmallVector offsets, sizes, strides; offsets.reserve(rank); sizes.reserve(rank); strides.reserve(rank); @@ -108,7 +108,7 @@ static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, } struct ViewDimension { - ValuePtr view; + Value view; unsigned dimension; }; @@ -121,14 +121,14 @@ static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) { auto maps = loopToOperandRangesMaps(op); // Iterate over the inputs and outputs in order. // Extract the subranges from the linearized ranges. - SmallVector ios(op.getInputsAndOutputs()); + SmallVector ios(op.getInputsAndOutputs()); for (auto en : llvm::enumerate(ios)) { unsigned idx = en.index(); auto map = maps[idx]; LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n"); LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n"); - ValuePtr view = en.value(); - SmallVector viewRanges(map.getNumResults(), nullptr); + Value view = en.value(); + SmallVector viewRanges(map.getNumResults(), nullptr); for (auto en2 : llvm::enumerate(map.getResults())) { if (loopDepth == en2.value().cast().getPosition()) { LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth @@ -142,9 +142,9 @@ static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) { llvm_unreachable("Expect to be able to extract a view defining loop range"); } -static LinalgOp fuse(ValuePtr producedView, LinalgOp producer, - LinalgOp consumer, unsigned consumerIdx, - unsigned producerIdx, OperationFolder *folder) { +static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer, + unsigned consumerIdx, unsigned producerIdx, + OperationFolder *folder) { auto subView = dyn_cast_or_null( consumer.getInput(consumerIdx)->getDefiningOp()); auto slice = dyn_cast_or_null( @@ -196,8 +196,7 @@ static LinalgOp fuse(ValuePtr producedView, LinalgOp producer, // Encode structural fusion safety preconditions. // Some of these will be lifted in the future with better analysis. -static bool isStructurallyFusableProducer(LinalgOp producer, - ValuePtr consumedView, +static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView, LinalgOp consumer) { if (producer.getNumOutputs() != 1) { LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)"); @@ -217,7 +216,7 @@ static bool isStructurallyFusableProducer(LinalgOp producer, bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph, LinalgOp consumer, - ValuePtr consumedView, + Value consumedView, LinalgOp producer) { // Make some simple structural checks that alleviate the need for more // complex analyses. @@ -236,7 +235,7 @@ bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph, } bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph, - LinalgOp consumer, ValuePtr consumedView, + LinalgOp consumer, Value consumedView, LinalgOp producer) { if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer)) return false; diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp index 0f333791dd7..d7cc4a86d21 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -40,7 +40,7 @@ using edsc::op::operator==; static SmallVector makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map, - ArrayRef vals) { + ArrayRef vals) { assert(map.getNumSymbols() == 0); assert(map.getNumInputs() == vals.size()); SmallVector res; @@ -48,35 +48,34 @@ makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map, auto dims = map.getNumDims(); for (auto e : map.getResults()) { auto exprMap = AffineMap::get(dims, 0, e); - SmallVector operands(vals.begin(), vals.end()); + SmallVector operands(vals.begin(), vals.end()); canonicalizeMapAndOperands(&exprMap, &operands); res.push_back(affine_apply(exprMap, operands)); } return res; } -static SmallVector permuteIvs(ArrayRef ivs, - Optional permutation) { +static SmallVector permuteIvs(ArrayRef ivs, + Optional permutation) { return permutation ? applyMapToValues(ScopedContext::getBuilder(), ScopedContext::getLocation(), permutation.getValue(), ivs) - : SmallVector(ivs.begin(), ivs.end()); + : SmallVector(ivs.begin(), ivs.end()); } // Creates a number of ranges equal to the number of results in `map`. // The returned ranges correspond to the loop ranges, in the proper order, for // which new loops will be created. -static SmallVector emitLoopRanges(OpBuilder &b, Location loc, - AffineMap map, - ArrayRef allViewSizes); -SmallVector emitLoopRanges(OpBuilder &b, Location loc, - AffineMap map, - ArrayRef allViewSizes) { +static SmallVector emitLoopRanges(OpBuilder &b, Location loc, + AffineMap map, + ArrayRef allViewSizes); +SmallVector emitLoopRanges(OpBuilder &b, Location loc, AffineMap map, + ArrayRef allViewSizes) { // Apply `map` to get view sizes in loop order. auto sizes = applyMapToValues(b, loc, map, allViewSizes); // Create a new range with the applied tile sizes. ScopedContext scope(b, loc); - SmallVector res; + SmallVector res; for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx) { res.push_back(range(constant_index(0), sizes[idx], constant_index(1))); } @@ -89,8 +88,7 @@ class LinalgScopedEmitter {}; template class LinalgScopedEmitter { public: - static void emitScalarImplementation(ArrayRef allIvs, - CopyOp copyOp) { + static void emitScalarImplementation(ArrayRef allIvs, CopyOp copyOp) { auto nPar = copyOp.getNumParallelLoops(); assert(nPar == allIvs.size()); auto inputIvs = @@ -112,8 +110,7 @@ public: template class LinalgScopedEmitter { public: - static void emitScalarImplementation(ArrayRef allIvs, - FillOp fillOp) { + static void emitScalarImplementation(ArrayRef allIvs, FillOp fillOp) { auto nPar = fillOp.getNumParallelLoops(); assert(nPar == allIvs.size()); auto ivs = @@ -129,7 +126,7 @@ public: template class LinalgScopedEmitter { public: - static void emitScalarImplementation(ArrayRef allIvs, DotOp dotOp) { + static void emitScalarImplementation(ArrayRef allIvs, DotOp dotOp) { assert(allIvs.size() == 1); IndexHandle r_i(allIvs[0]); IndexedValueType A(dotOp.getInput(0)), B(dotOp.getInput(1)), @@ -142,7 +139,7 @@ public: template class LinalgScopedEmitter { public: - static void emitScalarImplementation(ArrayRef allIvs, + static void emitScalarImplementation(ArrayRef allIvs, MatvecOp matvecOp) { assert(allIvs.size() == 2); IndexHandle i(allIvs[0]), r_j(allIvs[1]); @@ -156,7 +153,7 @@ public: template class LinalgScopedEmitter { public: - static void emitScalarImplementation(ArrayRef allIvs, + static void emitScalarImplementation(ArrayRef allIvs, MatmulOp matmulOp) { assert(allIvs.size() == 3); IndexHandle i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]); @@ -170,8 +167,7 @@ public: template class LinalgScopedEmitter { public: - static void emitScalarImplementation(ArrayRef allIvs, - ConvOp convOp) { + static void emitScalarImplementation(ArrayRef allIvs, ConvOp convOp) { auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); auto maps = loopToOperandRangesMaps(convOp); @@ -220,14 +216,14 @@ public: template class LinalgScopedEmitter { public: - static void emitScalarImplementation(ArrayRef allIvs, + static void emitScalarImplementation(ArrayRef allIvs, GenericOp genericOp) { auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); using edsc::intrinsics::detail::ValueHandleArray; unsigned nInputs = genericOp.getNumInputs(); unsigned nOutputs = genericOp.getNumOutputs(); - SmallVector indexedValues(nInputs + nOutputs); + SmallVector indexedValues(nInputs + nOutputs); // 1.a. Emit std_load from input views. for (unsigned i = 0; i < nInputs; ++i) { @@ -315,7 +311,7 @@ public: template class LinalgScopedEmitter { public: - static void emitScalarImplementation(ArrayRef allIvs, + static void emitScalarImplementation(ArrayRef allIvs, IndexedGenericOp indexedGenericOp) { auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); @@ -323,7 +319,7 @@ public: unsigned nInputs = indexedGenericOp.getNumInputs(); unsigned nOutputs = indexedGenericOp.getNumOutputs(); unsigned nLoops = allIvs.size(); - SmallVector indexedValues(nLoops + nInputs + nOutputs); + SmallVector indexedValues(nLoops + nInputs + nOutputs); for (unsigned i = 0; i < nLoops; ++i) { indexedValues[i] = allIvs[i]; diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp index 451803797f4..eb23a8ceb1a 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -90,7 +90,7 @@ LogicalResult mlir::linalg::tileAndFuseLinalgOpAndSetMarker( } bool mlir::linalg::detail::isProducedByOpOfTypeImpl( - Operation *consumerOp, ValuePtr consumedView, + Operation *consumerOp, Value consumedView, function_ref isaOpType) { LinalgOp consumer = dyn_cast(consumerOp); if (!consumer) @@ -166,7 +166,7 @@ LogicalResult mlir::linalg::vectorizeGenericOp(PatternRewriter &rewriter, return failure(); // TODO(ntv): non-identity layout. - auto isStaticMemRefWithIdentityLayout = [](ValuePtr v) { + auto isStaticMemRefWithIdentityLayout = [](Value v) { auto m = v->getType().dyn_cast(); if (!m || !m.hasStaticShape() || !m.getAffineMaps().empty()) return false; @@ -226,7 +226,7 @@ mlir::linalg::permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op, LogicalResult mlir::linalg::linalgOpPromoteSubviews(PatternRewriter &rewriter, Operation *op) { LinalgOp linOp = dyn_cast(op); - SetVector subViews; + SetVector subViews; for (auto it : linOp.getInputsAndOutputs()) if (auto sv = dyn_cast_or_null(it->getDefiningOp())) subViews.insert(sv); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp index 08bc1518a19..b8b27958ff5 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -46,15 +46,14 @@ static llvm::cl::opt clPromoteDynamic( llvm::cl::desc("Test generation of dynamic promoted buffers"), llvm::cl::cat(clOptionsCategory), llvm::cl::init(false)); -static ValuePtr allocBuffer(Type elementType, ValuePtr size, - bool dynamicBuffers) { +static Value allocBuffer(Type elementType, Value size, bool dynamicBuffers) { auto *ctx = size->getContext(); auto width = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8); if (!dynamicBuffers) if (auto cst = dyn_cast_or_null(size->getDefiningOp())) return alloc( MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx))); - ValuePtr mul = muli(constant_index(width), size); + Value mul = muli(constant_index(width), size); return alloc(MemRefType::get(-1, IntegerType::get(8, ctx)), mul); } @@ -84,14 +83,14 @@ static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc, auto viewType = subView.getType(); auto rank = viewType.getRank(); - ValuePtr allocSize = one; - SmallVector fullRanges, partialRanges; + Value allocSize = one; + SmallVector fullRanges, partialRanges; fullRanges.reserve(rank); partialRanges.reserve(rank); for (auto en : llvm::enumerate(subView.getRanges())) { auto rank = en.index(); auto rangeValue = en.value(); - ValuePtr d = rangeValue.size; + Value d = rangeValue.size; allocSize = muli(folder, allocSize, d).getValue(); fullRanges.push_back(d); partialRanges.push_back(range(folder, zero, dim(subView, rank), one)); @@ -107,7 +106,7 @@ static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc, SmallVector mlir::linalg::promoteSubViews(OpBuilder &b, Location loc, - ArrayRef subViews, bool dynamicBuffers, + ArrayRef subViews, bool dynamicBuffers, OperationFolder *folder) { if (subViews.empty()) return {}; @@ -115,7 +114,7 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc, ScopedContext scope(b, loc); SmallVector res; res.reserve(subViews.size()); - DenseMap promotionInfoMap; + DenseMap promotionInfoMap; for (auto v : subViews) { SubViewOp subView = cast(v->getDefiningOp()); auto viewType = subView.getType(); @@ -136,7 +135,7 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc, // TODO(ntv): value to fill with should be related to the operation. // For now, just use APFloat(0.0f). auto t = subView.getType().getElementType().cast(); - ValuePtr fillVal = constant_float(folder, APFloat(0.0f), t); + Value fillVal = constant_float(folder, APFloat(0.0f), t); // TODO(ntv): fill is only necessary if `promotionInfo` has a full local // view that is different from the partial local view and we are on the // boundary. @@ -153,16 +152,16 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc, } LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op, - SetVector subViews, + SetVector subViews, bool dynamicBuffers, OperationFolder *folder) { // 1. Promote the specified views and use them in the new op. ScopedContext scope(b, op.getLoc()); auto promotedBufferAndViews = promoteSubViews( b, op.getLoc(), subViews.getArrayRef(), dynamicBuffers, folder); - SmallVector opViews; + SmallVector opViews; opViews.reserve(op.getNumInputsAndOutputs()); - SmallVector, 8> writebackViews; + SmallVector, 8> writebackViews; writebackViews.reserve(subViews.size()); unsigned promotedIdx = 0; for (auto view : op.getInputsAndOutputs()) { @@ -206,7 +205,7 @@ static void promoteSubViews(FuncOp f, bool dynamicBuffers) { f.walk([dynamicBuffers, &folder, &toErase](LinalgOp op) { // TODO(ntv) some heuristic here to decide what to promote. Atm it is all or // nothing. - SetVector subViews; + SetVector subViews; OpBuilder b(op); for (auto it : op.getInputsAndOutputs()) if (auto sv = dyn_cast_or_null(it->getDefiningOp())) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 99645a23100..964f540c099 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -44,7 +44,7 @@ static llvm::cl::list llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated, llvm::cl::cat(clOptionsCategory)); -static bool isZero(ValuePtr v) { +static bool isZero(Value v) { return isa_and_nonnull(v->getDefiningOp()) && cast(v->getDefiningOp()).getValue() == 0; } @@ -62,12 +62,12 @@ using LoopIndexToRangeIndexMap = DenseMap; // indices of newly created loops. static std::tuple, LoopIndexToRangeIndexMap> makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map, - ArrayRef allViewSizes, - ArrayRef allTileSizes, OperationFolder *folder) { + ArrayRef allViewSizes, ArrayRef allTileSizes, + OperationFolder *folder) { assert(allTileSizes.size() == map.getNumResults()); // Apply `map` to get view sizes in loop order. auto viewSizes = applyMapToValues(b, loc, map, allViewSizes, folder); - SmallVector tileSizes(allTileSizes.begin(), allTileSizes.end()); + SmallVector tileSizes(allTileSizes.begin(), allTileSizes.end()); // Traverse the tile sizes, which are in loop order, erase zeros everywhere. LoopIndexToRangeIndexMap loopIndexToRangeIndex; @@ -101,8 +101,7 @@ namespace { // `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0] // struct TileCheck : public AffineExprVisitor { - TileCheck(ArrayRef tileSizes) - : isTiled(false), tileSizes(tileSizes) {} + TileCheck(ArrayRef tileSizes) : isTiled(false), tileSizes(tileSizes) {} void visitDimExpr(AffineDimExpr expr) { isTiled |= !isZero(tileSizes[expr.getPosition()]); @@ -115,7 +114,7 @@ struct TileCheck : public AffineExprVisitor { "nonpositive multiplying coefficient"); } bool isTiled; - ArrayRef tileSizes; + ArrayRef tileSizes; }; } // namespace @@ -197,11 +196,11 @@ void transformIndexedGenericOpIndices( auto rangeIndex = loopIndexToRangeIndex.find(i); if (rangeIndex == loopIndexToRangeIndex.end()) continue; - ValuePtr oldIndex = block.getArgument(i); + Value oldIndex = block.getArgument(i); // Offset the index argument `i` by the value of the corresponding induction // variable and replace all uses of the previous value. - ValuePtr newIndex = b.create(indexedGenericOp.getLoc(), oldIndex, - pivs[rangeIndex->second]->getValue()); + Value newIndex = b.create(indexedGenericOp.getLoc(), oldIndex, + pivs[rangeIndex->second]->getValue()); for (auto &use : oldIndex->getUses()) { if (use.getOwner() == newIndex->getDefiningOp()) continue; @@ -210,7 +209,7 @@ void transformIndexedGenericOpIndices( } } -static bool isTiled(AffineExpr expr, ArrayRef tileSizes) { +static bool isTiled(AffineExpr expr, ArrayRef tileSizes) { if (!expr) return false; TileCheck t(tileSizes); @@ -220,7 +219,7 @@ static bool isTiled(AffineExpr expr, ArrayRef tileSizes) { // Checks whether the view with index `viewIndex` within `linalgOp` varies with // respect to a non-zero `tileSize`. -static bool isTiled(AffineMap map, ArrayRef tileSizes) { +static bool isTiled(AffineMap map, ArrayRef tileSizes) { if (!map) return false; for (unsigned r = 0; r < map.getNumResults(); ++r) @@ -229,13 +228,13 @@ static bool isTiled(AffineMap map, ArrayRef tileSizes) { return false; } -static SmallVector +static SmallVector makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp, - ArrayRef ivs, ArrayRef tileSizes, - ArrayRef viewSizes, OperationFolder *folder) { + ArrayRef ivs, ArrayRef tileSizes, + ArrayRef viewSizes, OperationFolder *folder) { assert(ivs.size() == static_cast(llvm::count_if( llvm::make_range(tileSizes.begin(), tileSizes.end()), - [](ValuePtr v) { return !isZero(v); })) && + [](Value v) { return !isZero(v); })) && "expected as many ivs as non-zero sizes"); using edsc::intrinsics::select; @@ -244,22 +243,21 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp, // Construct (potentially temporary) mins and maxes on which to apply maps // that define tile subviews. - SmallVector lbs, subViewSizes; + SmallVector lbs, subViewSizes; for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) { bool isTiled = !isZero(tileSizes[idx]); - lbs.push_back(isTiled ? ivs[idxIvs++] - : (ValuePtr)constant_index(folder, 0)); + lbs.push_back(isTiled ? ivs[idxIvs++] : (Value)constant_index(folder, 0)); subViewSizes.push_back(isTiled ? tileSizes[idx] : viewSizes[idx]); } auto *op = linalgOp.getOperation(); - SmallVector res; + SmallVector res; res.reserve(op->getNumOperands()); auto viewIteratorBegin = linalgOp.getInputsAndOutputs().begin(); for (unsigned viewIndex = 0; viewIndex < linalgOp.getNumInputsAndOutputs(); ++viewIndex) { - ValuePtr view = *(viewIteratorBegin + viewIndex); + Value view = *(viewIteratorBegin + viewIndex); unsigned rank = view->getType().cast().getRank(); auto map = loopToOperandRangesMaps(linalgOp)[viewIndex]; // If the view is not tiled, we can use it as is. @@ -269,7 +267,7 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp, } // Construct a new subview for the tile. - SmallVector offsets, sizes, strides; + SmallVector offsets, sizes, strides; offsets.reserve(rank); sizes.reserve(rank); strides.reserve(rank); @@ -300,16 +298,17 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp, // This is a special type of folding that we only apply when `folder` is // defined. if (folder) - for (auto v : llvm::concat(lbs, subViewSizes)) + for (auto v : llvm::concat(lbs, subViewSizes)) if (v->use_empty()) v->getDefiningOp()->erase(); return res; } -Optional mlir::linalg::tileLinalgOp( - OpBuilder &b, LinalgOp op, ArrayRef tileSizes, - ArrayRef permutation, OperationFolder *folder) { +Optional +mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, ArrayRef tileSizes, + ArrayRef permutation, + OperationFolder *folder) { // 1. Enforce the convention that "tiling by zero" skips tiling a particular // dimension. This convention is significantly simpler to handle instead of // adjusting affine maps to account for missing dimensions. @@ -352,7 +351,7 @@ Optional mlir::linalg::tileLinalgOp( LoopNestRangeBuilder(pivs, loopRanges)([&] { auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); - SmallVector ivValues(ivs.begin(), ivs.end()); + SmallVector ivValues(ivs.begin(), ivs.end()); // If we have to apply a permutation to the tiled loop nest, we have to // reorder the induction variables This permutation is the right one @@ -403,7 +402,7 @@ Optional mlir::linalg::tileLinalgOp( ScopedContext scope(b, op.getLoc()); // Materialize concrete tile size values to pass the generic tiling function. - SmallVector tileSizeValues; + SmallVector tileSizeValues; tileSizeValues.reserve(tileSizes.size()); for (auto ts : tileSizes) tileSizeValues.push_back(constant_index(folder, ts)); diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index ae02af0ecc8..560a0235a38 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -83,7 +83,7 @@ mlir::edsc::LoopNestRangeBuilder::LoopNestRangeBuilder( } mlir::edsc::LoopNestRangeBuilder::LoopNestRangeBuilder( - ArrayRef ivs, ArrayRef ranges) + ArrayRef ivs, ArrayRef ranges) : LoopNestRangeBuilder( ivs, SmallVector(ranges.begin(), ranges.end())) {} @@ -97,22 +97,22 @@ ValueHandle LoopNestRangeBuilder::LoopNestRangeBuilder::operator()( return ValueHandle::null(); } -static ValuePtr emitOrFoldComposedAffineApply(OpBuilder &b, Location loc, - AffineMap map, - ArrayRef operandsRef, - OperationFolder *folder) { - SmallVector operands(operandsRef.begin(), operandsRef.end()); +static Value emitOrFoldComposedAffineApply(OpBuilder &b, Location loc, + AffineMap map, + ArrayRef operandsRef, + OperationFolder *folder) { + SmallVector operands(operandsRef.begin(), operandsRef.end()); fullyComposeAffineMapAndOperands(&map, &operands); canonicalizeMapAndOperands(&map, &operands); return folder ? folder->create(b, loc, map, operands) : b.create(loc, map, operands); } -SmallVector -mlir::linalg::applyMapToValues(OpBuilder &b, Location loc, AffineMap map, - ArrayRef values, - OperationFolder *folder) { - SmallVector res; +SmallVector mlir::linalg::applyMapToValues(OpBuilder &b, Location loc, + AffineMap map, + ArrayRef values, + OperationFolder *folder) { + SmallVector res; res.reserve(map.getNumResults()); unsigned numDims = map.getNumDims(); // For each `expr` in `map`, applies the `expr` to the values extracted from @@ -128,12 +128,12 @@ mlir::linalg::applyMapToValues(OpBuilder &b, Location loc, AffineMap map, /// Returns all the operands of `linalgOp` that are not views. /// Asserts that these operands are value types to allow transformations like /// tiling to just use the values when cloning `linalgOp`. -SmallVector +SmallVector mlir::linalg::getAssumedNonViewOperands(LinalgOp linalgOp) { auto *op = linalgOp.getOperation(); unsigned numViews = linalgOp.getNumInputsAndOutputs(); unsigned nOperands = op->getNumOperands() - numViews; - SmallVector res; + SmallVector res; res.reserve(nOperands); for (unsigned i = 0; i < nOperands; ++i) { res.push_back(op->getOperand(numViews + i)); diff --git a/mlir/lib/Dialect/LoopOps/LoopOps.cpp b/mlir/lib/Dialect/LoopOps/LoopOps.cpp index 8e19eba911a..acbab01df79 100644 --- a/mlir/lib/Dialect/LoopOps/LoopOps.cpp +++ b/mlir/lib/Dialect/LoopOps/LoopOps.cpp @@ -60,8 +60,8 @@ LoopOpsDialect::LoopOpsDialect(MLIRContext *context) // ForOp //===----------------------------------------------------------------------===// -void ForOp::build(Builder *builder, OperationState &result, ValuePtr lb, - ValuePtr ub, ValuePtr step) { +void ForOp::build(Builder *builder, OperationState &result, Value lb, Value ub, + Value step) { result.addOperands({lb, ub, step}); Region *bodyRegion = result.addRegion(); ForOp::ensureTerminator(*bodyRegion, *builder, result.location); @@ -125,7 +125,7 @@ static ParseResult parseForOp(OpAsmParser &parser, OperationState &result) { Region &ForOp::getLoopBody() { return region(); } -bool ForOp::isDefinedOutsideOfLoop(ValuePtr value) { +bool ForOp::isDefinedOutsideOfLoop(Value value) { return !region().isAncestor(value->getParentRegion()); } @@ -135,7 +135,7 @@ LogicalResult ForOp::moveOutOfLoop(ArrayRef ops) { return success(); } -ForOp mlir::loop::getForInductionVarOwner(ValuePtr val) { +ForOp mlir::loop::getForInductionVarOwner(Value val) { auto ivArg = val.dyn_cast(); if (!ivArg) return ForOp(); @@ -148,7 +148,7 @@ ForOp mlir::loop::getForInductionVarOwner(ValuePtr val) { // IfOp //===----------------------------------------------------------------------===// -void IfOp::build(Builder *builder, OperationState &result, ValuePtr cond, +void IfOp::build(Builder *builder, OperationState &result, Value cond, bool withElseRegion) { result.addOperands(cond); Region *thenRegion = result.addRegion(); diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp index 4416e1e6b04..144252bb272 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -94,7 +94,7 @@ struct SPIRVInlinerInterface : public DialectInlinerInterface { /// Handle the given inlined terminator by replacing it with a new operation /// as necessary. void handleTerminator(Operation *op, - ArrayRef valuesToRepl) const final { + ArrayRef valuesToRepl) const final { // Only spv.ReturnValue needs to be handled here. auto retValOp = dyn_cast(op); if (!retValOp) diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp index 7b6c013f9ed..0d2348c2626 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -220,9 +220,9 @@ getOrInsertBuiltinVariable(spirv::ModuleOp &moduleOp, Location loc, /// Gets the global variable associated with a builtin and add /// it if it doesn't exist. -ValuePtr mlir::spirv::getBuiltinVariableValue(Operation *op, - spirv::BuiltIn builtin, - OpBuilder &builder) { +Value mlir::spirv::getBuiltinVariableValue(Operation *op, + spirv::BuiltIn builtin, + OpBuilder &builder) { auto moduleOp = op->getParentOfType(); if (!moduleOp) { op->emitError("expected operation to be within a SPIR-V module"); @@ -230,7 +230,7 @@ ValuePtr mlir::spirv::getBuiltinVariableValue(Operation *op, } spirv::GlobalVariableOp varOp = getOrInsertBuiltinVariable(moduleOp, op->getLoc(), builtin, builder); - ValuePtr ptr = builder.create(op->getLoc(), varOp); + Value ptr = builder.create(op->getLoc(), varOp); return builder.create(op->getLoc(), ptr, /*memory_access =*/nullptr, /*alignment =*/nullptr); diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index e42dc10f55d..f42c077f77e 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -264,8 +264,8 @@ static LogicalResult verifyMemorySemantics(BarrierOp op) { } template -static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, - ValuePtr ptr, ValuePtr val) { +static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr, + Value val) { // ODS already checks ptr is spirv::PointerType. Just check that the pointee // type of the pointer and the type of the value are the same // @@ -655,8 +655,8 @@ static ParseResult parseShiftOp(OpAsmParser &parser, OperationState &state) { } static void printShiftOp(Operation *op, OpAsmPrinter &printer) { - ValuePtr base = op->getOperand(0); - ValuePtr shift = op->getOperand(1); + Value base = op->getOperand(0); + Value shift = op->getOperand(1); printer << op->getName() << ' ' << *base << ", " << *shift << " : " << base->getType() << ", " << shift->getType(); } @@ -733,7 +733,7 @@ static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) { } void spirv::AccessChainOp::build(Builder *builder, OperationState &state, - ValuePtr basePtr, ValueRange indices) { + Value basePtr, ValueRange indices) { auto type = getElementPtrType(basePtr->getType(), indices, state.location); assert(type && "Unable to deduce return type based on basePtr and indices"); build(builder, state, type, basePtr, indices); @@ -773,8 +773,8 @@ static void print(spirv::AccessChainOp op, OpAsmPrinter &printer) { } static LogicalResult verify(spirv::AccessChainOp accessChainOp) { - SmallVector indices(accessChainOp.indices().begin(), - accessChainOp.indices().end()); + SmallVector indices(accessChainOp.indices().begin(), + accessChainOp.indices().end()); auto resultType = getElementPtrType(accessChainOp.base_ptr()->getType(), indices, accessChainOp.getLoc()); if (!resultType) { @@ -815,7 +815,7 @@ struct CombineChainedAccessChain } // Combine indices. - SmallVector indices(parentAccessChainOp.indices()); + SmallVector indices(parentAccessChainOp.indices()); indices.append(accessChainOp.indices().begin(), accessChainOp.indices().end()); @@ -1051,7 +1051,7 @@ static LogicalResult verify(spirv::BitFieldInsertOp bitFieldOp) { static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &state) { Block *dest; - SmallVector destOperands; + SmallVector destOperands; if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure(); state.addSuccessor(dest, destOperands); @@ -1080,7 +1080,7 @@ static ParseResult parseBranchConditionalOp(OpAsmParser &parser, auto &builder = parser.getBuilder(); OpAsmParser::OperandType condInfo; Block *dest; - SmallVector destOperands; + SmallVector destOperands; // Parse the condition. Type boolTy = builder.getI1Type(); @@ -1205,7 +1205,7 @@ static void print(spirv::CompositeConstructOp compositeConstructOp, static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) { auto cType = compositeConstructOp.getType().cast(); - SmallVector constituents(compositeConstructOp.constituents()); + SmallVector constituents(compositeConstructOp.constituents()); if (constituents.size() != cType.getNumElements()) { return compositeConstructOp.emitError( "has incorrect number of operands: expected ") @@ -1230,7 +1230,7 @@ static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) { //===----------------------------------------------------------------------===// void spirv::CompositeExtractOp::build(Builder *builder, OperationState &state, - ValuePtr composite, + Value composite, ArrayRef indices) { auto indexAttr = builder->getI32ArrayAttr(indices); auto elementType = @@ -1954,7 +1954,7 @@ OpFoldResult spirv::ISubOp::fold(ArrayRef operands) { //===----------------------------------------------------------------------===// void spirv::LoadOp::build(Builder *builder, OperationState &state, - ValuePtr basePtr, IntegerAttr memory_access, + Value basePtr, IntegerAttr memory_access, IntegerAttr alignment) { auto ptrType = basePtr->getType().cast(); build(builder, state, ptrType.getPointeeType(), basePtr, memory_access, @@ -2487,9 +2487,8 @@ static LogicalResult verify(spirv::ReturnValueOp retValOp) { // spv.Select //===----------------------------------------------------------------------===// -void spirv::SelectOp::build(Builder *builder, OperationState &state, - ValuePtr cond, ValuePtr trueValue, - ValuePtr falseValue) { +void spirv::SelectOp::build(Builder *builder, OperationState &state, Value cond, + Value trueValue, Value falseValue) { build(builder, state, trueValue->getType(), cond, trueValue, falseValue); } @@ -2739,13 +2738,13 @@ private: } // Returns a soruce value for the given block. - ValuePtr getSrcValue(Block *block) const { + Value getSrcValue(Block *block) const { auto storeOp = cast(block->front()); return storeOp.value(); } // Returns a destination value for the given block. - ValuePtr getDstPtr(Block *block) const { + Value getDstPtr(Block *block) const { auto storeOp = cast(block->front()); return storeOp.ptr(); } diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index 9e820c6f42b..17ddc48573a 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -318,7 +318,7 @@ private: /// This method materializes normal constants and inserts "casting" ops /// (`spv._address_of` and `spv._reference_of`) to turn an symbol into a SSA /// value for handling uses of module scope constants/variables in functions. - ValuePtr getValue(uint32_t id); + Value getValue(uint32_t id); /// Slices the first instruction out of `binary` and returns its opcode and /// operands via `opcode` and `operands` respectively. Returns failure if @@ -437,7 +437,7 @@ private: DenseMap blockPhiInfo; // Result to value mapping. - DenseMap valueMap; + DenseMap valueMap; // Mapping from result to undef value of a type. DenseMap undefMap; @@ -1522,8 +1522,8 @@ Deserializer::processBranchConditional(ArrayRef operands) { opBuilder.create( unknownLoc, condition, trueBlock, - /*trueArguments=*/ArrayRef(), falseBlock, - /*falseArguments=*/ArrayRef(), weights); + /*trueArguments=*/ArrayRef(), falseBlock, + /*falseArguments=*/ArrayRef(), weights); return success(); } @@ -1617,7 +1617,7 @@ LogicalResult Deserializer::processPhi(ArrayRef operands) { // Create a block argument for this OpPhi instruction. Type blockArgType = getType(operands[0]); - BlockArgumentPtr blockArg = curBlock->addArgument(blockArgType); + BlockArgument blockArg = curBlock->addArgument(blockArgType); valueMap[operands[1]] = blockArg; LLVM_DEBUG(llvm::dbgs() << "[phi] created block argument " << blockArg << " id = " << operands[1] << " of type " @@ -1774,7 +1774,7 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() { LLVM_DEBUG(llvm::dbgs() << "[cf] cloned block " << newBlock << " from block " << block << "\n"); if (!isFnEntryBlock(block)) { - for (BlockArgumentPtr blockArg : block->getArguments()) { + for (BlockArgument blockArg : block->getArguments()) { auto newArg = newBlock->addArgument(blockArg->getType()); mapper.map(blockArg, newArg); LLVM_DEBUG(llvm::dbgs() << "[cf] remapped block argument " << blockArg @@ -1815,13 +1815,13 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() { // we place the selection/loop op inside the old merge block, we need to // make sure the old merge block has the same block argument list. assert(mergeBlock->args_empty() && "OpPhi in loop merge block unsupported"); - for (BlockArgumentPtr blockArg : headerBlock->getArguments()) { + for (BlockArgument blockArg : headerBlock->getArguments()) { mergeBlock->addArgument(blockArg->getType()); } // If the loop header block has block arguments, make sure the spv.branch op // matches. - SmallVector blockArgs; + SmallVector blockArgs; if (!headerBlock->args_empty()) blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()}; @@ -1829,7 +1829,7 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() { // loop header block. builder.setInsertionPointToEnd(&body.front()); builder.create(location, mapper.lookupOrNull(headerBlock), - ArrayRef(blockArgs)); + ArrayRef(blockArgs)); } // All the blocks cloned into the SelectionOp/LoopOp's region can now be @@ -1915,10 +1915,10 @@ LogicalResult Deserializer::wireUpBlockArgument() { auto *op = block->getTerminator(); opBuilder.setInsertionPoint(op); - SmallVector blockArgs; + SmallVector blockArgs; blockArgs.reserve(phiInfo.size()); for (uint32_t valueId : phiInfo) { - if (ValuePtr value = getValue(valueId)) { + if (Value value = getValue(valueId)) { blockArgs.push_back(value); LLVM_DEBUG(llvm::dbgs() << "[phi] block argument " << value << " id = " << valueId << '\n'); @@ -1987,7 +1987,7 @@ LogicalResult Deserializer::structurizeControlFlow() { // Instruction //===----------------------------------------------------------------------===// -ValuePtr Deserializer::getValue(uint32_t id) { +Value Deserializer::getValue(uint32_t id) { if (auto constInfo = getConstant(id)) { // Materialize a `spv.constant` op at every use site. return opBuilder.create(unknownLoc, constInfo->second, @@ -2183,7 +2183,7 @@ LogicalResult Deserializer::processBitcast(ArrayRef words) { } } valueID = words[wordIndex++]; - SmallVector operands; + SmallVector operands; SmallVector attributes; if (wordIndex < words.size()) { auto arg = getValue(words[wordIndex]); @@ -2357,7 +2357,7 @@ Deserializer::processOp(ArrayRef operands) { auto functionName = getFunctionSymbol(functionID); - SmallVector arguments; + SmallVector arguments; for (auto operand : llvm::drop_begin(operands, 3)) { auto value = getValue(operand); if (!value) { diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index 424c2e0427e..0cdcc25b77d 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -314,7 +314,7 @@ private: uint32_t opcode, ArrayRef operands); - uint32_t getValueID(ValuePtr val) const { return valueIDMap.lookup(val); } + uint32_t getValueID(Value val) const { return valueIDMap.lookup(val); } LogicalResult processAddressOfOp(spirv::AddressOfOp addressOfOp); @@ -405,7 +405,7 @@ private: DenseMap undefValIDMap; /// Map from results of normal operations to their s. - DenseMap valueIDMap; + DenseMap valueIDMap; /// Map from extended instruction set name to s. llvm::StringMap extendedInstSetIDMap; @@ -448,7 +448,7 @@ private: /// placed inside `functions`) here. And then after emitting all blocks, we /// replace the dummy 0 with the real result by overwriting /// `functions[offset]`. - DenseMap> deferredPhiValues; + DenseMap> deferredPhiValues; }; } // namespace @@ -504,7 +504,7 @@ void Serializer::collect(SmallVectorImpl &binary) { void Serializer::printValueIDMap(raw_ostream &os) { os << "\n= Value Map =\n\n"; for (auto valueIDPair : valueIDMap) { - ValuePtr val = valueIDPair.first; + Value val = valueIDPair.first; os << " " << val << " " << "id = " << valueIDPair.second << ' '; if (auto *op = val->getDefiningOp()) { @@ -743,7 +743,7 @@ LogicalResult Serializer::processFuncOp(FuncOp op) { // There might be OpPhi instructions who have value references needing to fix. for (auto deferredValue : deferredPhiValues) { - ValuePtr value = deferredValue.first; + Value value = deferredValue.first; uint32_t id = getValueID(value); LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value << " to id = " << id << '\n'); @@ -1393,7 +1393,7 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) { // Then create OpPhi instruction for each of the block argument. for (auto argIndex : llvm::seq(0, block->getNumArguments())) { - BlockArgumentPtr arg = block->getArgument(argIndex); + BlockArgument arg = block->getArgument(argIndex); // Get the type and result for this OpPhi instruction. uint32_t phiTypeID = 0; @@ -1409,7 +1409,7 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) { phiArgs.push_back(phiID); for (auto predIndex : llvm::seq(0, predecessors.size())) { - ValuePtr value = *(predecessors[predIndex].second + argIndex); + Value value = *(predecessors[predIndex].second + argIndex); uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first); LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId << ") value " << value << ' '); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp index 0be24bf169c..d7194da0778 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -131,7 +131,7 @@ class FuncOpLowering final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult - matchAndRewrite(FuncOp funcOp, ArrayRef operands, + matchAndRewrite(FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -144,7 +144,7 @@ private: } // namespace PatternMatchResult -FuncOpLowering::matchAndRewrite(FuncOp funcOp, ArrayRef operands, +FuncOpLowering::matchAndRewrite(FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (!funcOp.getAttrOfType( spirv::getEntryPointABIAttrName())) { @@ -174,7 +174,7 @@ FuncOpLowering::matchAndRewrite(FuncOp funcOp, ArrayRef operands, OpBuilder::InsertionGuard funcInsertionGuard(rewriter); rewriter.setInsertionPointToStart(&funcOp.front()); // Insert spirv::AddressOf and spirv::AccessChain operations. - ValuePtr replacement = + Value replacement = rewriter.create(funcOp.getLoc(), var); // Check if the arg is a scalar or vector type. In that case, the value // needs to be loaded into registers. diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index 55da59a0c74..831c78a4521 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -72,7 +72,7 @@ struct StdInlinerInterface : public DialectInlinerInterface { /// Handle the given inlined terminator by replacing it with a new operation /// as necessary. void handleTerminator(Operation *op, - ArrayRef valuesToRepl) const final { + ArrayRef valuesToRepl) const final { // Only "std.return" needs to be handled here. auto returnOp = cast(op); @@ -175,7 +175,7 @@ void mlir::printDimAndSymbolList(Operation::operand_iterator begin, // dimension operands parsed. // Returns 'false' on success and 'true' on error. ParseResult mlir::parseDimAndSymbolList(OpAsmParser &parser, - SmallVectorImpl &operands, + SmallVectorImpl &operands, unsigned &numDims) { SmallVector opInfos; if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren)) @@ -316,7 +316,7 @@ struct SimplifyAllocConst : public OpRewritePattern { PatternRewriter &rewriter) const override { // Check to see if any dimensions operands are constants. If so, we can // substitute and drop them. - if (llvm::none_of(alloc.getOperands(), [](ValuePtr operand) { + if (llvm::none_of(alloc.getOperands(), [](Value operand) { return matchPattern(operand, m_ConstantIndex()); })) return matchFailure(); @@ -327,8 +327,8 @@ struct SimplifyAllocConst : public OpRewritePattern { // and keep track of the resultant memref type to build. SmallVector newShapeConstants; newShapeConstants.reserve(memrefType.getRank()); - SmallVector newOperands; - SmallVector droppedOperands; + SmallVector newOperands; + SmallVector droppedOperands; unsigned dynamicDimPos = 0; for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { @@ -420,7 +420,7 @@ struct SimplifyBrToBlockWithSinglePred : public OpRewritePattern { static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &result) { Block *dest; - SmallVector destOperands; + SmallVector destOperands; if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure(); result.addSuccessor(dest, destOperands); @@ -614,7 +614,7 @@ static Type getI1SameShape(Builder *build, Type type) { //===----------------------------------------------------------------------===// static void buildCmpIOp(Builder *build, OperationState &result, - CmpIPredicate predicate, ValuePtr lhs, ValuePtr rhs) { + CmpIPredicate predicate, Value lhs, Value rhs) { result.addOperands({lhs, rhs}); result.types.push_back(getI1SameShape(build, lhs->getType())); result.addAttribute( @@ -768,7 +768,7 @@ CmpFPredicate CmpFOp::getPredicateByName(StringRef name) { } static void buildCmpFOp(Builder *build, OperationState &result, - CmpFPredicate predicate, ValuePtr lhs, ValuePtr rhs) { + CmpFPredicate predicate, Value lhs, Value rhs) { result.addOperands({lhs, rhs}); result.types.push_back(getI1SameShape(build, lhs->getType())); result.addAttribute( @@ -937,7 +937,7 @@ struct SimplifyConstCondBranchPred : public OpRewritePattern { static ParseResult parseCondBranchOp(OpAsmParser &parser, OperationState &result) { - SmallVector destOperands; + SmallVector destOperands; Block *dest; OpAsmParser::OperandType condInfo; @@ -1079,7 +1079,7 @@ OpFoldResult ConstantOp::fold(ArrayRef operands) { } void ConstantOp::getAsmResultNames( - function_ref setNameFn) { + function_ref setNameFn) { Type type = getType(); if (auto intCst = getValue().dyn_cast()) { IntegerType intTy = type.dyn_cast(); @@ -1174,7 +1174,7 @@ struct SimplifyDeadDealloc : public OpRewritePattern { PatternMatchResult matchAndRewrite(DeallocOp dealloc, PatternRewriter &rewriter) const override { // Check that the memref operand's defining operation is an AllocOp. - ValuePtr memref = dealloc.memref(); + Value memref = dealloc.memref(); if (!isa_and_nonnull(memref->getDefiningOp())) return matchFailure(); @@ -1353,11 +1353,10 @@ OpFoldResult UnsignedDivIOp::fold(ArrayRef operands) { // --------------------------------------------------------------------------- void DmaStartOp::build(Builder *builder, OperationState &result, - ValuePtr srcMemRef, ValueRange srcIndices, - ValuePtr destMemRef, ValueRange destIndices, - ValuePtr numElements, ValuePtr tagMemRef, - ValueRange tagIndices, ValuePtr stride, - ValuePtr elementsPerStride) { + Value srcMemRef, ValueRange srcIndices, Value destMemRef, + ValueRange destIndices, Value numElements, + Value tagMemRef, ValueRange tagIndices, Value stride, + Value elementsPerStride) { result.addOperands(srcMemRef); result.addOperands(srcIndices); result.addOperands(destMemRef); @@ -1497,9 +1496,8 @@ LogicalResult DmaStartOp::fold(ArrayRef cstOperands, // DmaWaitOp // --------------------------------------------------------------------------- -void DmaWaitOp::build(Builder *builder, OperationState &result, - ValuePtr tagMemRef, ValueRange tagIndices, - ValuePtr numElements) { +void DmaWaitOp::build(Builder *builder, OperationState &result, Value tagMemRef, + ValueRange tagIndices, Value numElements) { result.addOperands(tagMemRef); result.addOperands(tagIndices); result.addOperands(numElements); @@ -2356,7 +2354,7 @@ static void print(OpAsmPrinter &p, ViewOp op) { p << " : " << op.getOperand(0)->getType() << " to " << op.getType(); } -ValuePtr ViewOp::getDynamicOffset() { +Value ViewOp::getDynamicOffset() { int64_t offset; SmallVector strides; auto result = @@ -2431,7 +2429,7 @@ struct ViewOpShapeFolder : public OpRewritePattern { PatternMatchResult matchAndRewrite(ViewOp viewOp, PatternRewriter &rewriter) const override { // Return if none of the operands are constants. - if (llvm::none_of(viewOp.getOperands(), [](ValuePtr operand) { + if (llvm::none_of(viewOp.getOperands(), [](Value operand) { return matchPattern(operand, m_ConstantIndex()); })) return matchFailure(); @@ -2448,8 +2446,8 @@ struct ViewOpShapeFolder : public OpRewritePattern { if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) return matchFailure(); - SmallVector newOperands; - SmallVector droppedOperands; + SmallVector newOperands; + SmallVector droppedOperands; // Fold dynamic offset operand if it is produced by a constant. auto dynamicOffset = viewOp.getDynamicOffset(); @@ -2567,7 +2565,7 @@ static Type inferSubViewResultType(MemRefType memRefType) { memRefType.getMemorySpace()); } -void mlir::SubViewOp::build(Builder *b, OperationState &result, ValuePtr source, +void mlir::SubViewOp::build(Builder *b, OperationState &result, Value source, ValueRange offsets, ValueRange sizes, ValueRange strides, Type resultType, ArrayRef attrs) { @@ -2581,7 +2579,7 @@ void mlir::SubViewOp::build(Builder *b, OperationState &result, ValuePtr source, } void mlir::SubViewOp::build(Builder *b, OperationState &result, Type resultType, - ValuePtr source) { + Value source) { build(b, result, source, /*offsets=*/{}, /*sizes=*/{}, /*strides=*/{}, resultType); } @@ -2817,7 +2815,7 @@ public: // Follow all or nothing approach for shapes for now. If all the operands // for sizes are constants then fold it into the type of the result memref. if (subViewType.hasStaticShape() || - llvm::any_of(subViewOp.sizes(), [](ValuePtr operand) { + llvm::any_of(subViewOp.sizes(), [](Value operand) { return !matchPattern(operand, m_ConstantIndex()); })) { return matchFailure(); @@ -2833,7 +2831,7 @@ public: subViewType.getMemorySpace()); auto newSubViewOp = rewriter.create( subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(), - ArrayRef(), subViewOp.strides(), newMemRefType); + ArrayRef(), subViewOp.strides(), newMemRefType); // Insert a memref_cast for compatibility of the uses of the op. rewriter.replaceOpWithNewOp( subViewOp.sizes(), subViewOp, newSubViewOp, subViewOp.getType()); @@ -2862,7 +2860,7 @@ public: failed(getStridesAndOffset(subViewType, resultStrides, resultOffset)) || llvm::is_contained(baseStrides, MemRefType::getDynamicStrideOrOffset()) || - llvm::any_of(subViewOp.strides(), [](ValuePtr stride) { + llvm::any_of(subViewOp.strides(), [](Value stride) { return !matchPattern(stride, m_ConstantIndex()); })) { return matchFailure(); @@ -2883,7 +2881,7 @@ public: layoutMap, subViewType.getMemorySpace()); auto newSubViewOp = rewriter.create( subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(), - subViewOp.sizes(), ArrayRef(), newMemRefType); + subViewOp.sizes(), ArrayRef(), newMemRefType); // Insert a memref_cast for compatibility of the uses of the op. rewriter.replaceOpWithNewOp( subViewOp.strides(), subViewOp, newSubViewOp, subViewOp.getType()); @@ -2913,7 +2911,7 @@ public: llvm::is_contained(baseStrides, MemRefType::getDynamicStrideOrOffset()) || baseOffset == MemRefType::getDynamicStrideOrOffset() || - llvm::any_of(subViewOp.offsets(), [](ValuePtr stride) { + llvm::any_of(subViewOp.offsets(), [](Value stride) { return !matchPattern(stride, m_ConstantIndex()); })) { return matchFailure(); @@ -2934,7 +2932,7 @@ public: MemRefType::get(subViewType.getShape(), subViewType.getElementType(), layoutMap, subViewType.getMemorySpace()); auto newSubViewOp = rewriter.create( - subViewOp.getLoc(), subViewOp.source(), ArrayRef(), + subViewOp.getLoc(), subViewOp.source(), ArrayRef(), subViewOp.sizes(), subViewOp.strides(), newMemRefType); // Insert a memref_cast for compatibility of the uses of the op. rewriter.replaceOpWithNewOp( diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index 8ceff014029..a3904ef97a2 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -63,7 +63,7 @@ ArrayAttr vector::getVectorSubscriptAttr(Builder &builder, //===----------------------------------------------------------------------===// void vector::ContractionOp::build(Builder *builder, OperationState &result, - ValuePtr lhs, ValuePtr rhs, ValuePtr acc, + Value lhs, Value rhs, Value acc, ArrayAttr indexingMaps, ArrayAttr iteratorTypes) { result.addOperands({lhs, rhs, acc}); @@ -395,7 +395,7 @@ static Type inferExtractOpResultType(VectorType vectorType, } void vector::ExtractOp::build(Builder *builder, OperationState &result, - ValuePtr source, ArrayRef position) { + Value source, ArrayRef position) { result.addOperands(source); auto positionAttr = getVectorSubscriptAttr(*builder, position); result.addTypes(inferExtractOpResultType(source->getType().cast(), @@ -462,7 +462,7 @@ static LogicalResult verify(vector::ExtractOp op) { //===----------------------------------------------------------------------===// void ExtractSlicesOp::build(Builder *builder, OperationState &result, - TupleType tupleType, ValuePtr vector, + TupleType tupleType, Value vector, ArrayRef sizes, ArrayRef strides) { result.addOperands(vector); @@ -638,8 +638,8 @@ static ParseResult parseBroadcastOp(OpAsmParser &parser, // ShuffleOp //===----------------------------------------------------------------------===// -void ShuffleOp::build(Builder *builder, OperationState &result, ValuePtr v1, - ValuePtr v2, ArrayRef mask) { +void ShuffleOp::build(Builder *builder, OperationState &result, Value v1, + Value v2, ArrayRef mask) { result.addOperands({v1, v2}); auto maskAttr = getVectorSubscriptAttr(*builder, mask); result.addTypes(v1->getType()); @@ -762,8 +762,8 @@ static LogicalResult verify(InsertElementOp op) { // InsertOp //===----------------------------------------------------------------------===// -void InsertOp::build(Builder *builder, OperationState &result, ValuePtr source, - ValuePtr dest, ArrayRef position) { +void InsertOp::build(Builder *builder, OperationState &result, Value source, + Value dest, ArrayRef position) { result.addOperands({source, dest}); auto positionAttr = getVectorSubscriptAttr(*builder, position); result.addTypes(dest->getType()); @@ -884,7 +884,7 @@ void InsertSlicesOp::getStrides(SmallVectorImpl &results) { //===----------------------------------------------------------------------===// void InsertStridedSliceOp::build(Builder *builder, OperationState &result, - ValuePtr source, ValuePtr dest, + Value source, Value dest, ArrayRef offsets, ArrayRef strides) { result.addOperands({source, dest}); @@ -1192,7 +1192,7 @@ static LogicalResult verify(ReshapeOp op) { // If all shape operands are produced by constant ops, verify that product // of dimensions for input/output shape match. - auto isDefByConstant = [](ValuePtr operand) { + auto isDefByConstant = [](Value operand) { return isa_and_nonnull(operand->getDefiningOp()); }; if (llvm::all_of(op.input_shape(), isDefByConstant) && @@ -1238,7 +1238,7 @@ static Type inferStridedSliceOpResultType(VectorType vectorType, } void StridedSliceOp::build(Builder *builder, OperationState &result, - ValuePtr source, ArrayRef offsets, + Value source, ArrayRef offsets, ArrayRef sizes, ArrayRef strides) { result.addOperands(source); auto offsetsAttr = getVectorSubscriptAttr(*builder, offsets); @@ -1593,8 +1593,7 @@ static MemRefType inferVectorTypeCastResultType(MemRefType t) { return MemRefType::get({}, VectorType::get(t.getShape(), t.getElementType())); } -void TypeCastOp::build(Builder *builder, OperationState &result, - ValuePtr source) { +void TypeCastOp::build(Builder *builder, OperationState &result, Value source) { result.addOperands(source); result.addTypes( inferVectorTypeCastResultType(source->getType().cast())); @@ -1784,7 +1783,7 @@ public: PatternMatchResult matchAndRewrite(CreateMaskOp createMaskOp, PatternRewriter &rewriter) const override { // Return if any of 'createMaskOp' operands are not defined by a constant. - auto is_not_def_by_constant = [](ValuePtr operand) { + auto is_not_def_by_constant = [](Value operand) { return !isa_and_nonnull(operand->getDefiningOp()); }; if (llvm::any_of(createMaskOp.operands(), is_not_def_by_constant)) diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp index 927aeda4ecd..28b803f7cde 100644 --- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp @@ -97,17 +97,17 @@ static SmallVector delinearize(int64_t linearIndex, // `resultTypes`. static Operation *cloneOpWithOperandsAndTypes(PatternRewriter &builder, Location loc, Operation *op, - ArrayRef operands, + ArrayRef operands, ArrayRef resultTypes) { OperationState res(loc, op->getName().getStringRef(), operands, resultTypes, op->getAttrs()); return builder.createOperation(res); } -static ValuePtr makeSplatZero(Location loc, PatternRewriter &rewriter, - VectorType vt) { +static Value makeSplatZero(Location loc, PatternRewriter &rewriter, + VectorType vt) { auto t = vt.getElementType(); - ValuePtr f = nullptr; + Value f = nullptr; if (t.isBF16() || t.isF16()) f = rewriter.create(loc, t, rewriter.getF64FloatAttr(0.0f)); else if (t.isF32()) @@ -181,12 +181,12 @@ struct UnrolledVectorState { SmallVector unrollFactors; SmallVector basis; int64_t numInstances; - ValuePtr slicesTuple; + Value slicesTuple; }; // Populates 'state' with unrolled shape, unroll factors, basis and // num unrolled instances for 'vectorType'. -static void initUnrolledVectorState(VectorType vectorType, ValuePtr initValue, +static void initUnrolledVectorState(VectorType vectorType, Value initValue, const DenseMap &indexMap, ArrayRef targetShape, UnrolledVectorState &state, @@ -230,11 +230,10 @@ getUnrolledVectorLinearIndex(UnrolledVectorState &state, // Returns an unrolled vector at 'vectorOffsets' within the vector // represented by 'state'. The vector is created from a slice of 'initValue' // if not present in 'cache'. -static ValuePtr getOrCreateUnrolledVectorSlice( +static Value getOrCreateUnrolledVectorSlice( Location loc, UnrolledVectorState &state, ArrayRef vectorOffsets, ArrayRef offsets, DenseMap &indexMap, - ValuePtr initValue, SmallVectorImpl &cache, - PatternRewriter &builder) { + Value initValue, SmallVectorImpl &cache, PatternRewriter &builder) { // Compute slice offsets. SmallVector sliceOffsets(state.unrolledShape.size()); getMappedElements(indexMap, offsets, sliceOffsets); @@ -321,10 +320,12 @@ struct VectorState { // TODO(andydavis) Generalize this to support structured ops beyond // vector ContractionOp, and merge it with 'unrollSingleResultOpMatchingType' -static ValuePtr unrollSingleResultStructuredOp( - Operation *op, ArrayRef iterationBounds, - std::vector &vectors, unsigned resultIndex, - ArrayRef targetShape, PatternRewriter &builder) { +static Value unrollSingleResultStructuredOp(Operation *op, + ArrayRef iterationBounds, + std::vector &vectors, + unsigned resultIndex, + ArrayRef targetShape, + PatternRewriter &builder) { auto shapedType = op->getResult(0)->getType().dyn_cast_or_null(); if (!shapedType || !shapedType.hasStaticShape()) assert(false && "Expected a statically shaped result type"); @@ -353,7 +354,7 @@ static ValuePtr unrollSingleResultStructuredOp( shapedType.getElementType()); // Initialize caches for intermediate vector results. - std::vector> caches(numVectors); + std::vector> caches(numVectors); for (unsigned i = 0; i < numVectors; ++i) caches[i].resize(unrolledVectorState[i].numInstances); @@ -365,7 +366,7 @@ static ValuePtr unrollSingleResultStructuredOp( auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; }, vectorOffsets, targetShape); // Get cached slice (or create slice) for each operand at 'offsets'. - SmallVector operands; + SmallVector operands; operands.resize(op->getNumOperands()); for (unsigned i = 0; i < numVectors; ++i) { int64_t operandIndex = vectors[i].operandIndex; @@ -391,21 +392,21 @@ static ValuePtr unrollSingleResultStructuredOp( // Create TupleOp of unrolled result vectors. SmallVector vectorTupleTypes(resultValueState.numInstances); - SmallVector vectorTupleValues(resultValueState.numInstances); + SmallVector vectorTupleValues(resultValueState.numInstances); for (unsigned i = 0; i < resultValueState.numInstances; ++i) { vectorTupleTypes[i] = caches[resultIndex][i]->getType().cast(); vectorTupleValues[i] = caches[resultIndex][i]; } TupleType tupleType = builder.getTupleType(vectorTupleTypes); - ValuePtr tupleOp = builder.create(op->getLoc(), tupleType, - vectorTupleValues); + Value tupleOp = builder.create(op->getLoc(), tupleType, + vectorTupleValues); // Create InsertSlicesOp(Tuple(result_vectors)). auto resultVectorType = op->getResult(0)->getType().cast(); SmallVector sizes(resultValueState.unrolledShape); SmallVector strides(resultValueState.unrollFactors.size(), 1); - ValuePtr insertSlicesOp = builder.create( + Value insertSlicesOp = builder.create( op->getLoc(), resultVectorType, tupleOp, builder.getI64ArrayAttr(sizes), builder.getI64ArrayAttr(strides)); return insertSlicesOp; @@ -476,7 +477,7 @@ getVectorElementwiseOpUnrollState(Operation *op, ArrayRef targetShape, } // Entry point for unrolling declarative pattern rewrites. -ValuePtr mlir::vector::unrollSingleResultOpMatchingType( +Value mlir::vector::unrollSingleResultOpMatchingType( PatternRewriter &builder, Operation *op, ArrayRef targetShape) { assert(op->getNumResults() == 1 && "Expected single result operation"); @@ -505,8 +506,8 @@ ValuePtr mlir::vector::unrollSingleResultOpMatchingType( static void generateTransferOpSlices(VectorType vectorType, TupleType tupleType, ArrayRef sizes, ArrayRef strides, - ArrayRef indices, PatternRewriter &rewriter, - function_ref)> fn) { + ArrayRef indices, PatternRewriter &rewriter, + function_ref)> fn) { // Compute strides w.r.t. to slice counts in each dimension. auto maybeDimSliceCounts = shapeRatio(vectorType.getShape(), sizes); assert(maybeDimSliceCounts.hasValue()); @@ -523,13 +524,13 @@ generateTransferOpSlices(VectorType vectorType, TupleType tupleType, auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; }, vectorOffsets, sizes); // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'. - SmallVector sliceIndices(numSliceIndices); + SmallVector sliceIndices(numSliceIndices); for (auto it : llvm::enumerate(indices)) { auto expr = getAffineDimExpr(0, ctx) + getAffineConstantExpr(offsets[it.index()], ctx); auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); sliceIndices[it.index()] = rewriter.create( - it.value()->getLoc(), map, ArrayRef(it.value())); + it.value()->getLoc(), map, ArrayRef(it.value())); } // Call 'fn' to generate slice 'i' at 'sliceIndices'. fn(i, sliceIndices); @@ -548,7 +549,7 @@ struct SplitTransferReadOp : public OpRewritePattern { if (!xferReadOp.permutation_map().isIdentity()) return matchFailure(); // Return unless the unique 'xferReadOp' user is an ExtractSlicesOp. - ValuePtr xferReadResult = xferReadOp.getResult(); + Value xferReadResult = xferReadOp.getResult(); auto extractSlicesOp = dyn_cast(*xferReadResult->getUsers().begin()); if (!xferReadResult->hasOneUse() || !extractSlicesOp) @@ -565,10 +566,10 @@ struct SplitTransferReadOp : public OpRewritePattern { Location loc = xferReadOp.getLoc(); int64_t numSlices = resultTupleType.size(); - SmallVector vectorTupleValues(numSlices); - SmallVector indices(xferReadOp.indices().begin(), - xferReadOp.indices().end()); - auto createSlice = [&](unsigned index, ArrayRef sliceIndices) { + SmallVector vectorTupleValues(numSlices); + SmallVector indices(xferReadOp.indices().begin(), + xferReadOp.indices().end()); + auto createSlice = [&](unsigned index, ArrayRef sliceIndices) { // Get VectorType for slice 'i'. auto sliceVectorType = resultTupleType.getType(index); // Create split TransferReadOp for 'sliceUser'. @@ -580,8 +581,8 @@ struct SplitTransferReadOp : public OpRewritePattern { indices, rewriter, createSlice); // Create tuple of splice xfer read operations. - ValuePtr tupleOp = rewriter.create(loc, resultTupleType, - vectorTupleValues); + Value tupleOp = rewriter.create(loc, resultTupleType, + vectorTupleValues); // Replace 'xferReadOp' with result 'insertSlicesResult'. rewriter.replaceOpWithNewOp( xferReadOp, sourceVectorType, tupleOp, extractSlicesOp.sizes(), @@ -621,9 +622,9 @@ struct SplitTransferWriteOp : public OpRewritePattern { insertSlicesOp.getStrides(strides); Location loc = xferWriteOp.getLoc(); - SmallVector indices(xferWriteOp.indices().begin(), - xferWriteOp.indices().end()); - auto createSlice = [&](unsigned index, ArrayRef sliceIndices) { + SmallVector indices(xferWriteOp.indices().begin(), + xferWriteOp.indices().end()); + auto createSlice = [&](unsigned index, ArrayRef sliceIndices) { // Create split TransferWriteOp for source vector 'tupleOp.operand[i]'. rewriter.create( loc, tupleOp.getOperand(index), xferWriteOp.memref(), sliceIndices, @@ -665,7 +666,7 @@ struct TupleGetFolderOp : public OpRewritePattern { return matchFailure(); // Forward Value from 'tupleOp' at 'tupleGetOp.index'. - ValuePtr tupleValue = tupleOp.getOperand(tupleGetOp.getIndex()); + Value tupleValue = tupleOp.getOperand(tupleGetOp.getIndex()); rewriter.replaceOp(tupleGetOp, tupleValue); return matchSuccess(); } diff --git a/mlir/lib/EDSC/Builders.cpp b/mlir/lib/EDSC/Builders.cpp index b25eb987a9e..7d51cded0c5 100644 --- a/mlir/lib/EDSC/Builders.cpp +++ b/mlir/lib/EDSC/Builders.cpp @@ -79,8 +79,9 @@ ValueHandle &mlir::edsc::ValueHandle::operator=(const ValueHandle &other) { return *this; } -ValueHandle mlir::edsc::ValueHandle::createComposedAffineApply( - AffineMap map, ArrayRef operands) { +ValueHandle +mlir::edsc::ValueHandle::createComposedAffineApply(AffineMap map, + ArrayRef operands) { Operation *op = makeComposedAffineApply(ScopedContext::getBuilder(), ScopedContext::getLocation(), map, operands) @@ -108,7 +109,7 @@ OperationHandle OperationHandle::create(StringRef name, ArrayRef resultTypes, ArrayRef attributes) { OperationState state(ScopedContext::getLocation(), name); - SmallVector ops(operands.begin(), operands.end()); + SmallVector ops(operands.begin(), operands.end()); state.addOperands(ops); state.addTypes(resultTypes); for (const auto &attr : attributes) { @@ -159,8 +160,8 @@ mlir::edsc::LoopBuilder mlir::edsc::LoopBuilder::makeAffine( if (auto staticFor = emitStaticFor(lbHandles, ubHandles, step)) { *iv = staticFor.getValue(); } else { - SmallVector lbs(lbHandles.begin(), lbHandles.end()); - SmallVector ubs(ubHandles.begin(), ubHandles.end()); + SmallVector lbs(lbHandles.begin(), lbHandles.end()); + SmallVector ubs(ubHandles.begin(), ubHandles.end()); *iv = ValueHandle::create( lbs, ScopedContext::getBuilder().getMultiDimIdentityMap(lbs.size()), ubs, ScopedContext::getBuilder().getMultiDimIdentityMap(ubs.size()), @@ -299,11 +300,11 @@ static ValueHandle createBinaryHandle(ValueHandle lhs, ValueHandle rhs) { return ValueHandle::create(lhs.getValue(), rhs.getValue()); } -static std::pair -categorizeValueByAffineType(MLIRContext *context, ValuePtr val, - unsigned &numDims, unsigned &numSymbols) { +static std::pair +categorizeValueByAffineType(MLIRContext *context, Value val, unsigned &numDims, + unsigned &numSymbols) { AffineExpr d; - ValuePtr resultVal = nullptr; + Value resultVal = nullptr; if (auto constant = dyn_cast_or_null(val->getDefiningOp())) { d = getAffineConstantExpr(constant.getValue(), context); } else if (isValidSymbol(val) && !isValidDim(val)) { @@ -322,12 +323,12 @@ static ValueHandle createBinaryIndexHandle( MLIRContext *context = ScopedContext::getContext(); unsigned numDims = 0, numSymbols = 0; AffineExpr d0, d1; - ValuePtr v0, v1; + Value v0, v1; std::tie(d0, v0) = categorizeValueByAffineType(context, lhs.getValue(), numDims, numSymbols); std::tie(d1, v1) = categorizeValueByAffineType(context, rhs.getValue(), numDims, numSymbols); - SmallVector operands; + SmallVector operands; if (v0) { operands.push_back(v0); } diff --git a/mlir/lib/EDSC/Helpers.cpp b/mlir/lib/EDSC/Helpers.cpp index 79888334cd9..008948b202f 100644 --- a/mlir/lib/EDSC/Helpers.cpp +++ b/mlir/lib/EDSC/Helpers.cpp @@ -13,7 +13,7 @@ using namespace mlir; using namespace mlir::edsc; -static SmallVector getMemRefSizes(ValuePtr memRef) { +static SmallVector getMemRefSizes(Value memRef) { MemRefType memRefType = memRef->getType().cast(); assert(isStrided(memRefType) && "Expected strided MemRef type"); @@ -30,7 +30,7 @@ static SmallVector getMemRefSizes(ValuePtr memRef) { return res; } -mlir::edsc::MemRefView::MemRefView(ValuePtr v) : base(v) { +mlir::edsc::MemRefView::MemRefView(Value v) : base(v) { assert(v->getType().isa() && "MemRefType expected"); auto memrefSizeValues = getMemRefSizes(v); @@ -41,7 +41,7 @@ mlir::edsc::MemRefView::MemRefView(ValuePtr v) : base(v) { } } -mlir::edsc::VectorView::VectorView(ValuePtr v) : base(v) { +mlir::edsc::VectorView::VectorView(Value v) : base(v) { auto vectorType = v->getType().cast(); for (auto s : vectorType.getShape()) { diff --git a/mlir/lib/EDSC/Intrinsics.cpp b/mlir/lib/EDSC/Intrinsics.cpp index 1bb32b97867..d339ec06884 100644 --- a/mlir/lib/EDSC/Intrinsics.cpp +++ b/mlir/lib/EDSC/Intrinsics.cpp @@ -20,7 +20,7 @@ OperationHandle mlir::edsc::intrinsics::br(BlockHandle bh, (void)o; assert(o && "Expected already captured ValueHandle"); } - SmallVector ops(operands.begin(), operands.end()); + SmallVector ops(operands.begin(), operands.end()); return OperationHandle::create(bh.getBlock(), ops); } static void enforceEmptyCapturesMatchOperands(ArrayRef captures, @@ -43,7 +43,7 @@ OperationHandle mlir::edsc::intrinsics::br(BlockHandle *bh, assert(!*bh && "Unexpected already captured BlockHandle"); enforceEmptyCapturesMatchOperands(captures, operands); BlockBuilder(bh, captures)(/* no body */); - SmallVector ops(operands.begin(), operands.end()); + SmallVector ops(operands.begin(), operands.end()); return OperationHandle::create(bh->getBlock(), ops); } @@ -52,8 +52,8 @@ mlir::edsc::intrinsics::cond_br(ValueHandle cond, BlockHandle trueBranch, ArrayRef trueOperands, BlockHandle falseBranch, ArrayRef falseOperands) { - SmallVector trueOps(trueOperands.begin(), trueOperands.end()); - SmallVector falseOps(falseOperands.begin(), falseOperands.end()); + SmallVector trueOps(trueOperands.begin(), trueOperands.end()); + SmallVector falseOps(falseOperands.begin(), falseOperands.end()); return OperationHandle::create( cond, trueBranch.getBlock(), trueOps, falseBranch.getBlock(), falseOps); } @@ -69,8 +69,8 @@ OperationHandle mlir::edsc::intrinsics::cond_br( enforceEmptyCapturesMatchOperands(falseCaptures, falseOperands); BlockBuilder(trueBranch, trueCaptures)(/* no body */); BlockBuilder(falseBranch, falseCaptures)(/* no body */); - SmallVector trueOps(trueOperands.begin(), trueOperands.end()); - SmallVector falseOps(falseOperands.begin(), falseOperands.end()); + SmallVector trueOps(trueOperands.begin(), trueOperands.end()); + SmallVector falseOps(falseOperands.begin(), falseOperands.end()); return OperationHandle::create( cond, trueBranch->getBlock(), trueOps, falseBranch->getBlock(), falseOps); } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 4eeb5e4e95c..881a6365e20 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1428,7 +1428,7 @@ public: void printAttribute(Attribute attr) override { ModulePrinter::printAttribute(attr); } - void printOperand(ValuePtr value) override { printValueID(value); } + void printOperand(Value value) override { printValueID(value); } void printOptionalAttrDict(ArrayRef attrs, ArrayRef elidedAttrs = {}) override { @@ -1510,7 +1510,7 @@ protected: void numberValuesInRegion(Region ®ion); void numberValuesInBlock(Block &block); void numberValuesInOp(Operation &op); - void printValueID(ValuePtr value, bool printResultNo = true) const { + void printValueID(Value value, bool printResultNo = true) const { printValueIDImpl(value, printResultNo, os); } @@ -1519,13 +1519,13 @@ private: /// 'lookupValue' and the result of 'result' within that group in /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group /// has more than 1 result. - void getResultIDAndNumber(OpResultPtr result, ValuePtr &lookupValue, + void getResultIDAndNumber(OpResult result, Value &lookupValue, int &lookupResultNo) const; - void printValueIDImpl(ValuePtr value, bool printResultNo, + void printValueIDImpl(Value value, bool printResultNo, raw_ostream &stream) const; /// Set a special value name for the given value. - void setValueName(ValuePtr value, StringRef name); + void setValueName(Value value, StringRef name); /// Uniques the given value name within the printer. If the given name /// conflicts, it is automatically renamed. @@ -1533,8 +1533,8 @@ private: /// This is the value ID for each SSA value. If this returns ~0, then the /// valueID has an entry in valueNames. - DenseMap valueIDs; - DenseMap valueNames; + DenseMap valueIDs; + DenseMap valueNames; /// This is a map of operations that contain multiple named result groups, /// i.e. there may be multiple names for the results of the operation. The key @@ -1610,7 +1610,7 @@ void OperationPrinter::numberValuesInRegion(Region ®ion) { } void OperationPrinter::numberValuesInBlock(Block &block) { - auto setArgNameFn = [&](ValuePtr arg, StringRef name) { + auto setArgNameFn = [&](Value arg, StringRef name) { assert(!valueIDs.count(arg) && "arg numbered multiple times"); assert(arg.cast()->getOwner() == &block && "arg not defined in 'block'"); @@ -1648,11 +1648,11 @@ void OperationPrinter::numberValuesInOp(Operation &op) { unsigned numResults = op.getNumResults(); if (numResults == 0) return; - ValuePtr resultBegin = op.getResult(0); + Value resultBegin = op.getResult(0); // Function used to set the special result names for the operation. SmallVector resultGroups(/*Size=*/1, /*Value=*/0); - auto setResultNameFn = [&](ValuePtr result, StringRef name) { + auto setResultNameFn = [&](Value result, StringRef name) { assert(!valueIDs.count(result) && "result numbered multiple times"); assert(result->getDefiningOp() == &op && "result not defined by 'op'"); setValueName(result, name); @@ -1681,7 +1681,7 @@ void OperationPrinter::numberValuesInOp(Operation &op) { } /// Set a special value name for the given value. -void OperationPrinter::setValueName(ValuePtr value, StringRef name) { +void OperationPrinter::setValueName(Value value, StringRef name) { // If the name is empty, the value uses the default numbering. if (name.empty()) { valueIDs[value] = nextValueID++; @@ -1728,7 +1728,7 @@ void OperationPrinter::print(Block *block, bool printBlockArgs, // Print the argument list if non-empty. if (!block->args_empty()) { os << '('; - interleaveComma(block->getArguments(), [&](BlockArgumentPtr arg) { + interleaveComma(block->getArguments(), [&](BlockArgument arg) { printValueID(arg); os << ": "; printType(arg->getType()); @@ -1779,8 +1779,7 @@ void OperationPrinter::print(Operation *op) { printTrailingLocation(op->getLoc()); } -void OperationPrinter::getResultIDAndNumber(OpResultPtr result, - ValuePtr &lookupValue, +void OperationPrinter::getResultIDAndNumber(OpResult result, Value &lookupValue, int &lookupResultNo) const { Operation *owner = result->getOwner(); if (owner->getNumResults() == 1) @@ -1818,7 +1817,7 @@ void OperationPrinter::getResultIDAndNumber(OpResultPtr result, lookupValue = owner->getResult(groupResultNo); } -void OperationPrinter::printValueIDImpl(ValuePtr value, bool printResultNo, +void OperationPrinter::printValueIDImpl(Value value, bool printResultNo, raw_ostream &stream) const { if (!value) { stream << "<>"; @@ -1831,7 +1830,7 @@ void OperationPrinter::printValueIDImpl(ValuePtr value, bool printResultNo, // If this is a reference to the result of a multi-result operation or // operation, print out the # identifier and make sure to map our lookup // to the first result of the operation. - if (OpResultPtr result = value.dyn_cast()) + if (OpResult result = value.dyn_cast()) getResultIDAndNumber(result, lookupValue, resultNo); auto it = valueIDs.find(lookupValue); @@ -1942,10 +1941,10 @@ void OperationPrinter::printGenericOp(Operation *op) { for (unsigned i = 0; i < numSuccessors; ++i) totalNumSuccessorOperands += op->getNumSuccessorOperands(i); unsigned numProperOperands = op->getNumOperands() - totalNumSuccessorOperands; - SmallVector properOperands( + SmallVector properOperands( op->operand_begin(), std::next(op->operand_begin(), numProperOperands)); - interleaveComma(properOperands, [&](ValuePtr value) { printValueID(value); }); + interleaveComma(properOperands, [&](Value value) { printValueID(value); }); os << ')'; @@ -1988,10 +1987,10 @@ void OperationPrinter::printSuccessorAndUseList(Operation *term, os << '('; interleaveComma(succOperands, - [this](ValuePtr operand) { printValueID(operand); }); + [this](Value operand) { printValueID(operand); }); os << " : "; interleaveComma(succOperands, - [this](ValuePtr operand) { printType(operand->getType()); }); + [this](Value operand) { printType(operand->getType()); }); os << ')'; } diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp index 751ceb1bfb4..b0ada9981a8 100644 --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -143,7 +143,7 @@ void Block::recomputeOpOrder() { // Argument list management. //===----------------------------------------------------------------------===// -BlockArgumentPtr Block::addArgument(Type type) { +BlockArgument Block::addArgument(Type type) { BlockArgument arg = BlockArgument::create(type, this); arguments.push_back(arg); return arg; diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 2ef10b6e669..5567f873b5e 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -334,7 +334,7 @@ Operation *OpBuilder::createOperation(const OperationState &state) { /// 'results'. Returns success if the operation was folded, failure otherwise. /// Note: This function does not erase the operation on a successful fold. LogicalResult OpBuilder::tryFold(Operation *op, - SmallVectorImpl &results) { + SmallVectorImpl &results) { results.reserve(op->getNumResults()); auto cleanupFailure = [&] { results.assign(op->result_begin(), op->result_end()); @@ -365,7 +365,7 @@ LogicalResult OpBuilder::tryFold(Operation *op, Dialect *dialect = op->getDialect(); for (auto &it : llvm::enumerate(foldResults)) { // Normal values get pushed back directly. - if (auto value = it.value().dyn_cast()) { + if (auto value = it.value().dyn_cast()) { results.push_back(value); continue; } diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 77288b228aa..c7baba840e0 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -111,7 +111,7 @@ template <> unsigned BlockOperand::getOperandNumber() { /// Create a new Operation with the specific fields. Operation *Operation::create(Location location, OperationName name, ArrayRef resultTypes, - ArrayRef operands, + ArrayRef operands, ArrayRef attributes, ArrayRef successors, unsigned numRegions, bool resizableOperandList) { @@ -131,7 +131,7 @@ Operation *Operation::create(const OperationState &state) { /// Create a new Operation with the specific fields. Operation *Operation::create(Location location, OperationName name, ArrayRef resultTypes, - ArrayRef operands, + ArrayRef operands, NamedAttributeList attributes, ArrayRef successors, RegionRange regions, bool resizableOperandList) { @@ -148,7 +148,7 @@ Operation *Operation::create(Location location, OperationName name, /// unnecessarily uniquing a list of attributes. Operation *Operation::create(Location location, OperationName name, ArrayRef resultTypes, - ArrayRef operands, + ArrayRef operands, NamedAttributeList attributes, ArrayRef successors, unsigned numRegions, bool resizableOperandList) { @@ -311,7 +311,7 @@ bool Operation::isProperAncestor(Operation *other) { } /// Replace any uses of 'from' with 'to' within this operation. -void Operation::replaceUsesOfWith(ValuePtr from, ValuePtr to) { +void Operation::replaceUsesOfWith(Value from, Value to) { if (from == to) return; for (auto &operand : getOpOperands()) @@ -669,7 +669,7 @@ InFlightDiagnostic Operation::emitOpError(const Twine &message) { /// Operands are remapped using `mapper` (if present), and `mapper` is updated /// to contain the results. Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper) { - SmallVector operands; + SmallVector operands; SmallVector successors; operands.reserve(getNumOperands() + getNumSuccessors()); @@ -1089,8 +1089,8 @@ LogicalResult OpTrait::impl::verifyResultSizeAttr(Operation *op, // These functions are out-of-line implementations of the methods in BinaryOp, // which avoids them being template instantiated/duplicated. -void impl::buildBinaryOp(Builder *builder, OperationState &result, ValuePtr lhs, - ValuePtr rhs) { +void impl::buildBinaryOp(Builder *builder, OperationState &result, Value lhs, + Value rhs) { assert(lhs->getType() == rhs->getType()); result.addOperands({lhs, rhs}); result.types.push_back(lhs->getType()); @@ -1130,8 +1130,8 @@ void impl::printOneResultOp(Operation *op, OpAsmPrinter &p) { // CastOp implementation //===----------------------------------------------------------------------===// -void impl::buildCastOp(Builder *builder, OperationState &result, - ValuePtr source, Type destType) { +void impl::buildCastOp(Builder *builder, OperationState &result, Value source, + Type destType) { result.addOperands(source); result.addTypes(destType); } @@ -1154,7 +1154,7 @@ void impl::printCastOp(Operation *op, OpAsmPrinter &p) { << op->getResult(0)->getType(); } -ValuePtr impl::foldCastOp(Operation *op) { +Value impl::foldCastOp(Operation *op) { // Identity cast if (op->getOperand(0)->getType() == op->getResult(0)->getType()) return op->getOperand(0); diff --git a/mlir/lib/IR/Region.cpp b/mlir/lib/IR/Region.cpp index 935854a5365..1e8abc884dd 100644 --- a/mlir/lib/IR/Region.cpp +++ b/mlir/lib/IR/Region.cpp @@ -134,7 +134,7 @@ static bool isIsolatedAbove(Region ®ion, Region &limit, while (!pendingRegions.empty()) { for (Block &block : *pendingRegions.pop_back_val()) { for (Operation &op : block) { - for (ValuePtr operand : op.getOperands()) { + for (Value operand : op.getOperands()) { // operand should be non-null here if the IR is well-formed. But // we don't assert here as this function is called from the verifier // and so could be called on invalid IR. diff --git a/mlir/lib/IR/TypeUtilities.cpp b/mlir/lib/IR/TypeUtilities.cpp index 1fa13a85c51..0bf1627b9d5 100644 --- a/mlir/lib/IR/TypeUtilities.cpp +++ b/mlir/lib/IR/TypeUtilities.cpp @@ -24,7 +24,7 @@ Type mlir::getElementTypeOrSelf(Type type) { return type; } -Type mlir::getElementTypeOrSelf(ValuePtr val) { +Type mlir::getElementTypeOrSelf(Value val) { return getElementTypeOrSelf(val->getType()); } @@ -88,18 +88,18 @@ LogicalResult mlir::verifyCompatibleShape(Type type1, Type type2) { OperandElementTypeIterator::OperandElementTypeIterator( Operation::operand_iterator it) - : llvm::mapped_iterator( + : llvm::mapped_iterator( it, &unwrap) {} -Type OperandElementTypeIterator::unwrap(ValuePtr value) { +Type OperandElementTypeIterator::unwrap(Value value) { return value->getType().cast().getElementType(); } ResultElementTypeIterator::ResultElementTypeIterator( Operation::result_iterator it) - : llvm::mapped_iterator( + : llvm::mapped_iterator( it, &unwrap) {} -Type ResultElementTypeIterator::unwrap(ValuePtr value) { +Type ResultElementTypeIterator::unwrap(Value value) { return value->getType().cast().getElementType(); } diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index e25f4d19654..0198a45172b 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -3084,7 +3084,7 @@ public: ParseResult popSSANameScope(); /// Register a definition of a value with the symbol table. - ParseResult addDefinition(SSAUseInfo useInfo, ValuePtr value); + ParseResult addDefinition(SSAUseInfo useInfo, Value value); /// Parse an optional list of SSA uses into 'results'. ParseResult parseOptionalSSAUseList(SmallVectorImpl &results); @@ -3094,13 +3094,12 @@ public: /// Given a reference to an SSA value and its type, return a reference. This /// returns null on failure. - ValuePtr resolveSSAUse(SSAUseInfo useInfo, Type type); + Value resolveSSAUse(SSAUseInfo useInfo, Type type); ParseResult parseSSADefOrUseAndType( const std::function &action); - ParseResult - parseOptionalSSAUseAndTypeList(SmallVectorImpl &results); + ParseResult parseOptionalSSAUseAndTypeList(SmallVectorImpl &results); /// Return the location of the value identified by its name and number if it /// has been already reference. @@ -3122,12 +3121,11 @@ public: /// Parse a single operation successor and its operand list. ParseResult parseSuccessorAndUseList(Block *&dest, - SmallVectorImpl &operands); + SmallVectorImpl &operands); /// Parse a comma-separated list of operation successors in brackets. - ParseResult - parseSuccessors(SmallVectorImpl &destinations, - SmallVectorImpl> &operands); + ParseResult parseSuccessors(SmallVectorImpl &destinations, + SmallVectorImpl> &operands); /// Parse an operation instance that is in the generic form. Operation *parseGenericOperation(); @@ -3165,9 +3163,8 @@ public: ParseResult parseBlockBody(Block *block); /// Parse a (possibly empty) list of block arguments. - ParseResult - parseOptionalBlockArgList(SmallVectorImpl &results, - Block *owner); + ParseResult parseOptionalBlockArgList(SmallVectorImpl &results, + Block *owner); /// Get the block with the specified name, creating it if it doesn't /// already exist. The location specified is the point of use, which allows @@ -3196,14 +3193,14 @@ private: void recordDefinition(StringRef def); /// Get the value entry for the given SSA name. - SmallVectorImpl> &getSSAValueEntry(StringRef name); + SmallVectorImpl> &getSSAValueEntry(StringRef name); /// Create a forward reference placeholder value with the given location and /// result type. - ValuePtr createForwardRefPlaceholder(SMLoc loc, Type type); + Value createForwardRefPlaceholder(SMLoc loc, Type type); /// Return true if this is a forward reference. - bool isForwardRefPlaceholder(ValuePtr value) { + bool isForwardRefPlaceholder(Value value) { return forwardRefPlaceholders.count(value); } @@ -3228,7 +3225,7 @@ private: /// This keeps track of all of the SSA values we are tracking for each name /// scope, indexed by their name. This has one entry per result number. - llvm::StringMap, 1>> values; + llvm::StringMap, 1>> values; /// This keeps track of all of the values defined by a specific name scope. SmallVector, 2> definitionsPerScope; @@ -3245,7 +3242,7 @@ private: /// These are all of the placeholders we've made along with the location of /// their first reference, to allow checking for use of undefined values. - DenseMap forwardRefPlaceholders; + DenseMap forwardRefPlaceholders; /// The builder used when creating parsed operation instances. OpBuilder opBuilder; @@ -3270,7 +3267,7 @@ ParseResult OperationParser::finalize() { // Check for any forward references that are left. If we find any, error // out. if (!forwardRefPlaceholders.empty()) { - SmallVector, 4> errors; + SmallVector, 4> errors; // Iteration over the map isn't deterministic, so sort by source location. for (auto entry : forwardRefPlaceholders) errors.push_back({entry.second.getPointer(), entry.first}); @@ -3334,7 +3331,7 @@ ParseResult OperationParser::popSSANameScope() { } /// Register a definition of a value with the symbol table. -ParseResult OperationParser::addDefinition(SSAUseInfo useInfo, ValuePtr value) { +ParseResult OperationParser::addDefinition(SSAUseInfo useInfo, Value value) { auto &entries = getSSAValueEntry(useInfo.name); // Make sure there is a slot for this value. @@ -3408,7 +3405,7 @@ ParseResult OperationParser::parseSSAUse(SSAUseInfo &result) { /// Given an unbound reference to an SSA value and its type, return the value /// it specifies. This returns null on failure. -ValuePtr OperationParser::resolveSSAUse(SSAUseInfo useInfo, Type type) { +Value OperationParser::resolveSSAUse(SSAUseInfo useInfo, Type type) { auto &entries = getSSAValueEntry(useInfo.name); // If we have already seen a value of this name, return it. @@ -3469,7 +3466,7 @@ ParseResult OperationParser::parseSSADefOrUseAndType( /// ::= ssa-use-list ':' type-list-no-parens /// ParseResult OperationParser::parseOptionalSSAUseAndTypeList( - SmallVectorImpl &results) { + SmallVectorImpl &results) { SmallVector valueIDs; if (parseOptionalSSAUseList(valueIDs)) return failure(); @@ -3504,13 +3501,13 @@ void OperationParser::recordDefinition(StringRef def) { } /// Get the value entry for the given SSA name. -SmallVectorImpl> & +SmallVectorImpl> & OperationParser::getSSAValueEntry(StringRef name) { return isolatedNameScopes.back().values[name]; } /// Create and remember a new placeholder for a forward reference. -ValuePtr OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) { +Value OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) { // Forward references are always created as operations, because we just need // something with a def/use chain. // @@ -3624,7 +3621,7 @@ ParseResult OperationParser::parseOperation() { /// ParseResult OperationParser::parseSuccessorAndUseList(Block *&dest, - SmallVectorImpl &operands) { + SmallVectorImpl &operands) { // Verify branch is identifier and get the matching block. if (!getToken().is(Token::caret_identifier)) return emitError("expected block name"); @@ -3647,13 +3644,13 @@ OperationParser::parseSuccessorAndUseList(Block *&dest, /// ParseResult OperationParser::parseSuccessors( SmallVectorImpl &destinations, - SmallVectorImpl> &operands) { + SmallVectorImpl> &operands) { if (parseToken(Token::l_square, "expected '['")) return failure(); auto parseElt = [this, &destinations, &operands]() { Block *dest; - SmallVector destOperands; + SmallVector destOperands; auto res = parseSuccessorAndUseList(dest, destOperands); destinations.push_back(dest); operands.push_back(destOperands); @@ -3710,7 +3707,7 @@ Operation *OperationParser::parseGenericOperation() { // Parse the successor list but don't add successors to the result yet to // avoid messing up with the argument order. SmallVector successors; - SmallVector, 2> successorOperands; + SmallVector, 2> successorOperands; if (getToken().is(Token::l_square)) { // Check if the operation is a known terminator. const AbstractOperation *abstractOp = result.name.getAbstractOperation(); @@ -3771,7 +3768,7 @@ Operation *OperationParser::parseGenericOperation() { // Add the successors, and their operands after the proper operands. for (const auto &succ : llvm::zip(successors, successorOperands)) { Block *successor = std::get<0>(succ); - const SmallVector &operands = std::get<1>(succ); + const SmallVector &operands = std::get<1>(succ); result.addSuccessor(successor, operands); } @@ -4121,7 +4118,7 @@ public: /// Resolve an operand to an SSA value, emitting an error on failure. ParseResult resolveOperand(const OperandType &operand, Type type, - SmallVectorImpl &result) override { + SmallVectorImpl &result) override { OperationParser::SSAUseInfo operandInfo = {operand.name, operand.number, operand.location}; if (auto value = parser.resolveSSAUse(operandInfo, type)) { @@ -4234,7 +4231,7 @@ public: /// Parse a single operation successor and its operand list. ParseResult parseSuccessorAndUseList(Block *&dest, - SmallVectorImpl &operands) override { + SmallVectorImpl &operands) override { return parser.parseSuccessorAndUseList(dest, operands); } @@ -4462,7 +4459,7 @@ ParseResult OperationParser::parseBlock(Block *&block) { // If an argument list is present, parse it. if (consumeIf(Token::l_paren)) { - SmallVector bbArgs; + SmallVector bbArgs; if (parseOptionalBlockArgList(bbArgs, block) || parseToken(Token::r_paren, "expected ')' to end argument list")) return failure(); @@ -4526,7 +4523,7 @@ Block *OperationParser::defineBlockNamed(StringRef name, SMLoc loc, /// ssa-id-and-type-list ::= ssa-id-and-type (`,` ssa-id-and-type)* /// ParseResult OperationParser::parseOptionalBlockArgList( - SmallVectorImpl &results, Block *owner) { + SmallVectorImpl &results, Block *owner) { if (getToken().is(Token::r_brace)) return success(); diff --git a/mlir/lib/Pass/IRPrinting.cpp b/mlir/lib/Pass/IRPrinting.cpp index 132a0bec4b7..75aadbdf5cb 100644 --- a/mlir/lib/Pass/IRPrinting.cpp +++ b/mlir/lib/Pass/IRPrinting.cpp @@ -39,14 +39,14 @@ public: for (Region ®ion : op->getRegions()) { for (Block &block : region) { addDataToHash(hasher, &block); - for (BlockArgumentPtr arg : block.getArguments()) + for (BlockArgument arg : block.getArguments()) addDataToHash(hasher, arg); } } // - Location addDataToHash(hasher, op->getLoc().getAsOpaquePointer()); // - Operands - for (ValuePtr operand : op->getOperands()) + for (Value operand : op->getOperands()) addDataToHash(hasher, operand); // - Successors for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) diff --git a/mlir/lib/Quantizer/Support/ConstraintAnalysisGraph.cpp b/mlir/lib/Quantizer/Support/ConstraintAnalysisGraph.cpp index 38aa5dc811b..3c194bbd459 100644 --- a/mlir/lib/Quantizer/Support/ConstraintAnalysisGraph.cpp +++ b/mlir/lib/Quantizer/Support/ConstraintAnalysisGraph.cpp @@ -93,7 +93,7 @@ void CAGSlice::enumerateImpliedConnections( std::vector> impliedPairs; for (auto &resultAnchorPair : resultAnchors) { CAGResultAnchor *resultAnchor = resultAnchorPair.second; - ValuePtr resultValue = resultAnchor->getValue(); + Value resultValue = resultAnchor->getValue(); for (auto &use : resultValue->getUses()) { Operation *operandOp = use.getOwner(); unsigned operandIdx = use.getOperandNumber(); diff --git a/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp b/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp index c8569c2fe19..5ecb668ce55 100644 --- a/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp +++ b/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp @@ -172,17 +172,17 @@ void InferQuantizedTypesPass::runWithConfig(SolverContext &solverContext, void InferQuantizedTypesPass::transformOperandType(CAGOperandAnchor *anchor, Type newType) { - ValuePtr inputValue = anchor->getValue(); + Value inputValue = anchor->getValue(); Operation *op = anchor->getOp(); OpBuilder b(op->getBlock(), Block::iterator(op)); - SmallVector removeValuesIfDead; + SmallVector removeValuesIfDead; // Because we've already run the result transforms at this phase, it is // very likely that inputValue points to a dcast op whose input matches // our type. We detect that situation and route around just to save some // bulk in the IR. - ValuePtr newTypedInputValue = inputValue; + Value newTypedInputValue = inputValue; auto inputDcastOp = dyn_cast_or_null(inputValue->getDefiningOp()); if (inputDcastOp && inputDcastOp.arg()->getType() == newType) { @@ -219,7 +219,7 @@ void InferQuantizedTypesPass::transformOperandType(CAGOperandAnchor *anchor, break; } - for (ValuePtr removeValueIfDead : removeValuesIfDead) { + for (Value removeValueIfDead : removeValuesIfDead) { if (removeValueIfDead->use_empty()) { removeValueIfDead->getDefiningOp()->erase(); } @@ -228,12 +228,12 @@ void InferQuantizedTypesPass::transformOperandType(CAGOperandAnchor *anchor, void InferQuantizedTypesPass::transformResultType(CAGResultAnchor *anchor, Type newType) { - ValuePtr origResultValue = anchor->getValue(); + Value origResultValue = anchor->getValue(); Operation *op = origResultValue->getDefiningOp(); OpBuilder b(op->getBlock(), ++Block::iterator(op)); - ValuePtr replacedResultValue = nullptr; - ValuePtr newResultValue = nullptr; + Value replacedResultValue = nullptr; + Value newResultValue = nullptr; switch (anchor->getTypeTransformRule()) { case CAGAnchorNode::TypeTransformRule::Direct: origResultValue->setType(newType); diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp index 1045b784ae2..ada2af8fb47 100644 --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -215,7 +215,7 @@ tblgen::SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const { return formatv("Operation::operand_range {0}(op0->getOperands());\n", name); } case Kind::Value: { - return formatv("ArrayRef {0};\n", name); + return formatv("ArrayRef {0};\n", name); } case Kind::Result: { // Use the op itself for captured results. diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp index 6f3e2ef21aa..4466fb5fe26 100644 --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -67,7 +67,7 @@ private: /// `value` is an SSA-use. Return the remapped version of `value` or a /// placeholder that will be remapped later if this is an instruction that /// has not yet been visited. - ValuePtr processValue(llvm::Value *value); + Value processValue(llvm::Value *value); /// Create the most accurate Location possible using a llvm::DebugLoc and /// possibly an llvm::Instruction to narrow the Location if debug information /// is unavailable. @@ -76,14 +76,14 @@ private: /// `br` branches to `target`. Return the block arguments to attach to the /// generated branch op. These should be in the same order as the PHIs in /// `target`. - SmallVector processBranchArgs(llvm::BranchInst *br, - llvm::BasicBlock *target); + SmallVector processBranchArgs(llvm::BranchInst *br, + llvm::BasicBlock *target); /// Return `value` as an attribute to attach to a GlobalOp. Attribute getConstantAsAttr(llvm::Constant *value); /// Return `c` as an MLIR Value. This could either be a ConstantOp, or /// an expanded sequence of ops in the current function's entry block (for /// ConstantExprs or ConstantGEPs). - ValuePtr processConstant(llvm::Constant *c); + Value processConstant(llvm::Constant *c); /// The current builder, pointing at where the next Instruction should be /// generated. @@ -111,7 +111,7 @@ private: /// Remapped blocks, for the current function. DenseMap blocks; /// Remapped values. These are function-local. - DenseMap instMap; + DenseMap instMap; /// Instructions that had not been defined when first encountered as a use. /// Maps to the dummy Operation that was created in processValue(). DenseMap unknownInstMap; @@ -254,13 +254,13 @@ GlobalOp Importer::processGlobal(llvm::GlobalVariable *GV) { Region &r = op.getInitializerRegion(); currentEntryBlock = b.createBlock(&r); b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin()); - ValuePtr v = processConstant(GV->getInitializer()); - b.create(op.getLoc(), ArrayRef({v})); + Value v = processConstant(GV->getInitializer()); + b.create(op.getLoc(), ArrayRef({v})); } return globals[GV] = op; } -ValuePtr Importer::processConstant(llvm::Constant *c) { +Value Importer::processConstant(llvm::Constant *c) { if (Attribute attr = getConstantAsAttr(c)) { // These constants can be represented as attributes. OpBuilder b(currentEntryBlock, currentEntryBlock->begin()); @@ -289,7 +289,7 @@ ValuePtr Importer::processConstant(llvm::Constant *c) { return nullptr; } -ValuePtr Importer::processValue(llvm::Value *value) { +Value Importer::processValue(llvm::Value *value) { auto it = instMap.find(value); if (it != instMap.end()) return it->second; @@ -398,9 +398,9 @@ static ICmpPredicate getICmpPredicate(llvm::CmpInst::Predicate p) { // `br` branches to `target`. Return the branch arguments to `br`, in the // same order of the PHIs in `target`. -SmallVector Importer::processBranchArgs(llvm::BranchInst *br, - llvm::BasicBlock *target) { - SmallVector v; +SmallVector Importer::processBranchArgs(llvm::BranchInst *br, + llvm::BasicBlock *target) { + SmallVector v; for (auto inst = target->begin(); isa(inst); ++inst) { auto *PN = cast(&*inst); v.push_back(processValue(PN->getIncomingValueForBlock(br->getParent()))); @@ -412,7 +412,7 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) { // FIXME: Support uses of SubtargetData. Currently inbounds GEPs, fast-math // flags and call / operand attributes are not supported. Location loc = processDebugLoc(inst->getDebugLoc(), inst); - ValuePtr &v = instMap[inst]; + Value &v = instMap[inst]; assert(!v && "processInstruction must be called only once per instruction!"); switch (inst->getOpcode()) { default: @@ -453,7 +453,7 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) { case llvm::Instruction::AddrSpaceCast: case llvm::Instruction::BitCast: { OperationState state(loc, opcMap.lookup(inst->getOpcode())); - SmallVector ops; + SmallVector ops; ops.reserve(inst->getNumOperands()); for (auto *op : inst->operand_values()) ops.push_back(processValue(op)); @@ -475,7 +475,7 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) { auto *brInst = cast(inst); OperationState state(loc, brInst->isConditional() ? "llvm.cond_br" : "llvm.br"); - SmallVector ops; + SmallVector ops; if (brInst->isConditional()) ops.push_back(processValue(brInst->getCondition())); state.addOperands(ops); @@ -491,7 +491,7 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) { } case llvm::Instruction::Call: { llvm::CallInst *ci = cast(inst); - SmallVector ops; + SmallVector ops; ops.reserve(inst->getNumOperands()); for (auto &op : ci->arg_operands()) ops.push_back(processValue(op.get())); @@ -514,7 +514,7 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) { case llvm::Instruction::GetElementPtr: { // FIXME: Support inbounds GEPs. llvm::GetElementPtrInst *gep = cast(inst); - SmallVector ops; + SmallVector ops; for (auto *op : gep->operand_values()) ops.push_back(processValue(op)); v = b.create(loc, processType(inst->getType()), ops, @@ -556,8 +556,8 @@ LogicalResult Importer::processFunction(llvm::Function *f) { // any unknown uses we encountered are remapped. for (auto &llvmAndUnknown : unknownInstMap) { assert(instMap.count(llvmAndUnknown.first)); - ValuePtr newValue = instMap[llvmAndUnknown.first]; - ValuePtr oldValue = llvmAndUnknown.second->getResult(0); + Value newValue = instMap[llvmAndUnknown.first]; + Value oldValue = llvmAndUnknown.second->getResult(0); oldValue->replaceAllUsesWith(newValue); llvmAndUnknown.second->erase(); } diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index e8376364c41..e3c0768ef33 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -333,8 +333,8 @@ void ModuleTranslation::convertGlobals() { /// Get the SSA value passed to the current block from the terminator operation /// of its predecessor. -static ValuePtr getPHISourceValue(Block *current, Block *pred, - unsigned numArguments, unsigned index) { +static Value getPHISourceValue(Block *current, Block *pred, + unsigned numArguments, unsigned index) { auto &terminator = *pred->getTerminator(); if (isa(terminator)) { return terminator.getOperand(index); @@ -411,7 +411,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) { unsigned int argIdx = 0; for (const auto &kvp : llvm::zip(func.getArguments(), llvmFunc->args())) { llvm::Argument &llvmArg = std::get<1>(kvp); - BlockArgumentPtr mlirArg = std::get<0>(kvp); + BlockArgument mlirArg = std::get<0>(kvp); if (auto attr = func.getArgAttrOfType(argIdx, "llvm.noalias")) { // NB: Attribute already verified to be boolean, so check if we can indeed @@ -488,7 +488,7 @@ SmallVector ModuleTranslation::lookupValues(ValueRange values) { SmallVector remapped; remapped.reserve(values.size()); - for (ValuePtr v : values) + for (Value v : values) remapped.push_back(valueMapping.lookup(v)); return remapped; } diff --git a/mlir/lib/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Transforms/AffineDataCopyGeneration.cpp index 1e1b8775d32..902f5c3adcb 100644 --- a/mlir/lib/Transforms/AffineDataCopyGeneration.cpp +++ b/mlir/lib/Transforms/AffineDataCopyGeneration.cpp @@ -121,7 +121,7 @@ struct AffineDataCopyGeneration bool skipNonUnitStrideLoops; // Constant zero index to avoid too many duplicates. - ValuePtr zeroIndex = nullptr; + Value zeroIndex = nullptr; }; } // end anonymous namespace diff --git a/mlir/lib/Transforms/AffineLoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/AffineLoopInvariantCodeMotion.cpp index 1f33c0f5dca..24ec2d7c70b 100644 --- a/mlir/lib/Transforms/AffineLoopInvariantCodeMotion.cpp +++ b/mlir/lib/Transforms/AffineLoopInvariantCodeMotion.cpp @@ -49,15 +49,15 @@ struct LoopInvariantCodeMotion : public FunctionPass { } // end anonymous namespace static bool -checkInvarianceOfNestedIfOps(Operation *op, ValuePtr indVar, +checkInvarianceOfNestedIfOps(Operation *op, Value indVar, SmallPtrSetImpl &definedOps, SmallPtrSetImpl &opsToHoist); -static bool isOpLoopInvariant(Operation &op, ValuePtr indVar, +static bool isOpLoopInvariant(Operation &op, Value indVar, SmallPtrSetImpl &definedOps, SmallPtrSetImpl &opsToHoist); static bool -areAllOpsInTheBlockListInvariant(Region &blockList, ValuePtr indVar, +areAllOpsInTheBlockListInvariant(Region &blockList, Value indVar, SmallPtrSetImpl &definedOps, SmallPtrSetImpl &opsToHoist); @@ -70,7 +70,7 @@ static bool isMemRefDereferencingOp(Operation &op) { } // Returns true if the individual op is loop invariant. -bool isOpLoopInvariant(Operation &op, ValuePtr indVar, +bool isOpLoopInvariant(Operation &op, Value indVar, SmallPtrSetImpl &definedOps, SmallPtrSetImpl &opsToHoist) { LLVM_DEBUG(llvm::dbgs() << "iterating on op: " << op;); @@ -88,9 +88,9 @@ bool isOpLoopInvariant(Operation &op, ValuePtr indVar, return false; } else if (!isa(op)) { if (isMemRefDereferencingOp(op)) { - ValuePtr memref = isa(op) - ? cast(op).getMemRef() - : cast(op).getMemRef(); + Value memref = isa(op) + ? cast(op).getMemRef() + : cast(op).getMemRef(); for (auto *user : memref->getUsers()) { // If this memref has a user that is a DMA, give up because these // operations write to this memref. @@ -154,8 +154,7 @@ bool isOpLoopInvariant(Operation &op, ValuePtr indVar, // Checks if all ops in a region (i.e. list of blocks) are loop invariant. bool areAllOpsInTheBlockListInvariant( - Region &blockList, ValuePtr indVar, - SmallPtrSetImpl &definedOps, + Region &blockList, Value indVar, SmallPtrSetImpl &definedOps, SmallPtrSetImpl &opsToHoist) { for (auto &b : blockList) { @@ -170,7 +169,7 @@ bool areAllOpsInTheBlockListInvariant( } // Returns true if the affine.if op can be hoisted. -bool checkInvarianceOfNestedIfOps(Operation *op, ValuePtr indVar, +bool checkInvarianceOfNestedIfOps(Operation *op, Value indVar, SmallPtrSetImpl &definedOps, SmallPtrSetImpl &opsToHoist) { assert(isa(op)); diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index c9fcb670180..5f7fb7a68c9 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -77,13 +77,13 @@ namespace { struct ConversionValueMapping { /// Lookup a mapped value within the map. If a mapping for the provided value /// does not exist then return the provided value. - ValuePtr lookupOrDefault(ValuePtr from) const; + Value lookupOrDefault(Value from) const; /// Map a value to the one provided. - void map(ValuePtr oldVal, ValuePtr newVal) { mapping.map(oldVal, newVal); } + void map(Value oldVal, Value newVal) { mapping.map(oldVal, newVal); } /// Drop the last mapping for the given value. - void erase(ValuePtr value) { mapping.erase(value); } + void erase(Value value) { mapping.erase(value); } private: /// Current value mappings. @@ -93,7 +93,7 @@ private: /// Lookup a mapped value within the map. If a mapping for the provided value /// does not exist then return the provided value. -ValuePtr ConversionValueMapping::lookupOrDefault(ValuePtr from) const { +Value ConversionValueMapping::lookupOrDefault(Value from) const { // If this value had a valid mapping, unmap that value as well in the case // that it was also replaced. while (auto mappedValue = mapping.lookupOrNull(from)) @@ -118,7 +118,7 @@ struct ArgConverter { /// been converted. struct ConvertedArgInfo { ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize, - ValuePtr castValue = nullptr) + Value castValue = nullptr) : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {} /// The start index of in the new argument list that contains arguments that @@ -130,7 +130,7 @@ struct ArgConverter { /// The cast value that was created to cast from the new arguments to the /// old. This only used if 'newArgSize' > 1. - ValuePtr castValue; + Value castValue; }; /// This structure contains information pertaining to a block that has had its @@ -226,7 +226,7 @@ void ArgConverter::notifyOpRemoved(Operation *op) { // Drop all uses of the original arguments and delete the original block. Block *origBlock = it->second.origBlock; - for (BlockArgumentPtr arg : origBlock->getArguments()) + for (BlockArgument arg : origBlock->getArguments()) arg->dropAllUses(); conversionInfo.erase(it); } @@ -261,7 +261,7 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) { // Process the remapping for each of the original arguments. for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) { Optional &argInfo = blockInfo.argInfo[i]; - BlockArgumentPtr origArg = origBlock->getArgument(i); + BlockArgument origArg = origBlock->getArgument(i); // Handle the case of a 1->0 value mapping. if (!argInfo) { @@ -296,7 +296,7 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) { } // Otherwise this is a 1->N value mapping. - ValuePtr castValue = argInfo->castValue; + Value castValue = argInfo->castValue; assert(argInfo->newArgSize > 1 && castValue && "expected 1->N mapping"); // If the argument is still used, replace it with the generated cast. @@ -335,8 +335,8 @@ Block *ArgConverter::applySignatureConversion( Block *newBlock = block->splitBlock(block->begin()); block->replaceAllUsesWith(newBlock); - SmallVector newArgRange(newBlock->addArguments(convertedTypes)); - ArrayRef newArgs(newArgRange); + SmallVector newArgRange(newBlock->addArguments(convertedTypes)); + ArrayRef newArgs(newArgRange); // Remap each of the original arguments as determined by the signature // conversion. @@ -349,7 +349,7 @@ Block *ArgConverter::applySignatureConversion( auto inputMap = signatureConversion.getInputMapping(i); if (!inputMap) continue; - BlockArgumentPtr origArg = block->getArgument(i); + BlockArgument origArg = block->getArgument(i); // If inputMap->replacementValue is not nullptr, then the argument is // dropped and a replacement value is provided to be the remappedValue. @@ -473,7 +473,7 @@ struct ConversionPatternRewriterImpl { : op(op), newValues(newValues.begin(), newValues.end()) {} Operation *op; - SmallVector newValues; + SmallVector newValues; }; /// The kind of the block action performed during the rewrite. Actions can be @@ -570,7 +570,7 @@ struct ConversionPatternRewriterImpl { /// Remap the given operands to those with potentially different types. void remapValues(Operation::operand_range operands, - SmallVectorImpl &remapped); + SmallVectorImpl &remapped); /// Returns true if the given operation is ignored, and does not need to be /// converted. @@ -803,9 +803,9 @@ void ConversionPatternRewriterImpl::notifyRegionWasClonedBefore( } void ConversionPatternRewriterImpl::remapValues( - Operation::operand_range operands, SmallVectorImpl &remapped) { + Operation::operand_range operands, SmallVectorImpl &remapped) { remapped.reserve(llvm::size(operands)); - for (ValuePtr operand : operands) + for (Value operand : operands) remapped.push_back(mapping.lookupOrDefault(operand)); } @@ -851,7 +851,7 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues, void ConversionPatternRewriter::eraseOp(Operation *op) { LLVM_DEBUG(llvm::dbgs() << "** Erasing operation : " << op->getName() << "\n"); - SmallVector nullRepls(op->getNumResults(), nullptr); + SmallVector nullRepls(op->getNumResults(), nullptr); impl->replaceOp(op, nullRepls, /*valuesToRemoveIfDead=*/llvm::None); } @@ -861,8 +861,8 @@ Block *ConversionPatternRewriter::applySignatureConversion( return impl->applySignatureConversion(region, conversion); } -void ConversionPatternRewriter::replaceUsesOfBlockArgument( - BlockArgumentPtr from, ValuePtr to) { +void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, + Value to) { for (auto &u : from->getUses()) { if (u.getOwner() == to->getDefiningOp()) continue; @@ -873,7 +873,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument( /// Return the converted value that replaces 'key'. Return 'key' if there is /// no such a converted value. -ValuePtr ConversionPatternRewriter::getRemappedValue(ValuePtr key) { +Value ConversionPatternRewriter::getRemappedValue(Value key) { return impl->mapping.lookupOrDefault(key); } @@ -967,7 +967,7 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { PatternMatchResult ConversionPattern::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { - SmallVector operands; + SmallVector operands; auto &dialectRewriter = static_cast(rewriter); dialectRewriter.getImpl().remapValues(op->getOperands(), operands); @@ -979,7 +979,7 @@ ConversionPattern::matchAndRewrite(Operation *op, SmallVector destinations; destinations.reserve(op->getNumSuccessors()); - SmallVector, 2> operandsPerDestination; + SmallVector, 2> operandsPerDestination; unsigned firstSuccessorOperand = op->getSuccessorOperandIndex(0); for (unsigned i = 0, seen = 0, e = op->getNumSuccessors(); i < e; ++i) { destinations.push_back(op->getSuccessor(i)); @@ -1130,7 +1130,7 @@ OperationLegalizer::legalizeWithFold(Operation *op, RewriterState curState = rewriterImpl.getCurrentState(); // Try to fold the operation. - SmallVector replacementValues; + SmallVector replacementValues; rewriter.setInsertionPoint(op); if (failed(rewriter.tryFold(op, replacementValues))) return failure(); @@ -1554,7 +1554,7 @@ void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo, /// Remap an input of the original signature to another `replacementValue` /// value. This would make the signature converter drop this argument. void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo, - ValuePtr replacementValue) { + Value replacementValue) { assert(!remappedInputs[origInputNo] && "input has already been remapped"); remappedInputs[origInputNo] = InputMapping{origInputNo, /*size=*/0, replacementValue}; @@ -1623,7 +1623,7 @@ struct FuncOpSignatureConversion : public OpConversionPattern { /// Hook for derived classes to implement combined matching and rewriting. PatternMatchResult - matchAndRewrite(FuncOp funcOp, ArrayRef operands, + matchAndRewrite(FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { FunctionType type = funcOp.getType(); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 51e30ba7163..fcfc1d7ae52 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -163,7 +163,7 @@ public: Node(unsigned id, Operation *op) : id(id), op(op) {} // Returns the load op count for 'memref'. - unsigned getLoadOpCount(ValuePtr memref) { + unsigned getLoadOpCount(Value memref) { unsigned loadOpCount = 0; for (auto *loadOpInst : loads) { if (memref == cast(loadOpInst).getMemRef()) @@ -173,7 +173,7 @@ public: } // Returns the store op count for 'memref'. - unsigned getStoreOpCount(ValuePtr memref) { + unsigned getStoreOpCount(Value memref) { unsigned storeOpCount = 0; for (auto *storeOpInst : stores) { if (memref == cast(storeOpInst).getMemRef()) @@ -183,7 +183,7 @@ public: } // Returns all store ops in 'storeOps' which access 'memref'. - void getStoreOpsForMemref(ValuePtr memref, + void getStoreOpsForMemref(Value memref, SmallVectorImpl *storeOps) { for (auto *storeOpInst : stores) { if (memref == cast(storeOpInst).getMemRef()) @@ -192,7 +192,7 @@ public: } // Returns all load ops in 'loadOps' which access 'memref'. - void getLoadOpsForMemref(ValuePtr memref, + void getLoadOpsForMemref(Value memref, SmallVectorImpl *loadOps) { for (auto *loadOpInst : loads) { if (memref == cast(loadOpInst).getMemRef()) @@ -202,8 +202,8 @@ public: // Returns all memrefs in 'loadAndStoreMemrefSet' for which this node // has at least one load and store operation. - void getLoadAndStoreMemrefSet(DenseSet *loadAndStoreMemrefSet) { - llvm::SmallDenseSet loadMemrefs; + void getLoadAndStoreMemrefSet(DenseSet *loadAndStoreMemrefSet) { + llvm::SmallDenseSet loadMemrefs; for (auto *loadOpInst : loads) { loadMemrefs.insert(cast(loadOpInst).getMemRef()); } @@ -230,7 +230,7 @@ public: // defines an SSA value and another graph node which uses the SSA value // (e.g. a constant operation defining a value which is used inside a loop // nest). - ValuePtr value; + Value value; }; // Map from node id to Node. @@ -241,7 +241,7 @@ public: DenseMap> outEdges; // Map from memref to a count on the dependence edges associated with that // memref. - DenseMap memrefEdgeCount; + DenseMap memrefEdgeCount; // The next unique identifier to use for newly created graph nodes. unsigned nextNodeId = 0; @@ -372,7 +372,7 @@ public: // Returns true iff there is an edge from node 'srcId' to node 'dstId' which // is for 'value' if non-null, or for any value otherwise. Returns false // otherwise. - bool hasEdge(unsigned srcId, unsigned dstId, ValuePtr value = nullptr) { + bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr) { if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) { return false; } @@ -386,7 +386,7 @@ public: } // Adds an edge from node 'srcId' to node 'dstId' for 'value'. - void addEdge(unsigned srcId, unsigned dstId, ValuePtr value) { + void addEdge(unsigned srcId, unsigned dstId, Value value) { if (!hasEdge(srcId, dstId, value)) { outEdges[srcId].push_back({dstId, value}); inEdges[dstId].push_back({srcId, value}); @@ -396,7 +396,7 @@ public: } // Removes an edge from node 'srcId' to node 'dstId' for 'value'. - void removeEdge(unsigned srcId, unsigned dstId, ValuePtr value) { + void removeEdge(unsigned srcId, unsigned dstId, Value value) { assert(inEdges.count(dstId) > 0); assert(outEdges.count(srcId) > 0); if (value->getType().isa()) { @@ -450,7 +450,7 @@ public: // Returns the input edge count for node 'id' and 'memref' from src nodes // which access 'memref' with a store operation. - unsigned getIncomingMemRefAccesses(unsigned id, ValuePtr memref) { + unsigned getIncomingMemRefAccesses(unsigned id, Value memref) { unsigned inEdgeCount = 0; if (inEdges.count(id) > 0) for (auto &inEdge : inEdges[id]) @@ -465,7 +465,7 @@ public: // Returns the output edge count for node 'id' and 'memref' (if non-null), // otherwise returns the total output edge count from node 'id'. - unsigned getOutEdgeCount(unsigned id, ValuePtr memref = nullptr) { + unsigned getOutEdgeCount(unsigned id, Value memref = nullptr) { unsigned outEdgeCount = 0; if (outEdges.count(id) > 0) for (auto &outEdge : outEdges[id]) @@ -539,7 +539,7 @@ public: // Updates edge mappings from node 'srcId' to node 'dstId' after 'oldMemRef' // has been replaced in node at 'dstId' by a private memref depending // on the value of 'createPrivateMemRef'. - void updateEdges(unsigned srcId, unsigned dstId, ValuePtr oldMemRef, + void updateEdges(unsigned srcId, unsigned dstId, Value oldMemRef, bool createPrivateMemRef) { // For each edge in 'inEdges[srcId]': add new edge remaping to 'dstId'. if (inEdges.count(srcId) > 0) { @@ -672,7 +672,7 @@ public: // TODO(andydavis) Add support for taking a Block arg to construct the // dependence graph at a different depth. bool MemRefDependenceGraph::init(FuncOp f) { - DenseMap> memrefAccesses; + DenseMap> memrefAccesses; // TODO: support multi-block functions. if (f.getBlocks().size() != 1) @@ -768,7 +768,7 @@ bool MemRefDependenceGraph::init(FuncOp f) { // Removes load operations from 'srcLoads' which operate on 'memref', and // adds them to 'dstLoads'. -static void moveLoadsAccessingMemrefTo(ValuePtr memref, +static void moveLoadsAccessingMemrefTo(Value memref, SmallVectorImpl *srcLoads, SmallVectorImpl *dstLoads) { dstLoads->clear(); @@ -884,11 +884,10 @@ static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { // MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'. // TODO(bondhugula): consider refactoring the common code from generateDma and // this one. -static ValuePtr createPrivateMemRef(AffineForOp forOp, - Operation *srcStoreOpInst, - unsigned dstLoopDepth, - Optional fastMemorySpace, - uint64_t localBufSizeThreshold) { +static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, + unsigned dstLoopDepth, + Optional fastMemorySpace, + uint64_t localBufSizeThreshold) { auto *forInst = forOp.getOperation(); // Create builder to insert alloc op just before 'forOp'. @@ -920,7 +919,7 @@ static ValuePtr createPrivateMemRef(AffineForOp forOp, // 'outerIVs' holds the values that this memory region is symbolic/parametric // on; this would correspond to loop IVs surrounding the level at which the // slice is being materialized. - SmallVector outerIVs; + SmallVector outerIVs; cst->getIdValues(rank, cst->getNumIds(), &outerIVs); // Build 'rank' AffineExprs from MemRefRegion 'lbs' @@ -952,7 +951,7 @@ static ValuePtr createPrivateMemRef(AffineForOp forOp, auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(), {}, newMemSpace); // Gather alloc operands for the dynamic dimensions of the memref. - SmallVector allocOperands; + SmallVector allocOperands; unsigned dynamicDimCount = 0; for (auto dimSize : oldMemRefType.getShape()) { if (dimSize == -1) @@ -965,7 +964,7 @@ static ValuePtr createPrivateMemRef(AffineForOp forOp, // consumer loop nests to reduce their live range. Currently they are added // at the beginning of the function, because loop nests can be reordered // during the fusion pass. - ValuePtr newMemRef = + Value newMemRef = top.create(forOp.getLoc(), newMemRefType, allocOperands); // Build an AffineMap to remap access functions based on lower bound offsets. @@ -1008,7 +1007,7 @@ static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, MemRefDependenceGraph *mdg) { assert(srcLiveOutStoreOp && "Expected a valid store op"); auto *dstNode = mdg->getNode(dstId); - ValuePtr memref = srcLiveOutStoreOp.getMemRef(); + Value memref = srcLiveOutStoreOp.getMemRef(); // Return false if 'srcNode' has more than one output edge on 'memref'. if (mdg->getOutEdgeCount(srcId, memref) > 1) return false; @@ -1487,7 +1486,7 @@ public: SmallVector loads = dstNode->loads; SmallVector dstLoadOpInsts; - DenseSet visitedMemrefs; + DenseSet visitedMemrefs; while (!loads.empty()) { // Get memref of load on top of the stack. auto memref = cast(loads.back()).getMemRef(); @@ -1729,10 +1728,10 @@ public: // Attempt to fuse 'dstNode' with sibling nodes in the graph. void fuseWithSiblingNodes(Node *dstNode) { DenseSet visitedSibNodeIds; - std::pair idAndMemref; + std::pair idAndMemref; while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) { unsigned sibId = idAndMemref.first; - ValuePtr memref = idAndMemref.second; + Value memref = idAndMemref.second; // TODO(andydavis) Check that 'sibStoreOpInst' post-dominates all other // stores to the same memref in 'sibNode' loop nest. auto *sibNode = mdg->getNode(sibId); @@ -1796,10 +1795,10 @@ public: // 'idAndMemrefToFuse' on success. Returns false otherwise. bool findSiblingNodeToFuse(Node *dstNode, DenseSet *visitedSibNodeIds, - std::pair *idAndMemrefToFuse) { + std::pair *idAndMemrefToFuse) { // Returns true if 'sibNode' can be fused with 'dstNode' for input reuse // on 'memref'. - auto canFuseWithSibNode = [&](Node *sibNode, ValuePtr memref) { + auto canFuseWithSibNode = [&](Node *sibNode, Value memref) { // Skip if 'outEdge' is not a read-after-write dependence. // TODO(andydavis) Remove restrict to single load op restriction. if (sibNode->getLoadOpCount(memref) != 1) @@ -1811,15 +1810,15 @@ public: return false; // Skip sib node if it loads to (and stores from) the same memref on // which it also has an input dependence edge. - DenseSet loadAndStoreMemrefSet; + DenseSet loadAndStoreMemrefSet; sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet); - if (llvm::any_of(loadAndStoreMemrefSet, [=](ValuePtr memref) { + if (llvm::any_of(loadAndStoreMemrefSet, [=](Value memref) { return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > 0; })) return false; // Check that all stores are to the same memref. - DenseSet storeMemrefs; + DenseSet storeMemrefs; for (auto *storeOpInst : sibNode->stores) { storeMemrefs.insert(cast(storeOpInst).getMemRef()); } diff --git a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp index 93c80822fb3..fb3d0c0b45c 100644 --- a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp +++ b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp @@ -41,7 +41,7 @@ public: // - the op has no side-effects. If sideEffecting is Never, sideeffects of this // op and its nested ops are ignored. static bool canBeHoisted(Operation *op, - function_ref definedOutside, + function_ref definedOutside, SideEffecting sideEffecting, SideEffectsInterface &interface) { // Check that dependencies are defined outside of loop. @@ -83,7 +83,7 @@ static LogicalResult moveLoopInvariantCode(LoopLikeOpInterface looplike, SmallVector opsToMove; // Helper to check whether an operation is loop invariant wrt. SSA properties. - auto isDefinedOutsideOfBody = [&](ValuePtr value) { + auto isDefinedOutsideOfBody = [&](Value value) { auto definingOp = value->getDefiningOp(); return (definingOp && !!willBeMovedSet.count(definingOp)) || looplike.isDefinedOutsideOfLoop(value); diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 5389c7e4429..d3dc81760fc 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -111,8 +111,8 @@ constructTiledIndexSetHyperRect(MutableArrayRef origLoops, for (unsigned i = 0; i < width; i++) { auto lbOperands = origLoops[i].getLowerBoundOperands(); auto ubOperands = origLoops[i].getUpperBoundOperands(); - SmallVector newLbOperands(lbOperands); - SmallVector newUbOperands(ubOperands); + SmallVector newLbOperands(lbOperands); + SmallVector newUbOperands(ubOperands); newLoops[i].setLowerBound(newLbOperands, origLoops[i].getLowerBoundMap()); newLoops[i].setUpperBound(newUbOperands, origLoops[i].getUpperBoundMap()); newLoops[i].setStep(tileSizes[i]); @@ -138,7 +138,7 @@ constructTiledIndexSetHyperRect(MutableArrayRef origLoops, // with 'i' (tile-space loop) appended to it. The new upper bound map is // the original one with an additional expression i + tileSize appended. auto ub = origLoops[i].getUpperBound(); - SmallVector ubOperands; + SmallVector ubOperands; ubOperands.reserve(ub.getNumOperands() + 1); auto origUbMap = ub.getMap(); // Add dim operands from original upper bound. @@ -226,10 +226,9 @@ LogicalResult mlir::tileCodeGen(MutableArrayRef band, // Move the loop body of the original nest to the new one. moveLoopBody(origLoops[origLoops.size() - 1], innermostPointLoop); - SmallVector origLoopIVs; + SmallVector origLoopIVs; extractForInductionVars(band, &origLoopIVs); - SmallVector, 6> ids(origLoopIVs.begin(), - origLoopIVs.end()); + SmallVector, 6> ids(origLoopIVs.begin(), origLoopIVs.end()); FlatAffineConstraints cst; getIndexSet(band, &cst); diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 3cefcaacadc..6c74d545497 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -182,7 +182,7 @@ LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp, // Adjust the lower bound of the cleanup loop; its upper bound is the same // as the original loop's upper bound. AffineMap cleanupMap; - SmallVector cleanupOperands; + SmallVector cleanupOperands; getCleanupLoopLowerBound(forOp, unrollJamFactor, &cleanupMap, &cleanupOperands, builder); cleanupAffineForOp.setLowerBound(cleanupOperands, cleanupMap); diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index 957f41a9d3e..e2514e12cc7 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -67,7 +67,7 @@ struct MemRefDataFlowOpt : public FunctionPass { void forwardStoreToLoad(AffineLoadOp loadOp); // A list of memref's that are potentially dead / could be eliminated. - SmallPtrSet memrefsToErase; + SmallPtrSet memrefsToErase; // Load op's whose results were replaced by those forwarded from stores. SmallVector loadOpsToErase; @@ -171,7 +171,7 @@ void MemRefDataFlowOpt::forwardStoreToLoad(AffineLoadOp loadOp) { return; // Perform the actual store to load forwarding. - ValuePtr storeVal = cast(lastWriteStoreOp).getValueToStore(); + Value storeVal = cast(lastWriteStoreOp).getValueToStore(); loadOp.replaceAllUsesWith(storeVal); // Record the memref for a later sweep to optimize away. memrefsToErase.insert(loadOp.getMemRef()); diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 12ce6c66abd..dce02737064 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -61,7 +61,7 @@ static unsigned getTagMemRefPos(Operation &dmaInst) { /// Replaces all uses of the old memref by the new one while indexing the newly /// added dimension by the loop IV of the specified 'affine.for' operation /// modulo 2. Returns false if such a replacement cannot be performed. -static bool doubleBuffer(ValuePtr oldMemRef, AffineForOp forOp) { +static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) { auto *forBody = forOp.getBody(); OpBuilder bInner(forBody, forBody->begin()); @@ -85,7 +85,7 @@ static bool doubleBuffer(ValuePtr oldMemRef, AffineForOp forOp) { auto *forInst = forOp.getOperation(); OpBuilder bOuter(forInst); // Put together alloc operands for any dynamic dimensions of the memref. - SmallVector allocOperands; + SmallVector allocOperands; unsigned dynamicDimCount = 0; for (auto dimSize : oldMemRefType.getShape()) { if (dimSize == -1) @@ -94,7 +94,7 @@ static bool doubleBuffer(ValuePtr oldMemRef, AffineForOp forOp) { } // Create and place the alloc right before the 'affine.for' operation. - ValuePtr newMemRef = + Value newMemRef = bOuter.create(forInst->getLoc(), newMemRefType, allocOperands); // Create 'iv mod 2' value to index the leading dimension. @@ -261,7 +261,7 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { // dimension. for (auto &pair : startWaitPairs) { auto *dmaStartInst = pair.first; - ValuePtr oldMemRef = dmaStartInst->getOperand( + Value oldMemRef = dmaStartInst->getOperand( cast(dmaStartInst).getFasterMemPos()); if (!doubleBuffer(oldMemRef, forOp)) { // Normally, double buffering should not fail because we already checked @@ -292,7 +292,7 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { // Double the buffers for tag memrefs. for (auto &pair : startWaitPairs) { auto *dmaFinishInst = pair.second; - ValuePtr oldTagMemRef = + Value oldTagMemRef = dmaFinishInst->getOperand(getTagMemRefPos(*dmaFinishInst)); if (!doubleBuffer(oldTagMemRef, forOp)) { LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";); @@ -333,7 +333,7 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { // If a slice wasn't created, the reachable affine.apply op's from its // operands are the ones that go with it. SmallVector affineApplyInsts; - SmallVector operands(dmaStartInst->getOperands()); + SmallVector operands(dmaStartInst->getOperands()); getReachableAffineApplyOps(operands, affineApplyInsts); for (auto *op : affineApplyInsts) { instShiftMap[op] = 0; diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp index ce39625831a..719c6fac731 100644 --- a/mlir/lib/Transforms/Utils/FoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp @@ -81,7 +81,7 @@ LogicalResult OperationFolder::tryToFold( return failure(); // Try to fold the operation. - SmallVector results; + SmallVector results; if (failed(tryToFold(op, results, processGeneratedConstants))) return failure(); @@ -129,7 +129,7 @@ void OperationFolder::notifyRemoval(Operation *op) { /// Tries to perform folding on the given `op`. If successful, populates /// `results` with the results of the folding. LogicalResult OperationFolder::tryToFold( - Operation *op, SmallVectorImpl &results, + Operation *op, SmallVectorImpl &results, function_ref processGeneratedConstants) { SmallVector operandConstants; SmallVector foldResults; @@ -172,7 +172,7 @@ LogicalResult OperationFolder::tryToFold( assert(!foldResults[i].isNull() && "expected valid OpFoldResult"); // Check if the result was an SSA value. - if (auto repl = foldResults[i].dyn_cast()) { + if (auto repl = foldResults[i].dyn_cast()) { results.emplace_back(repl); continue; } diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 3ab4e287bb2..1eb9c57639a 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -109,7 +109,7 @@ private: // operation is modified or removed, as it may trigger further // simplifications. template void addToWorklist(Operands &&operands) { - for (ValuePtr operand : operands) { + for (Value operand : operands) { // If the use count of this operand is now < 2, we re-add the defining // operation to the worklist. // TODO(riverriddle) This is based on the fact that zero use operations @@ -151,7 +151,7 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef regions, region.walk(collectOps); // These are scratch vectors used in the folding loop below. - SmallVector originalOperands, resultValues; + SmallVector originalOperands, resultValues; changed = false; while (!worklist.empty()) { diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp index e7b34bb3956..1ac286c67fb 100644 --- a/mlir/lib/Transforms/Utils/InliningUtils.cpp +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -89,7 +89,7 @@ void InlinerInterface::handleTerminator(Operation *op, Block *newDest) const { /// Handle the given inlined terminator by replacing it with a new operation /// as necessary. void InlinerInterface::handleTerminator(Operation *op, - ArrayRef valuesToRepl) const { + ArrayRef valuesToRepl) const { auto *handler = getInterfaceFor(op); assert(handler && "expected valid dialect handler"); handler->handleTerminator(op, valuesToRepl); @@ -128,7 +128,7 @@ static bool isLegalToInline(InlinerInterface &interface, Region *src, LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, Operation *inlinePoint, BlockAndValueMapping &mapper, - ArrayRef resultsToReplace, + ArrayRef resultsToReplace, Optional inlineLoc, bool shouldCloneInlinedRegion) { // We expect the region to have at least one block. @@ -138,7 +138,7 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, // Check that all of the region arguments have been mapped. auto *srcEntryBlock = &src->front(); if (llvm::any_of(srcEntryBlock->getArguments(), - [&](BlockArgumentPtr arg) { return !mapper.contains(arg); })) + [&](BlockArgument arg) { return !mapper.contains(arg); })) return failure(); // The insertion point must be within a block. @@ -198,7 +198,7 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, } else { // Otherwise, there were multiple blocks inlined. Add arguments to the post // insertion block to represent the results to replace. - for (ValuePtr resultToRepl : resultsToReplace) { + for (Value resultToRepl : resultsToReplace) { resultToRepl->replaceAllUsesWith( postInsertBlock->addArgument(resultToRepl->getType())); } @@ -220,8 +220,8 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, /// in-favor of the region arguments when inlining. LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, Operation *inlinePoint, - ArrayRef inlinedOperands, - ArrayRef resultsToReplace, + ArrayRef inlinedOperands, + ArrayRef resultsToReplace, Optional inlineLoc, bool shouldCloneInlinedRegion) { // We expect the region to have at least one block. @@ -237,7 +237,7 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) { // Verify that the types of the provided values match the function argument // types. - BlockArgumentPtr regionArg = entryBlock->getArgument(i); + BlockArgument regionArg = entryBlock->getArgument(i); if (inlinedOperands[i]->getType() != regionArg->getType()) return failure(); mapper.map(regionArg, inlinedOperands[i]); @@ -250,10 +250,10 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, /// Utility function used to generate a cast operation from the given interface, /// or return nullptr if a cast could not be generated. -static ValuePtr materializeConversion(const DialectInlinerInterface *interface, - SmallVectorImpl &castOps, - OpBuilder &castBuilder, ValuePtr arg, - Type type, Location conversionLoc) { +static Value materializeConversion(const DialectInlinerInterface *interface, + SmallVectorImpl &castOps, + OpBuilder &castBuilder, Value arg, Type type, + Location conversionLoc) { if (!interface) return nullptr; @@ -288,8 +288,8 @@ LogicalResult mlir::inlineCall(InlinerInterface &interface, // Make sure that the number of arguments and results matchup between the call // and the region. - SmallVector callOperands(call.getArgOperands()); - SmallVector callResults(call.getOperation()->getResults()); + SmallVector callOperands(call.getArgOperands()); + SmallVector callResults(call.getOperation()->getResults()); if (callOperands.size() != entryBlock->getNumArguments() || callResults.size() != callableResultTypes.size()) return failure(); @@ -316,8 +316,8 @@ LogicalResult mlir::inlineCall(InlinerInterface &interface, // Map the provided call operands to the arguments of the region. BlockAndValueMapping mapper; for (unsigned i = 0, e = callOperands.size(); i != e; ++i) { - BlockArgumentPtr regionArg = entryBlock->getArgument(i); - ValuePtr operand = callOperands[i]; + BlockArgument regionArg = entryBlock->getArgument(i); + Value operand = callOperands[i]; // If the call operand doesn't match the expected region argument, try to // generate a cast. @@ -333,13 +333,13 @@ LogicalResult mlir::inlineCall(InlinerInterface &interface, // Ensure that the resultant values of the call, match the callable. castBuilder.setInsertionPointAfter(call); for (unsigned i = 0, e = callResults.size(); i != e; ++i) { - ValuePtr callResult = callResults[i]; + Value callResult = callResults[i]; if (callResult->getType() == callableResultTypes[i]) continue; // Generate a conversion that will produce the original type, so that the IR // is still valid after the original call gets replaced. - ValuePtr castResult = + Value castResult = materializeConversion(callInterface, castOps, castBuilder, callResult, callResult->getType(), castLoc); if (!castResult) diff --git a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp index 4745a26e168..b0d9fdf5fd8 100644 --- a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp @@ -36,7 +36,7 @@ using namespace mlir; // Gathers all load and store memref accesses in 'opA' into 'values', where // 'values[memref] == true' for each store operation. static void getLoadAndStoreMemRefAccesses(Operation *opA, - DenseMap &values) { + DenseMap &values) { opA->walk([&](Operation *op) { if (auto loadOp = dyn_cast(op)) { if (values.count(loadOp.getMemRef()) == 0) @@ -51,7 +51,7 @@ static void getLoadAndStoreMemRefAccesses(Operation *opA, // accessed 'values' and at least one of the access is a store operation. // Returns false otherwise. static bool isDependentLoadOrStoreOp(Operation *op, - DenseMap &values) { + DenseMap &values) { if (auto loadOp = dyn_cast(op)) { return values.count(loadOp.getMemRef()) > 0 && values[loadOp.getMemRef()] == true; @@ -66,7 +66,7 @@ static bool isDependentLoadOrStoreOp(Operation *op, static Operation *getFirstDependentOpInRange(Operation *opA, Operation *opB) { // Record memref values from all loads/store in loop nest rooted at 'opA'. // Map from memref value to bool which is true if store, false otherwise. - DenseMap values; + DenseMap values; getLoadAndStoreMemRefAccesses(opA, values); // For each 'opX' in block in range ('opA', 'opB'), check if there is a data @@ -92,7 +92,7 @@ static Operation *getFirstDependentOpInRange(Operation *opA, Operation *opB) { static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) { // Record memref values from all loads/store in loop nest rooted at 'opB'. // Map from memref value to bool which is true if store, false otherwise. - DenseMap values; + DenseMap values; getLoadAndStoreMemRefAccesses(opB, values); // For each 'opX' in block in range ('opA', 'opB') in reverse order, @@ -434,7 +434,7 @@ bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, // Subtract from operation count the loads/store we expect load/store // forwarding to remove. unsigned storeCount = 0; - llvm::SmallDenseSet storeMemrefs; + llvm::SmallDenseSet storeMemrefs; srcForOp.walk([&](Operation *op) { if (auto storeOp = dyn_cast(op)) { storeMemrefs.insert(storeOp.getMemRef()); diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 3d4db22c866..0fece54132a 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -43,7 +43,7 @@ using llvm::SmallMapVector; /// expression. void mlir::getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor, AffineMap *map, - SmallVectorImpl *operands, + SmallVectorImpl *operands, OpBuilder &b) { auto lbMap = forOp.getLowerBoundMap(); @@ -54,7 +54,7 @@ void mlir::getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor, } AffineMap tripCountMap; - SmallVector tripCountOperands; + SmallVector tripCountOperands; buildTripCountMapAndOperands(forOp, &tripCountMap, &tripCountOperands); // Sometimes the trip count cannot be expressed as an affine expression. @@ -73,7 +73,7 @@ void mlir::getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor, // lb + tr1 - tr1 % ufactor, lb + tr2 - tr2 % ufactor; the results of all // these affine.apply's make up the cleanup loop lower bound. SmallVector bumpExprs(tripCountMap.getNumResults()); - SmallVector bumpValues(tripCountMap.getNumResults()); + SmallVector bumpValues(tripCountMap.getNumResults()); for (unsigned i = 0, e = tripCountMap.getNumResults(); i < e; i++) { auto tripCountExpr = tripCountMap.getResult(i); bumpExprs[i] = (tripCountExpr - tripCountExpr % unrollFactor) * step; @@ -128,7 +128,7 @@ LogicalResult mlir::promoteIfSingleIteration(AffineForOp forOp) { iv->replaceAllUsesWith(constOp); } else { AffineBound lb = forOp.getLowerBound(); - SmallVector lbOperands(lb.operand_begin(), lb.operand_end()); + SmallVector lbOperands(lb.operand_begin(), lb.operand_end()); OpBuilder builder(op->getBlock(), Block::iterator(op)); if (lb.getMap() == builder.getDimIdentityMap()) { // No need of generating an affine.apply. @@ -169,8 +169,8 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, const std::vector>> &instGroupQueue, unsigned offset, AffineForOp srcForInst, OpBuilder b) { - SmallVector lbOperands(srcForInst.getLowerBoundOperands()); - SmallVector ubOperands(srcForInst.getUpperBoundOperands()); + SmallVector lbOperands(srcForInst.getLowerBoundOperands()); + SmallVector ubOperands(srcForInst.getUpperBoundOperands()); assert(lbMap.getNumInputs() == lbOperands.size()); assert(ubMap.getNumInputs() == ubOperands.size()); @@ -440,7 +440,7 @@ LogicalResult mlir::loopUnrollByFactor(AffineForOp forOp, OpBuilder builder(op->getBlock(), ++Block::iterator(op)); auto cleanupForInst = cast(builder.clone(*op)); AffineMap cleanupMap; - SmallVector cleanupOperands; + SmallVector cleanupOperands; getCleanupLoopLowerBound(forOp, unrollFactor, &cleanupMap, &cleanupOperands, builder); assert(cleanupMap && @@ -660,8 +660,8 @@ void mlir::sinkLoop(AffineForOp forOp, unsigned loopDepth) { // ... // } // ``` -static void augmentMapAndBounds(OpBuilder &b, ValuePtr iv, AffineMap *map, - SmallVector *operands, +static void augmentMapAndBounds(OpBuilder &b, Value iv, AffineMap *map, + SmallVector *operands, int64_t offset = 0) { auto bounds = llvm::to_vector<4>(map->getResults()); bounds.push_back(b.getAffineDimExpr(map->getNumDims()) + offset); @@ -690,12 +690,12 @@ stripmineSink(AffineForOp forOp, uint64_t factor, // Lower-bound map creation. auto lbMap = forOp.getLowerBoundMap(); - SmallVector lbOperands(forOp.getLowerBoundOperands()); + SmallVector lbOperands(forOp.getLowerBoundOperands()); augmentMapAndBounds(b, forOp.getInductionVar(), &lbMap, &lbOperands); // Upper-bound map creation. auto ubMap = forOp.getUpperBoundMap(); - SmallVector ubOperands(forOp.getUpperBoundOperands()); + SmallVector ubOperands(forOp.getUpperBoundOperands()); augmentMapAndBounds(b, forOp.getInductionVar(), &ubMap, &ubOperands, /*offset=*/scaledStep); @@ -720,7 +720,7 @@ stripmineSink(AffineForOp forOp, uint64_t factor, return innerLoops; } -static Loops stripmineSink(loop::ForOp forOp, ValuePtr factor, +static Loops stripmineSink(loop::ForOp forOp, Value factor, ArrayRef targets) { auto originalStep = forOp.step(); auto iv = forOp.getInductionVar(); @@ -736,10 +736,10 @@ static Loops stripmineSink(loop::ForOp forOp, ValuePtr factor, // Insert newForOp before the terminator of `t`. OpBuilder b(t.getBodyBuilder()); - ValuePtr stepped = b.create(t.getLoc(), iv, forOp.step()); - ValuePtr less = b.create(t.getLoc(), CmpIPredicate::slt, - forOp.upperBound(), stepped); - ValuePtr ub = + Value stepped = b.create(t.getLoc(), iv, forOp.step()); + Value less = b.create(t.getLoc(), CmpIPredicate::slt, + forOp.upperBound(), stepped); + Value ub = b.create(t.getLoc(), less, forOp.upperBound(), stepped); // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses. @@ -790,7 +790,7 @@ mlir::tile(ArrayRef forOps, ArrayRef sizes, } SmallVector mlir::tile(ArrayRef forOps, - ArrayRef sizes, + ArrayRef sizes, ArrayRef targets) { return tileImpl(forOps, sizes, targets); } @@ -812,13 +812,12 @@ SmallVector mlir::tile(ArrayRef forOps, return tileImpl(forOps, sizes, target); } -Loops mlir::tile(ArrayRef forOps, ArrayRef sizes, +Loops mlir::tile(ArrayRef forOps, ArrayRef sizes, loop::ForOp target) { return tileImpl(forOps, sizes, target); } -Loops mlir::tilePerfectlyNested(loop::ForOp rootForOp, - ArrayRef sizes) { +Loops mlir::tilePerfectlyNested(loop::ForOp rootForOp, ArrayRef sizes) { // Collect perfectly nested loops. If more size values provided than nested // loops available, truncate `sizes`. SmallVector forOps; @@ -833,15 +832,14 @@ Loops mlir::tilePerfectlyNested(loop::ForOp rootForOp, // Build the IR that performs ceil division of a positive value by a constant: // ceildiv(a, B) = divis(a + (B-1), B) // where divis is rounding-to-zero division. -static ValuePtr ceilDivPositive(OpBuilder &builder, Location loc, - ValuePtr dividend, int64_t divisor) { +static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, + int64_t divisor) { assert(divisor > 0 && "expected positive divisor"); assert(dividend->getType().isIndex() && "expected index-typed value"); - ValuePtr divisorMinusOneCst = - builder.create(loc, divisor - 1); - ValuePtr divisorCst = builder.create(loc, divisor); - ValuePtr sum = builder.create(loc, dividend, divisorMinusOneCst); + Value divisorMinusOneCst = builder.create(loc, divisor - 1); + Value divisorCst = builder.create(loc, divisor); + Value sum = builder.create(loc, dividend, divisorMinusOneCst); return builder.create(loc, sum, divisorCst); } @@ -849,13 +847,13 @@ static ValuePtr ceilDivPositive(OpBuilder &builder, Location loc, // positive value: // ceildiv(a, b) = divis(a + (b - 1), b) // where divis is rounding-to-zero division. -static ValuePtr ceilDivPositive(OpBuilder &builder, Location loc, - ValuePtr dividend, ValuePtr divisor) { +static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, + Value divisor) { assert(dividend->getType().isIndex() && "expected index-typed value"); - ValuePtr cstOne = builder.create(loc, 1); - ValuePtr divisorMinusOne = builder.create(loc, divisor, cstOne); - ValuePtr sum = builder.create(loc, dividend, divisorMinusOne); + Value cstOne = builder.create(loc, 1); + Value divisorMinusOne = builder.create(loc, divisor, cstOne); + Value sum = builder.create(loc, dividend, divisorMinusOne); return builder.create(loc, sum, divisor); } @@ -937,7 +935,7 @@ TileLoops mlir::extractFixedOuterLoops(loop::ForOp rootForOp, // iterations. Given that the loop current executes // numIterations = ceildiv((upperBound - lowerBound), step) // iterations, we need to tile with size ceildiv(numIterations, size[i]). - SmallVector tileSizes; + SmallVector tileSizes; tileSizes.reserve(sizes.size()); for (unsigned i = 0, e = sizes.size(); i < e; ++i) { assert(sizes[i] > 0 && "expected strictly positive size for strip-mining"); @@ -945,10 +943,10 @@ TileLoops mlir::extractFixedOuterLoops(loop::ForOp rootForOp, auto forOp = forOps[i]; OpBuilder builder(forOp); auto loc = forOp.getLoc(); - ValuePtr diff = + Value diff = builder.create(loc, forOp.upperBound(), forOp.lowerBound()); - ValuePtr numIterations = ceilDivPositive(builder, loc, diff, forOp.step()); - ValuePtr iterationsPerBlock = + Value numIterations = ceilDivPositive(builder, loc, diff, forOp.step()); + Value iterationsPerBlock = ceilDivPositive(builder, loc, numIterations, sizes[i]); tileSizes.push_back(iterationsPerBlock); } @@ -968,7 +966,7 @@ TileLoops mlir::extractFixedOuterLoops(loop::ForOp rootForOp, // Replaces all uses of `orig` with `replacement` except if the user is listed // in `exceptions`. static void -replaceAllUsesExcept(ValuePtr orig, ValuePtr replacement, +replaceAllUsesExcept(Value orig, Value replacement, const SmallPtrSetImpl &exceptions) { for (auto &use : llvm::make_early_inc_range(orig->getUses())) { if (exceptions.count(use.getOwner()) == 0) @@ -1010,30 +1008,30 @@ static void normalizeLoop(loop::ForOp loop, loop::ForOp outer, // of the loop to go from 0 to the number of iterations, if necessary. // TODO(zinenko): introduce support for negative steps or emit dynamic asserts // on step positivity, whatever gets implemented first. - ValuePtr diff = + Value diff = builder.create(loc, loop.upperBound(), loop.lowerBound()); - ValuePtr numIterations = ceilDivPositive(builder, loc, diff, loop.step()); + Value numIterations = ceilDivPositive(builder, loc, diff, loop.step()); loop.setUpperBound(numIterations); - ValuePtr lb = loop.lowerBound(); + Value lb = loop.lowerBound(); if (!isZeroBased) { - ValuePtr cst0 = builder.create(loc, 0); + Value cst0 = builder.create(loc, 0); loop.setLowerBound(cst0); } - ValuePtr step = loop.step(); + Value step = loop.step(); if (!isStepOne) { - ValuePtr cst1 = builder.create(loc, 1); + Value cst1 = builder.create(loc, 1); loop.setStep(cst1); } // Insert code computing the value of the original loop induction variable // from the "normalized" one. builder.setInsertionPointToStart(inner.getBody()); - ValuePtr scaled = + Value scaled = isStepOne ? loop.getInductionVar() : builder.create(loc, loop.getInductionVar(), step); - ValuePtr shifted = + Value shifted = isZeroBased ? scaled : builder.create(loc, scaled, lb); SmallPtrSet preserve{scaled->getDefiningOp(), @@ -1057,7 +1055,7 @@ void mlir::coalesceLoops(MutableArrayRef loops) { // of the number of iterations of all loops. OpBuilder builder(outermost); Location loc = outermost.getLoc(); - ValuePtr upperBound = outermost.upperBound(); + Value upperBound = outermost.upperBound(); for (auto loop : loops.drop_front()) upperBound = builder.create(loc, upperBound, loop.upperBound()); outermost.setUpperBound(upperBound); @@ -1072,16 +1070,16 @@ void mlir::coalesceLoops(MutableArrayRef loops) { // iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i. // Compute these iteratively from the innermost loop by creating a "running // quotient" of division by the range. - ValuePtr previous = outermost.getInductionVar(); + Value previous = outermost.getInductionVar(); for (unsigned i = 0, e = loops.size(); i < e; ++i) { unsigned idx = loops.size() - i - 1; if (i != 0) previous = builder.create(loc, previous, loops[idx + 1].upperBound()); - ValuePtr iv = (i == e - 1) ? previous - : builder.create( - loc, previous, loops[idx].upperBound()); + Value iv = (i == e - 1) ? previous + : builder.create( + loc, previous, loops[idx].upperBound()); replaceAllUsesInRegionWith(loops[idx].getInductionVar(), iv, loops.back().region()); } @@ -1096,24 +1094,23 @@ void mlir::coalesceLoops(MutableArrayRef loops) { second.erase(); } -void mlir::mapLoopToProcessorIds(loop::ForOp forOp, - ArrayRef processorId, - ArrayRef numProcessors) { +void mlir::mapLoopToProcessorIds(loop::ForOp forOp, ArrayRef processorId, + ArrayRef numProcessors) { assert(processorId.size() == numProcessors.size()); if (processorId.empty()) return; OpBuilder b(forOp); Location loc(forOp.getLoc()); - ValuePtr mul = processorId.front(); + Value mul = processorId.front(); for (unsigned i = 1, e = processorId.size(); i < e; ++i) mul = b.create(loc, b.create(loc, mul, numProcessors[i]), processorId[i]); - ValuePtr lb = b.create(loc, forOp.lowerBound(), - b.create(loc, forOp.step(), mul)); + Value lb = b.create(loc, forOp.lowerBound(), + b.create(loc, forOp.step(), mul)); forOp.setLowerBound(lb); - ValuePtr step = forOp.step(); + Value step = forOp.step(); for (auto numProcs : numProcessors) step = b.create(loc, step, numProcs); forOp.setStep(step); @@ -1131,7 +1128,7 @@ findHighestBlockForPlacement(const MemRefRegion ®ion, Block &block, Block::iterator *copyInPlacementStart, Block::iterator *copyOutPlacementStart) { const auto *cst = region.getConstraints(); - SmallVector symbols; + SmallVector symbols; cst->getIdValues(cst->getNumDimIds(), cst->getNumDimAndSymbolIds(), &symbols); SmallVector enclosingFors; @@ -1194,10 +1191,10 @@ static void getMultiLevelStrides(const MemRefRegion ®ion, /// returns the outermost AffineForOp of the copy loop nest. `memIndicesStart' /// holds the lower coordinates of the region in the original memref to copy /// in/out. If `copyOut' is true, generates a copy-out; otherwise a copy-in. -static AffineForOp generatePointWiseCopy(Location loc, ValuePtr memref, - ValuePtr fastMemRef, +static AffineForOp generatePointWiseCopy(Location loc, Value memref, + Value fastMemRef, AffineMap memAffineMap, - ArrayRef memIndicesStart, + ArrayRef memIndicesStart, ArrayRef fastBufferShape, bool isCopyOut, OpBuilder b) { assert(!memIndicesStart.empty() && "only 1-d or more memrefs"); @@ -1207,7 +1204,7 @@ static AffineForOp generatePointWiseCopy(Location loc, ValuePtr memref, // for y = ... // fast_buf[x][y] = buf[mem_x + x][mem_y + y] - SmallVector fastBufIndices, memIndices; + SmallVector fastBufIndices, memIndices; AffineForOp copyNestRoot; for (unsigned d = 0, e = fastBufferShape.size(); d < e; ++d) { auto forOp = b.create(loc, 0, fastBufferShape[d]); @@ -1216,7 +1213,7 @@ static AffineForOp generatePointWiseCopy(Location loc, ValuePtr memref, b = forOp.getBodyBuilder(); fastBufIndices.push_back(forOp.getInductionVar()); - ValuePtr memBase = + Value memBase = (memAffineMap == b.getMultiDimIdentityMap(memAffineMap.getNumDims())) ? memIndicesStart[d] : b.create( @@ -1269,7 +1266,7 @@ static LogicalResult generateCopy( const MemRefRegion ®ion, Block *block, Block::iterator begin, Block::iterator end, Block *copyPlacementBlock, Block::iterator copyInPlacementStart, Block::iterator copyOutPlacementStart, - AffineCopyOptions copyOptions, DenseMap &fastBufferMap, + AffineCopyOptions copyOptions, DenseMap &fastBufferMap, DenseSet ©Nests, uint64_t *sizeInBytes, Block::iterator *nBegin, Block::iterator *nEnd) { *nBegin = begin; @@ -1277,7 +1274,7 @@ static LogicalResult generateCopy( FuncOp f = begin->getParentOfType(); OpBuilder topBuilder(f.getBody()); - ValuePtr zeroIndex = topBuilder.create(f.getLoc(), 0); + Value zeroIndex = topBuilder.create(f.getLoc(), 0); if (begin == end) return success(); @@ -1309,9 +1306,9 @@ static LogicalResult generateCopy( // Indices to use for the copying. // Indices for the original memref being copied from/to. - SmallVector memIndices; + SmallVector memIndices; // Indices for the faster buffer being copied into/from. - SmallVector bufIndices; + SmallVector bufIndices; unsigned rank = memRefType.getRank(); SmallVector fastBufferShape; @@ -1337,7 +1334,7 @@ static LogicalResult generateCopy( // 'regionSymbols' hold values that this memory region is symbolic/parametric // on; these typically include loop IVs surrounding the level at which the // copy generation is being done or other valid symbols in MLIR. - SmallVector regionSymbols; + SmallVector regionSymbols; cst->getIdValues(rank, cst->getNumIds(), ®ionSymbols); // Construct the index expressions for the fast memory buffer. The index @@ -1385,7 +1382,7 @@ static LogicalResult generateCopy( } // The faster memory space buffer. - ValuePtr fastMemRef; + Value fastMemRef; // Check if a buffer was already created. bool existingBuf = fastBufferMap.count(memref) > 0; @@ -1425,8 +1422,8 @@ static LogicalResult generateCopy( return failure(); } - ValuePtr stride = nullptr; - ValuePtr numEltPerStride = nullptr; + Value stride = nullptr; + Value numEltPerStride = nullptr; if (!strideInfos.empty()) { stride = top.create(loc, strideInfos[0].stride); numEltPerStride = @@ -1465,7 +1462,7 @@ static LogicalResult generateCopy( copyOptions.tagMemorySpace); auto tagMemRef = prologue.create(loc, tagMemRefType); - SmallVector tagIndices({zeroIndex}); + SmallVector tagIndices({zeroIndex}); auto tagAffineMap = b.getMultiDimIdentityMap(tagIndices.size()); fullyComposeAffineMapAndOperands(&tagAffineMap, &tagIndices); if (!region.isWrite()) { @@ -1574,7 +1571,7 @@ static bool getFullMemRefAsRegion(Operation *opInst, unsigned numParamLoopIVs, SmallVector ivs; getLoopIVs(*opInst, &ivs); ivs.resize(numParamLoopIVs); - SmallVector symbols; + SmallVector symbols; extractForInductionVars(ivs, &symbols); regionCst->reset(rank, numParamLoopIVs, 0); regionCst->setIdValues(rank, rank + numParamLoopIVs, symbols); @@ -1621,12 +1618,12 @@ uint64_t mlir::affineDataCopyGenerate(Block::iterator begin, // List of memory regions to copy for. We need a map vector to have a // guaranteed iteration order to write test cases. CHECK-DAG doesn't help here // since the alloc's for example are identical except for the SSA id. - SmallMapVector, 4> readRegions; - SmallMapVector, 4> writeRegions; + SmallMapVector, 4> readRegions; + SmallMapVector, 4> writeRegions; // Map from original memref's to the fast buffers that their accesses are // replaced with. - DenseMap fastBufferMap; + DenseMap fastBufferMap; // To check for errors when walking the block. bool error = false; @@ -1676,7 +1673,7 @@ uint64_t mlir::affineDataCopyGenerate(Block::iterator begin, // Attempts to update; returns true if 'region' exists in targetRegions. auto updateRegion = - [&](const SmallMapVector, 4> + [&](const SmallMapVector, 4> &targetRegions) { auto it = targetRegions.find(region->memref); if (it == targetRegions.end()) @@ -1728,7 +1725,7 @@ uint64_t mlir::affineDataCopyGenerate(Block::iterator begin, uint64_t totalCopyBuffersSizeInBytes = 0; bool ret = true; auto processRegions = - [&](const SmallMapVector, 4> + [&](const SmallMapVector, 4> ®ions) { for (const auto ®ionEntry : regions) { // For each region, hoist copy in/out past all hoistable diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp index 569c5416edd..ca26074f288 100644 --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -18,7 +18,7 @@ using namespace mlir; -void mlir::replaceAllUsesInRegionWith(ValuePtr orig, ValuePtr replacement, +void mlir::replaceAllUsesInRegionWith(Value orig, Value replacement, Region ®ion) { for (auto &use : llvm::make_early_inc_range(orig->getUses())) { if (region.isAncestor(use.getOwner()->getParentRegion())) @@ -54,14 +54,14 @@ void mlir::visitUsedValuesDefinedAbove( } void mlir::getUsedValuesDefinedAbove(Region ®ion, Region &limit, - llvm::SetVector &values) { + llvm::SetVector &values) { visitUsedValuesDefinedAbove(region, limit, [&](OpOperand *operand) { values.insert(operand->get()); }); } void mlir::getUsedValuesDefinedAbove(MutableArrayRef regions, - llvm::SetVector &values) { + llvm::SetVector &values) { for (Region ®ion : regions) getUsedValuesDefinedAbove(region, region, values); } @@ -137,8 +137,8 @@ namespace { class LiveMap { public: /// Value methods. - bool wasProvenLive(ValuePtr value) { return liveValues.count(value); } - void setProvedLive(ValuePtr value) { + bool wasProvenLive(Value value) { return liveValues.count(value); } + void setProvedLive(Value value) { changed |= liveValues.insert(value).second; } @@ -152,7 +152,7 @@ public: private: bool changed = false; - DenseSet liveValues; + DenseSet liveValues; DenseSet liveOps; }; } // namespace @@ -179,7 +179,7 @@ static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap) { return false; } -static void processValue(ValuePtr value, LiveMap &liveMap) { +static void processValue(Value value, LiveMap &liveMap) { bool provedLive = llvm::any_of(value->getUses(), [&](OpOperand &use) { if (isUseSpeciallyKnownDead(use, liveMap)) return false; @@ -213,9 +213,9 @@ static void propagateLiveness(Operation *op, LiveMap &liveMap) { liveMap.setProvedLive(op); return; } - for (ValuePtr value : op->getResults()) + for (Value value : op->getResults()) processValue(value, liveMap); - bool provedLive = llvm::any_of(op->getResults(), [&](ValuePtr value) { + bool provedLive = llvm::any_of(op->getResults(), [&](Value value) { return liveMap.wasProvenLive(value); }); if (provedLive) @@ -231,7 +231,7 @@ static void propagateLiveness(Region ®ion, LiveMap &liveMap) { // faster convergence to a fixed point (we try to visit uses before defs). for (Operation &op : llvm::reverse(block->getOperations())) propagateLiveness(&op, liveMap); - for (ValuePtr value : block->getArguments()) + for (Value value : block->getArguments()) processValue(value, liveMap); } } @@ -250,7 +250,7 @@ static void eraseTerminatorSuccessorOperands(Operation *terminator, // Iterating args in reverse is needed for correctness, to avoid // shifting later args when earlier args are erased. unsigned arg = argE - argI - 1; - ValuePtr value = terminator->getSuccessor(succ)->getArgument(arg); + Value value = terminator->getSuccessor(succ)->getArgument(arg); if (!liveMap.wasProvenLive(value)) { terminator->eraseSuccessorOperand(succ, arg); } diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 409729a5f20..a6629183dee 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -38,8 +38,7 @@ static bool isMemRefDereferencingOp(Operation &op) { } /// Return the AffineMapAttr associated with memory 'op' on 'memref'. -static NamedAttribute getAffineMapAttrForMemRef(Operation *op, - ValuePtr memref) { +static NamedAttribute getAffineMapAttrForMemRef(Operation *op, Value memref) { return TypeSwitch(op) .Case( @@ -47,10 +46,12 @@ static NamedAttribute getAffineMapAttrForMemRef(Operation *op, } // Perform the replacement in `op`. -LogicalResult mlir::replaceAllMemRefUsesWith( - ValuePtr oldMemRef, ValuePtr newMemRef, Operation *op, - ArrayRef extraIndices, AffineMap indexRemap, - ArrayRef extraOperands, ArrayRef symbolOperands) { +LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, + Operation *op, + ArrayRef extraIndices, + AffineMap indexRemap, + ArrayRef extraOperands, + ArrayRef symbolOperands) { unsigned newMemRefRank = newMemRef->getType().cast().getRank(); (void)newMemRefRank; // unused in opt mode unsigned oldMemRefRank = oldMemRef->getType().cast().getRank(); @@ -96,13 +97,13 @@ LogicalResult mlir::replaceAllMemRefUsesWith( NamedAttribute oldMapAttrPair = getAffineMapAttrForMemRef(op, oldMemRef); AffineMap oldMap = oldMapAttrPair.second.cast().getValue(); unsigned oldMapNumInputs = oldMap.getNumInputs(); - SmallVector oldMapOperands( + SmallVector oldMapOperands( op->operand_begin() + memRefOperandPos + 1, op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs); // Apply 'oldMemRefOperands = oldMap(oldMapOperands)'. - SmallVector oldMemRefOperands; - SmallVector affineApplyOps; + SmallVector oldMemRefOperands; + SmallVector affineApplyOps; oldMemRefOperands.reserve(oldMemRefRank); if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) { for (auto resultExpr : oldMap.getResults()) { @@ -120,14 +121,14 @@ LogicalResult mlir::replaceAllMemRefUsesWith( // Construct new indices as a remap of the old ones if a remapping has been // provided. The indices of a memref come right after it, i.e., // at position memRefOperandPos + 1. - SmallVector remapOperands; + SmallVector remapOperands; remapOperands.reserve(extraOperands.size() + oldMemRefRank + symbolOperands.size()); remapOperands.append(extraOperands.begin(), extraOperands.end()); remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end()); remapOperands.append(symbolOperands.begin(), symbolOperands.end()); - SmallVector remapOutputs; + SmallVector remapOutputs; remapOutputs.reserve(oldMemRefRank); if (indexRemap && @@ -146,7 +147,7 @@ LogicalResult mlir::replaceAllMemRefUsesWith( remapOutputs.append(remapOperands.begin(), remapOperands.end()); } - SmallVector newMapOperands; + SmallVector newMapOperands; newMapOperands.reserve(newMemRefRank); // Prepend 'extraIndices' in 'newMapOperands'. @@ -214,11 +215,13 @@ LogicalResult mlir::replaceAllMemRefUsesWith( return success(); } -LogicalResult mlir::replaceAllMemRefUsesWith( - ValuePtr oldMemRef, ValuePtr newMemRef, ArrayRef extraIndices, - AffineMap indexRemap, ArrayRef extraOperands, - ArrayRef symbolOperands, Operation *domInstFilter, - Operation *postDomInstFilter) { +LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, + ArrayRef extraIndices, + AffineMap indexRemap, + ArrayRef extraOperands, + ArrayRef symbolOperands, + Operation *domInstFilter, + Operation *postDomInstFilter) { unsigned newMemRefRank = newMemRef->getType().cast().getRank(); (void)newMemRefRank; // unused in opt mode unsigned oldMemRefRank = oldMemRef->getType().cast().getRank(); @@ -319,7 +322,7 @@ LogicalResult mlir::replaceAllMemRefUsesWith( void mlir::createAffineComputationSlice( Operation *opInst, SmallVectorImpl *sliceOps) { // Collect all operands that are results of affine apply ops. - SmallVector subOperands; + SmallVector subOperands; subOperands.reserve(opInst->getNumOperands()); for (auto operand : opInst->getOperands()) if (isa_and_nonnull(operand->getDefiningOp())) @@ -349,7 +352,7 @@ void mlir::createAffineComputationSlice( return; OpBuilder builder(opInst); - SmallVector composedOpOperands(subOperands); + SmallVector composedOpOperands(subOperands); auto composedMap = builder.getMultiDimIdentityMap(composedOpOperands.size()); fullyComposeAffineMapAndOperands(&composedMap, &composedOpOperands); @@ -366,7 +369,7 @@ void mlir::createAffineComputationSlice( // affine apply op above instead of existing ones (subOperands). So, they // differ from opInst's operands only for those operands in 'subOperands', for // which they will be replaced by the corresponding one from 'sliceOps'. - SmallVector newOperands(opInst->getOperands()); + SmallVector newOperands(opInst->getOperands()); for (unsigned i = 0, e = newOperands.size(); i < e; i++) { // Replace the subOperands from among the new operands. unsigned j, f; @@ -440,7 +443,7 @@ LogicalResult mlir::normalizeMemRef(AllocOp allocOp) { } auto oldMemRef = allocOp.getResult(); - SmallVector symbolOperands(allocOp.getSymbolicOperands()); + SmallVector symbolOperands(allocOp.getSymbolicOperands()); auto newMemRefType = MemRefType::get(newShape, memrefType.getElementType(), b.getMultiDimIdentityMap(newRank)); diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 2dbac868cc0..6b2b3e1ee7e 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -696,7 +696,7 @@ struct VectorizationState { // Map of old scalar Operation to new vectorized Operation. DenseMap vectorizationMap; // Map of old scalar Value to new vectorized Value. - DenseMap replacementMap; + DenseMap replacementMap; // The strategy drives which loop to vectorize by which amount. const VectorizationStrategy *strategy; // Use-def roots. These represent the starting points for the worklist in the @@ -719,7 +719,7 @@ struct VectorizationState { OperationFolder *folder; private: - void registerReplacement(ValuePtr key, ValuePtr value); + void registerReplacement(Value key, Value value); }; } // end namespace @@ -759,7 +759,7 @@ void VectorizationState::finishVectorizationPattern() { } } -void VectorizationState::registerReplacement(ValuePtr key, ValuePtr value) { +void VectorizationState::registerReplacement(Value key, Value value) { assert(replacementMap.count(key) == 0 && "replacement already registered"); replacementMap.insert(std::make_pair(key, value)); } @@ -767,7 +767,7 @@ void VectorizationState::registerReplacement(ValuePtr key, ValuePtr value) { // Apply 'map' with 'mapOperands' returning resulting values in 'results'. static void computeMemoryOpIndices(Operation *op, AffineMap map, ValueRange mapOperands, - SmallVectorImpl &results) { + SmallVectorImpl &results) { OpBuilder builder(op); for (auto resultExpr : map.getResults()) { auto singleResMap = @@ -794,7 +794,7 @@ static void computeMemoryOpIndices(Operation *op, AffineMap map, /// Such special cases force us to delay the vectorization of the stores until /// the last step. Here we merely register the store operation. template -static LogicalResult vectorizeRootOrTerminal(ValuePtr iv, +static LogicalResult vectorizeRootOrTerminal(Value iv, LoadOrStoreOpPointer memoryOp, VectorizationState *state) { auto memRefType = memoryOp.getMemRef()->getType().template cast(); @@ -814,7 +814,7 @@ static LogicalResult vectorizeRootOrTerminal(ValuePtr iv, if (auto load = dyn_cast(opInst)) { OpBuilder b(opInst); ValueRange mapOperands = load.getMapOperands(); - SmallVector indices; + SmallVector indices; indices.reserve(load.getMemRefType().getRank()); if (load.getAffineMap() != b.getMultiDimIdentityMap(load.getMemRefType().getRank())) { @@ -941,8 +941,7 @@ vectorizeLoopsAndLoadsRecursively(NestedMatch oneMatch, /// element type. /// If `type` is not a valid vector type or if the scalar constant is not a /// valid vector element type, returns nullptr. -static ValuePtr vectorizeConstant(Operation *op, ConstantOp constant, - Type type) { +static Value vectorizeConstant(Operation *op, ConstantOp constant, Type type) { if (!type || !type.isa() || !VectorType::isValidElementType(constant.getType())) { return nullptr; @@ -980,8 +979,8 @@ static ValuePtr vectorizeConstant(Operation *op, ConstantOp constant, /// vectorization is possible with the above logic. Returns nullptr otherwise. /// /// TODO(ntv): handle more complex cases. -static ValuePtr vectorizeOperand(ValuePtr operand, Operation *op, - VectorizationState *state) { +static Value vectorizeOperand(Value operand, Operation *op, + VectorizationState *state) { LLVM_DEBUG(dbgs() << "\n[early-vect]vectorize operand: "); LLVM_DEBUG(operand->print(dbgs())); // 1. If this value has already been vectorized this round, we are done. @@ -1043,7 +1042,7 @@ static Operation *vectorizeOneOperation(Operation *opInst, auto vectorValue = vectorizeOperand(value, opInst, state); ValueRange mapOperands = store.getMapOperands(); - SmallVector indices; + SmallVector indices; indices.reserve(store.getMemRefType().getRank()); if (store.getAffineMap() != b.getMultiDimIdentityMap(store.getMemRefType().getRank())) { @@ -1076,12 +1075,12 @@ static Operation *vectorizeOneOperation(Operation *opInst, vectorTypes.push_back( VectorType::get(state->strategy->vectorSizes, v->getType())); } - SmallVector vectorOperands; + SmallVector vectorOperands; for (auto v : opInst->getOperands()) { vectorOperands.push_back(vectorizeOperand(v, opInst, state)); } // Check whether a single operand is null. If so, vectorization failed. - bool success = llvm::all_of(vectorOperands, [](ValuePtr op) { return op; }); + bool success = llvm::all_of(vectorOperands, [](Value op) { return op; }); if (!success) { LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ an operand failed vectorize"); return nullptr; diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp index cfe703281e2..c776ffe12bd 100644 --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -793,7 +793,7 @@ TEST_FUNC(affine_if_op) { }; auto intSet = IntegerSet::get(2, 2, affineExprs, isEq); - SmallVector affineIfArgs = {zero, zero, ten, ten}; + SmallVector affineIfArgs = {zero, zero, ten, ten}; intrinsics::affine_if(intSet, affineIfArgs, /*withElseRegion=*/false); intrinsics::affine_if(intSet, affineIfArgs, /*withElseRegion=*/true); diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp index 976a1976f01..21cf69ec1fa 100644 --- a/mlir/test/lib/TestDialect/TestDialect.cpp +++ b/mlir/test/lib/TestDialect/TestDialect.cpp @@ -91,7 +91,7 @@ struct TestInlinerInterface : public DialectInlinerInterface { /// Handle the given inlined terminator by replacing it with a new operation /// as necessary. void handleTerminator(Operation *op, - ArrayRef valuesToRepl) const final { + ArrayRef valuesToRepl) const final { // Only handle "test.return" here. auto returnOp = dyn_cast(op); if (!returnOp) @@ -108,7 +108,7 @@ struct TestInlinerInterface : public DialectInlinerInterface { /// operation that takes 'input' as the only operand, and produces a single /// result of 'resultType'. If a conversion can not be generated, nullptr /// should be returned. - Operation *materializeCallConversion(OpBuilder &builder, ValuePtr input, + Operation *materializeCallConversion(OpBuilder &builder, Value input, Type resultType, Location conversionLoc) const final { // Only allow conversion for i16/i32 types. @@ -222,7 +222,7 @@ static ParseResult parseWrappingRegionOp(OpAsmParser &parser, // Create a return terminator in the inner region, pass as operand to the // terminator the returned values from the wrapped operation. - SmallVector return_operands(wrapped_op->getResults()); + SmallVector return_operands(wrapped_op->getResults()); OpBuilder builder(parser.getBuilder().getContext()); builder.setInsertionPointToEnd(&block); builder.create(wrapped_op->getLoc(), return_operands); @@ -288,7 +288,7 @@ OpFoldResult TestOpWithRegionFold::fold(ArrayRef operands) { LogicalResult TestOpWithVariadicResultsAndFolder::fold( ArrayRef operands, SmallVectorImpl &results) { - for (ValuePtr input : this->operands()) { + for (Value input : this->operands()) { results.push_back(input); } return success(); diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td index a8709ddca27..dacb796de18 100644 --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -635,7 +635,7 @@ def OpSymbolBindingB : TEST_Op<"symbol_binding_b", []> { let builders = [ OpBuilder< - "Builder *builder, OperationState &state, ValuePtr operand", + "Builder *builder, OperationState &state, Value operand", [{ state.types.assign({builder->getIntegerType(32)}); state.addOperands({operand}); diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp index b886097202d..929c4a941a2 100644 --- a/mlir/test/lib/TestDialect/TestPatterns.cpp +++ b/mlir/test/lib/TestDialect/TestPatterns.cpp @@ -13,12 +13,11 @@ using namespace mlir; // Native function for testing NativeCodeCall -static ValuePtr chooseOperand(ValuePtr input1, ValuePtr input2, - BoolAttr choice) { +static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { return choice.getValue() ? input1 : input2; } -static void createOpI(PatternRewriter &rewriter, ValuePtr input) { +static void createOpI(PatternRewriter &rewriter, Value input) { rewriter.create(rewriter.getUnknownLoc(), input); } @@ -65,7 +64,7 @@ struct ReturnTypeOpMatch : public RewritePattern { PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final { if (auto retTypeFn = dyn_cast(op)) { - SmallVector values(op->getOperands()); + SmallVector values(op->getOperands()); SmallVector inferedReturnTypes; if (failed(retTypeFn.inferReturnTypes(op->getLoc(), values, op->getAttrs(), op->getRegions(), @@ -124,7 +123,7 @@ struct TestRegionRewriteBlockMovement : public ConversionPattern { : ConversionPattern("test.region", 1, ctx) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // Inline this region into the parent region. auto &parentRegion = *op->getParentRegion(); @@ -157,7 +156,7 @@ struct TestRegionRewriteUndo : public RewritePattern { // Add an explicitly illegal operation to ensure the conversion fails. rewriter.create(op->getLoc(), rewriter.getIntegerType(32)); - rewriter.create(op->getLoc(), ArrayRef()); + rewriter.create(op->getLoc(), ArrayRef()); // Drop this operation. rewriter.eraseOp(op); @@ -174,7 +173,7 @@ struct TestDropOpSignatureConversion : public ConversionPattern { : ConversionPattern("test.drop_region_op", 1, ctx), converter(converter) { } PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Region ®ion = op->getRegion(0); Block *entry = ®ion.front(); @@ -200,7 +199,7 @@ struct TestPassthroughInvalidOp : public ConversionPattern { TestPassthroughInvalidOp(MLIRContext *ctx) : ConversionPattern("test.invalid", 1, ctx) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { rewriter.replaceOpWithNewOp(op, llvm::None, operands, llvm::None); @@ -212,7 +211,7 @@ struct TestSplitReturnType : public ConversionPattern { TestSplitReturnType(MLIRContext *ctx) : ConversionPattern("test.return", 1, ctx) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // Check for a return of F32. if (op->getNumOperands() != 1 || !op->getOperand(0)->getType().isF32()) @@ -237,7 +236,7 @@ struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { TestChangeProducerTypeI32ToF32(MLIRContext *ctx) : ConversionPattern("test.type_producer", 1, ctx) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // If the type is I32, change the type to F32. if (!(*op->result_type_begin()).isInteger(32)) @@ -250,7 +249,7 @@ struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { TestChangeProducerTypeF32ToF64(MLIRContext *ctx) : ConversionPattern("test.type_producer", 1, ctx) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // If the type is F32, change the type to F64. if (!(*op->result_type_begin()).isF32()) @@ -263,7 +262,7 @@ struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) : ConversionPattern("test.type_producer", 10, ctx) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // Always convert to B16, even though it is not a legal type. This tests // that values are unmapped correctly. @@ -275,7 +274,7 @@ struct TestUpdateConsumerType : public ConversionPattern { TestUpdateConsumerType(MLIRContext *ctx) : ConversionPattern("test.type_consumer", 1, ctx) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // Verify that the incoming operand has been successfully remapped to F64. if (!operands[0]->getType().isF64()) @@ -336,7 +335,7 @@ struct TestTypeConverter : public TypeConverter { /// Override the hook to materialize a conversion. This is necessary because /// we generate 1->N type mappings. Operation *materializeConversion(PatternRewriter &rewriter, Type resultType, - ArrayRef inputs, + ArrayRef inputs, Location loc) override { return rewriter.create(loc, resultType, inputs); } @@ -459,13 +458,13 @@ struct OneVResOneVOperandOp1Converter using OpConversionPattern::OpConversionPattern; PatternMatchResult - matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef operands, + matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto origOps = op.getOperands(); assert(std::distance(origOps.begin(), origOps.end()) == 1 && "One operand expected"); - ValuePtr origOp = *origOps.begin(); - SmallVector remappedOperands; + Value origOp = *origOps.begin(); + SmallVector remappedOperands; // Replicate the remapped original operand twice. Note that we don't used // the remapped 'operand' since the goal is testing 'getRemappedValue'. remappedOperands.push_back(rewriter.getRemappedValue(origOp)); diff --git a/mlir/test/lib/Transforms/TestLoopMapping.cpp b/mlir/test/lib/Transforms/TestLoopMapping.cpp index 5b1394d5996..86e5713eb03 100644 --- a/mlir/test/lib/Transforms/TestLoopMapping.cpp +++ b/mlir/test/lib/Transforms/TestLoopMapping.cpp @@ -32,7 +32,7 @@ public: // SSA values for the transformation are created out of thin air by // unregistered "new_processor_id_and_range" operations. This is enough to // emulate mapping conditions. - SmallVector processorIds, numProcessors; + SmallVector processorIds, numProcessors; func.walk([&processorIds, &numProcessors](Operation *op) { if (op->getName().getStringRef() != "new_processor_id_and_range") return; diff --git a/mlir/test/lib/Transforms/TestVectorizationUtils.cpp b/mlir/test/lib/Transforms/TestVectorizationUtils.cpp index e131f4803ef..6f4d948e55f 100644 --- a/mlir/test/lib/Transforms/TestVectorizationUtils.cpp +++ b/mlir/test/lib/Transforms/TestVectorizationUtils.cpp @@ -236,7 +236,7 @@ void VectorizerTestPass::testNormalizeMaps() { for (auto m : matches) { auto app = cast(m.getMatchedOperation()); OpBuilder b(m.getMatchedOperation()); - SmallVector operands(app.getOperands()); + SmallVector operands(app.getOperands()); makeComposedAffineApply(b, app.getLoc(), app.getAffineMap(), operands); } } diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td index 004e7662299..5e6d56ccfa2 100644 --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -216,9 +216,9 @@ def MixOperandsAndAttrs : NS_Op<"mix_operands_and_attrs", []> { } // DEF-LABEL: MixOperandsAndAttrs definitions -// DEF-DAG: ValuePtr MixOperandsAndAttrs::operand() -// DEF-DAG: ValuePtr MixOperandsAndAttrs::otherArg() -// DEF-DAG: void MixOperandsAndAttrs::build(Builder *tblgen_builder, OperationState &tblgen_state, FloatAttr attr, ValuePtr operand, FloatAttr otherAttr, ValuePtr otherArg) +// DEF-DAG: Value MixOperandsAndAttrs::operand() +// DEF-DAG: Value MixOperandsAndAttrs::otherArg() +// DEF-DAG: void MixOperandsAndAttrs::build(Builder *tblgen_builder, OperationState &tblgen_state, FloatAttr attr, Value operand, FloatAttr otherAttr, Value otherArg) // DEF-DAG: APFloat MixOperandsAndAttrs::attr() // DEF-DAG: APFloat MixOperandsAndAttrs::otherAttr() diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td index 55952236429..74da938bd67 100644 --- a/mlir/test/mlir-tblgen/op-decl.td +++ b/mlir/test/mlir-tblgen/op-decl.td @@ -26,7 +26,7 @@ def NS_AOp : NS_Op<"a_op", [NoSideEffect, NoSideEffect]> { ); let regions = (region AnyRegion:$someRegion); - let builders = [OpBuilder<"ValuePtr val">]; + let builders = [OpBuilder<"Value val">]; let parser = [{ foo }]; let printer = [{ bar }]; let verifier = [{ baz }]; @@ -46,12 +46,12 @@ def NS_AOp : NS_Op<"a_op", [NoSideEffect, NoSideEffect]> { // CHECK: class AOpOperandAdaptor { // CHECK: public: -// CHECK: AOpOperandAdaptor(ArrayRef values); -// CHECK: ArrayRef getODSOperands(unsigned index); -// CHECK: ValuePtr a(); -// CHECK: ArrayRef b(); +// CHECK: AOpOperandAdaptor(ArrayRef values); +// CHECK: ArrayRef getODSOperands(unsigned index); +// CHECK: Value a(); +// CHECK: ArrayRef b(); // CHECK: private: -// CHECK: ArrayRef tblgen_operands; +// CHECK: ArrayRef tblgen_operands; // CHECK: }; // CHECK: class AOp : public Op::Impl, OpTrait::HasNoSideEffect, OpTrait::AtLeastNOperands<1>::Impl @@ -60,18 +60,18 @@ def NS_AOp : NS_Op<"a_op", [NoSideEffect, NoSideEffect]> { // CHECK: using OperandAdaptor = AOpOperandAdaptor; // CHECK: static StringRef getOperationName(); // CHECK: Operation::operand_range getODSOperands(unsigned index); -// CHECK: ValuePtr a(); +// CHECK: Value a(); // CHECK: Operation::operand_range b(); // CHECK: Operation::result_range getODSResults(unsigned index); -// CHECK: ValuePtr r(); +// CHECK: Value r(); // CHECK: Region &someRegion(); // CHECK: IntegerAttr attr1Attr() // CHECK: APInt attr1(); // CHECK: FloatAttr attr2Attr() // CHECK: Optional< APFloat > attr2(); -// CHECK: static void build(ValuePtr val); -// CHECK: static void build(Builder *tblgen_builder, OperationState &tblgen_state, Type r, ArrayRef s, ValuePtr a, ValueRange b, IntegerAttr attr1, /*optional*/FloatAttr attr2) -// CHECK: static void build(Builder *tblgen_builder, OperationState &tblgen_state, Type r, ArrayRef s, ValuePtr a, ValueRange b, APInt attr1, /*optional*/FloatAttr attr2) +// CHECK: static void build(Value val); +// CHECK: static void build(Builder *tblgen_builder, OperationState &tblgen_state, Type r, ArrayRef s, Value a, ValueRange b, IntegerAttr attr1, /*optional*/FloatAttr attr2) +// CHECK: static void build(Builder *tblgen_builder, OperationState &tblgen_state, Type r, ArrayRef s, Value a, ValueRange b, APInt attr1, /*optional*/FloatAttr attr2) // CHECK: static void build(Builder *, OperationState &tblgen_state, ArrayRef resultTypes, ValueRange operands, ArrayRef attributes) // CHECK: static ParseResult parse(OpAsmParser &parser, OperationState &result); // CHECK: void print(OpAsmPrinter &p); @@ -111,7 +111,7 @@ def NS_DOp : NS_Op<"op_with_two_operands", []> { def NS_SkipDefaultBuildersOp : NS_Op<"skip_default_builders", []> { let skipDefaultBuilders = 1; - let builders = [OpBuilder<"ValuePtr val">]; + let builders = [OpBuilder<"Value val">]; } // CHECK-LABEL: NS::SkipDefaultBuildersOp declarations diff --git a/mlir/test/mlir-tblgen/op-operand.td b/mlir/test/mlir-tblgen/op-operand.td index c592686ebd3..e2d5862e3ca 100644 --- a/mlir/test/mlir-tblgen/op-operand.td +++ b/mlir/test/mlir-tblgen/op-operand.td @@ -18,7 +18,7 @@ def OpA : NS_Op<"one_normal_operand_op", []> { // CHECK-NEXT: tblgen_operands = values // CHECK: void OpA::build -// CHECK: ValuePtr input +// CHECK: Value input // CHECK: tblgen_state.addOperands(input); // CHECK: void OpA::build @@ -39,19 +39,19 @@ def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]> let arguments = (ins Variadic:$input1, AnyTensor:$input2, Variadic:$input3); } -// CHECK-LABEL: ArrayRef OpDOperandAdaptor::input1 +// CHECK-LABEL: ArrayRef OpDOperandAdaptor::input1 // CHECK-NEXT: return getODSOperands(0); -// CHECK-LABEL: ValuePtr OpDOperandAdaptor::input2 +// CHECK-LABEL: Value OpDOperandAdaptor::input2 // CHECK-NEXT: return *getODSOperands(1).begin(); -// CHECK-LABEL: ArrayRef OpDOperandAdaptor::input3 +// CHECK-LABEL: ArrayRef OpDOperandAdaptor::input3 // CHECK-NEXT: return getODSOperands(2); // CHECK-LABEL: Operation::operand_range OpD::input1 // CHECK-NEXT: return getODSOperands(0); -// CHECK-LABEL: ValuePtr OpD::input2 +// CHECK-LABEL: Value OpD::input2 // CHECK-NEXT: return *getODSOperands(1).begin(); // CHECK-LABEL: OpD::build diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td index f9a77ea492e..a177f9ca305 100644 --- a/mlir/test/mlir-tblgen/op-result.td +++ b/mlir/test/mlir-tblgen/op-result.td @@ -23,9 +23,9 @@ def OpB : NS_Op<"same_input_output_type_op", [SameOperandsAndResultType]> { } // CHECK-LABEL: OpB definitions -// CHECK: void OpB::build(Builder *tblgen_builder, OperationState &tblgen_state, Type y, ValuePtr x) +// CHECK: void OpB::build(Builder *tblgen_builder, OperationState &tblgen_state, Type y, Value x) // CHECK: tblgen_state.addTypes(y); -// CHECK: void OpB::build(Builder *tblgen_builder, OperationState &tblgen_state, ValuePtr x) +// CHECK: void OpB::build(Builder *tblgen_builder, OperationState &tblgen_state, Value x) // CHECK: tblgen_state.addTypes({x->getType()}); def OpC : NS_Op<"three_normal_result_op", []> { @@ -89,7 +89,7 @@ def OpI : NS_Op<"mix_variadic_and_normal_results_op", [SameVariadicResultSize]> // CHECK-LABEL: Operation::result_range OpI::output1 // CHECK-NEXT: return getODSResults(0); -// CHECK-LABEL: ValuePtr OpI::output2 +// CHECK-LABEL: Value OpI::output2 // CHECK-NEXT: return *getODSResults(1).begin(); // CHECK-LABEL: OpI::build diff --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td index fef1b139dc9..ecfe709aa1b 100644 --- a/mlir/test/mlir-tblgen/predicate.td +++ b/mlir/test/mlir-tblgen/predicate.td @@ -16,7 +16,7 @@ def OpA : NS_Op<"op_for_CPred_containing_multiple_same_placeholder", []> { } // CHECK-LABEL: OpA::verify -// CHECK: for (ValuePtr v : getODSOperands(0)) { +// CHECK: for (Value v : getODSOperands(0)) { // CHECK: if (!((v->getType().isInteger(32) || v->getType().isF32()))) def OpB : NS_Op<"op_for_And_PredOpTrait", [ @@ -90,5 +90,5 @@ def OpK : NS_Op<"op_for_AnyTensorOf", []> { } // CHECK-LABEL: OpK::verify -// CHECK: for (ValuePtr v : getODSOperands(0)) { +// CHECK: for (Value v : getODSOperands(0)) { // CHECK: if (!(((v->getType().isa())) && (((v->getType().cast().getElementType().isF32())) || ((v->getType().cast().getElementType().isInteger(32)))))) diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 52cbc08c429..f5b3e0163a1 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -704,8 +704,8 @@ void OpEmitter::genAttrGetters() { // Generates the named operand getter methods for the given Operator `op` and // puts them in `opClass`. Uses `rangeType` as the return type of getters that -// return a range of operands (individual operands are `ValuePtr ` and each -// element in the range must also be `ValuePtr `); use `rangeBeginCall` to get +// return a range of operands (individual operands are `Value ` and each +// element in the range must also be `Value `); use `rangeBeginCall` to get // an iterator to the beginning of the operand range; use `rangeSizeCall` to // obtain the number of operands. `getOperandCallPattern` contains the code // necessary to obtain a single operand whose position will be substituted @@ -782,7 +782,7 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass, auto &m = opClass.newMethod(rangeType, operand.name); m.body() << " return getODSOperands(" << i << ");"; } else { - auto &m = opClass.newMethod("ValuePtr ", operand.name); + auto &m = opClass.newMethod("Value ", operand.name); m.body() << " return *getODSOperands(" << i << ").begin();"; } } @@ -860,7 +860,7 @@ void OpEmitter::genNamedResultGetters() { auto &m = opClass.newMethod("Operation::result_range", result.name); m.body() << " return getODSResults(" << i << ");"; } else { - auto &m = opClass.newMethod("ValuePtr ", result.name); + auto &m = opClass.newMethod("Value ", result.name); m.body() << " return *getODSResults(" << i << ").begin();"; } } @@ -1238,7 +1238,7 @@ void OpEmitter::buildParamList(std::string ¶mList, auto argument = op.getArg(i); if (argument.is()) { const auto &operand = op.getOperand(numOperands); - paramList.append(operand.isVariadic() ? ", ValueRange " : ", ValuePtr "); + paramList.append(operand.isVariadic() ? ", ValueRange " : ", Value "); paramList.append(getArgumentName(op, numOperands)); ++numOperands; } else { @@ -1527,7 +1527,7 @@ void OpEmitter::genOperandResultVerifier(OpMethodBody &body, continue; // Emit a loop to check all the dynamic values in the pack. - body << formatv(" for (ValuePtr v : getODS{0}{1}s({2})) {{\n", + body << formatv(" for (Value v : getODS{0}{1}s({2})) {{\n", // Capitalize the first letter to match the function name valueKind.substr(0, 1).upper(), valueKind.substr(1), staticValue.index()); @@ -1682,7 +1682,7 @@ void OpEmitter::genOpAsmInterface() { namespace { // Helper class to emit Op operand adaptors to an output stream. Operand -// adaptors are wrappers around ArrayRef that provide named operand +// adaptors are wrappers around ArrayRef that provide named operand // getters identical to those defined in the Op. class OpOperandAdaptorEmitter { public: @@ -1698,12 +1698,12 @@ private: OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op) : adapterClass(op.getCppClassName().str() + "OperandAdaptor") { - adapterClass.newField("ArrayRef", "tblgen_operands"); - auto &constructor = adapterClass.newConstructor("ArrayRef values"); + adapterClass.newField("ArrayRef", "tblgen_operands"); + auto &constructor = adapterClass.newConstructor("ArrayRef values"); constructor.body() << " tblgen_operands = values;\n"; generateNamedOperandGetters(op, adapterClass, - /*rangeType=*/"ArrayRef", + /*rangeType=*/"ArrayRef", /*rangeBeginCall=*/"tblgen_operands.begin()", /*rangeSizeCall=*/"tblgen_operands.size()", /*getOperandCallPattern=*/"tblgen_operands[{0}]"); diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 8cfd454d629..824ddae85aa 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -567,7 +567,7 @@ void PatternEmitter::emitRewriteLogic() { os.indent(4) << "rewriter.eraseOp(op0);\n"; } else { // Process replacement result patterns. - os.indent(4) << "SmallVector tblgen_repl_values;\n"; + os.indent(4) << "SmallVector tblgen_repl_values;\n"; for (int i = replStartIndex; i < numResultPatterns; ++i) { DagNode resultTree = pattern.getResultPattern(i); auto val = handleResultPattern(resultTree, offsets[i], 0); @@ -842,7 +842,7 @@ void PatternEmitter::createSeparateLocalVarsForOpArgs( std::string varName; if (operand->isVariadic()) { varName = formatv("tblgen_values_{0}", valueIndex++); - os.indent(6) << formatv("SmallVector {0};\n", varName); + os.indent(6) << formatv("SmallVector {0};\n", varName); std::string range; if (node.isNestedDagArg(argIndex)) { range = childNodeNames[argIndex]; @@ -856,7 +856,7 @@ void PatternEmitter::createSeparateLocalVarsForOpArgs( varName); } else { varName = formatv("tblgen_value_{0}", valueIndex++); - os.indent(6) << formatv("ValuePtr {0} = ", varName); + os.indent(6) << formatv("Value {0} = ", varName); if (node.isNestedDagArg(argIndex)) { os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]); } else { @@ -925,7 +925,7 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs( Operator &resultOp = node.getDialectOp(opMap); os.indent(6) << formatv( - "SmallVector tblgen_values; (void)tblgen_values;\n"); + "SmallVector tblgen_values; (void)tblgen_values;\n"); os.indent(6) << formatv( "SmallVector tblgen_attrs; (void)tblgen_attrs;\n"); diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp index 1aa7d5968d2..d65b216e109 100644 --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -461,7 +461,7 @@ static void emitDeserializationFunction(const Record *attrClass, emitResultDeserialization(op, record->getLoc(), " ", words, wordIndex, resultTypes, valueID, os); - os << formatv(" SmallVector {0};\n", operands); + os << formatv(" SmallVector {0};\n", operands); os << formatv(" SmallVector {0};\n", attributes); // Operand deserialization emitOperandDeserialization(op, record->getLoc(), " ", words, wordIndex, diff --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp index 004a940ca6c..c5bc5179785 100644 --- a/mlir/unittests/IR/OperationSupportTest.cpp +++ b/mlir/unittests/IR/OperationSupportTest.cpp @@ -16,7 +16,7 @@ using namespace mlir::detail; namespace { Operation *createOp(MLIRContext *context, bool resizableOperands, - ArrayRef operands = llvm::None, + ArrayRef operands = llvm::None, ArrayRef resultTypes = llvm::None) { return Operation::create( UnknownLoc::get(context), OperationName("foo.bar", context), resultTypes, @@ -30,7 +30,7 @@ TEST(OperandStorageTest, NonResizable) { Operation *useOp = createOp(&context, /*resizableOperands=*/false, /*operands=*/llvm::None, builder.getIntegerType(16)); - ValuePtr operand = useOp->getResult(0); + Value operand = useOp->getResult(0); // Create a non-resizable operation with one operand. Operation *user = createOp(&context, /*resizableOperands=*/false, operand, @@ -59,7 +59,7 @@ TEST(OperandStorageDeathTest, AddToNonResizable) { Operation *useOp = createOp(&context, /*resizableOperands=*/false, /*operands=*/llvm::None, builder.getIntegerType(16)); - ValuePtr operand = useOp->getResult(0); + Value operand = useOp->getResult(0); // Create a non-resizable operation with one operand. Operation *user = createOp(&context, /*resizableOperands=*/false, operand, @@ -79,7 +79,7 @@ TEST(OperandStorageTest, Resizable) { Operation *useOp = createOp(&context, /*resizableOperands=*/false, /*operands=*/llvm::None, builder.getIntegerType(16)); - ValuePtr operand = useOp->getResult(0); + Value operand = useOp->getResult(0); // Create a resizable operation with one operand. Operation *user = createOp(&context, /*resizableOperands=*/true, operand, -- cgit v1.2.3